Coverage for qml_essentials/expressibility.py: 95%

80 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-23 11:23 +0000

1import pennylane.numpy as np 

2from typing import Tuple, List, Any 

3from scipy import integrate 

4from scipy.linalg import sqrtm 

5from scipy.special import rel_entr 

6from qml_essentials.model import Model 

7import os 

8 

9 

10class Expressibility: 

11 @staticmethod 

12 def _sample_state_fidelities( 

13 model: Model, 

14 x_samples: np.ndarray, 

15 n_samples: int, 

16 seed: int, 

17 kwargs: Any, 

18 ) -> np.ndarray: 

19 """ 

20 Compute the fidelities for each pair of input samples and parameter sets. 

21 

22 Args: 

23 model (Callable): Function that models the quantum circuit. 

24 x_samples (np.ndarray): Array of shape (n_input_samples, n_features) 

25 containing the input samples. 

26 n_samples (int): Number of parameter sets to generate. 

27 seed (int): Random number generator seed. 

28 kwargs (Any): Additional keyword arguments for the model function. 

29 

30 Returns: 

31 np.ndarray: Array of shape (n_input_samples, n_samples) 

32 containing the fidelities. 

33 """ 

34 rng = np.random.default_rng(seed) 

35 

36 # Generate random parameter sets 

37 # We need two sets of parameters, as we are computing fidelities for a 

38 # pair of random state vectors 

39 model.initialize_params(rng=rng, repeat=n_samples * 2) 

40 

41 n_x_samples = len(x_samples) 

42 

43 # Initialize array to store fidelities 

44 fidelities: np.ndarray = np.zeros((n_x_samples, n_samples)) 

45 

46 # Batch input samples and parameter sets for efficient computation 

47 x_samples_batched: np.ndarray = x_samples.reshape(1, -1).repeat( 

48 n_samples * 2, axis=0 

49 ) 

50 

51 # Compute the fidelity for each pair of input samples and parameters 

52 for idx in range(n_x_samples): 

53 

54 # Evaluate the model for the current pair of input samples and parameters 

55 # Execution type is explicitly set to density 

56 sv: np.ndarray = model( 

57 inputs=x_samples_batched[:, idx], 

58 params=model.params, 

59 execution_type="density", 

60 **kwargs, 

61 ) 

62 

63 # $\sqrt{\rho}$ 

64 sqrt_sv1: np.ndarray = np.array([sqrtm(m) for m in sv[:n_samples]]) 

65 

66 # $\sqrt{\rho} \sigma \sqrt{\rho}$ 

67 inner_fidelity = sqrt_sv1 @ sv[n_samples:] @ sqrt_sv1 

68 

69 # Compute the fidelity using the partial trace of the statevector 

70 fidelity: np.ndarray = ( 

71 np.trace( 

72 np.array([sqrtm(m) for m in inner_fidelity]), 

73 axis1=1, 

74 axis2=2, 

75 ) 

76 ** 2 

77 ) 

78 

79 fidelities[idx] = np.abs(fidelity) 

80 

81 return fidelities 

82 

83 @staticmethod 

