Coverage for qml_essentials / expressibility.py: 99%

74 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-30 11:43 +0000

1import jax.numpy as jnp 

2from jax import random 

3import jax 

4import numpy as np 

5from typing import Tuple, List, Any, Optional 

6from scipy import integrate 

7from scipy.linalg import sqrtm 

8from scipy.special import rel_entr 

9from qml_essentials.model import Model 

10import os 

11 

12 

13class Expressibility: 

14 @staticmethod 

15 def _sample_state_fidelities( 

16 model: Model, 

17 n_samples: int, 

18 random_key: Optional[jax.random.PRNGKey] = None, 

19 kwargs: Any = None, 

20 ) -> jnp.ndarray: 

21 """ 

22 Compute the fidelities for each parameter set. 

23 

24 Args: 

25 model (Callable): Function that models the quantum circuit. 

26 n_samples (int): Number of parameter sets to generate. 

27 random_key (Optional[jax.random.PRNGKey]): JAX random key for 

28 parameter initialization. If None, uses the model's internal 

29 random key. 

30 kwargs (Any): Additional keyword arguments for the model function. 

31 

32 Returns: 

33 jnp.ndarray: Array of shape (n_samples,) containing the fidelities. 

34 """ 

35 # Generate random parameter sets 

36 # We need two sets of parameters, as we are computing fidelities for a 

37 # pair of random state vectors 

38 model.initialize_params(random_key, repeat=n_samples * 2) 

39 

40 # Evaluate the model for all parameters 

41 # Execution type is explicitly set to density 

42 sv: jnp.ndarray = model( 

43 params=model.params, 

44 execution_type="density", 

45 **kwargs, 

46 ) 

47 

48 # $\sqrt{\rho}$ 

49 sqrt_sv1: jnp.ndarray = jnp.array([sqrtm(m) for m in sv[:n_samples]]) 

50 

51 # $\sqrt{\rho} \sigma \sqrt{\rho}$ 

52 inner_fidelity = sqrt_sv1 @ sv[n_samples:] @ sqrt_sv1 

53 

54 # Compute the fidelity using the partial trace of the statevector 

55 fidelity: jnp.ndarray = ( 

56 jnp.trace( 

57 jnp.array([sqrtm(m) for m in inner_fidelity]), 

58 axis1=1, 

59 axis2=2, 

60 ) 

61 ** 2 

62 ) 

63 

64 fidelity = jnp.abs(fidelity) 

65 

66 return fidelity 

67 

68 @staticmethod 

69 def state_fidelities( 

70 n_samples: int, 

71 n_bins: int, 

72 model: Model, 

73 random_key: Optional[jax.random.PRNGKey] = None, 

74 scale: bool = False, 

75 **kwargs: Any, 

76 ) -> Tuple[jnp.ndarray, jnp.ndarray]: 

77 """ 

78 Sample the state fidelities and histogram them into a 2D array. 

79 

80 Args: 

81 n_samples (int): Number of parameter sets to generate. 

82 n_bins (int): Number of histogram bins. 

83 model (Callable): Function that models the quantum circuit. 

84 random_key (Optional[jax.random.PRNGKey]): JAX random key for 

85 parameter initialization. If None, uses the model's internal 

86 random key. 

87 scale (bool): Whether to scale the number of samples and bins. 

88 kwargs (Any): Additional keyword arguments for the model function. 

89 

90 Returns: 

91 Tuple[jnp.ndarray, jnp.ndarray]: Tuple containing the bin edges, 

92 and histogram values. 

93 """ 

94 if scale: 

95 n_samples = jnp.power(2, model.n_qubits) * n_samples 

96 n_bins = model.n_qubits * n_bins 

97 

98 fidelities = Expressibility._sample_state_fidelities( 

99 n_samples=n_samples, 

100 random_key=random_key, 

101 model=model, 

102 kwargs=kwargs, 

103 ) 

104 

105 y: jnp.ndarray = jnp.linspace(0, 1, n_bins + 1) 

106 

107 z, _ = jnp.histogram(fidelities, bins=y) 

108 

109 z = z / n_samples 

110 

111 return y, z 

112 

113 @staticmethod 

114 def _haar_probability(fidelity: float, n_qubits: int) -> float: 

115 """ 

116 Calculates theoretical probability density function for random Haar states 

117 as proposed by Sim et al. (https://arxiv.org/abs/1905.10876). 

118 

119 Args: 

120 fidelity (float): fidelity of two parameter assignments in [0, 1] 

121 n_qubits (int): number of qubits in the quantum system 

122 

123 Returns: 

124 float: probability for a given fidelity 

125 """ 

126 N = 2**n_qubits 

127 

128 prob = (N - 1) * (1 - fidelity) ** (N - 2) 

129 return prob 

130 

131 @staticmethod 

132 def _sample_haar_integral(n_qubits: int, n_bins: int) -> jnp.ndarray: 

133 """ 

134 Calculates theoretical probability density function for random Haar states 

135 as proposed by Sim et al. (https://arxiv.org/abs/1905.10876) and bins it 

136 into a 2D-histogram. 

137 

138 Args: 

139 n_qubits (int): number of qubits in the quantum system 

140 n_bins (int): number of histogram bins 

141 

142 Returns: 

143 jnp.ndarray: probability distribution for all fidelities 

144 """ 

