Coverage for qml_essentials / yaqsi.py: 93%

400 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-30 11:43 +0000

1from functools import reduce 

2from typing import Any, Callable, List, Optional, Tuple, Union 

3import threading 

4 

5import diffrax 

6import jax 

7import jax.numpy as jnp 

8import jax.scipy.linalg 

9import numpy as np # needed to prevent jitting some operations 

10 

11from qml_essentials.operations import ( 

12 Barrier, 

13 Hermitian, 

14 ParametrizedHamiltonian, 

15 Operation, 

16 KrausChannel, 

17 PauliZ, 

18 _einsum_subscript, 

19 _cdtype, 

20) 

21from qml_essentials.tape import recording, pulse_recording 

22from qml_essentials.drawing import draw_text, draw_mpl, draw_tikz 

23 

24import logging 

25 

26log = logging.getLogger(__name__) 

27 

28 

29def _make_hashable(obj): 

30 """Recursively convert an object into a hashable form for cache keys. 

31 

32 - ``dict`` → sorted tuple of ``(key, _make_hashable(value))`` pairs 

33 - ``list`` → tuple of ``_make_hashable(element)`` 

34 - ``set`` → frozenset of ``_make_hashable(element)`` 

35 - everything else is returned as-is (assumed hashable) 

36 """ 

37 if isinstance(obj, dict): 

38 return tuple(sorted((k, _make_hashable(v)) for k, v in obj.items())) 

39 if isinstance(obj, (list, tuple)): 

40 return tuple(_make_hashable(x) for x in obj) 

41 if isinstance(obj, set): 

42 return frozenset(_make_hashable(x) for x in obj) 

43 return obj 

44 

45 

46class Yaqsi: 

47 # TODO: generally, I would like to merge this into operations or vice-versa 

48 # and only keep Script here. It's not clear how to do this though. 

49 

50 # Module-level cache for JIT-compiled ODE solvers. Keyed on 

51 # (coeff_fn_id, dim, atol, rtol) so that all evolve() calls with the 

52 # same pulse shape function and matrix size share one compiled XLA 

53 # program. This turns O(n_gates) JIT compilations into 

54 # O(n_distinct_pulse_shapes) during pulse-mode circuit building. 

55 _evolve_solver_cache: dict = {} 

56 _evolve_solver_cache_lock = threading.Lock() 

57 

58 @staticmethod 

59 def _partial_trace_single( 

60 rho: jnp.ndarray, 

61 n_qubits: int, 

62 keep: List[int], 

63 ) -> jnp.ndarray: 

64 """Partial trace of a single density matrix (no batch dimension).""" 

65 shape = (2,) * (2 * n_qubits) 

66 rho_t = rho.reshape(shape) 

67 

68 trace_out = sorted(set(range(n_qubits)) - set(keep)) 

69 

70 for q in reversed(trace_out): 

71 n_remaining = rho_t.ndim // 2 

72 rho_t = jnp.trace(rho_t, axis1=q, axis2=q + n_remaining) 

73 

74 dim = 2 ** len(keep) 

75 return rho_t.reshape(dim, dim) 

76 

77 @classmethod 

78 def partial_trace( 

79 cls, 

80 rho: jnp.ndarray, 

81 n_qubits: int, 

82 keep: List[int], 

83 ) -> jnp.ndarray: 

84 """Partial trace of a density matrix, keeping only the specified qubits. 

85 

86 Supports both single density matrices of shape ``(2**n, 2**n)`` and 

87 batched density matrices of shape ``(B, 2**n, 2**n)``. 

88 

89 Args: 

90 rho: Density matrix of shape ``(2**n, 2**n)`` or ``(B, 2**n, 2**n)``. 

91 n_qubits: Total number of qubits. 

92 keep: List of qubit indices to *keep* (0-indexed). 

93 

94 Returns: 

95 Reduced density matrix of shape ``(2**k, 2**k)`` or ``(B, 2**k, 2**k)`` 

96 where *k* = ``len(keep)``. 

97 """ 

98 

99 dim = 2**n_qubits 

100 if rho.shape == (dim, dim): 

101 return Yaqsi._partial_trace_single(rho, n_qubits, keep) 

102 # Batched: shape (B, dim, dim) 

103 return jax.vmap(lambda r: Yaqsi._partial_trace_single(r, n_qubits, keep))(rho) 

104 

105 @staticmethod 

106 def _marginalize_probs_single( 

107 probs: jnp.ndarray, 

108 target_shape: Tuple[int], 

109 trace_out: Tuple[int], 

110 ) -> jnp.ndarray: 

111 """Marginalize a single probability vector (no batch dimension).""" 

112 probs_t = probs.reshape(target_shape) 

113 

114 for q in trace_out: 

115 probs_t = probs_t.sum(axis=q) 

116 

117 return probs_t.ravel() 

118 

119 @classmethod 

120 def marginalize_probs( 

121 cls, 

122 probs: jnp.ndarray, 

123 n_qubits: int, 

124 keep: Tuple[int], 

125 ) -> jnp.ndarray: 

126 """Marginalize a probability vector to keep only the specified qubits. 

127 

128 Supports both single probability vectors of shape ``(2**n,)`` and 

129 batched vectors of shape ``(B, 2**n)``. 

130 

131 Args: 

132 probs: Probability vector of shape ``(2**n,)`` or ``(B, 2**n)``. 

133 n_qubits: Total number of qubits. 

134 keep: List of qubit indices to *keep* (0-indexed). 

135 

136 Returns: 

137 Marginalized probability vector of shape ``(2**k,)`` or ``(B, 2**k)`` 

138 where *k* = ``len(keep)``. 

139 """ 

140 

141 dim = 2**n_qubits 

142 trace_out = tuple(q for q in range(n_qubits - 1, -1, -1) if q not in keep) 

143 target_shape = (2,) * n_qubits 

144 

145 return jax.vmap( 

146 lambda p: Yaqsi._marginalize_probs_single(p, target_shape, trace_out) 

147 )(probs.reshape(-1, dim)) 

148 

149 @classmethod 

150 def build_parity_observable( 

151 cls, 

152 qubit_group: List[int], 

153 ) -> Hermitian: 

154 """Build a multi-qubit parity observable. 

155 

156 Args: 

157 qubit_group: List of qubit indices for the parity measurement. 

158 

159 Returns: 

160 A :class:`Hermitian` operation whose matrix is the Z parity 

161 tensor product and whose wires match the given qubits. 

162 """ 

163 Z = PauliZ._matrix 

164 mat = reduce(jnp.kron, [Z] * len(qubit_group)) 

165 return Hermitian(matrix=mat, wires=qubit_group, record=False) 

166 

167 @classmethod 

168 def evolve(cls, hamiltonian, name=None, **odeint_kwargs): 

