Coverage for qml_essentials / yaqsi.py: 93%

511 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-16 10:19 +0000

1from functools import reduce 

2from typing import Any, Callable, List, Optional, Tuple, Union 

3import threading 

4 

5import diffrax 

6import jax 

7import jax.numpy as jnp 

8import jax.scipy.linalg 

9import equinox as eqx 

10import numpy as np # needed to prevent jitting some operations 

11 

12from qml_essentials.operations import ( 

13 Barrier, 

14 Hermitian, 

15 ParametrizedHamiltonian, 

16 Operation, 

17 KrausChannel, 

18 PauliZ, 

19 _einsum_subscript, 

20 _cdtype, 

21) 

22from qml_essentials.tape import recording, pulse_recording 

23from qml_essentials.drawing import draw_text, draw_mpl, draw_tikz 

24 

25import logging 

26 

27log = logging.getLogger(__name__) 

28 

29 

30# def _args_contain_tracer(args) -> bool: 

31# """Return True if any leaf in *args* is a JAX tracer. 

32 

33# Used by :meth:`Script._execute_batched` to detect that the call is 

34# happening under an outer JAX transformation (``jit``/``vmap``/``grad``/ 

35# ``jacrev`` etc.). When that is the case the per-Script 

36# ``_jit_cache`` must be bypassed: a previously cached 

37# ``jax.jit(jax.vmap(...))`` was built under a different outer trace 

38# and re-using it would leak that trace's tracers (raising 

39# ``UnexpectedTracerError`` on the second transform). XLA compilation 

40# artefacts are still cached at the JAX level by jaxpr signature, so 

41# bypassing only the local Python wrapper has negligible runtime cost. 

42# """ 

43# from jax.core import Tracer 

44# for leaf in jax.tree_util.tree_leaves(args): 

45# if isinstance(leaf, Tracer): 

46# return True 

47# return False 

48 

49 

50def _make_hashable(obj): 

51 """Recursively convert an object into a hashable form for cache keys. 

52 

53 - ``dict`` → sorted tuple of ``(key, _make_hashable(value))`` pairs 

54 - ``list`` → tuple of ``_make_hashable(element)`` 

55 - ``set`` → frozenset of ``_make_hashable(element)`` 

56 - everything else is returned as-is (assumed hashable) 

57 """ 

58 if isinstance(obj, dict): 

59 return tuple(sorted((k, _make_hashable(v)) for k, v in obj.items())) 

60 if isinstance(obj, (list, tuple)): 

61 return tuple(_make_hashable(x) for x in obj) 

62 if isinstance(obj, set): 

63 return frozenset(_make_hashable(x) for x in obj) 

64 return obj 

65 

66 

67class Yaqsi: 

68 # TODO: generally, I would like to merge this into operations or vice-versa 

69 # and only keep Script here. It's not clear how to do this though. 

70 

71 # Module-level cache for JIT-compiled ODE solvers. Keyed on 

72 # (coeff_fn_id, dim, atol, rtol, max_steps, throw) so that all 

73 # evolve() calls with the same pulse shape function and matrix size 

74 # share one compiled XLA program. This turns O(n_gates) JIT 

75 # compilations into O(n_distinct_pulse_shapes) during pulse-mode 

76 # circuit building. 

77 _evolve_solver_cache: dict = {} 

78 _evolve_solver_cache_lock = threading.Lock() 

79 

80 # Default solver knobs for parametrized (time-dependent) evolution. 

81 # These can be overridden per-call via the **odeint_kwargs of 

82 # ``evolve()`` or globally via :meth:`set_solver_defaults`. 

83 # 

84 # ``max_steps`` is the hard cap on accepted ODE steps. Pulse-level 

85 # workloads at on-resonance carriers (ω_c ≈ ω_q) require many more 

86 # steps than the diffrax default during JIT — 2**13 = 8192 is 

87 # large enough for realistic single- and two-qubit pulses while 

88 # remaining cheap to compile. 

89 # 

90 # ``throw`` controls whether diffrax raises on solver failure 

91 # (e.g. ``MaxStepsReached``). When set to ``False`` the gate 

92 # factory instead returns a NaN-filled unitary so the calling 

93 # optimiser sees a well-defined (but useless) result and can 

94 # gracefully reject the candidate. 

95 # Whether to call ``jax.clear_caches()`` between memory-aware 

96 # chunks in :meth:`Script._execute_chunked`. Default ``False``: 

97 # clearing caches between chunks forces XLA to recompile the same 

98 # batched program for every chunk, which is a major performance hit 

99 # when many chunks are needed. Set ``True`` only if you observe 

100 # OOM growth across chunks. 

101 _clear_caches_between_chunks: bool = False 

102 

103 # ``solver`` selects the time-integration backend for the 

104 # interaction-picture ODE ``dU/dt = -i H_I(t) U``: 

105 # 

106 # * ``"dopri8"`` (default) — adaptive Dormand-Prince 8(7) via 

107 # diffrax. Robust but expensive on highly oscillatory drives 

108 # because the step controller resolves every fast cycle. 

109 # * ``"dopri5"`` — TODO description 

110 # * ``"magnus2"`` — commutator-free Magnus, 2nd order (midpoint 

111 # rule) on a fixed ``magnus_steps`` grid via ``jax.lax.scan``. 

112 # One ``expm`` per step. Preserves unitarity to machine 

113 # precision and fuses into a single XLA program. 

114 # * ``"magnus4"`` — commutator-free Magnus, 4th order (CFM4:2 of 

115 # Blanes & Moan) on a fixed ``magnus_steps`` grid. Two ``H`` 

116 # evaluations and two ``expm`` per step; typically the best 

117 # accuracy/cost trade-off for smooth oscillatory pulse drives. 

118 # 

119 # ``magnus_steps`` is the number of fixed substeps for the Magnus 

120 # integrators (ignored for ``dopri8``). Choose it so that ``h = 

121 # T/N`` resolves the fastest oscillation in ``H(t)`` (~few steps 

122 # per period of the highest frequency). 

123 _solver_defaults: dict = { 

124 "max_steps": 2**13, 

125 "throw": True, 

126 "solver": "dopri8", 

127 "magnus_steps": 256, 

128 } 

129 _valid_solvers = ("dopri8", "dopri5", "magnus2", "magnus4") 

130 

131 @classmethod 

132 def set_solver_defaults( 

133 cls, 

134 max_steps: Optional[int] = None, 

135 throw: Optional[bool] = None, 

136 solver: Optional[str] = None, 

137 magnus_steps: Optional[int] = None, 

138 ) -> dict: 

139 """Update class-level solver defaults; return the previous values. 

140 

141 The returned dictionary is suitable for restoring the previous 

142 defaults via ``set_solver_defaults(**prev)``. 

143 

144 Args: 

145 max_steps: New default for ``max_steps`` (ignored if ``None``). 

146 throw: New default for ``throw`` (ignored if ``None``). 

147 

148 Returns: 

149 Dictionary with the previous values of the updated keys. 

150 """ 

151 prev: dict = {} 

152 if max_steps is not None: 

153 prev["max_steps"] = cls._solver_defaults["max_steps"] 

154 cls._solver_defaults["max_steps"] = int(max_steps) 

155 if throw is not None: 

156 prev["throw"] = cls._solver_defaults["throw"] 

157 cls._solver_defaults["throw"] = bool(throw) 

158 if solver is not None: 

159 if solver not in cls._valid_solvers: 

160 raise ValueError( 

161 f"Unknown solver {solver!r}; expected one of {cls._valid_solvers}" 

162 ) 

163 prev["solver"] = cls._solver_defaults["solver"] 

164 cls._solver_defaults["solver"] = solver 

165 if magnus_steps is not None: 

166 prev["magnus_steps"] = cls._solver_defaults["magnus_steps"] 

167 cls._solver_defaults["magnus_steps"] = int(magnus_steps) 

168 return prev 

169 

170 @classmethod 

171 def _store_evolve_solver(cls, cache_key: tuple, solve: Callable) -> Callable: 

172 """Cache a compiled evolve solver unless another thread won the race.""" 

173 with cls._evolve_solver_cache_lock: 

174 existing = cls._evolve_solver_cache.get(cache_key) 

175 if existing is not None: 

176 return existing 

177 cls._evolve_solver_cache[cache_key] = solve 

178 return solve 

179 

180 @classmethod 

181 def clear_evolve_solver_cache(cls) -> None: 

182 """Drop every cached compiled evolve solver. 

183 

184 Call this whenever the coefficient functions referenced by the 

185 cache keys are rebuilt (e.g. when :class:`PulseGates` swaps in 

186 a new pulse envelope, RWA flag or frame). Without an explicit 

187 eviction the cache keeps the old code objects alive and would 

188 also retain XLA programs that no longer match any active 

189 coefficient function. 

190 """ 

191 with cls._evolve_solver_cache_lock: 

192 cls._evolve_solver_cache.clear() 

193 

194 @classmethod 

195 def _parse_evolve_solver_options(cls, odeint_kwargs: dict) -> tuple: 

196 """Pop and validate solver options from ``evolve(..., **odeint_kwargs)``.""" 

197 default_tol = 1.0e-10 if jax.config.x64_enabled else 1.4e-8 

198 atol = odeint_kwargs.pop("atol", default_tol) 

199 rtol = odeint_kwargs.pop("rtol", default_tol) 

200 max_steps = int( 

201 odeint_kwargs.pop("max_steps", cls._solver_defaults["max_steps"]) 

202 ) 

203 throw = bool(odeint_kwargs.pop("throw", cls._solver_defaults["throw"])) 

204 solver_name = str(odeint_kwargs.pop("solver", cls._solver_defaults["solver"])) 

205 if solver_name not in cls._valid_solvers: 

206 raise ValueError( 

207 f"Unknown solver {solver_name!r}; expected one of {cls._valid_solvers}" 

208 ) 

