Coverage for qml_essentials / entanglement.py: 92%

236 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-16 10:19 +0000

1from typing import Optional, Any, Tuple 

2import jax 

3import jax.numpy as jnp 

4import numpy as np 

5 

6from qml_essentials import yaqsi as ys 

7from qml_essentials import operations as op 

8from qml_essentials.math import logm_v 

9from qml_essentials.model import Model 

10import logging 

11 

12log = logging.getLogger(__name__) 

13 

14 

15class Entanglement: 

16 @classmethod 

17 def meyer_wallach( 

18 cls, 

19 model: Model, 

20 n_samples: Optional[int | None], 

21 random_key: Optional[jax.random.PRNGKey] = None, 

22 scale: bool = False, 

23 **kwargs: Any, 

24 ) -> float: 

25 """ 

26 Calculates the entangling capacity of a given quantum circuit 

27 using Meyer-Wallach measure. 

28 

29 Args: 

30 model (Model): The quantum circuit model. 

31 n_samples (Optional[int]): Number of samples per qubit. 

32 If None or < 0, the current parameters of the model are used. 

33 random_key (Optional[jax.random.PRNGKey]): JAX random key for 

34 parameter initialization. If None, uses the model's internal 

35 random key. 

36 scale (bool): Whether to scale the number of samples. 

37 kwargs (Any): Additional keyword arguments for the model function. 

38 

39 Returns: 

40 float: Entangling capacity of the given circuit, guaranteed 

41 to be between 0.0 and 1.0. 

42 """ 

43 if "noise_params" in kwargs: 

44 log.warning( 

45 "Meyer-Wallach measure not suitable for noisy circuits. " 

46 "Consider 'concentratable entanglement' instead." 

47 ) 

48 

49 if scale: 

50 n_samples = jnp.power(2, model.n_qubits) * n_samples 

51 

52 if n_samples is not None and n_samples > 0: 

53 random_key = model.initialize_params(random_key, repeat=n_samples) 

54 

55 # implicitly set input to none in case it's not needed 

56 kwargs.setdefault("inputs", None) 

57 # explicitly set execution type because everything else won't work 

58 rhos = model(execution_type="density", **kwargs).reshape( 

59 -1, 2**model.n_qubits, 2**model.n_qubits 

60 ) 

61 

62 ent = cls._compute_meyer_wallach_meas(rhos, model.n_qubits) 

63 

64 log.debug(f"Variance of measure: {ent.var()}") 

65 

66 return ent.mean() 

67 

68 @classmethod 

69 def _compute_meyer_wallach_meas( 

70 cls, rhos: jnp.ndarray, n_qubits: int 

71 ) -> jnp.ndarray: 

72 """ 

73 Computes the Meyer-Wallach entangling capability measure for a given 

74 set of density matrices. 

75 

76 Args: 

77 rhos (jnp.ndarray): Density matrices of the sample quantum states. 

78 The shape is (B_s, 2^n, 2^n), where B_s is the number of samples 

79 (batch) and n the number of qubits 

80 n_qubits (int): The number of qubits 

81 

82 Returns: 

83 jnp.ndarray: Entangling capability for each sample, array with 

84 shape (B_s,) 

85 """ 

86 qb = list(range(n_qubits)) 

87 

88 def _f(rhos): 

89 entropy = 0 

90 for j in range(n_qubits): 

91 # Formula 6 in https://doi.org/10.48550/arXiv.quant-ph/0305094 

92 # Trace out qubit j, keep all others 

93 keep = qb[:j] + qb[j + 1 :] 

94 density = ys.partial_trace(rhos, n_qubits, keep) 

95 # only real values, because imaginary part will be separate 

96 # in all following calculations anyway 

97 # entropy should be 1/2 <= entropy <= 1 

98 entropy += jnp.trace((density @ density).real, axis1=-2, axis2=-1) 

99 

100 # inverse averaged entropy and scale to [0, 1] 

101 return 2 * (1 - entropy / n_qubits) 

102 

103 return jax.vmap(_f)(rhos) 

104 

105 @classmethod 

