Coverage for qml_essentials / memory.py: 93%

85 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-06-11 15:51 +0000

1"""Memory estimation and memory-aware batch chunking. 

2 

3These helpers let :class:`~qml_essentials.script.Script` decide whether a batched 

4simulation fits in available RAM and, if not, split it into chunks that do. They 

5are pure functions (the estimates are plain Python arithmetic) so they add 

6essentially zero overhead when the full batch fits. 

7""" 

8 

9from typing import Callable, Tuple 

10 

11import jax 

12import jax.numpy as jnp 

13 

14import logging 

15 

16log = logging.getLogger(__name__) 

17 

18# Whether to call ``jax.clear_caches()`` between memory-aware chunks in 

19# :func:`execute_chunked`. Default ``False``: clearing caches between chunks 

20# forces XLA to recompile the same batched program for every chunk, which is a 

21# major performance hit when many chunks are needed. Set ``True`` only if you 

22# observe OOM growth across chunks. 

23CLEAR_CACHES_BETWEEN_CHUNKS: bool = False 

24 

25 

26def _element_sizes() -> Tuple[int, int]: 

27 """Return ``(complex_elem, real_elem)`` byte sizes for the active JAX dtype. 

28 

29 JAX silently truncates complex128 to complex64 (and float64 to float32) when 

30 x64 mode is disabled (the default), halving memory usage. 

31 """ 

32 elem = 16 if jax.config.x64_enabled else 8 # complex128 vs complex64 

33 return elem, elem // 2 # (complex, real/float) 

34 

35 

36def _output_bytes( 

37 type: str, 

38 batch_size: int, 

39 dim: int, 

40 elem: int, 

41 real_elem: int, 

42 n_obs: int, 

43) -> int: 

44 """Bytes of the returned/accumulated ``(batch_size, ...)`` measurement array.""" 

45 if type == "density": 

46 return batch_size * dim * dim * elem 

47 if type == "expval": 

48 return batch_size * max(n_obs, 1) * real_elem 

49 if type == "probs": 

50 return batch_size * dim * real_elem 

51 return batch_size * dim * elem # state 

52 

53 

54def estimate_peak_bytes( 

55 n_qubits: int, 

56 batch_size: int, 

57 type: str, 

58 use_density: bool, 

59 n_obs: int = 0, 

60 n_ops: int = 1, 

61) -> int: 

62 """Estimate peak memory (bytes) for a batched simulation. 

63 

64 The estimate accounts for: 

65 

66 - The batched statevector (always needed, even for density). 

67 - The batched output tensor (state / probs / density / expval). 

68 - Gate-tensor temporaries (the einsum buffers). XLA frequently 

69 keeps several per-gate ``(B, dim)`` (or ``(B, dim, dim)`` for 

70 density) buffers alive simultaneously when fusion is not 

71 possible, so we multiply the per-element gate cost by *n_ops* 

72 (the number of operations on the recorded tape). 

73 

74 Observable matrices are **not** counted: they are computed inside 

75 the JIT-compiled function and XLA manages their lifetime (reusing 

76 buffers between observables). Similarly, the outer-product 

77 temporary for pure-circuit density mode is transient within XLA. 

78 

79 Element size is determined dynamically from ``jax.config.x64_enabled``: 

80 when x64 mode is disabled (the JAX default), complex values are 

81 ``complex64`` (8 bytes) and floats are ``float32`` (4 bytes), 

82 halving memory usage compared to the x64 path. 

83 

84 A 1.5× safety factor is applied to cover XLA compiler temporaries, 

85 padding, and other allocations not directly visible to Python. 

86 

87 This is a pure Python arithmetic calculation with no JAX calls — 

88 it adds essentially zero overhead. 

89 

90 Args: 

91 n_qubits: Number of qubits in the circuit. 

92 batch_size: Number of batch elements. 

93 type: Measurement type (``"state"``, ``"probs"``, ``"expval"``, 

94 ``"density"``). 

95 use_density: Whether density-matrix simulation is used. 

96 n_obs: Number of observables (relevant for ``"expval"``). 

97 n_ops: Number of operations on the circuit tape. Used to 

98 scale the per-gate intermediate buffers. Defaults to 1 

99 (backwards-compatible single-buffer estimate). 

100 

101 Returns: 

102 Estimated peak memory in bytes. 

103 """ 

104 dim = 2**n_qubits 

105 # Detect actual element size: JAX silently truncates complex128 

106 # to complex64 when x64 mode is disabled (the default). 

107 elem, real_elem = _element_sizes() 

108 

109 # Clamp n_ops to at least 1 so callers that omit the argument 

110 # reproduce the previous behaviour. 

111 n_ops = max(int(n_ops), 1) 

112 

113 # Statevector: always allocated during simulation 

114 sv_bytes = batch_size * dim * elem 

115 