169 """Return a gate-factory for Hamiltonian time evolution. 

170 

171 Supports two modes: 

172 

173 Static — when *hamiltonian* is a :class:`Hermitian`:: 

174 

175 gate = evolve(Hermitian(H_mat, wires=0)) 

176 gate(t=0.5) # U = exp(-i*0.5*H) 

177 

178 Time-dependent — when *hamiltonian* is a 

179 :class:`ParametrizedHamiltonian` (created via ``coeff_fn * Hermitian``):: 

180 

181 H_td = coeff_fn * Hermitian(H_mat, wires=0) 

182 gate = evolve(H_td) 

183 gate([A, sigma], T) # U via ODE: dU/dt = -i f(p,t) H * U 

184 

185 The time-dependent case solves the Schrödinger equation numerically 

186 using ``diffrax.diffeqsolve`` with a Dopri8 adaptive Runge-Kutta 

187 solver 

188 

189 All computations are pure JAX and fully differentiable with 

190 ``jax.grad``. 

191 

192 Args: 

193 hamiltonian: Either a :class:`Hermitian` (static evolution) or a 

194 :class:`ParametrizedHamiltonian` (time-dependent evolution). 

195 **odeint_kwargs: Extra keyword arguments. Recognised keys: 

196 

197 - ``atol``, ``rtol`` — absolute/relative tolerances for the 

198 adaptive step-size controller (default ``1.4e-8``). 

199 

200 Returns: 

201 A callable gate factory. Signature depends on the mode: 

202 

203 - Static: ``(t, wires=0) -> Operation`` 

204 - Time-dependent: ``(coeff_args, T) -> Operation`` 

205 

206 Raises: 

207 TypeError: If *hamiltonian* is neither ``Hermitian`` nor 

208 ``ParametrizedHamiltonian``. 

209 """ 

210 if isinstance(hamiltonian, Hermitian): 

211 return cls._evolve_static(hamiltonian, name=name) 

212 elif isinstance(hamiltonian, ParametrizedHamiltonian): 

213 return cls._evolve_parametrized(hamiltonian, name=name, **odeint_kwargs) 

214 else: 

215 raise TypeError( 

216 f"evolve() expects a Hermitian or ParametrizedHamiltonian, " 

217 f"got {type(hamiltonian)}" 

218 ) 

219 

220 @staticmethod 

221 def _evolve_static(hermitian: Hermitian, name=None) -> Callable: 

222 """Gate factory for static Hamiltonian evolution U = exp(-i t H).""" 

223 H_mat = hermitian.matrix 

224 

225 def _apply(t: float, wires: Union[int, List[int]] = 0) -> Operation: 

226 U = jax.scipy.linalg.expm(-1j * t * H_mat) 

227 return Operation(wires=wires, matrix=U, name=name) 

228 

229 return _apply 

230 

231 @classmethod 

232 def _evolve_parametrized( 

233 cls, ph: ParametrizedHamiltonian, name=None, **odeint_kwargs 

234 ) -> Callable: 

235 """Gate factory for time-dependent Hamiltonian evolution. 

236 

237 Solves the matrix ODE ``dU/dt = -i f(params, t) H * U`` with 

238 ``U(0) = I`` using ``diffrax.diffeqsolve`` (Dopri8 adaptive RK). 

239 

240 Performance improvements over the previous ``jax.experimental.ode`` 

241 implementation: 

242 

243 Uses diffrax — a modern, well-maintained JAX ODE library with 

244 better XLA compilation, adjoint methods, and step-size control. 

245 The JIT-compiled solver is cached per coefficient function so 

246 that multiple ``evolve()`` calls with the same pulse shape (but 

247 different Hamiltonian matrices or parameters) reuse the same 

248 compiled XLA program. This avoids O(n_gates) JIT compilations 

249 during pulse-mode tape recording. 

250 Pre-computes ``-i*H`` once instead of multiplying at every RHS 

251 evaluation. 

252 Avoids dynamic ``jnp.where`` branching for the time span. 

253 

254 Args: 

255 ph: A :class:`ParametrizedHamiltonian` holding the coefficient 

256 function, the Hamiltonian matrix, and wire indices. 

257 **odeint_kwargs: ``atol`` and ``rtol`` for the step-size controller. 

258 """ 

259 H_mat = ph.H_mat 

260 coeff_fn = ph.coeff_fn 

261 wires = ph.wires 

262 dim = H_mat.shape[0] 

263 

264 # Pre-compute -i*H once (avoids repeated multiplication in RHS) 

265 neg_iH = -1j * H_mat 

266 

267 atol = odeint_kwargs.pop("atol", 1.4e-8) 

268 rtol = odeint_kwargs.pop("rtol", 1.4e-8) 

269 

270 # Look up or build the JIT-compiled solver for this (coeff_fn, dim) 

271 # combination. All pulse gates with the same pulse shape function 

272 # (e.g. all RX gates share Sx, all RY share Sy) reuse a single 

273 # compiled XLA program, with neg_iH passed as a regular argument 

274 # rather than captured in the closure. This turns O(n_gates) JIT 

275 # compilations into O(n_distinct_pulse_shapes) compilations. 

276 # 

277 # We key on the function's __code__ object rather than id(coeff_fn) 

278 # because each pulse gate call creates a new closure (capturing the 

279 # rotation angle `w`), but the underlying code is identical. Under 

280 # JIT, the captured `w` is a tracer anyway, so the compiled program 

281 # is generic over `w`. 

282 cache_key = (id(coeff_fn.__code__), dim, atol, rtol) 

283 

284 with cls._evolve_solver_cache_lock: 

285 _solve = cls._evolve_solver_cache.get(cache_key) 

286 

287 if _solve is None: 

288 solver = diffrax.Dopri8() 

289 stepsize_controller = diffrax.PIDController(atol=atol, rtol=rtol) 

290 

291 @jax.jit 

292 def _solve(neg_iH, params, t0, t1): 

293 """Solve dU/dt = f(params,t) * (-iH) * U from t0 to t1.""" 

294 

295 def rhs(t, y, args): 

296 return coeff_fn(args, t) * (neg_iH @ y) 

297 

298 sol = diffrax.diffeqsolve( 

299 diffrax.ODETerm(rhs), 

300 solver, 

301 t0=t0, 

302 t1=t1, 

303 dt0=None, # let the controller choose the initial step 

304 y0=jnp.eye(dim, dtype=_cdtype()), 

305 args=params, 

306 stepsize_controller=stepsize_controller, 

307 max_steps=4096, 

308 ) 

309 

310 # sol.ys has shape (1, dim, dim) for SaveAt(t1=True) (default) 

311 return sol.ys[0] 

312 

313 with cls._evolve_solver_cache_lock: 

314 # Double-check to avoid overwriting a concurrent build 

315 existing = cls._evolve_solver_cache.get(cache_key) 

316 if existing is not None: 

317 _solve = existing 

318 else: 

319 cls._evolve_solver_cache[cache_key] = _solve 

320 

321 def _apply(coeff_args, T) -> Operation: 

322 """Evolve under the time-dependent Hamiltonian. 

323 

324 Args: 

325 coeff_args: List of parameter sets, one per Hamiltonian term. 

326 For static Hamiltonians, ``coeff_args[0]`` is 

327 forwarded to ``coeff_fn(params, t)`` as the first argument. 

328 T: Total evolution time. If scalar, the ODE is solved on 

329 ``[0, T]``. If a 2-element array, on ``[T[0], T[1]]``. 

330 

331 Returns: 

332 An :class:`Operation` wrapping the computed unitary. 

333 """ 

334 # PennyLane convention: coeff_args is a list of param-sets, 

335 # one per term. Single term -> unpack the first. 

336 params = ( 

337 coeff_args[0] if isinstance(coeff_args, (list, tuple)) else coeff_args 

338 ) 

339 

340 # Build time span — resolve at Python level to avoid traced branching 

341 T_arr = jnp.asarray(T, dtype=jnp.float64) 

342 if T_arr.ndim == 0: 

343 t0 = jnp.float64(0.0) 

344 t1 = T_arr 

345 else: 

346 t0 = T_arr[0] 

347 t1 = T_arr[1] 

348 

349 U = _solve(neg_iH, params, t0, t1) 

350 

351 return Operation(wires=wires, matrix=U, name=name) 

352 

353 return _apply 

354 

355 

356# TODO adjust imports to use classmethods instead 

357partial_trace = Yaqsi.partial_trace 

