Coverage for qml_essentials / entanglement.py: 94%
217 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-30 11:43 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-30 11:43 +0000
1from typing import Optional, Any, Tuple
2import jax
3import jax.numpy as jnp
4import numpy as np
6from qml_essentials import yaqsi as ys
7from qml_essentials import operations as op
8from qml_essentials.math import logm_v
9from qml_essentials.model import Model
10import logging
12log = logging.getLogger(__name__)
15class Entanglement:
16 @staticmethod
17 def meyer_wallach(
18 model: Model,
19 n_samples: Optional[int | None],
20 random_key: Optional[jax.random.PRNGKey] = None,
21 scale: bool = False,
22 **kwargs: Any,
23 ) -> float:
24 """
25 Calculates the entangling capacity of a given quantum circuit
26 using Meyer-Wallach measure.
28 Args:
29 model (Model): The quantum circuit model.
30 n_samples (Optional[int]): Number of samples per qubit.
31 If None or < 0, the current parameters of the model are used.
32 random_key (Optional[jax.random.PRNGKey]): JAX random key for
33 parameter initialization. If None, uses the model's internal
34 random key.
35 scale (bool): Whether to scale the number of samples.
36 kwargs (Any): Additional keyword arguments for the model function.
38 Returns:
39 float: Entangling capacity of the given circuit, guaranteed
40 to be between 0.0 and 1.0.
41 """
42 if "noise_params" in kwargs:
43 log.warning(
44 "Meyer-Wallach measure not suitable for noisy circuits. "
45 "Consider 'concentratable entanglement' instead."
46 )
48 if scale:
49 n_samples = jnp.power(2, model.n_qubits) * n_samples
51 if n_samples is not None and n_samples > 0:
52 random_key = model.initialize_params(random_key, repeat=n_samples)
54 # implicitly set input to none in case it's not needed
55 kwargs.setdefault("inputs", None)
56 # explicitly set execution type because everything else won't work
57 rhos = model(execution_type="density", **kwargs).reshape(
58 -1, 2**model.n_qubits, 2**model.n_qubits
59 )
61 ent = Entanglement._compute_meyer_wallach_meas(rhos, model.n_qubits)
63 log.debug(f"Variance of measure: {ent.var()}")
65 return ent.mean()
67 @staticmethod
68 def _compute_meyer_wallach_meas(rhos: jnp.ndarray, n_qubits: int) -> jnp.ndarray:
69 """
70 Computes the Meyer-Wallach entangling capability measure for a given
71 set of density matrices.
73 Args:
74 rhos (jnp.ndarray): Density matrices of the sample quantum states.
75 The shape is (B_s, 2^n, 2^n), where B_s is the number of samples
76 (batch) and n the number of qubits
77 n_qubits (int): The number of qubits
79 Returns:
80 jnp.ndarray: Entangling capability for each sample, array with
81 shape (B_s,)
82 """
83 qb = list(range(n_qubits))
85 def _f(rhos):
86 entropy = 0
87 for j in range(n_qubits):
88 # Formula 6 in https://doi.org/10.48550/arXiv.quant-ph/0305094
89 # Trace out qubit j, keep all others
90 keep = qb[:j] + qb[j + 1 :]
91 density = ys.partial_trace(rhos, n_qubits, keep)
92 # only real values, because imaginary part will be separate
93 # in all following calculations anyway
94 # entropy should be 1/2 <= entropy <= 1
95 entropy += jnp.trace((density @ density).real, axis1=-2, axis2=-1)
97 # inverse averaged entropy and scale to [0, 1]
98 return 2 * (1 - entropy / n_qubits)
100 return jax.vmap(_f)(rhos)
102 @staticmethod
103 def bell_measurements(
104 model: Model,
105 n_samples: int,
106 random_key: Optional[jax.random.PRNGKey] = None,
107 scale: bool = False,
108 **kwargs: Any,
109 ) -> float:
110 """
111 Compute the Bell measurement for a given model.
113 Constructs a ``2 * n_qubits`` circuit that prepares two copies of
114 the model state (on disjoint qubit registers), applies CNOTs and
115 Hadamards, and measures probabilities on the first register.
117 Args:
118 model (Model): The quantum circuit model.
119 n_samples (int): The number of samples to compute the measure for.
120 random_key (Optional[jax.random.PRNGKey]): JAX random key for
121 parameter initialization. If None, uses the model's internal
122 random key.
123 scale (bool): Whether to scale the number of samples
124 according to the number of qubits.
125 **kwargs (Any): Additional keyword arguments for the model function.
127 Returns:
128 float: The Bell measurement value.
129 """
130 if "noise_params" in kwargs:
131 log.warning(
132 "Bell Measurements not suitable for noisy circuits. "
133 "Consider 'concentratable entanglement' instead."
134 )
136 if scale:
137 n_samples = jnp.power(2, model.n_qubits) * n_samples
139 n = model.n_qubits
141 def _bell_circuit(params, inputs, pulse_params=None, random_key=None, **kw):
142 """Bell measurement circuit on 2*n qubits."""
143 # First copy on wires 0..n-1
144 model._variational(
145 params, inputs, pulse_params=pulse_params, random_key=random_key, **kw
146 )
148 # TODO: this is very user-unfriendly and we should find a better way
150 # Second copy on wires n..2n-1: record the tape then shift wires
151 from qml_essentials.tape import recording as _recording
153 with _recording() as shifted_tape:
154 model._variational(
155 params,
156 inputs,
157 pulse_params=pulse_params,
158 random_key=random_key,
159 **kw,
160 )
161 for o in shifted_tape:
162 shifted_op = o.__class__.__new__(o.__class__)
163 shifted_op.__dict__.update(o.__dict__)
164 shifted_op._wires = [w + n for w in o.wires]
165 # Re-register on the active tape
166 from qml_essentials.tape import active_tape as _active_tape
168 tape = _active_tape()
169 if tape is not None:
170 tape.append(shifted_op)
172 # Bell measurement: CNOT + H
173 for q in range(n):
174 op.CX(wires=[q, q + n])
175 op.H(wires=q)
177 bell_script = ys.Script(f=_bell_circuit, n_qubits=2 * n)
179 if n_samples is not None and n_samples > 0:
180 random_key = model.initialize_params(random_key, repeat=n_samples)
181 params = model.params
182 else:
183 if len(model.params.shape) <= 2:
184 params = model.params.reshape(1, *model.params.shape)
185 else:
186 log.info(f"Using sample size of model params: {model.params.shape[0]}")
187 params = model.params
189 n_samples = params.shape[0]
190 inputs = model._inputs_validation(kwargs.get("inputs", None))
192 # Execute: vmap over batch dimension of params (axis 0)
193 if n_samples > 1:
194 from qml_essentials.utils import safe_random_split
196 random_keys = safe_random_split(random_key, num=n_samples)
197 result = bell_script.execute(
198 type="probs",
199 args=(params, inputs, model.pulse_params, random_keys),
200 kwargs=kwargs,
201 in_axes=(0, None, None, 0),
202 )
203 else:
204 result = bell_script.execute(
205 type="probs",
206 args=(params, inputs, model.pulse_params, random_key),
207 kwargs=kwargs,
208 )
210 # Marginalize: for each qubit q, keep wires [q, q+n] from the 2n-qubit probs
211 # The last probability in each pair gives P(|11⟩) for that qubit pair
212 per_qubit = []
213 for q in range(n):
214 marg = ys.marginalize_probs(result, 2 * n, [q, q + n])
215 per_qubit.append(marg)
216 # per_qubit[q] has shape (n_samples, 4) or (4,)
217 exp = jnp.stack(per_qubit, axis=-2) # (..., n, 4)
218 exp = 1 - 2 * exp[..., -1] # (..., n)
220 if not jnp.isclose(jnp.sum(exp.imag), 0, atol=1e-6):
221 log.warning("Imaginary part of probabilities detected")
222 exp = jnp.abs(exp)
224 measure = 2 * (1 - exp.mean(axis=0))
225 entangling_capability = min(max(float(measure.mean()), 0.0), 1.0)
226 log.debug(f"Variance of measure: {measure.var()}")
228 return entangling_capability
230 @staticmethod
231 def relative_entropy(
232 model: Model,
233 n_samples: int,
234 n_sigmas: int,
235 random_key: Optional[jax.random.PRNGKey] = None,
236 scale: bool = False,
237 **kwargs: Any,
238 ) -> float:
239 """
240 Calculates the relative entropy of entanglement of a given quantum
241 circuit. This measure is also applicable to mixed state, albeit it
242 might me not fully accurate in this simplified case.
244 As the relative entropy is generally defined as the smallest relative
245 entropy from the state in question to the set of separable states.
246 However, as computing the nearest separable state is NP-hard, we select
247 n_sigmas of random separable states to compute the distance to, which
248 is not necessarily the nearest. Thus, this measure of entanglement
249 presents an upper limit of entanglement.
251 As the relative entropy is not necessarily between zero and one, this
252 function also normalises by the relative entroy to the GHZ state.
254 Args:
255 model (Model): The quantum circuit model.
256 n_samples (int): Number of samples per qubit.
257 If <= 0, the current parameters of the model are used.
258 n_sigmas (int): Number of random separable pure states to compare against.
259 random_key (Optional[jax.random.PRNGKey]): JAX random key for
260 parameter initialization. If None, uses the model's internal
261 random key.
262 scale (bool): Whether to scale the number of samples.
263 kwargs (Any): Additional keyword arguments for the model function.
265 Returns:
266 float: Entangling capacity of the given circuit, guaranteed
267 to be between 0.0 and 1.0.
268 """
269 dim = jnp.power(2, model.n_qubits)
270 if scale:
271 n_samples = dim * n_samples
272 n_sigmas = dim * n_sigmas
274 if random_key is None:
275 random_key = model.random_key
277 # Random separable states
278 log_sigmas = sample_random_separable_states(
279 model.n_qubits, n_samples=n_sigmas, random_key=random_key, take_log=True
280 )
282 random_key, _ = jax.random.split(random_key)
284 if n_samples is not None and n_samples > 0:
285 model.initialize_params(random_key, repeat=n_samples)
286 else:
287 if len(model.params.shape) <= 2:
288 model.params = model.params.reshape(1, *model.params.shape)
289 else:
290 log.info(f"Using sample size of model params: {model.params.shape[0]}")
292 rhos, log_rhos = Entanglement._compute_log_density(model, **kwargs)
294 rel_entropies = jnp.zeros((n_sigmas, model.params.shape[0]))
296 for i, log_sigma in enumerate(log_sigmas):
297 rel_entropies = rel_entropies.at[i].set(
298 Entanglement._compute_rel_entropies(rhos, log_rhos, log_sigma)
299 )
301 # Entropy of GHZ states should be maximal
302 ghz_model = Model(model.n_qubits, 1, "GHZ", data_reupload=False)
303 rho_ghz, log_rho_ghz = Entanglement._compute_log_density(ghz_model, **kwargs)
304 ghz_entropies = Entanglement._compute_rel_entropies(
305 rho_ghz, log_rho_ghz, log_sigmas
306 )
308 normalised_entropies = rel_entropies / ghz_entropies
310 # Average all iterated states
311 entangling_capability = normalised_entropies.T.min(axis=1)
312 log.debug(f"Variance of measure: {entangling_capability.var()}")
314 return entangling_capability.mean()
316 @staticmethod
317 def _compute_log_density(model: Model, **kwargs) -> Tuple[jnp.ndarray, jnp.ndarray]:
318 """
319 Obtains the density matrix of a model and computes its logarithm.
321 Args:
322 model (Model): The model for which to compute the density matrix.
324 Returns:
325 Tuple[jnp.ndarray, jnp.ndarray]:
326 - jnp.ndarray: density matrix.
327 - jnp.ndarray: logarithm of the density matrix.
328 """
329 # implicitly set input to none in case it's not needed
330 kwargs.setdefault("inputs", None)
331 # explicitly set execution type because everything else won't work
332 rho = model(execution_type="density", **kwargs)
333 rho = rho.reshape(-1, 2**model.n_qubits, 2**model.n_qubits)
334 log_rho = logm_v(rho) / jnp.log(2)
335 return rho, log_rho
337 @staticmethod
338 def _compute_rel_entropies(
339 rhos: jnp.ndarray,
340 log_rhos: jnp.ndarray,
341 log_sigmas: jnp.ndarray,
342 ) -> jnp.ndarray:
343 """
344 Compute the relative entropy for a given model.
346 Args:
347 rhos (jnp.ndarray): Density matrix result of the circuit, has shape
348 (R, 2^n, 2^n), with the batch size R and number of qubits n
349 log_rhos (jnp.ndarray): Corresponding logarithm of the density
350 matrix, has shape (R, 2^n, 2^n).
351 log_sigmas (jnp.ndarray): Density matrix of next separable state,
352 has shape (2^n, 2^n) if it's a single sigma or (S, 2^n, 2^n),
353 with the batch size S (number of sigmas).
355 Returns:
356 jnp.ndarray: Relative Entropy for each sample
357 """
358 n_rhos = rhos.shape[0]
359 if len(log_sigmas.shape) == 3:
360 n_sigmas = log_sigmas.shape[0]
361 rhos = jnp.tile(rhos, (n_sigmas, 1, 1))
362 log_rhos = jnp.tile(log_rhos, (n_sigmas, 1, 1))
363 einsum_subscript = "ij,jk->ik"
364 else:
365 n_sigmas = 1
366 log_sigmas = log_sigmas[jnp.newaxis, ...].repeat(n_rhos, axis=0)
368 einsum_subscript = "ij,jk->ik"
370 def _f(rhos, log_rhos, log_sigmas):
371 prod = jnp.einsum(einsum_subscript, rhos, log_rhos - log_sigmas)
372 rel_entropies = jnp.abs(jnp.trace(prod, axis1=-2, axis2=-1))
373 return rel_entropies
375 rel_entropies = jax.vmap(_f, in_axes=(0, 0, 0))(rhos, log_rhos, log_sigmas)
377 if n_sigmas > 1:
378 rel_entropies = rel_entropies.reshape(n_sigmas, n_rhos)
379 return rel_entropies
381 @staticmethod
382 def entanglement_of_formation(
383 model: Model,
384 n_samples: int,
385 random_key: Optional[jax.random.PRNGKey] = None,
386 scale: bool = False,
387 always_decompose: bool = False,
388 **kwargs: Any,
389 ) -> float:
390 """
391 This function implements the entanglement of formation for mixed
392 quantum systems.
393 In that a mixed state gets decomposed into pure states with respective
394 probabilities using the eigendecomposition of the density matrix.
395 Then, the Meyer-Wallach measure is computed for each pure state,
396 weighted by the eigenvalue.
397 See e.g. https://doi.org/10.48550/arXiv.quant-ph/0504163
399 Note that the decomposition is *not unique*! Therefore, this measure
400 presents the entanglement for *some* decomposition into pure states,
401 not necessarily the one that is anticipated when applying the Kraus
402 channels.
403 If a pure state is provided, this results in the same value as the
404 Entanglement.meyer_wallach function if `always_decompose` flag is not set.
406 Args:
407 model (Model): The quantum circuit model.
408 n_samples (int): Number of samples per qubit.
409 random_key (Optional[jax.random.PRNGKey]): JAX random key for
410 parameter initialization. If None, uses the model's internal
411 random key.
412 scale (bool): Whether to scale the number of samples.
413 always_decompose (bool): Whether to explicitly compute the
414 entantlement of formation for the eigendecomposition of a pure
415 state.
416 kwargs (Any): Additional keyword arguments for the model function.
418 Returns:
419 float: Entangling capacity of the given circuit, guaranteed
420 to be between 0.0 and 1.0.
421 """
423 if scale:
424 n_samples = jnp.power(2, model.n_qubits) * n_samples
426 if n_samples is not None and n_samples > 0:
427 model.initialize_params(random_key, repeat=n_samples)
428 else:
429 if len(model.params.shape) <= 2:
430 model.params = model.params.reshape(1, *model.params.shape)
431 else:
432 log.info(f"Using sample size of model params: {model.params.shape[0]}")
434 # implicitly set input to none in case it's not needed
435 kwargs.setdefault("inputs", None)
436 rhos = model(execution_type="density", **kwargs)
437 rhos = rhos.reshape(-1, 2**model.n_qubits, 2**model.n_qubits)
438 ent = Entanglement._compute_entanglement_of_formation(
439 rhos, model.n_qubits, always_decompose
440 )
441 return ent.mean()
443 @staticmethod
444 def _compute_entanglement_of_formation(
445 rhos: jnp.ndarray,
446 n_qubits: int,
447 always_decompose: bool,
448 ) -> jnp.ndarray:
449 """
450 Computes the entanglement of formation for a given batch of density
451 matrices.
453 Args:
454 rho (jnp.ndarray): The density matrices, has shape (B_s, 2^n, 2^n),
455 where B_s is the batch size and n the number of qubits.
456 n_qubits (int): Number of qubits
457 always_decompose (bool): Whether to explicitly compute the
458 entantlement of formation for the eigendecomposition of a pure
459 state.
461 Returns:
462 jnp.ndarray: Entanglement for the provided density matrices.
463 """
464 eigenvalues, eigenvectors = jnp.linalg.eigh(rhos)
465 if not always_decompose and jnp.isclose(eigenvalues, 1.0).any(axis=-1).all():
466 return Entanglement._compute_meyer_wallach_meas(rhos, n_qubits)
468 rhos = np.einsum("sij,sik->sijk", eigenvectors, eigenvectors.conjugate())
469 measures = Entanglement._compute_meyer_wallach_meas(
470 rhos.reshape(-1, 2**n_qubits, 2**n_qubits), n_qubits
471 )
472 ent = np.einsum("si,si->s", measures.reshape(-1, 2**n_qubits), eigenvalues)
473 return ent
475 @staticmethod
476 def concentratable_entanglement(
477 model: Model,
478 n_samples: int,
479 random_key: Optional[jax.random.PRNGKey] = None,
480 scale: bool = False,
481 **kwargs: Any,
482 ) -> float:
483 """
484 Computes the concentratable entanglement of a given model.
486 This method utilizes the Concentratable Entanglement measure from
487 https://arxiv.org/abs/2104.06923. The swap test is implemented
488 directly in yaqsi using a ``3 * n_qubits`` circuit.
490 Args:
491 model (Model): The quantum circuit model.
492 n_samples (int): The number of samples to compute the measure for.
493 random_key (Optional[jax.random.PRNGKey]): JAX random key for
494 parameter initialization. If None, uses the model's internal
495 random key.
496 scale (bool): Whether to scale the number of samples according to
497 the number of qubits.
498 **kwargs (Any): Additional keyword arguments for the model function.
500 Returns:
501 float: Entangling capability of the given circuit, guaranteed
502 to be between 0.0 and 1.0.
503 """
504 n = model.n_qubits
505 N = 2**n
507 if scale:
508 n_samples = N * n_samples
510 def _shift_and_append(tape_ops, offset):
511 """Re-register *tape_ops* on the active tape with wires shifted."""
512 from qml_essentials.tape import active_tape as _active_tape
514 current = _active_tape()
515 if current is None:
516 return
517 for o in tape_ops:
518 shifted = o.__class__.__new__(o.__class__)
519 shifted.__dict__.update(o.__dict__)
520 shifted._wires = [w + offset for w in o.wires]
521 current.append(shifted)
523 def _swap_test_circuit(
524 params, inputs, pulse_params=None, random_key=None, **kw
525 ):
526 """Swap-test circuit on 3*n qubits."""
527 from qml_essentials.tape import recording as _recording
529 # First copy on wires n..2n-1
530 with _recording() as copy1_tape:
531 model._variational(
532 params,
533 inputs,
534 pulse_params=pulse_params,
535 random_key=random_key,
536 **kw,
537 )
538 _shift_and_append(copy1_tape, n)
540 # Second copy on wires 2n..3n-1
541 with _recording() as copy2_tape:
542 model._variational(
543 params,
544 inputs,
545 pulse_params=pulse_params,
546 random_key=random_key,
547 **kw,
548 )
549 _shift_and_append(copy2_tape, 2 * n)
551 # Swap test: H on ancilla register (wires 0..n-1)
552 for i in range(n):
553 op.H(wires=i)
555 for i in range(n):
556 op.CSWAP(wires=[i, i + n, i + 2 * n])
558 for i in range(n):
559 op.H(wires=i)
561 swap_script = ys.Script(f=_swap_test_circuit, n_qubits=3 * n)
563 if n_samples is not None and n_samples > 0:
564 random_key = model.initialize_params(random_key, repeat=n_samples)
565 else:
566 if len(model.params.shape) <= 2:
567 model.params = model.params.reshape(1, *model.params.shape)
568 else:
569 log.info(f"Using sample size of model params: {model.params.shape[0]}")
571 params = model.params
572 inputs = model._inputs_validation(kwargs.get("inputs", None))
573 n_batch = params.shape[0]
575 marg_probs = jax.jit(ys.marginalize_probs, static_argnums=(1, 2))
577 if n_batch > 1:
578 from qml_essentials.utils import safe_random_split
580 random_keys = safe_random_split(random_key, num=n_batch)
581 probs = swap_script.execute(
582 type="probs",
583 args=(params, inputs, model.pulse_params, random_keys),
584 in_axes=(0, None, None, 0),
585 kwargs=kwargs,
586 )
587 else:
588 probs = swap_script.execute(
589 type="probs",
590 args=(params, inputs, model.pulse_params, random_key),
591 kwargs=kwargs,
592 )
594 # Marginalize to the ancilla register (wires 0..n-1)
595 probs = marg_probs(probs, 3 * n, tuple(range(n)))
597 ent = 1 - probs[..., 0]
599 log.debug(f"Variance of measure: {ent.var()}")
601 return float(ent.mean())
604def sample_random_separable_states(
605 n_qubits: int,
606 n_samples: int,
607 random_key: jax.random.PRNGKey,
608 take_log: bool = False,
609) -> jnp.ndarray:
610 """
611 Sample random separable states (density matrix).
613 Args:
614 n_qubits (int): number of qubits in the state
615 n_samples (int): number of states
616 random_key (random.PRNGKey): JAX random key
617 take_log (bool): if the matrix logarithm of the density matrix should be taken.
619 Returns:
620 jnp.ndarray: Density matrices of shape (n_samples, 2**n_qubits, 2**n_qubits)
621 """
622 model = Model(n_qubits, 1, "No_Entangling", data_reupload=False)
623 model.initialize_params(random_key, repeat=n_samples)
624 # explicitly set execution type because everything else won't work
625 sigmas = model(execution_type="density", inputs=None)
626 if take_log:
627 sigmas = logm_v(sigmas) / jnp.log(2.0 + 0j)
629 return sigmas