Coverage for qml_essentials / model.py: 90%

511 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-27 15:44 +0000

1from typing import Any, Dict, Optional, Tuple, Callable, Union, List 

2 

3from qml_essentials import operations as op 

4from qml_essentials import yaqsi as ys 

5import warnings 

6import jax.numpy as jnp 

7import numpy as np 

8from jax import random 

9 

10from qml_essentials.tape import recording 

11from qml_essentials.operations import KrausChannel 

12from qml_essentials.ansaetze import Ansaetze, Circuit, Encoding 

13from qml_essentials.gates import Gates, PulseInformation as pinfo 

14from qml_essentials.utils import safe_random_split 

15 

16import logging 

17 

18log = logging.getLogger(__name__) 

19 

20 

21class Model: 

22 """ 

23 A quantum circuit model. 

24 """ 

25 

26 def __init__( 

27 self, 

28 n_qubits: int, 

29 n_layers: int, 

30 circuit_type: Union[str, Circuit] = "No_Ansatz", 

31 data_reupload: Union[bool, List[List[bool]], List[List[List[bool]]]] = True, 

32 state_preparation: Union[ 

33 str, Callable, List[Union[str, Callable]], None 

34 ] = None, 

35 encoding: Union[Encoding, str, Callable, List[Union[str, Callable]]] = Gates.RX, 

36 trainable_frequencies: bool = False, 

37 initialization: str = "random", 

38 initialization_domain: List[float] = [0, 2 * jnp.pi], 

39 output_qubit: Union[List[int], int] = -1, 

40 shots: Optional[int] = None, 

41 random_seed: int = 1000, 

42 remove_zero_encoding: bool = True, 

43 repeat_batch_axis: List[bool] = [True, True, True], 

44 pulse_shape: str = "gaussian", 

45 ) -> None: 

46 """ 

47 Initialize the quantum circuit model. 

48 Parameters will have the shape [impl_n_layers, parameters_per_layer] 

49 where impl_n_layers is the number of layers provided and added by one 

50 depending if data_reupload is True and parameters_per_layer is given by 

51 the chosen ansatz. 

52 

53 The model is initialized with the following parameters as defaults: 

54 - noise_params: None 

55 - execution_type: "expval" 

56 - shots: None 

57 

58 Args: 

59 n_qubits (int): The number of qubits in the circuit. 

60 n_layers (int): The number of layers in the circuit. 

61 circuit_type (str, Circuit): The type of quantum circuit to use. 

62 If None, defaults to "no_ansatz". 

63 data_reupload (Union[bool, List[bool], List[List[bool]]], optional): 

64 Whether to reupload data to the quantum device on each 

65 layer and qubit. Detailed re-uploading instructions can be given 

66 as a list/array of 0/False and 1/True with shape (n_qubits, 

67 n_layers) to specify where to upload the data. Defaults to True 

68 for applying data re-uploading to the full circuit. 

69 encoding (Union[str, Callable, List[str], List[Callable]], optional): 

70 The unitary to use for encoding the input data. Can be a string 

71 (e.g. "RX") or a callable (e.g. op.RX). Defaults to op.RX. 

72 If input is multidimensional it is assumed to be a list of 

73 unitaries or a list of strings. 

74 trainable_frequencies (bool, optional): 

75 Sets trainable encoding parameters for trainable frequencies. 

76 Defaults to False. 

77 initialization (str, optional): The strategy to initialize the parameters. 

78 Can be "random", "zeros", "zero-controlled", "pi", or "pi-controlled". 

79 Defaults to "random". 

80 output_qubit (List[int], int, optional): The index of the output 

81 qubit (or qubits). When set to -1 all qubits are measured, or a 

82 global measurement is conducted, depending on the execution 

83 type. 

84 shots (Optional[int], optional): The number of shots to use for 

85 the quantum device. Defaults to None. 

86 random_seed (int, optional): seed for the random number generator 

87 in initialization is "random" and for random noise parameters. 

88 Defaults to 1000. 

89 remove_zero_encoding (bool, optional): whether to 

90 remove the zero encoding from the circuit. Defaults to True. 

91 repeat_batch_axis (List[bool], optional): Each boolean in the array 

92 determines over which axes to parallelise computation. The axes 

93 correspond to [inputs, params, pulse_params]. Defaults to 

94 [True, True, True], meaning that batching is enabled over all 

95 axes. 

96 pulse_shape (str, optional): Pulse envelope shape for pulse-level 

97 simulation. One of ``PulseEnvelope.available()``. 

98 Defaults to ``"gaussian"``. 

99 

100 Returns: 

101 None 

102 """ 

103 # Initialize default parameters needed for circuit evaluation 

104 self.n_qubits: int = n_qubits 

105 self.output_qubit: Union[List[int], int] = output_qubit 

106 self.n_layers: int = n_layers 

107 self.noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None 

108 self.shots = shots 

109 self.remove_zero_encoding = remove_zero_encoding 

110 self.trainable_frequencies: bool = trainable_frequencies 

111 self.execution_type: str = "expval" 

112 self.repeat_batch_axis: List[bool] = repeat_batch_axis 

113 

114 # --- Pulse envelope --- 

115 pinfo.set_envelope(pulse_shape) 

116 

117 # --- State Preparation --- 

118 try: 

119 self._sp = Gates.parse_gates(state_preparation, Gates) 

120 except ValueError as e: 

121 raise ValueError(f"Error parsing encodings: {e}") 

122 

123 # prepare corresponding pulse parameters (always optimized pulses) 

124 self.sp_pulse_params = [] 

125 for sp in self._sp: 

126 sp_name = sp.__name__ if hasattr(sp, "__name__") else str(sp) 

127 

128 if pinfo.gate_by_name(sp_name) is not None: 

129 self.sp_pulse_params.append(pinfo.gate_by_name(sp_name).params) 

130 else: 

131 # gate has no pulse parametrization 

132 self.sp_pulse_params.append(None) 

133 

134 # --- Encoding --- 

135 if isinstance(encoding, Encoding): 

136 # user wants custom strategy? do it! 

137 self._enc = encoding 

138 else: 

139 # use hammming encoding by default 

140 self._enc = Encoding("hamming", encoding) 

141 

142 if self._enc.is_golomb: 

143 self._enc._n_qubits = n_qubits 

144 

145 # Number of possible inputs 

146 self.n_input_feat = len(self._enc) 

147 log.debug(f"Number of input features: {self.n_input_feat}") 

148 

149 # Trainable frequencies, default initialization as in arXiv:2309.03279v2 

150 self.enc_params = jnp.ones((self.n_layers, self.n_qubits, self.n_input_feat)) 

151 

152 self._zero_inputs = False 

153 

154 # --- Data-Reuploading --- 

155 

156 # Keep as NumPy array (not JAX) so that ``if data_reupload[q, idx]`` 

157 # in _iec remains a concrete Python bool even under jax.jit tracing. 

158 # note that setting this will also update self.degree and self.frequencies 

159 # and in consequence also self.has_dru 

160 self.data_reupload = data_reupload 

161 

162 # check for the highest degree among all input dimensions 

163 if self.has_dru: 

164 impl_n_layers: int = n_layers + 1 # we need L+1 according to Schuld et al. 

165 else: 

166 impl_n_layers = n_layers 

167 log.info(f"Number of implicit layers: {impl_n_layers}.") 

168 

169 # --- Ansatz --- 

170 # only weak check for str. We trust the user to provide sth useful 

171 if isinstance(circuit_type, str): 

172 self.pqc: Callable[[Optional[jnp.ndarray], int], int] = getattr( 

173 Ansaetze, circuit_type or "No_Ansatz" 

174 )() 

175 else: 

176 self.pqc = circuit_type() 

177 log.info(f"Using Ansatz {circuit_type}.") 

178 

179 # calculate the shape of the parameter vector here, we will re-use this in init. 

180 params_per_layer = self.pqc.n_params_per_layer(self.n_qubits) 

181 self._params_shape: Tuple[int, int] = (impl_n_layers, params_per_layer) 

182 log.info(f"Parameters per layer: {params_per_layer}") 

183 

184 pulse_params_per_layer = self.pqc.n_pulse_params_per_layer(self.n_qubits) 

185 self._pulse_params_shape: Tuple[int, int] = ( 

186 impl_n_layers, 

187 pulse_params_per_layer, 

188 ) 

189 

190 # intialize to None as we can't know this yet 

191 self._batch_shape = None 

192 

193 # this will also be re-used in the init method, 

194 # however, only if nothing is provided 

