Coverage for qml_essentials / simulation.py: 95%

103 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-06-11 15:51 +0000

1"""Pure simulation and measurement kernels for :class:`~qml_essentials.script.Script`. 

2 

3These functions are stateless: they take a recorded tape (a list of 

4:class:`~qml_essentials.operations.Operation`) plus measurement parameters and 

5return JAX arrays. Keeping them as module-level free functions (rather than 

6static methods on ``Script``) makes the simulation engine independently testable 

7and keeps ``script.py`` focused on orchestration. 

8""" 

9 

10from typing import List, Optional 

11 

12import jax 

13import jax.numpy as jnp 

14import numpy as np # needed to prevent jitting some operations 

15 

16from qml_essentials.operations import ( 

17 Barrier, 

18 Operation, 

19 KrausChannel, 

20 _einsum_subscript, 

21 _cdtype, 

22) 

23 

24 

25def infer_n_qubits(ops: List[Operation], obs: List[Operation]) -> int: 

26 """Infer the number of qubits from a list of operations and observables. 

27 

28 Args: 

29 ops: Gate operations recorded on the tape. 

30 obs: Observable operations used for measurement. 

31 

32 Returns: 

33 The smallest number of qubits that covers all wire indices, i.e. 

34 ``max(all_wires) + 1`` (at least 1). 

35 """ 

36 all_wires: set[int] = set() 

37 for op in ops + obs: 

38 all_wires.update(op.wires) 

39 return max(all_wires) + 1 if all_wires else 1 

40 

41 

42def uses_density(tape: List[Operation], type: str) -> bool: 

43 """Return whether density-matrix simulation is required. 

44 

45 Density-matrix simulation is needed when the caller explicitly requests the 

46 ``"density"`` measurement type, or when the tape contains a noise channel 

47 (a :class:`~qml_essentials.operations.KrausChannel`). 

48 

49 Args: 

50 tape: Ordered list of gate/channel operations. 

51 type: Requested measurement type. 

52 

53 Returns: 

54 ``True`` if density-matrix simulation must be used. 

55 """ 

56 has_noise = any(isinstance(op, KrausChannel) for op in tape) 

57 return type == "density" or has_noise 

58 

59 

60def _stack_obs(obs: List[Operation], n_qubits: int) -> jnp.ndarray: 

61 """Stack lifted observable matrices into a single ``(n_obs, dim, dim)`` array.""" 

62 return jnp.stack([ob.lifted_matrix(n_qubits) for ob in obs], axis=0) 

63 

64 

65def simulate_pure(tape: List[Operation], n_qubits: int) -> jnp.ndarray: 

66 """Statevector simulation kernel. 

67 

68 Starts from |00…0⟩ and applies each gate in *tape* via tensor 

69 contraction. The state is kept in rank-*n* tensor form ``(2,)*n`` 

70 throughout the gate loop to avoid per-gate ``reshape`` dispatch; 

71 only the initial and final conversions to/from the flat ``(2**n,)`` 

72 representation incur a reshape. 

73 

74 All gate tensors and einsum subscript strings are pre-extracted from 

75 the tape before the simulation loop so that each iteration performs 

76 only a single ``jnp.einsum`` call with zero additional Python 

77 overhead (no method dispatch, no property access, no cache lookup). 

78 

79 Args: 

80 tape: Ordered list of gate operations to apply. 

81 n_qubits: Total number of qubits. 

82 

83 Returns: 

84 Statevector of shape ``(2**n_qubits,)``. 

85 """ 

86 dim = 2**n_qubits 

87 

88 # Pre-extract gate tensors and einsum subscripts — eliminates all 

89 # per-gate Python overhead (method calls, property lookups, cache 

90 # hits on _einsum_subscript) from the hot loop. 

91 compiled = [] 

92 for op in tape: 

93 if isinstance(op, Barrier): 

94 continue 

95 k = len(op.wires) 

96 gt = op._gate_tensor(k) 

97 sub = _einsum_subscript(n_qubits, k, tuple(op.wires)) 

98 compiled.append((gt, sub)) 

