Coverage for qml_essentials / script.py: 97%

146 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-06-11 15:51 +0000

1from typing import Any, Callable, List, NamedTuple, Optional, Tuple, Union 

2 

3import jax 

4import jax.numpy as jnp 

5import equinox as eqx 

6 

7from qml_essentials.operations import Operation, KrausChannel 

8from qml_essentials.tape import recording, pulse_recording 

9from qml_essentials.drawing import draw_text, draw_mpl, draw_tikz 

10from qml_essentials.unitary import UnitaryGates 

11from qml_essentials import simulation, memory 

12 

13 

14def _make_hashable(obj): 

15 """Recursively convert an object into a hashable form for cache keys. 

16 

17 - ``dict`` → sorted tuple of ``(key, _make_hashable(value))`` pairs 

18 - ``list`` → tuple of ``_make_hashable(element)`` 

19 - ``set`` → frozenset of ``_make_hashable(element)`` 

20 - everything else is returned as-is (assumed hashable) 

21 """ 

22 if isinstance(obj, dict): 

23 return tuple(sorted((k, _make_hashable(v)) for k, v in obj.items())) 

24 if isinstance(obj, (list, tuple)): 

25 return tuple(_make_hashable(x) for x in obj) 

26 if isinstance(obj, set): 

27 return frozenset(_make_hashable(x) for x in obj) 

28 return obj 

29 

30 

31class _BatchPlan(NamedTuple): 

32 """Compiled artefacts for one batched circuit signature. 

33 

34 Cached in :attr:`Script._jit_cache` keyed on the signature ``cache_key``. 

35 ``batched_fn`` is deliberately the first field so callers (and tests) can 

36 unpack ``batched_fn, *_ = plan``. 

37 

38 Attributes: 

39 batched_fn: ``eqx.filter_jit(jax.vmap(...))`` wrapper; always valid, 

40 including under an outer transform and in shot mode. 

41 plain_fn: AOT-eligible ``jax.jit(jax.vmap(...))`` wrapper, or ``None`` 

42 when no concrete-array fast path applies (non-array argument, shot 

43 mode, or running under a transform). 

44 n_qubits: Qubit count derived from the recorded tape. 

45 use_density: Whether density-matrix simulation is required. 

46 n_ops: Number of operations on the tape (for memory estimation). 

47 """ 

48 

49 batched_fn: Callable 

50 plain_fn: Optional[Callable] 

51 n_qubits: int 

52 use_density: bool 

53 n_ops: int 

54 

55 

56class Script: 

57 """Circuit container and executor backed by pure JAX kernels. 

58 

59 ``Script`` takes a callable *f* representing a quantum circuit. 

60 Within *f*, :class:`~qml_essentials.operations.Operation` objects are 

61 instantiated and automatically recorded onto a tape. The tape is then 

62 simulated using either a statevector or density-matrix kernel depending on 

63 whether noise channels are present. 

64 

65 The stateless simulation/measurement kernels live in 

66 :mod:`qml_essentials.simulation` and the memory-estimation/chunking helpers 

67 in :mod:`qml_essentials.memory`; this class orchestrates recording, 

68 batching, caching, and drawing around them. 

69 

70 Attributes: 

71 f: The circuit function whose body instantiates ``Operation`` objects. 

72 _n_qubits: Optionally pre-declared number of qubits. When ``None`` 

73 the qubit count is inferred from the operations recorded on the 

74 tape. 

75 

76 Example: 

77 >>> def circuit(theta): 

78 ... RX(theta, wires=0) 

79 ... PauliZ(wires=1) 

80 >>> script = Script(circuit, n_qubits=2) 

81 >>> result = script.execute(type="expval", obs=[PauliZ(0)]) 

82 """ 

83 

84 def __init__(self, f: Callable[..., None], n_qubits: Optional[int] = None) -> None: 

