Coverage for qml_essentials / expressibility.py: 99%
73 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-16 10:19 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-16 10:19 +0000
1import jax.numpy as jnp
2import jax
3import numpy as np
4from typing import Tuple, Any, Optional
5from scipy import integrate
6from scipy.linalg import sqrtm
7from scipy.special import rel_entr
8from qml_essentials.model import Model
9import os
12class Expressibility:
13 @classmethod
14 def _sample_state_fidelities(
15 cls,
16 model: Model,
17 n_samples: int,
18 random_key: Optional[jax.random.PRNGKey] = None,
19 kwargs: Any = None,
20 ) -> jnp.ndarray:
21 """
22 Compute the fidelities for each parameter set.
24 Args:
25 model (Callable): Function that models the quantum circuit.
26 n_samples (int): Number of parameter sets to generate.
27 random_key (Optional[jax.random.PRNGKey]): JAX random key for
28 parameter initialization. If None, uses the model's internal
29 random key.
30 kwargs (Any): Additional keyword arguments for the model function.
32 Returns:
33 jnp.ndarray: Array of shape (n_samples,) containing the fidelities.
34 """
35 # Generate random parameter sets
36 # We need two sets of parameters, as we are computing fidelities for a
37 # pair of random state vectors
38 model.initialize_params(random_key, repeat=n_samples * 2)
40 # Evaluate the model for all parameters
41 # Execution type is explicitly set to density
42 sv: jnp.ndarray = model(
43 params=model.params,
44 execution_type="density",
45 **kwargs,
46 )
48 # $\sqrt{\rho}$
49 sqrt_sv1: jnp.ndarray = jnp.array([sqrtm(m) for m in sv[:n_samples]])
51 # $\sqrt{\rho} \sigma \sqrt{\rho}$
52 inner_fidelity = sqrt_sv1 @ sv[n_samples:] @ sqrt_sv1
54 # Compute the fidelity using the partial trace of the statevector
55 fidelity: jnp.ndarray = (
56 jnp.trace(
57 jnp.array([sqrtm(m) for m in inner_fidelity]),
58 axis1=1,
59 axis2=2,
60 )
61 ** 2
62 )
64 fidelity = jnp.abs(fidelity)
66 return fidelity
68 @classmethod
69 def state_fidelities(
70 cls,
71 n_samples: int,
72 n_bins: int,
73 model: Model,
74 random_key: Optional[jax.random.PRNGKey] = None,
75 scale: bool = False,
76 **kwargs: Any,
77 ) -> Tuple[jnp.ndarray, jnp.ndarray]:
78 """
79 Sample the state fidelities and histogram them into a 2D array.
81 Args:
82 n_samples (int): Number of parameter sets to generate.
83 n_bins (int): Number of histogram bins.
84 model (Callable): Function that models the quantum circuit.
85 random_key (Optional[jax.random.PRNGKey]): JAX random key for
86 parameter initialization. If None, uses the model's internal
87 random key.
88 scale (bool): Whether to scale the number of samples and bins.
89 kwargs (Any): Additional keyword arguments for the model function.
91 Returns:
92 Tuple[jnp.ndarray, jnp.ndarray]: Tuple containing the bin edges,
93 and histogram values.
94 """
95 if scale:
96 n_samples = jnp.power(2, model.n_qubits) * n_samples
97 n_bins = model.n_qubits * n_bins
99 fidelities = cls._sample_state_fidelities(
100 n_samples=n_samples,
101 random_key=random_key,
102 model=model,
103 kwargs=kwargs,
104 )
106 y: jnp.ndarray = jnp.linspace(0, 1, n_bins + 1)
108 z, _ = jnp.histogram(fidelities, bins=y)
110 z = z / n_samples
112 return y, z
114 @classmethod
115 def _haar_probability(cls, fidelity: float, n_qubits: int) -> float:
116 """
117 Calculates theoretical probability density function for random Haar states
118 as proposed by Sim et al. (https://arxiv.org/abs/1905.10876).
120 Args:
121 fidelity (float): fidelity of two parameter assignments in [0, 1]
122 n_qubits (int): number of qubits in the quantum system
124 Returns:
125 float: probability for a given fidelity
126 """
127 N = 2**n_qubits
129 prob = (N - 1) * (1 - fidelity) ** (N - 2)
130 return prob
132 @classmethod
133 def _sample_haar_integral(cls, n_qubits: int, n_bins: int) -> jnp.ndarray:
134 """
135 Calculates theoretical probability density function for random Haar states
136 as proposed by Sim et al. (https://arxiv.org/abs/1905.10876) and bins it
137 into a 2D-histogram.
139 Args:
140 n_qubits (int): number of qubits in the quantum system
141 n_bins (int): number of histogram bins
143 Returns:
144 jnp.ndarray: probability distribution for all fidelities
145 """
146 dist = np.zeros(n_bins)
147 for idx in range(n_bins):
148 v = idx / n_bins
149 u = (idx + 1) / n_bins
150 dist[idx], _ = integrate.quad(cls._haar_probability, v, u, args=(n_qubits,))
152 return dist
154 @classmethod
155 def haar_integral(
156 cls,
157 n_qubits: int,
158 n_bins: int,
159 cache: bool = True,
160 scale: bool = False,
161 ) -> Tuple[jnp.ndarray, jnp.ndarray]:
162 """
163 Calculates theoretical probability density function for random Haar states
164 as proposed by Sim et al. (https://arxiv.org/abs/1905.10876) and bins it
165 into a 3D-histogram.
167 Args:
168 n_qubits (int): number of qubits in the quantum system
169 n_bins (int): number of histogram bins
170 cache (bool): whether to cache the haar integral
171 scale (bool): whether to scale the number of bins
173 Returns:
174 Tuple[jnp.ndarray, jnp.ndarray]:
175 - x component (bins): the input domain
176 - y component (probabilities): the haar probability density
177 funtion for random Haar states
178 """
179 if scale:
180 n_bins = n_qubits * n_bins
182 x = jnp.linspace(0, 1, n_bins)
184 if cache:
185 name = f"haar_{n_qubits}q_{n_bins}s_{'scaled' if scale else ''}.npy"
187 cache_folder = ".cache"
188 if not os.path.exists(cache_folder):
189 os.mkdir(cache_folder)
191 file_path = os.path.join(cache_folder, name)
193 if os.path.isfile(file_path):
194 y = jnp.load(file_path)
195 return x, y
197 y = cls._sample_haar_integral(n_qubits, n_bins)
199 if cache:
200 jnp.save(file_path, y)
202 return x, y
204 @classmethod
205 def kullback_leibler_divergence(
206 cls,
207 vqc_prob_dist: jnp.ndarray,
208 haar_dist: jnp.ndarray,
209 ) -> jnp.ndarray:
210 """
211 Calculates the KL divergence between two probability distributions (Haar
212 probability distribution and the fidelity distribution sampled from a VQC).
214 Args:
215 vqc_prob_dist (jnp.ndarray): VQC fidelity probability distribution.
216 Should have shape (n_inputs_samples, n_bins)
217 haar_dist (jnp.ndarray): Haar probability distribution with shape.
218 Should have shape (n_bins, )
220 Returns:
221 jnp.ndarray: Array of KL-Divergence values for all values in axis 1
222 """
223 if len(vqc_prob_dist.shape) > 1:
224 assert all([haar_dist.shape == p.shape for p in vqc_prob_dist]), (
225 "All probabilities for inputs should have the same shape as Haar. "
226 f"Got {haar_dist.shape} for Haar and {vqc_prob_dist.shape} for VQC"
227 )
228 else:
229 vqc_prob_dist = vqc_prob_dist.reshape((1, -1))
231 kl_divergence = np.zeros(vqc_prob_dist.shape[0])
232 for idx, p in enumerate(vqc_prob_dist):
233 kl_divergence[idx] = jnp.sum(rel_entr(p, haar_dist))
235 return kl_divergence
237 def kl_divergence_to_haar(
238 model: Model,
239 n_samples: int,
240 n_bins: int,
241 random_key: Optional[jax.random.PRNGKey] = None,
242 scale: bool = False,
243 **kwargs: Any,
244 ) -> float:
245 """
246 Shortcut method to compute the KL-Divergence bewteen a model and the
247 Haar distribution. The basic steps are:
248 - Sample the state fidelities for randomly initialised parameters.
249 - Calculates the KL divergence between the sampled probability and
250 the Haar probability distribution.
252 Args:
253 model (Model): Function that models the quantum circuit.
254 n_samples (int): Number of parameter sets to generate.
255 n_bins (int): Number of histogram bins.
256 random_key (Optional[jax.random.PRNGKey]): JAX random key for
257 parameter initialization. If None, uses the model's internal
258 random key.
259 scale (bool): Whether to scale the number of samples and bins.
260 kwargs (Any): Additional keyword arguments for the model function.
262 Returns:
263 Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: Tuple containing the
264 input samples, bin edges, and histogram values.
265 """
266 _, fidelities = Expressibility.state_fidelities(
267 model=model,
268 random_key=random_key,
269 n_samples=n_samples,
270 n_bins=n_bins,
271 scale=scale,
272 **kwargs,
273 )
274 _, haar_probs = Expressibility.haar_integral(
275 model.n_qubits, n_bins=n_bins, scale=scale
276 )
277 return Expressibility.kullback_leibler_divergence(fidelities, haar_probs)