Coverage for qml_essentials / math.py: 93%

70 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-16 10:19 +0000

1import jax.numpy as jnp 

2from qml_essentials.operations import _cdtype 

3from scipy.linalg import logm 

4 

5 

6def logm_v(A: jnp.ndarray, **kwargs) -> jnp.ndarray: 

7 """ 

8 Compute the logarithm of a matrix. If the provided matrix has an additional 

9 batch dimension, the logarithm of each matrix is computed. 

10 

11 Args: 

12 A (jnp.ndarray): The (potentially batched) matrices of which to compute 

13 the logarithm. 

14 

15 Returns: 

16 jnp.ndarray: The log matrices 

17 """ 

18 # TODO: check warnings 

19 if len(A.shape) == 2: 

20 return logm(A, **kwargs) 

21 elif len(A.shape) == 3: 

22 AV = jnp.zeros(A.shape, dtype=_cdtype()) 

23 for i in range(A.shape[0]): 

24 AV = AV.at[i].set(logm(A[i], **kwargs)) 

25 return AV 

26 else: 

27 raise NotImplementedError("Unsupported shape of input matrix") 

28 

29 

30def _sqrt_matrix(density_matrix: jnp.ndarray) -> jnp.ndarray: 

31 r"""Compute the matrix square root of a density matrix. 

32 

33 Uses eigendecomposition: if :math:`\rho = V \Lambda V^\dagger`, then 

34 :math:`\sqrt{\rho} = V \sqrt{\Lambda} V^\dagger`. 

35 

36 Negative eigenvalues (numerical noise) are clamped to zero. 

37 

38 Args: 

39 density_matrix: Density matrix of shape ``(d, d)`` or ``(B, d, d)``. 

40 

41 Returns: 

42 The matrix square root with the same shape as the input. 

43 """ 

44 evs, vecs = jnp.linalg.eigh(density_matrix) 

45 evs = jnp.real(evs) 

46 evs = jnp.where(evs > 0.0, evs, 0.0) 

47 

48 if density_matrix.ndim == 3: 

49 # batched: (B, d, d) 

50 sqrt_evs = jnp.sqrt(evs)[:, :, None] * jnp.eye( 

51 density_matrix.shape[-1], dtype=_cdtype() 

52 ) 

53 return vecs @ sqrt_evs @ jnp.conj(jnp.transpose(vecs, (0, 2, 1))) 

54 

55 # single: (d, d) 

56 return vecs @ jnp.diag(jnp.sqrt(evs)) @ jnp.conj(vecs.T) 

57 

58 

59def _fidelity_statevector( 

60 state0: jnp.ndarray, 

61 state1: jnp.ndarray, 

62) -> jnp.ndarray: 

63 r"""Fidelity between two pure states (state vectors). 

64 

65 .. math:: 

66 

67 F(\ket{\psi}, \ket{\phi}) = \left|\braket{\psi | \phi}\right|^2 

68 The inputs are normalised before the overlap is computed so that 

69 the result is always in :math:`[0, 1]`. 

70 """ 

71 # Normalise so that unnormalised inputs don't produce F > 1. 

72 norm0 = jnp.linalg.norm(state0, axis=-1, keepdims=True) 

73 norm1 = jnp.linalg.norm(state1, axis=-1, keepdims=True) 

74 state0 = state0 / jnp.where(norm0 > 0, norm0, 1.0) 

75 state1 = state1 / jnp.where(norm1 > 0, norm1, 1.0) 

76 

77 batched0 = state0.ndim > 1 

78 batched1 = state1.ndim > 1 

79 

80 idx0 = "ab" if batched0 else "b" 

81 idx1 = "ab" if batched1 else "b" 

82 target = "a" if (batched0 or batched1) else "" 

83 

84 overlap = jnp.einsum(f"{idx0},{idx1}->{target}", jnp.conj(state0), state1) 

85 return jnp.abs(overlap) ** 2 

86 

87 

88def _fidelity_dm( 

89 state0: jnp.ndarray, 

90 state1: jnp.ndarray, 

91) -> jnp.ndarray: 

92 r"""Fidelity between two mixed states (density matrices).""" 

93 sqrt_state0 = _sqrt_matrix(state0) 

94 product = sqrt_state0 @ state1 @ sqrt_state0 

95 

96 evs = jnp.linalg.eigvalsh(product) 

97 evs = jnp.real(evs) 

98 evs = jnp.where(evs > 0.0, evs, 0.0) 