85 """Initialise a Script. 

86 

87 Args: 

88 f: A function whose body instantiates ``Operation`` objects. 

89 Signature: ``f(*args, **kwargs) -> None``. 

90 n_qubits: Number of qubits. If ``None``, inferred from the 

91 operations recorded on the tape. 

92 """ 

93 self.f = f 

94 self._n_qubits = n_qubits 

95 self._jit_cache: dict = {} # keyed on (type, in_axes, arg_shapes, gateError) 

96 

97 def _record(self, *args, **kwargs) -> List[Operation]: 

98 """Run the circuit function and collect the recorded operations. 

99 

100 Uses :func:`~qml_essentials.tape.recording` as a context manager so 

101 that the tape is always cleaned up — even if the circuit function 

102 raises — and nested recordings (e.g. from ``_execute_batched``) each 

103 get their own independent tape. 

104 

105 Args: 

106 *args: Positional arguments forwarded to the circuit function. 

107 **kwargs: Keyword arguments forwarded to the circuit function. 

108 

109 Returns: 

110 List of :class:`~qml_essentials.operations.Operation` instances in 

111 the order they were instantiated. 

112 """ 

113 with recording() as tape: 

114 self.f(*args, **kwargs) 

115 return tape 

116 

117 def pulse_events(self, *args, **kwargs) -> list: 

118 """Run the circuit and collect pulse events emitted by PulseGates. 

119 

120 Activates both the normal operation tape (so gates execute) and 

121 a pulse-event tape that captures 

122 :class:`~qml_essentials.drawing.PulseEvent` objects from leaf 

123 pulse gates (RX, RY, RZ, CZ). 

124 

125 Args: 

126 *args (Any): Forwarded to the circuit function. 

127 **kwargs (Any): Forwarded to the circuit function. 

128 

129 Returns: 

130 List of :class:`~qml_essentials.drawing.PulseEvent`. 

131 """ 

132 with pulse_recording() as events: 

133 with recording(): 

134 self.f(*args, **kwargs) 

135 return events 

136 

137 def execute( 

138 self, 

139 type: str = "expval", 

140 obs: Optional[List[Operation]] = None, 

141 *, 

142 args: tuple = (), 

143 kwargs: Optional[dict] = None, 

144 in_axes: Optional[Tuple] = None, 

145 shots: Optional[int] = None, 

146 key: Optional[jnp.ndarray] = None, 

147 ) -> jnp.ndarray: 

148 """Execute the circuit and return measurement results. 

149 

150 Args: 

151 type: Measurement type. One of: 

152 

153 - ``"expval"`` — expectation value 

154 \\langle\\psi|O|\\psi\\rangle / Tr(O\\rho ) 

155 for each observable in *obs*. 

156 - ``"probs"`` — probability vector of shape ``(2**n,)``. 

157 - ``"state"`` — raw statevector of shape ``(2**n,)``. 

158 - ``"density"`` — full density matrix of shape 

159 ``(2**n, 2**n)``. 

160 

161 obs: Observables required when type is ``"expval"``. 

162 args: Positional arguments forwarded to the circuit function f. 

163 kwargs: Keyword arguments forwarded to f. 

164 in_axes: Batch axes for each element of *args*, following the same 

165 convention as ``jax.vmap``: 

166 

167 - An integer selects that axis of the corresponding array as 

168 the batch dimension. 

169 - ``None`` broadcasts the argument (no batching). 

170 

171 When provided, :meth:`execute` calls ``jax.vmap`` over the 

172 pure simulation kernel and returns results with a leading 

173 batch dimension. 

174 shots: Number of measurement shots for stochastic sampling. 

175 If ``None`` (default), exact analytic results are returned. 

176 Only supported for ``"probs"`` and ``"expval"`` measurement 

177 types. 

178 key: JAX PRNG key for shot sampling. If ``None`` and *shots* 

179 is set, a default key ``jax.random.PRNGKey(0)`` is used. 

180 

181 Returns: 

182 Without in_axes: shape determined by type. 

183 With in_axes: shape ``(B, ...)`` with a leading batch dimension. 

184 """ 

