Coverage for qml_essentials / qoc.py: 44%

894 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-16 10:19 +0000

1import argparse 

2import csv 

3import itertools 

4import logging 

5import os 

6from typing import Callable, Dict, List, Optional, Tuple, Union 

7 

8import jax 

9from jax import numpy as jnp 

10import numpy as np 

11import optax 

12 

13from qml_essentials.gates import Gates, PulseInformation, PulseEnvelope 

14from qml_essentials import operations as op 

15from qml_essentials import yaqsi as ys 

16from qml_essentials.math import phase_difference, fidelity 

17 

18jax.config.update("jax_enable_x64", True) 

19log = logging.getLogger(__name__) 

20 

21 

22def _build_optimizer(schedule, grad_clip: float): 

23 """Build the AdamW chain used by both stage-0 and stage-1. 

24 

25 Adds a global-norm gradient-clip step when ``grad_clip`` is a 

26 finite, strictly positive value; otherwise returns plain AdamW. 

27 """ 

28 use_clip = grad_clip and grad_clip > 0 and jnp.isfinite(grad_clip) 

29 if use_clip: 

30 return optax.chain( 

31 optax.clip_by_global_norm(grad_clip), 

32 optax.adamw(schedule), 

33 ) 

34 return optax.adamw(schedule) 

35 

36 

37def _safe_eval(cost_fn: Callable, params: jnp.ndarray) -> jnp.ndarray: 

38 """Evaluate ``cost_fn(params)``; map non-finite results to ``+inf``.""" 

39 loss = cost_fn(params) 

40 return jnp.where(jnp.isfinite(loss), loss, jnp.inf) 

41 

42 

43def _with_basis_prep(circuit_fn: Callable, k: int, n_wires: int) -> Callable: 

44 """Wrap ``circuit_fn`` so it first prepares basis state ``|k⟩``. 

45 

46 The wrapped circuit applies ``PauliX`` on every wire whose bit in 

47 ``k`` is set (MSB first) before delegating to ``circuit_fn``. Used 

48 by both per-gate and joint optimisation paths to build the 

49 column-stacked unitary required by :func:`unitary_cost_fn`. 

50 """ 

51 bits = [(k >> (n_wires - 1 - i)) & 1 for i in range(n_wires)] 

52 

53 def prepared(*args, **kwargs): 

54 for i, bit in enumerate(bits): 

55 if bit: 

56 op.PauliX(wires=i) 

57 circuit_fn(*args, **kwargs) 

58 

59 prepared.__name__ = f"basis{k}_{circuit_fn.__name__}" 

60 return prepared 

61 

62 

63def _sample_rotation_angles(n_samples: int) -> jnp.ndarray: 

64 """Boundary-biased sample of rotation angles in ``[0, 2π)``. 

65 

66 The pulse-vs-target residual scales roughly linearly with rotation 

67 angle, so a uniform sample over ``[0, 2π)`` underweights the 

68 high-residual band that dominates failing tests (typical large-w 

69 test points: π/2, π). We stratify the samples into 

70 

71 * a uniform component covering the full ``[0, 2π)`` circle, and 

72 * a focus component packed in ``[π/2, 3π/2]`` 

73 

74 so the central band is sampled at roughly twice the density of the 

75 tails. Returns at least one angle even for ``n_samples == 1``; 

76 when ``n_samples == 1`` the legacy uniform behaviour is preserved 

77 (single sample at ``w = 0``) to avoid surprising callers. 

78 """ 

79 if n_samples <= 1: 

80 return jnp.linspace(0.0, 2.0 * jnp.pi, max(n_samples, 1), endpoint=False) 

81 # ~1/3 of samples in the central [π/2, 3π/2] band on top of a full 

82 # uniform sweep. Sub-sample counts are rounded so both components 

83 # are non-empty for any ``n_samples >= 2``. 

