Coverage for qml_essentials / yaqsi.py: 93%
511 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-16 10:19 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-16 10:19 +0000
1from functools import reduce
2from typing import Any, Callable, List, Optional, Tuple, Union
3import threading
5import diffrax
6import jax
7import jax.numpy as jnp
8import jax.scipy.linalg
9import equinox as eqx
10import numpy as np # needed to prevent jitting some operations
12from qml_essentials.operations import (
13 Barrier,
14 Hermitian,
15 ParametrizedHamiltonian,
16 Operation,
17 KrausChannel,
18 PauliZ,
19 _einsum_subscript,
20 _cdtype,
21)
22from qml_essentials.tape import recording, pulse_recording
23from qml_essentials.drawing import draw_text, draw_mpl, draw_tikz
25import logging
27log = logging.getLogger(__name__)
30# def _args_contain_tracer(args) -> bool:
31# """Return True if any leaf in *args* is a JAX tracer.
33# Used by :meth:`Script._execute_batched` to detect that the call is
34# happening under an outer JAX transformation (``jit``/``vmap``/``grad``/
35# ``jacrev`` etc.). When that is the case the per-Script
36# ``_jit_cache`` must be bypassed: a previously cached
37# ``jax.jit(jax.vmap(...))`` was built under a different outer trace
38# and re-using it would leak that trace's tracers (raising
39# ``UnexpectedTracerError`` on the second transform). XLA compilation
40# artefacts are still cached at the JAX level by jaxpr signature, so
41# bypassing only the local Python wrapper has negligible runtime cost.
42# """
43# from jax.core import Tracer
44# for leaf in jax.tree_util.tree_leaves(args):
45# if isinstance(leaf, Tracer):
46# return True
47# return False
50def _make_hashable(obj):
51 """Recursively convert an object into a hashable form for cache keys.
53 - ``dict`` → sorted tuple of ``(key, _make_hashable(value))`` pairs
54 - ``list`` → tuple of ``_make_hashable(element)``
55 - ``set`` → frozenset of ``_make_hashable(element)``
56 - everything else is returned as-is (assumed hashable)
57 """
58 if isinstance(obj, dict):
59 return tuple(sorted((k, _make_hashable(v)) for k, v in obj.items()))
60 if isinstance(obj, (list, tuple)):
61 return tuple(_make_hashable(x) for x in obj)
62 if isinstance(obj, set):
63 return frozenset(_make_hashable(x) for x in obj)
64 return obj
67class Yaqsi:
68 # TODO: generally, I would like to merge this into operations or vice-versa
69 # and only keep Script here. It's not clear how to do this though.
71 # Module-level cache for JIT-compiled ODE solvers. Keyed on
72 # (coeff_fn_id, dim, atol, rtol, max_steps, throw) so that all
73 # evolve() calls with the same pulse shape function and matrix size
74 # share one compiled XLA program. This turns O(n_gates) JIT
75 # compilations into O(n_distinct_pulse_shapes) during pulse-mode
76 # circuit building.
77 _evolve_solver_cache: dict = {}
78 _evolve_solver_cache_lock = threading.Lock()
80 # Default solver knobs for parametrized (time-dependent) evolution.
81 # These can be overridden per-call via the **odeint_kwargs of
82 # ``evolve()`` or globally via :meth:`set_solver_defaults`.
83 #
84 # ``max_steps`` is the hard cap on accepted ODE steps. Pulse-level
85 # workloads at on-resonance carriers (ω_c ≈ ω_q) require many more
86 # steps than the diffrax default during JIT — 2**13 = 8192 is
87 # large enough for realistic single- and two-qubit pulses while
88 # remaining cheap to compile.
89 #
90 # ``throw`` controls whether diffrax raises on solver failure
91 # (e.g. ``MaxStepsReached``). When set to ``False`` the gate
92 # factory instead returns a NaN-filled unitary so the calling
93 # optimiser sees a well-defined (but useless) result and can
94 # gracefully reject the candidate.
95 # Whether to call ``jax.clear_caches()`` between memory-aware
96 # chunks in :meth:`Script._execute_chunked`. Default ``False``:
97 # clearing caches between chunks forces XLA to recompile the same
98 # batched program for every chunk, which is a major performance hit
99 # when many chunks are needed. Set ``True`` only if you observe
100 # OOM growth across chunks.
101 _clear_caches_between_chunks: bool = False
103 # ``solver`` selects the time-integration backend for the
104 # interaction-picture ODE ``dU/dt = -i H_I(t) U``:
105 #
106 # * ``"dopri8"`` (default) — adaptive Dormand-Prince 8(7) via
107 # diffrax. Robust but expensive on highly oscillatory drives
108 # because the step controller resolves every fast cycle.
109 # * ``"dopri5"`` — TODO description
110 # * ``"magnus2"`` — commutator-free Magnus, 2nd order (midpoint
111 # rule) on a fixed ``magnus_steps`` grid via ``jax.lax.scan``.
112 # One ``expm`` per step. Preserves unitarity to machine
113 # precision and fuses into a single XLA program.
114 # * ``"magnus4"`` — commutator-free Magnus, 4th order (CFM4:2 of
115 # Blanes & Moan) on a fixed ``magnus_steps`` grid. Two ``H``
116 # evaluations and two ``expm`` per step; typically the best
117 # accuracy/cost trade-off for smooth oscillatory pulse drives.
118 #
119 # ``magnus_steps`` is the number of fixed substeps for the Magnus
120 # integrators (ignored for ``dopri8``). Choose it so that ``h =
121 # T/N`` resolves the fastest oscillation in ``H(t)`` (~few steps
122 # per period of the highest frequency).
123 _solver_defaults: dict = {
124 "max_steps": 2**13,
125 "throw": True,
126 "solver": "dopri8",
127 "magnus_steps": 256,
128 }
129 _valid_solvers = ("dopri8", "dopri5", "magnus2", "magnus4")
131 @classmethod
132 def set_solver_defaults(
133 cls,
134 max_steps: Optional[int] = None,
135 throw: Optional[bool] = None,
136 solver: Optional[str] = None,
137 magnus_steps: Optional[int] = None,
138 ) -> dict:
139 """Update class-level solver defaults; return the previous values.
141 The returned dictionary is suitable for restoring the previous
142 defaults via ``set_solver_defaults(**prev)``.
144 Args:
145 max_steps: New default for ``max_steps`` (ignored if ``None``).
146 throw: New default for ``throw`` (ignored if ``None``).
148 Returns:
149 Dictionary with the previous values of the updated keys.
150 """
151 prev: dict = {}
152 if max_steps is not None:
153 prev["max_steps"] = cls._solver_defaults["max_steps"]
154 cls._solver_defaults["max_steps"] = int(max_steps)
155 if throw is not None:
156 prev["throw"] = cls._solver_defaults["throw"]
157 cls._solver_defaults["throw"] = bool(throw)
158 if solver is not None:
159 if solver not in cls._valid_solvers:
160 raise ValueError(
161 f"Unknown solver {solver!r}; expected one of {cls._valid_solvers}"
162 )
163 prev["solver"] = cls._solver_defaults["solver"]
164 cls._solver_defaults["solver"] = solver
165 if magnus_steps is not None:
166 prev["magnus_steps"] = cls._solver_defaults["magnus_steps"]
167 cls._solver_defaults["magnus_steps"] = int(magnus_steps)
168 return prev
170 @classmethod
171 def _store_evolve_solver(cls, cache_key: tuple, solve: Callable) -> Callable:
172 """Cache a compiled evolve solver unless another thread won the race."""
173 with cls._evolve_solver_cache_lock:
174 existing = cls._evolve_solver_cache.get(cache_key)
175 if existing is not None:
176 return existing
177 cls._evolve_solver_cache[cache_key] = solve
178 return solve
180 @classmethod
181 def clear_evolve_solver_cache(cls) -> None:
182 """Drop every cached compiled evolve solver.
184 Call this whenever the coefficient functions referenced by the
185 cache keys are rebuilt (e.g. when :class:`PulseGates` swaps in
186 a new pulse envelope, RWA flag or frame). Without an explicit
187 eviction the cache keeps the old code objects alive and would
188 also retain XLA programs that no longer match any active
189 coefficient function.
190 """
191 with cls._evolve_solver_cache_lock:
192 cls._evolve_solver_cache.clear()
194 @classmethod
195 def _parse_evolve_solver_options(cls, odeint_kwargs: dict) -> tuple:
196 """Pop and validate solver options from ``evolve(..., **odeint_kwargs)``."""
197 default_tol = 1.0e-10 if jax.config.x64_enabled else 1.4e-8
198 atol = odeint_kwargs.pop("atol", default_tol)
199 rtol = odeint_kwargs.pop("rtol", default_tol)
200 max_steps = int(
201 odeint_kwargs.pop("max_steps", cls._solver_defaults["max_steps"])
202 )
203 throw = bool(odeint_kwargs.pop("throw", cls._solver_defaults["throw"]))
204 solver_name = str(odeint_kwargs.pop("solver", cls._solver_defaults["solver"]))
205 if solver_name not in cls._valid_solvers:
206 raise ValueError(
207 f"Unknown solver {solver_name!r}; expected one of {cls._valid_solvers}"
208 )
209 magnus_steps = int(
210 odeint_kwargs.pop("magnus_steps", cls._solver_defaults["magnus_steps"])
211 )
212 return atol, rtol, max_steps, throw, solver_name, magnus_steps
214 @classmethod
215 def _build_magnus_evolve_solver(
216 cls,
217 cache_key: tuple,
218 coeff_fns: Tuple[Callable, ...],
219 n_terms: int,
220 dim: int,
221 solver_name: str,
222 magnus_steps: int,
223 ) -> Callable:
224 """Build and cache a fixed-step commutator-free Magnus solver."""
225 _coeff_fns = coeff_fns
226 _cdtype_local = jnp.complex128 if jax.config.x64_enabled else jnp.complex64
227 n_steps = magnus_steps
228 solver_name_local = solver_name
230 @eqx.filter_jit
231 def _solve(neg_iH_split, params, t0, t1):
232 # Reconstruct the per-term complex matrices ``-i H_i`` from their
233 # split (Re, Im) representation so the coefficient sum is a single
234 # complex tensordot.
235 A_all = neg_iH_split[:, 0]
236 B_all = neg_iH_split[:, 1]
237 neg_iH = (A_all + 1j * B_all).astype(_cdtype_local)
239 h = (t1 - t0) / n_steps
241 def H_at(t):
242 c = jnp.stack(
243 [
244 jnp.asarray(_coeff_fns[i](params[i], t)).reshape(())
245 for i in range(n_terms)
246 ]
247 ).astype(_cdtype_local)
248 return jnp.tensordot(c, neg_iH, axes=1)
250 if solver_name_local == "magnus2":
252 def step(U, n):
253 tn = t0 + n * h
254 Omega = h * H_at(tn + 0.5 * h)
255 return jax.scipy.linalg.expm(Omega) @ U, None
257 else:
258 import math
260 sqrt3 = math.sqrt(3.0)
261 c1 = 0.5 - sqrt3 / 6.0
262 c2 = 0.5 + sqrt3 / 6.0
263 a1 = 0.25 + sqrt3 / 6.0
264 a2 = 0.25 - sqrt3 / 6.0
266 def step(U, n):
267 tn = t0 + n * h
268 H1 = H_at(tn + c1 * h)
269 H2 = H_at(tn + c2 * h)
270 Omega_a = h * (a1 * H1 + a2 * H2)
271 Omega_b = h * (a2 * H1 + a1 * H2)
272 # CFM4:2 ordering (Blanes & Moan 2006, Table II):
273 # U_{n+1} = exp(Ω_b) · exp(Ω_a) · U_n.
274 U_next = (
275 jax.scipy.linalg.expm(Omega_b)
276 @ jax.scipy.linalg.expm(Omega_a)
277 @ U
278 )
279 return U_next, None
281 U0 = jnp.eye(dim, dtype=_cdtype_local)
282 U_final, _ = jax.lax.scan(step, U0, jnp.arange(n_steps))
283 return U_final
285 return cls._store_evolve_solver(cache_key, _solve)
287 @classmethod
288 def _build_diffrax_evolve_solver(
289 cls,
290 cache_key: tuple,
291 coeff_fns: Tuple[Callable, ...],
292 n_terms: int,
293 dim: int,
294 atol: float,
295 rtol: float,
296 max_steps: int,
297 throw: bool,
298 solver_name: str,
299 _rdtype,
300 ) -> Callable:
301 """Build and cache an adaptive diffrax-based evolve solver."""
302 solver = diffrax.Dopri8() if solver_name == "dopri8" else diffrax.Dopri5()
303 stepsize_controller = diffrax.PIDController(atol=atol, rtol=rtol)
304 _coeff_fns = coeff_fns
306 @eqx.filter_jit
307 def _solve(neg_iH_split, params, t0, t1):
308 """Solve dU/dt = sum_i f_i(p_i, t) * (-iH_i) * U from t0 to t1.
310 ``neg_iH_split`` has shape ``(n_terms, 2, dim, dim)`` with
311 ``[:, 0]`` = Re(-iH_i) and ``[:, 1]`` = Im(-iH_i).
312 ``params`` is a list/tuple of length ``n_terms`` carrying
313 each term's coefficient parameters. The state ``y`` has
314 shape ``(2, dim, dim)`` with ``y[0] = Re(U)`` and
315 ``y[1] = Im(U)``.
316 """
317 A_all = neg_iH_split[:, 0]
318 B_all = neg_iH_split[:, 1]
320 def rhs(t, y, args):
321 # Each coefficient function must return a scalar value; some
322 # call sites pass a shape-(1,) param array, so coerce to a
323 # true scalar before stacking.
324 c = jnp.stack(
325 [
326 jnp.asarray(_coeff_fns[i](args[i], t)).reshape(())
327 for i in range(n_terms)
328 ]
329 )
330 u_re = y[0]
331 u_im = y[1]
332 A_eff = jnp.tensordot(c, A_all, axes=1)
333 B_eff = jnp.tensordot(c, B_all, axes=1)
334 du_re = A_eff @ u_re - B_eff @ u_im
335 du_im = A_eff @ u_im + B_eff @ u_re
336 return jnp.stack([du_re, du_im], axis=0)
338 y0 = jnp.stack(
339 [
340 jnp.eye(dim, dtype=_rdtype),
341 jnp.zeros((dim, dim), dtype=_rdtype),
342 ],
343 axis=0,
344 )
346 sol = diffrax.diffeqsolve(
347 diffrax.ODETerm(rhs),
348 solver,
349 t0=t0,
350 t1=t1,
351 dt0=None,
352 y0=y0,
353 args=params,
354 stepsize_controller=stepsize_controller,
355 max_steps=max_steps,
356 throw=throw,
357 )
359 y_final = sol.ys[0]
360 U = y_final[0] + 1j * y_final[1]
362 if not throw:
363 successful = sol.result == diffrax.RESULTS.successful
364 U = jnp.where(successful, U, jnp.full_like(U, jnp.nan))
365 return U
367 return cls._store_evolve_solver(cache_key, _solve)
369 @staticmethod
370 def _partial_trace_single(
371 rho: jnp.ndarray,
372 n_qubits: int,
373 keep: List[int],
374 ) -> jnp.ndarray:
375 """Partial trace of a single density matrix (no batch dimension)."""
376 shape = (2,) * (2 * n_qubits)
377 rho_t = rho.reshape(shape)
379 trace_out = sorted(set(range(n_qubits)) - set(keep))
381 for q in reversed(trace_out):
382 n_remaining = rho_t.ndim // 2
383 rho_t = jnp.trace(rho_t, axis1=q, axis2=q + n_remaining)
385 dim = 2 ** len(keep)
386 return rho_t.reshape(dim, dim)
388 @classmethod
389 def partial_trace(
390 cls,
391 rho: jnp.ndarray,
392 n_qubits: int,
393 keep: List[int],
394 ) -> jnp.ndarray:
395 """Partial trace of a density matrix, keeping only the specified qubits.
397 Supports both single density matrices of shape ``(2**n, 2**n)`` and
398 batched density matrices of shape ``(B, 2**n, 2**n)``.
400 Args:
401 rho: Density matrix of shape ``(2**n, 2**n)`` or ``(B, 2**n, 2**n)``.
402 n_qubits: Total number of qubits.
403 keep: List of qubit indices to *keep* (0-indexed).
405 Returns:
406 Reduced density matrix of shape ``(2**k, 2**k)`` or ``(B, 2**k, 2**k)``
407 where *k* = ``len(keep)``.
408 """
410 dim = 2**n_qubits
411 if rho.shape == (dim, dim):
412 return Yaqsi._partial_trace_single(rho, n_qubits, keep)
413 # Batched: shape (B, dim, dim)
414 return jax.vmap(lambda r: Yaqsi._partial_trace_single(r, n_qubits, keep))(rho)
416 @staticmethod
417 def _marginalize_probs_single(
418 probs: jnp.ndarray,
419 target_shape: Tuple[int],
420 trace_out: Tuple[int],
421 ) -> jnp.ndarray:
422 """Marginalize a single probability vector (no batch dimension)."""
423 probs_t = probs.reshape(target_shape)
425 for q in trace_out:
426 probs_t = probs_t.sum(axis=q)
428 return probs_t.ravel()
430 @classmethod
431 def marginalize_probs(
432 cls,
433 probs: jnp.ndarray,
434 n_qubits: int,
435 keep: Tuple[int],
436 ) -> jnp.ndarray:
437 """Marginalize a probability vector to keep only the specified qubits.
439 Supports both single probability vectors of shape ``(2**n,)`` and
440 batched vectors of shape ``(B, 2**n)``.
442 Args:
443 probs: Probability vector of shape ``(2**n,)`` or ``(B, 2**n)``.
444 n_qubits: Total number of qubits.
445 keep: List of qubit indices to *keep* (0-indexed).
447 Returns:
448 Marginalized probability vector of shape ``(2**k,)`` or ``(B, 2**k)``
449 where *k* = ``len(keep)``.
450 """
452 dim = 2**n_qubits
453 trace_out = tuple(q for q in range(n_qubits - 1, -1, -1) if q not in keep)
454 target_shape = (2,) * n_qubits
456 return jax.vmap(
457 lambda p: Yaqsi._marginalize_probs_single(p, target_shape, trace_out)
458 )(probs.reshape(-1, dim))
460 @classmethod
461 def build_parity_observable(
462 cls,
463 qubit_group: List[int],
464 ) -> Hermitian:
465 """Build a multi-qubit parity observable.
467 Args:
468 qubit_group: List of qubit indices for the parity measurement.
470 Returns:
471 A :class:`Hermitian` operation whose matrix is the Z parity
472 tensor product and whose wires match the given qubits.
473 """
474 Z = PauliZ._matrix
475 mat = reduce(jnp.kron, [Z] * len(qubit_group))
476 return Hermitian(matrix=mat, wires=qubit_group, record=False)
478 @classmethod
479 def evolve(
480 cls,
481 hamiltonian: Union["Hermitian", "ParametrizedHamiltonian"],
482 name: Optional[str] = None,
483 **odeint_kwargs: Any,
484 ) -> Callable:
485 """Return a gate-factory for Hamiltonian time evolution.
487 Supports two modes:
489 Static — when *hamiltonian* is a :class:`Hermitian`::
491 gate = evolve(Hermitian(H_mat, wires=0))
492 gate(t=0.5) # U = exp(-i*0.5*H)
494 Time-dependent — when *hamiltonian* is a
495 :class:`ParametrizedHamiltonian` (created via ``coeff_fn * Hermitian``)::
497 H_td = coeff_fn * Hermitian(H_mat, wires=0)
498 gate = evolve(H_td)
499 gate([A, sigma], T) # U via ODE: dU/dt = -i f(p,t) H * U
501 The time-dependent case solves the Schrödinger equation numerically
502 using ``diffrax.diffeqsolve`` with a Dopri8 adaptive Runge-Kutta
503 solver
505 All computations are pure JAX and fully differentiable with
506 ``jax.grad``.
508 Args:
509 hamiltonian: Either a :class:`Hermitian` (static evolution) or a
510 :class:`ParametrizedHamiltonian` (time-dependent evolution).
511 **odeint_kwargs: Extra keyword arguments. Recognised keys:
513 - ``atol``, ``rtol`` — absolute/relative tolerances for the
514 adaptive step-size controller (default ``1.4e-8``).
516 Returns:
517 A callable gate factory. Signature depends on the mode:
519 - Static: ``(t, wires=0) -> Operation``
520 - Time-dependent: ``(coeff_args, T) -> Operation``
522 Raises:
523 TypeError: If *hamiltonian* is neither ``Hermitian`` nor
524 ``ParametrizedHamiltonian``.
525 """
526 if isinstance(hamiltonian, Hermitian):
527 return cls._evolve_static(hamiltonian, name=name)
528 elif isinstance(hamiltonian, ParametrizedHamiltonian):
529 return cls._evolve_parametrized(hamiltonian, name=name, **odeint_kwargs)
530 else:
531 raise TypeError(
532 f"evolve() expects a Hermitian or ParametrizedHamiltonian, "
533 f"got {type(hamiltonian)}"
534 )
536 @staticmethod
537 def _evolve_static(hermitian: Hermitian, name: Optional[str] = None) -> Callable:
538 """Gate factory for static Hamiltonian evolution U = exp(-i t H)."""
539 H_mat = hermitian.matrix
541 def _apply(t: float, wires: Union[int, List[int]] = 0) -> Operation:
542 U = jax.scipy.linalg.expm(-1j * t * H_mat)
543 return Operation(wires=wires, matrix=U, name=name)
545 return _apply
547 @classmethod
548 def _evolve_parametrized(
549 cls,
550 ph: ParametrizedHamiltonian,
551 name: Optional[str] = None,
552 **odeint_kwargs: Any,
553 ) -> Callable:
554 """Gate factory for time-dependent (multi-term) Hamiltonian evolution.
556 Solves the matrix ODE
558 dU/dt = -i [\\sum_i f_i(params_i, t) * H_i] * U, U(0) = I
560 with ``diffrax.diffeqsolve`` (Dopri8 adaptive RK). The Hamiltonian
561 may contain one or more ``coeff_fn * Hermitian`` terms (see
562 :class:`ParametrizedHamiltonian`); the single-term case is the
563 usual ``coeff_fn * Hermitian`` and is fully backward compatible.
565 Implementation notes:
567 - To avoid diffrax's experimental complex dtype path, the ODE is
568 reformulated in real arithmetic. Writing ``-iH_i = A_i + i B_i``
569 and ``U = U_re + i U_im``, each term contributes::
571 d(U_re)/dt += f_i(p_i,t) * (A_i @ U_re - B_i @ U_im)
572 d(U_im)/dt += f_i(p_i,t) * (A_i @ U_im + B_i @ U_re)
574 - ``-iH_i`` is precomputed once per term and stacked into a
575 ``(n_terms, 2, dim, dim)`` real array, contracted via
576 ``einsum`` against the per-step coefficient vector
577 ``c = [f_0(p_0,t), ..., f_{n-1}(p_{n-1},t)]``.
579 - The JIT-compiled solver is cached per coefficient-function code
580 tuple (and ``dim``, tolerances) so multiple ``evolve()`` calls
581 with the same pulse shape — but different Hamiltonian matrices
582 or parameters — reuse the same compiled XLA program.
584 TODO: switch back once diffrax is stable with complex arithmetic.
586 Args:
587 ph: A :class:`ParametrizedHamiltonian` (one or more terms).
588 **odeint_kwargs: Keyword arguments forwarded to
589 ``diffrax.diffeqsolve``. Recognised keys:
591 - ``atol``, ``rtol`` — absolute/relative tolerances for the
592 step-size controller (default ``1.4e-8`` in fp32 mode,
593 ``1.0e-10`` in fp64 mode).
594 - ``max_steps`` — hard cap on accepted ODE steps
595 (default :attr:`Yaqsi._solver_defaults['max_steps']`,
596 currently ``2**14``). Increase this if the integrator
597 raises ``MaxStepsReached`` for a stiff/oscillatory
598 pulse Hamiltonian.
599 - ``throw`` — whether to raise on solver failure
600 (default :attr:`Yaqsi._solver_defaults['throw']`,
601 currently ``True``). When ``False``, a failed
602 integration returns a NaN-filled unitary instead of
603 raising; this is the recommended setting for inner
604 loops of an optimiser (e.g. QOC Stage 0) so a single
605 pathological candidate cannot abort the whole run.
606 """
607 coeff_fns = ph.coeff_fns # tuple of callables
608 H_mats = ph.H_mats # tuple of (dim, dim)
609 wires = ph.wires
610 n_terms = ph.n_terms
611 dim = H_mats[0].shape[0]
613 # Pre-compute -i*H_i for each term and split into real / imaginary
614 # parts so the ODE RHS uses only real arithmetic. Final shape:
615 # (n_terms, 2, dim, dim).
616 neg_iH_split_per_term = []
617 for H_mat in H_mats:
618 neg_iH = -1j * H_mat
619 neg_iH_split_per_term.append(
620 jnp.stack([jnp.real(neg_iH), jnp.imag(neg_iH)], axis=0)
621 )
622 neg_iH_split = jnp.stack(neg_iH_split_per_term, axis=0)
624 # Real dtype matching the precision mode
625 # consider decreasing if no convergence
626 _rdtype = jnp.float64 if jax.config.x64_enabled else jnp.float32
628 # Pick tolerances according to precision + some headroom
629 atol, rtol, max_steps, throw, solver_name, magnus_steps = (
630 cls._parse_evolve_solver_options(odeint_kwargs)
631 )
633 # Cache key: every coeff fn's code object (same shape of pulse
634 # fns -> same JIT program) plus dim, tolerances, and solver
635 # budget / throw flag (different budgets mean different XLA
636 # programs). We use the code object itself (hashable, identity-
637 # equal) rather than ``id(fn.__code__)``: ids can be reused for
638 # later code objects after the original is garbage-collected,
639 # which would silently return a stale compiled solver for a
640 # different pulse shape. Holding the code object in the cache
641 # keeps it alive for as long as the cached program is valid.
642 cache_key = (
643 tuple(fn.__code__ for fn in coeff_fns),
644 dim,
645 atol,
646 rtol,
647 max_steps,
648 throw,
649 solver_name,
650 magnus_steps,
651 )
653 with cls._evolve_solver_cache_lock:
654 _solve = cls._evolve_solver_cache.get(cache_key)
655 if _solve is None:
656 if solver_name in ("magnus2", "magnus4"):
657 _solve = cls._build_magnus_evolve_solver(
658 cache_key=cache_key,
659 coeff_fns=coeff_fns,
660 n_terms=n_terms,
661 dim=dim,
662 solver_name=solver_name,
663 magnus_steps=magnus_steps,
664 )
665 else:
666 _solve = cls._build_diffrax_evolve_solver(
667 cache_key=cache_key,
668 coeff_fns=coeff_fns,
669 n_terms=n_terms,
670 dim=dim,
671 atol=atol,
672 rtol=rtol,
673 max_steps=max_steps,
674 throw=throw,
675 solver_name=solver_name,
676 _rdtype=_rdtype,
677 )
679 def _apply(coeff_args, T) -> Operation:
680 """Evolve under the (multi-term) time-dependent Hamiltonian.
682 Args:
683 coeff_args: List/tuple of parameter sets, one per term.
684 For single-term Hamiltonians the legacy form
685 ``[params]`` works unchanged; ``params`` is forwarded
686 to the sole coefficient function.
687 T: Total evolution time. Scalar -> integrate on
688 ``[0, T]``; 2-element -> integrate on ``[T[0], T[1]]``.
690 Returns:
691 An :class:`Operation` wrapping the computed unitary.
692 """
693 # Normalise to a tuple of length n_terms. Accept a bare
694 # single-term arg for backward compat.
695 if isinstance(coeff_args, (list, tuple)):
696 params = tuple(coeff_args)
697 else:
698 params = (coeff_args,)
700 if len(params) != n_terms:
701 raise ValueError(
702 f"Expected {n_terms} parameter set(s) for a "
703 f"{n_terms}-term ParametrizedHamiltonian, "
704 f"got {len(params)}."
705 )
707 # Build time span — resolve at Python level to avoid traced
708 # branching. ``T`` is either a Python scalar / 0-d array (=> integrate
709 # on [0, T]) or a 2-element sequence/array (=> integrate on [T[0], T[1]]).
710 # Let ``_solve`` cast t0/t1 to its working dtype; we only need the
711 # array form to know the rank.
712 T_arr = jnp.asarray(T, dtype=_rdtype)
713 if T_arr.ndim == 0:
714 t0 = _rdtype(0.0)
715 t1 = T_arr
716 else:
717 t0 = T_arr[0]
718 t1 = T_arr[1]
720 U = _solve(neg_iH_split, params, t0, t1)
722 return Operation(wires=wires, matrix=U, name=name)
724 return _apply
727# TODO adjust imports to use classmethods instead
728partial_trace = Yaqsi.partial_trace
729evolve = Yaqsi.evolve
730marginalize_probs = Yaqsi.marginalize_probs
731build_parity_observable = Yaqsi.build_parity_observable
734class Script:
735 """Circuit container and executor backed by pure JAX kernels.
737 ``Script`` takes a callable *f* representing a quantum circuit.
738 Within *f*, :class:`~qml_essentials.operations.Operation` objects are
739 instantiated and automatically recorded onto a tape. The tape is then
740 simulated using either a statevector or density-matrix kernel depending on
741 whether noise channels are present.
743 Attributes:
744 f: The circuit function whose body instantiates ``Operation`` objects.
745 _n_qubits: Optionally pre-declared number of qubits. When ``None``
746 the qubit count is inferred from the operations recorded on the
747 tape.
749 Example:
750 >>> def circuit(theta):
751 ... RX(theta, wires=0)
752 ... PauliZ(wires=1)
753 >>> script = Script(circuit, n_qubits=2)
754 >>> result = script.execute(type="expval", obs=[PauliZ(0)])
755 """
757 def __init__(self, f: Callable[..., None], n_qubits: Optional[int] = None) -> None:
758 """Initialise a Script.
760 Args:
761 f: A function whose body instantiates ``Operation`` objects.
762 Signature: ``f(*args, **kwargs) -> None``.
763 n_qubits: Number of qubits. If ``None``, inferred from the
764 operations recorded on the tape.
765 """
766 self.f = f
767 self._n_qubits = n_qubits
768 self._jit_cache: dict = {} # keyed on (type, in_axes, arg_shapes, gateError)
770 @staticmethod
771 def _estimate_peak_bytes(
772 n_qubits: int,
773 batch_size: int,
774 type: str,
775 use_density: bool,
776 n_obs: int = 0,
777 ) -> int:
778 """Estimate peak memory (bytes) for a batched simulation.
780 The estimate accounts for:
782 - The batched statevector (always needed, even for density).
783 - The batched output tensor (state / probs / density / expval).
784 - One gate-tensor temporary per batch element (the einsum buffer).
786 Observable matrices are **not** counted: they are computed inside
787 the JIT-compiled function and XLA manages their lifetime (reusing
788 buffers between observables). Similarly, the outer-product
789 temporary for pure-circuit density mode is transient within XLA.
791 Element size is determined dynamically from ``jax.config.x64_enabled``:
792 when x64 mode is disabled (the JAX default), complex values are
793 ``complex64`` (8 bytes) and floats are ``float32`` (4 bytes),
794 halving memory usage compared to the x64 path.
796 A 1.5× safety factor is applied to cover XLA compiler temporaries,
797 padding, and other allocations not directly visible to Python.
799 This is a pure Python arithmetic calculation with no JAX calls —
800 it adds essentially zero overhead.
802 Args:
803 n_qubits: Number of qubits in the circuit.
804 batch_size: Number of batch elements.
805 type: Measurement type (``"state"``, ``"probs"``, ``"expval"``,
806 ``"density"``).
807 use_density: Whether density-matrix simulation is used.
808 n_obs: Number of observables (relevant for ``"expval"``).
810 Returns:
811 Estimated peak memory in bytes.
812 """
813 dim = 2**n_qubits
814 # Detect actual element size: JAX silently truncates complex128
815 # to complex64 when x64 mode is disabled (the default).
816 elem = 16 if jax.config.x64_enabled else 8 # complex128 vs complex64
817 real_elem = elem // 2 # float64 vs float32
819 # Statevector: always allocated during simulation
820 sv_bytes = batch_size * dim * elem
822 # Simulation intermediate: when density-matrix simulation is used,
823 # the full rho (dim × dim) must be held during gate evolution —
824 # even if the final output is only probs or expval.
825 # apply_to_density contracts both U and U* against rho, so at least
826 # two intermediate (dim × dim) buffers are alive simultaneously.
827 if use_density:
828 sim_bytes = 2 * batch_size * dim * dim * elem
829 else:
830 sim_bytes = 0 # statevector is already counted above
832 # Output tensor: this is the *returned* array, not the simulation
833 # intermediate. For probs/expval with density simulation the
834 # density matrix is reduced to a small output *before* returning,
835 # so only the reduced output coexists with the next chunk.
836 if type == "density":
837 out_bytes = batch_size * dim * dim * elem
838 elif type == "expval":
839 out_bytes = batch_size * max(n_obs, 1) * real_elem
840 elif type == "probs":
841 out_bytes = batch_size * dim * real_elem
842 else: # state
843 out_bytes = batch_size * dim * elem
845 # Gate temporaries: einsum creates one (2,)*n buffer per batch elem
846 gate_tmp = batch_size * dim * elem
848 # Peak = max(simulation phase, output phase). During simulation
849 # the intermediate + statevector + gate temps are alive. After
850 # measurement, only the output survives. So peak is whichever
851 # phase is larger.
852 sim_peak = sv_bytes + sim_bytes + gate_tmp
853 out_peak = out_bytes
854 raw = max(sim_peak, out_peak)
856 # 1.5× safety factor for XLA compiler temporaries, padding, etc.
857 return int(raw * 1.5)
859 @staticmethod
860 def _available_memory_bytes() -> int:
861 """Return available system memory in bytes.
863 Uses ``psutil.virtual_memory().available`` for cross-platform
864 support (Linux, macOS, Windows). Falls back to reading
865 ``/proc/meminfo`` on Linux, and finally to a conservative 4 GiB
866 default if neither approach succeeds.
868 Returns:
869 Available memory in bytes.
870 """
871 mem = 4 * 1024**3
872 # Primary: psutil (works on Linux, macOS, Windows)
873 try:
874 import psutil
876 mem = psutil.virtual_memory().available
877 except Exception:
878 log.debug("psutil not available. Fallback to /proc/meminfo")
880 # Fallback: /proc/meminfo (Linux only)
881 try:
882 with open("/proc/meminfo", "r") as f:
883 for line in f:
884 if line.startswith("MemAvailable:"):
885 mem = int(line.split()[1]) * 1024 # kB → bytes
886 except Exception:
887 log.debug("Failed to read /proc/meminfo. Falling back to 4 GiB")
889 log.debug(f"Available memory: {mem / 1024**3:.1f} GB")
890 return mem
892 @staticmethod
893 def _compute_chunk_size(
894 n_qubits: int,
895 batch_size: int,
896 type: str,
897 use_density: bool,
898 n_obs: int = 0,
899 memory_fraction: float = 0.8,
900 ) -> int:
901 """Determine the largest chunk size that fits in available memory.
903 If the full batch fits, returns *batch_size* (i.e. no chunking).
904 Otherwise, returns the largest chunk size such that the computation
905 of one chunk **plus** the full output accumulator fits within
906 ``memory_fraction`` of available RAM.
908 The output accumulator is the final ``(batch_size, ...)`` array that
909 holds all results. When chunking, this array must coexist with the
910 active chunk computation, so its size is subtracted from available
911 memory before computing how many elements fit per chunk.
913 The minimum chunk size is 1 (fully serialised).
915 Args:
916 n_qubits: Number of qubits.
917 batch_size: Total batch size.
918 type: Measurement type.
919 use_density: Whether density-matrix simulation is used.
920 n_obs: Number of observables.
921 memory_fraction: Fraction of available memory to target
922 (default 0.8 = 80%).
924 Returns:
925 Chunk size (number of batch elements per sub-batch).
926 """
927 avail = int(Script._available_memory_bytes() * memory_fraction)
928 full_est = Script._estimate_peak_bytes(
929 n_qubits, batch_size, type, use_density, n_obs
930 )
932 if full_est <= avail:
933 return batch_size # everything fits — no chunking
935 # The output accumulator (the final (batch_size, ...) result array)
936 # must coexist with each chunk's computation, so subtract its size
937 # from available memory before sizing chunks.
938 dim = 2**n_qubits
939 elem = 16 if jax.config.x64_enabled else 8
940 real_elem = elem // 2
941 if type == "density":
942 accum_bytes = batch_size * dim * dim * elem
943 elif type == "expval":
944 accum_bytes = batch_size * max(n_obs, 1) * real_elem
945 elif type == "probs":
946 accum_bytes = batch_size * dim * real_elem
947 else:
948 accum_bytes = batch_size * dim * elem
949 avail_for_chunks = max(avail - accum_bytes, elem) # at least 1 element
951 # Per-element cost: the memory for computing a single batch element.
952 per_elem = Script._estimate_peak_bytes(n_qubits, 1, type, use_density, n_obs)
954 if per_elem <= 0:
955 return batch_size
957 chunk = avail_for_chunks // per_elem
958 chunk = max(1, min(chunk, batch_size))
960 if chunk == 1 and per_elem > avail:
961 log.warning(
962 f"A single batch element requires ~{per_elem / 1024**3:.2f} GB "
963 f"but only ~{avail / 1024**3:.2f} GB is available. "
964 f"Proceeding with chunk_size=1 but OOM is possible. "
965 f"Consider reducing n_qubits or switching measurement type."
966 )
968 log.info(
969 f"Computation requires ~{full_est / 1024**3:.2f} GB which "
970 f"does not fit in ~{avail / 1024**3:.2f} GB. "
971 f"Using chunk size {chunk}."
972 )
973 return chunk
975 @staticmethod
976 def _execute_chunked(
977 batched_fn: Callable,
978 args: tuple,
979 in_axes: Tuple,
980 batch_size: int,
981 chunk_size: int,
982 ) -> jnp.ndarray:
983 """Execute a vmapped function in memory-safe chunks.
985 Splits the batch dimension into sub-batches of at most *chunk_size*
986 elements, runs each through the JIT-compiled *batched_fn*, and
987 writes results into a pre-allocated output array.
989 Only one chunk's intermediate result is alive at a time: each
990 chunk is computed, copied into the output buffer, and then its
991 reference is dropped — allowing JAX/XLA to reclaim the memory
992 before the next chunk starts. This keeps peak memory at roughly
993 ``output_buffer + one_chunk_computation`` rather than the sum of
994 all chunk outputs.
996 Args:
997 batched_fn: A JIT-compiled, vmapped callable.
998 args: Full-batch arguments (before slicing).
999 in_axes: Per-argument batch axis specification.
1000 batch_size: Total number of batch elements.
1001 chunk_size: Maximum elements per chunk.
1003 Returns:
1004 Batched results with the same leading dimension as the
1005 full batch.
1006 """
1007 n_chunks = (batch_size + chunk_size - 1) // chunk_size
1008 log.debug(
1009 f"Memory-aware chunking: splitting batch of {batch_size} into "
1010 f"{n_chunks} chunks of <={chunk_size} elements."
1011 )
1013 output = None
1014 for chunk_idx in range(n_chunks):
1015 start = chunk_idx * chunk_size
1016 end = min(start + chunk_size, batch_size)
1017 size = end - start
1019 # Slice each batched argument along its batch axis
1020 chunk_args = tuple(
1021 (
1022 jax.lax.dynamic_slice_in_dim(a, start, size, axis=ax)
1023 if ax is not None
1024 else a
1025 )
1026 for a, ax in zip(args, in_axes)
1027 )
1029 chunk_result = batched_fn(*chunk_args)
1031 if output is None:
1032 # Pre-allocate the full output buffer on first chunk
1033 out_shape = (batch_size,) + chunk_result.shape[1:]
1034 output = jnp.zeros(out_shape, dtype=chunk_result.dtype)
1036 # Copy chunk into the output buffer; the slice assignment
1037 # creates a new array (JAX arrays are immutable) but the old
1038 # `output` reference is immediately replaced, letting XLA
1039 # reclaim it.
1040 output = output.at[start:end].set(chunk_result)
1042 # Explicitly drop the chunk reference so XLA can free the
1043 # chunk's device memory before computing the next one.
1044 del chunk_result, chunk_args
1045 # Optionally trigger a JAX cache clear to release device
1046 # buffers — disabled by default because it forces full
1047 # recompilation of ``batched_fn`` on every subsequent
1048 # chunk. Set ``Yaqsi._clear_caches_between_chunks = True``
1049 # if you actually observe OOM growth across chunks.
1050 if Yaqsi._clear_caches_between_chunks:
1051 jax.clear_caches()
1053 return output
1055 def _record(self, *args, **kwargs) -> List[Operation]:
1056 """Run the circuit function and collect the recorded operations.
1058 Uses :func:`~qml_essentials.tape.recording` as a context manager so
1059 that the tape is always cleaned up — even if the circuit function
1060 raises — and nested recordings (e.g. from ``_execute_batched``) each
1061 get their own independent tape.
1063 Args:
1064 *args: Positional arguments forwarded to the circuit function.
1065 **kwargs: Keyword arguments forwarded to the circuit function.
1067 Returns:
1068 List of :class:`~qml_essentials.operations.Operation` instances in
1069 the order they were instantiated.
1070 """
1071 with recording() as tape:
1072 self.f(*args, **kwargs)
1073 return tape
1075 def pulse_events(self, *args, **kwargs) -> list:
1076 """Run the circuit and collect pulse events emitted by PulseGates.
1078 Activates both the normal operation tape (so gates execute) and
1079 a pulse-event tape that captures
1080 :class:`~qml_essentials.drawing.PulseEvent` objects from leaf
1081 pulse gates (RX, RY, RZ, CZ).
1083 Args:
1084 *args (Any): Forwarded to the circuit function.
1085 **kwargs (Any): Forwarded to the circuit function.
1087 Returns:
1088 List of :class:`~qml_essentials.drawing.PulseEvent`.
1089 """
1090 with pulse_recording() as events:
1091 with recording():
1092 self.f(*args, **kwargs)
1093 return events
1095 @staticmethod
1096 def _infer_n_qubits(ops: List[Operation], obs: List[Operation]) -> int:
1097 """Infer the number of qubits from a list of operations and observables.
1099 Args:
1100 ops: Gate operations recorded on the tape.
1101 obs: Observable operations used for measurement.
1103 Returns:
1104 The smallest number of qubits that covers all wire indices, i.e.
1105 ``max(all_wires) + 1`` (at least 1).
1106 """
1107 all_wires: set[int] = set()
1108 for op in ops + obs:
1109 all_wires.update(op.wires)
1110 return max(all_wires) + 1 if all_wires else 1
1112 @staticmethod
1113 def _simulate_pure(tape: List[Operation], n_qubits: int) -> jnp.ndarray:
1114 """Statevector simulation kernel.
1116 Starts from |00…0⟩ and applies each gate in *tape* via tensor
1117 contraction. The state is kept in rank-*n* tensor form ``(2,)*n``
1118 throughout the gate loop to avoid per-gate ``reshape`` dispatch;
1119 only the initial and final conversions to/from the flat ``(2**n,)``
1120 representation incur a reshape.
1122 All gate tensors and einsum subscript strings are pre-extracted from
1123 the tape before the simulation loop so that each iteration performs
1124 only a single ``jnp.einsum`` call with zero additional Python
1125 overhead (no method dispatch, no property access, no cache lookup).
1127 Args:
1128 tape: Ordered list of gate operations to apply.
1129 n_qubits: Total number of qubits.
1131 Returns:
1132 Statevector of shape ``(2**n_qubits,)``.
1133 """
1134 dim = 2**n_qubits
1136 # Pre-extract gate tensors and einsum subscripts — eliminates all
1137 # per-gate Python overhead (method calls, property lookups, cache
1138 # hits on _einsum_subscript) from the hot loop.
1139 compiled = []
1140 for op in tape:
1141 if isinstance(op, Barrier):
1142 continue
1143 k = len(op.wires)
1144 gt = op._gate_tensor(k)
1145 sub = _einsum_subscript(n_qubits, k, tuple(op.wires))
1146 compiled.append((gt, sub))
1148 state = jnp.zeros(dim, dtype=_cdtype()).at[0].set(1.0)
1149 psi = state.reshape((2,) * n_qubits)
1150 for gt, sub in compiled:
1151 psi = jnp.einsum(sub, gt, psi)
1152 return psi.reshape(dim)
1154 @staticmethod
1155 def _simulate_mixed(tape: List[Operation], n_qubits: int) -> jnp.ndarray:
1156 """Density-matrix simulation kernel.
1158 Starts from \\rho = \\vert 0\\rangle\\langle 0\\vert and
1159 applies each gate in *tape* via
1160 :meth:`~qml_essentials.operations.Operation.apply_to_density`
1161 (\\rho -> U\\rho U† for unitaries, \\Sigma_k K_k \\rho K_k\\dagger
1162 for Kraus channels).
1163 Required for noisy circuits.
1165 Args:
1166 tape: Ordered list of gate or channel operations to apply.
1167 n_qubits: Total number of qubits.
1169 Returns:
1170 Density matrix of shape ``(2**n_qubits, 2**n_qubits)``.
1171 """
1172 dim = 2**n_qubits
1173 rho = jnp.zeros((dim, dim), dtype=_cdtype()).at[0, 0].set(1.0)
1174 for op in tape:
1175 rho = op.apply_to_density(rho, n_qubits)
1176 return rho
1178 @staticmethod
1179 def _simulate_and_measure(
1180 tape: List[Operation],
1181 n_qubits: int,
1182 type: str,
1183 obs: List[Operation],
1184 use_density: bool,
1185 shots: Optional[int] = None,
1186 key: Optional[jnp.ndarray] = None,
1187 ) -> jnp.ndarray:
1188 """Run simulation and measurement in a single dispatch.
1190 Chooses statevector or density-matrix simulation based on
1191 *use_density*, then applies the appropriate measurement function.
1192 This eliminates duplicated branching logic in single-sample and
1193 batched execution paths.
1195 When *shots* is not ``None``, the exact probability distribution is
1196 first computed, then ``shots`` samples are drawn from it to produce
1197 a noisy estimate of the requested measurement (``"probs"`` or
1198 ``"expval"``).
1200 Pure-circuit density optimisation — when ``type == "density"``
1201 but no noise channels are present on the tape, the density matrix
1202 is computed via statevector simulation followed by an outer product
1203 ``\\rho = \\vert\\psi\\rangle\\langle\\psi\\vert``
1204 instead of evolving the full ``2^n\\times 2^n`` matrix
1205 gate by gate. This reduces the per-gate cost from O(4^n) to
1206 O(2^n), giving a significant speed-up for medium qubit counts
1207 (~4x for 5 qubits).
1209 Args:
1210 tape: Ordered list of gate/channel operations to apply.
1211 n_qubits: Total number of qubits.
1212 type: Measurement type (``"state"``/``"probs"``/``"expval"``/
1213 ``"density"``).
1214 obs: Observables for ``"expval"`` measurements.
1215 use_density: If ``True``, use density-matrix simulation.
1216 shots: Number of measurement shots. If ``None`` (default),
1217 exact analytic results are returned.
1218 key: JAX PRNG key for shot sampling. Required when *shots*
1219 is not ``None``.
1221 Returns:
1222 Measurement result (shape depends on *type*).
1223 """
1224 if use_density:
1225 # Check if any operation is actually a noise channel.
1226 has_noise = any(isinstance(o, KrausChannel) for o in tape)
1227 if has_noise:
1228 # Must do full density-matrix evolution for Kraus channels.
1229 rho = Script._simulate_mixed(tape, n_qubits)
1230 else:
1231 # Pure circuit requesting density output: simulate the
1232 # statevector (O(depth\times 2^n)) and form # noqa: W605
1233 # \rho = \vert\psi\rangle\langle\psi\vert once # noqa: W605
1234 # (O(4^n)). This avoids the O(depth\times 4^n) cost of # noqa: W605
1235 # evolving the full density matrix gate by gate.
1236 state = Script._simulate_pure(tape, n_qubits)
1237 rho = jnp.outer(state, jnp.conj(state))
1239 if shots is not None and type in ("probs", "expval"):
1240 exact_probs = jnp.real(jnp.diag(rho))
1241 return Script._sample_shots(
1242 exact_probs, n_qubits, type, obs, shots, key
1243 )
1244 return Script._measure_density(rho, n_qubits, type, obs)
1246 state = Script._simulate_pure(tape, n_qubits)
1248 if shots is not None and type in ("probs", "expval"):
1249 exact_probs = jnp.abs(state) ** 2
1250 return Script._sample_shots(exact_probs, n_qubits, type, obs, shots, key)
1251 return Script._measure_state(state, n_qubits, type, obs)
1253 @staticmethod
1254 def _measure_state(
1255 state: jnp.ndarray,
1256 n_qubits: int,
1257 type: str,
1258 obs: List[Operation],
1259 ) -> jnp.ndarray:
1260 """Apply the requested measurement to a pure statevector.
1262 Args:
1263 state: Statevector of shape ``(2**n_qubits,)``.
1264 n_qubits: Total number of qubits.
1265 type: Measurement type — one of ``"state"``, ``"probs"``,
1266 or ``"expval"``.
1267 obs: Observables used when *type* is ``"expval"``.
1269 Returns:
1270 Measurement result whose shape depends on *type*:
1272 - ``"state"`` -> ``(2**n_qubits,)``
1273 - ``"probs"`` -> ``(2**n_qubits,)``
1274 - ``"expval"`` -> ``(len(obs),)``
1276 Raises:
1277 ValueError: If *type* is not a recognised measurement type.
1278 """
1279 if type == "state":
1280 return state
1282 if type == "probs":
1283 return jnp.abs(state) ** 2
1285 if type == "expval":
1286 # Fast path for single-qubit diagonal observables (PauliZ, etc.)
1287 # where d0, d1 are the diagonal elements of the 2x2 observable.
1288 # This replaces n_obs tensor contractions with a single |ψ|²
1289 # and n_obs reductions over the probability vector.
1291 def _is_single_qubit_diag(ob):
1292 m = ob.__class__._matrix
1293 if m is None or len(ob.wires) != 1:
1294 return False
1295 # Convert to NumPy to ensure concrete boolean evaluation
1296 m_np = np.asarray(m)
1297 return np.allclose(m_np - np.diag(np.diag(m_np)), 0)
1299 all_single_qubit_diag = all(_is_single_qubit_diag(ob) for ob in obs)
1301 if all_single_qubit_diag:
1302 probs = jnp.abs(state) ** 2
1303 psi_t = probs.reshape((2,) * n_qubits)
1304 results = []
1305 for ob in obs:
1306 q = ob.wires[0]
1307 d = np.real(np.diag(np.asarray(ob.__class__._matrix)))
1308 # Sum probabilities over all axes except qubit q
1309 p_q = jnp.sum(
1310 psi_t, axis=tuple(i for i in range(n_qubits) if i != q)
1311 )
1312 results.append(d[0] * p_q[0] + d[1] * p_q[1])
1313 return jnp.array(results)
1315 # General path: stack observable matrices and use a single
1316 # batched matmul instead of a Python loop of tensor contractions.
1317 # O_states[i] = obs[i] |ψ⟩, then ⟨O_i⟩ = Re(⟨ψ|O_states[i]⟩).
1318 obs_mats = jnp.stack(
1319 [ob.lifted_matrix(n_qubits) for ob in obs], axis=0
1320 ) # (n_obs, dim, dim)
1321 # Batched matvec: (n_obs, dim, dim) @ (dim,) -> (n_obs, dim)
1322 O_states = jnp.einsum("oij,j->oi", obs_mats, state)
1323 return jnp.real(jnp.einsum("i,oi->o", jnp.conj(state), O_states))
1325 raise ValueError(f"Unknown measurement type: {type!r}")
1327 @staticmethod
1328 def _measure_density(
1329 rho: jnp.ndarray,
1330 n_qubits: int,
1331 type: str,
1332 obs: List[Operation],
1333 ) -> jnp.ndarray:
1334 """Apply the requested measurement to a density matrix.
1336 Args:
1337 rho: Density matrix of shape ``(2**n_qubits, 2**n_qubits)``.
1338 n_qubits: Total number of qubits.
1339 type: Measurement type — one of ``"density"``, ``"probs"``,
1340 or ``"expval"``.
1341 obs: Observables used when *type* is ``"expval"``.
1343 Returns:
1344 Measurement result whose shape depends on *type*:
1346 - ``"density"`` -> ``(2**n_qubits, 2**n_qubits)``
1347 - ``"probs"`` -> ``(2**n_qubits,)``
1348 - ``"expval"`` -> ``(len(obs),)``
1350 Raises:
1351 ValueError: If *type* is ``"state"`` (not valid for mixed circuits)
1352 or another unrecognised type.
1353 """
1354 if type == "density":
1355 return rho
1357 if type == "probs":
1358 return jnp.real(jnp.diag(rho))
1360 if type == "expval":
1361 # Tr(O \\rho ) = \\Sigma_ij O_ij \\rho _ji
1362 # Stack all observable matrices and compute all traces in one
1363 # batched operation.
1364 obs_mats = jnp.stack(
1365 [ob.lifted_matrix(n_qubits) for ob in obs], axis=0
1366 ) # (n_obs, dim, dim)
1367 # einsum "oij,ji->o" computes Tr(O_o @ \\rho ) for each observable
1368 return jnp.real(jnp.einsum("oij,ji->o", obs_mats, rho))
1370 raise ValueError(
1371 "Measurement type 'state' is not defined for mixed (noisy) circuits. "
1372 "Use 'density' instead."
1373 )
1375 @staticmethod
1376 def _sample_shots(
1377 probs: jnp.ndarray,
1378 n_qubits: int,
1379 type: str,
1380 obs: List[Operation],
1381 shots: int,
1382 key: jnp.ndarray,
1383 ) -> jnp.ndarray:
1384 """Convert exact probabilities into shot-sampled results.
1386 Draws *shots* samples from the computational-basis probability
1387 distribution and returns either estimated probabilities or
1388 shot-based expectation values.
1390 Args:
1391 probs: Exact probability vector of shape ``(2**n_qubits,)``.
1392 n_qubits: Total number of qubits.
1393 type: Measurement type — ``"probs"`` or ``"expval"``.
1394 obs: Observables used when *type* is ``"expval"``.
1395 shots: Number of measurement shots.
1396 key: JAX PRNG key for sampling.
1398 Returns:
1399 Shot-sampled measurement result:
1401 - ``"probs"`` → ``(2**n_qubits,)`` estimated probabilities.
1402 - ``"expval"`` → ``(len(obs),)`` estimated expectation values.
1403 """
1404 dim = 2**n_qubits
1406 # Draw `shots` samples from the computational basis.
1407 # Each sample is an integer in [0, dim) representing a basis state.
1408 samples = jax.random.choice(key, dim, shape=(shots,), p=probs)
1410 # Build a histogram of counts for each basis state.
1411 counts = jnp.zeros(dim, dtype=jnp.int32)
1412 counts = counts.at[samples].add(1)
1413 estimated_probs = counts / shots
1415 if type == "probs":
1416 return estimated_probs
1418 if type == "expval":
1419 # For each observable, compute O from the shot-sampled
1420 # probabilities. For diagonal observables this is exact;
1421 # for general observables we use Tr(O · diag(estimated_probs)).
1422 results = []
1423 for ob in obs:
1424 O_mat = ob.lifted_matrix(n_qubits)
1425 # diagonal approximation from
1426 # computational basis measurements, which is exact for
1427 # diagonal observables like PauliZ)
1428 results.append(jnp.real(jnp.dot(jnp.diag(O_mat), estimated_probs)))
1429 return jnp.array(results)
1431 raise ValueError(
1432 f"Shot simulation is only supported for 'probs' and 'expval', got {type!r}."
1433 )
1435 def execute(
1436 self,
1437 type: str = "expval",
1438 obs: Optional[List[Operation]] = None,
1439 *,
1440 args: tuple = (),
1441 kwargs: Optional[dict] = None,
1442 in_axes: Optional[Tuple] = None,
1443 shots: Optional[int] = None,
1444 key: Optional[jnp.ndarray] = None,
1445 ) -> jnp.ndarray:
1446 """Execute the circuit and return measurement results.
1448 Args:
1449 type: Measurement type. One of:
1451 - ``"expval"`` — expectation value ⟨ψ|O|ψ⟩ / Tr(O\\rho ) for
1452 each observable in *obs*.
1453 - ``"probs"`` — probability vector of shape ``(2**n,)``.
1454 - ``"state"`` — raw statevector of shape ``(2**n,)``.
1455 - ``"density"`` — full density matrix of shape
1456 ``(2**n, 2**n)``.
1458 obs: Observables required when type is ``"expval"``.
1459 args: Positional arguments forwarded to the circuit function f.
1460 kwargs: Keyword arguments forwarded to f.
1461 in_axes: Batch axes for each element of *args*, following the same
1462 convention as ``jax.vmap``:
1464 - An integer selects that axis of the corresponding array as
1465 the batch dimension.
1466 - ``None`` broadcasts the argument (no batching).
1468 When provided, :meth:`execute` calls ``jax.vmap`` over the
1469 pure simulation kernel and returns results with a leading
1470 batch dimension.
1471 shots: Number of measurement shots for stochastic sampling.
1472 If ``None`` (default), exact analytic results are returned.
1473 Only supported for ``"probs"`` and ``"expval"`` measurement
1474 types.
1475 key: JAX PRNG key for shot sampling. If ``None`` and *shots*
1476 is set, a default key ``jax.random.PRNGKey(0)`` is used.
1478 Returns:
1479 Without in_axes: shape determined by type.
1480 With in_axes: shape ``(B, ...)`` with a leading batch dimension.
1481 """
1482 if obs is None:
1483 obs = []
1484 if kwargs is None:
1485 kwargs = {}
1486 if shots is not None and key is None:
1487 key = jax.random.PRNGKey(0)
1489 # Split single/ parallel execution
1490 # TODO: we might want to unify the n_qubit stuff such that we can eliminate
1491 # the parameter to this method entirely
1492 if in_axes is not None:
1493 return self._execute_batched(
1494 type=type,
1495 obs=obs,
1496 args=args,
1497 kwargs=kwargs,
1498 in_axes=in_axes,
1499 shots=shots,
1500 key=key,
1501 )
1502 else:
1503 tape = self._record(*args, **kwargs)
1504 n_qubits = self._n_qubits or self._infer_n_qubits(tape, obs)
1506 has_noise = any(isinstance(op, KrausChannel) for op in tape)
1507 use_density = type == "density" or has_noise
1509 return self._simulate_and_measure(
1510 tape,
1511 n_qubits,
1512 type,
1513 obs,
1514 use_density,
1515 shots=shots,
1516 key=key,
1517 )
1519 def _execute_batched(
1520 self,
1521 type: str,
1522 obs: List[Operation],
1523 args: tuple,
1524 kwargs: dict,
1525 in_axes: Tuple,
1526 shots: Optional[int] = None,
1527 key: Optional[jnp.ndarray] = None,
1528 ) -> jnp.ndarray:
1529 """Vectorise :meth:`execute` over a batch axis using ``jax.vmap``.
1531 The circuit function is traced once in Python with scalar slices to
1532 record the tape, determine ``n_qubits``, and detect noise. The
1533 resulting pure simulation kernel is then vmapped over the requested
1534 axes.
1536 Memory-aware chunking — before launching the full vmap, the
1537 method estimates peak memory usage. If the full batch would exceed
1538 available RAM (with a safety margin), the batch is automatically
1539 split into sub-batches that fit. Each chunk is vmapped independently
1540 and the results are concatenated. This trades a small amount of
1541 wall-clock time for guaranteed execution without OOM.
1543 When the full batch fits in memory, there is zero overhead — the
1544 memory check is a pure Python arithmetic calculation (no JAX calls).
1546 Args:
1547 type: Measurement type (see :meth:`execute`).
1548 obs: Observables (see :meth:`execute`).
1549 args: Positional arguments for the circuit function.
1550 kwargs: Keyword arguments for the circuit function.
1551 in_axes: One entry per element of *args*. Follows ``jax.vmap``
1552 convention: an int gives the batch axis; ``None`` broadcasts.
1553 shots: Number of measurement shots. If ``None``, exact results.
1554 key: JAX PRNG key for shot sampling.
1556 Returns:
1557 Batched measurement results of shape ``(B, ...)`` where *B* is the
1558 size of the batch dimension.
1560 Raises:
1561 ValueError: If ``len(in_axes) != len(args)``.
1563 Note:
1564 The ``jax.vmap`` call at the end of this method is the exact
1565 boundary to replace with ``jax.shard_map`` for multi-device
1566 execution::
1568 from jax.sharding import PartitionSpec as P, Mesh
1569 result = jax.shard_map(
1570 _single_execute, mesh=mesh,
1571 in_specs=tuple(P(0) if ax is not None else P() for ax in in_axes),
1572 out_specs=P(0),
1573 )(*args)
1574 """
1575 if len(in_axes) != len(args):
1576 raise ValueError(
1577 f"in_axes has {len(in_axes)} entries but args has {len(args)}. "
1578 "Provide one in_axes entry per positional argument."
1579 )
1581 # Determine batch size from the first batched arg
1582 batch_size = 1
1583 for a, ax in zip(args, in_axes):
1584 if ax is not None:
1585 batch_size = a.shape[ax]
1586 break
1588 arg_shapes = tuple(
1589 (a.shape, a.dtype) if hasattr(a, "shape") else type(a) for a in args
1590 )
1591 cache_kwargs = _make_hashable(
1592 {k: v for k, v in kwargs.items() if not isinstance(v, jnp.ndarray)}
1593 )
1595 # TODO: we need to fix the dirty class-level `batch_gate_error` hack
1596 from qml_essentials.gates import UnitaryGates
1598 cache_key = (
1599 type,
1600 in_axes,
1601 arg_shapes,
1602 cache_kwargs,
1603 UnitaryGates.batch_gate_error,
1604 )
1606 # When called under an outer JAX transform (e.g. ``jacrev``) the
1607 # cached ``batched_fn`` from a previous outer trace would leak that
1608 # trace's tracers. Bypass the per-Script wrapper cache in that
1609 # case; XLA-level compilation caching is unaffected.
1610 # in_transform = _args_contain_tracer(args)
1612 # --- Cache-hit fast path (no shots) ---
1613 cached = self._jit_cache.get(cache_key)
1614 # if cached is not None and shots is None and not in_transform:
1615 if cached is not None and shots is None:
1616 batched_fn, n_qubits, use_density = cached
1617 # Check if we already determined the chunk size for this
1618 # exact batch_size (avoids repeated psutil syscalls).
1619 mem_key = ("_mem", cache_key, batch_size)
1620 cached_chunk = self._jit_cache.get(mem_key)
1621 if cached_chunk is not None:
1622 if cached_chunk >= batch_size:
1623 return batched_fn(*args)
1624 return self._execute_chunked(
1625 batched_fn, args, in_axes, batch_size, cached_chunk
1626 )
1627 chunk_size = self._compute_chunk_size(
1628 n_qubits, batch_size, type, use_density, len(obs)
1629 )
1630 self._jit_cache[mem_key] = chunk_size
1631 if chunk_size >= batch_size:
1632 return batched_fn(*args)
1633 return self._execute_chunked(
1634 batched_fn, args, in_axes, batch_size, chunk_size
1635 )
1637 # Record the tape once using scalar slices of each arg.
1638 # This determines n_qubits and whether noise channels are present
1639 # without running the full batch.
1640 # Note, that we use lax.index_in_dim instead of jnp.take because JAX
1641 # random key arrays do not support jnp.take.
1642 # TODO: fix once that is available in JAX
1643 def _slice_first(a, ax):
1644 """Take the first element along axis *ax*."""
1645 return jax.lax.index_in_dim(a, 0, axis=ax, keepdims=False)
1647 scalar_args = tuple(
1648 _slice_first(a, ax) if ax is not None else a for a, ax in zip(args, in_axes)
1649 )
1650 tape = self._record(*scalar_args, **kwargs)
1651 n_qubits = self._n_qubits or self._infer_n_qubits(tape, obs)
1652 has_noise = any(isinstance(op, KrausChannel) for op in tape)
1653 use_density = type == "density" or has_noise
1655 chunk_size = self._compute_chunk_size(
1656 n_qubits, batch_size, type, use_density, len(obs)
1657 )
1659 # Re-recording inside this function is necessary: the tape may
1660 # contain operations whose matrices depend on the batched argument
1661 # (e.g. RX(theta) where theta is a JAX tracer). jax.vmap traces
1662 # this function once and generates a single XLA computation for
1663 # the entire batch.
1664 if shots is not None and type in ("probs", "expval"):
1665 # Shot mode: compute exact probabilities, then sample.
1666 # The shot key is appended as an extra vmapped argument.
1667 def _single_execute_shots(*single_args_and_key):
1668 *single_args, shot_key = single_args_and_key
1669 single_tape = self._record(*single_args, **kwargs)
1670 exact_result = self._simulate_and_measure(
1671 single_tape, n_qubits, "probs", obs, use_density
1672 )
1673 return Script._sample_shots(
1674 exact_result, n_qubits, type, obs, shots, shot_key
1675 )
1677 shot_keys = jax.random.split(key, batch_size)
1678 shot_in_axes = in_axes + (0,) # key is batched over axis 0
1679 shot_args = args + (shot_keys,)
1681 # Shot-mode uses a separate cache key (includes shots)
1682 shot_cache_key = (
1683 type,
1684 "shots",
1685 shots,
1686 in_axes,
1687 arg_shapes,
1688 UnitaryGates.batch_gate_error,
1689 )
1690 cached_shot = self._jit_cache.get(shot_cache_key)
1691 if cached_shot is not None:
1692 batched_fn = cached_shot[0]
1693 else:
1694 batched_fn = eqx.filter_jit(
1695 jax.vmap(_single_execute_shots, in_axes=shot_in_axes)
1696 )
1697 self._jit_cache[shot_cache_key] = (batched_fn, n_qubits, use_density)
1699 if chunk_size >= batch_size:
1700 return batched_fn(*shot_args)
1701 return self._execute_chunked(
1702 batched_fn, shot_args, shot_in_axes, batch_size, chunk_size
1703 )
1705 def _single_execute(*single_args):
1706 single_tape = self._record(*single_args, **kwargs)
1707 return self._simulate_and_measure(
1708 single_tape, n_qubits, type, obs, use_density
1709 )
1711 # Wrapping the vmapped function in eqx.filter_jit has two effects:
1712 # 1. Multi-core utilisation — the JIT-compiled XLA program can
1713 # use intra-op parallelism to distribute independent SIMD lanes
1714 # across CPU threads, unlike an eager vmap which runs
1715 # single-threaded.
1716 # 2. Compilation caching — subsequent calls with the same input
1717 # shapes reuse the compiled kernel and skip all Python-level
1718 # tracing, eliminating the O(B\\times circuit_depth) Python overhead.
1719 #
1720 # The compiled function is cached on this Script instance,
1721 # keyed on (type, in_axes, arg_shapes). Repeated calls with the
1722 # same structure (e.g. training iterations) skip both Python-level
1723 # tracing and XLA compilation entirely — they jump straight to the
1724 # cache check at the top of this method.
1725 # NOTE: when altering properties of the model, this might not get re-compiled
1726 # TODO: we might want to rework the data_reupload mechanism at some point
1727 batched_fn = eqx.filter_jit(jax.vmap(_single_execute, in_axes=in_axes))
1728 # Cache the function together with metadata for fast-path memory
1729 # checks on subsequent calls. Skip caching when the call is under
1730 # an outer JAX transform (the closure of ``_single_execute``
1731 # captures ``n_qubits``/``obs``/``kwargs`` of this trace; reusing
1732 # the wrapper under a different outer trace would leak its
1733 # tracers).
1734 # if not in_transform:
1735 # self._jit_cache[cache_key] = (batched_fn, n_qubits, use_density)
1736 self._jit_cache[cache_key] = (batched_fn, n_qubits, use_density)
1738 if chunk_size >= batch_size:
1739 return batched_fn(*args)
1740 return self._execute_chunked(batched_fn, args, in_axes, batch_size, chunk_size)
1742 def draw(
1743 self,
1744 figure: str = "text",
1745 args: tuple = (),
1746 kwargs: Optional[dict] = None,
1747 **draw_kwargs: Any,
1748 ) -> Union[str, Any]:
1749 """Draw the quantum circuit.
1751 Records the tape by calling the circuit function with the given
1752 arguments, then renders the resulting gate sequence.
1754 Args:
1755 figure: Rendering backend. One of:
1757 - ``"text"`` — ASCII art (returned as a ``str``).
1758 - ``"mpl"`` — Matplotlib figure (returns ``(fig, ax)``).
1759 - ``"tikz"`` — LaTeX/TikZ code via ``quantikz``
1760 (returns a :class:`TikzFigure`).
1761 - ``"pulse"`` — Pulse schedule plot (returns ``(fig, axes)``).
1763 args: Positional arguments forwarded to the circuit function
1764 to record the tape.
1765 kwargs: Keyword arguments forwarded to the circuit function.
1766 **draw_kwargs: Extra options forwarded to the rendering backend:
1768 - ``gate_values`` (bool): Show numeric gate angles instead of
1769 symbolic \\theta_i labels. Default ``False``.
1770 - ``show_carrier`` (bool): For ``"pulse"`` mode, overlay the
1771 carrier-modulated waveform. Default ``False``.
1773 Returns:
1774 Depends on *figure*:
1776 - ``"text"`` -> ``str``
1777 - ``"mpl"`` -> ``(matplotlib.figure.Figure, matplotlib.axes.Axes)``
1778 - ``"tikz"`` -> :class:`TikzFigure`
1779 - ``"pulse"`` -> ``(matplotlib.figure.Figure, numpy.ndarray)``
1781 Raises:
1782 ValueError: If *figure* is not one of the supported modes.
1783 """
1784 if figure not in ("text", "mpl", "tikz", "pulse"):
1785 raise ValueError(
1786 f"Invalid figure mode: {figure!r}. "
1787 "Must be 'text', 'mpl', 'tikz', or 'pulse'."
1788 )
1790 if kwargs is None:
1791 kwargs = {}
1793 if figure == "pulse":
1794 from qml_essentials.drawing import draw_pulse_schedule
1796 events = self.pulse_events(*args, **kwargs)
1797 n_qubits = (
1798 self._n_qubits
1799 or max((w for ev in events for w in ev.wires), default=0) + 1
1800 )
1801 return draw_pulse_schedule(events, n_qubits, **draw_kwargs)
1803 tape = self._record(*args, **kwargs)
1804 n_qubits = self._n_qubits or self._infer_n_qubits(tape, [])
1806 # Filter out noise channels for drawing — they clutter the diagram
1807 ops = [op for op in tape if not isinstance(op, KrausChannel)]
1809 if figure == "text":
1810 return draw_text(ops, n_qubits)
1811 elif figure == "mpl":
1812 return draw_mpl(ops, n_qubits, **draw_kwargs)
1813 else: # tikz
1814 return draw_tikz(ops, n_qubits, **draw_kwargs)