Coverage for qml_essentials / pulses.py: 82%

554 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-06-11 15:51 +0000

1import os 

2from contextlib import contextmanager 

3from dataclasses import dataclass 

4from typing import Optional, List, Union, Dict, Callable, Tuple 

5import csv 

6import jax.numpy as jnp 

7import jax 

8 

9from qml_essentials import jaqsi as js 

10from qml_essentials.utils import safe_random_split 

11from qml_essentials.tape import active_pulse_tape 

12from qml_essentials.unitary import UnitaryGates 

13import logging 

14 

15log = logging.getLogger(__name__) 

16 

17 

18@dataclass 

19class DecompositionStep: 

20 """One step in a composite pulse gate decomposition. 

21 

22 Attributes: 

23 gate: Child PulseParams object for this step. 

24 wire_fn: Wire selection - ``"all"``, ``"target"``, or ``"control"``. 

25 angle_fn: Maps parent angle(s) ``w`` to child angle. 

26 ``None`` means pass ``w`` through unchanged. 

27 """ 

28 

29 gate: "PulseParams" 

30 wire_fn: str = "all" 

31 angle_fn: Optional[Callable] = None 

32 

33 

34@dataclass(frozen=True) 

35class PulseStateSnapshot: 

36 """Snapshot of the mutable global pulse configuration.""" 

37 

38 envelope: str 

39 rwa: bool 

40 frame: str 

41 leaf_params: Dict[str, jnp.ndarray] 

42 

43 

44class PulseParams: 

45 """Container for hierarchical pulse parameters. 

46 

47 Leaf nodes hold direct parameters; composite nodes hold a list of 

48 :class:`DecompositionStep` objects that describe how the gate is 

49 built from simpler gates. 

50 

51 Attributes: 

52 name: Gate identifier (e.g. ``"RX"``, ``"H"``). 

53 decomposition: List of :class:`DecompositionStep` (composite only). 

54 """ 

55 

56 def __init__( 

57 self, 

58 name: str = "", 

59 params: Optional[jnp.ndarray] = None, 

60 decomposition: Optional[List[DecompositionStep]] = None, 

61 ) -> None: 

62 """ 

63 Args: 

64 name: Gate name. 

65 params: Direct pulse parameters (leaf gates). 

66 Mutually exclusive with *decomposition*. 

67 decomposition: List of :class:`DecompositionStep` (composite gates). 

68 Mutually exclusive with *params*. 

69 """ 

70 assert (params is None) != (decomposition is None), ( 

71 "Exactly one of `params` or `decomposition` must be provided." 

72 ) 

73 

74 self.decomposition = decomposition 

75 # Derive _pulse_obj for backward compat with childs/leafs/split_params 

76 self._pulse_obj = ( 

77 [step.gate for step in decomposition] if decomposition else None 

78 ) 

79 

80 if params is not None: 

81 self._params = params 

82 

83 self.name = name 

84 

85 def __len__(self) -> int: 

86 """ 

87 Get the total number of pulse parameters. 

88 

89 For composite gates, returns the accumulated count from all children. 

90 

91 Returns: 

92 int: Total number of pulse parameters. 

93 """ 

94 return len(self.params) 

95 

96 def __getitem__(self, idx: int) -> Union[float, jnp.ndarray]: 

97 """ 

98 Access pulse parameter(s) by index. 

99 

100 For leaf gates, returns the parameter at the given index. 

101 For composite gates, returns parameters of the child at the given index. 

102 

103 Args: 

104 idx (int): Index to access. 

105 

106 Returns: 

107 Union[float, jnp.ndarray]: Parameter value or child parameters. 

108 """ 

109 if self.is_leaf: 

110 return self.params[idx] 

111 else: 

112 return self.childs[idx].params 

113 

114 def __str__(self) -> str: 

115 """Return string representation (gate name).""" 

116 return self.name 

117 

118 def __repr__(self) -> str: 

119 """Return repr string (gate name).""" 

120 return self.name 

121 

122 @property 

123 def is_leaf(self) -> bool: 

124 """Check if this is a leaf node (direct parameters, no children).""" 

125 return self._pulse_obj is None 

126 

127 @property 

128 def size(self) -> int: 

129 """Get the total parameter count (alias for __len__).""" 

130 return len(self) 

131 

132 @property 

133 def leafs(self) -> List["PulseParams"]: 

134 """ 

135 Get all leaf nodes in the hierarchy. 

136 

137 Recursively collects all leaf PulseParams objects in the tree. 

138 

139 Returns: 

140 List[PulseParams]: List of unique leaf nodes. 

141 """ 

142 if self.is_leaf: 

143 return [self] 

144 

145 leafs = [] 

146 for obj in self._pulse_obj: 

147 leafs.extend(obj.leafs) 

148 

149 return list(set(leafs)) 

150 

151 @property 

152 def childs(self) -> List["PulseParams"]: 

153 """ 

154 Get direct children of this node. 

155 

156 Returns: 

157 List[PulseParams]: List of child PulseParams objects, or empty list 

158 if this is a leaf node. 

159 """ 

160 if self.is_leaf: 

161 return [] 

162 

163 return self._pulse_obj 

164 

165 @property 

166 def shape(self) -> List[int]: 

167 """ 

168 Get the shape of pulse parameters. 

169 

170 For leaf nodes, returns list with parameter count. 

171 For composite nodes, returns nested list of child shapes. 

172 

173 Returns: 

174 List[int]: Parameter shape specification. 

175 """ 

176 if self.is_leaf: 

177 return [len(self.params)] 

178 

179 shape = [] 

180 for obj in self.childs: 

181 shape.append(*obj.shape()) 

182 

183 return shape 

184 

185 @property 

186 def params(self) -> jnp.ndarray: 

187 """ 

188 Get or compute pulse parameters. 

189 

190 For leaf nodes, returns internal pulse parameters. 

191 For composite nodes, returns concatenated parameters from all children. 

192 

193 Returns: 

194 jnp.ndarray: Pulse parameters array. 

195 """ 

196 if self.is_leaf: 

197 return self._params 

198 

199 params = self.split_params(params=None, leafs=False) 

200 

201 return jnp.concatenate(params) 

202 

203 @params.setter 

204 def params(self, value: jnp.ndarray) -> None: 

205 """ 

206 Set pulse parameters. 

207 

208 For leaf nodes, sets internal parameters directly. 

209 For composite nodes, distributes values across children. 

210 

211 Args: 

212 value (jnp.ndarray): Pulse parameters to set. 

213 

214 Raises: 

215 AssertionError: If value is not jnp.ndarray for leaf nodes. 

216 """ 

217 if self.is_leaf: 

218 assert isinstance(value, jnp.ndarray), "params must be a jnp.ndarray" 

219 self._params = value 

220 return 

221 

222 idx = 0 

223 for obj in self.childs: 

224 nidx = idx + obj.size 

225 obj.params = value[idx:nidx] 

226 idx = nidx 

227 

228 @property 

229 def leaf_params(self) -> jnp.ndarray: 

230 """ 

231 Get parameters from all leaf nodes. 

232 

233 Returns: 

234 jnp.ndarray: Concatenated parameters from all leaf nodes. 

235 """ 

236 if self.is_leaf: 

237 return self._params 

238 

239 params = self.split_params(None, leafs=True) 

240 

241 return jnp.concatenate(params) 

242 

243 @leaf_params.setter 

244 def leaf_params(self, value: jnp.ndarray) -> None: 

245 """ 

246 Set parameters for all leaf nodes. 

247 

248 Args: 

249 value (jnp.ndarray): Parameters to distribute across leaf nodes. 

250 """ 

251 if self.is_leaf: 

252 self._params = value 

253 return 

254 

255 idx = 0 

256 for obj in self.leafs: 

257 nidx = idx + obj.size 

258 obj.params = value[idx:nidx] 

259 idx = nidx 

260 

261 def split_params( 

262 self, 

263 params: Optional[jnp.ndarray] = None, 

264 leafs: bool = False, 

265 ) -> List[jnp.ndarray]: 

266 """ 

267 Split parameters into sub-arrays for children or leaves. 

268 

269 Args: 

270 params (Optional[jnp.ndarray]): Parameters to split. If None, 

271 uses internal parameters. 

272 leafs (bool): If True, splits across leaf nodes; if False, 

273 splits across direct children. Defaults to False. 

274 

275 Returns: 

276 List[jnp.ndarray]: List of parameter arrays for children or leaves. 

277 """ 

278 if params is None: 

279 if self.is_leaf: 

280 return self._params 

281 

282 objs = self.leafs if leafs else self.childs 

283 s_params = [] 

284 for obj in objs: 

285 s_params.append(obj.params) 

286 

287 return s_params 

288 else: 

289 if self.is_leaf: 

290 return params 

291 

292 objs = self.leafs if leafs else self.childs 

293 s_params = [] 

294 idx = 0 

295 for obj in objs: 

296 nidx = idx + obj.size 

297 s_params.append(params[idx:nidx]) 

298 idx = nidx 

