Coverage for qml_essentials / drawing.py: 90%

402 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-16 10:19 +0000

1from fractions import Fraction 

2from typing import Any, Dict, List, Tuple, Union 

3from dataclasses import dataclass 

4from typing import Optional 

5import jax.numpy as jnp 

6import matplotlib.pyplot as plt 

7import matplotlib.patches as mpatches 

8 

9from qml_essentials.operations import ( 

10 Operation, 

11) 

12 

13 

14class TikzFigure: 

15 """Wrapper around a ``quantikz`` LaTeX string with export helpers.""" 

16 

17 def __init__(self, quantikz_str: str): 

18 self.quantikz_str = quantikz_str 

19 

20 def __repr__(self): 

21 return self.quantikz_str 

22 

23 def __str__(self): 

24 return self.quantikz_str 

25 

26 def wrap_figure(self) -> str: 

27 """ 

28 Wraps the quantikz string in a LaTeX figure environment. 

29 

30 Returns: 

31 str: A formatted LaTeX string representing the TikZ figure containing 

32 the quantum circuit diagram. 

33 """ 

34 return f""" 

35\\begin{{figure}} 

36 \\centering 

37 \\begin{{tikzpicture}} 

38 \\node[scale=0.85] {{ 

39 \\begin{{quantikz}} 

40 {self.quantikz_str} 

41 \\end{{quantikz}} 

42 }}; 

43 \\end{{tikzpicture}} 

44\\end{{figure}}""" 

45 

46 def export( 

47 self, destination: str, full_document: bool = False, mode: str = "w" 

48 ) -> None: 

49 """ 

50 Export a LaTeX document with a quantum circuit in stick notation. 

51 

52 Parameters 

53 ---------- 

54 quantikz_strs : str or list[str] 

55 LaTeX string for the quantum circuit or a list of LaTeX strings. 

56 destination : str 

57 Path to the destination file. 

58 """ 

59 if full_document: 

60 latex_code = f""" 

61\\documentclass{{article}} 

62\\usepackage{{quantikz}} 

63\\usepackage{{tikz}} 

64\\usetikzlibrary{{quantikz2}} 

65\\usepackage{{quantikz}} 

66\\usepackage[a3paper, landscape, margin=0.5cm]{{geometry}} 

67\\begin{{document}} 

68{self.wrap_figure()} 

69\\end{{document}}""" 

70 else: 

71 latex_code = self.quantikz_str + "\n" 

72 

73 with open(destination, mode) as f: 

74 f.write(latex_code) 

75 

76 

77# Backwards-compatible alias so existing ``QuanTikz.TikzFigure`` references 

78# keep working without changes in downstream code. 

79class QuanTikz: 

80 TikzFigure = TikzFigure 

81 

82 

83def _ctrl_target_name(name: str) -> str: 

84 """Strip the leading 'C' from a controlled gate name to get the target name.""" 

85 # CRX -> RX, CX -> X, etc. 

86 return name.replace("C", "") 

87 

88 

89def _format_param(val: float) -> str: 

90 """Format a numeric parameter for text display. 

91 

92 Shows nice π-fractions when possible, otherwise 2 decimal places. 

93 """ 

94 try: 

95 frac = Fraction(float(val) / float(jnp.pi)).limit_denominator(100) 

96 if frac.numerator == 0: 

97 return "0" 

98 if frac.denominator <= 12: 

99 if frac == Fraction(1, 1): 

100 return "π" 

101 if frac.numerator == 1: 

102 return f"π/{frac.denominator}" 

103 if frac.denominator == 1: 

104 return f"{frac.numerator}π" 

105 return f"{frac.numerator}π/{frac.denominator}" 

106 except (ValueError, ZeroDivisionError): 

107 pass 

108 return f"{float(val):.2f}" 

109 

110 

111def _gate_label(op: Operation) -> str: 

112 """Build a short label like ``RX(π/2)`` or ``H`` for a gate.""" 

113 name = op.name 

114 params = op.parameters 

115 if params: 

116 param_str = ", ".join(_format_param(p) for p in params) 

117 return f"{name}({param_str})" 

118 return name 

119 

120 

121def _tikz_param_str(val: float, op_name: str) -> str: 

122 """Format a rotation angle as a LaTeX string for quantikz gates.""" 

123 try: 

124 frac = Fraction(float(val) / float(jnp.pi)).limit_denominator(100) 

125 if frac.denominator > 12: 

126 return f"\\gate{{{op_name}({float(val):.2f})}}" 

127 if frac.denominator == 1 and frac.numerator == 1: 

128 return f"\\gate{{{op_name}(\\pi)}}" 

129 if frac.numerator == 0: 

