Coverage for qml_essentials / math.py: 93%
71 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-30 11:43 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-30 11:43 +0000
1import jax
2import jax.numpy as jnp
3from qml_essentials.operations import _cdtype
4from scipy.linalg import logm
7def logm_v(A: jnp.ndarray, **kwargs) -> jnp.ndarray:
8 """
9 Compute the logarithm of a matrix. If the provided matrix has an additional
10 batch dimension, the logarithm of each matrix is computed.
12 Args:
13 A (jnp.ndarray): The (potentially batched) matrices of which to compute
14 the logarithm.
16 Returns:
17 jnp.ndarray: The log matrices
18 """
19 # TODO: check warnings
20 if len(A.shape) == 2:
21 return logm(A, **kwargs)
22 elif len(A.shape) == 3:
23 AV = jnp.zeros(A.shape, dtype=_cdtype())
24 for i in range(A.shape[0]):
25 AV = AV.at[i].set(logm(A[i], **kwargs))
26 return AV
27 else:
28 raise NotImplementedError("Unsupported shape of input matrix")
31def _sqrt_matrix(density_matrix: jnp.ndarray) -> jnp.ndarray:
32 r"""Compute the matrix square root of a density matrix.
34 Uses eigendecomposition: if :math:`\rho = V \Lambda V^\dagger`, then
35 :math:`\sqrt{\rho} = V \sqrt{\Lambda} V^\dagger`.
37 Negative eigenvalues (numerical noise) are clamped to zero.
39 Args:
40 density_matrix: Density matrix of shape ``(d, d)`` or ``(B, d, d)``.
42 Returns:
43 The matrix square root with the same shape as the input.
44 """
45 evs, vecs = jnp.linalg.eigh(density_matrix)
46 evs = jnp.real(evs)
47 evs = jnp.where(evs > 0.0, evs, 0.0)
49 if density_matrix.ndim == 3:
50 # batched: (B, d, d)
51 sqrt_evs = jnp.sqrt(evs)[:, :, None] * jnp.eye(
52 density_matrix.shape[-1], dtype=_cdtype()
53 )
54 return vecs @ sqrt_evs @ jnp.conj(jnp.transpose(vecs, (0, 2, 1)))
56 # single: (d, d)
57 return vecs @ jnp.diag(jnp.sqrt(evs)) @ jnp.conj(vecs.T)
60def _fidelity_statevector(
61 state0: jnp.ndarray,
62 state1: jnp.ndarray,
63) -> jnp.ndarray:
64 r"""Fidelity between two pure states (state vectors).
66 .. math::
68 F(\ket{\psi}, \ket{\phi}) = \left|\braket{\psi | \phi}\right|^2
69 The inputs are normalised before the overlap is computed so that
70 the result is always in :math:`[0, 1]`.
71 """
72 # Normalise so that unnormalised inputs don't produce F > 1.
73 norm0 = jnp.linalg.norm(state0, axis=-1, keepdims=True)
74 norm1 = jnp.linalg.norm(state1, axis=-1, keepdims=True)
75 state0 = state0 / jnp.where(norm0 > 0, norm0, 1.0)
76 state1 = state1 / jnp.where(norm1 > 0, norm1, 1.0)
78 batched0 = state0.ndim > 1
79 batched1 = state1.ndim > 1
81 idx0 = "ab" if batched0 else "b"
82 idx1 = "ab" if batched1 else "b"
83 target = "a" if (batched0 or batched1) else ""
85 overlap = jnp.einsum(f"{idx0},{idx1}->{target}", jnp.conj(state0), state1)
86 return jnp.abs(overlap) ** 2
89def _fidelity_dm(
90 state0: jnp.ndarray,
91 state1: jnp.ndarray,
92) -> jnp.ndarray:
93 r"""Fidelity between two mixed states (density matrices)."""
94 sqrt_state0 = _sqrt_matrix(state0)
95 product = sqrt_state0 @ state1 @ sqrt_state0
97 evs = jnp.linalg.eigvalsh(product)
98 evs = jnp.real(evs)
99 evs = jnp.where(evs > 0.0, evs, 0.0)
101 return jnp.sum(jnp.sqrt(evs), axis=-1) ** 2
104def fidelity(
105 state0: jnp.ndarray,
106 state1: jnp.ndarray,
107) -> jnp.ndarray:
108 r"""Compute the fidelity between two quantum states.
110 Accepts either state vectors or density matrices.
112 Args:
113 state0: State vector or density matrix.
114 state1: State vector or density matrix (same kind as *state0*).
116 Returns:
117 Fidelity (scalar or shape ``(B,)``).
119 Raises:
120 ValueError: If the two states have incompatible shapes or
121 different representations (vector vs. matrix).
122 """
123 state0 = jnp.asarray(state0, dtype=_cdtype())
124 state1 = jnp.asarray(state1, dtype=_cdtype())
126 if state0.shape[-1] != state1.shape[-1]:
127 raise ValueError("The two states must have the same number of wires.")
129 is_sv0 = state0.ndim <= 2 and (
130 state0.ndim == 1 or state0.shape[-2] != state0.shape[-1]
131 )
132 is_sv1 = state1.ndim <= 2 and (
133 state1.ndim == 1 or state1.shape[-2] != state1.shape[-1]
134 )
136 if is_sv0 != is_sv1:
137 raise ValueError(
138 "Both states must be of the same kind "
139 "(both state vectors or both density matrices)."
140 )
142 if is_sv0:
143 return _fidelity_statevector(state0, state1)
144 return _fidelity_dm(state0, state1)
147def trace_distance(
148 state0: jnp.ndarray,
149 state1: jnp.ndarray,
150) -> jnp.ndarray:
151 r"""Compute the trace distance between two quantum states.
153 Supports single density matrices of shape ``(2**N, 2**N)`` and batched
154 density matrices of shape ``(B, 2**N, 2**N)``.
156 Args:
157 state0: Density matrix of shape ``(2**N, 2**N)`` or ``(B, 2**N, 2**N)``.
158 state1: Density matrix of shape ``(2**N, 2**N)`` or ``(B, 2**N, 2**N)``.
160 Returns:
161 Trace distance (scalar or shape ``(B,)``).
162 """
163 state0 = jnp.asarray(state0, dtype=_cdtype())
164 state1 = jnp.asarray(state1, dtype=_cdtype())
166 if state0.shape[-1] != state1.shape[-1]:
167 raise ValueError("The two states must have the same number of wires.")
169 eigvals = jnp.abs(jnp.linalg.eigvalsh(state0 - state1))
170 return jnp.sum(eigvals, axis=-1) / 2
173def phase_difference(
174 state0: jnp.ndarray,
175 state1: jnp.ndarray,
176) -> jnp.ndarray:
177 r"""Compute the phase difference between two state vectors.
179 A value of zero indicates the two states are related by at most a
180 real global factor (i.e. no relative phase). The result lies in
181 :math:`[-\pi, 1 + \pi]`.
183 Supports single state vectors of shape ``(2**N,)`` and batched state
184 vectors of shape ``(B, 2**N)``.
186 Args:
187 state0: State vector of shape ``(2**N,)`` or ``(B, 2**N)``.
188 state1: State vector of shape ``(2**N,)`` or ``(B, 2**N)``.
190 Returns:
191 Phase difference (scalar or shape ``(B,)``).
192 """
193 state0 = jnp.asarray(state0, dtype=_cdtype())
194 state1 = jnp.asarray(state1, dtype=_cdtype())
196 if state0.shape[-1] != state1.shape[-1]:
197 raise ValueError("The two states must have the same number of wires.")
199 batched0 = state0.ndim > 1
200 batched1 = state1.ndim > 1
202 idx0 = "ab" if batched0 else "b"
203 idx1 = "ab" if batched1 else "b"
204 target = "a" if (batched0 or batched1) else ""
206 inner = jnp.einsum(f"{idx0},{idx1}->{target}", jnp.conj(state0), state1)
207 return jnp.angle(inner)