Coverage for qml_essentials/expressibility.py: 98%

84 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2026-02-20 14:03 +0000

1import jax.numpy as jnp 

2from jax import random 

3import numpy as np 

4from typing import Tuple, List, Any 

5from scipy import integrate 

6from scipy.linalg import sqrtm 

7from scipy.special import rel_entr 

8from qml_essentials.model import Model 

9import os 

10 

11 

12class Expressibility: 

13 @staticmethod 

14 def _sample_state_fidelities( 

15 model: Model, 

16 x_samples: jnp.ndarray, 

17 n_samples: int, 

18 seed: int, 

19 kwargs: Any, 

20 ) -> jnp.ndarray: 

21 """ 

22 Compute the fidelities for each pair of input samples and parameter sets. 

23 

24 Args: 

25 model (Callable): Function that models the quantum circuit. 

26 x_samples (jnp.ndarray): Array of shape (n_input_samples, n_features) 

27 containing the input samples. 

28 n_samples (int): Number of parameter sets to generate. 

29 seed (int): Random number generator seed. 

30 kwargs (Any): Additional keyword arguments for the model function. 

31 

32 Returns: 

33 jnp.ndarray: Array of shape (n_input_samples, n_samples) 

34 containing the fidelities. 

35 """ 

36 random_key = random.key(seed) 

37 

38 # Generate random parameter sets 

39 # We need two sets of parameters, as we are computing fidelities for a 

40 # pair of random state vectors 

41 model.initialize_params(random_key, repeat=n_samples * 2) 

42 

43 # Initialize array to store fidelities 

44 fidelities: jnp.ndarray = jnp.zeros((len(x_samples), n_samples)) 

45 

46 # Compute the fidelity for each pair of input samples and parameters 

47 for idx, x_sample in enumerate(x_samples): 

48 # Evaluate the model for the current pair of input samples and parameters 

49 # Execution type is explicitly set to density 

50 sv: jnp.ndarray = model( 

51 inputs=x_sample, 

52 params=model.params, 

53 execution_type="density", 

54 **kwargs, 

55 ) 

56 

57 # $\sqrt{\rho}$ 

58 sqrt_sv1: jnp.ndarray = jnp.array([sqrtm(m) for m in sv[:n_samples]]) 

59 

60 # $\sqrt{\rho} \sigma \sqrt{\rho}$ 

61 inner_fidelity = sqrt_sv1 @ sv[n_samples:] @ sqrt_sv1 

62 

63 # Compute the fidelity using the partial trace of the statevector 

64 fidelity: jnp.ndarray = ( 

65 jnp.trace( 

66 jnp.array([sqrtm(m) for m in inner_fidelity]), 

67 axis1=1, 

68 axis2=2, 

69 ) 

70 ** 2 

71 ) 

72 

73 fidelities = fidelities.at[idx].set(jnp.abs(fidelity)) 

74 

75 return fidelities 

76 

77 @staticmethod 

