Coverage for qml_essentials/qoc.py: 68%

380 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-10-02 13:10 +0000

1import os 

2import csv 

3import jax 

4from jax import numpy as jnp 

5import optax 

6import pennylane as qml 

7from qml_essentials.ansaetze import Gates 

8import matplotlib.pyplot as plt 

9import warnings 

10import logging 

11 

12jax.config.update("jax_enable_x64", True) 

13log = logging.getLogger(__name__) 

14 

15 

16class QOC: 

17 # TODO: Potentially refactor all the optimize_*()... The only differences 

18 # are the circuits 

19 def __init__( 

20 self, 

21 make_plots=False, 

22 file_dir="qoc/results", 

23 fig_dir="qoc/figures", 

24 fig_points=70, 

25 ): 

26 """ 

27 Initialize Quantum Optimal Control with Pulse-level Gates. 

28 

29 Args: 

30 log_dir (str): Directory for TensorBoard logs. 

31 make_plots (bool): Whether to generate and save plots. 

32 file_dir (str): Directory to save optimization results. 

33 fig_dir (str): Directory to save figures. 

34 fig_points (int): Number of points for plotting rotations. 

35 """ 

36 self.ws = jnp.linspace(0, 2 * jnp.pi, fig_points) 

37 

38 self.make_plots = make_plots 

39 self.file_dir = file_dir 

40 self.fig_dir = fig_dir 

41 

42 self.current_gate = None 

43 

44 def get_circuits(self): 

45 """ 

46 Return pulse- and unitary-based circuits for the current gate. 

47 

48 Returns: 

49 tuple: (pulse_circuit, unitary_circuit, operation_str) 

50 """ 

51 dev = qml.device("default.qubit", wires=1) 

52 

53 if self.current_gate in ["RX", "RY"]: 

54 

55 @qml.qnode(dev, interface="jax") 

56 def pulse_circuit(w, pulse_params=None): 

57 getattr(Gates, self.current_gate)( 

58 w, 0, pulse_params=pulse_params, gate_mode="pulse" 

59 ) 

60 return [ 

61 qml.expval(qml.PauliX(0)), 

62 qml.expval(qml.PauliY(0)), 

63 qml.expval(qml.PauliZ(0)), 

64 ] 

65 

66 @qml.qnode(dev) 

67 def unitary_circuit(w): 

68 getattr(qml, self.current_gate)(w, wires=0) 

69 return [ 

70 qml.expval(qml.PauliX(0)), 

71 qml.expval(qml.PauliY(0)), 

72 qml.expval(qml.PauliZ(0)), 

73 ] 

74 

75 operation = f"{self.current_gate}(w)" 

76 

77 elif self.current_gate == "RZ": 

78 

79 @qml.qnode(dev, interface="jax") 

80 def pulse_circuit(w, *_): 

81 qml.RX(jnp.pi / 2, wires=0) 

82 getattr(Gates, self.current_gate)(w, 0, gate_mode="pulse") 

83 return [ 

84 qml.expval(qml.PauliX(0)), 

85 qml.expval(qml.PauliY(0)), 

86 qml.expval(qml.PauliZ(0)), 

87 ] 

88 

89 @qml.qnode(dev) 

90 def unitary_circuit(w): 

91 qml.RX(jnp.pi / 2, wires=0) 

92 getattr(qml, self.current_gate)(w, wires=0) 

93 return [ 

94 qml.expval(qml.PauliX(0)), 

95 qml.expval(qml.PauliY(0)), 

96 qml.expval(qml.PauliZ(0)), 

97 ] 

98 

99 operation = f"RX(π / 2)·{self.current_gate}(w)" 

100 

101 elif self.current_gate == "H": 

102 

103 @qml.qnode(dev, interface="jax") 

104 def pulse_circuit(w, pulse_params=None): 

105 qml.RX(w, wires=0) 

106 getattr(Gates, self.current_gate)( 

107 0, pulse_params=pulse_params, gate_mode="pulse" 

108 ) 

109 return [ 

110 qml.expval(qml.PauliX(0)), 

111 qml.expval(qml.PauliY(0)), 

112 qml.expval(qml.PauliZ(0)), 

113 ] 

114 

115 @qml.qnode(dev) 

116 def unitary_circuit(w): 

117 qml.RX(w, wires=0) 

118 getattr(qml, self.current_gate)(wires=0) 

119 return [ 

120 qml.expval(qml.PauliX(0)), 

121 qml.expval(qml.PauliY(0)), 

122 qml.expval(qml.PauliZ(0)), 

123 ] 

124 

125 operation = f"RX(w)·{self.current_gate}" 

126 

127 elif self.current_gate == "CZ": 

128 dev = qml.device("default.qubit", wires=2) 

129 

130 @qml.qnode(dev, interface="jax") 

131 def pulse_circuit(w, pulse_params=None): 

132 qml.RX(w, wires=0) 

133 qml.RX(w, wires=1) 

134 Gates.CZ([0, 1], pulse_params=pulse_params, gate_mode="pulse") 

135 qml.RX(-w, wires=1) 

136 qml.RX(-w, wires=0) 

137 return [ 

138 qml.expval(qml.PauliX(1)), 

139 qml.expval(qml.PauliY(1)), 

140 qml.expval(qml.PauliZ(1)), 

141 ] 

142 

143 @qml.qnode(dev) 

144 def unitary_circuit(w): 

