Coverage for qml_essentials / model.py: 90%

501 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-30 11:43 +0000

1from typing import Any, Dict, Optional, Tuple, Callable, Union, List 

2 

3from qml_essentials import operations as op 

4from qml_essentials import yaqsi as ys 

5import warnings 

6import jax.numpy as jnp 

7import numpy as np 

8from jax import random 

9 

10from qml_essentials.tape import recording 

11from qml_essentials.operations import KrausChannel 

12from qml_essentials.ansaetze import Ansaetze, Circuit, Encoding 

13from qml_essentials.gates import Gates, PulseInformation as pinfo 

14from qml_essentials.utils import safe_random_split 

15 

16import logging 

17 

18log = logging.getLogger(__name__) 

19 

20 

21class Model: 

22 """ 

23 A quantum circuit model. 

24 """ 

25 

26 def __init__( 

27 self, 

28 n_qubits: int, 

29 n_layers: int, 

30 circuit_type: Union[str, Circuit] = "No_Ansatz", 

31 data_reupload: Union[bool, List[List[bool]], List[List[List[bool]]]] = True, 

32 state_preparation: Union[ 

33 str, Callable, List[Union[str, Callable]], None 

34 ] = None, 

35 encoding: Union[Encoding, str, Callable, List[Union[str, Callable]]] = Gates.RX, 

36 trainable_frequencies: bool = False, 

37 initialization: str = "random", 

38 initialization_domain: List[float] = [0, 2 * jnp.pi], 

39 output_qubit: Union[List[int], int] = -1, 

40 shots: Optional[int] = None, 

41 random_seed: int = 1000, 

42 remove_zero_encoding: bool = True, 

43 repeat_batch_axis: List[bool] = [True, True, True], 

44 pulse_shape: str = "gaussian", 

45 ) -> None: 

46 """ 

47 Initialize the quantum circuit model. 

48 Parameters will have the shape [impl_n_layers, parameters_per_layer] 

49 where impl_n_layers is the number of layers provided and added by one 

50 depending if data_reupload is True and parameters_per_layer is given by 

51 the chosen ansatz. 

52 

53 The model is initialized with the following parameters as defaults: 

54 - noise_params: None 

55 - execution_type: "expval" 

56 - shots: None 

57 

58 Args: 

59 n_qubits (int): The number of qubits in the circuit. 

60 n_layers (int): The number of layers in the circuit. 

61 circuit_type (str, Circuit): The type of quantum circuit to use. 

62 If None, defaults to "no_ansatz". 

63 data_reupload (Union[bool, List[bool], List[List[bool]]], optional): 

64 Whether to reupload data to the quantum device on each 

65 layer and qubit. Detailed re-uploading instructions can be given 

66 as a list/array of 0/False and 1/True with shape (n_qubits, 

67 n_layers) to specify where to upload the data. Defaults to True 

68 for applying data re-uploading to the full circuit. 

69 encoding (Union[str, Callable, List[str], List[Callable]], optional): 

70 The unitary to use for encoding the input data. Can be a string 

71 (e.g. "RX") or a callable (e.g. op.RX). Defaults to op.RX. 

72 If input is multidimensional it is assumed to be a list of 

73 unitaries or a list of strings. 

74 trainable_frequencies (bool, optional): 

75 Sets trainable encoding parameters for trainable frequencies. 

76 Defaults to False. 

77 initialization (str, optional): The strategy to initialize the parameters. 

78 Can be "random", "zeros", "zero-controlled", "pi", or "pi-controlled". 

79 Defaults to "random". 

80 output_qubit (List[int], int, optional): The index of the output 

81 qubit (or qubits). When set to -1 all qubits are measured, or a 

82 global measurement is conducted, depending on the execution 

83 type. 

84 shots (Optional[int], optional): The number of shots to use for 

85 the quantum device. Defaults to None. 

86 random_seed (int, optional): seed for the random number generator 

87 in initialization is "random" and for random noise parameters. 

88 Defaults to 1000. 

89 remove_zero_encoding (bool, optional): whether to 

90 remove the zero encoding from the circuit. Defaults to True. 

91 repeat_batch_axis (List[bool], optional): Each boolean in the array 

92 determines over which axes to parallelise computation. The axes 

93 correspond to [inputs, params, pulse_params]. Defaults to 

94 [True, True, True], meaning that batching is enabled over all 

95 axes. 

96 pulse_shape (str, optional): Pulse envelope shape for pulse-level 

97 simulation. One of ``PulseEnvelope.available()``. 

98 Defaults to ``"gaussian"``. 

99 

100 Returns: 

101 None 

102 """ 

103 # Initialize default parameters needed for circuit evaluation 

104 self.n_qubits: int = n_qubits 

105 self.output_qubit: Union[List[int], int] = output_qubit 

106 self.n_layers: int = n_layers 

107 self.noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None 

108 self.shots = shots 

109 self.remove_zero_encoding = remove_zero_encoding 

110 self.trainable_frequencies: bool = trainable_frequencies 

111 self.execution_type: str = "expval" 

112 self.repeat_batch_axis: List[bool] = repeat_batch_axis 

113 

114 # --- Pulse envelope --- 

115 pinfo.set_envelope(pulse_shape) 

116 

117 # --- State Preparation --- 

118 try: 

119 self._sp = Gates.parse_gates(state_preparation, Gates) 

120 except ValueError as e: 

121 raise ValueError(f"Error parsing encodings: {e}") 

122 

123 # prepare corresponding pulse parameters (always optimized pulses) 

124 self.sp_pulse_params = [] 

125 for sp in self._sp: 

126 sp_name = sp.__name__ if hasattr(sp, "__name__") else str(sp) 

127 

128 if pinfo.gate_by_name(sp_name) is not None: 

129 self.sp_pulse_params.append(pinfo.gate_by_name(sp_name).params) 

130 else: 

131 # gate has no pulse parametrization 

132 self.sp_pulse_params.append(None) 

133 

134 # --- Encoding --- 

135 if isinstance(encoding, Encoding): 

136 # user wants custom strategy? do it! 

137 self._enc = encoding 

138 else: 

139 # use hammming encoding by default 

140 self._enc = Encoding("hamming", encoding) 

141 

142 # Number of possible inputs 

143 self.n_input_feat = len(self._enc) 

144 log.debug(f"Number of input features: {self.n_input_feat}") 

145 

146 # Trainable frequencies, default initialization as in arXiv:2309.03279v2 

147 self.enc_params = jnp.ones((self.n_layers, self.n_qubits, self.n_input_feat)) 

148 

149 self._zero_inputs = False 

150 

151 # --- Data-Reuploading --- 

152 

153 # Keep as NumPy array (not JAX) so that ``if data_reupload[q, idx]`` 

154 # in _iec remains a concrete Python bool even under jax.jit tracing. 

155 # note that setting this will also update self.degree and self.frequencies 

156 # and in consequence also self.has_dru 

157 self.data_reupload = data_reupload 

158 

159 # check for the highest degree among all input dimensions 

160 if self.has_dru: 

161 impl_n_layers: int = n_layers + 1 # we need L+1 according to Schuld et al. 

162 else: 

163 impl_n_layers = n_layers 

164 log.info(f"Number of implicit layers: {impl_n_layers}.") 

165 

166 # --- Ansatz --- 

167 # only weak check for str. We trust the user to provide sth useful 

168 if isinstance(circuit_type, str): 

169 self.pqc: Callable[[Optional[jnp.ndarray], int], int] = getattr( 

170 Ansaetze, circuit_type or "No_Ansatz" 

171 )() 

172 else: 

173 self.pqc = circuit_type() 

174 log.info(f"Using Ansatz {circuit_type}.") 

175 

176 # calculate the shape of the parameter vector here, we will re-use this in init. 

177 params_per_layer = self.pqc.n_params_per_layer(self.n_qubits) 

178 self._params_shape: Tuple[int, int] = (impl_n_layers, params_per_layer) 

179 log.info(f"Parameters per layer: {params_per_layer}") 

180 

181 pulse_params_per_layer = self.pqc.n_pulse_params_per_layer(self.n_qubits) 

182 self._pulse_params_shape: Tuple[int, int] = ( 

183 impl_n_layers, 

184 pulse_params_per_layer, 

185 ) 

