Coverage for qml_essentials / pulses.py: 84%

532 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-16 10:19 +0000

1import os 

2from contextlib import contextmanager 

3from dataclasses import dataclass 

4from typing import Optional, List, Union, Dict, Callable, Tuple 

5import csv 

6import jax.numpy as jnp 

7import jax 

8 

9from qml_essentials import operations as op 

10from qml_essentials import yaqsi as ys 

11from qml_essentials.utils import safe_random_split 

12from qml_essentials.tape import active_pulse_tape 

13from qml_essentials.unitary import UnitaryGates 

14import logging 

15 

16log = logging.getLogger(__name__) 

17 

18 

19@dataclass 

20class DecompositionStep: 

21 """One step in a composite pulse gate decomposition. 

22 

23 Attributes: 

24 gate: Child PulseParams object for this step. 

25 wire_fn: Wire selection - ``"all"``, ``"target"``, or ``"control"``. 

26 angle_fn: Maps parent angle(s) ``w`` to child angle. 

27 ``None`` means pass ``w`` through unchanged. 

28 """ 

29 

30 gate: "PulseParams" 

31 wire_fn: str = "all" 

32 angle_fn: Optional[Callable] = None 

33 

34 

35@dataclass(frozen=True) 

36class PulseStateSnapshot: 

37 """Snapshot of the mutable global pulse configuration.""" 

38 

39 envelope: str 

40 rwa: bool 

41 frame: str 

42 leaf_params: Dict[str, jnp.ndarray] 

43 

44 

45class PulseParams: 

46 """Container for hierarchical pulse parameters. 

47 

48 Leaf nodes hold direct parameters; composite nodes hold a list of 

49 :class:`DecompositionStep` objects that describe how the gate is 

50 built from simpler gates. 

51 

52 Attributes: 

53 name: Gate identifier (e.g. ``"RX"``, ``"H"``). 

54 decomposition: List of :class:`DecompositionStep` (composite only). 

55 """ 

56 

57 def __init__( 

58 self, 

59 name: str = "", 

60 params: Optional[jnp.ndarray] = None, 

61 decomposition: Optional[List[DecompositionStep]] = None, 

62 ) -> None: 

63 """ 

64 Args: 

65 name: Gate name. 

66 params: Direct pulse parameters (leaf gates). 

67 Mutually exclusive with *decomposition*. 

68 decomposition: List of :class:`DecompositionStep` (composite gates). 

69 Mutually exclusive with *params*. 

70 """ 

71 assert (params is None) != (decomposition is None), ( 

72 "Exactly one of `params` or `decomposition` must be provided." 

73 ) 

74 

75 self.decomposition = decomposition 

76 # Derive _pulse_obj for backward compat with childs/leafs/split_params 

77 self._pulse_obj = ( 

78 [step.gate for step in decomposition] if decomposition else None 

79 ) 

80 

81 if params is not None: 

82 self._params = params 

83 

84 self.name = name 

85 

86 def __len__(self) -> int: 

87 """ 

88 Get the total number of pulse parameters. 

89 

90 For composite gates, returns the accumulated count from all children. 

91 

92 Returns: 

93 int: Total number of pulse parameters. 

94 """ 

95 return len(self.params) 

96 

97 def __getitem__(self, idx: int) -> Union[float, jnp.ndarray]: 

98 """ 

99 Access pulse parameter(s) by index. 

100 

101 For leaf gates, returns the parameter at the given index. 

102 For composite gates, returns parameters of the child at the given index. 

103 

104 Args: 

105 idx (int): Index to access. 

106 

107 Returns: 

108 Union[float, jnp.ndarray]: Parameter value or child parameters. 

109 """ 

110 if self.is_leaf: 

111 return self.params[idx] 

112 else: 

113 return self.childs[idx].params 

114 

115 def __str__(self) -> str: 

116 """Return string representation (gate name).""" 

117 return self.name 

118 

119 def __repr__(self) -> str: 

120 """Return repr string (gate name).""" 

121 return self.name 

122 

123 @property 

124 def is_leaf(self) -> bool: 

125 """Check if this is a leaf node (direct parameters, no children).""" 

126 return self._pulse_obj is None 

127 

128 @property 

129 def size(self) -> int: 

130 """Get the total parameter count (alias for __len__).""" 

131 return len(self) 

132 

133 @property 

134 def leafs(self) -> List["PulseParams"]: 

135 """ 

136 Get all leaf nodes in the hierarchy. 

137 

138 Recursively collects all leaf PulseParams objects in the tree. 

139 

140 Returns: 

141 List[PulseParams]: List of unique leaf nodes. 

142 """ 

143 if self.is_leaf: 

144 return [self] 

145 

146 leafs = [] 

147 for obj in self._pulse_obj: 

148 leafs.extend(obj.leafs) 

149 

150 return list(set(leafs)) 

151 

152 @property 

153 def childs(self) -> List["PulseParams"]: 

154 """ 

155 Get direct children of this node. 

156 

157 Returns: 

158 List[PulseParams]: List of child PulseParams objects, or empty list 

159 if this is a leaf node. 

160 """ 

161 if self.is_leaf: 

162 return [] 

163 

164 return self._pulse_obj 

165 

166 @property 

167 def shape(self) -> List[int]: 

168 """ 

169 Get the shape of pulse parameters. 

170 

171 For leaf nodes, returns list with parameter count. 

172 For composite nodes, returns nested list of child shapes. 

173 

174 Returns: 

175 List[int]: Parameter shape specification. 

176 """ 

177 if self.is_leaf: 

178 return [len(self.params)] 

179 

180 shape = [] 

181 for obj in self.childs: 

182 shape.append(*obj.shape()) 

183 

184 return shape 

185 

186 @property 

187 def params(self) -> jnp.ndarray: 

188 """ 

189 Get or compute pulse parameters. 

190 

191 For leaf nodes, returns internal pulse parameters. 

192 For composite nodes, returns concatenated parameters from all children. 

193 

194 Returns: 

195 jnp.ndarray: Pulse parameters array. 

196 """ 

197 if self.is_leaf: 

198 return self._params 

199 

200 params = self.split_params(params=None, leafs=False) 

201 

202 return jnp.concatenate(params) 

203 

204 @params.setter 

205 def params(self, value: jnp.ndarray) -> None: 

206 """ 

207 Set pulse parameters. 

208 

209 For leaf nodes, sets internal parameters directly. 

210 For composite nodes, distributes values across children. 

211 

212 Args: 

213 value (jnp.ndarray): Pulse parameters to set. 

214 

215 Raises: 

216 AssertionError: If value is not jnp.ndarray for leaf nodes. 

217 """ 

218 if self.is_leaf: 

219 assert isinstance(value, jnp.ndarray), "params must be a jnp.ndarray" 

220 self._params = value 

221 return 

222 

223 idx = 0 

224 for obj in self.childs: 

225 nidx = idx + obj.size 

226 obj.params = value[idx:nidx] 

227 idx = nidx 

228 

229 @property 

230 def leaf_params(self) -> jnp.ndarray: 

231 """ 

232 Get parameters from all leaf nodes. 

233 

234 Returns: 

235 jnp.ndarray: Concatenated parameters from all leaf nodes. 

236 """ 

237 if self.is_leaf: 

238 return self._params 

239 

240 params = self.split_params(None, leafs=True) 

241 

242 return jnp.concatenate(params) 

243 

244 @leaf_params.setter 

245 def leaf_params(self, value: jnp.ndarray) -> None: 

246 """ 

247 Set parameters for all leaf nodes. 

248 

249 Args: 

250 value (jnp.ndarray): Parameters to distribute across leaf nodes. 

251 """ 

252 if self.is_leaf: 

253 self._params = value 

254 return 

255 

256 idx = 0 

257 for obj in self.leafs: 

258 nidx = idx + obj.size 

259 obj.params = value[idx:nidx] 

260 idx = nidx 

261 

262 def split_params( 

263 self, 

264 params: Optional[jnp.ndarray] = None, 

265 leafs: bool = False, 

266 ) -> List[jnp.ndarray]: 

