Coverage for qml_essentials / jaqsi.py: 100%

39 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-06-11 15:51 +0000

1"""Pulse/gate-independent entry point for building and simulating circuits. 

2 

3This module is the main interaction point for manually creating circuits. It 

4exposes the :class:`~qml_essentials.script.Script` circuit container, the 

5:func:`Hamiltonian` factory for time-evolution sources, and a few general 

6(pulse/gate-independent) quantum-info utilities. 

7 

8Time evolution is invoked as a method on the Hamiltonian object:: 

9 

10 H = Hamiltonian(matrix, wires=0) # static -> Hermitian 

11 H_t = coeff_fn * Hamiltonian(matrix, 0) # time-dep -> ParametrizedHamiltonian 

12 H_t.evolve(name="RX")([params], t) # gate factory 

13 

14The time-evolution engine itself lives in :mod:`qml_essentials.evolution` as 

15:class:`Evolution`, which is re-exported here for solver configuration 

16(``Evolution.set_solver_defaults`` / ``Evolution.clear_evolve_solver_cache``). 

17""" 

18 

19from functools import reduce 

20from typing import List, Tuple, Union 

21 

22import jax 

23import jax.numpy as jnp 

24 

25from qml_essentials.script import Script # noqa: F401 

26from qml_essentials.evolution import Evolution # noqa: F401 

27from qml_essentials.operations import ( # noqa: F401 

28 Hermitian, 

29 ParametrizedHamiltonian, 

30 PauliZ, 

31) 

32 

33 

34def Hamiltonian( 

35 matrix: jnp.ndarray, 

36 wires: Union[int, List[int]] = 0, 

37 record: bool = False, 

38) -> Hermitian: 

39 """Construct a (static) Hamiltonian as a :class:`Hermitian` operator. 

40 

41 This is a thin factory over the existing :class:`Hermitian` operation — 

42 not a new type. Multiply it by a coefficient function ``f(params, t)`` to 

43 obtain a time-dependent :class:`ParametrizedHamiltonian`. Both expose an 

44 :meth:`evolve` method that returns a gate factory. 

45 

46 Args: 

47 matrix: The Hermitian matrix defining this Hamiltonian. 

48 wires: Qubit index or list of qubit indices it acts on. 

49 record: Whether to record on the active tape. Defaults to ``False`` 

50 since a Hamiltonian used as an evolution source should not appear 

51 as a gate; the recorded operation is the one produced by 

52 :meth:`evolve`. 

53 

54 Returns: 

55 A :class:`Hermitian` instance. 

56 """ 

57 return Hermitian(matrix, wires=wires, record=record) 

58 

59 

60def _partial_trace_single( 

61 rho: jnp.ndarray, 

62 n_qubits: int, 

63 keep: List[int], 

64) -> jnp.ndarray: 

65 """Partial trace of a single density matrix (no batch dimension).""" 

66 shape = (2,) * (2 * n_qubits) 

67 rho_t = rho.reshape(shape) 

68 

69 trace_out = sorted(set(range(n_qubits)) - set(keep)) 

70 

71 for q in reversed(trace_out): 

72 n_remaining = rho_t.ndim // 2 

73 rho_t = jnp.trace(rho_t, axis1=q, axis2=q + n_remaining) 

74 

75 dim = 2 ** len(keep) 

76 return rho_t.reshape(dim, dim) 

77 

78 

79def partial_trace( 

80 rho: jnp.ndarray, 

81 n_qubits: int, 

82 keep: List[int], 

83) -> jnp.ndarray: 

84 """Partial trace of a density matrix, keeping only the specified qubits. 

85 

86 Supports both single density matrices of shape ``(2**n, 2**n)`` and 

87 batched density matrices of shape ``(B, 2**n, 2**n)``. 

88 

89 Args: 

90 rho: Density matrix of shape ``(2**n, 2**n)`` or ``(B, 2**n, 2**n)``. 

91 n_qubits: Total number of qubits. 

92 keep: List of qubit indices to *keep* (0-indexed). 

93 

94 Returns: 

95 Reduced density matrix of shape ``(2**k, 2**k)`` or ``(B, 2**k, 2**k)`` 

96 where *k* = ``len(keep)``. 

97 """ 

98 

99 dim = 2**n_qubits 

100 if rho.shape == (dim, dim): 

101 return _partial_trace_single(rho, n_qubits, keep) 

102 # Batched: shape (B, dim, dim) 

103 return jax.vmap(lambda r: _partial_trace_single(r, n_qubits, keep))(rho) 

104 

105 

106def _marginalize_probs_single( 

107 probs: jnp.ndarray, 

108 target_shape: Tuple[int], 

109 trace_out: Tuple[int], 

110) -> jnp.ndarray: 

111 """Marginalize a single probability vector (no batch dimension).""" 

112 probs_t = probs.reshape(target_shape) 

113 

114 for q in trace_out: 

115 probs_t = probs_t.sum(axis=q) 

116 

117 return probs_t.ravel() 

118 

119 

120def marginalize_probs( 

121 probs: jnp.ndarray, 

122 n_qubits: int, 

123 keep: Tuple[int], 

124) -> jnp.ndarray: 

125 """Marginalize a probability vector to keep only the specified qubits. 

126 

127 Supports both single probability vectors of shape ``(2**n,)`` and 

128 batched vectors of shape ``(B, 2**n)``. 

129 

130 Args: 

131 probs: Probability vector of shape ``(2**n,)`` or ``(B, 2**n)``. 

132 n_qubits: Total number of qubits. 

133 keep: List of qubit indices to *keep* (0-indexed). 

134 

135 Returns: 

136 Marginalized probability vector of shape ``(2**k,)`` or ``(B, 2**k)`` 

137 where *k* = ``len(keep)``. 

138 """ 

139 

140 dim = 2**n_qubits 

141 trace_out = tuple(q for q in range(n_qubits - 1, -1, -1) if q not in keep) 

142 target_shape = (2,) * n_qubits 

143 

144 return jax.vmap(lambda p: _marginalize_probs_single(p, target_shape, trace_out))( 

145 probs.reshape(-1, dim) 

146 ) 

147 

148 

149def build_parity_observable( 

150 qubit_group: List[int], 

151) -> Hermitian: 

152 """Build a multi-qubit parity observable. 

153 

154 Args: 

155 qubit_group: List of qubit indices for the parity measurement. 

156 

157 Returns: 

158 A :class:`Hermitian` operation whose matrix is the Z parity 

159 tensor product and whose wires match the given qubits. 

160 """ 

161 Z = PauliZ._matrix 

162 mat = reduce(jnp.kron, [Z] * len(qubit_group)) 

163 obs = Hermitian(matrix=mat, wires=qubit_group, record=False) 

164 # Tag the Pauli string so symbolic consumers (PauliWord / FourierTree) can 

165 # read it without an O(4^n) matrix decomposition. 

166 obs._pauli_label = "Z" * len(qubit_group) 

167 return obs