Coverage for qml_essentials / pulses.py: 84%
532 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-16 10:19 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-16 10:19 +0000
1import os
2from contextlib import contextmanager
3from dataclasses import dataclass
4from typing import Optional, List, Union, Dict, Callable, Tuple
5import csv
6import jax.numpy as jnp
7import jax
9from qml_essentials import operations as op
10from qml_essentials import yaqsi as ys
11from qml_essentials.utils import safe_random_split
12from qml_essentials.tape import active_pulse_tape
13from qml_essentials.unitary import UnitaryGates
14import logging
16log = logging.getLogger(__name__)
19@dataclass
20class DecompositionStep:
21 """One step in a composite pulse gate decomposition.
23 Attributes:
24 gate: Child PulseParams object for this step.
25 wire_fn: Wire selection - ``"all"``, ``"target"``, or ``"control"``.
26 angle_fn: Maps parent angle(s) ``w`` to child angle.
27 ``None`` means pass ``w`` through unchanged.
28 """
30 gate: "PulseParams"
31 wire_fn: str = "all"
32 angle_fn: Optional[Callable] = None
35@dataclass(frozen=True)
36class PulseStateSnapshot:
37 """Snapshot of the mutable global pulse configuration."""
39 envelope: str
40 rwa: bool
41 frame: str
42 leaf_params: Dict[str, jnp.ndarray]
45class PulseParams:
46 """Container for hierarchical pulse parameters.
48 Leaf nodes hold direct parameters; composite nodes hold a list of
49 :class:`DecompositionStep` objects that describe how the gate is
50 built from simpler gates.
52 Attributes:
53 name: Gate identifier (e.g. ``"RX"``, ``"H"``).
54 decomposition: List of :class:`DecompositionStep` (composite only).
55 """
57 def __init__(
58 self,
59 name: str = "",
60 params: Optional[jnp.ndarray] = None,
61 decomposition: Optional[List[DecompositionStep]] = None,
62 ) -> None:
63 """
64 Args:
65 name: Gate name.
66 params: Direct pulse parameters (leaf gates).
67 Mutually exclusive with *decomposition*.
68 decomposition: List of :class:`DecompositionStep` (composite gates).
69 Mutually exclusive with *params*.
70 """
71 assert (params is None) != (decomposition is None), (
72 "Exactly one of `params` or `decomposition` must be provided."
73 )
75 self.decomposition = decomposition
76 # Derive _pulse_obj for backward compat with childs/leafs/split_params
77 self._pulse_obj = (
78 [step.gate for step in decomposition] if decomposition else None
79 )
81 if params is not None:
82 self._params = params
84 self.name = name
86 def __len__(self) -> int:
87 """
88 Get the total number of pulse parameters.
90 For composite gates, returns the accumulated count from all children.
92 Returns:
93 int: Total number of pulse parameters.
94 """
95 return len(self.params)
97 def __getitem__(self, idx: int) -> Union[float, jnp.ndarray]:
98 """
99 Access pulse parameter(s) by index.
101 For leaf gates, returns the parameter at the given index.
102 For composite gates, returns parameters of the child at the given index.
104 Args:
105 idx (int): Index to access.
107 Returns:
108 Union[float, jnp.ndarray]: Parameter value or child parameters.
109 """
110 if self.is_leaf:
111 return self.params[idx]
112 else:
113 return self.childs[idx].params
115 def __str__(self) -> str:
116 """Return string representation (gate name)."""
117 return self.name
119 def __repr__(self) -> str:
120 """Return repr string (gate name)."""
121 return self.name
123 @property
124 def is_leaf(self) -> bool:
125 """Check if this is a leaf node (direct parameters, no children)."""
126 return self._pulse_obj is None
128 @property
129 def size(self) -> int:
130 """Get the total parameter count (alias for __len__)."""
131 return len(self)
133 @property
134 def leafs(self) -> List["PulseParams"]:
135 """
136 Get all leaf nodes in the hierarchy.
138 Recursively collects all leaf PulseParams objects in the tree.
140 Returns:
141 List[PulseParams]: List of unique leaf nodes.
142 """
143 if self.is_leaf:
144 return [self]
146 leafs = []
147 for obj in self._pulse_obj:
148 leafs.extend(obj.leafs)
150 return list(set(leafs))
152 @property
153 def childs(self) -> List["PulseParams"]:
154 """
155 Get direct children of this node.
157 Returns:
158 List[PulseParams]: List of child PulseParams objects, or empty list
159 if this is a leaf node.
160 """
161 if self.is_leaf:
162 return []
164 return self._pulse_obj
166 @property
167 def shape(self) -> List[int]:
168 """
169 Get the shape of pulse parameters.
171 For leaf nodes, returns list with parameter count.
172 For composite nodes, returns nested list of child shapes.
174 Returns:
175 List[int]: Parameter shape specification.
176 """
177 if self.is_leaf:
178 return [len(self.params)]
180 shape = []
181 for obj in self.childs:
182 shape.append(*obj.shape())
184 return shape
186 @property
187 def params(self) -> jnp.ndarray:
188 """
189 Get or compute pulse parameters.
191 For leaf nodes, returns internal pulse parameters.
192 For composite nodes, returns concatenated parameters from all children.
194 Returns:
195 jnp.ndarray: Pulse parameters array.
196 """
197 if self.is_leaf:
198 return self._params
200 params = self.split_params(params=None, leafs=False)
202 return jnp.concatenate(params)
204 @params.setter
205 def params(self, value: jnp.ndarray) -> None:
206 """
207 Set pulse parameters.
209 For leaf nodes, sets internal parameters directly.
210 For composite nodes, distributes values across children.
212 Args:
213 value (jnp.ndarray): Pulse parameters to set.
215 Raises:
216 AssertionError: If value is not jnp.ndarray for leaf nodes.
217 """
218 if self.is_leaf:
219 assert isinstance(value, jnp.ndarray), "params must be a jnp.ndarray"
220 self._params = value
221 return
223 idx = 0
224 for obj in self.childs:
225 nidx = idx + obj.size
226 obj.params = value[idx:nidx]
227 idx = nidx
229 @property
230 def leaf_params(self) -> jnp.ndarray:
231 """
232 Get parameters from all leaf nodes.
234 Returns:
235 jnp.ndarray: Concatenated parameters from all leaf nodes.
236 """
237 if self.is_leaf:
238 return self._params
240 params = self.split_params(None, leafs=True)
242 return jnp.concatenate(params)
244 @leaf_params.setter
245 def leaf_params(self, value: jnp.ndarray) -> None:
246 """
247 Set parameters for all leaf nodes.
249 Args:
250 value (jnp.ndarray): Parameters to distribute across leaf nodes.
251 """
252 if self.is_leaf:
253 self._params = value
254 return
256 idx = 0
257 for obj in self.leafs:
258 nidx = idx + obj.size
259 obj.params = value[idx:nidx]
260 idx = nidx
262 def split_params(
263 self,
264 params: Optional[jnp.ndarray] = None,
265 leafs: bool = False,
266 ) -> List[jnp.ndarray]:
267 """
268 Split parameters into sub-arrays for children or leaves.
270 Args:
271 params (Optional[jnp.ndarray]): Parameters to split. If None,
272 uses internal parameters.
273 leafs (bool): If True, splits across leaf nodes; if False,
274 splits across direct children. Defaults to False.
276 Returns:
277 List[jnp.ndarray]: List of parameter arrays for children or leaves.
278 """
279 if params is None:
280 if self.is_leaf:
281 return self._params
283 objs = self.leafs if leafs else self.childs
284 s_params = []
285 for obj in objs:
286 s_params.append(obj.params)
288 return s_params
289 else:
290 if self.is_leaf:
291 return params
293 objs = self.leafs if leafs else self.childs
294 s_params = []
295 idx = 0
296 for obj in objs:
297 nidx = idx + obj.size
298 s_params.append(params[idx:nidx])
299 idx = nidx
301 return s_params
304class PulseEnvelope:
305 """Registry of pulse envelope shapes.
307 Each envelope is a pure function ``(p, t, t_c) -> amplitude`` that
308 computes the pulse envelope *without* carrier modulation. The carrier
309 ``cos(omega_c * t + phi_c)`` is applied separately in the coefficient
310 functions built by :meth:`build_coeff_fns`.
312 Attributes:
313 REGISTRY: Mapping from envelope name to metadata dict containing
314 ``fn`` (callable), ``n_envelope_params`` (int), and per-gate
315 default parameter arrays.
316 """
318 @staticmethod
319 def gaussian(p, t, t_c):
320 """Gaussian envelope. ``p = [A, sigma]``."""
321 A, sigma = p[0], p[1]
322 return A * jnp.exp(-0.5 * ((t - t_c) / sigma) ** 2)
324 @staticmethod
325 def square(p, t, t_c):
326 """Rectangular envelope. ``p = [A, width]``."""
327 A, width = p[0], p[1]
328 return A * (jnp.abs(t - t_c) <= width / 2)
330 @staticmethod
331 def cosine(p, t, t_c):
332 """Raised cosine envelope. ``p = [A, width]``."""
333 A, width = p[0], p[1]
334 x = jnp.clip((t - t_c) / width, -0.5, 0.5)
335 return A * jnp.cos(jnp.pi * x)
337 @staticmethod
338 def drag(p, t, t_c):
339 """DRAG (Derivative Removal by Adiabatic Gate). ``p = [A, beta, sigma]``."""
340 A, beta, sigma = p[0], p[1], p[2]
341 g = A * jnp.exp(-0.5 * ((t - t_c) / sigma) ** 2)
342 dg = g * (-(t - t_c) / sigma**2)
343 return g + beta * dg
345 @staticmethod
346 def sech(p, t, t_c):
347 """Hyperbolic secant envelope. ``p = [A, sigma]``."""
348 A, sigma = p[0], p[1]
349 return A / jnp.cosh((t - t_c) / sigma)
351 # ``n_envelope_params`` counts only the envelope parameters (excluding
352 # the evolution time ``t`` which is always the last element of the full
353 # pulse parameter vector).
354 REGISTRY = {
355 "gaussian": {
356 "fn": gaussian.__func__,
357 "n_envelope_params": 2,
358 "defaults": {
359 "RX": jnp.array(
360 [0.38009941846766804, 1.631698142660167, 3.007403822238108]
361 ),
362 "RY": jnp.array(
363 [0.3836652338514791, 1.616595983505249, 2.9794135093698966]
364 ),
365 },
366 },
367 "square": {
368 "fn": square.__func__,
369 "n_envelope_params": 2,
370 "defaults": {
371 "RX": jnp.array(
372 [1.209655637514602, 0.8266815576721239, 1.1483122857413859]
373 ),
374 "RY": jnp.array(
375 [1.0287942142779052, 0.9860505130182093, 0.9720116870310977]
376 ),
377 },
378 },
379 "cosine": {
380 "fn": cosine.__func__,
381 "n_envelope_params": 2,
382 "defaults": {
383 "RX": jnp.array([1.0, 1.0, 1.0]),
384 "RY": jnp.array([1.0, 1.0, 1.0]),
385 },
386 },
387 "drag": {
388 "fn": drag.__func__,
389 "n_envelope_params": 3,
390 "defaults": {
391 "RX": jnp.array(
392 [
393 0.326562746114197,
394 0.4002767596709071,
395 5.3228107728890315,
396 3.141300761986467,
397 ]
398 ),
399 "RY": jnp.array(
400 [
401 0.323287924190616,
402 0.4065017233024265,
403 7.00299644871222,
404 3.139481229843545,
405 ]
406 ),
407 },
408 },
409 "sech": {
410 "fn": sech.__func__,
411 "n_envelope_params": 2,
412 "defaults": {
413 "RX": jnp.array([1.0, 1.0, 1.0]),
414 "RY": jnp.array([1.0, 1.0, 1.0]),
415 },
416 },
417 "general": {
418 "fn": None,
419 "n_envelope_params": 0,
420 "defaults": {
421 "RZ": jnp.array([0.5]),
422 "CZ": jnp.array([0.3183098783513154]),
423 },
424 },
425 }
427 @staticmethod
428 def available() -> List[str]:
429 """Return list of registered envelope names."""
430 return list(PulseEnvelope.REGISTRY.keys())
432 @staticmethod
433 def get(name: str) -> dict:
434 """Look up envelope metadata by name.
436 Raises:
437 ValueError: If *name* is not registered.
438 """
439 if name not in PulseEnvelope.REGISTRY:
440 raise ValueError(
441 f"Unknown pulse envelope '{name}'. "
442 f"Available: {PulseEnvelope.available()}"
443 )
444 return PulseEnvelope.REGISTRY[name]
446 @staticmethod
447 def build_coeff_fns(
448 envelope_fn: Callable,
449 omega_c: float,
450 omega_q: float,
451 rwa: bool = True,
452 frame: str = "drive",
453 ) -> Tuple[Callable, Callable, Callable, Callable]:
454 """Build the four interaction-picture coefficient functions.
456 The lab-frame Hamiltonian is
458 H(t,Π) = H_static + Σ_j S_j(t;Π) H_j ,
459 S_j(t;Π) = E_j(t;Π) · cos(ω_c·t + φ_c) ,
461 and the interaction-picture transform with respect to
462 ``H_static = (ω_q/2)·Z`` produces
464 H̃_j(t) = exp(+i H_static t) H_j exp(-i H_static t) ,
465 H_I(t) = Σ_j S_j(t) H̃_j(t) .
467 For a single qubit driven on X, ``H̃_X(t) = cos(ω_q·t) X
468 − sin(ω_q·t) Y``, so
470 H_I(t) = Ω(t) · cos(ω_c·t + φ) ·
471 [ cos(ω_q·t) · X − sin(ω_q·t) · Y ] .
473 ``rwa=True`` (default) drops the fast (~2·ω_q on resonance) terms and
474 keeps only the slow envelope, yielding the analytical RWA
476 H_I^RWA(t) = (Ω(t)/2) · [ cos(φ) X + sin(φ) Y ] .
478 For RX (``φ = 0``) this reduces to ``(Ω/2)·X``; for RY
479 (``φ = +π/2``) to ``(Ω/2)·Y``. This is dramatically cheaper to
480 integrate (no fast oscillations → adaptive ODE solver takes
481 large steps).
483 ``rwa=False`` keeps **both** the slow and the fast
484 counter-rotating components.
486 Each returned function has a unique ``__code__`` object so the
487 yaqsi solver cache assigns separate compiled XLA programs per
488 envelope shape and per (gate, component) pair.
490 The rotation angle ``w`` is expected as the **last** element of
491 the parameter array ``p`` (i.e. ``p[-1]``). Envelope parameters
492 occupy ``p[:-1]``.
494 Args:
495 envelope_fn: Pure envelope function ``(p, t, t_c) -> scalar``.
496 omega_c: Carrier frequency.
497 omega_q: Qubit frequency (interaction-picture rotation rate).
498 rwa: When ``True``, return the RWA-truncated coefficients
499 (no fast counter-rotating terms). Default ``True``
500 frame: Algebraic representation of the exact (non-RWA)
501 coefficients. Mathematically equivalent options:
503 * ``"drive"`` (default): applies the product-to-sum identity to
504 expose the slow ``(ω_c-ω_q)`` and fast ``(ω_c+ω_q)``
505 modes explicitly,
506 ``cos(ω_c t)cos(ω_q t) =
507 ½[cos((ω_c-ω_q)t) + cos((ω_c+ω_q)t)]``. Algebraically
508 identical to ``"lab"`` (no RWA, no information lost).
509 Primary use: combined with the ``magnus2``/``magnus4``
510 yaqsi solvers, the explicit slow/fast decomposition
511 is sometimes numerically better-conditioned and lets
512 the user pick a fixed grid based on the slow
513 frequency alone (``Δ = |ω_c-ω_q|``) when the fast
514 ``(ω_c+ω_q)`` mode is well-resolved by the chosen
515 step.
516 * ``"drive"``: the literal form
517 ``Ω(t) cos(ω_c t + φ) cos(ω_q t)`` (and the analogous
518 ``-sin`` term). Two trig multiplications per call;
519 contains all four product frequencies implicitly.
521 Ignored when ``rwa=True``.
523 Returns:
524 Tuple ``(coeff_RX_X, coeff_RX_Y, coeff_RY_X, coeff_RY_Y)``
525 of coefficient functions for the X- and Y-components of the
526 RX and RY interaction-picture Hamiltonians.
527 """
528 if frame not in ("lab", "drive"):
529 raise ValueError(f"Unknown frame {frame!r}; expected 'lab' or 'drive'.")
530 if rwa:
531 # RWA-truncated coefficients (no carrier, no fast factors).
532 # H_I^RWA = (Ω(t)/2) [cos(φ) X + sin(φ) Y]; we keep the
533 # ``p[-1]`` rotation-angle scaling so the calling
534 # ParametrizedHamiltonian shape is unchanged.
535 half = jnp.asarray(0.5)
537 def _coeff_RX_X(p, t):
538 t_c = t / 2
539 env = envelope_fn(p, t, t_c)
540 return half * env * p[-1]
542 def _coeff_RX_Y(p, t): # Y component vanishes for RX (φ=0)
543 t_c = t / 2
544 env = envelope_fn(p, t, t_c)
545 return jnp.zeros_like(half * env * p[-1])
547 def _coeff_RY_X(p, t): # X component vanishes for RY (φ=π/2)
548 t_c = t / 2
549 env = envelope_fn(p, t, t_c)
550 return jnp.zeros_like(half * env * p[-1])
552 def _coeff_RY_Y(p, t):
553 t_c = t / 2
554 env = envelope_fn(p, t, t_c)
555 return half * env * p[-1]
557 return _coeff_RX_X, _coeff_RX_Y, _coeff_RY_X, _coeff_RY_Y
559 if frame == "drive":
560 # Drive-frame: same exact dynamics, expressed via the
561 # product-to-sum identities so the slow ``Δ = ω_c - ω_q``
562 # and fast ``Σ = ω_c + ω_q`` modes appear explicitly.
563 # Mathematically identical to the ``lab`` branch below.
564 #
565 # Identities used:
566 # cos(ω_c t) cos(ω_q t) = ½[cos(Δ t) + cos(Σ t)]
567 # cos(ω_c t) sin(ω_q t) = ½[sin(Σ t) − sin(Δ t)]
568 # −sin(ω_c t) cos(ω_q t) = −½[sin(Σ t) + sin(Δ t)]
569 # −sin(ω_c t) sin(ω_q t) = ½[cos(Σ t) − cos(Δ t)]
570 # (RY uses cos(ω_c t + π/2) = −sin(ω_c t).)
571 omega_d = omega_c - omega_q
572 omega_s = omega_c + omega_q
573 half = jnp.asarray(0.5)
575 def _coeff_RX_X(p, t):
576 t_c = t / 2
577 env = envelope_fn(p, t, t_c)
578 mod = half * (jnp.cos(omega_d * t) + jnp.cos(omega_s * t))
579 return env * mod * p[-1]
581 def _coeff_RX_Y(p, t):
582 t_c = t / 2
583 env = envelope_fn(p, t, t_c)
584 mod = -half * (jnp.sin(omega_s * t) - jnp.sin(omega_d * t))
585 return env * mod * p[-1]
587 def _coeff_RY_X(p, t):
588 t_c = t / 2
589 env = envelope_fn(p, t, t_c)
590 mod = -half * (jnp.sin(omega_s * t) + jnp.sin(omega_d * t))
591 return env * mod * p[-1]
593 def _coeff_RY_Y(p, t):
594 t_c = t / 2
595 env = envelope_fn(p, t, t_c)
596 mod = -half * (jnp.cos(omega_s * t) - jnp.cos(omega_d * t))
597 return env * mod * p[-1]
599 return _coeff_RX_X, _coeff_RX_Y, _coeff_RY_X, _coeff_RY_Y
601 # RX uses carrier phase phi = 0 so that after RWA
602 # cos(ω_q τ)·cos(ω_q τ) averages to +1/2 → drives +X
603 # -cos(ω_q τ)·sin(ω_q τ) averages to 0 → Y cancels
604 # giving H_I^RWA ≈ (Ω/2)·X → U ≈ exp(-iθ/2 X), matching op.RX.
605 # The exact form below KEEPS the fast 2·ω_q components.
606 def _coeff_RX_X(p, t):
607 t_c = t / 2
608 env = envelope_fn(p, t, t_c)
609 carrier = jnp.cos(omega_c * t)
610 return env * carrier * jnp.cos(omega_q * t) * p[-1]
612 def _coeff_RX_Y(p, t):
613 t_c = t / 2
614 env = envelope_fn(p, t, t_c)
615 carrier = jnp.cos(omega_c * t)
616 return -env * carrier * jnp.sin(omega_q * t) * p[-1]
618 # RY uses carrier phase phi = +pi/2 so the RWA component drives +Y.
619 def _coeff_RY_X(p, t):
620 t_c = t / 2
621 env = envelope_fn(p, t, t_c)
622 carrier = jnp.cos(omega_c * t + jnp.pi / 2)
623 return env * carrier * jnp.cos(omega_q * t) * p[-1]
625 def _coeff_RY_Y(p, t):
626 t_c = t / 2
627 env = envelope_fn(p, t, t_c)
628 carrier = jnp.cos(omega_c * t + jnp.pi / 2)
629 return -env * carrier * jnp.sin(omega_q * t) * p[-1]
631 return _coeff_RX_X, _coeff_RX_Y, _coeff_RY_X, _coeff_RY_Y
634class PulseInformation:
635 """Stores pulse parameter counts and optimized pulse parameters.
637 Call :meth:`set_envelope` to switch the active pulse shape. This
638 rebuilds all :class:`PulseParams` trees so that parameter counts
639 and defaults match the selected envelope.
640 """
642 DEFAULT_ENVELOPE: str = "drag"
643 DEFAULT_RWA: bool = True
644 DEFAULT_FRAME: str = "drive"
645 LEAF_GATE_NAMES: Tuple[str, ...] = ("RX", "RY", "RZ", "CZ")
647 _envelope: str = DEFAULT_ENVELOPE
648 # Whether to apply the rotating-wave approximation when building the
649 # interaction-picture coefficient functions.
650 # Default ``True`` (exact dynamics, no RWA).
651 # Setting to ``True`` drops the fast counter-rotating terms —
652 # much faster to integrate
653 # See :meth:`PulseEnvelope.build_coeff_fns`.
654 _rwa: bool = DEFAULT_RWA
655 # Algebraic representation of the (non-RWA) coefficients. Either
656 # ``"lab"`` or ``"drive"`` (product-to-sum decomposition).
657 # Mathematically equivalent — see :meth:`PulseEnvelope.build_coeff_fns`
658 # when ``"drive"`` is numerically advantageous (mainly with the Magnus solvers).
659 _frame: str = DEFAULT_FRAME
661 @classmethod
662 def _build_leaf_gates(cls):
663 """(Re-)create leaf PulseParams from the active envelope defaults."""
664 defaults = PulseEnvelope.get(cls._envelope)["defaults"]
665 general = PulseEnvelope.get("general")["defaults"]
667 cls.RX = PulseParams(name="RX", params=defaults["RX"])
668 cls.RY = PulseParams(name="RY", params=defaults["RY"])
670 cls.RZ = PulseParams(name="RZ", params=general["RZ"])
671 cls.CZ = PulseParams(name="CZ", params=general["CZ"])
673 @classmethod
674 def _build_composite_gates(cls):
675 """(Re-)create composite PulseParams trees from current leaves."""
676 cls.H = PulseParams(
677 name="H",
678 decomposition=[
679 DecompositionStep(cls.RZ, "all", lambda w: jnp.pi),
680 DecompositionStep(cls.RY, "all", lambda w: jnp.pi / 2),
681 ],
682 )
683 cls.CX = PulseParams(
684 name="CX",
685 decomposition=[
686 DecompositionStep(cls.H, "target", lambda w: 0.0),
687 DecompositionStep(cls.CZ, "all", lambda w: 0.0),
688 DecompositionStep(cls.H, "target", lambda w: 0.0),
689 ],
690 )
691 cls.CY = PulseParams(
692 name="CY",
693 decomposition=[
694 DecompositionStep(cls.RZ, "target", lambda w: -jnp.pi / 2),
695 DecompositionStep(cls.CX, "all"),
696 DecompositionStep(cls.RZ, "target", lambda w: jnp.pi / 2),
697 ],
698 )
699 cls.CRX = PulseParams(
700 name="CRX",
701 decomposition=[
702 DecompositionStep(cls.RZ, "target", lambda w: jnp.pi / 2),
703 DecompositionStep(cls.RY, "target", lambda w: w / 2),
704 DecompositionStep(cls.CX, "all", lambda w: 0.0),
705 DecompositionStep(cls.RY, "target", lambda w: -w / 2),
706 DecompositionStep(cls.CX, "all", lambda w: 0.0),
707 DecompositionStep(cls.RZ, "target", lambda w: -jnp.pi / 2),
708 ],
709 )
710 cls.CRY = PulseParams(
711 name="CRY",
712 decomposition=[
713 DecompositionStep(cls.RY, "target", lambda w: w / 2),
714 DecompositionStep(cls.CX, "all", lambda w: 0.0),
715 DecompositionStep(cls.RY, "target", lambda w: -w / 2),
716 DecompositionStep(cls.CX, "all", lambda w: 0.0),
717 ],
718 )
719 cls.CRZ = PulseParams(
720 name="CRZ",
721 decomposition=[
722 DecompositionStep(cls.RZ, "target", lambda w: w / 2),
723 DecompositionStep(cls.CX, "all", lambda w: 0.0),
724 DecompositionStep(cls.RZ, "target", lambda w: -w / 2),
725 DecompositionStep(cls.CX, "all", lambda w: 0.0),
726 ],
727 )
728 # TODO: check if we could just make this a basis gate instead
729 cls.CPhase = PulseParams(
730 name="CPhase",
731 decomposition=[
732 DecompositionStep(cls.RZ, "control", lambda w: w / 2),
733 DecompositionStep(cls.RZ, "target", lambda w: w / 2),
734 DecompositionStep(cls.CX, "all", lambda w: 0.0),
735 DecompositionStep(cls.RZ, "target", lambda w: -w / 2),
736 DecompositionStep(cls.CX, "all", lambda w: 0.0),
737 ],
738 )
739 cls.Rot = PulseParams(
740 name="Rot",
741 decomposition=[
742 DecompositionStep(cls.RZ, "all", lambda w: w[0]),
743 DecompositionStep(cls.RY, "all", lambda w: w[1]),
744 DecompositionStep(cls.RZ, "all", lambda w: w[2]),
745 ],
746 )
747 cls.unique_gate_set = [cls.RX, cls.RY, cls.RZ, cls.CZ]
749 @classmethod
750 def set_envelope(
751 cls,
752 name: str,
753 rwa: Optional[bool] = None,
754 frame: Optional[str] = None,
755 ) -> None:
756 """Switch pulse envelope and rebuild all PulseParams trees.
758 Also updates the coefficient functions used by :class:`PulseGates`.
760 Args:
761 name: One of :meth:`PulseEnvelope.available`.
762 rwa: If given, also update the RWA flag. If ``None`` (the
763 default), the current value of ``cls._rwa`` is kept.
764 See :meth:`PulseEnvelope.build_coeff_fns` for the
765 physical meaning of the flag.
766 frame: If given, also update the coefficient frame
767 (``"lab"`` or ``"drive"``). ``None`` keeps the current
768 value of ``cls._frame``. Ignored when ``rwa=True`` or
769 when the existing RWA flag is on.
770 """
771 info = PulseEnvelope.get(name) # validates name
772 cls._envelope = name
773 if rwa is not None:
774 cls._rwa = bool(rwa)
775 if frame is not None:
776 if frame not in ("lab", "drive"):
777 raise ValueError(f"Unknown frame {frame!r}; expected 'lab' or 'drive'.")
778 cls._frame = frame
779 cls._build_leaf_gates()
780 cls._build_composite_gates()
782 # Rebuild interaction-picture coefficient functions on PulseGates.
783 # Four functions: (RX_X, RX_Y, RY_X, RY_Y) — one per (gate, Pauli)
784 # component of the proper interaction-picture drive Hamiltonian.
785 rx_x, rx_y, ry_x, ry_y = PulseEnvelope.build_coeff_fns(
786 info["fn"],
787 PulseGates.omega_c,
788 PulseGates.omega_q,
789 rwa=cls._rwa,
790 frame=cls._frame,
791 )
792 PulseGates._coeff_RX_X = staticmethod(rx_x)
793 PulseGates._coeff_RX_Y = staticmethod(rx_y)
794 PulseGates._coeff_RY_X = staticmethod(ry_x)
795 PulseGates._coeff_RY_Y = staticmethod(ry_y)
796 # Backward-compat aliases for older introspection (point at the
797 # X-component which dominates RX, Y-component which dominates RY).
798 PulseGates._coeff_Sx = staticmethod(rx_x)
799 PulseGates._coeff_Sy = staticmethod(ry_y)
800 PulseGates._active_envelope = name
801 PulseGates._active_rwa = cls._rwa
802 PulseGates._active_frame = cls._frame
804 # The compiled-solver cache in ``Yaqsi`` is keyed on the code
805 # objects of the coefficient functions. Rebuilding the coeff
806 # fns above produced fresh code objects, so any cached solver
807 # is now unreachable from the live coefficient functions and
808 # must be evicted to avoid both (a) holding compiled programs
809 # for a previous configuration alive forever and (b) returning
810 # a stale program if ``id`` collisions ever leaked through.
811 # Lazy import to prevent circular imports.
812 from qml_essentials.yaqsi import Yaqsi
814 Yaqsi.clear_evolve_solver_cache()
816 log.info(
817 f"Pulse envelope set to '{name}' "
818 f"(RWA {'on' if cls._rwa else 'off'}, frame={cls._frame})"
819 )
821 @classmethod
822 def set_rwa(cls, rwa: bool) -> None:
823 """Toggle the rotating-wave approximation for pulse coefficients.
825 Rebuilds the coefficient functions for the currently active
826 envelope so the change takes effect immediately. Default is
827 ``False`` (exact interaction picture).
828 See :meth:`PulseEnvelope.build_coeff_fns` for details
829 """
830 cls.set_envelope(cls._envelope, rwa=bool(rwa))
832 @classmethod
833 def get_envelope(cls) -> str:
834 """Return the name of the active pulse envelope."""
835 return cls._envelope
837 @classmethod
838 def get_rwa(cls) -> bool:
839 """Return whether the RWA flag is currently active."""
840 return cls._rwa
842 @classmethod
843 def set_frame(cls, frame: str) -> None:
844 """Switch the algebraic representation of the (non-RWA) coefficients.
846 ``"lab"`` (default) and ``"drive"`` are mathematically
847 identical (no information lost, no RWA applied) — see
848 :meth:`PulseEnvelope.build_coeff_fns` for when ``"drive"`` is
849 useful. Rebuilds the coefficient functions for the currently
850 active envelope so the change takes effect immediately.
851 """
852 cls.set_envelope(cls._envelope, frame=str(frame))
854 @classmethod
855 def get_frame(cls) -> str:
856 """Return the active coefficient frame (``"lab"`` or ``"drive"``)."""
857 return cls._frame
859 @classmethod
860 def snapshot_state(cls) -> PulseStateSnapshot:
861 """Return an immutable snapshot of the active pulse configuration."""
862 leaf_params = {}
863 for name in cls.LEAF_GATE_NAMES:
864 gate = getattr(cls, name, None)
865 if gate is not None:
866 leaf_params[name] = jnp.array(gate.params)
868 return PulseStateSnapshot(
869 envelope=cls._envelope,
870 rwa=cls._rwa,
871 frame=cls._frame,
872 leaf_params=leaf_params,
873 )
875 @classmethod
876 def restore_state(cls, snapshot: PulseStateSnapshot) -> None:
877 """Restore a snapshot produced by :meth:`snapshot_state`."""
878 cls.set_envelope(snapshot.envelope, rwa=snapshot.rwa, frame=snapshot.frame)
880 for name, params in snapshot.leaf_params.items():
881 gate = cls.gate_by_name(name)
882 if gate is None or not gate.is_leaf:
883 raise ValueError(f"Cannot restore unknown leaf pulse gate {name!r}.")
884 if gate.params.shape != params.shape:
885 raise ValueError(
886 f"Snapshot for {name!r} has shape {params.shape}, "
887 f"but active gate expects {gate.params.shape}."
888 )
889 gate.params = params
891 @classmethod
892 @contextmanager
893 def preserve_state(cls):
894 """Temporarily preserve global pulse state across scoped mutations."""
895 snapshot = cls.snapshot_state()
896 try:
897 yield snapshot
898 finally:
899 cls.restore_state(snapshot)
901 @classmethod
902 def reset_defaults(
903 cls,
904 envelope: Optional[str] = None,
905 rwa: Optional[bool] = None,
906 frame: Optional[str] = None,
907 ) -> None:
908 """Reset pulse globals to canonical defaults or explicit values."""
909 cls.set_envelope(
910 cls.DEFAULT_ENVELOPE if envelope is None else envelope,
911 rwa=cls.DEFAULT_RWA if rwa is None else rwa,
912 frame=cls.DEFAULT_FRAME if frame is None else frame,
913 )
915 @staticmethod
916 def gate_by_name(gate):
917 if isinstance(gate, str):
918 return getattr(PulseInformation, gate, None)
919 else:
920 return getattr(PulseInformation, gate.__name__, None)
922 @staticmethod
923 def num_params(gate):
924 return len(PulseInformation.gate_by_name(gate))
926 @staticmethod
927 def update_params(path=f"{os.getcwd()}/qml_essentials/qoc_results.csv"):
928 if os.path.isfile(path):
929 log.info(f"Loading optimized pulses from {path}")
930 with open(path, "r") as f:
931 reader = csv.reader(f)
933 for row in reader:
934 log.debug(
935 f"Loading optimized pulses for {row[0]}\
936 (Fidelity: {float(row[1]):.5f}): {row[2:]}"
937 )
938 PulseInformation.OPTIMIZED_PULSES[row[0]] = jnp.array(
939 [float(x) for x in row[2:]]
940 )
941 else:
942 log.error(f"No optimized pulses found at {path}")
944 @staticmethod
945 def shuffle_params(random_key):
946 log.info(
947 f"Shuffling optimized pulses with random key {random_key}\
948 of gates {PulseInformation.unique_gate_set}"
949 )
950 for gate in PulseInformation.unique_gate_set:
951 random_key, sub_key = safe_random_split(random_key)
952 gate.params = jax.random.uniform(sub_key, (len(gate),))
955class PulseGates:
956 """Pulse-level implementations of quantum gates.
958 Implements quantum gates using time-dependent Hamiltonians and pulse
959 sequences, following the approach from https://doi.org/10.5445/IR/1000184129.
960 The active pulse envelope is selected via
961 :meth:`PulseInformation.set_envelope`.
963 Attributes:
964 omega_q: Qubit frequency (10π).
965 omega_c: Carrier frequency (10π).
966 _active_envelope: Name of the currently active envelope shape.
967 """
969 # NOTE: Implementation of S, RX, RY, RZ, CZ, CNOT/CX and H pulse level
970 # gates closely follow https://doi.org/10.5445/IR/1000184129
971 omega_q = 10 * jnp.pi
972 omega_c = 10 * jnp.pi
974 X = jnp.array([[0, 1], [1, 0]])
975 Y = jnp.array([[0, -1j], [1j, 0]])
976 Z = jnp.array([[1, 0], [0, -1]])
978 Id = jnp.eye(2, dtype=jnp.complex64)
980 _H_CZ = (jnp.pi / 4) * (
981 jnp.kron(Id, Id) - jnp.kron(Z, Id) - jnp.kron(Id, Z) + jnp.kron(Z, Z)
982 )
984 _H_corr = jnp.pi / 2 * jnp.eye(2, dtype=jnp.complex64)
986 _active_envelope: str = "gaussian"
987 # Mirrors :attr:`PulseInformation._rwa`; kept here for introspection
988 # of which coefficient regime the active ``_coeff_*`` functions
989 # implement. Updated by :meth:`PulseInformation.set_envelope` /
990 # :meth:`PulseInformation.set_rwa`.
991 _active_rwa: bool = True
992 _active_frame: str = "drive"
994 # Default coefficient functions for the gaussian envelope; the active
995 # envelope's `set_envelope` will overwrite these. Each gate uses two
996 # coefficients (X- and Y-component of the proper interaction-picture
997 # drive Hamiltonian).
999 @staticmethod
1000 def _coeff_RX_X(p, t):
1001 """RX coefficient for the X term (gaussian default)."""
1002 t_c = t / 2
1003 env = PulseEnvelope.gaussian(p, t, t_c)
1004 carrier = jnp.cos(PulseGates.omega_c * t)
1005 return env * carrier * jnp.cos(PulseGates.omega_q * t) * p[-1]
1007 @staticmethod
1008 def _coeff_RX_Y(p, t):
1009 """RX coefficient for the Y term (gaussian default)."""
1010 t_c = t / 2
1011 env = PulseEnvelope.gaussian(p, t, t_c)
1012 carrier = jnp.cos(PulseGates.omega_c * t)
1013 return -env * carrier * jnp.sin(PulseGates.omega_q * t) * p[-1]
1015 @staticmethod
1016 def _coeff_RY_X(p, t):
1017 """RY coefficient for the X term (gaussian default)."""
1018 t_c = t / 2
1019 env = PulseEnvelope.gaussian(p, t, t_c)
1020 carrier = jnp.cos(PulseGates.omega_c * t + jnp.pi / 2)
1021 return env * carrier * jnp.cos(PulseGates.omega_q * t) * p[-1]
1023 @staticmethod
1024 def _coeff_RY_Y(p, t):
1025 """RY coefficient for the Y term (gaussian default)."""
1026 t_c = t / 2
1027 env = PulseEnvelope.gaussian(p, t, t_c)
1028 carrier = jnp.cos(PulseGates.omega_c * t + jnp.pi / 2)
1029 return -env * carrier * jnp.sin(PulseGates.omega_q * t) * p[-1]
1031 # Backward-compat aliases (resolve to the dominant component of each gate).
1032 _coeff_Sx = _coeff_RX_X
1033 _coeff_Sy = _coeff_RY_Y
1035 @staticmethod
1036 def _coeff_Sz(p, t):
1037 """Coefficient function for RZ pulse: p * w."""
1038 return p[0] * p[1]
1040 @staticmethod
1041 def _coeff_Sc(p, t):
1042 """Constant coefficient for H correction phase."""
1043 return -1.0
1045 @staticmethod
1046 def _coeff_Scz(p, t):
1047 """Coefficient function for CZ pulse."""
1048 return p * jnp.pi
1050 @staticmethod
1051 def _record_pulse_event(gate_name, w, wires, pulse_params, parent=None):
1052 """Append a PulseEvent to the active pulse tape if recording.
1054 This is called from leaf gate methods (RX, RY, RZ, CZ) so that
1055 :func:`~qml_essentials.tape.pulse_recording` can collect events
1056 without the caller needing to know about the tape.
1057 """
1058 ptape = active_pulse_tape()
1059 if ptape is None:
1060 return
1062 from qml_essentials.drawing import PulseEvent, LEAF_META
1064 meta = LEAF_META.get(gate_name, {})
1065 wires_list = [wires] if isinstance(wires, int) else list(wires)
1067 if meta.get("physical", False):
1068 info = PulseEnvelope.get(PulseInformation.get_envelope())
1069 pp = PulseInformation.gate_by_name(gate_name).split_params(pulse_params)
1070 env_p = pp[:-1]
1071 dur = float(pp[-1])
1072 ptape.append(
1073 PulseEvent(
1074 gate=gate_name,
1075 wires=wires_list,
1076 envelope_fn=info["fn"],
1077 envelope_params=jnp.array(env_p),
1078 w=float(w),
1079 duration=dur,
1080 carrier_phase=meta["carrier_phase"],
1081 parent=parent,
1082 )
1083 )
1084 else:
1085 pp = PulseInformation.gate_by_name(gate_name).split_params(pulse_params)
1086 ptape.append(
1087 PulseEvent(
1088 gate=gate_name,
1089 wires=wires_list,
1090 envelope_fn=None,
1091 envelope_params=jnp.ravel(jnp.asarray(pp)),
1092 w=float(w) if not isinstance(w, list) else 0.0,
1093 duration=1.0,
1094 carrier_phase=0.0,
1095 parent=parent,
1096 )
1097 )
1099 @staticmethod
1100 def Rot(
1101 phi: float,
1102 theta: float,
1103 omega: float,
1104 wires: Union[int, List[int]],
1105 pulse_params: Optional[jnp.ndarray] = None,
1106 noise_params: Optional[Dict[str, float]] = None,
1107 random_key: Optional[jax.random.PRNGKey] = None,
1108 ) -> None:
1109 """
1110 Apply general rotation via decomposition: RZ(phi) · RY(theta) · RZ(omega).
1112 Args:
1113 phi (float): First rotation angle.
1114 theta (float): Second rotation angle.
1115 omega (float): Third rotation angle.
1116 wires (Union[int, List[int]]): Qubit index or indices to apply rotation to.
1117 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the
1118 composing gates. If None, uses optimized parameters.
1119 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
1120 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
1122 Returns:
1123 None: Gates are applied in-place to the circuit.
1124 """
1125 if noise_params is not None and "GateError" in noise_params:
1126 phi, random_key = UnitaryGates.GateError(phi, noise_params, random_key)
1127 theta, random_key = UnitaryGates.GateError(theta, noise_params, random_key)
1128 omega, random_key = UnitaryGates.GateError(omega, noise_params, random_key)
1129 PulseGates._execute_composite("Rot", [phi, theta, omega], wires, pulse_params)
1130 UnitaryGates.Noise(wires, noise_params)
1132 @staticmethod
1133 def PauliRot(
1134 pauli: str,
1135 theta: float,
1136 wires: Union[int, List[int]],
1137 pulse_params: Optional[jnp.ndarray] = None,
1138 noise_params: Optional[Dict[str, float]] = None,
1139 random_key: Optional[jax.random.PRNGKey] = None,
1140 ) -> None:
1141 """Not implemented as a PulseGate."""
1142 raise NotImplementedError("PauliRot gate is not implemented as PulseGate")
1144 @staticmethod
1145 def RX(
1146 w: float,
1147 wires: Union[int, List[int]],
1148 pulse_params: Optional[jnp.ndarray] = None,
1149 noise_params: Optional[Dict[str, float]] = None,
1150 random_key: Optional[jax.random.PRNGKey] = None,
1151 ) -> None:
1152 """Apply X-axis rotation using the active pulse envelope.
1154 Args:
1155 w: Rotation angle in radians.
1156 wires: Qubit index or indices.
1157 pulse_params: Envelope parameters ``[env_0, ..., env_n, t]``.
1158 If ``None``, uses optimized defaults.
1159 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
1160 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
1161 """
1162 pulse_params = PulseInformation.RX.split_params(pulse_params)
1164 PulseGates._record_pulse_event("RX", w, wires, pulse_params)
1165 t = pulse_params[-1]
1167 # Proper interaction-picture drive Hamiltonian for RX:
1168 # H_I(τ) = Ω(τ)·cos(ω_c·τ) · [ cos(ω_q·τ)·X − sin(ω_q·τ)·Y ]
1169 # which on resonance averages (RWA) to +(Ω/2)·X while the
1170 # 2·ω_q counter-rotating part oscillates and cancels.
1171 H_X = op.Hermitian(PulseGates.X, wires=wires, record=False)
1172 H_Y = op.Hermitian(PulseGates.Y, wires=wires, record=False)
1173 H_eff = PulseGates._coeff_RX_X * H_X + PulseGates._coeff_RX_Y * H_Y
1175 # Pack: [envelope_params..., w] - evolution time is the last element
1176 # of pulse_params (pulse_params[-1]).
1177 w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
1178 # Use jnp.concatenate over Python list-splat to keep the trace graph
1179 # compact (no per-element unpacking + restack).
1180 env_params = jnp.concatenate(
1181 [jnp.ravel(pulse_params[:-1]), jnp.ravel(jnp.asarray(w))]
1182 )
1183 # Both terms share the same parameter array.
1184 ys.evolve(H_eff, name="RX")([env_params, env_params], t)
1185 UnitaryGates.Noise(wires, noise_params)
1187 @staticmethod
1188 def RY(
1189 w: float,
1190 wires: Union[int, List[int]],
1191 pulse_params: Optional[jnp.ndarray] = None,
1192 noise_params: Optional[Dict[str, float]] = None,
1193 random_key: Optional[jax.random.PRNGKey] = None,
1194 ) -> None:
1195 """Apply Y-axis rotation using the active pulse envelope.
1197 Args:
1198 w: Rotation angle in radians.
1199 wires: Qubit index or indices.
1200 pulse_params: Envelope parameters ``[env_0, ..., env_n, t]``.
1201 If ``None``, uses optimized defaults.
1202 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
1203 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
1204 """
1205 pulse_params = PulseInformation.RY.split_params(pulse_params)
1207 PulseGates._record_pulse_event("RY", w, wires, pulse_params)
1208 t = pulse_params[-1]
1210 # See NOTE in RX: same proper interaction-picture form, with
1211 # carrier phase ϕ = +π/2 so the slow RWA component drives +Y.
1212 H_X = op.Hermitian(PulseGates.X, wires=wires, record=False)
1213 H_Y = op.Hermitian(PulseGates.Y, wires=wires, record=False)
1214 H_eff = PulseGates._coeff_RY_X * H_X + PulseGates._coeff_RY_Y * H_Y
1216 # Pack w into the params so the coefficient function doesn't need
1217 # a closure - this enables JIT solver cache sharing across all RY calls.
1218 w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
1219 env_params = jnp.concatenate(
1220 [jnp.ravel(pulse_params[:-1]), jnp.ravel(jnp.asarray(w))]
1221 )
1222 ys.evolve(H_eff, name="RY")([env_params, env_params], t)
1223 UnitaryGates.Noise(wires, noise_params)
1225 @staticmethod
1226 def RZ(
1227 w: float,
1228 wires: Union[int, List[int]],
1229 pulse_params: Optional[float] = None,
1230 noise_params: Optional[Dict[str, float]] = None,
1231 random_key: Optional[jax.random.PRNGKey] = None,
1232 ) -> None:
1233 """
1234 Apply Z-axis rotation using pulse-level implementation.
1236 Implements RZ rotation using virtual Z rotations (phase tracking)
1237 without physical pulse application.
1239 Args:
1240 w (float): Rotation angle in radians.
1241 wires (Union[int, List[int]]): Qubit index or indices to apply rotation to.
1242 pulse_params (Optional[float]): Duration parameter for the pulse.
1243 Rotation angle = w * 2 * pulse_params. Defaults to 0.5 if None.
1244 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
1245 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
1247 Returns:
1248 None: Gate is applied in-place to the circuit.
1249 """
1250 pulse_params = PulseInformation.RZ.split_params(pulse_params)
1252 PulseGates._record_pulse_event("RZ", w, wires, pulse_params)
1254 _H = op.Hermitian(PulseGates.Z, wires=wires, record=False)
1255 H_eff = PulseGates._coeff_Sz * _H
1257 # Pack w into the params so the coefficient function doesn't need
1258 # a closure - [pulse_param_scalar, w] enables JIT solver cache sharing.
1259 # pulse_params may be a 1-element array or scalar; ravel + slice the first
1260 # element to preserve the original semantics, then concatenate with w.
1261 w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
1262 pp_flat = jnp.ravel(jnp.asarray(pulse_params))
1263 ys.evolve(H_eff, name="RZ")(
1264 [jnp.concatenate([pp_flat[:1], jnp.ravel(jnp.asarray(w))])],
1265 1,
1266 )
1268 UnitaryGates.Noise(wires, noise_params)
1270 @staticmethod
1271 def _resolve_wires(wire_fn, wires):
1272 """Resolve a wire selector string to actual wire(s).
1274 Args:
1275 wire_fn: ``"all"``, ``"target"``, or ``"control"``.
1276 wires: Parent gate's wire(s) (int or list).
1278 Returns:
1279 Wire(s) for the child gate.
1280 """
1281 wires_list = [wires] if isinstance(wires, int) else list(wires)
1282 if wire_fn == "all":
1283 return wires if len(wires_list) > 1 else wires_list[0]
1284 if wire_fn == "target":
1285 return wires_list[-1] if len(wires_list) > 1 else wires_list[0]
1286 if wire_fn == "control":
1287 return wires_list[0]
1288 raise ValueError(f"Unknown wire_fn: {wire_fn!r}")
1290 @staticmethod
1291 def _execute_composite(gate_name, w, wires, pulse_params=None):
1292 """Execute a composite gate by walking its decomposition.
1294 Reads the :class:`DecompositionStep` list from
1295 :class:`PulseInformation` and dispatches each step to the
1296 appropriate ``PulseGates`` method.
1298 Args:
1299 gate_name: Gate name (e.g. ``"H"``, ``"CX"``).
1300 w: Rotation angle(s) passed to the parent gate.
1301 wires: Wire(s) of the parent gate.
1302 pulse_params: Optional pulse parameters (split across children).
1303 """
1304 pp_obj = PulseInformation.gate_by_name(gate_name)
1305 parts = pp_obj.split_params(pulse_params)
1307 for step, child_params in zip(pp_obj.decomposition, parts):
1308 child_wires = PulseGates._resolve_wires(step.wire_fn, wires)
1309 child_w = step.angle_fn(w) if step.angle_fn is not None else w
1310 child_gate = getattr(PulseGates, step.gate.name)
1312 # Leaf gates that take a rotation angle
1313 if step.gate.name in ("RX", "RY", "RZ"):
1314 child_gate(child_w, wires=child_wires, pulse_params=child_params)
1315 # Leaf gates without a rotation angle
1316 elif step.gate.name in ("CZ",):
1317 child_gate(wires=child_wires, pulse_params=child_params)
1318 # Composite gates with a rotation angle (CRX, CRY, CRZ, Rot, ...)
1319 elif step.gate.name in ("Rot",):
1320 # Rot expects (phi, theta, omega, wires, ...)
1321 child_gate(*child_w, wires=child_wires, pulse_params=child_params)
1322 elif step.gate.decomposition is not None and step.gate.name in (
1323 "CRX",
1324 "CRY",
1325 "CRZ",
1326 "CPhase",
1327 ):
1328 child_gate(child_w, wires=child_wires, pulse_params=child_params)
1329 # Other composite gates (H, CX, CY, ...)
1330 else:
1331 child_gate(wires=child_wires, pulse_params=child_params)
1333 @staticmethod
1334 def H(
1335 wires: Union[int, List[int]],
1336 pulse_params: Optional[jnp.ndarray] = None,
1337 noise_params: Optional[Dict[str, float]] = None,
1338 random_key: Optional[jax.random.PRNGKey] = None,
1339 ) -> None:
1340 """Apply Hadamard gate using pulse decomposition.
1342 Decomposes as RZ(π) · RY(π/2) followed by a correction phase.
1344 Args:
1345 wires (Union[int, List[int]]): Qubit index or indices.
1346 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
1347 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
1348 (not used in this gate).
1349 """
1350 PulseGates._execute_composite("H", 0.0, wires, pulse_params)
1352 # Correction phase unique to the H gate
1353 _H = op.Hermitian(PulseGates._H_corr, wires=wires, record=False)
1354 H_corr = PulseGates._coeff_Sc * _H
1355 ys.evolve(H_corr, name="H")([0], 1)
1356 UnitaryGates.Noise(wires, noise_params)
1358 @staticmethod
1359 def CX(
1360 wires: List[int],
1361 pulse_params: Optional[jnp.ndarray] = None,
1362 noise_params: Optional[Dict[str, float]] = None,
1363 random_key: Optional[jax.random.PRNGKey] = None,
1364 ) -> None:
1365 """Apply CNOT gate via decomposition: H(target) · CZ · H(target).
1367 Args:
1368 wires (List[int]): Control and target qubit indices [control, target].
1369 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the
1370 composing gates. If None, uses optimized parameters.
1371 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
1372 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
1373 (not used in this gate).
1375 Returns:
1376 None: Gate is applied in-place to the circuit.
1377 """
1378 PulseGates._execute_composite("CX", 0.0, wires, pulse_params)
1379 UnitaryGates.Noise(wires, noise_params)
1381 @staticmethod
1382 def CY(
1383 wires: List[int],
1384 pulse_params: Optional[jnp.ndarray] = None,
1385 noise_params: Optional[Dict[str, float]] = None,
1386 random_key: Optional[jax.random.PRNGKey] = None,
1387 ) -> None:
1388 """Apply controlled-Y via decomposition.
1390 Args:
1391 wires (List[int]): Control and target qubit indices [control, target].
1392 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the
1393 composing gates. If None, uses optimized parameters.
1394 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
1395 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
1396 (not used in this gate).
1398 """
1399 PulseGates._execute_composite("CY", 0.0, wires, pulse_params)
1400 UnitaryGates.Noise(wires, noise_params)
1402 @staticmethod
1403 def CZ(
1404 wires: List[int],
1405 pulse_params: Optional[float] = None,
1406 noise_params: Optional[Dict[str, float]] = None,
1407 random_key: Optional[jax.random.PRNGKey] = None,
1408 ) -> None:
1409 """Apply controlled-Z using ZZ coupling Hamiltonian.
1411 Args:
1412 wires (List[int]): Control and target qubit indices.
1413 pulse_params (Optional[float]): Time or duration parameter for
1414 the pulse evolution. If None, uses optimized value.
1415 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
1416 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
1417 (not used in this gate).
1419 """
1420 if pulse_params is None:
1421 pulse_params = PulseInformation.CZ.params
1423 PulseGates._record_pulse_event("CZ", 0.0, wires, pulse_params)
1425 _H = op.Hermitian(PulseGates._H_CZ, wires=wires, record=False)
1426 H_eff = PulseGates._coeff_Scz * _H
1427 ys.evolve(H_eff, name="CZ")([pulse_params], 1)
1428 UnitaryGates.Noise(wires, noise_params)
1430 @staticmethod
1431 def CRX(
1432 w: float,
1433 wires: List[int],
1434 pulse_params: Optional[jnp.ndarray] = None,
1435 noise_params: Optional[Dict[str, float]] = None,
1436 random_key: Optional[jax.random.PRNGKey] = None,
1437 ) -> None:
1438 """Apply controlled-RX via decomposition.
1440 Args:
1441 w (float): Rotation angle in radians.
1442 wires (List[int]): Control and target qubit indices [control, target].
1443 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the
1444 composing gates. If None, uses optimized parameters.
1445 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
1446 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
1447 (not used in this gate).
1448 """
1449 PulseGates._execute_composite("CRX", w, wires, pulse_params)
1450 UnitaryGates.Noise(wires, noise_params)
1452 @staticmethod
1453 def CRY(
1454 w: float,
1455 wires: List[int],
1456 pulse_params: Optional[jnp.ndarray] = None,
1457 noise_params: Optional[Dict[str, float]] = None,
1458 random_key: Optional[jax.random.PRNGKey] = None,
1459 ) -> None:
1460 """Apply controlled-RY via decomposition.
1462 Args:
1463 w (float): Rotation angle in radians.
1464 wires (List[int]): Control and target qubit indices [control, target].
1465 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the
1466 composing gates. If None, uses optimized parameters.
1467 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
1468 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
1469 """
1470 w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
1471 PulseGates._execute_composite("CRY", w, wires, pulse_params)
1472 UnitaryGates.Noise(wires, noise_params)
1474 @staticmethod
1475 def CRZ(
1476 w: float,
1477 wires: List[int],
1478 pulse_params: Optional[jnp.ndarray] = None,
1479 noise_params: Optional[Dict[str, float]] = None,
1480 random_key: Optional[jax.random.PRNGKey] = None,
1481 ) -> None:
1482 """Apply controlled-RZ via decomposition.
1484 Args:
1485 w (float): Rotation angle in radians.
1486 wires (List[int]): Control and target qubit indices [control, target].
1487 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the
1488 composing gates. If None, uses optimized parameters.
1489 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
1490 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
1491 """
1492 w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
1493 PulseGates._execute_composite("CRZ", w, wires, pulse_params)
1494 UnitaryGates.Noise(wires, noise_params)
1496 @staticmethod
1497 def CPhase(
1498 w: float,
1499 wires: List[int],
1500 pulse_params: Optional[jnp.ndarray] = None,
1501 noise_params: Optional[Dict[str, float]] = None,
1502 random_key: Optional[jax.random.PRNGKey] = None,
1503 ) -> None:
1504 """Apply controlled phase shift via decomposition.
1506 Decomposes CPhase(φ) into RZ and CX gates:
1507 RZ(φ/2) on control, RZ(φ/2) on target, CX, RZ(-φ/2) on target, CX.
1509 Args:
1510 w (float): Phase shift angle in radians.
1511 wires (List[int]): Control and target qubit indices [control, target].
1512 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the
1513 composing gates. If None, uses optimized parameters.
1514 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
1515 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
1516 """
1517 w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
1518 PulseGates._execute_composite("CPhase", w, wires, pulse_params)
1519 UnitaryGates.Noise(wires, noise_params)
1522class PulseParamManager:
1523 def __init__(self, pulse_params: jnp.ndarray):
1524 self.pulse_params = pulse_params
1525 self.idx = 0
1527 def get(self, n: int):
1528 """Return the next n parameters and advance the cursor."""
1529 if self.idx + n > len(self.pulse_params):
1530 raise ValueError("Not enough pulse parameters left for this gate")
1531 # TODO: we squeeze here to get rid of any extra hidden dimension
1532 params = self.pulse_params[self.idx : self.idx + n].squeeze()
1533 self.idx += n
1534 return params
1537# Initialise PulseInformation after PulseGates exists so leaf defaults,
1538# composite trees, mirrored PulseGates flags, and coefficient functions agree.
1539PulseInformation.reset_defaults()