358evolve = Yaqsi.evolve 

359marginalize_probs = Yaqsi.marginalize_probs 

360build_parity_observable = Yaqsi.build_parity_observable 

361 

362 

363class Script: 

364 """Circuit container and executor backed by pure JAX kernels. 

365 

366 ``Script`` takes a callable *f* representing a quantum circuit. 

367 Within *f*, :class:`~qml_essentials.operations.Operation` objects are 

368 instantiated and automatically recorded onto a tape. The tape is then 

369 simulated using either a statevector or density-matrix kernel depending on 

370 whether noise channels are present. 

371 

372 Attributes: 

373 f: The circuit function whose body instantiates ``Operation`` objects. 

374 _n_qubits: Optionally pre-declared number of qubits. When ``None`` 

375 the qubit count is inferred from the operations recorded on the 

376 tape. 

377 

378 Example: 

379 >>> def circuit(theta): 

380 ... RX(theta, wires=0) 

381 ... PauliZ(wires=1) 

382 >>> script = Script(circuit, n_qubits=2) 

383 >>> result = script.execute(type="expval", obs=[PauliZ(0)]) 

384 """ 

385 

386 def __init__(self, f: Callable, n_qubits: Optional[int] = None) -> None: 

387 """Initialise a Script. 

388 

389 Args: 

390 f: A function whose body instantiates ``Operation`` objects. 

391 Signature: ``f(*args, **kwargs) -> None``. 

392 n_qubits: Number of qubits. If ``None``, inferred from the 

393 operations recorded on the tape. 

394 """ 

395 self.f = f 

396 self._n_qubits = n_qubits 

397 self._jit_cache: dict = {} # keyed on (type, in_axes, arg_shapes, gateError) 

398 

399 @staticmethod 

400 def _estimate_peak_bytes( 

401 n_qubits: int, 

402 batch_size: int, 

403 type: str, 

404 use_density: bool, 

405 n_obs: int = 0, 

406 ) -> int: 

407 """Estimate peak memory (bytes) for a batched simulation. 

408 

409 The estimate accounts for: 

410 

411 - The batched statevector (always needed, even for density). 

412 - The batched output tensor (state / probs / density / expval). 

413 - One gate-tensor temporary per batch element (the einsum buffer). 

414 

415 Observable matrices are **not** counted: they are computed inside 

416 the JIT-compiled function and XLA manages their lifetime (reusing 

417 buffers between observables). Similarly, the outer-product 

418 temporary for pure-circuit density mode is transient within XLA. 

419 

420 Element size is determined dynamically from ``jax.config.x64_enabled``: 

421 when x64 mode is disabled (the JAX default), complex values are 

422 ``complex64`` (8 bytes) and floats are ``float32`` (4 bytes), 

423 halving memory usage compared to the x64 path. 

424 

425 A 1.5× safety factor is applied to cover XLA compiler temporaries, 

426 padding, and other allocations not directly visible to Python. 

427 

428 This is a pure Python arithmetic calculation with no JAX calls — 

429 it adds essentially zero overhead. 

430 

431 Args: 

432 n_qubits: Number of qubits in the circuit. 

433 batch_size: Number of batch elements. 

434 type: Measurement type (``"state"``, ``"probs"``, ``"expval"``, 

435 ``"density"``). 

436 use_density: Whether density-matrix simulation is used. 

437 n_obs: Number of observables (relevant for ``"expval"``). 

438 

439 Returns: 

440 Estimated peak memory in bytes. 

441 """ 

442 dim = 2**n_qubits 

443 # Detect actual element size: JAX silently truncates complex128 

444 # to complex64 when x64 mode is disabled (the default). 

445 elem = 16 if jax.config.x64_enabled else 8 # complex128 vs complex64 

446 real_elem = elem // 2 # float64 vs float32 

447 

448 # Statevector: always allocated during simulation 

449 sv_bytes = batch_size * dim * elem 

450 

451 # Simulation intermediate: when density-matrix simulation is used, 

452 # the full rho (dim × dim) must be held during gate evolution — 

453 # even if the final output is only probs or expval. 

454 # apply_to_density contracts both U and U* against rho, so at least 

455 # two intermediate (dim × dim) buffers are alive simultaneously. 

456 if use_density: 

457 sim_bytes = 2 * batch_size * dim * dim * elem 

458 else: 

459 sim_bytes = 0 # statevector is already counted above 

460 

461 # Output tensor: this is the *returned* array, not the simulation 

462 # intermediate. For probs/expval with density simulation the 

463 # density matrix is reduced to a small output *before* returning, 

464 # so only the reduced output coexists with the next chunk. 

465 if type == "density": 

466 out_bytes = batch_size * dim * dim * elem 

467 elif type == "expval": 

468 out_bytes = batch_size * max(n_obs, 1) * real_elem 

469 elif type == "probs": 

470 out_bytes = batch_size * dim * real_elem 

471 else: # state 

472 out_bytes = batch_size * dim * elem 

473 

474 # Gate temporaries: einsum creates one (2,)*n buffer per batch elem 

475 gate_tmp = batch_size * dim * elem 

476 

477 # Peak = max(simulation phase, output phase). During simulation 

478 # the intermediate + statevector + gate temps are alive. After 

479 # measurement, only the output survives. So peak is whichever 

480 # phase is larger. 

481 sim_peak = sv_bytes + sim_bytes + gate_tmp 

482 out_peak = out_bytes 

483 raw = max(sim_peak, out_peak) 

484 

485 # 1.5× safety factor for XLA compiler temporaries, padding, etc. 

486 return int(raw * 1.5) 

487 

488 @staticmethod 

489 def _available_memory_bytes() -> int: 

490 """Return available system memory in bytes. 

491 

492 Uses ``psutil.virtual_memory().available`` for cross-platform 

493 support (Linux, macOS, Windows). Falls back to reading 

494 ``/proc/meminfo`` on Linux, and finally to a conservative 4 GiB 

495 default if neither approach succeeds. 

496 

497 Returns: 

498 Available memory in bytes. 

499 """ 

500 mem = 4 * 1024**3 

501 # Primary: psutil (works on Linux, macOS, Windows) 

502 try: 

503 import psutil 

504 

505 mem = psutil.virtual_memory().available 

506 except Exception: 

507 log.debug("psutil not available. Fallback to /proc/meminfo") 

508 

509 # Fallback: /proc/meminfo (Linux only) 

510 try: 

511 with open("/proc/meminfo", "r") as f: 

512 for line in f: 

513 if line.startswith("MemAvailable:"): 

514 mem = int(line.split()[1]) * 1024 # kB → bytes 

515 except Exception: 

516 log.debug("Failed to read /proc/meminfo. Falling back to 4 GiB") 

517 

518 log.debug(f"Available memory: {mem/1024**3:.1f} GB") 

519 return mem 

520 

521 @staticmethod 

522 def _compute_chunk_size( 

523 n_qubits: int, 

524 batch_size: int, 

525 type: str, 

526 use_density: bool, 

527 n_obs: int = 0, 

528 memory_fraction: float = 0.8, 

529 ) -> int: 

530 """Determine the largest chunk size that fits in available memory. 

531 

532 If the full batch fits, returns *batch_size* (i.e. no chunking). 

533 Otherwise, returns the largest chunk size such that the computation 

534 of one chunk **plus** the full output accumulator fits within 

535 ``memory_fraction`` of available RAM. 

536 

537 The output accumulator is the final ``(batch_size, ...)`` array that 

538 holds all results. When chunking, this array must coexist with the 

539 active chunk computation, so its size is subtracted from available 

540 memory before computing how many elements fit per chunk. 

541 

542 The minimum chunk size is 1 (fully serialised). 

543 

544 Args: 

545 n_qubits: Number of qubits. 

546 batch_size: Total batch size. 

547 type: Measurement type. 

548 use_density: Whether density-matrix simulation is used. 

549 n_obs: Number of observables. 

550 memory_fraction: Fraction of available memory to target 

551 (default 0.8 = 80%). 

552 

553 Returns: 

554 Chunk size (number of batch elements per sub-batch). 

555 """ 