195 self._inialization_strategy = initialization 

196 self._initialization_domain = initialization_domain 

197 

198 # ..here! where we only require a JAX random key 

199 self.random_key = self.initialize_params(random.key(random_seed)) 

200 

201 # Initializing pulse params 

202 self.pulse_params: jnp.ndarray = jnp.ones((1, *self._pulse_params_shape)) 

203 

204 log.info(f"Initialized pulse parameters with shape {self.pulse_params.shape}.") 

205 

206 # Initialise the yaqsi Script that wraps _variational. 

207 # No device selection needed - yaqsi auto-routes between statevector 

208 # and density-matrix simulation based on whether noise channels are 

209 # present on the tape. 

210 self.script = ys.Script(f=self._variational, n_qubits=self.n_qubits) 

211 

212 @property 

213 def noise_params(self) -> Optional[Dict[str, Union[float, Dict[str, float]]]]: 

214 """ 

215 Gets the noise parameters of the model. 

216 

217 Returns: 

218 Optional[Dict[str, float]]: A dictionary of 

219 noise parameters or None if not set. 

220 """ 

221 return self._noise_params 

222 

223 @noise_params.setter 

224 def noise_params( 

225 self, kvs: Optional[Dict[str, Union[float, Dict[str, float]]]] 

226 ) -> None: 

227 """ 

228 Sets the noise parameters of the model. 

229 

230 Typically a "noise parameter" refers to the error probability. 

231 ThermalRelaxation is a special case, and supports a dict as value with 

232 structure: 

233 "ThermalRelaxation": 

234 { 

235 "t1": 2000, # relative t1 time. 

236 "t2": 1000, # relative t2 time 

237 "t_factor" 1: # relative gate time factor 

238 }, 

239 

240 Args: 

241 kvs (Optional[Dict[str, Union[float, Dict[str, float]]]]): A 

242 dictionary of noise parameters. If all values are 0.0, the noise 

243 parameters are set to None. 

244 

245 Returns: 

246 None 

247 """ 

248 # set to None if only zero values provided 

249 if kvs is not None and all(v == 0.0 for v in kvs.values()): 

250 kvs = None 

251 

252 # set default values 

253 if kvs is not None: 

254 defaults = { 

255 "BitFlip": 0.0, 

256 "PhaseFlip": 0.0, 

257 "Depolarizing": 0.0, 

258 "MultiQubitDepolarizing": 0.0, 

259 "AmplitudeDamping": 0.0, 

260 "PhaseDamping": 0.0, 

261 "GateError": 0.0, 

262 "ThermalRelaxation": None, 

263 "StatePreparation": 0.0, 

264 "Measurement": 0.0, 

265 } 

266 for key, default_val in defaults.items(): 

267 kvs.setdefault(key, default_val) 

268 

269 # check if there are any keys not supported 

270 for key in kvs.keys(): 

271 if key not in defaults: 

272 warnings.warn( 

273 f"Noise type {key} is not supported by this package", 

274 UserWarning, 

275 ) 

276 

277 # check valid params for thermal relaxation noise channel 

278 tr_params = kvs["ThermalRelaxation"] 

279 if isinstance(tr_params, dict): 

280 tr_params.setdefault("t1", 0.0) 

281 tr_params.setdefault("t2", 0.0) 

282 tr_params.setdefault("t_factor", 0.0) 

283 valid_tr_keys = {"t1", "t2", "t_factor"} 

284 for k in tr_params.keys(): 

285 if k not in valid_tr_keys: 

286 warnings.warn( 

287 f"Thermal Relaxation parameter {k} is not supported " 

288 f"by this package", 

289 UserWarning, 

290 ) 

291 if not all(tr_params.values()) or tr_params["t2"] > 2 * tr_params["t1"]: 

292 warnings.warn( 

293 "Received invalid values for Thermal Relaxation noise " 

294 "parameter. Thermal relaxation is not applied!", 

295 UserWarning, 

296 ) 

297 kvs["ThermalRelaxation"] = 0.0 

298 

299 self._noise_params = kvs 

300 

301 @property 

302 def output_qubit(self) -> List[int]: 

303 """Get the output qubit indices for measurement.""" 

304 return self._output_qubit 

305 

306 @output_qubit.setter 

307 def output_qubit(self, value: Union[int, List[int]]) -> None: 

308 """ 

309 Set the output qubit(s) for measurement. 

310 

311 Args: 

312 value: Qubit index or list of indices. Use -1 for all qubits. 

313 """ 

314 if isinstance(value, list): 

315 assert len(value) <= self.n_qubits, ( 

316 f"Size of output_qubit {len(value)} cannot be\ 

317 larger than number of qubits {self.n_qubits}." 

318 ) 

319 elif isinstance(value, int): 

320 if value == -1: 

321 value = list(range(self.n_qubits)) 

322 else: 

323 assert value < self.n_qubits, ( 

324 f"Output qubit {value} cannot be larger than {self.n_qubits}." 

325 ) 

326 value = [value] 

327 

328 self._output_qubit = value 

329 

330 @property 

331 def execution_type(self) -> str: 

332 """ 

333 Gets the execution type of the model. 

334 

335 Returns: 

336 str: The execution type, one of 'density', 'expval', or 'probs'. 

337 """ 

338 return self._execution_type 

339 

340 @execution_type.setter 

341 def execution_type(self, value: str) -> None: 

342 if value == "density": 

343 self._result_shape = ( 

344 2 ** len(self.output_qubit), 

345 2 ** len(self.output_qubit), 

346 ) 

347 elif value == "expval": 

348 # check if all qubits are used 

349 if len(self.output_qubit) == self.n_qubits: 

350 self._result_shape = (len(self.output_qubit),) 

351 # if not -> parity measurement with only 1D output per pair 

352 # or n_local measurement 

353 else: 

354 self._result_shape = (len(self.output_qubit),) 

355 elif value == "probs": 

356 # in case this is a list of parities, 

357 # each pair has 2^len(qubits) probabilities 

358 n_parity = ( 

359 (2,) * len(self.output_qubit) 

360 if isinstance(self.output_qubit, (Tuple, List)) 

361 else (2,) 

362 ) 

363 self._result_shape = n_parity 

364 elif value == "state": 

365 self._result_shape = (2 ** len(self.output_qubit),) 

366 else: 

367 raise ValueError(f"Invalid execution type: {value}.") 

368 

369 if value == "state" and not self.all_qubit_measurement: 

370 warnings.warn( 

371 f"{value} measurement does ignore output_qubit, which is " 

372 f"{self.output_qubit}.", 

373 UserWarning, 

374 ) 

375 

376 if value == "probs" and self.shots is None: 

377 warnings.warn( 

378 "Setting execution_type to probs without specifying shots.", 

379 UserWarning, 

380 ) 

381 

382 if value == "density" and self.shots is not None: 

383 raise ValueError("Setting execution_type to density with shots not None.") 

384 

385 self._execution_type = value 

386 

387 @property 

388 def shots(self) -> Optional[int]: 

389 """ 

390 Gets the number of shots to use for the quantum device. 

391 

392 Returns: 

393 Optional[int]: The number of shots. 

394 """ 

395 return self._shots 

396 

397 @shots.setter 

398 def shots(self, value: Optional[int]) -> None: 

399 """ 

400 Sets the number of shots to use for the quantum device. 

401 

402 Args: 

403 value (Optional[int]): The number of shots. 

404 If an integer less than or equal to 0 is provided, it is set to None. 

405 

406 Returns: 

407 None 

408 """ 

409 if type(value) is int and value <= 0: 

410 value = None 

411 self._shots = value 

412 

413 @property 

414 def params(self) -> jnp.ndarray: 

415 """Get the variational parameters of the model.""" 

416 return self._params 

417 

418 @params.setter 

419 def params(self, value: jnp.ndarray) -> None: 

420 """Set the variational parameters, ensuring batch dimension exists.""" 

421 if len(value.shape) == 2: 

422 value = value.reshape(1, *value.shape) 

423 

424 self._params = value 

425 

426 @property 

427 def enc_params(self) -> jnp.ndarray: 

428 """Get the encoding parameters used for input transformation.""" 

429 return self._enc_params 

430 

431 @enc_params.setter 

432 def enc_params(self, value: jnp.ndarray) -> None: 

433 """Set the encoding parameters.""" 

434 self._enc_params = value 

435 

436 @property 

437 def pulse_params(self) -> jnp.ndarray: 

438 """Get the pulse parameters for pulse-mode gate execution.""" 

439 return self._pulse_params 

440 