99 

100 state = jnp.zeros(dim, dtype=_cdtype()).at[0].set(1.0) 

101 psi = state.reshape((2,) * n_qubits) 

102 for gt, sub in compiled: 

103 psi = jnp.einsum(sub, gt, psi) 

104 return psi.reshape(dim) 

105 

106 

107def simulate_mixed(tape: List[Operation], n_qubits: int) -> jnp.ndarray: 

108 """Density-matrix simulation kernel. 

109 

110 Starts from \\rho = \\vert 0\\rangle\\langle 0\\vert and 

111 applies each gate in *tape* via 

112 :meth:`~qml_essentials.operations.Operation.apply_to_density` 

113 (\\rho -> U\\rho U† for unitaries, \\Sigma_k K_k \\rho K_k\\dagger 

114 for Kraus channels). 

115 Required for noisy circuits. 

116 

117 Args: 

118 tape: Ordered list of gate or channel operations to apply. 

119 n_qubits: Total number of qubits. 

120 

121 Returns: 

122 Density matrix of shape ``(2**n_qubits, 2**n_qubits)``. 

123 """ 

124 dim = 2**n_qubits 

125 rho = jnp.zeros((dim, dim), dtype=_cdtype()).at[0, 0].set(1.0) 

126 for op in tape: 

127 rho = op.apply_to_density(rho, n_qubits) 

128 return rho 

129 

130 

131def simulate_and_measure( 

132 tape: List[Operation], 

133 n_qubits: int, 

134 type: str, 

135 obs: List[Operation], 

136 use_density: bool, 

137 shots: Optional[int] = None, 

138 key: Optional[jnp.ndarray] = None, 

139) -> jnp.ndarray: 

140 """Run simulation and measurement in a single dispatch. 

141 

142 Chooses statevector or density-matrix simulation based on 

143 *use_density*, then applies the appropriate measurement function. 

144 This eliminates duplicated branching logic in single-sample and 

145 batched execution paths. 

146 

147 When *shots* is not ``None``, the exact probability distribution is 

148 first computed, then ``shots`` samples are drawn from it to produce 

149 a noisy estimate of the requested measurement (``"probs"`` or 

150 ``"expval"``). 

151 

152 Pure-circuit density optimisation — when ``type == "density"`` 

153 but no noise channels are present on the tape, the density matrix 

154 is computed via statevector simulation followed by an outer product 

155 ``\\rho = \\vert\\psi\\rangle\\langle\\psi\\vert`` 

156 instead of evolving the full ``2^n\\times 2^n`` matrix 

157 gate by gate. This reduces the per-gate cost from O(4^n) to 

158 O(2^n), giving a significant speed-up for medium qubit counts 

159 (~4x for 5 qubits). 

160 

161 Args: 

162 tape: Ordered list of gate/channel operations to apply. 

163 n_qubits: Total number of qubits. 

164 type: Measurement type (``"state"``/``"probs"``/``"expval"``/ 

165 ``"density"``). 

166 obs: Observables for ``"expval"`` measurements. 

167 use_density: If ``True``, use density-matrix simulation. 

168 shots: Number of measurement shots. If ``None`` (default), 

169 exact analytic results are returned. 

170 key: JAX PRNG key for shot sampling. Required when *shots* 

171 is not ``None``. 

172 

173 Returns: 

174 Measurement result (shape depends on *type*). 

175 """ 

176 if use_density: 

177 # Check if any operation is actually a noise channel. 

178 has_noise = any(isinstance(o, KrausChannel) for o in tape) 

179 if has_noise: 

180 # Must do full density-matrix evolution for Kraus channels. 

181 rho = simulate_mixed(tape, n_qubits) 

182 else: 

183 # Pure circuit requesting density output: simulate the 

184 # statevector (O(depth\times 2^n)) and form # noqa: W605 

185 # \rho = \vert\psi\rangle\langle\psi\vert once # noqa: W605 

186 # (O(4^n)). This avoids the O(depth\times 4^n) cost of # noqa: W605 

187 # evolving the full density matrix gate by gate. 

