Coverage for qml_essentials / operations.py: 83%

822 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-06-11 15:51 +0000

1from typing import Callable, List, Optional, Tuple, Union 

2from functools import lru_cache 

3import string 

4import numpy as np 

5 

6import jax 

7import jax.numpy as jnp 

8 

9from qml_essentials.tape import active_tape, recording # noqa: F401 (re-export) 

10 

11 

12def _cdtype(): 

13 """Return the active JAX complex dtype 

14 (complex128 if x64 enabled, else complex64). 

15 """ 

16 return jnp.complex128 if jax.config.x64_enabled else jnp.complex64 

17 

18 

19@lru_cache(maxsize=256) 

20def _einsum_subscript( 

21 n: int, 

22 k: int, 

23 target_axes: Tuple[int, ...], 

24) -> str: 

25 """Build an ``einsum`` subscript that fuses contraction + axis restore. 

26 

27 Args: 

28 n: Total rank of the state tensor (number of qubits for statevectors, 

29 ``2 * n_qubits`` for density matrices). 

30 k: Number of qubits the gate acts on. 

31 target_axes: Tuple of k axis indices in the state tensor that the 

32 gate contracts against. 

33 

34 Returns: 

35 ``einsum`` subscript string, e.g. ``"ab,cBd->cad"`` for a 1-qubit 

36 gate on wire 1 of a 3-qubit state. 

37 """ 

38 letters = string.ascii_letters 

39 # State indices: one letter per axis 

40 state_idx = list(letters[:n]) 

41 # Contracted indices (the ones being replaced by the gate) 

42 contracted = [state_idx[ax] for ax in target_axes] 

43 # Gate indices: new output indices + contracted input indices 

44 new_out = [letters[n + i] for i in range(k)] # fresh letters for output 

45 gate_idx = new_out + contracted # gate shape: (out0, out1, ..., in0, in1, ...) 

46 # Result indices: replace target axes with new output letters 

47 result_idx = list(state_idx) 

48 for i, ax in enumerate(target_axes): 

49 result_idx[ax] = new_out[i] 

50 return "".join(gate_idx) + "," + "".join(state_idx) + "->" + "".join(result_idx) 

51 

52 

53def _contract_and_restore( 

54 tensor: jnp.ndarray, 

55 gate: jnp.ndarray, 

56 k: int, 

57 target_axes: List[int], 

58) -> jnp.ndarray: 

59 """Contract gate against target_axes of tensor and restore axis order. 

60 

61 The einsum subscript is cached via :func:`_einsum_subscript` so the 

62 string construction only happens once per unique 

63 ``(total, k, target_axes)`` combination. 

64 

65 Args: 

66 tensor: Rank-N tensor (e.g. ``(2,)*n`` for states or ``(2,)*2n`` 

67 for density matrices). 

68 gate: Reshaped gate tensor of shape ``(2,)*2k``. 

69 k: Number of qubits the gate acts on (= ``len(target_axes)``). 

70 target_axes: The k axes of tensor to contract against. 

71 

72 Returns: 

73 Updated tensor with the same rank as tensor, with the 

74 contracted axes restored to their original positions. 

75 """ 

76 subscript = _einsum_subscript(tensor.ndim, k, tuple(target_axes)) 

77 return jnp.einsum(subscript, gate, tensor) 

78 

79 

80class Operation: 

81 """Base class for any quantum operation or observable. 

82 

83 Further gates should inherit from this class to realise more specific 

84 operations. Generally, operations are created by instantiation inside a 

85 circuit function passed to :class:`Script`; the instance is 

86 automatically appended to the active tape. 

87 

88 An ``Operation`` can also serve as an *observable*: its matrix is used to 

89 compute expectation values via ``apply_to_state`` / ``apply_to_density``. 

90 

91 Attributes: 

92 _matrix: Class-level default gate matrix. Subclasses set this to their 

93 fixed unitary. Instances may override it via the *matrix* argument 

94 to ``__init__``. 

95 _num_wires: Expected number of wires for this gate. Subclasses set 

96 this to enforce wire count validation. ``None`` means any number 

97 of wires is accepted. 

98 _param_names: Tuple of attribute names for the gate parameters. 

99 Used by :attr:`parameters` and :meth:`__repr__`. 

100 """ 

101 

102 # Subclasses should set this to the gate's unitary / matrix 

103 # Whether this is a controlled operation 

104 is_controlled = False 

105 # Whether this gate is a Clifford gate (normalises the Pauli group 

106 is_clifford = False 

107 

108 _matrix: jnp.ndarray = None 

109 _num_wires: Optional[int] = None 

110 _param_names: Tuple[str, ...] = () 

111 

112 def __init__( 

113 self, 

114 wires: Union[int, List[int]] = 0, 

115 matrix: Optional[jnp.ndarray] = None, 

116 record: bool = True, 

117 name: Optional[str] = None, 

118 ) -> None: 

119 """Initialise the operation and optionally register it on the active tape. 

120 

121 Args: 

122 wires: Qubit index or list of qubit indices this operation acts on. 

123 matrix: Optional explicit gate matrix. When provided it overrides 

124 the class-level ``_matrix`` attribute. 

125 record: If ``True`` (default) and a tape is currently recording, 

126 append this operation to the tape. Set to ``False`` for 

127 auxiliary objects that should not appear in the circuit 

128 (e.g. Hamiltonians used only to build time-dependent 

129 evolutions). 

130 name: Optional explicit name for this operation. When ``None`` 

131 (default), the class name is used (e.g. ``"RX"``). 

132 

133 Raises: 

134 ValueError: If ``_num_wires`` is set and the number of wires 

135 doesn't match, or if duplicate wires are provided. 

136 """ 

137 self.name = name or self.__class__.__name__ 

138 self.wires = list(wires) if isinstance(wires, (list, tuple)) else [wires] 

139 

140 if self._num_wires is not None and len(self.wires) != self._num_wires: 

141 raise ValueError( 

142 f"{self.name} expects {self._num_wires} wire(s), " 

143 f"got {len(self.wires)}: {self.wires}" 

144 ) 

145 if len(self.wires) != len(set(self.wires)): 

146 raise ValueError(f"{self.name} received duplicate wires: {self.wires}") 

147 

148 if matrix is not None: 

149 self._matrix = matrix 

150 

151 # If a tape is currently recording, append ourselves 

152 if record: 

153 tape = active_tape() 

154 if tape is not None: 

155 tape.append(self) 

156 

157 @property 

158 def parameters(self) -> list: 

159 """Return the list of numeric parameters for this operation. 

160 

161 Uses the declarative ``_param_names`` tuple to collect parameter 

162 values in a canonical order. Non-parametrized gates return an 

163 empty list. 

164 

165 Returns: 

166 List of parameter values (floats or JAX arrays). 

167 """ 

168 return [getattr(self, name) for name in self._param_names] 

169 

170 def __repr__(self) -> str: 

171 """Return a human-readable representation of this operation. 

172 

173 Returns: 

174 A string like ``"RX(0.5000, wires=[0])"`` or ``"CX(wires=[0, 1])"``. 

175 """ 

176 params = self.parameters 

177 if params: 

178 param_str = ", ".join( 

179 ( 

180 f"{float(v):.4f}" 

181 if isinstance(v, (float, np.floating, jnp.ndarray)) 

182 else str(v) 

183 ) 

184 for v in params 

185 ) 

186 return f"{self.name}({param_str}, wires={self.wires})" 

187 return f"{self.name}(wires={self.wires})" 

188 

189 @property 

190 def matrix(self) -> jnp.ndarray: 

191 """Return the base matrix of this operation (before lifting). 

192 

193 Returns: 

194 The gate matrix as a JAX array. 

195 

196 Raises: 

197 NotImplementedError: If the subclass has not defined ``_matrix``. 

198 """ 

199 if self._matrix is None: 

200 raise NotImplementedError( 

201 f"{self.__class__.__name__} does not define a matrix." 

202 ) 

203 return self._matrix 

204 

205 def decompose(self) -> List["Operation"]: 

206 """Decompose this operation into a list of more primitive operations. 

207 

208 The returned operations are created with ``record=False`` so the caller 

209 controls where they are placed. Reused e.g. by 

210 :meth:`~qml_essentials.pauli.PauliCircuit.get_clifford_pauli_gates` to 

211 express composite gates in terms of Clifford + Pauli-rotation primitives. 

212 

213 Returns: 

214 List of :class:`Operation` instances equivalent to this gate. 

215 

216 Raises: 

217 NotImplementedError: If the gate has no decomposition (it is itself 

218 primitive). 

219 """ 

220 raise NotImplementedError( 

221 f"{self.__class__.__name__} does not define a decomposition." 

222 ) 

223 

224 @property 

225 def wires(self) -> List[int]: 

226 """Qubit indices this operation acts on. 

227 

228 Returns: 

229 List of integer qubit indices. 

230 """ 

231 return self._wires 

232 

233 @wires.setter 

234 def wires(self, wires: Union[int, List[int]]) -> None: 

235 """Set the qubit indices for this operation. 

236 

237 Args: 

238 wires: A single qubit index or a list of qubit indices. 

239 """ 

240 if isinstance(wires, (list, tuple)): 

241 self._wires = list(wires) 

242 else: 

243 self._wires = [wires] 

244 

245 def _update_tape_operation(self, op: "Operation") -> None: 

246 """ 

247 If ``self`` is already on the active tape (the typical case when 

248 chaining ``Gate(...).dagger()``), it is replaced by the daggered 

249 operation so that only U\\dagger appears on the tape — 

250 not both U and ``U\\dagger``. 

251 Note that this should only be called immediately after the tape is updated.s 

252 

253 Args: 

254 op (Operation): New replaced operation on the tape 

255 """ 

256 # If self was recorded on the tape, replace it with the daggered op. 

257 tape = active_tape() 

258 if tape is not None: 

259 if tape and tape[-1] is self: 

260 tape[-1] = op 

261 else: 

262 tape.append(op) 

263 

264 def dagger(self) -> "Operation": 

265 """Return a new operation, the conjugate transpose (``U\\dagger``) 

266 Usage inside a circuit function:: 

267 

268 RX(0.5, wires=0).dagger() 

269 

270 Returns: 

271 A new :class:`Operation` with matrix ``U\\dagger`` acting on the same wires. 

272 """ 

273 mat = jnp.conj(self._matrix).T 

274 op = Operation(wires=self.wires, matrix=mat, record=False) 

275 

276 self._update_tape_operation(op) 

277 

278 return op 

279 

280 def power(self, power) -> "Operation": 

281 """Return a new operation, the power (``U^power``) 

282 Usage inside a circuit function:: 

283 

284 PauliX(wires=0).power(2) 

285 

286 Returns: 

287 A new :class:`Operation` with matrix ``U\\dagger`` acting on the same wires. 

288 """ 

289 # TODO: support fractional powers 

290 mat = jnp.linalg.matrix_power(self._matrix, power) 

291 op = Operation(wires=self.wires, matrix=mat, record=False) 

292 

293 self._update_tape_operation(op) 

294 

295 return op 

296 

297 def __mul__(self, other: Union[float, "Operation"]) -> "Operation": 

298 """Return a new operation, the product between U and a scalar (``U*x``) 

299 or the composition of two operations. 

300 Usage inside a circuit function:: 

301 

302 PauliX(wires=0) * x 

303 PauliX(wires=0) * PauliZ(wires=0) 

304 

305 Returns: 

306 A new :class:`Operation` with matrix ``U*x`` acting on the same wires, 

307 or the composed matrix acting on the appropriate wires. 

308 """ 

309 if isinstance(other, Operation): 

310 return self.__matmul__(other) 

311 

312 mat = other * self._matrix 

313 op = Operation(wires=self.wires, matrix=mat, record=False) 

314 

315 self._update_tape_operation(op) 

316 

317 return op 

318 

319 # Also overwrite * for right operands 

320 __rmul__ = __mul__ 

321 

322 def __add__(self, other: "Operation") -> "Operation": 

323 """Element-wise addition of two operations on the same wires. 

324 

325 Returns: 

326 A new :class:`Operation` whose matrix is the sum of both matrices. 

327 

328 Raises: 

329 ValueError: If the wire sets differ. 

330 """ 

331 if sorted(self.wires) != sorted(other.wires): 

332 raise ValueError( 

333 f"Can only add operations acting on the same set of wires, " 

334 f"got {self.wires} and {other.wires}" 

335 ) 

336 

337 op = Operation( 

338 wires=self.wires, 

339 matrix=self.matrix + other.matrix, 

340 record=False, 

341 ) 

342 return op 

343 

344 def prod(self, *ops: "Operation") -> "Operation": 

345 """Construct the generalized product (tensor or matrix) 

346 of this operation with others. 

347 

348 The resulting operation acts on the union of all wire sets. 

349 If the wire sets are disjoint, this is a Kronecker product. 

350 If the wire sets overlap, the corresponding matrices are multiplied. 

351 

352 Usage:: 

353 

354 res = op1.prod(op2, op3) 

355 # or 

356 res = Operation.prod(op1, op2, op3) 

357 

358 Args: 

359 *ops: Variable number of :class:`Operation` instances. 

360 

361 Returns: 

362 A new :class:`Operation` representing the composed operation. 

363 """ 

364 if not ops: 

365 return self 

366 

367 all_ops = (self,) + ops 

368 all_wires = [] 

369 for op in all_ops: 

370 for w in op.wires: 

371 if w not in all_wires: 

372 all_wires.append(w) 

373 

374 n = len(all_wires) 

375 

376 mat = _embed_matrix(all_ops[0].matrix, all_ops[0].wires, all_wires, n) 

377 for op in all_ops[1:]: 

378 mat_other = _embed_matrix(op.matrix, op.wires, all_wires, n) 