116 # Simulation intermediate: when density-matrix simulation is used, 

117 # the full rho (dim × dim) must be held during gate evolution — 

118 # even if the final output is only probs or expval. 

119 # apply_to_density contracts both U and U* against rho, so at least 

120 # two intermediate (dim × dim) buffers are alive simultaneously 

121 # *per applied operation*. 

122 if use_density: 

123 sim_bytes = 2 * n_ops * batch_size * dim * dim * elem 

124 else: 

125 sim_bytes = 0 # statevector is already counted above 

126 

127 # Output tensor: this is the *returned* array, not the simulation 

128 # intermediate. For probs/expval with density simulation the 

129 # density matrix is reduced to a small output *before* returning, 

130 # so only the reduced output coexists with the next chunk. 

131 out_bytes = _output_bytes(type, batch_size, dim, elem, real_elem, n_obs) 

132 

133 # Gate temporaries: einsum creates a ``(B, dim)`` (statevector) or 

134 # ``(B, dim, dim)`` (density) buffer per gate, and XLA cannot 

135 # always free them between consecutive ops, so scale by ``n_ops``. 

136 if use_density: 

137 gate_tmp = n_ops * batch_size * dim * dim * elem 

138 else: 

139 gate_tmp = n_ops * batch_size * dim * elem 

140 

141 # Peak = max(simulation phase, output phase). During simulation 

142 # the intermediate + statevector + gate temps are alive. After 

143 # measurement, only the output survives. So peak is whichever 

144 # phase is larger. 

145 sim_peak = sv_bytes + sim_bytes + gate_tmp 

146 out_peak = out_bytes 

147 raw = max(sim_peak, out_peak) 

148 

149 # 1.5× safety factor for XLA compiler temporaries, padding, etc. 

150 return int(raw * 1.5) 

151 

152 

153def available_memory_bytes() -> int: 

154 """Return available system memory in bytes. 

155 

156 Uses ``psutil.virtual_memory().available`` for cross-platform 

157 support (Linux, macOS, Windows). Falls back to reading 

158 ``/proc/meminfo`` on Linux, and finally to a conservative 4 GiB 

159 default if neither approach succeeds. 

160 

161 Returns: 

162 Available memory in bytes. 

163 """ 

164 mem = 4 * 1024**3 

165 # Primary: psutil (works on Linux, macOS, Windows) 

166 try: 

167 import psutil 

168 

169 mem = psutil.virtual_memory().available 

170 except Exception: 

171 log.debug("psutil not available. Fallback to /proc/meminfo") 

172 

173 # Fallback: /proc/meminfo (Linux only) 

174 try: 

175 with open("/proc/meminfo", "r") as f: 

176 for line in f: 

177 if line.startswith("MemAvailable:"): 

178 mem = int(line.split()[1]) * 1024 # kB → bytes 

179 except Exception: 

180 log.debug("Failed to read /proc/meminfo. Falling back to 4 GiB") 

181 

182 log.debug(f"Available memory: {mem / 1024**3:.1f} GB") 

183 return mem 

184 

185 

186def compute_chunk_size( 

187 n_qubits: int, 

188 batch_size: int, 

189 type: str, 

190 use_density: bool, 

191 n_obs: int = 0, 

192 memory_fraction: float = 0.8, 

193 n_ops: int = 1, 

194) -> int: 

195 """Determine the largest chunk size that fits in available memory. 

196 

197 If the full batch fits, returns *batch_size* (i.e. no chunking). 

198 Otherwise, returns the largest chunk size such that the computation 

199 of one chunk **plus** the full output accumulator fits within 

200 ``memory_fraction`` of available RAM. 

201 

202 The output accumulator is the final ``(batch_size, ...)`` array that 

203 holds all results. When chunking, this array must coexist with the 

204 active chunk computation, so its size is subtracted from available 

205 memory before computing how many elements fit per chunk. 

206 

207 The minimum chunk size is 1 (fully serialised). 

208 

209 Args: 

210 n_qubits: Number of qubits. 

211 batch_size: Total batch size. 

212 type: Measurement type. 

213 use_density: Whether density-matrix simulation is used. 

214 n_obs: Number of observables. 

215 memory_fraction: Fraction of available memory to target 

216 (default 0.8 = 80%). 

217 n_ops: Number of operations on the recorded tape. Forwarded 

218 to :func:`estimate_peak_bytes`. Defaults to 1. 

219 

220 Returns: 

221 Chunk size (number of batch elements per sub-batch). 

222 """ 

223 avail = int(available_memory_bytes() * memory_fraction) 

224 full_est = estimate_peak_bytes( 

225 n_qubits, batch_size, type, use_density, n_obs, n_ops=n_ops 

226 ) 

227 

228 if full_est <= avail: 

229 return batch_size # everything fits — no chunking 

230 

231 # The output accumulator (the final (batch_size, ...) result array) 

