Coverage for qml_essentials / model.py: 90%
501 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-30 11:43 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-30 11:43 +0000
1from typing import Any, Dict, Optional, Tuple, Callable, Union, List
3from qml_essentials import operations as op
4from qml_essentials import yaqsi as ys
5import warnings
6import jax.numpy as jnp
7import numpy as np
8from jax import random
10from qml_essentials.tape import recording
11from qml_essentials.operations import KrausChannel
12from qml_essentials.ansaetze import Ansaetze, Circuit, Encoding
13from qml_essentials.gates import Gates, PulseInformation as pinfo
14from qml_essentials.utils import safe_random_split
16import logging
18log = logging.getLogger(__name__)
21class Model:
22 """
23 A quantum circuit model.
24 """
26 def __init__(
27 self,
28 n_qubits: int,
29 n_layers: int,
30 circuit_type: Union[str, Circuit] = "No_Ansatz",
31 data_reupload: Union[bool, List[List[bool]], List[List[List[bool]]]] = True,
32 state_preparation: Union[
33 str, Callable, List[Union[str, Callable]], None
34 ] = None,
35 encoding: Union[Encoding, str, Callable, List[Union[str, Callable]]] = Gates.RX,
36 trainable_frequencies: bool = False,
37 initialization: str = "random",
38 initialization_domain: List[float] = [0, 2 * jnp.pi],
39 output_qubit: Union[List[int], int] = -1,
40 shots: Optional[int] = None,
41 random_seed: int = 1000,
42 remove_zero_encoding: bool = True,
43 repeat_batch_axis: List[bool] = [True, True, True],
44 pulse_shape: str = "gaussian",
45 ) -> None:
46 """
47 Initialize the quantum circuit model.
48 Parameters will have the shape [impl_n_layers, parameters_per_layer]
49 where impl_n_layers is the number of layers provided and added by one
50 depending if data_reupload is True and parameters_per_layer is given by
51 the chosen ansatz.
53 The model is initialized with the following parameters as defaults:
54 - noise_params: None
55 - execution_type: "expval"
56 - shots: None
58 Args:
59 n_qubits (int): The number of qubits in the circuit.
60 n_layers (int): The number of layers in the circuit.
61 circuit_type (str, Circuit): The type of quantum circuit to use.
62 If None, defaults to "no_ansatz".
63 data_reupload (Union[bool, List[bool], List[List[bool]]], optional):
64 Whether to reupload data to the quantum device on each
65 layer and qubit. Detailed re-uploading instructions can be given
66 as a list/array of 0/False and 1/True with shape (n_qubits,
67 n_layers) to specify where to upload the data. Defaults to True
68 for applying data re-uploading to the full circuit.
69 encoding (Union[str, Callable, List[str], List[Callable]], optional):
70 The unitary to use for encoding the input data. Can be a string
71 (e.g. "RX") or a callable (e.g. op.RX). Defaults to op.RX.
72 If input is multidimensional it is assumed to be a list of
73 unitaries or a list of strings.
74 trainable_frequencies (bool, optional):
75 Sets trainable encoding parameters for trainable frequencies.
76 Defaults to False.
77 initialization (str, optional): The strategy to initialize the parameters.
78 Can be "random", "zeros", "zero-controlled", "pi", or "pi-controlled".
79 Defaults to "random".
80 output_qubit (List[int], int, optional): The index of the output
81 qubit (or qubits). When set to -1 all qubits are measured, or a
82 global measurement is conducted, depending on the execution
83 type.
84 shots (Optional[int], optional): The number of shots to use for
85 the quantum device. Defaults to None.
86 random_seed (int, optional): seed for the random number generator
87 in initialization is "random" and for random noise parameters.
88 Defaults to 1000.
89 remove_zero_encoding (bool, optional): whether to
90 remove the zero encoding from the circuit. Defaults to True.
91 repeat_batch_axis (List[bool], optional): Each boolean in the array
92 determines over which axes to parallelise computation. The axes
93 correspond to [inputs, params, pulse_params]. Defaults to
94 [True, True, True], meaning that batching is enabled over all
95 axes.
96 pulse_shape (str, optional): Pulse envelope shape for pulse-level
97 simulation. One of ``PulseEnvelope.available()``.
98 Defaults to ``"gaussian"``.
100 Returns:
101 None
102 """
103 # Initialize default parameters needed for circuit evaluation
104 self.n_qubits: int = n_qubits
105 self.output_qubit: Union[List[int], int] = output_qubit
106 self.n_layers: int = n_layers
107 self.noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None
108 self.shots = shots
109 self.remove_zero_encoding = remove_zero_encoding
110 self.trainable_frequencies: bool = trainable_frequencies
111 self.execution_type: str = "expval"
112 self.repeat_batch_axis: List[bool] = repeat_batch_axis
114 # --- Pulse envelope ---
115 pinfo.set_envelope(pulse_shape)
117 # --- State Preparation ---
118 try:
119 self._sp = Gates.parse_gates(state_preparation, Gates)
120 except ValueError as e:
121 raise ValueError(f"Error parsing encodings: {e}")
123 # prepare corresponding pulse parameters (always optimized pulses)
124 self.sp_pulse_params = []
125 for sp in self._sp:
126 sp_name = sp.__name__ if hasattr(sp, "__name__") else str(sp)
128 if pinfo.gate_by_name(sp_name) is not None:
129 self.sp_pulse_params.append(pinfo.gate_by_name(sp_name).params)
130 else:
131 # gate has no pulse parametrization
132 self.sp_pulse_params.append(None)
134 # --- Encoding ---
135 if isinstance(encoding, Encoding):
136 # user wants custom strategy? do it!
137 self._enc = encoding
138 else:
139 # use hammming encoding by default
140 self._enc = Encoding("hamming", encoding)
142 # Number of possible inputs
143 self.n_input_feat = len(self._enc)
144 log.debug(f"Number of input features: {self.n_input_feat}")
146 # Trainable frequencies, default initialization as in arXiv:2309.03279v2
147 self.enc_params = jnp.ones((self.n_layers, self.n_qubits, self.n_input_feat))
149 self._zero_inputs = False
151 # --- Data-Reuploading ---
153 # Keep as NumPy array (not JAX) so that ``if data_reupload[q, idx]``
154 # in _iec remains a concrete Python bool even under jax.jit tracing.
155 # note that setting this will also update self.degree and self.frequencies
156 # and in consequence also self.has_dru
157 self.data_reupload = data_reupload
159 # check for the highest degree among all input dimensions
160 if self.has_dru:
161 impl_n_layers: int = n_layers + 1 # we need L+1 according to Schuld et al.
162 else:
163 impl_n_layers = n_layers
164 log.info(f"Number of implicit layers: {impl_n_layers}.")
166 # --- Ansatz ---
167 # only weak check for str. We trust the user to provide sth useful
168 if isinstance(circuit_type, str):
169 self.pqc: Callable[[Optional[jnp.ndarray], int], int] = getattr(
170 Ansaetze, circuit_type or "No_Ansatz"
171 )()
172 else:
173 self.pqc = circuit_type()
174 log.info(f"Using Ansatz {circuit_type}.")
176 # calculate the shape of the parameter vector here, we will re-use this in init.
177 params_per_layer = self.pqc.n_params_per_layer(self.n_qubits)
178 self._params_shape: Tuple[int, int] = (impl_n_layers, params_per_layer)
179 log.info(f"Parameters per layer: {params_per_layer}")
181 pulse_params_per_layer = self.pqc.n_pulse_params_per_layer(self.n_qubits)
182 self._pulse_params_shape: Tuple[int, int] = (
183 impl_n_layers,
184 pulse_params_per_layer,
185 )
187 # intialize to None as we can't know this yet
188 self._batch_shape = None
190 # this will also be re-used in the init method,
191 # however, only if nothing is provided
192 self._inialization_strategy = initialization
193 self._initialization_domain = initialization_domain
195 # ..here! where we only require a JAX random key
196 self.random_key = self.initialize_params(random.key(random_seed))
198 # Initializing pulse params
199 self.pulse_params: jnp.ndarray = jnp.ones((1, *self._pulse_params_shape))
201 log.info(f"Initialized pulse parameters with shape {self.pulse_params.shape}.")
203 # Initialise the yaqsi Script that wraps _variational.
204 # No device selection needed - yaqsi auto-routes between statevector
205 # and density-matrix simulation based on whether noise channels are
206 # present on the tape.
207 self.script = ys.Script(f=self._variational, n_qubits=self.n_qubits)
209 @property
210 def noise_params(self) -> Optional[Dict[str, Union[float, Dict[str, float]]]]:
211 """
212 Gets the noise parameters of the model.
214 Returns:
215 Optional[Dict[str, float]]: A dictionary of
216 noise parameters or None if not set.
217 """
218 return self._noise_params
220 @noise_params.setter
221 def noise_params(
222 self, kvs: Optional[Dict[str, Union[float, Dict[str, float]]]]
223 ) -> None:
224 """
225 Sets the noise parameters of the model.
227 Typically a "noise parameter" refers to the error probability.
228 ThermalRelaxation is a special case, and supports a dict as value with
229 structure:
230 "ThermalRelaxation":
231 {
232 "t1": 2000, # relative t1 time.
233 "t2": 1000, # relative t2 time
234 "t_factor" 1: # relative gate time factor
235 },
237 Args:
238 kvs (Optional[Dict[str, Union[float, Dict[str, float]]]]): A
239 dictionary of noise parameters. If all values are 0.0, the noise
240 parameters are set to None.
242 Returns:
243 None
244 """
245 # set to None if only zero values provided
246 if kvs is not None and all(v == 0.0 for v in kvs.values()):
247 kvs = None
249 # set default values
250 if kvs is not None:
251 defaults = {
252 "BitFlip": 0.0,
253 "PhaseFlip": 0.0,
254 "Depolarizing": 0.0,
255 "MultiQubitDepolarizing": 0.0,
256 "AmplitudeDamping": 0.0,
257 "PhaseDamping": 0.0,
258 "GateError": 0.0,
259 "ThermalRelaxation": None,
260 "StatePreparation": 0.0,
261 "Measurement": 0.0,
262 }
263 for key, default_val in defaults.items():
264 kvs.setdefault(key, default_val)
266 # check if there are any keys not supported
267 for key in kvs.keys():
268 if key not in defaults:
269 warnings.warn(
270 f"Noise type {key} is not supported by this package",
271 UserWarning,
272 )
274 # check valid params for thermal relaxation noise channel
275 tr_params = kvs["ThermalRelaxation"]
276 if isinstance(tr_params, dict):
277 tr_params.setdefault("t1", 0.0)
278 tr_params.setdefault("t2", 0.0)
279 tr_params.setdefault("t_factor", 0.0)
280 valid_tr_keys = {"t1", "t2", "t_factor"}
281 for k in tr_params.keys():
282 if k not in valid_tr_keys:
283 warnings.warn(
284 f"Thermal Relaxation parameter {k} is not supported "
285 f"by this package",
286 UserWarning,
287 )
288 if not all(tr_params.values()) or tr_params["t2"] > 2 * tr_params["t1"]:
289 warnings.warn(
290 "Received invalid values for Thermal Relaxation noise "
291 "parameter. Thermal relaxation is not applied!",
292 UserWarning,
293 )
294 kvs["ThermalRelaxation"] = 0.0
296 self._noise_params = kvs
298 @property
299 def output_qubit(self) -> List[int]:
300 """Get the output qubit indices for measurement."""
301 return self._output_qubit
303 @output_qubit.setter
304 def output_qubit(self, value: Union[int, List[int]]) -> None:
305 """
306 Set the output qubit(s) for measurement.
308 Args:
309 value: Qubit index or list of indices. Use -1 for all qubits.
310 """
311 if isinstance(value, list):
312 assert (
313 len(value) <= self.n_qubits
314 ), f"Size of output_qubit {len(value)} cannot be\
315 larger than number of qubits {self.n_qubits}."
316 elif isinstance(value, int):
317 if value == -1:
318 value = list(range(self.n_qubits))
319 else:
320 assert (
321 value < self.n_qubits
322 ), f"Output qubit {value} cannot be larger than {self.n_qubits}."
323 value = [value]
325 self._output_qubit = value
327 @property
328 def execution_type(self) -> str:
329 """
330 Gets the execution type of the model.
332 Returns:
333 str: The execution type, one of 'density', 'expval', or 'probs'.
334 """
335 return self._execution_type
337 @execution_type.setter
338 def execution_type(self, value: str) -> None:
339 if value == "density":
340 self._result_shape = (
341 2 ** len(self.output_qubit),
342 2 ** len(self.output_qubit),
343 )
344 elif value == "expval":
345 # check if all qubits are used
346 if len(self.output_qubit) == self.n_qubits:
347 self._result_shape = (len(self.output_qubit),)
348 # if not -> parity measurement with only 1D output per pair
349 # or n_local measurement
350 else:
351 self._result_shape = (len(self.output_qubit),)
352 elif value == "probs":
353 # in case this is a list of parities,
354 # each pair has 2^len(qubits) probabilities
355 n_parity = (
356 (2,) * len(self.output_qubit)
357 if isinstance(self.output_qubit, (Tuple, List))
358 else (2,)
359 )
360 self._result_shape = n_parity
361 elif value == "state":
362 self._result_shape = (2 ** len(self.output_qubit),)
363 else:
364 raise ValueError(f"Invalid execution type: {value}.")
366 if value == "state" and not self.all_qubit_measurement:
367 warnings.warn(
368 f"{value} measurement does ignore output_qubit, which is "
369 f"{self.output_qubit}.",
370 UserWarning,
371 )
373 if value == "probs" and self.shots is None:
374 warnings.warn(
375 "Setting execution_type to probs without specifying shots.",
376 UserWarning,
377 )
379 if value == "density" and self.shots is not None:
380 raise ValueError("Setting execution_type to density with shots not None.")
382 self._execution_type = value
384 @property
385 def shots(self) -> Optional[int]:
386 """
387 Gets the number of shots to use for the quantum device.
389 Returns:
390 Optional[int]: The number of shots.
391 """
392 return self._shots
394 @shots.setter
395 def shots(self, value: Optional[int]) -> None:
396 """
397 Sets the number of shots to use for the quantum device.
399 Args:
400 value (Optional[int]): The number of shots.
401 If an integer less than or equal to 0 is provided, it is set to None.
403 Returns:
404 None
405 """
406 if type(value) is int and value <= 0:
407 value = None
408 self._shots = value
410 @property
411 def params(self) -> jnp.ndarray:
412 """Get the variational parameters of the model."""
413 return self._params
415 @params.setter
416 def params(self, value: jnp.ndarray) -> None:
417 """Set the variational parameters, ensuring batch dimension exists."""
418 if len(value.shape) == 2:
419 value = value.reshape(1, *value.shape)
421 self._params = value
423 @property
424 def enc_params(self) -> jnp.ndarray:
425 """Get the encoding parameters used for input transformation."""
426 return self._enc_params
428 @enc_params.setter
429 def enc_params(self, value: jnp.ndarray) -> None:
430 """Set the encoding parameters."""
431 self._enc_params = value
433 @property
434 def pulse_params(self) -> jnp.ndarray:
435 """Get the pulse parameters for pulse-mode gate execution."""
436 return self._pulse_params
438 @pulse_params.setter
439 def pulse_params(self, value: jnp.ndarray) -> None:
440 """Set the pulse parameters."""
441 self._pulse_params = value
443 @property
444 def data_reupload(self) -> jnp.ndarray:
445 """Get the data reupload mask."""
446 return self._data_reupload
448 @data_reupload.setter
449 def data_reupload(self, value: jnp.ndarray) -> None:
450 """Set the data reupload mask.
452 Always converts to a concrete NumPy boolean array so that
453 ``if data_reupload[q, idx]`` in :meth:`_iec` remains a plain
454 Python ``bool`` even inside JAX-traced functions (jit / grad / vmap).
455 """
456 # Process data reuploading strategy and set degree
457 if not isinstance(value, bool):
458 if not isinstance(value, np.ndarray):
459 value = np.array(value)
461 if len(value.shape) == 2:
462 assert value.shape == (
463 self.n_layers,
464 self.n_qubits,
465 ), f"Data reuploading array has wrong shape. \
466 Expected {(self.n_layers, self.n_qubits)} or\
467 {(self.n_layers, self.n_qubits, self.n_input_feat)},\
468 got {value.shape}."
469 value = value.reshape(*value.shape, 1)
470 value = np.repeat(value, self.n_input_feat, axis=2)
472 assert value.shape == (
473 self.n_layers,
474 self.n_qubits,
475 self.n_input_feat,
476 ), f"Data reuploading array has wrong shape. \
477 Expected {(self.n_layers, self.n_qubits, self.n_input_feat)},\
478 got {value.shape}."
480 log.debug(f"Data reuploading array:\n{value}")
481 else:
482 if value:
483 value = np.ones((self.n_layers, self.n_qubits, self.n_input_feat))
484 log.debug("Full data reuploading.")
485 else:
486 value = np.zeros((self.n_layers, self.n_qubits, self.n_input_feat))
487 value[0][0] = 1
488 log.debug("No data reuploading.")
490 # convert to boolean values
491 self._data_reupload = np.asarray(value).astype(bool)
493 self.degree: Tuple = tuple(
494 self._enc.get_n_freqs(np.count_nonzero(self.data_reupload[..., i]))
495 for i in range(self.n_input_feat)
496 )
498 self.frequencies: Tuple = tuple(
499 self._enc.get_spectrum(np.count_nonzero(self.data_reupload[..., i]))
500 for i in range(self.n_input_feat)
501 )
503 # Cache has_dru as a plain Python bool so that it can be used in
504 # Python ``if`` statements even inside JAX-traced functions.
505 self._has_dru: bool = bool(max(int(np.max(f)) for f in self._frequencies) > 1)
507 @property
508 def degree(self) -> Tuple:
509 """Get the degree of the model."""
510 return self._degree
512 @degree.setter
513 def degree(self, value: Tuple):
514 self._degree = value
516 @property
517 def frequencies(self) -> Tuple:
518 """Get the frequencies of the model."""
519 return self._frequencies
521 @frequencies.setter
522 def frequencies(self, value: Tuple):
523 self._frequencies = value
525 @property
526 def has_dru(self) -> bool:
527 """Check if the model has data reupload."""
528 return self._has_dru
530 @property
531 def all_qubit_measurement(self) -> bool:
532 """Check if measurement is performed on all qubits."""
533 return self.output_qubit == list(range(self.n_qubits))
535 @property
536 def batch_shape(self) -> Tuple[int, ...]:
537 """
538 Get the batch shape (B_I, B_P, B_R).
539 If the model was not called before,
540 it returns (1, 1, 1).
542 Returns:
543 Tuple[int, ...]: Tuple of (input_batch, param_batch, pulse_batch).
544 Returns (1, 1, 1) if model has not been called yet.
545 """
546 if self._batch_shape is None:
547 log.debug("Model was not called yet. Returning (1,1,1) as batch shape.")
548 return (1, 1, 1)
549 return self._batch_shape
551 @property
552 def eff_batch_shape(self) -> Tuple[int, ...]:
553 """
554 Get the effective batch shape after applying repeat_batch_axis mask.
556 Returns:
557 Tuple[int, ...]: Effective batch dimensions, excluding zeros.
558 """
559 batch_shape = np.array(self.batch_shape) * self.repeat_batch_axis
560 batch_shape = batch_shape[batch_shape != 0]
561 return batch_shape
563 def initialize_params(
564 self,
565 random_key: Optional[random.PRNGKey] = None,
566 repeat: int = 1,
567 initialization: Optional[str] = None,
568 initialization_domain: Optional[List[float]] = None,
569 ) -> random.PRNGKey:
570 """
571 Initialize the variational parameters of the model.
573 Args:
574 random_key (Optional[random.PRNGKey]): JAX random key for initialization.
575 If None, uses the model's internal random key.
576 repeat (int): Number of parameter sets to create (batch dimension).
577 Defaults to 1.
578 initialization (Optional[str]): Strategy for parameter initialization.
579 Options: "random", "zeros", "pi", "zero-controlled", "pi-controlled".
580 If None, uses the strategy specified in the constructor.
581 initialization_domain (Optional[List[float]]): Domain [min, max] for
582 random initialization. If None, uses the domain from constructor.
584 Returns:
585 random.PRNGKey: Updated random key after initialization.
587 Raises:
588 Exception: If an invalid initialization method is specified.
589 """
590 # Initializing params
591 params_shape = (repeat, *self._params_shape)
593 # use existing strategy if not specified
594 initialization = initialization or self._inialization_strategy
595 initialization_domain = initialization_domain or self._initialization_domain
597 random_key, sub_key = safe_random_split(
598 random_key if random_key is not None else self.random_key
599 )
601 def set_control_params(params: jnp.ndarray, value: float) -> jnp.ndarray:
602 indices = self.pqc.get_control_indices(self.n_qubits)
603 if indices is None:
604 warnings.warn(
605 f"Specified {initialization} but circuit\
606 does not contain controlled rotation gates.\
607 Parameters are intialized randomly.",
608 UserWarning,
609 )
610 else:
611 np_params = np.array(params)
612 np_params[:, :, indices[0] : indices[1] : indices[2]] = (
613 np.ones_like(params[:, :, indices[0] : indices[1] : indices[2]])
614 * value
615 )
616 params = jnp.array(np_params)
617 return params
619 if initialization == "random":
620 self.params: jnp.ndarray = random.uniform(
621 sub_key,
622 params_shape,
623 minval=initialization_domain[0],
624 maxval=initialization_domain[1],
625 )
626 elif initialization == "zeros":
627 self.params: jnp.ndarray = jnp.zeros(params_shape)
628 elif initialization == "pi":
629 self.params: jnp.ndarray = jnp.ones(params_shape) * jnp.pi
630 elif initialization == "zero-controlled":
631 self.params: jnp.ndarray = random.uniform(
632 sub_key,
633 params_shape,
634 minval=initialization_domain[0],
635 maxval=initialization_domain[1],
636 )
637 self.params = set_control_params(self.params, 0)
638 elif initialization == "pi-controlled":
639 self.params: jnp.ndarray = random.uniform(
640 sub_key,
641 params_shape,
642 minval=initialization_domain[0],
643 maxval=initialization_domain[1],
644 )
645 self.params = set_control_params(self.params, jnp.pi)
646 else:
647 raise Exception("Invalid initialization method")
649 log.info(
650 f"Initialized parameters with shape {self.params.shape}\
651 using strategy {initialization}."
652 )
654 return random_key
656 def transform_input(
657 self, inputs: jnp.ndarray, enc_params: jnp.ndarray
658 ) -> jnp.ndarray:
659 """
660 Transform input data by scaling with encoding parameters.
662 Implements the input transformation as described in arXiv:2309.03279v2,
663 where inputs are linearly scaled by encoding parameters before being
664 used in the quantum circuit.
666 Args:
667 inputs (jnp.ndarray): Input data point of shape (n_input_feat,) or
668 (batch_size, n_input_feat).
669 enc_params (jnp.ndarray): Encoding weight scalar or vector used to
670 scale the input.
672 Returns:
673 jnp.ndarray: Transformed input, element-wise product of inputs
674 and enc_params.
675 """
676 return inputs * enc_params
678 def _iec(
679 self,
680 inputs: jnp.ndarray,
681 data_reupload: jnp.ndarray,
682 enc: Encoding,
683 enc_params: jnp.ndarray,
684 noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None,
685 random_key: Optional[random.PRNGKey] = None,
686 ) -> None:
687 """
688 Apply Input Encoding Circuit (IEC) with angle encoding.
690 Encodes classical input data into the quantum circuit using rotation
691 gates (e.g., RX, RY, RZ). Supports data re-uploading at specified
692 positions in the circuit.
694 Args:
695 inputs (jnp.ndarray): Input data of shape (n_input_feat,) or
696 (batch_size, n_input_feat).
697 data_reupload (jnp.ndarray): Boolean array of shape (n_qubits, n_input_feat)
698 indicating where to apply encoding gates.
699 enc (Encoding): Encoding strategy containing the encoding gate functions.
700 enc_params (jnp.ndarray): Encoding parameters of shape
701 (n_qubits, n_input_feat) used to scale inputs.
702 noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]):
703 Noise parameters for gate-level noise simulation. Defaults to None.
704 random_key (Optional[random.PRNGKey]): JAX random key for stochastic
705 noise. Defaults to None.
707 Returns:
708 None: Gates are applied in-place to the quantum circuit.
709 """
710 # check for zero, because due to input validation, input cannot be none
711 if self.remove_zero_encoding and self._zero_inputs and self.batch_shape[0] == 1:
712 return
714 for q in range(self.n_qubits):
715 # use the last dimension of the inputs (feature dimension)
716 for idx in range(inputs.shape[-1]):
717 if data_reupload[q, idx]:
718 # use elipsis to indiex only the last dimension
719 # as inputs are generally *not* qubit dependent
720 random_key, sub_key = safe_random_split(random_key)
721 enc[idx](
722 self.transform_input(inputs[..., idx], enc_params[q, idx]),
723 wires=q,
724 noise_params=noise_params,
725 random_key=sub_key,
726 input_idx=idx,
727 )
729 def _variational(
730 self,
731 params: jnp.ndarray,
732 inputs: jnp.ndarray,
733 pulse_params: Optional[jnp.ndarray] = None,
734 random_key: Optional[random.PRNGKey] = None,
735 enc_params: Optional[jnp.ndarray] = None,
736 gate_mode: str = "unitary",
737 noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None,
738 ) -> None:
739 """
740 Build the variational quantum circuit structure.
742 Constructs the circuit by applying state preparation, alternating
743 variational ansatz layers with input encoding layers, and optional
744 noise channels.
746 The first five parameters (after ``self``) - ``params``, ``inputs``,
747 ``pulse_params``, ``random_key``, ``enc_params`` - are the batchable
748 positional arguments.
749 The remaining keyword arguments are broadcast across the batch.
751 Args:
752 params (jnp.ndarray): Variational parameters of shape
753 (n_layers, n_params_per_layer).
754 inputs (jnp.ndarray): Input data of shape (n_input_feat,).
755 pulse_params (Optional[jnp.ndarray]): Pulse parameter scalers of shape
756 (n_layers, n_pulse_params_per_layer) for pulse-mode execution.
757 Defaults to None (uses model's pulse_params).
758 random_key (Optional[random.PRNGKey]): JAX random key for stochastic
759 operations. Defaults to None.
760 enc_params (Optional[jnp.ndarray]): Encoding parameters of shape
761 (n_qubits, n_input_feat). Defaults to None (uses model's enc_params).
762 gate_mode (str): Gate execution mode, either "unitary" or "pulse".
763 Defaults to "unitary".
764 noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]):
765 Noise parameters for simulation. Defaults to None.
767 Returns:
768 None: Gates are applied in-place to the quantum circuit.
770 Note:
771 Issues RuntimeWarning if called directly without providing parameters
772 that would normally be passed through the forward method.
773 """
774 # TODO: rework and double check params shape
775 if len(params.shape) > 2 and params.shape[0] == 1:
776 params = params[0]
778 if len(inputs.shape) > 1 and inputs.shape[0] == 1:
779 inputs = inputs[0]
781 if enc_params is None:
782 # TODO: Raise warning if trainable frequencies is True, or similar. I.e., no
783 # warning if user does not care for frequencies or enc_params
784 if self.trainable_frequencies:
785 warnings.warn(
786 "Explicit call to `_circuit` or `_variational` detected: "
787 "`enc_params` is None, using `self.enc_params` instead.",
788 RuntimeWarning,
789 )
790 enc_params = self.enc_params
792 if pulse_params is None:
793 if gate_mode == "pulse":
794 warnings.warn(
795 "Explicit call to `_circuit` or `_variational` detected: "
796 "`pulse_params` is None, using `self.pulse_params` instead.",
797 RuntimeWarning,
798 )
799 pulse_params = self.pulse_params
801 # Squeeze batch dimension for pulse_params (batch-first convention)
802 if len(pulse_params.shape) > 2 and pulse_params.shape[0] == 1:
803 pulse_params = pulse_params[0]
805 if noise_params is None:
806 if self.noise_params is not None:
807 warnings.warn(
808 "Explicit call to `_circuit` or `_variational` detected: "
809 "`noise_params` is None, using `self.noise_params` instead.",
810 RuntimeWarning,
811 )
812 noise_params = self.noise_params
814 if noise_params is not None:
815 if random_key is None:
816 warnings.warn(
817 "Explicit call to `_circuit` or `_variational` detected: "
818 "`random_key` is None, using `random.PRNGKey(0)` instead.",
819 RuntimeWarning,
820 )
821 random_key = self.random_key
822 self._apply_state_prep_noise(noise_params=noise_params)
824 # state preparation
825 for q in range(self.n_qubits):
826 for _sp, sp_pulse_params in zip(self._sp, self.sp_pulse_params):
827 random_key, sub_key = safe_random_split(random_key)
828 _sp(
829 wires=q,
830 pulse_params=sp_pulse_params,
831 noise_params=noise_params,
832 random_key=sub_key,
833 gate_mode=gate_mode,
834 )
836 # circuit building
837 for layer in range(0, self.n_layers):
838 random_key, sub_key = safe_random_split(random_key)
839 # ansatz layers
840 self.pqc(
841 params[layer],
842 self.n_qubits,
843 pulse_params=pulse_params[layer],
844 noise_params=noise_params,
845 random_key=sub_key,
846 gate_mode=gate_mode,
847 )
849 random_key, sub_key = safe_random_split(random_key)
850 # encoding layers
851 self._iec(
852 inputs,
853 data_reupload=self.data_reupload[layer],
854 enc=self._enc,
855 enc_params=enc_params[layer],
856 noise_params=noise_params,
857 random_key=sub_key,
858 )
860 # visual barrier (no-op in yaqsi, purely cosmetic in PennyLane)
862 # final ansatz layer
863 if self.has_dru: # same check as in init
864 random_key, sub_key = safe_random_split(random_key)
865 self.pqc(
866 params[self.n_layers],
867 self.n_qubits,
868 pulse_params=pulse_params[-1],
869 noise_params=noise_params,
870 random_key=sub_key,
871 gate_mode=gate_mode,
872 )
874 # channel noise
875 if noise_params is not None:
876 self._apply_general_noise(noise_params=noise_params)
878 def _build_obs(self) -> Tuple[str, List[op.Operation]]:
879 """Build the yaqsi measurement type and observable list.
881 Translates the model's ``execution_type`` and ``output_qubit``
882 settings into parameters suitable for
883 :meth:`~qml_essentials.yaqsi.Script.execute`.
885 Returns:
886 Tuple ``(meas_type, obs)`` where *meas_type* is one of
887 ``"expval"``, ``"probs"``, ``"density"``, ``"state"`` and *obs*
888 is a (possibly empty) list of :class:`Operation` observables.
889 """
890 if self.execution_type == "density":
891 return "density", []
893 if self.execution_type == "state":
894 return "state", []
896 if self.execution_type == "expval":
897 obs: List[op.Operation] = []
898 for qubit_spec in self.output_qubit:
899 if isinstance(qubit_spec, int):
900 obs.append(op.PauliZ(wires=qubit_spec))
901 else:
902 # parity: Z \\otimes Z \\otimes …
903 obs.append(ys.build_parity_observable(list(qubit_spec)))
904 return "expval", obs
906 if self.execution_type == "probs":
907 # probs are computed on the full system; subsystem
908 # marginalisation is handled in _postprocess_res
909 return "probs", []
911 raise ValueError(f"Invalid execution_type: {self.execution_type}.")
913 def _apply_state_prep_noise(
914 self, noise_params: Dict[str, Union[float, Dict[str, float]]]
915 ) -> None:
916 """
917 Apply state preparation noise to all qubits.
919 Simulates imperfect state preparation by applying BitFlip errors
920 to each qubit with the specified probability.
922 Args:
923 noise_params (Dict[str, Union[float, Dict[str, float]]]): Dictionary
924 containing noise parameters. Uses the "StatePreparation" key
925 for the BitFlip probability.
927 Returns:
928 None: Noise channels are applied in-place to the circuit.
929 """
930 p = noise_params.get("StatePreparation", 0.0)
931 if p > 0:
932 for q in range(self.n_qubits):
933 op.BitFlip(p, wires=q)
935 def _apply_general_noise(
936 self, noise_params: Dict[str, Union[float, Dict[str, float]]]
937 ) -> None:
938 """
939 Apply general noise channels to all qubits.
941 Applies various decoherence and error channels after the circuit
942 execution, simulating environmental noise effects.
944 Args:
945 noise_params (Dict[str, Union[float, Dict[str, float]]]): Dictionary
946 containing noise parameters with the following supported keys:
947 - "AmplitudeDamping" (float): Probability for amplitude damping.
948 - "PhaseDamping" (float): Probability for phase damping.
949 - "Measurement" (float): Probability for measurement error (BitFlip).
950 - "ThermalRelaxation" (Dict): Dictionary with keys "t1", "t2",
951 "t_factor" for thermal relaxation simulation.
953 Returns:
954 None: Noise channels are applied in-place to the circuit.
956 Note:
957 Gate-level noise (e.g., GateError) is handled separately in the
958 Gates.Noise module and applied at the individual gate level.
959 """
960 amp_damp = noise_params.get("AmplitudeDamping", 0.0)
961 phase_damp = noise_params.get("PhaseDamping", 0.0)
962 thermal_relax = noise_params.get("ThermalRelaxation", 0.0)
963 meas = noise_params.get("Measurement", 0.0)
964 for q in range(self.n_qubits):
965 if amp_damp > 0:
966 op.AmplitudeDamping(amp_damp, wires=q)
967 if phase_damp > 0:
968 op.PhaseDamping(phase_damp, wires=q)
969 if meas > 0:
970 op.BitFlip(meas, wires=q)
971 if isinstance(thermal_relax, dict):
972 t1 = thermal_relax["t1"]
973 t2 = thermal_relax["t2"]
974 t_factor = thermal_relax["t_factor"]
975 circuit_depth = self._get_circuit_depth()
976 tg = circuit_depth * t_factor
977 op.ThermalRelaxationError(1.0, t1, t2, tg, q)
979 def _get_circuit_depth(self, inputs: Optional[jnp.ndarray] = None) -> int:
980 """
981 Calculate the depth of the quantum circuit.
983 Records the circuit onto a tape (without noise) and computes the
984 depth as the length of the critical path: each gate is scheduled
985 at the earliest time step after all of its qubits are free.
987 Args:
988 inputs (Optional[jnp.ndarray]): Input data for circuit evaluation.
989 If None, default zero inputs are used.
991 Returns:
992 int: The circuit depth (longest path of gates in the circuit).
993 """
994 # Return cached value if available
995 if hasattr(self, "_cached_circuit_depth"):
996 return self._cached_circuit_depth
998 inputs = self._inputs_validation(inputs)
1000 # Temporarily clear noise_params to prevent _variational from
1001 # picking them up (which would call _apply_general_noise ->
1002 # _get_circuit_depth again, causing infinite recursion).
1003 saved_noise = self._noise_params
1004 self._noise_params = None
1006 with recording() as tape:
1007 self._variational(
1008 self.params[0] if self.params.ndim == 3 else self.params,
1009 inputs[0] if inputs.ndim == 2 else inputs,
1010 noise_params=None,
1011 )
1013 self._noise_params = saved_noise
1015 # Filter out noise channels - only count unitary gates
1016 ops = [o for o in tape if not isinstance(o, KrausChannel)]
1018 if not ops:
1019 self._cached_circuit_depth = 0
1020 return 0
1022 # Schedule each gate at the earliest time step where all its wires
1023 # are free. ``wire_busy[q]`` tracks the next free time step for
1024 # qubit ``q``.
1025 wire_busy: Dict[int, int] = {}
1026 depth = 0
1027 for gate in ops:
1028 start = max((wire_busy.get(w, 0) for w in gate.wires), default=0)
1029 end = start + 1
1030 for w in gate.wires:
1031 wire_busy[w] = end
1032 depth = max(depth, end)
1034 self._cached_circuit_depth = depth
1035 return depth
1037 def draw(
1038 self,
1039 inputs: Optional[jnp.ndarray] = None,
1040 figure: str = "text",
1041 **kwargs: Any,
1042 ) -> Union[str, Any]:
1043 """Visualize the quantum circuit.
1045 Records the circuit tape (without noise) and renders the gate
1046 sequence using the requested backend.
1048 Args:
1049 inputs (Optional[jnp.ndarray]): Input data for the circuit.
1050 If ``None``, default zero inputs are used.
1051 figure (str): Rendering backend. One of:
1053 * ``"text"`` - ASCII art (returned as a ``str``).
1054 * ``"mpl"`` - Matplotlib figure (returns ``(fig, ax)``).
1055 * ``"tikz"`` - LaTeX/TikZ ``quantikz`` code (returns a
1056 :class:`TikzFigure`).
1057 * ``"pulse"`` - Pulse schedule (returns ``(fig, axes)``).
1058 Only meaningful for pulse-mode models.
1060 **kwargs: Extra options forwarded to the drawing backend
1061 (e.g. ``gate_values=True``).
1063 Returns:
1064 Depends on figure:
1066 * ``"text"`` -> ``str``
1067 * ``"mpl"`` -> ``(matplotlib.figure.Figure, matplotlib.axes.Axes)``
1068 * ``"tikz"`` -> :class:`TikzFigure`
1070 Raises:
1071 ValueError: If figure is not one of the supported modes.
1072 """
1073 inputs = self._inputs_validation(inputs)
1074 params = self.params[0] if self.params.ndim == 3 else self.params
1075 inp = inputs[0] if inputs.ndim == 2 else inputs
1077 if figure == "pulse":
1078 return self.draw_pulse(inputs=inputs, **kwargs)
1080 # Record without noise to get a clean circuit
1081 saved_noise = self._noise_params
1082 self._noise_params = None
1084 draw_script = ys.Script(f=self._variational, n_qubits=self.n_qubits)
1085 result = draw_script.draw(
1086 figure=figure,
1087 args=(params, inp),
1088 kwargs={"noise_params": None},
1089 **kwargs,
1090 )
1092 self._noise_params = saved_noise
1093 return result
1095 def draw_pulse(
1096 self,
1097 inputs: Optional[jnp.ndarray] = None,
1098 **kwargs: Any,
1099 ) -> Any:
1100 """Visualize the pulse schedule for the circuit.
1102 Records the circuit in pulse mode and collects PulseEvents
1103 automatically via the pulse-event tape, then renders them.
1105 Args:
1106 inputs: Input data. If ``None``, default zero inputs are used.
1107 **kwargs: Forwarded to
1108 :func:`~qml_essentials.drawing.draw_pulse_schedule`
1109 (e.g. ``show_carrier=True``, ``n_samples=300``).
1111 Returns:
1112 ``(fig, axes)`` — Matplotlib Figure and array of Axes.
1113 """
1114 inputs = self._inputs_validation(inputs)
1115 params = self.params[0] if self.params.ndim == 3 else self.params
1116 inp = inputs[0] if inputs.ndim == 2 else inputs
1118 draw_script = ys.Script(f=self._variational, n_qubits=self.n_qubits)
1119 return draw_script.draw(
1120 figure="pulse",
1121 args=(params, inp),
1122 kwargs={
1123 "gate_mode": "pulse",
1124 "noise_params": None,
1125 },
1126 **kwargs,
1127 )
1129 def __repr__(self) -> str:
1130 """Return text representation of the quantum circuit model."""
1131 return self.draw(figure="text")
1133 def __str__(self) -> str:
1134 """Return string representation of the quantum circuit model."""
1135 return self.draw(figure="text")
1137 def _params_validation(self, params: Optional[jnp.ndarray]) -> jnp.ndarray:
1138 """
1139 Validate and normalize variational parameters.
1141 Ensures parameters have the correct shape with a batch dimension,
1142 and updates the model's internal parameters if new ones are provided.
1144 Args:
1145 params (Optional[jnp.ndarray]): Variational parameters to validate.
1146 If None, returns the model's current parameters.
1148 Returns:
1149 jnp.ndarray: Validated parameters with shape
1150 (batch_size, n_layers, n_params_per_layer).
1151 """
1152 # append batch axis if not provided
1153 if params is not None:
1154 if len(params.shape) == 2:
1155 params = np.expand_dims(params, axis=0)
1157 self.params = params
1158 else:
1159 params = self.params
1161 return params
1163 def _pulse_params_validation(
1164 self, pulse_params: Optional[jnp.ndarray]
1165 ) -> jnp.ndarray:
1166 """
1167 Validate and normalize pulse parameters.
1169 Ensures pulse parameters are set, using model defaults if not provided.
1171 Args:
1172 pulse_params (Optional[jnp.ndarray]): Pulse parameter scalers.
1173 If None, returns the model's current pulse parameters.
1175 Returns:
1176 jnp.ndarray: Validated pulse parameters with shape
1177 (batch_size, n_layers, n_pulse_params_per_layer).
1178 """
1179 if pulse_params is None:
1180 pulse_params = self.pulse_params
1181 else:
1182 # ensure batch dimension exists (batch-first convention)
1183 if len(pulse_params.shape) == 2:
1184 pulse_params = jnp.expand_dims(pulse_params, axis=0)
1185 self.pulse_params = pulse_params
1187 return pulse_params
1189 def _enc_params_validation(self, enc_params: Optional[jnp.ndarray]) -> jnp.ndarray:
1190 """
1191 Validate and normalize encoding parameters.
1193 Ensures encoding parameters have the correct shape for the model's
1194 input feature dimensions.
1196 Args:
1197 enc_params (Optional[jnp.ndarray]): Encoding parameters to validate.
1198 If None, returns the model's current encoding parameters.
1200 Returns:
1201 jnp.ndarray: Validated encoding parameters with shape
1202 (n_qubits, n_input_feat).
1204 Raises:
1205 ValueError: If enc_params shape is incompatible with n_input_feat > 1.
1206 """
1207 if enc_params is None:
1208 enc_params = self.enc_params
1209 else:
1210 if self.trainable_frequencies:
1211 self.enc_params = enc_params
1212 else:
1213 self.enc_params = jnp.array(enc_params)
1215 if len(enc_params.shape) == 1 and self.n_input_feat == 1:
1216 enc_params = enc_params.reshape(-1, 1)
1217 elif len(enc_params.shape) == 1 and self.n_input_feat > 1:
1218 raise ValueError(
1219 f"Input dimension {self.n_input_feat} >1 but \
1220 `enc_params` has shape {enc_params.shape}"
1221 )
1223 return enc_params
1225 def _inputs_validation(
1226 self, inputs: Union[None, List, float, int, jnp.ndarray]
1227 ) -> jnp.ndarray:
1228 """
1229 Validate and normalize input data.
1231 Converts various input formats to a standardized 2D array shape
1232 suitable for batch processing in the quantum circuit.
1234 Args:
1235 inputs (Union[None, List, float, int, jnp.ndarray]): Input data in
1236 various formats:
1237 - None: Returns zeros with shape (1, n_input_feat)
1238 - float/int: Single scalar value
1239 - List: List of values or batched inputs
1240 - jnp.ndarray: NumPy/JAX array
1242 Returns:
1243 jnp.ndarray: Validated inputs with shape (batch_size, n_input_feat).
1245 Raises:
1246 ValueError: If input shape is incompatible with expected n_input_feat.
1248 Warns:
1249 UserWarning: If input is replicated to match n_input_feat.
1250 """
1251 self._zero_inputs = False
1252 if isinstance(inputs, List):
1253 inputs = jnp.array(np.stack(inputs))
1254 elif isinstance(inputs, float) or isinstance(inputs, int):
1255 inputs = jnp.array([inputs])
1256 elif inputs is None:
1257 inputs = jnp.array([[0] * self.n_input_feat])
1259 if not inputs.any():
1260 self._zero_inputs = True
1262 if len(inputs.shape) <= 1:
1263 if self.n_input_feat == 1:
1264 # add a batch dimension
1265 inputs = inputs.reshape(-1, 1)
1266 else:
1267 if inputs.shape[0] == self.n_input_feat:
1268 inputs = inputs.reshape(1, -1)
1269 else:
1270 inputs = inputs.reshape(-1, 1)
1271 inputs = inputs.repeat(self.n_input_feat, axis=1)
1272 warnings.warn(
1273 f"Expected {self.n_input_feat} inputs, but {inputs.shape[0]} "
1274 "was provided, replicating input for all input features.",
1275 UserWarning,
1276 )
1277 else:
1278 if inputs.shape[1] != self.n_input_feat:
1279 raise ValueError(
1280 f"Wrong number of inputs provided. Expected {self.n_input_feat} "
1281 f"inputs, but input has shape {inputs.shape}."
1282 )
1284 return inputs
1286 def _postprocess_res(self, result: Union[List, jnp.ndarray]) -> jnp.ndarray:
1287 """
1288 Post-process circuit execution results for uniform shape.
1290 Converts list outputs (from multiple measurements) to stacked arrays
1291 and reorders axes for consistent batch dimension placement.
1293 Args:
1294 result (Union[List, jnp.ndarray]): Raw circuit output, either a
1295 list of measurement results or a single array.
1297 Returns:
1298 jnp.ndarray: Uniformly shaped result array with batch dimension first.
1299 """
1300 if isinstance(result, list):
1301 # we use moveaxis here because in case of parity measure,
1302 # there is another dimension appended to the end and
1303 # simply transposing would result in a wrong shape
1304 result = jnp.stack(result)
1305 if len(result.shape) > 1:
1306 result = jnp.moveaxis(result, 0, 1)
1307 return result
1309 def _assimilate_batch(
1310 self,
1311 inputs: jnp.ndarray,
1312 params: jnp.ndarray,
1313 pulse_params: jnp.ndarray,
1314 ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
1315 """
1316 Align batch dimensions across inputs, parameters, and pulse parameters.
1318 Broadcasts and reshapes arrays to have compatible batch dimensions
1319 for vectorized circuit execution. Sets the internal batch_shape.
1321 Args:
1322 inputs (jnp.ndarray): Input data of shape (B_I, n_input_feat).
1323 params (jnp.ndarray): Parameters of shape (B_P, n_layers, n_params).
1324 pulse_params (jnp.ndarray): Pulse params of shape (B_R, n_layers, n_pulse).
1326 Returns:
1327 Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: Tuple containing:
1328 - inputs: Reshaped to (B, n_input_feat) where B = B_I * B_P * B_R
1329 - params: Reshaped to (B, n_layers, n_params)
1330 - pulse_params: Reshaped to (B, n_layers, n_pulse)
1332 Note:
1333 The effective batch shape depends on repeat_batch_axis configuration.
1334 This is the only method that sets self._batch_shape.
1335 """
1336 B_I = inputs.shape[0]
1337 # we check for the product because there is a chance that
1338 # there are no params. In this case we want B_P to be 1
1339 B_P = 1 if 0 in params.shape else params.shape[0]
1340 B_R = pulse_params.shape[0]
1342 # THIS is the only place where we set the batch shape
1343 self._batch_shape = (B_I, B_P, B_R)
1344 B = np.prod(self.eff_batch_shape)
1346 # [B_I, ...] -> [B_I, B_P, B_R, ...] -> [B, ...]
1347 if B_I > 1 and self.repeat_batch_axis[0]:
1348 if self.repeat_batch_axis[1]:
1349 inputs = jnp.repeat(inputs[:, None, None, ...], B_P, axis=1)
1350 if self.repeat_batch_axis[2]:
1351 inputs = jnp.repeat(inputs, B_R, axis=2)
1352 inputs = inputs.reshape(B, *inputs.shape[3:])
1354 # [B_P, ..., ...] -> [B_I, B_P, B_R, ..., ...] -> [B, ..., ...]
1355 if B_P > 1 and self.repeat_batch_axis[1]:
1356 # add B_I axis before first, and B_R axis after first batch dim
1357 params = params[None, :, None, ...] # [B_I(=1), B_P, B_R(=1), ...]
1358 if self.repeat_batch_axis[0]:
1359 params = jnp.repeat(params, B_I, axis=0) # [B_I, B_P, 1, ...]
1360 if self.repeat_batch_axis[2]:
1361 params = jnp.repeat(params, B_R, axis=2) # [B_I, B_P, B_R, ...]
1362 params = params.reshape(B, *params.shape[3:])
1364 # [B_R, ..., ...] -> [B_I, B_P, B_R, ..., ...] -> [B, ..., ...]
1365 if B_R > 1 and self.repeat_batch_axis[2]:
1366 # add B_I axis and B_P axis before B_R
1367 pulse_params = pulse_params[None, None, ...] # [B_I(=1), B_P(=1), B_R, ...]
1368 if self.repeat_batch_axis[0]:
1369 pulse_params = jnp.repeat(
1370 pulse_params, B_I, axis=0
1371 ) # [B_I, 1, B_R, ...]
1372 if self.repeat_batch_axis[1]:
1373 pulse_params = jnp.repeat(
1374 pulse_params, B_P, axis=1
1375 ) # [B_I, B_P, B_R, ...]
1376 pulse_params = pulse_params.reshape(B, *pulse_params.shape[3:])
1378 return inputs, params, pulse_params
1380 def _requires_density(self) -> bool:
1381 """
1382 Check if density matrix simulation is required.
1384 Determines whether the circuit must be executed with the mixed-state
1385 simulator based on execution type and noise configuration.
1387 Returns:
1388 bool: True if density matrix simulation is required, False otherwise.
1389 Returns True if:
1390 - execution_type is "density", or
1391 - Any non-coherent noise channel has non-zero probability
1392 """
1393 if self.execution_type == "density":
1394 return True
1396 if self.noise_params is None:
1397 return False
1399 coherent_noise = {"GateError"}
1400 for k, v in self.noise_params.items():
1401 if k in coherent_noise:
1402 continue
1403 if v is not None and v > 0:
1404 return True
1405 return False
1407 def __call__(
1408 self,
1409 params: Optional[jnp.ndarray] = None,
1410 inputs: Optional[jnp.ndarray] = None,
1411 pulse_params: Optional[jnp.ndarray] = None,
1412 enc_params: Optional[jnp.ndarray] = None,
1413 data_reupload: Union[bool, List[List[bool]], List[List[List[bool]]]] = None,
1414 noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None,
1415 execution_type: Optional[str] = None,
1416 force_mean: bool = False,
1417 gate_mode: str = "unitary",
1418 ) -> jnp.ndarray:
1419 """
1420 Execute the quantum circuit (callable interface).
1422 Provides a convenient callable interface for circuit execution,
1423 delegating to the _forward method.
1425 Args:
1426 params (Optional[jnp.ndarray]): Variational parameters of shape
1427 (n_layers, n_params_per_layer) or (batch, n_layers, n_params_per_layer).
1428 If None, uses model's internal parameters.
1429 inputs (Optional[jnp.ndarray]): Input data of shape
1430 (batch_size, n_input_feat). If None, uses zero inputs.
1431 pulse_params (Optional[jnp.ndarray]): Pulse parameter scalers for
1432 pulse-mode gate execution.
1433 enc_params (Optional[jnp.ndarray]): Encoding parameters of shape
1434 (n_qubits, n_input_feat). If None, uses model's encoding parameters.
1435 data_reupload (Union[bool, List[List[bool]], List[List[List[bool]]]]):
1436 Data reupload configuration. If None, uses previously set reupload
1437 configuration.
1438 noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]):
1439 Noise configuration. If None, uses previously set noise parameters.
1440 execution_type (Optional[str]): Measurement type: "expval", "density",
1441 "probs", or "state". If None, uses current execution_type setting.
1442 force_mean (bool): If True, averages results over measurement qubits.
1443 Defaults to False.
1444 gate_mode (str): Gate execution backend, "unitary" or "pulse".
1445 Defaults to "unitary".
1447 Returns:
1448 jnp.ndarray: Circuit output with shape depending on execution_type:
1449 - "expval": (n_output_qubits,) or scalar
1450 - "density": (2^n_output, 2^n_output)
1451 - "probs": (2^n_output,) or (n_pairs, 2^pair_size)
1452 - "state": (2^n_qubits,)
1453 """
1454 # Call forward method which handles the actual caching etc.
1455 return self._forward(
1456 params=params,
1457 inputs=inputs,
1458 pulse_params=pulse_params,
1459 enc_params=enc_params,
1460 data_reupload=data_reupload,
1461 noise_params=noise_params,
1462 execution_type=execution_type,
1463 force_mean=force_mean,
1464 gate_mode=gate_mode,
1465 )
1467 def _forward(
1468 self,
1469 params: Optional[jnp.ndarray] = None,
1470 inputs: Optional[jnp.ndarray] = None,
1471 pulse_params: Optional[jnp.ndarray] = None,
1472 enc_params: Optional[jnp.ndarray] = None,
1473 data_reupload: Union[bool, List[List[bool]], List[List[List[bool]]]] = None,
1474 noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None,
1475 execution_type: Optional[str] = None,
1476 force_mean: bool = False,
1477 gate_mode: str = "unitary",
1478 ) -> jnp.ndarray:
1479 """
1480 Execute the quantum circuit forward pass.
1482 Internal implementation of the forward pass that handles parameter
1483 validation, batch alignment, and circuit execution routing.
1485 Args:
1486 params (Optional[jnp.ndarray]): Variational parameters of shape
1487 (n_layers, n_params_per_layer) or
1488 (batch, n_layers, n_params_per_layer).
1489 If None, uses model's internal parameters.
1490 inputs (Optional[jnp.ndarray]): Input data of shape
1491 (batch_size, n_input_feat).
1492 If None, uses zero inputs.
1493 pulse_params (Optional[jnp.ndarray]): Pulse parameter scalers for
1494 pulse-mode gate execution.
1495 enc_params (Optional[jnp.ndarray]): Encoding parameters of shape
1496 (n_qubits, n_input_feat). If None, uses model's encoding parameters.
1497 data_reupload (Union[bool, List[List[bool]], List[List[List[bool]]]]):
1498 Data reupload configuration. If None, uses previously set reupload
1499 configuration.
1500 noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]):
1501 Noise configuration. If None, uses previously set noise parameters.
1502 execution_type (Optional[str]): Measurement type: "expval", "density",
1503 "probs", or "state". If None, uses current execution_type setting.
1504 force_mean (bool): If True, averages results over measurement qubits.
1505 Defaults to False.
1506 gate_mode (str): Gate execution backend, "unitary" or "pulse".
1507 Defaults to "unitary".
1509 Returns:
1510 jnp.ndarray: Circuit output with shape depending on execution_type:
1511 - "expval": (n_output_qubits,) or scalar
1512 - "density": (2^n_output, 2^n_output)
1513 - "probs": (2^n_output,) or (n_pairs, 2^pair_size)
1514 - "state": (2^n_qubits,)
1516 Raises:
1517 ValueError: If pulse_params provided without pulse gate_mode, or
1518 if noise_params provided with pulse gate_mode.
1519 """
1520 # set the parameters as object attributes
1521 if noise_params is not None:
1522 self.noise_params = noise_params
1523 if execution_type is not None:
1524 self.execution_type = execution_type
1525 self.gate_mode = gate_mode
1527 # consistency checks
1528 if pulse_params is not None and gate_mode != "pulse":
1529 raise ValueError(
1530 "pulse_params were provided but gate_mode is not 'pulse'. "
1531 "Either switch gate_mode='pulse' or do not pass pulse_params."
1532 )
1534 # TODO: add testing
1535 if data_reupload is not None:
1536 self.data_reupload = data_reupload
1538 params = self._params_validation(params)
1539 pulse_params = self._pulse_params_validation(pulse_params)
1540 inputs = self._inputs_validation(inputs)
1541 enc_params = self._enc_params_validation(enc_params)
1543 inputs, params, pulse_params = self._assimilate_batch(
1544 inputs,
1545 params,
1546 pulse_params,
1547 )
1549 # split to generate a sub_key, required for actual execution
1550 self.random_key, sub_key = safe_random_split(self.random_key)
1552 # Build measurement type & observables from execution_type / output_qubit
1553 meas_type, obs = self._build_obs()
1555 # Yaqsi auto-routes between statevector and density-matrix simulation
1556 # based on whether noise channels appear on the tape, so a single
1557 B = np.prod(self.eff_batch_shape)
1559 # kwargs are broadcast (not vmapped over)
1560 exec_kwargs = dict(
1561 noise_params=self.noise_params,
1562 gate_mode=self.gate_mode,
1563 )
1565 # Build a shot key from the random_key if shots are requested
1566 shot_key = None
1567 if self.shots is not None:
1568 # overwrite subkey and split shot_key
1569 sub_key, shot_key = safe_random_split(sub_key)
1571 if B > 1:
1572 # use random keys, derived from the subkey
1573 random_keys = safe_random_split(sub_key, num=B)
1575 in_axes = (
1576 0 if self.batch_shape[1] > 1 else None, # params
1577 0 if self.batch_shape[0] > 1 else None, # inputs
1578 0 if self.batch_shape[2] > 1 else None, # pulse_params
1579 0, # random_keys
1580 None, # enc_params (broadcast, not batched)
1581 )
1583 result = self.script.execute(
1584 type=meas_type,
1585 obs=obs,
1586 args=(params, inputs, pulse_params, random_keys, enc_params),
1587 kwargs=exec_kwargs,
1588 in_axes=in_axes,
1589 shots=self.shots,
1590 key=shot_key,
1591 )
1592 else:
1593 # use the subkey directly
1594 result = self.script.execute(
1595 type=meas_type,
1596 obs=obs,
1597 args=(params, inputs, pulse_params, sub_key, enc_params),
1598 kwargs=exec_kwargs,
1599 shots=self.shots,
1600 key=shot_key,
1601 )
1603 result = self._postprocess_res(result)
1605 # --- Post-processing for partial-qubit measurements ---------------
1606 if self.execution_type == "density" and not self.all_qubit_measurement:
1607 result = ys.partial_trace(result, self.n_qubits, self.output_qubit)
1609 if self.execution_type == "probs" and not self.all_qubit_measurement:
1610 if isinstance(self.output_qubit[0], (list, tuple)):
1611 # list of qubit groups - marginalize each independently
1612 result = jnp.stack(
1613 [
1614 ys.marginalize_probs(result, self.n_qubits, list(group))
1615 for group in self.output_qubit
1616 ]
1617 )
1618 else:
1619 result = ys.marginalize_probs(result, self.n_qubits, self.output_qubit)
1621 result = jnp.asarray(result)
1622 result = result.reshape((*self.eff_batch_shape, *self._result_shape)).squeeze()
1624 if (
1625 self.execution_type in ("expval", "probs")
1626 and force_mean
1627 and len(result.shape) > 0
1628 and self._result_shape[0] > 1
1629 ):
1630 result = result.mean(axis=-1)
1632 return result