Coverage for qml_essentials/utils.py: 94%
320 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-10-02 13:10 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-10-02 13:10 +0000
1from __future__ import annotations
2from typing import List, Tuple
3import numpy as np
4import pennylane as qml
5from pennylane.operation import Operator
6from pennylane.tape import QuantumScript, QuantumScriptBatch, QuantumTape
7from pennylane.typing import PostprocessingFn
8import pennylane.numpy as pnp
9import pennylane.ops.op_math as qml_op
10from pennylane.drawer import drawable_layers, tape_text
11from fractions import Fraction
12from itertools import cycle
13from scipy.linalg import logm
14import dill
15import multiprocessing
16import os
18CLIFFORD_GATES = (
19 qml.PauliX,
20 qml.PauliY,
21 qml.PauliZ,
22 qml.X,
23 qml.Y,
24 qml.Z,
25 qml.Hadamard,
26 qml.S,
27 qml.CNOT,
28)
30PAULI_ROTATION_GATES = (
31 qml.RX,
32 qml.RY,
33 qml.RZ,
34 qml.PauliRot,
35)
37SKIPPABLE_OPERATIONS = (qml.Barrier,)
40class MultiprocessingPool:
42 class DillProcess(multiprocessing.Process):
44 def __init__(self, *args, **kwargs):
45 super().__init__(*args, **kwargs)
46 self._target = dill.dumps(
47 self._target
48 ) # Save the target function as bytes, using dill
50 def run(self):
51 if self._target:
52 self._target = dill.loads(
53 self._target
54 ) # Unpickle the target function before executing
55 return self._target(
56 *self._args, **self._kwargs
57 ) # Execute the target function
59 def __init__(self, target, n_processes, cpu_scaler, *args, **kwargs):
60 self.target = target
61 self.n_processes = n_processes
62 self.cpu_scaler = cpu_scaler
63 self.args = args
64 self.kwargs = kwargs
66 assert (
67 self.cpu_scaler <= 1 and self.cpu_scaler >= 0
68 ), f"cpu_scaler must in [0..1], got {self.cpu_scaler}"
70 def spawn(self):
71 manager = multiprocessing.Manager()
72 return_dict = manager.dict()
74 jobs = []
75 # Portable CPU detection
76 try:
77 n_procs = len(os.sched_getaffinity(0))
78 except AttributeError:
79 n_procs = os.cpu_count() or 1
80 n_procs = max(int(n_procs * self.cpu_scaler), 1)
81 # n_procs = max(int(len(os.sched_getaffinity(0)) * self.cpu_scaler), 1)
83 c_procs = 0
84 for it in range(self.n_processes):
85 m = self.DillProcess(
86 target=self.target,
87 args=[it, return_dict, *self.args],
88 kwargs=self.kwargs,
89 )
91 # append and start job
92 jobs.append(m)
93 jobs[-1].start()
94 c_procs += 1
96 # if we reach the max limit of jobs
97 if c_procs > n_procs:
98 # wait for the last n_procs jobs to finish
99 for j in jobs[-c_procs:]:
100 j.join()
101 # then continue with the next batch
102 c_procs = 0
104 # wait for any remaining jobs
105 for j in jobs:
106 if j.is_alive():
107 j.join()
109 return return_dict
112def logm_v(A, **kwargs):
113 # TODO: check warnings
114 if len(A.shape) == 2:
115 return logm(A, **kwargs)
116 elif len(A.shape) == 3:
117 AV = np.zeros(A.shape, dtype=A.dtype)
118 for i in range(A.shape[0]):
119 AV[i] = logm(A[i], **kwargs)
120 return AV
121 else:
122 raise NotImplementedError("Unsupported shape of input matrix")
125class PauliCircuit:
126 """
127 Wrapper for Pauli-Clifford Circuits described by Nemkov et al.
128 (https://doi.org/10.1103/PhysRevA.108.032406). The code is inspired
129 by the corresponding implementation: https://github.com/idnm/FourierVQA.
131 A Pauli Circuit only consists of parameterised Pauli-rotations and Clifford
132 gates, which is the default for the most common VQCs.
133 """
135 @staticmethod
136 def from_parameterised_circuit(
137 tape: QuantumScript,
138 ) -> tuple[QuantumScriptBatch, PostprocessingFn]:
139 """
140 Transformation function (see also qml.transforms) to convert an ansatz
141 into a Pauli-Clifford circuit.
144 **Usage** (without using Model, Model provides a boolean argument
145 "as_pauli_circuit" that internally uses the Pauli-Clifford):
146 ```
147 # initialise some QNode
148 circuit = qml.QNode(
149 circuit_fkt, # function for your circuit definition
150 qml.device("default.qubit", wires=5),
151 )
152 pauli_circuit = PauliCircuit.from_parameterised_circuit(circuit)
154 # Call exactly the same as circuit
155 some_input = [0.1, 0.2]
157 circuit(some_input)
158 pauli_circuit(some_input)
160 # Both results should be equal!
161 ```
163 Args:
164 tape (QuantumScript): The quantum tape for the operations in the
165 ansatz. This is automatically passed, when initialising the
166 transform function with a QNode. Note: directly calling
167 `PauliCircuit.from_parameterised_circuit(circuit)` for a QNode
168 circuit will fail, see usage above.
170 Returns:
171 tuple[QuantumScriptBatch, PostprocessingFn]:
172 - A new quantum tape, containing the operations of the
173 Pauli-Clifford Circuit.
174 - A postprocessing function that does nothing.
175 """
177 operations = PauliCircuit.get_clifford_pauli_gates(tape)
179 pauli_gates, final_cliffords = PauliCircuit.commute_all_cliffords_to_the_end(
180 operations
181 )
183 observables = PauliCircuit.cliffords_in_observable(
184 final_cliffords, tape.observables
185 )
187 with QuantumTape() as tape_new:
188 for op in pauli_gates:
189 op.queue()
190 for obs in observables:
191 qml.expval(obs)
193 def postprocess(res):
194 return res[0]
196 return [tape_new], postprocess
198 @staticmethod
199 def commute_all_cliffords_to_the_end(
200 operations: List[Operator],
201 ) -> Tuple[List[Operator], List[Operator]]:
202 """
203 This function moves all clifford gates to the end of the circuit,
204 accounting for commutation rules.
206 Args:
207 operations (List[Operator]): The operations in the tape of the
208 circuit
210 Returns:
211 Tuple[List[Operator], List[Operator]]:
212 - List of the resulting Pauli-rotations
213 - List of the resulting Clifford gates
214 """
215 first_clifford = -1
216 for i in range(len(operations) - 2, -1, -1):
217 j = i
218 while (
219 j + 1 < len(operations) # Clifford has not alredy reached the end
220 and PauliCircuit._is_clifford(operations[j])
221 and PauliCircuit._is_pauli_rotation(operations[j + 1])
222 ):
223 pauli, clifford = PauliCircuit._evolve_clifford_rotation(
224 operations[j], operations[j + 1]
225 )
226 operations[j] = pauli
227 operations[j + 1] = clifford
228 j += 1
229 first_clifford = j
231 # No Clifford gates are in the circuit
232 if not PauliCircuit._is_clifford(operations[-1]):
233 return operations, []
235 pauli_rotations = operations[:first_clifford]
236 clifford_gates = operations[first_clifford:]
238 return pauli_rotations, clifford_gates
240 @staticmethod
241 def get_clifford_pauli_gates(tape: QuantumScript) -> List[Operator]:
242 """
243 This function decomposes all gates in the circuit to clifford and
244 pauli-rotation gates
246 Args:
247 tape (QuantumScript): The tape of the circuit containing all
248 operations.
250 Returns:
251 List[Operator]: A list of operations consisting only of clifford
252 and Pauli-rotation gates.
253 """
254 operations = []
255 for operation in tape.operations:
256 if PauliCircuit._is_clifford(operation) or PauliCircuit._is_pauli_rotation(
257 operation
258 ):
259 operations.append(operation)
260 elif PauliCircuit._is_skippable(operation):
261 continue
262 else:
263 # TODO: Maybe there is a prettier way to decompose a gate
264 # We currently can not handle parametrised input gates, that
265 # are not plain pauli rotations
266 tape = QuantumScript([operation])
267 decomposed_tape = qml.transforms.decompose(
268 tape, gate_set=PAULI_ROTATION_GATES + CLIFFORD_GATES
269 )
270 decomposed_ops = decomposed_tape[0][0].operations
271 decomposed_ops = [
272 (
273 op
274 if PauliCircuit._is_clifford(op)
275 else op.__class__(pnp.tensor(op.parameters), op.wires)
276 )
277 for op in decomposed_ops
278 ]
279 operations.extend(decomposed_ops)
281 return operations
283 @staticmethod
284 def _is_skippable(operation: Operator) -> bool:
285 """
286 Determines is an operator can be ignored when building the Pauli
287 Clifford circuit. Currently this only contains barriers.
289 Args:
290 operation (Operator): Gate operation
292 Returns:
293 bool: Whether the operation can be skipped.
294 """
295 return isinstance(operation, SKIPPABLE_OPERATIONS)
297 @staticmethod
298 def _is_clifford(operation: Operator) -> bool:
299 """
300 Determines is an operator is a Clifford gate.
302 Args:
303 operation (Operator): Gate operation
305 Returns:
306 bool: Whether the operation is Clifford.
307 """
308 return isinstance(operation, CLIFFORD_GATES)
310 @staticmethod
311 def _is_pauli_rotation(operation: Operator) -> bool:
312 """
313 Determines is an operator is a Pauli rotation gate.
315 Args:
316 operation (Operator): Gate operation
318 Returns:
319 bool: Whether the operation is a Pauli operation.
320 """
321 return isinstance(operation, PAULI_ROTATION_GATES)
323 @staticmethod
324 def _evolve_clifford_rotation(
325 clifford: Operator, pauli: Operator
326 ) -> Tuple[Operator, Operator]:
327 """
328 This function computes the resulting operations, when switching a
329 Cifford gate and a Pauli rotation in the circuit.
331 **Example**:
332 Consider a circuit consisting of the gate sequence
333 ... --- H --- R_z --- ...
334 This function computes the evolved Pauli Rotation, and moves the
335 clifford (Hadamard) gate to the end:
336 ... --- R_x --- H --- ...
338 Args:
339 clifford (Operator): Clifford gate to move.
340 pauli (Operator): Pauli rotation gate to move the clifford past.
342 Returns:
343 Tuple[Operator, Operator]:
344 - Resulting Clifford operator (should be the same as the input)
345 - Evolved Pauli rotation operator
346 """
348 if not any(p_c in clifford.wires for p_c in pauli.wires):
349 return pauli, clifford
351 gen = pauli.generator()
352 param = pauli.parameters[0]
353 requires_grad = param.requires_grad if isinstance(param, pnp.tensor) else False
354 param = pnp.tensor(param)
356 evolved_gen, _ = PauliCircuit._evolve_clifford_pauli(
357 clifford, gen, adjoint_left=False
358 )
359 qubits = evolved_gen.wires
360 evolved_gen = qml.pauli_decompose(evolved_gen.matrix())
361 pauli_str, param_factor = PauliCircuit._get_paulistring_from_generator(
362 evolved_gen
363 )
364 pauli_str, qubits = PauliCircuit._remove_identities_from_paulistr(
365 pauli_str, qubits
366 )
367 pauli = qml.PauliRot(param * param_factor, pauli_str, qubits)
368 pauli.parameters[0].requires_grad = requires_grad
370 return pauli, clifford
372 @staticmethod
373 def _remove_identities_from_paulistr(
374 pauli_str: str, qubits: List[int]
375 ) -> Tuple[str, List[int]]:
376 """
377 Removes identities from Pauli string and its corresponding qubits.
379 Args:
380 pauli_str (str): Pauli string
381 qubits (List[int]): Corresponding qubit indices.
383 Returns:
384 Tuple[str, List[int]]:
385 - Pauli string without identities
386 - Qubits indices without the identities
387 """
389 reduced_qubits = []
390 reduced_pauli_str = ""
391 for i, p in enumerate(pauli_str):
392 if p != "I":
393 reduced_pauli_str += p
394 reduced_qubits.append(qubits[i])
396 return reduced_pauli_str, reduced_qubits
398 @staticmethod
399 def _evolve_clifford_pauli(
400 clifford: Operator, pauli: Operator, adjoint_left: bool = True
401 ) -> Tuple[Operator, Operator]:
402 """
403 This function computes the resulting operation, when evolving a Pauli
404 Operation with a Clifford operation.
405 For a Clifford operator C and a Pauli operator P, this functin computes:
406 P' = C* P C
408 Args:
409 clifford (Operator): Clifford gate
410 pauli (Operator): Pauli gate
411 adjoint_left (bool, optional): If adjoint of the clifford gate is
412 applied to the left. If this is set to True C* P C is computed,
413 else C P C*. Defaults to True.
415 Returns:
416 Tuple[Operator, Operator]:
417 - Evolved Pauli operator
418 - Resulting Clifford operator (should be the same as the input)
419 """
420 if not any(p_c in clifford.wires for p_c in pauli.wires):
421 return pauli, clifford
423 if adjoint_left:
424 evolved_pauli = qml.adjoint(clifford) @ pauli @ qml.adjoint(clifford)
425 else:
426 evolved_pauli = clifford @ pauli @ qml.adjoint(clifford)
428 return evolved_pauli, clifford
430 @staticmethod
431 def _evolve_cliffords_list(cliffords: List[Operator], pauli: Operator) -> Operator:
432 """
433 This function evolves a Pauli operation according to a sequence of cliffords.
435 Args:
436 clifford (Operator): Clifford gate
437 pauli (Operator): Pauli gate
439 Returns:
440 Operator: Evolved Pauli operator
441 """
442 for clifford in cliffords[::-1]:
443 pauli, _ = PauliCircuit._evolve_clifford_pauli(clifford, pauli)
444 qubits = pauli.wires
445 pauli = qml.pauli_decompose(pauli.matrix(), wire_order=qubits)
447 pauli = qml.simplify(pauli)
449 # remove coefficients
450 pauli = (
451 pauli.terms()[1][0]
452 if isinstance(pauli, (qml_op.Prod, qml_op.LinearCombination))
453 else pauli
454 )
456 return pauli
458 @staticmethod
459 def _get_paulistring_from_generator(
460 gen: qml_op.LinearCombination,
461 ) -> Tuple[str, float]:
462 """
463 Compute a Paulistring, consisting of "X", "Y", "Z" and "I" from a
464 generator.
466 Args:
467 gen (qml_op.LinearCombination): The generator operation created by
468 Pennylane
470 Returns:
471 Tuple[str, float]:
472 - The Paulistring
473 - A factor with which to multiply a parameter to the rotation
474 gate.
475 """
476 factor, term = gen.terms()
477 param_factor = -2 * factor # Rotation is defined as exp(-0.5 theta G)
478 pauli_term = term[0] if isinstance(term[0], qml_op.Prod) else [term[0]]
479 pauli_str_list = ["I"] * len(pauli_term)
480 for p in pauli_term:
481 if "Pauli" in p.name:
482 q = p.wires[0]
483 pauli_str_list[q] = p.name[-1]
484 pauli_str = "".join(pauli_str_list)
485 return pauli_str, param_factor
487 @staticmethod
488 def cliffords_in_observable(
489 operations: List[Operator], original_obs: List[Operator]
490 ) -> List[Operator]:
491 """
492 Integrates Clifford gates in the observables of the original ansatz.
494 Args:
495 operations (List[Operator]): Clifford gates
496 original_obs (List[Operator]): Original observables from the
497 circuit
499 Returns:
500 List[Operator]: Observables with Clifford operations
501 """
502 observables = []
503 for ob in original_obs:
504 clifford_obs = PauliCircuit._evolve_cliffords_list(operations, ob)
505 observables.append(clifford_obs)
506 return observables
509class QuanTikz:
510 class TikzFigure:
511 def __init__(self, quantikz_str: str):
512 self.quantikz_str = quantikz_str
514 def __repr__(self):
515 return self.quantikz_str
517 def __str__(self):
518 return self.quantikz_str
520 def wrap_figure(self):
521 """
522 Wraps the quantikz string in a LaTeX figure environment.
524 Returns:
525 str: A formatted LaTeX string representing the TikZ figure containing
526 the quantum circuit diagram.
527 """
528 return f"""
529\\begin{{figure}}
530 \\centering
531 \\begin{{tikzpicture}}
532 \\node[scale=0.85] {{
533 \\begin{{quantikz}}
534 {self.quantikz_str}
535 \\end{{quantikz}}
536 }};
537 \\end{{tikzpicture}}
538\\end{{figure}}"""
540 def export(self, destination: str, full_document=False, mode="w") -> None:
541 """
542 Export a LaTeX document with a quantum circuit in stick notation.
544 Parameters
545 ----------
546 quantikz_strs : str or list[str]
547 LaTeX string for the quantum circuit or a list of LaTeX strings.
548 destination : str
549 Path to the destination file.
550 """
551 if full_document:
552 latex_code = f"""
553\\documentclass{{article}}
554\\usepackage{{quantikz}}
555\\usepackage{{tikz}}
556\\usetikzlibrary{{quantikz2}}
557\\usepackage{{quantikz}}
558\\usepackage[a3paper, landscape, margin=0.5cm]{{geometry}}
559\\begin{{document}}
560{self.wrap_figure()}
561\\end{{document}}"""
562 else:
563 latex_code = self.quantikz_str + "\n"
565 with open(destination, mode) as f:
566 f.write(latex_code)
568 @staticmethod
569 def ground_state() -> str:
570 """
571 Generate the LaTeX representation of the |0⟩ ground state in stick notation.
573 Returns
574 -------
575 str
576 LaTeX string for the |0⟩ state.
577 """
578 return "\\lstick{\\ket{0}}"
580 @staticmethod
581 def measure(op):
582 if len(op.wires) > 1:
583 raise NotImplementedError("Multi-wire measurements are not supported yet")
584 else:
585 return "\\meter{}"
587 @staticmethod
588 def search_pi_fraction(w, op_name):
589 w_pi = Fraction(w / np.pi).limit_denominator(100)
590 # Not a small nice Fraction
591 if w_pi.denominator > 12:
592 return f"\\gate{{{op_name}({w:.2f})}}"
593 # Pi
594 elif w_pi.denominator == 1 and w_pi.numerator == 1:
595 return f"\\gate{{{op_name}(\\pi)}}"
596 # 0
597 elif w_pi.numerator == 0:
598 return f"\\gate{{{op_name}(0)}}"
599 # Multiple of Pi
600 elif w_pi.denominator == 1:
601 return f"\\gate{{{op_name}({w_pi.numerator}\\pi)}}"
602 # Nice Fraction of pi
603 elif w_pi.numerator == 1:
604 return (
605 f"\\gate{{{op_name}\\left("
606 f"\\frac{{\\pi}}{{{w_pi.denominator}}}\\right)}}"
607 )
608 # Small nice Fraction
609 else:
610 return (
611 f"\\gate{{{op_name}\\left("
612 f"\\frac{{{w_pi.numerator}\\pi}}{{{w_pi.denominator}}}"
613 f"\\right)}}"
614 )
616 @staticmethod
617 def gate(op, index=None, gate_values=False, inputs_symbols="x") -> str:
618 """
619 Generate LaTeX for a quantum gate in stick notation.
621 Parameters
622 ----------
623 op : qml.Operation
624 The quantum gate to represent.
625 index : int, optional
626 Gate index in the circuit.
627 gate_values : bool, optional
628 Include gate values in the representation.
629 inputs_symbols : str, optional
630 Symbols for the inputs in the representation.
632 Returns
633 -------
634 str
635 LaTeX string for the gate.
636 """
637 op_name = op.name
638 match op.name:
639 case "Hadamard":
640 op_name = "H"
641 case "RX" | "RY" | "RZ":
642 pass
643 case "Rot":
644 op_name = "R"
646 if gate_values and len(op.parameters) > 0:
647 w = float(op.parameters[0].item())
648 return QuanTikz.search_pi_fraction(w, op_name)
649 else:
650 # Is gate with parameter
651 if op.parameters == [] or op.parameters[0].shape == ():
652 if index is None:
653 return f"\\gate{{{op_name}}}"
654 else:
655 return f"\\gate{{{op_name}(\\theta_{{{index}}})}}"
656 # Is gate with input
657 elif op.parameters[0].shape == (1,):
658 return f"\\gate{{{op_name}({inputs_symbols})}}"
660 @staticmethod
661 def cgate(op, index=None, gate_values=False, inputs_symbols="x") -> Tuple[str, str]:
662 """
663 Generate LaTeX for a controlled quantum gate in stick notation.
665 Parameters
666 ----------
667 op : qml.Operation
668 The quantum gate operation to represent.
669 index : int, optional
670 Gate index in the circuit.
671 gate_values : bool, optional
672 Include gate values in the representation.
673 inputs_symbols : str, optional
674 Symbols for the inputs in the representation.
676 Returns
677 -------
678 Tuple[str, str]
679 - LaTeX string for the control gate
680 - LaTeX string for the target gate
681 """
682 match op.name:
683 case "CRX" | "CRY" | "CRZ" | "CX" | "CY" | "CZ":
684 op_name = op.name[1:]
685 case _:
686 pass
687 targ = "\\targ{}"
688 if op.name in ["CRX", "CRY", "CRZ"]:
689 if gate_values and len(op.parameters) > 0:
690 w = float(op.parameters[0].item())
691 targ = QuanTikz.search_pi_fraction(w, op_name)
692 else:
693 # Is gate with parameter
694 if op.parameters[0].shape == ():
695 if index is None:
696 targ = f"\\gate{{{op_name}}}"
697 else:
698 targ = f"\\gate{{{op_name}(\\theta_{{{index}}})}}"
699 # Is gate with input
700 elif op.parameters[0].shape == (1,):
701 targ = f"\\gate{{{op_name}({inputs_symbols})}}"
702 elif op.name in ["CX", "CY", "CZ"]:
703 targ = "\\control{}"
705 distance = op.wires[1] - op.wires[0]
706 return f"\\ctrl{{{distance}}}", targ
708 @staticmethod
709 def barrier(op) -> str:
710 """
711 Generate LaTeX for a barrier in stick notation.
713 Parameters
714 ----------
715 op : qml.Operation
716 The barrier operation to represent.
718 Returns
719 -------
720 str
721 LaTeX string for the barrier.
722 """
723 return (
724 "\\slice[style={{draw=black, solid, double distance=2pt, "
725 "line width=0.5pt}}]{{}}"
726 )
728 @staticmethod
729 def _build_tikz_circuit(quantum_tape, gate_values=False, inputs_symbols="x"):
730 """
731 Builds a LaTeX representation of a quantum circuit in TikZ format.
733 This static method constructs a TikZ circuit diagram from a given quantum
734 tape. It processes the operations in the tape, including gates, controlled
735 gates, barriers, and measurements. The resulting structure is a list of
736 LaTeX strings, each representing a wire in the circuit.
738 Parameters
739 ----------
740 quantum_tape : QuantumTape
741 The quantum tape containing the operations of the circuit.
742 gate_values : bool, optional
743 If True, include gate parameter values in the representation.
744 inputs_symbols : str, optional
745 Symbols to represent the inputs in the circuit.
747 Returns
748 -------
749 circuit_tikz : list of list of str
750 A nested list where each inner list contains LaTeX strings representing
751 the operations on a single wire of the circuit.
752 """
754 circuit_tikz = [
755 [QuanTikz.ground_state()] for _ in range(quantum_tape.num_wires)
756 ]
758 index = iter(range(10 * quantum_tape.num_params))
759 for op in quantum_tape.circuit:
760 # catch measurement operations
761 if op._queue_category == "_measurements":
762 # get the maximum length of all wires
763 max_len = max(len(circuit_tikz[cw]) for cw in range(len(circuit_tikz)))
764 if op.wires[0] != 0:
765 max_len -= 1
766 # extend the wire by the number of missing operations
767 circuit_tikz[op.wires[0]].extend(
768 "" for _ in range(max_len - len(circuit_tikz[op.wires[0]]))
769 )
770 circuit_tikz[op.wires[0]].append(QuanTikz.measure(op))
771 # process all gates
772 elif op._queue_category == "_ops":
773 # catch barriers
774 if op.name == "Barrier":
776 # get the maximum length of all wires
777 max_len = max(
778 len(circuit_tikz[cw]) for cw in range(len(circuit_tikz))
779 )
781 # extend the wires by the number of missing operations
782 for ow in [i for i in range(len(circuit_tikz))]:
783 circuit_tikz[ow].extend(
784 "" for _ in range(max_len - len(circuit_tikz[ow]))
785 )
787 circuit_tikz[op.wires[0]][-1] += QuanTikz.barrier(op)
788 # single qubit gate?
789 elif len(op.wires) == 1:
790 # build and append standard gate
791 circuit_tikz[op.wires[0]].append(
792 QuanTikz.gate(
793 op,
794 index=next(index),
795 gate_values=gate_values,
796 inputs_symbols=next(inputs_symbols),
797 )
798 )
799 # controlled gate?
800 elif len(op.wires) == 2:
801 # build the controlled gate
802 if op.name in ["CRX", "CRY", "CRZ"]:
803 ctrl, targ = QuanTikz.cgate(
804 op,
805 index=next(index),
806 gate_values=gate_values,
807 inputs_symbols=next(inputs_symbols),
808 )
809 else:
810 ctrl, targ = QuanTikz.cgate(op)
812 # get the wires that this cgate spans over
813 crossing_wires = [
814 i for i in range(min(op.wires), max(op.wires) + 1)
815 ]
816 # get the maximum length of all operations currently on this wire
817 max_len = max([len(circuit_tikz[cw]) for cw in crossing_wires])
819 # extend the affected wires by the number of missing operations
820 for ow in [i for i in range(min(op.wires), max(op.wires) + 1)]:
821 circuit_tikz[ow].extend(
822 "" for _ in range(max_len - len(circuit_tikz[ow]))
823 )
825 # finally append the cgate operation
826 circuit_tikz[op.wires[0]].append(ctrl)
827 circuit_tikz[op.wires[1]].append(targ)
829 # extend the non-affected wires by the number of missing operations
830 for cw in crossing_wires - op.wires:
831 circuit_tikz[cw].append("")
832 else:
833 raise NotImplementedError(">2-wire gates are not supported yet")
835 return circuit_tikz
837 @staticmethod
838 def build(
839 circuit: qml.QNode,
840 params,
841 inputs,
842 enc_params=None,
843 gate_values=False,
844 inputs_symbols="x",
845 ) -> str:
846 """
847 Generate LaTeX for a quantum circuit in stick notation.
849 Parameters
850 ----------
851 circuit : qml.QNode
852 The quantum circuit to represent.
853 params : array
854 Weight parameters for the circuit.
855 inputs : array
856 Inputs for the circuit.
857 enc_params : array
858 Encoding weight parameters for the circuit.
859 gate_values : bool, optional
860 Toggle for gate values or theta variables in the representation.
861 inputs_symbols : str, optional
862 Symbols for the inputs in the representation.
864 Returns
865 -------
866 str
867 LaTeX string for the circuit.
868 """
869 if enc_params is not None:
870 quantum_tape = qml.workflow.construct_tape(circuit)(
871 params=params, inputs=inputs, enc_params=enc_params
872 )
873 else:
874 quantum_tape = qml.workflow.construct_tape(circuit)(
875 params=params, inputs=inputs
876 )
878 if isinstance(inputs_symbols, str) and inputs.size > 1:
879 inputs_symbols = cycle(
880 [f"{inputs_symbols}_{i}" for i in range(inputs.size)]
881 )
882 elif isinstance(inputs_symbols, list):
883 assert (
884 len(inputs_symbols) == inputs.size
885 ), f"The number of input symbols {len(inputs_symbols)} \
886 must match the number of inputs {inputs.size}."
887 inputs_symbols = cycle(inputs_symbols)
888 else:
889 inputs_symbols = cycle([inputs_symbols])
891 circuit_tikz = QuanTikz._build_tikz_circuit(
892 quantum_tape, gate_values=gate_values, inputs_symbols=inputs_symbols
893 )
894 quantikz_str = ""
896 # get the maximum length of all wires
897 max_len = max(len(circuit_tikz[cw]) for cw in range(len(circuit_tikz)))
899 # extend the wires by the number of missing operations
900 for ow in [i for i in range(len(circuit_tikz))]:
901 circuit_tikz[ow].extend("" for _ in range(max_len - len(circuit_tikz[ow])))
903 for wire_idx, wire_ops in enumerate(circuit_tikz):
904 for op_idx, op in enumerate(wire_ops):
905 # if not last operation on wire
906 if op_idx < len(wire_ops) - 1:
907 quantikz_str += f"{op} & "
908 else:
909 quantikz_str += f"{op}"
910 # if not last wire
911 if wire_idx < len(circuit_tikz) - 1:
912 quantikz_str += " \\\\\n"
914 return QuanTikz.TikzFigure(quantikz_str)