130 return f"\\gate{{{op_name}(0)}}" 

131 if frac.denominator == 1: 

132 return f"\\gate{{{op_name}({frac.numerator}\\pi)}}" 

133 if frac.numerator == 1: 

134 return ( 

135 f"\\gate{{{op_name}\\left(" 

136 f"\\frac{{\\pi}}{{{frac.denominator}}}\\right)}}" 

137 ) 

138 return ( 

139 f"\\gate{{{op_name}\\left(" 

140 f"\\frac{{{frac.numerator}\\pi}}{{{frac.denominator}}}" 

141 f"\\right)}}" 

142 ) 

143 except (ValueError, ZeroDivisionError): 

144 return f"\\gate{{{op_name}({float(val):.2f})}}" 

145 

146 

147def _tikz_align_wires(circuit_tikz: List[List[str]], wires: List[int]) -> None: 

148 """Pad all *wires* to the same column length in-place.""" 

149 max_len = max(len(circuit_tikz[w]) for w in wires) 

150 for w in wires: 

151 circuit_tikz[w].extend("" for _ in range(max_len - len(circuit_tikz[w]))) 

152 

153 

154def _tikz_cell_controlled( 

155 op: Operation, 

156 circuit_tikz: List[List[str]], 

157 param_index: int, 

158 gate_values: bool, 

159) -> int: 

160 """Append cells for a 2-wire controlled gate; return updated param_index.""" 

161 ctrl_wire = op.wires[0] 

162 targ_wire = op.wires[1] 

163 distance = targ_wire - ctrl_wire 

164 target_name = _ctrl_target_name(op.name) 

165 

166 # Build target cell 

167 if op.parameters and target_name in ("RX", "RY", "RZ"): 

168 if gate_values: 

169 targ_cell = _tikz_param_str(float(op.parameters[0]), target_name) 

170 else: 

171 targ_cell = f"\\gate{{{target_name}(\\theta_{{{param_index}}})}}" 

172 param_index += 1 

173 elif target_name in ("X", "Y", "Z"): 

174 targ_cell = "\\targ{}" if target_name == "X" else "\\control{}" 

175 else: 

176 targ_cell = f"\\gate{{{target_name}}}" 

177 

178 crossing = list(range(min(op.wires), max(op.wires) + 1)) 

179 _tikz_align_wires(circuit_tikz, crossing) 

180 

181 circuit_tikz[ctrl_wire].append(f"\\ctrl{{{distance}}}") 

182 circuit_tikz[targ_wire].append(targ_cell) 

183 

184 # Pad intermediate wires 

185 for w in crossing: 

186 if w != ctrl_wire and w != targ_wire: 

187 circuit_tikz[w].append("") 

188 

189 return param_index 

190 

191 

192def _tikz_cell_single( 

193 op: Operation, 

194 circuit_tikz: List[List[str]], 

195 param_index: int, 

196 gate_values: bool, 

197) -> int: 

198 """Append a cell for a single-qubit gate; return updated param_index.""" 

199 w = op.wires[0] 

200 name = op.name 

201 if name == "Hadamard": 

202 name = "H" 

203 

204 if gate_values and op.parameters: 

205 cell = _tikz_param_str(float(op.parameters[0]), name) 

206 elif op.parameters: 

207 cell = f"\\gate{{{name}(\\theta_{{{param_index}}})}}" 

208 param_index += 1 

209 else: 

210 cell = f"\\gate{{{name}}}" 

211 

212 circuit_tikz[w].append(cell) 

213 return param_index 

214 

215 

216def _tikz_cell_multiqubit( 

217 op: Operation, 

218 circuit_tikz: List[List[str]], 

219) -> None: 

220 """Append cells for a multi-qubit (>2 wire) gate.""" 

221 _tikz_align_wires(circuit_tikz, list(op.wires)) 

222 label = _gate_label(op) 

223 for w in op.wires: 

224 circuit_tikz[w].append(f"\\gate{{{label}}}") 

225 

226 

227def _tikz_cell_barrier( 

228 op: Operation, 

229 circuit_tikz: List[List[str]], 

230) -> None: 

231 """Align all wires so that subsequent gates start in the same column. 

232 

233 The barrier is a no-op visually — it only pads shorter wires so that 

234 every wire has the same number of cells at this point. 

235 """ 

236 all_wires = list(range(len(circuit_tikz))) 

237 _tikz_align_wires(circuit_tikz, all_wires) 

238 

239 

240def _tikz_build_string(circuit_tikz: List[List[str]], n_qubits: int) -> str: 

241 """Render the column grid to a quantikz LaTeX string.""" 

242 # Equalise wire lengths 

