Coverage for qml_essentials / model.py: 91%

529 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-06-11 15:51 +0000

1from typing import Any, Dict, Optional, Tuple, Callable, Union, List 

2 

3import warnings 

4import jax.numpy as jnp 

5import numpy as np 

6from jax import random 

7 

8from qml_essentials import jaqsi as js 

9from qml_essentials import operations as op 

10from qml_essentials.tape import recording 

11from qml_essentials.operations import KrausChannel 

12from qml_essentials.ansaetze import Ansaetze, Circuit, Encoding 

13from qml_essentials.gates import Gates, PulseInformation as pinfo 

14from qml_essentials.utils import safe_random_split 

15 

16import logging 

17 

18log = logging.getLogger(__name__) 

19 

20 

21class Model: 

22 """ 

23 A quantum circuit model. 

24 """ 

25 

26 def __init__( 

27 self, 

28 n_qubits: int, 

29 n_layers: int, 

30 circuit_type: Union[str, Circuit] = "No_Ansatz", 

31 data_reupload: Union[bool, List[List[bool]], List[List[List[bool]]]] = True, 

32 state_preparation: Union[ 

33 str, Callable, List[Union[str, Callable]], None 

34 ] = None, 

35 encoding: Union[Encoding, str, Callable, List[Union[str, Callable]]] = Gates.RX, 

36 trainable_frequencies: bool = False, 

37 initialization: str = "random", 

38 initialization_domain: List[float] = [0, 2 * jnp.pi], 

39 output_qubit: Union[List[int], int] = -1, 

40 shots: Optional[int] = None, 

41 random_seed: int = 1000, 

42 remove_zero_encoding: bool = True, 

43 repeat_batch_axis: List[bool] = [True, True, True], 

44 pulse_shape: str = "gaussian", 

45 ) -> None: 

46 """ 

47 Initialize the quantum circuit model. 

48 Parameters will have the shape [impl_n_layers, parameters_per_layer] 

49 where impl_n_layers is the number of layers provided and added by one 

50 depending if data_reupload is True and parameters_per_layer is given by 

51 the chosen ansatz. 

52 

53 The model is initialized with the following parameters as defaults: 

54 - noise_params: None 

55 - execution_type: "expval" 

56 - shots: None 

57 

58 Args: 

59 n_qubits (int): The number of qubits in the circuit. 

60 n_layers (int): The number of layers in the circuit. 

61 circuit_type (str, Circuit): The type of quantum circuit to use. 

62 If None, defaults to "no_ansatz". 

63 data_reupload (Union[bool, List[bool], List[List[bool]]], optional): 

64 Whether to reupload data to the quantum device on each 

65 layer and qubit. Detailed re-uploading instructions can be given 

66 as a list/array of 0/False and 1/True with shape (n_qubits, 

67 n_layers) to specify where to upload the data. Defaults to True 

68 for applying data re-uploading to the full circuit. 

69 encoding (Union[str, Callable, List[str], List[Callable]], optional): 

70 The unitary to use for encoding the input data. Can be a string 

71 (e.g. "RX") or a callable (e.g. op.RX). Defaults to op.RX. 

72 If input is multidimensional it is assumed to be a list of 

73 unitaries or a list of strings. 

74 trainable_frequencies (bool, optional): 

75 Sets trainable encoding parameters for trainable frequencies. 

76 Defaults to False. 

77 initialization (str, optional): The strategy to initialize the parameters. 

78 Can be "random", "zeros", "zero-controlled", "pi", or "pi-controlled". 

79 Defaults to "random". 

80 output_qubit (List[int], int, optional): The index of the output 

81 qubit (or qubits). When set to -1 all qubits are measured, or a 

82 global measurement is conducted, depending on the execution 

83 type. 

84 shots (Optional[int], optional): The number of shots to use for 

85 the quantum device. Defaults to None. 

86 random_seed (int, optional): seed for the random number generator 

87 in initialization is "random" and for random noise parameters. 

88 Defaults to 1000. 

89 remove_zero_encoding (bool, optional): whether to 

90 remove the zero encoding from the circuit. Defaults to True. 

91 repeat_batch_axis (List[bool], optional): Each boolean in the array 

92 determines over which axes to parallelise computation. The axes 

93 correspond to [inputs, params, pulse_params]. Defaults to 

94 [True, True, True], meaning that batching is enabled over all 

95 axes. 

96 pulse_shape (str, optional): Pulse envelope shape for pulse-level 

97 simulation. One of ``PulseEnvelope.available()``. 

98 Defaults to ``"gaussian"``. 

99 

100 Returns: 

101 None 

102 """ 

103 # Initialize default parameters needed for circuit evaluation 

104 self.n_qubits: int = n_qubits 

105 self.output_qubit: Union[List[int], int] = output_qubit 

106 self.n_layers: int = n_layers 

107 self.noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None 

108 self.shots = shots 

109 self.remove_zero_encoding = remove_zero_encoding 

110 self.trainable_frequencies: bool = trainable_frequencies 

111 self.execution_type: str = "expval" 

112 self.repeat_batch_axis: List[bool] = repeat_batch_axis 

113 

114 # --- Pulse envelope --- 

115 pinfo.set_envelope(pulse_shape) 

116 

117 # --- State Preparation --- 

118 try: 

119 self._sp = Gates.parse_gates(state_preparation, Gates) 

120 except ValueError as e: 

121 raise ValueError(f"Error parsing encodings: {e}") 

122 

123 # prepare corresponding pulse parameters (always optimized pulses) 

124 self.sp_pulse_params = [] 

125 for sp in self._sp: 

126 sp_name = sp.__name__ if hasattr(sp, "__name__") else str(sp) 

127 

128 if pinfo.gate_by_name(sp_name) is not None: 

129 self.sp_pulse_params.append(pinfo.gate_by_name(sp_name).params) 

130 else: 

131 # gate has no pulse parametrization 

132 self.sp_pulse_params.append(None) 

133 

134 # --- Encoding --- 

135 if isinstance(encoding, Encoding): 

136 # user wants custom strategy? do it! 

137 self._enc = encoding 

138 else: 

139 # use hammming encoding by default 

140 self._enc = Encoding("hamming", encoding) 

141 

142 if self._enc.is_golomb: 

143 self._enc._n_qubits = n_qubits 

144 

145 # Number of possible inputs 

146 self.n_input_feat = len(self._enc) 

147 log.debug(f"Number of input features: {self.n_input_feat}") 

148 

149 # Trainable frequencies, default initialization as in arXiv:2309.03279v2 

150 self.enc_params = jnp.ones((self.n_layers, self.n_qubits, self.n_input_feat)) 

151 

152 self._zero_inputs = False 

153 

154 # --- Data-Reuploading --- 

155 

156 # Keep as NumPy array (not JAX) so that ``if data_reupload[q, idx]`` 

157 # in _iec remains a concrete Python bool even under jax.jit tracing. 

158 # note that setting this will also update self.degree and self.frequencies 

159 # and in consequence also self.has_dru 

160 self.data_reupload = data_reupload 

161 

162 # check for the highest degree among all input dimensions 

163 if self.has_dru: 

164 impl_n_layers: int = n_layers + 1 # we need L+1 according to Schuld et al. 

165 else: 

166 impl_n_layers = n_layers 

167 log.info(f"Number of implicit layers: {impl_n_layers}.") 

168 

169 # --- Ansatz --- 

170 # only weak check for str. We trust the user to provide sth useful 

171 if isinstance(circuit_type, str): 

172 self.pqc: Callable[[Optional[jnp.ndarray], int], int] = getattr( 

173 Ansaetze, circuit_type or "No_Ansatz" 

174 )() 

175 else: 

176 self.pqc = circuit_type() 

177 log.info(f"Using Ansatz {circuit_type}.") 

178 

179 # calculate the shape of the parameter vector here, we will re-use this in init. 

180 params_per_layer = self.pqc.n_params_per_layer(self.n_qubits) 

181 self._params_shape: Tuple[int, int] = (impl_n_layers, params_per_layer) 

182 log.info(f"Parameters per layer: {params_per_layer}") 

183 

184 pulse_params_per_layer = self.pqc.n_pulse_params_per_layer(self.n_qubits) 

185 self._pulse_params_shape: Tuple[int, int] = ( 

186 impl_n_layers, 

187 pulse_params_per_layer, 

188 ) 

189 

190 # intialize to None as we can't know this yet 

191 self._batch_shape = None 

192 

193 # this will also be re-used in the init method, 

194 # however, only if nothing is provided 

195 self._inialization_strategy = initialization 

196 self._initialization_domain = initialization_domain 

197 

198 # ..here! where we only require a JAX random key 

199 self.random_key = self.initialize_params(random.key(random_seed)) 

200 

201 # Initializing pulse params 

202 self.pulse_params: jnp.ndarray = jnp.ones((1, *self._pulse_params_shape)) 