299 

300 return s_params 

301 

302 

303class PulseEnvelope: 

304 """Registry of pulse envelope shapes. 

305 

306 Each envelope is a pure function ``(p, t, t_c) -> amplitude`` that 

307 computes the pulse envelope *without* carrier modulation. The carrier 

308 ``cos(omega_c * t + phi_c)`` is applied separately in the coefficient 

309 functions built by :meth:`build_coeff_fns`. 

310 

311 Attributes: 

312 REGISTRY: Mapping from envelope name to metadata dict containing 

313 ``fn`` (callable), ``n_envelope_params`` (int), and per-gate 

314 default parameter arrays. 

315 """ 

316 

317 @staticmethod 

318 def gaussian(p, t, t_c): 

319 """Gaussian envelope. ``p = [A, sigma]``.""" 

320 A, sigma = p[0], p[1] 

321 return A * jnp.exp(-0.5 * ((t - t_c) / sigma) ** 2) 

322 

323 @staticmethod 

324 def square(p, t, t_c): 

325 """Rectangular envelope. ``p = [A, width]``.""" 

326 A, width = p[0], p[1] 

327 return A * (jnp.abs(t - t_c) <= width / 2) 

328 

329 @staticmethod 

330 def cosine(p, t, t_c): 

331 """Raised cosine envelope. ``p = [A, width]``.""" 

332 A, width = p[0], p[1] 

333 x = jnp.clip((t - t_c) / width, -0.5, 0.5) 

334 return A * jnp.cos(jnp.pi * x) 

335 

336 @staticmethod 

337 def drag(p, t, t_c): 

338 """DRAG (Derivative Removal by Adiabatic Gate). ``p = [A, beta, sigma]``.""" 

339 A, beta, sigma = p[0], p[1], p[2] 

340 g = A * jnp.exp(-0.5 * ((t - t_c) / sigma) ** 2) 

341 dg = g * (-(t - t_c) / sigma**2) 

342 return g + beta * dg 

343 

344 @staticmethod 

345 def sech(p, t, t_c): 

346 """Hyperbolic secant envelope. ``p = [A, sigma]``.""" 

347 A, sigma = p[0], p[1] 

348 return A / jnp.cosh((t - t_c) / sigma) 

349 

350 # ``n_envelope_params`` counts only the envelope parameters (excluding 

351 # the evolution time ``t`` which is always the last element of the full 

352 # pulse parameter vector). 

353 REGISTRY = { 

354 "gaussian": { 

355 "fn": gaussian.__func__, 

356 "n_envelope_params": 2, 

357 "defaults": { 

358 "RX": jnp.array( 

359 [0.38009941846766804, 1.631698142660167, 3.007403822238108] 

360 ), 

361 "RY": jnp.array( 

362 [0.3836652338514791, 1.616595983505249, 2.9794135093698966] 

363 ), 

364 }, 

365 }, 

366 "square": { 

367 "fn": square.__func__, 

368 "n_envelope_params": 2, 

369 "defaults": { 

370 "RX": jnp.array( 

371 [1.209655637514602, 0.8266815576721239, 1.1483122857413859] 

372 ), 

373 "RY": jnp.array( 

374 [1.0287942142779052, 0.9860505130182093, 0.9720116870310977] 

375 ), 

376 }, 

377 }, 

378 "cosine": { 

379 "fn": cosine.__func__, 

380 "n_envelope_params": 2, 

381 "defaults": { 

382 "RX": jnp.array([1.0, 1.0, 1.0]), 

383 "RY": jnp.array([1.0, 1.0, 1.0]), 

384 }, 

385 }, 

386 "drag": { 

387 "fn": drag.__func__, 

388 "n_envelope_params": 3, 

389 "defaults": { 

390 "RX": jnp.array( 

391 [ 

392 0.326562746114197, 

393 0.4002767596709071, 

394 5.3228107728890315, 

395 3.141300761986467, 

396 ] 

397 ), 

398 "RY": jnp.array( 

399 [ 

400 0.323287924190616, 

401 0.4065017233024265, 

402 7.00299644871222, 

403 3.139481229843545, 

404 ] 

405 ), 

406 }, 

407 }, 

408 "sech": { 

409 "fn": sech.__func__, 

410 "n_envelope_params": 2, 

411 "defaults": { 

412 "RX": jnp.array([1.0, 1.0, 1.0]), 

413 "RY": jnp.array([1.0, 1.0, 1.0]), 

414 }, 

415 }, 

416 "general": { 

417 "fn": None, 

418 "n_envelope_params": 0, 

419 "defaults": { 

420 "RZ": jnp.array([0.5]), 

421 "CZ": jnp.array([0.3183098783513154]), 

422 }, 

423 }, 

424 } 

425 

426 @staticmethod 

427 def available() -> List[str]: 

428 """Return list of registered envelope names.""" 

429 return list(PulseEnvelope.REGISTRY.keys()) 

430 

431 @staticmethod 

432 def get(name: str) -> dict: 

433 """Look up envelope metadata by name. 

434 

435 Raises: 

436 ValueError: If *name* is not registered. 

437 """ 

438 if name not in PulseEnvelope.REGISTRY: 

439 raise ValueError( 

440 f"Unknown pulse envelope '{name}'. " 

441 f"Available: {PulseEnvelope.available()}" 

442 ) 

443 return PulseEnvelope.REGISTRY[name] 

444 

445 @staticmethod 

446 def build_coeff_fns( 

447 envelope_fn: Callable, 

448 omega_c: float, 

449 omega_q: float, 

450 rwa: bool = True, 

451 frame: str = "drive", 

452 ) -> Tuple[Callable, Callable, Callable, Callable]: 

453 """Build the four interaction-picture coefficient functions. 

454 

455 The lab-frame Hamiltonian is 

456 

457 H(t,Π) = H_static + Σ_j S_j(t;Π) H_j , 

458 S_j(t;Π) = E_j(t;Π) · cos(ω_c·t + φ_c) , 

459 

460 and the interaction-picture transform with respect to 

461 ``H_static = (ω_q/2)·Z`` produces 

462 

463 H̃_j(t) = exp(+i H_static t) H_j exp(-i H_static t) , 

464 H_I(t) = Σ_j S_j(t) H̃_j(t) . 

465 

466 For a single qubit driven on X, ``H̃_X(t) = cos(ω_q·t) X 

467 − sin(ω_q·t) Y``, so 

468 

469 H_I(t) = Ω(t) · cos(ω_c·t + φ) · 

470 [ cos(ω_q·t) · X − sin(ω_q·t) · Y ] . 

471 

472 ``rwa=True`` (default) drops the fast (~2·ω_q on resonance) terms and 

473 keeps only the slow envelope, yielding the analytical RWA 

474 

475 H_I^RWA(t) = (Ω(t)/2) · [ cos(φ) X + sin(φ) Y ] . 

476 

477 For RX (``φ = 0``) this reduces to ``(Ω/2)·X``; for RY 

478 (``φ = +π/2``) to ``(Ω/2)·Y``. This is dramatically cheaper to 

479 integrate (no fast oscillations → adaptive ODE solver takes 

480 large steps). 

481 

482 ``rwa=False`` keeps **both** the slow and the fast 

483 counter-rotating components. 

484 

485 Each returned function has a unique ``__code__`` object so the 

486 jaqsi solver cache assigns separate compiled XLA programs per 

487 envelope shape and per (gate, component) pair. 

488 

489 The rotation angle ``w`` is expected as the **last** element of 

490 the parameter array ``p`` (i.e. ``p[-1]``). Envelope parameters 

491 occupy ``p[:-1]``. 

492 

493 Args: 

494 envelope_fn: Pure envelope function ``(p, t, t_c) -> scalar``. 

495 omega_c: Carrier frequency. 

496 omega_q: Qubit frequency (interaction-picture rotation rate). 

497 rwa: When ``True``, return the RWA-truncated coefficients 

498 (no fast counter-rotating terms). Default ``True`` 

499 frame: Algebraic representation of the exact (non-RWA) 

500 coefficients. Mathematically equivalent options: 

501 

502 * ``"drive"`` (default): applies the product-to-sum identity to 

503 expose the slow ``(ω_c-ω_q)`` and fast ``(ω_c+ω_q)`` 

504 modes explicitly, 

505 ``cos(ω_c t)cos(ω_q t) = 

506 ½[cos((ω_c-ω_q)t) + cos((ω_c+ω_q)t)]``. Algebraically 

507 identical to ``"lab"`` (no RWA, no information lost). 

508 Primary use: combined with the ``magnus2``/``magnus4`` 

509 jaqsi solvers, the explicit slow/fast decomposition 

510 is sometimes numerically better-conditioned and lets 

511 the user pick a fixed grid based on the slow 

512 frequency alone (``Δ = |ω_c-ω_q|``) when the fast 

513 ``(ω_c+ω_q)`` mode is well-resolved by the chosen 

514 step. 

515 * ``"drive"``: the literal form 

516 ``Ω(t) cos(ω_c t + φ) cos(ω_q t)`` (and the analogous 

517 ``-sin`` term). Two trig multiplications per call; 

518 contains all four product frequencies implicitly. 

519 

520 Ignored when ``rwa=True``. 

521 

522 Returns: 

523 Tuple ``(coeff_RX_X, coeff_RX_Y, coeff_RY_X, coeff_RY_Y)`` 

524 of coefficient functions for the X- and Y-components of the 

525 RX and RY interaction-picture Hamiltonians. 

526 """ 