556 avail = int(Script._available_memory_bytes() * memory_fraction) 

557 full_est = Script._estimate_peak_bytes( 

558 n_qubits, batch_size, type, use_density, n_obs 

559 ) 

560 

561 if full_est <= avail: 

562 return batch_size # everything fits — no chunking 

563 

564 # The output accumulator (the final (batch_size, ...) result array) 

565 # must coexist with each chunk's computation, so subtract its size 

566 # from available memory before sizing chunks. 

567 dim = 2**n_qubits 

568 elem = 16 if jax.config.x64_enabled else 8 

569 real_elem = elem // 2 

570 if type == "density": 

571 accum_bytes = batch_size * dim * dim * elem 

572 elif type == "expval": 

573 accum_bytes = batch_size * max(n_obs, 1) * real_elem 

574 elif type == "probs": 

575 accum_bytes = batch_size * dim * real_elem 

576 else: 

577 accum_bytes = batch_size * dim * elem 

578 avail_for_chunks = max(avail - accum_bytes, elem) # at least 1 element 

579 

580 # Per-element cost: the memory for computing a single batch element. 

581 per_elem = Script._estimate_peak_bytes(n_qubits, 1, type, use_density, n_obs) 

582 

583 if per_elem <= 0: 

584 return batch_size 

585 

586 chunk = avail_for_chunks // per_elem 

587 chunk = max(1, min(chunk, batch_size)) 

588 

589 if chunk == 1 and per_elem > avail: 

590 log.warning( 

591 f"A single batch element requires ~{per_elem / 1024**3:.2f} GB " 

592 f"but only ~{avail / 1024**3:.2f} GB is available. " 

593 f"Proceeding with chunk_size=1 but OOM is possible. " 

594 f"Consider reducing n_qubits or switching measurement type." 

595 ) 

596 

597 log.info( 

598 f"Computation requires ~{full_est / 1024**3:.2f} GB which " 

599 f"does not fit in ~{avail / 1024**3:.2f} GB. " 

600 f"Using chunk size {chunk}." 

601 ) 

602 return chunk 

603 

604 @staticmethod 

605 def _execute_chunked( 

606 batched_fn: Callable, 

607 args: tuple, 

608 in_axes: Tuple, 

609 batch_size: int, 

610 chunk_size: int, 

611 ) -> jnp.ndarray: 

612 """Execute a vmapped function in memory-safe chunks. 

613 

614 Splits the batch dimension into sub-batches of at most *chunk_size* 

615 elements, runs each through the JIT-compiled *batched_fn*, and 

616 writes results into a pre-allocated output array. 

617 

618 Only one chunk's intermediate result is alive at a time: each 

619 chunk is computed, copied into the output buffer, and then its 

620 reference is dropped — allowing JAX/XLA to reclaim the memory 

621 before the next chunk starts. This keeps peak memory at roughly 

622 ``output_buffer + one_chunk_computation`` rather than the sum of 

623 all chunk outputs. 

624 

625 Args: 

626 batched_fn: A JIT-compiled, vmapped callable. 

627 args: Full-batch arguments (before slicing). 

628 in_axes: Per-argument batch axis specification. 

629 batch_size: Total number of batch elements. 

630 chunk_size: Maximum elements per chunk. 

631 

632 Returns: 

633 Batched results with the same leading dimension as the 

634 full batch. 

635 """ 

636 n_chunks = (batch_size + chunk_size - 1) // chunk_size 

637 log.debug( 

638 f"Memory-aware chunking: splitting batch of {batch_size} into " 

639 f"{n_chunks} chunks of <={chunk_size} elements." 

640 ) 

641 

642 output = None 

643 for chunk_idx in range(n_chunks): 

644 start = chunk_idx * chunk_size 

645 end = min(start + chunk_size, batch_size) 

646 size = end - start 

647 

648 # Slice each batched argument along its batch axis 

649 chunk_args = tuple( 

650 ( 

651 jax.lax.dynamic_slice_in_dim(a, start, size, axis=ax) 

652 if ax is not None 

653 else a 

654 ) 

655 for a, ax in zip(args, in_axes) 

656 ) 

657 

658 chunk_result = batched_fn(*chunk_args) 

659 

660 if output is None: 

661 # Pre-allocate the full output buffer on first chunk 

662 out_shape = (batch_size,) + chunk_result.shape[1:] 

663 output = jnp.zeros(out_shape, dtype=chunk_result.dtype) 

664 

665 # Copy chunk into the output buffer; the slice assignment 

666 # creates a new array (JAX arrays are immutable) but the old 

667 # `output` reference is immediately replaced, letting XLA 

668 # reclaim it. 

669 output = output.at[start:end].set(chunk_result) 

670 

671 # Explicitly drop the chunk reference so XLA can free the 

672 # chunk's device memory before computing the next one. 

673 del chunk_result, chunk_args 

674 # Trigger garbage collection to release device buffers 

675 jax.clear_caches() 

676 

677 return output 

678 

679 def _record(self, *args, **kwargs) -> List[Operation]: 

680 """Run the circuit function and collect the recorded operations. 

681 

682 Uses :func:`~qml_essentials.tape.recording` as a context manager so 

683 that the tape is always cleaned up — even if the circuit function 

684 raises — and nested recordings (e.g. from ``_execute_batched``) each 

685 get their own independent tape. 

686 

687 Args: 

688 *args: Positional arguments forwarded to the circuit function. 

689 **kwargs: Keyword arguments forwarded to the circuit function. 

690 

691 Returns: 

692 List of :class:`~qml_essentials.operations.Operation` instances in 

693 the order they were instantiated. 

694 """ 

695 with recording() as tape: 

696 self.f(*args, **kwargs) 

697 return tape 

698 

699 def pulse_events(self, *args, **kwargs) -> list: 

700 """Run the circuit and collect pulse events emitted by PulseGates. 

701 

702 Activates both the normal operation tape (so gates execute) and 

703 a pulse-event tape that captures 

704 :class:`~qml_essentials.drawing.PulseEvent` objects from leaf 

705 pulse gates (RX, RY, RZ, CZ). 

706 

707 Args: 

708 *args: Forwarded to the circuit function. 

709 **kwargs: Forwarded to the circuit function. 

710 

711 Returns: 

712 List of :class:`~qml_essentials.drawing.PulseEvent`. 

713 """ 

714 with pulse_recording() as events: 

715 with recording(): 

716 self.f(*args, **kwargs) 

717 return events 

718 

719 @staticmethod 

720 def _infer_n_qubits(ops: List[Operation], obs: List[Operation]) -> int: 

721 """Infer the number of qubits from a list of operations and observables. 

722 

723 Args: 

724 ops: Gate operations recorded on the tape. 

725 obs: Observable operations used for measurement. 

726 

727 Returns: 

728 The smallest number of qubits that covers all wire indices, i.e. 

729 ``max(all_wires) + 1`` (at least 1). 

730 """ 

731 all_wires: set[int] = set() 

732 for op in ops + obs: 

733 all_wires.update(op.wires) 

734 return max(all_wires) + 1 if all_wires else 1 

735 

736 @staticmethod 