203 

204 log.info(f"Initialized pulse parameters with shape {self.pulse_params.shape}.") 

205 

206 # Initialise the jaqsi Script that wraps _variational. 

207 # No device selection needed - jaqsi auto-routes between statevector 

208 # and density-matrix simulation based on whether noise channels are 

209 # present on the tape. 

210 self.script = js.Script(f=self._variational, n_qubits=self.n_qubits) 

211 

212 @property 

213 def noise_params(self) -> Optional[Dict[str, Union[float, Dict[str, float]]]]: 

214 """ 

215 Gets the noise parameters of the model. 

216 

217 Returns: 

218 Optional[Dict[str, float]]: A dictionary of 

219 noise parameters or None if not set. 

220 """ 

221 return self._noise_params 

222 

223 @noise_params.setter 

224 def noise_params( 

225 self, kvs: Optional[Dict[str, Union[float, Dict[str, float]]]] 

226 ) -> None: 

227 """ 

228 Sets the noise parameters of the model. 

229 

230 Typically a "noise parameter" refers to the error probability. 

231 ThermalRelaxation is a special case, and supports a dict as value with 

232 structure: 

233 "ThermalRelaxation": 

234 { 

235 "t1": 2000, # relative t1 time. 

236 "t2": 1000, # relative t2 time 

237 "t_factor" 1: # relative gate time factor 

238 }, 

239 

240 Args: 

241 kvs (Optional[Dict[str, Union[float, Dict[str, float]]]]): A 

242 dictionary of noise parameters. If all values are 0.0, the noise 

243 parameters are set to None. 

244 

245 Returns: 

246 None 

247 """ 

248 # set to None if only zero values provided 

249 if kvs is not None and all(v == 0.0 for v in kvs.values()): 

250 kvs = None 

251 

252 # set default values 

253 if kvs is not None: 

254 defaults = { 

255 "BitFlip": 0.0, 

256 "PhaseFlip": 0.0, 

257 "Depolarizing": 0.0, 

258 "MultiQubitDepolarizing": 0.0, 

259 "AmplitudeDamping": 0.0, 

260 "PhaseDamping": 0.0, 

261 "GateError": 0.0, 

262 "ThermalRelaxation": None, 

263 "StatePreparation": 0.0, 

264 "Measurement": 0.0, 

265 } 

266 for key, default_val in defaults.items(): 

267 kvs.setdefault(key, default_val) 

268 

269 # check if there are any keys not supported 

270 for key in kvs.keys(): 

271 if key not in defaults: 

272 warnings.warn( 

273 f"Noise type {key} is not supported by this package", 

274 UserWarning, 

275 ) 

276 

277 # check valid params for thermal relaxation noise channel 

278 tr_params = kvs["ThermalRelaxation"] 

279 if isinstance(tr_params, dict): 

280 tr_params.setdefault("t1", 0.0) 

281 tr_params.setdefault("t2", 0.0) 

282 tr_params.setdefault("t_factor", 0.0) 

283 valid_tr_keys = {"t1", "t2", "t_factor"} 

284 for k in tr_params.keys(): 

285 if k not in valid_tr_keys: 

286 warnings.warn( 

287 f"Thermal Relaxation parameter {k} is not supported " 

288 f"by this package", 

289 UserWarning, 

290 ) 

291 if not all(tr_params.values()) or tr_params["t2"] > 2 * tr_params["t1"]: 

292 warnings.warn( 

293 "Received invalid values for Thermal Relaxation noise " 

294 "parameter. Thermal relaxation is not applied!", 

295 UserWarning, 

296 ) 

297 kvs["ThermalRelaxation"] = 0.0 

298 

299 self._noise_params = kvs 

300 

301 @property 

302 def output_qubit(self) -> List[int]: 

303 """Get the output qubit indices for measurement.""" 

304 return self._output_qubit 

305 

306 @output_qubit.setter 

307 def output_qubit(self, value: Union[int, List[int]]) -> None: 

308 """ 

309 Set the output qubit(s) for measurement. 

310 

311 Args: 

312 value: Qubit index or list of indices. Use -1 for all qubits. 

313 """ 

314 if isinstance(value, list): 

315 assert len(value) <= self.n_qubits, ( 

316 f"Size of output_qubit {len(value)} cannot be\ 

317 larger than number of qubits {self.n_qubits}." 

318 ) 

319 elif isinstance(value, int): 

320 if value == -1: 

321 value = list(range(self.n_qubits)) 

322 else: 

323 assert value < self.n_qubits, ( 

324 f"Output qubit {value} cannot be larger than {self.n_qubits}." 

325 ) 

326 value = [value] 

327 

328 self._output_qubit = value 

329 

330 @property 

331 def execution_type(self) -> str: 

332 """ 

333 Gets the execution type of the model. 

334 

335 Returns: 

336 str: The execution type, one of 'density', 'expval', or 'probs'. 

337 """ 

338 return self._execution_type 

339 

340 @execution_type.setter 

341 def execution_type(self, value: str) -> None: 

342 if value == "density": 

343 self._result_shape = ( 

344 2 ** len(self.output_qubit), 

345 2 ** len(self.output_qubit), 

346 ) 

347 elif value == "expval": 

348 # check if all qubits are used 

349 if len(self.output_qubit) == self.n_qubits: 

350 self._result_shape = (len(self.output_qubit),) 

351 # if not -> parity measurement with only 1D output per pair 

352 # or n_local measurement 

353 else: 

354 self._result_shape = (len(self.output_qubit),) 

355 elif value == "probs": 

356 # in case this is a list of parities, 

357 # each pair has 2^len(qubits) probabilities 

358 n_parity = ( 

359 (2,) * len(self.output_qubit) 

360 if isinstance(self.output_qubit, (Tuple, List)) 

361 else (2,) 

362 ) 

363 self._result_shape = n_parity 

364 elif value == "state": 

365 self._result_shape = (2 ** len(self.output_qubit),) 

366 else: 

367 raise ValueError(f"Invalid execution type: {value}.") 

368 

369 if value == "state" and not self.all_qubit_measurement: 

370 warnings.warn( 

371 f"{value} measurement does ignore output_qubit, which is " 

372 f"{self.output_qubit}.", 

373 UserWarning, 

374 ) 

375 

376 if value == "probs" and self.shots is None: 

377 warnings.warn( 

378 "Setting execution_type to probs without specifying shots.", 

379 UserWarning, 

380 ) 

381 

382 if value == "density" and self.shots is not None: 

383 raise ValueError("Setting execution_type to density with shots not None.") 

384 

385 self._execution_type = value 

386 

387 @property 

388 def shots(self) -> Optional[int]: 

389 """ 

390 Gets the number of shots to use for the quantum device. 

391 

392 Returns: 

393 Optional[int]: The number of shots. 

394 """ 

395 return self._shots 

396 

397 @shots.setter 

398 def shots(self, value: Optional[int]) -> None: 

399 """ 

400 Sets the number of shots to use for the quantum device. 

401 

402 Args: 

403 value (Optional[int]): The number of shots. 

404 If an integer less than or equal to 0 is provided, it is set to None. 

405 

406 Returns: 

407 None 

408 """ 

409 if type(value) is int and value <= 0: 

410 value = None 

411 self._shots = value 

412 

413 @property 

414 def params(self) -> jnp.ndarray: 

415 """Get the variational parameters of the model.""" 

416 return self._params 

417 

418 @params.setter 

419 def params(self, value: jnp.ndarray) -> None: 

420 """Set the variational parameters, ensuring batch dimension exists.""" 

421 if len(value.shape) == 2: 

422 value = value.reshape(1, *value.shape) 

423 

424 self._params = value 

425 

426 @property 

427 def enc_params(self) -> jnp.ndarray: 

428 """Get the encoding parameters used for input transformation.""" 

429 return self._enc_params 

430 

431 @enc_params.setter 

432 def enc_params(self, value: jnp.ndarray) -> None: 

433 """Set the encoding parameters.""" 

434 self._enc_params = value 

435 

436 @property 

437 def pulse_params(self) -> jnp.ndarray: 

438 """Get the pulse parameters for pulse-mode gate execution.""" 

439 return self._pulse_params 

440 

441 @pulse_params.setter 

442 def pulse_params(self, value: jnp.ndarray) -> None: 

443 """Set the pulse parameters.""" 

444 self._pulse_params = value 

445 

446 @property 

447 def data_reupload(self) -> jnp.ndarray: 

448 """Get the data reupload mask.""" 

449 return self._data_reupload 

450 

451 @data_reupload.setter 

452 def data_reupload(self, value: jnp.ndarray) -> None: 

453 """Set the data reupload mask. 

454 

455 Always converts to a concrete NumPy boolean array so that 

456 ``if data_reupload[q, idx]`` in :meth:`_iec` remains a plain 

457 Python ``bool`` even inside JAX-traced functions (jit / grad / vmap). 

458 """ 

