Coverage for qml_essentials / expressibility.py: 99%
74 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-30 11:43 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-30 11:43 +0000
1import jax.numpy as jnp
2from jax import random
3import jax
4import numpy as np
5from typing import Tuple, List, Any, Optional
6from scipy import integrate
7from scipy.linalg import sqrtm
8from scipy.special import rel_entr
9from qml_essentials.model import Model
10import os
13class Expressibility:
14 @staticmethod
15 def _sample_state_fidelities(
16 model: Model,
17 n_samples: int,
18 random_key: Optional[jax.random.PRNGKey] = None,
19 kwargs: Any = None,
20 ) -> jnp.ndarray:
21 """
22 Compute the fidelities for each parameter set.
24 Args:
25 model (Callable): Function that models the quantum circuit.
26 n_samples (int): Number of parameter sets to generate.
27 random_key (Optional[jax.random.PRNGKey]): JAX random key for
28 parameter initialization. If None, uses the model's internal
29 random key.
30 kwargs (Any): Additional keyword arguments for the model function.
32 Returns:
33 jnp.ndarray: Array of shape (n_samples,) containing the fidelities.
34 """
35 # Generate random parameter sets
36 # We need two sets of parameters, as we are computing fidelities for a
37 # pair of random state vectors
38 model.initialize_params(random_key, repeat=n_samples * 2)
40 # Evaluate the model for all parameters
41 # Execution type is explicitly set to density
42 sv: jnp.ndarray = model(
43 params=model.params,
44 execution_type="density",
45 **kwargs,
46 )
48 # $\sqrt{\rho}$
49 sqrt_sv1: jnp.ndarray = jnp.array([sqrtm(m) for m in sv[:n_samples]])
51 # $\sqrt{\rho} \sigma \sqrt{\rho}$
52 inner_fidelity = sqrt_sv1 @ sv[n_samples:] @ sqrt_sv1
54 # Compute the fidelity using the partial trace of the statevector
55 fidelity: jnp.ndarray = (
56 jnp.trace(
57 jnp.array([sqrtm(m) for m in inner_fidelity]),
58 axis1=1,
59 axis2=2,
60 )
61 ** 2
62 )
64 fidelity = jnp.abs(fidelity)
66 return fidelity
68 @staticmethod
69 def state_fidelities(
70 n_samples: int,
71 n_bins: int,
72 model: Model,
73 random_key: Optional[jax.random.PRNGKey] = None,
74 scale: bool = False,
75 **kwargs: Any,
76 ) -> Tuple[jnp.ndarray, jnp.ndarray]:
77 """
78 Sample the state fidelities and histogram them into a 2D array.
80 Args:
81 n_samples (int): Number of parameter sets to generate.
82 n_bins (int): Number of histogram bins.
83 model (Callable): Function that models the quantum circuit.
84 random_key (Optional[jax.random.PRNGKey]): JAX random key for
85 parameter initialization. If None, uses the model's internal
86 random key.
87 scale (bool): Whether to scale the number of samples and bins.
88 kwargs (Any): Additional keyword arguments for the model function.
90 Returns:
91 Tuple[jnp.ndarray, jnp.ndarray]: Tuple containing the bin edges,
92 and histogram values.
93 """
94 if scale:
95 n_samples = jnp.power(2, model.n_qubits) * n_samples
96 n_bins = model.n_qubits * n_bins
98 fidelities = Expressibility._sample_state_fidelities(
99 n_samples=n_samples,
100 random_key=random_key,
101 model=model,
102 kwargs=kwargs,
103 )
105 y: jnp.ndarray = jnp.linspace(0, 1, n_bins + 1)
107 z, _ = jnp.histogram(fidelities, bins=y)
109 z = z / n_samples
111 return y, z
113 @staticmethod
114 def _haar_probability(fidelity: float, n_qubits: int) -> float:
115 """
116 Calculates theoretical probability density function for random Haar states
117 as proposed by Sim et al. (https://arxiv.org/abs/1905.10876).
119 Args:
120 fidelity (float): fidelity of two parameter assignments in [0, 1]
121 n_qubits (int): number of qubits in the quantum system
123 Returns:
124 float: probability for a given fidelity
125 """
126 N = 2**n_qubits
128 prob = (N - 1) * (1 - fidelity) ** (N - 2)
129 return prob
131 @staticmethod
132 def _sample_haar_integral(n_qubits: int, n_bins: int) -> jnp.ndarray:
133 """
134 Calculates theoretical probability density function for random Haar states
135 as proposed by Sim et al. (https://arxiv.org/abs/1905.10876) and bins it
136 into a 2D-histogram.
138 Args:
139 n_qubits (int): number of qubits in the quantum system
140 n_bins (int): number of histogram bins
142 Returns:
143 jnp.ndarray: probability distribution for all fidelities
144 """
145 dist = np.zeros(n_bins)
146 for idx in range(n_bins):
147 v = idx / n_bins
148 u = (idx + 1) / n_bins
149 dist[idx], _ = integrate.quad(
150 Expressibility._haar_probability, v, u, args=(n_qubits,)
151 )
153 return dist
155 @staticmethod
156 def haar_integral(
157 n_qubits: int,
158 n_bins: int,
159 cache: bool = True,
160 scale: bool = False,
161 ) -> Tuple[jnp.ndarray, jnp.ndarray]:
162 """
163 Calculates theoretical probability density function for random Haar states
164 as proposed by Sim et al. (https://arxiv.org/abs/1905.10876) and bins it
165 into a 3D-histogram.
167 Args:
168 n_qubits (int): number of qubits in the quantum system
169 n_bins (int): number of histogram bins
170 cache (bool): whether to cache the haar integral
171 scale (bool): whether to scale the number of bins
173 Returns:
174 Tuple[jnp.ndarray, jnp.ndarray]:
175 - x component (bins): the input domain
176 - y component (probabilities): the haar probability density
177 funtion for random Haar states
178 """
179 if scale:
180 n_bins = n_qubits * n_bins
182 x = jnp.linspace(0, 1, n_bins)
184 if cache:
185 name = f"haar_{n_qubits}q_{n_bins}s_{'scaled' if scale else ''}.npy"
187 cache_folder = ".cache"
188 if not os.path.exists(cache_folder):
189 os.mkdir(cache_folder)
191 file_path = os.path.join(cache_folder, name)
193 if os.path.isfile(file_path):
194 y = jnp.load(file_path)
195 return x, y
197 y = Expressibility._sample_haar_integral(n_qubits, n_bins)
199 if cache:
200 jnp.save(file_path, y)
202 return x, y
204 @staticmethod
205 def kullback_leibler_divergence(
206 vqc_prob_dist: jnp.ndarray,
207 haar_dist: jnp.ndarray,
208 ) -> jnp.ndarray:
209 """
210 Calculates the KL divergence between two probability distributions (Haar
211 probability distribution and the fidelity distribution sampled from a VQC).
213 Args:
214 vqc_prob_dist (jnp.ndarray): VQC fidelity probability distribution.
215 Should have shape (n_inputs_samples, n_bins)
216 haar_dist (jnp.ndarray): Haar probability distribution with shape.
217 Should have shape (n_bins, )
219 Returns:
220 jnp.ndarray: Array of KL-Divergence values for all values in axis 1
221 """
222 if len(vqc_prob_dist.shape) > 1:
223 assert all([haar_dist.shape == p.shape for p in vqc_prob_dist]), (
224 "All probabilities for inputs should have the same shape as Haar. "
225 f"Got {haar_dist.shape} for Haar and {vqc_prob_dist.shape} for VQC"
226 )
227 else:
228 vqc_prob_dist = vqc_prob_dist.reshape((1, -1))
230 kl_divergence = np.zeros(vqc_prob_dist.shape[0])
231 for idx, p in enumerate(vqc_prob_dist):
232 kl_divergence[idx] = jnp.sum(rel_entr(p, haar_dist))
234 return kl_divergence
236 def kl_divergence_to_haar(
237 model: Model,
238 n_samples: int,
239 n_bins: int,
240 random_key: Optional[jax.random.PRNGKey] = None,
241 scale: bool = False,
242 **kwargs: Any,
243 ) -> float:
244 """
245 Shortcut method to compute the KL-Divergence bewteen a model and the
246 Haar distribution. The basic steps are:
247 - Sample the state fidelities for randomly initialised parameters.
248 - Calculates the KL divergence between the sampled probability and
249 the Haar probability distribution.
251 Args:
252 model (Model): Function that models the quantum circuit.
253 n_samples (int): Number of parameter sets to generate.
254 n_bins (int): Number of histogram bins.
255 random_key (Optional[jax.random.PRNGKey]): JAX random key for
256 parameter initialization. If None, uses the model's internal
257 random key.
258 scale (bool): Whether to scale the number of samples and bins.
259 kwargs (Any): Additional keyword arguments for the model function.
261 Returns:
262 Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: Tuple containing the
263 input samples, bin edges, and histogram values.
264 """
265 _, fidelities = Expressibility.state_fidelities(
266 model=model,
267 random_key=random_key,
268 n_samples=n_samples,
269 n_bins=n_bins,
270 scale=scale,
271 **kwargs,
272 )
273 _, haar_probs = Expressibility.haar_integral(
274 model.n_qubits, n_bins=n_bins, scale=scale
275 )
276 return Expressibility.kullback_leibler_divergence(fidelities, haar_probs)