145 qml.RX(w, wires=0) 

146 qml.RX(w, wires=1) 

147 qml.CZ(wires=[0, 1]) 

148 qml.RX(-w, wires=1) 

149 qml.RX(-w, wires=0) 

150 return [ 

151 qml.expval(qml.PauliX(1)), 

152 qml.expval(qml.PauliY(1)), 

153 qml.expval(qml.PauliZ(1)), 

154 ] 

155 

156 operation = r"$RX_0(w)$·$RX_1(w)$·$CZ_{0, 1}$·$RX_1(-w)$·$RX_0(-w)$" 

157 

158 elif self.current_gate == "CX": 

159 dev = qml.device("default.qubit", wires=2) 

160 

161 @qml.qnode(dev, interface="jax") 

162 def pulse_circuit(w, pulse_params=None): 

163 qml.RX(w, wires=0) 

164 Gates.CX([0, 1], pulse_params=pulse_params, gate_mode="pulse") 

165 return [ 

166 qml.expval(qml.PauliX(1)), 

167 qml.expval(qml.PauliY(1)), 

168 qml.expval(qml.PauliZ(1)), 

169 ] 

170 

171 @qml.qnode(dev) 

172 def unitary_circuit(w): 

173 qml.RX(w, wires=0) 

174 qml.CNOT(wires=[0, 1]) 

175 return [ 

176 qml.expval(qml.PauliX(1)), 

177 qml.expval(qml.PauliY(1)), 

178 qml.expval(qml.PauliZ(1)), 

179 ] 

180 

181 operation = r"$RX_0(w)$·$CX_{0,1}$" 

182 

183 return pulse_circuit, unitary_circuit, operation 

184 

185 # TODO: Update method for new gates (Rot, CY, CRZ, CRY, CRX) 

186 def plot_rotation(self, pulse_params): 

187 """ 

188 Plot expectation values of pulse- and unitary-based circuits for the 

189 current gate as a function of rotation angle. 

190 

191 Args: 

192 pulse_params: pulse parameters of pulse level gate. 

193 """ 

194 pulse_circuit, unitary_circuit, operation = self.get_circuits() 

195 

196 pulse_expvals = [pulse_circuit(w, pulse_params) for w in self.ws] 

197 ideal_expvals = [unitary_circuit(w) for w in self.ws] 

198 

199 pulse_expvals = jnp.array(pulse_expvals) 

200 ideal_expvals = jnp.array(ideal_expvals) 

201 

202 fig, axs = plt.subplots(3, 1, figsize=(6, 12)) 

203 

204 bases = ["X", "Y", "Z"] 

205 for i, basis in enumerate(bases): 

206 axs[i].plot(self.ws, pulse_expvals[:, i], label="Pulse-based") 

207 axs[i].plot(self.ws, ideal_expvals[:, i], "--", label="Unitary-based") 

208 axs[i].set_xlabel("Rotation angle w (rad)") 

209 axs[i].set_ylabel(f"{basis}") 

210 axs[i].set_title(f"{operation} in {basis}-basis") 

211 axs[i].grid(True) 

212 axs[i].legend() 

213 

214 xticks = [0, jnp.pi / 2, jnp.pi, 3 * jnp.pi / 2, 2 * jnp.pi] 

215 xtick_labels = ["0", "π/2", "π", "3π/2", "2π"] 

216 for ax in axs: 

217 ax.set_xticks(xticks) 

218 ax.set_xticklabels(xtick_labels) 

219 

220 plt.tight_layout() 

221 os.makedirs(self.fig_dir, exist_ok=True) 

222 plt.savefig(f"{self.fig_dir}/qoc_{self.current_gate}(w).png") 

223 plt.close() 

224 

225 def save_results(self, opt_pulse_params): 

226 """ 

227 Save optimized pulse parameters to CSV file. 

228 

229 Args: 

230 opt_pulse_params (list or array): Optimized parameters to save. 

231 filename (str): Path to CSV file. 

232 """ 

233 header = ["gate"] + [f"param_{i+1}" for i in range(len(opt_pulse_params))] 

234 if self.file_dir is not None: 

235 os.makedirs(self.file_dir, exist_ok=True) 

236 filename = os.path.join(self.file_dir, "qoc_results.csv") 

237 file_exists = os.path.isfile(filename) 

238 

239 with open(filename, mode="a", newline="") as f: 

240 writer = csv.writer(f) 

241 if not file_exists: 

242 writer.writerow(header) 

243 writer.writerow( 

244 [self.current_gate] + list(map(float, opt_pulse_params)) 

245 ) 

246 

247 def loss_fn(self, state, target_state): 

248 """ 

249 Compute infidelity between two quantum states. 

250 

251 Args: 

252 state (array): Output state from pulse circuit. 

253 target_state (array): Target state from unitary circuit. 

254 

255 Returns: 

256 float: Infidelity (1 - fidelity). 

257 """ 

258 fidelity = jnp.abs(jnp.vdot(target_state, state)) ** 2 

259 return 1 - fidelity 

260 

261 def cost_fn(self, pulse_params, circuit, w, target_state): 

262 """ 

263 Compute cost for optimization by evaluating circuit and loss. 

264 

265 Args: 

266 pulse_params: pulse parameters of pulse level gate. 

267 circuit (callable): QNode circuit accepting (w, pulse_params). 

268 w (float): Rotation angle. 

269 target_state (array): Target quantum state. 

270 

271 Returns: 

272 float: Computed loss. 

273 """ 