243 max_len = max(len(wire) for wire in circuit_tikz) 

244 for wire in circuit_tikz: 

245 wire.extend("" for _ in range(max_len - len(wire))) 

246 

247 quantikz_str = "" 

248 for wire_idx, wire_ops in enumerate(circuit_tikz): 

249 for op_idx, cell in enumerate(wire_ops): 

250 if op_idx < len(wire_ops) - 1: 

251 quantikz_str += f"{cell} & " 

252 else: 

253 quantikz_str += f"{cell}" 

254 if wire_idx < n_qubits - 1: 

255 quantikz_str += " \\\\\n" 

256 

257 return quantikz_str 

258 

259 

260def draw_tikz( 

261 ops: List[Operation], 

262 n_qubits: int, 

263 gate_values: bool = False, 

264 **kwargs: Any, 

265) -> Any: 

266 """Render a circuit tape as LaTeX/TikZ ``quantikz`` code. 

267 

268 Args: 

269 ops: Ordered list of gate operations (noise channels excluded). 

270 n_qubits: Total number of qubits. 

271 gate_values: If ``True``, show numeric angles; otherwise use 

272 symbolic \\theta_i labels. 

273 

274 Returns: 

275 A :class:`~qml_essentials.drawing.TikzFigure` object. 

276 """ 

277 circuit_tikz: List[List[str]] = [["\\lstick{\\ket{0}}"] for _ in range(n_qubits)] 

278 param_index = 0 

279 

280 for op in ops: 

281 if op.is_controlled and len(op.wires) == 2: 

282 param_index = _tikz_cell_controlled( 

283 op, circuit_tikz, param_index, gate_values 

284 ) 

285 elif len(op.wires) == 1: 

286 param_index = _tikz_cell_single(op, circuit_tikz, param_index, gate_values) 

287 elif op.name == "Barrier": 

288 _tikz_cell_barrier(op, circuit_tikz) 

289 else: 

290 _tikz_cell_multiqubit(op, circuit_tikz) 

291 

292 return TikzFigure(_tikz_build_string(circuit_tikz, n_qubits)) 

293 

294 

295def draw_text(ops: List[Operation], n_qubits: int) -> str: 

296 """Render a circuit tape as an ASCII-art string. 

297 

298 Args: 

299 ops: Ordered list of gate operations (noise channels excluded). 

300 n_qubits: Total number of qubits. 

301 

302 Returns: 

303 Multi-line string with one row per qubit. 

304 """ 

305 if not ops: 

306 lines = [f" q{q}: ───" for q in range(n_qubits)] 

307 return "\n".join(lines) 

308 

309 # Schedule operations into time-step columns. 

310 # Each column is a dict mapping qubit -> display string. 

311 columns: List[Dict[int, str]] = [] 

312 wire_busy: Dict[int, int] = {} # qubit -> next free column index 

313 

314 for op in ops: 

315 start = max((wire_busy.get(w, 0) for w in op.wires), default=0) 

316 

317 # Ensure enough columns exist 

318 while len(columns) <= start: 

319 columns.append({}) 

320 

321 if op.is_controlled and len(op.wires) >= 2: 

322 ctrl_wires = op.wires[:-1] 

323 target_wire = op.wires[-1] 

324 target_name = _ctrl_target_name(op.name) 

325 

326 # Build target label with parameters 

327 if op.parameters: 

328 param_str = ", ".join(_format_param(p) for p in op.parameters) 

329 target_label = f"{target_name}({param_str})" 

330 else: 

331 target_label = target_name 

332 

333 for cw in ctrl_wires: 

334 columns[start][cw] = "●" 

335 columns[start][target_wire] = target_label 

336 

337 # Mark all wires in the span as busy (for crossing wires) 

338 all_spanned = range(min(op.wires), max(op.wires) + 1) 

339 for w in all_spanned: 

340 wire_busy[w] = start + 1 

341 else: 

342 label = _gate_label(op) 

343 for w in op.wires: 

344 columns[start][w] = label 

345 for w in op.wires: 

346 wire_busy[w] = start + 1 

347 

348 # Render the grid 

349 # Determine column widths 

350 col_widths = [] 

351 for col in columns: 

352 max_w = max((len(v) for v in col.values()), default=1) 

353 col_widths.append(max(max_w, 1)) 

354 

355 lines = [] 

356 for q in range(n_qubits): 

357 parts = [f" q{q}: "] 

358 for ci, col in enumerate(columns): 

359 w = col_widths[ci] 

360 if q in col: 

361 cell = col[q].center(w) 

362 else: 

363 cell = "─" * w 

364 parts.append(f"─┤{cell}") 

365 parts.append("─") 