106 def bell_measurements( 

107 cls, 

108 model: Model, 

109 n_samples: int, 

110 random_key: Optional[jax.random.PRNGKey] = None, 

111 scale: bool = False, 

112 **kwargs: Any, 

113 ) -> float: 

114 """ 

115 Compute the Bell measurement for a given model. 

116 

117 Constructs a ``2 * n_qubits`` circuit that prepares two copies of 

118 the model state (on disjoint qubit registers), applies CNOTs and 

119 Hadamards, and measures probabilities on the first register. 

120 

121 Args: 

122 model (Model): The quantum circuit model. 

123 n_samples (int): The number of samples to compute the measure for. 

124 random_key (Optional[jax.random.PRNGKey]): JAX random key for 

125 parameter initialization. If None, uses the model's internal 

126 random key. 

127 scale (bool): Whether to scale the number of samples 

128 according to the number of qubits. 

129 **kwargs (Any): Additional keyword arguments for the model function. 

130 

131 Returns: 

132 float: The Bell measurement value. 

133 """ 

134 if "noise_params" in kwargs: 

135 log.warning( 

136 "Bell Measurements not suitable for noisy circuits. " 

137 "Consider 'concentratable entanglement' instead." 

138 ) 

139 

140 if scale: 

141 n_samples = jnp.power(2, model.n_qubits) * n_samples 

142 

143 n = model.n_qubits 

144 

145 def _bell_circuit(params, inputs, pulse_params=None, random_key=None, **kw): 

146 """Bell measurement circuit on 2*n qubits.""" 

147 from qml_essentials.tape import copy_to_tape 

148 

149 def vari(): 

150 model._variational( 

151 params, 

152 inputs, 

153 pulse_params=pulse_params, 

154 random_key=random_key, 

155 **kw, 

156 ) 

157 

158 # First copy on wires 0..n-1 

159 vari() 

160 # Second copy on wires n..2n-1 

161 copy_to_tape(vari, offset=n) 

162 

163 # Bell measurement: CNOT + H 

164 for q in range(n): 

165 op.CX(wires=[q, q + n]) 

166 op.H(wires=q) 

167 

168 bell_script = ys.Script(f=_bell_circuit, n_qubits=2 * n) 

169 

170 if n_samples is not None and n_samples > 0: 

171 random_key = model.initialize_params(random_key, repeat=n_samples) 

172 params = model.params 

173 else: 

174 if len(model.params.shape) <= 2: 

175 params = model.params.reshape(1, *model.params.shape) 

176 else: 

177 log.info(f"Using sample size of model params: {model.params.shape[0]}") 

178 params = model.params 

179 

180 n_samples = params.shape[0] 

181 inputs = model._inputs_validation(kwargs.get("inputs", None)) 

182 

183 # Execute: vmap over batch dimension of params (axis 0) 

184 if n_samples > 1: 

185 from qml_essentials.utils import safe_random_split 

186 

187 random_keys = safe_random_split(random_key, num=n_samples) 

188 result = bell_script.execute( 

189 type="probs", 

190 args=(params, inputs, model.pulse_params, random_keys), 

191 kwargs=kwargs, 

192 in_axes=(0, None, None, 0), 

193 ) 

194 else: 

195 result = bell_script.execute( 

196 type="probs", 

197 args=(params, inputs, model.pulse_params, random_key), 

198 kwargs=kwargs, 

199 ) 

200 

201 # Marginalize: for each qubit q, keep wires [q, q+n] from the 2n-qubit probs 

202 # The last probability in each pair gives P(|11⟩) for that qubit pair 

203 per_qubit = [] 

204 for q in range(n): 

205 marg = ys.marginalize_probs(result, 2 * n, [q, q + n]) 

206 per_qubit.append(marg) 

207 # per_qubit[q] has shape (n_samples, 4) or (4,) 

208 exp = jnp.stack(per_qubit, axis=-2) # (..., n, 4) 

209 exp = 1 - 2 * exp[..., -1] # (..., n) 

210 

211 if not jnp.isclose(jnp.sum(exp.imag), 0, atol=1e-6): 