209 magnus_steps = int( 

210 odeint_kwargs.pop("magnus_steps", cls._solver_defaults["magnus_steps"]) 

211 ) 

212 return atol, rtol, max_steps, throw, solver_name, magnus_steps 

213 

214 @classmethod 

215 def _build_magnus_evolve_solver( 

216 cls, 

217 cache_key: tuple, 

218 coeff_fns: Tuple[Callable, ...], 

219 n_terms: int, 

220 dim: int, 

221 solver_name: str, 

222 magnus_steps: int, 

223 ) -> Callable: 

224 """Build and cache a fixed-step commutator-free Magnus solver.""" 

225 _coeff_fns = coeff_fns 

226 _cdtype_local = jnp.complex128 if jax.config.x64_enabled else jnp.complex64 

227 n_steps = magnus_steps 

228 solver_name_local = solver_name 

229 

230 @eqx.filter_jit 

231 def _solve(neg_iH_split, params, t0, t1): 

232 # Reconstruct the per-term complex matrices ``-i H_i`` from their 

233 # split (Re, Im) representation so the coefficient sum is a single 

234 # complex tensordot. 

235 A_all = neg_iH_split[:, 0] 

236 B_all = neg_iH_split[:, 1] 

237 neg_iH = (A_all + 1j * B_all).astype(_cdtype_local) 

238 

239 h = (t1 - t0) / n_steps 

240 

241 def H_at(t): 

242 c = jnp.stack( 

243 [ 

244 jnp.asarray(_coeff_fns[i](params[i], t)).reshape(()) 

245 for i in range(n_terms) 

246 ] 

247 ).astype(_cdtype_local) 

248 return jnp.tensordot(c, neg_iH, axes=1) 

249 

250 if solver_name_local == "magnus2": 

251 

252 def step(U, n): 

253 tn = t0 + n * h 

254 Omega = h * H_at(tn + 0.5 * h) 

255 return jax.scipy.linalg.expm(Omega) @ U, None 

256 

257 else: 

258 import math 

259 

260 sqrt3 = math.sqrt(3.0) 

261 c1 = 0.5 - sqrt3 / 6.0 

262 c2 = 0.5 + sqrt3 / 6.0 

263 a1 = 0.25 + sqrt3 / 6.0 

264 a2 = 0.25 - sqrt3 / 6.0 

265 

266 def step(U, n): 

267 tn = t0 + n * h 

268 H1 = H_at(tn + c1 * h) 

269 H2 = H_at(tn + c2 * h) 

270 Omega_a = h * (a1 * H1 + a2 * H2) 

271 Omega_b = h * (a2 * H1 + a1 * H2) 

272 # CFM4:2 ordering (Blanes & Moan 2006, Table II): 

273 # U_{n+1} = exp(Ω_b) · exp(Ω_a) · U_n. 

274 U_next = ( 

275 jax.scipy.linalg.expm(Omega_b) 

276 @ jax.scipy.linalg.expm(Omega_a) 

277 @ U 

278 ) 

279 return U_next, None 

280 

281 U0 = jnp.eye(dim, dtype=_cdtype_local) 

282 U_final, _ = jax.lax.scan(step, U0, jnp.arange(n_steps)) 

283 return U_final 

284 

285 return cls._store_evolve_solver(cache_key, _solve) 

286 

287 @classmethod 

288 def _build_diffrax_evolve_solver( 

289 cls, 

290 cache_key: tuple, 

291 coeff_fns: Tuple[Callable, ...], 

292 n_terms: int, 

293 dim: int, 

294 atol: float, 

295 rtol: float, 

296 max_steps: int, 

297 throw: bool, 

298 solver_name: str, 

299 _rdtype, 

300 ) -> Callable: 

301 """Build and cache an adaptive diffrax-based evolve solver.""" 

302 solver = diffrax.Dopri8() if solver_name == "dopri8" else diffrax.Dopri5() 

303 stepsize_controller = diffrax.PIDController(atol=atol, rtol=rtol) 

304 _coeff_fns = coeff_fns 

305 

306 @eqx.filter_jit 

307 def _solve(neg_iH_split, params, t0, t1): 

308 """Solve dU/dt = sum_i f_i(p_i, t) * (-iH_i) * U from t0 to t1. 

309 

310 ``neg_iH_split`` has shape ``(n_terms, 2, dim, dim)`` with 

311 ``[:, 0]`` = Re(-iH_i) and ``[:, 1]`` = Im(-iH_i). 

312 ``params`` is a list/tuple of length ``n_terms`` carrying 

313 each term's coefficient parameters. The state ``y`` has 

314 shape ``(2, dim, dim)`` with ``y[0] = Re(U)`` and 

315 ``y[1] = Im(U)``. 

316 """ 

317 A_all = neg_iH_split[:, 0] 

318 B_all = neg_iH_split[:, 1] 

319 

320 def rhs(t, y, args): 

321 # Each coefficient function must return a scalar value; some 

322 # call sites pass a shape-(1,) param array, so coerce to a 

323 # true scalar before stacking. 

324 c = jnp.stack( 

325 [ 

326 jnp.asarray(_coeff_fns[i](args[i], t)).reshape(()) 

327 for i in range(n_terms) 

328 ] 

329 ) 

330 u_re = y[0] 

331 u_im = y[1] 

332 A_eff = jnp.tensordot(c, A_all, axes=1) 

333 B_eff = jnp.tensordot(c, B_all, axes=1) 

334 du_re = A_eff @ u_re - B_eff @ u_im 

335 du_im = A_eff @ u_im + B_eff @ u_re 

336 return jnp.stack([du_re, du_im], axis=0) 

337 

338 y0 = jnp.stack( 

339 [ 

340 jnp.eye(dim, dtype=_rdtype), 

341 jnp.zeros((dim, dim), dtype=_rdtype), 

342 ], 

343 axis=0, 

344 ) 

345 

346 sol = diffrax.diffeqsolve( 

347 diffrax.ODETerm(rhs), 

348 solver, 

349 t0=t0, 

350 t1=t1, 

351 dt0=None, 

352 y0=y0, 

353 args=params, 

354 stepsize_controller=stepsize_controller, 

355 max_steps=max_steps, 

356 throw=throw, 

357 ) 

358 

359 y_final = sol.ys[0] 

360 U = y_final[0] + 1j * y_final[1] 

361 

362 if not throw: 

363 successful = sol.result == diffrax.RESULTS.successful 

364 U = jnp.where(successful, U, jnp.full_like(U, jnp.nan)) 

365 return U 

366 

367 return cls._store_evolve_solver(cache_key, _solve) 

368 

369 @staticmethod 

370 def _partial_trace_single( 

371 rho: jnp.ndarray, 

372 n_qubits: int, 

373 keep: List[int], 

374 ) -> jnp.ndarray: 

375 """Partial trace of a single density matrix (no batch dimension).""" 

376 shape = (2,) * (2 * n_qubits) 

377 rho_t = rho.reshape(shape) 

378 

379 trace_out = sorted(set(range(n_qubits)) - set(keep)) 

380 

381 for q in reversed(trace_out): 

382 n_remaining = rho_t.ndim // 2 

383 rho_t = jnp.trace(rho_t, axis1=q, axis2=q + n_remaining) 

384 

385 dim = 2 ** len(keep) 

386 return rho_t.reshape(dim, dim) 

387 

388 @classmethod 

389 def partial_trace( 

390 cls, 

391 rho: jnp.ndarray, 

392 n_qubits: int, 

393 keep: List[int], 

394 ) -> jnp.ndarray: 

395 """Partial trace of a density matrix, keeping only the specified qubits. 

396 

397 Supports both single density matrices of shape ``(2**n, 2**n)`` and 

398 batched density matrices of shape ``(B, 2**n, 2**n)``. 

399 

400 Args: 

401 rho: Density matrix of shape ``(2**n, 2**n)`` or ``(B, 2**n, 2**n)``. 

402 n_qubits: Total number of qubits. 

403 keep: List of qubit indices to *keep* (0-indexed). 

404 

405 Returns: 

406 Reduced density matrix of shape ``(2**k, 2**k)`` or ``(B, 2**k, 2**k)`` 

407 where *k* = ``len(keep)``. 

408 """ 

409 

410 dim = 2**n_qubits 

411 if rho.shape == (dim, dim): 

412 return Yaqsi._partial_trace_single(rho, n_qubits, keep) 

413 # Batched: shape (B, dim, dim) 

414 return jax.vmap(lambda r: Yaqsi._partial_trace_single(r, n_qubits, keep))(rho) 

415 

416 @staticmethod 

417 def _marginalize_probs_single( 

418 probs: jnp.ndarray, 

419 target_shape: Tuple[int], 

420 trace_out: Tuple[int], 

421 ) -> jnp.ndarray: 

422 """Marginalize a single probability vector (no batch dimension).""" 

423 probs_t = probs.reshape(target_shape) 

424 

425 for q in trace_out: 

426 probs_t = probs_t.sum(axis=q) 

427 

428 return probs_t.ravel() 

429 

430 @classmethod 

431 def marginalize_probs( 

432 cls, 

433 probs: jnp.ndarray, 

434 n_qubits: int, 

435 keep: Tuple[int], 

436 ) -> jnp.ndarray: 

437 """Marginalize a probability vector to keep only the specified qubits. 

438 

439 Supports both single probability vectors of shape ``(2**n,)`` and 

440 batched vectors of shape ``(B, 2**n)``. 

441 

442 Args: 

443 probs: Probability vector of shape ``(2**n,)`` or ``(B, 2**n)``. 

444 n_qubits: Total number of qubits. 

445 keep: List of qubit indices to *keep* (0-indexed). 

446 

447 Returns: 

448 Marginalized probability vector of shape ``(2**k,)`` or ``(B, 2**k)`` 

449 where *k* = ``len(keep)``. 

450 """ 

451 

452 dim = 2**n_qubits 

453 trace_out = tuple(q for q in range(n_qubits - 1, -1, -1) if q not in keep) 

