Coverage for qml_essentials / qoc.py: 78%
446 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-30 11:43 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-30 11:43 +0000
1import argparse
2import csv
3import itertools
4import logging
5import os
6from typing import Callable, Dict, List, Optional, Tuple, Union
8import jax
9from jax import numpy as jnp
10import optax
12from qml_essentials.gates import Gates, PulseInformation, PulseEnvelope
13from qml_essentials import operations as op
14from qml_essentials import yaqsi as ys
15from qml_essentials.math import phase_difference, fidelity
17jax.config.update("jax_enable_x64", True)
18log = logging.getLogger(__name__)
21class Cost:
22 """Weighted wrapper around a cost function.
24 Combines a cost callable with a scalar or tuple weight and optional
25 constant keyword arguments. Multiple ``Cost`` instances can be
26 composed via the ``+`` operator to build a combined objective.
28 Args:
29 cost: Callable ``(pulse_params, **ckwargs) -> scalar | tuple``.
30 weight: Scalar or tuple of per-component weights.
31 ckwargs: Constant keyword arguments injected into every call.
32 """
34 def __init__(
35 self,
36 cost: Callable,
37 weight: Union[float, Tuple],
38 ckwargs: Optional[dict] = None,
39 ):
40 self.cost = cost
41 self.weight = weight
42 self.ckwargs = ckwargs if ckwargs is not None else {}
44 def __call__(self, *args, **kwargs):
45 """Evaluate the cost function with injected kwargs and apply weights."""
46 cost = self.cost(*args, **kwargs, **self.ckwargs)
47 if isinstance(self.weight, tuple):
48 return jnp.array(
49 [c * w for c, w in zip(cost, self.weight, strict=True)]
50 ).sum()
51 return cost * self.weight
53 def __add__(self, other):
54 """Compose two cost terms into a single callable that sums them."""
55 if other is None:
56 return lambda *args, **kwargs: self(*args, **kwargs)
57 if callable(other):
58 return lambda *args, **kwargs: self(*args, **kwargs) + other(
59 *args, **kwargs
60 )
61 raise TypeError(f"Cannot add Cost and {type(other)}")
64def fidelity_cost_fn(
65 pulse_params: jnp.ndarray,
66 pulse_script: ys.Script,
67 target_script: ys.Script,
68 n_samples: int,
69) -> Tuple[float, float]:
70 """
71 Cost function returning (1 - fidelity) and |phase_difference| averaged
72 over n_samples uniformly spaced rotation angles in [0, 2\\pi].
74 Uses batched (vmapped) circuit execution: all n_samples rotation
75 angles are evaluated in a single vectorised call per script, replacing
76 ``n_samples`` sequential Python-level circuit executions with one
77 JIT-compiled XLA program each.
79 Args:
80 pulse_params: Pulse parameters for evaluation.
81 pulse_script: Yaqsi script with pulse parameters.
82 target_script: Yaqsi script as target.
83 n_samples: Number of parameter samples.
85 Returns:
86 Tuple of (abs_diff, phase_diff).
87 """
88 ws = jnp.linspace(0, 2 * jnp.pi, n_samples)
90 pulse_states = pulse_script.execute(
91 type="state",
92 args=(ws, pulse_params),
93 in_axes=(0, None),
94 ) # (n_samples, dim)
96 target_states = target_script.execute(
97 type="state",
98 args=(ws,),
99 in_axes=(0,),
100 ) # (n_samples, dim)
102 abs_diff = jnp.mean(
103 jnp.array(1.0, dtype=jnp.float64) - fidelity(pulse_states, target_states)
104 )
105 phase_diff = jnp.mean(jnp.abs(phase_difference(pulse_states, target_states)))
107 # TODO: in future we could consider some sort of log based loss for the small values
108 # or utilize gradient ascent if we run into numerical limitations
110 return (abs_diff, phase_diff)
113def pulse_width_cost_fn(
114 pulse_params: jnp.ndarray,
115 envelope: str,
116) -> jnp.ndarray:
117 """
118 Cost function penalising the pulse width (sigma / width).
120 The pulse width is taken as the last envelope parameter. For
121 envelopes with no envelope parameters (e.g. ``"general"``), the cost
122 is zero.
124 Args:
125 pulse_params: Pulse parameters for the gate.
126 envelope: Name of the active pulse envelope.
128 Returns:
129 Scalar pulse-width cost.
130 """
131 envelope_info = PulseEnvelope.get(envelope)
132 n_envelope_params = envelope_info["n_envelope_params"]
134 if n_envelope_params > 0:
135 pulse_width = pulse_params[n_envelope_params - 1]
136 else:
137 pulse_width = 0
139 return jnp.array(pulse_width, dtype=jnp.float64)
142def evolution_time_cost_fn(
143 pulse_params: jnp.ndarray,
144 t_target: float,
145) -> jnp.ndarray:
146 """
147 Cost function penalising deviation of the evolution time from a target.
149 The evolution time is always the last element of the pulse parameter
150 vector. The cost is the squared relative deviation from ``t_target``:
152 cost = ((t - t_target) / t_target) ** 2
154 This encourages all independently optimized gates to converge towards a
155 common evolution time, making them compatible when composed into a
156 circuit.
158 Args:
159 pulse_params: Pulse parameters for the gate.
160 t_target: Target evolution time.
162 Returns:
163 Scalar evolution-time cost.
164 """
165 t = pulse_params[-1]
166 return ((t - t_target) / t_target) ** 2
169def spectral_density_cost_fn(
170 pulse_params: jnp.ndarray,
171 envelope: str,
172 n_fft: int = 1024,
173) -> jnp.ndarray:
174 """
175 Cost function penalising the spectral width of a given pulse.
177 Samples the pulse envelope in the time domain over ``[0, t_evol]``
178 (where ``t_evol`` is the last element of pulse_params), computes its
179 power spectral density via FFT, and returns the normalised RMS bandwidth
180 (square root of the second central moment of the PSD).
182 Pulses with narrow spectra (e.g. Gaussian, DRAG) receive a low cost,
183 whereas pulses with wide spectra (e.g. rectangular) are penalised more
184 heavily.
186 For envelopes with no envelope parameters (e.g. ``"general"``), the
187 cost is zero.
189 Args:
190 pulse_params: Pulse parameters for the gate. Envelope parameters
191 occupy ``pulse_params[:n_envelope_params]`` and the evolution
192 time is ``pulse_params[-1]``.
193 envelope: Name of the active pulse envelope.
194 n_fft: Number of time-domain samples used for the FFT
195 (default 1024).
197 Returns:
198 Scalar spectral-width cost (RMS bandwidth normalised by the
199 Nyquist frequency so the value is in [0, 1]).
200 """
201 envelope_info = PulseEnvelope.get(envelope)
202 n_envelope_params = envelope_info["n_envelope_params"]
203 envelope_fn = envelope_info["fn"]
205 # Nothing to penalise for envelopes without tuneable shape params
206 if n_envelope_params == 0 or envelope_fn is None:
207 return jnp.array(0.0, dtype=jnp.float64)
209 # Extract envelope parameters and evolution time
210 env_params = pulse_params[:n_envelope_params]
211 t_evol = pulse_params[-1]
212 t_c = t_evol / 2.0
214 t_samples = jnp.linspace(0.0, t_evol, n_fft)
215 signal = jax.vmap(lambda t: envelope_fn(env_params, t, t_c))(t_samples)
217 spectrum = jnp.fft.rfft(signal)
218 psd = jnp.abs(spectrum) ** 2
219 psd = psd / (jnp.sum(psd) + 1e-12) # normalise to a distribution
221 freqs = jnp.linspace(0.0, 1.0, len(psd))
223 mean_freq = jnp.sum(freqs * psd)
224 rms_bw = jnp.sqrt(jnp.sum((freqs - mean_freq) ** 2 * psd))
226 return jnp.array(rms_bw, dtype=jnp.float64)
229# Backward-compatible alias for the old misspelled name
230sepctral_density_cost_fn = spectral_density_cost_fn
233class CostFnRegistry:
234 """Registry of cost functions available for pulse optimisation.
236 Use :meth:`register` to add new cost functions at runtime and
237 :meth:`get` / :meth:`available` to query them.
238 """
240 _REGISTRY: Dict[str, dict] = {
241 "fidelity": {
242 "fn": fidelity_cost_fn,
243 "default_weight": (0.5, 0.5),
244 "ckwargs_keys": ["pulse_script", "target_script", "n_samples"],
245 },
246 "pulse_width": {
247 "fn": pulse_width_cost_fn,
248 "default_weight": 1.0,
249 "ckwargs_keys": ["envelope"],
250 },
251 "evolution_time": {
252 "fn": evolution_time_cost_fn,
253 "default_weight": 1.0,
254 "ckwargs_keys": ["t_target"],
255 },
256 "spectral_density": {
257 "fn": spectral_density_cost_fn,
258 "default_weight": 1.0,
259 "ckwargs_keys": ["envelope"],
260 },
261 }
263 @classmethod
264 def available(cls) -> List[str]:
265 """Return the names of all registered cost functions."""
266 return list(cls._REGISTRY.keys())
268 @classmethod
269 def get(cls, name: str) -> dict:
270 """Look up cost-function metadata by name.
272 Args:
273 name: Registered cost function name.
275 Returns:
276 Metadata dict with keys ``fn``,
277 ``default_weight``, ``ckwargs_keys``.
279 Raises:
280 ValueError: If name is not registered.
281 """
282 if name not in cls._REGISTRY:
283 raise ValueError(
284 f"Unknown cost function '{name}'. " f"Available: {cls.available()}"
285 )
286 return cls._REGISTRY[name]
288 @classmethod
289 def parse_cost_arg(
290 cls, spec: Union[str, Tuple]
291 ) -> Tuple[str, Union[float, Tuple[float, ...]]]:
292 """Parse a ``"name:w1,w2,..."`` CLI string into ``(name, weight)``.
293 If a tuple is provided, it is returned directly.
295 If the weight part is omitted the default weight from the registry
296 is used. A single-component weight is returned as a float;
297 multi-component weights are returned as a tuple of floats.
299 Args:
300 spec: A string of the form ``"name"`` or ``"name:w1,w2,..."``.
302 Returns:
303 A tuple of ``(name, weight)``.
305 Raises:
306 ValueError: If the name is unknown or the number of weight
307 components does not match the ones in ``default_weight``.
308 """
309 if isinstance(spec, tuple):
310 return spec
312 if ":" in spec:
313 name, weight_str = spec.split(":", 1)
314 parts = [float(x) for x in weight_str.split(",")]
315 weight: Union[float, Tuple[float, ...]] = (
316 parts[0] if len(parts) == 1 else tuple(parts)
317 )
318 else:
319 name = spec
320 weight = cls.get(name)["default_weight"]
322 # Validate weight count
323 got = len(weight) if isinstance(weight, tuple) else 1
324 default_weight = cls.get(name)["default_weight"]
325 expected = len(default_weight) if isinstance(default_weight, tuple) else 1
327 if got != expected:
328 raise ValueError(
329 f"Cost function '{name}' expects {expected} weight(s), " f"got {got}."
330 )
332 return name, weight
335class QOC:
336 """Quantum Optimal Control for pulse-level gate synthesis.
338 Optimises pulse parameters to reproduce the unitary of standard
339 quantum gates using a two-stage strategy.
341 Attributes:
342 GATES_1Q: Names of supported single-qubit gates.
343 GATES_2Q: Names of supported two-qubit gates.
344 DEFAULT_PARAM_RANGES: Default parameter ranges for each gate.
345 """
347 GATES_1Q: List[str] = ["RX", "RY", "RZ", "Rot", "H"]
348 GATES_2Q: List[str] = ["CX", "CY", "CZ", "CRX", "CRY", "CRZ"]
350 DEFAULT_PARAM_RANGES = {
351 1: [(0.05, 2.0)], # evolution time
352 2: [(0.5, 2.0), (0.05, 2.0)], # not typically used
353 3: [(0.5, 30.0), (0.05, 2.0), (0.05, 2.0)], # A, σ, t
354 4: [(0.5, 30.0), (0.05, 2.0), (0.01, 0.5), (0.05, 2.0)], # DRAG
355 }
357 def __init__(
358 self,
359 envelope: str,
360 cost_fns: List[Tuple[str, Union[float, Tuple[float, ...]]]],
361 t_target: float,
362 n_steps: int,
363 n_samples: int,
364 learning_rate: float,
365 log_interval: int = 50,
366 file_dir: str = None,
367 warmup_ratio: float = 0.0,
368 end_lr_ratio: float = 1.0,
369 n_restarts: int = 1,
370 restart_noise_scale: float = 0.5,
371 grad_clip: float = 1.0,
372 random_seed: int = 42,
373 scan_steps: int = 0,
374 scan_grid_size: int = 5,
375 scan_ranges: Optional[List[Tuple[float, float]]] = None,
376 log_scale_params: Optional[List[int]] = None,
377 ):
378 """
379 Initialize Quantum Optimal Control with Pulse-level Gates.
381 Args:
382 envelope (str): Pulse envelope shape to use for optimization.
383 Must be one of the registered envelopes in PulseEnvelope
384 (e.g. 'gaussian', 'square', 'cosine', 'drag', 'sech').
385 cost_fns (list): List of ``(name, weight)`` tuples that select
386 which cost functions to use and their weights. name must
387 be a key in :class:`CostFnRegistry`. *weight* is either a
388 single float or a tuple of floats matching the number of
389 return values of the cost function.
390 t_target (float, optional): Target evolution time for the
391 ``evolution_time`` cost function. Required when
392 ``"evolution_time"`` is among the selected cost functions.
393 n_steps (int): Number of steps in optimization.
394 n_samples (int): Number of parameter samples per step.
395 learning_rate (float): Peak learning rate for AdamW. When a
396 warmup/decay schedule is active this is the maximum LR
397 reached after the warmup phase.
398 log_interval (int): Interval for logging.
399 file_dir (str): Directory to save results.
400 warmup_ratio (float): Fraction of ``n_steps`` used for linear
401 warmup (0.0 - 1.0). Set to 0.0 to disable warmup and use
402 a constant learning rate throughout. A value of e.g. 0.05
403 means the first 5 % of steps linearly ramp the LR from
404 ``end_lr_ratio * learning_rate`` to ``learning_rate``.
405 end_lr_ratio (float): The final learning rate is
406 ``end_lr_ratio * learning_rate``. Also used as the initial
407 LR at the start of warmup. Set to 0.0 for full cosine
408 decay to zero; set to 1.0 (together with
409 ``warmup_ratio=0.0``) to recover a constant LR.
410 n_restarts (int): Number of random restarts for the
411 optimisation. The first run uses the initial parameters
412 as-is; subsequent runs add scaled random perturbations.
413 The best result across all restarts is kept.
414 Set to 1 to disable restarts (default behaviour).
415 restart_noise_scale (float): Standard deviation of the
416 Gaussian noise added to the initial parameters for each
417 restart (relative to the absolute value of each parameter).
418 Defaults to 0.5 (50 % relative perturbation).
419 grad_clip (float): Maximum global gradient norm. Gradients
420 are clipped to this value before being passed to the
421 optimiser, which stabilises training when the loss
422 landscape has steep regions. Set to ``float('inf')`` or
423 0.0 to disable. Defaults to 1.0.
424 random_seed (int): Base random seed for generating restart
425 perturbations. Defaults to 42.
426 scan_steps (int): Number of short gradient-descent steps to
427 run for each candidate in the coarse grid search
428 (Stage 0). Set to 0 to disable the grid scan entirely
429 and rely solely on restarts. A value of 20-50 is
430 usually enough to identify promising basins. Defaults
431 to 0.
432 scan_grid_size (int): Number of points per parameter
433 dimension in the coarse grid. The total number of
434 candidates is ``scan_grid_size ** n_params``, so keep
435 this small for high-dimensional parameter spaces.
436 Defaults to 5.
437 scan_ranges (Optional[List[Tuple[float, float]]]): Per-
438 parameter ``(lo, hi)`` ranges for the grid scan. If
439 ``None``, heuristic ranges are used based on the
440 envelope type: amplitude in ``[0.5, 30]``, width/sigma
441 in ``[0.05, 2]``, and evolution time in ``[0.05, 2]``.
442 Must have length equal to the number of pulse parameters
443 if provided.
444 log_scale_params (Optional[List[int]]): Indices of pulse
445 parameters that should be optimised in log-space. For
446 these parameters the optimizer sees ``log(p)`` and the
447 actual parameter used in the simulation is ``exp(log_p)``.
448 This dramatically improves convergence when the optimal
449 value may differ from the initial value by an order of
450 magnitude (e.g. amplitude, evolution time).
451 If ``None``, defaults to ``[0, -1]`` (amplitude and
452 evolution time) for envelopes with ≥ 2 envelope params,
453 or ``[]`` otherwise.
454 """
455 self.envelope = envelope
456 self.n_steps = n_steps
457 self.n_samples = n_samples
458 self.learning_rate = learning_rate
459 self.warmup_ratio = warmup_ratio
460 self.end_lr_ratio = end_lr_ratio
461 self.log_interval = log_interval
462 self.file_dir = (
463 file_dir if file_dir else os.path.dirname(os.path.realpath(__file__))
464 )
465 self.t_target = t_target
466 self.n_restarts = max(1, n_restarts)
467 self.restart_noise_scale = restart_noise_scale
468 self.grad_clip = grad_clip
469 self.random_key = jax.random.PRNGKey(random_seed)
470 self.scan_steps = scan_steps
471 self.scan_grid_size = scan_grid_size
472 self.scan_ranges = scan_ranges
474 # Determine log-scale param indices
475 envelope_info = PulseEnvelope.get(envelope)
476 n_env = envelope_info["n_envelope_params"]
477 if log_scale_params is not None:
478 self.log_scale_params = log_scale_params
479 elif n_env >= 2:
480 # Default: amplitude (index 0) and evolution time (last)
481 self.log_scale_params = [0, -1]
482 else:
483 self.log_scale_params = []
485 log.info(
486 f"Training parameters: {self.n_steps} steps, "
487 f"{self.n_samples} samples, {self.learning_rate} learning rate"
488 )
489 log.info(
490 f"LR schedule: warmup_ratio={self.warmup_ratio}, "
491 f"end_lr_ratio={self.end_lr_ratio}"
492 )
494 log.info(f"Envelope: {self.envelope}")
495 log.info(f"Target evolution time: {self.t_target}")
496 log.info(
497 f"Restarts: {self.n_restarts}, noise_scale={self.restart_noise_scale}, "
498 f"grad_clip={self.grad_clip}"
499 )
500 log.info(
501 f"Grid scan: scan_steps={self.scan_steps}, "
502 f"scan_grid_size={self.scan_grid_size}, "
503 f"log_scale_params={self.log_scale_params}"
504 )
505 log.info(f"Using cost function(s) {cost_fns}")
507 # Validate each entry against the registry
508 summed_weights = 0
509 for name, _weight in cost_fns:
510 CostFnRegistry.get(name) # raises ValueError if unknown
511 summed_weights += sum(_weight) if isinstance(_weight, tuple) else _weight
512 assert jnp.isclose(
513 summed_weights, 1.0, rtol=1e-8
514 ), f"Cost function weights must sum to 1. Got {summed_weights}"
516 self.cost_fns = cost_fns
518 # Configure the pulse system with the selected envelope
519 PulseInformation.set_envelope(self.envelope)
521 def save_results(self, gate: str, fidelity: float, pulse_params) -> None:
522 """Save optimised pulse parameters and fidelity for a gate to CSV.
524 If the gate already exists in the file, its entry is overwritten
525 regardless of whether the new fidelity is higher. A warning is
526 logged when the existing fidelity was better.
528 Args:
529 gate: Name of the gate (e.g. ``"RX"``).
530 fidelity: Achieved fidelity of the optimised pulse.
531 pulse_params: Optimised pulse parameters for the gate.
532 """
533 if self.file_dir is not None:
534 os.makedirs(self.file_dir, exist_ok=True)
535 filename = os.path.join(self.file_dir, "qoc_results.csv")
537 reader = None
538 if os.path.isfile(filename):
539 with open(filename, mode="r", newline="") as f:
540 reader = csv.reader(f.readlines())
542 entry = [gate] + [fidelity] + list(map(float, pulse_params))
544 with open(filename, mode="w", newline="") as f:
545 writer = csv.writer(f)
546 match = False
547 if reader is not None:
548 for row in reader:
549 # gate already exists
550 if row[0] == gate:
551 if fidelity <= float(row[1]):
552 log.warning(
553 f"Pulse parameters for {gate} already exist with "
554 f"higher fidelity ({row[1]} >= {fidelity})"
555 )
556 writer.writerow(entry)
557 match = True
558 # any other gate
559 else:
560 writer.writerow(row)
561 # gate does not exist
562 if not match:
563 writer.writerow(entry)
565 def _to_log_space(self, params: jnp.ndarray) -> jnp.ndarray:
566 """Convert selected parameters to log-space for optimisation.
568 Parameters at indices in ``self.log_scale_params`` are replaced
569 by ``log(|p| + eps)`` so the optimiser operates on a
570 logarithmic scale. All other parameters are left unchanged.
572 Args:
573 params: Pulse parameters in physical space.
575 Returns:
576 Parameters with selected entries in log-space.
577 """
578 if not self.log_scale_params:
579 return params
580 n = len(params)
581 log_params = params.copy()
582 for idx in self.log_scale_params:
583 # Normalise negative indices
584 i = idx if idx >= 0 else n + idx
585 log_params = log_params.at[i].set(jnp.log(jnp.abs(params[i]) + 1e-12))
586 return log_params
588 def _from_log_space(self, log_params: jnp.ndarray) -> jnp.ndarray:
589 """Convert selected parameters back from log-space.
591 Inverse of :meth:`_to_log_space`. Parameters at indices in
592 ``self.log_scale_params`` are exponentiated; all others are
593 passed through unchanged.
595 Args:
596 log_params: Parameters with selected entries in log-space.
598 Returns:
599 Parameters in physical space (all positive for log-scaled
600 entries).
601 """
602 if not self.log_scale_params:
603 return log_params
604 n = len(log_params)
605 params = log_params.copy()
606 for idx in self.log_scale_params:
607 i = idx if idx >= 0 else n + idx
608 params = params.at[i].set(jnp.exp(log_params[i]))
609 return params
611 def _build_scan_grid(self, n_params: int) -> jnp.ndarray:
612 """Build a coarse parameter grid for the initial scan phase.
614 Uses either user-supplied ``scan_ranges`` or heuristic defaults
615 based on typical Gaussian pulse parameter ranges.
617 Args:
618 n_params: Number of pulse parameters.
620 Returns:
621 Array of shape ``(n_candidates, n_params)`` with grid points.
622 """
623 if self.scan_ranges is not None:
624 ranges = self.scan_ranges
625 assert len(ranges) == n_params, (
626 f"scan_ranges has {len(ranges)} entries but gate has "
627 f"{n_params} parameters."
628 )
629 else:
630 # [amplitude, sigma/width, evolution_time]
631 ranges = self.DEFAULT_PARAM_RANGES.get(
632 n_params,
633 [(0.1, 10.0)] * n_params, # fallback
634 )
636 # Build log-spaced grids for each parameter
637 axes = []
638 for lo, hi in ranges:
639 axes.append(jnp.logspace(jnp.log10(lo), jnp.log10(hi), self.scan_grid_size))
641 # Cartesian product of all axes
642 grid = jnp.array(list(itertools.product(*axes)))
643 return grid
645 def stage_0_opt(self, init_pulse_params: jnp.ndarray, fidelity_only_cost):
646 """Run the coarse grid-scan phase (Stage 0).
648 Evaluates a Cartesian grid of parameter candidates using only the
649 fidelity cost (ignoring phase). Each candidate is refined with a
650 few fast gradient steps. Returns the best-found parameters.
652 Args:
653 init_pulse_params: Initial pulse parameters to compare against.
654 fidelity_only_cost: Cost callable using fidelity only.
656 Returns:
657 Best pulse parameters found during the scan.
658 """
660 def fidelity_only_cost_log(log_params, *args):
661 return fidelity_only_cost(self._from_log_space(log_params), *args)
663 best_scan_params = init_pulse_params
664 best_scan_loss = fidelity_only_cost(init_pulse_params)
666 if self.scan_steps > 0:
667 log.info(
668 f"Stage 0: Grid scan with {self.scan_grid_size}^"
669 f"{len(init_pulse_params)} candidates, "
670 f"{self.scan_steps} steps each"
671 )
673 grid = self._build_scan_grid(len(init_pulse_params))
674 log.info(f" Total candidates: {len(grid)}")
676 # Use a fast, constant-LR Adam for the scan phase
677 scan_optimizer = optax.chain(
678 optax.clip_by_global_norm(
679 self.grad_clip if self.grad_clip > 0 else 1.0
680 ),
681 optax.adam(self.learning_rate * 5), # aggressive LR
682 )
684 @jax.jit
685 def scan_step(opt_state, log_params):
686 loss, grads = jax.value_and_grad(fidelity_only_cost_log)(log_params)
687 updates, opt_state = scan_optimizer.update(grads, opt_state, log_params)
688 log_params = optax.apply_updates(log_params, updates)
689 return log_params, opt_state, loss
691 for ci, candidate in enumerate(grid):
692 log_candidate = self._to_log_space(candidate)
693 opt_state = scan_optimizer.init(log_candidate)
695 log_p = log_candidate
696 for _ in range(self.scan_steps):
697 log_p, opt_state, loss = scan_step(opt_state, log_p)
699 # Evaluate final loss
700 physical_p = self._from_log_space(log_p)
701 loss = fidelity_only_cost(physical_p)
703 if loss < best_scan_loss:
704 best_scan_loss = loss
705 best_scan_params = physical_p
706 log.info(
707 f" Candidate {ci + 1}/{len(grid)}: "
708 f"loss={loss.item():.3e} improved with "
709 f"params={physical_p}"
710 )
712 log.info(
713 f"Stage 0 complete. Best fidelity-only loss: "
714 f"{best_scan_loss.item():.3e}, "
715 f"params: {best_scan_params}"
716 )
718 return best_scan_params
720 def stage_1_opt(self, best_scan_params: jnp.ndarray, total_costs):
721 """Run multi-restart gradient optimisation (Stage 1).
723 Performs ``n_restarts`` independent AdamW runs with the full
724 (weighted) cost function. The first restart uses
725 ``best_scan_params`` directly; subsequent restarts add random
726 perturbations. Parameters specified in ``log_scale_params`` are
727 optimised in log-space.
729 Args:
730 best_scan_params: Starting parameters (typically from Stage 0).
731 total_costs: Combined cost callable.
733 Returns:
734 Tuple of ``(best_params, loss_history, best_loss)``.
735 """
737 # Wrap the cost function with log-space reparameterisation
738 def total_costs_log(log_params, *args):
739 return total_costs(self._from_log_space(log_params), *args)
741 # Build learning rate schedule
742 warmup_steps = int(self.n_steps * self.warmup_ratio)
743 end_value = self.learning_rate * self.end_lr_ratio
745 if warmup_steps > 0 or self.end_lr_ratio < 1.0:
746 schedule = optax.warmup_cosine_decay_schedule(
747 init_value=(end_value if warmup_steps > 0 else self.learning_rate),
748 peak_value=self.learning_rate,
749 warmup_steps=warmup_steps,
750 decay_steps=self.n_steps,
751 end_value=end_value,
752 )
753 else:
754 schedule = self.learning_rate
756 # Build optimiser chain with gradient clipping
757 use_clip = (
758 self.grad_clip and self.grad_clip > 0 and jnp.isfinite(self.grad_clip)
759 )
760 if use_clip:
761 optimizer = optax.chain(
762 optax.clip_by_global_norm(self.grad_clip),
763 optax.adamw(schedule),
764 )
765 else:
766 optimizer = optax.adamw(schedule)
768 @jax.jit
769 def opt_step(opt_state, log_params, *args):
770 loss, grads = jax.value_and_grad(total_costs_log)(log_params, *args)
771 updates, opt_state = optimizer.update(grads, opt_state, log_params)
772 log_params = optax.apply_updates(log_params, updates)
773 return log_params, opt_state, loss
775 # Use the best from grid scan as starting point
776 start_params = best_scan_params
778 global_best_loss = jnp.inf
779 global_best_params = start_params
780 global_best_history = []
781 restart_key = self.random_key
783 for restart in range(self.n_restarts):
784 if restart == 0:
785 params = start_params
786 else:
787 # Perturb the starting point
788 restart_key, sub_key = jax.random.split(restart_key)
789 noise = jax.random.normal(sub_key, shape=start_params.shape)
790 scale = (
791 jnp.maximum(jnp.abs(start_params), 0.1) * self.restart_noise_scale
792 )
793 params = start_params + noise * scale
794 # Ensure log-scaled params remain positive before
795 # conversion (evolution time at index -1 is always
796 # included since _to_log_space uses jnp.abs anyway,
797 # but we keep values positive for readability).
798 params = params.at[-1].set(jnp.abs(params[-1]))
799 for idx in self.log_scale_params:
800 i = idx if idx >= 0 else len(params) + idx
801 params = params.at[i].set(jnp.abs(params[i]))
802 log.info(
803 f"Restart {restart + 1}/{self.n_restarts} "
804 f"with perturbed params: {params}"
805 )
807 # Convert to log-space for optimisation
808 log_params = self._to_log_space(params)
809 opt_state = optimizer.init(log_params)
811 loss = total_costs(params)
812 loss_history = [loss]
813 best_loss = loss
814 best_pulse_params = params
816 for step in range(self.n_steps):
817 if step % self.log_interval == 0:
818 restart_tag = (
819 f" [restart {restart + 1}/{self.n_restarts}]"
820 if self.n_restarts > 1
821 else ""
822 )
823 log.info(
824 f"Step {step}/{self.n_steps}, "
825 f"Loss: {loss_history[-1].item():.3e}"
826 f"{restart_tag}"
827 )
829 log_params, opt_state, loss = opt_step(opt_state, log_params)
831 if loss < best_loss:
832 log.debug(f"Best set of params found at step {step}")
833 best_loss = loss
834 best_pulse_params = self._from_log_space(log_params)
836 loss_history.append(loss)
838 log.info(
839 f"Restart {restart + 1}/{self.n_restarts} finished "
840 f"with best loss: {best_loss.item():.3e}"
841 )
843 if best_loss < global_best_loss:
844 global_best_loss = best_loss
845 global_best_params = best_pulse_params
846 global_best_history = loss_history
848 return global_best_params, global_best_history, global_best_loss
850 def optimize(self, wires):
851 """Decorator factory that optimises pulse parameters for a gate.
853 Usage::
855 opt = qoc.optimize(wires=1)
856 best_params, loss_history = opt(qoc.create_RX)()
858 Args:
859 wires: Number of qubits the gate acts on.
861 Returns:
862 A decorator that accepts a circuit-factory function and
863 returns a callable ``(init_pulse_params=None) ->
864 (best_params, loss_history)``.
865 """
867 def decorator(create_circuits):
868 def wrapper(init_pulse_params: jnp.ndarray = None):
869 """
870 Optimise pulse parameters for a quantum gate using a
871 multi-phase strategy:
873 Stage 0 - Grid scan (if ``scan_steps > 0``):
874 Evaluate a coarse grid of parameter candidates using
875 only the fidelity cost (ignoring phase). Each
876 candidate is refined with a few fast gradient steps.
877 The best candidate becomes the starting point for
878 Stage 1, unless the user-supplied init_pulse_params
879 are already better.
881 Stage 1 - Multi-restart gradient optimisation:
882 Run ``n_restarts`` independent Adam optimisation runs
883 with the full cost function. The first restart uses
884 the best point found so far; subsequent restarts add
885 random perturbations. Parameters at indices in
886 ``log_scale_params`` are optimised in log-space to
887 handle order-of-magnitude differences in scale.
889 Args:
890 init_pulse_params (array): Initial pulse parameters.
891 If ``None``, uses the envelope defaults from
892 :class:`PulseInformation`.
894 Returns:
895 tuple: ``(best_params, loss_history)`` from the best
896 restart.
897 """
898 pulse_circuit, target_circuit = create_circuits()
900 pulse_script = ys.Script(pulse_circuit, n_qubits=wires)
901 target_script = ys.Script(target_circuit, n_qubits=wires)
903 gate_name = create_circuits.__name__.split("_")[1]
905 if init_pulse_params is None:
906 init_pulse_params = PulseInformation.gate_by_name(gate_name).params
907 log.debug(
908 f"Initial pulse parameters for {gate_name}: {init_pulse_params}"
909 )
911 all_ckwargs = {
912 "pulse_script": pulse_script,
913 "target_script": target_script,
914 "envelope": self.envelope,
915 "n_samples": self.n_samples,
916 "t_target": self.t_target,
917 }
919 def _build_cost(name, weight):
920 """Build a Cost from a registry entry, filtering ckwargs."""
921 meta = CostFnRegistry.get(name)
922 return Cost(
923 cost=meta["fn"],
924 weight=weight,
925 ckwargs={
926 k: v
927 for k, v in all_ckwargs.items()
928 if k in meta["ckwargs_keys"]
929 },
930 )
932 total_costs = None
933 for name, weight in self.cost_fns:
934 total_costs = _build_cost(name, weight) + total_costs
936 fidelity_only_cost = _build_cost(
937 "fidelity", (1.0, 0.0) # 100% fidelity, 0% phase
938 )
940 best_scan_params = self.stage_0_opt(
941 init_pulse_params,
942 fidelity_only_cost,
943 )
945 global_best_params, global_best_history, global_best_loss = (
946 self.stage_1_opt(
947 best_scan_params,
948 total_costs,
949 )
950 )
951 self.save_results(
952 gate=gate_name,
953 fidelity=1 - global_best_loss.item(),
954 pulse_params=global_best_params,
955 )
957 return global_best_params, global_best_history
959 return wrapper
961 return decorator
963 def create_RX(self):
964 """Create pulse and target circuits for the RX gate."""
966 def pulse_circuit(w, pulse_params):
967 Gates.RX(w, 0, pulse_params=pulse_params, gate_mode="pulse")
969 def target_circuit(w):
970 op.RX(w, wires=0)
972 return pulse_circuit, target_circuit
974 def create_RY(self):
975 """Create pulse and target circuits for the RY gate."""
977 def pulse_circuit(w, pulse_params):
978 Gates.RY(w, 0, pulse_params=pulse_params, gate_mode="pulse")
980 def target_circuit(w):
981 op.RY(w, wires=0)
983 return pulse_circuit, target_circuit
985 def create_RZ(self):
986 """Create pulse and target circuits for the RZ gate.
988 Both circuits are sandwiched between Hadamard gates to make the
989 RZ rotation observable in the computational basis.
990 """
992 def pulse_circuit(w, pulse_params):
993 op.H(wires=0)
994 Gates.RZ(w, 0, pulse_params=pulse_params, gate_mode="pulse")
995 op.H(wires=0)
997 def target_circuit(w):
998 op.H(wires=0)
999 op.RZ(w, wires=0)
1000 op.H(wires=0)
1002 return pulse_circuit, target_circuit
1004 def create_H(self):
1005 """Create pulse and target circuits for the Hadamard gate.
1007 An RY rotation is prepended to break symmetry.
1008 """
1010 def pulse_circuit(w, pulse_params):
1011 op.RY(w, wires=0)
1012 Gates.H(0, pulse_params=pulse_params, gate_mode="pulse")
1014 def target_circuit(w):
1015 op.RY(w, wires=0)
1016 op.H(wires=0)
1018 return pulse_circuit, target_circuit
1020 def create_Rot(self):
1021 """Create pulse and target circuits for the general Rot gate."""
1023 def pulse_circuit(w, pulse_params):
1024 op.H(wires=0)
1025 Gates.Rot(w, w * 2, w * 3, 0, pulse_params=pulse_params, gate_mode="pulse")
1027 def target_circuit(w):
1028 op.H(wires=0)
1029 op.Rot(w, w * 2, w * 3, wires=0)
1031 return pulse_circuit, target_circuit
1033 def create_CX(self):
1034 """Create pulse and target circuits for the CX (CNOT) gate."""
1036 def pulse_circuit(w, pulse_params):
1037 op.RY(w, wires=0)
1038 op.H(wires=1)
1039 Gates.CX(wires=[0, 1], pulse_params=pulse_params, gate_mode="pulse")
1041 def target_circuit(w):
1042 op.RY(w, wires=0)
1043 op.H(wires=1)
1044 op.CX(wires=[0, 1])
1046 return pulse_circuit, target_circuit
1048 def create_CY(self):
1049 """Create pulse and target circuits for the CY gate."""
1051 def pulse_circuit(w, pulse_params):
1052 op.RX(w, wires=0)
1053 op.H(wires=1)
1054 Gates.CY(wires=[0, 1], pulse_params=pulse_params, gate_mode="pulse")
1056 def target_circuit(w):
1057 op.RX(w, wires=0)
1058 op.H(wires=1)
1059 op.CY(wires=[0, 1])
1061 return pulse_circuit, target_circuit
1063 def create_CZ(self):
1064 """Create pulse and target circuits for the CZ gate."""
1066 def pulse_circuit(w, pulse_params):
1067 op.RY(w, wires=0)
1068 op.H(wires=1)
1069 Gates.CZ(wires=[0, 1], pulse_params=pulse_params, gate_mode="pulse")
1071 def target_circuit(w):
1072 op.RY(w, wires=0)
1073 op.H(wires=1)
1074 op.CZ(wires=[0, 1])
1076 return pulse_circuit, target_circuit
1078 def create_CRX(self):
1079 """Create pulse and target circuits for the CRX gate."""
1081 def pulse_circuit(w, pulse_params):
1082 op.H(wires=0)
1083 Gates.CRX(w, wires=[0, 1], pulse_params=pulse_params, gate_mode="pulse")
1085 def target_circuit(w):
1086 op.H(wires=0)
1087 op.CRX(w, wires=[0, 1])
1089 return pulse_circuit, target_circuit
1091 def create_CRY(self):
1092 """Create pulse and target circuits for the CRY gate."""
1094 def pulse_circuit(w, pulse_params):
1095 op.H(wires=0)
1096 Gates.CRY(w, wires=[0, 1], pulse_params=pulse_params, gate_mode="pulse")
1098 def target_circuit(w):
1099 op.H(wires=0)
1100 op.CRY(w, wires=[0, 1])
1102 return pulse_circuit, target_circuit
1104 def create_CRZ(self):
1105 """Create pulse and target circuits for the CRZ gate."""
1107 def pulse_circuit(w, pulse_params):
1108 op.H(wires=0)
1109 op.H(wires=1)
1110 Gates.CRZ(w, wires=[0, 1], pulse_params=pulse_params, gate_mode="pulse")
1112 def target_circuit(w):
1113 op.H(wires=0)
1114 op.H(wires=1)
1115 op.CRZ(w, wires=[0, 1])
1117 return pulse_circuit, target_circuit
1119 def optimize_all(self, sel_gates: str, make_log: bool) -> None:
1120 """Optimise all selected gates and optionally write a log CSV.
1122 Args:
1123 sel_gates: Comma-separated gate names or ``"all"``.
1124 make_log: If ``True``, write per-gate loss histories to
1125 ``qml_essentials/qoc_logs.csv``.
1126 """
1127 log_history: Dict[str, list] = {}
1129 for gate in self.GATES_1Q + self.GATES_2Q:
1130 if gate in sel_gates or "all" in sel_gates:
1131 n_wires = 1 if gate in self.GATES_1Q else 2
1132 opt = self.optimize(wires=n_wires)
1133 gate_factory = getattr(self, f"create_{gate}")
1134 log.info(f"Optimizing {gate} gate...")
1135 optimized_pulse_params, loss_history = opt(gate_factory)()
1136 log.info(f"Optimized parameters for {gate}: {optimized_pulse_params}")
1137 best_fid = 1 - min(float(loss) for loss in loss_history)
1138 log.info(f"Best achieved fidelity: {best_fid * 100:.5f}%")
1139 log_history[gate] = log_history.get(gate, []) + loss_history
1141 if make_log:
1142 # write log history to file
1143 with open("qml_essentials/qoc_logs.csv", "w") as f:
1144 writer = csv.writer(f)
1145 writer.writerow(log_history.keys())
1146 writer.writerows(zip(*log_history.values()))
1149default_qoc_params = {
1150 "envelope": "gaussian",
1151 "cost_fns": [
1152 ("fidelity", (0.49999999, 0.49999999)),
1153 ("pulse_width", 0.000000015),
1154 ("evolution_time", 0.000000005),
1155 ],
1156 "t_target": 0.5,
1157 "n_steps": 1500,
1158 "n_samples": 20,
1159 "learning_rate": 0.0001,
1160 "warmup_ratio": 0.05,
1161 "end_lr_ratio": 0.01,
1162 "log_interval": 50,
1163 "file_dir": None,
1164 "n_restarts": 3,
1165 "restart_noise_scale": 0.5,
1166 "grad_clip": 1.0,
1167 "random_seed": 1000,
1168 "scan_steps": 30,
1169 "scan_grid_size": 5,
1170 "scan_ranges": None,
1171 "log_scale_params": None,
1172}
1174if __name__ == "__main__":
1175 # argparse the selected gate
1176 parser = argparse.ArgumentParser(
1177 description="Quantum Optimal Control — pulse-level gate synthesis."
1178 )
1179 parser.add_argument(
1180 "--gates",
1181 type=str,
1182 nargs="+",
1183 default=["RX", "RY", "RZ", "CZ"],
1184 choices=QOC.GATES_1Q + QOC.GATES_2Q + ["all"],
1185 help="Gate(s) to optimize.",
1186 )
1187 parser.add_argument(
1188 "--log",
1189 action="store_true",
1190 default=True,
1191 help="Log results to file (default: True).",
1192 )
1193 parser.add_argument(
1194 "--no-log",
1195 action="store_false",
1196 dest="log",
1197 help="Disable logging results to file.",
1198 )
1199 parser.add_argument(
1200 "--envelope",
1201 type=str,
1202 default=default_qoc_params["envelope"],
1203 choices=PulseEnvelope.available(),
1204 help="Pulse envelope shape to use for optimization.",
1205 )
1206 parser.add_argument(
1207 "--costs",
1208 type=str,
1209 nargs="+",
1210 default=default_qoc_params["cost_fns"],
1211 help=(
1212 "Cost functions and weights as 'name:w1,w2,...' strings. "
1213 "If weights are omitted the registry defaults are used. "
1214 f"Available: {CostFnRegistry.available()}. "
1215 "Example: --costs fidelity:0.5,0.3 pulse_width:0.2"
1216 ),
1217 )
1218 parser.add_argument(
1219 "--t_target",
1220 type=float,
1221 default=default_qoc_params["t_target"],
1222 help=(
1223 "Target evolution time for the 'evolution_time' cost function. "
1224 "All gates will be softly encouraged towards this common time."
1225 ),
1226 )
1227 parser.add_argument(
1228 "--n_steps",
1229 type=int,
1230 default=default_qoc_params["n_steps"],
1231 help="Number of optimisation steps per gate.",
1232 )
1233 parser.add_argument(
1234 "--n_samples",
1235 type=int,
1236 default=default_qoc_params["n_samples"],
1237 help="Number of parameter samples in [0, 2\\pi] for cost evaluation.",
1238 )
1239 parser.add_argument(
1240 "--learning_rate",
1241 type=float,
1242 default=default_qoc_params["learning_rate"],
1243 help="Peak learning rate for the AdamW optimiser.",
1244 )
1245 parser.add_argument(
1246 "--warmup_ratio",
1247 type=float,
1248 default=default_qoc_params["warmup_ratio"],
1249 help=(
1250 "Fraction of n_steps used for linear LR warmup (0.0-1.0). "
1251 "Set to 0 to start at the peak LR immediately."
1252 ),
1253 )
1254 parser.add_argument(
1255 "--end_lr_ratio",
1256 type=float,
1257 default=default_qoc_params["end_lr_ratio"],
1258 help=(
1259 "Final LR as a fraction of --learning_rate after cosine decay. "
1260 "Also used as the initial LR before warmup. "
1261 "Set to 1.0 (with --warmup_ratio 0) for a constant LR."
1262 ),
1263 )
1264 parser.add_argument(
1265 "--log_interval",
1266 type=int,
1267 default=default_qoc_params["log_interval"],
1268 help="Log the current loss every N steps.",
1269 )
1270 parser.add_argument(
1271 "--file_dir",
1272 type=str,
1273 default=default_qoc_params["file_dir"],
1274 help="Directory to save qoc_results.csv. Defaults to the package directory.",
1275 )
1276 parser.add_argument(
1277 "--n_restarts",
1278 type=int,
1279 default=default_qoc_params["n_restarts"],
1280 help=(
1281 "Number of random restarts for the optimisation. "
1282 "The first run uses the initial parameters as-is; "
1283 "subsequent runs add random perturbations. "
1284 "The best result across all restarts is kept."
1285 ),
1286 )
1287 parser.add_argument(
1288 "--restart_noise_scale",
1289 type=float,
1290 default=default_qoc_params["restart_noise_scale"],
1291 help=(
1292 "Standard deviation of Gaussian noise added to the initial "
1293 "parameters for each restart, relative to parameter magnitude."
1294 ),
1295 )
1296 parser.add_argument(
1297 "--grad_clip",
1298 type=float,
1299 default=default_qoc_params["grad_clip"],
1300 help=(
1301 "Maximum global gradient norm. Gradients are clipped to this "
1302 "value before being passed to the optimiser. "
1303 "Set to 0 to disable."
1304 ),
1305 )
1306 parser.add_argument(
1307 "--random_seed",
1308 type=int,
1309 default=default_qoc_params["random_seed"],
1310 help="Base random seed for restart perturbations.",
1311 )
1312 parser.add_argument(
1313 "--scan_steps",
1314 type=int,
1315 default=default_qoc_params["scan_steps"],
1316 help=(
1317 "Number of short gradient-descent steps per candidate in the "
1318 "coarse grid scan (Stage 0). Set to 0 to disable the grid scan."
1319 ),
1320 )
1321 parser.add_argument(
1322 "--scan_grid_size",
1323 type=int,
1324 default=default_qoc_params["scan_grid_size"],
1325 help=(
1326 "Number of points per parameter dimension in the coarse grid. "
1327 "Total candidates = scan_grid_size^n_params."
1328 ),
1329 )
1330 parser.add_argument(
1331 "--scan_ranges",
1332 type=str,
1333 nargs="*",
1334 default=default_qoc_params["scan_ranges"],
1335 help=(
1336 "Per-parameter (lo,hi) ranges for the grid scan, given as "
1337 "'lo,hi' strings. One pair per pulse parameter. "
1338 "Example: --scan_ranges 0.5,30.0 0.05,2.0 0.05,2.0 "
1339 "If omitted, heuristic defaults are used."
1340 ),
1341 )
1343 args = parser.parse_args()
1344 sel_gates = args.gates # already a list from nargs="+"
1345 make_log = args.log
1347 # Parse scan_ranges from CLI (list of "lo,hi" strings -> list of tuples)
1348 scan_ranges = None
1349 if args.scan_ranges is not None:
1350 scan_ranges = []
1351 for pair in args.scan_ranges:
1352 lo, hi = pair.split(",")
1353 scan_ranges.append((float(lo), float(hi)))
1355 # Parse cost function specs from CLI
1356 cost_fns = [CostFnRegistry.parse_cost_arg(spec) for spec in args.costs]
1358 # create logger
1359 log = logging.getLogger("qml_essentials.qoc")
1361 log.setLevel(logging.INFO)
1362 log.addHandler(logging.StreamHandler())
1364 qoc = QOC(
1365 envelope=args.envelope,
1366 cost_fns=cost_fns,
1367 t_target=args.t_target,
1368 n_steps=args.n_steps,
1369 n_samples=args.n_samples,
1370 learning_rate=args.learning_rate,
1371 warmup_ratio=args.warmup_ratio,
1372 end_lr_ratio=args.end_lr_ratio,
1373 log_interval=args.log_interval,
1374 file_dir=args.file_dir,
1375 n_restarts=args.n_restarts,
1376 restart_noise_scale=args.restart_noise_scale,
1377 grad_clip=args.grad_clip,
1378 random_seed=args.random_seed,
1379 scan_steps=args.scan_steps,
1380 scan_grid_size=args.scan_grid_size,
1381 scan_ranges=scan_ranges,
1382 )
1384 qoc.optimize_all(sel_gates=sel_gates, make_log=make_log)