Coverage for qml_essentials / qoc.py: 78%

446 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-30 11:43 +0000

1import argparse 

2import csv 

3import itertools 

4import logging 

5import os 

6from typing import Callable, Dict, List, Optional, Tuple, Union 

7 

8import jax 

9from jax import numpy as jnp 

10import optax 

11 

12from qml_essentials.gates import Gates, PulseInformation, PulseEnvelope 

13from qml_essentials import operations as op 

14from qml_essentials import yaqsi as ys 

15from qml_essentials.math import phase_difference, fidelity 

16 

17jax.config.update("jax_enable_x64", True) 

18log = logging.getLogger(__name__) 

19 

20 

21class Cost: 

22 """Weighted wrapper around a cost function. 

23 

24 Combines a cost callable with a scalar or tuple weight and optional 

25 constant keyword arguments. Multiple ``Cost`` instances can be 

26 composed via the ``+`` operator to build a combined objective. 

27 

28 Args: 

29 cost: Callable ``(pulse_params, **ckwargs) -> scalar | tuple``. 

30 weight: Scalar or tuple of per-component weights. 

31 ckwargs: Constant keyword arguments injected into every call. 

32 """ 

33 

34 def __init__( 

35 self, 

36 cost: Callable, 

37 weight: Union[float, Tuple], 

38 ckwargs: Optional[dict] = None, 

39 ): 

40 self.cost = cost 

41 self.weight = weight 

42 self.ckwargs = ckwargs if ckwargs is not None else {} 

43 

44 def __call__(self, *args, **kwargs): 

45 """Evaluate the cost function with injected kwargs and apply weights.""" 

46 cost = self.cost(*args, **kwargs, **self.ckwargs) 

47 if isinstance(self.weight, tuple): 

48 return jnp.array( 

49 [c * w for c, w in zip(cost, self.weight, strict=True)] 

50 ).sum() 

51 return cost * self.weight 

52 

53 def __add__(self, other): 

54 """Compose two cost terms into a single callable that sums them.""" 

55 if other is None: 

56 return lambda *args, **kwargs: self(*args, **kwargs) 

57 if callable(other): 

58 return lambda *args, **kwargs: self(*args, **kwargs) + other( 

59 *args, **kwargs 

60 ) 

61 raise TypeError(f"Cannot add Cost and {type(other)}") 

62 

63 

64def fidelity_cost_fn( 

65 pulse_params: jnp.ndarray, 

66 pulse_script: ys.Script, 

67 target_script: ys.Script, 

68 n_samples: int, 

69) -> Tuple[float, float]: 

70 """ 

71 Cost function returning (1 - fidelity) and |phase_difference| averaged 

72 over n_samples uniformly spaced rotation angles in [0, 2\\pi]. 

73 

74 Uses batched (vmapped) circuit execution: all n_samples rotation 

75 angles are evaluated in a single vectorised call per script, replacing 

76 ``n_samples`` sequential Python-level circuit executions with one 

77 JIT-compiled XLA program each. 

78 

79 Args: 

80 pulse_params: Pulse parameters for evaluation. 

81 pulse_script: Yaqsi script with pulse parameters. 

82 target_script: Yaqsi script as target. 

83 n_samples: Number of parameter samples. 

84 

85 Returns: 

86 Tuple of (abs_diff, phase_diff). 

87 """ 

88 ws = jnp.linspace(0, 2 * jnp.pi, n_samples) 

89 

90 pulse_states = pulse_script.execute( 

91 type="state", 

92 args=(ws, pulse_params), 

93 in_axes=(0, None), 

94 ) # (n_samples, dim) 

95 

96 target_states = target_script.execute( 

97 type="state", 

98 args=(ws,), 

99 in_axes=(0,), 

100 ) # (n_samples, dim) 

101 

102 abs_diff = jnp.mean( 

103 jnp.array(1.0, dtype=jnp.float64) - fidelity(pulse_states, target_states) 

104 ) 

105 phase_diff = jnp.mean(jnp.abs(phase_difference(pulse_states, target_states))) 

106 

107 # TODO: in future we could consider some sort of log based loss for the small values 

108 # or utilize gradient ascent if we run into numerical limitations 

109 

110 return (abs_diff, phase_diff) 

111 

112 

113def pulse_width_cost_fn( 

114 pulse_params: jnp.ndarray, 

115 envelope: str, 

116) -> jnp.ndarray: 

117 """ 

118 Cost function penalising the pulse width (sigma / width). 

119 

120 The pulse width is taken as the last envelope parameter. For 

121 envelopes with no envelope parameters (e.g. ``"general"``), the cost 

122 is zero. 

123 

124 Args: 

125 pulse_params: Pulse parameters for the gate. 

126 envelope: Name of the active pulse envelope. 

127 

128 Returns: 

129 Scalar pulse-width cost. 

130 """ 

131 envelope_info = PulseEnvelope.get(envelope) 

132 n_envelope_params = envelope_info["n_envelope_params"] 

133 

134 if n_envelope_params > 0: 

135 pulse_width = pulse_params[n_envelope_params - 1] 

136 else: 

137 pulse_width = 0 

138 

139 return jnp.array(pulse_width, dtype=jnp.float64) 

140 

141 

142def evolution_time_cost_fn( 

143 pulse_params: jnp.ndarray, 

144 t_target: float, 

145) -> jnp.ndarray: 

146 """ 

147 Cost function penalising deviation of the evolution time from a target. 

148 

149 The evolution time is always the last element of the pulse parameter 

150 vector. The cost is the squared relative deviation from ``t_target``: 

151 

152 cost = ((t - t_target) / t_target) ** 2 

153 

154 This encourages all independently optimized gates to converge towards a 

155 common evolution time, making them compatible when composed into a 

156 circuit. 

157 

158 Args: 

159 pulse_params: Pulse parameters for the gate. 

160 t_target: Target evolution time. 

161 

162 Returns: 

163 Scalar evolution-time cost. 

164 """ 

165 t = pulse_params[-1] 

166 return ((t - t_target) / t_target) ** 2 

167 

168 

169def spectral_density_cost_fn( 

170 pulse_params: jnp.ndarray, 

171 envelope: str, 

172 n_fft: int = 1024, 

173) -> jnp.ndarray: 

174 """ 

175 Cost function penalising the spectral width of a given pulse. 

176 

177 Samples the pulse envelope in the time domain over ``[0, t_evol]`` 

178 (where ``t_evol`` is the last element of pulse_params), computes its 

179 power spectral density via FFT, and returns the normalised RMS bandwidth 

180 (square root of the second central moment of the PSD). 

181 

182 Pulses with narrow spectra (e.g. Gaussian, DRAG) receive a low cost, 

183 whereas pulses with wide spectra (e.g. rectangular) are penalised more 

184 heavily. 

185 

186 For envelopes with no envelope parameters (e.g. ``"general"``), the 

187 cost is zero. 

188 

189 Args: 

190 pulse_params: Pulse parameters for the gate. Envelope parameters 

191 occupy ``pulse_params[:n_envelope_params]`` and the evolution 

192 time is ``pulse_params[-1]``. 

193 envelope: Name of the active pulse envelope. 

194 n_fft: Number of time-domain samples used for the FFT 

195 (default 1024). 

196 

197 Returns: 

198 Scalar spectral-width cost (RMS bandwidth normalised by the 

199 Nyquist frequency so the value is in [0, 1]). 

200 """ 