459 # Process data reuploading strategy and set degree 

460 if not isinstance(value, bool): 

461 if not isinstance(value, np.ndarray): 

462 value = np.array(value) 

463 

464 if len(value.shape) == 2: 

465 assert value.shape == ( 

466 self.n_layers, 

467 self.n_qubits, 

468 ), ( 

469 f"Data reuploading array has wrong shape. \ 

470 Expected {(self.n_layers, self.n_qubits)} or\ 

471 {(self.n_layers, self.n_qubits, self.n_input_feat)},\ 

472 got {value.shape}." 

473 ) 

474 value = value.reshape(*value.shape, 1) 

475 value = np.repeat(value, self.n_input_feat, axis=2) 

476 

477 assert value.shape == ( 

478 self.n_layers, 

479 self.n_qubits, 

480 self.n_input_feat, 

481 ), ( 

482 f"Data reuploading array has wrong shape. \ 

483 Expected {(self.n_layers, self.n_qubits, self.n_input_feat)},\ 

484 got {value.shape}." 

485 ) 

486 

487 log.debug(f"Data reuploading array:\n{value}") 

488 else: 

489 if value: 

490 value = np.ones((self.n_layers, self.n_qubits, self.n_input_feat)) 

491 log.debug("Full data reuploading.") 

492 else: 

493 value = np.zeros((self.n_layers, self.n_qubits, self.n_input_feat)) 

494 value[0][0] = 1 

495 log.debug("No data reuploading.") 

496 

497 # convert to boolean values 

498 self._data_reupload = np.asarray(value).astype(bool) 

499 

500 self.degree: Tuple = tuple( 

501 self._enc.get_n_freqs(np.count_nonzero(self.data_reupload[..., i])) 

502 for i in range(self.n_input_feat) 

503 ) 

504 

505 self.frequencies: Tuple = tuple( 

506 self._enc.get_spectrum(np.count_nonzero(self.data_reupload[..., i])) 

507 for i in range(self.n_input_feat) 

508 ) 

509 

510 # Cache has_dru as a plain Python bool so that it can be used in 

511 # Python ``if`` statements even inside JAX-traced functions. 

512 self._has_dru: bool = bool(max(int(np.max(f)) for f in self._frequencies) > 1) 

513 

514 @property 

515 def degree(self) -> Tuple: 

516 """Get the degree of the model.""" 

517 return self._degree 

518 

519 @degree.setter 

520 def degree(self, value: Tuple): 

521 self._degree = value 

522 

523 @property 

524 def frequencies(self) -> Tuple: 

525 """Get the frequencies of the model.""" 

526 return self._frequencies 

527 

528 @frequencies.setter 

529 def frequencies(self, value: Tuple): 

530 self._frequencies = value 

531 

532 def exact_spectrum(self, method: str = "tree") -> Tuple[np.ndarray, ...]: 

533 """Compute the exact per-feature Fourier spectrum via the FourierTree. 

534 

535 Unlike :attr:`frequencies` -- a naive per-feature estimate derived purely 

536 from the encoding, which can *overestimate* the spectrum (some 

537 coefficients are constrained to zero for all parameters) -- this builds 

538 the analytical Fourier tree (Nemkov et al.) and returns, for each input 

539 feature, the integer frequencies whose Fourier coefficient is not 

540 identically zero. The result is always a subset of :attr:`frequencies`. 

541 

542 The support is derived purely symbolically (no parameter sampling): see 

543 :meth:`~qml_essentials.coefficients.FourierTree.get_exact_support`. 

544 With ``method="tree"`` (default), frequencies whose contributions cancel 

545 identically across tree paths (e.g. two consecutive encodings combining 

546 into a single rotation) are excluded exactly; this enumerates the 

547 explicit tree, which can be infeasible for deep entangling circuits. 

548 With ``method="dp"``, a merged-state dynamic program derives the support 

549 without enumerating paths, which scales to deep circuits (single input 

550 feature only) at the cost of not detecting identical cross-path 

551 cancellations. 

552 

553 Requires a Clifford + Pauli-rotation ansatz (see 

554 :class:`~qml_essentials.pauli.PauliCircuit`); other gate sets raise 

555 ``NotImplementedError`` during tree construction. 

556 

557 Args: 

558 method (str): ``"tree"`` (fully exact) or ``"dp"`` (scalable). 

559 

560 Returns: 

561 Tuple[np.ndarray, ...]: One sorted integer frequency array per input 

562 feature (same layout as :attr:`frequencies`). 

563 """ 

564 from qml_essentials.coefficients import FourierTree # avoid circular imp. 

565 

566 tree = FourierTree(self) 

567 

568 # Position of each model feature within the tree's frequency vectors. 

569 feature_pos = {feat: i for i, feat in enumerate(tree.features)} 

570 

571 # Union of the symbolic supports over all observables (roots). 

572 support = set() 

573 for freqs in tree.get_exact_support(method=method): 

574 farr = np.asarray(freqs) 

575 for k in range(farr.shape[0]): 

576 key = ( 

577 (int(farr[k]),) 

578 if farr.ndim == 1 

579 else tuple(int(v) for v in farr[k]) 

580 ) 

581 support.add(key) 

582 

583 spectrum = [] 

584 for feat in range(self.n_input_feat): 

585 if support and feat in feature_pos: 

586 pos = feature_pos[feat] 

587 vals = sorted({k[pos] for k in support}) 

588 else: 

589 vals = [0] 

590 spectrum.append(np.array(vals, dtype=int)) 

591 return tuple(spectrum) 

592 

593 @property 

594 def has_dru(self) -> bool: 

595 """Check if the model has data reupload.""" 

596 return self._has_dru 

597 

598 @property 

599 def all_qubit_measurement(self) -> bool: 

600 """Check if measurement is performed on all qubits.""" 

601 return self.output_qubit == list(range(self.n_qubits)) 

602 

603 @property 

604 def batch_shape(self) -> Tuple[int, ...]: 

605 """ 

606 Get the batch shape (B_I, B_P, B_R). 

607 If the model was not called before, 

608 it returns (1, 1, 1). 

609 

610 Returns: 

611 Tuple[int, ...]: Tuple of (input_batch, param_batch, pulse_batch). 

612 Returns (1, 1, 1) if model has not been called yet. 

613 """ 

614 if self._batch_shape is None: 

615 log.debug("Model was not called yet. Returning (1,1,1) as batch shape.") 

616 return (1, 1, 1) 

617 return self._batch_shape 

618 

619 @property 

620 def eff_batch_shape(self) -> Tuple[int, ...]: 

621 """ 

622 Get the effective batch shape after applying repeat_batch_axis mask. 

623 

624 Returns: 

625 Tuple[int, ...]: Effective batch dimensions, excluding zeros. 

626 """ 

627 batch_shape = np.array(self.batch_shape) * self.repeat_batch_axis 

628 batch_shape = batch_shape[batch_shape != 0] 

629 return batch_shape 

630 

631 def initialize_params( 

632 self, 

633 random_key: Optional[random.PRNGKey] = None, 

634 repeat: int = 1, 

635 initialization: Optional[str] = None, 

636 initialization_domain: Optional[List[float]] = None, 

637 ) -> random.PRNGKey: 

638 """ 

639 Initialize the variational parameters of the model. 

640 

641 Args: 

642 random_key (Optional[random.PRNGKey]): JAX random key for initialization. 

643 If None, uses the model's internal random key. 

644 repeat (int): Number of parameter sets to create (batch dimension). 

645 Defaults to 1. 

646 initialization (Optional[str]): Strategy for parameter initialization. 

647 Options: "random", "zeros", "pi", "zero-controlled", "pi-controlled". 

648 If None, uses the strategy specified in the constructor. 

649 initialization_domain (Optional[List[float]]): Domain [min, max] for 

650 random initialization. If None, uses the domain from constructor. 

651 

652 Returns: 

653 random.PRNGKey: Updated random key after initialization. 

654 

655 Raises: 

656 Exception: If an invalid initialization method is specified. 

657 """ 

658 # Initializing params 

659 params_shape = (repeat, *self._params_shape) 

660 

661 # use existing strategy if not specified 

662 initialization = initialization or self._inialization_strategy 

663 initialization_domain = initialization_domain or self._initialization_domain 

664 

665 random_key, sub_key = safe_random_split( 

666 random_key if random_key is not None else self.random_key 

667 ) 

668 

669 def set_control_params(params: jnp.ndarray, value: float) -> jnp.ndarray: 

670 indices = self.pqc.get_control_indices(self.n_qubits) 

671 if indices is None: 

672 warnings.warn( 

673 f"Specified {initialization} but circuit\ 

674 does not contain controlled rotation gates.\ 

675 Parameters are intialized randomly.", 

676 UserWarning, 

677 ) 

678 else: 

679 np_params = np.array(params) 

