Coverage for qml_essentials / entanglement.py: 94%

217 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-30 11:43 +0000

1from typing import Optional, Any, Tuple 

2import jax 

3import jax.numpy as jnp 

4import numpy as np 

5 

6from qml_essentials import yaqsi as ys 

7from qml_essentials import operations as op 

8from qml_essentials.math import logm_v 

9from qml_essentials.model import Model 

10import logging 

11 

12log = logging.getLogger(__name__) 

13 

14 

15class Entanglement: 

16 @staticmethod 

17 def meyer_wallach( 

18 model: Model, 

19 n_samples: Optional[int | None], 

20 random_key: Optional[jax.random.PRNGKey] = None, 

21 scale: bool = False, 

22 **kwargs: Any, 

23 ) -> float: 

24 """ 

25 Calculates the entangling capacity of a given quantum circuit 

26 using Meyer-Wallach measure. 

27 

28 Args: 

29 model (Model): The quantum circuit model. 

30 n_samples (Optional[int]): Number of samples per qubit. 

31 If None or < 0, the current parameters of the model are used. 

32 random_key (Optional[jax.random.PRNGKey]): JAX random key for 

33 parameter initialization. If None, uses the model's internal 

34 random key. 

35 scale (bool): Whether to scale the number of samples. 

36 kwargs (Any): Additional keyword arguments for the model function. 

37 

38 Returns: 

39 float: Entangling capacity of the given circuit, guaranteed 

40 to be between 0.0 and 1.0. 

41 """ 

42 if "noise_params" in kwargs: 

43 log.warning( 

44 "Meyer-Wallach measure not suitable for noisy circuits. " 

45 "Consider 'concentratable entanglement' instead." 

46 ) 

47 

48 if scale: 

49 n_samples = jnp.power(2, model.n_qubits) * n_samples 

50 

51 if n_samples is not None and n_samples > 0: 

52 random_key = model.initialize_params(random_key, repeat=n_samples) 

53 

54 # implicitly set input to none in case it's not needed 

55 kwargs.setdefault("inputs", None) 

56 # explicitly set execution type because everything else won't work 

57 rhos = model(execution_type="density", **kwargs).reshape( 

58 -1, 2**model.n_qubits, 2**model.n_qubits 

59 ) 

60 

61 ent = Entanglement._compute_meyer_wallach_meas(rhos, model.n_qubits) 

62 

63 log.debug(f"Variance of measure: {ent.var()}") 

64 

65 return ent.mean() 

66 

67 @staticmethod 

68 def _compute_meyer_wallach_meas(rhos: jnp.ndarray, n_qubits: int) -> jnp.ndarray: 

69 """ 

70 Computes the Meyer-Wallach entangling capability measure for a given 

71 set of density matrices. 

72 

73 Args: 

74 rhos (jnp.ndarray): Density matrices of the sample quantum states. 

75 The shape is (B_s, 2^n, 2^n), where B_s is the number of samples 

76 (batch) and n the number of qubits 

77 n_qubits (int): The number of qubits 

78 

79 Returns: 

80 jnp.ndarray: Entangling capability for each sample, array with 

81 shape (B_s,) 

82 """ 

83 qb = list(range(n_qubits)) 

84 

85 def _f(rhos): 

86 entropy = 0 

87 for j in range(n_qubits): 

88 # Formula 6 in https://doi.org/10.48550/arXiv.quant-ph/0305094 

89 # Trace out qubit j, keep all others 

90 keep = qb[:j] + qb[j + 1 :] 

91 density = ys.partial_trace(rhos, n_qubits, keep) 

92 # only real values, because imaginary part will be separate 

93 # in all following calculations anyway 

94 # entropy should be 1/2 <= entropy <= 1 

95 entropy += jnp.trace((density @ density).real, axis1=-2, axis2=-1) 

96 

97 # inverse averaged entropy and scale to [0, 1] 

98 return 2 * (1 - entropy / n_qubits) 

99 

100 return jax.vmap(_f)(rhos) 

101 

102 @staticmethod 

103 def bell_measurements( 

104 model: Model, 

105 n_samples: int, 

106 random_key: Optional[jax.random.PRNGKey] = None, 

107 scale: bool = False, 

108 **kwargs: Any, 

109 ) -> float: 

