Coverage for qml_essentials / pulses.py: 82%
387 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-30 11:43 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-30 11:43 +0000
1import os
2from dataclasses import dataclass
3from typing import Optional, List, Union, Dict, Callable
4import csv
5import jax.numpy as jnp
6import jax
8from qml_essentials import operations as op
9from qml_essentials import yaqsi as ys
10from qml_essentials.utils import safe_random_split
11from qml_essentials.tape import active_pulse_tape
12from qml_essentials.unitary import UnitaryGates
13import logging
15log = logging.getLogger(__name__)
18@dataclass
19class DecompositionStep:
20 """One step in a composite pulse gate decomposition.
22 Attributes:
23 gate: Child PulseParams object for this step.
24 wire_fn: Wire selection - ``"all"``, ``"target"``, or ``"control"``.
25 angle_fn: Maps parent angle(s) ``w`` to child angle.
26 ``None`` means pass ``w`` through unchanged.
27 """
29 gate: "PulseParams"
30 wire_fn: str = "all"
31 angle_fn: Optional[Callable] = None
34class PulseParams:
35 """Container for hierarchical pulse parameters.
37 Leaf nodes hold direct parameters; composite nodes hold a list of
38 :class:`DecompositionStep` objects that describe how the gate is
39 built from simpler gates.
41 Attributes:
42 name: Gate identifier (e.g. ``"RX"``, ``"H"``).
43 decomposition: List of :class:`DecompositionStep` (composite only).
44 """
46 def __init__(
47 self,
48 name: str = "",
49 params: Optional[jnp.ndarray] = None,
50 decomposition: Optional[List[DecompositionStep]] = None,
51 ) -> None:
52 """
53 Args:
54 name: Gate name.
55 params: Direct pulse parameters (leaf gates).
56 Mutually exclusive with *decomposition*.
57 decomposition: List of :class:`DecompositionStep` (composite gates).
58 Mutually exclusive with *params*.
59 """
60 assert (params is None) != (
61 decomposition is None
62 ), "Exactly one of `params` or `decomposition` must be provided."
64 self.decomposition = decomposition
65 # Derive _pulse_obj for backward compat with childs/leafs/split_params
66 self._pulse_obj = (
67 [step.gate for step in decomposition] if decomposition else None
68 )
70 if params is not None:
71 self._params = params
73 self.name = name
75 def __len__(self) -> int:
76 """
77 Get the total number of pulse parameters.
79 For composite gates, returns the accumulated count from all children.
81 Returns:
82 int: Total number of pulse parameters.
83 """
84 return len(self.params)
86 def __getitem__(self, idx: int) -> Union[float, jnp.ndarray]:
87 """
88 Access pulse parameter(s) by index.
90 For leaf gates, returns the parameter at the given index.
91 For composite gates, returns parameters of the child at the given index.
93 Args:
94 idx (int): Index to access.
96 Returns:
97 Union[float, jnp.ndarray]: Parameter value or child parameters.
98 """
99 if self.is_leaf:
100 return self.params[idx]
101 else:
102 return self.childs[idx].params
104 def __str__(self) -> str:
105 """Return string representation (gate name)."""
106 return self.name
108 def __repr__(self) -> str:
109 """Return repr string (gate name)."""
110 return self.name
112 @property
113 def is_leaf(self) -> bool:
114 """Check if this is a leaf node (direct parameters, no children)."""
115 return self._pulse_obj is None
117 @property
118 def size(self) -> int:
119 """Get the total parameter count (alias for __len__)."""
120 return len(self)
122 @property
123 def leafs(self) -> List["PulseParams"]:
124 """
125 Get all leaf nodes in the hierarchy.
127 Recursively collects all leaf PulseParams objects in the tree.
129 Returns:
130 List[PulseParams]: List of unique leaf nodes.
131 """
132 if self.is_leaf:
133 return [self]
135 leafs = []
136 for obj in self._pulse_obj:
137 leafs.extend(obj.leafs)
139 return list(set(leafs))
141 @property
142 def childs(self) -> List["PulseParams"]:
143 """
144 Get direct children of this node.
146 Returns:
147 List[PulseParams]: List of child PulseParams objects, or empty list
148 if this is a leaf node.
149 """
150 if self.is_leaf:
151 return []
153 return self._pulse_obj
155 @property
156 def shape(self) -> List[int]:
157 """
158 Get the shape of pulse parameters.
160 For leaf nodes, returns list with parameter count.
161 For composite nodes, returns nested list of child shapes.
163 Returns:
164 List[int]: Parameter shape specification.
165 """
166 if self.is_leaf:
167 return [len(self.params)]
169 shape = []
170 for obj in self.childs:
171 shape.append(*obj.shape())
173 return shape
175 @property
176 def params(self) -> jnp.ndarray:
177 """
178 Get or compute pulse parameters.
180 For leaf nodes, returns internal pulse parameters.
181 For composite nodes, returns concatenated parameters from all children.
183 Returns:
184 jnp.ndarray: Pulse parameters array.
185 """
186 if self.is_leaf:
187 return self._params
189 params = self.split_params(params=None, leafs=False)
191 return jnp.concatenate(params)
193 @params.setter
194 def params(self, value: jnp.ndarray) -> None:
195 """
196 Set pulse parameters.
198 For leaf nodes, sets internal parameters directly.
199 For composite nodes, distributes values across children.
201 Args:
202 value (jnp.ndarray): Pulse parameters to set.
204 Raises:
205 AssertionError: If value is not jnp.ndarray for leaf nodes.
206 """
207 if self.is_leaf:
208 assert isinstance(value, jnp.ndarray), "params must be a jnp.ndarray"
209 self._params = value
210 return
212 idx = 0
213 for obj in self.childs:
214 nidx = idx + obj.size
215 obj.params = value[idx:nidx]
216 idx = nidx
218 @property
219 def leaf_params(self) -> jnp.ndarray:
220 """
221 Get parameters from all leaf nodes.
223 Returns:
224 jnp.ndarray: Concatenated parameters from all leaf nodes.
225 """
226 if self.is_leaf:
227 return self._params
229 params = self.split_params(None, leafs=True)
231 return jnp.concatenate(params)
233 @leaf_params.setter
234 def leaf_params(self, value: jnp.ndarray) -> None:
235 """
236 Set parameters for all leaf nodes.
238 Args:
239 value (jnp.ndarray): Parameters to distribute across leaf nodes.
240 """
241 if self.is_leaf:
242 self._params = value
243 return
245 idx = 0
246 for obj in self.leafs:
247 nidx = idx + obj.size
248 obj.params = value[idx:nidx]
249 idx = nidx
251 def split_params(
252 self,
253 params: Optional[jnp.ndarray] = None,
254 leafs: bool = False,
255 ) -> List[jnp.ndarray]:
256 """
257 Split parameters into sub-arrays for children or leaves.
259 Args:
260 params (Optional[jnp.ndarray]): Parameters to split. If None,
261 uses internal parameters.
262 leafs (bool): If True, splits across leaf nodes; if False,
263 splits across direct children. Defaults to False.
265 Returns:
266 List[jnp.ndarray]: List of parameter arrays for children or leaves.
267 """
268 if params is None:
269 if self.is_leaf:
270 return self._params
272 objs = self.leafs if leafs else self.childs
273 s_params = []
274 for obj in objs:
275 s_params.append(obj.params)
277 return s_params
278 else:
279 if self.is_leaf:
280 return params
282 objs = self.leafs if leafs else self.childs
283 s_params = []
284 idx = 0
285 for obj in objs:
286 nidx = idx + obj.size
287 s_params.append(params[idx:nidx])
288 idx = nidx
290 return s_params
293class PulseEnvelope:
294 """Registry of pulse envelope shapes.
296 Each envelope is a pure function ``(p, t, t_c) -> amplitude`` that
297 computes the pulse envelope *without* carrier modulation. The carrier
298 ``cos(omega_c * t + phi_c)`` is applied separately in the coefficient
299 functions built by :meth:`build_coeff_fns`.
301 Attributes:
302 REGISTRY: Mapping from envelope name to metadata dict containing
303 ``fn`` (callable), ``n_envelope_params`` (int), and per-gate
304 default parameter arrays.
305 """
307 @staticmethod
308 def gaussian(p, t, t_c):
309 """Gaussian envelope. ``p = [A, sigma]``."""
310 A, sigma = p[0], p[1]
311 return A * jnp.exp(-0.5 * ((t - t_c) / sigma) ** 2)
313 @staticmethod
314 def square(p, t, t_c):
315 """Rectangular envelope. ``p = [A, width]``."""
316 A, width = p[0], p[1]
317 return A * (jnp.abs(t - t_c) <= width / 2)
319 @staticmethod
320 def cosine(p, t, t_c):
321 """Raised cosine envelope. ``p = [A, width]``."""
322 A, width = p[0], p[1]
323 x = jnp.clip((t - t_c) / width, -0.5, 0.5)
324 return A * jnp.cos(jnp.pi * x)
326 @staticmethod
327 def drag(p, t, t_c):
328 """DRAG (Derivative Removal by Adiabatic Gate). ``p = [A, beta, sigma]``."""
329 A, beta, sigma = p[0], p[1], p[2]
330 g = A * jnp.exp(-0.5 * ((t - t_c) / sigma) ** 2)
331 dg = g * (-(t - t_c) / sigma**2)
332 return g + beta * dg
334 @staticmethod
335 def sech(p, t, t_c):
336 """Hyperbolic secant envelope. ``p = [A, sigma]``."""
337 A, sigma = p[0], p[1]
338 return A / jnp.cosh((t - t_c) / sigma)
340 # ``n_envelope_params`` counts only the envelope parameters (excluding
341 # the evolution time ``t`` which is always the last element of the full
342 # pulse parameter vector).
343 REGISTRY = {
344 "gaussian": {
345 "fn": gaussian.__func__,
346 "n_envelope_params": 2,
347 "defaults": {
348 "RX": jnp.array(
349 [30.187402725219727, 0.32704535126686096, 0.320675790309906]
350 ),
351 "RY": jnp.array(
352 [10.794735903531707, 0.12725685459013134, 0.3157523181268348]
353 ),
354 },
355 },
356 "square": {
357 "fn": square.__func__,
358 "n_envelope_params": 2,
359 "defaults": {
360 "RX": jnp.array([1.0, 1.0, 1.0]),
361 "RY": jnp.array([1.0, 1.0, 1.0]),
362 },
363 },
364 "cosine": {
365 "fn": cosine.__func__,
366 "n_envelope_params": 2,
367 "defaults": {
368 "RX": jnp.array([1.0, 1.0, 1.0]),
369 "RY": jnp.array([1.0, 1.0, 1.0]),
370 },
371 },
372 "drag": {
373 "fn": drag.__func__,
374 "n_envelope_params": 3,
375 "defaults": {
376 "RX": jnp.array([1.0, 1.0, 0.1, 1.0]),
377 "RY": jnp.array([1.0, 1.0, 0.1, 1.0]),
378 },
379 },
380 "sech": {
381 "fn": sech.__func__,
382 "n_envelope_params": 2,
383 "defaults": {
384 "RX": jnp.array([1.0, 1.0, 1.0]),
385 "RY": jnp.array([1.0, 1.0, 1.0]),
386 },
387 },
388 "general": {
389 "fn": None,
390 "n_envelope_params": 0,
391 "defaults": {
392 "RZ": jnp.array([0.5]),
393 "CZ": jnp.array([0.31831514835357666]),
394 },
395 },
396 }
398 @staticmethod
399 def available() -> List[str]:
400 """Return list of registered envelope names."""
401 return list(PulseEnvelope.REGISTRY.keys())
403 @staticmethod
404 def get(name: str) -> dict:
405 """Look up envelope metadata by name.
407 Raises:
408 ValueError: If *name* is not registered.
409 """
410 if name not in PulseEnvelope.REGISTRY:
411 raise ValueError(
412 f"Unknown pulse envelope '{name}'. "
413 f"Available: {PulseEnvelope.available()}"
414 )
415 return PulseEnvelope.REGISTRY[name]
417 @staticmethod
418 def build_coeff_fns(envelope_fn, omega_c):
419 """Build ``(coeff_Sx, coeff_Sy)`` for a given envelope function.
421 Each returned function has a unique ``__code__`` object so that
422 the yaqsi JIT solver cache (keyed on ``id(coeff_fn.__code__)``)
423 assigns a separate compiled XLA program per envelope shape.
425 The rotation angle ``w`` is expected as the **last** element of the
426 parameter array ``p`` (i.e. ``p[-1]``). Envelope parameters occupy
427 ``p[:-1]`` (excluding the evolution-time element that is passed
428 separately to ``ys.evolve``).
430 Args:
431 envelope_fn: Pure envelope function ``(p, t, t_c) -> scalar``.
432 omega_c: Carrier frequency.
434 Returns:
435 Tuple of ``(coeff_Sx, coeff_Sy)``.
436 """
438 def _coeff_Sx(p, t):
439 t_c = t / 2
440 env = envelope_fn(p, t, t_c)
441 carrier = jnp.cos(omega_c * t + jnp.pi)
442 return env * carrier * p[-1]
444 def _coeff_Sy(p, t):
445 t_c = t / 2
446 env = envelope_fn(p, t, t_c)
447 carrier = jnp.cos(omega_c * t - jnp.pi / 2)
448 return env * carrier * p[-1]
450 return _coeff_Sx, _coeff_Sy
453class PulseInformation:
454 """Stores pulse parameter counts and optimized pulse parameters.
456 Call :meth:`set_envelope` to switch the active pulse shape. This
457 rebuilds all :class:`PulseParams` trees so that parameter counts
458 and defaults match the selected envelope.
459 """
461 _envelope: str = "gaussian"
463 @classmethod
464 def _build_leaf_gates(cls):
465 """(Re-)create leaf PulseParams from the active envelope defaults."""
466 defaults = PulseEnvelope.get(cls._envelope)["defaults"]
467 general = PulseEnvelope.get("general")["defaults"]
469 cls.RX = PulseParams(name="RX", params=defaults["RX"])
470 cls.RY = PulseParams(name="RY", params=defaults["RY"])
472 cls.RZ = PulseParams(name="RZ", params=general["RZ"])
473 cls.CZ = PulseParams(name="CZ", params=general["CZ"])
475 @classmethod
476 def _build_composite_gates(cls):
477 """(Re-)create composite PulseParams trees from current leaves."""
478 cls.H = PulseParams(
479 name="H",
480 decomposition=[
481 DecompositionStep(cls.RZ, "all", lambda w: jnp.pi),
482 DecompositionStep(cls.RY, "all", lambda w: jnp.pi / 2),
483 ],
484 )
485 cls.CX = PulseParams(
486 name="CX",
487 decomposition=[
488 DecompositionStep(cls.H, "target", lambda w: 0.0),
489 DecompositionStep(cls.CZ, "all", lambda w: 0.0),
490 DecompositionStep(cls.H, "target", lambda w: 0.0),
491 ],
492 )
493 cls.CY = PulseParams(
494 name="CY",
495 decomposition=[
496 DecompositionStep(cls.RZ, "target", lambda w: -jnp.pi / 2),
497 DecompositionStep(cls.CX, "all"),
498 DecompositionStep(cls.RZ, "target", lambda w: jnp.pi / 2),
499 ],
500 )
501 cls.CRX = PulseParams(
502 name="CRX",
503 decomposition=[
504 DecompositionStep(cls.RZ, "target", lambda w: jnp.pi / 2),
505 DecompositionStep(cls.RY, "target", lambda w: w / 2),
506 DecompositionStep(cls.CX, "all", lambda w: 0.0),
507 DecompositionStep(cls.RY, "target", lambda w: -w / 2),
508 DecompositionStep(cls.CX, "all", lambda w: 0.0),
509 DecompositionStep(cls.RZ, "target", lambda w: -jnp.pi / 2),
510 ],
511 )
512 cls.CRY = PulseParams(
513 name="CRY",
514 decomposition=[
515 DecompositionStep(cls.RY, "target", lambda w: w / 2),
516 DecompositionStep(cls.CX, "all", lambda w: 0.0),
517 DecompositionStep(cls.RY, "target", lambda w: -w / 2),
518 DecompositionStep(cls.CX, "all", lambda w: 0.0),
519 ],
520 )
521 cls.CRZ = PulseParams(
522 name="CRZ",
523 decomposition=[
524 DecompositionStep(cls.RZ, "target", lambda w: w / 2),
525 DecompositionStep(cls.CX, "all", lambda w: 0.0),
526 DecompositionStep(cls.RZ, "target", lambda w: -w / 2),
527 DecompositionStep(cls.CX, "all", lambda w: 0.0),
528 ],
529 )
530 cls.Rot = PulseParams(
531 name="Rot",
532 decomposition=[
533 DecompositionStep(cls.RZ, "all", lambda w: w[0]),
534 DecompositionStep(cls.RY, "all", lambda w: w[1]),
535 DecompositionStep(cls.RZ, "all", lambda w: w[2]),
536 ],
537 )
538 cls.unique_gate_set = [cls.RX, cls.RY, cls.RZ, cls.CZ]
540 @classmethod
541 def set_envelope(cls, name: str) -> None:
542 """Switch pulse envelope and rebuild all PulseParams trees.
544 Also updates the coefficient functions used by :class:`PulseGates`.
546 Args:
547 name: One of :meth:`PulseEnvelope.available`.
548 """
549 info = PulseEnvelope.get(name) # validates name
550 cls._envelope = name
551 cls._build_leaf_gates()
552 cls._build_composite_gates()
554 # Rebuild coefficient functions on PulseGates
555 coeff_Sx, coeff_Sy = PulseEnvelope.build_coeff_fns(
556 info["fn"], PulseGates.omega_c
557 )
558 PulseGates._coeff_Sx = staticmethod(coeff_Sx)
559 PulseGates._coeff_Sy = staticmethod(coeff_Sy)
560 PulseGates._active_envelope = name
562 log.info(f"Pulse envelope set to '{name}'")
564 @classmethod
565 def get_envelope(cls) -> str:
566 """Return the name of the active pulse envelope."""
567 return cls._envelope
569 @staticmethod
570 def gate_by_name(gate):
571 if isinstance(gate, str):
572 return getattr(PulseInformation, gate, None)
573 else:
574 return getattr(PulseInformation, gate.__name__, None)
576 @staticmethod
577 def num_params(gate):
578 return len(PulseInformation.gate_by_name(gate))
580 @staticmethod
581 def update_params(path=f"{os.getcwd()}/qml_essentials/qoc_results.csv"):
582 if os.path.isfile(path):
583 log.info(f"Loading optimized pulses from {path}")
584 with open(path, "r") as f:
585 reader = csv.reader(f)
587 for row in reader:
588 log.debug(
589 f"Loading optimized pulses for {row[0]}\
590 (Fidelity: {float(row[1]):.5f}): {row[2:]}"
591 )
592 PulseInformation.OPTIMIZED_PULSES[row[0]] = jnp.array(
593 [float(x) for x in row[2:]]
594 )
595 else:
596 log.error(f"No optimized pulses found at {path}")
598 @staticmethod
599 def shuffle_params(random_key):
600 log.info(
601 f"Shuffling optimized pulses with random key {random_key}\
602 of gates {PulseInformation.unique_gate_set}"
603 )
604 for gate in PulseInformation.unique_gate_set:
605 random_key, sub_key = safe_random_split(random_key)
606 gate.params = jax.random.uniform(sub_key, (len(gate),))
609# Initialise PulseInformation with default (gaussian) envelope
610PulseInformation._build_leaf_gates()
611PulseInformation._build_composite_gates()
614class PulseGates:
615 """Pulse-level implementations of quantum gates.
617 Implements quantum gates using time-dependent Hamiltonians and pulse
618 sequences, following the approach from https://doi.org/10.5445/IR/1000184129.
619 The active pulse envelope is selected via
620 :meth:`PulseInformation.set_envelope`.
622 Attributes:
623 omega_q: Qubit frequency (10π).
624 omega_c: Carrier frequency (10π).
625 _active_envelope: Name of the currently active envelope shape.
626 """
628 # NOTE: Implementation of S, RX, RY, RZ, CZ, CNOT/CX and H pulse level
629 # gates closely follow https://doi.org/10.5445/IR/1000184129
630 omega_q = 10 * jnp.pi
631 omega_c = 10 * jnp.pi
633 H_static = jnp.array(
634 [[jnp.exp(1j * omega_q / 2), 0], [0, jnp.exp(-1j * omega_q / 2)]]
635 )
637 Id = jnp.eye(2, dtype=jnp.complex64)
638 X = jnp.array([[0, 1], [1, 0]])
639 Y = jnp.array([[0, -1j], [1j, 0]])
640 Z = jnp.array([[1, 0], [0, -1]])
642 _H_X = H_static.conj().T @ X @ H_static
643 _H_Y = H_static.conj().T @ Y @ H_static
645 _H_CZ = (jnp.pi / 4) * (
646 jnp.kron(Id, Id) - jnp.kron(Z, Id) - jnp.kron(Id, Z) + jnp.kron(Z, Z)
647 )
649 _H_corr = jnp.pi / 2 * jnp.eye(2, dtype=jnp.complex64)
651 _active_envelope: str = "gaussian"
653 @staticmethod
654 def _coeff_Sx(p, t):
655 """Coefficient function for RX pulse (active envelope)."""
656 t_c = t / 2
657 env = PulseEnvelope.gaussian(p, t, t_c)
658 carrier = jnp.cos(PulseGates.omega_c * t + jnp.pi)
659 return env * carrier * p[-1]
661 @staticmethod
662 def _coeff_Sy(p, t):
663 """Coefficient function for RY pulse (active envelope)."""
664 t_c = t / 2
665 env = PulseEnvelope.gaussian(p, t, t_c)
666 carrier = jnp.cos(PulseGates.omega_c * t - jnp.pi / 2)
667 return env * carrier * p[-1]
669 @staticmethod
670 def _coeff_Sz(p, t):
671 """Coefficient function for RZ pulse: p * w."""
672 return p[0] * p[1]
674 @staticmethod
675 def _coeff_Sc(p, t):
676 """Constant coefficient for H correction phase."""
677 return -1.0
679 @staticmethod
680 def _coeff_Scz(p, t):
681 """Coefficient function for CZ pulse."""
682 return p * jnp.pi
684 @staticmethod
685 def _record_pulse_event(gate_name, w, wires, pulse_params, parent=None):
686 """Append a PulseEvent to the active pulse tape if recording.
688 This is called from leaf gate methods (RX, RY, RZ, CZ) so that
689 :func:`~qml_essentials.tape.pulse_recording` can collect events
690 without the caller needing to know about the tape.
691 """
692 ptape = active_pulse_tape()
693 if ptape is None:
694 return
696 from qml_essentials.drawing import PulseEvent, LEAF_META
698 meta = LEAF_META.get(gate_name, {})
699 wires_list = [wires] if isinstance(wires, int) else list(wires)
701 if meta.get("physical", False):
702 info = PulseEnvelope.get(PulseInformation.get_envelope())
703 pp = PulseInformation.gate_by_name(gate_name).split_params(pulse_params)
704 env_p = pp[:-1]
705 dur = float(pp[-1])
706 ptape.append(
707 PulseEvent(
708 gate=gate_name,
709 wires=wires_list,
710 envelope_fn=info["fn"],
711 envelope_params=jnp.array(env_p),
712 w=float(w),
713 duration=dur,
714 carrier_phase=meta["carrier_phase"],
715 parent=parent,
716 )
717 )
718 else:
719 pp = PulseInformation.gate_by_name(gate_name).split_params(pulse_params)
720 ptape.append(
721 PulseEvent(
722 gate=gate_name,
723 wires=wires_list,
724 envelope_fn=None,
725 envelope_params=jnp.ravel(jnp.asarray(pp)),
726 w=float(w) if not isinstance(w, list) else 0.0,
727 duration=1.0,
728 carrier_phase=0.0,
729 parent=parent,
730 )
731 )
733 @staticmethod
734 def Rot(
735 phi: float,
736 theta: float,
737 omega: float,
738 wires: Union[int, List[int]],
739 pulse_params: Optional[jnp.ndarray] = None,
740 noise_params: Optional[Dict[str, float]] = None,
741 random_key: Optional[jax.random.PRNGKey] = None,
742 ) -> None:
743 """
744 Apply general rotation via decomposition: RZ(phi) · RY(theta) · RZ(omega).
746 Args:
747 phi (float): First rotation angle.
748 theta (float): Second rotation angle.
749 omega (float): Third rotation angle.
750 wires (Union[int, List[int]]): Qubit index or indices to apply rotation to.
751 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the
752 composing gates. If None, uses optimized parameters.
753 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
754 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
756 Returns:
757 None: Gates are applied in-place to the circuit.
758 """
759 if noise_params is not None and "GateError" in noise_params:
760 phi, random_key = UnitaryGates.GateError(phi, noise_params, random_key)
761 theta, random_key = UnitaryGates.GateError(theta, noise_params, random_key)
762 omega, random_key = UnitaryGates.GateError(omega, noise_params, random_key)
763 PulseGates._execute_composite("Rot", [phi, theta, omega], wires, pulse_params)
764 UnitaryGates.Noise(wires, noise_params)
766 @staticmethod
767 def PauliRot(
768 pauli: str,
769 theta: float,
770 wires: Union[int, List[int]],
771 pulse_params: Optional[jnp.ndarray] = None,
772 noise_params: Optional[Dict[str, float]] = None,
773 random_key: Optional[jax.random.PRNGKey] = None,
774 ) -> None:
775 """Not implemented as a PulseGate."""
776 raise NotImplementedError("PauliRot gate is not implemented as PulseGate")
778 @staticmethod
779 def RX(
780 w: float,
781 wires: Union[int, List[int]],
782 pulse_params: Optional[jnp.ndarray] = None,
783 noise_params: Optional[Dict[str, float]] = None,
784 random_key: Optional[jax.random.PRNGKey] = None,
785 ) -> None:
786 """Apply X-axis rotation using the active pulse envelope.
788 Args:
789 w: Rotation angle in radians.
790 wires: Qubit index or indices.
791 pulse_params: Envelope parameters ``[env_0, ..., env_n, t]``.
792 If ``None``, uses optimized defaults.
793 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
794 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
795 """
796 pulse_params = PulseInformation.RX.split_params(pulse_params)
798 PulseGates._record_pulse_event("RX", w, wires, pulse_params)
800 _H = op.Hermitian(PulseGates._H_X, wires=wires, record=False)
801 H_eff = PulseGates._coeff_Sx * _H
803 # Pack: [envelope_params..., w] - evolution time is the last element
804 # of pulse_params (pulse_params[-1]).
805 w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
806 env_params = jnp.array([*pulse_params[:-1], w])
807 ys.evolve(H_eff, name="RX")([env_params], pulse_params[-1])
808 UnitaryGates.Noise(wires, noise_params)
810 @staticmethod
811 def RY(
812 w: float,
813 wires: Union[int, List[int]],
814 pulse_params: Optional[jnp.ndarray] = None,
815 noise_params: Optional[Dict[str, float]] = None,
816 random_key: Optional[jax.random.PRNGKey] = None,
817 ) -> None:
818 """Apply Y-axis rotation using the active pulse envelope.
820 Args:
821 w: Rotation angle in radians.
822 wires: Qubit index or indices.
823 pulse_params: Envelope parameters ``[env_0, ..., env_n, t]``.
824 If ``None``, uses optimized defaults.
825 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
826 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
827 """
828 pulse_params = PulseInformation.RY.split_params(pulse_params)
830 PulseGates._record_pulse_event("RY", w, wires, pulse_params)
832 _H = op.Hermitian(PulseGates._H_Y, wires=wires, record=False)
833 H_eff = PulseGates._coeff_Sy * _H
835 # Pack w into the params so the coefficient function doesn't need
836 # a closure - this enables JIT solver cache sharing across all RY calls.
837 w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
838 env_params = jnp.array([*pulse_params[:-1], w])
839 ys.evolve(H_eff, name="RY")([env_params], pulse_params[-1])
840 UnitaryGates.Noise(wires, noise_params)
842 @staticmethod
843 def RZ(
844 w: float,
845 wires: Union[int, List[int]],
846 pulse_params: Optional[float] = None,
847 noise_params: Optional[Dict[str, float]] = None,
848 random_key: Optional[jax.random.PRNGKey] = None,
849 ) -> None:
850 """
851 Apply Z-axis rotation using pulse-level implementation.
853 Implements RZ rotation using virtual Z rotations (phase tracking)
854 without physical pulse application.
856 Args:
857 w (float): Rotation angle in radians.
858 wires (Union[int, List[int]]): Qubit index or indices to apply rotation to.
859 pulse_params (Optional[float]): Duration parameter for the pulse.
860 Rotation angle = w * 2 * pulse_params. Defaults to 0.5 if None.
861 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
862 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
864 Returns:
865 None: Gate is applied in-place to the circuit.
866 """
867 pulse_params = PulseInformation.RZ.split_params(pulse_params)
869 PulseGates._record_pulse_event("RZ", w, wires, pulse_params)
871 _H = op.Hermitian(PulseGates.Z, wires=wires, record=False)
872 H_eff = PulseGates._coeff_Sz * _H
874 # Pack w into the params so the coefficient function doesn't need
875 # a closure - [pulse_param_scalar, w] enables JIT solver cache sharing.
876 # pulse_params may be a 1-element array or scalar; ravel + index to
877 # ensure a scalar for concatenation.
878 w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
879 pp_scalar = jnp.ravel(jnp.asarray(pulse_params))[0]
880 ys.evolve(H_eff, name="RZ")([jnp.array([pp_scalar, w])], 1)
882 UnitaryGates.Noise(wires, noise_params)
884 @staticmethod
885 def _resolve_wires(wire_fn, wires):
886 """Resolve a wire selector string to actual wire(s).
888 Args:
889 wire_fn: ``"all"``, ``"target"``, or ``"control"``.
890 wires: Parent gate's wire(s) (int or list).
892 Returns:
893 Wire(s) for the child gate.
894 """
895 wires_list = [wires] if isinstance(wires, int) else list(wires)
896 if wire_fn == "all":
897 return wires if len(wires_list) > 1 else wires_list[0]
898 if wire_fn == "target":
899 return wires_list[-1] if len(wires_list) > 1 else wires_list[0]
900 if wire_fn == "control":
901 return wires_list[0]
902 raise ValueError(f"Unknown wire_fn: {wire_fn!r}")
904 @staticmethod
905 def _execute_composite(gate_name, w, wires, pulse_params=None):
906 """Execute a composite gate by walking its decomposition.
908 Reads the :class:`DecompositionStep` list from
909 :class:`PulseInformation` and dispatches each step to the
910 appropriate ``PulseGates`` method.
912 Args:
913 gate_name: Gate name (e.g. ``"H"``, ``"CX"``).
914 w: Rotation angle(s) passed to the parent gate.
915 wires: Wire(s) of the parent gate.
916 pulse_params: Optional pulse parameters (split across children).
917 """
918 pp_obj = PulseInformation.gate_by_name(gate_name)
919 parts = pp_obj.split_params(pulse_params)
921 for step, child_params in zip(pp_obj.decomposition, parts):
922 child_wires = PulseGates._resolve_wires(step.wire_fn, wires)
923 child_w = step.angle_fn(w) if step.angle_fn is not None else w
924 child_gate = getattr(PulseGates, step.gate.name)
926 # Leaf gates that take a rotation angle
927 if step.gate.name in ("RX", "RY", "RZ"):
928 child_gate(child_w, wires=child_wires, pulse_params=child_params)
929 # Leaf gates without a rotation angle
930 elif step.gate.name in ("CZ",):
931 child_gate(wires=child_wires, pulse_params=child_params)
932 # Composite gates with a rotation angle (CRX, CRY, CRZ, Rot, ...)
933 elif step.gate.name in ("Rot",):
934 # Rot expects (phi, theta, omega, wires, ...)
935 child_gate(*child_w, wires=child_wires, pulse_params=child_params)
936 elif step.gate.decomposition is not None and step.gate.name in (
937 "CRX",
938 "CRY",
939 "CRZ",
940 ):
941 child_gate(child_w, wires=child_wires, pulse_params=child_params)
942 # Other composite gates (H, CX, CY, ...)
943 else:
944 child_gate(wires=child_wires, pulse_params=child_params)
946 @staticmethod
947 def H(
948 wires: Union[int, List[int]],
949 pulse_params: Optional[jnp.ndarray] = None,
950 noise_params: Optional[Dict[str, float]] = None,
951 random_key: Optional[jax.random.PRNGKey] = None,
952 ) -> None:
953 """Apply Hadamard gate using pulse decomposition.
955 Decomposes as RZ(π) · RY(π/2) followed by a correction phase.
957 Args:
958 wires (Union[int, List[int]]): Qubit index or indices.
959 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
960 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
961 (not used in this gate).
962 """
963 PulseGates._execute_composite("H", 0.0, wires, pulse_params)
965 # Correction phase unique to the H gate
966 _H = op.Hermitian(PulseGates._H_corr, wires=wires, record=False)
967 H_corr = PulseGates._coeff_Sc * _H
968 ys.evolve(H_corr, name="H")([0], 1)
969 UnitaryGates.Noise(wires, noise_params)
971 @staticmethod
972 def CX(
973 wires: List[int],
974 pulse_params: Optional[jnp.ndarray] = None,
975 noise_params: Optional[Dict[str, float]] = None,
976 random_key: Optional[jax.random.PRNGKey] = None,
977 ) -> None:
978 """Apply CNOT gate via decomposition: H(target) · CZ · H(target).
980 Args:
981 wires (List[int]): Control and target qubit indices [control, target].
982 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the
983 composing gates. If None, uses optimized parameters.
984 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
985 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
986 (not used in this gate).
988 Returns:
989 None: Gate is applied in-place to the circuit.
990 """
991 PulseGates._execute_composite("CX", 0.0, wires, pulse_params)
992 UnitaryGates.Noise(wires, noise_params)
994 @staticmethod
995 def CY(
996 wires: List[int],
997 pulse_params: Optional[jnp.ndarray] = None,
998 noise_params: Optional[Dict[str, float]] = None,
999 random_key: Optional[jax.random.PRNGKey] = None,
1000 ) -> None:
1001 """Apply controlled-Y via decomposition.
1003 Args:
1004 wires (List[int]): Control and target qubit indices [control, target].
1005 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the
1006 composing gates. If None, uses optimized parameters.
1007 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
1008 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
1009 (not used in this gate).
1011 """
1012 PulseGates._execute_composite("CY", 0.0, wires, pulse_params)
1013 UnitaryGates.Noise(wires, noise_params)
1015 @staticmethod
1016 def CZ(
1017 wires: List[int],
1018 pulse_params: Optional[float] = None,
1019 noise_params: Optional[Dict[str, float]] = None,
1020 random_key: Optional[jax.random.PRNGKey] = None,
1021 ) -> None:
1022 """Apply controlled-Z using ZZ coupling Hamiltonian.
1024 Args:
1025 wires (List[int]): Control and target qubit indices.
1026 pulse_params (Optional[float]): Time or duration parameter for
1027 the pulse evolution. If None, uses optimized value.
1028 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
1029 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
1030 (not used in this gate).
1032 """
1033 if pulse_params is None:
1034 pulse_params = PulseInformation.CZ.params
1036 PulseGates._record_pulse_event("CZ", 0.0, wires, pulse_params)
1038 _H = op.Hermitian(PulseGates._H_CZ, wires=wires, record=False)
1039 H_eff = PulseGates._coeff_Scz * _H
1040 ys.evolve(H_eff, name="CZ")([pulse_params], 1)
1041 UnitaryGates.Noise(wires, noise_params)
1043 @staticmethod
1044 def CRX(
1045 w: float,
1046 wires: List[int],
1047 pulse_params: Optional[jnp.ndarray] = None,
1048 noise_params: Optional[Dict[str, float]] = None,
1049 random_key: Optional[jax.random.PRNGKey] = None,
1050 ) -> None:
1051 """Apply controlled-RX via decomposition.
1053 Args:
1054 w (float): Rotation angle in radians.
1055 wires (List[int]): Control and target qubit indices [control, target].
1056 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the
1057 composing gates. If None, uses optimized parameters.
1058 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
1059 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
1060 (not used in this gate).
1061 """
1062 PulseGates._execute_composite("CRX", w, wires, pulse_params)
1063 UnitaryGates.Noise(wires, noise_params)
1065 @staticmethod
1066 def CRY(
1067 w: float,
1068 wires: List[int],
1069 pulse_params: Optional[jnp.ndarray] = None,
1070 noise_params: Optional[Dict[str, float]] = None,
1071 random_key: Optional[jax.random.PRNGKey] = None,
1072 ) -> None:
1073 """Apply controlled-RY via decomposition.
1075 Args:
1076 w (float): Rotation angle in radians.
1077 wires (List[int]): Control and target qubit indices [control, target].
1078 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the
1079 composing gates. If None, uses optimized parameters.
1080 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
1081 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
1082 """
1083 w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
1084 PulseGates._execute_composite("CRY", w, wires, pulse_params)
1085 UnitaryGates.Noise(wires, noise_params)
1087 @staticmethod
1088 def CRZ(
1089 w: float,
1090 wires: List[int],
1091 pulse_params: Optional[jnp.ndarray] = None,
1092 noise_params: Optional[Dict[str, float]] = None,
1093 random_key: Optional[jax.random.PRNGKey] = None,
1094 ) -> None:
1095 """Apply controlled-RZ via decomposition.
1097 Args:
1098 w (float): Rotation angle in radians.
1099 wires (List[int]): Control and target qubit indices [control, target].
1100 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the
1101 composing gates. If None, uses optimized parameters.
1102 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
1103 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
1104 """
1105 w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
1106 PulseGates._execute_composite("CRZ", w, wires, pulse_params)
1107 UnitaryGates.Noise(wires, noise_params)
1110class PulseParamManager:
1111 def __init__(self, pulse_params: jnp.ndarray):
1112 self.pulse_params = pulse_params
1113 self.idx = 0
1115 def get(self, n: int):
1116 """Return the next n parameters and advance the cursor."""
1117 if self.idx + n > len(self.pulse_params):
1118 raise ValueError("Not enough pulse parameters left for this gate")
1119 # TODO: we squeeze here to get rid of any extra hidden dimension
1120 params = self.pulse_params[self.idx : self.idx + n].squeeze()
1121 self.idx += n
1122 return params