Coverage for qml_essentials/gates.py: 86%
463 statements
« prev ^ index » next coverage.py v7.9.2, created at 2026-02-20 14:03 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2026-02-20 14:03 +0000
1import os
2from typing import Optional, List, Union, Dict, Callable, Tuple
3import numbers
4import csv
5import jax.numpy as np
6import pennylane as qml
7import jax
8import itertools
9from contextlib import contextmanager
10import logging
12from qml_essentials.utils import safe_random_split
14jax.config.update("jax_enable_x64", True)
15log = logging.getLogger(__name__)
18class UnitaryGates:
19 """Collection of unitary quantum gates with optional noise simulation."""
21 batch_gate_error = True
23 @staticmethod
24 def NQubitDepolarizingChannel(p: float, wires: List[int]) -> qml.QubitChannel:
25 """
26 Generate Kraus operators for n-qubit depolarizing channel.
28 The n-qubit depolarizing channel models uniform depolarizing noise
29 acting on n qubits simultaneously, useful for simulating realistic
30 multi-qubit noise affecting entangling gates.
32 Args:
33 p (float): Total probability of depolarizing error (0 ≤ p ≤ 1).
34 wires (List[int]): Qubit indices on which the channel acts.
35 Must contain at least 2 qubits.
37 Returns:
38 qml.QubitChannel: PennyLane QubitChannel with Kraus operators
39 representing the depolarizing noise channel.
41 Raises:
42 ValueError: If p is not in [0, 1] or if fewer than 2 qubits provided.
43 """
45 def n_qubit_depolarizing_kraus(p: float, n: int) -> List[np.ndarray]:
46 if not (0.0 <= p <= 1.0):
47 raise ValueError(f"Probability p must be between 0 and 1, got {p}")
48 if n < 2:
49 raise ValueError(f"Number of qubits must be >= 2, got {n}")
51 Id = np.eye(2)
52 X = qml.matrix(qml.PauliX(0))
53 Y = qml.matrix(qml.PauliY(0))
54 Z = qml.matrix(qml.PauliZ(0))
55 paulis = [Id, X, Y, Z]
57 dim = 2**n
58 all_ops = []
60 # Generate all n-qubit Pauli tensor products:
61 for indices in itertools.product(range(4), repeat=n):
62 P = np.eye(1)
63 for idx in indices:
64 P = np.kron(P, paulis[idx])
65 all_ops.append(P)
67 # Identity operator corresponds to all zeros indices (Id^n)
68 K0 = np.sqrt(1 - p * (4**n - 1) / (4**n)) * np.eye(dim)
70 kraus_ops = []
71 for i, P in enumerate(all_ops):
72 if i == 0:
73 # Skip the identity, already handled as K0
74 continue
75 kraus_ops.append(np.sqrt(p / (4**n)) * P)
77 return [K0] + kraus_ops
79 return qml.QubitChannel(n_qubit_depolarizing_kraus(p, len(wires)), wires=wires)
81 @staticmethod
82 def Noise(
83 wires: Union[int, List[int]], noise_params: Optional[Dict[str, float]] = None
84 ) -> None:
85 """
86 Apply noise channels to specified qubits.
88 Applies various single-qubit and multi-qubit noise channels based on
89 the provided noise parameters dictionary.
91 Args:
92 wires (Union[int, List[int]]): Qubit index or list of qubit indices
93 to apply noise to.
94 noise_params (Optional[Dict[str, float]]): Dictionary of noise
95 parameters. Supported keys:
96 - "BitFlip" (float): Bit flip error probability
97 - "PhaseFlip" (float): Phase flip error probability
98 - "Depolarizing" (float): Single-qubit depolarizing probability
99 - "MultiQubitDepolarizing" (float): Multi-qubit depolarizing
100 probability (applies if len(wires) > 1)
101 All parameters default to 0.0 if not provided.
103 Returns:
104 None: Noise channels are applied in-place to the circuit.
105 """
106 if noise_params is not None:
107 if isinstance(wires, int):
108 wires = [wires] # single qubit gate
110 # noise on single qubits
111 for wire in wires:
112 bf = noise_params.get("BitFlip", 0.0)
113 if bf > 0:
114 qml.BitFlip(bf, wires=wire)
116 pf = noise_params.get("PhaseFlip", 0.0)
117 if pf > 0:
118 qml.PhaseFlip(pf, wires=wire)
120 dp = noise_params.get("Depolarizing", 0.0)
121 if dp > 0:
122 qml.DepolarizingChannel(dp, wires=wire)
124 # noise on two-qubits
125 if len(wires) > 1:
126 p = noise_params.get("MultiQubitDepolarizing", 0.0)
127 if p > 0:
128 UnitaryGates.NQubitDepolarizingChannel(p, wires)
130 @staticmethod
131 def GateError(
132 w: Union[float, np.ndarray, List[float]],
133 noise_params: Optional[Dict[str, float]] = None,
134 random_key: Optional[jax.random.PRNGKey] = None,
135 ) -> Tuple[np.ndarray, jax.random.PRNGKey]:
136 """
137 Apply gate error noise to rotation angle(s).
139 Adds Gaussian noise to gate rotation angles to simulate imperfect
140 gate implementations.
142 Args:
143 w (Union[float, np.ndarray, List[float]]): Rotation angle(s) in radians.
144 noise_params (Optional[Dict[str, float]]): Dictionary with optional
145 "GateError" key specifying standard deviation of Gaussian noise.
146 random_key (Optional[jax.random.PRNGKey]): JAX random key for
147 stochastic noise generation.
149 Returns:
150 Tuple[np.ndarray, jax.random.PRNGKey]: Tuple containing:
151 - Modified rotation angle(s) with applied noise
152 - Updated JAX random key
154 Raises:
155 AssertionError: If noise_params contains "GateError" but random_key is None.
156 """
157 if noise_params is not None and noise_params.get("GateError", None) is not None:
158 assert (
159 random_key is not None
160 ), "A random_key must be provided when using GateError"
162 random_key, sub_key = safe_random_split(random_key)
163 w += noise_params["GateError"] * jax.random.normal(
164 sub_key,
165 (
166 w.shape
167 if isinstance(w, np.ndarray) and UnitaryGates.batch_gate_error
168 else (1,)
169 ),
170 )
171 return w, random_key
173 @staticmethod
174 def Rot(
175 phi: Union[float, np.ndarray, List[float]],
176 theta: Union[float, np.ndarray, List[float]],
177 omega: Union[float, np.ndarray, List[float]],
178 wires: Union[int, List[int]],
179 noise_params: Optional[Dict[str, float]] = None,
180 random_key: Optional[jax.random.PRNGKey] = None,
181 ) -> None:
182 """
183 Apply general rotation gate with optional noise.
185 Applies a three-angle rotation Rot(phi, theta, omega) with optional
186 gate errors and noise channels.
188 Args:
189 phi (Union[float, np.ndarray, List[float]]): First rotation angle.
190 theta (Union[float, np.ndarray, List[float]]): Second rotation angle.
191 omega (Union[float, np.ndarray, List[float]]): Third rotation angle.
192 wires (Union[int, List[int]]): Qubit index or indices to apply rotation to.
193 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
194 Supports BitFlip, PhaseFlip, Depolarizing, and GateError.
195 random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
197 Returns:
198 None: Gate and noise are applied in-place to the circuit.
199 """
200 if noise_params is not None and "GateError" in noise_params:
201 phi, random_key = UnitaryGates.GateError(phi, noise_params, random_key)
202 theta, random_key = UnitaryGates.GateError(theta, noise_params, random_key)
203 omega, random_key = UnitaryGates.GateError(omega, noise_params, random_key)
204 qml.Rot(phi, theta, omega, wires=wires)
205 UnitaryGates.Noise(wires, noise_params)
207 @staticmethod
208 def RX(
209 w: Union[float, np.ndarray, List[float]],
210 wires: Union[int, List[int]],
211 noise_params: Optional[Dict[str, float]] = None,
212 random_key: Optional[jax.random.PRNGKey] = None,
213 ) -> None:
214 """
215 Apply X-axis rotation with optional noise.
217 Args:
218 w (Union[float, np.ndarray, List[float]]): Rotation angle.
219 wires (Union[int, List[int]]): Qubit index or indices.
220 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
221 random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
223 Returns:
224 None: Gate and noise are applied in-place to the circuit.
225 """
226 w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
227 qml.RX(w, wires=wires)
228 UnitaryGates.Noise(wires, noise_params)
230 @staticmethod
231 def RY(
232 w: Union[float, np.ndarray, List[float]],
233 wires: Union[int, List[int]],
234 noise_params: Optional[Dict[str, float]] = None,
235 random_key: Optional[jax.random.PRNGKey] = None,
236 ) -> None:
237 """
238 Apply Y-axis rotation with optional noise.
240 Args:
241 w (Union[float, np.ndarray, List[float]]): Rotation angle.
242 wires (Union[int, List[int]]): Qubit index or indices.
243 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
244 random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
246 Returns:
247 None: Gate and noise are applied in-place to the circuit.
248 """
249 w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
250 qml.RY(w, wires=wires)
251 UnitaryGates.Noise(wires, noise_params)
253 @staticmethod
254 def RZ(
255 w: Union[float, np.ndarray, List[float]],
256 wires: Union[int, List[int]],
257 noise_params: Optional[Dict[str, float]] = None,
258 random_key: Optional[jax.random.PRNGKey] = None,
259 ) -> None:
260 """
261 Apply Z-axis rotation with optional noise.
263 Args:
264 w (Union[float, np.ndarray, List[float]]): Rotation angle.
265 wires (Union[int, List[int]]): Qubit index or indices.
266 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
267 random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
269 Returns:
270 None: Gate and noise are applied in-place to the circuit.
271 """
272 w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
273 qml.RZ(w, wires=wires)
274 UnitaryGates.Noise(wires, noise_params)
276 @staticmethod
277 def CRX(
278 w: Union[float, np.ndarray, List[float]],
279 wires: Union[int, List[int]],
280 noise_params: Optional[Dict[str, float]] = None,
281 random_key: Optional[jax.random.PRNGKey] = None,
282 ) -> None:
283 """
284 Apply controlled X-rotation with optional noise.
286 Args:
287 w (Union[float, np.ndarray, List[float]]): Rotation angle.
288 wires (Union[int, List[int]]): Control and target qubit indices.
289 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
290 random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
292 Returns:
293 None: Gate and noise are applied in-place to the circuit.
294 """
295 w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
296 qml.CRX(w, wires=wires)
297 UnitaryGates.Noise(wires, noise_params)
299 @staticmethod
300 def CRY(
301 w: Union[float, np.ndarray, List[float]],
302 wires: Union[int, List[int]],
303 noise_params: Optional[Dict[str, float]] = None,
304 random_key: Optional[jax.random.PRNGKey] = None,
305 ) -> None:
306 """
307 Apply controlled Y-rotation with optional noise.
309 Args:
310 w (Union[float, np.ndarray, List[float]]): Rotation angle.
311 wires (Union[int, List[int]]): Control and target qubit indices.
312 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
313 random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
315 Returns:
316 None: Gate and noise are applied in-place to the circuit.
317 """
318 w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
319 qml.CRY(w, wires=wires)
320 UnitaryGates.Noise(wires, noise_params)
322 @staticmethod
323 def CRZ(
324 w: Union[float, np.ndarray, List[float]],
325 wires: Union[int, List[int]],
326 noise_params: Optional[Dict[str, float]] = None,
327 random_key: Optional[jax.random.PRNGKey] = None,
328 ) -> None:
329 """
330 Apply controlled Z-rotation with optional noise.
332 Args:
333 w (Union[float, np.ndarray, List[float]]): Rotation angle.
334 wires (Union[int, List[int]]): Control and target qubit indices.
335 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
336 random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
338 Returns:
339 None: Gate and noise are applied in-place to the circuit.
340 """
341 w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
342 qml.CRZ(w, wires=wires)
343 UnitaryGates.Noise(wires, noise_params)
345 @staticmethod
346 def CX(
347 wires: Union[int, List[int]],
348 noise_params: Optional[Dict[str, float]] = None,
349 random_key: Optional[jax.random.PRNGKey] = None,
350 ) -> None:
351 """
352 Apply controlled-NOT (CNOT) gate with optional noise.
354 Args:
355 wires (Union[int, List[int]]): Control and target qubit indices.
356 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
357 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
358 (not used in this gate).
360 Returns:
361 None: Gate and noise are applied in-place to the circuit.
362 """
363 qml.CNOT(wires=wires)
364 UnitaryGates.Noise(wires, noise_params)
366 @staticmethod
367 def CY(
368 wires: Union[int, List[int]],
369 noise_params: Optional[Dict[str, float]] = None,
370 random_key: Optional[jax.random.PRNGKey] = None,
371 ) -> None:
372 """
373 Apply controlled-Y gate with optional noise.
375 Args:
376 wires (Union[int, List[int]]): Control and target qubit indices.
377 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
378 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
379 (not used in this gate).
381 Returns:
382 None: Gate and noise are applied in-place to the circuit.
383 """
384 qml.CY(wires=wires)
385 UnitaryGates.Noise(wires, noise_params)
387 @staticmethod
388 def CZ(
389 wires: Union[int, List[int]],
390 noise_params: Optional[Dict[str, float]] = None,
391 random_key: Optional[jax.random.PRNGKey] = None,
392 ) -> None:
393 """
394 Apply controlled-Z gate with optional noise.
396 Args:
397 wires (Union[int, List[int]]): Control and target qubit indices.
398 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
399 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
400 (not used in this gate).
402 Returns:
403 None: Gate and noise are applied in-place to the circuit.
404 """
405 qml.CZ(wires=wires)
406 UnitaryGates.Noise(wires, noise_params)
408 @staticmethod
409 def H(
410 wires: Union[int, List[int]],
411 noise_params: Optional[Dict[str, float]] = None,
412 random_key: Optional[jax.random.PRNGKey] = None,
413 ) -> None:
414 """
415 Apply Hadamard gate with optional noise.
417 Args:
418 wires (Union[int, List[int]]): Qubit index or indices.
419 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
420 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
421 (not used in this gate).
423 Returns:
424 None: Gate and noise are applied in-place to the circuit.
425 """
426 qml.Hadamard(wires=wires)
427 UnitaryGates.Noise(wires, noise_params)
430class PulseParams:
431 """
432 Container for hierarchical pulse parameters.
434 Manages pulse parameters for quantum gates, supporting both leaf nodes
435 (gates with direct parameters) and composite nodes (gates decomposed
436 into simpler gates). Enables hierarchical parameter access and
437 manipulation.
439 Attributes:
440 name (str): Name identifier for the gate.
441 _params (np.ndarray): Direct pulse parameters (leaf nodes only).
442 _pulse_obj (List): Child PulseParams objects (composite nodes only).
443 """
445 def __init__(
446 self,
447 name: str = "",
448 params: Optional[np.ndarray] = None,
449 pulse_obj: Optional[List] = None,
450 ) -> None:
451 """
452 Initialize pulse parameters container.
454 Args:
455 name (str): Name identifier for the gate. Defaults to empty string.
456 params (Optional[np.ndarray]): Direct pulse parameters for leaf gates.
457 Mutually exclusive with pulse_obj.
458 pulse_obj (Optional[List]): List of child PulseParams for composite
459 gates. Mutually exclusive with params.
461 Raises:
462 AssertionError: If both or neither of params and pulse_obj are provided.
463 """
464 assert (params is None and pulse_obj is not None) or (
465 params is not None and pulse_obj is None
466 ), "Exactly one of `params` or `pulse_params` must be provided."
468 self._pulse_obj = pulse_obj
470 if params is not None:
471 self._params = params
473 self.name = name
475 def __len__(self) -> int:
476 """
477 Get the total number of pulse parameters.
479 For composite gates, returns the accumulated count from all children.
481 Returns:
482 int: Total number of pulse parameters.
483 """
484 return len(self.params)
486 def __getitem__(self, idx: int) -> Union[float, np.ndarray]:
487 """
488 Access pulse parameter(s) by index.
490 For leaf gates, returns the parameter at the given index.
491 For composite gates, returns parameters of the child at the given index.
493 Args:
494 idx (int): Index to access.
496 Returns:
497 Union[float, np.ndarray]: Parameter value or child parameters.
498 """
499 if self.is_leaf:
500 return self.params[idx]
501 else:
502 return self.childs[idx].params
504 def __str__(self) -> str:
505 """Return string representation (gate name)."""
506 return self.name
508 def __repr__(self) -> str:
509 """Return repr string (gate name)."""
510 return self.name
512 @property
513 def is_leaf(self) -> bool:
514 """Check if this is a leaf node (direct parameters, no children)."""
515 return self._pulse_obj is None
517 @property
518 def size(self) -> int:
519 """Get the total parameter count (alias for __len__)."""
520 return len(self)
522 @property
523 def leafs(self) -> List["PulseParams"]:
524 """
525 Get all leaf nodes in the hierarchy.
527 Recursively collects all leaf PulseParams objects in the tree.
529 Returns:
530 List[PulseParams]: List of unique leaf nodes.
531 """
532 if self.is_leaf:
533 return [self]
535 leafs = []
536 for obj in self._pulse_obj:
537 leafs.extend(obj.leafs)
539 return list(set(leafs))
541 @property
542 def childs(self) -> List["PulseParams"]:
543 """
544 Get direct children of this node.
546 Returns:
547 List[PulseParams]: List of child PulseParams objects, or empty list
548 if this is a leaf node.
549 """
550 if self.is_leaf:
551 return []
553 return self._pulse_obj
555 @property
556 def shape(self) -> List[int]:
557 """
558 Get the shape of pulse parameters.
560 For leaf nodes, returns list with parameter count.
561 For composite nodes, returns nested list of child shapes.
563 Returns:
564 List[int]: Parameter shape specification.
565 """
566 if self.is_leaf:
567 return [len(self.params)]
569 shape = []
570 for obj in self.childs:
571 shape.append(*obj.shape())
573 return shape
575 @property
576 def params(self) -> np.ndarray:
577 """
578 Get or compute pulse parameters.
580 For leaf nodes, returns internal pulse parameters.
581 For composite nodes, returns concatenated parameters from all children.
583 Returns:
584 np.ndarray: Pulse parameters array.
585 """
586 if self.is_leaf:
587 return self._params
589 params = self.split_params(params=None, leafs=False)
591 return np.concatenate(params)
593 @params.setter
594 def params(self, value: np.ndarray) -> None:
595 """
596 Set pulse parameters.
598 For leaf nodes, sets internal parameters directly.
599 For composite nodes, distributes values across children.
601 Args:
602 value (np.ndarray): Pulse parameters to set.
604 Raises:
605 AssertionError: If value is not np.ndarray for leaf nodes.
606 """
607 if self.is_leaf:
608 assert isinstance(value, np.ndarray), "params must be a np.ndarray"
609 self._params = value
610 return
612 idx = 0
613 for obj in self.childs:
614 nidx = idx + obj.size
615 obj.params = value[idx:nidx]
616 idx = nidx
618 @property
619 def leaf_params(self) -> np.ndarray:
620 """
621 Get parameters from all leaf nodes.
623 Returns:
624 np.ndarray: Concatenated parameters from all leaf nodes.
625 """
626 if self.is_leaf:
627 return self._params
629 params = self.split_params(None, leafs=True)
631 return np.concatenate(params)
633 @leaf_params.setter
634 def leaf_params(self, value: np.ndarray) -> None:
635 """
636 Set parameters for all leaf nodes.
638 Args:
639 value (np.ndarray): Parameters to distribute across leaf nodes.
640 """
641 if self.is_leaf:
642 self._params = value
643 return
645 idx = 0
646 for obj in self.leafs:
647 nidx = idx + obj.size
648 obj.params = value[idx:nidx]
649 idx = nidx
651 def split_params(
652 self,
653 params: Optional[np.ndarray] = None,
654 leafs: bool = False,
655 ) -> List[np.ndarray]:
656 """
657 Split parameters into sub-arrays for children or leaves.
659 Args:
660 params (Optional[np.ndarray]): Parameters to split. If None,
661 uses internal parameters.
662 leafs (bool): If True, splits across leaf nodes; if False,
663 splits across direct children. Defaults to False.
665 Returns:
666 List[np.ndarray]: List of parameter arrays for children or leaves.
667 """
668 if params is None:
669 if self.is_leaf:
670 return self._params
672 objs = self.leafs if leafs else self.childs
673 s_params = []
674 for obj in objs:
675 s_params.append(obj.params)
677 return s_params
678 else:
679 if self.is_leaf:
680 return params
682 objs = self.leafs if leafs else self.childs
683 s_params = []
684 idx = 0
685 for obj in objs:
686 nidx = idx + obj.size
687 s_params.append(params[idx:nidx])
688 idx = nidx
690 return s_params
693class PulseInformation:
694 """
695 Stores pulse parameter counts and optimized pulse parameters for quantum Gates.
696 """
698 RX = PulseParams(
699 name="RX",
700 params=np.array([15.863171563255692, 29.66617464185762, 0.7544382603281181]),
701 )
702 RY = PulseParams(
703 name="RY",
704 params=np.array([7.921864297441735, 22.038129802391797, 1.0940923114464387]),
705 )
706 RZ = PulseParams(name="RZ", params=np.array([0.5]))
707 CZ = PulseParams(name="CZ", params=np.array([0.3183095268754836]))
708 H = PulseParams(
709 name="H",
710 pulse_obj=[RZ, RY],
711 )
713 # Rot = PulseParams(name=Gates.Rot, pulse_obj=[RZ, RY, RZ])
714 CX = PulseParams(name="CX", pulse_obj=[H, CZ, H])
715 CY = PulseParams(name="CY", pulse_obj=[RZ, CX, RZ])
717 CRX = PulseParams(name="CRX", pulse_obj=[RZ, RY, CX, RY, CX, RZ])
718 CRY = PulseParams(name="CRY", pulse_obj=[RY, CX, RY, CX])
719 CRZ = PulseParams(name="CRZ", pulse_obj=[RZ, CX, RZ, CX])
721 Rot = PulseParams(name="Rot", pulse_obj=[RZ, RY, RZ])
723 unique_gate_set = [
724 RX,
725 RY,
726 RZ,
727 CZ,
728 ]
730 @staticmethod
731 def gate_by_name(gate):
732 if isinstance(gate, str):
733 return getattr(PulseInformation, gate, None)
734 else:
735 return getattr(PulseInformation, gate.__name__, None)
737 @staticmethod
738 def num_params(gate):
739 return len(PulseInformation.gate_by_name(gate))
741 @staticmethod
742 def update_params(path=f"{os.getcwd()}/qml_essentials/qoc_results.csv"):
743 if os.path.isfile(path):
744 log.info(f"Loading optimized pulses from {path}")
745 with open(path, "r") as f:
746 reader = csv.reader(f)
748 for row in reader:
749 log.debug(
750 f"Loading optimized pulses for {row[0]}\
751 (Fidelity: {float(row[1]):.5f}): {row[2:]}"
752 )
753 PulseInformation.OPTIMIZED_PULSES[row[0]] = np.array(
754 [float(x) for x in row[2:]]
755 )
756 else:
757 log.error(f"No optimized pulses found at {path}")
759 @staticmethod
760 def shuffle_params(random_key):
761 log.info(
762 f"Shuffling optimized pulses with random key {random_key}\
763 of gates {PulseInformation.unique_gate_set}"
764 )
765 for gate in PulseInformation.unique_gate_set:
766 random_key, sub_key = safe_random_split(random_key)
767 gate.params = jax.random.uniform(sub_key, (len(gate),))
770class PulseGates:
771 """
772 Pulse-level implementations of quantum gates.
774 Implements quantum gates using time-dependent Hamiltonians and pulse
775 sequences, following the approach from https://doi.org/10.5445/IR/1000184129.
776 Gates are decomposed using shaped Gaussian pulses with carrier modulation.
778 Attributes:
779 omega_q (float): Qubit frequency (10π).
780 omega_c (float): Carrier frequency (10π).
781 H_static (np.ndarray): Static Hamiltonian in qubit rotating frame.
782 Id, X, Y, Z (np.ndarray): Pauli matrices for gate construction.
783 """
785 # NOTE: Implementation of S, RX, RY, RZ, CZ, CNOT/CX and H pulse level
786 # gates closely follow https://doi.org/10.5445/IR/1000184129
787 # TODO: Mention deviations from the above?
788 omega_q = 10 * np.pi
789 omega_c = 10 * np.pi
791 H_static = np.array([[np.exp(1j * omega_q / 2), 0], [0, np.exp(-1j * omega_q / 2)]])
793 Id = np.eye(2, dtype=np.complex64)
794 X = np.array([[0, 1], [1, 0]])
795 Y = np.array([[0, -1j], [1j, 0]])
796 Z = np.array([[1, 0], [0, -1]])
798 @staticmethod
799 def _S(
800 p: Union[List[float], np.ndarray],
801 t: Union[float, List[float], np.ndarray],
802 phi_c: float,
803 ) -> np.ndarray:
804 """
805 Generate shaped Gaussian pulse envelope with carrier modulation.
807 Internal helper function for creating time-dependent pulse shapes
808 used in rotation gates. Not intended for direct circuit use.
810 Args:
811 p (Union[List[float], np.ndarray]): Pulse parameters [A, sigma]:
812 - A (float): Amplitude of the Gaussian envelope
813 - sigma (float): Width (standard deviation) of the Gaussian
814 t (Union[float, List[float], np.ndarray]): Time or time interval
815 for pulse application. If sequence, center is computed as midpoint.
816 phi_c (float): Phase offset for the cosine carrier.
818 Returns:
819 np.ndarray: Shaped pulse amplitude at time(s) t.
820 """
821 A, sigma = p
822 t_c = (t[0] + t[1]) / 2 if isinstance(t, (list, tuple)) else t / 2
824 f = A * np.exp(-0.5 * ((t - t_c) / sigma) ** 2)
825 x = np.cos(PulseGates.omega_c * t + phi_c)
827 return f * x
829 @staticmethod
830 def Rot(
831 phi: float,
832 theta: float,
833 omega: float,
834 wires: Union[int, List[int]],
835 pulse_params: Optional[np.ndarray] = None,
836 ) -> None:
837 """
838 Apply general single-qubit rotation using pulse decomposition.
840 Decomposes a general rotation into RZ(phi) · RY(theta) · RZ(omega)
841 and applies each component using pulse-level implementations.
843 Args:
844 phi (float): First rotation angle.
845 theta (float): Second rotation angle.
846 omega (float): Third rotation angle.
847 wires (Union[int, List[int]]): Qubit index or indices to apply rotation to.
848 pulse_params (Optional[np.ndarray]): Pulse parameters for the
849 composing gates. If None, uses optimized parameters.
851 Returns:
852 None: Gates are applied in-place to the circuit.
853 """
854 params_RZ_1, params_RY, params_RZ_2 = PulseInformation.Rot.split_params(
855 pulse_params
856 )
858 PulseGates.RZ(phi, wires=wires, pulse_params=params_RZ_1)
859 PulseGates.RY(theta, wires=wires, pulse_params=params_RY)
860 PulseGates.RZ(omega, wires=wires, pulse_params=params_RZ_2)
862 @staticmethod
863 def RX(
864 w: float,
865 wires: Union[int, List[int]],
866 pulse_params: Optional[np.ndarray] = None,
867 ) -> None:
868 """
869 Apply X-axis rotation using pulse-level implementation.
871 Implements RX rotation using a shaped Gaussian pulse with optimized
872 envelope parameters.
874 Args:
875 w (float): Rotation angle in radians.
876 wires (Union[int, List[int]]): Qubit index or indices to apply rotation to.
877 pulse_params (Optional[np.ndarray]): Array containing pulse parameters
878 [A, sigma, t] for the Gaussian envelope. If None, uses optimized
879 parameters.
881 Returns:
882 None: Gate is applied in-place to the circuit.
883 """
884 pulse_params = PulseInformation.RX.split_params(pulse_params)
886 def Sx(p, t):
887 return PulseGates._S(p, t, phi_c=np.pi) * w
889 _H = PulseGates.H_static.conj().T @ PulseGates.X @ PulseGates.H_static
890 _H = qml.Hermitian(_H, wires=wires)
891 H_eff = Sx * _H
893 qml.evolve(H_eff)([pulse_params[0:2]], pulse_params[2])
895 @staticmethod
896 def RY(
897 w: float,
898 wires: Union[int, List[int]],
899 pulse_params: Optional[np.ndarray] = None,
900 ) -> None:
901 """
902 Apply Y-axis rotation using pulse-level implementation.
904 Implements RY rotation using a shaped Gaussian pulse with optimized
905 envelope parameters.
907 Args:
908 w (float): Rotation angle in radians.
909 wires (Union[int, List[int]]): Qubit index or indices to apply rotation to.
910 pulse_params (Optional[np.ndarray]): Array containing pulse parameters
911 [A, sigma, t] for the Gaussian envelope. If None, uses optimized
912 parameters.
914 Returns:
915 None: Gate is applied in-place to the circuit.
916 """
917 pulse_params = PulseInformation.RY.split_params(pulse_params)
919 def Sy(p, t):
920 return PulseGates._S(p, t, phi_c=-np.pi / 2) * w
922 _H = PulseGates.H_static.conj().T @ PulseGates.Y @ PulseGates.H_static
923 _H = qml.Hermitian(_H, wires=wires)
924 H_eff = Sy * _H
926 qml.evolve(H_eff)([pulse_params[0:2]], pulse_params[2])
928 @staticmethod
929 def RZ(
930 w: float, wires: Union[int, List[int]], pulse_params: Optional[float] = None
931 ) -> None:
932 """
933 Apply Z-axis rotation using pulse-level implementation.
935 Implements RZ rotation using virtual Z rotations (phase tracking)
936 without physical pulse application.
938 Args:
939 w (float): Rotation angle in radians.
940 wires (Union[int, List[int]]): Qubit index or indices to apply rotation to.
941 pulse_params (Optional[float]): Duration parameter for the pulse.
942 Rotation angle = w * 2 * pulse_params. Defaults to 0.5 if None.
944 Returns:
945 None: Gate is applied in-place to the circuit.
946 """
947 pulse_params = PulseInformation.RZ.split_params(pulse_params)
949 _H = qml.Hermitian(PulseGates.Z, wires=wires)
951 def Sz(p, t):
952 return p * w
954 H_eff = Sz * _H
956 qml.evolve(H_eff)([pulse_params], 1)
958 @staticmethod
959 def H(
960 wires: Union[int, List[int]], pulse_params: Optional[np.ndarray] = None
961 ) -> None:
962 """
963 Apply Hadamard gate using pulse decomposition.
965 Implements Hadamard as RZ(π) · RY(π/2) with a correction phase,
966 using pulse-level implementations for each component.
968 Args:
969 wires (Union[int, List[int]]): Qubit index or indices to apply gate to.
970 pulse_params (Optional[np.ndarray]): Pulse parameters for the
971 composing gates. If None, uses optimized parameters.
973 Returns:
974 None: Gate is applied in-place to the circuit.
975 """
976 pulse_params_RZ, pulse_params_RY = PulseInformation.H.split_params(pulse_params)
978 # qml.GlobalPhase(-np.pi / 2) # this could act as substitute to Sc
979 PulseGates.RZ(np.pi, wires=wires, pulse_params=pulse_params_RZ)
980 PulseGates.RY(np.pi / 2, wires=wires, pulse_params=pulse_params_RY)
982 def Sc(p, t):
983 return -1.0
985 _H = np.pi / 2 * np.eye(2, dtype=np.complex64)
986 _H = qml.Hermitian(_H, wires=wires)
987 H_corr = Sc * _H
989 qml.evolve(H_corr)([0], 1)
991 @staticmethod
992 def CX(wires: List[int], pulse_params: Optional[np.ndarray] = None) -> None:
993 """
994 Apply CNOT gate using pulse decomposition.
996 Implements CNOT as H_target · CZ · H_target, where H and CZ are
997 applied using their respective pulse-level implementations.
999 Args:
1000 wires (List[int]): Control and target qubit indices [control, target].
1001 pulse_params (Optional[np.ndarray]): Pulse parameters for the
1002 composing gates. If None, uses optimized parameters.
1004 Returns:
1005 None: Gate is applied in-place to the circuit.
1006 """
1007 params_H_1, params_CZ, params_H_2 = PulseInformation.CX.split_params(
1008 pulse_params
1009 )
1011 target = wires[1]
1013 PulseGates.H(wires=target, pulse_params=params_H_1)
1014 PulseGates.CZ(wires=wires, pulse_params=params_CZ)
1015 PulseGates.H(wires=target, pulse_params=params_H_2)
1017 @staticmethod
1018 def CY(wires: List[int], pulse_params: Optional[np.ndarray] = None) -> None:
1019 """
1020 Apply controlled-Y gate using pulse decomposition.
1022 Implements CY as RZ(-π/2)_target · CX · RZ(π/2)_target using
1023 pulse-level implementations.
1025 Args:
1026 wires (List[int]): Control and target qubit indices [control, target].
1027 pulse_params (Optional[np.ndarray]): Pulse parameters for the
1028 composing gates. If None, uses optimized parameters.
1030 Returns:
1031 None: Gate is applied in-place to the circuit.
1032 """
1033 params_RZ_1, params_CX, params_RZ_2 = PulseInformation.CY.split_params(
1034 pulse_params
1035 )
1037 target = wires[1]
1039 PulseGates.RZ(-np.pi / 2, wires=target, pulse_params=params_RZ_1)
1040 PulseGates.CX(wires=wires, pulse_params=params_CX)
1041 PulseGates.RZ(np.pi / 2, wires=target, pulse_params=params_RZ_2)
1043 @staticmethod
1044 def CZ(wires: List[int], pulse_params: Optional[float] = None) -> None:
1045 """
1046 Apply controlled-Z gate using pulse-level implementation.
1048 Implements CZ using a two-qubit interaction Hamiltonian based on
1049 ZZ coupling.
1051 Args:
1052 wires (List[int]): Control and target qubit indices.
1053 pulse_params (Optional[float]): Time or duration parameter for
1054 the pulse evolution. If None, uses optimized value.
1056 Returns:
1057 None: Gate is applied in-place to the circuit.
1058 """
1059 if pulse_params is None:
1060 pulse_params = PulseInformation.CZ.params
1061 else:
1062 pulse_params = pulse_params
1064 I_I = np.kron(PulseGates.Id, PulseGates.Id)
1065 Z_I = np.kron(PulseGates.Z, PulseGates.Id)
1066 I_Z = np.kron(PulseGates.Id, PulseGates.Z)
1067 Z_Z = np.kron(PulseGates.Z, PulseGates.Z)
1069 def Scz(p, t):
1070 return p * np.pi
1072 _H = (np.pi / 4) * (I_I - Z_I - I_Z + Z_Z)
1073 _H = qml.Hermitian(_H, wires=wires)
1074 H_eff = Scz * _H
1076 qml.evolve(H_eff)([pulse_params], 1)
1078 @staticmethod
1079 def CRX(
1080 w: float, wires: List[int], pulse_params: Optional[np.ndarray] = None
1081 ) -> None:
1082 """
1083 Apply controlled-RX gate using pulse decomposition.
1085 Implements CRX(w) as RZ(π/2) · RY(w/2) · CX · RY(-w/2) · CX · RZ(-π/2)
1086 applied to the target qubit, following arXiv:2408.01036.
1088 Args:
1089 w (float): Rotation angle in radians.
1090 wires (List[int]): Control and target qubit indices [control, target].
1091 pulse_params (Optional[np.ndarray]): Pulse parameters for the
1092 composing gates. If None, uses optimized parameters.
1094 Returns:
1095 None: Gate is applied in-place to the circuit.
1096 """
1097 params_RZ_1, params_RY, params_CX_1, params_RY_2, params_CX_2, params_RZ_2 = (
1098 PulseInformation.CRX.split_params(pulse_params)
1099 )
1101 target = wires[1]
1103 PulseGates.RZ(np.pi / 2, wires=target, pulse_params=params_RZ_1)
1104 PulseGates.RY(w / 2, wires=target, pulse_params=params_RY)
1105 PulseGates.CX(wires=wires, pulse_params=params_CX_1)
1106 PulseGates.RY(-w / 2, wires=target, pulse_params=params_RY_2)
1107 PulseGates.CX(wires=wires, pulse_params=params_CX_2)
1108 PulseGates.RZ(-np.pi / 2, wires=target, pulse_params=params_RZ_2)
1110 @staticmethod
1111 def CRY(
1112 w: float, wires: List[int], pulse_params: Optional[np.ndarray] = None
1113 ) -> None:
1114 """
1115 Apply controlled-RY gate using pulse decomposition.
1117 Implements CRY(w) as RY(w/2) · CX · RY(-w/2) · CX applied to the
1118 target qubit, following arXiv:2408.01036.
1120 Args:
1121 w (float): Rotation angle in radians.
1122 wires (List[int]): Control and target qubit indices [control, target].
1123 pulse_params (Optional[np.ndarray]): Pulse parameters for the
1124 composing gates. If None, uses optimized parameters.
1126 Returns:
1127 None: Gate is applied in-place to the circuit.
1128 """
1129 params_RY_1, params_CX_1, params_RY_2, params_CX_2 = (
1130 PulseInformation.CRY.split_params(pulse_params)
1131 )
1133 target = wires[1]
1135 PulseGates.RY(w / 2, wires=target, pulse_params=params_RY_1)
1136 PulseGates.CX(wires=wires, pulse_params=params_CX_1)
1137 PulseGates.RY(-w / 2, wires=target, pulse_params=params_RY_2)
1138 PulseGates.CX(wires=wires, pulse_params=params_CX_2)
1140 @staticmethod
1141 def CRZ(
1142 w: float, wires: List[int], pulse_params: Optional[np.ndarray] = None
1143 ) -> None:
1144 """
1145 Apply controlled-RZ gate using pulse decomposition.
1147 Implements CRZ(w) as RZ(w/2) · CX · RZ(-w/2) · CX applied to the
1148 target qubit, following arXiv:2408.01036.
1150 Args:
1151 w (float): Rotation angle in radians.
1152 wires (List[int]): Control and target qubit indices [control, target].
1153 pulse_params (Optional[np.ndarray]): Pulse parameters for the
1154 composing gates. If None, uses optimized parameters.
1156 Returns:
1157 None: Gate is applied in-place to the circuit.
1158 """
1159 params_RZ_1, params_CX_1, params_RZ_2, params_CX_2 = (
1160 PulseInformation.CRZ.split_params(pulse_params)
1161 )
1163 target = wires[1]
1165 PulseGates.RZ(w / 2, wires=target, pulse_params=params_RZ_1)
1166 PulseGates.CX(wires=wires, pulse_params=params_CX_1)
1167 PulseGates.RZ(-w / 2, wires=target, pulse_params=params_RZ_2)
1168 PulseGates.CX(wires=wires, pulse_params=params_CX_2)
1171# Meta class to avoid instantiating the Gates class
1172class GatesMeta(type):
1173 def __getattr__(cls, gate_name):
1174 def handler(*args, **kwargs):
1175 return Gates._inner_getattr(gate_name, *args, **kwargs)
1177 # Dirty way to preserve information about the gate name
1178 handler.__name__ = gate_name
1179 return handler
1182class Gates(metaclass=GatesMeta):
1183 """
1184 Dynamic accessor for quantum Gates.
1186 Routes calls like `Gates.RX(...)` to either `UnitaryGates` or `PulseGates`
1187 depending on the `gate_mode` keyword (defaults to 'unitary').
1189 During circuit building, the pulse manager can be activated via
1190 `pulse_manager_context`, which slices the global model pulse parameters
1191 and passes them to each gate. Model pulse parameters act as element-wise
1192 scalers on the gate's optimized pulse parameters.
1194 Parameters
1195 ----------
1196 gate_mode : str, optional
1197 Determines the backend. 'unitary' for UnitaryGates, 'pulse' for PulseGates.
1198 Defaults to 'unitary'.
1200 Examples
1201 --------
1202 >>> Gates.RX(w, wires)
1203 >>> Gates.RX(w, wires, gate_mode="unitary")
1204 >>> Gates.RX(w, wires, gate_mode="pulse")
1205 >>> Gates.RX(w, wires, pulse_params, gate_mode="pulse")
1206 """
1208 def __getattr__(self, gate_name):
1209 def handler(**kwargs):
1210 return self._inner_getattr(gate_name, **kwargs)
1212 return handler
1214 @staticmethod
1215 def _inner_getattr(gate_name, *args, **kwargs):
1216 gate_mode = kwargs.pop("gate_mode", "unitary")
1218 # Backend selection and kwargs filtering
1219 allowed_args = ["w", "wires", "phi", "theta", "omega"]
1220 if gate_mode == "unitary":
1221 gate_backend = UnitaryGates
1222 allowed_args += ["noise_params", "random_key"]
1223 elif gate_mode == "pulse":
1224 gate_backend = PulseGates
1225 allowed_args += ["pulse_params"]
1226 else:
1227 raise ValueError(
1228 f"Unknown gate mode: {gate_mode}. Use 'unitary' or 'pulse'."
1229 )
1231 if len(kwargs.keys() - allowed_args) > 0:
1232 # TODO: pulse params are always provided?
1233 log.debug(
1234 f"Unsupported keyword arguments: {list(kwargs.keys() - allowed_args)}"
1235 )
1237 kwargs = {k: v for k, v in kwargs.items() if k in allowed_args}
1238 pulse_params = kwargs.get("pulse_params")
1239 pulse_mgr = getattr(Gates, "_pulse_mgr", None)
1241 # TODO: rework this part to convert to valid PulseParams earlier
1242 # Type check on pulse parameters
1243 if pulse_params is not None:
1244 # flatten pulse parameters
1245 if isinstance(pulse_params, (list, tuple)):
1246 flat_params = pulse_params
1248 elif isinstance(pulse_params, jax.core.Tracer):
1249 flat_params = np.ravel(pulse_params)
1251 elif isinstance(pulse_params, (np.ndarray, np.ndarray)):
1252 flat_params = pulse_params.flatten().tolist()
1253 elif isinstance(pulse_params, PulseParams):
1254 # extract the params in case a full object is given
1255 kwargs["pulse_params"] = pulse_params.params
1256 flat_params = pulse_params.params.flatten().tolist()
1258 else:
1259 raise TypeError(f"Unsupported pulse_params type: {type(pulse_params)}")
1261 # checks elements in flat parameters are real numbers or jax Tracer
1262 if not all(
1263 isinstance(x, (numbers.Real, jax.core.Tracer)) for x in flat_params
1264 ):
1265 raise TypeError(
1266 "All elements in pulse_params must be int or float, "
1267 f"got {pulse_params}, type {type(pulse_params)}. "
1268 )
1270 # Len check on pulse parameters
1271 if pulse_params is not None and not isinstance(pulse_mgr, PulseParamManager):
1272 n_params = PulseInformation.gate_by_name(gate_name).size
1273 if len(flat_params) != n_params:
1274 raise ValueError(
1275 f"Gate '{gate_name}' expects {n_params} pulse parameters, "
1276 f"got {len(flat_params)}"
1277 )
1279 # Pulse slicing + scaling
1280 if gate_mode == "pulse" and isinstance(pulse_mgr, PulseParamManager):
1281 n_params = PulseInformation.gate_by_name(gate_name).size
1282 scalers = pulse_mgr.get(n_params)
1283 base = PulseInformation.gate_by_name(gate_name).params
1284 kwargs["pulse_params"] = base * scalers
1286 # Call the selected gate backend
1287 gate = getattr(gate_backend, gate_name, None)
1288 if gate is None:
1289 raise AttributeError(
1290 f"'{gate_backend.__class__.__name__}' object "
1291 f"has no attribute '{gate_name}'"
1292 )
1294 return gate(*args, **kwargs)
1296 @staticmethod
1297 @contextmanager
1298 def pulse_manager_context(pulse_params: np.ndarray):
1299 """Temporarily set the global pulse manager for circuit building."""
1300 Gates._pulse_mgr = PulseParamManager(pulse_params)
1301 try:
1302 yield
1303 finally:
1304 Gates._pulse_mgr = None
1306 @staticmethod
1307 def parse_gates(
1308 gates: Union[str, Callable, List[Union[str, Callable]]],
1309 set_of_gates=None,
1310 ):
1311 set_of_gates = set_of_gates or Gates
1313 if isinstance(gates, str):
1314 # if str, use the pennylane fct
1315 parsed_gates = [getattr(set_of_gates, f"{gates}")]
1316 elif isinstance(gates, list):
1317 parsed_gates = []
1318 for enc in gates:
1319 # if list, check if str or callable
1320 if isinstance(enc, str):
1321 parsed_gates.append(getattr(set_of_gates, f"{enc}"))
1322 # check if callable
1323 elif callable(enc):
1324 parsed_gates.append(enc)
1325 else:
1326 raise ValueError(
1327 f"Operation {enc} is not a valid gate or callable.\
1328 Got {type(enc)}"
1329 )
1330 elif callable(gates):
1331 # default to callable
1332 parsed_gates = [gates]
1333 elif gates is None:
1334 parsed_gates = [lambda *args, **kwargs: None]
1335 else:
1336 raise ValueError(
1337 f"Operation {gates} is not a valid gate or callable or list of both."
1338 )
1339 return parsed_gates
1341 @staticmethod
1342 def is_rotational(gate):
1343 return gate.__name__ in [
1344 "RX",
1345 "RY",
1346 "RZ",
1347 "Rot",
1348 "CRX",
1349 "CRY",
1350 "CRZ",
1351 ]
1353 @staticmethod
1354 def is_entangling(gate):
1355 return gate.__name__ in ["CX", "CY", "CZ", "CRX", "CRY", "CRZ"]
1358class PulseParamManager:
1359 def __init__(self, pulse_params: np.ndarray):
1360 self.pulse_params = pulse_params
1361 self.idx = 0
1363 def get(self, n: int):
1364 """Return the next n parameters and advance the cursor."""
1365 if self.idx + n > len(self.pulse_params):
1366 raise ValueError("Not enough pulse parameters left for this gate")
1367 # TODO: we squeeze here to get rid of any extra hidden dimension
1368 params = self.pulse_params[self.idx : self.idx + n].squeeze()
1369 self.idx += n
1370 return params