Coverage for qml_essentials / entanglement.py: 92%

237 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-04-10 10:29 +0000

1from typing import Optional, Any, Tuple 

2import jax 

3import jax.numpy as jnp 

4import numpy as np 

5 

6from qml_essentials import yaqsi as ys 

7from qml_essentials import operations as op 

8from qml_essentials.math import logm_v 

9from qml_essentials.model import Model 

10import logging 

11 

12from qml_essentials.operations import Hermitian 

13 

14log = logging.getLogger(__name__) 

15 

16 

17class Entanglement: 

18 @staticmethod 

19 def meyer_wallach( 

20 model: Model, 

21 n_samples: Optional[int | None], 

22 random_key: Optional[jax.random.PRNGKey] = None, 

23 scale: bool = False, 

24 **kwargs: Any, 

25 ) -> float: 

26 """ 

27 Calculates the entangling capacity of a given quantum circuit 

28 using Meyer-Wallach measure. 

29 

30 Args: 

31 model (Model): The quantum circuit model. 

32 n_samples (Optional[int]): Number of samples per qubit. 

33 If None or < 0, the current parameters of the model are used. 

34 random_key (Optional[jax.random.PRNGKey]): JAX random key for 

35 parameter initialization. If None, uses the model's internal 

36 random key. 

37 scale (bool): Whether to scale the number of samples. 

38 kwargs (Any): Additional keyword arguments for the model function. 

39 

40 Returns: 

41 float: Entangling capacity of the given circuit, guaranteed 

42 to be between 0.0 and 1.0. 

43 """ 

44 if "noise_params" in kwargs: 

45 log.warning( 

46 "Meyer-Wallach measure not suitable for noisy circuits. " 

47 "Consider 'concentratable entanglement' instead." 

48 ) 

49 

50 if scale: 

51 n_samples = jnp.power(2, model.n_qubits) * n_samples 

52 

53 if n_samples is not None and n_samples > 0: 

54 random_key = model.initialize_params(random_key, repeat=n_samples) 

55 

56 # implicitly set input to none in case it's not needed 

57 kwargs.setdefault("inputs", None) 

58 # explicitly set execution type because everything else won't work 

59 rhos = model(execution_type="density", **kwargs).reshape( 

60 -1, 2**model.n_qubits, 2**model.n_qubits 

61 ) 

62 

63 ent = Entanglement._compute_meyer_wallach_meas(rhos, model.n_qubits) 

64 

65 log.debug(f"Variance of measure: {ent.var()}") 

66 

67 return ent.mean() 

68 

69 @staticmethod 