267 """ 

268 Split parameters into sub-arrays for children or leaves. 

269 

270 Args: 

271 params (Optional[jnp.ndarray]): Parameters to split. If None, 

272 uses internal parameters. 

273 leafs (bool): If True, splits across leaf nodes; if False, 

274 splits across direct children. Defaults to False. 

275 

276 Returns: 

277 List[jnp.ndarray]: List of parameter arrays for children or leaves. 

278 """ 

279 if params is None: 

280 if self.is_leaf: 

281 return self._params 

282 

283 objs = self.leafs if leafs else self.childs 

284 s_params = [] 

285 for obj in objs: 

286 s_params.append(obj.params) 

287 

288 return s_params 

289 else: 

290 if self.is_leaf: 

291 return params 

292 

293 objs = self.leafs if leafs else self.childs 

294 s_params = [] 

295 idx = 0 

296 for obj in objs: 

297 nidx = idx + obj.size 

298 s_params.append(params[idx:nidx]) 

299 idx = nidx 

300 

301 return s_params 

302 

303 

304class PulseEnvelope: 

305 """Registry of pulse envelope shapes. 

306 

307 Each envelope is a pure function ``(p, t, t_c) -> amplitude`` that 

308 computes the pulse envelope *without* carrier modulation. The carrier 

309 ``cos(omega_c * t + phi_c)`` is applied separately in the coefficient 

310 functions built by :meth:`build_coeff_fns`. 

311 

312 Attributes: 

313 REGISTRY: Mapping from envelope name to metadata dict containing 

314 ``fn`` (callable), ``n_envelope_params`` (int), and per-gate 

315 default parameter arrays. 

316 """ 

317 

318 @staticmethod 

319 def gaussian(p, t, t_c): 

320 """Gaussian envelope. ``p = [A, sigma]``.""" 

321 A, sigma = p[0], p[1] 

322 return A * jnp.exp(-0.5 * ((t - t_c) / sigma) ** 2) 

323 

324 @staticmethod 

325 def square(p, t, t_c): 

326 """Rectangular envelope. ``p = [A, width]``.""" 

327 A, width = p[0], p[1] 

328 return A * (jnp.abs(t - t_c) <= width / 2) 

329 

330 @staticmethod 

331 def cosine(p, t, t_c): 

332 """Raised cosine envelope. ``p = [A, width]``.""" 

333 A, width = p[0], p[1] 

334 x = jnp.clip((t - t_c) / width, -0.5, 0.5) 

335 return A * jnp.cos(jnp.pi * x) 

336 

337 @staticmethod 

338 def drag(p, t, t_c): 

339 """DRAG (Derivative Removal by Adiabatic Gate). ``p = [A, beta, sigma]``.""" 

340 A, beta, sigma = p[0], p[1], p[2] 

341 g = A * jnp.exp(-0.5 * ((t - t_c) / sigma) ** 2) 

342 dg = g * (-(t - t_c) / sigma**2) 

343 return g + beta * dg 

344 

345 @staticmethod 

346 def sech(p, t, t_c): 

347 """Hyperbolic secant envelope. ``p = [A, sigma]``.""" 

348 A, sigma = p[0], p[1] 

349 return A / jnp.cosh((t - t_c) / sigma) 

350 

351 # ``n_envelope_params`` counts only the envelope parameters (excluding 

352 # the evolution time ``t`` which is always the last element of the full 

353 # pulse parameter vector). 

354 REGISTRY = { 

355 "gaussian": { 

356 "fn": gaussian.__func__, 

357 "n_envelope_params": 2, 

358 "defaults": { 

359 "RX": jnp.array( 

360 [0.38009941846766804, 1.631698142660167, 3.007403822238108] 

361 ), 

362 "RY": jnp.array( 

363 [0.3836652338514791, 1.616595983505249, 2.9794135093698966] 

364 ), 

365 }, 

366 }, 

367 "square": { 

368 "fn": square.__func__, 

369 "n_envelope_params": 2, 

370 "defaults": { 

371 "RX": jnp.array( 

372 [1.209655637514602, 0.8266815576721239, 1.1483122857413859] 

373 ), 

374 "RY": jnp.array( 

375 [1.0287942142779052, 0.9860505130182093, 0.9720116870310977] 

376 ), 

377 }, 

378 }, 

379 "cosine": { 

380 "fn": cosine.__func__, 

381 "n_envelope_params": 2, 

382 "defaults": { 

383 "RX": jnp.array([1.0, 1.0, 1.0]), 

384 "RY": jnp.array([1.0, 1.0, 1.0]), 

385 }, 

386 }, 

387 "drag": { 

388 "fn": drag.__func__, 

389 "n_envelope_params": 3, 

390 "defaults": { 

391 "RX": jnp.array( 

392 [ 

393 0.326562746114197, 

394 0.4002767596709071, 

395 5.3228107728890315, 

396 3.141300761986467, 

397 ] 

398 ), 

399 "RY": jnp.array( 

400 [ 

401 0.323287924190616, 

402 0.4065017233024265, 

403 7.00299644871222, 

404 3.139481229843545, 

405 ] 

406 ), 

407 }, 

408 }, 

409 "sech": { 

410 "fn": sech.__func__, 

411 "n_envelope_params": 2, 

412 "defaults": { 

413 "RX": jnp.array([1.0, 1.0, 1.0]), 

414 "RY": jnp.array([1.0, 1.0, 1.0]), 

415 }, 

416 }, 

417 "general": { 

418 "fn": None, 

419 "n_envelope_params": 0, 

420 "defaults": { 

421 "RZ": jnp.array([0.5]), 

422 "CZ": jnp.array([0.3183098783513154]), 

423 }, 

424 }, 

425 } 

426 

427 @staticmethod 

428 def available() -> List[str]: 

429 """Return list of registered envelope names.""" 

430 return list(PulseEnvelope.REGISTRY.keys()) 

431 

432 @staticmethod 

433 def get(name: str) -> dict: 

434 """Look up envelope metadata by name. 

435 

436 Raises: 

437 ValueError: If *name* is not registered. 

438 """ 

439 if name not in PulseEnvelope.REGISTRY: 

440 raise ValueError( 

441 f"Unknown pulse envelope '{name}'. " 

442 f"Available: {PulseEnvelope.available()}" 

443 ) 

444 return PulseEnvelope.REGISTRY[name] 

445 

446 @staticmethod 

447 def build_coeff_fns( 

448 envelope_fn: Callable, 

449 omega_c: float, 

450 omega_q: float, 

451 rwa: bool = True, 

452 frame: str = "drive", 

453 ) -> Tuple[Callable, Callable, Callable, Callable]: 