737 def _simulate_pure(tape: List[Operation], n_qubits: int) -> jnp.ndarray: 

738 """Statevector simulation kernel. 

739 

740 Starts from |00…0⟩ and applies each gate in *tape* via tensor 

741 contraction. The state is kept in rank-*n* tensor form ``(2,)*n`` 

742 throughout the gate loop to avoid per-gate ``reshape`` dispatch; 

743 only the initial and final conversions to/from the flat ``(2**n,)`` 

744 representation incur a reshape. 

745 

746 All gate tensors and einsum subscript strings are pre-extracted from 

747 the tape before the simulation loop so that each iteration performs 

748 only a single ``jnp.einsum`` call with zero additional Python 

749 overhead (no method dispatch, no property access, no cache lookup). 

750 

751 Args: 

752 tape: Ordered list of gate operations to apply. 

753 n_qubits: Total number of qubits. 

754 

755 Returns: 

756 Statevector of shape ``(2**n_qubits,)``. 

757 """ 

758 dim = 2**n_qubits 

759 

760 # Pre-extract gate tensors and einsum subscripts — eliminates all 

761 # per-gate Python overhead (method calls, property lookups, cache 

762 # hits on _einsum_subscript) from the hot loop. 

763 compiled = [] 

764 for op in tape: 

765 if isinstance(op, Barrier): 

766 continue 

767 k = len(op.wires) 

768 gt = op._gate_tensor(k) 

769 sub = _einsum_subscript(n_qubits, k, tuple(op.wires)) 

770 compiled.append((gt, sub)) 

771 

772 state = jnp.zeros(dim, dtype=_cdtype()).at[0].set(1.0) 

773 psi = state.reshape((2,) * n_qubits) 

774 for gt, sub in compiled: 

775 psi = jnp.einsum(sub, gt, psi) 

776 return psi.reshape(dim) 

777 

778 @staticmethod 

779 def _simulate_mixed(tape: List[Operation], n_qubits: int) -> jnp.ndarray: 

780 """Density-matrix simulation kernel. 

781 

782 Starts from \\rho = \\vert 0\\rangle\\langle 0\\vert and 

783 applies each gate in *tape* via 

784 :meth:`~qml_essentials.operations.Operation.apply_to_density` 

785 (\\rho -> U\\rho U† for unitaries, \\Sigma_k K_k \\rho K_k\\dagger 

786 for Kraus channels). 

787 Required for noisy circuits. 

788 

789 Args: 

790 tape: Ordered list of gate or channel operations to apply. 

791 n_qubits: Total number of qubits. 

792 

793 Returns: 

794 Density matrix of shape ``(2**n_qubits, 2**n_qubits)``. 

795 """ 

796 dim = 2**n_qubits 

797 rho = jnp.zeros((dim, dim), dtype=_cdtype()).at[0, 0].set(1.0) 

798 for op in tape: 

799 rho = op.apply_to_density(rho, n_qubits) 

800 return rho 

801 

802 @staticmethod 

803 def _simulate_and_measure( 

804 tape: List[Operation], 

805 n_qubits: int, 

806 type: str, 

807 obs: List[Operation], 

808 use_density: bool, 

809 shots: Optional[int] = None, 

810 key: Optional[jnp.ndarray] = None, 

811 ) -> jnp.ndarray: 

812 """Run simulation and measurement in a single dispatch. 

813 

814 Chooses statevector or density-matrix simulation based on 

815 *use_density*, then applies the appropriate measurement function. 

816 This eliminates duplicated branching logic in single-sample and 

817 batched execution paths. 

818 

819 When *shots* is not ``None``, the exact probability distribution is 

820 first computed, then ``shots`` samples are drawn from it to produce 

821 a noisy estimate of the requested measurement (``"probs"`` or 

822 ``"expval"``). 

823 

824 Pure-circuit density optimisation — when ``type == "density"`` 

825 but no noise channels are present on the tape, the density matrix 

826 is computed via statevector simulation followed by an outer product 

827 ``\\rho = \\vert\\psi\\rangle\\langle\\psi\\vert`` 

828 instead of evolving the full ``2^n\\times 2^n`` matrix 

829 gate by gate. This reduces the per-gate cost from O(4^n) to 

830 O(2^n), giving a significant speed-up for medium qubit counts 

831 (~4x for 5 qubits). 

832 

833 Args: 

834 tape: Ordered list of gate/channel operations to apply. 

835 n_qubits: Total number of qubits. 

836 type: Measurement type (``"state"``/``"probs"``/``"expval"``/ 

837 ``"density"``). 

838 obs: Observables for ``"expval"`` measurements. 

839 use_density: If ``True``, use density-matrix simulation. 

840 shots: Number of measurement shots. If ``None`` (default), 

841 exact analytic results are returned. 

842 key: JAX PRNG key for shot sampling. Required when *shots* 

843 is not ``None``. 

844 

845 Returns: 

846 Measurement result (shape depends on *type*). 

847 """ 

848 if use_density: 

849 # Check if any operation is actually a noise channel. 

850 has_noise = any(isinstance(o, KrausChannel) for o in tape) 

851 if has_noise: 

852 # Must do full density-matrix evolution for Kraus channels. 

853 rho = Script._simulate_mixed(tape, n_qubits) 

854 else: 

855 # Pure circuit requesting density output: simulate the 

856 # statevector (O(depth\times 2^n)) and form # noqa: W605 

857 # \rho = \vert\psi\rangle\langle\psi\vert once # noqa: W605 

858 # (O(4^n)). This avoids the O(depth\times 4^n) cost of # noqa: W605 

859 # evolving the full density matrix gate by gate. 

860 state = Script._simulate_pure(tape, n_qubits) 

861 rho = jnp.outer(state, jnp.conj(state)) 

862 

863 if shots is not None and type in ("probs", "expval"): 

864 exact_probs = jnp.real(jnp.diag(rho)) 

865 return Script._sample_shots( 

866 exact_probs, n_qubits, type, obs, shots, key 

867 ) 

868 return Script._measure_density(rho, n_qubits, type, obs) 

869 

870 state = Script._simulate_pure(tape, n_qubits) 

871 

872 if shots is not None and type in ("probs", "expval"): 

873 exact_probs = jnp.abs(state) ** 2 

874 return Script._sample_shots(exact_probs, n_qubits, type, obs, shots, key) 

875 return Script._measure_state(state, n_qubits, type, obs) 

876 

877 @staticmethod 

878 def _measure_state( 

879 state: jnp.ndarray, 

880 n_qubits: int, 

881 type: str, 

882 obs: List[Operation], 

883 ) -> jnp.ndarray: 

884 """Apply the requested measurement to a pure statevector. 

885 

886 Args: 

887 state: Statevector of shape ``(2**n_qubits,)``. 

888 n_qubits: Total number of qubits. 

889 type: Measurement type — one of ``"state"``, ``"probs"``, 

890 or ``"expval"``. 

891 obs: Observables used when *type* is ``"expval"``. 

892 

893 Returns: 

894 Measurement result whose shape depends on *type*: 

895 

896 - ``"state"`` -> ``(2**n_qubits,)`` 

897 - ``"probs"`` -> ``(2**n_qubits,)`` 

898 - ``"expval"`` -> ``(len(obs),)`` 

899 

900 Raises: 

901 ValueError: If *type* is not a recognised measurement type. 

902 """ 

903 if type == "state": 

904 return state 

905 

906 if type == "probs": 

907 return jnp.abs(state) ** 2 

908 

909 if type == "expval": 

910 # Fast path for single-qubit diagonal observables (PauliZ, etc.) 

911 # where d0, d1 are the diagonal elements of the 2x2 observable. 

