Coverage for qml_essentials/entanglement.py: 94%
205 statements
« prev ^ index » next coverage.py v7.9.2, created at 2026-02-20 14:03 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2026-02-20 14:03 +0000
1from typing import Optional, Any, List, Tuple
2import pennylane as qml
3import jax
4import jax.numpy as jnp
5import numpy as np
7from qml_essentials.utils import logm_v
8from qml_essentials.model import Model
9import logging
11log = logging.getLogger(__name__)
14class Entanglement:
15 @staticmethod
16 def meyer_wallach(
17 model: Model,
18 n_samples: Optional[int | None],
19 seed: Optional[int],
20 scale: bool = False,
21 **kwargs: Any,
22 ) -> float:
23 """
24 Calculates the entangling capacity of a given quantum circuit
25 using Meyer-Wallach measure.
27 Args:
28 model (Model): The quantum circuit model.
29 n_samples (Optional[int]): Number of samples per qubit.
30 If None or < 0, the current parameters of the model are used.
31 seed (Optional[int]): Seed for the random number generator.
32 scale (bool): Whether to scale the number of samples.
33 kwargs (Any): Additional keyword arguments for the model function.
35 Returns:
36 float: Entangling capacity of the given circuit, guaranteed
37 to be between 0.0 and 1.0.
38 """
39 if "noise_params" in kwargs:
40 log.warning(
41 "Meyer-Wallach measure not suitable for noisy circuits.\
42 Consider 'relative_entropy' instead."
43 )
45 if scale:
46 n_samples = jnp.power(2, model.n_qubits) * n_samples
48 random_key = jax.random.key(seed)
49 if n_samples is not None and n_samples > 0:
50 assert seed is not None, "Seed must be provided when samples > 0"
51 random_key = model.initialize_params(random_key, repeat=n_samples)
52 else:
53 if seed is not None:
54 log.warning("Seed is ignored when samples is 0")
56 # implicitly set input to none in case it's not needed
57 kwargs.setdefault("inputs", None)
58 # explicitly set execution type because everything else won't work
59 rhos = model(execution_type="density", **kwargs).reshape(
60 -1, 2**model.n_qubits, 2**model.n_qubits
61 )
63 ent = Entanglement._compute_meyer_wallach_meas(
64 rhos, model.n_qubits, model.use_multithreading
65 )
67 log.debug(f"Variance of measure: {ent.var()}")
69 return ent.mean()
71 @staticmethod
72 def _compute_meyer_wallach_meas(
73 rhos: jnp.ndarray, n_qubits: int, use_multithreading: bool = False
74 ) -> jnp.ndarray:
75 """
76 Computes the Meyer-Wallach entangling capability measure for a given
77 set of density matrices.
79 Args:
80 rhos (jnp.ndarray): Density matrices of the sample quantum states.
81 The shape is (B_s, 2^n, 2^n), where B_s is the number of samples
82 (batch) and n the number of qubits
83 n_qubits (int): The number of qubits
84 use_multithreading (bool): Whether to use JAX vectorisation.
86 Returns:
87 jnp.ndarray: Entangling capability for each sample, array with
88 shape (B_s,)
89 """
90 qb = list(range(n_qubits))
92 def _f(rhos):
93 entropy = 0
94 for j in range(n_qubits):
95 # Formula 6 in https://doi.org/10.48550/arXiv.quant-ph/0305094
96 density = qml.math.partial_trace(rhos, qb[:j] + qb[j + 1 :])
97 # only real values, because imaginary part will be separate
98 # in all following calculations anyway
99 # entropy should be 1/2 <= entropy <= 1
100 entropy += jnp.trace((density @ density).real, axis1=-2, axis2=-1)
102 # inverse averaged entropy and scale to [0, 1]
103 return 2 * (1 - entropy / n_qubits)
105 if use_multithreading:
106 return jax.vmap(_f)(rhos)
107 else:
108 return _f(rhos)
110 @staticmethod
111 def bell_measurements(
112 model: Model, n_samples: int, seed: int, scale: bool = False, **kwargs: Any
113 ) -> float:
114 """
115 Compute the Bell measurement for a given model.
117 Args:
118 model (Model): The quantum circuit model.
119 n_samples (int): The number of samples to compute the measure for.
120 seed (int): The seed for the random number generator.
121 scale (bool): Whether to scale the number of samples
122 according to the number of qubits.
123 **kwargs (Any): Additional keyword arguments for the model function.
125 Returns:
126 float: The Bell measurement value.
127 """
128 if "noise_params" in kwargs:
129 log.warning(
130 "Bell Measurements not suitable for noisy circuits.\
131 Consider 'relative_entropy' instead."
132 )
134 if scale:
135 n_samples = jnp.power(2, model.n_qubits) * n_samples
137 def _circuit(
138 params: jnp.ndarray, inputs: jnp.ndarray, **kwargs
139 ) -> List[jnp.ndarray]:
140 """
141 Compute the Bell measurement circuit.
143 Args:
144 params (jnp.ndarray): The model parameters.
145 inputs (jnp.ndarray): The input to the model.
146 pulse_params (jnp.ndarray): The model pulse parameters.
147 enc_params (Optional[jnp.ndarray]): The frequency encoding parameters.
149 Returns:
150 List[jnp.ndarray]: The probabilities of the Bell measurement.
151 """
152 model._variational(params, inputs, **kwargs)
154 qml.map_wires(
155 model._variational,
156 {i: i + model.n_qubits for i in range(model.n_qubits)},
157 )(params, inputs)
159 for q in range(model.n_qubits):
160 qml.CNOT(wires=[q, q + model.n_qubits])
161 qml.H(q)
163 # look at the auxiliary qubits
164 return model._observable()
166 prev_output_qubit = model.output_qubit
167 model.output_qubit = [(q, q + model.n_qubits) for q in range(model.n_qubits)]
168 model.circuit = qml.QNode(
169 _circuit,
170 qml.device(
171 "default.qubit",
172 shots=model.shots,
173 wires=model.n_qubits * 2,
174 ),
175 )
177 random_key = jax.random.key(seed)
178 if n_samples is not None and n_samples > 0:
179 assert seed is not None, "Seed must be provided when samples > 0"
180 random_key = model.initialize_params(random_key, repeat=n_samples)
181 params = model.params
182 else:
183 if seed is not None:
184 log.warning("Seed is ignored when samples is 0")
186 if len(model.params.shape) <= 2:
187 params = model.params.reshape(*model.params.shape, 1)
188 else:
189 log.info(f"Using sample size of model params: {model.params.shape[-1]}")
190 params = model.params
192 n_samples = params.shape[-1]
193 measure = jnp.zeros(n_samples)
195 # implicitly set input to none in case it's not needed
196 kwargs.setdefault("inputs", None)
197 exp = model(params=params, execution_type="probs", **kwargs)
198 exp = 1 - 2 * exp[..., -1]
200 if not jnp.isclose(jnp.sum(exp.imag), 0, atol=1e-6):
201 log.warning("Imaginary part of probabilities detected")
202 exp = jnp.abs(exp)
204 measure = 2 * (1 - exp.mean(axis=0))
205 entangling_capability = min(max(measure.mean(), 0.0), 1.0)
206 log.debug(f"Variance of measure: {measure.var()}")
208 # restore state
209 model.output_qubit = prev_output_qubit
210 return float(entangling_capability)
212 @staticmethod
213 def relative_entropy(
214 model: Model,
215 n_samples: int,
216 n_sigmas: int,
217 seed: Optional[int],
218 scale: bool = False,
219 **kwargs: Any,
220 ) -> float:
221 """
222 Calculates the relative entropy of entanglement of a given quantum
223 circuit. This measure is also applicable to mixed state, albeit it
224 might me not fully accurate in this simplified case.
226 As the relative entropy is generally defined as the smallest relative
227 entropy from the state in question to the set of separable states.
228 However, as computing the nearest separable state is NP-hard, we select
229 n_sigmas of random separable states to compute the distance to, which
230 is not necessarily the nearest. Thus, this measure of entanglement
231 presents an upper limit of entanglement.
233 As the relative entropy is not necessarily between zero and one, this
234 function also normalises by the relative entroy to the GHZ state.
236 Args:
237 model (Model): The quantum circuit model.
238 n_samples (int): Number of samples per qubit.
239 If <= 0, the current parameters of the model are used.
240 n_sigmas (int): Number of random separable pure states to compare against.
241 seed (Optional[int]): Seed for the random number generator.
242 scale (bool): Whether to scale the number of samples.
243 kwargs (Any): Additional keyword arguments for the model function.
245 Returns:
246 float: Entangling capacity of the given circuit, guaranteed
247 to be between 0.0 and 1.0.
248 """
249 dim = jnp.power(2, model.n_qubits)
250 if scale:
251 n_samples = dim * n_samples
252 n_sigmas = dim * n_sigmas
254 random_key = jax.random.key(seed)
256 # Random separable states
257 log_sigmas = sample_random_separable_states(
258 model.n_qubits, n_samples=n_sigmas, random_key=random_key, take_log=True
259 )
261 random_key, _ = jax.random.split(random_key)
263 if n_samples is not None and n_samples > 0:
264 assert seed is not None, "Seed must be provided when samples > 0"
265 model.initialize_params(random_key, repeat=n_samples)
266 else:
267 if seed is not None:
268 log.warning("Seed is ignored when samples is 0")
270 if len(model.params.shape) <= 2:
271 model.params = model.params.reshape(*model.params.shape, 1)
272 else:
273 log.info(f"Using sample size of model params: {model.params.shape[-1]}")
275 rhos, log_rhos = Entanglement._compute_log_density(model, **kwargs)
277 rel_entropies = jnp.zeros((n_sigmas, model.params.shape[-1]))
279 for i, log_sigma in enumerate(log_sigmas):
280 rel_entropies = rel_entropies.at[i].set(
281 Entanglement._compute_rel_entropies(
282 rhos, log_rhos, log_sigma, model.use_multithreading
283 )
284 )
286 # Entropy of GHZ states should be maximal
287 ghz_model = Model(model.n_qubits, 1, "GHZ", data_reupload=False)
288 rho_ghz, log_rho_ghz = Entanglement._compute_log_density(ghz_model, **kwargs)
289 ghz_entropies = Entanglement._compute_rel_entropies(
290 rho_ghz, log_rho_ghz, log_sigmas, use_multithreading=False
291 )
293 normalised_entropies = rel_entropies / ghz_entropies
295 # Average all iterated states
296 entangling_capability = normalised_entropies.T.min(axis=1)
297 log.debug(f"Variance of measure: {entangling_capability.var()}")
299 return entangling_capability.mean()
301 @staticmethod
302 def _compute_log_density(model: Model, **kwargs) -> Tuple[jnp.ndarray, jnp.ndarray]:
303 """
304 Obtains the density matrix of a model and computes its logarithm.
306 Args:
307 model (Model): The model for which to compute the density matrix.
309 Returns:
310 Tuple[jnp.ndarray, jnp.ndarray]:
311 - jnp.ndarray: density matrix.
312 - jnp.ndarray: logarithm of the density matrix.
313 """
314 # implicitly set input to none in case it's not needed
315 kwargs.setdefault("inputs", None)
316 # explicitly set execution type because everything else won't work
317 rho = model(execution_type="density", **kwargs)
318 rho = rho.reshape(-1, 2**model.n_qubits, 2**model.n_qubits)
319 log_rho = logm_v(rho) / jnp.log(2)
320 return rho, log_rho
322 @staticmethod
323 def _compute_rel_entropies(
324 rhos: jnp.ndarray,
325 log_rhos: jnp.ndarray,
326 log_sigmas: jnp.ndarray,
327 use_multithreading: bool,
328 ) -> jnp.ndarray:
329 """
330 Compute the relative entropy for a given model.
332 Args:
333 rhos (jnp.ndarray): Density matrix result of the circuit, has shape
334 (R, 2^n, 2^n), with the batch size R and number of qubits n
335 log_rhos (jnp.ndarray): Corresponding logarithm of the density
336 matrix, has shape (R, 2^n, 2^n).
337 log_sigmas (jnp.ndarray): Density matrix of next separable state,
338 has shape (2^n, 2^n) if it's a single sigma or (S, 2^n, 2^n),
339 with the batch size S (number of sigmas).
341 Returns:
342 jnp.ndarray: Relative Entropy for each sample
343 """
344 n_rhos = rhos.shape[0]
345 if len(log_sigmas.shape) == 3:
346 n_sigmas = log_sigmas.shape[0]
347 rhos = jnp.tile(rhos, (n_sigmas, 1, 1))
348 log_rhos = jnp.tile(log_rhos, (n_sigmas, 1, 1))
349 einsum_subscript = "ij,jk->ik" if use_multithreading else "sij,sjk->sik"
350 else:
351 n_sigmas = 1
352 log_sigmas = log_sigmas[jnp.newaxis, ...].repeat(n_rhos, axis=0)
354 einsum_subscript = "ij,jk->ik" if use_multithreading else "sij,sjk->sik"
356 def _f(rhos, log_rhos, log_sigmas):
357 prod = jnp.einsum(einsum_subscript, rhos, log_rhos - log_sigmas)
358 rel_entropies = jnp.abs(jnp.trace(prod, axis1=-2, axis2=-1))
359 return rel_entropies
361 if use_multithreading:
362 rel_entropies = jax.vmap(_f, in_axes=(0, 0, 0))(rhos, log_rhos, log_sigmas)
363 else:
364 rel_entropies = _f(rhos, log_rhos, log_sigmas)
366 if n_sigmas > 1:
367 rel_entropies = rel_entropies.reshape(n_sigmas, n_rhos)
368 return rel_entropies
370 @staticmethod
371 def entanglement_of_formation(
372 model: Model,
373 n_samples: int,
374 seed: Optional[int],
375 scale: bool = False,
376 always_decompose: bool = False,
377 **kwargs: Any,
378 ) -> float:
379 """
380 This function implements the entanglement of formation for mixed
381 quantum systems.
382 In that a mixed state gets decomposed into pure states with respective
383 probabilities using the eigendecomposition of the density matrix.
384 Then, the Meyer-Wallach measure is computed for each pure state,
385 weighted by the eigenvalue.
386 See e.g. https://doi.org/10.48550/arXiv.quant-ph/0504163
388 Note that the decomposition is *not unique*! Therefore, this measure
389 presents the entanglement for *some* decomposition into pure states,
390 not necessarily the one that is anticipated when applying the Kraus
391 channels.
392 If a pure state is provided, this results in the same value as the
393 Entanglement.meyer_wallach function if `always_decompose` flag is not set.
395 Args:
396 model (Model): The quantum circuit model.
397 n_samples (int): Number of samples per qubit.
398 seed (Optional[int]): Seed for the random number generator.
399 scale (bool): Whether to scale the number of samples.
400 always_decompose (bool): Whether to explicitly compute the
401 entantlement of formation for the eigendecomposition of a pure
402 state.
403 kwargs (Any): Additional keyword arguments for the model function.
405 Returns:
406 float: Entangling capacity of the given circuit, guaranteed
407 to be between 0.0 and 1.0.
408 """
410 if scale:
411 n_samples = jnp.power(2, model.n_qubits) * n_samples
413 random_key = jax.random.key(seed)
414 if n_samples is not None and n_samples > 0:
415 assert seed is not None, "Seed must be provided when samples > 0"
416 model.initialize_params(random_key, repeat=n_samples)
417 else:
418 if seed is not None:
419 log.warning("Seed is ignored when samples is 0")
421 if len(model.params.shape) <= 2:
422 model.params = model.params.reshape(*model.params.shape, 1)
423 else:
424 log.info(f"Using sample size of model params: {model.params.shape[-1]}")
426 # implicitly set input to none in case it's not needed
427 kwargs.setdefault("inputs", None)
428 rhos = model(execution_type="density", **kwargs)
429 rhos = rhos.reshape(-1, 2**model.n_qubits, 2**model.n_qubits)
430 ent = Entanglement._compute_entanglement_of_formation(
431 rhos, model.n_qubits, always_decompose, model.use_multithreading
432 )
433 return ent.mean()
435 @staticmethod
436 def _compute_entanglement_of_formation(
437 rhos: jnp.ndarray,
438 n_qubits: int,
439 always_decompose: bool,
440 use_multithreading: bool,
441 ) -> jnp.ndarray:
442 """
443 Computes the entanglement of formation for a given batch of density
444 matrices.
446 Args:
447 rho (jnp.ndarray): The density matrices, has shape (B_s, 2^n, 2^n),
448 where B_s is the batch size and n the number of qubits.
449 n_qubits (int): Number of qubits
450 always_decompose (bool): Whether to explicitly compute the
451 entantlement of formation for the eigendecomposition of a pure
452 state.
453 use_multithreading (bool): Whether to use JAX vectorisation.
455 Returns:
456 jnp.ndarray: Entanglement for the provided density matrices.
457 """
458 eigenvalues, eigenvectors = jnp.linalg.eigh(rhos)
459 if not always_decompose and jnp.isclose(eigenvalues, 1.0).any(axis=-1).all():
460 return Entanglement._compute_meyer_wallach_meas(
461 rhos, n_qubits, use_multithreading
462 )
464 rhos = np.einsum("sij,sik->sijk", eigenvectors, eigenvectors.conjugate())
465 measures = Entanglement._compute_meyer_wallach_meas(
466 rhos.reshape(-1, 2**n_qubits, 2**n_qubits), n_qubits, use_multithreading
467 )
468 ent = np.einsum("si,si->s", measures.reshape(-1, 2**n_qubits), eigenvalues)
469 return ent
471 @staticmethod
472 def concentratable_entanglement(
473 model: Model, n_samples: int, seed: int, scale: bool = False, **kwargs: Any
474 ) -> float:
475 """
476 Computes the concentratable entanglement of a given model.
478 This method utilizes the Concentratable Entanglement measure from
479 https://arxiv.org/abs/2104.06923.
481 Args:
482 model (Model): The quantum circuit model.
483 n_samples (int): The number of samples to compute the measure for.
484 seed (int): The seed for the random number generator.
485 scale (bool): Whether to scale the number of samples according to
486 the number of qubits.
487 **kwargs (Any): Additional keyword arguments for the model function.
489 Returns:
490 float: Entangling capability of the given circuit, guaranteed
491 to be between 0.0 and 1.0.
492 """
493 n = model.n_qubits
494 N = 2**n
496 if scale:
497 n_samples = N * n_samples
499 dev = qml.device(
500 "default.mixed",
501 shots=model.shots,
502 wires=n * 3,
503 )
505 @qml.qnode(device=dev)
506 def _swap_test(
507 params: jnp.ndarray, inputs: jnp.ndarray, **kwargs
508 ) -> jnp.ndarray:
509 """
510 Constructs a circuit to compute the concentratable entanglement using the
511 swap test by creating two copies of a state given by a density matrix rho
512 and mapping the output wires accordingly.
514 Args:
515 rho (jnp.ndarray): the density matrix of the state on which the swap
516 test is performed.
518 Returns:
519 List[jnp.ndarray]: Probabilities obtained from the swap test circuit.
520 """
522 qml.map_wires(model._variational, wire_map={o: o + n for o in range(n)})(
523 params, inputs, **kwargs
524 )
525 qml.map_wires(
526 model._variational, wire_map={o: o + 2 * n for o in range(n)}
527 )(params, inputs, **kwargs)
529 # Perform swap test
530 for i in range(n):
531 qml.H(i)
533 for i in range(n):
534 qml.CSWAP([i, i + n, i + 2 * n])
536 for i in range(n):
537 qml.H(i)
539 return qml.probs(wires=[i for i in range(n)])
541 random_key = jax.random.key(seed)
542 if n_samples is not None and n_samples > 0:
543 assert seed is not None, "Seed must be provided when samples > 0"
544 model.initialize_params(random_key, repeat=n_samples)
545 else:
546 if seed is not None:
547 log.warning("Seed is ignored when samples is 0")
549 if len(model.params.shape) <= 2:
550 model.params = model.params.reshape(*model.params.shape, 1)
551 else:
552 log.info(f"Using sample size of model params: {model.params.shape[-1]}")
554 def _f(params):
555 probs = _swap_test(params, model._inputs_validation(None), **kwargs)
556 ent = 1 - probs[..., 0]
557 return ent
559 if model.use_multithreading:
560 ent = jax.vmap(_f, in_axes=2)(model.params)
561 else:
562 ent = _f(model.params)
564 # Catch floating point errors
565 log.debug(f"Variance of measure: {ent.var()}")
567 return ent.mean()
570def sample_random_separable_states(
571 n_qubits: int,
572 n_samples: int,
573 random_key: jax.random.PRNGKey,
574 take_log: bool = False,
575) -> jnp.ndarray:
576 """
577 Sample random separable states (density matrix).
579 Args:
580 n_qubits (int): number of qubits in the state
581 n_samples (int): number of states
582 random_key (random.PRNGKey): JAX random key
583 take_log (bool): if the matrix logarithm of the density matrix should be taken.
585 Returns:
586 jnp.ndarray: Density matrices of shape (n_samples, 2**n_qubits, 2**n_qubits)
587 """
588 model = Model(n_qubits, 1, "No_Entangling", data_reupload=False)
589 model.initialize_params(random_key, repeat=n_samples)
590 # explicitly set execution type because everything else won't work
591 sigmas = model(execution_type="density", inputs=None)
592 if take_log:
593 sigmas = logm_v(sigmas) / jnp.log(2.0 + 0j)
595 return sigmas