Coverage for qml_essentials / operations.py: 90%

526 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-04-10 10:29 +0000

1from typing import Callable, List, Optional, Tuple, Union 

2from functools import lru_cache 

3import string 

4import numpy as np 

5 

6import jax 

7import jax.numpy as jnp 

8 

9from qml_essentials.tape import active_tape, recording # noqa: F401 (re-export) 

10 

11 

12def _cdtype(): 

13 """Return the active JAX complex dtype 

14 (complex128 if x64 enabled, else complex64). 

15 """ 

16 return jnp.complex128 if jax.config.x64_enabled else jnp.complex64 

17 

18 

19@lru_cache(maxsize=256) 

20def _einsum_subscript( 

21 n: int, 

22 k: int, 

23 target_axes: Tuple[int, ...], 

24) -> str: 

25 """Build an ``einsum`` subscript that fuses contraction + axis restore. 

26 

27 Args: 

28 n: Total rank of the state tensor (number of qubits for statevectors, 

29 ``2 * n_qubits`` for density matrices). 

30 k: Number of qubits the gate acts on. 

31 target_axes: Tuple of k axis indices in the state tensor that the 

32 gate contracts against. 

33 

34 Returns: 

35 ``einsum`` subscript string, e.g. ``"ab,cBd->cad"`` for a 1-qubit 

36 gate on wire 1 of a 3-qubit state. 

37 """ 

38 letters = string.ascii_letters 

39 # State indices: one letter per axis 

40 state_idx = list(letters[:n]) 

41 # Contracted indices (the ones being replaced by the gate) 

42 contracted = [state_idx[ax] for ax in target_axes] 

43 # Gate indices: new output indices + contracted input indices 

44 new_out = [letters[n + i] for i in range(k)] # fresh letters for output 

45 gate_idx = new_out + contracted # gate shape: (out0, out1, ..., in0, in1, ...) 

46 # Result indices: replace target axes with new output letters 

47 result_idx = list(state_idx) 

48 for i, ax in enumerate(target_axes): 

49 result_idx[ax] = new_out[i] 

50 return "".join(gate_idx) + "," + "".join(state_idx) + "->" + "".join(result_idx) 

51 

52 

53def _contract_and_restore( 

54 tensor: jnp.ndarray, 

55 gate: jnp.ndarray, 

56 k: int, 

57 target_axes: List[int], 

58) -> jnp.ndarray: 

59 """Contract gate against target_axes of tensor and restore axis order. 

60 

61 The einsum subscript is cached via :func:`_einsum_subscript` so the 

62 string construction only happens once per unique 

63 ``(total, k, target_axes)`` combination. 

64 

65 Args: 

66 tensor: Rank-N tensor (e.g. ``(2,)*n`` for states or ``(2,)*2n`` 

67 for density matrices). 

68 gate: Reshaped gate tensor of shape ``(2,)*2k``. 

69 k: Number of qubits the gate acts on (= ``len(target_axes)``). 

70 target_axes: The k axes of tensor to contract against. 

71 

72 Returns: 

73 Updated tensor with the same rank as tensor, with the 

74 contracted axes restored to their original positions. 

75 """ 

76 subscript = _einsum_subscript(tensor.ndim, k, tuple(target_axes)) 

77 return jnp.einsum(subscript, gate, tensor) 

78 

79 

80class Operation: 

81 """Base class for any quantum operation or observable. 

82 

83 Further gates should inherit from this class to realise more specific 

84 operations. Generally, operations are created by instantiation inside a 

85 circuit function passed to :class:`Script`; the instance is 

86 automatically appended to the active tape. 

87 

88 An ``Operation`` can also serve as an *observable*: its matrix is used to 

89 compute expectation values via ``apply_to_state`` / ``apply_to_density``. 

90 

91 Attributes: 

92 _matrix: Class-level default gate matrix. Subclasses set this to their 

93 fixed unitary. Instances may override it via the *matrix* argument 

94 to ``__init__``. 

95 _num_wires: Expected number of wires for this gate. Subclasses set 

96 this to enforce wire count validation. ``None`` means any number 

97 of wires is accepted. 

98 _param_names: Tuple of attribute names for the gate parameters. 

99 Used by :attr:`parameters` and :meth:`__repr__`. 

100 """ 

101 

102 # Subclasses should set this to the gate's unitary / matrix 

103 _matrix: jnp.ndarray = None 

104 is_controlled = False 

105 _num_wires: Optional[int] = None 

106 _param_names: Tuple[str, ...] = () 

107 

108 def __init__( 

109 self, 

110 wires: Union[int, List[int]] = 0, 

111 matrix: Optional[jnp.ndarray] = None, 

112 record: bool = True, 

113 input_idx: int = -1, 

114 name: Optional[str] = None, 

115 ) -> None: 

116 """Initialise the operation and optionally register it on the active tape. 

117 

118 Args: 

119 wires: Qubit index or list of qubit indices this operation acts on. 

120 matrix: Optional explicit gate matrix. When provided it overrides 

121 the class-level ``_matrix`` attribute. 

122 record: If ``True`` (default) and a tape is currently recording, 

123 append this operation to the tape. Set to ``False`` for 

124 auxiliary objects that should not appear in the circuit 

125 (e.g. Hamiltonians used only to build time-dependent 

126 evolutions). 

127 input_idx: Marks the operation as input with the corresponding 

128 input index, which is useful for the analytical Fourier 

129 coefficients computation, but has no effect otherwise. 

130 name: Optional explicit name for this operation. When ``None`` 

131 (default), the class name is used (e.g. ``"RX"``). 

132 

133 Raises: 

134 ValueError: If ``_num_wires`` is set and the number of wires 

135 doesn't match, or if duplicate wires are provided. 

136 """ 

137 self.name = name or self.__class__.__name__ 

138 self.wires = list(wires) if isinstance(wires, (list, tuple)) else [wires] 

139 self.input_idx = input_idx 

140 

141 if self._num_wires is not None and len(self.wires) != self._num_wires: 

142 raise ValueError( 

143 f"{self.name} expects {self._num_wires} wire(s), " 

144 f"got {len(self.wires)}: {self.wires}" 

145 ) 

146 if len(self.wires) != len(set(self.wires)): 

147 raise ValueError(f"{self.name} received duplicate wires: {self.wires}") 

148 

149 if matrix is not None: 

150 self._matrix = matrix 

151 

152 # If a tape is currently recording, append ourselves 

153 if record: 

154 tape = active_tape() 

155 if tape is not None: 

156 tape.append(self) 

157 

158 @property 

159 def parameters(self) -> list: 

160 """Return the list of numeric parameters for this operation. 

161 

162 Uses the declarative ``_param_names`` tuple to collect parameter 

163 values in a canonical order. Non-parametrized gates return an 

164 empty list. 

165 

166 Returns: 

167 List of parameter values (floats or JAX arrays). 

168 """ 

169 return [getattr(self, name) for name in self._param_names] 

170 

171 def __repr__(self) -> str: 

172 """Return a human-readable representation of this operation. 

173 

174 Returns: 

175 A string like ``"RX(0.5000, wires=[0])"`` or ``"CX(wires=[0, 1])"``. 

176 """ 

177 params = self.parameters 

178 if params: 

179 param_str = ", ".join( 

180 ( 

181 f"{float(v):.4f}" 

182 if isinstance(v, (float, np.floating, jnp.ndarray)) 

183 else str(v) 

184 ) 

185 for v in params 

186 ) 

187 return f"{self.name}({param_str}, wires={self.wires})" 

188 return f"{self.name}(wires={self.wires})" 

189 

190 @property 

191 def matrix(self) -> jnp.ndarray: 

192 """Return the base matrix of this operation (before lifting). 

193 

194 Returns: 

195 The gate matrix as a JAX array. 

196 

197 Raises: 

198 NotImplementedError: If the subclass has not defined ``_matrix``. 

199 """ 

200 if self._matrix is None: 

201 raise NotImplementedError( 

202 f"{self.__class__.__name__} does not define a matrix." 

203 ) 

204 return self._matrix 

205 

206 @property 

207 def wires(self) -> List[int]: 

208 """Qubit indices this operation acts on. 

209 

210 Returns: 

211 List of integer qubit indices. 

212 """ 

213 return self._wires 

214 

215 @wires.setter 

216 def wires(self, wires: Union[int, List[int]]) -> None: 

217 """Set the qubit indices for this operation. 

218 

219 Args: 

220 wires: A single qubit index or a list of qubit indices. 

221 """ 

222 if isinstance(wires, (list, tuple)): 

223 self._wires = list(wires) 

224 else: 

225 self._wires = [wires] 

226 

227 @property 

228 def input_idx(self) -> int: 

229 """The index of an input 

230 

231 Returns: 

232 input_idx: Index of the input 

233 """ 

234 return self._input_idx 

235 

236 @input_idx.setter 

237 def input_idx(self, input_idx: int) -> None: 

