Coverage for qml_essentials / gates.py: 86%

98 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-06-11 15:51 +0000

1from typing import List, Union, Callable 

2from contextlib import contextmanager 

3import numbers 

4import jax.numpy as jnp 

5import jax 

6 

7# Imports to keep the api `from gates import ...` 

8from qml_essentials.unitary import UnitaryGates 

9from qml_essentials.pulses import ( 

10 PulseGates, 

11 PulseParams, 

12 PulseEnvelope, # noqa: F401 

13 PulseInformation, 

14 PulseParamManager, 

15) 

16from qml_essentials.operations import Barrier as BarrierOp 

17 

18import logging 

19 

20log = logging.getLogger(__name__) 

21 

22 

23# Meta class to avoid instantiating the Gates class 

24class GatesMeta(type): 

25 def __getattr__(cls, gate_name): 

26 def handler(*args, **kwargs): 

27 return cls._inner_getattr(gate_name, *args, **kwargs) 

28 

29 # Dirty way to preserve information about the gate name 

30 handler.__name__ = gate_name 

31 return handler 

32 

33 

34def Barrier(wires: Union[int, List[int]], *args, **kwargs): 

35 """Thin wrapper for BarrierOp""" 

36 return BarrierOp(wires) 

37 

38 

39class Gates(metaclass=GatesMeta): 

40 """ 

41 Dynamic accessor for quantum Gates. 

42 

43 Routes calls like `Gates.RX(...)` to either `UnitaryGates` or `PulseGates` 

44 depending on the `gate_mode` keyword (defaults to 'unitary'). 

45 

46 During circuit building, the pulse manager can be activated via 

47 `pulse_manager_context`, which slices the global model pulse parameters 

48 and passes them to each gate. Model pulse parameters act as element-wise 

49 scalers on the gate's optimized pulse parameters. 

50 

51 Parameters 

52 ---------- 

53 gate_mode : str, optional 

54 Determines the backend. 'unitary' for UnitaryGates, 'pulse' for PulseGates. 

55 Defaults to 'unitary'. 

56 

57 Examples 

58 -------- 

59 >>> Gates.RX(w, wires) 

60 >>> Gates.RX(w, wires, gate_mode="unitary") 

61 >>> Gates.RX(w, wires, gate_mode="pulse") 

62 >>> Gates.RX(w, wires, pulse_params, gate_mode="pulse") 

63 """ 

64 

65 def __getattr__(self, gate_name): 

66 def handler(**kwargs): 

67 return self._inner_getattr(gate_name, **kwargs) 

68 

69 return handler 

70 

71 @classmethod 

72 def _inner_getattr(cls, gate_name, *args, **kwargs): 

73 if gate_name == "Barrier": 

74 return Barrier(*args, **kwargs) 

75 

76 gate_mode = kwargs.pop("gate_mode", "unitary") 

77 

78 # Backend selection and kwargs filtering 

79 allowed_args = [ 

80 "w", 

81 "wires", 

82 "phi", 

83 "theta", 

84 "omega", 

85 "noise_params", 

86 "random_key", 

87 ] 

88 if gate_mode == "unitary": 

89 gate_backend = UnitaryGates 

90 elif gate_mode == "pulse": 

91 gate_backend = PulseGates 

92 allowed_args += ["pulse_params"] 

93 else: 

94 raise ValueError( 

95 f"Unknown gate mode: {gate_mode}. Use 'unitary' or 'pulse'." 

96 ) 

97 

98 if len(kwargs.keys() - allowed_args) > 0: 

99 # TODO: pulse params are always provided? 

100 log.debug( 

101 f"Unsupported keyword arguments: {list(kwargs.keys() - allowed_args)}" 

102 ) 

103 

104 kwargs = {k: v for k, v in kwargs.items() if k in allowed_args} 

105 pulse_params = kwargs.get("pulse_params") 

106 pulse_mgr = getattr(cls, "_pulse_mgr", None) 

107 

108 # TODO: rework this part to convert to valid PulseParams earlier 

109 # Type check on pulse parameters 

110 if pulse_params is not None: 

111 # flatten pulse parameters 

112 if isinstance(pulse_params, (list, tuple)): 

113 flat_params = pulse_params 

114 