70 def _compute_meyer_wallach_meas(rhos: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

71 """ 

72 Computes the Meyer-Wallach entangling capability measure for a given 

73 set of density matrices. 

74 

75 Args: 

76 rhos (jnp.ndarray): Density matrices of the sample quantum states. 

77 The shape is (B_s, 2^n, 2^n), where B_s is the number of samples 

78 (batch) and n the number of qubits 

79 n_qubits (int): The number of qubits 

80 

81 Returns: 

82 jnp.ndarray: Entangling capability for each sample, array with 

83 shape (B_s,) 

84 """ 

85 qb = list(range(n_qubits)) 

86 

87 def _f(rhos): 

88 entropy = 0 

89 for j in range(n_qubits): 

90 # Formula 6 in https://doi.org/10.48550/arXiv.quant-ph/0305094 

91 # Trace out qubit j, keep all others 

92 keep = qb[:j] + qb[j + 1 :] 

93 density = ys.partial_trace(rhos, n_qubits, keep) 

94 # only real values, because imaginary part will be separate 

95 # in all following calculations anyway 

96 # entropy should be 1/2 <= entropy <= 1 

97 entropy += jnp.trace((density @ density).real, axis1=-2, axis2=-1) 

98 

99 # inverse averaged entropy and scale to [0, 1] 

100 return 2 * (1 - entropy / n_qubits) 

101 

102 return jax.vmap(_f)(rhos) 

103 

104 @staticmethod 

105 def bell_measurements( 

106 model: Model, 

107 n_samples: int, 

108 random_key: Optional[jax.random.PRNGKey] = None, 

109 scale: bool = False, 

110 **kwargs: Any, 

111 ) -> float: 

112 """ 

113 Compute the Bell measurement for a given model. 

114 

115 Constructs a ``2 * n_qubits`` circuit that prepares two copies of 

116 the model state (on disjoint qubit registers), applies CNOTs and 

117 Hadamards, and measures probabilities on the first register. 

118 

119 Args: 

120 model (Model): The quantum circuit model. 

121 n_samples (int): The number of samples to compute the measure for. 

122 random_key (Optional[jax.random.PRNGKey]): JAX random key for 

123 parameter initialization. If None, uses the model's internal 

124 random key. 

125 scale (bool): Whether to scale the number of samples 

126 according to the number of qubits. 

127 **kwargs (Any): Additional keyword arguments for the model function. 

128 

129 Returns: 

130 float: The Bell measurement value. 

131 """ 

132 if "noise_params" in kwargs: 

133 log.warning( 

134 "Bell Measurements not suitable for noisy circuits. " 

135 "Consider 'concentratable entanglement' instead." 

136 ) 

137 

138 if scale: 

139 n_samples = jnp.power(2, model.n_qubits) * n_samples 

140 

141 n = model.n_qubits 

142 

143 def _bell_circuit(params, inputs, pulse_params=None, random_key=None, **kw): 

144 """Bell measurement circuit on 2*n qubits.""" 

145 from qml_essentials.tape import copy_to_tape 

146 

147 def vari(): 

148 model._variational( 

149 params, 

150 inputs, 

151 pulse_params=pulse_params, 

152 random_key=random_key, 

153 **kw 

154 ) 

155 

156 # First copy on wires 0..n-1 

157 vari() 

158 # Second copy on wires n..2n-1 

159 copy_to_tape(vari, offset=n) 

160 

161 # Bell measurement: CNOT + H 

162 for q in range(n): 

163 op.CX(wires=[q, q + n]) 

164 op.H(wires=q) 

165 

166 bell_script = ys.Script(f=_bell_circuit, n_qubits=2 * n) 

167 

168 if n_samples is not None and n_samples > 0: 

169 random_key = model.initialize_params(random_key, repeat=n_samples) 

170 params = model.params 

171 else: 

172 if len(model.params.shape) <= 2: 

173 params = model.params.reshape(1, *model.params.shape) 

174 else: 

175 log.info(f"Using sample size of model params: {model.params.shape[0]}") 

176 params = model.params 

177 

178 n_samples = params.shape[0] 

179 inputs = model._inputs_validation(kwargs.get("inputs", None)) 

180 

181 # Execute: vmap over batch dimension of params (axis 0) 

182 if n_samples > 1: 

183 from qml_essentials.utils import safe_random_split 

184 

185 random_keys = safe_random_split(random_key, num=n_samples) 

186 result = bell_script.execute( 

187 type="probs", 

188 args=(params, inputs, model.pulse_params, random_keys), 

189 kwargs=kwargs, 

190 in_axes=(0, None, None, 0), 

191 ) 

192 else: 

193 result = bell_script.execute( 

194 type="probs", 

195 args=(params, inputs, model.pulse_params, random_key), 

196 kwargs=kwargs, 

197 ) 

198 

199 # Marginalize: for each qubit q, keep wires [q, q+n] from the 2n-qubit probs 

200 # The last probability in each pair gives P(|11⟩) for that qubit pair 

201 per_qubit = [] 

202 for q in range(n): 

203 marg = ys.marginalize_probs(result, 2 * n, [q, q + n]) 

204 per_qubit.append(marg) 

205 # per_qubit[q] has shape (n_samples, 4) or (4,) 

206 exp = jnp.stack(per_qubit, axis=-2) # (..., n, 4) 

207 exp = 1 - 2 * exp[..., -1] # (..., n) 

208 

209 if not jnp.isclose(jnp.sum(exp.imag), 0, atol=1e-6): 

210 log.warning("Imaginary part of probabilities detected") 

211 exp = jnp.abs(exp) 

212 

213 measure = 2 * (1 - exp.mean(axis=0)) 

214 entangling_capability = min(max(float(measure.mean()), 0.0), 1.0) 

215 log.debug(f"Variance of measure: {measure.var()}") 

216 

217 return entangling_capability 

218 

219 @staticmethod 

220 def relative_entropy( 

221 model: Model, 

222 n_samples: int, 

223 n_sigmas: int, 

224 random_key: Optional[jax.random.PRNGKey] = None, 

225 scale: bool = False, 

226 **kwargs: Any, 

227 ) -> float: 

228 """ 

229 Calculates the relative entropy of entanglement of a given quantum 

230 circuit. This measure is also applicable to mixed state, albeit it 

231 might me not fully accurate in this simplified case. 

232 

233 As the relative entropy is generally defined as the smallest relative 

234 entropy from the state in question to the set of separable states. 

235 However, as computing the nearest separable state is NP-hard, we select 

236 n_sigmas of random separable states to compute the distance to, which 

237 is not necessarily the nearest. Thus, this measure of entanglement 

238 presents an upper limit of entanglement. 

239 

240 As the relative entropy is not necessarily between zero and one, this 

241 function also normalises by the relative entroy to the GHZ state. 

242 

243 Args: 

244 model (Model): The quantum circuit model. 

245 n_samples (int): Number of samples per qubit. 

246 If <= 0, the current parameters of the model are used. 

247 n_sigmas (int): Number of random separable pure states to compare against. 

248 random_key (Optional[jax.random.PRNGKey]): JAX random key for 

249 parameter initialization. If None, uses the model's internal 

250 random key. 

251 scale (bool): Whether to scale the number of samples. 

252 kwargs (Any): Additional keyword arguments for the model function. 

253 

254 Returns: 

255 float: Entangling capacity of the given circuit, guaranteed 

256 to be between 0.0 and 1.0. 

257 """ 

258 dim = jnp.power(2, model.n_qubits) 

259 if scale: 

260 n_samples = dim * n_samples 

261 n_sigmas = dim * n_sigmas 

262 

263 if random_key is None: 

264 random_key = model.random_key 

265 

266 # Random separable states 

267 log_sigmas = sample_random_separable_states( 

268 model.n_qubits, n_samples=n_sigmas, random_key=random_key, take_log=True 

269 ) 

270 

271 random_key, _ = jax.random.split(random_key) 

272 

273 if n_samples is not None and n_samples > 0: 

274 model.initialize_params(random_key, repeat=n_samples) 

275 else: 

276 if len(model.params.shape) <= 2: 

277 model.params = model.params.reshape(1, *model.params.shape) 

278 else: 

279 log.info(f"Using sample size of model params: {model.params.shape[0]}") 

280 

281 rhos, log_rhos = Entanglement._compute_log_density(model, **kwargs) 

282 

283 rel_entropies = jnp.zeros((n_sigmas, model.params.shape[0])) 

284 

285 for i, log_sigma in enumerate(log_sigmas): 

286 rel_entropies = rel_entropies.at[i].set( 

287 Entanglement._compute_rel_entropies(rhos, log_rhos, log_sigma) 

288 ) 

289 

290 # Entropy of GHZ states should be maximal 

291 ghz_model = Model(model.n_qubits, 1, "GHZ", data_reupload=False) 

292 rho_ghz, log_rho_ghz = Entanglement._compute_log_density(ghz_model, **kwargs) 

293 ghz_entropies = Entanglement._compute_rel_entropies( 

294 rho_ghz, log_rho_ghz, log_sigmas 

295 ) 

296 

297 normalised_entropies = rel_entropies / ghz_entropies 

298 

299 # Average all iterated states 

300 entangling_capability = normalised_entropies.T.min(axis=1) 

301 log.debug(f"Variance of measure: {entangling_capability.var()}") 

302 

303 return entangling_capability.mean() 

304 

305 @staticmethod 

306 def _compute_log_density(model: Model, **kwargs) -> Tuple[jnp.ndarray, jnp.ndarray]: 

307 """ 

308 Obtains the density matrix of a model and computes its logarithm. 

309 

310 Args: 

311 model (Model): The model for which to compute the density matrix. 

312 

313 Returns: 

314 Tuple[jnp.ndarray, jnp.ndarray]: 

315 - jnp.ndarray: density matrix. 

316 - jnp.ndarray: logarithm of the density matrix. 

317 """ 

318 # implicitly set input to none in case it's not needed 

319 kwargs.setdefault("inputs", None) 

320 # explicitly set execution type because everything else won't work 

321 rho = model(execution_type="density", **kwargs) 

322 rho = rho.reshape(-1, 2**model.n_qubits, 2**model.n_qubits) 

323 log_rho = logm_v(rho) / jnp.log(2) 

324 return rho, log_rho 

325 

326 @staticmethod 

327 def _compute_rel_entropies( 

328 rhos: jnp.ndarray, 

329 log_rhos: jnp.ndarray, 

330 log_sigmas: jnp.ndarray, 

331 ) -> jnp.ndarray: 

332 """ 

333 Compute the relative entropy for a given model. 

334 

335 Args: 

336 rhos (jnp.ndarray): Density matrix result of the circuit, has shape 

337 (R, 2^n, 2^n), with the batch size R and number of qubits n 

338 log_rhos (jnp.ndarray): Corresponding logarithm of the density 

339 matrix, has shape (R, 2^n, 2^n). 

340 log_sigmas (jnp.ndarray): Density matrix of next separable state, 

341 has shape (2^n, 2^n) if it's a single sigma or (S, 2^n, 2^n), 

342 with the batch size S (number of sigmas). 

343 

344 Returns: 

345 jnp.ndarray: Relative Entropy for each sample 

346 """ 

347 n_rhos = rhos.shape[0] 

348 if len(log_sigmas.shape) == 3: 

349 n_sigmas = log_sigmas.shape[0] 

350 rhos = jnp.tile(rhos, (n_sigmas, 1, 1)) 

351 log_rhos = jnp.tile(log_rhos, (n_sigmas, 1, 1)) 

352 einsum_subscript = "ij,jk->ik" 

353 else: 

354 n_sigmas = 1 

355 log_sigmas = log_sigmas[jnp.newaxis, ...].repeat(n_rhos, axis=0) 

356 

357 einsum_subscript = "ij,jk->ik" 

358 

359 def _f(rhos, log_rhos, log_sigmas): 

360 prod = jnp.einsum(einsum_subscript, rhos, log_rhos - log_sigmas) 

361 rel_entropies = jnp.abs(jnp.trace(prod, axis1=-2, axis2=-1)) 

362 return rel_entropies 

363 

364 rel_entropies = jax.vmap(_f, in_axes=(0, 0, 0))(rhos, log_rhos, log_sigmas) 

365 

366 if n_sigmas > 1: 

367 rel_entropies = rel_entropies.reshape(n_sigmas, n_rhos) 

368 return rel_entropies 

369 

370 @staticmethod 

371 def entanglement_of_formation( 

372 model: Model, 

373 n_samples: int, 

374 random_key: Optional[jax.random.PRNGKey] = None, 

375 scale: bool = False, 

376 always_decompose: bool = False, 

377 **kwargs: Any, 

378 ) -> float: 

379 """ 

380 This function implements the entanglement of formation for mixed 

381 quantum systems. 

382 In that a mixed state gets decomposed into pure states with respective 

383 probabilities using the eigendecomposition of the density matrix. 

384 Then, the Meyer-Wallach measure is computed for each pure state, 

385 weighted by the eigenvalue. 

386 See e.g. https://doi.org/10.48550/arXiv.quant-ph/0504163 

387 

388 Note that the decomposition is *not unique*! Therefore, this measure 

389 presents the entanglement for *some* decomposition into pure states, 

390 not necessarily the one that is anticipated when applying the Kraus 

391 channels. 

392 If a pure state is provided, this results in the same value as the 

393 Entanglement.meyer_wallach function if `always_decompose` flag is not set. 

394 

395 Args: 

396 model (Model): The quantum circuit model. 

397 n_samples (int): Number of samples per qubit. 

398 random_key (Optional[jax.random.PRNGKey]): JAX random key for 

399 parameter initialization. If None, uses the model's internal 

400 random key. 

401 scale (bool): Whether to scale the number of samples. 

402 always_decompose (bool): Whether to explicitly compute the 

403 entantlement of formation for the eigendecomposition of a pure 

404 state. 

405 kwargs (Any): Additional keyword arguments for the model function. 

406 

407 Returns: 

408 float: Entangling capacity of the given circuit, guaranteed 

409 to be between 0.0 and 1.0. 

410 """ 

411 

412 if scale: 

413 n_samples = jnp.power(2, model.n_qubits) * n_samples 

414 

415 if n_samples is not None and n_samples > 0: 

416 model.initialize_params(random_key, repeat=n_samples) 

417 else: 

418 if len(model.params.shape) <= 2: 

419 model.params = model.params.reshape(1, *model.params.shape) 

420 else: 

421 log.info(f"Using sample size of model params: {model.params.shape[0]}") 

422 

423 # implicitly set input to none in case it's not needed 

424 kwargs.setdefault("inputs", None) 

425 rhos = model(execution_type="density", **kwargs) 

426 rhos = rhos.reshape(-1, 2**model.n_qubits, 2**model.n_qubits) 

427 ent = Entanglement._compute_entanglement_of_formation( 

428 rhos, model.n_qubits, always_decompose 

429 ) 

430 return ent.mean() 

431 

432 @staticmethod 

433 def _compute_entanglement_of_formation( 

434 rhos: jnp.ndarray, 

435 n_qubits: int, 

436 always_decompose: bool, 

437 ) -> jnp.ndarray: 

438 """ 

439 Computes the entanglement of formation for a given batch of density 

440 matrices. 

441 

442 Args: 

443 rho (jnp.ndarray): The density matrices, has shape (B_s, 2^n, 2^n), 

444 where B_s is the batch size and n the number of qubits. 

445 n_qubits (int): Number of qubits 

446 always_decompose (bool): Whether to explicitly compute the 

447 entantlement of formation for the eigendecomposition of a pure 

448 state. 

449 

450 Returns: 

451 jnp.ndarray: Entanglement for the provided density matrices. 

452 """ 

453 eigenvalues, eigenvectors = jnp.linalg.eigh(rhos) 

454 if not always_decompose and jnp.isclose(eigenvalues, 1.0).any(axis=-1).all(): 

455 return Entanglement._compute_meyer_wallach_meas(rhos, n_qubits) 

456 

457 rhos = np.einsum("sij,sik->sijk", eigenvectors, eigenvectors.conjugate()) 

458 measures = Entanglement._compute_meyer_wallach_meas( 

459 rhos.reshape(-1, 2**n_qubits, 2**n_qubits), n_qubits 

460 ) 

461 ent = np.einsum("si,si->s", measures.reshape(-1, 2**n_qubits), eigenvalues) 

462 return ent 

463 

464 @staticmethod 

465 def concentratable_entanglement( 

466 model: Model, 

467 n_samples: int, 

468 random_key: Optional[jax.random.PRNGKey] = None, 

469 scale: bool = False, 

470 **kwargs: Any, 

471 ) -> float: 

472 """ 

473 Computes the concentratable entanglement of a given model. 

474 

475 This method utilizes the Concentratable Entanglement measure from 

476 https://arxiv.org/abs/2104.06923. The swap test is implemented 

477 directly in yaqsi using a ``3 * n_qubits`` circuit. 

478 

479 Args: 

480 model (Model): The quantum circuit model. 

481 n_samples (int): The number of samples to compute the measure for. 

482 random_key (Optional[jax.random.PRNGKey]): JAX random key for 

483 parameter initialization. If None, uses the model's internal 

484 random key. 

485 scale (bool): Whether to scale the number of samples according to 

486 the number of qubits. 

487 **kwargs (Any): Additional keyword arguments for the model function. 

488 

489 Returns: 

490 float: Entangling capability of the given circuit, guaranteed 

491 to be between 0.0 and 1.0. 

492 """ 

493 n = model.n_qubits 

494 N = 2**n 

495 

496 if scale: 

497 n_samples = N * n_samples 

498 

499 def _swap_test_circuit( 

500 params, inputs, pulse_params=None, random_key=None, **kw 

501 ): 

502 """Swap-test circuit on 3*n qubits.""" 

503 from qml_essentials.tape import copy_to_tape 

504 

505 def vari(): 

506 model._variational( 

507 params, 

508 inputs, 

509 pulse_params=pulse_params, 

510 random_key=random_key, 

511 **kw 

512 ) 

513 

514 # First copy on wires n..2n-1 

515 copy_to_tape(vari, offset=n) 

516 # Second copy on wires 2n..3n-1 

517 copy_to_tape(vari, offset=2 * n) 

518 

519 # Swap test: H on ancilla register (wires 0..n-1) 

520 for i in range(n): 

521 op.H(wires=i) 

522 

523 for i in range(n): 

524 op.CSWAP(wires=[i, i + n, i + 2 * n]) 

525 

526 for i in range(n): 

527 op.H(wires=i) 

528 

529 swap_script = ys.Script(f=_swap_test_circuit, n_qubits=3 * n) 

530 

531 if n_samples is not None and n_samples > 0: 

532 random_key = model.initialize_params(random_key, repeat=n_samples) 

533 else: 

534 if len(model.params.shape) <= 2: 

535 model.params = model.params.reshape(1, *model.params.shape) 

536 else: 

537 log.info(f"Using sample size of model params: {model.params.shape[0]}") 

538 

539 params = model.params 

540 inputs = model._inputs_validation(kwargs.get("inputs", None)) 

541 n_batch = params.shape[0] 

542 

543 marg_probs = jax.jit(ys.marginalize_probs, static_argnums=(1, 2)) 

544 

545 if n_batch > 1: 

546 from qml_essentials.utils import safe_random_split 

547 

548 random_keys = safe_random_split(random_key, num=n_batch) 

549 probs = swap_script.execute( 

550 type="probs", 

551 args=(params, inputs, model.pulse_params, random_keys), 

552 in_axes=(0, None, None, 0), 

553 kwargs=kwargs, 

554 ) 

555 else: 

556 probs = swap_script.execute( 

557 type="probs", 

558 args=(params, inputs, model.pulse_params, random_key), 

559 kwargs=kwargs, 

560 ) 

561 

562 # Marginalize to the ancilla register (wires 0..n-1) 

563 probs = marg_probs(probs, 3 * n, tuple(range(n))) 

564 

565 ent = 1 - probs[..., 0] 

566 

567 log.debug(f"Variance of measure: {ent.var()}") 

568 

569 return float(ent.mean()) 

570 

571 @staticmethod 

572 def concentratable_entanglement_estimation( 

573 model: Model, 

574 n_samples: int, 

575 random_key: Optional[jax.random.PRNGKey] = None, 

576 scale: bool = False, 

577 **kwargs: Any, 

578 ) -> float: 

579 """ 

580 Computes the concentratable entanglement of a given model. 

581 

582 This method utilizes the Concentratable Entanglement measure from 

583 https://arxiv.org/abs/2104.06923. The swap test is implemented 

584 directly in yaqsi using a ``3 * n_qubits`` circuit. 

585 

586 Args: 

587 model (Model): The quantum circuit model. 

588 n_samples (int): The number of samples to compute the measure for. 

589 random_key (Optional[jax.random.PRNGKey]): JAX random key for 

590 parameter initialization. If None, uses the model's internal 

591 random key. 

592 scale (bool): Whether to scale the number of samples according to 

593 the number of qubits. 

594 **kwargs (Any): Additional keyword arguments for the model function. 

595 

596 Returns: 

597 float: Entangling capability of the given circuit, guaranteed 

598 to be between 0.0 and 1.0. 

599 """ 

600 n = model.n_qubits 

601 N = 2**n 

602 

603 if scale: 

604 n_samples = N * n_samples 

605 

606 def _bell_basis_measurement( 

607 params, inputs, pulse_params=None, random_key=None, **kw 

608 ): 

609 """Bell-basis measurement circuit on 3*n qubits.""" 

610 from qml_essentials.tape import copy_to_tape 

611 

612 def vari(): 

613 model._variational( 

614 params, 

615 inputs, 

616 pulse_params=pulse_params, 

617 random_key=random_key, 

618 **kw, 

619 ) 

620 

621 # First copy on wires 0..n-1 

622 copy_to_tape(vari, offset=0) 

623 # Second copy on wires n..2n-1 

624 copy_to_tape(vari, offset=n) 

625 

626 for i in range(n): 

627 op.CX(wires=[i, i + n]) 

628 op.H(wires=i) 

629 

630 bell_basis_script = ys.Script(f=_bell_basis_measurement, n_qubits=2 * n) 

631 

632 if n_samples is not None and n_samples > 0: 

633 random_key = model.initialize_params(random_key, repeat=n_samples) 

634 else: 

635 if len(model.params.shape) <= 2: 

636 model.params = model.params.reshape(1, *model.params.shape) 

637 else: 

638 log.info(f"Using sample size of model params: {model.params.shape[0]}") 

639 

640 params = model.params 

641 inputs = model._inputs_validation(kwargs.get("inputs", None)) 

642 n_batch = params.shape[0] 

643 

644 # SWAP operator in Bell-basis 

645 SWAP = jnp.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, -1]]) 