441 @pulse_params.setter 

442 def pulse_params(self, value: jnp.ndarray) -> None: 

443 """Set the pulse parameters.""" 

444 self._pulse_params = value 

445 

446 @property 

447 def data_reupload(self) -> jnp.ndarray: 

448 """Get the data reupload mask.""" 

449 return self._data_reupload 

450 

451 @data_reupload.setter 

452 def data_reupload(self, value: jnp.ndarray) -> None: 

453 """Set the data reupload mask. 

454 

455 Always converts to a concrete NumPy boolean array so that 

456 ``if data_reupload[q, idx]`` in :meth:`_iec` remains a plain 

457 Python ``bool`` even inside JAX-traced functions (jit / grad / vmap). 

458 """ 

459 # Process data reuploading strategy and set degree 

460 if not isinstance(value, bool): 

461 if not isinstance(value, np.ndarray): 

462 value = np.array(value) 

463 

464 if len(value.shape) == 2: 

465 assert value.shape == ( 

466 self.n_layers, 

467 self.n_qubits, 

468 ), ( 

469 f"Data reuploading array has wrong shape. \ 

470 Expected {(self.n_layers, self.n_qubits)} or\ 

471 {(self.n_layers, self.n_qubits, self.n_input_feat)},\ 

472 got {value.shape}." 

473 ) 

474 value = value.reshape(*value.shape, 1) 

475 value = np.repeat(value, self.n_input_feat, axis=2) 

476 

477 assert value.shape == ( 

478 self.n_layers, 

479 self.n_qubits, 

480 self.n_input_feat, 

481 ), ( 

482 f"Data reuploading array has wrong shape. \ 

483 Expected {(self.n_layers, self.n_qubits, self.n_input_feat)},\ 

484 got {value.shape}." 

485 ) 

486 

487 log.debug(f"Data reuploading array:\n{value}") 

488 else: 

489 if value: 

490 value = np.ones((self.n_layers, self.n_qubits, self.n_input_feat)) 

491 log.debug("Full data reuploading.") 

492 else: 

493 value = np.zeros((self.n_layers, self.n_qubits, self.n_input_feat)) 

494 value[0][0] = 1 

495 log.debug("No data reuploading.") 

496 

497 # convert to boolean values 

498 self._data_reupload = np.asarray(value).astype(bool) 

499 

500 self.degree: Tuple = tuple( 

501 self._enc.get_n_freqs(np.count_nonzero(self.data_reupload[..., i])) 

502 for i in range(self.n_input_feat) 

503 ) 

504 

505 self.frequencies: Tuple = tuple( 

506 self._enc.get_spectrum(np.count_nonzero(self.data_reupload[..., i])) 

507 for i in range(self.n_input_feat) 

508 ) 

509 

510 # Cache has_dru as a plain Python bool so that it can be used in 

511 # Python ``if`` statements even inside JAX-traced functions. 

512 self._has_dru: bool = bool(max(int(np.max(f)) for f in self._frequencies) > 1) 

513 

514 @property 

515 def degree(self) -> Tuple: 

516 """Get the degree of the model.""" 

517 return self._degree 

518 

519 @degree.setter 

520 def degree(self, value: Tuple): 

521 self._degree = value 

522 

523 @property 

524 def frequencies(self) -> Tuple: 

525 """Get the frequencies of the model.""" 

526 return self._frequencies 

527 

528 @frequencies.setter 

529 def frequencies(self, value: Tuple): 

530 self._frequencies = value 

531 

532 @property 

533 def has_dru(self) -> bool: 

534 """Check if the model has data reupload.""" 

535 return self._has_dru 

536 

537 @property 

538 def all_qubit_measurement(self) -> bool: 

539 """Check if measurement is performed on all qubits.""" 

540 return self.output_qubit == list(range(self.n_qubits)) 

541 

542 @property 

543 def batch_shape(self) -> Tuple[int, ...]: 

544 """ 

545 Get the batch shape (B_I, B_P, B_R). 

546 If the model was not called before, 

547 it returns (1, 1, 1). 

548 

549 Returns: 

550 Tuple[int, ...]: Tuple of (input_batch, param_batch, pulse_batch). 

551 Returns (1, 1, 1) if model has not been called yet. 

552 """ 

553 if self._batch_shape is None: 

554 log.debug("Model was not called yet. Returning (1,1,1) as batch shape.") 

555 return (1, 1, 1) 

556 return self._batch_shape 

557 

558 @property 

559 def eff_batch_shape(self) -> Tuple[int, ...]: 

560 """ 

561 Get the effective batch shape after applying repeat_batch_axis mask. 

562 

563 Returns: 

564 Tuple[int, ...]: Effective batch dimensions, excluding zeros. 

565 """ 

566 batch_shape = np.array(self.batch_shape) * self.repeat_batch_axis 

567 batch_shape = batch_shape[batch_shape != 0] 

568 return batch_shape 

569 

570 def initialize_params( 

571 self, 

572 random_key: Optional[random.PRNGKey] = None, 

573 repeat: int = 1, 

574 initialization: Optional[str] = None, 

575 initialization_domain: Optional[List[float]] = None, 

576 ) -> random.PRNGKey: 

577 """ 

578 Initialize the variational parameters of the model. 

579 

580 Args: 

581 random_key (Optional[random.PRNGKey]): JAX random key for initialization. 

582 If None, uses the model's internal random key. 

583 repeat (int): Number of parameter sets to create (batch dimension). 

584 Defaults to 1. 

585 initialization (Optional[str]): Strategy for parameter initialization. 

586 Options: "random", "zeros", "pi", "zero-controlled", "pi-controlled". 

587 If None, uses the strategy specified in the constructor. 

588 initialization_domain (Optional[List[float]]): Domain [min, max] for 

589 random initialization. If None, uses the domain from constructor. 

590 

591 Returns: 

592 random.PRNGKey: Updated random key after initialization. 

593 

594 Raises: 

595 Exception: If an invalid initialization method is specified. 

596 """ 

597 # Initializing params 

598 params_shape = (repeat, *self._params_shape) 

599 

600 # use existing strategy if not specified 

601 initialization = initialization or self._inialization_strategy 

602 initialization_domain = initialization_domain or self._initialization_domain 

603 

604 random_key, sub_key = safe_random_split( 

605 random_key if random_key is not None else self.random_key 

606 ) 

607 

608 def set_control_params(params: jnp.ndarray, value: float) -> jnp.ndarray: 

609 indices = self.pqc.get_control_indices(self.n_qubits) 

610 if indices is None: 

611 warnings.warn( 

612 f"Specified {initialization} but circuit\ 

613 does not contain controlled rotation gates.\ 

614 Parameters are intialized randomly.", 

615 UserWarning, 

616 ) 

617 else: 

618 np_params = np.array(params) 

619 np_params[:, :, indices[0] : indices[1] : indices[2]] = ( 

620 np.ones_like(params[:, :, indices[0] : indices[1] : indices[2]]) 

621 * value 

622 ) 

623 params = jnp.array(np_params) 

624 return params 

625 

626 if initialization == "random": 

627 self.params: jnp.ndarray = random.uniform( 

628 sub_key, 

629 params_shape, 

630 minval=initialization_domain[0], 

631 maxval=initialization_domain[1], 

632 ) 

633 elif initialization == "zeros": 

634 self.params: jnp.ndarray = jnp.zeros(params_shape) 

635 elif initialization == "pi": 

636 self.params: jnp.ndarray = jnp.ones(params_shape) * jnp.pi 

637 elif initialization == "zero-controlled": 

638 self.params: jnp.ndarray = random.uniform( 

639 sub_key, 

640 params_shape, 

641 minval=initialization_domain[0], 

642 maxval=initialization_domain[1], 

643 ) 

644 self.params = set_control_params(self.params, 0) 

645 elif initialization == "pi-controlled": 

646 self.params: jnp.ndarray = random.uniform( 

647 sub_key, 

648 params_shape, 

649 minval=initialization_domain[0], 

650 maxval=initialization_domain[1], 

651 ) 

652 self.params = set_control_params(self.params, jnp.pi) 

653 else: 

654 raise Exception("Invalid initialization method") 

655 

656 log.info( 

657 f"Initialized parameters with shape {self.params.shape}\ 

658 using strategy {initialization}." 

659 ) 

660 

661 return random_key 

662 

663 def transform_input( 

664 self, inputs: jnp.ndarray, enc_params: jnp.ndarray 

665 ) -> jnp.ndarray: 