454 target_shape = (2,) * n_qubits 

455 

456 return jax.vmap( 

457 lambda p: Yaqsi._marginalize_probs_single(p, target_shape, trace_out) 

458 )(probs.reshape(-1, dim)) 

459 

460 @classmethod 

461 def build_parity_observable( 

462 cls, 

463 qubit_group: List[int], 

464 ) -> Hermitian: 

465 """Build a multi-qubit parity observable. 

466 

467 Args: 

468 qubit_group: List of qubit indices for the parity measurement. 

469 

470 Returns: 

471 A :class:`Hermitian` operation whose matrix is the Z parity 

472 tensor product and whose wires match the given qubits. 

473 """ 

474 Z = PauliZ._matrix 

475 mat = reduce(jnp.kron, [Z] * len(qubit_group)) 

476 return Hermitian(matrix=mat, wires=qubit_group, record=False) 

477 

478 @classmethod 

479 def evolve( 

480 cls, 

481 hamiltonian: Union["Hermitian", "ParametrizedHamiltonian"], 

482 name: Optional[str] = None, 

483 **odeint_kwargs: Any, 

484 ) -> Callable: 

485 """Return a gate-factory for Hamiltonian time evolution. 

486 

487 Supports two modes: 

488 

489 Static — when *hamiltonian* is a :class:`Hermitian`:: 

490 

491 gate = evolve(Hermitian(H_mat, wires=0)) 

492 gate(t=0.5) # U = exp(-i*0.5*H) 

493 

494 Time-dependent — when *hamiltonian* is a 

495 :class:`ParametrizedHamiltonian` (created via ``coeff_fn * Hermitian``):: 

496 

497 H_td = coeff_fn * Hermitian(H_mat, wires=0) 

498 gate = evolve(H_td) 

499 gate([A, sigma], T) # U via ODE: dU/dt = -i f(p,t) H * U 

500 

501 The time-dependent case solves the Schrödinger equation numerically 

502 using ``diffrax.diffeqsolve`` with a Dopri8 adaptive Runge-Kutta 

503 solver 

504 

505 All computations are pure JAX and fully differentiable with 

506 ``jax.grad``. 

507 

508 Args: 

509 hamiltonian: Either a :class:`Hermitian` (static evolution) or a 

510 :class:`ParametrizedHamiltonian` (time-dependent evolution). 

511 **odeint_kwargs: Extra keyword arguments. Recognised keys: 

512 

513 - ``atol``, ``rtol`` — absolute/relative tolerances for the 

514 adaptive step-size controller (default ``1.4e-8``). 

515 

516 Returns: 

517 A callable gate factory. Signature depends on the mode: 

518 

519 - Static: ``(t, wires=0) -> Operation`` 

520 - Time-dependent: ``(coeff_args, T) -> Operation`` 

521 

522 Raises: 

523 TypeError: If *hamiltonian* is neither ``Hermitian`` nor 

524 ``ParametrizedHamiltonian``. 

525 """ 

526 if isinstance(hamiltonian, Hermitian): 

527 return cls._evolve_static(hamiltonian, name=name) 

528 elif isinstance(hamiltonian, ParametrizedHamiltonian): 

529 return cls._evolve_parametrized(hamiltonian, name=name, **odeint_kwargs) 

530 else: 

531 raise TypeError( 

532 f"evolve() expects a Hermitian or ParametrizedHamiltonian, " 

533 f"got {type(hamiltonian)}" 

534 ) 

535 

536 @staticmethod 

537 def _evolve_static(hermitian: Hermitian, name: Optional[str] = None) -> Callable: 

538 """Gate factory for static Hamiltonian evolution U = exp(-i t H).""" 

539 H_mat = hermitian.matrix 

540 

541 def _apply(t: float, wires: Union[int, List[int]] = 0) -> Operation: 

542 U = jax.scipy.linalg.expm(-1j * t * H_mat) 

543 return Operation(wires=wires, matrix=U, name=name) 

544 

545 return _apply 

546 

547 @classmethod 

548 def _evolve_parametrized( 

549 cls, 

550 ph: ParametrizedHamiltonian, 

551 name: Optional[str] = None, 

552 **odeint_kwargs: Any, 

553 ) -> Callable: 

554 """Gate factory for time-dependent (multi-term) Hamiltonian evolution. 

555 

556 Solves the matrix ODE 

557 

558 dU/dt = -i [\\sum_i f_i(params_i, t) * H_i] * U, U(0) = I 

559 

560 with ``diffrax.diffeqsolve`` (Dopri8 adaptive RK). The Hamiltonian 

561 may contain one or more ``coeff_fn * Hermitian`` terms (see 

562 :class:`ParametrizedHamiltonian`); the single-term case is the 

563 usual ``coeff_fn * Hermitian`` and is fully backward compatible. 

564 

565 Implementation notes: 

566 

567 - To avoid diffrax's experimental complex dtype path, the ODE is 

568 reformulated in real arithmetic. Writing ``-iH_i = A_i + i B_i`` 

569 and ``U = U_re + i U_im``, each term contributes:: 

570 

571 d(U_re)/dt += f_i(p_i,t) * (A_i @ U_re - B_i @ U_im) 

572 d(U_im)/dt += f_i(p_i,t) * (A_i @ U_im + B_i @ U_re) 

573 

574 - ``-iH_i`` is precomputed once per term and stacked into a 

575 ``(n_terms, 2, dim, dim)`` real array, contracted via 

576 ``einsum`` against the per-step coefficient vector 

577 ``c = [f_0(p_0,t), ..., f_{n-1}(p_{n-1},t)]``. 

578 

579 - The JIT-compiled solver is cached per coefficient-function code 

580 tuple (and ``dim``, tolerances) so multiple ``evolve()`` calls 

581 with the same pulse shape — but different Hamiltonian matrices 

582 or parameters — reuse the same compiled XLA program. 

583 

584 TODO: switch back once diffrax is stable with complex arithmetic. 

585 

586 Args: 

587 ph: A :class:`ParametrizedHamiltonian` (one or more terms). 

588 **odeint_kwargs: Keyword arguments forwarded to 

589 ``diffrax.diffeqsolve``. Recognised keys: 

590 

591 - ``atol``, ``rtol`` — absolute/relative tolerances for the 

592 step-size controller (default ``1.4e-8`` in fp32 mode, 

593 ``1.0e-10`` in fp64 mode). 

594 - ``max_steps`` — hard cap on accepted ODE steps 

595 (default :attr:`Yaqsi._solver_defaults['max_steps']`, 

596 currently ``2**14``). Increase this if the integrator 

597 raises ``MaxStepsReached`` for a stiff/oscillatory 

598 pulse Hamiltonian. 

599 - ``throw`` — whether to raise on solver failure 

600 (default :attr:`Yaqsi._solver_defaults['throw']`, 

601 currently ``True``). When ``False``, a failed 

602 integration returns a NaN-filled unitary instead of 

603 raising; this is the recommended setting for inner 

604 loops of an optimiser (e.g. QOC Stage 0) so a single 

605 pathological candidate cannot abort the whole run. 

606 """ 

607 coeff_fns = ph.coeff_fns # tuple of callables 

608 H_mats = ph.H_mats # tuple of (dim, dim) 

609 wires = ph.wires 

610 n_terms = ph.n_terms 

611 dim = H_mats[0].shape[0] 

612 

613 # Pre-compute -i*H_i for each term and split into real / imaginary 

614 # parts so the ODE RHS uses only real arithmetic. Final shape: 

615 # (n_terms, 2, dim, dim). 

616 neg_iH_split_per_term = [] 

617 for H_mat in H_mats: 

618 neg_iH = -1j * H_mat 

619 neg_iH_split_per_term.append( 

620 jnp.stack([jnp.real(neg_iH), jnp.imag(neg_iH)], axis=0) 

621 ) 

622 neg_iH_split = jnp.stack(neg_iH_split_per_term, axis=0) 

623 

624 # Real dtype matching the precision mode 

625 # consider decreasing if no convergence 

626 _rdtype = jnp.float64 if jax.config.x64_enabled else jnp.float32 

627 

628 # Pick tolerances according to precision + some headroom 

629 atol, rtol, max_steps, throw, solver_name, magnus_steps = ( 

630 cls._parse_evolve_solver_options(odeint_kwargs) 

631 ) 

632 

633 # Cache key: every coeff fn's code object (same shape of pulse 

634 # fns -> same JIT program) plus dim, tolerances, and solver 

635 # budget / throw flag (different budgets mean different XLA 

636 # programs). We use the code object itself (hashable, identity- 

637 # equal) rather than ``id(fn.__code__)``: ids can be reused for 

638 # later code objects after the original is garbage-collected, 

639 # which would silently return a stale compiled solver for a 

640 # different pulse shape. Holding the code object in the cache 

641 # keeps it alive for as long as the cached program is valid. 

642 cache_key = ( 

643 tuple(fn.__code__ for fn in coeff_fns), 

644 dim, 

645 atol, 

646 rtol, 

647 max_steps, 

648 throw, 

649 solver_name, 

650 magnus_steps, 

651 ) 

652 

653 with cls._evolve_solver_cache_lock: 

654 _solve = cls._evolve_solver_cache.get(cache_key) 

655 if _solve is None: 

656 if solver_name in ("magnus2", "magnus4"): 

657 _solve = cls._build_magnus_evolve_solver( 

658 cache_key=cache_key, 

659 coeff_fns=coeff_fns, 

660 n_terms=n_terms, 

661 dim=dim, 

662 solver_name=solver_name, 

663 magnus_steps=magnus_steps, 

664 ) 

665 else: 

666 _solve = cls._build_diffrax_evolve_solver( 

667 cache_key=cache_key, 

668 coeff_fns=coeff_fns, 

669 n_terms=n_terms, 

670 dim=dim, 

671 atol=atol, 

672 rtol=rtol, 

673 max_steps=max_steps, 

674 throw=throw, 

675 solver_name=solver_name, 

676 _rdtype=_rdtype, 

677 ) 