274 state = circuit(w, pulse_params) 

275 return self.loss_fn(state, target_state) 

276 

277 def run_optimization( 

278 self, 

279 cost, 

280 init_pulse_params, 

281 steps, 

282 patience, 

283 log_interval, 

284 *args, 

285 ): 

286 """ 

287 Run gradient-based optimization on given cost function. 

288 

289 Args: 

290 cost (callable): Cost function to minimize. 

291 init_pulse_params (array): Initial parameters. 

292 steps (int): Number of optimization steps. 

293 log_interval (int): Print frequency. 

294 *args: Extra args for cost. 

295 patience (int): Early stopping patience (default: 20). 

296 

297 Returns: 

298 tuple: (optimized parameters, best loss, list of loss values) 

299 """ 

300 optimizer = optax.adam(0.1) 

301 opt_state = optimizer.init(init_pulse_params) 

302 pulse_params = init_pulse_params 

303 losses = [] 

304 

305 best_loss = float("inf") 

306 best_pulse_params = pulse_params 

307 no_improve_counter = 0 

308 

309 @jax.jit 

310 def opt_step(pulse_params, opt_state, *args): 

311 loss, grads = jax.value_and_grad(cost)(pulse_params, *args) 

312 updates, opt_state = optimizer.update(grads, opt_state) 

313 pulse_params = optax.apply_updates(pulse_params, updates) 

314 return pulse_params, opt_state, loss 

315 

316 for step in range(steps): 

317 pulse_params, opt_state, loss = opt_step(pulse_params, opt_state, *args) 

318 losses.append(loss) 

319 

320 if loss < best_loss: 

321 best_loss = loss 

322 best_pulse_params = pulse_params 

323 no_improve_counter = 0 

324 else: 

325 no_improve_counter += 1 

326 

327 if (step + 1) % log_interval == 0: 

328 log.info(f"Step {step + 1}/{steps}, Loss: {loss:.2e}") 

329 

330 if no_improve_counter >= patience: 

331 log.info(f"Early stopping at step {step + 1} due to no improvement.") 

332 break 

333 

334 return best_pulse_params, best_loss, losses 

335 

336 def optimize_Rot( 

337 self, 

338 steps: int = 1000, 

339 patience: int = 100, 

340 phi: float = jnp.pi / 2, 

341 theta: float = jnp.pi / 2, 

342 omega: float = jnp.pi / 2, 

343 init_pulse_params: jnp.array = jnp.array([0.5, 1.0, 15.0, 1.0, 0.5]), 

344 log_interval: int = 50, 

345 ): 

346 """ 

347 Optimize pulse parameters for the Rot(theta, phi, lam) gate. 

348 

349 Uses gradient-based optimization to minimize the difference between the 

350 pulse-based Rot(phi, theta, omega) circuit expectation value and the target 

351 unitary-based Rot(phi, theta, omega). 

352 

353 Args: 

354 steps (int): Number of optimization steps. Default: 1000. 

355 patience (int): Patience for early stopping. Default: 100. 

356 theta, phi, lam (float): Rotation angles for the Rot gate. 

357 Default: π / 2 for all three. 

358 init_pulse_params (jnp.ndarray): Initial pulse parameters. 

359 Default: [0.5, 1.0, 15.0, 1.0, 0.5]. 

360 log_interval (int): Frequency of printing loss. 

361 

362 Returns: 

363 tuple: Optimized parameters and list of loss values. 

364 """ 

365 self.current_gate = "Rot" 

366 w = (phi, theta, omega) 

367 

368 dev = qml.device("default.qubit", wires=1) 

369 

370 @qml.qnode(dev, interface="jax") 

371 def pulse_circuit(w, pulse_params): 

372 phi, theta, omega = w 

373 Gates.Rot( 

374 phi, theta, omega, 0, pulse_params=pulse_params, gate_mode="pulse" 

375 ) 

376 return qml.state() 

377 

378 @qml.qnode(dev) 

379 def unitary_circuit(w): 

380 phi, theta, omega = w 

381 qml.Rot(phi, theta, omega, wires=0) 

382 return qml.state() 

383 

384 target = unitary_circuit(w) 

385 

386 # Optimizing 

387 def cost(pulse_params): 

388 return self.cost_fn( 

389 pulse_params, circuit=pulse_circuit, w=w, target_state=target 

390 ) 

391 

392 pulse_params, loss, losses = self.run_optimization( 

393 cost, init_pulse_params, steps, patience, log_interval 

394 ) 

395 

396 # Saving the optimized parameters 

397 self.save_results(pulse_params) 

398 

399 # Plotting the rotation 

400 if self.make_plots: 

401 warnings.warn("Plotting not implemented yet", UserWarning) 

402 

403 return pulse_params, loss, losses 

404 

405 def optimize_RX( 

406 self, 

407 steps: int = 1000, 

408 patience: int = 100, 

409 w: float = jnp.pi, 

410 init_pulse_params: jnp.array = jnp.array([1.0, 15.0, 1.0]), 

411 log_interval: int = 50, 

412 ): 

