Coverage for qml_essentials / unitary.py: 95%

125 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-30 11:43 +0000

1from typing import Optional, List, Union, Dict, Tuple 

2import itertools 

3import jax.numpy as jnp 

4import jax 

5 

6from qml_essentials import operations as op 

7import logging 

8 

9from qml_essentials.utils import safe_random_split 

10 

11log = logging.getLogger(__name__) 

12 

13 

14class UnitaryGates: 

15 """Collection of unitary quantum gates with optional noise simulation.""" 

16 

17 batch_gate_error = True 

18 

19 @staticmethod 

20 def NQubitDepolarizingChannel(p: float, wires: List[int]) -> op.QubitChannel: 

21 """ 

22 Generate Kraus operators for n-qubit depolarizing channel. 

23 

24 The n-qubit depolarizing channel models uniform depolarizing noise 

25 acting on n qubits simultaneously, useful for simulating realistic 

26 multi-qubit noise affecting entangling gates. 

27 

28 Args: 

29 p (float): Total probability of depolarizing error (0 ≤ p ≤ 1). 

30 wires (List[int]): Qubit indices on which the channel acts. 

31 Must contain at least 2 qubits. 

32 

33 Returns: 

34 op.QubitChannel: QubitChannel with Kraus operators 

35 representing the depolarizing noise channel. 

36 

37 Raises: 

38 ValueError: If p is not in [0, 1] or if fewer than 2 qubits provided. 

39 """ 

40 

41 def n_qubit_depolarizing_kraus(p: float, n: int) -> List[jnp.ndarray]: 

42 if not (0.0 <= p <= 1.0): 

43 raise ValueError(f"Probability p must be between 0 and 1, got {p}") 

44 if n < 2: 

45 raise ValueError(f"Number of qubits must be >= 2, got {n}") 

46 

47 Id = jnp.eye(2) 

48 X = op.PauliX._matrix 

49 Y = op.PauliY._matrix 

50 Z = op.PauliZ._matrix 

51 paulis = [Id, X, Y, Z] 

52 

53 dim = 2**n 

54 all_ops = [] 

55 

56 # Generate all n-qubit Pauli tensor products: 

57 for indices in itertools.product(range(4), repeat=n): 

58 P = jnp.eye(1) 

59 for idx in indices: 

60 P = jnp.kron(P, paulis[idx]) 

61 all_ops.append(P) 

62 

63 # Identity operator corresponds to all zeros indices (Id^n) 

64 K0 = jnp.sqrt(1 - p * (4**n - 1) / (4**n)) * jnp.eye(dim) 

65 

66 kraus_ops = [] 

67 for i, P in enumerate(all_ops): 

68 if i == 0: 

69 # Skip the identity, already handled as K0 

70 continue 

71 kraus_ops.append(jnp.sqrt(p / (4**n)) * P) 

72 

73 return [K0] + kraus_ops 

74 

75 return op.QubitChannel(n_qubit_depolarizing_kraus(p, len(wires)), wires=wires) 

76 

77 @staticmethod 

78 def Noise( 

79 wires: Union[int, List[int]], noise_params: Optional[Dict[str, float]] = None 

80 ) -> None: 

81 """ 

82 Apply noise channels to specified qubits. 

83 

84 Applies various single-qubit and multi-qubit noise channels based on 

85 the provided noise parameters dictionary. 

86 

87 Args: 

88 wires (Union[int, List[int]]): Qubit index or list of qubit indices 

89 to apply noise to. 

90 noise_params (Optional[Dict[str, float]]): Dictionary of noise 

91 parameters. Supported keys: 

92 - "BitFlip" (float): Bit flip error probability 

93 - "PhaseFlip" (float): Phase flip error probability 

94 - "Depolarizing" (float): Single-qubit depolarizing probability 

95 - "MultiQubitDepolarizing" (float): Multi-qubit depolarizing 

96 probability (applies if len(wires) > 1) 

97 All parameters default to 0.0 if not provided. 

98 

99 Returns: 

100 None: Noise channels are applied in-place to the circuit. 

101 """ 