678 

679 def _apply(coeff_args, T) -> Operation: 

680 """Evolve under the (multi-term) time-dependent Hamiltonian. 

681 

682 Args: 

683 coeff_args: List/tuple of parameter sets, one per term. 

684 For single-term Hamiltonians the legacy form 

685 ``[params]`` works unchanged; ``params`` is forwarded 

686 to the sole coefficient function. 

687 T: Total evolution time. Scalar -> integrate on 

688 ``[0, T]``; 2-element -> integrate on ``[T[0], T[1]]``. 

689 

690 Returns: 

691 An :class:`Operation` wrapping the computed unitary. 

692 """ 

693 # Normalise to a tuple of length n_terms. Accept a bare 

694 # single-term arg for backward compat. 

695 if isinstance(coeff_args, (list, tuple)): 

696 params = tuple(coeff_args) 

697 else: 

698 params = (coeff_args,) 

699 

700 if len(params) != n_terms: 

701 raise ValueError( 

702 f"Expected {n_terms} parameter set(s) for a " 

703 f"{n_terms}-term ParametrizedHamiltonian, " 

704 f"got {len(params)}." 

705 ) 

706 

707 # Build time span — resolve at Python level to avoid traced 

708 # branching. ``T`` is either a Python scalar / 0-d array (=> integrate 

709 # on [0, T]) or a 2-element sequence/array (=> integrate on [T[0], T[1]]). 

710 # Let ``_solve`` cast t0/t1 to its working dtype; we only need the 

711 # array form to know the rank. 

712 T_arr = jnp.asarray(T, dtype=_rdtype) 

713 if T_arr.ndim == 0: 

714 t0 = _rdtype(0.0) 

715 t1 = T_arr 

716 else: 

717 t0 = T_arr[0] 

718 t1 = T_arr[1] 

719 

720 U = _solve(neg_iH_split, params, t0, t1) 

721 

722 return Operation(wires=wires, matrix=U, name=name) 

723 

724 return _apply 

725 

726 

727# TODO adjust imports to use classmethods instead 

728partial_trace = Yaqsi.partial_trace 

729evolve = Yaqsi.evolve 

730marginalize_probs = Yaqsi.marginalize_probs 

731build_parity_observable = Yaqsi.build_parity_observable 

732 

733 

734class Script: 

735 """Circuit container and executor backed by pure JAX kernels. 

736 

737 ``Script`` takes a callable *f* representing a quantum circuit. 

738 Within *f*, :class:`~qml_essentials.operations.Operation` objects are 

739 instantiated and automatically recorded onto a tape. The tape is then 

740 simulated using either a statevector or density-matrix kernel depending on 

741 whether noise channels are present. 

742 

743 Attributes: 

744 f: The circuit function whose body instantiates ``Operation`` objects. 

745 _n_qubits: Optionally pre-declared number of qubits. When ``None`` 

746 the qubit count is inferred from the operations recorded on the 

747 tape. 

748 

749 Example: 

750 >>> def circuit(theta): 

751 ... RX(theta, wires=0) 

752 ... PauliZ(wires=1) 

753 >>> script = Script(circuit, n_qubits=2) 

754 >>> result = script.execute(type="expval", obs=[PauliZ(0)]) 

755 """ 

756 

757 def __init__(self, f: Callable[..., None], n_qubits: Optional[int] = None) -> None: 

758 """Initialise a Script. 

759 

760 Args: 

761 f: A function whose body instantiates ``Operation`` objects. 

762 Signature: ``f(*args, **kwargs) -> None``. 

763 n_qubits: Number of qubits. If ``None``, inferred from the 

764 operations recorded on the tape. 

765 """ 

766 self.f = f 

767 self._n_qubits = n_qubits 

768 self._jit_cache: dict = {} # keyed on (type, in_axes, arg_shapes, gateError) 

769 

770 @staticmethod 

771 def _estimate_peak_bytes( 

772 n_qubits: int, 

773 batch_size: int, 

774 type: str, 

775 use_density: bool, 

776 n_obs: int = 0, 

777 ) -> int: 

778 """Estimate peak memory (bytes) for a batched simulation. 

779 

780 The estimate accounts for: 

781 

782 - The batched statevector (always needed, even for density). 

783 - The batched output tensor (state / probs / density / expval). 

784 - One gate-tensor temporary per batch element (the einsum buffer). 

785 

786 Observable matrices are **not** counted: they are computed inside 

787 the JIT-compiled function and XLA manages their lifetime (reusing 

788 buffers between observables). Similarly, the outer-product 

789 temporary for pure-circuit density mode is transient within XLA. 

790 

791 Element size is determined dynamically from ``jax.config.x64_enabled``: 

792 when x64 mode is disabled (the JAX default), complex values are 

793 ``complex64`` (8 bytes) and floats are ``float32`` (4 bytes), 

794 halving memory usage compared to the x64 path. 

795 

796 A 1.5× safety factor is applied to cover XLA compiler temporaries, 

797 padding, and other allocations not directly visible to Python. 

798 

799 This is a pure Python arithmetic calculation with no JAX calls — 

800 it adds essentially zero overhead. 

801 

802 Args: 

803 n_qubits: Number of qubits in the circuit. 

804 batch_size: Number of batch elements. 

805 type: Measurement type (``"state"``, ``"probs"``, ``"expval"``, 

806 ``"density"``). 

807 use_density: Whether density-matrix simulation is used. 

808 n_obs: Number of observables (relevant for ``"expval"``). 

809 

810 Returns: 

811 Estimated peak memory in bytes. 

812 """ 

813 dim = 2**n_qubits 

814 # Detect actual element size: JAX silently truncates complex128 

815 # to complex64 when x64 mode is disabled (the default). 

816 elem = 16 if jax.config.x64_enabled else 8 # complex128 vs complex64 

817 real_elem = elem // 2 # float64 vs float32 

818 

819 # Statevector: always allocated during simulation 

820 sv_bytes = batch_size * dim * elem 

821 

822 # Simulation intermediate: when density-matrix simulation is used, 

823 # the full rho (dim × dim) must be held during gate evolution — 

824 # even if the final output is only probs or expval. 

825 # apply_to_density contracts both U and U* against rho, so at least 

826 # two intermediate (dim × dim) buffers are alive simultaneously. 

827 if use_density: 

828 sim_bytes = 2 * batch_size * dim * dim * elem 

829 else: 

830 sim_bytes = 0 # statevector is already counted above 

831 

832 # Output tensor: this is the *returned* array, not the simulation 

833 # intermediate. For probs/expval with density simulation the 

834 # density matrix is reduced to a small output *before* returning, 

835 # so only the reduced output coexists with the next chunk. 

836 if type == "density": 

837 out_bytes = batch_size * dim * dim * elem 

838 elif type == "expval": 

839 out_bytes = batch_size * max(n_obs, 1) * real_elem 

840 elif type == "probs": 

841 out_bytes = batch_size * dim * real_elem 

842 else: # state 

843 out_bytes = batch_size * dim * elem 

844 

845 # Gate temporaries: einsum creates one (2,)*n buffer per batch elem 

846 gate_tmp = batch_size * dim * elem 

847 

848 # Peak = max(simulation phase, output phase). During simulation 

849 # the intermediate + statevector + gate temps are alive. After 

850 # measurement, only the output survives. So peak is whichever 

851 # phase is larger. 

852 sim_peak = sv_bytes + sim_bytes + gate_tmp 

853 out_peak = out_bytes 

854 raw = max(sim_peak, out_peak) 

855 

856 # 1.5× safety factor for XLA compiler temporaries, padding, etc. 

857 return int(raw * 1.5) 

858 

859 @staticmethod 

860 def _available_memory_bytes() -> int: 

861 """Return available system memory in bytes. 

862 

863 Uses ``psutil.virtual_memory().available`` for cross-platform 

864 support (Linux, macOS, Windows). Falls back to reading 

865 ``/proc/meminfo`` on Linux, and finally to a conservative 4 GiB 

866 default if neither approach succeeds. 

867 

868 Returns: 

869 Available memory in bytes. 

870 """ 

871 mem = 4 * 1024**3 

872 # Primary: psutil (works on Linux, macOS, Windows) 

873 try: 

874 import psutil 

875 

876 mem = psutil.virtual_memory().available 

877 except Exception: 

878 log.debug("psutil not available. Fallback to /proc/meminfo") 

879 

880 # Fallback: /proc/meminfo (Linux only) 

881 try: 

882 with open("/proc/meminfo", "r") as f: 

883 for line in f: 

884 if line.startswith("MemAvailable:"): 

885 mem = int(line.split()[1]) * 1024 # kB → bytes 

886 except Exception: 

887 log.debug("Failed to read /proc/meminfo. Falling back to 4 GiB") 

888 

889 log.debug(f"Available memory: {mem / 1024**3:.1f} GB") 

890 return mem 

891 

892 @staticmethod 

893 def _compute_chunk_size( 

894 n_qubits: int, 

895 batch_size: int, 

896 type: str, 

897 use_density: bool, 

898 n_obs: int = 0, 

899 memory_fraction: float = 0.8, 

900 ) -> int: 

901 """Determine the largest chunk size that fits in available memory. 

902 

903 If the full batch fits, returns *batch_size* (i.e. no chunking). 

904 Otherwise, returns the largest chunk size such that the computation 

905 of one chunk **plus** the full output accumulator fits within 

906 ``memory_fraction`` of available RAM. 

907 

908 The output accumulator is the final ``(batch_size, ...)`` array that 

909 holds all results. When chunking, this array must coexist with the 

910 active chunk computation, so its size is subtracted from available 

911 memory before computing how many elements fit per chunk. 

912 

913 The minimum chunk size is 1 (fully serialised). 

914 

915 Args: 

916 n_qubits: Number of qubits. 

917 batch_size: Total batch size. 

918 type: Measurement type. 

919 use_density: Whether density-matrix simulation is used. 

920 n_obs: Number of observables. 

921 memory_fraction: Fraction of available memory to target 

922 (default 0.8 = 80%). 

923 

924 Returns: 

925 Chunk size (number of batch elements per sub-batch). 

926 """ 