84 def state_fidelities( 

85 seed: int, 

86 n_samples: int, 

87 n_bins: int, 

88 model: Model, 

89 n_input_samples: int = 0, 

90 input_domain: List[float] = None, 

91 scale: bool = False, 

92 **kwargs: Any, 

93 ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 

94 """ 

95 Sample the state fidelities and histogram them into a 2D array. 

96 

97 Args: 

98 seed (int): Random number generator seed. 

99 n_samples (int): Number of parameter sets to generate. 

100 n_bins (int): Number of histogram bins. 

101 n_input_samples (int): Number of input samples. 

102 input_domain (List[float]): Input domain. 

103 model (Callable): Function that models the quantum circuit. 

104 scale (bool): Whether to scale the number of samples and bins. 

105 kwargs (Any): Additional keyword arguments for the model function. 

106 

107 Returns: 

108 Tuple[np.ndarray, np.ndarray, np.ndarray]: Tuple containing the 

109 input samples, bin edges, and histogram values. 

110 """ 

111 if scale: 

112 n_samples = np.power(2, model.n_qubits) * n_samples 

113 n_bins = model.n_qubits * n_bins 

114 

115 if input_domain is None or n_input_samples is None or n_input_samples == 0: 

116 x = np.zeros((1)) 

117 n_input_samples = 1 

118 else: 

119 x = np.linspace(*input_domain, n_input_samples, requires_grad=False) 

120 

121 fidelities = Expressibility._sample_state_fidelities( 

122 x_samples=x, 

123 n_samples=n_samples, 

124 seed=seed, 

125 model=model, 

126 kwargs=kwargs, 

127 ) 

128 z: np.ndarray = np.zeros((n_input_samples, n_bins)) 

129 

130 y: np.ndarray = np.linspace(0, 1, n_bins + 1) 

131 

132 for i, f in enumerate(fidelities): 

133 z[i], _ = np.histogram(f, bins=y) 

134 

135 z = z / n_samples 

136 

137 if z.shape[0] == 1: 

138 z = z.flatten() 

139 

140 return x, y, z 

141 

142 @staticmethod 

143 def _haar_probability(fidelity: float, n_qubits: int) -> float: 

144 """ 

145 Calculates theoretical probability density function for random Haar states 

146 as proposed by Sim et al. (https://arxiv.org/abs/1905.10876). 

147 

148 Args: 

149 fidelity (float): fidelity of two parameter assignments in [0, 1] 

150 n_qubits (int): number of qubits in the quantum system 

151 

152 Returns: 

153 float: probability for a given fidelity 

154 """ 

155 N = 2**n_qubits 

156 

157 prob = (N - 1) * (1 - fidelity) ** (N - 2) 

158 return prob 

159 

160 @staticmethod 

161 def _sample_haar_integral(n_qubits: int, n_bins: int) -> np.ndarray: 

162 """ 

163 Calculates theoretical probability density function for random Haar states 

164 as proposed by Sim et al. (https://arxiv.org/abs/1905.10876) and bins it 

165 into a 2D-histogram. 

166 

167 Args: 

168 n_qubits (int): number of qubits in the quantum system 

169 n_bins (int): number of histogram bins 

170 

171 Returns: 

172 np.ndarray: probability distribution for all fidelities 

173 """ 

174 dist = np.zeros(n_bins) 

175 for idx in range(n_bins): 

176 v = idx / n_bins 

177 u = (idx + 1) / n_bins 

178 dist[idx], _ = integrate.quad( 

179 Expressibility._haar_probability, v, u, args=(n_qubits,) 

180 ) 

181 

182 return dist 

183 

184 @staticmethod 

185 def haar_integral( 

186 n_qubits: int, 

187 n_bins: int, 

188 cache: bool = True, 

189 scale: bool = False, 

190 ) -> Tuple[np.ndarray, np.ndarray]: 

191 """ 

192 Calculates theoretical probability density function for random Haar states 

193 as proposed by Sim et al. (https://arxiv.org/abs/1905.10876) and bins it 

194 into a 3D-histogram. 

195 

196 Args: 

197 n_qubits (int): number of qubits in the quantum system 

198 n_bins (int): number of histogram bins 

199 cache (bool): whether to cache the haar integral 

200 scale (bool): whether to scale the number of bins 

201 

202 Returns: 

203 Tuple[np.ndarray, np.ndarray]: 

204 - x component (bins): the input domain 

205 - y component (probabilities): the haar probability density 

206 funtion for random Haar states 

207 """ 

208 if scale: 

209 n_bins = n_qubits * n_bins 

210 

211 x = np.linspace(0, 1, n_bins) 

212 

213 if cache: 

214 name = f"haar_{n_qubits}q_{n_bins}s_{'scaled' if scale else ''}.npy" 

215 

216 cache_folder = ".cache" 

217 if not os.path.exists(cache_folder): 

218 os.mkdir(cache_folder) 

219 

220 file_path = os.path.join(cache_folder, name) 

221 

222 if os.path.isfile(file_path): 

223 y = np.load(file_path) 

224 return x, y 

225 

226 y = Expressibility._sample_haar_integral(n_qubits, n_bins) 

227 

228 if cache: 

229 np.save(file_path, y) 

230 

231 return x, y 

232 

233 @staticmethod 

234 def kullback_leibler_divergence( 

235 vqc_prob_dist: np.ndarray, 

236 haar_dist: np.ndarray, 

237 ) -> np.ndarray: 

238 """ 

239 Calculates the KL divergence between two probability distributions (Haar 

240 probability distribution and the fidelity distribution sampled from a VQC). 

241 

242 Args: 

243 vqc_prob_dist (np.ndarray): VQC fidelity probability distribution. 

244 Should have shape (n_inputs_samples, n_bins) 

245 haar_dist (np.ndarray): Haar probability distribution with shape. 

246 Should have shape (n_bins, ) 

247 

248 Returns: 

249 np.ndarray: Array of KL-Divergence values for all values in axis 1 

250 """ 

251 if len(vqc_prob_dist.shape) > 1: 

252 assert all([haar_dist.shape == p.shape for p in vqc_prob_dist]), ( 

253 "All probabilities for inputs should have the same shape as Haar. " 

254 f"Got {haar_dist.shape} for Haar and {vqc_prob_dist.shape} for VQC" 

255 ) 

256 else: 

257 vqc_prob_dist = vqc_prob_dist.reshape((1, -1)) 

258 

259 kl_divergence = np.zeros(vqc_prob_dist.shape[0]) 

260 for idx, p in enumerate(vqc_prob_dist): 

261 kl_divergence[idx] = np.sum(rel_entr(p, haar_dist)) 

262 

263 return kl_divergence