99 

100 return jnp.sum(jnp.sqrt(evs), axis=-1) ** 2 

101 

102 

103def fidelity( 

104 state0: jnp.ndarray, 

105 state1: jnp.ndarray, 

106) -> jnp.ndarray: 

107 r"""Compute the fidelity between two quantum states. 

108 

109 Accepts either state vectors or density matrices. 

110 

111 Args: 

112 state0: State vector or density matrix. 

113 state1: State vector or density matrix (same kind as *state0*). 

114 

115 Returns: 

116 Fidelity (scalar or shape ``(B,)``). 

117 

118 Raises: 

119 ValueError: If the two states have incompatible shapes or 

120 different representations (vector vs. matrix). 

121 """ 

122 state0 = jnp.asarray(state0, dtype=_cdtype()) 

123 state1 = jnp.asarray(state1, dtype=_cdtype()) 

124 

125 if state0.shape[-1] != state1.shape[-1]: 

126 raise ValueError("The two states must have the same number of wires.") 

127 

128 is_sv0 = state0.ndim <= 2 and ( 

129 state0.ndim == 1 or state0.shape[-2] != state0.shape[-1] 

130 ) 

131 is_sv1 = state1.ndim <= 2 and ( 

132 state1.ndim == 1 or state1.shape[-2] != state1.shape[-1] 

133 ) 

134 

135 if is_sv0 != is_sv1: 

136 raise ValueError( 

137 "Both states must be of the same kind " 

138 "(both state vectors or both density matrices)." 

139 ) 

140 

141 if is_sv0: 

142 return _fidelity_statevector(state0, state1) 

143 return _fidelity_dm(state0, state1) 

144 

145 

146def trace_distance( 

147 state0: jnp.ndarray, 

148 state1: jnp.ndarray, 

149) -> jnp.ndarray: 

150 r"""Compute the trace distance between two quantum states. 

151 

152 Supports single density matrices of shape ``(2**N, 2**N)`` and batched 

153 density matrices of shape ``(B, 2**N, 2**N)``. 

154 

155 Args: 

156 state0: Density matrix of shape ``(2**N, 2**N)`` or ``(B, 2**N, 2**N)``. 

157 state1: Density matrix of shape ``(2**N, 2**N)`` or ``(B, 2**N, 2**N)``. 

158 

159 Returns: 

160 Trace distance (scalar or shape ``(B,)``). 

161 """ 

162 state0 = jnp.asarray(state0, dtype=_cdtype()) 

163 state1 = jnp.asarray(state1, dtype=_cdtype()) 

164 

165 if state0.shape[-1] != state1.shape[-1]: 

166 raise ValueError("The two states must have the same number of wires.") 

167 

168 eigvals = jnp.abs(jnp.linalg.eigvalsh(state0 - state1)) 

169 return jnp.sum(eigvals, axis=-1) / 2 

170 

171 

172def phase_difference( 

173 state0: jnp.ndarray, 

174 state1: jnp.ndarray, 

175) -> jnp.ndarray: 

176 r"""Compute the phase difference between two state vectors. 

177 

178 A value of zero indicates the two states are related by at most a 

179 real global factor (i.e. no relative phase). The result lies in 

180 :math:`[-\pi, 1 + \pi]`. 

181 

182 Supports single state vectors of shape ``(2**N,)`` and batched state 

183 vectors of shape ``(B, 2**N)``. 

184 

185 Args: 

186 state0: State vector of shape ``(2**N,)`` or ``(B, 2**N)``. 

187 state1: State vector of shape ``(2**N,)`` or ``(B, 2**N)``. 

188 

189 Returns: 

190 Phase difference (scalar or shape ``(B,)``). 

191 """ 

192 state0 = jnp.asarray(state0, dtype=_cdtype()) 

193 state1 = jnp.asarray(state1, dtype=_cdtype()) 

194 

195 if state0.shape[-1] != state1.shape[-1]: 

196 raise ValueError("The two states must have the same number of wires.") 

197 

198 batched0 = state0.ndim > 1 

199 batched1 = state1.ndim > 1 

200 

201 idx0 = "ab" if batched0 else "b" 

202 idx1 = "ab" if batched1 else "b" 

203 target = "a" if (batched0 or batched1) else "" 

204 

205 inner = jnp.einsum(f"{idx0},{idx1}->{target}", jnp.conj(state0), state1) 

206 return jnp.angle(inner)