Coverage for qml_essentials / pulses.py: 82%

387 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-30 11:43 +0000

1import os 

2from dataclasses import dataclass 

3from typing import Optional, List, Union, Dict, Callable 

4import csv 

5import jax.numpy as jnp 

6import jax 

7 

8from qml_essentials import operations as op 

9from qml_essentials import yaqsi as ys 

10from qml_essentials.utils import safe_random_split 

11from qml_essentials.tape import active_pulse_tape 

12from qml_essentials.unitary import UnitaryGates 

13import logging 

14 

15log = logging.getLogger(__name__) 

16 

17 

18@dataclass 

19class DecompositionStep: 

20 """One step in a composite pulse gate decomposition. 

21 

22 Attributes: 

23 gate: Child PulseParams object for this step. 

24 wire_fn: Wire selection - ``"all"``, ``"target"``, or ``"control"``. 

25 angle_fn: Maps parent angle(s) ``w`` to child angle. 

26 ``None`` means pass ``w`` through unchanged. 

27 """ 

28 

29 gate: "PulseParams" 

30 wire_fn: str = "all" 

31 angle_fn: Optional[Callable] = None 

32 

33 

34class PulseParams: 

35 """Container for hierarchical pulse parameters. 

36 

37 Leaf nodes hold direct parameters; composite nodes hold a list of 

38 :class:`DecompositionStep` objects that describe how the gate is 

39 built from simpler gates. 

40 

41 Attributes: 

42 name: Gate identifier (e.g. ``"RX"``, ``"H"``). 

43 decomposition: List of :class:`DecompositionStep` (composite only). 

44 """ 

45 

46 def __init__( 

47 self, 

48 name: str = "", 

49 params: Optional[jnp.ndarray] = None, 

50 decomposition: Optional[List[DecompositionStep]] = None, 

51 ) -> None: 

52 """ 

53 Args: 

54 name: Gate name. 

55 params: Direct pulse parameters (leaf gates). 

56 Mutually exclusive with *decomposition*. 

57 decomposition: List of :class:`DecompositionStep` (composite gates). 

58 Mutually exclusive with *params*. 

59 """ 

60 assert (params is None) != ( 

61 decomposition is None 

62 ), "Exactly one of `params` or `decomposition` must be provided." 

63 

64 self.decomposition = decomposition 

65 # Derive _pulse_obj for backward compat with childs/leafs/split_params 

66 self._pulse_obj = ( 

67 [step.gate for step in decomposition] if decomposition else None 

68 ) 

69 

70 if params is not None: 

71 self._params = params 

72 

73 self.name = name 

74 

75 def __len__(self) -> int: 

76 """ 

77 Get the total number of pulse parameters. 

78 

79 For composite gates, returns the accumulated count from all children. 

80 

81 Returns: 

82 int: Total number of pulse parameters. 

83 """ 

84 return len(self.params) 

85 

86 def __getitem__(self, idx: int) -> Union[float, jnp.ndarray]: 

87 """ 

88 Access pulse parameter(s) by index. 

89 

90 For leaf gates, returns the parameter at the given index. 

91 For composite gates, returns parameters of the child at the given index. 

92 

93 Args: 

94 idx (int): Index to access. 

95 

96 Returns: 

97 Union[float, jnp.ndarray]: Parameter value or child parameters. 

98 """ 

99 if self.is_leaf: 

100 return self.params[idx] 

101 else: 

102 return self.childs[idx].params 

103 

104 def __str__(self) -> str: 

105 """Return string representation (gate name).""" 

106 return self.name 

107 

108 def __repr__(self) -> str: 

109 """Return repr string (gate name).""" 

110 return self.name 

111 

112 @property 

113 def is_leaf(self) -> bool: 

114 """Check if this is a leaf node (direct parameters, no children).""" 

115 return self._pulse_obj is None 

116 

117 @property 

118 def size(self) -> int: 

119 """Get the total parameter count (alias for __len__).""" 

120 return len(self) 

121 

122 @property 

123 def leafs(self) -> List["PulseParams"]: 

124 """ 

125 Get all leaf nodes in the hierarchy. 

126 

127 Recursively collects all leaf PulseParams objects in the tree. 

128 

129 Returns: 

130 List[PulseParams]: List of unique leaf nodes. 

131 """ 

132 if self.is_leaf: 

133 return [self] 

134 

135 leafs = [] 

136 for obj in self._pulse_obj: 

137 leafs.extend(obj.leafs) 

138 

139 return list(set(leafs)) 

140 

141 @property 

142 def childs(self) -> List["PulseParams"]: 

143 """ 

144 Get direct children of this node. 

145 

146 Returns: 

147 List[PulseParams]: List of child PulseParams objects, or empty list 

148 if this is a leaf node. 

149 """ 

150 if self.is_leaf: 

151 return [] 

152 

153 return self._pulse_obj 

154 

155 @property 

156 def shape(self) -> List[int]: 

157 """ 

158 Get the shape of pulse parameters. 

159 

160 For leaf nodes, returns list with parameter count. 

161 For composite nodes, returns nested list of child shapes. 

162 

163 Returns: 

164 List[int]: Parameter shape specification. 

165 """ 

166 if self.is_leaf: 

167 return [len(self.params)] 

168 

169 shape = [] 

170 for obj in self.childs: 

171 shape.append(*obj.shape()) 

172 

173 return shape 

174 

175 @property 

176 def params(self) -> jnp.ndarray: 