527 if frame not in ("lab", "drive"): 

528 raise ValueError(f"Unknown frame {frame!r}; expected 'lab' or 'drive'.") 

529 if rwa: 

530 # RWA-truncated coefficients (no carrier, no fast factors). 

531 # H_I^RWA = (Ω(t)/2) [cos(φ) X + sin(φ) Y]; we keep the 

532 # ``p[-1]`` rotation-angle scaling so the calling 

533 # ParametrizedHamiltonian shape is unchanged. 

534 half = jnp.asarray(0.5) 

535 

536 def _coeff_RX_X(p, t): 

537 t_c = t / 2 

538 env = envelope_fn(p, t, t_c) 

539 return half * env * p[-1] 

540 

541 def _coeff_RX_Y(p, t): # Y component vanishes for RX (φ=0) 

542 t_c = t / 2 

543 env = envelope_fn(p, t, t_c) 

544 return jnp.zeros_like(half * env * p[-1]) 

545 

546 def _coeff_RY_X(p, t): # X component vanishes for RY (φ=π/2) 

547 t_c = t / 2 

548 env = envelope_fn(p, t, t_c) 

549 return jnp.zeros_like(half * env * p[-1]) 

550 

551 def _coeff_RY_Y(p, t): 

552 t_c = t / 2 

553 env = envelope_fn(p, t, t_c) 

554 return half * env * p[-1] 

555 

556 return _coeff_RX_X, _coeff_RX_Y, _coeff_RY_X, _coeff_RY_Y 

557 

558 if frame == "drive": 

559 # Drive-frame: same exact dynamics, expressed via the 

560 # product-to-sum identities so the slow ``Δ = ω_c - ω_q`` 

561 # and fast ``Σ = ω_c + ω_q`` modes appear explicitly. 

562 # Mathematically identical to the ``lab`` branch below. 

563 # 

564 # Identities used: 

565 # cos(ω_c t) cos(ω_q t) = ½[cos(Δ t) + cos(Σ t)] 

566 # cos(ω_c t) sin(ω_q t) = ½[sin(Σ t) − sin(Δ t)] 

567 # −sin(ω_c t) cos(ω_q t) = −½[sin(Σ t) + sin(Δ t)] 

568 # −sin(ω_c t) sin(ω_q t) = ½[cos(Σ t) − cos(Δ t)] 

569 # (RY uses cos(ω_c t + π/2) = −sin(ω_c t).) 

570 omega_d = omega_c - omega_q 

571 omega_s = omega_c + omega_q 

572 half = jnp.asarray(0.5) 

573 

574 def _coeff_RX_X(p, t): 

575 t_c = t / 2 

576 env = envelope_fn(p, t, t_c) 

577 mod = half * (jnp.cos(omega_d * t) + jnp.cos(omega_s * t)) 

578 return env * mod * p[-1] 

579 

580 def _coeff_RX_Y(p, t): 

581 t_c = t / 2 

582 env = envelope_fn(p, t, t_c) 

583 mod = -half * (jnp.sin(omega_s * t) - jnp.sin(omega_d * t)) 

584 return env * mod * p[-1] 

585 

586 def _coeff_RY_X(p, t): 

587 t_c = t / 2 

588 env = envelope_fn(p, t, t_c) 

589 mod = -half * (jnp.sin(omega_s * t) + jnp.sin(omega_d * t)) 

590 return env * mod * p[-1] 

591 

592 def _coeff_RY_Y(p, t): 

593 t_c = t / 2 

594 env = envelope_fn(p, t, t_c) 

595 mod = -half * (jnp.cos(omega_s * t) - jnp.cos(omega_d * t)) 

596 return env * mod * p[-1] 

597 

598 return _coeff_RX_X, _coeff_RX_Y, _coeff_RY_X, _coeff_RY_Y 

599 

600 # RX uses carrier phase phi = 0 so that after RWA 

601 # cos(ω_q τ)·cos(ω_q τ) averages to +1/2 → drives +X 

602 # -cos(ω_q τ)·sin(ω_q τ) averages to 0 → Y cancels 

603 # giving H_I^RWA ≈ (Ω/2)·X → U ≈ exp(-iθ/2 X), matching op.RX. 

604 # The exact form below KEEPS the fast 2·ω_q components. 

605 def _coeff_RX_X(p, t): 

606 t_c = t / 2 

607 env = envelope_fn(p, t, t_c) 

608 carrier = jnp.cos(omega_c * t) 

609 return env * carrier * jnp.cos(omega_q * t) * p[-1] 

610 

611 def _coeff_RX_Y(p, t): 

612 t_c = t / 2 

613 env = envelope_fn(p, t, t_c) 

614 carrier = jnp.cos(omega_c * t) 

615 return -env * carrier * jnp.sin(omega_q * t) * p[-1] 

616 

617 # RY uses carrier phase phi = +pi/2 so the RWA component drives +Y. 

618 def _coeff_RY_X(p, t): 

619 t_c = t / 2 

620 env = envelope_fn(p, t, t_c) 

621 carrier = jnp.cos(omega_c * t + jnp.pi / 2) 

622 return env * carrier * jnp.cos(omega_q * t) * p[-1] 

623 

624 def _coeff_RY_Y(p, t): 

625 t_c = t / 2 

626 env = envelope_fn(p, t, t_c) 

627 carrier = jnp.cos(omega_c * t + jnp.pi / 2) 

628 return -env * carrier * jnp.sin(omega_q * t) * p[-1] 

629 

630 return _coeff_RX_X, _coeff_RX_Y, _coeff_RY_X, _coeff_RY_Y 

631 

632 

633class PulseInformation: 

634 """Stores pulse parameter counts and optimized pulse parameters. 

635 

636 Call :meth:`set_envelope` to switch the active pulse shape. This 

637 rebuilds all :class:`PulseParams` trees so that parameter counts 

638 and defaults match the selected envelope. 

639 """ 

640 

641 DEFAULT_ENVELOPE: str = "drag" 

642 DEFAULT_RWA: bool = True 

643 DEFAULT_FRAME: str = "drive" 

644 LEAF_GATE_NAMES: Tuple[str, ...] = ("RX", "RY", "RZ", "CZ") 

645 

646 _envelope: str = DEFAULT_ENVELOPE 

647 # Whether to apply the rotating-wave approximation when building the 

648 # interaction-picture coefficient functions. 

649 # Default ``True`` (exact dynamics, no RWA). 

650 # Setting to ``True`` drops the fast counter-rotating terms — 

651 # much faster to integrate 

652 # See :meth:`PulseEnvelope.build_coeff_fns`. 

653 _rwa: bool = DEFAULT_RWA 

654 # Algebraic representation of the (non-RWA) coefficients. Either 

655 # ``"lab"`` or ``"drive"`` (product-to-sum decomposition). 

656 # Mathematically equivalent — see :meth:`PulseEnvelope.build_coeff_fns` 

657 # when ``"drive"`` is numerically advantageous (mainly with the Magnus solvers). 

658 _frame: str = DEFAULT_FRAME 

659 

660 @classmethod 

661 def _build_leaf_gates(cls): 

662 """(Re-)create leaf PulseParams from the active envelope defaults.""" 

663 defaults = PulseEnvelope.get(cls._envelope)["defaults"] 

664 general = PulseEnvelope.get("general")["defaults"] 

665 

666 cls.RX = PulseParams(name="RX", params=defaults["RX"]) 

667 cls.RY = PulseParams(name="RY", params=defaults["RY"]) 

668 

669 cls.RZ = PulseParams(name="RZ", params=general["RZ"]) 

670 cls.CZ = PulseParams(name="CZ", params=general["CZ"]) 

671 

672 @classmethod 

673 def _build_composite_gates(cls): 

674 """(Re-)create composite PulseParams trees from current leaves.""" 

675 cls.H = PulseParams( 

676 name="H", 

677 decomposition=[ 

678 DecompositionStep(cls.RZ, "all", lambda w: jnp.pi), 

679 DecompositionStep(cls.RY, "all", lambda w: jnp.pi / 2), 

680 ], 

681 ) 

682 cls.CX = PulseParams( 

683 name="CX", 

684 decomposition=[ 

685 DecompositionStep(cls.H, "target", lambda w: 0.0), 

686 DecompositionStep(cls.CZ, "all", lambda w: 0.0), 

687 DecompositionStep(cls.H, "target", lambda w: 0.0), 

688 ], 

689 ) 