680 np_params[:, :, indices[0] : indices[1] : indices[2]] = ( 

681 np.ones_like(params[:, :, indices[0] : indices[1] : indices[2]]) 

682 * value 

683 ) 

684 params = jnp.array(np_params) 

685 return params 

686 

687 if initialization == "random": 

688 self.params: jnp.ndarray = random.uniform( 

689 sub_key, 

690 params_shape, 

691 minval=initialization_domain[0], 

692 maxval=initialization_domain[1], 

693 ) 

694 elif initialization == "zeros": 

695 self.params: jnp.ndarray = jnp.zeros(params_shape) 

696 elif initialization == "pi": 

697 self.params: jnp.ndarray = jnp.ones(params_shape) * jnp.pi 

698 elif initialization == "zero-controlled": 

699 self.params: jnp.ndarray = random.uniform( 

700 sub_key, 

701 params_shape, 

702 minval=initialization_domain[0], 

703 maxval=initialization_domain[1], 

704 ) 

705 self.params = set_control_params(self.params, 0) 

706 elif initialization == "pi-controlled": 

707 self.params: jnp.ndarray = random.uniform( 

708 sub_key, 

709 params_shape, 

710 minval=initialization_domain[0], 

711 maxval=initialization_domain[1], 

712 ) 

713 self.params = set_control_params(self.params, jnp.pi) 

714 else: 

715 raise Exception("Invalid initialization method") 

716 

717 log.info( 

718 f"Initialized parameters with shape {self.params.shape}\ 

719 using strategy {initialization}." 

720 ) 

721 

722 return random_key 

723 

724 def transform_input( 

725 self, inputs: jnp.ndarray, enc_params: jnp.ndarray 

726 ) -> jnp.ndarray: 

727 """ 

728 Transform input data by scaling with encoding parameters. 

729 

730 Implements the input transformation as described in arXiv:2309.03279v2, 

731 where inputs are linearly scaled by encoding parameters before being 

732 used in the quantum circuit. 

733 

734 Args: 

735 inputs (jnp.ndarray): Input data point of shape (n_input_feat,) or 

736 (batch_size, n_input_feat). 

737 enc_params (jnp.ndarray): Encoding weight scalar or vector used to 

738 scale the input. 

739 

740 Returns: 

741 jnp.ndarray: Transformed input, element-wise product of inputs 

742 and enc_params. 

743 """ 

744 return inputs * enc_params 

745 

746 def _iec( 

747 self, 

748 inputs: jnp.ndarray, 

749 data_reupload: jnp.ndarray, 

750 enc: Encoding, 

751 enc_params: jnp.ndarray, 

752 noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None, 

753 random_key: Optional[random.PRNGKey] = None, 

754 ) -> None: 

755 """ 

756 Apply Input Encoding Circuit (IEC) with angle encoding. 

757 

758 Encodes classical input data into the quantum circuit using rotation 

759 gates (e.g., RX, RY, RZ). Supports data re-uploading at specified 

760 positions in the circuit. 

761 

762 For Golomb encoding, a single multi-qubit diagonal unitary is applied 

763 to all qubits simultaneously instead of per-qubit rotation gates. 

764 

765 Args: 

766 inputs (jnp.ndarray): Input data of shape (n_input_feat,) or 

767 (batch_size, n_input_feat). 

768 data_reupload (jnp.ndarray): Boolean array of shape (n_qubits, n_input_feat) 

769 indicating where to apply encoding gates. 

770 enc (Encoding): Encoding strategy containing the encoding gate functions. 

771 enc_params (jnp.ndarray): Encoding parameters of shape 

772 (n_qubits, n_input_feat) used to scale inputs. 

773 noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]): 

774 Noise parameters for gate-level noise simulation. Defaults to None. 

775 random_key (Optional[random.PRNGKey]): JAX random key for stochastic 

776 noise. Defaults to None. 

777 

778 Returns: 

779 None: Gates are applied in-place to the quantum circuit. 

780 """ 

781 # check for zero, because due to input validation, input cannot be none 

782 if self.remove_zero_encoding and self._zero_inputs and self.batch_shape[0] == 1: 

783 return 

784 

785 # --- Golomb encoding: single multi-qubit gate on all qubits -------- 

786 if enc.is_golomb: 

787 idx = 0 # Golomb encoding supports a single input feature 

788 # Check if any qubit has re-uploading enabled for this layer 

789 if data_reupload[:, idx].any(): 

790 random_key, sub_key = safe_random_split(random_key) 

791 # Use the mean of enc_params across qubits as scalar scaling 

792 # (Golomb acts on all qubits jointly) 

793 mean_enc_param = jnp.mean(enc_params[:, idx]) 

794 all_wires = list(range(self.n_qubits)) 

795 enc[idx]( 

796 self.transform_input(inputs[..., idx], mean_enc_param), 

797 wires=all_wires, 

798 noise_params=noise_params, 

799 random_key=sub_key, 

800 ) 

801 return 

802 

803 # --- Standard per-qubit encoding ----------------------------------- 

804 for q in range(self.n_qubits): 

805 # use the last dimension of the inputs (feature dimension) 

806 for idx in range(inputs.shape[-1]): 

807 if data_reupload[q, idx]: 

808 # use elipsis to indiex only the last dimension 

809 # as inputs are generally *not* qubit dependent 

810 random_key, sub_key = safe_random_split(random_key) 

811 enc[idx]( 

812 self.transform_input(inputs[..., idx], enc_params[q, idx]), 

813 wires=q, 

814 noise_params=noise_params, 

815 random_key=sub_key, 

816 ) 

817 

818 def _variational( 

819 self, 

820 params: jnp.ndarray, 

821 inputs: jnp.ndarray, 

822 pulse_params: Optional[jnp.ndarray] = None, 

823 random_key: Optional[random.PRNGKey] = None, 

824 enc_params: Optional[jnp.ndarray] = None, 

825 gate_mode: str = "unitary", 

826 noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None, 

827 ) -> None: 

828 """ 

829 Build the variational quantum circuit structure. 

830 

831 Constructs the circuit by applying state preparation, alternating 

832 variational ansatz layers with input encoding layers, and optional 

833 noise channels. 

834 

835 The first five parameters (after ``self``) - ``params``, ``inputs``, 

836 ``pulse_params``, ``random_key``, ``enc_params`` - are the batchable 

837 positional arguments. 

838 The remaining keyword arguments are broadcast across the batch. 

839 

840 Args: 

841 params (jnp.ndarray): Variational parameters of shape 

842 (n_layers, n_params_per_layer). 

843 inputs (jnp.ndarray): Input data of shape (n_input_feat,). 

844 pulse_params (Optional[jnp.ndarray]): Pulse parameter scalers of shape 

845 (n_layers, n_pulse_params_per_layer) for pulse-mode execution. 

846 Defaults to None (uses model's pulse_params). 

847 random_key (Optional[random.PRNGKey]): JAX random key for stochastic 

848 operations. Defaults to None. 

849 enc_params (Optional[jnp.ndarray]): Encoding parameters of shape 

850 (n_qubits, n_input_feat). Defaults to None (uses model's enc_params). 

851 gate_mode (str): Gate execution mode, either "unitary" or "pulse". 

852 Defaults to "unitary". 

853 noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]): 

854 Noise parameters for simulation. Defaults to None. 

855 

856 Returns: 

857 None: Gates are applied in-place to the quantum circuit. 

858 

859 Note: 

860 Issues RuntimeWarning if called directly without providing parameters 

861 that would normally be passed through the forward method. 

862 """ 

863 # TODO: rework and double check params shape 

864 if len(params.shape) > 2 and params.shape[0] == 1: 

865 params = params[0] 

866 

867 if len(inputs.shape) > 1 and inputs.shape[0] == 1: 

868 inputs = inputs[0] 

869 

870 if enc_params is None: 

871 # TODO: Raise warning if trainable frequencies is True, or similar. I.e., no 

872 # warning if user does not care for frequencies or enc_params 

873 if self.trainable_frequencies: 

874 warnings.warn( 

875 "Explicit call to `_circuit` or `_variational` detected: " 

876 "`enc_params` is None, using `self.enc_params` instead.", 

877 RuntimeWarning, 

878 ) 

879 enc_params = self.enc_params 

880 

881 if pulse_params is None: 

882 if gate_mode == "pulse": 

883 warnings.warn( 

884 "Explicit call to `_circuit` or `_variational` detected: " 

885 "`pulse_params` is None, using `self.pulse_params` instead.", 

886 RuntimeWarning, 

887 ) 

888 pulse_params = self.pulse_params 

889 

890 # Squeeze batch dimension for pulse_params (batch-first convention) 

891 if len(pulse_params.shape) > 2 and pulse_params.shape[0] == 1: 

892 pulse_params = pulse_params[0] 

893 

894 if noise_params is None: 

895 if self.noise_params is not None: 

