Coverage for qml_essentials/utils.py: 95%

276 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2026-02-20 14:03 +0000

1from __future__ import annotations 

2from typing import List, Tuple 

3import jax 

4import jax.numpy as jnp 

5import numpy as np 

6import pennylane as qml 

7from pennylane.operation import Operator 

8from pennylane.tape import QuantumScript, QuantumTape 

9import pennylane.ops.op_math as qml_op 

10from fractions import Fraction 

11from itertools import cycle 

12from scipy.linalg import logm 

13 

14CLIFFORD_GATES = ( 

15 qml.PauliX, 

16 qml.PauliY, 

17 qml.PauliZ, 

18 qml.X, 

19 qml.Y, 

20 qml.Z, 

21 qml.Hadamard, 

22 qml.S, 

23 qml.CNOT, 

24) 

25 

26PAULI_ROTATION_GATES = ( 

27 qml.RX, 

28 qml.RY, 

29 qml.RZ, 

30 qml.PauliRot, 

31) 

32 

33SKIPPABLE_OPERATIONS = (qml.Barrier,) 

34 

35 

36def safe_random_split(random_key: jax.random.PRNGKey, *args, **kwargs): 

37 if random_key is None: 

38 return None, None 

39 else: 

40 return jax.random.split(random_key, *args, **kwargs) 

41 

42 

43def logm_v(A: jnp.ndarray, **kwargs) -> jnp.ndarray: 

44 """ 

45 Compute the logarithm of a matrix. If the provided matrix has an additional 

46 batch dimension, the logarithm of each matrix is computed. 

47 

48 Args: 

49 A (jnp.ndarray): The (potentially batched) matrices of which to compute 

50 the logarithm. 

51 

52 Returns: 

53 jnp.ndarray: The log matrices 

54 """ 

55 # TODO: check warnings 

56 if len(A.shape) == 2: 

57 return logm(A, **kwargs) 

58 elif len(A.shape) == 3: 

59 AV = jnp.zeros(A.shape, dtype=jnp.complex128) 

60 for i in range(A.shape[0]): 

61 AV = AV.at[i].set(logm(A[i], **kwargs)) 

62 return AV 

63 else: 

64 raise NotImplementedError("Unsupported shape of input matrix") 

65 

66 

67class PauliCircuit: 

68 """ 

69 Wrapper for Pauli-Clifford Circuits described by Nemkov et al. 

70 (https://doi.org/10.1103/PhysRevA.108.032406). The code is inspired 

71 by the corresponding implementation: https://github.com/idnm/FourierVQA. 

72 

73 A Pauli Circuit only consists of parameterised Pauli-rotations and Clifford 

74 gates, which is the default for the most common VQCs. 

75 """ 

76 

77 @staticmethod 

78 def from_parameterised_circuit( 

79 tape: QuantumScript, 

80 ) -> tuple[QuantumScript]: 

81 """ 

82 Transforms the quantum tape of a circuit a Pauli-Clifford circuit. 

83 

84 Args: 

85 tape (QuantumScript): The quantum tape for the operations in the 

86 ansatz. This is automatically passed, when initialising the 

87 transform function with a QNode. Note: directly calling 

88 `PauliCircuit.from_parameterised_circuit(circuit)` for a QNode 

89 circuit will fail 

90 

91 Returns: 

92 QuantumScript: 

93 - A new quantum tape, containing the operations of the 

94 Pauli-Clifford Circuit. 

95 """ 

96 operations = PauliCircuit.get_clifford_pauli_gates(tape) 

97 

98 pauli_gates, final_cliffords = PauliCircuit.commute_all_cliffords_to_the_end( 

99 operations 

100 ) 

101 

102 observables = PauliCircuit.cliffords_in_observable( 

103 final_cliffords, tape.observables 

104 ) 

105 

106 with QuantumTape() as tape_new: 

107 for op in pauli_gates: 

108 op.queue() 

109 for obs in observables: 

110 qml.expval(obs) 

111 

112 return tape_new 

113 

114 @staticmethod 

115 def commute_all_cliffords_to_the_end( 

116 operations: List[Operator], 

117 ) -> Tuple[List[Operator], List[Operator]]: 

