Coverage for qml_essentials/entanglement.py: 94%
34 statements
« prev ^ index » next coverage.py v7.6.5, created at 2024-11-15 11:13 +0000
« prev ^ index » next coverage.py v7.6.5, created at 2024-11-15 11:13 +0000
1from typing import Optional, Any
2import pennylane as qml
3import pennylane.numpy as np
5from qml_essentials.model import Model
6import logging
8log = logging.getLogger(__name__)
11class Entanglement:
13 @staticmethod
14 def meyer_wallach(
15 model: Model,
16 n_samples: Optional[int | None],
17 seed: Optional[int],
18 **kwargs: Any,
19 ) -> float:
20 """
21 Calculates the entangling capacity of a given quantum circuit
22 using Meyer-Wallach measure.
24 Args:
25 model (Callable): Function that models the quantum circuit.
26 n_samples (int): Number of samples per qubit.
27 If None or < 0, the current parameters of the model are used
28 seed (Optional[int]): Seed for the random number generator.
29 kwargs (Any): Additional keyword arguments for the model function.
31 Returns:
32 float: Entangling capacity of the given circuit. It is guaranteed
33 to be between 0.0 and 1.0.
34 """
35 rng = np.random.default_rng(seed)
36 if n_samples is not None and n_samples > 0:
37 assert seed is not None, "Seed must be provided when samples > 0"
38 # TODO: maybe switch to JAX rng
39 model.initialize_params(rng=rng, repeat=n_samples)
40 params = model.params
41 else:
42 if seed is not None:
43 log.warning("Seed is ignored when samples is 0")
45 if len(model.params.shape) <= 2:
46 params = model.params.reshape(*model.params.shape, 1)
47 else:
48 log.info(f"Using sample size of model params: {model.params.shape[-1]}")
49 params = model.params
51 n_samples = params.shape[-1]
52 mw_measure = np.zeros(n_samples, dtype=complex)
53 qb = list(range(model.n_qubits))
55 # TODO: vectorize in future iterations
56 for i in range(n_samples):
57 # implicitly set input to none in case it's not needed
58 kwargs.setdefault("inputs", None)
59 # explicitly set execution type because everything else won't work
60 U = model(params=params[:, :, i], execution_type="density", **kwargs)
62 entropy = 0
64 for j in range(model.n_qubits):
65 density = qml.math.partial_trace(U, qb[:j] + qb[j + 1 :])
66 entropy += np.trace((density @ density).real)
68 mw_measure[i] = 1 - entropy / model.n_qubits
70 mw = 2 * np.sum(mw_measure.real) / n_samples
72 # catch floating point errors
73 entangling_capability = min(max(mw, 0.0), 1.0)
75 return float(entangling_capability)