366 lines.append("".join(parts)) 

367 

368 return "\n".join(lines) 

369 

370 

371# Matplotlib drawing 

372 

373 

374def draw_mpl( 

375 ops: List[Operation], 

376 n_qubits: int, 

377 **kwargs: Any, 

378) -> Tuple: 

379 """Render a circuit tape as a Matplotlib figure. 

380 

381 Args: 

382 ops: Ordered list of gate operations (noise channels excluded). 

383 n_qubits: Total number of qubits. 

384 **kwargs: Reserved for future options. 

385 

386 Returns: 

387 Tuple ``(fig, ax)`` — a Matplotlib ``Figure`` and ``Axes``. 

388 """ 

389 

390 # Schedule into columns (same logic as text) 

391 columns: List[Dict[int, str]] = [] 

392 wire_busy: Dict[int, int] = {} 

393 ctrl_info: List[Dict[str, Any]] = [] # per-column control gate metadata 

394 

395 for op in ops: 

396 start = max((wire_busy.get(w, 0) for w in op.wires), default=0) 

397 while len(columns) <= start: 

398 columns.append({}) 

399 ctrl_info.append({}) 

400 

401 if op.is_controlled and len(op.wires) >= 2: 

402 ctrl_wires = op.wires[:-1] 

403 target_wire = op.wires[-1] 

404 target_name = _ctrl_target_name(op.name) 

405 if op.parameters: 

406 param_str = ", ".join(_format_param(p) for p in op.parameters) 

407 target_label = f"{target_name}({param_str})" 

408 else: 

409 target_label = target_name 

410 

411 for cw in ctrl_wires: 

412 columns[start][cw] = "●" 

413 columns[start][target_wire] = target_label 

414 

415 ctrl_info[start] = { 

416 "ctrl": ctrl_wires, 

417 "target": target_wire, 

418 } 

419 

420 all_spanned = range(min(op.wires), max(op.wires) + 1) 

421 for w in all_spanned: 

422 wire_busy[w] = start + 1 

423 else: 

424 label = _gate_label(op) 

425 for w in op.wires: 

426 columns[start][w] = label 

427 wire_busy[w] = start + 1 

428 

429 n_cols = len(columns) if columns else 1 

430 fig_width = max(3.0, 1.2 * (n_cols + 2)) 

431 fig_height = max(2.0, 0.8 * n_qubits) 

432 

433 fig, ax = plt.subplots(figsize=(fig_width, fig_height)) 

434 ax.set_xlim(-0.5, n_cols + 0.5) 

435 ax.set_ylim(-0.5, n_qubits - 0.5) 

436 ax.invert_yaxis() 

437 ax.set_aspect("equal") 

438 ax.axis("off") 

439 

440 # Draw qubit wires 

441 for q in range(n_qubits): 

442 ax.plot([-0.3, n_cols + 0.3], [q, q], color="black", linewidth=0.8, zorder=0) 

443 ax.text( 

444 -0.5, 

445 q, 

446 "|0⟩", 

447 ha="right", 

448 va="center", 

449 fontsize=10, 

450 fontfamily="monospace", 

451 ) 

452 

453 # Draw gates 

454 gate_box_h = 0.6 

455 gate_box_w = 0.6 

456 

457 for ci, col in enumerate(columns): 

458 x = ci + 0.5 

459 

460 # Draw control lines 

461 ci_meta = ctrl_info[ci] if ci < len(ctrl_info) else {} 

462 if ci_meta: 

463 all_wires = list(ci_meta["ctrl"]) + [ci_meta["target"]] 

464 y_min = min(all_wires) 

465 y_max = max(all_wires) 

466 ax.plot([x, x], [y_min, y_max], color="black", linewidth=1.0, zorder=1) 

467 

468 for q, label in col.items(): 

469 if label == "●": 

470 # Control dot 

471 ax.plot(x, q, "o", color="black", markersize=6, zorder=3) 

472 else: 

473 # Gate box 

474 fontsize = 9 if len(label) <= 6 else 7 

475 bw = max(gate_box_w, len(label) * 0.09 + 0.2) 

476 rect = mpatches.FancyBboxPatch( 

477 (x - bw / 2, q - gate_box_h / 2), 

478 bw, 

479 gate_box_h, 

480 boxstyle="round,pad=0.05", 

481 facecolor="white", 

482 edgecolor="black", 

483 linewidth=1.0, 

484 zorder=2, 

485 ) 

486 ax.add_patch(rect) 

487 ax.text( 

488 x, 

489 q, 

490 label, 

491 ha="center", 

492 va="center", 

493 fontsize=fontsize, 

494 zorder=4, 

495 ) 

496 

497 fig.tight_layout() 