896 warnings.warn( 

897 "Explicit call to `_circuit` or `_variational` detected: " 

898 "`noise_params` is None, using `self.noise_params` instead.", 

899 RuntimeWarning, 

900 ) 

901 noise_params = self.noise_params 

902 

903 if noise_params is not None: 

904 if random_key is None: 

905 warnings.warn( 

906 "Explicit call to `_circuit` or `_variational` detected: " 

907 "`random_key` is None, using `random.PRNGKey(0)` instead.", 

908 RuntimeWarning, 

909 ) 

910 random_key = self.random_key 

911 self._apply_state_prep_noise(noise_params=noise_params) 

912 

913 # state preparation 

914 for q in range(self.n_qubits): 

915 for _sp, sp_pulse_params in zip(self._sp, self.sp_pulse_params): 

916 random_key, sub_key = safe_random_split(random_key) 

917 _sp( 

918 wires=q, 

919 pulse_params=sp_pulse_params, 

920 noise_params=noise_params, 

921 random_key=sub_key, 

922 gate_mode=gate_mode, 

923 ) 

924 

925 # circuit building 

926 for layer in range(0, self.n_layers): 

927 random_key, sub_key = safe_random_split(random_key) 

928 # ansatz layers 

929 self.pqc( 

930 params[layer], 

931 self.n_qubits, 

932 pulse_params=pulse_params[layer], 

933 noise_params=noise_params, 

934 random_key=sub_key, 

935 gate_mode=gate_mode, 

936 ) 

937 

938 random_key, sub_key = safe_random_split(random_key) 

939 # encoding layers 

940 self._iec( 

941 inputs, 

942 data_reupload=self.data_reupload[layer], 

943 enc=self._enc, 

944 enc_params=enc_params[layer], 

945 noise_params=noise_params, 

946 random_key=sub_key, 

947 ) 

948 

949 # final ansatz layer 

950 if self.has_dru: # same check as in init 

951 random_key, sub_key = safe_random_split(random_key) 

952 self.pqc( 

953 params[self.n_layers], 

954 self.n_qubits, 

955 pulse_params=pulse_params[-1], 

956 noise_params=noise_params, 

957 random_key=sub_key, 

958 gate_mode=gate_mode, 

959 ) 

960 

961 # channel noise 

962 if noise_params is not None: 

963 self._apply_general_noise(noise_params=noise_params) 

964 

965 def _build_obs(self) -> Tuple[str, List[op.Operation]]: 

966 """Build the jaqsi measurement type and observable list. 

967 

968 Translates the model's ``execution_type`` and ``output_qubit`` 

969 settings into parameters suitable for 

970 :meth:`~qml_essentials.jaqsi.Script.execute`. 

971 

972 Returns: 

973 Tuple ``(meas_type, obs)`` where *meas_type* is one of 

974 ``"expval"``, ``"probs"``, ``"density"``, ``"state"`` and *obs* 

975 is a (possibly empty) list of :class:`Operation` observables. 

976 """ 

977 if self.execution_type == "density": 

978 return "density", [] 

979 

980 if self.execution_type == "state": 

981 return "state", [] 

982 

983 if self.execution_type == "expval": 

984 obs: List[op.Operation] = [] 

985 for qubit_spec in self.output_qubit: 

986 if isinstance(qubit_spec, int): 

987 obs.append(op.PauliZ(wires=qubit_spec)) 

988 else: 

989 # parity: Z \\otimes Z \\otimes … 

990 obs.append(js.build_parity_observable(list(qubit_spec))) 

991 return "expval", obs 

992 

993 if self.execution_type == "probs": 

994 # probs are computed on the full system; subsystem 

995 # marginalisation is handled in _postprocess_res 

996 return "probs", [] 

997 

998 raise ValueError(f"Invalid execution_type: {self.execution_type}.") 

999 

1000 def _apply_state_prep_noise( 

1001 self, noise_params: Dict[str, Union[float, Dict[str, float]]] 

1002 ) -> None: 

1003 """ 

1004 Apply state preparation noise to all qubits. 

1005 

1006 Simulates imperfect state preparation by applying BitFlip errors 

1007 to each qubit with the specified probability. 

1008 

1009 Args: 

1010 noise_params (Dict[str, Union[float, Dict[str, float]]]): Dictionary 

1011 containing noise parameters. Uses the "StatePreparation" key 

1012 for the BitFlip probability. 

1013 

1014 Returns: 

1015 None: Noise channels are applied in-place to the circuit. 

1016 """ 

1017 p = noise_params.get("StatePreparation", 0.0) 

1018 if p > 0: 

1019 for q in range(self.n_qubits): 

1020 op.BitFlip(p, wires=q) 

1021 

1022 def _apply_general_noise( 

1023 self, noise_params: Dict[str, Union[float, Dict[str, float]]] 

1024 ) -> None: 

1025 """ 

1026 Apply general noise channels to all qubits. 

1027 

1028 Applies various decoherence and error channels after the circuit 

1029 execution, simulating environmental noise effects. 

1030 

1031 Args: 

1032 noise_params (Dict[str, Union[float, Dict[str, float]]]): Dictionary 

1033 containing noise parameters with the following supported keys: 

1034 - "AmplitudeDamping" (float): Probability for amplitude damping. 

1035 - "PhaseDamping" (float): Probability for phase damping. 

1036 - "Measurement" (float): Probability for measurement error (BitFlip). 

1037 - "ThermalRelaxation" (Dict): Dictionary with keys "t1", "t2", 

1038 "t_factor" for thermal relaxation simulation. 

1039 

1040 Returns: 

1041 None: Noise channels are applied in-place to the circuit. 

1042 

1043 Note: 

1044 Gate-level noise (e.g., GateError) is handled separately in the 

1045 Gates.Noise module and applied at the individual gate level. 

1046 """ 

1047 amp_damp = noise_params.get("AmplitudeDamping", 0.0) 

1048 phase_damp = noise_params.get("PhaseDamping", 0.0) 

1049 thermal_relax = noise_params.get("ThermalRelaxation", 0.0) 

1050 meas = noise_params.get("Measurement", 0.0) 

1051 for q in range(self.n_qubits): 

1052 if amp_damp > 0: 

1053 op.AmplitudeDamping(amp_damp, wires=q) 

1054 if phase_damp > 0: 

1055 op.PhaseDamping(phase_damp, wires=q) 

1056 if meas > 0: 

1057 op.BitFlip(meas, wires=q) 

1058 if isinstance(thermal_relax, dict): 

1059 t1 = thermal_relax["t1"] 

1060 t2 = thermal_relax["t2"] 

1061 t_factor = thermal_relax["t_factor"] 

1062 circuit_depth = self._get_circuit_depth() 

1063 tg = circuit_depth * t_factor 

1064 op.ThermalRelaxationError(1.0, t1, t2, tg, q) 

1065 

1066 def _get_circuit_depth(self, inputs: Optional[jnp.ndarray] = None) -> int: 

1067 """ 

1068 Calculate the depth of the quantum circuit. 

1069 

1070 Records the circuit onto a tape (without noise) and computes the 

1071 depth as the length of the critical path: each gate is scheduled 

1072 at the earliest time step after all of its qubits are free. 

1073 

1074 Args: 

1075 inputs (Optional[jnp.ndarray]): Input data for circuit evaluation. 

1076 If None, default zero inputs are used. 

1077 

1078 Returns: 

1079 int: The circuit depth (longest path of gates in the circuit). 

1080 """ 

1081 # Return cached value if available 

1082 if hasattr(self, "_cached_circuit_depth"): 

1083 return self._cached_circuit_depth 

1084 

1085 inputs = self._inputs_validation(inputs) 

1086 

1087 # Temporarily clear noise_params to prevent _variational from 

1088 # picking them up (which would call _apply_general_noise -> 

1089 # _get_circuit_depth again, causing infinite recursion). 

1090 saved_noise = self._noise_params 

1091 self._noise_params = None 

1092 

1093 with recording() as tape: 

1094 self._variational( 

1095 self.params[0] if self.params.ndim == 3 else self.params, 

1096 inputs[0] if inputs.ndim == 2 else inputs, 

1097 noise_params=None, 

1098 ) 

1099 

1100 self._noise_params = saved_noise 

1101 

1102 # Filter out noise channels - only count unitary gates 

1103 ops = [o for o in tape if not isinstance(o, KrausChannel)] 

1104 

1105 if not ops: 

1106 self._cached_circuit_depth = 0 

1107 return 0 

1108 

1109 # Schedule each gate at the earliest time step where all its wires 

1110 # are free. ``wire_busy[q]`` tracks the next free time step for 

1111 # qubit ``q``. 

1112 wire_busy: Dict[int, int] = {} 

1113 depth = 0 

1114 for gate in ops: 

1115 start = max((wire_busy.get(w, 0) for w in gate.wires), default=0) 

1116 end = start + 1 

1117 for w in gate.wires: 