690 cls.CY = PulseParams( 

691 name="CY", 

692 decomposition=[ 

693 DecompositionStep(cls.RZ, "target", lambda w: -jnp.pi / 2), 

694 DecompositionStep(cls.CX, "all"), 

695 DecompositionStep(cls.RZ, "target", lambda w: jnp.pi / 2), 

696 ], 

697 ) 

698 cls.CRX = PulseParams( 

699 name="CRX", 

700 decomposition=[ 

701 DecompositionStep(cls.RZ, "target", lambda w: jnp.pi / 2), 

702 DecompositionStep(cls.RY, "target", lambda w: w / 2), 

703 DecompositionStep(cls.CX, "all", lambda w: 0.0), 

704 DecompositionStep(cls.RY, "target", lambda w: -w / 2), 

705 DecompositionStep(cls.CX, "all", lambda w: 0.0), 

706 DecompositionStep(cls.RZ, "target", lambda w: -jnp.pi / 2), 

707 ], 

708 ) 

709 cls.CRY = PulseParams( 

710 name="CRY", 

711 decomposition=[ 

712 DecompositionStep(cls.RY, "target", lambda w: w / 2), 

713 DecompositionStep(cls.CX, "all", lambda w: 0.0), 

714 DecompositionStep(cls.RY, "target", lambda w: -w / 2), 

715 DecompositionStep(cls.CX, "all", lambda w: 0.0), 

716 ], 

717 ) 

718 cls.CRZ = PulseParams( 

719 name="CRZ", 

720 decomposition=[ 

721 DecompositionStep(cls.RZ, "target", lambda w: w / 2), 

722 DecompositionStep(cls.CX, "all", lambda w: 0.0), 

723 DecompositionStep(cls.RZ, "target", lambda w: -w / 2), 

724 DecompositionStep(cls.CX, "all", lambda w: 0.0), 

725 ], 

726 ) 

727 # TODO: check if we could just make this a basis gate instead 

728 cls.CPhase = PulseParams( 

729 name="CPhase", 

730 decomposition=[ 

731 DecompositionStep(cls.RZ, "control", lambda w: w / 2), 

732 DecompositionStep(cls.RZ, "target", lambda w: w / 2), 

733 DecompositionStep(cls.CX, "all", lambda w: 0.0), 

734 DecompositionStep(cls.RZ, "target", lambda w: -w / 2), 

735 DecompositionStep(cls.CX, "all", lambda w: 0.0), 

736 ], 

737 ) 

738 cls.RZZ = PulseParams( 

739 name="RZZ", 

740 decomposition=[ 

741 DecompositionStep(cls.CX, "all", lambda w: 0.0), 

742 DecompositionStep(cls.RZ, "target", lambda w: w), 

743 DecompositionStep(cls.CX, "all", lambda w: 0.0), 

744 ], 

745 ) 

746 cls.RXX = PulseParams( 

747 name="RXX", 

748 decomposition=[ 

749 DecompositionStep(cls.H, "control", lambda w: 0.0), 

750 DecompositionStep(cls.H, "target", lambda w: 0.0), 

751 DecompositionStep(cls.CX, "all", lambda w: 0.0), 

752 DecompositionStep(cls.RZ, "target", lambda w: w), 

753 DecompositionStep(cls.CX, "all", lambda w: 0.0), 

754 DecompositionStep(cls.H, "control", lambda w: 0.0), 

755 DecompositionStep(cls.H, "target", lambda w: 0.0), 

756 ], 

757 ) 

758 cls.RYY = PulseParams( 

759 name="RYY", 

760 decomposition=[ 

761 DecompositionStep(cls.RX, "control", lambda w: jnp.pi / 2), 

762 DecompositionStep(cls.RX, "target", lambda w: jnp.pi / 2), 

763 DecompositionStep(cls.CX, "all", lambda w: 0.0), 

764 DecompositionStep(cls.RZ, "target", lambda w: w), 

765 DecompositionStep(cls.CX, "all", lambda w: 0.0), 

766 DecompositionStep(cls.RX, "control", lambda w: -jnp.pi / 2), 

767 DecompositionStep(cls.RX, "target", lambda w: -jnp.pi / 2), 

768 ], 

769 ) 

770 cls.RZX = PulseParams( 

771 name="RZX", 

772 decomposition=[ 

773 DecompositionStep(cls.H, "target", lambda w: 0.0), 

774 DecompositionStep(cls.CX, "all", lambda w: 0.0), 

775 DecompositionStep(cls.RZ, "target", lambda w: w), 

776 DecompositionStep(cls.CX, "all", lambda w: 0.0), 

777 DecompositionStep(cls.H, "target", lambda w: 0.0), 

778 ], 

779 ) 

780 cls.Rot = PulseParams( 

781 name="Rot", 

782 decomposition=[ 

783 DecompositionStep(cls.RZ, "all", lambda w: w[0]), 

784 DecompositionStep(cls.RY, "all", lambda w: w[1]), 

785 DecompositionStep(cls.RZ, "all", lambda w: w[2]), 

786 ], 

787 ) 

788 cls.unique_gate_set = [cls.RX, cls.RY, cls.RZ, cls.CZ] 

789 

790 @classmethod 

791 def set_envelope( 

792 cls, 

793 name: str, 

794 rwa: Optional[bool] = None, 

795 frame: Optional[str] = None, 

796 ) -> None: 

797 """Switch pulse envelope and rebuild all PulseParams trees. 

798 

799 Also updates the coefficient functions used by :class:`PulseGates`. 

800 

801 Args: 

802 name: One of :meth:`PulseEnvelope.available`. 

803 rwa: If given, also update the RWA flag. If ``None`` (the 

804 default), the current value of ``cls._rwa`` is kept. 

805 See :meth:`PulseEnvelope.build_coeff_fns` for the 

806 physical meaning of the flag. 

807 frame: If given, also update the coefficient frame 

808 (``"lab"`` or ``"drive"``). ``None`` keeps the current 

809 value of ``cls._frame``. Ignored when ``rwa=True`` or 

810 when the existing RWA flag is on. 

811 """ 

812 info = PulseEnvelope.get(name) # validates name 

813 cls._envelope = name 

814 if rwa is not None: 

815 cls._rwa = bool(rwa) 

816 if frame is not None: 

817 if frame not in ("lab", "drive"): 

818 raise ValueError(f"Unknown frame {frame!r}; expected 'lab' or 'drive'.") 

819 cls._frame = frame 

820 cls._build_leaf_gates() 

821 cls._build_composite_gates() 

822 

823 # Rebuild interaction-picture coefficient functions on PulseGates. 

824 # Four functions: (RX_X, RX_Y, RY_X, RY_Y) — one per (gate, Pauli) 

825 # component of the proper interaction-picture drive Hamiltonian. 

826 rx_x, rx_y, ry_x, ry_y = PulseEnvelope.build_coeff_fns( 

827 info["fn"], 

828 PulseGates.omega_c, 

829 PulseGates.omega_q, 

830 rwa=cls._rwa, 

831 frame=cls._frame, 

832 ) 

833 PulseGates._coeff_RX_X = staticmethod(rx_x) 

834 PulseGates._coeff_RX_Y = staticmethod(rx_y) 

835 PulseGates._coeff_RY_X = staticmethod(ry_x) 

836 PulseGates._coeff_RY_Y = staticmethod(ry_y) 

837 # Backward-compat aliases for older introspection (point at the 

838 # X-component which dominates RX, Y-component which dominates RY). 

839 PulseGates._coeff_Sx = staticmethod(rx_x) 

840 PulseGates._coeff_Sy = staticmethod(ry_y) 

841 PulseGates._active_envelope = name 

842 PulseGates._active_rwa = cls._rwa 

843 PulseGates._active_frame = cls._frame 

844 

845 # The compiled-solver cache in ``Evolution`` is keyed on the code 

846 # objects of the coefficient functions. Rebuilding the coeff 

847 # fns above produced fresh code objects, so any cached solver 

848 # is now unreachable from the live coefficient functions and 

849 # must be evicted to avoid both (a) holding compiled programs 

850 # for a previous configuration alive forever and (b) returning 

851 # a stale program if ``id`` collisions ever leaked through. 

852 js.Evolution.clear_evolve_solver_cache() 

853 

854 log.info( 

855 f"Pulse envelope set to '{name}' " 

856 f"(RWA {'on' if cls._rwa else 'off'}, frame={cls._frame})" 

857 ) 

858 

859 @classmethod 

860 def set_rwa(cls, rwa: bool) -> None: 

861 """Toggle the rotating-wave approximation for pulse coefficients. 

862 

863 Rebuilds the coefficient functions for the currently active 

864 envelope so the change takes effect immediately. Default is 

865 ``False`` (exact interaction picture). 

866 See :meth:`PulseEnvelope.build_coeff_fns` for details 

867 """ 

868 cls.set_envelope(cls._envelope, rwa=bool(rwa)) 

869 

870 @classmethod 

871 def get_envelope(cls) -> str: 

872 """Return the name of the active pulse envelope.""" 

873 return cls._envelope 

874 