188 state = simulate_pure(tape, n_qubits) 

189 rho = jnp.outer(state, jnp.conj(state)) 

190 

191 if shots is not None and type in ("probs", "expval"): 

192 exact_probs = jnp.real(jnp.diag(rho)) 

193 return sample_shots(exact_probs, n_qubits, type, obs, shots, key) 

194 return measure_density(rho, n_qubits, type, obs) 

195 

196 state = simulate_pure(tape, n_qubits) 

197 

198 if shots is not None and type in ("probs", "expval"): 

199 exact_probs = jnp.abs(state) ** 2 

200 return sample_shots(exact_probs, n_qubits, type, obs, shots, key) 

201 return measure_state(state, n_qubits, type, obs) 

202 

203 

204def measure_state( 

205 state: jnp.ndarray, 

206 n_qubits: int, 

207 type: str, 

208 obs: List[Operation], 

209) -> jnp.ndarray: 

210 """Apply the requested measurement to a pure statevector. 

211 

212 Args: 

213 state: Statevector of shape ``(2**n_qubits,)``. 

214 n_qubits: Total number of qubits. 

215 type: Measurement type — one of ``"state"``, ``"probs"``, 

216 or ``"expval"``. 

217 obs: Observables used when *type* is ``"expval"``. 

218 

219 Returns: 

220 Measurement result whose shape depends on *type*: 

221 

222 - ``"state"`` -> ``(2**n_qubits,)`` 

223 - ``"probs"`` -> ``(2**n_qubits,)`` 

224 - ``"expval"`` -> ``(len(obs),)`` 

225 

226 Raises: 

227 ValueError: If *type* is not a recognised measurement type. 

228 """ 

229 if type == "state": 

230 return state 

231 

232 if type == "probs": 

233 return jnp.abs(state) ** 2 

234 

235 if type == "expval": 

236 # Fast path for single-qubit diagonal observables (PauliZ, etc.) 

237 # where d0, d1 are the diagonal elements of the 2x2 observable. 

238 # This replaces n_obs tensor contractions with a single |ψ|² 

239 # and n_obs reductions over the probability vector. 

240 

241 def _is_single_qubit_diag(ob): 

242 m = ob.__class__._matrix 

243 if m is None or len(ob.wires) != 1: 

244 return False 

245 # Convert to NumPy to ensure concrete boolean evaluation 

246 m_np = np.asarray(m) 

247 return np.allclose(m_np - np.diag(np.diag(m_np)), 0) 

248 

249 all_single_qubit_diag = all(_is_single_qubit_diag(ob) for ob in obs) 

250 

251 if all_single_qubit_diag: 

252 probs = jnp.abs(state) ** 2 

253 psi_t = probs.reshape((2,) * n_qubits) 

254 results = [] 

255 for ob in obs: 

256 q = ob.wires[0] 

257 d = np.real(np.diag(np.asarray(ob.__class__._matrix))) 

258 # Sum probabilities over all axes except qubit q 

259 p_q = jnp.sum(psi_t, axis=tuple(i for i in range(n_qubits) if i != q)) 

260 results.append(d[0] * p_q[0] + d[1] * p_q[1]) 

261 return jnp.array(results) 

262 

263 # General path: stack observable matrices and use a single 

264 # batched matmul instead of a Python loop of tensor contractions. 

265 # O_states[i] = obs[i] |ψ⟩, then ⟨O_i⟩ = Re(⟨ψ|O_states[i]⟩). 

266 obs_mats = _stack_obs(obs, n_qubits) # (n_obs, dim, dim) 

267 # Batched matvec: (n_obs, dim, dim) @ (dim,) -> (n_obs, dim) 

268 O_states = jnp.einsum("oij,j->oi", obs_mats, state) 

269 return jnp.real(jnp.einsum("i,oi->o", jnp.conj(state), O_states)) 

270 

271 raise ValueError(f"Unknown measurement type: {type!r}") 

272 

273 

274def measure_density( 

275 rho: jnp.ndarray, 

276 n_qubits: int, 

277 type: str, 

278 obs: List[Operation], 

279) -> jnp.ndarray: 