379 mat = mat @ mat_other 

380 

381 op_names = "*".join(op.name for op in all_ops) 

382 return Operation( 

383 wires=all_wires, matrix=mat, name=f"Prod({op_names})", record=False 

384 ) 

385 

386 def __matmul__(self, other: "Operation") -> "Operation": 

387 """Tensor (Kronecker) product or matrix product of two operations. 

388 

389 The resulting operation acts on the union of both wire sets. 

390 If the wire sets are disjoint, this is a Kronecker product. 

391 If the wire sets overlap, the corresponding matrices are multiplied. 

392 

393 Returns: 

394 A new :class:`Operation` whose matrix represents the composed 

395 operation on the unified wire set. 

396 """ 

397 if not isinstance(other, Operation): 

398 return NotImplemented 

399 

400 return self.prod(other) 

401 

402 def lifted_matrix(self, n_qubits: int) -> jnp.ndarray: 

403 """Return the full ``2**n x 2**n`` matrix embedding this gate. 

404 

405 Embeds the ``k``-qubit gate matrix into the ``n``-qubit Hilbert space 

406 by applying it to the identity matrix via :meth:`apply_to_state`. 

407 This is useful for computing ``Tr(O·\\rho )`` directly without vmap. 

408 

409 Args: 

410 n_qubits: Total number of qubits in the circuit. 

411 

412 Returns: 

413 The ``(2**n, 2**n)`` matrix of this operation in the full space. 

414 """ 

415 dim = 2**n_qubits 

416 # Apply the gate to each basis vector (column of identity) 

417 return jax.vmap(lambda col: self.apply_to_state(col, n_qubits))( 

418 jnp.eye(dim, dtype=_cdtype()) 

419 ).T 

420 