454 """Build the four interaction-picture coefficient functions. 

455 

456 The lab-frame Hamiltonian is 

457 

458 H(t,Π) = H_static + Σ_j S_j(t;Π) H_j , 

459 S_j(t;Π) = E_j(t;Π) · cos(ω_c·t + φ_c) , 

460 

461 and the interaction-picture transform with respect to 

462 ``H_static = (ω_q/2)·Z`` produces 

463 

464 H̃_j(t) = exp(+i H_static t) H_j exp(-i H_static t) , 

465 H_I(t) = Σ_j S_j(t) H̃_j(t) . 

466 

467 For a single qubit driven on X, ``H̃_X(t) = cos(ω_q·t) X 

468 − sin(ω_q·t) Y``, so 

469 

470 H_I(t) = Ω(t) · cos(ω_c·t + φ) · 

471 [ cos(ω_q·t) · X − sin(ω_q·t) · Y ] . 

472 

473 ``rwa=True`` (default) drops the fast (~2·ω_q on resonance) terms and 

474 keeps only the slow envelope, yielding the analytical RWA 

475 

476 H_I^RWA(t) = (Ω(t)/2) · [ cos(φ) X + sin(φ) Y ] . 

477 

478 For RX (``φ = 0``) this reduces to ``(Ω/2)·X``; for RY 

479 (``φ = +π/2``) to ``(Ω/2)·Y``. This is dramatically cheaper to 

480 integrate (no fast oscillations → adaptive ODE solver takes 

481 large steps). 

482 

483 ``rwa=False`` keeps **both** the slow and the fast 

484 counter-rotating components. 

485 

486 Each returned function has a unique ``__code__`` object so the 

487 yaqsi solver cache assigns separate compiled XLA programs per 

488 envelope shape and per (gate, component) pair. 

489 

490 The rotation angle ``w`` is expected as the **last** element of 

491 the parameter array ``p`` (i.e. ``p[-1]``). Envelope parameters 

492 occupy ``p[:-1]``. 

493 

494 Args: 

495 envelope_fn: Pure envelope function ``(p, t, t_c) -> scalar``. 

496 omega_c: Carrier frequency. 

497 omega_q: Qubit frequency (interaction-picture rotation rate). 

498 rwa: When ``True``, return the RWA-truncated coefficients 

499 (no fast counter-rotating terms). Default ``True`` 

500 frame: Algebraic representation of the exact (non-RWA) 

501 coefficients. Mathematically equivalent options: 

502 

503 * ``"drive"`` (default): applies the product-to-sum identity to 

504 expose the slow ``(ω_c-ω_q)`` and fast ``(ω_c+ω_q)`` 

505 modes explicitly, 

506 ``cos(ω_c t)cos(ω_q t) = 

507 ½[cos((ω_c-ω_q)t) + cos((ω_c+ω_q)t)]``. Algebraically 

508 identical to ``"lab"`` (no RWA, no information lost). 

509 Primary use: combined with the ``magnus2``/``magnus4`` 

510 yaqsi solvers, the explicit slow/fast decomposition 

511 is sometimes numerically better-conditioned and lets 

512 the user pick a fixed grid based on the slow 

513 frequency alone (``Δ = |ω_c-ω_q|``) when the fast 

514 ``(ω_c+ω_q)`` mode is well-resolved by the chosen 

515 step. 

516 * ``"drive"``: the literal form 

517 ``Ω(t) cos(ω_c t + φ) cos(ω_q t)`` (and the analogous 

518 ``-sin`` term). Two trig multiplications per call; 

519 contains all four product frequencies implicitly. 

520 

521 Ignored when ``rwa=True``. 

522 

523 Returns: 

524 Tuple ``(coeff_RX_X, coeff_RX_Y, coeff_RY_X, coeff_RY_Y)`` 

525 of coefficient functions for the X- and Y-components of the 

526 RX and RY interaction-picture Hamiltonians. 

527 """ 

528 if frame not in ("lab", "drive"): 

529 raise ValueError(f"Unknown frame {frame!r}; expected 'lab' or 'drive'.") 

530 if rwa: 

531 # RWA-truncated coefficients (no carrier, no fast factors). 

532 # H_I^RWA = (Ω(t)/2) [cos(φ) X + sin(φ) Y]; we keep the 

533 # ``p[-1]`` rotation-angle scaling so the calling 

534 # ParametrizedHamiltonian shape is unchanged. 

535 half = jnp.asarray(0.5) 

536 

537 def _coeff_RX_X(p, t): 

538 t_c = t / 2 

539 env = envelope_fn(p, t, t_c) 

540 return half * env * p[-1] 

541 

542 def _coeff_RX_Y(p, t): # Y component vanishes for RX (φ=0) 

543 t_c = t / 2 

544 env = envelope_fn(p, t, t_c) 

545 return jnp.zeros_like(half * env * p[-1]) 

546 

547 def _coeff_RY_X(p, t): # X component vanishes for RY (φ=π/2) 

548 t_c = t / 2 

549 env = envelope_fn(p, t, t_c) 

550 return jnp.zeros_like(half * env * p[-1]) 

551 

552 def _coeff_RY_Y(p, t): 

553 t_c = t / 2 

554 env = envelope_fn(p, t, t_c) 

555 return half * env * p[-1] 

556 

557 return _coeff_RX_X, _coeff_RX_Y, _coeff_RY_X, _coeff_RY_Y 

558 

559 if frame == "drive": 

560 # Drive-frame: same exact dynamics, expressed via the 

561 # product-to-sum identities so the slow ``Δ = ω_c - ω_q`` 

562 # and fast ``Σ = ω_c + ω_q`` modes appear explicitly. 

563 # Mathematically identical to the ``lab`` branch below. 

564 # 

565 # Identities used: 

566 # cos(ω_c t) cos(ω_q t) = ½[cos(Δ t) + cos(Σ t)] 

567 # cos(ω_c t) sin(ω_q t) = ½[sin(Σ t) − sin(Δ t)] 

568 # −sin(ω_c t) cos(ω_q t) = −½[sin(Σ t) + sin(Δ t)] 

569 # −sin(ω_c t) sin(ω_q t) = ½[cos(Σ t) − cos(Δ t)] 

570 # (RY uses cos(ω_c t + π/2) = −sin(ω_c t).) 

571 omega_d = omega_c - omega_q 

572 omega_s = omega_c + omega_q 

573 half = jnp.asarray(0.5) 

574 

575 def _coeff_RX_X(p, t): 

576 t_c = t / 2 

577 env = envelope_fn(p, t, t_c) 

578 mod = half * (jnp.cos(omega_d * t) + jnp.cos(omega_s * t)) 

579 return env * mod * p[-1] 

580 

581 def _coeff_RX_Y(p, t): 

582 t_c = t / 2 

583 env = envelope_fn(p, t, t_c) 

584 mod = -half * (jnp.sin(omega_s * t) - jnp.sin(omega_d * t)) 

585 return env * mod * p[-1] 

586 

587 def _coeff_RY_X(p, t): 

588 t_c = t / 2 

589 env = envelope_fn(p, t, t_c) 

590 mod = -half * (jnp.sin(omega_s * t) + jnp.sin(omega_d * t)) 

591 return env * mod * p[-1] 

592 

593 def _coeff_RY_Y(p, t): 

594 t_c = t / 2 

595 env = envelope_fn(p, t, t_c) 

596 mod = -half * (jnp.cos(omega_s * t) - jnp.cos(omega_d * t)) 

597 return env * mod * p[-1] 

598 

599 return _coeff_RX_X, _coeff_RX_Y, _coeff_RY_X, _coeff_RY_Y 

600 

601 # RX uses carrier phase phi = 0 so that after RWA 

602 # cos(ω_q τ)·cos(ω_q τ) averages to +1/2 → drives +X 

603 # -cos(ω_q τ)·sin(ω_q τ) averages to 0 → Y cancels 

604 # giving H_I^RWA ≈ (Ω/2)·X → U ≈ exp(-iθ/2 X), matching op.RX. 

605 # The exact form below KEEPS the fast 2·ω_q components. 

606 def _coeff_RX_X(p, t): 

607 t_c = t / 2 

608 env = envelope_fn(p, t, t_c) 

609 carrier = jnp.cos(omega_c * t) 

610 return env * carrier * jnp.cos(omega_q * t) * p[-1] 

611 

612 def _coeff_RX_Y(p, t): 

613 t_c = t / 2 

614 env = envelope_fn(p, t, t_c) 

615 carrier = jnp.cos(omega_c * t) 

616 return -env * carrier * jnp.sin(omega_q * t) * p[-1] 

617 

618 # RY uses carrier phase phi = +pi/2 so the RWA component drives +Y. 

619 def _coeff_RY_X(p, t): 

620 t_c = t / 2 

621 env = envelope_fn(p, t, t_c) 

622 carrier = jnp.cos(omega_c * t + jnp.pi / 2) 

623 return env * carrier * jnp.cos(omega_q * t) * p[-1] 

624 

625 def _coeff_RY_Y(p, t): 

626 t_c = t / 2 

627 env = envelope_fn(p, t, t_c) 

628 carrier = jnp.cos(omega_c * t + jnp.pi / 2) 

629 return -env * carrier * jnp.sin(omega_q * t) * p[-1] 

630 

631 return _coeff_RX_X, _coeff_RX_Y, _coeff_RY_X, _coeff_RY_Y 

632 

633 

634class PulseInformation: 

635 """Stores pulse parameter counts and optimized pulse parameters. 

636 

637 Call :meth:`set_envelope` to switch the active pulse shape. This 

638 rebuilds all :class:`PulseParams` trees so that parameter counts 

639 and defaults match the selected envelope. 

640 """ 

641 

642 DEFAULT_ENVELOPE: str = "drag" 