102 if noise_params is not None: 

103 if isinstance(wires, int): 

104 wires = [wires] # single qubit gate 

105 

106 # noise on single qubits 

107 for wire in wires: 

108 bf = noise_params.get("BitFlip", 0.0) 

109 if bf > 0: 

110 op.BitFlip(bf, wires=wire) 

111 

112 pf = noise_params.get("PhaseFlip", 0.0) 

113 if pf > 0: 

114 op.PhaseFlip(pf, wires=wire) 

115 

116 dp = noise_params.get("Depolarizing", 0.0) 

117 if dp > 0: 

118 op.DepolarizingChannel(dp, wires=wire) 

119 

120 # noise on two-qubits 

121 if len(wires) > 1: 

122 p = noise_params.get("MultiQubitDepolarizing", 0.0) 

123 if p > 0: 

124 UnitaryGates.NQubitDepolarizingChannel(p, wires) 

125 

126 @staticmethod 

127 def GateError( 

128 w: Union[float, jnp.ndarray, List[float]], 

129 noise_params: Optional[Dict[str, float]] = None, 

130 random_key: Optional[jax.random.PRNGKey] = None, 

131 ) -> Tuple[jnp.ndarray, jax.random.PRNGKey]: 

132 """ 

133 Apply gate error noise to rotation angle(s). 

134 

135 Adds Gaussian noise to gate rotation angles to simulate imperfect 

136 gate implementations. 

137 

138 Args: 

139 w (Union[float, jnp.ndarray, List[float]]): Rotation angle(s) in radians. 

140 noise_params (Optional[Dict[str, float]]): Dictionary with optional 

141 "GateError" key specifying standard deviation of Gaussian noise. 

142 random_key (Optional[jax.random.PRNGKey]): JAX random key for 

143 stochastic noise generation. 

144 

145 Returns: 

146 Tuple[jnp.ndarray, jax.random.PRNGKey]: Tuple containing: 

147 - Modified rotation angle(s) with applied noise 

148 - Updated JAX random key 

149 

150 Raises: 

151 AssertionError: If noise_params contains "GateError" but random_key is None. 

152 """ 

153 if noise_params is not None and noise_params.get("GateError", None) is not None: 

154 assert ( 

155 random_key is not None 

156 ), "A random_key must be provided when using GateError" 

157 

158 if UnitaryGates.batch_gate_error: 

159 random_key, sub_key = safe_random_split(random_key) 

160 else: 

161 # Use a fixed key so that every batch element (under vmap) 

162 # draws the same noise value, effectively broadcasting. 

163 sub_key = jax.random.key(0) 

164 

165 w += noise_params["GateError"] * jax.random.normal( 

166 sub_key, 

167 ( 

168 w.shape 

169 if isinstance(w, jnp.ndarray) and UnitaryGates.batch_gate_error 

170 else () 

171 ), 

172 ) 

173 return w, random_key 

174 

175 @staticmethod 

176 def Rot( 

177 phi: Union[float, jnp.ndarray, List[float]], 

178 theta: Union[float, jnp.ndarray, List[float]], 

179 omega: Union[float, jnp.ndarray, List[float]], 

180 wires: Union[int, List[int]], 

181 noise_params: Optional[Dict[str, float]] = None, 

182 random_key: Optional[jax.random.PRNGKey] = None, 

183 input_idx: int = -1, 

184 ) -> None: 

