Coverage for qml_essentials / operations.py: 88%

607 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-16 10:19 +0000

1from typing import Callable, List, Optional, Tuple, Union 

2from functools import lru_cache 

3import string 

4import numpy as np 

5 

6import jax 

7import jax.numpy as jnp 

8 

9from qml_essentials.tape import active_tape, recording # noqa: F401 (re-export) 

10 

11 

12def _cdtype(): 

13 """Return the active JAX complex dtype 

14 (complex128 if x64 enabled, else complex64). 

15 """ 

16 return jnp.complex128 if jax.config.x64_enabled else jnp.complex64 

17 

18 

19@lru_cache(maxsize=256) 

20def _einsum_subscript( 

21 n: int, 

22 k: int, 

23 target_axes: Tuple[int, ...], 

24) -> str: 

25 """Build an ``einsum`` subscript that fuses contraction + axis restore. 

26 

27 Args: 

28 n: Total rank of the state tensor (number of qubits for statevectors, 

29 ``2 * n_qubits`` for density matrices). 

30 k: Number of qubits the gate acts on. 

31 target_axes: Tuple of k axis indices in the state tensor that the 

32 gate contracts against. 

33 

34 Returns: 

35 ``einsum`` subscript string, e.g. ``"ab,cBd->cad"`` for a 1-qubit 

36 gate on wire 1 of a 3-qubit state. 

37 """ 

38 letters = string.ascii_letters 

39 # State indices: one letter per axis 

40 state_idx = list(letters[:n]) 

41 # Contracted indices (the ones being replaced by the gate) 

42 contracted = [state_idx[ax] for ax in target_axes] 

43 # Gate indices: new output indices + contracted input indices 

44 new_out = [letters[n + i] for i in range(k)] # fresh letters for output 

45 gate_idx = new_out + contracted # gate shape: (out0, out1, ..., in0, in1, ...) 

46 # Result indices: replace target axes with new output letters 

47 result_idx = list(state_idx) 

48 for i, ax in enumerate(target_axes): 

49 result_idx[ax] = new_out[i] 

50 return "".join(gate_idx) + "," + "".join(state_idx) + "->" + "".join(result_idx) 

51 

52 

53def _contract_and_restore( 

54 tensor: jnp.ndarray, 

55 gate: jnp.ndarray, 

56 k: int, 

57 target_axes: List[int], 

58) -> jnp.ndarray: 

59 """Contract gate against target_axes of tensor and restore axis order. 

60 

61 The einsum subscript is cached via :func:`_einsum_subscript` so the 

62 string construction only happens once per unique 

63 ``(total, k, target_axes)`` combination. 

64 

65 Args: 

66 tensor: Rank-N tensor (e.g. ``(2,)*n`` for states or ``(2,)*2n`` 

67 for density matrices). 

68 gate: Reshaped gate tensor of shape ``(2,)*2k``. 

69 k: Number of qubits the gate acts on (= ``len(target_axes)``). 

70 target_axes: The k axes of tensor to contract against. 

71 

72 Returns: 

73 Updated tensor with the same rank as tensor, with the 

74 contracted axes restored to their original positions. 

75 """ 

76 subscript = _einsum_subscript(tensor.ndim, k, tuple(target_axes)) 

77 return jnp.einsum(subscript, gate, tensor) 

78 

79 

80class Operation: 

81 """Base class for any quantum operation or observable. 

82 

83 Further gates should inherit from this class to realise more specific 

84 operations. Generally, operations are created by instantiation inside a 

85 circuit function passed to :class:`Script`; the instance is 

86 automatically appended to the active tape. 

87 

88 An ``Operation`` can also serve as an *observable*: its matrix is used to 

89 compute expectation values via ``apply_to_state`` / ``apply_to_density``. 

90 

91 Attributes: 

92 _matrix: Class-level default gate matrix. Subclasses set this to their 

93 fixed unitary. Instances may override it via the *matrix* argument 

94 to ``__init__``. 

95 _num_wires: Expected number of wires for this gate. Subclasses set 

96 this to enforce wire count validation. ``None`` means any number 

97 of wires is accepted. 

98 _param_names: Tuple of attribute names for the gate parameters. 

99 Used by :attr:`parameters` and :meth:`__repr__`. 

100 """ 

101 

102 # Subclasses should set this to the gate's unitary / matrix 

103 _matrix: jnp.ndarray = None 

104 is_controlled = False 

105 _num_wires: Optional[int] = None 

106 _param_names: Tuple[str, ...] = () 

107 

108 def __init__( 

109 self, 

110 wires: Union[int, List[int]] = 0, 

111 matrix: Optional[jnp.ndarray] = None, 

112 record: bool = True, 

113 input_idx: int = -1, 

114 name: Optional[str] = None, 

115 ) -> None: 

116 """Initialise the operation and optionally register it on the active tape. 

117 

118 Args: 

119 wires: Qubit index or list of qubit indices this operation acts on. 

120 matrix: Optional explicit gate matrix. When provided it overrides 

121 the class-level ``_matrix`` attribute. 

122 record: If ``True`` (default) and a tape is currently recording, 

123 append this operation to the tape. Set to ``False`` for 

124 auxiliary objects that should not appear in the circuit 

125 (e.g. Hamiltonians used only to build time-dependent 

126 evolutions). 

127 input_idx: Marks the operation as input with the corresponding 

128 input index, which is useful for the analytical Fourier 

129 coefficients computation, but has no effect otherwise. 

130 name: Optional explicit name for this operation. When ``None`` 

131 (default), the class name is used (e.g. ``"RX"``). 

132 

133 Raises: 

134 ValueError: If ``_num_wires`` is set and the number of wires 

135 doesn't match, or if duplicate wires are provided. 

136 """ 

137 self.name = name or self.__class__.__name__ 

138 self.wires = list(wires) if isinstance(wires, (list, tuple)) else [wires] 

139 self.input_idx = input_idx 

140 

141 if self._num_wires is not None and len(self.wires) != self._num_wires: 

142 raise ValueError( 

143 f"{self.name} expects {self._num_wires} wire(s), " 

144 f"got {len(self.wires)}: {self.wires}" 

145 ) 

146 if len(self.wires) != len(set(self.wires)): 

147 raise ValueError(f"{self.name} received duplicate wires: {self.wires}") 

148 

149 if matrix is not None: 

150 self._matrix = matrix 

151 

152 # If a tape is currently recording, append ourselves 

153 if record: 

154 tape = active_tape() 

155 if tape is not None: 

156 tape.append(self) 

157 

158 @property 

159 def parameters(self) -> list: 

160 """Return the list of numeric parameters for this operation. 

161 

162 Uses the declarative ``_param_names`` tuple to collect parameter 

163 values in a canonical order. Non-parametrized gates return an 

164 empty list. 

165 

166 Returns: 

167 List of parameter values (floats or JAX arrays). 

168 """ 

169 return [getattr(self, name) for name in self._param_names] 

170 

171 def __repr__(self) -> str: 

172 """Return a human-readable representation of this operation. 

173 

174 Returns: 

175 A string like ``"RX(0.5000, wires=[0])"`` or ``"CX(wires=[0, 1])"``. 

176 """ 

177 params = self.parameters 

178 if params: 

179 param_str = ", ".join( 

180 ( 

181 f"{float(v):.4f}" 

182 if isinstance(v, (float, np.floating, jnp.ndarray)) 

183 else str(v) 

184 ) 

185 for v in params 

186 ) 

187 return f"{self.name}({param_str}, wires={self.wires})" 

188 return f"{self.name}(wires={self.wires})" 

189 

190 @property 

191 def matrix(self) -> jnp.ndarray: 

192 """Return the base matrix of this operation (before lifting). 

193 

194 Returns: 

195 The gate matrix as a JAX array. 

196 

197 Raises: 

198 NotImplementedError: If the subclass has not defined ``_matrix``. 

199 """ 

200 if self._matrix is None: 

201 raise NotImplementedError( 

202 f"{self.__class__.__name__} does not define a matrix." 

203 ) 

204 return self._matrix 

205 

206 @property 

207 def wires(self) -> List[int]: 

208 """Qubit indices this operation acts on. 

209 

210 Returns: 

211 List of integer qubit indices. 

212 """ 

213 return self._wires 

214 

215 @wires.setter 

216 def wires(self, wires: Union[int, List[int]]) -> None: 

217 """Set the qubit indices for this operation. 

218 

219 Args: 

220 wires: A single qubit index or a list of qubit indices. 

221 """ 

222 if isinstance(wires, (list, tuple)): 

223 self._wires = list(wires) 

224 else: 

225 self._wires = [wires] 

226 

227 @property 

228 def input_idx(self) -> int: 

229 """The index of an input 

230 

231 Returns: 

232 input_idx: Index of the input 

233 """ 

234 return self._input_idx 

235 

236 @input_idx.setter 

237 def input_idx(self, input_idx: int) -> None: 

238 """Setter for the input_idx flag 

239 

240 Args: 

241 input_idx: Index of the input 

242 """ 

243 self._input_idx = input_idx 

244 

245 def _update_tape_operation(self, op: "Operation") -> None: 

246 """ 

247 If ``self`` is already on the active tape (the typical case when 

248 chaining ``Gate(...).dagger()``), it is replaced by the daggered 

249 operation so that only U\\dagger appears on the tape — 

250 not both U and ``U\\dagger``. 

251 Note that this should only be called immediately after the tape is updated.s 

252 

253 Args: 

254 op (Operation): New replaced operation on the tape 

255 """ 

256 # If self was recorded on the tape, replace it with the daggered op. 

257 tape = active_tape() 

258 if tape is not None: 

259 if tape and tape[-1] is self: 

260 tape[-1] = op 

261 else: 

262 tape.append(op) 

263 

264 def dagger(self) -> "Operation": 

265 """Return a new operation, the conjugate transpose (``U\\dagger``) 

266 Usage inside a circuit function:: 

267 

268 RX(0.5, wires=0).dagger() 

269 

270 Returns: 

271 A new :class:`Operation` with matrix ``U\\dagger`` acting on the same wires. 

272 """ 

273 mat = jnp.conj(self._matrix).T 

274 op = Operation(wires=self.wires, matrix=mat, record=False) 

275 

276 self._update_tape_operation(op) 

277 

278 return op 

279 

280 def power(self, power) -> "Operation": 