912 # This replaces n_obs tensor contractions with a single |ψ|² 

913 # and n_obs reductions over the probability vector. 

914 

915 def _is_single_qubit_diag(ob): 

916 m = ob.__class__._matrix 

917 if m is None or len(ob.wires) != 1: 

918 return False 

919 # Convert to NumPy to ensure concrete boolean evaluation 

920 m_np = np.asarray(m) 

921 return np.allclose(m_np - np.diag(np.diag(m_np)), 0) 

922 

923 all_single_qubit_diag = all(_is_single_qubit_diag(ob) for ob in obs) 

924 

925 if all_single_qubit_diag: 

926 probs = jnp.abs(state) ** 2 

927 psi_t = probs.reshape((2,) * n_qubits) 

928 results = [] 

929 for ob in obs: 

930 q = ob.wires[0] 

931 d = np.real(np.diag(np.asarray(ob.__class__._matrix))) 

932 # Sum probabilities over all axes except qubit q 

933 p_q = jnp.sum( 

934 psi_t, axis=tuple(i for i in range(n_qubits) if i != q) 

935 ) 

936 results.append(d[0] * p_q[0] + d[1] * p_q[1]) 

937 return jnp.array(results) 

938 

939 # General path: stack observable matrices and use a single 

940 # batched matmul instead of a Python loop of tensor contractions. 

941 # O_states[i] = obs[i] |ψ⟩, then ⟨O_i⟩ = Re(⟨ψ|O_states[i]⟩). 

942 obs_mats = jnp.stack( 

943 [ob.lifted_matrix(n_qubits) for ob in obs], axis=0 

944 ) # (n_obs, dim, dim) 

945 # Batched matvec: (n_obs, dim, dim) @ (dim,) -> (n_obs, dim) 

946 O_states = jnp.einsum("oij,j->oi", obs_mats, state) 

947 return jnp.real(jnp.einsum("i,oi->o", jnp.conj(state), O_states)) 

948 

949 raise ValueError(f"Unknown measurement type: {type!r}") 

950 

951 @staticmethod 

952 def _measure_density( 

953 rho: jnp.ndarray, 

954 n_qubits: int, 

955 type: str, 

956 obs: List[Operation], 

957 ) -> jnp.ndarray: 

958 """Apply the requested measurement to a density matrix. 

959 

960 Args: 

961 rho: Density matrix of shape ``(2**n_qubits, 2**n_qubits)``. 

962 n_qubits: Total number of qubits. 

963 type: Measurement type — one of ``"density"``, ``"probs"``, 

964 or ``"expval"``. 

965 obs: Observables used when *type* is ``"expval"``. 

966 

967 Returns: 

968 Measurement result whose shape depends on *type*: 

969 

970 - ``"density"`` -> ``(2**n_qubits, 2**n_qubits)`` 

971 - ``"probs"`` -> ``(2**n_qubits,)`` 

972 - ``"expval"`` -> ``(len(obs),)`` 

973 

974 Raises: 

975 ValueError: If *type* is ``"state"`` (not valid for mixed circuits) 

976 or another unrecognised type. 

977 """ 

978 if type == "density": 

979 return rho 

980 

981 if type == "probs": 

982 return jnp.real(jnp.diag(rho)) 

983 

984 if type == "expval": 

985 # Tr(O \\rho ) = \\Sigma_ij O_ij \\rho _ji 

986 # Stack all observable matrices and compute all traces in one 

987 # batched operation. 

988 obs_mats = jnp.stack( 

989 [ob.lifted_matrix(n_qubits) for ob in obs], axis=0 

990 ) # (n_obs, dim, dim) 

991 # einsum "oij,ji->o" computes Tr(O_o @ \\rho ) for each observable 

992 return jnp.real(jnp.einsum("oij,ji->o", obs_mats, rho)) 

993 

994 raise ValueError( 

995 "Measurement type 'state' is not defined for mixed (noisy) circuits. " 

996 "Use 'density' instead." 

997 ) 

998 

999 @staticmethod 

1000 def _sample_shots( 

1001 probs: jnp.ndarray, 

1002 n_qubits: int, 

1003 type: str, 

1004 obs: List[Operation], 

1005 shots: int, 

1006 key: jnp.ndarray, 

1007 ) -> jnp.ndarray: 

1008 """Convert exact probabilities into shot-sampled results. 

1009 

1010 Draws *shots* samples from the computational-basis probability 

1011 distribution and returns either estimated probabilities or 

1012 shot-based expectation values. 

1013 

1014 Args: 

1015 probs: Exact probability vector of shape ``(2**n_qubits,)``. 

1016 n_qubits: Total number of qubits. 

1017 type: Measurement type — ``"probs"`` or ``"expval"``. 

1018 obs: Observables used when *type* is ``"expval"``. 

1019 shots: Number of measurement shots. 

1020 key: JAX PRNG key for sampling. 

1021 

1022 Returns: 

1023 Shot-sampled measurement result: 

1024 

1025 - ``"probs"`` → ``(2**n_qubits,)`` estimated probabilities. 

1026 - ``"expval"`` → ``(len(obs),)`` estimated expectation values. 

1027 """ 

1028 dim = 2**n_qubits 

1029 

1030 # Draw `shots` samples from the computational basis. 

1031 # Each sample is an integer in [0, dim) representing a basis state. 

1032 samples = jax.random.choice(key, dim, shape=(shots,), p=probs) 

1033 

1034 # Build a histogram of counts for each basis state. 

1035 counts = jnp.zeros(dim, dtype=jnp.int32) 

1036 counts = counts.at[samples].add(1) 

1037 estimated_probs = counts / shots 

1038 

1039 if type == "probs": 

1040 return estimated_probs 

1041 

1042 if type == "expval": 

1043 # For each observable, compute O from the shot-sampled 

1044 # probabilities. For diagonal observables this is exact; 

1045 # for general observables we use Tr(O · diag(estimated_probs)). 

1046 results = [] 

1047 for ob in obs: 

1048 O_mat = ob.lifted_matrix(n_qubits) 

1049 # diagonal approximation from 

1050 # computational basis measurements, which is exact for 

1051 # diagonal observables like PauliZ) 

1052 results.append(jnp.real(jnp.dot(jnp.diag(O_mat), estimated_probs))) 

1053 return jnp.array(results) 

1054 

1055 raise ValueError( 

1056 f"Shot simulation is only supported for 'probs' and 'expval', " 

1057 f"got {type!r}." 

1058 ) 

1059 

1060 def execute( 

1061 self, 

1062 type: str = "expval", 

1063 obs: Optional[List[Operation]] = None, 

1064 *, 

1065 args: tuple = (), 

1066 kwargs: Optional[dict] = None, 

1067 in_axes: Optional[Tuple] = None, 

1068 shots: Optional[int] = None, 

1069 key: Optional[jnp.ndarray] = None, 

1070 ) -> jnp.ndarray: 