186 

187 # intialize to None as we can't know this yet 

188 self._batch_shape = None 

189 

190 # this will also be re-used in the init method, 

191 # however, only if nothing is provided 

192 self._inialization_strategy = initialization 

193 self._initialization_domain = initialization_domain 

194 

195 # ..here! where we only require a JAX random key 

196 self.random_key = self.initialize_params(random.key(random_seed)) 

197 

198 # Initializing pulse params 

199 self.pulse_params: jnp.ndarray = jnp.ones((1, *self._pulse_params_shape)) 

200 

201 log.info(f"Initialized pulse parameters with shape {self.pulse_params.shape}.") 

202 

203 # Initialise the yaqsi Script that wraps _variational. 

204 # No device selection needed - yaqsi auto-routes between statevector 

205 # and density-matrix simulation based on whether noise channels are 

206 # present on the tape. 

207 self.script = ys.Script(f=self._variational, n_qubits=self.n_qubits) 

208 

209 @property 

210 def noise_params(self) -> Optional[Dict[str, Union[float, Dict[str, float]]]]: 

211 """ 

212 Gets the noise parameters of the model. 

213 

214 Returns: 

215 Optional[Dict[str, float]]: A dictionary of 

216 noise parameters or None if not set. 

217 """ 

218 return self._noise_params 

219 

220 @noise_params.setter 

221 def noise_params( 

222 self, kvs: Optional[Dict[str, Union[float, Dict[str, float]]]] 

223 ) -> None: 

224 """ 

225 Sets the noise parameters of the model. 

226 

227 Typically a "noise parameter" refers to the error probability. 

228 ThermalRelaxation is a special case, and supports a dict as value with 

229 structure: 

230 "ThermalRelaxation": 

231 { 

232 "t1": 2000, # relative t1 time. 

233 "t2": 1000, # relative t2 time 

234 "t_factor" 1: # relative gate time factor 

235 }, 

236 

237 Args: 

238 kvs (Optional[Dict[str, Union[float, Dict[str, float]]]]): A 

239 dictionary of noise parameters. If all values are 0.0, the noise 

240 parameters are set to None. 

241 

242 Returns: 

243 None 

244 """ 

245 # set to None if only zero values provided 

246 if kvs is not None and all(v == 0.0 for v in kvs.values()): 

247 kvs = None 

248 

249 # set default values 

250 if kvs is not None: 

251 defaults = { 

252 "BitFlip": 0.0, 

253 "PhaseFlip": 0.0, 

254 "Depolarizing": 0.0, 

255 "MultiQubitDepolarizing": 0.0, 

256 "AmplitudeDamping": 0.0, 

257 "PhaseDamping": 0.0, 

258 "GateError": 0.0, 

259 "ThermalRelaxation": None, 

260 "StatePreparation": 0.0, 

261 "Measurement": 0.0, 

262 } 

263 for key, default_val in defaults.items(): 

264 kvs.setdefault(key, default_val) 

265 

266 # check if there are any keys not supported 

267 for key in kvs.keys(): 

268 if key not in defaults: 

269 warnings.warn( 

270 f"Noise type {key} is not supported by this package", 

271 UserWarning, 

272 ) 

273 

274 # check valid params for thermal relaxation noise channel 

275 tr_params = kvs["ThermalRelaxation"] 

276 if isinstance(tr_params, dict): 

277 tr_params.setdefault("t1", 0.0) 

278 tr_params.setdefault("t2", 0.0) 

279 tr_params.setdefault("t_factor", 0.0) 

280 valid_tr_keys = {"t1", "t2", "t_factor"} 

281 for k in tr_params.keys(): 

282 if k not in valid_tr_keys: 

283 warnings.warn( 

284 f"Thermal Relaxation parameter {k} is not supported " 

285 f"by this package", 

286 UserWarning, 

287 ) 

288 if not all(tr_params.values()) or tr_params["t2"] > 2 * tr_params["t1"]: 

289 warnings.warn( 

290 "Received invalid values for Thermal Relaxation noise " 

291 "parameter. Thermal relaxation is not applied!", 

292 UserWarning, 

293 ) 

294 kvs["ThermalRelaxation"] = 0.0 

295 

296 self._noise_params = kvs 

297 

298 @property 

299 def output_qubit(self) -> List[int]: 

300 """Get the output qubit indices for measurement.""" 

301 return self._output_qubit 

302 

303 @output_qubit.setter 

304 def output_qubit(self, value: Union[int, List[int]]) -> None: 

305 """ 

306 Set the output qubit(s) for measurement. 

307 

308 Args: 

309 value: Qubit index or list of indices. Use -1 for all qubits. 

310 """ 

311 if isinstance(value, list): 

312 assert ( 

313 len(value) <= self.n_qubits 

314 ), f"Size of output_qubit {len(value)} cannot be\ 

315 larger than number of qubits {self.n_qubits}." 

316 elif isinstance(value, int): 

317 if value == -1: 

318 value = list(range(self.n_qubits)) 

319 else: 

320 assert ( 

321 value < self.n_qubits 

322 ), f"Output qubit {value} cannot be larger than {self.n_qubits}." 

323 value = [value] 

324 

325 self._output_qubit = value 

326 

327 @property 

328 def execution_type(self) -> str: 

329 """ 

330 Gets the execution type of the model. 

331 

332 Returns: 

333 str: The execution type, one of 'density', 'expval', or 'probs'. 

334 """ 

335 return self._execution_type 

336 

337 @execution_type.setter 

338 def execution_type(self, value: str) -> None: 

339 if value == "density": 

340 self._result_shape = ( 

341 2 ** len(self.output_qubit), 

342 2 ** len(self.output_qubit), 

343 ) 

344 elif value == "expval": 

345 # check if all qubits are used 

346 if len(self.output_qubit) == self.n_qubits: 

347 self._result_shape = (len(self.output_qubit),) 

348 # if not -> parity measurement with only 1D output per pair 

349 # or n_local measurement 

350 else: 

351 self._result_shape = (len(self.output_qubit),) 

352 elif value == "probs": 

353 # in case this is a list of parities, 

354 # each pair has 2^len(qubits) probabilities 

355 n_parity = ( 

356 (2,) * len(self.output_qubit) 

357 if isinstance(self.output_qubit, (Tuple, List)) 

358 else (2,) 

359 ) 

360 self._result_shape = n_parity 

361 elif value == "state": 

362 self._result_shape = (2 ** len(self.output_qubit),) 

363 else: 

364 raise ValueError(f"Invalid execution type: {value}.") 

365 

366 if value == "state" and not self.all_qubit_measurement: 

367 warnings.warn( 

368 f"{value} measurement does ignore output_qubit, which is " 

369 f"{self.output_qubit}.", 

370 UserWarning, 

371 ) 

372 

373 if value == "probs" and self.shots is None: 

374 warnings.warn( 

375 "Setting execution_type to probs without specifying shots.", 

376 UserWarning, 

377 ) 

378 

379 if value == "density" and self.shots is not None: 

380 raise ValueError("Setting execution_type to density with shots not None.") 

381 

382 self._execution_type = value 

383 

384 @property 

385 def shots(self) -> Optional[int]: 

386 """ 

387 Gets the number of shots to use for the quantum device. 

388 

389 Returns: 

390 Optional[int]: The number of shots. 

391 """ 

392 return self._shots 

393 

394 @shots.setter 

395 def shots(self, value: Optional[int]) -> None: 

396 """ 

397 Sets the number of shots to use for the quantum device. 

398 

399 Args: 

400 value (Optional[int]): The number of shots. 

401 If an integer less than or equal to 0 is provided, it is set to None. 

402 

403 Returns: 

404 None 

405 """ 

406 if type(value) is int and value <= 0: 

407 value = None 

408 self._shots = value 

409 

410 @property 

411 def params(self) -> jnp.ndarray: 

412 """Get the variational parameters of the model.""" 

413 return self._params 

414 

415 @params.setter 

416 def params(self, value: jnp.ndarray) -> None: 

417 """Set the variational parameters, ensuring batch dimension exists.""" 

418 if len(value.shape) == 2: 

419 value = value.reshape(1, *value.shape) 

420 

421 self._params = value 

422 

423 @property 

424 def enc_params(self) -> jnp.ndarray: 