201 envelope_info = PulseEnvelope.get(envelope) 

202 n_envelope_params = envelope_info["n_envelope_params"] 

203 envelope_fn = envelope_info["fn"] 

204 

205 # Nothing to penalise for envelopes without tuneable shape params 

206 if n_envelope_params == 0 or envelope_fn is None: 

207 return jnp.array(0.0, dtype=jnp.float64) 

208 

209 # Extract envelope parameters and evolution time 

210 env_params = pulse_params[:n_envelope_params] 

211 t_evol = pulse_params[-1] 

212 t_c = t_evol / 2.0 

213 

214 t_samples = jnp.linspace(0.0, t_evol, n_fft) 

215 signal = jax.vmap(lambda t: envelope_fn(env_params, t, t_c))(t_samples) 

216 

217 spectrum = jnp.fft.rfft(signal) 

218 psd = jnp.abs(spectrum) ** 2 

219 psd = psd / (jnp.sum(psd) + 1e-12) # normalise to a distribution 

220 

221 freqs = jnp.linspace(0.0, 1.0, len(psd)) 

222 

223 mean_freq = jnp.sum(freqs * psd) 

224 rms_bw = jnp.sqrt(jnp.sum((freqs - mean_freq) ** 2 * psd)) 

225 

226 return jnp.array(rms_bw, dtype=jnp.float64) 

227 

228 

229# Backward-compatible alias for the old misspelled name 

230sepctral_density_cost_fn = spectral_density_cost_fn 

231 

232 

233class CostFnRegistry: 

234 """Registry of cost functions available for pulse optimisation. 

235 

236 Use :meth:`register` to add new cost functions at runtime and 

237 :meth:`get` / :meth:`available` to query them. 

238 """ 

239 

240 _REGISTRY: Dict[str, dict] = { 

241 "fidelity": { 

242 "fn": fidelity_cost_fn, 

243 "default_weight": (0.5, 0.5), 

244 "ckwargs_keys": ["pulse_script", "target_script", "n_samples"], 

245 }, 

246 "pulse_width": { 

247 "fn": pulse_width_cost_fn, 

248 "default_weight": 1.0, 

249 "ckwargs_keys": ["envelope"], 

250 }, 

251 "evolution_time": { 

252 "fn": evolution_time_cost_fn, 

253 "default_weight": 1.0, 

254 "ckwargs_keys": ["t_target"], 

255 }, 

256 "spectral_density": { 

257 "fn": spectral_density_cost_fn, 

258 "default_weight": 1.0, 

259 "ckwargs_keys": ["envelope"], 

260 }, 

261 } 

262 

263 @classmethod 

264 def available(cls) -> List[str]: 

265 """Return the names of all registered cost functions.""" 

266 return list(cls._REGISTRY.keys()) 

267 

268 @classmethod 

269 def get(cls, name: str) -> dict: 

270 """Look up cost-function metadata by name. 

271 

272 Args: 

273 name: Registered cost function name. 

274 

275 Returns: 

276 Metadata dict with keys ``fn``, 

277 ``default_weight``, ``ckwargs_keys``. 

278 

279 Raises: 

280 ValueError: If name is not registered. 

281 """ 

282 if name not in cls._REGISTRY: 

283 raise ValueError( 

284 f"Unknown cost function '{name}'. " f"Available: {cls.available()}" 

285 ) 

286 return cls._REGISTRY[name] 

287 

288 @classmethod 

289 def parse_cost_arg( 

290 cls, spec: Union[str, Tuple] 

291 ) -> Tuple[str, Union[float, Tuple[float, ...]]]: 

292 """Parse a ``"name:w1,w2,..."`` CLI string into ``(name, weight)``. 

293 If a tuple is provided, it is returned directly. 

294 

295 If the weight part is omitted the default weight from the registry 

296 is used. A single-component weight is returned as a float; 

297 multi-component weights are returned as a tuple of floats. 

298 

299 Args: 

300 spec: A string of the form ``"name"`` or ``"name:w1,w2,..."``. 

301 

302 Returns: 

303 A tuple of ``(name, weight)``. 

304 

305 Raises: 

306 ValueError: If the name is unknown or the number of weight 

307 components does not match the ones in ``default_weight``. 

308 """ 

309 if isinstance(spec, tuple): 

310 return spec 

311 

312 if ":" in spec: 

313 name, weight_str = spec.split(":", 1) 

314 parts = [float(x) for x in weight_str.split(",")] 

315 weight: Union[float, Tuple[float, ...]] = ( 

316 parts[0] if len(parts) == 1 else tuple(parts) 

317 ) 

318 else: 

319 name = spec 

320 weight = cls.get(name)["default_weight"] 

321 

322 # Validate weight count 

323 got = len(weight) if isinstance(weight, tuple) else 1 

324 default_weight = cls.get(name)["default_weight"] 

325 expected = len(default_weight) if isinstance(default_weight, tuple) else 1 

326 

327 if got != expected: 

328 raise ValueError( 

329 f"Cost function '{name}' expects {expected} weight(s), " f"got {got}." 

330 ) 

331 

332 return name, weight 

333 

334 

335class QOC: 

336 """Quantum Optimal Control for pulse-level gate synthesis. 

337 

338 Optimises pulse parameters to reproduce the unitary of standard 

339 quantum gates using a two-stage strategy. 

340 

341 Attributes: 

342 GATES_1Q: Names of supported single-qubit gates. 

343 GATES_2Q: Names of supported two-qubit gates. 

344 DEFAULT_PARAM_RANGES: Default parameter ranges for each gate. 

345 """ 

346 

347 GATES_1Q: List[str] = ["RX", "RY", "RZ", "Rot", "H"] 

348 GATES_2Q: List[str] = ["CX", "CY", "CZ", "CRX", "CRY", "CRZ"] 

349 

350 DEFAULT_PARAM_RANGES = { 

351 1: [(0.05, 2.0)], # evolution time 

352 2: [(0.5, 2.0), (0.05, 2.0)], # not typically used 

353 3: [(0.5, 30.0), (0.05, 2.0), (0.05, 2.0)], # A, σ, t 

354 4: [(0.5, 30.0), (0.05, 2.0), (0.01, 0.5), (0.05, 2.0)], # DRAG 

355 } 

356 

