Coverage for qml_essentials/utils.py: 33%
312 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-15 15:48 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-15 15:48 +0000
1from __future__ import annotations
2from typing import List, Tuple
3import numpy as np
4import pennylane as qml
5from pennylane.operation import Operator
6from pennylane.tape import QuantumScript, QuantumScriptBatch, QuantumTape
7from pennylane.typing import PostprocessingFn
8import pennylane.numpy as pnp
9import pennylane.ops.op_math as qml_op
10from pennylane.drawer import drawable_layers, tape_text
11from fractions import Fraction
12from itertools import cycle
13from scipy.linalg import logm
14import dill
15import multiprocessing
16import os
18CLIFFORD_GATES = (
19 qml.PauliX,
20 qml.PauliY,
21 qml.PauliZ,
22 qml.X,
23 qml.Y,
24 qml.Z,
25 qml.Hadamard,
26 qml.S,
27 qml.CNOT,
28)
30PAULI_ROTATION_GATES = (
31 qml.RX,
32 qml.RY,
33 qml.RZ,
34 qml.PauliRot,
35)
37SKIPPABLE_OPERATIONS = (qml.Barrier,)
40class MultiprocessingPool:
41 class DillProcess(multiprocessing.Process):
43 def __init__(self, *args, **kwargs):
44 super().__init__(*args, **kwargs)
45 self._target = dill.dumps(
46 self._target
47 ) # Save the target function as bytes, using dill
49 def run(self):
50 if self._target:
51 self._target = dill.loads(
52 self._target
53 ) # Unpickle the target function before executing
54 return self._target(
55 *self._args, **self._kwargs
56 ) # Execute the target function
58 def __init__(self, target, n_processes, *args, **kwargs):
59 self.target = target
60 self.n_processes = n_processes
61 self.args = args
62 self.kwargs = kwargs
64 def spawn(self):
65 manager = multiprocessing.Manager()
66 return_dict = manager.dict()
68 jobs = []
69 n_procs = max(len(os.sched_getaffinity(0)) // 2, 1)
70 c_procs = 0
71 for it in range(self.n_processes):
72 m = self.DillProcess(
73 target=self.target,
74 args=[it, return_dict, *self.args],
75 kwargs=self.kwargs,
76 )
78 # append and start job
79 jobs.append(m)
80 jobs[-1].start()
81 c_procs += 1
83 # if we reach the max limit of jobs
84 if c_procs > n_procs:
85 # wait for the last n_procs jobs to finish
86 for j in jobs[-c_procs:]:
87 j.join()
88 # then continue with the next batch
89 c_procs = 0
91 # wait for any remaining jobs
92 for j in jobs:
93 if j.is_alive():
94 j.join()
96 return return_dict
99def logm_v(A, **kwargs):
100 # TODO: check warnings
101 if len(A.shape) == 2:
102 return logm(A, **kwargs)
103 elif len(A.shape) == 3:
104 AV = np.zeros(A.shape, dtype=A.dtype)
105 for i in range(A.shape[0]):
106 AV[i] = logm(A[i], **kwargs)
107 return AV
108 else:
109 raise NotImplementedError("Unsupported shape of input matrix")
112class PauliCircuit:
113 """
114 Wrapper for Pauli-Clifford Circuits described by Nemkov et al.
115 (https://doi.org/10.1103/PhysRevA.108.032406). The code is inspired
116 by the corresponding implementation: https://github.com/idnm/FourierVQA.
118 A Pauli Circuit only consists of parameterised Pauli-rotations and Clifford
119 gates, which is the default for the most common VQCs.
120 """
122 @staticmethod
123 def from_parameterised_circuit(
124 tape: QuantumScript,
125 ) -> tuple[QuantumScriptBatch, PostprocessingFn]:
126 """
127 Transformation function (see also qml.transforms) to convert an ansatz
128 into a Pauli-Clifford circuit.
131 **Usage** (without using Model, Model provides a boolean argument
132 "as_pauli_circuit" that internally uses the Pauli-Clifford):
133 ```
134 # initialise some QNode
135 circuit = qml.QNode(
136 circuit_fkt, # function for your circuit definition
137 qml.device("default.qubit", wires=5),
138 )
139 pauli_circuit = PauliCircuit.from_parameterised_circuit(circuit)
141 # Call exactly the same as circuit
142 some_input = [0.1, 0.2]
144 circuit(some_input)
145 pauli_circuit(some_input)
147 # Both results should be equal!
148 ```
150 Args:
151 tape (QuantumScript): The quantum tape for the operations in the
152 ansatz. This is automatically passed, when initialising the
153 transform function with a QNode. Note: directly calling
154 `PauliCircuit.from_parameterised_circuit(circuit)` for a QNode
155 circuit will fail, see usage above.
157 Returns:
158 tuple[QuantumScriptBatch, PostprocessingFn]:
159 - A new quantum tape, containing the operations of the
160 Pauli-Clifford Circuit.
161 - A postprocessing function that does nothing.
162 """
164 operations = PauliCircuit.get_clifford_pauli_gates(tape)
166 pauli_gates, final_cliffords = PauliCircuit.commute_all_cliffords_to_the_end(
167 operations
168 )
170 observables = PauliCircuit.cliffords_in_observable(
171 final_cliffords, tape.observables
172 )
174 with QuantumTape() as tape_new:
175 for op in pauli_gates:
176 op.queue()
177 for obs in observables:
178 qml.expval(obs)
180 def postprocess(res):
181 return res[0]
183 return [tape_new], postprocess
185 @staticmethod
186 def commute_all_cliffords_to_the_end(
187 operations: List[Operator],
188 ) -> Tuple[List[Operator], List[Operator]]:
189 """
190 This function moves all clifford gates to the end of the circuit,
191 accounting for commutation rules.
193 Args:
194 operations (List[Operator]): The operations in the tape of the
195 circuit
197 Returns:
198 Tuple[List[Operator], List[Operator]]:
199 - List of the resulting Pauli-rotations
200 - List of the resulting Clifford gates
201 """
202 first_clifford = -1
203 for i in range(len(operations) - 2, -1, -1):
204 j = i
205 while (
206 j + 1 < len(operations) # Clifford has not alredy reached the end
207 and PauliCircuit._is_clifford(operations[j])
208 and PauliCircuit._is_pauli_rotation(operations[j + 1])
209 ):
210 pauli, clifford = PauliCircuit._evolve_clifford_rotation(
211 operations[j], operations[j + 1]
212 )
213 operations[j] = pauli
214 operations[j + 1] = clifford
215 j += 1
216 first_clifford = j
218 # No Clifford gates are in the circuit
219 if not PauliCircuit._is_clifford(operations[-1]):
220 return operations, []
222 pauli_rotations = operations[:first_clifford]
223 clifford_gates = operations[first_clifford:]
225 return pauli_rotations, clifford_gates
227 @staticmethod
228 def get_clifford_pauli_gates(tape: QuantumScript) -> List[Operator]:
229 """
230 This function decomposes all gates in the circuit to clifford and
231 pauli-rotation gates
233 Args:
234 tape (QuantumScript): The tape of the circuit containing all
235 operations.
237 Returns:
238 List[Operator]: A list of operations consisting only of clifford
239 and Pauli-rotation gates.
240 """
241 operations = []
242 for operation in tape.operations:
243 if PauliCircuit._is_clifford(operation) or PauliCircuit._is_pauli_rotation(
244 operation
245 ):
246 operations.append(operation)
247 elif PauliCircuit._is_skippable(operation):
248 continue
249 else:
250 # TODO: Maybe there is a prettier way to decompose a gate
251 # We currently can not handle parametrised input gates, that
252 # are not plain pauli rotations
253 tape = QuantumScript([operation])
254 decomposed_tape = qml.transforms.decompose(
255 tape, gate_set=PAULI_ROTATION_GATES + CLIFFORD_GATES
256 )
257 decomposed_ops = decomposed_tape[0][0].operations
258 decomposed_ops = [
259 (
260 op
261 if PauliCircuit._is_clifford(op)
262 else op.__class__(pnp.tensor(op.parameters), op.wires)
263 )
264 for op in decomposed_ops
265 ]
266 operations.extend(decomposed_ops)
268 return operations
270 @staticmethod
271 def _is_skippable(operation: Operator) -> bool:
272 """
273 Determines is an operator can be ignored when building the Pauli
274 Clifford circuit. Currently this only contains barriers.
276 Args:
277 operation (Operator): Gate operation
279 Returns:
280 bool: Whether the operation can be skipped.
281 """
282 return isinstance(operation, SKIPPABLE_OPERATIONS)
284 @staticmethod
285 def _is_clifford(operation: Operator) -> bool:
286 """
287 Determines is an operator is a Clifford gate.
289 Args:
290 operation (Operator): Gate operation
292 Returns:
293 bool: Whether the operation is Clifford.
294 """
295 return isinstance(operation, CLIFFORD_GATES)
297 @staticmethod
298 def _is_pauli_rotation(operation: Operator) -> bool:
299 """
300 Determines is an operator is a Pauli rotation gate.
302 Args:
303 operation (Operator): Gate operation
305 Returns:
306 bool: Whether the operation is a Pauli operation.
307 """
308 return isinstance(operation, PAULI_ROTATION_GATES)
310 @staticmethod
311 def _evolve_clifford_rotation(
312 clifford: Operator, pauli: Operator
313 ) -> Tuple[Operator, Operator]:
314 """
315 This function computes the resulting operations, when switching a
316 Cifford gate and a Pauli rotation in the circuit.
318 **Example**:
319 Consider a circuit consisting of the gate sequence
320 ... --- H --- R_z --- ...
321 This function computes the evolved Pauli Rotation, and moves the
322 clifford (Hadamard) gate to the end:
323 ... --- R_x --- H --- ...
325 Args:
326 clifford (Operator): Clifford gate to move.
327 pauli (Operator): Pauli rotation gate to move the clifford past.
329 Returns:
330 Tuple[Operator, Operator]:
331 - Resulting Clifford operator (should be the same as the input)
332 - Evolved Pauli rotation operator
333 """
335 if not any(p_c in clifford.wires for p_c in pauli.wires):
336 return pauli, clifford
338 gen = pauli.generator()
339 param = pauli.parameters[0]
340 requires_grad = param.requires_grad if isinstance(param, pnp.tensor) else False
341 param = pnp.tensor(param)
343 evolved_gen, _ = PauliCircuit._evolve_clifford_pauli(
344 clifford, gen, adjoint_left=False
345 )
346 qubits = evolved_gen.wires
347 evolved_gen = qml.pauli_decompose(evolved_gen.matrix())
348 pauli_str, param_factor = PauliCircuit._get_paulistring_from_generator(
349 evolved_gen
350 )
351 pauli_str, qubits = PauliCircuit._remove_identities_from_paulistr(
352 pauli_str, qubits
353 )
354 pauli = qml.PauliRot(param * param_factor, pauli_str, qubits)
355 pauli.parameters[0].requires_grad = requires_grad
357 return pauli, clifford
359 @staticmethod
360 def _remove_identities_from_paulistr(
361 pauli_str: str, qubits: List[int]
362 ) -> Tuple[str, List[int]]:
363 """
364 Removes identities from Pauli string and its corresponding qubits.
366 Args:
367 pauli_str (str): Pauli string
368 qubits (List[int]): Corresponding qubit indices.
370 Returns:
371 Tuple[str, List[int]]:
372 - Pauli string without identities
373 - Qubits indices without the identities
374 """
376 reduced_qubits = []
377 reduced_pauli_str = ""
378 for i, p in enumerate(pauli_str):
379 if p != "I":
380 reduced_pauli_str += p
381 reduced_qubits.append(qubits[i])
383 return reduced_pauli_str, reduced_qubits
385 @staticmethod
386 def _evolve_clifford_pauli(
387 clifford: Operator, pauli: Operator, adjoint_left: bool = True
388 ) -> Tuple[Operator, Operator]:
389 """
390 This function computes the resulting operation, when evolving a Pauli
391 Operation with a Clifford operation.
392 For a Clifford operator C and a Pauli operator P, this functin computes:
393 P' = C* P C
395 Args:
396 clifford (Operator): Clifford gate
397 pauli (Operator): Pauli gate
398 adjoint_left (bool, optional): If adjoint of the clifford gate is
399 applied to the left. If this is set to True C* P C is computed,
400 else C P C*. Defaults to True.
402 Returns:
403 Tuple[Operator, Operator]:
404 - Evolved Pauli operator
405 - Resulting Clifford operator (should be the same as the input)
406 """
407 if not any(p_c in clifford.wires for p_c in pauli.wires):
408 return pauli, clifford
410 if adjoint_left:
411 evolved_pauli = qml.adjoint(clifford) @ pauli @ qml.adjoint(clifford)
412 else:
413 evolved_pauli = clifford @ pauli @ qml.adjoint(clifford)
415 return evolved_pauli, clifford
417 @staticmethod
418 def _evolve_cliffords_list(cliffords: List[Operator], pauli: Operator) -> Operator:
419 """
420 This function evolves a Pauli operation according to a sequence of cliffords.
422 Args:
423 clifford (Operator): Clifford gate
424 pauli (Operator): Pauli gate
426 Returns:
427 Operator: Evolved Pauli operator
428 """
429 for clifford in cliffords[::-1]:
430 pauli, _ = PauliCircuit._evolve_clifford_pauli(clifford, pauli)
431 qubits = pauli.wires
432 pauli = qml.pauli_decompose(pauli.matrix(), wire_order=qubits)
434 pauli = qml.simplify(pauli)
436 # remove coefficients
437 pauli = (
438 pauli.terms()[1][0]
439 if isinstance(pauli, (qml_op.Prod, qml_op.LinearCombination))
440 else pauli
441 )
443 return pauli
445 @staticmethod
446 def _get_paulistring_from_generator(
447 gen: qml_op.LinearCombination,
448 ) -> Tuple[str, float]:
449 """
450 Compute a Paulistring, consisting of "X", "Y", "Z" and "I" from a
451 generator.
453 Args:
454 gen (qml_op.LinearCombination): The generator operation created by
455 Pennylane
457 Returns:
458 Tuple[str, float]:
459 - The Paulistring
460 - A factor with which to multiply a parameter to the rotation
461 gate.
462 """
463 factor, term = gen.terms()
464 param_factor = -2 * factor # Rotation is defined as exp(-0.5 theta G)
465 pauli_term = term[0] if isinstance(term[0], qml_op.Prod) else [term[0]]
466 pauli_str_list = ["I"] * len(pauli_term)
467 for p in pauli_term:
468 if "Pauli" in p.name:
469 q = p.wires[0]
470 pauli_str_list[q] = p.name[-1]
471 pauli_str = "".join(pauli_str_list)
472 return pauli_str, param_factor
474 @staticmethod
475 def cliffords_in_observable(
476 operations: List[Operator], original_obs: List[Operator]
477 ) -> List[Operator]:
478 """
479 Integrates Clifford gates in the observables of the original ansatz.
481 Args:
482 operations (List[Operator]): Clifford gates
483 original_obs (List[Operator]): Original observables from the
484 circuit
486 Returns:
487 List[Operator]: Observables with Clifford operations
488 """
489 observables = []
490 for ob in original_obs:
491 clifford_obs = PauliCircuit._evolve_cliffords_list(operations, ob)
492 observables.append(clifford_obs)
493 return observables
496class QuanTikz:
497 class TikzFigure:
498 def __init__(self, quantikz_str: str):
499 self.quantikz_str = quantikz_str
501 def __repr__(self):
502 return self.quantikz_str
504 def __str__(self):
505 return self.quantikz_str
507 def wrap_figure(self):
508 """
509 Wraps the quantikz string in a LaTeX figure environment.
511 Returns:
512 str: A formatted LaTeX string representing the TikZ figure containing
513 the quantum circuit diagram.
514 """
515 return f"""
516\\begin{ figure}
517 \\centering
518 \\begin{ tikzpicture}
519 \\node[scale=0.85] {
520 \\begin{ quantikz}
521 {self.quantikz_str}
522 \\end{ quantikz}
523 } ;
524 \\end{ tikzpicture}
525\\end{ figure} """
527 def export(self, destination: str, full_document=False, mode="w") -> None:
528 """
529 Export a LaTeX document with a quantum circuit in stick notation.
531 Parameters
532 ----------
533 quantikz_strs : str or list[str]
534 LaTeX string for the quantum circuit or a list of LaTeX strings.
535 destination : str
536 Path to the destination file.
537 """
538 if full_document:
539 latex_code = f"""
540\\documentclass{ article}
541\\usepackage{ quantikz}
542\\usepackage{ tikz}
543\\usetikzlibrary{ quantikz2}
544\\usepackage{ quantikz}
545\\usepackage[a3paper, landscape, margin=0.5cm]{ geometry}
546\\begin{ document}
547{self.wrap_figure()}
548\\end{ document} """
549 else:
550 latex_code = self.quantikz_str + "\n"
552 with open(destination, mode) as f:
553 f.write(latex_code)
555 @staticmethod
556 def ground_state() -> str:
557 """
558 Generate the LaTeX representation of the |0⟩ ground state in stick notation.
560 Returns
561 -------
562 str
563 LaTeX string for the |0⟩ state.
564 """
565 return "\\lstick{\\ket{0}}"
567 @staticmethod
568 def measure(op):
569 if len(op.wires) > 1:
570 raise NotImplementedError("Multi-wire measurements are not supported yet")
571 else:
572 return "\\meter{}"
574 @staticmethod
575 def search_pi_fraction(w, op_name):
576 w_pi = Fraction(w / np.pi).limit_denominator(100)
577 # Not a small nice Fraction
578 if w_pi.denominator > 12:
579 return f"\\gate{ {op_name}({w:.2f})} "
580 # Pi
581 elif w_pi.denominator == 1 and w_pi.numerator == 1:
582 return f"\\gate{ {op_name}(\\pi)} "
583 # 0
584 elif w_pi.numerator == 0:
585 return f"\\gate{ {op_name}(0)} "
586 # Multiple of Pi
587 elif w_pi.denominator == 1:
588 return f"\\gate{ {op_name}({w_pi.numerator}\\pi)} "
589 # Nice Fraction of pi
590 elif w_pi.numerator == 1:
591 return (
592 f"\\gate{ {op_name}\\left("
593 f"\\frac{ \\pi} { {w_pi.denominator}} \\right)} "
594 )
595 # Small nice Fraction
596 else:
597 return (
598 f"\\gate{ {op_name}\\left("
599 f"\\frac{ {w_pi.numerator}\\pi} { {w_pi.denominator}} "
600 f"\\right)} "
601 )
603 @staticmethod
604 def gate(op, index=None, gate_values=False, inputs_symbols="x") -> str:
605 """
606 Generate LaTeX for a quantum gate in stick notation.
608 Parameters
609 ----------
610 op : qml.Operation
611 The quantum gate to represent.
612 index : int, optional
613 Gate index in the circuit.
614 gate_values : bool, optional
615 Include gate values in the representation.
616 inputs_symbols : str, optional
617 Symbols for the inputs in the representation.
619 Returns
620 -------
621 str
622 LaTeX string for the gate.
623 """
624 op_name = op.name
625 match op.name:
626 case "Hadamard":
627 op_name = "H"
628 case "RX" | "RY" | "RZ":
629 pass
630 case "Rot":
631 op_name = "R"
633 if gate_values and len(op.parameters) > 0:
634 w = float(op.parameters[0].item())
635 return QuanTikz.search_pi_fraction(w, op_name)
636 else:
637 # Is gate with parameter
638 if op.parameters == [] or op.parameters[0].shape == ():
639 if index is None:
640 return f"\\gate{ {op_name}} "
641 else:
642 return f"\\gate{ {op_name}(\\theta_{ {index}} )} "
643 # Is gate with input
644 elif op.parameters[0].shape == (1,):
645 return f"\\gate{ {op_name}({inputs_symbols})} "
647 @staticmethod
648 def cgate(op, index=None, gate_values=False, inputs_symbols="x") -> Tuple[str, str]:
649 """
650 Generate LaTeX for a controlled quantum gate in stick notation.
652 Parameters
653 ----------
654 op : qml.Operation
655 The quantum gate operation to represent.
656 index : int, optional
657 Gate index in the circuit.
658 gate_values : bool, optional
659 Include gate values in the representation.
660 inputs_symbols : str, optional
661 Symbols for the inputs in the representation.
663 Returns
664 -------
665 Tuple[str, str]
666 - LaTeX string for the control gate
667 - LaTeX string for the target gate
668 """
669 match op.name:
670 case "CRX" | "CRY" | "CRZ" | "CX" | "CY" | "CZ":
671 op_name = op.name[1:]
672 case _:
673 pass
674 targ = "\\targ{}"
675 if op.name in ["CRX", "CRY", "CRZ"]:
676 if gate_values and len(op.parameters) > 0:
677 w = float(op.parameters[0].item())
678 targ = QuanTikz.search_pi_fraction(w, op_name)
679 else:
680 # Is gate with parameter
681 if op.parameters[0].shape == ():
682 if index is None:
683 targ = f"\\gate{ {op_name}} "
684 else:
685 targ = f"\\gate{ {op_name}(\\theta_{ {index}} )} "
686 # Is gate with input
687 elif op.parameters[0].shape == (1,):
688 targ = f"\\gate{ {op_name}({inputs_symbols})} "
689 elif op.name in ["CX", "CY", "CZ"]:
690 targ = "\\control{}"
692 distance = op.wires[1] - op.wires[0]
693 return f"\\ctrl{ {distance}} ", targ
695 @staticmethod
696 def barrier(op) -> str:
697 """
698 Generate LaTeX for a barrier in stick notation.
700 Parameters
701 ----------
702 op : qml.Operation
703 The barrier operation to represent.
705 Returns
706 -------
707 str
708 LaTeX string for the barrier.
709 """
710 return (
711 "\\slice[style={{draw=black, solid, double distance=2pt, "
712 "line width=0.5pt}}]{{}}"
713 )
715 @staticmethod
716 def _build_tikz_circuit(quantum_tape, gate_values=False, inputs_symbols="x"):
717 """
718 Builds a LaTeX representation of a quantum circuit in TikZ format.
720 This static method constructs a TikZ circuit diagram from a given quantum
721 tape. It processes the operations in the tape, including gates, controlled
722 gates, barriers, and measurements. The resulting structure is a list of
723 LaTeX strings, each representing a wire in the circuit.
725 Parameters
726 ----------
727 quantum_tape : QuantumTape
728 The quantum tape containing the operations of the circuit.
729 gate_values : bool, optional
730 If True, include gate parameter values in the representation.
731 inputs_symbols : str, optional
732 Symbols to represent the inputs in the circuit.
734 Returns
735 -------
736 circuit_tikz : list of list of str
737 A nested list where each inner list contains LaTeX strings representing
738 the operations on a single wire of the circuit.
739 """
741 circuit_tikz = [
742 [QuanTikz.ground_state()] for _ in range(quantum_tape.num_wires)
743 ]
745 index = iter(range(10 * quantum_tape.num_params))
746 for op in quantum_tape.circuit:
747 # catch measurement operations
748 if op._queue_category == "_measurements":
749 # get the maximum length of all wires
750 max_len = max(len(circuit_tikz[cw]) for cw in range(len(circuit_tikz)))
751 if op.wires[0] != 0:
752 max_len -= 1
753 # extend the wire by the number of missing operations
754 circuit_tikz[op.wires[0]].extend(
755 "" for _ in range(max_len - len(circuit_tikz[op.wires[0]]))
756 )
757 circuit_tikz[op.wires[0]].append(QuanTikz.measure(op))
758 # process all gates
759 elif op._queue_category == "_ops":
760 # catch barriers
761 if op.name == "Barrier":
763 # get the maximum length of all wires
764 max_len = max(
765 len(circuit_tikz[cw]) for cw in range(len(circuit_tikz))
766 )
768 # extend the wires by the number of missing operations
769 for ow in [i for i in range(len(circuit_tikz))]:
770 circuit_tikz[ow].extend(
771 "" for _ in range(max_len - len(circuit_tikz[ow]))
772 )
774 circuit_tikz[op.wires[0]][-1] += QuanTikz.barrier(op)
775 # single qubit gate?
776 elif len(op.wires) == 1:
777 # build and append standard gate
778 circuit_tikz[op.wires[0]].append(
779 QuanTikz.gate(
780 op,
781 index=next(index),
782 gate_values=gate_values,
783 inputs_symbols=next(inputs_symbols),
784 )
785 )
786 # controlled gate?
787 elif len(op.wires) == 2:
788 # build the controlled gate
789 if op.name in ["CRX", "CRY", "CRZ"]:
790 ctrl, targ = QuanTikz.cgate(
791 op,
792 index=next(index),
793 gate_values=gate_values,
794 inputs_symbols=next(inputs_symbols),
795 )
796 else:
797 ctrl, targ = QuanTikz.cgate(op)
799 # get the wires that this cgate spans over
800 crossing_wires = [
801 i for i in range(min(op.wires), max(op.wires) + 1)
802 ]
803 # get the maximum length of all operations currently on this wire
804 max_len = max([len(circuit_tikz[cw]) for cw in crossing_wires])
806 # extend the affected wires by the number of missing operations
807 for ow in [i for i in range(min(op.wires), max(op.wires) + 1)]:
808 circuit_tikz[ow].extend(
809 "" for _ in range(max_len - len(circuit_tikz[ow]))
810 )
812 # finally append the cgate operation
813 circuit_tikz[op.wires[0]].append(ctrl)
814 circuit_tikz[op.wires[1]].append(targ)
816 # extend the non-affected wires by the number of missing operations
817 for cw in crossing_wires - op.wires:
818 circuit_tikz[cw].append("")
819 else:
820 raise NotImplementedError(">2-wire gates are not supported yet")
822 return circuit_tikz
824 @staticmethod
825 def build(
826 circuit: qml.QNode, params, inputs, gate_values=False, inputs_symbols="x"
827 ) -> str:
828 """
829 Generate LaTeX for a quantum circuit in stick notation.
831 Parameters
832 ----------
833 circuit : qml.QNode
834 The quantum circuit to represent.
835 params : array
836 Weight parameters for the circuit.
837 inputs : array
838 Inputs for the circuit.
839 gate_values : bool, optional
840 Toggle for gate values or theta variables in the representation.
841 inputs_symbols : str, optional
842 Symbols for the inputs in the representation.
844 Returns
845 -------
846 str
847 LaTeX string for the circuit.
848 """
849 quantum_tape = qml.workflow.construct_tape(circuit)(
850 params=params, inputs=inputs
851 )
852 if isinstance(inputs_symbols, str) and inputs.size > 1:
853 inputs_symbols = cycle(
854 [f"{inputs_symbols}_{i}" for i in range(inputs.size)]
855 )
856 elif isinstance(inputs_symbols, list):
857 assert (
858 len(inputs_symbols) == inputs.size
859 ), f"The number of input symbols {len(inputs_symbols)} \
860 must match the number of inputs {inputs.size}."
861 inputs_symbols = cycle(inputs_symbols)
862 else:
863 inputs_symbols = cycle([inputs_symbols])
865 circuit_tikz = QuanTikz._build_tikz_circuit(
866 quantum_tape, gate_values=gate_values, inputs_symbols=inputs_symbols
867 )
868 quantikz_str = ""
870 # get the maximum length of all wires
871 max_len = max(len(circuit_tikz[cw]) for cw in range(len(circuit_tikz)))
873 # extend the wires by the number of missing operations
874 for ow in [i for i in range(len(circuit_tikz))]:
875 circuit_tikz[ow].extend("" for _ in range(max_len - len(circuit_tikz[ow])))
877 for wire_idx, wire_ops in enumerate(circuit_tikz):
878 for op_idx, op in enumerate(wire_ops):
879 # if not last operation on wire
880 if op_idx < len(wire_ops) - 1:
881 quantikz_str += f"{op} & "
882 else:
883 quantikz_str += f"{op}"
884 # if not last wire
885 if wire_idx < len(circuit_tikz) - 1:
886 quantikz_str += " \\\\\n"
888 return QuanTikz.TikzFigure(quantikz_str)