185 """ 

186 Apply general rotation gate with optional noise. 

187 

188 Applies a three-angle rotation Rot(phi, theta, omega) with optional 

189 gate errors and noise channels. 

190 

191 Args: 

192 phi (Union[float, jnp.ndarray, List[float]]): First rotation angle. 

193 theta (Union[float, jnp.ndarray, List[float]]): Second rotation angle. 

194 omega (Union[float, jnp.ndarray, List[float]]): Third rotation angle. 

195 wires (Union[int, List[int]]): Qubit index or indices to apply rotation to. 

196 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

197 Supports BitFlip, PhaseFlip, Depolarizing, and GateError. 

198 random_key (Optional[jax.random.PRNGKey]): JAX random key for noise. 

199 input_idx (int): Flag for the tape to track inputs 

200 

201 Returns: 

202 None: Gate and noise are applied in-place to the circuit. 

203 """ 

204 if noise_params is not None and "GateError" in noise_params: 

205 phi, random_key = UnitaryGates.GateError(phi, noise_params, random_key) 

206 theta, random_key = UnitaryGates.GateError(theta, noise_params, random_key) 

207 omega, random_key = UnitaryGates.GateError(omega, noise_params, random_key) 

208 op.Rot(phi, theta, omega, wires=wires, input_idx=False) 

209 UnitaryGates.Noise(wires, noise_params) 

210 

211 @staticmethod 

212 def PauliRot( 

213 theta: float, 

214 pauli: str, 

215 wires: Union[int, List[int]], 

216 noise_params: Optional[Dict[str, float]] = None, 

217 random_key: Optional[jax.random.PRNGKey] = None, 

218 input_idx: int = -1, 

219 ) -> None: 

220 """ 

221 Apply general rotation gate with optional noise. 

222 

223 Applies a three-angle rotation Rot(phi, theta, omega) with optional 

224 gate errors and noise channels. 

225 

226 Args: 

227 theta (Union[float, jnp.ndarray, List[float]]): Second rotation angle. 

228 pauli (str): Pauli operator to apply. Must be "X", "Y", or "Z". 

229 wires (Union[int, List[int]]): Qubit index or indices to apply rotation to. 

230 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

231 Supports BitFlip, PhaseFlip, Depolarizing, and GateError. 

232 random_key (Optional[jax.random.PRNGKey]): JAX random key for noise. 

233 input_idx (int): Flag for the tape to track inputs 

234 

235 Returns: 

236 None: Gate and noise are applied in-place to the circuit. 

237 """ 

238 if noise_params is not None and "GateError" in noise_params: 

239 theta, random_key = UnitaryGates.GateError(theta, noise_params, random_key) 

240 op.PauliRot(theta, pauli, wires=wires, input_idx=input_idx) 

241 UnitaryGates.Noise(wires, noise_params) 

242 

243 @staticmethod 

244 def RX( 

245 w: Union[float, jnp.ndarray, List[float]], 

246 wires: Union[int, List[int]], 

247 noise_params: Optional[Dict[str, float]] = None, 

248 random_key: Optional[jax.random.PRNGKey] = None, 

249 input_idx: int = -1, 

250 ) -> None: 

251 """ 

252 Apply X-axis rotation with optional noise. 

253 

254 Args: 

255 w (Union[float, jnp.ndarray, List[float]]): Rotation angle. 

256 wires (Union[int, List[int]]): Qubit index or indices. 

257 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

258 random_key (Optional[jax.random.PRNGKey]): JAX random key for noise. 

259 input_idx (int): Flag for the tape to track inputs 

260 

261 Returns: 

262 None: Gate and noise are applied in-place to the circuit. 

263 """ 

264 w, random_key = UnitaryGates.GateError(w, noise_params, random_key) 

265 op.RX(w, wires=wires, input_idx=input_idx) 

266 UnitaryGates.Noise(wires, noise_params) 

267 

268 @staticmethod 

269 def RY( 

270 w: Union[float, jnp.ndarray, List[float]], 

271 wires: Union[int, List[int]], 

272 noise_params: Optional[Dict[str, float]] = None, 

273 random_key: Optional[jax.random.PRNGKey] = None, 

274 input_idx: int = -1, 

275 ) -> None: 