413 """ 

414 Optimize pulse parameters for the RX(w) gate to best approximate 

415 the unitary RX(w) gate. 

416 

417 Uses gradient-based optimization to minimize the difference between the 

418 pulse-based RX(w) circuit expectation value and the target gate-based RX(w). 

419 

420 Args: 

421 steps (int): Number of optimization steps. Default: 1000. 

422 patience (int): Amount of epochs without improvement before early stopping. 

423 Default: 100. 

424 w (float): Rotation angle in radians with which to run the optimization. 

425 Default: π. 

426 init_pulse_params (jnp.ndarray): Initial pulse parameters (A, sigma) and 

427 time. Default: [1.0, 15.0, 1.0]. 

428 log_interval (int): Frequency of printing loss during optimization. 

429 Default: 50. 

430 

431 Returns: 

432 tuple: Optimized parameters (jnp.ndarray) and list of loss values 

433 during optimization. 

434 """ 

435 self.current_gate = "RX" 

436 

437 dev = qml.device("default.qubit", wires=1) 

438 

439 @qml.qnode(dev, interface="jax") 

440 def pulse_circuit(w, pulse_params): 

441 Gates.RX(w, 0, pulse_params=pulse_params, gate_mode="pulse") 

442 return qml.state() 

443 

444 @qml.qnode(dev) 

445 def unitary_circuit(w): 

446 qml.RX(w, wires=0) 

447 return qml.state() 

448 

449 target = unitary_circuit(w) 

450 

451 # Optimizing 

452 def cost(pulse_params): 

453 return self.cost_fn( 

454 pulse_params, circuit=pulse_circuit, w=w, target_state=target 

455 ) 

456 

457 pulse_params, loss, losses = self.run_optimization( 

458 cost, init_pulse_params, steps, patience, log_interval 

459 ) 

460 

461 # Saving the optimized parameters 

462 self.save_results(pulse_params) 

463 

464 # Plotting the RX rotation 

465 if self.make_plots: 

466 log.info("Plotting RX rotation...") 

467 self.plot_rotation(pulse_params) 

468 

469 return pulse_params, loss, losses 

470 

471 def optimize_RY( 

472 self, 

473 steps: int = 1000, 

474 patience: int = 100, 

475 w: float = jnp.pi, 

476 init_pulse_params: jnp.array = jnp.array([1.0, 15.0, 1.0]), 

477 log_interval: int = 50, 

478 ): 

479 """ 

480 Optimize pulse parameters for the RY(w) gate to best approximate 

481 the unitary RY(w) gate. 

482 

483 Uses gradient-based optimization to minimize the difference between the 

484 pulse-based RY(w) circuit expectation value and the target unitary-based RY(w). 

485 

486 Args: 

487 steps (int): Number of optimization steps. Default: 1000. 

488 patience (int): Amount of epochs without improvement before early stopping. 

489 Default: 100. 

490 w (float): Rotation angle in radians with which to run the optimization. 

491 Default: π. 

492 init_pulse_params (jnp.ndarray): Initial pulse parameters (A, sigma) and 

493 time. Default: [1.0, 15.0, 1.0]. 

494 log_interval (int): Frequency of printing loss during optimization. 

495 Default: 50. 

496 

497 Returns: 

498 tuple: Optimized parameters (jnp.ndarray) and list of loss values 

499 during optimization. 

500 """ 

501 self.current_gate = "RY" 

502 

503 dev = qml.device("default.qubit", wires=1) 

504 

505 @qml.qnode(dev, interface="jax") 

506 def pulse_circuit(w, pulse_params): 

507 Gates.RY(w, 0, pulse_params=pulse_params, gate_mode="pulse") 

508 return qml.state() 

509 

510 @qml.qnode(dev) 

511 def unitary_circuit(w): 

512 qml.RY(w, wires=0) 

513 return qml.state() 

514 

515 target = unitary_circuit(w) 

516 

517 # Optimizing 

518 def cost(pulse_params): 

519 return self.cost_fn( 

520 pulse_params, circuit=pulse_circuit, w=w, target_state=target 

521 ) 

522 

523 pulse_params, loss, losses = self.run_optimization( 

524 cost, init_pulse_params, steps, patience, log_interval 

525 ) 

526 

527 # Saving the optimized parameters 

528 self.save_results(pulse_params) 

529 

530 # Plotting the RY rotation 

531 if self.make_plots: 

532 log.info("Plotting RY rotation...") 

533 self.plot_rotation(pulse_params) 

534 

535 return pulse_params, loss, losses 

536 

537 def optimize_RZ(self): 

538 """ 

539 Plot the pulse level RZ rotation on the X basis. 

540 

541 Note: 

542 No actual optimization is performed since the RZ gate 

543 does not have pulse parameters to optimize. 

544 

545 Returns: 

546 tuple: (None, None) 

547 """ 

548 self.current_gate = "RZ" 

549 if self.make_plots: 

550 log.info("Plotting RZ rotation...") 

551 self.plot_rotation([]) 

552 

553 return None, None 

554 

555 def optimize_H( 

556 self, 

557 steps=1000, 

558 patience: int = 100, 

559 init_pulse_params: jnp.array = jnp.array([1.0, 15.0, 1.0]), 

560 log_interval: int = 50, 

561 ): 