425 """Get the encoding parameters used for input transformation.""" 

426 return self._enc_params 

427 

428 @enc_params.setter 

429 def enc_params(self, value: jnp.ndarray) -> None: 

430 """Set the encoding parameters.""" 

431 self._enc_params = value 

432 

433 @property 

434 def pulse_params(self) -> jnp.ndarray: 

435 """Get the pulse parameters for pulse-mode gate execution.""" 

436 return self._pulse_params 

437 

438 @pulse_params.setter 

439 def pulse_params(self, value: jnp.ndarray) -> None: 

440 """Set the pulse parameters.""" 

441 self._pulse_params = value 

442 

443 @property 

444 def data_reupload(self) -> jnp.ndarray: 

445 """Get the data reupload mask.""" 

446 return self._data_reupload 

447 

448 @data_reupload.setter 

449 def data_reupload(self, value: jnp.ndarray) -> None: 

450 """Set the data reupload mask. 

451 

452 Always converts to a concrete NumPy boolean array so that 

453 ``if data_reupload[q, idx]`` in :meth:`_iec` remains a plain 

454 Python ``bool`` even inside JAX-traced functions (jit / grad / vmap). 

455 """ 

456 # Process data reuploading strategy and set degree 

457 if not isinstance(value, bool): 

458 if not isinstance(value, np.ndarray): 

459 value = np.array(value) 

460 

461 if len(value.shape) == 2: 

462 assert value.shape == ( 

463 self.n_layers, 

464 self.n_qubits, 

465 ), f"Data reuploading array has wrong shape. \ 

466 Expected {(self.n_layers, self.n_qubits)} or\ 

467 {(self.n_layers, self.n_qubits, self.n_input_feat)},\ 

468 got {value.shape}." 

469 value = value.reshape(*value.shape, 1) 

470 value = np.repeat(value, self.n_input_feat, axis=2) 

471 

472 assert value.shape == ( 

473 self.n_layers, 

474 self.n_qubits, 

475 self.n_input_feat, 

476 ), f"Data reuploading array has wrong shape. \ 

477 Expected {(self.n_layers, self.n_qubits, self.n_input_feat)},\ 

478 got {value.shape}." 

479 

480 log.debug(f"Data reuploading array:\n{value}") 

481 else: 

482 if value: 

483 value = np.ones((self.n_layers, self.n_qubits, self.n_input_feat)) 

484 log.debug("Full data reuploading.") 

485 else: 

486 value = np.zeros((self.n_layers, self.n_qubits, self.n_input_feat)) 

487 value[0][0] = 1 

488 log.debug("No data reuploading.") 

489 

490 # convert to boolean values 

491 self._data_reupload = np.asarray(value).astype(bool) 

492 

493 self.degree: Tuple = tuple( 

494 self._enc.get_n_freqs(np.count_nonzero(self.data_reupload[..., i])) 

495 for i in range(self.n_input_feat) 

496 ) 

497 

498 self.frequencies: Tuple = tuple( 

499 self._enc.get_spectrum(np.count_nonzero(self.data_reupload[..., i])) 

500 for i in range(self.n_input_feat) 

501 ) 

502 

503 # Cache has_dru as a plain Python bool so that it can be used in 

504 # Python ``if`` statements even inside JAX-traced functions. 

505 self._has_dru: bool = bool(max(int(np.max(f)) for f in self._frequencies) > 1) 

506 

507 @property 

508 def degree(self) -> Tuple: 

509 """Get the degree of the model.""" 

510 return self._degree 

511 

512 @degree.setter 

513 def degree(self, value: Tuple): 

514 self._degree = value 

515 

516 @property 

517 def frequencies(self) -> Tuple: 

518 """Get the frequencies of the model.""" 

519 return self._frequencies 

520 

521 @frequencies.setter 

522 def frequencies(self, value: Tuple): 

523 self._frequencies = value 

524 

525 @property 

526 def has_dru(self) -> bool: 

527 """Check if the model has data reupload.""" 

528 return self._has_dru 

529 

530 @property 

531 def all_qubit_measurement(self) -> bool: 

532 """Check if measurement is performed on all qubits.""" 

533 return self.output_qubit == list(range(self.n_qubits)) 

534 

535 @property 

536 def batch_shape(self) -> Tuple[int, ...]: 

537 """ 

538 Get the batch shape (B_I, B_P, B_R). 

539 If the model was not called before, 

540 it returns (1, 1, 1). 

541 

542 Returns: 

543 Tuple[int, ...]: Tuple of (input_batch, param_batch, pulse_batch). 

544 Returns (1, 1, 1) if model has not been called yet. 

545 """ 

546 if self._batch_shape is None: 

547 log.debug("Model was not called yet. Returning (1,1,1) as batch shape.") 

548 return (1, 1, 1) 

549 return self._batch_shape 

550 

551 @property 

552 def eff_batch_shape(self) -> Tuple[int, ...]: 

553 """ 

554 Get the effective batch shape after applying repeat_batch_axis mask. 

555 

556 Returns: 

557 Tuple[int, ...]: Effective batch dimensions, excluding zeros. 

558 """ 

559 batch_shape = np.array(self.batch_shape) * self.repeat_batch_axis 

560 batch_shape = batch_shape[batch_shape != 0] 

561 return batch_shape 

562 

563 def initialize_params( 

564 self, 

565 random_key: Optional[random.PRNGKey] = None, 

566 repeat: int = 1, 

567 initialization: Optional[str] = None, 

568 initialization_domain: Optional[List[float]] = None, 

569 ) -> random.PRNGKey: 

570 """ 

571 Initialize the variational parameters of the model. 

572 

573 Args: 

574 random_key (Optional[random.PRNGKey]): JAX random key for initialization. 

575 If None, uses the model's internal random key. 

576 repeat (int): Number of parameter sets to create (batch dimension). 

577 Defaults to 1. 

578 initialization (Optional[str]): Strategy for parameter initialization. 

579 Options: "random", "zeros", "pi", "zero-controlled", "pi-controlled". 

580 If None, uses the strategy specified in the constructor. 

581 initialization_domain (Optional[List[float]]): Domain [min, max] for 

582 random initialization. If None, uses the domain from constructor. 

583 

584 Returns: 

585 random.PRNGKey: Updated random key after initialization. 

586 

587 Raises: 

588 Exception: If an invalid initialization method is specified. 

589 """ 

590 # Initializing params 

591 params_shape = (repeat, *self._params_shape) 

592 

593 # use existing strategy if not specified 

594 initialization = initialization or self._inialization_strategy 

595 initialization_domain = initialization_domain or self._initialization_domain 

596 

597 random_key, sub_key = safe_random_split( 

598 random_key if random_key is not None else self.random_key 

599 ) 

600 

601 def set_control_params(params: jnp.ndarray, value: float) -> jnp.ndarray: 

602 indices = self.pqc.get_control_indices(self.n_qubits) 

603 if indices is None: 

604 warnings.warn( 

605 f"Specified {initialization} but circuit\ 

606 does not contain controlled rotation gates.\ 

607 Parameters are intialized randomly.", 

608 UserWarning, 

609 ) 

610 else: 

611 np_params = np.array(params) 

612 np_params[:, :, indices[0] : indices[1] : indices[2]] = ( 

613 np.ones_like(params[:, :, indices[0] : indices[1] : indices[2]]) 

614 * value 

615 ) 

616 params = jnp.array(np_params) 

617 return params 

618 

619 if initialization == "random": 

620 self.params: jnp.ndarray = random.uniform( 

621 sub_key, 

622 params_shape, 

623 minval=initialization_domain[0], 

624 maxval=initialization_domain[1], 

625 ) 

626 elif initialization == "zeros": 

627 self.params: jnp.ndarray = jnp.zeros(params_shape) 

628 elif initialization == "pi": 

629 self.params: jnp.ndarray = jnp.ones(params_shape) * jnp.pi 

630 elif initialization == "zero-controlled": 

631 self.params: jnp.ndarray = random.uniform( 

632 sub_key, 

633 params_shape, 

634 minval=initialization_domain[0], 

635 maxval=initialization_domain[1], 

636 ) 

637 self.params = set_control_params(self.params, 0) 

638 elif initialization == "pi-controlled": 