1118 wire_busy[w] = end 

1119 depth = max(depth, end) 

1120 

1121 self._cached_circuit_depth = depth 

1122 return depth 

1123 

1124 def draw( 

1125 self, 

1126 inputs: Optional[jnp.ndarray] = None, 

1127 figure: str = "text", 

1128 **kwargs: Any, 

1129 ) -> Union[str, Any]: 

1130 """Visualize the quantum circuit. 

1131 

1132 Records the circuit tape (without noise) and renders the gate 

1133 sequence using the requested backend. 

1134 

1135 Args: 

1136 inputs (Optional[jnp.ndarray]): Input data for the circuit. 

1137 If ``None``, default zero inputs are used. 

1138 figure (str): Rendering backend. One of: 

1139 

1140 * ``"text"`` - ASCII art (returned as a ``str``). 

1141 * ``"mpl"`` - Matplotlib figure (returns ``(fig, ax)``). 

1142 * ``"tikz"`` - LaTeX/TikZ ``quantikz`` code (returns a 

1143 :class:`TikzFigure`). 

1144 * ``"pulse"`` - Pulse schedule (returns ``(fig, axes)``). 

1145 Only meaningful for pulse-mode models. 

1146 

1147 **kwargs: Extra options forwarded to the drawing backend 

1148 (e.g. ``gate_values=True``). 

1149 

1150 Returns: 

1151 Depends on figure: 

1152 

1153 * ``"text"`` -> ``str`` 

1154 * ``"mpl"`` -> ``(matplotlib.figure.Figure, matplotlib.axes.Axes)`` 

1155 * ``"tikz"`` -> :class:`TikzFigure` 

1156 

1157 Raises: 

1158 ValueError: If figure is not one of the supported modes. 

1159 """ 

1160 inputs = self._inputs_validation(inputs) 

1161 params = self.params[0] if self.params.ndim == 3 else self.params 

1162 inp = inputs[0] if inputs.ndim == 2 else inputs 

1163 

1164 if figure == "pulse": 

1165 return self.draw_pulse(inputs=inputs, **kwargs) 

1166 

1167 # Record without noise to get a clean circuit 

1168 saved_noise = self._noise_params 

1169 self._noise_params = None 

1170 

1171 draw_script = js.Script(f=self._variational, n_qubits=self.n_qubits) 

1172 result = draw_script.draw( 

1173 figure=figure, 

1174 args=(params, inp), 

1175 kwargs={"noise_params": None}, 

1176 **kwargs, 

1177 ) 

1178 

1179 self._noise_params = saved_noise 

1180 return result 

1181 

1182 def draw_pulse( 

1183 self, 

1184 inputs: Optional[jnp.ndarray] = None, 

1185 **kwargs: Any, 

1186 ) -> Any: 

1187 """Visualize the pulse schedule for the circuit. 

1188 

1189 Records the circuit in pulse mode and collects PulseEvents 

1190 automatically via the pulse-event tape, then renders them. 

1191 

1192 Args: 

1193 inputs: Input data. If ``None``, default zero inputs are used. 

1194 **kwargs: Forwarded to 

1195 :func:`~qml_essentials.drawing.draw_pulse_schedule` 

1196 (e.g. ``show_carrier=True``, ``n_samples=300``). 

1197 

1198 Returns: 

1199 ``(fig, axes)`` — Matplotlib Figure and array of Axes. 

1200 """ 

1201 inputs = self._inputs_validation(inputs) 

1202 params = self.params[0] if self.params.ndim == 3 else self.params 

1203 inp = inputs[0] if inputs.ndim == 2 else inputs 

1204 

1205 draw_script = js.Script(f=self._variational, n_qubits=self.n_qubits) 

1206 return draw_script.draw( 

1207 figure="pulse", 

1208 args=(params, inp), 

1209 kwargs={ 

1210 "gate_mode": "pulse", 

1211 "noise_params": None, 

1212 }, 

1213 **kwargs, 

1214 ) 

1215 

1216 def __repr__(self) -> str: 

1217 """Return text representation of the quantum circuit model.""" 

1218 return self.draw(figure="text") 

1219 

1220 def __str__(self) -> str: 

1221 """Return string representation of the quantum circuit model.""" 

1222 return self.draw(figure="text") 

1223 

1224 def _params_validation(self, params: Optional[jnp.ndarray]) -> jnp.ndarray: 

1225 """ 

1226 Validate and normalize variational parameters. 

1227 

1228 Ensures parameters have the correct shape with a batch dimension, 

1229 and updates the model's internal parameters if new ones are provided. 

1230 

1231 Args: 

1232 params (Optional[jnp.ndarray]): Variational parameters to validate. 

1233 If None, returns the model's current parameters. 

1234 

1235 Returns: 

1236 jnp.ndarray: Validated parameters with shape 

1237 (batch_size, n_layers, n_params_per_layer). 

1238 """ 

1239 # append batch axis if not provided 

1240 if params is not None: 

1241 if len(params.shape) == 2: 

1242 params = np.expand_dims(params, axis=0) 

1243 

1244 # Avoid stashing JAX tracers on ``self``: under an outer 

1245 # transform (e.g. ``jacrev``) the tracer becomes invalid once 

1246 # the transform returns, and a subsequent read of 

1247 # ``self.params`` would feed a leaked tracer into the next 

1248 # call (raising ``UnexpectedTracerError``). 

1249 # if not isinstance(params, jax.core.Tracer): 

1250 # self.params = params 

1251 self.params = params 

1252 else: 

1253 params = self.params 

1254 

1255 return params 

1256 

1257 def _pulse_params_validation( 

1258 self, pulse_params: Optional[jnp.ndarray] 

1259 ) -> jnp.ndarray: 

1260 """ 

1261 Validate and normalize pulse parameters. 

1262 

1263 Ensures pulse parameters are set, using model defaults if not provided. 

1264 

1265 Args: 

1266 pulse_params (Optional[jnp.ndarray]): Pulse parameter scalers. 

1267 If None, returns the model's current pulse parameters. 

1268 

1269 Returns: 

1270 jnp.ndarray: Validated pulse parameters with shape 

1271 (batch_size, n_layers, n_pulse_params_per_layer). 

1272 """ 

1273 if pulse_params is None: 

1274 pulse_params = self.pulse_params 

1275 else: 

1276 # ensure batch dimension exists (batch-first convention) 

1277 if len(pulse_params.shape) == 2: 

1278 pulse_params = jnp.expand_dims(pulse_params, axis=0) 

1279 # See note in _params_validation: never stash JAX tracers on 

1280 # ``self``. 

1281 # if not isinstance(pulse_params, jax.core.Tracer): 

1282 # self.pulse_params = pulse_params 

1283 self.pulse_params = pulse_params 

1284 

1285 return pulse_params 

1286 

1287 def _enc_params_validation(self, enc_params: Optional[jnp.ndarray]) -> jnp.ndarray: 

1288 """ 

1289 Validate and normalize encoding parameters. 

1290 

1291 Ensures encoding parameters have the correct shape for the model's 

1292 input feature dimensions. 

1293 

1294 Args: 

1295 enc_params (Optional[jnp.ndarray]): Encoding parameters to validate. 

1296 If None, returns the model's current encoding parameters. 

1297 

1298 Returns: 

1299 jnp.ndarray: Validated encoding parameters with shape 

1300 (n_qubits, n_input_feat). 

1301 

1302 Raises: 

1303 ValueError: If enc_params shape is incompatible with n_input_feat > 1. 

1304 """ 

1305 if enc_params is None: 

1306 enc_params = self.enc_params 

1307 else: 

1308 # See note in _params_validation: never stash JAX tracers on 

1309 # ``self``. 

1310 # if not isinstance(enc_params, jax.core.Tracer): 

1311 # if self.trainable_frequencies: 

1312 # self.enc_params = enc_params 

1313 # else: 

1314 # self.enc_params = jnp.array(enc_params) 

1315 if self.trainable_frequencies: 

1316 self.enc_params = enc_params 

1317 else: 

1318 self.enc_params = jnp.array(enc_params) 

1319 

1320 if len(enc_params.shape) == 1 and self.n_input_feat == 1: 

1321 enc_params = enc_params.reshape(-1, 1) 

1322 elif len(enc_params.shape) == 1 and self.n_input_feat > 1: 

1323 raise ValueError( 

1324 f"Input dimension {self.n_input_feat} >1 but \ 

1325 `enc_params` has shape {enc_params.shape}" 

1326 ) 

1327 

1328 return enc_params 

1329 

1330 def _inputs_validation( 

1331 self, inputs: Union[None, List, float, int, jnp.ndarray] 

1332 ) -> jnp.ndarray: 