927 avail = int(Script._available_memory_bytes() * memory_fraction) 

928 full_est = Script._estimate_peak_bytes( 

929 n_qubits, batch_size, type, use_density, n_obs 

930 ) 

931 

932 if full_est <= avail: 

933 return batch_size # everything fits — no chunking 

934 

935 # The output accumulator (the final (batch_size, ...) result array) 

936 # must coexist with each chunk's computation, so subtract its size 

937 # from available memory before sizing chunks. 

938 dim = 2**n_qubits 

939 elem = 16 if jax.config.x64_enabled else 8 

940 real_elem = elem // 2 

941 if type == "density": 

942 accum_bytes = batch_size * dim * dim * elem 

943 elif type == "expval": 

944 accum_bytes = batch_size * max(n_obs, 1) * real_elem 

945 elif type == "probs": 

946 accum_bytes = batch_size * dim * real_elem 

947 else: 

948 accum_bytes = batch_size * dim * elem 

949 avail_for_chunks = max(avail - accum_bytes, elem) # at least 1 element 

950 

951 # Per-element cost: the memory for computing a single batch element. 

952 per_elem = Script._estimate_peak_bytes(n_qubits, 1, type, use_density, n_obs) 

953 

954 if per_elem <= 0: 

955 return batch_size 

956 

957 chunk = avail_for_chunks // per_elem 

958 chunk = max(1, min(chunk, batch_size)) 

959 

960 if chunk == 1 and per_elem > avail: 

961 log.warning( 

962 f"A single batch element requires ~{per_elem / 1024**3:.2f} GB " 

963 f"but only ~{avail / 1024**3:.2f} GB is available. " 

964 f"Proceeding with chunk_size=1 but OOM is possible. " 

965 f"Consider reducing n_qubits or switching measurement type." 

966 ) 

967 

968 log.info( 

969 f"Computation requires ~{full_est / 1024**3:.2f} GB which " 

970 f"does not fit in ~{avail / 1024**3:.2f} GB. " 

971 f"Using chunk size {chunk}." 

972 ) 

973 return chunk 

974 

975 @staticmethod 

976 def _execute_chunked( 

977 batched_fn: Callable, 

978 args: tuple, 

979 in_axes: Tuple, 

980 batch_size: int, 

981 chunk_size: int, 

982 ) -> jnp.ndarray: 

983 """Execute a vmapped function in memory-safe chunks. 

984 

985 Splits the batch dimension into sub-batches of at most *chunk_size* 

986 elements, runs each through the JIT-compiled *batched_fn*, and 

987 writes results into a pre-allocated output array. 

988 

989 Only one chunk's intermediate result is alive at a time: each 

990 chunk is computed, copied into the output buffer, and then its 

991 reference is dropped — allowing JAX/XLA to reclaim the memory 

992 before the next chunk starts. This keeps peak memory at roughly 

993 ``output_buffer + one_chunk_computation`` rather than the sum of 

994 all chunk outputs. 

995 

996 Args: 

997 batched_fn: A JIT-compiled, vmapped callable. 

998 args: Full-batch arguments (before slicing). 

999 in_axes: Per-argument batch axis specification. 

1000 batch_size: Total number of batch elements. 

1001 chunk_size: Maximum elements per chunk. 

1002 

1003 Returns: 

1004 Batched results with the same leading dimension as the 

1005 full batch. 

1006 """ 

1007 n_chunks = (batch_size + chunk_size - 1) // chunk_size 

1008 log.debug( 

1009 f"Memory-aware chunking: splitting batch of {batch_size} into " 

1010 f"{n_chunks} chunks of <={chunk_size} elements." 

1011 ) 

1012 

1013 output = None 

1014 for chunk_idx in range(n_chunks): 

1015 start = chunk_idx * chunk_size 

1016 end = min(start + chunk_size, batch_size) 

1017 size = end - start 

1018 

1019 # Slice each batched argument along its batch axis 

1020 chunk_args = tuple( 

1021 ( 

1022 jax.lax.dynamic_slice_in_dim(a, start, size, axis=ax) 

1023 if ax is not None 

1024 else a 

1025 ) 

1026 for a, ax in zip(args, in_axes) 

1027 ) 

1028 

1029 chunk_result = batched_fn(*chunk_args) 

1030 

1031 if output is None: 

1032 # Pre-allocate the full output buffer on first chunk 

1033 out_shape = (batch_size,) + chunk_result.shape[1:] 

1034 output = jnp.zeros(out_shape, dtype=chunk_result.dtype) 

1035 

1036 # Copy chunk into the output buffer; the slice assignment 

1037 # creates a new array (JAX arrays are immutable) but the old 

1038 # `output` reference is immediately replaced, letting XLA 

1039 # reclaim it. 

1040 output = output.at[start:end].set(chunk_result) 

1041 

1042 # Explicitly drop the chunk reference so XLA can free the 

1043 # chunk's device memory before computing the next one. 

1044 del chunk_result, chunk_args 

1045 # Optionally trigger a JAX cache clear to release device 

1046 # buffers — disabled by default because it forces full 

1047 # recompilation of ``batched_fn`` on every subsequent 

1048 # chunk. Set ``Yaqsi._clear_caches_between_chunks = True`` 

1049 # if you actually observe OOM growth across chunks. 

1050 if Yaqsi._clear_caches_between_chunks: 

1051 jax.clear_caches() 

1052 

1053 return output 

1054 

1055 def _record(self, *args, **kwargs) -> List[Operation]: 

1056 """Run the circuit function and collect the recorded operations. 

1057 

1058 Uses :func:`~qml_essentials.tape.recording` as a context manager so 

1059 that the tape is always cleaned up — even if the circuit function 

1060 raises — and nested recordings (e.g. from ``_execute_batched``) each 

1061 get their own independent tape. 

1062 

1063 Args: 

1064 *args: Positional arguments forwarded to the circuit function. 

1065 **kwargs: Keyword arguments forwarded to the circuit function. 

1066 

1067 Returns: 

1068 List of :class:`~qml_essentials.operations.Operation` instances in 

1069 the order they were instantiated. 

1070 """ 

1071 with recording() as tape: 

1072 self.f(*args, **kwargs) 

1073 return tape 

1074 

1075 def pulse_events(self, *args, **kwargs) -> list: 

1076 """Run the circuit and collect pulse events emitted by PulseGates. 

1077 

1078 Activates both the normal operation tape (so gates execute) and 

1079 a pulse-event tape that captures 

1080 :class:`~qml_essentials.drawing.PulseEvent` objects from leaf 

1081 pulse gates (RX, RY, RZ, CZ). 

1082 

1083 Args: 

1084 *args (Any): Forwarded to the circuit function. 

1085 **kwargs (Any): Forwarded to the circuit function. 

1086 

1087 Returns: 

1088 List of :class:`~qml_essentials.drawing.PulseEvent`. 

1089 """ 

1090 with pulse_recording() as events: 

1091 with recording(): 

1092 self.f(*args, **kwargs) 

1093 return events 

1094 

1095 @staticmethod 

1096 def _infer_n_qubits(ops: List[Operation], obs: List[Operation]) -> int: 

1097 """Infer the number of qubits from a list of operations and observables. 

1098 

1099 Args: 

1100 ops: Gate operations recorded on the tape. 

1101 obs: Observable operations used for measurement. 

1102 

1103 Returns: 

1104 The smallest number of qubits that covers all wire indices, i.e. 

1105 ``max(all_wires) + 1`` (at least 1). 

1106 """ 

1107 all_wires: set[int] = set() 

1108 for op in ops + obs: 

1109 all_wires.update(op.wires) 

1110 return max(all_wires) + 1 if all_wires else 1 

1111 

1112 @staticmethod 

1113 def _simulate_pure(tape: List[Operation], n_qubits: int) -> jnp.ndarray: 

1114 """Statevector simulation kernel. 

1115 

1116 Starts from |00…0⟩ and applies each gate in *tape* via tensor 

1117 contraction. The state is kept in rank-*n* tensor form ``(2,)*n`` 

1118 throughout the gate loop to avoid per-gate ``reshape`` dispatch; 

1119 only the initial and final conversions to/from the flat ``(2**n,)`` 

1120 representation incur a reshape. 

1121 

1122 All gate tensors and einsum subscript strings are pre-extracted from 

1123 the tape before the simulation loop so that each iteration performs 

1124 only a single ``jnp.einsum`` call with zero additional Python 

1125 overhead (no method dispatch, no property access, no cache lookup). 

1126 

1127 Args: 

1128 tape: Ordered list of gate operations to apply. 

1129 n_qubits: Total number of qubits. 

1130 

1131 Returns: 

1132 Statevector of shape ``(2**n_qubits,)``. 

1133 """ 

1134 dim = 2**n_qubits 

1135 

1136 # Pre-extract gate tensors and einsum subscripts — eliminates all 

1137 # per-gate Python overhead (method calls, property lookups, cache 

1138 # hits on _einsum_subscript) from the hot loop. 

1139 compiled = [] 

1140 for op in tape: 

1141 if isinstance(op, Barrier): 

1142 continue 

1143 k = len(op.wires) 

1144 gt = op._gate_tensor(k) 

1145 sub = _einsum_subscript(n_qubits, k, tuple(op.wires)) 

1146 compiled.append((gt, sub)) 

1147 

1148 state = jnp.zeros(dim, dtype=_cdtype()).at[0].set(1.0) 

1149 psi = state.reshape((2,) * n_qubits) 

1150 for gt, sub in compiled: 

1151 psi = jnp.einsum(sub, gt, psi) 

1152 return psi.reshape(dim) 

1153 

1154 @staticmethod 