562 """ 

563 Optimize pulse parameters for the Hadamard (H) gate to best approximate 

564 the unitary H gate. 

565 

566 Uses gradient-based optimization to minimize the difference between the 

567 pulse-based H circuit output state and the target gate-based H state. 

568 

569 Args: 

570 steps (int): Number of optimization steps. Default: 1000. 

571 patience (int): Amount of epochs without improvement before early stopping. 

572 Default: 100. 

573 init_pulse_params (jnp.ndarray): Initial pulse parameters (A, sigma, t) 

574 Default: [1.0, 15.0, 1.0]. 

575 log_interval (int): Frequency of printing loss during optimization. 

576 Default: 50. 

577 

578 Returns: 

579 tuple: Optimized parameters (jnp.ndarray), best loss (float), and 

580 list of loss values during optimization. 

581 """ 

582 self.current_gate = "H" 

583 

584 dev = qml.device("default.qubit", wires=1) 

585 

586 @qml.qnode(dev, interface="jax") 

587 def pulse_circuit(w, pulse_params): 

588 Gates.H(0, pulse_params=pulse_params, gate_mode="pulse") 

589 return qml.state() 

590 

591 @qml.qnode(dev) 

592 def unitary_circuit(): 

593 qml.H(wires=0) 

594 return qml.state() 

595 

596 target = unitary_circuit() 

597 

598 # Optimizing 

599 def cost(pulse_params): 

600 return self.cost_fn( 

601 pulse_params, circuit=pulse_circuit, w=None, target_state=target 

602 ) 

603 

604 pulse_params, loss, losses = self.run_optimization( 

605 cost, init_pulse_params, steps, patience, log_interval 

606 ) 

607 

608 # Saving the optimized parameters 

609 self.save_results(pulse_params) 

610 

611 # Plotting the RX rotation 

612 if self.make_plots: 

613 log.info("Plotting H rotation...") 

614 self.plot_rotation(pulse_params) 

615 

616 return pulse_params, loss, losses 

617 

618 def optimize_CZ( 

619 self, 

620 steps=1000, 

621 patience: int = 100, 

622 init_pulse_params: jnp.ndarray = jnp.array([1.0]), 

623 log_interval: int = 50, 

624 ): 

625 """ 

626 Optimize pulse parameters for the CZ gate to best approximate 

627 the unitary CZ gate. 

628 

629 Uses gradient-based optimization to minimize the difference between the 

630 pulse-based H_c · H_t · CZ circuit expectation value and the target 

631 unitary-based H_c · H_t · CZ. 

632 

633 Args: 

634 steps (int): Number of optimization steps. Default: 1000. 

635 patience (int): Amount of epochs without improvement before early stopping. 

636 Default: 100. 

637 init_pulse_params (jnp.ndarray): Initial pulse duration. Default: [1.0]. 

638 log_interval (int): Frequency of printing loss during optimization. 

639 Default: 50. 

640 

641 Returns: 

642 tuple: Optimized parameters (jnp.ndarray) and list of loss values 

643 during optimization. 

644 """ 

645 self.current_gate = "CZ" 

646 

647 dev = qml.device("default.qubit", wires=2) 

648 

649 @qml.qnode(dev, interface="jax") 

650 def pulse_circuit(w, pulse_params): 

651 qml.H(wires=0) 

652 qml.H(wires=1) 

653 Gates.CZ(wires=[0, 1], pulse_params=pulse_params, gate_mode="pulse") 

654 return qml.state() 

655 

656 @qml.qnode(dev) 

657 def unitary_circuit(): 

658 qml.H(wires=0) 

659 qml.H(wires=1) 

660 qml.CZ(wires=[0, 1]) 

661 return qml.state() 

662 

663 target = unitary_circuit() 

664 

665 # Optimizing 

666 def cost(pulse_params): 

667 return self.cost_fn( 

668 pulse_params, circuit=pulse_circuit, w=None, target_state=target 

669 ) 

670 

671 pulse_params, loss, losses = self.run_optimization( 

672 cost, init_pulse_params, steps, patience, log_interval 

673 ) 

674 

675 # Saving the optimized parameters 

676 self.save_results(pulse_params) 

677 

678 # Plotting the CZ rotation 

679 if self.make_plots: 

680 log.info("Plotting CZ rotation...") 

681 self.plot_rotation(pulse_params) 

682 

683 return pulse_params, loss, losses 

684 

685 def optimize_CY( 

686 self, 

687 steps=1000, 

688 patience: int = 100, 

689 init_pulse_params: jnp.ndarray = jnp.array( 

690 [0.5, 15.0, 10.0, 1.0, 15.0, 10.0, 1.0, 1.0, 0.5] 

691 ), 

692 log_interval: int = 50, 

693 ): 

694 """ 

695 Optimize pulse parameters for the CY gate to best approximate the 

696 unitary CY gate. 

697 

698 Uses gradient-based optimization to minimize the difference between the 

699 pulse-based H_c · CY circuit expectation value and the target 

700 unitary-based H_c · CY. 

701 

702 Args: 

703 steps (int): Number of optimization steps. Default: 1000. 

704 patience (int): Amount of epochs without improvement before early stopping. 

705 Default: 100. 

706 init_pulse_params (jnp.ndarray): Initial pulse parameters. 

707 Default: [1.0, 15.0, 1.0, 1.0, 1.0, 15.0, 1.0]. 

708 log_interval (int): Frequency of printing loss during optimization. 

709 Default: 50. 

710 

711 Returns: 

712 tuple: Optimized parameters (jnp.ndarray) and list of loss values during 

713 optimization. 

714 """ 

715 self.current_gate = "CY" 

716 

717 dev = qml.device("default.qubit", wires=2) 