177 """ 

178 Get or compute pulse parameters. 

179 

180 For leaf nodes, returns internal pulse parameters. 

181 For composite nodes, returns concatenated parameters from all children. 

182 

183 Returns: 

184 jnp.ndarray: Pulse parameters array. 

185 """ 

186 if self.is_leaf: 

187 return self._params 

188 

189 params = self.split_params(params=None, leafs=False) 

190 

191 return jnp.concatenate(params) 

192 

193 @params.setter 

194 def params(self, value: jnp.ndarray) -> None: 

195 """ 

196 Set pulse parameters. 

197 

198 For leaf nodes, sets internal parameters directly. 

199 For composite nodes, distributes values across children. 

200 

201 Args: 

202 value (jnp.ndarray): Pulse parameters to set. 

203 

204 Raises: 

205 AssertionError: If value is not jnp.ndarray for leaf nodes. 

206 """ 

207 if self.is_leaf: 

208 assert isinstance(value, jnp.ndarray), "params must be a jnp.ndarray" 

209 self._params = value 

210 return 

211 

212 idx = 0 

213 for obj in self.childs: 

214 nidx = idx + obj.size 

215 obj.params = value[idx:nidx] 

216 idx = nidx 

217 

218 @property 

219 def leaf_params(self) -> jnp.ndarray: 

220 """ 

221 Get parameters from all leaf nodes. 

222 

223 Returns: 

224 jnp.ndarray: Concatenated parameters from all leaf nodes. 

225 """ 

226 if self.is_leaf: 

227 return self._params 

228 

229 params = self.split_params(None, leafs=True) 

230 

231 return jnp.concatenate(params) 

232 

233 @leaf_params.setter 

234 def leaf_params(self, value: jnp.ndarray) -> None: 

235 """ 

236 Set parameters for all leaf nodes. 

237 

238 Args: 

239 value (jnp.ndarray): Parameters to distribute across leaf nodes. 

240 """ 

241 if self.is_leaf: 

242 self._params = value 

243 return 

244 

245 idx = 0 

246 for obj in self.leafs: 

247 nidx = idx + obj.size 

248 obj.params = value[idx:nidx] 

249 idx = nidx 

250 

251 def split_params( 

252 self, 

253 params: Optional[jnp.ndarray] = None, 

254 leafs: bool = False, 

255 ) -> List[jnp.ndarray]: 

256 """ 

257 Split parameters into sub-arrays for children or leaves. 

258 

259 Args: 

260 params (Optional[jnp.ndarray]): Parameters to split. If None, 

261 uses internal parameters. 

262 leafs (bool): If True, splits across leaf nodes; if False, 

263 splits across direct children. Defaults to False. 

264 

265 Returns: 

266 List[jnp.ndarray]: List of parameter arrays for children or leaves. 

267 """ 

268 if params is None: 

269 if self.is_leaf: 

270 return self._params 

271 

272 objs = self.leafs if leafs else self.childs 

273 s_params = [] 

274 for obj in objs: 

275 s_params.append(obj.params) 

276 

277 return s_params 

278 else: 

279 if self.is_leaf: 

280 return params 

281 

282 objs = self.leafs if leafs else self.childs 

283 s_params = [] 

284 idx = 0 

285 for obj in objs: 

286 nidx = idx + obj.size 

287 s_params.append(params[idx:nidx]) 

288 idx = nidx 

289 

290 return s_params 

291 

292 

293class PulseEnvelope: 

294 """Registry of pulse envelope shapes. 

295 

296 Each envelope is a pure function ``(p, t, t_c) -> amplitude`` that 

297 computes the pulse envelope *without* carrier modulation. The carrier 

298 ``cos(omega_c * t + phi_c)`` is applied separately in the coefficient 

299 functions built by :meth:`build_coeff_fns`. 

300 

301 Attributes: 

302 REGISTRY: Mapping from envelope name to metadata dict containing 

303 ``fn`` (callable), ``n_envelope_params`` (int), and per-gate 

304 default parameter arrays. 

305 """ 

306 

307 @staticmethod 

308 def gaussian(p, t, t_c): 

309 """Gaussian envelope. ``p = [A, sigma]``.""" 

310 A, sigma = p[0], p[1] 

311 return A * jnp.exp(-0.5 * ((t - t_c) / sigma) ** 2) 

312 

313 @staticmethod 

314 def square(p, t, t_c): 

315 """Rectangular envelope. ``p = [A, width]``.""" 

316 A, width = p[0], p[1] 

317 return A * (jnp.abs(t - t_c) <= width / 2) 

318 

319 @staticmethod 

320 def cosine(p, t, t_c): 

321 """Raised cosine envelope. ``p = [A, width]``.""" 

322 A, width = p[0], p[1] 

323 x = jnp.clip((t - t_c) / width, -0.5, 0.5) 

324 return A * jnp.cos(jnp.pi * x) 

325 

326 @staticmethod 

327 def drag(p, t, t_c): 

328 """DRAG (Derivative Removal by Adiabatic Gate). ``p = [A, beta, sigma]``.""" 

329 A, beta, sigma = p[0], p[1], p[2] 

330 g = A * jnp.exp(-0.5 * ((t - t_c) / sigma) ** 2) 

331 dg = g * (-(t - t_c) / sigma**2) 

332 return g + beta * dg 

333 

334 @staticmethod 

335 def sech(p, t, t_c): 

336 """Hyperbolic secant envelope. ``p = [A, sigma]``.""" 

337 A, sigma = p[0], p[1] 

338 return A / jnp.cosh((t - t_c) / sigma) 

339 