1155 def _simulate_mixed(tape: List[Operation], n_qubits: int) -> jnp.ndarray: 

1156 """Density-matrix simulation kernel. 

1157 

1158 Starts from \\rho = \\vert 0\\rangle\\langle 0\\vert and 

1159 applies each gate in *tape* via 

1160 :meth:`~qml_essentials.operations.Operation.apply_to_density` 

1161 (\\rho -> U\\rho U† for unitaries, \\Sigma_k K_k \\rho K_k\\dagger 

1162 for Kraus channels). 

1163 Required for noisy circuits. 

1164 

1165 Args: 

1166 tape: Ordered list of gate or channel operations to apply. 

1167 n_qubits: Total number of qubits. 

1168 

1169 Returns: 

1170 Density matrix of shape ``(2**n_qubits, 2**n_qubits)``. 

1171 """ 

1172 dim = 2**n_qubits 

1173 rho = jnp.zeros((dim, dim), dtype=_cdtype()).at[0, 0].set(1.0) 

1174 for op in tape: 

1175 rho = op.apply_to_density(rho, n_qubits) 

1176 return rho 

1177 

1178 @staticmethod 

1179 def _simulate_and_measure( 

1180 tape: List[Operation], 

1181 n_qubits: int, 

1182 type: str, 

1183 obs: List[Operation], 

1184 use_density: bool, 

1185 shots: Optional[int] = None, 

1186 key: Optional[jnp.ndarray] = None, 

1187 ) -> jnp.ndarray: 

1188 """Run simulation and measurement in a single dispatch. 

1189 

1190 Chooses statevector or density-matrix simulation based on 

1191 *use_density*, then applies the appropriate measurement function. 

1192 This eliminates duplicated branching logic in single-sample and 

1193 batched execution paths. 

1194 

1195 When *shots* is not ``None``, the exact probability distribution is 

1196 first computed, then ``shots`` samples are drawn from it to produce 

1197 a noisy estimate of the requested measurement (``"probs"`` or 

1198 ``"expval"``). 

1199 

1200 Pure-circuit density optimisation — when ``type == "density"`` 

1201 but no noise channels are present on the tape, the density matrix 

1202 is computed via statevector simulation followed by an outer product 

1203 ``\\rho = \\vert\\psi\\rangle\\langle\\psi\\vert`` 

1204 instead of evolving the full ``2^n\\times 2^n`` matrix 

1205 gate by gate. This reduces the per-gate cost from O(4^n) to 

1206 O(2^n), giving a significant speed-up for medium qubit counts 

1207 (~4x for 5 qubits). 

1208 

1209 Args: 

1210 tape: Ordered list of gate/channel operations to apply. 

1211 n_qubits: Total number of qubits. 

1212 type: Measurement type (``"state"``/``"probs"``/``"expval"``/ 

1213 ``"density"``). 

1214 obs: Observables for ``"expval"`` measurements. 

1215 use_density: If ``True``, use density-matrix simulation. 

1216 shots: Number of measurement shots. If ``None`` (default), 

1217 exact analytic results are returned. 

1218 key: JAX PRNG key for shot sampling. Required when *shots* 

1219 is not ``None``. 

1220 

1221 Returns: 

1222 Measurement result (shape depends on *type*). 

1223 """ 

1224 if use_density: 

1225 # Check if any operation is actually a noise channel. 

1226 has_noise = any(isinstance(o, KrausChannel) for o in tape) 

1227 if has_noise: 

1228 # Must do full density-matrix evolution for Kraus channels. 

1229 rho = Script._simulate_mixed(tape, n_qubits) 

1230 else: 

1231 # Pure circuit requesting density output: simulate the 

1232 # statevector (O(depth\times 2^n)) and form # noqa: W605 

1233 # \rho = \vert\psi\rangle\langle\psi\vert once # noqa: W605 

1234 # (O(4^n)). This avoids the O(depth\times 4^n) cost of # noqa: W605 

1235 # evolving the full density matrix gate by gate. 

1236 state = Script._simulate_pure(tape, n_qubits) 

1237 rho = jnp.outer(state, jnp.conj(state)) 

1238 

1239 if shots is not None and type in ("probs", "expval"): 

1240 exact_probs = jnp.real(jnp.diag(rho)) 

1241 return Script._sample_shots( 

1242 exact_probs, n_qubits, type, obs, shots, key 

1243 ) 

1244 return Script._measure_density(rho, n_qubits, type, obs) 

1245 

1246 state = Script._simulate_pure(tape, n_qubits) 

1247 

1248 if shots is not None and type in ("probs", "expval"): 

1249 exact_probs = jnp.abs(state) ** 2 

1250 return Script._sample_shots(exact_probs, n_qubits, type, obs, shots, key) 

1251 return Script._measure_state(state, n_qubits, type, obs) 

1252 

1253 @staticmethod 

1254 def _measure_state( 

1255 state: jnp.ndarray, 

1256 n_qubits: int, 

1257 type: str, 

1258 obs: List[Operation], 

1259 ) -> jnp.ndarray: 

1260 """Apply the requested measurement to a pure statevector. 

1261 

1262 Args: 

1263 state: Statevector of shape ``(2**n_qubits,)``. 

1264 n_qubits: Total number of qubits. 

1265 type: Measurement type — one of ``"state"``, ``"probs"``, 

1266 or ``"expval"``. 

1267 obs: Observables used when *type* is ``"expval"``. 

1268 

1269 Returns: 

1270 Measurement result whose shape depends on *type*: 

1271 

1272 - ``"state"`` -> ``(2**n_qubits,)`` 

1273 - ``"probs"`` -> ``(2**n_qubits,)`` 

1274 - ``"expval"`` -> ``(len(obs),)`` 

1275 

1276 Raises: 

1277 ValueError: If *type* is not a recognised measurement type. 

1278 """ 

1279 if type == "state": 

1280 return state 

1281 

1282 if type == "probs": 

1283 return jnp.abs(state) ** 2 

1284 

1285 if type == "expval": 

1286 # Fast path for single-qubit diagonal observables (PauliZ, etc.) 

1287 # where d0, d1 are the diagonal elements of the 2x2 observable. 

1288 # This replaces n_obs tensor contractions with a single |ψ|² 

1289 # and n_obs reductions over the probability vector. 

1290 

1291 def _is_single_qubit_diag(ob): 

1292 m = ob.__class__._matrix 

1293 if m is None or len(ob.wires) != 1: 

1294 return False 

1295 # Convert to NumPy to ensure concrete boolean evaluation 

1296 m_np = np.asarray(m) 

1297 return np.allclose(m_np - np.diag(np.diag(m_np)), 0) 

1298 

1299 all_single_qubit_diag = all(_is_single_qubit_diag(ob) for ob in obs) 

1300 

1301 if all_single_qubit_diag: 

1302 probs = jnp.abs(state) ** 2 

1303 psi_t = probs.reshape((2,) * n_qubits) 

1304 results = [] 

1305 for ob in obs: 

1306 q = ob.wires[0] 

1307 d = np.real(np.diag(np.asarray(ob.__class__._matrix))) 

1308 # Sum probabilities over all axes except qubit q 

1309 p_q = jnp.sum( 

1310 psi_t, axis=tuple(i for i in range(n_qubits) if i != q) 

1311 ) 

1312 results.append(d[0] * p_q[0] + d[1] * p_q[1]) 

1313 return jnp.array(results) 

1314 

1315 # General path: stack observable matrices and use a single 

1316 # batched matmul instead of a Python loop of tensor contractions. 

1317 # O_states[i] = obs[i] |ψ⟩, then ⟨O_i⟩ = Re(⟨ψ|O_states[i]⟩). 

1318 obs_mats = jnp.stack( 

1319 [ob.lifted_matrix(n_qubits) for ob in obs], axis=0 

1320 ) # (n_obs, dim, dim) 

1321 # Batched matvec: (n_obs, dim, dim) @ (dim,) -> (n_obs, dim) 

1322 O_states = jnp.einsum("oij,j->oi", obs_mats, state) 

1323 return jnp.real(jnp.einsum("i,oi->o", jnp.conj(state), O_states)) 

1324 

1325 raise ValueError(f"Unknown measurement type: {type!r}") 

1326 

1327 @staticmethod 

1328 def _measure_density( 

1329 rho: jnp.ndarray, 

1330 n_qubits: int, 

1331 type: str, 

1332 obs: List[Operation], 

1333 ) -> jnp.ndarray: 

1334 """Apply the requested measurement to a density matrix. 

1335 

1336 Args: 

1337 rho: Density matrix of shape ``(2**n_qubits, 2**n_qubits)``. 

1338 n_qubits: Total number of qubits. 

1339 type: Measurement type — one of ``"density"``, ``"probs"``, 

1340 or ``"expval"``. 

1341 obs: Observables used when *type* is ``"expval"``. 

1342 

1343 Returns: 

1344 Measurement result whose shape depends on *type*: 

1345 

1346 - ``"density"`` -> ``(2**n_qubits, 2**n_qubits)`` 

1347 - ``"probs"`` -> ``(2**n_qubits,)`` 

1348 - ``"expval"`` -> ``(len(obs),)`` 

1349 

1350 Raises: 

1351 ValueError: If *type* is ``"state"`` (not valid for mixed circuits) 

1352 or another unrecognised type. 

1353 """ 

1354 if type == "density": 

1355 return rho 

1356 

1357 if type == "probs": 

1358 return jnp.real(jnp.diag(rho)) 

1359 

1360 if type == "expval": 

1361 # Tr(O \\rho ) = \\Sigma_ij O_ij \\rho _ji 

1362 # Stack all observable matrices and compute all traces in one 

1363 # batched operation. 

1364 obs_mats = jnp.stack( 

1365 [ob.lifted_matrix(n_qubits) for ob in obs], axis=0 

1366 ) # (n_obs, dim, dim) 

1367 # einsum "oij,ji->o" computes Tr(O_o @ \\rho ) for each observable 