1071 """Execute the circuit and return measurement results. 

1072 

1073 Args: 

1074 type: Measurement type. One of: 

1075 

1076 - ``"expval"`` — expectation value ⟨ψ|O|ψ⟩ / Tr(O\\rho ) for 

1077 each observable in *obs*. 

1078 - ``"probs"`` — probability vector of shape ``(2**n,)``. 

1079 - ``"state"`` — raw statevector of shape ``(2**n,)``. 

1080 - ``"density"`` — full density matrix of shape 

1081 ``(2**n, 2**n)``. 

1082 

1083 obs: Observables required when type is ``"expval"``. 

1084 args: Positional arguments forwarded to the circuit function f. 

1085 kwargs: Keyword arguments forwarded to f. 

1086 in_axes: Batch axes for each element of *args*, following the same 

1087 convention as ``jax.vmap``: 

1088 

1089 - An integer selects that axis of the corresponding array as 

1090 the batch dimension. 

1091 - ``None`` broadcasts the argument (no batching). 

1092 

1093 When provided, :meth:`execute` calls ``jax.vmap`` over the 

1094 pure simulation kernel and returns results with a leading 

1095 batch dimension. 

1096 shots: Number of measurement shots for stochastic sampling. 

1097 If ``None`` (default), exact analytic results are returned. 

1098 Only supported for ``"probs"`` and ``"expval"`` measurement 

1099 types. 

1100 key: JAX PRNG key for shot sampling. If ``None`` and *shots* 

1101 is set, a default key ``jax.random.PRNGKey(0)`` is used. 

1102 

1103 Returns: 

1104 Without in_axes: shape determined by type. 

1105 With in_axes: shape ``(B, ...)`` with a leading batch dimension. 

1106 """ 

1107 if obs is None: 

1108 obs = [] 

1109 if kwargs is None: 

1110 kwargs = {} 

1111 if shots is not None and key is None: 

1112 key = jax.random.PRNGKey(0) 

1113 

1114 # Split single/ parallel execution 

1115 # TODO: we might want to unify the n_qubit stuff such that we can eliminate 

1116 # the parameter to this method entirely 

1117 if in_axes is not None: 

1118 return self._execute_batched( 

1119 type=type, 

1120 obs=obs, 

1121 args=args, 

1122 kwargs=kwargs, 

1123 in_axes=in_axes, 

1124 shots=shots, 

1125 key=key, 

1126 ) 

1127 else: 

1128 tape = self._record(*args, **kwargs) 

1129 n_qubits = self._n_qubits or self._infer_n_qubits(tape, obs) 

1130 

1131 has_noise = any(isinstance(op, KrausChannel) for op in tape) 

1132 use_density = type == "density" or has_noise 

1133 

1134 return self._simulate_and_measure( 

1135 tape, 

1136 n_qubits, 

1137 type, 

1138 obs, 

1139 use_density, 

1140 shots=shots, 

1141 key=key, 

1142 ) 

1143 

1144 def _execute_batched( 

1145 self, 

1146 type: str, 

1147 obs: List[Operation], 

1148 args: tuple, 

1149 kwargs: dict, 

1150 in_axes: Tuple, 

1151 shots: Optional[int] = None, 

1152 key: Optional[jnp.ndarray] = None, 

1153 ) -> jnp.ndarray: 

1154 """Vectorise :meth:`execute` over a batch axis using ``jax.vmap``. 

1155 

1156 The circuit function is traced once in Python with scalar slices to 

1157 record the tape, determine ``n_qubits``, and detect noise. The 

1158 resulting pure simulation kernel is then vmapped over the requested 

1159 axes. 

1160 

1161 Memory-aware chunking — before launching the full vmap, the 

1162 method estimates peak memory usage. If the full batch would exceed 

1163 available RAM (with a safety margin), the batch is automatically 

1164 split into sub-batches that fit. Each chunk is vmapped independently 

1165 and the results are concatenated. This trades a small amount of 

1166 wall-clock time for guaranteed execution without OOM. 

1167 

1168 When the full batch fits in memory, there is zero overhead — the 

1169 memory check is a pure Python arithmetic calculation (no JAX calls). 

1170 

1171 Args: 

1172 type: Measurement type (see :meth:`execute`). 

1173 obs: Observables (see :meth:`execute`). 

1174 args: Positional arguments for the circuit function. 

1175 kwargs: Keyword arguments for the circuit function. 

1176 in_axes: One entry per element of *args*. Follows ``jax.vmap`` 

1177 convention: an int gives the batch axis; ``None`` broadcasts. 

1178 shots: Number of measurement shots. If ``None``, exact results. 

1179 key: JAX PRNG key for shot sampling. 

1180 

1181 Returns: 

1182 Batched measurement results of shape ``(B, ...)`` where *B* is the 

1183 size of the batch dimension. 

1184 

1185 Raises: 

1186 ValueError: If ``len(in_axes) != len(args)``. 

1187 

1188 Note: 

1189 The ``jax.vmap`` call at the end of this method is the exact 

1190 boundary to replace with ``jax.shard_map`` for multi-device 

1191 execution:: 

1192 

1193 from jax.sharding import PartitionSpec as P, Mesh 

1194 result = jax.shard_map( 

1195 _single_execute, mesh=mesh, 

1196 in_specs=tuple(P(0) if ax is not None else P() for ax in in_axes), 

1197 out_specs=P(0), 

1198 )(*args) 

1199 """ 

1200 if len(in_axes) != len(args): 

1201 raise ValueError( 

1202 f"in_axes has {len(in_axes)} entries but args has {len(args)}. " 

1203 "Provide one in_axes entry per positional argument." 

1204 ) 

1205 

1206 # Determine batch size from the first batched arg 

1207 batch_size = 1 

1208 for a, ax in zip(args, in_axes): 

1209 if ax is not None: 

1210 batch_size = a.shape[ax] 

1211 break 

1212 

1213 arg_shapes = tuple( 

1214 (a.shape, a.dtype) if hasattr(a, "shape") else type(a) for a in args 

1215 ) 

1216 cache_kwargs = _make_hashable( 

1217 {k: v for k, v in kwargs.items() if not isinstance(v, jnp.ndarray)} 

1218 ) 

1219 

1220 # TODO: we need to fix the dirty class-level `batch_gate_error` hack 

1221 from qml_essentials.gates import UnitaryGates 

1222 

1223 cache_key = ( 

1224 type, 

1225 in_axes, 

1226 arg_shapes, 

1227 cache_kwargs, 

1228 UnitaryGates.batch_gate_error, 

1229 ) 

1230 

1231 # --- Cache-hit fast path (no shots) --- 

1232 cached = self._jit_cache.get(cache_key) 

1233 if cached is not None and shots is None: 

1234 batched_fn, n_qubits, use_density = cached 

1235 # Check if we already determined the chunk size for this 

1236 # exact batch_size (avoids repeated psutil syscalls). 

1237 mem_key = ("_mem", cache_key, batch_size) 

1238 cached_chunk = self._jit_cache.get(mem_key) 

1239 if cached_chunk is not None: 

1240 if cached_chunk >= batch_size: 

1241 return batched_fn(*args) 

1242 return self._execute_chunked( 

1243 batched_fn, args, in_axes, batch_size, cached_chunk 

1244 ) 

1245 chunk_size = self._compute_chunk_size( 

1246 n_qubits, batch_size, type, use_density, len(obs) 

1247 ) 

1248 self._jit_cache[mem_key] = chunk_size 

1249 if chunk_size >= batch_size: 

1250 return batched_fn(*args) 

1251 return self._execute_chunked( 

1252 batched_fn, args, in_axes, batch_size, chunk_size 

1253 ) 

1254 

1255 # Record the tape once using scalar slices of each arg. 

1256 # This determines n_qubits and whether noise channels are present 

1257 # without running the full batch. 

1258 # Note, that we use lax.index_in_dim instead of jnp.take because JAX 

1259 # random key arrays do not support jnp.take. 

1260 # TODO: fix once that is available in JAX 

1261 def _slice_first(a, ax): 

1262 """Take the first element along axis *ax*.""" 

1263 return jax.lax.index_in_dim(a, 0, axis=ax, keepdims=False) 

1264 

1265 scalar_args = tuple( 

1266 _slice_first(a, ax) if ax is not None else a for a, ax in zip(args, in_axes) 

1267 ) 

