Coverage for qml_essentials/entanglement.py: 88%
67 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-07 14:54 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-07 14:54 +0000
1from typing import Optional, Any
2import pennylane as qml
3import pennylane.numpy as np
4from copy import deepcopy
6from qml_essentials.model import Model
7import logging
9log = logging.getLogger(__name__)
12class Entanglement:
14 @staticmethod
15 def meyer_wallach(
16 model: Model,
17 n_samples: Optional[int | None],
18 seed: Optional[int],
19 **kwargs: Any,
20 ) -> float:
21 """
22 Calculates the entangling capacity of a given quantum circuit
23 using Meyer-Wallach measure.
25 Args:
26 model (Callable): Function that models the quantum circuit.
27 n_samples (int): Number of samples per qubit.
28 If None or < 0, the current parameters of the model are used
29 seed (Optional[int]): Seed for the random number generator.
30 kwargs (Any): Additional keyword arguments for the model function.
32 Returns:
33 float: Entangling capacity of the given circuit. It is guaranteed
34 to be between 0.0 and 1.0.
35 """
36 rng = np.random.default_rng(seed)
37 if n_samples is not None and n_samples > 0:
38 assert seed is not None, "Seed must be provided when samples > 0"
39 # TODO: maybe switch to JAX rng
40 model.initialize_params(rng=rng, repeat=n_samples)
41 params = model.params
42 else:
43 if seed is not None:
44 log.warning("Seed is ignored when samples is 0")
46 if len(model.params.shape) <= 2:
47 params = model.params.reshape(*model.params.shape, 1)
48 else:
49 log.info(f"Using sample size of model params: {model.params.shape[-1]}")
50 params = model.params
52 n_samples = params.shape[-1]
53 mw_measure = np.zeros(n_samples)
54 qb = list(range(model.n_qubits))
56 # TODO: vectorize in future iterations
57 for i in range(n_samples):
58 # implicitly set input to none in case it's not needed
59 kwargs.setdefault("inputs", None)
60 # explicitly set execution type because everything else won't work
61 U = model(params=params[:, :, i], execution_type="density", **kwargs)
63 # Formula 6 in https://doi.org/10.48550/arXiv.quant-ph/0305094
64 # ---
65 entropy = 0
66 for j in range(model.n_qubits):
67 density = qml.math.partial_trace(U, qb[:j] + qb[j + 1 :])
68 # only real values, because imaginary part will be separate
69 # in all following calculations anyway
70 # entropy should be 1/2 <= entropy <= 1
71 entropy += np.trace((density @ density).real)
73 # inverse averaged entropy and scale to [0, 1]
74 mw_measure[i] = 2 * (1 - entropy / model.n_qubits)
75 # ---
77 # Average all iterated states
78 # catch floating point errors
79 entangling_capability = min(max(mw_measure.mean(), 0.0), 1.0)
80 log.debug(f"Variance of measure: {mw_measure.var()}")
82 return float(entangling_capability)
84 @staticmethod
85 def bell_measurements(model: Model, n_samples, seed, **kwargs: Any) -> float:
87 def _circuit(params, inputs):
88 model._variational(params, inputs)
90 qml.map_wires(
91 model._variational,
92 {i: i + model.n_qubits for i in range(model.n_qubits)},
93 )(params, inputs)
95 for q in range(model.n_qubits):
96 qml.CNOT(wires=[q, q + model.n_qubits])
97 qml.H(q)
99 obs_wires = [(q, q + model.n_qubits) for q in range(model.n_qubits)]
100 return [qml.probs(wires=w) for w in obs_wires]
102 model.circuit = qml.QNode(
103 _circuit,
104 qml.device(
105 "default.qubit",
106 shots=model.shots,
107 wires=model.n_qubits * 2,
108 ),
109 )
111 rng = np.random.default_rng(seed)
112 if n_samples is not None and n_samples > 0:
113 assert seed is not None, "Seed must be provided when samples > 0"
114 # TODO: maybe switch to JAX rng
115 model.initialize_params(rng=rng, repeat=n_samples)
116 params = model.params
117 else:
118 if seed is not None:
119 log.warning("Seed is ignored when samples is 0")
121 if len(model.params.shape) <= 2:
122 params = model.params.reshape(*model.params.shape, 1)
123 else:
124 log.info(f"Using sample size of model params: {model.params.shape[-1]}")
125 params = model.params
127 n_samples = params.shape[-1]
128 mw_measure = np.zeros(n_samples)
130 for i in range(n_samples):
131 # implicitly set input to none in case it's not needed
132 kwargs.setdefault("inputs", None)
133 exp = model(params=params[:, :, i], **kwargs)
135 exp = 1 - 2 * exp[:, -1]
136 mw_measure[i] = 2 * (1 - exp.mean())
137 entangling_capability = min(max(mw_measure.mean(), 0.0), 1.0)
138 log.debug(f"Variance of measure: {mw_measure.var()}")
140 return float(entangling_capability)