666 """ 

667 Transform input data by scaling with encoding parameters. 

668 

669 Implements the input transformation as described in arXiv:2309.03279v2, 

670 where inputs are linearly scaled by encoding parameters before being 

671 used in the quantum circuit. 

672 

673 Args: 

674 inputs (jnp.ndarray): Input data point of shape (n_input_feat,) or 

675 (batch_size, n_input_feat). 

676 enc_params (jnp.ndarray): Encoding weight scalar or vector used to 

677 scale the input. 

678 

679 Returns: 

680 jnp.ndarray: Transformed input, element-wise product of inputs 

681 and enc_params. 

682 """ 

683 return inputs * enc_params 

684 

685 def _iec( 

686 self, 

687 inputs: jnp.ndarray, 

688 data_reupload: jnp.ndarray, 

689 enc: Encoding, 

690 enc_params: jnp.ndarray, 

691 noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None, 

692 random_key: Optional[random.PRNGKey] = None, 

693 ) -> None: 

694 """ 

695 Apply Input Encoding Circuit (IEC) with angle encoding. 

696 

697 Encodes classical input data into the quantum circuit using rotation 

698 gates (e.g., RX, RY, RZ). Supports data re-uploading at specified 

699 positions in the circuit. 

700 

701 For Golomb encoding, a single multi-qubit diagonal unitary is applied 

702 to all qubits simultaneously instead of per-qubit rotation gates. 

703 

704 Args: 

705 inputs (jnp.ndarray): Input data of shape (n_input_feat,) or 

706 (batch_size, n_input_feat). 

707 data_reupload (jnp.ndarray): Boolean array of shape (n_qubits, n_input_feat) 

708 indicating where to apply encoding gates. 

709 enc (Encoding): Encoding strategy containing the encoding gate functions. 

710 enc_params (jnp.ndarray): Encoding parameters of shape 

711 (n_qubits, n_input_feat) used to scale inputs. 

712 noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]): 

713 Noise parameters for gate-level noise simulation. Defaults to None. 

714 random_key (Optional[random.PRNGKey]): JAX random key for stochastic 

715 noise. Defaults to None. 

716 

717 Returns: 

718 None: Gates are applied in-place to the quantum circuit. 

719 """ 

720 # check for zero, because due to input validation, input cannot be none 

721 if self.remove_zero_encoding and self._zero_inputs and self.batch_shape[0] == 1: 

722 return 

723 

724 # --- Golomb encoding: single multi-qubit gate on all qubits -------- 

725 if enc.is_golomb: 

726 idx = 0 # Golomb encoding supports a single input feature 

727 # Check if any qubit has re-uploading enabled for this layer 

728 if data_reupload[:, idx].any(): 

729 random_key, sub_key = safe_random_split(random_key) 

730 # Use the mean of enc_params across qubits as scalar scaling 

731 # (Golomb acts on all qubits jointly) 

732 mean_enc_param = jnp.mean(enc_params[:, idx]) 

733 all_wires = list(range(self.n_qubits)) 

734 enc[idx]( 

735 self.transform_input(inputs[..., idx], mean_enc_param), 

736 wires=all_wires, 

737 noise_params=noise_params, 

738 random_key=sub_key, 

739 input_idx=idx, 

740 ) 

741 return 

742 

743 # --- Standard per-qubit encoding ----------------------------------- 

744 for q in range(self.n_qubits): 

745 # use the last dimension of the inputs (feature dimension) 

746 for idx in range(inputs.shape[-1]): 

747 if data_reupload[q, idx]: 

748 # use elipsis to indiex only the last dimension 

749 # as inputs are generally *not* qubit dependent 

750 random_key, sub_key = safe_random_split(random_key) 

751 enc[idx]( 

752 self.transform_input(inputs[..., idx], enc_params[q, idx]), 

753 wires=q, 

754 noise_params=noise_params, 

755 random_key=sub_key, 

756 input_idx=idx, 

757 ) 

758 

759 def _variational( 

760 self, 

761 params: jnp.ndarray, 

762 inputs: jnp.ndarray, 

763 pulse_params: Optional[jnp.ndarray] = None, 

764 random_key: Optional[random.PRNGKey] = None, 

765 enc_params: Optional[jnp.ndarray] = None, 

766 gate_mode: str = "unitary", 

767 noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None, 

768 ) -> None: 

769 """ 

770 Build the variational quantum circuit structure. 

771 

772 Constructs the circuit by applying state preparation, alternating 

773 variational ansatz layers with input encoding layers, and optional 

774 noise channels. 

775 

776 The first five parameters (after ``self``) - ``params``, ``inputs``, 

777 ``pulse_params``, ``random_key``, ``enc_params`` - are the batchable 

778 positional arguments. 

779 The remaining keyword arguments are broadcast across the batch. 

780 

781 Args: 

782 params (jnp.ndarray): Variational parameters of shape 

783 (n_layers, n_params_per_layer). 

784 inputs (jnp.ndarray): Input data of shape (n_input_feat,). 

785 pulse_params (Optional[jnp.ndarray]): Pulse parameter scalers of shape 

786 (n_layers, n_pulse_params_per_layer) for pulse-mode execution. 

787 Defaults to None (uses model's pulse_params). 

788 random_key (Optional[random.PRNGKey]): JAX random key for stochastic 

789 operations. Defaults to None. 

790 enc_params (Optional[jnp.ndarray]): Encoding parameters of shape 

791 (n_qubits, n_input_feat). Defaults to None (uses model's enc_params). 

792 gate_mode (str): Gate execution mode, either "unitary" or "pulse". 

793 Defaults to "unitary". 

794 noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]): 

795 Noise parameters for simulation. Defaults to None. 

796 

797 Returns: 

798 None: Gates are applied in-place to the quantum circuit. 

799 

800 Note: 

801 Issues RuntimeWarning if called directly without providing parameters 

802 that would normally be passed through the forward method. 

803 """ 

804 # TODO: rework and double check params shape 

805 if len(params.shape) > 2 and params.shape[0] == 1: 

806 params = params[0] 

807 

808 if len(inputs.shape) > 1 and inputs.shape[0] == 1: 

809 inputs = inputs[0] 

810 

811 if enc_params is None: 

812 # TODO: Raise warning if trainable frequencies is True, or similar. I.e., no 

813 # warning if user does not care for frequencies or enc_params 

814 if self.trainable_frequencies: 

815 warnings.warn( 

816 "Explicit call to `_circuit` or `_variational` detected: " 

817 "`enc_params` is None, using `self.enc_params` instead.", 

818 RuntimeWarning, 

819 ) 

820 enc_params = self.enc_params 

821 

822 if pulse_params is None: 

823 if gate_mode == "pulse": 

824 warnings.warn( 

825 "Explicit call to `_circuit` or `_variational` detected: " 

826 "`pulse_params` is None, using `self.pulse_params` instead.", 

827 RuntimeWarning, 

828 ) 

829 pulse_params = self.pulse_params 

830 

831 # Squeeze batch dimension for pulse_params (batch-first convention) 

832 if len(pulse_params.shape) > 2 and pulse_params.shape[0] == 1: 

833 pulse_params = pulse_params[0] 

834 

835 if noise_params is None: 

836 if self.noise_params is not None: 

837 warnings.warn( 

838 "Explicit call to `_circuit` or `_variational` detected: " 

839 "`noise_params` is None, using `self.noise_params` instead.", 

840 RuntimeWarning, 

841 ) 

842 noise_params = self.noise_params 

843 

844 if noise_params is not None: 

845 if random_key is None: 

846 warnings.warn( 

847 "Explicit call to `_circuit` or `_variational` detected: " 

848 "`random_key` is None, using `random.PRNGKey(0)` instead.", 

849 RuntimeWarning, 

850 ) 

851 random_key = self.random_key 

852 self._apply_state_prep_noise(noise_params=noise_params) 

853 

854 # state preparation 

855 for q in range(self.n_qubits): 

856 for _sp, sp_pulse_params in zip(self._sp, self.sp_pulse_params): 

857 random_key, sub_key = safe_random_split(random_key) 

858 _sp( 

859 wires=q, 

860 pulse_params=sp_pulse_params, 

861 noise_params=noise_params, 

862 random_key=sub_key, 

863 gate_mode=gate_mode, 

864 ) 

865 

866 # circuit building 

867 for layer in range(0, self.n_layers): 

868 random_key, sub_key = safe_random_split(random_key) 

869 # ansatz layers 

870 self.pqc( 

871 params[layer], 

872 self.n_qubits, 

873 pulse_params=pulse_params[layer], 

874 noise_params=noise_params, 

875 random_key=sub_key, 

876 gate_mode=gate_mode, 

877 ) 