110 """ 

111 Compute the Bell measurement for a given model. 

112 

113 Constructs a ``2 * n_qubits`` circuit that prepares two copies of 

114 the model state (on disjoint qubit registers), applies CNOTs and 

115 Hadamards, and measures probabilities on the first register. 

116 

117 Args: 

118 model (Model): The quantum circuit model. 

119 n_samples (int): The number of samples to compute the measure for. 

120 random_key (Optional[jax.random.PRNGKey]): JAX random key for 

121 parameter initialization. If None, uses the model's internal 

122 random key. 

123 scale (bool): Whether to scale the number of samples 

124 according to the number of qubits. 

125 **kwargs (Any): Additional keyword arguments for the model function. 

126 

127 Returns: 

128 float: The Bell measurement value. 

129 """ 

130 if "noise_params" in kwargs: 

131 log.warning( 

132 "Bell Measurements not suitable for noisy circuits. " 

133 "Consider 'concentratable entanglement' instead." 

134 ) 

135 

136 if scale: 

137 n_samples = jnp.power(2, model.n_qubits) * n_samples 

138 

139 n = model.n_qubits 

140 

141 def _bell_circuit(params, inputs, pulse_params=None, random_key=None, **kw): 

142 """Bell measurement circuit on 2*n qubits.""" 

143 # First copy on wires 0..n-1 

144 model._variational( 

145 params, inputs, pulse_params=pulse_params, random_key=random_key, **kw 

146 ) 

147 

148 # TODO: this is very user-unfriendly and we should find a better way 

149 

150 # Second copy on wires n..2n-1: record the tape then shift wires 

151 from qml_essentials.tape import recording as _recording 

152 

153 with _recording() as shifted_tape: 

154 model._variational( 

155 params, 

156 inputs, 

157 pulse_params=pulse_params, 

158 random_key=random_key, 

159 **kw, 

160 ) 

161 for o in shifted_tape: 

162 shifted_op = o.__class__.__new__(o.__class__) 

163 shifted_op.__dict__.update(o.__dict__) 

164 shifted_op._wires = [w + n for w in o.wires] 

165 # Re-register on the active tape 

166 from qml_essentials.tape import active_tape as _active_tape 

167 

168 tape = _active_tape() 

169 if tape is not None: 

170 tape.append(shifted_op) 

171 

172 # Bell measurement: CNOT + H 

173 for q in range(n): 

174 op.CX(wires=[q, q + n]) 

175 op.H(wires=q) 

176 

177 bell_script = ys.Script(f=_bell_circuit, n_qubits=2 * n) 

178 

179 if n_samples is not None and n_samples > 0: 

180 random_key = model.initialize_params(random_key, repeat=n_samples) 

181 params = model.params 

182 else: 

183 if len(model.params.shape) <= 2: 

184 params = model.params.reshape(1, *model.params.shape) 

185 else: 

186 log.info(f"Using sample size of model params: {model.params.shape[0]}") 

187 params = model.params 

188 

189 n_samples = params.shape[0] 

190 inputs = model._inputs_validation(kwargs.get("inputs", None)) 

191 

192 # Execute: vmap over batch dimension of params (axis 0) 

193 if n_samples > 1: 

194 from qml_essentials.utils import safe_random_split 

195 

196 random_keys = safe_random_split(random_key, num=n_samples) 

197 result = bell_script.execute( 

198 type="probs", 

199 args=(params, inputs, model.pulse_params, random_keys), 

200 kwargs=kwargs, 

201 in_axes=(0, None, None, 0), 

202 ) 

203 else: 

204 result = bell_script.execute( 

205 type="probs", 

206 args=(params, inputs, model.pulse_params, random_key), 

207 kwargs=kwargs, 

208 ) 

209 

210 # Marginalize: for each qubit q, keep wires [q, q+n] from the 2n-qubit probs 

211 # The last probability in each pair gives P(|11⟩) for that qubit pair 

212 per_qubit = [] 

213 for q in range(n): 

214 marg = ys.marginalize_probs(result, 2 * n, [q, q + n]) 

215 per_qubit.append(marg) 

216 # per_qubit[q] has shape (n_samples, 4) or (4,) 

217 exp = jnp.stack(per_qubit, axis=-2) # (..., n, 4) 

