Coverage for qml_essentials / operations.py: 89%
499 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-30 11:43 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-30 11:43 +0000
1from typing import Callable, List, Optional, Tuple, Union
2from functools import lru_cache
3import string
4import numpy as np
6import jax
7import jax.numpy as jnp
9from qml_essentials.tape import active_tape, recording # noqa: F401 (re-export)
12def _cdtype():
13 """Return the active JAX complex dtype
14 (complex128 if x64 enabled, else complex64).
15 """
16 return jnp.complex128 if jax.config.x64_enabled else jnp.complex64
19@lru_cache(maxsize=256)
20def _einsum_subscript(
21 n: int,
22 k: int,
23 target_axes: Tuple[int, ...],
24) -> str:
25 """Build an ``einsum`` subscript that fuses contraction + axis restore.
27 Args:
28 n: Total rank of the state tensor (number of qubits for statevectors,
29 ``2 * n_qubits`` for density matrices).
30 k: Number of qubits the gate acts on.
31 target_axes: Tuple of k axis indices in the state tensor that the
32 gate contracts against.
34 Returns:
35 ``einsum`` subscript string, e.g. ``"ab,cBd->cad"`` for a 1-qubit
36 gate on wire 1 of a 3-qubit state.
37 """
38 letters = string.ascii_letters
39 # State indices: one letter per axis
40 state_idx = list(letters[:n])
41 # Contracted indices (the ones being replaced by the gate)
42 contracted = [state_idx[ax] for ax in target_axes]
43 # Gate indices: new output indices + contracted input indices
44 new_out = [letters[n + i] for i in range(k)] # fresh letters for output
45 gate_idx = new_out + contracted # gate shape: (out0, out1, ..., in0, in1, ...)
46 # Result indices: replace target axes with new output letters
47 result_idx = list(state_idx)
48 for i, ax in enumerate(target_axes):
49 result_idx[ax] = new_out[i]
50 return "".join(gate_idx) + "," + "".join(state_idx) + "->" + "".join(result_idx)
53def _contract_and_restore(
54 tensor: jnp.ndarray,
55 gate: jnp.ndarray,
56 k: int,
57 target_axes: List[int],
58) -> jnp.ndarray:
59 """Contract gate against target_axes of tensor and restore axis order.
61 The einsum subscript is cached via :func:`_einsum_subscript` so the
62 string construction only happens once per unique
63 ``(total, k, target_axes)`` combination.
65 Args:
66 tensor: Rank-N tensor (e.g. ``(2,)*n`` for states or ``(2,)*2n``
67 for density matrices).
68 gate: Reshaped gate tensor of shape ``(2,)*2k``.
69 k: Number of qubits the gate acts on (= ``len(target_axes)``).
70 target_axes: The k axes of tensor to contract against.
72 Returns:
73 Updated tensor with the same rank as tensor, with the
74 contracted axes restored to their original positions.
75 """
76 subscript = _einsum_subscript(tensor.ndim, k, tuple(target_axes))
77 return jnp.einsum(subscript, gate, tensor)
80class Operation:
81 """Base class for any quantum operation or observable.
83 Further gates should inherit from this class to realise more specific
84 operations. Generally, operations are created by instantiation inside a
85 circuit function passed to :class:`Script`; the instance is
86 automatically appended to the active tape.
88 An ``Operation`` can also serve as an *observable*: its matrix is used to
89 compute expectation values via ``apply_to_state`` / ``apply_to_density``.
91 Attributes:
92 _matrix: Class-level default gate matrix. Subclasses set this to their
93 fixed unitary. Instances may override it via the *matrix* argument
94 to ``__init__``.
95 _num_wires: Expected number of wires for this gate. Subclasses set
96 this to enforce wire count validation. ``None`` means any number
97 of wires is accepted.
98 _param_names: Tuple of attribute names for the gate parameters.
99 Used by :attr:`parameters` and :meth:`__repr__`.
100 """
102 # Subclasses should set this to the gate's unitary / matrix
103 _matrix: jnp.ndarray = None
104 is_controlled = False
105 _num_wires: Optional[int] = None
106 _param_names: Tuple[str, ...] = ()
108 def __init__(
109 self,
110 wires: Union[int, List[int]] = 0,
111 matrix: Optional[jnp.ndarray] = None,
112 record: bool = True,
113 input_idx: int = -1,
114 name: Optional[str] = None,
115 ) -> None:
116 """Initialise the operation and optionally register it on the active tape.
118 Args:
119 wires: Qubit index or list of qubit indices this operation acts on.
120 matrix: Optional explicit gate matrix. When provided it overrides
121 the class-level ``_matrix`` attribute.
122 record: If ``True`` (default) and a tape is currently recording,
123 append this operation to the tape. Set to ``False`` for
124 auxiliary objects that should not appear in the circuit
125 (e.g. Hamiltonians used only to build time-dependent
126 evolutions).
127 input_idx: Marks the operation as input with the corresponding
128 input index, which is useful for the analytical Fourier
129 coefficients computation, but has no effect otherwise.
130 name: Optional explicit name for this operation. When ``None``
131 (default), the class name is used (e.g. ``"RX"``).
133 Raises:
134 ValueError: If ``_num_wires`` is set and the number of wires
135 doesn't match, or if duplicate wires are provided.
136 """
137 self.name = name or self.__class__.__name__
138 self.wires = list(wires) if isinstance(wires, (list, tuple)) else [wires]
139 self.input_idx = input_idx
141 if self._num_wires is not None and len(self.wires) != self._num_wires:
142 raise ValueError(
143 f"{self.name} expects {self._num_wires} wire(s), "
144 f"got {len(self.wires)}: {self.wires}"
145 )
146 if len(self.wires) != len(set(self.wires)):
147 raise ValueError(f"{self.name} received duplicate wires: {self.wires}")
149 if matrix is not None:
150 self._matrix = matrix
152 # If a tape is currently recording, append ourselves
153 if record:
154 tape = active_tape()
155 if tape is not None:
156 tape.append(self)
158 @property
159 def parameters(self) -> list:
160 """Return the list of numeric parameters for this operation.
162 Uses the declarative ``_param_names`` tuple to collect parameter
163 values in a canonical order. Non-parametrized gates return an
164 empty list.
166 Returns:
167 List of parameter values (floats or JAX arrays).
168 """
169 return [getattr(self, name) for name in self._param_names]
171 def __repr__(self) -> str:
172 """Return a human-readable representation of this operation.
174 Returns:
175 A string like ``"RX(0.5000, wires=[0])"`` or ``"CX(wires=[0, 1])"``.
176 """
177 params = self.parameters
178 if params:
179 param_str = ", ".join(
180 (
181 f"{float(v):.4f}"
182 if isinstance(v, (float, np.floating, jnp.ndarray))
183 else str(v)
184 )
185 for v in params
186 )
187 return f"{self.name}({param_str}, wires={self.wires})"
188 return f"{self.name}(wires={self.wires})"
190 @property
191 def matrix(self) -> jnp.ndarray:
192 """Return the base matrix of this operation (before lifting).
194 Returns:
195 The gate matrix as a JAX array.
197 Raises:
198 NotImplementedError: If the subclass has not defined ``_matrix``.
199 """
200 if self._matrix is None:
201 raise NotImplementedError(
202 f"{self.__class__.__name__} does not define a matrix."
203 )
204 return self._matrix
206 @property
207 def wires(self) -> List[int]:
208 """Qubit indices this operation acts on.
210 Returns:
211 List of integer qubit indices.
212 """
213 return self._wires
215 @wires.setter
216 def wires(self, wires: Union[int, List[int]]) -> None:
217 """Set the qubit indices for this operation.
219 Args:
220 wires: A single qubit index or a list of qubit indices.
221 """
222 if isinstance(wires, (list, tuple)):
223 self._wires = list(wires)
224 else:
225 self._wires = [wires]
227 @property
228 def input_idx(self) -> int:
229 """The index of an input
231 Returns:
232 input_idx: Index of the input
233 """
234 return self._input_idx
236 @input_idx.setter
237 def input_idx(self, input_idx: int) -> None:
238 """Setter for the input_idx flag
240 Args:
241 input_idx: Index of the input
242 """
243 self._input_idx = input_idx
245 def _update_tape_operation(self, op: "Operation") -> None:
246 """
247 If ``self`` is already on the active tape (the typical case when
248 chaining ``Gate(...).dagger()``), it is replaced by the daggered
249 operation so that only U\\dagger appears on the tape —
250 not both U and ``U\\dagger``.
251 Note that this should only be called immediately after the tape is updated.s
253 Args:
254 op (Operation): New replaced operation on the tape
255 """
256 # If self was recorded on the tape, replace it with the daggered op.
257 tape = active_tape()
258 if tape is not None:
259 if tape and tape[-1] is self:
260 tape[-1] = op
261 else:
262 tape.append(op)
264 def dagger(self) -> "Operation":
265 """Return a new operation, the conjugate transpose (``U\\dagger``)
266 Usage inside a circuit function::
268 RX(0.5, wires=0).dagger()
270 Returns:
271 A new :class:`Operation` with matrix ``U\\dagger`` acting on the same wires.
272 """
273 mat = jnp.conj(self._matrix).T
274 op = Operation(wires=self.wires, matrix=mat, record=False)
276 self._update_tape_operation(op)
278 return op
280 def power(self, power) -> "Operation":
281 """Return a new operation, the power (``U^power``)
282 Usage inside a circuit function::
284 PauliX(wires=0).power(2)
286 Returns:
287 A new :class:`Operation` with matrix ``U\\dagger`` acting on the same wires.
288 """
289 # TODO: support fractional powers
290 mat = jnp.linalg.matrix_power(self._matrix, power)
291 op = Operation(wires=self.wires, matrix=mat, record=False)
293 self._update_tape_operation(op)
295 return op
297 def lifted_matrix(self, n_qubits: int) -> jnp.ndarray:
298 """Return the full ``2**n x 2**n`` matrix embedding this gate.
300 Embeds the ``k``-qubit gate matrix into the ``n``-qubit Hilbert space
301 by applying it to the identity matrix via :meth:`apply_to_state`.
302 This is useful for computing ``Tr(O·\\rho )`` directly without vmap.
304 Args:
305 n_qubits: Total number of qubits in the circuit.
307 Returns:
308 The ``(2**n, 2**n)`` matrix of this operation in the full space.
309 """
310 dim = 2**n_qubits
311 # Apply the gate to each basis vector (column of identity)
312 return jax.vmap(lambda col: self.apply_to_state(col, n_qubits))(
313 jnp.eye(dim, dtype=_cdtype())
314 ).T
316 def apply_to_state(self, state: jnp.ndarray, n_qubits: int) -> jnp.ndarray:
317 """Apply this gate to a statevector via tensor contraction.
319 The statevector (shape ``(2**n,)``) is reshaped into a rank-n tensor
320 of shape ``(2,)*n``. The gate (shape ``(2**k, 2**k)``) is reshaped to
321 ``(2,)*2k`` and contracted against the k target wire axes.
323 Memory footprint is O(2**n) and the operation supports arbitrary k.
324 The implementation is fully differentiable through JAX.
326 Args:
327 state: Statevector of shape ``(2**n_qubits,)``.
328 n_qubits: Total number of qubits in the circuit.
330 Returns:
331 Updated statevector of shape ``(2**n_qubits,)``.
332 """
333 k = len(self.wires)
334 gate_tensor = self.matrix.reshape((2,) * 2 * k)
335 psi = state.reshape((2,) * n_qubits)
336 psi_out = _contract_and_restore(psi, gate_tensor, k, self.wires)
337 return psi_out.reshape(2**n_qubits)
339 def apply_to_state_tensor(self, psi: jnp.ndarray, n_qubits: int) -> jnp.ndarray:
340 """Apply this gate to a statevector already in tensor form.
342 Like :meth:`apply_to_state` but expects the state in rank-n tensor
343 form ``(2,)*n`` and returns the result in the same form. This avoids
344 the ``reshape`` calls at the per-gate level when the simulation loop
345 keeps the state in tensor form throughout.
347 Args:
348 psi: Statevector tensor of shape ``(2,)*n_qubits``.
349 n_qubits: Total number of qubits in the circuit.
351 Returns:
352 Updated statevector tensor of shape ``(2,)*n_qubits``.
353 """
354 k = len(self.wires)
355 gate_tensor = self._gate_tensor(k)
356 return _contract_and_restore(psi, gate_tensor, k, self.wires)
358 def _gate_tensor(self, k: int) -> jnp.ndarray:
359 """Return the gate matrix reshaped to ``(2,)*2k`` tensor form.
361 The result is cached on the instance so repeated calls (e.g. from
362 density-matrix simulation which applies U and U*) avoid redundant
363 reshape dispatch.
365 Args:
366 k: Number of qubits the gate acts on.
368 Returns:
369 Gate matrix as a rank-2k tensor of shape ``(2,)*2k``.
370 """
371 cached = getattr(self, "_cached_gate_tensor", None)
372 if cached is not None:
373 return cached
374 gt = self.matrix.reshape((2,) * 2 * k)
375 # Only cache for non-parametrized gates (whose matrix is a class attr)
376 if self._matrix is self.__class__._matrix:
377 object.__setattr__(self, "_cached_gate_tensor", gt)
378 return gt
380 def apply_to_density(self, rho: jnp.ndarray, n_qubits: int) -> jnp.ndarray:
381 """Apply this gate to a density matrix via \\rho -> U\\rho U\\dagger.
383 The density matrix (shape ``(2**n, 2**n)``) is treated as a rank-*2n*
384 tensor with n "ket" axes (0..n-1) and n "bra" axes (n..2n-1).
385 U acts on the ket half; U* acts on the bra half. Both contractions
386 use the shared :func:`_contract_and_restore` helper, keeping the
387 operation allocation-free with respect to building full unitaries.
389 Args:
390 rho: Density matrix of shape ``(2**n_qubits, 2**n_qubits)``.
391 n_qubits: Total number of qubits in the circuit.
393 Returns:
394 Updated density matrix of shape ``(2**n_qubits, 2**n_qubits)``.
395 """
396 k = len(self.wires)
397 U = self._gate_tensor(k)
398 U_conj = jnp.conj(U)
400 rho_t = rho.reshape((2,) * 2 * n_qubits)
402 # Apply U to ket axes, U\\dagger to bra axes
403 rho_t = _contract_and_restore(rho_t, U, k, self.wires)
404 bra_wires = [w + n_qubits for w in self.wires]
405 rho_t = _contract_and_restore(rho_t, U_conj, k, bra_wires)
407 return rho_t.reshape(2**n_qubits, 2**n_qubits)
410class Hermitian(Operation):
411 """A generic Hermitian observable or gate defined by an arbitrary matrix.
413 Example:
414 >>> obs = Hermitian(matrix=my_matrix, wires=0)
415 """
417 def __init__(
418 self,
419 matrix: jnp.ndarray,
420 wires: Union[int, List[int]] = 0,
421 record: bool = True,
422 ) -> None:
423 """Initialise a Hermitian operator.
425 Args:
426 matrix: The Hermitian matrix defining this operator.
427 wires: Qubit index or list of qubit indices this operator acts on.
428 record: If ``True`` (default), record on the active tape. Set to
429 ``False`` when using the Hermitian purely as a Hamiltonian
430 component (e.g. for time-dependent evolution).
431 """
432 super().__init__(
433 wires=wires,
434 matrix=jnp.asarray(matrix, dtype=_cdtype()),
435 record=record,
436 )
438 def __rmul__(self, coeff_fn):
439 """Support ``coeff_fn * Hermitian`` -> :class:`ParametrizedHamiltonian`.
441 Args:
442 coeff_fn: A callable ``(params, t) -> scalar`` giving the
443 time-dependent coefficient.
445 Returns:
446 A :class:`ParametrizedHamiltonian` pairing *coeff_fn* with this
447 operator's matrix and wires.
449 Raises:
450 TypeError: If *coeff_fn* is not callable.
451 """
452 if not callable(coeff_fn):
453 raise TypeError(
454 f"Left operand of `* Hermitian` must be callable, got {type(coeff_fn)}"
455 )
456 return ParametrizedHamiltonian(coeff_fn, self.matrix, self.wires)
459class ParametrizedHamiltonian:
460 """A time-dependent Hamiltonian ``H(t) = f(params, t) · H_mat``.
462 Created by multiplying a callable coefficient function with a
463 :class:`Hermitian` operator::
465 def coeff(p, t):
466 return p[0] * jnp.exp(-0.5 * ((t - t_c) / p[1]) ** 2)
468 H_td = coeff * Hermitian(matrix=sigma_x, wires=0)
470 The Hamiltonian is then used with :func:`evolve`::
472 evolve(H_td)(coeff_args=[A, sigma], T=1.0)
474 Attributes:
475 coeff_fn: Callable ``(params, t) -> scalar``.
476 H_mat: Static Hermitian matrix (JAX array).
477 wires: Qubit wire(s) this Hamiltonian acts on.
478 """
480 def __init__(
481 self,
482 coeff_fn: Callable,
483 H_mat: jnp.ndarray,
484 wires: Union[int, List[int]],
485 ) -> None:
486 self.coeff_fn = coeff_fn
487 self.H_mat = H_mat
488 self.wires = wires
491class Id(Operation):
492 """Identity gate."""
494 _matrix = jnp.eye(2, dtype=_cdtype())
495 _num_wires = 1
497 def __init__(self, wires: Union[int, List[int]] = 0, **kwargs) -> None:
498 """Initialise an identity gate.
500 Args:
501 wires: Qubit index or list of qubit indices this gate acts on.
502 """
503 super().__init__(wires=wires, **kwargs)
506class PauliX(Operation):
507 """Pauli-X gate / observable (bit-flip, \\sigma_x)."""
509 _matrix = jnp.array([[0, 1], [1, 0]], dtype=_cdtype())
510 _num_wires = 1
512 def __init__(self, wires: Union[int, List[int]] = 0, **kwargs) -> None:
513 """Initialise a Pauli-X gate.
515 Args:
516 wires: Qubit index or list of qubit indices this gate acts on.
517 """
518 super().__init__(wires=wires, **kwargs)
521class PauliY(Operation):
522 """Pauli-Y gate / observable (\\sigma_y)."""
524 _matrix = jnp.array([[0, -1j], [1j, 0]], dtype=_cdtype())
525 _num_wires = 1
527 def __init__(self, wires: Union[int, List[int]] = 0, **kwargs) -> None:
528 """Initialise a Pauli-Y gate.
530 Args:
531 wires: Qubit index or list of qubit indices this gate acts on.
532 """
533 super().__init__(wires=wires, **kwargs)
536class PauliZ(Operation):
537 """Pauli-Z gate / observable (phase-flip, \\sigma_z)."""
539 _matrix = jnp.array([[1, 0], [0, -1]], dtype=_cdtype())
540 _num_wires = 1
542 def __init__(self, wires: Union[int, List[int]] = 0, **kwargs) -> None:
543 """Initialise a Pauli-Z gate.
545 Args:
546 wires: Qubit index or list of qubit indices this gate acts on.
547 """
548 super().__init__(wires=wires, **kwargs)
551class H(Operation):
552 """Hadamard gate."""
554 _matrix = jnp.array([[1, 1], [1, -1]], dtype=_cdtype()) / jnp.sqrt(2)
555 _num_wires = 1
557 def __init__(self, wires: Union[int, List[int]] = 0, **kwargs) -> None:
558 """Initialise a Hadamard gate.
560 Args:
561 wires: Qubit index or list of qubit indices this gate acts on.
562 """
563 super().__init__(wires=wires, **kwargs)
566class S(Operation):
567 """S (phase) gate — a Clifford gate equal to \\sqrt Z.
569 .. math::
570 S = \\begin{pmatrix}1 & 0\\ 0 & i\\end{pmatrix}
571 """
573 _matrix = jnp.array([[1, 0], [0, 1j]], dtype=_cdtype())
574 _num_wires = 1
576 def __init__(self, wires: Union[int, List[int]] = 0) -> None:
577 """Initialise an S gate.
579 Args:
580 wires: Qubit index or list of qubit indices this gate acts on.
581 """
582 super().__init__(wires=wires)
585class RandomUnitary(Operation):
586 """Creates a random hermitian matrix and applies it as a gate."""
588 def __init__(self, wires, key, scale=1.0, record=True):
589 """Initialise a random unitary gate.
591 Args:
592 wires: Qubit index or list of qubit indices this gate acts on.
593 jax.random.PRNGKey: PRNGKey for randomization
594 scale: Scale of the random unitary (default: 1.0)
595 """
596 dim = 2 ** len(wires)
597 key_a, key_b = jax.random.split(key)
599 A = (
600 jax.random.normal(key=key_a, shape=(dim, dim))
601 + 1j * jax.random.normal(key=key_b, shape=(dim, dim))
602 ).astype(_cdtype())
603 H = (A + A.conj().T) / 2.0
605 H *= scale / jnp.linalg.norm(H, ord="fro")
607 super().__init__(wires, matrix=H, record=record)
610class Barrier(Operation):
611 """Barrier operation — a no-op used for visual circuit separation.
613 The barrier does not change the quantum state. It is recorded on the
614 tape so that drawing backends can insert a visual separator.
615 """
617 _matrix = None # not a real gate
619 def __init__(self, wires: Union[int, List[int]] = 0) -> None:
620 """Initialise a Barrier.
622 Args:
623 wires: Qubit index or list of qubit indices this barrier spans.
624 """
625 super().__init__(wires=wires)
627 def apply_to_state(self, state: jnp.ndarray, n_qubits: int) -> jnp.ndarray:
628 """No-op: return the state unchanged."""
629 return state
631 def apply_to_state_tensor(self, psi: jnp.ndarray, n_qubits: int) -> jnp.ndarray:
632 """No-op: return the state tensor unchanged."""
633 return psi
635 def apply_to_density(self, rho: jnp.ndarray, n_qubits: int) -> jnp.ndarray:
636 """No-op: return the density matrix unchanged."""
637 return rho
640def _make_rotation_gate(pauli_class: type, name: str) -> type:
641 """Factory for single-qubit rotation gates RX, RY, RZ.
643 Each gate has the form ``R_P(\\theta) = cos(\\theta/2) I - i sin(\\theta/2) P``.
645 Args:
646 pauli_class: One of PauliX, PauliY, PauliZ.
647 name: Class name for the generated gate (e.g. ``"RX"``).
649 Returns:
650 A new :class:`Operation` subclass.
651 """
652 pauli_mat = pauli_class._matrix
654 class _RotationGate(Operation):
655 # Fancy way of setting docstring to make it generic
656 __doc__ = (
657 f"Rotation around the {name[1]} axis: {name}(\\theta) =\n"
658 f"exp(-i \\theta/2 {name[1]}).\n"
659 )
660 _num_wires = 1
661 _param_names = ("theta",)
663 def __init__(
664 self, theta: float, wires: Union[int, List[int]] = 0, **kwargs
665 ) -> None:
666 self.theta = theta
667 c = jnp.cos(theta / 2)
668 s = jnp.sin(theta / 2)
669 mat = c * Id._matrix - 1j * s * pauli_mat
670 super().__init__(wires=wires, matrix=mat, **kwargs)
672 def generator(self) -> Operation:
673 """Return the generator as the corresponding Pauli operation."""
674 return pauli_class(wires=self.wires[0], record=False)
676 _RotationGate.__name__ = name
677 _RotationGate.__qualname__ = name
678 return _RotationGate
681RX = _make_rotation_gate(PauliX, "RX")
682RY = _make_rotation_gate(PauliY, "RY")
683RZ = _make_rotation_gate(PauliZ, "RZ")
686# Projectors used by controlled-gate factories
687_P0 = jnp.array([[1, 0], [0, 0]], dtype=_cdtype())
688_P1 = jnp.array([[0, 0], [0, 1]], dtype=_cdtype())
691def _make_controlled_gate(target_class: type, name: str) -> type:
692 """Factory for controlled Pauli gates CX, CY, CZ.
694 Each gate has the form
695 ``CP = |0><0| \\otimes I + |1\\langle\\rangle 1| \\otimes P``.
697 Args:
698 target_class: The single-qubit gate class (PauliX, PauliY, PauliZ).
699 name: Class name for the generated gate (e.g. ``"CX"``).
701 Returns:
702 A new :class:`Operation` subclass.
703 """
704 target_mat = target_class._matrix
706 class _ControlledGate(Operation):
707 __doc__ = (
708 f"Controlled-{target_class.__name__[5:]} gate.\n\n"
709 f"Applies {target_class.__name__} on the target qubit conditioned "
710 f"on the control qubit being in state |1\\rangle."
711 )
712 _matrix = jnp.kron(_P0, Id._matrix) + jnp.kron(_P1, target_mat)
713 _num_wires = 2
714 is_controlled = True
716 def __init__(self, wires: List[int] = [0, 1], **kwargs) -> None:
717 super().__init__(wires=wires, **kwargs)
719 _ControlledGate.__name__ = name
720 _ControlledGate.__qualname__ = name
721 return _ControlledGate
724CX = _make_controlled_gate(PauliX, "CX")
725CY = _make_controlled_gate(PauliY, "CY")
726CZ = _make_controlled_gate(PauliZ, "CZ")
729class CCX(Operation):
730 """Toffoli (CCX) gate.
732 The 3-qubit Toffoli gate exercises the arbitrary-k-qubit path in
733 :meth:`~Operation.apply_to_state` and cannot be expressed as a pair of
734 2-qubit gates without ancilla, making it a good stress-test for the
735 simulator.
736 """
738 _matrix = jnp.array(
739 [
740 [1, 0, 0, 0, 0, 0, 0, 0],
741 [0, 1, 0, 0, 0, 0, 0, 0],
742 [0, 0, 1, 0, 0, 0, 0, 0],
743 [0, 0, 0, 1, 0, 0, 0, 0],
744 [0, 0, 0, 0, 1, 0, 0, 0],
745 [0, 0, 0, 0, 0, 1, 0, 0],
746 [0, 0, 0, 0, 0, 0, 0, 1],
747 [0, 0, 0, 0, 0, 0, 1, 0],
748 ],
749 dtype=_cdtype(),
750 )
751 is_controlled = True
752 _num_wires = 3
754 def __init__(self, wires: List[int] = [0, 1, 2], **kwargs) -> None:
755 """Initialise a Toffoli (CCX) gate.
757 Args:
758 wires: Three-element list ``[control0, control1, target]``.
759 """
760 super().__init__(wires=wires, **kwargs)
763class CSWAP(Operation):
764 """Controlled-SWAP (Fredkin) gate.
766 Swaps the two target qubits conditioned on the control qubit being |1\\rangle.
768 Args on construction:
769 wires: ``[control, target0, target1]``.
770 """
772 _matrix = jnp.array(
773 [
774 [1, 0, 0, 0, 0, 0, 0, 0],
775 [0, 1, 0, 0, 0, 0, 0, 0],
776 [0, 0, 1, 0, 0, 0, 0, 0],
777 [0, 0, 0, 1, 0, 0, 0, 0],
778 [0, 0, 0, 0, 1, 0, 0, 0],
779 [0, 0, 0, 0, 0, 0, 1, 0],
780 [0, 0, 0, 0, 0, 1, 0, 0],
781 [0, 0, 0, 0, 0, 0, 0, 1],
782 ],
783 dtype=_cdtype(),
784 )
785 is_controlled = True
786 _num_wires = 3
788 def __init__(self, wires: List[int] = [0, 1, 2], **kwargs) -> None:
789 """Initialise a Controlled-SWAP (Fredkin) gate.
791 Args:
792 wires: Three-element list ``[control, target0, target1]``.
793 """
794 super().__init__(wires=wires, **kwargs)
797def _make_controlled_rotation_gate(pauli_class: type, name: str) -> type:
798 """Factory for controlled rotation gates CRX, CRY, CRZ.
800 Each gate has the form
801 ``CR_P(\\theta) = |0><0| \\otimes I + |1><1| \\otimes R_P(\\theta)``.
803 Args:
804 pauli_class: One of PauliX, PauliY, PauliZ.
805 name: Class name for the generated gate (e.g. ``"CRX"``).
807 Returns:
808 A new :class:`Operation` subclass.
809 """
810 pauli_mat = pauli_class._matrix
812 class _CRotationGate(Operation):
813 __doc__ = (
814 f"Controlled rotation around the {name[2]} axis.\n\n"
815 f"Applies R{name[2]}(\\theta) on the target qubit conditioned on the "
816 f"control qubit being in state |1\\rangle.\n\n"
817 f".. math::\n"
818 f"{name}(\\theta) = |0\\rangle\\langle 0| \\otimes I\n"
819 f" + |1\\rangle\\langle 1| \\otimes R{name[2]}(\\theta)"
820 )
821 _num_wires = 2
822 _param_names = ("theta",)
823 is_controlled = True
825 def __init__(self, theta: float, wires: List[int] = [0, 1], **kwargs) -> None:
826 self.theta = theta
827 c = jnp.cos(theta / 2)
828 s = jnp.sin(theta / 2)
829 rot = c * Id._matrix - 1j * s * pauli_mat
830 mat = jnp.kron(_P0, Id._matrix) + jnp.kron(_P1, rot)
831 super().__init__(wires=wires, matrix=mat, **kwargs)
833 _CRotationGate.__name__ = name
834 _CRotationGate.__qualname__ = name
835 return _CRotationGate
838CRX = _make_controlled_rotation_gate(PauliX, "CRX")
839CRY = _make_controlled_rotation_gate(PauliY, "CRY")
840CRZ = _make_controlled_rotation_gate(PauliZ, "CRZ")
843class Rot(Operation):
844 """General single-qubit rotation:
845 Rot(\\phi, \\theta, \\omega) = RZ(\\omega) RY(\\theta) RZ(\\phi).
847 This is the most general SU(2) rotation (up to a global phase). It
848 decomposes into three successive rotations and has three free parameters.
849 """
851 _num_wires = 1
852 _param_names = ("phi", "theta", "omega")
854 def __init__(
855 self,
856 phi: float,
857 theta: float,
858 omega: float,
859 wires: Union[int, List[int]] = 0,
860 **kwargs,
861 ) -> None:
862 """Initialise a general rotation gate.
864 Args:
865 phi: First RZ rotation angle (radians).
866 theta: RY rotation angle (radians).
867 omega: Second RZ rotation angle (radians).
868 wires: Qubit index or list of qubit indices this gate acts on.
869 """
870 self.phi = phi
871 self.theta = theta
872 self.omega = omega
873 # Rot(\\phi, \theta, \\omega) = RZ(\\omega) @ RY(\theta) @ RZ(\\phi)
874 rz_phi = jnp.cos(phi / 2) * Id._matrix - 1j * jnp.sin(phi / 2) * PauliZ._matrix
875 ry_theta = (
876 jnp.cos(theta / 2) * Id._matrix - 1j * jnp.sin(theta / 2) * PauliY._matrix
877 )
878 rz_omega = (
879 jnp.cos(omega / 2) * Id._matrix - 1j * jnp.sin(omega / 2) * PauliZ._matrix
880 )
881 mat = rz_omega @ ry_theta @ rz_phi
882 super().__init__(wires=wires, matrix=mat, **kwargs)
885class PauliRot(Operation):
886 """Multi-qubit Pauli rotation: exp(-i \\theta/2 P) for a Pauli word P.
888 The Pauli word is given as a string of ``'I'``, ``'X'``, ``'Y'``, ``'Z'``
889 characters (one per qubit). The rotation matrix is computed as
890 ``cos(\\theta/2) I - i sin(\\theta/2) P`` where *P* is the tensor product of the
891 corresponding single-qubit Pauli matrices.
893 Example::
895 PauliRot(0.5, "XY", wires=[0, 1])
896 """
898 _param_names = ("theta",)
900 # Map from character to 2x2 matrix
901 _PAULI_MAP = {
902 "I": Id._matrix,
903 "X": PauliX._matrix,
904 "Y": PauliY._matrix,
905 "Z": PauliZ._matrix,
906 }
908 def __init__(
909 self, theta: float, pauli_word: str, wires: Union[int, List[int]] = 0, **kwargs
910 ) -> None:
911 """Initialise a PauliRot gate.
913 Args:
914 theta: Rotation angle in radians.
915 pauli_word: A string of ``'I'``, ``'X'``, ``'Y'``, ``'Z'``
916 characters specifying the Pauli tensor product.
917 wires: Qubit index or list of qubit indices this gate acts on.
918 """
919 from functools import reduce as _reduce
921 self.theta = theta
922 self.pauli_word = pauli_word
924 pauli_matrices = [self._PAULI_MAP[c] for c in pauli_word]
925 P = _reduce(jnp.kron, pauli_matrices)
926 dim = P.shape[0]
927 mat = (
928 jnp.cos(theta / 2) * jnp.eye(dim, dtype=_cdtype())
929 - 1j * jnp.sin(theta / 2) * P
930 )
931 super().__init__(wires=wires, matrix=mat, **kwargs)
933 def generator(self) -> Operation:
934 """Return the generator Pauli tensor product as an :class:`Operation`.
936 The generator of ``PauliRot(\\theta, word, wires)`` is the tensor product
937 of single-qubit Pauli matrices specified by *word*. The returned
938 :class:`Hermitian` wraps that matrix and the gate's wires.
940 Returns:
941 :class:`Hermitian` operation representing the Pauli tensor product.
942 """
943 from functools import reduce as _reduce
945 pauli_matrices = [self._PAULI_MAP[c] for c in self.pauli_word]
946 P = _reduce(jnp.kron, pauli_matrices)
947 return Hermitian(matrix=P, wires=self.wires, record=False)
950class KrausChannel(Operation):
951 """Base class for noise channels defined by a set of Kraus operators.
953 A Kraus channel \\phi(\\rho ) = \\sigma_k K_k \\rho K_k\\dagger
954 is the most general physical
955 operation on a quantum state. For a pure unitary gate there is a single
956 operator K_0 = U satisfying K_0\\daggerK_0 = I; for noisy channels there are
957 multiple operators.
959 Subclasses must implement :meth:`kraus_matrices` and return a list of JAX
960 arrays. :meth:`apply_to_state` is intentionally left unimplemented:
961 Kraus channels require a density-matrix representation and cannot be
962 applied to a pure statevector in general.
963 """
965 def kraus_matrices(self) -> List[jnp.ndarray]:
966 """Return the list of Kraus operators for this channel.
968 Returns:
969 List of 2-D JAX arrays, each of shape ``(2**k, 2**k)`` where k
970 is the number of target qubits.
972 Raises:
973 NotImplementedError: Subclasses must override this method.
974 """
975 raise NotImplementedError
977 @property
978 def matrix(self) -> jnp.ndarray:
979 """Raises TypeError — noise channels have no single unitary matrix.
981 Raises:
982 TypeError: Always raised; use :meth:`apply_to_density` instead.
983 """
984 raise TypeError(
985 f"{self.__class__.__name__} is a noise channel and has no single "
986 "unitary matrix. Use apply_to_density() instead."
987 )
989 def apply_to_state(self, state: jnp.ndarray, n_qubits: int) -> jnp.ndarray:
990 """Raises TypeError — noise channels require density-matrix simulation.
992 Args:
993 state: Statevector (unused).
994 n_qubits: Number of qubits (unused).
996 Raises:
997 TypeError: Always raised; use ``execute(type='density')`` instead.
998 """
999 raise TypeError(
1000 f"{self.__class__.__name__} is a noise channel and cannot be "
1001 "applied to a pure statevector. Use execute(type='density') instead."
1002 )
1004 def apply_to_state_tensor(self, psi: jnp.ndarray, n_qubits: int) -> jnp.ndarray:
1005 """Raises TypeError — noise channels require density-matrix simulation."""
1006 raise TypeError(
1007 f"{self.__class__.__name__} is a noise channel and cannot be "
1008 "applied to a pure statevector. Use execute(type='density') instead."
1009 )
1011 def apply_to_density(self, rho: jnp.ndarray, n_qubits: int) -> jnp.ndarray:
1012 """Apply
1013 \\phi(\\rho ) = \\sigma_k K_k \\rho K_k\\dagger using tensor-contraction.
1015 Uses the shared :func:`_contract_and_restore` helper, summing the
1016 result over all Kraus operators.
1018 Args:
1019 rho: Density matrix of shape ``(2**n_qubits, 2**n_qubits)``.
1020 n_qubits: Total number of qubits in the circuit.
1022 Returns:
1023 Updated density matrix of shape ``(2**n_qubits, 2**n_qubits)``.
1024 """
1025 k = len(self.wires)
1026 dim = 2**n_qubits
1027 bra_wires = [w + n_qubits for w in self.wires]
1028 rho_out = jnp.zeros_like(rho)
1030 for K in self.kraus_matrices():
1031 K_t = K.reshape((2,) * 2 * k)
1032 K_conj_t = jnp.conj(K_t)
1033 rho_t = rho.reshape((2,) * 2 * n_qubits)
1034 rho_t = _contract_and_restore(rho_t, K_t, k, self.wires)
1035 rho_t = _contract_and_restore(rho_t, K_conj_t, k, bra_wires)
1036 rho_out = rho_out + rho_t.reshape(dim, dim)
1038 return rho_out
1041class BitFlip(KrausChannel):
1042 r"""Single-qubit bit-flip (Pauli-X) error channel.
1044 .. math::
1045 K_0 = \sqrt{1-p}\,I, \quad K_1 = \sqrt{p}\,X
1047 where *p* \\in [0, 1] is the probability of a bit flip.
1048 """
1050 _num_wires = 1
1051 _param_names = ("p",)
1053 def __init__(self, p: float, wires: Union[int, List[int]] = 0) -> None:
1054 """Initialise a bit-flip channel.
1056 Args:
1057 p: Bit-flip probability, must be in [0, 1].
1058 wires: Qubit index or list of qubit indices this channel acts on.
1060 Raises:
1061 ValueError: If *p* is outside [0, 1].
1062 """
1063 if not 0.0 <= p <= 1.0:
1064 raise ValueError("p must be in [0, 1].")
1065 self.p = p
1066 super().__init__(wires=wires)
1068 def kraus_matrices(self) -> List[jnp.ndarray]:
1069 """Return the two Kraus operators for the bit-flip channel.
1071 Returns:
1072 List ``[K0, K1]`` where K0 = \\sqrt (1-p)·I and K1 = \\sqrt p·X.
1073 """
1074 p = self.p
1075 K0 = jnp.sqrt(1 - p) * Id._matrix
1076 K1 = jnp.sqrt(p) * PauliX._matrix
1077 return [K0, K1]
1080class PhaseFlip(KrausChannel):
1081 r"""Single-qubit phase-flip (Pauli-Z) error channel.
1083 .. math::
1084 K_0 = \sqrt{1-p}\,I, \quad K_1 = \sqrt{p}\,Z
1086 where *p* \\in [0, 1] is the probability of a phase flip.
1087 """
1089 _num_wires = 1
1090 _param_names = ("p",)
1092 def __init__(self, p: float, wires: Union[int, List[int]] = 0) -> None:
1093 """Initialise a phase-flip channel.
1095 Args:
1096 p: Phase-flip probability, must be in [0, 1].
1097 wires: Qubit index or list of qubit indices this channel acts on.
1099 Raises:
1100 ValueError: If *p* is outside [0, 1].
1101 """
1102 if not 0.0 <= p <= 1.0:
1103 raise ValueError("p must be in [0, 1].")
1104 self.p = p
1105 super().__init__(wires=wires)
1107 def kraus_matrices(self) -> List[jnp.ndarray]:
1108 """Return the two Kraus operators for the phase-flip channel.
1110 Returns:
1111 List ``[K0, K1]`` where K0 = \\sqrt (1-p)·I and K1 = \\sqrt p·Z.
1112 """
1113 p = self.p
1114 K0 = jnp.sqrt(1 - p) * Id._matrix
1115 K1 = jnp.sqrt(p) * PauliZ._matrix
1116 return [K0, K1]
1119class DepolarizingChannel(KrausChannel):
1120 r"""Single-qubit depolarizing channel.
1122 .. math::
1123 K_0 = \sqrt{1-p}\,I,\quad K_1 = \sqrt{p/3}\,X,\quad
1124 K_2 = \sqrt{p/3}\,Y,\quad K_3 = \sqrt{p/3}\,Z
1126 where *p* \\in [0, 1]. At p = 3/4 the channel is fully depolarizing.
1127 """
1129 _num_wires = 1
1130 _param_names = ("p",)
1132 def __init__(self, p: float, wires: Union[int, List[int]] = 0) -> None:
1133 """Initialise a depolarizing channel.
1135 Args:
1136 p: Depolarization probability, must be in [0, 1].
1137 wires: Qubit index or list of qubit indices this channel acts on.
1139 Raises:
1140 ValueError: If *p* is outside [0, 1].
1141 """
1142 if not 0.0 <= p <= 1.0:
1143 raise ValueError("p must be in [0, 1].")
1144 self.p = p
1145 super().__init__(wires=wires)
1147 def kraus_matrices(self) -> List[jnp.ndarray]:
1148 """Return the four Kraus operators for the depolarizing channel.
1150 Returns:
1151 List ``[K0, K1, K2, K3]`` corresponding to I, X, Y, Z components.
1152 """
1153 p = self.p
1154 K0 = jnp.sqrt(1 - p) * Id._matrix
1155 K1 = jnp.sqrt(p / 3) * PauliX._matrix
1156 K2 = jnp.sqrt(p / 3) * PauliY._matrix
1157 K3 = jnp.sqrt(p / 3) * PauliZ._matrix
1158 return [K0, K1, K2, K3]
1161class AmplitudeDamping(KrausChannel):
1162 r"""Single-qubit amplitude damping channel.
1164 .. math::
1165 K_0 = \begin{pmatrix}1 & 0\\ 0 & \sqrt{1-\gamma}\end{pmatrix},\quad
1166 K_1 = \begin{pmatrix}0 & \sqrt{\gamma}\\ 0 & 0\end{pmatrix}
1168 where *\\gamma* \\in [0, 1] is the probability of
1169 energy loss (|1\\rangle -> |0\\rangle).
1170 """
1172 _num_wires = 1
1173 _param_names = ("gamma",)
1175 def __init__(self, gamma: float, wires: Union[int, List[int]] = 0) -> None:
1176 """Initialise an amplitude damping channel.
1178 Args:
1179 gamma: Energy-loss probability, must be in [0, 1].
1180 wires: Qubit index or list of qubit indices this channel acts on.
1182 Raises:
1183 ValueError: If *gamma* is outside [0, 1].
1184 """
1185 if not 0.0 <= gamma <= 1.0:
1186 raise ValueError("gamma must be in [0, 1].")
1187 self.gamma = gamma
1188 super().__init__(wires=wires)
1190 def kraus_matrices(self) -> List[jnp.ndarray]:
1191 """Return the two Kraus operators for the amplitude damping channel.
1193 Returns:
1194 List ``[K0, K1]`` as defined in the class docstring.
1195 """
1196 g = self.gamma
1197 K0 = jnp.array([[1.0, 0.0], [0.0, jnp.sqrt(1 - g)]], dtype=_cdtype())
1198 K1 = jnp.array([[0.0, jnp.sqrt(g)], [0.0, 0.0]], dtype=_cdtype())
1199 return [K0, K1]
1202class PhaseDamping(KrausChannel):
1203 r"""Single-qubit phase damping (dephasing) channel.
1205 .. math::
1206 K_0 = \begin{pmatrix}1 & 0\\ 0 & \sqrt{1-\gamma}\end{pmatrix},\quad
1207 K_1 = \begin{pmatrix}0 & 0\\ 0 & \sqrt{\gamma}\end{pmatrix}
1209 where *\\gamma* \\in [0, 1] is the phase damping probability.
1210 """
1212 _num_wires = 1
1213 _param_names = ("gamma",)
1215 def __init__(self, gamma: float, wires: Union[int, List[int]] = 0) -> None:
1216 """Initialise a phase damping channel.
1218 Args:
1219 gamma: Phase-damping probability, must be in [0, 1].
1220 wires: Qubit index or list of qubit indices this channel acts on.
1222 Raises:
1223 ValueError: If *gamma* is outside [0, 1].
1224 """
1225 if not 0.0 <= gamma <= 1.0:
1226 raise ValueError("gamma must be in [0, 1].")
1227 self.gamma = gamma
1228 super().__init__(wires=wires)
1230 def kraus_matrices(self) -> List[jnp.ndarray]:
1231 """Return the two Kraus operators for the phase damping channel.
1233 Returns:
1234 List ``[K0, K1]`` as defined in the class docstring.
1235 """
1236 g = self.gamma
1237 K0 = jnp.array([[1.0, 0.0], [0.0, jnp.sqrt(1 - g)]], dtype=_cdtype())
1238 K1 = jnp.array([[0.0, 0.0], [0.0, jnp.sqrt(g)]], dtype=_cdtype())
1239 return [K0, K1]
1242class ThermalRelaxationError(KrausChannel):
1243 r"""Single-qubit thermal relaxation error channel.
1245 Models simultaneous T_1 energy relaxation and T_2 dephasing. Two regimes
1246 are handled:
1248 T_2 <= T_1 (Markovian dephasing + reset):
1249 Six Kraus operators built from p_z (phase-flip probability), p_r0
1250 (reset-to-|0\\rangle probability) and p_r1 (reset-to-|1\\rangle probability).
1252 T_2 > T_1 (non-Markovian; Choi matrix decomposition):
1253 The Choi matrix is assembled from the relaxation/dephasing rates, then
1254 diagonalised; Kraus operators are K_i = \sqrt \lambda_i · mat(v_i).
1256 Attributes:
1257 pe: Excited-state population (thermal population of |1\\rangle).
1258 t1: T_1 longitudinal relaxation time.
1259 t2: T_2 transverse dephasing time.
1260 tg: Gate duration.
1261 """
1263 _num_wires = 1
1264 _param_names = ("pe", "t1", "t2", "tg")
1266 def __init__(
1267 self,
1268 pe: float,
1269 t1: float,
1270 t2: float,
1271 tg: float,
1272 wires: Union[int, List[int]] = 0,
1273 ) -> None:
1274 """Initialise a thermal relaxation error channel.
1276 Args:
1277 pe: Excited-state population (thermal population of |1\\rangle), in [0, 1].
1278 t1: T_1 longitudinal relaxation time, must be > 0.
1279 t2: T_2 transverse dephasing time, must be > 0 and <= 2·T_1.
1280 tg: Gate duration, must be >= 0.
1281 wires: Qubit index or list of qubit indices this channel acts on.
1283 Raises:
1284 ValueError: If any parameter violates the stated constraints.
1285 """
1286 if not 0.0 <= pe <= 1.0:
1287 raise ValueError("pe must be in [0, 1].")
1288 if t1 <= 0:
1289 raise ValueError("t1 must be > 0.")
1290 if t2 <= 0:
1291 raise ValueError("t2 must be > 0.")
1292 if t2 > 2 * t1:
1293 raise ValueError("t2 must be <= 2·t1.")
1294 if tg < 0:
1295 raise ValueError("tg must be >= 0.")
1296 self.pe = pe
1297 self.t1 = t1
1298 self.t2 = t2
1299 self.tg = tg
1300 super().__init__(wires=wires)
1302 def kraus_matrices(self) -> List[jnp.ndarray]:
1303 """Return the Kraus operators for the thermal relaxation channel.
1305 The number of operators depends on the regime:
1307 * T_2 <= T_1: six operators (identity, phase-flip, two reset-to-|0\\rangle,
1308 two reset-to-|1\\rangle).
1309 * T_2 > T_1: four operators derived from the Choi matrix eigendecomposition.
1311 Returns:
1312 List of 2x2 JAX arrays representing the Kraus operators.
1313 """
1314 pe, t1, t2, tg = self.pe, self.t1, self.t2, self.tg
1316 eT1 = jnp.exp(-tg / t1)
1317 p_reset = 1.0 - eT1
1318 eT2 = jnp.exp(-tg / t2)
1320 if t2 <= t1:
1321 # --- Case T_2 <= T_1: six Kraus operators ---
1322 pz = (1.0 - p_reset) * (1.0 - eT2 / eT1) / 2.0
1323 pr0 = (1.0 - pe) * p_reset
1324 pr1 = pe * p_reset
1325 pid = 1.0 - pz - pr0 - pr1
1327 K0 = jnp.sqrt(pid) * jnp.eye(2, dtype=_cdtype())
1328 K1 = jnp.sqrt(pz) * jnp.array([[1, 0], [0, -1]], dtype=_cdtype())
1329 K2 = jnp.sqrt(pr0) * jnp.array([[1, 0], [0, 0]], dtype=_cdtype())
1330 K3 = jnp.sqrt(pr0) * jnp.array([[0, 1], [0, 0]], dtype=_cdtype())
1331 K4 = jnp.sqrt(pr1) * jnp.array([[0, 0], [1, 0]], dtype=_cdtype())
1332 K5 = jnp.sqrt(pr1) * jnp.array([[0, 0], [0, 1]], dtype=_cdtype())
1333 return [K0, K1, K2, K3, K4, K5]
1335 else:
1336 # --- Case T_2 > T_1: Choi matrix decomposition ---
1337 # Choi matrix (column-major / reshaping convention matching PennyLane)
1338 choi = jnp.array(
1339 [
1340 [1 - pe * p_reset, 0, 0, eT2],
1341 [0, pe * p_reset, 0, 0],
1342 [0, 0, (1 - pe) * p_reset, 0],
1343 [eT2, 0, 0, 1 - (1 - pe) * p_reset],
1344 ],
1345 dtype=_cdtype(),
1346 )
1347 eigenvalues, eigenvectors = jnp.linalg.eigh(choi)
1348 # Each eigenvector (column of eigenvectors) reshaped as 2x2 -> one Kraus op
1349 kraus = []
1350 for i in range(4):
1351 lam = eigenvalues[i]
1352 vec = eigenvectors[:, i]
1353 mat = jnp.sqrt(jnp.abs(lam)) * vec.reshape(2, 2, order="F")
1354 kraus.append(mat.astype(_cdtype()))
1355 return kraus
1358class QubitChannel(KrausChannel):
1359 """Generic Kraus channel from a user-supplied list of Kraus operators.
1361 This replaces PennyLane's ``qml.QubitChannel`` and accepts an arbitrary set
1362 of Kraus matrices satisfying \\sigma_k K_k\\dagger K_k = I.
1364 Example::
1366 kraus_ops = [jnp.sqrt(0.9) * jnp.eye(2), jnp.sqrt(0.1) * PauliX._matrix]
1367 QubitChannel(kraus_ops, wires=0)
1368 """
1370 def __init__(
1371 self, kraus_ops: List[jnp.ndarray], wires: Union[int, List[int]] = 0
1372 ) -> None:
1373 """Initialise a generic Kraus channel.
1375 Args:
1376 kraus_ops: List of Kraus matrices. Each must be a square 2D array
1377 of dimension ``2**k x 2**k`` where k = ``len(wires)``.
1378 wires: Qubit index or list of qubit indices this channel acts on.
1379 """
1380 self._kraus_ops = [jnp.asarray(K, dtype=_cdtype()) for K in kraus_ops]
1381 super().__init__(wires=wires)
1383 def kraus_matrices(self) -> List[jnp.ndarray]:
1384 """Return the stored Kraus operators.
1386 Returns:
1387 List of Kraus operator matrices.
1388 """
1389 return self._kraus_ops
1392# Single-qubit Pauli matrices (plain arrays, no Operation overhead)
1393_PAULI_MATS = [Id._matrix, PauliX._matrix, PauliY._matrix, PauliZ._matrix]
1394_PAULI_LABELS = ["I", "X", "Y", "Z"]
1395_PAULI_CLASSES = [Id, PauliX, PauliY, PauliZ]
1398def evolve_pauli_with_clifford(
1399 clifford: Operation,
1400 pauli: Operation,
1401 adjoint_left: bool = True,
1402) -> Operation:
1403 """Compute C\\dagger P C (or C P C\\dagger) and
1404 return the result as an Operation.
1406 Both operators are first embedded into the full Hilbert space spanned by
1407 the union of their wire sets. The result is wrapped in a
1408 :class:`Hermitian` so it can be used in further algebra.
1410 Args:
1411 clifford: A Clifford gate.
1412 pauli: A Pauli / Hermitian operator.
1413 adjoint_left: If ``True``, compute C\\dagger P C; otherwise C P C\\dagger.
1415 Returns:
1416 A :class:`Hermitian` wrapping the evolved matrix.
1417 """
1418 all_wires = sorted(set(clifford.wires) | set(pauli.wires))
1419 n = len(all_wires)
1421 C = _embed_matrix(clifford.matrix, clifford.wires, all_wires, n)
1422 P = _embed_matrix(pauli.matrix, pauli.wires, all_wires, n)
1423 Cd = jnp.conj(C).T
1425 if adjoint_left:
1426 result = Cd @ P @ C
1427 else:
1428 result = C @ P @ Cd
1430 return Hermitian(matrix=result, wires=all_wires, record=False)
1433def _embed_matrix(
1434 mat: jnp.ndarray,
1435 op_wires: list,
1436 all_wires: list,
1437 n_total: int,
1438) -> jnp.ndarray:
1439 """Embed a gate matrix into a larger Hilbert space via tensor products.
1441 If the gate already acts on all wires, the matrix is returned as-is.
1442 Otherwise the gate matrix is tensored with identities on the missing
1443 wires, and the resulting matrix rows/columns are permuted so that qubit
1444 ordering matches *all_wires*.
1446 Args:
1447 mat: The gate's unitary matrix of shape ``(2**k, 2**k)`` where
1448 ``k = len(op_wires)``.
1449 op_wires: The wires the gate acts on.
1450 all_wires: The full ordered list of wires.
1451 n_total: ``len(all_wires)``.
1453 Returns:
1454 A ``(2**n_total, 2**n_total)`` matrix.
1455 """
1456 k = len(op_wires)
1457 if k == n_total and list(op_wires) == list(all_wires):
1458 return mat
1460 # Build the full-space matrix by tensoring with identities
1461 # Strategy: tensor I on missing wires, then permute
1462 missing = [w for w in all_wires if w not in op_wires]
1463 # Full matrix = mat \\otimes I_{missing}
1464 full_mat = mat
1465 for _ in missing:
1466 full_mat = jnp.kron(full_mat, jnp.eye(2, dtype=_cdtype()))
1468 # The current ordering is [op_wires..., missing...]
1469 # We need to permute to match all_wires ordering
1470 current_order = list(op_wires) + missing
1471 if current_order != list(all_wires):
1472 perm = [current_order.index(w) for w in all_wires]
1473 full_mat = _permute_matrix(full_mat, perm, n_total)
1475 return full_mat
1478def _permute_matrix(mat: jnp.ndarray, perm: list, n_qubits: int) -> jnp.ndarray:
1479 """Permute the qubit ordering of a matrix.
1481 Given a ``(2**n, 2**n)`` matrix and a permutation of ``[0..n-1]``,
1482 reorder the qubits so that qubit ``i`` moves to position ``perm[i]``.
1484 Args:
1485 mat: Square matrix of dimension ``2**n_qubits``.
1486 perm: Permutation list.
1487 n_qubits: Number of qubits.
1489 Returns:
1490 Permuted matrix of the same shape.
1491 """
1492 dim = 2**n_qubits
1493 # Reshape to tensor, permute axes, reshape back
1494 tensor = mat.reshape([2] * (2 * n_qubits))
1495 # Axes: first n_qubits are row indices, last n_qubits are column indices
1496 row_perm = perm
1497 col_perm = [p + n_qubits for p in perm]
1498 tensor = jnp.transpose(tensor, row_perm + col_perm)
1499 return tensor.reshape(dim, dim)
1502def pauli_decompose(matrix: jnp.ndarray, wire_order: Optional[List[int]] = None):
1503 r"""Decompose a Hermitian matrix into a sum of Pauli tensor products.
1505 For an n-qubit matrix (``2**n x 2**n``), returns the dominant Pauli
1506 term (the one with the largest absolute coefficient), wrapped as an
1507 :class:`Operation`. This is sufficient for the Fourier-tree algorithm
1508 which only needs the single non-zero Pauli term produced by Clifford
1509 conjugation of a Pauli operator.
1511 The decomposition uses the trace formula:
1512 ``c_P = Tr(P · M) / 2**n``
1514 Args:
1515 matrix: A ``(2**n, 2**n)`` Hermitian matrix.
1516 wire_order: Optional list of wire indices. If ``None``, defaults
1517 to ``[0, 1, ..., n-1]``.
1519 Returns:
1520 A tuple ``(coeff, op)`` where *coeff* is the complex coefficient and
1521 *op* is the Pauli :class:`Operation` (PauliX, PauliY, PauliZ, I, or
1522 a :class:`Hermitian` for multi-qubit tensor products).
1523 """
1524 from itertools import product as _product
1525 from functools import reduce as _reduce
1527 dim = matrix.shape[0]
1528 n_qubits = int(jnp.round(jnp.log2(dim)))
1530 if wire_order is None:
1531 wire_order = list(range(n_qubits))
1533 # For single qubit, fast path
1534 if n_qubits == 1:
1535 best_idx, best_coeff = 0, 0.0
1536 for idx, P in enumerate(_PAULI_MATS):
1537 coeff = jnp.trace(P @ matrix) / 2.0
1538 if jnp.abs(coeff) > jnp.abs(best_coeff):
1539 best_idx = idx
1540 best_coeff = coeff
1541 op_cls = _PAULI_CLASSES[best_idx]
1542 result_op = op_cls(wires=wire_order[0], record=False)
1543 result_op._pauli_label = _PAULI_LABELS[best_idx]
1544 return best_coeff, result_op
1546 # Multi-qubit: iterate over all Pauli tensor products
1547 best_label = None
1548 best_coeff = 0.0
1549 for indices in _product(range(4), repeat=n_qubits):
1550 P = _reduce(jnp.kron, [_PAULI_MATS[i] for i in indices])
1551 coeff = jnp.trace(P @ matrix) / dim
1552 if jnp.abs(coeff) > jnp.abs(best_coeff):
1553 best_coeff = coeff
1554 best_label = indices
1556 # Build the Pauli string label
1557 pauli_label = "".join(_PAULI_LABELS[i] for i in best_label)
1559 # Build the operation for the dominant term
1560 if sum(1 for i in best_label if i != 0) <= 1:
1561 # Single-qubit Pauli on one wire
1562 for q, idx in enumerate(best_label):
1563 if idx != 0:
1564 op_cls = _PAULI_CLASSES[idx]
1565 result_op = op_cls(wires=wire_order[q], record=False)
1566 result_op._pauli_label = _PAULI_LABELS[idx]
1567 return best_coeff, result_op
1568 # All identity
1569 result_op = Id(wires=wire_order[0], record=False)
1570 result_op._pauli_label = "I" * n_qubits
1571 return best_coeff, result_op
1572 else:
1573 # Multi-qubit tensor product -> Hermitian with pauli label attached
1574 P = _reduce(jnp.kron, [_PAULI_MATS[i] for i in best_label])
1575 result_op = Hermitian(matrix=P, wires=wire_order, record=False)
1576 result_op._pauli_label = pauli_label
1577 return best_coeff, result_op
1580def pauli_string_from_operation(op: Operation) -> str:
1581 """Extract a Pauli word string from an operation.
1583 Maps ``PauliX`` -> ``"X"``, ``PauliY`` -> ``"Y"``, ``PauliZ`` -> ``"Z"``,
1584 ``I`` -> ``"I"``. For :class:`PauliRot`, returns its stored ``pauli_word``.
1585 For operations produced by :func:`pauli_decompose`, returns the stored
1586 ``_pauli_label`` attribute.
1588 Args:
1589 op: A quantum operation.
1591 Returns:
1592 A string like ``"X"``, ``"ZZ"``, etc.
1593 """
1594 if isinstance(op, PauliRot) and hasattr(op, "pauli_word"):
1595 return op.pauli_word
1596 # Check for label stored by pauli_decompose
1597 if hasattr(op, "_pauli_label"):
1598 return op._pauli_label
1599 name_map = {"PauliX": "X", "PauliY": "Y", "PauliZ": "Z", "I": "I"}
1600 if op.name in name_map:
1601 return name_map[op.name]
1602 # Fall back: decompose the matrix
1603 _, pauli_op = pauli_decompose(op.matrix, wire_order=op.wires)
1604 return pauli_op._pauli_label