Coverage for qml_essentials/expressibility.py: 98%
84 statements
« prev ^ index » next coverage.py v7.9.2, created at 2026-02-20 14:03 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2026-02-20 14:03 +0000
1import jax.numpy as jnp
2from jax import random
3import numpy as np
4from typing import Tuple, List, Any
5from scipy import integrate
6from scipy.linalg import sqrtm
7from scipy.special import rel_entr
8from qml_essentials.model import Model
9import os
12class Expressibility:
13 @staticmethod
14 def _sample_state_fidelities(
15 model: Model,
16 x_samples: jnp.ndarray,
17 n_samples: int,
18 seed: int,
19 kwargs: Any,
20 ) -> jnp.ndarray:
21 """
22 Compute the fidelities for each pair of input samples and parameter sets.
24 Args:
25 model (Callable): Function that models the quantum circuit.
26 x_samples (jnp.ndarray): Array of shape (n_input_samples, n_features)
27 containing the input samples.
28 n_samples (int): Number of parameter sets to generate.
29 seed (int): Random number generator seed.
30 kwargs (Any): Additional keyword arguments for the model function.
32 Returns:
33 jnp.ndarray: Array of shape (n_input_samples, n_samples)
34 containing the fidelities.
35 """
36 random_key = random.key(seed)
38 # Generate random parameter sets
39 # We need two sets of parameters, as we are computing fidelities for a
40 # pair of random state vectors
41 model.initialize_params(random_key, repeat=n_samples * 2)
43 # Initialize array to store fidelities
44 fidelities: jnp.ndarray = jnp.zeros((len(x_samples), n_samples))
46 # Compute the fidelity for each pair of input samples and parameters
47 for idx, x_sample in enumerate(x_samples):
48 # Evaluate the model for the current pair of input samples and parameters
49 # Execution type is explicitly set to density
50 sv: jnp.ndarray = model(
51 inputs=x_sample,
52 params=model.params,
53 execution_type="density",
54 **kwargs,
55 )
57 # $\sqrt{\rho}$
58 sqrt_sv1: jnp.ndarray = jnp.array([sqrtm(m) for m in sv[:n_samples]])
60 # $\sqrt{\rho} \sigma \sqrt{\rho}$
61 inner_fidelity = sqrt_sv1 @ sv[n_samples:] @ sqrt_sv1
63 # Compute the fidelity using the partial trace of the statevector
64 fidelity: jnp.ndarray = (
65 jnp.trace(
66 jnp.array([sqrtm(m) for m in inner_fidelity]),
67 axis1=1,
68 axis2=2,
69 )
70 ** 2
71 )
73 fidelities = fidelities.at[idx].set(jnp.abs(fidelity))
75 return fidelities
77 @staticmethod
78 def state_fidelities(
79 seed: int,
80 n_samples: int,
81 n_bins: int,
82 model: Model,
83 n_input_samples: int = 0,
84 input_domain: List[float] = None,
85 scale: bool = False,
86 **kwargs: Any,
87 ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
88 """
89 Sample the state fidelities and histogram them into a 2D array.
91 Args:
92 seed (int): Random number generator seed.
93 n_samples (int): Number of parameter sets to generate.
94 n_bins (int): Number of histogram bins.
95 n_input_samples (int): Number of input samples.
96 input_domain (List[float]): Input domain.
97 model (Callable): Function that models the quantum circuit.
98 scale (bool): Whether to scale the number of samples and bins.
99 kwargs (Any): Additional keyword arguments for the model function.
101 Returns:
102 Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: Tuple containing the
103 input samples, bin edges, and histogram values.
104 """
105 if scale:
106 n_samples = jnp.power(2, model.n_qubits) * n_samples
107 n_bins = model.n_qubits * n_bins
109 if input_domain is None or n_input_samples is None or n_input_samples == 0:
110 x = jnp.zeros((1))
111 n_input_samples = 1
112 else:
113 x = jnp.linspace(*input_domain, n_input_samples)
115 fidelities = Expressibility._sample_state_fidelities(
116 x_samples=x,
117 n_samples=n_samples,
118 seed=seed,
119 model=model,
120 kwargs=kwargs,
121 )
122 z: np.ndarray = np.zeros((n_input_samples, n_bins))
124 y: jnp.ndarray = jnp.linspace(0, 1, n_bins + 1)
126 for i, f in enumerate(fidelities):
127 z[i], _ = jnp.histogram(f, bins=y)
129 z = z / n_samples
131 if z.shape[0] == 1:
132 z = z.flatten()
134 return x, y, z
136 @staticmethod
137 def _haar_probability(fidelity: float, n_qubits: int) -> float:
138 """
139 Calculates theoretical probability density function for random Haar states
140 as proposed by Sim et al. (https://arxiv.org/abs/1905.10876).
142 Args:
143 fidelity (float): fidelity of two parameter assignments in [0, 1]
144 n_qubits (int): number of qubits in the quantum system
146 Returns:
147 float: probability for a given fidelity
148 """
149 N = 2**n_qubits
151 prob = (N - 1) * (1 - fidelity) ** (N - 2)
152 return prob
154 @staticmethod
155 def _sample_haar_integral(n_qubits: int, n_bins: int) -> jnp.ndarray:
156 """
157 Calculates theoretical probability density function for random Haar states
158 as proposed by Sim et al. (https://arxiv.org/abs/1905.10876) and bins it
159 into a 2D-histogram.
161 Args:
162 n_qubits (int): number of qubits in the quantum system
163 n_bins (int): number of histogram bins
165 Returns:
166 jnp.ndarray: probability distribution for all fidelities
167 """
168 dist = np.zeros(n_bins)
169 for idx in range(n_bins):
170 v = idx / n_bins
171 u = (idx + 1) / n_bins
172 dist[idx], _ = integrate.quad(
173 Expressibility._haar_probability, v, u, args=(n_qubits,)
174 )
176 return dist
178 @staticmethod
179 def haar_integral(
180 n_qubits: int,
181 n_bins: int,
182 cache: bool = True,
183 scale: bool = False,
184 ) -> Tuple[jnp.ndarray, jnp.ndarray]:
185 """
186 Calculates theoretical probability density function for random Haar states
187 as proposed by Sim et al. (https://arxiv.org/abs/1905.10876) and bins it
188 into a 3D-histogram.
190 Args:
191 n_qubits (int): number of qubits in the quantum system
192 n_bins (int): number of histogram bins
193 cache (bool): whether to cache the haar integral
194 scale (bool): whether to scale the number of bins
196 Returns:
197 Tuple[jnp.ndarray, jnp.ndarray]:
198 - x component (bins): the input domain
199 - y component (probabilities): the haar probability density
200 funtion for random Haar states
201 """
202 if scale:
203 n_bins = n_qubits * n_bins
205 x = jnp.linspace(0, 1, n_bins)
207 if cache:
208 name = f"haar_{n_qubits}q_{n_bins}s_{'scaled' if scale else ''}.npy"
210 cache_folder = ".cache"
211 if not os.path.exists(cache_folder):
212 os.mkdir(cache_folder)
214 file_path = os.path.join(cache_folder, name)
216 if os.path.isfile(file_path):
217 y = jnp.load(file_path)
218 return x, y
220 y = Expressibility._sample_haar_integral(n_qubits, n_bins)
222 if cache:
223 jnp.save(file_path, y)
225 return x, y
227 @staticmethod
228 def kullback_leibler_divergence(
229 vqc_prob_dist: jnp.ndarray,
230 haar_dist: jnp.ndarray,
231 ) -> jnp.ndarray:
232 """
233 Calculates the KL divergence between two probability distributions (Haar
234 probability distribution and the fidelity distribution sampled from a VQC).
236 Args:
237 vqc_prob_dist (jnp.ndarray): VQC fidelity probability distribution.
238 Should have shape (n_inputs_samples, n_bins)
239 haar_dist (jnp.ndarray): Haar probability distribution with shape.
240 Should have shape (n_bins, )
242 Returns:
243 jnp.ndarray: Array of KL-Divergence values for all values in axis 1
244 """
245 if len(vqc_prob_dist.shape) > 1:
246 assert all([haar_dist.shape == p.shape for p in vqc_prob_dist]), (
247 "All probabilities for inputs should have the same shape as Haar. "
248 f"Got {haar_dist.shape} for Haar and {vqc_prob_dist.shape} for VQC"
249 )
250 else:
251 vqc_prob_dist = vqc_prob_dist.reshape((1, -1))
253 kl_divergence = np.zeros(vqc_prob_dist.shape[0])
254 for idx, p in enumerate(vqc_prob_dist):
255 kl_divergence[idx] = jnp.sum(rel_entr(p, haar_dist))
257 return kl_divergence
259 def kl_divergence_to_haar(
260 model: Model,
261 seed: int,
262 n_samples: int,
263 n_bins: int,
264 n_input_samples: int = 0,
265 input_domain: List[float] = None,
266 scale: bool = False,
267 **kwargs: Any,
268 ) -> float:
269 """
270 Shortcut method to compute the KL-Divergence bewteen a model and the
271 Haar distribution. The basic steps are:
272 - Sample the state fidelities for randomly initialised parameters.
273 - Calculates the KL divergence between the sampled probability and
274 the Haar probability distribution.
276 Args:
277 model (Model): Function that models the quantum circuit.
278 seed (int): Random number generator seed.
279 n_samples (int): Number of parameter sets to generate.
280 n_bins (int): Number of histogram bins.
281 n_input_samples (int): Number of input samples.
282 input_domain (List[float]): Input domain.
283 scale (bool): Whether to scale the number of samples and bins.
284 kwargs (Any): Additional keyword arguments for the model function.
286 Returns:
287 Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: Tuple containing the
288 input samples, bin edges, and histogram values.
289 """
290 _, _, fidelities = Expressibility.state_fidelities(
291 model=model,
292 seed=seed,
293 n_samples=n_samples,
294 n_bins=n_bins,
295 n_input_samples=n_input_samples,
296 input_domain=input_domain,
297 scale=scale,
298 **kwargs,
299 )
300 _, haar_probs = Expressibility.haar_integral(
301 model.n_qubits, n_bins=n_bins, scale=scale
302 )
303 return Expressibility.kullback_leibler_divergence(fidelities, haar_probs)