718 

719 @qml.qnode(dev, interface="jax") 

720 def pulse_circuit(w, pulse_params): 

721 qml.H(wires=0) 

722 Gates.CY(wires=[0, 1], pulse_params=pulse_params, gate_mode="pulse") 

723 return qml.state() 

724 

725 @qml.qnode(dev) 

726 def unitary_circuit(): 

727 qml.H(wires=0) 

728 qml.CY(wires=[0, 1]) 

729 return qml.state() 

730 

731 target = unitary_circuit() 

732 

733 # Optimizing 

734 def cost(pulse_params): 

735 return self.cost_fn( 

736 pulse_params, circuit=pulse_circuit, w=None, target_state=target 

737 ) 

738 

739 pulse_params, loss, losses = self.run_optimization( 

740 cost, init_pulse_params, steps, patience, log_interval 

741 ) 

742 

743 # Saving the optimized parameters 

744 self.save_results(pulse_params) 

745 

746 # Plotting the CY rotation 

747 if self.make_plots: 

748 warnings.warn("Plotting not implemented yet", UserWarning) 

749 

750 return pulse_params, loss, losses 

751 

752 def optimize_CX( 

753 self, 

754 steps=1000, 

755 patience: int = 100, 

756 init_pulse_params: jnp.ndarray = jnp.array( 

757 [1.0, 15.0, 1.0, 1.0, 1.0, 15.0, 1.0] 

758 ), 

759 log_interval: int = 50, 

760 ): 

761 """ 

762 Optimize pulse parameters for the CX gate to best approximate the 

763 unitary CX gate. 

764 

765 Uses gradient-based optimization to minimize the difference between the 

766 pulse-based H_c · CX circuit expectation value and the target 

767 unitary-based H_c · CX. 

768 

769 Args: 

770 steps (int): Number of optimization steps. Default: 1000. 

771 patience (int): Amount of epochs without improvement before early stopping. 

772 Default: 100. 

773 init_pulse_params (jnp.ndarray): Initial pulse parameters. 

774 Default: [1.0, 15.0, 1.0, 1.0, 1.0, 15.0, 1.0]. 

775 log_interval (int): Frequency of printing loss during optimization. 

776 Default: 50. 

777 

778 Returns: 

779 tuple: Optimized parameters (jnp.ndarray) and list of loss values during 

780 optimization. 

781 """ 

782 self.current_gate = "CX" 

783 

784 dev = qml.device("default.qubit", wires=2) 

785 

786 @qml.qnode(dev, interface="jax") 

787 def pulse_circuit(w, pulse_params): 

788 qml.H(wires=0) 

789 Gates.CX(wires=[0, 1], pulse_params=pulse_params, gate_mode="pulse") 

790 return qml.state() 

791 

792 @qml.qnode(dev) 

793 def unitary_circuit(): 

794 qml.H(wires=0) 

795 qml.CNOT(wires=[0, 1]) 

796 return qml.state() 

797 

798 target = unitary_circuit() 

799 

800 # Optimizing 

801 def cost(pulse_params): 

802 return self.cost_fn( 

803 pulse_params, circuit=pulse_circuit, w=None, target_state=target 

804 ) 

805 

806 pulse_params, loss, losses = self.run_optimization( 

807 cost, init_pulse_params, steps, patience, log_interval 

808 ) 

809 

810 # Saving the optimized parameters 

811 self.save_results(pulse_params) 

812 

813 # Plotting the CX rotation 

814 if self.make_plots: 

815 log.info("Plotting CX rotation...") 

816 self.plot_rotation(pulse_params) 

817 

818 return pulse_params, loss, losses 

819 

820 def optimize_CRX( 

821 self, 

822 steps=1000, 

823 patience: int = 100, 

824 w: float = jnp.pi, 

825 init_pulse_params: jnp.ndarray = jnp.array( 

826 [10.0, 15.0, 1.0, 0.5, 1.0, 0.5, 10.0, 15.0, 1.0] 

827 ), 

828 log_interval: int = 50, 

829 ): 

830 """ 

831 Optimize pulse parameters for the CRX(w) gate to best approximate 

832 the unitary CRX(w) gate. 

833 

834 Uses gradient-based optimization to minimize the difference between the 

835 pulse-based H_c · CRX(w) circuit expectation value and the target 

836 unitary-based H_c · CRX(w). 

837 

838 Args: 

839 steps (int): Number of optimization steps. Default: 1000. 

840 patience (int): Amount of epochs without improvement before early stopping. 

841 Default: 100. 

842 w (float): Rotation angle. 

843 init_pulse_params (jnp.ndarray): Initial pulse parameters. 

844 log_interval (int): Frequency of printing loss. 

845 

846 Returns: 

847 tuple: Optimized parameters (jnp.ndarray) and list of loss values. 

848 """ 

849 self.current_gate = "CRX" 

850 

851 dev = qml.device("default.qubit", wires=2) 

852 

853 @qml.qnode(dev, interface="jax") 

854 def pulse_circuit(w, pulse_params): 

855 qml.H(wires=0) 

856 Gates.CRX(w, wires=[0, 1], pulse_params=pulse_params, gate_mode="pulse") 

857 return qml.state() 

858 

859 @qml.qnode(dev) 

860 def unitary_circuit(w): 

861 qml.H(wires=0) 

862 qml.CRX(w, wires=[0, 1]) 