643 DEFAULT_RWA: bool = True 

644 DEFAULT_FRAME: str = "drive" 

645 LEAF_GATE_NAMES: Tuple[str, ...] = ("RX", "RY", "RZ", "CZ") 

646 

647 _envelope: str = DEFAULT_ENVELOPE 

648 # Whether to apply the rotating-wave approximation when building the 

649 # interaction-picture coefficient functions. 

650 # Default ``True`` (exact dynamics, no RWA). 

651 # Setting to ``True`` drops the fast counter-rotating terms — 

652 # much faster to integrate 

653 # See :meth:`PulseEnvelope.build_coeff_fns`. 

654 _rwa: bool = DEFAULT_RWA 

655 # Algebraic representation of the (non-RWA) coefficients. Either 

656 # ``"lab"`` or ``"drive"`` (product-to-sum decomposition). 

657 # Mathematically equivalent — see :meth:`PulseEnvelope.build_coeff_fns` 

658 # when ``"drive"`` is numerically advantageous (mainly with the Magnus solvers). 

659 _frame: str = DEFAULT_FRAME 

660 

661 @classmethod 

662 def _build_leaf_gates(cls): 

663 """(Re-)create leaf PulseParams from the active envelope defaults.""" 

664 defaults = PulseEnvelope.get(cls._envelope)["defaults"] 

665 general = PulseEnvelope.get("general")["defaults"] 

666 

667 cls.RX = PulseParams(name="RX", params=defaults["RX"]) 

668 cls.RY = PulseParams(name="RY", params=defaults["RY"]) 

669 

670 cls.RZ = PulseParams(name="RZ", params=general["RZ"]) 

671 cls.CZ = PulseParams(name="CZ", params=general["CZ"]) 

672 

673 @classmethod 

674 def _build_composite_gates(cls): 

675 """(Re-)create composite PulseParams trees from current leaves.""" 

676 cls.H = PulseParams( 

677 name="H", 

678 decomposition=[ 

679 DecompositionStep(cls.RZ, "all", lambda w: jnp.pi), 

680 DecompositionStep(cls.RY, "all", lambda w: jnp.pi / 2), 

681 ], 

682 ) 

683 cls.CX = PulseParams( 

684 name="CX", 

685 decomposition=[ 

686 DecompositionStep(cls.H, "target", lambda w: 0.0), 

687 DecompositionStep(cls.CZ, "all", lambda w: 0.0), 

688 DecompositionStep(cls.H, "target", lambda w: 0.0), 

689 ], 

690 ) 

691 cls.CY = PulseParams( 

692 name="CY", 

693 decomposition=[ 

694 DecompositionStep(cls.RZ, "target", lambda w: -jnp.pi / 2), 

695 DecompositionStep(cls.CX, "all"), 

696 DecompositionStep(cls.RZ, "target", lambda w: jnp.pi / 2), 

697 ], 

698 ) 

699 cls.CRX = PulseParams( 

700 name="CRX", 

701 decomposition=[ 

702 DecompositionStep(cls.RZ, "target", lambda w: jnp.pi / 2), 

703 DecompositionStep(cls.RY, "target", lambda w: w / 2), 

704 DecompositionStep(cls.CX, "all", lambda w: 0.0), 

705 DecompositionStep(cls.RY, "target", lambda w: -w / 2), 

706 DecompositionStep(cls.CX, "all", lambda w: 0.0), 

707 DecompositionStep(cls.RZ, "target", lambda w: -jnp.pi / 2), 

708 ], 

709 ) 

710 cls.CRY = PulseParams( 

711 name="CRY", 

712 decomposition=[ 

713 DecompositionStep(cls.RY, "target", lambda w: w / 2), 

714 DecompositionStep(cls.CX, "all", lambda w: 0.0), 

715 DecompositionStep(cls.RY, "target", lambda w: -w / 2), 

716 DecompositionStep(cls.CX, "all", lambda w: 0.0), 

717 ], 

718 ) 

719 cls.CRZ = PulseParams( 

720 name="CRZ", 

721 decomposition=[ 

722 DecompositionStep(cls.RZ, "target", lambda w: w / 2), 

723 DecompositionStep(cls.CX, "all", lambda w: 0.0), 

724 DecompositionStep(cls.RZ, "target", lambda w: -w / 2), 

725 DecompositionStep(cls.CX, "all", lambda w: 0.0), 

726 ], 

727 ) 

728 # TODO: check if we could just make this a basis gate instead 

729 cls.CPhase = PulseParams( 

730 name="CPhase", 

731 decomposition=[ 

732 DecompositionStep(cls.RZ, "control", lambda w: w / 2), 

733 DecompositionStep(cls.RZ, "target", lambda w: w / 2), 

734 DecompositionStep(cls.CX, "all", lambda w: 0.0), 

735 DecompositionStep(cls.RZ, "target", lambda w: -w / 2), 

736 DecompositionStep(cls.CX, "all", lambda w: 0.0), 

737 ], 

738 ) 

739 cls.Rot = PulseParams( 

740 name="Rot", 

741 decomposition=[ 

742 DecompositionStep(cls.RZ, "all", lambda w: w[0]), 

743 DecompositionStep(cls.RY, "all", lambda w: w[1]), 

744 DecompositionStep(cls.RZ, "all", lambda w: w[2]), 

745 ], 

746 ) 

747 cls.unique_gate_set = [cls.RX, cls.RY, cls.RZ, cls.CZ] 

748 

749 @classmethod 

750 def set_envelope( 

751 cls, 

752 name: str, 

753 rwa: Optional[bool] = None, 

754 frame: Optional[str] = None, 

755 ) -> None: 

756 """Switch pulse envelope and rebuild all PulseParams trees. 

757 

758 Also updates the coefficient functions used by :class:`PulseGates`. 

759 

760 Args: 

761 name: One of :meth:`PulseEnvelope.available`. 

762 rwa: If given, also update the RWA flag. If ``None`` (the 

763 default), the current value of ``cls._rwa`` is kept. 

764 See :meth:`PulseEnvelope.build_coeff_fns` for the 

765 physical meaning of the flag. 

766 frame: If given, also update the coefficient frame 

767 (``"lab"`` or ``"drive"``). ``None`` keeps the current 

768 value of ``cls._frame``. Ignored when ``rwa=True`` or 

769 when the existing RWA flag is on. 

770 """ 

771 info = PulseEnvelope.get(name) # validates name 

772 cls._envelope = name 

773 if rwa is not None: 

774 cls._rwa = bool(rwa) 

775 if frame is not None: 

776 if frame not in ("lab", "drive"): 

777 raise ValueError(f"Unknown frame {frame!r}; expected 'lab' or 'drive'.") 

778 cls._frame = frame 

779 cls._build_leaf_gates() 

780 cls._build_composite_gates() 

781 

782 # Rebuild interaction-picture coefficient functions on PulseGates. 

783 # Four functions: (RX_X, RX_Y, RY_X, RY_Y) — one per (gate, Pauli) 

784 # component of the proper interaction-picture drive Hamiltonian. 

785 rx_x, rx_y, ry_x, ry_y = PulseEnvelope.build_coeff_fns( 

786 info["fn"], 

787 PulseGates.omega_c, 

788 PulseGates.omega_q, 

789 rwa=cls._rwa, 

790 frame=cls._frame, 

791 ) 

792 PulseGates._coeff_RX_X = staticmethod(rx_x) 

793 PulseGates._coeff_RX_Y = staticmethod(rx_y) 

794 PulseGates._coeff_RY_X = staticmethod(ry_x) 

795 PulseGates._coeff_RY_Y = staticmethod(ry_y) 

796 # Backward-compat aliases for older introspection (point at the 

797 # X-component which dominates RX, Y-component which dominates RY). 

798 PulseGates._coeff_Sx = staticmethod(rx_x) 

799 PulseGates._coeff_Sy = staticmethod(ry_y) 

800 PulseGates._active_envelope = name 

801 PulseGates._active_rwa = cls._rwa 

802 PulseGates._active_frame = cls._frame 

803 

804 # The compiled-solver cache in ``Yaqsi`` is keyed on the code 