340 # ``n_envelope_params`` counts only the envelope parameters (excluding 

341 # the evolution time ``t`` which is always the last element of the full 

342 # pulse parameter vector). 

343 REGISTRY = { 

344 "gaussian": { 

345 "fn": gaussian.__func__, 

346 "n_envelope_params": 2, 

347 "defaults": { 

348 "RX": jnp.array( 

349 [30.187402725219727, 0.32704535126686096, 0.320675790309906] 

350 ), 

351 "RY": jnp.array( 

352 [10.794735903531707, 0.12725685459013134, 0.3157523181268348] 

353 ), 

354 }, 

355 }, 

356 "square": { 

357 "fn": square.__func__, 

358 "n_envelope_params": 2, 

359 "defaults": { 

360 "RX": jnp.array([1.0, 1.0, 1.0]), 

361 "RY": jnp.array([1.0, 1.0, 1.0]), 

362 }, 

363 }, 

364 "cosine": { 

365 "fn": cosine.__func__, 

366 "n_envelope_params": 2, 

367 "defaults": { 

368 "RX": jnp.array([1.0, 1.0, 1.0]), 

369 "RY": jnp.array([1.0, 1.0, 1.0]), 

370 }, 

371 }, 

372 "drag": { 

373 "fn": drag.__func__, 

374 "n_envelope_params": 3, 

375 "defaults": { 

376 "RX": jnp.array([1.0, 1.0, 0.1, 1.0]), 

377 "RY": jnp.array([1.0, 1.0, 0.1, 1.0]), 

378 }, 

379 }, 

380 "sech": { 

381 "fn": sech.__func__, 

382 "n_envelope_params": 2, 

383 "defaults": { 

384 "RX": jnp.array([1.0, 1.0, 1.0]), 

385 "RY": jnp.array([1.0, 1.0, 1.0]), 

386 }, 

387 }, 

388 "general": { 

389 "fn": None, 

390 "n_envelope_params": 0, 

391 "defaults": { 

392 "RZ": jnp.array([0.5]), 

393 "CZ": jnp.array([0.31831514835357666]), 

394 }, 

395 }, 

396 } 

397 

398 @staticmethod 

399 def available() -> List[str]: 

400 """Return list of registered envelope names.""" 

401 return list(PulseEnvelope.REGISTRY.keys()) 

402 

403 @staticmethod 

404 def get(name: str) -> dict: 

405 """Look up envelope metadata by name. 

406 

407 Raises: 

408 ValueError: If *name* is not registered. 

409 """ 

410 if name not in PulseEnvelope.REGISTRY: 

411 raise ValueError( 

412 f"Unknown pulse envelope '{name}'. " 

413 f"Available: {PulseEnvelope.available()}" 

414 ) 

415 return PulseEnvelope.REGISTRY[name] 

416 

417 @staticmethod 

418 def build_coeff_fns(envelope_fn, omega_c): 

419 """Build ``(coeff_Sx, coeff_Sy)`` for a given envelope function. 

420 

421 Each returned function has a unique ``__code__`` object so that 

422 the yaqsi JIT solver cache (keyed on ``id(coeff_fn.__code__)``) 

423 assigns a separate compiled XLA program per envelope shape. 

424 

425 The rotation angle ``w`` is expected as the **last** element of the 

426 parameter array ``p`` (i.e. ``p[-1]``). Envelope parameters occupy 

427 ``p[:-1]`` (excluding the evolution-time element that is passed 

428 separately to ``ys.evolve``). 

429 

430 Args: 

431 envelope_fn: Pure envelope function ``(p, t, t_c) -> scalar``. 

432 omega_c: Carrier frequency. 

433 

434 Returns: 

435 Tuple of ``(coeff_Sx, coeff_Sy)``. 

436 """ 

437 

438 def _coeff_Sx(p, t): 

439 t_c = t / 2 

440 env = envelope_fn(p, t, t_c) 

441 carrier = jnp.cos(omega_c * t + jnp.pi) 

442 return env * carrier * p[-1] 

443 

444 def _coeff_Sy(p, t): 

445 t_c = t / 2 

446 env = envelope_fn(p, t, t_c) 

447 carrier = jnp.cos(omega_c * t - jnp.pi / 2) 

448 return env * carrier * p[-1] 

449 

450 return _coeff_Sx, _coeff_Sy 

451 

452 

453class PulseInformation: 

454 """Stores pulse parameter counts and optimized pulse parameters. 

455 

456 Call :meth:`set_envelope` to switch the active pulse shape. This 

457 rebuilds all :class:`PulseParams` trees so that parameter counts 

458 and defaults match the selected envelope. 

459 """ 

460 

461 _envelope: str = "gaussian" 

462 

463 @classmethod 

464 def _build_leaf_gates(cls): 

465 """(Re-)create leaf PulseParams from the active envelope defaults.""" 

466 defaults = PulseEnvelope.get(cls._envelope)["defaults"] 

467 general = PulseEnvelope.get("general")["defaults"] 

468 

469 cls.RX = PulseParams(name="RX", params=defaults["RX"]) 

470 cls.RY = PulseParams(name="RY", params=defaults["RY"]) 

471 

472 cls.RZ = PulseParams(name="RZ", params=general["RZ"]) 

473 cls.CZ = PulseParams(name="CZ", params=general["CZ"]) 

474 

475 @classmethod 

476 def _build_composite_gates(cls): 

477 """(Re-)create composite PulseParams trees from current leaves.""" 