357 def __init__( 

358 self, 

359 envelope: str, 

360 cost_fns: List[Tuple[str, Union[float, Tuple[float, ...]]]], 

361 t_target: float, 

362 n_steps: int, 

363 n_samples: int, 

364 learning_rate: float, 

365 log_interval: int = 50, 

366 file_dir: str = None, 

367 warmup_ratio: float = 0.0, 

368 end_lr_ratio: float = 1.0, 

369 n_restarts: int = 1, 

370 restart_noise_scale: float = 0.5, 

371 grad_clip: float = 1.0, 

372 random_seed: int = 42, 

373 scan_steps: int = 0, 

374 scan_grid_size: int = 5, 

375 scan_ranges: Optional[List[Tuple[float, float]]] = None, 

376 log_scale_params: Optional[List[int]] = None, 

377 ): 

378 """ 

379 Initialize Quantum Optimal Control with Pulse-level Gates. 

380 

381 Args: 

382 envelope (str): Pulse envelope shape to use for optimization. 

383 Must be one of the registered envelopes in PulseEnvelope 

384 (e.g. 'gaussian', 'square', 'cosine', 'drag', 'sech'). 

385 cost_fns (list): List of ``(name, weight)`` tuples that select 

386 which cost functions to use and their weights. name must 

387 be a key in :class:`CostFnRegistry`. *weight* is either a 

388 single float or a tuple of floats matching the number of 

389 return values of the cost function. 

390 t_target (float, optional): Target evolution time for the 

391 ``evolution_time`` cost function. Required when 

392 ``"evolution_time"`` is among the selected cost functions. 

393 n_steps (int): Number of steps in optimization. 

394 n_samples (int): Number of parameter samples per step. 

395 learning_rate (float): Peak learning rate for AdamW. When a 

396 warmup/decay schedule is active this is the maximum LR 

397 reached after the warmup phase. 

398 log_interval (int): Interval for logging. 

399 file_dir (str): Directory to save results. 

400 warmup_ratio (float): Fraction of ``n_steps`` used for linear 

401 warmup (0.0 - 1.0). Set to 0.0 to disable warmup and use 

402 a constant learning rate throughout. A value of e.g. 0.05 

403 means the first 5 % of steps linearly ramp the LR from 

404 ``end_lr_ratio * learning_rate`` to ``learning_rate``. 

405 end_lr_ratio (float): The final learning rate is 

406 ``end_lr_ratio * learning_rate``. Also used as the initial 

407 LR at the start of warmup. Set to 0.0 for full cosine 

408 decay to zero; set to 1.0 (together with 

409 ``warmup_ratio=0.0``) to recover a constant LR. 

410 n_restarts (int): Number of random restarts for the 

411 optimisation. The first run uses the initial parameters 

412 as-is; subsequent runs add scaled random perturbations. 

413 The best result across all restarts is kept. 

414 Set to 1 to disable restarts (default behaviour). 

415 restart_noise_scale (float): Standard deviation of the 

416 Gaussian noise added to the initial parameters for each 

417 restart (relative to the absolute value of each parameter). 

418 Defaults to 0.5 (50 % relative perturbation). 

419 grad_clip (float): Maximum global gradient norm. Gradients 

420 are clipped to this value before being passed to the 

421 optimiser, which stabilises training when the loss 

422 landscape has steep regions. Set to ``float('inf')`` or 

423 0.0 to disable. Defaults to 1.0. 

424 random_seed (int): Base random seed for generating restart 

425 perturbations. Defaults to 42. 

426 scan_steps (int): Number of short gradient-descent steps to 

427 run for each candidate in the coarse grid search 

428 (Stage 0). Set to 0 to disable the grid scan entirely 

429 and rely solely on restarts. A value of 20-50 is 

430 usually enough to identify promising basins. Defaults 

431 to 0. 

432 scan_grid_size (int): Number of points per parameter 

433 dimension in the coarse grid. The total number of 

434 candidates is ``scan_grid_size ** n_params``, so keep 

435 this small for high-dimensional parameter spaces. 

436 Defaults to 5. 

437 scan_ranges (Optional[List[Tuple[float, float]]]): Per- 

438 parameter ``(lo, hi)`` ranges for the grid scan. If 

439 ``None``, heuristic ranges are used based on the 

440 envelope type: amplitude in ``[0.5, 30]``, width/sigma 

441 in ``[0.05, 2]``, and evolution time in ``[0.05, 2]``. 

442 Must have length equal to the number of pulse parameters 

443 if provided. 

444 log_scale_params (Optional[List[int]]): Indices of pulse 

445 parameters that should be optimised in log-space. For 

446 these parameters the optimizer sees ``log(p)`` and the 

447 actual parameter used in the simulation is ``exp(log_p)``. 

448 This dramatically improves convergence when the optimal 

449 value may differ from the initial value by an order of 

450 magnitude (e.g. amplitude, evolution time). 

451 If ``None``, defaults to ``[0, -1]`` (amplitude and 

452 evolution time) for envelopes with ≥ 2 envelope params, 

453 or ``[]`` otherwise. 

454 """ 

455 self.envelope = envelope 

456 self.n_steps = n_steps 

457 self.n_samples = n_samples 

458 self.learning_rate = learning_rate 

459 self.warmup_ratio = warmup_ratio 

460 self.end_lr_ratio = end_lr_ratio 

461 self.log_interval = log_interval 

462 self.file_dir = ( 

463 file_dir if file_dir else os.path.dirname(os.path.realpath(__file__)) 

464 ) 

465 self.t_target = t_target 

466 self.n_restarts = max(1, n_restarts) 

467 self.restart_noise_scale = restart_noise_scale 

468 self.grad_clip = grad_clip 

469 self.random_key = jax.random.PRNGKey(random_seed) 

470 self.scan_steps = scan_steps 

471 self.scan_grid_size = scan_grid_size 

472 self.scan_ranges = scan_ranges 

473 

474 # Determine log-scale param indices 

475 envelope_info = PulseEnvelope.get(envelope) 

476 n_env = envelope_info["n_envelope_params"] 

477 if log_scale_params is not None: 

478 self.log_scale_params = log_scale_params 

479 elif n_env >= 2: 

480 # Default: amplitude (index 0) and evolution time (last) 

481 self.log_scale_params = [0, -1] 

482 else: 

483 self.log_scale_params = [] 

484 

485 log.info( 

486 f"Training parameters: {self.n_steps} steps, " 

487 f"{self.n_samples} samples, {self.learning_rate} learning rate" 

488 ) 

489 log.info( 

490 f"LR schedule: warmup_ratio={self.warmup_ratio}, " 

491 f"end_lr_ratio={self.end_lr_ratio}" 

492 ) 

493 

494 log.info(f"Envelope: {self.envelope}") 

495 log.info(f"Target evolution time: {self.t_target}") 