498 return fig, ax 

499 

500 

501@dataclass 

502class PulseEvent: 

503 """Single pulse applied to one or more wires. 

504 

505 Attributes: 

506 gate: Gate label, e.g. ``"RX"``, ``"CZ"``. 

507 wires: Target qubit wire(s). 

508 envelope_fn: Pure envelope function ``(p, t, t_c) -> amplitude``. 

509 envelope_params: Envelope-shape parameters (excluding ``w`` and ``t``). 

510 w: Rotation angle passed to the gate. 

511 duration: Pulse duration (evolution time). 

512 carrier_phase: Phase offset for the carrier cosine. 

513 parent: Optional high-level gate name that decomposed into this event. 

514 """ 

515 

516 gate: str 

517 wires: List[int] 

518 envelope_fn: Any # (p, t, t_c) -> scalar 

519 envelope_params: Any # jnp array of envelope shape params 

520 w: float # rotation angle 

521 duration: float # evolution time 

522 carrier_phase: float = 0.0 # phi_c in cos(omega_c * t + phi_c) 

523 parent: Optional[str] = None # composite gate that owns this pulse 

524 

525 

526# Leaf gate metadata for pulse schedule drawing. 

527# ``physical`` gates have an envelope; virtual gates (RZ, CZ) do not. 

528LEAF_META = { 

529 "RX": {"carrier_phase": float(jnp.pi), "physical": True}, 

530 "RY": {"carrier_phase": float(-jnp.pi / 2), "physical": True}, 

531 "RZ": {"carrier_phase": 0.0, "physical": False}, 

532 "CZ": {"carrier_phase": 0.0, "physical": False}, 

533} 

534 

535 

536def _resolve_wires_for_drawing(wire_fn, wires_list): 

537 """Resolve a ``wire_fn`` string to concrete wire(s) for drawing.""" 

538 if wire_fn == "all": 

539 return wires_list 

540 if wire_fn == "target": 

541 return [wires_list[-1]] if len(wires_list) > 1 else wires_list 

542 if wire_fn == "control": 

543 return [wires_list[0]] 

544 raise ValueError(f"Unknown wire_fn: {wire_fn!r}") 

545 

546 

547def collect_pulse_events( 

548 gate_name: str, 

549 w: Union[float, List[float]], 

550 wires: Union[int, List[int]], 

551 pulse_params: Any = None, 

552 parent: Optional[str] = None, 

553) -> List[PulseEvent]: 

554 """Decompose a (possibly composite) pulse gate into leaf PulseEvents. 

555 

556 Walks the :class:`DecompositionStep` tree stored on :class:`PulseParams` 

557 and collects timing / envelope information for drawing — no quantum 

558 operations are applied. 

559 

560 Args: 

561 gate_name: Name of the gate (``"RX"``, ``"H"``, ``"CX"``, etc.). 

562 w: Rotation angle(s). 

563 wires: Qubit index or ``[control, target]``. 

564 pulse_params: Pulse parameters or ``None`` for defaults. 

565 parent: Label of the enclosing composite gate. 

566 

567 Returns: 

568 Ordered list of :class:`PulseEvent` objects. 

569 """ 

570 from qml_essentials.gates import PulseEnvelope, PulseInformation 

571 

572 pp_obj = PulseInformation.gate_by_name(gate_name) 

573 if pp_obj is None: 

574 raise ValueError(f"Unknown pulse gate: {gate_name!r}") 

575 

576 wires_list = [wires] if isinstance(wires, int) else list(wires) 

577 parent_label = parent or gate_name 

578 

579 # --- Leaf gate --- 

580 if pp_obj.is_leaf: 

581 meta = LEAF_META.get(gate_name) 

582 if meta is None: 

583 raise ValueError(f"Unknown pulse gate: {gate_name!r}") 

584 

585 info = PulseEnvelope.get(PulseInformation.get_envelope()) 

586 pp = pp_obj.split_params(pulse_params) 

587 

588 if meta["physical"]: 

589 env_p = pp[:-1] 

590 dur = float(pp[-1]) 

591 return [ 

592 PulseEvent( 

593 gate=gate_name, 

594 wires=wires_list, 

595 envelope_fn=info["fn"], 

596 envelope_params=jnp.array(env_p), 

597 w=float(w), 

598 duration=dur, 

599 carrier_phase=meta["carrier_phase"], 

600 parent=parent_label, 

601 ) 

602 ] 

603 else: 

604 # Virtual gate (RZ, CZ) — no physical envelope 