805 # objects of the coefficient functions. Rebuilding the coeff 

806 # fns above produced fresh code objects, so any cached solver 

807 # is now unreachable from the live coefficient functions and 

808 # must be evicted to avoid both (a) holding compiled programs 

809 # for a previous configuration alive forever and (b) returning 

810 # a stale program if ``id`` collisions ever leaked through. 

811 # Lazy import to prevent circular imports. 

812 from qml_essentials.yaqsi import Yaqsi 

813 

814 Yaqsi.clear_evolve_solver_cache() 

815 

816 log.info( 

817 f"Pulse envelope set to '{name}' " 

818 f"(RWA {'on' if cls._rwa else 'off'}, frame={cls._frame})" 

819 ) 

820 

821 @classmethod 

822 def set_rwa(cls, rwa: bool) -> None: 

823 """Toggle the rotating-wave approximation for pulse coefficients. 

824 

825 Rebuilds the coefficient functions for the currently active 

826 envelope so the change takes effect immediately. Default is 

827 ``False`` (exact interaction picture). 

828 See :meth:`PulseEnvelope.build_coeff_fns` for details 

829 """ 

830 cls.set_envelope(cls._envelope, rwa=bool(rwa)) 

831 

832 @classmethod 

833 def get_envelope(cls) -> str: 

834 """Return the name of the active pulse envelope.""" 

835 return cls._envelope 

836 

837 @classmethod 

838 def get_rwa(cls) -> bool: 

839 """Return whether the RWA flag is currently active.""" 

840 return cls._rwa 

841 

842 @classmethod 

843 def set_frame(cls, frame: str) -> None: 

844 """Switch the algebraic representation of the (non-RWA) coefficients. 

845 

846 ``"lab"`` (default) and ``"drive"`` are mathematically 

847 identical (no information lost, no RWA applied) — see 

848 :meth:`PulseEnvelope.build_coeff_fns` for when ``"drive"`` is 

849 useful. Rebuilds the coefficient functions for the currently 

850 active envelope so the change takes effect immediately. 

851 """ 

852 cls.set_envelope(cls._envelope, frame=str(frame)) 

853 

854 @classmethod 

855 def get_frame(cls) -> str: 

856 """Return the active coefficient frame (``"lab"`` or ``"drive"``).""" 

857 return cls._frame 

858 

859 @classmethod 

860 def snapshot_state(cls) -> PulseStateSnapshot: 

861 """Return an immutable snapshot of the active pulse configuration.""" 

862 leaf_params = {} 

863 for name in cls.LEAF_GATE_NAMES: 

864 gate = getattr(cls, name, None) 

865 if gate is not None: 

866 leaf_params[name] = jnp.array(gate.params) 

867 

868 return PulseStateSnapshot( 

869 envelope=cls._envelope, 

870 rwa=cls._rwa, 

871 frame=cls._frame, 

872 leaf_params=leaf_params, 

873 ) 

874 

875 @classmethod 

876 def restore_state(cls, snapshot: PulseStateSnapshot) -> None: 

877 """Restore a snapshot produced by :meth:`snapshot_state`.""" 

878 cls.set_envelope(snapshot.envelope, rwa=snapshot.rwa, frame=snapshot.frame) 

879 

880 for name, params in snapshot.leaf_params.items(): 

881 gate = cls.gate_by_name(name) 

882 if gate is None or not gate.is_leaf: 

883 raise ValueError(f"Cannot restore unknown leaf pulse gate {name!r}.") 

884 if gate.params.shape != params.shape: 

885 raise ValueError( 

886 f"Snapshot for {name!r} has shape {params.shape}, " 

887 f"but active gate expects {gate.params.shape}." 

888 ) 

889 gate.params = params 

890 

891 @classmethod 

892 @contextmanager 

893 def preserve_state(cls): 

894 """Temporarily preserve global pulse state across scoped mutations.""" 

895 snapshot = cls.snapshot_state() 

896 try: 

897 yield snapshot 

898 finally: 

899 cls.restore_state(snapshot) 

900 

901 @classmethod 

902 def reset_defaults( 

903 cls, 

904 envelope: Optional[str] = None, 

905 rwa: Optional[bool] = None, 

906 frame: Optional[str] = None, 

907 ) -> None: 

908 """Reset pulse globals to canonical defaults or explicit values.""" 

909 cls.set_envelope( 

910 cls.DEFAULT_ENVELOPE if envelope is None else envelope, 

911 rwa=cls.DEFAULT_RWA if rwa is None else rwa, 

912 frame=cls.DEFAULT_FRAME if frame is None else frame, 

913 ) 

914 

915 @staticmethod 

916 def gate_by_name(gate): 

917 if isinstance(gate, str): 

918 return getattr(PulseInformation, gate, None) 

919 else: 

920 return getattr(PulseInformation, gate.__name__, None) 

921 

922 @staticmethod 

923 def num_params(gate): 

924 return len(PulseInformation.gate_by_name(gate)) 

925 

926 @staticmethod 

927 def update_params(path=f"{os.getcwd()}/qml_essentials/qoc_results.csv"): 

928 if os.path.isfile(path): 

929 log.info(f"Loading optimized pulses from {path}") 

930 with open(path, "r") as f: 

931 reader = csv.reader(f) 

932 

933 for row in reader: 

934 log.debug( 

935 f"Loading optimized pulses for {row[0]}\ 

936 (Fidelity: {float(row[1]):.5f}): {row[2:]}" 

937 ) 

938 PulseInformation.OPTIMIZED_PULSES[row[0]] = jnp.array( 

939 [float(x) for x in row[2:]] 

940 ) 

941 else: 

942 log.error(f"No optimized pulses found at {path}") 

943 

944 @staticmethod 

945 def shuffle_params(random_key): 

946 log.info( 

947 f"Shuffling optimized pulses with random key {random_key}\ 

948 of gates {PulseInformation.unique_gate_set}" 

949 ) 

950 for gate in PulseInformation.unique_gate_set: 

951 random_key, sub_key = safe_random_split(random_key) 

952 gate.params = jax.random.uniform(sub_key, (len(gate),)) 

953 

954 

955class PulseGates: 

956 """Pulse-level implementations of quantum gates. 

957 

958 Implements quantum gates using time-dependent Hamiltonians and pulse 

959 sequences, following the approach from https://doi.org/10.5445/IR/1000184129. 

960 The active pulse envelope is selected via 

961 :meth:`PulseInformation.set_envelope`. 

962 

963 Attributes: 

964 omega_q: Qubit frequency (10π). 

965 omega_c: Carrier frequency (10π). 

966 _active_envelope: Name of the currently active envelope shape. 

967 """ 

968 

969 # NOTE: Implementation of S, RX, RY, RZ, CZ, CNOT/CX and H pulse level 

970 # gates closely follow https://doi.org/10.5445/IR/1000184129 

971 omega_q = 10 * jnp.pi 

972 omega_c = 10 * jnp.pi 

973 

974 X = jnp.array([[0, 1], [1, 0]]) 

975 Y = jnp.array([[0, -1j], [1j, 0]]) 

976 Z = jnp.array([[1, 0], [0, -1]]) 

977 

978 Id = jnp.eye(2, dtype=jnp.complex64) 

979 

980 _H_CZ = (jnp.pi / 4) * ( 

981 jnp.kron(Id, Id) - jnp.kron(Z, Id) - jnp.kron(Id, Z) + jnp.kron(Z, Z) 

982 ) 

983 

984 _H_corr = jnp.pi / 2 * jnp.eye(2, dtype=jnp.complex64) 

985 

986 _active_envelope: str = "gaussian" 

987 # Mirrors :attr:`PulseInformation._rwa`; kept here for introspection 

988 # of which coefficient regime the active ``_coeff_*`` functions 

989 # implement. Updated by :meth:`PulseInformation.set_envelope` / 

990 # :meth:`PulseInformation.set_rwa`. 

991 _active_rwa: bool = True 

992 _active_frame: str = "drive" 

993 

994 # Default coefficient functions for the gaussian envelope; the active 

995 # envelope's `set_envelope` will overwrite these. Each gate uses two 

996 # coefficients (X- and Y-component of the proper interaction-picture 

