Coverage for qml_essentials / entanglement.py: 92%
237 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-04-10 10:29 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-04-10 10:29 +0000
1from typing import Optional, Any, Tuple
2import jax
3import jax.numpy as jnp
4import numpy as np
6from qml_essentials import yaqsi as ys
7from qml_essentials import operations as op
8from qml_essentials.math import logm_v
9from qml_essentials.model import Model
10import logging
12from qml_essentials.operations import Hermitian
14log = logging.getLogger(__name__)
17class Entanglement:
18 @staticmethod
19 def meyer_wallach(
20 model: Model,
21 n_samples: Optional[int | None],
22 random_key: Optional[jax.random.PRNGKey] = None,
23 scale: bool = False,
24 **kwargs: Any,
25 ) -> float:
26 """
27 Calculates the entangling capacity of a given quantum circuit
28 using Meyer-Wallach measure.
30 Args:
31 model (Model): The quantum circuit model.
32 n_samples (Optional[int]): Number of samples per qubit.
33 If None or < 0, the current parameters of the model are used.
34 random_key (Optional[jax.random.PRNGKey]): JAX random key for
35 parameter initialization. If None, uses the model's internal
36 random key.
37 scale (bool): Whether to scale the number of samples.
38 kwargs (Any): Additional keyword arguments for the model function.
40 Returns:
41 float: Entangling capacity of the given circuit, guaranteed
42 to be between 0.0 and 1.0.
43 """
44 if "noise_params" in kwargs:
45 log.warning(
46 "Meyer-Wallach measure not suitable for noisy circuits. "
47 "Consider 'concentratable entanglement' instead."
48 )
50 if scale:
51 n_samples = jnp.power(2, model.n_qubits) * n_samples
53 if n_samples is not None and n_samples > 0:
54 random_key = model.initialize_params(random_key, repeat=n_samples)
56 # implicitly set input to none in case it's not needed
57 kwargs.setdefault("inputs", None)
58 # explicitly set execution type because everything else won't work
59 rhos = model(execution_type="density", **kwargs).reshape(
60 -1, 2**model.n_qubits, 2**model.n_qubits
61 )
63 ent = Entanglement._compute_meyer_wallach_meas(rhos, model.n_qubits)
65 log.debug(f"Variance of measure: {ent.var()}")
67 return ent.mean()
69 @staticmethod
70 def _compute_meyer_wallach_meas(rhos: jnp.ndarray, n_qubits: int) -> jnp.ndarray:
71 """
72 Computes the Meyer-Wallach entangling capability measure for a given
73 set of density matrices.
75 Args:
76 rhos (jnp.ndarray): Density matrices of the sample quantum states.
77 The shape is (B_s, 2^n, 2^n), where B_s is the number of samples
78 (batch) and n the number of qubits
79 n_qubits (int): The number of qubits
81 Returns:
82 jnp.ndarray: Entangling capability for each sample, array with
83 shape (B_s,)
84 """
85 qb = list(range(n_qubits))
87 def _f(rhos):
88 entropy = 0
89 for j in range(n_qubits):
90 # Formula 6 in https://doi.org/10.48550/arXiv.quant-ph/0305094
91 # Trace out qubit j, keep all others
92 keep = qb[:j] + qb[j + 1 :]
93 density = ys.partial_trace(rhos, n_qubits, keep)
94 # only real values, because imaginary part will be separate
95 # in all following calculations anyway
96 # entropy should be 1/2 <= entropy <= 1
97 entropy += jnp.trace((density @ density).real, axis1=-2, axis2=-1)
99 # inverse averaged entropy and scale to [0, 1]
100 return 2 * (1 - entropy / n_qubits)
102 return jax.vmap(_f)(rhos)
104 @staticmethod
105 def bell_measurements(
106 model: Model,
107 n_samples: int,
108 random_key: Optional[jax.random.PRNGKey] = None,
109 scale: bool = False,
110 **kwargs: Any,
111 ) -> float:
112 """
113 Compute the Bell measurement for a given model.
115 Constructs a ``2 * n_qubits`` circuit that prepares two copies of
116 the model state (on disjoint qubit registers), applies CNOTs and
117 Hadamards, and measures probabilities on the first register.
119 Args:
120 model (Model): The quantum circuit model.
121 n_samples (int): The number of samples to compute the measure for.
122 random_key (Optional[jax.random.PRNGKey]): JAX random key for
123 parameter initialization. If None, uses the model's internal
124 random key.
125 scale (bool): Whether to scale the number of samples
126 according to the number of qubits.
127 **kwargs (Any): Additional keyword arguments for the model function.
129 Returns:
130 float: The Bell measurement value.
131 """
132 if "noise_params" in kwargs:
133 log.warning(
134 "Bell Measurements not suitable for noisy circuits. "
135 "Consider 'concentratable entanglement' instead."
136 )
138 if scale:
139 n_samples = jnp.power(2, model.n_qubits) * n_samples
141 n = model.n_qubits
143 def _bell_circuit(params, inputs, pulse_params=None, random_key=None, **kw):
144 """Bell measurement circuit on 2*n qubits."""
145 from qml_essentials.tape import copy_to_tape
147 def vari():
148 model._variational(
149 params,
150 inputs,
151 pulse_params=pulse_params,
152 random_key=random_key,
153 **kw
154 )
156 # First copy on wires 0..n-1
157 vari()
158 # Second copy on wires n..2n-1
159 copy_to_tape(vari, offset=n)
161 # Bell measurement: CNOT + H
162 for q in range(n):
163 op.CX(wires=[q, q + n])
164 op.H(wires=q)
166 bell_script = ys.Script(f=_bell_circuit, n_qubits=2 * n)
168 if n_samples is not None and n_samples > 0:
169 random_key = model.initialize_params(random_key, repeat=n_samples)
170 params = model.params
171 else:
172 if len(model.params.shape) <= 2:
173 params = model.params.reshape(1, *model.params.shape)
174 else:
175 log.info(f"Using sample size of model params: {model.params.shape[0]}")
176 params = model.params
178 n_samples = params.shape[0]
179 inputs = model._inputs_validation(kwargs.get("inputs", None))
181 # Execute: vmap over batch dimension of params (axis 0)
182 if n_samples > 1:
183 from qml_essentials.utils import safe_random_split
185 random_keys = safe_random_split(random_key, num=n_samples)
186 result = bell_script.execute(
187 type="probs",
188 args=(params, inputs, model.pulse_params, random_keys),
189 kwargs=kwargs,
190 in_axes=(0, None, None, 0),
191 )
192 else:
193 result = bell_script.execute(
194 type="probs",
195 args=(params, inputs, model.pulse_params, random_key),
196 kwargs=kwargs,
197 )
199 # Marginalize: for each qubit q, keep wires [q, q+n] from the 2n-qubit probs
200 # The last probability in each pair gives P(|11⟩) for that qubit pair
201 per_qubit = []
202 for q in range(n):
203 marg = ys.marginalize_probs(result, 2 * n, [q, q + n])
204 per_qubit.append(marg)
205 # per_qubit[q] has shape (n_samples, 4) or (4,)
206 exp = jnp.stack(per_qubit, axis=-2) # (..., n, 4)
207 exp = 1 - 2 * exp[..., -1] # (..., n)
209 if not jnp.isclose(jnp.sum(exp.imag), 0, atol=1e-6):
210 log.warning("Imaginary part of probabilities detected")
211 exp = jnp.abs(exp)
213 measure = 2 * (1 - exp.mean(axis=0))
214 entangling_capability = min(max(float(measure.mean()), 0.0), 1.0)
215 log.debug(f"Variance of measure: {measure.var()}")
217 return entangling_capability
219 @staticmethod
220 def relative_entropy(
221 model: Model,
222 n_samples: int,
223 n_sigmas: int,
224 random_key: Optional[jax.random.PRNGKey] = None,
225 scale: bool = False,
226 **kwargs: Any,
227 ) -> float:
228 """
229 Calculates the relative entropy of entanglement of a given quantum
230 circuit. This measure is also applicable to mixed state, albeit it
231 might me not fully accurate in this simplified case.
233 As the relative entropy is generally defined as the smallest relative
234 entropy from the state in question to the set of separable states.
235 However, as computing the nearest separable state is NP-hard, we select
236 n_sigmas of random separable states to compute the distance to, which
237 is not necessarily the nearest. Thus, this measure of entanglement
238 presents an upper limit of entanglement.
240 As the relative entropy is not necessarily between zero and one, this
241 function also normalises by the relative entroy to the GHZ state.
243 Args:
244 model (Model): The quantum circuit model.
245 n_samples (int): Number of samples per qubit.
246 If <= 0, the current parameters of the model are used.
247 n_sigmas (int): Number of random separable pure states to compare against.
248 random_key (Optional[jax.random.PRNGKey]): JAX random key for
249 parameter initialization. If None, uses the model's internal
250 random key.
251 scale (bool): Whether to scale the number of samples.
252 kwargs (Any): Additional keyword arguments for the model function.
254 Returns:
255 float: Entangling capacity of the given circuit, guaranteed
256 to be between 0.0 and 1.0.
257 """
258 dim = jnp.power(2, model.n_qubits)
259 if scale:
260 n_samples = dim * n_samples
261 n_sigmas = dim * n_sigmas
263 if random_key is None:
264 random_key = model.random_key
266 # Random separable states
267 log_sigmas = sample_random_separable_states(
268 model.n_qubits, n_samples=n_sigmas, random_key=random_key, take_log=True
269 )
271 random_key, _ = jax.random.split(random_key)
273 if n_samples is not None and n_samples > 0:
274 model.initialize_params(random_key, repeat=n_samples)
275 else:
276 if len(model.params.shape) <= 2:
277 model.params = model.params.reshape(1, *model.params.shape)
278 else:
279 log.info(f"Using sample size of model params: {model.params.shape[0]}")
281 rhos, log_rhos = Entanglement._compute_log_density(model, **kwargs)
283 rel_entropies = jnp.zeros((n_sigmas, model.params.shape[0]))
285 for i, log_sigma in enumerate(log_sigmas):
286 rel_entropies = rel_entropies.at[i].set(
287 Entanglement._compute_rel_entropies(rhos, log_rhos, log_sigma)
288 )
290 # Entropy of GHZ states should be maximal
291 ghz_model = Model(model.n_qubits, 1, "GHZ", data_reupload=False)
292 rho_ghz, log_rho_ghz = Entanglement._compute_log_density(ghz_model, **kwargs)
293 ghz_entropies = Entanglement._compute_rel_entropies(
294 rho_ghz, log_rho_ghz, log_sigmas
295 )
297 normalised_entropies = rel_entropies / ghz_entropies
299 # Average all iterated states
300 entangling_capability = normalised_entropies.T.min(axis=1)
301 log.debug(f"Variance of measure: {entangling_capability.var()}")
303 return entangling_capability.mean()
305 @staticmethod
306 def _compute_log_density(model: Model, **kwargs) -> Tuple[jnp.ndarray, jnp.ndarray]:
307 """
308 Obtains the density matrix of a model and computes its logarithm.
310 Args:
311 model (Model): The model for which to compute the density matrix.
313 Returns:
314 Tuple[jnp.ndarray, jnp.ndarray]:
315 - jnp.ndarray: density matrix.
316 - jnp.ndarray: logarithm of the density matrix.
317 """
318 # implicitly set input to none in case it's not needed
319 kwargs.setdefault("inputs", None)
320 # explicitly set execution type because everything else won't work
321 rho = model(execution_type="density", **kwargs)
322 rho = rho.reshape(-1, 2**model.n_qubits, 2**model.n_qubits)
323 log_rho = logm_v(rho) / jnp.log(2)
324 return rho, log_rho
326 @staticmethod
327 def _compute_rel_entropies(
328 rhos: jnp.ndarray,
329 log_rhos: jnp.ndarray,
330 log_sigmas: jnp.ndarray,
331 ) -> jnp.ndarray:
332 """
333 Compute the relative entropy for a given model.
335 Args:
336 rhos (jnp.ndarray): Density matrix result of the circuit, has shape
337 (R, 2^n, 2^n), with the batch size R and number of qubits n
338 log_rhos (jnp.ndarray): Corresponding logarithm of the density
339 matrix, has shape (R, 2^n, 2^n).
340 log_sigmas (jnp.ndarray): Density matrix of next separable state,
341 has shape (2^n, 2^n) if it's a single sigma or (S, 2^n, 2^n),
342 with the batch size S (number of sigmas).
344 Returns:
345 jnp.ndarray: Relative Entropy for each sample
346 """
347 n_rhos = rhos.shape[0]
348 if len(log_sigmas.shape) == 3:
349 n_sigmas = log_sigmas.shape[0]
350 rhos = jnp.tile(rhos, (n_sigmas, 1, 1))
351 log_rhos = jnp.tile(log_rhos, (n_sigmas, 1, 1))
352 einsum_subscript = "ij,jk->ik"
353 else:
354 n_sigmas = 1
355 log_sigmas = log_sigmas[jnp.newaxis, ...].repeat(n_rhos, axis=0)
357 einsum_subscript = "ij,jk->ik"
359 def _f(rhos, log_rhos, log_sigmas):
360 prod = jnp.einsum(einsum_subscript, rhos, log_rhos - log_sigmas)
361 rel_entropies = jnp.abs(jnp.trace(prod, axis1=-2, axis2=-1))
362 return rel_entropies
364 rel_entropies = jax.vmap(_f, in_axes=(0, 0, 0))(rhos, log_rhos, log_sigmas)
366 if n_sigmas > 1:
367 rel_entropies = rel_entropies.reshape(n_sigmas, n_rhos)
368 return rel_entropies
370 @staticmethod
371 def entanglement_of_formation(
372 model: Model,
373 n_samples: int,
374 random_key: Optional[jax.random.PRNGKey] = None,
375 scale: bool = False,
376 always_decompose: bool = False,
377 **kwargs: Any,
378 ) -> float:
379 """
380 This function implements the entanglement of formation for mixed
381 quantum systems.
382 In that a mixed state gets decomposed into pure states with respective
383 probabilities using the eigendecomposition of the density matrix.
384 Then, the Meyer-Wallach measure is computed for each pure state,
385 weighted by the eigenvalue.
386 See e.g. https://doi.org/10.48550/arXiv.quant-ph/0504163
388 Note that the decomposition is *not unique*! Therefore, this measure
389 presents the entanglement for *some* decomposition into pure states,
390 not necessarily the one that is anticipated when applying the Kraus
391 channels.
392 If a pure state is provided, this results in the same value as the
393 Entanglement.meyer_wallach function if `always_decompose` flag is not set.
395 Args:
396 model (Model): The quantum circuit model.
397 n_samples (int): Number of samples per qubit.
398 random_key (Optional[jax.random.PRNGKey]): JAX random key for
399 parameter initialization. If None, uses the model's internal
400 random key.
401 scale (bool): Whether to scale the number of samples.
402 always_decompose (bool): Whether to explicitly compute the
403 entantlement of formation for the eigendecomposition of a pure
404 state.
405 kwargs (Any): Additional keyword arguments for the model function.
407 Returns:
408 float: Entangling capacity of the given circuit, guaranteed
409 to be between 0.0 and 1.0.
410 """
412 if scale:
413 n_samples = jnp.power(2, model.n_qubits) * n_samples
415 if n_samples is not None and n_samples > 0:
416 model.initialize_params(random_key, repeat=n_samples)
417 else:
418 if len(model.params.shape) <= 2:
419 model.params = model.params.reshape(1, *model.params.shape)
420 else:
421 log.info(f"Using sample size of model params: {model.params.shape[0]}")
423 # implicitly set input to none in case it's not needed
424 kwargs.setdefault("inputs", None)
425 rhos = model(execution_type="density", **kwargs)
426 rhos = rhos.reshape(-1, 2**model.n_qubits, 2**model.n_qubits)
427 ent = Entanglement._compute_entanglement_of_formation(
428 rhos, model.n_qubits, always_decompose
429 )
430 return ent.mean()
432 @staticmethod
433 def _compute_entanglement_of_formation(
434 rhos: jnp.ndarray,
435 n_qubits: int,
436 always_decompose: bool,
437 ) -> jnp.ndarray:
438 """
439 Computes the entanglement of formation for a given batch of density
440 matrices.
442 Args:
443 rho (jnp.ndarray): The density matrices, has shape (B_s, 2^n, 2^n),
444 where B_s is the batch size and n the number of qubits.
445 n_qubits (int): Number of qubits
446 always_decompose (bool): Whether to explicitly compute the
447 entantlement of formation for the eigendecomposition of a pure
448 state.
450 Returns:
451 jnp.ndarray: Entanglement for the provided density matrices.
452 """
453 eigenvalues, eigenvectors = jnp.linalg.eigh(rhos)
454 if not always_decompose and jnp.isclose(eigenvalues, 1.0).any(axis=-1).all():
455 return Entanglement._compute_meyer_wallach_meas(rhos, n_qubits)
457 rhos = np.einsum("sij,sik->sijk", eigenvectors, eigenvectors.conjugate())
458 measures = Entanglement._compute_meyer_wallach_meas(
459 rhos.reshape(-1, 2**n_qubits, 2**n_qubits), n_qubits
460 )
461 ent = np.einsum("si,si->s", measures.reshape(-1, 2**n_qubits), eigenvalues)
462 return ent
464 @staticmethod
465 def concentratable_entanglement(
466 model: Model,
467 n_samples: int,
468 random_key: Optional[jax.random.PRNGKey] = None,
469 scale: bool = False,
470 **kwargs: Any,
471 ) -> float:
472 """
473 Computes the concentratable entanglement of a given model.
475 This method utilizes the Concentratable Entanglement measure from
476 https://arxiv.org/abs/2104.06923. The swap test is implemented
477 directly in yaqsi using a ``3 * n_qubits`` circuit.
479 Args:
480 model (Model): The quantum circuit model.
481 n_samples (int): The number of samples to compute the measure for.
482 random_key (Optional[jax.random.PRNGKey]): JAX random key for
483 parameter initialization. If None, uses the model's internal
484 random key.
485 scale (bool): Whether to scale the number of samples according to
486 the number of qubits.
487 **kwargs (Any): Additional keyword arguments for the model function.
489 Returns:
490 float: Entangling capability of the given circuit, guaranteed
491 to be between 0.0 and 1.0.
492 """
493 n = model.n_qubits
494 N = 2**n
496 if scale:
497 n_samples = N * n_samples
499 def _swap_test_circuit(
500 params, inputs, pulse_params=None, random_key=None, **kw
501 ):
502 """Swap-test circuit on 3*n qubits."""
503 from qml_essentials.tape import copy_to_tape
505 def vari():
506 model._variational(
507 params,
508 inputs,
509 pulse_params=pulse_params,
510 random_key=random_key,
511 **kw
512 )
514 # First copy on wires n..2n-1
515 copy_to_tape(vari, offset=n)
516 # Second copy on wires 2n..3n-1
517 copy_to_tape(vari, offset=2 * n)
519 # Swap test: H on ancilla register (wires 0..n-1)
520 for i in range(n):
521 op.H(wires=i)
523 for i in range(n):
524 op.CSWAP(wires=[i, i + n, i + 2 * n])
526 for i in range(n):
527 op.H(wires=i)
529 swap_script = ys.Script(f=_swap_test_circuit, n_qubits=3 * n)
531 if n_samples is not None and n_samples > 0:
532 random_key = model.initialize_params(random_key, repeat=n_samples)
533 else:
534 if len(model.params.shape) <= 2:
535 model.params = model.params.reshape(1, *model.params.shape)
536 else:
537 log.info(f"Using sample size of model params: {model.params.shape[0]}")
539 params = model.params
540 inputs = model._inputs_validation(kwargs.get("inputs", None))
541 n_batch = params.shape[0]
543 marg_probs = jax.jit(ys.marginalize_probs, static_argnums=(1, 2))
545 if n_batch > 1:
546 from qml_essentials.utils import safe_random_split
548 random_keys = safe_random_split(random_key, num=n_batch)
549 probs = swap_script.execute(
550 type="probs",
551 args=(params, inputs, model.pulse_params, random_keys),
552 in_axes=(0, None, None, 0),
553 kwargs=kwargs,
554 )
555 else:
556 probs = swap_script.execute(
557 type="probs",
558 args=(params, inputs, model.pulse_params, random_key),
559 kwargs=kwargs,
560 )
562 # Marginalize to the ancilla register (wires 0..n-1)
563 probs = marg_probs(probs, 3 * n, tuple(range(n)))
565 ent = 1 - probs[..., 0]
567 log.debug(f"Variance of measure: {ent.var()}")
569 return float(ent.mean())
571 @staticmethod
572 def concentratable_entanglement_estimation(
573 model: Model,
574 n_samples: int,
575 random_key: Optional[jax.random.PRNGKey] = None,
576 scale: bool = False,
577 **kwargs: Any,
578 ) -> float:
579 """
580 Computes the concentratable entanglement of a given model.
582 This method utilizes the Concentratable Entanglement measure from
583 https://arxiv.org/abs/2104.06923. The swap test is implemented
584 directly in yaqsi using a ``3 * n_qubits`` circuit.
586 Args:
587 model (Model): The quantum circuit model.
588 n_samples (int): The number of samples to compute the measure for.
589 random_key (Optional[jax.random.PRNGKey]): JAX random key for
590 parameter initialization. If None, uses the model's internal
591 random key.
592 scale (bool): Whether to scale the number of samples according to
593 the number of qubits.
594 **kwargs (Any): Additional keyword arguments for the model function.
596 Returns:
597 float: Entangling capability of the given circuit, guaranteed
598 to be between 0.0 and 1.0.
599 """
600 n = model.n_qubits
601 N = 2**n
603 if scale:
604 n_samples = N * n_samples
606 def _bell_basis_measurement(
607 params, inputs, pulse_params=None, random_key=None, **kw
608 ):
609 """Bell-basis measurement circuit on 3*n qubits."""
610 from qml_essentials.tape import copy_to_tape
612 def vari():
613 model._variational(
614 params,
615 inputs,
616 pulse_params=pulse_params,
617 random_key=random_key,
618 **kw,
619 )
621 # First copy on wires 0..n-1
622 copy_to_tape(vari, offset=0)
623 # Second copy on wires n..2n-1
624 copy_to_tape(vari, offset=n)
626 for i in range(n):
627 op.CX(wires=[i, i + n])
628 op.H(wires=i)
630 bell_basis_script = ys.Script(f=_bell_basis_measurement, n_qubits=2 * n)
632 if n_samples is not None and n_samples > 0:
633 random_key = model.initialize_params(random_key, repeat=n_samples)
634 else:
635 if len(model.params.shape) <= 2:
636 model.params = model.params.reshape(1, *model.params.shape)
637 else:
638 log.info(f"Using sample size of model params: {model.params.shape[0]}")
640 params = model.params
641 inputs = model._inputs_validation(kwargs.get("inputs", None))
642 n_batch = params.shape[0]
644 # SWAP operator in Bell-basis
645 SWAP = jnp.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, -1]])
646 # Construct observable for measuring CE
647 CE_observable = op.Id([0, n]) + op.Operation([0, n], SWAP)
648 for i in range(1, n):
649 CE_observable = (CE_observable @
650 (op.Id([i, i + n]) + op.Operation([i, i + n], SWAP)))
651 CE_observable = (1/N) * CE_observable
653 expvals = []
654 if n_batch > 1:
655 from qml_essentials.utils import safe_random_split
657 random_keys = safe_random_split(random_key, num=n_batch)
658 expvals = bell_basis_script.execute(
659 type="expval",
660 obs=[CE_observable],
661 args=(params, inputs, model.pulse_params, random_keys),
662 in_axes=(0, None, None, 0),
663 kwargs=kwargs,
664 )
665 else:
666 expvals = bell_basis_script.execute(
667 type="expval",
668 obs=[CE_observable],
669 args=(params, inputs, model.pulse_params, random_key),
670 kwargs=kwargs,
671 )
673 ent = 1 - expvals
674 log.debug(f"Variance of measure: {ent.var()}")
675 return float(ent.mean())
678def sample_random_separable_states(
679 n_qubits: int,
680 n_samples: int,
681 random_key: jax.random.PRNGKey,
682 take_log: bool = False,
683) -> jnp.ndarray:
684 """
685 Sample random separable states (density matrix).
687 Args:
688 n_qubits (int): number of qubits in the state
689 n_samples (int): number of states
690 random_key (random.PRNGKey): JAX random key
691 take_log (bool): if the matrix logarithm of the density matrix should be taken.
693 Returns:
694 jnp.ndarray: Density matrices of shape (n_samples, 2**n_qubits, 2**n_qubits)
695 """
696 model = Model(n_qubits, 1, "No_Entangling", data_reupload=False)
697 model.initialize_params(random_key, repeat=n_samples)
698 # explicitly set execution type because anything else won't work
699 sigmas = model(execution_type="density", inputs=None)
700 if take_log:
701 sigmas = logm_v(sigmas) / jnp.log(2.0 + 0j)
703 return sigmas