185 if obs is None: 

186 obs = [] 

187 if kwargs is None: 

188 kwargs = {} 

189 if shots is not None and key is None: 

190 key = jax.random.PRNGKey(0) 

191 

192 # Split single/ parallel execution 

193 # TODO: we might want to unify the n_qubit stuff such that we can eliminate 

194 # the parameter to this method entirely 

195 if in_axes is not None: 

196 return self._execute_batched( 

197 type=type, 

198 obs=obs, 

199 args=args, 

200 kwargs=kwargs, 

201 in_axes=in_axes, 

202 shots=shots, 

203 key=key, 

204 ) 

205 else: 

206 tape = self._record(*args, **kwargs) 

207 n_qubits = self._n_qubits or simulation.infer_n_qubits(tape, obs) 

208 

209 use_density = simulation.uses_density(tape, type) 

210 

211 return simulation.simulate_and_measure( 

212 tape, 

213 n_qubits, 

214 type, 

215 obs, 

216 use_density, 

217 shots=shots, 

218 key=key, 

219 ) 

220 

221 @staticmethod 

222 def _args_contain_tracer(args: tuple) -> bool: 

223 """Return ``True`` if any leaf of *args* is a JAX tracer. 

224 

225 When :meth:`execute` runs under an outer transform (``jax.grad``, 

226 ``jax.jacrev``, an enclosing ``jax.jit``/``vmap``) the positional 

227 arguments are tracers rather than concrete arrays. The tracer-tolerant 

228 ``eqx.filter_jit`` wrapper is still reused in that case (its closure 

229 captures only concrete metadata), but the concrete-only fast path — the 

230 ahead-of-time-compiled XLA executable — is invalid for tracers and must 

231 be skipped. 

232 """ 

233 return any( 

234 isinstance(x, jax.core.Tracer) for x in jax.tree_util.tree_leaves(args) 

235 ) 

236 

237 @staticmethod 

238 def _batch_size(args: tuple, in_axes: Tuple) -> int: 

239 """Size of the batch dimension, read from the first batched argument.""" 

240 for a, ax in zip(args, in_axes): 

241 if ax is not None: 

242 return a.shape[ax] 

243 return 1 

244 

245 @staticmethod 

246 def _slice_first(a: Any, ax: int) -> Any: 

247 """Take the first element along axis *ax*. 

248 

249 Uses ``jax.lax.index_in_dim`` rather than ``jnp.take`` because JAX 

250 random-key arrays do not support ``jnp.take``. 

251 """ 

252 # TODO: fix once that is available in JAX 

253 return jax.lax.index_in_dim(a, 0, axis=ax, keepdims=False) 

254 

255 def _record_metadata( 

256 self, scalar_args: tuple, kwargs: dict, obs: List[Operation], type: str 

257 ) -> Tuple[int, bool, int]: 

258 """Trace the tape from scalar slices to derive batch-invariant metadata. 

259 

260 Recording once with scalar slices determines ``n_qubits`` and whether 

261 noise channels are present (forcing density-matrix simulation) without 

262 running the full batch. 

263 

264 Returns: 

265 ``(n_qubits, use_density, n_ops)``. 

266 """ 

267 tape = self._record(*scalar_args, **kwargs) 

268 n_qubits = self._n_qubits or simulation.infer_n_qubits(tape, obs) 

269 use_density = simulation.uses_density(tape, type) 

270 return n_qubits, use_density, len(tape) 

271 

272 def _build_plan( 

273 self, 

274 type: str, 

275 obs: List[Operation], 

276 args: tuple, 

277 kwargs: dict, 

278 in_axes: Tuple, 

279 ) -> _BatchPlan: 

