Coverage for qml_essentials / operations.py: 89%

499 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-30 11:43 +0000

1from typing import Callable, List, Optional, Tuple, Union 

2from functools import lru_cache 

3import string 

4import numpy as np 

5 

6import jax 

7import jax.numpy as jnp 

8 

9from qml_essentials.tape import active_tape, recording # noqa: F401 (re-export) 

10 

11 

12def _cdtype(): 

13 """Return the active JAX complex dtype 

14 (complex128 if x64 enabled, else complex64). 

15 """ 

16 return jnp.complex128 if jax.config.x64_enabled else jnp.complex64 

17 

18 

19@lru_cache(maxsize=256) 

20def _einsum_subscript( 

21 n: int, 

22 k: int, 

23 target_axes: Tuple[int, ...], 

24) -> str: 

25 """Build an ``einsum`` subscript that fuses contraction + axis restore. 

26 

27 Args: 

28 n: Total rank of the state tensor (number of qubits for statevectors, 

29 ``2 * n_qubits`` for density matrices). 

30 k: Number of qubits the gate acts on. 

31 target_axes: Tuple of k axis indices in the state tensor that the 

32 gate contracts against. 

33 

34 Returns: 

35 ``einsum`` subscript string, e.g. ``"ab,cBd->cad"`` for a 1-qubit 

36 gate on wire 1 of a 3-qubit state. 

37 """ 

38 letters = string.ascii_letters 

39 # State indices: one letter per axis 

40 state_idx = list(letters[:n]) 

41 # Contracted indices (the ones being replaced by the gate) 

42 contracted = [state_idx[ax] for ax in target_axes] 

43 # Gate indices: new output indices + contracted input indices 

44 new_out = [letters[n + i] for i in range(k)] # fresh letters for output 

45 gate_idx = new_out + contracted # gate shape: (out0, out1, ..., in0, in1, ...) 

46 # Result indices: replace target axes with new output letters 

47 result_idx = list(state_idx) 

48 for i, ax in enumerate(target_axes): 

49 result_idx[ax] = new_out[i] 

50 return "".join(gate_idx) + "," + "".join(state_idx) + "->" + "".join(result_idx) 

51 

52 

53def _contract_and_restore( 

54 tensor: jnp.ndarray, 

55 gate: jnp.ndarray, 

56 k: int, 

57 target_axes: List[int], 

58) -> jnp.ndarray: 

59 """Contract gate against target_axes of tensor and restore axis order. 

60 

61 The einsum subscript is cached via :func:`_einsum_subscript` so the 

62 string construction only happens once per unique 

63 ``(total, k, target_axes)`` combination. 

64 

65 Args: 

66 tensor: Rank-N tensor (e.g. ``(2,)*n`` for states or ``(2,)*2n`` 

67 for density matrices). 

68 gate: Reshaped gate tensor of shape ``(2,)*2k``. 

69 k: Number of qubits the gate acts on (= ``len(target_axes)``). 

70 target_axes: The k axes of tensor to contract against. 

71 

72 Returns: 

73 Updated tensor with the same rank as tensor, with the 

74 contracted axes restored to their original positions. 

75 """ 

76 subscript = _einsum_subscript(tensor.ndim, k, tuple(target_axes)) 

77 return jnp.einsum(subscript, gate, tensor) 

78 

79 

80class Operation: 

81 """Base class for any quantum operation or observable. 

82 

83 Further gates should inherit from this class to realise more specific 

84 operations. Generally, operations are created by instantiation inside a 

85 circuit function passed to :class:`Script`; the instance is 

86 automatically appended to the active tape. 

87 

88 An ``Operation`` can also serve as an *observable*: its matrix is used to 

89 compute expectation values via ``apply_to_state`` / ``apply_to_density``. 

90 

91 Attributes: 

92 _matrix: Class-level default gate matrix. Subclasses set this to their 

93 fixed unitary. Instances may override it via the *matrix* argument 

94 to ``__init__``. 

95 _num_wires: Expected number of wires for this gate. Subclasses set 

96 this to enforce wire count validation. ``None`` means any number 

97 of wires is accepted. 

98 _param_names: Tuple of attribute names for the gate parameters. 

99 Used by :attr:`parameters` and :meth:`__repr__`. 

100 """ 

101 

102 # Subclasses should set this to the gate's unitary / matrix 

103 _matrix: jnp.ndarray = None 

104 is_controlled = False 

105 _num_wires: Optional[int] = None 

106 _param_names: Tuple[str, ...] = () 

107 

108 def __init__( 

109 self, 

110 wires: Union[int, List[int]] = 0, 

111 matrix: Optional[jnp.ndarray] = None, 

112 record: bool = True, 

113 input_idx: int = -1, 

114 name: Optional[str] = None, 

115 ) -> None: 

116 """Initialise the operation and optionally register it on the active tape. 

117 

118 Args: 

119 wires: Qubit index or list of qubit indices this operation acts on. 

120 matrix: Optional explicit gate matrix. When provided it overrides 

121 the class-level ``_matrix`` attribute. 

122 record: If ``True`` (default) and a tape is currently recording, 

123 append this operation to the tape. Set to ``False`` for 

124 auxiliary objects that should not appear in the circuit 

125 (e.g. Hamiltonians used only to build time-dependent 

126 evolutions). 

127 input_idx: Marks the operation as input with the corresponding 

128 input index, which is useful for the analytical Fourier 

129 coefficients computation, but has no effect otherwise. 

130 name: Optional explicit name for this operation. When ``None`` 

131 (default), the class name is used (e.g. ``"RX"``). 

132 

133 Raises: 

134 ValueError: If ``_num_wires`` is set and the number of wires 

135 doesn't match, or if duplicate wires are provided. 

136 """ 

137 self.name = name or self.__class__.__name__ 

138 self.wires = list(wires) if isinstance(wires, (list, tuple)) else [wires] 

139 self.input_idx = input_idx 

140 

141 if self._num_wires is not None and len(self.wires) != self._num_wires: 

142 raise ValueError( 

143 f"{self.name} expects {self._num_wires} wire(s), " 

144 f"got {len(self.wires)}: {self.wires}" 

145 ) 

146 if len(self.wires) != len(set(self.wires)): 

147 raise ValueError(f"{self.name} received duplicate wires: {self.wires}") 

148 

149 if matrix is not None: 

150 self._matrix = matrix 

151 

152 # If a tape is currently recording, append ourselves 

153 if record: 

154 tape = active_tape() 

155 if tape is not None: 

156 tape.append(self) 

157 

158 @property 

159 def parameters(self) -> list: 

160 """Return the list of numeric parameters for this operation. 

161 

162 Uses the declarative ``_param_names`` tuple to collect parameter 

163 values in a canonical order. Non-parametrized gates return an 

164 empty list. 

165 

166 Returns: 

167 List of parameter values (floats or JAX arrays). 

168 """ 

169 return [getattr(self, name) for name in self._param_names] 

170 

171 def __repr__(self) -> str: 

172 """Return a human-readable representation of this operation. 

173 

174 Returns: 

175 A string like ``"RX(0.5000, wires=[0])"`` or ``"CX(wires=[0, 1])"``. 

176 """ 

177 params = self.parameters 

178 if params: 

179 param_str = ", ".join( 

180 ( 

181 f"{float(v):.4f}" 

182 if isinstance(v, (float, np.floating, jnp.ndarray)) 

183 else str(v) 

184 ) 

185 for v in params 

186 ) 

187 return f"{self.name}({param_str}, wires={self.wires})" 

188 return f"{self.name}(wires={self.wires})" 

189 

190 @property 

191 def matrix(self) -> jnp.ndarray: 

192 """Return the base matrix of this operation (before lifting). 

193 

194 Returns: 

195 The gate matrix as a JAX array. 

196 

197 Raises: 

198 NotImplementedError: If the subclass has not defined ``_matrix``. 

199 """ 

200 if self._matrix is None: 

201 raise NotImplementedError( 

202 f"{self.__class__.__name__} does not define a matrix." 

203 ) 

204 return self._matrix 

205 

206 @property 

207 def wires(self) -> List[int]: 

208 """Qubit indices this operation acts on. 

209 

210 Returns: 

211 List of integer qubit indices. 

212 """ 

213 return self._wires 

214 

215 @wires.setter 

216 def wires(self, wires: Union[int, List[int]]) -> None: 

217 """Set the qubit indices for this operation. 

218 

219 Args: 

220 wires: A single qubit index or a list of qubit indices. 

221 """ 