118 """ 

119 This function moves all clifford gates to the end of the circuit, 

120 accounting for commutation rules. 

121 

122 Args: 

123 operations (List[Operator]): The operations in the tape of the 

124 circuit 

125 

126 Returns: 

127 Tuple[List[Operator], List[Operator]]: 

128 - List of the resulting Pauli-rotations 

129 - List of the resulting Clifford gates 

130 """ 

131 first_clifford = -1 

132 for i in range(len(operations) - 2, -1, -1): 

133 j = i 

134 while ( 

135 j + 1 < len(operations) # Clifford has not alredy reached the end 

136 and PauliCircuit._is_clifford(operations[j]) 

137 and PauliCircuit._is_pauli_rotation(operations[j + 1]) 

138 ): 

139 pauli, clifford = PauliCircuit._evolve_clifford_rotation( 

140 operations[j], operations[j + 1] 

141 ) 

142 operations[j] = pauli 

143 operations[j + 1] = clifford 

144 j += 1 

145 first_clifford = j 

146 

147 # No Clifford gates are in the circuit 

148 if not PauliCircuit._is_clifford(operations[-1]): 

149 return operations, [] 

150 

151 pauli_rotations = operations[:first_clifford] 

152 clifford_gates = operations[first_clifford:] 

153 

154 return pauli_rotations, clifford_gates 

155 

156 @staticmethod 

157 def get_clifford_pauli_gates(tape: QuantumScript) -> List[Operator]: 

158 """ 

159 This function decomposes all gates in the circuit to clifford and 

160 pauli-rotation gates 

161 

162 Args: 

163 tape (QuantumScript): The tape of the circuit containing all 

164 operations. 

165 

166 Returns: 

167 List[Operator]: A list of operations consisting only of clifford 

168 and Pauli-rotation gates. 

169 """ 

170 operations = [] 

171 for operation in tape.operations: 

172 if PauliCircuit._is_clifford(operation) or PauliCircuit._is_pauli_rotation( 

173 operation 

174 ): 

175 operations.append(operation) 

176 elif PauliCircuit._is_skippable(operation): 

177 continue 

178 else: 

179 # TODO: Maybe there is a prettier way to decompose a gate 

180 # We currently can not handle parametrised input gates, that 

181 # are not plain pauli rotations 

182 tape = QuantumScript([operation]) 

183 decomposed_tape = qml.transforms.decompose( 

184 tape, gate_set=PAULI_ROTATION_GATES + CLIFFORD_GATES 

185 ) 

186 decomposed_ops = decomposed_tape[0][0].operations 

187 decomposed_ops = [ 

188 ( 

189 op 

190 if PauliCircuit._is_clifford(op) 

191 else op.__class__(jnp.array(op.parameters), op.wires) 

192 ) 

193 for op in decomposed_ops 

194 ] 

195 operations.extend(decomposed_ops) 

196 

197 return operations 

198 

199 @staticmethod 

200 def _is_skippable(operation: Operator) -> bool: 

201 """ 

202 Determines is an operator can be ignored when building the Pauli 

203 Clifford circuit. Currently this only contains barriers. 

204 

205 Args: 

206 operation (Operator): Gate operation 

207 

208 Returns: 

209 bool: Whether the operation can be skipped. 

210 """ 

211 return isinstance(operation, SKIPPABLE_OPERATIONS) 

212 

213 @staticmethod 

214 def _is_clifford(operation: Operator) -> bool: 

215 """ 

216 Determines is an operator is a Clifford gate. 

217 

218 Args: 

219 operation (Operator): Gate operation 

220 

221 Returns: 

222 bool: Whether the operation is Clifford. 

223 """ 

224 return isinstance(operation, CLIFFORD_GATES) 

225 

226 @staticmethod 

227 def _is_pauli_rotation(operation: Operator) -> bool: 

228 """ 

229 Determines is an operator is a Pauli rotation gate. 

230 

231 Args: 

232 operation (Operator): Gate operation 

233 

234 Returns: 

235 bool: Whether the operation is a Pauli operation. 

236 """ 

237 return isinstance(operation, PAULI_ROTATION_GATES) 

238 

239 @staticmethod 

240 def _evolve_clifford_rotation( 

241 clifford: Operator, pauli: Operator 

242 ) -> Tuple[Operator, Operator]: 

