Coverage for qml_essentials / evolution.py: 94%
170 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-11 15:51 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-11 15:51 +0000
1"""Hamiltonian time-evolution machinery for pulse/gate construction.
3This module houses the :class:`Evolution` class, which turns a (static or
4time-dependent) Hamiltonian into a gate factory by solving the Schrödinger
5equation ``dU/dt = -i H(t) U``. It is the pulse/gate-dependent counterpart to
6the otherwise pulse-agnostic :mod:`qml_essentials.jaqsi` entry point.
8The engine is normally reached through the :meth:`evolve` method on the
9Hamiltonian object (``Hermitian`` / ``ParametrizedHamiltonian``), which delegates
10to :meth:`Evolution.evolve`. :class:`Evolution` is also where solver defaults
11live (:meth:`Evolution.set_solver_defaults`).
12"""
14from typing import Any, Callable, List, Optional, Tuple, Union
15import math
16import threading
18import diffrax
19import jax
20import jax.numpy as jnp
21import jax.scipy.linalg
22import equinox as eqx
24from qml_essentials.operations import (
25 Hermitian,
26 ParametrizedHamiltonian,
27 Operation,
28)
31class Evolution:
32 # Module-level cache for JIT-compiled ODE solvers. Keyed on
33 # (coeff_fn_id, dim, atol, rtol, max_steps, throw) so that all
34 # evolve() calls with the same pulse shape function and matrix size
35 # share one compiled XLA program. This turns O(n_gates) JIT
36 # compilations into O(n_distinct_pulse_shapes) during pulse-mode
37 # circuit building.
38 _evolve_solver_cache: dict = {}
39 _evolve_solver_cache_lock = threading.Lock()
41 # Default solver knobs for parametrized (time-dependent) evolution.
42 # These can be overridden per-call via the **odeint_kwargs of
43 # ``evolve()`` or globally via :meth:`set_solver_defaults`.
44 #
45 # ``max_steps`` is the hard cap on accepted ODE steps. Pulse-level
46 # workloads at on-resonance carriers (ω_c ≈ ω_q) require many more
47 # steps than the diffrax default during JIT — 2**13 = 8192 is
48 # large enough for realistic single- and two-qubit pulses while
49 # remaining cheap to compile.
50 #
51 # ``throw`` controls whether diffrax raises on solver failure
52 # (e.g. ``MaxStepsReached``). When set to ``False`` the gate
53 # factory instead returns a NaN-filled unitary so the calling
54 # optimiser sees a well-defined (but useless) result and can
55 # gracefully reject the candidate.
57 # ``solver`` selects the time-integration backend for the
58 # interaction-picture ODE ``dU/dt = -i H_I(t) U``:
59 #
60 # * ``"dopri8"`` (default) — adaptive Dormand-Prince 8(7) via
61 # diffrax. Robust but expensive on highly oscillatory drives
62 # because the step controller resolves every fast cycle.
63 # * ``"dopri5"`` — TODO description
64 # * ``"magnus2"`` — commutator-free Magnus, 2nd order (midpoint
65 # rule) on a fixed ``magnus_steps`` grid via ``jax.lax.scan``.
66 # One ``expm`` per step. Preserves unitarity to machine
67 # precision and fuses into a single XLA program.
68 # * ``"magnus4"`` — commutator-free Magnus, 4th order (CFM4:2 of
69 # Blanes & Moan) on a fixed ``magnus_steps`` grid. Two ``H``
70 # evaluations and two ``expm`` per step; typically the best
71 # accuracy/cost trade-off for smooth oscillatory pulse drives.
72 #
73 # ``magnus_steps`` is the number of fixed substeps for the Magnus
74 # integrators (ignored for ``dopri8``). Choose it so that ``h =
75 # T/N`` resolves the fastest oscillation in ``H(t)`` (~few steps
76 # per period of the highest frequency).
77 _solver_defaults: dict = {
78 "max_steps": 2**13,
79 "throw": True,
80 "solver": "dopri8",
81 "magnus_steps": 256,
82 }
83 _valid_solvers = ("dopri8", "dopri5", "magnus2", "magnus4")
85 @classmethod
86 def set_solver_defaults(
87 cls,
88 max_steps: Optional[int] = None,
89 throw: Optional[bool] = None,
90 solver: Optional[str] = None,
91 magnus_steps: Optional[int] = None,
92 ) -> dict:
93 """Update class-level solver defaults; return the previous values.
95 The returned dictionary is suitable for restoring the previous
96 defaults via ``set_solver_defaults(**prev)``.
98 Args:
99 max_steps: New default for ``max_steps`` (ignored if ``None``).
100 throw: New default for ``throw`` (ignored if ``None``).
102 Returns:
103 Dictionary with the previous values of the updated keys.
104 """
105 prev: dict = {}
106 if max_steps is not None:
107 prev["max_steps"] = cls._solver_defaults["max_steps"]
108 cls._solver_defaults["max_steps"] = int(max_steps)
109 if throw is not None:
110 prev["throw"] = cls._solver_defaults["throw"]
111 cls._solver_defaults["throw"] = bool(throw)
112 if solver is not None:
113 if solver not in cls._valid_solvers:
114 raise ValueError(
115 f"Unknown solver {solver!r}; expected one of {cls._valid_solvers}"
116 )
117 prev["solver"] = cls._solver_defaults["solver"]
118 cls._solver_defaults["solver"] = solver
119 if magnus_steps is not None:
120 prev["magnus_steps"] = cls._solver_defaults["magnus_steps"]
121 cls._solver_defaults["magnus_steps"] = int(magnus_steps)
122 return prev
124 @classmethod
125 def _store_evolve_solver(cls, cache_key: tuple, solve: Callable) -> Callable:
126 """Cache a compiled evolve solver unless another thread won the race."""
127 with cls._evolve_solver_cache_lock:
128 existing = cls._evolve_solver_cache.get(cache_key)
129 if existing is not None:
130 return existing
131 cls._evolve_solver_cache[cache_key] = solve
132 return solve
134 @classmethod
135 def clear_evolve_solver_cache(cls) -> None:
136 """Drop every cached compiled evolve solver.
138 Call this whenever the coefficient functions referenced by the
139 cache keys are rebuilt (e.g. when :class:`PulseGates` swaps in
140 a new pulse envelope, RWA flag or frame). Without an explicit
141 eviction the cache keeps the old code objects alive and would
142 also retain XLA programs that no longer match any active
143 coefficient function.
144 """
145 with cls._evolve_solver_cache_lock:
146 cls._evolve_solver_cache.clear()
148 @classmethod
149 def _parse_evolve_solver_options(cls, odeint_kwargs: dict) -> tuple:
150 """Pop and validate solver options from ``evolve(..., **odeint_kwargs)``."""
151 default_tol = 1.0e-10 if jax.config.x64_enabled else 1.4e-8
152 atol = odeint_kwargs.pop("atol", default_tol)
153 rtol = odeint_kwargs.pop("rtol", default_tol)
154 max_steps = int(
155 odeint_kwargs.pop("max_steps", cls._solver_defaults["max_steps"])
156 )
157 throw = bool(odeint_kwargs.pop("throw", cls._solver_defaults["throw"]))
158 solver_name = str(odeint_kwargs.pop("solver", cls._solver_defaults["solver"]))
159 if solver_name not in cls._valid_solvers:
160 raise ValueError(
161 f"Unknown solver {solver_name!r}; expected one of {cls._valid_solvers}"
162 )
163 magnus_steps = int(
164 odeint_kwargs.pop("magnus_steps", cls._solver_defaults["magnus_steps"])
165 )
166 return atol, rtol, max_steps, throw, solver_name, magnus_steps
168 @classmethod
169 def _build_magnus_evolve_solver(
170 cls,
171 cache_key: tuple,
172 coeff_fns: Tuple[Callable, ...],
173 n_terms: int,
174 dim: int,
175 solver_name: str,
176 magnus_steps: int,
177 ) -> Callable:
178 """Build and cache a fixed-step commutator-free Magnus solver."""
179 _coeff_fns = coeff_fns
180 _cdtype_local = jnp.complex128 if jax.config.x64_enabled else jnp.complex64
181 n_steps = magnus_steps
182 solver_name_local = solver_name
184 @eqx.filter_jit
185 def _solve(neg_iH_split, params, t0, t1):
186 # Reconstruct the per-term complex matrices ``-i H_i`` from their
187 # split (Re, Im) representation so the coefficient sum is a single
188 # complex tensordot.
189 A_all = neg_iH_split[:, 0]
190 B_all = neg_iH_split[:, 1]
191 neg_iH = (A_all + 1j * B_all).astype(_cdtype_local)
193 h = (t1 - t0) / n_steps
195 def H_at(t):
196 c = jnp.stack(
197 [
198 jnp.asarray(_coeff_fns[i](params[i], t)).reshape(())
199 for i in range(n_terms)
200 ]
201 ).astype(_cdtype_local)
202 return jnp.tensordot(c, neg_iH, axes=1)
204 if solver_name_local == "magnus2":
206 def step(U, n):
207 tn = t0 + n * h
208 Omega = h * H_at(tn + 0.5 * h)
209 return jax.scipy.linalg.expm(Omega) @ U, None
211 else:
212 sqrt3 = math.sqrt(3.0)
213 c1 = 0.5 - sqrt3 / 6.0
214 c2 = 0.5 + sqrt3 / 6.0
215 a1 = 0.25 + sqrt3 / 6.0
216 a2 = 0.25 - sqrt3 / 6.0
218 def step(U, n):
219 tn = t0 + n * h
220 H1 = H_at(tn + c1 * h)
221 H2 = H_at(tn + c2 * h)
222 Omega_a = h * (a1 * H1 + a2 * H2)
223 Omega_b = h * (a2 * H1 + a1 * H2)
224 # CFM4:2 ordering (Blanes & Moan 2006, Table II):
225 # U_{n+1} = exp(Ω_b) · exp(Ω_a) · U_n.
226 U_next = (
227 jax.scipy.linalg.expm(Omega_b)
228 @ jax.scipy.linalg.expm(Omega_a)
229 @ U
230 )
231 return U_next, None
233 U0 = jnp.eye(dim, dtype=_cdtype_local)
234 U_final, _ = jax.lax.scan(step, U0, jnp.arange(n_steps))
235 return U_final
237 return cls._store_evolve_solver(cache_key, _solve)
239 @classmethod
240 def _build_diffrax_evolve_solver(
241 cls,
242 cache_key: tuple,
243 coeff_fns: Tuple[Callable, ...],
244 n_terms: int,
245 dim: int,
246 atol: float,
247 rtol: float,
248 max_steps: int,
249 throw: bool,
250 solver_name: str,
251 _rdtype,
252 ) -> Callable:
253 """Build and cache an adaptive diffrax-based evolve solver."""
254 solver = diffrax.Dopri8() if solver_name == "dopri8" else diffrax.Dopri5()
255 stepsize_controller = diffrax.PIDController(atol=atol, rtol=rtol)
256 _coeff_fns = coeff_fns
258 @eqx.filter_jit
259 def _solve(neg_iH_split, params, t0, t1):
260 """Solve dU/dt = sum_i f_i(p_i, t) * (-iH_i) * U from t0 to t1.
262 ``neg_iH_split`` has shape ``(n_terms, 2, dim, dim)`` with
263 ``[:, 0]`` = Re(-iH_i) and ``[:, 1]`` = Im(-iH_i).
264 ``params`` is a list/tuple of length ``n_terms`` carrying
265 each term's coefficient parameters. The state ``y`` has
266 shape ``(2, dim, dim)`` with ``y[0] = Re(U)`` and
267 ``y[1] = Im(U)``.
268 """
269 A_all = neg_iH_split[:, 0]
270 B_all = neg_iH_split[:, 1]
272 def rhs(t, y, args):
273 # Each coefficient function must return a scalar value; some
274 # call sites pass a shape-(1,) param array, so coerce to a
275 # true scalar before stacking.
276 c = jnp.stack(
277 [
278 jnp.asarray(_coeff_fns[i](args[i], t)).reshape(())
279 for i in range(n_terms)
280 ]
281 )
282 u_re = y[0]
283 u_im = y[1]
284 A_eff = jnp.tensordot(c, A_all, axes=1)
285 B_eff = jnp.tensordot(c, B_all, axes=1)
286 du_re = A_eff @ u_re - B_eff @ u_im
287 du_im = A_eff @ u_im + B_eff @ u_re
288 return jnp.stack([du_re, du_im], axis=0)
290 y0 = jnp.stack(
291 [
292 jnp.eye(dim, dtype=_rdtype),
293 jnp.zeros((dim, dim), dtype=_rdtype),
294 ],
295 axis=0,
296 )
298 sol = diffrax.diffeqsolve(
299 diffrax.ODETerm(rhs),
300 solver,
301 t0=t0,
302 t1=t1,
303 dt0=None,
304 y0=y0,
305 args=params,
306 stepsize_controller=stepsize_controller,
307 max_steps=max_steps,
308 throw=throw,
309 )
311 y_final = sol.ys[0]
312 U = y_final[0] + 1j * y_final[1]
314 if not throw:
315 successful = sol.result == diffrax.RESULTS.successful
316 U = jnp.where(successful, U, jnp.full_like(U, jnp.nan))
317 return U
319 return cls._store_evolve_solver(cache_key, _solve)
321 @classmethod
322 def evolve(
323 cls,
324 hamiltonian: Union["Hermitian", "ParametrizedHamiltonian"],
325 name: Optional[str] = None,
326 **odeint_kwargs: Any,
327 ) -> Callable:
328 """Return a gate-factory for Hamiltonian time evolution.
330 Engine for the :meth:`Hermitian.evolve` / :meth:`ParametrizedHamiltonian.evolve`
331 methods (the usual entry point); it dispatches on the Hamiltonian type.
333 Supports two modes:
335 Static — when *hamiltonian* is a :class:`Hermitian`::
337 gate = Hermitian(H_mat, wires=0).evolve()
338 gate(t=0.5) # U = exp(-i*0.5*H)
340 Time-dependent — when *hamiltonian* is a
341 :class:`ParametrizedHamiltonian` (created via ``coeff_fn * Hermitian``)::
343 H_td = coeff_fn * Hermitian(H_mat, wires=0)
344 gate = H_td.evolve()
345 gate([A, sigma], T) # U via ODE: dU/dt = -i f(p,t) H * U
347 The time-dependent case solves the Schrödinger equation numerically
348 using ``diffrax.diffeqsolve`` with a Dopri8 adaptive Runge-Kutta
349 solver
351 All computations are pure JAX and fully differentiable with
352 ``jax.grad``.
354 Args:
355 hamiltonian: Either a :class:`Hermitian` (static evolution) or a
356 :class:`ParametrizedHamiltonian` (time-dependent evolution).
357 **odeint_kwargs: Extra keyword arguments. Recognised keys:
359 - ``atol``, ``rtol`` — absolute/relative tolerances for the
360 adaptive step-size controller (default ``1.4e-8``).
362 Returns:
363 A callable gate factory. Signature depends on the mode:
365 - Static: ``(t, wires=0) -> Operation``
366 - Time-dependent: ``(coeff_args, T) -> Operation``
368 Raises:
369 TypeError: If *hamiltonian* is neither ``Hermitian`` nor
370 ``ParametrizedHamiltonian``.
371 """
372 if isinstance(hamiltonian, Hermitian):
373 return cls._evolve_static(hamiltonian, name=name)
374 elif isinstance(hamiltonian, ParametrizedHamiltonian):
375 return cls._evolve_parametrized(hamiltonian, name=name, **odeint_kwargs)
376 else:
377 raise TypeError(
378 f"evolve() expects a Hermitian or ParametrizedHamiltonian, "
379 f"got {type(hamiltonian)}"
380 )
382 @staticmethod
383 def _evolve_static(hermitian: Hermitian, name: Optional[str] = None) -> Callable:
384 """Gate factory for static Hamiltonian evolution U = exp(-i t H)."""
385 H_mat = hermitian.matrix
387 def _apply(t: float, wires: Union[int, List[int]] = 0) -> Operation:
388 U = jax.scipy.linalg.expm(-1j * t * H_mat)
389 return Operation(wires=wires, matrix=U, name=name)
391 return _apply
393 @classmethod
394 def _evolve_parametrized(
395 cls,
396 ph: ParametrizedHamiltonian,
397 name: Optional[str] = None,
398 **odeint_kwargs: Any,
399 ) -> Callable:
400 """Gate factory for time-dependent (multi-term) Hamiltonian evolution.
402 Solves the matrix ODE
404 dU/dt = -i [\\sum_i f_i(params_i, t) * H_i] * U, U(0) = I
406 with ``diffrax.diffeqsolve`` (Dopri8 adaptive RK). The Hamiltonian
407 may contain one or more ``coeff_fn * Hermitian`` terms (see
408 :class:`ParametrizedHamiltonian`); the single-term case is the
409 usual ``coeff_fn * Hermitian`` and is fully backward compatible.
411 Implementation notes:
413 - To avoid diffrax's experimental complex dtype path, the ODE is
414 reformulated in real arithmetic. Writing ``-iH_i = A_i + i B_i``
415 and ``U = U_re + i U_im``, each term contributes::
417 d(U_re)/dt += f_i(p_i,t) * (A_i @ U_re - B_i @ U_im)
418 d(U_im)/dt += f_i(p_i,t) * (A_i @ U_im + B_i @ U_re)
420 - ``-iH_i`` is precomputed once per term and stacked into a
421 ``(n_terms, 2, dim, dim)`` real array, contracted via
422 ``einsum`` against the per-step coefficient vector
423 ``c = [f_0(p_0,t), ..., f_{n-1}(p_{n-1},t)]``.
425 - The JIT-compiled solver is cached per coefficient-function code
426 tuple (and ``dim``, tolerances) so multiple ``evolve()`` calls
427 with the same pulse shape — but different Hamiltonian matrices
428 or parameters — reuse the same compiled XLA program.
430 TODO: switch back once diffrax is stable with complex arithmetic.
432 Args:
433 ph: A :class:`ParametrizedHamiltonian` (one or more terms).
434 **odeint_kwargs: Keyword arguments forwarded to
435 ``diffrax.diffeqsolve``. Recognised keys:
437 - ``atol``, ``rtol`` — absolute/relative tolerances for the
438 step-size controller (default ``1.4e-8`` in fp32 mode,
439 ``1.0e-10`` in fp64 mode).
440 - ``max_steps`` — hard cap on accepted ODE steps
441 (default :attr:`cls._solver_defaults['max_steps']`,
442 currently ``2**14``). Increase this if the integrator
443 raises ``MaxStepsReached`` for a stiff/oscillatory
444 pulse Hamiltonian.
445 - ``throw`` — whether to raise on solver failure
446 (default :attr:`cls._solver_defaults['throw']`,
447 currently ``True``). When ``False``, a failed
448 integration returns a NaN-filled unitary instead of
449 raising; this is the recommended setting for inner
450 loops of an optimiser (e.g. QOC Stage 0) so a single
451 pathological candidate cannot abort the whole run.
452 """
453 coeff_fns = ph.coeff_fns # tuple of callables
454 H_mats = ph.H_mats # tuple of (dim, dim)
455 wires = ph.wires
456 n_terms = ph.n_terms
457 dim = H_mats[0].shape[0]
459 # Pre-compute -i*H_i for each term and split into real / imaginary
460 # parts so the ODE RHS uses only real arithmetic. Final shape:
461 # (n_terms, 2, dim, dim).
462 neg_iH_split_per_term = []
463 for H_mat in H_mats:
464 neg_iH = -1j * H_mat
465 neg_iH_split_per_term.append(
466 jnp.stack([jnp.real(neg_iH), jnp.imag(neg_iH)], axis=0)
467 )
468 neg_iH_split = jnp.stack(neg_iH_split_per_term, axis=0)
470 # Real dtype matching the precision mode
471 # consider decreasing if no convergence
472 _rdtype = jnp.float64 if jax.config.x64_enabled else jnp.float32
474 # Pick tolerances according to precision + some headroom
475 atol, rtol, max_steps, throw, solver_name, magnus_steps = (
476 cls._parse_evolve_solver_options(odeint_kwargs)
477 )
479 # Cache key: every coeff fn's code object (same shape of pulse
480 # fns -> same JIT program) plus dim, tolerances, and solver
481 # budget / throw flag (different budgets mean different XLA
482 # programs). We use the code object itself (hashable, identity-
483 # equal) rather than ``id(fn.__code__)``: ids can be reused for
484 # later code objects after the original is garbage-collected,
485 # which would silently return a stale compiled solver for a
486 # different pulse shape. Holding the code object in the cache
487 # keeps it alive for as long as the cached program is valid.
488 cache_key = (
489 tuple(fn.__code__ for fn in coeff_fns),
490 dim,
491 atol,
492 rtol,
493 max_steps,
494 throw,
495 solver_name,
496 magnus_steps,
497 )
499 with cls._evolve_solver_cache_lock:
500 _solve = cls._evolve_solver_cache.get(cache_key)
501 if _solve is None:
502 if solver_name in ("magnus2", "magnus4"):
503 _solve = cls._build_magnus_evolve_solver(
504 cache_key=cache_key,
505 coeff_fns=coeff_fns,
506 n_terms=n_terms,
507 dim=dim,
508 solver_name=solver_name,
509 magnus_steps=magnus_steps,
510 )
511 else:
512 _solve = cls._build_diffrax_evolve_solver(
513 cache_key=cache_key,
514 coeff_fns=coeff_fns,
515 n_terms=n_terms,
516 dim=dim,
517 atol=atol,
518 rtol=rtol,
519 max_steps=max_steps,
520 throw=throw,
521 solver_name=solver_name,
522 _rdtype=_rdtype,
523 )
525 def _apply(coeff_args, T) -> Operation:
526 """Evolve under the (multi-term) time-dependent Hamiltonian.
528 Args:
529 coeff_args: List/tuple of parameter sets, one per term.
530 For single-term Hamiltonians the legacy form
531 ``[params]`` works unchanged; ``params`` is forwarded
532 to the sole coefficient function.
533 T: Total evolution time. Scalar -> integrate on
534 ``[0, T]``; 2-element -> integrate on ``[T[0], T[1]]``.
536 Returns:
537 An :class:`Operation` wrapping the computed unitary.
538 """
539 # Normalise to a tuple of length n_terms. Accept a bare
540 # single-term arg for backward compat.
541 if isinstance(coeff_args, (list, tuple)):
542 params = tuple(coeff_args)
543 else:
544 params = (coeff_args,)
546 if len(params) != n_terms:
547 raise ValueError(
548 f"Expected {n_terms} parameter set(s) for a "
549 f"{n_terms}-term ParametrizedHamiltonian, "
550 f"got {len(params)}."
551 )
553 # Build time span — resolve at Python level to avoid traced
554 # branching. ``T`` is either a Python scalar / 0-d array (=> integrate
555 # on [0, T]) or a 2-element sequence/array (=> integrate on [T[0], T[1]]).
556 # Let ``_solve`` cast t0/t1 to its working dtype; we only need the
557 # array form to know the rank.
558 T_arr = jnp.asarray(T, dtype=_rdtype)
559 if T_arr.ndim == 0:
560 t0 = _rdtype(0.0)
561 t1 = T_arr
562 else:
563 t0 = T_arr[0]
564 t1 = T_arr[1]
566 U = _solve(neg_iH_split, params, t0, t1)
568 return Operation(wires=wires, matrix=U, name=name)
570 return _apply