212 log.warning("Imaginary part of probabilities detected") 

213 exp = jnp.abs(exp) 

214 

215 measure = 2 * (1 - exp.mean(axis=0)) 

216 entangling_capability = min(max(float(measure.mean()), 0.0), 1.0) 

217 log.debug(f"Variance of measure: {measure.var()}") 

218 

219 return entangling_capability 

220 

221 @classmethod 

222 def relative_entropy( 

223 cls, 

224 model: Model, 

225 n_samples: int, 

226 n_sigmas: int, 

227 random_key: Optional[jax.random.PRNGKey] = None, 

228 scale: bool = False, 

229 **kwargs: Any, 

230 ) -> float: 

231 """ 

232 Calculates the relative entropy of entanglement of a given quantum 

233 circuit. This measure is also applicable to mixed state, albeit it 

234 might me not fully accurate in this simplified case. 

235 

236 As the relative entropy is generally defined as the smallest relative 

237 entropy from the state in question to the set of separable states. 

238 However, as computing the nearest separable state is NP-hard, we select 

239 n_sigmas of random separable states to compute the distance to, which 

240 is not necessarily the nearest. Thus, this measure of entanglement 

241 presents an upper limit of entanglement. 

242 

243 As the relative entropy is not necessarily between zero and one, this 

244 function also normalises by the relative entroy to the GHZ state. 

245 

246 Args: 

247 model (Model): The quantum circuit model. 

248 n_samples (int): Number of samples per qubit. 

249 If <= 0, the current parameters of the model are used. 

250 n_sigmas (int): Number of random separable pure states to compare against. 

251 random_key (Optional[jax.random.PRNGKey]): JAX random key for 

252 parameter initialization. If None, uses the model's internal 

253 random key. 

254 scale (bool): Whether to scale the number of samples. 

255 kwargs (Any): Additional keyword arguments for the model function. 

256 

257 Returns: 

258 float: Entangling capacity of the given circuit, guaranteed 

259 to be between 0.0 and 1.0. 

260 """ 

261 dim = jnp.power(2, model.n_qubits) 

262 if scale: 

263 n_samples = dim * n_samples 

264 n_sigmas = dim * n_sigmas 

265 

266 if random_key is None: 

267 random_key = model.random_key 

268 

269 # Random separable states 

270 log_sigmas = sample_random_separable_states( 

271 model.n_qubits, n_samples=n_sigmas, random_key=random_key, take_log=True 

272 ) 

273 

274 random_key, _ = jax.random.split(random_key) 

275 

276 if n_samples is not None and n_samples > 0: 

277 model.initialize_params(random_key, repeat=n_samples) 

278 else: 

279 if len(model.params.shape) <= 2: 

280 model.params = model.params.reshape(1, *model.params.shape) 

281 else: 

282 log.info(f"Using sample size of model params: {model.params.shape[0]}") 

283 

284 rhos, log_rhos = cls._compute_log_density(model, **kwargs) 

285 

286 rel_entropies = jnp.zeros((n_sigmas, model.params.shape[0])) 

287 

288 for i, log_sigma in enumerate(log_sigmas): 

289 rel_entropies = rel_entropies.at[i].set( 

290 cls._compute_rel_entropies(rhos, log_rhos, log_sigma) 

291 ) 

292 

293 # Entropy of GHZ states should be maximal 

294 ghz_model = Model(model.n_qubits, 1, "GHZ", data_reupload=False) 

295 rho_ghz, log_rho_ghz = cls._compute_log_density(ghz_model, **kwargs) 

296 ghz_entropies = cls._compute_rel_entropies(rho_ghz, log_rho_ghz, log_sigmas) 

297 

298 normalised_entropies = rel_entropies / ghz_entropies 

299 

300 # Average all iterated states 

301 entangling_capability = normalised_entropies.T.min(axis=1) 

302 log.debug(f"Variance of measure: {entangling_capability.var()}") 

303 

304 return entangling_capability.mean() 

305 

306 @classmethod 