280 """Trace the circuit once and build the cacheable execution plan. 

281 

282 Records the tape from scalar slices of *args* (to derive 

283 ``n_qubits``/noise), then builds the vmapped ``eqx.filter_jit`` 

284 wrapper. When every positional argument is array-like (so plain 

285 ``jax.jit`` — which has no static-argument handling — is valid) an 

286 AOT-eligible plain ``jax.jit`` wrapper is built too; :meth:`_dispatch` 

287 lowers and compiles it lazily per batch size, and only with concrete 

288 args (the AOT path is gated off under a transform by the caller). 

289 """ 

290 scalar_args = tuple( 

291 self._slice_first(a, ax) if ax is not None else a 

292 for a, ax in zip(args, in_axes) 

293 ) 

294 n_qubits, use_density, n_ops = self._record_metadata( 

295 scalar_args, kwargs, obs, type 

296 ) 

297 

298 # Re-recording inside this closure is necessary: tape operations may 

299 # have matrices that depend on the batched argument (e.g. RX(theta) 

300 # with theta a tracer). jax.vmap traces this once into a single XLA 

301 # computation spanning the whole batch. 

302 def _single_execute(*single_args): 

303 single_tape = self._record(*single_args, **kwargs) 

304 return simulation.simulate_and_measure( 

305 single_tape, n_qubits, type, obs, use_density 

306 ) 

307 

308 # Wrapping the vmapped function in eqx.filter_jit: (1) treats non-array 

309 # arguments as static, so circuit signatures mixing arrays and Python 

310 # values work; (2) lets the XLA program use intra-op CPU parallelism; 

311 # (3) caches compilation across calls with the same input shapes. 

312 # NOTE: when altering properties of the model, this might not get 

313 # re-compiled. 

314 # TODO: we might want to rework the data_reupload mechanism at some point 

315 batched_fn = eqx.filter_jit(jax.vmap(_single_execute, in_axes=in_axes)) 

316 

317 # AOT eligibility is a structural property of the signature: plain 

318 # ``jax.jit`` has no static-argument handling, so it is valid only when 

319 # every positional argument is array-like. ``hasattr(a, "shape")`` is 

320 # true for concrete arrays, numpy arrays, and tracers, but false for 

321 # Python statics (str/None/dict). Building the wrapper is pure (it 

322 # traces nothing); the lower+compile happens lazily in :meth:`_dispatch` 

323 # and only with concrete args, so this is safe to build under a 

324 # transform — its use is gated off there by the caller. 

325 plain_fn = None 

326 if all(hasattr(a, "shape") for a in args): 

327 plain_fn = jax.jit(jax.vmap(_single_execute, in_axes=in_axes)) 

328 

329 return _BatchPlan(batched_fn, plain_fn, n_qubits, use_density, n_ops) 

330 

331 def _chunk_size( 

332 self, 

333 cache_key: tuple, 

334 plan: _BatchPlan, 

335 type: str, 

336 n_obs: int, 

337 batch_size: int, 

338 ) -> int: 

339 """Largest batch chunk that fits in memory, memoized per batch size. 

340 

341 The result is cached under ``("_mem", cache_key, batch_size)`` to avoid 

342 repeated ``psutil`` syscalls across a tight repeated-call loop. 

343 """ 

344 mem_key = ("_mem", cache_key, batch_size) 

345 chunk_size = self._jit_cache.get(mem_key) 

346 if chunk_size is None: 

347 chunk_size = memory.compute_chunk_size( 

348 plan.n_qubits, 

349 batch_size, 

350 type, 

351 plan.use_density, 

352 n_obs, 

353 n_ops=plan.n_ops, 

354 ) 

355 self._jit_cache[mem_key] = chunk_size 

356 return chunk_size 

357 

358 def _dispatch( 

359 self, 

360 aot_key: Optional[tuple], 

361 batched_fn: Callable, 

362 plain_fn: Optional[Callable], 

363 args: tuple, 

364 in_axes: Tuple, 

365 batch_size: int, 

366 chunk_size: int, 

367 ) -> jnp.ndarray: 