1368 return jnp.real(jnp.einsum("oij,ji->o", obs_mats, rho)) 

1369 

1370 raise ValueError( 

1371 "Measurement type 'state' is not defined for mixed (noisy) circuits. " 

1372 "Use 'density' instead." 

1373 ) 

1374 

1375 @staticmethod 

1376 def _sample_shots( 

1377 probs: jnp.ndarray, 

1378 n_qubits: int, 

1379 type: str, 

1380 obs: List[Operation], 

1381 shots: int, 

1382 key: jnp.ndarray, 

1383 ) -> jnp.ndarray: 

1384 """Convert exact probabilities into shot-sampled results. 

1385 

1386 Draws *shots* samples from the computational-basis probability 

1387 distribution and returns either estimated probabilities or 

1388 shot-based expectation values. 

1389 

1390 Args: 

1391 probs: Exact probability vector of shape ``(2**n_qubits,)``. 

1392 n_qubits: Total number of qubits. 

1393 type: Measurement type — ``"probs"`` or ``"expval"``. 

1394 obs: Observables used when *type* is ``"expval"``. 

1395 shots: Number of measurement shots. 

1396 key: JAX PRNG key for sampling. 

1397 

1398 Returns: 

1399 Shot-sampled measurement result: 

1400 

1401 - ``"probs"`` → ``(2**n_qubits,)`` estimated probabilities. 

1402 - ``"expval"`` → ``(len(obs),)`` estimated expectation values. 

1403 """ 

1404 dim = 2**n_qubits 

1405 

1406 # Draw `shots` samples from the computational basis. 

1407 # Each sample is an integer in [0, dim) representing a basis state. 

1408 samples = jax.random.choice(key, dim, shape=(shots,), p=probs) 

1409 

1410 # Build a histogram of counts for each basis state. 

1411 counts = jnp.zeros(dim, dtype=jnp.int32) 

1412 counts = counts.at[samples].add(1) 

1413 estimated_probs = counts / shots 

1414 

1415 if type == "probs": 

1416 return estimated_probs 

1417 

1418 if type == "expval": 

1419 # For each observable, compute O from the shot-sampled 

1420 # probabilities. For diagonal observables this is exact; 

1421 # for general observables we use Tr(O · diag(estimated_probs)). 

1422 results = [] 

1423 for ob in obs: 

1424 O_mat = ob.lifted_matrix(n_qubits) 

1425 # diagonal approximation from 

1426 # computational basis measurements, which is exact for 

1427 # diagonal observables like PauliZ) 

1428 results.append(jnp.real(jnp.dot(jnp.diag(O_mat), estimated_probs))) 

1429 return jnp.array(results) 

1430 

1431 raise ValueError( 

1432 f"Shot simulation is only supported for 'probs' and 'expval', got {type!r}." 

1433 ) 

1434 

1435 def execute( 

1436 self, 

1437 type: str = "expval", 

1438 obs: Optional[List[Operation]] = None, 

1439 *, 

1440 args: tuple = (), 

1441 kwargs: Optional[dict] = None, 

1442 in_axes: Optional[Tuple] = None, 

1443 shots: Optional[int] = None, 

1444 key: Optional[jnp.ndarray] = None, 

1445 ) -> jnp.ndarray: 

1446 """Execute the circuit and return measurement results. 

1447 

1448 Args: 

1449 type: Measurement type. One of: 

1450 

1451 - ``"expval"`` — expectation value ⟨ψ|O|ψ⟩ / Tr(O\\rho ) for 

1452 each observable in *obs*. 

1453 - ``"probs"`` — probability vector of shape ``(2**n,)``. 

1454 - ``"state"`` — raw statevector of shape ``(2**n,)``. 

1455 - ``"density"`` — full density matrix of shape 

1456 ``(2**n, 2**n)``. 

1457 

1458 obs: Observables required when type is ``"expval"``. 

1459 args: Positional arguments forwarded to the circuit function f. 

1460 kwargs: Keyword arguments forwarded to f. 

1461 in_axes: Batch axes for each element of *args*, following the same 

1462 convention as ``jax.vmap``: 

1463 

1464 - An integer selects that axis of the corresponding array as 

1465 the batch dimension. 

1466 - ``None`` broadcasts the argument (no batching). 

1467 

1468 When provided, :meth:`execute` calls ``jax.vmap`` over the 

1469 pure simulation kernel and returns results with a leading 

1470 batch dimension. 

1471 shots: Number of measurement shots for stochastic sampling. 

1472 If ``None`` (default), exact analytic results are returned. 

1473 Only supported for ``"probs"`` and ``"expval"`` measurement 

1474 types. 

1475 key: JAX PRNG key for shot sampling. If ``None`` and *shots* 

1476 is set, a default key ``jax.random.PRNGKey(0)`` is used. 

1477 

1478 Returns: 

1479 Without in_axes: shape determined by type. 

1480 With in_axes: shape ``(B, ...)`` with a leading batch dimension. 

1481 """ 

1482 if obs is None: 

1483 obs = [] 

1484 if kwargs is None: 

1485 kwargs = {} 

1486 if shots is not None and key is None: 

1487 key = jax.random.PRNGKey(0) 

1488 

1489 # Split single/ parallel execution 

1490 # TODO: we might want to unify the n_qubit stuff such that we can eliminate 

1491 # the parameter to this method entirely 

1492 if in_axes is not None: 

1493 return self._execute_batched( 

1494 type=type, 

1495 obs=obs, 

1496 args=args, 

1497 kwargs=kwargs, 

1498 in_axes=in_axes, 

1499 shots=shots, 

1500 key=key, 

1501 ) 

1502 else: 

1503 tape = self._record(*args, **kwargs) 

1504 n_qubits = self._n_qubits or self._infer_n_qubits(tape, obs) 

1505 

1506 has_noise = any(isinstance(op, KrausChannel) for op in tape) 

1507 use_density = type == "density" or has_noise 

1508 

1509 return self._simulate_and_measure( 

1510 tape, 

1511 n_qubits, 

1512 type, 

1513 obs, 

1514 use_density, 

1515 shots=shots, 

1516 key=key, 

1517 ) 

1518 

1519 def _execute_batched( 

1520 self, 

1521 type: str, 

1522 obs: List[Operation], 

1523 args: tuple, 

1524 kwargs: dict, 

1525 in_axes: Tuple, 

1526 shots: Optional[int] = None, 

1527 key: Optional[jnp.ndarray] = None, 

1528 ) -> jnp.ndarray: 

1529 """Vectorise :meth:`execute` over a batch axis using ``jax.vmap``. 

1530 

1531 The circuit function is traced once in Python with scalar slices to 

1532 record the tape, determine ``n_qubits``, and detect noise. The 

1533 resulting pure simulation kernel is then vmapped over the requested 

1534 axes. 

1535 

1536 Memory-aware chunking — before launching the full vmap, the 

1537 method estimates peak memory usage. If the full batch would exceed 

1538 available RAM (with a safety margin), the batch is automatically 

1539 split into sub-batches that fit. Each chunk is vmapped independently 

1540 and the results are concatenated. This trades a small amount of 

1541 wall-clock time for guaranteed execution without OOM. 

1542 

1543 When the full batch fits in memory, there is zero overhead — the 

1544 memory check is a pure Python arithmetic calculation (no JAX calls). 

1545 

1546 Args: 

1547 type: Measurement type (see :meth:`execute`). 

1548 obs: Observables (see :meth:`execute`). 

1549 args: Positional arguments for the circuit function. 

1550 kwargs: Keyword arguments for the circuit function. 

1551 in_axes: One entry per element of *args*. Follows ``jax.vmap`` 

1552 convention: an int gives the batch axis; ``None`` broadcasts. 

1553 shots: Number of measurement shots. If ``None``, exact results. 

1554 key: JAX PRNG key for shot sampling. 

1555 

1556 Returns: 

1557 Batched measurement results of shape ``(B, ...)`` where *B* is the 

1558 size of the batch dimension. 

1559 

1560 Raises: 

1561 ValueError: If ``len(in_axes) != len(args)``. 

1562 

1563 Note: 

1564 The ``jax.vmap`` call at the end of this method is the exact 

1565 boundary to replace with ``jax.shard_map`` for multi-device 

1566 execution:: 

1567 

1568 from jax.sharding import PartitionSpec as P, Mesh 

1569 result = jax.shard_map( 

1570 _single_execute, mesh=mesh, 

1571 in_specs=tuple(P(0) if ax is not None else P() for ax in in_axes), 

1572 out_specs=P(0), 

1573 )(*args) 

1574 """ 

1575 if len(in_axes) != len(args): 

1576 raise ValueError( 

1577 f"in_axes has {len(in_axes)} entries but args has {len(args)}. " 

1578 "Provide one in_axes entry per positional argument." 

1579 ) 

1580 

1581 # Determine batch size from the first batched arg 

1582 batch_size = 1 

1583 for a, ax in zip(args, in_axes): 

1584 if ax is not None: 

1585 batch_size = a.shape[ax] 

1586 break 

1587 

1588 arg_shapes = tuple( 

1589 (a.shape, a.dtype) if hasattr(a, "shape") else type(a) for a in args 

1590 ) 

1591 cache_kwargs = _make_hashable( 

1592 {k: v for k, v in kwargs.items() if not isinstance(v, jnp.ndarray)} 

1593 ) 

1594 

1595 # TODO: we need to fix the dirty class-level `batch_gate_error` hack 

1596 from qml_essentials.gates import UnitaryGates 

1597 

1598 cache_key = ( 

1599 type, 

1600 in_axes, 

1601 arg_shapes, 

1602 cache_kwargs, 

1603 UnitaryGates.batch_gate_error, 

1604 ) 

1605 

1606 # When called under an outer JAX transform (e.g. ``jacrev``) the 

1607 # cached ``batched_fn`` from a previous outer trace would leak that 

1608 # trace's tracers. Bypass the per-Script wrapper cache in that 