78 def state_fidelities( 

79 seed: int, 

80 n_samples: int, 

81 n_bins: int, 

82 model: Model, 

83 n_input_samples: int = 0, 

84 input_domain: List[float] = None, 

85 scale: bool = False, 

86 **kwargs: Any, 

87 ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: 

88 """ 

89 Sample the state fidelities and histogram them into a 2D array. 

90 

91 Args: 

92 seed (int): Random number generator seed. 

93 n_samples (int): Number of parameter sets to generate. 

94 n_bins (int): Number of histogram bins. 

95 n_input_samples (int): Number of input samples. 

96 input_domain (List[float]): Input domain. 

97 model (Callable): Function that models the quantum circuit. 

98 scale (bool): Whether to scale the number of samples and bins. 

99 kwargs (Any): Additional keyword arguments for the model function. 

100 

101 Returns: 

102 Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: Tuple containing the 

103 input samples, bin edges, and histogram values. 

104 """ 

105 if scale: 

106 n_samples = jnp.power(2, model.n_qubits) * n_samples 

107 n_bins = model.n_qubits * n_bins 

108 

109 if input_domain is None or n_input_samples is None or n_input_samples == 0: 

110 x = jnp.zeros((1)) 

111 n_input_samples = 1 

112 else: 

113 x = jnp.linspace(*input_domain, n_input_samples) 

114 

115 fidelities = Expressibility._sample_state_fidelities( 

116 x_samples=x, 

117 n_samples=n_samples, 

118 seed=seed, 

119 model=model, 

120 kwargs=kwargs, 

121 ) 

122 z: np.ndarray = np.zeros((n_input_samples, n_bins)) 

123 

124 y: jnp.ndarray = jnp.linspace(0, 1, n_bins + 1) 

125 

126 for i, f in enumerate(fidelities): 

127 z[i], _ = jnp.histogram(f, bins=y) 

128 

129 z = z / n_samples 

130 

131 if z.shape[0] == 1: 

132 z = z.flatten() 

133 

134 return x, y, z 

135 

136 @staticmethod 

137 def _haar_probability(fidelity: float, n_qubits: int) -> float: 

138 """ 

139 Calculates theoretical probability density function for random Haar states 

140 as proposed by Sim et al. (https://arxiv.org/abs/1905.10876). 

141 

142 Args: 

143 fidelity (float): fidelity of two parameter assignments in [0, 1] 

144 n_qubits (int): number of qubits in the quantum system 

145 

146 Returns: 

147 float: probability for a given fidelity 

148 """ 

149 N = 2**n_qubits 

150 

151 prob = (N - 1) * (1 - fidelity) ** (N - 2) 

152 return prob 

153 

154 @staticmethod 

155 def _sample_haar_integral(n_qubits: int, n_bins: int) -> jnp.ndarray: 

156 """ 

157 Calculates theoretical probability density function for random Haar states 

158 as proposed by Sim et al. (https://arxiv.org/abs/1905.10876) and bins it 

159 into a 2D-histogram. 

160 

161 Args: 

162 n_qubits (int): number of qubits in the quantum system 

163 n_bins (int): number of histogram bins 

164 

165 Returns: 

166 jnp.ndarray: probability distribution for all fidelities 

167 """ 

168 dist = np.zeros(n_bins) 

169 for idx in range(n_bins): 

170 v = idx / n_bins 

171 u = (idx + 1) / n_bins 

172 dist[idx], _ = integrate.quad( 

173 Expressibility._haar_probability, v, u, args=(n_qubits,) 

174 ) 

175 

176 return dist 

177 

178 @staticmethod 

179 def haar_integral( 

180 n_qubits: int, 

181 n_bins: int, 

182 cache: bool = True, 

183 scale: bool = False, 

184 ) -> Tuple[jnp.ndarray, jnp.ndarray]: 

185 """ 

186 Calculates theoretical probability density function for random Haar states 

187 as proposed by Sim et al. (https://arxiv.org/abs/1905.10876) and bins it 

188 into a 3D-histogram. 

189 

190 Args: 

191 n_qubits (int): number of qubits in the quantum system 

192 n_bins (int): number of histogram bins 

193 cache (bool): whether to cache the haar integral 

194 scale (bool): whether to scale the number of bins 

195 

196 Returns: 

197 Tuple[jnp.ndarray, jnp.ndarray]: 

198 - x component (bins): the input domain 

199 - y component (probabilities): the haar probability density 

200 funtion for random Haar states 

201 """ 

202 if scale: 

203 n_bins = n_qubits * n_bins 

204 

205 x = jnp.linspace(0, 1, n_bins) 

206 

207 if cache: 

208 name = f"haar_{n_qubits}q_{n_bins}s_{'scaled' if scale else ''}.npy" 

209 

210 cache_folder = ".cache" 

211 if not os.path.exists(cache_folder): 

212 os.mkdir(cache_folder) 

213 

214 file_path = os.path.join(cache_folder, name) 

215 

216 if os.path.isfile(file_path): 

217 y = jnp.load(file_path) 

218 return x, y 

219 

220 y = Expressibility._sample_haar_integral(n_qubits, n_bins) 

221 

222 if cache: 

223 jnp.save(file_path, y) 

224 

225 return x, y 

226 

227 @staticmethod 

228 def kullback_leibler_divergence( 

229 vqc_prob_dist: jnp.ndarray, 

230 haar_dist: jnp.ndarray, 

231 ) -> jnp.ndarray: 

232 """ 

233 Calculates the KL divergence between two probability distributions (Haar 

234 probability distribution and the fidelity distribution sampled from a VQC). 

235 

236 Args: 

237 vqc_prob_dist (jnp.ndarray): VQC fidelity probability distribution. 

238 Should have shape (n_inputs_samples, n_bins) 

239 haar_dist (jnp.ndarray): Haar probability distribution with shape. 

240 Should have shape (n_bins, ) 

241 

242 Returns: 

243 jnp.ndarray: Array of KL-Divergence values for all values in axis 1 

244 """ 

245 if len(vqc_prob_dist.shape) > 1: 

246 assert all([haar_dist.shape == p.shape for p in vqc_prob_dist]), ( 

247 "All probabilities for inputs should have the same shape as Haar. " 

248 f"Got {haar_dist.shape} for Haar and {vqc_prob_dist.shape} for VQC" 

249 ) 

250 else: 

251 vqc_prob_dist = vqc_prob_dist.reshape((1, -1)) 

252 

253 kl_divergence = np.zeros(vqc_prob_dist.shape[0]) 

254 for idx, p in enumerate(vqc_prob_dist): 

255 kl_divergence[idx] = jnp.sum(rel_entr(p, haar_dist)) 

256 

257 return kl_divergence 

258 

259 def kl_divergence_to_haar( 

260 model: Model, 

261 seed: int, 

262 n_samples: int, 

263 n_bins: int, 

264 n_input_samples: int = 0, 

265 input_domain: List[float] = None, 

266 scale: bool = False, 

267 **kwargs: Any, 

268 ) -> float: 

269 """ 

270 Shortcut method to compute the KL-Divergence bewteen a model and the 

271 Haar distribution. The basic steps are: 

272 - Sample the state fidelities for randomly initialised parameters. 

273 - Calculates the KL divergence between the sampled probability and 

274 the Haar probability distribution. 

275 

276 Args: 

277 model (Model): Function that models the quantum circuit. 

278 seed (int): Random number generator seed. 

279 n_samples (int): Number of parameter sets to generate. 

280 n_bins (int): Number of histogram bins. 

281 n_input_samples (int): Number of input samples. 

282 input_domain (List[float]): Input domain. 

283 scale (bool): Whether to scale the number of samples and bins. 

284 kwargs (Any): Additional keyword arguments for the model function. 

285 

286 Returns: 

287 Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: Tuple containing the 

288 input samples, bin edges, and histogram values. 

289 """ 

290 _, _, fidelities = Expressibility.state_fidelities( 

291 model=model, 

292 seed=seed, 

293 n_samples=n_samples, 

294 n_bins=n_bins, 

295 n_input_samples=n_input_samples, 

296 input_domain=input_domain, 

297 scale=scale, 

298 **kwargs, 

299 ) 

300 _, haar_probs = Expressibility.haar_integral( 

301 model.n_qubits, n_bins=n_bins, scale=scale 

302 ) 

303 return Expressibility.kullback_leibler_divergence(fidelities, haar_probs)