368 """Run a built plan through the leanest applicable path. 

369 

370 - ``chunk_size < batch_size``: the full batch would not fit in memory, 

371 so execute it in memory-safe sub-batches via 

372 :func:`~qml_essentials.memory.execute_chunked`. 

373 - otherwise, when an AOT-eligible ``plain_fn`` exists, ahead-of-time 

374 lower+compile the vmapped kernel to an XLA executable (cached per 

375 ``aot_key``) and call it directly. This skips both the per-call 

376 pytree partition/combine of :func:`eqx.filter_jit` and its 

377 just-in-time cache-key recomputation; for small circuits in a tight 

378 loop that dispatch overhead, not the XLA compute, dominates. 

379 - otherwise fall back to ``batched_fn`` (no ``plain_fn``: a non-array 

380 argument, shot mode, or running under a transform). 

381 """ 

382 if chunk_size < batch_size: 

383 return memory.execute_chunked( 

384 batched_fn, 

385 args, 

386 in_axes, 

387 batch_size, 

388 chunk_size, 

389 clear_caches=memory.CLEAR_CACHES_BETWEEN_CHUNKS, 

390 ) 

391 if plain_fn is None: 

392 return batched_fn(*args) 

393 compiled = self._jit_cache.get(aot_key) 

394 if compiled is None: 

395 compiled = plain_fn.lower(*args).compile() 

396 self._jit_cache[aot_key] = compiled 

397 return compiled(*args) 

398 

399 def _execute_batched( 

400 self, 

401 type: str, 

402 obs: List[Operation], 

403 args: tuple, 

404 kwargs: dict, 

405 in_axes: Tuple, 

406 shots: Optional[int] = None, 

407 key: Optional[jnp.ndarray] = None, 

408 ) -> jnp.ndarray: 

409 """Vectorise :meth:`execute` over a batch axis using ``jax.vmap``. 

410 

411 The circuit function is traced once in Python with scalar slices to 

412 record the tape, determine ``n_qubits``, and detect noise. The 

413 resulting pure simulation kernel is then vmapped over the requested 

414 axes. 

415 

416 Memory-aware chunking — before launching the full vmap, the 

417 method estimates peak memory usage. If the full batch would exceed 

418 available RAM (with a safety margin), the batch is automatically 

419 split into sub-batches that fit. Each chunk is vmapped independently 

420 and the results are concatenated. This trades a small amount of 

421 wall-clock time for guaranteed execution without OOM. 

422 

423 When the full batch fits in memory, there is zero overhead — the 

424 memory check is a pure Python arithmetic calculation (no JAX calls). 

425 

426 Args: 

427 type: Measurement type (see :meth:`execute`). 

428 obs: Observables (see :meth:`execute`). 

429 args: Positional arguments for the circuit function. 

430 kwargs: Keyword arguments for the circuit function. 

431 in_axes: One entry per element of *args*. Follows ``jax.vmap`` 

432 convention: an int gives the batch axis; ``None`` broadcasts. 

433 shots: Number of measurement shots. If ``None``, exact results. 

434 key: JAX PRNG key for shot sampling. 

435 

436 Returns: 

437 Batched measurement results of shape ``(B, ...)`` where *B* is the 

438 size of the batch dimension. 

439 

440 Raises: 

441 ValueError: If ``len(in_axes) != len(args)``. 

442 

443 Note: 

444 The ``jax.vmap`` call in :meth:`_build_plan` is the exact 

445 boundary to replace with ``jax.shard_map`` for multi-device 

446 execution:: 

447 

448 from jax.sharding import PartitionSpec as P, Mesh 

449 result = jax.shard_map( 

450 _single_execute, mesh=mesh, 

451 in_specs=tuple(P(0) if ax is not None else P() for ax in in_axes), 

452 out_specs=P(0), 

453 )(*args) 

454 """ 

455 if len(in_axes) != len(args): 