243 """ 

244 This function computes the resulting operations, when switching a 

245 Cifford gate and a Pauli rotation in the circuit. 

246 

247 **Example**: 

248 Consider a circuit consisting of the gate sequence 

249 ... --- H --- R_z --- ... 

250 This function computes the evolved Pauli Rotation, and moves the 

251 clifford (Hadamard) gate to the end: 

252 ... --- R_x --- H --- ... 

253 

254 Args: 

255 clifford (Operator): Clifford gate to move. 

256 pauli (Operator): Pauli rotation gate to move the clifford past. 

257 

258 Returns: 

259 Tuple[Operator, Operator]: 

260 - Resulting Clifford operator (should be the same as the input) 

261 - Evolved Pauli rotation operator 

262 """ 

263 

264 if not any(p_c in clifford.wires for p_c in pauli.wires): 

265 return pauli, clifford 

266 

267 gen = pauli.generator() 

268 param = pauli.parameters[0] 

269 

270 evolved_gen, _ = PauliCircuit._evolve_clifford_pauli( 

271 clifford, gen, adjoint_left=False 

272 ) 

273 qubits = evolved_gen.wires 

274 evolved_gen = qml.pauli_decompose(evolved_gen.matrix()) 

275 pauli_str, param_factor = PauliCircuit._get_paulistring_from_generator( 

276 evolved_gen 

277 ) 

278 pauli_str, qubits = PauliCircuit._remove_identities_from_paulistr( 

279 pauli_str, qubits 

280 ) 

281 pauli = qml.PauliRot(param * param_factor, pauli_str, qubits) 

282 

283 return pauli, clifford 

284 

285 @staticmethod 

286 def _remove_identities_from_paulistr( 

287 pauli_str: str, qubits: List[int] 

288 ) -> Tuple[str, List[int]]: 

289 """ 

290 Removes identities from Pauli string and its corresponding qubits. 

291 

292 Args: 

293 pauli_str (str): Pauli string 

294 qubits (List[int]): Corresponding qubit indices. 

295 

296 Returns: 

297 Tuple[str, List[int]]: 

298 - Pauli string without identities 

299 - Qubits indices without the identities 

300 """ 

301 

302 reduced_qubits = [] 

303 reduced_pauli_str = "" 

304 for i, p in enumerate(pauli_str): 

305 if p != "I": 

306 reduced_pauli_str += p 

307 reduced_qubits.append(qubits[i]) 

308 

309 return reduced_pauli_str, reduced_qubits 

310 

311 @staticmethod 

312 def _evolve_clifford_pauli( 

313 clifford: Operator, pauli: Operator, adjoint_left: bool = True 

314 ) -> Tuple[Operator, Operator]: 

315 """ 

316 This function computes the resulting operation, when evolving a Pauli 

317 Operation with a Clifford operation. 

318 For a Clifford operator C and a Pauli operator P, this functin computes: 

319 P' = C* P C 

320 

321 Args: 

322 clifford (Operator): Clifford gate 

323 pauli (Operator): Pauli gate 

324 adjoint_left (bool, optional): If adjoint of the clifford gate is 

325 applied to the left. If this is set to True C* P C is computed, 

326 else C P C*. Defaults to True. 

327 

328 Returns: 

329 Tuple[Operator, Operator]: 

330 - Evolved Pauli operator 

331 - Resulting Clifford operator (should be the same as the input) 

332 """ 

333 if not any(p_c in clifford.wires for p_c in pauli.wires): 

334 return pauli, clifford 

335 

336 if adjoint_left: 

337 evolved_pauli = qml.adjoint(clifford) @ pauli @ qml.adjoint(clifford) 

338 else: 

339 evolved_pauli = clifford @ pauli @ qml.adjoint(clifford) 

340 

341 return evolved_pauli, clifford 

342 

343 @staticmethod 

344 def _evolve_cliffords_list(cliffords: List[Operator], pauli: Operator) -> Operator: 

345 """ 

346 This function evolves a Pauli operation according to a sequence of cliffords. 

347 

348 Args: 

349 clifford (Operator): Clifford gate 

350 pauli (Operator): Pauli gate 

351 

352 Returns: 

353 Operator: Evolved Pauli operator 

354 """ 

355 for clifford in cliffords[::-1]: 

356 pauli, _ = PauliCircuit._evolve_clifford_pauli(clifford, pauli) 

357 qubits = pauli.wires 

358 pauli = qml.pauli_decompose(pauli.matrix(), wire_order=qubits) 

