Coverage for qml_essentials/expressibility.py: 100%

78 statements  

« prev     ^ index     » next       coverage.py v7.6.5, created at 2024-11-15 11:13 +0000

1import pennylane.numpy as np 

2from typing import Tuple, List, Any 

3from scipy import integrate 

4from scipy.special import rel_entr 

5import os 

6 

7from qml_essentials.model import Model 

8 

9 

10class Expressibility: 

11 @staticmethod 

12 def _sample_state_fidelities( 

13 model: Model, 

14 x_samples: np.ndarray, 

15 n_samples: int, 

16 seed: int, 

17 kwargs: Any, 

18 ) -> np.ndarray: 

19 """ 

20 Compute the fidelities for each pair of input samples and parameter sets. 

21 

22 Args: 

23 model (Callable): Function that models the quantum circuit. 

24 x_samples (np.ndarray): Array of shape (n_input_samples, n_features) 

25 containing the input samples. 

26 n_samples (int): Number of parameter sets to generate. 

27 seed (int): Random number generator seed. 

28 kwargs (Any): Additional keyword arguments for the model function. 

29 

30 Returns: 

31 np.ndarray: Array of shape (n_input_samples, n_samples) 

32 containing the fidelities. 

33 """ 

34 rng = np.random.default_rng(seed) 

35 

36 # Generate random parameter sets 

37 # We need two sets of parameters, as we are computing fidelities for a 

38 # pair of random state vectors 

39 model.initialize_params(rng=rng, repeat=n_samples * 2) 

40 

41 n_x_samples = len(x_samples) 

42 

43 # Initialize array to store fidelities 

44 fidelities: np.ndarray = np.zeros((n_x_samples, n_samples)) 

45 

46 # Batch input samples and parameter sets for efficient computation 

47 x_samples_batched: np.ndarray = x_samples.reshape(1, -1).repeat( 

48 n_samples * 2, axis=0 

49 ) 

50 

51 # Compute the fidelity for each pair of input samples and parameters 

52 for idx in range(n_x_samples): 

53 

54 # Evaluate the model for the current pair of input samples and parameters 

55 # Execution type is explicitly set to density 

56 sv: np.ndarray = model( 

57 inputs=x_samples_batched[:, idx], 

58 params=model.params, 

59 execution_type="density", 

60 **kwargs, 

61 ) 

62 sqrt_sv1: np.ndarray = np.sqrt(sv[:n_samples]) 

63 

64 # Compute the fidelity using the partial trace of the statevector 

65 fidelity: np.ndarray = ( 

66 np.trace( 

67 np.sqrt(sqrt_sv1 * sv[n_samples:] * sqrt_sv1), 

68 axis1=1, 

69 axis2=2, 

70 ) 

71 ** 2 

72 ) 

73 # TODO: abs instead? 

74 fidelities[idx] = np.real(fidelity) 

75 

76 return fidelities 

77 

78 @staticmethod 