875 @classmethod 

876 def get_rwa(cls) -> bool: 

877 """Return whether the RWA flag is currently active.""" 

878 return cls._rwa 

879 

880 @classmethod 

881 def set_frame(cls, frame: str) -> None: 

882 """Switch the algebraic representation of the (non-RWA) coefficients. 

883 

884 ``"lab"`` (default) and ``"drive"`` are mathematically 

885 identical (no information lost, no RWA applied) — see 

886 :meth:`PulseEnvelope.build_coeff_fns` for when ``"drive"`` is 

887 useful. Rebuilds the coefficient functions for the currently 

888 active envelope so the change takes effect immediately. 

889 """ 

890 cls.set_envelope(cls._envelope, frame=str(frame)) 

891 

892 @classmethod 

893 def get_frame(cls) -> str: 

894 """Return the active coefficient frame (``"lab"`` or ``"drive"``).""" 

895 return cls._frame 

896 

897 @classmethod 

898 def snapshot_state(cls) -> PulseStateSnapshot: 

899 """Return an immutable snapshot of the active pulse configuration.""" 

900 leaf_params = {} 

901 for name in cls.LEAF_GATE_NAMES: 

902 gate = getattr(cls, name, None) 

903 if gate is not None: 

904 leaf_params[name] = jnp.array(gate.params) 

905 

906 return PulseStateSnapshot( 

907 envelope=cls._envelope, 

908 rwa=cls._rwa, 

909 frame=cls._frame, 

910 leaf_params=leaf_params, 

911 ) 

912 

913 @classmethod 

914 def restore_state(cls, snapshot: PulseStateSnapshot) -> None: 

915 """Restore a snapshot produced by :meth:`snapshot_state`.""" 

916 cls.set_envelope(snapshot.envelope, rwa=snapshot.rwa, frame=snapshot.frame) 

917 

918 for name, params in snapshot.leaf_params.items(): 

919 gate = cls.gate_by_name(name) 

920 if gate is None or not gate.is_leaf: 

921 raise ValueError(f"Cannot restore unknown leaf pulse gate {name!r}.") 

922 if gate.params.shape != params.shape: 

923 raise ValueError( 

924 f"Snapshot for {name!r} has shape {params.shape}, " 

925 f"but active gate expects {gate.params.shape}." 

926 ) 

927 gate.params = params 

928 

929 @classmethod 

930 @contextmanager 

931 def preserve_state(cls): 

932 """Temporarily preserve global pulse state across scoped mutations.""" 

933 snapshot = cls.snapshot_state() 

934 try: 

935 yield snapshot 

936 finally: 

937 cls.restore_state(snapshot) 

938 

939 @classmethod 

940 def reset_defaults( 

941 cls, 

942 envelope: Optional[str] = None, 

943 rwa: Optional[bool] = None, 

944 frame: Optional[str] = None, 

945 ) -> None: 

946 """Reset pulse globals to canonical defaults or explicit values.""" 

947 cls.set_envelope( 

948 cls.DEFAULT_ENVELOPE if envelope is None else envelope, 

949 rwa=cls.DEFAULT_RWA if rwa is None else rwa, 

950 frame=cls.DEFAULT_FRAME if frame is None else frame, 

951 ) 

952 

953 @staticmethod 

954 def gate_by_name(gate): 

955 if isinstance(gate, str): 

956 return getattr(PulseInformation, gate, None) 

957 else: 

958 return getattr(PulseInformation, gate.__name__, None) 

959 

960 @staticmethod 

961 def num_params(gate): 

962 return len(PulseInformation.gate_by_name(gate)) 

963 

964 @staticmethod 

965 def update_params(path=f"{os.getcwd()}/qml_essentials/qoc_results.csv"): 

966 if os.path.isfile(path): 

967 log.info(f"Loading optimized pulses from {path}") 

968 with open(path, "r") as f: 

969 reader = csv.reader(f) 

970 

971 for row in reader: 

972 log.debug( 

973 f"Loading optimized pulses for {row[0]}\ 

974 (Fidelity: {float(row[1]):.5f}): {row[2:]}" 

975 ) 

976 PulseInformation.OPTIMIZED_PULSES[row[0]] = jnp.array( 

977 [float(x) for x in row[2:]] 

978 ) 

979 else: 

980 log.error(f"No optimized pulses found at {path}") 

981 

982 @staticmethod 

983 def shuffle_params(random_key): 

984 log.info( 

985 f"Shuffling optimized pulses with random key {random_key}\ 

986 of gates {PulseInformation.unique_gate_set}" 

987 ) 

988 for gate in PulseInformation.unique_gate_set: 

989 random_key, sub_key = safe_random_split(random_key) 

990 gate.params = jax.random.uniform(sub_key, (len(gate),)) 

991 

992 

993class PulseGates: 

994 """Pulse-level implementations of quantum gates. 

995 

996 Implements quantum gates using time-dependent Hamiltonians and pulse 

997 sequences, following the approach from https://doi.org/10.5445/IR/1000184129. 

998 The active pulse envelope is selected via 

999 :meth:`PulseInformation.set_envelope`. 

1000 

1001 Attributes: 

1002 omega_q: Qubit frequency (10π). 

1003 omega_c: Carrier frequency (10π). 

1004 _active_envelope: Name of the currently active envelope shape. 

1005 """ 

1006 

1007 # NOTE: Implementation of S, RX, RY, RZ, CZ, CNOT/CX and H pulse level 

1008 # gates closely follow https://doi.org/10.5445/IR/1000184129 

1009 omega_q = 10 * jnp.pi 

1010 omega_c = 10 * jnp.pi 

1011 

1012 X = jnp.array([[0, 1], [1, 0]]) 

1013 Y = jnp.array([[0, -1j], [1j, 0]]) 

1014 Z = jnp.array([[1, 0], [0, -1]]) 

1015 

1016 Id = jnp.eye(2, dtype=jnp.complex64) 

1017 

1018 _H_CZ = (jnp.pi / 4) * ( 

1019 jnp.kron(Id, Id) - jnp.kron(Z, Id) - jnp.kron(Id, Z) + jnp.kron(Z, Z) 

1020 ) 

1021 

1022 _H_corr = jnp.pi / 2 * jnp.eye(2, dtype=jnp.complex64) 

1023 

1024 _active_envelope: str = "gaussian" 

1025 # Mirrors :attr:`PulseInformation._rwa`; kept here for introspection 

1026 # of which coefficient regime the active ``_coeff_*`` functions 

1027 # implement. Updated by :meth:`PulseInformation.set_envelope` / 

1028 # :meth:`PulseInformation.set_rwa`. 

1029 _active_rwa: bool = True 

1030 _active_frame: str = "drive" 

1031 

1032 # Default coefficient functions for the gaussian envelope; the active 

1033 # envelope's `set_envelope` will overwrite these. Each gate uses two 

1034 # coefficients (X- and Y-component of the proper interaction-picture 

1035 # drive Hamiltonian). 

1036 

1037 @staticmethod 

1038 def _coeff_RX_X(p, t): 

1039 """RX coefficient for the X term (gaussian default).""" 

1040 t_c = t / 2 

1041 env = PulseEnvelope.gaussian(p, t, t_c) 

1042 carrier = jnp.cos(PulseGates.omega_c * t) 

1043 return env * carrier * jnp.cos(PulseGates.omega_q * t) * p[-1] 

1044 

1045 @staticmethod 

1046 def _coeff_RX_Y(p, t): 

1047 """RX coefficient for the Y term (gaussian default).""" 

1048 t_c = t / 2 

1049 env = PulseEnvelope.gaussian(p, t, t_c) 

1050 carrier = jnp.cos(PulseGates.omega_c * t) 

1051 return -env * carrier * jnp.sin(PulseGates.omega_q * t) * p[-1] 

1052 

1053 @staticmethod 

1054 def _coeff_RY_X(p, t): 

1055 """RY coefficient for the X term (gaussian default).""" 

1056 t_c = t / 2 

1057 env = PulseEnvelope.gaussian(p, t, t_c) 

1058 carrier = jnp.cos(PulseGates.omega_c * t + jnp.pi / 2) 

1059 return env * carrier * jnp.cos(PulseGates.omega_q * t) * p[-1] 

1060 

1061 @staticmethod 

1062 def _coeff_RY_Y(p, t): 

1063 """RY coefficient for the Y term (gaussian default).""" 

1064 t_c = t / 2 

1065 env = PulseEnvelope.gaussian(p, t, t_c) 

1066 carrier = jnp.cos(PulseGates.omega_c * t + jnp.pi / 2) 

1067 return -env * carrier * jnp.sin(PulseGates.omega_q * t) * p[-1] 

1068 

1069 # Backward-compat aliases (resolve to the dominant component of each gate). 

1070 _coeff_Sx = _coeff_RX_X 

1071 _coeff_Sy = _coeff_RY_Y 

1072 

1073 @staticmethod 

1074 def _coeff_Sz(p, t): 

1075 """Coefficient function for RZ pulse: p * w.""" 