115 elif isinstance(pulse_params, jax.core.Tracer): 

116 flat_params = jnp.ravel(pulse_params) 

117 

118 elif isinstance(pulse_params, (jnp.ndarray, jnp.ndarray)): 

119 flat_params = pulse_params.flatten().tolist() 

120 elif isinstance(pulse_params, PulseParams): 

121 # extract the params in case a full object is given 

122 kwargs["pulse_params"] = pulse_params.params 

123 flat_params = pulse_params.params.flatten().tolist() 

124 

125 else: 

126 raise TypeError(f"Unsupported pulse_params type: {type(pulse_params)}") 

127 

128 # checks elements in flat parameters are real numbers or jax Tracer 

129 if not all( 

130 isinstance(x, (numbers.Real, jax.core.Tracer)) for x in flat_params 

131 ): 

132 raise TypeError( 

133 "All elements in pulse_params must be int or float, " 

134 f"got {pulse_params}, type {type(pulse_params)}. " 

135 ) 

136 

137 # Len check on pulse parameters 

138 if pulse_params is not None and not isinstance(pulse_mgr, PulseParamManager): 

139 n_params = PulseInformation.gate_by_name(gate_name).size 

140 if len(flat_params) != n_params: 

141 raise ValueError( 

142 f"Gate '{gate_name}' expects {n_params} pulse parameters, " 

143 f"got {len(flat_params)}" 

144 ) 

145 

146 # Pulse slicing + scaling 

147 if gate_mode == "pulse" and isinstance(pulse_mgr, PulseParamManager): 

148 n_params = PulseInformation.gate_by_name(gate_name).size 

149 scalers = pulse_mgr.get(n_params) 

150 base = PulseInformation.gate_by_name(gate_name).params 

151 kwargs["pulse_params"] = base * scalers 

152 

153 # Call the selected gate backend 

154 gate = getattr(gate_backend, gate_name, None) 

155 if gate is None: 

156 raise AttributeError( 

157 f"'{gate_backend.__class__.__name__}' object " 

158 f"has no attribute '{gate_name}'" 

159 ) 

160 

161 return gate(*args, **kwargs) 

162 

163 @classmethod 

164 @contextmanager 

165 def pulse_manager_context(cls, pulse_params: jnp.ndarray): 

166 """Temporarily set the global pulse manager for circuit building.""" 

167 cls._pulse_mgr = PulseParamManager(pulse_params) 

168 try: 

169 yield 

170 finally: 

171 cls._pulse_mgr = None 

172 

173 @classmethod 

174 def parse_gates( 

175 cls, 

176 gates: Union[str, Callable, List[Union[str, Callable]]], 

177 set_of_gates=None, 

178 ): 

179 set_of_gates = set_of_gates or cls 

180 

181 if isinstance(gates, str): 

182 # if str, use the pennylane fct 

183 parsed_gates = [getattr(set_of_gates, f"{gates}")] 

184 elif isinstance(gates, list): 

185 parsed_gates = [] 

186 for enc in gates: 

187 # if list, check if str or callable 

188 if isinstance(enc, str): 

189 parsed_gates.append(getattr(set_of_gates, f"{enc}")) 

190 # check if callable 

191 elif callable(enc): 

192 parsed_gates.append(enc) 

193 else: 

194 raise ValueError( 

195 f"Operation {enc} is not a valid gate or callable.\ 

196 Got {type(enc)}" 

197 ) 

198 elif callable(gates): 

199 # default to callable 

200 parsed_gates = [gates] 

201 elif gates is None: 

202 parsed_gates = [lambda *args, **kwargs: None] 

203 else: 

204 raise ValueError( 

205 f"Operation {gates} is not a valid gate or callable or list of both." 

206 ) 

207 return parsed_gates 

208 

209 @classmethod 

210 def is_rotational(cls, gate): 

211 return gate.__name__ in [ 

212 "RX", 

213 "RY", 

214 "RZ", 

215 "Rot", 

216 "CRX", 

217 "CRY", 

218 "CRZ", 

219 "GolombEncoding", 

220 "CPhase", 

221 ] 

222 

223 @classmethod 

224 def is_entangling(cls, gate): 

225 return gate.__name__ in ["CX", "CY", "CZ", "CRX", "CRY", "CRZ", "CPhase"]