639 self.params: jnp.ndarray = random.uniform( 

640 sub_key, 

641 params_shape, 

642 minval=initialization_domain[0], 

643 maxval=initialization_domain[1], 

644 ) 

645 self.params = set_control_params(self.params, jnp.pi) 

646 else: 

647 raise Exception("Invalid initialization method") 

648 

649 log.info( 

650 f"Initialized parameters with shape {self.params.shape}\ 

651 using strategy {initialization}." 

652 ) 

653 

654 return random_key 

655 

656 def transform_input( 

657 self, inputs: jnp.ndarray, enc_params: jnp.ndarray 

658 ) -> jnp.ndarray: 

659 """ 

660 Transform input data by scaling with encoding parameters. 

661 

662 Implements the input transformation as described in arXiv:2309.03279v2, 

663 where inputs are linearly scaled by encoding parameters before being 

664 used in the quantum circuit. 

665 

666 Args: 

667 inputs (jnp.ndarray): Input data point of shape (n_input_feat,) or 

668 (batch_size, n_input_feat). 

669 enc_params (jnp.ndarray): Encoding weight scalar or vector used to 

670 scale the input. 

671 

672 Returns: 

673 jnp.ndarray: Transformed input, element-wise product of inputs 

674 and enc_params. 

675 """ 

676 return inputs * enc_params 

677 

678 def _iec( 

679 self, 

680 inputs: jnp.ndarray, 

681 data_reupload: jnp.ndarray, 

682 enc: Encoding, 

683 enc_params: jnp.ndarray, 

684 noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None, 

685 random_key: Optional[random.PRNGKey] = None, 

686 ) -> None: 

687 """ 

688 Apply Input Encoding Circuit (IEC) with angle encoding. 

689 

690 Encodes classical input data into the quantum circuit using rotation 

691 gates (e.g., RX, RY, RZ). Supports data re-uploading at specified 

692 positions in the circuit. 

693 

694 Args: 

695 inputs (jnp.ndarray): Input data of shape (n_input_feat,) or 

696 (batch_size, n_input_feat). 

697 data_reupload (jnp.ndarray): Boolean array of shape (n_qubits, n_input_feat) 

698 indicating where to apply encoding gates. 

699 enc (Encoding): Encoding strategy containing the encoding gate functions. 

700 enc_params (jnp.ndarray): Encoding parameters of shape 

701 (n_qubits, n_input_feat) used to scale inputs. 

702 noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]): 

703 Noise parameters for gate-level noise simulation. Defaults to None. 

704 random_key (Optional[random.PRNGKey]): JAX random key for stochastic 

705 noise. Defaults to None. 

706 

707 Returns: 

708 None: Gates are applied in-place to the quantum circuit. 

709 """ 

710 # check for zero, because due to input validation, input cannot be none 

711 if self.remove_zero_encoding and self._zero_inputs and self.batch_shape[0] == 1: 

712 return 

713 

714 for q in range(self.n_qubits): 

715 # use the last dimension of the inputs (feature dimension) 

716 for idx in range(inputs.shape[-1]): 

717 if data_reupload[q, idx]: 

718 # use elipsis to indiex only the last dimension 

719 # as inputs are generally *not* qubit dependent 

720 random_key, sub_key = safe_random_split(random_key) 

721 enc[idx]( 

722 self.transform_input(inputs[..., idx], enc_params[q, idx]), 

723 wires=q, 

724 noise_params=noise_params, 

725 random_key=sub_key, 

726 input_idx=idx, 

727 ) 

728 

729 def _variational( 

730 self, 

731 params: jnp.ndarray, 

732 inputs: jnp.ndarray, 

733 pulse_params: Optional[jnp.ndarray] = None, 

734 random_key: Optional[random.PRNGKey] = None, 

735 enc_params: Optional[jnp.ndarray] = None, 

736 gate_mode: str = "unitary", 

737 noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None, 

738 ) -> None: 

739 """ 

740 Build the variational quantum circuit structure. 

741 

742 Constructs the circuit by applying state preparation, alternating 

743 variational ansatz layers with input encoding layers, and optional 

744 noise channels. 

745 

746 The first five parameters (after ``self``) - ``params``, ``inputs``, 

747 ``pulse_params``, ``random_key``, ``enc_params`` - are the batchable 

748 positional arguments. 

749 The remaining keyword arguments are broadcast across the batch. 

750 

751 Args: 

752 params (jnp.ndarray): Variational parameters of shape 

753 (n_layers, n_params_per_layer). 

754 inputs (jnp.ndarray): Input data of shape (n_input_feat,). 

755 pulse_params (Optional[jnp.ndarray]): Pulse parameter scalers of shape 

756 (n_layers, n_pulse_params_per_layer) for pulse-mode execution. 

757 Defaults to None (uses model's pulse_params). 

758 random_key (Optional[random.PRNGKey]): JAX random key for stochastic 

759 operations. Defaults to None. 

760 enc_params (Optional[jnp.ndarray]): Encoding parameters of shape 

761 (n_qubits, n_input_feat). Defaults to None (uses model's enc_params). 

762 gate_mode (str): Gate execution mode, either "unitary" or "pulse". 

763 Defaults to "unitary". 

764 noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]): 

765 Noise parameters for simulation. Defaults to None. 

766 

767 Returns: 

768 None: Gates are applied in-place to the quantum circuit. 

769 

770 Note: 

771 Issues RuntimeWarning if called directly without providing parameters 

772 that would normally be passed through the forward method. 

773 """ 

774 # TODO: rework and double check params shape 

775 if len(params.shape) > 2 and params.shape[0] == 1: 

776 params = params[0] 

777 

778 if len(inputs.shape) > 1 and inputs.shape[0] == 1: 

779 inputs = inputs[0] 

780 

781 if enc_params is None: 

782 # TODO: Raise warning if trainable frequencies is True, or similar. I.e., no 

783 # warning if user does not care for frequencies or enc_params 

784 if self.trainable_frequencies: 

785 warnings.warn( 

786 "Explicit call to `_circuit` or `_variational` detected: " 

787 "`enc_params` is None, using `self.enc_params` instead.", 

788 RuntimeWarning, 

789 ) 

790 enc_params = self.enc_params 

791 

792 if pulse_params is None: 

793 if gate_mode == "pulse": 

794 warnings.warn( 

795 "Explicit call to `_circuit` or `_variational` detected: " 

796 "`pulse_params` is None, using `self.pulse_params` instead.", 

797 RuntimeWarning, 

798 ) 

799 pulse_params = self.pulse_params 

800 

801 # Squeeze batch dimension for pulse_params (batch-first convention) 

802 if len(pulse_params.shape) > 2 and pulse_params.shape[0] == 1: 

803 pulse_params = pulse_params[0] 

804 

805 if noise_params is None: 

806 if self.noise_params is not None: 

807 warnings.warn( 

808 "Explicit call to `_circuit` or `_variational` detected: " 

809 "`noise_params` is None, using `self.noise_params` instead.", 

810 RuntimeWarning, 

811 ) 

812 noise_params = self.noise_params 

813 

814 if noise_params is not None: 

815 if random_key is None: 

816 warnings.warn( 

817 "Explicit call to `_circuit` or `_variational` detected: " 

818 "`random_key` is None, using `random.PRNGKey(0)` instead.", 

819 RuntimeWarning, 

820 ) 

821 random_key = self.random_key 

822 self._apply_state_prep_noise(noise_params=noise_params) 

823 

824 # state preparation 

825 for q in range(self.n_qubits): 

826 for _sp, sp_pulse_params in zip(self._sp, self.sp_pulse_params): 

827 random_key, sub_key = safe_random_split(random_key) 

828 _sp( 

829 wires=q, 

830 pulse_params=sp_pulse_params, 

831 noise_params=noise_params, 

832 random_key=sub_key, 

833 gate_mode=gate_mode, 

834 ) 

835 

836 # circuit building 

837 for layer in range(0, self.n_layers): 

838 random_key, sub_key = safe_random_split(random_key) 

839 # ansatz layers 

840 self.pqc( 

841 params[layer], 

842 self.n_qubits, 

843 pulse_params=pulse_params[layer], 

844 noise_params=noise_params, 

845 random_key=sub_key, 

846 gate_mode=gate_mode, 

847 ) 

848 