1076 return p[0] * p[1] 

1077 

1078 @staticmethod 

1079 def _coeff_Sc(p, t): 

1080 """Constant coefficient for H correction phase.""" 

1081 return -1.0 

1082 

1083 @staticmethod 

1084 def _coeff_Scz(p, t): 

1085 """Coefficient function for CZ pulse.""" 

1086 return p * jnp.pi 

1087 

1088 @staticmethod 

1089 def _record_pulse_event(gate_name, w, wires, pulse_params, parent=None): 

1090 """Append a PulseEvent to the active pulse tape if recording. 

1091 

1092 This is called from leaf gate methods (RX, RY, RZ, CZ) so that 

1093 :func:`~qml_essentials.tape.pulse_recording` can collect events 

1094 without the caller needing to know about the tape. 

1095 """ 

1096 ptape = active_pulse_tape() 

1097 if ptape is None: 

1098 return 

1099 

1100 from qml_essentials.drawing import PulseEvent, LEAF_META 

1101 

1102 meta = LEAF_META.get(gate_name, {}) 

1103 wires_list = [wires] if isinstance(wires, int) else list(wires) 

1104 

1105 if meta.get("physical", False): 

1106 info = PulseEnvelope.get(PulseInformation.get_envelope()) 

1107 pp = PulseInformation.gate_by_name(gate_name).split_params(pulse_params) 

1108 env_p = pp[:-1] 

1109 dur = float(pp[-1]) 

1110 ptape.append( 

1111 PulseEvent( 

1112 gate=gate_name, 

1113 wires=wires_list, 

1114 envelope_fn=info["fn"], 

1115 envelope_params=jnp.array(env_p), 

1116 w=float(w), 

1117 duration=dur, 

1118 carrier_phase=meta["carrier_phase"], 

1119 parent=parent, 

1120 ) 

1121 ) 

1122 else: 

1123 pp = PulseInformation.gate_by_name(gate_name).split_params(pulse_params) 

1124 ptape.append( 

1125 PulseEvent( 

1126 gate=gate_name, 

1127 wires=wires_list, 

1128 envelope_fn=None, 

1129 envelope_params=jnp.ravel(jnp.asarray(pp)), 

1130 w=float(w) if not isinstance(w, list) else 0.0, 

1131 duration=1.0, 

1132 carrier_phase=0.0, 

1133 parent=parent, 

1134 ) 

1135 ) 

1136 

1137 @staticmethod 

1138 def Rot( 

1139 phi: float, 

1140 theta: float, 

1141 omega: float, 

1142 wires: Union[int, List[int]], 

1143 pulse_params: Optional[jnp.ndarray] = None, 

1144 noise_params: Optional[Dict[str, float]] = None, 

1145 random_key: Optional[jax.random.PRNGKey] = None, 

1146 ) -> None: 

1147 """ 

1148 Apply general rotation via decomposition: RZ(phi) · RY(theta) · RZ(omega). 

1149 

1150 Args: 

1151 phi (float): First rotation angle. 

1152 theta (float): Second rotation angle. 

1153 omega (float): Third rotation angle. 

1154 wires (Union[int, List[int]]): Qubit index or indices to apply rotation to. 

1155 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the 

1156 composing gates. If None, uses optimized parameters. 

1157 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

1158 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility 

1159 

1160 Returns: 

1161 None: Gates are applied in-place to the circuit. 

1162 """ 

1163 if noise_params is not None and "GateError" in noise_params: 

1164 phi, random_key = UnitaryGates.GateError(phi, noise_params, random_key) 

1165 theta, random_key = UnitaryGates.GateError(theta, noise_params, random_key) 

1166 omega, random_key = UnitaryGates.GateError(omega, noise_params, random_key) 

1167 PulseGates._execute_composite("Rot", [phi, theta, omega], wires, pulse_params) 

1168 UnitaryGates.Noise(wires, noise_params) 

1169 

1170 @staticmethod 

1171 def PauliRot( 

1172 pauli: str, 

1173 theta: float, 

1174 wires: Union[int, List[int]], 

1175 pulse_params: Optional[jnp.ndarray] = None, 

1176 noise_params: Optional[Dict[str, float]] = None, 

1177 random_key: Optional[jax.random.PRNGKey] = None, 

1178 ) -> None: 

1179 """Not implemented as a PulseGate.""" 

1180 raise NotImplementedError("PauliRot gate is not implemented as PulseGate") 

1181 

1182 @staticmethod 

1183 def RX( 

1184 w: float, 

1185 wires: Union[int, List[int]], 

1186 pulse_params: Optional[jnp.ndarray] = None, 

1187 noise_params: Optional[Dict[str, float]] = None, 

1188 random_key: Optional[jax.random.PRNGKey] = None, 

1189 ) -> None: 

1190 """Apply X-axis rotation using the active pulse envelope. 

1191 

1192 Args: 

1193 w: Rotation angle in radians. 

1194 wires: Qubit index or indices. 

1195 pulse_params: Envelope parameters ``[env_0, ..., env_n, t]``. 

1196 If ``None``, uses optimized defaults. 

1197 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

1198 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility 

1199 """ 

1200 pulse_params = PulseInformation.RX.split_params(pulse_params) 

1201 

1202 PulseGates._record_pulse_event("RX", w, wires, pulse_params) 

1203 t = pulse_params[-1] 

1204 

1205 # Proper interaction-picture drive Hamiltonian for RX: 

1206 # H_I(τ) = Ω(τ)·cos(ω_c·τ) · [ cos(ω_q·τ)·X − sin(ω_q·τ)·Y ] 

1207 # which on resonance averages (RWA) to +(Ω/2)·X while the 

1208 # 2·ω_q counter-rotating part oscillates and cancels. 

1209 H_X = js.Hamiltonian(PulseGates.X, wires=wires) 

1210 H_Y = js.Hamiltonian(PulseGates.Y, wires=wires) 

1211 H_eff = PulseGates._coeff_RX_X * H_X + PulseGates._coeff_RX_Y * H_Y 

1212 

1213 # Pack: [envelope_params..., w] - evolution time is the last element 

1214 # of pulse_params (pulse_params[-1]). 

1215 w, random_key = UnitaryGates.GateError(w, noise_params, random_key) 

1216 # Use jnp.concatenate over Python list-splat to keep the trace graph 

1217 # compact (no per-element unpacking + restack). 

1218 env_params = jnp.concatenate( 

1219 [jnp.ravel(pulse_params[:-1]), jnp.ravel(jnp.asarray(w))] 

1220 ) 

1221 # Both terms share the same parameter array. 

1222 H_eff.evolve(name="RX")([env_params, env_params], t) 

1223 UnitaryGates.Noise(wires, noise_params) 

1224 

1225 @staticmethod 

1226 def RY( 

1227 w: float, 

1228 wires: Union[int, List[int]], 

1229 pulse_params: Optional[jnp.ndarray] = None, 

1230 noise_params: Optional[Dict[str, float]] = None, 

1231 random_key: Optional[jax.random.PRNGKey] = None, 

1232 ) -> None: 

1233 """Apply Y-axis rotation using the active pulse envelope. 

1234 

1235 Args: 

1236 w: Rotation angle in radians. 

1237 wires: Qubit index or indices. 

1238 pulse_params: Envelope parameters ``[env_0, ..., env_n, t]``. 

1239 If ``None``, uses optimized defaults. 

1240 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

1241 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility 

1242 """ 

1243 pulse_params = PulseInformation.RY.split_params(pulse_params) 

1244 

1245 PulseGates._record_pulse_event("RY", w, wires, pulse_params) 

1246 t = pulse_params[-1] 

1247 

1248 # See NOTE in RX: same proper interaction-picture form, with 

1249 # carrier phase ϕ = +π/2 so the slow RWA component drives +Y. 

1250 H_X = js.Hamiltonian(PulseGates.X, wires=wires) 

1251 H_Y = js.Hamiltonian(PulseGates.Y, wires=wires) 

1252 H_eff = PulseGates._coeff_RY_X * H_X + PulseGates._coeff_RY_Y * H_Y 

1253 

1254 # Pack w into the params so the coefficient function doesn't need 

1255 # a closure - this enables JIT solver cache sharing across all RY calls. 

1256 w, random_key = UnitaryGates.GateError(w, noise_params, random_key) 

1257 env_params = jnp.concatenate( 

1258 [jnp.ravel(pulse_params[:-1]), jnp.ravel(jnp.asarray(w))] 

1259 ) 

1260 H_eff.evolve(name="RY")([env_params, env_params], t) 

1261 UnitaryGates.Noise(wires, noise_params) 

1262 

1263 @staticmethod 

1264 def RZ( 

1265 w: float, 

1266 wires: Union[int, List[int]], 

1267 pulse_params: Optional[float] = None, 

1268 noise_params: Optional[Dict[str, float]] = None, 

1269 random_key: Optional[jax.random.PRNGKey] = None, 

1270 ) -> None: 

