Coverage for qml_essentials / script.py: 97%
146 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-11 15:51 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-11 15:51 +0000
1from typing import Any, Callable, List, NamedTuple, Optional, Tuple, Union
3import jax
4import jax.numpy as jnp
5import equinox as eqx
7from qml_essentials.operations import Operation, KrausChannel
8from qml_essentials.tape import recording, pulse_recording
9from qml_essentials.drawing import draw_text, draw_mpl, draw_tikz
10from qml_essentials.unitary import UnitaryGates
11from qml_essentials import simulation, memory
14def _make_hashable(obj):
15 """Recursively convert an object into a hashable form for cache keys.
17 - ``dict`` → sorted tuple of ``(key, _make_hashable(value))`` pairs
18 - ``list`` → tuple of ``_make_hashable(element)``
19 - ``set`` → frozenset of ``_make_hashable(element)``
20 - everything else is returned as-is (assumed hashable)
21 """
22 if isinstance(obj, dict):
23 return tuple(sorted((k, _make_hashable(v)) for k, v in obj.items()))
24 if isinstance(obj, (list, tuple)):
25 return tuple(_make_hashable(x) for x in obj)
26 if isinstance(obj, set):
27 return frozenset(_make_hashable(x) for x in obj)
28 return obj
31class _BatchPlan(NamedTuple):
32 """Compiled artefacts for one batched circuit signature.
34 Cached in :attr:`Script._jit_cache` keyed on the signature ``cache_key``.
35 ``batched_fn`` is deliberately the first field so callers (and tests) can
36 unpack ``batched_fn, *_ = plan``.
38 Attributes:
39 batched_fn: ``eqx.filter_jit(jax.vmap(...))`` wrapper; always valid,
40 including under an outer transform and in shot mode.
41 plain_fn: AOT-eligible ``jax.jit(jax.vmap(...))`` wrapper, or ``None``
42 when no concrete-array fast path applies (non-array argument, shot
43 mode, or running under a transform).
44 n_qubits: Qubit count derived from the recorded tape.
45 use_density: Whether density-matrix simulation is required.
46 n_ops: Number of operations on the tape (for memory estimation).
47 """
49 batched_fn: Callable
50 plain_fn: Optional[Callable]
51 n_qubits: int
52 use_density: bool
53 n_ops: int
56class Script:
57 """Circuit container and executor backed by pure JAX kernels.
59 ``Script`` takes a callable *f* representing a quantum circuit.
60 Within *f*, :class:`~qml_essentials.operations.Operation` objects are
61 instantiated and automatically recorded onto a tape. The tape is then
62 simulated using either a statevector or density-matrix kernel depending on
63 whether noise channels are present.
65 The stateless simulation/measurement kernels live in
66 :mod:`qml_essentials.simulation` and the memory-estimation/chunking helpers
67 in :mod:`qml_essentials.memory`; this class orchestrates recording,
68 batching, caching, and drawing around them.
70 Attributes:
71 f: The circuit function whose body instantiates ``Operation`` objects.
72 _n_qubits: Optionally pre-declared number of qubits. When ``None``
73 the qubit count is inferred from the operations recorded on the
74 tape.
76 Example:
77 >>> def circuit(theta):
78 ... RX(theta, wires=0)
79 ... PauliZ(wires=1)
80 >>> script = Script(circuit, n_qubits=2)
81 >>> result = script.execute(type="expval", obs=[PauliZ(0)])
82 """
84 def __init__(self, f: Callable[..., None], n_qubits: Optional[int] = None) -> None:
85 """Initialise a Script.
87 Args:
88 f: A function whose body instantiates ``Operation`` objects.
89 Signature: ``f(*args, **kwargs) -> None``.
90 n_qubits: Number of qubits. If ``None``, inferred from the
91 operations recorded on the tape.
92 """
93 self.f = f
94 self._n_qubits = n_qubits
95 self._jit_cache: dict = {} # keyed on (type, in_axes, arg_shapes, gateError)
97 def _record(self, *args, **kwargs) -> List[Operation]:
98 """Run the circuit function and collect the recorded operations.
100 Uses :func:`~qml_essentials.tape.recording` as a context manager so
101 that the tape is always cleaned up — even if the circuit function
102 raises — and nested recordings (e.g. from ``_execute_batched``) each
103 get their own independent tape.
105 Args:
106 *args: Positional arguments forwarded to the circuit function.
107 **kwargs: Keyword arguments forwarded to the circuit function.
109 Returns:
110 List of :class:`~qml_essentials.operations.Operation` instances in
111 the order they were instantiated.
112 """
113 with recording() as tape:
114 self.f(*args, **kwargs)
115 return tape
117 def pulse_events(self, *args, **kwargs) -> list:
118 """Run the circuit and collect pulse events emitted by PulseGates.
120 Activates both the normal operation tape (so gates execute) and
121 a pulse-event tape that captures
122 :class:`~qml_essentials.drawing.PulseEvent` objects from leaf
123 pulse gates (RX, RY, RZ, CZ).
125 Args:
126 *args (Any): Forwarded to the circuit function.
127 **kwargs (Any): Forwarded to the circuit function.
129 Returns:
130 List of :class:`~qml_essentials.drawing.PulseEvent`.
131 """
132 with pulse_recording() as events:
133 with recording():
134 self.f(*args, **kwargs)
135 return events
137 def execute(
138 self,
139 type: str = "expval",
140 obs: Optional[List[Operation]] = None,
141 *,
142 args: tuple = (),
143 kwargs: Optional[dict] = None,
144 in_axes: Optional[Tuple] = None,
145 shots: Optional[int] = None,
146 key: Optional[jnp.ndarray] = None,
147 ) -> jnp.ndarray:
148 """Execute the circuit and return measurement results.
150 Args:
151 type: Measurement type. One of:
153 - ``"expval"`` — expectation value
154 \\langle\\psi|O|\\psi\\rangle / Tr(O\\rho )
155 for each observable in *obs*.
156 - ``"probs"`` — probability vector of shape ``(2**n,)``.
157 - ``"state"`` — raw statevector of shape ``(2**n,)``.
158 - ``"density"`` — full density matrix of shape
159 ``(2**n, 2**n)``.
161 obs: Observables required when type is ``"expval"``.
162 args: Positional arguments forwarded to the circuit function f.
163 kwargs: Keyword arguments forwarded to f.
164 in_axes: Batch axes for each element of *args*, following the same
165 convention as ``jax.vmap``:
167 - An integer selects that axis of the corresponding array as
168 the batch dimension.
169 - ``None`` broadcasts the argument (no batching).
171 When provided, :meth:`execute` calls ``jax.vmap`` over the
172 pure simulation kernel and returns results with a leading
173 batch dimension.
174 shots: Number of measurement shots for stochastic sampling.
175 If ``None`` (default), exact analytic results are returned.
176 Only supported for ``"probs"`` and ``"expval"`` measurement
177 types.
178 key: JAX PRNG key for shot sampling. If ``None`` and *shots*
179 is set, a default key ``jax.random.PRNGKey(0)`` is used.
181 Returns:
182 Without in_axes: shape determined by type.
183 With in_axes: shape ``(B, ...)`` with a leading batch dimension.
184 """
185 if obs is None:
186 obs = []
187 if kwargs is None:
188 kwargs = {}
189 if shots is not None and key is None:
190 key = jax.random.PRNGKey(0)
192 # Split single/ parallel execution
193 # TODO: we might want to unify the n_qubit stuff such that we can eliminate
194 # the parameter to this method entirely
195 if in_axes is not None:
196 return self._execute_batched(
197 type=type,
198 obs=obs,
199 args=args,
200 kwargs=kwargs,
201 in_axes=in_axes,
202 shots=shots,
203 key=key,
204 )
205 else:
206 tape = self._record(*args, **kwargs)
207 n_qubits = self._n_qubits or simulation.infer_n_qubits(tape, obs)
209 use_density = simulation.uses_density(tape, type)
211 return simulation.simulate_and_measure(
212 tape,
213 n_qubits,
214 type,
215 obs,
216 use_density,
217 shots=shots,
218 key=key,
219 )
221 @staticmethod
222 def _args_contain_tracer(args: tuple) -> bool:
223 """Return ``True`` if any leaf of *args* is a JAX tracer.
225 When :meth:`execute` runs under an outer transform (``jax.grad``,
226 ``jax.jacrev``, an enclosing ``jax.jit``/``vmap``) the positional
227 arguments are tracers rather than concrete arrays. The tracer-tolerant
228 ``eqx.filter_jit`` wrapper is still reused in that case (its closure
229 captures only concrete metadata), but the concrete-only fast path — the
230 ahead-of-time-compiled XLA executable — is invalid for tracers and must
231 be skipped.
232 """
233 return any(
234 isinstance(x, jax.core.Tracer) for x in jax.tree_util.tree_leaves(args)
235 )
237 @staticmethod
238 def _batch_size(args: tuple, in_axes: Tuple) -> int:
239 """Size of the batch dimension, read from the first batched argument."""
240 for a, ax in zip(args, in_axes):
241 if ax is not None:
242 return a.shape[ax]
243 return 1
245 @staticmethod
246 def _slice_first(a: Any, ax: int) -> Any:
247 """Take the first element along axis *ax*.
249 Uses ``jax.lax.index_in_dim`` rather than ``jnp.take`` because JAX
250 random-key arrays do not support ``jnp.take``.
251 """
252 # TODO: fix once that is available in JAX
253 return jax.lax.index_in_dim(a, 0, axis=ax, keepdims=False)
255 def _record_metadata(
256 self, scalar_args: tuple, kwargs: dict, obs: List[Operation], type: str
257 ) -> Tuple[int, bool, int]:
258 """Trace the tape from scalar slices to derive batch-invariant metadata.
260 Recording once with scalar slices determines ``n_qubits`` and whether
261 noise channels are present (forcing density-matrix simulation) without
262 running the full batch.
264 Returns:
265 ``(n_qubits, use_density, n_ops)``.
266 """
267 tape = self._record(*scalar_args, **kwargs)
268 n_qubits = self._n_qubits or simulation.infer_n_qubits(tape, obs)
269 use_density = simulation.uses_density(tape, type)
270 return n_qubits, use_density, len(tape)
272 def _build_plan(
273 self,
274 type: str,
275 obs: List[Operation],
276 args: tuple,
277 kwargs: dict,
278 in_axes: Tuple,
279 ) -> _BatchPlan:
280 """Trace the circuit once and build the cacheable execution plan.
282 Records the tape from scalar slices of *args* (to derive
283 ``n_qubits``/noise), then builds the vmapped ``eqx.filter_jit``
284 wrapper. When every positional argument is array-like (so plain
285 ``jax.jit`` — which has no static-argument handling — is valid) an
286 AOT-eligible plain ``jax.jit`` wrapper is built too; :meth:`_dispatch`
287 lowers and compiles it lazily per batch size, and only with concrete
288 args (the AOT path is gated off under a transform by the caller).
289 """
290 scalar_args = tuple(
291 self._slice_first(a, ax) if ax is not None else a
292 for a, ax in zip(args, in_axes)
293 )
294 n_qubits, use_density, n_ops = self._record_metadata(
295 scalar_args, kwargs, obs, type
296 )
298 # Re-recording inside this closure is necessary: tape operations may
299 # have matrices that depend on the batched argument (e.g. RX(theta)
300 # with theta a tracer). jax.vmap traces this once into a single XLA
301 # computation spanning the whole batch.
302 def _single_execute(*single_args):
303 single_tape = self._record(*single_args, **kwargs)
304 return simulation.simulate_and_measure(
305 single_tape, n_qubits, type, obs, use_density
306 )
308 # Wrapping the vmapped function in eqx.filter_jit: (1) treats non-array
309 # arguments as static, so circuit signatures mixing arrays and Python
310 # values work; (2) lets the XLA program use intra-op CPU parallelism;
311 # (3) caches compilation across calls with the same input shapes.
312 # NOTE: when altering properties of the model, this might not get
313 # re-compiled.
314 # TODO: we might want to rework the data_reupload mechanism at some point
315 batched_fn = eqx.filter_jit(jax.vmap(_single_execute, in_axes=in_axes))
317 # AOT eligibility is a structural property of the signature: plain
318 # ``jax.jit`` has no static-argument handling, so it is valid only when
319 # every positional argument is array-like. ``hasattr(a, "shape")`` is
320 # true for concrete arrays, numpy arrays, and tracers, but false for
321 # Python statics (str/None/dict). Building the wrapper is pure (it
322 # traces nothing); the lower+compile happens lazily in :meth:`_dispatch`
323 # and only with concrete args, so this is safe to build under a
324 # transform — its use is gated off there by the caller.
325 plain_fn = None
326 if all(hasattr(a, "shape") for a in args):
327 plain_fn = jax.jit(jax.vmap(_single_execute, in_axes=in_axes))
329 return _BatchPlan(batched_fn, plain_fn, n_qubits, use_density, n_ops)
331 def _chunk_size(
332 self,
333 cache_key: tuple,
334 plan: _BatchPlan,
335 type: str,
336 n_obs: int,
337 batch_size: int,
338 ) -> int:
339 """Largest batch chunk that fits in memory, memoized per batch size.
341 The result is cached under ``("_mem", cache_key, batch_size)`` to avoid
342 repeated ``psutil`` syscalls across a tight repeated-call loop.
343 """
344 mem_key = ("_mem", cache_key, batch_size)
345 chunk_size = self._jit_cache.get(mem_key)
346 if chunk_size is None:
347 chunk_size = memory.compute_chunk_size(
348 plan.n_qubits,
349 batch_size,
350 type,
351 plan.use_density,
352 n_obs,
353 n_ops=plan.n_ops,
354 )
355 self._jit_cache[mem_key] = chunk_size
356 return chunk_size
358 def _dispatch(
359 self,
360 aot_key: Optional[tuple],
361 batched_fn: Callable,
362 plain_fn: Optional[Callable],
363 args: tuple,
364 in_axes: Tuple,
365 batch_size: int,
366 chunk_size: int,
367 ) -> jnp.ndarray:
368 """Run a built plan through the leanest applicable path.
370 - ``chunk_size < batch_size``: the full batch would not fit in memory,
371 so execute it in memory-safe sub-batches via
372 :func:`~qml_essentials.memory.execute_chunked`.
373 - otherwise, when an AOT-eligible ``plain_fn`` exists, ahead-of-time
374 lower+compile the vmapped kernel to an XLA executable (cached per
375 ``aot_key``) and call it directly. This skips both the per-call
376 pytree partition/combine of :func:`eqx.filter_jit` and its
377 just-in-time cache-key recomputation; for small circuits in a tight
378 loop that dispatch overhead, not the XLA compute, dominates.
379 - otherwise fall back to ``batched_fn`` (no ``plain_fn``: a non-array
380 argument, shot mode, or running under a transform).
381 """
382 if chunk_size < batch_size:
383 return memory.execute_chunked(
384 batched_fn,
385 args,
386 in_axes,
387 batch_size,
388 chunk_size,
389 clear_caches=memory.CLEAR_CACHES_BETWEEN_CHUNKS,
390 )
391 if plain_fn is None:
392 return batched_fn(*args)
393 compiled = self._jit_cache.get(aot_key)
394 if compiled is None:
395 compiled = plain_fn.lower(*args).compile()
396 self._jit_cache[aot_key] = compiled
397 return compiled(*args)
399 def _execute_batched(
400 self,
401 type: str,
402 obs: List[Operation],
403 args: tuple,
404 kwargs: dict,
405 in_axes: Tuple,
406 shots: Optional[int] = None,
407 key: Optional[jnp.ndarray] = None,
408 ) -> jnp.ndarray:
409 """Vectorise :meth:`execute` over a batch axis using ``jax.vmap``.
411 The circuit function is traced once in Python with scalar slices to
412 record the tape, determine ``n_qubits``, and detect noise. The
413 resulting pure simulation kernel is then vmapped over the requested
414 axes.
416 Memory-aware chunking — before launching the full vmap, the
417 method estimates peak memory usage. If the full batch would exceed
418 available RAM (with a safety margin), the batch is automatically
419 split into sub-batches that fit. Each chunk is vmapped independently
420 and the results are concatenated. This trades a small amount of
421 wall-clock time for guaranteed execution without OOM.
423 When the full batch fits in memory, there is zero overhead — the
424 memory check is a pure Python arithmetic calculation (no JAX calls).
426 Args:
427 type: Measurement type (see :meth:`execute`).
428 obs: Observables (see :meth:`execute`).
429 args: Positional arguments for the circuit function.
430 kwargs: Keyword arguments for the circuit function.
431 in_axes: One entry per element of *args*. Follows ``jax.vmap``
432 convention: an int gives the batch axis; ``None`` broadcasts.
433 shots: Number of measurement shots. If ``None``, exact results.
434 key: JAX PRNG key for shot sampling.
436 Returns:
437 Batched measurement results of shape ``(B, ...)`` where *B* is the
438 size of the batch dimension.
440 Raises:
441 ValueError: If ``len(in_axes) != len(args)``.
443 Note:
444 The ``jax.vmap`` call in :meth:`_build_plan` is the exact
445 boundary to replace with ``jax.shard_map`` for multi-device
446 execution::
448 from jax.sharding import PartitionSpec as P, Mesh
449 result = jax.shard_map(
450 _single_execute, mesh=mesh,
451 in_specs=tuple(P(0) if ax is not None else P() for ax in in_axes),
452 out_specs=P(0),
453 )(*args)
454 """
455 if len(in_axes) != len(args):
456 raise ValueError(
457 f"in_axes has {len(in_axes)} entries but args has {len(args)}. "
458 "Provide one in_axes entry per positional argument."
459 )
461 batch_size = self._batch_size(args, in_axes)
463 # Running under an outer JAX transform (e.g. ``jax.jacrev``) makes
464 # ``args`` tracers. The tracer-tolerant ``batched_fn`` wrapper is still
465 # cached and reused (see exact-mode dispatch below); only the AOT
466 # ``plain_fn`` executable is gated off, as it cannot accept tracers.
467 in_transform = self._args_contain_tracer(args)
469 arg_shapes = tuple(
470 (a.shape, a.dtype) if hasattr(a, "shape") else type(a) for a in args
471 )
472 # TODO: we need to fix the dirty class-level `batch_gate_error` hack.
473 # It is a global toggle that changes the compiled circuit, so it has to
474 # take part in every cache key.
475 gate_error = UnitaryGates.batch_gate_error
477 # --- Shot mode: compute exact probabilities, then sample. ---
478 if shots is not None and type in ("probs", "expval"):
479 shot_cache_key = (type, "shots", shots, in_axes, arg_shapes, gate_error)
480 shot_in_axes = in_axes + (0,) # shot key batched over axis 0
481 shot_args = args + (jax.random.split(key, batch_size),)
483 plan = self._jit_cache.get(shot_cache_key)
484 if plan is None:
485 scalar_args = tuple(
486 self._slice_first(a, ax) if ax is not None else a
487 for a, ax in zip(args, in_axes)
488 )
489 n_qubits, use_density, n_ops = self._record_metadata(
490 scalar_args, kwargs, obs, type
491 )
493 # Re-recording inside the closure lets jax.vmap trace the whole
494 # batch into one XLA program; the shot key is the extra vmapped
495 # argument.
496 def _single_execute_shots(*single_args_and_key):
497 *single_args, shot_key = single_args_and_key
498 single_tape = self._record(*single_args, **kwargs)
499 exact_result = simulation.simulate_and_measure(
500 single_tape, n_qubits, "probs", obs, use_density
501 )
502 return simulation.sample_shots(
503 exact_result, n_qubits, type, obs, shots, shot_key
504 )
506 batched_fn = eqx.filter_jit(
507 jax.vmap(_single_execute_shots, in_axes=shot_in_axes)
508 )
509 plan = _BatchPlan(batched_fn, None, n_qubits, use_density, n_ops)
510 self._jit_cache[shot_cache_key] = plan
512 chunk_size = self._chunk_size(
513 shot_cache_key, plan, type, len(obs), batch_size
514 )
515 # Shot mode never uses the AOT fast path (plain_fn is None).
516 return self._dispatch(
517 None,
518 plan.batched_fn,
519 None,
520 shot_args,
521 shot_in_axes,
522 batch_size,
523 chunk_size,
524 )
526 # --- Exact mode: reuse the cached plan or build it on a miss. ---
527 cache_kwargs = _make_hashable(
528 {k: v for k, v in kwargs.items() if not isinstance(v, jnp.ndarray)}
529 )
530 cache_key = (type, in_axes, arg_shapes, cache_kwargs, gate_error)
532 # The cached ``batched_fn`` (eqx.filter_jit wrapper) is reused across
533 # calls including under an outer transform: its ``_single_execute``
534 # closure captures only concrete metadata (n_qubits/obs/use_density and
535 # non-array kwargs), so it leaks no tracers, and reusing one wrapper
536 # lets JAX hit its aval-keyed trace cache instead of re-tracing the
537 # circuit every call. Only the AOT ``plain_fn`` (a compiled executable)
538 # is invalid for tracers; its use is gated below by ``in_transform``.
539 plan = self._jit_cache.get(cache_key)
540 if plan is None:
541 plan = self._build_plan(type, obs, args, kwargs, in_axes)
542 self._jit_cache[cache_key] = plan
544 chunk_size = self._chunk_size(cache_key, plan, type, len(obs), batch_size)
545 return self._dispatch(
546 ("_aot", cache_key, batch_size),
547 plan.batched_fn,
548 None if in_transform else plan.plain_fn,
549 args,
550 in_axes,
551 batch_size,
552 chunk_size,
553 )
555 def draw(
556 self,
557 figure: str = "text",
558 args: tuple = (),
559 kwargs: Optional[dict] = None,
560 **draw_kwargs: Any,
561 ) -> Union[str, Any]:
562 """Draw the quantum circuit.
564 Records the tape by calling the circuit function with the given
565 arguments, then renders the resulting gate sequence.
567 Args:
568 figure: Rendering backend. One of:
570 - ``"text"`` — ASCII art (returned as a ``str``).
571 - ``"mpl"`` — Matplotlib figure (returns ``(fig, ax)``).
572 - ``"tikz"`` — LaTeX/TikZ code via ``quantikz``
573 (returns a :class:`TikzFigure`).
574 - ``"pulse"`` — Pulse schedule plot (returns ``(fig, axes)``).
576 args: Positional arguments forwarded to the circuit function
577 to record the tape.
578 kwargs: Keyword arguments forwarded to the circuit function.
579 **draw_kwargs: Extra options forwarded to the rendering backend:
581 - ``gate_values`` (bool): Show numeric gate angles instead of
582 symbolic \\theta_i labels. Default ``False``.
583 - ``show_carrier`` (bool): For ``"pulse"`` mode, overlay the
584 carrier-modulated waveform. Default ``False``.
586 Returns:
587 Depends on *figure*:
589 - ``"text"`` -> ``str``
590 - ``"mpl"`` -> ``(matplotlib.figure.Figure, matplotlib.axes.Axes)``
591 - ``"tikz"`` -> :class:`TikzFigure`
592 - ``"pulse"`` -> ``(matplotlib.figure.Figure, numpy.ndarray)``
594 Raises:
595 ValueError: If *figure* is not one of the supported modes.
596 """
597 if figure not in ("text", "mpl", "tikz", "pulse"):
598 raise ValueError(
599 f"Invalid figure mode: {figure!r}. "
600 "Must be 'text', 'mpl', 'tikz', or 'pulse'."
601 )
603 if kwargs is None:
604 kwargs = {}
606 if figure == "pulse":
607 from qml_essentials.drawing import draw_pulse_schedule
609 events = self.pulse_events(*args, **kwargs)
610 n_qubits = (
611 self._n_qubits
612 or max((w for ev in events for w in ev.wires), default=0) + 1
613 )
614 return draw_pulse_schedule(events, n_qubits, **draw_kwargs)
616 tape = self._record(*args, **kwargs)
617 n_qubits = self._n_qubits or simulation.infer_n_qubits(tape, [])
619 # Filter out noise channels for drawing — they clutter the diagram
620 ops = [op for op in tape if not isinstance(op, KrausChannel)]
622 if figure == "text":
623 return draw_text(ops, n_qubits)
624 elif figure == "mpl":
625 return draw_mpl(ops, n_qubits, **draw_kwargs)
626 else: # tikz
627 return draw_tikz(ops, n_qubits, **draw_kwargs)