997 # drive Hamiltonian). 

998 

999 @staticmethod 

1000 def _coeff_RX_X(p, t): 

1001 """RX coefficient for the X term (gaussian default).""" 

1002 t_c = t / 2 

1003 env = PulseEnvelope.gaussian(p, t, t_c) 

1004 carrier = jnp.cos(PulseGates.omega_c * t) 

1005 return env * carrier * jnp.cos(PulseGates.omega_q * t) * p[-1] 

1006 

1007 @staticmethod 

1008 def _coeff_RX_Y(p, t): 

1009 """RX coefficient for the Y term (gaussian default).""" 

1010 t_c = t / 2 

1011 env = PulseEnvelope.gaussian(p, t, t_c) 

1012 carrier = jnp.cos(PulseGates.omega_c * t) 

1013 return -env * carrier * jnp.sin(PulseGates.omega_q * t) * p[-1] 

1014 

1015 @staticmethod 

1016 def _coeff_RY_X(p, t): 

1017 """RY coefficient for the X term (gaussian default).""" 

1018 t_c = t / 2 

1019 env = PulseEnvelope.gaussian(p, t, t_c) 

1020 carrier = jnp.cos(PulseGates.omega_c * t + jnp.pi / 2) 

1021 return env * carrier * jnp.cos(PulseGates.omega_q * t) * p[-1] 

1022 

1023 @staticmethod 

1024 def _coeff_RY_Y(p, t): 

1025 """RY coefficient for the Y term (gaussian default).""" 

1026 t_c = t / 2 

1027 env = PulseEnvelope.gaussian(p, t, t_c) 

1028 carrier = jnp.cos(PulseGates.omega_c * t + jnp.pi / 2) 

1029 return -env * carrier * jnp.sin(PulseGates.omega_q * t) * p[-1] 

1030 

1031 # Backward-compat aliases (resolve to the dominant component of each gate). 

1032 _coeff_Sx = _coeff_RX_X 

1033 _coeff_Sy = _coeff_RY_Y 

1034 

1035 @staticmethod 

1036 def _coeff_Sz(p, t): 

1037 """Coefficient function for RZ pulse: p * w.""" 

1038 return p[0] * p[1] 

1039 

1040 @staticmethod 

1041 def _coeff_Sc(p, t): 

1042 """Constant coefficient for H correction phase.""" 

1043 return -1.0 

1044 

1045 @staticmethod 

1046 def _coeff_Scz(p, t): 

1047 """Coefficient function for CZ pulse.""" 

1048 return p * jnp.pi 

1049 

1050 @staticmethod 

1051 def _record_pulse_event(gate_name, w, wires, pulse_params, parent=None): 

1052 """Append a PulseEvent to the active pulse tape if recording. 

1053 

1054 This is called from leaf gate methods (RX, RY, RZ, CZ) so that 

1055 :func:`~qml_essentials.tape.pulse_recording` can collect events 

1056 without the caller needing to know about the tape. 

1057 """ 

1058 ptape = active_pulse_tape() 

1059 if ptape is None: 

1060 return 

1061 

1062 from qml_essentials.drawing import PulseEvent, LEAF_META 

1063 

1064 meta = LEAF_META.get(gate_name, {}) 

1065 wires_list = [wires] if isinstance(wires, int) else list(wires) 

1066 

1067 if meta.get("physical", False): 

1068 info = PulseEnvelope.get(PulseInformation.get_envelope()) 

1069 pp = PulseInformation.gate_by_name(gate_name).split_params(pulse_params) 

1070 env_p = pp[:-1] 

1071 dur = float(pp[-1]) 

1072 ptape.append( 

1073 PulseEvent( 

1074 gate=gate_name, 

1075 wires=wires_list, 

1076 envelope_fn=info["fn"], 

1077 envelope_params=jnp.array(env_p), 

1078 w=float(w), 

1079 duration=dur, 

1080 carrier_phase=meta["carrier_phase"], 

1081 parent=parent, 

1082 ) 

1083 ) 

1084 else: 

1085 pp = PulseInformation.gate_by_name(gate_name).split_params(pulse_params) 

1086 ptape.append( 

1087 PulseEvent( 

1088 gate=gate_name, 

1089 wires=wires_list, 

1090 envelope_fn=None, 

1091 envelope_params=jnp.ravel(jnp.asarray(pp)), 

1092 w=float(w) if not isinstance(w, list) else 0.0, 

1093 duration=1.0, 

1094 carrier_phase=0.0, 

1095 parent=parent, 

1096 ) 

1097 ) 

1098 

1099 @staticmethod 

1100 def Rot( 

1101 phi: float, 

1102 theta: float, 

1103 omega: float, 

1104 wires: Union[int, List[int]], 

1105 pulse_params: Optional[jnp.ndarray] = None, 

1106 noise_params: Optional[Dict[str, float]] = None, 

1107 random_key: Optional[jax.random.PRNGKey] = None, 

1108 ) -> None: 

1109 """ 

1110 Apply general rotation via decomposition: RZ(phi) · RY(theta) · RZ(omega). 

1111 

1112 Args: 

1113 phi (float): First rotation angle. 

1114 theta (float): Second rotation angle. 

1115 omega (float): Third rotation angle. 

1116 wires (Union[int, List[int]]): Qubit index or indices to apply rotation to. 

1117 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the 

1118 composing gates. If None, uses optimized parameters. 

1119 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

1120 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility 

1121 

1122 Returns: 

1123 None: Gates are applied in-place to the circuit. 

1124 """ 

1125 if noise_params is not None and "GateError" in noise_params: 

1126 phi, random_key = UnitaryGates.GateError(phi, noise_params, random_key) 

1127 theta, random_key = UnitaryGates.GateError(theta, noise_params, random_key) 

1128 omega, random_key = UnitaryGates.GateError(omega, noise_params, random_key) 

1129 PulseGates._execute_composite("Rot", [phi, theta, omega], wires, pulse_params) 

1130 UnitaryGates.Noise(wires, noise_params) 

1131 

1132 @staticmethod 

1133 def PauliRot( 

1134 pauli: str, 

1135 theta: float, 

1136 wires: Union[int, List[int]], 

1137 pulse_params: Optional[jnp.ndarray] = None, 

1138 noise_params: Optional[Dict[str, float]] = None, 

1139 random_key: Optional[jax.random.PRNGKey] = None, 

1140 ) -> None: 

1141 """Not implemented as a PulseGate.""" 

1142 raise NotImplementedError("PauliRot gate is not implemented as PulseGate") 

1143 

1144 @staticmethod 

1145 def RX( 

1146 w: float, 

1147 wires: Union[int, List[int]], 

1148 pulse_params: Optional[jnp.ndarray] = None, 

1149 noise_params: Optional[Dict[str, float]] = None, 

1150 random_key: Optional[jax.random.PRNGKey] = None, 

1151 ) -> None: 

1152 """Apply X-axis rotation using the active pulse envelope. 

1153 

1154 Args: 

1155 w: Rotation angle in radians. 

1156 wires: Qubit index or indices. 

1157 pulse_params: Envelope parameters ``[env_0, ..., env_n, t]``. 

1158 If ``None``, uses optimized defaults. 

1159 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

1160 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility 

1161 """ 

1162 pulse_params = PulseInformation.RX.split_params(pulse_params) 

1163 

1164 PulseGates._record_pulse_event("RX", w, wires, pulse_params) 

1165 t = pulse_params[-1] 

1166 

1167 # Proper interaction-picture drive Hamiltonian for RX: 

1168 # H_I(τ) = Ω(τ)·cos(ω_c·τ) · [ cos(ω_q·τ)·X − sin(ω_q·τ)·Y ] 

1169 # which on resonance averages (RWA) to +(Ω/2)·X while the 

1170 # 2·ω_q counter-rotating part oscillates and cancels. 

1171 H_X = op.Hermitian(PulseGates.X, wires=wires, record=False) 

1172 H_Y = op.Hermitian(PulseGates.Y, wires=wires, record=False) 

1173 H_eff = PulseGates._coeff_RX_X * H_X + PulseGates._coeff_RX_Y * H_Y 

1174 

