Coverage for qml_essentials / model.py: 91%
529 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-11 15:51 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-11 15:51 +0000
1from typing import Any, Dict, Optional, Tuple, Callable, Union, List
3import warnings
4import jax.numpy as jnp
5import numpy as np
6from jax import random
8from qml_essentials import jaqsi as js
9from qml_essentials import operations as op
10from qml_essentials.tape import recording
11from qml_essentials.operations import KrausChannel
12from qml_essentials.ansaetze import Ansaetze, Circuit, Encoding
13from qml_essentials.gates import Gates, PulseInformation as pinfo
14from qml_essentials.utils import safe_random_split
16import logging
18log = logging.getLogger(__name__)
21class Model:
22 """
23 A quantum circuit model.
24 """
26 def __init__(
27 self,
28 n_qubits: int,
29 n_layers: int,
30 circuit_type: Union[str, Circuit] = "No_Ansatz",
31 data_reupload: Union[bool, List[List[bool]], List[List[List[bool]]]] = True,
32 state_preparation: Union[
33 str, Callable, List[Union[str, Callable]], None
34 ] = None,
35 encoding: Union[Encoding, str, Callable, List[Union[str, Callable]]] = Gates.RX,
36 trainable_frequencies: bool = False,
37 initialization: str = "random",
38 initialization_domain: List[float] = [0, 2 * jnp.pi],
39 output_qubit: Union[List[int], int] = -1,
40 shots: Optional[int] = None,
41 random_seed: int = 1000,
42 remove_zero_encoding: bool = True,
43 repeat_batch_axis: List[bool] = [True, True, True],
44 pulse_shape: str = "gaussian",
45 ) -> None:
46 """
47 Initialize the quantum circuit model.
48 Parameters will have the shape [impl_n_layers, parameters_per_layer]
49 where impl_n_layers is the number of layers provided and added by one
50 depending if data_reupload is True and parameters_per_layer is given by
51 the chosen ansatz.
53 The model is initialized with the following parameters as defaults:
54 - noise_params: None
55 - execution_type: "expval"
56 - shots: None
58 Args:
59 n_qubits (int): The number of qubits in the circuit.
60 n_layers (int): The number of layers in the circuit.
61 circuit_type (str, Circuit): The type of quantum circuit to use.
62 If None, defaults to "no_ansatz".
63 data_reupload (Union[bool, List[bool], List[List[bool]]], optional):
64 Whether to reupload data to the quantum device on each
65 layer and qubit. Detailed re-uploading instructions can be given
66 as a list/array of 0/False and 1/True with shape (n_qubits,
67 n_layers) to specify where to upload the data. Defaults to True
68 for applying data re-uploading to the full circuit.
69 encoding (Union[str, Callable, List[str], List[Callable]], optional):
70 The unitary to use for encoding the input data. Can be a string
71 (e.g. "RX") or a callable (e.g. op.RX). Defaults to op.RX.
72 If input is multidimensional it is assumed to be a list of
73 unitaries or a list of strings.
74 trainable_frequencies (bool, optional):
75 Sets trainable encoding parameters for trainable frequencies.
76 Defaults to False.
77 initialization (str, optional): The strategy to initialize the parameters.
78 Can be "random", "zeros", "zero-controlled", "pi", or "pi-controlled".
79 Defaults to "random".
80 output_qubit (List[int], int, optional): The index of the output
81 qubit (or qubits). When set to -1 all qubits are measured, or a
82 global measurement is conducted, depending on the execution
83 type.
84 shots (Optional[int], optional): The number of shots to use for
85 the quantum device. Defaults to None.
86 random_seed (int, optional): seed for the random number generator
87 in initialization is "random" and for random noise parameters.
88 Defaults to 1000.
89 remove_zero_encoding (bool, optional): whether to
90 remove the zero encoding from the circuit. Defaults to True.
91 repeat_batch_axis (List[bool], optional): Each boolean in the array
92 determines over which axes to parallelise computation. The axes
93 correspond to [inputs, params, pulse_params]. Defaults to
94 [True, True, True], meaning that batching is enabled over all
95 axes.
96 pulse_shape (str, optional): Pulse envelope shape for pulse-level
97 simulation. One of ``PulseEnvelope.available()``.
98 Defaults to ``"gaussian"``.
100 Returns:
101 None
102 """
103 # Initialize default parameters needed for circuit evaluation
104 self.n_qubits: int = n_qubits
105 self.output_qubit: Union[List[int], int] = output_qubit
106 self.n_layers: int = n_layers
107 self.noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None
108 self.shots = shots
109 self.remove_zero_encoding = remove_zero_encoding
110 self.trainable_frequencies: bool = trainable_frequencies
111 self.execution_type: str = "expval"
112 self.repeat_batch_axis: List[bool] = repeat_batch_axis
114 # --- Pulse envelope ---
115 pinfo.set_envelope(pulse_shape)
117 # --- State Preparation ---
118 try:
119 self._sp = Gates.parse_gates(state_preparation, Gates)
120 except ValueError as e:
121 raise ValueError(f"Error parsing encodings: {e}")
123 # prepare corresponding pulse parameters (always optimized pulses)
124 self.sp_pulse_params = []
125 for sp in self._sp:
126 sp_name = sp.__name__ if hasattr(sp, "__name__") else str(sp)
128 if pinfo.gate_by_name(sp_name) is not None:
129 self.sp_pulse_params.append(pinfo.gate_by_name(sp_name).params)
130 else:
131 # gate has no pulse parametrization
132 self.sp_pulse_params.append(None)
134 # --- Encoding ---
135 if isinstance(encoding, Encoding):
136 # user wants custom strategy? do it!
137 self._enc = encoding
138 else:
139 # use hammming encoding by default
140 self._enc = Encoding("hamming", encoding)
142 if self._enc.is_golomb:
143 self._enc._n_qubits = n_qubits
145 # Number of possible inputs
146 self.n_input_feat = len(self._enc)
147 log.debug(f"Number of input features: {self.n_input_feat}")
149 # Trainable frequencies, default initialization as in arXiv:2309.03279v2
150 self.enc_params = jnp.ones((self.n_layers, self.n_qubits, self.n_input_feat))
152 self._zero_inputs = False
154 # --- Data-Reuploading ---
156 # Keep as NumPy array (not JAX) so that ``if data_reupload[q, idx]``
157 # in _iec remains a concrete Python bool even under jax.jit tracing.
158 # note that setting this will also update self.degree and self.frequencies
159 # and in consequence also self.has_dru
160 self.data_reupload = data_reupload
162 # check for the highest degree among all input dimensions
163 if self.has_dru:
164 impl_n_layers: int = n_layers + 1 # we need L+1 according to Schuld et al.
165 else:
166 impl_n_layers = n_layers
167 log.info(f"Number of implicit layers: {impl_n_layers}.")
169 # --- Ansatz ---
170 # only weak check for str. We trust the user to provide sth useful
171 if isinstance(circuit_type, str):
172 self.pqc: Callable[[Optional[jnp.ndarray], int], int] = getattr(
173 Ansaetze, circuit_type or "No_Ansatz"
174 )()
175 else:
176 self.pqc = circuit_type()
177 log.info(f"Using Ansatz {circuit_type}.")
179 # calculate the shape of the parameter vector here, we will re-use this in init.
180 params_per_layer = self.pqc.n_params_per_layer(self.n_qubits)
181 self._params_shape: Tuple[int, int] = (impl_n_layers, params_per_layer)
182 log.info(f"Parameters per layer: {params_per_layer}")
184 pulse_params_per_layer = self.pqc.n_pulse_params_per_layer(self.n_qubits)
185 self._pulse_params_shape: Tuple[int, int] = (
186 impl_n_layers,
187 pulse_params_per_layer,
188 )
190 # intialize to None as we can't know this yet
191 self._batch_shape = None
193 # this will also be re-used in the init method,
194 # however, only if nothing is provided
195 self._inialization_strategy = initialization
196 self._initialization_domain = initialization_domain
198 # ..here! where we only require a JAX random key
199 self.random_key = self.initialize_params(random.key(random_seed))
201 # Initializing pulse params
202 self.pulse_params: jnp.ndarray = jnp.ones((1, *self._pulse_params_shape))
204 log.info(f"Initialized pulse parameters with shape {self.pulse_params.shape}.")
206 # Initialise the jaqsi Script that wraps _variational.
207 # No device selection needed - jaqsi auto-routes between statevector
208 # and density-matrix simulation based on whether noise channels are
209 # present on the tape.
210 self.script = js.Script(f=self._variational, n_qubits=self.n_qubits)
212 @property
213 def noise_params(self) -> Optional[Dict[str, Union[float, Dict[str, float]]]]:
214 """
215 Gets the noise parameters of the model.
217 Returns:
218 Optional[Dict[str, float]]: A dictionary of
219 noise parameters or None if not set.
220 """
221 return self._noise_params
223 @noise_params.setter
224 def noise_params(
225 self, kvs: Optional[Dict[str, Union[float, Dict[str, float]]]]
226 ) -> None:
227 """
228 Sets the noise parameters of the model.
230 Typically a "noise parameter" refers to the error probability.
231 ThermalRelaxation is a special case, and supports a dict as value with
232 structure:
233 "ThermalRelaxation":
234 {
235 "t1": 2000, # relative t1 time.
236 "t2": 1000, # relative t2 time
237 "t_factor" 1: # relative gate time factor
238 },
240 Args:
241 kvs (Optional[Dict[str, Union[float, Dict[str, float]]]]): A
242 dictionary of noise parameters. If all values are 0.0, the noise
243 parameters are set to None.
245 Returns:
246 None
247 """
248 # set to None if only zero values provided
249 if kvs is not None and all(v == 0.0 for v in kvs.values()):
250 kvs = None
252 # set default values
253 if kvs is not None:
254 defaults = {
255 "BitFlip": 0.0,
256 "PhaseFlip": 0.0,
257 "Depolarizing": 0.0,
258 "MultiQubitDepolarizing": 0.0,
259 "AmplitudeDamping": 0.0,
260 "PhaseDamping": 0.0,
261 "GateError": 0.0,
262 "ThermalRelaxation": None,
263 "StatePreparation": 0.0,
264 "Measurement": 0.0,
265 }
266 for key, default_val in defaults.items():
267 kvs.setdefault(key, default_val)
269 # check if there are any keys not supported
270 for key in kvs.keys():
271 if key not in defaults:
272 warnings.warn(
273 f"Noise type {key} is not supported by this package",
274 UserWarning,
275 )
277 # check valid params for thermal relaxation noise channel
278 tr_params = kvs["ThermalRelaxation"]
279 if isinstance(tr_params, dict):
280 tr_params.setdefault("t1", 0.0)
281 tr_params.setdefault("t2", 0.0)
282 tr_params.setdefault("t_factor", 0.0)
283 valid_tr_keys = {"t1", "t2", "t_factor"}
284 for k in tr_params.keys():
285 if k not in valid_tr_keys:
286 warnings.warn(
287 f"Thermal Relaxation parameter {k} is not supported "
288 f"by this package",
289 UserWarning,
290 )
291 if not all(tr_params.values()) or tr_params["t2"] > 2 * tr_params["t1"]:
292 warnings.warn(
293 "Received invalid values for Thermal Relaxation noise "
294 "parameter. Thermal relaxation is not applied!",
295 UserWarning,
296 )
297 kvs["ThermalRelaxation"] = 0.0
299 self._noise_params = kvs
301 @property
302 def output_qubit(self) -> List[int]:
303 """Get the output qubit indices for measurement."""
304 return self._output_qubit
306 @output_qubit.setter
307 def output_qubit(self, value: Union[int, List[int]]) -> None:
308 """
309 Set the output qubit(s) for measurement.
311 Args:
312 value: Qubit index or list of indices. Use -1 for all qubits.
313 """
314 if isinstance(value, list):
315 assert len(value) <= self.n_qubits, (
316 f"Size of output_qubit {len(value)} cannot be\
317 larger than number of qubits {self.n_qubits}."
318 )
319 elif isinstance(value, int):
320 if value == -1:
321 value = list(range(self.n_qubits))
322 else:
323 assert value < self.n_qubits, (
324 f"Output qubit {value} cannot be larger than {self.n_qubits}."
325 )
326 value = [value]
328 self._output_qubit = value
330 @property
331 def execution_type(self) -> str:
332 """
333 Gets the execution type of the model.
335 Returns:
336 str: The execution type, one of 'density', 'expval', or 'probs'.
337 """
338 return self._execution_type
340 @execution_type.setter
341 def execution_type(self, value: str) -> None:
342 if value == "density":
343 self._result_shape = (
344 2 ** len(self.output_qubit),
345 2 ** len(self.output_qubit),
346 )
347 elif value == "expval":
348 # check if all qubits are used
349 if len(self.output_qubit) == self.n_qubits:
350 self._result_shape = (len(self.output_qubit),)
351 # if not -> parity measurement with only 1D output per pair
352 # or n_local measurement
353 else:
354 self._result_shape = (len(self.output_qubit),)
355 elif value == "probs":
356 # in case this is a list of parities,
357 # each pair has 2^len(qubits) probabilities
358 n_parity = (
359 (2,) * len(self.output_qubit)
360 if isinstance(self.output_qubit, (Tuple, List))
361 else (2,)
362 )
363 self._result_shape = n_parity
364 elif value == "state":
365 self._result_shape = (2 ** len(self.output_qubit),)
366 else:
367 raise ValueError(f"Invalid execution type: {value}.")
369 if value == "state" and not self.all_qubit_measurement:
370 warnings.warn(
371 f"{value} measurement does ignore output_qubit, which is "
372 f"{self.output_qubit}.",
373 UserWarning,
374 )
376 if value == "probs" and self.shots is None:
377 warnings.warn(
378 "Setting execution_type to probs without specifying shots.",
379 UserWarning,
380 )
382 if value == "density" and self.shots is not None:
383 raise ValueError("Setting execution_type to density with shots not None.")
385 self._execution_type = value
387 @property
388 def shots(self) -> Optional[int]:
389 """
390 Gets the number of shots to use for the quantum device.
392 Returns:
393 Optional[int]: The number of shots.
394 """
395 return self._shots
397 @shots.setter
398 def shots(self, value: Optional[int]) -> None:
399 """
400 Sets the number of shots to use for the quantum device.
402 Args:
403 value (Optional[int]): The number of shots.
404 If an integer less than or equal to 0 is provided, it is set to None.
406 Returns:
407 None
408 """
409 if type(value) is int and value <= 0:
410 value = None
411 self._shots = value
413 @property
414 def params(self) -> jnp.ndarray:
415 """Get the variational parameters of the model."""
416 return self._params
418 @params.setter
419 def params(self, value: jnp.ndarray) -> None:
420 """Set the variational parameters, ensuring batch dimension exists."""
421 if len(value.shape) == 2:
422 value = value.reshape(1, *value.shape)
424 self._params = value
426 @property
427 def enc_params(self) -> jnp.ndarray:
428 """Get the encoding parameters used for input transformation."""
429 return self._enc_params
431 @enc_params.setter
432 def enc_params(self, value: jnp.ndarray) -> None:
433 """Set the encoding parameters."""
434 self._enc_params = value
436 @property
437 def pulse_params(self) -> jnp.ndarray:
438 """Get the pulse parameters for pulse-mode gate execution."""
439 return self._pulse_params
441 @pulse_params.setter
442 def pulse_params(self, value: jnp.ndarray) -> None:
443 """Set the pulse parameters."""
444 self._pulse_params = value
446 @property
447 def data_reupload(self) -> jnp.ndarray:
448 """Get the data reupload mask."""
449 return self._data_reupload
451 @data_reupload.setter
452 def data_reupload(self, value: jnp.ndarray) -> None:
453 """Set the data reupload mask.
455 Always converts to a concrete NumPy boolean array so that
456 ``if data_reupload[q, idx]`` in :meth:`_iec` remains a plain
457 Python ``bool`` even inside JAX-traced functions (jit / grad / vmap).
458 """
459 # Process data reuploading strategy and set degree
460 if not isinstance(value, bool):
461 if not isinstance(value, np.ndarray):
462 value = np.array(value)
464 if len(value.shape) == 2:
465 assert value.shape == (
466 self.n_layers,
467 self.n_qubits,
468 ), (
469 f"Data reuploading array has wrong shape. \
470 Expected {(self.n_layers, self.n_qubits)} or\
471 {(self.n_layers, self.n_qubits, self.n_input_feat)},\
472 got {value.shape}."
473 )
474 value = value.reshape(*value.shape, 1)
475 value = np.repeat(value, self.n_input_feat, axis=2)
477 assert value.shape == (
478 self.n_layers,
479 self.n_qubits,
480 self.n_input_feat,
481 ), (
482 f"Data reuploading array has wrong shape. \
483 Expected {(self.n_layers, self.n_qubits, self.n_input_feat)},\
484 got {value.shape}."
485 )
487 log.debug(f"Data reuploading array:\n{value}")
488 else:
489 if value:
490 value = np.ones((self.n_layers, self.n_qubits, self.n_input_feat))
491 log.debug("Full data reuploading.")
492 else:
493 value = np.zeros((self.n_layers, self.n_qubits, self.n_input_feat))
494 value[0][0] = 1
495 log.debug("No data reuploading.")
497 # convert to boolean values
498 self._data_reupload = np.asarray(value).astype(bool)
500 self.degree: Tuple = tuple(
501 self._enc.get_n_freqs(np.count_nonzero(self.data_reupload[..., i]))
502 for i in range(self.n_input_feat)
503 )
505 self.frequencies: Tuple = tuple(
506 self._enc.get_spectrum(np.count_nonzero(self.data_reupload[..., i]))
507 for i in range(self.n_input_feat)
508 )
510 # Cache has_dru as a plain Python bool so that it can be used in
511 # Python ``if`` statements even inside JAX-traced functions.
512 self._has_dru: bool = bool(max(int(np.max(f)) for f in self._frequencies) > 1)
514 @property
515 def degree(self) -> Tuple:
516 """Get the degree of the model."""
517 return self._degree
519 @degree.setter
520 def degree(self, value: Tuple):
521 self._degree = value
523 @property
524 def frequencies(self) -> Tuple:
525 """Get the frequencies of the model."""
526 return self._frequencies
528 @frequencies.setter
529 def frequencies(self, value: Tuple):
530 self._frequencies = value
532 def exact_spectrum(self, method: str = "tree") -> Tuple[np.ndarray, ...]:
533 """Compute the exact per-feature Fourier spectrum via the FourierTree.
535 Unlike :attr:`frequencies` -- a naive per-feature estimate derived purely
536 from the encoding, which can *overestimate* the spectrum (some
537 coefficients are constrained to zero for all parameters) -- this builds
538 the analytical Fourier tree (Nemkov et al.) and returns, for each input
539 feature, the integer frequencies whose Fourier coefficient is not
540 identically zero. The result is always a subset of :attr:`frequencies`.
542 The support is derived purely symbolically (no parameter sampling): see
543 :meth:`~qml_essentials.coefficients.FourierTree.get_exact_support`.
544 With ``method="tree"`` (default), frequencies whose contributions cancel
545 identically across tree paths (e.g. two consecutive encodings combining
546 into a single rotation) are excluded exactly; this enumerates the
547 explicit tree, which can be infeasible for deep entangling circuits.
548 With ``method="dp"``, a merged-state dynamic program derives the support
549 without enumerating paths, which scales to deep circuits (single input
550 feature only) at the cost of not detecting identical cross-path
551 cancellations.
553 Requires a Clifford + Pauli-rotation ansatz (see
554 :class:`~qml_essentials.pauli.PauliCircuit`); other gate sets raise
555 ``NotImplementedError`` during tree construction.
557 Args:
558 method (str): ``"tree"`` (fully exact) or ``"dp"`` (scalable).
560 Returns:
561 Tuple[np.ndarray, ...]: One sorted integer frequency array per input
562 feature (same layout as :attr:`frequencies`).
563 """
564 from qml_essentials.coefficients import FourierTree # avoid circular imp.
566 tree = FourierTree(self)
568 # Position of each model feature within the tree's frequency vectors.
569 feature_pos = {feat: i for i, feat in enumerate(tree.features)}
571 # Union of the symbolic supports over all observables (roots).
572 support = set()
573 for freqs in tree.get_exact_support(method=method):
574 farr = np.asarray(freqs)
575 for k in range(farr.shape[0]):
576 key = (
577 (int(farr[k]),)
578 if farr.ndim == 1
579 else tuple(int(v) for v in farr[k])
580 )
581 support.add(key)
583 spectrum = []
584 for feat in range(self.n_input_feat):
585 if support and feat in feature_pos:
586 pos = feature_pos[feat]
587 vals = sorted({k[pos] for k in support})
588 else:
589 vals = [0]
590 spectrum.append(np.array(vals, dtype=int))
591 return tuple(spectrum)
593 @property
594 def has_dru(self) -> bool:
595 """Check if the model has data reupload."""
596 return self._has_dru
598 @property
599 def all_qubit_measurement(self) -> bool:
600 """Check if measurement is performed on all qubits."""
601 return self.output_qubit == list(range(self.n_qubits))
603 @property
604 def batch_shape(self) -> Tuple[int, ...]:
605 """
606 Get the batch shape (B_I, B_P, B_R).
607 If the model was not called before,
608 it returns (1, 1, 1).
610 Returns:
611 Tuple[int, ...]: Tuple of (input_batch, param_batch, pulse_batch).
612 Returns (1, 1, 1) if model has not been called yet.
613 """
614 if self._batch_shape is None:
615 log.debug("Model was not called yet. Returning (1,1,1) as batch shape.")
616 return (1, 1, 1)
617 return self._batch_shape
619 @property
620 def eff_batch_shape(self) -> Tuple[int, ...]:
621 """
622 Get the effective batch shape after applying repeat_batch_axis mask.
624 Returns:
625 Tuple[int, ...]: Effective batch dimensions, excluding zeros.
626 """
627 batch_shape = np.array(self.batch_shape) * self.repeat_batch_axis
628 batch_shape = batch_shape[batch_shape != 0]
629 return batch_shape
631 def initialize_params(
632 self,
633 random_key: Optional[random.PRNGKey] = None,
634 repeat: int = 1,
635 initialization: Optional[str] = None,
636 initialization_domain: Optional[List[float]] = None,
637 ) -> random.PRNGKey:
638 """
639 Initialize the variational parameters of the model.
641 Args:
642 random_key (Optional[random.PRNGKey]): JAX random key for initialization.
643 If None, uses the model's internal random key.
644 repeat (int): Number of parameter sets to create (batch dimension).
645 Defaults to 1.
646 initialization (Optional[str]): Strategy for parameter initialization.
647 Options: "random", "zeros", "pi", "zero-controlled", "pi-controlled".
648 If None, uses the strategy specified in the constructor.
649 initialization_domain (Optional[List[float]]): Domain [min, max] for
650 random initialization. If None, uses the domain from constructor.
652 Returns:
653 random.PRNGKey: Updated random key after initialization.
655 Raises:
656 Exception: If an invalid initialization method is specified.
657 """
658 # Initializing params
659 params_shape = (repeat, *self._params_shape)
661 # use existing strategy if not specified
662 initialization = initialization or self._inialization_strategy
663 initialization_domain = initialization_domain or self._initialization_domain
665 random_key, sub_key = safe_random_split(
666 random_key if random_key is not None else self.random_key
667 )
669 def set_control_params(params: jnp.ndarray, value: float) -> jnp.ndarray:
670 indices = self.pqc.get_control_indices(self.n_qubits)
671 if indices is None:
672 warnings.warn(
673 f"Specified {initialization} but circuit\
674 does not contain controlled rotation gates.\
675 Parameters are intialized randomly.",
676 UserWarning,
677 )
678 else:
679 np_params = np.array(params)
680 np_params[:, :, indices[0] : indices[1] : indices[2]] = (
681 np.ones_like(params[:, :, indices[0] : indices[1] : indices[2]])
682 * value
683 )
684 params = jnp.array(np_params)
685 return params
687 if initialization == "random":
688 self.params: jnp.ndarray = random.uniform(
689 sub_key,
690 params_shape,
691 minval=initialization_domain[0],
692 maxval=initialization_domain[1],
693 )
694 elif initialization == "zeros":
695 self.params: jnp.ndarray = jnp.zeros(params_shape)
696 elif initialization == "pi":
697 self.params: jnp.ndarray = jnp.ones(params_shape) * jnp.pi
698 elif initialization == "zero-controlled":
699 self.params: jnp.ndarray = random.uniform(
700 sub_key,
701 params_shape,
702 minval=initialization_domain[0],
703 maxval=initialization_domain[1],
704 )
705 self.params = set_control_params(self.params, 0)
706 elif initialization == "pi-controlled":
707 self.params: jnp.ndarray = random.uniform(
708 sub_key,
709 params_shape,
710 minval=initialization_domain[0],
711 maxval=initialization_domain[1],
712 )
713 self.params = set_control_params(self.params, jnp.pi)
714 else:
715 raise Exception("Invalid initialization method")
717 log.info(
718 f"Initialized parameters with shape {self.params.shape}\
719 using strategy {initialization}."
720 )
722 return random_key
724 def transform_input(
725 self, inputs: jnp.ndarray, enc_params: jnp.ndarray
726 ) -> jnp.ndarray:
727 """
728 Transform input data by scaling with encoding parameters.
730 Implements the input transformation as described in arXiv:2309.03279v2,
731 where inputs are linearly scaled by encoding parameters before being
732 used in the quantum circuit.
734 Args:
735 inputs (jnp.ndarray): Input data point of shape (n_input_feat,) or
736 (batch_size, n_input_feat).
737 enc_params (jnp.ndarray): Encoding weight scalar or vector used to
738 scale the input.
740 Returns:
741 jnp.ndarray: Transformed input, element-wise product of inputs
742 and enc_params.
743 """
744 return inputs * enc_params
746 def _iec(
747 self,
748 inputs: jnp.ndarray,
749 data_reupload: jnp.ndarray,
750 enc: Encoding,
751 enc_params: jnp.ndarray,
752 noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None,
753 random_key: Optional[random.PRNGKey] = None,
754 ) -> None:
755 """
756 Apply Input Encoding Circuit (IEC) with angle encoding.
758 Encodes classical input data into the quantum circuit using rotation
759 gates (e.g., RX, RY, RZ). Supports data re-uploading at specified
760 positions in the circuit.
762 For Golomb encoding, a single multi-qubit diagonal unitary is applied
763 to all qubits simultaneously instead of per-qubit rotation gates.
765 Args:
766 inputs (jnp.ndarray): Input data of shape (n_input_feat,) or
767 (batch_size, n_input_feat).
768 data_reupload (jnp.ndarray): Boolean array of shape (n_qubits, n_input_feat)
769 indicating where to apply encoding gates.
770 enc (Encoding): Encoding strategy containing the encoding gate functions.
771 enc_params (jnp.ndarray): Encoding parameters of shape
772 (n_qubits, n_input_feat) used to scale inputs.
773 noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]):
774 Noise parameters for gate-level noise simulation. Defaults to None.
775 random_key (Optional[random.PRNGKey]): JAX random key for stochastic
776 noise. Defaults to None.
778 Returns:
779 None: Gates are applied in-place to the quantum circuit.
780 """
781 # check for zero, because due to input validation, input cannot be none
782 if self.remove_zero_encoding and self._zero_inputs and self.batch_shape[0] == 1:
783 return
785 # --- Golomb encoding: single multi-qubit gate on all qubits --------
786 if enc.is_golomb:
787 idx = 0 # Golomb encoding supports a single input feature
788 # Check if any qubit has re-uploading enabled for this layer
789 if data_reupload[:, idx].any():
790 random_key, sub_key = safe_random_split(random_key)
791 # Use the mean of enc_params across qubits as scalar scaling
792 # (Golomb acts on all qubits jointly)
793 mean_enc_param = jnp.mean(enc_params[:, idx])
794 all_wires = list(range(self.n_qubits))
795 enc[idx](
796 self.transform_input(inputs[..., idx], mean_enc_param),
797 wires=all_wires,
798 noise_params=noise_params,
799 random_key=sub_key,
800 )
801 return
803 # --- Standard per-qubit encoding -----------------------------------
804 for q in range(self.n_qubits):
805 # use the last dimension of the inputs (feature dimension)
806 for idx in range(inputs.shape[-1]):
807 if data_reupload[q, idx]:
808 # use elipsis to indiex only the last dimension
809 # as inputs are generally *not* qubit dependent
810 random_key, sub_key = safe_random_split(random_key)
811 enc[idx](
812 self.transform_input(inputs[..., idx], enc_params[q, idx]),
813 wires=q,
814 noise_params=noise_params,
815 random_key=sub_key,
816 )
818 def _variational(
819 self,
820 params: jnp.ndarray,
821 inputs: jnp.ndarray,
822 pulse_params: Optional[jnp.ndarray] = None,
823 random_key: Optional[random.PRNGKey] = None,
824 enc_params: Optional[jnp.ndarray] = None,
825 gate_mode: str = "unitary",
826 noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None,
827 ) -> None:
828 """
829 Build the variational quantum circuit structure.
831 Constructs the circuit by applying state preparation, alternating
832 variational ansatz layers with input encoding layers, and optional
833 noise channels.
835 The first five parameters (after ``self``) - ``params``, ``inputs``,
836 ``pulse_params``, ``random_key``, ``enc_params`` - are the batchable
837 positional arguments.
838 The remaining keyword arguments are broadcast across the batch.
840 Args:
841 params (jnp.ndarray): Variational parameters of shape
842 (n_layers, n_params_per_layer).
843 inputs (jnp.ndarray): Input data of shape (n_input_feat,).
844 pulse_params (Optional[jnp.ndarray]): Pulse parameter scalers of shape
845 (n_layers, n_pulse_params_per_layer) for pulse-mode execution.
846 Defaults to None (uses model's pulse_params).
847 random_key (Optional[random.PRNGKey]): JAX random key for stochastic
848 operations. Defaults to None.
849 enc_params (Optional[jnp.ndarray]): Encoding parameters of shape
850 (n_qubits, n_input_feat). Defaults to None (uses model's enc_params).
851 gate_mode (str): Gate execution mode, either "unitary" or "pulse".
852 Defaults to "unitary".
853 noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]):
854 Noise parameters for simulation. Defaults to None.
856 Returns:
857 None: Gates are applied in-place to the quantum circuit.
859 Note:
860 Issues RuntimeWarning if called directly without providing parameters
861 that would normally be passed through the forward method.
862 """
863 # TODO: rework and double check params shape
864 if len(params.shape) > 2 and params.shape[0] == 1:
865 params = params[0]
867 if len(inputs.shape) > 1 and inputs.shape[0] == 1:
868 inputs = inputs[0]
870 if enc_params is None:
871 # TODO: Raise warning if trainable frequencies is True, or similar. I.e., no
872 # warning if user does not care for frequencies or enc_params
873 if self.trainable_frequencies:
874 warnings.warn(
875 "Explicit call to `_circuit` or `_variational` detected: "
876 "`enc_params` is None, using `self.enc_params` instead.",
877 RuntimeWarning,
878 )
879 enc_params = self.enc_params
881 if pulse_params is None:
882 if gate_mode == "pulse":
883 warnings.warn(
884 "Explicit call to `_circuit` or `_variational` detected: "
885 "`pulse_params` is None, using `self.pulse_params` instead.",
886 RuntimeWarning,
887 )
888 pulse_params = self.pulse_params
890 # Squeeze batch dimension for pulse_params (batch-first convention)
891 if len(pulse_params.shape) > 2 and pulse_params.shape[0] == 1:
892 pulse_params = pulse_params[0]
894 if noise_params is None:
895 if self.noise_params is not None:
896 warnings.warn(
897 "Explicit call to `_circuit` or `_variational` detected: "
898 "`noise_params` is None, using `self.noise_params` instead.",
899 RuntimeWarning,
900 )
901 noise_params = self.noise_params
903 if noise_params is not None:
904 if random_key is None:
905 warnings.warn(
906 "Explicit call to `_circuit` or `_variational` detected: "
907 "`random_key` is None, using `random.PRNGKey(0)` instead.",
908 RuntimeWarning,
909 )
910 random_key = self.random_key
911 self._apply_state_prep_noise(noise_params=noise_params)
913 # state preparation
914 for q in range(self.n_qubits):
915 for _sp, sp_pulse_params in zip(self._sp, self.sp_pulse_params):
916 random_key, sub_key = safe_random_split(random_key)
917 _sp(
918 wires=q,
919 pulse_params=sp_pulse_params,
920 noise_params=noise_params,
921 random_key=sub_key,
922 gate_mode=gate_mode,
923 )
925 # circuit building
926 for layer in range(0, self.n_layers):
927 random_key, sub_key = safe_random_split(random_key)
928 # ansatz layers
929 self.pqc(
930 params[layer],
931 self.n_qubits,
932 pulse_params=pulse_params[layer],
933 noise_params=noise_params,
934 random_key=sub_key,
935 gate_mode=gate_mode,
936 )
938 random_key, sub_key = safe_random_split(random_key)
939 # encoding layers
940 self._iec(
941 inputs,
942 data_reupload=self.data_reupload[layer],
943 enc=self._enc,
944 enc_params=enc_params[layer],
945 noise_params=noise_params,
946 random_key=sub_key,
947 )
949 # final ansatz layer
950 if self.has_dru: # same check as in init
951 random_key, sub_key = safe_random_split(random_key)
952 self.pqc(
953 params[self.n_layers],
954 self.n_qubits,
955 pulse_params=pulse_params[-1],
956 noise_params=noise_params,
957 random_key=sub_key,
958 gate_mode=gate_mode,
959 )
961 # channel noise
962 if noise_params is not None:
963 self._apply_general_noise(noise_params=noise_params)
965 def _build_obs(self) -> Tuple[str, List[op.Operation]]:
966 """Build the jaqsi measurement type and observable list.
968 Translates the model's ``execution_type`` and ``output_qubit``
969 settings into parameters suitable for
970 :meth:`~qml_essentials.jaqsi.Script.execute`.
972 Returns:
973 Tuple ``(meas_type, obs)`` where *meas_type* is one of
974 ``"expval"``, ``"probs"``, ``"density"``, ``"state"`` and *obs*
975 is a (possibly empty) list of :class:`Operation` observables.
976 """
977 if self.execution_type == "density":
978 return "density", []
980 if self.execution_type == "state":
981 return "state", []
983 if self.execution_type == "expval":
984 obs: List[op.Operation] = []
985 for qubit_spec in self.output_qubit:
986 if isinstance(qubit_spec, int):
987 obs.append(op.PauliZ(wires=qubit_spec))
988 else:
989 # parity: Z \\otimes Z \\otimes …
990 obs.append(js.build_parity_observable(list(qubit_spec)))
991 return "expval", obs
993 if self.execution_type == "probs":
994 # probs are computed on the full system; subsystem
995 # marginalisation is handled in _postprocess_res
996 return "probs", []
998 raise ValueError(f"Invalid execution_type: {self.execution_type}.")
1000 def _apply_state_prep_noise(
1001 self, noise_params: Dict[str, Union[float, Dict[str, float]]]
1002 ) -> None:
1003 """
1004 Apply state preparation noise to all qubits.
1006 Simulates imperfect state preparation by applying BitFlip errors
1007 to each qubit with the specified probability.
1009 Args:
1010 noise_params (Dict[str, Union[float, Dict[str, float]]]): Dictionary
1011 containing noise parameters. Uses the "StatePreparation" key
1012 for the BitFlip probability.
1014 Returns:
1015 None: Noise channels are applied in-place to the circuit.
1016 """
1017 p = noise_params.get("StatePreparation", 0.0)
1018 if p > 0:
1019 for q in range(self.n_qubits):
1020 op.BitFlip(p, wires=q)
1022 def _apply_general_noise(
1023 self, noise_params: Dict[str, Union[float, Dict[str, float]]]
1024 ) -> None:
1025 """
1026 Apply general noise channels to all qubits.
1028 Applies various decoherence and error channels after the circuit
1029 execution, simulating environmental noise effects.
1031 Args:
1032 noise_params (Dict[str, Union[float, Dict[str, float]]]): Dictionary
1033 containing noise parameters with the following supported keys:
1034 - "AmplitudeDamping" (float): Probability for amplitude damping.
1035 - "PhaseDamping" (float): Probability for phase damping.
1036 - "Measurement" (float): Probability for measurement error (BitFlip).
1037 - "ThermalRelaxation" (Dict): Dictionary with keys "t1", "t2",
1038 "t_factor" for thermal relaxation simulation.
1040 Returns:
1041 None: Noise channels are applied in-place to the circuit.
1043 Note:
1044 Gate-level noise (e.g., GateError) is handled separately in the
1045 Gates.Noise module and applied at the individual gate level.
1046 """
1047 amp_damp = noise_params.get("AmplitudeDamping", 0.0)
1048 phase_damp = noise_params.get("PhaseDamping", 0.0)
1049 thermal_relax = noise_params.get("ThermalRelaxation", 0.0)
1050 meas = noise_params.get("Measurement", 0.0)
1051 for q in range(self.n_qubits):
1052 if amp_damp > 0:
1053 op.AmplitudeDamping(amp_damp, wires=q)
1054 if phase_damp > 0:
1055 op.PhaseDamping(phase_damp, wires=q)
1056 if meas > 0:
1057 op.BitFlip(meas, wires=q)
1058 if isinstance(thermal_relax, dict):
1059 t1 = thermal_relax["t1"]
1060 t2 = thermal_relax["t2"]
1061 t_factor = thermal_relax["t_factor"]
1062 circuit_depth = self._get_circuit_depth()
1063 tg = circuit_depth * t_factor
1064 op.ThermalRelaxationError(1.0, t1, t2, tg, q)
1066 def _get_circuit_depth(self, inputs: Optional[jnp.ndarray] = None) -> int:
1067 """
1068 Calculate the depth of the quantum circuit.
1070 Records the circuit onto a tape (without noise) and computes the
1071 depth as the length of the critical path: each gate is scheduled
1072 at the earliest time step after all of its qubits are free.
1074 Args:
1075 inputs (Optional[jnp.ndarray]): Input data for circuit evaluation.
1076 If None, default zero inputs are used.
1078 Returns:
1079 int: The circuit depth (longest path of gates in the circuit).
1080 """
1081 # Return cached value if available
1082 if hasattr(self, "_cached_circuit_depth"):
1083 return self._cached_circuit_depth
1085 inputs = self._inputs_validation(inputs)
1087 # Temporarily clear noise_params to prevent _variational from
1088 # picking them up (which would call _apply_general_noise ->
1089 # _get_circuit_depth again, causing infinite recursion).
1090 saved_noise = self._noise_params
1091 self._noise_params = None
1093 with recording() as tape:
1094 self._variational(
1095 self.params[0] if self.params.ndim == 3 else self.params,
1096 inputs[0] if inputs.ndim == 2 else inputs,
1097 noise_params=None,
1098 )
1100 self._noise_params = saved_noise
1102 # Filter out noise channels - only count unitary gates
1103 ops = [o for o in tape if not isinstance(o, KrausChannel)]
1105 if not ops:
1106 self._cached_circuit_depth = 0
1107 return 0
1109 # Schedule each gate at the earliest time step where all its wires
1110 # are free. ``wire_busy[q]`` tracks the next free time step for
1111 # qubit ``q``.
1112 wire_busy: Dict[int, int] = {}
1113 depth = 0
1114 for gate in ops:
1115 start = max((wire_busy.get(w, 0) for w in gate.wires), default=0)
1116 end = start + 1
1117 for w in gate.wires:
1118 wire_busy[w] = end
1119 depth = max(depth, end)
1121 self._cached_circuit_depth = depth
1122 return depth
1124 def draw(
1125 self,
1126 inputs: Optional[jnp.ndarray] = None,
1127 figure: str = "text",
1128 **kwargs: Any,
1129 ) -> Union[str, Any]:
1130 """Visualize the quantum circuit.
1132 Records the circuit tape (without noise) and renders the gate
1133 sequence using the requested backend.
1135 Args:
1136 inputs (Optional[jnp.ndarray]): Input data for the circuit.
1137 If ``None``, default zero inputs are used.
1138 figure (str): Rendering backend. One of:
1140 * ``"text"`` - ASCII art (returned as a ``str``).
1141 * ``"mpl"`` - Matplotlib figure (returns ``(fig, ax)``).
1142 * ``"tikz"`` - LaTeX/TikZ ``quantikz`` code (returns a
1143 :class:`TikzFigure`).
1144 * ``"pulse"`` - Pulse schedule (returns ``(fig, axes)``).
1145 Only meaningful for pulse-mode models.
1147 **kwargs: Extra options forwarded to the drawing backend
1148 (e.g. ``gate_values=True``).
1150 Returns:
1151 Depends on figure:
1153 * ``"text"`` -> ``str``
1154 * ``"mpl"`` -> ``(matplotlib.figure.Figure, matplotlib.axes.Axes)``
1155 * ``"tikz"`` -> :class:`TikzFigure`
1157 Raises:
1158 ValueError: If figure is not one of the supported modes.
1159 """
1160 inputs = self._inputs_validation(inputs)
1161 params = self.params[0] if self.params.ndim == 3 else self.params
1162 inp = inputs[0] if inputs.ndim == 2 else inputs
1164 if figure == "pulse":
1165 return self.draw_pulse(inputs=inputs, **kwargs)
1167 # Record without noise to get a clean circuit
1168 saved_noise = self._noise_params
1169 self._noise_params = None
1171 draw_script = js.Script(f=self._variational, n_qubits=self.n_qubits)
1172 result = draw_script.draw(
1173 figure=figure,
1174 args=(params, inp),
1175 kwargs={"noise_params": None},
1176 **kwargs,
1177 )
1179 self._noise_params = saved_noise
1180 return result
1182 def draw_pulse(
1183 self,
1184 inputs: Optional[jnp.ndarray] = None,
1185 **kwargs: Any,
1186 ) -> Any:
1187 """Visualize the pulse schedule for the circuit.
1189 Records the circuit in pulse mode and collects PulseEvents
1190 automatically via the pulse-event tape, then renders them.
1192 Args:
1193 inputs: Input data. If ``None``, default zero inputs are used.
1194 **kwargs: Forwarded to
1195 :func:`~qml_essentials.drawing.draw_pulse_schedule`
1196 (e.g. ``show_carrier=True``, ``n_samples=300``).
1198 Returns:
1199 ``(fig, axes)`` — Matplotlib Figure and array of Axes.
1200 """
1201 inputs = self._inputs_validation(inputs)
1202 params = self.params[0] if self.params.ndim == 3 else self.params
1203 inp = inputs[0] if inputs.ndim == 2 else inputs
1205 draw_script = js.Script(f=self._variational, n_qubits=self.n_qubits)
1206 return draw_script.draw(
1207 figure="pulse",
1208 args=(params, inp),
1209 kwargs={
1210 "gate_mode": "pulse",
1211 "noise_params": None,
1212 },
1213 **kwargs,
1214 )
1216 def __repr__(self) -> str:
1217 """Return text representation of the quantum circuit model."""
1218 return self.draw(figure="text")
1220 def __str__(self) -> str:
1221 """Return string representation of the quantum circuit model."""
1222 return self.draw(figure="text")
1224 def _params_validation(self, params: Optional[jnp.ndarray]) -> jnp.ndarray:
1225 """
1226 Validate and normalize variational parameters.
1228 Ensures parameters have the correct shape with a batch dimension,
1229 and updates the model's internal parameters if new ones are provided.
1231 Args:
1232 params (Optional[jnp.ndarray]): Variational parameters to validate.
1233 If None, returns the model's current parameters.
1235 Returns:
1236 jnp.ndarray: Validated parameters with shape
1237 (batch_size, n_layers, n_params_per_layer).
1238 """
1239 # append batch axis if not provided
1240 if params is not None:
1241 if len(params.shape) == 2:
1242 params = np.expand_dims(params, axis=0)
1244 # Avoid stashing JAX tracers on ``self``: under an outer
1245 # transform (e.g. ``jacrev``) the tracer becomes invalid once
1246 # the transform returns, and a subsequent read of
1247 # ``self.params`` would feed a leaked tracer into the next
1248 # call (raising ``UnexpectedTracerError``).
1249 # if not isinstance(params, jax.core.Tracer):
1250 # self.params = params
1251 self.params = params
1252 else:
1253 params = self.params
1255 return params
1257 def _pulse_params_validation(
1258 self, pulse_params: Optional[jnp.ndarray]
1259 ) -> jnp.ndarray:
1260 """
1261 Validate and normalize pulse parameters.
1263 Ensures pulse parameters are set, using model defaults if not provided.
1265 Args:
1266 pulse_params (Optional[jnp.ndarray]): Pulse parameter scalers.
1267 If None, returns the model's current pulse parameters.
1269 Returns:
1270 jnp.ndarray: Validated pulse parameters with shape
1271 (batch_size, n_layers, n_pulse_params_per_layer).
1272 """
1273 if pulse_params is None:
1274 pulse_params = self.pulse_params
1275 else:
1276 # ensure batch dimension exists (batch-first convention)
1277 if len(pulse_params.shape) == 2:
1278 pulse_params = jnp.expand_dims(pulse_params, axis=0)
1279 # See note in _params_validation: never stash JAX tracers on
1280 # ``self``.
1281 # if not isinstance(pulse_params, jax.core.Tracer):
1282 # self.pulse_params = pulse_params
1283 self.pulse_params = pulse_params
1285 return pulse_params
1287 def _enc_params_validation(self, enc_params: Optional[jnp.ndarray]) -> jnp.ndarray:
1288 """
1289 Validate and normalize encoding parameters.
1291 Ensures encoding parameters have the correct shape for the model's
1292 input feature dimensions.
1294 Args:
1295 enc_params (Optional[jnp.ndarray]): Encoding parameters to validate.
1296 If None, returns the model's current encoding parameters.
1298 Returns:
1299 jnp.ndarray: Validated encoding parameters with shape
1300 (n_qubits, n_input_feat).
1302 Raises:
1303 ValueError: If enc_params shape is incompatible with n_input_feat > 1.
1304 """
1305 if enc_params is None:
1306 enc_params = self.enc_params
1307 else:
1308 # See note in _params_validation: never stash JAX tracers on
1309 # ``self``.
1310 # if not isinstance(enc_params, jax.core.Tracer):
1311 # if self.trainable_frequencies:
1312 # self.enc_params = enc_params
1313 # else:
1314 # self.enc_params = jnp.array(enc_params)
1315 if self.trainable_frequencies:
1316 self.enc_params = enc_params
1317 else:
1318 self.enc_params = jnp.array(enc_params)
1320 if len(enc_params.shape) == 1 and self.n_input_feat == 1:
1321 enc_params = enc_params.reshape(-1, 1)
1322 elif len(enc_params.shape) == 1 and self.n_input_feat > 1:
1323 raise ValueError(
1324 f"Input dimension {self.n_input_feat} >1 but \
1325 `enc_params` has shape {enc_params.shape}"
1326 )
1328 return enc_params
1330 def _inputs_validation(
1331 self, inputs: Union[None, List, float, int, jnp.ndarray]
1332 ) -> jnp.ndarray:
1333 """
1334 Validate and normalize input data.
1336 Converts various input formats to a standardized 2D array shape
1337 suitable for batch processing in the quantum circuit.
1339 Args:
1340 inputs (Union[None, List, float, int, jnp.ndarray]): Input data in
1341 various formats:
1342 - None: Returns zeros with shape (1, n_input_feat)
1343 - float/int: Single scalar value
1344 - List: List of values or batched inputs
1345 - jnp.ndarray: NumPy/JAX array
1347 Returns:
1348 jnp.ndarray: Validated inputs with shape (batch_size, n_input_feat).
1350 Raises:
1351 ValueError: If input shape is incompatible with expected n_input_feat.
1353 Warns:
1354 UserWarning: If input is replicated to match n_input_feat.
1355 """
1356 self._zero_inputs = False
1357 if isinstance(inputs, List):
1358 inputs = jnp.array(np.stack(inputs))
1359 elif isinstance(inputs, float) or isinstance(inputs, int):
1360 inputs = jnp.array([inputs])
1361 elif inputs is None:
1362 inputs = jnp.array([[0] * self.n_input_feat])
1364 if not inputs.any():
1365 self._zero_inputs = True
1367 if len(inputs.shape) <= 1:
1368 if self.n_input_feat == 1:
1369 # add a batch dimension
1370 inputs = inputs.reshape(-1, 1)
1371 else:
1372 if inputs.shape[0] == self.n_input_feat:
1373 inputs = inputs.reshape(1, -1)
1374 else:
1375 inputs = inputs.reshape(-1, 1)
1376 inputs = inputs.repeat(self.n_input_feat, axis=1)
1377 warnings.warn(
1378 f"Expected {self.n_input_feat} inputs, but {inputs.shape[0]} "
1379 "was provided, replicating input for all input features.",
1380 UserWarning,
1381 )
1382 else:
1383 if inputs.shape[1] != self.n_input_feat:
1384 raise ValueError(
1385 f"Wrong number of inputs provided. Expected {self.n_input_feat} "
1386 f"inputs, but input has shape {inputs.shape}."
1387 )
1389 return inputs
1391 def _postprocess_res(self, result: Union[List, jnp.ndarray]) -> jnp.ndarray:
1392 """
1393 Post-process circuit execution results for uniform shape.
1395 Converts list outputs (from multiple measurements) to stacked arrays
1396 and reorders axes for consistent batch dimension placement.
1398 Args:
1399 result (Union[List, jnp.ndarray]): Raw circuit output, either a
1400 list of measurement results or a single array.
1402 Returns:
1403 jnp.ndarray: Uniformly shaped result array with batch dimension first.
1404 """
1405 if isinstance(result, list):
1406 # we use moveaxis here because in case of parity measure,
1407 # there is another dimension appended to the end and
1408 # simply transposing would result in a wrong shape
1409 result = jnp.stack(result)
1410 if len(result.shape) > 1:
1411 result = jnp.moveaxis(result, 0, 1)
1412 return result
1414 def _assimilate_batch(
1415 self,
1416 inputs: jnp.ndarray,
1417 params: jnp.ndarray,
1418 pulse_params: jnp.ndarray,
1419 ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
1420 """
1421 Align batch dimensions across inputs, parameters, and pulse parameters.
1423 Broadcasts and reshapes arrays to have compatible batch dimensions
1424 for vectorized circuit execution. Sets the internal batch_shape.
1426 Args:
1427 inputs (jnp.ndarray): Input data of shape (B_I, n_input_feat).
1428 params (jnp.ndarray): Parameters of shape (B_P, n_layers, n_params).
1429 pulse_params (jnp.ndarray): Pulse params of shape (B_R, n_layers, n_pulse).
1431 Returns:
1432 Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: Tuple containing:
1433 - inputs: Reshaped to (B, n_input_feat) where B = B_I * B_P * B_R
1434 - params: Reshaped to (B, n_layers, n_params)
1435 - pulse_params: Reshaped to (B, n_layers, n_pulse)
1437 Note:
1438 The effective batch shape depends on repeat_batch_axis configuration.
1439 This is the only method that sets self._batch_shape.
1440 """
1441 B_I = inputs.shape[0]
1442 # we check for the product because there is a chance that
1443 # there are no params. In this case we want B_P to be 1
1444 B_P = 1 if 0 in params.shape else params.shape[0]
1445 B_R = pulse_params.shape[0]
1447 # THIS is the only place where we set the batch shape
1448 self._batch_shape = (B_I, B_P, B_R)
1449 B = np.prod(self.eff_batch_shape)
1451 # [B_I, ...] -> [B_I, B_P, B_R, ...] -> [B, ...]
1452 if B_I > 1 and self.repeat_batch_axis[0]:
1453 if self.repeat_batch_axis[1]:
1454 inputs = jnp.repeat(inputs[:, None, None, ...], B_P, axis=1)
1455 if self.repeat_batch_axis[2]:
1456 inputs = jnp.repeat(inputs, B_R, axis=2)
1457 inputs = inputs.reshape(B, *inputs.shape[3:])
1459 # [B_P, ..., ...] -> [B_I, B_P, B_R, ..., ...] -> [B, ..., ...]
1460 if B_P > 1 and self.repeat_batch_axis[1]:
1461 # add B_I axis before first, and B_R axis after first batch dim
1462 params = params[None, :, None, ...] # [B_I(=1), B_P, B_R(=1), ...]
1463 if self.repeat_batch_axis[0]:
1464 params = jnp.repeat(params, B_I, axis=0) # [B_I, B_P, 1, ...]
1465 if self.repeat_batch_axis[2]:
1466 params = jnp.repeat(params, B_R, axis=2) # [B_I, B_P, B_R, ...]
1467 params = params.reshape(B, *params.shape[3:])
1469 # [B_R, ..., ...] -> [B_I, B_P, B_R, ..., ...] -> [B, ..., ...]
1470 if B_R > 1 and self.repeat_batch_axis[2]:
1471 # add B_I axis and B_P axis before B_R
1472 pulse_params = pulse_params[None, None, ...] # [B_I(=1), B_P(=1), B_R, ...]
1473 if self.repeat_batch_axis[0]:
1474 pulse_params = jnp.repeat(
1475 pulse_params, B_I, axis=0
1476 ) # [B_I, 1, B_R, ...]
1477 if self.repeat_batch_axis[1]:
1478 pulse_params = jnp.repeat(
1479 pulse_params, B_P, axis=1
1480 ) # [B_I, B_P, B_R, ...]
1481 pulse_params = pulse_params.reshape(B, *pulse_params.shape[3:])
1483 return inputs, params, pulse_params
1485 def _requires_density(self) -> bool:
1486 """
1487 Check if density matrix simulation is required.
1489 Determines whether the circuit must be executed with the mixed-state
1490 simulator based on execution type and noise configuration.
1492 Returns:
1493 bool: True if density matrix simulation is required, False otherwise.
1494 Returns True if:
1495 - execution_type is "density", or
1496 - Any non-coherent noise channel has non-zero probability
1497 """
1498 if self.execution_type == "density":
1499 return True
1501 if self.noise_params is None:
1502 return False
1504 coherent_noise = {"GateError"}
1505 for k, v in self.noise_params.items():
1506 if k in coherent_noise:
1507 continue
1508 if v is not None and v > 0:
1509 return True
1510 return False
1512 def __call__(
1513 self,
1514 params: Optional[jnp.ndarray] = None,
1515 inputs: Optional[jnp.ndarray] = None,
1516 pulse_params: Optional[jnp.ndarray] = None,
1517 enc_params: Optional[jnp.ndarray] = None,
1518 data_reupload: Union[bool, List[List[bool]], List[List[List[bool]]]] = None,
1519 noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None,
1520 execution_type: Optional[str] = None,
1521 force_mean: bool = False,
1522 gate_mode: str = "unitary",
1523 ) -> jnp.ndarray:
1524 """
1525 Execute the quantum circuit (callable interface).
1527 Provides a convenient callable interface for circuit execution,
1528 delegating to the _forward method.
1530 Args:
1531 params (Optional[jnp.ndarray]): Variational parameters of shape
1532 (n_layers, n_params_per_layer) or (batch, n_layers, n_params_per_layer).
1533 If None, uses model's internal parameters.
1534 inputs (Optional[jnp.ndarray]): Input data of shape
1535 (batch_size, n_input_feat). If None, uses zero inputs.
1536 pulse_params (Optional[jnp.ndarray]): Pulse parameter scalers for
1537 pulse-mode gate execution.
1538 enc_params (Optional[jnp.ndarray]): Encoding parameters of shape
1539 (n_qubits, n_input_feat). If None, uses model's encoding parameters.
1540 data_reupload (Union[bool, List[List[bool]], List[List[List[bool]]]]):
1541 Data reupload configuration. If None, uses previously set reupload
1542 configuration.
1543 noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]):
1544 Noise configuration. If None, uses previously set noise parameters.
1545 execution_type (Optional[str]): Measurement type: "expval", "density",
1546 "probs", or "state". If None, uses current execution_type setting.
1547 force_mean (bool): If True, averages results over measurement qubits.
1548 Defaults to False.
1549 gate_mode (str): Gate execution backend, "unitary" or "pulse".
1550 Defaults to "unitary".
1552 Returns:
1553 jnp.ndarray: Circuit output with shape depending on execution_type:
1554 - "expval": (n_output_qubits,) or scalar
1555 - "density": (2^n_output, 2^n_output)
1556 - "probs": (2^n_output,) or (n_pairs, 2^pair_size)
1557 - "state": (2^n_qubits,)
1558 """
1559 # Call forward method which handles the actual caching etc.
1560 return self._forward(
1561 params=params,
1562 inputs=inputs,
1563 pulse_params=pulse_params,
1564 enc_params=enc_params,
1565 data_reupload=data_reupload,
1566 noise_params=noise_params,
1567 execution_type=execution_type,
1568 force_mean=force_mean,
1569 gate_mode=gate_mode,
1570 )
1572 def _forward(
1573 self,
1574 params: Optional[jnp.ndarray] = None,
1575 inputs: Optional[jnp.ndarray] = None,
1576 pulse_params: Optional[jnp.ndarray] = None,
1577 enc_params: Optional[jnp.ndarray] = None,
1578 data_reupload: Union[bool, List[List[bool]], List[List[List[bool]]]] = None,
1579 noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None,
1580 execution_type: Optional[str] = None,
1581 force_mean: bool = False,
1582 gate_mode: str = "unitary",
1583 ) -> jnp.ndarray:
1584 """
1585 Execute the quantum circuit forward pass.
1587 Internal implementation of the forward pass that handles parameter
1588 validation, batch alignment, and circuit execution routing.
1590 Args:
1591 params (Optional[jnp.ndarray]): Variational parameters of shape
1592 (n_layers, n_params_per_layer) or
1593 (batch, n_layers, n_params_per_layer).
1594 If None, uses model's internal parameters.
1595 inputs (Optional[jnp.ndarray]): Input data of shape
1596 (batch_size, n_input_feat).
1597 If None, uses zero inputs.
1598 pulse_params (Optional[jnp.ndarray]): Pulse parameter scalers for
1599 pulse-mode gate execution.
1600 enc_params (Optional[jnp.ndarray]): Encoding parameters of shape
1601 (n_qubits, n_input_feat). If None, uses model's encoding parameters.
1602 data_reupload (Union[bool, List[List[bool]], List[List[List[bool]]]]):
1603 Data reupload configuration. If None, uses previously set reupload
1604 configuration.
1605 noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]):
1606 Noise configuration. If None, uses previously set noise parameters.
1607 execution_type (Optional[str]): Measurement type: "expval", "density",
1608 "probs", or "state". If None, uses current execution_type setting.
1609 force_mean (bool): If True, averages results over measurement qubits.
1610 Defaults to False.
1611 gate_mode (str): Gate execution backend, "unitary" or "pulse".
1612 Defaults to "unitary".
1614 Returns:
1615 jnp.ndarray: Circuit output with shape depending on execution_type:
1616 - "expval": (n_output_qubits,) or scalar
1617 - "density": (2^n_output, 2^n_output)
1618 - "probs": (2^n_output,) or (n_pairs, 2^pair_size)
1619 - "state": (2^n_qubits,)
1621 Raises:
1622 ValueError: If pulse_params provided without pulse gate_mode, or
1623 if noise_params provided with pulse gate_mode.
1624 """
1625 # set the parameters as object attributes
1626 if noise_params is not None:
1627 self.noise_params = noise_params
1628 if execution_type is not None:
1629 self.execution_type = execution_type
1630 self.gate_mode = gate_mode
1632 # consistency checks
1633 if pulse_params is not None and gate_mode != "pulse":
1634 raise ValueError(
1635 "pulse_params were provided but gate_mode is not 'pulse'. "
1636 "Either switch gate_mode='pulse' or do not pass pulse_params."
1637 )
1639 # TODO: add testing
1640 if data_reupload is not None:
1641 self.data_reupload = data_reupload
1643 params = self._params_validation(params)
1644 pulse_params = self._pulse_params_validation(pulse_params)
1645 inputs = self._inputs_validation(inputs)
1646 enc_params = self._enc_params_validation(enc_params)
1648 inputs, params, pulse_params = self._assimilate_batch(
1649 inputs,
1650 params,
1651 pulse_params,
1652 )
1654 # split to generate a sub_key, required for actual execution
1655 self.random_key, sub_key = safe_random_split(self.random_key)
1657 # Build measurement type & observables from execution_type / output_qubit
1658 meas_type, obs = self._build_obs()
1660 # Jaqsi auto-routes between statevector and density-matrix simulation
1661 # based on whether noise channels appear on the tape, so a single
1662 B = np.prod(self.eff_batch_shape)
1664 # kwargs are broadcast (not vmapped over)
1665 exec_kwargs = dict(
1666 noise_params=self.noise_params,
1667 gate_mode=self.gate_mode,
1668 )
1670 # Build a shot key from the random_key if shots are requested
1671 shot_key = None
1672 if self.shots is not None:
1673 # overwrite subkey and split shot_key
1674 sub_key, shot_key = safe_random_split(sub_key)
1676 if B > 1:
1677 # use random keys, derived from the subkey
1678 random_keys = safe_random_split(sub_key, num=B)
1680 in_axes = (
1681 0 if self.batch_shape[1] > 1 else None, # params
1682 0 if self.batch_shape[0] > 1 else None, # inputs
1683 0 if self.batch_shape[2] > 1 else None, # pulse_params
1684 0, # random_keys
1685 None, # enc_params (broadcast, not batched)
1686 )
1688 result = self.script.execute(
1689 type=meas_type,
1690 obs=obs,
1691 args=(params, inputs, pulse_params, random_keys, enc_params),
1692 kwargs=exec_kwargs,
1693 in_axes=in_axes,
1694 shots=self.shots,
1695 key=shot_key,
1696 )
1697 else:
1698 # use the subkey directly
1699 result = self.script.execute(
1700 type=meas_type,
1701 obs=obs,
1702 args=(params, inputs, pulse_params, sub_key, enc_params),
1703 kwargs=exec_kwargs,
1704 shots=self.shots,
1705 key=shot_key,
1706 )
1708 result = self._postprocess_res(result)
1710 # --- Post-processing for partial-qubit measurements ---------------
1711 if self.execution_type == "density" and not self.all_qubit_measurement:
1712 result = js.partial_trace(result, self.n_qubits, self.output_qubit)
1714 if self.execution_type == "probs" and not self.all_qubit_measurement:
1715 if isinstance(self.output_qubit[0], (list, tuple)):
1716 # list of qubit groups - marginalize each independently
1717 result = jnp.stack(
1718 [
1719 js.marginalize_probs(result, self.n_qubits, list(group))
1720 for group in self.output_qubit
1721 ]
1722 )
1723 else:
1724 result = js.marginalize_probs(result, self.n_qubits, self.output_qubit)
1726 result = jnp.asarray(result)
1727 result = result.reshape((*self.eff_batch_shape, *self._result_shape)).squeeze()
1729 if (
1730 self.execution_type in ("expval", "probs")
1731 and force_mean
1732 and len(result.shape) > 0
1733 and self._result_shape[0] > 1
1734 ):
1735 result = result.mean(axis=-1)
1737 return result