605 return [ 

606 PulseEvent( 

607 gate=gate_name, 

608 wires=wires_list, 

609 envelope_fn=None, 

610 envelope_params=jnp.ravel(jnp.asarray(pp)), 

611 w=float(w) if not isinstance(w, list) else 0.0, 

612 duration=1.0, 

613 carrier_phase=0.0, 

614 parent=parent_label, 

615 ) 

616 ] 

617 

618 # --- Composite gate --- 

619 parts = pp_obj.split_params(pulse_params) 

620 events = [] 

621 for step, child_params in zip(pp_obj.decomposition, parts): 

622 child_wires = _resolve_wires_for_drawing(step.wire_fn, wires_list) 

623 child_w = step.angle_fn(w) if step.angle_fn is not None else w 

624 events += collect_pulse_events( 

625 step.gate.name, 

626 child_w, 

627 child_wires, 

628 child_params, 

629 parent=parent_label, 

630 ) 

631 return events 

632 

633 

634def _make_event_label(gate: str, parent: Optional[str]) -> str: 

635 """Build a display label, appending the parent gate if different.""" 

636 if parent and parent != gate: 

637 return f"{gate} ({parent})" 

638 return gate 

639 

640 

641def _sample_envelope(ev: PulseEvent, t_lo: float, t_hi: float, n_samples: int): 

642 """Sample the envelope over [t_lo, t_hi] and return (t_arr, signal). 

643 

644 Uses vectorised JAX operations instead of a Python loop. 

645 """ 

646 t_c = ev.duration / 2 

647 t_arr = jnp.linspace(t_lo, t_hi, n_samples) 

648 env = ev.envelope_fn(ev.envelope_params, t_arr, t_c) 

649 signal = env * ev.w 

650 return t_arr, signal 

651 

652 

653def _compute_display_window( 

654 ev: PulseEvent, 

655 n_samples: int, 

656 envelope_width: float = 1.0, 

657) -> Tuple[float, float, float]: 

658 """Compute the (t_lo, t_hi, amp_max) display window for a physical pulse. 

659 

660 The display window is chosen adaptively based on how much the envelope 

661 decays within the evolution window ``[0, duration]``. If the envelope 

662 is essentially zero at the edges, the evolution window is used as-is. 

663 Otherwise the window is widened until the envelope drops to 

664 ``edge_ratio ** 10`` of its peak, where ``edge_ratio`` is the 

665 amplitude at the window edge relative to the center. 

666 

667 The ``envelope_width`` parameter scales the resulting extension beyond 

668 the evolution window. ``1.0`` gives the default adaptive width, 

669 values ``> 1`` widen further, values ``< 1`` tighten, and ``0`` 

670 clamps the display exactly to the evolution window ``[0, duration]``. 

671 

672 Returns: 

673 ``(t_lo, t_hi, amp_max)`` — local time bounds and peak amplitude. 

674 """ 

675 t_c = ev.duration / 2 

676 

677 if envelope_width == 0: 

678 t_lo, t_hi = 0.0, ev.duration 

679 else: 

680 val_center = float(ev.envelope_fn(ev.envelope_params, t_c, t_c)) 

681 

682 if abs(val_center) < 1e-12: 

683 t_lo, t_hi = 0.0, ev.duration 

684 else: 

685 val_edge = float(ev.envelope_fn(ev.envelope_params, 0.0, t_c)) 

686 edge_ratio = abs(val_edge / val_center) 

687 

688 if edge_ratio < 0.01: 

689 t_lo, t_hi = 0.0, ev.duration 

690 else: 

691 target = edge_ratio**10 

692 lo, hi = ev.duration / 2, ev.duration * 50 

693 for _ in range(30): 

694 mid = (lo + hi) / 2 

695 val = float(ev.envelope_fn(ev.envelope_params, t_c + mid, t_c)) 

696 if abs(val / val_center) > target: 

697 lo = mid 

698 else: 

699 hi = mid 

700 half_width = ev.duration / 2 + (hi - ev.duration / 2) * envelope_width 

701 t_lo = t_c - half_width 

702 t_hi = t_c + half_width 

703 

704 _, signal = _sample_envelope(ev, t_lo, t_hi, n_samples) 

705 amp = float(jnp.max(jnp.abs(signal))) * 1.1 

706 return t_lo, t_hi, amp 

707 

708 

709def _draw_physical_pulse( 

710 ev: PulseEvent, 

711 t_start: float, 

712 t_lo: float, 

713 t_hi: float, 

714 axes, 

715 color: str, 

716 n_samples: int, 

717 omega_c: float, 

718 show_carrier: bool, 

719) -> None: 

720 """Draw a physical (RX/RY) pulse envelope on the given axes.""" 

721 t_arr, signal = _sample_envelope(ev, t_lo, t_hi, n_samples) 