281 """Return a new operation, the power (``U^power``) 

282 Usage inside a circuit function:: 

283 

284 PauliX(wires=0).power(2) 

285 

286 Returns: 

287 A new :class:`Operation` with matrix ``U\\dagger`` acting on the same wires. 

288 """ 

289 # TODO: support fractional powers 

290 mat = jnp.linalg.matrix_power(self._matrix, power) 

291 op = Operation(wires=self.wires, matrix=mat, record=False) 

292 

293 self._update_tape_operation(op) 

294 

295 return op 

296 

297 def __mul__(self, other: Union[float, "Operation"]) -> "Operation": 

298 """Return a new operation, the product between U and a scalar (``U*x``) 

299 or the composition of two operations. 

300 Usage inside a circuit function:: 

301 

302 PauliX(wires=0) * x 

303 PauliX(wires=0) * PauliZ(wires=0) 

304 

305 Returns: 

306 A new :class:`Operation` with matrix ``U*x`` acting on the same wires, 

307 or the composed matrix acting on the appropriate wires. 

308 """ 

309 if isinstance(other, Operation): 

310 return self.__matmul__(other) 

311 

312 mat = other * self._matrix 

313 op = Operation(wires=self.wires, matrix=mat, record=False) 

314 

315 self._update_tape_operation(op) 

316 

317 return op 

318 

319 # Also overwrite * for right operands 

320 __rmul__ = __mul__ 

321 

322 def __add__(self, other: "Operation") -> "Operation": 

323 """Element-wise addition of two operations on the same wires. 

324 

325 Returns: 

326 A new :class:`Operation` whose matrix is the sum of both matrices. 

327 

328 Raises: 

329 ValueError: If the wire sets differ. 

330 """ 

331 if sorted(self.wires) != sorted(other.wires): 

332 raise ValueError( 

333 f"Can only add operations acting on the same set of wires, " 

334 f"got {self.wires} and {other.wires}" 

335 ) 

336 

337 op = Operation( 

338 wires=self.wires, 

339 matrix=self.matrix + other.matrix, 

340 record=False, 

341 ) 

342 return op 

343 

344 def prod(self, *ops: "Operation") -> "Operation": 

345 """Construct the generalized product (tensor or matrix) 

346 of this operation with others. 

347 

348 The resulting operation acts on the union of all wire sets. 

349 If the wire sets are disjoint, this is a Kronecker product. 

350 If the wire sets overlap, the corresponding matrices are multiplied. 

351 

352 Usage:: 

353 

354 res = op1.prod(op2, op3) 

355 # or 

356 res = Operation.prod(op1, op2, op3) 

357 

358 Args: 

359 *ops: Variable number of :class:`Operation` instances. 

360 

361 Returns: 

362 A new :class:`Operation` representing the composed operation. 

363 """ 

364 if not ops: 

365 return self 

366 

367 all_ops = (self,) + ops 

368 all_wires = [] 

369 for op in all_ops: 

370 for w in op.wires: 

371 if w not in all_wires: 

372 all_wires.append(w) 

373 

374 n = len(all_wires) 

375 

376 mat = _embed_matrix(all_ops[0].matrix, all_ops[0].wires, all_wires, n) 

377 for op in all_ops[1:]: 

378 mat_other = _embed_matrix(op.matrix, op.wires, all_wires, n) 

379 mat = mat @ mat_other 

380 

381 op_names = "*".join(op.name for op in all_ops) 

382 return Operation( 

383 wires=all_wires, matrix=mat, name=f"Prod({op_names})", record=False 

384 ) 

385 

386 def __matmul__(self, other: "Operation") -> "Operation": 

387 """Tensor (Kronecker) product or matrix product of two operations. 

388 

389 The resulting operation acts on the union of both wire sets. 

390 If the wire sets are disjoint, this is a Kronecker product. 

391 If the wire sets overlap, the corresponding matrices are multiplied. 

392 

393 Returns: 

394 A new :class:`Operation` whose matrix represents the composed 

395 operation on the unified wire set. 

396 """ 

397 if not isinstance(other, Operation): 

398 return NotImplemented 

399 

400 return self.prod(other) 

401 

402 def lifted_matrix(self, n_qubits: int) -> jnp.ndarray: 

403 """Return the full ``2**n x 2**n`` matrix embedding this gate. 

404 

405 Embeds the ``k``-qubit gate matrix into the ``n``-qubit Hilbert space 

406 by applying it to the identity matrix via :meth:`apply_to_state`. 

407 This is useful for computing ``Tr(O·\\rho )`` directly without vmap. 

408 

409 Args: 

410 n_qubits: Total number of qubits in the circuit. 

411 

412 Returns: 

413 The ``(2**n, 2**n)`` matrix of this operation in the full space. 

414 """ 

415 dim = 2**n_qubits 

416 # Apply the gate to each basis vector (column of identity) 

417 return jax.vmap(lambda col: self.apply_to_state(col, n_qubits))( 

418 jnp.eye(dim, dtype=_cdtype()) 

419 ).T 

420 