276 """ 

277 Apply Y-axis rotation with optional noise. 

278 

279 Args: 

280 w (Union[float, jnp.ndarray, List[float]]): Rotation angle. 

281 wires (Union[int, List[int]]): Qubit index or indices. 

282 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

283 random_key (Optional[jax.random.PRNGKey]): JAX random key for noise. 

284 input_idx (int): Flag for the tape to track inputs 

285 

286 Returns: 

287 None: Gate and noise are applied in-place to the circuit. 

288 """ 

289 w, random_key = UnitaryGates.GateError(w, noise_params, random_key) 

290 op.RY(w, wires=wires, input_idx=input_idx) 

291 UnitaryGates.Noise(wires, noise_params) 

292 

293 @staticmethod 

294 def RZ( 

295 w: Union[float, jnp.ndarray, List[float]], 

296 wires: Union[int, List[int]], 

297 noise_params: Optional[Dict[str, float]] = None, 

298 random_key: Optional[jax.random.PRNGKey] = None, 

299 input_idx: int = -1, 

300 ) -> None: 

301 """ 

302 Apply Z-axis rotation with optional noise. 

303 

304 Args: 

305 w (Union[float, jnp.ndarray, List[float]]): Rotation angle. 

306 wires (Union[int, List[int]]): Qubit index or indices. 

307 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

308 random_key (Optional[jax.random.PRNGKey]): JAX random key for noise. 

309 input_idx (int): Flag for the tape to track inputs 

310 

311 Returns: 

312 None: Gate and noise are applied in-place to the circuit. 

313 """ 

314 w, random_key = UnitaryGates.GateError(w, noise_params, random_key) 

315 op.RZ(w, wires=wires, input_idx=input_idx) 

316 UnitaryGates.Noise(wires, noise_params) 

317 

318 @staticmethod 

319 def CRX( 

320 w: Union[float, jnp.ndarray, List[float]], 

321 wires: Union[int, List[int]], 

322 noise_params: Optional[Dict[str, float]] = None, 

323 random_key: Optional[jax.random.PRNGKey] = None, 

324 input_idx: int = -1, 

325 ) -> None: 

326 """ 

327 Apply controlled X-rotation with optional noise. 

328 

329 Args: 

330 w (Union[float, jnp.ndarray, List[float]]): Rotation angle. 

331 wires (Union[int, List[int]]): Control and target qubit indices. 

332 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

333 random_key (Optional[jax.random.PRNGKey]): JAX random key for noise. 

334 input_idx (int): Flag for the tape to track inputs 

335 

336 Returns: 

337 None: Gate and noise are applied in-place to the circuit. 

338 """ 

339 w, random_key = UnitaryGates.GateError(w, noise_params, random_key) 

340 op.CRX(w, wires=wires, input_idx=input_idx) 

341 UnitaryGates.Noise(wires, noise_params) 

342 

343 @staticmethod 

344 def CRY( 

345 w: Union[float, jnp.ndarray, List[float]], 

346 wires: Union[int, List[int]], 

347 noise_params: Optional[Dict[str, float]] = None, 

348 random_key: Optional[jax.random.PRNGKey] = None, 

349 input_idx: int = -1, 

350 ) -> None: 

351 """ 

352 Apply controlled Y-rotation with optional noise. 

353 

354 Args: 

355 w (Union[float, jnp.ndarray, List[float]]): Rotation angle. 

356 wires (Union[int, List[int]]): Control and target qubit indices. 

357 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

358 random_key (Optional[jax.random.PRNGKey]): JAX random key for noise. 

359 input_idx (int): Flag for the tape to track inputs 

360 

361 Returns: 

362 None: Gate and noise are applied in-place to the circuit. 

363 """ 

364 w, random_key = UnitaryGates.GateError(w, noise_params, random_key) 

365 op.CRY(w, wires=wires, input_idx=input_idx) 