1333 """ 

1334 Validate and normalize input data. 

1335 

1336 Converts various input formats to a standardized 2D array shape 

1337 suitable for batch processing in the quantum circuit. 

1338 

1339 Args: 

1340 inputs (Union[None, List, float, int, jnp.ndarray]): Input data in 

1341 various formats: 

1342 - None: Returns zeros with shape (1, n_input_feat) 

1343 - float/int: Single scalar value 

1344 - List: List of values or batched inputs 

1345 - jnp.ndarray: NumPy/JAX array 

1346 

1347 Returns: 

1348 jnp.ndarray: Validated inputs with shape (batch_size, n_input_feat). 

1349 

1350 Raises: 

1351 ValueError: If input shape is incompatible with expected n_input_feat. 

1352 

1353 Warns: 

1354 UserWarning: If input is replicated to match n_input_feat. 

1355 """ 

1356 self._zero_inputs = False 

1357 if isinstance(inputs, List): 

1358 inputs = jnp.array(np.stack(inputs)) 

1359 elif isinstance(inputs, float) or isinstance(inputs, int): 

1360 inputs = jnp.array([inputs]) 

1361 elif inputs is None: 

1362 inputs = jnp.array([[0] * self.n_input_feat]) 

1363 

1364 if not inputs.any(): 

1365 self._zero_inputs = True 

1366 

1367 if len(inputs.shape) <= 1: 

1368 if self.n_input_feat == 1: 

1369 # add a batch dimension 

1370 inputs = inputs.reshape(-1, 1) 

1371 else: 

1372 if inputs.shape[0] == self.n_input_feat: 

1373 inputs = inputs.reshape(1, -1) 

1374 else: 

1375 inputs = inputs.reshape(-1, 1) 

1376 inputs = inputs.repeat(self.n_input_feat, axis=1) 

1377 warnings.warn( 

1378 f"Expected {self.n_input_feat} inputs, but {inputs.shape[0]} " 

1379 "was provided, replicating input for all input features.", 

1380 UserWarning, 

1381 ) 

1382 else: 

1383 if inputs.shape[1] != self.n_input_feat: 

1384 raise ValueError( 

1385 f"Wrong number of inputs provided. Expected {self.n_input_feat} " 

1386 f"inputs, but input has shape {inputs.shape}." 

1387 ) 

1388 

1389 return inputs 

1390 

1391 def _postprocess_res(self, result: Union[List, jnp.ndarray]) -> jnp.ndarray: 

1392 """ 

1393 Post-process circuit execution results for uniform shape. 

1394 

1395 Converts list outputs (from multiple measurements) to stacked arrays 

1396 and reorders axes for consistent batch dimension placement. 

1397 

1398 Args: 

1399 result (Union[List, jnp.ndarray]): Raw circuit output, either a 

1400 list of measurement results or a single array. 

1401 

1402 Returns: 

1403 jnp.ndarray: Uniformly shaped result array with batch dimension first. 

1404 """ 

1405 if isinstance(result, list): 

1406 # we use moveaxis here because in case of parity measure, 

1407 # there is another dimension appended to the end and 

1408 # simply transposing would result in a wrong shape 

1409 result = jnp.stack(result) 

1410 if len(result.shape) > 1: 

1411 result = jnp.moveaxis(result, 0, 1) 

1412 return result 

1413 