478 cls.H = PulseParams( 

479 name="H", 

480 decomposition=[ 

481 DecompositionStep(cls.RZ, "all", lambda w: jnp.pi), 

482 DecompositionStep(cls.RY, "all", lambda w: jnp.pi / 2), 

483 ], 

484 ) 

485 cls.CX = PulseParams( 

486 name="CX", 

487 decomposition=[ 

488 DecompositionStep(cls.H, "target", lambda w: 0.0), 

489 DecompositionStep(cls.CZ, "all", lambda w: 0.0), 

490 DecompositionStep(cls.H, "target", lambda w: 0.0), 

491 ], 

492 ) 

493 cls.CY = PulseParams( 

494 name="CY", 

495 decomposition=[ 

496 DecompositionStep(cls.RZ, "target", lambda w: -jnp.pi / 2), 

497 DecompositionStep(cls.CX, "all"), 

498 DecompositionStep(cls.RZ, "target", lambda w: jnp.pi / 2), 

499 ], 

500 ) 

501 cls.CRX = PulseParams( 

502 name="CRX", 

503 decomposition=[ 

504 DecompositionStep(cls.RZ, "target", lambda w: jnp.pi / 2), 

505 DecompositionStep(cls.RY, "target", lambda w: w / 2), 

506 DecompositionStep(cls.CX, "all", lambda w: 0.0), 

507 DecompositionStep(cls.RY, "target", lambda w: -w / 2), 

508 DecompositionStep(cls.CX, "all", lambda w: 0.0), 

509 DecompositionStep(cls.RZ, "target", lambda w: -jnp.pi / 2), 

510 ], 

511 ) 

512 cls.CRY = PulseParams( 

513 name="CRY", 

514 decomposition=[ 

515 DecompositionStep(cls.RY, "target", lambda w: w / 2), 

516 DecompositionStep(cls.CX, "all", lambda w: 0.0), 

517 DecompositionStep(cls.RY, "target", lambda w: -w / 2), 

518 DecompositionStep(cls.CX, "all", lambda w: 0.0), 

519 ], 

520 ) 

521 cls.CRZ = PulseParams( 

522 name="CRZ", 

523 decomposition=[ 

524 DecompositionStep(cls.RZ, "target", lambda w: w / 2), 

525 DecompositionStep(cls.CX, "all", lambda w: 0.0), 

526 DecompositionStep(cls.RZ, "target", lambda w: -w / 2), 

527 DecompositionStep(cls.CX, "all", lambda w: 0.0), 

528 ], 

529 ) 

530 cls.Rot = PulseParams( 

531 name="Rot", 

532 decomposition=[ 

533 DecompositionStep(cls.RZ, "all", lambda w: w[0]), 

534 DecompositionStep(cls.RY, "all", lambda w: w[1]), 

535 DecompositionStep(cls.RZ, "all", lambda w: w[2]), 

536 ], 

537 ) 

538 cls.unique_gate_set = [cls.RX, cls.RY, cls.RZ, cls.CZ] 

539 

540 @classmethod 

541 def set_envelope(cls, name: str) -> None: 

542 """Switch pulse envelope and rebuild all PulseParams trees. 

543 

544 Also updates the coefficient functions used by :class:`PulseGates`. 

545 

546 Args: 

547 name: One of :meth:`PulseEnvelope.available`. 

548 """ 

549 info = PulseEnvelope.get(name) # validates name 

550 cls._envelope = name 

551 cls._build_leaf_gates() 

552 cls._build_composite_gates() 

553 

554 # Rebuild coefficient functions on PulseGates 

555 coeff_Sx, coeff_Sy = PulseEnvelope.build_coeff_fns( 

556 info["fn"], PulseGates.omega_c 

557 ) 

558 PulseGates._coeff_Sx = staticmethod(coeff_Sx) 

559 PulseGates._coeff_Sy = staticmethod(coeff_Sy) 

560 PulseGates._active_envelope = name 

561 

562 log.info(f"Pulse envelope set to '{name}'") 

563 

564 @classmethod 

565 def get_envelope(cls) -> str: 

566 """Return the name of the active pulse envelope.""" 

567 return cls._envelope 

568 

569 @staticmethod 

570 def gate_by_name(gate): 

571 if isinstance(gate, str): 

572 return getattr(PulseInformation, gate, None) 

573 else: 

574 return getattr(PulseInformation, gate.__name__, None) 

575 

576 @staticmethod 

577 def num_params(gate): 

578 return len(PulseInformation.gate_by_name(gate)) 

579 

580 @staticmethod 

581 def update_params(path=f"{os.getcwd()}/qml_essentials/qoc_results.csv"): 

582 if os.path.isfile(path): 

583 log.info(f"Loading optimized pulses from {path}") 

584 with open(path, "r") as f: 

585 reader = csv.reader(f) 

586 

587 for row in reader: 

588 log.debug( 

589 f"Loading optimized pulses for {row[0]}\ 

590 (Fidelity: {float(row[1]):.5f}): {row[2:]}" 

591 ) 

592 PulseInformation.OPTIMIZED_PULSES[row[0]] = jnp.array( 

593 [float(x) for x in row[2:]] 

594 ) 

595 else: 

596 log.error(f"No optimized pulses found at {path}") 

597 

598 @staticmethod 

599 def shuffle_params(random_key): 

600 log.info( 

601 f"Shuffling optimized pulses with random key {random_key}\ 

602 of gates {PulseInformation.unique_gate_set}" 

603 ) 

604 for gate in PulseInformation.unique_gate_set: 

605 random_key, sub_key = safe_random_split(random_key) 

