Coverage for qml_essentials / pulses.py: 82%
554 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-11 15:51 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-11 15:51 +0000
1import os
2from contextlib import contextmanager
3from dataclasses import dataclass
4from typing import Optional, List, Union, Dict, Callable, Tuple
5import csv
6import jax.numpy as jnp
7import jax
9from qml_essentials import jaqsi as js
10from qml_essentials.utils import safe_random_split
11from qml_essentials.tape import active_pulse_tape
12from qml_essentials.unitary import UnitaryGates
13import logging
15log = logging.getLogger(__name__)
18@dataclass
19class DecompositionStep:
20 """One step in a composite pulse gate decomposition.
22 Attributes:
23 gate: Child PulseParams object for this step.
24 wire_fn: Wire selection - ``"all"``, ``"target"``, or ``"control"``.
25 angle_fn: Maps parent angle(s) ``w`` to child angle.
26 ``None`` means pass ``w`` through unchanged.
27 """
29 gate: "PulseParams"
30 wire_fn: str = "all"
31 angle_fn: Optional[Callable] = None
34@dataclass(frozen=True)
35class PulseStateSnapshot:
36 """Snapshot of the mutable global pulse configuration."""
38 envelope: str
39 rwa: bool
40 frame: str
41 leaf_params: Dict[str, jnp.ndarray]
44class PulseParams:
45 """Container for hierarchical pulse parameters.
47 Leaf nodes hold direct parameters; composite nodes hold a list of
48 :class:`DecompositionStep` objects that describe how the gate is
49 built from simpler gates.
51 Attributes:
52 name: Gate identifier (e.g. ``"RX"``, ``"H"``).
53 decomposition: List of :class:`DecompositionStep` (composite only).
54 """
56 def __init__(
57 self,
58 name: str = "",
59 params: Optional[jnp.ndarray] = None,
60 decomposition: Optional[List[DecompositionStep]] = None,
61 ) -> None:
62 """
63 Args:
64 name: Gate name.
65 params: Direct pulse parameters (leaf gates).
66 Mutually exclusive with *decomposition*.
67 decomposition: List of :class:`DecompositionStep` (composite gates).
68 Mutually exclusive with *params*.
69 """
70 assert (params is None) != (decomposition is None), (
71 "Exactly one of `params` or `decomposition` must be provided."
72 )
74 self.decomposition = decomposition
75 # Derive _pulse_obj for backward compat with childs/leafs/split_params
76 self._pulse_obj = (
77 [step.gate for step in decomposition] if decomposition else None
78 )
80 if params is not None:
81 self._params = params
83 self.name = name
85 def __len__(self) -> int:
86 """
87 Get the total number of pulse parameters.
89 For composite gates, returns the accumulated count from all children.
91 Returns:
92 int: Total number of pulse parameters.
93 """
94 return len(self.params)
96 def __getitem__(self, idx: int) -> Union[float, jnp.ndarray]:
97 """
98 Access pulse parameter(s) by index.
100 For leaf gates, returns the parameter at the given index.
101 For composite gates, returns parameters of the child at the given index.
103 Args:
104 idx (int): Index to access.
106 Returns:
107 Union[float, jnp.ndarray]: Parameter value or child parameters.
108 """
109 if self.is_leaf:
110 return self.params[idx]
111 else:
112 return self.childs[idx].params
114 def __str__(self) -> str:
115 """Return string representation (gate name)."""
116 return self.name
118 def __repr__(self) -> str:
119 """Return repr string (gate name)."""
120 return self.name
122 @property
123 def is_leaf(self) -> bool:
124 """Check if this is a leaf node (direct parameters, no children)."""
125 return self._pulse_obj is None
127 @property
128 def size(self) -> int:
129 """Get the total parameter count (alias for __len__)."""
130 return len(self)
132 @property
133 def leafs(self) -> List["PulseParams"]:
134 """
135 Get all leaf nodes in the hierarchy.
137 Recursively collects all leaf PulseParams objects in the tree.
139 Returns:
140 List[PulseParams]: List of unique leaf nodes.
141 """
142 if self.is_leaf:
143 return [self]
145 leafs = []
146 for obj in self._pulse_obj:
147 leafs.extend(obj.leafs)
149 return list(set(leafs))
151 @property
152 def childs(self) -> List["PulseParams"]:
153 """
154 Get direct children of this node.
156 Returns:
157 List[PulseParams]: List of child PulseParams objects, or empty list
158 if this is a leaf node.
159 """
160 if self.is_leaf:
161 return []
163 return self._pulse_obj
165 @property
166 def shape(self) -> List[int]:
167 """
168 Get the shape of pulse parameters.
170 For leaf nodes, returns list with parameter count.
171 For composite nodes, returns nested list of child shapes.
173 Returns:
174 List[int]: Parameter shape specification.
175 """
176 if self.is_leaf:
177 return [len(self.params)]
179 shape = []
180 for obj in self.childs:
181 shape.append(*obj.shape())
183 return shape
185 @property
186 def params(self) -> jnp.ndarray:
187 """
188 Get or compute pulse parameters.
190 For leaf nodes, returns internal pulse parameters.
191 For composite nodes, returns concatenated parameters from all children.
193 Returns:
194 jnp.ndarray: Pulse parameters array.
195 """
196 if self.is_leaf:
197 return self._params
199 params = self.split_params(params=None, leafs=False)
201 return jnp.concatenate(params)
203 @params.setter
204 def params(self, value: jnp.ndarray) -> None:
205 """
206 Set pulse parameters.
208 For leaf nodes, sets internal parameters directly.
209 For composite nodes, distributes values across children.
211 Args:
212 value (jnp.ndarray): Pulse parameters to set.
214 Raises:
215 AssertionError: If value is not jnp.ndarray for leaf nodes.
216 """
217 if self.is_leaf:
218 assert isinstance(value, jnp.ndarray), "params must be a jnp.ndarray"
219 self._params = value
220 return
222 idx = 0
223 for obj in self.childs:
224 nidx = idx + obj.size
225 obj.params = value[idx:nidx]
226 idx = nidx
228 @property
229 def leaf_params(self) -> jnp.ndarray:
230 """
231 Get parameters from all leaf nodes.
233 Returns:
234 jnp.ndarray: Concatenated parameters from all leaf nodes.
235 """
236 if self.is_leaf:
237 return self._params
239 params = self.split_params(None, leafs=True)
241 return jnp.concatenate(params)
243 @leaf_params.setter
244 def leaf_params(self, value: jnp.ndarray) -> None:
245 """
246 Set parameters for all leaf nodes.
248 Args:
249 value (jnp.ndarray): Parameters to distribute across leaf nodes.
250 """
251 if self.is_leaf:
252 self._params = value
253 return
255 idx = 0
256 for obj in self.leafs:
257 nidx = idx + obj.size
258 obj.params = value[idx:nidx]
259 idx = nidx
261 def split_params(
262 self,
263 params: Optional[jnp.ndarray] = None,
264 leafs: bool = False,
265 ) -> List[jnp.ndarray]:
266 """
267 Split parameters into sub-arrays for children or leaves.
269 Args:
270 params (Optional[jnp.ndarray]): Parameters to split. If None,
271 uses internal parameters.
272 leafs (bool): If True, splits across leaf nodes; if False,
273 splits across direct children. Defaults to False.
275 Returns:
276 List[jnp.ndarray]: List of parameter arrays for children or leaves.
277 """
278 if params is None:
279 if self.is_leaf:
280 return self._params
282 objs = self.leafs if leafs else self.childs
283 s_params = []
284 for obj in objs:
285 s_params.append(obj.params)
287 return s_params
288 else:
289 if self.is_leaf:
290 return params
292 objs = self.leafs if leafs else self.childs
293 s_params = []
294 idx = 0
295 for obj in objs:
296 nidx = idx + obj.size
297 s_params.append(params[idx:nidx])
298 idx = nidx
300 return s_params
303class PulseEnvelope:
304 """Registry of pulse envelope shapes.
306 Each envelope is a pure function ``(p, t, t_c) -> amplitude`` that
307 computes the pulse envelope *without* carrier modulation. The carrier
308 ``cos(omega_c * t + phi_c)`` is applied separately in the coefficient
309 functions built by :meth:`build_coeff_fns`.
311 Attributes:
312 REGISTRY: Mapping from envelope name to metadata dict containing
313 ``fn`` (callable), ``n_envelope_params`` (int), and per-gate
314 default parameter arrays.
315 """
317 @staticmethod
318 def gaussian(p, t, t_c):
319 """Gaussian envelope. ``p = [A, sigma]``."""
320 A, sigma = p[0], p[1]
321 return A * jnp.exp(-0.5 * ((t - t_c) / sigma) ** 2)
323 @staticmethod
324 def square(p, t, t_c):
325 """Rectangular envelope. ``p = [A, width]``."""
326 A, width = p[0], p[1]
327 return A * (jnp.abs(t - t_c) <= width / 2)
329 @staticmethod
330 def cosine(p, t, t_c):
331 """Raised cosine envelope. ``p = [A, width]``."""
332 A, width = p[0], p[1]
333 x = jnp.clip((t - t_c) / width, -0.5, 0.5)
334 return A * jnp.cos(jnp.pi * x)
336 @staticmethod
337 def drag(p, t, t_c):
338 """DRAG (Derivative Removal by Adiabatic Gate). ``p = [A, beta, sigma]``."""
339 A, beta, sigma = p[0], p[1], p[2]
340 g = A * jnp.exp(-0.5 * ((t - t_c) / sigma) ** 2)
341 dg = g * (-(t - t_c) / sigma**2)
342 return g + beta * dg
344 @staticmethod
345 def sech(p, t, t_c):
346 """Hyperbolic secant envelope. ``p = [A, sigma]``."""
347 A, sigma = p[0], p[1]
348 return A / jnp.cosh((t - t_c) / sigma)
350 # ``n_envelope_params`` counts only the envelope parameters (excluding
351 # the evolution time ``t`` which is always the last element of the full
352 # pulse parameter vector).
353 REGISTRY = {
354 "gaussian": {
355 "fn": gaussian.__func__,
356 "n_envelope_params": 2,
357 "defaults": {
358 "RX": jnp.array(
359 [0.38009941846766804, 1.631698142660167, 3.007403822238108]
360 ),
361 "RY": jnp.array(
362 [0.3836652338514791, 1.616595983505249, 2.9794135093698966]
363 ),
364 },
365 },
366 "square": {
367 "fn": square.__func__,
368 "n_envelope_params": 2,
369 "defaults": {
370 "RX": jnp.array(
371 [1.209655637514602, 0.8266815576721239, 1.1483122857413859]
372 ),
373 "RY": jnp.array(
374 [1.0287942142779052, 0.9860505130182093, 0.9720116870310977]
375 ),
376 },
377 },
378 "cosine": {
379 "fn": cosine.__func__,
380 "n_envelope_params": 2,
381 "defaults": {
382 "RX": jnp.array([1.0, 1.0, 1.0]),
383 "RY": jnp.array([1.0, 1.0, 1.0]),
384 },
385 },
386 "drag": {
387 "fn": drag.__func__,
388 "n_envelope_params": 3,
389 "defaults": {
390 "RX": jnp.array(
391 [
392 0.326562746114197,
393 0.4002767596709071,
394 5.3228107728890315,
395 3.141300761986467,
396 ]
397 ),
398 "RY": jnp.array(
399 [
400 0.323287924190616,
401 0.4065017233024265,
402 7.00299644871222,
403 3.139481229843545,
404 ]
405 ),
406 },
407 },
408 "sech": {
409 "fn": sech.__func__,
410 "n_envelope_params": 2,
411 "defaults": {
412 "RX": jnp.array([1.0, 1.0, 1.0]),
413 "RY": jnp.array([1.0, 1.0, 1.0]),
414 },
415 },
416 "general": {
417 "fn": None,
418 "n_envelope_params": 0,
419 "defaults": {
420 "RZ": jnp.array([0.5]),
421 "CZ": jnp.array([0.3183098783513154]),
422 },
423 },
424 }
426 @staticmethod
427 def available() -> List[str]:
428 """Return list of registered envelope names."""
429 return list(PulseEnvelope.REGISTRY.keys())
431 @staticmethod
432 def get(name: str) -> dict:
433 """Look up envelope metadata by name.
435 Raises:
436 ValueError: If *name* is not registered.
437 """
438 if name not in PulseEnvelope.REGISTRY:
439 raise ValueError(
440 f"Unknown pulse envelope '{name}'. "
441 f"Available: {PulseEnvelope.available()}"
442 )
443 return PulseEnvelope.REGISTRY[name]
445 @staticmethod
446 def build_coeff_fns(
447 envelope_fn: Callable,
448 omega_c: float,
449 omega_q: float,
450 rwa: bool = True,
451 frame: str = "drive",
452 ) -> Tuple[Callable, Callable, Callable, Callable]:
453 """Build the four interaction-picture coefficient functions.
455 The lab-frame Hamiltonian is
457 H(t,Π) = H_static + Σ_j S_j(t;Π) H_j ,
458 S_j(t;Π) = E_j(t;Π) · cos(ω_c·t + φ_c) ,
460 and the interaction-picture transform with respect to
461 ``H_static = (ω_q/2)·Z`` produces
463 H̃_j(t) = exp(+i H_static t) H_j exp(-i H_static t) ,
464 H_I(t) = Σ_j S_j(t) H̃_j(t) .
466 For a single qubit driven on X, ``H̃_X(t) = cos(ω_q·t) X
467 − sin(ω_q·t) Y``, so
469 H_I(t) = Ω(t) · cos(ω_c·t + φ) ·
470 [ cos(ω_q·t) · X − sin(ω_q·t) · Y ] .
472 ``rwa=True`` (default) drops the fast (~2·ω_q on resonance) terms and
473 keeps only the slow envelope, yielding the analytical RWA
475 H_I^RWA(t) = (Ω(t)/2) · [ cos(φ) X + sin(φ) Y ] .
477 For RX (``φ = 0``) this reduces to ``(Ω/2)·X``; for RY
478 (``φ = +π/2``) to ``(Ω/2)·Y``. This is dramatically cheaper to
479 integrate (no fast oscillations → adaptive ODE solver takes
480 large steps).
482 ``rwa=False`` keeps **both** the slow and the fast
483 counter-rotating components.
485 Each returned function has a unique ``__code__`` object so the
486 jaqsi solver cache assigns separate compiled XLA programs per
487 envelope shape and per (gate, component) pair.
489 The rotation angle ``w`` is expected as the **last** element of
490 the parameter array ``p`` (i.e. ``p[-1]``). Envelope parameters
491 occupy ``p[:-1]``.
493 Args:
494 envelope_fn: Pure envelope function ``(p, t, t_c) -> scalar``.
495 omega_c: Carrier frequency.
496 omega_q: Qubit frequency (interaction-picture rotation rate).
497 rwa: When ``True``, return the RWA-truncated coefficients
498 (no fast counter-rotating terms). Default ``True``
499 frame: Algebraic representation of the exact (non-RWA)
500 coefficients. Mathematically equivalent options:
502 * ``"drive"`` (default): applies the product-to-sum identity to
503 expose the slow ``(ω_c-ω_q)`` and fast ``(ω_c+ω_q)``
504 modes explicitly,
505 ``cos(ω_c t)cos(ω_q t) =
506 ½[cos((ω_c-ω_q)t) + cos((ω_c+ω_q)t)]``. Algebraically
507 identical to ``"lab"`` (no RWA, no information lost).
508 Primary use: combined with the ``magnus2``/``magnus4``
509 jaqsi solvers, the explicit slow/fast decomposition
510 is sometimes numerically better-conditioned and lets
511 the user pick a fixed grid based on the slow
512 frequency alone (``Δ = |ω_c-ω_q|``) when the fast
513 ``(ω_c+ω_q)`` mode is well-resolved by the chosen
514 step.
515 * ``"drive"``: the literal form
516 ``Ω(t) cos(ω_c t + φ) cos(ω_q t)`` (and the analogous
517 ``-sin`` term). Two trig multiplications per call;
518 contains all four product frequencies implicitly.
520 Ignored when ``rwa=True``.
522 Returns:
523 Tuple ``(coeff_RX_X, coeff_RX_Y, coeff_RY_X, coeff_RY_Y)``
524 of coefficient functions for the X- and Y-components of the
525 RX and RY interaction-picture Hamiltonians.
526 """
527 if frame not in ("lab", "drive"):
528 raise ValueError(f"Unknown frame {frame!r}; expected 'lab' or 'drive'.")
529 if rwa:
530 # RWA-truncated coefficients (no carrier, no fast factors).
531 # H_I^RWA = (Ω(t)/2) [cos(φ) X + sin(φ) Y]; we keep the
532 # ``p[-1]`` rotation-angle scaling so the calling
533 # ParametrizedHamiltonian shape is unchanged.
534 half = jnp.asarray(0.5)
536 def _coeff_RX_X(p, t):
537 t_c = t / 2
538 env = envelope_fn(p, t, t_c)
539 return half * env * p[-1]
541 def _coeff_RX_Y(p, t): # Y component vanishes for RX (φ=0)
542 t_c = t / 2
543 env = envelope_fn(p, t, t_c)
544 return jnp.zeros_like(half * env * p[-1])
546 def _coeff_RY_X(p, t): # X component vanishes for RY (φ=π/2)
547 t_c = t / 2
548 env = envelope_fn(p, t, t_c)
549 return jnp.zeros_like(half * env * p[-1])
551 def _coeff_RY_Y(p, t):
552 t_c = t / 2
553 env = envelope_fn(p, t, t_c)
554 return half * env * p[-1]
556 return _coeff_RX_X, _coeff_RX_Y, _coeff_RY_X, _coeff_RY_Y
558 if frame == "drive":
559 # Drive-frame: same exact dynamics, expressed via the
560 # product-to-sum identities so the slow ``Δ = ω_c - ω_q``
561 # and fast ``Σ = ω_c + ω_q`` modes appear explicitly.
562 # Mathematically identical to the ``lab`` branch below.
563 #
564 # Identities used:
565 # cos(ω_c t) cos(ω_q t) = ½[cos(Δ t) + cos(Σ t)]
566 # cos(ω_c t) sin(ω_q t) = ½[sin(Σ t) − sin(Δ t)]
567 # −sin(ω_c t) cos(ω_q t) = −½[sin(Σ t) + sin(Δ t)]
568 # −sin(ω_c t) sin(ω_q t) = ½[cos(Σ t) − cos(Δ t)]
569 # (RY uses cos(ω_c t + π/2) = −sin(ω_c t).)
570 omega_d = omega_c - omega_q
571 omega_s = omega_c + omega_q
572 half = jnp.asarray(0.5)
574 def _coeff_RX_X(p, t):
575 t_c = t / 2
576 env = envelope_fn(p, t, t_c)
577 mod = half * (jnp.cos(omega_d * t) + jnp.cos(omega_s * t))
578 return env * mod * p[-1]
580 def _coeff_RX_Y(p, t):
581 t_c = t / 2
582 env = envelope_fn(p, t, t_c)
583 mod = -half * (jnp.sin(omega_s * t) - jnp.sin(omega_d * t))
584 return env * mod * p[-1]
586 def _coeff_RY_X(p, t):
587 t_c = t / 2
588 env = envelope_fn(p, t, t_c)
589 mod = -half * (jnp.sin(omega_s * t) + jnp.sin(omega_d * t))
590 return env * mod * p[-1]
592 def _coeff_RY_Y(p, t):
593 t_c = t / 2
594 env = envelope_fn(p, t, t_c)
595 mod = -half * (jnp.cos(omega_s * t) - jnp.cos(omega_d * t))
596 return env * mod * p[-1]
598 return _coeff_RX_X, _coeff_RX_Y, _coeff_RY_X, _coeff_RY_Y
600 # RX uses carrier phase phi = 0 so that after RWA
601 # cos(ω_q τ)·cos(ω_q τ) averages to +1/2 → drives +X
602 # -cos(ω_q τ)·sin(ω_q τ) averages to 0 → Y cancels
603 # giving H_I^RWA ≈ (Ω/2)·X → U ≈ exp(-iθ/2 X), matching op.RX.
604 # The exact form below KEEPS the fast 2·ω_q components.
605 def _coeff_RX_X(p, t):
606 t_c = t / 2
607 env = envelope_fn(p, t, t_c)
608 carrier = jnp.cos(omega_c * t)
609 return env * carrier * jnp.cos(omega_q * t) * p[-1]
611 def _coeff_RX_Y(p, t):
612 t_c = t / 2
613 env = envelope_fn(p, t, t_c)
614 carrier = jnp.cos(omega_c * t)
615 return -env * carrier * jnp.sin(omega_q * t) * p[-1]
617 # RY uses carrier phase phi = +pi/2 so the RWA component drives +Y.
618 def _coeff_RY_X(p, t):
619 t_c = t / 2
620 env = envelope_fn(p, t, t_c)
621 carrier = jnp.cos(omega_c * t + jnp.pi / 2)
622 return env * carrier * jnp.cos(omega_q * t) * p[-1]
624 def _coeff_RY_Y(p, t):
625 t_c = t / 2
626 env = envelope_fn(p, t, t_c)
627 carrier = jnp.cos(omega_c * t + jnp.pi / 2)
628 return -env * carrier * jnp.sin(omega_q * t) * p[-1]
630 return _coeff_RX_X, _coeff_RX_Y, _coeff_RY_X, _coeff_RY_Y
633class PulseInformation:
634 """Stores pulse parameter counts and optimized pulse parameters.
636 Call :meth:`set_envelope` to switch the active pulse shape. This
637 rebuilds all :class:`PulseParams` trees so that parameter counts
638 and defaults match the selected envelope.
639 """
641 DEFAULT_ENVELOPE: str = "drag"
642 DEFAULT_RWA: bool = True
643 DEFAULT_FRAME: str = "drive"
644 LEAF_GATE_NAMES: Tuple[str, ...] = ("RX", "RY", "RZ", "CZ")
646 _envelope: str = DEFAULT_ENVELOPE
647 # Whether to apply the rotating-wave approximation when building the
648 # interaction-picture coefficient functions.
649 # Default ``True`` (exact dynamics, no RWA).
650 # Setting to ``True`` drops the fast counter-rotating terms —
651 # much faster to integrate
652 # See :meth:`PulseEnvelope.build_coeff_fns`.
653 _rwa: bool = DEFAULT_RWA
654 # Algebraic representation of the (non-RWA) coefficients. Either
655 # ``"lab"`` or ``"drive"`` (product-to-sum decomposition).
656 # Mathematically equivalent — see :meth:`PulseEnvelope.build_coeff_fns`
657 # when ``"drive"`` is numerically advantageous (mainly with the Magnus solvers).
658 _frame: str = DEFAULT_FRAME
660 @classmethod
661 def _build_leaf_gates(cls):
662 """(Re-)create leaf PulseParams from the active envelope defaults."""
663 defaults = PulseEnvelope.get(cls._envelope)["defaults"]
664 general = PulseEnvelope.get("general")["defaults"]
666 cls.RX = PulseParams(name="RX", params=defaults["RX"])
667 cls.RY = PulseParams(name="RY", params=defaults["RY"])
669 cls.RZ = PulseParams(name="RZ", params=general["RZ"])
670 cls.CZ = PulseParams(name="CZ", params=general["CZ"])
672 @classmethod
673 def _build_composite_gates(cls):
674 """(Re-)create composite PulseParams trees from current leaves."""
675 cls.H = PulseParams(
676 name="H",
677 decomposition=[
678 DecompositionStep(cls.RZ, "all", lambda w: jnp.pi),
679 DecompositionStep(cls.RY, "all", lambda w: jnp.pi / 2),
680 ],
681 )
682 cls.CX = PulseParams(
683 name="CX",
684 decomposition=[
685 DecompositionStep(cls.H, "target", lambda w: 0.0),
686 DecompositionStep(cls.CZ, "all", lambda w: 0.0),
687 DecompositionStep(cls.H, "target", lambda w: 0.0),
688 ],
689 )
690 cls.CY = PulseParams(
691 name="CY",
692 decomposition=[
693 DecompositionStep(cls.RZ, "target", lambda w: -jnp.pi / 2),
694 DecompositionStep(cls.CX, "all"),
695 DecompositionStep(cls.RZ, "target", lambda w: jnp.pi / 2),
696 ],
697 )
698 cls.CRX = PulseParams(
699 name="CRX",
700 decomposition=[
701 DecompositionStep(cls.RZ, "target", lambda w: jnp.pi / 2),
702 DecompositionStep(cls.RY, "target", lambda w: w / 2),
703 DecompositionStep(cls.CX, "all", lambda w: 0.0),
704 DecompositionStep(cls.RY, "target", lambda w: -w / 2),
705 DecompositionStep(cls.CX, "all", lambda w: 0.0),
706 DecompositionStep(cls.RZ, "target", lambda w: -jnp.pi / 2),
707 ],
708 )
709 cls.CRY = PulseParams(
710 name="CRY",
711 decomposition=[
712 DecompositionStep(cls.RY, "target", lambda w: w / 2),
713 DecompositionStep(cls.CX, "all", lambda w: 0.0),
714 DecompositionStep(cls.RY, "target", lambda w: -w / 2),
715 DecompositionStep(cls.CX, "all", lambda w: 0.0),
716 ],
717 )
718 cls.CRZ = PulseParams(
719 name="CRZ",
720 decomposition=[
721 DecompositionStep(cls.RZ, "target", lambda w: w / 2),
722 DecompositionStep(cls.CX, "all", lambda w: 0.0),
723 DecompositionStep(cls.RZ, "target", lambda w: -w / 2),
724 DecompositionStep(cls.CX, "all", lambda w: 0.0),
725 ],
726 )
727 # TODO: check if we could just make this a basis gate instead
728 cls.CPhase = PulseParams(
729 name="CPhase",
730 decomposition=[
731 DecompositionStep(cls.RZ, "control", lambda w: w / 2),
732 DecompositionStep(cls.RZ, "target", lambda w: w / 2),
733 DecompositionStep(cls.CX, "all", lambda w: 0.0),
734 DecompositionStep(cls.RZ, "target", lambda w: -w / 2),
735 DecompositionStep(cls.CX, "all", lambda w: 0.0),
736 ],
737 )
738 cls.RZZ = PulseParams(
739 name="RZZ",
740 decomposition=[
741 DecompositionStep(cls.CX, "all", lambda w: 0.0),
742 DecompositionStep(cls.RZ, "target", lambda w: w),
743 DecompositionStep(cls.CX, "all", lambda w: 0.0),
744 ],
745 )
746 cls.RXX = PulseParams(
747 name="RXX",
748 decomposition=[
749 DecompositionStep(cls.H, "control", lambda w: 0.0),
750 DecompositionStep(cls.H, "target", lambda w: 0.0),
751 DecompositionStep(cls.CX, "all", lambda w: 0.0),
752 DecompositionStep(cls.RZ, "target", lambda w: w),
753 DecompositionStep(cls.CX, "all", lambda w: 0.0),
754 DecompositionStep(cls.H, "control", lambda w: 0.0),
755 DecompositionStep(cls.H, "target", lambda w: 0.0),
756 ],
757 )
758 cls.RYY = PulseParams(
759 name="RYY",
760 decomposition=[
761 DecompositionStep(cls.RX, "control", lambda w: jnp.pi / 2),
762 DecompositionStep(cls.RX, "target", lambda w: jnp.pi / 2),
763 DecompositionStep(cls.CX, "all", lambda w: 0.0),
764 DecompositionStep(cls.RZ, "target", lambda w: w),
765 DecompositionStep(cls.CX, "all", lambda w: 0.0),
766 DecompositionStep(cls.RX, "control", lambda w: -jnp.pi / 2),
767 DecompositionStep(cls.RX, "target", lambda w: -jnp.pi / 2),
768 ],
769 )
770 cls.RZX = PulseParams(
771 name="RZX",
772 decomposition=[
773 DecompositionStep(cls.H, "target", lambda w: 0.0),
774 DecompositionStep(cls.CX, "all", lambda w: 0.0),
775 DecompositionStep(cls.RZ, "target", lambda w: w),
776 DecompositionStep(cls.CX, "all", lambda w: 0.0),
777 DecompositionStep(cls.H, "target", lambda w: 0.0),
778 ],
779 )
780 cls.Rot = PulseParams(
781 name="Rot",
782 decomposition=[
783 DecompositionStep(cls.RZ, "all", lambda w: w[0]),
784 DecompositionStep(cls.RY, "all", lambda w: w[1]),
785 DecompositionStep(cls.RZ, "all", lambda w: w[2]),
786 ],
787 )
788 cls.unique_gate_set = [cls.RX, cls.RY, cls.RZ, cls.CZ]
790 @classmethod
791 def set_envelope(
792 cls,
793 name: str,
794 rwa: Optional[bool] = None,
795 frame: Optional[str] = None,
796 ) -> None:
797 """Switch pulse envelope and rebuild all PulseParams trees.
799 Also updates the coefficient functions used by :class:`PulseGates`.
801 Args:
802 name: One of :meth:`PulseEnvelope.available`.
803 rwa: If given, also update the RWA flag. If ``None`` (the
804 default), the current value of ``cls._rwa`` is kept.
805 See :meth:`PulseEnvelope.build_coeff_fns` for the
806 physical meaning of the flag.
807 frame: If given, also update the coefficient frame
808 (``"lab"`` or ``"drive"``). ``None`` keeps the current
809 value of ``cls._frame``. Ignored when ``rwa=True`` or
810 when the existing RWA flag is on.
811 """
812 info = PulseEnvelope.get(name) # validates name
813 cls._envelope = name
814 if rwa is not None:
815 cls._rwa = bool(rwa)
816 if frame is not None:
817 if frame not in ("lab", "drive"):
818 raise ValueError(f"Unknown frame {frame!r}; expected 'lab' or 'drive'.")
819 cls._frame = frame
820 cls._build_leaf_gates()
821 cls._build_composite_gates()
823 # Rebuild interaction-picture coefficient functions on PulseGates.
824 # Four functions: (RX_X, RX_Y, RY_X, RY_Y) — one per (gate, Pauli)
825 # component of the proper interaction-picture drive Hamiltonian.
826 rx_x, rx_y, ry_x, ry_y = PulseEnvelope.build_coeff_fns(
827 info["fn"],
828 PulseGates.omega_c,
829 PulseGates.omega_q,
830 rwa=cls._rwa,
831 frame=cls._frame,
832 )
833 PulseGates._coeff_RX_X = staticmethod(rx_x)
834 PulseGates._coeff_RX_Y = staticmethod(rx_y)
835 PulseGates._coeff_RY_X = staticmethod(ry_x)
836 PulseGates._coeff_RY_Y = staticmethod(ry_y)
837 # Backward-compat aliases for older introspection (point at the
838 # X-component which dominates RX, Y-component which dominates RY).
839 PulseGates._coeff_Sx = staticmethod(rx_x)
840 PulseGates._coeff_Sy = staticmethod(ry_y)
841 PulseGates._active_envelope = name
842 PulseGates._active_rwa = cls._rwa
843 PulseGates._active_frame = cls._frame
845 # The compiled-solver cache in ``Evolution`` is keyed on the code
846 # objects of the coefficient functions. Rebuilding the coeff
847 # fns above produced fresh code objects, so any cached solver
848 # is now unreachable from the live coefficient functions and
849 # must be evicted to avoid both (a) holding compiled programs
850 # for a previous configuration alive forever and (b) returning
851 # a stale program if ``id`` collisions ever leaked through.
852 js.Evolution.clear_evolve_solver_cache()
854 log.info(
855 f"Pulse envelope set to '{name}' "
856 f"(RWA {'on' if cls._rwa else 'off'}, frame={cls._frame})"
857 )
859 @classmethod
860 def set_rwa(cls, rwa: bool) -> None:
861 """Toggle the rotating-wave approximation for pulse coefficients.
863 Rebuilds the coefficient functions for the currently active
864 envelope so the change takes effect immediately. Default is
865 ``False`` (exact interaction picture).
866 See :meth:`PulseEnvelope.build_coeff_fns` for details
867 """
868 cls.set_envelope(cls._envelope, rwa=bool(rwa))
870 @classmethod
871 def get_envelope(cls) -> str:
872 """Return the name of the active pulse envelope."""
873 return cls._envelope
875 @classmethod
876 def get_rwa(cls) -> bool:
877 """Return whether the RWA flag is currently active."""
878 return cls._rwa
880 @classmethod
881 def set_frame(cls, frame: str) -> None:
882 """Switch the algebraic representation of the (non-RWA) coefficients.
884 ``"lab"`` (default) and ``"drive"`` are mathematically
885 identical (no information lost, no RWA applied) — see
886 :meth:`PulseEnvelope.build_coeff_fns` for when ``"drive"`` is
887 useful. Rebuilds the coefficient functions for the currently
888 active envelope so the change takes effect immediately.
889 """
890 cls.set_envelope(cls._envelope, frame=str(frame))
892 @classmethod
893 def get_frame(cls) -> str:
894 """Return the active coefficient frame (``"lab"`` or ``"drive"``)."""
895 return cls._frame
897 @classmethod
898 def snapshot_state(cls) -> PulseStateSnapshot:
899 """Return an immutable snapshot of the active pulse configuration."""
900 leaf_params = {}
901 for name in cls.LEAF_GATE_NAMES:
902 gate = getattr(cls, name, None)
903 if gate is not None:
904 leaf_params[name] = jnp.array(gate.params)
906 return PulseStateSnapshot(
907 envelope=cls._envelope,
908 rwa=cls._rwa,
909 frame=cls._frame,
910 leaf_params=leaf_params,
911 )
913 @classmethod
914 def restore_state(cls, snapshot: PulseStateSnapshot) -> None:
915 """Restore a snapshot produced by :meth:`snapshot_state`."""
916 cls.set_envelope(snapshot.envelope, rwa=snapshot.rwa, frame=snapshot.frame)
918 for name, params in snapshot.leaf_params.items():
919 gate = cls.gate_by_name(name)
920 if gate is None or not gate.is_leaf:
921 raise ValueError(f"Cannot restore unknown leaf pulse gate {name!r}.")
922 if gate.params.shape != params.shape:
923 raise ValueError(
924 f"Snapshot for {name!r} has shape {params.shape}, "
925 f"but active gate expects {gate.params.shape}."
926 )
927 gate.params = params
929 @classmethod
930 @contextmanager
931 def preserve_state(cls):
932 """Temporarily preserve global pulse state across scoped mutations."""
933 snapshot = cls.snapshot_state()
934 try:
935 yield snapshot
936 finally:
937 cls.restore_state(snapshot)
939 @classmethod
940 def reset_defaults(
941 cls,
942 envelope: Optional[str] = None,
943 rwa: Optional[bool] = None,
944 frame: Optional[str] = None,
945 ) -> None:
946 """Reset pulse globals to canonical defaults or explicit values."""
947 cls.set_envelope(
948 cls.DEFAULT_ENVELOPE if envelope is None else envelope,
949 rwa=cls.DEFAULT_RWA if rwa is None else rwa,
950 frame=cls.DEFAULT_FRAME if frame is None else frame,
951 )
953 @staticmethod
954 def gate_by_name(gate):
955 if isinstance(gate, str):
956 return getattr(PulseInformation, gate, None)
957 else:
958 return getattr(PulseInformation, gate.__name__, None)
960 @staticmethod
961 def num_params(gate):
962 return len(PulseInformation.gate_by_name(gate))
964 @staticmethod
965 def update_params(path=f"{os.getcwd()}/qml_essentials/qoc_results.csv"):
966 if os.path.isfile(path):
967 log.info(f"Loading optimized pulses from {path}")
968 with open(path, "r") as f:
969 reader = csv.reader(f)
971 for row in reader:
972 log.debug(
973 f"Loading optimized pulses for {row[0]}\
974 (Fidelity: {float(row[1]):.5f}): {row[2:]}"
975 )
976 PulseInformation.OPTIMIZED_PULSES[row[0]] = jnp.array(
977 [float(x) for x in row[2:]]
978 )
979 else:
980 log.error(f"No optimized pulses found at {path}")
982 @staticmethod
983 def shuffle_params(random_key):
984 log.info(
985 f"Shuffling optimized pulses with random key {random_key}\
986 of gates {PulseInformation.unique_gate_set}"
987 )
988 for gate in PulseInformation.unique_gate_set:
989 random_key, sub_key = safe_random_split(random_key)
990 gate.params = jax.random.uniform(sub_key, (len(gate),))
993class PulseGates:
994 """Pulse-level implementations of quantum gates.
996 Implements quantum gates using time-dependent Hamiltonians and pulse
997 sequences, following the approach from https://doi.org/10.5445/IR/1000184129.
998 The active pulse envelope is selected via
999 :meth:`PulseInformation.set_envelope`.
1001 Attributes:
1002 omega_q: Qubit frequency (10π).
1003 omega_c: Carrier frequency (10π).
1004 _active_envelope: Name of the currently active envelope shape.
1005 """
1007 # NOTE: Implementation of S, RX, RY, RZ, CZ, CNOT/CX and H pulse level
1008 # gates closely follow https://doi.org/10.5445/IR/1000184129
1009 omega_q = 10 * jnp.pi
1010 omega_c = 10 * jnp.pi
1012 X = jnp.array([[0, 1], [1, 0]])
1013 Y = jnp.array([[0, -1j], [1j, 0]])
1014 Z = jnp.array([[1, 0], [0, -1]])
1016 Id = jnp.eye(2, dtype=jnp.complex64)
1018 _H_CZ = (jnp.pi / 4) * (
1019 jnp.kron(Id, Id) - jnp.kron(Z, Id) - jnp.kron(Id, Z) + jnp.kron(Z, Z)
1020 )
1022 _H_corr = jnp.pi / 2 * jnp.eye(2, dtype=jnp.complex64)
1024 _active_envelope: str = "gaussian"
1025 # Mirrors :attr:`PulseInformation._rwa`; kept here for introspection
1026 # of which coefficient regime the active ``_coeff_*`` functions
1027 # implement. Updated by :meth:`PulseInformation.set_envelope` /
1028 # :meth:`PulseInformation.set_rwa`.
1029 _active_rwa: bool = True
1030 _active_frame: str = "drive"
1032 # Default coefficient functions for the gaussian envelope; the active
1033 # envelope's `set_envelope` will overwrite these. Each gate uses two
1034 # coefficients (X- and Y-component of the proper interaction-picture
1035 # drive Hamiltonian).
1037 @staticmethod
1038 def _coeff_RX_X(p, t):
1039 """RX coefficient for the X term (gaussian default)."""
1040 t_c = t / 2
1041 env = PulseEnvelope.gaussian(p, t, t_c)
1042 carrier = jnp.cos(PulseGates.omega_c * t)
1043 return env * carrier * jnp.cos(PulseGates.omega_q * t) * p[-1]
1045 @staticmethod
1046 def _coeff_RX_Y(p, t):
1047 """RX coefficient for the Y term (gaussian default)."""
1048 t_c = t / 2
1049 env = PulseEnvelope.gaussian(p, t, t_c)
1050 carrier = jnp.cos(PulseGates.omega_c * t)
1051 return -env * carrier * jnp.sin(PulseGates.omega_q * t) * p[-1]
1053 @staticmethod
1054 def _coeff_RY_X(p, t):
1055 """RY coefficient for the X term (gaussian default)."""
1056 t_c = t / 2
1057 env = PulseEnvelope.gaussian(p, t, t_c)
1058 carrier = jnp.cos(PulseGates.omega_c * t + jnp.pi / 2)
1059 return env * carrier * jnp.cos(PulseGates.omega_q * t) * p[-1]
1061 @staticmethod
1062 def _coeff_RY_Y(p, t):
1063 """RY coefficient for the Y term (gaussian default)."""
1064 t_c = t / 2
1065 env = PulseEnvelope.gaussian(p, t, t_c)
1066 carrier = jnp.cos(PulseGates.omega_c * t + jnp.pi / 2)
1067 return -env * carrier * jnp.sin(PulseGates.omega_q * t) * p[-1]
1069 # Backward-compat aliases (resolve to the dominant component of each gate).
1070 _coeff_Sx = _coeff_RX_X
1071 _coeff_Sy = _coeff_RY_Y
1073 @staticmethod
1074 def _coeff_Sz(p, t):
1075 """Coefficient function for RZ pulse: p * w."""
1076 return p[0] * p[1]
1078 @staticmethod
1079 def _coeff_Sc(p, t):
1080 """Constant coefficient for H correction phase."""
1081 return -1.0
1083 @staticmethod
1084 def _coeff_Scz(p, t):
1085 """Coefficient function for CZ pulse."""
1086 return p * jnp.pi
1088 @staticmethod
1089 def _record_pulse_event(gate_name, w, wires, pulse_params, parent=None):
1090 """Append a PulseEvent to the active pulse tape if recording.
1092 This is called from leaf gate methods (RX, RY, RZ, CZ) so that
1093 :func:`~qml_essentials.tape.pulse_recording` can collect events
1094 without the caller needing to know about the tape.
1095 """
1096 ptape = active_pulse_tape()
1097 if ptape is None:
1098 return
1100 from qml_essentials.drawing import PulseEvent, LEAF_META
1102 meta = LEAF_META.get(gate_name, {})
1103 wires_list = [wires] if isinstance(wires, int) else list(wires)
1105 if meta.get("physical", False):
1106 info = PulseEnvelope.get(PulseInformation.get_envelope())
1107 pp = PulseInformation.gate_by_name(gate_name).split_params(pulse_params)
1108 env_p = pp[:-1]
1109 dur = float(pp[-1])
1110 ptape.append(
1111 PulseEvent(
1112 gate=gate_name,
1113 wires=wires_list,
1114 envelope_fn=info["fn"],
1115 envelope_params=jnp.array(env_p),
1116 w=float(w),
1117 duration=dur,
1118 carrier_phase=meta["carrier_phase"],
1119 parent=parent,
1120 )
1121 )
1122 else:
1123 pp = PulseInformation.gate_by_name(gate_name).split_params(pulse_params)
1124 ptape.append(
1125 PulseEvent(
1126 gate=gate_name,
1127 wires=wires_list,
1128 envelope_fn=None,
1129 envelope_params=jnp.ravel(jnp.asarray(pp)),
1130 w=float(w) if not isinstance(w, list) else 0.0,
1131 duration=1.0,
1132 carrier_phase=0.0,
1133 parent=parent,
1134 )
1135 )
1137 @staticmethod
1138 def Rot(
1139 phi: float,
1140 theta: float,
1141 omega: float,
1142 wires: Union[int, List[int]],
1143 pulse_params: Optional[jnp.ndarray] = None,
1144 noise_params: Optional[Dict[str, float]] = None,
1145 random_key: Optional[jax.random.PRNGKey] = None,
1146 ) -> None:
1147 """
1148 Apply general rotation via decomposition: RZ(phi) · RY(theta) · RZ(omega).
1150 Args:
1151 phi (float): First rotation angle.
1152 theta (float): Second rotation angle.
1153 omega (float): Third rotation angle.
1154 wires (Union[int, List[int]]): Qubit index or indices to apply rotation to.
1155 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the
1156 composing gates. If None, uses optimized parameters.
1157 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
1158 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
1160 Returns:
1161 None: Gates are applied in-place to the circuit.
1162 """
1163 if noise_params is not None and "GateError" in noise_params:
1164 phi, random_key = UnitaryGates.GateError(phi, noise_params, random_key)
1165 theta, random_key = UnitaryGates.GateError(theta, noise_params, random_key)
1166 omega, random_key = UnitaryGates.GateError(omega, noise_params, random_key)
1167 PulseGates._execute_composite("Rot", [phi, theta, omega], wires, pulse_params)
1168 UnitaryGates.Noise(wires, noise_params)
1170 @staticmethod
1171 def PauliRot(
1172 pauli: str,
1173 theta: float,
1174 wires: Union[int, List[int]],
1175 pulse_params: Optional[jnp.ndarray] = None,
1176 noise_params: Optional[Dict[str, float]] = None,
1177 random_key: Optional[jax.random.PRNGKey] = None,
1178 ) -> None:
1179 """Not implemented as a PulseGate."""
1180 raise NotImplementedError("PauliRot gate is not implemented as PulseGate")
1182 @staticmethod
1183 def RX(
1184 w: float,
1185 wires: Union[int, List[int]],
1186 pulse_params: Optional[jnp.ndarray] = None,
1187 noise_params: Optional[Dict[str, float]] = None,
1188 random_key: Optional[jax.random.PRNGKey] = None,
1189 ) -> None:
1190 """Apply X-axis rotation using the active pulse envelope.
1192 Args:
1193 w: Rotation angle in radians.
1194 wires: Qubit index or indices.
1195 pulse_params: Envelope parameters ``[env_0, ..., env_n, t]``.
1196 If ``None``, uses optimized defaults.
1197 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
1198 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
1199 """
1200 pulse_params = PulseInformation.RX.split_params(pulse_params)
1202 PulseGates._record_pulse_event("RX", w, wires, pulse_params)
1203 t = pulse_params[-1]
1205 # Proper interaction-picture drive Hamiltonian for RX:
1206 # H_I(τ) = Ω(τ)·cos(ω_c·τ) · [ cos(ω_q·τ)·X − sin(ω_q·τ)·Y ]
1207 # which on resonance averages (RWA) to +(Ω/2)·X while the
1208 # 2·ω_q counter-rotating part oscillates and cancels.
1209 H_X = js.Hamiltonian(PulseGates.X, wires=wires)
1210 H_Y = js.Hamiltonian(PulseGates.Y, wires=wires)
1211 H_eff = PulseGates._coeff_RX_X * H_X + PulseGates._coeff_RX_Y * H_Y
1213 # Pack: [envelope_params..., w] - evolution time is the last element
1214 # of pulse_params (pulse_params[-1]).
1215 w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
1216 # Use jnp.concatenate over Python list-splat to keep the trace graph
1217 # compact (no per-element unpacking + restack).
1218 env_params = jnp.concatenate(
1219 [jnp.ravel(pulse_params[:-1]), jnp.ravel(jnp.asarray(w))]
1220 )
1221 # Both terms share the same parameter array.
1222 H_eff.evolve(name="RX")([env_params, env_params], t)
1223 UnitaryGates.Noise(wires, noise_params)
1225 @staticmethod
1226 def RY(
1227 w: float,
1228 wires: Union[int, List[int]],
1229 pulse_params: Optional[jnp.ndarray] = None,
1230 noise_params: Optional[Dict[str, float]] = None,
1231 random_key: Optional[jax.random.PRNGKey] = None,
1232 ) -> None:
1233 """Apply Y-axis rotation using the active pulse envelope.
1235 Args:
1236 w: Rotation angle in radians.
1237 wires: Qubit index or indices.
1238 pulse_params: Envelope parameters ``[env_0, ..., env_n, t]``.
1239 If ``None``, uses optimized defaults.
1240 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
1241 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
1242 """
1243 pulse_params = PulseInformation.RY.split_params(pulse_params)
1245 PulseGates._record_pulse_event("RY", w, wires, pulse_params)
1246 t = pulse_params[-1]
1248 # See NOTE in RX: same proper interaction-picture form, with
1249 # carrier phase ϕ = +π/2 so the slow RWA component drives +Y.
1250 H_X = js.Hamiltonian(PulseGates.X, wires=wires)
1251 H_Y = js.Hamiltonian(PulseGates.Y, wires=wires)
1252 H_eff = PulseGates._coeff_RY_X * H_X + PulseGates._coeff_RY_Y * H_Y
1254 # Pack w into the params so the coefficient function doesn't need
1255 # a closure - this enables JIT solver cache sharing across all RY calls.
1256 w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
1257 env_params = jnp.concatenate(
1258 [jnp.ravel(pulse_params[:-1]), jnp.ravel(jnp.asarray(w))]
1259 )
1260 H_eff.evolve(name="RY")([env_params, env_params], t)
1261 UnitaryGates.Noise(wires, noise_params)
1263 @staticmethod
1264 def RZ(
1265 w: float,
1266 wires: Union[int, List[int]],
1267 pulse_params: Optional[float] = None,
1268 noise_params: Optional[Dict[str, float]] = None,
1269 random_key: Optional[jax.random.PRNGKey] = None,
1270 ) -> None:
1271 """
1272 Apply Z-axis rotation using pulse-level implementation.
1274 Implements RZ rotation using virtual Z rotations (phase tracking)
1275 without physical pulse application.
1277 Args:
1278 w (float): Rotation angle in radians.
1279 wires (Union[int, List[int]]): Qubit index or indices to apply rotation to.
1280 pulse_params (Optional[float]): Duration parameter for the pulse.
1281 Rotation angle = w * 2 * pulse_params. Defaults to 0.5 if None.
1282 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
1283 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
1285 Returns:
1286 None: Gate is applied in-place to the circuit.
1287 """
1288 pulse_params = PulseInformation.RZ.split_params(pulse_params)
1290 PulseGates._record_pulse_event("RZ", w, wires, pulse_params)
1292 _H = js.Hamiltonian(PulseGates.Z, wires=wires)
1293 H_eff = PulseGates._coeff_Sz * _H
1295 # Pack w into the params so the coefficient function doesn't need
1296 # a closure - [pulse_param_scalar, w] enables JIT solver cache sharing.
1297 # pulse_params may be a 1-element array or scalar; ravel + slice the first
1298 # element to preserve the original semantics, then concatenate with w.
1299 w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
1300 pp_flat = jnp.ravel(jnp.asarray(pulse_params))
1301 H_eff.evolve(name="RZ")(
1302 [jnp.concatenate([pp_flat[:1], jnp.ravel(jnp.asarray(w))])],
1303 1,
1304 )
1306 UnitaryGates.Noise(wires, noise_params)
1308 @staticmethod
1309 def _resolve_wires(wire_fn, wires):
1310 """Resolve a wire selector string to actual wire(s).
1312 Args:
1313 wire_fn: ``"all"``, ``"target"``, or ``"control"``.
1314 wires: Parent gate's wire(s) (int or list).
1316 Returns:
1317 Wire(s) for the child gate.
1318 """
1319 wires_list = [wires] if isinstance(wires, int) else list(wires)
1320 if wire_fn == "all":
1321 return wires if len(wires_list) > 1 else wires_list[0]
1322 if wire_fn == "target":
1323 return wires_list[-1] if len(wires_list) > 1 else wires_list[0]
1324 if wire_fn == "control":
1325 return wires_list[0]
1326 raise ValueError(f"Unknown wire_fn: {wire_fn!r}")
1328 @staticmethod
1329 def _execute_composite(gate_name, w, wires, pulse_params=None):
1330 """Execute a composite gate by walking its decomposition.
1332 Reads the :class:`DecompositionStep` list from
1333 :class:`PulseInformation` and dispatches each step to the
1334 appropriate ``PulseGates`` method.
1336 Args:
1337 gate_name: Gate name (e.g. ``"H"``, ``"CX"``).
1338 w: Rotation angle(s) passed to the parent gate.
1339 wires: Wire(s) of the parent gate.
1340 pulse_params: Optional pulse parameters (split across children).
1341 """
1342 pp_obj = PulseInformation.gate_by_name(gate_name)
1343 parts = pp_obj.split_params(pulse_params)
1345 for step, child_params in zip(pp_obj.decomposition, parts):
1346 child_wires = PulseGates._resolve_wires(step.wire_fn, wires)
1347 child_w = step.angle_fn(w) if step.angle_fn is not None else w
1348 child_gate = getattr(PulseGates, step.gate.name)
1350 # Leaf gates that take a rotation angle
1351 if step.gate.name in ("RX", "RY", "RZ"):
1352 child_gate(child_w, wires=child_wires, pulse_params=child_params)
1353 # Leaf gates without a rotation angle
1354 elif step.gate.name in ("CZ",):
1355 child_gate(wires=child_wires, pulse_params=child_params)
1356 # Composite gates with a rotation angle (CRX, CRY, CRZ, Rot, ...)
1357 elif step.gate.name in ("Rot",):
1358 # Rot expects (phi, theta, omega, wires, ...)
1359 child_gate(*child_w, wires=child_wires, pulse_params=child_params)
1360 elif step.gate.decomposition is not None and step.gate.name in (
1361 "CRX",
1362 "CRY",
1363 "CRZ",
1364 "CPhase",
1365 "RXX",
1366 "RYY",
1367 "RZZ",
1368 "RZX",
1369 ):
1370 child_gate(child_w, wires=child_wires, pulse_params=child_params)
1371 # Other composite gates (H, CX, CY, ...)
1372 else:
1373 child_gate(wires=child_wires, pulse_params=child_params)
1375 @staticmethod
1376 def H(
1377 wires: Union[int, List[int]],
1378 pulse_params: Optional[jnp.ndarray] = None,
1379 noise_params: Optional[Dict[str, float]] = None,
1380 random_key: Optional[jax.random.PRNGKey] = None,
1381 ) -> None:
1382 """Apply Hadamard gate using pulse decomposition.
1384 Decomposes as RZ(π) · RY(π/2) followed by a correction phase.
1386 Args:
1387 wires (Union[int, List[int]]): Qubit index or indices.
1388 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
1389 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
1390 (not used in this gate).
1391 """
1392 PulseGates._execute_composite("H", 0.0, wires, pulse_params)
1394 # Correction phase unique to the H gate
1395 _H = js.Hamiltonian(PulseGates._H_corr, wires=wires)
1396 H_corr = PulseGates._coeff_Sc * _H
1397 H_corr.evolve(name="H")([0], 1)
1398 UnitaryGates.Noise(wires, noise_params)
1400 @staticmethod
1401 def CX(
1402 wires: List[int],
1403 pulse_params: Optional[jnp.ndarray] = None,
1404 noise_params: Optional[Dict[str, float]] = None,
1405 random_key: Optional[jax.random.PRNGKey] = None,
1406 ) -> None:
1407 """Apply CNOT gate via decomposition: H(target) · CZ · H(target).
1409 Args:
1410 wires (List[int]): Control and target qubit indices [control, target].
1411 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the
1412 composing gates. If None, uses optimized parameters.
1413 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
1414 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
1415 (not used in this gate).
1417 Returns:
1418 None: Gate is applied in-place to the circuit.
1419 """
1420 PulseGates._execute_composite("CX", 0.0, wires, pulse_params)
1421 UnitaryGates.Noise(wires, noise_params)
1423 @staticmethod
1424 def CY(
1425 wires: List[int],
1426 pulse_params: Optional[jnp.ndarray] = None,
1427 noise_params: Optional[Dict[str, float]] = None,
1428 random_key: Optional[jax.random.PRNGKey] = None,
1429 ) -> None:
1430 """Apply controlled-Y via decomposition.
1432 Args:
1433 wires (List[int]): Control and target qubit indices [control, target].
1434 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the
1435 composing gates. If None, uses optimized parameters.
1436 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
1437 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
1438 (not used in this gate).
1440 """
1441 PulseGates._execute_composite("CY", 0.0, wires, pulse_params)
1442 UnitaryGates.Noise(wires, noise_params)
1444 @staticmethod
1445 def CZ(
1446 wires: List[int],
1447 pulse_params: Optional[float] = None,
1448 noise_params: Optional[Dict[str, float]] = None,
1449 random_key: Optional[jax.random.PRNGKey] = None,
1450 ) -> None:
1451 """Apply controlled-Z using ZZ coupling Hamiltonian.
1453 Args:
1454 wires (List[int]): Control and target qubit indices.
1455 pulse_params (Optional[float]): Time or duration parameter for
1456 the pulse evolution. If None, uses optimized value.
1457 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
1458 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
1459 (not used in this gate).
1461 """
1462 if pulse_params is None:
1463 pulse_params = PulseInformation.CZ.params
1465 PulseGates._record_pulse_event("CZ", 0.0, wires, pulse_params)
1467 _H = js.Hamiltonian(PulseGates._H_CZ, wires=wires)
1468 H_eff = PulseGates._coeff_Scz * _H
1469 H_eff.evolve(name="CZ")([pulse_params], 1)
1470 UnitaryGates.Noise(wires, noise_params)
1472 @staticmethod
1473 def CRX(
1474 w: float,
1475 wires: List[int],
1476 pulse_params: Optional[jnp.ndarray] = None,
1477 noise_params: Optional[Dict[str, float]] = None,
1478 random_key: Optional[jax.random.PRNGKey] = None,
1479 ) -> None:
1480 """Apply controlled-RX via decomposition.
1482 Args:
1483 w (float): Rotation angle in radians.
1484 wires (List[int]): Control and target qubit indices [control, target].
1485 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the
1486 composing gates. If None, uses optimized parameters.
1487 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
1488 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
1489 (not used in this gate).
1490 """
1491 PulseGates._execute_composite("CRX", w, wires, pulse_params)
1492 UnitaryGates.Noise(wires, noise_params)
1494 @staticmethod
1495 def CRY(
1496 w: float,
1497 wires: List[int],
1498 pulse_params: Optional[jnp.ndarray] = None,
1499 noise_params: Optional[Dict[str, float]] = None,
1500 random_key: Optional[jax.random.PRNGKey] = None,
1501 ) -> None:
1502 """Apply controlled-RY via decomposition.
1504 Args:
1505 w (float): Rotation angle in radians.
1506 wires (List[int]): Control and target qubit indices [control, target].
1507 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the
1508 composing gates. If None, uses optimized parameters.
1509 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
1510 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
1511 """
1512 w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
1513 PulseGates._execute_composite("CRY", w, wires, pulse_params)
1514 UnitaryGates.Noise(wires, noise_params)
1516 @staticmethod
1517 def CRZ(
1518 w: float,
1519 wires: List[int],
1520 pulse_params: Optional[jnp.ndarray] = None,
1521 noise_params: Optional[Dict[str, float]] = None,
1522 random_key: Optional[jax.random.PRNGKey] = None,
1523 ) -> None:
1524 """Apply controlled-RZ via decomposition.
1526 Args:
1527 w (float): Rotation angle in radians.
1528 wires (List[int]): Control and target qubit indices [control, target].
1529 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the
1530 composing gates. If None, uses optimized parameters.
1531 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
1532 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
1533 """
1534 w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
1535 PulseGates._execute_composite("CRZ", w, wires, pulse_params)
1536 UnitaryGates.Noise(wires, noise_params)
1538 @staticmethod
1539 def CPhase(
1540 w: float,
1541 wires: List[int],
1542 pulse_params: Optional[jnp.ndarray] = None,
1543 noise_params: Optional[Dict[str, float]] = None,
1544 random_key: Optional[jax.random.PRNGKey] = None,
1545 ) -> None:
1546 """Apply controlled phase shift via decomposition.
1548 Decomposes CPhase(φ) into RZ and CX gates:
1549 RZ(φ/2) on control, RZ(φ/2) on target, CX, RZ(-φ/2) on target, CX.
1551 Args:
1552 w (float): Phase shift angle in radians.
1553 wires (List[int]): Control and target qubit indices [control, target].
1554 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the
1555 composing gates. If None, uses optimized parameters.
1556 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
1557 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
1558 """
1559 w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
1560 PulseGates._execute_composite("CPhase", w, wires, pulse_params)
1561 UnitaryGates.Noise(wires, noise_params)
1563 @staticmethod
1564 def RXX(
1565 w: float,
1566 wires: List[int],
1567 pulse_params: Optional[jnp.ndarray] = None,
1568 noise_params: Optional[Dict[str, float]] = None,
1569 random_key: Optional[jax.random.PRNGKey] = None,
1570 ) -> None:
1571 """Apply two-qubit RXX rotation via decomposition.
1573 Implements ``RXX(theta) = exp(-i theta/2 X ⊗ X)`` as
1574 ``(H ⊗ H) · RZZ(theta) · (H ⊗ H)``.
1576 Args:
1577 w (float): Rotation angle in radians.
1578 wires (List[int]): Two qubit indices.
1579 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the
1580 composing gates. If None, uses optimized parameters.
1581 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
1582 random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
1583 """
1584 w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
1585 PulseGates._execute_composite("RXX", w, wires, pulse_params)
1586 UnitaryGates.Noise(wires, noise_params)
1588 @staticmethod
1589 def RYY(
1590 w: float,
1591 wires: List[int],
1592 pulse_params: Optional[jnp.ndarray] = None,
1593 noise_params: Optional[Dict[str, float]] = None,
1594 random_key: Optional[jax.random.PRNGKey] = None,
1595 ) -> None:
1596 """Apply two-qubit RYY rotation via decomposition.
1598 Implements ``RYY(theta) = exp(-i theta/2 Y ⊗ Y)`` by conjugating the
1599 RZZ skeleton with ``RX(pi/2)`` rotations on both wires.
1601 Args:
1602 w (float): Rotation angle in radians.
1603 wires (List[int]): Two qubit indices.
1604 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the
1605 composing gates. If None, uses optimized parameters.
1606 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
1607 random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
1608 """
1609 w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
1610 PulseGates._execute_composite("RYY", w, wires, pulse_params)
1611 UnitaryGates.Noise(wires, noise_params)
1613 @staticmethod
1614 def RZZ(
1615 w: float,
1616 wires: List[int],
1617 pulse_params: Optional[jnp.ndarray] = None,
1618 noise_params: Optional[Dict[str, float]] = None,
1619 random_key: Optional[jax.random.PRNGKey] = None,
1620 ) -> None:
1621 """Apply two-qubit RZZ rotation via decomposition.
1623 Implements ``RZZ(theta) = exp(-i theta/2 Z ⊗ Z)`` as
1624 ``CX · RZ(theta)_target · CX``.
1626 Args:
1627 w (float): Rotation angle in radians.
1628 wires (List[int]): Two qubit indices.
1629 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the
1630 composing gates. If None, uses optimized parameters.
1631 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
1632 random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
1633 """
1634 w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
1635 PulseGates._execute_composite("RZZ", w, wires, pulse_params)
1636 UnitaryGates.Noise(wires, noise_params)
1638 @staticmethod
1639 def RZX(
1640 w: float,
1641 wires: List[int],
1642 pulse_params: Optional[jnp.ndarray] = None,
1643 noise_params: Optional[Dict[str, float]] = None,
1644 random_key: Optional[jax.random.PRNGKey] = None,
1645 ) -> None:
1646 """Apply two-qubit RZX rotation via decomposition.
1648 Implements ``RZX(theta) = exp(-i theta/2 Z ⊗ X)`` (Z on the first
1649 wire, X on the second) by conjugating the RZZ skeleton with a
1650 Hadamard on the target wire.
1652 Args:
1653 w (float): Rotation angle in radians.
1654 wires (List[int]): Two qubit indices ``[zwire, xwire]``.
1655 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the
1656 composing gates. If None, uses optimized parameters.
1657 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
1658 random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
1659 """
1660 w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
1661 PulseGates._execute_composite("RZX", w, wires, pulse_params)
1662 UnitaryGates.Noise(wires, noise_params)
1665class PulseParamManager:
1666 def __init__(self, pulse_params: jnp.ndarray):
1667 self.pulse_params = pulse_params
1668 self.idx = 0
1670 def get(self, n: int):
1671 """Return the next n parameters and advance the cursor."""
1672 if self.idx + n > len(self.pulse_params):
1673 raise ValueError("Not enough pulse parameters left for this gate")
1674 # TODO: we squeeze here to get rid of any extra hidden dimension
1675 params = self.pulse_params[self.idx : self.idx + n].squeeze()
1676 self.idx += n
1677 return params
1680# Initialise PulseInformation after PulseGates exists so leaf defaults,
1681# composite trees, mirrored PulseGates flags, and coefficient functions agree.
1682PulseInformation.reset_defaults()