Coverage for qml_essentials / gates.py: 91%

98 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-30 11:43 +0000

1from typing import List, Union, Callable 

2from contextlib import contextmanager 

3import numbers 

4import jax.numpy as jnp 

5import jax 

6 

7# Imports to keep the api `from gates import ...` 

8from qml_essentials.unitary import UnitaryGates 

9from qml_essentials.pulses import ( 

10 PulseGates, 

11 PulseParams, 

12 PulseEnvelope, 

13 PulseInformation, 

14 PulseParamManager, 

15) 

16from qml_essentials.operations import Barrier as BarrierOp 

17 

18import logging 

19 

20log = logging.getLogger(__name__) 

21 

22 

23# Meta class to avoid instantiating the Gates class 

24class GatesMeta(type): 

25 def __getattr__(cls, gate_name): 

26 def handler(*args, **kwargs): 

27 return Gates._inner_getattr(gate_name, *args, **kwargs) 

28 

29 # Dirty way to preserve information about the gate name 

30 handler.__name__ = gate_name 

31 return handler 

32 

33 

34def Barrier(wires: Union[int, List[int]], *args, **kwargs): 

35 """Thin wrapper for BarrierOp""" 

36 return BarrierOp(wires) 

37 

38 

39class Gates(metaclass=GatesMeta): 

40 """ 

41 Dynamic accessor for quantum Gates. 

42 

43 Routes calls like `Gates.RX(...)` to either `UnitaryGates` or `PulseGates` 

44 depending on the `gate_mode` keyword (defaults to 'unitary'). 

45 

46 During circuit building, the pulse manager can be activated via 

47 `pulse_manager_context`, which slices the global model pulse parameters 

48 and passes them to each gate. Model pulse parameters act as element-wise 

49 scalers on the gate's optimized pulse parameters. 

50 

51 Parameters 

52 ---------- 

53 gate_mode : str, optional 

54 Determines the backend. 'unitary' for UnitaryGates, 'pulse' for PulseGates. 

55 Defaults to 'unitary'. 

56 

57 Examples 

58 -------- 

59 >>> Gates.RX(w, wires) 

60 >>> Gates.RX(w, wires, gate_mode="unitary") 

61 >>> Gates.RX(w, wires, gate_mode="pulse") 

62 >>> Gates.RX(w, wires, pulse_params, gate_mode="pulse") 

63 """ 

64 

65 def __getattr__(self, gate_name): 

66 def handler(**kwargs): 

67 return self._inner_getattr(gate_name, **kwargs) 

68 

69 return handler 

70 

71 @staticmethod 

72 def _inner_getattr(gate_name, *args, **kwargs): 

73 if gate_name == "Barrier": 

74 return Barrier(*args, **kwargs) 

75 

76 gate_mode = kwargs.pop("gate_mode", "unitary") 

77 

78 # Backend selection and kwargs filtering 

79 allowed_args = [ 

80 "w", 

81 "wires", 

82 "phi", 

83 "theta", 

84 "omega", 

85 "input_idx", 

86 "noise_params", 

87 "random_key", 

88 ] 

89 if gate_mode == "unitary": 

90 gate_backend = UnitaryGates 

91 elif gate_mode == "pulse": 

92 gate_backend = PulseGates 

93 allowed_args += ["pulse_params"] 

94 else: 

95 raise ValueError( 

96 f"Unknown gate mode: {gate_mode}. Use 'unitary' or 'pulse'." 

97 ) 

98 

99 if len(kwargs.keys() - allowed_args) > 0: 

100 # TODO: pulse params are always provided? 

101 log.debug( 

102 f"Unsupported keyword arguments: {list(kwargs.keys() - allowed_args)}" 

103 ) 

104 

105 kwargs = {k: v for k, v in kwargs.items() if k in allowed_args} 

106 pulse_params = kwargs.get("pulse_params") 

107 pulse_mgr = getattr(Gates, "_pulse_mgr", None) 

108 

109 # TODO: rework this part to convert to valid PulseParams earlier 

110 # Type check on pulse parameters 

111 if pulse_params is not None: 

112 # flatten pulse parameters 

113 if isinstance(pulse_params, (list, tuple)): 

114 flat_params = pulse_params 