496 log.info( 

497 f"Restarts: {self.n_restarts}, noise_scale={self.restart_noise_scale}, " 

498 f"grad_clip={self.grad_clip}" 

499 ) 

500 log.info( 

501 f"Grid scan: scan_steps={self.scan_steps}, " 

502 f"scan_grid_size={self.scan_grid_size}, " 

503 f"log_scale_params={self.log_scale_params}" 

504 ) 

505 log.info(f"Using cost function(s) {cost_fns}") 

506 

507 # Validate each entry against the registry 

508 summed_weights = 0 

509 for name, _weight in cost_fns: 

510 CostFnRegistry.get(name) # raises ValueError if unknown 

511 summed_weights += sum(_weight) if isinstance(_weight, tuple) else _weight 

512 assert jnp.isclose( 

513 summed_weights, 1.0, rtol=1e-8 

514 ), f"Cost function weights must sum to 1. Got {summed_weights}" 

515 

516 self.cost_fns = cost_fns 

517 

518 # Configure the pulse system with the selected envelope 

519 PulseInformation.set_envelope(self.envelope) 

520 

521 def save_results(self, gate: str, fidelity: float, pulse_params) -> None: 

522 """Save optimised pulse parameters and fidelity for a gate to CSV. 

523 

524 If the gate already exists in the file, its entry is overwritten 

525 regardless of whether the new fidelity is higher. A warning is 

526 logged when the existing fidelity was better. 

527 

528 Args: 

529 gate: Name of the gate (e.g. ``"RX"``). 

530 fidelity: Achieved fidelity of the optimised pulse. 

531 pulse_params: Optimised pulse parameters for the gate. 

532 """ 

533 if self.file_dir is not None: 

534 os.makedirs(self.file_dir, exist_ok=True) 

535 filename = os.path.join(self.file_dir, "qoc_results.csv") 

536 

537 reader = None 

538 if os.path.isfile(filename): 

539 with open(filename, mode="r", newline="") as f: 

540 reader = csv.reader(f.readlines()) 

541 

542 entry = [gate] + [fidelity] + list(map(float, pulse_params)) 

543 

544 with open(filename, mode="w", newline="") as f: 

545 writer = csv.writer(f) 

546 match = False 

547 if reader is not None: 

548 for row in reader: 

549 # gate already exists 

550 if row[0] == gate: 

551 if fidelity <= float(row[1]): 

552 log.warning( 

553 f"Pulse parameters for {gate} already exist with " 

554 f"higher fidelity ({row[1]} >= {fidelity})" 

555 ) 

556 writer.writerow(entry) 

557 match = True 

558 # any other gate 

559 else: 

560 writer.writerow(row) 

561 # gate does not exist 

562 if not match: 

563 writer.writerow(entry) 

564 

565 def _to_log_space(self, params: jnp.ndarray) -> jnp.ndarray: 

566 """Convert selected parameters to log-space for optimisation. 

567 

568 Parameters at indices in ``self.log_scale_params`` are replaced 

569 by ``log(|p| + eps)`` so the optimiser operates on a 

570 logarithmic scale. All other parameters are left unchanged. 

571 

572 Args: 

573 params: Pulse parameters in physical space. 

574 

575 Returns: 

576 Parameters with selected entries in log-space. 

577 """ 

578 if not self.log_scale_params: 

579 return params 

580 n = len(params) 

581 log_params = params.copy() 

582 for idx in self.log_scale_params: 

583 # Normalise negative indices 

584 i = idx if idx >= 0 else n + idx 

585 log_params = log_params.at[i].set(jnp.log(jnp.abs(params[i]) + 1e-12)) 

586 return log_params 

587 

588 def _from_log_space(self, log_params: jnp.ndarray) -> jnp.ndarray: 

589 """Convert selected parameters back from log-space. 

590 

591 Inverse of :meth:`_to_log_space`. Parameters at indices in 

592 ``self.log_scale_params`` are exponentiated; all others are 

593 passed through unchanged. 

594 

595 Args: 

596 log_params: Parameters with selected entries in log-space. 

597 

598 Returns: 

599 Parameters in physical space (all positive for log-scaled 

600 entries). 

601 """ 

602 if not self.log_scale_params: 

603 return log_params 

604 n = len(log_params) 

605 params = log_params.copy() 

606 for idx in self.log_scale_params: 

607 i = idx if idx >= 0 else n + idx 

608 params = params.at[i].set(jnp.exp(log_params[i])) 

609 return params 

610 

611 def _build_scan_grid(self, n_params: int) -> jnp.ndarray: 

612 """Build a coarse parameter grid for the initial scan phase. 

613 

614 Uses either user-supplied ``scan_ranges`` or heuristic defaults 

615 based on typical Gaussian pulse parameter ranges. 

616 

617 Args: 

618 n_params: Number of pulse parameters. 

619 

620 Returns: 

621 Array of shape ``(n_candidates, n_params)`` with grid points. 

622 """ 

623 if self.scan_ranges is not None: 

624 ranges = self.scan_ranges 

625 assert len(ranges) == n_params, ( 

626 f"scan_ranges has {len(ranges)} entries but gate has " 

627 f"{n_params} parameters." 

628 ) 

629 else: 

630 # [amplitude, sigma/width, evolution_time] 

631 ranges = self.DEFAULT_PARAM_RANGES.get( 

632 n_params, 

633 [(0.1, 10.0)] * n_params, # fallback 

634 ) 

635 

636 # Build log-spaced grids for each parameter 

637 axes = [] 

638 for lo, hi in ranges: 

639 axes.append(jnp.logspace(jnp.log10(lo), jnp.log10(hi), self.scan_grid_size)) 

640 

641 # Cartesian product of all axes 

642 grid = jnp.array(list(itertools.product(*axes))) 

643 return grid 

644 

645 def stage_0_opt(self, init_pulse_params: jnp.ndarray, fidelity_only_cost): 

646 """Run the coarse grid-scan phase (Stage 0). 

647 

648 Evaluates a Cartesian grid of parameter candidates using only the 

649 fidelity cost (ignoring phase). Each candidate is refined with a 

650 few fast gradient steps. Returns the best-found parameters. 

651 

652 Args: 

653 init_pulse_params: Initial pulse parameters to compare against. 

654 fidelity_only_cost: Cost callable using fidelity only. 

655 

656 Returns: 

657 Best pulse parameters found during the scan. 

658 """ 

659 

660 def fidelity_only_cost_log(log_params, *args): 

661 return fidelity_only_cost(self._from_log_space(log_params), *args) 

662 

663 best_scan_params = init_pulse_params 

664 best_scan_loss = fidelity_only_cost(init_pulse_params) 

665 

666 if self.scan_steps > 0: 

667 log.info( 

668 f"Stage 0: Grid scan with {self.scan_grid_size}^" 

669 f"{len(init_pulse_params)} candidates, " 

670 f"{self.scan_steps} steps each" 

671 ) 

672 

