Coverage for qml_essentials / unitary.py: 95%
166 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-16 10:19 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-16 10:19 +0000
1from typing import Optional, List, Union, Dict, Tuple
2import itertools
3import jax.numpy as jnp
4import jax
6from qml_essentials import operations as op
7import logging
9from qml_essentials.utils import safe_random_split
11log = logging.getLogger(__name__)
14# Cache for computed rulers
15_GOLOMB_RULER_CACHE: Dict[int, Tuple[int, ...]] = {}
18def _greedy_golomb(d: int) -> Tuple[int, ...]:
19 """Construct a valid Golomb ruler of order *d* using a greedy algorithm.
21 Starting from mark 0, each subsequent mark is the smallest integer
22 whose pairwise differences with all existing marks are distinct.
23 This always succeeds and produces a valid ruler, though it may not
24 be optimal (i.e. the max mark may not be minimal).
26 Args:
27 d: Order of the ruler (number of marks).
29 Returns:
30 Tuple of *d* non-negative integers forming a valid Golomb ruler.
31 """
32 if d <= 0:
33 return ()
34 marks = [0]
35 diffs: set = set()
36 candidate = 1
37 while len(marks) < d:
38 new_diffs: set = set()
39 valid = True
40 for existing in marks:
41 diff = candidate - existing
42 if diff in diffs or diff in new_diffs:
43 valid = False
44 break
45 new_diffs.add(diff)
46 if valid:
47 marks.append(candidate)
48 diffs |= new_diffs
49 candidate += 1
50 return tuple(marks)
53def golomb_ruler(d: int) -> Tuple[int, ...]:
54 """Return a valid Golomb ruler of order *d*.
56 A Golomb ruler is a set of *d* non-negative integers such that all
57 pairwise differences are distinct. When used as the diagonal of a
58 data-encoding Hamiltonian ``H = diag(marks)``, the resulting Fourier
59 spectrum ``\\Omega`` has ``|\\Omega| = d(d-1) + 1`` distinct frequencies
60 with ``|R(k)| = 1`` for all ``k ≠ 0`` — the minimal possible degeneracy
61 for any *d*-dimensional Hamiltonian.
63 Uses a greedy construction that always produces a valid ruler.
64 Results are cached for efficiency.
66 Args:
67 d: Order of the ruler (number of marks, equal to the Hilbert
68 space dimension ``2^n_qubits``).
70 Returns:
71 Tuple of *d* non-negative integers forming a Golomb ruler.
73 Raises:
74 ValueError: If ``d <= 0``.
76 References:
77 Peters et al., "Generalization despite overfitting in quantum
78 machine learning models", arXiv:2209.05523, Appendix C.4.
79 """
80 if d <= 0:
81 raise ValueError(f"Golomb ruler order must be positive, got {d}")
82 if d not in _GOLOMB_RULER_CACHE:
83 _GOLOMB_RULER_CACHE[d] = _greedy_golomb(d)
84 return _GOLOMB_RULER_CACHE[d]
87class UnitaryGates:
88 """Collection of unitary quantum gates with optional noise simulation."""
90 batch_gate_error = True
92 @staticmethod
93 def NQubitDepolarizingChannel(p: float, wires: List[int]) -> op.QubitChannel:
94 """
95 Generate Kraus operators for n-qubit depolarizing channel.
97 The n-qubit depolarizing channel models uniform depolarizing noise
98 acting on n qubits simultaneously, useful for simulating realistic
99 multi-qubit noise affecting entangling gates.
101 Args:
102 p (float): Total probability of depolarizing error (0 ≤ p ≤ 1).
103 wires (List[int]): Qubit indices on which the channel acts.
104 Must contain at least 2 qubits.
106 Returns:
107 op.QubitChannel: QubitChannel with Kraus operators
108 representing the depolarizing noise channel.
110 Raises:
111 ValueError: If p is not in [0, 1] or if fewer than 2 qubits provided.
112 """
114 def n_qubit_depolarizing_kraus(p: float, n: int) -> List[jnp.ndarray]:
115 if not (0.0 <= p <= 1.0):
116 raise ValueError(f"Probability p must be between 0 and 1, got {p}")
117 if n < 2:
118 raise ValueError(f"Number of qubits must be >= 2, got {n}")
120 Id = jnp.eye(2)
121 X = op.PauliX._matrix
122 Y = op.PauliY._matrix
123 Z = op.PauliZ._matrix
124 paulis = [Id, X, Y, Z]
126 dim = 2**n
127 all_ops = []
129 # Generate all n-qubit Pauli tensor products:
130 for indices in itertools.product(range(4), repeat=n):
131 P = jnp.eye(1)
132 for idx in indices:
133 P = jnp.kron(P, paulis[idx])
134 all_ops.append(P)
136 # Identity operator corresponds to all zeros indices (Id^n)
137 K0 = jnp.sqrt(1 - p * (4**n - 1) / (4**n)) * jnp.eye(dim)
139 kraus_ops = []
140 for i, P in enumerate(all_ops):
141 if i == 0:
142 # Skip the identity, already handled as K0
143 continue
144 kraus_ops.append(jnp.sqrt(p / (4**n)) * P)
146 return [K0] + kraus_ops
148 return op.QubitChannel(n_qubit_depolarizing_kraus(p, len(wires)), wires=wires)
150 @staticmethod
151 def Noise(
152 wires: Union[int, List[int]], noise_params: Optional[Dict[str, float]] = None
153 ) -> None:
154 """
155 Apply noise channels to specified qubits.
157 Applies various single-qubit and multi-qubit noise channels based on
158 the provided noise parameters dictionary.
160 Args:
161 wires (Union[int, List[int]]): Qubit index or list of qubit indices
162 to apply noise to.
163 noise_params (Optional[Dict[str, float]]): Dictionary of noise
164 parameters. Supported keys:
165 - "BitFlip" (float): Bit flip error probability
166 - "PhaseFlip" (float): Phase flip error probability
167 - "Depolarizing" (float): Single-qubit depolarizing probability
168 - "MultiQubitDepolarizing" (float): Multi-qubit depolarizing
169 probability (applies if len(wires) > 1)
170 All parameters default to 0.0 if not provided.
172 Returns:
173 None: Noise channels are applied in-place to the circuit.
174 """
175 if noise_params is not None:
176 if isinstance(wires, int):
177 wires = [wires] # single qubit gate
179 # noise on single qubits
180 for wire in wires:
181 bf = noise_params.get("BitFlip", 0.0)
182 if bf > 0:
183 op.BitFlip(bf, wires=wire)
185 pf = noise_params.get("PhaseFlip", 0.0)
186 if pf > 0:
187 op.PhaseFlip(pf, wires=wire)
189 dp = noise_params.get("Depolarizing", 0.0)
190 if dp > 0:
191 op.DepolarizingChannel(dp, wires=wire)
193 # noise on two-qubits
194 if len(wires) > 1:
195 p = noise_params.get("MultiQubitDepolarizing", 0.0)
196 if p > 0:
197 UnitaryGates.NQubitDepolarizingChannel(p, wires)
199 @staticmethod
200 def GateError(
201 w: Union[float, jnp.ndarray, List[float]],
202 noise_params: Optional[Dict[str, float]] = None,
203 random_key: Optional[jax.random.PRNGKey] = None,
204 ) -> Tuple[jnp.ndarray, jax.random.PRNGKey]:
205 """
206 Apply gate error noise to rotation angle(s).
208 Adds Gaussian noise to gate rotation angles to simulate imperfect
209 gate implementations.
211 Args:
212 w (Union[float, jnp.ndarray, List[float]]): Rotation angle(s) in radians.
213 noise_params (Optional[Dict[str, float]]): Dictionary with optional
214 "GateError" key specifying standard deviation of Gaussian noise.
215 random_key (Optional[jax.random.PRNGKey]): JAX random key for
216 stochastic noise generation.
218 Returns:
219 Tuple[jnp.ndarray, jax.random.PRNGKey]: Tuple containing:
220 - Modified rotation angle(s) with applied noise
221 - Updated JAX random key
223 Raises:
224 AssertionError: If noise_params contains "GateError" but random_key is None.
225 """
226 if noise_params is not None and noise_params.get("GateError", None) is not None:
227 assert random_key is not None, (
228 "A random_key must be provided when using GateError"
229 )
231 if UnitaryGates.batch_gate_error:
232 random_key, sub_key = safe_random_split(random_key)
233 else:
234 # Use a fixed key so that every batch element (under vmap)
235 # draws the same noise value, effectively broadcasting.
236 sub_key = jax.random.key(0)
238 w += noise_params["GateError"] * jax.random.normal(
239 sub_key,
240 (
241 w.shape
242 if isinstance(w, jnp.ndarray) and UnitaryGates.batch_gate_error
243 else ()
244 ),
245 )
246 return w, random_key
248 @staticmethod
249 def Rot(
250 phi: Union[float, jnp.ndarray, List[float]],
251 theta: Union[float, jnp.ndarray, List[float]],
252 omega: Union[float, jnp.ndarray, List[float]],
253 wires: Union[int, List[int]],
254 noise_params: Optional[Dict[str, float]] = None,
255 random_key: Optional[jax.random.PRNGKey] = None,
256 input_idx: int = -1,
257 ) -> None:
258 """
259 Apply general rotation gate with optional noise.
261 Applies a three-angle rotation Rot(phi, theta, omega) with optional
262 gate errors and noise channels.
264 Args:
265 phi (Union[float, jnp.ndarray, List[float]]): First rotation angle.
266 theta (Union[float, jnp.ndarray, List[float]]): Second rotation angle.
267 omega (Union[float, jnp.ndarray, List[float]]): Third rotation angle.
268 wires (Union[int, List[int]]): Qubit index or indices to apply rotation to.
269 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
270 Supports BitFlip, PhaseFlip, Depolarizing, and GateError.
271 random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
272 input_idx (int): Flag for the tape to track inputs
274 Returns:
275 None: Gate and noise are applied in-place to the circuit.
276 """
277 if noise_params is not None and "GateError" in noise_params:
278 phi, random_key = UnitaryGates.GateError(phi, noise_params, random_key)
279 theta, random_key = UnitaryGates.GateError(theta, noise_params, random_key)
280 omega, random_key = UnitaryGates.GateError(omega, noise_params, random_key)
281 op.Rot(phi, theta, omega, wires=wires, input_idx=False)
282 UnitaryGates.Noise(wires, noise_params)
284 @staticmethod
285 def PauliRot(
286 theta: float,
287 pauli: str,
288 wires: Union[int, List[int]],
289 noise_params: Optional[Dict[str, float]] = None,
290 random_key: Optional[jax.random.PRNGKey] = None,
291 input_idx: int = -1,
292 ) -> None:
293 """
294 Apply general rotation gate with optional noise.
296 Applies a three-angle rotation Rot(phi, theta, omega) with optional
297 gate errors and noise channels.
299 Args:
300 theta (Union[float, jnp.ndarray, List[float]]): Second rotation angle.
301 pauli (str): Pauli operator to apply. Must be "X", "Y", or "Z".
302 wires (Union[int, List[int]]): Qubit index or indices to apply rotation to.
303 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
304 Supports BitFlip, PhaseFlip, Depolarizing, and GateError.
305 random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
306 input_idx (int): Flag for the tape to track inputs
308 Returns:
309 None: Gate and noise are applied in-place to the circuit.
310 """
311 if noise_params is not None and "GateError" in noise_params:
312 theta, random_key = UnitaryGates.GateError(theta, noise_params, random_key)
313 op.PauliRot(theta, pauli, wires=wires, input_idx=input_idx)
314 UnitaryGates.Noise(wires, noise_params)
316 @staticmethod
317 def RX(
318 w: Union[float, jnp.ndarray, List[float]],
319 wires: Union[int, List[int]],
320 noise_params: Optional[Dict[str, float]] = None,
321 random_key: Optional[jax.random.PRNGKey] = None,
322 input_idx: int = -1,
323 ) -> None:
324 """
325 Apply X-axis rotation with optional noise.
327 Args:
328 w (Union[float, jnp.ndarray, List[float]]): Rotation angle.
329 wires (Union[int, List[int]]): Qubit index or indices.
330 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
331 random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
332 input_idx (int): Flag for the tape to track inputs
334 Returns:
335 None: Gate and noise are applied in-place to the circuit.
336 """
337 w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
338 op.RX(w, wires=wires, input_idx=input_idx)
339 UnitaryGates.Noise(wires, noise_params)
341 @staticmethod
342 def RY(
343 w: Union[float, jnp.ndarray, List[float]],
344 wires: Union[int, List[int]],
345 noise_params: Optional[Dict[str, float]] = None,
346 random_key: Optional[jax.random.PRNGKey] = None,
347 input_idx: int = -1,
348 ) -> None:
349 """
350 Apply Y-axis rotation with optional noise.
352 Args:
353 w (Union[float, jnp.ndarray, List[float]]): Rotation angle.
354 wires (Union[int, List[int]]): Qubit index or indices.
355 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
356 random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
357 input_idx (int): Flag for the tape to track inputs
359 Returns:
360 None: Gate and noise are applied in-place to the circuit.
361 """
362 w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
363 op.RY(w, wires=wires, input_idx=input_idx)
364 UnitaryGates.Noise(wires, noise_params)
366 @staticmethod
367 def RZ(
368 w: Union[float, jnp.ndarray, List[float]],
369 wires: Union[int, List[int]],
370 noise_params: Optional[Dict[str, float]] = None,
371 random_key: Optional[jax.random.PRNGKey] = None,
372 input_idx: int = -1,
373 ) -> None:
374 """
375 Apply Z-axis rotation with optional noise.
377 Args:
378 w (Union[float, jnp.ndarray, List[float]]): Rotation angle.
379 wires (Union[int, List[int]]): Qubit index or indices.
380 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
381 random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
382 input_idx (int): Flag for the tape to track inputs
384 Returns:
385 None: Gate and noise are applied in-place to the circuit.
386 """
387 w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
388 op.RZ(w, wires=wires, input_idx=input_idx)
389 UnitaryGates.Noise(wires, noise_params)
391 @staticmethod
392 def CRX(
393 w: Union[float, jnp.ndarray, List[float]],
394 wires: Union[int, List[int]],
395 noise_params: Optional[Dict[str, float]] = None,
396 random_key: Optional[jax.random.PRNGKey] = None,
397 input_idx: int = -1,
398 ) -> None:
399 """
400 Apply controlled X-rotation with optional noise.
402 Args:
403 w (Union[float, jnp.ndarray, List[float]]): Rotation angle.
404 wires (Union[int, List[int]]): Control and target qubit indices.
405 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
406 random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
407 input_idx (int): Flag for the tape to track inputs
409 Returns:
410 None: Gate and noise are applied in-place to the circuit.
411 """
412 w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
413 op.CRX(w, wires=wires, input_idx=input_idx)
414 UnitaryGates.Noise(wires, noise_params)
416 @staticmethod
417 def CRY(
418 w: Union[float, jnp.ndarray, List[float]],
419 wires: Union[int, List[int]],
420 noise_params: Optional[Dict[str, float]] = None,
421 random_key: Optional[jax.random.PRNGKey] = None,
422 input_idx: int = -1,
423 ) -> None:
424 """
425 Apply controlled Y-rotation with optional noise.
427 Args:
428 w (Union[float, jnp.ndarray, List[float]]): Rotation angle.
429 wires (Union[int, List[int]]): Control and target qubit indices.
430 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
431 random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
432 input_idx (int): Flag for the tape to track inputs
434 Returns:
435 None: Gate and noise are applied in-place to the circuit.
436 """
437 w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
438 op.CRY(w, wires=wires, input_idx=input_idx)
439 UnitaryGates.Noise(wires, noise_params)
441 @staticmethod
442 def CRZ(
443 w: Union[float, jnp.ndarray, List[float]],
444 wires: Union[int, List[int]],
445 noise_params: Optional[Dict[str, float]] = None,
446 random_key: Optional[jax.random.PRNGKey] = None,
447 input_idx: int = -1,
448 ) -> None:
449 """
450 Apply controlled Z-rotation with optional noise.
452 Args:
453 w (Union[float, jnp.ndarray, List[float]]): Rotation angle.
454 wires (Union[int, List[int]]): Control and target qubit indices.
455 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
456 random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
457 input_idx (int): Flag for the tape to track inputs
459 Returns:
460 None: Gate and noise are applied in-place to the circuit.
461 """
462 w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
463 op.CRZ(w, wires=wires, input_idx=input_idx)
464 UnitaryGates.Noise(wires, noise_params)
466 @staticmethod
467 def CPhase(
468 w: Union[float, jnp.ndarray, List[float]],
469 wires: Union[int, List[int]],
470 noise_params: Optional[Dict[str, float]] = None,
471 random_key: Optional[jax.random.PRNGKey] = None,
472 input_idx: int = -1,
473 ) -> None:
474 """
475 Apply controlled phase shift gate with optional noise.
477 This is a generalization of the CZ gate, applying a phase shift of
478 exp(i*w) to the |11⟩ state. When w=π, this reduces to CZ.
480 Args:
481 w (Union[float, jnp.ndarray, List[float]]): Phase shift angle.
482 wires (Union[int, List[int]]): Control and target qubit indices.
483 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
484 random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
485 input_idx (int): Flag for the tape to track inputs
487 Returns:
488 None: Gate and noise are applied in-place to the circuit.
489 """
490 w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
491 op.ControlledPhaseShift(w, wires=wires, input_idx=input_idx)
492 UnitaryGates.Noise(wires, noise_params)
494 @staticmethod
495 def CX(
496 wires: Union[int, List[int]],
497 noise_params: Optional[Dict[str, float]] = None,
498 random_key: Optional[jax.random.PRNGKey] = None,
499 ) -> None:
500 """
501 Apply controlled-NOT (CNOT) gate with optional noise.
503 Args:
504 wires (Union[int, List[int]]): Control and target qubit indices.
505 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
506 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
507 (not used in this gate).
509 Returns:
510 None: Gate and noise are applied in-place to the circuit.
511 """
512 op.CX(wires=wires)
513 UnitaryGates.Noise(wires, noise_params)
515 @staticmethod
516 def CY(
517 wires: Union[int, List[int]],
518 noise_params: Optional[Dict[str, float]] = None,
519 random_key: Optional[jax.random.PRNGKey] = None,
520 ) -> None:
521 """
522 Apply controlled-Y gate with optional noise.
524 Args:
525 wires (Union[int, List[int]]): Control and target qubit indices.
526 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
527 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
528 (not used in this gate).
530 Returns:
531 None: Gate and noise are applied in-place to the circuit.
532 """
533 op.CY(wires=wires)
534 UnitaryGates.Noise(wires, noise_params)
536 @staticmethod
537 def CZ(
538 wires: Union[int, List[int]],
539 noise_params: Optional[Dict[str, float]] = None,
540 random_key: Optional[jax.random.PRNGKey] = None,
541 ) -> None:
542 """
543 Apply controlled-Z gate with optional noise.
545 Args:
546 wires (Union[int, List[int]]): Control and target qubit indices.
547 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
548 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
549 (not used in this gate).
551 Returns:
552 None: Gate and noise are applied in-place to the circuit.
553 """
554 op.CZ(wires=wires)
555 UnitaryGates.Noise(wires, noise_params)
557 @staticmethod
558 def H(
559 wires: Union[int, List[int]],
560 noise_params: Optional[Dict[str, float]] = None,
561 random_key: Optional[jax.random.PRNGKey] = None,
562 ) -> None:
563 """
564 Apply Hadamard gate with optional noise.
566 Args:
567 wires (Union[int, List[int]]): Qubit index or indices.
568 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
569 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
570 (not used in this gate).
572 Returns:
573 None: Gate and noise are applied in-place to the circuit.
574 """
575 op.H(wires=wires)
576 UnitaryGates.Noise(wires, noise_params)
578 @staticmethod
579 def GolombEncoding(
580 w: Union[float, jnp.ndarray],
581 wires: Union[int, List[int]],
582 noise_params: Optional[Dict[str, float]] = None,
583 random_key: Optional[jax.random.PRNGKey] = None,
584 input_idx: int = -1,
585 ) -> None:
586 """Apply Golomb encoding as a diagonal unitary on all given wires.
588 Implements ``S(x) = exp(-i H x)`` where
589 ``H = diag(g_0, g_1, ..., g_{d-1})`` and the ``g_j`` are the marks
590 of a Golomb ruler of order ``d = 2^len(wires)``. This produces a
591 maximally non-degenerate Fourier spectrum with
592 ``|\\Omega| = d(d-1) + 1`` distinct frequencies, each with degeneracy
593 ``|R(k)| = 1``.
595 See Peters et al., arXiv:2209.05523, Sec. 3.1 and Appendix C.4.
597 Args:
598 w: Scalar input value (the data point *x* to encode).
599 wires: Qubit indices this encoding acts on. All qubits are
600 acted upon simultaneously via a single multi-qubit diagonal
601 gate.
602 noise_params: Optional noise parameters dictionary.
603 random_key: JAX random key for stochastic noise.
604 input_idx: Flag for the tape to track inputs.
606 Returns:
607 None: Gate and noise are applied in-place to the circuit.
608 """
609 wires_list = list(wires) if isinstance(wires, (list, tuple)) else [wires]
610 d = 2 ** len(wires_list)
611 marks = jnp.array(golomb_ruler(d), dtype=float)
613 # Apply gate error to the input angle
614 w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
616 # Build diagonal: exp(-i * mark_j * x)
617 diag = jnp.exp(-1j * marks * w)
619 op.DiagonalQubitUnitary(diag, wires=wires_list, input_idx=input_idx)
620 UnitaryGates.Noise(wires_list, noise_params)