606 gate.params = jax.random.uniform(sub_key, (len(gate),)) 

607 

608 

609# Initialise PulseInformation with default (gaussian) envelope 

610PulseInformation._build_leaf_gates() 

611PulseInformation._build_composite_gates() 

612 

613 

614class PulseGates: 

615 """Pulse-level implementations of quantum gates. 

616 

617 Implements quantum gates using time-dependent Hamiltonians and pulse 

618 sequences, following the approach from https://doi.org/10.5445/IR/1000184129. 

619 The active pulse envelope is selected via 

620 :meth:`PulseInformation.set_envelope`. 

621 

622 Attributes: 

623 omega_q: Qubit frequency (10π). 

624 omega_c: Carrier frequency (10π). 

625 _active_envelope: Name of the currently active envelope shape. 

626 """ 

627 

628 # NOTE: Implementation of S, RX, RY, RZ, CZ, CNOT/CX and H pulse level 

629 # gates closely follow https://doi.org/10.5445/IR/1000184129 

630 omega_q = 10 * jnp.pi 

631 omega_c = 10 * jnp.pi 

632 

633 H_static = jnp.array( 

634 [[jnp.exp(1j * omega_q / 2), 0], [0, jnp.exp(-1j * omega_q / 2)]] 

635 ) 

636 

637 Id = jnp.eye(2, dtype=jnp.complex64) 

638 X = jnp.array([[0, 1], [1, 0]]) 

639 Y = jnp.array([[0, -1j], [1j, 0]]) 

640 Z = jnp.array([[1, 0], [0, -1]]) 

641 

642 _H_X = H_static.conj().T @ X @ H_static 

643 _H_Y = H_static.conj().T @ Y @ H_static 

644 

645 _H_CZ = (jnp.pi / 4) * ( 

646 jnp.kron(Id, Id) - jnp.kron(Z, Id) - jnp.kron(Id, Z) + jnp.kron(Z, Z) 

647 ) 

648 

649 _H_corr = jnp.pi / 2 * jnp.eye(2, dtype=jnp.complex64) 

650 

651 _active_envelope: str = "gaussian" 

652 

653 @staticmethod 

654 def _coeff_Sx(p, t): 

655 """Coefficient function for RX pulse (active envelope).""" 

656 t_c = t / 2 

657 env = PulseEnvelope.gaussian(p, t, t_c) 

658 carrier = jnp.cos(PulseGates.omega_c * t + jnp.pi) 

659 return env * carrier * p[-1] 

660 

661 @staticmethod 

662 def _coeff_Sy(p, t): 

663 """Coefficient function for RY pulse (active envelope).""" 

664 t_c = t / 2 

665 env = PulseEnvelope.gaussian(p, t, t_c) 

666 carrier = jnp.cos(PulseGates.omega_c * t - jnp.pi / 2) 

667 return env * carrier * p[-1] 

668 

669 @staticmethod 

670 def _coeff_Sz(p, t): 

671 """Coefficient function for RZ pulse: p * w.""" 

672 return p[0] * p[1] 

673 

674 @staticmethod 

675 def _coeff_Sc(p, t): 

676 """Constant coefficient for H correction phase.""" 

677 return -1.0 

678 

679 @staticmethod 

680 def _coeff_Scz(p, t): 

681 """Coefficient function for CZ pulse.""" 

682 return p * jnp.pi 

683 

684 @staticmethod 

685 def _record_pulse_event(gate_name, w, wires, pulse_params, parent=None): 

686 """Append a PulseEvent to the active pulse tape if recording. 

687 

688 This is called from leaf gate methods (RX, RY, RZ, CZ) so that 

689 :func:`~qml_essentials.tape.pulse_recording` can collect events 

690 without the caller needing to know about the tape. 

691 """ 

692 ptape = active_pulse_tape() 

693 if ptape is None: 

694 return 

695 

696 from qml_essentials.drawing import PulseEvent, LEAF_META 

697 

698 meta = LEAF_META.get(gate_name, {}) 

699 wires_list = [wires] if isinstance(wires, int) else list(wires) 

700 

701 if meta.get("physical", False): 

702 info = PulseEnvelope.get(PulseInformation.get_envelope()) 

703 pp = PulseInformation.gate_by_name(gate_name).split_params(pulse_params) 

704 env_p = pp[:-1] 

705 dur = float(pp[-1]) 

706 ptape.append( 

707 PulseEvent( 

708 gate=gate_name, 

709 wires=wires_list, 

710 envelope_fn=info["fn"], 

711 envelope_params=jnp.array(env_p), 

712 w=float(w), 

713 duration=dur, 

714 carrier_phase=meta["carrier_phase"], 

715 parent=parent, 

716 ) 

717 ) 

718 else: 

719 pp = PulseInformation.gate_by_name(gate_name).split_params(pulse_params) 

720 ptape.append( 

721 PulseEvent( 

722 gate=gate_name, 

723 wires=wires_list, 

724 envelope_fn=None, 

725 envelope_params=jnp.ravel(jnp.asarray(pp)), 

726 w=float(w) if not isinstance(w, list) else 0.0, 

727 duration=1.0, 

728 carrier_phase=0.0, 

729 parent=parent, 

730 ) 

731 ) 

732 

733 @staticmethod 

734 def Rot( 

735 phi: float, 

736 theta: float, 

737 omega: float, 

738 wires: Union[int, List[int]], 

739 pulse_params: Optional[jnp.ndarray] = None, 

740 noise_params: Optional[Dict[str, float]] = None, 

741 random_key: Optional[jax.random.PRNGKey] = None, 

742 ) -> None: 