218 exp = 1 - 2 * exp[..., -1] # (..., n) 

219 

220 if not jnp.isclose(jnp.sum(exp.imag), 0, atol=1e-6): 

221 log.warning("Imaginary part of probabilities detected") 

222 exp = jnp.abs(exp) 

223 

224 measure = 2 * (1 - exp.mean(axis=0)) 

225 entangling_capability = min(max(float(measure.mean()), 0.0), 1.0) 

226 log.debug(f"Variance of measure: {measure.var()}") 

227 

228 return entangling_capability 

229 

230 @staticmethod 

231 def relative_entropy( 

232 model: Model, 

233 n_samples: int, 

234 n_sigmas: int, 

235 random_key: Optional[jax.random.PRNGKey] = None, 

236 scale: bool = False, 

237 **kwargs: Any, 

238 ) -> float: 

239 """ 

240 Calculates the relative entropy of entanglement of a given quantum 

241 circuit. This measure is also applicable to mixed state, albeit it 

242 might me not fully accurate in this simplified case. 

243 

244 As the relative entropy is generally defined as the smallest relative 

245 entropy from the state in question to the set of separable states. 

246 However, as computing the nearest separable state is NP-hard, we select 

247 n_sigmas of random separable states to compute the distance to, which 

248 is not necessarily the nearest. Thus, this measure of entanglement 

249 presents an upper limit of entanglement. 

250 

251 As the relative entropy is not necessarily between zero and one, this 

252 function also normalises by the relative entroy to the GHZ state. 

253 

254 Args: 

255 model (Model): The quantum circuit model. 

256 n_samples (int): Number of samples per qubit. 

257 If <= 0, the current parameters of the model are used. 

258 n_sigmas (int): Number of random separable pure states to compare against. 

259 random_key (Optional[jax.random.PRNGKey]): JAX random key for 

260 parameter initialization. If None, uses the model's internal 

261 random key. 

262 scale (bool): Whether to scale the number of samples. 

263 kwargs (Any): Additional keyword arguments for the model function. 

264 

265 Returns: 

266 float: Entangling capacity of the given circuit, guaranteed 

267 to be between 0.0 and 1.0. 

268 """ 

269 dim = jnp.power(2, model.n_qubits) 

270 if scale: 

271 n_samples = dim * n_samples 

272 n_sigmas = dim * n_sigmas 

273 

274 if random_key is None: 

275 random_key = model.random_key 

276 

277 # Random separable states 

278 log_sigmas = sample_random_separable_states( 

279 model.n_qubits, n_samples=n_sigmas, random_key=random_key, take_log=True 

280 ) 

281 

282 random_key, _ = jax.random.split(random_key) 

283 

284 if n_samples is not None and n_samples > 0: 

285 model.initialize_params(random_key, repeat=n_samples) 

286 else: 

287 if len(model.params.shape) <= 2: 

288 model.params = model.params.reshape(1, *model.params.shape) 

289 else: 

290 log.info(f"Using sample size of model params: {model.params.shape[0]}") 

291 

292 rhos, log_rhos = Entanglement._compute_log_density(model, **kwargs) 

293 

294 rel_entropies = jnp.zeros((n_sigmas, model.params.shape[0])) 

295 

296 for i, log_sigma in enumerate(log_sigmas): 

297 rel_entropies = rel_entropies.at[i].set( 

298 Entanglement._compute_rel_entropies(rhos, log_rhos, log_sigma) 

299 ) 

300 

301 # Entropy of GHZ states should be maximal 

302 ghz_model = Model(model.n_qubits, 1, "GHZ", data_reupload=False) 

303 rho_ghz, log_rho_ghz = Entanglement._compute_log_density(ghz_model, **kwargs) 

304 ghz_entropies = Entanglement._compute_rel_entropies( 

305 rho_ghz, log_rho_ghz, log_sigmas 

306 ) 

307 

308 normalised_entropies = rel_entropies / ghz_entropies 

309 

310 # Average all iterated states 

311 entangling_capability = normalised_entropies.T.min(axis=1) 

312 log.debug(f"Variance of measure: {entangling_capability.var()}") 

313 

314 return entangling_capability.mean() 

315 

316 @staticmethod 

317 def _compute_log_density(model: Model, **kwargs) -> Tuple[jnp.ndarray, jnp.ndarray]: 