722 t_display = t_arr + t_start 

723 

724 for wire in ev.wires: 

725 ax = axes[wire] 

726 ax.fill_between(t_display, signal, alpha=0.12, color=color, zorder=2) 

727 ax.plot(t_display, signal, color=color, linewidth=1.4, alpha=0.85, zorder=3) 

728 

729 # Mark evolution window boundaries with visible dashed lines 

730 for t_edge in (t_start, t_start + ev.duration): 

731 ax.axvline( 

732 t_edge, 

733 color=color, 

734 linestyle="--", 

735 linewidth=0.8, 

736 alpha=0.7, 

737 zorder=4, 

738 ) 

739 

740 if show_carrier: 

741 modulated = signal * jnp.cos(omega_c * t_arr + ev.carrier_phase) 

742 ax.plot( 

743 t_display, modulated, color=color, linewidth=0.8, alpha=0.8, zorder=2 

744 ) 

745 

746 peak_idx = jnp.argmax(jnp.abs(signal)) 

747 ax.annotate( 

748 _make_event_label(ev.gate, ev.parent), 

749 xy=(float(t_display[peak_idx]), float(signal[peak_idx])), 

750 fontsize=7, 

751 ha="center", 

752 va="bottom" if signal[peak_idx] >= 0 else "top", 

753 color=color, 

754 fontweight="bold", 

755 zorder=5, 

756 ) 

757 

758 

759def _draw_virtual_z( 

760 ev: PulseEvent, t_start: float, axes, color: str, amp_max: float 

761) -> None: 

762 """Draw a virtual-Z gate as a dashed vertical line.""" 

763 t_mid = t_start + ev.duration / 2 

764 for wire in ev.wires: 

765 ax = axes[wire] 

766 ax.vlines( 

767 t_mid, 

768 -amp_max * 0.6, 

769 amp_max * 0.6, 

770 color=color, 

771 linestyle="--", 

772 linewidth=1.2, 

773 alpha=0.8, 

774 zorder=2, 

775 ) 

776 ax.annotate( 

777 _make_event_label(ev.gate, ev.parent), 

778 xy=(t_mid, amp_max * 0.85), 

779 fontsize=7, 

780 ha="center", 

781 va="bottom", 

782 color=color, 

783 # fontstyle="italic", 

784 zorder=5, 

785 ) 

786 

787 

788def _draw_block( 

789 ev: PulseEvent, t_start: float, axes, color: str, amp_max: float 

790) -> None: 

791 """Draw a gate as a labelled rectangular block on each of its wires.""" 

792 label = _make_event_label(ev.gate, ev.parent) 

793 for wire in ev.wires: 

794 ax = axes[wire] 

795 rect = mpatches.Rectangle( 

796 (t_start, -amp_max * 0.6), 

797 ev.duration, 

798 amp_max * 1.2, 

799 alpha=0.2, 

800 facecolor=color, 

801 edgecolor=color, 

802 linewidth=1.0, 

803 zorder=1, 

804 ) 

805 ax.add_patch(rect) 

806 

807 axes[ev.wires[0]].annotate( 

808 label, 

809 xy=(t_start + ev.duration / 2, amp_max * 0.7), 

810 fontsize=7, 

811 ha="center", 

812 va="bottom", 

813 color=color, 

814 fontweight="bold", 

815 zorder=5, 

816 ) 

817 

818 

819def draw_pulse_schedule( 

820 events: List[PulseEvent], 

821 n_qubits: int, 

822 n_samples: int = 200, 

823 show_carrier: bool = True, 

824 show_envelope: bool = True, 

825 envelope_width: float = 0.0, 

826 **kwargs: Any, 

827) -> Tuple: 

828 """Render a pulse schedule as a Matplotlib figure. 

829 

830 Each qubit gets its own subplot row. Physical pulses (RX, RY) are 

831 drawn as filled envelope shapes; virtual-Z gates are shown as thin 

832 vertical lines; CZ gates appear as shaded rectangles spanning both 

833 wires. 

834 

835 Args: 

836 events: Ordered list of :class:`PulseEvent` from 

837 :func:`collect_pulse_events`. 

838 n_qubits: Total number of qubits. 

839 n_samples: Number of time samples per pulse envelope. 

840 show_carrier: If ``True``, overlay the carrier-modulated waveform 

841 (envelope x cos) as a thin line. 

842 show_envelope: If ``True``, draw the full envelope shape for 

843 physical pulses. If ``False``, show them as simple 

844 rectangular blocks indicating duration only. 

845 envelope_width: Scales how far the displayed envelope extends 

846 beyond the evolution window. ``1.0`` (default) uses the 

847 adaptive width, ``> 1`` widens further, ``< 1`` tightens, 

848 and ``0`` clamps the envelope exactly to the pulse duration. 

849 **kwargs: Forwarded to ``plt.subplots``. 

850 

851 Returns: 

852 ``(fig, axes)`` — Matplotlib Figure and array of Axes. 

853 """ 

