Coverage for qml_essentials / qoc.py: 44%
894 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-16 10:19 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-16 10:19 +0000
1import argparse
2import csv
3import itertools
4import logging
5import os
6from typing import Callable, Dict, List, Optional, Tuple, Union
8import jax
9from jax import numpy as jnp
10import numpy as np
11import optax
13from qml_essentials.gates import Gates, PulseInformation, PulseEnvelope
14from qml_essentials import operations as op
15from qml_essentials import yaqsi as ys
16from qml_essentials.math import phase_difference, fidelity
18jax.config.update("jax_enable_x64", True)
19log = logging.getLogger(__name__)
22def _build_optimizer(schedule, grad_clip: float):
23 """Build the AdamW chain used by both stage-0 and stage-1.
25 Adds a global-norm gradient-clip step when ``grad_clip`` is a
26 finite, strictly positive value; otherwise returns plain AdamW.
27 """
28 use_clip = grad_clip and grad_clip > 0 and jnp.isfinite(grad_clip)
29 if use_clip:
30 return optax.chain(
31 optax.clip_by_global_norm(grad_clip),
32 optax.adamw(schedule),
33 )
34 return optax.adamw(schedule)
37def _safe_eval(cost_fn: Callable, params: jnp.ndarray) -> jnp.ndarray:
38 """Evaluate ``cost_fn(params)``; map non-finite results to ``+inf``."""
39 loss = cost_fn(params)
40 return jnp.where(jnp.isfinite(loss), loss, jnp.inf)
43def _with_basis_prep(circuit_fn: Callable, k: int, n_wires: int) -> Callable:
44 """Wrap ``circuit_fn`` so it first prepares basis state ``|k⟩``.
46 The wrapped circuit applies ``PauliX`` on every wire whose bit in
47 ``k`` is set (MSB first) before delegating to ``circuit_fn``. Used
48 by both per-gate and joint optimisation paths to build the
49 column-stacked unitary required by :func:`unitary_cost_fn`.
50 """
51 bits = [(k >> (n_wires - 1 - i)) & 1 for i in range(n_wires)]
53 def prepared(*args, **kwargs):
54 for i, bit in enumerate(bits):
55 if bit:
56 op.PauliX(wires=i)
57 circuit_fn(*args, **kwargs)
59 prepared.__name__ = f"basis{k}_{circuit_fn.__name__}"
60 return prepared
63def _sample_rotation_angles(n_samples: int) -> jnp.ndarray:
64 """Boundary-biased sample of rotation angles in ``[0, 2π)``.
66 The pulse-vs-target residual scales roughly linearly with rotation
67 angle, so a uniform sample over ``[0, 2π)`` underweights the
68 high-residual band that dominates failing tests (typical large-w
69 test points: π/2, π). We stratify the samples into
71 * a uniform component covering the full ``[0, 2π)`` circle, and
72 * a focus component packed in ``[π/2, 3π/2]``
74 so the central band is sampled at roughly twice the density of the
75 tails. Returns at least one angle even for ``n_samples == 1``;
76 when ``n_samples == 1`` the legacy uniform behaviour is preserved
77 (single sample at ``w = 0``) to avoid surprising callers.
78 """
79 if n_samples <= 1:
80 return jnp.linspace(0.0, 2.0 * jnp.pi, max(n_samples, 1), endpoint=False)
81 # ~1/3 of samples in the central [π/2, 3π/2] band on top of a full
82 # uniform sweep. Sub-sample counts are rounded so both components
83 # are non-empty for any ``n_samples >= 2``.
84 k_focus = max(1, n_samples // 3)
85 k_uniform = n_samples - k_focus
86 ws_uniform = jnp.linspace(0.0, 2.0 * jnp.pi, k_uniform, endpoint=False)
87 ws_focus = jnp.linspace(0.5 * jnp.pi, 1.5 * jnp.pi, k_focus, endpoint=False)
88 return jnp.concatenate([ws_uniform, ws_focus])
91def _run_gate_stage(stage: Optional[Callable], w) -> None:
92 """Execute an optional gate-preparation stage."""
93 if stage is not None:
94 stage(w)
97def _chain_gate_stages(*stages: Callable) -> Callable:
98 """Build a stage that runs multiple preparation operations in sequence."""
100 def chained(w):
101 for stage in stages:
102 stage(w)
104 return chained
107def _make_gate_pair(
108 pulse_gate: Callable,
109 target_gate: Callable,
110 prep: Optional[Callable] = None,
111 post: Optional[Callable] = None,
112) -> Tuple[Callable, Callable]:
113 """Build matching pulse and target circuits with optional pre/post stages."""
115 def pulse_circuit(w, pp):
116 _run_gate_stage(prep, w)
117 pulse_gate(w, pp)
118 _run_gate_stage(post, w)
120 def target_circuit(w):
121 _run_gate_stage(prep, w)
122 target_gate(w)
123 _run_gate_stage(post, w)
125 return pulse_circuit, target_circuit
128class Cost:
129 """Weighted wrapper around a cost function.
131 Combines a cost callable with a scalar or tuple weight and optional
132 constant keyword arguments. Multiple ``Cost`` instances can be
133 composed via the ``+`` operator to build a combined objective.
135 Args:
136 cost: Callable ``(pulse_params, **ckwargs) -> scalar | tuple``.
137 weight: Scalar or tuple of per-component weights.
138 ckwargs: Constant keyword arguments injected into every call.
139 """
141 def __init__(
142 self,
143 cost: Callable,
144 weight: Union[float, Tuple],
145 ckwargs: Optional[dict] = None,
146 ):
147 self.cost = cost
148 self.weight = weight
149 self.ckwargs = ckwargs if ckwargs is not None else {}
151 def __call__(self, *args, **kwargs):
152 """Evaluate the cost function with injected kwargs and apply weights."""
153 cost = self.cost(*args, **kwargs, **self.ckwargs)
154 if isinstance(self.weight, tuple):
155 return jnp.array(
156 [c * w for c, w in zip(cost, self.weight, strict=True)]
157 ).sum()
158 return cost * self.weight
160 def __add__(self, other):
161 """Compose two cost terms into a single callable that sums them."""
162 if other is None:
163 return lambda *args, **kwargs: self(*args, **kwargs)
164 if callable(other):
165 return lambda *args, **kwargs: (
166 self(*args, **kwargs) + other(*args, **kwargs)
167 )
168 raise TypeError(f"Cannot add Cost and {type(other)}")
171def fidelity_cost_fn(
172 pulse_params: jnp.ndarray,
173 pulse_scripts: Union[ys.Script, List[ys.Script]],
174 target_scripts: Union[ys.Script, List[ys.Script]],
175 n_samples: int,
176) -> Tuple[float, float]:
177 """
178 Cost function returning ``(1 - fidelity, 1 - cos(phase_difference))``
179 averaged over ``n_samples`` uniformly spaced rotation angles in
180 ``[0, 2π)`` and across one or more (pulse, target) script pairs.
182 Multiple script pairs let the optimiser probe sensitivity from
183 multiple initial states (e.g. ``|0⟩`` and ``|+⟩``). This makes
184 rotation-axis tilt observable to the cost: from ``|0⟩`` alone an
185 RX/RY pulse with a small Z-component is largely degenerate with
186 the correct pulse, but from ``|+⟩`` the same tilt produces a
187 visible state-vector deviation.
189 Uses batched (vmapped) circuit execution per script: all
190 ``n_samples`` rotation angles are evaluated in a single vectorised
191 call per script, replacing ``n_samples`` sequential Python-level
192 circuit executions with one JIT-compiled XLA program each.
194 The phase term uses ``1 - cos(Δφ)`` rather than ``|Δφ|`` so that
195 it is differentiable everywhere (including at the optimum) and
196 well-behaved at the ``±π`` wrap-around — important because Stage 0
197 now sees the same cost as Stage 1.
199 Args:
200 pulse_params: Pulse parameters for evaluation.
201 pulse_scripts: One or a list of yaqsi scripts with pulse
202 parameters. If a list is supplied, the cost is averaged
203 element-wise with ``target_scripts`` (which must have the
204 same length).
205 target_scripts: One or a list of yaqsi target scripts.
206 n_samples: Number of parameter samples.
208 Returns:
209 Tuple of ``(abs_diff, phase_diff)`` averaged across script pairs.
210 """
211 if not isinstance(pulse_scripts, (list, tuple)):
212 pulse_scripts = [pulse_scripts]
213 if not isinstance(target_scripts, (list, tuple)):
214 target_scripts = [target_scripts]
215 assert len(pulse_scripts) == len(target_scripts), (
216 f"pulse_scripts and target_scripts must have the same length "
217 f"({len(pulse_scripts)} vs {len(target_scripts)})."
218 )
220 ws = _sample_rotation_angles(n_samples)
222 abs_diffs = []
223 phase_diffs = []
224 for p_script, t_script in zip(pulse_scripts, target_scripts):
225 pulse_states = p_script.execute(
226 type="state",
227 args=(ws, pulse_params),
228 in_axes=(0, None),
229 ) # (n_samples, dim)
231 target_states = t_script.execute(
232 type="state",
233 args=(ws,),
234 in_axes=(0,),
235 ) # (n_samples, dim)
237 abs_diffs.append(
238 jnp.mean(
239 jnp.array(1.0, dtype=jnp.float64)
240 - fidelity(pulse_states, target_states)
241 )
242 )
243 phase_diffs.append(
244 jnp.mean(
245 jnp.array(1.0, dtype=jnp.float64)
246 - jnp.cos(phase_difference(pulse_states, target_states))
247 )
248 )
250 abs_diff = jnp.mean(jnp.stack(abs_diffs))
251 phase_diff = jnp.mean(jnp.stack(phase_diffs))
253 # TODO: in future we could consider some sort of log based loss for the small values
254 # or utilize gradient ascent if we run into numerical limitations
256 return (abs_diff, phase_diff)
259def unitary_cost_fn(
260 pulse_params: jnp.ndarray,
261 pulse_basis_scripts: List[ys.Script],
262 target_basis_scripts: List[ys.Script],
263 n_samples: int,
264 n_qubits: int,
265) -> Tuple[float, float]:
266 """Unitary-level cost based on the average gate (process) fidelity.
268 Builds the full unitary of the pulse and target circuits at every
269 sampled rotation angle by stacking ``2**n_qubits`` basis-state
270 evolutions as columns (``U[:, k] = circuit(|k⟩)``). Returns
272 (1 - |Tr(E)|² / d², 1 - cos(angle(Tr(E))))
274 where ``E = U_target† · U_pulse`` and ``d = 2**n_qubits``.
276 The first component is the standard process-infidelity (which is
277 *global-phase invariant*). The second component captures the
278 residual global phase between pulse and target — without it the
279 optimiser cannot distinguish ``U_pulse`` and ``e^{iα} U_pulse``,
280 which leaves systematic phase errors in composed gates (e.g. the
281 H-CZ-H decomposition of CX).
283 Compared to the state-vector ``fidelity_cost_fn``, this cost
284 captures rotation-axis tilt and off-diagonal coherent error in a
285 single number, regardless of which probe state(s) one chooses.
287 Args:
288 pulse_params: Pulse parameters under optimisation.
289 pulse_basis_scripts: List of ``d`` scripts; the k-th script
290 prepares ``|k⟩`` (via ``PauliX`` gates) and then applies
291 the pulse-level circuit.
292 target_basis_scripts: Same for the target circuit.
293 n_samples: Number of rotation-angle samples in ``[0, 2π)``.
294 n_qubits: Number of qubits the gate acts on.
296 Returns:
297 Tuple ``(process_loss, phase_loss)`` averaged over rotation
298 angles.
299 """
300 d = 2**n_qubits
301 assert len(pulse_basis_scripts) == d, (
302 f"pulse_basis_scripts must have {d} entries (one per basis "
303 f"state); got {len(pulse_basis_scripts)}."
304 )
305 assert len(target_basis_scripts) == d, (
306 f"target_basis_scripts must have {d} entries (one per basis "
307 f"state); got {len(target_basis_scripts)}."
308 )
310 ws = _sample_rotation_angles(n_samples)
312 pulse_cols = []
313 target_cols = []
314 for k in range(d):
315 ps = pulse_basis_scripts[k].execute(
316 type="state",
317 args=(ws, pulse_params),
318 in_axes=(0, None),
319 ) # (n_samples, d)
320 ts = target_basis_scripts[k].execute(
321 type="state",
322 args=(ws,),
323 in_axes=(0,),
324 ) # (n_samples, d)
325 pulse_cols.append(ps)
326 target_cols.append(ts)
328 # Stack basis-state outputs as columns of U at every sampled angle.
329 # Resulting shape (n_samples, d, d) with U[s, :, k] = column k.
330 U_pulse = jnp.stack(pulse_cols, axis=-1)
331 U_target = jnp.stack(target_cols, axis=-1)
333 # E = U_target^† U_pulse, shape (n_samples, d, d)
334 E = jnp.einsum("sji,sjk->sik", jnp.conj(U_target), U_pulse)
335 trE = jnp.einsum("sii->s", E)
337 F_pro = jnp.abs(trE) ** 2 / float(d) ** 2
338 process_loss = jnp.mean(jnp.array(1.0, dtype=jnp.float64) - F_pro)
339 phase_loss = jnp.mean(jnp.array(1.0, dtype=jnp.float64) - jnp.cos(jnp.angle(trE)))
341 return (process_loss, phase_loss)
344def joint_unitary_cost_fn(
345 pulse_params: jnp.ndarray,
346 gate_specs: List[dict],
347 n_samples: int,
348) -> Tuple[float, float]:
349 """Joint unitary-level cost summed over multiple target gates.
351 Each entry in ``gate_specs`` is a dictionary describing one target
352 gate that shares the joint parameter vector ``pulse_params``::
354 {
355 "name": str, # gate name (debug)
356 "n_qubits": int,
357 "weight": float, # per-gate weight
358 "assembler": Callable, # theta -> per-gate flat params
359 "pulse_basis_scripts": List[ys.Script], # 2**n_qubits scripts
360 "target_basis_scripts": List[ys.Script],
361 }
363 The total return value is a ``(process_loss, phase_loss)`` tuple
364 where each component is ``Σ_g w_g · loss_g(theta)`` divided by the
365 sum of weights. Sharing the leaf parameters across all target
366 gates pulls the optimum into a basin that is good for *every*
367 use-site (composite gates as well as standalone leaves) — fixing
368 the failure mode where per-gate optimisation pushes a leaf into a
369 "selfish" basin that is optimal for its standalone use but breaks
370 composites that contain it.
372 Args:
373 pulse_params: Joint leaf parameter vector (theta).
374 gate_specs: List of per-gate spec dicts (see above).
375 n_samples: Number of rotation-angle samples per gate.
377 Returns:
378 Tuple ``(process_loss, phase_loss)`` averaged over angles and
379 weighted-summed over gates.
380 """
381 total_proc = jnp.array(0.0, dtype=jnp.float64)
382 total_phase = jnp.array(0.0, dtype=jnp.float64)
383 total_w = 0.0
385 for spec in gate_specs:
386 per_gate_pp = spec["assembler"](pulse_params)
387 proc_loss, phase_loss = unitary_cost_fn(
388 per_gate_pp,
389 spec["pulse_basis_scripts"],
390 spec["target_basis_scripts"],
391 n_samples,
392 spec["n_qubits"],
393 )
394 w = spec["weight"]
395 total_proc = total_proc + w * proc_loss
396 total_phase = total_phase + w * phase_loss
397 total_w += w
399 if total_w > 0:
400 total_proc = total_proc / total_w
401 total_phase = total_phase / total_w
403 return (total_proc, total_phase)
406def pulse_width_cost_fn(
407 pulse_params: jnp.ndarray,
408 envelope: str,
409) -> jnp.ndarray:
410 """
411 Cost function penalising the pulse width (sigma / width).
413 The pulse width is taken as the last envelope parameter. For
414 envelopes with no envelope parameters (e.g. ``"general"``), the cost
415 is zero.
417 Args:
418 pulse_params: Pulse parameters for the gate.
419 envelope: Name of the active pulse envelope.
421 Returns:
422 Scalar pulse-width cost.
423 """
424 envelope_info = PulseEnvelope.get(envelope)
425 n_envelope_params = envelope_info["n_envelope_params"]
427 if n_envelope_params > 0:
428 pulse_width = pulse_params[n_envelope_params - 1]
429 else:
430 pulse_width = 0
432 return jnp.array(pulse_width, dtype=jnp.float64)
435def evolution_time_cost_fn(
436 pulse_params: jnp.ndarray,
437 t_target: float,
438) -> jnp.ndarray:
439 """
440 Cost function penalising deviation of the evolution time from a target.
442 The evolution time is always the last element of the pulse parameter
443 vector. The cost is the squared relative deviation from ``t_target``:
445 cost = ((t - t_target) / t_target) ** 2
447 This encourages all independently optimized gates to converge towards a
448 common evolution time, making them compatible when composed into a
449 circuit.
451 Args:
452 pulse_params: Pulse parameters for the gate.
453 t_target: Target evolution time.
455 Returns:
456 Scalar evolution-time cost.
457 """
458 t = pulse_params[-1]
459 return ((t - t_target) / t_target) ** 2
462def spectral_density_cost_fn(
463 pulse_params: jnp.ndarray,
464 envelope: str,
465 n_fft: int = 1024,
466) -> jnp.ndarray:
467 """
468 Cost function penalising the spectral width of a given pulse.
470 Samples the pulse envelope in the time domain over ``[0, t_evol]``
471 (where ``t_evol`` is the last element of pulse_params), computes its
472 power spectral density via FFT, and returns the normalised RMS bandwidth
473 (square root of the second central moment of the PSD).
475 Pulses with narrow spectra (e.g. Gaussian, DRAG) receive a low cost,
476 whereas pulses with wide spectra (e.g. rectangular) are penalised more
477 heavily.
479 For envelopes with no envelope parameters (e.g. ``"general"``), the
480 cost is zero.
482 Args:
483 pulse_params: Pulse parameters for the gate. Envelope parameters
484 occupy ``pulse_params[:n_envelope_params]`` and the evolution
485 time is ``pulse_params[-1]``.
486 envelope: Name of the active pulse envelope.
487 n_fft: Number of time-domain samples used for the FFT
488 (default 1024).
490 Returns:
491 Scalar spectral-width cost (RMS bandwidth normalised by the
492 Nyquist frequency so the value is in [0, 1]).
493 """
494 envelope_info = PulseEnvelope.get(envelope)
495 n_envelope_params = envelope_info["n_envelope_params"]
496 envelope_fn = envelope_info["fn"]
498 # Nothing to penalise for envelopes without tuneable shape params
499 if n_envelope_params == 0 or envelope_fn is None:
500 return jnp.array(0.0, dtype=jnp.float64)
502 # Extract envelope parameters and evolution time
503 env_params = pulse_params[:n_envelope_params]
504 t_evol = pulse_params[-1]
505 t_c = t_evol / 2.0
507 t_samples = jnp.linspace(0.0, t_evol, n_fft)
508 signal = jax.vmap(lambda t: envelope_fn(env_params, t, t_c))(t_samples)
510 spectrum = jnp.fft.rfft(signal)
511 psd = jnp.abs(spectrum) ** 2
512 psd = psd / (jnp.sum(psd) + 1e-12) # normalise to a distribution
514 freqs = jnp.linspace(0.0, 1.0, len(psd))
516 mean_freq = jnp.sum(freqs * psd)
517 rms_bw = jnp.sqrt(jnp.sum((freqs - mean_freq) ** 2 * psd))
519 return jnp.array(rms_bw, dtype=jnp.float64)
522class CostFnRegistry:
523 """Registry of cost functions available for pulse optimisation.
525 Use :meth:`register` to add new cost functions at runtime and
526 :meth:`get` / :meth:`available` to query them.
527 """
529 _REGISTRY: Dict[str, dict] = {
530 "fidelity": {
531 "fn": fidelity_cost_fn,
532 "default_weight": (0.5, 0.5),
533 "ckwargs_keys": ["pulse_scripts", "target_scripts", "n_samples"],
534 },
535 "unitary": {
536 "fn": unitary_cost_fn,
537 "default_weight": (0.5, 0.5),
538 "ckwargs_keys": [
539 "pulse_basis_scripts",
540 "target_basis_scripts",
541 "n_samples",
542 "n_qubits",
543 ],
544 },
545 "pulse_width": {
546 "fn": pulse_width_cost_fn,
547 "default_weight": 1.0,
548 "ckwargs_keys": ["envelope"],
549 },
550 "evolution_time": {
551 "fn": evolution_time_cost_fn,
552 "default_weight": 1.0,
553 "ckwargs_keys": ["t_target"],
554 },
555 "spectral_density": {
556 "fn": spectral_density_cost_fn,
557 "default_weight": 1.0,
558 "ckwargs_keys": ["envelope"],
559 },
560 }
562 @classmethod
563 def available(cls) -> List[str]:
564 """Return the names of all registered cost functions."""
565 return list(cls._REGISTRY.keys())
567 @classmethod
568 def get(cls, name: str) -> dict:
569 """Look up cost-function metadata by name.
571 Args:
572 name: Registered cost function name.
574 Returns:
575 Metadata dict with keys ``fn``,
576 ``default_weight``, ``ckwargs_keys``.
578 Raises:
579 ValueError: If name is not registered.
580 """
581 if name not in cls._REGISTRY:
582 raise ValueError(
583 f"Unknown cost function '{name}'. Available: {cls.available()}"
584 )
585 return cls._REGISTRY[name]
587 @classmethod
588 def parse_cost_arg(
589 cls, spec: Union[str, Tuple]
590 ) -> Tuple[str, Union[float, Tuple[float, ...]]]:
591 """Parse a ``"name:w1,w2,..."`` CLI string into ``(name, weight)``.
592 If a tuple is provided, it is returned directly.
594 If the weight part is omitted the default weight from the registry
595 is used. A single-component weight is returned as a float;
596 multi-component weights are returned as a tuple of floats.
598 Args:
599 spec: A string of the form ``"name"`` or ``"name:w1,w2,..."``.
601 Returns:
602 A tuple of ``(name, weight)``.
604 Raises:
605 ValueError: If the name is unknown or the number of weight
606 components does not match the ones in ``default_weight``.
607 """
608 if isinstance(spec, tuple):
609 return spec
611 if ":" in spec:
612 name, weight_str = spec.split(":", 1)
613 parts = [float(x) for x in weight_str.split(",")]
614 weight: Union[float, Tuple[float, ...]] = (
615 parts[0] if len(parts) == 1 else tuple(parts)
616 )
617 else:
618 name = spec
619 weight = cls.get(name)["default_weight"]
621 # Validate weight count
622 got = len(weight) if isinstance(weight, tuple) else 1
623 default_weight = cls.get(name)["default_weight"]
624 expected = len(default_weight) if isinstance(default_weight, tuple) else 1
626 if got != expected:
627 raise ValueError(
628 f"Cost function '{name}' expects {expected} weight(s), got {got}."
629 )
631 return name, weight
634class QOC:
635 """Quantum Optimal Control for pulse-level gate synthesis.
637 Optimises pulse parameters to reproduce the unitary of standard
638 quantum gates using a two-stage strategy.
640 Attributes:
641 GATES_1Q: Names of supported single-qubit gates.
642 GATES_2Q: Names of supported two-qubit gates.
643 DEFAULT_PARAM_RANGES: Default parameter ranges for each gate.
644 """
646 GATES_1Q: List[str] = ["RX", "RY", "RZ", "Rot", "H"]
647 GATES_2Q: List[str] = ["CX", "CY", "CZ", "CRX", "CRY", "CRZ"]
649 DEFAULT_PARAM_RANGES = {
650 1: [(0.05, 3.0)], # evolution time
651 2: [(0.05, 3.0), (0.05, 3.0)], # not typically used
652 3: [(0.05, 3.0), (0.05, 3.0), (0.05, 3.0)], # [A, sigma, t]
653 4: [(0.05, 3.0), (0.05, 3.0), (0.05, 3.0), (0.05, 3.0)], # [A, beta, sigma, t]
654 }
656 def __init__(
657 self,
658 envelope: str,
659 cost_fns: List[Tuple[str, Union[float, Tuple[float, ...]]]],
660 t_target: float,
661 n_steps: int,
662 n_samples: int,
663 learning_rate: float,
664 log_interval: int = 50,
665 file_dir: str = None,
666 warmup_ratio: float = 0.0,
667 end_lr_ratio: float = 1.0,
668 n_restarts: int = 1,
669 restart_noise_scale: float = 0.5,
670 grad_clip: float = 1.0,
671 random_seed: int = 42,
672 scan_steps: int = 0,
673 scan_grid_size: int = 5,
674 scan_ranges: Optional[List[Tuple[float, float]]] = None,
675 log_scale_params: Optional[List[int]] = None,
676 early_stop_patience: int = 0,
677 early_stop_min_delta: float = 0.0,
678 plot: bool = False,
679 ):
680 """
681 Initialize Quantum Optimal Control with Pulse-level Gates.
683 Args:
684 envelope (str): Pulse envelope shape to use for optimization.
685 Must be one of the registered envelopes in PulseEnvelope
686 (e.g. 'gaussian', 'square', 'cosine', 'drag', 'sech').
687 cost_fns (list): List of ``(name, weight)`` tuples that select
688 which cost functions to use and their weights. name must
689 be a key in :class:`CostFnRegistry`. *weight* is either a
690 single float or a tuple of floats matching the number of
691 return values of the cost function.
692 t_target (float, optional): Target evolution time for the
693 ``evolution_time`` cost function. Required when
694 ``"evolution_time"`` is among the selected cost functions.
695 n_steps (int): Number of steps in optimization.
696 n_samples (int): Number of parameter samples per step.
697 learning_rate (float): Peak learning rate for AdamW. When a
698 warmup/decay schedule is active this is the maximum LR
699 reached after the warmup phase.
700 log_interval (int): Interval for logging.
701 file_dir (str): Directory to save results.
702 warmup_ratio (float): Fraction of ``n_steps`` used for linear
703 warmup (0.0 - 1.0). Set to 0.0 to disable warmup and use
704 a constant learning rate throughout. A value of e.g. 0.05
705 means the first 5 % of steps linearly ramp the LR from
706 ``end_lr_ratio * learning_rate`` to ``learning_rate``.
707 end_lr_ratio (float): The final learning rate is
708 ``end_lr_ratio * learning_rate``. Also used as the initial
709 LR at the start of warmup. Set to 0.0 for full cosine
710 decay to zero; set to 1.0 (together with
711 ``warmup_ratio=0.0``) to recover a constant LR.
712 n_restarts (int): Number of random restarts for the
713 optimisation. The first run uses the initial parameters
714 as-is; subsequent runs add scaled random perturbations.
715 The best result across all restarts is kept.
716 Set to 1 to disable restarts (default behaviour).
717 restart_noise_scale (float): Standard deviation of the
718 Gaussian noise added to the initial parameters for each
719 restart (relative to the absolute value of each parameter).
720 Defaults to 0.5 (50 % relative perturbation). Note that
721 the package-level default in ``default_qoc_params`` is a
722 much smaller ``0.01`` because the QOC loss landscape is
723 highly sensitive to initial conditions and large
724 perturbations routinely move restarts into useless
725 basins; tune up only if you have reason to believe the
726 initial point is far from any good basin.
727 grad_clip (float): Maximum global gradient norm. Gradients
728 are clipped to this value before being passed to the
729 optimiser, which stabilises training when the loss
730 landscape has steep regions. Set to ``float('inf')`` or
731 0.0 to disable. Defaults to 1.0.
732 random_seed (int): Base random seed for generating restart
733 perturbations. Defaults to 42.
734 scan_steps (int): Number of short gradient-descent steps to
735 run for each candidate in the coarse grid search
736 (Stage 0). Set to 0 to disable the grid scan entirely
737 and rely solely on restarts. A value of 20-50 is
738 usually enough to identify promising basins. Defaults
739 to 0.
740 scan_grid_size (int): Number of points per parameter
741 dimension in the coarse grid. The total number of
742 candidates is ``scan_grid_size ** n_params``, so keep
743 this small for high-dimensional parameter spaces.
744 Defaults to 5.
745 scan_ranges (Optional[List[Tuple[float, float]]]): Per-
746 parameter ``(lo, hi)`` ranges for the grid scan. If
747 ``None``, heuristic ranges are used based on the
748 envelope type: amplitude in ``[0.5, 30]``, width/sigma
749 in ``[0.05, 2]``, and evolution time in ``[0.05, 2]``.
750 Must have length equal to the number of pulse parameters
751 if provided.
752 log_scale_params (Optional[List[int]]): Indices of pulse
753 parameters that should be optimised in log-space. For
754 these parameters the optimizer sees ``log(p)`` and the
755 actual parameter used in the simulation is ``exp(log_p)``.
756 This dramatically improves convergence when the optimal
757 value may differ from the initial value by an order of
758 magnitude (e.g. amplitude, evolution time).
759 If ``None``, defaults to ``[0, -1]`` (amplitude and
760 evolution time) for envelopes with ≥ 2 envelope params,
761 or ``[]`` otherwise.
762 early_stop_patience (int): Number of consecutive
763 Stage-1 steps with no improvement greater than
764 ``early_stop_min_delta`` after which optimisation
765 exits early. Set to ``0`` (default) to disable.
766 Only honoured in the single-restart (sequential)
767 path; when ``n_restarts > 1`` the parallel
768 vmap+scan path always runs the full ``n_steps``.
769 early_stop_min_delta (float): Minimum decrease in loss
770 that counts as an improvement for the early-stopping
771 patience counter. Defaults to ``0.0`` (any strict
772 improvement resets the counter).
773 plot (bool): If ``True``, save a loss-landscape figure after
774 Phase 0 and a loss-curve figure after Phase 1 to
775 ``file_dir``. Requires ``matplotlib`` to be installed.
776 Defaults to ``False``.
777 """
778 self.envelope = envelope
779 self.n_steps = n_steps
780 self.n_samples = n_samples
781 self.learning_rate = learning_rate
782 self.warmup_ratio = warmup_ratio
783 self.end_lr_ratio = end_lr_ratio
784 self.log_interval = log_interval
785 self.file_dir = (
786 file_dir if file_dir else os.path.dirname(os.path.realpath(__file__))
787 )
788 self.t_target = t_target
789 self.n_restarts = max(1, n_restarts)
790 self.restart_noise_scale = restart_noise_scale
791 self.grad_clip = grad_clip
792 self.random_key = jax.random.PRNGKey(random_seed)
793 self.scan_steps = scan_steps
794 self.scan_grid_size = scan_grid_size
795 self.scan_ranges = scan_ranges
797 # Determine log-scale param indices
798 envelope_info = PulseEnvelope.get(envelope)
799 n_env = envelope_info["n_envelope_params"]
800 if log_scale_params is not None:
801 self.log_scale_params = log_scale_params
802 elif n_env >= 2:
803 # Default: amplitude (index 0) and evolution time (last)
804 self.log_scale_params = [0, -1]
805 else:
806 self.log_scale_params = []
808 # Mask cache used by ``_to_log_space``/``_from_log_space``;
809 # rebuilt lazily because the mask length depends on the size of
810 # the param vector being converted (per-gate vs joint).
811 self._log_mask_cache: Dict[int, jnp.ndarray] = {}
813 self.early_stop_patience = max(0, int(early_stop_patience))
814 self.early_stop_min_delta = float(early_stop_min_delta)
816 self.plot = plot
818 log.info(
819 f"Training parameters: {self.n_steps} steps, "
820 f"{self.n_samples} samples, {self.learning_rate} learning rate"
821 )
822 log.info(
823 f"LR schedule: warmup_ratio={self.warmup_ratio}, "
824 f"end_lr_ratio={self.end_lr_ratio}"
825 )
827 log.info(f"Envelope: {self.envelope}")
828 log.info(f"Target evolution time: {self.t_target}")
829 log.info(
830 f"Restarts: {self.n_restarts}, noise_scale={self.restart_noise_scale}, "
831 f"grad_clip={self.grad_clip}"
832 )
833 if PulseInformation.get_rwa():
834 log.info("Using RWA. Rotating frame is ignored.")
835 else:
836 log.info(f"Using no RWA and {PulseInformation.get_frame()} frame.")
838 if self.early_stop_patience > 0:
839 log.info(
840 f"Early stopping: patience={self.early_stop_patience}, "
841 f"min_delta={self.early_stop_min_delta:g}"
842 )
843 log.info(
844 f"Grid scan: scan_steps={self.scan_steps}, "
845 f"scan_grid_size={self.scan_grid_size}, "
846 f"log_scale_params={self.log_scale_params}"
847 )
848 log.info(f"Using cost function(s) {cost_fns}")
850 # Validate each entry against the registry
851 summed_weights = 0
852 for name, _weight in cost_fns:
853 CostFnRegistry.get(name) # raises ValueError if unknown
854 summed_weights += sum(_weight) if isinstance(_weight, tuple) else _weight
855 assert jnp.isclose(summed_weights, 1.0, rtol=1e-8), (
856 f"Cost function weights must sum to 1. Got {summed_weights}"
857 )
859 self.cost_fns = cost_fns
861 # Configure the pulse system with the selected envelope
862 PulseInformation.set_envelope(self.envelope)
864 def save_results(self, gate: str, fidelity: float, pulse_params) -> None:
865 """Save optimised pulse parameters and fidelity for a gate to CSV.
867 If the gate already exists in the file, its entry is overwritten
868 regardless of whether the new fidelity is higher. A warning is
869 logged when the existing fidelity was better.
871 Args:
872 gate: Name of the gate (e.g. ``"RX"``).
873 fidelity: Achieved fidelity of the optimised pulse.
874 pulse_params (jnp.ndarray): Optimised pulse parameters for the gate.
875 """
876 if self.file_dir is not None:
877 os.makedirs(self.file_dir, exist_ok=True)
878 filename = os.path.join(self.file_dir, f"qoc_results_{self.envelope}.csv")
880 reader = None
881 if os.path.isfile(filename):
882 with open(filename, mode="r", newline="") as f:
883 reader = csv.reader(f.readlines())
885 entry = [gate] + [fidelity] + list(map(float, pulse_params))
887 with open(filename, mode="w", newline="") as f:
888 writer = csv.writer(f)
889 match = False
890 if reader is not None:
891 for row in reader:
892 # gate already exists
893 if row[0] == gate:
894 if fidelity <= float(row[1]):
895 log.warning(
896 f"Pulse parameters for {gate} already exist with "
897 f"higher fidelity ({row[1]} >= {fidelity})"
898 )
899 writer.writerow(entry)
900 match = True
901 # any other gate
902 else:
903 writer.writerow(row)
904 # gate does not exist
905 if not match:
906 writer.writerow(entry)
908 def _log_mask(self, n: int) -> jnp.ndarray:
909 """Return a boolean mask of length ``n`` marking log-scaled indices."""
910 cached = self._log_mask_cache.get(n)
911 if cached is not None and cached.shape[0] == n:
912 return cached
913 mask = np.zeros(n, dtype=bool)
914 for idx in self.log_scale_params:
915 i = idx if idx >= 0 else n + idx
916 if 0 <= i < n:
917 mask[i] = True
918 out = jnp.asarray(mask)
919 self._log_mask_cache[n] = out
920 return out
922 def _to_log_space(self, params: jnp.ndarray) -> jnp.ndarray:
923 """Convert selected parameters to log-space for optimisation.
925 Parameters at indices in ``self.log_scale_params`` are replaced
926 by ``log(|p| + eps)`` so the optimiser operates on a
927 logarithmic scale. All other parameters are left unchanged.
928 """
929 if not self.log_scale_params:
930 return params
931 mask = self._log_mask(params.shape[0])
932 log_vals = jnp.log(jnp.abs(params) + 1e-12)
933 return jnp.where(mask, log_vals, params)
935 def _from_log_space(self, log_params: jnp.ndarray) -> jnp.ndarray:
936 """Convert selected parameters back from log-space.
938 Inverse of :meth:`_to_log_space`. Parameters at indices in
939 ``self.log_scale_params`` are exponentiated; all others are
940 passed through unchanged.
941 """
942 if not self.log_scale_params:
943 return log_params
944 mask = self._log_mask(log_params.shape[0])
945 return jnp.where(mask, jnp.exp(log_params), log_params)
947 # Multiplicative factors used to build a centred grid around the
948 # supplied init parameters when no explicit ``scan_ranges`` are
949 # given. ``1.0`` is included so the init point itself is always a
950 # candidate (Stage 0 cannot otherwise re-evaluate it as a grid
951 # point — only as the baseline ``best_scan_loss``).
952 SCAN_REL_FACTORS: Tuple[float, ...] = (0.5, 0.75, 1.0, 1.25, 1.5)
954 def _build_scan_grid(
955 self,
956 n_params: int,
957 init_pulse_params: Optional[jnp.ndarray] = None,
958 ) -> Tuple[jnp.ndarray, List[jnp.ndarray]]:
959 """Build a coarse parameter grid for the initial scan phase.
961 If the user supplied ``scan_ranges`` they take precedence and
962 a log-spaced grid is built within those bounds. Otherwise, when
963 ``init_pulse_params`` is available, a **multiplicative grid
964 centred on the init point** is used (each axis spans
965 ``init * SCAN_REL_FACTORS``) so that already-optimised init
966 params are always re-evaluated and only their immediate
967 neighbourhood is explored. This avoids the failure mode where
968 the global ``DEFAULT_PARAM_RANGES`` brackets exclude the actual
969 optimum (the previous default range was ``(0.05, 3.0)`` per
970 axis, which clipped DRAG amplitudes around 3.1 and made the
971 scan systematically worse than the init point).
973 Args:
974 n_params: Number of pulse parameters.
975 init_pulse_params: Optional init params used to centre the
976 multiplicative grid when ``scan_ranges`` is ``None``.
978 Returns:
979 Tuple of:
980 - Array of shape ``(n_candidates, n_params)`` with grid points.
981 - List of 1-D arrays, one per parameter axis.
982 """
983 if self.scan_ranges is not None:
984 ranges = self.scan_ranges
985 assert len(ranges) == n_params, (
986 f"scan_ranges has {len(ranges)} entries but gate has "
987 f"{n_params} parameters."
988 )
989 # Build log-spaced grids for each parameter
990 axes = []
991 for lo, hi in ranges:
992 axes.append(
993 jnp.logspace(jnp.log10(lo), jnp.log10(hi), self.scan_grid_size)
994 )
995 elif init_pulse_params is not None:
996 # Multiplicative grid centred on init params. We pick
997 # ``scan_grid_size`` factors symmetric around 1.0. When
998 # ``scan_grid_size`` matches the static SCAN_REL_FACTORS
999 # length we use those; otherwise build a symmetric linspace.
1000 if self.scan_grid_size == len(self.SCAN_REL_FACTORS):
1001 factors = jnp.array(self.SCAN_REL_FACTORS, dtype=jnp.float64)
1002 else:
1003 half = (self.scan_grid_size - 1) / 2.0
1004 if half <= 0:
1005 factors = jnp.array([1.0], dtype=jnp.float64)
1006 else:
1007 factors = jnp.linspace(
1008 1.0 - 0.5,
1009 1.0 + 0.5,
1010 self.scan_grid_size,
1011 dtype=jnp.float64,
1012 )
1013 axes = [factors * float(p) for p in init_pulse_params]
1014 else:
1015 # Fall back to legacy log-spaced default ranges
1016 ranges = self.DEFAULT_PARAM_RANGES.get(
1017 n_params,
1018 [(0.1, 10.0)] * n_params,
1019 )
1020 axes = []
1021 for lo, hi in ranges:
1022 axes.append(
1023 jnp.logspace(jnp.log10(lo), jnp.log10(hi), self.scan_grid_size)
1024 )
1026 # Cartesian product of all axes
1027 grid = jnp.array(list(itertools.product(*axes)))
1028 return grid, axes
1030 def stage_0_opt(
1031 self, init_pulse_params: jnp.ndarray, total_cost: Callable
1032 ) -> Tuple[jnp.ndarray, Optional[Tuple[List[jnp.ndarray], list]]]:
1033 """Run the coarse grid-scan phase (Stage 0).
1035 Evaluates a Cartesian grid of parameter candidates using the
1036 **full weighted cost** (fidelity + phase, plus any other
1037 registered terms) — the same objective as Stage 1. Each
1038 candidate is refined with a few fast gradient steps. Returns
1039 the best-found parameters.
1041 Sharing the objective with Stage 1 prevents the grid scan from
1042 landing in a basin that has high fidelity but a biased phase
1043 which Adam then has to migrate out of (the previous
1044 fidelity-only scan caused exactly this failure mode for RX/RY,
1045 whose phase residuals compounded in the CRX decomposition).
1047 Robustness: candidates that produce a non-finite loss (e.g. when
1048 the underlying pulse drives the integrator into a NaN — typical
1049 for very narrow DRAG envelopes) are skipped with a warning. For
1050 the duration of the scan, :class:`qml_essentials.yaqsi.Yaqsi` is
1051 switched into ``throw=False`` mode so a single bad candidate
1052 cannot abort the loop with ``MaxStepsReached``; the previous
1053 defaults are restored on exit.
1055 Args:
1056 init_pulse_params: Initial pulse parameters to compare against.
1057 total_cost: Combined cost callable (same as Stage 1).
1059 Returns:
1060 Tuple of:
1061 - Best pulse parameters found during the scan.
1062 - ``(grid_axes, landscape_data)`` if the grid scan ran, else
1063 ``None``. ``landscape_data`` is a list of
1064 ``(candidate_index, original_params, loss)`` tuples for
1065 every successful scan candidate.
1066 """
1068 def total_cost_log(log_params, *args):
1069 return total_cost(self._from_log_space(log_params), *args)
1071 best_scan_params = init_pulse_params
1072 best_scan_loss = _safe_eval(total_cost, init_pulse_params)
1073 if not jnp.isfinite(best_scan_loss):
1074 log.warning(
1075 "Stage 0: initial pulse parameters produced a non-finite "
1076 "loss; falling back to a placeholder loss of +inf."
1077 )
1079 landscape_data: list = []
1080 axes_out: Optional[List[jnp.ndarray]] = None
1082 if self.scan_steps > 0:
1083 log.info(
1084 f"Stage 0: Grid scan with {self.scan_grid_size}^"
1085 f"{len(init_pulse_params)} candidates, "
1086 f"{self.scan_steps} steps each"
1087 )
1089 grid, axes_out = self._build_scan_grid(
1090 len(init_pulse_params),
1091 init_pulse_params=init_pulse_params,
1092 )
1093 log.info(f" Total candidates: {len(grid)}")
1095 # Use a fast Adam for the scan phase. The aggressive 5×
1096 # multiplier originally used here tended to push refined
1097 # candidates *out* of good basins; 2× keeps the refinement
1098 # localised. Always-evaluate-the-raw-candidate below
1099 # additionally guards against this.
1100 scan_optimizer = optax.chain(
1101 optax.clip_by_global_norm(
1102 self.grad_clip if self.grad_clip > 0 else 1.0
1103 ),
1104 optax.adam(self.learning_rate * 2),
1105 )
1107 @jax.jit
1108 def refine_candidate(log_candidate):
1109 """Run ``self.scan_steps`` Adam steps on a single candidate.
1111 Fused into a single ``jax.lax.scan`` so the whole
1112 refinement is one XLA program — no per-step host
1113 syncs, no Python-loop dispatch. Returns the final
1114 log-params and a scalar bool ``failed`` flag (set if
1115 any intermediate update produced a non-finite value).
1116 """
1118 opt_state0 = scan_optimizer.init(log_candidate)
1120 def body(carry, _):
1121 log_p, opt_state, failed = carry
1122 loss, grads = jax.value_and_grad(total_cost_log)(log_p)
1123 updates, opt_state = scan_optimizer.update(grads, opt_state, log_p)
1124 new_log_p = optax.apply_updates(log_p, updates)
1125 new_failed = failed | (~jnp.all(jnp.isfinite(new_log_p)))
1126 # Freeze on failure so subsequent steps cannot
1127 # propagate NaNs further.
1128 new_log_p = jnp.where(new_failed, log_p, new_log_p)
1129 return (new_log_p, opt_state, new_failed), loss
1131 (final_log_p, _, failed), _ = jax.lax.scan(
1132 body,
1133 (log_candidate, opt_state0, jnp.bool_(False)),
1134 None,
1135 length=self.scan_steps,
1136 )
1137 return final_log_p, failed
1139 # Switch the underlying ODE solver to non-throwing mode for
1140 # the duration of the scan so candidates that exceed the step
1141 # budget produce NaN unitaries (and therefore +inf losses)
1142 # rather than aborting the whole grid loop.
1143 prev_solver_defaults = ys.Yaqsi.set_solver_defaults(throw=False)
1144 n_skipped = 0
1145 n_raw_better = 0
1146 try:
1147 for ci, candidate in enumerate(grid):
1148 log_candidate = self._to_log_space(candidate)
1150 # Evaluate the raw (unrefined) candidate so an
1151 # over-aggressive refinement step cannot discard
1152 # an already-good grid point.
1153 raw_loss = _safe_eval(total_cost, candidate)
1155 try:
1156 log_p, failed_flag = refine_candidate(log_candidate)
1157 except Exception as exc: # pragma: no cover - defensive
1158 log.debug(
1159 f" Candidate {ci + 1}/{len(grid)} "
1160 f"raised during refinement: {exc}; skipping."
1161 )
1162 physical_p = candidate
1163 loss = raw_loss
1164 else:
1165 if bool(failed_flag):
1166 physical_p = candidate
1167 loss = raw_loss
1168 else:
1169 physical_p = self._from_log_space(log_p)
1170 if not jnp.all(jnp.isfinite(physical_p)):
1171 physical_p = candidate
1172 loss = raw_loss
1173 else:
1174 loss = _safe_eval(total_cost, physical_p)
1176 # Keep the better of (raw, refined) for this candidate.
1177 if jnp.isfinite(raw_loss) and (
1178 not jnp.isfinite(loss) or raw_loss < loss
1179 ):
1180 physical_p = candidate
1181 loss = raw_loss
1182 n_raw_better += 1
1184 if not jnp.isfinite(loss):
1185 n_skipped += 1
1186 continue
1188 landscape_data.append((ci, candidate, float(loss)))
1190 if loss < best_scan_loss:
1191 best_scan_loss = loss
1192 best_scan_params = physical_p
1193 log.info(
1194 f" Candidate {ci + 1}/{len(grid)}: "
1195 f"loss={float(loss):.6e} improved with "
1196 f"params={physical_p}"
1197 )
1198 finally:
1199 # Always restore the previous solver defaults so other
1200 # callers (including Stage 1) are unaffected.
1201 if prev_solver_defaults:
1202 ys.Yaqsi.set_solver_defaults(**prev_solver_defaults)
1204 if n_skipped:
1205 log.warning(
1206 f"Stage 0: skipped {n_skipped}/{len(grid)} candidates "
1207 f"due to solver failure or non-finite loss "
1208 f"(typical for very narrow / very large-amplitude "
1209 f"DRAG pulses)."
1210 )
1211 if n_raw_better:
1212 log.info(
1213 f"Stage 0: {n_raw_better}/{len(grid)} candidates "
1214 f"were better unrefined than after the {self.scan_steps}-"
1215 f"step refinement; raw values were kept."
1216 )
1218 log.info(
1219 f"Stage 0 complete. Best loss: "
1220 f"{float(best_scan_loss):.6e}, "
1221 f"params: {best_scan_params}"
1222 )
1224 scan_data = (axes_out, landscape_data) if self.scan_steps > 0 else None
1225 return best_scan_params, scan_data
1227 def stage_1_opt(
1228 self, best_scan_params: jnp.ndarray, total_costs: Callable
1229 ) -> Tuple[jnp.ndarray, list, jnp.ndarray]:
1230 """Run multi-restart gradient optimisation (Stage 1).
1232 Performs ``n_restarts`` independent AdamW runs with the full
1233 (weighted) cost function. The first restart uses
1234 ``best_scan_params`` directly; subsequent restarts add random
1235 perturbations. Parameters specified in ``log_scale_params`` are
1236 optimised in log-space.
1238 When ``n_restarts == 1`` we keep the original single-restart
1239 Python loop (it preserves per-step ``log.info`` granularity
1240 and avoids the vmap/scan compilation overhead). When
1241 ``n_restarts > 1`` we ``vmap`` the optimiser over restarts and
1242 run the inner step loop with :func:`jax.lax.scan`, fusing all
1243 ``n_restarts × n_steps`` steps into a single XLA program.
1245 Args:
1246 best_scan_params: Starting parameters (typically from Stage 0).
1247 total_costs: Combined cost callable.
1249 Returns:
1250 Tuple of ``(best_params, loss_history, best_loss)`` from the
1251 best restart.
1252 """
1254 # Wrap the cost function with log-space reparameterisation
1255 def total_costs_log(log_params):
1256 return total_costs(self._from_log_space(log_params))
1258 # Build learning rate schedule
1259 warmup_steps = int(self.n_steps * self.warmup_ratio)
1260 end_value = self.learning_rate * self.end_lr_ratio
1262 if warmup_steps > 0 or self.end_lr_ratio < 1.0:
1263 schedule = optax.warmup_cosine_decay_schedule(
1264 init_value=(end_value if warmup_steps > 0 else self.learning_rate),
1265 peak_value=self.learning_rate,
1266 warmup_steps=warmup_steps,
1267 decay_steps=self.n_steps,
1268 end_value=end_value,
1269 )
1270 else:
1271 schedule = self.learning_rate
1273 optimizer = _build_optimizer(schedule, self.grad_clip)
1275 if self.n_restarts <= 1:
1276 return self._stage_1_sequential(
1277 best_scan_params, total_costs, total_costs_log, optimizer
1278 )
1279 return self._stage_1_parallel(
1280 best_scan_params, total_costs, total_costs_log, optimizer
1281 )
1283 def _perturb_starts(self, start_params: jnp.ndarray) -> jnp.ndarray:
1284 """Pre-build the ``(n_restarts, n_params)`` matrix of restart starts.
1286 Restart 0 is the unperturbed start; subsequent restarts add
1287 Gaussian noise scaled by ``max(|start|, 0.1) *
1288 restart_noise_scale``. Indices that are optimised in
1289 log-space (plus the evolution time at index ``-1``) are kept
1290 positive via ``jnp.abs`` so the subsequent ``log`` is safe.
1291 """
1292 n_params = start_params.shape[0]
1293 keys = jax.random.split(self.random_key, self.n_restarts)
1294 # Shape (n_restarts, n_params); restart 0 is intentionally zero
1295 # noise so the unperturbed start is preserved.
1296 noise = jax.vmap(lambda k: jax.random.normal(k, shape=(n_params,)))(keys)
1297 noise = noise.at[0].set(0.0)
1298 scale = jnp.maximum(jnp.abs(start_params), 0.1) * self.restart_noise_scale
1299 starts = start_params[None, :] + noise * scale[None, :]
1301 # Keep the evolution time and any log-scaled indices positive.
1302 positive_mask = np.zeros(n_params, dtype=bool)
1303 positive_mask[-1] = True
1304 for idx in self.log_scale_params:
1305 i = idx if idx >= 0 else n_params + idx
1306 if 0 <= i < n_params:
1307 positive_mask[i] = True
1308 positive_mask_j = jnp.asarray(positive_mask)
1309 starts = jnp.where(positive_mask_j[None, :], jnp.abs(starts), starts)
1310 return starts
1312 def _stage_1_sequential(
1313 self,
1314 start_params: jnp.ndarray,
1315 total_costs: Callable,
1316 total_costs_log: Callable,
1317 optimizer,
1318 ) -> Tuple[jnp.ndarray, list, jnp.ndarray]:
1319 """Single-restart Stage 1, fused into a single ``jax.lax.scan``.
1321 The whole optimisation loop (n_steps × Adam updates) compiles
1322 to one XLA program, eliminating the per-step Python overhead
1323 and per-step host/device syncs that the previous Python ``for``
1324 loop incurred. Early stopping is preserved via *masked
1325 updates*: once the patience condition trips, subsequent steps
1326 leave the parameters and loss unchanged. Compute is not
1327 skipped (lax.scan has fixed length) but the optimiser state
1328 and parameter trajectory freeze, matching the previous
1329 early-stop semantics modulo wall-clock savings.
1330 """
1332 params = start_params
1333 log_params = self._to_log_space(params)
1334 opt_state = optimizer.init(log_params)
1336 init_loss = total_costs(params)
1337 min_delta = self.early_stop_min_delta
1338 patience = self.early_stop_patience
1339 # ``patience <= 0`` ⇒ early stopping disabled. Use a large
1340 # constant so the masked-update path is never triggered.
1341 eff_patience = patience if patience > 0 else self.n_steps + 1
1343 def scan_body(carry, _):
1344 (
1345 log_params,
1346 opt_state,
1347 best_loss,
1348 best_log_params,
1349 steps_since_improve,
1350 stopped_flag,
1351 stopped_step,
1352 step_idx,
1353 ) = carry
1355 loss, grads = jax.value_and_grad(total_costs_log)(log_params)
1356 updates, new_opt_state = optimizer.update(grads, opt_state, log_params)
1357 stepped_log_params = optax.apply_updates(log_params, updates)
1359 # Improvement test (uses the pre-update loss, matching the
1360 # original semantics where the loss recorded on step *i*
1361 # corresponds to the params *before* that step's update).
1362 improved = loss < best_loss - min_delta
1363 best_loss = jnp.where(improved, loss, best_loss)
1364 # Save the params that *produced* the improving loss
1365 # (i.e. the pre-update ``log_params``). ``improved`` is a
1366 # scalar bool and broadcasts against the 1-D params arrays.
1367 best_log_params = jnp.where(improved, log_params, best_log_params)
1368 steps_since_improve = jnp.where(
1369 improved, jnp.int32(0), steps_since_improve + jnp.int32(1)
1370 )
1372 # Latch the early-stop flag once it fires.
1373 trigger = steps_since_improve >= jnp.int32(eff_patience)
1374 new_stopped_flag = stopped_flag | trigger
1375 stopped_step = jnp.where(
1376 stopped_flag,
1377 stopped_step,
1378 jnp.where(trigger, step_idx + jnp.int32(1), stopped_step),
1379 )
1381 # Mask the update once stopped: freeze params/optimiser.
1382 new_log_params = jnp.where(new_stopped_flag, log_params, stepped_log_params)
1383 new_opt_state_kept = jax.tree_util.tree_map(
1384 lambda new, old: jnp.where(new_stopped_flag, old, new),
1385 new_opt_state,
1386 opt_state,
1387 )
1389 new_carry = (
1390 new_log_params,
1391 new_opt_state_kept,
1392 best_loss,
1393 best_log_params,
1394 steps_since_improve,
1395 new_stopped_flag,
1396 stopped_step,
1397 step_idx + jnp.int32(1),
1398 )
1399 return new_carry, loss
1401 init_carry = (
1402 log_params, # log_params
1403 opt_state, # opt_state
1404 init_loss, # best_loss
1405 log_params, # best_log_params
1406 jnp.int32(0), # steps_since_improve
1407 jnp.bool_(False), # stopped_flag
1408 jnp.int32(self.n_steps), # stopped_step (default = n_steps)
1409 jnp.int32(0), # step_idx
1410 )
1412 @jax.jit
1413 def run_scan(carry):
1414 return jax.lax.scan(scan_body, carry, None, length=self.n_steps)
1416 final_carry, step_losses = run_scan(init_carry)
1417 (
1418 _,
1419 _,
1420 best_loss,
1421 best_log_params,
1422 _,
1423 stopped_flag,
1424 stopped_step,
1425 _,
1426 ) = final_carry
1428 # One sync: pull just what we need for logging in a single
1429 # device->host transfer instead of a per-step ``.item()`` call.
1430 host_step_losses, host_best_loss, host_stopped, host_stopped_step = (
1431 jax.device_get((step_losses, best_loss, stopped_flag, stopped_step))
1432 )
1434 # Periodic progress log (replaces the per-step inline log;
1435 # cheap because step losses already live on the host).
1436 for step in range(0, self.n_steps, max(1, self.log_interval)):
1437 log.info(
1438 f"Step {step}/{self.n_steps}, Loss: {float(host_step_losses[step]):.3e}"
1439 )
1440 if bool(host_stopped):
1441 log.info(
1442 f"Early stop at step {int(host_stopped_step)}/{self.n_steps} "
1443 f"(no improvement > {min_delta:g} for "
1444 f"{self.early_stop_patience} steps)."
1445 )
1447 log.info(
1448 f"Restart 1/1 finished with best loss: {float(host_best_loss):.3e}"
1449 + (
1450 f" (early stopped at step {int(host_stopped_step)})"
1451 if bool(host_stopped)
1452 else ""
1453 )
1454 )
1456 # Reconstruct the historical loss list shape: leading entry is
1457 # the initial (pre-step-0) loss, followed by one entry per
1458 # scan step. Match the previous return type (``list``) so
1459 # downstream plotting code is unchanged.
1460 loss_history = [init_loss] + list(step_losses)
1462 best_pulse_params = self._from_log_space(best_log_params)
1463 return best_pulse_params, loss_history, best_loss
1465 def _stage_1_parallel(
1466 self,
1467 start_params: jnp.ndarray,
1468 total_costs: Callable,
1469 total_costs_log: Callable,
1470 optimizer,
1471 ) -> Tuple[jnp.ndarray, list, jnp.ndarray]:
1472 """Vmap+scan Stage 1: all restarts × all steps in one XLA program.
1474 Always runs the full ``n_steps``: an early-stop break would
1475 require either chunking the scan (extra Python overhead) or
1476 masking updates inside the scan (no compute saved), and
1477 because every restart would have to plateau before we could
1478 break, the win is small. Sequential mode (``n_restarts == 1``)
1479 does honour ``early_stop_patience``.
1480 """
1482 # (n_restarts, n_params) starting points (restart 0 unperturbed).
1483 params_batch = self._perturb_starts(start_params)
1484 log.info(
1485 f"Stage 1 (parallel): vmapping {self.n_restarts} restarts × "
1486 f"{self.n_steps} steps in a single fused program."
1487 )
1488 if self.early_stop_patience > 0:
1489 log.info(
1490 "Note: early_stop_patience is ignored in the parallel "
1491 "(n_restarts > 1) path; the full n_steps will run."
1492 )
1494 log_params_batch = jax.vmap(self._to_log_space)(params_batch)
1495 opt_state_batch = jax.vmap(optimizer.init)(log_params_batch)
1497 # Initial losses (per-restart) so loss_history[0] matches the
1498 # per-restart sequential semantics.
1499 init_losses = jax.vmap(total_costs)(params_batch)
1501 def opt_step(log_params, opt_state):
1502 loss, grads = jax.value_and_grad(total_costs_log)(log_params)
1503 updates, opt_state = optimizer.update(grads, opt_state, log_params)
1504 log_params = optax.apply_updates(log_params, updates)
1505 return log_params, opt_state, loss
1507 v_opt_step = jax.vmap(opt_step, in_axes=(0, 0))
1509 def scan_body(carry, _):
1510 log_params, opt_state, prev_log_params, best_loss, best_log_params = carry
1511 new_log_params, new_opt_state, loss = v_opt_step(log_params, opt_state)
1512 # Track best loss (and the params that *produced* it,
1513 # which are the pre-update ``prev_log_params`` — same
1514 # rationale as the sequential path).
1515 improved = loss < best_loss
1516 best_loss = jnp.where(improved, loss, best_loss)
1517 best_log_params = jnp.where(
1518 improved[:, None], prev_log_params, best_log_params
1519 )
1520 new_carry = (
1521 new_log_params,
1522 new_opt_state,
1523 log_params, # becomes prev for the next step
1524 best_loss,
1525 best_log_params,
1526 )
1527 return new_carry, loss
1529 init_carry = (
1530 log_params_batch,
1531 opt_state_batch,
1532 log_params_batch,
1533 init_losses,
1534 log_params_batch,
1535 )
1537 @jax.jit
1538 def run_scan(carry):
1539 return jax.lax.scan(scan_body, carry, None, length=self.n_steps)
1541 final_carry, step_losses = run_scan(init_carry)
1542 # step_losses shape (n_steps, n_restarts); each row is the
1543 # cross-restart loss vector at one optimisation step.
1544 _, _, _, best_losses, best_log_params_batch = final_carry
1546 # Periodic batch summary so the operator still sees progress.
1547 # Pull the small per-step loss matrix to host once, then format
1548 # without further device→host transfers.
1549 host_step_losses = jax.device_get(step_losses)
1550 for step in range(0, self.n_steps, max(1, self.log_interval)):
1551 row = host_step_losses[step]
1552 log.info(
1553 f"Step {step}/{self.n_steps}, "
1554 f"loss min/mean/max: {float(row.min()):.3e} / "
1555 f"{float(row.mean()):.3e} / {float(row.max()):.3e}"
1556 )
1558 # Per-restart final summary (single sync for ``best_losses``).
1559 host_best_losses = jax.device_get(best_losses)
1560 for r in range(self.n_restarts):
1561 log.info(
1562 f"Restart {r + 1}/{self.n_restarts} finished "
1563 f"with best loss: {float(host_best_losses[r]):.3e}"
1564 )
1566 winner = int(jnp.argmin(best_losses))
1567 global_best_loss = best_losses[winner]
1568 global_best_params = self._from_log_space(best_log_params_batch[winner])
1570 # Build a per-step loss history for the winning restart so the
1571 # downstream API (and the loss-curve plot) keeps the same
1572 # shape as before.
1573 winner_history = [init_losses[winner]]
1574 winner_history.extend(step_losses[:, winner])
1575 return global_best_params, winner_history, global_best_loss
1577 def plot_loss_landscape(
1578 self,
1579 gate_name: str,
1580 grid_axes: List[jnp.ndarray],
1581 landscape_data: list,
1582 ) -> None:
1583 """Save a loss-landscape figure for the Phase-0 grid scan.
1585 The visualisation adapts to the number of pulse parameters:
1587 - **1 parameter**: line/scatter plot (param value vs. loss).
1588 - **2 parameters**: 2-D heatmap (param₀ × param₁, colour = loss).
1589 - **≥ 3 parameters**: horizontal scatter sorted by ascending loss
1590 with the best candidate highlighted.
1592 The figure is saved to ``{file_dir}/{gate_name}_loss_landscape.png``.
1594 Args:
1595 gate_name: Name of the gate being optimised (e.g. ``"RX"``).
1596 grid_axes: Per-parameter 1-D arrays that span the scan grid.
1597 landscape_data: List of ``(candidate_index, params, loss)``
1598 tuples for every successful scan candidate.
1599 """
1600 import matplotlib.pyplot as plt # lazy — matplotlib is dev-only
1602 if not landscape_data:
1603 log.warning("plot_loss_landscape: no landscape data to plot, skipping.")
1604 return
1606 os.makedirs(self.file_dir, exist_ok=True)
1607 n_params = len(grid_axes)
1608 indices, _params_list, losses = zip(*landscape_data)
1609 losses_arr = np.array(losses, dtype=float)
1611 fig, ax = plt.subplots(figsize=(8, 5))
1613 if n_params == 1:
1614 x = np.array([float(grid_axes[0][i]) for i in indices])
1615 sc = ax.scatter(
1616 x, losses_arr, c=losses_arr, cmap="viridis_r", s=60, zorder=3
1617 )
1618 fig.colorbar(sc, ax=ax, label="Loss")
1619 best_i = int(np.argmin(losses_arr))
1620 ax.scatter(
1621 x[best_i],
1622 losses_arr[best_i],
1623 marker="*",
1624 s=200,
1625 color="red",
1626 zorder=4,
1627 label="best",
1628 )
1629 ax.set_xlabel("Parameter value")
1630 ax.set_xscale("log")
1631 ax.set_yscale("log")
1632 ax.legend()
1634 elif n_params == 2:
1635 n = self.scan_grid_size
1636 loss_grid = np.full((n, n), np.nan)
1637 for ci, _, loss in landscape_data:
1638 row = ci // n
1639 col = ci % n
1640 loss_grid[row, col] = loss
1641 masked = np.ma.masked_invalid(loss_grid)
1642 cmap = plt.cm.viridis_r.copy()
1643 cmap.set_bad(color="lightgrey")
1644 im = ax.imshow(
1645 masked,
1646 origin="lower",
1647 cmap=cmap,
1648 aspect="auto",
1649 extent=[
1650 float(grid_axes[1][0]),
1651 float(grid_axes[1][-1]),
1652 float(grid_axes[0][0]),
1653 float(grid_axes[0][-1]),
1654 ],
1655 )
1656 fig.colorbar(im, ax=ax, label="Loss")
1657 ax.set_xlabel("Parameter 1")
1658 ax.set_ylabel("Parameter 0")
1660 else: # n_params >= 3: sorted scatter
1661 order = np.argsort(losses_arr)
1662 sorted_losses = losses_arr[order]
1663 sorted_indices = np.array(indices)[order] # original trial numbers
1664 ranks = np.arange(len(sorted_losses))
1665 sc = ax.scatter(
1666 sorted_losses,
1667 ranks,
1668 c=sorted_indices,
1669 cmap="plasma",
1670 s=40,
1671 zorder=3,
1672 )
1673 fig.colorbar(sc, ax=ax, label="Trial number")
1674 ax.scatter(
1675 sorted_losses[0],
1676 ranks[0],
1677 marker="*",
1678 s=200,
1679 color="red",
1680 zorder=4,
1681 label="best",
1682 )
1683 ax.set_xlabel("Loss")
1684 ax.set_ylabel("Candidate rank (0 = best)")
1685 ax.set_xscale("log")
1686 ax.legend()
1688 ax.set_title(f"Loss Landscape (Phase 0) — {gate_name}")
1689 fig.tight_layout()
1690 path = os.path.join(self.file_dir, f"{gate_name}_loss_landscape.png")
1691 fig.savefig(path, dpi=150)
1692 plt.close(fig)
1693 log.info(f"Loss landscape saved to {path}")
1695 def plot_loss_curve(
1696 self,
1697 gate_name: str,
1698 loss_history: list,
1699 ) -> None:
1700 """Save a training-loss curve figure for the Phase-1 optimisation.
1702 Shows loss vs. optimisation step on a log y-scale with a dashed
1703 horizontal line at the minimum achieved loss.
1705 The figure is saved to ``{file_dir}/{gate_name}_loss_curve.png``.
1707 Args:
1708 gate_name: Name of the gate being optimised (e.g. ``"RX"``).
1709 loss_history: Sequence of loss values, one per step (including
1710 the initial loss at index 0).
1711 """
1712 import matplotlib.pyplot as plt # lazy — matplotlib is dev-only
1714 if not loss_history:
1715 log.warning("plot_loss_curve: empty loss history, skipping.")
1716 return
1718 os.makedirs(self.file_dir, exist_ok=True)
1719 losses = [float(v) for v in loss_history]
1720 best = min(losses)
1722 fig, ax = plt.subplots(figsize=(9, 4))
1723 ax.plot(losses, linewidth=1.2, label="Loss")
1724 ax.axhline(
1725 best, color="red", linestyle="--", linewidth=1.0, label=f"Best: {best:.3e}"
1726 )
1727 ax.set_xlabel("Step")
1728 ax.set_ylabel("Loss")
1729 ax.set_yscale("log")
1730 ax.set_title(f"Training Loss (Phase 1) — {gate_name}")
1731 ax.legend()
1732 fig.tight_layout()
1733 path = os.path.join(self.file_dir, f"{gate_name}_loss_curve.png")
1734 fig.savefig(path, dpi=150)
1735 plt.close(fig)
1736 log.info(f"Loss curve saved to {path}")
1738 def optimize(self, wires: int) -> Callable:
1739 """Decorator factory that optimises pulse parameters for a gate.
1741 Usage::
1743 opt = qoc.optimize(wires=1)
1744 best_params, loss_history = opt(qoc.create_RX)()
1746 Args:
1747 wires: Number of qubits the gate acts on.
1749 Returns:
1750 A decorator that accepts a circuit-factory function and
1751 returns a callable ``(init_pulse_params=None) ->
1752 (best_params, loss_history)``.
1753 """
1755 def decorator(create_circuits):
1756 def wrapper(init_pulse_params: jnp.ndarray = None):
1757 """
1758 Optimise pulse parameters for a quantum gate using a
1759 multi-phase strategy:
1761 Stage 0 - Grid scan (if ``scan_steps > 0``):
1762 Evaluate a coarse grid of parameter candidates using
1763 the same weighted cost as Stage 1. Each candidate
1764 is refined with a few fast gradient steps. The
1765 best candidate becomes the starting point for
1766 Stage 1, unless the user-supplied init_pulse_params
1767 are already better.
1769 Stage 1 - Multi-restart gradient optimisation:
1770 Run ``n_restarts`` independent Adam optimisation runs
1771 with the full cost function. The first restart uses
1772 the best point found so far; subsequent restarts add
1773 random perturbations. Parameters at indices in
1774 ``log_scale_params`` are optimised in log-space to
1775 handle order-of-magnitude differences in scale.
1777 Args:
1778 init_pulse_params (array): Initial pulse parameters.
1779 If ``None``, uses the envelope defaults from
1780 :class:`PulseInformation`.
1782 Returns:
1783 tuple: ``(best_params, loss_history)`` from the best
1784 restart.
1785 """
1786 pulse_circuit, target_circuit = create_circuits()
1788 # Build a second pair that prepends a Hadamard on every
1789 # wire so the cost is also evaluated from the
1790 # ``|+⟩^⊗n`` initial state. Probing two non-collinear
1791 # initial states exposes rotation-axis tilt to the
1792 # optimiser: an RX/RY pulse with a residual Z component
1793 # is partly degenerate from ``|0⟩`` alone but produces
1794 # a clearly distinguishable trajectory from ``|+⟩``.
1795 # Both circuits get the same preparation so the target
1796 # remains exact.
1797 def _with_plus_prep(circuit_fn):
1798 def prepared(*args, **kwargs):
1799 for q in range(wires):
1800 op.H(wires=q)
1801 circuit_fn(*args, **kwargs)
1803 prepared.__name__ = f"plus_{circuit_fn.__name__}"
1804 return prepared
1806 pulse_circuit_plus = _with_plus_prep(pulse_circuit)
1807 target_circuit_plus = _with_plus_prep(target_circuit)
1809 pulse_scripts = [
1810 ys.Script(pulse_circuit, n_qubits=wires),
1811 ys.Script(pulse_circuit_plus, n_qubits=wires),
1812 ]
1813 target_scripts = [
1814 ys.Script(target_circuit, n_qubits=wires),
1815 ys.Script(target_circuit_plus, n_qubits=wires),
1816 ]
1818 d_basis = 2**wires
1819 pulse_basis_scripts = [
1820 ys.Script(_with_basis_prep(pulse_circuit, k, wires), n_qubits=wires)
1821 for k in range(d_basis)
1822 ]
1823 target_basis_scripts = [
1824 ys.Script(
1825 _with_basis_prep(target_circuit, k, wires), n_qubits=wires
1826 )
1827 for k in range(d_basis)
1828 ]
1830 gate_name = create_circuits.__name__.split("_")[1]
1832 if init_pulse_params is None:
1833 init_pulse_params = PulseInformation.gate_by_name(gate_name).params
1834 log.debug(
1835 f"Initial pulse parameters for {gate_name}: {init_pulse_params}"
1836 )
1838 all_ckwargs = {
1839 "pulse_scripts": pulse_scripts,
1840 "target_scripts": target_scripts,
1841 "pulse_basis_scripts": pulse_basis_scripts,
1842 "target_basis_scripts": target_basis_scripts,
1843 "envelope": self.envelope,
1844 "n_samples": self.n_samples,
1845 "n_qubits": wires,
1846 "t_target": self.t_target,
1847 }
1849 def _build_cost(name, weight):
1850 """Build a Cost from a registry entry, filtering ckwargs."""
1851 meta = CostFnRegistry.get(name)
1852 return Cost(
1853 cost=meta["fn"],
1854 weight=weight,
1855 ckwargs={
1856 k: v
1857 for k, v in all_ckwargs.items()
1858 if k in meta["ckwargs_keys"]
1859 },
1860 )
1862 total_costs = None
1863 for name, weight in self.cost_fns:
1864 total_costs = _build_cost(name, weight) + total_costs
1866 best_scan_params, scan_data = self.stage_0_opt(
1867 init_pulse_params,
1868 total_costs,
1869 )
1871 global_best_params, global_best_history, global_best_loss = (
1872 self.stage_1_opt(
1873 best_scan_params,
1874 total_costs,
1875 )
1876 )
1877 self.save_results(
1878 gate=gate_name,
1879 fidelity=1 - global_best_loss.item(),
1880 pulse_params=global_best_params,
1881 )
1883 if self.plot:
1884 if scan_data is not None:
1885 grid_axes, landscape_items = scan_data
1886 self.plot_loss_landscape(gate_name, grid_axes, landscape_items)
1887 self.plot_loss_curve(gate_name, global_best_history)
1889 return global_best_params, global_best_history
1891 return wrapper
1893 return decorator
1895 # ------------------------------------------------------------------
1896 # Per-gate (pulse, target) circuit factories
1897 # ------------------------------------------------------------------
1898 #
1899 # Each entry maps a gate name to a ``(pulse_circuit, target_circuit)``
1900 # pair. The per-gate variants prepend a symmetry-breaking
1901 # preparation (e.g. ``op.H``/``op.RY``) so the *state-vector* cost
1902 # is sensitive to rotation-axis tilt. The joint-mode variants drop
1903 # those preps because the unitary cost already captures axis tilt
1904 # without probe-state trickery (see :meth:`_create_joint_pair_for`).
1906 @staticmethod
1907 def _gate_factories() -> Dict[str, Tuple[Callable, Callable]]:
1908 """Return the ``{gate_name: (pulse_fn, target_fn)}`` table.
1910 Constructed lazily inside a staticmethod so the closures
1911 capture the imported gate symbols at call time.
1912 """
1914 return {
1915 "RX": _make_gate_pair(
1916 lambda w, pp: Gates.RX(w, 0, pulse_params=pp, gate_mode="pulse"),
1917 lambda w: op.RX(w, wires=0),
1918 ),
1919 "RY": _make_gate_pair(
1920 lambda w, pp: Gates.RY(w, 0, pulse_params=pp, gate_mode="pulse"),
1921 lambda w: op.RY(w, wires=0),
1922 ),
1923 "RZ": _make_gate_pair(
1924 lambda w, pp: Gates.RZ(w, 0, pulse_params=pp, gate_mode="pulse"),
1925 lambda w: op.RZ(w, wires=0),
1926 prep=lambda w: op.H(wires=0),
1927 post=lambda w: op.H(wires=0),
1928 ),
1929 "H": _make_gate_pair(
1930 lambda w, pp: Gates.H(0, pulse_params=pp, gate_mode="pulse"),
1931 lambda w: op.H(wires=0),
1932 prep=lambda w: op.RY(w, wires=0),
1933 ),
1934 "Rot": _make_gate_pair(
1935 lambda w, pp: Gates.Rot(
1936 w, w * 2, w * 3, 0, pulse_params=pp, gate_mode="pulse"
1937 ),
1938 lambda w: op.Rot(w, w * 2, w * 3, wires=0),
1939 prep=lambda w: op.H(wires=0),
1940 ),
1941 "CX": _make_gate_pair(
1942 lambda w, pp: Gates.CX(
1943 wires=[0, 1], pulse_params=pp, gate_mode="pulse"
1944 ),
1945 lambda w: op.CX(wires=[0, 1]),
1946 prep=_chain_gate_stages(
1947 lambda w: op.RY(w, wires=0),
1948 lambda w: op.H(wires=1),
1949 ),
1950 ),
1951 "CY": _make_gate_pair(
1952 lambda w, pp: Gates.CY(
1953 wires=[0, 1], pulse_params=pp, gate_mode="pulse"
1954 ),
1955 lambda w: op.CY(wires=[0, 1]),
1956 prep=_chain_gate_stages(
1957 lambda w: op.RX(w, wires=0),
1958 lambda w: op.H(wires=1),
1959 ),
1960 ),
1961 "CZ": _make_gate_pair(
1962 lambda w, pp: Gates.CZ(
1963 wires=[0, 1], pulse_params=pp, gate_mode="pulse"
1964 ),
1965 lambda w: op.CZ(wires=[0, 1]),
1966 prep=_chain_gate_stages(
1967 lambda w: op.RY(w, wires=0),
1968 lambda w: op.H(wires=1),
1969 ),
1970 ),
1971 "CRX": _make_gate_pair(
1972 lambda w, pp: Gates.CRX(
1973 w, wires=[0, 1], pulse_params=pp, gate_mode="pulse"
1974 ),
1975 lambda w: op.CRX(w, wires=[0, 1]),
1976 prep=lambda w: op.H(wires=0),
1977 ),
1978 "CRY": _make_gate_pair(
1979 lambda w, pp: Gates.CRY(
1980 w, wires=[0, 1], pulse_params=pp, gate_mode="pulse"
1981 ),
1982 lambda w: op.CRY(w, wires=[0, 1]),
1983 prep=lambda w: op.H(wires=0),
1984 ),
1985 "CRZ": _make_gate_pair(
1986 lambda w, pp: Gates.CRZ(
1987 w, wires=[0, 1], pulse_params=pp, gate_mode="pulse"
1988 ),
1989 lambda w: op.CRZ(w, wires=[0, 1]),
1990 prep=_chain_gate_stages(
1991 lambda w: op.H(wires=0),
1992 lambda w: op.H(wires=1),
1993 ),
1994 ),
1995 }
1997 @staticmethod
1998 def _joint_gate_factories() -> Dict[str, Tuple[Callable, Callable]]:
1999 """``(pulse, target)`` pairs without any symmetry-breaking preps.
2001 Used by :meth:`_create_joint_pair_for`: the unitary cost
2002 already exposes rotation-axis tilt without a probe state, and
2003 leaving the preps in actively *hides* certain errors (e.g.
2004 ``op.H(wires=1)`` puts the target qubit of CX into a CX
2005 eigenstate, so the column-stacked unitary becomes insensitive
2006 to the pulse error). ``Rot`` and ``CY`` are intentionally
2007 absent because the joint optimiser does not target them.
2008 """
2010 return {
2011 "RX": _make_gate_pair(
2012 lambda w, pp: Gates.RX(w, wires=0, pulse_params=pp, gate_mode="pulse"),
2013 lambda w: op.RX(w, wires=0),
2014 ),
2015 "RY": _make_gate_pair(
2016 lambda w, pp: Gates.RY(w, wires=0, pulse_params=pp, gate_mode="pulse"),
2017 lambda w: op.RY(w, wires=0),
2018 ),
2019 "RZ": _make_gate_pair(
2020 lambda w, pp: Gates.RZ(w, wires=0, pulse_params=pp, gate_mode="pulse"),
2021 lambda w: op.RZ(w, wires=0),
2022 ),
2023 "H": _make_gate_pair(
2024 lambda w, pp: Gates.H(0, pulse_params=pp, gate_mode="pulse"),
2025 lambda w: op.H(wires=0),
2026 ),
2027 "CZ": _make_gate_pair(
2028 lambda w, pp: Gates.CZ(
2029 wires=[0, 1], pulse_params=pp, gate_mode="pulse"
2030 ),
2031 lambda w: op.CZ(wires=[0, 1]),
2032 ),
2033 "CX": _make_gate_pair(
2034 lambda w, pp: Gates.CX(
2035 wires=[0, 1], pulse_params=pp, gate_mode="pulse"
2036 ),
2037 lambda w: op.CX(wires=[0, 1]),
2038 ),
2039 "CRX": _make_gate_pair(
2040 lambda w, pp: Gates.CRX(
2041 w, wires=[0, 1], pulse_params=pp, gate_mode="pulse"
2042 ),
2043 lambda w: op.CRX(w, wires=[0, 1]),
2044 ),
2045 "CRY": _make_gate_pair(
2046 lambda w, pp: Gates.CRY(
2047 w, wires=[0, 1], pulse_params=pp, gate_mode="pulse"
2048 ),
2049 lambda w: op.CRY(w, wires=[0, 1]),
2050 ),
2051 "CRZ": _make_gate_pair(
2052 lambda w, pp: Gates.CRZ(
2053 w, wires=[0, 1], pulse_params=pp, gate_mode="pulse"
2054 ),
2055 lambda w: op.CRZ(w, wires=[0, 1]),
2056 ),
2057 }
2059 def _create_pair(self, gate_name: str) -> Tuple[Callable, Callable]:
2060 """Look up the per-gate ``(pulse, target)`` pair from the table."""
2061 try:
2062 return self._gate_factories()[gate_name]
2063 except KeyError as exc:
2064 raise ValueError(f"No factory for gate {gate_name!r}.") from exc
2066 # Thin compatibility wrappers around :meth:`_create_pair` so existing
2067 # code (and tests) that call ``qoc.create_<gate>`` keep working.
2068 def create_RX(self):
2069 return self._create_pair("RX")
2071 def create_RY(self):
2072 return self._create_pair("RY")
2074 def create_RZ(self):
2075 return self._create_pair("RZ")
2077 def create_H(self):
2078 return self._create_pair("H")
2080 def create_Rot(self):
2081 return self._create_pair("Rot")
2083 def create_CX(self):
2084 return self._create_pair("CX")
2086 def create_CY(self):
2087 return self._create_pair("CY")
2089 def create_CZ(self):
2090 return self._create_pair("CZ")
2092 def create_CRX(self):
2093 return self._create_pair("CRX")
2095 def create_CRY(self):
2096 return self._create_pair("CRY")
2098 def create_CRZ(self):
2099 return self._create_pair("CRZ")
2101 def create_CPhase(self):
2102 """Create pulse and target circuits for the CPhase gate."""
2104 def pulse_circuit(w, pulse_params):
2105 op.H(wires=0)
2106 op.H(wires=1)
2107 Gates.CPhase(w, wires=[0, 1], pulse_params=pulse_params, gate_mode="pulse")
2109 def target_circuit(w):
2110 op.H(wires=0)
2111 op.H(wires=1)
2112 op.ControlledPhaseShift(w, wires=[0, 1])
2114 return pulse_circuit, target_circuit
2116 def optimize_all(self, sel_gates: str, make_log: bool) -> None:
2117 """Optimise all selected gates and optionally write a log CSV.
2119 Args:
2120 sel_gates: Comma-separated gate names or ``"all"``.
2121 make_log: If ``True``, write per-gate loss histories to
2122 ``qml_essentials/qoc_logs.csv``.
2123 """
2124 # Joint mode (Round 3) is now implemented in :meth:`optimize_joint`.
2125 # The `--joint` CLI flag selects it instead of this per-gate loop.
2126 log_history: Dict[str, list] = {}
2128 for gate in self.GATES_1Q + self.GATES_2Q:
2129 if gate in sel_gates or "all" in sel_gates:
2130 n_wires = 1 if gate in self.GATES_1Q else 2
2131 opt = self.optimize(wires=n_wires)
2132 gate_factory = getattr(self, f"create_{gate}")
2133 log.info(f"Optimizing {gate} gate...")
2134 optimized_pulse_params, loss_history = opt(gate_factory)()
2135 log.info(f"Optimized parameters for {gate}: {optimized_pulse_params}")
2136 best_fid = 1 - min(float(loss) for loss in loss_history)
2137 log.info(f"Best achieved fidelity: {best_fid * 100:.5f}%")
2138 log_history[gate] = log_history.get(gate, []) + loss_history
2140 if make_log:
2141 # write log history to file
2142 with open("qml_essentials/qoc_logs.csv", "w") as f:
2143 writer = csv.writer(f)
2144 writer.writerow(log_history.keys())
2145 writer.writerows(zip(*log_history.values()))
2147 # ------------------------------------------------------------------
2148 # Joint composite-aware optimisation (Round 3)
2149 # ------------------------------------------------------------------
2151 # Default leaf set whose parameters are jointly optimised. Order
2152 # matters — it determines the layout of the joint parameter vector
2153 # (theta). Excluding a leaf from this list freezes it at its
2154 # current PulseInformation default during joint optimisation.
2155 JOINT_LEAVES_DEFAULT: Tuple[str, ...] = ("RX", "RY", "RZ", "CZ")
2157 # Default set of target gates whose unitary cost is summed during
2158 # joint optimisation. Composite gates back-propagate into the
2159 # shared leaves; leaf-gate terms keep the standalone fidelity
2160 # acceptable. CZ is excluded from the default targets because it
2161 # is implemented as a static diagonal-Hamiltonian evolution
2162 # (``H_CZ = π·|11⟩⟨11|``, t=1) that is structurally exact and
2163 # cannot be improved by tuning leaf parameters — including it only
2164 # adds ballast to the averaged loss.
2165 JOINT_TARGETS_DEFAULT: Tuple[str, ...] = (
2166 "RX",
2167 "RY",
2168 "RZ",
2169 "H",
2170 "CX",
2171 "CRX",
2172 "CRY",
2173 "CRZ",
2174 )
2176 # Default per-target weights for the joint objective. Weights are
2177 # normalised inside :func:`joint_unitary_cost_fn`. Composites are
2178 # up-weighted because (a) they are what fails the tightened tests
2179 # and (b) standalone leaves already start near-perfect, so the
2180 # averaged loss would otherwise be dominated by the cheap leaves
2181 # and the optimiser would happily refuse to move. Within
2182 # composites, CR_ are weighted higher than H/CX because they are
2183 # the longest decompositions (2 CX + ~6 single-qubit gates) so
2184 # their leaf-error compounding is worst.
2185 JOINT_WEIGHTS_DEFAULT: Dict[str, float] = {
2186 "RX": 0.3,
2187 "RY": 0.3,
2188 "RZ": 0.3,
2189 "H": 1.0,
2190 "CX": 2.0,
2191 "CRX": 3.0,
2192 "CRY": 3.0,
2193 "CRZ": 3.0,
2194 }
2196 # Leaves that are physically identical up to a static carrier-phase
2197 # offset (RX uses cos(ω_c t), RY uses cos(ω_c t + π/2)) and therefore
2198 # *should* share the same envelope parameters. Tying them here in
2199 # the QOC layout — rather than in :mod:`pulses` — keeps the per-gate
2200 # decomposition tree intact while ensuring joint optimisation cannot
2201 # drift their envelopes apart. Empirically RY is the dominant
2202 # contributor to H/CX residuals, so leaving it un-tied lets the
2203 # joint loss settle into a basin where RX is well-tuned but RY is
2204 # ~3× worse; tying them removes that asymmetry.
2205 JOINT_TIED_GROUPS_DEFAULT: Tuple[Tuple[str, ...], ...] = (("RX", "RY"),)
2207 def _build_joint_layout(
2208 self,
2209 leaf_names: Tuple[str, ...],
2210 tied_groups: Optional[Tuple[Tuple[str, ...], ...]] = None,
2211 ) -> Tuple[jnp.ndarray, Dict[str, slice], List[int]]:
2212 """Build the joint parameter layout.
2214 Args:
2215 leaf_names: Ordered names of the leaf gates that participate
2216 in the joint optimisation.
2217 tied_groups: Optional tuple of leaf-name groups whose
2218 parameters are forced to share a single slice in
2219 ``theta``. Defaults to
2220 :pyattr:`JOINT_TIED_GROUPS_DEFAULT` (ties RX/RY). Only
2221 leaves that are present in ``leaf_names`` participate —
2222 a group becomes a no-op if fewer than two of its
2223 members are listed.
2225 Returns:
2226 Tuple ``(init_theta, leaf_slices, log_scale_indices)``:
2227 * ``init_theta`` — concatenated init parameters from
2228 ``PulseInformation.<leaf>.params`` in the given order.
2229 For tied groups, the representative leaf is the *first*
2230 member in the group (the group's mean of current params
2231 is used as the shared init so neither side dominates).
2232 * ``leaf_slices`` — mapping leaf-name → ``slice`` into
2233 ``init_theta``. Tied leaves map to the *same* slice.
2234 * ``log_scale_indices`` — indices into ``init_theta`` that
2235 should be optimised in log-space (amplitude + evolution
2236 time per envelope leaf, mirroring the per-gate default
2237 ``[0, -1]`` rule).
2238 """
2239 if tied_groups is None:
2240 tied_groups = self.JOINT_TIED_GROUPS_DEFAULT
2242 # Build leaf_name -> representative_name lookup. Members of a
2243 # tied group are routed to the group's first member that is
2244 # actually present in ``leaf_names``.
2245 rep_of: Dict[str, str] = {n: n for n in leaf_names}
2246 leaf_set = set(leaf_names)
2247 for group in tied_groups:
2248 present = [n for n in group if n in leaf_set]
2249 if len(present) < 2:
2250 continue
2251 head = present[0]
2252 for member in present[1:]:
2253 rep_of[member] = head
2254 log.info(
2255 f" Joint layout: tying leaf {member!r} to {head!r} "
2256 f"(shared slice in theta)."
2257 )
2259 envelope_info = PulseEnvelope.get(self.envelope)
2260 n_env = envelope_info["n_envelope_params"]
2262 leaf_slices: Dict[str, slice] = {}
2263 init_chunks = []
2264 log_idx: List[int] = []
2265 offset = 0
2266 for name in leaf_names:
2267 rep = rep_of[name]
2268 if rep != name:
2269 # Tied member — point at the representative's slice.
2270 leaf_slices[name] = leaf_slices[rep]
2271 continue
2273 pp = PulseInformation.gate_by_name(name)
2274 assert pp is not None and pp.is_leaf, (
2275 f"_build_joint_layout: {name!r} is not a leaf gate"
2276 )
2277 # For tied groups the shared init is the elementwise mean
2278 # of the current params across all present members; this
2279 # avoids biasing toward whichever member happens to be the
2280 # group representative.
2281 tied_members = [m for m in leaf_names if rep_of[m] == name]
2282 if len(tied_members) > 1:
2283 stacked = jnp.stack(
2284 [
2285 jnp.asarray(
2286 PulseInformation.gate_by_name(m).params,
2287 dtype=jnp.float64,
2288 )
2289 for m in tied_members
2290 ]
2291 )
2292 chunk = jnp.mean(stacked, axis=0)
2293 else:
2294 chunk = jnp.asarray(pp.params, dtype=jnp.float64)
2295 n_p = chunk.shape[0]
2296 leaf_slices[name] = slice(offset, offset + n_p)
2297 init_chunks.append(chunk)
2298 # Log-scale rule per leaf: only leaves that come from the
2299 # *envelope* (RX, RY) get log-scaled amplitude+time. RZ
2300 # and CZ use the "general" registry with a single tuning
2301 # scalar — leave them in linear space.
2302 if name in ("RX", "RY") and n_env >= 2:
2303 log_idx.append(offset) # amplitude
2304 log_idx.append(offset + n_p - 1) # evolution time
2305 offset += n_p
2307 init_theta = jnp.concatenate(init_chunks)
2308 return init_theta, leaf_slices, log_idx
2310 @staticmethod
2311 def _assemble_for_gate(
2312 theta: jnp.ndarray,
2313 pp_obj,
2314 leaf_slices: Dict[str, slice],
2315 ) -> jnp.ndarray:
2316 """Assemble the per-gate flat ``pulse_params`` from ``theta``.
2318 Walks the gate's decomposition tree (recursing through
2319 composites) and concatenates the appropriate slice of ``theta``
2320 for each leaf occurrence. Mirrors :pyattr:`PulseParams.params`
2321 getter logic but pulls leaf data from the joint vector
2322 ``theta`` rather than the leaves' own ``_params``.
2323 """
2324 if pp_obj.is_leaf:
2325 sl = leaf_slices.get(pp_obj.name)
2326 if sl is None:
2327 # Leaf is frozen — use its current PulseInformation
2328 # value directly.
2329 return jnp.asarray(pp_obj.params, dtype=jnp.float64)
2330 return theta[sl]
2331 return jnp.concatenate(
2332 [
2333 QOC._assemble_for_gate(theta, child, leaf_slices)
2334 for child in pp_obj.childs
2335 ]
2336 )
2338 def _joint_stage_0_coord_descent(
2339 self,
2340 init_theta: jnp.ndarray,
2341 leaf_slices: Dict[str, slice],
2342 total_cost: Callable,
2343 ) -> jnp.ndarray:
2344 """Coordinate-descent grid scan over leaf-axis blocks.
2346 For each leaf in ``leaf_slices`` (in order), sweep a centred
2347 multiplicative grid over that leaf's params (using the existing
2348 :meth:`_build_scan_grid` machinery) while holding the other
2349 leaves at the current best. Greedily accept any improvement.
2351 This avoids the combinatorial explosion of a Cartesian
2352 product over all leaf axes simultaneously: instead of
2353 ``Π_i scan_grid_size**k_i`` candidates, only ``Σ_i
2354 scan_grid_size**k_i`` are evaluated.
2356 Args:
2357 init_theta: Starting joint parameter vector.
2358 leaf_slices: Mapping leaf-name → slice into ``init_theta``.
2359 total_cost: Joint cost callable taking ``theta`` and
2360 returning a scalar loss.
2362 Returns:
2363 Best joint parameter vector found.
2364 """
2365 if self.scan_steps <= 0:
2366 log.info("Joint Stage 0: scan disabled (scan_steps=0); skipping.")
2367 return init_theta
2369 current = init_theta
2370 best_loss = _safe_eval(total_cost, current)
2371 log.info(
2372 f"Joint Stage 0: coordinate-descent over {len(leaf_slices)} leaves, "
2373 f"init_loss={float(best_loss):.6e}"
2374 )
2376 prev_solver_defaults = ys.Yaqsi.set_solver_defaults(throw=False)
2377 try:
2378 seen_slices: set = set()
2379 for leaf_name, sl in leaf_slices.items():
2380 # Tied leaves share a slice — only scan the unique
2381 # (start, stop) range once to avoid wasted evaluations.
2382 key = (sl.start, sl.stop)
2383 if key in seen_slices:
2384 continue
2385 seen_slices.add(key)
2386 leaf_init = current[sl]
2387 n_p = int(leaf_init.shape[0])
2388 if n_p == 0:
2389 continue
2390 grid, _ = self._build_scan_grid(n_p, init_pulse_params=leaf_init)
2391 n_better = 0
2392 for cand in grid:
2393 new_theta = current.at[sl].set(cand)
2394 loss = _safe_eval(total_cost, new_theta)
2395 if loss < best_loss:
2396 best_loss = loss
2397 current = new_theta
2398 n_better += 1
2399 log.info(
2400 f" Joint scan after leaf {leaf_name} "
2401 f"({len(grid)} candidates, {n_better} improved): "
2402 f"best_loss={float(best_loss):.6e}"
2403 )
2404 finally:
2405 if prev_solver_defaults:
2406 ys.Yaqsi.set_solver_defaults(**prev_solver_defaults)
2408 return current
2410 def _create_joint_pair_for(self, gate_name: str):
2411 """Return a prep-free ``(pulse, target)`` pair for joint mode.
2413 Looks up :meth:`_joint_gate_factories` first; falls back to the
2414 per-gate (preps included) variant via :meth:`_create_pair_for`
2415 with a warning if the gate is not in the joint table. See the
2416 joint-table docstring for why preps are dropped.
2417 """
2418 table = self._joint_gate_factories()
2419 if gate_name in table:
2420 return table[gate_name]
2421 log.warning(
2422 f"_create_joint_pair_for: no prep-free factory for {gate_name!r}; "
2423 f"falling back to create_{gate_name} (preps may hide errors)."
2424 )
2425 return self._create_pair_for(gate_name)
2427 def _create_pair_for(self, gate_name: str):
2428 """Return ``(pulse_circuit, target_circuit)`` for a target gate.
2430 Reuses :meth:`_create_pair` so the joint mode targets exactly
2431 the same circuits as the per-gate mode.
2432 """
2433 return self._create_pair(gate_name)
2435 def optimize_joint(
2436 self,
2437 target_gates: Optional[List[str]] = None,
2438 leaf_names: Optional[List[str]] = None,
2439 weights: Optional[Dict[str, float]] = None,
2440 ) -> Tuple[jnp.ndarray, Dict[str, slice], list]:
2441 """Joint composite-aware optimisation of leaf pulse parameters.
2443 Optimises a single shared parameter vector ``theta`` (containing
2444 the concatenated leaf params for ``leaf_names``) against a
2445 weighted sum of unitary-cost terms over ``target_gates``.
2446 Composite gates back-propagate into the shared leaves; leaf
2447 terms keep the standalone fidelity acceptable. CZ is omitted
2448 from the default targets because the ``PulseGates.CZ``
2449 implementation is a static diagonal-Hamiltonian evolution
2450 (``H_CZ = π·|11⟩⟨11|``, t=1) that is structurally exact and
2451 unaffected by any leaf re-tuning.
2453 Args:
2454 target_gates: Gates whose unitary cost contributes to the
2455 joint objective. Defaults to
2456 :pyattr:`JOINT_TARGETS_DEFAULT` (RX, RY, RZ, H, CX,
2457 CRX, CRY, CRZ).
2458 leaf_names: Leaf gates whose parameters are jointly
2459 optimised. Defaults to :pyattr:`JOINT_LEAVES_DEFAULT`
2460 (RX, RY, RZ, CZ).
2461 weights: Optional mapping ``gate_name → weight``. Merged
2462 on top of :pyattr:`JOINT_WEIGHTS_DEFAULT` (composites
2463 up-weighted; leaves down-weighted). All weights are
2464 normalised inside the cost.
2466 Returns:
2467 ``(best_theta, leaf_slices, loss_history)``. Per-leaf
2468 results are also written to ``qoc_results_<envelope>.csv``
2469 via :meth:`save_results`.
2470 """
2471 if target_gates:
2472 target_gates = list(target_gates)
2473 else:
2474 target_gates = list(self.JOINT_TARGETS_DEFAULT)
2476 if leaf_names:
2477 leaf_names = list(leaf_names)
2478 else:
2479 leaf_names = list(self.JOINT_LEAVES_DEFAULT)
2481 # Merge user-provided weights on top of class defaults so callers
2482 # can override only the gates they care about.
2483 merged_weights: Dict[str, float] = dict(self.JOINT_WEIGHTS_DEFAULT)
2484 if weights:
2485 merged_weights.update({k: float(v) for k, v in weights.items()})
2486 weights = merged_weights
2488 log.info(f"Joint optimisation: leaves={leaf_names}, targets={target_gates}")
2490 init_theta, leaf_slices, joint_log_idx = self._build_joint_layout(
2491 tuple(leaf_names)
2492 )
2493 log.info(
2494 f" Joint theta size: {init_theta.shape[0]}; "
2495 f"log-scale indices: {joint_log_idx}"
2496 )
2498 # Build per-gate specs (assembler + basis-prep scripts).
2499 gate_specs: List[dict] = []
2500 for gname in target_gates:
2501 pp_obj = PulseInformation.gate_by_name(gname)
2502 if pp_obj is None:
2503 log.warning(f" Skipping unknown gate {gname!r}.")
2504 continue
2505 n_wires = 1 if gname in self.GATES_1Q else 2
2506 d_basis = 2**n_wires
2507 pulse_circuit, target_circuit = self._create_joint_pair_for(gname)
2509 pulse_basis_scripts = [
2510 ys.Script(_with_basis_prep(pulse_circuit, k, n_wires), n_qubits=n_wires)
2511 for k in range(d_basis)
2512 ]
2513 target_basis_scripts = [
2514 ys.Script(
2515 _with_basis_prep(target_circuit, k, n_wires), n_qubits=n_wires
2516 )
2517 for k in range(d_basis)
2518 ]
2520 # Closure capturing pp_obj + leaf_slices. Defined here so
2521 # each spec carries its own assembler.
2522 def _make_assembler(pp_obj=pp_obj):
2523 def assemble(theta):
2524 return QOC._assemble_for_gate(theta, pp_obj, leaf_slices)
2526 return assemble
2528 gate_specs.append(
2529 {
2530 "name": gname,
2531 "n_qubits": n_wires,
2532 "weight": float(weights.get(gname, 1.0)),
2533 "assembler": _make_assembler(),
2534 "pulse_basis_scripts": pulse_basis_scripts,
2535 "target_basis_scripts": target_basis_scripts,
2536 }
2537 )
2538 log.info(
2539 f" Built spec for {gname}: n_qubits={n_wires}, "
2540 f"weight={gate_specs[-1]['weight']}"
2541 )
2543 # Build the joint cost as a Cost wrapper (so weight-tuple
2544 # collapsing into a scalar is shared with the per-gate path).
2545 # We use the same (process_loss, phase_loss) two-component
2546 # weighting as the standalone unitary cost — keeps the relative
2547 # importance of fidelity vs phase consistent.
2548 ((_, weight_tuple),) = (
2549 ((n, w) for n, w in self.cost_fns if n == "unitary")
2550 if any(n == "unitary" for n, _ in self.cost_fns)
2551 else ((None, (0.5, 0.5)),)
2552 )
2553 joint_cost = Cost(
2554 cost=joint_unitary_cost_fn,
2555 weight=weight_tuple,
2556 ckwargs={
2557 "gate_specs": gate_specs,
2558 "n_samples": self.n_samples,
2559 },
2560 )
2562 # Temporarily override log_scale_params to point at joint
2563 # vector indices (Stage 0 grid building + Stage 1 log-space
2564 # reparam both consult ``self.log_scale_params``). Invalidate
2565 # the mask cache on either side of the swap so the joint
2566 # vector picks up the joint indices and per-gate runs revert
2567 # cleanly afterwards.
2568 prev_log_scale = self.log_scale_params
2569 self.log_scale_params = joint_log_idx
2570 self._log_mask_cache.clear()
2571 try:
2572 best_scan_theta = self._joint_stage_0_coord_descent(
2573 init_theta, leaf_slices, joint_cost
2574 )
2576 global_best_theta, global_best_history, global_best_loss = self.stage_1_opt(
2577 best_scan_theta, joint_cost
2578 )
2579 finally:
2580 self.log_scale_params = prev_log_scale
2581 self._log_mask_cache.clear()
2583 log.info(f"Joint optimisation done. final loss={float(global_best_loss):.6e}")
2585 # Save per-leaf results to the CSV (one row per leaf). The
2586 # fidelity column carries the *joint* fidelity; downstream code
2587 # that reads the CSV (or the user copy-pasting into pulses.py)
2588 # can use it as a coarse quality signal.
2589 joint_fid = float(1.0 - global_best_loss)
2590 for leaf_name, sl in leaf_slices.items():
2591 leaf_params = global_best_theta[sl]
2592 self.save_results(
2593 gate=leaf_name,
2594 fidelity=joint_fid,
2595 pulse_params=leaf_params,
2596 )
2598 # Update PulseInformation in-place so the new defaults are
2599 # active in this Python process (handy for diagnostic scripts
2600 # that import QOC and then evaluate the new gates).
2601 for leaf_name, sl in leaf_slices.items():
2602 pp = PulseInformation.gate_by_name(leaf_name)
2603 pp.params = global_best_theta[sl]
2605 return global_best_theta, leaf_slices, global_best_history
2608default_qoc_params = {
2609 "envelope": "drag",
2610 "cost_fns": [
2611 # Unitary-level cost (process infidelity + trace-phase term).
2612 # Captures rotation-axis tilt and global-phase residual that
2613 # the state-fidelity cost is blind to; required to keep two-CX
2614 # composites (CRX/CRY/CRZ) within tightened phase tolerances.
2615 ("unitary", (0.5, 0.5)),
2616 # ("fidelity", (0.5, 0.5)), # legacy state-vector cost
2617 # ("pulse_width", 0.000000015),
2618 # ("evolution_time", 0.000000005),
2619 ],
2620 "t_target": 0.5,
2621 "n_steps": 800,
2622 "n_samples": 20,
2623 "learning_rate": 0.0001,
2624 "warmup_ratio": 0.05,
2625 "end_lr_ratio": 0.01,
2626 "log_interval": 50,
2627 "file_dir": None,
2628 "n_restarts": 5,
2629 "restart_noise_scale": 0.01,
2630 "grad_clip": 1.0,
2631 "random_seed": 1000,
2632 "scan_steps": 20,
2633 "scan_grid_size": 4,
2634 "scan_ranges": None,
2635 "log_scale_params": None,
2636 "early_stop_patience": 0,
2637 "early_stop_min_delta": 0.0,
2638}
2641def profile_pulse_pipeline(
2642 gate: str = "RX",
2643 n_samples: int = 3,
2644 rwa: Optional[bool] = None,
2645 n_qubits: int = 1,
2646) -> dict:
2647 """Profile a single pulse gate's forward + ``value_and_grad`` pass.
2649 Diagnostic helper for the JIT pipeline. Builds a minimal
2650 :class:`Script` that applies the requested pulse gate, then
2651 times JIT compilation and steady-state evaluation of:
2653 1. one forward pass (``Script.execute(type="state", ...)``);
2654 2. one ``jax.value_and_grad`` of the squared overlap with the
2655 analytic ``operations.<gate>`` target.
2657 Use this to measure the impact of the RWA toggle
2658 (``rwa=True``) and of the scan/sync refactors documented in
2659 the patch notes:
2661 from qml_essentials.qoc import profile_pulse_pipeline
2662 profile_pulse_pipeline("RX", rwa=False)
2663 profile_pulse_pipeline("RX", rwa=True)
2665 Args:
2666 gate: Gate name to profile (default ``"RX"``). Must be a
2667 single-qubit pulse-level gate (``RX`` / ``RY``).
2668 n_samples: Number of timed evaluations after warm-up.
2669 rwa: If not ``None``, temporarily switch the global RWA flag
2670 for the duration of the profile. ``None`` keeps the
2671 current setting.
2672 n_qubits: Width of the script (kept at 1 for the single-
2673 qubit pulse gates).
2675 Returns:
2676 Dict with keys ``compile_fwd``, ``mean_fwd``, ``compile_grad``,
2677 ``mean_grad``, ``rwa``, ``loss``.
2678 """
2679 import time
2681 with PulseInformation.preserve_state():
2682 if rwa is not None:
2683 PulseInformation.set_rwa(bool(rwa))
2684 from qml_essentials.pulses import PulseGates
2686 gate_op = getattr(op, gate)
2687 gate_pulse = getattr(PulseGates, gate)
2689 def pulse_circuit(theta, pp):
2690 gate_pulse(theta, wires=0, pulse_params=pp)
2692 def target_circuit(theta):
2693 gate_op(theta, wires=0)
2695 pulse_script = ys.Script(pulse_circuit, n_qubits=n_qubits)
2696 target_script = ys.Script(target_circuit, n_qubits=n_qubits)
2698 theta = jnp.asarray(jnp.pi / 4)
2699 pp = PulseInformation.gate_by_name(gate).params
2700 target_state = target_script.execute(type="state", args=(theta,))
2701 target_state = jax.lax.stop_gradient(target_state)
2703 @jax.jit
2704 def fwd(theta, pp):
2705 return pulse_script.execute(type="state", args=(theta, pp))
2707 @jax.jit
2708 def loss_and_grad(pp):
2709 def loss_fn(p):
2710 state = pulse_script.execute(type="state", args=(theta, p))
2711 return 1.0 - jnp.abs(jnp.vdot(target_state, state)) ** 2
2713 return jax.value_and_grad(loss_fn)(pp)
2715 # Warm-up + compile timings.
2716 t0 = time.perf_counter()
2717 s = fwd(theta, pp)
2718 jax.block_until_ready(s)
2719 compile_fwd = time.perf_counter() - t0
2721 t0 = time.perf_counter()
2722 loss, grads = loss_and_grad(pp)
2723 jax.block_until_ready(loss)
2724 jax.block_until_ready(grads)
2725 compile_grad = time.perf_counter() - t0
2727 fwd_t, grad_t = [], []
2728 for _ in range(n_samples):
2729 t0 = time.perf_counter()
2730 s = fwd(theta, pp)
2731 jax.block_until_ready(s)
2732 fwd_t.append(time.perf_counter() - t0)
2734 t0 = time.perf_counter()
2735 loss, grads = loss_and_grad(pp)
2736 jax.block_until_ready(loss)
2737 jax.block_until_ready(grads)
2738 grad_t.append(time.perf_counter() - t0)
2740 result = {
2741 "gate": gate,
2742 "rwa": PulseInformation.get_rwa(),
2743 "compile_fwd": compile_fwd,
2744 "mean_fwd": float(np.mean(fwd_t)),
2745 "compile_grad": compile_grad,
2746 "mean_grad": float(np.mean(grad_t)),
2747 "loss": float(loss),
2748 }
2749 log.info(
2750 f"[profile] gate={gate} rwa={result['rwa']} "
2751 f"compile fwd/grad: {compile_fwd * 1e3:.1f}/"
2752 f"{compile_grad * 1e3:.1f} ms, "
2753 f"mean fwd/grad: {result['mean_fwd'] * 1e3:.1f}/"
2754 f"{result['mean_grad'] * 1e3:.1f} ms, "
2755 f"loss={result['loss']:.4e}"
2756 )
2757 return result
2760if __name__ == "__main__":
2761 # argparse the selected gate
2762 parser = argparse.ArgumentParser(
2763 description="Quantum Optimal Control — pulse-level gate synthesis."
2764 )
2765 parser.add_argument(
2766 "--gates",
2767 type=str,
2768 nargs="+",
2769 default=["RX", "RY", "RZ", "CZ"],
2770 choices=QOC.GATES_1Q + QOC.GATES_2Q + ["all"],
2771 help="Gate(s) to optimize.",
2772 )
2773 parser.add_argument(
2774 "--log",
2775 action="store_true",
2776 default=False,
2777 help="Log results to file (default: False).",
2778 )
2779 parser.add_argument(
2780 "--no-log",
2781 action="store_false",
2782 dest="log",
2783 help="Disable logging results to file.",
2784 )
2785 parser.add_argument(
2786 "--envelope",
2787 type=str,
2788 default=default_qoc_params["envelope"],
2789 choices=PulseEnvelope.available(),
2790 help="Pulse envelope shape to use for optimization.",
2791 )
2792 parser.add_argument(
2793 "--costs",
2794 type=str,
2795 nargs="+",
2796 default=default_qoc_params["cost_fns"],
2797 help=(
2798 "Cost functions and weights as 'name:w1,w2,...' strings. "
2799 "If weights are omitted the registry defaults are used. "
2800 f"Available: {CostFnRegistry.available()}. "
2801 "Example: --costs fidelity:0.5,0.3 pulse_width:0.2"
2802 ),
2803 )
2804 parser.add_argument(
2805 "--t_target",
2806 type=float,
2807 default=default_qoc_params["t_target"],
2808 help=(
2809 "Target evolution time for the 'evolution_time' cost function. "
2810 "All gates will be softly encouraged towards this common time."
2811 ),
2812 )
2813 parser.add_argument(
2814 "--n_steps",
2815 type=int,
2816 default=default_qoc_params["n_steps"],
2817 help="Number of optimisation steps per gate.",
2818 )
2819 parser.add_argument(
2820 "--n_samples",
2821 type=int,
2822 default=default_qoc_params["n_samples"],
2823 help="Number of parameter samples in [0, 2\\pi] for cost evaluation.",
2824 )
2825 parser.add_argument(
2826 "--learning_rate",
2827 type=float,
2828 default=default_qoc_params["learning_rate"],
2829 help="Peak learning rate for the AdamW optimiser.",
2830 )
2831 parser.add_argument(
2832 "--warmup_ratio",
2833 type=float,
2834 default=default_qoc_params["warmup_ratio"],
2835 help=(
2836 "Fraction of n_steps used for linear LR warmup (0.0-1.0). "
2837 "Set to 0 to start at the peak LR immediately."
2838 ),
2839 )
2840 parser.add_argument(
2841 "--end_lr_ratio",
2842 type=float,
2843 default=default_qoc_params["end_lr_ratio"],
2844 help=(
2845 "Final LR as a fraction of --learning_rate after cosine decay. "
2846 "Also used as the initial LR before warmup. "
2847 "Set to 1.0 (with --warmup_ratio 0) for a constant LR."
2848 ),
2849 )
2850 parser.add_argument(
2851 "--log_interval",
2852 type=int,
2853 default=default_qoc_params["log_interval"],
2854 help="Log the current loss every N steps.",
2855 )
2856 parser.add_argument(
2857 "--file_dir",
2858 type=str,
2859 default=default_qoc_params["file_dir"],
2860 help="Directory to save qoc_results_[envelope].csv. "
2861 "Defaults to the package directory.",
2862 )
2863 parser.add_argument(
2864 "--n_restarts",
2865 type=int,
2866 default=default_qoc_params["n_restarts"],
2867 help=(
2868 "Number of random restarts for the optimisation. "
2869 "The first run uses the initial parameters as-is; "
2870 "subsequent runs add random perturbations. "
2871 "The best result across all restarts is kept."
2872 ),
2873 )
2874 parser.add_argument(
2875 "--restart_noise_scale",
2876 type=float,
2877 default=default_qoc_params["restart_noise_scale"],
2878 help=(
2879 "Standard deviation of Gaussian noise added to the initial "
2880 "parameters for each restart, relative to parameter magnitude."
2881 ),
2882 )
2883 parser.add_argument(
2884 "--grad_clip",
2885 type=float,
2886 default=default_qoc_params["grad_clip"],
2887 help=(
2888 "Maximum global gradient norm. Gradients are clipped to this "
2889 "value before being passed to the optimiser. "
2890 "Set to 0 to disable."
2891 ),
2892 )
2893 parser.add_argument(
2894 "--random_seed",
2895 type=int,
2896 default=default_qoc_params["random_seed"],
2897 help="Base random seed for restart perturbations.",
2898 )
2899 parser.add_argument(
2900 "--scan_steps",
2901 type=int,
2902 default=default_qoc_params["scan_steps"],
2903 help=(
2904 "Number of short gradient-descent steps per candidate in the "
2905 "coarse grid scan (Stage 0). Set to 0 to disable the grid scan."
2906 ),
2907 )
2908 parser.add_argument(
2909 "--scan_grid_size",
2910 type=int,
2911 default=default_qoc_params["scan_grid_size"],
2912 help=(
2913 "Number of points per parameter dimension in the coarse grid. "
2914 "Total candidates = scan_grid_size^n_params."
2915 ),
2916 )
2917 parser.add_argument(
2918 "--scan_ranges",
2919 type=str,
2920 nargs="*",
2921 default=default_qoc_params["scan_ranges"],
2922 help=(
2923 "Per-parameter (lo,hi) ranges for the grid scan, given as "
2924 "'lo,hi' strings. One pair per pulse parameter. "
2925 "Example: --scan_ranges 0.5,30.0 0.05,2.0 0.05,2.0 "
2926 "If omitted, heuristic defaults are used."
2927 ),
2928 )
2929 parser.add_argument(
2930 "--plot",
2931 action="store_true",
2932 default=False,
2933 help=(
2934 "Save a loss-landscape plot (Phase 0) and a loss-curve plot "
2935 "(Phase 1) as PNG files in --file_dir for each optimised gate."
2936 ),
2937 )
2938 parser.add_argument(
2939 "--early_stop_patience",
2940 type=int,
2941 default=default_qoc_params["early_stop_patience"],
2942 help=(
2943 "Number of consecutive Stage-1 steps without improvement "
2944 "(> --early_stop_min_delta) after which optimisation exits "
2945 "early. 0 disables early stopping (default)."
2946 ),
2947 )
2948 parser.add_argument(
2949 "--early_stop_min_delta",
2950 type=float,
2951 default=default_qoc_params["early_stop_min_delta"],
2952 help=(
2953 "Minimum loss decrease that counts as an improvement for "
2954 "the --early_stop_patience counter (default 0.0)."
2955 ),
2956 )
2957 parser.add_argument(
2958 "--joint",
2959 action="store_true",
2960 default=False,
2961 help=(
2962 "Use composite-aware joint optimisation: a single shared "
2963 "leaf parameter vector is optimised against the unitary "
2964 "cost summed over leaf and composite gates "
2965 "(default targets: RX, RY, RZ, CZ, H, CX, CRX, CRY, CRZ). "
2966 "Pulls leaves into a basin that works well in *every* "
2967 "use-site instead of only standalone, fixing the "
2968 "selfish-basin failure mode of per-gate optimisation. "
2969 "Ignores --gates."
2970 ),
2971 )
2972 parser.add_argument(
2973 "--joint_targets",
2974 nargs="+",
2975 type=str,
2976 default=None,
2977 help=(
2978 "(Used only with --joint.) Override the list of target "
2979 "gates whose unitary cost contributes to the joint "
2980 "objective."
2981 ),
2982 )
2983 parser.add_argument(
2984 "--joint_leaves",
2985 nargs="+",
2986 type=str,
2987 default=None,
2988 help=(
2989 "(Used only with --joint.) Override the list of leaf "
2990 "gates whose parameters are jointly optimised. "
2991 "Default: RX RY RZ CZ."
2992 ),
2993 )
2994 parser.add_argument(
2995 "--joint_weights",
2996 nargs="+",
2997 type=str,
2998 default=None,
2999 help=(
3000 "(Used only with --joint.) Override per-target weights as "
3001 "'gate:weight' strings (e.g. --joint_weights CRX:5 CX:3). "
3002 "Merged on top of QOC.JOINT_WEIGHTS_DEFAULT, so unspecified "
3003 "gates keep their default weight."
3004 ),
3005 )
3006 parser.add_argument(
3007 "--rwa",
3008 action="store_true",
3009 default=False,
3010 help=(
3011 "Toggles RWA mode for pulse simulation."
3012 "If this is set true, we utilize the rotating wave approximation "
3013 "instead of the exact interaction picture."
3014 "While this makes the calculations less exact, it provides"
3015 "significant speedup."
3016 ),
3017 )
3018 parser.add_argument(
3019 "--drive",
3020 action="store_true",
3021 default=False,
3022 help=("Uses drive hamiltonian instead of lab frame."),
3023 )
3025 args = parser.parse_args()
3026 sel_gates = args.gates # already a list from nargs="+"
3027 make_log = args.log
3029 # Parse scan_ranges from CLI (list of "lo,hi" strings -> list of tuples)
3030 scan_ranges = None
3031 if args.scan_ranges is not None:
3032 scan_ranges = []
3033 for pair in args.scan_ranges:
3034 lo, hi = pair.split(",")
3035 scan_ranges.append((float(lo), float(hi)))
3037 PulseInformation.set_rwa(args.rwa)
3038 PulseInformation.set_frame("drive" if args.drive else "lab")
3040 # Parse cost function specs from CLI
3041 cost_fns = [CostFnRegistry.parse_cost_arg(spec) for spec in args.costs]
3043 # create logger
3044 log = logging.getLogger("qml_essentials.qoc")
3046 log.setLevel(logging.INFO)
3047 log.addHandler(logging.StreamHandler())
3049 qoc = QOC(
3050 envelope=args.envelope,
3051 cost_fns=cost_fns,
3052 t_target=args.t_target,
3053 n_steps=args.n_steps,
3054 n_samples=args.n_samples,
3055 learning_rate=args.learning_rate,
3056 warmup_ratio=args.warmup_ratio,
3057 end_lr_ratio=args.end_lr_ratio,
3058 log_interval=args.log_interval,
3059 file_dir=args.file_dir,
3060 n_restarts=args.n_restarts,
3061 restart_noise_scale=args.restart_noise_scale,
3062 grad_clip=args.grad_clip,
3063 random_seed=args.random_seed,
3064 scan_steps=args.scan_steps,
3065 scan_grid_size=args.scan_grid_size,
3066 scan_ranges=scan_ranges,
3067 early_stop_patience=args.early_stop_patience,
3068 early_stop_min_delta=args.early_stop_min_delta,
3069 plot=args.plot,
3070 )
3072 if args.joint:
3073 joint_weights = None
3074 if args.joint_weights:
3075 joint_weights = {}
3076 for spec in args.joint_weights:
3077 gname, w = spec.split(":")
3078 joint_weights[gname.strip()] = float(w)
3079 qoc.optimize_joint(
3080 target_gates=args.joint_targets,
3081 leaf_names=args.joint_leaves,
3082 weights=joint_weights,
3083 )
3084 else:
3085 qoc.optimize_all(sel_gates=sel_gates, make_log=make_log)