1271 """ 

1272 Apply Z-axis rotation using pulse-level implementation. 

1273 

1274 Implements RZ rotation using virtual Z rotations (phase tracking) 

1275 without physical pulse application. 

1276 

1277 Args: 

1278 w (float): Rotation angle in radians. 

1279 wires (Union[int, List[int]]): Qubit index or indices to apply rotation to. 

1280 pulse_params (Optional[float]): Duration parameter for the pulse. 

1281 Rotation angle = w * 2 * pulse_params. Defaults to 0.5 if None. 

1282 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

1283 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility 

1284 

1285 Returns: 

1286 None: Gate is applied in-place to the circuit. 

1287 """ 

1288 pulse_params = PulseInformation.RZ.split_params(pulse_params) 

1289 

1290 PulseGates._record_pulse_event("RZ", w, wires, pulse_params) 

1291 

1292 _H = js.Hamiltonian(PulseGates.Z, wires=wires) 

1293 H_eff = PulseGates._coeff_Sz * _H 

1294 

1295 # Pack w into the params so the coefficient function doesn't need 

1296 # a closure - [pulse_param_scalar, w] enables JIT solver cache sharing. 

1297 # pulse_params may be a 1-element array or scalar; ravel + slice the first 

1298 # element to preserve the original semantics, then concatenate with w. 

1299 w, random_key = UnitaryGates.GateError(w, noise_params, random_key) 

1300 pp_flat = jnp.ravel(jnp.asarray(pulse_params)) 

1301 H_eff.evolve(name="RZ")( 

1302 [jnp.concatenate([pp_flat[:1], jnp.ravel(jnp.asarray(w))])], 

1303 1, 

1304 ) 

1305 

1306 UnitaryGates.Noise(wires, noise_params) 

1307 

1308 @staticmethod 

1309 def _resolve_wires(wire_fn, wires): 

1310 """Resolve a wire selector string to actual wire(s). 

1311 

1312 Args: 

1313 wire_fn: ``"all"``, ``"target"``, or ``"control"``. 

1314 wires: Parent gate's wire(s) (int or list). 

1315 

1316 Returns: 

1317 Wire(s) for the child gate. 

1318 """ 

1319 wires_list = [wires] if isinstance(wires, int) else list(wires) 

1320 if wire_fn == "all": 

1321 return wires if len(wires_list) > 1 else wires_list[0] 

1322 if wire_fn == "target": 

1323 return wires_list[-1] if len(wires_list) > 1 else wires_list[0] 

1324 if wire_fn == "control": 

1325 return wires_list[0] 

1326 raise ValueError(f"Unknown wire_fn: {wire_fn!r}") 

1327 

1328 @staticmethod 

1329 def _execute_composite(gate_name, w, wires, pulse_params=None): 

1330 """Execute a composite gate by walking its decomposition. 

1331 

1332 Reads the :class:`DecompositionStep` list from 

1333 :class:`PulseInformation` and dispatches each step to the 

1334 appropriate ``PulseGates`` method. 

1335 

1336 Args: 

1337 gate_name: Gate name (e.g. ``"H"``, ``"CX"``). 

1338 w: Rotation angle(s) passed to the parent gate. 

1339 wires: Wire(s) of the parent gate. 

1340 pulse_params: Optional pulse parameters (split across children). 

1341 """ 

1342 pp_obj = PulseInformation.gate_by_name(gate_name) 

1343 parts = pp_obj.split_params(pulse_params) 

1344 

1345 for step, child_params in zip(pp_obj.decomposition, parts): 

1346 child_wires = PulseGates._resolve_wires(step.wire_fn, wires) 

1347 child_w = step.angle_fn(w) if step.angle_fn is not None else w 

1348 child_gate = getattr(PulseGates, step.gate.name) 

1349 

1350 # Leaf gates that take a rotation angle 

1351 if step.gate.name in ("RX", "RY", "RZ"): 

1352 child_gate(child_w, wires=child_wires, pulse_params=child_params) 

1353 # Leaf gates without a rotation angle 

1354 elif step.gate.name in ("CZ",): 

1355 child_gate(wires=child_wires, pulse_params=child_params) 

1356 # Composite gates with a rotation angle (CRX, CRY, CRZ, Rot, ...) 

1357 elif step.gate.name in ("Rot",): 

1358 # Rot expects (phi, theta, omega, wires, ...) 

1359 child_gate(*child_w, wires=child_wires, pulse_params=child_params) 

1360 elif step.gate.decomposition is not None and step.gate.name in ( 

1361 "CRX", 

1362 "CRY", 

1363 "CRZ", 

1364 "CPhase", 

1365 "RXX", 

1366 "RYY", 

1367 "RZZ", 

1368 "RZX", 

1369 ): 

1370 child_gate(child_w, wires=child_wires, pulse_params=child_params) 

1371 # Other composite gates (H, CX, CY, ...) 

1372 else: 

1373 child_gate(wires=child_wires, pulse_params=child_params) 

1374 

1375 @staticmethod 

1376 def H( 

1377 wires: Union[int, List[int]], 

1378 pulse_params: Optional[jnp.ndarray] = None, 

1379 noise_params: Optional[Dict[str, float]] = None, 

1380 random_key: Optional[jax.random.PRNGKey] = None, 

1381 ) -> None: 

1382 """Apply Hadamard gate using pulse decomposition. 

1383 

1384 Decomposes as RZ(π) · RY(π/2) followed by a correction phase. 

1385 

1386 Args: 

1387 wires (Union[int, List[int]]): Qubit index or indices. 

1388 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

1389 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility 

1390 (not used in this gate). 

1391 """ 

1392 PulseGates._execute_composite("H", 0.0, wires, pulse_params) 

1393 

1394 # Correction phase unique to the H gate 

1395 _H = js.Hamiltonian(PulseGates._H_corr, wires=wires) 

1396 H_corr = PulseGates._coeff_Sc * _H 

1397 H_corr.evolve(name="H")([0], 1) 

1398 UnitaryGates.Noise(wires, noise_params) 

1399 

1400 @staticmethod 

1401 def CX( 

1402 wires: List[int], 

1403 pulse_params: Optional[jnp.ndarray] = None, 

1404 noise_params: Optional[Dict[str, float]] = None, 

1405 random_key: Optional[jax.random.PRNGKey] = None, 

1406 ) -> None: 

1407 """Apply CNOT gate via decomposition: H(target) · CZ · H(target). 

1408 

1409 Args: 

1410 wires (List[int]): Control and target qubit indices [control, target]. 

1411 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the 

1412 composing gates. If None, uses optimized parameters. 

1413 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

1414 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility 

1415 (not used in this gate). 

1416 

1417 Returns: 

1418 None: Gate is applied in-place to the circuit. 

1419 """ 

1420 PulseGates._execute_composite("CX", 0.0, wires, pulse_params) 

1421 UnitaryGates.Noise(wires, noise_params) 

1422 

1423 @staticmethod 

1424 def CY( 

1425 wires: List[int], 

1426 pulse_params: Optional[jnp.ndarray] = None, 

1427 noise_params: Optional[Dict[str, float]] = None, 

1428 random_key: Optional[jax.random.PRNGKey] = None, 

1429 ) -> None: 

1430 """Apply controlled-Y via decomposition. 

1431 

1432 Args: 

1433 wires (List[int]): Control and target qubit indices [control, target]. 

1434 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the 

1435 composing gates. If None, uses optimized parameters. 

1436 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

1437 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility 

1438 (not used in this gate). 

1439 

1440 """ 

1441 PulseGates._execute_composite("CY", 0.0, wires, pulse_params) 

1442 UnitaryGates.Noise(wires, noise_params) 

1443 

1444 @staticmethod 

1445 def CZ( 

1446 wires: List[int], 

1447 pulse_params: Optional[float] = None, 

1448 noise_params: Optional[Dict[str, float]] = None, 

1449 random_key: Optional[jax.random.PRNGKey] = None, 

1450 ) -> None: 

1451 """Apply controlled-Z using ZZ coupling Hamiltonian. 

1452 

1453 Args: 

1454 wires (List[int]): Control and target qubit indices. 

1455 pulse_params (Optional[float]): Time or duration parameter for 

1456 the pulse evolution. If None, uses optimized value. 

1457 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

1458 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility 

1459 (not used in this gate). 

1460 

1461 """ 

1462 if pulse_params is None: 

1463 pulse_params = PulseInformation.CZ.params 

1464 

1465 PulseGates._record_pulse_event("CZ", 0.0, wires, pulse_params) 

1466 

1467 _H = js.Hamiltonian(PulseGates._H_CZ, wires=wires) 

1468 H_eff = PulseGates._coeff_Scz * _H 

1469 H_eff.evolve(name="CZ")([pulse_params], 1) 

1470 UnitaryGates.Noise(wires, noise_params) 

1471 

1472 @staticmethod 

1473 def CRX( 

1474 w: float, 

1475 wires: List[int], 

1476 pulse_params: Optional[jnp.ndarray] = None, 

1477 noise_params: Optional[Dict[str, float]] = None, 

1478 random_key: Optional[jax.random.PRNGKey] = None, 

1479 ) -> None: 