456 raise ValueError( 

457 f"in_axes has {len(in_axes)} entries but args has {len(args)}. " 

458 "Provide one in_axes entry per positional argument." 

459 ) 

460 

461 batch_size = self._batch_size(args, in_axes) 

462 

463 # Running under an outer JAX transform (e.g. ``jax.jacrev``) makes 

464 # ``args`` tracers. The tracer-tolerant ``batched_fn`` wrapper is still 

465 # cached and reused (see exact-mode dispatch below); only the AOT 

466 # ``plain_fn`` executable is gated off, as it cannot accept tracers. 

467 in_transform = self._args_contain_tracer(args) 

468 

469 arg_shapes = tuple( 

470 (a.shape, a.dtype) if hasattr(a, "shape") else type(a) for a in args 

471 ) 

472 # TODO: we need to fix the dirty class-level `batch_gate_error` hack. 

473 # It is a global toggle that changes the compiled circuit, so it has to 

474 # take part in every cache key. 

475 gate_error = UnitaryGates.batch_gate_error 

476 

477 # --- Shot mode: compute exact probabilities, then sample. --- 

478 if shots is not None and type in ("probs", "expval"): 

479 shot_cache_key = (type, "shots", shots, in_axes, arg_shapes, gate_error) 

480 shot_in_axes = in_axes + (0,) # shot key batched over axis 0 

481 shot_args = args + (jax.random.split(key, batch_size),) 

482 

483 plan = self._jit_cache.get(shot_cache_key) 

484 if plan is None: 

485 scalar_args = tuple( 

486 self._slice_first(a, ax) if ax is not None else a 

487 for a, ax in zip(args, in_axes) 

488 ) 

489 n_qubits, use_density, n_ops = self._record_metadata( 

490 scalar_args, kwargs, obs, type 

491 ) 

492 

493 # Re-recording inside the closure lets jax.vmap trace the whole 

494 # batch into one XLA program; the shot key is the extra vmapped 

495 # argument. 

496 def _single_execute_shots(*single_args_and_key): 

497 *single_args, shot_key = single_args_and_key 

498 single_tape = self._record(*single_args, **kwargs) 

499 exact_result = simulation.simulate_and_measure( 

500 single_tape, n_qubits, "probs", obs, use_density 

501 ) 

502 return simulation.sample_shots( 

503 exact_result, n_qubits, type, obs, shots, shot_key 

504 ) 

505 

506 batched_fn = eqx.filter_jit( 

507 jax.vmap(_single_execute_shots, in_axes=shot_in_axes) 

508 ) 

509 plan = _BatchPlan(batched_fn, None, n_qubits, use_density, n_ops) 

510 self._jit_cache[shot_cache_key] = plan 

511 

512 chunk_size = self._chunk_size( 

513 shot_cache_key, plan, type, len(obs), batch_size 

514 ) 

515 # Shot mode never uses the AOT fast path (plain_fn is None). 

516 return self._dispatch( 

517 None, 

518 plan.batched_fn, 

519 None, 

520 shot_args, 

521 shot_in_axes, 

522 batch_size, 

523 chunk_size, 

524 ) 

525 

526 # --- Exact mode: reuse the cached plan or build it on a miss. --- 

527 cache_kwargs = _make_hashable( 

528 {k: v for k, v in kwargs.items() if not isinstance(v, jnp.ndarray)} 

529 ) 

530 cache_key = (type, in_axes, arg_shapes, cache_kwargs, gate_error) 

531 

532 # The cached ``batched_fn`` (eqx.filter_jit wrapper) is reused across 

533 # calls including under an outer transform: its ``_single_execute`` 

534 # closure captures only concrete metadata (n_qubits/obs/use_density and 

535 # non-array kwargs), so it leaks no tracers, and reusing one wrapper 

536 # lets JAX hit its aval-keyed trace cache instead of re-tracing the 