1609 # case; XLA-level compilation caching is unaffected. 

1610 # in_transform = _args_contain_tracer(args) 

1611 

1612 # --- Cache-hit fast path (no shots) --- 

1613 cached = self._jit_cache.get(cache_key) 

1614 # if cached is not None and shots is None and not in_transform: 

1615 if cached is not None and shots is None: 

1616 batched_fn, n_qubits, use_density = cached 

1617 # Check if we already determined the chunk size for this 

1618 # exact batch_size (avoids repeated psutil syscalls). 

1619 mem_key = ("_mem", cache_key, batch_size) 

1620 cached_chunk = self._jit_cache.get(mem_key) 

1621 if cached_chunk is not None: 

1622 if cached_chunk >= batch_size: 

1623 return batched_fn(*args) 

1624 return self._execute_chunked( 

1625 batched_fn, args, in_axes, batch_size, cached_chunk 

1626 ) 

1627 chunk_size = self._compute_chunk_size( 

1628 n_qubits, batch_size, type, use_density, len(obs) 

1629 ) 

1630 self._jit_cache[mem_key] = chunk_size 

1631 if chunk_size >= batch_size: 

1632 return batched_fn(*args) 

1633 return self._execute_chunked( 

1634 batched_fn, args, in_axes, batch_size, chunk_size 

1635 ) 

1636 

1637 # Record the tape once using scalar slices of each arg. 

1638 # This determines n_qubits and whether noise channels are present 

1639 # without running the full batch. 

1640 # Note, that we use lax.index_in_dim instead of jnp.take because JAX 

1641 # random key arrays do not support jnp.take. 

1642 # TODO: fix once that is available in JAX 

1643 def _slice_first(a, ax): 

1644 """Take the first element along axis *ax*.""" 

1645 return jax.lax.index_in_dim(a, 0, axis=ax, keepdims=False) 

1646 

1647 scalar_args = tuple( 

1648 _slice_first(a, ax) if ax is not None else a for a, ax in zip(args, in_axes) 

1649 ) 

1650 tape = self._record(*scalar_args, **kwargs) 

1651 n_qubits = self._n_qubits or self._infer_n_qubits(tape, obs) 

1652 has_noise = any(isinstance(op, KrausChannel) for op in tape) 

1653 use_density = type == "density" or has_noise 

1654 

1655 chunk_size = self._compute_chunk_size( 

1656 n_qubits, batch_size, type, use_density, len(obs) 

1657 ) 

1658 

1659 # Re-recording inside this function is necessary: the tape may 

1660 # contain operations whose matrices depend on the batched argument 

1661 # (e.g. RX(theta) where theta is a JAX tracer). jax.vmap traces 

1662 # this function once and generates a single XLA computation for 

1663 # the entire batch. 

1664 if shots is not None and type in ("probs", "expval"): 

1665 # Shot mode: compute exact probabilities, then sample. 

1666 # The shot key is appended as an extra vmapped argument. 

1667 def _single_execute_shots(*single_args_and_key): 

1668 *single_args, shot_key = single_args_and_key 

1669 single_tape = self._record(*single_args, **kwargs) 

1670 exact_result = self._simulate_and_measure( 

1671 single_tape, n_qubits, "probs", obs, use_density 

1672 ) 

1673 return Script._sample_shots( 

1674 exact_result, n_qubits, type, obs, shots, shot_key 

1675 ) 

1676 

1677 shot_keys = jax.random.split(key, batch_size) 

1678 shot_in_axes = in_axes + (0,) # key is batched over axis 0 

1679 shot_args = args + (shot_keys,) 

1680 

1681 # Shot-mode uses a separate cache key (includes shots) 

1682 shot_cache_key = ( 

1683 type, 

1684 "shots", 

1685 shots, 

1686 in_axes, 

1687 arg_shapes, 

1688 UnitaryGates.batch_gate_error, 

1689 ) 

1690 cached_shot = self._jit_cache.get(shot_cache_key) 

1691 if cached_shot is not None: 

1692 batched_fn = cached_shot[0] 

1693 else: 

1694 batched_fn = eqx.filter_jit( 

1695 jax.vmap(_single_execute_shots, in_axes=shot_in_axes) 

1696 ) 

1697 self._jit_cache[shot_cache_key] = (batched_fn, n_qubits, use_density) 

1698 

1699 if chunk_size >= batch_size: 

1700 return batched_fn(*shot_args) 

1701 return self._execute_chunked( 

1702 batched_fn, shot_args, shot_in_axes, batch_size, chunk_size 

1703 ) 

1704 

1705 def _single_execute(*single_args): 

1706 single_tape = self._record(*single_args, **kwargs) 

1707 return self._simulate_and_measure( 

1708 single_tape, n_qubits, type, obs, use_density 

1709 ) 

1710 

1711 # Wrapping the vmapped function in eqx.filter_jit has two effects: 

1712 # 1. Multi-core utilisation — the JIT-compiled XLA program can 

1713 # use intra-op parallelism to distribute independent SIMD lanes 

1714 # across CPU threads, unlike an eager vmap which runs 

1715 # single-threaded. 

1716 # 2. Compilation caching — subsequent calls with the same input 

1717 # shapes reuse the compiled kernel and skip all Python-level 

1718 # tracing, eliminating the O(B\\times circuit_depth) Python overhead. 

1719 # 

1720 # The compiled function is cached on this Script instance, 

1721 # keyed on (type, in_axes, arg_shapes). Repeated calls with the 

1722 # same structure (e.g. training iterations) skip both Python-level 

1723 # tracing and XLA compilation entirely — they jump straight to the 

1724 # cache check at the top of this method. 

1725 # NOTE: when altering properties of the model, this might not get re-compiled 

1726 # TODO: we might want to rework the data_reupload mechanism at some point 

1727 batched_fn = eqx.filter_jit(jax.vmap(_single_execute, in_axes=in_axes)) 

1728 # Cache the function together with metadata for fast-path memory 

1729 # checks on subsequent calls. Skip caching when the call is under 

1730 # an outer JAX transform (the closure of ``_single_execute`` 

1731 # captures ``n_qubits``/``obs``/``kwargs`` of this trace; reusing 

1732 # the wrapper under a different outer trace would leak its 

1733 # tracers). 

1734 # if not in_transform: 

1735 # self._jit_cache[cache_key] = (batched_fn, n_qubits, use_density) 

1736 self._jit_cache[cache_key] = (batched_fn, n_qubits, use_density) 

1737 

1738 if chunk_size >= batch_size: 

1739 return batched_fn(*args) 

1740 return self._execute_chunked(batched_fn, args, in_axes, batch_size, chunk_size) 

1741 

1742 def draw( 

1743 self, 

1744 figure: str = "text", 

1745 args: tuple = (), 

1746 kwargs: Optional[dict] = None, 

1747 **draw_kwargs: Any, 

1748 ) -> Union[str, Any]: 

1749 """Draw the quantum circuit. 

1750 

1751 Records the tape by calling the circuit function with the given 

1752 arguments, then renders the resulting gate sequence. 

1753 

1754 Args: 

1755 figure: Rendering backend. One of: 

1756 

1757 - ``"text"`` — ASCII art (returned as a ``str``). 

1758 - ``"mpl"`` — Matplotlib figure (returns ``(fig, ax)``). 

1759 - ``"tikz"`` — LaTeX/TikZ code via ``quantikz`` 

1760 (returns a :class:`TikzFigure`). 

1761 - ``"pulse"`` — Pulse schedule plot (returns ``(fig, axes)``). 

1762 

1763 args: Positional arguments forwarded to the circuit function 

1764 to record the tape. 

1765 kwargs: Keyword arguments forwarded to the circuit function. 

1766 **draw_kwargs: Extra options forwarded to the rendering backend: 

1767 

1768 - ``gate_values`` (bool): Show numeric gate angles instead of 

1769 symbolic \\theta_i labels. Default ``False``. 

1770 - ``show_carrier`` (bool): For ``"pulse"`` mode, overlay the 

1771 carrier-modulated waveform. Default ``False``. 

1772 

1773 Returns: 

1774 Depends on *figure*: 

1775 

1776 - ``"text"`` -> ``str`` 

1777 - ``"mpl"`` -> ``(matplotlib.figure.Figure, matplotlib.axes.Axes)`` 

1778 - ``"tikz"`` -> :class:`TikzFigure` 

1779 - ``"pulse"`` -> ``(matplotlib.figure.Figure, numpy.ndarray)`` 

1780 

1781 Raises: 

1782 ValueError: If *figure* is not one of the supported modes. 

1783 """ 

1784 if figure not in ("text", "mpl", "tikz", "pulse"): 

1785 raise ValueError( 

1786 f"Invalid figure mode: {figure!r}. " 

1787 "Must be 'text', 'mpl', 'tikz', or 'pulse'." 

1788 ) 

1789 

1790 if kwargs is None: 

1791 kwargs = {} 

1792 

1793 if figure == "pulse": 

1794 from qml_essentials.drawing import draw_pulse_schedule 

1795 

1796 events = self.pulse_events(*args, **kwargs) 

1797 n_qubits = ( 

1798 self._n_qubits 

1799 or max((w for ev in events for w in ev.wires), default=0) + 1 

1800 ) 

1801 return draw_pulse_schedule(events, n_qubits, **draw_kwargs) 

1802 

1803 tape = self._record(*args, **kwargs) 

1804 n_qubits = self._n_qubits or self._infer_n_qubits(tape, []) 

1805 

1806 # Filter out noise channels for drawing — they clutter the diagram 

1807 ops = [op for op in tape if not isinstance(op, KrausChannel)] 

1808 

1809 if figure == "text": 

1810 return draw_text(ops, n_qubits) 

1811 elif figure == "mpl": 

1812 return draw_mpl(ops, n_qubits, **draw_kwargs) 

1813 else: # tikz 

1814 return draw_tikz(ops, n_qubits, **draw_kwargs)