222 if isinstance(wires, (list, tuple)): 

223 self._wires = list(wires) 

224 else: 

225 self._wires = [wires] 

226 

227 @property 

228 def input_idx(self) -> int: 

229 """The index of an input 

230 

231 Returns: 

232 input_idx: Index of the input 

233 """ 

234 return self._input_idx 

235 

236 @input_idx.setter 

237 def input_idx(self, input_idx: int) -> None: 

238 """Setter for the input_idx flag 

239 

240 Args: 

241 input_idx: Index of the input 

242 """ 

243 self._input_idx = input_idx 

244 

245 def _update_tape_operation(self, op: "Operation") -> None: 

246 """ 

247 If ``self`` is already on the active tape (the typical case when 

248 chaining ``Gate(...).dagger()``), it is replaced by the daggered 

249 operation so that only U\\dagger appears on the tape — 

250 not both U and ``U\\dagger``. 

251 Note that this should only be called immediately after the tape is updated.s 

252 

253 Args: 

254 op (Operation): New replaced operation on the tape 

255 """ 

256 # If self was recorded on the tape, replace it with the daggered op. 

257 tape = active_tape() 

258 if tape is not None: 

259 if tape and tape[-1] is self: 

260 tape[-1] = op 

261 else: 

262 tape.append(op) 

263 

264 def dagger(self) -> "Operation": 

265 """Return a new operation, the conjugate transpose (``U\\dagger``) 

266 Usage inside a circuit function:: 

267 

268 RX(0.5, wires=0).dagger() 

269 

270 Returns: 

271 A new :class:`Operation` with matrix ``U\\dagger`` acting on the same wires. 

272 """ 

273 mat = jnp.conj(self._matrix).T 

274 op = Operation(wires=self.wires, matrix=mat, record=False) 

275 

276 self._update_tape_operation(op) 

277 

278 return op 

279 

280 def power(self, power) -> "Operation": 

281 """Return a new operation, the power (``U^power``) 

282 Usage inside a circuit function:: 

283 

284 PauliX(wires=0).power(2) 

285 

286 Returns: 

287 A new :class:`Operation` with matrix ``U\\dagger`` acting on the same wires. 

288 """ 

289 # TODO: support fractional powers 

290 mat = jnp.linalg.matrix_power(self._matrix, power) 

291 op = Operation(wires=self.wires, matrix=mat, record=False) 

292 

293 self._update_tape_operation(op) 

294 

295 return op 

296 

297 def lifted_matrix(self, n_qubits: int) -> jnp.ndarray: 

298 """Return the full ``2**n x 2**n`` matrix embedding this gate. 

299 

300 Embeds the ``k``-qubit gate matrix into the ``n``-qubit Hilbert space 

301 by applying it to the identity matrix via :meth:`apply_to_state`. 

302 This is useful for computing ``Tr(O·\\rho )`` directly without vmap. 

303 

304 Args: 

305 n_qubits: Total number of qubits in the circuit. 

306 

307 Returns: 

308 The ``(2**n, 2**n)`` matrix of this operation in the full space. 

309 """ 

310 dim = 2**n_qubits 

311 # Apply the gate to each basis vector (column of identity) 

312 return jax.vmap(lambda col: self.apply_to_state(col, n_qubits))( 

313 jnp.eye(dim, dtype=_cdtype()) 

314 ).T 

315 