646 # Construct observable for measuring CE 

647 CE_observable = op.Id([0, n]) + op.Operation([0, n], SWAP) 

648 for i in range(1, n): 

649 CE_observable = (CE_observable @ 

650 (op.Id([i, i + n]) + op.Operation([i, i + n], SWAP))) 

651 CE_observable = (1/N) * CE_observable 

652 

653 expvals = [] 

654 if n_batch > 1: 

655 from qml_essentials.utils import safe_random_split 

656 

657 random_keys = safe_random_split(random_key, num=n_batch) 

658 expvals = bell_basis_script.execute( 

659 type="expval", 

660 obs=[CE_observable], 

661 args=(params, inputs, model.pulse_params, random_keys), 

662 in_axes=(0, None, None, 0), 

663 kwargs=kwargs, 

664 ) 

665 else: 

666 expvals = bell_basis_script.execute( 

667 type="expval", 

668 obs=[CE_observable], 

669 args=(params, inputs, model.pulse_params, random_key), 

670 kwargs=kwargs, 

671 ) 

672 

673 ent = 1 - expvals 

674 log.debug(f"Variance of measure: {ent.var()}") 

675 return float(ent.mean()) 

676 

677 

678def sample_random_separable_states( 

679 n_qubits: int, 

680 n_samples: int, 

681 random_key: jax.random.PRNGKey, 

682 take_log: bool = False, 

683) -> jnp.ndarray: 

684 """ 

685 Sample random separable states (density matrix). 

686 

687 Args: 

688 n_qubits (int): number of qubits in the state 

689 n_samples (int): number of states 

690 random_key (random.PRNGKey): JAX random key 

691 take_log (bool): if the matrix logarithm of the density matrix should be taken. 

692 

693 Returns: 

694 jnp.ndarray: Density matrices of shape (n_samples, 2**n_qubits, 2**n_qubits) 

695 """ 

696 model = Model(n_qubits, 1, "No_Entangling", data_reupload=False) 

697 model.initialize_params(random_key, repeat=n_samples) 

698 # explicitly set execution type because anything else won't work 

699 sigmas = model(execution_type="density", inputs=None) 

700 if take_log: 

701 sigmas = logm_v(sigmas) / jnp.log(2.0 + 0j) 

702 

703 return sigmas