Coverage for qml_essentials/model.py: 93%
478 statements
« prev ^ index » next coverage.py v7.9.2, created at 2026-02-20 14:03 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2026-02-20 14:03 +0000
1from typing import Any, Dict, Optional, Tuple, Callable, Union, List
2import pennylane as qml
3import warnings
4from copy import deepcopy
5import jax
6import jax.numpy as jnp
7import numpy as np
8from jax import random
10from qml_essentials.ansaetze import Ansaetze, Circuit, Encoding
11from qml_essentials.gates import Gates
12from qml_essentials.gates import PulseInformation as pinfo
13from qml_essentials.utils import QuanTikz, safe_random_split
15import logging
17log = logging.getLogger(__name__)
20class Model:
21 """
22 A quantum circuit model.
23 """
25 lightning_threshold = 12
26 cpu_scaler = 0.9 # default cpu scaler, =1 means full CPU for MP
28 def __init__(
29 self,
30 n_qubits: int,
31 n_layers: int,
32 circuit_type: Union[str, Circuit] = "No_Ansatz",
33 data_reupload: Union[bool, List[List[bool]], List[List[List[bool]]]] = True,
34 state_preparation: Union[
35 str, Callable, List[Union[str, Callable]], None
36 ] = None,
37 encoding: Union[Encoding, str, Callable, List[Union[str, Callable]]] = Gates.RX,
38 trainable_frequencies: bool = False,
39 initialization: str = "random",
40 initialization_domain: List[float] = [0, 2 * jnp.pi],
41 output_qubit: Union[List[int], int] = -1,
42 shots: Optional[int] = None,
43 random_seed: int = 1000,
44 remove_zero_encoding: bool = True,
45 use_multithreading: bool = False,
46 repeat_batch_axis: List[bool] = [True, True, True],
47 ) -> None:
48 """
49 Initialize the quantum circuit model.
50 Parameters will have the shape [impl_n_layers, parameters_per_layer]
51 where impl_n_layers is the number of layers provided and added by one
52 depending if data_reupload is True and parameters_per_layer is given by
53 the chosen ansatz.
55 The model is initialized with the following parameters as defaults:
56 - noise_params: None
57 - execution_type: "expval"
58 - shots: None
60 Args:
61 n_qubits (int): The number of qubits in the circuit.
62 n_layers (int): The number of layers in the circuit.
63 circuit_type (str, Circuit): The type of quantum circuit to use.
64 If None, defaults to "no_ansatz".
65 data_reupload (Union[bool, List[bool], List[List[bool]]], optional):
66 Whether to reupload data to the quantum device on each
67 layer and qubit. Detailed re-uploading instructions can be given
68 as a list/array of 0/False and 1/True with shape (n_qubits,
69 n_layers) to specify where to upload the data. Defaults to True
70 for applying data re-uploading to the full circuit.
71 encoding (Union[str, Callable, List[str], List[Callable]], optional):
72 The unitary to use for encoding the input data. Can be a string
73 (e.g. "RX") or a callable (e.g. qml.RX). Defaults to qml.RX.
74 If input is multidimensional it is assumed to be a list of
75 unitaries or a list of strings.
76 trainable_frequencies (bool, optional):
77 Sets trainable encoding parameters for trainable frequencies.
78 Defaults to False.
79 initialization (str, optional): The strategy to initialize the parameters.
80 Can be "random", "zeros", "zero-controlled", "pi", or "pi-controlled".
81 Defaults to "random".
82 output_qubit (List[int], int, optional): The index of the output
83 qubit (or qubits). When set to -1 all qubits are measured, or a
84 global measurement is conducted, depending on the execution
85 type.
86 shots (Optional[int], optional): The number of shots to use for
87 the quantum device. Defaults to None.
88 random_seed (int, optional): seed for the random number generator
89 in initialization is "random" and for random noise parameters.
90 Defaults to 1000.
91 remove_zero_encoding (bool, optional): whether to
92 remove the zero encoding from the circuit. Defaults to True.
93 use_multithreading (bool, optional): whether to use JAX
94 multithreading to parallelise over batch dimension.
96 Returns:
97 None
98 """
99 # Initialize default parameters needed for circuit evaluation
100 self.n_qubits: int = n_qubits
101 self.output_qubit: Union[List[int], int] = output_qubit
102 self.n_layers: int = n_layers
103 self.noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None
104 self.shots = shots
105 self.remove_zero_encoding = remove_zero_encoding
106 self.use_multithreading = use_multithreading
107 self.trainable_frequencies: bool = trainable_frequencies
108 self.execution_type: str = "expval"
109 self.repeat_batch_axis: List[bool] = repeat_batch_axis
111 # --- State Preparation ---
112 try:
113 self._sp = Gates.parse_gates(state_preparation, Gates)
114 except ValueError as e:
115 raise ValueError(f"Error parsing encodings: {e}")
117 # prepare corresponding pulse parameters (always optimized pulses)
118 self.sp_pulse_params = []
119 for sp in self._sp:
120 sp_name = sp.__name__ if hasattr(sp, "__name__") else str(sp)
122 if pinfo.gate_by_name(sp_name) is not None:
123 self.sp_pulse_params.append(pinfo.gate_by_name(sp_name).params)
124 else:
125 # gate has no pulse parametrization
126 self.sp_pulse_params.append(None)
128 # --- Encoding ---
129 if isinstance(encoding, Encoding):
130 # user wants custom strategy? do it!
131 self._enc = encoding
132 else:
133 # use hammming encoding by default
134 self._enc = Encoding("hamming", encoding)
136 # Number of possible inputs
137 self.n_input_feat = len(self._enc)
138 log.debug(f"Number of input features: {self.n_input_feat}")
140 # Trainable frequencies, default initialization as in arXiv:2309.03279v2
141 self.enc_params = jnp.ones((self.n_qubits, self.n_input_feat))
143 self._zero_inputs = False
145 # --- Data-Reuploading ---
146 # Process data reuploading strategy and set degree
147 if not isinstance(data_reupload, bool):
148 if not isinstance(data_reupload, np.ndarray):
149 data_reupload = np.array(data_reupload)
151 if len(data_reupload.shape) == 2:
152 assert data_reupload.shape == (
153 n_layers,
154 n_qubits,
155 ), f"Data reuploading array has wrong shape. \
156 Expected {(n_layers, n_qubits)} or\
157 {(n_layers, n_qubits, self.n_input_feat)},\
158 got {data_reupload.shape}."
159 data_reupload = data_reupload.reshape(*data_reupload.shape, 1)
160 data_reupload = np.repeat(data_reupload, self.n_input_feat, axis=2)
162 assert data_reupload.shape == (
163 n_layers,
164 n_qubits,
165 self.n_input_feat,
166 ), f"Data reuploading array has wrong shape. \
167 Expected {(n_layers, n_qubits, self.n_input_feat)},\
168 got {data_reupload.shape}."
170 log.debug(f"Data reuploading array:\n{data_reupload}")
171 else:
172 if data_reupload:
173 impl_n_layers: int = (
174 n_layers + 1
175 ) # we need L+1 according to Schuld et al.
176 data_reupload = np.ones((n_layers, n_qubits, self.n_input_feat))
177 log.debug("Full data reuploading.")
178 else:
179 impl_n_layers: int = n_layers
180 data_reupload = np.zeros((n_layers, n_qubits, self.n_input_feat))
181 data_reupload[0][0] = 1
182 log.debug("No data reuploading.")
184 # convert to boolean values
185 data_reupload = data_reupload.astype(bool)
186 self.data_reupload = jnp.array(data_reupload)
188 self.degree: Tuple = tuple(
189 self._enc.get_n_freqs(jnp.count_nonzero(self.data_reupload[..., i]))
190 for i in range(self.n_input_feat)
191 )
193 self.frequencies: Tuple = tuple(
194 self._enc.get_spectrum(jnp.count_nonzero(self.data_reupload[..., i]))
195 for i in range(self.n_input_feat)
196 )
198 self.has_dru = jnp.max(jnp.array([jnp.max(f) for f in self.frequencies])) > 1
200 # check for the highest degree among all input dimensions
201 if self.has_dru:
202 impl_n_layers: int = n_layers + 1 # we need L+1 according to Schuld et al.
203 else:
204 impl_n_layers = n_layers
205 log.info(f"Number of implicit layers: {impl_n_layers}.")
207 # --- Ansatz ---
208 # only weak check for str. We trust the user to provide sth useful
209 if isinstance(circuit_type, str):
210 self.pqc: Callable[[Optional[jnp.ndarray], int], int] = getattr(
211 Ansaetze, circuit_type or "No_Ansatz"
212 )()
213 else:
214 self.pqc = circuit_type()
215 log.info(f"Using Ansatz {circuit_type}.")
217 # calculate the shape of the parameter vector here, we will re-use this in init.
218 params_per_layer = self.pqc.n_params_per_layer(self.n_qubits)
219 self._params_shape: Tuple[int, int] = (impl_n_layers, params_per_layer)
220 log.info(f"Parameters per layer: {params_per_layer}")
222 pulse_params_per_layer = self.pqc.n_pulse_params_per_layer(self.n_qubits)
223 self._pulse_params_shape: Tuple[int, int] = (
224 impl_n_layers,
225 pulse_params_per_layer,
226 )
228 # intialize to None as we can't know this yet
229 self._batch_shape = None
231 # this will also be re-used in the init method,
232 # however, only if nothing is provided
233 self._inialization_strategy = initialization
234 self._initialization_domain = initialization_domain
236 # ..here! where we only require a JAX random key
237 self.random_key = self.initialize_params(random.key(random_seed))
239 # Initializing pulse params
240 self.pulse_params: jnp.ndarray = jnp.ones((*self._pulse_params_shape, 1))
242 log.info(f"Initialized pulse parameters with shape {self.pulse_params.shape}.")
244 # Initialize two circuits, one with the default device and
245 # one with the mixed device
246 # which allows us to later route depending on the state_vector flag
247 if self.n_qubits < self.lightning_threshold:
248 device = "default.qubit"
249 else:
250 device = "lightning.qubit"
251 self.use_multithreading = False
252 self.circuit: qml.QNode = qml.QNode(
253 self._circuit,
254 qml.device(
255 device,
256 shots=self.shots,
257 wires=self.n_qubits,
258 ),
259 interface="jax-jit",
260 diff_method="parameter-shift" if self.shots is not None else "best",
261 )
263 self.circuit_mixed: qml.QNode = qml.QNode(
264 self._circuit,
265 qml.device("default.mixed", shots=self.shots, wires=self.n_qubits),
266 interface="jax-jit",
267 diff_method="parameter-shift" if self.shots is not None else "best",
268 )
270 @property
271 def noise_params(self) -> Optional[Dict[str, Union[float, Dict[str, float]]]]:
272 """
273 Gets the noise parameters of the model.
275 Returns:
276 Optional[Dict[str, float]]: A dictionary of
277 noise parameters or None if not set.
278 """
279 return self._noise_params
281 @noise_params.setter
282 def noise_params(
283 self, kvs: Optional[Dict[str, Union[float, Dict[str, float]]]]
284 ) -> None:
285 """
286 Sets the noise parameters of the model.
288 Typically a "noise parameter" refers to the error probability.
289 ThermalRelaxation is a special case, and supports a dict as value with
290 structure:
291 "ThermalRelaxation":
292 {
293 "t1": 2000, # relative t1 time.
294 "t2": 1000, # relative t2 time
295 "t_factor" 1: # relative gate time factor
296 },
298 Args:
299 kvs (Optional[Dict[str, Union[float, Dict[str, float]]]]): A
300 dictionary of noise parameters. If all values are 0.0, the noise
301 parameters are set to None.
303 Returns:
304 None
305 """
306 # set to None if only zero values provided
307 if kvs is not None and all(v == 0.0 for v in kvs.values()):
308 kvs = None
310 # set default values
311 if kvs is not None:
312 defaults = {
313 "BitFlip": 0.0,
314 "PhaseFlip": 0.0,
315 "Depolarizing": 0.0,
316 "MultiQubitDepolarizing": 0.0,
317 "AmplitudeDamping": 0.0,
318 "PhaseDamping": 0.0,
319 "GateError": 0.0,
320 "ThermalRelaxation": None,
321 "StatePreparation": 0.0,
322 "Measurement": 0.0,
323 }
324 for key, default_val in defaults.items():
325 kvs.setdefault(key, default_val)
327 # check if there are any keys not supported
328 for key in kvs.keys():
329 if key not in defaults:
330 warnings.warn(
331 f"Noise type {key} is not supported by this package",
332 UserWarning,
333 )
335 # check valid params for thermal relaxation noise channel
336 tr_params = kvs["ThermalRelaxation"]
337 if isinstance(tr_params, dict):
338 tr_params.setdefault("t1", 0.0)
339 tr_params.setdefault("t2", 0.0)
340 tr_params.setdefault("t_factor", 0.0)
341 valid_tr_keys = {"t1", "t2", "t_factor"}
342 for k in tr_params.keys():
343 if k not in valid_tr_keys:
344 warnings.warn(
345 f"Thermal Relaxation parameter {k} is not supported "
346 f"by this package",
347 UserWarning,
348 )
349 if not all(tr_params.values()) or tr_params["t2"] > 2 * tr_params["t1"]:
350 warnings.warn(
351 "Received invalid values for Thermal Relaxation noise "
352 "parameter. Thermal relaxation is not applied!",
353 UserWarning,
354 )
355 kvs["ThermalRelaxation"] = 0.0
357 self._noise_params = kvs
359 @property
360 def output_qubit(self) -> List[int]:
361 """Get the output qubit indices for measurement."""
362 return self._output_qubit
364 @output_qubit.setter
365 def output_qubit(self, value: Union[int, List[int]]) -> None:
366 """
367 Set the output qubit(s) for measurement.
369 Args:
370 value: Qubit index or list of indices. Use -1 for all qubits.
371 """
372 if isinstance(value, list):
373 assert (
374 len(value) <= self.n_qubits
375 ), f"Size of output_qubit {len(value)} cannot be\
376 larger than number of qubits {self.n_qubits}."
377 elif isinstance(value, int):
378 if value == -1:
379 value = list(range(self.n_qubits))
380 else:
381 assert (
382 value < self.n_qubits
383 ), f"Output qubit {value} cannot be larger than {self.n_qubits}."
384 value = [value]
386 self._output_qubit = value
388 @property
389 def execution_type(self) -> str:
390 """
391 Gets the execution type of the model.
393 Returns:
394 str: The execution type, one of 'density', 'expval', or 'probs'.
395 """
396 return self._execution_type
398 @execution_type.setter
399 def execution_type(self, value: str) -> None:
400 if value == "density":
401 self._result_shape = (
402 2 ** len(self.output_qubit),
403 2 ** len(self.output_qubit),
404 )
405 elif value == "expval":
406 # check if all qubits are used
407 if len(self.output_qubit) == self.n_qubits:
408 self._result_shape = (len(self.output_qubit),)
409 # if not -> parity measurement with only 1D output per pair
410 # or n_local measurement
411 else:
412 self._result_shape = (len(self.output_qubit),)
413 elif value == "probs":
414 # in case this is a list of parities,
415 # each pair has 2^len(qubits) probabilities
416 n_parity = (
417 2 ** len(self.output_qubit[0])
418 if isinstance(self.output_qubit[0], Tuple)
419 else 2
420 )
421 self._result_shape = (len(self.output_qubit), n_parity)
422 elif value == "state":
423 self._result_shape = (2 ** len(self.output_qubit),)
424 else:
425 raise ValueError(f"Invalid execution type: {value}.")
427 if value == "state" and not self.all_qubit_measurement:
428 warnings.warn(
429 f"{value} measurement does ignore output_qubit, which is "
430 f"{self.output_qubit}.",
431 UserWarning,
432 )
434 if value == "probs" and self.shots is None:
435 warnings.warn(
436 "Setting execution_type to probs without specifying shots.",
437 UserWarning,
438 )
440 if value == "density" and self.shots is not None:
441 warnings.warn(
442 "Setting execution_type to density with specified shots.",
443 UserWarning,
444 )
446 self._execution_type = value
448 @property
449 def shots(self) -> Optional[int]:
450 """
451 Gets the number of shots to use for the quantum device.
453 Returns:
454 Optional[int]: The number of shots.
455 """
456 return self._shots
458 @shots.setter
459 def shots(self, value: Optional[int]) -> None:
460 """
461 Sets the number of shots to use for the quantum device.
463 Args:
464 value (Optional[int]): The number of shots.
465 If an integer less than or equal to 0 is provided, it is set to None.
467 Returns:
468 None
469 """
470 if type(value) is int and value <= 0:
471 value = None
472 self._shots = value
474 @property
475 def params(self) -> jnp.ndarray:
476 """Get the variational parameters of the model."""
477 return self._params
479 @params.setter
480 def params(self, value: jnp.ndarray) -> None:
481 """Set the variational parameters, ensuring batch dimension exists."""
482 if len(value.shape) == 2:
483 value = value.reshape(*value.shape, 1)
485 self._params = value
487 @property
488 def enc_params(self) -> jnp.ndarray:
489 """Get the encoding parameters used for input transformation."""
490 return self._enc_params
492 @enc_params.setter
493 def enc_params(self, value: jnp.ndarray) -> None:
494 """Set the encoding parameters."""
495 self._enc_params = value
497 @property
498 def pulse_params(self) -> jnp.ndarray:
499 """Get the pulse parameters for pulse-mode gate execution."""
500 return self._pulse_params
502 @pulse_params.setter
503 def pulse_params(self, value: jnp.ndarray) -> None:
504 """Set the pulse parameters."""
505 self._pulse_params = value
507 @property
508 def all_qubit_measurement(self) -> bool:
509 """Check if measurement is performed on all qubits."""
510 return self.output_qubit == list(range(self.n_qubits))
512 @property
513 def batch_shape(self) -> Tuple[int, ...]:
514 """
515 Get the batch shape (B_I, B_P, B_R).
516 If the model was not called before,
517 it returns (1, 1, 1).
519 Returns:
520 Tuple[int, ...]: Tuple of (input_batch, param_batch, pulse_batch).
521 Returns (1, 1, 1) if model has not been called yet.
522 """
523 if self._batch_shape is None:
524 log.debug("Model was not called yet. Returning (1,1,1) as batch shape.")
525 return (1, 1, 1)
526 return self._batch_shape
528 @property
529 def eff_batch_shape(self) -> Tuple[int, ...]:
530 """
531 Get the effective batch shape after applying repeat_batch_axis mask.
533 Returns:
534 Tuple[int, ...]: Effective batch dimensions, excluding zeros.
535 """
536 batch_shape = np.array(self.batch_shape) * self.repeat_batch_axis
537 batch_shape = batch_shape[batch_shape != 0]
538 return batch_shape
540 def initialize_params(
541 self,
542 random_key: Optional[random.PRNGKey] = None,
543 repeat: int = 1,
544 initialization: Optional[str] = None,
545 initialization_domain: Optional[List[float]] = None,
546 ) -> random.PRNGKey:
547 """
548 Initialize the variational parameters of the model.
550 Args:
551 random_key (Optional[random.PRNGKey]): JAX random key for initialization.
552 If None, uses the model's internal random key.
553 repeat (int): Number of parameter sets to create (batch dimension).
554 Defaults to 1.
555 initialization (Optional[str]): Strategy for parameter initialization.
556 Options: "random", "zeros", "pi", "zero-controlled", "pi-controlled".
557 If None, uses the strategy specified in the constructor.
558 initialization_domain (Optional[List[float]]): Domain [min, max] for
559 random initialization. If None, uses the domain from constructor.
561 Returns:
562 random.PRNGKey: Updated random key after initialization.
564 Raises:
565 Exception: If an invalid initialization method is specified.
566 """
567 # Initializing params
568 params_shape = (*self._params_shape, repeat)
570 # use existing strategy if not specified
571 initialization = initialization or self._inialization_strategy
572 initialization_domain = initialization_domain or self._initialization_domain
574 random_key, sub_key = safe_random_split(
575 random_key if random_key is not None else self.random_key
576 )
578 def set_control_params(params: jnp.ndarray, value: float) -> jnp.ndarray:
579 indices = self.pqc.get_control_indices(self.n_qubits)
580 if indices is None:
581 warnings.warn(
582 f"Specified {initialization} but circuit\
583 does not contain controlled rotation gates.\
584 Parameters are intialized randomly.",
585 UserWarning,
586 )
587 else:
588 np_params = np.array(params)
589 np_params[:, indices[0] : indices[1] : indices[2]] = (
590 np.ones_like(params[:, indices[0] : indices[1] : indices[2]])
591 * value
592 )
593 params = jnp.array(np_params)
594 return params
596 if initialization == "random":
597 self.params: jnp.ndarray = random.uniform(
598 sub_key,
599 params_shape,
600 minval=initialization_domain[0],
601 maxval=initialization_domain[1],
602 )
603 elif initialization == "zeros":
604 self.params: jnp.ndarray = jnp.zeros(params_shape)
605 elif initialization == "pi":
606 self.params: jnp.ndarray = jnp.ones(params_shape) * jnp.pi
607 elif initialization == "zero-controlled":
608 self.params: jnp.ndarray = random.uniform(
609 sub_key,
610 params_shape,
611 minval=initialization_domain[0],
612 maxval=initialization_domain[1],
613 )
614 self.params = set_control_params(self.params, 0)
615 elif initialization == "pi-controlled":
616 self.params: jnp.ndarray = random.uniform(
617 sub_key,
618 params_shape,
619 minval=initialization_domain[0],
620 maxval=initialization_domain[1],
621 )
622 self.params = set_control_params(self.params, jnp.pi)
623 else:
624 raise Exception("Invalid initialization method")
626 log.info(
627 f"Initialized parameters with shape {self.params.shape}\
628 using strategy {initialization}."
629 )
631 return random_key
633 def transform_input(
634 self, inputs: jnp.ndarray, enc_params: jnp.ndarray
635 ) -> jnp.ndarray:
636 """
637 Transform input data by scaling with encoding parameters.
639 Implements the input transformation as described in arXiv:2309.03279v2,
640 where inputs are linearly scaled by encoding parameters before being
641 used in the quantum circuit.
643 Args:
644 inputs (jnp.ndarray): Input data point of shape (n_input_feat,) or
645 (batch_size, n_input_feat).
646 enc_params (jnp.ndarray): Encoding weight scalar or vector used to
647 scale the input.
649 Returns:
650 jnp.ndarray: Transformed input, element-wise product of inputs
651 and enc_params.
652 """
653 return inputs * enc_params
655 def _iec(
656 self,
657 inputs: jnp.ndarray,
658 data_reupload: jnp.ndarray,
659 enc: Encoding,
660 enc_params: jnp.ndarray,
661 noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None,
662 random_key: Optional[random.PRNGKey] = None,
663 ) -> None:
664 """
665 Apply Input Encoding Circuit (IEC) with angle encoding.
667 Encodes classical input data into the quantum circuit using rotation
668 gates (e.g., RX, RY, RZ). Supports data re-uploading at specified
669 positions in the circuit.
671 Args:
672 inputs (jnp.ndarray): Input data of shape (n_input_feat,) or
673 (batch_size, n_input_feat).
674 data_reupload (jnp.ndarray): Boolean array of shape (n_qubits, n_input_feat)
675 indicating where to apply encoding gates.
676 enc (Encoding): Encoding strategy containing the encoding gate functions.
677 enc_params (jnp.ndarray): Encoding parameters of shape
678 (n_qubits, n_input_feat) used to scale inputs.
679 noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]):
680 Noise parameters for gate-level noise simulation. Defaults to None.
681 random_key (Optional[random.PRNGKey]): JAX random key for stochastic
682 noise. Defaults to None.
684 Returns:
685 None: Gates are applied in-place to the quantum circuit.
686 """
687 # check for zero, because due to input validation, input cannot be none
688 if self.remove_zero_encoding and self._zero_inputs and self.batch_shape[0] == 1:
689 return
691 for q in range(self.n_qubits):
692 # use the last dimension of the inputs (feature dimension)
693 for idx in range(inputs.shape[-1]):
694 if data_reupload[q, idx]:
695 # use elipsis to indiex only the last dimension
696 # as inputs are generally *not* qubit dependent
697 random_key, sub_key = safe_random_split(random_key)
698 enc[idx](
699 self.transform_input(inputs[..., idx], enc_params[q, idx]),
700 wires=q,
701 noise_params=noise_params,
702 random_key=sub_key,
703 )
705 def _circuit(
706 self,
707 params: jnp.ndarray,
708 inputs: jnp.ndarray,
709 pulse_params: Optional[jnp.ndarray] = None,
710 enc_params: Optional[jnp.ndarray] = None,
711 gate_mode: str = "unitary",
712 noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None,
713 random_key: Optional[random.PRNGKey] = None,
714 ) -> Union[float, jnp.ndarray]:
715 """
716 Build and execute the quantum circuit.
718 Constructs the full quantum circuit including variational layers and
719 encoding, then returns the measurement result based on the configured
720 execution type.
722 Args:
723 params (jnp.ndarray): Variational parameters of shape
724 (n_layers, n_params_per_layer).
725 inputs (jnp.ndarray): Input data of shape (n_input_feat,).
726 pulse_params (Optional[jnp.ndarray]): Pulse parameter scalers of shape
727 (n_layers, n_pulse_params_per_layer) for pulse-mode execution.
728 Defaults to None.
729 enc_params (Optional[jnp.ndarray]): Encoding parameters of shape
730 (n_qubits, n_input_feat). Defaults to None (uses model's enc_params).
731 gate_mode (str): Gate execution mode, either "unitary" or "pulse".
732 Defaults to "unitary".
733 noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]):
734 Noise parameters for simulation. Defaults to None.
735 random_key (Optional[random.PRNGKey]): JAX random key for stochastic
736 operations. Defaults to None.
738 Returns:
739 Union[float, jnp.ndarray]: Circuit output depending on execution_type:
740 - "expval": Expectation value(s) of the observable(s)
741 - "density": Density matrix of output qubits
742 - "probs": Measurement probabilities
743 - "state": Full quantum state vector
744 """
745 self._variational(
746 params=params,
747 inputs=inputs,
748 pulse_params=pulse_params,
749 enc_params=enc_params,
750 gate_mode=gate_mode,
751 noise_params=noise_params,
752 random_key=random_key,
753 )
754 return self._observable()
756 def _variational(
757 self,
758 params: jnp.ndarray,
759 inputs: jnp.ndarray,
760 pulse_params: Optional[jnp.ndarray] = None,
761 enc_params: Optional[jnp.ndarray] = None,
762 gate_mode: str = "unitary",
763 noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None,
764 random_key: Optional[random.PRNGKey] = None,
765 ) -> None:
766 """
767 Build the variational quantum circuit structure.
769 Constructs the circuit by applying state preparation, alternating
770 variational ansatz layers with input encoding layers, and optional
771 noise channels.
773 Args:
774 params (jnp.ndarray): Variational parameters of shape
775 (n_layers, n_params_per_layer).
776 inputs (jnp.ndarray): Input data of shape (n_input_feat,).
777 pulse_params (Optional[jnp.ndarray]): Pulse parameter scalers of shape
778 (n_layers, n_pulse_params_per_layer) for pulse-mode execution.
779 Defaults to None (uses model's pulse_params).
780 enc_params (Optional[jnp.ndarray]): Encoding parameters of shape
781 (n_qubits, n_input_feat). Defaults to None (uses model's enc_params).
782 gate_mode (str): Gate execution mode, either "unitary" or "pulse".
783 Defaults to "unitary".
784 noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]):
785 Noise parameters for simulation. Defaults to None.
786 random_key (Optional[random.PRNGKey]): JAX random key for stochastic
787 operations. Defaults to None.
789 Returns:
790 None: Gates are applied in-place to the quantum circuit.
792 Note:
793 Issues RuntimeWarning if called directly without providing parameters
794 that would normally be passed through the forward method.
795 """
796 # TODO: rework and double check params shape
797 if len(params.shape) > 2 and params.shape[2] == 1:
798 params = params[:, :, 0]
800 if len(inputs.shape) > 1 and inputs.shape[0] == 1:
801 inputs = inputs[0]
803 if enc_params is None:
804 # TODO: Raise warning if trainable frequencies is True, or similar. I.e., no
805 # warning if user does not care for frequencies or enc_params
806 if self.trainable_frequencies:
807 warnings.warn(
808 "Explicit call to `_circuit` or `_variational` detected: "
809 "`enc_params` is None, using `self.enc_params` instead.",
810 RuntimeWarning,
811 )
812 enc_params = self.enc_params
814 if pulse_params is None:
815 if gate_mode == "pulse":
816 warnings.warn(
817 "Explicit call to `_circuit` or `_variational` detected: "
818 "`pulse_params` is None, using `self.pulse_params` instead.",
819 RuntimeWarning,
820 )
821 pulse_params = self.pulse_params
823 if noise_params is None:
824 if self.noise_params is not None:
825 warnings.warn(
826 "Explicit call to `_circuit` or `_variational` detected: "
827 "`noise_params` is None, using `self.noise_params` instead.",
828 RuntimeWarning,
829 )
830 noise_params = self.noise_params
832 if noise_params is not None:
833 if random_key is None:
834 warnings.warn(
835 "Explicit call to `_circuit` or `_variational` detected: "
836 "`random_key` is None, using `random.PRNGKey(0)` instead.",
837 RuntimeWarning,
838 )
839 random_key = self.random_key
840 self._apply_state_prep_noise(noise_params=noise_params)
842 # state preparation
843 for q in range(self.n_qubits):
844 for _sp, sp_pulse_params in zip(self._sp, self.sp_pulse_params):
845 random_key, sub_key = safe_random_split(random_key)
846 _sp(
847 wires=q,
848 pulse_params=sp_pulse_params,
849 noise_params=noise_params,
850 random_key=sub_key,
851 gate_mode=gate_mode,
852 )
854 # circuit building
855 for layer in range(0, self.n_layers):
856 self.random_key, sub_key = safe_random_split(self.random_key)
857 # ansatz layers
858 self.pqc(
859 params[layer],
860 self.n_qubits,
861 pulse_params=pulse_params[layer],
862 noise_params=noise_params,
863 random_key=sub_key,
864 gate_mode=gate_mode,
865 )
867 self.random_key, sub_key = safe_random_split(self.random_key)
868 # encoding layers
869 self._iec(
870 inputs,
871 data_reupload=self.data_reupload[layer],
872 enc=self._enc,
873 enc_params=enc_params,
874 noise_params=noise_params,
875 random_key=sub_key,
876 )
878 # visual barrier
879 if self.has_dru:
880 qml.Barrier(wires=list(range(self.n_qubits)), only_visual=True)
882 # final ansatz layer
883 if self.has_dru: # same check as in init
884 self.random_key, sub_key = safe_random_split(self.random_key)
885 self.pqc(
886 params[self.n_layers],
887 self.n_qubits,
888 pulse_params=pulse_params[-1],
889 noise_params=noise_params,
890 random_key=sub_key,
891 gate_mode=gate_mode,
892 )
894 # channel noise
895 if noise_params is not None:
896 self._apply_general_noise(noise_params=noise_params)
898 def _observable(self) -> Union[jnp.ndarray, List[jnp.ndarray]]:
899 """
900 Define and return the measurement observable(s) for the circuit.
902 Constructs the appropriate PennyLane measurement based on the
903 configured execution_type and output_qubit settings.
905 Returns:
906 Union[jnp.ndarray, List[jnp.ndarray]]: Measurement result(s):
907 - "density": qml.density_matrix for output qubits
908 - "state": Full quantum state via qml.state()
909 - "expval": Expectation value(s) of PauliZ observable(s)
910 - "probs": Measurement probabilities
912 Raises:
913 ValueError: If execution_type or output_qubit configuration is invalid.
914 """
915 # run mixed simulation and get density matrix
916 if self.execution_type == "density":
917 return qml.density_matrix(wires=self.output_qubit)
918 elif self.execution_type == "state":
919 return qml.state()
920 # run default simulation and get expectation value
921 elif self.execution_type == "expval":
922 # n-local measurement
923 if self.all_qubit_measurement:
924 return [qml.expval(qml.PauliZ(q)) for q in self.output_qubit]
925 # parity or local measurement(s)
926 elif isinstance(self.output_qubit, list):
927 ret = []
928 # list of parity pairs
929 for pair in self.output_qubit:
930 if isinstance(pair, int):
931 ret.append(qml.expval(qml.PauliZ(pair)))
932 else:
933 obs = qml.PauliZ(pair[0])
934 for q in pair[1:]:
935 obs = obs @ qml.PauliZ(q)
936 ret.append(qml.expval(obs))
937 return ret
938 else:
939 raise ValueError(
940 f"Invalid parameter `output_qubit`: {self.output_qubit}.\
941 Must be int, list or -1."
942 )
943 # run default simulation and get probs
944 elif self.execution_type == "probs":
945 # n-local measurement
946 if self.all_qubit_measurement:
947 return qml.probs(wires=self.output_qubit)
948 # parity or local measurement(s)
949 elif isinstance(self.output_qubit, list):
950 ret = []
951 # list of parity pairs
952 for pair in self.output_qubit:
953 if isinstance(pair, int):
954 ret.append(qml.probs(wires=[pair]))
955 else:
956 ret.append(qml.probs(wires=pair))
957 return ret
958 else:
959 raise ValueError(
960 f"Invalid parameter `output_qubit`: {self.output_qubit}.\
961 Must be int, list or -1."
962 )
963 else:
964 raise ValueError(f"Invalid execution_type: {self.execution_type}.")
966 def _apply_state_prep_noise(
967 self, noise_params: Dict[str, Union[float, Dict[str, float]]]
968 ) -> None:
969 """
970 Apply state preparation noise to all qubits.
972 Simulates imperfect state preparation by applying BitFlip errors
973 to each qubit with the specified probability.
975 Args:
976 noise_params (Dict[str, Union[float, Dict[str, float]]]): Dictionary
977 containing noise parameters. Uses the "StatePreparation" key
978 for the BitFlip probability.
980 Returns:
981 None: Noise channels are applied in-place to the circuit.
982 """
983 p = noise_params.get("StatePreparation", 0.0)
984 if p > 0:
985 for q in range(self.n_qubits):
986 qml.BitFlip(p, wires=q)
988 def _apply_general_noise(
989 self, noise_params: Dict[str, Union[float, Dict[str, float]]]
990 ) -> None:
991 """
992 Apply general noise channels to all qubits.
994 Applies various decoherence and error channels after the circuit
995 execution, simulating environmental noise effects.
997 Args:
998 noise_params (Dict[str, Union[float, Dict[str, float]]]): Dictionary
999 containing noise parameters with the following supported keys:
1000 - "AmplitudeDamping" (float): Probability for amplitude damping.
1001 - "PhaseDamping" (float): Probability for phase damping.
1002 - "Measurement" (float): Probability for measurement error (BitFlip).
1003 - "ThermalRelaxation" (Dict): Dictionary with keys "t1", "t2",
1004 "t_factor" for thermal relaxation simulation.
1006 Returns:
1007 None: Noise channels are applied in-place to the circuit.
1009 Note:
1010 Gate-level noise (e.g., GateError) is handled separately in the
1011 Gates.Noise module and applied at the individual gate level.
1012 """
1013 amp_damp = noise_params.get("AmplitudeDamping", 0.0)
1014 phase_damp = noise_params.get("PhaseDamping", 0.0)
1015 thermal_relax = noise_params.get("ThermalRelaxation", 0.0)
1016 meas = noise_params.get("Measurement", 0.0)
1017 for q in range(self.n_qubits):
1018 if amp_damp > 0:
1019 qml.AmplitudeDamping(amp_damp, wires=q)
1020 if phase_damp > 0:
1021 qml.PhaseDamping(phase_damp, wires=q)
1022 if meas > 0:
1023 qml.BitFlip(meas, wires=q)
1024 if isinstance(thermal_relax, dict):
1025 t1 = thermal_relax["t1"]
1026 t2 = thermal_relax["t2"]
1027 t_factor = thermal_relax["t_factor"]
1028 circuit_depth = self._get_circuit_depth()
1029 tg = circuit_depth * t_factor
1030 qml.ThermalRelaxationError(1.0, t1, t2, tg, q)
1032 def _get_circuit_depth(self, inputs: Optional[jnp.ndarray] = None) -> int:
1033 """
1034 Calculate the depth of the quantum circuit.
1036 Creates a copy of the model without noise to accurately measure
1037 the circuit depth using PennyLane's specs function.
1039 Args:
1040 inputs (Optional[jnp.ndarray]): Input data for circuit evaluation.
1041 If None, default zero inputs are used.
1043 Returns:
1044 int: The circuit depth (longest path of gates in the circuit).
1045 """
1046 inputs = self._inputs_validation(inputs)
1047 spec_model = deepcopy(self)
1048 spec_model.noise_params = None # remove noise
1049 specs = qml.specs(spec_model.circuit)(self.params, inputs)
1051 return specs["resources"].depth
1053 def draw(
1054 self,
1055 inputs: Optional[jnp.ndarray] = None,
1056 figure: str = "text",
1057 *args: Any,
1058 **kwargs: Any,
1059 ) -> Union[str, Any]:
1060 """
1061 Visualize the quantum circuit.
1063 Generates a visual representation of the circuit using the specified
1064 rendering method.
1066 Args:
1067 inputs (Optional[jnp.ndarray]): Input data for the circuit. If None,
1068 default zero inputs are used. Defaults to None.
1069 figure (str): Visualization format. Options:
1070 - "text": ASCII text representation
1071 - "mpl": Matplotlib figure
1072 - "tikz": TikZ/LaTeX code for publication-quality figures
1073 Defaults to "text".
1074 *args (Any): Additional positional arguments passed to the
1075 visualization backend.
1076 **kwargs (Any): Additional keyword arguments passed to the
1077 visualization backend. May include pulse_params, gate_mode,
1078 enc_params, or noise_params.
1080 Returns:
1081 Union[str, Any]: Visualization output:
1082 - "text": String with ASCII circuit diagram
1083 - "mpl": Matplotlib Figure and Axes objects
1084 - "tikz": TikZ code string
1086 Raises:
1087 AssertionError: If figure is not one of "text", "mpl", or "tikz".
1088 """
1090 if not isinstance(self.circuit, qml.QNode):
1091 # TODO: throws strange argument error if not catched
1092 return ""
1094 assert figure in [
1095 "text",
1096 "mpl",
1097 "tikz",
1098 ], f"Invalid figure: {figure}. Must be 'text', 'mpl' or 'tikz'."
1100 inputs = self._inputs_validation(inputs)
1102 if figure == "mpl":
1103 return qml.draw_mpl(self.circuit)(
1104 params=self.params,
1105 inputs=inputs,
1106 *args,
1107 **kwargs,
1108 )
1109 elif figure == "tikz":
1110 return QuanTikz.build(
1111 self.circuit,
1112 params=self.params,
1113 inputs=inputs,
1114 *args,
1115 **kwargs,
1116 )
1117 else:
1118 return qml.draw(self.circuit)(params=self.params, inputs=inputs)
1120 def __repr__(self) -> str:
1121 """Return text representation of the quantum circuit."""
1122 return self.draw(figure="text")
1124 def __str__(self) -> str:
1125 """Return string representation of the quantum circuit."""
1126 return self.draw(figure="text")
1128 def _params_validation(self, params: Optional[jnp.ndarray]) -> jnp.ndarray:
1129 """
1130 Validate and normalize variational parameters.
1132 Ensures parameters have the correct shape with a batch dimension,
1133 and updates the model's internal parameters if new ones are provided.
1135 Args:
1136 params (Optional[jnp.ndarray]): Variational parameters to validate.
1137 If None, returns the model's current parameters.
1139 Returns:
1140 jnp.ndarray: Validated parameters with shape
1141 (n_layers, n_params_per_layer, batch_size).
1142 """
1143 # append batch axis if not provided
1144 if params is not None:
1145 if len(params.shape) == 2:
1146 params = np.expand_dims(params, axis=-1)
1148 self.params = params
1149 else:
1150 params = self.params
1152 return params
1154 def _pulse_params_validation(
1155 self, pulse_params: Optional[jnp.ndarray]
1156 ) -> jnp.ndarray:
1157 """
1158 Validate and normalize pulse parameters.
1160 Ensures pulse parameters are set, using model defaults if not provided.
1162 Args:
1163 pulse_params (Optional[jnp.ndarray]): Pulse parameter scalers.
1164 If None, returns the model's current pulse parameters.
1166 Returns:
1167 jnp.ndarray: Validated pulse parameters with shape
1168 (n_layers, n_pulse_params_per_layer, batch_size).
1169 """
1170 if pulse_params is None:
1171 pulse_params = self.pulse_params
1172 else:
1173 self.pulse_params = pulse_params
1175 return pulse_params
1177 def _enc_params_validation(self, enc_params: Optional[jnp.ndarray]) -> jnp.ndarray:
1178 """
1179 Validate and normalize encoding parameters.
1181 Ensures encoding parameters have the correct shape for the model's
1182 input feature dimensions.
1184 Args:
1185 enc_params (Optional[jnp.ndarray]): Encoding parameters to validate.
1186 If None, returns the model's current encoding parameters.
1188 Returns:
1189 jnp.ndarray: Validated encoding parameters with shape
1190 (n_qubits, n_input_feat).
1192 Raises:
1193 ValueError: If enc_params shape is incompatible with n_input_feat > 1.
1194 """
1195 if enc_params is None:
1196 enc_params = self.enc_params
1197 else:
1198 if self.trainable_frequencies:
1199 self.enc_params = enc_params
1200 else:
1201 self.enc_params = jnp.array(enc_params)
1203 if len(enc_params.shape) == 1 and self.n_input_feat == 1:
1204 enc_params = enc_params.reshape(-1, 1)
1205 elif len(enc_params.shape) == 1 and self.n_input_feat > 1:
1206 raise ValueError(
1207 f"Input dimension {self.n_input_feat} >1 but \
1208 `enc_params` has shape {enc_params.shape}"
1209 )
1211 return enc_params
1213 def _inputs_validation(
1214 self, inputs: Union[None, List, float, int, jnp.ndarray]
1215 ) -> jnp.ndarray:
1216 """
1217 Validate and normalize input data.
1219 Converts various input formats to a standardized 2D array shape
1220 suitable for batch processing in the quantum circuit.
1222 Args:
1223 inputs (Union[None, List, float, int, jnp.ndarray]): Input data in
1224 various formats:
1225 - None: Returns zeros with shape (1, n_input_feat)
1226 - float/int: Single scalar value
1227 - List: List of values or batched inputs
1228 - jnp.ndarray: NumPy/JAX array
1230 Returns:
1231 jnp.ndarray: Validated inputs with shape (batch_size, n_input_feat).
1233 Raises:
1234 ValueError: If input shape is incompatible with expected n_input_feat.
1236 Warns:
1237 UserWarning: If input is replicated to match n_input_feat.
1238 """
1239 self._zero_inputs = False
1240 if isinstance(inputs, List):
1241 inputs = jnp.array(np.stack(inputs))
1242 elif isinstance(inputs, float) or isinstance(inputs, int):
1243 inputs = jnp.array([inputs])
1244 elif inputs is None:
1245 inputs = jnp.array([[0] * self.n_input_feat])
1247 if not inputs.any():
1248 self._zero_inputs = True
1250 if len(inputs.shape) <= 1:
1251 if self.n_input_feat == 1:
1252 # add a batch dimension
1253 inputs = inputs.reshape(-1, 1)
1254 else:
1255 if inputs.shape[0] == self.n_input_feat:
1256 inputs = inputs.reshape(1, -1)
1257 else:
1258 inputs = inputs.reshape(-1, 1)
1259 inputs = inputs.repeat(self.n_input_feat, axis=1)
1260 warnings.warn(
1261 f"Expected {self.n_input_feat} inputs, but {inputs.shape[0]} "
1262 "was provided, replicating input for all input features.",
1263 UserWarning,
1264 )
1265 else:
1266 if inputs.shape[1] != self.n_input_feat:
1267 raise ValueError(
1268 f"Wrong number of inputs provided. Expected {self.n_input_feat} "
1269 f"inputs, but input has shape {inputs.shape}."
1270 )
1272 return inputs
1274 def _mp_executor(
1275 self,
1276 f: Callable,
1277 params: jnp.ndarray,
1278 pulse_params: jnp.ndarray,
1279 inputs: jnp.ndarray,
1280 enc_params: jnp.ndarray,
1281 noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]],
1282 random_key: random.PRNGKey,
1283 gate_mode: str,
1284 ) -> jnp.ndarray:
1285 """
1286 Execute circuit function with optional parallelization over batches.
1288 Uses JAX's vmap for vectorized execution when batching over inputs,
1289 parameters, or pulse parameters. Falls back to sequential execution
1290 for single samples or when multithreading is disabled.
1292 Args:
1293 f (Callable): Circuit function to execute (circuit or circuit_mixed).
1294 params (jnp.ndarray): Variational parameters of shape
1295 (n_layers, n_params_per_layer, batch_size).
1296 pulse_params (jnp.ndarray): Pulse parameters of shape
1297 (n_layers, n_pulse_params_per_layer, batch_size).
1298 inputs (jnp.ndarray): Input data of shape (batch_size, n_input_feat).
1299 enc_params (jnp.ndarray): Encoding parameters of shape
1300 (n_qubits, n_input_feat).
1301 noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]):
1302 Noise configuration dictionary.
1303 random_key (random.PRNGKey): JAX random key for stochastic operations.
1304 gate_mode (str): Gate execution mode ("unitary" or "pulse").
1306 Returns:
1307 jnp.ndarray: Circuit execution results, post-processed for uniformity.
1308 """
1310 def _f(
1311 _params: jnp.ndarray,
1312 _inputs: jnp.ndarray,
1313 _pulse_params: jnp.ndarray,
1314 _random_key: random.PRNGKey,
1315 ) -> jnp.ndarray:
1316 return f(
1317 params=_params,
1318 inputs=_inputs,
1319 pulse_params=_pulse_params,
1320 random_key=_random_key,
1321 noise_params=noise_params,
1322 enc_params=enc_params,
1323 gate_mode=gate_mode,
1324 )
1326 B = np.prod(self.eff_batch_shape)
1327 if (gate_mode == "pulse" or self.use_multithreading) and B > 1:
1328 random_keys = safe_random_split(random_key, num=B)
1330 # wrapper to allow kwargs (not supported by jax)
1331 result = jax.vmap(
1332 _f,
1333 in_axes=(
1334 2 if self.batch_shape[1] > 1 else None, # params
1335 0 if self.batch_shape[0] > 1 else None, # inputs
1336 2 if self.batch_shape[2] > 1 else None, # pulse_params
1337 0, # random_keys
1338 ),
1339 )(
1340 params,
1341 inputs,
1342 pulse_params,
1343 random_keys,
1344 )
1345 else:
1346 result = _f(
1347 _params=params,
1348 _pulse_params=pulse_params,
1349 _inputs=inputs,
1350 _random_key=random_key,
1351 )
1353 return self._postprocess_res(result)
1355 def _postprocess_res(self, result: Union[List, jnp.ndarray]) -> jnp.ndarray:
1356 """
1357 Post-process circuit execution results for uniform shape.
1359 Converts list outputs (from multiple measurements) to stacked arrays
1360 and reorders axes for consistent batch dimension placement.
1362 Args:
1363 result (Union[List, jnp.ndarray]): Raw circuit output, either a
1364 list of measurement results or a single array.
1366 Returns:
1367 jnp.ndarray: Uniformly shaped result array with batch dimension first.
1368 """
1369 if isinstance(result, list):
1370 # we use moveaxis here because in case of parity measure,
1371 # there is another dimension appended to the end and
1372 # simply transposing would result in a wrong shape
1373 result = jnp.stack(result)
1374 if len(result.shape) > 1:
1375 result = jnp.moveaxis(result, 0, 1)
1376 return result
1378 def _assimilate_batch(
1379 self,
1380 inputs: jnp.ndarray,
1381 params: jnp.ndarray,
1382 pulse_params: jnp.ndarray,
1383 ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
1384 """
1385 Align batch dimensions across inputs, parameters, and pulse parameters.
1387 Broadcasts and reshapes arrays to have compatible batch dimensions
1388 for vectorized circuit execution. Sets the internal batch_shape.
1390 Args:
1391 inputs (jnp.ndarray): Input data of shape (B_I, n_input_feat).
1392 params (jnp.ndarray): Parameters of shape (n_layers, n_params, B_P).
1393 pulse_params (jnp.ndarray): Pulse params of shape (n_layers, n_pulse, B_R).
1395 Returns:
1396 Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: Tuple containing:
1397 - inputs: Reshaped to (B, n_input_feat) where B = B_I * B_P * B_R
1398 - params: Reshaped to (n_layers, n_params, B)
1399 - pulse_params: Reshaped to (n_layers, n_pulse, B)
1401 Note:
1402 The effective batch shape depends on repeat_batch_axis configuration.
1403 This is the only method that sets self._batch_shape.
1404 """
1405 B_I = inputs.shape[0]
1406 # we check for the product because there is a chance that
1407 # there are no params. In this case we want B_P to be 1
1408 B_P = 1 if 0 in params.shape else params.shape[-1]
1409 B_R = pulse_params.shape[-1]
1411 # THIS is the only place where we set the batch shape
1412 self._batch_shape = (B_I, B_P, B_R)
1413 B = np.prod(self.eff_batch_shape)
1415 # [B_I, ...] -> [B_I, B_P, B_R, ...] -> [B, ...]
1416 if B_I > 1 and self.repeat_batch_axis[0]:
1417 if self.repeat_batch_axis[1]:
1418 inputs = jnp.repeat(inputs[:, None, None, ...], B_P, axis=1)
1419 if self.repeat_batch_axis[2]:
1420 inputs = jnp.repeat(inputs, B_R, axis=2)
1421 inputs = inputs.reshape(B, *inputs.shape[3:])
1423 # [..., ..., B_P] -> [..., ..., B_I, B_P, B_R] -> [..., ..., B]
1424 if B_P > 1 and self.repeat_batch_axis[1]:
1425 # add B_I axis before last, and B_R axis after last
1426 params = params[..., None, :, None] # [..., B_I(=1), B_P, B_R(=1)]
1427 if self.repeat_batch_axis[0]:
1428 params = jnp.repeat(params, B_I, axis=-3) # [..., B_I, B_P, 1]
1429 if self.repeat_batch_axis[2]:
1430 params = jnp.repeat(params, B_R, axis=-1) # [..., B_I, B_P, B_R]
1431 params = params.reshape(*params.shape[:-3], B)
1433 # [..., ..., B_R] -> [..., ..., B_I, B_P, B_R] -> [..., ..., B]
1434 if B_R > 1 and self.repeat_batch_axis[2]:
1435 # add B_I axis before last, and B_P axis before last (after adding B_I)
1436 pulse_params = pulse_params[
1437 ..., None, None, :
1438 ] # [..., B_I(=1), B_P(=1), B_R]
1439 if self.repeat_batch_axis[0]:
1440 pulse_params = jnp.repeat(
1441 pulse_params, B_I, axis=-3
1442 ) # [..., B_I, 1, B_R]
1443 if self.repeat_batch_axis[1]:
1444 pulse_params = jnp.repeat(
1445 pulse_params, B_P, axis=-2
1446 ) # [..., B_I, B_P, B_R]
1447 pulse_params = pulse_params.reshape(*pulse_params.shape[:-3], B)
1449 return inputs, params, pulse_params
1451 def _requires_density(self) -> bool:
1452 """
1453 Check if density matrix simulation is required.
1455 Determines whether the circuit must be executed with the mixed-state
1456 simulator based on execution type and noise configuration.
1458 Returns:
1459 bool: True if density matrix simulation is required, False otherwise.
1460 Returns True if:
1461 - execution_type is "density", or
1462 - Any non-coherent noise channel has non-zero probability
1463 """
1464 if self.execution_type == "density":
1465 return True
1467 if self.noise_params is None:
1468 return False
1470 coherent_noise = {"GateError"}
1471 for k, v in self.noise_params.items():
1472 if k in coherent_noise:
1473 continue
1474 if v is not None and v > 0:
1475 return True
1476 return False
1478 def __call__(
1479 self,
1480 params: Optional[jnp.ndarray] = None,
1481 inputs: Optional[jnp.ndarray] = None,
1482 pulse_params: Optional[jnp.ndarray] = None,
1483 enc_params: Optional[jnp.ndarray] = None,
1484 noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None,
1485 execution_type: Optional[str] = None,
1486 force_mean: bool = False,
1487 gate_mode: str = "unitary",
1488 ) -> jnp.ndarray:
1489 """
1490 Execute the quantum circuit (callable interface).
1492 Provides a convenient callable interface for circuit execution,
1493 delegating to the _forward method.
1495 Args:
1496 params (Optional[jnp.ndarray]): Variational parameters of shape
1497 (n_layers, n_params_per_layer) or (n_layers, n_params_per_layer, batch).
1498 If None, uses model's internal parameters.
1499 inputs (Optional[jnp.ndarray]): Input data of shape
1500 (batch_size, n_input_feat). If None, uses zero inputs.
1501 pulse_params (Optional[jnp.ndarray]): Pulse parameter scalers for
1502 pulse-mode gate execution.
1503 enc_params (Optional[jnp.ndarray]): Encoding parameters of shape
1504 (n_qubits, n_input_feat). If None, uses model's encoding parameters.
1505 noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]):
1506 Noise configuration. If None, uses previously set noise parameters.
1507 execution_type (Optional[str]): Measurement type: "expval", "density",
1508 "probs", or "state". If None, uses current execution_type setting.
1509 force_mean (bool): If True, averages results over measurement qubits.
1510 Defaults to False.
1511 gate_mode (str): Gate execution backend, "unitary" or "pulse".
1512 Defaults to "unitary".
1514 Returns:
1515 jnp.ndarray: Circuit output with shape depending on execution_type:
1516 - "expval": (n_output_qubits,) or scalar
1517 - "density": (2^n_output, 2^n_output)
1518 - "probs": (2^n_output,) or (n_pairs, 2^pair_size)
1519 - "state": (2^n_qubits,)
1520 """
1521 # Call forward method which handles the actual caching etc.
1522 return self._forward(
1523 params=params,
1524 inputs=inputs,
1525 pulse_params=pulse_params,
1526 enc_params=enc_params,
1527 noise_params=noise_params,
1528 execution_type=execution_type,
1529 force_mean=force_mean,
1530 gate_mode=gate_mode,
1531 )
1533 def _forward(
1534 self,
1535 params: Optional[jnp.ndarray] = None,
1536 inputs: Optional[jnp.ndarray] = None,
1537 pulse_params: Optional[jnp.ndarray] = None,
1538 enc_params: Optional[jnp.ndarray] = None,
1539 noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None,
1540 execution_type: Optional[str] = None,
1541 force_mean: bool = False,
1542 gate_mode: str = "unitary",
1543 ) -> jnp.ndarray:
1544 """
1545 Execute the quantum circuit forward pass.
1547 Internal implementation of the forward pass that handles parameter
1548 validation, batch alignment, and circuit execution routing.
1550 Args:
1551 params (Optional[jnp.ndarray]): Variational parameters of shape
1552 (n_layers, n_params_per_layer) or
1553 (n_layers, n_params_per_layer, batch).
1554 If None, uses model's internal parameters.
1555 inputs (Optional[jnp.ndarray]): Input data of shape
1556 (batch_size, n_input_feat).
1557 If None, uses zero inputs.
1558 pulse_params (Optional[jnp.ndarray]): Pulse parameter scalers for
1559 pulse-mode gate execution.
1560 enc_params (Optional[jnp.ndarray]): Encoding parameters of shape
1561 (n_qubits, n_input_feat). If None, uses model's encoding parameters.
1562 noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]):
1563 Noise configuration. If None, uses previously set noise parameters.
1564 execution_type (Optional[str]): Measurement type: "expval", "density",
1565 "probs", or "state". If None, uses current execution_type setting.
1566 force_mean (bool): If True, averages results over measurement qubits.
1567 Defaults to False.
1568 gate_mode (str): Gate execution backend, "unitary" or "pulse".
1569 Defaults to "unitary".
1571 Returns:
1572 jnp.ndarray: Circuit output with shape depending on execution_type:
1573 - "expval": (n_output_qubits,) or scalar
1574 - "density": (2^n_output, 2^n_output)
1575 - "probs": (2^n_output,) or (n_pairs, 2^pair_size)
1576 - "state": (2^n_qubits,)
1578 Raises:
1579 ValueError: If pulse_params provided without pulse gate_mode, or
1580 if noise_params provided with pulse gate_mode.
1581 """
1582 # set the parameters as object attributes
1583 if noise_params is not None:
1584 self.noise_params = noise_params
1585 if execution_type is not None:
1586 self.execution_type = execution_type
1587 self.gate_mode = gate_mode
1589 # consistency checks
1590 if pulse_params is not None and gate_mode != "pulse":
1591 raise ValueError(
1592 "pulse_params were provided but gate_mode is not 'pulse'. "
1593 "Either switch gate_mode='pulse' or do not pass pulse_params."
1594 )
1596 if noise_params is not None and gate_mode == "pulse":
1597 raise ValueError(
1598 "Noise is not supported in 'pulse' gate_mode. "
1599 "Either remove noise_params or use gate_mode='unitary'."
1600 )
1602 params = self._params_validation(params)
1603 pulse_params = self._pulse_params_validation(pulse_params)
1604 inputs = self._inputs_validation(inputs)
1605 enc_params = self._enc_params_validation(enc_params)
1607 inputs, params, pulse_params = self._assimilate_batch(
1608 inputs,
1609 params,
1610 pulse_params,
1611 )
1613 result: Optional[jnp.ndarray] = None
1614 self.random_key, subkey = safe_random_split(self.random_key)
1616 # if density matrix requested or noise params used
1617 if self._requires_density():
1618 result = self._mp_executor(
1619 f=self.circuit_mixed,
1620 params=params,
1621 pulse_params=pulse_params,
1622 inputs=inputs,
1623 enc_params=enc_params,
1624 noise_params=self.noise_params,
1625 random_key=subkey,
1626 gate_mode=gate_mode,
1627 )
1628 else:
1629 if not isinstance(self.circuit, qml.QNode):
1630 result = self.circuit(
1631 inputs=inputs,
1632 )
1633 else:
1634 result = self._mp_executor(
1635 f=self.circuit,
1636 params=params,
1637 pulse_params=pulse_params,
1638 inputs=inputs,
1639 enc_params=enc_params,
1640 noise_params=self.noise_params,
1641 random_key=subkey,
1642 gate_mode=gate_mode,
1643 )
1645 result = result.reshape((*self.eff_batch_shape, *self._result_shape)).squeeze()
1647 if (
1648 self.execution_type in ("expval", "probs")
1649 and force_mean
1650 and len(result.shape) > 0
1651 and self._result_shape[0] > 1
1652 ):
1653 result = result.mean(axis=-1)
1655 return result