878 

879 random_key, sub_key = safe_random_split(random_key) 

880 # encoding layers 

881 self._iec( 

882 inputs, 

883 data_reupload=self.data_reupload[layer], 

884 enc=self._enc, 

885 enc_params=enc_params[layer], 

886 noise_params=noise_params, 

887 random_key=sub_key, 

888 ) 

889 

890 # visual barrier (no-op in yaqsi, purely cosmetic in PennyLane) 

891 

892 # final ansatz layer 

893 if self.has_dru: # same check as in init 

894 random_key, sub_key = safe_random_split(random_key) 

895 self.pqc( 

896 params[self.n_layers], 

897 self.n_qubits, 

898 pulse_params=pulse_params[-1], 

899 noise_params=noise_params, 

900 random_key=sub_key, 

901 gate_mode=gate_mode, 

902 ) 

903 

904 # channel noise 

905 if noise_params is not None: 

906 self._apply_general_noise(noise_params=noise_params) 

907 

908 def _build_obs(self) -> Tuple[str, List[op.Operation]]: 

909 """Build the yaqsi measurement type and observable list. 

910 

911 Translates the model's ``execution_type`` and ``output_qubit`` 

912 settings into parameters suitable for 

913 :meth:`~qml_essentials.yaqsi.Script.execute`. 

914 

915 Returns: 

916 Tuple ``(meas_type, obs)`` where *meas_type* is one of 

917 ``"expval"``, ``"probs"``, ``"density"``, ``"state"`` and *obs* 

918 is a (possibly empty) list of :class:`Operation` observables. 

919 """ 

920 if self.execution_type == "density": 

921 return "density", [] 

922 

923 if self.execution_type == "state": 

924 return "state", [] 

925 

926 if self.execution_type == "expval": 

927 obs: List[op.Operation] = [] 

928 for qubit_spec in self.output_qubit: 

929 if isinstance(qubit_spec, int): 

930 obs.append(op.PauliZ(wires=qubit_spec)) 

931 else: 

932 # parity: Z \\otimes Z \\otimes … 

933 obs.append(ys.build_parity_observable(list(qubit_spec))) 

934 return "expval", obs 

935 

936 if self.execution_type == "probs": 

937 # probs are computed on the full system; subsystem 

938 # marginalisation is handled in _postprocess_res 

939 return "probs", [] 

940 

941 raise ValueError(f"Invalid execution_type: {self.execution_type}.") 

942 

943 def _apply_state_prep_noise( 

944 self, noise_params: Dict[str, Union[float, Dict[str, float]]] 

945 ) -> None: 

946 """ 

947 Apply state preparation noise to all qubits. 

948 

949 Simulates imperfect state preparation by applying BitFlip errors 

950 to each qubit with the specified probability. 

951 

952 Args: 

953 noise_params (Dict[str, Union[float, Dict[str, float]]]): Dictionary 

954 containing noise parameters. Uses the "StatePreparation" key 

955 for the BitFlip probability. 

956 

957 Returns: 

958 None: Noise channels are applied in-place to the circuit. 

959 """ 

960 p = noise_params.get("StatePreparation", 0.0) 

961 if p > 0: 

962 for q in range(self.n_qubits): 

963 op.BitFlip(p, wires=q) 

964 

965 def _apply_general_noise( 

966 self, noise_params: Dict[str, Union[float, Dict[str, float]]] 

967 ) -> None: 

968 """ 

969 Apply general noise channels to all qubits. 

970 

971 Applies various decoherence and error channels after the circuit 

972 execution, simulating environmental noise effects. 

973 

974 Args: 

975 noise_params (Dict[str, Union[float, Dict[str, float]]]): Dictionary 

976 containing noise parameters with the following supported keys: 

977 - "AmplitudeDamping" (float): Probability for amplitude damping. 

978 - "PhaseDamping" (float): Probability for phase damping. 

979 - "Measurement" (float): Probability for measurement error (BitFlip). 

980 - "ThermalRelaxation" (Dict): Dictionary with keys "t1", "t2", 

981 "t_factor" for thermal relaxation simulation. 

982 

983 Returns: 

984 None: Noise channels are applied in-place to the circuit. 

985 

986 Note: 

987 Gate-level noise (e.g., GateError) is handled separately in the 

988 Gates.Noise module and applied at the individual gate level. 

989 """ 

990 amp_damp = noise_params.get("AmplitudeDamping", 0.0) 

991 phase_damp = noise_params.get("PhaseDamping", 0.0) 

992 thermal_relax = noise_params.get("ThermalRelaxation", 0.0) 

993 meas = noise_params.get("Measurement", 0.0) 

994 for q in range(self.n_qubits): 

995 if amp_damp > 0: 

996 op.AmplitudeDamping(amp_damp, wires=q) 

997 if phase_damp > 0: 

998 op.PhaseDamping(phase_damp, wires=q) 

999 if meas > 0: 

1000 op.BitFlip(meas, wires=q) 

1001 if isinstance(thermal_relax, dict): 

1002 t1 = thermal_relax["t1"] 

1003 t2 = thermal_relax["t2"] 

1004 t_factor = thermal_relax["t_factor"] 

1005 circuit_depth = self._get_circuit_depth() 

1006 tg = circuit_depth * t_factor 

1007 op.ThermalRelaxationError(1.0, t1, t2, tg, q) 

1008 

1009 def _get_circuit_depth(self, inputs: Optional[jnp.ndarray] = None) -> int: 

1010 """ 

1011 Calculate the depth of the quantum circuit. 

1012 

1013 Records the circuit onto a tape (without noise) and computes the 

1014 depth as the length of the critical path: each gate is scheduled 

1015 at the earliest time step after all of its qubits are free. 

1016 

1017 Args: 

1018 inputs (Optional[jnp.ndarray]): Input data for circuit evaluation. 

1019 If None, default zero inputs are used. 

1020 

1021 Returns: 

1022 int: The circuit depth (longest path of gates in the circuit). 

1023 """ 

1024 # Return cached value if available 

1025 if hasattr(self, "_cached_circuit_depth"): 

1026 return self._cached_circuit_depth 

1027 

1028 inputs = self._inputs_validation(inputs) 

1029 

1030 # Temporarily clear noise_params to prevent _variational from 

1031 # picking them up (which would call _apply_general_noise -> 

1032 # _get_circuit_depth again, causing infinite recursion). 

1033 saved_noise = self._noise_params 

1034 self._noise_params = None 

1035 

1036 with recording() as tape: 

1037 self._variational( 

1038 self.params[0] if self.params.ndim == 3 else self.params, 

1039 inputs[0] if inputs.ndim == 2 else inputs, 

1040 noise_params=None, 

1041 ) 

1042 

1043 self._noise_params = saved_noise 

1044 

1045 # Filter out noise channels - only count unitary gates 

1046 ops = [o for o in tape if not isinstance(o, KrausChannel)] 

1047 

1048 if not ops: 

1049 self._cached_circuit_depth = 0 

1050 return 0 

1051 

1052 # Schedule each gate at the earliest time step where all its wires 

1053 # are free. ``wire_busy[q]`` tracks the next free time step for 

1054 # qubit ``q``. 

1055 wire_busy: Dict[int, int] = {} 

1056 depth = 0 

1057 for gate in ops: 

1058 start = max((wire_busy.get(w, 0) for w in gate.wires), default=0) 

1059 end = start + 1 

1060 for w in gate.wires: 

1061 wire_busy[w] = end 

1062 depth = max(depth, end) 

1063 

1064 self._cached_circuit_depth = depth 

1065 return depth 

1066 

1067 def draw( 

1068 self, 

1069 inputs: Optional[jnp.ndarray] = None, 

1070 figure: str = "text", 

1071 **kwargs: Any, 

1072 ) -> Union[str, Any]: 