84 k_focus = max(1, n_samples // 3) 

85 k_uniform = n_samples - k_focus 

86 ws_uniform = jnp.linspace(0.0, 2.0 * jnp.pi, k_uniform, endpoint=False) 

87 ws_focus = jnp.linspace(0.5 * jnp.pi, 1.5 * jnp.pi, k_focus, endpoint=False) 

88 return jnp.concatenate([ws_uniform, ws_focus]) 

89 

90 

91def _run_gate_stage(stage: Optional[Callable], w) -> None: 

92 """Execute an optional gate-preparation stage.""" 

93 if stage is not None: 

94 stage(w) 

95 

96 

97def _chain_gate_stages(*stages: Callable) -> Callable: 

98 """Build a stage that runs multiple preparation operations in sequence.""" 

99 

100 def chained(w): 

101 for stage in stages: 

102 stage(w) 

103 

104 return chained 

105 

106 

107def _make_gate_pair( 

108 pulse_gate: Callable, 

109 target_gate: Callable, 

110 prep: Optional[Callable] = None, 

111 post: Optional[Callable] = None, 

112) -> Tuple[Callable, Callable]: 

113 """Build matching pulse and target circuits with optional pre/post stages.""" 

114 

115 def pulse_circuit(w, pp): 

116 _run_gate_stage(prep, w) 

117 pulse_gate(w, pp) 

118 _run_gate_stage(post, w) 

119 

120 def target_circuit(w): 

121 _run_gate_stage(prep, w) 

122 target_gate(w) 

123 _run_gate_stage(post, w) 

124 

125 return pulse_circuit, target_circuit 

126 

127 

128class Cost: 

129 """Weighted wrapper around a cost function. 

130 

131 Combines a cost callable with a scalar or tuple weight and optional 

132 constant keyword arguments. Multiple ``Cost`` instances can be 

133 composed via the ``+`` operator to build a combined objective. 

134 

135 Args: 

136 cost: Callable ``(pulse_params, **ckwargs) -> scalar | tuple``. 

137 weight: Scalar or tuple of per-component weights. 

138 ckwargs: Constant keyword arguments injected into every call. 

139 """ 

140 

141 def __init__( 

142 self, 

143 cost: Callable, 

144 weight: Union[float, Tuple], 

145 ckwargs: Optional[dict] = None, 

146 ): 

147 self.cost = cost 

148 self.weight = weight 

149 self.ckwargs = ckwargs if ckwargs is not None else {} 

150 

151 def __call__(self, *args, **kwargs): 

152 """Evaluate the cost function with injected kwargs and apply weights.""" 

153 cost = self.cost(*args, **kwargs, **self.ckwargs) 

154 if isinstance(self.weight, tuple): 

155 return jnp.array( 

156 [c * w for c, w in zip(cost, self.weight, strict=True)] 

157 ).sum() 

158 return cost * self.weight 

159 

160 def __add__(self, other): 

161 """Compose two cost terms into a single callable that sums them.""" 

162 if other is None: 

163 return lambda *args, **kwargs: self(*args, **kwargs) 

164 if callable(other): 

165 return lambda *args, **kwargs: ( 

166 self(*args, **kwargs) + other(*args, **kwargs) 

167 ) 

168 raise TypeError(f"Cannot add Cost and {type(other)}") 

169 

170 

171def fidelity_cost_fn( 

172 pulse_params: jnp.ndarray, 

173 pulse_scripts: Union[ys.Script, List[ys.Script]], 

174 target_scripts: Union[ys.Script, List[ys.Script]], 

175 n_samples: int, 

176) -> Tuple[float, float]: 

177 """ 

178 Cost function returning ``(1 - fidelity, 1 - cos(phase_difference))`` 

179 averaged over ``n_samples`` uniformly spaced rotation angles in 

180 ``[0, 2π)`` and across one or more (pulse, target) script pairs. 

181 

182 Multiple script pairs let the optimiser probe sensitivity from 

183 multiple initial states (e.g. ``|0⟩`` and ``|+⟩``). This makes 

184 rotation-axis tilt observable to the cost: from ``|0⟩`` alone an 

185 RX/RY pulse with a small Z-component is largely degenerate with 

186 the correct pulse, but from ``|+⟩`` the same tilt produces a 

187 visible state-vector deviation. 

188 

189 Uses batched (vmapped) circuit execution per script: all 

190 ``n_samples`` rotation angles are evaluated in a single vectorised 

191 call per script, replacing ``n_samples`` sequential Python-level 

192 circuit executions with one JIT-compiled XLA program each. 

193 

194 The phase term uses ``1 - cos(Δφ)`` rather than ``|Δφ|`` so that 

195 it is differentiable everywhere (including at the optimum) and 

196 well-behaved at the ``±π`` wrap-around — important because Stage 0 

197 now sees the same cost as Stage 1. 

198 

199 Args: 

200 pulse_params: Pulse parameters for evaluation. 

201 pulse_scripts: One or a list of yaqsi scripts with pulse 

202 parameters. If a list is supplied, the cost is averaged 

203 element-wise with ``target_scripts`` (which must have the 

204 same length). 

205 target_scripts: One or a list of yaqsi target scripts. 

206 n_samples: Number of parameter samples. 

207 

208 Returns: 

209 Tuple of ``(abs_diff, phase_diff)`` averaged across script pairs. 

210 """ 

211 if not isinstance(pulse_scripts, (list, tuple)): 

212 pulse_scripts = [pulse_scripts] 

213 if not isinstance(target_scripts, (list, tuple)): 

214 target_scripts = [target_scripts] 

215 assert len(pulse_scripts) == len(target_scripts), ( 

216 f"pulse_scripts and target_scripts must have the same length " 

217 f"({len(pulse_scripts)} vs {len(target_scripts)})." 

218 ) 

219 

220 ws = _sample_rotation_angles(n_samples) 

221 

222 abs_diffs = [] 

223 phase_diffs = [] 

224 for p_script, t_script in zip(pulse_scripts, target_scripts): 

225 pulse_states = p_script.execute( 

226 type="state", 

227 args=(ws, pulse_params), 

228 in_axes=(0, None), 

229 ) # (n_samples, dim) 

230 

231 target_states = t_script.execute( 

232 type="state", 

233 args=(ws,), 

234 in_axes=(0,), 

235 ) # (n_samples, dim) 

236 

237 abs_diffs.append( 

238 jnp.mean( 

239 jnp.array(1.0, dtype=jnp.float64) 

240 - fidelity(pulse_states, target_states) 

241 ) 

242 ) 

243 phase_diffs.append( 

244 jnp.mean( 

245 jnp.array(1.0, dtype=jnp.float64) 

246 - jnp.cos(phase_difference(pulse_states, target_states)) 

247 ) 

248 ) 

249 

250 abs_diff = jnp.mean(jnp.stack(abs_diffs)) 

251 phase_diff = jnp.mean(jnp.stack(phase_diffs)) 

252 

253 # TODO: in future we could consider some sort of log based loss for the small values 

254 # or utilize gradient ascent if we run into numerical limitations 

255 

256 return (abs_diff, phase_diff) 

257 

258 

259def unitary_cost_fn( 

260 pulse_params: jnp.ndarray, 

261 pulse_basis_scripts: List[ys.Script], 

262 target_basis_scripts: List[ys.Script], 

263 n_samples: int, 

264 n_qubits: int, 

265) -> Tuple[float, float]: 

266 """Unitary-level cost based on the average gate (process) fidelity. 

267 

268 Builds the full unitary of the pulse and target circuits at every 

269 sampled rotation angle by stacking ``2**n_qubits`` basis-state 

270 evolutions as columns (``U[:, k] = circuit(|k⟩)``). Returns 

271 

272 (1 - |Tr(E)|² / d², 1 - cos(angle(Tr(E)))) 

273 

274 where ``E = U_target† · U_pulse`` and ``d = 2**n_qubits``. 

275 

276 The first component is the standard process-infidelity (which is 

277 *global-phase invariant*). The second component captures the 

278 residual global phase between pulse and target — without it the 

279 optimiser cannot distinguish ``U_pulse`` and ``e^{iα} U_pulse``, 

280 which leaves systematic phase errors in composed gates (e.g. the 

281 H-CZ-H decomposition of CX). 

282 

283 Compared to the state-vector ``fidelity_cost_fn``, this cost 

284 captures rotation-axis tilt and off-diagonal coherent error in a 

285 single number, regardless of which probe state(s) one chooses. 

286 

287 Args: 

288 pulse_params: Pulse parameters under optimisation. 

289 pulse_basis_scripts: List of ``d`` scripts; the k-th script 

290 prepares ``|k⟩`` (via ``PauliX`` gates) and then applies 

291 the pulse-level circuit. 

292 target_basis_scripts: Same for the target circuit. 

293 n_samples: Number of rotation-angle samples in ``[0, 2π)``. 

294 n_qubits: Number of qubits the gate acts on. 

295 

296 Returns: 

297 Tuple ``(process_loss, phase_loss)`` averaged over rotation 

298 angles. 

299 """ 

300 d = 2**n_qubits 

301 assert len(pulse_basis_scripts) == d, ( 

302 f"pulse_basis_scripts must have {d} entries (one per basis " 

303 f"state); got {len(pulse_basis_scripts)}." 

304 ) 

305 assert len(target_basis_scripts) == d, ( 

306 f"target_basis_scripts must have {d} entries (one per basis " 

307 f"state); got {len(target_basis_scripts)}." 

308 ) 

309 

310 ws = _sample_rotation_angles(n_samples) 

311 

312 pulse_cols = [] 

313 target_cols = [] 

314 for k in range(d): 

315 ps = pulse_basis_scripts[k].execute( 

316 type="state", 

317 args=(ws, pulse_params), 

318 in_axes=(0, None), 

319 ) # (n_samples, d) 

320 ts = target_basis_scripts[k].execute( 

321 type="state", 

322 args=(ws,), 

323 in_axes=(0,), 

324 ) # (n_samples, d) 

325 pulse_cols.append(ps) 

326 target_cols.append(ts) 

327 

328 # Stack basis-state outputs as columns of U at every sampled angle. 

329 # Resulting shape (n_samples, d, d) with U[s, :, k] = column k. 

330 U_pulse = jnp.stack(pulse_cols, axis=-1) 

331 U_target = jnp.stack(target_cols, axis=-1) 

332 

333 # E = U_target^† U_pulse, shape (n_samples, d, d) 

334 E = jnp.einsum("sji,sjk->sik", jnp.conj(U_target), U_pulse) 

335 trE = jnp.einsum("sii->s", E) 

336 

337 F_pro = jnp.abs(trE) ** 2 / float(d) ** 2 

338 process_loss = jnp.mean(jnp.array(1.0, dtype=jnp.float64) - F_pro) 

339 phase_loss = jnp.mean(jnp.array(1.0, dtype=jnp.float64) - jnp.cos(jnp.angle(trE))) 

340 

341 return (process_loss, phase_loss) 

342 

343 

344def joint_unitary_cost_fn( 

345 pulse_params: jnp.ndarray, 

346 gate_specs: List[dict], 

347 n_samples: int, 

348) -> Tuple[float, float]: 

349 """Joint unitary-level cost summed over multiple target gates. 

350 

351 Each entry in ``gate_specs`` is a dictionary describing one target 

352 gate that shares the joint parameter vector ``pulse_params``:: 

353 

354 { 

355 "name": str, # gate name (debug) 

356 "n_qubits": int, 

357 "weight": float, # per-gate weight 

358 "assembler": Callable, # theta -> per-gate flat params 

359 "pulse_basis_scripts": List[ys.Script], # 2**n_qubits scripts 

360 "target_basis_scripts": List[ys.Script], 

361 } 

362 

363 The total return value is a ``(process_loss, phase_loss)`` tuple 

364 where each component is ``Σ_g w_g · loss_g(theta)`` divided by the 

365 sum of weights. Sharing the leaf parameters across all target 

366 gates pulls the optimum into a basin that is good for *every* 

367 use-site (composite gates as well as standalone leaves) — fixing 

368 the failure mode where per-gate optimisation pushes a leaf into a 

369 "selfish" basin that is optimal for its standalone use but breaks 

370 composites that contain it. 

371 

372 Args: 

373 pulse_params: Joint leaf parameter vector (theta). 

374 gate_specs: List of per-gate spec dicts (see above). 

375 n_samples: Number of rotation-angle samples per gate. 

376 

377 Returns: 

378 Tuple ``(process_loss, phase_loss)`` averaged over angles and 

379 weighted-summed over gates. 

380 """ 

381 total_proc = jnp.array(0.0, dtype=jnp.float64) 

382 total_phase = jnp.array(0.0, dtype=jnp.float64) 

383 total_w = 0.0 

384 

385 for spec in gate_specs: 

386 per_gate_pp = spec["assembler"](pulse_params) 

387 proc_loss, phase_loss = unitary_cost_fn( 

388 per_gate_pp, 

389 spec["pulse_basis_scripts"], 

390 spec["target_basis_scripts"], 

391 n_samples, 

392 spec["n_qubits"], 

393 ) 

394 w = spec["weight"] 

395 total_proc = total_proc + w * proc_loss 

396 total_phase = total_phase + w * phase_loss 

397 total_w += w 

398 

399 if total_w > 0: 

400 total_proc = total_proc / total_w 

401 total_phase = total_phase / total_w 

402 

403 return (total_proc, total_phase) 

404 

405 

406def pulse_width_cost_fn( 

407 pulse_params: jnp.ndarray, 

408 envelope: str, 

409) -> jnp.ndarray: 

410 """ 

411 Cost function penalising the pulse width (sigma / width). 

412 

413 The pulse width is taken as the last envelope parameter. For 

414 envelopes with no envelope parameters (e.g. ``"general"``), the cost 

415 is zero. 

416 

417 Args: 

418 pulse_params: Pulse parameters for the gate. 

419 envelope: Name of the active pulse envelope. 

420 

421 Returns: 

422 Scalar pulse-width cost. 

423 """ 

424 envelope_info = PulseEnvelope.get(envelope) 

425 n_envelope_params = envelope_info["n_envelope_params"] 

426 

427 if n_envelope_params > 0: 

428 pulse_width = pulse_params[n_envelope_params - 1] 

429 else: 

430 pulse_width = 0 

431 

432 return jnp.array(pulse_width, dtype=jnp.float64) 

433 

434 

435def evolution_time_cost_fn( 

436 pulse_params: jnp.ndarray, 

437 t_target: float, 

438) -> jnp.ndarray: 

439 """ 

440 Cost function penalising deviation of the evolution time from a target. 

441 

442 The evolution time is always the last element of the pulse parameter 

443 vector. The cost is the squared relative deviation from ``t_target``: 

444 

445 cost = ((t - t_target) / t_target) ** 2 

446 

447 This encourages all independently optimized gates to converge towards a 

448 common evolution time, making them compatible when composed into a 

449 circuit. 

450 

451 Args: 

452 pulse_params: Pulse parameters for the gate. 

453 t_target: Target evolution time. 

454 

455 Returns: 

456 Scalar evolution-time cost. 

457 """ 

458 t = pulse_params[-1] 

459 return ((t - t_target) / t_target) ** 2 

460 

461 

462def spectral_density_cost_fn( 

463 pulse_params: jnp.ndarray, 

464 envelope: str, 

465 n_fft: int = 1024, 

466) -> jnp.ndarray: 

467 """ 

468 Cost function penalising the spectral width of a given pulse. 

469 

470 Samples the pulse envelope in the time domain over ``[0, t_evol]`` 

471 (where ``t_evol`` is the last element of pulse_params), computes its 

472 power spectral density via FFT, and returns the normalised RMS bandwidth 

473 (square root of the second central moment of the PSD). 

474 

475 Pulses with narrow spectra (e.g. Gaussian, DRAG) receive a low cost, 

476 whereas pulses with wide spectra (e.g. rectangular) are penalised more 

477 heavily. 

478 

479 For envelopes with no envelope parameters (e.g. ``"general"``), the 

480 cost is zero. 

481 

482 Args: 

483 pulse_params: Pulse parameters for the gate. Envelope parameters 

484 occupy ``pulse_params[:n_envelope_params]`` and the evolution 

485 time is ``pulse_params[-1]``. 

486 envelope: Name of the active pulse envelope. 

487 n_fft: Number of time-domain samples used for the FFT 

488 (default 1024). 

489 

490 Returns: 

491 Scalar spectral-width cost (RMS bandwidth normalised by the 

492 Nyquist frequency so the value is in [0, 1]). 

493 """ 

494 envelope_info = PulseEnvelope.get(envelope) 

495 n_envelope_params = envelope_info["n_envelope_params"] 

496 envelope_fn = envelope_info["fn"] 

497 

498 # Nothing to penalise for envelopes without tuneable shape params 

499 if n_envelope_params == 0 or envelope_fn is None: 

500 return jnp.array(0.0, dtype=jnp.float64) 

501 

502 # Extract envelope parameters and evolution time 

503 env_params = pulse_params[:n_envelope_params] 

504 t_evol = pulse_params[-1] 

505 t_c = t_evol / 2.0 

506 

507 t_samples = jnp.linspace(0.0, t_evol, n_fft) 

508 signal = jax.vmap(lambda t: envelope_fn(env_params, t, t_c))(t_samples) 

509 

510 spectrum = jnp.fft.rfft(signal) 

511 psd = jnp.abs(spectrum) ** 2 

512 psd = psd / (jnp.sum(psd) + 1e-12) # normalise to a distribution 

513 

514 freqs = jnp.linspace(0.0, 1.0, len(psd)) 

515 

516 mean_freq = jnp.sum(freqs * psd) 

517 rms_bw = jnp.sqrt(jnp.sum((freqs - mean_freq) ** 2 * psd)) 

518 

519 return jnp.array(rms_bw, dtype=jnp.float64) 

520 

521 

522class CostFnRegistry: 

523 """Registry of cost functions available for pulse optimisation. 

524 

525 Use :meth:`register` to add new cost functions at runtime and 

526 :meth:`get` / :meth:`available` to query them. 

527 """ 

528 

529 _REGISTRY: Dict[str, dict] = { 

530 "fidelity": { 

531 "fn": fidelity_cost_fn, 

532 "default_weight": (0.5, 0.5), 

533 "ckwargs_keys": ["pulse_scripts", "target_scripts", "n_samples"], 

534 }, 

535 "unitary": { 

536 "fn": unitary_cost_fn, 

537 "default_weight": (0.5, 0.5), 

538 "ckwargs_keys": [ 

539 "pulse_basis_scripts", 

540 "target_basis_scripts", 

541 "n_samples", 

542 "n_qubits", 

543 ], 

544 }, 

545 "pulse_width": { 

546 "fn": pulse_width_cost_fn, 

547 "default_weight": 1.0, 

548 "ckwargs_keys": ["envelope"], 

549 }, 

550 "evolution_time": { 

551 "fn": evolution_time_cost_fn, 

552 "default_weight": 1.0, 

553 "ckwargs_keys": ["t_target"], 

554 }, 

555 "spectral_density": { 

556 "fn": spectral_density_cost_fn, 

557 "default_weight": 1.0, 

558 "ckwargs_keys": ["envelope"], 

559 }, 

560 } 

561 

562 @classmethod 

563 def available(cls) -> List[str]: 

564 """Return the names of all registered cost functions.""" 

565 return list(cls._REGISTRY.keys()) 

566 

567 @classmethod 

568 def get(cls, name: str) -> dict: 

569 """Look up cost-function metadata by name. 

570 

571 Args: 

572 name: Registered cost function name. 

573 

574 Returns: 

575 Metadata dict with keys ``fn``, 

576 ``default_weight``, ``ckwargs_keys``. 

577 

578 Raises: 

579 ValueError: If name is not registered. 

580 """ 

581 if name not in cls._REGISTRY: 

582 raise ValueError( 

583 f"Unknown cost function '{name}'. Available: {cls.available()}" 

584 ) 

585 return cls._REGISTRY[name] 

586 

587 @classmethod 

588 def parse_cost_arg( 

589 cls, spec: Union[str, Tuple] 

590 ) -> Tuple[str, Union[float, Tuple[float, ...]]]: 

591 """Parse a ``"name:w1,w2,..."`` CLI string into ``(name, weight)``. 

592 If a tuple is provided, it is returned directly. 

593 

594 If the weight part is omitted the default weight from the registry 

595 is used. A single-component weight is returned as a float; 

596 multi-component weights are returned as a tuple of floats. 

597 

598 Args: 

599 spec: A string of the form ``"name"`` or ``"name:w1,w2,..."``. 

600 

601 Returns: 

602 A tuple of ``(name, weight)``. 

603 

604 Raises: 

605 ValueError: If the name is unknown or the number of weight 

606 components does not match the ones in ``default_weight``. 

607 """ 

608 if isinstance(spec, tuple): 

609 return spec 

610 

611 if ":" in spec: 

612 name, weight_str = spec.split(":", 1) 

613 parts = [float(x) for x in weight_str.split(",")] 

614 weight: Union[float, Tuple[float, ...]] = ( 

615 parts[0] if len(parts) == 1 else tuple(parts) 

616 ) 

617 else: 

618 name = spec 

619 weight = cls.get(name)["default_weight"] 

620 

621 # Validate weight count 

622 got = len(weight) if isinstance(weight, tuple) else 1 

623 default_weight = cls.get(name)["default_weight"] 

624 expected = len(default_weight) if isinstance(default_weight, tuple) else 1 

625 

626 if got != expected: 

627 raise ValueError( 

628 f"Cost function '{name}' expects {expected} weight(s), got {got}." 

629 ) 

630 

631 return name, weight 

632 

633 

634class QOC: 

635 """Quantum Optimal Control for pulse-level gate synthesis. 

636 

637 Optimises pulse parameters to reproduce the unitary of standard 

638 quantum gates using a two-stage strategy. 

639 

640 Attributes: 

641 GATES_1Q: Names of supported single-qubit gates. 

642 GATES_2Q: Names of supported two-qubit gates. 

643 DEFAULT_PARAM_RANGES: Default parameter ranges for each gate. 

644 """ 

645 

646 GATES_1Q: List[str] = ["RX", "RY", "RZ", "Rot", "H"] 

647 GATES_2Q: List[str] = ["CX", "CY", "CZ", "CRX", "CRY", "CRZ"] 

648 

649 DEFAULT_PARAM_RANGES = { 

650 1: [(0.05, 3.0)], # evolution time 

651 2: [(0.05, 3.0), (0.05, 3.0)], # not typically used 

652 3: [(0.05, 3.0), (0.05, 3.0), (0.05, 3.0)], # [A, sigma, t] 

653 4: [(0.05, 3.0), (0.05, 3.0), (0.05, 3.0), (0.05, 3.0)], # [A, beta, sigma, t] 

654 } 

655 

656 def __init__( 

657 self, 

658 envelope: str, 

659 cost_fns: List[Tuple[str, Union[float, Tuple[float, ...]]]], 

660 t_target: float, 

661 n_steps: int, 

662 n_samples: int, 

663 learning_rate: float, 

664 log_interval: int = 50, 

665 file_dir: str = None, 

666 warmup_ratio: float = 0.0, 

667 end_lr_ratio: float = 1.0, 

668 n_restarts: int = 1, 

669 restart_noise_scale: float = 0.5, 

670 grad_clip: float = 1.0, 

671 random_seed: int = 42, 

672 scan_steps: int = 0, 

673 scan_grid_size: int = 5, 

674 scan_ranges: Optional[List[Tuple[float, float]]] = None, 

675 log_scale_params: Optional[List[int]] = None, 

676 early_stop_patience: int = 0, 

677 early_stop_min_delta: float = 0.0, 

678 plot: bool = False, 

679 ): 

680 """ 

681 Initialize Quantum Optimal Control with Pulse-level Gates. 

682 

683 Args: 

684 envelope (str): Pulse envelope shape to use for optimization. 

685 Must be one of the registered envelopes in PulseEnvelope 

686 (e.g. 'gaussian', 'square', 'cosine', 'drag', 'sech'). 

687 cost_fns (list): List of ``(name, weight)`` tuples that select 

688 which cost functions to use and their weights. name must 

689 be a key in :class:`CostFnRegistry`. *weight* is either a 

690 single float or a tuple of floats matching the number of 

691 return values of the cost function. 

692 t_target (float, optional): Target evolution time for the 

693 ``evolution_time`` cost function. Required when 

694 ``"evolution_time"`` is among the selected cost functions. 

695 n_steps (int): Number of steps in optimization. 

696 n_samples (int): Number of parameter samples per step. 

697 learning_rate (float): Peak learning rate for AdamW. When a 

698 warmup/decay schedule is active this is the maximum LR 

699 reached after the warmup phase. 

700 log_interval (int): Interval for logging. 

701 file_dir (str): Directory to save results. 

702 warmup_ratio (float): Fraction of ``n_steps`` used for linear 

703 warmup (0.0 - 1.0). Set to 0.0 to disable warmup and use 

704 a constant learning rate throughout. A value of e.g. 0.05 

705 means the first 5 % of steps linearly ramp the LR from 

706 ``end_lr_ratio * learning_rate`` to ``learning_rate``. 

707 end_lr_ratio (float): The final learning rate is 

708 ``end_lr_ratio * learning_rate``. Also used as the initial 

709 LR at the start of warmup. Set to 0.0 for full cosine 

710 decay to zero; set to 1.0 (together with 

711 ``warmup_ratio=0.0``) to recover a constant LR. 

712 n_restarts (int): Number of random restarts for the 

713 optimisation. The first run uses the initial parameters 

714 as-is; subsequent runs add scaled random perturbations. 

715 The best result across all restarts is kept. 

716 Set to 1 to disable restarts (default behaviour). 

717 restart_noise_scale (float): Standard deviation of the 

718 Gaussian noise added to the initial parameters for each 

719 restart (relative to the absolute value of each parameter). 

720 Defaults to 0.5 (50 % relative perturbation). Note that 

721 the package-level default in ``default_qoc_params`` is a 

722 much smaller ``0.01`` because the QOC loss landscape is 

723 highly sensitive to initial conditions and large 

724 perturbations routinely move restarts into useless 

725 basins; tune up only if you have reason to believe the 

726 initial point is far from any good basin. 

727 grad_clip (float): Maximum global gradient norm. Gradients 

728 are clipped to this value before being passed to the 

729 optimiser, which stabilises training when the loss 

730 landscape has steep regions. Set to ``float('inf')`` or 

731 0.0 to disable. Defaults to 1.0. 

732 random_seed (int): Base random seed for generating restart 

733 perturbations. Defaults to 42. 

734 scan_steps (int): Number of short gradient-descent steps to 

735 run for each candidate in the coarse grid search 

736 (Stage 0). Set to 0 to disable the grid scan entirely 

737 and rely solely on restarts. A value of 20-50 is 

738 usually enough to identify promising basins. Defaults 

739 to 0. 

740 scan_grid_size (int): Number of points per parameter 

741 dimension in the coarse grid. The total number of 

742 candidates is ``scan_grid_size ** n_params``, so keep 

743 this small for high-dimensional parameter spaces. 

744 Defaults to 5. 

745 scan_ranges (Optional[List[Tuple[float, float]]]): Per- 

746 parameter ``(lo, hi)`` ranges for the grid scan. If 

747 ``None``, heuristic ranges are used based on the 

748 envelope type: amplitude in ``[0.5, 30]``, width/sigma 

749 in ``[0.05, 2]``, and evolution time in ``[0.05, 2]``. 

750 Must have length equal to the number of pulse parameters 

751 if provided. 

752 log_scale_params (Optional[List[int]]): Indices of pulse 

753 parameters that should be optimised in log-space. For 

754 these parameters the optimizer sees ``log(p)`` and the 

755 actual parameter used in the simulation is ``exp(log_p)``. 

756 This dramatically improves convergence when the optimal 

757 value may differ from the initial value by an order of 

758 magnitude (e.g. amplitude, evolution time). 

759 If ``None``, defaults to ``[0, -1]`` (amplitude and 

760 evolution time) for envelopes with ≥ 2 envelope params, 

761 or ``[]`` otherwise. 

762 early_stop_patience (int): Number of consecutive 

763 Stage-1 steps with no improvement greater than 

764 ``early_stop_min_delta`` after which optimisation 

765 exits early. Set to ``0`` (default) to disable. 

766 Only honoured in the single-restart (sequential) 

767 path; when ``n_restarts > 1`` the parallel 

768 vmap+scan path always runs the full ``n_steps``. 

769 early_stop_min_delta (float): Minimum decrease in loss 

770 that counts as an improvement for the early-stopping 

771 patience counter. Defaults to ``0.0`` (any strict 

772 improvement resets the counter). 

773 plot (bool): If ``True``, save a loss-landscape figure after 

774 Phase 0 and a loss-curve figure after Phase 1 to 

775 ``file_dir``. Requires ``matplotlib`` to be installed. 

776 Defaults to ``False``. 

777 """ 

778 self.envelope = envelope 

779 self.n_steps = n_steps 

780 self.n_samples = n_samples 

781 self.learning_rate = learning_rate 

782 self.warmup_ratio = warmup_ratio 

783 self.end_lr_ratio = end_lr_ratio 

784 self.log_interval = log_interval 

785 self.file_dir = ( 

786 file_dir if file_dir else os.path.dirname(os.path.realpath(__file__)) 

787 ) 

788 self.t_target = t_target 

789 self.n_restarts = max(1, n_restarts) 

790 self.restart_noise_scale = restart_noise_scale 

791 self.grad_clip = grad_clip 

792 self.random_key = jax.random.PRNGKey(random_seed) 

793 self.scan_steps = scan_steps 

794 self.scan_grid_size = scan_grid_size 

795 self.scan_ranges = scan_ranges 

796 

797 # Determine log-scale param indices 

798 envelope_info = PulseEnvelope.get(envelope) 

799 n_env = envelope_info["n_envelope_params"] 

800 if log_scale_params is not None: 

801 self.log_scale_params = log_scale_params 

802 elif n_env >= 2: 

803 # Default: amplitude (index 0) and evolution time (last) 

804 self.log_scale_params = [0, -1] 

805 else: 

806 self.log_scale_params = [] 

807 

808 # Mask cache used by ``_to_log_space``/``_from_log_space``; 

809 # rebuilt lazily because the mask length depends on the size of 

810 # the param vector being converted (per-gate vs joint). 

811 self._log_mask_cache: Dict[int, jnp.ndarray] = {} 

812 

813 self.early_stop_patience = max(0, int(early_stop_patience)) 

814 self.early_stop_min_delta = float(early_stop_min_delta) 

815 

816 self.plot = plot 

817 

818 log.info( 

819 f"Training parameters: {self.n_steps} steps, " 

820 f"{self.n_samples} samples, {self.learning_rate} learning rate" 

821 ) 

822 log.info( 

823 f"LR schedule: warmup_ratio={self.warmup_ratio}, " 

824 f"end_lr_ratio={self.end_lr_ratio}" 

825 ) 

826 

827 log.info(f"Envelope: {self.envelope}") 

828 log.info(f"Target evolution time: {self.t_target}") 

829 log.info( 

830 f"Restarts: {self.n_restarts}, noise_scale={self.restart_noise_scale}, " 

831 f"grad_clip={self.grad_clip}" 

832 ) 

833 if PulseInformation.get_rwa(): 

834 log.info("Using RWA. Rotating frame is ignored.") 

835 else: 

836 log.info(f"Using no RWA and {PulseInformation.get_frame()} frame.") 

837 

838 if self.early_stop_patience > 0: 

839 log.info( 

840 f"Early stopping: patience={self.early_stop_patience}, " 

841 f"min_delta={self.early_stop_min_delta:g}" 

842 ) 

843 log.info( 

844 f"Grid scan: scan_steps={self.scan_steps}, " 

845 f"scan_grid_size={self.scan_grid_size}, " 

846 f"log_scale_params={self.log_scale_params}" 

847 ) 

848 log.info(f"Using cost function(s) {cost_fns}") 

849 

850 # Validate each entry against the registry 

851 summed_weights = 0 

852 for name, _weight in cost_fns: 

853 CostFnRegistry.get(name) # raises ValueError if unknown 

854 summed_weights += sum(_weight) if isinstance(_weight, tuple) else _weight 

855 assert jnp.isclose(summed_weights, 1.0, rtol=1e-8), ( 

856 f"Cost function weights must sum to 1. Got {summed_weights}" 

857 ) 

858 

859 self.cost_fns = cost_fns 

860 

861 # Configure the pulse system with the selected envelope 

862 PulseInformation.set_envelope(self.envelope) 

863 

864 def save_results(self, gate: str, fidelity: float, pulse_params) -> None: 

865 """Save optimised pulse parameters and fidelity for a gate to CSV. 

866 

867 If the gate already exists in the file, its entry is overwritten 

868 regardless of whether the new fidelity is higher. A warning is 

869 logged when the existing fidelity was better. 

870 

871 Args: 

872 gate: Name of the gate (e.g. ``"RX"``). 

873 fidelity: Achieved fidelity of the optimised pulse. 

874 pulse_params (jnp.ndarray): Optimised pulse parameters for the gate. 

875 """ 

876 if self.file_dir is not None: 

877 os.makedirs(self.file_dir, exist_ok=True) 

878 filename = os.path.join(self.file_dir, f"qoc_results_{self.envelope}.csv") 

879 

880 reader = None 

881 if os.path.isfile(filename): 

882 with open(filename, mode="r", newline="") as f: 

883 reader = csv.reader(f.readlines()) 

884 

885 entry = [gate] + [fidelity] + list(map(float, pulse_params)) 

886 

887 with open(filename, mode="w", newline="") as f: 

888 writer = csv.writer(f) 

889 match = False 

890 if reader is not None: 

891 for row in reader: 

892 # gate already exists 

893 if row[0] == gate: 

894 if fidelity <= float(row[1]): 

895 log.warning( 

896 f"Pulse parameters for {gate} already exist with " 

897 f"higher fidelity ({row[1]} >= {fidelity})" 

898 ) 

899 writer.writerow(entry) 

900 match = True 

901 # any other gate 

902 else: 

903 writer.writerow(row) 

904 # gate does not exist 

905 if not match: 

906 writer.writerow(entry) 

907 

908 def _log_mask(self, n: int) -> jnp.ndarray: 

909 """Return a boolean mask of length ``n`` marking log-scaled indices.""" 

910 cached = self._log_mask_cache.get(n) 

911 if cached is not None and cached.shape[0] == n: 

912 return cached 

913 mask = np.zeros(n, dtype=bool) 

914 for idx in self.log_scale_params: 

915 i = idx if idx >= 0 else n + idx 

916 if 0 <= i < n: 

917 mask[i] = True 

918 out = jnp.asarray(mask) 

919 self._log_mask_cache[n] = out 

920 return out 

921 

922 def _to_log_space(self, params: jnp.ndarray) -> jnp.ndarray: 

923 """Convert selected parameters to log-space for optimisation. 

924 

925 Parameters at indices in ``self.log_scale_params`` are replaced 

926 by ``log(|p| + eps)`` so the optimiser operates on a 

927 logarithmic scale. All other parameters are left unchanged. 

928 """ 

929 if not self.log_scale_params: 

930 return params 

931 mask = self._log_mask(params.shape[0]) 

932 log_vals = jnp.log(jnp.abs(params) + 1e-12) 

933 return jnp.where(mask, log_vals, params) 

934 

935 def _from_log_space(self, log_params: jnp.ndarray) -> jnp.ndarray: 

936 """Convert selected parameters back from log-space. 

937 

938 Inverse of :meth:`_to_log_space`. Parameters at indices in 

939 ``self.log_scale_params`` are exponentiated; all others are 

940 passed through unchanged. 

941 """ 

942 if not self.log_scale_params: 

943 return log_params 

944 mask = self._log_mask(log_params.shape[0]) 

945 return jnp.where(mask, jnp.exp(log_params), log_params) 

946 

947 # Multiplicative factors used to build a centred grid around the 

948 # supplied init parameters when no explicit ``scan_ranges`` are 

949 # given. ``1.0`` is included so the init point itself is always a 

950 # candidate (Stage 0 cannot otherwise re-evaluate it as a grid 

951 # point — only as the baseline ``best_scan_loss``). 

952 SCAN_REL_FACTORS: Tuple[float, ...] = (0.5, 0.75, 1.0, 1.25, 1.5) 

953 

954 def _build_scan_grid( 

955 self, 

956 n_params: int, 

957 init_pulse_params: Optional[jnp.ndarray] = None, 

958 ) -> Tuple[jnp.ndarray, List[jnp.ndarray]]: 

959 """Build a coarse parameter grid for the initial scan phase. 

960 

961 If the user supplied ``scan_ranges`` they take precedence and 

962 a log-spaced grid is built within those bounds. Otherwise, when 

963 ``init_pulse_params`` is available, a **multiplicative grid 

964 centred on the init point** is used (each axis spans 

965 ``init * SCAN_REL_FACTORS``) so that already-optimised init 

966 params are always re-evaluated and only their immediate 

967 neighbourhood is explored. This avoids the failure mode where 

968 the global ``DEFAULT_PARAM_RANGES`` brackets exclude the actual 

969 optimum (the previous default range was ``(0.05, 3.0)`` per 

970 axis, which clipped DRAG amplitudes around 3.1 and made the 

971 scan systematically worse than the init point). 

972 

973 Args: 

974 n_params: Number of pulse parameters. 

975 init_pulse_params: Optional init params used to centre the 

976 multiplicative grid when ``scan_ranges`` is ``None``. 

977 

978 Returns: 

979 Tuple of: 

980 - Array of shape ``(n_candidates, n_params)`` with grid points. 

981 - List of 1-D arrays, one per parameter axis. 

982 """ 

983 if self.scan_ranges is not None: 

984 ranges = self.scan_ranges 

985 assert len(ranges) == n_params, ( 

986 f"scan_ranges has {len(ranges)} entries but gate has " 

987 f"{n_params} parameters." 

988 ) 

989 # Build log-spaced grids for each parameter 

990 axes = [] 

991 for lo, hi in ranges: 

992 axes.append( 

993 jnp.logspace(jnp.log10(lo), jnp.log10(hi), self.scan_grid_size) 

994 ) 

995 elif init_pulse_params is not None: 

996 # Multiplicative grid centred on init params. We pick 

997 # ``scan_grid_size`` factors symmetric around 1.0. When 

998 # ``scan_grid_size`` matches the static SCAN_REL_FACTORS 

999 # length we use those; otherwise build a symmetric linspace. 

1000 if self.scan_grid_size == len(self.SCAN_REL_FACTORS): 

1001 factors = jnp.array(self.SCAN_REL_FACTORS, dtype=jnp.float64) 

1002 else: 

1003 half = (self.scan_grid_size - 1) / 2.0 

1004 if half <= 0: 

1005 factors = jnp.array([1.0], dtype=jnp.float64) 

1006 else: 

1007 factors = jnp.linspace( 

1008 1.0 - 0.5, 

1009 1.0 + 0.5, 

1010 self.scan_grid_size, 

1011 dtype=jnp.float64, 

1012 ) 

1013 axes = [factors * float(p) for p in init_pulse_params] 

1014 else: 

1015 # Fall back to legacy log-spaced default ranges 

1016 ranges = self.DEFAULT_PARAM_RANGES.get( 

1017 n_params, 

1018 [(0.1, 10.0)] * n_params, 

1019 ) 

1020 axes = [] 

1021 for lo, hi in ranges: 

1022 axes.append( 

1023 jnp.logspace(jnp.log10(lo), jnp.log10(hi), self.scan_grid_size) 

1024 ) 

1025 

1026 # Cartesian product of all axes 

1027 grid = jnp.array(list(itertools.product(*axes))) 

1028 return grid, axes 

1029 

1030 def stage_0_opt( 

1031 self, init_pulse_params: jnp.ndarray, total_cost: Callable 

1032 ) -> Tuple[jnp.ndarray, Optional[Tuple[List[jnp.ndarray], list]]]: 

1033 """Run the coarse grid-scan phase (Stage 0). 

1034 

1035 Evaluates a Cartesian grid of parameter candidates using the 

1036 **full weighted cost** (fidelity + phase, plus any other 

1037 registered terms) — the same objective as Stage 1. Each 

1038 candidate is refined with a few fast gradient steps. Returns 

1039 the best-found parameters. 

1040 

1041 Sharing the objective with Stage 1 prevents the grid scan from 

1042 landing in a basin that has high fidelity but a biased phase 

1043 which Adam then has to migrate out of (the previous 

1044 fidelity-only scan caused exactly this failure mode for RX/RY, 

1045 whose phase residuals compounded in the CRX decomposition). 

1046 

1047 Robustness: candidates that produce a non-finite loss (e.g. when 

1048 the underlying pulse drives the integrator into a NaN — typical 

1049 for very narrow DRAG envelopes) are skipped with a warning. For 

1050 the duration of the scan, :class:`qml_essentials.yaqsi.Yaqsi` is 

1051 switched into ``throw=False`` mode so a single bad candidate 

1052 cannot abort the loop with ``MaxStepsReached``; the previous 

1053 defaults are restored on exit. 

1054 

1055 Args: 

1056 init_pulse_params: Initial pulse parameters to compare against. 

1057 total_cost: Combined cost callable (same as Stage 1). 

1058 

1059 Returns: 

1060 Tuple of: 

1061 - Best pulse parameters found during the scan. 

1062 - ``(grid_axes, landscape_data)`` if the grid scan ran, else 

1063 ``None``. ``landscape_data`` is a list of 

1064 ``(candidate_index, original_params, loss)`` tuples for 

1065 every successful scan candidate. 

1066 """ 

1067 

1068 def total_cost_log(log_params, *args): 

1069 return total_cost(self._from_log_space(log_params), *args) 

1070 

1071 best_scan_params = init_pulse_params 

1072 best_scan_loss = _safe_eval(total_cost, init_pulse_params) 

1073 if not jnp.isfinite(best_scan_loss): 

1074 log.warning( 

1075 "Stage 0: initial pulse parameters produced a non-finite " 

1076 "loss; falling back to a placeholder loss of +inf." 

1077 ) 

1078 

1079 landscape_data: list = [] 

1080 axes_out: Optional[List[jnp.ndarray]] = None 

1081 

1082 if self.scan_steps > 0: 

1083 log.info( 

1084 f"Stage 0: Grid scan with {self.scan_grid_size}^" 

1085 f"{len(init_pulse_params)} candidates, " 

1086 f"{self.scan_steps} steps each" 

1087 ) 

1088 

1089 grid, axes_out = self._build_scan_grid( 

1090 len(init_pulse_params), 

1091 init_pulse_params=init_pulse_params, 

1092 ) 

1093 log.info(f" Total candidates: {len(grid)}") 

1094 

1095 # Use a fast Adam for the scan phase. The aggressive 5× 

1096 # multiplier originally used here tended to push refined 

1097 # candidates *out* of good basins; 2× keeps the refinement 

1098 # localised. Always-evaluate-the-raw-candidate below 

1099 # additionally guards against this. 

1100 scan_optimizer = optax.chain( 

1101 optax.clip_by_global_norm( 

1102 self.grad_clip if self.grad_clip > 0 else 1.0 

1103 ), 

1104 optax.adam(self.learning_rate * 2), 

1105 ) 

1106 

1107 @jax.jit 

1108 def refine_candidate(log_candidate): 

1109 """Run ``self.scan_steps`` Adam steps on a single candidate. 

1110 

1111 Fused into a single ``jax.lax.scan`` so the whole 

1112 refinement is one XLA program — no per-step host 

1113 syncs, no Python-loop dispatch. Returns the final 

1114 log-params and a scalar bool ``failed`` flag (set if 

1115 any intermediate update produced a non-finite value). 

1116 """ 

1117 

1118 opt_state0 = scan_optimizer.init(log_candidate) 

1119 

1120 def body(carry, _): 

1121 log_p, opt_state, failed = carry 

1122 loss, grads = jax.value_and_grad(total_cost_log)(log_p) 

1123 updates, opt_state = scan_optimizer.update(grads, opt_state, log_p) 

1124 new_log_p = optax.apply_updates(log_p, updates) 

1125 new_failed = failed | (~jnp.all(jnp.isfinite(new_log_p))) 

1126 # Freeze on failure so subsequent steps cannot 

1127 # propagate NaNs further. 

1128 new_log_p = jnp.where(new_failed, log_p, new_log_p) 

1129 return (new_log_p, opt_state, new_failed), loss 

1130 

1131 (final_log_p, _, failed), _ = jax.lax.scan( 

1132 body, 

1133 (log_candidate, opt_state0, jnp.bool_(False)), 

1134 None, 

1135 length=self.scan_steps, 

1136 ) 

1137 return final_log_p, failed 

1138 

1139 # Switch the underlying ODE solver to non-throwing mode for 

1140 # the duration of the scan so candidates that exceed the step 

1141 # budget produce NaN unitaries (and therefore +inf losses) 

1142 # rather than aborting the whole grid loop. 

1143 prev_solver_defaults = ys.Yaqsi.set_solver_defaults(throw=False) 

1144 n_skipped = 0 

1145 n_raw_better = 0 

1146 try: 

1147 for ci, candidate in enumerate(grid): 

1148 log_candidate = self._to_log_space(candidate) 

1149 

1150 # Evaluate the raw (unrefined) candidate so an 

1151 # over-aggressive refinement step cannot discard 

1152 # an already-good grid point. 

1153 raw_loss = _safe_eval(total_cost, candidate) 

1154 

1155 try: 

1156 log_p, failed_flag = refine_candidate(log_candidate) 

1157 except Exception as exc: # pragma: no cover - defensive 

1158 log.debug( 

1159 f" Candidate {ci + 1}/{len(grid)} " 

1160 f"raised during refinement: {exc}; skipping." 

1161 ) 

1162 physical_p = candidate 

1163 loss = raw_loss 

1164 else: 

1165 if bool(failed_flag): 

1166 physical_p = candidate 

1167 loss = raw_loss 

1168 else: 

1169 physical_p = self._from_log_space(log_p) 

1170 if not jnp.all(jnp.isfinite(physical_p)): 

1171 physical_p = candidate 

1172 loss = raw_loss 

1173 else: 

1174 loss = _safe_eval(total_cost, physical_p) 

1175 

1176 # Keep the better of (raw, refined) for this candidate. 

1177 if jnp.isfinite(raw_loss) and ( 

1178 not jnp.isfinite(loss) or raw_loss < loss 

1179 ): 

1180 physical_p = candidate 

1181 loss = raw_loss 

1182 n_raw_better += 1 

1183 

1184 if not jnp.isfinite(loss): 

1185 n_skipped += 1 

1186 continue 

1187 

1188 landscape_data.append((ci, candidate, float(loss))) 

1189 

1190 if loss < best_scan_loss: 

1191 best_scan_loss = loss 

1192 best_scan_params = physical_p 

1193 log.info( 

1194 f" Candidate {ci + 1}/{len(grid)}: " 

1195 f"loss={float(loss):.6e} improved with " 

1196 f"params={physical_p}" 

1197 ) 

1198 finally: 

1199 # Always restore the previous solver defaults so other 

1200 # callers (including Stage 1) are unaffected. 

1201 if prev_solver_defaults: 

1202 ys.Yaqsi.set_solver_defaults(**prev_solver_defaults) 

1203 

1204 if n_skipped: 

1205 log.warning( 

1206 f"Stage 0: skipped {n_skipped}/{len(grid)} candidates " 

1207 f"due to solver failure or non-finite loss " 

1208 f"(typical for very narrow / very large-amplitude " 

1209 f"DRAG pulses)." 

1210 ) 

1211 if n_raw_better: 

1212 log.info( 

1213 f"Stage 0: {n_raw_better}/{len(grid)} candidates " 

1214 f"were better unrefined than after the {self.scan_steps}-" 

1215 f"step refinement; raw values were kept." 

1216 ) 

1217 

1218 log.info( 

1219 f"Stage 0 complete. Best loss: " 

1220 f"{float(best_scan_loss):.6e}, " 

1221 f"params: {best_scan_params}" 

1222 ) 

1223 

1224 scan_data = (axes_out, landscape_data) if self.scan_steps > 0 else None 

1225 return best_scan_params, scan_data 

1226 

1227 def stage_1_opt( 

1228 self, best_scan_params: jnp.ndarray, total_costs: Callable 

1229 ) -> Tuple[jnp.ndarray, list, jnp.ndarray]: 

1230 """Run multi-restart gradient optimisation (Stage 1). 

1231 

1232 Performs ``n_restarts`` independent AdamW runs with the full 

1233 (weighted) cost function. The first restart uses 

1234 ``best_scan_params`` directly; subsequent restarts add random 

1235 perturbations. Parameters specified in ``log_scale_params`` are 

1236 optimised in log-space. 

1237 

1238 When ``n_restarts == 1`` we keep the original single-restart 

1239 Python loop (it preserves per-step ``log.info`` granularity 

1240 and avoids the vmap/scan compilation overhead). When 

1241 ``n_restarts > 1`` we ``vmap`` the optimiser over restarts and 

1242 run the inner step loop with :func:`jax.lax.scan`, fusing all 

1243 ``n_restarts × n_steps`` steps into a single XLA program. 

1244 

1245 Args: 

1246 best_scan_params: Starting parameters (typically from Stage 0). 

1247 total_costs: Combined cost callable. 

1248 

1249 Returns: 

1250 Tuple of ``(best_params, loss_history, best_loss)`` from the 

1251 best restart. 

1252 """ 

1253 

1254 # Wrap the cost function with log-space reparameterisation 

1255 def total_costs_log(log_params): 

1256 return total_costs(self._from_log_space(log_params)) 

1257 

1258 # Build learning rate schedule 

1259 warmup_steps = int(self.n_steps * self.warmup_ratio) 

1260 end_value = self.learning_rate * self.end_lr_ratio 

1261 

1262 if warmup_steps > 0 or self.end_lr_ratio < 1.0: 

1263 schedule = optax.warmup_cosine_decay_schedule( 

1264 init_value=(end_value if warmup_steps > 0 else self.learning_rate), 

1265 peak_value=self.learning_rate, 

1266 warmup_steps=warmup_steps, 

1267 decay_steps=self.n_steps, 

1268 end_value=end_value, 

1269 ) 

1270 else: 

1271 schedule = self.learning_rate 

1272 

1273 optimizer = _build_optimizer(schedule, self.grad_clip) 

1274 

1275 if self.n_restarts <= 1: 

1276 return self._stage_1_sequential( 

1277 best_scan_params, total_costs, total_costs_log, optimizer 

1278 ) 

1279 return self._stage_1_parallel( 

1280 best_scan_params, total_costs, total_costs_log, optimizer 

1281 ) 

1282 

1283 def _perturb_starts(self, start_params: jnp.ndarray) -> jnp.ndarray: 

1284 """Pre-build the ``(n_restarts, n_params)`` matrix of restart starts. 

1285 

1286 Restart 0 is the unperturbed start; subsequent restarts add 

1287 Gaussian noise scaled by ``max(|start|, 0.1) * 

1288 restart_noise_scale``. Indices that are optimised in 

1289 log-space (plus the evolution time at index ``-1``) are kept 

1290 positive via ``jnp.abs`` so the subsequent ``log`` is safe. 

1291 """ 

1292 n_params = start_params.shape[0] 

1293 keys = jax.random.split(self.random_key, self.n_restarts) 

1294 # Shape (n_restarts, n_params); restart 0 is intentionally zero 

1295 # noise so the unperturbed start is preserved. 

1296 noise = jax.vmap(lambda k: jax.random.normal(k, shape=(n_params,)))(keys) 

1297 noise = noise.at[0].set(0.0) 

1298 scale = jnp.maximum(jnp.abs(start_params), 0.1) * self.restart_noise_scale 

1299 starts = start_params[None, :] + noise * scale[None, :] 

1300 

1301 # Keep the evolution time and any log-scaled indices positive. 

1302 positive_mask = np.zeros(n_params, dtype=bool) 

1303 positive_mask[-1] = True 

1304 for idx in self.log_scale_params: 

1305 i = idx if idx >= 0 else n_params + idx 

1306 if 0 <= i < n_params: 

1307 positive_mask[i] = True 

1308 positive_mask_j = jnp.asarray(positive_mask) 

1309 starts = jnp.where(positive_mask_j[None, :], jnp.abs(starts), starts) 

1310 return starts 

1311 

1312 def _stage_1_sequential( 

1313 self, 

1314 start_params: jnp.ndarray, 

1315 total_costs: Callable, 

1316 total_costs_log: Callable, 

1317 optimizer, 

1318 ) -> Tuple[jnp.ndarray, list, jnp.ndarray]: 

1319 """Single-restart Stage 1, fused into a single ``jax.lax.scan``. 

1320 

1321 The whole optimisation loop (n_steps × Adam updates) compiles 

1322 to one XLA program, eliminating the per-step Python overhead 

1323 and per-step host/device syncs that the previous Python ``for`` 

1324 loop incurred. Early stopping is preserved via *masked 

1325 updates*: once the patience condition trips, subsequent steps 

1326 leave the parameters and loss unchanged. Compute is not 

1327 skipped (lax.scan has fixed length) but the optimiser state 

1328 and parameter trajectory freeze, matching the previous 

1329 early-stop semantics modulo wall-clock savings. 

1330 """ 

1331 

1332 params = start_params 

1333 log_params = self._to_log_space(params) 

1334 opt_state = optimizer.init(log_params) 

1335 

1336 init_loss = total_costs(params) 

1337 min_delta = self.early_stop_min_delta 

1338 patience = self.early_stop_patience 

1339 # ``patience <= 0`` ⇒ early stopping disabled. Use a large 

1340 # constant so the masked-update path is never triggered. 

1341 eff_patience = patience if patience > 0 else self.n_steps + 1 

1342 

1343 def scan_body(carry, _): 

1344 ( 

1345 log_params, 

1346 opt_state, 

1347 best_loss, 

1348 best_log_params, 

1349 steps_since_improve, 

1350 stopped_flag, 

1351 stopped_step, 

1352 step_idx, 

1353 ) = carry 

1354 

1355 loss, grads = jax.value_and_grad(total_costs_log)(log_params) 

1356 updates, new_opt_state = optimizer.update(grads, opt_state, log_params) 

1357 stepped_log_params = optax.apply_updates(log_params, updates) 

1358 

1359 # Improvement test (uses the pre-update loss, matching the 

1360 # original semantics where the loss recorded on step *i* 

1361 # corresponds to the params *before* that step's update). 

1362 improved = loss < best_loss - min_delta 

1363 best_loss = jnp.where(improved, loss, best_loss) 

1364 # Save the params that *produced* the improving loss 

1365 # (i.e. the pre-update ``log_params``). ``improved`` is a 

1366 # scalar bool and broadcasts against the 1-D params arrays. 

1367 best_log_params = jnp.where(improved, log_params, best_log_params) 

1368 steps_since_improve = jnp.where( 

1369 improved, jnp.int32(0), steps_since_improve + jnp.int32(1) 

1370 ) 

1371 

1372 # Latch the early-stop flag once it fires. 

1373 trigger = steps_since_improve >= jnp.int32(eff_patience) 

1374 new_stopped_flag = stopped_flag | trigger 

1375 stopped_step = jnp.where( 

1376 stopped_flag, 

1377 stopped_step, 

1378 jnp.where(trigger, step_idx + jnp.int32(1), stopped_step), 

1379 ) 

1380 

1381 # Mask the update once stopped: freeze params/optimiser. 

1382 new_log_params = jnp.where(new_stopped_flag, log_params, stepped_log_params) 

1383 new_opt_state_kept = jax.tree_util.tree_map( 

1384 lambda new, old: jnp.where(new_stopped_flag, old, new), 

1385 new_opt_state, 

1386 opt_state, 

1387 ) 

1388 

1389 new_carry = ( 

1390 new_log_params, 

1391 new_opt_state_kept, 

1392 best_loss, 

1393 best_log_params, 

1394 steps_since_improve, 

1395 new_stopped_flag, 

1396 stopped_step, 

1397 step_idx + jnp.int32(1), 

1398 ) 

1399 return new_carry, loss 

1400 

1401 init_carry = ( 

1402 log_params, # log_params 

1403 opt_state, # opt_state 

1404 init_loss, # best_loss 

1405 log_params, # best_log_params 

1406 jnp.int32(0), # steps_since_improve 

1407 jnp.bool_(False), # stopped_flag 

1408 jnp.int32(self.n_steps), # stopped_step (default = n_steps) 

1409 jnp.int32(0), # step_idx 

1410 ) 

1411 

1412 @jax.jit 

1413 def run_scan(carry): 

1414 return jax.lax.scan(scan_body, carry, None, length=self.n_steps) 

1415 

1416 final_carry, step_losses = run_scan(init_carry) 

1417 ( 

1418 _, 

1419 _, 

1420 best_loss, 

1421 best_log_params, 

1422 _, 

1423 stopped_flag, 

1424 stopped_step, 

1425 _, 

1426 ) = final_carry 

1427 

1428 # One sync: pull just what we need for logging in a single 

1429 # device->host transfer instead of a per-step ``.item()`` call. 

1430 host_step_losses, host_best_loss, host_stopped, host_stopped_step = ( 

1431 jax.device_get((step_losses, best_loss, stopped_flag, stopped_step)) 

1432 ) 

1433 

1434 # Periodic progress log (replaces the per-step inline log; 

1435 # cheap because step losses already live on the host). 

1436 for step in range(0, self.n_steps, max(1, self.log_interval)): 

1437 log.info( 

1438 f"Step {step}/{self.n_steps}, Loss: {float(host_step_losses[step]):.3e}" 

1439 ) 

1440 if bool(host_stopped): 

1441 log.info( 

1442 f"Early stop at step {int(host_stopped_step)}/{self.n_steps} " 

1443 f"(no improvement > {min_delta:g} for " 

1444 f"{self.early_stop_patience} steps)." 

1445 ) 

1446 

1447 log.info( 

1448 f"Restart 1/1 finished with best loss: {float(host_best_loss):.3e}" 

1449 + ( 

1450 f" (early stopped at step {int(host_stopped_step)})" 

1451 if bool(host_stopped) 

1452 else "" 

1453 ) 

1454 ) 

1455 

1456 # Reconstruct the historical loss list shape: leading entry is 

1457 # the initial (pre-step-0) loss, followed by one entry per 

1458 # scan step. Match the previous return type (``list``) so 

1459 # downstream plotting code is unchanged. 

1460 loss_history = [init_loss] + list(step_losses) 

1461 

1462 best_pulse_params = self._from_log_space(best_log_params) 

1463 return best_pulse_params, loss_history, best_loss 

1464 

1465 def _stage_1_parallel( 

1466 self, 

1467 start_params: jnp.ndarray, 

1468 total_costs: Callable, 

1469 total_costs_log: Callable, 

1470 optimizer, 

1471 ) -> Tuple[jnp.ndarray, list, jnp.ndarray]: 

1472 """Vmap+scan Stage 1: all restarts × all steps in one XLA program. 

1473 

1474 Always runs the full ``n_steps``: an early-stop break would 

1475 require either chunking the scan (extra Python overhead) or 

1476 masking updates inside the scan (no compute saved), and 

1477 because every restart would have to plateau before we could 

1478 break, the win is small. Sequential mode (``n_restarts == 1``) 

1479 does honour ``early_stop_patience``. 

1480 """ 

1481 

1482 # (n_restarts, n_params) starting points (restart 0 unperturbed). 

1483 params_batch = self._perturb_starts(start_params) 

1484 log.info( 

1485 f"Stage 1 (parallel): vmapping {self.n_restarts} restarts × " 

1486 f"{self.n_steps} steps in a single fused program." 

1487 ) 

1488 if self.early_stop_patience > 0: 

1489 log.info( 

1490 "Note: early_stop_patience is ignored in the parallel " 

1491 "(n_restarts > 1) path; the full n_steps will run." 

1492 ) 

1493 

1494 log_params_batch = jax.vmap(self._to_log_space)(params_batch) 

1495 opt_state_batch = jax.vmap(optimizer.init)(log_params_batch) 

1496 

1497 # Initial losses (per-restart) so loss_history[0] matches the 

1498 # per-restart sequential semantics. 

1499 init_losses = jax.vmap(total_costs)(params_batch) 

1500 

1501 def opt_step(log_params, opt_state): 

1502 loss, grads = jax.value_and_grad(total_costs_log)(log_params) 

1503 updates, opt_state = optimizer.update(grads, opt_state, log_params) 

1504 log_params = optax.apply_updates(log_params, updates) 

1505 return log_params, opt_state, loss 

1506 

1507 v_opt_step = jax.vmap(opt_step, in_axes=(0, 0)) 

1508 

1509 def scan_body(carry, _): 

1510 log_params, opt_state, prev_log_params, best_loss, best_log_params = carry 

1511 new_log_params, new_opt_state, loss = v_opt_step(log_params, opt_state) 

1512 # Track best loss (and the params that *produced* it, 

1513 # which are the pre-update ``prev_log_params`` — same 

1514 # rationale as the sequential path). 

1515 improved = loss < best_loss 

1516 best_loss = jnp.where(improved, loss, best_loss) 

1517 best_log_params = jnp.where( 

1518 improved[:, None], prev_log_params, best_log_params 

1519 ) 

1520 new_carry = ( 

1521 new_log_params, 

1522 new_opt_state, 

1523 log_params, # becomes prev for the next step 

1524 best_loss, 

1525 best_log_params, 

1526 ) 

1527 return new_carry, loss 

1528 

1529 init_carry = ( 

1530 log_params_batch, 

1531 opt_state_batch, 

1532 log_params_batch, 

1533 init_losses, 

1534 log_params_batch, 

1535 ) 

1536 

1537 @jax.jit 

1538 def run_scan(carry): 

1539 return jax.lax.scan(scan_body, carry, None, length=self.n_steps) 

1540 

1541 final_carry, step_losses = run_scan(init_carry) 

1542 # step_losses shape (n_steps, n_restarts); each row is the 

1543 # cross-restart loss vector at one optimisation step. 

1544 _, _, _, best_losses, best_log_params_batch = final_carry 

1545 

1546 # Periodic batch summary so the operator still sees progress. 

1547 # Pull the small per-step loss matrix to host once, then format 

1548 # without further device→host transfers. 

1549 host_step_losses = jax.device_get(step_losses) 

1550 for step in range(0, self.n_steps, max(1, self.log_interval)): 

1551 row = host_step_losses[step] 

1552 log.info( 

1553 f"Step {step}/{self.n_steps}, " 

1554 f"loss min/mean/max: {float(row.min()):.3e} / " 

1555 f"{float(row.mean()):.3e} / {float(row.max()):.3e}" 

1556 ) 

1557 

1558 # Per-restart final summary (single sync for ``best_losses``). 

1559 host_best_losses = jax.device_get(best_losses) 

1560 for r in range(self.n_restarts): 

1561 log.info( 

1562 f"Restart {r + 1}/{self.n_restarts} finished " 

1563 f"with best loss: {float(host_best_losses[r]):.3e}" 

1564 ) 

1565 

1566 winner = int(jnp.argmin(best_losses)) 

1567 global_best_loss = best_losses[winner] 

1568 global_best_params = self._from_log_space(best_log_params_batch[winner]) 

1569 

1570 # Build a per-step loss history for the winning restart so the 

1571 # downstream API (and the loss-curve plot) keeps the same 

1572 # shape as before. 

1573 winner_history = [init_losses[winner]] 

1574 winner_history.extend(step_losses[:, winner]) 

1575 return global_best_params, winner_history, global_best_loss 

1576 

1577 def plot_loss_landscape( 

1578 self, 

1579 gate_name: str, 

1580 grid_axes: List[jnp.ndarray], 

1581 landscape_data: list, 

1582 ) -> None: 

1583 """Save a loss-landscape figure for the Phase-0 grid scan. 

1584 

1585 The visualisation adapts to the number of pulse parameters: 

1586 

1587 - **1 parameter**: line/scatter plot (param value vs. loss). 

1588 - **2 parameters**: 2-D heatmap (param₀ × param₁, colour = loss). 

1589 - **≥ 3 parameters**: horizontal scatter sorted by ascending loss 

1590 with the best candidate highlighted. 

1591 

1592 The figure is saved to ``{file_dir}/{gate_name}_loss_landscape.png``. 

1593 

1594 Args: 

1595 gate_name: Name of the gate being optimised (e.g. ``"RX"``). 

1596 grid_axes: Per-parameter 1-D arrays that span the scan grid. 

1597 landscape_data: List of ``(candidate_index, params, loss)`` 

1598 tuples for every successful scan candidate. 

1599 """ 

1600 import matplotlib.pyplot as plt # lazy — matplotlib is dev-only 

1601 

1602 if not landscape_data: 

1603 log.warning("plot_loss_landscape: no landscape data to plot, skipping.") 

1604 return 

1605 

1606 os.makedirs(self.file_dir, exist_ok=True) 

1607 n_params = len(grid_axes) 

1608 indices, _params_list, losses = zip(*landscape_data) 

1609 losses_arr = np.array(losses, dtype=float) 

1610 

1611 fig, ax = plt.subplots(figsize=(8, 5)) 

1612 

1613 if n_params == 1: 

1614 x = np.array([float(grid_axes[0][i]) for i in indices]) 

1615 sc = ax.scatter( 

1616 x, losses_arr, c=losses_arr, cmap="viridis_r", s=60, zorder=3 

1617 ) 

1618 fig.colorbar(sc, ax=ax, label="Loss") 

1619 best_i = int(np.argmin(losses_arr)) 

1620 ax.scatter( 

1621 x[best_i], 

1622 losses_arr[best_i], 

1623 marker="*", 

1624 s=200, 

1625 color="red", 

1626 zorder=4, 

1627 label="best", 

1628 ) 

1629 ax.set_xlabel("Parameter value") 

1630 ax.set_xscale("log") 

1631 ax.set_yscale("log") 

1632 ax.legend() 

1633 

1634 elif n_params == 2: 

1635 n = self.scan_grid_size 

1636 loss_grid = np.full((n, n), np.nan) 

1637 for ci, _, loss in landscape_data: 

1638 row = ci // n 

1639 col = ci % n 

1640 loss_grid[row, col] = loss 

1641 masked = np.ma.masked_invalid(loss_grid) 

1642 cmap = plt.cm.viridis_r.copy() 

1643 cmap.set_bad(color="lightgrey") 

1644 im = ax.imshow( 

1645 masked, 

1646 origin="lower", 

1647 cmap=cmap, 

1648 aspect="auto", 

1649 extent=[ 

1650 float(grid_axes[1][0]), 

1651 float(grid_axes[1][-1]), 

1652 float(grid_axes[0][0]), 

1653 float(grid_axes[0][-1]), 

1654 ], 

1655 ) 

1656 fig.colorbar(im, ax=ax, label="Loss") 

1657 ax.set_xlabel("Parameter 1") 

1658 ax.set_ylabel("Parameter 0") 

1659 

1660 else: # n_params >= 3: sorted scatter 

1661 order = np.argsort(losses_arr) 

1662 sorted_losses = losses_arr[order] 

1663 sorted_indices = np.array(indices)[order] # original trial numbers 

1664 ranks = np.arange(len(sorted_losses)) 

1665 sc = ax.scatter( 

1666 sorted_losses, 

1667 ranks, 

1668 c=sorted_indices, 

1669 cmap="plasma", 

1670 s=40, 

1671 zorder=3, 

1672 ) 

1673 fig.colorbar(sc, ax=ax, label="Trial number") 

1674 ax.scatter( 

1675 sorted_losses[0], 

1676 ranks[0], 

1677 marker="*", 

1678 s=200, 

1679 color="red", 

1680 zorder=4, 

1681 label="best", 

1682 ) 

1683 ax.set_xlabel("Loss") 

1684 ax.set_ylabel("Candidate rank (0 = best)") 

1685 ax.set_xscale("log") 

1686 ax.legend() 

1687 

1688 ax.set_title(f"Loss Landscape (Phase 0) — {gate_name}") 

1689 fig.tight_layout() 

1690 path = os.path.join(self.file_dir, f"{gate_name}_loss_landscape.png") 

1691 fig.savefig(path, dpi=150) 

1692 plt.close(fig) 

1693 log.info(f"Loss landscape saved to {path}") 

1694 

1695 def plot_loss_curve( 

1696 self, 

1697 gate_name: str, 

1698 loss_history: list, 

1699 ) -> None: 

1700 """Save a training-loss curve figure for the Phase-1 optimisation. 

1701 

1702 Shows loss vs. optimisation step on a log y-scale with a dashed 

1703 horizontal line at the minimum achieved loss. 

1704 

1705 The figure is saved to ``{file_dir}/{gate_name}_loss_curve.png``. 

1706 

1707 Args: 

1708 gate_name: Name of the gate being optimised (e.g. ``"RX"``). 

1709 loss_history: Sequence of loss values, one per step (including 

1710 the initial loss at index 0). 

1711 """ 

1712 import matplotlib.pyplot as plt # lazy — matplotlib is dev-only 

1713 

1714 if not loss_history: 

1715 log.warning("plot_loss_curve: empty loss history, skipping.") 

1716 return 

1717 

1718 os.makedirs(self.file_dir, exist_ok=True) 

1719 losses = [float(v) for v in loss_history] 

1720 best = min(losses) 

1721 

1722 fig, ax = plt.subplots(figsize=(9, 4)) 

1723 ax.plot(losses, linewidth=1.2, label="Loss") 

1724 ax.axhline( 

1725 best, color="red", linestyle="--", linewidth=1.0, label=f"Best: {best:.3e}" 

1726 ) 

1727 ax.set_xlabel("Step") 

1728 ax.set_ylabel("Loss") 

1729 ax.set_yscale("log") 

1730 ax.set_title(f"Training Loss (Phase 1) — {gate_name}") 

1731 ax.legend() 

1732 fig.tight_layout() 

1733 path = os.path.join(self.file_dir, f"{gate_name}_loss_curve.png") 

1734 fig.savefig(path, dpi=150) 

1735 plt.close(fig) 

1736 log.info(f"Loss curve saved to {path}") 

1737 

1738 def optimize(self, wires: int) -> Callable: 

1739 """Decorator factory that optimises pulse parameters for a gate. 

1740 

1741 Usage:: 

1742 

1743 opt = qoc.optimize(wires=1) 

1744 best_params, loss_history = opt(qoc.create_RX)() 

1745 

1746 Args: 

1747 wires: Number of qubits the gate acts on. 

1748 

1749 Returns: 

1750 A decorator that accepts a circuit-factory function and 

1751 returns a callable ``(init_pulse_params=None) -> 

1752 (best_params, loss_history)``. 

1753 """ 

1754 

1755 def decorator(create_circuits): 

1756 def wrapper(init_pulse_params: jnp.ndarray = None): 

1757 """ 

1758 Optimise pulse parameters for a quantum gate using a 

1759 multi-phase strategy: 

1760 

1761 Stage 0 - Grid scan (if ``scan_steps > 0``): 

1762 Evaluate a coarse grid of parameter candidates using 

1763 the same weighted cost as Stage 1. Each candidate 

1764 is refined with a few fast gradient steps. The 

1765 best candidate becomes the starting point for 

1766 Stage 1, unless the user-supplied init_pulse_params 

1767 are already better. 

1768 

1769 Stage 1 - Multi-restart gradient optimisation: 

1770 Run ``n_restarts`` independent Adam optimisation runs 

1771 with the full cost function. The first restart uses 

1772 the best point found so far; subsequent restarts add 

1773 random perturbations. Parameters at indices in 

1774 ``log_scale_params`` are optimised in log-space to 

1775 handle order-of-magnitude differences in scale. 

1776 

1777 Args: 

1778 init_pulse_params (array): Initial pulse parameters. 

1779 If ``None``, uses the envelope defaults from 

1780 :class:`PulseInformation`. 

1781 

1782 Returns: 

1783 tuple: ``(best_params, loss_history)`` from the best 

1784 restart. 

1785 """ 

1786 pulse_circuit, target_circuit = create_circuits() 

1787 

1788 # Build a second pair that prepends a Hadamard on every 

1789 # wire so the cost is also evaluated from the 

1790 # ``|+⟩^⊗n`` initial state. Probing two non-collinear 

1791 # initial states exposes rotation-axis tilt to the 

1792 # optimiser: an RX/RY pulse with a residual Z component 

1793 # is partly degenerate from ``|0⟩`` alone but produces 

1794 # a clearly distinguishable trajectory from ``|+⟩``. 

1795 # Both circuits get the same preparation so the target 

1796 # remains exact. 

1797 def _with_plus_prep(circuit_fn): 

1798 def prepared(*args, **kwargs): 

1799 for q in range(wires): 

1800 op.H(wires=q) 

1801 circuit_fn(*args, **kwargs) 

1802 

1803 prepared.__name__ = f"plus_{circuit_fn.__name__}" 

1804 return prepared 

1805 

1806 pulse_circuit_plus = _with_plus_prep(pulse_circuit) 

1807 target_circuit_plus = _with_plus_prep(target_circuit) 

1808 

1809 pulse_scripts = [ 

1810 ys.Script(pulse_circuit, n_qubits=wires), 

1811 ys.Script(pulse_circuit_plus, n_qubits=wires), 

1812 ] 

1813 target_scripts = [ 

1814 ys.Script(target_circuit, n_qubits=wires), 

1815 ys.Script(target_circuit_plus, n_qubits=wires), 

1816 ] 

1817 

1818 d_basis = 2**wires 

1819 pulse_basis_scripts = [ 

1820 ys.Script(_with_basis_prep(pulse_circuit, k, wires), n_qubits=wires) 

1821 for k in range(d_basis) 

1822 ] 

1823 target_basis_scripts = [ 

1824 ys.Script( 

1825 _with_basis_prep(target_circuit, k, wires), n_qubits=wires 

1826 ) 

1827 for k in range(d_basis) 

1828 ] 

1829 

1830 gate_name = create_circuits.__name__.split("_")[1] 

1831 

1832 if init_pulse_params is None: 

1833 init_pulse_params = PulseInformation.gate_by_name(gate_name).params 

1834 log.debug( 

1835 f"Initial pulse parameters for {gate_name}: {init_pulse_params}" 

1836 ) 

1837 

1838 all_ckwargs = { 

1839 "pulse_scripts": pulse_scripts, 

1840 "target_scripts": target_scripts, 

1841 "pulse_basis_scripts": pulse_basis_scripts, 

1842 "target_basis_scripts": target_basis_scripts, 

1843 "envelope": self.envelope, 

1844 "n_samples": self.n_samples, 

1845 "n_qubits": wires, 

1846 "t_target": self.t_target, 

1847 } 

1848 

1849 def _build_cost(name, weight): 

1850 """Build a Cost from a registry entry, filtering ckwargs.""" 

1851 meta = CostFnRegistry.get(name) 

1852 return Cost( 

1853 cost=meta["fn"], 

1854 weight=weight, 

1855 ckwargs={ 

1856 k: v 

1857 for k, v in all_ckwargs.items() 

1858 if k in meta["ckwargs_keys"] 

1859 }, 

1860 ) 

1861 

1862 total_costs = None 

1863 for name, weight in self.cost_fns: 

1864 total_costs = _build_cost(name, weight) + total_costs 

1865 

1866 best_scan_params, scan_data = self.stage_0_opt( 

1867 init_pulse_params, 

1868 total_costs, 

1869 ) 

1870 

1871 global_best_params, global_best_history, global_best_loss = ( 

1872 self.stage_1_opt( 

1873 best_scan_params, 

1874 total_costs, 

1875 ) 

1876 ) 

1877 self.save_results( 

1878 gate=gate_name, 

1879 fidelity=1 - global_best_loss.item(), 

1880 pulse_params=global_best_params, 

1881 ) 

1882 

1883 if self.plot: 

1884 if scan_data is not None: 

1885 grid_axes, landscape_items = scan_data 

1886 self.plot_loss_landscape(gate_name, grid_axes, landscape_items) 

1887 self.plot_loss_curve(gate_name, global_best_history) 

1888 

1889 return global_best_params, global_best_history 

1890 

1891 return wrapper 

1892 

1893 return decorator 

1894 

1895 # ------------------------------------------------------------------ 

1896 # Per-gate (pulse, target) circuit factories 

1897 # ------------------------------------------------------------------ 

1898 # 

1899 # Each entry maps a gate name to a ``(pulse_circuit, target_circuit)`` 

1900 # pair. The per-gate variants prepend a symmetry-breaking 

1901 # preparation (e.g. ``op.H``/``op.RY``) so the *state-vector* cost 

1902 # is sensitive to rotation-axis tilt. The joint-mode variants drop 

1903 # those preps because the unitary cost already captures axis tilt 

1904 # without probe-state trickery (see :meth:`_create_joint_pair_for`). 

1905 

1906 @staticmethod 

1907 def _gate_factories() -> Dict[str, Tuple[Callable, Callable]]: 

1908 """Return the ``{gate_name: (pulse_fn, target_fn)}`` table. 

1909 

1910 Constructed lazily inside a staticmethod so the closures 

1911 capture the imported gate symbols at call time. 

1912 """ 

1913 

1914 return { 

1915 "RX": _make_gate_pair( 

1916 lambda w, pp: Gates.RX(w, 0, pulse_params=pp, gate_mode="pulse"), 

1917 lambda w: op.RX(w, wires=0), 

1918 ), 

1919 "RY": _make_gate_pair( 

1920 lambda w, pp: Gates.RY(w, 0, pulse_params=pp, gate_mode="pulse"), 

1921 lambda w: op.RY(w, wires=0), 

1922 ), 

1923 "RZ": _make_gate_pair( 

1924 lambda w, pp: Gates.RZ(w, 0, pulse_params=pp, gate_mode="pulse"), 

1925 lambda w: op.RZ(w, wires=0), 

1926 prep=lambda w: op.H(wires=0), 

1927 post=lambda w: op.H(wires=0), 

1928 ), 

1929 "H": _make_gate_pair( 

1930 lambda w, pp: Gates.H(0, pulse_params=pp, gate_mode="pulse"), 

1931 lambda w: op.H(wires=0), 

1932 prep=lambda w: op.RY(w, wires=0), 

1933 ), 

1934 "Rot": _make_gate_pair( 

1935 lambda w, pp: Gates.Rot( 

1936 w, w * 2, w * 3, 0, pulse_params=pp, gate_mode="pulse" 

1937 ), 

1938 lambda w: op.Rot(w, w * 2, w * 3, wires=0), 

1939 prep=lambda w: op.H(wires=0), 

1940 ), 

1941 "CX": _make_gate_pair( 

1942 lambda w, pp: Gates.CX( 

1943 wires=[0, 1], pulse_params=pp, gate_mode="pulse" 

1944 ), 

1945 lambda w: op.CX(wires=[0, 1]), 

1946 prep=_chain_gate_stages( 

1947 lambda w: op.RY(w, wires=0), 

1948 lambda w: op.H(wires=1), 

1949 ), 

1950 ), 

1951 "CY": _make_gate_pair( 

1952 lambda w, pp: Gates.CY( 

1953 wires=[0, 1], pulse_params=pp, gate_mode="pulse" 

1954 ), 

1955 lambda w: op.CY(wires=[0, 1]), 

1956 prep=_chain_gate_stages( 

1957 lambda w: op.RX(w, wires=0), 

1958 lambda w: op.H(wires=1), 

1959 ), 

1960 ), 

1961 "CZ": _make_gate_pair( 

1962 lambda w, pp: Gates.CZ( 

1963 wires=[0, 1], pulse_params=pp, gate_mode="pulse" 

1964 ), 

1965 lambda w: op.CZ(wires=[0, 1]), 

1966 prep=_chain_gate_stages( 

1967 lambda w: op.RY(w, wires=0), 

1968 lambda w: op.H(wires=1), 

1969 ), 

1970 ), 

1971 "CRX": _make_gate_pair( 

1972 lambda w, pp: Gates.CRX( 

1973 w, wires=[0, 1], pulse_params=pp, gate_mode="pulse" 

1974 ), 

1975 lambda w: op.CRX(w, wires=[0, 1]), 

1976 prep=lambda w: op.H(wires=0), 

1977 ), 

1978 "CRY": _make_gate_pair( 

1979 lambda w, pp: Gates.CRY( 

1980 w, wires=[0, 1], pulse_params=pp, gate_mode="pulse" 

1981 ), 

1982 lambda w: op.CRY(w, wires=[0, 1]), 

1983 prep=lambda w: op.H(wires=0), 

1984 ), 

1985 "CRZ": _make_gate_pair( 

1986 lambda w, pp: Gates.CRZ( 

1987 w, wires=[0, 1], pulse_params=pp, gate_mode="pulse" 

1988 ), 

1989 lambda w: op.CRZ(w, wires=[0, 1]), 

1990 prep=_chain_gate_stages( 

1991 lambda w: op.H(wires=0), 

1992 lambda w: op.H(wires=1), 

1993 ), 

1994 ), 

1995 } 

1996 

1997 @staticmethod 

1998 def _joint_gate_factories() -> Dict[str, Tuple[Callable, Callable]]: 

1999 """``(pulse, target)`` pairs without any symmetry-breaking preps. 

2000 

2001 Used by :meth:`_create_joint_pair_for`: the unitary cost 

2002 already exposes rotation-axis tilt without a probe state, and 

2003 leaving the preps in actively *hides* certain errors (e.g. 

2004 ``op.H(wires=1)`` puts the target qubit of CX into a CX 

2005 eigenstate, so the column-stacked unitary becomes insensitive 

2006 to the pulse error). ``Rot`` and ``CY`` are intentionally 

2007 absent because the joint optimiser does not target them. 

2008 """ 

2009 

2010 return { 

2011 "RX": _make_gate_pair( 

2012 lambda w, pp: Gates.RX(w, wires=0, pulse_params=pp, gate_mode="pulse"), 

2013 lambda w: op.RX(w, wires=0), 

2014 ), 

2015 "RY": _make_gate_pair( 

2016 lambda w, pp: Gates.RY(w, wires=0, pulse_params=pp, gate_mode="pulse"), 

2017 lambda w: op.RY(w, wires=0), 

2018 ), 

2019 "RZ": _make_gate_pair( 

2020 lambda w, pp: Gates.RZ(w, wires=0, pulse_params=pp, gate_mode="pulse"), 

2021 lambda w: op.RZ(w, wires=0), 

2022 ), 

2023 "H": _make_gate_pair( 

2024 lambda w, pp: Gates.H(0, pulse_params=pp, gate_mode="pulse"), 

2025 lambda w: op.H(wires=0), 

2026 ), 

2027 "CZ": _make_gate_pair( 

2028 lambda w, pp: Gates.CZ( 

2029 wires=[0, 1], pulse_params=pp, gate_mode="pulse" 

2030 ), 

2031 lambda w: op.CZ(wires=[0, 1]), 

2032 ), 

2033 "CX": _make_gate_pair( 

2034 lambda w, pp: Gates.CX( 

2035 wires=[0, 1], pulse_params=pp, gate_mode="pulse" 

2036 ), 

2037 lambda w: op.CX(wires=[0, 1]), 

2038 ), 

2039 "CRX": _make_gate_pair( 

2040 lambda w, pp: Gates.CRX( 

2041 w, wires=[0, 1], pulse_params=pp, gate_mode="pulse" 

2042 ), 

2043 lambda w: op.CRX(w, wires=[0, 1]), 

2044 ), 

2045 "CRY": _make_gate_pair( 

2046 lambda w, pp: Gates.CRY( 

2047 w, wires=[0, 1], pulse_params=pp, gate_mode="pulse" 

2048 ), 

2049 lambda w: op.CRY(w, wires=[0, 1]), 

2050 ), 

2051 "CRZ": _make_gate_pair( 

2052 lambda w, pp: Gates.CRZ( 

2053 w, wires=[0, 1], pulse_params=pp, gate_mode="pulse" 

2054 ), 

2055 lambda w: op.CRZ(w, wires=[0, 1]), 

2056 ), 

2057 } 

2058 

2059 def _create_pair(self, gate_name: str) -> Tuple[Callable, Callable]: 

2060 """Look up the per-gate ``(pulse, target)`` pair from the table.""" 

2061 try: 

2062 return self._gate_factories()[gate_name] 

2063 except KeyError as exc: 

2064 raise ValueError(f"No factory for gate {gate_name!r}.") from exc 

2065 

2066 # Thin compatibility wrappers around :meth:`_create_pair` so existing 

2067 # code (and tests) that call ``qoc.create_<gate>`` keep working. 

2068 def create_RX(self): 

2069 return self._create_pair("RX") 

2070 

2071 def create_RY(self): 

2072 return self._create_pair("RY") 

2073 

2074 def create_RZ(self): 

2075 return self._create_pair("RZ") 

2076 

2077 def create_H(self): 

2078 return self._create_pair("H") 

2079 

2080 def create_Rot(self): 

2081 return self._create_pair("Rot") 

2082 

2083 def create_CX(self): 

2084 return self._create_pair("CX") 

2085 

2086 def create_CY(self): 

2087 return self._create_pair("CY") 

2088 

2089 def create_CZ(self): 

2090 return self._create_pair("CZ") 

2091 

2092 def create_CRX(self): 

2093 return self._create_pair("CRX") 

2094 

2095 def create_CRY(self): 

2096 return self._create_pair("CRY") 

2097 

2098 def create_CRZ(self): 

2099 return self._create_pair("CRZ") 

2100 

2101 def create_CPhase(self): 

2102 """Create pulse and target circuits for the CPhase gate.""" 

2103 

2104 def pulse_circuit(w, pulse_params): 

2105 op.H(wires=0) 

2106 op.H(wires=1) 

2107 Gates.CPhase(w, wires=[0, 1], pulse_params=pulse_params, gate_mode="pulse") 

2108 

2109 def target_circuit(w): 

2110 op.H(wires=0) 

2111 op.H(wires=1) 

2112 op.ControlledPhaseShift(w, wires=[0, 1]) 

2113 

2114 return pulse_circuit, target_circuit 

2115 

2116 def optimize_all(self, sel_gates: str, make_log: bool) -> None: 

2117 """Optimise all selected gates and optionally write a log CSV. 

2118 

2119 Args: 

2120 sel_gates: Comma-separated gate names or ``"all"``. 

2121 make_log: If ``True``, write per-gate loss histories to 

2122 ``qml_essentials/qoc_logs.csv``. 

2123 """ 

2124 # Joint mode (Round 3) is now implemented in :meth:`optimize_joint`. 

2125 # The `--joint` CLI flag selects it instead of this per-gate loop. 

2126 log_history: Dict[str, list] = {} 

2127 

2128 for gate in self.GATES_1Q + self.GATES_2Q: 

2129 if gate in sel_gates or "all" in sel_gates: 

2130 n_wires = 1 if gate in self.GATES_1Q else 2 

2131 opt = self.optimize(wires=n_wires) 

2132 gate_factory = getattr(self, f"create_{gate}") 

2133 log.info(f"Optimizing {gate} gate...") 

2134 optimized_pulse_params, loss_history = opt(gate_factory)() 

2135 log.info(f"Optimized parameters for {gate}: {optimized_pulse_params}") 

2136 best_fid = 1 - min(float(loss) for loss in loss_history) 

2137 log.info(f"Best achieved fidelity: {best_fid * 100:.5f}%") 

2138 log_history[gate] = log_history.get(gate, []) + loss_history 

2139 

2140 if make_log: 

2141 # write log history to file 

2142 with open("qml_essentials/qoc_logs.csv", "w") as f: 

2143 writer = csv.writer(f) 

2144 writer.writerow(log_history.keys()) 

2145 writer.writerows(zip(*log_history.values())) 

2146 

2147 # ------------------------------------------------------------------ 

2148 # Joint composite-aware optimisation (Round 3) 

2149 # ------------------------------------------------------------------ 

2150 

2151 # Default leaf set whose parameters are jointly optimised. Order 

2152 # matters — it determines the layout of the joint parameter vector 

2153 # (theta). Excluding a leaf from this list freezes it at its 

2154 # current PulseInformation default during joint optimisation. 

2155 JOINT_LEAVES_DEFAULT: Tuple[str, ...] = ("RX", "RY", "RZ", "CZ") 

2156 

2157 # Default set of target gates whose unitary cost is summed during 

2158 # joint optimisation. Composite gates back-propagate into the 

2159 # shared leaves; leaf-gate terms keep the standalone fidelity 

2160 # acceptable. CZ is excluded from the default targets because it 

2161 # is implemented as a static diagonal-Hamiltonian evolution 

2162 # (``H_CZ = π·|11⟩⟨11|``, t=1) that is structurally exact and 

2163 # cannot be improved by tuning leaf parameters — including it only 

2164 # adds ballast to the averaged loss. 

2165 JOINT_TARGETS_DEFAULT: Tuple[str, ...] = ( 

2166 "RX", 

2167 "RY", 

2168 "RZ", 

2169 "H", 

2170 "CX", 

2171 "CRX", 

2172 "CRY", 

2173 "CRZ", 

2174 ) 

2175 

2176 # Default per-target weights for the joint objective. Weights are 

2177 # normalised inside :func:`joint_unitary_cost_fn`. Composites are 

2178 # up-weighted because (a) they are what fails the tightened tests 

2179 # and (b) standalone leaves already start near-perfect, so the 

2180 # averaged loss would otherwise be dominated by the cheap leaves 

2181 # and the optimiser would happily refuse to move. Within 

2182 # composites, CR_ are weighted higher than H/CX because they are 

2183 # the longest decompositions (2 CX + ~6 single-qubit gates) so 

2184 # their leaf-error compounding is worst. 

2185 JOINT_WEIGHTS_DEFAULT: Dict[str, float] = { 

2186 "RX": 0.3, 

2187 "RY": 0.3, 

2188 "RZ": 0.3, 

2189 "H": 1.0, 

2190 "CX": 2.0, 

2191 "CRX": 3.0, 

2192 "CRY": 3.0, 

2193 "CRZ": 3.0, 

2194 } 

2195 

2196 # Leaves that are physically identical up to a static carrier-phase 

2197 # offset (RX uses cos(ω_c t), RY uses cos(ω_c t + π/2)) and therefore 

2198 # *should* share the same envelope parameters. Tying them here in 

2199 # the QOC layout — rather than in :mod:`pulses` — keeps the per-gate 

2200 # decomposition tree intact while ensuring joint optimisation cannot 

2201 # drift their envelopes apart. Empirically RY is the dominant 

2202 # contributor to H/CX residuals, so leaving it un-tied lets the 

2203 # joint loss settle into a basin where RX is well-tuned but RY is 

2204 # ~3× worse; tying them removes that asymmetry. 

2205 JOINT_TIED_GROUPS_DEFAULT: Tuple[Tuple[str, ...], ...] = (("RX", "RY"),) 

2206 

2207 def _build_joint_layout( 

2208 self, 

2209 leaf_names: Tuple[str, ...], 

2210 tied_groups: Optional[Tuple[Tuple[str, ...], ...]] = None, 

2211 ) -> Tuple[jnp.ndarray, Dict[str, slice], List[int]]: 

2212 """Build the joint parameter layout. 

2213 

2214 Args: 

2215 leaf_names: Ordered names of the leaf gates that participate 

2216 in the joint optimisation. 

2217 tied_groups: Optional tuple of leaf-name groups whose 

2218 parameters are forced to share a single slice in 

2219 ``theta``. Defaults to 

2220 :pyattr:`JOINT_TIED_GROUPS_DEFAULT` (ties RX/RY). Only 

2221 leaves that are present in ``leaf_names`` participate — 

2222 a group becomes a no-op if fewer than two of its 

2223 members are listed. 

2224 

2225 Returns: 

2226 Tuple ``(init_theta, leaf_slices, log_scale_indices)``: 

2227 * ``init_theta`` — concatenated init parameters from 

2228 ``PulseInformation.<leaf>.params`` in the given order. 

2229 For tied groups, the representative leaf is the *first* 

2230 member in the group (the group's mean of current params 

2231 is used as the shared init so neither side dominates). 

2232 * ``leaf_slices`` — mapping leaf-name → ``slice`` into 

2233 ``init_theta``. Tied leaves map to the *same* slice. 

2234 * ``log_scale_indices`` — indices into ``init_theta`` that 

2235 should be optimised in log-space (amplitude + evolution 

2236 time per envelope leaf, mirroring the per-gate default 

2237 ``[0, -1]`` rule). 

2238 """ 

2239 if tied_groups is None: 

2240 tied_groups = self.JOINT_TIED_GROUPS_DEFAULT 

2241 

2242 # Build leaf_name -> representative_name lookup. Members of a 

2243 # tied group are routed to the group's first member that is 

2244 # actually present in ``leaf_names``. 

2245 rep_of: Dict[str, str] = {n: n for n in leaf_names} 

2246 leaf_set = set(leaf_names) 

2247 for group in tied_groups: 

2248 present = [n for n in group if n in leaf_set] 

2249 if len(present) < 2: 

2250 continue 

2251 head = present[0] 

2252 for member in present[1:]: 

2253 rep_of[member] = head 

2254 log.info( 

2255 f" Joint layout: tying leaf {member!r} to {head!r} " 

2256 f"(shared slice in theta)." 

2257 ) 

2258 

2259 envelope_info = PulseEnvelope.get(self.envelope) 

2260 n_env = envelope_info["n_envelope_params"] 

2261 

2262 leaf_slices: Dict[str, slice] = {} 

2263 init_chunks = [] 

2264 log_idx: List[int] = [] 

2265 offset = 0 

2266 for name in leaf_names: 

2267 rep = rep_of[name] 

2268 if rep != name: 

2269 # Tied member — point at the representative's slice. 

2270 leaf_slices[name] = leaf_slices[rep] 

2271 continue 

2272 

2273 pp = PulseInformation.gate_by_name(name) 

2274 assert pp is not None and pp.is_leaf, ( 

2275 f"_build_joint_layout: {name!r} is not a leaf gate" 

2276 ) 

2277 # For tied groups the shared init is the elementwise mean 

2278 # of the current params across all present members; this 

2279 # avoids biasing toward whichever member happens to be the 

2280 # group representative. 

2281 tied_members = [m for m in leaf_names if rep_of[m] == name] 

2282 if len(tied_members) > 1: 

2283 stacked = jnp.stack( 

2284 [ 

2285 jnp.asarray( 

2286 PulseInformation.gate_by_name(m).params, 

2287 dtype=jnp.float64, 

2288 ) 

2289 for m in tied_members 

2290 ] 

2291 ) 

2292 chunk = jnp.mean(stacked, axis=0) 

2293 else: 

2294 chunk = jnp.asarray(pp.params, dtype=jnp.float64) 

2295 n_p = chunk.shape[0] 

2296 leaf_slices[name] = slice(offset, offset + n_p) 

2297 init_chunks.append(chunk) 

2298 # Log-scale rule per leaf: only leaves that come from the 

2299 # *envelope* (RX, RY) get log-scaled amplitude+time. RZ 

2300 # and CZ use the "general" registry with a single tuning 

2301 # scalar — leave them in linear space. 

2302 if name in ("RX", "RY") and n_env >= 2: 

2303 log_idx.append(offset) # amplitude 

2304 log_idx.append(offset + n_p - 1) # evolution time 

2305 offset += n_p 

2306 

2307 init_theta = jnp.concatenate(init_chunks) 

2308 return init_theta, leaf_slices, log_idx 

2309 

2310 @staticmethod 

2311 def _assemble_for_gate( 

2312 theta: jnp.ndarray, 

2313 pp_obj, 

2314 leaf_slices: Dict[str, slice], 

2315 ) -> jnp.ndarray: 

2316 """Assemble the per-gate flat ``pulse_params`` from ``theta``. 

2317 

2318 Walks the gate's decomposition tree (recursing through 

2319 composites) and concatenates the appropriate slice of ``theta`` 

2320 for each leaf occurrence. Mirrors :pyattr:`PulseParams.params` 

2321 getter logic but pulls leaf data from the joint vector 

2322 ``theta`` rather than the leaves' own ``_params``. 

2323 """ 

2324 if pp_obj.is_leaf: 

2325 sl = leaf_slices.get(pp_obj.name) 

2326 if sl is None: 

2327 # Leaf is frozen — use its current PulseInformation 

2328 # value directly. 

2329 return jnp.asarray(pp_obj.params, dtype=jnp.float64) 

2330 return theta[sl] 

2331 return jnp.concatenate( 

2332 [ 

2333 QOC._assemble_for_gate(theta, child, leaf_slices) 

2334 for child in pp_obj.childs 

2335 ] 

2336 ) 

2337 

2338 def _joint_stage_0_coord_descent( 

2339 self, 

2340 init_theta: jnp.ndarray, 

2341 leaf_slices: Dict[str, slice], 

2342 total_cost: Callable, 

2343 ) -> jnp.ndarray: 

2344 """Coordinate-descent grid scan over leaf-axis blocks. 

2345 

2346 For each leaf in ``leaf_slices`` (in order), sweep a centred 

2347 multiplicative grid over that leaf's params (using the existing 

2348 :meth:`_build_scan_grid` machinery) while holding the other 

2349 leaves at the current best. Greedily accept any improvement. 

2350 

2351 This avoids the combinatorial explosion of a Cartesian 

2352 product over all leaf axes simultaneously: instead of 

2353 ``Π_i scan_grid_size**k_i`` candidates, only ``Σ_i 

2354 scan_grid_size**k_i`` are evaluated. 

2355 

2356 Args: 

2357 init_theta: Starting joint parameter vector. 

2358 leaf_slices: Mapping leaf-name → slice into ``init_theta``. 

2359 total_cost: Joint cost callable taking ``theta`` and 

2360 returning a scalar loss. 

2361 

2362 Returns: 

2363 Best joint parameter vector found. 

2364 """ 

2365 if self.scan_steps <= 0: 

2366 log.info("Joint Stage 0: scan disabled (scan_steps=0); skipping.") 

2367 return init_theta 

2368 

2369 current = init_theta 

2370 best_loss = _safe_eval(total_cost, current) 

2371 log.info( 

2372 f"Joint Stage 0: coordinate-descent over {len(leaf_slices)} leaves, " 

2373 f"init_loss={float(best_loss):.6e}" 

2374 ) 

2375 

2376 prev_solver_defaults = ys.Yaqsi.set_solver_defaults(throw=False) 

2377 try: 

2378 seen_slices: set = set() 

2379 for leaf_name, sl in leaf_slices.items(): 

2380 # Tied leaves share a slice — only scan the unique 

2381 # (start, stop) range once to avoid wasted evaluations. 

2382 key = (sl.start, sl.stop) 

2383 if key in seen_slices: 

2384 continue 

2385 seen_slices.add(key) 

2386 leaf_init = current[sl] 

2387 n_p = int(leaf_init.shape[0]) 

2388 if n_p == 0: 

2389 continue 

2390 grid, _ = self._build_scan_grid(n_p, init_pulse_params=leaf_init) 

2391 n_better = 0 

2392 for cand in grid: 

2393 new_theta = current.at[sl].set(cand) 

2394 loss = _safe_eval(total_cost, new_theta) 

2395 if loss < best_loss: 

2396 best_loss = loss 

2397 current = new_theta 

2398 n_better += 1 

2399 log.info( 

2400 f" Joint scan after leaf {leaf_name} " 

2401 f"({len(grid)} candidates, {n_better} improved): " 

2402 f"best_loss={float(best_loss):.6e}" 

2403 ) 

2404 finally: 

2405 if prev_solver_defaults: 

2406 ys.Yaqsi.set_solver_defaults(**prev_solver_defaults) 

2407 

2408 return current 

2409 

2410 def _create_joint_pair_for(self, gate_name: str): 

2411 """Return a prep-free ``(pulse, target)`` pair for joint mode. 

2412 

2413 Looks up :meth:`_joint_gate_factories` first; falls back to the 

2414 per-gate (preps included) variant via :meth:`_create_pair_for` 

2415 with a warning if the gate is not in the joint table. See the 

2416 joint-table docstring for why preps are dropped. 

2417 """ 

2418 table = self._joint_gate_factories() 

2419 if gate_name in table: 

2420 return table[gate_name] 

2421 log.warning( 

2422 f"_create_joint_pair_for: no prep-free factory for {gate_name!r}; " 

2423 f"falling back to create_{gate_name} (preps may hide errors)." 

2424 ) 

2425 return self._create_pair_for(gate_name) 

2426 

2427 def _create_pair_for(self, gate_name: str): 

2428 """Return ``(pulse_circuit, target_circuit)`` for a target gate. 

2429 

2430 Reuses :meth:`_create_pair` so the joint mode targets exactly 

2431 the same circuits as the per-gate mode. 

2432 """ 

2433 return self._create_pair(gate_name) 

2434 

2435 def optimize_joint( 

2436 self, 

2437 target_gates: Optional[List[str]] = None, 

2438 leaf_names: Optional[List[str]] = None, 

2439 weights: Optional[Dict[str, float]] = None, 

2440 ) -> Tuple[jnp.ndarray, Dict[str, slice], list]: 

2441 """Joint composite-aware optimisation of leaf pulse parameters. 

2442 

2443 Optimises a single shared parameter vector ``theta`` (containing 

2444 the concatenated leaf params for ``leaf_names``) against a 

2445 weighted sum of unitary-cost terms over ``target_gates``. 

2446 Composite gates back-propagate into the shared leaves; leaf 

2447 terms keep the standalone fidelity acceptable. CZ is omitted 

2448 from the default targets because the ``PulseGates.CZ`` 

2449 implementation is a static diagonal-Hamiltonian evolution 

2450 (``H_CZ = π·|11⟩⟨11|``, t=1) that is structurally exact and 

2451 unaffected by any leaf re-tuning. 

2452 

2453 Args: 

2454 target_gates: Gates whose unitary cost contributes to the 

2455 joint objective. Defaults to 

2456 :pyattr:`JOINT_TARGETS_DEFAULT` (RX, RY, RZ, H, CX, 

2457 CRX, CRY, CRZ). 

2458 leaf_names: Leaf gates whose parameters are jointly 

2459 optimised. Defaults to :pyattr:`JOINT_LEAVES_DEFAULT` 

2460 (RX, RY, RZ, CZ). 

2461 weights: Optional mapping ``gate_name → weight``. Merged 

2462 on top of :pyattr:`JOINT_WEIGHTS_DEFAULT` (composites 

2463 up-weighted; leaves down-weighted). All weights are 

2464 normalised inside the cost. 

2465 

2466 Returns: 

2467 ``(best_theta, leaf_slices, loss_history)``. Per-leaf 

2468 results are also written to ``qoc_results_<envelope>.csv`` 

2469 via :meth:`save_results`. 

2470 """ 

2471 if target_gates: 

2472 target_gates = list(target_gates) 

2473 else: 

2474 target_gates = list(self.JOINT_TARGETS_DEFAULT) 

2475 

2476 if leaf_names: 

2477 leaf_names = list(leaf_names) 

2478 else: 

2479 leaf_names = list(self.JOINT_LEAVES_DEFAULT) 

2480 

2481 # Merge user-provided weights on top of class defaults so callers 

2482 # can override only the gates they care about. 

2483 merged_weights: Dict[str, float] = dict(self.JOINT_WEIGHTS_DEFAULT) 

2484 if weights: 

2485 merged_weights.update({k: float(v) for k, v in weights.items()}) 

2486 weights = merged_weights 

2487 

2488 log.info(f"Joint optimisation: leaves={leaf_names}, targets={target_gates}") 

2489 

2490 init_theta, leaf_slices, joint_log_idx = self._build_joint_layout( 

2491 tuple(leaf_names) 

2492 ) 

2493 log.info( 

2494 f" Joint theta size: {init_theta.shape[0]}; " 

2495 f"log-scale indices: {joint_log_idx}" 

2496 ) 

2497 

2498 # Build per-gate specs (assembler + basis-prep scripts). 

2499 gate_specs: List[dict] = [] 

2500 for gname in target_gates: 

2501 pp_obj = PulseInformation.gate_by_name(gname) 

2502 if pp_obj is None: 

2503 log.warning(f" Skipping unknown gate {gname!r}.") 

2504 continue 

2505 n_wires = 1 if gname in self.GATES_1Q else 2 

2506 d_basis = 2**n_wires 

2507 pulse_circuit, target_circuit = self._create_joint_pair_for(gname) 

2508 

2509 pulse_basis_scripts = [ 

2510 ys.Script(_with_basis_prep(pulse_circuit, k, n_wires), n_qubits=n_wires) 

2511 for k in range(d_basis) 

2512 ] 

2513 target_basis_scripts = [ 

2514 ys.Script( 

2515 _with_basis_prep(target_circuit, k, n_wires), n_qubits=n_wires 

2516 ) 

2517 for k in range(d_basis) 

2518 ] 

2519 

2520 # Closure capturing pp_obj + leaf_slices. Defined here so 

2521 # each spec carries its own assembler. 

2522 def _make_assembler(pp_obj=pp_obj): 

2523 def assemble(theta): 

2524 return QOC._assemble_for_gate(theta, pp_obj, leaf_slices) 

2525 

2526 return assemble 

2527 

2528 gate_specs.append( 

2529 { 

2530 "name": gname, 

2531 "n_qubits": n_wires, 

2532 "weight": float(weights.get(gname, 1.0)), 

2533 "assembler": _make_assembler(), 

2534 "pulse_basis_scripts": pulse_basis_scripts, 

2535 "target_basis_scripts": target_basis_scripts, 

2536 } 

2537 ) 

2538 log.info( 

2539 f" Built spec for {gname}: n_qubits={n_wires}, " 

2540 f"weight={gate_specs[-1]['weight']}" 

2541 ) 

2542 

2543 # Build the joint cost as a Cost wrapper (so weight-tuple 

2544 # collapsing into a scalar is shared with the per-gate path). 

2545 # We use the same (process_loss, phase_loss) two-component 

2546 # weighting as the standalone unitary cost — keeps the relative 

2547 # importance of fidelity vs phase consistent. 

2548 ((_, weight_tuple),) = ( 

2549 ((n, w) for n, w in self.cost_fns if n == "unitary") 

2550 if any(n == "unitary" for n, _ in self.cost_fns) 

2551 else ((None, (0.5, 0.5)),) 

2552 ) 

2553 joint_cost = Cost( 

2554 cost=joint_unitary_cost_fn, 

2555 weight=weight_tuple, 

2556 ckwargs={ 

2557 "gate_specs": gate_specs, 

2558 "n_samples": self.n_samples, 

2559 }, 

2560 ) 

2561 

2562 # Temporarily override log_scale_params to point at joint 

2563 # vector indices (Stage 0 grid building + Stage 1 log-space 

2564 # reparam both consult ``self.log_scale_params``). Invalidate 

2565 # the mask cache on either side of the swap so the joint 

2566 # vector picks up the joint indices and per-gate runs revert 

2567 # cleanly afterwards. 

2568 prev_log_scale = self.log_scale_params 

2569 self.log_scale_params = joint_log_idx 

2570 self._log_mask_cache.clear() 

2571 try: 

2572 best_scan_theta = self._joint_stage_0_coord_descent( 

2573 init_theta, leaf_slices, joint_cost 

2574 ) 

2575 

2576 global_best_theta, global_best_history, global_best_loss = self.stage_1_opt( 

2577 best_scan_theta, joint_cost 

2578 ) 

2579 finally: 

2580 self.log_scale_params = prev_log_scale 

2581 self._log_mask_cache.clear() 

2582 

2583 log.info(f"Joint optimisation done. final loss={float(global_best_loss):.6e}") 

2584 

2585 # Save per-leaf results to the CSV (one row per leaf). The 

2586 # fidelity column carries the *joint* fidelity; downstream code 

2587 # that reads the CSV (or the user copy-pasting into pulses.py) 

2588 # can use it as a coarse quality signal. 

2589 joint_fid = float(1.0 - global_best_loss) 

2590 for leaf_name, sl in leaf_slices.items(): 

2591 leaf_params = global_best_theta[sl] 

2592 self.save_results( 

2593 gate=leaf_name, 

2594 fidelity=joint_fid, 

2595 pulse_params=leaf_params, 

2596 ) 

2597 

2598 # Update PulseInformation in-place so the new defaults are 

2599 # active in this Python process (handy for diagnostic scripts 

2600 # that import QOC and then evaluate the new gates). 

2601 for leaf_name, sl in leaf_slices.items(): 

2602 pp = PulseInformation.gate_by_name(leaf_name) 

2603 pp.params = global_best_theta[sl] 

2604 

2605 return global_best_theta, leaf_slices, global_best_history 

2606 

2607 

2608default_qoc_params = { 

2609 "envelope": "drag", 

2610 "cost_fns": [ 

2611 # Unitary-level cost (process infidelity + trace-phase term). 

2612 # Captures rotation-axis tilt and global-phase residual that 

2613 # the state-fidelity cost is blind to; required to keep two-CX 

2614 # composites (CRX/CRY/CRZ) within tightened phase tolerances. 

2615 ("unitary", (0.5, 0.5)), 

2616 # ("fidelity", (0.5, 0.5)), # legacy state-vector cost 

2617 # ("pulse_width", 0.000000015), 

2618 # ("evolution_time", 0.000000005), 

2619 ], 

2620 "t_target": 0.5, 

2621 "n_steps": 800, 

2622 "n_samples": 20, 

2623 "learning_rate": 0.0001, 

2624 "warmup_ratio": 0.05, 

2625 "end_lr_ratio": 0.01, 

2626 "log_interval": 50, 

2627 "file_dir": None, 

2628 "n_restarts": 5, 

2629 "restart_noise_scale": 0.01, 

2630 "grad_clip": 1.0, 

2631 "random_seed": 1000, 

2632 "scan_steps": 20, 

2633 "scan_grid_size": 4, 

2634 "scan_ranges": None, 

2635 "log_scale_params": None, 

2636 "early_stop_patience": 0, 

2637 "early_stop_min_delta": 0.0, 

2638} 

2639 

2640 

2641def profile_pulse_pipeline( 

2642 gate: str = "RX", 

2643 n_samples: int = 3, 

2644 rwa: Optional[bool] = None, 

2645 n_qubits: int = 1, 

2646) -> dict: 

2647 """Profile a single pulse gate's forward + ``value_and_grad`` pass. 

2648 

2649 Diagnostic helper for the JIT pipeline. Builds a minimal 

2650 :class:`Script` that applies the requested pulse gate, then 

2651 times JIT compilation and steady-state evaluation of: 

2652 

2653 1. one forward pass (``Script.execute(type="state", ...)``); 

2654 2. one ``jax.value_and_grad`` of the squared overlap with the 

2655 analytic ``operations.<gate>`` target. 

2656 

2657 Use this to measure the impact of the RWA toggle 

2658 (``rwa=True``) and of the scan/sync refactors documented in 

2659 the patch notes: 

2660 

2661 from qml_essentials.qoc import profile_pulse_pipeline 

2662 profile_pulse_pipeline("RX", rwa=False) 

2663 profile_pulse_pipeline("RX", rwa=True) 

2664 

2665 Args: 

2666 gate: Gate name to profile (default ``"RX"``). Must be a 

2667 single-qubit pulse-level gate (``RX`` / ``RY``). 

2668 n_samples: Number of timed evaluations after warm-up. 

2669 rwa: If not ``None``, temporarily switch the global RWA flag 

2670 for the duration of the profile. ``None`` keeps the 

2671 current setting. 

2672 n_qubits: Width of the script (kept at 1 for the single- 

2673 qubit pulse gates). 

2674 

2675 Returns: 

2676 Dict with keys ``compile_fwd``, ``mean_fwd``, ``compile_grad``, 

2677 ``mean_grad``, ``rwa``, ``loss``. 

2678 """ 

2679 import time 

2680 

2681 with PulseInformation.preserve_state(): 

2682 if rwa is not None: 

2683 PulseInformation.set_rwa(bool(rwa)) 

2684 from qml_essentials.pulses import PulseGates 

2685 

2686 gate_op = getattr(op, gate) 

2687 gate_pulse = getattr(PulseGates, gate) 

2688 

2689 def pulse_circuit(theta, pp): 

2690 gate_pulse(theta, wires=0, pulse_params=pp) 

2691 

2692 def target_circuit(theta): 

2693 gate_op(theta, wires=0) 

2694 

2695 pulse_script = ys.Script(pulse_circuit, n_qubits=n_qubits) 

2696 target_script = ys.Script(target_circuit, n_qubits=n_qubits) 

2697 

2698 theta = jnp.asarray(jnp.pi / 4) 

2699 pp = PulseInformation.gate_by_name(gate).params 

2700 target_state = target_script.execute(type="state", args=(theta,)) 

2701 target_state = jax.lax.stop_gradient(target_state) 

2702 

2703 @jax.jit 

2704 def fwd(theta, pp): 

2705 return pulse_script.execute(type="state", args=(theta, pp)) 

2706 

2707 @jax.jit 

2708 def loss_and_grad(pp): 

2709 def loss_fn(p): 

2710 state = pulse_script.execute(type="state", args=(theta, p)) 

2711 return 1.0 - jnp.abs(jnp.vdot(target_state, state)) ** 2 

2712 

2713 return jax.value_and_grad(loss_fn)(pp) 

2714 

2715 # Warm-up + compile timings. 

2716 t0 = time.perf_counter() 

2717 s = fwd(theta, pp) 

2718 jax.block_until_ready(s) 

2719 compile_fwd = time.perf_counter() - t0 

2720 

2721 t0 = time.perf_counter() 

2722 loss, grads = loss_and_grad(pp) 

2723 jax.block_until_ready(loss) 

2724 jax.block_until_ready(grads) 

2725 compile_grad = time.perf_counter() - t0 

2726 

2727 fwd_t, grad_t = [], [] 

2728 for _ in range(n_samples): 

2729 t0 = time.perf_counter() 

2730 s = fwd(theta, pp) 

2731 jax.block_until_ready(s) 

2732 fwd_t.append(time.perf_counter() - t0) 

2733 

2734 t0 = time.perf_counter() 

2735 loss, grads = loss_and_grad(pp) 

2736 jax.block_until_ready(loss) 

2737 jax.block_until_ready(grads) 

2738 grad_t.append(time.perf_counter() - t0) 

2739 

2740 result = { 

2741 "gate": gate, 

2742 "rwa": PulseInformation.get_rwa(), 

2743 "compile_fwd": compile_fwd, 

2744 "mean_fwd": float(np.mean(fwd_t)), 

2745 "compile_grad": compile_grad, 

2746 "mean_grad": float(np.mean(grad_t)), 

2747 "loss": float(loss), 

2748 } 

2749 log.info( 

2750 f"[profile] gate={gate} rwa={result['rwa']} " 

2751 f"compile fwd/grad: {compile_fwd * 1e3:.1f}/" 

2752 f"{compile_grad * 1e3:.1f} ms, " 

2753 f"mean fwd/grad: {result['mean_fwd'] * 1e3:.1f}/" 

2754 f"{result['mean_grad'] * 1e3:.1f} ms, " 

2755 f"loss={result['loss']:.4e}" 

2756 ) 

2757 return result 

2758 

2759 

2760if __name__ == "__main__": 

2761 # argparse the selected gate 

2762 parser = argparse.ArgumentParser( 

2763 description="Quantum Optimal Control — pulse-level gate synthesis." 

2764 ) 

2765 parser.add_argument( 

2766 "--gates", 

2767 type=str, 

2768 nargs="+", 

2769 default=["RX", "RY", "RZ", "CZ"], 

2770 choices=QOC.GATES_1Q + QOC.GATES_2Q + ["all"], 

2771 help="Gate(s) to optimize.", 

2772 ) 

2773 parser.add_argument( 

2774 "--log", 

2775 action="store_true", 

2776 default=False, 

2777 help="Log results to file (default: False).", 

2778 ) 

2779 parser.add_argument( 

2780 "--no-log", 

2781 action="store_false", 

2782 dest="log", 

2783 help="Disable logging results to file.", 

2784 ) 

2785 parser.add_argument( 

2786 "--envelope", 

2787 type=str, 

2788 default=default_qoc_params["envelope"], 

2789 choices=PulseEnvelope.available(), 

2790 help="Pulse envelope shape to use for optimization.", 

2791 ) 

2792 parser.add_argument( 

2793 "--costs", 

2794 type=str, 

2795 nargs="+", 

2796 default=default_qoc_params["cost_fns"], 

2797 help=( 

2798 "Cost functions and weights as 'name:w1,w2,...' strings. " 

2799 "If weights are omitted the registry defaults are used. " 

2800 f"Available: {CostFnRegistry.available()}. " 

2801 "Example: --costs fidelity:0.5,0.3 pulse_width:0.2" 

2802 ), 

2803 ) 

2804 parser.add_argument( 

2805 "--t_target", 

2806 type=float, 

2807 default=default_qoc_params["t_target"], 

2808 help=( 

2809 "Target evolution time for the 'evolution_time' cost function. " 

2810 "All gates will be softly encouraged towards this common time." 

2811 ), 

2812 ) 

2813 parser.add_argument( 

2814 "--n_steps", 

2815 type=int, 

2816 default=default_qoc_params["n_steps"], 

2817 help="Number of optimisation steps per gate.", 

2818 ) 

2819 parser.add_argument( 

2820 "--n_samples", 

2821 type=int, 

2822 default=default_qoc_params["n_samples"], 

2823 help="Number of parameter samples in [0, 2\\pi] for cost evaluation.", 

2824 ) 

2825 parser.add_argument( 

2826 "--learning_rate", 

2827 type=float, 

2828 default=default_qoc_params["learning_rate"], 

2829 help="Peak learning rate for the AdamW optimiser.", 

2830 ) 

2831 parser.add_argument( 

2832 "--warmup_ratio", 

2833 type=float, 

2834 default=default_qoc_params["warmup_ratio"], 

2835 help=( 

2836 "Fraction of n_steps used for linear LR warmup (0.0-1.0). " 

2837 "Set to 0 to start at the peak LR immediately." 

2838 ), 

2839 ) 

2840 parser.add_argument( 

2841 "--end_lr_ratio", 

2842 type=float, 

2843 default=default_qoc_params["end_lr_ratio"], 

2844 help=( 

2845 "Final LR as a fraction of --learning_rate after cosine decay. " 

2846 "Also used as the initial LR before warmup. " 

2847 "Set to 1.0 (with --warmup_ratio 0) for a constant LR." 

2848 ), 

2849 ) 

2850 parser.add_argument( 

2851 "--log_interval", 

2852 type=int, 

2853 default=default_qoc_params["log_interval"], 

2854 help="Log the current loss every N steps.", 

2855 ) 

2856 parser.add_argument( 

2857 "--file_dir", 

2858 type=str, 

2859 default=default_qoc_params["file_dir"], 

2860 help="Directory to save qoc_results_[envelope].csv. " 

2861 "Defaults to the package directory.", 

2862 ) 

2863 parser.add_argument( 

2864 "--n_restarts", 

2865 type=int, 

2866 default=default_qoc_params["n_restarts"], 

2867 help=( 

2868 "Number of random restarts for the optimisation. " 

2869 "The first run uses the initial parameters as-is; " 

2870 "subsequent runs add random perturbations. " 

2871 "The best result across all restarts is kept." 

2872 ), 

2873 ) 

2874 parser.add_argument( 

2875 "--restart_noise_scale", 

2876 type=float, 

2877 default=default_qoc_params["restart_noise_scale"], 

2878 help=( 

2879 "Standard deviation of Gaussian noise added to the initial " 

2880 "parameters for each restart, relative to parameter magnitude." 

2881 ), 

2882 ) 

2883 parser.add_argument( 

2884 "--grad_clip", 

2885 type=float, 

2886 default=default_qoc_params["grad_clip"], 

2887 help=( 

2888 "Maximum global gradient norm. Gradients are clipped to this " 

2889 "value before being passed to the optimiser. " 

2890 "Set to 0 to disable." 

2891 ), 

2892 ) 

2893 parser.add_argument( 

2894 "--random_seed", 

2895 type=int, 

2896 default=default_qoc_params["random_seed"], 

2897 help="Base random seed for restart perturbations.", 

2898 ) 

2899 parser.add_argument( 

2900 "--scan_steps", 

2901 type=int, 

2902 default=default_qoc_params["scan_steps"], 

2903 help=( 

2904 "Number of short gradient-descent steps per candidate in the " 

2905 "coarse grid scan (Stage 0). Set to 0 to disable the grid scan." 

2906 ), 

2907 ) 

2908 parser.add_argument( 

2909 "--scan_grid_size", 

2910 type=int, 

2911 default=default_qoc_params["scan_grid_size"], 

2912 help=( 

2913 "Number of points per parameter dimension in the coarse grid. " 

2914 "Total candidates = scan_grid_size^n_params." 

2915 ), 

2916 ) 

2917 parser.add_argument( 

2918 "--scan_ranges", 

2919 type=str, 

2920 nargs="*", 

2921 default=default_qoc_params["scan_ranges"], 

2922 help=( 

2923 "Per-parameter (lo,hi) ranges for the grid scan, given as " 

2924 "'lo,hi' strings. One pair per pulse parameter. " 

2925 "Example: --scan_ranges 0.5,30.0 0.05,2.0 0.05,2.0 " 

2926 "If omitted, heuristic defaults are used." 

2927 ), 

2928 ) 

2929 parser.add_argument( 

2930 "--plot", 

2931 action="store_true", 

2932 default=False, 

2933 help=( 

2934 "Save a loss-landscape plot (Phase 0) and a loss-curve plot " 

2935 "(Phase 1) as PNG files in --file_dir for each optimised gate." 

2936 ), 

2937 ) 

2938 parser.add_argument( 

2939 "--early_stop_patience", 

2940 type=int, 

2941 default=default_qoc_params["early_stop_patience"], 

2942 help=( 

2943 "Number of consecutive Stage-1 steps without improvement " 

2944 "(> --early_stop_min_delta) after which optimisation exits " 

2945 "early. 0 disables early stopping (default)." 

2946 ), 

2947 ) 

2948 parser.add_argument( 

2949 "--early_stop_min_delta", 

2950 type=float, 

2951 default=default_qoc_params["early_stop_min_delta"], 

2952 help=( 

2953 "Minimum loss decrease that counts as an improvement for " 

2954 "the --early_stop_patience counter (default 0.0)." 

2955 ), 

2956 ) 

2957 parser.add_argument( 

2958 "--joint", 

2959 action="store_true", 

2960 default=False, 

2961 help=( 

2962 "Use composite-aware joint optimisation: a single shared " 

2963 "leaf parameter vector is optimised against the unitary " 

2964 "cost summed over leaf and composite gates " 

2965 "(default targets: RX, RY, RZ, CZ, H, CX, CRX, CRY, CRZ). " 

2966 "Pulls leaves into a basin that works well in *every* " 

2967 "use-site instead of only standalone, fixing the " 

2968 "selfish-basin failure mode of per-gate optimisation. " 

2969 "Ignores --gates." 

2970 ), 

2971 ) 

2972 parser.add_argument( 

2973 "--joint_targets", 

2974 nargs="+", 

2975 type=str, 

2976 default=None, 

2977 help=( 

2978 "(Used only with --joint.) Override the list of target " 

2979 "gates whose unitary cost contributes to the joint " 

2980 "objective." 

2981 ), 

2982 ) 

2983 parser.add_argument( 

2984 "--joint_leaves", 

2985 nargs="+", 

2986 type=str, 

2987 default=None, 

2988 help=( 

2989 "(Used only with --joint.) Override the list of leaf " 

2990 "gates whose parameters are jointly optimised. " 

2991 "Default: RX RY RZ CZ." 

2992 ), 

2993 ) 

2994 parser.add_argument( 

2995 "--joint_weights", 

2996 nargs="+", 

2997 type=str, 

2998 default=None, 

2999 help=( 

3000 "(Used only with --joint.) Override per-target weights as " 

3001 "'gate:weight' strings (e.g. --joint_weights CRX:5 CX:3). " 

3002 "Merged on top of QOC.JOINT_WEIGHTS_DEFAULT, so unspecified " 

3003 "gates keep their default weight." 

3004 ), 

3005 ) 

3006 parser.add_argument( 

3007 "--rwa", 

3008 action="store_true", 

3009 default=False, 

3010 help=( 

3011 "Toggles RWA mode for pulse simulation." 

3012 "If this is set true, we utilize the rotating wave approximation " 

3013 "instead of the exact interaction picture." 

3014 "While this makes the calculations less exact, it provides" 

3015 "significant speedup." 

3016 ), 

3017 ) 

3018 parser.add_argument( 

3019 "--drive", 

3020 action="store_true", 

3021 default=False, 

3022 help=("Uses drive hamiltonian instead of lab frame."), 

3023 ) 

3024 

3025 args = parser.parse_args() 

3026 sel_gates = args.gates # already a list from nargs="+" 

3027 make_log = args.log 

3028 

3029 # Parse scan_ranges from CLI (list of "lo,hi" strings -> list of tuples) 

3030 scan_ranges = None 

3031 if args.scan_ranges is not None: 

3032 scan_ranges = [] 

3033 for pair in args.scan_ranges: 

3034 lo, hi = pair.split(",") 

3035 scan_ranges.append((float(lo), float(hi))) 

3036 

3037 PulseInformation.set_rwa(args.rwa) 

3038 PulseInformation.set_frame("drive" if args.drive else "lab") 

3039 

3040 # Parse cost function specs from CLI 

3041 cost_fns = [CostFnRegistry.parse_cost_arg(spec) for spec in args.costs] 

3042 

3043 # create logger 

3044 log = logging.getLogger("qml_essentials.qoc") 

3045 

3046 log.setLevel(logging.INFO) 

3047 log.addHandler(logging.StreamHandler()) 

3048 

3049 qoc = QOC( 

3050 envelope=args.envelope, 

3051 cost_fns=cost_fns, 

3052 t_target=args.t_target, 

3053 n_steps=args.n_steps, 

3054 n_samples=args.n_samples, 

3055 learning_rate=args.learning_rate, 

3056 warmup_ratio=args.warmup_ratio, 

3057 end_lr_ratio=args.end_lr_ratio, 

3058 log_interval=args.log_interval, 

3059 file_dir=args.file_dir, 

3060 n_restarts=args.n_restarts, 

3061 restart_noise_scale=args.restart_noise_scale, 

3062 grad_clip=args.grad_clip, 

3063 random_seed=args.random_seed, 

3064 scan_steps=args.scan_steps, 

3065 scan_grid_size=args.scan_grid_size, 

3066 scan_ranges=scan_ranges, 

3067 early_stop_patience=args.early_stop_patience, 

3068 early_stop_min_delta=args.early_stop_min_delta, 

3069 plot=args.plot, 

3070 ) 

3071 

3072 if args.joint: 

3073 joint_weights = None 

3074 if args.joint_weights: 

3075 joint_weights = {} 

3076 for spec in args.joint_weights: 

3077 gname, w = spec.split(":") 

3078 joint_weights[gname.strip()] = float(w) 

3079 qoc.optimize_joint( 

3080 target_gates=args.joint_targets, 

3081 leaf_names=args.joint_leaves, 

3082 weights=joint_weights, 

3083 ) 

3084 else: 

3085 qoc.optimize_all(sel_gates=sel_gates, make_log=make_log)