Coverage for qml_essentials / expressibility.py: 99%

73 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-16 10:19 +0000

1import jax.numpy as jnp 

2import jax 

3import numpy as np 

4from typing import Tuple, Any, Optional 

5from scipy import integrate 

6from scipy.linalg import sqrtm 

7from scipy.special import rel_entr 

8from qml_essentials.model import Model 

9import os 

10 

11 

12class Expressibility: 

13 @classmethod 

14 def _sample_state_fidelities( 

15 cls, 

16 model: Model, 

17 n_samples: int, 

18 random_key: Optional[jax.random.PRNGKey] = None, 

19 kwargs: Any = None, 

20 ) -> jnp.ndarray: 

21 """ 

22 Compute the fidelities for each parameter set. 

23 

24 Args: 

25 model (Callable): Function that models the quantum circuit. 

26 n_samples (int): Number of parameter sets to generate. 

27 random_key (Optional[jax.random.PRNGKey]): JAX random key for 

28 parameter initialization. If None, uses the model's internal 

29 random key. 

30 kwargs (Any): Additional keyword arguments for the model function. 

31 

32 Returns: 

33 jnp.ndarray: Array of shape (n_samples,) containing the fidelities. 

34 """ 

35 # Generate random parameter sets 

36 # We need two sets of parameters, as we are computing fidelities for a 

37 # pair of random state vectors 

38 model.initialize_params(random_key, repeat=n_samples * 2) 

39 

40 # Evaluate the model for all parameters 

41 # Execution type is explicitly set to density 

42 sv: jnp.ndarray = model( 

43 params=model.params, 

44 execution_type="density", 

45 **kwargs, 

46 ) 

47 

48 # $\sqrt{\rho}$ 

49 sqrt_sv1: jnp.ndarray = jnp.array([sqrtm(m) for m in sv[:n_samples]]) 

50 

51 # $\sqrt{\rho} \sigma \sqrt{\rho}$ 

52 inner_fidelity = sqrt_sv1 @ sv[n_samples:] @ sqrt_sv1 

53 

54 # Compute the fidelity using the partial trace of the statevector 

55 fidelity: jnp.ndarray = ( 

56 jnp.trace( 

57 jnp.array([sqrtm(m) for m in inner_fidelity]), 

58 axis1=1, 

59 axis2=2, 

60 ) 

61 ** 2 

62 ) 

63 

64 fidelity = jnp.abs(fidelity) 

65 

66 return fidelity 

67 

68 @classmethod 

69 def state_fidelities( 

70 cls, 

71 n_samples: int, 

72 n_bins: int, 

73 model: Model, 

74 random_key: Optional[jax.random.PRNGKey] = None, 

75 scale: bool = False, 

76 **kwargs: Any, 

77 ) -> Tuple[jnp.ndarray, jnp.ndarray]: 

78 """ 

79 Sample the state fidelities and histogram them into a 2D array. 

80 

81 Args: 

82 n_samples (int): Number of parameter sets to generate. 

83 n_bins (int): Number of histogram bins. 

84 model (Callable): Function that models the quantum circuit. 

85 random_key (Optional[jax.random.PRNGKey]): JAX random key for 

86 parameter initialization. If None, uses the model's internal 

87 random key. 

88 scale (bool): Whether to scale the number of samples and bins. 

89 kwargs (Any): Additional keyword arguments for the model function. 

90 

91 Returns: 

92 Tuple[jnp.ndarray, jnp.ndarray]: Tuple containing the bin edges, 

93 and histogram values. 

94 """ 

95 if scale: 

96 n_samples = jnp.power(2, model.n_qubits) * n_samples 

97 n_bins = model.n_qubits * n_bins 

98 

99 fidelities = cls._sample_state_fidelities( 

100 n_samples=n_samples, 

101 random_key=random_key, 

102 model=model, 

103 kwargs=kwargs, 

104 ) 

105 

106 y: jnp.ndarray = jnp.linspace(0, 1, n_bins + 1) 

107 

108 z, _ = jnp.histogram(fidelities, bins=y) 

109 

110 z = z / n_samples 

111 

112 return y, z 

113 

114 @classmethod 

115 def _haar_probability(cls, fidelity: float, n_qubits: int) -> float: 

116 """ 

117 Calculates theoretical probability density function for random Haar states 

118 as proposed by Sim et al. (https://arxiv.org/abs/1905.10876). 

119 

120 Args: 

121 fidelity (float): fidelity of two parameter assignments in [0, 1] 

122 n_qubits (int): number of qubits in the quantum system 

123 

124 Returns: 

125 float: probability for a given fidelity 

126 """ 

127 N = 2**n_qubits 

128 

129 prob = (N - 1) * (1 - fidelity) ** (N - 2) 

130 return prob 

131 

132 @classmethod 

133 def _sample_haar_integral(cls, n_qubits: int, n_bins: int) -> jnp.ndarray: 

134 """ 

135 Calculates theoretical probability density function for random Haar states 

136 as proposed by Sim et al. (https://arxiv.org/abs/1905.10876) and bins it 

137 into a 2D-histogram. 

138 

139 Args: 

140 n_qubits (int): number of qubits in the quantum system 

141 n_bins (int): number of histogram bins 

142 

143 Returns: 

144 jnp.ndarray: probability distribution for all fidelities 

145 """ 