849 random_key, sub_key = safe_random_split(random_key) 

850 # encoding layers 

851 self._iec( 

852 inputs, 

853 data_reupload=self.data_reupload[layer], 

854 enc=self._enc, 

855 enc_params=enc_params[layer], 

856 noise_params=noise_params, 

857 random_key=sub_key, 

858 ) 

859 

860 # visual barrier (no-op in yaqsi, purely cosmetic in PennyLane) 

861 

862 # final ansatz layer 

863 if self.has_dru: # same check as in init 

864 random_key, sub_key = safe_random_split(random_key) 

865 self.pqc( 

866 params[self.n_layers], 

867 self.n_qubits, 

868 pulse_params=pulse_params[-1], 

869 noise_params=noise_params, 

870 random_key=sub_key, 

871 gate_mode=gate_mode, 

872 ) 

873 

874 # channel noise 

875 if noise_params is not None: 

876 self._apply_general_noise(noise_params=noise_params) 

877 

878 def _build_obs(self) -> Tuple[str, List[op.Operation]]: 

879 """Build the yaqsi measurement type and observable list. 

880 

881 Translates the model's ``execution_type`` and ``output_qubit`` 

882 settings into parameters suitable for 

883 :meth:`~qml_essentials.yaqsi.Script.execute`. 

884 

885 Returns: 

886 Tuple ``(meas_type, obs)`` where *meas_type* is one of 

887 ``"expval"``, ``"probs"``, ``"density"``, ``"state"`` and *obs* 

888 is a (possibly empty) list of :class:`Operation` observables. 

889 """ 

890 if self.execution_type == "density": 

891 return "density", [] 

892 

893 if self.execution_type == "state": 

894 return "state", [] 

895 

896 if self.execution_type == "expval": 

897 obs: List[op.Operation] = [] 

898 for qubit_spec in self.output_qubit: 

899 if isinstance(qubit_spec, int): 

900 obs.append(op.PauliZ(wires=qubit_spec)) 

901 else: 

902 # parity: Z \\otimes Z \\otimes … 

903 obs.append(ys.build_parity_observable(list(qubit_spec))) 

904 return "expval", obs 

905 

906 if self.execution_type == "probs": 

907 # probs are computed on the full system; subsystem 

908 # marginalisation is handled in _postprocess_res 

909 return "probs", [] 

910 

911 raise ValueError(f"Invalid execution_type: {self.execution_type}.") 

912 

913 def _apply_state_prep_noise( 

914 self, noise_params: Dict[str, Union[float, Dict[str, float]]] 

915 ) -> None: 

916 """ 

917 Apply state preparation noise to all qubits. 

918 

919 Simulates imperfect state preparation by applying BitFlip errors 

920 to each qubit with the specified probability. 

921 

922 Args: 

923 noise_params (Dict[str, Union[float, Dict[str, float]]]): Dictionary 

924 containing noise parameters. Uses the "StatePreparation" key 

925 for the BitFlip probability. 

926 

927 Returns: 

928 None: Noise channels are applied in-place to the circuit. 

929 """ 

930 p = noise_params.get("StatePreparation", 0.0) 

931 if p > 0: 

932 for q in range(self.n_qubits): 

933 op.BitFlip(p, wires=q) 

934 

935 def _apply_general_noise( 

936 self, noise_params: Dict[str, Union[float, Dict[str, float]]] 

937 ) -> None: 

938 """ 

939 Apply general noise channels to all qubits. 

940 

941 Applies various decoherence and error channels after the circuit 

942 execution, simulating environmental noise effects. 

943 

944 Args: 

945 noise_params (Dict[str, Union[float, Dict[str, float]]]): Dictionary 

946 containing noise parameters with the following supported keys: 

947 - "AmplitudeDamping" (float): Probability for amplitude damping. 

948 - "PhaseDamping" (float): Probability for phase damping. 

949 - "Measurement" (float): Probability for measurement error (BitFlip). 

950 - "ThermalRelaxation" (Dict): Dictionary with keys "t1", "t2", 

951 "t_factor" for thermal relaxation simulation. 

952 

953 Returns: 

954 None: Noise channels are applied in-place to the circuit. 

955 

956 Note: 

957 Gate-level noise (e.g., GateError) is handled separately in the 

958 Gates.Noise module and applied at the individual gate level. 

959 """ 

960 amp_damp = noise_params.get("AmplitudeDamping", 0.0) 

961 phase_damp = noise_params.get("PhaseDamping", 0.0) 

962 thermal_relax = noise_params.get("ThermalRelaxation", 0.0) 

963 meas = noise_params.get("Measurement", 0.0) 

964 for q in range(self.n_qubits): 

965 if amp_damp > 0: 

966 op.AmplitudeDamping(amp_damp, wires=q) 

967 if phase_damp > 0: 

968 op.PhaseDamping(phase_damp, wires=q) 

969 if meas > 0: 

970 op.BitFlip(meas, wires=q) 

971 if isinstance(thermal_relax, dict): 

972 t1 = thermal_relax["t1"] 

973 t2 = thermal_relax["t2"] 

974 t_factor = thermal_relax["t_factor"] 

975 circuit_depth = self._get_circuit_depth() 

976 tg = circuit_depth * t_factor 

977 op.ThermalRelaxationError(1.0, t1, t2, tg, q) 

978 

979 def _get_circuit_depth(self, inputs: Optional[jnp.ndarray] = None) -> int: 

980 """ 

981 Calculate the depth of the quantum circuit. 

982 

983 Records the circuit onto a tape (without noise) and computes the 

984 depth as the length of the critical path: each gate is scheduled 

985 at the earliest time step after all of its qubits are free. 

986 

987 Args: 

988 inputs (Optional[jnp.ndarray]): Input data for circuit evaluation. 

989 If None, default zero inputs are used. 

990 

991 Returns: 

992 int: The circuit depth (longest path of gates in the circuit). 

993 """ 

994 # Return cached value if available 

995 if hasattr(self, "_cached_circuit_depth"): 

996 return self._cached_circuit_depth 

997 

998 inputs = self._inputs_validation(inputs) 

999 

1000 # Temporarily clear noise_params to prevent _variational from 

1001 # picking them up (which would call _apply_general_noise -> 

1002 # _get_circuit_depth again, causing infinite recursion). 

1003 saved_noise = self._noise_params 

1004 self._noise_params = None 

1005 

1006 with recording() as tape: 

1007 self._variational( 

1008 self.params[0] if self.params.ndim == 3 else self.params, 

1009 inputs[0] if inputs.ndim == 2 else inputs, 

1010 noise_params=None, 

1011 ) 

1012 

1013 self._noise_params = saved_noise 

1014 

1015 # Filter out noise channels - only count unitary gates 

1016 ops = [o for o in tape if not isinstance(o, KrausChannel)] 

1017 

1018 if not ops: 

1019 self._cached_circuit_depth = 0 

1020 return 0 

1021 

1022 # Schedule each gate at the earliest time step where all its wires 

1023 # are free. ``wire_busy[q]`` tracks the next free time step for 

1024 # qubit ``q``. 

1025 wire_busy: Dict[int, int] = {} 

1026 depth = 0 

1027 for gate in ops: 

1028 start = max((wire_busy.get(w, 0) for w in gate.wires), default=0) 

1029 end = start + 1 

1030 for w in gate.wires: 

1031 wire_busy[w] = end 

1032 depth = max(depth, end) 

1033 

1034 self._cached_circuit_depth = depth 

1035 return depth 

1036 

1037 def draw( 

1038 self, 

1039 inputs: Optional[jnp.ndarray] = None, 

1040 figure: str = "text", 

1041 **kwargs: Any, 

1042 ) -> Union[str, Any]: 

