Coverage for qml_essentials / operations.py: 88%
607 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-16 10:19 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-16 10:19 +0000
1from typing import Callable, List, Optional, Tuple, Union
2from functools import lru_cache
3import string
4import numpy as np
6import jax
7import jax.numpy as jnp
9from qml_essentials.tape import active_tape, recording # noqa: F401 (re-export)
12def _cdtype():
13 """Return the active JAX complex dtype
14 (complex128 if x64 enabled, else complex64).
15 """
16 return jnp.complex128 if jax.config.x64_enabled else jnp.complex64
19@lru_cache(maxsize=256)
20def _einsum_subscript(
21 n: int,
22 k: int,
23 target_axes: Tuple[int, ...],
24) -> str:
25 """Build an ``einsum`` subscript that fuses contraction + axis restore.
27 Args:
28 n: Total rank of the state tensor (number of qubits for statevectors,
29 ``2 * n_qubits`` for density matrices).
30 k: Number of qubits the gate acts on.
31 target_axes: Tuple of k axis indices in the state tensor that the
32 gate contracts against.
34 Returns:
35 ``einsum`` subscript string, e.g. ``"ab,cBd->cad"`` for a 1-qubit
36 gate on wire 1 of a 3-qubit state.
37 """
38 letters = string.ascii_letters
39 # State indices: one letter per axis
40 state_idx = list(letters[:n])
41 # Contracted indices (the ones being replaced by the gate)
42 contracted = [state_idx[ax] for ax in target_axes]
43 # Gate indices: new output indices + contracted input indices
44 new_out = [letters[n + i] for i in range(k)] # fresh letters for output
45 gate_idx = new_out + contracted # gate shape: (out0, out1, ..., in0, in1, ...)
46 # Result indices: replace target axes with new output letters
47 result_idx = list(state_idx)
48 for i, ax in enumerate(target_axes):
49 result_idx[ax] = new_out[i]
50 return "".join(gate_idx) + "," + "".join(state_idx) + "->" + "".join(result_idx)
53def _contract_and_restore(
54 tensor: jnp.ndarray,
55 gate: jnp.ndarray,
56 k: int,
57 target_axes: List[int],
58) -> jnp.ndarray:
59 """Contract gate against target_axes of tensor and restore axis order.
61 The einsum subscript is cached via :func:`_einsum_subscript` so the
62 string construction only happens once per unique
63 ``(total, k, target_axes)`` combination.
65 Args:
66 tensor: Rank-N tensor (e.g. ``(2,)*n`` for states or ``(2,)*2n``
67 for density matrices).
68 gate: Reshaped gate tensor of shape ``(2,)*2k``.
69 k: Number of qubits the gate acts on (= ``len(target_axes)``).
70 target_axes: The k axes of tensor to contract against.
72 Returns:
73 Updated tensor with the same rank as tensor, with the
74 contracted axes restored to their original positions.
75 """
76 subscript = _einsum_subscript(tensor.ndim, k, tuple(target_axes))
77 return jnp.einsum(subscript, gate, tensor)
80class Operation:
81 """Base class for any quantum operation or observable.
83 Further gates should inherit from this class to realise more specific
84 operations. Generally, operations are created by instantiation inside a
85 circuit function passed to :class:`Script`; the instance is
86 automatically appended to the active tape.
88 An ``Operation`` can also serve as an *observable*: its matrix is used to
89 compute expectation values via ``apply_to_state`` / ``apply_to_density``.
91 Attributes:
92 _matrix: Class-level default gate matrix. Subclasses set this to their
93 fixed unitary. Instances may override it via the *matrix* argument
94 to ``__init__``.
95 _num_wires: Expected number of wires for this gate. Subclasses set
96 this to enforce wire count validation. ``None`` means any number
97 of wires is accepted.
98 _param_names: Tuple of attribute names for the gate parameters.
99 Used by :attr:`parameters` and :meth:`__repr__`.
100 """
102 # Subclasses should set this to the gate's unitary / matrix
103 _matrix: jnp.ndarray = None
104 is_controlled = False
105 _num_wires: Optional[int] = None
106 _param_names: Tuple[str, ...] = ()
108 def __init__(
109 self,
110 wires: Union[int, List[int]] = 0,
111 matrix: Optional[jnp.ndarray] = None,
112 record: bool = True,
113 input_idx: int = -1,
114 name: Optional[str] = None,
115 ) -> None:
116 """Initialise the operation and optionally register it on the active tape.
118 Args:
119 wires: Qubit index or list of qubit indices this operation acts on.
120 matrix: Optional explicit gate matrix. When provided it overrides
121 the class-level ``_matrix`` attribute.
122 record: If ``True`` (default) and a tape is currently recording,
123 append this operation to the tape. Set to ``False`` for
124 auxiliary objects that should not appear in the circuit
125 (e.g. Hamiltonians used only to build time-dependent
126 evolutions).
127 input_idx: Marks the operation as input with the corresponding
128 input index, which is useful for the analytical Fourier
129 coefficients computation, but has no effect otherwise.
130 name: Optional explicit name for this operation. When ``None``
131 (default), the class name is used (e.g. ``"RX"``).
133 Raises:
134 ValueError: If ``_num_wires`` is set and the number of wires
135 doesn't match, or if duplicate wires are provided.
136 """
137 self.name = name or self.__class__.__name__
138 self.wires = list(wires) if isinstance(wires, (list, tuple)) else [wires]
139 self.input_idx = input_idx
141 if self._num_wires is not None and len(self.wires) != self._num_wires:
142 raise ValueError(
143 f"{self.name} expects {self._num_wires} wire(s), "
144 f"got {len(self.wires)}: {self.wires}"
145 )
146 if len(self.wires) != len(set(self.wires)):
147 raise ValueError(f"{self.name} received duplicate wires: {self.wires}")
149 if matrix is not None:
150 self._matrix = matrix
152 # If a tape is currently recording, append ourselves
153 if record:
154 tape = active_tape()
155 if tape is not None:
156 tape.append(self)
158 @property
159 def parameters(self) -> list:
160 """Return the list of numeric parameters for this operation.
162 Uses the declarative ``_param_names`` tuple to collect parameter
163 values in a canonical order. Non-parametrized gates return an
164 empty list.
166 Returns:
167 List of parameter values (floats or JAX arrays).
168 """
169 return [getattr(self, name) for name in self._param_names]
171 def __repr__(self) -> str:
172 """Return a human-readable representation of this operation.
174 Returns:
175 A string like ``"RX(0.5000, wires=[0])"`` or ``"CX(wires=[0, 1])"``.
176 """
177 params = self.parameters
178 if params:
179 param_str = ", ".join(
180 (
181 f"{float(v):.4f}"
182 if isinstance(v, (float, np.floating, jnp.ndarray))
183 else str(v)
184 )
185 for v in params
186 )
187 return f"{self.name}({param_str}, wires={self.wires})"
188 return f"{self.name}(wires={self.wires})"
190 @property
191 def matrix(self) -> jnp.ndarray:
192 """Return the base matrix of this operation (before lifting).
194 Returns:
195 The gate matrix as a JAX array.
197 Raises:
198 NotImplementedError: If the subclass has not defined ``_matrix``.
199 """
200 if self._matrix is None:
201 raise NotImplementedError(
202 f"{self.__class__.__name__} does not define a matrix."
203 )
204 return self._matrix
206 @property
207 def wires(self) -> List[int]:
208 """Qubit indices this operation acts on.
210 Returns:
211 List of integer qubit indices.
212 """
213 return self._wires
215 @wires.setter
216 def wires(self, wires: Union[int, List[int]]) -> None:
217 """Set the qubit indices for this operation.
219 Args:
220 wires: A single qubit index or a list of qubit indices.
221 """
222 if isinstance(wires, (list, tuple)):
223 self._wires = list(wires)
224 else:
225 self._wires = [wires]
227 @property
228 def input_idx(self) -> int:
229 """The index of an input
231 Returns:
232 input_idx: Index of the input
233 """
234 return self._input_idx
236 @input_idx.setter
237 def input_idx(self, input_idx: int) -> None:
238 """Setter for the input_idx flag
240 Args:
241 input_idx: Index of the input
242 """
243 self._input_idx = input_idx
245 def _update_tape_operation(self, op: "Operation") -> None:
246 """
247 If ``self`` is already on the active tape (the typical case when
248 chaining ``Gate(...).dagger()``), it is replaced by the daggered
249 operation so that only U\\dagger appears on the tape —
250 not both U and ``U\\dagger``.
251 Note that this should only be called immediately after the tape is updated.s
253 Args:
254 op (Operation): New replaced operation on the tape
255 """
256 # If self was recorded on the tape, replace it with the daggered op.
257 tape = active_tape()
258 if tape is not None:
259 if tape and tape[-1] is self:
260 tape[-1] = op
261 else:
262 tape.append(op)
264 def dagger(self) -> "Operation":
265 """Return a new operation, the conjugate transpose (``U\\dagger``)
266 Usage inside a circuit function::
268 RX(0.5, wires=0).dagger()
270 Returns:
271 A new :class:`Operation` with matrix ``U\\dagger`` acting on the same wires.
272 """
273 mat = jnp.conj(self._matrix).T
274 op = Operation(wires=self.wires, matrix=mat, record=False)
276 self._update_tape_operation(op)
278 return op
280 def power(self, power) -> "Operation":
281 """Return a new operation, the power (``U^power``)
282 Usage inside a circuit function::
284 PauliX(wires=0).power(2)
286 Returns:
287 A new :class:`Operation` with matrix ``U\\dagger`` acting on the same wires.
288 """
289 # TODO: support fractional powers
290 mat = jnp.linalg.matrix_power(self._matrix, power)
291 op = Operation(wires=self.wires, matrix=mat, record=False)
293 self._update_tape_operation(op)
295 return op
297 def __mul__(self, other: Union[float, "Operation"]) -> "Operation":
298 """Return a new operation, the product between U and a scalar (``U*x``)
299 or the composition of two operations.
300 Usage inside a circuit function::
302 PauliX(wires=0) * x
303 PauliX(wires=0) * PauliZ(wires=0)
305 Returns:
306 A new :class:`Operation` with matrix ``U*x`` acting on the same wires,
307 or the composed matrix acting on the appropriate wires.
308 """
309 if isinstance(other, Operation):
310 return self.__matmul__(other)
312 mat = other * self._matrix
313 op = Operation(wires=self.wires, matrix=mat, record=False)
315 self._update_tape_operation(op)
317 return op
319 # Also overwrite * for right operands
320 __rmul__ = __mul__
322 def __add__(self, other: "Operation") -> "Operation":
323 """Element-wise addition of two operations on the same wires.
325 Returns:
326 A new :class:`Operation` whose matrix is the sum of both matrices.
328 Raises:
329 ValueError: If the wire sets differ.
330 """
331 if sorted(self.wires) != sorted(other.wires):
332 raise ValueError(
333 f"Can only add operations acting on the same set of wires, "
334 f"got {self.wires} and {other.wires}"
335 )
337 op = Operation(
338 wires=self.wires,
339 matrix=self.matrix + other.matrix,
340 record=False,
341 )
342 return op
344 def prod(self, *ops: "Operation") -> "Operation":
345 """Construct the generalized product (tensor or matrix)
346 of this operation with others.
348 The resulting operation acts on the union of all wire sets.
349 If the wire sets are disjoint, this is a Kronecker product.
350 If the wire sets overlap, the corresponding matrices are multiplied.
352 Usage::
354 res = op1.prod(op2, op3)
355 # or
356 res = Operation.prod(op1, op2, op3)
358 Args:
359 *ops: Variable number of :class:`Operation` instances.
361 Returns:
362 A new :class:`Operation` representing the composed operation.
363 """
364 if not ops:
365 return self
367 all_ops = (self,) + ops
368 all_wires = []
369 for op in all_ops:
370 for w in op.wires:
371 if w not in all_wires:
372 all_wires.append(w)
374 n = len(all_wires)
376 mat = _embed_matrix(all_ops[0].matrix, all_ops[0].wires, all_wires, n)
377 for op in all_ops[1:]:
378 mat_other = _embed_matrix(op.matrix, op.wires, all_wires, n)
379 mat = mat @ mat_other
381 op_names = "*".join(op.name for op in all_ops)
382 return Operation(
383 wires=all_wires, matrix=mat, name=f"Prod({op_names})", record=False
384 )
386 def __matmul__(self, other: "Operation") -> "Operation":
387 """Tensor (Kronecker) product or matrix product of two operations.
389 The resulting operation acts on the union of both wire sets.
390 If the wire sets are disjoint, this is a Kronecker product.
391 If the wire sets overlap, the corresponding matrices are multiplied.
393 Returns:
394 A new :class:`Operation` whose matrix represents the composed
395 operation on the unified wire set.
396 """
397 if not isinstance(other, Operation):
398 return NotImplemented
400 return self.prod(other)
402 def lifted_matrix(self, n_qubits: int) -> jnp.ndarray:
403 """Return the full ``2**n x 2**n`` matrix embedding this gate.
405 Embeds the ``k``-qubit gate matrix into the ``n``-qubit Hilbert space
406 by applying it to the identity matrix via :meth:`apply_to_state`.
407 This is useful for computing ``Tr(O·\\rho )`` directly without vmap.
409 Args:
410 n_qubits: Total number of qubits in the circuit.
412 Returns:
413 The ``(2**n, 2**n)`` matrix of this operation in the full space.
414 """
415 dim = 2**n_qubits
416 # Apply the gate to each basis vector (column of identity)
417 return jax.vmap(lambda col: self.apply_to_state(col, n_qubits))(
418 jnp.eye(dim, dtype=_cdtype())
419 ).T
421 def apply_to_state(self, state: jnp.ndarray, n_qubits: int) -> jnp.ndarray:
422 """Apply this gate to a statevector via tensor contraction.
424 The statevector (shape ``(2**n,)``) is reshaped into a rank-n tensor
425 of shape ``(2,)*n``. The gate (shape ``(2**k, 2**k)``) is reshaped to
426 ``(2,)*2k`` and contracted against the k target wire axes.
428 Memory footprint is O(2**n) and the operation supports arbitrary k.
429 The implementation is fully differentiable through JAX.
431 Args:
432 state: Statevector of shape ``(2**n_qubits,)``.
433 n_qubits: Total number of qubits in the circuit.
435 Returns:
436 Updated statevector of shape ``(2**n_qubits,)``.
437 """
438 k = len(self.wires)
439 gate_tensor = self.matrix.reshape((2,) * 2 * k)
440 psi = state.reshape((2,) * n_qubits)
441 psi_out = _contract_and_restore(psi, gate_tensor, k, self.wires)
442 return psi_out.reshape(2**n_qubits)
444 def apply_to_state_tensor(self, psi: jnp.ndarray, n_qubits: int) -> jnp.ndarray:
445 """Apply this gate to a statevector already in tensor form.
447 Like :meth:`apply_to_state` but expects the state in rank-n tensor
448 form ``(2,)*n`` and returns the result in the same form. This avoids
449 the ``reshape`` calls at the per-gate level when the simulation loop
450 keeps the state in tensor form throughout.
452 Args:
453 psi: Statevector tensor of shape ``(2,)*n_qubits``.
454 n_qubits: Total number of qubits in the circuit.
456 Returns:
457 Updated statevector tensor of shape ``(2,)*n_qubits``.
458 """
459 k = len(self.wires)
460 gate_tensor = self._gate_tensor(k)
461 return _contract_and_restore(psi, gate_tensor, k, self.wires)
463 def _gate_tensor(self, k: int) -> jnp.ndarray:
464 """Return the gate matrix reshaped to ``(2,)*2k`` tensor form.
466 The result is cached on the instance so repeated calls (e.g. from
467 density-matrix simulation which applies U and U*) avoid redundant
468 reshape dispatch.
470 Args:
471 k: Number of qubits the gate acts on.
473 Returns:
474 Gate matrix as a rank-2k tensor of shape ``(2,)*2k``.
475 """
476 cached = getattr(self, "_cached_gate_tensor", None)
477 if cached is not None:
478 return cached
479 gt = self.matrix.reshape((2,) * 2 * k)
480 # Only cache for non-parametrized gates (whose matrix is a class attr)
481 if self._matrix is self.__class__._matrix:
482 object.__setattr__(self, "_cached_gate_tensor", gt)
483 return gt
485 def apply_to_density(self, rho: jnp.ndarray, n_qubits: int) -> jnp.ndarray:
486 """Apply this gate to a density matrix via \\rho -> U\\rho U\\dagger.
488 The density matrix (shape ``(2**n, 2**n)``) is treated as a rank-*2n*
489 tensor with n "ket" axes (0..n-1) and n "bra" axes (n..2n-1).
490 U acts on the ket half; U* acts on the bra half. Both contractions
491 use the shared :func:`_contract_and_restore` helper, keeping the
492 operation allocation-free with respect to building full unitaries.
494 Args:
495 rho: Density matrix of shape ``(2**n_qubits, 2**n_qubits)``.
496 n_qubits: Total number of qubits in the circuit.
498 Returns:
499 Updated density matrix of shape ``(2**n_qubits, 2**n_qubits)``.
500 """
501 k = len(self.wires)
502 U = self._gate_tensor(k)
503 U_conj = jnp.conj(U)
505 rho_t = rho.reshape((2,) * 2 * n_qubits)
507 # Apply U to ket axes, U\\dagger to bra axes
508 rho_t = _contract_and_restore(rho_t, U, k, self.wires)
509 bra_wires = [w + n_qubits for w in self.wires]
510 rho_t = _contract_and_restore(rho_t, U_conj, k, bra_wires)
512 return rho_t.reshape(2**n_qubits, 2**n_qubits)
515class Hermitian(Operation):
516 """A generic Hermitian observable or gate defined by an arbitrary matrix.
518 Example:
519 >>> obs = Hermitian(matrix=my_matrix, wires=0)
520 """
522 def __init__(
523 self,
524 matrix: jnp.ndarray,
525 wires: Union[int, List[int]] = 0,
526 record: bool = True,
527 ) -> None:
528 """Initialise a Hermitian operator.
530 Args:
531 matrix: The Hermitian matrix defining this operator.
532 wires: Qubit index or list of qubit indices this operator acts on.
533 record: If ``True`` (default), record on the active tape. Set to
534 ``False`` when using the Hermitian purely as a Hamiltonian
535 component (e.g. for time-dependent evolution).
536 """
537 super().__init__(
538 wires=wires,
539 matrix=jnp.asarray(matrix, dtype=_cdtype()),
540 record=record,
541 )
543 def __rmul__(self, coeff_fn: Callable) -> "ParametrizedHamiltonian":
544 """Support ``coeff_fn * Hermitian`` -> :class:`ParametrizedHamiltonian`.
546 Args:
547 coeff_fn (Callable): A callable ``(params, t) -> scalar`` giving the
548 time-dependent coefficient.
550 Returns:
551 ParametrizedHamiltonian: A :class:`ParametrizedHamiltonian` pairing
552 *coeff_fn* with this operator's matrix and wires.
554 Raises:
555 TypeError: If *coeff_fn* is not callable.
556 """
557 if not callable(coeff_fn):
558 raise TypeError(
559 f"Left operand of `* Hermitian` must be callable, got {type(coeff_fn)}"
560 )
561 return ParametrizedHamiltonian(terms=[(coeff_fn, self.matrix, self.wires)])
564class ParametrizedHamiltonian:
565 """A time-dependent Hamiltonian as a sum of ``coeff * Hermitian`` terms.
567 Mathematically::
569 H(t) = \\sum_i f_i(params_i, t) * H_i
571 Construction is always done from an explicit list of
572 ``(coeff_fn, H_mat, wires)`` triples passed as ``terms``. The
573 common single-term shorthand is the operator form
574 ``coeff_fn * Hermitian(matrix, wires)`` (see
575 :meth:`Hermitian.__rmul__`), which returns a one-term instance.
576 Multi-term Hamiltonians are composed with ``+`` between
577 :class:`ParametrizedHamiltonian` instances::
579 H1 = coeff_x * Hermitian(X, wires=0)
580 H2 = coeff_y * Hermitian(Y, wires=0)
581 H_td = H1 + H2
583 # evolve under the composite Hamiltonian; coeff_args is a list of
584 # parameter sets, one per term, in the order the terms were added:
585 evolve(H_td)([px, py], T=1.0)
587 Attributes:
588 coeff_fns: Tuple of callables ``(params, t) -> scalar``, one per term.
589 H_mats: Tuple of static Hermitian matrices, one per term.
590 wires: Wires this Hamiltonian acts on (union across all terms; for
591 now all terms are required to share the same wire set).
592 """
594 def __init__(
595 self,
596 terms: List[Tuple[Callable, jnp.ndarray, Union[int, List[int]]]],
597 ) -> None:
598 """Build a (possibly multi-term) parametrized Hamiltonian.
600 Args:
601 terms: List of ``(coeff_fn, H_mat, wires)`` triples. Use the
602 ``coeff_fn * Hermitian(...)`` shorthand to build a
603 one-term instance; combine instances with ``+`` to add
604 terms.
606 Raises:
607 ValueError: If the term list is empty, or if terms act on
608 differing wire sets (multi-wire broadcasting is
609 deferred — see :mod:`yaqsi`), or if term matrices have
610 incompatible shapes.
611 """
612 if len(terms) == 0:
613 raise ValueError("ParametrizedHamiltonian needs at least one term.")
615 # Normalise wires (single int -> [int]) and validate consistency.
616 def _wlist(w):
617 return [w] if isinstance(w, int) else list(w)
619 first_wires = _wlist(terms[0][2])
620 for _, _, w in terms[1:]:
621 if _wlist(w) != first_wires:
622 raise ValueError(
623 "All terms of a ParametrizedHamiltonian must currently "
624 "act on the same wires; got "
625 f"{_wlist(w)} vs. {first_wires}. "
626 "Multi-wire broadcasting across terms is not yet supported."
627 )
629 # Validate matrix shape compatibility across terms.
630 first_dim = jnp.asarray(terms[0][1]).shape
631 for _, H, _ in terms[1:]:
632 if jnp.asarray(H).shape != first_dim:
633 raise ValueError(
634 f"All term matrices must have the same shape; got "
635 f"{jnp.asarray(H).shape} vs. {first_dim}."
636 )
638 self._terms: Tuple[Tuple[Callable, jnp.ndarray, List[int]], ...] = tuple(
639 (fn, jnp.asarray(H, dtype=_cdtype()), _wlist(w)) for fn, H, w in terms
640 )
641 self.wires: List[int] = list(first_wires)
643 # --- term accessors -------------------------------------------------
645 @property
646 def coeff_fns(self) -> Tuple[Callable, ...]:
647 """Tuple of coefficient functions, one per term."""
648 return tuple(fn for fn, _, _ in self._terms)
650 @property
651 def H_mats(self) -> Tuple[jnp.ndarray, ...]:
652 """Tuple of Hermitian matrices, one per term."""
653 return tuple(H for _, H, _ in self._terms)
655 @property
656 def n_terms(self) -> int:
657 """Number of terms in the Hamiltonian."""
658 return len(self._terms)
660 # --- composition ---------------------------------------------------
662 def __add__(self, other: "ParametrizedHamiltonian") -> "ParametrizedHamiltonian":
663 """Concatenate term lists: ``H = H1 + H2``."""
664 if not isinstance(other, ParametrizedHamiltonian):
665 return NotImplemented
666 return ParametrizedHamiltonian(terms=list(self._terms) + list(other._terms))
668 def __neg__(self) -> "ParametrizedHamiltonian":
669 """Negate every coefficient: ``-H`` = sum of ``(-f_i) * H_i``."""
670 new_terms = [
671 ((lambda f: lambda p, t: -f(p, t))(fn), H, w) for fn, H, w in self._terms
672 ]
673 return ParametrizedHamiltonian(terms=new_terms)
675 def __sub__(self, other: "ParametrizedHamiltonian") -> "ParametrizedHamiltonian":
676 if not isinstance(other, ParametrizedHamiltonian):
677 return NotImplemented
678 return self + (-other)
681class Id(Operation):
682 """Identity gate.
684 Supports an arbitrary number of wires. When more than one wire is
685 given the matrix is the ``2**k x 2**k`` identity (where *k* is the
686 number of wires).
687 """
689 _matrix = jnp.eye(2, dtype=_cdtype())
690 _num_wires = None # accept any number of wires
692 def __init__(self, wires: Union[int, List[int]] = 0, **kwargs) -> None:
693 """Initialise an identity gate.
695 Args:
696 wires: Qubit index or list of qubit indices this gate acts on.
697 When multiple wires are given the matrix is automatically
698 expanded to the matching ``2**k × 2**k`` identity.
699 """
700 w = list(wires) if isinstance(wires, (list, tuple)) else [wires]
701 k = len(w)
702 if k > 1:
703 kwargs["matrix"] = jnp.eye(2**k, dtype=_cdtype())
704 super().__init__(wires=wires, **kwargs)
707class PauliX(Operation):
708 """Pauli-X gate / observable (bit-flip, \\sigma_x)."""
710 _matrix = jnp.array([[0, 1], [1, 0]], dtype=_cdtype())
711 _num_wires = 1
713 def __init__(self, wires: Union[int, List[int]] = 0, **kwargs) -> None:
714 """Initialise a Pauli-X gate.
716 Args:
717 wires: Qubit index or list of qubit indices this gate acts on.
718 """
719 super().__init__(wires=wires, **kwargs)
722class PauliY(Operation):
723 """Pauli-Y gate / observable (\\sigma_y)."""
725 _matrix = jnp.array([[0, -1j], [1j, 0]], dtype=_cdtype())
726 _num_wires = 1
728 def __init__(self, wires: Union[int, List[int]] = 0, **kwargs) -> None:
729 """Initialise a Pauli-Y gate.
731 Args:
732 wires: Qubit index or list of qubit indices this gate acts on.
733 """
734 super().__init__(wires=wires, **kwargs)
737class PauliZ(Operation):
738 """Pauli-Z gate / observable (phase-flip, \\sigma_z)."""
740 _matrix = jnp.array([[1, 0], [0, -1]], dtype=_cdtype())
741 _num_wires = 1
743 def __init__(self, wires: Union[int, List[int]] = 0, **kwargs) -> None:
744 """Initialise a Pauli-Z gate.
746 Args:
747 wires: Qubit index or list of qubit indices this gate acts on.
748 """
749 super().__init__(wires=wires, **kwargs)
752class H(Operation):
753 """Hadamard gate."""
755 _matrix = jnp.array([[1, 1], [1, -1]], dtype=_cdtype()) / jnp.sqrt(2)
756 _num_wires = 1
758 def __init__(self, wires: Union[int, List[int]] = 0, **kwargs) -> None:
759 """Initialise a Hadamard gate.
761 Args:
762 wires: Qubit index or list of qubit indices this gate acts on.
763 """
764 super().__init__(wires=wires, **kwargs)
767class S(Operation):
768 """S (phase) gate — a Clifford gate equal to \\sqrt Z.
770 .. math::
771 S = \\begin{pmatrix}1 & 0\\ 0 & i\\end{pmatrix}
772 """
774 _matrix = jnp.array([[1, 0], [0, 1j]], dtype=_cdtype())
775 _num_wires = 1
777 def __init__(self, wires: Union[int, List[int]] = 0) -> None:
778 """Initialise an S gate.
780 Args:
781 wires: Qubit index or list of qubit indices this gate acts on.
782 """
783 super().__init__(wires=wires)
786class SWAP(Operation):
787 """SWAP gate."""
789 _matrix = jnp.array(
790 [[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=_cdtype()
791 )
792 _num_wires = 2
794 def __init__(self, wires: Union[int, List[int]] = 0, **kwargs) -> None:
795 """Initialise a SWAP gate.
797 Args:
798 wires: Qubit index or list of qubit indices this gate acts on.
799 """
800 super().__init__(wires=wires, **kwargs)
803class RandomUnitary(Operation):
804 """Creates a random hermitian matrix and applies it as a gate."""
806 def __init__(
807 self,
808 wires: Union[int, List[int]],
809 key: jax.random.PRNGKey,
810 scale: float = 1.0,
811 record: bool = True,
812 ) -> None:
813 """Initialise a random unitary gate.
815 Args:
816 wires (Union[int, List[int]]): Qubit index or list of qubit indices
817 this gate acts on.
818 key (jax.random.PRNGKey): PRNGKey for randomization.
819 scale (float): Scale of the random unitary (default: 1.0).
820 record (bool): Whether to record this gate on the active tape.
821 """
822 dim = 2 ** len(wires)
823 key_a, key_b = jax.random.split(key)
825 A = (
826 jax.random.normal(key=key_a, shape=(dim, dim))
827 + 1j * jax.random.normal(key=key_b, shape=(dim, dim))
828 ).astype(_cdtype())
829 H = (A + A.conj().T) / 2.0
831 H *= scale / jnp.linalg.norm(H, ord="fro")
833 super().__init__(wires, matrix=H, record=record)
836class DiagonalQubitUnitary(Operation):
837 """A diagonal unitary gate specified by its diagonal entries.
839 Implements ``U = diag(d_0, d_1, ..., d_{2^k-1})`` where each ``d_i`` lies
840 on the unit circle. This is the natural gate for data-encoding
841 Hamiltonians of the form ``S(x) = exp(-i H x)`` where *H* is diagonal in
842 the computational basis (see Peters et al., arXiv:2209.05523).
844 The Golomb encoding strategy uses this gate with diagonal entries
845 ``exp(-i * golomb_marks * x)`` to achieve a maximally non-degenerate
846 Fourier spectrum.
848 Args:
849 diag: 1-D array of ``2**k`` complex values on the unit circle.
850 wires: Qubit indices this gate acts on (s.t. ``2**len(wires) == len(diag)``).
851 **kwargs: Forwarded to :class:`Operation`.
852 """
854 # Do NOT list "diag" in _param_names — the array is not a scalar
855 # parameter and would break drawing helpers that call float(p).
856 _param_names = ()
858 def __init__(
859 self,
860 diag: jnp.ndarray,
861 wires: Union[int, List[int]] = 0,
862 **kwargs,
863 ) -> None:
864 self.diag = diag
865 wires_list = list(wires) if isinstance(wires, (list, tuple)) else [wires]
866 expected_dim = 2 ** len(wires_list)
867 if diag.shape != (expected_dim,):
868 raise ValueError(
869 f"DiagonalQubitUnitary expects {expected_dim} diagonal entries "
870 f"for {len(wires_list)} wire(s), got shape {diag.shape}"
871 )
872 mat = jnp.diag(diag)
873 # Use a descriptive name for drawing
874 kwargs.setdefault("name", "DiagU")
875 super().__init__(wires=wires, matrix=mat, **kwargs)
877 def apply_to_state(self, state: jnp.ndarray, n_qubits: int) -> jnp.ndarray:
878 """Apply diagonal gate via element-wise multiplication.
880 For a diagonal unitary, the full ``2^n``-dimensional diagonal is
881 constructed by appropriate Kronecker-product embedding and the gate
882 is applied as an element-wise product, which is significantly cheaper
883 than generic matrix contraction for large qubit counts.
885 Args:
886 state: Statevector of shape ``(2**n_qubits,)``.
887 n_qubits: Total number of qubits in the circuit.
889 Returns:
890 Updated statevector of shape ``(2**n_qubits,)``.
891 """
892 k = len(self.wires)
893 if k == n_qubits and self.wires == list(range(n_qubits)):
894 # Gate acts on all qubits in order — direct element-wise multiply
895 return state * self.diag
896 # Fall back to general tensor contraction for arbitrary wire subsets
897 return super().apply_to_state(state, n_qubits)
899 def apply_to_density(self, rho: jnp.ndarray, n_qubits: int) -> jnp.ndarray:
900 """Apply diagonal gate to density matrix: rho -> U rho U†.
902 For diagonal U the transformation is
903 ``rho_ij -> d_i * conj(d_j) * rho_ij``.
905 Args:
906 rho: Density matrix of shape ``(2**n_qubits, 2**n_qubits)``.
907 n_qubits: Total number of qubits in the circuit.
909 Returns:
910 Updated density matrix of shape ``(2**n_qubits, 2**n_qubits)``.
911 """
912 k = len(self.wires)
913 if k == n_qubits and self.wires == list(range(n_qubits)):
914 d = self.diag
915 return d[:, None] * jnp.conj(d)[None, :] * rho
916 return super().apply_to_density(rho, n_qubits)
919class Barrier(Operation):
920 """Barrier operation — a no-op used for visual circuit separation.
922 The barrier does not change the quantum state. It is recorded on the
923 tape so that drawing backends can insert a visual separator.
924 """
926 _matrix = None # not a real gate
928 def __init__(self, wires: Union[int, List[int]] = 0) -> None:
929 """Initialise a Barrier.
931 Args:
932 wires: Qubit index or list of qubit indices this barrier spans.
933 """
934 super().__init__(wires=wires)
936 def apply_to_state(self, state: jnp.ndarray, n_qubits: int) -> jnp.ndarray:
937 """No-op: return the state unchanged."""
938 return state
940 def apply_to_state_tensor(self, psi: jnp.ndarray, n_qubits: int) -> jnp.ndarray:
941 """No-op: return the state tensor unchanged."""
942 return psi
944 def apply_to_density(self, rho: jnp.ndarray, n_qubits: int) -> jnp.ndarray:
945 """No-op: return the density matrix unchanged."""
946 return rho
949def _make_rotation_gate(pauli_class: type, name: str) -> type:
950 """Factory for single-qubit rotation gates RX, RY, RZ.
952 Each gate has the form ``R_P(\\theta) = cos(\\theta/2) I - i sin(\\theta/2) P``.
954 Args:
955 pauli_class: One of PauliX, PauliY, PauliZ.
956 name: Class name for the generated gate (e.g. ``"RX"``).
958 Returns:
959 A new :class:`Operation` subclass.
960 """
961 pauli_mat = pauli_class._matrix
963 class _RotationGate(Operation):
964 # Fancy way of setting docstring to make it generic
965 __doc__ = (
966 f"Rotation around the {name[1]} axis: {name}(\\theta) =\n"
967 f"exp(-i \\theta/2 {name[1]}).\n"
968 )
969 _num_wires = 1
970 _param_names = ("theta",)
972 def __init__(
973 self, theta: float, wires: Union[int, List[int]] = 0, **kwargs
974 ) -> None:
975 self.theta = theta
976 c = jnp.cos(theta / 2)
977 s = jnp.sin(theta / 2)
978 mat = c * Id._matrix - 1j * s * pauli_mat
979 super().__init__(wires=wires, matrix=mat, **kwargs)
981 def generator(self) -> Operation:
982 """Return the generator as the corresponding Pauli operation."""
983 return pauli_class(wires=self.wires[0], record=False)
985 _RotationGate.__name__ = name
986 _RotationGate.__qualname__ = name
987 return _RotationGate
990RX = _make_rotation_gate(PauliX, "RX")
991RY = _make_rotation_gate(PauliY, "RY")
992RZ = _make_rotation_gate(PauliZ, "RZ")
995# Projectors used by controlled-gate factories
996_P0 = jnp.array([[1, 0], [0, 0]], dtype=_cdtype())
997_P1 = jnp.array([[0, 0], [0, 1]], dtype=_cdtype())
1000def _make_controlled_gate(target_class: type, name: str) -> type:
1001 """Factory for controlled Pauli gates CX, CY, CZ.
1003 Each gate has the form
1004 ``CP = |0><0| \\otimes I + |1\\langle\\rangle 1| \\otimes P``.
1006 Args:
1007 target_class: The single-qubit gate class (PauliX, PauliY, PauliZ).
1008 name: Class name for the generated gate (e.g. ``"CX"``).
1010 Returns:
1011 A new :class:`Operation` subclass.
1012 """
1013 target_mat = target_class._matrix
1015 class _ControlledGate(Operation):
1016 __doc__ = (
1017 f"Controlled-{target_class.__name__[5:]} gate.\n\n"
1018 f"Applies {target_class.__name__} on the target qubit conditioned "
1019 f"on the control qubit being in state |1\\rangle."
1020 )
1021 _matrix = jnp.kron(_P0, Id._matrix) + jnp.kron(_P1, target_mat)
1022 _num_wires = 2
1023 is_controlled = True
1025 def __init__(self, wires: List[int] = [0, 1], **kwargs) -> None:
1026 super().__init__(wires=wires, **kwargs)
1028 _ControlledGate.__name__ = name
1029 _ControlledGate.__qualname__ = name
1030 return _ControlledGate
1033CX = _make_controlled_gate(PauliX, "CX")
1034CY = _make_controlled_gate(PauliY, "CY")
1035CZ = _make_controlled_gate(PauliZ, "CZ")
1038class CCX(Operation):
1039 """Toffoli (CCX) gate.
1041 The 3-qubit Toffoli gate exercises the arbitrary-k-qubit path in
1042 :meth:`~Operation.apply_to_state` and cannot be expressed as a pair of
1043 2-qubit gates without ancilla, making it a good stress-test for the
1044 simulator.
1045 """
1047 _matrix = jnp.array(
1048 [
1049 [1, 0, 0, 0, 0, 0, 0, 0],
1050 [0, 1, 0, 0, 0, 0, 0, 0],
1051 [0, 0, 1, 0, 0, 0, 0, 0],
1052 [0, 0, 0, 1, 0, 0, 0, 0],
1053 [0, 0, 0, 0, 1, 0, 0, 0],
1054 [0, 0, 0, 0, 0, 1, 0, 0],
1055 [0, 0, 0, 0, 0, 0, 0, 1],
1056 [0, 0, 0, 0, 0, 0, 1, 0],
1057 ],
1058 dtype=_cdtype(),
1059 )
1060 is_controlled = True
1061 _num_wires = 3
1063 def __init__(self, wires: List[int] = [0, 1, 2], **kwargs) -> None:
1064 """Initialise a Toffoli (CCX) gate.
1066 Args:
1067 wires: Three-element list ``[control0, control1, target]``.
1068 """
1069 super().__init__(wires=wires, **kwargs)
1072class CSWAP(Operation):
1073 """Controlled-SWAP (Fredkin) gate.
1075 Swaps the two target qubits conditioned on the control qubit being |1\\rangle.
1077 Args on construction:
1078 wires: ``[control, target0, target1]``.
1079 """
1081 _matrix = jnp.array(
1082 [
1083 [1, 0, 0, 0, 0, 0, 0, 0],
1084 [0, 1, 0, 0, 0, 0, 0, 0],
1085 [0, 0, 1, 0, 0, 0, 0, 0],
1086 [0, 0, 0, 1, 0, 0, 0, 0],
1087 [0, 0, 0, 0, 1, 0, 0, 0],
1088 [0, 0, 0, 0, 0, 0, 1, 0],
1089 [0, 0, 0, 0, 0, 1, 0, 0],
1090 [0, 0, 0, 0, 0, 0, 0, 1],
1091 ],
1092 dtype=_cdtype(),
1093 )
1094 is_controlled = True
1095 _num_wires = 3
1097 def __init__(self, wires: List[int] = [0, 1, 2], **kwargs) -> None:
1098 """Initialise a Controlled-SWAP (Fredkin) gate.
1100 Args:
1101 wires: Three-element list ``[control, target0, target1]``.
1102 """
1103 super().__init__(wires=wires, **kwargs)
1106def _make_controlled_rotation_gate(pauli_class: type, name: str) -> type:
1107 """Factory for controlled rotation gates CRX, CRY, CRZ.
1109 Each gate has the form
1110 ``CR_P(\\theta) = |0><0| \\otimes I + |1><1| \\otimes R_P(\\theta)``.
1112 Args:
1113 pauli_class: One of PauliX, PauliY, PauliZ.
1114 name: Class name for the generated gate (e.g. ``"CRX"``).
1116 Returns:
1117 A new :class:`Operation` subclass.
1118 """
1119 pauli_mat = pauli_class._matrix
1121 class _CRotationGate(Operation):
1122 __doc__ = (
1123 f"Controlled rotation around the {name[2]} axis.\n\n"
1124 f"Applies R{name[2]}(\\theta) on the target qubit conditioned on the "
1125 f"control qubit being in state |1\\rangle.\n\n"
1126 f".. math::\n"
1127 f"{name}(\\theta) = |0\\rangle\\langle 0| \\otimes I\n"
1128 f" + |1\\rangle\\langle 1| \\otimes R{name[2]}(\\theta)"
1129 )
1130 _num_wires = 2
1131 _param_names = ("theta",)
1132 is_controlled = True
1134 def __init__(self, theta: float, wires: List[int] = [0, 1], **kwargs) -> None:
1135 self.theta = theta
1136 c = jnp.cos(theta / 2)
1137 s = jnp.sin(theta / 2)
1138 rot = c * Id._matrix - 1j * s * pauli_mat
1139 mat = jnp.kron(_P0, Id._matrix) + jnp.kron(_P1, rot)
1140 super().__init__(wires=wires, matrix=mat, **kwargs)
1142 _CRotationGate.__name__ = name
1143 _CRotationGate.__qualname__ = name
1144 return _CRotationGate
1147CRX = _make_controlled_rotation_gate(PauliX, "CRX")
1148CRY = _make_controlled_rotation_gate(PauliY, "CRY")
1149CRZ = _make_controlled_rotation_gate(PauliZ, "CRZ")
1152class ControlledPhaseShift(Operation):
1153 r"""Controlled phase shift gate (CPhase).
1155 Applies a phase shift of ``exp(i * phi)`` to the |11⟩ component of the
1156 two-qubit state, leaving all other computational basis states unchanged.
1157 This is a generalization of the CZ gate: when ``phi = \\pi`` the gate
1158 reduces to CZ.
1160 .. math::
1161 \text{CPhase}(\phi) = \text{diag}(1, 1, 1, e^{i\phi})
1163 which is equivalent to
1164 ``|0⟩⟨0| \\otimes I + |1⟩⟨1| \\otimes P(phi)`` where
1165 ``P(phi) = diag(1, exp(i*phi))``.
1166 """
1168 _num_wires = 2
1169 _param_names = ("phi",)
1170 is_controlled = True
1172 def __init__(self, phi: float, wires: List[int] = [0, 1], **kwargs) -> None:
1173 """Initialise a controlled phase shift gate.
1175 Args:
1176 phi: Phase shift angle in radians.
1177 wires: Two-element list ``[control, target]``.
1178 """
1179 self.phi = phi
1180 phase_gate = jnp.array([[1, 0], [0, jnp.exp(1j * phi)]], dtype=_cdtype())
1181 mat = jnp.kron(_P0, Id._matrix) + jnp.kron(_P1, phase_gate)
1182 super().__init__(wires=wires, matrix=mat, **kwargs)
1185class Rot(Operation):
1186 """General single-qubit rotation:
1187 Rot(\\phi, \\theta, \\omega) = RZ(\\omega) RY(\\theta) RZ(\\phi).
1189 This is the most general SU(2) rotation (up to a global phase). It
1190 decomposes into three successive rotations and has three free parameters.
1191 """
1193 _num_wires = 1
1194 _param_names = ("phi", "theta", "omega")
1196 def __init__(
1197 self,
1198 phi: float,
1199 theta: float,
1200 omega: float,
1201 wires: Union[int, List[int]] = 0,
1202 **kwargs,
1203 ) -> None:
1204 """Initialise a general rotation gate.
1206 Args:
1207 phi: First RZ rotation angle (radians).
1208 theta: RY rotation angle (radians).
1209 omega: Second RZ rotation angle (radians).
1210 wires: Qubit index or list of qubit indices this gate acts on.
1211 """
1212 self.phi = phi
1213 self.theta = theta
1214 self.omega = omega
1215 # Rot(\\phi, \theta, \\omega) = RZ(\\omega) @ RY(\theta) @ RZ(\\phi)
1216 rz_phi = jnp.cos(phi / 2) * Id._matrix - 1j * jnp.sin(phi / 2) * PauliZ._matrix
1217 ry_theta = (
1218 jnp.cos(theta / 2) * Id._matrix - 1j * jnp.sin(theta / 2) * PauliY._matrix
1219 )
1220 rz_omega = (
1221 jnp.cos(omega / 2) * Id._matrix - 1j * jnp.sin(omega / 2) * PauliZ._matrix
1222 )
1223 mat = rz_omega @ ry_theta @ rz_phi
1224 super().__init__(wires=wires, matrix=mat, **kwargs)
1227class PauliRot(Operation):
1228 """Multi-qubit Pauli rotation: exp(-i \\theta/2 P) for a Pauli word P.
1230 The Pauli word is given as a string of ``'I'``, ``'X'``, ``'Y'``, ``'Z'``
1231 characters (one per qubit). The rotation matrix is computed as
1232 ``cos(\\theta/2) I - i sin(\\theta/2) P`` where *P* is the tensor product of the
1233 corresponding single-qubit Pauli matrices.
1235 Example::
1237 PauliRot(0.5, "XY", wires=[0, 1])
1238 """
1240 _param_names = ("theta",)
1242 # Map from character to 2x2 matrix
1243 _PAULI_MAP = {
1244 "I": Id._matrix,
1245 "X": PauliX._matrix,
1246 "Y": PauliY._matrix,
1247 "Z": PauliZ._matrix,
1248 }
1250 def __init__(
1251 self, theta: float, pauli_word: str, wires: Union[int, List[int]] = 0, **kwargs
1252 ) -> None:
1253 """Initialise a PauliRot gate.
1255 Args:
1256 theta: Rotation angle in radians.
1257 pauli_word: A string of ``'I'``, ``'X'``, ``'Y'``, ``'Z'``
1258 characters specifying the Pauli tensor product.
1259 wires: Qubit index or list of qubit indices this gate acts on.
1260 """
1261 from functools import reduce as _reduce
1263 self.theta = theta
1264 self.pauli_word = pauli_word
1266 pauli_matrices = [self._PAULI_MAP[c] for c in pauli_word]
1267 P = _reduce(jnp.kron, pauli_matrices)
1268 dim = P.shape[0]
1269 mat = (
1270 jnp.cos(theta / 2) * jnp.eye(dim, dtype=_cdtype())
1271 - 1j * jnp.sin(theta / 2) * P
1272 )
1273 super().__init__(wires=wires, matrix=mat, **kwargs)
1275 def generator(self) -> Operation:
1276 """Return the generator Pauli tensor product as an :class:`Operation`.
1278 The generator of ``PauliRot(\\theta, word, wires)`` is the tensor product
1279 of single-qubit Pauli matrices specified by *word*. The returned
1280 :class:`Hermitian` wraps that matrix and the gate's wires.
1282 Returns:
1283 :class:`Hermitian` operation representing the Pauli tensor product.
1284 """
1285 from functools import reduce as _reduce
1287 pauli_matrices = [self._PAULI_MAP[c] for c in self.pauli_word]
1288 P = _reduce(jnp.kron, pauli_matrices)
1289 return Hermitian(matrix=P, wires=self.wires, record=False)
1292class KrausChannel(Operation):
1293 """Base class for noise channels defined by a set of Kraus operators.
1295 A Kraus channel \\phi(\\rho ) = \\sigma_k K_k \\rho K_k\\dagger
1296 is the most general physical
1297 operation on a quantum state. For a pure unitary gate there is a single
1298 operator K_0 = U satisfying K_0\\daggerK_0 = I; for noisy channels there are
1299 multiple operators.
1301 Subclasses must implement :meth:`kraus_matrices` and return a list of JAX
1302 arrays. :meth:`apply_to_state` is intentionally left unimplemented:
1303 Kraus channels require a density-matrix representation and cannot be
1304 applied to a pure statevector in general.
1305 """
1307 def kraus_matrices(self) -> List[jnp.ndarray]:
1308 """Return the list of Kraus operators for this channel.
1310 Returns:
1311 List of 2-D JAX arrays, each of shape ``(2**k, 2**k)`` where k
1312 is the number of target qubits.
1314 Raises:
1315 NotImplementedError: Subclasses must override this method.
1316 """
1317 raise NotImplementedError
1319 @property
1320 def matrix(self) -> jnp.ndarray:
1321 """Raises TypeError — noise channels have no single unitary matrix.
1323 Raises:
1324 TypeError: Always raised; use :meth:`apply_to_density` instead.
1325 """
1326 raise TypeError(
1327 f"{self.__class__.__name__} is a noise channel and has no single "
1328 "unitary matrix. Use apply_to_density() instead."
1329 )
1331 def apply_to_state(self, state: jnp.ndarray, n_qubits: int) -> jnp.ndarray:
1332 """Raises TypeError — noise channels require density-matrix simulation.
1334 Args:
1335 state: Statevector (unused).
1336 n_qubits: Number of qubits (unused).
1338 Raises:
1339 TypeError: Always raised; use ``execute(type='density')`` instead.
1340 """
1341 raise TypeError(
1342 f"{self.__class__.__name__} is a noise channel and cannot be "
1343 "applied to a pure statevector. Use execute(type='density') instead."
1344 )
1346 def apply_to_state_tensor(self, psi: jnp.ndarray, n_qubits: int) -> jnp.ndarray:
1347 """Raises TypeError — noise channels require density-matrix simulation."""
1348 raise TypeError(
1349 f"{self.__class__.__name__} is a noise channel and cannot be "
1350 "applied to a pure statevector. Use execute(type='density') instead."
1351 )
1353 def apply_to_density(self, rho: jnp.ndarray, n_qubits: int) -> jnp.ndarray:
1354 """Apply
1355 \\phi(\\rho ) = \\sigma_k K_k \\rho K_k\\dagger using tensor-contraction.
1357 Uses the shared :func:`_contract_and_restore` helper, summing the
1358 result over all Kraus operators.
1360 Args:
1361 rho: Density matrix of shape ``(2**n_qubits, 2**n_qubits)``.
1362 n_qubits: Total number of qubits in the circuit.
1364 Returns:
1365 Updated density matrix of shape ``(2**n_qubits, 2**n_qubits)``.
1366 """
1367 k = len(self.wires)
1368 dim = 2**n_qubits
1369 bra_wires = [w + n_qubits for w in self.wires]
1370 rho_out = jnp.zeros_like(rho)
1372 for K in self.kraus_matrices():
1373 K_t = K.reshape((2,) * 2 * k)
1374 K_conj_t = jnp.conj(K_t)
1375 rho_t = rho.reshape((2,) * 2 * n_qubits)
1376 rho_t = _contract_and_restore(rho_t, K_t, k, self.wires)
1377 rho_t = _contract_and_restore(rho_t, K_conj_t, k, bra_wires)
1378 rho_out = rho_out + rho_t.reshape(dim, dim)
1380 return rho_out
1383class BitFlip(KrausChannel):
1384 r"""Single-qubit bit-flip (Pauli-X) error channel.
1386 .. math::
1387 K_0 = \sqrt{1-p}\,I, \quad K_1 = \sqrt{p}\,X
1389 where *p* \\in [0, 1] is the probability of a bit flip.
1390 """
1392 _num_wires = 1
1393 _param_names = ("p",)
1395 def __init__(self, p: float, wires: Union[int, List[int]] = 0) -> None:
1396 """Initialise a bit-flip channel.
1398 Args:
1399 p: Bit-flip probability, must be in [0, 1].
1400 wires: Qubit index or list of qubit indices this channel acts on.
1402 Raises:
1403 ValueError: If *p* is outside [0, 1].
1404 """
1405 if not 0.0 <= p <= 1.0:
1406 raise ValueError("p must be in [0, 1].")
1407 self.p = p
1408 super().__init__(wires=wires)
1410 def kraus_matrices(self) -> List[jnp.ndarray]:
1411 """Return the two Kraus operators for the bit-flip channel.
1413 Returns:
1414 List ``[K0, K1]`` where K0 = \\sqrt (1-p)·I and K1 = \\sqrt p·X.
1415 """
1416 p = self.p
1417 K0 = jnp.sqrt(1 - p) * Id._matrix
1418 K1 = jnp.sqrt(p) * PauliX._matrix
1419 return [K0, K1]
1422class PhaseFlip(KrausChannel):
1423 r"""Single-qubit phase-flip (Pauli-Z) error channel.
1425 .. math::
1426 K_0 = \sqrt{1-p}\,I, \quad K_1 = \sqrt{p}\,Z
1428 where *p* \\in [0, 1] is the probability of a phase flip.
1429 """
1431 _num_wires = 1
1432 _param_names = ("p",)
1434 def __init__(self, p: float, wires: Union[int, List[int]] = 0) -> None:
1435 """Initialise a phase-flip channel.
1437 Args:
1438 p: Phase-flip probability, must be in [0, 1].
1439 wires: Qubit index or list of qubit indices this channel acts on.
1441 Raises:
1442 ValueError: If *p* is outside [0, 1].
1443 """
1444 if not 0.0 <= p <= 1.0:
1445 raise ValueError("p must be in [0, 1].")
1446 self.p = p
1447 super().__init__(wires=wires)
1449 def kraus_matrices(self) -> List[jnp.ndarray]:
1450 """Return the two Kraus operators for the phase-flip channel.
1452 Returns:
1453 List ``[K0, K1]`` where K0 = \\sqrt (1-p)·I and K1 = \\sqrt p·Z.
1454 """
1455 p = self.p
1456 K0 = jnp.sqrt(1 - p) * Id._matrix
1457 K1 = jnp.sqrt(p) * PauliZ._matrix
1458 return [K0, K1]
1461class DepolarizingChannel(KrausChannel):
1462 r"""Single-qubit depolarizing channel.
1464 .. math::
1465 K_0 = \sqrt{1-p}\,I,\quad K_1 = \sqrt{p/3}\,X,\quad
1466 K_2 = \sqrt{p/3}\,Y,\quad K_3 = \sqrt{p/3}\,Z
1468 where *p* \\in [0, 1]. At p = 3/4 the channel is fully depolarizing.
1469 """
1471 _num_wires = 1
1472 _param_names = ("p",)
1474 def __init__(self, p: float, wires: Union[int, List[int]] = 0) -> None:
1475 """Initialise a depolarizing channel.
1477 Args:
1478 p: Depolarization probability, must be in [0, 1].
1479 wires: Qubit index or list of qubit indices this channel acts on.
1481 Raises:
1482 ValueError: If *p* is outside [0, 1].
1483 """
1484 if not 0.0 <= p <= 1.0:
1485 raise ValueError("p must be in [0, 1].")
1486 self.p = p
1487 super().__init__(wires=wires)
1489 def kraus_matrices(self) -> List[jnp.ndarray]:
1490 """Return the four Kraus operators for the depolarizing channel.
1492 Returns:
1493 List ``[K0, K1, K2, K3]`` corresponding to I, X, Y, Z components.
1494 """
1495 p = self.p
1496 K0 = jnp.sqrt(1 - p) * Id._matrix
1497 K1 = jnp.sqrt(p / 3) * PauliX._matrix
1498 K2 = jnp.sqrt(p / 3) * PauliY._matrix
1499 K3 = jnp.sqrt(p / 3) * PauliZ._matrix
1500 return [K0, K1, K2, K3]
1503class AmplitudeDamping(KrausChannel):
1504 r"""Single-qubit amplitude damping channel.
1506 .. math::
1507 K_0 = \begin{pmatrix}1 & 0\\ 0 & \sqrt{1-\gamma}\end{pmatrix},\quad
1508 K_1 = \begin{pmatrix}0 & \sqrt{\gamma}\\ 0 & 0\end{pmatrix}
1510 where *\\gamma* \\in [0, 1] is the probability of
1511 energy loss (|1\\rangle -> |0\\rangle).
1512 """
1514 _num_wires = 1
1515 _param_names = ("gamma",)
1517 def __init__(self, gamma: float, wires: Union[int, List[int]] = 0) -> None:
1518 """Initialise an amplitude damping channel.
1520 Args:
1521 gamma: Energy-loss probability, must be in [0, 1].
1522 wires: Qubit index or list of qubit indices this channel acts on.
1524 Raises:
1525 ValueError: If *gamma* is outside [0, 1].
1526 """
1527 if not 0.0 <= gamma <= 1.0:
1528 raise ValueError("gamma must be in [0, 1].")
1529 self.gamma = gamma
1530 super().__init__(wires=wires)
1532 def kraus_matrices(self) -> List[jnp.ndarray]:
1533 """Return the two Kraus operators for the amplitude damping channel.
1535 Returns:
1536 List ``[K0, K1]`` as defined in the class docstring.
1537 """
1538 g = self.gamma
1539 K0 = jnp.array([[1.0, 0.0], [0.0, jnp.sqrt(1 - g)]], dtype=_cdtype())
1540 K1 = jnp.array([[0.0, jnp.sqrt(g)], [0.0, 0.0]], dtype=_cdtype())
1541 return [K0, K1]
1544class PhaseDamping(KrausChannel):
1545 r"""Single-qubit phase damping (dephasing) channel.
1547 .. math::
1548 K_0 = \begin{pmatrix}1 & 0\\ 0 & \sqrt{1-\gamma}\end{pmatrix},\quad
1549 K_1 = \begin{pmatrix}0 & 0\\ 0 & \sqrt{\gamma}\end{pmatrix}
1551 where *\\gamma* \\in [0, 1] is the phase damping probability.
1552 """
1554 _num_wires = 1
1555 _param_names = ("gamma",)
1557 def __init__(self, gamma: float, wires: Union[int, List[int]] = 0) -> None:
1558 """Initialise a phase damping channel.
1560 Args:
1561 gamma: Phase-damping probability, must be in [0, 1].
1562 wires: Qubit index or list of qubit indices this channel acts on.
1564 Raises:
1565 ValueError: If *gamma* is outside [0, 1].
1566 """
1567 if not 0.0 <= gamma <= 1.0:
1568 raise ValueError("gamma must be in [0, 1].")
1569 self.gamma = gamma
1570 super().__init__(wires=wires)
1572 def kraus_matrices(self) -> List[jnp.ndarray]:
1573 """Return the two Kraus operators for the phase damping channel.
1575 Returns:
1576 List ``[K0, K1]`` as defined in the class docstring.
1577 """
1578 g = self.gamma
1579 K0 = jnp.array([[1.0, 0.0], [0.0, jnp.sqrt(1 - g)]], dtype=_cdtype())
1580 K1 = jnp.array([[0.0, 0.0], [0.0, jnp.sqrt(g)]], dtype=_cdtype())
1581 return [K0, K1]
1584class ThermalRelaxationError(KrausChannel):
1585 r"""Single-qubit thermal relaxation error channel.
1587 Models simultaneous T_1 energy relaxation and T_2 dephasing. Two regimes
1588 are handled:
1590 T_2 <= T_1 (Markovian dephasing + reset):
1591 Six Kraus operators built from p_z (phase-flip probability), p_r0
1592 (reset-to-|0\\rangle probability) and p_r1 (reset-to-|1\\rangle probability).
1594 T_2 > T_1 (non-Markovian; Choi matrix decomposition):
1595 The Choi matrix is assembled from the relaxation/dephasing rates, then
1596 diagonalised; Kraus operators are K_i = \sqrt \lambda_i · mat(v_i).
1598 Attributes:
1599 pe: Excited-state population (thermal population of |1\\rangle).
1600 t1: T_1 longitudinal relaxation time.
1601 t2: T_2 transverse dephasing time.
1602 tg: Gate duration.
1603 """
1605 _num_wires = 1
1606 _param_names = ("pe", "t1", "t2", "tg")
1608 def __init__(
1609 self,
1610 pe: float,
1611 t1: float,
1612 t2: float,
1613 tg: float,
1614 wires: Union[int, List[int]] = 0,
1615 ) -> None:
1616 """Initialise a thermal relaxation error channel.
1618 Args:
1619 pe: Excited-state population (thermal population of |1\\rangle), in [0, 1].
1620 t1: T_1 longitudinal relaxation time, must be > 0.
1621 t2: T_2 transverse dephasing time, must be > 0 and <= 2·T_1.
1622 tg: Gate duration, must be >= 0.
1623 wires: Qubit index or list of qubit indices this channel acts on.
1625 Raises:
1626 ValueError: If any parameter violates the stated constraints.
1627 """
1628 if not 0.0 <= pe <= 1.0:
1629 raise ValueError("pe must be in [0, 1].")
1630 if t1 <= 0:
1631 raise ValueError("t1 must be > 0.")
1632 if t2 <= 0:
1633 raise ValueError("t2 must be > 0.")
1634 if t2 > 2 * t1:
1635 raise ValueError("t2 must be <= 2·t1.")
1636 if tg < 0:
1637 raise ValueError("tg must be >= 0.")
1638 self.pe = pe
1639 self.t1 = t1
1640 self.t2 = t2
1641 self.tg = tg
1642 super().__init__(wires=wires)
1644 def kraus_matrices(self) -> List[jnp.ndarray]:
1645 """Return the Kraus operators for the thermal relaxation channel.
1647 The number of operators depends on the regime:
1649 * T_2 <= T_1: six operators (identity, phase-flip, two reset-to-|0\\rangle,
1650 two reset-to-|1\\rangle).
1651 * T_2 > T_1: four operators derived from the Choi matrix eigendecomposition.
1653 Returns:
1654 List of 2x2 JAX arrays representing the Kraus operators.
1655 """
1656 pe, t1, t2, tg = self.pe, self.t1, self.t2, self.tg
1658 eT1 = jnp.exp(-tg / t1)
1659 p_reset = 1.0 - eT1
1660 eT2 = jnp.exp(-tg / t2)
1662 if t2 <= t1:
1663 # --- Case T_2 <= T_1: six Kraus operators ---
1664 pz = (1.0 - p_reset) * (1.0 - eT2 / eT1) / 2.0
1665 pr0 = (1.0 - pe) * p_reset
1666 pr1 = pe * p_reset
1667 pid = 1.0 - pz - pr0 - pr1
1669 K0 = jnp.sqrt(pid) * jnp.eye(2, dtype=_cdtype())
1670 K1 = jnp.sqrt(pz) * jnp.array([[1, 0], [0, -1]], dtype=_cdtype())
1671 K2 = jnp.sqrt(pr0) * jnp.array([[1, 0], [0, 0]], dtype=_cdtype())
1672 K3 = jnp.sqrt(pr0) * jnp.array([[0, 1], [0, 0]], dtype=_cdtype())
1673 K4 = jnp.sqrt(pr1) * jnp.array([[0, 0], [1, 0]], dtype=_cdtype())
1674 K5 = jnp.sqrt(pr1) * jnp.array([[0, 0], [0, 1]], dtype=_cdtype())
1675 return [K0, K1, K2, K3, K4, K5]
1677 else:
1678 # --- Case T_2 > T_1: Choi matrix decomposition ---
1679 # Choi matrix (column-major / reshaping convention matching PennyLane)
1680 choi = jnp.array(
1681 [
1682 [1 - pe * p_reset, 0, 0, eT2],
1683 [0, pe * p_reset, 0, 0],
1684 [0, 0, (1 - pe) * p_reset, 0],
1685 [eT2, 0, 0, 1 - (1 - pe) * p_reset],
1686 ],
1687 dtype=_cdtype(),
1688 )
1689 eigenvalues, eigenvectors = jnp.linalg.eigh(choi)
1690 # Each eigenvector (column of eigenvectors) reshaped as 2x2 -> one Kraus op
1691 kraus = []
1692 for i in range(4):
1693 lam = eigenvalues[i]
1694 vec = eigenvectors[:, i]
1695 mat = jnp.sqrt(jnp.abs(lam)) * vec.reshape(2, 2, order="F")
1696 kraus.append(mat.astype(_cdtype()))
1697 return kraus
1700class QubitChannel(KrausChannel):
1701 """Generic Kraus channel from a user-supplied list of Kraus operators.
1703 This replaces PennyLane's ``qml.QubitChannel`` and accepts an arbitrary set
1704 of Kraus matrices satisfying \\sigma_k K_k\\dagger K_k = I.
1706 Example::
1708 kraus_ops = [jnp.sqrt(0.9) * jnp.eye(2), jnp.sqrt(0.1) * PauliX._matrix]
1709 QubitChannel(kraus_ops, wires=0)
1710 """
1712 def __init__(
1713 self, kraus_ops: List[jnp.ndarray], wires: Union[int, List[int]] = 0
1714 ) -> None:
1715 """Initialise a generic Kraus channel.
1717 Args:
1718 kraus_ops: List of Kraus matrices. Each must be a square 2D array
1719 of dimension ``2**k x 2**k`` where k = ``len(wires)``.
1720 wires: Qubit index or list of qubit indices this channel acts on.
1721 """
1722 self._kraus_ops = [jnp.asarray(K, dtype=_cdtype()) for K in kraus_ops]
1723 super().__init__(wires=wires)
1725 def kraus_matrices(self) -> List[jnp.ndarray]:
1726 """Return the stored Kraus operators.
1728 Returns:
1729 List of Kraus operator matrices.
1730 """
1731 return self._kraus_ops
1734# Single-qubit Pauli matrices (plain arrays, no Operation overhead)
1735_PAULI_MATS = [Id._matrix, PauliX._matrix, PauliY._matrix, PauliZ._matrix]
1736_PAULI_LABELS = ["I", "X", "Y", "Z"]
1737_PAULI_CLASSES = [Id, PauliX, PauliY, PauliZ]
1740def evolve_pauli_with_clifford(
1741 clifford: Operation,
1742 pauli: Operation,
1743 adjoint_left: bool = True,
1744) -> Operation:
1745 """Compute C\\dagger P C (or C P C\\dagger) and
1746 return the result as an Operation.
1748 Both operators are first embedded into the full Hilbert space spanned by
1749 the union of their wire sets. The result is wrapped in a
1750 :class:`Hermitian` so it can be used in further algebra.
1752 Args:
1753 clifford: A Clifford gate.
1754 pauli: A Pauli / Hermitian operator.
1755 adjoint_left: If ``True``, compute C\\dagger P C; otherwise C P C\\dagger.
1757 Returns:
1758 A :class:`Hermitian` wrapping the evolved matrix.
1759 """
1760 all_wires = sorted(set(clifford.wires) | set(pauli.wires))
1761 n = len(all_wires)
1763 C = _embed_matrix(clifford.matrix, clifford.wires, all_wires, n)
1764 P = _embed_matrix(pauli.matrix, pauli.wires, all_wires, n)
1765 Cd = jnp.conj(C).T
1767 if adjoint_left:
1768 result = Cd @ P @ C
1769 else:
1770 result = C @ P @ Cd
1772 return Hermitian(matrix=result, wires=all_wires, record=False)
1775def _embed_matrix(
1776 mat: jnp.ndarray,
1777 op_wires: list,
1778 all_wires: list,
1779 n_total: int,
1780) -> jnp.ndarray:
1781 """Embed a gate matrix into a larger Hilbert space via tensor products.
1783 If the gate already acts on all wires, the matrix is returned as-is.
1784 Otherwise the gate matrix is tensored with identities on the missing
1785 wires, and the resulting matrix rows/columns are permuted so that qubit
1786 ordering matches *all_wires*.
1788 Args:
1789 mat: The gate's unitary matrix of shape ``(2**k, 2**k)`` where
1790 ``k = len(op_wires)``.
1791 op_wires: The wires the gate acts on.
1792 all_wires: The full ordered list of wires.
1793 n_total: ``len(all_wires)``.
1795 Returns:
1796 A ``(2**n_total, 2**n_total)`` matrix.
1797 """
1798 k = len(op_wires)
1799 if k == n_total and list(op_wires) == list(all_wires):
1800 return mat
1802 # Build the full-space matrix by tensoring with identities
1803 # Strategy: tensor I on missing wires, then permute
1804 missing = [w for w in all_wires if w not in op_wires]
1805 # Full matrix = mat \\otimes I_{missing}
1806 full_mat = mat
1807 for _ in missing:
1808 full_mat = jnp.kron(full_mat, jnp.eye(2, dtype=_cdtype()))
1810 # The current ordering is [op_wires..., missing...]
1811 # We need to permute to match all_wires ordering
1812 current_order = list(op_wires) + missing
1813 if current_order != list(all_wires):
1814 perm = [current_order.index(w) for w in all_wires]
1815 full_mat = _permute_matrix(full_mat, perm, n_total)
1817 return full_mat
1820def _permute_matrix(mat: jnp.ndarray, perm: list, n_qubits: int) -> jnp.ndarray:
1821 """Permute the qubit ordering of a matrix.
1823 Given a ``(2**n, 2**n)`` matrix and a permutation of ``[0..n-1]``,
1824 reorder the qubits so that qubit ``i`` moves to position ``perm[i]``.
1826 Args:
1827 mat: Square matrix of dimension ``2**n_qubits``.
1828 perm: Permutation list.
1829 n_qubits: Number of qubits.
1831 Returns:
1832 Permuted matrix of the same shape.
1833 """
1834 dim = 2**n_qubits
1835 # Reshape to tensor, permute axes, reshape back
1836 tensor = mat.reshape([2] * (2 * n_qubits))
1837 # Axes: first n_qubits are row indices, last n_qubits are column indices
1838 row_perm = perm
1839 col_perm = [p + n_qubits for p in perm]
1840 tensor = jnp.transpose(tensor, row_perm + col_perm)
1841 return tensor.reshape(dim, dim)
1844def pauli_decompose(matrix: jnp.ndarray, wire_order: Optional[List[int]] = None):
1845 r"""Decompose a Hermitian matrix into a sum of Pauli tensor products.
1847 For an n-qubit matrix (``2**n x 2**n``), returns the dominant Pauli
1848 term (the one with the largest absolute coefficient), wrapped as an
1849 :class:`Operation`. This is sufficient for the Fourier-tree algorithm
1850 which only needs the single non-zero Pauli term produced by Clifford
1851 conjugation of a Pauli operator.
1853 The decomposition uses the trace formula:
1854 ``c_P = Tr(P · M) / 2**n``
1856 Args:
1857 matrix: A ``(2**n, 2**n)`` Hermitian matrix.
1858 wire_order: Optional list of wire indices. If ``None``, defaults
1859 to ``[0, 1, ..., n-1]``.
1861 Returns:
1862 A tuple ``(coeff, op)`` where *coeff* is the complex coefficient and
1863 *op* is the Pauli :class:`Operation` (PauliX, PauliY, PauliZ, I, or
1864 a :class:`Hermitian` for multi-qubit tensor products).
1865 """
1866 from itertools import product as _product
1867 from functools import reduce as _reduce
1869 dim = matrix.shape[0]
1870 n_qubits = int(jnp.round(jnp.log2(dim)))
1872 if wire_order is None:
1873 wire_order = list(range(n_qubits))
1875 # For single qubit, fast path
1876 if n_qubits == 1:
1877 best_idx, best_coeff = 0, 0.0
1878 for idx, P in enumerate(_PAULI_MATS):
1879 coeff = jnp.trace(P @ matrix) / 2.0
1880 if jnp.abs(coeff) > jnp.abs(best_coeff):
1881 best_idx = idx
1882 best_coeff = coeff
1883 op_cls = _PAULI_CLASSES[best_idx]
1884 result_op = op_cls(wires=wire_order[0], record=False)
1885 result_op._pauli_label = _PAULI_LABELS[best_idx]
1886 return best_coeff, result_op
1888 # Multi-qubit: iterate over all Pauli tensor products
1889 best_label = None
1890 best_coeff = 0.0
1891 for indices in _product(range(4), repeat=n_qubits):
1892 P = _reduce(jnp.kron, [_PAULI_MATS[i] for i in indices])
1893 coeff = jnp.trace(P @ matrix) / dim
1894 if jnp.abs(coeff) > jnp.abs(best_coeff):
1895 best_coeff = coeff
1896 best_label = indices
1898 # Build the Pauli string label
1899 pauli_label = "".join(_PAULI_LABELS[i] for i in best_label)
1901 # Build the operation for the dominant term
1902 if sum(1 for i in best_label if i != 0) <= 1:
1903 # Single-qubit Pauli on one wire
1904 for q, idx in enumerate(best_label):
1905 if idx != 0:
1906 op_cls = _PAULI_CLASSES[idx]
1907 result_op = op_cls(wires=wire_order[q], record=False)
1908 result_op._pauli_label = _PAULI_LABELS[idx]
1909 return best_coeff, result_op
1910 # All identity
1911 result_op = Id(wires=wire_order[0], record=False)
1912 result_op._pauli_label = "I" * n_qubits
1913 return best_coeff, result_op
1914 else:
1915 # Multi-qubit tensor product -> Hermitian with pauli label attached
1916 P = _reduce(jnp.kron, [_PAULI_MATS[i] for i in best_label])
1917 result_op = Hermitian(matrix=P, wires=wire_order, record=False)
1918 result_op._pauli_label = pauli_label
1919 return best_coeff, result_op
1922def pauli_string_from_operation(op: Operation) -> str:
1923 """Extract a Pauli word string from an operation.
1925 Maps ``PauliX`` -> ``"X"``, ``PauliY`` -> ``"Y"``, ``PauliZ`` -> ``"Z"``,
1926 ``I`` -> ``"I"``. For :class:`PauliRot`, returns its stored ``pauli_word``.
1927 For operations produced by :func:`pauli_decompose`, returns the stored
1928 ``_pauli_label`` attribute.
1930 Args:
1931 op: A quantum operation.
1933 Returns:
1934 A string like ``"X"``, ``"ZZ"``, etc.
1935 """
1936 if isinstance(op, PauliRot) and hasattr(op, "pauli_word"):
1937 return op.pauli_word
1938 # Check for label stored by pauli_decompose
1939 if hasattr(op, "_pauli_label"):
1940 return op._pauli_label
1941 name_map = {"PauliX": "X", "PauliY": "Y", "PauliZ": "Z", "I": "I"}
1942 if op.name in name_map:
1943 return name_map[op.name]
1944 # Fall back: decompose the matrix
1945 _, pauli_op = pauli_decompose(op.matrix, wire_order=op.wires)
1946 return pauli_op._pauli_label
1949def prod(*ops: Operation) -> Operation:
1950 """Construct the generalized product (tensor or matrix) of multiple operations.
1952 The resulting operation acts on the union of all wire sets.
1953 If the wire sets are disjoint, this is a Kronecker product.
1954 If the wire sets overlap, the corresponding matrices are multiplied.
1956 Args:
1957 *ops: Variable number of :class:`Operation` instances.
1959 Returns:
1960 A new :class:`Operation` whose matrix represents the composed
1961 operation on the unified wire set.
1962 """
1963 if not ops:
1964 raise ValueError("At least one operation must be provided to prod().")
1965 return ops[0].prod(*ops[1:])