Coverage for qml_essentials / unitary.py: 95%
125 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-30 11:43 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-30 11:43 +0000
1from typing import Optional, List, Union, Dict, Tuple
2import itertools
3import jax.numpy as jnp
4import jax
6from qml_essentials import operations as op
7import logging
9from qml_essentials.utils import safe_random_split
11log = logging.getLogger(__name__)
14class UnitaryGates:
15 """Collection of unitary quantum gates with optional noise simulation."""
17 batch_gate_error = True
19 @staticmethod
20 def NQubitDepolarizingChannel(p: float, wires: List[int]) -> op.QubitChannel:
21 """
22 Generate Kraus operators for n-qubit depolarizing channel.
24 The n-qubit depolarizing channel models uniform depolarizing noise
25 acting on n qubits simultaneously, useful for simulating realistic
26 multi-qubit noise affecting entangling gates.
28 Args:
29 p (float): Total probability of depolarizing error (0 ≤ p ≤ 1).
30 wires (List[int]): Qubit indices on which the channel acts.
31 Must contain at least 2 qubits.
33 Returns:
34 op.QubitChannel: QubitChannel with Kraus operators
35 representing the depolarizing noise channel.
37 Raises:
38 ValueError: If p is not in [0, 1] or if fewer than 2 qubits provided.
39 """
41 def n_qubit_depolarizing_kraus(p: float, n: int) -> List[jnp.ndarray]:
42 if not (0.0 <= p <= 1.0):
43 raise ValueError(f"Probability p must be between 0 and 1, got {p}")
44 if n < 2:
45 raise ValueError(f"Number of qubits must be >= 2, got {n}")
47 Id = jnp.eye(2)
48 X = op.PauliX._matrix
49 Y = op.PauliY._matrix
50 Z = op.PauliZ._matrix
51 paulis = [Id, X, Y, Z]
53 dim = 2**n
54 all_ops = []
56 # Generate all n-qubit Pauli tensor products:
57 for indices in itertools.product(range(4), repeat=n):
58 P = jnp.eye(1)
59 for idx in indices:
60 P = jnp.kron(P, paulis[idx])
61 all_ops.append(P)
63 # Identity operator corresponds to all zeros indices (Id^n)
64 K0 = jnp.sqrt(1 - p * (4**n - 1) / (4**n)) * jnp.eye(dim)
66 kraus_ops = []
67 for i, P in enumerate(all_ops):
68 if i == 0:
69 # Skip the identity, already handled as K0
70 continue
71 kraus_ops.append(jnp.sqrt(p / (4**n)) * P)
73 return [K0] + kraus_ops
75 return op.QubitChannel(n_qubit_depolarizing_kraus(p, len(wires)), wires=wires)
77 @staticmethod
78 def Noise(
79 wires: Union[int, List[int]], noise_params: Optional[Dict[str, float]] = None
80 ) -> None:
81 """
82 Apply noise channels to specified qubits.
84 Applies various single-qubit and multi-qubit noise channels based on
85 the provided noise parameters dictionary.
87 Args:
88 wires (Union[int, List[int]]): Qubit index or list of qubit indices
89 to apply noise to.
90 noise_params (Optional[Dict[str, float]]): Dictionary of noise
91 parameters. Supported keys:
92 - "BitFlip" (float): Bit flip error probability
93 - "PhaseFlip" (float): Phase flip error probability
94 - "Depolarizing" (float): Single-qubit depolarizing probability
95 - "MultiQubitDepolarizing" (float): Multi-qubit depolarizing
96 probability (applies if len(wires) > 1)
97 All parameters default to 0.0 if not provided.
99 Returns:
100 None: Noise channels are applied in-place to the circuit.
101 """
102 if noise_params is not None:
103 if isinstance(wires, int):
104 wires = [wires] # single qubit gate
106 # noise on single qubits
107 for wire in wires:
108 bf = noise_params.get("BitFlip", 0.0)
109 if bf > 0:
110 op.BitFlip(bf, wires=wire)
112 pf = noise_params.get("PhaseFlip", 0.0)
113 if pf > 0:
114 op.PhaseFlip(pf, wires=wire)
116 dp = noise_params.get("Depolarizing", 0.0)
117 if dp > 0:
118 op.DepolarizingChannel(dp, wires=wire)
120 # noise on two-qubits
121 if len(wires) > 1:
122 p = noise_params.get("MultiQubitDepolarizing", 0.0)
123 if p > 0:
124 UnitaryGates.NQubitDepolarizingChannel(p, wires)
126 @staticmethod
127 def GateError(
128 w: Union[float, jnp.ndarray, List[float]],
129 noise_params: Optional[Dict[str, float]] = None,
130 random_key: Optional[jax.random.PRNGKey] = None,
131 ) -> Tuple[jnp.ndarray, jax.random.PRNGKey]:
132 """
133 Apply gate error noise to rotation angle(s).
135 Adds Gaussian noise to gate rotation angles to simulate imperfect
136 gate implementations.
138 Args:
139 w (Union[float, jnp.ndarray, List[float]]): Rotation angle(s) in radians.
140 noise_params (Optional[Dict[str, float]]): Dictionary with optional
141 "GateError" key specifying standard deviation of Gaussian noise.
142 random_key (Optional[jax.random.PRNGKey]): JAX random key for
143 stochastic noise generation.
145 Returns:
146 Tuple[jnp.ndarray, jax.random.PRNGKey]: Tuple containing:
147 - Modified rotation angle(s) with applied noise
148 - Updated JAX random key
150 Raises:
151 AssertionError: If noise_params contains "GateError" but random_key is None.
152 """
153 if noise_params is not None and noise_params.get("GateError", None) is not None:
154 assert (
155 random_key is not None
156 ), "A random_key must be provided when using GateError"
158 if UnitaryGates.batch_gate_error:
159 random_key, sub_key = safe_random_split(random_key)
160 else:
161 # Use a fixed key so that every batch element (under vmap)
162 # draws the same noise value, effectively broadcasting.
163 sub_key = jax.random.key(0)
165 w += noise_params["GateError"] * jax.random.normal(
166 sub_key,
167 (
168 w.shape
169 if isinstance(w, jnp.ndarray) and UnitaryGates.batch_gate_error
170 else ()
171 ),
172 )
173 return w, random_key
175 @staticmethod
176 def Rot(
177 phi: Union[float, jnp.ndarray, List[float]],
178 theta: Union[float, jnp.ndarray, List[float]],
179 omega: Union[float, jnp.ndarray, List[float]],
180 wires: Union[int, List[int]],
181 noise_params: Optional[Dict[str, float]] = None,
182 random_key: Optional[jax.random.PRNGKey] = None,
183 input_idx: int = -1,
184 ) -> None:
185 """
186 Apply general rotation gate with optional noise.
188 Applies a three-angle rotation Rot(phi, theta, omega) with optional
189 gate errors and noise channels.
191 Args:
192 phi (Union[float, jnp.ndarray, List[float]]): First rotation angle.
193 theta (Union[float, jnp.ndarray, List[float]]): Second rotation angle.
194 omega (Union[float, jnp.ndarray, List[float]]): Third rotation angle.
195 wires (Union[int, List[int]]): Qubit index or indices to apply rotation to.
196 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
197 Supports BitFlip, PhaseFlip, Depolarizing, and GateError.
198 random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
199 input_idx (int): Flag for the tape to track inputs
201 Returns:
202 None: Gate and noise are applied in-place to the circuit.
203 """
204 if noise_params is not None and "GateError" in noise_params:
205 phi, random_key = UnitaryGates.GateError(phi, noise_params, random_key)
206 theta, random_key = UnitaryGates.GateError(theta, noise_params, random_key)
207 omega, random_key = UnitaryGates.GateError(omega, noise_params, random_key)
208 op.Rot(phi, theta, omega, wires=wires, input_idx=False)
209 UnitaryGates.Noise(wires, noise_params)
211 @staticmethod
212 def PauliRot(
213 theta: float,
214 pauli: str,
215 wires: Union[int, List[int]],
216 noise_params: Optional[Dict[str, float]] = None,
217 random_key: Optional[jax.random.PRNGKey] = None,
218 input_idx: int = -1,
219 ) -> None:
220 """
221 Apply general rotation gate with optional noise.
223 Applies a three-angle rotation Rot(phi, theta, omega) with optional
224 gate errors and noise channels.
226 Args:
227 theta (Union[float, jnp.ndarray, List[float]]): Second rotation angle.
228 pauli (str): Pauli operator to apply. Must be "X", "Y", or "Z".
229 wires (Union[int, List[int]]): Qubit index or indices to apply rotation to.
230 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
231 Supports BitFlip, PhaseFlip, Depolarizing, and GateError.
232 random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
233 input_idx (int): Flag for the tape to track inputs
235 Returns:
236 None: Gate and noise are applied in-place to the circuit.
237 """
238 if noise_params is not None and "GateError" in noise_params:
239 theta, random_key = UnitaryGates.GateError(theta, noise_params, random_key)
240 op.PauliRot(theta, pauli, wires=wires, input_idx=input_idx)
241 UnitaryGates.Noise(wires, noise_params)
243 @staticmethod
244 def RX(
245 w: Union[float, jnp.ndarray, List[float]],
246 wires: Union[int, List[int]],
247 noise_params: Optional[Dict[str, float]] = None,
248 random_key: Optional[jax.random.PRNGKey] = None,
249 input_idx: int = -1,
250 ) -> None:
251 """
252 Apply X-axis rotation with optional noise.
254 Args:
255 w (Union[float, jnp.ndarray, List[float]]): Rotation angle.
256 wires (Union[int, List[int]]): Qubit index or indices.
257 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
258 random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
259 input_idx (int): Flag for the tape to track inputs
261 Returns:
262 None: Gate and noise are applied in-place to the circuit.
263 """
264 w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
265 op.RX(w, wires=wires, input_idx=input_idx)
266 UnitaryGates.Noise(wires, noise_params)
268 @staticmethod
269 def RY(
270 w: Union[float, jnp.ndarray, List[float]],
271 wires: Union[int, List[int]],
272 noise_params: Optional[Dict[str, float]] = None,
273 random_key: Optional[jax.random.PRNGKey] = None,
274 input_idx: int = -1,
275 ) -> None:
276 """
277 Apply Y-axis rotation with optional noise.
279 Args:
280 w (Union[float, jnp.ndarray, List[float]]): Rotation angle.
281 wires (Union[int, List[int]]): Qubit index or indices.
282 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
283 random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
284 input_idx (int): Flag for the tape to track inputs
286 Returns:
287 None: Gate and noise are applied in-place to the circuit.
288 """
289 w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
290 op.RY(w, wires=wires, input_idx=input_idx)
291 UnitaryGates.Noise(wires, noise_params)
293 @staticmethod
294 def RZ(
295 w: Union[float, jnp.ndarray, List[float]],
296 wires: Union[int, List[int]],
297 noise_params: Optional[Dict[str, float]] = None,
298 random_key: Optional[jax.random.PRNGKey] = None,
299 input_idx: int = -1,
300 ) -> None:
301 """
302 Apply Z-axis rotation with optional noise.
304 Args:
305 w (Union[float, jnp.ndarray, List[float]]): Rotation angle.
306 wires (Union[int, List[int]]): Qubit index or indices.
307 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
308 random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
309 input_idx (int): Flag for the tape to track inputs
311 Returns:
312 None: Gate and noise are applied in-place to the circuit.
313 """
314 w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
315 op.RZ(w, wires=wires, input_idx=input_idx)
316 UnitaryGates.Noise(wires, noise_params)
318 @staticmethod
319 def CRX(
320 w: Union[float, jnp.ndarray, List[float]],
321 wires: Union[int, List[int]],
322 noise_params: Optional[Dict[str, float]] = None,
323 random_key: Optional[jax.random.PRNGKey] = None,
324 input_idx: int = -1,
325 ) -> None:
326 """
327 Apply controlled X-rotation with optional noise.
329 Args:
330 w (Union[float, jnp.ndarray, List[float]]): Rotation angle.
331 wires (Union[int, List[int]]): Control and target qubit indices.
332 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
333 random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
334 input_idx (int): Flag for the tape to track inputs
336 Returns:
337 None: Gate and noise are applied in-place to the circuit.
338 """
339 w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
340 op.CRX(w, wires=wires, input_idx=input_idx)
341 UnitaryGates.Noise(wires, noise_params)
343 @staticmethod
344 def CRY(
345 w: Union[float, jnp.ndarray, List[float]],
346 wires: Union[int, List[int]],
347 noise_params: Optional[Dict[str, float]] = None,
348 random_key: Optional[jax.random.PRNGKey] = None,
349 input_idx: int = -1,
350 ) -> None:
351 """
352 Apply controlled Y-rotation with optional noise.
354 Args:
355 w (Union[float, jnp.ndarray, List[float]]): Rotation angle.
356 wires (Union[int, List[int]]): Control and target qubit indices.
357 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
358 random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
359 input_idx (int): Flag for the tape to track inputs
361 Returns:
362 None: Gate and noise are applied in-place to the circuit.
363 """
364 w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
365 op.CRY(w, wires=wires, input_idx=input_idx)
366 UnitaryGates.Noise(wires, noise_params)
368 @staticmethod
369 def CRZ(
370 w: Union[float, jnp.ndarray, List[float]],
371 wires: Union[int, List[int]],
372 noise_params: Optional[Dict[str, float]] = None,
373 random_key: Optional[jax.random.PRNGKey] = None,
374 input_idx: int = -1,
375 ) -> None:
376 """
377 Apply controlled Z-rotation with optional noise.
379 Args:
380 w (Union[float, jnp.ndarray, List[float]]): Rotation angle.
381 wires (Union[int, List[int]]): Control and target qubit indices.
382 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
383 random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
384 input_idx (int): Flag for the tape to track inputs
386 Returns:
387 None: Gate and noise are applied in-place to the circuit.
388 """
389 w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
390 op.CRZ(w, wires=wires, input_idx=input_idx)
391 UnitaryGates.Noise(wires, noise_params)
393 @staticmethod
394 def CX(
395 wires: Union[int, List[int]],
396 noise_params: Optional[Dict[str, float]] = None,
397 random_key: Optional[jax.random.PRNGKey] = None,
398 ) -> None:
399 """
400 Apply controlled-NOT (CNOT) gate with optional noise.
402 Args:
403 wires (Union[int, List[int]]): Control and target qubit indices.
404 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
405 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
406 (not used in this gate).
408 Returns:
409 None: Gate and noise are applied in-place to the circuit.
410 """
411 op.CX(wires=wires)
412 UnitaryGates.Noise(wires, noise_params)
414 @staticmethod
415 def CY(
416 wires: Union[int, List[int]],
417 noise_params: Optional[Dict[str, float]] = None,
418 random_key: Optional[jax.random.PRNGKey] = None,
419 ) -> None:
420 """
421 Apply controlled-Y gate with optional noise.
423 Args:
424 wires (Union[int, List[int]]): Control and target qubit indices.
425 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
426 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
427 (not used in this gate).
429 Returns:
430 None: Gate and noise are applied in-place to the circuit.
431 """
432 op.CY(wires=wires)
433 UnitaryGates.Noise(wires, noise_params)
435 @staticmethod
436 def CZ(
437 wires: Union[int, List[int]],
438 noise_params: Optional[Dict[str, float]] = None,
439 random_key: Optional[jax.random.PRNGKey] = None,
440 ) -> None:
441 """
442 Apply controlled-Z gate with optional noise.
444 Args:
445 wires (Union[int, List[int]]): Control and target qubit indices.
446 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
447 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
448 (not used in this gate).
450 Returns:
451 None: Gate and noise are applied in-place to the circuit.
452 """
453 op.CZ(wires=wires)
454 UnitaryGates.Noise(wires, noise_params)
456 @staticmethod
457 def H(
458 wires: Union[int, List[int]],
459 noise_params: Optional[Dict[str, float]] = None,
460 random_key: Optional[jax.random.PRNGKey] = None,
461 ) -> None:
462 """
463 Apply Hadamard gate with optional noise.
465 Args:
466 wires (Union[int, List[int]]): Qubit index or indices.
467 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
468 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
469 (not used in this gate).
471 Returns:
472 None: Gate and noise are applied in-place to the circuit.
473 """
474 op.H(wires=wires)
475 UnitaryGates.Noise(wires, noise_params)