Coverage for qml_essentials / math.py: 93%

71 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-30 11:43 +0000

1import jax 

2import jax.numpy as jnp 

3from qml_essentials.operations import _cdtype 

4from scipy.linalg import logm 

5 

6 

7def logm_v(A: jnp.ndarray, **kwargs) -> jnp.ndarray: 

8 """ 

9 Compute the logarithm of a matrix. If the provided matrix has an additional 

10 batch dimension, the logarithm of each matrix is computed. 

11 

12 Args: 

13 A (jnp.ndarray): The (potentially batched) matrices of which to compute 

14 the logarithm. 

15 

16 Returns: 

17 jnp.ndarray: The log matrices 

18 """ 

19 # TODO: check warnings 

20 if len(A.shape) == 2: 

21 return logm(A, **kwargs) 

22 elif len(A.shape) == 3: 

23 AV = jnp.zeros(A.shape, dtype=_cdtype()) 

24 for i in range(A.shape[0]): 

25 AV = AV.at[i].set(logm(A[i], **kwargs)) 

26 return AV 

27 else: 

28 raise NotImplementedError("Unsupported shape of input matrix") 

29 

30 

31def _sqrt_matrix(density_matrix: jnp.ndarray) -> jnp.ndarray: 

32 r"""Compute the matrix square root of a density matrix. 

33 

34 Uses eigendecomposition: if :math:`\rho = V \Lambda V^\dagger`, then 

35 :math:`\sqrt{\rho} = V \sqrt{\Lambda} V^\dagger`. 

36 

37 Negative eigenvalues (numerical noise) are clamped to zero. 

38 

39 Args: 

40 density_matrix: Density matrix of shape ``(d, d)`` or ``(B, d, d)``. 

41 

42 Returns: 

43 The matrix square root with the same shape as the input. 

44 """ 

45 evs, vecs = jnp.linalg.eigh(density_matrix) 

46 evs = jnp.real(evs) 

47 evs = jnp.where(evs > 0.0, evs, 0.0) 

48 

49 if density_matrix.ndim == 3: 

50 # batched: (B, d, d) 

51 sqrt_evs = jnp.sqrt(evs)[:, :, None] * jnp.eye( 

52 density_matrix.shape[-1], dtype=_cdtype() 

53 ) 

54 return vecs @ sqrt_evs @ jnp.conj(jnp.transpose(vecs, (0, 2, 1))) 

55 

56 # single: (d, d) 

57 return vecs @ jnp.diag(jnp.sqrt(evs)) @ jnp.conj(vecs.T) 

58 

59 

60def _fidelity_statevector( 

61 state0: jnp.ndarray, 

62 state1: jnp.ndarray, 

63) -> jnp.ndarray: 

64 r"""Fidelity between two pure states (state vectors). 

65 

66 .. math:: 

67 

68 F(\ket{\psi}, \ket{\phi}) = \left|\braket{\psi | \phi}\right|^2 

69 The inputs are normalised before the overlap is computed so that 

70 the result is always in :math:`[0, 1]`. 

71 """ 

72 # Normalise so that unnormalised inputs don't produce F > 1. 

73 norm0 = jnp.linalg.norm(state0, axis=-1, keepdims=True) 

74 norm1 = jnp.linalg.norm(state1, axis=-1, keepdims=True) 

75 state0 = state0 / jnp.where(norm0 > 0, norm0, 1.0) 

76 state1 = state1 / jnp.where(norm1 > 0, norm1, 1.0) 

77 

78 batched0 = state0.ndim > 1 

79 batched1 = state1.ndim > 1 

80 

81 idx0 = "ab" if batched0 else "b" 

82 idx1 = "ab" if batched1 else "b" 

83 target = "a" if (batched0 or batched1) else "" 

84 

85 overlap = jnp.einsum(f"{idx0},{idx1}->{target}", jnp.conj(state0), state1) 

86 return jnp.abs(overlap) ** 2 

87 

88 

89def _fidelity_dm( 

90 state0: jnp.ndarray, 

91 state1: jnp.ndarray, 

92) -> jnp.ndarray: 

93 r"""Fidelity between two mixed states (density matrices).""" 

94 sqrt_state0 = _sqrt_matrix(state0) 

95 product = sqrt_state0 @ state1 @ sqrt_state0 

96 

97 evs = jnp.linalg.eigvalsh(product) 

98 evs = jnp.real(evs) 