1043 """Visualize the quantum circuit. 

1044 

1045 Records the circuit tape (without noise) and renders the gate 

1046 sequence using the requested backend. 

1047 

1048 Args: 

1049 inputs (Optional[jnp.ndarray]): Input data for the circuit. 

1050 If ``None``, default zero inputs are used. 

1051 figure (str): Rendering backend. One of: 

1052 

1053 * ``"text"`` - ASCII art (returned as a ``str``). 

1054 * ``"mpl"`` - Matplotlib figure (returns ``(fig, ax)``). 

1055 * ``"tikz"`` - LaTeX/TikZ ``quantikz`` code (returns a 

1056 :class:`TikzFigure`). 

1057 * ``"pulse"`` - Pulse schedule (returns ``(fig, axes)``). 

1058 Only meaningful for pulse-mode models. 

1059 

1060 **kwargs: Extra options forwarded to the drawing backend 

1061 (e.g. ``gate_values=True``). 

1062 

1063 Returns: 

1064 Depends on figure: 

1065 

1066 * ``"text"`` -> ``str`` 

1067 * ``"mpl"`` -> ``(matplotlib.figure.Figure, matplotlib.axes.Axes)`` 

1068 * ``"tikz"`` -> :class:`TikzFigure` 

1069 

1070 Raises: 

1071 ValueError: If figure is not one of the supported modes. 

1072 """ 

1073 inputs = self._inputs_validation(inputs) 

1074 params = self.params[0] if self.params.ndim == 3 else self.params 

1075 inp = inputs[0] if inputs.ndim == 2 else inputs 

1076 

1077 if figure == "pulse": 

1078 return self.draw_pulse(inputs=inputs, **kwargs) 

1079 

1080 # Record without noise to get a clean circuit 

1081 saved_noise = self._noise_params 

1082 self._noise_params = None 

1083 

1084 draw_script = ys.Script(f=self._variational, n_qubits=self.n_qubits) 

1085 result = draw_script.draw( 

1086 figure=figure, 

1087 args=(params, inp), 

1088 kwargs={"noise_params": None}, 

1089 **kwargs, 

1090 ) 

1091 

1092 self._noise_params = saved_noise 

1093 return result 

1094 

1095 def draw_pulse( 

1096 self, 

1097 inputs: Optional[jnp.ndarray] = None, 

1098 **kwargs: Any, 

1099 ) -> Any: 

1100 """Visualize the pulse schedule for the circuit. 

1101 

1102 Records the circuit in pulse mode and collects PulseEvents 

1103 automatically via the pulse-event tape, then renders them. 

1104 

1105 Args: 

1106 inputs: Input data. If ``None``, default zero inputs are used. 

1107 **kwargs: Forwarded to 

1108 :func:`~qml_essentials.drawing.draw_pulse_schedule` 

1109 (e.g. ``show_carrier=True``, ``n_samples=300``). 

1110 

1111 Returns: 

1112 ``(fig, axes)`` — Matplotlib Figure and array of Axes. 

1113 """ 

1114 inputs = self._inputs_validation(inputs) 

1115 params = self.params[0] if self.params.ndim == 3 else self.params 

1116 inp = inputs[0] if inputs.ndim == 2 else inputs 

1117 

1118 draw_script = ys.Script(f=self._variational, n_qubits=self.n_qubits) 

1119 return draw_script.draw( 

1120 figure="pulse", 

1121 args=(params, inp), 

1122 kwargs={ 

1123 "gate_mode": "pulse", 

1124 "noise_params": None, 

1125 }, 

1126 **kwargs, 

1127 ) 

1128 

1129 def __repr__(self) -> str: 

1130 """Return text representation of the quantum circuit model.""" 

1131 return self.draw(figure="text") 

1132 

1133 def __str__(self) -> str: 

1134 """Return string representation of the quantum circuit model.""" 

1135 return self.draw(figure="text") 

1136 

1137 def _params_validation(self, params: Optional[jnp.ndarray]) -> jnp.ndarray: 

1138 """ 

1139 Validate and normalize variational parameters. 

1140 

1141 Ensures parameters have the correct shape with a batch dimension, 

1142 and updates the model's internal parameters if new ones are provided. 

1143 

1144 Args: 

1145 params (Optional[jnp.ndarray]): Variational parameters to validate. 

1146 If None, returns the model's current parameters. 

1147 

1148 Returns: 

1149 jnp.ndarray: Validated parameters with shape 

1150 (batch_size, n_layers, n_params_per_layer). 

1151 """ 

1152 # append batch axis if not provided 

1153 if params is not None: 

1154 if len(params.shape) == 2: 

1155 params = np.expand_dims(params, axis=0) 

1156 

1157 self.params = params 

1158 else: 

1159 params = self.params 

1160 

1161 return params 

1162 

1163 def _pulse_params_validation( 

1164 self, pulse_params: Optional[jnp.ndarray] 

1165 ) -> jnp.ndarray: 

1166 """ 

1167 Validate and normalize pulse parameters. 

1168 

1169 Ensures pulse parameters are set, using model defaults if not provided. 

1170 

1171 Args: 

1172 pulse_params (Optional[jnp.ndarray]): Pulse parameter scalers. 

1173 If None, returns the model's current pulse parameters. 

1174 

1175 Returns: 

1176 jnp.ndarray: Validated pulse parameters with shape 

1177 (batch_size, n_layers, n_pulse_params_per_layer). 

1178 """ 

1179 if pulse_params is None: 

1180 pulse_params = self.pulse_params 

1181 else: 

1182 # ensure batch dimension exists (batch-first convention) 

1183 if len(pulse_params.shape) == 2: 

1184 pulse_params = jnp.expand_dims(pulse_params, axis=0) 

1185 self.pulse_params = pulse_params 

1186 

1187 return pulse_params 

1188 

1189 def _enc_params_validation(self, enc_params: Optional[jnp.ndarray]) -> jnp.ndarray: 

1190 """ 

1191 Validate and normalize encoding parameters. 

1192 

1193 Ensures encoding parameters have the correct shape for the model's 

1194 input feature dimensions. 

1195 

1196 Args: 

1197 enc_params (Optional[jnp.ndarray]): Encoding parameters to validate. 

1198 If None, returns the model's current encoding parameters. 

1199 

1200 Returns: 

1201 jnp.ndarray: Validated encoding parameters with shape 

1202 (n_qubits, n_input_feat). 

1203 

1204 Raises: 

1205 ValueError: If enc_params shape is incompatible with n_input_feat > 1. 

1206 """ 

1207 if enc_params is None: 

1208 enc_params = self.enc_params 

1209 else: 

1210 if self.trainable_frequencies: 

1211 self.enc_params = enc_params 

1212 else: 

1213 self.enc_params = jnp.array(enc_params) 

1214 

1215 if len(enc_params.shape) == 1 and self.n_input_feat == 1: 

1216 enc_params = enc_params.reshape(-1, 1) 

1217 elif len(enc_params.shape) == 1 and self.n_input_feat > 1: 

1218 raise ValueError( 

1219 f"Input dimension {self.n_input_feat} >1 but \ 

1220 `enc_params` has shape {enc_params.shape}" 

1221 ) 

1222 

1223 return enc_params 

1224 

1225 def _inputs_validation( 

1226 self, inputs: Union[None, List, float, int, jnp.ndarray] 

1227 ) -> jnp.ndarray: 

1228 """ 

1229 Validate and normalize input data. 

1230 

1231 Converts various input formats to a standardized 2D array shape 

1232 suitable for batch processing in the quantum circuit. 

1233 

1234 Args: 

1235 inputs (Union[None, List, float, int, jnp.ndarray]): Input data in 

1236 various formats: 

1237 - None: Returns zeros with shape (1, n_input_feat) 

1238 - float/int: Single scalar value 

1239 - List: List of values or batched inputs 

1240 - jnp.ndarray: NumPy/JAX array 

1241 

1242 Returns: 

1243 jnp.ndarray: Validated inputs with shape (batch_size, n_input_feat). 

1244 

1245 Raises: 

1246 ValueError: If input shape is incompatible with expected n_input_feat. 

1247 

1248 Warns: 

1249 UserWarning: If input is replicated to match n_input_feat. 

1250 """ 

1251 self._zero_inputs = False 

1252 if isinstance(inputs, List): 

1253 inputs = jnp.array(np.stack(inputs)) 

1254 elif isinstance(inputs, float) or isinstance(inputs, int): 

1255 inputs = jnp.array([inputs]) 

1256 elif inputs is None: 

1257 inputs = jnp.array([[0] * self.n_input_feat]) 

1258 

1259 if not inputs.any(): 

1260 self._zero_inputs = True 

1261 

1262 if len(inputs.shape) <= 1: 

1263 if self.n_input_feat == 1: 

1264 # add a batch dimension 

