Coverage for qml_essentials / entanglement.py: 92%
236 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-16 10:19 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-16 10:19 +0000
1from typing import Optional, Any, Tuple
2import jax
3import jax.numpy as jnp
4import numpy as np
6from qml_essentials import yaqsi as ys
7from qml_essentials import operations as op
8from qml_essentials.math import logm_v
9from qml_essentials.model import Model
10import logging
12log = logging.getLogger(__name__)
15class Entanglement:
16 @classmethod
17 def meyer_wallach(
18 cls,
19 model: Model,
20 n_samples: Optional[int | None],
21 random_key: Optional[jax.random.PRNGKey] = None,
22 scale: bool = False,
23 **kwargs: Any,
24 ) -> float:
25 """
26 Calculates the entangling capacity of a given quantum circuit
27 using Meyer-Wallach measure.
29 Args:
30 model (Model): The quantum circuit model.
31 n_samples (Optional[int]): Number of samples per qubit.
32 If None or < 0, the current parameters of the model are used.
33 random_key (Optional[jax.random.PRNGKey]): JAX random key for
34 parameter initialization. If None, uses the model's internal
35 random key.
36 scale (bool): Whether to scale the number of samples.
37 kwargs (Any): Additional keyword arguments for the model function.
39 Returns:
40 float: Entangling capacity of the given circuit, guaranteed
41 to be between 0.0 and 1.0.
42 """
43 if "noise_params" in kwargs:
44 log.warning(
45 "Meyer-Wallach measure not suitable for noisy circuits. "
46 "Consider 'concentratable entanglement' instead."
47 )
49 if scale:
50 n_samples = jnp.power(2, model.n_qubits) * n_samples
52 if n_samples is not None and n_samples > 0:
53 random_key = model.initialize_params(random_key, repeat=n_samples)
55 # implicitly set input to none in case it's not needed
56 kwargs.setdefault("inputs", None)
57 # explicitly set execution type because everything else won't work
58 rhos = model(execution_type="density", **kwargs).reshape(
59 -1, 2**model.n_qubits, 2**model.n_qubits
60 )
62 ent = cls._compute_meyer_wallach_meas(rhos, model.n_qubits)
64 log.debug(f"Variance of measure: {ent.var()}")
66 return ent.mean()
68 @classmethod
69 def _compute_meyer_wallach_meas(
70 cls, rhos: jnp.ndarray, n_qubits: int
71 ) -> jnp.ndarray:
72 """
73 Computes the Meyer-Wallach entangling capability measure for a given
74 set of density matrices.
76 Args:
77 rhos (jnp.ndarray): Density matrices of the sample quantum states.
78 The shape is (B_s, 2^n, 2^n), where B_s is the number of samples
79 (batch) and n the number of qubits
80 n_qubits (int): The number of qubits
82 Returns:
83 jnp.ndarray: Entangling capability for each sample, array with
84 shape (B_s,)
85 """
86 qb = list(range(n_qubits))
88 def _f(rhos):
89 entropy = 0
90 for j in range(n_qubits):
91 # Formula 6 in https://doi.org/10.48550/arXiv.quant-ph/0305094
92 # Trace out qubit j, keep all others
93 keep = qb[:j] + qb[j + 1 :]
94 density = ys.partial_trace(rhos, n_qubits, keep)
95 # only real values, because imaginary part will be separate
96 # in all following calculations anyway
97 # entropy should be 1/2 <= entropy <= 1
98 entropy += jnp.trace((density @ density).real, axis1=-2, axis2=-1)
100 # inverse averaged entropy and scale to [0, 1]
101 return 2 * (1 - entropy / n_qubits)
103 return jax.vmap(_f)(rhos)
105 @classmethod
106 def bell_measurements(
107 cls,
108 model: Model,
109 n_samples: int,
110 random_key: Optional[jax.random.PRNGKey] = None,
111 scale: bool = False,
112 **kwargs: Any,
113 ) -> float:
114 """
115 Compute the Bell measurement for a given model.
117 Constructs a ``2 * n_qubits`` circuit that prepares two copies of
118 the model state (on disjoint qubit registers), applies CNOTs and
119 Hadamards, and measures probabilities on the first register.
121 Args:
122 model (Model): The quantum circuit model.
123 n_samples (int): The number of samples to compute the measure for.
124 random_key (Optional[jax.random.PRNGKey]): JAX random key for
125 parameter initialization. If None, uses the model's internal
126 random key.
127 scale (bool): Whether to scale the number of samples
128 according to the number of qubits.
129 **kwargs (Any): Additional keyword arguments for the model function.
131 Returns:
132 float: The Bell measurement value.
133 """
134 if "noise_params" in kwargs:
135 log.warning(
136 "Bell Measurements not suitable for noisy circuits. "
137 "Consider 'concentratable entanglement' instead."
138 )
140 if scale:
141 n_samples = jnp.power(2, model.n_qubits) * n_samples
143 n = model.n_qubits
145 def _bell_circuit(params, inputs, pulse_params=None, random_key=None, **kw):
146 """Bell measurement circuit on 2*n qubits."""
147 from qml_essentials.tape import copy_to_tape
149 def vari():
150 model._variational(
151 params,
152 inputs,
153 pulse_params=pulse_params,
154 random_key=random_key,
155 **kw,
156 )
158 # First copy on wires 0..n-1
159 vari()
160 # Second copy on wires n..2n-1
161 copy_to_tape(vari, offset=n)
163 # Bell measurement: CNOT + H
164 for q in range(n):
165 op.CX(wires=[q, q + n])
166 op.H(wires=q)
168 bell_script = ys.Script(f=_bell_circuit, n_qubits=2 * n)
170 if n_samples is not None and n_samples > 0:
171 random_key = model.initialize_params(random_key, repeat=n_samples)
172 params = model.params
173 else:
174 if len(model.params.shape) <= 2:
175 params = model.params.reshape(1, *model.params.shape)
176 else:
177 log.info(f"Using sample size of model params: {model.params.shape[0]}")
178 params = model.params
180 n_samples = params.shape[0]
181 inputs = model._inputs_validation(kwargs.get("inputs", None))
183 # Execute: vmap over batch dimension of params (axis 0)
184 if n_samples > 1:
185 from qml_essentials.utils import safe_random_split
187 random_keys = safe_random_split(random_key, num=n_samples)
188 result = bell_script.execute(
189 type="probs",
190 args=(params, inputs, model.pulse_params, random_keys),
191 kwargs=kwargs,
192 in_axes=(0, None, None, 0),
193 )
194 else:
195 result = bell_script.execute(
196 type="probs",
197 args=(params, inputs, model.pulse_params, random_key),
198 kwargs=kwargs,
199 )
201 # Marginalize: for each qubit q, keep wires [q, q+n] from the 2n-qubit probs
202 # The last probability in each pair gives P(|11⟩) for that qubit pair
203 per_qubit = []
204 for q in range(n):
205 marg = ys.marginalize_probs(result, 2 * n, [q, q + n])
206 per_qubit.append(marg)
207 # per_qubit[q] has shape (n_samples, 4) or (4,)
208 exp = jnp.stack(per_qubit, axis=-2) # (..., n, 4)
209 exp = 1 - 2 * exp[..., -1] # (..., n)
211 if not jnp.isclose(jnp.sum(exp.imag), 0, atol=1e-6):
212 log.warning("Imaginary part of probabilities detected")
213 exp = jnp.abs(exp)
215 measure = 2 * (1 - exp.mean(axis=0))
216 entangling_capability = min(max(float(measure.mean()), 0.0), 1.0)
217 log.debug(f"Variance of measure: {measure.var()}")
219 return entangling_capability
221 @classmethod
222 def relative_entropy(
223 cls,
224 model: Model,
225 n_samples: int,
226 n_sigmas: int,
227 random_key: Optional[jax.random.PRNGKey] = None,
228 scale: bool = False,
229 **kwargs: Any,
230 ) -> float:
231 """
232 Calculates the relative entropy of entanglement of a given quantum
233 circuit. This measure is also applicable to mixed state, albeit it
234 might me not fully accurate in this simplified case.
236 As the relative entropy is generally defined as the smallest relative
237 entropy from the state in question to the set of separable states.
238 However, as computing the nearest separable state is NP-hard, we select
239 n_sigmas of random separable states to compute the distance to, which
240 is not necessarily the nearest. Thus, this measure of entanglement
241 presents an upper limit of entanglement.
243 As the relative entropy is not necessarily between zero and one, this
244 function also normalises by the relative entroy to the GHZ state.
246 Args:
247 model (Model): The quantum circuit model.
248 n_samples (int): Number of samples per qubit.
249 If <= 0, the current parameters of the model are used.
250 n_sigmas (int): Number of random separable pure states to compare against.
251 random_key (Optional[jax.random.PRNGKey]): JAX random key for
252 parameter initialization. If None, uses the model's internal
253 random key.
254 scale (bool): Whether to scale the number of samples.
255 kwargs (Any): Additional keyword arguments for the model function.
257 Returns:
258 float: Entangling capacity of the given circuit, guaranteed
259 to be between 0.0 and 1.0.
260 """
261 dim = jnp.power(2, model.n_qubits)
262 if scale:
263 n_samples = dim * n_samples
264 n_sigmas = dim * n_sigmas
266 if random_key is None:
267 random_key = model.random_key
269 # Random separable states
270 log_sigmas = sample_random_separable_states(
271 model.n_qubits, n_samples=n_sigmas, random_key=random_key, take_log=True
272 )
274 random_key, _ = jax.random.split(random_key)
276 if n_samples is not None and n_samples > 0:
277 model.initialize_params(random_key, repeat=n_samples)
278 else:
279 if len(model.params.shape) <= 2:
280 model.params = model.params.reshape(1, *model.params.shape)
281 else:
282 log.info(f"Using sample size of model params: {model.params.shape[0]}")
284 rhos, log_rhos = cls._compute_log_density(model, **kwargs)
286 rel_entropies = jnp.zeros((n_sigmas, model.params.shape[0]))
288 for i, log_sigma in enumerate(log_sigmas):
289 rel_entropies = rel_entropies.at[i].set(
290 cls._compute_rel_entropies(rhos, log_rhos, log_sigma)
291 )
293 # Entropy of GHZ states should be maximal
294 ghz_model = Model(model.n_qubits, 1, "GHZ", data_reupload=False)
295 rho_ghz, log_rho_ghz = cls._compute_log_density(ghz_model, **kwargs)
296 ghz_entropies = cls._compute_rel_entropies(rho_ghz, log_rho_ghz, log_sigmas)
298 normalised_entropies = rel_entropies / ghz_entropies
300 # Average all iterated states
301 entangling_capability = normalised_entropies.T.min(axis=1)
302 log.debug(f"Variance of measure: {entangling_capability.var()}")
304 return entangling_capability.mean()
306 @classmethod
307 def _compute_log_density(
308 cls, model: Model, **kwargs
309 ) -> Tuple[jnp.ndarray, jnp.ndarray]:
310 """
311 Obtains the density matrix of a model and computes its logarithm.
313 Args:
314 model (Model): The model for which to compute the density matrix.
316 Returns:
317 Tuple[jnp.ndarray, jnp.ndarray]:
318 - jnp.ndarray: density matrix.
319 - jnp.ndarray: logarithm of the density matrix.
320 """
321 # implicitly set input to none in case it's not needed
322 kwargs.setdefault("inputs", None)
323 # explicitly set execution type because everything else won't work
324 rho = model(execution_type="density", **kwargs)
325 rho = rho.reshape(-1, 2**model.n_qubits, 2**model.n_qubits)
326 log_rho = logm_v(rho) / jnp.log(2)
327 return rho, log_rho
329 @classmethod
330 def _compute_rel_entropies(
331 cls,
332 rhos: jnp.ndarray,
333 log_rhos: jnp.ndarray,
334 log_sigmas: jnp.ndarray,
335 ) -> jnp.ndarray:
336 """
337 Compute the relative entropy for a given model.
339 Args:
340 rhos (jnp.ndarray): Density matrix result of the circuit, has shape
341 (R, 2^n, 2^n), with the batch size R and number of qubits n
342 log_rhos (jnp.ndarray): Corresponding logarithm of the density
343 matrix, has shape (R, 2^n, 2^n).
344 log_sigmas (jnp.ndarray): Density matrix of next separable state,
345 has shape (2^n, 2^n) if it's a single sigma or (S, 2^n, 2^n),
346 with the batch size S (number of sigmas).
348 Returns:
349 jnp.ndarray: Relative Entropy for each sample
350 """
351 n_rhos = rhos.shape[0]
352 if len(log_sigmas.shape) == 3:
353 n_sigmas = log_sigmas.shape[0]
354 rhos = jnp.tile(rhos, (n_sigmas, 1, 1))
355 log_rhos = jnp.tile(log_rhos, (n_sigmas, 1, 1))
356 einsum_subscript = "ij,jk->ik"
357 else:
358 n_sigmas = 1
359 log_sigmas = log_sigmas[jnp.newaxis, ...].repeat(n_rhos, axis=0)
361 einsum_subscript = "ij,jk->ik"
363 def _f(rhos, log_rhos, log_sigmas):
364 prod = jnp.einsum(einsum_subscript, rhos, log_rhos - log_sigmas)
365 rel_entropies = jnp.abs(jnp.trace(prod, axis1=-2, axis2=-1))
366 return rel_entropies
368 rel_entropies = jax.vmap(_f, in_axes=(0, 0, 0))(rhos, log_rhos, log_sigmas)
370 if n_sigmas > 1:
371 rel_entropies = rel_entropies.reshape(n_sigmas, n_rhos)
372 return rel_entropies
374 @classmethod
375 def entanglement_of_formation(
376 cls,
377 model: Model,
378 n_samples: int,
379 random_key: Optional[jax.random.PRNGKey] = None,
380 scale: bool = False,
381 always_decompose: bool = False,
382 **kwargs: Any,
383 ) -> float:
384 """
385 This function implements the entanglement of formation for mixed
386 quantum systems.
387 In that a mixed state gets decomposed into pure states with respective
388 probabilities using the eigendecomposition of the density matrix.
389 Then, the Meyer-Wallach measure is computed for each pure state,
390 weighted by the eigenvalue.
391 See e.g. https://doi.org/10.48550/arXiv.quant-ph/0504163
393 Note that the decomposition is *not unique*! Therefore, this measure
394 presents the entanglement for *some* decomposition into pure states,
395 not necessarily the one that is anticipated when applying the Kraus
396 channels.
397 If a pure state is provided, this results in the same value as the
398 Entanglement.meyer_wallach function if `always_decompose` flag is not set.
400 Args:
401 model (Model): The quantum circuit model.
402 n_samples (int): Number of samples per qubit.
403 random_key (Optional[jax.random.PRNGKey]): JAX random key for
404 parameter initialization. If None, uses the model's internal
405 random key.
406 scale (bool): Whether to scale the number of samples.
407 always_decompose (bool): Whether to explicitly compute the
408 entantlement of formation for the eigendecomposition of a pure
409 state.
410 kwargs (Any): Additional keyword arguments for the model function.
412 Returns:
413 float: Entangling capacity of the given circuit, guaranteed
414 to be between 0.0 and 1.0.
415 """
417 if scale:
418 n_samples = jnp.power(2, model.n_qubits) * n_samples
420 if n_samples is not None and n_samples > 0:
421 model.initialize_params(random_key, repeat=n_samples)
422 else:
423 if len(model.params.shape) <= 2:
424 model.params = model.params.reshape(1, *model.params.shape)
425 else:
426 log.info(f"Using sample size of model params: {model.params.shape[0]}")
428 # implicitly set input to none in case it's not needed
429 kwargs.setdefault("inputs", None)
430 rhos = model(execution_type="density", **kwargs)
431 rhos = rhos.reshape(-1, 2**model.n_qubits, 2**model.n_qubits)
432 ent = cls._compute_entanglement_of_formation(
433 rhos, model.n_qubits, always_decompose
434 )
435 return ent.mean()
437 @classmethod
438 def _compute_entanglement_of_formation(
439 cls,
440 rhos: jnp.ndarray,
441 n_qubits: int,
442 always_decompose: bool,
443 ) -> jnp.ndarray:
444 """
445 Computes the entanglement of formation for a given batch of density
446 matrices.
448 Args:
449 rho (jnp.ndarray): The density matrices, has shape (B_s, 2^n, 2^n),
450 where B_s is the batch size and n the number of qubits.
451 n_qubits (int): Number of qubits
452 always_decompose (bool): Whether to explicitly compute the
453 entantlement of formation for the eigendecomposition of a pure
454 state.
456 Returns:
457 jnp.ndarray: Entanglement for the provided density matrices.
458 """
459 eigenvalues, eigenvectors = jnp.linalg.eigh(rhos)
460 if not always_decompose and jnp.isclose(eigenvalues, 1.0).any(axis=-1).all():
461 return cls._compute_meyer_wallach_meas(rhos, n_qubits)
463 rhos = np.einsum("sij,sik->sijk", eigenvectors, eigenvectors.conjugate())
464 measures = cls._compute_meyer_wallach_meas(
465 rhos.reshape(-1, 2**n_qubits, 2**n_qubits), n_qubits
466 )
467 ent = np.einsum("si,si->s", measures.reshape(-1, 2**n_qubits), eigenvalues)
468 return ent
470 @classmethod
471 def concentratable_entanglement(
472 cls,
473 model: Model,
474 n_samples: int,
475 random_key: Optional[jax.random.PRNGKey] = None,
476 scale: bool = False,
477 **kwargs: Any,
478 ) -> float:
479 """
480 Computes the concentratable entanglement of a given model.
482 This method utilizes the Concentratable Entanglement measure from
483 https://arxiv.org/abs/2104.06923. The swap test is implemented
484 directly in yaqsi using a ``3 * n_qubits`` circuit.
486 Args:
487 model (Model): The quantum circuit model.
488 n_samples (int): The number of samples to compute the measure for.
489 random_key (Optional[jax.random.PRNGKey]): JAX random key for
490 parameter initialization. If None, uses the model's internal
491 random key.
492 scale (bool): Whether to scale the number of samples according to
493 the number of qubits.
494 **kwargs (Any): Additional keyword arguments for the model function.
496 Returns:
497 float: Entangling capability of the given circuit, guaranteed
498 to be between 0.0 and 1.0.
499 """
500 n = model.n_qubits
501 N = 2**n
503 if scale:
504 n_samples = N * n_samples
506 def _swap_test_circuit(
507 params, inputs, pulse_params=None, random_key=None, **kw
508 ):
509 """Swap-test circuit on 3*n qubits."""
510 from qml_essentials.tape import copy_to_tape
512 def vari():
513 model._variational(
514 params,
515 inputs,
516 pulse_params=pulse_params,
517 random_key=random_key,
518 **kw,
519 )
521 # First copy on wires n..2n-1
522 copy_to_tape(vari, offset=n)
523 # Second copy on wires 2n..3n-1
524 copy_to_tape(vari, offset=2 * n)
526 # Swap test: H on ancilla register (wires 0..n-1)
527 for i in range(n):
528 op.H(wires=i)
530 for i in range(n):
531 op.CSWAP(wires=[i, i + n, i + 2 * n])
533 for i in range(n):
534 op.H(wires=i)
536 swap_script = ys.Script(f=_swap_test_circuit, n_qubits=3 * n)
538 if n_samples is not None and n_samples > 0:
539 random_key = model.initialize_params(random_key, repeat=n_samples)
540 else:
541 if len(model.params.shape) <= 2:
542 model.params = model.params.reshape(1, *model.params.shape)
543 else:
544 log.info(f"Using sample size of model params: {model.params.shape[0]}")
546 params = model.params
547 inputs = model._inputs_validation(kwargs.get("inputs", None))
548 n_batch = params.shape[0]
550 marg_probs = jax.jit(ys.marginalize_probs, static_argnums=(1, 2))
552 if n_batch > 1:
553 from qml_essentials.utils import safe_random_split
555 random_keys = safe_random_split(random_key, num=n_batch)
556 probs = swap_script.execute(
557 type="probs",
558 args=(params, inputs, model.pulse_params, random_keys),
559 in_axes=(0, None, None, 0),
560 kwargs=kwargs,
561 )
562 else:
563 probs = swap_script.execute(
564 type="probs",
565 args=(params, inputs, model.pulse_params, random_key),
566 kwargs=kwargs,
567 )
569 # Marginalize to the ancilla register (wires 0..n-1)
570 probs = marg_probs(probs, 3 * n, tuple(range(n)))
572 ent = 1 - probs[..., 0]
574 log.debug(f"Variance of measure: {ent.var()}")
576 return float(ent.mean())
578 @classmethod
579 def concentratable_entanglement_estimation(
580 cls,
581 model: Model,
582 n_samples: int,
583 random_key: Optional[jax.random.PRNGKey] = None,
584 scale: bool = False,
585 **kwargs: Any,
586 ) -> float:
587 """
588 Computes the concentratable entanglement of a given model.
590 This method utilizes the Concentratable Entanglement measure from
591 https://arxiv.org/abs/2104.06923. The swap test is implemented
592 directly in yaqsi using a ``3 * n_qubits`` circuit.
594 Args:
595 model (Model): The quantum circuit model.
596 n_samples (int): The number of samples to compute the measure for.
597 random_key (Optional[jax.random.PRNGKey]): JAX random key for
598 parameter initialization. If None, uses the model's internal
599 random key.
600 scale (bool): Whether to scale the number of samples according to
601 the number of qubits.
602 **kwargs (Any): Additional keyword arguments for the model function.
604 Returns:
605 float: Entangling capability of the given circuit, guaranteed
606 to be between 0.0 and 1.0.
607 """
608 n = model.n_qubits
609 N = 2**n
611 if scale:
612 n_samples = N * n_samples
614 def _bell_basis_measurement(
615 params, inputs, pulse_params=None, random_key=None, **kw
616 ):
617 """Bell-basis measurement circuit on 3*n qubits."""
618 from qml_essentials.tape import copy_to_tape
620 def vari():
621 model._variational(
622 params,
623 inputs,
624 pulse_params=pulse_params,
625 random_key=random_key,
626 **kw,
627 )
629 # First copy on wires 0..n-1
630 copy_to_tape(vari, offset=0)
631 # Second copy on wires n..2n-1
632 copy_to_tape(vari, offset=n)
634 for i in range(n):
635 op.CX(wires=[i, i + n])
636 op.H(wires=i)
638 bell_basis_script = ys.Script(f=_bell_basis_measurement, n_qubits=2 * n)
640 if n_samples is not None and n_samples > 0:
641 random_key = model.initialize_params(random_key, repeat=n_samples)
642 else:
643 if len(model.params.shape) <= 2:
644 model.params = model.params.reshape(1, *model.params.shape)
645 else:
646 log.info(f"Using sample size of model params: {model.params.shape[0]}")
648 params = model.params
649 inputs = model._inputs_validation(kwargs.get("inputs", None))
650 n_batch = params.shape[0]
652 # SWAP operator in Bell-basis
653 SWAP = jnp.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, -1]])
654 # Construct observable for measuring CE
655 CE_observable = op.Id([0, n]) + op.Operation([0, n], SWAP)
656 for i in range(1, n):
657 CE_observable = CE_observable @ (
658 op.Id([i, i + n]) + op.Operation([i, i + n], SWAP)
659 )
660 CE_observable = (1 / N) * CE_observable
662 expvals = []
663 if n_batch > 1:
664 from qml_essentials.utils import safe_random_split
666 random_keys = safe_random_split(random_key, num=n_batch)
667 expvals = bell_basis_script.execute(
668 type="expval",
669 obs=[CE_observable],
670 args=(params, inputs, model.pulse_params, random_keys),
671 in_axes=(0, None, None, 0),
672 kwargs=kwargs,
673 )
674 else:
675 expvals = bell_basis_script.execute(
676 type="expval",
677 obs=[CE_observable],
678 args=(params, inputs, model.pulse_params, random_key),
679 kwargs=kwargs,
680 )
682 ent = 1 - expvals
683 log.debug(f"Variance of measure: {ent.var()}")
684 return float(ent.mean())
687def sample_random_separable_states(
688 n_qubits: int,
689 n_samples: int,
690 random_key: jax.random.PRNGKey,
691 take_log: bool = False,
692) -> jnp.ndarray:
693 """
694 Sample random separable states (density matrix).
696 Args:
697 n_qubits (int): number of qubits in the state
698 n_samples (int): number of states
699 random_key (random.PRNGKey): JAX random key
700 take_log (bool): if the matrix logarithm of the density matrix should be taken.
702 Returns:
703 jnp.ndarray: Density matrices of shape (n_samples, 2**n_qubits, 2**n_qubits)
704 """
705 model = Model(n_qubits, 1, "No_Entangling", data_reupload=False)
706 model.initialize_params(random_key, repeat=n_samples)
707 # explicitly set execution type because anything else won't work
708 sigmas = model(execution_type="density", inputs=None)
709 if take_log:
710 sigmas = logm_v(sigmas) / jnp.log(2.0 + 0j)
712 return sigmas