307 def _compute_log_density( 

308 cls, model: Model, **kwargs 

309 ) -> Tuple[jnp.ndarray, jnp.ndarray]: 

310 """ 

311 Obtains the density matrix of a model and computes its logarithm. 

312 

313 Args: 

314 model (Model): The model for which to compute the density matrix. 

315 

316 Returns: 

317 Tuple[jnp.ndarray, jnp.ndarray]: 

318 - jnp.ndarray: density matrix. 

319 - jnp.ndarray: logarithm of the density matrix. 

320 """ 

321 # implicitly set input to none in case it's not needed 

322 kwargs.setdefault("inputs", None) 

323 # explicitly set execution type because everything else won't work 

324 rho = model(execution_type="density", **kwargs) 

325 rho = rho.reshape(-1, 2**model.n_qubits, 2**model.n_qubits) 

326 log_rho = logm_v(rho) / jnp.log(2) 

327 return rho, log_rho 

328 

329 @classmethod 

330 def _compute_rel_entropies( 

331 cls, 

332 rhos: jnp.ndarray, 

333 log_rhos: jnp.ndarray, 

334 log_sigmas: jnp.ndarray, 

335 ) -> jnp.ndarray: 

336 """ 

337 Compute the relative entropy for a given model. 

338 

339 Args: 

340 rhos (jnp.ndarray): Density matrix result of the circuit, has shape 

341 (R, 2^n, 2^n), with the batch size R and number of qubits n 

342 log_rhos (jnp.ndarray): Corresponding logarithm of the density 

343 matrix, has shape (R, 2^n, 2^n). 

344 log_sigmas (jnp.ndarray): Density matrix of next separable state, 

345 has shape (2^n, 2^n) if it's a single sigma or (S, 2^n, 2^n), 

346 with the batch size S (number of sigmas). 

347 

348 Returns: 

349 jnp.ndarray: Relative Entropy for each sample 

350 """ 

351 n_rhos = rhos.shape[0] 

352 if len(log_sigmas.shape) == 3: 

353 n_sigmas = log_sigmas.shape[0] 

354 rhos = jnp.tile(rhos, (n_sigmas, 1, 1)) 

355 log_rhos = jnp.tile(log_rhos, (n_sigmas, 1, 1)) 

356 einsum_subscript = "ij,jk->ik" 

357 else: 

358 n_sigmas = 1 

359 log_sigmas = log_sigmas[jnp.newaxis, ...].repeat(n_rhos, axis=0) 

360 

361 einsum_subscript = "ij,jk->ik" 

362 

363 def _f(rhos, log_rhos, log_sigmas): 

364 prod = jnp.einsum(einsum_subscript, rhos, log_rhos - log_sigmas) 

365 rel_entropies = jnp.abs(jnp.trace(prod, axis1=-2, axis2=-1)) 

366 return rel_entropies 

367 

368 rel_entropies = jax.vmap(_f, in_axes=(0, 0, 0))(rhos, log_rhos, log_sigmas) 

369 

370 if n_sigmas > 1: 

371 rel_entropies = rel_entropies.reshape(n_sigmas, n_rhos) 

372 return rel_entropies 

373 

374 @classmethod 

375 def entanglement_of_formation( 

376 cls, 

377 model: Model, 

378 n_samples: int, 

379 random_key: Optional[jax.random.PRNGKey] = None, 

380 scale: bool = False, 

381 always_decompose: bool = False, 

382 **kwargs: Any, 

383 ) -> float: 