366 UnitaryGates.Noise(wires, noise_params) 

367 

368 @staticmethod 

369 def CRZ( 

370 w: Union[float, jnp.ndarray, List[float]], 

371 wires: Union[int, List[int]], 

372 noise_params: Optional[Dict[str, float]] = None, 

373 random_key: Optional[jax.random.PRNGKey] = None, 

374 input_idx: int = -1, 

375 ) -> None: 

376 """ 

377 Apply controlled Z-rotation with optional noise. 

378 

379 Args: 

380 w (Union[float, jnp.ndarray, List[float]]): Rotation angle. 

381 wires (Union[int, List[int]]): Control and target qubit indices. 

382 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

383 random_key (Optional[jax.random.PRNGKey]): JAX random key for noise. 

384 input_idx (int): Flag for the tape to track inputs 

385 

386 Returns: 

387 None: Gate and noise are applied in-place to the circuit. 

388 """ 

389 w, random_key = UnitaryGates.GateError(w, noise_params, random_key) 

390 op.CRZ(w, wires=wires, input_idx=input_idx) 

391 UnitaryGates.Noise(wires, noise_params) 

392 

393 @staticmethod 

394 def CX( 

395 wires: Union[int, List[int]], 

396 noise_params: Optional[Dict[str, float]] = None, 

397 random_key: Optional[jax.random.PRNGKey] = None, 

398 ) -> None: 

399 """ 

400 Apply controlled-NOT (CNOT) gate with optional noise. 

401 

402 Args: 

403 wires (Union[int, List[int]]): Control and target qubit indices. 

404 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

405 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility 

406 (not used in this gate). 

407 

408 Returns: 

409 None: Gate and noise are applied in-place to the circuit. 

410 """ 

411 op.CX(wires=wires) 

412 UnitaryGates.Noise(wires, noise_params) 

413 

414 @staticmethod 

415 def CY( 

416 wires: Union[int, List[int]], 

417 noise_params: Optional[Dict[str, float]] = None, 

418 random_key: Optional[jax.random.PRNGKey] = None, 

419 ) -> None: 

420 """ 

421 Apply controlled-Y gate with optional noise. 

422 

423 Args: 

424 wires (Union[int, List[int]]): Control and target qubit indices. 

425 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

426 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility 

427 (not used in this gate). 

428 

429 Returns: 

430 None: Gate and noise are applied in-place to the circuit. 

431 """ 

432 op.CY(wires=wires) 

433 UnitaryGates.Noise(wires, noise_params) 

434 

435 @staticmethod 

436 def CZ( 

437 wires: Union[int, List[int]], 

438 noise_params: Optional[Dict[str, float]] = None, 

439 random_key: Optional[jax.random.PRNGKey] = None, 

440 ) -> None: 

441 """ 

442 Apply controlled-Z gate with optional noise. 

443 

444 Args: 

445 wires (Union[int, List[int]]): Control and target qubit indices. 

446 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

447 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility 

448 (not used in this gate). 

449 

450 Returns: 

451 None: Gate and noise are applied in-place to the circuit. 

452 """ 

453 op.CZ(wires=wires) 

454 UnitaryGates.Noise(wires, noise_params) 

455 

456 @staticmethod 

457 def H( 

458 wires: Union[int, List[int]], 

459 noise_params: Optional[Dict[str, float]] = None, 

460 random_key: Optional[jax.random.PRNGKey] = None, 

461 ) -> None: 

462 """ 

463 Apply Hadamard gate with optional noise. 

464 

465 Args: 

466 wires (Union[int, List[int]]): Qubit index or indices. 

467 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

468 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility 

469 (not used in this gate). 

470 

471 Returns: 

472 None: Gate and noise are applied in-place to the circuit. 

473 """ 

474 op.H(wires=wires) 

475 UnitaryGates.Noise(wires, noise_params)