863 return qml.state() 

864 

865 target = unitary_circuit(w) 

866 

867 def cost(pulse_params): 

868 return self.cost_fn( 

869 pulse_params, circuit=pulse_circuit, w=w, target_state=target 

870 ) 

871 

872 pulse_params, loss, losses = self.run_optimization( 

873 cost, init_pulse_params, steps, patience, log_interval 

874 ) 

875 

876 self.save_results(pulse_params) 

877 

878 if self.make_plots: 

879 warnings.warn("Plotting not implemented yet", UserWarning) 

880 

881 return pulse_params, loss, losses 

882 

883 def optimize_CRY( 

884 self, 

885 steps=1000, 

886 patience: int = 100, 

887 w: float = jnp.pi, 

888 init_pulse_params: jnp.ndarray = jnp.array( 

889 [10.0, 15.0, 1.0, 0.5, 1.0, 0.5, 10.0, 15.0, 1.0] 

890 ), 

891 log_interval: int = 50, 

892 ): 

893 """ 

894 Optimize pulse parameters for the CRY(w) gate to best approximate 

895 the unitary CRY(w) gate. 

896 

897 Uses gradient-based optimization to minimize the difference between the 

898 pulse-based H_c · CRY(w) circuit expectation value and the target 

899 unitary-based H_c · CRY(w). 

900 

901 Args: 

902 steps (int): Number of optimization steps. Default: 1000. 

903 patience (int): Amount of epochs without improvement before early stopping. 

904 Default: 100. 

905 w (float): Rotation angle. 

906 init_pulse_params (jnp.ndarray): Initial pulse parameters. 

907 log_interval (int): Frequency of printing loss. 

908 

909 Returns: 

910 tuple: Optimized parameters (jnp.ndarray) and list of loss values. 

911 """ 

912 self.current_gate = "CRY" 

913 

914 dev = qml.device("default.qubit", wires=2) 

915 

916 @qml.qnode(dev, interface="jax") 

917 def pulse_circuit(w, pulse_params): 

918 qml.H(wires=0) 

919 Gates.CRY(w, wires=[0, 1], pulse_params=pulse_params, gate_mode="pulse") 

920 return qml.state() 

921 

922 @qml.qnode(dev) 

923 def unitary_circuit(w): 

924 qml.H(wires=0) 

925 qml.CRY(w, wires=[0, 1]) 

926 return qml.state() 

927 

928 target = unitary_circuit(w) 

929 

930 def cost(pulse_params): 

931 return self.cost_fn( 

932 pulse_params, circuit=pulse_circuit, w=w, target_state=target 

933 ) 

934 

935 pulse_params, loss, losses = self.run_optimization( 

936 cost, init_pulse_params, steps, patience, log_interval 

937 ) 

938 

939 self.save_results(pulse_params) 

940 

941 if self.make_plots: 

942 warnings.warn("Plotting not implemented yet", UserWarning) 

943 

944 return pulse_params, loss, losses 

945 

946 def optimize_CRZ( 

947 self, 

948 steps=1000, 

949 patience: int = 100, 

950 w: float = jnp.pi, 

951 init_pulse_params: jnp.ndarray = jnp.array([0.5, 2.0, 0.5]), 

952 log_interval: int = 50, 

953 ): 

954 """ 

955 Optimize pulse parameters for the CRZ(w) gate to best approximate 

956 the unitary CRZ(w) gate. 

957 

958 Uses gradient-based optimization to minimize the difference between the 

959 pulse-based H_c · H_t · CRZ(w) circuit expectation value and the target 

960 unitary-based H_c · H_t · CRZ(w). 

961 

962 Args: 

963 steps (int): Number of optimization steps. Default: 1000. 

964 patience (int): Early stopping patience. Default: 100. 

965 w (float): Rotation angle. 

966 init_pulse_params (jnp.ndarray): Initial pulse parameters. 

967 log_interval (int): Frequency of printing loss. 

968 

969 Returns: 

970 tuple: Optimized parameters (jnp.ndarray) and list of loss values. 

971 """ 

972 self.current_gate = "CRZ" 

973 

974 dev = qml.device("default.qubit", wires=2) 

975 

976 # Pulse circuit with full parametric decomposition 

977 @qml.qnode(dev, interface="jax") 

978 def pulse_circuit(w, pulse_params): 

979 qml.H(wires=0) 

980 qml.H(wires=1) 

981 Gates.CRZ(w, wires=[0, 1], pulse_params=pulse_params, gate_mode="pulse") 

982 return qml.state() 

983 

984 @qml.qnode(dev) 

985 def unitary_circuit(w): 

986 qml.H(wires=0) 

987 qml.H(wires=1) 

988 qml.CRZ(w, wires=[0, 1]) 

989 return qml.state() 

990 

991 target = unitary_circuit(w) 

992 

993 # Cost function 

994 def cost(pulse_params): 

995 return self.cost_fn( 

996 pulse_params, circuit=pulse_circuit, w=w, target_state=target 

997 ) 

998 

999 pulse_params, loss, losses = self.run_optimization( 

1000 cost, init_pulse_params, steps, patience, log_interval 

1001 ) 

1002 

1003 self.save_results(pulse_params) 

1004 

1005 if self.make_plots: 

1006 warnings.warn("Plotting not implemented yet", UserWarning) 

1007 

1008 return pulse_params, loss, losses 