673 grid = self._build_scan_grid(len(init_pulse_params)) 

674 log.info(f" Total candidates: {len(grid)}") 

675 

676 # Use a fast, constant-LR Adam for the scan phase 

677 scan_optimizer = optax.chain( 

678 optax.clip_by_global_norm( 

679 self.grad_clip if self.grad_clip > 0 else 1.0 

680 ), 

681 optax.adam(self.learning_rate * 5), # aggressive LR 

682 ) 

683 

684 @jax.jit 

685 def scan_step(opt_state, log_params): 

686 loss, grads = jax.value_and_grad(fidelity_only_cost_log)(log_params) 

687 updates, opt_state = scan_optimizer.update(grads, opt_state, log_params) 

688 log_params = optax.apply_updates(log_params, updates) 

689 return log_params, opt_state, loss 

690 

691 for ci, candidate in enumerate(grid): 

692 log_candidate = self._to_log_space(candidate) 

693 opt_state = scan_optimizer.init(log_candidate) 

694 

695 log_p = log_candidate 

696 for _ in range(self.scan_steps): 

697 log_p, opt_state, loss = scan_step(opt_state, log_p) 

698 

699 # Evaluate final loss 

700 physical_p = self._from_log_space(log_p) 

701 loss = fidelity_only_cost(physical_p) 

702 

703 if loss < best_scan_loss: 

704 best_scan_loss = loss 

705 best_scan_params = physical_p 

706 log.info( 

707 f" Candidate {ci + 1}/{len(grid)}: " 

708 f"loss={loss.item():.3e} improved with " 

709 f"params={physical_p}" 

710 ) 

711 

712 log.info( 

713 f"Stage 0 complete. Best fidelity-only loss: " 

714 f"{best_scan_loss.item():.3e}, " 

715 f"params: {best_scan_params}" 

716 ) 

717 

718 return best_scan_params 

719 

720 def stage_1_opt(self, best_scan_params: jnp.ndarray, total_costs): 

721 """Run multi-restart gradient optimisation (Stage 1). 

722 

723 Performs ``n_restarts`` independent AdamW runs with the full 

724 (weighted) cost function. The first restart uses 

725 ``best_scan_params`` directly; subsequent restarts add random 

726 perturbations. Parameters specified in ``log_scale_params`` are 

727 optimised in log-space. 

728 

729 Args: 

730 best_scan_params: Starting parameters (typically from Stage 0). 

731 total_costs: Combined cost callable. 

732 

733 Returns: 

734 Tuple of ``(best_params, loss_history, best_loss)``. 

735 """ 

736 

737 # Wrap the cost function with log-space reparameterisation 

738 def total_costs_log(log_params, *args): 

739 return total_costs(self._from_log_space(log_params), *args) 

740 

741 # Build learning rate schedule 

742 warmup_steps = int(self.n_steps * self.warmup_ratio) 

743 end_value = self.learning_rate * self.end_lr_ratio 

744 

745 if warmup_steps > 0 or self.end_lr_ratio < 1.0: 

746 schedule = optax.warmup_cosine_decay_schedule( 

747 init_value=(end_value if warmup_steps > 0 else self.learning_rate), 

748 peak_value=self.learning_rate, 

749 warmup_steps=warmup_steps, 

750 decay_steps=self.n_steps, 

751 end_value=end_value, 

752 ) 

753 else: 

754 schedule = self.learning_rate 

755 

756 # Build optimiser chain with gradient clipping 

757 use_clip = ( 

758 self.grad_clip and self.grad_clip > 0 and jnp.isfinite(self.grad_clip) 

759 ) 

760 if use_clip: 

761 optimizer = optax.chain( 

762 optax.clip_by_global_norm(self.grad_clip), 

763 optax.adamw(schedule), 

764 ) 

765 else: 

766 optimizer = optax.adamw(schedule) 

767 

768 @jax.jit 

769 def opt_step(opt_state, log_params, *args): 

770 loss, grads = jax.value_and_grad(total_costs_log)(log_params, *args) 

771 updates, opt_state = optimizer.update(grads, opt_state, log_params) 

772 log_params = optax.apply_updates(log_params, updates) 

773 return log_params, opt_state, loss 

774 

775 # Use the best from grid scan as starting point 

776 start_params = best_scan_params 

777 

778 global_best_loss = jnp.inf 

779 global_best_params = start_params 

780 global_best_history = [] 

781 restart_key = self.random_key 

782 

783 for restart in range(self.n_restarts): 

784 if restart == 0: 

785 params = start_params 

786 else: 

787 # Perturb the starting point 

788 restart_key, sub_key = jax.random.split(restart_key) 

789 noise = jax.random.normal(sub_key, shape=start_params.shape) 

790 scale = ( 

791 jnp.maximum(jnp.abs(start_params), 0.1) * self.restart_noise_scale 

792 ) 

793 params = start_params + noise * scale 

794 # Ensure log-scaled params remain positive before 

795 # conversion (evolution time at index -1 is always 

796 # included since _to_log_space uses jnp.abs anyway, 

797 # but we keep values positive for readability). 

798 params = params.at[-1].set(jnp.abs(params[-1])) 

799 for idx in self.log_scale_params: 

800 i = idx if idx >= 0 else len(params) + idx 

801 params = params.at[i].set(jnp.abs(params[i])) 

802 log.info( 

803 f"Restart {restart + 1}/{self.n_restarts} " 

804 f"with perturbed params: {params}" 

805 ) 

806 

807 # Convert to log-space for optimisation 

808 log_params = self._to_log_space(params) 

809 opt_state = optimizer.init(log_params) 

810 

811 loss = total_costs(params) 

812 loss_history = [loss] 

813 best_loss = loss 

814 best_pulse_params = params 

815 

816 for step in range(self.n_steps): 

817 if step % self.log_interval == 0: 

818 restart_tag = ( 

819 f" [restart {restart + 1}/{self.n_restarts}]" 

820 if self.n_restarts > 1 

821 else "" 

822 ) 

823 log.info( 

824 f"Step {step}/{self.n_steps}, " 

825 f"Loss: {loss_history[-1].item():.3e}" 

826 f"{restart_tag}" 

827 ) 

828 

829 log_params, opt_state, loss = opt_step(opt_state, log_params) 

830 

831 if loss < best_loss: 

832 log.debug(f"Best set of params found at step {step}") 

833 best_loss = loss 

834 best_pulse_params = self._from_log_space(log_params) 

835 

836 loss_history.append(loss) 

837 

838 log.info( 

839 f"Restart {restart + 1}/{self.n_restarts} finished " 

840 f"with best loss: {best_loss.item():.3e}" 

841 ) 

842 

843 if best_loss < global_best_loss: 

844 global_best_loss = best_loss 

845 global_best_params = best_pulse_params 

846 global_best_history = loss_history 

847 

848 return global_best_params, global_best_history, global_best_loss 

