Coverage for qml_essentials / gates.py: 86%
98 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-11 15:51 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-11 15:51 +0000
1from typing import List, Union, Callable
2from contextlib import contextmanager
3import numbers
4import jax.numpy as jnp
5import jax
7# Imports to keep the api `from gates import ...`
8from qml_essentials.unitary import UnitaryGates
9from qml_essentials.pulses import (
10 PulseGates,
11 PulseParams,
12 PulseEnvelope, # noqa: F401
13 PulseInformation,
14 PulseParamManager,
15)
16from qml_essentials.operations import Barrier as BarrierOp
18import logging
20log = logging.getLogger(__name__)
23# Meta class to avoid instantiating the Gates class
24class GatesMeta(type):
25 def __getattr__(cls, gate_name):
26 def handler(*args, **kwargs):
27 return cls._inner_getattr(gate_name, *args, **kwargs)
29 # Dirty way to preserve information about the gate name
30 handler.__name__ = gate_name
31 return handler
34def Barrier(wires: Union[int, List[int]], *args, **kwargs):
35 """Thin wrapper for BarrierOp"""
36 return BarrierOp(wires)
39class Gates(metaclass=GatesMeta):
40 """
41 Dynamic accessor for quantum Gates.
43 Routes calls like `Gates.RX(...)` to either `UnitaryGates` or `PulseGates`
44 depending on the `gate_mode` keyword (defaults to 'unitary').
46 During circuit building, the pulse manager can be activated via
47 `pulse_manager_context`, which slices the global model pulse parameters
48 and passes them to each gate. Model pulse parameters act as element-wise
49 scalers on the gate's optimized pulse parameters.
51 Parameters
52 ----------
53 gate_mode : str, optional
54 Determines the backend. 'unitary' for UnitaryGates, 'pulse' for PulseGates.
55 Defaults to 'unitary'.
57 Examples
58 --------
59 >>> Gates.RX(w, wires)
60 >>> Gates.RX(w, wires, gate_mode="unitary")
61 >>> Gates.RX(w, wires, gate_mode="pulse")
62 >>> Gates.RX(w, wires, pulse_params, gate_mode="pulse")
63 """
65 def __getattr__(self, gate_name):
66 def handler(**kwargs):
67 return self._inner_getattr(gate_name, **kwargs)
69 return handler
71 @classmethod
72 def _inner_getattr(cls, gate_name, *args, **kwargs):
73 if gate_name == "Barrier":
74 return Barrier(*args, **kwargs)
76 gate_mode = kwargs.pop("gate_mode", "unitary")
78 # Backend selection and kwargs filtering
79 allowed_args = [
80 "w",
81 "wires",
82 "phi",
83 "theta",
84 "omega",
85 "noise_params",
86 "random_key",
87 ]
88 if gate_mode == "unitary":
89 gate_backend = UnitaryGates
90 elif gate_mode == "pulse":
91 gate_backend = PulseGates
92 allowed_args += ["pulse_params"]
93 else:
94 raise ValueError(
95 f"Unknown gate mode: {gate_mode}. Use 'unitary' or 'pulse'."
96 )
98 if len(kwargs.keys() - allowed_args) > 0:
99 # TODO: pulse params are always provided?
100 log.debug(
101 f"Unsupported keyword arguments: {list(kwargs.keys() - allowed_args)}"
102 )
104 kwargs = {k: v for k, v in kwargs.items() if k in allowed_args}
105 pulse_params = kwargs.get("pulse_params")
106 pulse_mgr = getattr(cls, "_pulse_mgr", None)
108 # TODO: rework this part to convert to valid PulseParams earlier
109 # Type check on pulse parameters
110 if pulse_params is not None:
111 # flatten pulse parameters
112 if isinstance(pulse_params, (list, tuple)):
113 flat_params = pulse_params
115 elif isinstance(pulse_params, jax.core.Tracer):
116 flat_params = jnp.ravel(pulse_params)
118 elif isinstance(pulse_params, (jnp.ndarray, jnp.ndarray)):
119 flat_params = pulse_params.flatten().tolist()
120 elif isinstance(pulse_params, PulseParams):
121 # extract the params in case a full object is given
122 kwargs["pulse_params"] = pulse_params.params
123 flat_params = pulse_params.params.flatten().tolist()
125 else:
126 raise TypeError(f"Unsupported pulse_params type: {type(pulse_params)}")
128 # checks elements in flat parameters are real numbers or jax Tracer
129 if not all(
130 isinstance(x, (numbers.Real, jax.core.Tracer)) for x in flat_params
131 ):
132 raise TypeError(
133 "All elements in pulse_params must be int or float, "
134 f"got {pulse_params}, type {type(pulse_params)}. "
135 )
137 # Len check on pulse parameters
138 if pulse_params is not None and not isinstance(pulse_mgr, PulseParamManager):
139 n_params = PulseInformation.gate_by_name(gate_name).size
140 if len(flat_params) != n_params:
141 raise ValueError(
142 f"Gate '{gate_name}' expects {n_params} pulse parameters, "
143 f"got {len(flat_params)}"
144 )
146 # Pulse slicing + scaling
147 if gate_mode == "pulse" and isinstance(pulse_mgr, PulseParamManager):
148 n_params = PulseInformation.gate_by_name(gate_name).size
149 scalers = pulse_mgr.get(n_params)
150 base = PulseInformation.gate_by_name(gate_name).params
151 kwargs["pulse_params"] = base * scalers
153 # Call the selected gate backend
154 gate = getattr(gate_backend, gate_name, None)
155 if gate is None:
156 raise AttributeError(
157 f"'{gate_backend.__class__.__name__}' object "
158 f"has no attribute '{gate_name}'"
159 )
161 return gate(*args, **kwargs)
163 @classmethod
164 @contextmanager
165 def pulse_manager_context(cls, pulse_params: jnp.ndarray):
166 """Temporarily set the global pulse manager for circuit building."""
167 cls._pulse_mgr = PulseParamManager(pulse_params)
168 try:
169 yield
170 finally:
171 cls._pulse_mgr = None
173 @classmethod
174 def parse_gates(
175 cls,
176 gates: Union[str, Callable, List[Union[str, Callable]]],
177 set_of_gates=None,
178 ):
179 set_of_gates = set_of_gates or cls
181 if isinstance(gates, str):
182 # if str, use the pennylane fct
183 parsed_gates = [getattr(set_of_gates, f"{gates}")]
184 elif isinstance(gates, list):
185 parsed_gates = []
186 for enc in gates:
187 # if list, check if str or callable
188 if isinstance(enc, str):
189 parsed_gates.append(getattr(set_of_gates, f"{enc}"))
190 # check if callable
191 elif callable(enc):
192 parsed_gates.append(enc)
193 else:
194 raise ValueError(
195 f"Operation {enc} is not a valid gate or callable.\
196 Got {type(enc)}"
197 )
198 elif callable(gates):
199 # default to callable
200 parsed_gates = [gates]
201 elif gates is None:
202 parsed_gates = [lambda *args, **kwargs: None]
203 else:
204 raise ValueError(
205 f"Operation {gates} is not a valid gate or callable or list of both."
206 )
207 return parsed_gates
209 @classmethod
210 def is_rotational(cls, gate):
211 return gate.__name__ in [
212 "RX",
213 "RY",
214 "RZ",
215 "Rot",
216 "CRX",
217 "CRY",
218 "CRZ",
219 "GolombEncoding",
220 "CPhase",
221 ]
223 @classmethod
224 def is_entangling(cls, gate):
225 return gate.__name__ in ["CX", "CY", "CZ", "CRX", "CRY", "CRZ", "CPhase"]