1268 tape = self._record(*scalar_args, **kwargs) 

1269 n_qubits = self._n_qubits or self._infer_n_qubits(tape, obs) 

1270 has_noise = any(isinstance(op, KrausChannel) for op in tape) 

1271 use_density = type == "density" or has_noise 

1272 

1273 chunk_size = self._compute_chunk_size( 

1274 n_qubits, batch_size, type, use_density, len(obs) 

1275 ) 

1276 

1277 # Re-recording inside this function is necessary: the tape may 

1278 # contain operations whose matrices depend on the batched argument 

1279 # (e.g. RX(theta) where theta is a JAX tracer). jax.vmap traces 

1280 # this function once and generates a single XLA computation for 

1281 # the entire batch. 

1282 if shots is not None and type in ("probs", "expval"): 

1283 # Shot mode: compute exact probabilities, then sample. 

1284 # The shot key is appended as an extra vmapped argument. 

1285 def _single_execute_shots(*single_args_and_key): 

1286 *single_args, shot_key = single_args_and_key 

1287 single_tape = self._record(*single_args, **kwargs) 

1288 exact_result = self._simulate_and_measure( 

1289 single_tape, n_qubits, "probs", obs, use_density 

1290 ) 

1291 return Script._sample_shots( 

1292 exact_result, n_qubits, type, obs, shots, shot_key 

1293 ) 

1294 

1295 shot_keys = jax.random.split(key, batch_size) 

1296 shot_in_axes = in_axes + (0,) # key is batched over axis 0 

1297 shot_args = args + (shot_keys,) 

1298 

1299 # Shot-mode uses a separate cache key (includes shots) 

1300 shot_cache_key = ( 

1301 type, 

1302 "shots", 

1303 shots, 

1304 in_axes, 

1305 arg_shapes, 

1306 UnitaryGates.batch_gate_error, 

1307 ) 

1308 cached_shot = self._jit_cache.get(shot_cache_key) 

1309 if cached_shot is not None: 

1310 batched_fn = cached_shot[0] 

1311 else: 

1312 batched_fn = jax.jit( 

1313 jax.vmap(_single_execute_shots, in_axes=shot_in_axes) 

1314 ) 

1315 self._jit_cache[shot_cache_key] = (batched_fn, n_qubits, use_density) 

1316 

1317 if chunk_size >= batch_size: 

1318 return batched_fn(*shot_args) 

1319 return self._execute_chunked( 

1320 batched_fn, shot_args, shot_in_axes, batch_size, chunk_size 

1321 ) 

1322 

1323 def _single_execute(*single_args): 

1324 single_tape = self._record(*single_args, **kwargs) 

1325 return self._simulate_and_measure( 

1326 single_tape, n_qubits, type, obs, use_density 

1327 ) 

1328 

1329 # Wrapping the vmapped function in jax.jit has two effects: 

1330 # 1. Multi-core utilisation — the JIT-compiled XLA program can 

1331 # use intra-op parallelism to distribute independent SIMD lanes 

1332 # across CPU threads, unlike an eager vmap which runs 

1333 # single-threaded. 

1334 # 2. Compilation caching — subsequent calls with the same input 

1335 # shapes reuse the compiled kernel and skip all Python-level 

1336 # tracing, eliminating the O(B\\times circuit_depth) Python overhead. 

1337 # 

1338 # The compiled function is cached on this Script instance, 

1339 # keyed on (type, in_axes, arg_shapes). Repeated calls with the 

1340 # same structure (e.g. training iterations) skip both Python-level 

1341 # tracing and XLA compilation entirely — they jump straight to the 

1342 # cache check at the top of this method. 

1343 # NOTE: when altering properties of the model, this might not get re-compiled 

1344 # TODO: we might want to rework the data_reupload mechanism at some point 

1345 batched_fn = jax.jit(jax.vmap(_single_execute, in_axes=in_axes)) 

1346 # Cache the function together with metadata for fast-path memory 

1347 # checks on subsequent calls. 

1348 self._jit_cache[cache_key] = (batched_fn, n_qubits, use_density) 

1349 

1350 if chunk_size >= batch_size: 

1351 return batched_fn(*args) 

1352 return self._execute_chunked(batched_fn, args, in_axes, batch_size, chunk_size) 

1353 

1354 def draw( 

1355 self, 

1356 figure: str = "text", 

1357 args: tuple = (), 

1358 kwargs: Optional[dict] = None, 

1359 **draw_kwargs: Any, 

1360 ) -> Union[str, Any]: 

1361 """Draw the quantum circuit. 

1362 

1363 Records the tape by calling the circuit function with the given 

1364 arguments, then renders the resulting gate sequence. 

1365 

1366 Args: 

1367 figure: Rendering backend. One of: 

1368 

1369 - ``"text"`` — ASCII art (returned as a ``str``). 

1370 - ``"mpl"`` — Matplotlib figure (returns ``(fig, ax)``). 

1371 - ``"tikz"`` — LaTeX/TikZ code via ``quantikz`` 

1372 (returns a :class:`TikzFigure`). 

1373 - ``"pulse"`` — Pulse schedule plot (returns ``(fig, axes)``). 

1374 

1375 args: Positional arguments forwarded to the circuit function 

1376 to record the tape. 

1377 kwargs: Keyword arguments forwarded to the circuit function. 

1378 **draw_kwargs: Extra options forwarded to the rendering backend: 

1379 

1380 - ``gate_values`` (bool): Show numeric gate angles instead of 

1381 symbolic \\theta_i labels. Default ``False``. 

1382 - ``show_carrier`` (bool): For ``"pulse"`` mode, overlay the 

1383 carrier-modulated waveform. Default ``False``. 

1384 

1385 Returns: 

1386 Depends on *figure*: 

1387 

1388 - ``"text"`` -> ``str`` 

1389 - ``"mpl"`` -> ``(matplotlib.figure.Figure, matplotlib.axes.Axes)`` 

1390 - ``"tikz"`` -> :class:`TikzFigure` 

1391 - ``"pulse"`` -> ``(matplotlib.figure.Figure, numpy.ndarray)`` 

1392 

1393 Raises: 

1394 ValueError: If *figure* is not one of the supported modes. 

1395 """ 

1396 if figure not in ("text", "mpl", "tikz", "pulse"): 

1397 raise ValueError( 

1398 f"Invalid figure mode: {figure!r}. " 

1399 "Must be 'text', 'mpl', 'tikz', or 'pulse'." 

1400 ) 

1401 

1402 if kwargs is None: 

1403 kwargs = {} 

1404 

1405 if figure == "pulse": 

1406 from qml_essentials.drawing import draw_pulse_schedule 

1407 

1408 events = self.pulse_events(*args, **kwargs) 

1409 n_qubits = ( 

1410 self._n_qubits 

1411 or max((w for ev in events for w in ev.wires), default=0) + 1 

1412 ) 

1413 return draw_pulse_schedule(events, n_qubits, **draw_kwargs) 

1414 

1415 tape = self._record(*args, **kwargs) 

1416 n_qubits = self._n_qubits or self._infer_n_qubits(tape, []) 

1417 

1418 # Filter out noise channels for drawing — they clutter the diagram 

1419 ops = [op for op in tape if not isinstance(op, KrausChannel)] 

1420 

1421 if figure == "text": 

1422 return draw_text(ops, n_qubits) 

1423 elif figure == "mpl": 

1424 return draw_mpl(ops, n_qubits, **draw_kwargs) 

1425 else: # tikz 

1426 return draw_tikz(ops, n_qubits, **draw_kwargs)