849 

850 def optimize(self, wires): 

851 """Decorator factory that optimises pulse parameters for a gate. 

852 

853 Usage:: 

854 

855 opt = qoc.optimize(wires=1) 

856 best_params, loss_history = opt(qoc.create_RX)() 

857 

858 Args: 

859 wires: Number of qubits the gate acts on. 

860 

861 Returns: 

862 A decorator that accepts a circuit-factory function and 

863 returns a callable ``(init_pulse_params=None) -> 

864 (best_params, loss_history)``. 

865 """ 

866 

867 def decorator(create_circuits): 

868 def wrapper(init_pulse_params: jnp.ndarray = None): 

869 """ 

870 Optimise pulse parameters for a quantum gate using a 

871 multi-phase strategy: 

872 

873 Stage 0 - Grid scan (if ``scan_steps > 0``): 

874 Evaluate a coarse grid of parameter candidates using 

875 only the fidelity cost (ignoring phase). Each 

876 candidate is refined with a few fast gradient steps. 

877 The best candidate becomes the starting point for 

878 Stage 1, unless the user-supplied init_pulse_params 

879 are already better. 

880 

881 Stage 1 - Multi-restart gradient optimisation: 

882 Run ``n_restarts`` independent Adam optimisation runs 

883 with the full cost function. The first restart uses 

884 the best point found so far; subsequent restarts add 

885 random perturbations. Parameters at indices in 

886 ``log_scale_params`` are optimised in log-space to 

887 handle order-of-magnitude differences in scale. 

888 

889 Args: 

890 init_pulse_params (array): Initial pulse parameters. 

891 If ``None``, uses the envelope defaults from 

892 :class:`PulseInformation`. 

893 

894 Returns: 

895 tuple: ``(best_params, loss_history)`` from the best 

896 restart. 

897 """ 

898 pulse_circuit, target_circuit = create_circuits() 

899 

900 pulse_script = ys.Script(pulse_circuit, n_qubits=wires) 

901 target_script = ys.Script(target_circuit, n_qubits=wires) 

902 

903 gate_name = create_circuits.__name__.split("_")[1] 

904 

905 if init_pulse_params is None: 

906 init_pulse_params = PulseInformation.gate_by_name(gate_name).params 

907 log.debug( 

908 f"Initial pulse parameters for {gate_name}: {init_pulse_params}" 

909 ) 

910 

911 all_ckwargs = { 

912 "pulse_script": pulse_script, 

913 "target_script": target_script, 

914 "envelope": self.envelope, 

915 "n_samples": self.n_samples, 

916 "t_target": self.t_target, 

917 } 

918 

919 def _build_cost(name, weight): 

920 """Build a Cost from a registry entry, filtering ckwargs.""" 

921 meta = CostFnRegistry.get(name) 

922 return Cost( 

923 cost=meta["fn"], 

924 weight=weight, 

925 ckwargs={ 

926 k: v 

927 for k, v in all_ckwargs.items() 

928 if k in meta["ckwargs_keys"] 

929 }, 

930 ) 

931 

932 total_costs = None 

933 for name, weight in self.cost_fns: 

934 total_costs = _build_cost(name, weight) + total_costs 

935 

936 fidelity_only_cost = _build_cost( 

937 "fidelity", (1.0, 0.0) # 100% fidelity, 0% phase 

938 ) 

939 

940 best_scan_params = self.stage_0_opt( 

941 init_pulse_params, 

942 fidelity_only_cost, 

943 ) 

944 

945 global_best_params, global_best_history, global_best_loss = ( 

946 self.stage_1_opt( 

947 best_scan_params, 

948 total_costs, 

949 ) 

950 ) 

951 self.save_results( 

952 gate=gate_name, 

953 fidelity=1 - global_best_loss.item(), 

954 pulse_params=global_best_params, 

955 ) 

956 

957 return global_best_params, global_best_history 

958 

959 return wrapper 

960 

961 return decorator 

962 

963 def create_RX(self): 

964 """Create pulse and target circuits for the RX gate.""" 

965 

966 def pulse_circuit(w, pulse_params): 

967 Gates.RX(w, 0, pulse_params=pulse_params, gate_mode="pulse") 

968 

969 def target_circuit(w): 

970 op.RX(w, wires=0) 

971 

972 return pulse_circuit, target_circuit 

973 

974 def create_RY(self): 

975 """Create pulse and target circuits for the RY gate.""" 

976 

977 def pulse_circuit(w, pulse_params): 

978 Gates.RY(w, 0, pulse_params=pulse_params, gate_mode="pulse") 

979 

980 def target_circuit(w): 

981 op.RY(w, wires=0) 

982 

983 return pulse_circuit, target_circuit 

984 

985 def create_RZ(self): 

986 """Create pulse and target circuits for the RZ gate. 

987 

988 Both circuits are sandwiched between Hadamard gates to make the 

989 RZ rotation observable in the computational basis. 

990 """ 

991 

992 def pulse_circuit(w, pulse_params): 

993 op.H(wires=0) 

994 Gates.RZ(w, 0, pulse_params=pulse_params, gate_mode="pulse") 

995 op.H(wires=0) 

996 

997 def target_circuit(w): 

998 op.H(wires=0) 

999 op.RZ(w, wires=0) 

1000 op.H(wires=0) 

1001 

1002 return pulse_circuit, target_circuit 

1003 

1004 def create_H(self): 

1005 """Create pulse and target circuits for the Hadamard gate. 

1006 

1007 An RY rotation is prepended to break symmetry. 

1008 """ 

1009 

1010 def pulse_circuit(w, pulse_params): 

1011 op.RY(w, wires=0) 

1012 Gates.H(0, pulse_params=pulse_params, gate_mode="pulse") 

1013 

1014 def target_circuit(w): 

1015 op.RY(w, wires=0) 

1016 op.H(wires=0) 

1017 

1018 return pulse_circuit, target_circuit 

1019 

1020 def create_Rot(self): 

1021 """Create pulse and target circuits for the general Rot gate.""" 

1022 

1023 def pulse_circuit(w, pulse_params): 

1024 op.H(wires=0) 

1025 Gates.Rot(w, w * 2, w * 3, 0, pulse_params=pulse_params, gate_mode="pulse") 

1026 

1027 def target_circuit(w): 

1028 op.H(wires=0) 

1029 op.Rot(w, w * 2, w * 3, wires=0) 

1030 

1031 return pulse_circuit, target_circuit 

1032 

1033 def create_CX(self): 

1034 """Create pulse and target circuits for the CX (CNOT) gate.""" 

1035 

1036 def pulse_circuit(w, pulse_params): 

1037 op.RY(w, wires=0) 

1038 op.H(wires=1) 

1039 Gates.CX(wires=[0, 1], pulse_params=pulse_params, gate_mode="pulse") 

1040 

1041 def target_circuit(w): 

