Coverage for qml_essentials / yaqsi.py: 93%
400 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-30 11:43 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-30 11:43 +0000
1from functools import reduce
2from typing import Any, Callable, List, Optional, Tuple, Union
3import threading
5import diffrax
6import jax
7import jax.numpy as jnp
8import jax.scipy.linalg
9import numpy as np # needed to prevent jitting some operations
11from qml_essentials.operations import (
12 Barrier,
13 Hermitian,
14 ParametrizedHamiltonian,
15 Operation,
16 KrausChannel,
17 PauliZ,
18 _einsum_subscript,
19 _cdtype,
20)
21from qml_essentials.tape import recording, pulse_recording
22from qml_essentials.drawing import draw_text, draw_mpl, draw_tikz
24import logging
26log = logging.getLogger(__name__)
29def _make_hashable(obj):
30 """Recursively convert an object into a hashable form for cache keys.
32 - ``dict`` → sorted tuple of ``(key, _make_hashable(value))`` pairs
33 - ``list`` → tuple of ``_make_hashable(element)``
34 - ``set`` → frozenset of ``_make_hashable(element)``
35 - everything else is returned as-is (assumed hashable)
36 """
37 if isinstance(obj, dict):
38 return tuple(sorted((k, _make_hashable(v)) for k, v in obj.items()))
39 if isinstance(obj, (list, tuple)):
40 return tuple(_make_hashable(x) for x in obj)
41 if isinstance(obj, set):
42 return frozenset(_make_hashable(x) for x in obj)
43 return obj
46class Yaqsi:
47 # TODO: generally, I would like to merge this into operations or vice-versa
48 # and only keep Script here. It's not clear how to do this though.
50 # Module-level cache for JIT-compiled ODE solvers. Keyed on
51 # (coeff_fn_id, dim, atol, rtol) so that all evolve() calls with the
52 # same pulse shape function and matrix size share one compiled XLA
53 # program. This turns O(n_gates) JIT compilations into
54 # O(n_distinct_pulse_shapes) during pulse-mode circuit building.
55 _evolve_solver_cache: dict = {}
56 _evolve_solver_cache_lock = threading.Lock()
58 @staticmethod
59 def _partial_trace_single(
60 rho: jnp.ndarray,
61 n_qubits: int,
62 keep: List[int],
63 ) -> jnp.ndarray:
64 """Partial trace of a single density matrix (no batch dimension)."""
65 shape = (2,) * (2 * n_qubits)
66 rho_t = rho.reshape(shape)
68 trace_out = sorted(set(range(n_qubits)) - set(keep))
70 for q in reversed(trace_out):
71 n_remaining = rho_t.ndim // 2
72 rho_t = jnp.trace(rho_t, axis1=q, axis2=q + n_remaining)
74 dim = 2 ** len(keep)
75 return rho_t.reshape(dim, dim)
77 @classmethod
78 def partial_trace(
79 cls,
80 rho: jnp.ndarray,
81 n_qubits: int,
82 keep: List[int],
83 ) -> jnp.ndarray:
84 """Partial trace of a density matrix, keeping only the specified qubits.
86 Supports both single density matrices of shape ``(2**n, 2**n)`` and
87 batched density matrices of shape ``(B, 2**n, 2**n)``.
89 Args:
90 rho: Density matrix of shape ``(2**n, 2**n)`` or ``(B, 2**n, 2**n)``.
91 n_qubits: Total number of qubits.
92 keep: List of qubit indices to *keep* (0-indexed).
94 Returns:
95 Reduced density matrix of shape ``(2**k, 2**k)`` or ``(B, 2**k, 2**k)``
96 where *k* = ``len(keep)``.
97 """
99 dim = 2**n_qubits
100 if rho.shape == (dim, dim):
101 return Yaqsi._partial_trace_single(rho, n_qubits, keep)
102 # Batched: shape (B, dim, dim)
103 return jax.vmap(lambda r: Yaqsi._partial_trace_single(r, n_qubits, keep))(rho)
105 @staticmethod
106 def _marginalize_probs_single(
107 probs: jnp.ndarray,
108 target_shape: Tuple[int],
109 trace_out: Tuple[int],
110 ) -> jnp.ndarray:
111 """Marginalize a single probability vector (no batch dimension)."""
112 probs_t = probs.reshape(target_shape)
114 for q in trace_out:
115 probs_t = probs_t.sum(axis=q)
117 return probs_t.ravel()
119 @classmethod
120 def marginalize_probs(
121 cls,
122 probs: jnp.ndarray,
123 n_qubits: int,
124 keep: Tuple[int],
125 ) -> jnp.ndarray:
126 """Marginalize a probability vector to keep only the specified qubits.
128 Supports both single probability vectors of shape ``(2**n,)`` and
129 batched vectors of shape ``(B, 2**n)``.
131 Args:
132 probs: Probability vector of shape ``(2**n,)`` or ``(B, 2**n)``.
133 n_qubits: Total number of qubits.
134 keep: List of qubit indices to *keep* (0-indexed).
136 Returns:
137 Marginalized probability vector of shape ``(2**k,)`` or ``(B, 2**k)``
138 where *k* = ``len(keep)``.
139 """
141 dim = 2**n_qubits
142 trace_out = tuple(q for q in range(n_qubits - 1, -1, -1) if q not in keep)
143 target_shape = (2,) * n_qubits
145 return jax.vmap(
146 lambda p: Yaqsi._marginalize_probs_single(p, target_shape, trace_out)
147 )(probs.reshape(-1, dim))
149 @classmethod
150 def build_parity_observable(
151 cls,
152 qubit_group: List[int],
153 ) -> Hermitian:
154 """Build a multi-qubit parity observable.
156 Args:
157 qubit_group: List of qubit indices for the parity measurement.
159 Returns:
160 A :class:`Hermitian` operation whose matrix is the Z parity
161 tensor product and whose wires match the given qubits.
162 """
163 Z = PauliZ._matrix
164 mat = reduce(jnp.kron, [Z] * len(qubit_group))
165 return Hermitian(matrix=mat, wires=qubit_group, record=False)
167 @classmethod
168 def evolve(cls, hamiltonian, name=None, **odeint_kwargs):
169 """Return a gate-factory for Hamiltonian time evolution.
171 Supports two modes:
173 Static — when *hamiltonian* is a :class:`Hermitian`::
175 gate = evolve(Hermitian(H_mat, wires=0))
176 gate(t=0.5) # U = exp(-i*0.5*H)
178 Time-dependent — when *hamiltonian* is a
179 :class:`ParametrizedHamiltonian` (created via ``coeff_fn * Hermitian``)::
181 H_td = coeff_fn * Hermitian(H_mat, wires=0)
182 gate = evolve(H_td)
183 gate([A, sigma], T) # U via ODE: dU/dt = -i f(p,t) H * U
185 The time-dependent case solves the Schrödinger equation numerically
186 using ``diffrax.diffeqsolve`` with a Dopri8 adaptive Runge-Kutta
187 solver
189 All computations are pure JAX and fully differentiable with
190 ``jax.grad``.
192 Args:
193 hamiltonian: Either a :class:`Hermitian` (static evolution) or a
194 :class:`ParametrizedHamiltonian` (time-dependent evolution).
195 **odeint_kwargs: Extra keyword arguments. Recognised keys:
197 - ``atol``, ``rtol`` — absolute/relative tolerances for the
198 adaptive step-size controller (default ``1.4e-8``).
200 Returns:
201 A callable gate factory. Signature depends on the mode:
203 - Static: ``(t, wires=0) -> Operation``
204 - Time-dependent: ``(coeff_args, T) -> Operation``
206 Raises:
207 TypeError: If *hamiltonian* is neither ``Hermitian`` nor
208 ``ParametrizedHamiltonian``.
209 """
210 if isinstance(hamiltonian, Hermitian):
211 return cls._evolve_static(hamiltonian, name=name)
212 elif isinstance(hamiltonian, ParametrizedHamiltonian):
213 return cls._evolve_parametrized(hamiltonian, name=name, **odeint_kwargs)
214 else:
215 raise TypeError(
216 f"evolve() expects a Hermitian or ParametrizedHamiltonian, "
217 f"got {type(hamiltonian)}"
218 )
220 @staticmethod
221 def _evolve_static(hermitian: Hermitian, name=None) -> Callable:
222 """Gate factory for static Hamiltonian evolution U = exp(-i t H)."""
223 H_mat = hermitian.matrix
225 def _apply(t: float, wires: Union[int, List[int]] = 0) -> Operation:
226 U = jax.scipy.linalg.expm(-1j * t * H_mat)
227 return Operation(wires=wires, matrix=U, name=name)
229 return _apply
231 @classmethod
232 def _evolve_parametrized(
233 cls, ph: ParametrizedHamiltonian, name=None, **odeint_kwargs
234 ) -> Callable:
235 """Gate factory for time-dependent Hamiltonian evolution.
237 Solves the matrix ODE ``dU/dt = -i f(params, t) H * U`` with
238 ``U(0) = I`` using ``diffrax.diffeqsolve`` (Dopri8 adaptive RK).
240 Performance improvements over the previous ``jax.experimental.ode``
241 implementation:
243 Uses diffrax — a modern, well-maintained JAX ODE library with
244 better XLA compilation, adjoint methods, and step-size control.
245 The JIT-compiled solver is cached per coefficient function so
246 that multiple ``evolve()`` calls with the same pulse shape (but
247 different Hamiltonian matrices or parameters) reuse the same
248 compiled XLA program. This avoids O(n_gates) JIT compilations
249 during pulse-mode tape recording.
250 Pre-computes ``-i*H`` once instead of multiplying at every RHS
251 evaluation.
252 Avoids dynamic ``jnp.where`` branching for the time span.
254 Args:
255 ph: A :class:`ParametrizedHamiltonian` holding the coefficient
256 function, the Hamiltonian matrix, and wire indices.
257 **odeint_kwargs: ``atol`` and ``rtol`` for the step-size controller.
258 """
259 H_mat = ph.H_mat
260 coeff_fn = ph.coeff_fn
261 wires = ph.wires
262 dim = H_mat.shape[0]
264 # Pre-compute -i*H once (avoids repeated multiplication in RHS)
265 neg_iH = -1j * H_mat
267 atol = odeint_kwargs.pop("atol", 1.4e-8)
268 rtol = odeint_kwargs.pop("rtol", 1.4e-8)
270 # Look up or build the JIT-compiled solver for this (coeff_fn, dim)
271 # combination. All pulse gates with the same pulse shape function
272 # (e.g. all RX gates share Sx, all RY share Sy) reuse a single
273 # compiled XLA program, with neg_iH passed as a regular argument
274 # rather than captured in the closure. This turns O(n_gates) JIT
275 # compilations into O(n_distinct_pulse_shapes) compilations.
276 #
277 # We key on the function's __code__ object rather than id(coeff_fn)
278 # because each pulse gate call creates a new closure (capturing the
279 # rotation angle `w`), but the underlying code is identical. Under
280 # JIT, the captured `w` is a tracer anyway, so the compiled program
281 # is generic over `w`.
282 cache_key = (id(coeff_fn.__code__), dim, atol, rtol)
284 with cls._evolve_solver_cache_lock:
285 _solve = cls._evolve_solver_cache.get(cache_key)
287 if _solve is None:
288 solver = diffrax.Dopri8()
289 stepsize_controller = diffrax.PIDController(atol=atol, rtol=rtol)
291 @jax.jit
292 def _solve(neg_iH, params, t0, t1):
293 """Solve dU/dt = f(params,t) * (-iH) * U from t0 to t1."""
295 def rhs(t, y, args):
296 return coeff_fn(args, t) * (neg_iH @ y)
298 sol = diffrax.diffeqsolve(
299 diffrax.ODETerm(rhs),
300 solver,
301 t0=t0,
302 t1=t1,
303 dt0=None, # let the controller choose the initial step
304 y0=jnp.eye(dim, dtype=_cdtype()),
305 args=params,
306 stepsize_controller=stepsize_controller,
307 max_steps=4096,
308 )
310 # sol.ys has shape (1, dim, dim) for SaveAt(t1=True) (default)
311 return sol.ys[0]
313 with cls._evolve_solver_cache_lock:
314 # Double-check to avoid overwriting a concurrent build
315 existing = cls._evolve_solver_cache.get(cache_key)
316 if existing is not None:
317 _solve = existing
318 else:
319 cls._evolve_solver_cache[cache_key] = _solve
321 def _apply(coeff_args, T) -> Operation:
322 """Evolve under the time-dependent Hamiltonian.
324 Args:
325 coeff_args: List of parameter sets, one per Hamiltonian term.
326 For static Hamiltonians, ``coeff_args[0]`` is
327 forwarded to ``coeff_fn(params, t)`` as the first argument.
328 T: Total evolution time. If scalar, the ODE is solved on
329 ``[0, T]``. If a 2-element array, on ``[T[0], T[1]]``.
331 Returns:
332 An :class:`Operation` wrapping the computed unitary.
333 """
334 # PennyLane convention: coeff_args is a list of param-sets,
335 # one per term. Single term -> unpack the first.
336 params = (
337 coeff_args[0] if isinstance(coeff_args, (list, tuple)) else coeff_args
338 )
340 # Build time span — resolve at Python level to avoid traced branching
341 T_arr = jnp.asarray(T, dtype=jnp.float64)
342 if T_arr.ndim == 0:
343 t0 = jnp.float64(0.0)
344 t1 = T_arr
345 else:
346 t0 = T_arr[0]
347 t1 = T_arr[1]
349 U = _solve(neg_iH, params, t0, t1)
351 return Operation(wires=wires, matrix=U, name=name)
353 return _apply
356# TODO adjust imports to use classmethods instead
357partial_trace = Yaqsi.partial_trace
358evolve = Yaqsi.evolve
359marginalize_probs = Yaqsi.marginalize_probs
360build_parity_observable = Yaqsi.build_parity_observable
363class Script:
364 """Circuit container and executor backed by pure JAX kernels.
366 ``Script`` takes a callable *f* representing a quantum circuit.
367 Within *f*, :class:`~qml_essentials.operations.Operation` objects are
368 instantiated and automatically recorded onto a tape. The tape is then
369 simulated using either a statevector or density-matrix kernel depending on
370 whether noise channels are present.
372 Attributes:
373 f: The circuit function whose body instantiates ``Operation`` objects.
374 _n_qubits: Optionally pre-declared number of qubits. When ``None``
375 the qubit count is inferred from the operations recorded on the
376 tape.
378 Example:
379 >>> def circuit(theta):
380 ... RX(theta, wires=0)
381 ... PauliZ(wires=1)
382 >>> script = Script(circuit, n_qubits=2)
383 >>> result = script.execute(type="expval", obs=[PauliZ(0)])
384 """
386 def __init__(self, f: Callable, n_qubits: Optional[int] = None) -> None:
387 """Initialise a Script.
389 Args:
390 f: A function whose body instantiates ``Operation`` objects.
391 Signature: ``f(*args, **kwargs) -> None``.
392 n_qubits: Number of qubits. If ``None``, inferred from the
393 operations recorded on the tape.
394 """
395 self.f = f
396 self._n_qubits = n_qubits
397 self._jit_cache: dict = {} # keyed on (type, in_axes, arg_shapes, gateError)
399 @staticmethod
400 def _estimate_peak_bytes(
401 n_qubits: int,
402 batch_size: int,
403 type: str,
404 use_density: bool,
405 n_obs: int = 0,
406 ) -> int:
407 """Estimate peak memory (bytes) for a batched simulation.
409 The estimate accounts for:
411 - The batched statevector (always needed, even for density).
412 - The batched output tensor (state / probs / density / expval).
413 - One gate-tensor temporary per batch element (the einsum buffer).
415 Observable matrices are **not** counted: they are computed inside
416 the JIT-compiled function and XLA manages their lifetime (reusing
417 buffers between observables). Similarly, the outer-product
418 temporary for pure-circuit density mode is transient within XLA.
420 Element size is determined dynamically from ``jax.config.x64_enabled``:
421 when x64 mode is disabled (the JAX default), complex values are
422 ``complex64`` (8 bytes) and floats are ``float32`` (4 bytes),
423 halving memory usage compared to the x64 path.
425 A 1.5× safety factor is applied to cover XLA compiler temporaries,
426 padding, and other allocations not directly visible to Python.
428 This is a pure Python arithmetic calculation with no JAX calls —
429 it adds essentially zero overhead.
431 Args:
432 n_qubits: Number of qubits in the circuit.
433 batch_size: Number of batch elements.
434 type: Measurement type (``"state"``, ``"probs"``, ``"expval"``,
435 ``"density"``).
436 use_density: Whether density-matrix simulation is used.
437 n_obs: Number of observables (relevant for ``"expval"``).
439 Returns:
440 Estimated peak memory in bytes.
441 """
442 dim = 2**n_qubits
443 # Detect actual element size: JAX silently truncates complex128
444 # to complex64 when x64 mode is disabled (the default).
445 elem = 16 if jax.config.x64_enabled else 8 # complex128 vs complex64
446 real_elem = elem // 2 # float64 vs float32
448 # Statevector: always allocated during simulation
449 sv_bytes = batch_size * dim * elem
451 # Simulation intermediate: when density-matrix simulation is used,
452 # the full rho (dim × dim) must be held during gate evolution —
453 # even if the final output is only probs or expval.
454 # apply_to_density contracts both U and U* against rho, so at least
455 # two intermediate (dim × dim) buffers are alive simultaneously.
456 if use_density:
457 sim_bytes = 2 * batch_size * dim * dim * elem
458 else:
459 sim_bytes = 0 # statevector is already counted above
461 # Output tensor: this is the *returned* array, not the simulation
462 # intermediate. For probs/expval with density simulation the
463 # density matrix is reduced to a small output *before* returning,
464 # so only the reduced output coexists with the next chunk.
465 if type == "density":
466 out_bytes = batch_size * dim * dim * elem
467 elif type == "expval":
468 out_bytes = batch_size * max(n_obs, 1) * real_elem
469 elif type == "probs":
470 out_bytes = batch_size * dim * real_elem
471 else: # state
472 out_bytes = batch_size * dim * elem
474 # Gate temporaries: einsum creates one (2,)*n buffer per batch elem
475 gate_tmp = batch_size * dim * elem
477 # Peak = max(simulation phase, output phase). During simulation
478 # the intermediate + statevector + gate temps are alive. After
479 # measurement, only the output survives. So peak is whichever
480 # phase is larger.
481 sim_peak = sv_bytes + sim_bytes + gate_tmp
482 out_peak = out_bytes
483 raw = max(sim_peak, out_peak)
485 # 1.5× safety factor for XLA compiler temporaries, padding, etc.
486 return int(raw * 1.5)
488 @staticmethod
489 def _available_memory_bytes() -> int:
490 """Return available system memory in bytes.
492 Uses ``psutil.virtual_memory().available`` for cross-platform
493 support (Linux, macOS, Windows). Falls back to reading
494 ``/proc/meminfo`` on Linux, and finally to a conservative 4 GiB
495 default if neither approach succeeds.
497 Returns:
498 Available memory in bytes.
499 """
500 mem = 4 * 1024**3
501 # Primary: psutil (works on Linux, macOS, Windows)
502 try:
503 import psutil
505 mem = psutil.virtual_memory().available
506 except Exception:
507 log.debug("psutil not available. Fallback to /proc/meminfo")
509 # Fallback: /proc/meminfo (Linux only)
510 try:
511 with open("/proc/meminfo", "r") as f:
512 for line in f:
513 if line.startswith("MemAvailable:"):
514 mem = int(line.split()[1]) * 1024 # kB → bytes
515 except Exception:
516 log.debug("Failed to read /proc/meminfo. Falling back to 4 GiB")
518 log.debug(f"Available memory: {mem/1024**3:.1f} GB")
519 return mem
521 @staticmethod
522 def _compute_chunk_size(
523 n_qubits: int,
524 batch_size: int,
525 type: str,
526 use_density: bool,
527 n_obs: int = 0,
528 memory_fraction: float = 0.8,
529 ) -> int:
530 """Determine the largest chunk size that fits in available memory.
532 If the full batch fits, returns *batch_size* (i.e. no chunking).
533 Otherwise, returns the largest chunk size such that the computation
534 of one chunk **plus** the full output accumulator fits within
535 ``memory_fraction`` of available RAM.
537 The output accumulator is the final ``(batch_size, ...)`` array that
538 holds all results. When chunking, this array must coexist with the
539 active chunk computation, so its size is subtracted from available
540 memory before computing how many elements fit per chunk.
542 The minimum chunk size is 1 (fully serialised).
544 Args:
545 n_qubits: Number of qubits.
546 batch_size: Total batch size.
547 type: Measurement type.
548 use_density: Whether density-matrix simulation is used.
549 n_obs: Number of observables.
550 memory_fraction: Fraction of available memory to target
551 (default 0.8 = 80%).
553 Returns:
554 Chunk size (number of batch elements per sub-batch).
555 """
556 avail = int(Script._available_memory_bytes() * memory_fraction)
557 full_est = Script._estimate_peak_bytes(
558 n_qubits, batch_size, type, use_density, n_obs
559 )
561 if full_est <= avail:
562 return batch_size # everything fits — no chunking
564 # The output accumulator (the final (batch_size, ...) result array)
565 # must coexist with each chunk's computation, so subtract its size
566 # from available memory before sizing chunks.
567 dim = 2**n_qubits
568 elem = 16 if jax.config.x64_enabled else 8
569 real_elem = elem // 2
570 if type == "density":
571 accum_bytes = batch_size * dim * dim * elem
572 elif type == "expval":
573 accum_bytes = batch_size * max(n_obs, 1) * real_elem
574 elif type == "probs":
575 accum_bytes = batch_size * dim * real_elem
576 else:
577 accum_bytes = batch_size * dim * elem
578 avail_for_chunks = max(avail - accum_bytes, elem) # at least 1 element
580 # Per-element cost: the memory for computing a single batch element.
581 per_elem = Script._estimate_peak_bytes(n_qubits, 1, type, use_density, n_obs)
583 if per_elem <= 0:
584 return batch_size
586 chunk = avail_for_chunks // per_elem
587 chunk = max(1, min(chunk, batch_size))
589 if chunk == 1 and per_elem > avail:
590 log.warning(
591 f"A single batch element requires ~{per_elem / 1024**3:.2f} GB "
592 f"but only ~{avail / 1024**3:.2f} GB is available. "
593 f"Proceeding with chunk_size=1 but OOM is possible. "
594 f"Consider reducing n_qubits or switching measurement type."
595 )
597 log.info(
598 f"Computation requires ~{full_est / 1024**3:.2f} GB which "
599 f"does not fit in ~{avail / 1024**3:.2f} GB. "
600 f"Using chunk size {chunk}."
601 )
602 return chunk
604 @staticmethod
605 def _execute_chunked(
606 batched_fn: Callable,
607 args: tuple,
608 in_axes: Tuple,
609 batch_size: int,
610 chunk_size: int,
611 ) -> jnp.ndarray:
612 """Execute a vmapped function in memory-safe chunks.
614 Splits the batch dimension into sub-batches of at most *chunk_size*
615 elements, runs each through the JIT-compiled *batched_fn*, and
616 writes results into a pre-allocated output array.
618 Only one chunk's intermediate result is alive at a time: each
619 chunk is computed, copied into the output buffer, and then its
620 reference is dropped — allowing JAX/XLA to reclaim the memory
621 before the next chunk starts. This keeps peak memory at roughly
622 ``output_buffer + one_chunk_computation`` rather than the sum of
623 all chunk outputs.
625 Args:
626 batched_fn: A JIT-compiled, vmapped callable.
627 args: Full-batch arguments (before slicing).
628 in_axes: Per-argument batch axis specification.
629 batch_size: Total number of batch elements.
630 chunk_size: Maximum elements per chunk.
632 Returns:
633 Batched results with the same leading dimension as the
634 full batch.
635 """
636 n_chunks = (batch_size + chunk_size - 1) // chunk_size
637 log.debug(
638 f"Memory-aware chunking: splitting batch of {batch_size} into "
639 f"{n_chunks} chunks of <={chunk_size} elements."
640 )
642 output = None
643 for chunk_idx in range(n_chunks):
644 start = chunk_idx * chunk_size
645 end = min(start + chunk_size, batch_size)
646 size = end - start
648 # Slice each batched argument along its batch axis
649 chunk_args = tuple(
650 (
651 jax.lax.dynamic_slice_in_dim(a, start, size, axis=ax)
652 if ax is not None
653 else a
654 )
655 for a, ax in zip(args, in_axes)
656 )
658 chunk_result = batched_fn(*chunk_args)
660 if output is None:
661 # Pre-allocate the full output buffer on first chunk
662 out_shape = (batch_size,) + chunk_result.shape[1:]
663 output = jnp.zeros(out_shape, dtype=chunk_result.dtype)
665 # Copy chunk into the output buffer; the slice assignment
666 # creates a new array (JAX arrays are immutable) but the old
667 # `output` reference is immediately replaced, letting XLA
668 # reclaim it.
669 output = output.at[start:end].set(chunk_result)
671 # Explicitly drop the chunk reference so XLA can free the
672 # chunk's device memory before computing the next one.
673 del chunk_result, chunk_args
674 # Trigger garbage collection to release device buffers
675 jax.clear_caches()
677 return output
679 def _record(self, *args, **kwargs) -> List[Operation]:
680 """Run the circuit function and collect the recorded operations.
682 Uses :func:`~qml_essentials.tape.recording` as a context manager so
683 that the tape is always cleaned up — even if the circuit function
684 raises — and nested recordings (e.g. from ``_execute_batched``) each
685 get their own independent tape.
687 Args:
688 *args: Positional arguments forwarded to the circuit function.
689 **kwargs: Keyword arguments forwarded to the circuit function.
691 Returns:
692 List of :class:`~qml_essentials.operations.Operation` instances in
693 the order they were instantiated.
694 """
695 with recording() as tape:
696 self.f(*args, **kwargs)
697 return tape
699 def pulse_events(self, *args, **kwargs) -> list:
700 """Run the circuit and collect pulse events emitted by PulseGates.
702 Activates both the normal operation tape (so gates execute) and
703 a pulse-event tape that captures
704 :class:`~qml_essentials.drawing.PulseEvent` objects from leaf
705 pulse gates (RX, RY, RZ, CZ).
707 Args:
708 *args: Forwarded to the circuit function.
709 **kwargs: Forwarded to the circuit function.
711 Returns:
712 List of :class:`~qml_essentials.drawing.PulseEvent`.
713 """
714 with pulse_recording() as events:
715 with recording():
716 self.f(*args, **kwargs)
717 return events
719 @staticmethod
720 def _infer_n_qubits(ops: List[Operation], obs: List[Operation]) -> int:
721 """Infer the number of qubits from a list of operations and observables.
723 Args:
724 ops: Gate operations recorded on the tape.
725 obs: Observable operations used for measurement.
727 Returns:
728 The smallest number of qubits that covers all wire indices, i.e.
729 ``max(all_wires) + 1`` (at least 1).
730 """
731 all_wires: set[int] = set()
732 for op in ops + obs:
733 all_wires.update(op.wires)
734 return max(all_wires) + 1 if all_wires else 1
736 @staticmethod
737 def _simulate_pure(tape: List[Operation], n_qubits: int) -> jnp.ndarray:
738 """Statevector simulation kernel.
740 Starts from |00…0⟩ and applies each gate in *tape* via tensor
741 contraction. The state is kept in rank-*n* tensor form ``(2,)*n``
742 throughout the gate loop to avoid per-gate ``reshape`` dispatch;
743 only the initial and final conversions to/from the flat ``(2**n,)``
744 representation incur a reshape.
746 All gate tensors and einsum subscript strings are pre-extracted from
747 the tape before the simulation loop so that each iteration performs
748 only a single ``jnp.einsum`` call with zero additional Python
749 overhead (no method dispatch, no property access, no cache lookup).
751 Args:
752 tape: Ordered list of gate operations to apply.
753 n_qubits: Total number of qubits.
755 Returns:
756 Statevector of shape ``(2**n_qubits,)``.
757 """
758 dim = 2**n_qubits
760 # Pre-extract gate tensors and einsum subscripts — eliminates all
761 # per-gate Python overhead (method calls, property lookups, cache
762 # hits on _einsum_subscript) from the hot loop.
763 compiled = []
764 for op in tape:
765 if isinstance(op, Barrier):
766 continue
767 k = len(op.wires)
768 gt = op._gate_tensor(k)
769 sub = _einsum_subscript(n_qubits, k, tuple(op.wires))
770 compiled.append((gt, sub))
772 state = jnp.zeros(dim, dtype=_cdtype()).at[0].set(1.0)
773 psi = state.reshape((2,) * n_qubits)
774 for gt, sub in compiled:
775 psi = jnp.einsum(sub, gt, psi)
776 return psi.reshape(dim)
778 @staticmethod
779 def _simulate_mixed(tape: List[Operation], n_qubits: int) -> jnp.ndarray:
780 """Density-matrix simulation kernel.
782 Starts from \\rho = \\vert 0\\rangle\\langle 0\\vert and
783 applies each gate in *tape* via
784 :meth:`~qml_essentials.operations.Operation.apply_to_density`
785 (\\rho -> U\\rho U† for unitaries, \\Sigma_k K_k \\rho K_k\\dagger
786 for Kraus channels).
787 Required for noisy circuits.
789 Args:
790 tape: Ordered list of gate or channel operations to apply.
791 n_qubits: Total number of qubits.
793 Returns:
794 Density matrix of shape ``(2**n_qubits, 2**n_qubits)``.
795 """
796 dim = 2**n_qubits
797 rho = jnp.zeros((dim, dim), dtype=_cdtype()).at[0, 0].set(1.0)
798 for op in tape:
799 rho = op.apply_to_density(rho, n_qubits)
800 return rho
802 @staticmethod
803 def _simulate_and_measure(
804 tape: List[Operation],
805 n_qubits: int,
806 type: str,
807 obs: List[Operation],
808 use_density: bool,
809 shots: Optional[int] = None,
810 key: Optional[jnp.ndarray] = None,
811 ) -> jnp.ndarray:
812 """Run simulation and measurement in a single dispatch.
814 Chooses statevector or density-matrix simulation based on
815 *use_density*, then applies the appropriate measurement function.
816 This eliminates duplicated branching logic in single-sample and
817 batched execution paths.
819 When *shots* is not ``None``, the exact probability distribution is
820 first computed, then ``shots`` samples are drawn from it to produce
821 a noisy estimate of the requested measurement (``"probs"`` or
822 ``"expval"``).
824 Pure-circuit density optimisation — when ``type == "density"``
825 but no noise channels are present on the tape, the density matrix
826 is computed via statevector simulation followed by an outer product
827 ``\\rho = \\vert\\psi\\rangle\\langle\\psi\\vert``
828 instead of evolving the full ``2^n\\times 2^n`` matrix
829 gate by gate. This reduces the per-gate cost from O(4^n) to
830 O(2^n), giving a significant speed-up for medium qubit counts
831 (~4x for 5 qubits).
833 Args:
834 tape: Ordered list of gate/channel operations to apply.
835 n_qubits: Total number of qubits.
836 type: Measurement type (``"state"``/``"probs"``/``"expval"``/
837 ``"density"``).
838 obs: Observables for ``"expval"`` measurements.
839 use_density: If ``True``, use density-matrix simulation.
840 shots: Number of measurement shots. If ``None`` (default),
841 exact analytic results are returned.
842 key: JAX PRNG key for shot sampling. Required when *shots*
843 is not ``None``.
845 Returns:
846 Measurement result (shape depends on *type*).
847 """
848 if use_density:
849 # Check if any operation is actually a noise channel.
850 has_noise = any(isinstance(o, KrausChannel) for o in tape)
851 if has_noise:
852 # Must do full density-matrix evolution for Kraus channels.
853 rho = Script._simulate_mixed(tape, n_qubits)
854 else:
855 # Pure circuit requesting density output: simulate the
856 # statevector (O(depth\times 2^n)) and form # noqa: W605
857 # \rho = \vert\psi\rangle\langle\psi\vert once # noqa: W605
858 # (O(4^n)). This avoids the O(depth\times 4^n) cost of # noqa: W605
859 # evolving the full density matrix gate by gate.
860 state = Script._simulate_pure(tape, n_qubits)
861 rho = jnp.outer(state, jnp.conj(state))
863 if shots is not None and type in ("probs", "expval"):
864 exact_probs = jnp.real(jnp.diag(rho))
865 return Script._sample_shots(
866 exact_probs, n_qubits, type, obs, shots, key
867 )
868 return Script._measure_density(rho, n_qubits, type, obs)
870 state = Script._simulate_pure(tape, n_qubits)
872 if shots is not None and type in ("probs", "expval"):
873 exact_probs = jnp.abs(state) ** 2
874 return Script._sample_shots(exact_probs, n_qubits, type, obs, shots, key)
875 return Script._measure_state(state, n_qubits, type, obs)
877 @staticmethod
878 def _measure_state(
879 state: jnp.ndarray,
880 n_qubits: int,
881 type: str,
882 obs: List[Operation],
883 ) -> jnp.ndarray:
884 """Apply the requested measurement to a pure statevector.
886 Args:
887 state: Statevector of shape ``(2**n_qubits,)``.
888 n_qubits: Total number of qubits.
889 type: Measurement type — one of ``"state"``, ``"probs"``,
890 or ``"expval"``.
891 obs: Observables used when *type* is ``"expval"``.
893 Returns:
894 Measurement result whose shape depends on *type*:
896 - ``"state"`` -> ``(2**n_qubits,)``
897 - ``"probs"`` -> ``(2**n_qubits,)``
898 - ``"expval"`` -> ``(len(obs),)``
900 Raises:
901 ValueError: If *type* is not a recognised measurement type.
902 """
903 if type == "state":
904 return state
906 if type == "probs":
907 return jnp.abs(state) ** 2
909 if type == "expval":
910 # Fast path for single-qubit diagonal observables (PauliZ, etc.)
911 # where d0, d1 are the diagonal elements of the 2x2 observable.
912 # This replaces n_obs tensor contractions with a single |ψ|²
913 # and n_obs reductions over the probability vector.
915 def _is_single_qubit_diag(ob):
916 m = ob.__class__._matrix
917 if m is None or len(ob.wires) != 1:
918 return False
919 # Convert to NumPy to ensure concrete boolean evaluation
920 m_np = np.asarray(m)
921 return np.allclose(m_np - np.diag(np.diag(m_np)), 0)
923 all_single_qubit_diag = all(_is_single_qubit_diag(ob) for ob in obs)
925 if all_single_qubit_diag:
926 probs = jnp.abs(state) ** 2
927 psi_t = probs.reshape((2,) * n_qubits)
928 results = []
929 for ob in obs:
930 q = ob.wires[0]
931 d = np.real(np.diag(np.asarray(ob.__class__._matrix)))
932 # Sum probabilities over all axes except qubit q
933 p_q = jnp.sum(
934 psi_t, axis=tuple(i for i in range(n_qubits) if i != q)
935 )
936 results.append(d[0] * p_q[0] + d[1] * p_q[1])
937 return jnp.array(results)
939 # General path: stack observable matrices and use a single
940 # batched matmul instead of a Python loop of tensor contractions.
941 # O_states[i] = obs[i] |ψ⟩, then ⟨O_i⟩ = Re(⟨ψ|O_states[i]⟩).
942 obs_mats = jnp.stack(
943 [ob.lifted_matrix(n_qubits) for ob in obs], axis=0
944 ) # (n_obs, dim, dim)
945 # Batched matvec: (n_obs, dim, dim) @ (dim,) -> (n_obs, dim)
946 O_states = jnp.einsum("oij,j->oi", obs_mats, state)
947 return jnp.real(jnp.einsum("i,oi->o", jnp.conj(state), O_states))
949 raise ValueError(f"Unknown measurement type: {type!r}")
951 @staticmethod
952 def _measure_density(
953 rho: jnp.ndarray,
954 n_qubits: int,
955 type: str,
956 obs: List[Operation],
957 ) -> jnp.ndarray:
958 """Apply the requested measurement to a density matrix.
960 Args:
961 rho: Density matrix of shape ``(2**n_qubits, 2**n_qubits)``.
962 n_qubits: Total number of qubits.
963 type: Measurement type — one of ``"density"``, ``"probs"``,
964 or ``"expval"``.
965 obs: Observables used when *type* is ``"expval"``.
967 Returns:
968 Measurement result whose shape depends on *type*:
970 - ``"density"`` -> ``(2**n_qubits, 2**n_qubits)``
971 - ``"probs"`` -> ``(2**n_qubits,)``
972 - ``"expval"`` -> ``(len(obs),)``
974 Raises:
975 ValueError: If *type* is ``"state"`` (not valid for mixed circuits)
976 or another unrecognised type.
977 """
978 if type == "density":
979 return rho
981 if type == "probs":
982 return jnp.real(jnp.diag(rho))
984 if type == "expval":
985 # Tr(O \\rho ) = \\Sigma_ij O_ij \\rho _ji
986 # Stack all observable matrices and compute all traces in one
987 # batched operation.
988 obs_mats = jnp.stack(
989 [ob.lifted_matrix(n_qubits) for ob in obs], axis=0
990 ) # (n_obs, dim, dim)
991 # einsum "oij,ji->o" computes Tr(O_o @ \\rho ) for each observable
992 return jnp.real(jnp.einsum("oij,ji->o", obs_mats, rho))
994 raise ValueError(
995 "Measurement type 'state' is not defined for mixed (noisy) circuits. "
996 "Use 'density' instead."
997 )
999 @staticmethod
1000 def _sample_shots(
1001 probs: jnp.ndarray,
1002 n_qubits: int,
1003 type: str,
1004 obs: List[Operation],
1005 shots: int,
1006 key: jnp.ndarray,
1007 ) -> jnp.ndarray:
1008 """Convert exact probabilities into shot-sampled results.
1010 Draws *shots* samples from the computational-basis probability
1011 distribution and returns either estimated probabilities or
1012 shot-based expectation values.
1014 Args:
1015 probs: Exact probability vector of shape ``(2**n_qubits,)``.
1016 n_qubits: Total number of qubits.
1017 type: Measurement type — ``"probs"`` or ``"expval"``.
1018 obs: Observables used when *type* is ``"expval"``.
1019 shots: Number of measurement shots.
1020 key: JAX PRNG key for sampling.
1022 Returns:
1023 Shot-sampled measurement result:
1025 - ``"probs"`` → ``(2**n_qubits,)`` estimated probabilities.
1026 - ``"expval"`` → ``(len(obs),)`` estimated expectation values.
1027 """
1028 dim = 2**n_qubits
1030 # Draw `shots` samples from the computational basis.
1031 # Each sample is an integer in [0, dim) representing a basis state.
1032 samples = jax.random.choice(key, dim, shape=(shots,), p=probs)
1034 # Build a histogram of counts for each basis state.
1035 counts = jnp.zeros(dim, dtype=jnp.int32)
1036 counts = counts.at[samples].add(1)
1037 estimated_probs = counts / shots
1039 if type == "probs":
1040 return estimated_probs
1042 if type == "expval":
1043 # For each observable, compute O from the shot-sampled
1044 # probabilities. For diagonal observables this is exact;
1045 # for general observables we use Tr(O · diag(estimated_probs)).
1046 results = []
1047 for ob in obs:
1048 O_mat = ob.lifted_matrix(n_qubits)
1049 # diagonal approximation from
1050 # computational basis measurements, which is exact for
1051 # diagonal observables like PauliZ)
1052 results.append(jnp.real(jnp.dot(jnp.diag(O_mat), estimated_probs)))
1053 return jnp.array(results)
1055 raise ValueError(
1056 f"Shot simulation is only supported for 'probs' and 'expval', "
1057 f"got {type!r}."
1058 )
1060 def execute(
1061 self,
1062 type: str = "expval",
1063 obs: Optional[List[Operation]] = None,
1064 *,
1065 args: tuple = (),
1066 kwargs: Optional[dict] = None,
1067 in_axes: Optional[Tuple] = None,
1068 shots: Optional[int] = None,
1069 key: Optional[jnp.ndarray] = None,
1070 ) -> jnp.ndarray:
1071 """Execute the circuit and return measurement results.
1073 Args:
1074 type: Measurement type. One of:
1076 - ``"expval"`` — expectation value ⟨ψ|O|ψ⟩ / Tr(O\\rho ) for
1077 each observable in *obs*.
1078 - ``"probs"`` — probability vector of shape ``(2**n,)``.
1079 - ``"state"`` — raw statevector of shape ``(2**n,)``.
1080 - ``"density"`` — full density matrix of shape
1081 ``(2**n, 2**n)``.
1083 obs: Observables required when type is ``"expval"``.
1084 args: Positional arguments forwarded to the circuit function f.
1085 kwargs: Keyword arguments forwarded to f.
1086 in_axes: Batch axes for each element of *args*, following the same
1087 convention as ``jax.vmap``:
1089 - An integer selects that axis of the corresponding array as
1090 the batch dimension.
1091 - ``None`` broadcasts the argument (no batching).
1093 When provided, :meth:`execute` calls ``jax.vmap`` over the
1094 pure simulation kernel and returns results with a leading
1095 batch dimension.
1096 shots: Number of measurement shots for stochastic sampling.
1097 If ``None`` (default), exact analytic results are returned.
1098 Only supported for ``"probs"`` and ``"expval"`` measurement
1099 types.
1100 key: JAX PRNG key for shot sampling. If ``None`` and *shots*
1101 is set, a default key ``jax.random.PRNGKey(0)`` is used.
1103 Returns:
1104 Without in_axes: shape determined by type.
1105 With in_axes: shape ``(B, ...)`` with a leading batch dimension.
1106 """
1107 if obs is None:
1108 obs = []
1109 if kwargs is None:
1110 kwargs = {}
1111 if shots is not None and key is None:
1112 key = jax.random.PRNGKey(0)
1114 # Split single/ parallel execution
1115 # TODO: we might want to unify the n_qubit stuff such that we can eliminate
1116 # the parameter to this method entirely
1117 if in_axes is not None:
1118 return self._execute_batched(
1119 type=type,
1120 obs=obs,
1121 args=args,
1122 kwargs=kwargs,
1123 in_axes=in_axes,
1124 shots=shots,
1125 key=key,
1126 )
1127 else:
1128 tape = self._record(*args, **kwargs)
1129 n_qubits = self._n_qubits or self._infer_n_qubits(tape, obs)
1131 has_noise = any(isinstance(op, KrausChannel) for op in tape)
1132 use_density = type == "density" or has_noise
1134 return self._simulate_and_measure(
1135 tape,
1136 n_qubits,
1137 type,
1138 obs,
1139 use_density,
1140 shots=shots,
1141 key=key,
1142 )
1144 def _execute_batched(
1145 self,
1146 type: str,
1147 obs: List[Operation],
1148 args: tuple,
1149 kwargs: dict,
1150 in_axes: Tuple,
1151 shots: Optional[int] = None,
1152 key: Optional[jnp.ndarray] = None,
1153 ) -> jnp.ndarray:
1154 """Vectorise :meth:`execute` over a batch axis using ``jax.vmap``.
1156 The circuit function is traced once in Python with scalar slices to
1157 record the tape, determine ``n_qubits``, and detect noise. The
1158 resulting pure simulation kernel is then vmapped over the requested
1159 axes.
1161 Memory-aware chunking — before launching the full vmap, the
1162 method estimates peak memory usage. If the full batch would exceed
1163 available RAM (with a safety margin), the batch is automatically
1164 split into sub-batches that fit. Each chunk is vmapped independently
1165 and the results are concatenated. This trades a small amount of
1166 wall-clock time for guaranteed execution without OOM.
1168 When the full batch fits in memory, there is zero overhead — the
1169 memory check is a pure Python arithmetic calculation (no JAX calls).
1171 Args:
1172 type: Measurement type (see :meth:`execute`).
1173 obs: Observables (see :meth:`execute`).
1174 args: Positional arguments for the circuit function.
1175 kwargs: Keyword arguments for the circuit function.
1176 in_axes: One entry per element of *args*. Follows ``jax.vmap``
1177 convention: an int gives the batch axis; ``None`` broadcasts.
1178 shots: Number of measurement shots. If ``None``, exact results.
1179 key: JAX PRNG key for shot sampling.
1181 Returns:
1182 Batched measurement results of shape ``(B, ...)`` where *B* is the
1183 size of the batch dimension.
1185 Raises:
1186 ValueError: If ``len(in_axes) != len(args)``.
1188 Note:
1189 The ``jax.vmap`` call at the end of this method is the exact
1190 boundary to replace with ``jax.shard_map`` for multi-device
1191 execution::
1193 from jax.sharding import PartitionSpec as P, Mesh
1194 result = jax.shard_map(
1195 _single_execute, mesh=mesh,
1196 in_specs=tuple(P(0) if ax is not None else P() for ax in in_axes),
1197 out_specs=P(0),
1198 )(*args)
1199 """
1200 if len(in_axes) != len(args):
1201 raise ValueError(
1202 f"in_axes has {len(in_axes)} entries but args has {len(args)}. "
1203 "Provide one in_axes entry per positional argument."
1204 )
1206 # Determine batch size from the first batched arg
1207 batch_size = 1
1208 for a, ax in zip(args, in_axes):
1209 if ax is not None:
1210 batch_size = a.shape[ax]
1211 break
1213 arg_shapes = tuple(
1214 (a.shape, a.dtype) if hasattr(a, "shape") else type(a) for a in args
1215 )
1216 cache_kwargs = _make_hashable(
1217 {k: v for k, v in kwargs.items() if not isinstance(v, jnp.ndarray)}
1218 )
1220 # TODO: we need to fix the dirty class-level `batch_gate_error` hack
1221 from qml_essentials.gates import UnitaryGates
1223 cache_key = (
1224 type,
1225 in_axes,
1226 arg_shapes,
1227 cache_kwargs,
1228 UnitaryGates.batch_gate_error,
1229 )
1231 # --- Cache-hit fast path (no shots) ---
1232 cached = self._jit_cache.get(cache_key)
1233 if cached is not None and shots is None:
1234 batched_fn, n_qubits, use_density = cached
1235 # Check if we already determined the chunk size for this
1236 # exact batch_size (avoids repeated psutil syscalls).
1237 mem_key = ("_mem", cache_key, batch_size)
1238 cached_chunk = self._jit_cache.get(mem_key)
1239 if cached_chunk is not None:
1240 if cached_chunk >= batch_size:
1241 return batched_fn(*args)
1242 return self._execute_chunked(
1243 batched_fn, args, in_axes, batch_size, cached_chunk
1244 )
1245 chunk_size = self._compute_chunk_size(
1246 n_qubits, batch_size, type, use_density, len(obs)
1247 )
1248 self._jit_cache[mem_key] = chunk_size
1249 if chunk_size >= batch_size:
1250 return batched_fn(*args)
1251 return self._execute_chunked(
1252 batched_fn, args, in_axes, batch_size, chunk_size
1253 )
1255 # Record the tape once using scalar slices of each arg.
1256 # This determines n_qubits and whether noise channels are present
1257 # without running the full batch.
1258 # Note, that we use lax.index_in_dim instead of jnp.take because JAX
1259 # random key arrays do not support jnp.take.
1260 # TODO: fix once that is available in JAX
1261 def _slice_first(a, ax):
1262 """Take the first element along axis *ax*."""
1263 return jax.lax.index_in_dim(a, 0, axis=ax, keepdims=False)
1265 scalar_args = tuple(
1266 _slice_first(a, ax) if ax is not None else a for a, ax in zip(args, in_axes)
1267 )
1268 tape = self._record(*scalar_args, **kwargs)
1269 n_qubits = self._n_qubits or self._infer_n_qubits(tape, obs)
1270 has_noise = any(isinstance(op, KrausChannel) for op in tape)
1271 use_density = type == "density" or has_noise
1273 chunk_size = self._compute_chunk_size(
1274 n_qubits, batch_size, type, use_density, len(obs)
1275 )
1277 # Re-recording inside this function is necessary: the tape may
1278 # contain operations whose matrices depend on the batched argument
1279 # (e.g. RX(theta) where theta is a JAX tracer). jax.vmap traces
1280 # this function once and generates a single XLA computation for
1281 # the entire batch.
1282 if shots is not None and type in ("probs", "expval"):
1283 # Shot mode: compute exact probabilities, then sample.
1284 # The shot key is appended as an extra vmapped argument.
1285 def _single_execute_shots(*single_args_and_key):
1286 *single_args, shot_key = single_args_and_key
1287 single_tape = self._record(*single_args, **kwargs)
1288 exact_result = self._simulate_and_measure(
1289 single_tape, n_qubits, "probs", obs, use_density
1290 )
1291 return Script._sample_shots(
1292 exact_result, n_qubits, type, obs, shots, shot_key
1293 )
1295 shot_keys = jax.random.split(key, batch_size)
1296 shot_in_axes = in_axes + (0,) # key is batched over axis 0
1297 shot_args = args + (shot_keys,)
1299 # Shot-mode uses a separate cache key (includes shots)
1300 shot_cache_key = (
1301 type,
1302 "shots",
1303 shots,
1304 in_axes,
1305 arg_shapes,
1306 UnitaryGates.batch_gate_error,
1307 )
1308 cached_shot = self._jit_cache.get(shot_cache_key)
1309 if cached_shot is not None:
1310 batched_fn = cached_shot[0]
1311 else:
1312 batched_fn = jax.jit(
1313 jax.vmap(_single_execute_shots, in_axes=shot_in_axes)
1314 )
1315 self._jit_cache[shot_cache_key] = (batched_fn, n_qubits, use_density)
1317 if chunk_size >= batch_size:
1318 return batched_fn(*shot_args)
1319 return self._execute_chunked(
1320 batched_fn, shot_args, shot_in_axes, batch_size, chunk_size
1321 )
1323 def _single_execute(*single_args):
1324 single_tape = self._record(*single_args, **kwargs)
1325 return self._simulate_and_measure(
1326 single_tape, n_qubits, type, obs, use_density
1327 )
1329 # Wrapping the vmapped function in jax.jit has two effects:
1330 # 1. Multi-core utilisation — the JIT-compiled XLA program can
1331 # use intra-op parallelism to distribute independent SIMD lanes
1332 # across CPU threads, unlike an eager vmap which runs
1333 # single-threaded.
1334 # 2. Compilation caching — subsequent calls with the same input
1335 # shapes reuse the compiled kernel and skip all Python-level
1336 # tracing, eliminating the O(B\\times circuit_depth) Python overhead.
1337 #
1338 # The compiled function is cached on this Script instance,
1339 # keyed on (type, in_axes, arg_shapes). Repeated calls with the
1340 # same structure (e.g. training iterations) skip both Python-level
1341 # tracing and XLA compilation entirely — they jump straight to the
1342 # cache check at the top of this method.
1343 # NOTE: when altering properties of the model, this might not get re-compiled
1344 # TODO: we might want to rework the data_reupload mechanism at some point
1345 batched_fn = jax.jit(jax.vmap(_single_execute, in_axes=in_axes))
1346 # Cache the function together with metadata for fast-path memory
1347 # checks on subsequent calls.
1348 self._jit_cache[cache_key] = (batched_fn, n_qubits, use_density)
1350 if chunk_size >= batch_size:
1351 return batched_fn(*args)
1352 return self._execute_chunked(batched_fn, args, in_axes, batch_size, chunk_size)
1354 def draw(
1355 self,
1356 figure: str = "text",
1357 args: tuple = (),
1358 kwargs: Optional[dict] = None,
1359 **draw_kwargs: Any,
1360 ) -> Union[str, Any]:
1361 """Draw the quantum circuit.
1363 Records the tape by calling the circuit function with the given
1364 arguments, then renders the resulting gate sequence.
1366 Args:
1367 figure: Rendering backend. One of:
1369 - ``"text"`` — ASCII art (returned as a ``str``).
1370 - ``"mpl"`` — Matplotlib figure (returns ``(fig, ax)``).
1371 - ``"tikz"`` — LaTeX/TikZ code via ``quantikz``
1372 (returns a :class:`TikzFigure`).
1373 - ``"pulse"`` — Pulse schedule plot (returns ``(fig, axes)``).
1375 args: Positional arguments forwarded to the circuit function
1376 to record the tape.
1377 kwargs: Keyword arguments forwarded to the circuit function.
1378 **draw_kwargs: Extra options forwarded to the rendering backend:
1380 - ``gate_values`` (bool): Show numeric gate angles instead of
1381 symbolic \\theta_i labels. Default ``False``.
1382 - ``show_carrier`` (bool): For ``"pulse"`` mode, overlay the
1383 carrier-modulated waveform. Default ``False``.
1385 Returns:
1386 Depends on *figure*:
1388 - ``"text"`` -> ``str``
1389 - ``"mpl"`` -> ``(matplotlib.figure.Figure, matplotlib.axes.Axes)``
1390 - ``"tikz"`` -> :class:`TikzFigure`
1391 - ``"pulse"`` -> ``(matplotlib.figure.Figure, numpy.ndarray)``
1393 Raises:
1394 ValueError: If *figure* is not one of the supported modes.
1395 """
1396 if figure not in ("text", "mpl", "tikz", "pulse"):
1397 raise ValueError(
1398 f"Invalid figure mode: {figure!r}. "
1399 "Must be 'text', 'mpl', 'tikz', or 'pulse'."
1400 )
1402 if kwargs is None:
1403 kwargs = {}
1405 if figure == "pulse":
1406 from qml_essentials.drawing import draw_pulse_schedule
1408 events = self.pulse_events(*args, **kwargs)
1409 n_qubits = (
1410 self._n_qubits
1411 or max((w for ev in events for w in ev.wires), default=0) + 1
1412 )
1413 return draw_pulse_schedule(events, n_qubits, **draw_kwargs)
1415 tape = self._record(*args, **kwargs)
1416 n_qubits = self._n_qubits or self._infer_n_qubits(tape, [])
1418 # Filter out noise channels for drawing — they clutter the diagram
1419 ops = [op for op in tape if not isinstance(op, KrausChannel)]
1421 if figure == "text":
1422 return draw_text(ops, n_qubits)
1423 elif figure == "mpl":
1424 return draw_mpl(ops, n_qubits, **draw_kwargs)
1425 else: # tikz
1426 return draw_tikz(ops, n_qubits, **draw_kwargs)