Coverage for qml_essentials / model.py: 90%
511 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-27 15:44 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-27 15:44 +0000
1from typing import Any, Dict, Optional, Tuple, Callable, Union, List
3from qml_essentials import operations as op
4from qml_essentials import yaqsi as ys
5import warnings
6import jax.numpy as jnp
7import numpy as np
8from jax import random
10from qml_essentials.tape import recording
11from qml_essentials.operations import KrausChannel
12from qml_essentials.ansaetze import Ansaetze, Circuit, Encoding
13from qml_essentials.gates import Gates, PulseInformation as pinfo
14from qml_essentials.utils import safe_random_split
16import logging
18log = logging.getLogger(__name__)
21class Model:
22 """
23 A quantum circuit model.
24 """
26 def __init__(
27 self,
28 n_qubits: int,
29 n_layers: int,
30 circuit_type: Union[str, Circuit] = "No_Ansatz",
31 data_reupload: Union[bool, List[List[bool]], List[List[List[bool]]]] = True,
32 state_preparation: Union[
33 str, Callable, List[Union[str, Callable]], None
34 ] = None,
35 encoding: Union[Encoding, str, Callable, List[Union[str, Callable]]] = Gates.RX,
36 trainable_frequencies: bool = False,
37 initialization: str = "random",
38 initialization_domain: List[float] = [0, 2 * jnp.pi],
39 output_qubit: Union[List[int], int] = -1,
40 shots: Optional[int] = None,
41 random_seed: int = 1000,
42 remove_zero_encoding: bool = True,
43 repeat_batch_axis: List[bool] = [True, True, True],
44 pulse_shape: str = "gaussian",
45 ) -> None:
46 """
47 Initialize the quantum circuit model.
48 Parameters will have the shape [impl_n_layers, parameters_per_layer]
49 where impl_n_layers is the number of layers provided and added by one
50 depending if data_reupload is True and parameters_per_layer is given by
51 the chosen ansatz.
53 The model is initialized with the following parameters as defaults:
54 - noise_params: None
55 - execution_type: "expval"
56 - shots: None
58 Args:
59 n_qubits (int): The number of qubits in the circuit.
60 n_layers (int): The number of layers in the circuit.
61 circuit_type (str, Circuit): The type of quantum circuit to use.
62 If None, defaults to "no_ansatz".
63 data_reupload (Union[bool, List[bool], List[List[bool]]], optional):
64 Whether to reupload data to the quantum device on each
65 layer and qubit. Detailed re-uploading instructions can be given
66 as a list/array of 0/False and 1/True with shape (n_qubits,
67 n_layers) to specify where to upload the data. Defaults to True
68 for applying data re-uploading to the full circuit.
69 encoding (Union[str, Callable, List[str], List[Callable]], optional):
70 The unitary to use for encoding the input data. Can be a string
71 (e.g. "RX") or a callable (e.g. op.RX). Defaults to op.RX.
72 If input is multidimensional it is assumed to be a list of
73 unitaries or a list of strings.
74 trainable_frequencies (bool, optional):
75 Sets trainable encoding parameters for trainable frequencies.
76 Defaults to False.
77 initialization (str, optional): The strategy to initialize the parameters.
78 Can be "random", "zeros", "zero-controlled", "pi", or "pi-controlled".
79 Defaults to "random".
80 output_qubit (List[int], int, optional): The index of the output
81 qubit (or qubits). When set to -1 all qubits are measured, or a
82 global measurement is conducted, depending on the execution
83 type.
84 shots (Optional[int], optional): The number of shots to use for
85 the quantum device. Defaults to None.
86 random_seed (int, optional): seed for the random number generator
87 in initialization is "random" and for random noise parameters.
88 Defaults to 1000.
89 remove_zero_encoding (bool, optional): whether to
90 remove the zero encoding from the circuit. Defaults to True.
91 repeat_batch_axis (List[bool], optional): Each boolean in the array
92 determines over which axes to parallelise computation. The axes
93 correspond to [inputs, params, pulse_params]. Defaults to
94 [True, True, True], meaning that batching is enabled over all
95 axes.
96 pulse_shape (str, optional): Pulse envelope shape for pulse-level
97 simulation. One of ``PulseEnvelope.available()``.
98 Defaults to ``"gaussian"``.
100 Returns:
101 None
102 """
103 # Initialize default parameters needed for circuit evaluation
104 self.n_qubits: int = n_qubits
105 self.output_qubit: Union[List[int], int] = output_qubit
106 self.n_layers: int = n_layers
107 self.noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None
108 self.shots = shots
109 self.remove_zero_encoding = remove_zero_encoding
110 self.trainable_frequencies: bool = trainable_frequencies
111 self.execution_type: str = "expval"
112 self.repeat_batch_axis: List[bool] = repeat_batch_axis
114 # --- Pulse envelope ---
115 pinfo.set_envelope(pulse_shape)
117 # --- State Preparation ---
118 try:
119 self._sp = Gates.parse_gates(state_preparation, Gates)
120 except ValueError as e:
121 raise ValueError(f"Error parsing encodings: {e}")
123 # prepare corresponding pulse parameters (always optimized pulses)
124 self.sp_pulse_params = []
125 for sp in self._sp:
126 sp_name = sp.__name__ if hasattr(sp, "__name__") else str(sp)
128 if pinfo.gate_by_name(sp_name) is not None:
129 self.sp_pulse_params.append(pinfo.gate_by_name(sp_name).params)
130 else:
131 # gate has no pulse parametrization
132 self.sp_pulse_params.append(None)
134 # --- Encoding ---
135 if isinstance(encoding, Encoding):
136 # user wants custom strategy? do it!
137 self._enc = encoding
138 else:
139 # use hammming encoding by default
140 self._enc = Encoding("hamming", encoding)
142 if self._enc.is_golomb:
143 self._enc._n_qubits = n_qubits
145 # Number of possible inputs
146 self.n_input_feat = len(self._enc)
147 log.debug(f"Number of input features: {self.n_input_feat}")
149 # Trainable frequencies, default initialization as in arXiv:2309.03279v2
150 self.enc_params = jnp.ones((self.n_layers, self.n_qubits, self.n_input_feat))
152 self._zero_inputs = False
154 # --- Data-Reuploading ---
156 # Keep as NumPy array (not JAX) so that ``if data_reupload[q, idx]``
157 # in _iec remains a concrete Python bool even under jax.jit tracing.
158 # note that setting this will also update self.degree and self.frequencies
159 # and in consequence also self.has_dru
160 self.data_reupload = data_reupload
162 # check for the highest degree among all input dimensions
163 if self.has_dru:
164 impl_n_layers: int = n_layers + 1 # we need L+1 according to Schuld et al.
165 else:
166 impl_n_layers = n_layers
167 log.info(f"Number of implicit layers: {impl_n_layers}.")
169 # --- Ansatz ---
170 # only weak check for str. We trust the user to provide sth useful
171 if isinstance(circuit_type, str):
172 self.pqc: Callable[[Optional[jnp.ndarray], int], int] = getattr(
173 Ansaetze, circuit_type or "No_Ansatz"
174 )()
175 else:
176 self.pqc = circuit_type()
177 log.info(f"Using Ansatz {circuit_type}.")
179 # calculate the shape of the parameter vector here, we will re-use this in init.
180 params_per_layer = self.pqc.n_params_per_layer(self.n_qubits)
181 self._params_shape: Tuple[int, int] = (impl_n_layers, params_per_layer)
182 log.info(f"Parameters per layer: {params_per_layer}")
184 pulse_params_per_layer = self.pqc.n_pulse_params_per_layer(self.n_qubits)
185 self._pulse_params_shape: Tuple[int, int] = (
186 impl_n_layers,
187 pulse_params_per_layer,
188 )
190 # intialize to None as we can't know this yet
191 self._batch_shape = None
193 # this will also be re-used in the init method,
194 # however, only if nothing is provided
195 self._inialization_strategy = initialization
196 self._initialization_domain = initialization_domain
198 # ..here! where we only require a JAX random key
199 self.random_key = self.initialize_params(random.key(random_seed))
201 # Initializing pulse params
202 self.pulse_params: jnp.ndarray = jnp.ones((1, *self._pulse_params_shape))
204 log.info(f"Initialized pulse parameters with shape {self.pulse_params.shape}.")
206 # Initialise the yaqsi Script that wraps _variational.
207 # No device selection needed - yaqsi auto-routes between statevector
208 # and density-matrix simulation based on whether noise channels are
209 # present on the tape.
210 self.script = ys.Script(f=self._variational, n_qubits=self.n_qubits)
212 @property
213 def noise_params(self) -> Optional[Dict[str, Union[float, Dict[str, float]]]]:
214 """
215 Gets the noise parameters of the model.
217 Returns:
218 Optional[Dict[str, float]]: A dictionary of
219 noise parameters or None if not set.
220 """
221 return self._noise_params
223 @noise_params.setter
224 def noise_params(
225 self, kvs: Optional[Dict[str, Union[float, Dict[str, float]]]]
226 ) -> None:
227 """
228 Sets the noise parameters of the model.
230 Typically a "noise parameter" refers to the error probability.
231 ThermalRelaxation is a special case, and supports a dict as value with
232 structure:
233 "ThermalRelaxation":
234 {
235 "t1": 2000, # relative t1 time.
236 "t2": 1000, # relative t2 time
237 "t_factor" 1: # relative gate time factor
238 },
240 Args:
241 kvs (Optional[Dict[str, Union[float, Dict[str, float]]]]): A
242 dictionary of noise parameters. If all values are 0.0, the noise
243 parameters are set to None.
245 Returns:
246 None
247 """
248 # set to None if only zero values provided
249 if kvs is not None and all(v == 0.0 for v in kvs.values()):
250 kvs = None
252 # set default values
253 if kvs is not None:
254 defaults = {
255 "BitFlip": 0.0,
256 "PhaseFlip": 0.0,
257 "Depolarizing": 0.0,
258 "MultiQubitDepolarizing": 0.0,
259 "AmplitudeDamping": 0.0,
260 "PhaseDamping": 0.0,
261 "GateError": 0.0,
262 "ThermalRelaxation": None,
263 "StatePreparation": 0.0,
264 "Measurement": 0.0,
265 }
266 for key, default_val in defaults.items():
267 kvs.setdefault(key, default_val)
269 # check if there are any keys not supported
270 for key in kvs.keys():
271 if key not in defaults:
272 warnings.warn(
273 f"Noise type {key} is not supported by this package",
274 UserWarning,
275 )
277 # check valid params for thermal relaxation noise channel
278 tr_params = kvs["ThermalRelaxation"]
279 if isinstance(tr_params, dict):
280 tr_params.setdefault("t1", 0.0)
281 tr_params.setdefault("t2", 0.0)
282 tr_params.setdefault("t_factor", 0.0)
283 valid_tr_keys = {"t1", "t2", "t_factor"}
284 for k in tr_params.keys():
285 if k not in valid_tr_keys:
286 warnings.warn(
287 f"Thermal Relaxation parameter {k} is not supported "
288 f"by this package",
289 UserWarning,
290 )
291 if not all(tr_params.values()) or tr_params["t2"] > 2 * tr_params["t1"]:
292 warnings.warn(
293 "Received invalid values for Thermal Relaxation noise "
294 "parameter. Thermal relaxation is not applied!",
295 UserWarning,
296 )
297 kvs["ThermalRelaxation"] = 0.0
299 self._noise_params = kvs
301 @property
302 def output_qubit(self) -> List[int]:
303 """Get the output qubit indices for measurement."""
304 return self._output_qubit
306 @output_qubit.setter
307 def output_qubit(self, value: Union[int, List[int]]) -> None:
308 """
309 Set the output qubit(s) for measurement.
311 Args:
312 value: Qubit index or list of indices. Use -1 for all qubits.
313 """
314 if isinstance(value, list):
315 assert len(value) <= self.n_qubits, (
316 f"Size of output_qubit {len(value)} cannot be\
317 larger than number of qubits {self.n_qubits}."
318 )
319 elif isinstance(value, int):
320 if value == -1:
321 value = list(range(self.n_qubits))
322 else:
323 assert value < self.n_qubits, (
324 f"Output qubit {value} cannot be larger than {self.n_qubits}."
325 )
326 value = [value]
328 self._output_qubit = value
330 @property
331 def execution_type(self) -> str:
332 """
333 Gets the execution type of the model.
335 Returns:
336 str: The execution type, one of 'density', 'expval', or 'probs'.
337 """
338 return self._execution_type
340 @execution_type.setter
341 def execution_type(self, value: str) -> None:
342 if value == "density":
343 self._result_shape = (
344 2 ** len(self.output_qubit),
345 2 ** len(self.output_qubit),
346 )
347 elif value == "expval":
348 # check if all qubits are used
349 if len(self.output_qubit) == self.n_qubits:
350 self._result_shape = (len(self.output_qubit),)
351 # if not -> parity measurement with only 1D output per pair
352 # or n_local measurement
353 else:
354 self._result_shape = (len(self.output_qubit),)
355 elif value == "probs":
356 # in case this is a list of parities,
357 # each pair has 2^len(qubits) probabilities
358 n_parity = (
359 (2,) * len(self.output_qubit)
360 if isinstance(self.output_qubit, (Tuple, List))
361 else (2,)
362 )
363 self._result_shape = n_parity
364 elif value == "state":
365 self._result_shape = (2 ** len(self.output_qubit),)
366 else:
367 raise ValueError(f"Invalid execution type: {value}.")
369 if value == "state" and not self.all_qubit_measurement:
370 warnings.warn(
371 f"{value} measurement does ignore output_qubit, which is "
372 f"{self.output_qubit}.",
373 UserWarning,
374 )
376 if value == "probs" and self.shots is None:
377 warnings.warn(
378 "Setting execution_type to probs without specifying shots.",
379 UserWarning,
380 )
382 if value == "density" and self.shots is not None:
383 raise ValueError("Setting execution_type to density with shots not None.")
385 self._execution_type = value
387 @property
388 def shots(self) -> Optional[int]:
389 """
390 Gets the number of shots to use for the quantum device.
392 Returns:
393 Optional[int]: The number of shots.
394 """
395 return self._shots
397 @shots.setter
398 def shots(self, value: Optional[int]) -> None:
399 """
400 Sets the number of shots to use for the quantum device.
402 Args:
403 value (Optional[int]): The number of shots.
404 If an integer less than or equal to 0 is provided, it is set to None.
406 Returns:
407 None
408 """
409 if type(value) is int and value <= 0:
410 value = None
411 self._shots = value
413 @property
414 def params(self) -> jnp.ndarray:
415 """Get the variational parameters of the model."""
416 return self._params
418 @params.setter
419 def params(self, value: jnp.ndarray) -> None:
420 """Set the variational parameters, ensuring batch dimension exists."""
421 if len(value.shape) == 2:
422 value = value.reshape(1, *value.shape)
424 self._params = value
426 @property
427 def enc_params(self) -> jnp.ndarray:
428 """Get the encoding parameters used for input transformation."""
429 return self._enc_params
431 @enc_params.setter
432 def enc_params(self, value: jnp.ndarray) -> None:
433 """Set the encoding parameters."""
434 self._enc_params = value
436 @property
437 def pulse_params(self) -> jnp.ndarray:
438 """Get the pulse parameters for pulse-mode gate execution."""
439 return self._pulse_params
441 @pulse_params.setter
442 def pulse_params(self, value: jnp.ndarray) -> None:
443 """Set the pulse parameters."""
444 self._pulse_params = value
446 @property
447 def data_reupload(self) -> jnp.ndarray:
448 """Get the data reupload mask."""
449 return self._data_reupload
451 @data_reupload.setter
452 def data_reupload(self, value: jnp.ndarray) -> None:
453 """Set the data reupload mask.
455 Always converts to a concrete NumPy boolean array so that
456 ``if data_reupload[q, idx]`` in :meth:`_iec` remains a plain
457 Python ``bool`` even inside JAX-traced functions (jit / grad / vmap).
458 """
459 # Process data reuploading strategy and set degree
460 if not isinstance(value, bool):
461 if not isinstance(value, np.ndarray):
462 value = np.array(value)
464 if len(value.shape) == 2:
465 assert value.shape == (
466 self.n_layers,
467 self.n_qubits,
468 ), (
469 f"Data reuploading array has wrong shape. \
470 Expected {(self.n_layers, self.n_qubits)} or\
471 {(self.n_layers, self.n_qubits, self.n_input_feat)},\
472 got {value.shape}."
473 )
474 value = value.reshape(*value.shape, 1)
475 value = np.repeat(value, self.n_input_feat, axis=2)
477 assert value.shape == (
478 self.n_layers,
479 self.n_qubits,
480 self.n_input_feat,
481 ), (
482 f"Data reuploading array has wrong shape. \
483 Expected {(self.n_layers, self.n_qubits, self.n_input_feat)},\
484 got {value.shape}."
485 )
487 log.debug(f"Data reuploading array:\n{value}")
488 else:
489 if value:
490 value = np.ones((self.n_layers, self.n_qubits, self.n_input_feat))
491 log.debug("Full data reuploading.")
492 else:
493 value = np.zeros((self.n_layers, self.n_qubits, self.n_input_feat))
494 value[0][0] = 1
495 log.debug("No data reuploading.")
497 # convert to boolean values
498 self._data_reupload = np.asarray(value).astype(bool)
500 self.degree: Tuple = tuple(
501 self._enc.get_n_freqs(np.count_nonzero(self.data_reupload[..., i]))
502 for i in range(self.n_input_feat)
503 )
505 self.frequencies: Tuple = tuple(
506 self._enc.get_spectrum(np.count_nonzero(self.data_reupload[..., i]))
507 for i in range(self.n_input_feat)
508 )
510 # Cache has_dru as a plain Python bool so that it can be used in
511 # Python ``if`` statements even inside JAX-traced functions.
512 self._has_dru: bool = bool(max(int(np.max(f)) for f in self._frequencies) > 1)
514 @property
515 def degree(self) -> Tuple:
516 """Get the degree of the model."""
517 return self._degree
519 @degree.setter
520 def degree(self, value: Tuple):
521 self._degree = value
523 @property
524 def frequencies(self) -> Tuple:
525 """Get the frequencies of the model."""
526 return self._frequencies
528 @frequencies.setter
529 def frequencies(self, value: Tuple):
530 self._frequencies = value
532 @property
533 def has_dru(self) -> bool:
534 """Check if the model has data reupload."""
535 return self._has_dru
537 @property
538 def all_qubit_measurement(self) -> bool:
539 """Check if measurement is performed on all qubits."""
540 return self.output_qubit == list(range(self.n_qubits))
542 @property
543 def batch_shape(self) -> Tuple[int, ...]:
544 """
545 Get the batch shape (B_I, B_P, B_R).
546 If the model was not called before,
547 it returns (1, 1, 1).
549 Returns:
550 Tuple[int, ...]: Tuple of (input_batch, param_batch, pulse_batch).
551 Returns (1, 1, 1) if model has not been called yet.
552 """
553 if self._batch_shape is None:
554 log.debug("Model was not called yet. Returning (1,1,1) as batch shape.")
555 return (1, 1, 1)
556 return self._batch_shape
558 @property
559 def eff_batch_shape(self) -> Tuple[int, ...]:
560 """
561 Get the effective batch shape after applying repeat_batch_axis mask.
563 Returns:
564 Tuple[int, ...]: Effective batch dimensions, excluding zeros.
565 """
566 batch_shape = np.array(self.batch_shape) * self.repeat_batch_axis
567 batch_shape = batch_shape[batch_shape != 0]
568 return batch_shape
570 def initialize_params(
571 self,
572 random_key: Optional[random.PRNGKey] = None,
573 repeat: int = 1,
574 initialization: Optional[str] = None,
575 initialization_domain: Optional[List[float]] = None,
576 ) -> random.PRNGKey:
577 """
578 Initialize the variational parameters of the model.
580 Args:
581 random_key (Optional[random.PRNGKey]): JAX random key for initialization.
582 If None, uses the model's internal random key.
583 repeat (int): Number of parameter sets to create (batch dimension).
584 Defaults to 1.
585 initialization (Optional[str]): Strategy for parameter initialization.
586 Options: "random", "zeros", "pi", "zero-controlled", "pi-controlled".
587 If None, uses the strategy specified in the constructor.
588 initialization_domain (Optional[List[float]]): Domain [min, max] for
589 random initialization. If None, uses the domain from constructor.
591 Returns:
592 random.PRNGKey: Updated random key after initialization.
594 Raises:
595 Exception: If an invalid initialization method is specified.
596 """
597 # Initializing params
598 params_shape = (repeat, *self._params_shape)
600 # use existing strategy if not specified
601 initialization = initialization or self._inialization_strategy
602 initialization_domain = initialization_domain or self._initialization_domain
604 random_key, sub_key = safe_random_split(
605 random_key if random_key is not None else self.random_key
606 )
608 def set_control_params(params: jnp.ndarray, value: float) -> jnp.ndarray:
609 indices = self.pqc.get_control_indices(self.n_qubits)
610 if indices is None:
611 warnings.warn(
612 f"Specified {initialization} but circuit\
613 does not contain controlled rotation gates.\
614 Parameters are intialized randomly.",
615 UserWarning,
616 )
617 else:
618 np_params = np.array(params)
619 np_params[:, :, indices[0] : indices[1] : indices[2]] = (
620 np.ones_like(params[:, :, indices[0] : indices[1] : indices[2]])
621 * value
622 )
623 params = jnp.array(np_params)
624 return params
626 if initialization == "random":
627 self.params: jnp.ndarray = random.uniform(
628 sub_key,
629 params_shape,
630 minval=initialization_domain[0],
631 maxval=initialization_domain[1],
632 )
633 elif initialization == "zeros":
634 self.params: jnp.ndarray = jnp.zeros(params_shape)
635 elif initialization == "pi":
636 self.params: jnp.ndarray = jnp.ones(params_shape) * jnp.pi
637 elif initialization == "zero-controlled":
638 self.params: jnp.ndarray = random.uniform(
639 sub_key,
640 params_shape,
641 minval=initialization_domain[0],
642 maxval=initialization_domain[1],
643 )
644 self.params = set_control_params(self.params, 0)
645 elif initialization == "pi-controlled":
646 self.params: jnp.ndarray = random.uniform(
647 sub_key,
648 params_shape,
649 minval=initialization_domain[0],
650 maxval=initialization_domain[1],
651 )
652 self.params = set_control_params(self.params, jnp.pi)
653 else:
654 raise Exception("Invalid initialization method")
656 log.info(
657 f"Initialized parameters with shape {self.params.shape}\
658 using strategy {initialization}."
659 )
661 return random_key
663 def transform_input(
664 self, inputs: jnp.ndarray, enc_params: jnp.ndarray
665 ) -> jnp.ndarray:
666 """
667 Transform input data by scaling with encoding parameters.
669 Implements the input transformation as described in arXiv:2309.03279v2,
670 where inputs are linearly scaled by encoding parameters before being
671 used in the quantum circuit.
673 Args:
674 inputs (jnp.ndarray): Input data point of shape (n_input_feat,) or
675 (batch_size, n_input_feat).
676 enc_params (jnp.ndarray): Encoding weight scalar or vector used to
677 scale the input.
679 Returns:
680 jnp.ndarray: Transformed input, element-wise product of inputs
681 and enc_params.
682 """
683 return inputs * enc_params
685 def _iec(
686 self,
687 inputs: jnp.ndarray,
688 data_reupload: jnp.ndarray,
689 enc: Encoding,
690 enc_params: jnp.ndarray,
691 noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None,
692 random_key: Optional[random.PRNGKey] = None,
693 ) -> None:
694 """
695 Apply Input Encoding Circuit (IEC) with angle encoding.
697 Encodes classical input data into the quantum circuit using rotation
698 gates (e.g., RX, RY, RZ). Supports data re-uploading at specified
699 positions in the circuit.
701 For Golomb encoding, a single multi-qubit diagonal unitary is applied
702 to all qubits simultaneously instead of per-qubit rotation gates.
704 Args:
705 inputs (jnp.ndarray): Input data of shape (n_input_feat,) or
706 (batch_size, n_input_feat).
707 data_reupload (jnp.ndarray): Boolean array of shape (n_qubits, n_input_feat)
708 indicating where to apply encoding gates.
709 enc (Encoding): Encoding strategy containing the encoding gate functions.
710 enc_params (jnp.ndarray): Encoding parameters of shape
711 (n_qubits, n_input_feat) used to scale inputs.
712 noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]):
713 Noise parameters for gate-level noise simulation. Defaults to None.
714 random_key (Optional[random.PRNGKey]): JAX random key for stochastic
715 noise. Defaults to None.
717 Returns:
718 None: Gates are applied in-place to the quantum circuit.
719 """
720 # check for zero, because due to input validation, input cannot be none
721 if self.remove_zero_encoding and self._zero_inputs and self.batch_shape[0] == 1:
722 return
724 # --- Golomb encoding: single multi-qubit gate on all qubits --------
725 if enc.is_golomb:
726 idx = 0 # Golomb encoding supports a single input feature
727 # Check if any qubit has re-uploading enabled for this layer
728 if data_reupload[:, idx].any():
729 random_key, sub_key = safe_random_split(random_key)
730 # Use the mean of enc_params across qubits as scalar scaling
731 # (Golomb acts on all qubits jointly)
732 mean_enc_param = jnp.mean(enc_params[:, idx])
733 all_wires = list(range(self.n_qubits))
734 enc[idx](
735 self.transform_input(inputs[..., idx], mean_enc_param),
736 wires=all_wires,
737 noise_params=noise_params,
738 random_key=sub_key,
739 input_idx=idx,
740 )
741 return
743 # --- Standard per-qubit encoding -----------------------------------
744 for q in range(self.n_qubits):
745 # use the last dimension of the inputs (feature dimension)
746 for idx in range(inputs.shape[-1]):
747 if data_reupload[q, idx]:
748 # use elipsis to indiex only the last dimension
749 # as inputs are generally *not* qubit dependent
750 random_key, sub_key = safe_random_split(random_key)
751 enc[idx](
752 self.transform_input(inputs[..., idx], enc_params[q, idx]),
753 wires=q,
754 noise_params=noise_params,
755 random_key=sub_key,
756 input_idx=idx,
757 )
759 def _variational(
760 self,
761 params: jnp.ndarray,
762 inputs: jnp.ndarray,
763 pulse_params: Optional[jnp.ndarray] = None,
764 random_key: Optional[random.PRNGKey] = None,
765 enc_params: Optional[jnp.ndarray] = None,
766 gate_mode: str = "unitary",
767 noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None,
768 ) -> None:
769 """
770 Build the variational quantum circuit structure.
772 Constructs the circuit by applying state preparation, alternating
773 variational ansatz layers with input encoding layers, and optional
774 noise channels.
776 The first five parameters (after ``self``) - ``params``, ``inputs``,
777 ``pulse_params``, ``random_key``, ``enc_params`` - are the batchable
778 positional arguments.
779 The remaining keyword arguments are broadcast across the batch.
781 Args:
782 params (jnp.ndarray): Variational parameters of shape
783 (n_layers, n_params_per_layer).
784 inputs (jnp.ndarray): Input data of shape (n_input_feat,).
785 pulse_params (Optional[jnp.ndarray]): Pulse parameter scalers of shape
786 (n_layers, n_pulse_params_per_layer) for pulse-mode execution.
787 Defaults to None (uses model's pulse_params).
788 random_key (Optional[random.PRNGKey]): JAX random key for stochastic
789 operations. Defaults to None.
790 enc_params (Optional[jnp.ndarray]): Encoding parameters of shape
791 (n_qubits, n_input_feat). Defaults to None (uses model's enc_params).
792 gate_mode (str): Gate execution mode, either "unitary" or "pulse".
793 Defaults to "unitary".
794 noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]):
795 Noise parameters for simulation. Defaults to None.
797 Returns:
798 None: Gates are applied in-place to the quantum circuit.
800 Note:
801 Issues RuntimeWarning if called directly without providing parameters
802 that would normally be passed through the forward method.
803 """
804 # TODO: rework and double check params shape
805 if len(params.shape) > 2 and params.shape[0] == 1:
806 params = params[0]
808 if len(inputs.shape) > 1 and inputs.shape[0] == 1:
809 inputs = inputs[0]
811 if enc_params is None:
812 # TODO: Raise warning if trainable frequencies is True, or similar. I.e., no
813 # warning if user does not care for frequencies or enc_params
814 if self.trainable_frequencies:
815 warnings.warn(
816 "Explicit call to `_circuit` or `_variational` detected: "
817 "`enc_params` is None, using `self.enc_params` instead.",
818 RuntimeWarning,
819 )
820 enc_params = self.enc_params
822 if pulse_params is None:
823 if gate_mode == "pulse":
824 warnings.warn(
825 "Explicit call to `_circuit` or `_variational` detected: "
826 "`pulse_params` is None, using `self.pulse_params` instead.",
827 RuntimeWarning,
828 )
829 pulse_params = self.pulse_params
831 # Squeeze batch dimension for pulse_params (batch-first convention)
832 if len(pulse_params.shape) > 2 and pulse_params.shape[0] == 1:
833 pulse_params = pulse_params[0]
835 if noise_params is None:
836 if self.noise_params is not None:
837 warnings.warn(
838 "Explicit call to `_circuit` or `_variational` detected: "
839 "`noise_params` is None, using `self.noise_params` instead.",
840 RuntimeWarning,
841 )
842 noise_params = self.noise_params
844 if noise_params is not None:
845 if random_key is None:
846 warnings.warn(
847 "Explicit call to `_circuit` or `_variational` detected: "
848 "`random_key` is None, using `random.PRNGKey(0)` instead.",
849 RuntimeWarning,
850 )
851 random_key = self.random_key
852 self._apply_state_prep_noise(noise_params=noise_params)
854 # state preparation
855 for q in range(self.n_qubits):
856 for _sp, sp_pulse_params in zip(self._sp, self.sp_pulse_params):
857 random_key, sub_key = safe_random_split(random_key)
858 _sp(
859 wires=q,
860 pulse_params=sp_pulse_params,
861 noise_params=noise_params,
862 random_key=sub_key,
863 gate_mode=gate_mode,
864 )
866 # circuit building
867 for layer in range(0, self.n_layers):
868 random_key, sub_key = safe_random_split(random_key)
869 # ansatz layers
870 self.pqc(
871 params[layer],
872 self.n_qubits,
873 pulse_params=pulse_params[layer],
874 noise_params=noise_params,
875 random_key=sub_key,
876 gate_mode=gate_mode,
877 )
879 random_key, sub_key = safe_random_split(random_key)
880 # encoding layers
881 self._iec(
882 inputs,
883 data_reupload=self.data_reupload[layer],
884 enc=self._enc,
885 enc_params=enc_params[layer],
886 noise_params=noise_params,
887 random_key=sub_key,
888 )
890 # visual barrier (no-op in yaqsi, purely cosmetic in PennyLane)
892 # final ansatz layer
893 if self.has_dru: # same check as in init
894 random_key, sub_key = safe_random_split(random_key)
895 self.pqc(
896 params[self.n_layers],
897 self.n_qubits,
898 pulse_params=pulse_params[-1],
899 noise_params=noise_params,
900 random_key=sub_key,
901 gate_mode=gate_mode,
902 )
904 # channel noise
905 if noise_params is not None:
906 self._apply_general_noise(noise_params=noise_params)
908 def _build_obs(self) -> Tuple[str, List[op.Operation]]:
909 """Build the yaqsi measurement type and observable list.
911 Translates the model's ``execution_type`` and ``output_qubit``
912 settings into parameters suitable for
913 :meth:`~qml_essentials.yaqsi.Script.execute`.
915 Returns:
916 Tuple ``(meas_type, obs)`` where *meas_type* is one of
917 ``"expval"``, ``"probs"``, ``"density"``, ``"state"`` and *obs*
918 is a (possibly empty) list of :class:`Operation` observables.
919 """
920 if self.execution_type == "density":
921 return "density", []
923 if self.execution_type == "state":
924 return "state", []
926 if self.execution_type == "expval":
927 obs: List[op.Operation] = []
928 for qubit_spec in self.output_qubit:
929 if isinstance(qubit_spec, int):
930 obs.append(op.PauliZ(wires=qubit_spec))
931 else:
932 # parity: Z \\otimes Z \\otimes …
933 obs.append(ys.build_parity_observable(list(qubit_spec)))
934 return "expval", obs
936 if self.execution_type == "probs":
937 # probs are computed on the full system; subsystem
938 # marginalisation is handled in _postprocess_res
939 return "probs", []
941 raise ValueError(f"Invalid execution_type: {self.execution_type}.")
943 def _apply_state_prep_noise(
944 self, noise_params: Dict[str, Union[float, Dict[str, float]]]
945 ) -> None:
946 """
947 Apply state preparation noise to all qubits.
949 Simulates imperfect state preparation by applying BitFlip errors
950 to each qubit with the specified probability.
952 Args:
953 noise_params (Dict[str, Union[float, Dict[str, float]]]): Dictionary
954 containing noise parameters. Uses the "StatePreparation" key
955 for the BitFlip probability.
957 Returns:
958 None: Noise channels are applied in-place to the circuit.
959 """
960 p = noise_params.get("StatePreparation", 0.0)
961 if p > 0:
962 for q in range(self.n_qubits):
963 op.BitFlip(p, wires=q)
965 def _apply_general_noise(
966 self, noise_params: Dict[str, Union[float, Dict[str, float]]]
967 ) -> None:
968 """
969 Apply general noise channels to all qubits.
971 Applies various decoherence and error channels after the circuit
972 execution, simulating environmental noise effects.
974 Args:
975 noise_params (Dict[str, Union[float, Dict[str, float]]]): Dictionary
976 containing noise parameters with the following supported keys:
977 - "AmplitudeDamping" (float): Probability for amplitude damping.
978 - "PhaseDamping" (float): Probability for phase damping.
979 - "Measurement" (float): Probability for measurement error (BitFlip).
980 - "ThermalRelaxation" (Dict): Dictionary with keys "t1", "t2",
981 "t_factor" for thermal relaxation simulation.
983 Returns:
984 None: Noise channels are applied in-place to the circuit.
986 Note:
987 Gate-level noise (e.g., GateError) is handled separately in the
988 Gates.Noise module and applied at the individual gate level.
989 """
990 amp_damp = noise_params.get("AmplitudeDamping", 0.0)
991 phase_damp = noise_params.get("PhaseDamping", 0.0)
992 thermal_relax = noise_params.get("ThermalRelaxation", 0.0)
993 meas = noise_params.get("Measurement", 0.0)
994 for q in range(self.n_qubits):
995 if amp_damp > 0:
996 op.AmplitudeDamping(amp_damp, wires=q)
997 if phase_damp > 0:
998 op.PhaseDamping(phase_damp, wires=q)
999 if meas > 0:
1000 op.BitFlip(meas, wires=q)
1001 if isinstance(thermal_relax, dict):
1002 t1 = thermal_relax["t1"]
1003 t2 = thermal_relax["t2"]
1004 t_factor = thermal_relax["t_factor"]
1005 circuit_depth = self._get_circuit_depth()
1006 tg = circuit_depth * t_factor
1007 op.ThermalRelaxationError(1.0, t1, t2, tg, q)
1009 def _get_circuit_depth(self, inputs: Optional[jnp.ndarray] = None) -> int:
1010 """
1011 Calculate the depth of the quantum circuit.
1013 Records the circuit onto a tape (without noise) and computes the
1014 depth as the length of the critical path: each gate is scheduled
1015 at the earliest time step after all of its qubits are free.
1017 Args:
1018 inputs (Optional[jnp.ndarray]): Input data for circuit evaluation.
1019 If None, default zero inputs are used.
1021 Returns:
1022 int: The circuit depth (longest path of gates in the circuit).
1023 """
1024 # Return cached value if available
1025 if hasattr(self, "_cached_circuit_depth"):
1026 return self._cached_circuit_depth
1028 inputs = self._inputs_validation(inputs)
1030 # Temporarily clear noise_params to prevent _variational from
1031 # picking them up (which would call _apply_general_noise ->
1032 # _get_circuit_depth again, causing infinite recursion).
1033 saved_noise = self._noise_params
1034 self._noise_params = None
1036 with recording() as tape:
1037 self._variational(
1038 self.params[0] if self.params.ndim == 3 else self.params,
1039 inputs[0] if inputs.ndim == 2 else inputs,
1040 noise_params=None,
1041 )
1043 self._noise_params = saved_noise
1045 # Filter out noise channels - only count unitary gates
1046 ops = [o for o in tape if not isinstance(o, KrausChannel)]
1048 if not ops:
1049 self._cached_circuit_depth = 0
1050 return 0
1052 # Schedule each gate at the earliest time step where all its wires
1053 # are free. ``wire_busy[q]`` tracks the next free time step for
1054 # qubit ``q``.
1055 wire_busy: Dict[int, int] = {}
1056 depth = 0
1057 for gate in ops:
1058 start = max((wire_busy.get(w, 0) for w in gate.wires), default=0)
1059 end = start + 1
1060 for w in gate.wires:
1061 wire_busy[w] = end
1062 depth = max(depth, end)
1064 self._cached_circuit_depth = depth
1065 return depth
1067 def draw(
1068 self,
1069 inputs: Optional[jnp.ndarray] = None,
1070 figure: str = "text",
1071 **kwargs: Any,
1072 ) -> Union[str, Any]:
1073 """Visualize the quantum circuit.
1075 Records the circuit tape (without noise) and renders the gate
1076 sequence using the requested backend.
1078 Args:
1079 inputs (Optional[jnp.ndarray]): Input data for the circuit.
1080 If ``None``, default zero inputs are used.
1081 figure (str): Rendering backend. One of:
1083 * ``"text"`` - ASCII art (returned as a ``str``).
1084 * ``"mpl"`` - Matplotlib figure (returns ``(fig, ax)``).
1085 * ``"tikz"`` - LaTeX/TikZ ``quantikz`` code (returns a
1086 :class:`TikzFigure`).
1087 * ``"pulse"`` - Pulse schedule (returns ``(fig, axes)``).
1088 Only meaningful for pulse-mode models.
1090 **kwargs: Extra options forwarded to the drawing backend
1091 (e.g. ``gate_values=True``).
1093 Returns:
1094 Depends on figure:
1096 * ``"text"`` -> ``str``
1097 * ``"mpl"`` -> ``(matplotlib.figure.Figure, matplotlib.axes.Axes)``
1098 * ``"tikz"`` -> :class:`TikzFigure`
1100 Raises:
1101 ValueError: If figure is not one of the supported modes.
1102 """
1103 inputs = self._inputs_validation(inputs)
1104 params = self.params[0] if self.params.ndim == 3 else self.params
1105 inp = inputs[0] if inputs.ndim == 2 else inputs
1107 if figure == "pulse":
1108 return self.draw_pulse(inputs=inputs, **kwargs)
1110 # Record without noise to get a clean circuit
1111 saved_noise = self._noise_params
1112 self._noise_params = None
1114 draw_script = ys.Script(f=self._variational, n_qubits=self.n_qubits)
1115 result = draw_script.draw(
1116 figure=figure,
1117 args=(params, inp),
1118 kwargs={"noise_params": None},
1119 **kwargs,
1120 )
1122 self._noise_params = saved_noise
1123 return result
1125 def draw_pulse(
1126 self,
1127 inputs: Optional[jnp.ndarray] = None,
1128 **kwargs: Any,
1129 ) -> Any:
1130 """Visualize the pulse schedule for the circuit.
1132 Records the circuit in pulse mode and collects PulseEvents
1133 automatically via the pulse-event tape, then renders them.
1135 Args:
1136 inputs: Input data. If ``None``, default zero inputs are used.
1137 **kwargs: Forwarded to
1138 :func:`~qml_essentials.drawing.draw_pulse_schedule`
1139 (e.g. ``show_carrier=True``, ``n_samples=300``).
1141 Returns:
1142 ``(fig, axes)`` — Matplotlib Figure and array of Axes.
1143 """
1144 inputs = self._inputs_validation(inputs)
1145 params = self.params[0] if self.params.ndim == 3 else self.params
1146 inp = inputs[0] if inputs.ndim == 2 else inputs
1148 draw_script = ys.Script(f=self._variational, n_qubits=self.n_qubits)
1149 return draw_script.draw(
1150 figure="pulse",
1151 args=(params, inp),
1152 kwargs={
1153 "gate_mode": "pulse",
1154 "noise_params": None,
1155 },
1156 **kwargs,
1157 )
1159 def __repr__(self) -> str:
1160 """Return text representation of the quantum circuit model."""
1161 return self.draw(figure="text")
1163 def __str__(self) -> str:
1164 """Return string representation of the quantum circuit model."""
1165 return self.draw(figure="text")
1167 def _params_validation(self, params: Optional[jnp.ndarray]) -> jnp.ndarray:
1168 """
1169 Validate and normalize variational parameters.
1171 Ensures parameters have the correct shape with a batch dimension,
1172 and updates the model's internal parameters if new ones are provided.
1174 Args:
1175 params (Optional[jnp.ndarray]): Variational parameters to validate.
1176 If None, returns the model's current parameters.
1178 Returns:
1179 jnp.ndarray: Validated parameters with shape
1180 (batch_size, n_layers, n_params_per_layer).
1181 """
1182 # append batch axis if not provided
1183 if params is not None:
1184 if len(params.shape) == 2:
1185 params = np.expand_dims(params, axis=0)
1187 # Avoid stashing JAX tracers on ``self``: under an outer
1188 # transform (e.g. ``jacrev``) the tracer becomes invalid once
1189 # the transform returns, and a subsequent read of
1190 # ``self.params`` would feed a leaked tracer into the next
1191 # call (raising ``UnexpectedTracerError``).
1192 # if not isinstance(params, jax.core.Tracer):
1193 # self.params = params
1194 self.params = params
1195 else:
1196 params = self.params
1198 return params
1200 def _pulse_params_validation(
1201 self, pulse_params: Optional[jnp.ndarray]
1202 ) -> jnp.ndarray:
1203 """
1204 Validate and normalize pulse parameters.
1206 Ensures pulse parameters are set, using model defaults if not provided.
1208 Args:
1209 pulse_params (Optional[jnp.ndarray]): Pulse parameter scalers.
1210 If None, returns the model's current pulse parameters.
1212 Returns:
1213 jnp.ndarray: Validated pulse parameters with shape
1214 (batch_size, n_layers, n_pulse_params_per_layer).
1215 """
1216 if pulse_params is None:
1217 pulse_params = self.pulse_params
1218 else:
1219 # ensure batch dimension exists (batch-first convention)
1220 if len(pulse_params.shape) == 2:
1221 pulse_params = jnp.expand_dims(pulse_params, axis=0)
1222 # See note in _params_validation: never stash JAX tracers on
1223 # ``self``.
1224 # if not isinstance(pulse_params, jax.core.Tracer):
1225 # self.pulse_params = pulse_params
1226 self.pulse_params = pulse_params
1228 return pulse_params
1230 def _enc_params_validation(self, enc_params: Optional[jnp.ndarray]) -> jnp.ndarray:
1231 """
1232 Validate and normalize encoding parameters.
1234 Ensures encoding parameters have the correct shape for the model's
1235 input feature dimensions.
1237 Args:
1238 enc_params (Optional[jnp.ndarray]): Encoding parameters to validate.
1239 If None, returns the model's current encoding parameters.
1241 Returns:
1242 jnp.ndarray: Validated encoding parameters with shape
1243 (n_qubits, n_input_feat).
1245 Raises:
1246 ValueError: If enc_params shape is incompatible with n_input_feat > 1.
1247 """
1248 if enc_params is None:
1249 enc_params = self.enc_params
1250 else:
1251 # See note in _params_validation: never stash JAX tracers on
1252 # ``self``.
1253 # if not isinstance(enc_params, jax.core.Tracer):
1254 # if self.trainable_frequencies:
1255 # self.enc_params = enc_params
1256 # else:
1257 # self.enc_params = jnp.array(enc_params)
1258 if self.trainable_frequencies:
1259 self.enc_params = enc_params
1260 else:
1261 self.enc_params = jnp.array(enc_params)
1263 if len(enc_params.shape) == 1 and self.n_input_feat == 1:
1264 enc_params = enc_params.reshape(-1, 1)
1265 elif len(enc_params.shape) == 1 and self.n_input_feat > 1:
1266 raise ValueError(
1267 f"Input dimension {self.n_input_feat} >1 but \
1268 `enc_params` has shape {enc_params.shape}"
1269 )
1271 return enc_params
1273 def _inputs_validation(
1274 self, inputs: Union[None, List, float, int, jnp.ndarray]
1275 ) -> jnp.ndarray:
1276 """
1277 Validate and normalize input data.
1279 Converts various input formats to a standardized 2D array shape
1280 suitable for batch processing in the quantum circuit.
1282 Args:
1283 inputs (Union[None, List, float, int, jnp.ndarray]): Input data in
1284 various formats:
1285 - None: Returns zeros with shape (1, n_input_feat)
1286 - float/int: Single scalar value
1287 - List: List of values or batched inputs
1288 - jnp.ndarray: NumPy/JAX array
1290 Returns:
1291 jnp.ndarray: Validated inputs with shape (batch_size, n_input_feat).
1293 Raises:
1294 ValueError: If input shape is incompatible with expected n_input_feat.
1296 Warns:
1297 UserWarning: If input is replicated to match n_input_feat.
1298 """
1299 self._zero_inputs = False
1300 if isinstance(inputs, List):
1301 inputs = jnp.array(np.stack(inputs))
1302 elif isinstance(inputs, float) or isinstance(inputs, int):
1303 inputs = jnp.array([inputs])
1304 elif inputs is None:
1305 inputs = jnp.array([[0] * self.n_input_feat])
1307 if not inputs.any():
1308 self._zero_inputs = True
1310 if len(inputs.shape) <= 1:
1311 if self.n_input_feat == 1:
1312 # add a batch dimension
1313 inputs = inputs.reshape(-1, 1)
1314 else:
1315 if inputs.shape[0] == self.n_input_feat:
1316 inputs = inputs.reshape(1, -1)
1317 else:
1318 inputs = inputs.reshape(-1, 1)
1319 inputs = inputs.repeat(self.n_input_feat, axis=1)
1320 warnings.warn(
1321 f"Expected {self.n_input_feat} inputs, but {inputs.shape[0]} "
1322 "was provided, replicating input for all input features.",
1323 UserWarning,
1324 )
1325 else:
1326 if inputs.shape[1] != self.n_input_feat:
1327 raise ValueError(
1328 f"Wrong number of inputs provided. Expected {self.n_input_feat} "
1329 f"inputs, but input has shape {inputs.shape}."
1330 )
1332 return inputs
1334 def _postprocess_res(self, result: Union[List, jnp.ndarray]) -> jnp.ndarray:
1335 """
1336 Post-process circuit execution results for uniform shape.
1338 Converts list outputs (from multiple measurements) to stacked arrays
1339 and reorders axes for consistent batch dimension placement.
1341 Args:
1342 result (Union[List, jnp.ndarray]): Raw circuit output, either a
1343 list of measurement results or a single array.
1345 Returns:
1346 jnp.ndarray: Uniformly shaped result array with batch dimension first.
1347 """
1348 if isinstance(result, list):
1349 # we use moveaxis here because in case of parity measure,
1350 # there is another dimension appended to the end and
1351 # simply transposing would result in a wrong shape
1352 result = jnp.stack(result)
1353 if len(result.shape) > 1:
1354 result = jnp.moveaxis(result, 0, 1)
1355 return result
1357 def _assimilate_batch(
1358 self,
1359 inputs: jnp.ndarray,
1360 params: jnp.ndarray,
1361 pulse_params: jnp.ndarray,
1362 ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
1363 """
1364 Align batch dimensions across inputs, parameters, and pulse parameters.
1366 Broadcasts and reshapes arrays to have compatible batch dimensions
1367 for vectorized circuit execution. Sets the internal batch_shape.
1369 Args:
1370 inputs (jnp.ndarray): Input data of shape (B_I, n_input_feat).
1371 params (jnp.ndarray): Parameters of shape (B_P, n_layers, n_params).
1372 pulse_params (jnp.ndarray): Pulse params of shape (B_R, n_layers, n_pulse).
1374 Returns:
1375 Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: Tuple containing:
1376 - inputs: Reshaped to (B, n_input_feat) where B = B_I * B_P * B_R
1377 - params: Reshaped to (B, n_layers, n_params)
1378 - pulse_params: Reshaped to (B, n_layers, n_pulse)
1380 Note:
1381 The effective batch shape depends on repeat_batch_axis configuration.
1382 This is the only method that sets self._batch_shape.
1383 """
1384 B_I = inputs.shape[0]
1385 # we check for the product because there is a chance that
1386 # there are no params. In this case we want B_P to be 1
1387 B_P = 1 if 0 in params.shape else params.shape[0]
1388 B_R = pulse_params.shape[0]
1390 # THIS is the only place where we set the batch shape
1391 self._batch_shape = (B_I, B_P, B_R)
1392 B = np.prod(self.eff_batch_shape)
1394 # [B_I, ...] -> [B_I, B_P, B_R, ...] -> [B, ...]
1395 if B_I > 1 and self.repeat_batch_axis[0]:
1396 if self.repeat_batch_axis[1]:
1397 inputs = jnp.repeat(inputs[:, None, None, ...], B_P, axis=1)
1398 if self.repeat_batch_axis[2]:
1399 inputs = jnp.repeat(inputs, B_R, axis=2)
1400 inputs = inputs.reshape(B, *inputs.shape[3:])
1402 # [B_P, ..., ...] -> [B_I, B_P, B_R, ..., ...] -> [B, ..., ...]
1403 if B_P > 1 and self.repeat_batch_axis[1]:
1404 # add B_I axis before first, and B_R axis after first batch dim
1405 params = params[None, :, None, ...] # [B_I(=1), B_P, B_R(=1), ...]
1406 if self.repeat_batch_axis[0]:
1407 params = jnp.repeat(params, B_I, axis=0) # [B_I, B_P, 1, ...]
1408 if self.repeat_batch_axis[2]:
1409 params = jnp.repeat(params, B_R, axis=2) # [B_I, B_P, B_R, ...]
1410 params = params.reshape(B, *params.shape[3:])
1412 # [B_R, ..., ...] -> [B_I, B_P, B_R, ..., ...] -> [B, ..., ...]
1413 if B_R > 1 and self.repeat_batch_axis[2]:
1414 # add B_I axis and B_P axis before B_R
1415 pulse_params = pulse_params[None, None, ...] # [B_I(=1), B_P(=1), B_R, ...]
1416 if self.repeat_batch_axis[0]:
1417 pulse_params = jnp.repeat(
1418 pulse_params, B_I, axis=0
1419 ) # [B_I, 1, B_R, ...]
1420 if self.repeat_batch_axis[1]:
1421 pulse_params = jnp.repeat(
1422 pulse_params, B_P, axis=1
1423 ) # [B_I, B_P, B_R, ...]
1424 pulse_params = pulse_params.reshape(B, *pulse_params.shape[3:])
1426 return inputs, params, pulse_params
1428 def _requires_density(self) -> bool:
1429 """
1430 Check if density matrix simulation is required.
1432 Determines whether the circuit must be executed with the mixed-state
1433 simulator based on execution type and noise configuration.
1435 Returns:
1436 bool: True if density matrix simulation is required, False otherwise.
1437 Returns True if:
1438 - execution_type is "density", or
1439 - Any non-coherent noise channel has non-zero probability
1440 """
1441 if self.execution_type == "density":
1442 return True
1444 if self.noise_params is None:
1445 return False
1447 coherent_noise = {"GateError"}
1448 for k, v in self.noise_params.items():
1449 if k in coherent_noise:
1450 continue
1451 if v is not None and v > 0:
1452 return True
1453 return False
1455 def __call__(
1456 self,
1457 params: Optional[jnp.ndarray] = None,
1458 inputs: Optional[jnp.ndarray] = None,
1459 pulse_params: Optional[jnp.ndarray] = None,
1460 enc_params: Optional[jnp.ndarray] = None,
1461 data_reupload: Union[bool, List[List[bool]], List[List[List[bool]]]] = None,
1462 noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None,
1463 execution_type: Optional[str] = None,
1464 force_mean: bool = False,
1465 gate_mode: str = "unitary",
1466 ) -> jnp.ndarray:
1467 """
1468 Execute the quantum circuit (callable interface).
1470 Provides a convenient callable interface for circuit execution,
1471 delegating to the _forward method.
1473 Args:
1474 params (Optional[jnp.ndarray]): Variational parameters of shape
1475 (n_layers, n_params_per_layer) or (batch, n_layers, n_params_per_layer).
1476 If None, uses model's internal parameters.
1477 inputs (Optional[jnp.ndarray]): Input data of shape
1478 (batch_size, n_input_feat). If None, uses zero inputs.
1479 pulse_params (Optional[jnp.ndarray]): Pulse parameter scalers for
1480 pulse-mode gate execution.
1481 enc_params (Optional[jnp.ndarray]): Encoding parameters of shape
1482 (n_qubits, n_input_feat). If None, uses model's encoding parameters.
1483 data_reupload (Union[bool, List[List[bool]], List[List[List[bool]]]]):
1484 Data reupload configuration. If None, uses previously set reupload
1485 configuration.
1486 noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]):
1487 Noise configuration. If None, uses previously set noise parameters.
1488 execution_type (Optional[str]): Measurement type: "expval", "density",
1489 "probs", or "state". If None, uses current execution_type setting.
1490 force_mean (bool): If True, averages results over measurement qubits.
1491 Defaults to False.
1492 gate_mode (str): Gate execution backend, "unitary" or "pulse".
1493 Defaults to "unitary".
1495 Returns:
1496 jnp.ndarray: Circuit output with shape depending on execution_type:
1497 - "expval": (n_output_qubits,) or scalar
1498 - "density": (2^n_output, 2^n_output)
1499 - "probs": (2^n_output,) or (n_pairs, 2^pair_size)
1500 - "state": (2^n_qubits,)
1501 """
1502 # Call forward method which handles the actual caching etc.
1503 return self._forward(
1504 params=params,
1505 inputs=inputs,
1506 pulse_params=pulse_params,
1507 enc_params=enc_params,
1508 data_reupload=data_reupload,
1509 noise_params=noise_params,
1510 execution_type=execution_type,
1511 force_mean=force_mean,
1512 gate_mode=gate_mode,
1513 )
1515 def _forward(
1516 self,
1517 params: Optional[jnp.ndarray] = None,
1518 inputs: Optional[jnp.ndarray] = None,
1519 pulse_params: Optional[jnp.ndarray] = None,
1520 enc_params: Optional[jnp.ndarray] = None,
1521 data_reupload: Union[bool, List[List[bool]], List[List[List[bool]]]] = None,
1522 noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None,
1523 execution_type: Optional[str] = None,
1524 force_mean: bool = False,
1525 gate_mode: str = "unitary",
1526 ) -> jnp.ndarray:
1527 """
1528 Execute the quantum circuit forward pass.
1530 Internal implementation of the forward pass that handles parameter
1531 validation, batch alignment, and circuit execution routing.
1533 Args:
1534 params (Optional[jnp.ndarray]): Variational parameters of shape
1535 (n_layers, n_params_per_layer) or
1536 (batch, n_layers, n_params_per_layer).
1537 If None, uses model's internal parameters.
1538 inputs (Optional[jnp.ndarray]): Input data of shape
1539 (batch_size, n_input_feat).
1540 If None, uses zero inputs.
1541 pulse_params (Optional[jnp.ndarray]): Pulse parameter scalers for
1542 pulse-mode gate execution.
1543 enc_params (Optional[jnp.ndarray]): Encoding parameters of shape
1544 (n_qubits, n_input_feat). If None, uses model's encoding parameters.
1545 data_reupload (Union[bool, List[List[bool]], List[List[List[bool]]]]):
1546 Data reupload configuration. If None, uses previously set reupload
1547 configuration.
1548 noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]):
1549 Noise configuration. If None, uses previously set noise parameters.
1550 execution_type (Optional[str]): Measurement type: "expval", "density",
1551 "probs", or "state". If None, uses current execution_type setting.
1552 force_mean (bool): If True, averages results over measurement qubits.
1553 Defaults to False.
1554 gate_mode (str): Gate execution backend, "unitary" or "pulse".
1555 Defaults to "unitary".
1557 Returns:
1558 jnp.ndarray: Circuit output with shape depending on execution_type:
1559 - "expval": (n_output_qubits,) or scalar
1560 - "density": (2^n_output, 2^n_output)
1561 - "probs": (2^n_output,) or (n_pairs, 2^pair_size)
1562 - "state": (2^n_qubits,)
1564 Raises:
1565 ValueError: If pulse_params provided without pulse gate_mode, or
1566 if noise_params provided with pulse gate_mode.
1567 """
1568 # set the parameters as object attributes
1569 if noise_params is not None:
1570 self.noise_params = noise_params
1571 if execution_type is not None:
1572 self.execution_type = execution_type
1573 self.gate_mode = gate_mode
1575 # consistency checks
1576 if pulse_params is not None and gate_mode != "pulse":
1577 raise ValueError(
1578 "pulse_params were provided but gate_mode is not 'pulse'. "
1579 "Either switch gate_mode='pulse' or do not pass pulse_params."
1580 )
1582 # TODO: add testing
1583 if data_reupload is not None:
1584 self.data_reupload = data_reupload
1586 params = self._params_validation(params)
1587 pulse_params = self._pulse_params_validation(pulse_params)
1588 inputs = self._inputs_validation(inputs)
1589 enc_params = self._enc_params_validation(enc_params)
1591 inputs, params, pulse_params = self._assimilate_batch(
1592 inputs,
1593 params,
1594 pulse_params,
1595 )
1597 # split to generate a sub_key, required for actual execution
1598 self.random_key, sub_key = safe_random_split(self.random_key)
1600 # Build measurement type & observables from execution_type / output_qubit
1601 meas_type, obs = self._build_obs()
1603 # Yaqsi auto-routes between statevector and density-matrix simulation
1604 # based on whether noise channels appear on the tape, so a single
1605 B = np.prod(self.eff_batch_shape)
1607 # kwargs are broadcast (not vmapped over)
1608 exec_kwargs = dict(
1609 noise_params=self.noise_params,
1610 gate_mode=self.gate_mode,
1611 )
1613 # Build a shot key from the random_key if shots are requested
1614 shot_key = None
1615 if self.shots is not None:
1616 # overwrite subkey and split shot_key
1617 sub_key, shot_key = safe_random_split(sub_key)
1619 if B > 1:
1620 # use random keys, derived from the subkey
1621 random_keys = safe_random_split(sub_key, num=B)
1623 in_axes = (
1624 0 if self.batch_shape[1] > 1 else None, # params
1625 0 if self.batch_shape[0] > 1 else None, # inputs
1626 0 if self.batch_shape[2] > 1 else None, # pulse_params
1627 0, # random_keys
1628 None, # enc_params (broadcast, not batched)
1629 )
1631 result = self.script.execute(
1632 type=meas_type,
1633 obs=obs,
1634 args=(params, inputs, pulse_params, random_keys, enc_params),
1635 kwargs=exec_kwargs,
1636 in_axes=in_axes,
1637 shots=self.shots,
1638 key=shot_key,
1639 )
1640 else:
1641 # use the subkey directly
1642 result = self.script.execute(
1643 type=meas_type,
1644 obs=obs,
1645 args=(params, inputs, pulse_params, sub_key, enc_params),
1646 kwargs=exec_kwargs,
1647 shots=self.shots,
1648 key=shot_key,
1649 )
1651 result = self._postprocess_res(result)
1653 # --- Post-processing for partial-qubit measurements ---------------
1654 if self.execution_type == "density" and not self.all_qubit_measurement:
1655 result = ys.partial_trace(result, self.n_qubits, self.output_qubit)
1657 if self.execution_type == "probs" and not self.all_qubit_measurement:
1658 if isinstance(self.output_qubit[0], (list, tuple)):
1659 # list of qubit groups - marginalize each independently
1660 result = jnp.stack(
1661 [
1662 ys.marginalize_probs(result, self.n_qubits, list(group))
1663 for group in self.output_qubit
1664 ]
1665 )
1666 else:
1667 result = ys.marginalize_probs(result, self.n_qubits, self.output_qubit)
1669 result = jnp.asarray(result)
1670 result = result.reshape((*self.eff_batch_shape, *self._result_shape)).squeeze()
1672 if (
1673 self.execution_type in ("expval", "probs")
1674 and force_mean
1675 and len(result.shape) > 0
1676 and self._result_shape[0] > 1
1677 ):
1678 result = result.mean(axis=-1)
1680 return result