318 """ 

319 Obtains the density matrix of a model and computes its logarithm. 

320 

321 Args: 

322 model (Model): The model for which to compute the density matrix. 

323 

324 Returns: 

325 Tuple[jnp.ndarray, jnp.ndarray]: 

326 - jnp.ndarray: density matrix. 

327 - jnp.ndarray: logarithm of the density matrix. 

328 """ 

329 # implicitly set input to none in case it's not needed 

330 kwargs.setdefault("inputs", None) 

331 # explicitly set execution type because everything else won't work 

332 rho = model(execution_type="density", **kwargs) 

333 rho = rho.reshape(-1, 2**model.n_qubits, 2**model.n_qubits) 

334 log_rho = logm_v(rho) / jnp.log(2) 

335 return rho, log_rho 

336 

337 @staticmethod 

338 def _compute_rel_entropies( 

339 rhos: jnp.ndarray, 

340 log_rhos: jnp.ndarray, 

341 log_sigmas: jnp.ndarray, 

342 ) -> jnp.ndarray: 

343 """ 

344 Compute the relative entropy for a given model. 

345 

346 Args: 

347 rhos (jnp.ndarray): Density matrix result of the circuit, has shape 

348 (R, 2^n, 2^n), with the batch size R and number of qubits n 

349 log_rhos (jnp.ndarray): Corresponding logarithm of the density 

350 matrix, has shape (R, 2^n, 2^n). 

351 log_sigmas (jnp.ndarray): Density matrix of next separable state, 

352 has shape (2^n, 2^n) if it's a single sigma or (S, 2^n, 2^n), 

353 with the batch size S (number of sigmas). 

354 

355 Returns: 

356 jnp.ndarray: Relative Entropy for each sample 

357 """ 

358 n_rhos = rhos.shape[0] 

359 if len(log_sigmas.shape) == 3: 

360 n_sigmas = log_sigmas.shape[0] 

361 rhos = jnp.tile(rhos, (n_sigmas, 1, 1)) 

362 log_rhos = jnp.tile(log_rhos, (n_sigmas, 1, 1)) 

363 einsum_subscript = "ij,jk->ik" 

364 else: 

365 n_sigmas = 1 

366 log_sigmas = log_sigmas[jnp.newaxis, ...].repeat(n_rhos, axis=0) 

367 

368 einsum_subscript = "ij,jk->ik" 

369 

370 def _f(rhos, log_rhos, log_sigmas): 

371 prod = jnp.einsum(einsum_subscript, rhos, log_rhos - log_sigmas) 

372 rel_entropies = jnp.abs(jnp.trace(prod, axis1=-2, axis2=-1)) 

373 return rel_entropies 

374 

375 rel_entropies = jax.vmap(_f, in_axes=(0, 0, 0))(rhos, log_rhos, log_sigmas) 

376 

377 if n_sigmas > 1: 

378 rel_entropies = rel_entropies.reshape(n_sigmas, n_rhos) 

379 return rel_entropies 

380 

381 @staticmethod 

382 def entanglement_of_formation( 

383 model: Model, 

384 n_samples: int, 

385 random_key: Optional[jax.random.PRNGKey] = None, 

386 scale: bool = False, 

387 always_decompose: bool = False, 

388 **kwargs: Any, 

389 ) -> float: 

390 """ 

391 This function implements the entanglement of formation for mixed 

392 quantum systems. 

393 In that a mixed state gets decomposed into pure states with respective 

394 probabilities using the eigendecomposition of the density matrix. 

395 Then, the Meyer-Wallach measure is computed for each pure state, 

396 weighted by the eigenvalue. 

397 See e.g. https://doi.org/10.48550/arXiv.quant-ph/0504163 

398 

399 Note that the decomposition is *not unique*! Therefore, this measure 

400 presents the entanglement for *some* decomposition into pure states, 

401 not necessarily the one that is anticipated when applying the Kraus 

402 channels. 

403 If a pure state is provided, this results in the same value as the 

404 Entanglement.meyer_wallach function if `always_decompose` flag is not set. 

405 

406 Args: 

407 model (Model): The quantum circuit model. 

408 n_samples (int): Number of samples per qubit. 

409 random_key (Optional[jax.random.PRNGKey]): JAX random key for 

410 parameter initialization. If None, uses the model's internal 

411 random key. 

412 scale (bool): Whether to scale the number of samples. 

413 always_decompose (bool): Whether to explicitly compute the 

414 entantlement of formation for the eigendecomposition of a pure 

415 state. 

416 kwargs (Any): Additional keyword arguments for the model function. 

417 

418 Returns: 

419 float: Entangling capacity of the given circuit, guaranteed 

420 to be between 0.0 and 1.0. 

421 """ 