1073 """Visualize the quantum circuit. 

1074 

1075 Records the circuit tape (without noise) and renders the gate 

1076 sequence using the requested backend. 

1077 

1078 Args: 

1079 inputs (Optional[jnp.ndarray]): Input data for the circuit. 

1080 If ``None``, default zero inputs are used. 

1081 figure (str): Rendering backend. One of: 

1082 

1083 * ``"text"`` - ASCII art (returned as a ``str``). 

1084 * ``"mpl"`` - Matplotlib figure (returns ``(fig, ax)``). 

1085 * ``"tikz"`` - LaTeX/TikZ ``quantikz`` code (returns a 

1086 :class:`TikzFigure`). 

1087 * ``"pulse"`` - Pulse schedule (returns ``(fig, axes)``). 

1088 Only meaningful for pulse-mode models. 

1089 

1090 **kwargs: Extra options forwarded to the drawing backend 

1091 (e.g. ``gate_values=True``). 

1092 

1093 Returns: 

1094 Depends on figure: 

1095 

1096 * ``"text"`` -> ``str`` 

1097 * ``"mpl"`` -> ``(matplotlib.figure.Figure, matplotlib.axes.Axes)`` 

1098 * ``"tikz"`` -> :class:`TikzFigure` 

1099 

1100 Raises: 

1101 ValueError: If figure is not one of the supported modes. 

1102 """ 

1103 inputs = self._inputs_validation(inputs) 

1104 params = self.params[0] if self.params.ndim == 3 else self.params 

1105 inp = inputs[0] if inputs.ndim == 2 else inputs 

1106 

1107 if figure == "pulse": 

1108 return self.draw_pulse(inputs=inputs, **kwargs) 

1109 

1110 # Record without noise to get a clean circuit 

1111 saved_noise = self._noise_params 

1112 self._noise_params = None 

1113 

1114 draw_script = ys.Script(f=self._variational, n_qubits=self.n_qubits) 

1115 result = draw_script.draw( 

1116 figure=figure, 

1117 args=(params, inp), 

1118 kwargs={"noise_params": None}, 

1119 **kwargs, 

1120 ) 

1121 

1122 self._noise_params = saved_noise 

1123 return result 

1124 

1125 def draw_pulse( 

1126 self, 

1127 inputs: Optional[jnp.ndarray] = None, 

1128 **kwargs: Any, 

1129 ) -> Any: 

1130 """Visualize the pulse schedule for the circuit. 

1131 

1132 Records the circuit in pulse mode and collects PulseEvents 

1133 automatically via the pulse-event tape, then renders them. 

1134 

1135 Args: 

1136 inputs: Input data. If ``None``, default zero inputs are used. 

1137 **kwargs: Forwarded to 

1138 :func:`~qml_essentials.drawing.draw_pulse_schedule` 

1139 (e.g. ``show_carrier=True``, ``n_samples=300``). 

1140 

1141 Returns: 

1142 ``(fig, axes)`` — Matplotlib Figure and array of Axes. 

1143 """ 

1144 inputs = self._inputs_validation(inputs) 

1145 params = self.params[0] if self.params.ndim == 3 else self.params 

1146 inp = inputs[0] if inputs.ndim == 2 else inputs 

1147 

1148 draw_script = ys.Script(f=self._variational, n_qubits=self.n_qubits) 

1149 return draw_script.draw( 

1150 figure="pulse", 

1151 args=(params, inp), 

1152 kwargs={ 

1153 "gate_mode": "pulse", 

1154 "noise_params": None, 

1155 }, 

1156 **kwargs, 

1157 ) 

1158 

1159 def __repr__(self) -> str: 

1160 """Return text representation of the quantum circuit model.""" 

1161 return self.draw(figure="text") 

1162 

1163 def __str__(self) -> str: 

1164 """Return string representation of the quantum circuit model.""" 

1165 return self.draw(figure="text") 

1166 

1167 def _params_validation(self, params: Optional[jnp.ndarray]) -> jnp.ndarray: 

1168 """ 

1169 Validate and normalize variational parameters. 

1170 

1171 Ensures parameters have the correct shape with a batch dimension, 

1172 and updates the model's internal parameters if new ones are provided. 

1173 

1174 Args: 

1175 params (Optional[jnp.ndarray]): Variational parameters to validate. 

1176 If None, returns the model's current parameters. 

1177 

1178 Returns: 

1179 jnp.ndarray: Validated parameters with shape 

1180 (batch_size, n_layers, n_params_per_layer). 

1181 """ 

1182 # append batch axis if not provided 

1183 if params is not None: 

1184 if len(params.shape) == 2: 

1185 params = np.expand_dims(params, axis=0) 

1186 

1187 # Avoid stashing JAX tracers on ``self``: under an outer 

1188 # transform (e.g. ``jacrev``) the tracer becomes invalid once 

1189 # the transform returns, and a subsequent read of 

1190 # ``self.params`` would feed a leaked tracer into the next 

1191 # call (raising ``UnexpectedTracerError``). 

1192 # if not isinstance(params, jax.core.Tracer): 

1193 # self.params = params 

1194 self.params = params 

1195 else: 

1196 params = self.params 

1197 

1198 return params 

1199 

1200 def _pulse_params_validation( 

1201 self, pulse_params: Optional[jnp.ndarray] 

1202 ) -> jnp.ndarray: 

1203 """ 

1204 Validate and normalize pulse parameters. 

1205 

1206 Ensures pulse parameters are set, using model defaults if not provided. 

1207 

1208 Args: 

1209 pulse_params (Optional[jnp.ndarray]): Pulse parameter scalers. 

1210 If None, returns the model's current pulse parameters. 

1211 

1212 Returns: 

1213 jnp.ndarray: Validated pulse parameters with shape 

1214 (batch_size, n_layers, n_pulse_params_per_layer). 

1215 """ 

1216 if pulse_params is None: 

1217 pulse_params = self.pulse_params 

1218 else: 

1219 # ensure batch dimension exists (batch-first convention) 

1220 if len(pulse_params.shape) == 2: 

1221 pulse_params = jnp.expand_dims(pulse_params, axis=0) 

1222 # See note in _params_validation: never stash JAX tracers on 

1223 # ``self``. 

1224 # if not isinstance(pulse_params, jax.core.Tracer): 

1225 # self.pulse_params = pulse_params 

1226 self.pulse_params = pulse_params 

1227 

1228 return pulse_params 

1229 

1230 def _enc_params_validation(self, enc_params: Optional[jnp.ndarray]) -> jnp.ndarray: 

1231 """ 

1232 Validate and normalize encoding parameters. 

1233 

1234 Ensures encoding parameters have the correct shape for the model's 

1235 input feature dimensions. 

1236 

1237 Args: 

1238 enc_params (Optional[jnp.ndarray]): Encoding parameters to validate. 

1239 If None, returns the model's current encoding parameters. 

1240 

1241 Returns: 

1242 jnp.ndarray: Validated encoding parameters with shape 

1243 (n_qubits, n_input_feat). 

1244 

1245 Raises: 

1246 ValueError: If enc_params shape is incompatible with n_input_feat > 1. 

1247 """ 

1248 if enc_params is None: 

1249 enc_params = self.enc_params 

1250 else: 

1251 # See note in _params_validation: never stash JAX tracers on 

1252 # ``self``. 

1253 # if not isinstance(enc_params, jax.core.Tracer): 

1254 # if self.trainable_frequencies: 

1255 # self.enc_params = enc_params 

1256 # else: 

1257 # self.enc_params = jnp.array(enc_params) 

1258 if self.trainable_frequencies: 

1259 self.enc_params = enc_params 

1260 else: 

1261 self.enc_params = jnp.array(enc_params) 

1262 

1263 if len(enc_params.shape) == 1 and self.n_input_feat == 1: 

1264 enc_params = enc_params.reshape(-1, 1) 

1265 elif len(enc_params.shape) == 1 and self.n_input_feat > 1: 

1266 raise ValueError( 

1267 f"Input dimension {self.n_input_feat} >1 but \ 

1268 `enc_params` has shape {enc_params.shape}" 

1269 ) 

1270 

1271 return enc_params 

1272 

1273 def _inputs_validation( 

1274 self, inputs: Union[None, List, float, int, jnp.ndarray] 

1275 ) -> jnp.ndarray: 

1276 """ 

1277 Validate and normalize input data. 

1278 

1279 Converts various input formats to a standardized 2D array shape 

1280 suitable for batch processing in the quantum circuit. 

1281 

1282 Args: 

1283 inputs (Union[None, List, float, int, jnp.ndarray]): Input data in 

1284 various formats: 

1285 - None: Returns zeros with shape (1, n_input_feat) 

1286 - float/int: Single scalar value 

1287 - List: List of values or batched inputs 

1288 - jnp.ndarray: NumPy/JAX array 

1289 

1290 Returns: 

1291 jnp.ndarray: Validated inputs with shape (batch_size, n_input_feat). 

1292 

1293 Raises: 

1294 ValueError: If input shape is incompatible with expected n_input_feat. 

1295 

1296 Warns: 

1297 UserWarning: If input is replicated to match n_input_feat. 

1298 """ 

1299 self._zero_inputs = False 