1175 # Pack: [envelope_params..., w] - evolution time is the last element 

1176 # of pulse_params (pulse_params[-1]). 

1177 w, random_key = UnitaryGates.GateError(w, noise_params, random_key) 

1178 # Use jnp.concatenate over Python list-splat to keep the trace graph 

1179 # compact (no per-element unpacking + restack). 

1180 env_params = jnp.concatenate( 

1181 [jnp.ravel(pulse_params[:-1]), jnp.ravel(jnp.asarray(w))] 

1182 ) 

1183 # Both terms share the same parameter array. 

1184 ys.evolve(H_eff, name="RX")([env_params, env_params], t) 

1185 UnitaryGates.Noise(wires, noise_params) 

1186 

1187 @staticmethod 

1188 def RY( 

1189 w: float, 

1190 wires: Union[int, List[int]], 

1191 pulse_params: Optional[jnp.ndarray] = None, 

1192 noise_params: Optional[Dict[str, float]] = None, 

1193 random_key: Optional[jax.random.PRNGKey] = None, 

1194 ) -> None: 

1195 """Apply Y-axis rotation using the active pulse envelope. 

1196 

1197 Args: 

1198 w: Rotation angle in radians. 

1199 wires: Qubit index or indices. 

1200 pulse_params: Envelope parameters ``[env_0, ..., env_n, t]``. 

1201 If ``None``, uses optimized defaults. 

1202 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

1203 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility 

1204 """ 

1205 pulse_params = PulseInformation.RY.split_params(pulse_params) 

1206 

1207 PulseGates._record_pulse_event("RY", w, wires, pulse_params) 

1208 t = pulse_params[-1] 

1209 

1210 # See NOTE in RX: same proper interaction-picture form, with 

1211 # carrier phase ϕ = +π/2 so the slow RWA component drives +Y. 

1212 H_X = op.Hermitian(PulseGates.X, wires=wires, record=False) 

1213 H_Y = op.Hermitian(PulseGates.Y, wires=wires, record=False) 

1214 H_eff = PulseGates._coeff_RY_X * H_X + PulseGates._coeff_RY_Y * H_Y 

1215 

1216 # Pack w into the params so the coefficient function doesn't need 

1217 # a closure - this enables JIT solver cache sharing across all RY calls. 

1218 w, random_key = UnitaryGates.GateError(w, noise_params, random_key) 

1219 env_params = jnp.concatenate( 

1220 [jnp.ravel(pulse_params[:-1]), jnp.ravel(jnp.asarray(w))] 

1221 ) 

1222 ys.evolve(H_eff, name="RY")([env_params, env_params], t) 

1223 UnitaryGates.Noise(wires, noise_params) 

1224 

1225 @staticmethod 

1226 def RZ( 

1227 w: float, 

1228 wires: Union[int, List[int]], 

1229 pulse_params: Optional[float] = None, 

1230 noise_params: Optional[Dict[str, float]] = None, 

1231 random_key: Optional[jax.random.PRNGKey] = None, 

1232 ) -> None: 

1233 """ 

1234 Apply Z-axis rotation using pulse-level implementation. 

1235 

1236 Implements RZ rotation using virtual Z rotations (phase tracking) 

1237 without physical pulse application. 

1238 

1239 Args: 

1240 w (float): Rotation angle in radians. 

1241 wires (Union[int, List[int]]): Qubit index or indices to apply rotation to. 

1242 pulse_params (Optional[float]): Duration parameter for the pulse. 

1243 Rotation angle = w * 2 * pulse_params. Defaults to 0.5 if None. 

1244 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

1245 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility 

1246 

1247 Returns: 

1248 None: Gate is applied in-place to the circuit. 

1249 """ 

1250 pulse_params = PulseInformation.RZ.split_params(pulse_params) 

1251 

1252 PulseGates._record_pulse_event("RZ", w, wires, pulse_params) 

1253 

1254 _H = op.Hermitian(PulseGates.Z, wires=wires, record=False) 

1255 H_eff = PulseGates._coeff_Sz * _H 

1256 

1257 # Pack w into the params so the coefficient function doesn't need 

1258 # a closure - [pulse_param_scalar, w] enables JIT solver cache sharing. 

1259 # pulse_params may be a 1-element array or scalar; ravel + slice the first 

1260 # element to preserve the original semantics, then concatenate with w. 

1261 w, random_key = UnitaryGates.GateError(w, noise_params, random_key) 

1262 pp_flat = jnp.ravel(jnp.asarray(pulse_params)) 

1263 ys.evolve(H_eff, name="RZ")( 

1264 [jnp.concatenate([pp_flat[:1], jnp.ravel(jnp.asarray(w))])], 

1265 1, 

1266 ) 

1267 

1268 UnitaryGates.Noise(wires, noise_params) 

1269 

1270 @staticmethod 

1271 def _resolve_wires(wire_fn, wires): 

1272 """Resolve a wire selector string to actual wire(s). 

1273 

1274 Args: 

1275 wire_fn: ``"all"``, ``"target"``, or ``"control"``. 

1276 wires: Parent gate's wire(s) (int or list). 

1277 

1278 Returns: 

1279 Wire(s) for the child gate. 

1280 """ 

1281 wires_list = [wires] if isinstance(wires, int) else list(wires) 

1282 if wire_fn == "all": 

1283 return wires if len(wires_list) > 1 else wires_list[0] 

1284 if wire_fn == "target": 

1285 return wires_list[-1] if len(wires_list) > 1 else wires_list[0] 

1286 if wire_fn == "control": 

1287 return wires_list[0] 

1288 raise ValueError(f"Unknown wire_fn: {wire_fn!r}") 

1289 

1290 @staticmethod 

1291 def _execute_composite(gate_name, w, wires, pulse_params=None): 

1292 """Execute a composite gate by walking its decomposition. 

1293 

1294 Reads the :class:`DecompositionStep` list from 

1295 :class:`PulseInformation` and dispatches each step to the 

1296 appropriate ``PulseGates`` method. 

1297 

1298 Args: 

1299 gate_name: Gate name (e.g. ``"H"``, ``"CX"``). 

1300 w: Rotation angle(s) passed to the parent gate. 

1301 wires: Wire(s) of the parent gate. 

1302 pulse_params: Optional pulse parameters (split across children). 

1303 """ 

1304 pp_obj = PulseInformation.gate_by_name(gate_name) 

1305 parts = pp_obj.split_params(pulse_params) 

1306 

1307 for step, child_params in zip(pp_obj.decomposition, parts): 

1308 child_wires = PulseGates._resolve_wires(step.wire_fn, wires) 

1309 child_w = step.angle_fn(w) if step.angle_fn is not None else w 

1310 child_gate = getattr(PulseGates, step.gate.name) 

1311 

1312 # Leaf gates that take a rotation angle 

1313 if step.gate.name in ("RX", "RY", "RZ"): 

1314 child_gate(child_w, wires=child_wires, pulse_params=child_params) 

1315 # Leaf gates without a rotation angle 

1316 elif step.gate.name in ("CZ",): 

1317 child_gate(wires=child_wires, pulse_params=child_params) 

1318 # Composite gates with a rotation angle (CRX, CRY, CRZ, Rot, ...) 

1319 elif step.gate.name in ("Rot",): 

1320 # Rot expects (phi, theta, omega, wires, ...) 

1321 child_gate(*child_w, wires=child_wires, pulse_params=child_params) 

1322 elif step.gate.decomposition is not None and step.gate.name in ( 

1323 "CRX", 

1324 "CRY", 

1325 "CRZ", 

1326 "CPhase", 

1327 ): 

1328 child_gate(child_w, wires=child_wires, pulse_params=child_params) 

1329 # Other composite gates (H, CX, CY, ...) 

1330 else: 

1331 child_gate(wires=child_wires, pulse_params=child_params) 

1332 

1333 @staticmethod 

1334 def H( 

1335 wires: Union[int, List[int]], 

1336 pulse_params: Optional[jnp.ndarray] = None, 

1337 noise_params: Optional[Dict[str, float]] = None, 

1338 random_key: Optional[jax.random.PRNGKey] = None, 

1339 ) -> None: 