1265 inputs = inputs.reshape(-1, 1) 

1266 else: 

1267 if inputs.shape[0] == self.n_input_feat: 

1268 inputs = inputs.reshape(1, -1) 

1269 else: 

1270 inputs = inputs.reshape(-1, 1) 

1271 inputs = inputs.repeat(self.n_input_feat, axis=1) 

1272 warnings.warn( 

1273 f"Expected {self.n_input_feat} inputs, but {inputs.shape[0]} " 

1274 "was provided, replicating input for all input features.", 

1275 UserWarning, 

1276 ) 

1277 else: 

1278 if inputs.shape[1] != self.n_input_feat: 

1279 raise ValueError( 

1280 f"Wrong number of inputs provided. Expected {self.n_input_feat} " 

1281 f"inputs, but input has shape {inputs.shape}." 

1282 ) 

1283 

1284 return inputs 

1285 

1286 def _postprocess_res(self, result: Union[List, jnp.ndarray]) -> jnp.ndarray: 

1287 """ 

1288 Post-process circuit execution results for uniform shape. 

1289 

1290 Converts list outputs (from multiple measurements) to stacked arrays 

1291 and reorders axes for consistent batch dimension placement. 

1292 

1293 Args: 

1294 result (Union[List, jnp.ndarray]): Raw circuit output, either a 

1295 list of measurement results or a single array. 

1296 

1297 Returns: 

1298 jnp.ndarray: Uniformly shaped result array with batch dimension first. 

1299 """ 

1300 if isinstance(result, list): 

1301 # we use moveaxis here because in case of parity measure, 

1302 # there is another dimension appended to the end and 

1303 # simply transposing would result in a wrong shape 

1304 result = jnp.stack(result) 

1305 if len(result.shape) > 1: 

1306 result = jnp.moveaxis(result, 0, 1) 

1307 return result 

1308 