359 

360 pauli = qml.simplify(pauli) 

361 

362 # remove coefficients 

363 pauli = ( 

364 pauli.terms()[1][0] 

365 if isinstance(pauli, (qml_op.Prod, qml_op.LinearCombination)) 

366 else pauli 

367 ) 

368 

369 return pauli 

370 

371 @staticmethod 

372 def _get_paulistring_from_generator( 

373 gen: qml_op.LinearCombination, 

374 ) -> Tuple[str, float]: 

375 """ 

376 Compute a Paulistring, consisting of "X", "Y", "Z" and "I" from a 

377 generator. 

378 

379 Args: 

380 gen (qml_op.LinearCombination): The generator operation created by 

381 Pennylane 

382 

383 Returns: 

384 Tuple[str, float]: 

385 - The Paulistring 

386 - A factor with which to multiply a parameter to the rotation 

387 gate. 

388 """ 

389 factor, term = gen.terms() 

390 param_factor = -2 * factor # Rotation is defined as exp(-0.5 theta G) 

391 pauli_term = term[0] if isinstance(term[0], qml_op.Prod) else [term[0]] 

392 pauli_str_list = ["I"] * len(pauli_term) 

393 for p in pauli_term: 

394 if "Pauli" in p.name: 

395 q = p.wires[0] 

396 pauli_str_list[q] = p.name[-1] 

397 pauli_str = "".join(pauli_str_list) 

398 return pauli_str, param_factor 

399 

400 @staticmethod 

401 def cliffords_in_observable( 

402 operations: List[Operator], original_obs: List[Operator] 

403 ) -> List[Operator]: 

404 """ 

405 Integrates Clifford gates in the observables of the original ansatz. 

406 

407 Args: 

408 operations (List[Operator]): Clifford gates 

409 original_obs (List[Operator]): Original observables from the 

410 circuit 

411 

412 Returns: 

413 List[Operator]: Observables with Clifford operations 

414 """ 

415 observables = [] 

416 for ob in original_obs: 

417 clifford_obs = PauliCircuit._evolve_cliffords_list(operations, ob) 

418 observables.append(clifford_obs) 

419 return observables 

420 

421 

422class QuanTikz: 

423 class TikzFigure: 

424 def __init__(self, quantikz_str: str): 

425 self.quantikz_str = quantikz_str 

426 

427 def __repr__(self): 

428 return self.quantikz_str 

429 

430 def __str__(self): 

431 return self.quantikz_str 

432 

433 def wrap_figure(self): 

434 """ 

435 Wraps the quantikz string in a LaTeX figure environment. 

436 

437 Returns: 

438 str: A formatted LaTeX string representing the TikZ figure containing 

439 the quantum circuit diagram. 

440 """ 

441 return f""" 

442\\begin{{figure}} 

443 \\centering 

444 \\begin{{tikzpicture}} 

445 \\node[scale=0.85] {{ 

446 \\begin{{quantikz}} 

447 {self.quantikz_str} 

448 \\end{{quantikz}} 

449 }}; 

450 \\end{{tikzpicture}} 

451\\end{{figure}}""" 

452 

453 def export(self, destination: str, full_document=False, mode="w") -> None: 

454 """ 

455 Export a LaTeX document with a quantum circuit in stick notation. 

456 

457 Parameters 

458 ---------- 

459 quantikz_strs : str or list[str] 

460 LaTeX string for the quantum circuit or a list of LaTeX strings. 

461 destination : str 

462 Path to the destination file. 

463 """ 

464 if full_document: 

465 latex_code = f""" 

466\\documentclass{{article}} 

467\\usepackage{{quantikz}} 

468\\usepackage{{tikz}} 

469\\usetikzlibrary{{quantikz2}} 

470\\usepackage{{quantikz}} 

471\\usepackage[a3paper, landscape, margin=0.5cm]{{geometry}} 

472\\begin{{document}} 

473{self.wrap_figure()} 

474\\end{{document}}""" 

475 else: 

476 latex_code = self.quantikz_str + "\n" 

477 

478 with open(destination, mode) as f: 

479 f.write(latex_code) 

480 

481 @staticmethod 

482 def ground_state() -> str: 

483 """ 

484 Generate the LaTeX representation of the |0⟩ ground state in stick notation. 

485 

486 Returns 

487 ------- 

488 str 

489 LaTeX string for the |0⟩ state. 

490 """ 

