Coverage for qml_essentials / gates.py: 91%
98 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-30 11:43 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-30 11:43 +0000
1from typing import List, Union, Callable
2from contextlib import contextmanager
3import numbers
4import jax.numpy as jnp
5import jax
7# Imports to keep the api `from gates import ...`
8from qml_essentials.unitary import UnitaryGates
9from qml_essentials.pulses import (
10 PulseGates,
11 PulseParams,
12 PulseEnvelope,
13 PulseInformation,
14 PulseParamManager,
15)
16from qml_essentials.operations import Barrier as BarrierOp
18import logging
20log = logging.getLogger(__name__)
23# Meta class to avoid instantiating the Gates class
24class GatesMeta(type):
25 def __getattr__(cls, gate_name):
26 def handler(*args, **kwargs):
27 return Gates._inner_getattr(gate_name, *args, **kwargs)
29 # Dirty way to preserve information about the gate name
30 handler.__name__ = gate_name
31 return handler
34def Barrier(wires: Union[int, List[int]], *args, **kwargs):
35 """Thin wrapper for BarrierOp"""
36 return BarrierOp(wires)
39class Gates(metaclass=GatesMeta):
40 """
41 Dynamic accessor for quantum Gates.
43 Routes calls like `Gates.RX(...)` to either `UnitaryGates` or `PulseGates`
44 depending on the `gate_mode` keyword (defaults to 'unitary').
46 During circuit building, the pulse manager can be activated via
47 `pulse_manager_context`, which slices the global model pulse parameters
48 and passes them to each gate. Model pulse parameters act as element-wise
49 scalers on the gate's optimized pulse parameters.
51 Parameters
52 ----------
53 gate_mode : str, optional
54 Determines the backend. 'unitary' for UnitaryGates, 'pulse' for PulseGates.
55 Defaults to 'unitary'.
57 Examples
58 --------
59 >>> Gates.RX(w, wires)
60 >>> Gates.RX(w, wires, gate_mode="unitary")
61 >>> Gates.RX(w, wires, gate_mode="pulse")
62 >>> Gates.RX(w, wires, pulse_params, gate_mode="pulse")
63 """
65 def __getattr__(self, gate_name):
66 def handler(**kwargs):
67 return self._inner_getattr(gate_name, **kwargs)
69 return handler
71 @staticmethod
72 def _inner_getattr(gate_name, *args, **kwargs):
73 if gate_name == "Barrier":
74 return Barrier(*args, **kwargs)
76 gate_mode = kwargs.pop("gate_mode", "unitary")
78 # Backend selection and kwargs filtering
79 allowed_args = [
80 "w",
81 "wires",
82 "phi",
83 "theta",
84 "omega",
85 "input_idx",
86 "noise_params",
87 "random_key",
88 ]
89 if gate_mode == "unitary":
90 gate_backend = UnitaryGates
91 elif gate_mode == "pulse":
92 gate_backend = PulseGates
93 allowed_args += ["pulse_params"]
94 else:
95 raise ValueError(
96 f"Unknown gate mode: {gate_mode}. Use 'unitary' or 'pulse'."
97 )
99 if len(kwargs.keys() - allowed_args) > 0:
100 # TODO: pulse params are always provided?
101 log.debug(
102 f"Unsupported keyword arguments: {list(kwargs.keys() - allowed_args)}"
103 )
105 kwargs = {k: v for k, v in kwargs.items() if k in allowed_args}
106 pulse_params = kwargs.get("pulse_params")
107 pulse_mgr = getattr(Gates, "_pulse_mgr", None)
109 # TODO: rework this part to convert to valid PulseParams earlier
110 # Type check on pulse parameters
111 if pulse_params is not None:
112 # flatten pulse parameters
113 if isinstance(pulse_params, (list, tuple)):
114 flat_params = pulse_params
116 elif isinstance(pulse_params, jax.core.Tracer):
117 flat_params = jnp.ravel(pulse_params)
119 elif isinstance(pulse_params, (jnp.ndarray, jnp.ndarray)):
120 flat_params = pulse_params.flatten().tolist()
121 elif isinstance(pulse_params, PulseParams):
122 # extract the params in case a full object is given
123 kwargs["pulse_params"] = pulse_params.params
124 flat_params = pulse_params.params.flatten().tolist()
126 else:
127 raise TypeError(f"Unsupported pulse_params type: {type(pulse_params)}")
129 # checks elements in flat parameters are real numbers or jax Tracer
130 if not all(
131 isinstance(x, (numbers.Real, jax.core.Tracer)) for x in flat_params
132 ):
133 raise TypeError(
134 "All elements in pulse_params must be int or float, "
135 f"got {pulse_params}, type {type(pulse_params)}. "
136 )
138 # Len check on pulse parameters
139 if pulse_params is not None and not isinstance(pulse_mgr, PulseParamManager):
140 n_params = PulseInformation.gate_by_name(gate_name).size
141 if len(flat_params) != n_params:
142 raise ValueError(
143 f"Gate '{gate_name}' expects {n_params} pulse parameters, "
144 f"got {len(flat_params)}"
145 )
147 # Pulse slicing + scaling
148 if gate_mode == "pulse" and isinstance(pulse_mgr, PulseParamManager):
149 n_params = PulseInformation.gate_by_name(gate_name).size
150 scalers = pulse_mgr.get(n_params)
151 base = PulseInformation.gate_by_name(gate_name).params
152 kwargs["pulse_params"] = base * scalers
154 # Call the selected gate backend
155 gate = getattr(gate_backend, gate_name, None)
156 if gate is None:
157 raise AttributeError(
158 f"'{gate_backend.__class__.__name__}' object "
159 f"has no attribute '{gate_name}'"
160 )
162 return gate(*args, **kwargs)
164 @staticmethod
165 @contextmanager
166 def pulse_manager_context(pulse_params: jnp.ndarray):
167 """Temporarily set the global pulse manager for circuit building."""
168 Gates._pulse_mgr = PulseParamManager(pulse_params)
169 try:
170 yield
171 finally:
172 Gates._pulse_mgr = None
174 @staticmethod
175 def parse_gates(
176 gates: Union[str, Callable, List[Union[str, Callable]]],
177 set_of_gates=None,
178 ):
179 set_of_gates = set_of_gates or Gates
181 if isinstance(gates, str):
182 # if str, use the pennylane fct
183 parsed_gates = [getattr(set_of_gates, f"{gates}")]
184 elif isinstance(gates, list):
185 parsed_gates = []
186 for enc in gates:
187 # if list, check if str or callable
188 if isinstance(enc, str):
189 parsed_gates.append(getattr(set_of_gates, f"{enc}"))
190 # check if callable
191 elif callable(enc):
192 parsed_gates.append(enc)
193 else:
194 raise ValueError(
195 f"Operation {enc} is not a valid gate or callable.\
196 Got {type(enc)}"
197 )
198 elif callable(gates):
199 # default to callable
200 parsed_gates = [gates]
201 elif gates is None:
202 parsed_gates = [lambda *args, **kwargs: None]
203 else:
204 raise ValueError(
205 f"Operation {gates} is not a valid gate or callable or list of both."
206 )
207 return parsed_gates
209 @staticmethod
210 def is_rotational(gate):
211 return gate.__name__ in [
212 "RX",
213 "RY",
214 "RZ",
215 "Rot",
216 "CRX",
217 "CRY",
218 "CRZ",
219 ]
221 @staticmethod
222 def is_entangling(gate):
223 return gate.__name__ in ["CX", "CY", "CZ", "CRX", "CRY", "CRZ"]