384 """ 

385 This function implements the entanglement of formation for mixed 

386 quantum systems. 

387 In that a mixed state gets decomposed into pure states with respective 

388 probabilities using the eigendecomposition of the density matrix. 

389 Then, the Meyer-Wallach measure is computed for each pure state, 

390 weighted by the eigenvalue. 

391 See e.g. https://doi.org/10.48550/arXiv.quant-ph/0504163 

392 

393 Note that the decomposition is *not unique*! Therefore, this measure 

394 presents the entanglement for *some* decomposition into pure states, 

395 not necessarily the one that is anticipated when applying the Kraus 

396 channels. 

397 If a pure state is provided, this results in the same value as the 

398 Entanglement.meyer_wallach function if `always_decompose` flag is not set. 

399 

400 Args: 

401 model (Model): The quantum circuit model. 

402 n_samples (int): Number of samples per qubit. 

403 random_key (Optional[jax.random.PRNGKey]): JAX random key for 

404 parameter initialization. If None, uses the model's internal 

405 random key. 

406 scale (bool): Whether to scale the number of samples. 

407 always_decompose (bool): Whether to explicitly compute the 

408 entantlement of formation for the eigendecomposition of a pure 

409 state. 

410 kwargs (Any): Additional keyword arguments for the model function. 

411 

412 Returns: 

413 float: Entangling capacity of the given circuit, guaranteed 

414 to be between 0.0 and 1.0. 

415 """ 

416 

417 if scale: 

418 n_samples = jnp.power(2, model.n_qubits) * n_samples 

419 

420 if n_samples is not None and n_samples > 0: 

421 model.initialize_params(random_key, repeat=n_samples) 

422 else: 

423 if len(model.params.shape) <= 2: 

424 model.params = model.params.reshape(1, *model.params.shape) 

425 else: 

426 log.info(f"Using sample size of model params: {model.params.shape[0]}") 

427 

428 # implicitly set input to none in case it's not needed 

429 kwargs.setdefault("inputs", None) 

430 rhos = model(execution_type="density", **kwargs) 

431 rhos = rhos.reshape(-1, 2**model.n_qubits, 2**model.n_qubits) 

432 ent = cls._compute_entanglement_of_formation( 

433 rhos, model.n_qubits, always_decompose 

434 ) 

435 return ent.mean() 

436 

437 @classmethod 

438 def _compute_entanglement_of_formation( 

439 cls, 

440 rhos: jnp.ndarray, 

441 n_qubits: int, 

442 always_decompose: bool, 

443 ) -> jnp.ndarray: 

444 """ 

445 Computes the entanglement of formation for a given batch of density 

446 matrices. 

447 

448 Args: 

449 rho (jnp.ndarray): The density matrices, has shape (B_s, 2^n, 2^n), 

450 where B_s is the batch size and n the number of qubits. 

451 n_qubits (int): Number of qubits 

452 always_decompose (bool): Whether to explicitly compute the 

453 entantlement of formation for the eigendecomposition of a pure 

454 state. 

455 

456 Returns: 

457 jnp.ndarray: Entanglement for the provided density matrices. 

458 """ 

459 eigenvalues, eigenvectors = jnp.linalg.eigh(rhos) 

460 if not always_decompose and jnp.isclose(eigenvalues, 1.0).any(axis=-1).all(): 

461 return cls._compute_meyer_wallach_meas(rhos, n_qubits) 

462 

463 rhos = np.einsum("sij,sik->sijk", eigenvectors, eigenvectors.conjugate()) 

464 measures = cls._compute_meyer_wallach_meas( 

465 rhos.reshape(-1, 2**n_qubits, 2**n_qubits), n_qubits 

466 ) 

467 ent = np.einsum("si,si->s", measures.reshape(-1, 2**n_qubits), eigenvalues) 

468 return ent 

469 

470 @classmethod 

471 def concentratable_entanglement( 

472 cls, 

473 model: Model, 

474 n_samples: int, 

475 random_key: Optional[jax.random.PRNGKey] = None, 

476 scale: bool = False, 

477 **kwargs: Any, 

478 ) -> float: 

479 """ 

480 Computes the concentratable entanglement of a given model. 

481 

482 This method utilizes the Concentratable Entanglement measure from 

483 https://arxiv.org/abs/2104.06923. The swap test is implemented 

484 directly in yaqsi using a ``3 * n_qubits`` circuit. 

485 

486 Args: 

487 model (Model): The quantum circuit model. 

488 n_samples (int): The number of samples to compute the measure for. 

489 random_key (Optional[jax.random.PRNGKey]): JAX random key for 

490 parameter initialization. If None, uses the model's internal 

491 random key. 

492 scale (bool): Whether to scale the number of samples according to 

493 the number of qubits. 

494 **kwargs (Any): Additional keyword arguments for the model function. 

495 

496 Returns: 

497 float: Entangling capability of the given circuit, guaranteed 

498 to be between 0.0 and 1.0. 

499 """ 