99 evs = jnp.where(evs > 0.0, evs, 0.0) 

100 

101 return jnp.sum(jnp.sqrt(evs), axis=-1) ** 2 

102 

103 

104def fidelity( 

105 state0: jnp.ndarray, 

106 state1: jnp.ndarray, 

107) -> jnp.ndarray: 

108 r"""Compute the fidelity between two quantum states. 

109 

110 Accepts either state vectors or density matrices. 

111 

112 Args: 

113 state0: State vector or density matrix. 

114 state1: State vector or density matrix (same kind as *state0*). 

115 

116 Returns: 

117 Fidelity (scalar or shape ``(B,)``). 

118 

119 Raises: 

120 ValueError: If the two states have incompatible shapes or 

121 different representations (vector vs. matrix). 

122 """ 

123 state0 = jnp.asarray(state0, dtype=_cdtype()) 

124 state1 = jnp.asarray(state1, dtype=_cdtype()) 

125 

126 if state0.shape[-1] != state1.shape[-1]: 

127 raise ValueError("The two states must have the same number of wires.") 

128 

129 is_sv0 = state0.ndim <= 2 and ( 

130 state0.ndim == 1 or state0.shape[-2] != state0.shape[-1] 

131 ) 

132 is_sv1 = state1.ndim <= 2 and ( 

133 state1.ndim == 1 or state1.shape[-2] != state1.shape[-1] 

134 ) 

135 

136 if is_sv0 != is_sv1: 

137 raise ValueError( 

138 "Both states must be of the same kind " 

139 "(both state vectors or both density matrices)." 

140 ) 

141 

142 if is_sv0: 

143 return _fidelity_statevector(state0, state1) 

144 return _fidelity_dm(state0, state1) 

145 

146 

147def trace_distance( 

148 state0: jnp.ndarray, 

149 state1: jnp.ndarray, 

150) -> jnp.ndarray: 

151 r"""Compute the trace distance between two quantum states. 

152 

153 Supports single density matrices of shape ``(2**N, 2**N)`` and batched 

154 density matrices of shape ``(B, 2**N, 2**N)``. 

155 

156 Args: 

157 state0: Density matrix of shape ``(2**N, 2**N)`` or ``(B, 2**N, 2**N)``. 

158 state1: Density matrix of shape ``(2**N, 2**N)`` or ``(B, 2**N, 2**N)``. 

159 

160 Returns: 

161 Trace distance (scalar or shape ``(B,)``). 

162 """ 

163 state0 = jnp.asarray(state0, dtype=_cdtype()) 

164 state1 = jnp.asarray(state1, dtype=_cdtype()) 

165 

166 if state0.shape[-1] != state1.shape[-1]: 

167 raise ValueError("The two states must have the same number of wires.") 

168 

169 eigvals = jnp.abs(jnp.linalg.eigvalsh(state0 - state1)) 

170 return jnp.sum(eigvals, axis=-1) / 2 

171 

172 

173def phase_difference( 

174 state0: jnp.ndarray, 

175 state1: jnp.ndarray, 

176) -> jnp.ndarray: 

177 r"""Compute the phase difference between two state vectors. 

178 

179 A value of zero indicates the two states are related by at most a 

180 real global factor (i.e. no relative phase). The result lies in 

181 :math:`[-\pi, 1 + \pi]`. 

182 

183 Supports single state vectors of shape ``(2**N,)`` and batched state 

184 vectors of shape ``(B, 2**N)``. 

185 

186 Args: 

187 state0: State vector of shape ``(2**N,)`` or ``(B, 2**N)``. 

188 state1: State vector of shape ``(2**N,)`` or ``(B, 2**N)``. 

189 

190 Returns: 

191 Phase difference (scalar or shape ``(B,)``). 

192 """ 

193 state0 = jnp.asarray(state0, dtype=_cdtype()) 

194 state1 = jnp.asarray(state1, dtype=_cdtype()) 

195 

196 if state0.shape[-1] != state1.shape[-1]: 

197 raise ValueError("The two states must have the same number of wires.") 

198 

199 batched0 = state0.ndim > 1 

200 batched1 = state1.ndim > 1 

201 

202 idx0 = "ab" if batched0 else "b" 

203 idx1 = "ab" if batched1 else "b" 

204 target = "a" if (batched0 or batched1) else "" 

205 

206 inner = jnp.einsum(f"{idx0},{idx1}->{target}", jnp.conj(state0), state1) 

207 return jnp.angle(inner)