Coverage for qml_essentials/utils.py: 95%
276 statements
« prev ^ index » next coverage.py v7.9.2, created at 2026-02-20 14:03 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2026-02-20 14:03 +0000
1from __future__ import annotations
2from typing import List, Tuple
3import jax
4import jax.numpy as jnp
5import numpy as np
6import pennylane as qml
7from pennylane.operation import Operator
8from pennylane.tape import QuantumScript, QuantumTape
9import pennylane.ops.op_math as qml_op
10from fractions import Fraction
11from itertools import cycle
12from scipy.linalg import logm
14CLIFFORD_GATES = (
15 qml.PauliX,
16 qml.PauliY,
17 qml.PauliZ,
18 qml.X,
19 qml.Y,
20 qml.Z,
21 qml.Hadamard,
22 qml.S,
23 qml.CNOT,
24)
26PAULI_ROTATION_GATES = (
27 qml.RX,
28 qml.RY,
29 qml.RZ,
30 qml.PauliRot,
31)
33SKIPPABLE_OPERATIONS = (qml.Barrier,)
36def safe_random_split(random_key: jax.random.PRNGKey, *args, **kwargs):
37 if random_key is None:
38 return None, None
39 else:
40 return jax.random.split(random_key, *args, **kwargs)
43def logm_v(A: jnp.ndarray, **kwargs) -> jnp.ndarray:
44 """
45 Compute the logarithm of a matrix. If the provided matrix has an additional
46 batch dimension, the logarithm of each matrix is computed.
48 Args:
49 A (jnp.ndarray): The (potentially batched) matrices of which to compute
50 the logarithm.
52 Returns:
53 jnp.ndarray: The log matrices
54 """
55 # TODO: check warnings
56 if len(A.shape) == 2:
57 return logm(A, **kwargs)
58 elif len(A.shape) == 3:
59 AV = jnp.zeros(A.shape, dtype=jnp.complex128)
60 for i in range(A.shape[0]):
61 AV = AV.at[i].set(logm(A[i], **kwargs))
62 return AV
63 else:
64 raise NotImplementedError("Unsupported shape of input matrix")
67class PauliCircuit:
68 """
69 Wrapper for Pauli-Clifford Circuits described by Nemkov et al.
70 (https://doi.org/10.1103/PhysRevA.108.032406). The code is inspired
71 by the corresponding implementation: https://github.com/idnm/FourierVQA.
73 A Pauli Circuit only consists of parameterised Pauli-rotations and Clifford
74 gates, which is the default for the most common VQCs.
75 """
77 @staticmethod
78 def from_parameterised_circuit(
79 tape: QuantumScript,
80 ) -> tuple[QuantumScript]:
81 """
82 Transforms the quantum tape of a circuit a Pauli-Clifford circuit.
84 Args:
85 tape (QuantumScript): The quantum tape for the operations in the
86 ansatz. This is automatically passed, when initialising the
87 transform function with a QNode. Note: directly calling
88 `PauliCircuit.from_parameterised_circuit(circuit)` for a QNode
89 circuit will fail
91 Returns:
92 QuantumScript:
93 - A new quantum tape, containing the operations of the
94 Pauli-Clifford Circuit.
95 """
96 operations = PauliCircuit.get_clifford_pauli_gates(tape)
98 pauli_gates, final_cliffords = PauliCircuit.commute_all_cliffords_to_the_end(
99 operations
100 )
102 observables = PauliCircuit.cliffords_in_observable(
103 final_cliffords, tape.observables
104 )
106 with QuantumTape() as tape_new:
107 for op in pauli_gates:
108 op.queue()
109 for obs in observables:
110 qml.expval(obs)
112 return tape_new
114 @staticmethod
115 def commute_all_cliffords_to_the_end(
116 operations: List[Operator],
117 ) -> Tuple[List[Operator], List[Operator]]:
118 """
119 This function moves all clifford gates to the end of the circuit,
120 accounting for commutation rules.
122 Args:
123 operations (List[Operator]): The operations in the tape of the
124 circuit
126 Returns:
127 Tuple[List[Operator], List[Operator]]:
128 - List of the resulting Pauli-rotations
129 - List of the resulting Clifford gates
130 """
131 first_clifford = -1
132 for i in range(len(operations) - 2, -1, -1):
133 j = i
134 while (
135 j + 1 < len(operations) # Clifford has not alredy reached the end
136 and PauliCircuit._is_clifford(operations[j])
137 and PauliCircuit._is_pauli_rotation(operations[j + 1])
138 ):
139 pauli, clifford = PauliCircuit._evolve_clifford_rotation(
140 operations[j], operations[j + 1]
141 )
142 operations[j] = pauli
143 operations[j + 1] = clifford
144 j += 1
145 first_clifford = j
147 # No Clifford gates are in the circuit
148 if not PauliCircuit._is_clifford(operations[-1]):
149 return operations, []
151 pauli_rotations = operations[:first_clifford]
152 clifford_gates = operations[first_clifford:]
154 return pauli_rotations, clifford_gates
156 @staticmethod
157 def get_clifford_pauli_gates(tape: QuantumScript) -> List[Operator]:
158 """
159 This function decomposes all gates in the circuit to clifford and
160 pauli-rotation gates
162 Args:
163 tape (QuantumScript): The tape of the circuit containing all
164 operations.
166 Returns:
167 List[Operator]: A list of operations consisting only of clifford
168 and Pauli-rotation gates.
169 """
170 operations = []
171 for operation in tape.operations:
172 if PauliCircuit._is_clifford(operation) or PauliCircuit._is_pauli_rotation(
173 operation
174 ):
175 operations.append(operation)
176 elif PauliCircuit._is_skippable(operation):
177 continue
178 else:
179 # TODO: Maybe there is a prettier way to decompose a gate
180 # We currently can not handle parametrised input gates, that
181 # are not plain pauli rotations
182 tape = QuantumScript([operation])
183 decomposed_tape = qml.transforms.decompose(
184 tape, gate_set=PAULI_ROTATION_GATES + CLIFFORD_GATES
185 )
186 decomposed_ops = decomposed_tape[0][0].operations
187 decomposed_ops = [
188 (
189 op
190 if PauliCircuit._is_clifford(op)
191 else op.__class__(jnp.array(op.parameters), op.wires)
192 )
193 for op in decomposed_ops
194 ]
195 operations.extend(decomposed_ops)
197 return operations
199 @staticmethod
200 def _is_skippable(operation: Operator) -> bool:
201 """
202 Determines is an operator can be ignored when building the Pauli
203 Clifford circuit. Currently this only contains barriers.
205 Args:
206 operation (Operator): Gate operation
208 Returns:
209 bool: Whether the operation can be skipped.
210 """
211 return isinstance(operation, SKIPPABLE_OPERATIONS)
213 @staticmethod
214 def _is_clifford(operation: Operator) -> bool:
215 """
216 Determines is an operator is a Clifford gate.
218 Args:
219 operation (Operator): Gate operation
221 Returns:
222 bool: Whether the operation is Clifford.
223 """
224 return isinstance(operation, CLIFFORD_GATES)
226 @staticmethod
227 def _is_pauli_rotation(operation: Operator) -> bool:
228 """
229 Determines is an operator is a Pauli rotation gate.
231 Args:
232 operation (Operator): Gate operation
234 Returns:
235 bool: Whether the operation is a Pauli operation.
236 """
237 return isinstance(operation, PAULI_ROTATION_GATES)
239 @staticmethod
240 def _evolve_clifford_rotation(
241 clifford: Operator, pauli: Operator
242 ) -> Tuple[Operator, Operator]:
243 """
244 This function computes the resulting operations, when switching a
245 Cifford gate and a Pauli rotation in the circuit.
247 **Example**:
248 Consider a circuit consisting of the gate sequence
249 ... --- H --- R_z --- ...
250 This function computes the evolved Pauli Rotation, and moves the
251 clifford (Hadamard) gate to the end:
252 ... --- R_x --- H --- ...
254 Args:
255 clifford (Operator): Clifford gate to move.
256 pauli (Operator): Pauli rotation gate to move the clifford past.
258 Returns:
259 Tuple[Operator, Operator]:
260 - Resulting Clifford operator (should be the same as the input)
261 - Evolved Pauli rotation operator
262 """
264 if not any(p_c in clifford.wires for p_c in pauli.wires):
265 return pauli, clifford
267 gen = pauli.generator()
268 param = pauli.parameters[0]
270 evolved_gen, _ = PauliCircuit._evolve_clifford_pauli(
271 clifford, gen, adjoint_left=False
272 )
273 qubits = evolved_gen.wires
274 evolved_gen = qml.pauli_decompose(evolved_gen.matrix())
275 pauli_str, param_factor = PauliCircuit._get_paulistring_from_generator(
276 evolved_gen
277 )
278 pauli_str, qubits = PauliCircuit._remove_identities_from_paulistr(
279 pauli_str, qubits
280 )
281 pauli = qml.PauliRot(param * param_factor, pauli_str, qubits)
283 return pauli, clifford
285 @staticmethod
286 def _remove_identities_from_paulistr(
287 pauli_str: str, qubits: List[int]
288 ) -> Tuple[str, List[int]]:
289 """
290 Removes identities from Pauli string and its corresponding qubits.
292 Args:
293 pauli_str (str): Pauli string
294 qubits (List[int]): Corresponding qubit indices.
296 Returns:
297 Tuple[str, List[int]]:
298 - Pauli string without identities
299 - Qubits indices without the identities
300 """
302 reduced_qubits = []
303 reduced_pauli_str = ""
304 for i, p in enumerate(pauli_str):
305 if p != "I":
306 reduced_pauli_str += p
307 reduced_qubits.append(qubits[i])
309 return reduced_pauli_str, reduced_qubits
311 @staticmethod
312 def _evolve_clifford_pauli(
313 clifford: Operator, pauli: Operator, adjoint_left: bool = True
314 ) -> Tuple[Operator, Operator]:
315 """
316 This function computes the resulting operation, when evolving a Pauli
317 Operation with a Clifford operation.
318 For a Clifford operator C and a Pauli operator P, this functin computes:
319 P' = C* P C
321 Args:
322 clifford (Operator): Clifford gate
323 pauli (Operator): Pauli gate
324 adjoint_left (bool, optional): If adjoint of the clifford gate is
325 applied to the left. If this is set to True C* P C is computed,
326 else C P C*. Defaults to True.
328 Returns:
329 Tuple[Operator, Operator]:
330 - Evolved Pauli operator
331 - Resulting Clifford operator (should be the same as the input)
332 """
333 if not any(p_c in clifford.wires for p_c in pauli.wires):
334 return pauli, clifford
336 if adjoint_left:
337 evolved_pauli = qml.adjoint(clifford) @ pauli @ qml.adjoint(clifford)
338 else:
339 evolved_pauli = clifford @ pauli @ qml.adjoint(clifford)
341 return evolved_pauli, clifford
343 @staticmethod
344 def _evolve_cliffords_list(cliffords: List[Operator], pauli: Operator) -> Operator:
345 """
346 This function evolves a Pauli operation according to a sequence of cliffords.
348 Args:
349 clifford (Operator): Clifford gate
350 pauli (Operator): Pauli gate
352 Returns:
353 Operator: Evolved Pauli operator
354 """
355 for clifford in cliffords[::-1]:
356 pauli, _ = PauliCircuit._evolve_clifford_pauli(clifford, pauli)
357 qubits = pauli.wires
358 pauli = qml.pauli_decompose(pauli.matrix(), wire_order=qubits)
360 pauli = qml.simplify(pauli)
362 # remove coefficients
363 pauli = (
364 pauli.terms()[1][0]
365 if isinstance(pauli, (qml_op.Prod, qml_op.LinearCombination))
366 else pauli
367 )
369 return pauli
371 @staticmethod
372 def _get_paulistring_from_generator(
373 gen: qml_op.LinearCombination,
374 ) -> Tuple[str, float]:
375 """
376 Compute a Paulistring, consisting of "X", "Y", "Z" and "I" from a
377 generator.
379 Args:
380 gen (qml_op.LinearCombination): The generator operation created by
381 Pennylane
383 Returns:
384 Tuple[str, float]:
385 - The Paulistring
386 - A factor with which to multiply a parameter to the rotation
387 gate.
388 """
389 factor, term = gen.terms()
390 param_factor = -2 * factor # Rotation is defined as exp(-0.5 theta G)
391 pauli_term = term[0] if isinstance(term[0], qml_op.Prod) else [term[0]]
392 pauli_str_list = ["I"] * len(pauli_term)
393 for p in pauli_term:
394 if "Pauli" in p.name:
395 q = p.wires[0]
396 pauli_str_list[q] = p.name[-1]
397 pauli_str = "".join(pauli_str_list)
398 return pauli_str, param_factor
400 @staticmethod
401 def cliffords_in_observable(
402 operations: List[Operator], original_obs: List[Operator]
403 ) -> List[Operator]:
404 """
405 Integrates Clifford gates in the observables of the original ansatz.
407 Args:
408 operations (List[Operator]): Clifford gates
409 original_obs (List[Operator]): Original observables from the
410 circuit
412 Returns:
413 List[Operator]: Observables with Clifford operations
414 """
415 observables = []
416 for ob in original_obs:
417 clifford_obs = PauliCircuit._evolve_cliffords_list(operations, ob)
418 observables.append(clifford_obs)
419 return observables
422class QuanTikz:
423 class TikzFigure:
424 def __init__(self, quantikz_str: str):
425 self.quantikz_str = quantikz_str
427 def __repr__(self):
428 return self.quantikz_str
430 def __str__(self):
431 return self.quantikz_str
433 def wrap_figure(self):
434 """
435 Wraps the quantikz string in a LaTeX figure environment.
437 Returns:
438 str: A formatted LaTeX string representing the TikZ figure containing
439 the quantum circuit diagram.
440 """
441 return f"""
442\\begin{{figure}}
443 \\centering
444 \\begin{{tikzpicture}}
445 \\node[scale=0.85] {{
446 \\begin{{quantikz}}
447 {self.quantikz_str}
448 \\end{{quantikz}}
449 }};
450 \\end{{tikzpicture}}
451\\end{{figure}}"""
453 def export(self, destination: str, full_document=False, mode="w") -> None:
454 """
455 Export a LaTeX document with a quantum circuit in stick notation.
457 Parameters
458 ----------
459 quantikz_strs : str or list[str]
460 LaTeX string for the quantum circuit or a list of LaTeX strings.
461 destination : str
462 Path to the destination file.
463 """
464 if full_document:
465 latex_code = f"""
466\\documentclass{{article}}
467\\usepackage{{quantikz}}
468\\usepackage{{tikz}}
469\\usetikzlibrary{{quantikz2}}
470\\usepackage{{quantikz}}
471\\usepackage[a3paper, landscape, margin=0.5cm]{{geometry}}
472\\begin{{document}}
473{self.wrap_figure()}
474\\end{{document}}"""
475 else:
476 latex_code = self.quantikz_str + "\n"
478 with open(destination, mode) as f:
479 f.write(latex_code)
481 @staticmethod
482 def ground_state() -> str:
483 """
484 Generate the LaTeX representation of the |0⟩ ground state in stick notation.
486 Returns
487 -------
488 str
489 LaTeX string for the |0⟩ state.
490 """
491 return "\\lstick{\\ket{0}}"
493 @staticmethod
494 def measure(op):
495 if len(op.wires) > 1:
496 raise NotImplementedError("Multi-wire measurements are not supported yet")
497 else:
498 return "\\meter{}"
500 @staticmethod
501 def search_pi_fraction(w, op_name):
502 w_pi = Fraction(w / jnp.pi).limit_denominator(100)
503 # Not a small nice Fraction
504 if w_pi.denominator > 12:
505 return f"\\gate{{{op_name}({w:.2f})}}"
506 # Pi
507 elif w_pi.denominator == 1 and w_pi.numerator == 1:
508 return f"\\gate{{{op_name}(\\pi)}}"
509 # 0
510 elif w_pi.numerator == 0:
511 return f"\\gate{{{op_name}(0)}}"
512 # Multiple of Pi
513 elif w_pi.denominator == 1:
514 return f"\\gate{{{op_name}({w_pi.numerator}\\pi)}}"
515 # Nice Fraction of pi
516 elif w_pi.numerator == 1:
517 return (
518 f"\\gate{{{op_name}\\left("
519 f"\\frac{{\\pi}}{{{w_pi.denominator}}}\\right)}}"
520 )
521 # Small nice Fraction
522 else:
523 return (
524 f"\\gate{{{op_name}\\left("
525 f"\\frac{{{w_pi.numerator}\\pi}}{{{w_pi.denominator}}}"
526 f"\\right)}}"
527 )
529 @staticmethod
530 def gate(op, index=None, gate_values=False, inputs_symbols="x") -> str:
531 """
532 Generate LaTeX for a quantum gate in stick notation.
534 Parameters
535 ----------
536 op : qml.Operation
537 The quantum gate to represent.
538 index : int, optional
539 Gate index in the circuit.
540 gate_values : bool, optional
541 Include gate values in the representation.
542 inputs_symbols : str, optional
543 Symbols for the inputs in the representation.
545 Returns
546 -------
547 str
548 LaTeX string for the gate.
549 """
550 op_name = op.name
551 match op.name:
552 case "Hadamard":
553 op_name = "H"
554 case "RX" | "RY" | "RZ":
555 pass
556 case "Rot":
557 op_name = "R"
559 if gate_values and len(op.parameters) > 0:
560 w = float(op.parameters[0].item())
561 return QuanTikz.search_pi_fraction(w, op_name)
562 else:
563 # Is gate with parameter
564 if op.parameters == [] or op.parameters[0].shape == ():
565 if index is None:
566 return f"\\gate{{{op_name}}}"
567 else:
568 return f"\\gate{{{op_name}(\\theta_{{{index}}})}}"
569 # Is gate with input
570 elif op.parameters[0].shape == (1,):
571 return f"\\gate{{{op_name}({inputs_symbols})}}"
573 @staticmethod
574 def cgate(op, index=None, gate_values=False, inputs_symbols="x") -> Tuple[str, str]:
575 """
576 Generate LaTeX for a controlled quantum gate in stick notation.
578 Parameters
579 ----------
580 op : qml.Operation
581 The quantum gate operation to represent.
582 index : int, optional
583 Gate index in the circuit.
584 gate_values : bool, optional
585 Include gate values in the representation.
586 inputs_symbols : str, optional
587 Symbols for the inputs in the representation.
589 Returns
590 -------
591 Tuple[str, str]
592 - LaTeX string for the control gate
593 - LaTeX string for the target gate
594 """
595 match op.name:
596 case "CRX" | "CRY" | "CRZ" | "CX" | "CY" | "CZ":
597 op_name = op.name[1:]
598 case _:
599 pass
600 targ = "\\targ{}"
601 if op.name in ["CRX", "CRY", "CRZ"]:
602 if gate_values and len(op.parameters) > 0:
603 w = float(op.parameters[0].item())
604 targ = QuanTikz.search_pi_fraction(w, op_name)
605 else:
606 # Is gate with parameter
607 if op.parameters[0].shape == ():
608 if index is None:
609 targ = f"\\gate{{{op_name}}}"
610 else:
611 targ = f"\\gate{{{op_name}(\\theta_{{{index}}})}}"
612 # Is gate with input
613 elif op.parameters[0].shape == (1,):
614 targ = f"\\gate{{{op_name}({inputs_symbols})}}"
615 elif op.name in ["CX", "CY", "CZ"]:
616 targ = "\\control{}"
618 distance = op.wires[1] - op.wires[0]
619 return f"\\ctrl{{{distance}}}", targ
621 @staticmethod
622 def barrier(op) -> str:
623 """
624 Generate LaTeX for a barrier in stick notation.
626 Parameters
627 ----------
628 op : qml.Operation
629 The barrier operation to represent.
631 Returns
632 -------
633 str
634 LaTeX string for the barrier.
635 """
636 return (
637 "\\slice[style={{draw=black, solid, double distance=2pt, "
638 "line width=0.5pt}}]{{}}"
639 )
641 @staticmethod
642 def _build_tikz_circuit(quantum_tape, gate_values=False, inputs_symbols="x"):
643 """
644 Builds a LaTeX representation of a quantum circuit in TikZ format.
646 This static method constructs a TikZ circuit diagram from a given quantum
647 tape. It processes the operations in the tape, including gates, controlled
648 gates, barriers, and measurements. The resulting structure is a list of
649 LaTeX strings, each representing a wire in the circuit.
651 Parameters
652 ----------
653 quantum_tape : QuantumTape
654 The quantum tape containing the operations of the circuit.
655 gate_values : bool, optional
656 If True, include gate parameter values in the representation.
657 inputs_symbols : str, optional
658 Symbols to represent the inputs in the circuit.
660 Returns
661 -------
662 circuit_tikz : list of list of str
663 A nested list where each inner list contains LaTeX strings representing
664 the operations on a single wire of the circuit.
665 """
667 circuit_tikz = [
668 [QuanTikz.ground_state()] for _ in range(quantum_tape.num_wires)
669 ]
671 index = iter(range(len(quantum_tape.operations)))
672 for op in quantum_tape.circuit:
673 # catch measurement operations
674 if op._queue_category == "_measurements":
675 # get the maximum length of all wires
676 max_len = max(len(circuit_tikz[cw]) for cw in range(len(circuit_tikz)))
677 if op.wires[0] != 0:
678 max_len -= 1
679 # extend the wire by the number of missing operations
680 circuit_tikz[op.wires[0]].extend(
681 "" for _ in range(max_len - len(circuit_tikz[op.wires[0]]))
682 )
683 circuit_tikz[op.wires[0]].append(QuanTikz.measure(op))
684 # process all gates
685 elif op._queue_category == "_ops":
686 # catch barriers
687 if op.name == "Barrier":
688 # get the maximum length of all wires
689 max_len = max(
690 len(circuit_tikz[cw]) for cw in range(len(circuit_tikz))
691 )
693 # extend the wires by the number of missing operations
694 for ow in [i for i in range(len(circuit_tikz))]:
695 circuit_tikz[ow].extend(
696 "" for _ in range(max_len - len(circuit_tikz[ow]))
697 )
699 circuit_tikz[op.wires[0]][-1] += QuanTikz.barrier(op)
700 # single qubit gate?
701 elif len(op.wires) == 1:
702 # build and append standard gate
703 circuit_tikz[op.wires[0]].append(
704 QuanTikz.gate(
705 op,
706 index=next(index),
707 gate_values=gate_values,
708 inputs_symbols=next(inputs_symbols),
709 )
710 )
711 # controlled gate?
712 elif len(op.wires) == 2:
713 # build the controlled gate
714 if op.name in ["CRX", "CRY", "CRZ"]:
715 ctrl, targ = QuanTikz.cgate(
716 op,
717 index=next(index),
718 gate_values=gate_values,
719 inputs_symbols=next(inputs_symbols),
720 )
721 else:
722 ctrl, targ = QuanTikz.cgate(op)
724 # get the wires that this cgate spans over
725 crossing_wires = [
726 i for i in range(min(op.wires), max(op.wires) + 1)
727 ]
728 # get the maximum length of all operations currently on this wire
729 max_len = max([len(circuit_tikz[cw]) for cw in crossing_wires])
731 # extend the affected wires by the number of missing operations
732 for ow in [i for i in range(min(op.wires), max(op.wires) + 1)]:
733 circuit_tikz[ow].extend(
734 "" for _ in range(max_len - len(circuit_tikz[ow]))
735 )
737 # finally append the cgate operation
738 circuit_tikz[op.wires[0]].append(ctrl)
739 circuit_tikz[op.wires[1]].append(targ)
741 # extend the non-affected wires by the number of missing operations
742 for cw in crossing_wires - op.wires:
743 circuit_tikz[cw].append("")
744 else:
745 raise NotImplementedError(">2-wire gates are not supported yet")
747 return circuit_tikz
749 @staticmethod
750 def build(
751 circuit: qml.QNode,
752 params,
753 inputs,
754 enc_params=None,
755 gate_values=False,
756 inputs_symbols="x",
757 ) -> str:
758 """
759 Generate LaTeX for a quantum circuit in stick notation.
761 Parameters
762 ----------
763 circuit : qml.QNode
764 The quantum circuit to represent.
765 params : array
766 Weight parameters for the circuit.
767 inputs : array
768 Inputs for the circuit.
769 enc_params : array
770 Encoding weight parameters for the circuit.
771 gate_values : bool, optional
772 Toggle for gate values or theta variables in the representation.
773 inputs_symbols : str, optional
774 Symbols for the inputs in the representation.
776 Returns
777 -------
778 str
779 LaTeX string for the circuit.
780 """
781 if enc_params is not None:
782 quantum_tape = qml.workflow.construct_tape(circuit)(
783 params=params, inputs=inputs, enc_params=enc_params
784 )
785 else:
786 quantum_tape = qml.workflow.construct_tape(circuit)(
787 params=params, inputs=inputs
788 )
790 if isinstance(inputs_symbols, str) and inputs.size > 1:
791 inputs_symbols = cycle(
792 [f"{inputs_symbols}_{i}" for i in range(inputs.size)]
793 )
794 elif isinstance(inputs_symbols, list):
795 assert (
796 len(inputs_symbols) == inputs.size
797 ), f"The number of input symbols {len(inputs_symbols)} \
798 must match the number of inputs {inputs.size}."
799 inputs_symbols = cycle(inputs_symbols)
800 else:
801 inputs_symbols = cycle([inputs_symbols])
803 circuit_tikz = QuanTikz._build_tikz_circuit(
804 quantum_tape, gate_values=gate_values, inputs_symbols=inputs_symbols
805 )
806 quantikz_str = ""
808 # get the maximum length of all wires
809 max_len = max(len(circuit_tikz[cw]) for cw in range(len(circuit_tikz)))
811 # extend the wires by the number of missing operations
812 for ow in [i for i in range(len(circuit_tikz))]:
813 circuit_tikz[ow].extend("" for _ in range(max_len - len(circuit_tikz[ow])))
815 for wire_idx, wire_ops in enumerate(circuit_tikz):
816 for op_idx, op in enumerate(wire_ops):
817 # if not last operation on wire
818 if op_idx < len(wire_ops) - 1:
819 quantikz_str += f"{op} & "
820 else:
821 quantikz_str += f"{op}"
822 # if not last wire
823 if wire_idx < len(circuit_tikz) - 1:
824 quantikz_str += " \\\\\n"
826 return QuanTikz.TikzFigure(quantikz_str)