491 return "\\lstick{\\ket{0}}" 

492 

493 @staticmethod 

494 def measure(op): 

495 if len(op.wires) > 1: 

496 raise NotImplementedError("Multi-wire measurements are not supported yet") 

497 else: 

498 return "\\meter{}" 

499 

500 @staticmethod 

501 def search_pi_fraction(w, op_name): 

502 w_pi = Fraction(w / jnp.pi).limit_denominator(100) 

503 # Not a small nice Fraction 

504 if w_pi.denominator > 12: 

505 return f"\\gate{{{op_name}({w:.2f})}}" 

506 # Pi 

507 elif w_pi.denominator == 1 and w_pi.numerator == 1: 

508 return f"\\gate{{{op_name}(\\pi)}}" 

509 # 0 

510 elif w_pi.numerator == 0: 

511 return f"\\gate{{{op_name}(0)}}" 

512 # Multiple of Pi 

513 elif w_pi.denominator == 1: 

514 return f"\\gate{{{op_name}({w_pi.numerator}\\pi)}}" 

515 # Nice Fraction of pi 

516 elif w_pi.numerator == 1: 

517 return ( 

518 f"\\gate{{{op_name}\\left(" 

519 f"\\frac{{\\pi}}{{{w_pi.denominator}}}\\right)}}" 

520 ) 

521 # Small nice Fraction 

522 else: 

523 return ( 

524 f"\\gate{{{op_name}\\left(" 

525 f"\\frac{{{w_pi.numerator}\\pi}}{{{w_pi.denominator}}}" 

526 f"\\right)}}" 

527 ) 

528 

529 @staticmethod 

530 def gate(op, index=None, gate_values=False, inputs_symbols="x") -> str: 

531 """ 

532 Generate LaTeX for a quantum gate in stick notation. 

533 

534 Parameters 

535 ---------- 

536 op : qml.Operation 

537 The quantum gate to represent. 

538 index : int, optional 

539 Gate index in the circuit. 

540 gate_values : bool, optional 

541 Include gate values in the representation. 

542 inputs_symbols : str, optional 

543 Symbols for the inputs in the representation. 

544 

545 Returns 

546 ------- 

547 str 

548 LaTeX string for the gate. 

549 """ 

550 op_name = op.name 

551 match op.name: 

552 case "Hadamard": 

553 op_name = "H" 

554 case "RX" | "RY" | "RZ": 

555 pass 

556 case "Rot": 

557 op_name = "R" 

558 

559 if gate_values and len(op.parameters) > 0: 

560 w = float(op.parameters[0].item()) 

561 return QuanTikz.search_pi_fraction(w, op_name) 

562 else: 

563 # Is gate with parameter 

564 if op.parameters == [] or op.parameters[0].shape == (): 

565 if index is None: 

566 return f"\\gate{{{op_name}}}" 

567 else: 

568 return f"\\gate{{{op_name}(\\theta_{{{index}}})}}" 

569 # Is gate with input 

570 elif op.parameters[0].shape == (1,): 

571 return f"\\gate{{{op_name}({inputs_symbols})}}" 

572 

573 @staticmethod 

574 def cgate(op, index=None, gate_values=False, inputs_symbols="x") -> Tuple[str, str]: 

575 """ 

576 Generate LaTeX for a controlled quantum gate in stick notation. 

577 

578 Parameters 

579 ---------- 

580 op : qml.Operation 

581 The quantum gate operation to represent. 

582 index : int, optional 

583 Gate index in the circuit. 

584 gate_values : bool, optional 

585 Include gate values in the representation. 

586 inputs_symbols : str, optional 

587 Symbols for the inputs in the representation. 

588 

589 Returns 

590 ------- 

591 Tuple[str, str] 

592 - LaTeX string for the control gate 

593 - LaTeX string for the target gate 

594 """ 

595 match op.name: 

596 case "CRX" | "CRY" | "CRZ" | "CX" | "CY" | "CZ": 

597 op_name = op.name[1:] 

598 case _: 

599 pass 

600 targ = "\\targ{}" 

601 if op.name in ["CRX", "CRY", "CRZ"]: 

602 if gate_values and len(op.parameters) > 0: 

603 w = float(op.parameters[0].item()) 