854 from qml_essentials.gates import PulseGates, PulseInformation 

855 

856 omega_c = float(PulseGates.omega_c) 

857 

858 # Assign start times per wire (sequential, no parallelism) 

859 wire_cursor: Dict[int, float] = {q: 0.0 for q in range(n_qubits)} 

860 scheduled: List[Tuple[PulseEvent, float]] = [] # (event, t_start) 

861 

862 for ev in events: 

863 t_start = max(wire_cursor[w] for w in ev.wires) 

864 scheduled.append((ev, t_start)) 

865 for w in ev.wires: 

866 wire_cursor[w] = t_start + ev.duration 

867 

868 t_total = max(wire_cursor.values()) if wire_cursor else 1.0 

869 

870 gate_colors = { 

871 "RX": "#1F78B4", 

872 "RY": "#E69F00", 

873 "RZ": "#009371", 

874 "CZ": "#ED665A", 

875 } 

876 

877 fig, axes = plt.subplots( 

878 n_qubits, 

879 1, 

880 figsize=kwargs.pop("figsize", (max(8, t_total * 2.5), 1.8 * n_qubits)), 

881 sharex=True, 

882 squeeze=False, 

883 ) 

884 axes = axes.ravel() 

885 

886 for q in range(n_qubits): 

887 ax = axes[q] 

888 ax.set_ylabel(f"q{q}", rotation=0, labelpad=20, fontsize=11, va="center") 

889 ax.axhline(0, color="grey", linewidth=0.4, zorder=0) 

890 ax.set_yticks([]) 

891 ax.spines["top"].set_visible(False) 

892 ax.spines["right"].set_visible(False) 

893 ax.spines["left"].set_visible(False) 

894 

895 axes[-1].set_xlabel("Time", fontsize=11) 

896 

897 # Pre-compute display windows, amplitude range, and x-limits 

898 display_windows: Dict[int, Tuple[float, float]] = {} 

899 amp_max = 1.0 

900 x_lo, x_hi = 0.0, t_total 

901 

902 if show_envelope: 

903 for idx, (ev, t_start) in enumerate(scheduled): 

904 if ev.envelope_fn is None or ev.gate not in ("RX", "RY"): 

905 continue 

906 t_lo, t_hi, amp = _compute_display_window(ev, n_samples, envelope_width) 

907 display_windows[idx] = (t_lo, t_hi) 

908 amp_max = max(amp_max, amp) 

909 # Map local display bounds to global time coordinates 

910 x_lo = min(x_lo, t_lo + t_start) 

911 x_hi = max(x_hi, t_hi + t_start) 

912 

913 x_margin = (x_hi - x_lo) * 0.05 

914 for q in range(n_qubits): 

915 axes[q].set_xlim(x_lo - x_margin, x_hi + x_margin) 

916 axes[q].set_ylim(-amp_max, amp_max) 

917 

918 # Draw events 

919 for idx, (ev, t_start) in enumerate(scheduled): 

920 color = gate_colors.get(ev.gate, "#bab0ac") 

921 

922 if ev.gate in ("RX", "RY") and ev.envelope_fn is not None and show_envelope: 

923 t_lo, t_hi = display_windows[idx] 

924 _draw_physical_pulse( 

925 ev, 

926 t_start, 

927 t_lo, 

928 t_hi, 

929 axes, 

930 color, 

931 n_samples, 

932 omega_c, 

933 show_carrier, 

934 ) 

935 elif ev.gate == "RZ": 

936 _draw_virtual_z(ev, t_start, axes, color, amp_max) 

937 else: 

938 _draw_block(ev, t_start, axes, color, amp_max) 

939 

940 # Legend 

941 used_gates = {ev.gate for ev, _ in scheduled} 

942 handles = [ 

943 mpatches.Patch(color=c, alpha=0.7, label=g) 

944 for g, c in gate_colors.items() 

945 if g in used_gates 

946 ] 

947 if handles: 

948 fig.legend( 

949 handles=handles, 

950 loc="lower right", 

951 ncol=len(handles), 

952 fontsize=8, 

953 framealpha=0.8, 

954 ) 

955 

956 fig.suptitle( 

957 f"Pulse Schedule ({PulseInformation.get_envelope()} envelope)", 

958 fontsize=12, 

959 fontweight="bold", 

960 ) 

961 fig.tight_layout() 

962 return fig, axes