238 """Setter for the input_idx flag 

239 

240 Args: 

241 input_idx: Index of the input 

242 """ 

243 self._input_idx = input_idx 

244 

245 def _update_tape_operation(self, op: "Operation") -> None: 

246 """ 

247 If ``self`` is already on the active tape (the typical case when 

248 chaining ``Gate(...).dagger()``), it is replaced by the daggered 

249 operation so that only U\\dagger appears on the tape — 

250 not both U and ``U\\dagger``. 

251 Note that this should only be called immediately after the tape is updated.s 

252 

253 Args: 

254 op (Operation): New replaced operation on the tape 

255 """ 

256 # If self was recorded on the tape, replace it with the daggered op. 

257 tape = active_tape() 

258 if tape is not None: 

259 if tape and tape[-1] is self: 

260 tape[-1] = op 

261 else: 

262 tape.append(op) 

263 

264 def dagger(self) -> "Operation": 

265 """Return a new operation, the conjugate transpose (``U\\dagger``) 

266 Usage inside a circuit function:: 

267 

268 RX(0.5, wires=0).dagger() 

269 

270 Returns: 

271 A new :class:`Operation` with matrix ``U\\dagger`` acting on the same wires. 

272 """ 

273 mat = jnp.conj(self._matrix).T 

274 op = Operation(wires=self.wires, matrix=mat, record=False) 

275 

276 self._update_tape_operation(op) 

277 

278 return op 

279 

280 def power(self, power) -> "Operation": 

281 """Return a new operation, the power (``U^power``) 

282 Usage inside a circuit function:: 

283 

284 PauliX(wires=0).power(2) 

285 

286 Returns: 

287 A new :class:`Operation` with matrix ``U\\dagger`` acting on the same wires. 

288 """ 

289 # TODO: support fractional powers 

290 mat = jnp.linalg.matrix_power(self._matrix, power) 

291 op = Operation(wires=self.wires, matrix=mat, record=False) 

292 

293 self._update_tape_operation(op) 

294 

295 return op 

296 

297 def __mul__(self, factor: float) -> "Operation": 

298 """Return a new operation, the product between U and a scalar (``U*x``) 

299 Usage inside a circuit function:: 

300 

301 PauliX(wires=0) * x 

302 

303 Returns: 

304 A new :class:`Operation` with matrix ``U*x`` acting on the same wires. 

305 """ 

306 mat = factor * self._matrix 

307 op = Operation(wires=self.wires, matrix=mat, record=False) 

308 

309 self._update_tape_operation(op) 

310 

311 return op 

312 

313 # Also overwrite * for right operands 

314 __rmul__ = __mul__ 

315 

316 def __add__(self, other: "Operation") -> "Operation": 

317 """Element-wise addition of two operations on the same wires. 

318 

319 Returns: 

320 A new :class:`Operation` whose matrix is the sum of both matrices. 

321 

322 Raises: 

323 ValueError: If the wire sets differ. 

324 """ 

325 if sorted(self.wires) != sorted(other.wires): 

326 raise ValueError( 

327 f"Can only add operations acting on the same set of wires, " 

328 f"got {self.wires} and {other.wires}" 

329 ) 

330 

331 op = Operation( 

332 wires=self.wires, 

333 matrix=self.matrix + other.matrix, 

334 record=False, 

335 ) 

336 return op 

337 

338 def __matmul__(self, other: "Operation") -> "Operation": 

339 """Tensor (Kronecker) product of two operations. 

340 

341 The resulting operation acts on the union of both wire sets and 

342 carries the Kronecker product of both matrices. Wire sets must 

343 be disjoint. 

344 

345 Returns: 

346 A new :class:`Operation` whose matrix is ``self.matrix ⊗ other.matrix`` 

347 and whose wires are the concatenation of both wire lists. 

348 

349 Raises: 

350 ValueError: If the two operations share any wires. 

351 """ 

352 if set(self.wires) & set(other.wires): 

353 raise ValueError( 

354 f"Cannot take tensor product: overlapping wires " 

355 f"{self.wires} and {other.wires}" 

356 ) 

357 new_matrix = jnp.kron(self.matrix, other.matrix) 

358 new_wires = self.wires + other.wires 

359 op = Operation(wires=new_wires, matrix=new_matrix, record=False) 

360 return op 

361 

362 def lifted_matrix(self, n_qubits: int) -> jnp.ndarray: 

363 """Return the full ``2**n x 2**n`` matrix embedding this gate. 

364 

365 Embeds the ``k``-qubit gate matrix into the ``n``-qubit Hilbert space 

366 by applying it to the identity matrix via :meth:`apply_to_state`. 

367 This is useful for computing ``Tr(O·\\rho )`` directly without vmap. 

368 

369 Args: 

370 n_qubits: Total number of qubits in the circuit. 

371 

372 Returns: 

373 The ``(2**n, 2**n)`` matrix of this operation in the full space. 

374 """ 

375 dim = 2**n_qubits 

376 # Apply the gate to each basis vector (column of identity) 

377 return jax.vmap(lambda col: self.apply_to_state(col, n_qubits))( 

378 jnp.eye(dim, dtype=_cdtype()) 

379 ).T 

380 