422 

423 if scale: 

424 n_samples = jnp.power(2, model.n_qubits) * n_samples 

425 

426 if n_samples is not None and n_samples > 0: 

427 model.initialize_params(random_key, repeat=n_samples) 

428 else: 

429 if len(model.params.shape) <= 2: 

430 model.params = model.params.reshape(1, *model.params.shape) 

431 else: 

432 log.info(f"Using sample size of model params: {model.params.shape[0]}") 

433 

434 # implicitly set input to none in case it's not needed 

435 kwargs.setdefault("inputs", None) 

436 rhos = model(execution_type="density", **kwargs) 

437 rhos = rhos.reshape(-1, 2**model.n_qubits, 2**model.n_qubits) 

438 ent = Entanglement._compute_entanglement_of_formation( 

439 rhos, model.n_qubits, always_decompose 

440 ) 

441 return ent.mean() 

442 

443 @staticmethod 

444 def _compute_entanglement_of_formation( 

445 rhos: jnp.ndarray, 

446 n_qubits: int, 

447 always_decompose: bool, 

448 ) -> jnp.ndarray: 

449 """ 

450 Computes the entanglement of formation for a given batch of density 

451 matrices. 

452 

453 Args: 

454 rho (jnp.ndarray): The density matrices, has shape (B_s, 2^n, 2^n), 

455 where B_s is the batch size and n the number of qubits. 

456 n_qubits (int): Number of qubits 

457 always_decompose (bool): Whether to explicitly compute the 

458 entantlement of formation for the eigendecomposition of a pure 

459 state. 

460 

461 Returns: 

462 jnp.ndarray: Entanglement for the provided density matrices. 

463 """ 

464 eigenvalues, eigenvectors = jnp.linalg.eigh(rhos) 

465 if not always_decompose and jnp.isclose(eigenvalues, 1.0).any(axis=-1).all(): 

466 return Entanglement._compute_meyer_wallach_meas(rhos, n_qubits) 

467 

468 rhos = np.einsum("sij,sik->sijk", eigenvectors, eigenvectors.conjugate()) 

469 measures = Entanglement._compute_meyer_wallach_meas( 

470 rhos.reshape(-1, 2**n_qubits, 2**n_qubits), n_qubits 

471 ) 

472 ent = np.einsum("si,si->s", measures.reshape(-1, 2**n_qubits), eigenvalues) 

473 return ent 

474 

475 @staticmethod 

476 def concentratable_entanglement( 

477 model: Model, 

478 n_samples: int, 

479 random_key: Optional[jax.random.PRNGKey] = None, 

480 scale: bool = False, 

481 **kwargs: Any, 

482 ) -> float: 

483 """ 

484 Computes the concentratable entanglement of a given model. 

485 

486 This method utilizes the Concentratable Entanglement measure from 

487 https://arxiv.org/abs/2104.06923. The swap test is implemented 

488 directly in yaqsi using a ``3 * n_qubits`` circuit. 

489 

490 Args: 

491 model (Model): The quantum circuit model. 

492 n_samples (int): The number of samples to compute the measure for. 

493 random_key (Optional[jax.random.PRNGKey]): JAX random key for 

494 parameter initialization. If None, uses the model's internal 

495 random key. 

496 scale (bool): Whether to scale the number of samples according to 

497 the number of qubits. 

498 **kwargs (Any): Additional keyword arguments for the model function. 

499 

500 Returns: 

501 float: Entangling capability of the given circuit, guaranteed 

502 to be between 0.0 and 1.0. 

503 """ 

504 n = model.n_qubits 

505 N = 2**n 

506 

507 if scale: 

508 n_samples = N * n_samples 

509 

510 def _shift_and_append(tape_ops, offset): 

511 """Re-register *tape_ops* on the active tape with wires shifted.""" 