743 """ 

744 Apply general rotation via decomposition: RZ(phi) · RY(theta) · RZ(omega). 

745 

746 Args: 

747 phi (float): First rotation angle. 

748 theta (float): Second rotation angle. 

749 omega (float): Third rotation angle. 

750 wires (Union[int, List[int]]): Qubit index or indices to apply rotation to. 

751 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the 

752 composing gates. If None, uses optimized parameters. 

753 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

754 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility 

755 

756 Returns: 

757 None: Gates are applied in-place to the circuit. 

758 """ 

759 if noise_params is not None and "GateError" in noise_params: 

760 phi, random_key = UnitaryGates.GateError(phi, noise_params, random_key) 

761 theta, random_key = UnitaryGates.GateError(theta, noise_params, random_key) 

762 omega, random_key = UnitaryGates.GateError(omega, noise_params, random_key) 

763 PulseGates._execute_composite("Rot", [phi, theta, omega], wires, pulse_params) 

764 UnitaryGates.Noise(wires, noise_params) 

765 

766 @staticmethod 

767 def PauliRot( 

768 pauli: str, 

769 theta: float, 

770 wires: Union[int, List[int]], 

771 pulse_params: Optional[jnp.ndarray] = None, 

772 noise_params: Optional[Dict[str, float]] = None, 

773 random_key: Optional[jax.random.PRNGKey] = None, 

774 ) -> None: 

775 """Not implemented as a PulseGate.""" 

776 raise NotImplementedError("PauliRot gate is not implemented as PulseGate") 

777 

778 @staticmethod 

779 def RX( 

780 w: float, 

781 wires: Union[int, List[int]], 

782 pulse_params: Optional[jnp.ndarray] = None, 

783 noise_params: Optional[Dict[str, float]] = None, 

784 random_key: Optional[jax.random.PRNGKey] = None, 

785 ) -> None: 

786 """Apply X-axis rotation using the active pulse envelope. 

787 

788 Args: 

789 w: Rotation angle in radians. 

790 wires: Qubit index or indices. 

791 pulse_params: Envelope parameters ``[env_0, ..., env_n, t]``. 

792 If ``None``, uses optimized defaults. 

793 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

794 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility 

795 """ 

796 pulse_params = PulseInformation.RX.split_params(pulse_params) 

797 

798 PulseGates._record_pulse_event("RX", w, wires, pulse_params) 

799 

800 _H = op.Hermitian(PulseGates._H_X, wires=wires, record=False) 

801 H_eff = PulseGates._coeff_Sx * _H 

802 

803 # Pack: [envelope_params..., w] - evolution time is the last element 

804 # of pulse_params (pulse_params[-1]). 

805 w, random_key = UnitaryGates.GateError(w, noise_params, random_key) 

806 env_params = jnp.array([*pulse_params[:-1], w]) 

807 ys.evolve(H_eff, name="RX")([env_params], pulse_params[-1]) 

808 UnitaryGates.Noise(wires, noise_params) 

809 

810 @staticmethod 

811 def RY( 

812 w: float, 

813 wires: Union[int, List[int]], 

814 pulse_params: Optional[jnp.ndarray] = None, 

815 noise_params: Optional[Dict[str, float]] = None, 

816 random_key: Optional[jax.random.PRNGKey] = None, 

817 ) -> None: 

818 """Apply Y-axis rotation using the active pulse envelope. 

819 

820 Args: 

821 w: Rotation angle in radians. 

822 wires: Qubit index or indices. 

823 pulse_params: Envelope parameters ``[env_0, ..., env_n, t]``. 

824 If ``None``, uses optimized defaults. 

825 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

826 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility 

827 """ 

828 pulse_params = PulseInformation.RY.split_params(pulse_params) 

829 

830 PulseGates._record_pulse_event("RY", w, wires, pulse_params) 

831 

832 _H = op.Hermitian(PulseGates._H_Y, wires=wires, record=False) 

833 H_eff = PulseGates._coeff_Sy * _H 

834 

835 # Pack w into the params so the coefficient function doesn't need 

836 # a closure - this enables JIT solver cache sharing across all RY calls. 

837 w, random_key = UnitaryGates.GateError(w, noise_params, random_key) 

838 env_params = jnp.array([*pulse_params[:-1], w]) 

839 ys.evolve(H_eff, name="RY")([env_params], pulse_params[-1]) 

840 UnitaryGates.Noise(wires, noise_params) 

841 

842 @staticmethod 

843 def RZ( 

844 w: float, 

845 wires: Union[int, List[int]], 

846 pulse_params: Optional[float] = None, 

847 noise_params: Optional[Dict[str, float]] = None, 

848 random_key: Optional[jax.random.PRNGKey] = None, 

849 ) -> None: 

850 """ 

851 Apply Z-axis rotation using pulse-level implementation. 

852 

853 Implements RZ rotation using virtual Z rotations (phase tracking) 

854 without physical pulse application. 

855 

856 Args: 

857 w (float): Rotation angle in radians. 

858 wires (Union[int, List[int]]): Qubit index or indices to apply rotation to. 

859 pulse_params (Optional[float]): Duration parameter for the pulse. 

860 Rotation angle = w * 2 * pulse_params. Defaults to 0.5 if None. 

861 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

862 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility 

863 

864 Returns: 

865 None: Gate is applied in-place to the circuit. 

866 """ 

867 pulse_params = PulseInformation.RZ.split_params(pulse_params) 

868 