421 def apply_to_state(self, state: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

422 """Apply this gate to a statevector via tensor contraction. 

423 

424 The statevector (shape ``(2**n,)``) is reshaped into a rank-n tensor 

425 of shape ``(2,)*n``. The gate (shape ``(2**k, 2**k)``) is reshaped to 

426 ``(2,)*2k`` and contracted against the k target wire axes. 

427 

428 Memory footprint is O(2**n) and the operation supports arbitrary k. 

429 The implementation is fully differentiable through JAX. 

430 

431 Args: 

432 state: Statevector of shape ``(2**n_qubits,)``. 

433 n_qubits: Total number of qubits in the circuit. 

434 

435 Returns: 

436 Updated statevector of shape ``(2**n_qubits,)``. 

437 """ 

438 k = len(self.wires) 

439 gate_tensor = self.matrix.reshape((2,) * 2 * k) 

440 psi = state.reshape((2,) * n_qubits) 

441 psi_out = _contract_and_restore(psi, gate_tensor, k, self.wires) 

442 return psi_out.reshape(2**n_qubits) 

443 

444 def apply_to_state_tensor(self, psi: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

445 """Apply this gate to a statevector already in tensor form. 

446 

447 Like :meth:`apply_to_state` but expects the state in rank-n tensor 

448 form ``(2,)*n`` and returns the result in the same form. This avoids 

449 the ``reshape`` calls at the per-gate level when the simulation loop 

450 keeps the state in tensor form throughout. 

451 

452 Args: 

453 psi: Statevector tensor of shape ``(2,)*n_qubits``. 

454 n_qubits: Total number of qubits in the circuit. 

455 

456 Returns: 

457 Updated statevector tensor of shape ``(2,)*n_qubits``. 

458 """ 

459 k = len(self.wires) 

460 gate_tensor = self._gate_tensor(k) 

461 return _contract_and_restore(psi, gate_tensor, k, self.wires) 

462 

463 def _gate_tensor(self, k: int) -> jnp.ndarray: 

464 """Return the gate matrix reshaped to ``(2,)*2k`` tensor form. 

465 

466 The result is cached on the instance so repeated calls (e.g. from 

467 density-matrix simulation which applies U and U*) avoid redundant 

468 reshape dispatch. 

469 

470 Args: 

471 k: Number of qubits the gate acts on. 

472 

473 Returns: 

474 Gate matrix as a rank-2k tensor of shape ``(2,)*2k``. 

475 """ 

476 cached = getattr(self, "_cached_gate_tensor", None) 

477 if cached is not None: 

478 return cached 

479 gt = self.matrix.reshape((2,) * 2 * k) 

480 # Only cache for non-parametrized gates (whose matrix is a class attr) 

481 if self._matrix is self.__class__._matrix: 

482 object.__setattr__(self, "_cached_gate_tensor", gt) 

483 return gt 

484 

485 def apply_to_density(self, rho: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

486 """Apply this gate to a density matrix via \\rho -> U\\rho U\\dagger. 

487 

488 The density matrix (shape ``(2**n, 2**n)``) is treated as a rank-*2n* 

489 tensor with n "ket" axes (0..n-1) and n "bra" axes (n..2n-1). 

490 U acts on the ket half; U* acts on the bra half. Both contractions 

491 use the shared :func:`_contract_and_restore` helper, keeping the 

492 operation allocation-free with respect to building full unitaries. 

493 

494 Args: 

495 rho: Density matrix of shape ``(2**n_qubits, 2**n_qubits)``. 

496 n_qubits: Total number of qubits in the circuit. 

497 

498 Returns: 

499 Updated density matrix of shape ``(2**n_qubits, 2**n_qubits)``. 

500 """ 

501 k = len(self.wires) 

502 U = self._gate_tensor(k) 

503 U_conj = jnp.conj(U) 

504 

505 rho_t = rho.reshape((2,) * 2 * n_qubits) 

506 

507 # Apply U to ket axes, U\\dagger to bra axes 

508 rho_t = _contract_and_restore(rho_t, U, k, self.wires) 

509 bra_wires = [w + n_qubits for w in self.wires] 

510 rho_t = _contract_and_restore(rho_t, U_conj, k, bra_wires) 

511 

512 return rho_t.reshape(2**n_qubits, 2**n_qubits) 

513 

514 

515class Hermitian(Operation): 

516 """A generic Hermitian observable or gate defined by an arbitrary matrix. 

517 

518 Example: 

519 >>> obs = Hermitian(matrix=my_matrix, wires=0) 

520 """ 

521 

522 def __init__( 

523 self, 

524 matrix: jnp.ndarray, 

525 wires: Union[int, List[int]] = 0, 

526 record: bool = True, 

527 ) -> None: 

528 """Initialise a Hermitian operator. 

529 

530 Args: 

531 matrix: The Hermitian matrix defining this operator. 

532 wires: Qubit index or list of qubit indices this operator acts on. 

533 record: If ``True`` (default), record on the active tape. Set to 

534 ``False`` when using the Hermitian purely as a Hamiltonian 

535 component (e.g. for time-dependent evolution). 

536 """ 

537 super().__init__( 

538 wires=wires, 

539 matrix=jnp.asarray(matrix, dtype=_cdtype()), 

540 record=record, 

541 ) 

542 

543 def __rmul__(self, coeff_fn: Callable) -> "ParametrizedHamiltonian": 

544 """Support ``coeff_fn * Hermitian`` -> :class:`ParametrizedHamiltonian`. 

545 

546 Args: 

547 coeff_fn (Callable): A callable ``(params, t) -> scalar`` giving the 

548 time-dependent coefficient. 

549 

550 Returns: 

551 ParametrizedHamiltonian: A :class:`ParametrizedHamiltonian` pairing 

552 *coeff_fn* with this operator's matrix and wires. 

553 

554 Raises: 

555 TypeError: If *coeff_fn* is not callable. 

556 """ 

557 if not callable(coeff_fn): 

558 raise TypeError( 

559 f"Left operand of `* Hermitian` must be callable, got {type(coeff_fn)}" 

560 ) 

561 return ParametrizedHamiltonian(terms=[(coeff_fn, self.matrix, self.wires)]) 

562 

563 

564class ParametrizedHamiltonian: 

565 """A time-dependent Hamiltonian as a sum of ``coeff * Hermitian`` terms. 

566 

567 Mathematically:: 

568 

569 H(t) = \\sum_i f_i(params_i, t) * H_i 

570 

571 Construction is always done from an explicit list of 

572 ``(coeff_fn, H_mat, wires)`` triples passed as ``terms``. The 

573 common single-term shorthand is the operator form 

574 ``coeff_fn * Hermitian(matrix, wires)`` (see 

575 :meth:`Hermitian.__rmul__`), which returns a one-term instance. 

576 Multi-term Hamiltonians are composed with ``+`` between 

577 :class:`ParametrizedHamiltonian` instances:: 

578 

579 H1 = coeff_x * Hermitian(X, wires=0) 

580 H2 = coeff_y * Hermitian(Y, wires=0) 

581 H_td = H1 + H2 

582 

583 # evolve under the composite Hamiltonian; coeff_args is a list of 

584 # parameter sets, one per term, in the order the terms were added: 

585 evolve(H_td)([px, py], T=1.0) 

586 

587 Attributes: 

588 coeff_fns: Tuple of callables ``(params, t) -> scalar``, one per term. 

589 H_mats: Tuple of static Hermitian matrices, one per term. 

590 wires: Wires this Hamiltonian acts on (union across all terms; for 

591 now all terms are required to share the same wire set). 

592 """ 

593 

594 def __init__( 

595 self, 

596 terms: List[Tuple[Callable, jnp.ndarray, Union[int, List[int]]]], 

597 ) -> None: 

598 """Build a (possibly multi-term) parametrized Hamiltonian. 

599 

600 Args: 

601 terms: List of ``(coeff_fn, H_mat, wires)`` triples. Use the 

602 ``coeff_fn * Hermitian(...)`` shorthand to build a 

603 one-term instance; combine instances with ``+`` to add 

604 terms. 

605 

606 Raises: 

607 ValueError: If the term list is empty, or if terms act on 

608 differing wire sets (multi-wire broadcasting is 

609 deferred — see :mod:`yaqsi`), or if term matrices have 

610 incompatible shapes. 

611 """ 

612 if len(terms) == 0: 

613 raise ValueError("ParametrizedHamiltonian needs at least one term.") 

614 

615 # Normalise wires (single int -> [int]) and validate consistency. 

616 def _wlist(w): 

617 return [w] if isinstance(w, int) else list(w) 

618 

619 first_wires = _wlist(terms[0][2]) 

620 for _, _, w in terms[1:]: 

621 if _wlist(w) != first_wires: 

622 raise ValueError( 

623 "All terms of a ParametrizedHamiltonian must currently " 

624 "act on the same wires; got " 

625 f"{_wlist(w)} vs. {first_wires}. " 

626 "Multi-wire broadcasting across terms is not yet supported." 

627 ) 

628 

629 # Validate matrix shape compatibility across terms. 

630 first_dim = jnp.asarray(terms[0][1]).shape 

631 for _, H, _ in terms[1:]: 

632 if jnp.asarray(H).shape != first_dim: 

633 raise ValueError( 

634 f"All term matrices must have the same shape; got " 

635 f"{jnp.asarray(H).shape} vs. {first_dim}." 

636 ) 

637 

638 self._terms: Tuple[Tuple[Callable, jnp.ndarray, List[int]], ...] = tuple( 

639 (fn, jnp.asarray(H, dtype=_cdtype()), _wlist(w)) for fn, H, w in terms 

640 ) 

641 self.wires: List[int] = list(first_wires) 

642 

643 # --- term accessors ------------------------------------------------- 

644 

645 @property 

646 def coeff_fns(self) -> Tuple[Callable, ...]: 

647 """Tuple of coefficient functions, one per term.""" 

648 return tuple(fn for fn, _, _ in self._terms) 

649 

650 @property 

651 def H_mats(self) -> Tuple[jnp.ndarray, ...]: 

652 """Tuple of Hermitian matrices, one per term.""" 

653 return tuple(H for _, H, _ in self._terms) 

654 

655 @property 

656 def n_terms(self) -> int: 

657 """Number of terms in the Hamiltonian.""" 

658 return len(self._terms) 

659 

660 # --- composition --------------------------------------------------- 

661 

662 def __add__(self, other: "ParametrizedHamiltonian") -> "ParametrizedHamiltonian": 

663 """Concatenate term lists: ``H = H1 + H2``.""" 

664 if not isinstance(other, ParametrizedHamiltonian): 

665 return NotImplemented 

666 return ParametrizedHamiltonian(terms=list(self._terms) + list(other._terms)) 

667 

668 def __neg__(self) -> "ParametrizedHamiltonian": 

669 """Negate every coefficient: ``-H`` = sum of ``(-f_i) * H_i``.""" 

670 new_terms = [ 

671 ((lambda f: lambda p, t: -f(p, t))(fn), H, w) for fn, H, w in self._terms 

672 ] 

673 return ParametrizedHamiltonian(terms=new_terms) 

674 

675 def __sub__(self, other: "ParametrizedHamiltonian") -> "ParametrizedHamiltonian": 

676 if not isinstance(other, ParametrizedHamiltonian): 

677 return NotImplemented 

678 return self + (-other) 

679 

680 

681class Id(Operation): 

682 """Identity gate. 

683 

684 Supports an arbitrary number of wires. When more than one wire is 

685 given the matrix is the ``2**k x 2**k`` identity (where *k* is the 

686 number of wires). 

687 """ 

688 

689 _matrix = jnp.eye(2, dtype=_cdtype()) 

690 _num_wires = None # accept any number of wires 

691 

692 def __init__(self, wires: Union[int, List[int]] = 0, **kwargs) -> None: 

693 """Initialise an identity gate. 

694 

695 Args: 

696 wires: Qubit index or list of qubit indices this gate acts on. 

697 When multiple wires are given the matrix is automatically 

698 expanded to the matching ``2**k × 2**k`` identity. 

699 """ 

700 w = list(wires) if isinstance(wires, (list, tuple)) else [wires] 

701 k = len(w) 

702 if k > 1: 

703 kwargs["matrix"] = jnp.eye(2**k, dtype=_cdtype()) 

704 super().__init__(wires=wires, **kwargs) 

705 

706 

707class PauliX(Operation): 

708 """Pauli-X gate / observable (bit-flip, \\sigma_x).""" 

709 

710 _matrix = jnp.array([[0, 1], [1, 0]], dtype=_cdtype()) 

711 _num_wires = 1 

712 

713 def __init__(self, wires: Union[int, List[int]] = 0, **kwargs) -> None: 

714 """Initialise a Pauli-X gate. 

715 

716 Args: 

717 wires: Qubit index or list of qubit indices this gate acts on. 

718 """ 

719 super().__init__(wires=wires, **kwargs) 

720 

721 

722class PauliY(Operation): 

723 """Pauli-Y gate / observable (\\sigma_y).""" 

724 

725 _matrix = jnp.array([[0, -1j], [1j, 0]], dtype=_cdtype()) 

726 _num_wires = 1 

727 

728 def __init__(self, wires: Union[int, List[int]] = 0, **kwargs) -> None: 

729 """Initialise a Pauli-Y gate. 

730 

731 Args: 

732 wires: Qubit index or list of qubit indices this gate acts on. 

733 """ 

734 super().__init__(wires=wires, **kwargs) 

735 

736 

737class PauliZ(Operation): 

738 """Pauli-Z gate / observable (phase-flip, \\sigma_z).""" 

739 

740 _matrix = jnp.array([[1, 0], [0, -1]], dtype=_cdtype()) 

741 _num_wires = 1 

742 

743 def __init__(self, wires: Union[int, List[int]] = 0, **kwargs) -> None: 

744 """Initialise a Pauli-Z gate. 

745 

746 Args: 

747 wires: Qubit index or list of qubit indices this gate acts on. 

748 """ 

749 super().__init__(wires=wires, **kwargs) 

750 

751 

752class H(Operation): 

753 """Hadamard gate.""" 

754 

755 _matrix = jnp.array([[1, 1], [1, -1]], dtype=_cdtype()) / jnp.sqrt(2) 

756 _num_wires = 1 

757 

758 def __init__(self, wires: Union[int, List[int]] = 0, **kwargs) -> None: 

759 """Initialise a Hadamard gate. 

760 

761 Args: 

762 wires: Qubit index or list of qubit indices this gate acts on. 

763 """ 

764 super().__init__(wires=wires, **kwargs) 

765 

766 

767class S(Operation): 

768 """S (phase) gate — a Clifford gate equal to \\sqrt Z. 

769 

770 .. math:: 

771 S = \\begin{pmatrix}1 & 0\\ 0 & i\\end{pmatrix} 

772 """ 

773 

774 _matrix = jnp.array([[1, 0], [0, 1j]], dtype=_cdtype()) 

775 _num_wires = 1 

776 

777 def __init__(self, wires: Union[int, List[int]] = 0) -> None: 

778 """Initialise an S gate. 

779 

780 Args: 

781 wires: Qubit index or list of qubit indices this gate acts on. 

782 """ 

783 super().__init__(wires=wires) 

784 

785 

786class SWAP(Operation): 

787 """SWAP gate.""" 

788 

789 _matrix = jnp.array( 

790 [[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=_cdtype() 

791 ) 

792 _num_wires = 2 

793 

794 def __init__(self, wires: Union[int, List[int]] = 0, **kwargs) -> None: 

795 """Initialise a SWAP gate. 

796 

797 Args: 

798 wires: Qubit index or list of qubit indices this gate acts on. 

799 """ 

800 super().__init__(wires=wires, **kwargs) 

801 

802 

803class RandomUnitary(Operation): 

804 """Creates a random hermitian matrix and applies it as a gate.""" 

805 

806 def __init__( 

807 self, 

808 wires: Union[int, List[int]], 

809 key: jax.random.PRNGKey, 

810 scale: float = 1.0, 

811 record: bool = True, 

812 ) -> None: 

813 """Initialise a random unitary gate. 

814 

815 Args: 

816 wires (Union[int, List[int]]): Qubit index or list of qubit indices 

817 this gate acts on. 

818 key (jax.random.PRNGKey): PRNGKey for randomization. 

819 scale (float): Scale of the random unitary (default: 1.0). 

820 record (bool): Whether to record this gate on the active tape. 

821 """ 

822 dim = 2 ** len(wires) 

823 key_a, key_b = jax.random.split(key) 

824 

825 A = ( 

826 jax.random.normal(key=key_a, shape=(dim, dim)) 

827 + 1j * jax.random.normal(key=key_b, shape=(dim, dim)) 

828 ).astype(_cdtype()) 

829 H = (A + A.conj().T) / 2.0 

830 

831 H *= scale / jnp.linalg.norm(H, ord="fro") 

832 

833 super().__init__(wires, matrix=H, record=record) 

834 

835 

836class DiagonalQubitUnitary(Operation): 

837 """A diagonal unitary gate specified by its diagonal entries. 

838 

839 Implements ``U = diag(d_0, d_1, ..., d_{2^k-1})`` where each ``d_i`` lies 

840 on the unit circle. This is the natural gate for data-encoding 

841 Hamiltonians of the form ``S(x) = exp(-i H x)`` where *H* is diagonal in 

842 the computational basis (see Peters et al., arXiv:2209.05523). 

843 

844 The Golomb encoding strategy uses this gate with diagonal entries 

845 ``exp(-i * golomb_marks * x)`` to achieve a maximally non-degenerate 

846 Fourier spectrum. 

847 

848 Args: 

849 diag: 1-D array of ``2**k`` complex values on the unit circle. 

850 wires: Qubit indices this gate acts on (s.t. ``2**len(wires) == len(diag)``). 

851 **kwargs: Forwarded to :class:`Operation`. 

852 """ 

853 

854 # Do NOT list "diag" in _param_names — the array is not a scalar 

855 # parameter and would break drawing helpers that call float(p). 

856 _param_names = () 

857 

858 def __init__( 

859 self, 

860 diag: jnp.ndarray, 

861 wires: Union[int, List[int]] = 0, 

862 **kwargs, 

863 ) -> None: 

864 self.diag = diag 

865 wires_list = list(wires) if isinstance(wires, (list, tuple)) else [wires] 

866 expected_dim = 2 ** len(wires_list) 

867 if diag.shape != (expected_dim,): 

868 raise ValueError( 

869 f"DiagonalQubitUnitary expects {expected_dim} diagonal entries " 

870 f"for {len(wires_list)} wire(s), got shape {diag.shape}" 

871 ) 

872 mat = jnp.diag(diag) 

873 # Use a descriptive name for drawing 

874 kwargs.setdefault("name", "DiagU") 

875 super().__init__(wires=wires, matrix=mat, **kwargs) 

876 

877 def apply_to_state(self, state: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

878 """Apply diagonal gate via element-wise multiplication. 

879 

880 For a diagonal unitary, the full ``2^n``-dimensional diagonal is 

881 constructed by appropriate Kronecker-product embedding and the gate 

882 is applied as an element-wise product, which is significantly cheaper 

883 than generic matrix contraction for large qubit counts. 

884 

885 Args: 

886 state: Statevector of shape ``(2**n_qubits,)``. 

887 n_qubits: Total number of qubits in the circuit. 

888 

889 Returns: 

890 Updated statevector of shape ``(2**n_qubits,)``. 

891 """ 

892 k = len(self.wires) 

893 if k == n_qubits and self.wires == list(range(n_qubits)): 

894 # Gate acts on all qubits in order — direct element-wise multiply 

895 return state * self.diag 

896 # Fall back to general tensor contraction for arbitrary wire subsets 

897 return super().apply_to_state(state, n_qubits) 

898 

899 def apply_to_density(self, rho: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

900 """Apply diagonal gate to density matrix: rho -> U rho U†. 

901 

902 For diagonal U the transformation is 

903 ``rho_ij -> d_i * conj(d_j) * rho_ij``. 

904 

905 Args: 

906 rho: Density matrix of shape ``(2**n_qubits, 2**n_qubits)``. 

907 n_qubits: Total number of qubits in the circuit. 

908 

909 Returns: 

910 Updated density matrix of shape ``(2**n_qubits, 2**n_qubits)``. 

911 """ 

912 k = len(self.wires) 

913 if k == n_qubits and self.wires == list(range(n_qubits)): 

914 d = self.diag 

915 return d[:, None] * jnp.conj(d)[None, :] * rho 

916 return super().apply_to_density(rho, n_qubits) 

917 

918 

919class Barrier(Operation): 

920 """Barrier operation — a no-op used for visual circuit separation. 

921 

922 The barrier does not change the quantum state. It is recorded on the 

923 tape so that drawing backends can insert a visual separator. 

924 """ 

925 

926 _matrix = None # not a real gate 

927 

928 def __init__(self, wires: Union[int, List[int]] = 0) -> None: 

929 """Initialise a Barrier. 

930 

931 Args: 

932 wires: Qubit index or list of qubit indices this barrier spans. 

933 """ 

934 super().__init__(wires=wires) 

935 

936 def apply_to_state(self, state: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

937 """No-op: return the state unchanged.""" 

938 return state 

939 

940 def apply_to_state_tensor(self, psi: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

941 """No-op: return the state tensor unchanged.""" 

942 return psi 

943 

944 def apply_to_density(self, rho: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

945 """No-op: return the density matrix unchanged.""" 

946 return rho 

947 

948 

949def _make_rotation_gate(pauli_class: type, name: str) -> type: 

950 """Factory for single-qubit rotation gates RX, RY, RZ. 

951 

952 Each gate has the form ``R_P(\\theta) = cos(\\theta/2) I - i sin(\\theta/2) P``. 

953 

954 Args: 

955 pauli_class: One of PauliX, PauliY, PauliZ. 

956 name: Class name for the generated gate (e.g. ``"RX"``). 

957 

958 Returns: 

959 A new :class:`Operation` subclass. 

960 """ 

961 pauli_mat = pauli_class._matrix 

962 

963 class _RotationGate(Operation): 

964 # Fancy way of setting docstring to make it generic 

965 __doc__ = ( 

966 f"Rotation around the {name[1]} axis: {name}(\\theta) =\n" 

967 f"exp(-i \\theta/2 {name[1]}).\n" 

968 ) 

969 _num_wires = 1 

970 _param_names = ("theta",) 

971 

972 def __init__( 

973 self, theta: float, wires: Union[int, List[int]] = 0, **kwargs 

974 ) -> None: 

975 self.theta = theta 

976 c = jnp.cos(theta / 2) 

977 s = jnp.sin(theta / 2) 

978 mat = c * Id._matrix - 1j * s * pauli_mat 

979 super().__init__(wires=wires, matrix=mat, **kwargs) 

980 

981 def generator(self) -> Operation: 

982 """Return the generator as the corresponding Pauli operation.""" 

983 return pauli_class(wires=self.wires[0], record=False) 

984 

985 _RotationGate.__name__ = name 

986 _RotationGate.__qualname__ = name 

987 return _RotationGate 

988 

989 

990RX = _make_rotation_gate(PauliX, "RX") 

991RY = _make_rotation_gate(PauliY, "RY") 

992RZ = _make_rotation_gate(PauliZ, "RZ") 

993 

994 

995# Projectors used by controlled-gate factories 

996_P0 = jnp.array([[1, 0], [0, 0]], dtype=_cdtype()) 

997_P1 = jnp.array([[0, 0], [0, 1]], dtype=_cdtype()) 

998 

999 

1000def _make_controlled_gate(target_class: type, name: str) -> type: 

1001 """Factory for controlled Pauli gates CX, CY, CZ. 

1002 

1003 Each gate has the form 

1004 ``CP = |0><0| \\otimes I + |1\\langle\\rangle 1| \\otimes P``. 

1005 

1006 Args: 

1007 target_class: The single-qubit gate class (PauliX, PauliY, PauliZ). 

1008 name: Class name for the generated gate (e.g. ``"CX"``). 

1009 

1010 Returns: 

1011 A new :class:`Operation` subclass. 

1012 """ 

1013 target_mat = target_class._matrix 

1014 

1015 class _ControlledGate(Operation): 

1016 __doc__ = ( 

1017 f"Controlled-{target_class.__name__[5:]} gate.\n\n" 

1018 f"Applies {target_class.__name__} on the target qubit conditioned " 

1019 f"on the control qubit being in state |1\\rangle." 

1020 ) 

1021 _matrix = jnp.kron(_P0, Id._matrix) + jnp.kron(_P1, target_mat) 

1022 _num_wires = 2 

1023 is_controlled = True 

1024 

1025 def __init__(self, wires: List[int] = [0, 1], **kwargs) -> None: 

1026 super().__init__(wires=wires, **kwargs) 

1027 

1028 _ControlledGate.__name__ = name 

1029 _ControlledGate.__qualname__ = name 

1030 return _ControlledGate 

1031 

1032 

1033CX = _make_controlled_gate(PauliX, "CX") 

1034CY = _make_controlled_gate(PauliY, "CY") 

1035CZ = _make_controlled_gate(PauliZ, "CZ") 

1036 

1037 

1038class CCX(Operation): 

1039 """Toffoli (CCX) gate. 

1040 

1041 The 3-qubit Toffoli gate exercises the arbitrary-k-qubit path in 

1042 :meth:`~Operation.apply_to_state` and cannot be expressed as a pair of 

1043 2-qubit gates without ancilla, making it a good stress-test for the 

1044 simulator. 

1045 """ 

1046 

1047 _matrix = jnp.array( 

1048 [ 

1049 [1, 0, 0, 0, 0, 0, 0, 0], 

1050 [0, 1, 0, 0, 0, 0, 0, 0], 

1051 [0, 0, 1, 0, 0, 0, 0, 0], 

1052 [0, 0, 0, 1, 0, 0, 0, 0], 

1053 [0, 0, 0, 0, 1, 0, 0, 0], 

1054 [0, 0, 0, 0, 0, 1, 0, 0], 

1055 [0, 0, 0, 0, 0, 0, 0, 1], 

1056 [0, 0, 0, 0, 0, 0, 1, 0], 

1057 ], 

1058 dtype=_cdtype(), 

1059 ) 

1060 is_controlled = True 

1061 _num_wires = 3 

1062 

1063 def __init__(self, wires: List[int] = [0, 1, 2], **kwargs) -> None: 

1064 """Initialise a Toffoli (CCX) gate. 

1065 

1066 Args: 

1067 wires: Three-element list ``[control0, control1, target]``. 

1068 """ 

1069 super().__init__(wires=wires, **kwargs) 

1070 

1071 

1072class CSWAP(Operation): 

1073 """Controlled-SWAP (Fredkin) gate. 

1074 

1075 Swaps the two target qubits conditioned on the control qubit being |1\\rangle. 

1076 

1077 Args on construction: 

1078 wires: ``[control, target0, target1]``. 

1079 """ 

1080 

1081 _matrix = jnp.array( 

1082 [ 

1083 [1, 0, 0, 0, 0, 0, 0, 0], 

1084 [0, 1, 0, 0, 0, 0, 0, 0], 

1085 [0, 0, 1, 0, 0, 0, 0, 0], 

1086 [0, 0, 0, 1, 0, 0, 0, 0], 

1087 [0, 0, 0, 0, 1, 0, 0, 0], 

1088 [0, 0, 0, 0, 0, 0, 1, 0], 

1089 [0, 0, 0, 0, 0, 1, 0, 0], 

1090 [0, 0, 0, 0, 0, 0, 0, 1], 

1091 ], 

1092 dtype=_cdtype(), 

1093 ) 

1094 is_controlled = True 

1095 _num_wires = 3 

1096 

1097 def __init__(self, wires: List[int] = [0, 1, 2], **kwargs) -> None: 

1098 """Initialise a Controlled-SWAP (Fredkin) gate. 

1099 

1100 Args: 

1101 wires: Three-element list ``[control, target0, target1]``. 

1102 """ 

1103 super().__init__(wires=wires, **kwargs) 

1104 

1105 

1106def _make_controlled_rotation_gate(pauli_class: type, name: str) -> type: 

1107 """Factory for controlled rotation gates CRX, CRY, CRZ. 

1108 

1109 Each gate has the form 

1110 ``CR_P(\\theta) = |0><0| \\otimes I + |1><1| \\otimes R_P(\\theta)``. 

1111 

1112 Args: 

1113 pauli_class: One of PauliX, PauliY, PauliZ. 

1114 name: Class name for the generated gate (e.g. ``"CRX"``). 

1115 

1116 Returns: 

1117 A new :class:`Operation` subclass. 

1118 """ 

1119 pauli_mat = pauli_class._matrix 

1120 

1121 class _CRotationGate(Operation): 

1122 __doc__ = ( 

1123 f"Controlled rotation around the {name[2]} axis.\n\n" 

1124 f"Applies R{name[2]}(\\theta) on the target qubit conditioned on the " 

1125 f"control qubit being in state |1\\rangle.\n\n" 

1126 f".. math::\n" 

1127 f"{name}(\\theta) = |0\\rangle\\langle 0| \\otimes I\n" 

1128 f" + |1\\rangle\\langle 1| \\otimes R{name[2]}(\\theta)" 

1129 ) 

1130 _num_wires = 2 

1131 _param_names = ("theta",) 

1132 is_controlled = True 

1133 

1134 def __init__(self, theta: float, wires: List[int] = [0, 1], **kwargs) -> None: 

1135 self.theta = theta 

1136 c = jnp.cos(theta / 2) 

1137 s = jnp.sin(theta / 2) 

1138 rot = c * Id._matrix - 1j * s * pauli_mat 

1139 mat = jnp.kron(_P0, Id._matrix) + jnp.kron(_P1, rot) 

1140 super().__init__(wires=wires, matrix=mat, **kwargs) 

1141 

1142 _CRotationGate.__name__ = name 

1143 _CRotationGate.__qualname__ = name 

1144 return _CRotationGate 

1145 

1146 

1147CRX = _make_controlled_rotation_gate(PauliX, "CRX") 

1148CRY = _make_controlled_rotation_gate(PauliY, "CRY") 

1149CRZ = _make_controlled_rotation_gate(PauliZ, "CRZ") 

1150 

1151 

1152class ControlledPhaseShift(Operation): 

1153 r"""Controlled phase shift gate (CPhase). 

1154 

1155 Applies a phase shift of ``exp(i * phi)`` to the |11⟩ component of the 

1156 two-qubit state, leaving all other computational basis states unchanged. 

1157 This is a generalization of the CZ gate: when ``phi = \\pi`` the gate 

1158 reduces to CZ. 

1159 

1160 .. math:: 

1161 \text{CPhase}(\phi) = \text{diag}(1, 1, 1, e^{i\phi}) 

1162 

1163 which is equivalent to 

1164 ``|0⟩⟨0| \\otimes I + |1⟩⟨1| \\otimes P(phi)`` where 

1165 ``P(phi) = diag(1, exp(i*phi))``. 

1166 """ 

1167 

1168 _num_wires = 2 

1169 _param_names = ("phi",) 

1170 is_controlled = True 

1171 

1172 def __init__(self, phi: float, wires: List[int] = [0, 1], **kwargs) -> None: 

1173 """Initialise a controlled phase shift gate. 

1174 

1175 Args: 

1176 phi: Phase shift angle in radians. 

1177 wires: Two-element list ``[control, target]``. 

1178 """ 

1179 self.phi = phi 

1180 phase_gate = jnp.array([[1, 0], [0, jnp.exp(1j * phi)]], dtype=_cdtype()) 

1181 mat = jnp.kron(_P0, Id._matrix) + jnp.kron(_P1, phase_gate) 

1182 super().__init__(wires=wires, matrix=mat, **kwargs) 

1183 

1184 

1185class Rot(Operation): 

1186 """General single-qubit rotation: 

1187 Rot(\\phi, \\theta, \\omega) = RZ(\\omega) RY(\\theta) RZ(\\phi). 

1188 

1189 This is the most general SU(2) rotation (up to a global phase). It 

1190 decomposes into three successive rotations and has three free parameters. 

1191 """ 

1192 

1193 _num_wires = 1 

1194 _param_names = ("phi", "theta", "omega") 

1195 

1196 def __init__( 

1197 self, 

1198 phi: float, 

1199 theta: float, 

1200 omega: float, 

1201 wires: Union[int, List[int]] = 0, 

1202 **kwargs, 

1203 ) -> None: 

1204 """Initialise a general rotation gate. 

1205 

1206 Args: 

1207 phi: First RZ rotation angle (radians). 

1208 theta: RY rotation angle (radians). 

1209 omega: Second RZ rotation angle (radians). 

1210 wires: Qubit index or list of qubit indices this gate acts on. 

1211 """ 

1212 self.phi = phi 

1213 self.theta = theta 

1214 self.omega = omega 

1215 # Rot(\\phi, \theta, \\omega) = RZ(\\omega) @ RY(\theta) @ RZ(\\phi) 

1216 rz_phi = jnp.cos(phi / 2) * Id._matrix - 1j * jnp.sin(phi / 2) * PauliZ._matrix 

1217 ry_theta = ( 

1218 jnp.cos(theta / 2) * Id._matrix - 1j * jnp.sin(theta / 2) * PauliY._matrix 

1219 ) 

1220 rz_omega = ( 

1221 jnp.cos(omega / 2) * Id._matrix - 1j * jnp.sin(omega / 2) * PauliZ._matrix 

1222 ) 

1223 mat = rz_omega @ ry_theta @ rz_phi 

1224 super().__init__(wires=wires, matrix=mat, **kwargs) 

1225 

1226 

1227class PauliRot(Operation): 

1228 """Multi-qubit Pauli rotation: exp(-i \\theta/2 P) for a Pauli word P. 

1229 

1230 The Pauli word is given as a string of ``'I'``, ``'X'``, ``'Y'``, ``'Z'`` 

1231 characters (one per qubit). The rotation matrix is computed as 

1232 ``cos(\\theta/2) I - i sin(\\theta/2) P`` where *P* is the tensor product of the 

1233 corresponding single-qubit Pauli matrices. 

1234 

1235 Example:: 

1236 

1237 PauliRot(0.5, "XY", wires=[0, 1]) 

1238 """ 

1239 

1240 _param_names = ("theta",) 

1241 

1242 # Map from character to 2x2 matrix 

1243 _PAULI_MAP = { 

1244 "I": Id._matrix, 

1245 "X": PauliX._matrix, 

1246 "Y": PauliY._matrix, 

1247 "Z": PauliZ._matrix, 

1248 } 

1249 

1250 def __init__( 

1251 self, theta: float, pauli_word: str, wires: Union[int, List[int]] = 0, **kwargs 

1252 ) -> None: 

1253 """Initialise a PauliRot gate. 

1254 

1255 Args: 

1256 theta: Rotation angle in radians. 

1257 pauli_word: A string of ``'I'``, ``'X'``, ``'Y'``, ``'Z'`` 

1258 characters specifying the Pauli tensor product. 

1259 wires: Qubit index or list of qubit indices this gate acts on. 

1260 """ 

1261 from functools import reduce as _reduce 

1262 

1263 self.theta = theta 

1264 self.pauli_word = pauli_word 

1265 

1266 pauli_matrices = [self._PAULI_MAP[c] for c in pauli_word] 

1267 P = _reduce(jnp.kron, pauli_matrices) 

1268 dim = P.shape[0] 

1269 mat = ( 

1270 jnp.cos(theta / 2) * jnp.eye(dim, dtype=_cdtype()) 

1271 - 1j * jnp.sin(theta / 2) * P 

1272 ) 

1273 super().__init__(wires=wires, matrix=mat, **kwargs) 

1274 

1275 def generator(self) -> Operation: 

1276 """Return the generator Pauli tensor product as an :class:`Operation`. 

1277 

1278 The generator of ``PauliRot(\\theta, word, wires)`` is the tensor product 

1279 of single-qubit Pauli matrices specified by *word*. The returned 

1280 :class:`Hermitian` wraps that matrix and the gate's wires. 

1281 

1282 Returns: 

1283 :class:`Hermitian` operation representing the Pauli tensor product. 

1284 """ 

1285 from functools import reduce as _reduce 

1286 

1287 pauli_matrices = [self._PAULI_MAP[c] for c in self.pauli_word] 

1288 P = _reduce(jnp.kron, pauli_matrices) 

1289 return Hermitian(matrix=P, wires=self.wires, record=False) 

1290 

1291 

1292class KrausChannel(Operation): 

1293 """Base class for noise channels defined by a set of Kraus operators. 

1294 

1295 A Kraus channel \\phi(\\rho ) = \\sigma_k K_k \\rho K_k\\dagger 

1296 is the most general physical 

1297 operation on a quantum state. For a pure unitary gate there is a single 

1298 operator K_0 = U satisfying K_0\\daggerK_0 = I; for noisy channels there are 

1299 multiple operators. 

1300 

1301 Subclasses must implement :meth:`kraus_matrices` and return a list of JAX 

1302 arrays. :meth:`apply_to_state` is intentionally left unimplemented: 

1303 Kraus channels require a density-matrix representation and cannot be 

1304 applied to a pure statevector in general. 

1305 """ 

1306 

1307 def kraus_matrices(self) -> List[jnp.ndarray]: 

1308 """Return the list of Kraus operators for this channel. 

1309 

1310 Returns: 

1311 List of 2-D JAX arrays, each of shape ``(2**k, 2**k)`` where k 

1312 is the number of target qubits. 

1313 

1314 Raises: 

1315 NotImplementedError: Subclasses must override this method. 

1316 """ 

1317 raise NotImplementedError 

1318 

1319 @property 

1320 def matrix(self) -> jnp.ndarray: 

1321 """Raises TypeError — noise channels have no single unitary matrix. 

1322 

1323 Raises: 

1324 TypeError: Always raised; use :meth:`apply_to_density` instead. 

1325 """ 

1326 raise TypeError( 

1327 f"{self.__class__.__name__} is a noise channel and has no single " 

1328 "unitary matrix. Use apply_to_density() instead." 

1329 ) 

1330 

1331 def apply_to_state(self, state: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

1332 """Raises TypeError — noise channels require density-matrix simulation. 

1333 

1334 Args: 

1335 state: Statevector (unused). 

1336 n_qubits: Number of qubits (unused). 

1337 

1338 Raises: 

1339 TypeError: Always raised; use ``execute(type='density')`` instead. 

1340 """ 

1341 raise TypeError( 

1342 f"{self.__class__.__name__} is a noise channel and cannot be " 

1343 "applied to a pure statevector. Use execute(type='density') instead." 

1344 ) 

1345 

1346 def apply_to_state_tensor(self, psi: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

1347 """Raises TypeError — noise channels require density-matrix simulation.""" 

1348 raise TypeError( 

1349 f"{self.__class__.__name__} is a noise channel and cannot be " 

1350 "applied to a pure statevector. Use execute(type='density') instead." 

1351 ) 

1352 

1353 def apply_to_density(self, rho: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

1354 """Apply 

1355 \\phi(\\rho ) = \\sigma_k K_k \\rho K_k\\dagger using tensor-contraction. 

1356 

1357 Uses the shared :func:`_contract_and_restore` helper, summing the 

1358 result over all Kraus operators. 

1359 

1360 Args: 

1361 rho: Density matrix of shape ``(2**n_qubits, 2**n_qubits)``. 

1362 n_qubits: Total number of qubits in the circuit. 

1363 

1364 Returns: 

1365 Updated density matrix of shape ``(2**n_qubits, 2**n_qubits)``. 

1366 """ 

1367 k = len(self.wires) 

1368 dim = 2**n_qubits 

1369 bra_wires = [w + n_qubits for w in self.wires] 

1370 rho_out = jnp.zeros_like(rho) 

1371 

1372 for K in self.kraus_matrices(): 

1373 K_t = K.reshape((2,) * 2 * k) 

1374 K_conj_t = jnp.conj(K_t) 

1375 rho_t = rho.reshape((2,) * 2 * n_qubits) 

1376 rho_t = _contract_and_restore(rho_t, K_t, k, self.wires) 

1377 rho_t = _contract_and_restore(rho_t, K_conj_t, k, bra_wires) 

1378 rho_out = rho_out + rho_t.reshape(dim, dim) 

1379 

1380 return rho_out 

1381 

1382 

1383class BitFlip(KrausChannel): 

1384 r"""Single-qubit bit-flip (Pauli-X) error channel. 

1385 

1386 .. math:: 

1387 K_0 = \sqrt{1-p}\,I, \quad K_1 = \sqrt{p}\,X 

1388 

1389 where *p* \\in [0, 1] is the probability of a bit flip. 

1390 """ 

1391 

1392 _num_wires = 1 

1393 _param_names = ("p",) 

1394 

1395 def __init__(self, p: float, wires: Union[int, List[int]] = 0) -> None: 

1396 """Initialise a bit-flip channel. 

1397 

1398 Args: 

1399 p: Bit-flip probability, must be in [0, 1]. 

1400 wires: Qubit index or list of qubit indices this channel acts on. 

1401 

1402 Raises: 

1403 ValueError: If *p* is outside [0, 1]. 

1404 """ 

1405 if not 0.0 <= p <= 1.0: 

1406 raise ValueError("p must be in [0, 1].") 

1407 self.p = p 

1408 super().__init__(wires=wires) 

1409 

1410 def kraus_matrices(self) -> List[jnp.ndarray]: 

1411 """Return the two Kraus operators for the bit-flip channel. 

1412 

1413 Returns: 

1414 List ``[K0, K1]`` where K0 = \\sqrt (1-p)·I and K1 = \\sqrt p·X. 

1415 """ 

1416 p = self.p 

1417 K0 = jnp.sqrt(1 - p) * Id._matrix 

1418 K1 = jnp.sqrt(p) * PauliX._matrix 

1419 return [K0, K1] 

1420 

1421 

1422class PhaseFlip(KrausChannel): 

1423 r"""Single-qubit phase-flip (Pauli-Z) error channel. 

1424 

1425 .. math:: 

1426 K_0 = \sqrt{1-p}\,I, \quad K_1 = \sqrt{p}\,Z 

1427 

1428 where *p* \\in [0, 1] is the probability of a phase flip. 

1429 """ 

1430 

1431 _num_wires = 1 

1432 _param_names = ("p",) 

1433 

1434 def __init__(self, p: float, wires: Union[int, List[int]] = 0) -> None: 

1435 """Initialise a phase-flip channel. 

1436 

1437 Args: 

1438 p: Phase-flip probability, must be in [0, 1]. 

1439 wires: Qubit index or list of qubit indices this channel acts on. 

1440 

1441 Raises: 

1442 ValueError: If *p* is outside [0, 1]. 

1443 """ 

1444 if not 0.0 <= p <= 1.0: 

1445 raise ValueError("p must be in [0, 1].") 

1446 self.p = p 

1447 super().__init__(wires=wires) 

1448 

1449 def kraus_matrices(self) -> List[jnp.ndarray]: 

1450 """Return the two Kraus operators for the phase-flip channel. 

1451 

1452 Returns: 

1453 List ``[K0, K1]`` where K0 = \\sqrt (1-p)·I and K1 = \\sqrt p·Z. 

1454 """ 

1455 p = self.p 

1456 K0 = jnp.sqrt(1 - p) * Id._matrix 

1457 K1 = jnp.sqrt(p) * PauliZ._matrix 

1458 return [K0, K1] 

1459 

1460 

1461class DepolarizingChannel(KrausChannel): 

1462 r"""Single-qubit depolarizing channel. 

1463 

1464 .. math:: 

1465 K_0 = \sqrt{1-p}\,I,\quad K_1 = \sqrt{p/3}\,X,\quad 

1466 K_2 = \sqrt{p/3}\,Y,\quad K_3 = \sqrt{p/3}\,Z 

1467 

1468 where *p* \\in [0, 1]. At p = 3/4 the channel is fully depolarizing. 

1469 """ 

1470 

1471 _num_wires = 1 

1472 _param_names = ("p",) 

1473 

1474 def __init__(self, p: float, wires: Union[int, List[int]] = 0) -> None: 

1475 """Initialise a depolarizing channel. 

1476 

1477 Args: 

1478 p: Depolarization probability, must be in [0, 1]. 

1479 wires: Qubit index or list of qubit indices this channel acts on. 

1480 

1481 Raises: 

1482 ValueError: If *p* is outside [0, 1]. 

1483 """ 

1484 if not 0.0 <= p <= 1.0: 

1485 raise ValueError("p must be in [0, 1].") 

1486 self.p = p 

1487 super().__init__(wires=wires) 

1488 

1489 def kraus_matrices(self) -> List[jnp.ndarray]: 

1490 """Return the four Kraus operators for the depolarizing channel. 

1491 

1492 Returns: 

1493 List ``[K0, K1, K2, K3]`` corresponding to I, X, Y, Z components. 

1494 """ 

1495 p = self.p 

1496 K0 = jnp.sqrt(1 - p) * Id._matrix 

1497 K1 = jnp.sqrt(p / 3) * PauliX._matrix 

1498 K2 = jnp.sqrt(p / 3) * PauliY._matrix 

1499 K3 = jnp.sqrt(p / 3) * PauliZ._matrix 

1500 return [K0, K1, K2, K3] 

1501 

1502 

1503class AmplitudeDamping(KrausChannel): 

1504 r"""Single-qubit amplitude damping channel. 

1505 

1506 .. math:: 

1507 K_0 = \begin{pmatrix}1 & 0\\ 0 & \sqrt{1-\gamma}\end{pmatrix},\quad 

1508 K_1 = \begin{pmatrix}0 & \sqrt{\gamma}\\ 0 & 0\end{pmatrix} 

1509 

1510 where *\\gamma* \\in [0, 1] is the probability of 

1511 energy loss (|1\\rangle -> |0\\rangle). 

1512 """ 

1513 

1514 _num_wires = 1 

1515 _param_names = ("gamma",) 

1516 

1517 def __init__(self, gamma: float, wires: Union[int, List[int]] = 0) -> None: 

1518 """Initialise an amplitude damping channel. 

1519 

1520 Args: 

1521 gamma: Energy-loss probability, must be in [0, 1]. 

1522 wires: Qubit index or list of qubit indices this channel acts on. 

1523 

1524 Raises: 

1525 ValueError: If *gamma* is outside [0, 1]. 

1526 """ 

1527 if not 0.0 <= gamma <= 1.0: 

1528 raise ValueError("gamma must be in [0, 1].") 

1529 self.gamma = gamma 

1530 super().__init__(wires=wires) 

1531 

1532 def kraus_matrices(self) -> List[jnp.ndarray]: 

1533 """Return the two Kraus operators for the amplitude damping channel. 

1534 

1535 Returns: 

1536 List ``[K0, K1]`` as defined in the class docstring. 

1537 """ 

1538 g = self.gamma 

1539 K0 = jnp.array([[1.0, 0.0], [0.0, jnp.sqrt(1 - g)]], dtype=_cdtype()) 

1540 K1 = jnp.array([[0.0, jnp.sqrt(g)], [0.0, 0.0]], dtype=_cdtype()) 

1541 return [K0, K1] 

1542 

1543 

1544class PhaseDamping(KrausChannel): 

1545 r"""Single-qubit phase damping (dephasing) channel. 

1546 

1547 .. math:: 

1548 K_0 = \begin{pmatrix}1 & 0\\ 0 & \sqrt{1-\gamma}\end{pmatrix},\quad 

1549 K_1 = \begin{pmatrix}0 & 0\\ 0 & \sqrt{\gamma}\end{pmatrix} 

1550 

1551 where *\\gamma* \\in [0, 1] is the phase damping probability. 

1552 """ 

1553 

1554 _num_wires = 1 

1555 _param_names = ("gamma",) 

1556 

1557 def __init__(self, gamma: float, wires: Union[int, List[int]] = 0) -> None: 

1558 """Initialise a phase damping channel. 

1559 

1560 Args: 

1561 gamma: Phase-damping probability, must be in [0, 1]. 

1562 wires: Qubit index or list of qubit indices this channel acts on. 

1563 

1564 Raises: 

1565 ValueError: If *gamma* is outside [0, 1]. 

1566 """ 

1567 if not 0.0 <= gamma <= 1.0: 

1568 raise ValueError("gamma must be in [0, 1].") 

1569 self.gamma = gamma 

1570 super().__init__(wires=wires) 

1571 

1572 def kraus_matrices(self) -> List[jnp.ndarray]: 

1573 """Return the two Kraus operators for the phase damping channel. 

1574 

1575 Returns: 

1576 List ``[K0, K1]`` as defined in the class docstring. 

1577 """ 

1578 g = self.gamma 

1579 K0 = jnp.array([[1.0, 0.0], [0.0, jnp.sqrt(1 - g)]], dtype=_cdtype()) 

1580 K1 = jnp.array([[0.0, 0.0], [0.0, jnp.sqrt(g)]], dtype=_cdtype()) 

1581 return [K0, K1] 

1582 

1583 

1584class ThermalRelaxationError(KrausChannel): 

1585 r"""Single-qubit thermal relaxation error channel. 

1586 

1587 Models simultaneous T_1 energy relaxation and T_2 dephasing. Two regimes 

1588 are handled: 

1589 

1590 T_2 <= T_1 (Markovian dephasing + reset): 

1591 Six Kraus operators built from p_z (phase-flip probability), p_r0 

1592 (reset-to-|0\\rangle probability) and p_r1 (reset-to-|1\\rangle probability). 

1593 

1594 T_2 > T_1 (non-Markovian; Choi matrix decomposition): 

1595 The Choi matrix is assembled from the relaxation/dephasing rates, then 

1596 diagonalised; Kraus operators are K_i = \sqrt \lambda_i · mat(v_i). 

1597 

1598 Attributes: 

1599 pe: Excited-state population (thermal population of |1\\rangle). 

1600 t1: T_1 longitudinal relaxation time. 

1601 t2: T_2 transverse dephasing time. 

1602 tg: Gate duration. 

1603 """ 

1604 

1605 _num_wires = 1 

1606 _param_names = ("pe", "t1", "t2", "tg") 

1607 

1608 def __init__( 

1609 self, 

1610 pe: float, 

1611 t1: float, 

1612 t2: float, 

1613 tg: float, 

1614 wires: Union[int, List[int]] = 0, 

1615 ) -> None: 

1616 """Initialise a thermal relaxation error channel. 

1617 

1618 Args: 

1619 pe: Excited-state population (thermal population of |1\\rangle), in [0, 1]. 

1620 t1: T_1 longitudinal relaxation time, must be > 0. 

1621 t2: T_2 transverse dephasing time, must be > 0 and <= 2·T_1. 

1622 tg: Gate duration, must be >= 0. 

1623 wires: Qubit index or list of qubit indices this channel acts on. 

1624 

1625 Raises: 

1626 ValueError: If any parameter violates the stated constraints. 

1627 """ 

1628 if not 0.0 <= pe <= 1.0: 

1629 raise ValueError("pe must be in [0, 1].") 

1630 if t1 <= 0: 

1631 raise ValueError("t1 must be > 0.") 

1632 if t2 <= 0: 

1633 raise ValueError("t2 must be > 0.") 

1634 if t2 > 2 * t1: 

1635 raise ValueError("t2 must be <= 2·t1.") 

1636 if tg < 0: 

1637 raise ValueError("tg must be >= 0.") 

1638 self.pe = pe 

1639 self.t1 = t1 

1640 self.t2 = t2 

1641 self.tg = tg 

1642 super().__init__(wires=wires) 

1643 

1644 def kraus_matrices(self) -> List[jnp.ndarray]: 

1645 """Return the Kraus operators for the thermal relaxation channel. 

1646 

1647 The number of operators depends on the regime: 

1648 

1649 * T_2 <= T_1: six operators (identity, phase-flip, two reset-to-|0\\rangle, 

1650 two reset-to-|1\\rangle). 

1651 * T_2 > T_1: four operators derived from the Choi matrix eigendecomposition. 

1652 

1653 Returns: 

1654 List of 2x2 JAX arrays representing the Kraus operators. 

1655 """ 

1656 pe, t1, t2, tg = self.pe, self.t1, self.t2, self.tg 

1657 

1658 eT1 = jnp.exp(-tg / t1) 

1659 p_reset = 1.0 - eT1 

1660 eT2 = jnp.exp(-tg / t2) 

1661 

1662 if t2 <= t1: 

1663 # --- Case T_2 <= T_1: six Kraus operators --- 

1664 pz = (1.0 - p_reset) * (1.0 - eT2 / eT1) / 2.0 

1665 pr0 = (1.0 - pe) * p_reset 

1666 pr1 = pe * p_reset 

1667 pid = 1.0 - pz - pr0 - pr1 

1668 

1669 K0 = jnp.sqrt(pid) * jnp.eye(2, dtype=_cdtype()) 

1670 K1 = jnp.sqrt(pz) * jnp.array([[1, 0], [0, -1]], dtype=_cdtype()) 

1671 K2 = jnp.sqrt(pr0) * jnp.array([[1, 0], [0, 0]], dtype=_cdtype()) 

1672 K3 = jnp.sqrt(pr0) * jnp.array([[0, 1], [0, 0]], dtype=_cdtype()) 

1673 K4 = jnp.sqrt(pr1) * jnp.array([[0, 0], [1, 0]], dtype=_cdtype()) 

1674 K5 = jnp.sqrt(pr1) * jnp.array([[0, 0], [0, 1]], dtype=_cdtype()) 

1675 return [K0, K1, K2, K3, K4, K5] 

1676 

1677 else: 

1678 # --- Case T_2 > T_1: Choi matrix decomposition --- 

1679 # Choi matrix (column-major / reshaping convention matching PennyLane) 

1680 choi = jnp.array( 

1681 [ 

1682 [1 - pe * p_reset, 0, 0, eT2], 

1683 [0, pe * p_reset, 0, 0], 

1684 [0, 0, (1 - pe) * p_reset, 0], 

1685 [eT2, 0, 0, 1 - (1 - pe) * p_reset], 

1686 ], 

1687 dtype=_cdtype(), 

1688 ) 

1689 eigenvalues, eigenvectors = jnp.linalg.eigh(choi) 

1690 # Each eigenvector (column of eigenvectors) reshaped as 2x2 -> one Kraus op 

1691 kraus = [] 

1692 for i in range(4): 

1693 lam = eigenvalues[i] 

1694 vec = eigenvectors[:, i] 

1695 mat = jnp.sqrt(jnp.abs(lam)) * vec.reshape(2, 2, order="F") 

1696 kraus.append(mat.astype(_cdtype())) 

1697 return kraus 

1698 

1699 

1700class QubitChannel(KrausChannel): 

1701 """Generic Kraus channel from a user-supplied list of Kraus operators. 

1702 

1703 This replaces PennyLane's ``qml.QubitChannel`` and accepts an arbitrary set 

1704 of Kraus matrices satisfying \\sigma_k K_k\\dagger K_k = I. 

1705 

1706 Example:: 

1707 

1708 kraus_ops = [jnp.sqrt(0.9) * jnp.eye(2), jnp.sqrt(0.1) * PauliX._matrix] 

1709 QubitChannel(kraus_ops, wires=0) 

1710 """ 

1711 

1712 def __init__( 

1713 self, kraus_ops: List[jnp.ndarray], wires: Union[int, List[int]] = 0 

1714 ) -> None: 

1715 """Initialise a generic Kraus channel. 

1716 

1717 Args: 

1718 kraus_ops: List of Kraus matrices. Each must be a square 2D array 

1719 of dimension ``2**k x 2**k`` where k = ``len(wires)``. 

1720 wires: Qubit index or list of qubit indices this channel acts on. 

1721 """ 

1722 self._kraus_ops = [jnp.asarray(K, dtype=_cdtype()) for K in kraus_ops] 

1723 super().__init__(wires=wires) 

1724 

1725 def kraus_matrices(self) -> List[jnp.ndarray]: 

1726 """Return the stored Kraus operators. 

1727 

1728 Returns: 

1729 List of Kraus operator matrices. 

1730 """ 

1731 return self._kraus_ops 

1732 

1733 

1734# Single-qubit Pauli matrices (plain arrays, no Operation overhead) 

1735_PAULI_MATS = [Id._matrix, PauliX._matrix, PauliY._matrix, PauliZ._matrix] 

1736_PAULI_LABELS = ["I", "X", "Y", "Z"] 

1737_PAULI_CLASSES = [Id, PauliX, PauliY, PauliZ] 

1738 

1739 

1740def evolve_pauli_with_clifford( 

1741 clifford: Operation, 

1742 pauli: Operation, 

1743 adjoint_left: bool = True, 

1744) -> Operation: 

1745 """Compute C\\dagger P C (or C P C\\dagger) and 

1746 return the result as an Operation. 

1747 

1748 Both operators are first embedded into the full Hilbert space spanned by 

1749 the union of their wire sets. The result is wrapped in a 

1750 :class:`Hermitian` so it can be used in further algebra. 

1751 

1752 Args: 

1753 clifford: A Clifford gate. 

1754 pauli: A Pauli / Hermitian operator. 

1755 adjoint_left: If ``True``, compute C\\dagger P C; otherwise C P C\\dagger. 

1756 

1757 Returns: 

1758 A :class:`Hermitian` wrapping the evolved matrix. 

1759 """ 

1760 all_wires = sorted(set(clifford.wires) | set(pauli.wires)) 

1761 n = len(all_wires) 

1762 

1763 C = _embed_matrix(clifford.matrix, clifford.wires, all_wires, n) 

1764 P = _embed_matrix(pauli.matrix, pauli.wires, all_wires, n) 

1765 Cd = jnp.conj(C).T 

1766 

1767 if adjoint_left: 

1768 result = Cd @ P @ C 

1769 else: 

1770 result = C @ P @ Cd 

1771 

1772 return Hermitian(matrix=result, wires=all_wires, record=False) 

1773 

1774 

1775def _embed_matrix( 

1776 mat: jnp.ndarray, 

1777 op_wires: list, 

1778 all_wires: list, 

1779 n_total: int, 

1780) -> jnp.ndarray: 

1781 """Embed a gate matrix into a larger Hilbert space via tensor products. 

1782 

1783 If the gate already acts on all wires, the matrix is returned as-is. 

1784 Otherwise the gate matrix is tensored with identities on the missing 

1785 wires, and the resulting matrix rows/columns are permuted so that qubit 

1786 ordering matches *all_wires*. 

1787 

1788 Args: 

1789 mat: The gate's unitary matrix of shape ``(2**k, 2**k)`` where 

1790 ``k = len(op_wires)``. 

1791 op_wires: The wires the gate acts on. 

1792 all_wires: The full ordered list of wires. 

1793 n_total: ``len(all_wires)``. 

1794 

1795 Returns: 

1796 A ``(2**n_total, 2**n_total)`` matrix. 

1797 """ 

1798 k = len(op_wires) 

1799 if k == n_total and list(op_wires) == list(all_wires): 

1800 return mat 

1801 

1802 # Build the full-space matrix by tensoring with identities 

1803 # Strategy: tensor I on missing wires, then permute 

1804 missing = [w for w in all_wires if w not in op_wires] 

1805 # Full matrix = mat \\otimes I_{missing} 

1806 full_mat = mat 

1807 for _ in missing: 

1808 full_mat = jnp.kron(full_mat, jnp.eye(2, dtype=_cdtype())) 

1809 

1810 # The current ordering is [op_wires..., missing...] 

1811 # We need to permute to match all_wires ordering 

1812 current_order = list(op_wires) + missing 

1813 if current_order != list(all_wires): 

1814 perm = [current_order.index(w) for w in all_wires] 

1815 full_mat = _permute_matrix(full_mat, perm, n_total) 

1816 

1817 return full_mat 

1818 

1819 

1820def _permute_matrix(mat: jnp.ndarray, perm: list, n_qubits: int) -> jnp.ndarray: 

1821 """Permute the qubit ordering of a matrix. 

1822 

1823 Given a ``(2**n, 2**n)`` matrix and a permutation of ``[0..n-1]``, 

1824 reorder the qubits so that qubit ``i`` moves to position ``perm[i]``. 

1825 

1826 Args: 

1827 mat: Square matrix of dimension ``2**n_qubits``. 

1828 perm: Permutation list. 

1829 n_qubits: Number of qubits. 

1830 

1831 Returns: 

1832 Permuted matrix of the same shape. 

1833 """ 

1834 dim = 2**n_qubits 

1835 # Reshape to tensor, permute axes, reshape back 

1836 tensor = mat.reshape([2] * (2 * n_qubits)) 

1837 # Axes: first n_qubits are row indices, last n_qubits are column indices 

1838 row_perm = perm 

1839 col_perm = [p + n_qubits for p in perm] 

1840 tensor = jnp.transpose(tensor, row_perm + col_perm) 

1841 return tensor.reshape(dim, dim) 

1842 

1843 

1844def pauli_decompose(matrix: jnp.ndarray, wire_order: Optional[List[int]] = None): 

1845 r"""Decompose a Hermitian matrix into a sum of Pauli tensor products. 

1846 

1847 For an n-qubit matrix (``2**n x 2**n``), returns the dominant Pauli 

1848 term (the one with the largest absolute coefficient), wrapped as an 

1849 :class:`Operation`. This is sufficient for the Fourier-tree algorithm 

1850 which only needs the single non-zero Pauli term produced by Clifford 

1851 conjugation of a Pauli operator. 

1852 

1853 The decomposition uses the trace formula: 

1854 ``c_P = Tr(P · M) / 2**n`` 

1855 

1856 Args: 

1857 matrix: A ``(2**n, 2**n)`` Hermitian matrix. 

1858 wire_order: Optional list of wire indices. If ``None``, defaults 

1859 to ``[0, 1, ..., n-1]``. 

1860 

1861 Returns: 

1862 A tuple ``(coeff, op)`` where *coeff* is the complex coefficient and 

1863 *op* is the Pauli :class:`Operation` (PauliX, PauliY, PauliZ, I, or 

1864 a :class:`Hermitian` for multi-qubit tensor products). 

1865 """ 

1866 from itertools import product as _product 

1867 from functools import reduce as _reduce 

1868 

1869 dim = matrix.shape[0] 

1870 n_qubits = int(jnp.round(jnp.log2(dim))) 

1871 

1872 if wire_order is None: 

1873 wire_order = list(range(n_qubits)) 

1874 

1875 # For single qubit, fast path 

1876 if n_qubits == 1: 

1877 best_idx, best_coeff = 0, 0.0 

1878 for idx, P in enumerate(_PAULI_MATS): 

1879 coeff = jnp.trace(P @ matrix) / 2.0 

1880 if jnp.abs(coeff) > jnp.abs(best_coeff): 

1881 best_idx = idx 

1882 best_coeff = coeff 

1883 op_cls = _PAULI_CLASSES[best_idx] 

1884 result_op = op_cls(wires=wire_order[0], record=False) 

1885 result_op._pauli_label = _PAULI_LABELS[best_idx] 

1886 return best_coeff, result_op 

1887 

1888 # Multi-qubit: iterate over all Pauli tensor products 

1889 best_label = None 

1890 best_coeff = 0.0 

1891 for indices in _product(range(4), repeat=n_qubits): 

1892 P = _reduce(jnp.kron, [_PAULI_MATS[i] for i in indices]) 

1893 coeff = jnp.trace(P @ matrix) / dim 

1894 if jnp.abs(coeff) > jnp.abs(best_coeff): 

1895 best_coeff = coeff 

1896 best_label = indices 

1897 

1898 # Build the Pauli string label 

1899 pauli_label = "".join(_PAULI_LABELS[i] for i in best_label) 

1900 

1901 # Build the operation for the dominant term 

1902 if sum(1 for i in best_label if i != 0) <= 1: 

1903 # Single-qubit Pauli on one wire 

1904 for q, idx in enumerate(best_label): 

1905 if idx != 0: 

1906 op_cls = _PAULI_CLASSES[idx] 

1907 result_op = op_cls(wires=wire_order[q], record=False) 

1908 result_op._pauli_label = _PAULI_LABELS[idx] 

1909 return best_coeff, result_op 

1910 # All identity 

1911 result_op = Id(wires=wire_order[0], record=False) 

1912 result_op._pauli_label = "I" * n_qubits 

1913 return best_coeff, result_op 

1914 else: 

1915 # Multi-qubit tensor product -> Hermitian with pauli label attached 

1916 P = _reduce(jnp.kron, [_PAULI_MATS[i] for i in best_label]) 

1917 result_op = Hermitian(matrix=P, wires=wire_order, record=False) 

1918 result_op._pauli_label = pauli_label 

1919 return best_coeff, result_op 

1920 

1921 

1922def pauli_string_from_operation(op: Operation) -> str: 

1923 """Extract a Pauli word string from an operation. 

1924 

1925 Maps ``PauliX`` -> ``"X"``, ``PauliY`` -> ``"Y"``, ``PauliZ`` -> ``"Z"``, 

1926 ``I`` -> ``"I"``. For :class:`PauliRot`, returns its stored ``pauli_word``. 

1927 For operations produced by :func:`pauli_decompose`, returns the stored 

1928 ``_pauli_label`` attribute. 

1929 

1930 Args: 

1931 op: A quantum operation. 

1932 

1933 Returns: 

1934 A string like ``"X"``, ``"ZZ"``, etc. 

1935 """ 

1936 if isinstance(op, PauliRot) and hasattr(op, "pauli_word"): 

1937 return op.pauli_word 

1938 # Check for label stored by pauli_decompose 

1939 if hasattr(op, "_pauli_label"): 

1940 return op._pauli_label 

1941 name_map = {"PauliX": "X", "PauliY": "Y", "PauliZ": "Z", "I": "I"} 

1942 if op.name in name_map: 

1943 return name_map[op.name] 

1944 # Fall back: decompose the matrix 

1945 _, pauli_op = pauli_decompose(op.matrix, wire_order=op.wires) 

1946 return pauli_op._pauli_label 

1947 

1948 

1949def prod(*ops: Operation) -> Operation: 

1950 """Construct the generalized product (tensor or matrix) of multiple operations. 

1951 

1952 The resulting operation acts on the union of all wire sets. 

1953 If the wire sets are disjoint, this is a Kronecker product. 

1954 If the wire sets overlap, the corresponding matrices are multiplied. 

1955 

1956 Args: 

1957 *ops: Variable number of :class:`Operation` instances. 

1958 

1959 Returns: 

1960 A new :class:`Operation` whose matrix represents the composed 

1961 operation on the unified wire set. 

1962 """ 

1963 if not ops: 

1964 raise ValueError("At least one operation must be provided to prod().") 

1965 return ops[0].prod(*ops[1:])