145 dist = np.zeros(n_bins) 

146 for idx in range(n_bins): 

147 v = idx / n_bins 

148 u = (idx + 1) / n_bins 

149 dist[idx], _ = integrate.quad( 

150 Expressibility._haar_probability, v, u, args=(n_qubits,) 

151 ) 

152 

153 return dist 

154 

155 @staticmethod 

156 def haar_integral( 

157 n_qubits: int, 

158 n_bins: int, 

159 cache: bool = True, 

160 scale: bool = False, 

161 ) -> Tuple[jnp.ndarray, jnp.ndarray]: 

162 """ 

163 Calculates theoretical probability density function for random Haar states 

164 as proposed by Sim et al. (https://arxiv.org/abs/1905.10876) and bins it 

165 into a 3D-histogram. 

166 

167 Args: 

168 n_qubits (int): number of qubits in the quantum system 

169 n_bins (int): number of histogram bins 

170 cache (bool): whether to cache the haar integral 

171 scale (bool): whether to scale the number of bins 

172 

173 Returns: 

174 Tuple[jnp.ndarray, jnp.ndarray]: 

175 - x component (bins): the input domain 

176 - y component (probabilities): the haar probability density 

177 funtion for random Haar states 

178 """ 

179 if scale: 

180 n_bins = n_qubits * n_bins 

181 

182 x = jnp.linspace(0, 1, n_bins) 

183 

184 if cache: 

185 name = f"haar_{n_qubits}q_{n_bins}s_{'scaled' if scale else ''}.npy" 

186 

187 cache_folder = ".cache" 

188 if not os.path.exists(cache_folder): 

189 os.mkdir(cache_folder) 

190 

191 file_path = os.path.join(cache_folder, name) 

192 

193 if os.path.isfile(file_path): 

194 y = jnp.load(file_path) 

195 return x, y 

196 

197 y = Expressibility._sample_haar_integral(n_qubits, n_bins) 

198 

199 if cache: 

200 jnp.save(file_path, y) 

201 

202 return x, y 

203 

204 @staticmethod 

205 def kullback_leibler_divergence( 

206 vqc_prob_dist: jnp.ndarray, 

207 haar_dist: jnp.ndarray, 

208 ) -> jnp.ndarray: 

209 """ 

210 Calculates the KL divergence between two probability distributions (Haar 

211 probability distribution and the fidelity distribution sampled from a VQC). 

212 

213 Args: 

214 vqc_prob_dist (jnp.ndarray): VQC fidelity probability distribution. 

215 Should have shape (n_inputs_samples, n_bins) 

216 haar_dist (jnp.ndarray): Haar probability distribution with shape. 

217 Should have shape (n_bins, ) 

218 

219 Returns: 

220 jnp.ndarray: Array of KL-Divergence values for all values in axis 1 

221 """ 

222 if len(vqc_prob_dist.shape) > 1: 

223 assert all([haar_dist.shape == p.shape for p in vqc_prob_dist]), ( 

224 "All probabilities for inputs should have the same shape as Haar. " 

225 f"Got {haar_dist.shape} for Haar and {vqc_prob_dist.shape} for VQC" 

226 ) 

227 else: 

228 vqc_prob_dist = vqc_prob_dist.reshape((1, -1)) 

229 

230 kl_divergence = np.zeros(vqc_prob_dist.shape[0]) 

231 for idx, p in enumerate(vqc_prob_dist): 

232 kl_divergence[idx] = jnp.sum(rel_entr(p, haar_dist)) 

233 

234 return kl_divergence 

235 

236 def kl_divergence_to_haar( 

237 model: Model, 

238 n_samples: int, 

239 n_bins: int, 

240 random_key: Optional[jax.random.PRNGKey] = None, 

241 scale: bool = False, 

242 **kwargs: Any, 

243 ) -> float: 

244 """ 

245 Shortcut method to compute the KL-Divergence bewteen a model and the 

246 Haar distribution. The basic steps are: 

247 - Sample the state fidelities for randomly initialised parameters. 

248 - Calculates the KL divergence between the sampled probability and 

249 the Haar probability distribution. 

250 

251 Args: 

252 model (Model): Function that models the quantum circuit. 

253 n_samples (int): Number of parameter sets to generate. 

254 n_bins (int): Number of histogram bins. 

255 random_key (Optional[jax.random.PRNGKey]): JAX random key for 

256 parameter initialization. If None, uses the model's internal 

257 random key. 

258 scale (bool): Whether to scale the number of samples and bins. 

259 kwargs (Any): Additional keyword arguments for the model function. 

260 

261 Returns: 

262 Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: Tuple containing the 

263 input samples, bin edges, and histogram values. 

264 """ 

265 _, fidelities = Expressibility.state_fidelities( 

266 model=model, 

267 random_key=random_key, 

268 n_samples=n_samples, 

269 n_bins=n_bins, 

270 scale=scale, 

271 **kwargs, 

272 ) 

273 _, haar_probs = Expressibility.haar_integral( 

274 model.n_qubits, n_bins=n_bins, scale=scale 

275 ) 

276 return Expressibility.kullback_leibler_divergence(fidelities, haar_probs)