Coverage for qml_essentials / simulation.py: 95%
103 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-11 15:51 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-11 15:51 +0000
1"""Pure simulation and measurement kernels for :class:`~qml_essentials.script.Script`.
3These functions are stateless: they take a recorded tape (a list of
4:class:`~qml_essentials.operations.Operation`) plus measurement parameters and
5return JAX arrays. Keeping them as module-level free functions (rather than
6static methods on ``Script``) makes the simulation engine independently testable
7and keeps ``script.py`` focused on orchestration.
8"""
10from typing import List, Optional
12import jax
13import jax.numpy as jnp
14import numpy as np # needed to prevent jitting some operations
16from qml_essentials.operations import (
17 Barrier,
18 Operation,
19 KrausChannel,
20 _einsum_subscript,
21 _cdtype,
22)
25def infer_n_qubits(ops: List[Operation], obs: List[Operation]) -> int:
26 """Infer the number of qubits from a list of operations and observables.
28 Args:
29 ops: Gate operations recorded on the tape.
30 obs: Observable operations used for measurement.
32 Returns:
33 The smallest number of qubits that covers all wire indices, i.e.
34 ``max(all_wires) + 1`` (at least 1).
35 """
36 all_wires: set[int] = set()
37 for op in ops + obs:
38 all_wires.update(op.wires)
39 return max(all_wires) + 1 if all_wires else 1
42def uses_density(tape: List[Operation], type: str) -> bool:
43 """Return whether density-matrix simulation is required.
45 Density-matrix simulation is needed when the caller explicitly requests the
46 ``"density"`` measurement type, or when the tape contains a noise channel
47 (a :class:`~qml_essentials.operations.KrausChannel`).
49 Args:
50 tape: Ordered list of gate/channel operations.
51 type: Requested measurement type.
53 Returns:
54 ``True`` if density-matrix simulation must be used.
55 """
56 has_noise = any(isinstance(op, KrausChannel) for op in tape)
57 return type == "density" or has_noise
60def _stack_obs(obs: List[Operation], n_qubits: int) -> jnp.ndarray:
61 """Stack lifted observable matrices into a single ``(n_obs, dim, dim)`` array."""
62 return jnp.stack([ob.lifted_matrix(n_qubits) for ob in obs], axis=0)
65def simulate_pure(tape: List[Operation], n_qubits: int) -> jnp.ndarray:
66 """Statevector simulation kernel.
68 Starts from |00…0⟩ and applies each gate in *tape* via tensor
69 contraction. The state is kept in rank-*n* tensor form ``(2,)*n``
70 throughout the gate loop to avoid per-gate ``reshape`` dispatch;
71 only the initial and final conversions to/from the flat ``(2**n,)``
72 representation incur a reshape.
74 All gate tensors and einsum subscript strings are pre-extracted from
75 the tape before the simulation loop so that each iteration performs
76 only a single ``jnp.einsum`` call with zero additional Python
77 overhead (no method dispatch, no property access, no cache lookup).
79 Args:
80 tape: Ordered list of gate operations to apply.
81 n_qubits: Total number of qubits.
83 Returns:
84 Statevector of shape ``(2**n_qubits,)``.
85 """
86 dim = 2**n_qubits
88 # Pre-extract gate tensors and einsum subscripts — eliminates all
89 # per-gate Python overhead (method calls, property lookups, cache
90 # hits on _einsum_subscript) from the hot loop.
91 compiled = []
92 for op in tape:
93 if isinstance(op, Barrier):
94 continue
95 k = len(op.wires)
96 gt = op._gate_tensor(k)
97 sub = _einsum_subscript(n_qubits, k, tuple(op.wires))
98 compiled.append((gt, sub))
100 state = jnp.zeros(dim, dtype=_cdtype()).at[0].set(1.0)
101 psi = state.reshape((2,) * n_qubits)
102 for gt, sub in compiled:
103 psi = jnp.einsum(sub, gt, psi)
104 return psi.reshape(dim)
107def simulate_mixed(tape: List[Operation], n_qubits: int) -> jnp.ndarray:
108 """Density-matrix simulation kernel.
110 Starts from \\rho = \\vert 0\\rangle\\langle 0\\vert and
111 applies each gate in *tape* via
112 :meth:`~qml_essentials.operations.Operation.apply_to_density`
113 (\\rho -> U\\rho U† for unitaries, \\Sigma_k K_k \\rho K_k\\dagger
114 for Kraus channels).
115 Required for noisy circuits.
117 Args:
118 tape: Ordered list of gate or channel operations to apply.
119 n_qubits: Total number of qubits.
121 Returns:
122 Density matrix of shape ``(2**n_qubits, 2**n_qubits)``.
123 """
124 dim = 2**n_qubits
125 rho = jnp.zeros((dim, dim), dtype=_cdtype()).at[0, 0].set(1.0)
126 for op in tape:
127 rho = op.apply_to_density(rho, n_qubits)
128 return rho
131def simulate_and_measure(
132 tape: List[Operation],
133 n_qubits: int,
134 type: str,
135 obs: List[Operation],
136 use_density: bool,
137 shots: Optional[int] = None,
138 key: Optional[jnp.ndarray] = None,
139) -> jnp.ndarray:
140 """Run simulation and measurement in a single dispatch.
142 Chooses statevector or density-matrix simulation based on
143 *use_density*, then applies the appropriate measurement function.
144 This eliminates duplicated branching logic in single-sample and
145 batched execution paths.
147 When *shots* is not ``None``, the exact probability distribution is
148 first computed, then ``shots`` samples are drawn from it to produce
149 a noisy estimate of the requested measurement (``"probs"`` or
150 ``"expval"``).
152 Pure-circuit density optimisation — when ``type == "density"``
153 but no noise channels are present on the tape, the density matrix
154 is computed via statevector simulation followed by an outer product
155 ``\\rho = \\vert\\psi\\rangle\\langle\\psi\\vert``
156 instead of evolving the full ``2^n\\times 2^n`` matrix
157 gate by gate. This reduces the per-gate cost from O(4^n) to
158 O(2^n), giving a significant speed-up for medium qubit counts
159 (~4x for 5 qubits).
161 Args:
162 tape: Ordered list of gate/channel operations to apply.
163 n_qubits: Total number of qubits.
164 type: Measurement type (``"state"``/``"probs"``/``"expval"``/
165 ``"density"``).
166 obs: Observables for ``"expval"`` measurements.
167 use_density: If ``True``, use density-matrix simulation.
168 shots: Number of measurement shots. If ``None`` (default),
169 exact analytic results are returned.
170 key: JAX PRNG key for shot sampling. Required when *shots*
171 is not ``None``.
173 Returns:
174 Measurement result (shape depends on *type*).
175 """
176 if use_density:
177 # Check if any operation is actually a noise channel.
178 has_noise = any(isinstance(o, KrausChannel) for o in tape)
179 if has_noise:
180 # Must do full density-matrix evolution for Kraus channels.
181 rho = simulate_mixed(tape, n_qubits)
182 else:
183 # Pure circuit requesting density output: simulate the
184 # statevector (O(depth\times 2^n)) and form # noqa: W605
185 # \rho = \vert\psi\rangle\langle\psi\vert once # noqa: W605
186 # (O(4^n)). This avoids the O(depth\times 4^n) cost of # noqa: W605
187 # evolving the full density matrix gate by gate.
188 state = simulate_pure(tape, n_qubits)
189 rho = jnp.outer(state, jnp.conj(state))
191 if shots is not None and type in ("probs", "expval"):
192 exact_probs = jnp.real(jnp.diag(rho))
193 return sample_shots(exact_probs, n_qubits, type, obs, shots, key)
194 return measure_density(rho, n_qubits, type, obs)
196 state = simulate_pure(tape, n_qubits)
198 if shots is not None and type in ("probs", "expval"):
199 exact_probs = jnp.abs(state) ** 2
200 return sample_shots(exact_probs, n_qubits, type, obs, shots, key)
201 return measure_state(state, n_qubits, type, obs)
204def measure_state(
205 state: jnp.ndarray,
206 n_qubits: int,
207 type: str,
208 obs: List[Operation],
209) -> jnp.ndarray:
210 """Apply the requested measurement to a pure statevector.
212 Args:
213 state: Statevector of shape ``(2**n_qubits,)``.
214 n_qubits: Total number of qubits.
215 type: Measurement type — one of ``"state"``, ``"probs"``,
216 or ``"expval"``.
217 obs: Observables used when *type* is ``"expval"``.
219 Returns:
220 Measurement result whose shape depends on *type*:
222 - ``"state"`` -> ``(2**n_qubits,)``
223 - ``"probs"`` -> ``(2**n_qubits,)``
224 - ``"expval"`` -> ``(len(obs),)``
226 Raises:
227 ValueError: If *type* is not a recognised measurement type.
228 """
229 if type == "state":
230 return state
232 if type == "probs":
233 return jnp.abs(state) ** 2
235 if type == "expval":
236 # Fast path for single-qubit diagonal observables (PauliZ, etc.)
237 # where d0, d1 are the diagonal elements of the 2x2 observable.
238 # This replaces n_obs tensor contractions with a single |ψ|²
239 # and n_obs reductions over the probability vector.
241 def _is_single_qubit_diag(ob):
242 m = ob.__class__._matrix
243 if m is None or len(ob.wires) != 1:
244 return False
245 # Convert to NumPy to ensure concrete boolean evaluation
246 m_np = np.asarray(m)
247 return np.allclose(m_np - np.diag(np.diag(m_np)), 0)
249 all_single_qubit_diag = all(_is_single_qubit_diag(ob) for ob in obs)
251 if all_single_qubit_diag:
252 probs = jnp.abs(state) ** 2
253 psi_t = probs.reshape((2,) * n_qubits)
254 results = []
255 for ob in obs:
256 q = ob.wires[0]
257 d = np.real(np.diag(np.asarray(ob.__class__._matrix)))
258 # Sum probabilities over all axes except qubit q
259 p_q = jnp.sum(psi_t, axis=tuple(i for i in range(n_qubits) if i != q))
260 results.append(d[0] * p_q[0] + d[1] * p_q[1])
261 return jnp.array(results)
263 # General path: stack observable matrices and use a single
264 # batched matmul instead of a Python loop of tensor contractions.
265 # O_states[i] = obs[i] |ψ⟩, then ⟨O_i⟩ = Re(⟨ψ|O_states[i]⟩).
266 obs_mats = _stack_obs(obs, n_qubits) # (n_obs, dim, dim)
267 # Batched matvec: (n_obs, dim, dim) @ (dim,) -> (n_obs, dim)
268 O_states = jnp.einsum("oij,j->oi", obs_mats, state)
269 return jnp.real(jnp.einsum("i,oi->o", jnp.conj(state), O_states))
271 raise ValueError(f"Unknown measurement type: {type!r}")
274def measure_density(
275 rho: jnp.ndarray,
276 n_qubits: int,
277 type: str,
278 obs: List[Operation],
279) -> jnp.ndarray:
280 """Apply the requested measurement to a density matrix.
282 Args:
283 rho: Density matrix of shape ``(2**n_qubits, 2**n_qubits)``.
284 n_qubits: Total number of qubits.
285 type: Measurement type — one of ``"density"``, ``"probs"``,
286 or ``"expval"``.
287 obs: Observables used when *type* is ``"expval"``.
289 Returns:
290 Measurement result whose shape depends on *type*:
292 - ``"density"`` -> ``(2**n_qubits, 2**n_qubits)``
293 - ``"probs"`` -> ``(2**n_qubits,)``
294 - ``"expval"`` -> ``(len(obs),)``
296 Raises:
297 ValueError: If *type* is ``"state"`` (not valid for mixed circuits)
298 or another unrecognised type.
299 """
300 if type == "density":
301 return rho
303 if type == "probs":
304 return jnp.real(jnp.diag(rho))
306 if type == "expval":
307 # Tr(O \\rho ) = \\Sigma_ij O_ij \\rho _ji
308 # Stack all observable matrices and compute all traces in one
309 # batched operation.
310 obs_mats = _stack_obs(obs, n_qubits) # (n_obs, dim, dim)
311 # einsum "oij,ji->o" computes Tr(O_o @ \\rho ) for each observable
312 return jnp.real(jnp.einsum("oij,ji->o", obs_mats, rho))
314 raise ValueError(
315 "Measurement type 'state' is not defined for mixed (noisy) circuits. "
316 "Use 'density' instead."
317 )
320def sample_shots(
321 probs: jnp.ndarray,
322 n_qubits: int,
323 type: str,
324 obs: List[Operation],
325 shots: int,
326 key: jnp.ndarray,
327) -> jnp.ndarray:
328 """Convert exact probabilities into shot-sampled results.
330 Draws *shots* samples from the computational-basis probability
331 distribution and returns either estimated probabilities or
332 shot-based expectation values.
334 Args:
335 probs: Exact probability vector of shape ``(2**n_qubits,)``.
336 n_qubits: Total number of qubits.
337 type: Measurement type — ``"probs"`` or ``"expval"``.
338 obs: Observables used when *type* is ``"expval"``.
339 shots: Number of measurement shots.
340 key: JAX PRNG key for sampling.
342 Returns:
343 Shot-sampled measurement result:
345 - ``"probs"`` → ``(2**n_qubits,)`` estimated probabilities.
346 - ``"expval"`` → ``(len(obs),)`` estimated expectation values.
347 """
348 dim = 2**n_qubits
350 # Draw `shots` samples from the computational basis.
351 # Each sample is an integer in [0, dim) representing a basis state.
352 samples = jax.random.choice(key, dim, shape=(shots,), p=probs)
354 # Build a histogram of counts for each basis state.
355 counts = jnp.zeros(dim, dtype=jnp.int32)
356 counts = counts.at[samples].add(1)
357 estimated_probs = counts / shots
359 if type == "probs":
360 return estimated_probs
362 if type == "expval":
363 # For each observable, compute O from the shot-sampled
364 # probabilities. For diagonal observables this is exact;
365 # for general observables we use Tr(O · diag(estimated_probs)).
366 results = []
367 for ob in obs:
368 O_mat = ob.lifted_matrix(n_qubits)
369 # diagonal approximation from
370 # computational basis measurements, which is exact for
371 # diagonal observables like PauliZ)
372 results.append(jnp.real(jnp.dot(jnp.diag(O_mat), estimated_probs)))
373 return jnp.array(results)
375 raise ValueError(
376 f"Shot simulation is only supported for 'probs' and 'expval', got {type!r}."
377 )