421 def apply_to_state(self, state: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

422 """Apply this gate to a statevector via tensor contraction. 

423 

424 The statevector (shape ``(2**n,)``) is reshaped into a rank-n tensor 

425 of shape ``(2,)*n``. The gate (shape ``(2**k, 2**k)``) is reshaped to 

426 ``(2,)*2k`` and contracted against the k target wire axes. 

427 

428 Memory footprint is O(2**n) and the operation supports arbitrary k. 

429 The implementation is fully differentiable through JAX. 

430 

431 Args: 

432 state: Statevector of shape ``(2**n_qubits,)``. 

433 n_qubits: Total number of qubits in the circuit. 

434 

435 Returns: 

436 Updated statevector of shape ``(2**n_qubits,)``. 

437 """ 

438 k = len(self.wires) 

439 gate_tensor = self.matrix.reshape((2,) * 2 * k) 

440 psi = state.reshape((2,) * n_qubits) 

441 psi_out = _contract_and_restore(psi, gate_tensor, k, self.wires) 

442 return psi_out.reshape(2**n_qubits) 

443 

444 def apply_to_state_tensor(self, psi: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

445 """Apply this gate to a statevector already in tensor form. 

446 

447 Like :meth:`apply_to_state` but expects the state in rank-n tensor 

448 form ``(2,)*n`` and returns the result in the same form. This avoids 

449 the ``reshape`` calls at the per-gate level when the simulation loop 

450 keeps the state in tensor form throughout. 

451 

452 Args: 

453 psi: Statevector tensor of shape ``(2,)*n_qubits``. 

454 n_qubits: Total number of qubits in the circuit. 

455 

456 Returns: 

457 Updated statevector tensor of shape ``(2,)*n_qubits``. 

458 """ 

459 k = len(self.wires) 

460 gate_tensor = self._gate_tensor(k) 

461 return _contract_and_restore(psi, gate_tensor, k, self.wires) 

462 

463 def _gate_tensor(self, k: int) -> jnp.ndarray: 

464 """Return the gate matrix reshaped to ``(2,)*2k`` tensor form. 

465 

466 The result is cached on the instance so repeated calls (e.g. from 

467 density-matrix simulation which applies U and U*) avoid redundant 

468 reshape dispatch. 

469 

470 Args: 

471 k: Number of qubits the gate acts on. 

472 

473 Returns: 

474 Gate matrix as a rank-2k tensor of shape ``(2,)*2k``. 

475 """ 

476 cached = getattr(self, "_cached_gate_tensor", None) 

477 if cached is not None: 

478 return cached 

479 gt = self.matrix.reshape((2,) * 2 * k) 

480 # Only cache for non-parametrized gates (whose matrix is a class attr) 

481 if self._matrix is self.__class__._matrix: 

482 object.__setattr__(self, "_cached_gate_tensor", gt) 

483 return gt 

484 

485 def apply_to_density(self, rho: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

486 """Apply this gate to a density matrix via \\rho -> U\\rho U\\dagger. 

487 

488 The density matrix (shape ``(2**n, 2**n)``) is treated as a rank-*2n* 

489 tensor with n "ket" axes (0..n-1) and n "bra" axes (n..2n-1). 

490 U acts on the ket half; U* acts on the bra half. Both contractions 

491 use the shared :func:`_contract_and_restore` helper, keeping the 

492 operation allocation-free with respect to building full unitaries. 

493 

494 Args: 

495 rho: Density matrix of shape ``(2**n_qubits, 2**n_qubits)``. 

496 n_qubits: Total number of qubits in the circuit. 

497 

498 Returns: 

499 Updated density matrix of shape ``(2**n_qubits, 2**n_qubits)``. 

500 """ 

501 k = len(self.wires) 

502 U = self._gate_tensor(k) 

503 U_conj = jnp.conj(U) 

504 

505 rho_t = rho.reshape((2,) * 2 * n_qubits) 

506 

507 # Apply U to ket axes, U\\dagger to bra axes 

508 rho_t = _contract_and_restore(rho_t, U, k, self.wires) 

509 bra_wires = [w + n_qubits for w in self.wires] 

510 rho_t = _contract_and_restore(rho_t, U_conj, k, bra_wires) 

511 

512 return rho_t.reshape(2**n_qubits, 2**n_qubits) 

513 

514 

515class Hermitian(Operation): 

516 """A generic Hermitian observable or gate defined by an arbitrary matrix. 

517 

518 Example: 

519 >>> obs = Hermitian(matrix=my_matrix, wires=0) 

520 """ 

521 

522 def __init__( 

523 self, 

524 matrix: jnp.ndarray, 

525 wires: Union[int, List[int]] = 0, 

526 record: bool = True, 

527 ) -> None: 

528 """Initialise a Hermitian operator. 

529 

530 Args: 

531 matrix: The Hermitian matrix defining this operator. 

532 wires: Qubit index or list of qubit indices this operator acts on. 

533 record: If ``True`` (default), record on the active tape. Set to 

534 ``False`` when using the Hermitian purely as a Hamiltonian 

535 component (e.g. for time-dependent evolution). 

536 """ 

537 super().__init__( 

538 wires=wires, 

539 matrix=jnp.asarray(matrix, dtype=_cdtype()), 

540 record=record, 

541 ) 

542 

543 def __rmul__(self, coeff_fn: Callable) -> "ParametrizedHamiltonian": 

544 """Support ``coeff_fn * Hermitian`` -> :class:`ParametrizedHamiltonian`. 

545 

546 Args: 

547 coeff_fn (Callable): A callable ``(params, t) -> scalar`` giving the 

548 time-dependent coefficient. 

549 

550 Returns: 

551 ParametrizedHamiltonian: A :class:`ParametrizedHamiltonian` pairing 

552 *coeff_fn* with this operator's matrix and wires. 

553 

554 Raises: 

555 TypeError: If *coeff_fn* is not callable. 

556 """ 

557 if not callable(coeff_fn): 

558 raise TypeError( 

559 f"Left operand of `* Hermitian` must be callable, got {type(coeff_fn)}" 

560 ) 

561 return ParametrizedHamiltonian(terms=[(coeff_fn, self.matrix, self.wires)]) 

562 

563 def evolve(self, name: Optional[str] = None, **odeint_kwargs) -> Callable: 

564 """Return a gate factory for static evolution ``U = exp(-i t H)``. 

565 

566 Thin delegator to :meth:`qml_essentials.evolution.Evolution.evolve`. 

567 

568 Args: 

569 name: Optional name for the produced :class:`Operation`. 

570 **odeint_kwargs: Unused for static evolution (accepted for a 

571 uniform signature with :meth:`ParametrizedHamiltonian.evolve`). 

572 

573 Returns: 

574 A callable gate factory ``(t, wires=0) -> Operation``. 

575 """ 

576 from qml_essentials.evolution import Evolution # deferred: circular import 

577 

578 return Evolution.evolve(self, name=name, **odeint_kwargs) 

579 

580 

581class ParametrizedHamiltonian: 

582 """A time-dependent Hamiltonian as a sum of ``coeff * Hermitian`` terms. 

583 

584 Mathematically:: 

585 

586 H(t) = \\sum_i f_i(params_i, t) * H_i 

587 

588 Construction is always done from an explicit list of 

589 ``(coeff_fn, H_mat, wires)`` triples passed as ``terms``. The 

590 common single-term shorthand is the operator form 

591 ``coeff_fn * Hermitian(matrix, wires)`` (see 

592 :meth:`Hermitian.__rmul__`), which returns a one-term instance. 

593 Multi-term Hamiltonians are composed with ``+`` between 

594 :class:`ParametrizedHamiltonian` instances:: 

595 

596 H1 = coeff_x * Hermitian(X, wires=0) 

597 H2 = coeff_y * Hermitian(Y, wires=0) 

598 H_td = H1 + H2 

599 

600 # evolve under the composite Hamiltonian; coeff_args is a list of 

601 # parameter sets, one per term, in the order the terms were added: 

602 H_td.evolve()([px, py], T=1.0) 

603 

604 Attributes: 

605 coeff_fns: Tuple of callables ``(params, t) -> scalar``, one per term. 

606 H_mats: Tuple of static Hermitian matrices, one per term. 

607 wires: Wires this Hamiltonian acts on (union across all terms; for 

608 now all terms are required to share the same wire set). 

609 """ 

610 

611 def __init__( 

612 self, 

613 terms: List[Tuple[Callable, jnp.ndarray, Union[int, List[int]]]], 

614 ) -> None: 

615 """Build a (possibly multi-term) parametrized Hamiltonian. 

616 

617 Args: 

618 terms: List of ``(coeff_fn, H_mat, wires)`` triples. Use the 

619 ``coeff_fn * Hermitian(...)`` shorthand to build a 

620 one-term instance; combine instances with ``+`` to add 

621 terms. 

622 

623 Raises: 

624 ValueError: If the term list is empty, or if terms act on 

625 differing wire sets (multi-wire broadcasting is 

626 deferred — see :mod:`jaqsi`), or if term matrices have 

627 incompatible shapes. 

628 """ 

629 if len(terms) == 0: 

630 raise ValueError("ParametrizedHamiltonian needs at least one term.") 

631 

632 # Normalise wires (single int -> [int]) and validate consistency. 

633 def _wlist(w): 

634 return [w] if isinstance(w, int) else list(w) 

635 

636 first_wires = _wlist(terms[0][2]) 

637 for _, _, w in terms[1:]: 

638 if _wlist(w) != first_wires: 

639 raise ValueError( 

640 "All terms of a ParametrizedHamiltonian must currently " 

641 "act on the same wires; got " 

642 f"{_wlist(w)} vs. {first_wires}. " 

643 "Multi-wire broadcasting across terms is not yet supported." 

644 ) 

645 

646 # Validate matrix shape compatibility across terms. 

647 first_dim = jnp.asarray(terms[0][1]).shape 

648 for _, H, _ in terms[1:]: 

649 if jnp.asarray(H).shape != first_dim: 

650 raise ValueError( 

651 f"All term matrices must have the same shape; got " 

652 f"{jnp.asarray(H).shape} vs. {first_dim}." 

653 ) 

654 

655 self._terms: Tuple[Tuple[Callable, jnp.ndarray, List[int]], ...] = tuple( 

656 (fn, jnp.asarray(H, dtype=_cdtype()), _wlist(w)) for fn, H, w in terms 

657 ) 

658 self.wires: List[int] = list(first_wires) 

659 

660 # --- term accessors ------------------------------------------------- 

661 

662 @property 

663 def coeff_fns(self) -> Tuple[Callable, ...]: 

664 """Tuple of coefficient functions, one per term.""" 

665 return tuple(fn for fn, _, _ in self._terms) 

666 

667 @property 

668 def H_mats(self) -> Tuple[jnp.ndarray, ...]: 

669 """Tuple of Hermitian matrices, one per term.""" 

670 return tuple(H for _, H, _ in self._terms) 

671 

672 @property 

673 def n_terms(self) -> int: 

674 """Number of terms in the Hamiltonian.""" 

675 return len(self._terms) 

676 

677 # --- composition --------------------------------------------------- 

678 

679 def __add__(self, other: "ParametrizedHamiltonian") -> "ParametrizedHamiltonian": 

680 """Concatenate term lists: ``H = H1 + H2``.""" 

681 if not isinstance(other, ParametrizedHamiltonian): 

682 return NotImplemented 

683 return ParametrizedHamiltonian(terms=list(self._terms) + list(other._terms)) 

684 

685 def __neg__(self) -> "ParametrizedHamiltonian": 

686 """Negate every coefficient: ``-H`` = sum of ``(-f_i) * H_i``.""" 

687 new_terms = [ 

688 ((lambda f: lambda p, t: -f(p, t))(fn), H, w) for fn, H, w in self._terms 

689 ] 

690 return ParametrizedHamiltonian(terms=new_terms) 

691 

692 def __sub__(self, other: "ParametrizedHamiltonian") -> "ParametrizedHamiltonian": 

693 if not isinstance(other, ParametrizedHamiltonian): 

694 return NotImplemented 

695 return self + (-other) 

696 

697 # --- evolution ----------------------------------------------------- 

698 

699 def evolve(self, name: Optional[str] = None, **odeint_kwargs) -> Callable: 

700 """Return a gate factory for time-dependent evolution. 

701 

702 Solves ``dU/dt = -i [sum_i f_i(p_i, t) H_i] U``. Thin delegator to 

703 :meth:`qml_essentials.evolution.Evolution.evolve`. 

704 

705 Args: 

706 name: Optional name for the produced :class:`Operation`. 

707 **odeint_kwargs: Solver options forwarded to ``Evolution.evolve`` 

708 (``atol``, ``rtol``, ``max_steps``, ``throw``, ``solver``, 

709 ``magnus_steps``). 

710 

711 Returns: 

712 A callable gate factory ``(coeff_args, T) -> Operation``. 

713 """ 

714 from qml_essentials.evolution import Evolution # deferred: circular import 

715 

716 return Evolution.evolve(self, name=name, **odeint_kwargs) 

717 

718 

719class Id(Operation): 

720 """Identity gate. 

721 

722 Supports an arbitrary number of wires. When more than one wire is 

723 given the matrix is the ``2**k x 2**k`` identity (where *k* is the 

724 number of wires). 

725 """ 

726 

727 _matrix = jnp.eye(2, dtype=_cdtype()) 

728 _num_wires = None # accept any number of wires 

729 is_clifford = True 

730 

731 def __init__(self, wires: Union[int, List[int]] = 0, **kwargs) -> None: 

732 """Initialise an identity gate. 

733 

734 Args: 

735 wires: Qubit index or list of qubit indices this gate acts on. 

736 When multiple wires are given the matrix is automatically 

737 expanded to the matching ``2**k × 2**k`` identity. 

738 """ 

739 w = list(wires) if isinstance(wires, (list, tuple)) else [wires] 

740 k = len(w) 

741 if k > 1: 

742 kwargs["matrix"] = jnp.eye(2**k, dtype=_cdtype()) 

743 super().__init__(wires=wires, **kwargs) 

744 

745 

746class PauliX(Operation): 

747 """Pauli-X gate / observable (bit-flip, \\sigma_x).""" 

748 

749 _matrix = jnp.array([[0, 1], [1, 0]], dtype=_cdtype()) 

750 _num_wires = 1 

751 is_clifford = True 

752 

753 def __init__(self, wires: Union[int, List[int]] = 0, **kwargs) -> None: 

754 """Initialise a Pauli-X gate. 

755 

756 Args: 

757 wires: Qubit index or list of qubit indices this gate acts on. 

758 """ 

759 super().__init__(wires=wires, **kwargs) 

760 

761 

762class PauliY(Operation): 

763 """Pauli-Y gate / observable (\\sigma_y).""" 

764 

765 _matrix = jnp.array([[0, -1j], [1j, 0]], dtype=_cdtype()) 

766 _num_wires = 1 

767 is_clifford = True 

768 

769 def __init__(self, wires: Union[int, List[int]] = 0, **kwargs) -> None: 

770 """Initialise a Pauli-Y gate. 

771 

772 Args: 

773 wires: Qubit index or list of qubit indices this gate acts on. 

774 """ 

775 super().__init__(wires=wires, **kwargs) 

776 

777 

778class PauliZ(Operation): 

779 """Pauli-Z gate / observable (phase-flip, \\sigma_z).""" 

780 

781 _matrix = jnp.array([[1, 0], [0, -1]], dtype=_cdtype()) 

782 _num_wires = 1 

783 is_clifford = True 

784 

785 def __init__(self, wires: Union[int, List[int]] = 0, **kwargs) -> None: 

786 """Initialise a Pauli-Z gate. 

787 

788 Args: 

789 wires: Qubit index or list of qubit indices this gate acts on. 

790 """ 

791 super().__init__(wires=wires, **kwargs) 

792 

793 

794class H(Operation): 

795 """Hadamard gate.""" 

796 

797 _matrix = jnp.array([[1, 1], [1, -1]], dtype=_cdtype()) / jnp.sqrt(2) 

798 _num_wires = 1 

799 is_clifford = True 

800 

801 def __init__(self, wires: Union[int, List[int]] = 0, **kwargs) -> None: 

802 """Initialise a Hadamard gate. 

803 

804 Args: 

805 wires: Qubit index or list of qubit indices this gate acts on. 

806 """ 

807 super().__init__(wires=wires, **kwargs) 

808 

809 

810class S(Operation): 

811 """S (phase) gate — a Clifford gate equal to \\sqrt Z. 

812 

813 .. math:: 

814 S = \\begin{pmatrix}1 & 0\\ 0 & i\\end{pmatrix} 

815 """ 

816 

817 _matrix = jnp.array([[1, 0], [0, 1j]], dtype=_cdtype()) 

818 _num_wires = 1 

819 is_clifford = True 

820 

821 def __init__(self, wires: Union[int, List[int]] = 0) -> None: 

822 """Initialise an S gate. 

823 

824 Args: 

825 wires: Qubit index or list of qubit indices this gate acts on. 

826 """ 

827 super().__init__(wires=wires) 

828 

829 

830class SWAP(Operation): 

831 """SWAP gate.""" 

832 

833 _matrix = jnp.array( 

834 [[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=_cdtype() 

835 ) 

836 _num_wires = 2 

837 is_clifford = True 

838 

839 def __init__(self, wires: Union[int, List[int]] = 0, **kwargs) -> None: 

840 """Initialise a SWAP gate. 

841 

842 Args: 

843 wires: Qubit index or list of qubit indices this gate acts on. 

844 """ 

845 super().__init__(wires=wires, **kwargs) 

846 

847 

848class RandomUnitary(Operation): 

849 """Creates a random hermitian matrix and applies it as a gate.""" 

850 

851 def __init__( 

852 self, 

853 wires: Union[int, List[int]], 

854 key: jax.random.PRNGKey, 

855 scale: float = 1.0, 

856 record: bool = True, 

857 ) -> None: 

858 """Initialise a random unitary gate. 

859 

860 Args: 

861 wires (Union[int, List[int]]): Qubit index or list of qubit indices 

862 this gate acts on. 

863 key (jax.random.PRNGKey): PRNGKey for randomization. 

864 scale (float): Scale of the random unitary (default: 1.0). 

865 record (bool): Whether to record this gate on the active tape. 

866 """ 

867 dim = 2 ** len(wires) 

868 key_a, key_b = jax.random.split(key) 

869 

870 A = ( 

871 jax.random.normal(key=key_a, shape=(dim, dim)) 

872 + 1j * jax.random.normal(key=key_b, shape=(dim, dim)) 

873 ).astype(_cdtype()) 

874 H = (A + A.conj().T) / 2.0 

875 

876 H *= scale / jnp.linalg.norm(H, ord="fro") 

877 

878 super().__init__(wires, matrix=H, record=record) 

879 

880 

881class DiagonalQubitUnitary(Operation): 

882 """A diagonal unitary gate specified by its diagonal entries. 

883 

884 Implements ``U = diag(d_0, d_1, ..., d_{2^k-1})`` where each ``d_i`` lies 

885 on the unit circle. This is the natural gate for data-encoding 

886 Hamiltonians of the form ``S(x) = exp(-i H x)`` where *H* is diagonal in 

887 the computational basis (see Peters et al., arXiv:2209.05523). 

888 

889 The Golomb encoding strategy uses this gate with diagonal entries 

890 ``exp(-i * golomb_marks * x)`` to achieve a maximally non-degenerate 

891 Fourier spectrum. 

892 

893 Args: 

894 diag: 1-D array of ``2**k`` complex values on the unit circle. 

895 wires: Qubit indices this gate acts on (s.t. ``2**len(wires) == len(diag)``). 

896 **kwargs: Forwarded to :class:`Operation`. 

897 """ 

898 

899 # Do NOT list "diag" in _param_names — the array is not a scalar 

900 # parameter and would break drawing helpers that call float(p). 

901 _param_names = () 

902 

903 def __init__( 

904 self, 

905 diag: jnp.ndarray, 

906 wires: Union[int, List[int]] = 0, 

907 **kwargs, 

908 ) -> None: 

909 self.diag = diag 

910 wires_list = list(wires) if isinstance(wires, (list, tuple)) else [wires] 

911 expected_dim = 2 ** len(wires_list) 

912 if diag.shape != (expected_dim,): 

913 raise ValueError( 

914 f"DiagonalQubitUnitary expects {expected_dim} diagonal entries " 

915 f"for {len(wires_list)} wire(s), got shape {diag.shape}" 

916 ) 

917 mat = jnp.diag(diag) 

918 # Use a descriptive name for drawing 

919 kwargs.setdefault("name", "DiagU") 

920 super().__init__(wires=wires, matrix=mat, **kwargs) 

921 

922 def apply_to_state(self, state: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

923 """Apply diagonal gate via element-wise multiplication. 

924 

925 For a diagonal unitary, the full ``2^n``-dimensional diagonal is 

926 constructed by appropriate Kronecker-product embedding and the gate 

927 is applied as an element-wise product, which is significantly cheaper 

928 than generic matrix contraction for large qubit counts. 

929 

930 Args: 

931 state: Statevector of shape ``(2**n_qubits,)``. 

932 n_qubits: Total number of qubits in the circuit. 

933 

934 Returns: 

935 Updated statevector of shape ``(2**n_qubits,)``. 

936 """ 

937 k = len(self.wires) 

938 if k == n_qubits and self.wires == list(range(n_qubits)): 

939 # Gate acts on all qubits in order — direct element-wise multiply 

940 return state * self.diag 

941 # Fall back to general tensor contraction for arbitrary wire subsets 

942 return super().apply_to_state(state, n_qubits) 

943 

944 def apply_to_density(self, rho: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

945 """Apply diagonal gate to density matrix: rho -> U rho U†. 

946 

947 For diagonal U the transformation is 

948 ``rho_ij -> d_i * conj(d_j) * rho_ij``. 

949 

950 Args: 

951 rho: Density matrix of shape ``(2**n_qubits, 2**n_qubits)``. 

952 n_qubits: Total number of qubits in the circuit. 

953 

954 Returns: 

955 Updated density matrix of shape ``(2**n_qubits, 2**n_qubits)``. 

956 """ 

957 k = len(self.wires) 

958 if k == n_qubits and self.wires == list(range(n_qubits)): 

959 d = self.diag 

960 return d[:, None] * jnp.conj(d)[None, :] * rho 

961 return super().apply_to_density(rho, n_qubits) 

962 

963 

964class Barrier(Operation): 

965 """Barrier operation — a no-op used for visual circuit separation. 

966 

967 The barrier does not change the quantum state. It is recorded on the 

968 tape so that drawing backends can insert a visual separator. 

969 """ 

970 

971 _matrix = None # not a real gate 

972 

973 def __init__(self, wires: Union[int, List[int]] = 0) -> None: 

974 """Initialise a Barrier. 

975 

976 Args: 

977 wires: Qubit index or list of qubit indices this barrier spans. 

978 """ 

979 super().__init__(wires=wires) 

980 

981 def apply_to_state(self, state: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

982 """No-op: return the state unchanged.""" 

983 return state 

984 

985 def apply_to_state_tensor(self, psi: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

986 """No-op: return the state tensor unchanged.""" 

987 return psi 

988 

989 def apply_to_density(self, rho: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

990 """No-op: return the density matrix unchanged.""" 

991 return rho 

992 

993 

994_PAULI_LABELS = ["I", "X", "Y", "Z"] 

995_PAULI_CLASSES = [Id, PauliX, PauliY, PauliZ] 

996_PAULI_MATRICES = { 

997 label: cls._matrix for label, cls in zip(_PAULI_LABELS, _PAULI_CLASSES) 

998} 

999_PAULI_MATS = [_PAULI_MATRICES[label] for label in _PAULI_LABELS] 

1000 

1001 

1002def _make_rotation_gate(pauli_class: type, name: str) -> type: 

1003 """Factory for single-qubit rotation gates RX, RY, RZ. 

1004 

1005 Each gate has the form ``R_P(\\theta) = cos(\\theta/2) I - i sin(\\theta/2) P``. 

1006 

1007 Args: 

1008 pauli_class: One of PauliX, PauliY, PauliZ. 

1009 name: Class name for the generated gate (e.g. ``"RX"``). 

1010 

1011 Returns: 

1012 A new :class:`Operation` subclass. 

1013 """ 

1014 pauli_mat = pauli_class._matrix 

1015 

1016 class _RotationGate(Operation): 

1017 # Fancy way of setting docstring to make it generic 

1018 __doc__ = ( 

1019 f"Rotation around the {name[1]} axis: {name}(\\theta) =\n" 

1020 f"exp(-i \\theta/2 {name[1]}).\n" 

1021 ) 

1022 _num_wires = 1 

1023 _param_names = ("theta",) 

1024 

1025 def __init__( 

1026 self, theta: float, wires: Union[int, List[int]] = 0, **kwargs 

1027 ) -> None: 

1028 self.theta = theta 

1029 c = jnp.cos(theta / 2) 

1030 s = jnp.sin(theta / 2) 

1031 mat = c * Id._matrix - 1j * s * pauli_mat 

1032 super().__init__(wires=wires, matrix=mat, **kwargs) 

1033 

1034 def generator(self) -> Operation: 

1035 """Return the generator as the corresponding Pauli operation.""" 

1036 return pauli_class(wires=self.wires[0], record=False) 

1037 

1038 _RotationGate.__name__ = name 

1039 _RotationGate.__qualname__ = name 

1040 return _RotationGate 

1041 

1042 

1043RX = _make_rotation_gate(PauliX, "RX") 

1044RY = _make_rotation_gate(PauliY, "RY") 

1045RZ = _make_rotation_gate(PauliZ, "RZ") 

1046 

1047 

1048# Projectors used by controlled-gate factories 

1049_P0 = jnp.array([[1, 0], [0, 0]], dtype=_cdtype()) 

1050_P1 = jnp.array([[0, 0], [0, 1]], dtype=_cdtype()) 

1051 

1052 

1053def _make_controlled_gate(target_class: type, name: str) -> type: 

1054 """Factory for controlled Pauli gates CX, CY, CZ. 

1055 

1056 Each gate has the form 

1057 ``CP = |0><0| \\otimes I + |1\\langle\\rangle 1| \\otimes P``. 

1058 

1059 Args: 

1060 target_class: The single-qubit gate class (PauliX, PauliY, PauliZ). 

1061 name: Class name for the generated gate (e.g. ``"CX"``). 

1062 

1063 Returns: 

1064 A new :class:`Operation` subclass. 

1065 """ 

1066 target_mat = target_class._matrix 

1067 

1068 class _ControlledGate(Operation): 

1069 __doc__ = ( 

1070 f"Controlled-{target_class.__name__[5:]} gate.\n\n" 

1071 f"Applies {target_class.__name__} on the target qubit conditioned " 

1072 f"on the control qubit being in state |1\\rangle." 

1073 ) 

1074 _matrix = jnp.kron(_P0, Id._matrix) + jnp.kron(_P1, target_mat) 

1075 _num_wires = 2 

1076 is_controlled = True 

1077 is_clifford = True # CX, CY, CZ are all Clifford gates 

1078 

1079 def __init__(self, wires: List[int] = [0, 1], **kwargs) -> None: 

1080 super().__init__(wires=wires, **kwargs) 

1081 

1082 def decompose(self) -> List["Operation"]: 

1083 # CZ = (H on target) CX (H on target). CX/CY are primitive. 

1084 if name != "CZ": 

1085 return super().decompose() 

1086 c, t = self.wires 

1087 return [ 

1088 H(wires=t, record=False), 

1089 CX(wires=[c, t], record=False), 

1090 H(wires=t, record=False), 

1091 ] 

1092 

1093 _ControlledGate.__name__ = name 

1094 _ControlledGate.__qualname__ = name 

1095 return _ControlledGate 

1096 

1097 

1098CX = _make_controlled_gate(PauliX, "CX") 

1099CY = _make_controlled_gate(PauliY, "CY") 

1100CZ = _make_controlled_gate(PauliZ, "CZ") 

1101 

1102 

1103class CCX(Operation): 

1104 """Toffoli (CCX) gate. 

1105 

1106 The 3-qubit Toffoli gate exercises the arbitrary-k-qubit path in 

1107 :meth:`~Operation.apply_to_state` and cannot be expressed as a pair of 

1108 2-qubit gates without ancilla, making it a good stress-test for the 

1109 simulator. 

1110 """ 

1111 

1112 _matrix = jnp.array( 

1113 [ 

1114 [1, 0, 0, 0, 0, 0, 0, 0], 

1115 [0, 1, 0, 0, 0, 0, 0, 0], 

1116 [0, 0, 1, 0, 0, 0, 0, 0], 

1117 [0, 0, 0, 1, 0, 0, 0, 0], 

1118 [0, 0, 0, 0, 1, 0, 0, 0], 

1119 [0, 0, 0, 0, 0, 1, 0, 0], 

1120 [0, 0, 0, 0, 0, 0, 0, 1], 

1121 [0, 0, 0, 0, 0, 0, 1, 0], 

1122 ], 

1123 dtype=_cdtype(), 

1124 ) 

1125 is_controlled = True 

1126 _num_wires = 3 

1127 

1128 def __init__(self, wires: List[int] = [0, 1, 2], **kwargs) -> None: 

1129 """Initialise a Toffoli (CCX) gate. 

1130 

1131 Args: 

1132 wires: Three-element list ``[control0, control1, target]``. 

1133 """ 

1134 super().__init__(wires=wires, **kwargs) 

1135 

1136 

1137class CSWAP(Operation): 

1138 """Controlled-SWAP (Fredkin) gate. 

1139 

1140 Swaps the two target qubits conditioned on the control qubit being |1\\rangle. 

1141 

1142 Args on construction: 

1143 wires: ``[control, target0, target1]``. 

1144 """ 

1145 

1146 _matrix = jnp.array( 

1147 [ 

1148 [1, 0, 0, 0, 0, 0, 0, 0], 

1149 [0, 1, 0, 0, 0, 0, 0, 0], 

1150 [0, 0, 1, 0, 0, 0, 0, 0], 

1151 [0, 0, 0, 1, 0, 0, 0, 0], 

1152 [0, 0, 0, 0, 1, 0, 0, 0], 

1153 [0, 0, 0, 0, 0, 0, 1, 0], 

1154 [0, 0, 0, 0, 0, 1, 0, 0], 

1155 [0, 0, 0, 0, 0, 0, 0, 1], 

1156 ], 

1157 dtype=_cdtype(), 

1158 ) 

1159 is_controlled = True 

1160 _num_wires = 3 

1161 

1162 def __init__(self, wires: List[int] = [0, 1, 2], **kwargs) -> None: 

1163 """Initialise a Controlled-SWAP (Fredkin) gate. 

1164 

1165 Args: 

1166 wires: Three-element list ``[control, target0, target1]``. 

1167 """ 

1168 super().__init__(wires=wires, **kwargs) 

1169 

1170 

1171class ControlledPhaseShift(Operation): 

1172 r"""Controlled phase shift gate (CPhase). 

1173 

1174 Applies a phase shift of ``exp(i * phi)`` to the |11⟩ component of the 

1175 two-qubit state, leaving all other computational basis states unchanged. 

1176 This is a generalization of the CZ gate: when ``phi = \\pi`` the gate 

1177 reduces to CZ. 

1178 

1179 .. math:: 

1180 \text{CPhase}(\phi) = \text{diag}(1, 1, 1, e^{i\phi}) 

1181 

1182 which is equivalent to 

1183 ``|0⟩⟨0| \\otimes I + |1⟩⟨1| \\otimes P(phi)`` where 

1184 ``P(phi) = diag(1, exp(i*phi))``. 

1185 """ 

1186 

1187 _num_wires = 2 

1188 _param_names = ("phi",) 

1189 is_controlled = True 

1190 

1191 def __init__(self, phi: float, wires: List[int] = [0, 1], **kwargs) -> None: 

1192 """Initialise a controlled phase shift gate. 

1193 

1194 Args: 

1195 phi: Phase shift angle in radians. 

1196 wires: Two-element list ``[control, target]``. 

1197 """ 

1198 self.phi = phi 

1199 phase_gate = jnp.array([[1, 0], [0, jnp.exp(1j * phi)]], dtype=_cdtype()) 

1200 mat = jnp.kron(_P0, Id._matrix) + jnp.kron(_P1, phase_gate) 

1201 super().__init__(wires=wires, matrix=mat, **kwargs) 

1202 

1203 

1204class Rot(Operation): 

1205 """General single-qubit rotation: 

1206 Rot(\\phi, \\theta, \\omega) = RZ(\\omega) RY(\\theta) RZ(\\phi). 

1207 

1208 This is the most general SU(2) rotation (up to a global phase). It 

1209 decomposes into three successive rotations and has three free parameters. 

1210 """ 

1211 

1212 _num_wires = 1 

1213 _param_names = ("phi", "theta", "omega") 

1214 

1215 def __init__( 

1216 self, 

1217 phi: float, 

1218 theta: float, 

1219 omega: float, 

1220 wires: Union[int, List[int]] = 0, 

1221 **kwargs, 

1222 ) -> None: 

1223 """Initialise a general rotation gate. 

1224 

1225 Args: 

1226 phi: First RZ rotation angle (radians). 

1227 theta: RY rotation angle (radians). 

1228 omega: Second RZ rotation angle (radians). 

1229 wires: Qubit index or list of qubit indices this gate acts on. 

1230 """ 

1231 self.phi = phi 

1232 self.theta = theta 

1233 self.omega = omega 

1234 # Rot(\\phi, \theta, \\omega) = RZ(\\omega) @ RY(\theta) @ RZ(\\phi) 

1235 rz_phi = jnp.cos(phi / 2) * Id._matrix - 1j * jnp.sin(phi / 2) * PauliZ._matrix 

1236 ry_theta = ( 

1237 jnp.cos(theta / 2) * Id._matrix - 1j * jnp.sin(theta / 2) * PauliY._matrix 

1238 ) 

1239 rz_omega = ( 

1240 jnp.cos(omega / 2) * Id._matrix - 1j * jnp.sin(omega / 2) * PauliZ._matrix 

1241 ) 

1242 mat = rz_omega @ ry_theta @ rz_phi 

1243 super().__init__(wires=wires, matrix=mat, **kwargs) 

1244 

1245 def decompose(self) -> List["Operation"]: 

1246 """Decompose into ``RZ(phi) RY(theta) RZ(omega)`` (same wire).""" 

1247 w = self.wires[0] 

1248 return [ 

1249 RZ(self.phi, wires=w, record=False), 

1250 RY(self.theta, wires=w, record=False), 

1251 RZ(self.omega, wires=w, record=False), 

1252 ] 

1253 

1254 

1255class PauliRot(Operation): 

1256 """Multi-qubit Pauli rotation: exp(-i \\theta/2 P) for a Pauli word P. 

1257 

1258 The Pauli word is given as a string of ``'I'``, ``'X'``, ``'Y'``, ``'Z'`` 

1259 characters (one per qubit). The rotation matrix is computed as 

1260 ``cos(\\theta/2) I - i sin(\\theta/2) P`` where *P* is the tensor product of the 

1261 corresponding single-qubit Pauli matrices. 

1262 

1263 Example:: 

1264 

1265 PauliRot(0.5, "XY", wires=[0, 1]) 

1266 """ 

1267 

1268 _param_names = ("theta",) 

1269 

1270 # Map from character to 2x2 matrix (canonical single source of truth) 

1271 _PAULI_MAP = _PAULI_MATRICES 

1272 

1273 def __init__( 

1274 self, theta: float, pauli_word: str, wires: Union[int, List[int]] = 0, **kwargs 

1275 ) -> None: 

1276 """Initialise a PauliRot gate. 

1277 

1278 Args: 

1279 theta: Rotation angle in radians. 

1280 pauli_word: A string of ``'I'``, ``'X'``, ``'Y'``, ``'Z'`` 

1281 characters specifying the Pauli tensor product. 

1282 wires: Qubit index or list of qubit indices this gate acts on. 

1283 """ 

1284 from functools import reduce as _reduce 

1285 

1286 self.theta = theta 

1287 self.pauli_word = pauli_word 

1288 

1289 pauli_matrices = [self._PAULI_MAP[c] for c in pauli_word] 

1290 P = _reduce(jnp.kron, pauli_matrices) 

1291 dim = P.shape[0] 

1292 mat = ( 

1293 jnp.cos(theta / 2) * jnp.eye(dim, dtype=_cdtype()) 

1294 - 1j * jnp.sin(theta / 2) * P 

1295 ) 

1296 super().__init__(wires=wires, matrix=mat, **kwargs) 

1297 

1298 def generator(self) -> Operation: 

1299 """Return the generator Pauli tensor product as an :class:`Operation`. 

1300 

1301 The generator of ``PauliRot(\\theta, word, wires)`` is the tensor product 

1302 of single-qubit Pauli matrices specified by *word*. The returned 

1303 :class:`Hermitian` wraps that matrix and the gate's wires. 

1304 

1305 Returns: 

1306 :class:`Hermitian` operation representing the Pauli tensor product. 

1307 """ 

1308 from functools import reduce as _reduce 

1309 

1310 pauli_matrices = [self._PAULI_MAP[c] for c in self.pauli_word] 

1311 P = _reduce(jnp.kron, pauli_matrices) 

1312 return Hermitian(matrix=P, wires=self.wires, record=False) 

1313 

1314 

1315def _make_pauli_rotation_subclass(name: str, word: str) -> type: 

1316 """Build a thin :class:`PauliRot` subclass with the Pauli word fixed. 

1317 

1318 Used to expose multi-qubit Pauli rotations (``RXX``, ``RYY``, ``RZZ``, 

1319 ``RZX``, ...) as standalone classes while sharing PauliRot's matrix 

1320 construction and generator logic. 

1321 """ 

1322 

1323 sep = " \\otimes " 

1324 doc = ( 

1325 f"{name}(\\theta) = exp(-i \\theta/2\\, {sep.join(word)}).\n\n" 

1326 f"Thin :class:`PauliRot` subclass with ``pauli_word={word!r}``." 

1327 ) 

1328 

1329 class _PauliRotSubclass(PauliRot): 

1330 __doc__ = doc 

1331 _num_wires = len(word) 

1332 

1333 def __init__( 

1334 self, 

1335 theta: float, 

1336 wires: Union[int, List[int]] = None, 

1337 **kwargs, 

1338 ) -> None: 

1339 if wires is None: 

1340 wires = list(range(len(word))) 

1341 super().__init__(theta, word, wires=wires, **kwargs) 

1342 

1343 _PauliRotSubclass.__name__ = name 

1344 _PauliRotSubclass.__qualname__ = name 

1345 return _PauliRotSubclass 

1346 

1347 

1348RXX = _make_pauli_rotation_subclass("RXX", "XX") 

1349RYY = _make_pauli_rotation_subclass("RYY", "YY") 

1350RZZ = _make_pauli_rotation_subclass("RZZ", "ZZ") 

1351RZX = _make_pauli_rotation_subclass("RZX", "ZX") 

1352 

1353 

1354# --- Controlled multi-qubit Pauli rotation --------------------------------- 

1355 

1356 

1357class ControlledPauliRot(Operation): 

1358 r"""Multi-controlled multi-qubit Pauli rotation. 

1359 

1360 Applies ``PauliRot(theta, pauli_word)`` on the *target* wires 

1361 conditioned on all *control* wires being in :math:`|1\rangle`. 

1362 

1363 For a single control wire and a single-character Pauli word this 

1364 reduces to the textbook controlled rotations ``CRX``, ``CRY``, 

1365 ``CRZ`` — these are exposed below as thin subclasses. 

1366 

1367 The wire layout is ``[control_0, ..., control_{n_controls-1}, 

1368 target_0, ..., target_{m-1}]`` where ``m = len(pauli_word)``. 

1369 """ 

1370 

1371 _param_names = ("theta",) 

1372 is_controlled = True 

1373 

1374 def __init__( 

1375 self, 

1376 theta: float, 

1377 pauli_word: str, 

1378 wires: List[int], 

1379 n_controls: int = 1, 

1380 **kwargs, 

1381 ) -> None: 

1382 from functools import reduce as _reduce 

1383 

1384 self.theta = theta 

1385 self.pauli_word = pauli_word 

1386 self.n_controls = n_controls 

1387 

1388 wires_list = [wires] if isinstance(wires, int) else list(wires) 

1389 n_targets = len(pauli_word) 

1390 if len(wires_list) != n_controls + n_targets: 

1391 raise ValueError( 

1392 f"ControlledPauliRot expects {n_controls + n_targets} wires " 

1393 f"({n_controls} control + {n_targets} target), got " 

1394 f"{len(wires_list)}." 

1395 ) 

1396 

1397 pauli_matrices = [PauliRot._PAULI_MAP[c] for c in pauli_word] 

1398 P = _reduce(jnp.kron, pauli_matrices) 

1399 d_t = P.shape[0] 

1400 R = ( 

1401 jnp.cos(theta / 2) * jnp.eye(d_t, dtype=_cdtype()) 

1402 - 1j * jnp.sin(theta / 2) * P 

1403 ) 

1404 

1405 d_c = 2**n_controls 

1406 dim = d_c * d_t 

1407 # All control patterns except |1...1> act trivially; the active 

1408 # block sits in the last d_t x d_t slot. 

1409 mat = jnp.eye(dim, dtype=_cdtype()) 

1410 start = (d_c - 1) * d_t 

1411 mat = mat.at[start : start + d_t, start : start + d_t].set(R) 

1412 

1413 super().__init__(wires=wires_list, matrix=mat, **kwargs) 

1414 

1415 def generator(self) -> Operation: 

1416 """Return the (Hermitian) generator on the full wire set.""" 

1417 from functools import reduce as _reduce 

1418 

1419 pauli_matrices = [PauliRot._PAULI_MAP[c] for c in self.pauli_word] 

1420 P = _reduce(jnp.kron, pauli_matrices) 

1421 d_t = P.shape[0] 

1422 d_c = 2**self.n_controls 

1423 dim = d_c * d_t 

1424 gen = jnp.zeros((dim, dim), dtype=_cdtype()) 

1425 start = (d_c - 1) * d_t 

1426 gen = gen.at[start : start + d_t, start : start + d_t].set(P) 

1427 return Hermitian(matrix=gen, wires=self.wires, record=False) 

1428 

1429 

1430def _make_controlled_rotation_subclass(name: str, axis: str) -> type: 

1431 """Build a single-control controlled single-qubit rotation subclass. 

1432 

1433 Reproduces the historical ``CRX``, ``CRY``, ``CRZ`` API as thin 

1434 :class:`ControlledPauliRot` subclasses. 

1435 """ 

1436 

1437 class _CRotation(ControlledPauliRot): 

1438 __doc__ = ( 

1439 f"Controlled rotation around the {axis} axis.\n\n" 

1440 f"Applies R{axis}(\\theta) on the target qubit conditioned on the " 

1441 f"control qubit being in state |1\\rangle.\n\n" 

1442 f".. math::\n" 

1443 f"{name}(\\theta) = |0\\rangle\\langle 0| \\otimes I\n" 

1444 f" + |1\\rangle\\langle 1| \\otimes R{axis}(\\theta)" 

1445 ) 

1446 _num_wires = 2 

1447 

1448 def __init__(self, theta: float, wires: List[int] = [0, 1], **kwargs) -> None: 

1449 super().__init__(theta, axis, wires=wires, n_controls=1, **kwargs) 

1450 

1451 def decompose(self) -> List["Operation"]: 

1452 """Decompose into Clifford + single-qubit Pauli rotations.""" 

1453 c, t = self.wires 

1454 theta = self.theta 

1455 if axis == "Z": 

1456 return [ 

1457 RZ(theta / 2, wires=t, record=False), 

1458 CX(wires=[c, t], record=False), 

1459 RZ(-theta / 2, wires=t, record=False), 

1460 CX(wires=[c, t], record=False), 

1461 ] 

1462 if axis == "X": 

1463 return [ 

1464 H(wires=t, record=False), 

1465 RZ(theta / 2, wires=t, record=False), 

1466 CX(wires=[c, t], record=False), 

1467 RZ(-theta / 2, wires=t, record=False), 

1468 CX(wires=[c, t], record=False), 

1469 H(wires=t, record=False), 

1470 ] 

1471 # axis == "Y" 

1472 return [ 

1473 RX(-jnp.pi / 2, wires=t, record=False), 

1474 RZ(theta / 2, wires=t, record=False), 

1475 CX(wires=[c, t], record=False), 

1476 RZ(-theta / 2, wires=t, record=False), 

1477 RX(jnp.pi / 2, wires=t, record=False), 

1478 ] 

1479 

1480 _CRotation.__name__ = name 

1481 _CRotation.__qualname__ = name 

1482 return _CRotation 

1483 

1484 

1485CRX = _make_controlled_rotation_subclass("CRX", "X") 

1486CRY = _make_controlled_rotation_subclass("CRY", "Y") 

1487CRZ = _make_controlled_rotation_subclass("CRZ", "Z") 

1488 

1489 

1490class KrausChannel(Operation): 

1491 """Base class for noise channels defined by a set of Kraus operators. 

1492 

1493 A Kraus channel \\phi(\\rho ) = \\sigma_k K_k \\rho K_k\\dagger 

1494 is the most general physical 

1495 operation on a quantum state. For a pure unitary gate there is a single 

1496 operator K_0 = U satisfying K_0\\daggerK_0 = I; for noisy channels there are 

1497 multiple operators. 

1498 

1499 Subclasses must implement :meth:`kraus_matrices` and return a list of JAX 

1500 arrays. :meth:`apply_to_state` is intentionally left unimplemented: 

1501 Kraus channels require a density-matrix representation and cannot be 

1502 applied to a pure statevector in general. 

1503 """ 

1504 

1505 def kraus_matrices(self) -> List[jnp.ndarray]: 

1506 """Return the list of Kraus operators for this channel. 

1507 

1508 Returns: 

1509 List of 2-D JAX arrays, each of shape ``(2**k, 2**k)`` where k 

1510 is the number of target qubits. 

1511 

1512 Raises: 

1513 NotImplementedError: Subclasses must override this method. 

1514 """ 

1515 raise NotImplementedError 

1516 

1517 @property 

1518 def matrix(self) -> jnp.ndarray: 

1519 """Raises TypeError — noise channels have no single unitary matrix. 

1520 

1521 Raises: 

1522 TypeError: Always raised; use :meth:`apply_to_density` instead. 

1523 """ 

1524 raise TypeError( 

1525 f"{self.__class__.__name__} is a noise channel and has no single " 

1526 "unitary matrix. Use apply_to_density() instead." 

1527 ) 

1528 

1529 def apply_to_state(self, state: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

1530 """Raises TypeError — noise channels require density-matrix simulation. 

1531 

1532 Args: 

1533 state: Statevector (unused). 

1534 n_qubits: Number of qubits (unused). 

1535 

1536 Raises: 

1537 TypeError: Always raised; use ``execute(type='density')`` instead. 

1538 """ 

1539 raise TypeError( 

1540 f"{self.__class__.__name__} is a noise channel and cannot be " 

1541 "applied to a pure statevector. Use execute(type='density') instead." 

1542 ) 

1543 

1544 def apply_to_state_tensor(self, psi: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

1545 """Raises TypeError — noise channels require density-matrix simulation.""" 

1546 raise TypeError( 

1547 f"{self.__class__.__name__} is a noise channel and cannot be " 

1548 "applied to a pure statevector. Use execute(type='density') instead." 

1549 ) 

1550 

1551 def apply_to_density(self, rho: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

1552 """Apply 

1553 \\phi(\\rho ) = \\sigma_k K_k \\rho K_k\\dagger using tensor-contraction. 

1554 

1555 Uses the shared :func:`_contract_and_restore` helper, summing the 

1556 result over all Kraus operators. 

1557 

1558 Args: 

1559 rho: Density matrix of shape ``(2**n_qubits, 2**n_qubits)``. 

1560 n_qubits: Total number of qubits in the circuit. 

1561 

1562 Returns: 

1563 Updated density matrix of shape ``(2**n_qubits, 2**n_qubits)``. 

1564 """ 

1565 k = len(self.wires) 

1566 dim = 2**n_qubits 

1567 bra_wires = [w + n_qubits for w in self.wires] 

1568 rho_out = jnp.zeros_like(rho) 

1569 

1570 for K in self.kraus_matrices(): 

1571 K_t = K.reshape((2,) * 2 * k) 

1572 K_conj_t = jnp.conj(K_t) 

1573 rho_t = rho.reshape((2,) * 2 * n_qubits) 

1574 rho_t = _contract_and_restore(rho_t, K_t, k, self.wires) 

1575 rho_t = _contract_and_restore(rho_t, K_conj_t, k, bra_wires) 

1576 rho_out = rho_out + rho_t.reshape(dim, dim) 

1577 

1578 return rho_out 

1579 

1580 

1581class BitFlip(KrausChannel): 

1582 r"""Single-qubit bit-flip (Pauli-X) error channel. 

1583 

1584 .. math:: 

1585 K_0 = \sqrt{1-p}\,I, \quad K_1 = \sqrt{p}\,X 

1586 

1587 where *p* \\in [0, 1] is the probability of a bit flip. 

1588 """ 

1589 

1590 _num_wires = 1 

1591 _param_names = ("p",) 

1592 

1593 def __init__(self, p: float, wires: Union[int, List[int]] = 0) -> None: 

1594 """Initialise a bit-flip channel. 

1595 

1596 Args: 

1597 p: Bit-flip probability, must be in [0, 1]. 

1598 wires: Qubit index or list of qubit indices this channel acts on. 

1599 

1600 Raises: 

1601 ValueError: If *p* is outside [0, 1]. 

1602 """ 

1603 if not 0.0 <= p <= 1.0: 

1604 raise ValueError("p must be in [0, 1].") 

1605 self.p = p 

1606 super().__init__(wires=wires) 

1607 

1608 def kraus_matrices(self) -> List[jnp.ndarray]: 

1609 """Return the two Kraus operators for the bit-flip channel. 

1610 

1611 Returns: 

1612 List ``[K0, K1]`` where K0 = \\sqrt (1-p)·I and K1 = \\sqrt p·X. 

1613 """ 

1614 p = self.p 

1615 K0 = jnp.sqrt(1 - p) * Id._matrix 

1616 K1 = jnp.sqrt(p) * PauliX._matrix 

1617 return [K0, K1] 

1618 

1619 

1620class PhaseFlip(KrausChannel): 

1621 r"""Single-qubit phase-flip (Pauli-Z) error channel. 

1622 

1623 .. math:: 

1624 K_0 = \sqrt{1-p}\,I, \quad K_1 = \sqrt{p}\,Z 

1625 

1626 where *p* \\in [0, 1] is the probability of a phase flip. 

1627 """ 

1628 

1629 _num_wires = 1 

1630 _param_names = ("p",) 

1631 

1632 def __init__(self, p: float, wires: Union[int, List[int]] = 0) -> None: 

1633 """Initialise a phase-flip channel. 

1634 

1635 Args: 

1636 p: Phase-flip probability, must be in [0, 1]. 

1637 wires: Qubit index or list of qubit indices this channel acts on. 

1638 

1639 Raises: 

1640 ValueError: If *p* is outside [0, 1]. 

1641 """ 

1642 if not 0.0 <= p <= 1.0: 

1643 raise ValueError("p must be in [0, 1].") 

1644 self.p = p 

1645 super().__init__(wires=wires) 

1646 

1647 def kraus_matrices(self) -> List[jnp.ndarray]: 

1648 """Return the two Kraus operators for the phase-flip channel. 

1649 

1650 Returns: 

1651 List ``[K0, K1]`` where K0 = \\sqrt (1-p)·I and K1 = \\sqrt p·Z. 

1652 """ 

1653 p = self.p 

1654 K0 = jnp.sqrt(1 - p) * Id._matrix 

1655 K1 = jnp.sqrt(p) * PauliZ._matrix 

1656 return [K0, K1] 

1657 

1658 

1659class DepolarizingChannel(KrausChannel): 

1660 r"""Single-qubit depolarizing channel. 

1661 

1662 .. math:: 

1663 K_0 = \sqrt{1-p}\,I,\quad K_1 = \sqrt{p/3}\,X,\quad 

1664 K_2 = \sqrt{p/3}\,Y,\quad K_3 = \sqrt{p/3}\,Z 

1665 

1666 where *p* \\in [0, 1]. At p = 3/4 the channel is fully depolarizing. 

1667 """ 

1668 

1669 _num_wires = 1 

1670 _param_names = ("p",) 

1671 

1672 def __init__(self, p: float, wires: Union[int, List[int]] = 0) -> None: 

1673 """Initialise a depolarizing channel. 

1674 

1675 Args: 

1676 p: Depolarization probability, must be in [0, 1]. 

1677 wires: Qubit index or list of qubit indices this channel acts on. 

1678 

1679 Raises: 

1680 ValueError: If *p* is outside [0, 1]. 

1681 """ 

1682 if not 0.0 <= p <= 1.0: 

1683 raise ValueError("p must be in [0, 1].") 

1684 self.p = p 

1685 super().__init__(wires=wires) 

1686 

1687 def kraus_matrices(self) -> List[jnp.ndarray]: 

1688 """Return the four Kraus operators for the depolarizing channel. 

1689 

1690 Returns: 

1691 List ``[K0, K1, K2, K3]`` corresponding to I, X, Y, Z components. 

1692 """ 

1693 p = self.p 

1694 K0 = jnp.sqrt(1 - p) * Id._matrix 

1695 K1 = jnp.sqrt(p / 3) * PauliX._matrix 

1696 K2 = jnp.sqrt(p / 3) * PauliY._matrix 

1697 K3 = jnp.sqrt(p / 3) * PauliZ._matrix 

1698 return [K0, K1, K2, K3] 

1699 

1700 

1701class AmplitudeDamping(KrausChannel): 

1702 r"""Single-qubit amplitude damping channel. 

1703 

1704 .. math:: 

1705 K_0 = \begin{pmatrix}1 & 0\\ 0 & \sqrt{1-\gamma}\end{pmatrix},\quad 

1706 K_1 = \begin{pmatrix}0 & \sqrt{\gamma}\\ 0 & 0\end{pmatrix} 

1707 

1708 where *\\gamma* \\in [0, 1] is the probability of 

1709 energy loss (|1\\rangle -> |0\\rangle). 

1710 """ 

1711 

1712 _num_wires = 1 

1713 _param_names = ("gamma",) 

1714 

1715 def __init__(self, gamma: float, wires: Union[int, List[int]] = 0) -> None: 

1716 """Initialise an amplitude damping channel. 

1717 

1718 Args: 

1719 gamma: Energy-loss probability, must be in [0, 1]. 

1720 wires: Qubit index or list of qubit indices this channel acts on. 

1721 

1722 Raises: 

1723 ValueError: If *gamma* is outside [0, 1]. 

1724 """ 

1725 if not 0.0 <= gamma <= 1.0: 

1726 raise ValueError("gamma must be in [0, 1].") 

1727 self.gamma = gamma 

1728 super().__init__(wires=wires) 

1729 

1730 def kraus_matrices(self) -> List[jnp.ndarray]: 

1731 """Return the two Kraus operators for the amplitude damping channel. 

1732 

1733 Returns: 

1734 List ``[K0, K1]`` as defined in the class docstring. 

1735 """ 

1736 g = self.gamma 

1737 K0 = jnp.array([[1.0, 0.0], [0.0, jnp.sqrt(1 - g)]], dtype=_cdtype()) 

1738 K1 = jnp.array([[0.0, jnp.sqrt(g)], [0.0, 0.0]], dtype=_cdtype()) 

1739 return [K0, K1] 

1740 

1741 

1742class PhaseDamping(KrausChannel): 

1743 r"""Single-qubit phase damping (dephasing) channel. 

1744 

1745 .. math:: 

1746 K_0 = \begin{pmatrix}1 & 0\\ 0 & \sqrt{1-\gamma}\end{pmatrix},\quad 

1747 K_1 = \begin{pmatrix}0 & 0\\ 0 & \sqrt{\gamma}\end{pmatrix} 

1748 

1749 where *\\gamma* \\in [0, 1] is the phase damping probability. 

1750 """ 

1751 

1752 _num_wires = 1 

1753 _param_names = ("gamma",) 

1754 

1755 def __init__(self, gamma: float, wires: Union[int, List[int]] = 0) -> None: 

1756 """Initialise a phase damping channel. 

1757 

1758 Args: 

1759 gamma: Phase-damping probability, must be in [0, 1]. 

1760 wires: Qubit index or list of qubit indices this channel acts on. 

1761 

1762 Raises: 

1763 ValueError: If *gamma* is outside [0, 1]. 

1764 """ 

1765 if not 0.0 <= gamma <= 1.0: 

1766 raise ValueError("gamma must be in [0, 1].") 

1767 self.gamma = gamma 

1768 super().__init__(wires=wires) 

1769 

1770 def kraus_matrices(self) -> List[jnp.ndarray]: 

1771 """Return the two Kraus operators for the phase damping channel. 

1772 

1773 Returns: 

1774 List ``[K0, K1]`` as defined in the class docstring. 

1775 """ 

1776 g = self.gamma 

1777 K0 = jnp.array([[1.0, 0.0], [0.0, jnp.sqrt(1 - g)]], dtype=_cdtype()) 

1778 K1 = jnp.array([[0.0, 0.0], [0.0, jnp.sqrt(g)]], dtype=_cdtype()) 

1779 return [K0, K1] 

1780 

1781 

1782class ThermalRelaxationError(KrausChannel): 

1783 r"""Single-qubit thermal relaxation error channel. 

1784 

1785 Models simultaneous T_1 energy relaxation and T_2 dephasing. Two regimes 

1786 are handled: 

1787 

1788 T_2 <= T_1 (Markovian dephasing + reset): 

1789 Six Kraus operators built from p_z (phase-flip probability), p_r0 

1790 (reset-to-|0\\rangle probability) and p_r1 (reset-to-|1\\rangle probability). 

1791 

1792 T_2 > T_1 (non-Markovian; Choi matrix decomposition): 

1793 The Choi matrix is assembled from the relaxation/dephasing rates, then 

1794 diagonalised; Kraus operators are K_i = \sqrt \lambda_i · mat(v_i). 

1795 

1796 Attributes: 

1797 pe: Excited-state population (thermal population of |1\\rangle). 

1798 t1: T_1 longitudinal relaxation time. 

1799 t2: T_2 transverse dephasing time. 

1800 tg: Gate duration. 

1801 """ 

1802 

1803 _num_wires = 1 

1804 _param_names = ("pe", "t1", "t2", "tg") 

1805 

1806 def __init__( 

1807 self, 

1808 pe: float, 

1809 t1: float, 

1810 t2: float, 

1811 tg: float, 

1812 wires: Union[int, List[int]] = 0, 

1813 ) -> None: 

1814 """Initialise a thermal relaxation error channel. 

1815 

1816 Args: 

1817 pe: Excited-state population (thermal population of |1\\rangle), in [0, 1]. 

1818 t1: T_1 longitudinal relaxation time, must be > 0. 

1819 t2: T_2 transverse dephasing time, must be > 0 and <= 2·T_1. 

1820 tg: Gate duration, must be >= 0. 

1821 wires: Qubit index or list of qubit indices this channel acts on. 

1822 

1823 Raises: 

1824 ValueError: If any parameter violates the stated constraints. 

1825 """ 

1826 if not 0.0 <= pe <= 1.0: 

1827 raise ValueError("pe must be in [0, 1].") 

1828 if t1 <= 0: 

1829 raise ValueError("t1 must be > 0.") 

1830 if t2 <= 0: 

1831 raise ValueError("t2 must be > 0.") 

1832 if t2 > 2 * t1: 

1833 raise ValueError("t2 must be <= 2·t1.") 

1834 if tg < 0: 

1835 raise ValueError("tg must be >= 0.") 

1836 self.pe = pe 

1837 self.t1 = t1 

1838 self.t2 = t2 

1839 self.tg = tg 

1840 super().__init__(wires=wires) 

1841 

1842 def kraus_matrices(self) -> List[jnp.ndarray]: 

1843 """Return the Kraus operators for the thermal relaxation channel. 

1844 

1845 The number of operators depends on the regime: 

1846 

1847 * T_2 <= T_1: six operators (identity, phase-flip, two reset-to-|0\\rangle, 

1848 two reset-to-|1\\rangle). 

1849 * T_2 > T_1: four operators derived from the Choi matrix eigendecomposition. 

1850 

1851 Returns: 

1852 List of 2x2 JAX arrays representing the Kraus operators. 

1853 """ 

1854 pe, t1, t2, tg = self.pe, self.t1, self.t2, self.tg 

1855 

1856 eT1 = jnp.exp(-tg / t1) 

1857 p_reset = 1.0 - eT1 

1858 eT2 = jnp.exp(-tg / t2) 

1859 

1860 if t2 <= t1: 

1861 # --- Case T_2 <= T_1: six Kraus operators --- 

1862 pz = (1.0 - p_reset) * (1.0 - eT2 / eT1) / 2.0 

1863 pr0 = (1.0 - pe) * p_reset 

1864 pr1 = pe * p_reset 

1865 pid = 1.0 - pz - pr0 - pr1 

1866 

1867 K0 = jnp.sqrt(pid) * jnp.eye(2, dtype=_cdtype()) 

1868 K1 = jnp.sqrt(pz) * jnp.array([[1, 0], [0, -1]], dtype=_cdtype()) 

1869 K2 = jnp.sqrt(pr0) * jnp.array([[1, 0], [0, 0]], dtype=_cdtype()) 

1870 K3 = jnp.sqrt(pr0) * jnp.array([[0, 1], [0, 0]], dtype=_cdtype()) 

1871 K4 = jnp.sqrt(pr1) * jnp.array([[0, 0], [1, 0]], dtype=_cdtype()) 

1872 K5 = jnp.sqrt(pr1) * jnp.array([[0, 0], [0, 1]], dtype=_cdtype()) 

1873 return [K0, K1, K2, K3, K4, K5] 

1874 

1875 else: 

1876 # --- Case T_2 > T_1: Choi matrix decomposition --- 

1877 # Choi matrix (column-major / reshaping convention matching PennyLane) 

1878 choi = jnp.array( 

1879 [ 

1880 [1 - pe * p_reset, 0, 0, eT2], 

1881 [0, pe * p_reset, 0, 0], 

1882 [0, 0, (1 - pe) * p_reset, 0], 

1883 [eT2, 0, 0, 1 - (1 - pe) * p_reset], 

1884 ], 

1885 dtype=_cdtype(), 

1886 ) 

1887 eigenvalues, eigenvectors = jnp.linalg.eigh(choi) 

1888 # Each eigenvector (column of eigenvectors) reshaped as 2x2 -> one Kraus op 

1889 kraus = [] 

1890 for i in range(4): 

1891 lam = eigenvalues[i] 

1892 vec = eigenvectors[:, i] 

1893 mat = jnp.sqrt(jnp.abs(lam)) * vec.reshape(2, 2, order="F") 

1894 kraus.append(mat.astype(_cdtype())) 

1895 return kraus 

1896 

1897 

1898class QubitChannel(KrausChannel): 

1899 """Generic Kraus channel from a user-supplied list of Kraus operators. 

1900 

1901 This replaces PennyLane's ``qml.QubitChannel`` and accepts an arbitrary set 

1902 of Kraus matrices satisfying \\sigma_k K_k\\dagger K_k = I. 

1903 

1904 Example:: 

1905 

1906 kraus_ops = [jnp.sqrt(0.9) * jnp.eye(2), jnp.sqrt(0.1) * PauliX._matrix] 

1907 QubitChannel(kraus_ops, wires=0) 

1908 """ 

1909 

1910 def __init__( 

1911 self, kraus_ops: List[jnp.ndarray], wires: Union[int, List[int]] = 0 

1912 ) -> None: 

1913 """Initialise a generic Kraus channel. 

1914 

1915 Args: 

1916 kraus_ops: List of Kraus matrices. Each must be a square 2D array 

1917 of dimension ``2**k x 2**k`` where k = ``len(wires)``. 

1918 wires: Qubit index or list of qubit indices this channel acts on. 

1919 """ 

1920 self._kraus_ops = [jnp.asarray(K, dtype=_cdtype()) for K in kraus_ops] 

1921 super().__init__(wires=wires) 

1922 

1923 def kraus_matrices(self) -> List[jnp.ndarray]: 

1924 """Return the stored Kraus operators. 

1925 

1926 Returns: 

1927 List of Kraus operator matrices. 

1928 """ 

1929 return self._kraus_ops 

1930 

1931 

1932def evolve_pauli_with_clifford( 

1933 clifford: Operation, 

1934 pauli: Operation, 

1935 adjoint_left: bool = True, 

1936) -> Operation: 

1937 """Compute C\\dagger P C (or C P C\\dagger) and 

1938 return the result as an Operation. 

1939 

1940 Both operators are first embedded into the full Hilbert space spanned by 

1941 the union of their wire sets. The result is wrapped in a 

1942 :class:`Hermitian` so it can be used in further algebra. 

1943 

1944 Args: 

1945 clifford: A Clifford gate. 

1946 pauli: A Pauli / Hermitian operator. 

1947 adjoint_left: If ``True``, compute C\\dagger P C; otherwise C P C\\dagger. 

1948 

1949 Returns: 

1950 A :class:`Hermitian` wrapping the evolved matrix. 

1951 """ 

1952 all_wires = sorted(set(clifford.wires) | set(pauli.wires)) 

1953 n = len(all_wires) 

1954 

1955 C = _embed_matrix(clifford.matrix, clifford.wires, all_wires, n) 

1956 P = _embed_matrix(pauli.matrix, pauli.wires, all_wires, n) 

1957 Cd = jnp.conj(C).T 

1958 

1959 if adjoint_left: 

1960 result = Cd @ P @ C 

1961 else: 

1962 result = C @ P @ Cd 

1963 

1964 return Hermitian(matrix=result, wires=all_wires, record=False) 

1965 

1966 

1967def _embed_matrix( 

1968 mat: jnp.ndarray, 

1969 op_wires: list, 

1970 all_wires: list, 

1971 n_total: int, 

1972) -> jnp.ndarray: 

1973 """Embed a gate matrix into a larger Hilbert space via tensor products. 

1974 

1975 If the gate already acts on all wires, the matrix is returned as-is. 

1976 Otherwise the gate matrix is tensored with identities on the missing 

1977 wires, and the resulting matrix rows/columns are permuted so that qubit 

1978 ordering matches *all_wires*. 

1979 

1980 Args: 

1981 mat: The gate's unitary matrix of shape ``(2**k, 2**k)`` where 

1982 ``k = len(op_wires)``. 

1983 op_wires: The wires the gate acts on. 

1984 all_wires: The full ordered list of wires. 

1985 n_total: ``len(all_wires)``. 

1986 

1987 Returns: 

1988 A ``(2**n_total, 2**n_total)`` matrix. 

1989 """ 

1990 k = len(op_wires) 

1991 if k == n_total and list(op_wires) == list(all_wires): 

1992 return mat 

1993 

1994 # Build the full-space matrix by tensoring with identities 

1995 # Strategy: tensor I on missing wires, then permute 

1996 missing = [w for w in all_wires if w not in op_wires] 

1997 # Full matrix = mat \\otimes I_{missing} 

1998 full_mat = mat 

1999 for _ in missing: 

2000 full_mat = jnp.kron(full_mat, jnp.eye(2, dtype=_cdtype())) 

2001 

2002 # The current ordering is [op_wires..., missing...] 

2003 # We need to permute to match all_wires ordering 

2004 current_order = list(op_wires) + missing 

2005 if current_order != list(all_wires): 

2006 perm = [current_order.index(w) for w in all_wires] 

2007 full_mat = _permute_matrix(full_mat, perm, n_total) 

2008 

2009 return full_mat 

2010 

2011 

2012def _permute_matrix(mat: jnp.ndarray, perm: list, n_qubits: int) -> jnp.ndarray: 

2013 """Permute the qubit ordering of a matrix. 

2014 

2015 Given a ``(2**n, 2**n)`` matrix and a permutation of ``[0..n-1]``, 

2016 reorder the qubits so that qubit ``i`` moves to position ``perm[i]``. 

2017 

2018 Args: 

2019 mat: Square matrix of dimension ``2**n_qubits``. 

2020 perm: Permutation list. 

2021 n_qubits: Number of qubits. 

2022 

2023 Returns: 

2024 Permuted matrix of the same shape. 

2025 """ 

2026 dim = 2**n_qubits 

2027 # Reshape to tensor, permute axes, reshape back 

2028 tensor = mat.reshape([2] * (2 * n_qubits)) 

2029 # Axes: first n_qubits are row indices, last n_qubits are column indices 

2030 row_perm = perm 

2031 col_perm = [p + n_qubits for p in perm] 

2032 tensor = jnp.transpose(tensor, row_perm + col_perm) 

2033 return tensor.reshape(dim, dim) 

2034 

2035 

2036def _dominant_pauli_label(matrix: jnp.ndarray) -> Tuple[complex, str]: 

2037 r"""Return the dominant Pauli term ``(coeff, label)`` of a matrix. 

2038 

2039 Finds the Pauli tensor product :math:`P` (over ``I, X, Y, Z``) with the 

2040 largest :math:`|c_P|`, where :math:`c_P = \mathrm{Tr}(P M) / 2^n`. Shared 

2041 by :func:`pauli_decompose` and :meth:`PauliWord.from_matrix` so the 

2042 brute-force search lives in one place. 

2043 

2044 Args: 

2045 matrix: A ``(2**n, 2**n)`` matrix. 

2046 

2047 Returns: 

2048 ``(coeff, label)`` with *label* a string over ``{I, X, Y, Z}``. 

2049 """ 

2050 from itertools import product as _product 

2051 from functools import reduce as _reduce 

2052 

2053 dim = matrix.shape[0] 

2054 n_qubits = int(jnp.round(jnp.log2(dim))) 

2055 

2056 best_label = "I" * n_qubits 

2057 best_coeff = 0.0 

2058 for indices in _product(range(4), repeat=n_qubits): 

2059 P = _reduce(jnp.kron, [_PAULI_MATS[i] for i in indices]) 

2060 coeff = jnp.trace(P @ matrix) / dim 

2061 if jnp.abs(coeff) > jnp.abs(best_coeff): 

2062 best_coeff = coeff 

2063 best_label = "".join(_PAULI_LABELS[i] for i in indices) 

2064 return best_coeff, best_label 

2065 

2066 

2067def pauli_decompose(matrix: jnp.ndarray, wire_order: Optional[List[int]] = None): 

2068 r"""Decompose a Hermitian matrix into a sum of Pauli tensor products. 

2069 

2070 For an n-qubit matrix (``2**n x 2**n``), returns the dominant Pauli 

2071 term (the one with the largest absolute coefficient), wrapped as an 

2072 :class:`Operation`. This is sufficient for the Fourier-tree algorithm 

2073 which only needs the single non-zero Pauli term produced by Clifford 

2074 conjugation of a Pauli operator. 

2075 

2076 The decomposition uses the trace formula: 

2077 ``c_P = Tr(P · M) / 2**n`` 

2078 

2079 Args: 

2080 matrix: A ``(2**n, 2**n)`` Hermitian matrix. 

2081 wire_order: Optional list of wire indices. If ``None``, defaults 

2082 to ``[0, 1, ..., n-1]``. 

2083 

2084 Returns: 

2085 A tuple ``(coeff, op)`` where *coeff* is the complex coefficient and 

2086 *op* is the Pauli :class:`Operation` (PauliX, PauliY, PauliZ, I, or 

2087 a :class:`Hermitian` for multi-qubit tensor products). 

2088 """ 

2089 from functools import reduce as _reduce 

2090 

2091 dim = matrix.shape[0] 

2092 n_qubits = int(jnp.round(jnp.log2(dim))) 

2093 

2094 if wire_order is None: 

2095 wire_order = list(range(n_qubits)) 

2096 

2097 best_coeff, pauli_label = _dominant_pauli_label(matrix) 

2098 label_to_idx = {label: i for i, label in enumerate(_PAULI_LABELS)} 

2099 

2100 # Build the operation for the dominant term 

2101 if sum(1 for ch in pauli_label if ch != "I") <= 1: 

2102 # Single-qubit Pauli on one wire (or all-identity) 

2103 for q, ch in enumerate(pauli_label): 

2104 if ch != "I": 

2105 result_op = _PAULI_CLASSES[label_to_idx[ch]]( 

2106 wires=wire_order[q], record=False 

2107 ) 

2108 result_op._pauli_label = ch 

2109 return best_coeff, result_op 

2110 result_op = Id(wires=wire_order[0], record=False) 

2111 result_op._pauli_label = "I" * n_qubits 

2112 return best_coeff, result_op 

2113 else: 

2114 # Multi-qubit tensor product -> Hermitian with pauli label attached 

2115 P = _reduce(jnp.kron, [_PAULI_MATRICES[ch] for ch in pauli_label]) 

2116 result_op = Hermitian(matrix=P, wires=wire_order, record=False) 

2117 result_op._pauli_label = pauli_label 

2118 return best_coeff, result_op 

2119 

2120 

2121def pauli_string_from_operation(op: Operation) -> str: 

2122 """Extract a Pauli word string from an operation. 

2123 

2124 Maps ``PauliX`` -> ``"X"``, ``PauliY`` -> ``"Y"``, ``PauliZ`` -> ``"Z"``, 

2125 ``I`` -> ``"I"``. For :class:`PauliRot`, returns its stored ``pauli_word``. 

2126 For operations produced by :func:`pauli_decompose`, returns the stored 

2127 ``_pauli_label`` attribute. 

2128 

2129 Args: 

2130 op: A quantum operation. 

2131 

2132 Returns: 

2133 A string like ``"X"``, ``"ZZ"``, etc. 

2134 """ 

2135 if isinstance(op, PauliRot) and hasattr(op, "pauli_word"): 

2136 return op.pauli_word 

2137 # Check for label stored by pauli_decompose 

2138 if hasattr(op, "_pauli_label"): 

2139 return op._pauli_label 

2140 name_map = {"PauliX": "X", "PauliY": "Y", "PauliZ": "Z", "I": "I"} 

2141 if op.name in name_map: 

2142 return name_map[op.name] 

2143 # Fall back: decompose the matrix 

2144 _, pauli_op = pauli_decompose(op.matrix, wire_order=op.wires) 

2145 return pauli_op._pauli_label 

2146 

2147 

2148def prod(*ops: Operation) -> Operation: 

2149 """Construct the generalized product (tensor or matrix) of multiple operations. 

2150 

2151 The resulting operation acts on the union of all wire sets. 

2152 If the wire sets are disjoint, this is a Kronecker product. 

2153 If the wire sets overlap, the corresponding matrices are multiplied. 

2154 

2155 Args: 

2156 *ops: Variable number of :class:`Operation` instances. 

2157 

2158 Returns: 

2159 A new :class:`Operation` whose matrix represents the composed 

2160 operation on the unified wire set. 

2161 """ 

2162 if not ops: 

2163 raise ValueError("At least one operation must be provided to prod().") 

2164 return ops[0].prod(*ops[1:]) 

2165 

2166 

2167# Single-qubit (x, z) bit pattern -> Pauli label, with the convention that a 

2168# Pauli word is stored as i^phase * prod_q X_q^{x_q} Z_q^{z_q}. 

2169# Under this convention Y = i * X * Z, so the single-qubit Y carries x=z=1. 

2170_XZ_TO_LABEL = {(0, 0): "I", (1, 0): "X", (0, 1): "Z", (1, 1): "Y"} 

2171_LABEL_TO_XZ = {"I": (0, 0), "X": (1, 0), "Z": (0, 1), "Y": (1, 1)} 

2172 

2173 

2174class PauliWord: 

2175 r"""Symbolic n-qubit Pauli operator in the stabilizer-tableau (symplectic) 

2176 representation. 

2177 

2178 A Pauli word is stored as 

2179 

2180 .. math:: 

2181 P = i^{\text{phase}} \prod_{q} X_q^{x_q} Z_q^{z_q}, 

2182 

2183 with bit arrays ``x, z \in \{0, 1\}^n`` and an integer ``phase`` taken mod 4 

2184 (tracking the scalar ``i^{phase}``). Single-qubit Paulis map as 

2185 ``I=(0,0)``, ``X=(1,0)``, ``Z=(0,1)``, ``Y=(1,1)`` (since ``Y = i X Z``). 

2186 

2187 This replaces the matrix-based Clifford conjugation 

2188 (:func:`evolve_pauli_with_clifford` + :func:`pauli_decompose`) with O(n) 

2189 symbolic updates, and is shared by both 

2190 :class:`~qml_essentials.pauli.PauliCircuit` and the Fourier-tree algorithm. 

2191 

2192 All operations use NumPy (integer arithmetic), not JAX — this is symbolic 

2193 bookkeeping, not numeric computation. 

2194 """ 

2195 

2196 __slots__ = ("x", "z", "phase") 

2197 

2198 def __init__(self, x: np.ndarray, z: np.ndarray, phase: int = 0) -> None: 

2199 """Initialise a Pauli word. 

2200 

2201 Args: 

2202 x: Integer/boolean array of X-component bits, length ``n_qubits``. 

2203 z: Integer/boolean array of Z-component bits, length ``n_qubits``. 

2204 phase: Exponent of the global ``i^{phase}`` scalar (taken mod 4). 

2205 """ 

2206 self.x = np.asarray(x, dtype=np.int8) & 1 

2207 self.z = np.asarray(z, dtype=np.int8) & 1 

2208 self.phase = int(phase) % 4 

2209 

2210 # ---- constructors --------------------------------------------------- 

2211 @classmethod 

2212 def identity(cls, n_qubits: int) -> "PauliWord": 

2213 """Return the identity Pauli word on *n_qubits*.""" 

2214 z = np.zeros(n_qubits, dtype=np.int8) 

2215 return cls(z.copy(), z, 0) 

2216 

2217 @classmethod 

2218 def from_pauli_string( 

2219 cls, pauli_string: str, wires: List[int], n_qubits: int 

2220 ) -> "PauliWord": 

2221 """Build a Pauli word from a Pauli string and its wires. 

2222 

2223 Args: 

2224 pauli_string: String over ``{'I', 'X', 'Y', 'Z'}``; one character 

2225 per entry of *wires*. 

2226 wires: Qubit indices the characters act on. 

2227 n_qubits: Total number of qubits in the circuit. 

2228 

2229 Returns: 

2230 The corresponding :class:`PauliWord`. 

2231 """ 

2232 x = np.zeros(n_qubits, dtype=np.int8) 

2233 z = np.zeros(n_qubits, dtype=np.int8) 

2234 n_y = 0 

2235 for ch, w in zip(pauli_string, wires): 

2236 xb, zb = _LABEL_TO_XZ[ch] 

2237 x[w] = xb 

2238 z[w] = zb 

2239 if ch == "Y": 

2240 n_y += 1 

2241 # Each Y contributes a factor i (Y = i X Z), accumulated into phase. 

2242 return cls(x, z, n_y % 4) 

2243 

2244 @classmethod 

2245 def from_operation(cls, op: "Operation", n_qubits: int) -> "PauliWord": 

2246 """Build a Pauli word from a Pauli-like operation. 

2247 

2248 Supports :class:`PauliX`/:class:`PauliY`/:class:`PauliZ`/:class:`Id`, 

2249 :class:`PauliRot` (via its ``pauli_word``), and any operation carrying a 

2250 ``_pauli_label`` (e.g. produced by :func:`pauli_decompose`) or otherwise 

2251 decomposable by :func:`pauli_string_from_operation`. 

2252 

2253 Args: 

2254 op: The operation to convert. 

2255 n_qubits: Total number of qubits in the circuit. 

2256 

2257 Returns: 

2258 The corresponding :class:`PauliWord`. 

2259 """ 

2260 # Cached symbolic word (e.g. attached to a Clifford-evolved observable). 

2261 cached = getattr(op, "_pauli_word", None) 

2262 if isinstance(cached, PauliWord) and cached.n_qubits == n_qubits: 

2263 return cached 

2264 if isinstance(op, PauliRot): 

2265 return cls.from_pauli_string(op.pauli_word, op.wires, n_qubits) 

2266 # Single-qubit Pauli rotations: generator is the corresponding Pauli. 

2267 rot_to_label = {"RX": "X", "RY": "Y", "RZ": "Z"} 

2268 if op.name in rot_to_label: 

2269 return cls.from_pauli_string(rot_to_label[op.name], op.wires, n_qubits) 

2270 name_to_label = {"PauliX": "X", "PauliY": "Y", "PauliZ": "Z", "I": "I"} 

2271 if op.name in name_to_label: 

2272 return cls.from_pauli_string(name_to_label[op.name], op.wires, n_qubits) 

2273 pauli_str = pauli_string_from_operation(op) 

2274 return cls.from_pauli_string(pauli_str, op.wires, n_qubits) 

2275 

2276 @property 

2277 def n_qubits(self) -> int: 

2278 """Number of qubits this Pauli word spans.""" 

2279 return self.x.shape[0] 

2280 

2281 @property 

2282 def xy_mask(self) -> np.ndarray: 

2283 """Boolean mask of qubits carrying an X or Y (i.e. ``x`` bits set).""" 

2284 return self.x.astype(bool) 

2285 

2286 @property 

2287 def is_diagonal(self) -> bool: 

2288 """Whether the word is diagonal (only I/Z, i.e. no X component).""" 

2289 return not bool(self.x.any()) 

2290 

2291 # ---- algebra -------------------------------------------------------- 

2292 def commutes_with(self, other: "PauliWord") -> bool: 

2293 """Return whether this Pauli word commutes with *other*. 

2294 

2295 Two Paulis commute iff their symplectic inner product vanishes mod 2. 

2296 """ 

2297 sp = int(np.dot(self.x, other.z) + np.dot(self.z, other.x)) % 2 

2298 return sp == 0 

2299 

2300 def compose(self, other: "PauliWord") -> "PauliWord": 

2301 r"""Return the operator product ``self @ other`` as a new Pauli word. 

2302 

2303 Uses the exact symplectic product rule 

2304 

2305 .. math:: 

2306 (X^{x_1} Z^{z_1})(X^{x_2} Z^{z_2}) 

2307 = (-1)^{z_1 \cdot x_2}\, X^{x_1 \oplus x_2} Z^{z_1 \oplus z_2}, 

2308 

2309 combined with the ``i^{phase}`` scalars (``-1 = i^2``). 

2310 """ 

2311 new_x = self.x ^ other.x 

2312 new_z = self.z ^ other.z 

2313 cross = int(np.dot(self.z, other.x)) 

2314 new_phase = (self.phase + other.phase + 2 * cross) % 4 

2315 return PauliWord(new_x, new_z, new_phase) 

2316 

2317 def conjugate_by_clifford( 

2318 self, clifford: "Operation", adjoint_left: bool = False 

2319 ) -> "PauliWord": 

2320 r"""Return the Clifford conjugation of this Pauli word. 

2321 

2322 Computes ``C P C^\dagger`` (``adjoint_left=False``) or 

2323 ``C^\dagger P C`` (``adjoint_left=True``) symbolically, where *C* is one 

2324 of the supported Clifford gates ``H, S, CX, CZ`` or a Pauli gate 

2325 ``PauliX/PauliY/PauliZ``. 

2326 

2327 The conjugation is realised by substituting the images of the 

2328 single-qubit generators ``X_q`` and ``Z_q`` and re-composing in canonical 

2329 order, so all phases are tracked exactly by :meth:`compose`. 

2330 

2331 Args: 

2332 clifford: The Clifford operation to conjugate by. 

2333 adjoint_left: If ``True`` compute ``C^\dagger P C``; else 

2334 ``C P C^\dagger``. 

2335 

2336 Returns: 

2337 The conjugated :class:`PauliWord`. 

2338 

2339 Raises: 

2340 NotImplementedError: If *clifford* is not a supported gate. 

2341 """ 

2342 n = self.n_qubits 

2343 name = clifford.name 

2344 

2345 # Pauli gates: conjugation is just Q P Q (Q is Hermitian => Q^dagger=Q). 

2346 if name in ("PauliX", "PauliY", "PauliZ"): 

2347 q = PauliWord.from_operation(clifford, n) 

2348 return q.compose(self).compose(q) 

2349 

2350 try: 

2351 images_x, images_z = self._clifford_generator_images( 

2352 name, list(clifford.wires), adjoint_left, n 

2353 ) 

2354 except NotImplementedError: 

2355 # Any other Clifford (e.g. CY): fall back to the (exact) matrix 

2356 # conjugation, which works for arbitrary Cliffords at O(2^n) cost. 

2357 return self._conjugate_via_matrix(clifford, adjoint_left) 

2358 

2359 result = PauliWord.identity(n) 

2360 result.phase = self.phase 

2361 for q in range(n): 

2362 if self.x[q]: 

2363 result = result.compose(images_x[q]) 

2364 if self.z[q]: 

2365 result = result.compose(images_z[q]) 

2366 return result 

2367 

2368 def _conjugate_via_matrix( 

2369 self, clifford: "Operation", adjoint_left: bool 

2370 ) -> "PauliWord": 

2371 """Matrix-based Clifford conjugation fallback (exact, any Clifford). 

2372 

2373 Used by :meth:`conjugate_by_clifford` for Cliffords without a symbolic 

2374 tableau rule. Reuses :meth:`to_matrix` / :meth:`from_matrix` and the 

2375 gate's dense matrix. 

2376 """ 

2377 n = self.n_qubits 

2378 C = _embed_matrix(clifford.matrix, clifford.wires, list(range(n)), n) 

2379 Cd = jnp.conj(C).T 

2380 mat = self.to_matrix() 

2381 result = (Cd @ mat @ C) if adjoint_left else (C @ mat @ Cd) 

2382 return PauliWord.from_matrix(result) 

2383 

2384 @staticmethod 

2385 def _clifford_generator_images( 

2386 name: str, wires: List[int], adjoint_left: bool, n: int 

2387 ) -> Tuple[List["PauliWord"], List["PauliWord"]]: 

2388 """Images of single-qubit generators ``X_q``/``Z_q`` under a Clifford. 

2389 

2390 Returns two lists (indexed by qubit) of :class:`PauliWord` giving 

2391 ``C X_q C^\\dagger`` and ``C Z_q C^\\dagger`` (or the adjoint direction). 

2392 Qubits outside the gate support map to themselves. 

2393 """ 

2394 

2395 def single(label: str, q: int) -> "PauliWord": 

2396 return PauliWord.from_pauli_string(label, [q], n) 

2397 

2398 images_x = [single("X", q) for q in range(n)] 

2399 images_z = [single("Z", q) for q in range(n)] 

2400 

2401 if name == "H": 

2402 w = wires[0] 

2403 images_x[w] = single("Z", w) # H X H = Z 

2404 images_z[w] = single("X", w) # H Z H = X 

2405 elif name == "S": 

2406 w = wires[0] 

2407 if adjoint_left: # S^dagger X S = -Y ; S^dagger Z S = Z 

2408 images_x[w] = PauliWord.from_pauli_string("Y", [w], n).compose( 

2409 PauliWord(np.zeros(n, np.int8), np.zeros(n, np.int8), 2) 

2410 ) 

2411 else: # S X S^dagger = Y ; S Z S^dagger = Z 

2412 images_x[w] = single("Y", w) 

2413 # images_z[w] unchanged (Z) 

2414 elif name == "CX": 

2415 c, t = wires 

2416 images_x[c] = single("X", c).compose(single("X", t)) # X_c -> X_c X_t 

2417 images_z[t] = single("Z", c).compose(single("Z", t)) # Z_t -> Z_c Z_t 

2418 # X_t -> X_t and Z_c -> Z_c unchanged ; CX is Hermitian 

2419 elif name == "CZ": 

2420 c, t = wires 

2421 images_x[c] = single("X", c).compose(single("Z", t)) # X_c -> X_c Z_t 

2422 images_x[t] = single("Z", c).compose(single("X", t)) # X_t -> Z_c X_t 

2423 # Z_c, Z_t unchanged ; CZ is Hermitian 

2424 elif name == "SWAP": 

2425 a, b = wires 

2426 images_x[a], images_x[b] = single("X", b), single("X", a) # swap supports 

2427 images_z[a], images_z[b] = single("Z", b), single("Z", a) 

2428 else: 

2429 raise NotImplementedError(f"No symbolic Clifford rule for gate '{name}'.") 

2430 return images_x, images_z 

2431 

2432 # ---- expectation / conversions ------------------------------------- 

2433 def zero_expectation(self) -> complex: 

2434 r"""Return ``<0|P|0>`` for the all-zero computational basis state. 

2435 

2436 Non-zero only for diagonal words (I/Z only), in which case it equals the 

2437 global phase ``i^{phase}``. 

2438 """ 

2439 if not self.is_diagonal: 

2440 return 0.0 + 0.0j 

2441 return complex(1j**self.phase) 

2442 

2443 def to_pauli_string(self) -> str: 

2444 """Return the bare Pauli string (ignoring the global phase).""" 

2445 return "".join( 

2446 _XZ_TO_LABEL[(int(self.x[q]), int(self.z[q]))] for q in range(self.n_qubits) 

2447 ) 

2448 

2449 def leading_phase(self) -> complex: 

2450 r"""Return the scalar ``c`` such that ``P = c * (bare Pauli string)``. 

2451 

2452 Because the bare string already contains ``i^{n_Y}`` from its Y factors, 

2453 ``c = i^{phase - n_Y}``. 

2454 """ 

2455 n_y = int(((self.x == 1) & (self.z == 1)).sum()) 

2456 return complex(1j ** ((self.phase - n_y) % 4)) 

2457 

2458 def to_pauli_string_and_phase(self) -> Tuple[str, complex]: 

2459 """Return ``(bare Pauli string, leading scalar phase)``.""" 

2460 return self.to_pauli_string(), self.leading_phase() 

2461 

2462 def to_matrix(self) -> jnp.ndarray: 

2463 r"""Return the dense operator matrix ``i^{phase} \bigotimes_q X^{x_q} Z^{z_q}``. 

2464 

2465 The per-qubit factor is the symplectic product ``X^{x} Z^{z}`` (so the 

2466 ``(1, 1)`` factor is ``XZ = -iY``; the ``Y``-vs-``XZ`` phase is carried by 

2467 ``i^{phase}``). Inverse of :meth:`from_matrix`. 

2468 """ 

2469 ident = _PAULI_MATRICES["I"] 

2470 xmat = _PAULI_MATRICES["X"] 

2471 zmat = _PAULI_MATRICES["Z"] 

2472 mat = jnp.array([[1.0 + 0.0j]], dtype=_cdtype()) 

2473 for q in range(self.n_qubits): 

2474 factor = (xmat if self.x[q] else ident) @ (zmat if self.z[q] else ident) 

2475 mat = jnp.kron(mat, factor) 

2476 return (1j**self.phase) * mat 

2477 

2478 @classmethod 

2479 def from_matrix(cls, matrix: jnp.ndarray) -> "PauliWord": 

2480 r"""Build a Pauli word from a matrix that is a single (signed) Pauli. 

2481 

2482 Recovers the dominant Pauli string and folds its (unit) coefficient 

2483 ``c = i^k`` into the word's phase. Intended for matrices that are 

2484 exactly a Pauli up to a ``{\pm 1, \pm i}`` scalar (e.g. the result of 

2485 Clifford conjugation of a Pauli); the dominant term is returned for 

2486 general inputs. 

2487 

2488 Args: 

2489 matrix: A ``(2**n, 2**n)`` matrix proportional to a Pauli string. 

2490 

2491 Returns: 

2492 The corresponding :class:`PauliWord` on ``n`` qubits. 

2493 """ 

2494 coeff, label = _dominant_pauli_label(matrix) 

2495 n = len(label) 

2496 word = cls.from_pauli_string(label, list(range(n)), n) 

2497 # Fold the unit coefficient c = i^k into the phase. 

2498 k = int(round(np.angle(complex(coeff)) / (np.pi / 2))) % 4 

2499 word.phase = (word.phase + k) % 4 

2500 return word 

2501 

2502 def to_list_repr(self) -> np.ndarray: 

2503 """Return the legacy int list representation (I=-1, X=0, Y=1, Z=2).""" 

2504 out = np.full(self.n_qubits, -1, dtype=int) 

2505 for q in range(self.n_qubits): 

2506 label = _XZ_TO_LABEL[(int(self.x[q]), int(self.z[q]))] 

2507 out[q] = {"I": -1, "X": 0, "Y": 1, "Z": 2}[label] 

2508 return out 

2509 

2510 def __eq__(self, other: object) -> bool: 

2511 if not isinstance(other, PauliWord): 

2512 return NotImplemented 

2513 return ( 

2514 self.phase == other.phase 

2515 and np.array_equal(self.x, other.x) 

2516 and np.array_equal(self.z, other.z) 

2517 ) 

2518 

2519 def __repr__(self) -> str: 

2520 phase_str = {0: "+", 1: "+i", 2: "-", 3: "-i"}[self.phase] 

2521 return f"PauliWord({phase_str}{self.to_pauli_string()})"