115 

116 elif isinstance(pulse_params, jax.core.Tracer): 

117 flat_params = jnp.ravel(pulse_params) 

118 

119 elif isinstance(pulse_params, (jnp.ndarray, jnp.ndarray)): 

120 flat_params = pulse_params.flatten().tolist() 

121 elif isinstance(pulse_params, PulseParams): 

122 # extract the params in case a full object is given 

123 kwargs["pulse_params"] = pulse_params.params 

124 flat_params = pulse_params.params.flatten().tolist() 

125 

126 else: 

127 raise TypeError(f"Unsupported pulse_params type: {type(pulse_params)}") 

128 

129 # checks elements in flat parameters are real numbers or jax Tracer 

130 if not all( 

131 isinstance(x, (numbers.Real, jax.core.Tracer)) for x in flat_params 

132 ): 

133 raise TypeError( 

134 "All elements in pulse_params must be int or float, " 

135 f"got {pulse_params}, type {type(pulse_params)}. " 

136 ) 

137 

138 # Len check on pulse parameters 

139 if pulse_params is not None and not isinstance(pulse_mgr, PulseParamManager): 

140 n_params = PulseInformation.gate_by_name(gate_name).size 

141 if len(flat_params) != n_params: 

142 raise ValueError( 

143 f"Gate '{gate_name}' expects {n_params} pulse parameters, " 

144 f"got {len(flat_params)}" 

145 ) 

146 

147 # Pulse slicing + scaling 

148 if gate_mode == "pulse" and isinstance(pulse_mgr, PulseParamManager): 

149 n_params = PulseInformation.gate_by_name(gate_name).size 

150 scalers = pulse_mgr.get(n_params) 

151 base = PulseInformation.gate_by_name(gate_name).params 

152 kwargs["pulse_params"] = base * scalers 

153 

154 # Call the selected gate backend 

155 gate = getattr(gate_backend, gate_name, None) 

156 if gate is None: 

157 raise AttributeError( 

158 f"'{gate_backend.__class__.__name__}' object " 

159 f"has no attribute '{gate_name}'" 

160 ) 

161 

162 return gate(*args, **kwargs) 

163 

164 @staticmethod 

165 @contextmanager 

166 def pulse_manager_context(pulse_params: jnp.ndarray): 

167 """Temporarily set the global pulse manager for circuit building.""" 

168 Gates._pulse_mgr = PulseParamManager(pulse_params) 

169 try: 

170 yield 

171 finally: 

172 Gates._pulse_mgr = None 

173 

174 @staticmethod 

175 def parse_gates( 

176 gates: Union[str, Callable, List[Union[str, Callable]]], 

177 set_of_gates=None, 

178 ): 

179 set_of_gates = set_of_gates or Gates 

180 

181 if isinstance(gates, str): 

182 # if str, use the pennylane fct 

183 parsed_gates = [getattr(set_of_gates, f"{gates}")] 

184 elif isinstance(gates, list): 

185 parsed_gates = [] 

186 for enc in gates: 

187 # if list, check if str or callable 

188 if isinstance(enc, str): 

189 parsed_gates.append(getattr(set_of_gates, f"{enc}")) 

190 # check if callable 

191 elif callable(enc): 

192 parsed_gates.append(enc) 

193 else: 

194 raise ValueError( 

195 f"Operation {enc} is not a valid gate or callable.\ 

196 Got {type(enc)}" 

197 ) 

198 elif callable(gates): 

199 # default to callable 

200 parsed_gates = [gates] 

201 elif gates is None: 

202 parsed_gates = [lambda *args, **kwargs: None] 

203 else: 

204 raise ValueError( 

205 f"Operation {gates} is not a valid gate or callable or list of both." 

206 ) 

207 return parsed_gates 

208 

209 @staticmethod 

210 def is_rotational(gate): 

211 return gate.__name__ in [ 

212 "RX", 

213 "RY", 

214 "RZ", 

215 "Rot", 

216 "CRX", 

217 "CRY", 

218 "CRZ", 

219 ] 

220 

221 @staticmethod 

222 def is_entangling(gate): 

223 return gate.__name__ in ["CX", "CY", "CZ", "CRX", "CRY", "CRZ"]