Coverage for qml_essentials / drawing.py: 90%
402 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-30 11:43 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-30 11:43 +0000
1from fractions import Fraction
2from typing import Any, Dict, List, Tuple, Union
3from dataclasses import dataclass
4from typing import Optional
5import jax.numpy as jnp
6import matplotlib.pyplot as plt
7import matplotlib.patches as mpatches
9from qml_essentials.operations import (
10 Operation,
11)
14class TikzFigure:
15 """Wrapper around a ``quantikz`` LaTeX string with export helpers."""
17 def __init__(self, quantikz_str: str):
18 self.quantikz_str = quantikz_str
20 def __repr__(self):
21 return self.quantikz_str
23 def __str__(self):
24 return self.quantikz_str
26 def wrap_figure(self):
27 """
28 Wraps the quantikz string in a LaTeX figure environment.
30 Returns:
31 str: A formatted LaTeX string representing the TikZ figure containing
32 the quantum circuit diagram.
33 """
34 return f"""
35\\begin{{figure}}
36 \\centering
37 \\begin{{tikzpicture}}
38 \\node[scale=0.85] {{
39 \\begin{{quantikz}}
40 {self.quantikz_str}
41 \\end{{quantikz}}
42 }};
43 \\end{{tikzpicture}}
44\\end{{figure}}"""
46 def export(self, destination: str, full_document=False, mode="w") -> None:
47 """
48 Export a LaTeX document with a quantum circuit in stick notation.
50 Parameters
51 ----------
52 quantikz_strs : str or list[str]
53 LaTeX string for the quantum circuit or a list of LaTeX strings.
54 destination : str
55 Path to the destination file.
56 """
57 if full_document:
58 latex_code = f"""
59\\documentclass{{article}}
60\\usepackage{{quantikz}}
61\\usepackage{{tikz}}
62\\usetikzlibrary{{quantikz2}}
63\\usepackage{{quantikz}}
64\\usepackage[a3paper, landscape, margin=0.5cm]{{geometry}}
65\\begin{{document}}
66{self.wrap_figure()}
67\\end{{document}}"""
68 else:
69 latex_code = self.quantikz_str + "\n"
71 with open(destination, mode) as f:
72 f.write(latex_code)
75# Backwards-compatible alias so existing ``QuanTikz.TikzFigure`` references
76# keep working without changes in downstream code.
77class QuanTikz:
78 TikzFigure = TikzFigure
81def _ctrl_target_name(name: str) -> str:
82 """Strip the leading 'C' from a controlled gate name to get the target name."""
83 # CRX -> RX, CX -> X, etc.
84 return name.replace("C", "")
87def _format_param(val: float) -> str:
88 """Format a numeric parameter for text display.
90 Shows nice π-fractions when possible, otherwise 2 decimal places.
91 """
92 try:
93 frac = Fraction(float(val) / float(jnp.pi)).limit_denominator(100)
94 if frac.numerator == 0:
95 return "0"
96 if frac.denominator <= 12:
97 if frac == Fraction(1, 1):
98 return "π"
99 if frac.numerator == 1:
100 return f"π/{frac.denominator}"
101 if frac.denominator == 1:
102 return f"{frac.numerator}π"
103 return f"{frac.numerator}π/{frac.denominator}"
104 except (ValueError, ZeroDivisionError):
105 pass
106 return f"{float(val):.2f}"
109def _gate_label(op: Operation) -> str:
110 """Build a short label like ``RX(π/2)`` or ``H`` for a gate."""
111 name = op.name
112 params = op.parameters
113 if params:
114 param_str = ", ".join(_format_param(p) for p in params)
115 return f"{name}({param_str})"
116 return name
119def _tikz_param_str(val: float, op_name: str) -> str:
120 """Format a rotation angle as a LaTeX string for quantikz gates."""
121 try:
122 frac = Fraction(float(val) / float(jnp.pi)).limit_denominator(100)
123 if frac.denominator > 12:
124 return f"\\gate{{{op_name}({float(val):.2f})}}"
125 if frac.denominator == 1 and frac.numerator == 1:
126 return f"\\gate{{{op_name}(\\pi)}}"
127 if frac.numerator == 0:
128 return f"\\gate{{{op_name}(0)}}"
129 if frac.denominator == 1:
130 return f"\\gate{{{op_name}({frac.numerator}\\pi)}}"
131 if frac.numerator == 1:
132 return (
133 f"\\gate{{{op_name}\\left("
134 f"\\frac{{\\pi}}{{{frac.denominator}}}\\right)}}"
135 )
136 return (
137 f"\\gate{{{op_name}\\left("
138 f"\\frac{{{frac.numerator}\\pi}}{{{frac.denominator}}}"
139 f"\\right)}}"
140 )
141 except (ValueError, ZeroDivisionError):
142 return f"\\gate{{{op_name}({float(val):.2f})}}"
145def _tikz_align_wires(circuit_tikz: List[List[str]], wires: List[int]) -> None:
146 """Pad all *wires* to the same column length in-place."""
147 max_len = max(len(circuit_tikz[w]) for w in wires)
148 for w in wires:
149 circuit_tikz[w].extend("" for _ in range(max_len - len(circuit_tikz[w])))
152def _tikz_cell_controlled(
153 op: Operation,
154 circuit_tikz: List[List[str]],
155 param_index: int,
156 gate_values: bool,
157) -> int:
158 """Append cells for a 2-wire controlled gate; return updated param_index."""
159 ctrl_wire = op.wires[0]
160 targ_wire = op.wires[1]
161 distance = targ_wire - ctrl_wire
162 target_name = _ctrl_target_name(op.name)
164 # Build target cell
165 if op.parameters and target_name in ("RX", "RY", "RZ"):
166 if gate_values:
167 targ_cell = _tikz_param_str(float(op.parameters[0]), target_name)
168 else:
169 targ_cell = f"\\gate{{{target_name}(\\theta_{{{param_index}}})}}"
170 param_index += 1
171 elif target_name in ("X", "Y", "Z"):
172 targ_cell = "\\targ{}" if target_name == "X" else "\\control{}"
173 else:
174 targ_cell = f"\\gate{{{target_name}}}"
176 crossing = list(range(min(op.wires), max(op.wires) + 1))
177 _tikz_align_wires(circuit_tikz, crossing)
179 circuit_tikz[ctrl_wire].append(f"\\ctrl{{{distance}}}")
180 circuit_tikz[targ_wire].append(targ_cell)
182 # Pad intermediate wires
183 for w in crossing:
184 if w != ctrl_wire and w != targ_wire:
185 circuit_tikz[w].append("")
187 return param_index
190def _tikz_cell_single(
191 op: Operation,
192 circuit_tikz: List[List[str]],
193 param_index: int,
194 gate_values: bool,
195) -> int:
196 """Append a cell for a single-qubit gate; return updated param_index."""
197 w = op.wires[0]
198 name = op.name
199 if name == "Hadamard":
200 name = "H"
202 if gate_values and op.parameters:
203 cell = _tikz_param_str(float(op.parameters[0]), name)
204 elif op.parameters:
205 cell = f"\\gate{{{name}(\\theta_{{{param_index}}})}}"
206 param_index += 1
207 else:
208 cell = f"\\gate{{{name}}}"
210 circuit_tikz[w].append(cell)
211 return param_index
214def _tikz_cell_multiqubit(
215 op: Operation,
216 circuit_tikz: List[List[str]],
217) -> None:
218 """Append cells for a multi-qubit (>2 wire) gate."""
219 _tikz_align_wires(circuit_tikz, list(op.wires))
220 label = _gate_label(op)
221 for w in op.wires:
222 circuit_tikz[w].append(f"\\gate{{{label}}}")
225def _tikz_cell_barrier(
226 op: Operation,
227 circuit_tikz: List[List[str]],
228) -> None:
229 """Align all wires so that subsequent gates start in the same column.
231 The barrier is a no-op visually — it only pads shorter wires so that
232 every wire has the same number of cells at this point.
233 """
234 all_wires = list(range(len(circuit_tikz)))
235 _tikz_align_wires(circuit_tikz, all_wires)
238def _tikz_build_string(circuit_tikz: List[List[str]], n_qubits: int) -> str:
239 """Render the column grid to a quantikz LaTeX string."""
240 # Equalise wire lengths
241 max_len = max(len(wire) for wire in circuit_tikz)
242 for wire in circuit_tikz:
243 wire.extend("" for _ in range(max_len - len(wire)))
245 quantikz_str = ""
246 for wire_idx, wire_ops in enumerate(circuit_tikz):
247 for op_idx, cell in enumerate(wire_ops):
248 if op_idx < len(wire_ops) - 1:
249 quantikz_str += f"{cell} & "
250 else:
251 quantikz_str += f"{cell}"
252 if wire_idx < n_qubits - 1:
253 quantikz_str += " \\\\\n"
255 return quantikz_str
258def draw_tikz(
259 ops: List[Operation],
260 n_qubits: int,
261 gate_values: bool = False,
262 **kwargs: Any,
263) -> Any:
264 """Render a circuit tape as LaTeX/TikZ ``quantikz`` code.
266 Args:
267 ops: Ordered list of gate operations (noise channels excluded).
268 n_qubits: Total number of qubits.
269 gate_values: If ``True``, show numeric angles; otherwise use
270 symbolic \\theta_i labels.
272 Returns:
273 A :class:`~qml_essentials.drawing.TikzFigure` object.
274 """
275 circuit_tikz: List[List[str]] = [["\\lstick{\\ket{0}}"] for _ in range(n_qubits)]
276 param_index = 0
278 for op in ops:
279 if op.is_controlled and len(op.wires) == 2:
280 param_index = _tikz_cell_controlled(
281 op, circuit_tikz, param_index, gate_values
282 )
283 elif len(op.wires) == 1:
284 param_index = _tikz_cell_single(op, circuit_tikz, param_index, gate_values)
285 elif op.name == "Barrier":
286 _tikz_cell_barrier(op, circuit_tikz)
287 else:
288 _tikz_cell_multiqubit(op, circuit_tikz)
290 return TikzFigure(_tikz_build_string(circuit_tikz, n_qubits))
293def draw_text(ops: List[Operation], n_qubits: int) -> str:
294 """Render a circuit tape as an ASCII-art string.
296 Args:
297 ops: Ordered list of gate operations (noise channels excluded).
298 n_qubits: Total number of qubits.
300 Returns:
301 Multi-line string with one row per qubit.
302 """
303 if not ops:
304 lines = [f" q{q}: ───" for q in range(n_qubits)]
305 return "\n".join(lines)
307 # Schedule operations into time-step columns.
308 # Each column is a dict mapping qubit -> display string.
309 columns: List[Dict[int, str]] = []
310 wire_busy: Dict[int, int] = {} # qubit -> next free column index
312 for op in ops:
313 start = max((wire_busy.get(w, 0) for w in op.wires), default=0)
315 # Ensure enough columns exist
316 while len(columns) <= start:
317 columns.append({})
319 if op.is_controlled and len(op.wires) >= 2:
320 ctrl_wires = op.wires[:-1]
321 target_wire = op.wires[-1]
322 target_name = _ctrl_target_name(op.name)
324 # Build target label with parameters
325 if op.parameters:
326 param_str = ", ".join(_format_param(p) for p in op.parameters)
327 target_label = f"{target_name}({param_str})"
328 else:
329 target_label = target_name
331 for cw in ctrl_wires:
332 columns[start][cw] = "●"
333 columns[start][target_wire] = target_label
335 # Mark all wires in the span as busy (for crossing wires)
336 all_spanned = range(min(op.wires), max(op.wires) + 1)
337 for w in all_spanned:
338 wire_busy[w] = start + 1
339 else:
340 label = _gate_label(op)
341 for w in op.wires:
342 columns[start][w] = label
343 for w in op.wires:
344 wire_busy[w] = start + 1
346 # Render the grid
347 # Determine column widths
348 col_widths = []
349 for col in columns:
350 max_w = max((len(v) for v in col.values()), default=1)
351 col_widths.append(max(max_w, 1))
353 lines = []
354 for q in range(n_qubits):
355 parts = [f" q{q}: "]
356 for ci, col in enumerate(columns):
357 w = col_widths[ci]
358 if q in col:
359 cell = col[q].center(w)
360 else:
361 cell = "─" * w
362 parts.append(f"─┤{cell}├")
363 parts.append("─")
364 lines.append("".join(parts))
366 return "\n".join(lines)
369# Matplotlib drawing
372def draw_mpl(
373 ops: List[Operation],
374 n_qubits: int,
375 **kwargs: Any,
376) -> Tuple:
377 """Render a circuit tape as a Matplotlib figure.
379 Args:
380 ops: Ordered list of gate operations (noise channels excluded).
381 n_qubits: Total number of qubits.
382 **kwargs: Reserved for future options.
384 Returns:
385 Tuple ``(fig, ax)`` — a Matplotlib ``Figure`` and ``Axes``.
386 """
388 # Schedule into columns (same logic as text)
389 columns: List[Dict[int, str]] = []
390 wire_busy: Dict[int, int] = {}
391 ctrl_info: List[Dict[str, Any]] = [] # per-column control gate metadata
393 for op in ops:
394 start = max((wire_busy.get(w, 0) for w in op.wires), default=0)
395 while len(columns) <= start:
396 columns.append({})
397 ctrl_info.append({})
399 if op.is_controlled and len(op.wires) >= 2:
400 ctrl_wires = op.wires[:-1]
401 target_wire = op.wires[-1]
402 target_name = _ctrl_target_name(op.name)
403 if op.parameters:
404 param_str = ", ".join(_format_param(p) for p in op.parameters)
405 target_label = f"{target_name}({param_str})"
406 else:
407 target_label = target_name
409 for cw in ctrl_wires:
410 columns[start][cw] = "●"
411 columns[start][target_wire] = target_label
413 ctrl_info[start] = {
414 "ctrl": ctrl_wires,
415 "target": target_wire,
416 }
418 all_spanned = range(min(op.wires), max(op.wires) + 1)
419 for w in all_spanned:
420 wire_busy[w] = start + 1
421 else:
422 label = _gate_label(op)
423 for w in op.wires:
424 columns[start][w] = label
425 wire_busy[w] = start + 1
427 n_cols = len(columns) if columns else 1
428 fig_width = max(3.0, 1.2 * (n_cols + 2))
429 fig_height = max(2.0, 0.8 * n_qubits)
431 fig, ax = plt.subplots(figsize=(fig_width, fig_height))
432 ax.set_xlim(-0.5, n_cols + 0.5)
433 ax.set_ylim(-0.5, n_qubits - 0.5)
434 ax.invert_yaxis()
435 ax.set_aspect("equal")
436 ax.axis("off")
438 # Draw qubit wires
439 for q in range(n_qubits):
440 ax.plot([-0.3, n_cols + 0.3], [q, q], color="black", linewidth=0.8, zorder=0)
441 ax.text(
442 -0.5,
443 q,
444 "|0⟩",
445 ha="right",
446 va="center",
447 fontsize=10,
448 fontfamily="monospace",
449 )
451 # Draw gates
452 gate_box_h = 0.6
453 gate_box_w = 0.6
455 for ci, col in enumerate(columns):
456 x = ci + 0.5
458 # Draw control lines
459 ci_meta = ctrl_info[ci] if ci < len(ctrl_info) else {}
460 if ci_meta:
461 all_wires = list(ci_meta["ctrl"]) + [ci_meta["target"]]
462 y_min = min(all_wires)
463 y_max = max(all_wires)
464 ax.plot([x, x], [y_min, y_max], color="black", linewidth=1.0, zorder=1)
466 for q, label in col.items():
467 if label == "●":
468 # Control dot
469 ax.plot(x, q, "o", color="black", markersize=6, zorder=3)
470 else:
471 # Gate box
472 fontsize = 9 if len(label) <= 6 else 7
473 bw = max(gate_box_w, len(label) * 0.09 + 0.2)
474 rect = mpatches.FancyBboxPatch(
475 (x - bw / 2, q - gate_box_h / 2),
476 bw,
477 gate_box_h,
478 boxstyle="round,pad=0.05",
479 facecolor="white",
480 edgecolor="black",
481 linewidth=1.0,
482 zorder=2,
483 )
484 ax.add_patch(rect)
485 ax.text(
486 x,
487 q,
488 label,
489 ha="center",
490 va="center",
491 fontsize=fontsize,
492 zorder=4,
493 )
495 fig.tight_layout()
496 return fig, ax
499@dataclass
500class PulseEvent:
501 """Single pulse applied to one or more wires.
503 Attributes:
504 gate: Gate label, e.g. ``"RX"``, ``"CZ"``.
505 wires: Target qubit wire(s).
506 envelope_fn: Pure envelope function ``(p, t, t_c) -> amplitude``.
507 envelope_params: Envelope-shape parameters (excluding ``w`` and ``t``).
508 w: Rotation angle passed to the gate.
509 duration: Pulse duration (evolution time).
510 carrier_phase: Phase offset for the carrier cosine.
511 parent: Optional high-level gate name that decomposed into this event.
512 """
514 gate: str
515 wires: List[int]
516 envelope_fn: Any # (p, t, t_c) -> scalar
517 envelope_params: Any # jnp array of envelope shape params
518 w: float # rotation angle
519 duration: float # evolution time
520 carrier_phase: float = 0.0 # phi_c in cos(omega_c * t + phi_c)
521 parent: Optional[str] = None # composite gate that owns this pulse
524# Leaf gate metadata for pulse schedule drawing.
525# ``physical`` gates have an envelope; virtual gates (RZ, CZ) do not.
526LEAF_META = {
527 "RX": {"carrier_phase": float(jnp.pi), "physical": True},
528 "RY": {"carrier_phase": float(-jnp.pi / 2), "physical": True},
529 "RZ": {"carrier_phase": 0.0, "physical": False},
530 "CZ": {"carrier_phase": 0.0, "physical": False},
531}
534def _resolve_wires_for_drawing(wire_fn, wires_list):
535 """Resolve a ``wire_fn`` string to concrete wire(s) for drawing."""
536 if wire_fn == "all":
537 return wires_list
538 if wire_fn == "target":
539 return [wires_list[-1]] if len(wires_list) > 1 else wires_list
540 if wire_fn == "control":
541 return [wires_list[0]]
542 raise ValueError(f"Unknown wire_fn: {wire_fn!r}")
545def collect_pulse_events(
546 gate_name: str,
547 w: Union[float, List[float]],
548 wires: Union[int, List[int]],
549 pulse_params: Any = None,
550 parent: Optional[str] = None,
551) -> List[PulseEvent]:
552 """Decompose a (possibly composite) pulse gate into leaf PulseEvents.
554 Walks the :class:`DecompositionStep` tree stored on :class:`PulseParams`
555 and collects timing / envelope information for drawing — no quantum
556 operations are applied.
558 Args:
559 gate_name: Name of the gate (``"RX"``, ``"H"``, ``"CX"``, etc.).
560 w: Rotation angle(s).
561 wires: Qubit index or ``[control, target]``.
562 pulse_params: Pulse parameters or ``None`` for defaults.
563 parent: Label of the enclosing composite gate.
565 Returns:
566 Ordered list of :class:`PulseEvent` objects.
567 """
568 from qml_essentials.gates import PulseEnvelope, PulseInformation
570 pp_obj = PulseInformation.gate_by_name(gate_name)
571 if pp_obj is None:
572 raise ValueError(f"Unknown pulse gate: {gate_name!r}")
574 wires_list = [wires] if isinstance(wires, int) else list(wires)
575 parent_label = parent or gate_name
577 # --- Leaf gate ---
578 if pp_obj.is_leaf:
579 meta = LEAF_META.get(gate_name)
580 if meta is None:
581 raise ValueError(f"Unknown pulse gate: {gate_name!r}")
583 info = PulseEnvelope.get(PulseInformation.get_envelope())
584 pp = pp_obj.split_params(pulse_params)
586 if meta["physical"]:
587 env_p = pp[:-1]
588 dur = float(pp[-1])
589 return [
590 PulseEvent(
591 gate=gate_name,
592 wires=wires_list,
593 envelope_fn=info["fn"],
594 envelope_params=jnp.array(env_p),
595 w=float(w),
596 duration=dur,
597 carrier_phase=meta["carrier_phase"],
598 parent=parent_label,
599 )
600 ]
601 else:
602 # Virtual gate (RZ, CZ) — no physical envelope
603 return [
604 PulseEvent(
605 gate=gate_name,
606 wires=wires_list,
607 envelope_fn=None,
608 envelope_params=jnp.ravel(jnp.asarray(pp)),
609 w=float(w) if not isinstance(w, list) else 0.0,
610 duration=1.0,
611 carrier_phase=0.0,
612 parent=parent_label,
613 )
614 ]
616 # --- Composite gate ---
617 parts = pp_obj.split_params(pulse_params)
618 events = []
619 for step, child_params in zip(pp_obj.decomposition, parts):
620 child_wires = _resolve_wires_for_drawing(step.wire_fn, wires_list)
621 child_w = step.angle_fn(w) if step.angle_fn is not None else w
622 events += collect_pulse_events(
623 step.gate.name,
624 child_w,
625 child_wires,
626 child_params,
627 parent=parent_label,
628 )
629 return events
632def _make_event_label(gate: str, parent: Optional[str]) -> str:
633 """Build a display label, appending the parent gate if different."""
634 if parent and parent != gate:
635 return f"{gate} ({parent})"
636 return gate
639def _sample_envelope(ev: PulseEvent, t_lo: float, t_hi: float, n_samples: int):
640 """Sample the envelope over [t_lo, t_hi] and return (t_arr, signal).
642 Uses vectorised JAX operations instead of a Python loop.
643 """
644 t_c = ev.duration / 2
645 t_arr = jnp.linspace(t_lo, t_hi, n_samples)
646 env = ev.envelope_fn(ev.envelope_params, t_arr, t_c)
647 signal = env * ev.w
648 return t_arr, signal
651def _compute_display_window(
652 ev: PulseEvent,
653 n_samples: int,
654 envelope_width: float = 1.0,
655) -> Tuple[float, float, float]:
656 """Compute the (t_lo, t_hi, amp_max) display window for a physical pulse.
658 The display window is chosen adaptively based on how much the envelope
659 decays within the evolution window ``[0, duration]``. If the envelope
660 is essentially zero at the edges, the evolution window is used as-is.
661 Otherwise the window is widened until the envelope drops to
662 ``edge_ratio ** 10`` of its peak, where ``edge_ratio`` is the
663 amplitude at the window edge relative to the center.
665 The ``envelope_width`` parameter scales the resulting extension beyond
666 the evolution window. ``1.0`` gives the default adaptive width,
667 values ``> 1`` widen further, values ``< 1`` tighten, and ``0``
668 clamps the display exactly to the evolution window ``[0, duration]``.
670 Returns:
671 ``(t_lo, t_hi, amp_max)`` — local time bounds and peak amplitude.
672 """
673 t_c = ev.duration / 2
675 if envelope_width == 0:
676 t_lo, t_hi = 0.0, ev.duration
677 else:
678 val_center = float(ev.envelope_fn(ev.envelope_params, t_c, t_c))
680 if abs(val_center) < 1e-12:
681 t_lo, t_hi = 0.0, ev.duration
682 else:
683 val_edge = float(ev.envelope_fn(ev.envelope_params, 0.0, t_c))
684 edge_ratio = abs(val_edge / val_center)
686 if edge_ratio < 0.01:
687 t_lo, t_hi = 0.0, ev.duration
688 else:
689 target = edge_ratio**10
690 lo, hi = ev.duration / 2, ev.duration * 50
691 for _ in range(30):
692 mid = (lo + hi) / 2
693 val = float(ev.envelope_fn(ev.envelope_params, t_c + mid, t_c))
694 if abs(val / val_center) > target:
695 lo = mid
696 else:
697 hi = mid
698 half_width = ev.duration / 2 + (hi - ev.duration / 2) * envelope_width
699 t_lo = t_c - half_width
700 t_hi = t_c + half_width
702 _, signal = _sample_envelope(ev, t_lo, t_hi, n_samples)
703 amp = float(jnp.max(jnp.abs(signal))) * 1.1
704 return t_lo, t_hi, amp
707def _draw_physical_pulse(
708 ev: PulseEvent,
709 t_start: float,
710 t_lo: float,
711 t_hi: float,
712 axes,
713 color: str,
714 n_samples: int,
715 omega_c: float,
716 show_carrier: bool,
717) -> None:
718 """Draw a physical (RX/RY) pulse envelope on the given axes."""
719 t_arr, signal = _sample_envelope(ev, t_lo, t_hi, n_samples)
720 t_display = t_arr + t_start
722 for wire in ev.wires:
723 ax = axes[wire]
724 ax.fill_between(t_display, signal, alpha=0.12, color=color, zorder=2)
725 ax.plot(t_display, signal, color=color, linewidth=1.4, alpha=0.85, zorder=3)
727 # Mark evolution window boundaries with visible dashed lines
728 for t_edge in (t_start, t_start + ev.duration):
729 ax.axvline(
730 t_edge,
731 color=color,
732 linestyle="--",
733 linewidth=0.8,
734 alpha=0.7,
735 zorder=4,
736 )
738 if show_carrier:
739 modulated = signal * jnp.cos(omega_c * t_arr + ev.carrier_phase)
740 ax.plot(
741 t_display, modulated, color=color, linewidth=0.8, alpha=0.8, zorder=2
742 )
744 peak_idx = jnp.argmax(jnp.abs(signal))
745 ax.annotate(
746 _make_event_label(ev.gate, ev.parent),
747 xy=(float(t_display[peak_idx]), float(signal[peak_idx])),
748 fontsize=7,
749 ha="center",
750 va="bottom" if signal[peak_idx] >= 0 else "top",
751 color=color,
752 fontweight="bold",
753 zorder=5,
754 )
757def _draw_virtual_z(
758 ev: PulseEvent, t_start: float, axes, color: str, amp_max: float
759) -> None:
760 """Draw a virtual-Z gate as a dashed vertical line."""
761 t_mid = t_start + ev.duration / 2
762 for wire in ev.wires:
763 ax = axes[wire]
764 ax.vlines(
765 t_mid,
766 -amp_max * 0.6,
767 amp_max * 0.6,
768 color=color,
769 linestyle="--",
770 linewidth=1.2,
771 alpha=0.8,
772 zorder=2,
773 )
774 ax.annotate(
775 _make_event_label(ev.gate, ev.parent),
776 xy=(t_mid, amp_max * 0.85),
777 fontsize=7,
778 ha="center",
779 va="bottom",
780 color=color,
781 # fontstyle="italic",
782 zorder=5,
783 )
786def _draw_block(
787 ev: PulseEvent, t_start: float, axes, color: str, amp_max: float
788) -> None:
789 """Draw a gate as a labelled rectangular block on each of its wires."""
790 label = _make_event_label(ev.gate, ev.parent)
791 for wire in ev.wires:
792 ax = axes[wire]
793 rect = mpatches.Rectangle(
794 (t_start, -amp_max * 0.6),
795 ev.duration,
796 amp_max * 1.2,
797 alpha=0.2,
798 facecolor=color,
799 edgecolor=color,
800 linewidth=1.0,
801 zorder=1,
802 )
803 ax.add_patch(rect)
805 axes[ev.wires[0]].annotate(
806 label,
807 xy=(t_start + ev.duration / 2, amp_max * 0.7),
808 fontsize=7,
809 ha="center",
810 va="bottom",
811 color=color,
812 fontweight="bold",
813 zorder=5,
814 )
817def draw_pulse_schedule(
818 events: List[PulseEvent],
819 n_qubits: int,
820 n_samples: int = 200,
821 show_carrier: bool = True,
822 show_envelope: bool = True,
823 envelope_width: float = 0.0,
824 **kwargs: Any,
825) -> Tuple:
826 """Render a pulse schedule as a Matplotlib figure.
828 Each qubit gets its own subplot row. Physical pulses (RX, RY) are
829 drawn as filled envelope shapes; virtual-Z gates are shown as thin
830 vertical lines; CZ gates appear as shaded rectangles spanning both
831 wires.
833 Args:
834 events: Ordered list of :class:`PulseEvent` from
835 :func:`collect_pulse_events`.
836 n_qubits: Total number of qubits.
837 n_samples: Number of time samples per pulse envelope.
838 show_carrier: If ``True``, overlay the carrier-modulated waveform
839 (envelope x cos) as a thin line.
840 show_envelope: If ``True``, draw the full envelope shape for
841 physical pulses. If ``False``, show them as simple
842 rectangular blocks indicating duration only.
843 envelope_width: Scales how far the displayed envelope extends
844 beyond the evolution window. ``1.0`` (default) uses the
845 adaptive width, ``> 1`` widens further, ``< 1`` tightens,
846 and ``0`` clamps the envelope exactly to the pulse duration.
847 **kwargs: Forwarded to ``plt.subplots``.
849 Returns:
850 ``(fig, axes)`` — Matplotlib Figure and array of Axes.
851 """
852 from qml_essentials.gates import PulseGates, PulseInformation
854 omega_c = float(PulseGates.omega_c)
856 # Assign start times per wire (sequential, no parallelism)
857 wire_cursor: Dict[int, float] = {q: 0.0 for q in range(n_qubits)}
858 scheduled: List[Tuple[PulseEvent, float]] = [] # (event, t_start)
860 for ev in events:
861 t_start = max(wire_cursor[w] for w in ev.wires)
862 scheduled.append((ev, t_start))
863 for w in ev.wires:
864 wire_cursor[w] = t_start + ev.duration
866 t_total = max(wire_cursor.values()) if wire_cursor else 1.0
868 gate_colors = {
869 "RX": "#1F78B4",
870 "RY": "#E69F00",
871 "RZ": "#009371",
872 "CZ": "#ED665A",
873 }
875 fig, axes = plt.subplots(
876 n_qubits,
877 1,
878 figsize=kwargs.pop("figsize", (max(8, t_total * 2.5), 1.8 * n_qubits)),
879 sharex=True,
880 squeeze=False,
881 )
882 axes = axes.ravel()
884 for q in range(n_qubits):
885 ax = axes[q]
886 ax.set_ylabel(f"q{q}", rotation=0, labelpad=20, fontsize=11, va="center")
887 ax.axhline(0, color="grey", linewidth=0.4, zorder=0)
888 ax.set_yticks([])
889 ax.spines["top"].set_visible(False)
890 ax.spines["right"].set_visible(False)
891 ax.spines["left"].set_visible(False)
893 axes[-1].set_xlabel("Time", fontsize=11)
895 # Pre-compute display windows, amplitude range, and x-limits
896 display_windows: Dict[int, Tuple[float, float]] = {}
897 amp_max = 1.0
898 x_lo, x_hi = 0.0, t_total
900 if show_envelope:
901 for idx, (ev, t_start) in enumerate(scheduled):
902 if ev.envelope_fn is None or ev.gate not in ("RX", "RY"):
903 continue
904 t_lo, t_hi, amp = _compute_display_window(ev, n_samples, envelope_width)
905 display_windows[idx] = (t_lo, t_hi)
906 amp_max = max(amp_max, amp)
907 # Map local display bounds to global time coordinates
908 x_lo = min(x_lo, t_lo + t_start)
909 x_hi = max(x_hi, t_hi + t_start)
911 x_margin = (x_hi - x_lo) * 0.05
912 for q in range(n_qubits):
913 axes[q].set_xlim(x_lo - x_margin, x_hi + x_margin)
914 axes[q].set_ylim(-amp_max, amp_max)
916 # Draw events
917 for idx, (ev, t_start) in enumerate(scheduled):
918 color = gate_colors.get(ev.gate, "#bab0ac")
920 if ev.gate in ("RX", "RY") and ev.envelope_fn is not None and show_envelope:
921 t_lo, t_hi = display_windows[idx]
922 _draw_physical_pulse(
923 ev,
924 t_start,
925 t_lo,
926 t_hi,
927 axes,
928 color,
929 n_samples,
930 omega_c,
931 show_carrier,
932 )
933 elif ev.gate == "RZ":
934 _draw_virtual_z(ev, t_start, axes, color, amp_max)
935 else:
936 _draw_block(ev, t_start, axes, color, amp_max)
938 # Legend
939 used_gates = {ev.gate for ev, _ in scheduled}
940 handles = [
941 mpatches.Patch(color=c, alpha=0.7, label=g)
942 for g, c in gate_colors.items()
943 if g in used_gates
944 ]
945 if handles:
946 fig.legend(
947 handles=handles,
948 loc="lower right",
949 ncol=len(handles),
950 fontsize=8,
951 framealpha=0.8,
952 )
954 fig.suptitle(
955 f"Pulse Schedule ({PulseInformation.get_envelope()} envelope)",
956 fontsize=12,
957 fontweight="bold",
958 )
959 fig.tight_layout()
960 return fig, axes