Coverage for qml_essentials/qoc.py: 68%
380 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-10-02 13:10 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-10-02 13:10 +0000
1import os
2import csv
3import jax
4from jax import numpy as jnp
5import optax
6import pennylane as qml
7from qml_essentials.ansaetze import Gates
8import matplotlib.pyplot as plt
9import warnings
10import logging
12jax.config.update("jax_enable_x64", True)
13log = logging.getLogger(__name__)
16class QOC:
17 # TODO: Potentially refactor all the optimize_*()... The only differences
18 # are the circuits
19 def __init__(
20 self,
21 make_plots=False,
22 file_dir="qoc/results",
23 fig_dir="qoc/figures",
24 fig_points=70,
25 ):
26 """
27 Initialize Quantum Optimal Control with Pulse-level Gates.
29 Args:
30 log_dir (str): Directory for TensorBoard logs.
31 make_plots (bool): Whether to generate and save plots.
32 file_dir (str): Directory to save optimization results.
33 fig_dir (str): Directory to save figures.
34 fig_points (int): Number of points for plotting rotations.
35 """
36 self.ws = jnp.linspace(0, 2 * jnp.pi, fig_points)
38 self.make_plots = make_plots
39 self.file_dir = file_dir
40 self.fig_dir = fig_dir
42 self.current_gate = None
44 def get_circuits(self):
45 """
46 Return pulse- and unitary-based circuits for the current gate.
48 Returns:
49 tuple: (pulse_circuit, unitary_circuit, operation_str)
50 """
51 dev = qml.device("default.qubit", wires=1)
53 if self.current_gate in ["RX", "RY"]:
55 @qml.qnode(dev, interface="jax")
56 def pulse_circuit(w, pulse_params=None):
57 getattr(Gates, self.current_gate)(
58 w, 0, pulse_params=pulse_params, gate_mode="pulse"
59 )
60 return [
61 qml.expval(qml.PauliX(0)),
62 qml.expval(qml.PauliY(0)),
63 qml.expval(qml.PauliZ(0)),
64 ]
66 @qml.qnode(dev)
67 def unitary_circuit(w):
68 getattr(qml, self.current_gate)(w, wires=0)
69 return [
70 qml.expval(qml.PauliX(0)),
71 qml.expval(qml.PauliY(0)),
72 qml.expval(qml.PauliZ(0)),
73 ]
75 operation = f"{self.current_gate}(w)"
77 elif self.current_gate == "RZ":
79 @qml.qnode(dev, interface="jax")
80 def pulse_circuit(w, *_):
81 qml.RX(jnp.pi / 2, wires=0)
82 getattr(Gates, self.current_gate)(w, 0, gate_mode="pulse")
83 return [
84 qml.expval(qml.PauliX(0)),
85 qml.expval(qml.PauliY(0)),
86 qml.expval(qml.PauliZ(0)),
87 ]
89 @qml.qnode(dev)
90 def unitary_circuit(w):
91 qml.RX(jnp.pi / 2, wires=0)
92 getattr(qml, self.current_gate)(w, wires=0)
93 return [
94 qml.expval(qml.PauliX(0)),
95 qml.expval(qml.PauliY(0)),
96 qml.expval(qml.PauliZ(0)),
97 ]
99 operation = f"RX(π / 2)·{self.current_gate}(w)"
101 elif self.current_gate == "H":
103 @qml.qnode(dev, interface="jax")
104 def pulse_circuit(w, pulse_params=None):
105 qml.RX(w, wires=0)
106 getattr(Gates, self.current_gate)(
107 0, pulse_params=pulse_params, gate_mode="pulse"
108 )
109 return [
110 qml.expval(qml.PauliX(0)),
111 qml.expval(qml.PauliY(0)),
112 qml.expval(qml.PauliZ(0)),
113 ]
115 @qml.qnode(dev)
116 def unitary_circuit(w):
117 qml.RX(w, wires=0)
118 getattr(qml, self.current_gate)(wires=0)
119 return [
120 qml.expval(qml.PauliX(0)),
121 qml.expval(qml.PauliY(0)),
122 qml.expval(qml.PauliZ(0)),
123 ]
125 operation = f"RX(w)·{self.current_gate}"
127 elif self.current_gate == "CZ":
128 dev = qml.device("default.qubit", wires=2)
130 @qml.qnode(dev, interface="jax")
131 def pulse_circuit(w, pulse_params=None):
132 qml.RX(w, wires=0)
133 qml.RX(w, wires=1)
134 Gates.CZ([0, 1], pulse_params=pulse_params, gate_mode="pulse")
135 qml.RX(-w, wires=1)
136 qml.RX(-w, wires=0)
137 return [
138 qml.expval(qml.PauliX(1)),
139 qml.expval(qml.PauliY(1)),
140 qml.expval(qml.PauliZ(1)),
141 ]
143 @qml.qnode(dev)
144 def unitary_circuit(w):
145 qml.RX(w, wires=0)
146 qml.RX(w, wires=1)
147 qml.CZ(wires=[0, 1])
148 qml.RX(-w, wires=1)
149 qml.RX(-w, wires=0)
150 return [
151 qml.expval(qml.PauliX(1)),
152 qml.expval(qml.PauliY(1)),
153 qml.expval(qml.PauliZ(1)),
154 ]
156 operation = r"$RX_0(w)$·$RX_1(w)$·$CZ_{0, 1}$·$RX_1(-w)$·$RX_0(-w)$"
158 elif self.current_gate == "CX":
159 dev = qml.device("default.qubit", wires=2)
161 @qml.qnode(dev, interface="jax")
162 def pulse_circuit(w, pulse_params=None):
163 qml.RX(w, wires=0)
164 Gates.CX([0, 1], pulse_params=pulse_params, gate_mode="pulse")
165 return [
166 qml.expval(qml.PauliX(1)),
167 qml.expval(qml.PauliY(1)),
168 qml.expval(qml.PauliZ(1)),
169 ]
171 @qml.qnode(dev)
172 def unitary_circuit(w):
173 qml.RX(w, wires=0)
174 qml.CNOT(wires=[0, 1])
175 return [
176 qml.expval(qml.PauliX(1)),
177 qml.expval(qml.PauliY(1)),
178 qml.expval(qml.PauliZ(1)),
179 ]
181 operation = r"$RX_0(w)$·$CX_{0,1}$"
183 return pulse_circuit, unitary_circuit, operation
185 # TODO: Update method for new gates (Rot, CY, CRZ, CRY, CRX)
186 def plot_rotation(self, pulse_params):
187 """
188 Plot expectation values of pulse- and unitary-based circuits for the
189 current gate as a function of rotation angle.
191 Args:
192 pulse_params: pulse parameters of pulse level gate.
193 """
194 pulse_circuit, unitary_circuit, operation = self.get_circuits()
196 pulse_expvals = [pulse_circuit(w, pulse_params) for w in self.ws]
197 ideal_expvals = [unitary_circuit(w) for w in self.ws]
199 pulse_expvals = jnp.array(pulse_expvals)
200 ideal_expvals = jnp.array(ideal_expvals)
202 fig, axs = plt.subplots(3, 1, figsize=(6, 12))
204 bases = ["X", "Y", "Z"]
205 for i, basis in enumerate(bases):
206 axs[i].plot(self.ws, pulse_expvals[:, i], label="Pulse-based")
207 axs[i].plot(self.ws, ideal_expvals[:, i], "--", label="Unitary-based")
208 axs[i].set_xlabel("Rotation angle w (rad)")
209 axs[i].set_ylabel(f"⟨{basis}⟩")
210 axs[i].set_title(f"{operation} in {basis}-basis")
211 axs[i].grid(True)
212 axs[i].legend()
214 xticks = [0, jnp.pi / 2, jnp.pi, 3 * jnp.pi / 2, 2 * jnp.pi]
215 xtick_labels = ["0", "π/2", "π", "3π/2", "2π"]
216 for ax in axs:
217 ax.set_xticks(xticks)
218 ax.set_xticklabels(xtick_labels)
220 plt.tight_layout()
221 os.makedirs(self.fig_dir, exist_ok=True)
222 plt.savefig(f"{self.fig_dir}/qoc_{self.current_gate}(w).png")
223 plt.close()
225 def save_results(self, opt_pulse_params):
226 """
227 Save optimized pulse parameters to CSV file.
229 Args:
230 opt_pulse_params (list or array): Optimized parameters to save.
231 filename (str): Path to CSV file.
232 """
233 header = ["gate"] + [f"param_{i+1}" for i in range(len(opt_pulse_params))]
234 if self.file_dir is not None:
235 os.makedirs(self.file_dir, exist_ok=True)
236 filename = os.path.join(self.file_dir, "qoc_results.csv")
237 file_exists = os.path.isfile(filename)
239 with open(filename, mode="a", newline="") as f:
240 writer = csv.writer(f)
241 if not file_exists:
242 writer.writerow(header)
243 writer.writerow(
244 [self.current_gate] + list(map(float, opt_pulse_params))
245 )
247 def loss_fn(self, state, target_state):
248 """
249 Compute infidelity between two quantum states.
251 Args:
252 state (array): Output state from pulse circuit.
253 target_state (array): Target state from unitary circuit.
255 Returns:
256 float: Infidelity (1 - fidelity).
257 """
258 fidelity = jnp.abs(jnp.vdot(target_state, state)) ** 2
259 return 1 - fidelity
261 def cost_fn(self, pulse_params, circuit, w, target_state):
262 """
263 Compute cost for optimization by evaluating circuit and loss.
265 Args:
266 pulse_params: pulse parameters of pulse level gate.
267 circuit (callable): QNode circuit accepting (w, pulse_params).
268 w (float): Rotation angle.
269 target_state (array): Target quantum state.
271 Returns:
272 float: Computed loss.
273 """
274 state = circuit(w, pulse_params)
275 return self.loss_fn(state, target_state)
277 def run_optimization(
278 self,
279 cost,
280 init_pulse_params,
281 steps,
282 patience,
283 log_interval,
284 *args,
285 ):
286 """
287 Run gradient-based optimization on given cost function.
289 Args:
290 cost (callable): Cost function to minimize.
291 init_pulse_params (array): Initial parameters.
292 steps (int): Number of optimization steps.
293 log_interval (int): Print frequency.
294 *args: Extra args for cost.
295 patience (int): Early stopping patience (default: 20).
297 Returns:
298 tuple: (optimized parameters, best loss, list of loss values)
299 """
300 optimizer = optax.adam(0.1)
301 opt_state = optimizer.init(init_pulse_params)
302 pulse_params = init_pulse_params
303 losses = []
305 best_loss = float("inf")
306 best_pulse_params = pulse_params
307 no_improve_counter = 0
309 @jax.jit
310 def opt_step(pulse_params, opt_state, *args):
311 loss, grads = jax.value_and_grad(cost)(pulse_params, *args)
312 updates, opt_state = optimizer.update(grads, opt_state)
313 pulse_params = optax.apply_updates(pulse_params, updates)
314 return pulse_params, opt_state, loss
316 for step in range(steps):
317 pulse_params, opt_state, loss = opt_step(pulse_params, opt_state, *args)
318 losses.append(loss)
320 if loss < best_loss:
321 best_loss = loss
322 best_pulse_params = pulse_params
323 no_improve_counter = 0
324 else:
325 no_improve_counter += 1
327 if (step + 1) % log_interval == 0:
328 log.info(f"Step {step + 1}/{steps}, Loss: {loss:.2e}")
330 if no_improve_counter >= patience:
331 log.info(f"Early stopping at step {step + 1} due to no improvement.")
332 break
334 return best_pulse_params, best_loss, losses
336 def optimize_Rot(
337 self,
338 steps: int = 1000,
339 patience: int = 100,
340 phi: float = jnp.pi / 2,
341 theta: float = jnp.pi / 2,
342 omega: float = jnp.pi / 2,
343 init_pulse_params: jnp.array = jnp.array([0.5, 1.0, 15.0, 1.0, 0.5]),
344 log_interval: int = 50,
345 ):
346 """
347 Optimize pulse parameters for the Rot(theta, phi, lam) gate.
349 Uses gradient-based optimization to minimize the difference between the
350 pulse-based Rot(phi, theta, omega) circuit expectation value and the target
351 unitary-based Rot(phi, theta, omega).
353 Args:
354 steps (int): Number of optimization steps. Default: 1000.
355 patience (int): Patience for early stopping. Default: 100.
356 theta, phi, lam (float): Rotation angles for the Rot gate.
357 Default: π / 2 for all three.
358 init_pulse_params (jnp.ndarray): Initial pulse parameters.
359 Default: [0.5, 1.0, 15.0, 1.0, 0.5].
360 log_interval (int): Frequency of printing loss.
362 Returns:
363 tuple: Optimized parameters and list of loss values.
364 """
365 self.current_gate = "Rot"
366 w = (phi, theta, omega)
368 dev = qml.device("default.qubit", wires=1)
370 @qml.qnode(dev, interface="jax")
371 def pulse_circuit(w, pulse_params):
372 phi, theta, omega = w
373 Gates.Rot(
374 phi, theta, omega, 0, pulse_params=pulse_params, gate_mode="pulse"
375 )
376 return qml.state()
378 @qml.qnode(dev)
379 def unitary_circuit(w):
380 phi, theta, omega = w
381 qml.Rot(phi, theta, omega, wires=0)
382 return qml.state()
384 target = unitary_circuit(w)
386 # Optimizing
387 def cost(pulse_params):
388 return self.cost_fn(
389 pulse_params, circuit=pulse_circuit, w=w, target_state=target
390 )
392 pulse_params, loss, losses = self.run_optimization(
393 cost, init_pulse_params, steps, patience, log_interval
394 )
396 # Saving the optimized parameters
397 self.save_results(pulse_params)
399 # Plotting the rotation
400 if self.make_plots:
401 warnings.warn("Plotting not implemented yet", UserWarning)
403 return pulse_params, loss, losses
405 def optimize_RX(
406 self,
407 steps: int = 1000,
408 patience: int = 100,
409 w: float = jnp.pi,
410 init_pulse_params: jnp.array = jnp.array([1.0, 15.0, 1.0]),
411 log_interval: int = 50,
412 ):
413 """
414 Optimize pulse parameters for the RX(w) gate to best approximate
415 the unitary RX(w) gate.
417 Uses gradient-based optimization to minimize the difference between the
418 pulse-based RX(w) circuit expectation value and the target gate-based RX(w).
420 Args:
421 steps (int): Number of optimization steps. Default: 1000.
422 patience (int): Amount of epochs without improvement before early stopping.
423 Default: 100.
424 w (float): Rotation angle in radians with which to run the optimization.
425 Default: π.
426 init_pulse_params (jnp.ndarray): Initial pulse parameters (A, sigma) and
427 time. Default: [1.0, 15.0, 1.0].
428 log_interval (int): Frequency of printing loss during optimization.
429 Default: 50.
431 Returns:
432 tuple: Optimized parameters (jnp.ndarray) and list of loss values
433 during optimization.
434 """
435 self.current_gate = "RX"
437 dev = qml.device("default.qubit", wires=1)
439 @qml.qnode(dev, interface="jax")
440 def pulse_circuit(w, pulse_params):
441 Gates.RX(w, 0, pulse_params=pulse_params, gate_mode="pulse")
442 return qml.state()
444 @qml.qnode(dev)
445 def unitary_circuit(w):
446 qml.RX(w, wires=0)
447 return qml.state()
449 target = unitary_circuit(w)
451 # Optimizing
452 def cost(pulse_params):
453 return self.cost_fn(
454 pulse_params, circuit=pulse_circuit, w=w, target_state=target
455 )
457 pulse_params, loss, losses = self.run_optimization(
458 cost, init_pulse_params, steps, patience, log_interval
459 )
461 # Saving the optimized parameters
462 self.save_results(pulse_params)
464 # Plotting the RX rotation
465 if self.make_plots:
466 log.info("Plotting RX rotation...")
467 self.plot_rotation(pulse_params)
469 return pulse_params, loss, losses
471 def optimize_RY(
472 self,
473 steps: int = 1000,
474 patience: int = 100,
475 w: float = jnp.pi,
476 init_pulse_params: jnp.array = jnp.array([1.0, 15.0, 1.0]),
477 log_interval: int = 50,
478 ):
479 """
480 Optimize pulse parameters for the RY(w) gate to best approximate
481 the unitary RY(w) gate.
483 Uses gradient-based optimization to minimize the difference between the
484 pulse-based RY(w) circuit expectation value and the target unitary-based RY(w).
486 Args:
487 steps (int): Number of optimization steps. Default: 1000.
488 patience (int): Amount of epochs without improvement before early stopping.
489 Default: 100.
490 w (float): Rotation angle in radians with which to run the optimization.
491 Default: π.
492 init_pulse_params (jnp.ndarray): Initial pulse parameters (A, sigma) and
493 time. Default: [1.0, 15.0, 1.0].
494 log_interval (int): Frequency of printing loss during optimization.
495 Default: 50.
497 Returns:
498 tuple: Optimized parameters (jnp.ndarray) and list of loss values
499 during optimization.
500 """
501 self.current_gate = "RY"
503 dev = qml.device("default.qubit", wires=1)
505 @qml.qnode(dev, interface="jax")
506 def pulse_circuit(w, pulse_params):
507 Gates.RY(w, 0, pulse_params=pulse_params, gate_mode="pulse")
508 return qml.state()
510 @qml.qnode(dev)
511 def unitary_circuit(w):
512 qml.RY(w, wires=0)
513 return qml.state()
515 target = unitary_circuit(w)
517 # Optimizing
518 def cost(pulse_params):
519 return self.cost_fn(
520 pulse_params, circuit=pulse_circuit, w=w, target_state=target
521 )
523 pulse_params, loss, losses = self.run_optimization(
524 cost, init_pulse_params, steps, patience, log_interval
525 )
527 # Saving the optimized parameters
528 self.save_results(pulse_params)
530 # Plotting the RY rotation
531 if self.make_plots:
532 log.info("Plotting RY rotation...")
533 self.plot_rotation(pulse_params)
535 return pulse_params, loss, losses
537 def optimize_RZ(self):
538 """
539 Plot the pulse level RZ rotation on the X basis.
541 Note:
542 No actual optimization is performed since the RZ gate
543 does not have pulse parameters to optimize.
545 Returns:
546 tuple: (None, None)
547 """
548 self.current_gate = "RZ"
549 if self.make_plots:
550 log.info("Plotting RZ rotation...")
551 self.plot_rotation([])
553 return None, None
555 def optimize_H(
556 self,
557 steps=1000,
558 patience: int = 100,
559 init_pulse_params: jnp.array = jnp.array([1.0, 15.0, 1.0]),
560 log_interval: int = 50,
561 ):
562 """
563 Optimize pulse parameters for the Hadamard (H) gate to best approximate
564 the unitary H gate.
566 Uses gradient-based optimization to minimize the difference between the
567 pulse-based H circuit output state and the target gate-based H state.
569 Args:
570 steps (int): Number of optimization steps. Default: 1000.
571 patience (int): Amount of epochs without improvement before early stopping.
572 Default: 100.
573 init_pulse_params (jnp.ndarray): Initial pulse parameters (A, sigma, t)
574 Default: [1.0, 15.0, 1.0].
575 log_interval (int): Frequency of printing loss during optimization.
576 Default: 50.
578 Returns:
579 tuple: Optimized parameters (jnp.ndarray), best loss (float), and
580 list of loss values during optimization.
581 """
582 self.current_gate = "H"
584 dev = qml.device("default.qubit", wires=1)
586 @qml.qnode(dev, interface="jax")
587 def pulse_circuit(w, pulse_params):
588 Gates.H(0, pulse_params=pulse_params, gate_mode="pulse")
589 return qml.state()
591 @qml.qnode(dev)
592 def unitary_circuit():
593 qml.H(wires=0)
594 return qml.state()
596 target = unitary_circuit()
598 # Optimizing
599 def cost(pulse_params):
600 return self.cost_fn(
601 pulse_params, circuit=pulse_circuit, w=None, target_state=target
602 )
604 pulse_params, loss, losses = self.run_optimization(
605 cost, init_pulse_params, steps, patience, log_interval
606 )
608 # Saving the optimized parameters
609 self.save_results(pulse_params)
611 # Plotting the RX rotation
612 if self.make_plots:
613 log.info("Plotting H rotation...")
614 self.plot_rotation(pulse_params)
616 return pulse_params, loss, losses
618 def optimize_CZ(
619 self,
620 steps=1000,
621 patience: int = 100,
622 init_pulse_params: jnp.ndarray = jnp.array([1.0]),
623 log_interval: int = 50,
624 ):
625 """
626 Optimize pulse parameters for the CZ gate to best approximate
627 the unitary CZ gate.
629 Uses gradient-based optimization to minimize the difference between the
630 pulse-based H_c · H_t · CZ circuit expectation value and the target
631 unitary-based H_c · H_t · CZ.
633 Args:
634 steps (int): Number of optimization steps. Default: 1000.
635 patience (int): Amount of epochs without improvement before early stopping.
636 Default: 100.
637 init_pulse_params (jnp.ndarray): Initial pulse duration. Default: [1.0].
638 log_interval (int): Frequency of printing loss during optimization.
639 Default: 50.
641 Returns:
642 tuple: Optimized parameters (jnp.ndarray) and list of loss values
643 during optimization.
644 """
645 self.current_gate = "CZ"
647 dev = qml.device("default.qubit", wires=2)
649 @qml.qnode(dev, interface="jax")
650 def pulse_circuit(w, pulse_params):
651 qml.H(wires=0)
652 qml.H(wires=1)
653 Gates.CZ(wires=[0, 1], pulse_params=pulse_params, gate_mode="pulse")
654 return qml.state()
656 @qml.qnode(dev)
657 def unitary_circuit():
658 qml.H(wires=0)
659 qml.H(wires=1)
660 qml.CZ(wires=[0, 1])
661 return qml.state()
663 target = unitary_circuit()
665 # Optimizing
666 def cost(pulse_params):
667 return self.cost_fn(
668 pulse_params, circuit=pulse_circuit, w=None, target_state=target
669 )
671 pulse_params, loss, losses = self.run_optimization(
672 cost, init_pulse_params, steps, patience, log_interval
673 )
675 # Saving the optimized parameters
676 self.save_results(pulse_params)
678 # Plotting the CZ rotation
679 if self.make_plots:
680 log.info("Plotting CZ rotation...")
681 self.plot_rotation(pulse_params)
683 return pulse_params, loss, losses
685 def optimize_CY(
686 self,
687 steps=1000,
688 patience: int = 100,
689 init_pulse_params: jnp.ndarray = jnp.array(
690 [0.5, 15.0, 10.0, 1.0, 15.0, 10.0, 1.0, 1.0, 0.5]
691 ),
692 log_interval: int = 50,
693 ):
694 """
695 Optimize pulse parameters for the CY gate to best approximate the
696 unitary CY gate.
698 Uses gradient-based optimization to minimize the difference between the
699 pulse-based H_c · CY circuit expectation value and the target
700 unitary-based H_c · CY.
702 Args:
703 steps (int): Number of optimization steps. Default: 1000.
704 patience (int): Amount of epochs without improvement before early stopping.
705 Default: 100.
706 init_pulse_params (jnp.ndarray): Initial pulse parameters.
707 Default: [1.0, 15.0, 1.0, 1.0, 1.0, 15.0, 1.0].
708 log_interval (int): Frequency of printing loss during optimization.
709 Default: 50.
711 Returns:
712 tuple: Optimized parameters (jnp.ndarray) and list of loss values during
713 optimization.
714 """
715 self.current_gate = "CY"
717 dev = qml.device("default.qubit", wires=2)
719 @qml.qnode(dev, interface="jax")
720 def pulse_circuit(w, pulse_params):
721 qml.H(wires=0)
722 Gates.CY(wires=[0, 1], pulse_params=pulse_params, gate_mode="pulse")
723 return qml.state()
725 @qml.qnode(dev)
726 def unitary_circuit():
727 qml.H(wires=0)
728 qml.CY(wires=[0, 1])
729 return qml.state()
731 target = unitary_circuit()
733 # Optimizing
734 def cost(pulse_params):
735 return self.cost_fn(
736 pulse_params, circuit=pulse_circuit, w=None, target_state=target
737 )
739 pulse_params, loss, losses = self.run_optimization(
740 cost, init_pulse_params, steps, patience, log_interval
741 )
743 # Saving the optimized parameters
744 self.save_results(pulse_params)
746 # Plotting the CY rotation
747 if self.make_plots:
748 warnings.warn("Plotting not implemented yet", UserWarning)
750 return pulse_params, loss, losses
752 def optimize_CX(
753 self,
754 steps=1000,
755 patience: int = 100,
756 init_pulse_params: jnp.ndarray = jnp.array(
757 [1.0, 15.0, 1.0, 1.0, 1.0, 15.0, 1.0]
758 ),
759 log_interval: int = 50,
760 ):
761 """
762 Optimize pulse parameters for the CX gate to best approximate the
763 unitary CX gate.
765 Uses gradient-based optimization to minimize the difference between the
766 pulse-based H_c · CX circuit expectation value and the target
767 unitary-based H_c · CX.
769 Args:
770 steps (int): Number of optimization steps. Default: 1000.
771 patience (int): Amount of epochs without improvement before early stopping.
772 Default: 100.
773 init_pulse_params (jnp.ndarray): Initial pulse parameters.
774 Default: [1.0, 15.0, 1.0, 1.0, 1.0, 15.0, 1.0].
775 log_interval (int): Frequency of printing loss during optimization.
776 Default: 50.
778 Returns:
779 tuple: Optimized parameters (jnp.ndarray) and list of loss values during
780 optimization.
781 """
782 self.current_gate = "CX"
784 dev = qml.device("default.qubit", wires=2)
786 @qml.qnode(dev, interface="jax")
787 def pulse_circuit(w, pulse_params):
788 qml.H(wires=0)
789 Gates.CX(wires=[0, 1], pulse_params=pulse_params, gate_mode="pulse")
790 return qml.state()
792 @qml.qnode(dev)
793 def unitary_circuit():
794 qml.H(wires=0)
795 qml.CNOT(wires=[0, 1])
796 return qml.state()
798 target = unitary_circuit()
800 # Optimizing
801 def cost(pulse_params):
802 return self.cost_fn(
803 pulse_params, circuit=pulse_circuit, w=None, target_state=target
804 )
806 pulse_params, loss, losses = self.run_optimization(
807 cost, init_pulse_params, steps, patience, log_interval
808 )
810 # Saving the optimized parameters
811 self.save_results(pulse_params)
813 # Plotting the CX rotation
814 if self.make_plots:
815 log.info("Plotting CX rotation...")
816 self.plot_rotation(pulse_params)
818 return pulse_params, loss, losses
820 def optimize_CRX(
821 self,
822 steps=1000,
823 patience: int = 100,
824 w: float = jnp.pi,
825 init_pulse_params: jnp.ndarray = jnp.array(
826 [10.0, 15.0, 1.0, 0.5, 1.0, 0.5, 10.0, 15.0, 1.0]
827 ),
828 log_interval: int = 50,
829 ):
830 """
831 Optimize pulse parameters for the CRX(w) gate to best approximate
832 the unitary CRX(w) gate.
834 Uses gradient-based optimization to minimize the difference between the
835 pulse-based H_c · CRX(w) circuit expectation value and the target
836 unitary-based H_c · CRX(w).
838 Args:
839 steps (int): Number of optimization steps. Default: 1000.
840 patience (int): Amount of epochs without improvement before early stopping.
841 Default: 100.
842 w (float): Rotation angle.
843 init_pulse_params (jnp.ndarray): Initial pulse parameters.
844 log_interval (int): Frequency of printing loss.
846 Returns:
847 tuple: Optimized parameters (jnp.ndarray) and list of loss values.
848 """
849 self.current_gate = "CRX"
851 dev = qml.device("default.qubit", wires=2)
853 @qml.qnode(dev, interface="jax")
854 def pulse_circuit(w, pulse_params):
855 qml.H(wires=0)
856 Gates.CRX(w, wires=[0, 1], pulse_params=pulse_params, gate_mode="pulse")
857 return qml.state()
859 @qml.qnode(dev)
860 def unitary_circuit(w):
861 qml.H(wires=0)
862 qml.CRX(w, wires=[0, 1])
863 return qml.state()
865 target = unitary_circuit(w)
867 def cost(pulse_params):
868 return self.cost_fn(
869 pulse_params, circuit=pulse_circuit, w=w, target_state=target
870 )
872 pulse_params, loss, losses = self.run_optimization(
873 cost, init_pulse_params, steps, patience, log_interval
874 )
876 self.save_results(pulse_params)
878 if self.make_plots:
879 warnings.warn("Plotting not implemented yet", UserWarning)
881 return pulse_params, loss, losses
883 def optimize_CRY(
884 self,
885 steps=1000,
886 patience: int = 100,
887 w: float = jnp.pi,
888 init_pulse_params: jnp.ndarray = jnp.array(
889 [10.0, 15.0, 1.0, 0.5, 1.0, 0.5, 10.0, 15.0, 1.0]
890 ),
891 log_interval: int = 50,
892 ):
893 """
894 Optimize pulse parameters for the CRY(w) gate to best approximate
895 the unitary CRY(w) gate.
897 Uses gradient-based optimization to minimize the difference between the
898 pulse-based H_c · CRY(w) circuit expectation value and the target
899 unitary-based H_c · CRY(w).
901 Args:
902 steps (int): Number of optimization steps. Default: 1000.
903 patience (int): Amount of epochs without improvement before early stopping.
904 Default: 100.
905 w (float): Rotation angle.
906 init_pulse_params (jnp.ndarray): Initial pulse parameters.
907 log_interval (int): Frequency of printing loss.
909 Returns:
910 tuple: Optimized parameters (jnp.ndarray) and list of loss values.
911 """
912 self.current_gate = "CRY"
914 dev = qml.device("default.qubit", wires=2)
916 @qml.qnode(dev, interface="jax")
917 def pulse_circuit(w, pulse_params):
918 qml.H(wires=0)
919 Gates.CRY(w, wires=[0, 1], pulse_params=pulse_params, gate_mode="pulse")
920 return qml.state()
922 @qml.qnode(dev)
923 def unitary_circuit(w):
924 qml.H(wires=0)
925 qml.CRY(w, wires=[0, 1])
926 return qml.state()
928 target = unitary_circuit(w)
930 def cost(pulse_params):
931 return self.cost_fn(
932 pulse_params, circuit=pulse_circuit, w=w, target_state=target
933 )
935 pulse_params, loss, losses = self.run_optimization(
936 cost, init_pulse_params, steps, patience, log_interval
937 )
939 self.save_results(pulse_params)
941 if self.make_plots:
942 warnings.warn("Plotting not implemented yet", UserWarning)
944 return pulse_params, loss, losses
946 def optimize_CRZ(
947 self,
948 steps=1000,
949 patience: int = 100,
950 w: float = jnp.pi,
951 init_pulse_params: jnp.ndarray = jnp.array([0.5, 2.0, 0.5]),
952 log_interval: int = 50,
953 ):
954 """
955 Optimize pulse parameters for the CRZ(w) gate to best approximate
956 the unitary CRZ(w) gate.
958 Uses gradient-based optimization to minimize the difference between the
959 pulse-based H_c · H_t · CRZ(w) circuit expectation value and the target
960 unitary-based H_c · H_t · CRZ(w).
962 Args:
963 steps (int): Number of optimization steps. Default: 1000.
964 patience (int): Early stopping patience. Default: 100.
965 w (float): Rotation angle.
966 init_pulse_params (jnp.ndarray): Initial pulse parameters.
967 log_interval (int): Frequency of printing loss.
969 Returns:
970 tuple: Optimized parameters (jnp.ndarray) and list of loss values.
971 """
972 self.current_gate = "CRZ"
974 dev = qml.device("default.qubit", wires=2)
976 # Pulse circuit with full parametric decomposition
977 @qml.qnode(dev, interface="jax")
978 def pulse_circuit(w, pulse_params):
979 qml.H(wires=0)
980 qml.H(wires=1)
981 Gates.CRZ(w, wires=[0, 1], pulse_params=pulse_params, gate_mode="pulse")
982 return qml.state()
984 @qml.qnode(dev)
985 def unitary_circuit(w):
986 qml.H(wires=0)
987 qml.H(wires=1)
988 qml.CRZ(w, wires=[0, 1])
989 return qml.state()
991 target = unitary_circuit(w)
993 # Cost function
994 def cost(pulse_params):
995 return self.cost_fn(
996 pulse_params, circuit=pulse_circuit, w=w, target_state=target
997 )
999 pulse_params, loss, losses = self.run_optimization(
1000 cost, init_pulse_params, steps, patience, log_interval
1001 )
1003 self.save_results(pulse_params)
1005 if self.make_plots:
1006 warnings.warn("Plotting not implemented yet", UserWarning)
1008 return pulse_params, loss, losses
1011if __name__ == "__main__":
1012 qoc = QOC(
1013 make_plots=False,
1014 fig_points=40,
1015 fig_dir="docs/figures",
1016 file_dir="qml_essentials",
1017 )
1019 # - Run optimization for Rot gate -
1020 # log.info("Optimizing Rot gate...")
1021 # optimized_pulse_params, best_loss, loss_values = qoc.optimize_Rot()
1022 # log.info(f"Optimized parameters for Rot: {optimized_pulse_params}\n")
1023 # log.info(f"Best achieved fidelity: {1 - best_loss}")
1024 # log.info("-" * 20, "\n")
1026 # # - Run optimization for RX gate -
1027 # log.info("Optimizing RX gate...")
1028 # optimized_pulse_params, best_loss, loss_values = qoc.optimize_RX(
1029 # w=jnp.pi, init_pulse_params=jnp.array([1.0, 15.0, 1.0])
1030 # )
1031 # log.info(f"Optimized parameters for RX: {optimized_pulse_params}\n")
1032 # log.info(f"Best achieved fidelity: {1 - best_loss}")
1033 # log.info("-" * 20, "\n")
1035 # # - Run optimization for RY gate -
1036 # log.info("Optimizing RY gate...")
1037 # optimized_pulse_params, best_loss, loss_values = qoc.optimize_RY(
1038 # w=jnp.pi, init_pulse_params=jnp.array([1.0, 15.0, 1.0])
1039 # )
1040 # log.info(f"Optimized parameters for RY: {optimized_pulse_params}\n")
1041 # log.info(f"Best achieved fidelity: {1 - best_loss:.6f}")
1042 # log.info("-" * 20, "\n")
1044 # # - Run optimization for RZ gate -
1045 # log.info("Plotting RZ gate rotation...")
1046 # qoc.optimize_RZ()
1047 # log.info("Plotted RZ gate rotation")
1048 # log.info("-" * 20, "\n")
1050 # # - Run optimization for H gate -
1051 # log.info("Optimizing H gate...")
1052 # optimized_pulse_params, best_loss, loss_values = qoc.optimize_H(
1053 # init_pulse_params=jnp.array([1.0, 15.0, 1.0])
1054 # )
1055 # log.info(f"Optimized parameters for H: {optimized_pulse_params}\n")
1056 # log.info(f"Best achieved fidelity: {1 - best_loss}")
1057 # log.info("-" * 20, "\n")
1059 # # - Run optimization for CZ gate -
1060 # log.info("Optimizing CZ gate...")
1061 # optimized_pulse_params, best_loss, loss_values = qoc.optimize_CZ(
1062 # init_pulse_params=jnp.array([0.975]), log_interval=5
1063 # )
1064 # log.info(f"Optimized parameters for CZ: {optimized_pulse_params}\n")
1065 # log.info(f"Best achieved fidelity: {1 - best_loss}")
1066 # log.info("-" * 20, "\n")
1068 # - Run optimization for CY gate -
1069 # log.info("Optimizing CY gate...")
1070 # optimized_pulse_params, best_loss, loss_values = qoc.optimize_CY(log_interval=50)
1071 # log.info(f"Optimized parameters for CY: {optimized_pulse_params}\n")
1072 # log.info(f"Best achieved fidelity: {1 - best_loss}")
1073 # log.info("-" * 20, "\n")
1075 # - Run optimization for CX gate -
1076 # log.info("Optimizing CX gate...")
1077 # optimized_pulse_params, best_loss, loss_values = qoc.optimize_CX(log_interval=50)
1078 # log.info(f"Optimized parameters for CX: {optimized_pulse_params}\n")
1079 # log.info(f"Best achieved fidelity: {1 - best_loss}")
1080 # log.info("-" * 20, "\n")
1082 # - Run optimization for CRX gate -
1083 # log.info("Optimizing CRX gate...")
1084 # optimized_pulse_params, best_loss, loss_values = qoc.optimize_CRX(w=jnp.pi)
1085 # log.info(f"Optimized parameters for CRX: {optimized_pulse_params}\n")
1086 # log.info(f"Best achieved fidelity: {1 - best_loss}")
1087 # log.info("-" * 20, "\n")
1089 # - Run optimization for CRY gate -
1090 # log.info("Optimizing CRY gate...")
1091 # optimized_pulse_params, best_loss, loss_values = qoc.optimize_CRY(w=jnp.pi)
1092 # log.info(f"Optimized parameters for CRY: {optimized_pulse_params}\n")
1093 # log.info(f"Best achieved fidelity: {1 - best_loss}")
1094 # log.info("-" * 20, "\n")
1096 # - Run optimization for CRZ gate -
1097 # log.info("Optimizing CRZ gate...")
1098 # optimized_pulse_params, best_loss, loss_values = qoc.optimize_CRZ()
1099 # log.info(f"Optimized parameters for CRZ: {optimized_pulse_params}\n")
1100 # log.info(f"Best achieved fidelity: {1 - best_loss}")
1101 # log.info("-" * 20, "\n")