232 # must coexist with each chunk's computation, so subtract its size 

233 # from available memory before sizing chunks. 

234 dim = 2**n_qubits 

235 elem, real_elem = _element_sizes() 

236 accum_bytes = _output_bytes(type, batch_size, dim, elem, real_elem, n_obs) 

237 avail_for_chunks = max(avail - accum_bytes, elem) # at least 1 element 

238 

239 # Per-element cost: the memory for computing a single batch element. 

240 per_elem = estimate_peak_bytes(n_qubits, 1, type, use_density, n_obs, n_ops=n_ops) 

241 

242 if per_elem <= 0: 

243 return batch_size 

244 

245 chunk = avail_for_chunks // per_elem 

246 chunk = max(1, min(chunk, batch_size)) 

247 

248 if chunk == 1 and per_elem > avail: 

249 log.warning( 

250 f"A single batch element requires ~{per_elem / 1024**3:.2f} GB " 

251 f"but only ~{avail / 1024**3:.2f} GB is available. " 

252 f"Proceeding with chunk_size=1 but OOM is possible. " 

253 f"Consider reducing n_qubits or switching measurement type." 

254 ) 

255 

256 log.info( 

257 f"Computation requires ~{full_est / 1024**3:.2f} GB which " 

258 f"does not fit in ~{avail / 1024**3:.2f} GB. " 

259 f"Using chunk size {chunk}." 

260 ) 

261 return chunk 

262 

263 

264def execute_chunked( 

265 batched_fn: Callable, 

266 args: tuple, 

267 in_axes: Tuple, 

268 batch_size: int, 

269 chunk_size: int, 

270 clear_caches: bool = False, 

271) -> jnp.ndarray: 

272 """Execute a vmapped function in memory-safe chunks. 

273 

274 Splits the batch dimension into sub-batches of at most *chunk_size* 

275 elements, runs each through the JIT-compiled *batched_fn*, and 

276 writes results into a pre-allocated output array. 

277 

278 Only one chunk's intermediate result is alive at a time: each 

279 chunk is computed, copied into the output buffer, and then its 

280 reference is dropped — allowing JAX/XLA to reclaim the memory 

281 before the next chunk starts. This keeps peak memory at roughly 

282 ``output_buffer + one_chunk_computation`` rather than the sum of 

283 all chunk outputs. 

284 

285 Args: 

286 batched_fn: A JIT-compiled, vmapped callable. 

287 args: Full-batch arguments (before slicing). 

288 in_axes: Per-argument batch axis specification. 

289 batch_size: Total number of batch elements. 

290 chunk_size: Maximum elements per chunk. 

291 clear_caches: When ``True``, call ``jax.clear_caches()`` after each 

292 chunk to release device buffers. Disabled by default because it 

293 forces full recompilation of *batched_fn* on every subsequent chunk. 

294 

295 Returns: 

296 Batched results with the same leading dimension as the 

297 full batch. 

298 """ 

299 n_chunks = (batch_size + chunk_size - 1) // chunk_size 

300 log.debug( 

301 f"Memory-aware chunking: splitting batch of {batch_size} into " 

302 f"{n_chunks} chunks of <={chunk_size} elements." 

303 ) 

304 

305 output = None 

306 for chunk_idx in range(n_chunks): 

307 start = chunk_idx * chunk_size 

308 end = min(start + chunk_size, batch_size) 

309 size = end - start 

310 

311 # Slice each batched argument along its batch axis 

312 chunk_args = tuple( 

313 ( 

314 jax.lax.dynamic_slice_in_dim(a, start, size, axis=ax) 

315 if ax is not None 

316 else a 

317 ) 

318 for a, ax in zip(args, in_axes) 

319 ) 

320 

321 chunk_result = batched_fn(*chunk_args) 

322 

323 if output is None: 

324 # Pre-allocate the full output buffer on first chunk 

325 out_shape = (batch_size,) + chunk_result.shape[1:] 

326 output = jnp.zeros(out_shape, dtype=chunk_result.dtype) 

327 

328 # Copy chunk into the output buffer; the slice assignment 

329 # creates a new array (JAX arrays are immutable) but the old 

330 # `output` reference is immediately replaced, letting XLA 

331 # reclaim it. 

332 output = output.at[start:end].set(chunk_result) 

333 

334 # Explicitly drop the chunk reference so XLA can free the 

335 # chunk's device memory before computing the next one. 

336 del chunk_result, chunk_args 

337 # Optionally trigger a JAX cache clear to release device 

338 # buffers — disabled by default because it forces full 

339 # recompilation of ``batched_fn`` on every subsequent 

340 # chunk. Enable by passing ``clear_caches=True`` if you 

341 # actually observe OOM growth across chunks. 

342 if clear_caches: 

343 jax.clear_caches() # TODO: confirm to remove 

344 

345 return output