869 PulseGates._record_pulse_event("RZ", w, wires, pulse_params) 

870 

871 _H = op.Hermitian(PulseGates.Z, wires=wires, record=False) 

872 H_eff = PulseGates._coeff_Sz * _H 

873 

874 # Pack w into the params so the coefficient function doesn't need 

875 # a closure - [pulse_param_scalar, w] enables JIT solver cache sharing. 

876 # pulse_params may be a 1-element array or scalar; ravel + index to 

877 # ensure a scalar for concatenation. 

878 w, random_key = UnitaryGates.GateError(w, noise_params, random_key) 

879 pp_scalar = jnp.ravel(jnp.asarray(pulse_params))[0] 

880 ys.evolve(H_eff, name="RZ")([jnp.array([pp_scalar, w])], 1) 

881 

882 UnitaryGates.Noise(wires, noise_params) 

883 

884 @staticmethod 

885 def _resolve_wires(wire_fn, wires): 

886 """Resolve a wire selector string to actual wire(s). 

887 

888 Args: 

889 wire_fn: ``"all"``, ``"target"``, or ``"control"``. 

890 wires: Parent gate's wire(s) (int or list). 

891 

892 Returns: 

893 Wire(s) for the child gate. 

894 """ 

895 wires_list = [wires] if isinstance(wires, int) else list(wires) 

896 if wire_fn == "all": 

897 return wires if len(wires_list) > 1 else wires_list[0] 

898 if wire_fn == "target": 

899 return wires_list[-1] if len(wires_list) > 1 else wires_list[0] 

900 if wire_fn == "control": 

901 return wires_list[0] 

902 raise ValueError(f"Unknown wire_fn: {wire_fn!r}") 

903 

904 @staticmethod 

905 def _execute_composite(gate_name, w, wires, pulse_params=None): 

906 """Execute a composite gate by walking its decomposition. 

907 

908 Reads the :class:`DecompositionStep` list from 

909 :class:`PulseInformation` and dispatches each step to the 

910 appropriate ``PulseGates`` method. 

911 

912 Args: 

913 gate_name: Gate name (e.g. ``"H"``, ``"CX"``). 

914 w: Rotation angle(s) passed to the parent gate. 

915 wires: Wire(s) of the parent gate. 

916 pulse_params: Optional pulse parameters (split across children). 

917 """ 

918 pp_obj = PulseInformation.gate_by_name(gate_name) 

919 parts = pp_obj.split_params(pulse_params) 

920 

921 for step, child_params in zip(pp_obj.decomposition, parts): 

922 child_wires = PulseGates._resolve_wires(step.wire_fn, wires) 

923 child_w = step.angle_fn(w) if step.angle_fn is not None else w 

924 child_gate = getattr(PulseGates, step.gate.name) 

925 

926 # Leaf gates that take a rotation angle 

927 if step.gate.name in ("RX", "RY", "RZ"): 

928 child_gate(child_w, wires=child_wires, pulse_params=child_params) 

929 # Leaf gates without a rotation angle 

930 elif step.gate.name in ("CZ",): 

931 child_gate(wires=child_wires, pulse_params=child_params) 

932 # Composite gates with a rotation angle (CRX, CRY, CRZ, Rot, ...) 

933 elif step.gate.name in ("Rot",): 

934 # Rot expects (phi, theta, omega, wires, ...) 

935 child_gate(*child_w, wires=child_wires, pulse_params=child_params) 

936 elif step.gate.decomposition is not None and step.gate.name in ( 

937 "CRX", 

938 "CRY", 

939 "CRZ", 

940 ): 

941 child_gate(child_w, wires=child_wires, pulse_params=child_params) 

942 # Other composite gates (H, CX, CY, ...) 

943 else: 

944 child_gate(wires=child_wires, pulse_params=child_params) 

945 

946 @staticmethod 

947 def H( 

948 wires: Union[int, List[int]], 

949 pulse_params: Optional[jnp.ndarray] = None, 

950 noise_params: Optional[Dict[str, float]] = None, 

951 random_key: Optional[jax.random.PRNGKey] = None, 

952 ) -> None: 

953 """Apply Hadamard gate using pulse decomposition. 

954 

955 Decomposes as RZ(π) · RY(π/2) followed by a correction phase. 

956 

957 Args: 

958 wires (Union[int, List[int]]): Qubit index or indices. 

959 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

960 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility 

961 (not used in this gate). 

962 """ 

963 PulseGates._execute_composite("H", 0.0, wires, pulse_params) 

964 

965 # Correction phase unique to the H gate 

966 _H = op.Hermitian(PulseGates._H_corr, wires=wires, record=False) 

967 H_corr = PulseGates._coeff_Sc * _H 

968 ys.evolve(H_corr, name="H")([0], 1) 

969 UnitaryGates.Noise(wires, noise_params) 

970 

971 @staticmethod 

972 def CX( 

973 wires: List[int], 

974 pulse_params: Optional[jnp.ndarray] = None, 

975 noise_params: Optional[Dict[str, float]] = None, 

976 random_key: Optional[jax.random.PRNGKey] = None, 

977 ) -> None: 

978 """Apply CNOT gate via decomposition: H(target) · CZ · H(target). 

979 

980 Args: 

981 wires (List[int]): Control and target qubit indices [control, target]. 

982 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the 

983 composing gates. If None, uses optimized parameters. 

984 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

985 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility 

986 (not used in this gate). 

987 

988 Returns: 

989 None: Gate is applied in-place to the circuit. 

990 """ 

