Coverage for qml_essentials/expressibility.py: 97%

78 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-15 15:48 +0000

1import pennylane.numpy as np 

2from typing import Tuple, List, Any 

3from scipy import integrate 

4from scipy.linalg import sqrtm 

5from scipy.special import rel_entr 

6from qml_essentials.model import Model 

7import os 

8 

9 

10class Expressibility: 

11 @staticmethod 

12 def _sample_state_fidelities( 

13 model: Model, 

14 x_samples: np.ndarray, 

15 n_samples: int, 

16 seed: int, 

17 kwargs: Any, 

18 ) -> np.ndarray: 

19 """ 

20 Compute the fidelities for each pair of input samples and parameter sets. 

21 

22 Args: 

23 model (Callable): Function that models the quantum circuit. 

24 x_samples (np.ndarray): Array of shape (n_input_samples, n_features) 

25 containing the input samples. 

26 n_samples (int): Number of parameter sets to generate. 

27 seed (int): Random number generator seed. 

28 kwargs (Any): Additional keyword arguments for the model function. 

29 

30 Returns: 

31 np.ndarray: Array of shape (n_input_samples, n_samples) 

32 containing the fidelities. 

33 """ 

34 rng = np.random.default_rng(seed) 

35 

36 # Generate random parameter sets 

37 # We need two sets of parameters, as we are computing fidelities for a 

38 # pair of random state vectors 

39 model.initialize_params(rng=rng, repeat=n_samples * 2) 

40 

41 # Initialize array to store fidelities 

42 fidelities: np.ndarray = np.zeros((len(x_samples), n_samples)) 

43 

44 # Compute the fidelity for each pair of input samples and parameters 

45 for idx, x_sample in enumerate(x_samples): 

46 

47 # Evaluate the model for the current pair of input samples and parameters 

48 # Execution type is explicitly set to density 

49 sv: np.ndarray = model( 

50 inputs=x_sample, 

51 params=model.params, 

52 execution_type="density", 

53 **kwargs, 

54 ) 

55 

56 # $\sqrt{\rho}$ 

57 sqrt_sv1: np.ndarray = np.array([sqrtm(m) for m in sv[:n_samples]]) 

58 

59 # $\sqrt{\rho} \sigma \sqrt{\rho}$ 

60 inner_fidelity = sqrt_sv1 @ sv[n_samples:] @ sqrt_sv1 

61 

62 # Compute the fidelity using the partial trace of the statevector 

63 fidelity: np.ndarray = ( 

64 np.trace( 

65 np.array([sqrtm(m) for m in inner_fidelity]), 

66 axis1=1, 

67 axis2=2, 

68 ) 

69 ** 2 

70 ) 

71 

72 fidelities[idx] = np.abs(fidelity) 

73 

74 return fidelities 

75 

76 @staticmethod 

