Coverage for qml_essentials / drawing.py: 90%
402 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-16 10:19 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-16 10:19 +0000
1from fractions import Fraction
2from typing import Any, Dict, List, Tuple, Union
3from dataclasses import dataclass
4from typing import Optional
5import jax.numpy as jnp
6import matplotlib.pyplot as plt
7import matplotlib.patches as mpatches
9from qml_essentials.operations import (
10 Operation,
11)
14class TikzFigure:
15 """Wrapper around a ``quantikz`` LaTeX string with export helpers."""
17 def __init__(self, quantikz_str: str):
18 self.quantikz_str = quantikz_str
20 def __repr__(self):
21 return self.quantikz_str
23 def __str__(self):
24 return self.quantikz_str
26 def wrap_figure(self) -> str:
27 """
28 Wraps the quantikz string in a LaTeX figure environment.
30 Returns:
31 str: A formatted LaTeX string representing the TikZ figure containing
32 the quantum circuit diagram.
33 """
34 return f"""
35\\begin{{figure}}
36 \\centering
37 \\begin{{tikzpicture}}
38 \\node[scale=0.85] {{
39 \\begin{{quantikz}}
40 {self.quantikz_str}
41 \\end{{quantikz}}
42 }};
43 \\end{{tikzpicture}}
44\\end{{figure}}"""
46 def export(
47 self, destination: str, full_document: bool = False, mode: str = "w"
48 ) -> None:
49 """
50 Export a LaTeX document with a quantum circuit in stick notation.
52 Parameters
53 ----------
54 quantikz_strs : str or list[str]
55 LaTeX string for the quantum circuit or a list of LaTeX strings.
56 destination : str
57 Path to the destination file.
58 """
59 if full_document:
60 latex_code = f"""
61\\documentclass{{article}}
62\\usepackage{{quantikz}}
63\\usepackage{{tikz}}
64\\usetikzlibrary{{quantikz2}}
65\\usepackage{{quantikz}}
66\\usepackage[a3paper, landscape, margin=0.5cm]{{geometry}}
67\\begin{{document}}
68{self.wrap_figure()}
69\\end{{document}}"""
70 else:
71 latex_code = self.quantikz_str + "\n"
73 with open(destination, mode) as f:
74 f.write(latex_code)
77# Backwards-compatible alias so existing ``QuanTikz.TikzFigure`` references
78# keep working without changes in downstream code.
79class QuanTikz:
80 TikzFigure = TikzFigure
83def _ctrl_target_name(name: str) -> str:
84 """Strip the leading 'C' from a controlled gate name to get the target name."""
85 # CRX -> RX, CX -> X, etc.
86 return name.replace("C", "")
89def _format_param(val: float) -> str:
90 """Format a numeric parameter for text display.
92 Shows nice π-fractions when possible, otherwise 2 decimal places.
93 """
94 try:
95 frac = Fraction(float(val) / float(jnp.pi)).limit_denominator(100)
96 if frac.numerator == 0:
97 return "0"
98 if frac.denominator <= 12:
99 if frac == Fraction(1, 1):
100 return "π"
101 if frac.numerator == 1:
102 return f"π/{frac.denominator}"
103 if frac.denominator == 1:
104 return f"{frac.numerator}π"
105 return f"{frac.numerator}π/{frac.denominator}"
106 except (ValueError, ZeroDivisionError):
107 pass
108 return f"{float(val):.2f}"
111def _gate_label(op: Operation) -> str:
112 """Build a short label like ``RX(π/2)`` or ``H`` for a gate."""
113 name = op.name
114 params = op.parameters
115 if params:
116 param_str = ", ".join(_format_param(p) for p in params)
117 return f"{name}({param_str})"
118 return name
121def _tikz_param_str(val: float, op_name: str) -> str:
122 """Format a rotation angle as a LaTeX string for quantikz gates."""
123 try:
124 frac = Fraction(float(val) / float(jnp.pi)).limit_denominator(100)
125 if frac.denominator > 12:
126 return f"\\gate{{{op_name}({float(val):.2f})}}"
127 if frac.denominator == 1 and frac.numerator == 1:
128 return f"\\gate{{{op_name}(\\pi)}}"
129 if frac.numerator == 0:
130 return f"\\gate{{{op_name}(0)}}"
131 if frac.denominator == 1:
132 return f"\\gate{{{op_name}({frac.numerator}\\pi)}}"
133 if frac.numerator == 1:
134 return (
135 f"\\gate{{{op_name}\\left("
136 f"\\frac{{\\pi}}{{{frac.denominator}}}\\right)}}"
137 )
138 return (
139 f"\\gate{{{op_name}\\left("
140 f"\\frac{{{frac.numerator}\\pi}}{{{frac.denominator}}}"
141 f"\\right)}}"
142 )
143 except (ValueError, ZeroDivisionError):
144 return f"\\gate{{{op_name}({float(val):.2f})}}"
147def _tikz_align_wires(circuit_tikz: List[List[str]], wires: List[int]) -> None:
148 """Pad all *wires* to the same column length in-place."""
149 max_len = max(len(circuit_tikz[w]) for w in wires)
150 for w in wires:
151 circuit_tikz[w].extend("" for _ in range(max_len - len(circuit_tikz[w])))
154def _tikz_cell_controlled(
155 op: Operation,
156 circuit_tikz: List[List[str]],
157 param_index: int,
158 gate_values: bool,
159) -> int:
160 """Append cells for a 2-wire controlled gate; return updated param_index."""
161 ctrl_wire = op.wires[0]
162 targ_wire = op.wires[1]
163 distance = targ_wire - ctrl_wire
164 target_name = _ctrl_target_name(op.name)
166 # Build target cell
167 if op.parameters and target_name in ("RX", "RY", "RZ"):
168 if gate_values:
169 targ_cell = _tikz_param_str(float(op.parameters[0]), target_name)
170 else:
171 targ_cell = f"\\gate{{{target_name}(\\theta_{{{param_index}}})}}"
172 param_index += 1
173 elif target_name in ("X", "Y", "Z"):
174 targ_cell = "\\targ{}" if target_name == "X" else "\\control{}"
175 else:
176 targ_cell = f"\\gate{{{target_name}}}"
178 crossing = list(range(min(op.wires), max(op.wires) + 1))
179 _tikz_align_wires(circuit_tikz, crossing)
181 circuit_tikz[ctrl_wire].append(f"\\ctrl{{{distance}}}")
182 circuit_tikz[targ_wire].append(targ_cell)
184 # Pad intermediate wires
185 for w in crossing:
186 if w != ctrl_wire and w != targ_wire:
187 circuit_tikz[w].append("")
189 return param_index
192def _tikz_cell_single(
193 op: Operation,
194 circuit_tikz: List[List[str]],
195 param_index: int,
196 gate_values: bool,
197) -> int:
198 """Append a cell for a single-qubit gate; return updated param_index."""
199 w = op.wires[0]
200 name = op.name
201 if name == "Hadamard":
202 name = "H"
204 if gate_values and op.parameters:
205 cell = _tikz_param_str(float(op.parameters[0]), name)
206 elif op.parameters:
207 cell = f"\\gate{{{name}(\\theta_{{{param_index}}})}}"
208 param_index += 1
209 else:
210 cell = f"\\gate{{{name}}}"
212 circuit_tikz[w].append(cell)
213 return param_index
216def _tikz_cell_multiqubit(
217 op: Operation,
218 circuit_tikz: List[List[str]],
219) -> None:
220 """Append cells for a multi-qubit (>2 wire) gate."""
221 _tikz_align_wires(circuit_tikz, list(op.wires))
222 label = _gate_label(op)
223 for w in op.wires:
224 circuit_tikz[w].append(f"\\gate{{{label}}}")
227def _tikz_cell_barrier(
228 op: Operation,
229 circuit_tikz: List[List[str]],
230) -> None:
231 """Align all wires so that subsequent gates start in the same column.
233 The barrier is a no-op visually — it only pads shorter wires so that
234 every wire has the same number of cells at this point.
235 """
236 all_wires = list(range(len(circuit_tikz)))
237 _tikz_align_wires(circuit_tikz, all_wires)
240def _tikz_build_string(circuit_tikz: List[List[str]], n_qubits: int) -> str:
241 """Render the column grid to a quantikz LaTeX string."""
242 # Equalise wire lengths
243 max_len = max(len(wire) for wire in circuit_tikz)
244 for wire in circuit_tikz:
245 wire.extend("" for _ in range(max_len - len(wire)))
247 quantikz_str = ""
248 for wire_idx, wire_ops in enumerate(circuit_tikz):
249 for op_idx, cell in enumerate(wire_ops):
250 if op_idx < len(wire_ops) - 1:
251 quantikz_str += f"{cell} & "
252 else:
253 quantikz_str += f"{cell}"
254 if wire_idx < n_qubits - 1:
255 quantikz_str += " \\\\\n"
257 return quantikz_str
260def draw_tikz(
261 ops: List[Operation],
262 n_qubits: int,
263 gate_values: bool = False,
264 **kwargs: Any,
265) -> Any:
266 """Render a circuit tape as LaTeX/TikZ ``quantikz`` code.
268 Args:
269 ops: Ordered list of gate operations (noise channels excluded).
270 n_qubits: Total number of qubits.
271 gate_values: If ``True``, show numeric angles; otherwise use
272 symbolic \\theta_i labels.
274 Returns:
275 A :class:`~qml_essentials.drawing.TikzFigure` object.
276 """
277 circuit_tikz: List[List[str]] = [["\\lstick{\\ket{0}}"] for _ in range(n_qubits)]
278 param_index = 0
280 for op in ops:
281 if op.is_controlled and len(op.wires) == 2:
282 param_index = _tikz_cell_controlled(
283 op, circuit_tikz, param_index, gate_values
284 )
285 elif len(op.wires) == 1:
286 param_index = _tikz_cell_single(op, circuit_tikz, param_index, gate_values)
287 elif op.name == "Barrier":
288 _tikz_cell_barrier(op, circuit_tikz)
289 else:
290 _tikz_cell_multiqubit(op, circuit_tikz)
292 return TikzFigure(_tikz_build_string(circuit_tikz, n_qubits))
295def draw_text(ops: List[Operation], n_qubits: int) -> str:
296 """Render a circuit tape as an ASCII-art string.
298 Args:
299 ops: Ordered list of gate operations (noise channels excluded).
300 n_qubits: Total number of qubits.
302 Returns:
303 Multi-line string with one row per qubit.
304 """
305 if not ops:
306 lines = [f" q{q}: ───" for q in range(n_qubits)]
307 return "\n".join(lines)
309 # Schedule operations into time-step columns.
310 # Each column is a dict mapping qubit -> display string.
311 columns: List[Dict[int, str]] = []
312 wire_busy: Dict[int, int] = {} # qubit -> next free column index
314 for op in ops:
315 start = max((wire_busy.get(w, 0) for w in op.wires), default=0)
317 # Ensure enough columns exist
318 while len(columns) <= start:
319 columns.append({})
321 if op.is_controlled and len(op.wires) >= 2:
322 ctrl_wires = op.wires[:-1]
323 target_wire = op.wires[-1]
324 target_name = _ctrl_target_name(op.name)
326 # Build target label with parameters
327 if op.parameters:
328 param_str = ", ".join(_format_param(p) for p in op.parameters)
329 target_label = f"{target_name}({param_str})"
330 else:
331 target_label = target_name
333 for cw in ctrl_wires:
334 columns[start][cw] = "●"
335 columns[start][target_wire] = target_label
337 # Mark all wires in the span as busy (for crossing wires)
338 all_spanned = range(min(op.wires), max(op.wires) + 1)
339 for w in all_spanned:
340 wire_busy[w] = start + 1
341 else:
342 label = _gate_label(op)
343 for w in op.wires:
344 columns[start][w] = label
345 for w in op.wires:
346 wire_busy[w] = start + 1
348 # Render the grid
349 # Determine column widths
350 col_widths = []
351 for col in columns:
352 max_w = max((len(v) for v in col.values()), default=1)
353 col_widths.append(max(max_w, 1))
355 lines = []
356 for q in range(n_qubits):
357 parts = [f" q{q}: "]
358 for ci, col in enumerate(columns):
359 w = col_widths[ci]
360 if q in col:
361 cell = col[q].center(w)
362 else:
363 cell = "─" * w
364 parts.append(f"─┤{cell}├")
365 parts.append("─")
366 lines.append("".join(parts))
368 return "\n".join(lines)
371# Matplotlib drawing
374def draw_mpl(
375 ops: List[Operation],
376 n_qubits: int,
377 **kwargs: Any,
378) -> Tuple:
379 """Render a circuit tape as a Matplotlib figure.
381 Args:
382 ops: Ordered list of gate operations (noise channels excluded).
383 n_qubits: Total number of qubits.
384 **kwargs: Reserved for future options.
386 Returns:
387 Tuple ``(fig, ax)`` — a Matplotlib ``Figure`` and ``Axes``.
388 """
390 # Schedule into columns (same logic as text)
391 columns: List[Dict[int, str]] = []
392 wire_busy: Dict[int, int] = {}
393 ctrl_info: List[Dict[str, Any]] = [] # per-column control gate metadata
395 for op in ops:
396 start = max((wire_busy.get(w, 0) for w in op.wires), default=0)
397 while len(columns) <= start:
398 columns.append({})
399 ctrl_info.append({})
401 if op.is_controlled and len(op.wires) >= 2:
402 ctrl_wires = op.wires[:-1]
403 target_wire = op.wires[-1]
404 target_name = _ctrl_target_name(op.name)
405 if op.parameters:
406 param_str = ", ".join(_format_param(p) for p in op.parameters)
407 target_label = f"{target_name}({param_str})"
408 else:
409 target_label = target_name
411 for cw in ctrl_wires:
412 columns[start][cw] = "●"
413 columns[start][target_wire] = target_label
415 ctrl_info[start] = {
416 "ctrl": ctrl_wires,
417 "target": target_wire,
418 }
420 all_spanned = range(min(op.wires), max(op.wires) + 1)
421 for w in all_spanned:
422 wire_busy[w] = start + 1
423 else:
424 label = _gate_label(op)
425 for w in op.wires:
426 columns[start][w] = label
427 wire_busy[w] = start + 1
429 n_cols = len(columns) if columns else 1
430 fig_width = max(3.0, 1.2 * (n_cols + 2))
431 fig_height = max(2.0, 0.8 * n_qubits)
433 fig, ax = plt.subplots(figsize=(fig_width, fig_height))
434 ax.set_xlim(-0.5, n_cols + 0.5)
435 ax.set_ylim(-0.5, n_qubits - 0.5)
436 ax.invert_yaxis()
437 ax.set_aspect("equal")
438 ax.axis("off")
440 # Draw qubit wires
441 for q in range(n_qubits):
442 ax.plot([-0.3, n_cols + 0.3], [q, q], color="black", linewidth=0.8, zorder=0)
443 ax.text(
444 -0.5,
445 q,
446 "|0⟩",
447 ha="right",
448 va="center",
449 fontsize=10,
450 fontfamily="monospace",
451 )
453 # Draw gates
454 gate_box_h = 0.6
455 gate_box_w = 0.6
457 for ci, col in enumerate(columns):
458 x = ci + 0.5
460 # Draw control lines
461 ci_meta = ctrl_info[ci] if ci < len(ctrl_info) else {}
462 if ci_meta:
463 all_wires = list(ci_meta["ctrl"]) + [ci_meta["target"]]
464 y_min = min(all_wires)
465 y_max = max(all_wires)
466 ax.plot([x, x], [y_min, y_max], color="black", linewidth=1.0, zorder=1)
468 for q, label in col.items():
469 if label == "●":
470 # Control dot
471 ax.plot(x, q, "o", color="black", markersize=6, zorder=3)
472 else:
473 # Gate box
474 fontsize = 9 if len(label) <= 6 else 7
475 bw = max(gate_box_w, len(label) * 0.09 + 0.2)
476 rect = mpatches.FancyBboxPatch(
477 (x - bw / 2, q - gate_box_h / 2),
478 bw,
479 gate_box_h,
480 boxstyle="round,pad=0.05",
481 facecolor="white",
482 edgecolor="black",
483 linewidth=1.0,
484 zorder=2,
485 )
486 ax.add_patch(rect)
487 ax.text(
488 x,
489 q,
490 label,
491 ha="center",
492 va="center",
493 fontsize=fontsize,
494 zorder=4,
495 )
497 fig.tight_layout()
498 return fig, ax
501@dataclass
502class PulseEvent:
503 """Single pulse applied to one or more wires.
505 Attributes:
506 gate: Gate label, e.g. ``"RX"``, ``"CZ"``.
507 wires: Target qubit wire(s).
508 envelope_fn: Pure envelope function ``(p, t, t_c) -> amplitude``.
509 envelope_params: Envelope-shape parameters (excluding ``w`` and ``t``).
510 w: Rotation angle passed to the gate.
511 duration: Pulse duration (evolution time).
512 carrier_phase: Phase offset for the carrier cosine.
513 parent: Optional high-level gate name that decomposed into this event.
514 """
516 gate: str
517 wires: List[int]
518 envelope_fn: Any # (p, t, t_c) -> scalar
519 envelope_params: Any # jnp array of envelope shape params
520 w: float # rotation angle
521 duration: float # evolution time
522 carrier_phase: float = 0.0 # phi_c in cos(omega_c * t + phi_c)
523 parent: Optional[str] = None # composite gate that owns this pulse
526# Leaf gate metadata for pulse schedule drawing.
527# ``physical`` gates have an envelope; virtual gates (RZ, CZ) do not.
528LEAF_META = {
529 "RX": {"carrier_phase": float(jnp.pi), "physical": True},
530 "RY": {"carrier_phase": float(-jnp.pi / 2), "physical": True},
531 "RZ": {"carrier_phase": 0.0, "physical": False},
532 "CZ": {"carrier_phase": 0.0, "physical": False},
533}
536def _resolve_wires_for_drawing(wire_fn, wires_list):
537 """Resolve a ``wire_fn`` string to concrete wire(s) for drawing."""
538 if wire_fn == "all":
539 return wires_list
540 if wire_fn == "target":
541 return [wires_list[-1]] if len(wires_list) > 1 else wires_list
542 if wire_fn == "control":
543 return [wires_list[0]]
544 raise ValueError(f"Unknown wire_fn: {wire_fn!r}")
547def collect_pulse_events(
548 gate_name: str,
549 w: Union[float, List[float]],
550 wires: Union[int, List[int]],
551 pulse_params: Any = None,
552 parent: Optional[str] = None,
553) -> List[PulseEvent]:
554 """Decompose a (possibly composite) pulse gate into leaf PulseEvents.
556 Walks the :class:`DecompositionStep` tree stored on :class:`PulseParams`
557 and collects timing / envelope information for drawing — no quantum
558 operations are applied.
560 Args:
561 gate_name: Name of the gate (``"RX"``, ``"H"``, ``"CX"``, etc.).
562 w: Rotation angle(s).
563 wires: Qubit index or ``[control, target]``.
564 pulse_params: Pulse parameters or ``None`` for defaults.
565 parent: Label of the enclosing composite gate.
567 Returns:
568 Ordered list of :class:`PulseEvent` objects.
569 """
570 from qml_essentials.gates import PulseEnvelope, PulseInformation
572 pp_obj = PulseInformation.gate_by_name(gate_name)
573 if pp_obj is None:
574 raise ValueError(f"Unknown pulse gate: {gate_name!r}")
576 wires_list = [wires] if isinstance(wires, int) else list(wires)
577 parent_label = parent or gate_name
579 # --- Leaf gate ---
580 if pp_obj.is_leaf:
581 meta = LEAF_META.get(gate_name)
582 if meta is None:
583 raise ValueError(f"Unknown pulse gate: {gate_name!r}")
585 info = PulseEnvelope.get(PulseInformation.get_envelope())
586 pp = pp_obj.split_params(pulse_params)
588 if meta["physical"]:
589 env_p = pp[:-1]
590 dur = float(pp[-1])
591 return [
592 PulseEvent(
593 gate=gate_name,
594 wires=wires_list,
595 envelope_fn=info["fn"],
596 envelope_params=jnp.array(env_p),
597 w=float(w),
598 duration=dur,
599 carrier_phase=meta["carrier_phase"],
600 parent=parent_label,
601 )
602 ]
603 else:
604 # Virtual gate (RZ, CZ) — no physical envelope
605 return [
606 PulseEvent(
607 gate=gate_name,
608 wires=wires_list,
609 envelope_fn=None,
610 envelope_params=jnp.ravel(jnp.asarray(pp)),
611 w=float(w) if not isinstance(w, list) else 0.0,
612 duration=1.0,
613 carrier_phase=0.0,
614 parent=parent_label,
615 )
616 ]
618 # --- Composite gate ---
619 parts = pp_obj.split_params(pulse_params)
620 events = []
621 for step, child_params in zip(pp_obj.decomposition, parts):
622 child_wires = _resolve_wires_for_drawing(step.wire_fn, wires_list)
623 child_w = step.angle_fn(w) if step.angle_fn is not None else w
624 events += collect_pulse_events(
625 step.gate.name,
626 child_w,
627 child_wires,
628 child_params,
629 parent=parent_label,
630 )
631 return events
634def _make_event_label(gate: str, parent: Optional[str]) -> str:
635 """Build a display label, appending the parent gate if different."""
636 if parent and parent != gate:
637 return f"{gate} ({parent})"
638 return gate
641def _sample_envelope(ev: PulseEvent, t_lo: float, t_hi: float, n_samples: int):
642 """Sample the envelope over [t_lo, t_hi] and return (t_arr, signal).
644 Uses vectorised JAX operations instead of a Python loop.
645 """
646 t_c = ev.duration / 2
647 t_arr = jnp.linspace(t_lo, t_hi, n_samples)
648 env = ev.envelope_fn(ev.envelope_params, t_arr, t_c)
649 signal = env * ev.w
650 return t_arr, signal
653def _compute_display_window(
654 ev: PulseEvent,
655 n_samples: int,
656 envelope_width: float = 1.0,
657) -> Tuple[float, float, float]:
658 """Compute the (t_lo, t_hi, amp_max) display window for a physical pulse.
660 The display window is chosen adaptively based on how much the envelope
661 decays within the evolution window ``[0, duration]``. If the envelope
662 is essentially zero at the edges, the evolution window is used as-is.
663 Otherwise the window is widened until the envelope drops to
664 ``edge_ratio ** 10`` of its peak, where ``edge_ratio`` is the
665 amplitude at the window edge relative to the center.
667 The ``envelope_width`` parameter scales the resulting extension beyond
668 the evolution window. ``1.0`` gives the default adaptive width,
669 values ``> 1`` widen further, values ``< 1`` tighten, and ``0``
670 clamps the display exactly to the evolution window ``[0, duration]``.
672 Returns:
673 ``(t_lo, t_hi, amp_max)`` — local time bounds and peak amplitude.
674 """
675 t_c = ev.duration / 2
677 if envelope_width == 0:
678 t_lo, t_hi = 0.0, ev.duration
679 else:
680 val_center = float(ev.envelope_fn(ev.envelope_params, t_c, t_c))
682 if abs(val_center) < 1e-12:
683 t_lo, t_hi = 0.0, ev.duration
684 else:
685 val_edge = float(ev.envelope_fn(ev.envelope_params, 0.0, t_c))
686 edge_ratio = abs(val_edge / val_center)
688 if edge_ratio < 0.01:
689 t_lo, t_hi = 0.0, ev.duration
690 else:
691 target = edge_ratio**10
692 lo, hi = ev.duration / 2, ev.duration * 50
693 for _ in range(30):
694 mid = (lo + hi) / 2
695 val = float(ev.envelope_fn(ev.envelope_params, t_c + mid, t_c))
696 if abs(val / val_center) > target:
697 lo = mid
698 else:
699 hi = mid
700 half_width = ev.duration / 2 + (hi - ev.duration / 2) * envelope_width
701 t_lo = t_c - half_width
702 t_hi = t_c + half_width
704 _, signal = _sample_envelope(ev, t_lo, t_hi, n_samples)
705 amp = float(jnp.max(jnp.abs(signal))) * 1.1
706 return t_lo, t_hi, amp
709def _draw_physical_pulse(
710 ev: PulseEvent,
711 t_start: float,
712 t_lo: float,
713 t_hi: float,
714 axes,
715 color: str,
716 n_samples: int,
717 omega_c: float,
718 show_carrier: bool,
719) -> None:
720 """Draw a physical (RX/RY) pulse envelope on the given axes."""
721 t_arr, signal = _sample_envelope(ev, t_lo, t_hi, n_samples)
722 t_display = t_arr + t_start
724 for wire in ev.wires:
725 ax = axes[wire]
726 ax.fill_between(t_display, signal, alpha=0.12, color=color, zorder=2)
727 ax.plot(t_display, signal, color=color, linewidth=1.4, alpha=0.85, zorder=3)
729 # Mark evolution window boundaries with visible dashed lines
730 for t_edge in (t_start, t_start + ev.duration):
731 ax.axvline(
732 t_edge,
733 color=color,
734 linestyle="--",
735 linewidth=0.8,
736 alpha=0.7,
737 zorder=4,
738 )
740 if show_carrier:
741 modulated = signal * jnp.cos(omega_c * t_arr + ev.carrier_phase)
742 ax.plot(
743 t_display, modulated, color=color, linewidth=0.8, alpha=0.8, zorder=2
744 )
746 peak_idx = jnp.argmax(jnp.abs(signal))
747 ax.annotate(
748 _make_event_label(ev.gate, ev.parent),
749 xy=(float(t_display[peak_idx]), float(signal[peak_idx])),
750 fontsize=7,
751 ha="center",
752 va="bottom" if signal[peak_idx] >= 0 else "top",
753 color=color,
754 fontweight="bold",
755 zorder=5,
756 )
759def _draw_virtual_z(
760 ev: PulseEvent, t_start: float, axes, color: str, amp_max: float
761) -> None:
762 """Draw a virtual-Z gate as a dashed vertical line."""
763 t_mid = t_start + ev.duration / 2
764 for wire in ev.wires:
765 ax = axes[wire]
766 ax.vlines(
767 t_mid,
768 -amp_max * 0.6,
769 amp_max * 0.6,
770 color=color,
771 linestyle="--",
772 linewidth=1.2,
773 alpha=0.8,
774 zorder=2,
775 )
776 ax.annotate(
777 _make_event_label(ev.gate, ev.parent),
778 xy=(t_mid, amp_max * 0.85),
779 fontsize=7,
780 ha="center",
781 va="bottom",
782 color=color,
783 # fontstyle="italic",
784 zorder=5,
785 )
788def _draw_block(
789 ev: PulseEvent, t_start: float, axes, color: str, amp_max: float
790) -> None:
791 """Draw a gate as a labelled rectangular block on each of its wires."""
792 label = _make_event_label(ev.gate, ev.parent)
793 for wire in ev.wires:
794 ax = axes[wire]
795 rect = mpatches.Rectangle(
796 (t_start, -amp_max * 0.6),
797 ev.duration,
798 amp_max * 1.2,
799 alpha=0.2,
800 facecolor=color,
801 edgecolor=color,
802 linewidth=1.0,
803 zorder=1,
804 )
805 ax.add_patch(rect)
807 axes[ev.wires[0]].annotate(
808 label,
809 xy=(t_start + ev.duration / 2, amp_max * 0.7),
810 fontsize=7,
811 ha="center",
812 va="bottom",
813 color=color,
814 fontweight="bold",
815 zorder=5,
816 )
819def draw_pulse_schedule(
820 events: List[PulseEvent],
821 n_qubits: int,
822 n_samples: int = 200,
823 show_carrier: bool = True,
824 show_envelope: bool = True,
825 envelope_width: float = 0.0,
826 **kwargs: Any,
827) -> Tuple:
828 """Render a pulse schedule as a Matplotlib figure.
830 Each qubit gets its own subplot row. Physical pulses (RX, RY) are
831 drawn as filled envelope shapes; virtual-Z gates are shown as thin
832 vertical lines; CZ gates appear as shaded rectangles spanning both
833 wires.
835 Args:
836 events: Ordered list of :class:`PulseEvent` from
837 :func:`collect_pulse_events`.
838 n_qubits: Total number of qubits.
839 n_samples: Number of time samples per pulse envelope.
840 show_carrier: If ``True``, overlay the carrier-modulated waveform
841 (envelope x cos) as a thin line.
842 show_envelope: If ``True``, draw the full envelope shape for
843 physical pulses. If ``False``, show them as simple
844 rectangular blocks indicating duration only.
845 envelope_width: Scales how far the displayed envelope extends
846 beyond the evolution window. ``1.0`` (default) uses the
847 adaptive width, ``> 1`` widens further, ``< 1`` tightens,
848 and ``0`` clamps the envelope exactly to the pulse duration.
849 **kwargs: Forwarded to ``plt.subplots``.
851 Returns:
852 ``(fig, axes)`` — Matplotlib Figure and array of Axes.
853 """
854 from qml_essentials.gates import PulseGates, PulseInformation
856 omega_c = float(PulseGates.omega_c)
858 # Assign start times per wire (sequential, no parallelism)
859 wire_cursor: Dict[int, float] = {q: 0.0 for q in range(n_qubits)}
860 scheduled: List[Tuple[PulseEvent, float]] = [] # (event, t_start)
862 for ev in events:
863 t_start = max(wire_cursor[w] for w in ev.wires)
864 scheduled.append((ev, t_start))
865 for w in ev.wires:
866 wire_cursor[w] = t_start + ev.duration
868 t_total = max(wire_cursor.values()) if wire_cursor else 1.0
870 gate_colors = {
871 "RX": "#1F78B4",
872 "RY": "#E69F00",
873 "RZ": "#009371",
874 "CZ": "#ED665A",
875 }
877 fig, axes = plt.subplots(
878 n_qubits,
879 1,
880 figsize=kwargs.pop("figsize", (max(8, t_total * 2.5), 1.8 * n_qubits)),
881 sharex=True,
882 squeeze=False,
883 )
884 axes = axes.ravel()
886 for q in range(n_qubits):
887 ax = axes[q]
888 ax.set_ylabel(f"q{q}", rotation=0, labelpad=20, fontsize=11, va="center")
889 ax.axhline(0, color="grey", linewidth=0.4, zorder=0)
890 ax.set_yticks([])
891 ax.spines["top"].set_visible(False)
892 ax.spines["right"].set_visible(False)
893 ax.spines["left"].set_visible(False)
895 axes[-1].set_xlabel("Time", fontsize=11)
897 # Pre-compute display windows, amplitude range, and x-limits
898 display_windows: Dict[int, Tuple[float, float]] = {}
899 amp_max = 1.0
900 x_lo, x_hi = 0.0, t_total
902 if show_envelope:
903 for idx, (ev, t_start) in enumerate(scheduled):
904 if ev.envelope_fn is None or ev.gate not in ("RX", "RY"):
905 continue
906 t_lo, t_hi, amp = _compute_display_window(ev, n_samples, envelope_width)
907 display_windows[idx] = (t_lo, t_hi)
908 amp_max = max(amp_max, amp)
909 # Map local display bounds to global time coordinates
910 x_lo = min(x_lo, t_lo + t_start)
911 x_hi = max(x_hi, t_hi + t_start)
913 x_margin = (x_hi - x_lo) * 0.05
914 for q in range(n_qubits):
915 axes[q].set_xlim(x_lo - x_margin, x_hi + x_margin)
916 axes[q].set_ylim(-amp_max, amp_max)
918 # Draw events
919 for idx, (ev, t_start) in enumerate(scheduled):
920 color = gate_colors.get(ev.gate, "#bab0ac")
922 if ev.gate in ("RX", "RY") and ev.envelope_fn is not None and show_envelope:
923 t_lo, t_hi = display_windows[idx]
924 _draw_physical_pulse(
925 ev,
926 t_start,
927 t_lo,
928 t_hi,
929 axes,
930 color,
931 n_samples,
932 omega_c,
933 show_carrier,
934 )
935 elif ev.gate == "RZ":
936 _draw_virtual_z(ev, t_start, axes, color, amp_max)
937 else:
938 _draw_block(ev, t_start, axes, color, amp_max)
940 # Legend
941 used_gates = {ev.gate for ev, _ in scheduled}
942 handles = [
943 mpatches.Patch(color=c, alpha=0.7, label=g)
944 for g, c in gate_colors.items()
945 if g in used_gates
946 ]
947 if handles:
948 fig.legend(
949 handles=handles,
950 loc="lower right",
951 ncol=len(handles),
952 fontsize=8,
953 framealpha=0.8,
954 )
956 fig.suptitle(
957 f"Pulse Schedule ({PulseInformation.get_envelope()} envelope)",
958 fontsize=12,
959 fontweight="bold",
960 )
961 fig.tight_layout()
962 return fig, axes