Coverage for qml_essentials / memory.py: 93%
85 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-11 15:51 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-11 15:51 +0000
1"""Memory estimation and memory-aware batch chunking.
3These helpers let :class:`~qml_essentials.script.Script` decide whether a batched
4simulation fits in available RAM and, if not, split it into chunks that do. They
5are pure functions (the estimates are plain Python arithmetic) so they add
6essentially zero overhead when the full batch fits.
7"""
9from typing import Callable, Tuple
11import jax
12import jax.numpy as jnp
14import logging
16log = logging.getLogger(__name__)
18# Whether to call ``jax.clear_caches()`` between memory-aware chunks in
19# :func:`execute_chunked`. Default ``False``: clearing caches between chunks
20# forces XLA to recompile the same batched program for every chunk, which is a
21# major performance hit when many chunks are needed. Set ``True`` only if you
22# observe OOM growth across chunks.
23CLEAR_CACHES_BETWEEN_CHUNKS: bool = False
26def _element_sizes() -> Tuple[int, int]:
27 """Return ``(complex_elem, real_elem)`` byte sizes for the active JAX dtype.
29 JAX silently truncates complex128 to complex64 (and float64 to float32) when
30 x64 mode is disabled (the default), halving memory usage.
31 """
32 elem = 16 if jax.config.x64_enabled else 8 # complex128 vs complex64
33 return elem, elem // 2 # (complex, real/float)
36def _output_bytes(
37 type: str,
38 batch_size: int,
39 dim: int,
40 elem: int,
41 real_elem: int,
42 n_obs: int,
43) -> int:
44 """Bytes of the returned/accumulated ``(batch_size, ...)`` measurement array."""
45 if type == "density":
46 return batch_size * dim * dim * elem
47 if type == "expval":
48 return batch_size * max(n_obs, 1) * real_elem
49 if type == "probs":
50 return batch_size * dim * real_elem
51 return batch_size * dim * elem # state
54def estimate_peak_bytes(
55 n_qubits: int,
56 batch_size: int,
57 type: str,
58 use_density: bool,
59 n_obs: int = 0,
60 n_ops: int = 1,
61) -> int:
62 """Estimate peak memory (bytes) for a batched simulation.
64 The estimate accounts for:
66 - The batched statevector (always needed, even for density).
67 - The batched output tensor (state / probs / density / expval).
68 - Gate-tensor temporaries (the einsum buffers). XLA frequently
69 keeps several per-gate ``(B, dim)`` (or ``(B, dim, dim)`` for
70 density) buffers alive simultaneously when fusion is not
71 possible, so we multiply the per-element gate cost by *n_ops*
72 (the number of operations on the recorded tape).
74 Observable matrices are **not** counted: they are computed inside
75 the JIT-compiled function and XLA manages their lifetime (reusing
76 buffers between observables). Similarly, the outer-product
77 temporary for pure-circuit density mode is transient within XLA.
79 Element size is determined dynamically from ``jax.config.x64_enabled``:
80 when x64 mode is disabled (the JAX default), complex values are
81 ``complex64`` (8 bytes) and floats are ``float32`` (4 bytes),
82 halving memory usage compared to the x64 path.
84 A 1.5× safety factor is applied to cover XLA compiler temporaries,
85 padding, and other allocations not directly visible to Python.
87 This is a pure Python arithmetic calculation with no JAX calls —
88 it adds essentially zero overhead.
90 Args:
91 n_qubits: Number of qubits in the circuit.
92 batch_size: Number of batch elements.
93 type: Measurement type (``"state"``, ``"probs"``, ``"expval"``,
94 ``"density"``).
95 use_density: Whether density-matrix simulation is used.
96 n_obs: Number of observables (relevant for ``"expval"``).
97 n_ops: Number of operations on the circuit tape. Used to
98 scale the per-gate intermediate buffers. Defaults to 1
99 (backwards-compatible single-buffer estimate).
101 Returns:
102 Estimated peak memory in bytes.
103 """
104 dim = 2**n_qubits
105 # Detect actual element size: JAX silently truncates complex128
106 # to complex64 when x64 mode is disabled (the default).
107 elem, real_elem = _element_sizes()
109 # Clamp n_ops to at least 1 so callers that omit the argument
110 # reproduce the previous behaviour.
111 n_ops = max(int(n_ops), 1)
113 # Statevector: always allocated during simulation
114 sv_bytes = batch_size * dim * elem
116 # Simulation intermediate: when density-matrix simulation is used,
117 # the full rho (dim × dim) must be held during gate evolution —
118 # even if the final output is only probs or expval.
119 # apply_to_density contracts both U and U* against rho, so at least
120 # two intermediate (dim × dim) buffers are alive simultaneously
121 # *per applied operation*.
122 if use_density:
123 sim_bytes = 2 * n_ops * batch_size * dim * dim * elem
124 else:
125 sim_bytes = 0 # statevector is already counted above
127 # Output tensor: this is the *returned* array, not the simulation
128 # intermediate. For probs/expval with density simulation the
129 # density matrix is reduced to a small output *before* returning,
130 # so only the reduced output coexists with the next chunk.
131 out_bytes = _output_bytes(type, batch_size, dim, elem, real_elem, n_obs)
133 # Gate temporaries: einsum creates a ``(B, dim)`` (statevector) or
134 # ``(B, dim, dim)`` (density) buffer per gate, and XLA cannot
135 # always free them between consecutive ops, so scale by ``n_ops``.
136 if use_density:
137 gate_tmp = n_ops * batch_size * dim * dim * elem
138 else:
139 gate_tmp = n_ops * batch_size * dim * elem
141 # Peak = max(simulation phase, output phase). During simulation
142 # the intermediate + statevector + gate temps are alive. After
143 # measurement, only the output survives. So peak is whichever
144 # phase is larger.
145 sim_peak = sv_bytes + sim_bytes + gate_tmp
146 out_peak = out_bytes
147 raw = max(sim_peak, out_peak)
149 # 1.5× safety factor for XLA compiler temporaries, padding, etc.
150 return int(raw * 1.5)
153def available_memory_bytes() -> int:
154 """Return available system memory in bytes.
156 Uses ``psutil.virtual_memory().available`` for cross-platform
157 support (Linux, macOS, Windows). Falls back to reading
158 ``/proc/meminfo`` on Linux, and finally to a conservative 4 GiB
159 default if neither approach succeeds.
161 Returns:
162 Available memory in bytes.
163 """
164 mem = 4 * 1024**3
165 # Primary: psutil (works on Linux, macOS, Windows)
166 try:
167 import psutil
169 mem = psutil.virtual_memory().available
170 except Exception:
171 log.debug("psutil not available. Fallback to /proc/meminfo")
173 # Fallback: /proc/meminfo (Linux only)
174 try:
175 with open("/proc/meminfo", "r") as f:
176 for line in f:
177 if line.startswith("MemAvailable:"):
178 mem = int(line.split()[1]) * 1024 # kB → bytes
179 except Exception:
180 log.debug("Failed to read /proc/meminfo. Falling back to 4 GiB")
182 log.debug(f"Available memory: {mem / 1024**3:.1f} GB")
183 return mem
186def compute_chunk_size(
187 n_qubits: int,
188 batch_size: int,
189 type: str,
190 use_density: bool,
191 n_obs: int = 0,
192 memory_fraction: float = 0.8,
193 n_ops: int = 1,
194) -> int:
195 """Determine the largest chunk size that fits in available memory.
197 If the full batch fits, returns *batch_size* (i.e. no chunking).
198 Otherwise, returns the largest chunk size such that the computation
199 of one chunk **plus** the full output accumulator fits within
200 ``memory_fraction`` of available RAM.
202 The output accumulator is the final ``(batch_size, ...)`` array that
203 holds all results. When chunking, this array must coexist with the
204 active chunk computation, so its size is subtracted from available
205 memory before computing how many elements fit per chunk.
207 The minimum chunk size is 1 (fully serialised).
209 Args:
210 n_qubits: Number of qubits.
211 batch_size: Total batch size.
212 type: Measurement type.
213 use_density: Whether density-matrix simulation is used.
214 n_obs: Number of observables.
215 memory_fraction: Fraction of available memory to target
216 (default 0.8 = 80%).
217 n_ops: Number of operations on the recorded tape. Forwarded
218 to :func:`estimate_peak_bytes`. Defaults to 1.
220 Returns:
221 Chunk size (number of batch elements per sub-batch).
222 """
223 avail = int(available_memory_bytes() * memory_fraction)
224 full_est = estimate_peak_bytes(
225 n_qubits, batch_size, type, use_density, n_obs, n_ops=n_ops
226 )
228 if full_est <= avail:
229 return batch_size # everything fits — no chunking
231 # The output accumulator (the final (batch_size, ...) result array)
232 # must coexist with each chunk's computation, so subtract its size
233 # from available memory before sizing chunks.
234 dim = 2**n_qubits
235 elem, real_elem = _element_sizes()
236 accum_bytes = _output_bytes(type, batch_size, dim, elem, real_elem, n_obs)
237 avail_for_chunks = max(avail - accum_bytes, elem) # at least 1 element
239 # Per-element cost: the memory for computing a single batch element.
240 per_elem = estimate_peak_bytes(n_qubits, 1, type, use_density, n_obs, n_ops=n_ops)
242 if per_elem <= 0:
243 return batch_size
245 chunk = avail_for_chunks // per_elem
246 chunk = max(1, min(chunk, batch_size))
248 if chunk == 1 and per_elem > avail:
249 log.warning(
250 f"A single batch element requires ~{per_elem / 1024**3:.2f} GB "
251 f"but only ~{avail / 1024**3:.2f} GB is available. "
252 f"Proceeding with chunk_size=1 but OOM is possible. "
253 f"Consider reducing n_qubits or switching measurement type."
254 )
256 log.info(
257 f"Computation requires ~{full_est / 1024**3:.2f} GB which "
258 f"does not fit in ~{avail / 1024**3:.2f} GB. "
259 f"Using chunk size {chunk}."
260 )
261 return chunk
264def execute_chunked(
265 batched_fn: Callable,
266 args: tuple,
267 in_axes: Tuple,
268 batch_size: int,
269 chunk_size: int,
270 clear_caches: bool = False,
271) -> jnp.ndarray:
272 """Execute a vmapped function in memory-safe chunks.
274 Splits the batch dimension into sub-batches of at most *chunk_size*
275 elements, runs each through the JIT-compiled *batched_fn*, and
276 writes results into a pre-allocated output array.
278 Only one chunk's intermediate result is alive at a time: each
279 chunk is computed, copied into the output buffer, and then its
280 reference is dropped — allowing JAX/XLA to reclaim the memory
281 before the next chunk starts. This keeps peak memory at roughly
282 ``output_buffer + one_chunk_computation`` rather than the sum of
283 all chunk outputs.
285 Args:
286 batched_fn: A JIT-compiled, vmapped callable.
287 args: Full-batch arguments (before slicing).
288 in_axes: Per-argument batch axis specification.
289 batch_size: Total number of batch elements.
290 chunk_size: Maximum elements per chunk.
291 clear_caches: When ``True``, call ``jax.clear_caches()`` after each
292 chunk to release device buffers. Disabled by default because it
293 forces full recompilation of *batched_fn* on every subsequent chunk.
295 Returns:
296 Batched results with the same leading dimension as the
297 full batch.
298 """
299 n_chunks = (batch_size + chunk_size - 1) // chunk_size
300 log.debug(
301 f"Memory-aware chunking: splitting batch of {batch_size} into "
302 f"{n_chunks} chunks of <={chunk_size} elements."
303 )
305 output = None
306 for chunk_idx in range(n_chunks):
307 start = chunk_idx * chunk_size
308 end = min(start + chunk_size, batch_size)
309 size = end - start
311 # Slice each batched argument along its batch axis
312 chunk_args = tuple(
313 (
314 jax.lax.dynamic_slice_in_dim(a, start, size, axis=ax)
315 if ax is not None
316 else a
317 )
318 for a, ax in zip(args, in_axes)
319 )
321 chunk_result = batched_fn(*chunk_args)
323 if output is None:
324 # Pre-allocate the full output buffer on first chunk
325 out_shape = (batch_size,) + chunk_result.shape[1:]
326 output = jnp.zeros(out_shape, dtype=chunk_result.dtype)
328 # Copy chunk into the output buffer; the slice assignment
329 # creates a new array (JAX arrays are immutable) but the old
330 # `output` reference is immediately replaced, letting XLA
331 # reclaim it.
332 output = output.at[start:end].set(chunk_result)
334 # Explicitly drop the chunk reference so XLA can free the
335 # chunk's device memory before computing the next one.
336 del chunk_result, chunk_args
337 # Optionally trigger a JAX cache clear to release device
338 # buffers — disabled by default because it forces full
339 # recompilation of ``batched_fn`` on every subsequent
340 # chunk. Enable by passing ``clear_caches=True`` if you
341 # actually observe OOM growth across chunks.
342 if clear_caches:
343 jax.clear_caches() # TODO: confirm to remove
345 return output