1042 op.RY(w, wires=0) 

1043 op.H(wires=1) 

1044 op.CX(wires=[0, 1]) 

1045 

1046 return pulse_circuit, target_circuit 

1047 

1048 def create_CY(self): 

1049 """Create pulse and target circuits for the CY gate.""" 

1050 

1051 def pulse_circuit(w, pulse_params): 

1052 op.RX(w, wires=0) 

1053 op.H(wires=1) 

1054 Gates.CY(wires=[0, 1], pulse_params=pulse_params, gate_mode="pulse") 

1055 

1056 def target_circuit(w): 

1057 op.RX(w, wires=0) 

1058 op.H(wires=1) 

1059 op.CY(wires=[0, 1]) 

1060 

1061 return pulse_circuit, target_circuit 

1062 

1063 def create_CZ(self): 

1064 """Create pulse and target circuits for the CZ gate.""" 

1065 

1066 def pulse_circuit(w, pulse_params): 

1067 op.RY(w, wires=0) 

1068 op.H(wires=1) 

1069 Gates.CZ(wires=[0, 1], pulse_params=pulse_params, gate_mode="pulse") 

1070 

1071 def target_circuit(w): 

1072 op.RY(w, wires=0) 

1073 op.H(wires=1) 

1074 op.CZ(wires=[0, 1]) 

1075 

1076 return pulse_circuit, target_circuit 

1077 

1078 def create_CRX(self): 

1079 """Create pulse and target circuits for the CRX gate.""" 

1080 

1081 def pulse_circuit(w, pulse_params): 

1082 op.H(wires=0) 

1083 Gates.CRX(w, wires=[0, 1], pulse_params=pulse_params, gate_mode="pulse") 

1084 

1085 def target_circuit(w): 

1086 op.H(wires=0) 

1087 op.CRX(w, wires=[0, 1]) 

1088 

1089 return pulse_circuit, target_circuit 

1090 

1091 def create_CRY(self): 

1092 """Create pulse and target circuits for the CRY gate.""" 

1093 

1094 def pulse_circuit(w, pulse_params): 

1095 op.H(wires=0) 

1096 Gates.CRY(w, wires=[0, 1], pulse_params=pulse_params, gate_mode="pulse") 

1097 

1098 def target_circuit(w): 

1099 op.H(wires=0) 

1100 op.CRY(w, wires=[0, 1]) 

1101 

1102 return pulse_circuit, target_circuit 

1103 

1104 def create_CRZ(self): 

1105 """Create pulse and target circuits for the CRZ gate.""" 

1106 

1107 def pulse_circuit(w, pulse_params): 

1108 op.H(wires=0) 

1109 op.H(wires=1) 

1110 Gates.CRZ(w, wires=[0, 1], pulse_params=pulse_params, gate_mode="pulse") 

1111 

1112 def target_circuit(w): 

1113 op.H(wires=0) 

1114 op.H(wires=1) 

1115 op.CRZ(w, wires=[0, 1]) 

1116 

1117 return pulse_circuit, target_circuit 

1118 

1119 def optimize_all(self, sel_gates: str, make_log: bool) -> None: 

1120 """Optimise all selected gates and optionally write a log CSV. 

1121 

1122 Args: 

1123 sel_gates: Comma-separated gate names or ``"all"``. 

1124 make_log: If ``True``, write per-gate loss histories to 

1125 ``qml_essentials/qoc_logs.csv``. 

1126 """ 

1127 log_history: Dict[str, list] = {} 

1128 

1129 for gate in self.GATES_1Q + self.GATES_2Q: 

1130 if gate in sel_gates or "all" in sel_gates: 

1131 n_wires = 1 if gate in self.GATES_1Q else 2 

1132 opt = self.optimize(wires=n_wires) 

1133 gate_factory = getattr(self, f"create_{gate}") 

1134 log.info(f"Optimizing {gate} gate...") 

1135 optimized_pulse_params, loss_history = opt(gate_factory)() 

1136 log.info(f"Optimized parameters for {gate}: {optimized_pulse_params}") 

1137 best_fid = 1 - min(float(loss) for loss in loss_history) 

1138 log.info(f"Best achieved fidelity: {best_fid * 100:.5f}%") 

1139 log_history[gate] = log_history.get(gate, []) + loss_history 

1140 

1141 if make_log: 

1142 # write log history to file 

1143 with open("qml_essentials/qoc_logs.csv", "w") as f: 

1144 writer = csv.writer(f) 

1145 writer.writerow(log_history.keys()) 

1146 writer.writerows(zip(*log_history.values())) 

1147 

1148 

1149default_qoc_params = { 

1150 "envelope": "gaussian", 

1151 "cost_fns": [ 

1152 ("fidelity", (0.49999999, 0.49999999)), 

1153 ("pulse_width", 0.000000015), 

1154 ("evolution_time", 0.000000005), 

1155 ], 

1156 "t_target": 0.5, 

1157 "n_steps": 1500, 

1158 "n_samples": 20, 

1159 "learning_rate": 0.0001, 

1160 "warmup_ratio": 0.05, 

1161 "end_lr_ratio": 0.01, 

1162 "log_interval": 50, 

1163 "file_dir": None, 

1164 "n_restarts": 3, 

1165 "restart_noise_scale": 0.5, 

1166 "grad_clip": 1.0, 

1167 "random_seed": 1000, 

1168 "scan_steps": 30, 

1169 "scan_grid_size": 5, 

1170 "scan_ranges": None, 

1171 "log_scale_params": None, 

1172} 

1173 

1174if __name__ == "__main__": 

1175 # argparse the selected gate 

1176 parser = argparse.ArgumentParser( 

1177 description="Quantum Optimal Control — pulse-level gate synthesis." 

1178 ) 

1179 parser.add_argument( 

1180 "--gates", 

1181 type=str, 

1182 nargs="+", 

1183 default=["RX", "RY", "RZ", "CZ"], 

1184 choices=QOC.GATES_1Q + QOC.GATES_2Q + ["all"], 

1185 help="Gate(s) to optimize.", 

1186 ) 

1187 parser.add_argument( 

1188 "--log", 

1189 action="store_true", 

1190 default=True, 

1191 help="Log results to file (default: True).", 

1192 ) 

1193 parser.add_argument( 

1194 "--no-log", 

1195 action="store_false", 

1196 dest="log", 

1197 help="Disable logging results to file.", 

1198 ) 

1199 parser.add_argument( 

1200 "--envelope", 

1201 type=str, 

1202 default=default_qoc_params["envelope"], 

1203 choices=PulseEnvelope.available(), 

1204 help="Pulse envelope shape to use for optimization.", 

1205 ) 