79 def state_fidelities( 

80 seed: int, 

81 n_samples: int, 

82 n_bins: int, 

83 n_input_samples: int, 

84 input_domain: List[float], 

85 model: Model, 

86 scale: bool = False, 

87 **kwargs: Any, 

88 ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 

89 """ 

90 Sample the state fidelities and histogram them into a 2D array. 

91 

92 Args: 

93 seed (int): Random number generator seed. 

94 n_samples (int): Number of parameter sets to generate. 

95 n_bins (int): Number of histogram bins. 

96 n_input_samples (int): Number of input samples. 

97 input_domain (List[float]): Input domain. 

98 model (Callable): Function that models the quantum circuit. 

99 scale (bool): Whether to scale the number of samples and bins. 

100 kwargs (Any): Additional keyword arguments for the model function. 

101 

102 Returns: 

103 Tuple[np.ndarray, np.ndarray, np.ndarray]: Tuple containing the 

104 input samples, bin edges, and histogram values. 

105 """ 

106 if scale: 

107 n_samples = np.power(2, model.n_qubits) * n_samples 

108 n_bins = model.n_qubits * n_bins 

109 

110 if input_domain is None or n_input_samples is None or n_input_samples == 0: 

111 x = np.zeros((1)) 

112 n_input_samples = 1 

113 else: 

114 x = np.linspace(*input_domain, n_input_samples, requires_grad=False) 

115 

116 fidelities = Expressibility._sample_state_fidelities( 

117 x_samples=x, 

118 n_samples=n_samples, 

119 seed=seed, 

120 model=model, 

121 kwargs=kwargs, 

122 ) 

123 z: np.ndarray = np.zeros((n_input_samples, n_bins)) 

124 

125 y: np.ndarray = np.linspace(0, 1, n_bins + 1) 

126 

127 for i, f in enumerate(fidelities): 

128 z[i], _ = np.histogram(f, bins=y) 

129 

130 z = z / n_samples 

131 

132 if z.shape[0] == 1: 

133 z = z.flatten() 

134 

135 return x, y, z 

136 

137 @staticmethod 

138 def _haar_probability(fidelity: float, n_qubits: int) -> float: 

139 """ 

140 Calculates theoretical probability density function for random Haar states 

141 as proposed by Sim et al. (https://arxiv.org/abs/1905.10876). 

142 

143 Args: 

144 fidelity (float): fidelity of two parameter assignments in [0, 1] 

145 n_qubits (int): number of qubits in the quantum system 

146 

147 Returns: 

148 float: probability for a given fidelity 

149 """ 

150 N = 2**n_qubits 

151 

152 prob = (N - 1) * (1 - fidelity) ** (N - 2) 

153 return prob 

154 

155 @staticmethod 

156 def _sample_haar_integral(n_qubits: int, n_bins: int) -> np.ndarray: 

157 """ 

158 Calculates theoretical probability density function for random Haar states 

159 as proposed by Sim et al. (https://arxiv.org/abs/1905.10876) and bins it 

160 into a 2D-histogram. 

161 

162 Args: 

163 n_qubits (int): number of qubits in the quantum system 

164 n_bins (int): number of histogram bins 

165 

166 Returns: 

167 np.ndarray: probability distribution for all fidelities 

168 """ 

169 dist = np.zeros(n_bins) 

170 for idx in range(n_bins): 

171 v = (1 / n_bins) * idx 

172 u = v + (1 / n_bins) 

173 dist[idx], _ = integrate.quad( 

174 Expressibility._haar_probability, v, u, args=(n_qubits,) 

175 ) 

176 

177 return dist 

178 

179 @staticmethod 

180 def haar_integral( 

181 n_qubits: int, 

182 n_bins: int, 

183 cache: bool = True, 

184 scale: bool = False, 

185 ) -> Tuple[np.ndarray, np.ndarray]: 

186 """ 

187 Calculates theoretical probability density function for random Haar states 

188 as proposed by Sim et al. (https://arxiv.org/abs/1905.10876) and bins it 

189 into a 3D-histogram. 

190 

191 Args: 

192 n_qubits (int): number of qubits in the quantum system 

193 n_bins (int): number of histogram bins 

194 cache (bool): whether to cache the haar integral 

195 scale (bool): whether to scale the number of bins 

196 

197 Returns: 

198 Tuple[np.ndarray, np.ndarray]: 

199 - x component (bins): the input domain 

200 - y component (probabilities): the haar probability density 

201 funtion for random Haar states 

202 """ 

203 if scale: 

204 n_bins = n_qubits * n_bins 

205 

206 x = np.linspace(0, 1, n_bins) 

207 

208 if cache: 

209 name = f"haar_{n_qubits}q_{n_bins}s_{'scaled' if scale else ''}.npy" 

210 

211 cache_folder = ".cache" 

212 if not os.path.exists(cache_folder): 

213 os.mkdir(cache_folder) 

214 

215 file_path = os.path.join(cache_folder, name) 

216 

217 if os.path.isfile(file_path): 

218 y = np.load(file_path) 

219 return x, y 

220 

221 y = Expressibility._sample_haar_integral(n_qubits, n_bins) 

222 

223 if cache: 

224 np.save(file_path, y) 

225 

226 return x, y 

227 

228 @staticmethod 

229 def kullback_leibler_divergence( 

230 vqc_prob_dist: np.ndarray, 

231 haar_dist: np.ndarray, 

232 ) -> np.ndarray: 

233 """ 

234 Calculates the KL divergence between two probability distributions (Haar 

235 probability distribution and the fidelity distribution sampled from a VQC). 

236 

237 Args: 

238 vqc_prob_dist (np.ndarray): VQC fidelity probability distribution. 

239 Should have shape (n_inputs_samples, n_bins) 

240 haar_dist (np.ndarray): Haar probability distribution with shape. 

241 Should have shape (n_bins, ) 

242 

243 Returns: 

244 np.ndarray: Array of KL-Divergence values for all values in axis 1 

245 """ 

246 if len(vqc_prob_dist.shape) > 1: 

247 assert all([haar_dist.shape == p.shape for p in vqc_prob_dist]), ( 

248 "All probabilities for inputs should have the same shape as Haar. " 

249 f"Got {haar_dist.shape} for Haar and {vqc_prob_dist.shape} for VQC" 

250 ) 

251 else: 

252 vqc_prob_dist = vqc_prob_dist.reshape((1, -1)) 

253 

254 kl_divergence = np.zeros(vqc_prob_dist.shape[0]) 

255 for idx, p in enumerate(vqc_prob_dist): 

256 kl_divergence[idx] = np.sum(rel_entr(p, haar_dist)) 

257 

258 return kl_divergence