604 targ = QuanTikz.search_pi_fraction(w, op_name) 

605 else: 

606 # Is gate with parameter 

607 if op.parameters[0].shape == (): 

608 if index is None: 

609 targ = f"\\gate{{{op_name}}}" 

610 else: 

611 targ = f"\\gate{{{op_name}(\\theta_{{{index}}})}}" 

612 # Is gate with input 

613 elif op.parameters[0].shape == (1,): 

614 targ = f"\\gate{{{op_name}({inputs_symbols})}}" 

615 elif op.name in ["CX", "CY", "CZ"]: 

616 targ = "\\control{}" 

617 

618 distance = op.wires[1] - op.wires[0] 

619 return f"\\ctrl{{{distance}}}", targ 

620 

621 @staticmethod 

622 def barrier(op) -> str: 

623 """ 

624 Generate LaTeX for a barrier in stick notation. 

625 

626 Parameters 

627 ---------- 

628 op : qml.Operation 

629 The barrier operation to represent. 

630 

631 Returns 

632 ------- 

633 str 

634 LaTeX string for the barrier. 

635 """ 

636 return ( 

637 "\\slice[style={{draw=black, solid, double distance=2pt, " 

638 "line width=0.5pt}}]{{}}" 

639 ) 

640 

641 @staticmethod 

642 def _build_tikz_circuit(quantum_tape, gate_values=False, inputs_symbols="x"): 

643 """ 

644 Builds a LaTeX representation of a quantum circuit in TikZ format. 

645 

646 This static method constructs a TikZ circuit diagram from a given quantum 

647 tape. It processes the operations in the tape, including gates, controlled 

648 gates, barriers, and measurements. The resulting structure is a list of 

649 LaTeX strings, each representing a wire in the circuit. 

650 

651 Parameters 

652 ---------- 

653 quantum_tape : QuantumTape 

654 The quantum tape containing the operations of the circuit. 

655 gate_values : bool, optional 

656 If True, include gate parameter values in the representation. 

657 inputs_symbols : str, optional 

658 Symbols to represent the inputs in the circuit. 

659 

660 Returns 

661 ------- 

662 circuit_tikz : list of list of str 

663 A nested list where each inner list contains LaTeX strings representing 

664 the operations on a single wire of the circuit. 

665 """ 

666 

667 circuit_tikz = [ 

668 [QuanTikz.ground_state()] for _ in range(quantum_tape.num_wires) 

669 ] 

670 

671 index = iter(range(len(quantum_tape.operations))) 

672 for op in quantum_tape.circuit: 

673 # catch measurement operations 

674 if op._queue_category == "_measurements": 

675 # get the maximum length of all wires 

676 max_len = max(len(circuit_tikz[cw]) for cw in range(len(circuit_tikz))) 

677 if op.wires[0] != 0: 

678 max_len -= 1 

679 # extend the wire by the number of missing operations 

680 circuit_tikz[op.wires[0]].extend( 

681 "" for _ in range(max_len - len(circuit_tikz[op.wires[0]])) 

682 ) 

683 circuit_tikz[op.wires[0]].append(QuanTikz.measure(op)) 

684 # process all gates 

685 elif op._queue_category == "_ops": 

686 # catch barriers 

687 if op.name == "Barrier": 

688 # get the maximum length of all wires 

689 max_len = max( 

690 len(circuit_tikz[cw]) for cw in range(len(circuit_tikz)) 

691 ) 

692 

693 # extend the wires by the number of missing operations 

694 for ow in [i for i in range(len(circuit_tikz))]: 

695 circuit_tikz[ow].extend( 

696 "" for _ in range(max_len - len(circuit_tikz[ow])) 

697 ) 

698 

699 circuit_tikz[op.wires[0]][-1] += QuanTikz.barrier(op) 

700 # single qubit gate? 

701 elif len(op.wires) == 1: 

702 # build and append standard gate 

703 circuit_tikz[op.wires[0]].append( 

704 QuanTikz.gate( 

705 op, 

706 index=next(index), 

707 gate_values=gate_values, 

708 inputs_symbols=next(inputs_symbols), 

709 ) 

710 ) 

711 # controlled gate? 

712 elif len(op.wires) == 2: 

713 # build the controlled gate 

714 if op.name in ["CRX", "CRY", "CRZ"]: 