77 def state_fidelities( 

78 seed: int, 

79 n_samples: int, 

80 n_bins: int, 

81 model: Model, 

82 n_input_samples: int = 0, 

83 input_domain: List[float] = None, 

84 scale: bool = False, 

85 **kwargs: Any, 

86 ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 

87 """ 

88 Sample the state fidelities and histogram them into a 2D array. 

89 

90 Args: 

91 seed (int): Random number generator seed. 

92 n_samples (int): Number of parameter sets to generate. 

93 n_bins (int): Number of histogram bins. 

94 n_input_samples (int): Number of input samples. 

95 input_domain (List[float]): Input domain. 

96 model (Callable): Function that models the quantum circuit. 

97 scale (bool): Whether to scale the number of samples and bins. 

98 kwargs (Any): Additional keyword arguments for the model function. 

99 

100 Returns: 

101 Tuple[np.ndarray, np.ndarray, np.ndarray]: Tuple containing the 

102 input samples, bin edges, and histogram values. 

103 """ 

104 if scale: 

105 n_samples = np.power(2, model.n_qubits) * n_samples 

106 n_bins = model.n_qubits * n_bins 

107 

108 if input_domain is None or n_input_samples is None or n_input_samples == 0: 

109 x = np.zeros((1)) 

110 n_input_samples = 1 

111 else: 

112 x = np.linspace(*input_domain, n_input_samples, requires_grad=False) 

113 

114 fidelities = Expressibility._sample_state_fidelities( 

115 x_samples=x, 

116 n_samples=n_samples, 

117 seed=seed, 

118 model=model, 

119 kwargs=kwargs, 

120 ) 

121 z: np.ndarray = np.zeros((n_input_samples, n_bins)) 

122 

123 y: np.ndarray = np.linspace(0, 1, n_bins + 1) 

124 

125 for i, f in enumerate(fidelities): 

126 z[i], _ = np.histogram(f, bins=y) 

127 

128 z = z / n_samples 

129 

130 if z.shape[0] == 1: 

131 z = z.flatten() 

132 

133 return x, y, z 

134 

135 @staticmethod 

136 def _haar_probability(fidelity: float, n_qubits: int) -> float: 

137 """ 

138 Calculates theoretical probability density function for random Haar states 

139 as proposed by Sim et al. (https://arxiv.org/abs/1905.10876). 

140 

141 Args: 

142 fidelity (float): fidelity of two parameter assignments in [0, 1] 

143 n_qubits (int): number of qubits in the quantum system 

144 

145 Returns: 

146 float: probability for a given fidelity 

147 """ 

148 N = 2**n_qubits 

149 

150 prob = (N - 1) * (1 - fidelity) ** (N - 2) 

151 return prob 

152 

153 @staticmethod 

154 def _sample_haar_integral(n_qubits: int, n_bins: int) -> np.ndarray: 

155 """ 

156 Calculates theoretical probability density function for random Haar states 

157 as proposed by Sim et al. (https://arxiv.org/abs/1905.10876) and bins it 

158 into a 2D-histogram. 

159 

160 Args: 

161 n_qubits (int): number of qubits in the quantum system 

162 n_bins (int): number of histogram bins 

163 

164 Returns: 

165 np.ndarray: probability distribution for all fidelities 

166 """ 

167 dist = np.zeros(n_bins) 

168 for idx in range(n_bins): 

169 v = idx / n_bins 

170 u = (idx + 1) / n_bins 

171 dist[idx], _ = integrate.quad( 

172 Expressibility._haar_probability, v, u, args=(n_qubits,) 

173 ) 

174 

175 return dist 

176 

177 @staticmethod 

178 def haar_integral( 

179 n_qubits: int, 

180 n_bins: int, 

181 cache: bool = True, 

182 scale: bool = False, 

183 ) -> Tuple[np.ndarray, np.ndarray]: 

184 """ 

185 Calculates theoretical probability density function for random Haar states 

186 as proposed by Sim et al. (https://arxiv.org/abs/1905.10876) and bins it 

187 into a 3D-histogram. 

188 

189 Args: 

190 n_qubits (int): number of qubits in the quantum system 

191 n_bins (int): number of histogram bins 

192 cache (bool): whether to cache the haar integral 

193 scale (bool): whether to scale the number of bins 

194 

195 Returns: 

196 Tuple[np.ndarray, np.ndarray]: 

197 - x component (bins): the input domain 

198 - y component (probabilities): the haar probability density 

199 funtion for random Haar states 

200 """ 

201 if scale: 

202 n_bins = n_qubits * n_bins 

203 

204 x = np.linspace(0, 1, n_bins) 

205 

206 if cache: 

207 name = f"haar_{n_qubits}q_{n_bins}s_{'scaled' if scale else ''}.npy" 

208 

209 cache_folder = ".cache" 

210 if not os.path.exists(cache_folder): 

211 os.mkdir(cache_folder) 

212 

213 file_path = os.path.join(cache_folder, name) 

214 

215 if os.path.isfile(file_path): 

216 y = np.load(file_path) 

217 return x, y 

218 

219 y = Expressibility._sample_haar_integral(n_qubits, n_bins) 

220 

221 if cache: 

222 np.save(file_path, y) 

223 

224 return x, y 

225 

226 @staticmethod 

227 def kullback_leibler_divergence( 

228 vqc_prob_dist: np.ndarray, 

229 haar_dist: np.ndarray, 

230 ) -> np.ndarray: 

231 """ 

232 Calculates the KL divergence between two probability distributions (Haar 

233 probability distribution and the fidelity distribution sampled from a VQC). 

234 

235 Args: 

236 vqc_prob_dist (np.ndarray): VQC fidelity probability distribution. 

237 Should have shape (n_inputs_samples, n_bins) 

238 haar_dist (np.ndarray): Haar probability distribution with shape. 

239 Should have shape (n_bins, ) 

240 

241 Returns: 

242 np.ndarray: Array of KL-Divergence values for all values in axis 1 

243 """ 

244 if len(vqc_prob_dist.shape) > 1: 

245 assert all([haar_dist.shape == p.shape for p in vqc_prob_dist]), ( 

246 "All probabilities for inputs should have the same shape as Haar. " 

247 f"Got {haar_dist.shape} for Haar and {vqc_prob_dist.shape} for VQC" 

248 ) 

249 else: 

250 vqc_prob_dist = vqc_prob_dist.reshape((1, -1)) 

251 

252 kl_divergence = np.zeros(vqc_prob_dist.shape[0]) 

253 for idx, p in enumerate(vqc_prob_dist): 

254 kl_divergence[idx] = np.sum(rel_entr(p, haar_dist)) 

255 

256 return kl_divergence