Coverage for qml_essentials/qoc.py: 40%
351 statements
« prev ^ index » next coverage.py v7.9.2, created at 2026-02-20 14:03 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2026-02-20 14:03 +0000
1# qa: disable
3import os
4import csv
5import jax
6from jax import numpy as jnp
7import optax
8import pennylane as qml
9from qml_essentials.gates import Gates, PulseInformation
10import argparse
11from functools import partial
12from typing import List, Callable, Union
13import logging
15jax.config.update("jax_enable_x64", True)
16log = logging.getLogger(__name__)
19class QOC:
20 def __init__(
21 self,
22 observable: Union[Callable, List[Callable]] = qml.state,
23 n_steps: int = 1000,
24 n_loops: int = 1,
25 n_samples: int = 8,
26 learning_rate: float = 0.01,
27 log_interval: int = 50,
28 skip_on_fidelity: bool = True,
29 ):
30 """
31 Initialize Quantum Optimal Control with Pulse-level Gates.
33 Args:
34 observable (str): Observable to measure during optimization.
35 n_steps (int): Number of steps in optimization.
36 n_loops (int): Number of loops for optimization.
37 n_samples (int): Number of parameter samples per step.
38 learning_rate (float): Learning rate for Adam with
39 weight decay regularization.
40 log_interval (int): Interval for logging.
41 skip_on_fidelity (bool): Skip writing to qoc_results if fidelity is lower?
42 """
43 self.ws = jnp.linspace(0, 2 * jnp.pi, n_samples)
45 self.observable = observable
46 self.n_steps = n_steps
47 self.n_loops = n_loops
48 self.n_samples = n_samples
49 self.learning_rate = learning_rate
50 self.log_interval = log_interval
51 self.skip_on_fidelity = skip_on_fidelity
53 self.current_gate = None
55 def save_results(self, gate, fidelity, pulse_params):
56 """
57 Saves the optimized pulse parameters and fidelity for a given gate to a CSV file
59 Args:
60 gate (str): Name of the gate.
61 fidelity (float): Fidelity of the optimized pulse parameters.
62 pulse_params (list): Optimized pulse parameters for the gate.
64 Notes:
65 If the gate already exists in the file and
66 the newly optimized pulse parameters have a higher fidelity,
67 the existing entry will be overwritten.
68 If the fidelity is lower, the new entry will be skipped unless
69 `skip_on_fidelity=False`.
70 """
71 if self.file_dir is not None:
72 os.makedirs(self.file_dir, exist_ok=True)
73 filename = os.path.join(self.file_dir, "qoc_results.csv")
75 reader = None
76 if os.path.isfile(filename):
77 with open(filename, mode="r", newline="") as f:
78 reader = csv.reader(f.readlines())
80 entry = [gate] + [fidelity] + list(map(float, pulse_params))
82 with open(filename, mode="w", newline="") as f:
83 writer = csv.writer(f)
84 match = False
85 if reader is not None:
86 for row in reader:
87 # gate already exists
88 if row[0] == gate:
89 if fidelity > float(row[1]):
90 writer.writerow(entry)
91 else:
92 log.warning(
93 f"Pulse parameters for {gate} already exist with "
94 f"higher fidelity ({row[1]} >= {fidelity})"
95 )
96 if not self.skip_on_fidelity:
97 log.info("Overwriting parameters anyway")
98 writer.writerow(entry)
99 else:
100 writer.writerow(row)
101 match = True
102 # any other gate
103 else:
104 writer.writerow(row)
105 # gate does not exist
106 if not match:
107 writer.writerow(entry)
109 def cost_fn(self, pulse_params, pulse_qnode, target_qnode) -> float:
110 """
111 Cost function for QOC optimization.
113 The cost function is calculated as the average of the fidelity and
114 phase difference between the pulse-based and unitary-based gates.
116 Args:
117 pulse_params (list or array): Optimized parameters to
118 use for the pulse-based gate.
119 pulse_qnode (callable): Pulse-based gate qnode.
120 target_qnode (callable): Unitary-based gate qnode.
122 Returns:
123 float: Cost function value.
124 """
125 abs_diff = 0
126 phase_diff = 0
127 for w in jnp.arange(0, 2 * jnp.pi, (2 * jnp.pi) / self.n_samples):
128 pulse_state = pulse_qnode(w, pulse_params)
129 target_state = target_qnode(w)
130 dot_prod = jnp.vdot(target_state, pulse_state)
131 abs_diff += 1 - jnp.abs(dot_prod) ** 2 # one if no diff
132 phase_diff += jnp.abs(jnp.angle(dot_prod)) / jnp.pi # zero if no diff
134 abs_diff /= self.n_samples
135 phase_diff /= self.n_samples
137 return (abs_diff + phase_diff) / 2 # loss
139 def multi_objective_cost_fn(
140 self, pulse_params, pulse_qnodes, target_qnodes, leafs
141 ) -> float:
142 """
143 Cost function for QOC optimization.
145 The cost function is calculated as the average of the fidelity and
146 phase difference between the pulse-based and unitary-based gates.
148 Args:
149 pulse_params (list or array): Optimized parameters to use for
150 the pulse-based gate.
151 pulse_qnode (callable): Pulse-based gate qnode.
152 target_qnode (callable): Unitary-based gate qnode.
154 Returns:
155 float: Cost function value.
156 """
157 idx = 0
158 for leaf in leafs:
159 nidx = idx + leaf.size
160 leaf.params = pulse_params[idx:nidx]
161 idx = nidx
163 abs_diff = 0
164 phase_diff = 0
165 for pulse_qnode, target_qnode in zip(pulse_qnodes, target_qnodes):
166 for w in jnp.arange(0, 2 * jnp.pi, (2 * jnp.pi) / self.n_samples):
167 pulse_state = pulse_qnode(w, None)
168 target_state = target_qnode(w)
169 dot_prod = jnp.vdot(target_state, pulse_state)
170 abs_diff += 1 - jnp.abs(dot_prod) ** 2 # one if no diff
171 # phase_diff += jnp.abs(jnp.angle(dot_prod)) / jnp.pi # zero if no diff
173 abs_diff /= self.n_samples
174 phase_diff /= self.n_samples
176 n_nodes = len(pulse_qnodes)
177 return ((abs_diff + phase_diff) / 2) / n_nodes # loss
179 def run_optimization(
180 self,
181 cost,
182 params,
183 *args,
184 ) -> tuple[jnp.ndarray, List]:
185 """
186 Run the optimization process.
188 Args:
189 cost (callable): Cost function to use for optimization.
190 params (list or array): Initial parameters to use for
191 the pulse-based gate.
192 *args: Arguments to pass to the cost function.
194 Returns:
195 tuple[jnp.ndarray, List]: Optimized parameters and list of loss values
196 at each iteration.
197 """
198 optimizer = optax.adamw(self.learning_rate)
199 opt_state = optimizer.init(params)
201 loss = cost(params, *args).item()
202 loss_history = [loss]
203 best_pulse_params = params
205 @jax.jit
206 def opt_step(params, opt_state, *args):
207 loss, grads = jax.value_and_grad(cost)(params, *args)
208 updates, opt_state = optimizer.update(grads, opt_state, params)
209 params = optax.apply_updates(params, updates)
210 return params, opt_state, loss
212 for step in range(self.n_steps):
213 if step % self.log_interval == 0:
214 log.info(f"Step {step}/{self.n_steps}, Loss: {loss_history[-1]:.3e}")
216 params, opt_state, loss = opt_step(params, opt_state, *args)
218 if loss.item() < min(loss_history):
219 log.debug(f"Best set of params found at step {step}")
220 best_pulse_params = params
222 loss_history.append(loss.item())
224 return best_pulse_params, loss_history
226 def optimize(self, simulator, wires):
227 def decorator(create_circuits):
228 def wrapper(init_pulse_params: jnp.ndarray = None):
229 """
230 This function is a wrapper for the create_circuits method.
231 It takes a simulator and wires as input and optimizes
232 the pulse parameters using the cost function defined
233 in the QOC class.
235 Args:
236 create_circuits (callable): A function to generate the pulse and
237 target circuits for the gate.
238 init_pulse_params (array): Initial pulse parameters to use for
239 the pulse-based gate.
241 Returns:
242 tuple: Optimized pulse parameters and list of loss values
243 at each iteration.
244 """
245 dev = qml.device(simulator, wires=wires)
246 pulse_circuit, target_circuit = create_circuits(dev)
248 pulse_qnode = qml.QNode(pulse_circuit, dev, interface="jax")
249 target_qnode = qml.QNode(target_circuit, dev, interface="jax")
251 gate_name = create_circuits.__name__.split("_")[1]
253 if init_pulse_params is None:
254 log.warning(
255 f"Using initial pulse parameters for {gate_name} \
256 from `ansaetze.py`"
257 )
258 init_pulse_params = PulseInformation.gate_by_name(gate_name).params
259 log.debug(
260 f"Initial pulse parameters for {gate_name}: {init_pulse_params}"
261 )
263 # Optimizing
264 pulse_params, loss_history = self.run_optimization(
265 partial(
266 self.cost_fn,
267 pulse_qnode=pulse_qnode,
268 target_qnode=target_qnode,
269 ),
270 params=init_pulse_params,
271 )
273 self.save_results(
274 gate=gate_name,
275 fidelity=1 - min(loss_history),
276 pulse_params=pulse_params,
277 )
279 return pulse_params, loss_history
281 return wrapper
283 return decorator
285 def optimize_multi_objective(self, simulator, wires):
286 def decorator(create_circuits_array):
287 def wrapper(init_pulse_params: jnp.ndarray = None):
288 """
289 This function is a wrapper for the create_circuits method.
290 It takes a simulator and wires as input and optimizes
291 the pulse parameters using the cost function defined
292 in the QOC class.
294 Args:
295 create_circuits (callable): A function to generate the pulse and
296 target circuits for the gate.
297 init_pulse_params (array): Initial pulse parameters to use for
298 the pulse-based gate.
300 Returns:
301 tuple: Optimized pulse parameters and list of
302 loss values at each iteration.
303 """
304 dev = qml.device(simulator, wires=wires)
306 pulse_qnodes = []
307 target_qnodes = []
308 leafs = []
309 for create_circuits in create_circuits_array:
310 pulse_circuit, target_circuit = create_circuits(dev)
312 pulse_qnodes.append(qml.QNode(pulse_circuit, dev, interface="jax"))
313 target_qnodes.append(
314 qml.QNode(target_circuit, dev, interface="jax")
315 )
317 gate_name = create_circuits.__name__.split("_")[1]
319 leafs.append(PulseInformation.gate_by_name(gate_name).leafs)
320 leafs = list(set(leafs))
322 params = []
323 for leaf in leafs:
324 params.extend(leaf.params)
326 params = jnp.concatenate(params)
328 # Optimizing
329 pulse_params, loss_history = self.run_optimization(
330 partial(
331 self.multi_objective_cost_fn,
332 pulse_qnode=pulse_qnodes,
333 target_qnode=target_qnodes,
334 leafs=leafs,
335 ),
336 params=params,
337 )
339 idx = 0
340 for leaf in leafs:
341 nidx = idx + leaf.size
342 # Saving the optimized parameters
343 self.save_results(
344 gate=leaf.name,
345 fidelity=1 - min(loss_history),
346 pulse_params=pulse_params[idx:nidx],
347 )
348 idx = nidx
350 return pulse_params, loss_history
352 return wrapper
354 return decorator
356 def create_RX(self, init_pulse_params: jnp.ndarray = None):
357 def pulse_circuit(w, pulse_params):
358 Gates.RX(w, 0, pulse_params=pulse_params, gate_mode="pulse")
359 return qml.state()
361 def target_circuit(w):
362 qml.RX(w, wires=0)
363 return qml.state()
365 return pulse_circuit, target_circuit
367 def create_RY(self, init_pulse_params: jnp.ndarray = None):
368 def pulse_circuit(w, pulse_params):
369 Gates.RY(w, 0, pulse_params=pulse_params, gate_mode="pulse")
370 return qml.state()
372 def target_circuit(w):
373 qml.RY(w, wires=0)
374 return qml.state()
376 return pulse_circuit, target_circuit
378 def create_RZ(self, init_pulse_params: jnp.ndarray = None):
379 def pulse_circuit(w, pulse_params):
380 qml.H(wires=0)
381 Gates.RZ(w, 0, pulse_params=pulse_params, gate_mode="pulse")
382 qml.H(wires=0)
383 return qml.state()
385 def target_circuit(w):
386 qml.H(wires=0)
387 qml.RZ(w, wires=0)
388 qml.H(wires=0)
389 return qml.state()
391 return pulse_circuit, target_circuit
393 def create_H(self, init_pulse_params: jnp.ndarray = None):
394 def pulse_circuit(w, pulse_params):
395 qml.RY(w, wires=0)
396 Gates.H(0, pulse_params=pulse_params, gate_mode="pulse")
397 return qml.state()
399 def target_circuit(w):
400 qml.RY(w, wires=0)
401 qml.H(wires=0)
402 return qml.state()
404 return pulse_circuit, target_circuit
406 def create_Rot(self, init_pulse_params: jnp.ndarray = None):
407 def pulse_circuit(w, pulse_params):
408 qml.H(wires=0)
409 Gates.Rot(w, w * 2, w * 3, 0, pulse_params=pulse_params, gate_mode="pulse")
410 return qml.state()
412 def target_circuit(w):
413 qml.H(wires=0)
414 qml.Rot(w, w * 2, w * 3, wires=0)
415 return qml.state()
417 return pulse_circuit, target_circuit
419 def create_CX(self, init_pulse_params: jnp.ndarray = None):
420 def pulse_circuit(w, pulse_params):
421 qml.RY(w, wires=0)
422 qml.H(wires=1)
423 Gates.CX(wires=[0, 1], pulse_params=pulse_params, gate_mode="pulse")
424 return qml.state()
426 def target_circuit(w):
427 qml.RY(w, wires=0)
428 qml.H(wires=1)
429 qml.CNOT(wires=[0, 1])
430 return qml.state()
432 return pulse_circuit, target_circuit
434 def create_CY(self, init_pulse_params: jnp.ndarray = None):
435 def pulse_circuit(w, pulse_params):
436 qml.RX(w, wires=0)
437 qml.H(wires=1)
438 Gates.CY(wires=[0, 1], pulse_params=pulse_params, gate_mode="pulse")
439 return qml.state()
441 def target_circuit(w):
442 qml.RX(w, wires=0)
443 qml.H(wires=1)
444 qml.CY(wires=[0, 1])
445 return qml.state()
447 return pulse_circuit, target_circuit
449 def create_CZ(self, init_pulse_params: jnp.ndarray = None):
450 def pulse_circuit(w, pulse_params):
451 qml.RY(w, wires=0)
452 qml.H(wires=1)
453 Gates.CZ(wires=[0, 1], pulse_params=pulse_params, gate_mode="pulse")
454 return qml.state()
456 def target_circuit(w):
457 qml.RY(w, wires=0)
458 qml.H(wires=1)
459 qml.CZ(wires=[0, 1])
460 return qml.state()
462 return pulse_circuit, target_circuit
464 def create_CRX(self, init_pulse_params: jnp.ndarray = None):
465 def pulse_circuit(w, pulse_params):
466 qml.H(wires=0)
467 Gates.CRX(w, wires=[0, 1], pulse_params=pulse_params, gate_mode="pulse")
468 return qml.state()
470 def target_circuit(w):
471 qml.H(wires=0)
472 qml.CRX(w, wires=[0, 1])
473 return qml.state()
475 return pulse_circuit, target_circuit
477 def create_CRY(self, init_pulse_params: jnp.ndarray = None):
478 def pulse_circuit(w, pulse_params):
479 qml.H(wires=0)
480 Gates.CRY(w, wires=[0, 1], pulse_params=pulse_params, gate_mode="pulse")
481 return qml.state()
483 def target_circuit(w):
484 qml.H(wires=0)
485 qml.CRY(w, wires=[0, 1])
486 return qml.state()
488 return pulse_circuit, target_circuit
490 def create_CRZ(self, init_pulse_params: jnp.ndarray = None):
491 def pulse_circuit(w, pulse_params):
492 qml.H(wires=0)
493 qml.H(wires=1)
494 Gates.CRZ(w, wires=[0, 1], pulse_params=pulse_params, gate_mode="pulse")
495 return qml.state()
497 def target_circuit(w):
498 qml.H(wires=0)
499 qml.H(wires=1)
500 qml.CRZ(w, wires=[0, 1])
501 return qml.state()
503 return pulse_circuit, target_circuit
505 def optimize_all(self, sel_gates, make_log):
506 assert (
507 self.observable == qml.state
508 ), "Observable must be qml.state when doing optimization"
510 log_history = {}
511 optimize_1q = self.optimize("default.qubit", wires=1)
512 optimize_2q = self.optimize("default.qubit", wires=2)
514 # random_key = jax.random.key(seed=1000)
515 # PulseInformation.shuffle_params(random_key)
516 for loop in range(self.n_loops):
517 log.info("Reading back optimized pulse parameters")
518 # PulseInformation.update_params()
520 log.info(f"Optimization loop {loop+1} of {self.n_loops}")
522 if "RX" in sel_gates or "all" in sel_gates:
523 log.info("Optimizing RX gate...")
524 optimized_pulse_params, loss_history = optimize_1q(self.create_RX)()
525 log.info(f"Optimized parameters for RX: {optimized_pulse_params}")
526 log.info(f"Best achieved fidelity: {(1 - min(loss_history))*100:.5f}%")
527 log_history["RX"] = log_history.get("RX", []) + loss_history
529 if "RY" in sel_gates or "all" in sel_gates:
530 log.info("Optimizing RY gate...")
531 optimized_pulse_params, loss_history = optimize_1q(self.create_RY)()
532 log.info(f"Optimized parameters for RY: {optimized_pulse_params}")
533 log.info(f"Best achieved fidelity: {(1 - min(loss_history))*100:.5f}%")
534 log_history["RY"] = log_history.get("RY", []) + loss_history
536 if "RZ" in sel_gates or "all" in sel_gates:
537 log.info("Optimizing RZ gate...")
538 optimized_pulse_params, loss_history = optimize_1q(self.create_RZ)()
539 log.info(f"Optimized parameters for RZ: {optimized_pulse_params}")
540 log.info(f"Best achieved fidelity: {(1 - min(loss_history))*100:.5f}%")
541 log_history["RZ"] = log_history.get("RZ", []) + loss_history
543 if "H" in sel_gates or "all" in sel_gates:
544 log.info("Optimizing H gate...")
545 optimized_pulse_params, loss_history = optimize_1q(self.create_H)()
546 log.info(f"Optimized parameters for H: {optimized_pulse_params}")
547 log.info(f"Best achieved fidelity: {(1 - min(loss_history))*100:.5f}%")
548 log_history["H"] = log_history.get("H", []) + loss_history
550 if "Rot" in sel_gates or "all" in sel_gates:
551 log.info("Optimizing Rot gate...")
552 optimized_pulse_params, loss_history = optimize_1q(self.create_Rot)()
553 log.info(f"Optimized parameters for Rot: {optimized_pulse_params}")
554 log.info(f"Best achieved fidelity: {(1 - min(loss_history))*100:.5f}%")
555 log_history["Rot"] = log_history.get("Rot", []) + loss_history
557 if "CX" in sel_gates or "all" in sel_gates:
558 log.info("Optimizing CX gate...")
559 optimized_pulse_params, loss_history = optimize_2q(self.create_CX)()
560 log.info(f"Optimized parameters for CX: {optimized_pulse_params}")
561 log.info(f"Best achieved fidelity: {(1 - min(loss_history))*100:.5f}%")
562 log_history["CX"] = log_history.get("CX", []) + loss_history
564 if "CZ" in sel_gates or "all" in sel_gates:
565 log.info("Optimizing CZ gate...")
566 optimized_pulse_params, loss_history = optimize_2q(self.create_CZ)()
567 log.info(f"Optimized parameters for CZ: {optimized_pulse_params}")
568 log.info(f"Best achieved fidelity: {(1 - min(loss_history))*100:.5f}%")
569 log_history["CZ"] = log_history.get("CZ", []) + loss_history
571 if "CY" in sel_gates or "all" in sel_gates:
572 log.info("Optimizing CY gate...")
573 optimized_pulse_params, loss_history = optimize_2q(self.create_CY)()
574 log.info(f"Optimized parameters for CY: {optimized_pulse_params}")
575 log.info(f"Best achieved fidelity: {(1 - min(loss_history))*100:.5f}%")
576 log_history["CY"] = log_history.get("CY", []) + loss_history
578 if "CRX" in sel_gates or "all" in sel_gates:
579 log.info("Optimizing CRX gate...")
580 optimized_pulse_params, loss_history = optimize_2q(self.create_CRX)()
581 log.info(f"Optimized parameters for CRX: {optimized_pulse_params}")
582 log.info(f"Best achieved fidelity: {(1 - min(loss_history))*100:.5f}%")
583 log_history["CRX"] = log_history.get("CRX", []) + loss_history
585 if "CRY" in sel_gates or "all" in sel_gates:
586 log.info("Optimizing CRY gate...")
587 optimized_pulse_params, loss_history = optimize_2q(self.create_CRY)()
588 log.info(f"Optimized parameters for CRY: {optimized_pulse_params}")
589 log.info(f"Best achieved fidelity: {(1 - min(loss_history))*100:.5f}%")
590 log_history["CRY"] = log_history.get("CRY", []) + loss_history
592 if "CRZ" in sel_gates or "all" in sel_gates:
593 log.info("Optimizing CRZ gate...")
594 optimized_pulse_params, loss_history = optimize_2q(self.create_CRZ)()
595 log.info(f"Optimized parameters for CRZ: {optimized_pulse_params}")
596 log.info(f"Best achieved fidelity: {(1 - min(loss_history))*100:.5f}%")
597 log_history["CRZ"] = log_history.get("CRZ", []) + loss_history
599 if make_log:
600 # write log history to file
601 with open("qml_essentials/qoc_logs.csv", "w") as f:
602 writer = csv.writer(f)
603 writer.writerow(log_history.keys())
604 writer.writerows(zip(*log_history.values()))
607if __name__ == "__main__":
608 # argparse the selected gate
609 parser = argparse.ArgumentParser()
610 parser.add_argument("--gates", type=str, default=["RX", "RY", "RZ", "CZ"])
611 parser.add_argument("--log", type=str, default=True)
612 # TODO: add more arguments that take e.g. n_steps etc for initialization
614 args = parser.parse_args()
615 sel_gates = str(args.gates)
616 make_log = bool(args.log)
618 log.setLevel(logging.INFO)
619 log.addHandler(logging.StreamHandler())
621 qoc = QOC(
622 observable=qml.state,
623 )
625 qoc.optimize_all(sel_gates=sel_gates, make_log=make_log)