500 n = model.n_qubits 

501 N = 2**n 

502 

503 if scale: 

504 n_samples = N * n_samples 

505 

506 def _swap_test_circuit( 

507 params, inputs, pulse_params=None, random_key=None, **kw 

508 ): 

509 """Swap-test circuit on 3*n qubits.""" 

510 from qml_essentials.tape import copy_to_tape 

511 

512 def vari(): 

513 model._variational( 

514 params, 

515 inputs, 

516 pulse_params=pulse_params, 

517 random_key=random_key, 

518 **kw, 

519 ) 

520 

521 # First copy on wires n..2n-1 

522 copy_to_tape(vari, offset=n) 

523 # Second copy on wires 2n..3n-1 

524 copy_to_tape(vari, offset=2 * n) 

525 

526 # Swap test: H on ancilla register (wires 0..n-1) 

527 for i in range(n): 

528 op.H(wires=i) 

529 

530 for i in range(n): 

531 op.CSWAP(wires=[i, i + n, i + 2 * n]) 

532 

533 for i in range(n): 

534 op.H(wires=i) 

535 

536 swap_script = ys.Script(f=_swap_test_circuit, n_qubits=3 * n) 

537 

538 if n_samples is not None and n_samples > 0: 

539 random_key = model.initialize_params(random_key, repeat=n_samples) 

540 else: 

541 if len(model.params.shape) <= 2: 

542 model.params = model.params.reshape(1, *model.params.shape) 

543 else: 

544 log.info(f"Using sample size of model params: {model.params.shape[0]}") 

545 

546 params = model.params 

547 inputs = model._inputs_validation(kwargs.get("inputs", None)) 

548 n_batch = params.shape[0] 

549 

550 marg_probs = jax.jit(ys.marginalize_probs, static_argnums=(1, 2)) 

551 

552 if n_batch > 1: 

553 from qml_essentials.utils import safe_random_split 

554 

555 random_keys = safe_random_split(random_key, num=n_batch) 

556 probs = swap_script.execute( 

557 type="probs", 

558 args=(params, inputs, model.pulse_params, random_keys), 

559 in_axes=(0, None, None, 0), 

560 kwargs=kwargs, 

561 ) 

562 else: 

563 probs = swap_script.execute( 

564 type="probs", 

565 args=(params, inputs, model.pulse_params, random_key), 

566 kwargs=kwargs, 

567 ) 

568 

569 # Marginalize to the ancilla register (wires 0..n-1) 

570 probs = marg_probs(probs, 3 * n, tuple(range(n))) 

571 

572 ent = 1 - probs[..., 0] 

573 

574 log.debug(f"Variance of measure: {ent.var()}") 

575 

576 return float(ent.mean()) 

577 

578 @classmethod 

579 def concentratable_entanglement_estimation( 

580 cls, 

581 model: Model, 

582 n_samples: int, 

583 random_key: Optional[jax.random.PRNGKey] = None, 

584 scale: bool = False, 

585 **kwargs: Any, 

586 ) -> float: 

587 """ 

588 Computes the concentratable entanglement of a given model. 

589 

590 This method utilizes the Concentratable Entanglement measure from 

591 https://arxiv.org/abs/2104.06923. The swap test is implemented 

592 directly in yaqsi using a ``3 * n_qubits`` circuit. 

593 

594 Args: 

595 model (Model): The quantum circuit model. 

596 n_samples (int): The number of samples to compute the measure for. 

597 random_key (Optional[jax.random.PRNGKey]): JAX random key for 

598 parameter initialization. If None, uses the model's internal 

599 random key. 

600 scale (bool): Whether to scale the number of samples according to 

601 the number of qubits. 

602 **kwargs (Any): Additional keyword arguments for the model function. 

603 

604 Returns: 

605 float: Entangling capability of the given circuit, guaranteed 

606 to be between 0.0 and 1.0. 

607 """ 