1009 

1010 

1011if __name__ == "__main__": 

1012 qoc = QOC( 

1013 make_plots=False, 

1014 fig_points=40, 

1015 fig_dir="docs/figures", 

1016 file_dir="qml_essentials", 

1017 ) 

1018 

1019 # - Run optimization for Rot gate - 

1020 # log.info("Optimizing Rot gate...") 

1021 # optimized_pulse_params, best_loss, loss_values = qoc.optimize_Rot() 

1022 # log.info(f"Optimized parameters for Rot: {optimized_pulse_params}\n") 

1023 # log.info(f"Best achieved fidelity: {1 - best_loss}") 

1024 # log.info("-" * 20, "\n") 

1025 

1026 # # - Run optimization for RX gate - 

1027 # log.info("Optimizing RX gate...") 

1028 # optimized_pulse_params, best_loss, loss_values = qoc.optimize_RX( 

1029 # w=jnp.pi, init_pulse_params=jnp.array([1.0, 15.0, 1.0]) 

1030 # ) 

1031 # log.info(f"Optimized parameters for RX: {optimized_pulse_params}\n") 

1032 # log.info(f"Best achieved fidelity: {1 - best_loss}") 

1033 # log.info("-" * 20, "\n") 

1034 

1035 # # - Run optimization for RY gate - 

1036 # log.info("Optimizing RY gate...") 

1037 # optimized_pulse_params, best_loss, loss_values = qoc.optimize_RY( 

1038 # w=jnp.pi, init_pulse_params=jnp.array([1.0, 15.0, 1.0]) 

1039 # ) 

1040 # log.info(f"Optimized parameters for RY: {optimized_pulse_params}\n") 

1041 # log.info(f"Best achieved fidelity: {1 - best_loss:.6f}") 

1042 # log.info("-" * 20, "\n") 

1043 

1044 # # - Run optimization for RZ gate - 

1045 # log.info("Plotting RZ gate rotation...") 

1046 # qoc.optimize_RZ() 

1047 # log.info("Plotted RZ gate rotation") 

1048 # log.info("-" * 20, "\n") 

1049 

1050 # # - Run optimization for H gate - 

1051 # log.info("Optimizing H gate...") 

1052 # optimized_pulse_params, best_loss, loss_values = qoc.optimize_H( 

1053 # init_pulse_params=jnp.array([1.0, 15.0, 1.0]) 

1054 # ) 

1055 # log.info(f"Optimized parameters for H: {optimized_pulse_params}\n") 

1056 # log.info(f"Best achieved fidelity: {1 - best_loss}") 

1057 # log.info("-" * 20, "\n") 

1058 

1059 # # - Run optimization for CZ gate - 

1060 # log.info("Optimizing CZ gate...") 

1061 # optimized_pulse_params, best_loss, loss_values = qoc.optimize_CZ( 

1062 # init_pulse_params=jnp.array([0.975]), log_interval=5 

1063 # ) 

1064 # log.info(f"Optimized parameters for CZ: {optimized_pulse_params}\n") 

1065 # log.info(f"Best achieved fidelity: {1 - best_loss}") 

1066 # log.info("-" * 20, "\n") 

1067 

1068 # - Run optimization for CY gate - 

1069 # log.info("Optimizing CY gate...") 

1070 # optimized_pulse_params, best_loss, loss_values = qoc.optimize_CY(log_interval=50) 

1071 # log.info(f"Optimized parameters for CY: {optimized_pulse_params}\n") 

1072 # log.info(f"Best achieved fidelity: {1 - best_loss}") 

1073 # log.info("-" * 20, "\n") 

1074 

1075 # - Run optimization for CX gate - 

1076 # log.info("Optimizing CX gate...") 

1077 # optimized_pulse_params, best_loss, loss_values = qoc.optimize_CX(log_interval=50) 

1078 # log.info(f"Optimized parameters for CX: {optimized_pulse_params}\n") 

1079 # log.info(f"Best achieved fidelity: {1 - best_loss}") 

1080 # log.info("-" * 20, "\n") 

1081 

1082 # - Run optimization for CRX gate - 

1083 # log.info("Optimizing CRX gate...") 

1084 # optimized_pulse_params, best_loss, loss_values = qoc.optimize_CRX(w=jnp.pi) 

1085 # log.info(f"Optimized parameters for CRX: {optimized_pulse_params}\n") 

1086 # log.info(f"Best achieved fidelity: {1 - best_loss}") 

1087 # log.info("-" * 20, "\n") 

1088 

1089 # - Run optimization for CRY gate - 

1090 # log.info("Optimizing CRY gate...") 

1091 # optimized_pulse_params, best_loss, loss_values = qoc.optimize_CRY(w=jnp.pi) 

1092 # log.info(f"Optimized parameters for CRY: {optimized_pulse_params}\n") 

1093 # log.info(f"Best achieved fidelity: {1 - best_loss}") 

1094 # log.info("-" * 20, "\n") 

1095 

1096 # - Run optimization for CRZ gate - 

1097 # log.info("Optimizing CRZ gate...") 

1098 # optimized_pulse_params, best_loss, loss_values = qoc.optimize_CRZ() 

1099 # log.info(f"Optimized parameters for CRZ: {optimized_pulse_params}\n") 

1100 # log.info(f"Best achieved fidelity: {1 - best_loss}") 

1101 # log.info("-" * 20, "\n")