146 dist = np.zeros(n_bins) 

147 for idx in range(n_bins): 

148 v = idx / n_bins 

149 u = (idx + 1) / n_bins 

150 dist[idx], _ = integrate.quad(cls._haar_probability, v, u, args=(n_qubits,)) 

151 

152 return dist 

153 

154 @classmethod 

155 def haar_integral( 

156 cls, 

157 n_qubits: int, 

158 n_bins: int, 

159 cache: bool = True, 

160 scale: bool = False, 

161 ) -> Tuple[jnp.ndarray, jnp.ndarray]: 

162 """ 

163 Calculates theoretical probability density function for random Haar states 

164 as proposed by Sim et al. (https://arxiv.org/abs/1905.10876) and bins it 

165 into a 3D-histogram. 

166 

167 Args: 

168 n_qubits (int): number of qubits in the quantum system 

169 n_bins (int): number of histogram bins 

170 cache (bool): whether to cache the haar integral 

171 scale (bool): whether to scale the number of bins 

172 

173 Returns: 

174 Tuple[jnp.ndarray, jnp.ndarray]: 

175 - x component (bins): the input domain 

176 - y component (probabilities): the haar probability density 

177 funtion for random Haar states 

178 """ 

179 if scale: 

180 n_bins = n_qubits * n_bins 

181 

182 x = jnp.linspace(0, 1, n_bins) 

183 

184 if cache: 

185 name = f"haar_{n_qubits}q_{n_bins}s_{'scaled' if scale else ''}.npy" 

186 

187 cache_folder = ".cache" 

188 if not os.path.exists(cache_folder): 

189 os.mkdir(cache_folder) 

190 

191 file_path = os.path.join(cache_folder, name) 

192 

193 if os.path.isfile(file_path): 

194 y = jnp.load(file_path) 

195 return x, y 

196 

197 y = cls._sample_haar_integral(n_qubits, n_bins) 

198 

199 if cache: 

200 jnp.save(file_path, y) 

201 

202 return x, y 

203 

204 @classmethod 

205 def kullback_leibler_divergence( 

206 cls, 

207 vqc_prob_dist: jnp.ndarray, 

208 haar_dist: jnp.ndarray, 

209 ) -> jnp.ndarray: 

210 """ 

211 Calculates the KL divergence between two probability distributions (Haar 

212 probability distribution and the fidelity distribution sampled from a VQC). 

213 

214 Args: 

215 vqc_prob_dist (jnp.ndarray): VQC fidelity probability distribution. 

216 Should have shape (n_inputs_samples, n_bins) 

217 haar_dist (jnp.ndarray): Haar probability distribution with shape. 

218 Should have shape (n_bins, ) 

219 

220 Returns: 

221 jnp.ndarray: Array of KL-Divergence values for all values in axis 1 

222 """ 

223 if len(vqc_prob_dist.shape) > 1: 

224 assert all([haar_dist.shape == p.shape for p in vqc_prob_dist]), ( 

225 "All probabilities for inputs should have the same shape as Haar. " 

226 f"Got {haar_dist.shape} for Haar and {vqc_prob_dist.shape} for VQC" 

227 ) 

228 else: 

229 vqc_prob_dist = vqc_prob_dist.reshape((1, -1)) 

230 

231 kl_divergence = np.zeros(vqc_prob_dist.shape[0]) 

232 for idx, p in enumerate(vqc_prob_dist): 

233 kl_divergence[idx] = jnp.sum(rel_entr(p, haar_dist)) 

234 

235 return kl_divergence 

236 

237 def kl_divergence_to_haar( 

238 model: Model, 

239 n_samples: int, 

240 n_bins: int, 

241 random_key: Optional[jax.random.PRNGKey] = None, 

242 scale: bool = False, 

243 **kwargs: Any, 

244 ) -> float: 

245 """ 

246 Shortcut method to compute the KL-Divergence bewteen a model and the 

247 Haar distribution. The basic steps are: 

248 - Sample the state fidelities for randomly initialised parameters. 

249 - Calculates the KL divergence between the sampled probability and 

250 the Haar probability distribution. 

251 

252 Args: 

253 model (Model): Function that models the quantum circuit. 

254 n_samples (int): Number of parameter sets to generate. 

255 n_bins (int): Number of histogram bins. 

256 random_key (Optional[jax.random.PRNGKey]): JAX random key for 

257 parameter initialization. If None, uses the model's internal 

258 random key. 

259 scale (bool): Whether to scale the number of samples and bins. 

260 kwargs (Any): Additional keyword arguments for the model function. 

261 

262 Returns: 

263 Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: Tuple containing the 

264 input samples, bin edges, and histogram values. 

265 """ 

266 _, fidelities = Expressibility.state_fidelities( 

267 model=model, 

268 random_key=random_key, 

269 n_samples=n_samples, 

270 n_bins=n_bins, 

271 scale=scale, 

272 **kwargs, 

273 ) 

274 _, haar_probs = Expressibility.haar_integral( 

275 model.n_qubits, n_bins=n_bins, scale=scale 

276 ) 

277 return Expressibility.kullback_leibler_divergence(fidelities, haar_probs)