1206 parser.add_argument( 

1207 "--costs", 

1208 type=str, 

1209 nargs="+", 

1210 default=default_qoc_params["cost_fns"], 

1211 help=( 

1212 "Cost functions and weights as 'name:w1,w2,...' strings. " 

1213 "If weights are omitted the registry defaults are used. " 

1214 f"Available: {CostFnRegistry.available()}. " 

1215 "Example: --costs fidelity:0.5,0.3 pulse_width:0.2" 

1216 ), 

1217 ) 

1218 parser.add_argument( 

1219 "--t_target", 

1220 type=float, 

1221 default=default_qoc_params["t_target"], 

1222 help=( 

1223 "Target evolution time for the 'evolution_time' cost function. " 

1224 "All gates will be softly encouraged towards this common time." 

1225 ), 

1226 ) 

1227 parser.add_argument( 

1228 "--n_steps", 

1229 type=int, 

1230 default=default_qoc_params["n_steps"], 

1231 help="Number of optimisation steps per gate.", 

1232 ) 

1233 parser.add_argument( 

1234 "--n_samples", 

1235 type=int, 

1236 default=default_qoc_params["n_samples"], 

1237 help="Number of parameter samples in [0, 2\\pi] for cost evaluation.", 

1238 ) 

1239 parser.add_argument( 

1240 "--learning_rate", 

1241 type=float, 

1242 default=default_qoc_params["learning_rate"], 

1243 help="Peak learning rate for the AdamW optimiser.", 

1244 ) 

1245 parser.add_argument( 

1246 "--warmup_ratio", 

1247 type=float, 

1248 default=default_qoc_params["warmup_ratio"], 

1249 help=( 

1250 "Fraction of n_steps used for linear LR warmup (0.0-1.0). " 

1251 "Set to 0 to start at the peak LR immediately." 

1252 ), 

1253 ) 

1254 parser.add_argument( 

1255 "--end_lr_ratio", 

1256 type=float, 

1257 default=default_qoc_params["end_lr_ratio"], 

1258 help=( 

1259 "Final LR as a fraction of --learning_rate after cosine decay. " 

1260 "Also used as the initial LR before warmup. " 

1261 "Set to 1.0 (with --warmup_ratio 0) for a constant LR." 

1262 ), 

1263 ) 

1264 parser.add_argument( 

1265 "--log_interval", 

1266 type=int, 

1267 default=default_qoc_params["log_interval"], 

1268 help="Log the current loss every N steps.", 

1269 ) 

1270 parser.add_argument( 

1271 "--file_dir", 

1272 type=str, 

1273 default=default_qoc_params["file_dir"], 

1274 help="Directory to save qoc_results.csv. Defaults to the package directory.", 

1275 ) 

1276 parser.add_argument( 

1277 "--n_restarts", 

1278 type=int, 

1279 default=default_qoc_params["n_restarts"], 

1280 help=( 

1281 "Number of random restarts for the optimisation. " 

1282 "The first run uses the initial parameters as-is; " 

1283 "subsequent runs add random perturbations. " 

1284 "The best result across all restarts is kept." 

1285 ), 

1286 ) 

1287 parser.add_argument( 

1288 "--restart_noise_scale", 

1289 type=float, 

1290 default=default_qoc_params["restart_noise_scale"], 

1291 help=( 

1292 "Standard deviation of Gaussian noise added to the initial " 

1293 "parameters for each restart, relative to parameter magnitude." 

1294 ), 

1295 ) 

1296 parser.add_argument( 

1297 "--grad_clip", 

1298 type=float, 

1299 default=default_qoc_params["grad_clip"], 

1300 help=( 

1301 "Maximum global gradient norm. Gradients are clipped to this " 

1302 "value before being passed to the optimiser. " 

1303 "Set to 0 to disable." 

1304 ), 

1305 ) 

1306 parser.add_argument( 

1307 "--random_seed", 

1308 type=int, 

1309 default=default_qoc_params["random_seed"], 

1310 help="Base random seed for restart perturbations.", 

1311 ) 

1312 parser.add_argument( 

1313 "--scan_steps", 

1314 type=int, 

1315 default=default_qoc_params["scan_steps"], 

1316 help=( 

1317 "Number of short gradient-descent steps per candidate in the " 

1318 "coarse grid scan (Stage 0). Set to 0 to disable the grid scan." 

1319 ), 

1320 ) 

1321 parser.add_argument( 

1322 "--scan_grid_size", 

1323 type=int, 

1324 default=default_qoc_params["scan_grid_size"], 

1325 help=( 

1326 "Number of points per parameter dimension in the coarse grid. " 

1327 "Total candidates = scan_grid_size^n_params." 

1328 ), 

1329 ) 

1330 parser.add_argument( 

1331 "--scan_ranges", 

1332 type=str, 

1333 nargs="*", 

1334 default=default_qoc_params["scan_ranges"], 

1335 help=( 

1336 "Per-parameter (lo,hi) ranges for the grid scan, given as " 

1337 "'lo,hi' strings. One pair per pulse parameter. " 

1338 "Example: --scan_ranges 0.5,30.0 0.05,2.0 0.05,2.0 " 

1339 "If omitted, heuristic defaults are used." 

1340 ), 

1341 ) 

1342 

1343 args = parser.parse_args() 

1344 sel_gates = args.gates # already a list from nargs="+" 

1345 make_log = args.log 

1346 

1347 # Parse scan_ranges from CLI (list of "lo,hi" strings -> list of tuples) 

1348 scan_ranges = None 

1349 if args.scan_ranges is not None: 

1350 scan_ranges = [] 

1351 for pair in args.scan_ranges: 

1352 lo, hi = pair.split(",") 

1353 scan_ranges.append((float(lo), float(hi))) 

1354 

1355 # Parse cost function specs from CLI 

1356 cost_fns = [CostFnRegistry.parse_cost_arg(spec) for spec in args.costs] 

1357 

1358 # create logger 

1359 log = logging.getLogger("qml_essentials.qoc") 

1360 

1361 log.setLevel(logging.INFO) 

1362 log.addHandler(logging.StreamHandler()) 

1363 

1364 qoc = QOC( 

1365 envelope=args.envelope, 

1366 cost_fns=cost_fns, 

1367 t_target=args.t_target, 

1368 n_steps=args.n_steps, 

1369 n_samples=args.n_samples, 

1370 learning_rate=args.learning_rate, 

1371 warmup_ratio=args.warmup_ratio, 

1372 end_lr_ratio=args.end_lr_ratio, 

1373 log_interval=args.log_interval, 

1374 file_dir=args.file_dir, 

1375 n_restarts=args.n_restarts, 

1376 restart_noise_scale=args.restart_noise_scale, 

1377 grad_clip=args.grad_clip, 

1378 random_seed=args.random_seed, 

1379 scan_steps=args.scan_steps, 

1380 scan_grid_size=args.scan_grid_size, 

1381 scan_ranges=scan_ranges, 

1382 ) 

1383 

1384 qoc.optimize_all(sel_gates=sel_gates, make_log=make_log)