316 def apply_to_state(self, state: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

317 """Apply this gate to a statevector via tensor contraction. 

318 

319 The statevector (shape ``(2**n,)``) is reshaped into a rank-n tensor 

320 of shape ``(2,)*n``. The gate (shape ``(2**k, 2**k)``) is reshaped to 

321 ``(2,)*2k`` and contracted against the k target wire axes. 

322 

323 Memory footprint is O(2**n) and the operation supports arbitrary k. 

324 The implementation is fully differentiable through JAX. 

325 

326 Args: 

327 state: Statevector of shape ``(2**n_qubits,)``. 

328 n_qubits: Total number of qubits in the circuit. 

329 

330 Returns: 

331 Updated statevector of shape ``(2**n_qubits,)``. 

332 """ 

333 k = len(self.wires) 

334 gate_tensor = self.matrix.reshape((2,) * 2 * k) 

335 psi = state.reshape((2,) * n_qubits) 

336 psi_out = _contract_and_restore(psi, gate_tensor, k, self.wires) 

337 return psi_out.reshape(2**n_qubits) 

338 

339 def apply_to_state_tensor(self, psi: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

340 """Apply this gate to a statevector already in tensor form. 

341 

342 Like :meth:`apply_to_state` but expects the state in rank-n tensor 

343 form ``(2,)*n`` and returns the result in the same form. This avoids 

344 the ``reshape`` calls at the per-gate level when the simulation loop 

345 keeps the state in tensor form throughout. 

346 

347 Args: 

348 psi: Statevector tensor of shape ``(2,)*n_qubits``. 

349 n_qubits: Total number of qubits in the circuit. 

350 

351 Returns: 

352 Updated statevector tensor of shape ``(2,)*n_qubits``. 

353 """ 

354 k = len(self.wires) 

355 gate_tensor = self._gate_tensor(k) 

356 return _contract_and_restore(psi, gate_tensor, k, self.wires) 

357 

358 def _gate_tensor(self, k: int) -> jnp.ndarray: 

359 """Return the gate matrix reshaped to ``(2,)*2k`` tensor form. 

360 

361 The result is cached on the instance so repeated calls (e.g. from 

362 density-matrix simulation which applies U and U*) avoid redundant 

363 reshape dispatch. 

364 

365 Args: 

366 k: Number of qubits the gate acts on. 

367 

368 Returns: 

369 Gate matrix as a rank-2k tensor of shape ``(2,)*2k``. 

370 """ 

371 cached = getattr(self, "_cached_gate_tensor", None) 

372 if cached is not None: 

373 return cached 

374 gt = self.matrix.reshape((2,) * 2 * k) 

375 # Only cache for non-parametrized gates (whose matrix is a class attr) 

376 if self._matrix is self.__class__._matrix: 

377 object.__setattr__(self, "_cached_gate_tensor", gt) 

378 return gt 

379 

380 def apply_to_density(self, rho: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

381 """Apply this gate to a density matrix via \\rho -> U\\rho U\\dagger. 

382 

383 The density matrix (shape ``(2**n, 2**n)``) is treated as a rank-*2n* 

384 tensor with n "ket" axes (0..n-1) and n "bra" axes (n..2n-1). 

385 U acts on the ket half; U* acts on the bra half. Both contractions 

386 use the shared :func:`_contract_and_restore` helper, keeping the 

387 operation allocation-free with respect to building full unitaries. 

388 

389 Args: 

390 rho: Density matrix of shape ``(2**n_qubits, 2**n_qubits)``. 

391 n_qubits: Total number of qubits in the circuit. 

392 

393 Returns: 

394 Updated density matrix of shape ``(2**n_qubits, 2**n_qubits)``. 

395 """ 

396 k = len(self.wires) 

397 U = self._gate_tensor(k) 

398 U_conj = jnp.conj(U) 

399 

400 rho_t = rho.reshape((2,) * 2 * n_qubits) 

401 

402 # Apply U to ket axes, U\\dagger to bra axes 

403 rho_t = _contract_and_restore(rho_t, U, k, self.wires) 

404 bra_wires = [w + n_qubits for w in self.wires] 

405 rho_t = _contract_and_restore(rho_t, U_conj, k, bra_wires) 

406 

407 return rho_t.reshape(2**n_qubits, 2**n_qubits) 

408 

409 

410class Hermitian(Operation): 

411 """A generic Hermitian observable or gate defined by an arbitrary matrix. 

412 

413 Example: 

414 >>> obs = Hermitian(matrix=my_matrix, wires=0) 

415 """ 

416 

417 def __init__( 

418 self, 

419 matrix: jnp.ndarray, 

420 wires: Union[int, List[int]] = 0, 

421 record: bool = True, 

422 ) -> None: 

423 """Initialise a Hermitian operator. 

424 

425 Args: 

426 matrix: The Hermitian matrix defining this operator. 

427 wires: Qubit index or list of qubit indices this operator acts on. 

428 record: If ``True`` (default), record on the active tape. Set to 

429 ``False`` when using the Hermitian purely as a Hamiltonian 

430 component (e.g. for time-dependent evolution). 

431 """ 

432 super().__init__( 

433 wires=wires, 

434 matrix=jnp.asarray(matrix, dtype=_cdtype()), 

435 record=record, 

436 ) 

437 

438 def __rmul__(self, coeff_fn): 

439 """Support ``coeff_fn * Hermitian`` -> :class:`ParametrizedHamiltonian`. 

440 

441 Args: 

442 coeff_fn: A callable ``(params, t) -> scalar`` giving the 

443 time-dependent coefficient. 

444 

445 Returns: 

446 A :class:`ParametrizedHamiltonian` pairing *coeff_fn* with this 

447 operator's matrix and wires. 

448 

449 Raises: 

450 TypeError: If *coeff_fn* is not callable. 

451 """ 

452 if not callable(coeff_fn): 

453 raise TypeError( 

454 f"Left operand of `* Hermitian` must be callable, got {type(coeff_fn)}" 

455 ) 

456 return ParametrizedHamiltonian(coeff_fn, self.matrix, self.wires) 

457 

458 

459class ParametrizedHamiltonian: 

460 """A time-dependent Hamiltonian ``H(t) = f(params, t) · H_mat``. 

461 

462 Created by multiplying a callable coefficient function with a 

463 :class:`Hermitian` operator:: 

464 

465 def coeff(p, t): 

466 return p[0] * jnp.exp(-0.5 * ((t - t_c) / p[1]) ** 2) 

467 

468 H_td = coeff * Hermitian(matrix=sigma_x, wires=0) 

469 

470 The Hamiltonian is then used with :func:`evolve`:: 

471 

472 evolve(H_td)(coeff_args=[A, sigma], T=1.0) 

473 

474 Attributes: 

475 coeff_fn: Callable ``(params, t) -> scalar``. 

476 H_mat: Static Hermitian matrix (JAX array). 

477 wires: Qubit wire(s) this Hamiltonian acts on. 

478 """ 

479 

480 def __init__( 

481 self, 

482 coeff_fn: Callable, 

483 H_mat: jnp.ndarray, 

484 wires: Union[int, List[int]], 

485 ) -> None: 

486 self.coeff_fn = coeff_fn 

487 self.H_mat = H_mat 

488 self.wires = wires 

489 

490 

491class Id(Operation): 

492 """Identity gate.""" 

493 

494 _matrix = jnp.eye(2, dtype=_cdtype()) 

495 _num_wires = 1 

496 

497 def __init__(self, wires: Union[int, List[int]] = 0, **kwargs) -> None: 

498 """Initialise an identity gate. 

499 

500 Args: 

501 wires: Qubit index or list of qubit indices this gate acts on. 

502 """ 

503 super().__init__(wires=wires, **kwargs) 

504 

505 

506class PauliX(Operation): 

507 """Pauli-X gate / observable (bit-flip, \\sigma_x).""" 

508 

509 _matrix = jnp.array([[0, 1], [1, 0]], dtype=_cdtype()) 

510 _num_wires = 1 

511 

512 def __init__(self, wires: Union[int, List[int]] = 0, **kwargs) -> None: 

513 """Initialise a Pauli-X gate. 

514 

515 Args: 

516 wires: Qubit index or list of qubit indices this gate acts on. 

517 """ 

518 super().__init__(wires=wires, **kwargs) 

519 

520 

521class PauliY(Operation): 

522 """Pauli-Y gate / observable (\\sigma_y).""" 

523 

524 _matrix = jnp.array([[0, -1j], [1j, 0]], dtype=_cdtype()) 

525 _num_wires = 1 

526 

527 def __init__(self, wires: Union[int, List[int]] = 0, **kwargs) -> None: 

528 """Initialise a Pauli-Y gate. 

529 

530 Args: 

531 wires: Qubit index or list of qubit indices this gate acts on. 

532 """ 

533 super().__init__(wires=wires, **kwargs) 

534 

535 

536class PauliZ(Operation): 

537 """Pauli-Z gate / observable (phase-flip, \\sigma_z).""" 

538 

539 _matrix = jnp.array([[1, 0], [0, -1]], dtype=_cdtype()) 

540 _num_wires = 1 

541 

542 def __init__(self, wires: Union[int, List[int]] = 0, **kwargs) -> None: 

543 """Initialise a Pauli-Z gate. 

544 

545 Args: 

546 wires: Qubit index or list of qubit indices this gate acts on. 

547 """ 

548 super().__init__(wires=wires, **kwargs) 

549 

550 

551class H(Operation): 

552 """Hadamard gate.""" 

553 

554 _matrix = jnp.array([[1, 1], [1, -1]], dtype=_cdtype()) / jnp.sqrt(2) 

555 _num_wires = 1 

556 

557 def __init__(self, wires: Union[int, List[int]] = 0, **kwargs) -> None: 

558 """Initialise a Hadamard gate. 

559 

560 Args: 

561 wires: Qubit index or list of qubit indices this gate acts on. 

562 """ 

563 super().__init__(wires=wires, **kwargs) 

564 

565 

566class S(Operation): 

567 """S (phase) gate — a Clifford gate equal to \\sqrt Z. 

568 

569 .. math:: 

570 S = \\begin{pmatrix}1 & 0\\ 0 & i\\end{pmatrix} 

571 """ 

572 

573 _matrix = jnp.array([[1, 0], [0, 1j]], dtype=_cdtype()) 

574 _num_wires = 1 

575 

576 def __init__(self, wires: Union[int, List[int]] = 0) -> None: 

577 """Initialise an S gate. 

578 

579 Args: 

580 wires: Qubit index or list of qubit indices this gate acts on. 

581 """ 

582 super().__init__(wires=wires) 

583 

584 

585class RandomUnitary(Operation): 

586 """Creates a random hermitian matrix and applies it as a gate.""" 

587 

588 def __init__(self, wires, key, scale=1.0, record=True): 

589 """Initialise a random unitary gate. 

590 

591 Args: 

592 wires: Qubit index or list of qubit indices this gate acts on. 

593 jax.random.PRNGKey: PRNGKey for randomization 

594 scale: Scale of the random unitary (default: 1.0) 

595 """ 

596 dim = 2 ** len(wires) 

597 key_a, key_b = jax.random.split(key) 

598 

599 A = ( 

600 jax.random.normal(key=key_a, shape=(dim, dim)) 

601 + 1j * jax.random.normal(key=key_b, shape=(dim, dim)) 

602 ).astype(_cdtype()) 

603 H = (A + A.conj().T) / 2.0 

604 

605 H *= scale / jnp.linalg.norm(H, ord="fro") 

606 

607 super().__init__(wires, matrix=H, record=record) 

608 

609 

610class Barrier(Operation): 

611 """Barrier operation — a no-op used for visual circuit separation. 

612 

613 The barrier does not change the quantum state. It is recorded on the 

614 tape so that drawing backends can insert a visual separator. 

615 """ 

616 

617 _matrix = None # not a real gate 

618 

619 def __init__(self, wires: Union[int, List[int]] = 0) -> None: 

620 """Initialise a Barrier. 

621 

622 Args: 

623 wires: Qubit index or list of qubit indices this barrier spans. 

624 """ 

625 super().__init__(wires=wires) 

626 

627 def apply_to_state(self, state: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

628 """No-op: return the state unchanged.""" 

629 return state 

630 

631 def apply_to_state_tensor(self, psi: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

632 """No-op: return the state tensor unchanged.""" 

633 return psi 

634 

635 def apply_to_density(self, rho: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

636 """No-op: return the density matrix unchanged.""" 

637 return rho 

638 

639 

640def _make_rotation_gate(pauli_class: type, name: str) -> type: 

641 """Factory for single-qubit rotation gates RX, RY, RZ. 

642 

643 Each gate has the form ``R_P(\\theta) = cos(\\theta/2) I - i sin(\\theta/2) P``. 

644 

645 Args: 

646 pauli_class: One of PauliX, PauliY, PauliZ. 

647 name: Class name for the generated gate (e.g. ``"RX"``). 

648 

649 Returns: 

650 A new :class:`Operation` subclass. 

651 """ 

652 pauli_mat = pauli_class._matrix 

653 

654 class _RotationGate(Operation): 

655 # Fancy way of setting docstring to make it generic 

656 __doc__ = ( 

657 f"Rotation around the {name[1]} axis: {name}(\\theta) =\n" 

658 f"exp(-i \\theta/2 {name[1]}).\n" 

659 ) 

660 _num_wires = 1 

661 _param_names = ("theta",) 

662 

663 def __init__( 

664 self, theta: float, wires: Union[int, List[int]] = 0, **kwargs 

665 ) -> None: 

666 self.theta = theta 

667 c = jnp.cos(theta / 2) 

668 s = jnp.sin(theta / 2) 

669 mat = c * Id._matrix - 1j * s * pauli_mat 

670 super().__init__(wires=wires, matrix=mat, **kwargs) 

671 

672 def generator(self) -> Operation: 

673 """Return the generator as the corresponding Pauli operation.""" 

674 return pauli_class(wires=self.wires[0], record=False) 

675 

676 _RotationGate.__name__ = name 

677 _RotationGate.__qualname__ = name 

678 return _RotationGate 

679 

680 

681RX = _make_rotation_gate(PauliX, "RX") 

682RY = _make_rotation_gate(PauliY, "RY") 

683RZ = _make_rotation_gate(PauliZ, "RZ") 

684 

685 

686# Projectors used by controlled-gate factories 

687_P0 = jnp.array([[1, 0], [0, 0]], dtype=_cdtype()) 

688_P1 = jnp.array([[0, 0], [0, 1]], dtype=_cdtype()) 

689 

690 

691def _make_controlled_gate(target_class: type, name: str) -> type: 

692 """Factory for controlled Pauli gates CX, CY, CZ. 

693 

694 Each gate has the form 

695 ``CP = |0><0| \\otimes I + |1\\langle\\rangle 1| \\otimes P``. 

696 

697 Args: 

698 target_class: The single-qubit gate class (PauliX, PauliY, PauliZ). 

699 name: Class name for the generated gate (e.g. ``"CX"``). 

700 

701 Returns: 

702 A new :class:`Operation` subclass. 

703 """ 

704 target_mat = target_class._matrix 

705 

706 class _ControlledGate(Operation): 

707 __doc__ = ( 

708 f"Controlled-{target_class.__name__[5:]} gate.\n\n" 

709 f"Applies {target_class.__name__} on the target qubit conditioned " 

710 f"on the control qubit being in state |1\\rangle." 

711 ) 

712 _matrix = jnp.kron(_P0, Id._matrix) + jnp.kron(_P1, target_mat) 

713 _num_wires = 2 

714 is_controlled = True 

715 

716 def __init__(self, wires: List[int] = [0, 1], **kwargs) -> None: 

717 super().__init__(wires=wires, **kwargs) 

718 

719 _ControlledGate.__name__ = name 

720 _ControlledGate.__qualname__ = name 

721 return _ControlledGate 

722 

723 

724CX = _make_controlled_gate(PauliX, "CX") 

725CY = _make_controlled_gate(PauliY, "CY") 

726CZ = _make_controlled_gate(PauliZ, "CZ") 

727 

728 

729class CCX(Operation): 

730 """Toffoli (CCX) gate. 

731 

732 The 3-qubit Toffoli gate exercises the arbitrary-k-qubit path in 

733 :meth:`~Operation.apply_to_state` and cannot be expressed as a pair of 

734 2-qubit gates without ancilla, making it a good stress-test for the 

735 simulator. 

736 """ 

737 

738 _matrix = jnp.array( 

739 [ 

740 [1, 0, 0, 0, 0, 0, 0, 0], 

741 [0, 1, 0, 0, 0, 0, 0, 0], 

742 [0, 0, 1, 0, 0, 0, 0, 0], 

743 [0, 0, 0, 1, 0, 0, 0, 0], 

744 [0, 0, 0, 0, 1, 0, 0, 0], 

745 [0, 0, 0, 0, 0, 1, 0, 0], 

746 [0, 0, 0, 0, 0, 0, 0, 1], 

747 [0, 0, 0, 0, 0, 0, 1, 0], 

748 ], 

749 dtype=_cdtype(), 

750 ) 

751 is_controlled = True 

752 _num_wires = 3 

753 

754 def __init__(self, wires: List[int] = [0, 1, 2], **kwargs) -> None: 

755 """Initialise a Toffoli (CCX) gate. 

756 

757 Args: 

758 wires: Three-element list ``[control0, control1, target]``. 

759 """ 

760 super().__init__(wires=wires, **kwargs) 

761 

762 

763class CSWAP(Operation): 

764 """Controlled-SWAP (Fredkin) gate. 

765 

766 Swaps the two target qubits conditioned on the control qubit being |1\\rangle. 

767 

768 Args on construction: 

769 wires: ``[control, target0, target1]``. 

770 """ 

771 

772 _matrix = jnp.array( 

773 [ 

774 [1, 0, 0, 0, 0, 0, 0, 0], 

775 [0, 1, 0, 0, 0, 0, 0, 0], 

776 [0, 0, 1, 0, 0, 0, 0, 0], 

777 [0, 0, 0, 1, 0, 0, 0, 0], 

778 [0, 0, 0, 0, 1, 0, 0, 0], 

779 [0, 0, 0, 0, 0, 0, 1, 0], 

780 [0, 0, 0, 0, 0, 1, 0, 0], 

781 [0, 0, 0, 0, 0, 0, 0, 1], 

782 ], 

783 dtype=_cdtype(), 

784 ) 

785 is_controlled = True 

786 _num_wires = 3 

787 

788 def __init__(self, wires: List[int] = [0, 1, 2], **kwargs) -> None: 

789 """Initialise a Controlled-SWAP (Fredkin) gate. 

790 

791 Args: 

792 wires: Three-element list ``[control, target0, target1]``. 

793 """ 

794 super().__init__(wires=wires, **kwargs) 

795 

796 

797def _make_controlled_rotation_gate(pauli_class: type, name: str) -> type: 

798 """Factory for controlled rotation gates CRX, CRY, CRZ. 

799 

800 Each gate has the form 

801 ``CR_P(\\theta) = |0><0| \\otimes I + |1><1| \\otimes R_P(\\theta)``. 

802 

803 Args: 

804 pauli_class: One of PauliX, PauliY, PauliZ. 

805 name: Class name for the generated gate (e.g. ``"CRX"``). 

806 

807 Returns: 

808 A new :class:`Operation` subclass. 

809 """ 

810 pauli_mat = pauli_class._matrix 

811 

812 class _CRotationGate(Operation): 

813 __doc__ = ( 

814 f"Controlled rotation around the {name[2]} axis.\n\n" 

815 f"Applies R{name[2]}(\\theta) on the target qubit conditioned on the " 

816 f"control qubit being in state |1\\rangle.\n\n" 

817 f".. math::\n" 

818 f"{name}(\\theta) = |0\\rangle\\langle 0| \\otimes I\n" 

819 f" + |1\\rangle\\langle 1| \\otimes R{name[2]}(\\theta)" 

820 ) 

821 _num_wires = 2 

822 _param_names = ("theta",) 

823 is_controlled = True 

824 

825 def __init__(self, theta: float, wires: List[int] = [0, 1], **kwargs) -> None: 

826 self.theta = theta 

827 c = jnp.cos(theta / 2) 

828 s = jnp.sin(theta / 2) 

829 rot = c * Id._matrix - 1j * s * pauli_mat 

830 mat = jnp.kron(_P0, Id._matrix) + jnp.kron(_P1, rot) 

831 super().__init__(wires=wires, matrix=mat, **kwargs) 

832 

833 _CRotationGate.__name__ = name 

834 _CRotationGate.__qualname__ = name 

835 return _CRotationGate 

836 

837 

838CRX = _make_controlled_rotation_gate(PauliX, "CRX") 

839CRY = _make_controlled_rotation_gate(PauliY, "CRY") 

840CRZ = _make_controlled_rotation_gate(PauliZ, "CRZ") 

841 

842 

843class Rot(Operation): 

844 """General single-qubit rotation: 

845 Rot(\\phi, \\theta, \\omega) = RZ(\\omega) RY(\\theta) RZ(\\phi). 

846 

847 This is the most general SU(2) rotation (up to a global phase). It 

848 decomposes into three successive rotations and has three free parameters. 

849 """ 

850 

851 _num_wires = 1 

852 _param_names = ("phi", "theta", "omega") 

853 

854 def __init__( 

855 self, 

856 phi: float, 

857 theta: float, 

858 omega: float, 

859 wires: Union[int, List[int]] = 0, 

860 **kwargs, 

861 ) -> None: 

862 """Initialise a general rotation gate. 

863 

864 Args: 

865 phi: First RZ rotation angle (radians). 

866 theta: RY rotation angle (radians). 

867 omega: Second RZ rotation angle (radians). 

868 wires: Qubit index or list of qubit indices this gate acts on. 

869 """ 

870 self.phi = phi 

871 self.theta = theta 

872 self.omega = omega 

873 # Rot(\\phi, \theta, \\omega) = RZ(\\omega) @ RY(\theta) @ RZ(\\phi) 

874 rz_phi = jnp.cos(phi / 2) * Id._matrix - 1j * jnp.sin(phi / 2) * PauliZ._matrix 

875 ry_theta = ( 

876 jnp.cos(theta / 2) * Id._matrix - 1j * jnp.sin(theta / 2) * PauliY._matrix 

877 ) 

878 rz_omega = ( 

879 jnp.cos(omega / 2) * Id._matrix - 1j * jnp.sin(omega / 2) * PauliZ._matrix 

880 ) 

881 mat = rz_omega @ ry_theta @ rz_phi 

882 super().__init__(wires=wires, matrix=mat, **kwargs) 

883 

884 

885class PauliRot(Operation): 

886 """Multi-qubit Pauli rotation: exp(-i \\theta/2 P) for a Pauli word P. 

887 

888 The Pauli word is given as a string of ``'I'``, ``'X'``, ``'Y'``, ``'Z'`` 

889 characters (one per qubit). The rotation matrix is computed as 

890 ``cos(\\theta/2) I - i sin(\\theta/2) P`` where *P* is the tensor product of the 

891 corresponding single-qubit Pauli matrices. 

892 

893 Example:: 

894 

895 PauliRot(0.5, "XY", wires=[0, 1]) 

896 """ 

897 

898 _param_names = ("theta",) 

899 

900 # Map from character to 2x2 matrix 

901 _PAULI_MAP = { 

902 "I": Id._matrix, 

903 "X": PauliX._matrix, 

904 "Y": PauliY._matrix, 

905 "Z": PauliZ._matrix, 

906 } 

907 

908 def __init__( 

909 self, theta: float, pauli_word: str, wires: Union[int, List[int]] = 0, **kwargs 

910 ) -> None: 

911 """Initialise a PauliRot gate. 

912 

913 Args: 

914 theta: Rotation angle in radians. 

915 pauli_word: A string of ``'I'``, ``'X'``, ``'Y'``, ``'Z'`` 

916 characters specifying the Pauli tensor product. 

917 wires: Qubit index or list of qubit indices this gate acts on. 

918 """ 

919 from functools import reduce as _reduce 

920 

921 self.theta = theta 

922 self.pauli_word = pauli_word 

923 

924 pauli_matrices = [self._PAULI_MAP[c] for c in pauli_word] 

925 P = _reduce(jnp.kron, pauli_matrices) 

926 dim = P.shape[0] 

927 mat = ( 

928 jnp.cos(theta / 2) * jnp.eye(dim, dtype=_cdtype()) 

929 - 1j * jnp.sin(theta / 2) * P 

930 ) 

931 super().__init__(wires=wires, matrix=mat, **kwargs) 

932 

933 def generator(self) -> Operation: 

934 """Return the generator Pauli tensor product as an :class:`Operation`. 

935 

936 The generator of ``PauliRot(\\theta, word, wires)`` is the tensor product 

937 of single-qubit Pauli matrices specified by *word*. The returned 

938 :class:`Hermitian` wraps that matrix and the gate's wires. 

939 

940 Returns: 

941 :class:`Hermitian` operation representing the Pauli tensor product. 

942 """ 

943 from functools import reduce as _reduce 

944 

945 pauli_matrices = [self._PAULI_MAP[c] for c in self.pauli_word] 

946 P = _reduce(jnp.kron, pauli_matrices) 

947 return Hermitian(matrix=P, wires=self.wires, record=False) 

948 

949 

950class KrausChannel(Operation): 

951 """Base class for noise channels defined by a set of Kraus operators. 

952 

953 A Kraus channel \\phi(\\rho ) = \\sigma_k K_k \\rho K_k\\dagger 

954 is the most general physical 

955 operation on a quantum state. For a pure unitary gate there is a single 

956 operator K_0 = U satisfying K_0\\daggerK_0 = I; for noisy channels there are 

957 multiple operators. 

958 

959 Subclasses must implement :meth:`kraus_matrices` and return a list of JAX 

960 arrays. :meth:`apply_to_state` is intentionally left unimplemented: 

961 Kraus channels require a density-matrix representation and cannot be 

962 applied to a pure statevector in general. 

963 """ 

964 

965 def kraus_matrices(self) -> List[jnp.ndarray]: 

966 """Return the list of Kraus operators for this channel. 

967 

968 Returns: 

969 List of 2-D JAX arrays, each of shape ``(2**k, 2**k)`` where k 

970 is the number of target qubits. 

971 

972 Raises: 

973 NotImplementedError: Subclasses must override this method. 

974 """ 

975 raise NotImplementedError 

976 

977 @property 

978 def matrix(self) -> jnp.ndarray: 

979 """Raises TypeError — noise channels have no single unitary matrix. 

980 

981 Raises: 

982 TypeError: Always raised; use :meth:`apply_to_density` instead. 

983 """ 

984 raise TypeError( 

985 f"{self.__class__.__name__} is a noise channel and has no single " 

986 "unitary matrix. Use apply_to_density() instead." 

987 ) 

988 

989 def apply_to_state(self, state: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

990 """Raises TypeError — noise channels require density-matrix simulation. 

991 

992 Args: 

993 state: Statevector (unused). 

994 n_qubits: Number of qubits (unused). 

995 

996 Raises: 

997 TypeError: Always raised; use ``execute(type='density')`` instead. 

998 """ 

999 raise TypeError( 

1000 f"{self.__class__.__name__} is a noise channel and cannot be " 

1001 "applied to a pure statevector. Use execute(type='density') instead." 

1002 ) 

1003 

1004 def apply_to_state_tensor(self, psi: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

1005 """Raises TypeError — noise channels require density-matrix simulation.""" 

1006 raise TypeError( 

1007 f"{self.__class__.__name__} is a noise channel and cannot be " 

1008 "applied to a pure statevector. Use execute(type='density') instead." 

1009 ) 

1010 

1011 def apply_to_density(self, rho: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

1012 """Apply 

1013 \\phi(\\rho ) = \\sigma_k K_k \\rho K_k\\dagger using tensor-contraction. 

1014 

1015 Uses the shared :func:`_contract_and_restore` helper, summing the 

1016 result over all Kraus operators. 

1017 

1018 Args: 

1019 rho: Density matrix of shape ``(2**n_qubits, 2**n_qubits)``. 

1020 n_qubits: Total number of qubits in the circuit. 

1021 

1022 Returns: 

1023 Updated density matrix of shape ``(2**n_qubits, 2**n_qubits)``. 

1024 """ 

1025 k = len(self.wires) 

1026 dim = 2**n_qubits 

1027 bra_wires = [w + n_qubits for w in self.wires] 

1028 rho_out = jnp.zeros_like(rho) 

1029 

1030 for K in self.kraus_matrices(): 

1031 K_t = K.reshape((2,) * 2 * k) 

1032 K_conj_t = jnp.conj(K_t) 

1033 rho_t = rho.reshape((2,) * 2 * n_qubits) 

1034 rho_t = _contract_and_restore(rho_t, K_t, k, self.wires) 

1035 rho_t = _contract_and_restore(rho_t, K_conj_t, k, bra_wires) 

1036 rho_out = rho_out + rho_t.reshape(dim, dim) 

1037 

1038 return rho_out 

1039 

1040 

1041class BitFlip(KrausChannel): 

1042 r"""Single-qubit bit-flip (Pauli-X) error channel. 

1043 

1044 .. math:: 

1045 K_0 = \sqrt{1-p}\,I, \quad K_1 = \sqrt{p}\,X 

1046 

1047 where *p* \\in [0, 1] is the probability of a bit flip. 

1048 """ 

1049 

1050 _num_wires = 1 

1051 _param_names = ("p",) 

1052 

1053 def __init__(self, p: float, wires: Union[int, List[int]] = 0) -> None: 

1054 """Initialise a bit-flip channel. 

1055 

1056 Args: 

1057 p: Bit-flip probability, must be in [0, 1]. 

1058 wires: Qubit index or list of qubit indices this channel acts on. 

1059 

1060 Raises: 

1061 ValueError: If *p* is outside [0, 1]. 

1062 """ 

1063 if not 0.0 <= p <= 1.0: 

1064 raise ValueError("p must be in [0, 1].") 

1065 self.p = p 

1066 super().__init__(wires=wires) 

1067 

1068 def kraus_matrices(self) -> List[jnp.ndarray]: 

1069 """Return the two Kraus operators for the bit-flip channel. 

1070 

1071 Returns: 

1072 List ``[K0, K1]`` where K0 = \\sqrt (1-p)·I and K1 = \\sqrt p·X. 

1073 """ 

1074 p = self.p 

1075 K0 = jnp.sqrt(1 - p) * Id._matrix 

1076 K1 = jnp.sqrt(p) * PauliX._matrix 

1077 return [K0, K1] 

1078 

1079 

1080class PhaseFlip(KrausChannel): 

1081 r"""Single-qubit phase-flip (Pauli-Z) error channel. 

1082 

1083 .. math:: 

1084 K_0 = \sqrt{1-p}\,I, \quad K_1 = \sqrt{p}\,Z 

1085 

1086 where *p* \\in [0, 1] is the probability of a phase flip. 

1087 """ 

1088 

1089 _num_wires = 1 

1090 _param_names = ("p",) 

1091 

1092 def __init__(self, p: float, wires: Union[int, List[int]] = 0) -> None: 

1093 """Initialise a phase-flip channel. 

1094 

1095 Args: 

1096 p: Phase-flip probability, must be in [0, 1]. 

1097 wires: Qubit index or list of qubit indices this channel acts on. 

1098 

1099 Raises: 

1100 ValueError: If *p* is outside [0, 1]. 

1101 """ 

1102 if not 0.0 <= p <= 1.0: 

1103 raise ValueError("p must be in [0, 1].") 

1104 self.p = p 

1105 super().__init__(wires=wires) 

1106 

1107 def kraus_matrices(self) -> List[jnp.ndarray]: 

1108 """Return the two Kraus operators for the phase-flip channel. 

1109 

1110 Returns: 

1111 List ``[K0, K1]`` where K0 = \\sqrt (1-p)·I and K1 = \\sqrt p·Z. 

1112 """ 

1113 p = self.p 

1114 K0 = jnp.sqrt(1 - p) * Id._matrix 

1115 K1 = jnp.sqrt(p) * PauliZ._matrix 

1116 return [K0, K1] 

1117 

1118 

1119class DepolarizingChannel(KrausChannel): 

1120 r"""Single-qubit depolarizing channel. 

1121 

1122 .. math:: 

1123 K_0 = \sqrt{1-p}\,I,\quad K_1 = \sqrt{p/3}\,X,\quad 

1124 K_2 = \sqrt{p/3}\,Y,\quad K_3 = \sqrt{p/3}\,Z 

1125 

1126 where *p* \\in [0, 1]. At p = 3/4 the channel is fully depolarizing. 

1127 """ 

1128 

1129 _num_wires = 1 

1130 _param_names = ("p",) 

1131 

1132 def __init__(self, p: float, wires: Union[int, List[int]] = 0) -> None: 

1133 """Initialise a depolarizing channel. 

1134 

1135 Args: 

1136 p: Depolarization probability, must be in [0, 1]. 

1137 wires: Qubit index or list of qubit indices this channel acts on. 

1138 

1139 Raises: 

1140 ValueError: If *p* is outside [0, 1]. 

1141 """ 

1142 if not 0.0 <= p <= 1.0: 

1143 raise ValueError("p must be in [0, 1].") 

1144 self.p = p 

1145 super().__init__(wires=wires) 

1146 

1147 def kraus_matrices(self) -> List[jnp.ndarray]: 

1148 """Return the four Kraus operators for the depolarizing channel. 

1149 

1150 Returns: 

1151 List ``[K0, K1, K2, K3]`` corresponding to I, X, Y, Z components. 

1152 """ 

1153 p = self.p 

1154 K0 = jnp.sqrt(1 - p) * Id._matrix 

1155 K1 = jnp.sqrt(p / 3) * PauliX._matrix 

1156 K2 = jnp.sqrt(p / 3) * PauliY._matrix 

1157 K3 = jnp.sqrt(p / 3) * PauliZ._matrix 

1158 return [K0, K1, K2, K3] 

1159 

1160 

1161class AmplitudeDamping(KrausChannel): 

1162 r"""Single-qubit amplitude damping channel. 

1163 

1164 .. math:: 

1165 K_0 = \begin{pmatrix}1 & 0\\ 0 & \sqrt{1-\gamma}\end{pmatrix},\quad 

1166 K_1 = \begin{pmatrix}0 & \sqrt{\gamma}\\ 0 & 0\end{pmatrix} 

1167 

1168 where *\\gamma* \\in [0, 1] is the probability of 

1169 energy loss (|1\\rangle -> |0\\rangle). 

1170 """ 

1171 

1172 _num_wires = 1 

1173 _param_names = ("gamma",) 

1174 

1175 def __init__(self, gamma: float, wires: Union[int, List[int]] = 0) -> None: 

1176 """Initialise an amplitude damping channel. 

1177 

1178 Args: 

1179 gamma: Energy-loss probability, must be in [0, 1]. 

1180 wires: Qubit index or list of qubit indices this channel acts on. 

1181 

1182 Raises: 

1183 ValueError: If *gamma* is outside [0, 1]. 

1184 """ 

1185 if not 0.0 <= gamma <= 1.0: 

1186 raise ValueError("gamma must be in [0, 1].") 

1187 self.gamma = gamma 

1188 super().__init__(wires=wires) 

1189 

1190 def kraus_matrices(self) -> List[jnp.ndarray]: 

1191 """Return the two Kraus operators for the amplitude damping channel. 

1192 

1193 Returns: 

1194 List ``[K0, K1]`` as defined in the class docstring. 

1195 """ 

1196 g = self.gamma 

1197 K0 = jnp.array([[1.0, 0.0], [0.0, jnp.sqrt(1 - g)]], dtype=_cdtype()) 

1198 K1 = jnp.array([[0.0, jnp.sqrt(g)], [0.0, 0.0]], dtype=_cdtype()) 

1199 return [K0, K1] 

1200 

1201 

1202class PhaseDamping(KrausChannel): 

1203 r"""Single-qubit phase damping (dephasing) channel. 

1204 

1205 .. math:: 

1206 K_0 = \begin{pmatrix}1 & 0\\ 0 & \sqrt{1-\gamma}\end{pmatrix},\quad 

1207 K_1 = \begin{pmatrix}0 & 0\\ 0 & \sqrt{\gamma}\end{pmatrix} 

1208 

1209 where *\\gamma* \\in [0, 1] is the phase damping probability. 

1210 """ 

1211 

1212 _num_wires = 1 

1213 _param_names = ("gamma",) 

1214 

1215 def __init__(self, gamma: float, wires: Union[int, List[int]] = 0) -> None: 

1216 """Initialise a phase damping channel. 

1217 

1218 Args: 

1219 gamma: Phase-damping probability, must be in [0, 1]. 

1220 wires: Qubit index or list of qubit indices this channel acts on. 

1221 

1222 Raises: 

1223 ValueError: If *gamma* is outside [0, 1]. 

1224 """ 

1225 if not 0.0 <= gamma <= 1.0: 

1226 raise ValueError("gamma must be in [0, 1].") 

1227 self.gamma = gamma 

1228 super().__init__(wires=wires) 

1229 

1230 def kraus_matrices(self) -> List[jnp.ndarray]: 

1231 """Return the two Kraus operators for the phase damping channel. 

1232 

1233 Returns: 

1234 List ``[K0, K1]`` as defined in the class docstring. 

1235 """ 

1236 g = self.gamma 

1237 K0 = jnp.array([[1.0, 0.0], [0.0, jnp.sqrt(1 - g)]], dtype=_cdtype()) 

1238 K1 = jnp.array([[0.0, 0.0], [0.0, jnp.sqrt(g)]], dtype=_cdtype()) 

1239 return [K0, K1] 

1240 

1241 

1242class ThermalRelaxationError(KrausChannel): 

1243 r"""Single-qubit thermal relaxation error channel. 

1244 

1245 Models simultaneous T_1 energy relaxation and T_2 dephasing. Two regimes 

1246 are handled: 

1247 

1248 T_2 <= T_1 (Markovian dephasing + reset): 

1249 Six Kraus operators built from p_z (phase-flip probability), p_r0 

1250 (reset-to-|0\\rangle probability) and p_r1 (reset-to-|1\\rangle probability). 

1251 

1252 T_2 > T_1 (non-Markovian; Choi matrix decomposition): 

1253 The Choi matrix is assembled from the relaxation/dephasing rates, then 

1254 diagonalised; Kraus operators are K_i = \sqrt \lambda_i · mat(v_i). 

1255 

1256 Attributes: 

1257 pe: Excited-state population (thermal population of |1\\rangle). 

1258 t1: T_1 longitudinal relaxation time. 

1259 t2: T_2 transverse dephasing time. 

1260 tg: Gate duration. 

1261 """ 

1262 

1263 _num_wires = 1 

1264 _param_names = ("pe", "t1", "t2", "tg") 

1265 

1266 def __init__( 

1267 self, 

1268 pe: float, 

1269 t1: float, 

1270 t2: float, 

1271 tg: float, 

1272 wires: Union[int, List[int]] = 0, 

1273 ) -> None: 

1274 """Initialise a thermal relaxation error channel. 

1275 

1276 Args: 

1277 pe: Excited-state population (thermal population of |1\\rangle), in [0, 1]. 

1278 t1: T_1 longitudinal relaxation time, must be > 0. 

1279 t2: T_2 transverse dephasing time, must be > 0 and <= 2·T_1. 

1280 tg: Gate duration, must be >= 0. 

1281 wires: Qubit index or list of qubit indices this channel acts on. 

1282 

1283 Raises: 

1284 ValueError: If any parameter violates the stated constraints. 

1285 """ 

1286 if not 0.0 <= pe <= 1.0: 

1287 raise ValueError("pe must be in [0, 1].") 

1288 if t1 <= 0: 

1289 raise ValueError("t1 must be > 0.") 

1290 if t2 <= 0: 

1291 raise ValueError("t2 must be > 0.") 

1292 if t2 > 2 * t1: 

1293 raise ValueError("t2 must be <= 2·t1.") 

1294 if tg < 0: 

1295 raise ValueError("tg must be >= 0.") 

1296 self.pe = pe 

1297 self.t1 = t1 

1298 self.t2 = t2 

1299 self.tg = tg 

1300 super().__init__(wires=wires) 

1301 

1302 def kraus_matrices(self) -> List[jnp.ndarray]: 

1303 """Return the Kraus operators for the thermal relaxation channel. 

1304 

1305 The number of operators depends on the regime: 

1306 

1307 * T_2 <= T_1: six operators (identity, phase-flip, two reset-to-|0\\rangle, 

1308 two reset-to-|1\\rangle). 

1309 * T_2 > T_1: four operators derived from the Choi matrix eigendecomposition. 

1310 

1311 Returns: 

1312 List of 2x2 JAX arrays representing the Kraus operators. 

1313 """ 

1314 pe, t1, t2, tg = self.pe, self.t1, self.t2, self.tg 

1315 

1316 eT1 = jnp.exp(-tg / t1) 

1317 p_reset = 1.0 - eT1 

1318 eT2 = jnp.exp(-tg / t2) 

1319 

1320 if t2 <= t1: 

1321 # --- Case T_2 <= T_1: six Kraus operators --- 

1322 pz = (1.0 - p_reset) * (1.0 - eT2 / eT1) / 2.0 

1323 pr0 = (1.0 - pe) * p_reset 

1324 pr1 = pe * p_reset 

1325 pid = 1.0 - pz - pr0 - pr1 

1326 

1327 K0 = jnp.sqrt(pid) * jnp.eye(2, dtype=_cdtype()) 

1328 K1 = jnp.sqrt(pz) * jnp.array([[1, 0], [0, -1]], dtype=_cdtype()) 

1329 K2 = jnp.sqrt(pr0) * jnp.array([[1, 0], [0, 0]], dtype=_cdtype()) 

1330 K3 = jnp.sqrt(pr0) * jnp.array([[0, 1], [0, 0]], dtype=_cdtype()) 

1331 K4 = jnp.sqrt(pr1) * jnp.array([[0, 0], [1, 0]], dtype=_cdtype()) 

1332 K5 = jnp.sqrt(pr1) * jnp.array([[0, 0], [0, 1]], dtype=_cdtype()) 

1333 return [K0, K1, K2, K3, K4, K5] 

1334 

1335 else: 

1336 # --- Case T_2 > T_1: Choi matrix decomposition --- 

1337 # Choi matrix (column-major / reshaping convention matching PennyLane) 

1338 choi = jnp.array( 

1339 [ 

1340 [1 - pe * p_reset, 0, 0, eT2], 

1341 [0, pe * p_reset, 0, 0], 

1342 [0, 0, (1 - pe) * p_reset, 0], 

1343 [eT2, 0, 0, 1 - (1 - pe) * p_reset], 

1344 ], 

1345 dtype=_cdtype(), 

1346 ) 

1347 eigenvalues, eigenvectors = jnp.linalg.eigh(choi) 

1348 # Each eigenvector (column of eigenvectors) reshaped as 2x2 -> one Kraus op 

1349 kraus = [] 

1350 for i in range(4): 

1351 lam = eigenvalues[i] 

1352 vec = eigenvectors[:, i] 

1353 mat = jnp.sqrt(jnp.abs(lam)) * vec.reshape(2, 2, order="F") 

1354 kraus.append(mat.astype(_cdtype())) 

1355 return kraus 

1356 

1357 

1358class QubitChannel(KrausChannel): 

1359 """Generic Kraus channel from a user-supplied list of Kraus operators. 

1360 

1361 This replaces PennyLane's ``qml.QubitChannel`` and accepts an arbitrary set 

1362 of Kraus matrices satisfying \\sigma_k K_k\\dagger K_k = I. 

1363 

1364 Example:: 

1365 

1366 kraus_ops = [jnp.sqrt(0.9) * jnp.eye(2), jnp.sqrt(0.1) * PauliX._matrix] 

1367 QubitChannel(kraus_ops, wires=0) 

1368 """ 

1369 

1370 def __init__( 

1371 self, kraus_ops: List[jnp.ndarray], wires: Union[int, List[int]] = 0 

1372 ) -> None: 

1373 """Initialise a generic Kraus channel. 

1374 

1375 Args: 

1376 kraus_ops: List of Kraus matrices. Each must be a square 2D array 

1377 of dimension ``2**k x 2**k`` where k = ``len(wires)``. 

1378 wires: Qubit index or list of qubit indices this channel acts on. 

1379 """ 

1380 self._kraus_ops = [jnp.asarray(K, dtype=_cdtype()) for K in kraus_ops] 

1381 super().__init__(wires=wires) 

1382 

1383 def kraus_matrices(self) -> List[jnp.ndarray]: 

1384 """Return the stored Kraus operators. 

1385 

1386 Returns: 

1387 List of Kraus operator matrices. 

1388 """ 

1389 return self._kraus_ops 

1390 

1391 

1392# Single-qubit Pauli matrices (plain arrays, no Operation overhead) 

1393_PAULI_MATS = [Id._matrix, PauliX._matrix, PauliY._matrix, PauliZ._matrix] 

1394_PAULI_LABELS = ["I", "X", "Y", "Z"] 

1395_PAULI_CLASSES = [Id, PauliX, PauliY, PauliZ] 

1396 

1397 

1398def evolve_pauli_with_clifford( 

1399 clifford: Operation, 

1400 pauli: Operation, 

1401 adjoint_left: bool = True, 

1402) -> Operation: 

1403 """Compute C\\dagger P C (or C P C\\dagger) and 

1404 return the result as an Operation. 

1405 

1406 Both operators are first embedded into the full Hilbert space spanned by 

1407 the union of their wire sets. The result is wrapped in a 

1408 :class:`Hermitian` so it can be used in further algebra. 

1409 

1410 Args: 

1411 clifford: A Clifford gate. 

1412 pauli: A Pauli / Hermitian operator. 

1413 adjoint_left: If ``True``, compute C\\dagger P C; otherwise C P C\\dagger. 

1414 

1415 Returns: 

1416 A :class:`Hermitian` wrapping the evolved matrix. 

1417 """ 

1418 all_wires = sorted(set(clifford.wires) | set(pauli.wires)) 

1419 n = len(all_wires) 

1420 

1421 C = _embed_matrix(clifford.matrix, clifford.wires, all_wires, n) 

1422 P = _embed_matrix(pauli.matrix, pauli.wires, all_wires, n) 

1423 Cd = jnp.conj(C).T 

1424 

1425 if adjoint_left: 

1426 result = Cd @ P @ C 

1427 else: 

1428 result = C @ P @ Cd 

1429 

1430 return Hermitian(matrix=result, wires=all_wires, record=False) 

1431 

1432 

1433def _embed_matrix( 

1434 mat: jnp.ndarray, 

1435 op_wires: list, 

1436 all_wires: list, 

1437 n_total: int, 

1438) -> jnp.ndarray: 

1439 """Embed a gate matrix into a larger Hilbert space via tensor products. 

1440 

1441 If the gate already acts on all wires, the matrix is returned as-is. 

1442 Otherwise the gate matrix is tensored with identities on the missing 

1443 wires, and the resulting matrix rows/columns are permuted so that qubit 

1444 ordering matches *all_wires*. 

1445 

1446 Args: 

1447 mat: The gate's unitary matrix of shape ``(2**k, 2**k)`` where 

1448 ``k = len(op_wires)``. 

1449 op_wires: The wires the gate acts on. 

1450 all_wires: The full ordered list of wires. 

1451 n_total: ``len(all_wires)``. 

1452 

1453 Returns: 

1454 A ``(2**n_total, 2**n_total)`` matrix. 

1455 """ 

1456 k = len(op_wires) 

1457 if k == n_total and list(op_wires) == list(all_wires): 

1458 return mat 

1459 

1460 # Build the full-space matrix by tensoring with identities 

1461 # Strategy: tensor I on missing wires, then permute 

1462 missing = [w for w in all_wires if w not in op_wires] 

1463 # Full matrix = mat \\otimes I_{missing} 

1464 full_mat = mat 

1465 for _ in missing: 

1466 full_mat = jnp.kron(full_mat, jnp.eye(2, dtype=_cdtype())) 

1467 

1468 # The current ordering is [op_wires..., missing...] 

1469 # We need to permute to match all_wires ordering 

1470 current_order = list(op_wires) + missing 

1471 if current_order != list(all_wires): 

1472 perm = [current_order.index(w) for w in all_wires] 

1473 full_mat = _permute_matrix(full_mat, perm, n_total) 

1474 

1475 return full_mat 

1476 

1477 

1478def _permute_matrix(mat: jnp.ndarray, perm: list, n_qubits: int) -> jnp.ndarray: 

1479 """Permute the qubit ordering of a matrix. 

1480 

1481 Given a ``(2**n, 2**n)`` matrix and a permutation of ``[0..n-1]``, 

1482 reorder the qubits so that qubit ``i`` moves to position ``perm[i]``. 

1483 

1484 Args: 

1485 mat: Square matrix of dimension ``2**n_qubits``. 

1486 perm: Permutation list. 

1487 n_qubits: Number of qubits. 

1488 

1489 Returns: 

1490 Permuted matrix of the same shape. 

1491 """ 

1492 dim = 2**n_qubits 

1493 # Reshape to tensor, permute axes, reshape back 

1494 tensor = mat.reshape([2] * (2 * n_qubits)) 

1495 # Axes: first n_qubits are row indices, last n_qubits are column indices 

1496 row_perm = perm 

1497 col_perm = [p + n_qubits for p in perm] 

1498 tensor = jnp.transpose(tensor, row_perm + col_perm) 

1499 return tensor.reshape(dim, dim) 

1500 

1501 

1502def pauli_decompose(matrix: jnp.ndarray, wire_order: Optional[List[int]] = None): 

1503 r"""Decompose a Hermitian matrix into a sum of Pauli tensor products. 

1504 

1505 For an n-qubit matrix (``2**n x 2**n``), returns the dominant Pauli 

1506 term (the one with the largest absolute coefficient), wrapped as an 

1507 :class:`Operation`. This is sufficient for the Fourier-tree algorithm 

1508 which only needs the single non-zero Pauli term produced by Clifford 

1509 conjugation of a Pauli operator. 

1510 

1511 The decomposition uses the trace formula: 

1512 ``c_P = Tr(P · M) / 2**n`` 

1513 

1514 Args: 

1515 matrix: A ``(2**n, 2**n)`` Hermitian matrix. 

1516 wire_order: Optional list of wire indices. If ``None``, defaults 

1517 to ``[0, 1, ..., n-1]``. 

1518 

1519 Returns: 

1520 A tuple ``(coeff, op)`` where *coeff* is the complex coefficient and 

1521 *op* is the Pauli :class:`Operation` (PauliX, PauliY, PauliZ, I, or 

1522 a :class:`Hermitian` for multi-qubit tensor products). 

1523 """ 

1524 from itertools import product as _product 

1525 from functools import reduce as _reduce 

1526 

1527 dim = matrix.shape[0] 

1528 n_qubits = int(jnp.round(jnp.log2(dim))) 

1529 

1530 if wire_order is None: 

1531 wire_order = list(range(n_qubits)) 

1532 

1533 # For single qubit, fast path 

1534 if n_qubits == 1: 

1535 best_idx, best_coeff = 0, 0.0 

1536 for idx, P in enumerate(_PAULI_MATS): 

1537 coeff = jnp.trace(P @ matrix) / 2.0 

1538 if jnp.abs(coeff) > jnp.abs(best_coeff): 

1539 best_idx = idx 

1540 best_coeff = coeff 

1541 op_cls = _PAULI_CLASSES[best_idx] 

1542 result_op = op_cls(wires=wire_order[0], record=False) 

1543 result_op._pauli_label = _PAULI_LABELS[best_idx] 

1544 return best_coeff, result_op 

1545 

1546 # Multi-qubit: iterate over all Pauli tensor products 

1547 best_label = None 

1548 best_coeff = 0.0 

1549 for indices in _product(range(4), repeat=n_qubits): 

1550 P = _reduce(jnp.kron, [_PAULI_MATS[i] for i in indices]) 

1551 coeff = jnp.trace(P @ matrix) / dim 

1552 if jnp.abs(coeff) > jnp.abs(best_coeff): 

1553 best_coeff = coeff 

1554 best_label = indices 

1555 

1556 # Build the Pauli string label 

1557 pauli_label = "".join(_PAULI_LABELS[i] for i in best_label) 

1558 

1559 # Build the operation for the dominant term 

1560 if sum(1 for i in best_label if i != 0) <= 1: 

1561 # Single-qubit Pauli on one wire 

1562 for q, idx in enumerate(best_label): 

1563 if idx != 0: 

1564 op_cls = _PAULI_CLASSES[idx] 

1565 result_op = op_cls(wires=wire_order[q], record=False) 

1566 result_op._pauli_label = _PAULI_LABELS[idx] 

1567 return best_coeff, result_op 

1568 # All identity 

1569 result_op = Id(wires=wire_order[0], record=False) 

1570 result_op._pauli_label = "I" * n_qubits 

1571 return best_coeff, result_op 

1572 else: 

1573 # Multi-qubit tensor product -> Hermitian with pauli label attached 

1574 P = _reduce(jnp.kron, [_PAULI_MATS[i] for i in best_label]) 

1575 result_op = Hermitian(matrix=P, wires=wire_order, record=False) 

1576 result_op._pauli_label = pauli_label 

1577 return best_coeff, result_op 

1578 

1579 

1580def pauli_string_from_operation(op: Operation) -> str: 

1581 """Extract a Pauli word string from an operation. 

1582 

1583 Maps ``PauliX`` -> ``"X"``, ``PauliY`` -> ``"Y"``, ``PauliZ`` -> ``"Z"``, 

1584 ``I`` -> ``"I"``. For :class:`PauliRot`, returns its stored ``pauli_word``. 

1585 For operations produced by :func:`pauli_decompose`, returns the stored 

1586 ``_pauli_label`` attribute. 

1587 

1588 Args: 

1589 op: A quantum operation. 

1590 

1591 Returns: 

1592 A string like ``"X"``, ``"ZZ"``, etc. 

1593 """ 

1594 if isinstance(op, PauliRot) and hasattr(op, "pauli_word"): 

1595 return op.pauli_word 

1596 # Check for label stored by pauli_decompose 

1597 if hasattr(op, "_pauli_label"): 

1598 return op._pauli_label 

1599 name_map = {"PauliX": "X", "PauliY": "Y", "PauliZ": "Z", "I": "I"} 

1600 if op.name in name_map: 

1601 return name_map[op.name] 

1602 # Fall back: decompose the matrix 

1603 _, pauli_op = pauli_decompose(op.matrix, wire_order=op.wires) 

1604 return pauli_op._pauli_label