Coverage for qml_essentials / math.py: 93%
70 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-16 10:19 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-16 10:19 +0000
1import jax.numpy as jnp
2from qml_essentials.operations import _cdtype
3from scipy.linalg import logm
6def logm_v(A: jnp.ndarray, **kwargs) -> jnp.ndarray:
7 """
8 Compute the logarithm of a matrix. If the provided matrix has an additional
9 batch dimension, the logarithm of each matrix is computed.
11 Args:
12 A (jnp.ndarray): The (potentially batched) matrices of which to compute
13 the logarithm.
15 Returns:
16 jnp.ndarray: The log matrices
17 """
18 # TODO: check warnings
19 if len(A.shape) == 2:
20 return logm(A, **kwargs)
21 elif len(A.shape) == 3:
22 AV = jnp.zeros(A.shape, dtype=_cdtype())
23 for i in range(A.shape[0]):
24 AV = AV.at[i].set(logm(A[i], **kwargs))
25 return AV
26 else:
27 raise NotImplementedError("Unsupported shape of input matrix")
30def _sqrt_matrix(density_matrix: jnp.ndarray) -> jnp.ndarray:
31 r"""Compute the matrix square root of a density matrix.
33 Uses eigendecomposition: if :math:`\rho = V \Lambda V^\dagger`, then
34 :math:`\sqrt{\rho} = V \sqrt{\Lambda} V^\dagger`.
36 Negative eigenvalues (numerical noise) are clamped to zero.
38 Args:
39 density_matrix: Density matrix of shape ``(d, d)`` or ``(B, d, d)``.
41 Returns:
42 The matrix square root with the same shape as the input.
43 """
44 evs, vecs = jnp.linalg.eigh(density_matrix)
45 evs = jnp.real(evs)
46 evs = jnp.where(evs > 0.0, evs, 0.0)
48 if density_matrix.ndim == 3:
49 # batched: (B, d, d)
50 sqrt_evs = jnp.sqrt(evs)[:, :, None] * jnp.eye(
51 density_matrix.shape[-1], dtype=_cdtype()
52 )
53 return vecs @ sqrt_evs @ jnp.conj(jnp.transpose(vecs, (0, 2, 1)))
55 # single: (d, d)
56 return vecs @ jnp.diag(jnp.sqrt(evs)) @ jnp.conj(vecs.T)
59def _fidelity_statevector(
60 state0: jnp.ndarray,
61 state1: jnp.ndarray,
62) -> jnp.ndarray:
63 r"""Fidelity between two pure states (state vectors).
65 .. math::
67 F(\ket{\psi}, \ket{\phi}) = \left|\braket{\psi | \phi}\right|^2
68 The inputs are normalised before the overlap is computed so that
69 the result is always in :math:`[0, 1]`.
70 """
71 # Normalise so that unnormalised inputs don't produce F > 1.
72 norm0 = jnp.linalg.norm(state0, axis=-1, keepdims=True)
73 norm1 = jnp.linalg.norm(state1, axis=-1, keepdims=True)
74 state0 = state0 / jnp.where(norm0 > 0, norm0, 1.0)
75 state1 = state1 / jnp.where(norm1 > 0, norm1, 1.0)
77 batched0 = state0.ndim > 1
78 batched1 = state1.ndim > 1
80 idx0 = "ab" if batched0 else "b"
81 idx1 = "ab" if batched1 else "b"
82 target = "a" if (batched0 or batched1) else ""
84 overlap = jnp.einsum(f"{idx0},{idx1}->{target}", jnp.conj(state0), state1)
85 return jnp.abs(overlap) ** 2
88def _fidelity_dm(
89 state0: jnp.ndarray,
90 state1: jnp.ndarray,
91) -> jnp.ndarray:
92 r"""Fidelity between two mixed states (density matrices)."""
93 sqrt_state0 = _sqrt_matrix(state0)
94 product = sqrt_state0 @ state1 @ sqrt_state0
96 evs = jnp.linalg.eigvalsh(product)
97 evs = jnp.real(evs)
98 evs = jnp.where(evs > 0.0, evs, 0.0)
100 return jnp.sum(jnp.sqrt(evs), axis=-1) ** 2
103def fidelity(
104 state0: jnp.ndarray,
105 state1: jnp.ndarray,
106) -> jnp.ndarray:
107 r"""Compute the fidelity between two quantum states.
109 Accepts either state vectors or density matrices.
111 Args:
112 state0: State vector or density matrix.
113 state1: State vector or density matrix (same kind as *state0*).
115 Returns:
116 Fidelity (scalar or shape ``(B,)``).
118 Raises:
119 ValueError: If the two states have incompatible shapes or
120 different representations (vector vs. matrix).
121 """
122 state0 = jnp.asarray(state0, dtype=_cdtype())
123 state1 = jnp.asarray(state1, dtype=_cdtype())
125 if state0.shape[-1] != state1.shape[-1]:
126 raise ValueError("The two states must have the same number of wires.")
128 is_sv0 = state0.ndim <= 2 and (
129 state0.ndim == 1 or state0.shape[-2] != state0.shape[-1]
130 )
131 is_sv1 = state1.ndim <= 2 and (
132 state1.ndim == 1 or state1.shape[-2] != state1.shape[-1]
133 )
135 if is_sv0 != is_sv1:
136 raise ValueError(
137 "Both states must be of the same kind "
138 "(both state vectors or both density matrices)."
139 )
141 if is_sv0:
142 return _fidelity_statevector(state0, state1)
143 return _fidelity_dm(state0, state1)
146def trace_distance(
147 state0: jnp.ndarray,
148 state1: jnp.ndarray,
149) -> jnp.ndarray:
150 r"""Compute the trace distance between two quantum states.
152 Supports single density matrices of shape ``(2**N, 2**N)`` and batched
153 density matrices of shape ``(B, 2**N, 2**N)``.
155 Args:
156 state0: Density matrix of shape ``(2**N, 2**N)`` or ``(B, 2**N, 2**N)``.
157 state1: Density matrix of shape ``(2**N, 2**N)`` or ``(B, 2**N, 2**N)``.
159 Returns:
160 Trace distance (scalar or shape ``(B,)``).
161 """
162 state0 = jnp.asarray(state0, dtype=_cdtype())
163 state1 = jnp.asarray(state1, dtype=_cdtype())
165 if state0.shape[-1] != state1.shape[-1]:
166 raise ValueError("The two states must have the same number of wires.")
168 eigvals = jnp.abs(jnp.linalg.eigvalsh(state0 - state1))
169 return jnp.sum(eigvals, axis=-1) / 2
172def phase_difference(
173 state0: jnp.ndarray,
174 state1: jnp.ndarray,
175) -> jnp.ndarray:
176 r"""Compute the phase difference between two state vectors.
178 A value of zero indicates the two states are related by at most a
179 real global factor (i.e. no relative phase). The result lies in
180 :math:`[-\pi, 1 + \pi]`.
182 Supports single state vectors of shape ``(2**N,)`` and batched state
183 vectors of shape ``(B, 2**N)``.
185 Args:
186 state0: State vector of shape ``(2**N,)`` or ``(B, 2**N)``.
187 state1: State vector of shape ``(2**N,)`` or ``(B, 2**N)``.
189 Returns:
190 Phase difference (scalar or shape ``(B,)``).
191 """
192 state0 = jnp.asarray(state0, dtype=_cdtype())
193 state1 = jnp.asarray(state1, dtype=_cdtype())
195 if state0.shape[-1] != state1.shape[-1]:
196 raise ValueError("The two states must have the same number of wires.")
198 batched0 = state0.ndim > 1
199 batched1 = state1.ndim > 1
201 idx0 = "ab" if batched0 else "b"
202 idx1 = "ab" if batched1 else "b"
203 target = "a" if (batched0 or batched1) else ""
205 inner = jnp.einsum(f"{idx0},{idx1}->{target}", jnp.conj(state0), state1)
206 return jnp.angle(inner)