Coverage for qml_essentials / jaqsi.py: 100%
39 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-11 15:51 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-11 15:51 +0000
1"""Pulse/gate-independent entry point for building and simulating circuits.
3This module is the main interaction point for manually creating circuits. It
4exposes the :class:`~qml_essentials.script.Script` circuit container, the
5:func:`Hamiltonian` factory for time-evolution sources, and a few general
6(pulse/gate-independent) quantum-info utilities.
8Time evolution is invoked as a method on the Hamiltonian object::
10 H = Hamiltonian(matrix, wires=0) # static -> Hermitian
11 H_t = coeff_fn * Hamiltonian(matrix, 0) # time-dep -> ParametrizedHamiltonian
12 H_t.evolve(name="RX")([params], t) # gate factory
14The time-evolution engine itself lives in :mod:`qml_essentials.evolution` as
15:class:`Evolution`, which is re-exported here for solver configuration
16(``Evolution.set_solver_defaults`` / ``Evolution.clear_evolve_solver_cache``).
17"""
19from functools import reduce
20from typing import List, Tuple, Union
22import jax
23import jax.numpy as jnp
25from qml_essentials.script import Script # noqa: F401
26from qml_essentials.evolution import Evolution # noqa: F401
27from qml_essentials.operations import ( # noqa: F401
28 Hermitian,
29 ParametrizedHamiltonian,
30 PauliZ,
31)
34def Hamiltonian(
35 matrix: jnp.ndarray,
36 wires: Union[int, List[int]] = 0,
37 record: bool = False,
38) -> Hermitian:
39 """Construct a (static) Hamiltonian as a :class:`Hermitian` operator.
41 This is a thin factory over the existing :class:`Hermitian` operation —
42 not a new type. Multiply it by a coefficient function ``f(params, t)`` to
43 obtain a time-dependent :class:`ParametrizedHamiltonian`. Both expose an
44 :meth:`evolve` method that returns a gate factory.
46 Args:
47 matrix: The Hermitian matrix defining this Hamiltonian.
48 wires: Qubit index or list of qubit indices it acts on.
49 record: Whether to record on the active tape. Defaults to ``False``
50 since a Hamiltonian used as an evolution source should not appear
51 as a gate; the recorded operation is the one produced by
52 :meth:`evolve`.
54 Returns:
55 A :class:`Hermitian` instance.
56 """
57 return Hermitian(matrix, wires=wires, record=record)
60def _partial_trace_single(
61 rho: jnp.ndarray,
62 n_qubits: int,
63 keep: List[int],
64) -> jnp.ndarray:
65 """Partial trace of a single density matrix (no batch dimension)."""
66 shape = (2,) * (2 * n_qubits)
67 rho_t = rho.reshape(shape)
69 trace_out = sorted(set(range(n_qubits)) - set(keep))
71 for q in reversed(trace_out):
72 n_remaining = rho_t.ndim // 2
73 rho_t = jnp.trace(rho_t, axis1=q, axis2=q + n_remaining)
75 dim = 2 ** len(keep)
76 return rho_t.reshape(dim, dim)
79def partial_trace(
80 rho: jnp.ndarray,
81 n_qubits: int,
82 keep: List[int],
83) -> jnp.ndarray:
84 """Partial trace of a density matrix, keeping only the specified qubits.
86 Supports both single density matrices of shape ``(2**n, 2**n)`` and
87 batched density matrices of shape ``(B, 2**n, 2**n)``.
89 Args:
90 rho: Density matrix of shape ``(2**n, 2**n)`` or ``(B, 2**n, 2**n)``.
91 n_qubits: Total number of qubits.
92 keep: List of qubit indices to *keep* (0-indexed).
94 Returns:
95 Reduced density matrix of shape ``(2**k, 2**k)`` or ``(B, 2**k, 2**k)``
96 where *k* = ``len(keep)``.
97 """
99 dim = 2**n_qubits
100 if rho.shape == (dim, dim):
101 return _partial_trace_single(rho, n_qubits, keep)
102 # Batched: shape (B, dim, dim)
103 return jax.vmap(lambda r: _partial_trace_single(r, n_qubits, keep))(rho)
106def _marginalize_probs_single(
107 probs: jnp.ndarray,
108 target_shape: Tuple[int],
109 trace_out: Tuple[int],
110) -> jnp.ndarray:
111 """Marginalize a single probability vector (no batch dimension)."""
112 probs_t = probs.reshape(target_shape)
114 for q in trace_out:
115 probs_t = probs_t.sum(axis=q)
117 return probs_t.ravel()
120def marginalize_probs(
121 probs: jnp.ndarray,
122 n_qubits: int,
123 keep: Tuple[int],
124) -> jnp.ndarray:
125 """Marginalize a probability vector to keep only the specified qubits.
127 Supports both single probability vectors of shape ``(2**n,)`` and
128 batched vectors of shape ``(B, 2**n)``.
130 Args:
131 probs: Probability vector of shape ``(2**n,)`` or ``(B, 2**n)``.
132 n_qubits: Total number of qubits.
133 keep: List of qubit indices to *keep* (0-indexed).
135 Returns:
136 Marginalized probability vector of shape ``(2**k,)`` or ``(B, 2**k)``
137 where *k* = ``len(keep)``.
138 """
140 dim = 2**n_qubits
141 trace_out = tuple(q for q in range(n_qubits - 1, -1, -1) if q not in keep)
142 target_shape = (2,) * n_qubits
144 return jax.vmap(lambda p: _marginalize_probs_single(p, target_shape, trace_out))(
145 probs.reshape(-1, dim)
146 )
149def build_parity_observable(
150 qubit_group: List[int],
151) -> Hermitian:
152 """Build a multi-qubit parity observable.
154 Args:
155 qubit_group: List of qubit indices for the parity measurement.
157 Returns:
158 A :class:`Hermitian` operation whose matrix is the Z parity
159 tensor product and whose wires match the given qubits.
160 """
161 Z = PauliZ._matrix
162 mat = reduce(jnp.kron, [Z] * len(qubit_group))
163 obs = Hermitian(matrix=mat, wires=qubit_group, record=False)
164 # Tag the Pauli string so symbolic consumers (PauliWord / FourierTree) can
165 # read it without an O(4^n) matrix decomposition.
166 obs._pauli_label = "Z" * len(qubit_group)
167 return obs