537 # circuit every call. Only the AOT ``plain_fn`` (a compiled executable) 

538 # is invalid for tracers; its use is gated below by ``in_transform``. 

539 plan = self._jit_cache.get(cache_key) 

540 if plan is None: 

541 plan = self._build_plan(type, obs, args, kwargs, in_axes) 

542 self._jit_cache[cache_key] = plan 

543 

544 chunk_size = self._chunk_size(cache_key, plan, type, len(obs), batch_size) 

545 return self._dispatch( 

546 ("_aot", cache_key, batch_size), 

547 plan.batched_fn, 

548 None if in_transform else plan.plain_fn, 

549 args, 

550 in_axes, 

551 batch_size, 

552 chunk_size, 

553 ) 

554 

555 def draw( 

556 self, 

557 figure: str = "text", 

558 args: tuple = (), 

559 kwargs: Optional[dict] = None, 

560 **draw_kwargs: Any, 

561 ) -> Union[str, Any]: 

562 """Draw the quantum circuit. 

563 

564 Records the tape by calling the circuit function with the given 

565 arguments, then renders the resulting gate sequence. 

566 

567 Args: 

568 figure: Rendering backend. One of: 

569 

570 - ``"text"`` — ASCII art (returned as a ``str``). 

571 - ``"mpl"`` — Matplotlib figure (returns ``(fig, ax)``). 

572 - ``"tikz"`` — LaTeX/TikZ code via ``quantikz`` 

573 (returns a :class:`TikzFigure`). 

574 - ``"pulse"`` — Pulse schedule plot (returns ``(fig, axes)``). 

575 

576 args: Positional arguments forwarded to the circuit function 

577 to record the tape. 

578 kwargs: Keyword arguments forwarded to the circuit function. 

579 **draw_kwargs: Extra options forwarded to the rendering backend: 

580 

581 - ``gate_values`` (bool): Show numeric gate angles instead of 

582 symbolic \\theta_i labels. Default ``False``. 

583 - ``show_carrier`` (bool): For ``"pulse"`` mode, overlay the 

584 carrier-modulated waveform. Default ``False``. 

585 

586 Returns: 

587 Depends on *figure*: 

588 

589 - ``"text"`` -> ``str`` 

590 - ``"mpl"`` -> ``(matplotlib.figure.Figure, matplotlib.axes.Axes)`` 

591 - ``"tikz"`` -> :class:`TikzFigure` 

592 - ``"pulse"`` -> ``(matplotlib.figure.Figure, numpy.ndarray)`` 

593 

594 Raises: 

595 ValueError: If *figure* is not one of the supported modes. 

596 """ 

597 if figure not in ("text", "mpl", "tikz", "pulse"): 

598 raise ValueError( 

599 f"Invalid figure mode: {figure!r}. " 

600 "Must be 'text', 'mpl', 'tikz', or 'pulse'." 

601 ) 

602 

603 if kwargs is None: 

604 kwargs = {} 

605 

606 if figure == "pulse": 

607 from qml_essentials.drawing import draw_pulse_schedule 

608 

609 events = self.pulse_events(*args, **kwargs) 

610 n_qubits = ( 

611 self._n_qubits 

612 or max((w for ev in events for w in ev.wires), default=0) + 1 

613 ) 

614 return draw_pulse_schedule(events, n_qubits, **draw_kwargs) 

615 

616 tape = self._record(*args, **kwargs) 

617 n_qubits = self._n_qubits or simulation.infer_n_qubits(tape, []) 

618 

619 # Filter out noise channels for drawing — they clutter the diagram 

620 ops = [op for op in tape if not isinstance(op, KrausChannel)] 

621 

622 if figure == "text": 

623 return draw_text(ops, n_qubits) 

624 elif figure == "mpl": 

625 return draw_mpl(ops, n_qubits, **draw_kwargs) 

626 else: # tikz 

627 return draw_tikz(ops, n_qubits, **draw_kwargs)