381 def apply_to_state(self, state: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

382 """Apply this gate to a statevector via tensor contraction. 

383 

384 The statevector (shape ``(2**n,)``) is reshaped into a rank-n tensor 

385 of shape ``(2,)*n``. The gate (shape ``(2**k, 2**k)``) is reshaped to 

386 ``(2,)*2k`` and contracted against the k target wire axes. 

387 

388 Memory footprint is O(2**n) and the operation supports arbitrary k. 

389 The implementation is fully differentiable through JAX. 

390 

391 Args: 

392 state: Statevector of shape ``(2**n_qubits,)``. 

393 n_qubits: Total number of qubits in the circuit. 

394 

395 Returns: 

396 Updated statevector of shape ``(2**n_qubits,)``. 

397 """ 

398 k = len(self.wires) 

399 gate_tensor = self.matrix.reshape((2,) * 2 * k) 

400 psi = state.reshape((2,) * n_qubits) 

401 psi_out = _contract_and_restore(psi, gate_tensor, k, self.wires) 

402 return psi_out.reshape(2**n_qubits) 

403 

404 def apply_to_state_tensor(self, psi: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

405 """Apply this gate to a statevector already in tensor form. 

406 

407 Like :meth:`apply_to_state` but expects the state in rank-n tensor 

408 form ``(2,)*n`` and returns the result in the same form. This avoids 

409 the ``reshape`` calls at the per-gate level when the simulation loop 

410 keeps the state in tensor form throughout. 

411 

412 Args: 

413 psi: Statevector tensor of shape ``(2,)*n_qubits``. 

414 n_qubits: Total number of qubits in the circuit. 

415 

416 Returns: 

417 Updated statevector tensor of shape ``(2,)*n_qubits``. 

418 """ 

419 k = len(self.wires) 

420 gate_tensor = self._gate_tensor(k) 

421 return _contract_and_restore(psi, gate_tensor, k, self.wires) 

422 

423 def _gate_tensor(self, k: int) -> jnp.ndarray: 

424 """Return the gate matrix reshaped to ``(2,)*2k`` tensor form. 

425 

426 The result is cached on the instance so repeated calls (e.g. from 

427 density-matrix simulation which applies U and U*) avoid redundant 

428 reshape dispatch. 

429 

430 Args: 

431 k: Number of qubits the gate acts on. 

432 

433 Returns: 

434 Gate matrix as a rank-2k tensor of shape ``(2,)*2k``. 

435 """ 

436 cached = getattr(self, "_cached_gate_tensor", None) 

437 if cached is not None: 

438 return cached 

439 gt = self.matrix.reshape((2,) * 2 * k) 

440 # Only cache for non-parametrized gates (whose matrix is a class attr) 

441 if self._matrix is self.__class__._matrix: 

442 object.__setattr__(self, "_cached_gate_tensor", gt) 

443 return gt 

444 

445 def apply_to_density(self, rho: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

446 """Apply this gate to a density matrix via \\rho -> U\\rho U\\dagger. 

447 

448 The density matrix (shape ``(2**n, 2**n)``) is treated as a rank-*2n* 

449 tensor with n "ket" axes (0..n-1) and n "bra" axes (n..2n-1). 

450 U acts on the ket half; U* acts on the bra half. Both contractions 

451 use the shared :func:`_contract_and_restore` helper, keeping the 

452 operation allocation-free with respect to building full unitaries. 

453 

454 Args: 

455 rho: Density matrix of shape ``(2**n_qubits, 2**n_qubits)``. 

456 n_qubits: Total number of qubits in the circuit. 

457 

458 Returns: 

459 Updated density matrix of shape ``(2**n_qubits, 2**n_qubits)``. 

460 """ 

461 k = len(self.wires) 

462 U = self._gate_tensor(k) 

463 U_conj = jnp.conj(U) 

464 

465 rho_t = rho.reshape((2,) * 2 * n_qubits) 

466 

467 # Apply U to ket axes, U\\dagger to bra axes 

468 rho_t = _contract_and_restore(rho_t, U, k, self.wires) 

469 bra_wires = [w + n_qubits for w in self.wires] 

470 rho_t = _contract_and_restore(rho_t, U_conj, k, bra_wires) 

471 

472 return rho_t.reshape(2**n_qubits, 2**n_qubits) 

473 

474 

475class Hermitian(Operation): 

476 """A generic Hermitian observable or gate defined by an arbitrary matrix. 

477 

478 Example: 

479 >>> obs = Hermitian(matrix=my_matrix, wires=0) 

480 """ 

481 

482 def __init__( 

483 self, 

484 matrix: jnp.ndarray, 

485 wires: Union[int, List[int]] = 0, 

486 record: bool = True, 

487 ) -> None: 

488 """Initialise a Hermitian operator. 

489 

490 Args: 

491 matrix: The Hermitian matrix defining this operator. 

492 wires: Qubit index or list of qubit indices this operator acts on. 

493 record: If ``True`` (default), record on the active tape. Set to 

494 ``False`` when using the Hermitian purely as a Hamiltonian 

495 component (e.g. for time-dependent evolution). 

496 """ 

497 super().__init__( 

498 wires=wires, 

499 matrix=jnp.asarray(matrix, dtype=_cdtype()), 

500 record=record, 

501 ) 

502 

503 def __rmul__(self, coeff_fn): 

504 """Support ``coeff_fn * Hermitian`` -> :class:`ParametrizedHamiltonian`. 

505 

506 Args: 

507 coeff_fn: A callable ``(params, t) -> scalar`` giving the 

508 time-dependent coefficient. 

509 

510 Returns: 

511 A :class:`ParametrizedHamiltonian` pairing *coeff_fn* with this 

512 operator's matrix and wires. 

513 

514 Raises: 

515 TypeError: If *coeff_fn* is not callable. 

516 """ 

517 if not callable(coeff_fn): 

518 raise TypeError( 

519 f"Left operand of `* Hermitian` must be callable, got {type(coeff_fn)}" 

520 ) 

521 return ParametrizedHamiltonian(coeff_fn, self.matrix, self.wires) 

522 

523 

524class ParametrizedHamiltonian: 

525 """A time-dependent Hamiltonian ``H(t) = f(params, t) · H_mat``. 

526 

527 Created by multiplying a callable coefficient function with a 

528 :class:`Hermitian` operator:: 

529 

530 def coeff(p, t): 

531 return p[0] * jnp.exp(-0.5 * ((t - t_c) / p[1]) ** 2) 

532 

533 H_td = coeff * Hermitian(matrix=sigma_x, wires=0) 

534 

535 The Hamiltonian is then used with :func:`evolve`:: 

536 

537 evolve(H_td)(coeff_args=[A, sigma], T=1.0) 

538 

539 Attributes: 

540 coeff_fn: Callable ``(params, t) -> scalar``. 

541 H_mat: Static Hermitian matrix (JAX array). 

542 wires: Qubit wire(s) this Hamiltonian acts on. 

543 """ 

544 

545 def __init__( 

546 self, 

547 coeff_fn: Callable, 

548 H_mat: jnp.ndarray, 

549 wires: Union[int, List[int]], 

550 ) -> None: 

551 self.coeff_fn = coeff_fn 

552 self.H_mat = H_mat 

553 self.wires = wires 

554 

555 

556class Id(Operation): 

557 """Identity gate. 

558 

559 Supports an arbitrary number of wires. When more than one wire is 

560 given the matrix is the ``2**k x 2**k`` identity (where *k* is the 

561 number of wires). 

562 """ 

563 

564 _matrix = jnp.eye(2, dtype=_cdtype()) 

565 _num_wires = None # accept any number of wires 

566 

567 def __init__(self, wires: Union[int, List[int]] = 0, **kwargs) -> None: 

568 """Initialise an identity gate. 

569 

570 Args: 

571 wires: Qubit index or list of qubit indices this gate acts on. 

572 When multiple wires are given the matrix is automatically 

573 expanded to the matching ``2**k × 2**k`` identity. 

574 """ 

575 w = list(wires) if isinstance(wires, (list, tuple)) else [wires] 

576 k = len(w) 

577 if k > 1: 

578 kwargs["matrix"] = jnp.eye(2**k, dtype=_cdtype()) 

579 super().__init__(wires=wires, **kwargs) 

580 

581 

582class PauliX(Operation): 

583 """Pauli-X gate / observable (bit-flip, \\sigma_x).""" 

584 

585 _matrix = jnp.array([[0, 1], [1, 0]], dtype=_cdtype()) 

586 _num_wires = 1 

587 

588 def __init__(self, wires: Union[int, List[int]] = 0, **kwargs) -> None: 

589 """Initialise a Pauli-X gate. 

590 

591 Args: 

592 wires: Qubit index or list of qubit indices this gate acts on. 

593 """ 

594 super().__init__(wires=wires, **kwargs) 

595 

596 

597class PauliY(Operation): 

598 """Pauli-Y gate / observable (\\sigma_y).""" 

599 

600 _matrix = jnp.array([[0, -1j], [1j, 0]], dtype=_cdtype()) 

601 _num_wires = 1 

602 

603 def __init__(self, wires: Union[int, List[int]] = 0, **kwargs) -> None: 

604 """Initialise a Pauli-Y gate. 

605 

606 Args: 

607 wires: Qubit index or list of qubit indices this gate acts on. 

608 """ 

609 super().__init__(wires=wires, **kwargs) 

610 

611 

612class PauliZ(Operation): 

613 """Pauli-Z gate / observable (phase-flip, \\sigma_z).""" 

614 

615 _matrix = jnp.array([[1, 0], [0, -1]], dtype=_cdtype()) 

616 _num_wires = 1 

617 

618 def __init__(self, wires: Union[int, List[int]] = 0, **kwargs) -> None: 

619 """Initialise a Pauli-Z gate. 

620 

621 Args: 

622 wires: Qubit index or list of qubit indices this gate acts on. 

623 """ 

624 super().__init__(wires=wires, **kwargs) 

625 

626 

627class H(Operation): 

628 """Hadamard gate.""" 

629 

630 _matrix = jnp.array([[1, 1], [1, -1]], dtype=_cdtype()) / jnp.sqrt(2) 

631 _num_wires = 1 

632 

633 def __init__(self, wires: Union[int, List[int]] = 0, **kwargs) -> None: 

634 """Initialise a Hadamard gate. 

635 

636 Args: 

637 wires: Qubit index or list of qubit indices this gate acts on. 

638 """ 

639 super().__init__(wires=wires, **kwargs) 

640 

641 

642class S(Operation): 

643 """S (phase) gate — a Clifford gate equal to \\sqrt Z. 

644 

645 .. math:: 

646 S = \\begin{pmatrix}1 & 0\\ 0 & i\\end{pmatrix} 

647 """ 

648 

649 _matrix = jnp.array([[1, 0], [0, 1j]], dtype=_cdtype()) 

650 _num_wires = 1 

651 

652 def __init__(self, wires: Union[int, List[int]] = 0) -> None: 

653 """Initialise an S gate. 

654 

655 Args: 

656 wires: Qubit index or list of qubit indices this gate acts on. 

657 """ 

658 super().__init__(wires=wires) 

659 

660 

661class SWAP(Operation): 

662 """SWAP gate.""" 

663 

664 _matrix = jnp.array( 

665 [[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=_cdtype() 

666 ) 

667 _num_wires = 2 

668 

669 def __init__(self, wires: Union[int, List[int]] = 0, **kwargs) -> None: 

670 """Initialise a SWAP gate. 

671 

672 Args: 

673 wires: Qubit index or list of qubit indices this gate acts on. 

674 """ 

675 super().__init__(wires=wires, **kwargs) 

676 

677 

678class RandomUnitary(Operation): 

679 """Creates a random hermitian matrix and applies it as a gate.""" 

680 

681 def __init__(self, wires, key, scale=1.0, record=True): 

682 """Initialise a random unitary gate. 

683 

684 Args: 

685 wires: Qubit index or list of qubit indices this gate acts on. 

686 jax.random.PRNGKey: PRNGKey for randomization 

687 scale: Scale of the random unitary (default: 1.0) 

688 """ 

689 dim = 2 ** len(wires) 

690 key_a, key_b = jax.random.split(key) 

691 

692 A = ( 

693 jax.random.normal(key=key_a, shape=(dim, dim)) 

694 + 1j * jax.random.normal(key=key_b, shape=(dim, dim)) 

695 ).astype(_cdtype()) 

696 H = (A + A.conj().T) / 2.0 

697 

698 H *= scale / jnp.linalg.norm(H, ord="fro") 

699 

700 super().__init__(wires, matrix=H, record=record) 

701 

702 

703class Barrier(Operation): 

704 """Barrier operation — a no-op used for visual circuit separation. 

705 

706 The barrier does not change the quantum state. It is recorded on the 

707 tape so that drawing backends can insert a visual separator. 

708 """ 

709 

710 _matrix = None # not a real gate 

711 

712 def __init__(self, wires: Union[int, List[int]] = 0) -> None: 

713 """Initialise a Barrier. 

714 

715 Args: 

716 wires: Qubit index or list of qubit indices this barrier spans. 

717 """ 

718 super().__init__(wires=wires) 

719 

720 def apply_to_state(self, state: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

721 """No-op: return the state unchanged.""" 

722 return state 

723 

724 def apply_to_state_tensor(self, psi: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

725 """No-op: return the state tensor unchanged.""" 

726 return psi 

727 

728 def apply_to_density(self, rho: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

729 """No-op: return the density matrix unchanged.""" 

730 return rho 

731 

732 

733def _make_rotation_gate(pauli_class: type, name: str) -> type: 

734 """Factory for single-qubit rotation gates RX, RY, RZ. 

735 

736 Each gate has the form ``R_P(\\theta) = cos(\\theta/2) I - i sin(\\theta/2) P``. 

737 

738 Args: 

739 pauli_class: One of PauliX, PauliY, PauliZ. 

740 name: Class name for the generated gate (e.g. ``"RX"``). 

741 

742 Returns: 

743 A new :class:`Operation` subclass. 

744 """ 

745 pauli_mat = pauli_class._matrix 

746 

747 class _RotationGate(Operation): 

748 # Fancy way of setting docstring to make it generic 

749 __doc__ = ( 

750 f"Rotation around the {name[1]} axis: {name}(\\theta) =\n" 

751 f"exp(-i \\theta/2 {name[1]}).\n" 

752 ) 

753 _num_wires = 1 

754 _param_names = ("theta",) 

755 

756 def __init__( 

757 self, theta: float, wires: Union[int, List[int]] = 0, **kwargs 

758 ) -> None: 

759 self.theta = theta 

760 c = jnp.cos(theta / 2) 

761 s = jnp.sin(theta / 2) 

762 mat = c * Id._matrix - 1j * s * pauli_mat 

763 super().__init__(wires=wires, matrix=mat, **kwargs) 

764 

765 def generator(self) -> Operation: 

766 """Return the generator as the corresponding Pauli operation.""" 

767 return pauli_class(wires=self.wires[0], record=False) 

768 

769 _RotationGate.__name__ = name 

770 _RotationGate.__qualname__ = name 

771 return _RotationGate 

772 

773 

774RX = _make_rotation_gate(PauliX, "RX") 

775RY = _make_rotation_gate(PauliY, "RY") 

776RZ = _make_rotation_gate(PauliZ, "RZ") 

777 

778 

779# Projectors used by controlled-gate factories 

780_P0 = jnp.array([[1, 0], [0, 0]], dtype=_cdtype()) 

781_P1 = jnp.array([[0, 0], [0, 1]], dtype=_cdtype()) 

782 

783 

784def _make_controlled_gate(target_class: type, name: str) -> type: 

785 """Factory for controlled Pauli gates CX, CY, CZ. 

786 

787 Each gate has the form 

788 ``CP = |0><0| \\otimes I + |1\\langle\\rangle 1| \\otimes P``. 

789 

790 Args: 

791 target_class: The single-qubit gate class (PauliX, PauliY, PauliZ). 

792 name: Class name for the generated gate (e.g. ``"CX"``). 

793 

794 Returns: 

795 A new :class:`Operation` subclass. 

796 """ 

797 target_mat = target_class._matrix 

798 

799 class _ControlledGate(Operation): 

800 __doc__ = ( 

801 f"Controlled-{target_class.__name__[5:]} gate.\n\n" 

802 f"Applies {target_class.__name__} on the target qubit conditioned " 

803 f"on the control qubit being in state |1\\rangle." 

804 ) 

805 _matrix = jnp.kron(_P0, Id._matrix) + jnp.kron(_P1, target_mat) 

806 _num_wires = 2 

807 is_controlled = True 

808 

809 def __init__(self, wires: List[int] = [0, 1], **kwargs) -> None: 

810 super().__init__(wires=wires, **kwargs) 

811 

812 _ControlledGate.__name__ = name 

813 _ControlledGate.__qualname__ = name 

814 return _ControlledGate 

815 

816 

817CX = _make_controlled_gate(PauliX, "CX") 

818CY = _make_controlled_gate(PauliY, "CY") 

819CZ = _make_controlled_gate(PauliZ, "CZ") 

820 

821 

822class CCX(Operation): 

823 """Toffoli (CCX) gate. 

824 

825 The 3-qubit Toffoli gate exercises the arbitrary-k-qubit path in 

826 :meth:`~Operation.apply_to_state` and cannot be expressed as a pair of 

827 2-qubit gates without ancilla, making it a good stress-test for the 

828 simulator. 

829 """ 

830 

831 _matrix = jnp.array( 

832 [ 

833 [1, 0, 0, 0, 0, 0, 0, 0], 

834 [0, 1, 0, 0, 0, 0, 0, 0], 

835 [0, 0, 1, 0, 0, 0, 0, 0], 

836 [0, 0, 0, 1, 0, 0, 0, 0], 

837 [0, 0, 0, 0, 1, 0, 0, 0], 

838 [0, 0, 0, 0, 0, 1, 0, 0], 

839 [0, 0, 0, 0, 0, 0, 0, 1], 

840 [0, 0, 0, 0, 0, 0, 1, 0], 

841 ], 

842 dtype=_cdtype(), 

843 ) 

844 is_controlled = True 

845 _num_wires = 3 

846 

847 def __init__(self, wires: List[int] = [0, 1, 2], **kwargs) -> None: 

848 """Initialise a Toffoli (CCX) gate. 

849 

850 Args: 

851 wires: Three-element list ``[control0, control1, target]``. 

852 """ 

853 super().__init__(wires=wires, **kwargs) 

854 

855 

856class CSWAP(Operation): 

857 """Controlled-SWAP (Fredkin) gate. 

858 

859 Swaps the two target qubits conditioned on the control qubit being |1\\rangle. 

860 

861 Args on construction: 

862 wires: ``[control, target0, target1]``. 

863 """ 

864 

865 _matrix = jnp.array( 

866 [ 

867 [1, 0, 0, 0, 0, 0, 0, 0], 

868 [0, 1, 0, 0, 0, 0, 0, 0], 

869 [0, 0, 1, 0, 0, 0, 0, 0], 

870 [0, 0, 0, 1, 0, 0, 0, 0], 

871 [0, 0, 0, 0, 1, 0, 0, 0], 

872 [0, 0, 0, 0, 0, 0, 1, 0], 

873 [0, 0, 0, 0, 0, 1, 0, 0], 

874 [0, 0, 0, 0, 0, 0, 0, 1], 

875 ], 

876 dtype=_cdtype(), 

877 ) 

878 is_controlled = True 

879 _num_wires = 3 

880 

881 def __init__(self, wires: List[int] = [0, 1, 2], **kwargs) -> None: 

882 """Initialise a Controlled-SWAP (Fredkin) gate. 

883 

884 Args: 

885 wires: Three-element list ``[control, target0, target1]``. 

886 """ 

887 super().__init__(wires=wires, **kwargs) 

888 

889 

890def _make_controlled_rotation_gate(pauli_class: type, name: str) -> type: 

891 """Factory for controlled rotation gates CRX, CRY, CRZ. 

892 

893 Each gate has the form 

894 ``CR_P(\\theta) = |0><0| \\otimes I + |1><1| \\otimes R_P(\\theta)``. 

895 

896 Args: 

897 pauli_class: One of PauliX, PauliY, PauliZ. 

898 name: Class name for the generated gate (e.g. ``"CRX"``). 

899 

900 Returns: 

901 A new :class:`Operation` subclass. 

902 """ 

903 pauli_mat = pauli_class._matrix 

904 

905 class _CRotationGate(Operation): 

906 __doc__ = ( 

907 f"Controlled rotation around the {name[2]} axis.\n\n" 

908 f"Applies R{name[2]}(\\theta) on the target qubit conditioned on the " 

909 f"control qubit being in state |1\\rangle.\n\n" 

910 f".. math::\n" 

911 f"{name}(\\theta) = |0\\rangle\\langle 0| \\otimes I\n" 

912 f" + |1\\rangle\\langle 1| \\otimes R{name[2]}(\\theta)" 

913 ) 

914 _num_wires = 2 

915 _param_names = ("theta",) 

916 is_controlled = True 

917 

918 def __init__(self, theta: float, wires: List[int] = [0, 1], **kwargs) -> None: 

919 self.theta = theta 

920 c = jnp.cos(theta / 2) 

921 s = jnp.sin(theta / 2) 

922 rot = c * Id._matrix - 1j * s * pauli_mat 

923 mat = jnp.kron(_P0, Id._matrix) + jnp.kron(_P1, rot) 

924 super().__init__(wires=wires, matrix=mat, **kwargs) 

925 

926 _CRotationGate.__name__ = name 

927 _CRotationGate.__qualname__ = name 

928 return _CRotationGate 

929 

930 

931CRX = _make_controlled_rotation_gate(PauliX, "CRX") 

932CRY = _make_controlled_rotation_gate(PauliY, "CRY") 

933CRZ = _make_controlled_rotation_gate(PauliZ, "CRZ") 

934 

935 

936class Rot(Operation): 

937 """General single-qubit rotation: 

938 Rot(\\phi, \\theta, \\omega) = RZ(\\omega) RY(\\theta) RZ(\\phi). 

939 

940 This is the most general SU(2) rotation (up to a global phase). It 

941 decomposes into three successive rotations and has three free parameters. 

942 """ 

943 

944 _num_wires = 1 

945 _param_names = ("phi", "theta", "omega") 

946 

947 def __init__( 

948 self, 

949 phi: float, 

950 theta: float, 

951 omega: float, 

952 wires: Union[int, List[int]] = 0, 

953 **kwargs, 

954 ) -> None: 

955 """Initialise a general rotation gate. 

956 

957 Args: 

958 phi: First RZ rotation angle (radians). 

959 theta: RY rotation angle (radians). 

960 omega: Second RZ rotation angle (radians). 

961 wires: Qubit index or list of qubit indices this gate acts on. 

962 """ 

963 self.phi = phi 

964 self.theta = theta 

965 self.omega = omega 

966 # Rot(\\phi, \theta, \\omega) = RZ(\\omega) @ RY(\theta) @ RZ(\\phi) 

967 rz_phi = jnp.cos(phi / 2) * Id._matrix - 1j * jnp.sin(phi / 2) * PauliZ._matrix 

968 ry_theta = ( 

969 jnp.cos(theta / 2) * Id._matrix - 1j * jnp.sin(theta / 2) * PauliY._matrix 

970 ) 

971 rz_omega = ( 

972 jnp.cos(omega / 2) * Id._matrix - 1j * jnp.sin(omega / 2) * PauliZ._matrix 

973 ) 

974 mat = rz_omega @ ry_theta @ rz_phi 

975 super().__init__(wires=wires, matrix=mat, **kwargs) 

976 

977 

978class PauliRot(Operation): 

979 """Multi-qubit Pauli rotation: exp(-i \\theta/2 P) for a Pauli word P. 

980 

981 The Pauli word is given as a string of ``'I'``, ``'X'``, ``'Y'``, ``'Z'`` 

982 characters (one per qubit). The rotation matrix is computed as 

983 ``cos(\\theta/2) I - i sin(\\theta/2) P`` where *P* is the tensor product of the 

984 corresponding single-qubit Pauli matrices. 

985 

986 Example:: 

987 

988 PauliRot(0.5, "XY", wires=[0, 1]) 

989 """ 

990 

991 _param_names = ("theta",) 

992 

993 # Map from character to 2x2 matrix 

994 _PAULI_MAP = { 

995 "I": Id._matrix, 

996 "X": PauliX._matrix, 

997 "Y": PauliY._matrix, 

998 "Z": PauliZ._matrix, 

999 } 

1000 

1001 def __init__( 

1002 self, theta: float, pauli_word: str, wires: Union[int, List[int]] = 0, **kwargs 

1003 ) -> None: 

1004 """Initialise a PauliRot gate. 

1005 

1006 Args: 

1007 theta: Rotation angle in radians. 

1008 pauli_word: A string of ``'I'``, ``'X'``, ``'Y'``, ``'Z'`` 

1009 characters specifying the Pauli tensor product. 

1010 wires: Qubit index or list of qubit indices this gate acts on. 

1011 """ 

1012 from functools import reduce as _reduce 

1013 

1014 self.theta = theta 

1015 self.pauli_word = pauli_word 

1016 

1017 pauli_matrices = [self._PAULI_MAP[c] for c in pauli_word] 

1018 P = _reduce(jnp.kron, pauli_matrices) 

1019 dim = P.shape[0] 

1020 mat = ( 

1021 jnp.cos(theta / 2) * jnp.eye(dim, dtype=_cdtype()) 

1022 - 1j * jnp.sin(theta / 2) * P 

1023 ) 

1024 super().__init__(wires=wires, matrix=mat, **kwargs) 

1025 

1026 def generator(self) -> Operation: 

1027 """Return the generator Pauli tensor product as an :class:`Operation`. 

1028 

1029 The generator of ``PauliRot(\\theta, word, wires)`` is the tensor product 

1030 of single-qubit Pauli matrices specified by *word*. The returned 

1031 :class:`Hermitian` wraps that matrix and the gate's wires. 

1032 

1033 Returns: 

1034 :class:`Hermitian` operation representing the Pauli tensor product. 

1035 """ 

1036 from functools import reduce as _reduce 

1037 

1038 pauli_matrices = [self._PAULI_MAP[c] for c in self.pauli_word] 

1039 P = _reduce(jnp.kron, pauli_matrices) 

1040 return Hermitian(matrix=P, wires=self.wires, record=False) 

1041 

1042 

1043class KrausChannel(Operation): 

1044 """Base class for noise channels defined by a set of Kraus operators. 

1045 

1046 A Kraus channel \\phi(\\rho ) = \\sigma_k K_k \\rho K_k\\dagger 

1047 is the most general physical 

1048 operation on a quantum state. For a pure unitary gate there is a single 

1049 operator K_0 = U satisfying K_0\\daggerK_0 = I; for noisy channels there are 

1050 multiple operators. 

1051 

1052 Subclasses must implement :meth:`kraus_matrices` and return a list of JAX 

1053 arrays. :meth:`apply_to_state` is intentionally left unimplemented: 

1054 Kraus channels require a density-matrix representation and cannot be 

1055 applied to a pure statevector in general. 

1056 """ 

1057 

1058 def kraus_matrices(self) -> List[jnp.ndarray]: 

1059 """Return the list of Kraus operators for this channel. 

1060 

1061 Returns: 

1062 List of 2-D JAX arrays, each of shape ``(2**k, 2**k)`` where k 

1063 is the number of target qubits. 

1064 

1065 Raises: 

1066 NotImplementedError: Subclasses must override this method. 

1067 """ 

1068 raise NotImplementedError 

1069 

1070 @property 

1071 def matrix(self) -> jnp.ndarray: 

1072 """Raises TypeError — noise channels have no single unitary matrix. 

1073 

1074 Raises: 

1075 TypeError: Always raised; use :meth:`apply_to_density` instead. 

1076 """ 

1077 raise TypeError( 

1078 f"{self.__class__.__name__} is a noise channel and has no single " 

1079 "unitary matrix. Use apply_to_density() instead." 

1080 ) 

1081 

1082 def apply_to_state(self, state: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

1083 """Raises TypeError — noise channels require density-matrix simulation. 

1084 

1085 Args: 

1086 state: Statevector (unused). 

1087 n_qubits: Number of qubits (unused). 

1088 

1089 Raises: 

1090 TypeError: Always raised; use ``execute(type='density')`` instead. 

1091 """ 

1092 raise TypeError( 

1093 f"{self.__class__.__name__} is a noise channel and cannot be " 

1094 "applied to a pure statevector. Use execute(type='density') instead." 

1095 ) 

1096 

1097 def apply_to_state_tensor(self, psi: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

1098 """Raises TypeError — noise channels require density-matrix simulation.""" 

1099 raise TypeError( 

1100 f"{self.__class__.__name__} is a noise channel and cannot be " 

1101 "applied to a pure statevector. Use execute(type='density') instead." 

1102 ) 

1103 

1104 def apply_to_density(self, rho: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

1105 """Apply 

1106 \\phi(\\rho ) = \\sigma_k K_k \\rho K_k\\dagger using tensor-contraction. 

1107 

1108 Uses the shared :func:`_contract_and_restore` helper, summing the 

1109 result over all Kraus operators. 

1110 

1111 Args: 

1112 rho: Density matrix of shape ``(2**n_qubits, 2**n_qubits)``. 

1113 n_qubits: Total number of qubits in the circuit. 

1114 

1115 Returns: 

1116 Updated density matrix of shape ``(2**n_qubits, 2**n_qubits)``. 

1117 """ 

1118 k = len(self.wires) 

1119 dim = 2**n_qubits 

1120 bra_wires = [w + n_qubits for w in self.wires] 

1121 rho_out = jnp.zeros_like(rho) 

1122 

1123 for K in self.kraus_matrices(): 

1124 K_t = K.reshape((2,) * 2 * k) 

1125 K_conj_t = jnp.conj(K_t) 

1126 rho_t = rho.reshape((2,) * 2 * n_qubits) 

1127 rho_t = _contract_and_restore(rho_t, K_t, k, self.wires) 

1128 rho_t = _contract_and_restore(rho_t, K_conj_t, k, bra_wires) 

1129 rho_out = rho_out + rho_t.reshape(dim, dim) 

1130 

1131 return rho_out 

1132 

1133 

1134class BitFlip(KrausChannel): 

1135 r"""Single-qubit bit-flip (Pauli-X) error channel. 

1136 

1137 .. math:: 

1138 K_0 = \sqrt{1-p}\,I, \quad K_1 = \sqrt{p}\,X 

1139 

1140 where *p* \\in [0, 1] is the probability of a bit flip. 

1141 """ 

1142 

1143 _num_wires = 1 

1144 _param_names = ("p",) 

1145 

1146 def __init__(self, p: float, wires: Union[int, List[int]] = 0) -> None: 

1147 """Initialise a bit-flip channel. 

1148 

1149 Args: 

1150 p: Bit-flip probability, must be in [0, 1]. 

1151 wires: Qubit index or list of qubit indices this channel acts on. 

1152 

1153 Raises: 

1154 ValueError: If *p* is outside [0, 1]. 

1155 """ 

1156 if not 0.0 <= p <= 1.0: 

1157 raise ValueError("p must be in [0, 1].") 

1158 self.p = p 

1159 super().__init__(wires=wires) 

1160 

1161 def kraus_matrices(self) -> List[jnp.ndarray]: 

1162 """Return the two Kraus operators for the bit-flip channel. 

1163 

1164 Returns: 

1165 List ``[K0, K1]`` where K0 = \\sqrt (1-p)·I and K1 = \\sqrt p·X. 

1166 """ 

1167 p = self.p 

1168 K0 = jnp.sqrt(1 - p) * Id._matrix 

1169 K1 = jnp.sqrt(p) * PauliX._matrix 

1170 return [K0, K1] 

1171 

1172 

1173class PhaseFlip(KrausChannel): 

1174 r"""Single-qubit phase-flip (Pauli-Z) error channel. 

1175 

1176 .. math:: 

1177 K_0 = \sqrt{1-p}\,I, \quad K_1 = \sqrt{p}\,Z 

1178 

1179 where *p* \\in [0, 1] is the probability of a phase flip. 

1180 """ 

1181 

1182 _num_wires = 1 

1183 _param_names = ("p",) 

1184 

1185 def __init__(self, p: float, wires: Union[int, List[int]] = 0) -> None: 

1186 """Initialise a phase-flip channel. 

1187 

1188 Args: 

1189 p: Phase-flip probability, must be in [0, 1]. 

1190 wires: Qubit index or list of qubit indices this channel acts on. 

1191 

1192 Raises: 

1193 ValueError: If *p* is outside [0, 1]. 

1194 """ 

1195 if not 0.0 <= p <= 1.0: 

1196 raise ValueError("p must be in [0, 1].") 

1197 self.p = p 

1198 super().__init__(wires=wires) 

1199 

1200 def kraus_matrices(self) -> List[jnp.ndarray]: 

1201 """Return the two Kraus operators for the phase-flip channel. 

1202 

1203 Returns: 

1204 List ``[K0, K1]`` where K0 = \\sqrt (1-p)·I and K1 = \\sqrt p·Z. 

1205 """ 

1206 p = self.p 

1207 K0 = jnp.sqrt(1 - p) * Id._matrix 

1208 K1 = jnp.sqrt(p) * PauliZ._matrix 

1209 return [K0, K1] 

1210 

1211 

1212class DepolarizingChannel(KrausChannel): 

1213 r"""Single-qubit depolarizing channel. 

1214 

1215 .. math:: 

1216 K_0 = \sqrt{1-p}\,I,\quad K_1 = \sqrt{p/3}\,X,\quad 

1217 K_2 = \sqrt{p/3}\,Y,\quad K_3 = \sqrt{p/3}\,Z 

1218 

1219 where *p* \\in [0, 1]. At p = 3/4 the channel is fully depolarizing. 

1220 """ 

1221 

1222 _num_wires = 1 

1223 _param_names = ("p",) 

1224 

1225 def __init__(self, p: float, wires: Union[int, List[int]] = 0) -> None: 

1226 """Initialise a depolarizing channel. 

1227 

1228 Args: 

1229 p: Depolarization probability, must be in [0, 1]. 

1230 wires: Qubit index or list of qubit indices this channel acts on. 

1231 

1232 Raises: 

1233 ValueError: If *p* is outside [0, 1]. 

1234 """ 

1235 if not 0.0 <= p <= 1.0: 

1236 raise ValueError("p must be in [0, 1].") 

1237 self.p = p 

1238 super().__init__(wires=wires) 

1239 

1240 def kraus_matrices(self) -> List[jnp.ndarray]: 

1241 """Return the four Kraus operators for the depolarizing channel. 

1242 

1243 Returns: 

1244 List ``[K0, K1, K2, K3]`` corresponding to I, X, Y, Z components. 

1245 """ 

1246 p = self.p 

1247 K0 = jnp.sqrt(1 - p) * Id._matrix 

1248 K1 = jnp.sqrt(p / 3) * PauliX._matrix 

1249 K2 = jnp.sqrt(p / 3) * PauliY._matrix 

1250 K3 = jnp.sqrt(p / 3) * PauliZ._matrix 

1251 return [K0, K1, K2, K3] 

1252 

1253 

1254class AmplitudeDamping(KrausChannel): 

1255 r"""Single-qubit amplitude damping channel. 

1256 

1257 .. math:: 

1258 K_0 = \begin{pmatrix}1 & 0\\ 0 & \sqrt{1-\gamma}\end{pmatrix},\quad 

1259 K_1 = \begin{pmatrix}0 & \sqrt{\gamma}\\ 0 & 0\end{pmatrix} 

1260 

1261 where *\\gamma* \\in [0, 1] is the probability of 

1262 energy loss (|1\\rangle -> |0\\rangle). 

1263 """ 

1264 

1265 _num_wires = 1 

1266 _param_names = ("gamma",) 

1267 

1268 def __init__(self, gamma: float, wires: Union[int, List[int]] = 0) -> None: 

1269 """Initialise an amplitude damping channel. 

1270 

1271 Args: 

1272 gamma: Energy-loss probability, must be in [0, 1]. 

1273 wires: Qubit index or list of qubit indices this channel acts on. 

1274 

1275 Raises: 

1276 ValueError: If *gamma* is outside [0, 1]. 

1277 """ 

1278 if not 0.0 <= gamma <= 1.0: 

1279 raise ValueError("gamma must be in [0, 1].") 

1280 self.gamma = gamma 

1281 super().__init__(wires=wires) 

1282 

1283 def kraus_matrices(self) -> List[jnp.ndarray]: 

1284 """Return the two Kraus operators for the amplitude damping channel. 

1285 

1286 Returns: 

1287 List ``[K0, K1]`` as defined in the class docstring. 

1288 """ 

1289 g = self.gamma 

1290 K0 = jnp.array([[1.0, 0.0], [0.0, jnp.sqrt(1 - g)]], dtype=_cdtype()) 

1291 K1 = jnp.array([[0.0, jnp.sqrt(g)], [0.0, 0.0]], dtype=_cdtype()) 

1292 return [K0, K1] 

1293 

1294 

1295class PhaseDamping(KrausChannel): 

1296 r"""Single-qubit phase damping (dephasing) channel. 

1297 

1298 .. math:: 

1299 K_0 = \begin{pmatrix}1 & 0\\ 0 & \sqrt{1-\gamma}\end{pmatrix},\quad 

1300 K_1 = \begin{pmatrix}0 & 0\\ 0 & \sqrt{\gamma}\end{pmatrix} 

1301 

1302 where *\\gamma* \\in [0, 1] is the phase damping probability. 

1303 """ 

1304 

1305 _num_wires = 1 

1306 _param_names = ("gamma",) 

1307 

1308 def __init__(self, gamma: float, wires: Union[int, List[int]] = 0) -> None: 

1309 """Initialise a phase damping channel. 

1310 

1311 Args: 

1312 gamma: Phase-damping probability, must be in [0, 1]. 

1313 wires: Qubit index or list of qubit indices this channel acts on. 

1314 

1315 Raises: 

1316 ValueError: If *gamma* is outside [0, 1]. 

1317 """ 

1318 if not 0.0 <= gamma <= 1.0: 

1319 raise ValueError("gamma must be in [0, 1].") 

1320 self.gamma = gamma 

1321 super().__init__(wires=wires) 

1322 

1323 def kraus_matrices(self) -> List[jnp.ndarray]: 

1324 """Return the two Kraus operators for the phase damping channel. 

1325 

1326 Returns: 

1327 List ``[K0, K1]`` as defined in the class docstring. 

1328 """ 

1329 g = self.gamma 

1330 K0 = jnp.array([[1.0, 0.0], [0.0, jnp.sqrt(1 - g)]], dtype=_cdtype()) 

1331 K1 = jnp.array([[0.0, 0.0], [0.0, jnp.sqrt(g)]], dtype=_cdtype()) 

1332 return [K0, K1] 

1333 

1334 

1335class ThermalRelaxationError(KrausChannel): 

1336 r"""Single-qubit thermal relaxation error channel. 

1337 

1338 Models simultaneous T_1 energy relaxation and T_2 dephasing. Two regimes 

1339 are handled: 

1340 

1341 T_2 <= T_1 (Markovian dephasing + reset): 

1342 Six Kraus operators built from p_z (phase-flip probability), p_r0 

1343 (reset-to-|0\\rangle probability) and p_r1 (reset-to-|1\\rangle probability). 

1344 

1345 T_2 > T_1 (non-Markovian; Choi matrix decomposition): 

1346 The Choi matrix is assembled from the relaxation/dephasing rates, then 

1347 diagonalised; Kraus operators are K_i = \sqrt \lambda_i · mat(v_i). 

1348 

1349 Attributes: 

1350 pe: Excited-state population (thermal population of |1\\rangle). 

1351 t1: T_1 longitudinal relaxation time. 

1352 t2: T_2 transverse dephasing time. 

1353 tg: Gate duration. 

1354 """ 

1355 

1356 _num_wires = 1 

1357 _param_names = ("pe", "t1", "t2", "tg") 

1358 

1359 def __init__( 

1360 self, 

1361 pe: float, 

1362 t1: float, 

1363 t2: float, 

1364 tg: float, 

1365 wires: Union[int, List[int]] = 0, 

1366 ) -> None: 

1367 """Initialise a thermal relaxation error channel. 

1368 

1369 Args: 

1370 pe: Excited-state population (thermal population of |1\\rangle), in [0, 1]. 

1371 t1: T_1 longitudinal relaxation time, must be > 0. 

1372 t2: T_2 transverse dephasing time, must be > 0 and <= 2·T_1. 

1373 tg: Gate duration, must be >= 0. 

1374 wires: Qubit index or list of qubit indices this channel acts on. 

1375 

1376 Raises: 

1377 ValueError: If any parameter violates the stated constraints. 

1378 """ 

1379 if not 0.0 <= pe <= 1.0: 

1380 raise ValueError("pe must be in [0, 1].") 

1381 if t1 <= 0: 

1382 raise ValueError("t1 must be > 0.") 

1383 if t2 <= 0: 

1384 raise ValueError("t2 must be > 0.") 

1385 if t2 > 2 * t1: 

1386 raise ValueError("t2 must be <= 2·t1.") 

1387 if tg < 0: 

1388 raise ValueError("tg must be >= 0.") 

1389 self.pe = pe 

1390 self.t1 = t1 

1391 self.t2 = t2 

1392 self.tg = tg 

1393 super().__init__(wires=wires) 

1394 

1395 def kraus_matrices(self) -> List[jnp.ndarray]: 

1396 """Return the Kraus operators for the thermal relaxation channel. 

1397 

1398 The number of operators depends on the regime: 

1399 

1400 * T_2 <= T_1: six operators (identity, phase-flip, two reset-to-|0\\rangle, 

1401 two reset-to-|1\\rangle). 

1402 * T_2 > T_1: four operators derived from the Choi matrix eigendecomposition. 

1403 

1404 Returns: 

1405 List of 2x2 JAX arrays representing the Kraus operators. 

1406 """ 

1407 pe, t1, t2, tg = self.pe, self.t1, self.t2, self.tg 

1408 

1409 eT1 = jnp.exp(-tg / t1) 

1410 p_reset = 1.0 - eT1 

1411 eT2 = jnp.exp(-tg / t2) 

1412 

1413 if t2 <= t1: 

1414 # --- Case T_2 <= T_1: six Kraus operators --- 

1415 pz = (1.0 - p_reset) * (1.0 - eT2 / eT1) / 2.0 

1416 pr0 = (1.0 - pe) * p_reset 

1417 pr1 = pe * p_reset 

1418 pid = 1.0 - pz - pr0 - pr1 

1419 

1420 K0 = jnp.sqrt(pid) * jnp.eye(2, dtype=_cdtype()) 

1421 K1 = jnp.sqrt(pz) * jnp.array([[1, 0], [0, -1]], dtype=_cdtype()) 

1422 K2 = jnp.sqrt(pr0) * jnp.array([[1, 0], [0, 0]], dtype=_cdtype()) 

1423 K3 = jnp.sqrt(pr0) * jnp.array([[0, 1], [0, 0]], dtype=_cdtype()) 

1424 K4 = jnp.sqrt(pr1) * jnp.array([[0, 0], [1, 0]], dtype=_cdtype()) 

1425 K5 = jnp.sqrt(pr1) * jnp.array([[0, 0], [0, 1]], dtype=_cdtype()) 

1426 return [K0, K1, K2, K3, K4, K5] 

1427 

1428 else: 

1429 # --- Case T_2 > T_1: Choi matrix decomposition --- 

1430 # Choi matrix (column-major / reshaping convention matching PennyLane) 

1431 choi = jnp.array( 

1432 [ 

1433 [1 - pe * p_reset, 0, 0, eT2], 

1434 [0, pe * p_reset, 0, 0], 

1435 [0, 0, (1 - pe) * p_reset, 0], 

1436 [eT2, 0, 0, 1 - (1 - pe) * p_reset], 

1437 ], 

1438 dtype=_cdtype(), 

1439 ) 

1440 eigenvalues, eigenvectors = jnp.linalg.eigh(choi) 

1441 # Each eigenvector (column of eigenvectors) reshaped as 2x2 -> one Kraus op 

1442 kraus = [] 

1443 for i in range(4): 

1444 lam = eigenvalues[i] 

1445 vec = eigenvectors[:, i] 

1446 mat = jnp.sqrt(jnp.abs(lam)) * vec.reshape(2, 2, order="F") 

1447 kraus.append(mat.astype(_cdtype())) 

1448 return kraus 

1449 

1450 

1451class QubitChannel(KrausChannel): 

1452 """Generic Kraus channel from a user-supplied list of Kraus operators. 

1453 

1454 This replaces PennyLane's ``qml.QubitChannel`` and accepts an arbitrary set 

1455 of Kraus matrices satisfying \\sigma_k K_k\\dagger K_k = I. 

1456 

1457 Example:: 

1458 

1459 kraus_ops = [jnp.sqrt(0.9) * jnp.eye(2), jnp.sqrt(0.1) * PauliX._matrix] 

1460 QubitChannel(kraus_ops, wires=0) 

1461 """ 

1462 

1463 def __init__( 

1464 self, kraus_ops: List[jnp.ndarray], wires: Union[int, List[int]] = 0 

1465 ) -> None: 

1466 """Initialise a generic Kraus channel. 

1467 

1468 Args: 

1469 kraus_ops: List of Kraus matrices. Each must be a square 2D array 

1470 of dimension ``2**k x 2**k`` where k = ``len(wires)``. 

1471 wires: Qubit index or list of qubit indices this channel acts on. 

1472 """ 

1473 self._kraus_ops = [jnp.asarray(K, dtype=_cdtype()) for K in kraus_ops] 

1474 super().__init__(wires=wires) 

1475 

1476 def kraus_matrices(self) -> List[jnp.ndarray]: 

1477 """Return the stored Kraus operators. 

1478 

1479 Returns: 

1480 List of Kraus operator matrices. 

1481 """ 

1482 return self._kraus_ops 

1483 

1484 

1485# Single-qubit Pauli matrices (plain arrays, no Operation overhead) 

1486_PAULI_MATS = [Id._matrix, PauliX._matrix, PauliY._matrix, PauliZ._matrix] 

1487_PAULI_LABELS = ["I", "X", "Y", "Z"] 

1488_PAULI_CLASSES = [Id, PauliX, PauliY, PauliZ] 

1489 

1490 

1491def evolve_pauli_with_clifford( 

1492 clifford: Operation, 

1493 pauli: Operation, 

1494 adjoint_left: bool = True, 

1495) -> Operation: 

1496 """Compute C\\dagger P C (or C P C\\dagger) and 

1497 return the result as an Operation. 

1498 

1499 Both operators are first embedded into the full Hilbert space spanned by 

1500 the union of their wire sets. The result is wrapped in a 

1501 :class:`Hermitian` so it can be used in further algebra. 

1502 

1503 Args: 

1504 clifford: A Clifford gate. 

1505 pauli: A Pauli / Hermitian operator. 

1506 adjoint_left: If ``True``, compute C\\dagger P C; otherwise C P C\\dagger. 

1507 

1508 Returns: 

1509 A :class:`Hermitian` wrapping the evolved matrix. 

1510 """ 

1511 all_wires = sorted(set(clifford.wires) | set(pauli.wires)) 

1512 n = len(all_wires) 

1513 

1514 C = _embed_matrix(clifford.matrix, clifford.wires, all_wires, n) 

1515 P = _embed_matrix(pauli.matrix, pauli.wires, all_wires, n) 

1516 Cd = jnp.conj(C).T 

1517 

1518 if adjoint_left: 

1519 result = Cd @ P @ C 

1520 else: 

1521 result = C @ P @ Cd 

1522 

1523 return Hermitian(matrix=result, wires=all_wires, record=False) 

1524 

1525 

1526def _embed_matrix( 

1527 mat: jnp.ndarray, 

1528 op_wires: list, 

1529 all_wires: list, 

1530 n_total: int, 

1531) -> jnp.ndarray: 

1532 """Embed a gate matrix into a larger Hilbert space via tensor products. 

1533 

1534 If the gate already acts on all wires, the matrix is returned as-is. 

1535 Otherwise the gate matrix is tensored with identities on the missing 

1536 wires, and the resulting matrix rows/columns are permuted so that qubit 

1537 ordering matches *all_wires*. 

1538 

1539 Args: 

1540 mat: The gate's unitary matrix of shape ``(2**k, 2**k)`` where 

1541 ``k = len(op_wires)``. 

1542 op_wires: The wires the gate acts on. 

1543 all_wires: The full ordered list of wires. 

1544 n_total: ``len(all_wires)``. 

1545 

1546 Returns: 

1547 A ``(2**n_total, 2**n_total)`` matrix. 

1548 """ 

1549 k = len(op_wires) 

1550 if k == n_total and list(op_wires) == list(all_wires): 

1551 return mat 

1552 

1553 # Build the full-space matrix by tensoring with identities 

1554 # Strategy: tensor I on missing wires, then permute 

1555 missing = [w for w in all_wires if w not in op_wires] 

1556 # Full matrix = mat \\otimes I_{missing} 

1557 full_mat = mat 

1558 for _ in missing: 

1559 full_mat = jnp.kron(full_mat, jnp.eye(2, dtype=_cdtype())) 

1560 

1561 # The current ordering is [op_wires..., missing...] 

1562 # We need to permute to match all_wires ordering 

1563 current_order = list(op_wires) + missing 

1564 if current_order != list(all_wires): 

1565 perm = [current_order.index(w) for w in all_wires] 

1566 full_mat = _permute_matrix(full_mat, perm, n_total) 

1567 

1568 return full_mat 

1569 

1570 

1571def _permute_matrix(mat: jnp.ndarray, perm: list, n_qubits: int) -> jnp.ndarray: 

1572 """Permute the qubit ordering of a matrix. 

1573 

1574 Given a ``(2**n, 2**n)`` matrix and a permutation of ``[0..n-1]``, 

1575 reorder the qubits so that qubit ``i`` moves to position ``perm[i]``. 

1576 

1577 Args: 

1578 mat: Square matrix of dimension ``2**n_qubits``. 

1579 perm: Permutation list. 

1580 n_qubits: Number of qubits. 

1581 

1582 Returns: 

1583 Permuted matrix of the same shape. 

1584 """ 

1585 dim = 2**n_qubits 

1586 # Reshape to tensor, permute axes, reshape back 

1587 tensor = mat.reshape([2] * (2 * n_qubits)) 

1588 # Axes: first n_qubits are row indices, last n_qubits are column indices 

1589 row_perm = perm 

1590 col_perm = [p + n_qubits for p in perm] 

1591 tensor = jnp.transpose(tensor, row_perm + col_perm) 

1592 return tensor.reshape(dim, dim) 

1593 

1594 

1595def pauli_decompose(matrix: jnp.ndarray, wire_order: Optional[List[int]] = None): 

1596 r"""Decompose a Hermitian matrix into a sum of Pauli tensor products. 

1597 

1598 For an n-qubit matrix (``2**n x 2**n``), returns the dominant Pauli 

1599 term (the one with the largest absolute coefficient), wrapped as an 

1600 :class:`Operation`. This is sufficient for the Fourier-tree algorithm 

1601 which only needs the single non-zero Pauli term produced by Clifford 

1602 conjugation of a Pauli operator. 

1603 

1604 The decomposition uses the trace formula: 

1605 ``c_P = Tr(P · M) / 2**n`` 

1606 

1607 Args: 

1608 matrix: A ``(2**n, 2**n)`` Hermitian matrix. 

1609 wire_order: Optional list of wire indices. If ``None``, defaults 

1610 to ``[0, 1, ..., n-1]``. 

1611 

1612 Returns: 

1613 A tuple ``(coeff, op)`` where *coeff* is the complex coefficient and 

1614 *op* is the Pauli :class:`Operation` (PauliX, PauliY, PauliZ, I, or 

1615 a :class:`Hermitian` for multi-qubit tensor products). 

1616 """ 

1617 from itertools import product as _product 

1618 from functools import reduce as _reduce 

1619 

1620 dim = matrix.shape[0] 

1621 n_qubits = int(jnp.round(jnp.log2(dim))) 

1622 

1623 if wire_order is None: 

1624 wire_order = list(range(n_qubits)) 

1625 

1626 # For single qubit, fast path 

1627 if n_qubits == 1: 

1628 best_idx, best_coeff = 0, 0.0 

1629 for idx, P in enumerate(_PAULI_MATS): 

1630 coeff = jnp.trace(P @ matrix) / 2.0 

1631 if jnp.abs(coeff) > jnp.abs(best_coeff): 

1632 best_idx = idx 

1633 best_coeff = coeff 

1634 op_cls = _PAULI_CLASSES[best_idx] 

1635 result_op = op_cls(wires=wire_order[0], record=False) 

1636 result_op._pauli_label = _PAULI_LABELS[best_idx] 

1637 return best_coeff, result_op 

1638 

1639 # Multi-qubit: iterate over all Pauli tensor products 

1640 best_label = None 

1641 best_coeff = 0.0 

1642 for indices in _product(range(4), repeat=n_qubits): 

1643 P = _reduce(jnp.kron, [_PAULI_MATS[i] for i in indices]) 

1644 coeff = jnp.trace(P @ matrix) / dim 

1645 if jnp.abs(coeff) > jnp.abs(best_coeff): 

1646 best_coeff = coeff 

1647 best_label = indices 

1648 

1649 # Build the Pauli string label 

1650 pauli_label = "".join(_PAULI_LABELS[i] for i in best_label) 

1651 

1652 # Build the operation for the dominant term 

1653 if sum(1 for i in best_label if i != 0) <= 1: 

1654 # Single-qubit Pauli on one wire 

1655 for q, idx in enumerate(best_label): 

1656 if idx != 0: 

1657 op_cls = _PAULI_CLASSES[idx] 

1658 result_op = op_cls(wires=wire_order[q], record=False) 

1659 result_op._pauli_label = _PAULI_LABELS[idx] 

1660 return best_coeff, result_op 

1661 # All identity 

1662 result_op = Id(wires=wire_order[0], record=False) 

1663 result_op._pauli_label = "I" * n_qubits 

1664 return best_coeff, result_op 

1665 else: 

1666 # Multi-qubit tensor product -> Hermitian with pauli label attached 

1667 P = _reduce(jnp.kron, [_PAULI_MATS[i] for i in best_label]) 

1668 result_op = Hermitian(matrix=P, wires=wire_order, record=False) 

1669 result_op._pauli_label = pauli_label 

1670 return best_coeff, result_op 

1671 

1672 

1673def pauli_string_from_operation(op: Operation) -> str: 

1674 """Extract a Pauli word string from an operation. 

1675 

1676 Maps ``PauliX`` -> ``"X"``, ``PauliY`` -> ``"Y"``, ``PauliZ`` -> ``"Z"``, 

1677 ``I`` -> ``"I"``. For :class:`PauliRot`, returns its stored ``pauli_word``. 

1678 For operations produced by :func:`pauli_decompose`, returns the stored 

1679 ``_pauli_label`` attribute. 

1680 

1681 Args: 

1682 op: A quantum operation. 

1683 

1684 Returns: 

1685 A string like ``"X"``, ``"ZZ"``, etc. 

1686 """ 

1687 if isinstance(op, PauliRot) and hasattr(op, "pauli_word"): 

1688 return op.pauli_word 

1689 # Check for label stored by pauli_decompose 

1690 if hasattr(op, "_pauli_label"): 

1691 return op._pauli_label 

1692 name_map = {"PauliX": "X", "PauliY": "Y", "PauliZ": "Z", "I": "I"} 

1693 if op.name in name_map: 

1694 return name_map[op.name] 

1695 # Fall back: decompose the matrix 

1696 _, pauli_op = pauli_decompose(op.matrix, wire_order=op.wires) 

1697 return pauli_op._pauli_label