608 n = model.n_qubits 

609 N = 2**n 

610 

611 if scale: 

612 n_samples = N * n_samples 

613 

614 def _bell_basis_measurement( 

615 params, inputs, pulse_params=None, random_key=None, **kw 

616 ): 

617 """Bell-basis measurement circuit on 3*n qubits.""" 

618 from qml_essentials.tape import copy_to_tape 

619 

620 def vari(): 

621 model._variational( 

622 params, 

623 inputs, 

624 pulse_params=pulse_params, 

625 random_key=random_key, 

626 **kw, 

627 ) 

628 

629 # First copy on wires 0..n-1 

630 copy_to_tape(vari, offset=0) 

631 # Second copy on wires n..2n-1 

632 copy_to_tape(vari, offset=n) 

633 

634 for i in range(n): 

635 op.CX(wires=[i, i + n]) 

636 op.H(wires=i) 

637 

638 bell_basis_script = ys.Script(f=_bell_basis_measurement, n_qubits=2 * n) 

639 

640 if n_samples is not None and n_samples > 0: 

641 random_key = model.initialize_params(random_key, repeat=n_samples) 

642 else: 

643 if len(model.params.shape) <= 2: 

644 model.params = model.params.reshape(1, *model.params.shape) 

645 else: 

646 log.info(f"Using sample size of model params: {model.params.shape[0]}") 

647 

648 params = model.params 

649 inputs = model._inputs_validation(kwargs.get("inputs", None)) 

650 n_batch = params.shape[0] 

651 

652 # SWAP operator in Bell-basis 

653 SWAP = jnp.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, -1]]) 

654 # Construct observable for measuring CE 

655 CE_observable = op.Id([0, n]) + op.Operation([0, n], SWAP) 

656 for i in range(1, n): 

657 CE_observable = CE_observable @ ( 

658 op.Id([i, i + n]) + op.Operation([i, i + n], SWAP) 

659 ) 

660 CE_observable = (1 / N) * CE_observable 

661 

662 expvals = [] 

663 if n_batch > 1: 

664 from qml_essentials.utils import safe_random_split 

665 

666 random_keys = safe_random_split(random_key, num=n_batch) 

667 expvals = bell_basis_script.execute( 

668 type="expval", 

669 obs=[CE_observable], 

670 args=(params, inputs, model.pulse_params, random_keys), 

671 in_axes=(0, None, None, 0), 

672 kwargs=kwargs, 

673 ) 

674 else: 

675 expvals = bell_basis_script.execute( 

676 type="expval", 

677 obs=[CE_observable], 

678 args=(params, inputs, model.pulse_params, random_key), 

679 kwargs=kwargs, 

680 ) 

681 

682 ent = 1 - expvals 

683 log.debug(f"Variance of measure: {ent.var()}") 

684 return float(ent.mean()) 

685 

686 

687def sample_random_separable_states( 

688 n_qubits: int, 

689 n_samples: int, 

690 random_key: jax.random.PRNGKey, 

691 take_log: bool = False, 

692) -> jnp.ndarray: 

693 """ 

694 Sample random separable states (density matrix). 

695 

696 Args: 

697 n_qubits (int): number of qubits in the state 

698 n_samples (int): number of states 

699 random_key (random.PRNGKey): JAX random key 

700 take_log (bool): if the matrix logarithm of the density matrix should be taken. 

701 

702 Returns: 

703 jnp.ndarray: Density matrices of shape (n_samples, 2**n_qubits, 2**n_qubits) 

704 """ 

705 model = Model(n_qubits, 1, "No_Entangling", data_reupload=False) 

706 model.initialize_params(random_key, repeat=n_samples) 

707 # explicitly set execution type because anything else won't work 

708 sigmas = model(execution_type="density", inputs=None) 

709 if take_log: 

710 sigmas = logm_v(sigmas) / jnp.log(2.0 + 0j) 

711 

712 return sigmas