715 ctrl, targ = QuanTikz.cgate( 

716 op, 

717 index=next(index), 

718 gate_values=gate_values, 

719 inputs_symbols=next(inputs_symbols), 

720 ) 

721 else: 

722 ctrl, targ = QuanTikz.cgate(op) 

723 

724 # get the wires that this cgate spans over 

725 crossing_wires = [ 

726 i for i in range(min(op.wires), max(op.wires) + 1) 

727 ] 

728 # get the maximum length of all operations currently on this wire 

729 max_len = max([len(circuit_tikz[cw]) for cw in crossing_wires]) 

730 

731 # extend the affected wires by the number of missing operations 

732 for ow in [i for i in range(min(op.wires), max(op.wires) + 1)]: 

733 circuit_tikz[ow].extend( 

734 "" for _ in range(max_len - len(circuit_tikz[ow])) 

735 ) 

736 

737 # finally append the cgate operation 

738 circuit_tikz[op.wires[0]].append(ctrl) 

739 circuit_tikz[op.wires[1]].append(targ) 

740 

741 # extend the non-affected wires by the number of missing operations 

742 for cw in crossing_wires - op.wires: 

743 circuit_tikz[cw].append("") 

744 else: 

745 raise NotImplementedError(">2-wire gates are not supported yet") 

746 

747 return circuit_tikz 

748 

749 @staticmethod 

750 def build( 

751 circuit: qml.QNode, 

752 params, 

753 inputs, 

754 enc_params=None, 

755 gate_values=False, 

756 inputs_symbols="x", 

757 ) -> str: 

758 """ 

759 Generate LaTeX for a quantum circuit in stick notation. 

760 

761 Parameters 

762 ---------- 

763 circuit : qml.QNode 

764 The quantum circuit to represent. 

765 params : array 

766 Weight parameters for the circuit. 

767 inputs : array 

768 Inputs for the circuit. 

769 enc_params : array 

770 Encoding weight parameters for the circuit. 

771 gate_values : bool, optional 

772 Toggle for gate values or theta variables in the representation. 

773 inputs_symbols : str, optional 

774 Symbols for the inputs in the representation. 

775 

776 Returns 

777 ------- 

778 str 

779 LaTeX string for the circuit. 

780 """ 

781 if enc_params is not None: 

782 quantum_tape = qml.workflow.construct_tape(circuit)( 

783 params=params, inputs=inputs, enc_params=enc_params 

784 ) 

785 else: 

786 quantum_tape = qml.workflow.construct_tape(circuit)( 

787 params=params, inputs=inputs 

788 ) 

789 

790 if isinstance(inputs_symbols, str) and inputs.size > 1: 

791 inputs_symbols = cycle( 

792 [f"{inputs_symbols}_{i}" for i in range(inputs.size)] 

793 ) 

794 elif isinstance(inputs_symbols, list): 

795 assert ( 

796 len(inputs_symbols) == inputs.size 

797 ), f"The number of input symbols {len(inputs_symbols)} \ 

798 must match the number of inputs {inputs.size}." 

799 inputs_symbols = cycle(inputs_symbols) 

800 else: 

801 inputs_symbols = cycle([inputs_symbols]) 

802 

803 circuit_tikz = QuanTikz._build_tikz_circuit( 

804 quantum_tape, gate_values=gate_values, inputs_symbols=inputs_symbols 

805 ) 

806 quantikz_str = "" 

807 

808 # get the maximum length of all wires 

809 max_len = max(len(circuit_tikz[cw]) for cw in range(len(circuit_tikz))) 

810 

811 # extend the wires by the number of missing operations 

812 for ow in [i for i in range(len(circuit_tikz))]: 

813 circuit_tikz[ow].extend("" for _ in range(max_len - len(circuit_tikz[ow]))) 

814 

815 for wire_idx, wire_ops in enumerate(circuit_tikz): 

816 for op_idx, op in enumerate(wire_ops): 

817 # if not last operation on wire 

818 if op_idx < len(wire_ops) - 1: 

819 quantikz_str += f"{op} & " 

820 else: 

821 quantikz_str += f"{op}" 

822 # if not last wire 

823 if wire_idx < len(circuit_tikz) - 1: 

824 quantikz_str += " \\\\\n" 

825 

826 return QuanTikz.TikzFigure(quantikz_str)