991 PulseGates._execute_composite("CX", 0.0, wires, pulse_params) 

992 UnitaryGates.Noise(wires, noise_params) 

993 

994 @staticmethod 

995 def CY( 

996 wires: List[int], 

997 pulse_params: Optional[jnp.ndarray] = None, 

998 noise_params: Optional[Dict[str, float]] = None, 

999 random_key: Optional[jax.random.PRNGKey] = None, 

1000 ) -> None: 

1001 """Apply controlled-Y via decomposition. 

1002 

1003 Args: 

1004 wires (List[int]): Control and target qubit indices [control, target]. 

1005 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the 

1006 composing gates. If None, uses optimized parameters. 

1007 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

1008 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility 

1009 (not used in this gate). 

1010 

1011 """ 

1012 PulseGates._execute_composite("CY", 0.0, wires, pulse_params) 

1013 UnitaryGates.Noise(wires, noise_params) 

1014 

1015 @staticmethod 

1016 def CZ( 

1017 wires: List[int], 

1018 pulse_params: Optional[float] = None, 

1019 noise_params: Optional[Dict[str, float]] = None, 

1020 random_key: Optional[jax.random.PRNGKey] = None, 

1021 ) -> None: 

1022 """Apply controlled-Z using ZZ coupling Hamiltonian. 

1023 

1024 Args: 

1025 wires (List[int]): Control and target qubit indices. 

1026 pulse_params (Optional[float]): Time or duration parameter for 

1027 the pulse evolution. If None, uses optimized value. 

1028 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

1029 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility 

1030 (not used in this gate). 

1031 

1032 """ 

1033 if pulse_params is None: 

1034 pulse_params = PulseInformation.CZ.params 

1035 

1036 PulseGates._record_pulse_event("CZ", 0.0, wires, pulse_params) 

1037 

1038 _H = op.Hermitian(PulseGates._H_CZ, wires=wires, record=False) 

1039 H_eff = PulseGates._coeff_Scz * _H 

1040 ys.evolve(H_eff, name="CZ")([pulse_params], 1) 

1041 UnitaryGates.Noise(wires, noise_params) 

1042 

1043 @staticmethod 

1044 def CRX( 

1045 w: float, 

1046 wires: List[int], 

1047 pulse_params: Optional[jnp.ndarray] = None, 

1048 noise_params: Optional[Dict[str, float]] = None, 

1049 random_key: Optional[jax.random.PRNGKey] = None, 

1050 ) -> None: 

1051 """Apply controlled-RX via decomposition. 

1052 

1053 Args: 

1054 w (float): Rotation angle in radians. 

1055 wires (List[int]): Control and target qubit indices [control, target]. 

1056 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the 

1057 composing gates. If None, uses optimized parameters. 

1058 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

1059 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility 

1060 (not used in this gate). 

1061 """ 

1062 PulseGates._execute_composite("CRX", w, wires, pulse_params) 

1063 UnitaryGates.Noise(wires, noise_params) 

1064 

1065 @staticmethod 

1066 def CRY( 

1067 w: float, 

1068 wires: List[int], 

1069 pulse_params: Optional[jnp.ndarray] = None, 

1070 noise_params: Optional[Dict[str, float]] = None, 

1071 random_key: Optional[jax.random.PRNGKey] = None, 

1072 ) -> None: 

1073 """Apply controlled-RY via decomposition. 

1074 

1075 Args: 

1076 w (float): Rotation angle in radians. 

1077 wires (List[int]): Control and target qubit indices [control, target]. 

1078 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the 

1079 composing gates. If None, uses optimized parameters. 

1080 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

1081 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility 

1082 """ 

1083 w, random_key = UnitaryGates.GateError(w, noise_params, random_key) 

1084 PulseGates._execute_composite("CRY", w, wires, pulse_params) 

1085 UnitaryGates.Noise(wires, noise_params) 

1086 

1087 @staticmethod 

1088 def CRZ( 

1089 w: float, 

1090 wires: List[int], 

1091 pulse_params: Optional[jnp.ndarray] = None, 

1092 noise_params: Optional[Dict[str, float]] = None, 

1093 random_key: Optional[jax.random.PRNGKey] = None, 

1094 ) -> None: 

1095 """Apply controlled-RZ via decomposition. 

1096 

1097 Args: 

1098 w (float): Rotation angle in radians. 

1099 wires (List[int]): Control and target qubit indices [control, target]. 

1100 pulse_params (Optional[jnp.ndarray]): Pulse parameters for the 

1101 composing gates. If None, uses optimized parameters. 

1102 noise_params (Optional[Dict[str, float]]): Noise parameters dictionary. 

1103 random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility 

1104 """ 

1105 w, random_key = UnitaryGates.GateError(w, noise_params, random_key) 

1106 PulseGates._execute_composite("CRZ", w, wires, pulse_params) 

1107 UnitaryGates.Noise(wires, noise_params) 

1108 

1109 

1110class PulseParamManager: 

1111 def __init__(self, pulse_params: jnp.ndarray): 

1112 self.pulse_params = pulse_params 

1113 self.idx = 0 

1114 

1115 def get(self, n: int): 

1116 """Return the next n parameters and advance the cursor.""" 

1117 if self.idx + n > len(self.pulse_params): 

1118 raise ValueError("Not enough pulse parameters left for this gate") 

1119 # TODO: we squeeze here to get rid of any extra hidden dimension 

1120 params = self.pulse_params[self.idx : self.idx + n].squeeze() 

1121 self.idx += n 

1122 return params