1300 if isinstance(inputs, List): 

1301 inputs = jnp.array(np.stack(inputs)) 

1302 elif isinstance(inputs, float) or isinstance(inputs, int): 

1303 inputs = jnp.array([inputs]) 

1304 elif inputs is None: 

1305 inputs = jnp.array([[0] * self.n_input_feat]) 

1306 

1307 if not inputs.any(): 

1308 self._zero_inputs = True 

1309 

1310 if len(inputs.shape) <= 1: 

1311 if self.n_input_feat == 1: 

1312 # add a batch dimension 

1313 inputs = inputs.reshape(-1, 1) 

1314 else: 

1315 if inputs.shape[0] == self.n_input_feat: 

1316 inputs = inputs.reshape(1, -1) 

1317 else: 

1318 inputs = inputs.reshape(-1, 1) 

1319 inputs = inputs.repeat(self.n_input_feat, axis=1) 

1320 warnings.warn( 

1321 f"Expected {self.n_input_feat} inputs, but {inputs.shape[0]} " 

1322 "was provided, replicating input for all input features.", 

1323 UserWarning, 

1324 ) 

1325 else: 

1326 if inputs.shape[1] != self.n_input_feat: 

1327 raise ValueError( 

1328 f"Wrong number of inputs provided. Expected {self.n_input_feat} " 

1329 f"inputs, but input has shape {inputs.shape}." 

1330 ) 

1331 

1332 return inputs 

1333 

1334 def _postprocess_res(self, result: Union[List, jnp.ndarray]) -> jnp.ndarray: 

1335 """ 

1336 Post-process circuit execution results for uniform shape. 

1337 

1338 Converts list outputs (from multiple measurements) to stacked arrays 

1339 and reorders axes for consistent batch dimension placement. 

1340 

1341 Args: 

1342 result (Union[List, jnp.ndarray]): Raw circuit output, either a 

1343 list of measurement results or a single array. 

1344 

1345 Returns: 

1346 jnp.ndarray: Uniformly shaped result array with batch dimension first. 

1347 """ 

1348 if isinstance(result, list): 

1349 # we use moveaxis here because in case of parity measure, 

1350 # there is another dimension appended to the end and 

1351 # simply transposing would result in a wrong shape 

1352 result = jnp.stack(result) 

1353 if len(result.shape) > 1: 

1354 result = jnp.moveaxis(result, 0, 1) 

1355 return result 

1356 