1309 def _assimilate_batch( 

1310 self, 

1311 inputs: jnp.ndarray, 

1312 params: jnp.ndarray, 

1313 pulse_params: jnp.ndarray, 

1314 ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: 

1315 """ 

1316 Align batch dimensions across inputs, parameters, and pulse parameters. 

1317 

1318 Broadcasts and reshapes arrays to have compatible batch dimensions 

1319 for vectorized circuit execution. Sets the internal batch_shape. 

1320 

1321 Args: 

1322 inputs (jnp.ndarray): Input data of shape (B_I, n_input_feat). 

1323 params (jnp.ndarray): Parameters of shape (B_P, n_layers, n_params). 

1324 pulse_params (jnp.ndarray): Pulse params of shape (B_R, n_layers, n_pulse). 

1325 

1326 Returns: 

1327 Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: Tuple containing: 

1328 - inputs: Reshaped to (B, n_input_feat) where B = B_I * B_P * B_R 

1329 - params: Reshaped to (B, n_layers, n_params) 

1330 - pulse_params: Reshaped to (B, n_layers, n_pulse) 

1331 

1332 Note: 

1333 The effective batch shape depends on repeat_batch_axis configuration. 

1334 This is the only method that sets self._batch_shape. 

1335 """ 

1336 B_I = inputs.shape[0] 

1337 # we check for the product because there is a chance that 

1338 # there are no params. In this case we want B_P to be 1 

1339 B_P = 1 if 0 in params.shape else params.shape[0] 

1340 B_R = pulse_params.shape[0] 

1341 

1342 # THIS is the only place where we set the batch shape 

1343 self._batch_shape = (B_I, B_P, B_R) 

1344 B = np.prod(self.eff_batch_shape) 

1345 

1346 # [B_I, ...] -> [B_I, B_P, B_R, ...] -> [B, ...] 

1347 if B_I > 1 and self.repeat_batch_axis[0]: 

1348 if self.repeat_batch_axis[1]: 

1349 inputs = jnp.repeat(inputs[:, None, None, ...], B_P, axis=1) 

1350 if self.repeat_batch_axis[2]: 

1351 inputs = jnp.repeat(inputs, B_R, axis=2) 

1352 inputs = inputs.reshape(B, *inputs.shape[3:]) 

1353 

1354 # [B_P, ..., ...] -> [B_I, B_P, B_R, ..., ...] -> [B, ..., ...] 

1355 if B_P > 1 and self.repeat_batch_axis[1]: 

1356 # add B_I axis before first, and B_R axis after first batch dim 

1357 params = params[None, :, None, ...] # [B_I(=1), B_P, B_R(=1), ...] 

1358 if self.repeat_batch_axis[0]: 

1359 params = jnp.repeat(params, B_I, axis=0) # [B_I, B_P, 1, ...] 

1360 if self.repeat_batch_axis[2]: 

1361 params = jnp.repeat(params, B_R, axis=2) # [B_I, B_P, B_R, ...] 

1362 params = params.reshape(B, *params.shape[3:]) 

1363 

1364 # [B_R, ..., ...] -> [B_I, B_P, B_R, ..., ...] -> [B, ..., ...] 

1365 if B_R > 1 and self.repeat_batch_axis[2]: 

1366 # add B_I axis and B_P axis before B_R 

1367 pulse_params = pulse_params[None, None, ...] # [B_I(=1), B_P(=1), B_R, ...] 

1368 if self.repeat_batch_axis[0]: 

1369 pulse_params = jnp.repeat( 

1370 pulse_params, B_I, axis=0 

1371 ) # [B_I, 1, B_R, ...] 

1372 if self.repeat_batch_axis[1]: 

1373 pulse_params = jnp.repeat( 

1374 pulse_params, B_P, axis=1 

1375 ) # [B_I, B_P, B_R, ...] 

1376 pulse_params = pulse_params.reshape(B, *pulse_params.shape[3:]) 

1377 

1378 return inputs, params, pulse_params 

1379 

1380 def _requires_density(self) -> bool: 

1381 """ 

1382 Check if density matrix simulation is required. 

1383 

1384 Determines whether the circuit must be executed with the mixed-state 

1385 simulator based on execution type and noise configuration. 

1386 

1387 Returns: 

1388 bool: True if density matrix simulation is required, False otherwise. 

1389 Returns True if: 

1390 - execution_type is "density", or 

1391 - Any non-coherent noise channel has non-zero probability 

1392 """ 

1393 if self.execution_type == "density": 

1394 return True 

1395 

1396 if self.noise_params is None: 

1397 return False 

1398 

1399 coherent_noise = {"GateError"} 

1400 for k, v in self.noise_params.items(): 

1401 if k in coherent_noise: 

1402 continue 

1403 if v is not None and v > 0: 

1404 return True 

1405 return False 

1406 

1407 def __call__( 

1408 self, 

1409 params: Optional[jnp.ndarray] = None, 

1410 inputs: Optional[jnp.ndarray] = None, 

1411 pulse_params: Optional[jnp.ndarray] = None, 

1412 enc_params: Optional[jnp.ndarray] = None, 

1413 data_reupload: Union[bool, List[List[bool]], List[List[List[bool]]]] = None, 

1414 noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None, 

1415 execution_type: Optional[str] = None, 

1416 force_mean: bool = False, 

1417 gate_mode: str = "unitary", 

1418 ) -> jnp.ndarray: 

1419 """ 

1420 Execute the quantum circuit (callable interface). 

1421 

1422 Provides a convenient callable interface for circuit execution, 

1423 delegating to the _forward method. 

1424 

1425 Args: 

1426 params (Optional[jnp.ndarray]): Variational parameters of shape 

1427 (n_layers, n_params_per_layer) or (batch, n_layers, n_params_per_layer). 

1428 If None, uses model's internal parameters. 

1429 inputs (Optional[jnp.ndarray]): Input data of shape 

1430 (batch_size, n_input_feat). If None, uses zero inputs. 

1431 pulse_params (Optional[jnp.ndarray]): Pulse parameter scalers for 

1432 pulse-mode gate execution. 

1433 enc_params (Optional[jnp.ndarray]): Encoding parameters of shape 

1434 (n_qubits, n_input_feat). If None, uses model's encoding parameters. 

1435 data_reupload (Union[bool, List[List[bool]], List[List[List[bool]]]]): 

1436 Data reupload configuration. If None, uses previously set reupload 

1437 configuration. 

1438 noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]): 

1439 Noise configuration. If None, uses previously set noise parameters. 

1440 execution_type (Optional[str]): Measurement type: "expval", "density", 

1441 "probs", or "state". If None, uses current execution_type setting. 

1442 force_mean (bool): If True, averages results over measurement qubits. 

1443 Defaults to False. 

1444 gate_mode (str): Gate execution backend, "unitary" or "pulse". 

1445 Defaults to "unitary". 

1446 

1447 Returns: 

1448 jnp.ndarray: Circuit output with shape depending on execution_type: 

1449 - "expval": (n_output_qubits,) or scalar 

1450 - "density": (2^n_output, 2^n_output) 

1451 - "probs": (2^n_output,) or (n_pairs, 2^pair_size) 

1452 - "state": (2^n_qubits,) 

1453 """ 

1454 # Call forward method which handles the actual caching etc. 

1455 return self._forward( 

1456 params=params, 

1457 inputs=inputs, 

1458 pulse_params=pulse_params, 

1459 enc_params=enc_params, 

1460 data_reupload=data_reupload, 

1461 noise_params=noise_params, 

1462 execution_type=execution_type, 

1463 force_mean=force_mean, 

1464 gate_mode=gate_mode, 

1465 ) 

1466 

1467 def _forward( 

1468 self, 

1469 params: Optional[jnp.ndarray] = None, 

1470 inputs: Optional[jnp.ndarray] = None, 

1471 pulse_params: Optional[jnp.ndarray] = None, 

1472 enc_params: Optional[jnp.ndarray] = None, 

1473 data_reupload: Union[bool, List[List[bool]], List[List[List[bool]]]] = None, 

1474 noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None, 

1475 execution_type: Optional[str] = None, 

1476 force_mean: bool = False, 

1477 gate_mode: str = "unitary", 

1478 ) -> jnp.ndarray: 

1479 """ 

1480 Execute the quantum circuit forward pass. 

1481 

1482 Internal implementation of the forward pass that handles parameter 

1483 validation, batch alignment, and circuit execution routing. 

1484 

1485 Args: 

1486 params (Optional[jnp.ndarray]): Variational parameters of shape 

1487 (n_layers, n_params_per_layer) or 

1488 (batch, n_layers, n_params_per_layer). 

1489 If None, uses model's internal parameters. 

1490 inputs (Optional[jnp.ndarray]): Input data of shape 

1491 (batch_size, n_input_feat). 

1492 If None, uses zero inputs. 

1493 pulse_params (Optional[jnp.ndarray]): Pulse parameter scalers for 

1494 pulse-mode gate execution. 

1495 enc_params (Optional[jnp.ndarray]): Encoding parameters of shape 

1496 (n_qubits, n_input_feat). If None, uses model's encoding parameters. 

1497 data_reupload (Union[bool, List[List[bool]], List[List[List[bool]]]]): 

1498 Data reupload configuration. If None, uses previously set reupload 

1499 configuration. 

1500 noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]): 

1501 Noise configuration. If None, uses previously set noise parameters. 

1502 execution_type (Optional[str]): Measurement type: "expval", "density", 

1503 "probs", or "state". If None, uses current execution_type setting. 

1504 force_mean (bool): If True, averages results over measurement qubits. 

1505 Defaults to False. 

1506 gate_mode (str): Gate execution backend, "unitary" or "pulse". 

1507 Defaults to "unitary". 

1508 

1509 Returns: 

1510 jnp.ndarray: Circuit output with shape depending on execution_type: 

1511 - "expval": (n_output_qubits,) or scalar 

1512 - "density": (2^n_output, 2^n_output) 

1513 - "probs": (2^n_output,) or (n_pairs, 2^pair_size) 

1514 - "state": (2^n_qubits,) 

1515 

1516 Raises: 

1517 ValueError: If pulse_params provided without pulse gate_mode, or 

1518 if noise_params provided with pulse gate_mode. 

1519 """ 

1520 # set the parameters as object attributes 

1521 if noise_params is not None: 

1522 self.noise_params = noise_params 

1523 if execution_type is not None: 

1524 self.execution_type = execution_type 

1525 self.gate_mode = gate_mode 

1526 

1527 # consistency checks 

1528 if pulse_params is not None and gate_mode != "pulse": 

1529 raise ValueError( 

1530 "pulse_params were provided but gate_mode is not 'pulse'. " 

1531 "Either switch gate_mode='pulse' or do not pass pulse_params." 

1532 ) 

1533 

1534 # TODO: add testing 

1535 if data_reupload is not None: 

1536 self.data_reupload = data_reupload 

1537 

1538 params = self._params_validation(params) 

1539 pulse_params = self._pulse_params_validation(pulse_params) 

1540 inputs = self._inputs_validation(inputs) 

1541 enc_params = self._enc_params_validation(enc_params) 

1542 

1543 inputs, params, pulse_params = self._assimilate_batch( 

1544 inputs, 

1545 params, 

1546 pulse_params, 

1547 ) 

1548 

1549 # split to generate a sub_key, required for actual execution 

1550 self.random_key, sub_key = safe_random_split(self.random_key) 

1551 

1552 # Build measurement type & observables from execution_type / output_qubit 

1553 meas_type, obs = self._build_obs() 

1554 

1555 # Yaqsi auto-routes between statevector and density-matrix simulation 

1556 # based on whether noise channels appear on the tape, so a single 

1557 B = np.prod(self.eff_batch_shape) 

1558 

1559 # kwargs are broadcast (not vmapped over) 

1560 exec_kwargs = dict( 

1561 noise_params=self.noise_params, 

1562 gate_mode=self.gate_mode, 

1563 ) 

1564 

1565 # Build a shot key from the random_key if shots are requested 

1566 shot_key = None 

1567 if self.shots is not None: 

1568 # overwrite subkey and split shot_key 

1569 sub_key, shot_key = safe_random_split(sub_key) 

1570 

1571 if B > 1: 

1572 # use random keys, derived from the subkey 

1573 random_keys = safe_random_split(sub_key, num=B) 

1574 

1575 in_axes = ( 

1576 0 if self.batch_shape[1] > 1 else None, # params 

1577 0 if self.batch_shape[0] > 1 else None, # inputs 

1578 0 if self.batch_shape[2] > 1 else None, # pulse_params 

1579 0, # random_keys 

1580 None, # enc_params (broadcast, not batched) 

1581 ) 

1582 

1583 result = self.script.execute( 

1584 type=meas_type, 

1585 obs=obs, 

1586 args=(params, inputs, pulse_params, random_keys, enc_params), 

1587 kwargs=exec_kwargs, 

1588 in_axes=in_axes, 

1589 shots=self.shots, 

1590 key=shot_key, 

1591 ) 

1592 else: 

1593 # use the subkey directly 

1594 result = self.script.execute( 

1595 type=meas_type, 

1596 obs=obs, 

1597 args=(params, inputs, pulse_params, sub_key, enc_params), 

1598 kwargs=exec_kwargs, 

1599 shots=self.shots, 

1600 key=shot_key, 

1601 ) 

1602 

1603 result = self._postprocess_res(result) 

1604 

1605 # --- Post-processing for partial-qubit measurements --------------- 

1606 if self.execution_type == "density" and not self.all_qubit_measurement: 

1607 result = ys.partial_trace(result, self.n_qubits, self.output_qubit) 

1608 

1609 if self.execution_type == "probs" and not self.all_qubit_measurement: 

1610 if isinstance(self.output_qubit[0], (list, tuple)): 

1611 # list of qubit groups - marginalize each independently 

1612 result = jnp.stack( 

1613 [ 

1614 ys.marginalize_probs(result, self.n_qubits, list(group)) 

1615 for group in self.output_qubit 

1616 ] 

1617 ) 

1618 else: 

1619 result = ys.marginalize_probs(result, self.n_qubits, self.output_qubit) 

1620 

1621 result = jnp.asarray(result) 

1622 result = result.reshape((*self.eff_batch_shape, *self._result_shape)).squeeze() 

1623 

1624 if ( 

1625 self.execution_type in ("expval", "probs") 

1626 and force_mean 

1627 and len(result.shape) > 0 

1628 and self._result_shape[0] > 1 

1629 ): 

1630 result = result.mean(axis=-1) 

1631 

1632 return result