280 """Apply the requested measurement to a density matrix. 

281 

282 Args: 

283 rho: Density matrix of shape ``(2**n_qubits, 2**n_qubits)``. 

284 n_qubits: Total number of qubits. 

285 type: Measurement type — one of ``"density"``, ``"probs"``, 

286 or ``"expval"``. 

287 obs: Observables used when *type* is ``"expval"``. 

288 

289 Returns: 

290 Measurement result whose shape depends on *type*: 

291 

292 - ``"density"`` -> ``(2**n_qubits, 2**n_qubits)`` 

293 - ``"probs"`` -> ``(2**n_qubits,)`` 

294 - ``"expval"`` -> ``(len(obs),)`` 

295 

296 Raises: 

297 ValueError: If *type* is ``"state"`` (not valid for mixed circuits) 

298 or another unrecognised type. 

299 """ 

300 if type == "density": 

301 return rho 

302 

303 if type == "probs": 

304 return jnp.real(jnp.diag(rho)) 

305 

306 if type == "expval": 

307 # Tr(O \\rho ) = \\Sigma_ij O_ij \\rho _ji 

308 # Stack all observable matrices and compute all traces in one 

309 # batched operation. 

310 obs_mats = _stack_obs(obs, n_qubits) # (n_obs, dim, dim) 

311 # einsum "oij,ji->o" computes Tr(O_o @ \\rho ) for each observable 

312 return jnp.real(jnp.einsum("oij,ji->o", obs_mats, rho)) 

313 

314 raise ValueError( 

315 "Measurement type 'state' is not defined for mixed (noisy) circuits. " 

316 "Use 'density' instead." 

317 ) 

318 

319 

320def sample_shots( 

321 probs: jnp.ndarray, 

322 n_qubits: int, 

323 type: str, 

324 obs: List[Operation], 

325 shots: int, 

326 key: jnp.ndarray, 

327) -> jnp.ndarray: 

328 """Convert exact probabilities into shot-sampled results. 

329 

330 Draws *shots* samples from the computational-basis probability 

331 distribution and returns either estimated probabilities or 

332 shot-based expectation values. 

333 

334 Args: 

335 probs: Exact probability vector of shape ``(2**n_qubits,)``. 

336 n_qubits: Total number of qubits. 

337 type: Measurement type — ``"probs"`` or ``"expval"``. 

338 obs: Observables used when *type* is ``"expval"``. 

339 shots: Number of measurement shots. 

340 key: JAX PRNG key for sampling. 

341 

342 Returns: 

343 Shot-sampled measurement result: 

344 

345 - ``"probs"`` → ``(2**n_qubits,)`` estimated probabilities. 

346 - ``"expval"`` → ``(len(obs),)`` estimated expectation values. 

347 """ 

348 dim = 2**n_qubits 

349 

350 # Draw `shots` samples from the computational basis. 

351 # Each sample is an integer in [0, dim) representing a basis state. 

352 samples = jax.random.choice(key, dim, shape=(shots,), p=probs) 

353 

354 # Build a histogram of counts for each basis state. 

355 counts = jnp.zeros(dim, dtype=jnp.int32) 

356 counts = counts.at[samples].add(1) 

357 estimated_probs = counts / shots 

358 

359 if type == "probs": 

360 return estimated_probs 

361 

362 if type == "expval": 

363 # For each observable, compute O from the shot-sampled 

364 # probabilities. For diagonal observables this is exact; 

365 # for general observables we use Tr(O · diag(estimated_probs)). 

366 results = [] 

367 for ob in obs: 

368 O_mat = ob.lifted_matrix(n_qubits) 

369 # diagonal approximation from 

370 # computational basis measurements, which is exact for 

371 # diagonal observables like PauliZ) 

372 results.append(jnp.real(jnp.dot(jnp.diag(O_mat), estimated_probs))) 

373 return jnp.array(results) 

374 

375 raise ValueError( 

376 f"Shot simulation is only supported for 'probs' and 'expval', got {type!r}." 

377 )