512 from qml_essentials.tape import active_tape as _active_tape 

513 

514 current = _active_tape() 

515 if current is None: 

516 return 

517 for o in tape_ops: 

518 shifted = o.__class__.__new__(o.__class__) 

519 shifted.__dict__.update(o.__dict__) 

520 shifted._wires = [w + offset for w in o.wires] 

521 current.append(shifted) 

522 

523 def _swap_test_circuit( 

524 params, inputs, pulse_params=None, random_key=None, **kw 

525 ): 

526 """Swap-test circuit on 3*n qubits.""" 

527 from qml_essentials.tape import recording as _recording 

528 

529 # First copy on wires n..2n-1 

530 with _recording() as copy1_tape: 

531 model._variational( 

532 params, 

533 inputs, 

534 pulse_params=pulse_params, 

535 random_key=random_key, 

536 **kw, 

537 ) 

538 _shift_and_append(copy1_tape, n) 

539 

540 # Second copy on wires 2n..3n-1 

541 with _recording() as copy2_tape: 

542 model._variational( 

543 params, 

544 inputs, 

545 pulse_params=pulse_params, 

546 random_key=random_key, 

547 **kw, 

548 ) 

549 _shift_and_append(copy2_tape, 2 * n) 

550 

551 # Swap test: H on ancilla register (wires 0..n-1) 

552 for i in range(n): 

553 op.H(wires=i) 

554 

555 for i in range(n): 

556 op.CSWAP(wires=[i, i + n, i + 2 * n]) 

557 

558 for i in range(n): 

559 op.H(wires=i) 

560 

561 swap_script = ys.Script(f=_swap_test_circuit, n_qubits=3 * n) 

562 

563 if n_samples is not None and n_samples > 0: 

564 random_key = model.initialize_params(random_key, repeat=n_samples) 

565 else: 

566 if len(model.params.shape) <= 2: 

567 model.params = model.params.reshape(1, *model.params.shape) 

568 else: 

569 log.info(f"Using sample size of model params: {model.params.shape[0]}") 

570 

571 params = model.params 

572 inputs = model._inputs_validation(kwargs.get("inputs", None)) 

573 n_batch = params.shape[0] 

574 

575 marg_probs = jax.jit(ys.marginalize_probs, static_argnums=(1, 2)) 

576 

577 if n_batch > 1: 

578 from qml_essentials.utils import safe_random_split 

579 

580 random_keys = safe_random_split(random_key, num=n_batch) 

581 probs = swap_script.execute( 

582 type="probs", 

583 args=(params, inputs, model.pulse_params, random_keys), 

584 in_axes=(0, None, None, 0), 

585 kwargs=kwargs, 

586 ) 

587 else: 

588 probs = swap_script.execute( 

589 type="probs", 

590 args=(params, inputs, model.pulse_params, random_key), 

591 kwargs=kwargs, 

592 ) 

593 

594 # Marginalize to the ancilla register (wires 0..n-1) 

595 probs = marg_probs(probs, 3 * n, tuple(range(n))) 

596 

597 ent = 1 - probs[..., 0] 

598 

599 log.debug(f"Variance of measure: {ent.var()}") 

600 

601 return float(ent.mean()) 

602 

603 

604def sample_random_separable_states( 

605 n_qubits: int, 

606 n_samples: int, 

607 random_key: jax.random.PRNGKey, 

608 take_log: bool = False, 

609) -> jnp.ndarray: 

610 """ 

611 Sample random separable states (density matrix). 

612 

613 Args: 

614 n_qubits (int): number of qubits in the state 

615 n_samples (int): number of states 

616 random_key (random.PRNGKey): JAX random key 

617 take_log (bool): if the matrix logarithm of the density matrix should be taken. 

618 

619 Returns: 

620 jnp.ndarray: Density matrices of shape (n_samples, 2**n_qubits, 2**n_qubits) 

621 """ 

622 model = Model(n_qubits, 1, "No_Entangling", data_reupload=False) 

623 model.initialize_params(random_key, repeat=n_samples) 

624 # explicitly set execution type because everything else won't work 

625 sigmas = model(execution_type="density", inputs=None) 

626 if take_log: 

627 sigmas = logm_v(sigmas) / jnp.log(2.0 + 0j) 

628 

629 return sigmas