Coverage for qml_essentials/utils.py: 95%
316 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-29 14:55 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-29 14:55 +0000
1from __future__ import annotations
2from typing import List, Tuple
3import numpy as np
4import pennylane as qml
5from pennylane.operation import Operator
6from pennylane.tape import QuantumScript, QuantumScriptBatch, QuantumTape
7from pennylane.typing import PostprocessingFn
8import pennylane.numpy as pnp
9import pennylane.ops.op_math as qml_op
10from pennylane.drawer import drawable_layers, tape_text
11from fractions import Fraction
12from itertools import cycle
13from scipy.linalg import logm
14import dill
15import multiprocessing
16import os
18CLIFFORD_GATES = (
19 qml.PauliX,
20 qml.PauliY,
21 qml.PauliZ,
22 qml.X,
23 qml.Y,
24 qml.Z,
25 qml.Hadamard,
26 qml.S,
27 qml.CNOT,
28)
30PAULI_ROTATION_GATES = (
31 qml.RX,
32 qml.RY,
33 qml.RZ,
34 qml.PauliRot,
35)
37SKIPPABLE_OPERATIONS = (qml.Barrier,)
40class MultiprocessingPool:
42 class DillProcess(multiprocessing.Process):
44 def __init__(self, *args, **kwargs):
45 super().__init__(*args, **kwargs)
46 self._target = dill.dumps(
47 self._target
48 ) # Save the target function as bytes, using dill
50 def run(self):
51 if self._target:
52 self._target = dill.loads(
53 self._target
54 ) # Unpickle the target function before executing
55 return self._target(
56 *self._args, **self._kwargs
57 ) # Execute the target function
59 def __init__(self, target, n_processes, cpu_scaler, *args, **kwargs):
60 self.target = target
61 self.n_processes = n_processes
62 self.cpu_scaler = cpu_scaler
63 self.args = args
64 self.kwargs = kwargs
66 assert (
67 self.cpu_scaler <= 1 and self.cpu_scaler >= 0
68 ), f"cpu_scaler must in [0..1], got {self.cpu_scaler}"
70 def spawn(self):
71 manager = multiprocessing.Manager()
72 return_dict = manager.dict()
74 jobs = []
75 n_procs = max(int(len(os.sched_getaffinity(0)) * self.cpu_scaler), 1)
76 c_procs = 0
77 for it in range(self.n_processes):
78 m = self.DillProcess(
79 target=self.target,
80 args=[it, return_dict, *self.args],
81 kwargs=self.kwargs,
82 )
84 # append and start job
85 jobs.append(m)
86 jobs[-1].start()
87 c_procs += 1
89 # if we reach the max limit of jobs
90 if c_procs > n_procs:
91 # wait for the last n_procs jobs to finish
92 for j in jobs[-c_procs:]:
93 j.join()
94 # then continue with the next batch
95 c_procs = 0
97 # wait for any remaining jobs
98 for j in jobs:
99 if j.is_alive():
100 j.join()
102 return return_dict
105def logm_v(A, **kwargs):
106 # TODO: check warnings
107 if len(A.shape) == 2:
108 return logm(A, **kwargs)
109 elif len(A.shape) == 3:
110 AV = np.zeros(A.shape, dtype=A.dtype)
111 for i in range(A.shape[0]):
112 AV[i] = logm(A[i], **kwargs)
113 return AV
114 else:
115 raise NotImplementedError("Unsupported shape of input matrix")
118class PauliCircuit:
119 """
120 Wrapper for Pauli-Clifford Circuits described by Nemkov et al.
121 (https://doi.org/10.1103/PhysRevA.108.032406). The code is inspired
122 by the corresponding implementation: https://github.com/idnm/FourierVQA.
124 A Pauli Circuit only consists of parameterised Pauli-rotations and Clifford
125 gates, which is the default for the most common VQCs.
126 """
128 @staticmethod
129 def from_parameterised_circuit(
130 tape: QuantumScript,
131 ) -> tuple[QuantumScriptBatch, PostprocessingFn]:
132 """
133 Transformation function (see also qml.transforms) to convert an ansatz
134 into a Pauli-Clifford circuit.
137 **Usage** (without using Model, Model provides a boolean argument
138 "as_pauli_circuit" that internally uses the Pauli-Clifford):
139 ```
140 # initialise some QNode
141 circuit = qml.QNode(
142 circuit_fkt, # function for your circuit definition
143 qml.device("default.qubit", wires=5),
144 )
145 pauli_circuit = PauliCircuit.from_parameterised_circuit(circuit)
147 # Call exactly the same as circuit
148 some_input = [0.1, 0.2]
150 circuit(some_input)
151 pauli_circuit(some_input)
153 # Both results should be equal!
154 ```
156 Args:
157 tape (QuantumScript): The quantum tape for the operations in the
158 ansatz. This is automatically passed, when initialising the
159 transform function with a QNode. Note: directly calling
160 `PauliCircuit.from_parameterised_circuit(circuit)` for a QNode
161 circuit will fail, see usage above.
163 Returns:
164 tuple[QuantumScriptBatch, PostprocessingFn]:
165 - A new quantum tape, containing the operations of the
166 Pauli-Clifford Circuit.
167 - A postprocessing function that does nothing.
168 """
170 operations = PauliCircuit.get_clifford_pauli_gates(tape)
172 pauli_gates, final_cliffords = PauliCircuit.commute_all_cliffords_to_the_end(
173 operations
174 )
176 observables = PauliCircuit.cliffords_in_observable(
177 final_cliffords, tape.observables
178 )
180 with QuantumTape() as tape_new:
181 for op in pauli_gates:
182 op.queue()
183 for obs in observables:
184 qml.expval(obs)
186 def postprocess(res):
187 return res[0]
189 return [tape_new], postprocess
191 @staticmethod
192 def commute_all_cliffords_to_the_end(
193 operations: List[Operator],
194 ) -> Tuple[List[Operator], List[Operator]]:
195 """
196 This function moves all clifford gates to the end of the circuit,
197 accounting for commutation rules.
199 Args:
200 operations (List[Operator]): The operations in the tape of the
201 circuit
203 Returns:
204 Tuple[List[Operator], List[Operator]]:
205 - List of the resulting Pauli-rotations
206 - List of the resulting Clifford gates
207 """
208 first_clifford = -1
209 for i in range(len(operations) - 2, -1, -1):
210 j = i
211 while (
212 j + 1 < len(operations) # Clifford has not alredy reached the end
213 and PauliCircuit._is_clifford(operations[j])
214 and PauliCircuit._is_pauli_rotation(operations[j + 1])
215 ):
216 pauli, clifford = PauliCircuit._evolve_clifford_rotation(
217 operations[j], operations[j + 1]
218 )
219 operations[j] = pauli
220 operations[j + 1] = clifford
221 j += 1
222 first_clifford = j
224 # No Clifford gates are in the circuit
225 if not PauliCircuit._is_clifford(operations[-1]):
226 return operations, []
228 pauli_rotations = operations[:first_clifford]
229 clifford_gates = operations[first_clifford:]
231 return pauli_rotations, clifford_gates
233 @staticmethod
234 def get_clifford_pauli_gates(tape: QuantumScript) -> List[Operator]:
235 """
236 This function decomposes all gates in the circuit to clifford and
237 pauli-rotation gates
239 Args:
240 tape (QuantumScript): The tape of the circuit containing all
241 operations.
243 Returns:
244 List[Operator]: A list of operations consisting only of clifford
245 and Pauli-rotation gates.
246 """
247 operations = []
248 for operation in tape.operations:
249 if PauliCircuit._is_clifford(operation) or PauliCircuit._is_pauli_rotation(
250 operation
251 ):
252 operations.append(operation)
253 elif PauliCircuit._is_skippable(operation):
254 continue
255 else:
256 # TODO: Maybe there is a prettier way to decompose a gate
257 # We currently can not handle parametrised input gates, that
258 # are not plain pauli rotations
259 tape = QuantumScript([operation])
260 decomposed_tape = qml.transforms.decompose(
261 tape, gate_set=PAULI_ROTATION_GATES + CLIFFORD_GATES
262 )
263 decomposed_ops = decomposed_tape[0][0].operations
264 decomposed_ops = [
265 (
266 op
267 if PauliCircuit._is_clifford(op)
268 else op.__class__(pnp.tensor(op.parameters), op.wires)
269 )
270 for op in decomposed_ops
271 ]
272 operations.extend(decomposed_ops)
274 return operations
276 @staticmethod
277 def _is_skippable(operation: Operator) -> bool:
278 """
279 Determines is an operator can be ignored when building the Pauli
280 Clifford circuit. Currently this only contains barriers.
282 Args:
283 operation (Operator): Gate operation
285 Returns:
286 bool: Whether the operation can be skipped.
287 """
288 return isinstance(operation, SKIPPABLE_OPERATIONS)
290 @staticmethod
291 def _is_clifford(operation: Operator) -> bool:
292 """
293 Determines is an operator is a Clifford gate.
295 Args:
296 operation (Operator): Gate operation
298 Returns:
299 bool: Whether the operation is Clifford.
300 """
301 return isinstance(operation, CLIFFORD_GATES)
303 @staticmethod
304 def _is_pauli_rotation(operation: Operator) -> bool:
305 """
306 Determines is an operator is a Pauli rotation gate.
308 Args:
309 operation (Operator): Gate operation
311 Returns:
312 bool: Whether the operation is a Pauli operation.
313 """
314 return isinstance(operation, PAULI_ROTATION_GATES)
316 @staticmethod
317 def _evolve_clifford_rotation(
318 clifford: Operator, pauli: Operator
319 ) -> Tuple[Operator, Operator]:
320 """
321 This function computes the resulting operations, when switching a
322 Cifford gate and a Pauli rotation in the circuit.
324 **Example**:
325 Consider a circuit consisting of the gate sequence
326 ... --- H --- R_z --- ...
327 This function computes the evolved Pauli Rotation, and moves the
328 clifford (Hadamard) gate to the end:
329 ... --- R_x --- H --- ...
331 Args:
332 clifford (Operator): Clifford gate to move.
333 pauli (Operator): Pauli rotation gate to move the clifford past.
335 Returns:
336 Tuple[Operator, Operator]:
337 - Resulting Clifford operator (should be the same as the input)
338 - Evolved Pauli rotation operator
339 """
341 if not any(p_c in clifford.wires for p_c in pauli.wires):
342 return pauli, clifford
344 gen = pauli.generator()
345 param = pauli.parameters[0]
346 requires_grad = param.requires_grad if isinstance(param, pnp.tensor) else False
347 param = pnp.tensor(param)
349 evolved_gen, _ = PauliCircuit._evolve_clifford_pauli(
350 clifford, gen, adjoint_left=False
351 )
352 qubits = evolved_gen.wires
353 evolved_gen = qml.pauli_decompose(evolved_gen.matrix())
354 pauli_str, param_factor = PauliCircuit._get_paulistring_from_generator(
355 evolved_gen
356 )
357 pauli_str, qubits = PauliCircuit._remove_identities_from_paulistr(
358 pauli_str, qubits
359 )
360 pauli = qml.PauliRot(param * param_factor, pauli_str, qubits)
361 pauli.parameters[0].requires_grad = requires_grad
363 return pauli, clifford
365 @staticmethod
366 def _remove_identities_from_paulistr(
367 pauli_str: str, qubits: List[int]
368 ) -> Tuple[str, List[int]]:
369 """
370 Removes identities from Pauli string and its corresponding qubits.
372 Args:
373 pauli_str (str): Pauli string
374 qubits (List[int]): Corresponding qubit indices.
376 Returns:
377 Tuple[str, List[int]]:
378 - Pauli string without identities
379 - Qubits indices without the identities
380 """
382 reduced_qubits = []
383 reduced_pauli_str = ""
384 for i, p in enumerate(pauli_str):
385 if p != "I":
386 reduced_pauli_str += p
387 reduced_qubits.append(qubits[i])
389 return reduced_pauli_str, reduced_qubits
391 @staticmethod
392 def _evolve_clifford_pauli(
393 clifford: Operator, pauli: Operator, adjoint_left: bool = True
394 ) -> Tuple[Operator, Operator]:
395 """
396 This function computes the resulting operation, when evolving a Pauli
397 Operation with a Clifford operation.
398 For a Clifford operator C and a Pauli operator P, this functin computes:
399 P' = C* P C
401 Args:
402 clifford (Operator): Clifford gate
403 pauli (Operator): Pauli gate
404 adjoint_left (bool, optional): If adjoint of the clifford gate is
405 applied to the left. If this is set to True C* P C is computed,
406 else C P C*. Defaults to True.
408 Returns:
409 Tuple[Operator, Operator]:
410 - Evolved Pauli operator
411 - Resulting Clifford operator (should be the same as the input)
412 """
413 if not any(p_c in clifford.wires for p_c in pauli.wires):
414 return pauli, clifford
416 if adjoint_left:
417 evolved_pauli = qml.adjoint(clifford) @ pauli @ qml.adjoint(clifford)
418 else:
419 evolved_pauli = clifford @ pauli @ qml.adjoint(clifford)
421 return evolved_pauli, clifford
423 @staticmethod
424 def _evolve_cliffords_list(cliffords: List[Operator], pauli: Operator) -> Operator:
425 """
426 This function evolves a Pauli operation according to a sequence of cliffords.
428 Args:
429 clifford (Operator): Clifford gate
430 pauli (Operator): Pauli gate
432 Returns:
433 Operator: Evolved Pauli operator
434 """
435 for clifford in cliffords[::-1]:
436 pauli, _ = PauliCircuit._evolve_clifford_pauli(clifford, pauli)
437 qubits = pauli.wires
438 pauli = qml.pauli_decompose(pauli.matrix(), wire_order=qubits)
440 pauli = qml.simplify(pauli)
442 # remove coefficients
443 pauli = (
444 pauli.terms()[1][0]
445 if isinstance(pauli, (qml_op.Prod, qml_op.LinearCombination))
446 else pauli
447 )
449 return pauli
451 @staticmethod
452 def _get_paulistring_from_generator(
453 gen: qml_op.LinearCombination,
454 ) -> Tuple[str, float]:
455 """
456 Compute a Paulistring, consisting of "X", "Y", "Z" and "I" from a
457 generator.
459 Args:
460 gen (qml_op.LinearCombination): The generator operation created by
461 Pennylane
463 Returns:
464 Tuple[str, float]:
465 - The Paulistring
466 - A factor with which to multiply a parameter to the rotation
467 gate.
468 """
469 factor, term = gen.terms()
470 param_factor = -2 * factor # Rotation is defined as exp(-0.5 theta G)
471 pauli_term = term[0] if isinstance(term[0], qml_op.Prod) else [term[0]]
472 pauli_str_list = ["I"] * len(pauli_term)
473 for p in pauli_term:
474 if "Pauli" in p.name:
475 q = p.wires[0]
476 pauli_str_list[q] = p.name[-1]
477 pauli_str = "".join(pauli_str_list)
478 return pauli_str, param_factor
480 @staticmethod
481 def cliffords_in_observable(
482 operations: List[Operator], original_obs: List[Operator]
483 ) -> List[Operator]:
484 """
485 Integrates Clifford gates in the observables of the original ansatz.
487 Args:
488 operations (List[Operator]): Clifford gates
489 original_obs (List[Operator]): Original observables from the
490 circuit
492 Returns:
493 List[Operator]: Observables with Clifford operations
494 """
495 observables = []
496 for ob in original_obs:
497 clifford_obs = PauliCircuit._evolve_cliffords_list(operations, ob)
498 observables.append(clifford_obs)
499 return observables
502class QuanTikz:
503 class TikzFigure:
504 def __init__(self, quantikz_str: str):
505 self.quantikz_str = quantikz_str
507 def __repr__(self):
508 return self.quantikz_str
510 def __str__(self):
511 return self.quantikz_str
513 def wrap_figure(self):
514 """
515 Wraps the quantikz string in a LaTeX figure environment.
517 Returns:
518 str: A formatted LaTeX string representing the TikZ figure containing
519 the quantum circuit diagram.
520 """
521 return f"""
522\\begin{{figure}}
523 \\centering
524 \\begin{{tikzpicture}}
525 \\node[scale=0.85] {{
526 \\begin{{quantikz}}
527 {self.quantikz_str}
528 \\end{{quantikz}}
529 }};
530 \\end{{tikzpicture}}
531\\end{{figure}}"""
533 def export(self, destination: str, full_document=False, mode="w") -> None:
534 """
535 Export a LaTeX document with a quantum circuit in stick notation.
537 Parameters
538 ----------
539 quantikz_strs : str or list[str]
540 LaTeX string for the quantum circuit or a list of LaTeX strings.
541 destination : str
542 Path to the destination file.
543 """
544 if full_document:
545 latex_code = f"""
546\\documentclass{{article}}
547\\usepackage{{quantikz}}
548\\usepackage{{tikz}}
549\\usetikzlibrary{{quantikz2}}
550\\usepackage{{quantikz}}
551\\usepackage[a3paper, landscape, margin=0.5cm]{{geometry}}
552\\begin{{document}}
553{self.wrap_figure()}
554\\end{{document}}"""
555 else:
556 latex_code = self.quantikz_str + "\n"
558 with open(destination, mode) as f:
559 f.write(latex_code)
561 @staticmethod
562 def ground_state() -> str:
563 """
564 Generate the LaTeX representation of the |0⟩ ground state in stick notation.
566 Returns
567 -------
568 str
569 LaTeX string for the |0⟩ state.
570 """
571 return "\\lstick{\\ket{0}}"
573 @staticmethod
574 def measure(op):
575 if len(op.wires) > 1:
576 raise NotImplementedError("Multi-wire measurements are not supported yet")
577 else:
578 return "\\meter{}"
580 @staticmethod
581 def search_pi_fraction(w, op_name):
582 w_pi = Fraction(w / np.pi).limit_denominator(100)
583 # Not a small nice Fraction
584 if w_pi.denominator > 12:
585 return f"\\gate{{{op_name}({w:.2f})}}"
586 # Pi
587 elif w_pi.denominator == 1 and w_pi.numerator == 1:
588 return f"\\gate{{{op_name}(\\pi)}}"
589 # 0
590 elif w_pi.numerator == 0:
591 return f"\\gate{{{op_name}(0)}}"
592 # Multiple of Pi
593 elif w_pi.denominator == 1:
594 return f"\\gate{{{op_name}({w_pi.numerator}\\pi)}}"
595 # Nice Fraction of pi
596 elif w_pi.numerator == 1:
597 return (
598 f"\\gate{{{op_name}\\left("
599 f"\\frac{{\\pi}}{{{w_pi.denominator}}}\\right)}}"
600 )
601 # Small nice Fraction
602 else:
603 return (
604 f"\\gate{{{op_name}\\left("
605 f"\\frac{{{w_pi.numerator}\\pi}}{{{w_pi.denominator}}}"
606 f"\\right)}}"
607 )
609 @staticmethod
610 def gate(op, index=None, gate_values=False, inputs_symbols="x") -> str:
611 """
612 Generate LaTeX for a quantum gate in stick notation.
614 Parameters
615 ----------
616 op : qml.Operation
617 The quantum gate to represent.
618 index : int, optional
619 Gate index in the circuit.
620 gate_values : bool, optional
621 Include gate values in the representation.
622 inputs_symbols : str, optional
623 Symbols for the inputs in the representation.
625 Returns
626 -------
627 str
628 LaTeX string for the gate.
629 """
630 op_name = op.name
631 match op.name:
632 case "Hadamard":
633 op_name = "H"
634 case "RX" | "RY" | "RZ":
635 pass
636 case "Rot":
637 op_name = "R"
639 if gate_values and len(op.parameters) > 0:
640 w = float(op.parameters[0].item())
641 return QuanTikz.search_pi_fraction(w, op_name)
642 else:
643 # Is gate with parameter
644 if op.parameters == [] or op.parameters[0].shape == ():
645 if index is None:
646 return f"\\gate{{{op_name}}}"
647 else:
648 return f"\\gate{{{op_name}(\\theta_{{{index}}})}}"
649 # Is gate with input
650 elif op.parameters[0].shape == (1,):
651 return f"\\gate{{{op_name}({inputs_symbols})}}"
653 @staticmethod
654 def cgate(op, index=None, gate_values=False, inputs_symbols="x") -> Tuple[str, str]:
655 """
656 Generate LaTeX for a controlled quantum gate in stick notation.
658 Parameters
659 ----------
660 op : qml.Operation
661 The quantum gate operation to represent.
662 index : int, optional
663 Gate index in the circuit.
664 gate_values : bool, optional
665 Include gate values in the representation.
666 inputs_symbols : str, optional
667 Symbols for the inputs in the representation.
669 Returns
670 -------
671 Tuple[str, str]
672 - LaTeX string for the control gate
673 - LaTeX string for the target gate
674 """
675 match op.name:
676 case "CRX" | "CRY" | "CRZ" | "CX" | "CY" | "CZ":
677 op_name = op.name[1:]
678 case _:
679 pass
680 targ = "\\targ{}"
681 if op.name in ["CRX", "CRY", "CRZ"]:
682 if gate_values and len(op.parameters) > 0:
683 w = float(op.parameters[0].item())
684 targ = QuanTikz.search_pi_fraction(w, op_name)
685 else:
686 # Is gate with parameter
687 if op.parameters[0].shape == ():
688 if index is None:
689 targ = f"\\gate{{{op_name}}}"
690 else:
691 targ = f"\\gate{{{op_name}(\\theta_{{{index}}})}}"
692 # Is gate with input
693 elif op.parameters[0].shape == (1,):
694 targ = f"\\gate{{{op_name}({inputs_symbols})}}"
695 elif op.name in ["CX", "CY", "CZ"]:
696 targ = "\\control{}"
698 distance = op.wires[1] - op.wires[0]
699 return f"\\ctrl{{{distance}}}", targ
701 @staticmethod
702 def barrier(op) -> str:
703 """
704 Generate LaTeX for a barrier in stick notation.
706 Parameters
707 ----------
708 op : qml.Operation
709 The barrier operation to represent.
711 Returns
712 -------
713 str
714 LaTeX string for the barrier.
715 """
716 return (
717 "\\slice[style={{draw=black, solid, double distance=2pt, "
718 "line width=0.5pt}}]{{}}"
719 )
721 @staticmethod
722 def _build_tikz_circuit(quantum_tape, gate_values=False, inputs_symbols="x"):
723 """
724 Builds a LaTeX representation of a quantum circuit in TikZ format.
726 This static method constructs a TikZ circuit diagram from a given quantum
727 tape. It processes the operations in the tape, including gates, controlled
728 gates, barriers, and measurements. The resulting structure is a list of
729 LaTeX strings, each representing a wire in the circuit.
731 Parameters
732 ----------
733 quantum_tape : QuantumTape
734 The quantum tape containing the operations of the circuit.
735 gate_values : bool, optional
736 If True, include gate parameter values in the representation.
737 inputs_symbols : str, optional
738 Symbols to represent the inputs in the circuit.
740 Returns
741 -------
742 circuit_tikz : list of list of str
743 A nested list where each inner list contains LaTeX strings representing
744 the operations on a single wire of the circuit.
745 """
747 circuit_tikz = [
748 [QuanTikz.ground_state()] for _ in range(quantum_tape.num_wires)
749 ]
751 index = iter(range(10 * quantum_tape.num_params))
752 for op in quantum_tape.circuit:
753 # catch measurement operations
754 if op._queue_category == "_measurements":
755 # get the maximum length of all wires
756 max_len = max(len(circuit_tikz[cw]) for cw in range(len(circuit_tikz)))
757 if op.wires[0] != 0:
758 max_len -= 1
759 # extend the wire by the number of missing operations
760 circuit_tikz[op.wires[0]].extend(
761 "" for _ in range(max_len - len(circuit_tikz[op.wires[0]]))
762 )
763 circuit_tikz[op.wires[0]].append(QuanTikz.measure(op))
764 # process all gates
765 elif op._queue_category == "_ops":
766 # catch barriers
767 if op.name == "Barrier":
769 # get the maximum length of all wires
770 max_len = max(
771 len(circuit_tikz[cw]) for cw in range(len(circuit_tikz))
772 )
774 # extend the wires by the number of missing operations
775 for ow in [i for i in range(len(circuit_tikz))]:
776 circuit_tikz[ow].extend(
777 "" for _ in range(max_len - len(circuit_tikz[ow]))
778 )
780 circuit_tikz[op.wires[0]][-1] += QuanTikz.barrier(op)
781 # single qubit gate?
782 elif len(op.wires) == 1:
783 # build and append standard gate
784 circuit_tikz[op.wires[0]].append(
785 QuanTikz.gate(
786 op,
787 index=next(index),
788 gate_values=gate_values,
789 inputs_symbols=next(inputs_symbols),
790 )
791 )
792 # controlled gate?
793 elif len(op.wires) == 2:
794 # build the controlled gate
795 if op.name in ["CRX", "CRY", "CRZ"]:
796 ctrl, targ = QuanTikz.cgate(
797 op,
798 index=next(index),
799 gate_values=gate_values,
800 inputs_symbols=next(inputs_symbols),
801 )
802 else:
803 ctrl, targ = QuanTikz.cgate(op)
805 # get the wires that this cgate spans over
806 crossing_wires = [
807 i for i in range(min(op.wires), max(op.wires) + 1)
808 ]
809 # get the maximum length of all operations currently on this wire
810 max_len = max([len(circuit_tikz[cw]) for cw in crossing_wires])
812 # extend the affected wires by the number of missing operations
813 for ow in [i for i in range(min(op.wires), max(op.wires) + 1)]:
814 circuit_tikz[ow].extend(
815 "" for _ in range(max_len - len(circuit_tikz[ow]))
816 )
818 # finally append the cgate operation
819 circuit_tikz[op.wires[0]].append(ctrl)
820 circuit_tikz[op.wires[1]].append(targ)
822 # extend the non-affected wires by the number of missing operations
823 for cw in crossing_wires - op.wires:
824 circuit_tikz[cw].append("")
825 else:
826 raise NotImplementedError(">2-wire gates are not supported yet")
828 return circuit_tikz
830 @staticmethod
831 def build(
832 circuit: qml.QNode,
833 params,
834 inputs,
835 enc_params=None,
836 gate_values=False,
837 inputs_symbols="x",
838 ) -> str:
839 """
840 Generate LaTeX for a quantum circuit in stick notation.
842 Parameters
843 ----------
844 circuit : qml.QNode
845 The quantum circuit to represent.
846 params : array
847 Weight parameters for the circuit.
848 inputs : array
849 Inputs for the circuit.
850 enc_params : array
851 Encoding weight parameters for the circuit.
852 gate_values : bool, optional
853 Toggle for gate values or theta variables in the representation.
854 inputs_symbols : str, optional
855 Symbols for the inputs in the representation.
857 Returns
858 -------
859 str
860 LaTeX string for the circuit.
861 """
862 if enc_params is not None:
863 quantum_tape = qml.workflow.construct_tape(circuit)(
864 params=params, inputs=inputs, enc_params=enc_params
865 )
866 else:
867 quantum_tape = qml.workflow.construct_tape(circuit)(
868 params=params, inputs=inputs
869 )
871 if isinstance(inputs_symbols, str) and inputs.size > 1:
872 inputs_symbols = cycle(
873 [f"{inputs_symbols}_{i}" for i in range(inputs.size)]
874 )
875 elif isinstance(inputs_symbols, list):
876 assert (
877 len(inputs_symbols) == inputs.size
878 ), f"The number of input symbols {len(inputs_symbols)} \
879 must match the number of inputs {inputs.size}."
880 inputs_symbols = cycle(inputs_symbols)
881 else:
882 inputs_symbols = cycle([inputs_symbols])
884 circuit_tikz = QuanTikz._build_tikz_circuit(
885 quantum_tape, gate_values=gate_values, inputs_symbols=inputs_symbols
886 )
887 quantikz_str = ""
889 # get the maximum length of all wires
890 max_len = max(len(circuit_tikz[cw]) for cw in range(len(circuit_tikz)))
892 # extend the wires by the number of missing operations
893 for ow in [i for i in range(len(circuit_tikz))]:
894 circuit_tikz[ow].extend("" for _ in range(max_len - len(circuit_tikz[ow])))
896 for wire_idx, wire_ops in enumerate(circuit_tikz):
897 for op_idx, op in enumerate(wire_ops):
898 # if not last operation on wire
899 if op_idx < len(wire_ops) - 1:
900 quantikz_str += f"{op} & "
901 else:
902 quantikz_str += f"{op}"
903 # if not last wire
904 if wire_idx < len(circuit_tikz) - 1:
905 quantikz_str += " \\\\\n"
907 return QuanTikz.TikzFigure(quantikz_str)