1340 """Apply Hadamard gate using pulse decomposition. 

1341 

1342 Decomposes as RZ(π) · RY(π/2) followed by a correction phase. 

1343 

1344 Args: 

1345 wires (Union[int, List[int]]): Qubit index or indices. 

1346 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

1347 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility 

1348 (not used in this gate). 

1349 """ 

1350 PulseGates._execute_composite("H", 0.0, wires, pulse_params) 

1351 

1352 # Correction phase unique to the H gate 

1353 _H = op.Hermitian(PulseGates._H_corr, wires=wires, record=False) 

1354 H_corr = PulseGates._coeff_Sc * _H 

1355 ys.evolve(H_corr, name="H")([0], 1) 

1356 UnitaryGates.Noise(wires, noise_params) 

1357 

1358 @staticmethod 

1359 def CX( 

1360 wires: List[int], 

1361 pulse_params: Optional[jnp.ndarray] = None, 

1362 noise_params: Optional[Dict[str, float]] = None, 

1363 random_key: Optional[jax.random.PRNGKey] = None, 

1364 ) -> None: 

1365 """Apply CNOT gate via decomposition: H(target) · CZ · H(target). 

1366 

1367 Args: 

1368 wires (List[int]): Control and target qubit indices [control, target]. 

1369 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the 

1370 composing gates. If None, uses optimized parameters. 

1371 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

1372 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility 

1373 (not used in this gate). 

1374 

1375 Returns: 

1376 None: Gate is applied in-place to the circuit. 

1377 """ 

1378 PulseGates._execute_composite("CX", 0.0, wires, pulse_params) 

1379 UnitaryGates.Noise(wires, noise_params) 

1380 

1381 @staticmethod 

1382 def CY( 

1383 wires: List[int], 

1384 pulse_params: Optional[jnp.ndarray] = None, 

1385 noise_params: Optional[Dict[str, float]] = None, 

1386 random_key: Optional[jax.random.PRNGKey] = None, 

1387 ) -> None: 

1388 """Apply controlled-Y via decomposition. 

1389 

1390 Args: 

1391 wires (List[int]): Control and target qubit indices [control, target]. 

1392 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the 

1393 composing gates. If None, uses optimized parameters. 

1394 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

1395 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility 

1396 (not used in this gate). 

1397 

1398 """ 

1399 PulseGates._execute_composite("CY", 0.0, wires, pulse_params) 

1400 UnitaryGates.Noise(wires, noise_params) 

1401 

1402 @staticmethod 

1403 def CZ( 

1404 wires: List[int], 

1405 pulse_params: Optional[float] = None, 

1406 noise_params: Optional[Dict[str, float]] = None, 

1407 random_key: Optional[jax.random.PRNGKey] = None, 

1408 ) -> None: 

1409 """Apply controlled-Z using ZZ coupling Hamiltonian. 

1410 

1411 Args: 

1412 wires (List[int]): Control and target qubit indices. 

1413 pulse_params (Optional[float]): Time or duration parameter for 

1414 the pulse evolution. If None, uses optimized value. 

1415 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

1416 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility 

1417 (not used in this gate). 

1418 

1419 """ 

1420 if pulse_params is None: 

1421 pulse_params = PulseInformation.CZ.params 

1422 

1423 PulseGates._record_pulse_event("CZ", 0.0, wires, pulse_params) 

1424 

1425 _H = op.Hermitian(PulseGates._H_CZ, wires=wires, record=False) 

1426 H_eff = PulseGates._coeff_Scz * _H 

1427 ys.evolve(H_eff, name="CZ")([pulse_params], 1) 

1428 UnitaryGates.Noise(wires, noise_params) 

1429 

1430 @staticmethod 

1431 def CRX( 

1432 w: float, 

1433 wires: List[int], 

1434 pulse_params: Optional[jnp.ndarray] = None, 

1435 noise_params: Optional[Dict[str, float]] = None, 

1436 random_key: Optional[jax.random.PRNGKey] = None, 

1437 ) -> None: 

1438 """Apply controlled-RX via decomposition. 

1439 

1440 Args: 

1441 w (float): Rotation angle in radians. 

1442 wires (List[int]): Control and target qubit indices [control, target]. 

1443 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the 

1444 composing gates. If None, uses optimized parameters. 

1445 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

1446 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility 

1447 (not used in this gate). 

1448 """ 

1449 PulseGates._execute_composite("CRX", w, wires, pulse_params) 

1450 UnitaryGates.Noise(wires, noise_params) 

1451 

1452 @staticmethod 

1453 def CRY( 

1454 w: float, 

1455 wires: List[int], 

1456 pulse_params: Optional[jnp.ndarray] = None, 

1457 noise_params: Optional[Dict[str, float]] = None, 

1458 random_key: Optional[jax.random.PRNGKey] = None, 

1459 ) -> None: 

1460 """Apply controlled-RY via decomposition. 

1461 

1462 Args: 

1463 w (float): Rotation angle in radians. 

1464 wires (List[int]): Control and target qubit indices [control, target]. 

1465 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the 

1466 composing gates. If None, uses optimized parameters. 

1467 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

1468 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility 

1469 """ 

1470 w, random_key = UnitaryGates.GateError(w, noise_params, random_key) 

1471 PulseGates._execute_composite("CRY", w, wires, pulse_params) 

1472 UnitaryGates.Noise(wires, noise_params) 

1473 

1474 @staticmethod 

1475 def CRZ( 

1476 w: float, 

1477 wires: List[int], 

1478 pulse_params: Optional[jnp.ndarray] = None, 

1479 noise_params: Optional[Dict[str, float]] = None, 

1480 random_key: Optional[jax.random.PRNGKey] = None, 

1481 ) -> None: 

1482 """Apply controlled-RZ via decomposition. 

1483 

1484 Args: 

1485 w (float): Rotation angle in radians. 

1486 wires (List[int]): Control and target qubit indices [control, target]. 

1487 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the 

1488 composing gates. If None, uses optimized parameters. 

1489 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

1490 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility 

1491 """ 

1492 w, random_key = UnitaryGates.GateError(w, noise_params, random_key) 

1493 PulseGates._execute_composite("CRZ", w, wires, pulse_params) 

1494 UnitaryGates.Noise(wires, noise_params) 

1495 

1496 @staticmethod 

1497 def CPhase( 

1498 w: float, 

1499 wires: List[int], 

1500 pulse_params: Optional[jnp.ndarray] = None, 

1501 noise_params: Optional[Dict[str, float]] = None, 

1502 random_key: Optional[jax.random.PRNGKey] = None, 

1503 ) -> None: 

1504 """Apply controlled phase shift via decomposition. 

1505 

1506 Decomposes CPhase(φ) into RZ and CX gates: 

1507 RZ(φ/2) on control, RZ(φ/2) on target, CX, RZ(-φ/2) on target, CX. 

1508 

1509 Args: 

1510 w (float): Phase shift angle in radians. 

1511 wires (List[int]): Control and target qubit indices [control, target]. 

1512 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the 

1513 composing gates. If None, uses optimized parameters. 

1514 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

1515 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility 

1516 """ 

1517 w, random_key = UnitaryGates.GateError(w, noise_params, random_key) 

1518 PulseGates._execute_composite("CPhase", w, wires, pulse_params) 

1519 UnitaryGates.Noise(wires, noise_params) 

1520 

1521 

1522class PulseParamManager: 

1523 def __init__(self, pulse_params: jnp.ndarray): 

1524 self.pulse_params = pulse_params 

1525 self.idx = 0 

1526 

1527 def get(self, n: int): 

1528 """Return the next n parameters and advance the cursor.""" 

1529 if self.idx + n > len(self.pulse_params): 

1530 raise ValueError("Not enough pulse parameters left for this gate") 

1531 # TODO: we squeeze here to get rid of any extra hidden dimension 

1532 params = self.pulse_params[self.idx : self.idx + n].squeeze() 

1533 self.idx += n 

1534 return params 

1535 

1536 

1537# Initialise PulseInformation after PulseGates exists so leaf defaults, 

1538# composite trees, mirrored PulseGates flags, and coefficient functions agree. 

1539PulseInformation.reset_defaults()