1480 """Apply controlled-RX via decomposition. 

1481 

1482 Args: 

1483 w (float): Rotation angle in radians. 

1484 wires (List[int]): Control and target qubit indices [control, target]. 

1485 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the 

1486 composing gates. If None, uses optimized parameters. 

1487 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

1488 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility 

1489 (not used in this gate). 

1490 """ 

1491 PulseGates._execute_composite("CRX", w, wires, pulse_params) 

1492 UnitaryGates.Noise(wires, noise_params) 

1493 

1494 @staticmethod 

1495 def CRY( 

1496 w: float, 

1497 wires: List[int], 

1498 pulse_params: Optional[jnp.ndarray] = None, 

1499 noise_params: Optional[Dict[str, float]] = None, 

1500 random_key: Optional[jax.random.PRNGKey] = None, 

1501 ) -> None: 

1502 """Apply controlled-RY via decomposition. 

1503 

1504 Args: 

1505 w (float): Rotation angle in radians. 

1506 wires (List[int]): Control and target qubit indices [control, target]. 

1507 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the 

1508 composing gates. If None, uses optimized parameters. 

1509 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

1510 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility 

1511 """ 

1512 w, random_key = UnitaryGates.GateError(w, noise_params, random_key) 

1513 PulseGates._execute_composite("CRY", w, wires, pulse_params) 

1514 UnitaryGates.Noise(wires, noise_params) 

1515 

1516 @staticmethod 

1517 def CRZ( 

1518 w: float, 

1519 wires: List[int], 

1520 pulse_params: Optional[jnp.ndarray] = None, 

1521 noise_params: Optional[Dict[str, float]] = None, 

1522 random_key: Optional[jax.random.PRNGKey] = None, 

1523 ) -> None: 

1524 """Apply controlled-RZ via decomposition. 

1525 

1526 Args: 

1527 w (float): Rotation angle in radians. 

1528 wires (List[int]): Control and target qubit indices [control, target]. 

1529 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the 

1530 composing gates. If None, uses optimized parameters. 

1531 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

1532 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility 

1533 """ 

1534 w, random_key = UnitaryGates.GateError(w, noise_params, random_key) 

1535 PulseGates._execute_composite("CRZ", w, wires, pulse_params) 

1536 UnitaryGates.Noise(wires, noise_params) 

1537 

1538 @staticmethod 

1539 def CPhase( 

1540 w: float, 

1541 wires: List[int], 

1542 pulse_params: Optional[jnp.ndarray] = None, 

1543 noise_params: Optional[Dict[str, float]] = None, 

1544 random_key: Optional[jax.random.PRNGKey] = None, 

1545 ) -> None: 

1546 """Apply controlled phase shift via decomposition. 

1547 

1548 Decomposes CPhase(φ) into RZ and CX gates: 

1549 RZ(φ/2) on control, RZ(φ/2) on target, CX, RZ(-φ/2) on target, CX. 

1550 

1551 Args: 

1552 w (float): Phase shift angle in radians. 

1553 wires (List[int]): Control and target qubit indices [control, target]. 

1554 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the 

1555 composing gates. If None, uses optimized parameters. 

1556 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

1557 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility 

1558 """ 

1559 w, random_key = UnitaryGates.GateError(w, noise_params, random_key) 

1560 PulseGates._execute_composite("CPhase", w, wires, pulse_params) 

1561 UnitaryGates.Noise(wires, noise_params) 

1562 

1563 @staticmethod 

1564 def RXX( 

1565 w: float, 

1566 wires: List[int], 

1567 pulse_params: Optional[jnp.ndarray] = None, 

1568 noise_params: Optional[Dict[str, float]] = None, 

1569 random_key: Optional[jax.random.PRNGKey] = None, 

1570 ) -> None: 

1571 """Apply two-qubit RXX rotation via decomposition. 

1572 

1573 Implements ``RXX(theta) = exp(-i theta/2 X ⊗ X)`` as 

1574 ``(H ⊗ H) · RZZ(theta) · (H ⊗ H)``. 

1575 

1576 Args: 

1577 w (float): Rotation angle in radians. 

1578 wires (List[int]): Two qubit indices. 

1579 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the 

1580 composing gates. If None, uses optimized parameters. 

1581 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

1582 random_key (Optional[jax.random.PRNGKey]): JAX random key for noise. 

1583 """ 

1584 w, random_key = UnitaryGates.GateError(w, noise_params, random_key) 

1585 PulseGates._execute_composite("RXX", w, wires, pulse_params) 

1586 UnitaryGates.Noise(wires, noise_params) 

1587 

1588 @staticmethod 

1589 def RYY( 

1590 w: float, 

1591 wires: List[int], 

1592 pulse_params: Optional[jnp.ndarray] = None, 

1593 noise_params: Optional[Dict[str, float]] = None, 

1594 random_key: Optional[jax.random.PRNGKey] = None, 

1595 ) -> None: 

1596 """Apply two-qubit RYY rotation via decomposition. 

1597 

1598 Implements ``RYY(theta) = exp(-i theta/2 Y ⊗ Y)`` by conjugating the 

1599 RZZ skeleton with ``RX(pi/2)`` rotations on both wires. 

1600 

1601 Args: 

1602 w (float): Rotation angle in radians. 

1603 wires (List[int]): Two qubit indices. 

1604 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the 

1605 composing gates. If None, uses optimized parameters. 

1606 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

1607 random_key (Optional[jax.random.PRNGKey]): JAX random key for noise. 

1608 """ 

1609 w, random_key = UnitaryGates.GateError(w, noise_params, random_key) 

1610 PulseGates._execute_composite("RYY", w, wires, pulse_params) 

1611 UnitaryGates.Noise(wires, noise_params) 

1612 

1613 @staticmethod 

1614 def RZZ( 

1615 w: float, 

1616 wires: List[int], 

1617 pulse_params: Optional[jnp.ndarray] = None, 

1618 noise_params: Optional[Dict[str, float]] = None, 

1619 random_key: Optional[jax.random.PRNGKey] = None, 

1620 ) -> None: 

1621 """Apply two-qubit RZZ rotation via decomposition. 

1622 

1623 Implements ``RZZ(theta) = exp(-i theta/2 Z ⊗ Z)`` as 

1624 ``CX · RZ(theta)_target · CX``. 

1625 

1626 Args: 

1627 w (float): Rotation angle in radians. 

1628 wires (List[int]): Two qubit indices. 

1629 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the 

1630 composing gates. If None, uses optimized parameters. 

1631 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

1632 random_key (Optional[jax.random.PRNGKey]): JAX random key for noise. 

1633 """ 

1634 w, random_key = UnitaryGates.GateError(w, noise_params, random_key) 

1635 PulseGates._execute_composite("RZZ", w, wires, pulse_params) 

1636 UnitaryGates.Noise(wires, noise_params) 

1637 

1638 @staticmethod 

1639 def RZX( 

1640 w: float, 

1641 wires: List[int], 

1642 pulse_params: Optional[jnp.ndarray] = None, 

1643 noise_params: Optional[Dict[str, float]] = None, 

1644 random_key: Optional[jax.random.PRNGKey] = None, 

1645 ) -> None: 

1646 """Apply two-qubit RZX rotation via decomposition. 

1647 

1648 Implements ``RZX(theta) = exp(-i theta/2 Z ⊗ X)`` (Z on the first 

1649 wire, X on the second) by conjugating the RZZ skeleton with a 

1650 Hadamard on the target wire. 

1651 

1652 Args: 

1653 w (float): Rotation angle in radians. 

1654 wires (List[int]): Two qubit indices ``[zwire, xwire]``. 

1655 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the 

1656 composing gates. If None, uses optimized parameters. 

1657 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

1658 random_key (Optional[jax.random.PRNGKey]): JAX random key for noise. 

1659 """ 

1660 w, random_key = UnitaryGates.GateError(w, noise_params, random_key) 

1661 PulseGates._execute_composite("RZX", w, wires, pulse_params) 

1662 UnitaryGates.Noise(wires, noise_params) 

1663 

1664 

1665class PulseParamManager: 

1666 def __init__(self, pulse_params: jnp.ndarray): 

1667 self.pulse_params = pulse_params 

1668 self.idx = 0 

1669 

1670 def get(self, n: int): 

1671 """Return the next n parameters and advance the cursor.""" 

1672 if self.idx + n > len(self.pulse_params): 

1673 raise ValueError("Not enough pulse parameters left for this gate") 

1674 # TODO: we squeeze here to get rid of any extra hidden dimension 

1675 params = self.pulse_params[self.idx : self.idx + n].squeeze() 

1676 self.idx += n 

1677 return params 

1678 

1679 

1680# Initialise PulseInformation after PulseGates exists so leaf defaults, 

1681# composite trees, mirrored PulseGates flags, and coefficient functions agree. 

1682PulseInformation.reset_defaults()