1357 def _assimilate_batch( 

1358 self, 

1359 inputs: jnp.ndarray, 

1360 params: jnp.ndarray, 

1361 pulse_params: jnp.ndarray, 

1362 ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: 

1363 """ 

1364 Align batch dimensions across inputs, parameters, and pulse parameters. 

1365 

1366 Broadcasts and reshapes arrays to have compatible batch dimensions 

1367 for vectorized circuit execution. Sets the internal batch_shape. 

1368 

1369 Args: 

1370 inputs (jnp.ndarray): Input data of shape (B_I, n_input_feat). 

1371 params (jnp.ndarray): Parameters of shape (B_P, n_layers, n_params). 

1372 pulse_params (jnp.ndarray): Pulse params of shape (B_R, n_layers, n_pulse). 

1373 

1374 Returns: 

1375 Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: Tuple containing: 

1376 - inputs: Reshaped to (B, n_input_feat) where B = B_I * B_P * B_R 

1377 - params: Reshaped to (B, n_layers, n_params) 

1378 - pulse_params: Reshaped to (B, n_layers, n_pulse) 

1379 

1380 Note: 

1381 The effective batch shape depends on repeat_batch_axis configuration. 

1382 This is the only method that sets self._batch_shape. 

1383 """ 

1384 B_I = inputs.shape[0] 

1385 # we check for the product because there is a chance that 

1386 # there are no params. In this case we want B_P to be 1 

1387 B_P = 1 if 0 in params.shape else params.shape[0] 

1388 B_R = pulse_params.shape[0] 

1389 

1390 # THIS is the only place where we set the batch shape 

1391 self._batch_shape = (B_I, B_P, B_R) 

1392 B = np.prod(self.eff_batch_shape) 

1393 

1394 # [B_I, ...] -> [B_I, B_P, B_R, ...] -> [B, ...] 

1395 if B_I > 1 and self.repeat_batch_axis[0]: 

1396 if self.repeat_batch_axis[1]: 

1397 inputs = jnp.repeat(inputs[:, None, None, ...], B_P, axis=1) 

1398 if self.repeat_batch_axis[2]: 

1399 inputs = jnp.repeat(inputs, B_R, axis=2) 

1400 inputs = inputs.reshape(B, *inputs.shape[3:]) 

1401 

1402 # [B_P, ..., ...] -> [B_I, B_P, B_R, ..., ...] -> [B, ..., ...] 

1403 if B_P > 1 and self.repeat_batch_axis[1]: 

1404 # add B_I axis before first, and B_R axis after first batch dim 

1405 params = params[None, :, None, ...] # [B_I(=1), B_P, B_R(=1), ...] 

1406 if self.repeat_batch_axis[0]: 

1407 params = jnp.repeat(params, B_I, axis=0) # [B_I, B_P, 1, ...] 

1408 if self.repeat_batch_axis[2]: 

1409 params = jnp.repeat(params, B_R, axis=2) # [B_I, B_P, B_R, ...] 

1410 params = params.reshape(B, *params.shape[3:]) 

1411 

1412 # [B_R, ..., ...] -> [B_I, B_P, B_R, ..., ...] -> [B, ..., ...] 

1413 if B_R > 1 and self.repeat_batch_axis[2]: 

1414 # add B_I axis and B_P axis before B_R 

1415 pulse_params = pulse_params[None, None, ...] # [B_I(=1), B_P(=1), B_R, ...] 

1416 if self.repeat_batch_axis[0]: 

1417 pulse_params = jnp.repeat( 

1418 pulse_params, B_I, axis=0 

1419 ) # [B_I, 1, B_R, ...] 

1420 if self.repeat_batch_axis[1]: 

1421 pulse_params = jnp.repeat( 

1422 pulse_params, B_P, axis=1 

1423 ) # [B_I, B_P, B_R, ...] 

1424 pulse_params = pulse_params.reshape(B, *pulse_params.shape[3:]) 

1425 

1426 return inputs, params, pulse_params 

1427 

1428 def _requires_density(self) -> bool: 

1429 """ 

1430 Check if density matrix simulation is required. 

1431 

1432 Determines whether the circuit must be executed with the mixed-state 

1433 simulator based on execution type and noise configuration. 

1434 

1435 Returns: 

1436 bool: True if density matrix simulation is required, False otherwise. 

1437 Returns True if: 

1438 - execution_type is "density", or 

1439 - Any non-coherent noise channel has non-zero probability 

1440 """ 

1441 if self.execution_type == "density": 

1442 return True 

1443 

1444 if self.noise_params is None: 

1445 return False 

1446 

1447 coherent_noise = {"GateError"} 

1448 for k, v in self.noise_params.items(): 

1449 if k in coherent_noise: 

1450 continue 

1451 if v is not None and v > 0: 

1452 return True 

1453 return False 

1454 

1455 def __call__( 

1456 self, 

1457 params: Optional[jnp.ndarray] = None, 

1458 inputs: Optional[jnp.ndarray] = None, 

1459 pulse_params: Optional[jnp.ndarray] = None, 

1460 enc_params: Optional[jnp.ndarray] = None, 

1461 data_reupload: Union[bool, List[List[bool]], List[List[List[bool]]]] = None, 

1462 noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None, 

1463 execution_type: Optional[str] = None, 

1464 force_mean: bool = False, 

1465 gate_mode: str = "unitary", 

1466 ) -> jnp.ndarray: 

1467 """ 

1468 Execute the quantum circuit (callable interface). 

1469 

1470 Provides a convenient callable interface for circuit execution, 

1471 delegating to the _forward method. 

1472 

1473 Args: 

1474 params (Optional[jnp.ndarray]): Variational parameters of shape 

1475 (n_layers, n_params_per_layer) or (batch, n_layers, n_params_per_layer). 

1476 If None, uses model's internal parameters. 

1477 inputs (Optional[jnp.ndarray]): Input data of shape 

1478 (batch_size, n_input_feat). If None, uses zero inputs. 

1479 pulse_params (Optional[jnp.ndarray]): Pulse parameter scalers for 

1480 pulse-mode gate execution. 

1481 enc_params (Optional[jnp.ndarray]): Encoding parameters of shape 

1482 (n_qubits, n_input_feat). If None, uses model's encoding parameters. 

1483 data_reupload (Union[bool, List[List[bool]], List[List[List[bool]]]]): 

1484 Data reupload configuration. If None, uses previously set reupload 

1485 configuration. 

1486 noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]): 

1487 Noise configuration. If None, uses previously set noise parameters. 

1488 execution_type (Optional[str]): Measurement type: "expval", "density", 

1489 "probs", or "state". If None, uses current execution_type setting. 

1490 force_mean (bool): If True, averages results over measurement qubits. 

1491 Defaults to False. 

1492 gate_mode (str): Gate execution backend, "unitary" or "pulse". 

1493 Defaults to "unitary". 

1494 

1495 Returns: 

1496 jnp.ndarray: Circuit output with shape depending on execution_type: 

1497 - "expval": (n_output_qubits,) or scalar 

1498 - "density": (2^n_output, 2^n_output) 

1499 - "probs": (2^n_output,) or (n_pairs, 2^pair_size) 

1500 - "state": (2^n_qubits,) 

1501 """ 

1502 # Call forward method which handles the actual caching etc. 

1503 return self._forward( 

1504 params=params, 

1505 inputs=inputs, 

1506 pulse_params=pulse_params, 

1507 enc_params=enc_params, 

1508 data_reupload=data_reupload, 

1509 noise_params=noise_params, 

1510 execution_type=execution_type, 

1511 force_mean=force_mean, 

1512 gate_mode=gate_mode, 

1513 ) 

1514 

1515 def _forward( 

1516 self, 

1517 params: Optional[jnp.ndarray] = None, 

1518 inputs: Optional[jnp.ndarray] = None, 

1519 pulse_params: Optional[jnp.ndarray] = None, 

1520 enc_params: Optional[jnp.ndarray] = None, 

1521 data_reupload: Union[bool, List[List[bool]], List[List[List[bool]]]] = None, 

1522 noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None, 

1523 execution_type: Optional[str] = None, 

1524 force_mean: bool = False, 

1525 gate_mode: str = "unitary", 

1526 ) -> jnp.ndarray: 

1527 """ 

1528 Execute the quantum circuit forward pass. 

1529 

1530 Internal implementation of the forward pass that handles parameter 

1531 validation, batch alignment, and circuit execution routing. 

1532 

1533 Args: 

1534 params (Optional[jnp.ndarray]): Variational parameters of shape 

1535 (n_layers, n_params_per_layer) or 

1536 (batch, n_layers, n_params_per_layer). 

1537 If None, uses model's internal parameters. 

1538 inputs (Optional[jnp.ndarray]): Input data of shape 

1539 (batch_size, n_input_feat). 

1540 If None, uses zero inputs. 

1541 pulse_params (Optional[jnp.ndarray]): Pulse parameter scalers for 

1542 pulse-mode gate execution. 

1543 enc_params (Optional[jnp.ndarray]): Encoding parameters of shape 

1544 (n_qubits, n_input_feat). If None, uses model's encoding parameters. 

1545 data_reupload (Union[bool, List[List[bool]], List[List[List[bool]]]]): 

1546 Data reupload configuration. If None, uses previously set reupload 

1547 configuration. 

1548 noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]): 

1549 Noise configuration. If None, uses previously set noise parameters. 

1550 execution_type (Optional[str]): Measurement type: "expval", "density", 

1551 "probs", or "state". If None, uses current execution_type setting. 

1552 force_mean (bool): If True, averages results over measurement qubits. 

1553 Defaults to False. 

1554 gate_mode (str): Gate execution backend, "unitary" or "pulse". 

1555 Defaults to "unitary". 

1556 

1557 Returns: 

1558 jnp.ndarray: Circuit output with shape depending on execution_type: 

1559 - "expval": (n_output_qubits,) or scalar 

1560 - "density": (2^n_output, 2^n_output) 

1561 - "probs": (2^n_output,) or (n_pairs, 2^pair_size) 

1562 - "state": (2^n_qubits,) 

1563 

1564 Raises: 

1565 ValueError: If pulse_params provided without pulse gate_mode, or 

1566 if noise_params provided with pulse gate_mode. 

1567 """ 

1568 # set the parameters as object attributes 

1569 if noise_params is not None: 

1570 self.noise_params = noise_params 

1571 if execution_type is not None: 

1572 self.execution_type = execution_type 

1573 self.gate_mode = gate_mode 

1574 

1575 # consistency checks 

1576 if pulse_params is not None and gate_mode != "pulse": 

1577 raise ValueError( 

1578 "pulse_params were provided but gate_mode is not 'pulse'. " 

1579 "Either switch gate_mode='pulse' or do not pass pulse_params." 

1580 ) 

1581 

1582 # TODO: add testing 

1583 if data_reupload is not None: 

1584 self.data_reupload = data_reupload 

1585 

1586 params = self._params_validation(params) 

1587 pulse_params = self._pulse_params_validation(pulse_params) 

1588 inputs = self._inputs_validation(inputs) 

1589 enc_params = self._enc_params_validation(enc_params) 

1590 

1591 inputs, params, pulse_params = self._assimilate_batch( 

1592 inputs, 

1593 params, 

1594 pulse_params, 

1595 ) 

1596 

1597 # split to generate a sub_key, required for actual execution 

1598 self.random_key, sub_key = safe_random_split(self.random_key) 

1599 

1600 # Build measurement type & observables from execution_type / output_qubit 

1601 meas_type, obs = self._build_obs() 

1602 

1603 # Yaqsi auto-routes between statevector and density-matrix simulation 

1604 # based on whether noise channels appear on the tape, so a single 

1605 B = np.prod(self.eff_batch_shape) 

1606 

1607 # kwargs are broadcast (not vmapped over) 

1608 exec_kwargs = dict( 

1609 noise_params=self.noise_params, 

1610 gate_mode=self.gate_mode, 

1611 ) 

1612 

1613 # Build a shot key from the random_key if shots are requested 

1614 shot_key = None 

1615 if self.shots is not None: 

1616 # overwrite subkey and split shot_key 

1617 sub_key, shot_key = safe_random_split(sub_key) 

1618 

1619 if B > 1: 

1620 # use random keys, derived from the subkey 

1621 random_keys = safe_random_split(sub_key, num=B) 

1622 

1623 in_axes = ( 

1624 0 if self.batch_shape[1] > 1 else None, # params 

1625 0 if self.batch_shape[0] > 1 else None, # inputs 

1626 0 if self.batch_shape[2] > 1 else None, # pulse_params 

1627 0, # random_keys 

1628 None, # enc_params (broadcast, not batched) 

1629 ) 

1630 

1631 result = self.script.execute( 

1632 type=meas_type, 

1633 obs=obs, 

1634 args=(params, inputs, pulse_params, random_keys, enc_params), 

1635 kwargs=exec_kwargs, 

1636 in_axes=in_axes, 

1637 shots=self.shots, 

1638 key=shot_key, 

1639 ) 

1640 else: 

1641 # use the subkey directly 

1642 result = self.script.execute( 

1643 type=meas_type, 

1644 obs=obs, 

1645 args=(params, inputs, pulse_params, sub_key, enc_params), 

1646 kwargs=exec_kwargs, 

1647 shots=self.shots, 

1648 key=shot_key, 

1649 ) 

1650 

1651 result = self._postprocess_res(result) 

1652 

1653 # --- Post-processing for partial-qubit measurements --------------- 

1654 if self.execution_type == "density" and not self.all_qubit_measurement: 

1655 result = ys.partial_trace(result, self.n_qubits, self.output_qubit) 

1656 

1657 if self.execution_type == "probs" and not self.all_qubit_measurement: 

1658 if isinstance(self.output_qubit[0], (list, tuple)): 

1659 # list of qubit groups - marginalize each independently 

1660 result = jnp.stack( 

1661 [ 

1662 ys.marginalize_probs(result, self.n_qubits, list(group)) 

1663 for group in self.output_qubit 

1664 ] 

1665 ) 

1666 else: 

1667 result = ys.marginalize_probs(result, self.n_qubits, self.output_qubit) 

1668 

1669 result = jnp.asarray(result) 

1670 result = result.reshape((*self.eff_batch_shape, *self._result_shape)).squeeze() 

1671 

1672 if ( 

1673 self.execution_type in ("expval", "probs") 

1674 and force_mean 

1675 and len(result.shape) > 0 

1676 and self._result_shape[0] > 1 

1677 ): 

1678 result = result.mean(axis=-1) 

1679 

1680 return result