Coverage for qml_essentials / operations.py: 90%
526 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-04-10 10:29 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-04-10 10:29 +0000
1from typing import Callable, List, Optional, Tuple, Union
2from functools import lru_cache
3import string
4import numpy as np
6import jax
7import jax.numpy as jnp
9from qml_essentials.tape import active_tape, recording # noqa: F401 (re-export)
12def _cdtype():
13 """Return the active JAX complex dtype
14 (complex128 if x64 enabled, else complex64).
15 """
16 return jnp.complex128 if jax.config.x64_enabled else jnp.complex64
19@lru_cache(maxsize=256)
20def _einsum_subscript(
21 n: int,
22 k: int,
23 target_axes: Tuple[int, ...],
24) -> str:
25 """Build an ``einsum`` subscript that fuses contraction + axis restore.
27 Args:
28 n: Total rank of the state tensor (number of qubits for statevectors,
29 ``2 * n_qubits`` for density matrices).
30 k: Number of qubits the gate acts on.
31 target_axes: Tuple of k axis indices in the state tensor that the
32 gate contracts against.
34 Returns:
35 ``einsum`` subscript string, e.g. ``"ab,cBd->cad"`` for a 1-qubit
36 gate on wire 1 of a 3-qubit state.
37 """
38 letters = string.ascii_letters
39 # State indices: one letter per axis
40 state_idx = list(letters[:n])
41 # Contracted indices (the ones being replaced by the gate)
42 contracted = [state_idx[ax] for ax in target_axes]
43 # Gate indices: new output indices + contracted input indices
44 new_out = [letters[n + i] for i in range(k)] # fresh letters for output
45 gate_idx = new_out + contracted # gate shape: (out0, out1, ..., in0, in1, ...)
46 # Result indices: replace target axes with new output letters
47 result_idx = list(state_idx)
48 for i, ax in enumerate(target_axes):
49 result_idx[ax] = new_out[i]
50 return "".join(gate_idx) + "," + "".join(state_idx) + "->" + "".join(result_idx)
53def _contract_and_restore(
54 tensor: jnp.ndarray,
55 gate: jnp.ndarray,
56 k: int,
57 target_axes: List[int],
58) -> jnp.ndarray:
59 """Contract gate against target_axes of tensor and restore axis order.
61 The einsum subscript is cached via :func:`_einsum_subscript` so the
62 string construction only happens once per unique
63 ``(total, k, target_axes)`` combination.
65 Args:
66 tensor: Rank-N tensor (e.g. ``(2,)*n`` for states or ``(2,)*2n``
67 for density matrices).
68 gate: Reshaped gate tensor of shape ``(2,)*2k``.
69 k: Number of qubits the gate acts on (= ``len(target_axes)``).
70 target_axes: The k axes of tensor to contract against.
72 Returns:
73 Updated tensor with the same rank as tensor, with the
74 contracted axes restored to their original positions.
75 """
76 subscript = _einsum_subscript(tensor.ndim, k, tuple(target_axes))
77 return jnp.einsum(subscript, gate, tensor)
80class Operation:
81 """Base class for any quantum operation or observable.
83 Further gates should inherit from this class to realise more specific
84 operations. Generally, operations are created by instantiation inside a
85 circuit function passed to :class:`Script`; the instance is
86 automatically appended to the active tape.
88 An ``Operation`` can also serve as an *observable*: its matrix is used to
89 compute expectation values via ``apply_to_state`` / ``apply_to_density``.
91 Attributes:
92 _matrix: Class-level default gate matrix. Subclasses set this to their
93 fixed unitary. Instances may override it via the *matrix* argument
94 to ``__init__``.
95 _num_wires: Expected number of wires for this gate. Subclasses set
96 this to enforce wire count validation. ``None`` means any number
97 of wires is accepted.
98 _param_names: Tuple of attribute names for the gate parameters.
99 Used by :attr:`parameters` and :meth:`__repr__`.
100 """
102 # Subclasses should set this to the gate's unitary / matrix
103 _matrix: jnp.ndarray = None
104 is_controlled = False
105 _num_wires: Optional[int] = None
106 _param_names: Tuple[str, ...] = ()
108 def __init__(
109 self,
110 wires: Union[int, List[int]] = 0,
111 matrix: Optional[jnp.ndarray] = None,
112 record: bool = True,
113 input_idx: int = -1,
114 name: Optional[str] = None,
115 ) -> None:
116 """Initialise the operation and optionally register it on the active tape.
118 Args:
119 wires: Qubit index or list of qubit indices this operation acts on.
120 matrix: Optional explicit gate matrix. When provided it overrides
121 the class-level ``_matrix`` attribute.
122 record: If ``True`` (default) and a tape is currently recording,
123 append this operation to the tape. Set to ``False`` for
124 auxiliary objects that should not appear in the circuit
125 (e.g. Hamiltonians used only to build time-dependent
126 evolutions).
127 input_idx: Marks the operation as input with the corresponding
128 input index, which is useful for the analytical Fourier
129 coefficients computation, but has no effect otherwise.
130 name: Optional explicit name for this operation. When ``None``
131 (default), the class name is used (e.g. ``"RX"``).
133 Raises:
134 ValueError: If ``_num_wires`` is set and the number of wires
135 doesn't match, or if duplicate wires are provided.
136 """
137 self.name = name or self.__class__.__name__
138 self.wires = list(wires) if isinstance(wires, (list, tuple)) else [wires]
139 self.input_idx = input_idx
141 if self._num_wires is not None and len(self.wires) != self._num_wires:
142 raise ValueError(
143 f"{self.name} expects {self._num_wires} wire(s), "
144 f"got {len(self.wires)}: {self.wires}"
145 )
146 if len(self.wires) != len(set(self.wires)):
147 raise ValueError(f"{self.name} received duplicate wires: {self.wires}")
149 if matrix is not None:
150 self._matrix = matrix
152 # If a tape is currently recording, append ourselves
153 if record:
154 tape = active_tape()
155 if tape is not None:
156 tape.append(self)
158 @property
159 def parameters(self) -> list:
160 """Return the list of numeric parameters for this operation.
162 Uses the declarative ``_param_names`` tuple to collect parameter
163 values in a canonical order. Non-parametrized gates return an
164 empty list.
166 Returns:
167 List of parameter values (floats or JAX arrays).
168 """
169 return [getattr(self, name) for name in self._param_names]
171 def __repr__(self) -> str:
172 """Return a human-readable representation of this operation.
174 Returns:
175 A string like ``"RX(0.5000, wires=[0])"`` or ``"CX(wires=[0, 1])"``.
176 """
177 params = self.parameters
178 if params:
179 param_str = ", ".join(
180 (
181 f"{float(v):.4f}"
182 if isinstance(v, (float, np.floating, jnp.ndarray))
183 else str(v)
184 )
185 for v in params
186 )
187 return f"{self.name}({param_str}, wires={self.wires})"
188 return f"{self.name}(wires={self.wires})"
190 @property
191 def matrix(self) -> jnp.ndarray:
192 """Return the base matrix of this operation (before lifting).
194 Returns:
195 The gate matrix as a JAX array.
197 Raises:
198 NotImplementedError: If the subclass has not defined ``_matrix``.
199 """
200 if self._matrix is None:
201 raise NotImplementedError(
202 f"{self.__class__.__name__} does not define a matrix."
203 )
204 return self._matrix
206 @property
207 def wires(self) -> List[int]:
208 """Qubit indices this operation acts on.
210 Returns:
211 List of integer qubit indices.
212 """
213 return self._wires
215 @wires.setter
216 def wires(self, wires: Union[int, List[int]]) -> None:
217 """Set the qubit indices for this operation.
219 Args:
220 wires: A single qubit index or a list of qubit indices.
221 """
222 if isinstance(wires, (list, tuple)):
223 self._wires = list(wires)
224 else:
225 self._wires = [wires]
227 @property
228 def input_idx(self) -> int:
229 """The index of an input
231 Returns:
232 input_idx: Index of the input
233 """
234 return self._input_idx
236 @input_idx.setter
237 def input_idx(self, input_idx: int) -> None:
238 """Setter for the input_idx flag
240 Args:
241 input_idx: Index of the input
242 """
243 self._input_idx = input_idx
245 def _update_tape_operation(self, op: "Operation") -> None:
246 """
247 If ``self`` is already on the active tape (the typical case when
248 chaining ``Gate(...).dagger()``), it is replaced by the daggered
249 operation so that only U\\dagger appears on the tape —
250 not both U and ``U\\dagger``.
251 Note that this should only be called immediately after the tape is updated.s
253 Args:
254 op (Operation): New replaced operation on the tape
255 """
256 # If self was recorded on the tape, replace it with the daggered op.
257 tape = active_tape()
258 if tape is not None:
259 if tape and tape[-1] is self:
260 tape[-1] = op
261 else:
262 tape.append(op)
264 def dagger(self) -> "Operation":
265 """Return a new operation, the conjugate transpose (``U\\dagger``)
266 Usage inside a circuit function::
268 RX(0.5, wires=0).dagger()
270 Returns:
271 A new :class:`Operation` with matrix ``U\\dagger`` acting on the same wires.
272 """
273 mat = jnp.conj(self._matrix).T
274 op = Operation(wires=self.wires, matrix=mat, record=False)
276 self._update_tape_operation(op)
278 return op
280 def power(self, power) -> "Operation":
281 """Return a new operation, the power (``U^power``)
282 Usage inside a circuit function::
284 PauliX(wires=0).power(2)
286 Returns:
287 A new :class:`Operation` with matrix ``U\\dagger`` acting on the same wires.
288 """
289 # TODO: support fractional powers
290 mat = jnp.linalg.matrix_power(self._matrix, power)
291 op = Operation(wires=self.wires, matrix=mat, record=False)
293 self._update_tape_operation(op)
295 return op
297 def __mul__(self, factor: float) -> "Operation":
298 """Return a new operation, the product between U and a scalar (``U*x``)
299 Usage inside a circuit function::
301 PauliX(wires=0) * x
303 Returns:
304 A new :class:`Operation` with matrix ``U*x`` acting on the same wires.
305 """
306 mat = factor * self._matrix
307 op = Operation(wires=self.wires, matrix=mat, record=False)
309 self._update_tape_operation(op)
311 return op
313 # Also overwrite * for right operands
314 __rmul__ = __mul__
316 def __add__(self, other: "Operation") -> "Operation":
317 """Element-wise addition of two operations on the same wires.
319 Returns:
320 A new :class:`Operation` whose matrix is the sum of both matrices.
322 Raises:
323 ValueError: If the wire sets differ.
324 """
325 if sorted(self.wires) != sorted(other.wires):
326 raise ValueError(
327 f"Can only add operations acting on the same set of wires, "
328 f"got {self.wires} and {other.wires}"
329 )
331 op = Operation(
332 wires=self.wires,
333 matrix=self.matrix + other.matrix,
334 record=False,
335 )
336 return op
338 def __matmul__(self, other: "Operation") -> "Operation":
339 """Tensor (Kronecker) product of two operations.
341 The resulting operation acts on the union of both wire sets and
342 carries the Kronecker product of both matrices. Wire sets must
343 be disjoint.
345 Returns:
346 A new :class:`Operation` whose matrix is ``self.matrix ⊗ other.matrix``
347 and whose wires are the concatenation of both wire lists.
349 Raises:
350 ValueError: If the two operations share any wires.
351 """
352 if set(self.wires) & set(other.wires):
353 raise ValueError(
354 f"Cannot take tensor product: overlapping wires "
355 f"{self.wires} and {other.wires}"
356 )
357 new_matrix = jnp.kron(self.matrix, other.matrix)
358 new_wires = self.wires + other.wires
359 op = Operation(wires=new_wires, matrix=new_matrix, record=False)
360 return op
362 def lifted_matrix(self, n_qubits: int) -> jnp.ndarray:
363 """Return the full ``2**n x 2**n`` matrix embedding this gate.
365 Embeds the ``k``-qubit gate matrix into the ``n``-qubit Hilbert space
366 by applying it to the identity matrix via :meth:`apply_to_state`.
367 This is useful for computing ``Tr(O·\\rho )`` directly without vmap.
369 Args:
370 n_qubits: Total number of qubits in the circuit.
372 Returns:
373 The ``(2**n, 2**n)`` matrix of this operation in the full space.
374 """
375 dim = 2**n_qubits
376 # Apply the gate to each basis vector (column of identity)
377 return jax.vmap(lambda col: self.apply_to_state(col, n_qubits))(
378 jnp.eye(dim, dtype=_cdtype())
379 ).T
381 def apply_to_state(self, state: jnp.ndarray, n_qubits: int) -> jnp.ndarray:
382 """Apply this gate to a statevector via tensor contraction.
384 The statevector (shape ``(2**n,)``) is reshaped into a rank-n tensor
385 of shape ``(2,)*n``. The gate (shape ``(2**k, 2**k)``) is reshaped to
386 ``(2,)*2k`` and contracted against the k target wire axes.
388 Memory footprint is O(2**n) and the operation supports arbitrary k.
389 The implementation is fully differentiable through JAX.
391 Args:
392 state: Statevector of shape ``(2**n_qubits,)``.
393 n_qubits: Total number of qubits in the circuit.
395 Returns:
396 Updated statevector of shape ``(2**n_qubits,)``.
397 """
398 k = len(self.wires)
399 gate_tensor = self.matrix.reshape((2,) * 2 * k)
400 psi = state.reshape((2,) * n_qubits)
401 psi_out = _contract_and_restore(psi, gate_tensor, k, self.wires)
402 return psi_out.reshape(2**n_qubits)
404 def apply_to_state_tensor(self, psi: jnp.ndarray, n_qubits: int) -> jnp.ndarray:
405 """Apply this gate to a statevector already in tensor form.
407 Like :meth:`apply_to_state` but expects the state in rank-n tensor
408 form ``(2,)*n`` and returns the result in the same form. This avoids
409 the ``reshape`` calls at the per-gate level when the simulation loop
410 keeps the state in tensor form throughout.
412 Args:
413 psi: Statevector tensor of shape ``(2,)*n_qubits``.
414 n_qubits: Total number of qubits in the circuit.
416 Returns:
417 Updated statevector tensor of shape ``(2,)*n_qubits``.
418 """
419 k = len(self.wires)
420 gate_tensor = self._gate_tensor(k)
421 return _contract_and_restore(psi, gate_tensor, k, self.wires)
423 def _gate_tensor(self, k: int) -> jnp.ndarray:
424 """Return the gate matrix reshaped to ``(2,)*2k`` tensor form.
426 The result is cached on the instance so repeated calls (e.g. from
427 density-matrix simulation which applies U and U*) avoid redundant
428 reshape dispatch.
430 Args:
431 k: Number of qubits the gate acts on.
433 Returns:
434 Gate matrix as a rank-2k tensor of shape ``(2,)*2k``.
435 """
436 cached = getattr(self, "_cached_gate_tensor", None)
437 if cached is not None:
438 return cached
439 gt = self.matrix.reshape((2,) * 2 * k)
440 # Only cache for non-parametrized gates (whose matrix is a class attr)
441 if self._matrix is self.__class__._matrix:
442 object.__setattr__(self, "_cached_gate_tensor", gt)
443 return gt
445 def apply_to_density(self, rho: jnp.ndarray, n_qubits: int) -> jnp.ndarray:
446 """Apply this gate to a density matrix via \\rho -> U\\rho U\\dagger.
448 The density matrix (shape ``(2**n, 2**n)``) is treated as a rank-*2n*
449 tensor with n "ket" axes (0..n-1) and n "bra" axes (n..2n-1).
450 U acts on the ket half; U* acts on the bra half. Both contractions
451 use the shared :func:`_contract_and_restore` helper, keeping the
452 operation allocation-free with respect to building full unitaries.
454 Args:
455 rho: Density matrix of shape ``(2**n_qubits, 2**n_qubits)``.
456 n_qubits: Total number of qubits in the circuit.
458 Returns:
459 Updated density matrix of shape ``(2**n_qubits, 2**n_qubits)``.
460 """
461 k = len(self.wires)
462 U = self._gate_tensor(k)
463 U_conj = jnp.conj(U)
465 rho_t = rho.reshape((2,) * 2 * n_qubits)
467 # Apply U to ket axes, U\\dagger to bra axes
468 rho_t = _contract_and_restore(rho_t, U, k, self.wires)
469 bra_wires = [w + n_qubits for w in self.wires]
470 rho_t = _contract_and_restore(rho_t, U_conj, k, bra_wires)
472 return rho_t.reshape(2**n_qubits, 2**n_qubits)
475class Hermitian(Operation):
476 """A generic Hermitian observable or gate defined by an arbitrary matrix.
478 Example:
479 >>> obs = Hermitian(matrix=my_matrix, wires=0)
480 """
482 def __init__(
483 self,
484 matrix: jnp.ndarray,
485 wires: Union[int, List[int]] = 0,
486 record: bool = True,
487 ) -> None:
488 """Initialise a Hermitian operator.
490 Args:
491 matrix: The Hermitian matrix defining this operator.
492 wires: Qubit index or list of qubit indices this operator acts on.
493 record: If ``True`` (default), record on the active tape. Set to
494 ``False`` when using the Hermitian purely as a Hamiltonian
495 component (e.g. for time-dependent evolution).
496 """
497 super().__init__(
498 wires=wires,
499 matrix=jnp.asarray(matrix, dtype=_cdtype()),
500 record=record,
501 )
503 def __rmul__(self, coeff_fn):
504 """Support ``coeff_fn * Hermitian`` -> :class:`ParametrizedHamiltonian`.
506 Args:
507 coeff_fn: A callable ``(params, t) -> scalar`` giving the
508 time-dependent coefficient.
510 Returns:
511 A :class:`ParametrizedHamiltonian` pairing *coeff_fn* with this
512 operator's matrix and wires.
514 Raises:
515 TypeError: If *coeff_fn* is not callable.
516 """
517 if not callable(coeff_fn):
518 raise TypeError(
519 f"Left operand of `* Hermitian` must be callable, got {type(coeff_fn)}"
520 )
521 return ParametrizedHamiltonian(coeff_fn, self.matrix, self.wires)
524class ParametrizedHamiltonian:
525 """A time-dependent Hamiltonian ``H(t) = f(params, t) · H_mat``.
527 Created by multiplying a callable coefficient function with a
528 :class:`Hermitian` operator::
530 def coeff(p, t):
531 return p[0] * jnp.exp(-0.5 * ((t - t_c) / p[1]) ** 2)
533 H_td = coeff * Hermitian(matrix=sigma_x, wires=0)
535 The Hamiltonian is then used with :func:`evolve`::
537 evolve(H_td)(coeff_args=[A, sigma], T=1.0)
539 Attributes:
540 coeff_fn: Callable ``(params, t) -> scalar``.
541 H_mat: Static Hermitian matrix (JAX array).
542 wires: Qubit wire(s) this Hamiltonian acts on.
543 """
545 def __init__(
546 self,
547 coeff_fn: Callable,
548 H_mat: jnp.ndarray,
549 wires: Union[int, List[int]],
550 ) -> None:
551 self.coeff_fn = coeff_fn
552 self.H_mat = H_mat
553 self.wires = wires
556class Id(Operation):
557 """Identity gate.
559 Supports an arbitrary number of wires. When more than one wire is
560 given the matrix is the ``2**k x 2**k`` identity (where *k* is the
561 number of wires).
562 """
564 _matrix = jnp.eye(2, dtype=_cdtype())
565 _num_wires = None # accept any number of wires
567 def __init__(self, wires: Union[int, List[int]] = 0, **kwargs) -> None:
568 """Initialise an identity gate.
570 Args:
571 wires: Qubit index or list of qubit indices this gate acts on.
572 When multiple wires are given the matrix is automatically
573 expanded to the matching ``2**k × 2**k`` identity.
574 """
575 w = list(wires) if isinstance(wires, (list, tuple)) else [wires]
576 k = len(w)
577 if k > 1:
578 kwargs["matrix"] = jnp.eye(2**k, dtype=_cdtype())
579 super().__init__(wires=wires, **kwargs)
582class PauliX(Operation):
583 """Pauli-X gate / observable (bit-flip, \\sigma_x)."""
585 _matrix = jnp.array([[0, 1], [1, 0]], dtype=_cdtype())
586 _num_wires = 1
588 def __init__(self, wires: Union[int, List[int]] = 0, **kwargs) -> None:
589 """Initialise a Pauli-X gate.
591 Args:
592 wires: Qubit index or list of qubit indices this gate acts on.
593 """
594 super().__init__(wires=wires, **kwargs)
597class PauliY(Operation):
598 """Pauli-Y gate / observable (\\sigma_y)."""
600 _matrix = jnp.array([[0, -1j], [1j, 0]], dtype=_cdtype())
601 _num_wires = 1
603 def __init__(self, wires: Union[int, List[int]] = 0, **kwargs) -> None:
604 """Initialise a Pauli-Y gate.
606 Args:
607 wires: Qubit index or list of qubit indices this gate acts on.
608 """
609 super().__init__(wires=wires, **kwargs)
612class PauliZ(Operation):
613 """Pauli-Z gate / observable (phase-flip, \\sigma_z)."""
615 _matrix = jnp.array([[1, 0], [0, -1]], dtype=_cdtype())
616 _num_wires = 1
618 def __init__(self, wires: Union[int, List[int]] = 0, **kwargs) -> None:
619 """Initialise a Pauli-Z gate.
621 Args:
622 wires: Qubit index or list of qubit indices this gate acts on.
623 """
624 super().__init__(wires=wires, **kwargs)
627class H(Operation):
628 """Hadamard gate."""
630 _matrix = jnp.array([[1, 1], [1, -1]], dtype=_cdtype()) / jnp.sqrt(2)
631 _num_wires = 1
633 def __init__(self, wires: Union[int, List[int]] = 0, **kwargs) -> None:
634 """Initialise a Hadamard gate.
636 Args:
637 wires: Qubit index or list of qubit indices this gate acts on.
638 """
639 super().__init__(wires=wires, **kwargs)
642class S(Operation):
643 """S (phase) gate — a Clifford gate equal to \\sqrt Z.
645 .. math::
646 S = \\begin{pmatrix}1 & 0\\ 0 & i\\end{pmatrix}
647 """
649 _matrix = jnp.array([[1, 0], [0, 1j]], dtype=_cdtype())
650 _num_wires = 1
652 def __init__(self, wires: Union[int, List[int]] = 0) -> None:
653 """Initialise an S gate.
655 Args:
656 wires: Qubit index or list of qubit indices this gate acts on.
657 """
658 super().__init__(wires=wires)
661class SWAP(Operation):
662 """SWAP gate."""
664 _matrix = jnp.array(
665 [[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=_cdtype()
666 )
667 _num_wires = 2
669 def __init__(self, wires: Union[int, List[int]] = 0, **kwargs) -> None:
670 """Initialise a SWAP gate.
672 Args:
673 wires: Qubit index or list of qubit indices this gate acts on.
674 """
675 super().__init__(wires=wires, **kwargs)
678class RandomUnitary(Operation):
679 """Creates a random hermitian matrix and applies it as a gate."""
681 def __init__(self, wires, key, scale=1.0, record=True):
682 """Initialise a random unitary gate.
684 Args:
685 wires: Qubit index or list of qubit indices this gate acts on.
686 jax.random.PRNGKey: PRNGKey for randomization
687 scale: Scale of the random unitary (default: 1.0)
688 """
689 dim = 2 ** len(wires)
690 key_a, key_b = jax.random.split(key)
692 A = (
693 jax.random.normal(key=key_a, shape=(dim, dim))
694 + 1j * jax.random.normal(key=key_b, shape=(dim, dim))
695 ).astype(_cdtype())
696 H = (A + A.conj().T) / 2.0
698 H *= scale / jnp.linalg.norm(H, ord="fro")
700 super().__init__(wires, matrix=H, record=record)
703class Barrier(Operation):
704 """Barrier operation — a no-op used for visual circuit separation.
706 The barrier does not change the quantum state. It is recorded on the
707 tape so that drawing backends can insert a visual separator.
708 """
710 _matrix = None # not a real gate
712 def __init__(self, wires: Union[int, List[int]] = 0) -> None:
713 """Initialise a Barrier.
715 Args:
716 wires: Qubit index or list of qubit indices this barrier spans.
717 """
718 super().__init__(wires=wires)
720 def apply_to_state(self, state: jnp.ndarray, n_qubits: int) -> jnp.ndarray:
721 """No-op: return the state unchanged."""
722 return state
724 def apply_to_state_tensor(self, psi: jnp.ndarray, n_qubits: int) -> jnp.ndarray:
725 """No-op: return the state tensor unchanged."""
726 return psi
728 def apply_to_density(self, rho: jnp.ndarray, n_qubits: int) -> jnp.ndarray:
729 """No-op: return the density matrix unchanged."""
730 return rho
733def _make_rotation_gate(pauli_class: type, name: str) -> type:
734 """Factory for single-qubit rotation gates RX, RY, RZ.
736 Each gate has the form ``R_P(\\theta) = cos(\\theta/2) I - i sin(\\theta/2) P``.
738 Args:
739 pauli_class: One of PauliX, PauliY, PauliZ.
740 name: Class name for the generated gate (e.g. ``"RX"``).
742 Returns:
743 A new :class:`Operation` subclass.
744 """
745 pauli_mat = pauli_class._matrix
747 class _RotationGate(Operation):
748 # Fancy way of setting docstring to make it generic
749 __doc__ = (
750 f"Rotation around the {name[1]} axis: {name}(\\theta) =\n"
751 f"exp(-i \\theta/2 {name[1]}).\n"
752 )
753 _num_wires = 1
754 _param_names = ("theta",)
756 def __init__(
757 self, theta: float, wires: Union[int, List[int]] = 0, **kwargs
758 ) -> None:
759 self.theta = theta
760 c = jnp.cos(theta / 2)
761 s = jnp.sin(theta / 2)
762 mat = c * Id._matrix - 1j * s * pauli_mat
763 super().__init__(wires=wires, matrix=mat, **kwargs)
765 def generator(self) -> Operation:
766 """Return the generator as the corresponding Pauli operation."""
767 return pauli_class(wires=self.wires[0], record=False)
769 _RotationGate.__name__ = name
770 _RotationGate.__qualname__ = name
771 return _RotationGate
774RX = _make_rotation_gate(PauliX, "RX")
775RY = _make_rotation_gate(PauliY, "RY")
776RZ = _make_rotation_gate(PauliZ, "RZ")
779# Projectors used by controlled-gate factories
780_P0 = jnp.array([[1, 0], [0, 0]], dtype=_cdtype())
781_P1 = jnp.array([[0, 0], [0, 1]], dtype=_cdtype())
784def _make_controlled_gate(target_class: type, name: str) -> type:
785 """Factory for controlled Pauli gates CX, CY, CZ.
787 Each gate has the form
788 ``CP = |0><0| \\otimes I + |1\\langle\\rangle 1| \\otimes P``.
790 Args:
791 target_class: The single-qubit gate class (PauliX, PauliY, PauliZ).
792 name: Class name for the generated gate (e.g. ``"CX"``).
794 Returns:
795 A new :class:`Operation` subclass.
796 """
797 target_mat = target_class._matrix
799 class _ControlledGate(Operation):
800 __doc__ = (
801 f"Controlled-{target_class.__name__[5:]} gate.\n\n"
802 f"Applies {target_class.__name__} on the target qubit conditioned "
803 f"on the control qubit being in state |1\\rangle."
804 )
805 _matrix = jnp.kron(_P0, Id._matrix) + jnp.kron(_P1, target_mat)
806 _num_wires = 2
807 is_controlled = True
809 def __init__(self, wires: List[int] = [0, 1], **kwargs) -> None:
810 super().__init__(wires=wires, **kwargs)
812 _ControlledGate.__name__ = name
813 _ControlledGate.__qualname__ = name
814 return _ControlledGate
817CX = _make_controlled_gate(PauliX, "CX")
818CY = _make_controlled_gate(PauliY, "CY")
819CZ = _make_controlled_gate(PauliZ, "CZ")
822class CCX(Operation):
823 """Toffoli (CCX) gate.
825 The 3-qubit Toffoli gate exercises the arbitrary-k-qubit path in
826 :meth:`~Operation.apply_to_state` and cannot be expressed as a pair of
827 2-qubit gates without ancilla, making it a good stress-test for the
828 simulator.
829 """
831 _matrix = jnp.array(
832 [
833 [1, 0, 0, 0, 0, 0, 0, 0],
834 [0, 1, 0, 0, 0, 0, 0, 0],
835 [0, 0, 1, 0, 0, 0, 0, 0],
836 [0, 0, 0, 1, 0, 0, 0, 0],
837 [0, 0, 0, 0, 1, 0, 0, 0],
838 [0, 0, 0, 0, 0, 1, 0, 0],
839 [0, 0, 0, 0, 0, 0, 0, 1],
840 [0, 0, 0, 0, 0, 0, 1, 0],
841 ],
842 dtype=_cdtype(),
843 )
844 is_controlled = True
845 _num_wires = 3
847 def __init__(self, wires: List[int] = [0, 1, 2], **kwargs) -> None:
848 """Initialise a Toffoli (CCX) gate.
850 Args:
851 wires: Three-element list ``[control0, control1, target]``.
852 """
853 super().__init__(wires=wires, **kwargs)
856class CSWAP(Operation):
857 """Controlled-SWAP (Fredkin) gate.
859 Swaps the two target qubits conditioned on the control qubit being |1\\rangle.
861 Args on construction:
862 wires: ``[control, target0, target1]``.
863 """
865 _matrix = jnp.array(
866 [
867 [1, 0, 0, 0, 0, 0, 0, 0],
868 [0, 1, 0, 0, 0, 0, 0, 0],
869 [0, 0, 1, 0, 0, 0, 0, 0],
870 [0, 0, 0, 1, 0, 0, 0, 0],
871 [0, 0, 0, 0, 1, 0, 0, 0],
872 [0, 0, 0, 0, 0, 0, 1, 0],
873 [0, 0, 0, 0, 0, 1, 0, 0],
874 [0, 0, 0, 0, 0, 0, 0, 1],
875 ],
876 dtype=_cdtype(),
877 )
878 is_controlled = True
879 _num_wires = 3
881 def __init__(self, wires: List[int] = [0, 1, 2], **kwargs) -> None:
882 """Initialise a Controlled-SWAP (Fredkin) gate.
884 Args:
885 wires: Three-element list ``[control, target0, target1]``.
886 """
887 super().__init__(wires=wires, **kwargs)
890def _make_controlled_rotation_gate(pauli_class: type, name: str) -> type:
891 """Factory for controlled rotation gates CRX, CRY, CRZ.
893 Each gate has the form
894 ``CR_P(\\theta) = |0><0| \\otimes I + |1><1| \\otimes R_P(\\theta)``.
896 Args:
897 pauli_class: One of PauliX, PauliY, PauliZ.
898 name: Class name for the generated gate (e.g. ``"CRX"``).
900 Returns:
901 A new :class:`Operation` subclass.
902 """
903 pauli_mat = pauli_class._matrix
905 class _CRotationGate(Operation):
906 __doc__ = (
907 f"Controlled rotation around the {name[2]} axis.\n\n"
908 f"Applies R{name[2]}(\\theta) on the target qubit conditioned on the "
909 f"control qubit being in state |1\\rangle.\n\n"
910 f".. math::\n"
911 f"{name}(\\theta) = |0\\rangle\\langle 0| \\otimes I\n"
912 f" + |1\\rangle\\langle 1| \\otimes R{name[2]}(\\theta)"
913 )
914 _num_wires = 2
915 _param_names = ("theta",)
916 is_controlled = True
918 def __init__(self, theta: float, wires: List[int] = [0, 1], **kwargs) -> None:
919 self.theta = theta
920 c = jnp.cos(theta / 2)
921 s = jnp.sin(theta / 2)
922 rot = c * Id._matrix - 1j * s * pauli_mat
923 mat = jnp.kron(_P0, Id._matrix) + jnp.kron(_P1, rot)
924 super().__init__(wires=wires, matrix=mat, **kwargs)
926 _CRotationGate.__name__ = name
927 _CRotationGate.__qualname__ = name
928 return _CRotationGate
931CRX = _make_controlled_rotation_gate(PauliX, "CRX")
932CRY = _make_controlled_rotation_gate(PauliY, "CRY")
933CRZ = _make_controlled_rotation_gate(PauliZ, "CRZ")
936class Rot(Operation):
937 """General single-qubit rotation:
938 Rot(\\phi, \\theta, \\omega) = RZ(\\omega) RY(\\theta) RZ(\\phi).
940 This is the most general SU(2) rotation (up to a global phase). It
941 decomposes into three successive rotations and has three free parameters.
942 """
944 _num_wires = 1
945 _param_names = ("phi", "theta", "omega")
947 def __init__(
948 self,
949 phi: float,
950 theta: float,
951 omega: float,
952 wires: Union[int, List[int]] = 0,
953 **kwargs,
954 ) -> None:
955 """Initialise a general rotation gate.
957 Args:
958 phi: First RZ rotation angle (radians).
959 theta: RY rotation angle (radians).
960 omega: Second RZ rotation angle (radians).
961 wires: Qubit index or list of qubit indices this gate acts on.
962 """
963 self.phi = phi
964 self.theta = theta
965 self.omega = omega
966 # Rot(\\phi, \theta, \\omega) = RZ(\\omega) @ RY(\theta) @ RZ(\\phi)
967 rz_phi = jnp.cos(phi / 2) * Id._matrix - 1j * jnp.sin(phi / 2) * PauliZ._matrix
968 ry_theta = (
969 jnp.cos(theta / 2) * Id._matrix - 1j * jnp.sin(theta / 2) * PauliY._matrix
970 )
971 rz_omega = (
972 jnp.cos(omega / 2) * Id._matrix - 1j * jnp.sin(omega / 2) * PauliZ._matrix
973 )
974 mat = rz_omega @ ry_theta @ rz_phi
975 super().__init__(wires=wires, matrix=mat, **kwargs)
978class PauliRot(Operation):
979 """Multi-qubit Pauli rotation: exp(-i \\theta/2 P) for a Pauli word P.
981 The Pauli word is given as a string of ``'I'``, ``'X'``, ``'Y'``, ``'Z'``
982 characters (one per qubit). The rotation matrix is computed as
983 ``cos(\\theta/2) I - i sin(\\theta/2) P`` where *P* is the tensor product of the
984 corresponding single-qubit Pauli matrices.
986 Example::
988 PauliRot(0.5, "XY", wires=[0, 1])
989 """
991 _param_names = ("theta",)
993 # Map from character to 2x2 matrix
994 _PAULI_MAP = {
995 "I": Id._matrix,
996 "X": PauliX._matrix,
997 "Y": PauliY._matrix,
998 "Z": PauliZ._matrix,
999 }
1001 def __init__(
1002 self, theta: float, pauli_word: str, wires: Union[int, List[int]] = 0, **kwargs
1003 ) -> None:
1004 """Initialise a PauliRot gate.
1006 Args:
1007 theta: Rotation angle in radians.
1008 pauli_word: A string of ``'I'``, ``'X'``, ``'Y'``, ``'Z'``
1009 characters specifying the Pauli tensor product.
1010 wires: Qubit index or list of qubit indices this gate acts on.
1011 """
1012 from functools import reduce as _reduce
1014 self.theta = theta
1015 self.pauli_word = pauli_word
1017 pauli_matrices = [self._PAULI_MAP[c] for c in pauli_word]
1018 P = _reduce(jnp.kron, pauli_matrices)
1019 dim = P.shape[0]
1020 mat = (
1021 jnp.cos(theta / 2) * jnp.eye(dim, dtype=_cdtype())
1022 - 1j * jnp.sin(theta / 2) * P
1023 )
1024 super().__init__(wires=wires, matrix=mat, **kwargs)
1026 def generator(self) -> Operation:
1027 """Return the generator Pauli tensor product as an :class:`Operation`.
1029 The generator of ``PauliRot(\\theta, word, wires)`` is the tensor product
1030 of single-qubit Pauli matrices specified by *word*. The returned
1031 :class:`Hermitian` wraps that matrix and the gate's wires.
1033 Returns:
1034 :class:`Hermitian` operation representing the Pauli tensor product.
1035 """
1036 from functools import reduce as _reduce
1038 pauli_matrices = [self._PAULI_MAP[c] for c in self.pauli_word]
1039 P = _reduce(jnp.kron, pauli_matrices)
1040 return Hermitian(matrix=P, wires=self.wires, record=False)
1043class KrausChannel(Operation):
1044 """Base class for noise channels defined by a set of Kraus operators.
1046 A Kraus channel \\phi(\\rho ) = \\sigma_k K_k \\rho K_k\\dagger
1047 is the most general physical
1048 operation on a quantum state. For a pure unitary gate there is a single
1049 operator K_0 = U satisfying K_0\\daggerK_0 = I; for noisy channels there are
1050 multiple operators.
1052 Subclasses must implement :meth:`kraus_matrices` and return a list of JAX
1053 arrays. :meth:`apply_to_state` is intentionally left unimplemented:
1054 Kraus channels require a density-matrix representation and cannot be
1055 applied to a pure statevector in general.
1056 """
1058 def kraus_matrices(self) -> List[jnp.ndarray]:
1059 """Return the list of Kraus operators for this channel.
1061 Returns:
1062 List of 2-D JAX arrays, each of shape ``(2**k, 2**k)`` where k
1063 is the number of target qubits.
1065 Raises:
1066 NotImplementedError: Subclasses must override this method.
1067 """
1068 raise NotImplementedError
1070 @property
1071 def matrix(self) -> jnp.ndarray:
1072 """Raises TypeError — noise channels have no single unitary matrix.
1074 Raises:
1075 TypeError: Always raised; use :meth:`apply_to_density` instead.
1076 """
1077 raise TypeError(
1078 f"{self.__class__.__name__} is a noise channel and has no single "
1079 "unitary matrix. Use apply_to_density() instead."
1080 )
1082 def apply_to_state(self, state: jnp.ndarray, n_qubits: int) -> jnp.ndarray:
1083 """Raises TypeError — noise channels require density-matrix simulation.
1085 Args:
1086 state: Statevector (unused).
1087 n_qubits: Number of qubits (unused).
1089 Raises:
1090 TypeError: Always raised; use ``execute(type='density')`` instead.
1091 """
1092 raise TypeError(
1093 f"{self.__class__.__name__} is a noise channel and cannot be "
1094 "applied to a pure statevector. Use execute(type='density') instead."
1095 )
1097 def apply_to_state_tensor(self, psi: jnp.ndarray, n_qubits: int) -> jnp.ndarray:
1098 """Raises TypeError — noise channels require density-matrix simulation."""
1099 raise TypeError(
1100 f"{self.__class__.__name__} is a noise channel and cannot be "
1101 "applied to a pure statevector. Use execute(type='density') instead."
1102 )
1104 def apply_to_density(self, rho: jnp.ndarray, n_qubits: int) -> jnp.ndarray:
1105 """Apply
1106 \\phi(\\rho ) = \\sigma_k K_k \\rho K_k\\dagger using tensor-contraction.
1108 Uses the shared :func:`_contract_and_restore` helper, summing the
1109 result over all Kraus operators.
1111 Args:
1112 rho: Density matrix of shape ``(2**n_qubits, 2**n_qubits)``.
1113 n_qubits: Total number of qubits in the circuit.
1115 Returns:
1116 Updated density matrix of shape ``(2**n_qubits, 2**n_qubits)``.
1117 """
1118 k = len(self.wires)
1119 dim = 2**n_qubits
1120 bra_wires = [w + n_qubits for w in self.wires]
1121 rho_out = jnp.zeros_like(rho)
1123 for K in self.kraus_matrices():
1124 K_t = K.reshape((2,) * 2 * k)
1125 K_conj_t = jnp.conj(K_t)
1126 rho_t = rho.reshape((2,) * 2 * n_qubits)
1127 rho_t = _contract_and_restore(rho_t, K_t, k, self.wires)
1128 rho_t = _contract_and_restore(rho_t, K_conj_t, k, bra_wires)
1129 rho_out = rho_out + rho_t.reshape(dim, dim)
1131 return rho_out
1134class BitFlip(KrausChannel):
1135 r"""Single-qubit bit-flip (Pauli-X) error channel.
1137 .. math::
1138 K_0 = \sqrt{1-p}\,I, \quad K_1 = \sqrt{p}\,X
1140 where *p* \\in [0, 1] is the probability of a bit flip.
1141 """
1143 _num_wires = 1
1144 _param_names = ("p",)
1146 def __init__(self, p: float, wires: Union[int, List[int]] = 0) -> None:
1147 """Initialise a bit-flip channel.
1149 Args:
1150 p: Bit-flip probability, must be in [0, 1].
1151 wires: Qubit index or list of qubit indices this channel acts on.
1153 Raises:
1154 ValueError: If *p* is outside [0, 1].
1155 """
1156 if not 0.0 <= p <= 1.0:
1157 raise ValueError("p must be in [0, 1].")
1158 self.p = p
1159 super().__init__(wires=wires)
1161 def kraus_matrices(self) -> List[jnp.ndarray]:
1162 """Return the two Kraus operators for the bit-flip channel.
1164 Returns:
1165 List ``[K0, K1]`` where K0 = \\sqrt (1-p)·I and K1 = \\sqrt p·X.
1166 """
1167 p = self.p
1168 K0 = jnp.sqrt(1 - p) * Id._matrix
1169 K1 = jnp.sqrt(p) * PauliX._matrix
1170 return [K0, K1]
1173class PhaseFlip(KrausChannel):
1174 r"""Single-qubit phase-flip (Pauli-Z) error channel.
1176 .. math::
1177 K_0 = \sqrt{1-p}\,I, \quad K_1 = \sqrt{p}\,Z
1179 where *p* \\in [0, 1] is the probability of a phase flip.
1180 """
1182 _num_wires = 1
1183 _param_names = ("p",)
1185 def __init__(self, p: float, wires: Union[int, List[int]] = 0) -> None:
1186 """Initialise a phase-flip channel.
1188 Args:
1189 p: Phase-flip probability, must be in [0, 1].
1190 wires: Qubit index or list of qubit indices this channel acts on.
1192 Raises:
1193 ValueError: If *p* is outside [0, 1].
1194 """
1195 if not 0.0 <= p <= 1.0:
1196 raise ValueError("p must be in [0, 1].")
1197 self.p = p
1198 super().__init__(wires=wires)
1200 def kraus_matrices(self) -> List[jnp.ndarray]:
1201 """Return the two Kraus operators for the phase-flip channel.
1203 Returns:
1204 List ``[K0, K1]`` where K0 = \\sqrt (1-p)·I and K1 = \\sqrt p·Z.
1205 """
1206 p = self.p
1207 K0 = jnp.sqrt(1 - p) * Id._matrix
1208 K1 = jnp.sqrt(p) * PauliZ._matrix
1209 return [K0, K1]
1212class DepolarizingChannel(KrausChannel):
1213 r"""Single-qubit depolarizing channel.
1215 .. math::
1216 K_0 = \sqrt{1-p}\,I,\quad K_1 = \sqrt{p/3}\,X,\quad
1217 K_2 = \sqrt{p/3}\,Y,\quad K_3 = \sqrt{p/3}\,Z
1219 where *p* \\in [0, 1]. At p = 3/4 the channel is fully depolarizing.
1220 """
1222 _num_wires = 1
1223 _param_names = ("p",)
1225 def __init__(self, p: float, wires: Union[int, List[int]] = 0) -> None:
1226 """Initialise a depolarizing channel.
1228 Args:
1229 p: Depolarization probability, must be in [0, 1].
1230 wires: Qubit index or list of qubit indices this channel acts on.
1232 Raises:
1233 ValueError: If *p* is outside [0, 1].
1234 """
1235 if not 0.0 <= p <= 1.0:
1236 raise ValueError("p must be in [0, 1].")
1237 self.p = p
1238 super().__init__(wires=wires)
1240 def kraus_matrices(self) -> List[jnp.ndarray]:
1241 """Return the four Kraus operators for the depolarizing channel.
1243 Returns:
1244 List ``[K0, K1, K2, K3]`` corresponding to I, X, Y, Z components.
1245 """
1246 p = self.p
1247 K0 = jnp.sqrt(1 - p) * Id._matrix
1248 K1 = jnp.sqrt(p / 3) * PauliX._matrix
1249 K2 = jnp.sqrt(p / 3) * PauliY._matrix
1250 K3 = jnp.sqrt(p / 3) * PauliZ._matrix
1251 return [K0, K1, K2, K3]
1254class AmplitudeDamping(KrausChannel):
1255 r"""Single-qubit amplitude damping channel.
1257 .. math::
1258 K_0 = \begin{pmatrix}1 & 0\\ 0 & \sqrt{1-\gamma}\end{pmatrix},\quad
1259 K_1 = \begin{pmatrix}0 & \sqrt{\gamma}\\ 0 & 0\end{pmatrix}
1261 where *\\gamma* \\in [0, 1] is the probability of
1262 energy loss (|1\\rangle -> |0\\rangle).
1263 """
1265 _num_wires = 1
1266 _param_names = ("gamma",)
1268 def __init__(self, gamma: float, wires: Union[int, List[int]] = 0) -> None:
1269 """Initialise an amplitude damping channel.
1271 Args:
1272 gamma: Energy-loss probability, must be in [0, 1].
1273 wires: Qubit index or list of qubit indices this channel acts on.
1275 Raises:
1276 ValueError: If *gamma* is outside [0, 1].
1277 """
1278 if not 0.0 <= gamma <= 1.0:
1279 raise ValueError("gamma must be in [0, 1].")
1280 self.gamma = gamma
1281 super().__init__(wires=wires)
1283 def kraus_matrices(self) -> List[jnp.ndarray]:
1284 """Return the two Kraus operators for the amplitude damping channel.
1286 Returns:
1287 List ``[K0, K1]`` as defined in the class docstring.
1288 """
1289 g = self.gamma
1290 K0 = jnp.array([[1.0, 0.0], [0.0, jnp.sqrt(1 - g)]], dtype=_cdtype())
1291 K1 = jnp.array([[0.0, jnp.sqrt(g)], [0.0, 0.0]], dtype=_cdtype())
1292 return [K0, K1]
1295class PhaseDamping(KrausChannel):
1296 r"""Single-qubit phase damping (dephasing) channel.
1298 .. math::
1299 K_0 = \begin{pmatrix}1 & 0\\ 0 & \sqrt{1-\gamma}\end{pmatrix},\quad
1300 K_1 = \begin{pmatrix}0 & 0\\ 0 & \sqrt{\gamma}\end{pmatrix}
1302 where *\\gamma* \\in [0, 1] is the phase damping probability.
1303 """
1305 _num_wires = 1
1306 _param_names = ("gamma",)
1308 def __init__(self, gamma: float, wires: Union[int, List[int]] = 0) -> None:
1309 """Initialise a phase damping channel.
1311 Args:
1312 gamma: Phase-damping probability, must be in [0, 1].
1313 wires: Qubit index or list of qubit indices this channel acts on.
1315 Raises:
1316 ValueError: If *gamma* is outside [0, 1].
1317 """
1318 if not 0.0 <= gamma <= 1.0:
1319 raise ValueError("gamma must be in [0, 1].")
1320 self.gamma = gamma
1321 super().__init__(wires=wires)
1323 def kraus_matrices(self) -> List[jnp.ndarray]:
1324 """Return the two Kraus operators for the phase damping channel.
1326 Returns:
1327 List ``[K0, K1]`` as defined in the class docstring.
1328 """
1329 g = self.gamma
1330 K0 = jnp.array([[1.0, 0.0], [0.0, jnp.sqrt(1 - g)]], dtype=_cdtype())
1331 K1 = jnp.array([[0.0, 0.0], [0.0, jnp.sqrt(g)]], dtype=_cdtype())
1332 return [K0, K1]
1335class ThermalRelaxationError(KrausChannel):
1336 r"""Single-qubit thermal relaxation error channel.
1338 Models simultaneous T_1 energy relaxation and T_2 dephasing. Two regimes
1339 are handled:
1341 T_2 <= T_1 (Markovian dephasing + reset):
1342 Six Kraus operators built from p_z (phase-flip probability), p_r0
1343 (reset-to-|0\\rangle probability) and p_r1 (reset-to-|1\\rangle probability).
1345 T_2 > T_1 (non-Markovian; Choi matrix decomposition):
1346 The Choi matrix is assembled from the relaxation/dephasing rates, then
1347 diagonalised; Kraus operators are K_i = \sqrt \lambda_i · mat(v_i).
1349 Attributes:
1350 pe: Excited-state population (thermal population of |1\\rangle).
1351 t1: T_1 longitudinal relaxation time.
1352 t2: T_2 transverse dephasing time.
1353 tg: Gate duration.
1354 """
1356 _num_wires = 1
1357 _param_names = ("pe", "t1", "t2", "tg")
1359 def __init__(
1360 self,
1361 pe: float,
1362 t1: float,
1363 t2: float,
1364 tg: float,
1365 wires: Union[int, List[int]] = 0,
1366 ) -> None:
1367 """Initialise a thermal relaxation error channel.
1369 Args:
1370 pe: Excited-state population (thermal population of |1\\rangle), in [0, 1].
1371 t1: T_1 longitudinal relaxation time, must be > 0.
1372 t2: T_2 transverse dephasing time, must be > 0 and <= 2·T_1.
1373 tg: Gate duration, must be >= 0.
1374 wires: Qubit index or list of qubit indices this channel acts on.
1376 Raises:
1377 ValueError: If any parameter violates the stated constraints.
1378 """
1379 if not 0.0 <= pe <= 1.0:
1380 raise ValueError("pe must be in [0, 1].")
1381 if t1 <= 0:
1382 raise ValueError("t1 must be > 0.")
1383 if t2 <= 0:
1384 raise ValueError("t2 must be > 0.")
1385 if t2 > 2 * t1:
1386 raise ValueError("t2 must be <= 2·t1.")
1387 if tg < 0:
1388 raise ValueError("tg must be >= 0.")
1389 self.pe = pe
1390 self.t1 = t1
1391 self.t2 = t2
1392 self.tg = tg
1393 super().__init__(wires=wires)
1395 def kraus_matrices(self) -> List[jnp.ndarray]:
1396 """Return the Kraus operators for the thermal relaxation channel.
1398 The number of operators depends on the regime:
1400 * T_2 <= T_1: six operators (identity, phase-flip, two reset-to-|0\\rangle,
1401 two reset-to-|1\\rangle).
1402 * T_2 > T_1: four operators derived from the Choi matrix eigendecomposition.
1404 Returns:
1405 List of 2x2 JAX arrays representing the Kraus operators.
1406 """
1407 pe, t1, t2, tg = self.pe, self.t1, self.t2, self.tg
1409 eT1 = jnp.exp(-tg / t1)
1410 p_reset = 1.0 - eT1
1411 eT2 = jnp.exp(-tg / t2)
1413 if t2 <= t1:
1414 # --- Case T_2 <= T_1: six Kraus operators ---
1415 pz = (1.0 - p_reset) * (1.0 - eT2 / eT1) / 2.0
1416 pr0 = (1.0 - pe) * p_reset
1417 pr1 = pe * p_reset
1418 pid = 1.0 - pz - pr0 - pr1
1420 K0 = jnp.sqrt(pid) * jnp.eye(2, dtype=_cdtype())
1421 K1 = jnp.sqrt(pz) * jnp.array([[1, 0], [0, -1]], dtype=_cdtype())
1422 K2 = jnp.sqrt(pr0) * jnp.array([[1, 0], [0, 0]], dtype=_cdtype())
1423 K3 = jnp.sqrt(pr0) * jnp.array([[0, 1], [0, 0]], dtype=_cdtype())
1424 K4 = jnp.sqrt(pr1) * jnp.array([[0, 0], [1, 0]], dtype=_cdtype())
1425 K5 = jnp.sqrt(pr1) * jnp.array([[0, 0], [0, 1]], dtype=_cdtype())
1426 return [K0, K1, K2, K3, K4, K5]
1428 else:
1429 # --- Case T_2 > T_1: Choi matrix decomposition ---
1430 # Choi matrix (column-major / reshaping convention matching PennyLane)
1431 choi = jnp.array(
1432 [
1433 [1 - pe * p_reset, 0, 0, eT2],
1434 [0, pe * p_reset, 0, 0],
1435 [0, 0, (1 - pe) * p_reset, 0],
1436 [eT2, 0, 0, 1 - (1 - pe) * p_reset],
1437 ],
1438 dtype=_cdtype(),
1439 )
1440 eigenvalues, eigenvectors = jnp.linalg.eigh(choi)
1441 # Each eigenvector (column of eigenvectors) reshaped as 2x2 -> one Kraus op
1442 kraus = []
1443 for i in range(4):
1444 lam = eigenvalues[i]
1445 vec = eigenvectors[:, i]
1446 mat = jnp.sqrt(jnp.abs(lam)) * vec.reshape(2, 2, order="F")
1447 kraus.append(mat.astype(_cdtype()))
1448 return kraus
1451class QubitChannel(KrausChannel):
1452 """Generic Kraus channel from a user-supplied list of Kraus operators.
1454 This replaces PennyLane's ``qml.QubitChannel`` and accepts an arbitrary set
1455 of Kraus matrices satisfying \\sigma_k K_k\\dagger K_k = I.
1457 Example::
1459 kraus_ops = [jnp.sqrt(0.9) * jnp.eye(2), jnp.sqrt(0.1) * PauliX._matrix]
1460 QubitChannel(kraus_ops, wires=0)
1461 """
1463 def __init__(
1464 self, kraus_ops: List[jnp.ndarray], wires: Union[int, List[int]] = 0
1465 ) -> None:
1466 """Initialise a generic Kraus channel.
1468 Args:
1469 kraus_ops: List of Kraus matrices. Each must be a square 2D array
1470 of dimension ``2**k x 2**k`` where k = ``len(wires)``.
1471 wires: Qubit index or list of qubit indices this channel acts on.
1472 """
1473 self._kraus_ops = [jnp.asarray(K, dtype=_cdtype()) for K in kraus_ops]
1474 super().__init__(wires=wires)
1476 def kraus_matrices(self) -> List[jnp.ndarray]:
1477 """Return the stored Kraus operators.
1479 Returns:
1480 List of Kraus operator matrices.
1481 """
1482 return self._kraus_ops
1485# Single-qubit Pauli matrices (plain arrays, no Operation overhead)
1486_PAULI_MATS = [Id._matrix, PauliX._matrix, PauliY._matrix, PauliZ._matrix]
1487_PAULI_LABELS = ["I", "X", "Y", "Z"]
1488_PAULI_CLASSES = [Id, PauliX, PauliY, PauliZ]
1491def evolve_pauli_with_clifford(
1492 clifford: Operation,
1493 pauli: Operation,
1494 adjoint_left: bool = True,
1495) -> Operation:
1496 """Compute C\\dagger P C (or C P C\\dagger) and
1497 return the result as an Operation.
1499 Both operators are first embedded into the full Hilbert space spanned by
1500 the union of their wire sets. The result is wrapped in a
1501 :class:`Hermitian` so it can be used in further algebra.
1503 Args:
1504 clifford: A Clifford gate.
1505 pauli: A Pauli / Hermitian operator.
1506 adjoint_left: If ``True``, compute C\\dagger P C; otherwise C P C\\dagger.
1508 Returns:
1509 A :class:`Hermitian` wrapping the evolved matrix.
1510 """
1511 all_wires = sorted(set(clifford.wires) | set(pauli.wires))
1512 n = len(all_wires)
1514 C = _embed_matrix(clifford.matrix, clifford.wires, all_wires, n)
1515 P = _embed_matrix(pauli.matrix, pauli.wires, all_wires, n)
1516 Cd = jnp.conj(C).T
1518 if adjoint_left:
1519 result = Cd @ P @ C
1520 else:
1521 result = C @ P @ Cd
1523 return Hermitian(matrix=result, wires=all_wires, record=False)
1526def _embed_matrix(
1527 mat: jnp.ndarray,
1528 op_wires: list,
1529 all_wires: list,
1530 n_total: int,
1531) -> jnp.ndarray:
1532 """Embed a gate matrix into a larger Hilbert space via tensor products.
1534 If the gate already acts on all wires, the matrix is returned as-is.
1535 Otherwise the gate matrix is tensored with identities on the missing
1536 wires, and the resulting matrix rows/columns are permuted so that qubit
1537 ordering matches *all_wires*.
1539 Args:
1540 mat: The gate's unitary matrix of shape ``(2**k, 2**k)`` where
1541 ``k = len(op_wires)``.
1542 op_wires: The wires the gate acts on.
1543 all_wires: The full ordered list of wires.
1544 n_total: ``len(all_wires)``.
1546 Returns:
1547 A ``(2**n_total, 2**n_total)`` matrix.
1548 """
1549 k = len(op_wires)
1550 if k == n_total and list(op_wires) == list(all_wires):
1551 return mat
1553 # Build the full-space matrix by tensoring with identities
1554 # Strategy: tensor I on missing wires, then permute
1555 missing = [w for w in all_wires if w not in op_wires]
1556 # Full matrix = mat \\otimes I_{missing}
1557 full_mat = mat
1558 for _ in missing:
1559 full_mat = jnp.kron(full_mat, jnp.eye(2, dtype=_cdtype()))
1561 # The current ordering is [op_wires..., missing...]
1562 # We need to permute to match all_wires ordering
1563 current_order = list(op_wires) + missing
1564 if current_order != list(all_wires):
1565 perm = [current_order.index(w) for w in all_wires]
1566 full_mat = _permute_matrix(full_mat, perm, n_total)
1568 return full_mat
1571def _permute_matrix(mat: jnp.ndarray, perm: list, n_qubits: int) -> jnp.ndarray:
1572 """Permute the qubit ordering of a matrix.
1574 Given a ``(2**n, 2**n)`` matrix and a permutation of ``[0..n-1]``,
1575 reorder the qubits so that qubit ``i`` moves to position ``perm[i]``.
1577 Args:
1578 mat: Square matrix of dimension ``2**n_qubits``.
1579 perm: Permutation list.
1580 n_qubits: Number of qubits.
1582 Returns:
1583 Permuted matrix of the same shape.
1584 """
1585 dim = 2**n_qubits
1586 # Reshape to tensor, permute axes, reshape back
1587 tensor = mat.reshape([2] * (2 * n_qubits))
1588 # Axes: first n_qubits are row indices, last n_qubits are column indices
1589 row_perm = perm
1590 col_perm = [p + n_qubits for p in perm]
1591 tensor = jnp.transpose(tensor, row_perm + col_perm)
1592 return tensor.reshape(dim, dim)
1595def pauli_decompose(matrix: jnp.ndarray, wire_order: Optional[List[int]] = None):
1596 r"""Decompose a Hermitian matrix into a sum of Pauli tensor products.
1598 For an n-qubit matrix (``2**n x 2**n``), returns the dominant Pauli
1599 term (the one with the largest absolute coefficient), wrapped as an
1600 :class:`Operation`. This is sufficient for the Fourier-tree algorithm
1601 which only needs the single non-zero Pauli term produced by Clifford
1602 conjugation of a Pauli operator.
1604 The decomposition uses the trace formula:
1605 ``c_P = Tr(P · M) / 2**n``
1607 Args:
1608 matrix: A ``(2**n, 2**n)`` Hermitian matrix.
1609 wire_order: Optional list of wire indices. If ``None``, defaults
1610 to ``[0, 1, ..., n-1]``.
1612 Returns:
1613 A tuple ``(coeff, op)`` where *coeff* is the complex coefficient and
1614 *op* is the Pauli :class:`Operation` (PauliX, PauliY, PauliZ, I, or
1615 a :class:`Hermitian` for multi-qubit tensor products).
1616 """
1617 from itertools import product as _product
1618 from functools import reduce as _reduce
1620 dim = matrix.shape[0]
1621 n_qubits = int(jnp.round(jnp.log2(dim)))
1623 if wire_order is None:
1624 wire_order = list(range(n_qubits))
1626 # For single qubit, fast path
1627 if n_qubits == 1:
1628 best_idx, best_coeff = 0, 0.0
1629 for idx, P in enumerate(_PAULI_MATS):
1630 coeff = jnp.trace(P @ matrix) / 2.0
1631 if jnp.abs(coeff) > jnp.abs(best_coeff):
1632 best_idx = idx
1633 best_coeff = coeff
1634 op_cls = _PAULI_CLASSES[best_idx]
1635 result_op = op_cls(wires=wire_order[0], record=False)
1636 result_op._pauli_label = _PAULI_LABELS[best_idx]
1637 return best_coeff, result_op
1639 # Multi-qubit: iterate over all Pauli tensor products
1640 best_label = None
1641 best_coeff = 0.0
1642 for indices in _product(range(4), repeat=n_qubits):
1643 P = _reduce(jnp.kron, [_PAULI_MATS[i] for i in indices])
1644 coeff = jnp.trace(P @ matrix) / dim
1645 if jnp.abs(coeff) > jnp.abs(best_coeff):
1646 best_coeff = coeff
1647 best_label = indices
1649 # Build the Pauli string label
1650 pauli_label = "".join(_PAULI_LABELS[i] for i in best_label)
1652 # Build the operation for the dominant term
1653 if sum(1 for i in best_label if i != 0) <= 1:
1654 # Single-qubit Pauli on one wire
1655 for q, idx in enumerate(best_label):
1656 if idx != 0:
1657 op_cls = _PAULI_CLASSES[idx]
1658 result_op = op_cls(wires=wire_order[q], record=False)
1659 result_op._pauli_label = _PAULI_LABELS[idx]
1660 return best_coeff, result_op
1661 # All identity
1662 result_op = Id(wires=wire_order[0], record=False)
1663 result_op._pauli_label = "I" * n_qubits
1664 return best_coeff, result_op
1665 else:
1666 # Multi-qubit tensor product -> Hermitian with pauli label attached
1667 P = _reduce(jnp.kron, [_PAULI_MATS[i] for i in best_label])
1668 result_op = Hermitian(matrix=P, wires=wire_order, record=False)
1669 result_op._pauli_label = pauli_label
1670 return best_coeff, result_op
1673def pauli_string_from_operation(op: Operation) -> str:
1674 """Extract a Pauli word string from an operation.
1676 Maps ``PauliX`` -> ``"X"``, ``PauliY`` -> ``"Y"``, ``PauliZ`` -> ``"Z"``,
1677 ``I`` -> ``"I"``. For :class:`PauliRot`, returns its stored ``pauli_word``.
1678 For operations produced by :func:`pauli_decompose`, returns the stored
1679 ``_pauli_label`` attribute.
1681 Args:
1682 op: A quantum operation.
1684 Returns:
1685 A string like ``"X"``, ``"ZZ"``, etc.
1686 """
1687 if isinstance(op, PauliRot) and hasattr(op, "pauli_word"):
1688 return op.pauli_word
1689 # Check for label stored by pauli_decompose
1690 if hasattr(op, "_pauli_label"):
1691 return op._pauli_label
1692 name_map = {"PauliX": "X", "PauliY": "Y", "PauliZ": "Z", "I": "I"}
1693 if op.name in name_map:
1694 return name_map[op.name]
1695 # Fall back: decompose the matrix
1696 _, pauli_op = pauli_decompose(op.matrix, wire_order=op.wires)
1697 return pauli_op._pauli_label