1414 def _assimilate_batch( 

1415 self, 

1416 inputs: jnp.ndarray, 

1417 params: jnp.ndarray, 

1418 pulse_params: jnp.ndarray, 

1419 ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: 

1420 """ 

1421 Align batch dimensions across inputs, parameters, and pulse parameters. 

1422 

1423 Broadcasts and reshapes arrays to have compatible batch dimensions 

1424 for vectorized circuit execution. Sets the internal batch_shape. 

1425 

1426 Args: 

1427 inputs (jnp.ndarray): Input data of shape (B_I, n_input_feat). 

1428 params (jnp.ndarray): Parameters of shape (B_P, n_layers, n_params). 

1429 pulse_params (jnp.ndarray): Pulse params of shape (B_R, n_layers, n_pulse). 

1430 

1431 Returns: 

1432 Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: Tuple containing: 

1433 - inputs: Reshaped to (B, n_input_feat) where B = B_I * B_P * B_R 

1434 - params: Reshaped to (B, n_layers, n_params) 

1435 - pulse_params: Reshaped to (B, n_layers, n_pulse) 

1436 

1437 Note: 

1438 The effective batch shape depends on repeat_batch_axis configuration. 

1439 This is the only method that sets self._batch_shape. 

1440 """ 

1441 B_I = inputs.shape[0] 

1442 # we check for the product because there is a chance that 

1443 # there are no params. In this case we want B_P to be 1 

1444 B_P = 1 if 0 in params.shape else params.shape[0] 

1445 B_R = pulse_params.shape[0] 

1446 

1447 # THIS is the only place where we set the batch shape 

1448 self._batch_shape = (B_I, B_P, B_R) 

1449 B = np.prod(self.eff_batch_shape) 

1450 

1451 # [B_I, ...] -> [B_I, B_P, B_R, ...] -> [B, ...] 

1452 if B_I > 1 and self.repeat_batch_axis[0]: 

1453 if self.repeat_batch_axis[1]: 

1454 inputs = jnp.repeat(inputs[:, None, None, ...], B_P, axis=1) 

1455 if self.repeat_batch_axis[2]: 

1456 inputs = jnp.repeat(inputs, B_R, axis=2) 

1457 inputs = inputs.reshape(B, *inputs.shape[3:]) 

1458 

1459 # [B_P, ..., ...] -> [B_I, B_P, B_R, ..., ...] -> [B, ..., ...] 

1460 if B_P > 1 and self.repeat_batch_axis[1]: 

1461 # add B_I axis before first, and B_R axis after first batch dim 

1462 params = params[None, :, None, ...] # [B_I(=1), B_P, B_R(=1), ...] 

1463 if self.repeat_batch_axis[0]: 

1464 params = jnp.repeat(params, B_I, axis=0) # [B_I, B_P, 1, ...] 

1465 if self.repeat_batch_axis[2]: 

1466 params = jnp.repeat(params, B_R, axis=2) # [B_I, B_P, B_R, ...] 

1467 params = params.reshape(B, *params.shape[3:]) 

1468 

1469 # [B_R, ..., ...] -> [B_I, B_P, B_R, ..., ...] -> [B, ..., ...] 

1470 if B_R > 1 and self.repeat_batch_axis[2]: 

1471 # add B_I axis and B_P axis before B_R 

1472 pulse_params = pulse_params[None, None, ...] # [B_I(=1), B_P(=1), B_R, ...] 

1473 if self.repeat_batch_axis[0]: 

1474 pulse_params = jnp.repeat( 

1475 pulse_params, B_I, axis=0 

1476 ) # [B_I, 1, B_R, ...] 

1477 if self.repeat_batch_axis[1]: 

1478 pulse_params = jnp.repeat( 

1479 pulse_params, B_P, axis=1 

1480 ) # [B_I, B_P, B_R, ...] 

1481 pulse_params = pulse_params.reshape(B, *pulse_params.shape[3:]) 

1482 

1483 return inputs, params, pulse_params 

1484 

1485 def _requires_density(self) -> bool: 

1486 """ 

1487 Check if density matrix simulation is required. 

1488 

1489 Determines whether the circuit must be executed with the mixed-state 

1490 simulator based on execution type and noise configuration. 

1491 

1492 Returns: 

1493 bool: True if density matrix simulation is required, False otherwise. 

1494 Returns True if: 

1495 - execution_type is "density", or 

1496 - Any non-coherent noise channel has non-zero probability 

1497 """ 

1498 if self.execution_type == "density": 

1499 return True 

1500 

1501 if self.noise_params is None: 

1502 return False 

1503 

1504 coherent_noise = {"GateError"} 

1505 for k, v in self.noise_params.items(): 

1506 if k in coherent_noise: 

1507 continue 

1508 if v is not None and v > 0: 

1509 return True 

1510 return False 

1511 

1512 def __call__( 

1513 self, 

1514 params: Optional[jnp.ndarray] = None, 

1515 inputs: Optional[jnp.ndarray] = None, 

1516 pulse_params: Optional[jnp.ndarray] = None, 

1517 enc_params: Optional[jnp.ndarray] = None, 

1518 data_reupload: Union[bool, List[List[bool]], List[List[List[bool]]]] = None, 

1519 noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None, 

1520 execution_type: Optional[str] = None, 

1521 force_mean: bool = False, 

1522 gate_mode: str = "unitary", 

1523 ) -> jnp.ndarray: 

1524 """ 

1525 Execute the quantum circuit (callable interface). 

1526 

1527 Provides a convenient callable interface for circuit execution, 

1528 delegating to the _forward method. 

1529 

1530 Args: 

1531 params (Optional[jnp.ndarray]): Variational parameters of shape 

1532 (n_layers, n_params_per_layer) or (batch, n_layers, n_params_per_layer). 

1533 If None, uses model's internal parameters. 

1534 inputs (Optional[jnp.ndarray]): Input data of shape 

1535 (batch_size, n_input_feat). If None, uses zero inputs. 

1536 pulse_params (Optional[jnp.ndarray]): Pulse parameter scalers for 

1537 pulse-mode gate execution. 

1538 enc_params (Optional[jnp.ndarray]): Encoding parameters of shape 

1539 (n_qubits, n_input_feat). If None, uses model's encoding parameters. 

1540 data_reupload (Union[bool, List[List[bool]], List[List[List[bool]]]]): 

1541 Data reupload configuration. If None, uses previously set reupload 

1542 configuration. 

1543 noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]): 

1544 Noise configuration. If None, uses previously set noise parameters. 

1545 execution_type (Optional[str]): Measurement type: "expval", "density", 

1546 "probs", or "state". If None, uses current execution_type setting. 

1547 force_mean (bool): If True, averages results over measurement qubits. 

1548 Defaults to False. 

1549 gate_mode (str): Gate execution backend, "unitary" or "pulse". 

1550 Defaults to "unitary". 

1551 

1552 Returns: 

1553 jnp.ndarray: Circuit output with shape depending on execution_type: 

1554 - "expval": (n_output_qubits,) or scalar 

1555 - "density": (2^n_output, 2^n_output) 

1556 - "probs": (2^n_output,) or (n_pairs, 2^pair_size) 

1557 - "state": (2^n_qubits,) 

1558 """ 

1559 # Call forward method which handles the actual caching etc. 

1560 return self._forward( 

1561 params=params, 

1562 inputs=inputs, 

1563 pulse_params=pulse_params, 

1564 enc_params=enc_params, 

1565 data_reupload=data_reupload, 

1566 noise_params=noise_params, 

1567 execution_type=execution_type, 

1568 force_mean=force_mean, 

1569 gate_mode=gate_mode, 

1570 ) 

1571 

1572 def _forward( 

1573 self, 

1574 params: Optional[jnp.ndarray] = None, 

1575 inputs: Optional[jnp.ndarray] = None, 

1576 pulse_params: Optional[jnp.ndarray] = None, 

1577 enc_params: Optional[jnp.ndarray] = None, 

1578 data_reupload: Union[bool, List[List[bool]], List[List[List[bool]]]] = None, 

1579 noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None, 

1580 execution_type: Optional[str] = None, 

1581 force_mean: bool = False, 

1582 gate_mode: str = "unitary", 

1583 ) -> jnp.ndarray: 

1584 """ 

1585 Execute the quantum circuit forward pass. 

1586 

1587 Internal implementation of the forward pass that handles parameter 

1588 validation, batch alignment, and circuit execution routing. 

1589 

1590 Args: 

1591 params (Optional[jnp.ndarray]): Variational parameters of shape 

1592 (n_layers, n_params_per_layer) or 

1593 (batch, n_layers, n_params_per_layer). 

1594 If None, uses model's internal parameters. 

1595 inputs (Optional[jnp.ndarray]): Input data of shape 

1596 (batch_size, n_input_feat). 

1597 If None, uses zero inputs. 

1598 pulse_params (Optional[jnp.ndarray]): Pulse parameter scalers for 

1599 pulse-mode gate execution. 

1600 enc_params (Optional[jnp.ndarray]): Encoding parameters of shape 

1601 (n_qubits, n_input_feat). If None, uses model's encoding parameters. 

1602 data_reupload (Union[bool, List[List[bool]], List[List[List[bool]]]]): 

1603 Data reupload configuration. If None, uses previously set reupload 

1604 configuration. 

1605 noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]): 

1606 Noise configuration. If None, uses previously set noise parameters. 

1607 execution_type (Optional[str]): Measurement type: "expval", "density", 

1608 "probs", or "state". If None, uses current execution_type setting. 

1609 force_mean (bool): If True, averages results over measurement qubits. 

1610 Defaults to False. 

1611 gate_mode (str): Gate execution backend, "unitary" or "pulse". 

1612 Defaults to "unitary". 

1613 

1614 Returns: 

1615 jnp.ndarray: Circuit output with shape depending on execution_type: 

1616 - "expval": (n_output_qubits,) or scalar 

1617 - "density": (2^n_output, 2^n_output) 

1618 - "probs": (2^n_output,) or (n_pairs, 2^pair_size) 

1619 - "state": (2^n_qubits,) 

1620 

1621 Raises: 

1622 ValueError: If pulse_params provided without pulse gate_mode, or 

1623 if noise_params provided with pulse gate_mode. 

1624 """ 

1625 # set the parameters as object attributes 

1626 if noise_params is not None: 

1627 self.noise_params = noise_params 

1628 if execution_type is not None: 

1629 self.execution_type = execution_type 

1630 self.gate_mode = gate_mode 

1631 

1632 # consistency checks 

1633 if pulse_params is not None and gate_mode != "pulse": 

1634 raise ValueError( 

1635 "pulse_params were provided but gate_mode is not 'pulse'. " 

1636 "Either switch gate_mode='pulse' or do not pass pulse_params." 

1637 ) 

1638 

1639 # TODO: add testing 

1640 if data_reupload is not None: 

1641 self.data_reupload = data_reupload 

1642 

1643 params = self._params_validation(params) 

1644 pulse_params = self._pulse_params_validation(pulse_params) 

1645 inputs = self._inputs_validation(inputs) 

1646 enc_params = self._enc_params_validation(enc_params) 

1647 

1648 inputs, params, pulse_params = self._assimilate_batch( 

1649 inputs, 

1650 params, 

1651 pulse_params, 

1652 ) 

1653 

1654 # split to generate a sub_key, required for actual execution 

1655 self.random_key, sub_key = safe_random_split(self.random_key) 

1656 

1657 # Build measurement type & observables from execution_type / output_qubit 

1658 meas_type, obs = self._build_obs() 

1659 

1660 # Jaqsi auto-routes between statevector and density-matrix simulation 

1661 # based on whether noise channels appear on the tape, so a single 

1662 B = np.prod(self.eff_batch_shape) 

1663 

1664 # kwargs are broadcast (not vmapped over) 

1665 exec_kwargs = dict( 

1666 noise_params=self.noise_params, 

1667 gate_mode=self.gate_mode, 

1668 ) 

1669 

1670 # Build a shot key from the random_key if shots are requested 

1671 shot_key = None 

1672 if self.shots is not None: 

1673 # overwrite subkey and split shot_key 

1674 sub_key, shot_key = safe_random_split(sub_key) 

1675 

1676 if B > 1: 

1677 # use random keys, derived from the subkey 

1678 random_keys = safe_random_split(sub_key, num=B) 

1679 

1680 in_axes = ( 

1681 0 if self.batch_shape[1] > 1 else None, # params 

1682 0 if self.batch_shape[0] > 1 else None, # inputs 

1683 0 if self.batch_shape[2] > 1 else None, # pulse_params 

1684 0, # random_keys 

1685 None, # enc_params (broadcast, not batched) 

1686 ) 

1687 

1688 result = self.script.execute( 

1689 type=meas_type, 

1690 obs=obs, 

1691 args=(params, inputs, pulse_params, random_keys, enc_params), 

1692 kwargs=exec_kwargs, 

1693 in_axes=in_axes, 

1694 shots=self.shots, 

1695 key=shot_key, 

1696 ) 

1697 else: 

1698 # use the subkey directly 

1699 result = self.script.execute( 

1700 type=meas_type, 

1701 obs=obs, 

1702 args=(params, inputs, pulse_params, sub_key, enc_params), 

1703 kwargs=exec_kwargs, 

1704 shots=self.shots, 

1705 key=shot_key, 

1706 ) 

1707 

1708 result = self._postprocess_res(result) 

1709 

1710 # --- Post-processing for partial-qubit measurements --------------- 

1711 if self.execution_type == "density" and not self.all_qubit_measurement: 

1712 result = js.partial_trace(result, self.n_qubits, self.output_qubit) 

1713 

1714 if self.execution_type == "probs" and not self.all_qubit_measurement: 

1715 if isinstance(self.output_qubit[0], (list, tuple)): 

1716 # list of qubit groups - marginalize each independently 

1717 result = jnp.stack( 

1718 [ 

1719 js.marginalize_probs(result, self.n_qubits, list(group)) 

1720 for group in self.output_qubit 

1721 ] 

1722 ) 

1723 else: 

1724 result = js.marginalize_probs(result, self.n_qubits, self.output_qubit) 

1725 

1726 result = jnp.asarray(result) 

1727 result = result.reshape((*self.eff_batch_shape, *self._result_shape)).squeeze() 

1728 

1729 if ( 

1730 self.execution_type in ("expval", "probs") 

1731 and force_mean 

1732 and len(result.shape) > 0 

1733 and self._result_shape[0] > 1 

1734 ): 

1735 result = result.mean(axis=-1) 

1736 

1737 return result