Skip to content

References

Ansaetze#

from qml_essentials.ansaetze import Ansaetze
Source code in qml_essentials/ansaetze.py
class Ansaetze:
    def get_available():
        return [
            Ansaetze.No_Ansatz,
            Ansaetze.Circuit_1,
            Ansaetze.Circuit_2,
            Ansaetze.Circuit_3,
            Ansaetze.Circuit_4,
            Ansaetze.Circuit_5,
            Ansaetze.Circuit_6,
            Ansaetze.Circuit_7,
            Ansaetze.Circuit_8,
            Ansaetze.Circuit_9,
            Ansaetze.Circuit_10,
            Ansaetze.Circuit_13,
            Ansaetze.Circuit_14,
            Ansaetze.Circuit_15,
            Ansaetze.Circuit_16,
            Ansaetze.Circuit_17,
            Ansaetze.Circuit_18,
            Ansaetze.Circuit_19,
            Ansaetze.Circuit_20,
            Ansaetze.No_Entangling,
            Ansaetze.Strongly_Entangling,
            Ansaetze.Hardware_Efficient,
            Ansaetze.GHZ,
        ]

    class No_Ansatz(DeclarativeCircuit):
        @staticmethod
        def structure():
            return ()

    class GHZ(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.H),
                Block(gate=Gates.CX, topology=Topology.stairs, reverse=True),
            )

        @staticmethod
        def build(w: np.ndarray, n_qubits: int, **kwargs):
            Gates.H(wires=0, **kwargs)
            for q in range(n_qubits - 1):
                Gates.CX(wires=[q, q + 1], **kwargs)

        @staticmethod
        def n_pulse_params_per_layer(n_qubits: int) -> int:
            n_params = PulseInformation.num_params("H")  # only 1 H
            n_params += (n_qubits - 1) * PulseInformation.num_params(Gates.CX)
            return n_params

    class Circuit_1(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.RX),
                Block(gate=Gates.RZ),
            )

    class Circuit_2(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.RX),
                Block(gate=Gates.RZ),
                Block(
                    gate=Gates.CX,
                    topology=Topology.stairs,
                ),
            )

    class Circuit_3(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.RX),
                Block(gate=Gates.RZ),
                Block(gate=Gates.CRZ, topology=Topology.stairs),
            )

    class Circuit_4(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.RX),
                Block(gate=Gates.RZ),
                Block(gate=Gates.CRX, topology=Topology.stairs),
            )

    class Circuit_5(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.RX),
                Block(gate=Gates.RZ),
                Block(gate=Gates.CRZ, topology=Topology.all_to_all),
                Block(gate=Gates.RX),
                Block(gate=Gates.RZ),
            )

    class Circuit_6(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.RX),
                Block(gate=Gates.RZ),
                Block(gate=Gates.CRX, topology=Topology.all_to_all),
                Block(gate=Gates.RX),
                Block(gate=Gates.RZ),
            )

    class Circuit_7(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.RX),
                Block(gate=Gates.RZ),
                Block(
                    gate=Gates.CRZ,
                    topology=Topology.bricks,
                ),
                Block(gate=Gates.RX),
                Block(gate=Gates.RZ),
                Block(
                    gate=Gates.CRZ,
                    topology=Topology.bricks,
                    offset=1,
                ),
            )

    class Circuit_8(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.RX),
                Block(gate=Gates.RZ),
                Block(
                    gate=Gates.CRX,
                    topology=Topology.bricks,
                ),
                Block(gate=Gates.RX),
                Block(gate=Gates.RZ),
                Block(
                    gate=Gates.CRX,
                    topology=Topology.bricks,
                    offset=1,
                ),
            )

    class Circuit_9(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.H),
                Block(gate="CZ", topology=Topology.stairs),
                Block(gate=Gates.RX),
            )

    class Circuit_10(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.RY),
                Block(gate="CZ", topology=Topology.stairs, offset=-1, wrap=True),
                Block(gate=Gates.RY),
            )

    class Circuit_13(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.RY),
                Block(
                    gate=Gates.CRZ,
                    topology=Topology.stairs,
                    wrap=True,
                    reverse=True,
                    mirror=False,
                ),
                Block(gate=Gates.RY),
                Block(
                    gate=Gates.CRZ,
                    topology=Topology.stairs,
                    reverse=False,
                    mirror=False,
                    offset=lambda n: n - 1,
                    span=3,
                    wrap=True,
                ),
            )

    class Circuit_14(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.RY),
                Block(
                    gate=Gates.CRX,
                    topology=Topology.stairs,
                    wrap=True,
                    reverse=True,
                    mirror=False,
                ),
                Block(gate=Gates.RY),
                Block(
                    gate=Gates.CRX,
                    topology=Topology.stairs,
                    reverse=False,
                    mirror=False,
                    offset=lambda n: n - 1,
                    span=3,
                    wrap=True,
                ),
            )

    class Circuit_15(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.RY),
                Block(
                    gate=Gates.CX,
                    topology=Topology.stairs,
                    wrap=True,
                    reverse=True,
                    mirror=False,
                ),
                Block(gate=Gates.RY),
                Block(
                    gate=Gates.CX,
                    topology=Topology.stairs,
                    reverse=False,
                    mirror=False,
                    offset=lambda n: n - 1,
                    span=3,
                    wrap=True,
                ),
            )

    class Circuit_16(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.RX),
                Block(gate=Gates.RZ),
                Block(
                    gate=Gates.CRZ,
                    topology=Topology.bricks,
                ),
                Block(
                    gate=Gates.CRZ,
                    topology=Topology.bricks,
                    offset=1,
                ),
            )

    class Circuit_17(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.RX),
                Block(gate=Gates.RZ),
                Block(
                    gate=Gates.CRX,
                    topology=Topology.bricks,
                ),
                Block(
                    gate=Gates.CRX,
                    topology=Topology.bricks,
                    offset=1,
                ),
            )

    class Circuit_18(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.RX),
                Block(gate=Gates.RZ),
                Block(
                    gate=Gates.CRZ,
                    topology=Topology.stairs,
                    wrap=True,
                    mirror=False,
                ),
            )

    class Circuit_19(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.RX),
                Block(gate=Gates.RZ),
                Block(
                    gate=Gates.CRX,
                    topology=Topology.stairs,
                    wrap=True,
                    mirror=False,
                ),
            )

    class Circuit_20(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.RY),
                Block(
                    gate=Gates.CX,
                    topology=Topology.stairs,
                    wrap=True,
                    reverse=True,
                    mirror=False,
                ),
                Block(gate=Gates.RY),
                Block(
                    gate=Gates.CX,
                    topology=Topology.stairs,
                    reverse=False,
                    offset=lambda n: n - 2,
                    span=1,
                    wrap=True,
                ),
            )

    class No_Entangling(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (Block(gate=Gates.Rot),)

    class Hardware_Efficient(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.RY),
                Block(gate=Gates.RZ),
                Block(gate=Gates.RY),
                Block(
                    gate=Gates.CX,
                    topology=Topology.bricks,
                    mirror=False,
                ),
                Block(
                    gate=Gates.CX,
                    topology=Topology.bricks,
                    offset=-1,
                    modulo=True,
                    wrap=True,
                    mirror=False,
                ),
            )

    class Strongly_Entangling(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.Rot),
                Block(
                    gate=Gates.CX,
                    topology=Topology.stairs,
                    wrap=True,
                    reverse=False,
                    mirror=False,
                ),
                Block(gate=Gates.Rot),
                Block(
                    gate=Gates.CX,
                    topology=Topology.stairs,
                    reverse=False,
                    span=lambda n: n // 2,
                    wrap=True,
                    mirror=False,
                ),
            )

Gates#

As the structure of the different classes used to realize pulse and unitary gates can be a bit confusing, the following diagram might help:

Gate Structure Gate Structure

from qml_essentials.gates import Gates

Dynamic accessor for quantum Gates.

Routes calls like Gates.RX(...) to either UnitaryGates or PulseGates depending on the gate_mode keyword (defaults to 'unitary').

During circuit building, the pulse manager can be activated via pulse_manager_context, which slices the global model pulse parameters and passes them to each gate. Model pulse parameters act as element-wise scalers on the gate's optimized pulse parameters.

Parameters#

gate_mode : str, optional Determines the backend. 'unitary' for UnitaryGates, 'pulse' for PulseGates. Defaults to 'unitary'.

Examples#

Gates.RX(w, wires) Gates.RX(w, wires, gate_mode="unitary") Gates.RX(w, wires, gate_mode="pulse") Gates.RX(w, wires, pulse_params, gate_mode="pulse")

Source code in qml_essentials/gates.py
class Gates(metaclass=GatesMeta):
    """
    Dynamic accessor for quantum Gates.

    Routes calls like `Gates.RX(...)` to either `UnitaryGates` or `PulseGates`
    depending on the `gate_mode` keyword (defaults to 'unitary').

    During circuit building, the pulse manager can be activated via
    `pulse_manager_context`, which slices the global model pulse parameters
    and passes them to each gate. Model pulse parameters act as element-wise
    scalers on the gate's optimized pulse parameters.

    Parameters
    ----------
    gate_mode : str, optional
        Determines the backend. 'unitary' for UnitaryGates, 'pulse' for PulseGates.
        Defaults to 'unitary'.

    Examples
    --------
    >>> Gates.RX(w, wires)
    >>> Gates.RX(w, wires, gate_mode="unitary")
    >>> Gates.RX(w, wires, gate_mode="pulse")
    >>> Gates.RX(w, wires, pulse_params, gate_mode="pulse")
    """

    def __getattr__(self, gate_name):
        def handler(**kwargs):
            return self._inner_getattr(gate_name, **kwargs)

        return handler

    @staticmethod
    def _inner_getattr(gate_name, *args, **kwargs):
        gate_mode = kwargs.pop("gate_mode", "unitary")

        # Backend selection and kwargs filtering
        allowed_args = ["w", "wires", "phi", "theta", "omega"]
        if gate_mode == "unitary":
            gate_backend = UnitaryGates
            allowed_args += ["noise_params", "random_key"]
        elif gate_mode == "pulse":
            gate_backend = PulseGates
            allowed_args += ["pulse_params"]
        else:
            raise ValueError(
                f"Unknown gate mode: {gate_mode}. Use 'unitary' or 'pulse'."
            )

        if len(kwargs.keys() - allowed_args) > 0:
            # TODO: pulse params are always provided?
            log.debug(
                f"Unsupported keyword arguments: {list(kwargs.keys() - allowed_args)}"
            )

        kwargs = {k: v for k, v in kwargs.items() if k in allowed_args}
        pulse_params = kwargs.get("pulse_params")
        pulse_mgr = getattr(Gates, "_pulse_mgr", None)

        # TODO: rework this part to convert to valid PulseParams earlier
        # Type check on pulse parameters
        if pulse_params is not None:
            # flatten pulse parameters
            if isinstance(pulse_params, (list, tuple)):
                flat_params = pulse_params

            elif isinstance(pulse_params, jax.core.Tracer):
                flat_params = np.ravel(pulse_params)

            elif isinstance(pulse_params, (np.ndarray, np.ndarray)):
                flat_params = pulse_params.flatten().tolist()
            elif isinstance(pulse_params, PulseParams):
                # extract the params in case a full object is given
                kwargs["pulse_params"] = pulse_params.params
                flat_params = pulse_params.params.flatten().tolist()

            else:
                raise TypeError(f"Unsupported pulse_params type: {type(pulse_params)}")

            # checks elements in flat parameters are real numbers or jax Tracer
            if not all(
                isinstance(x, (numbers.Real, jax.core.Tracer)) for x in flat_params
            ):
                raise TypeError(
                    "All elements in pulse_params must be int or float, "
                    f"got {pulse_params}, type {type(pulse_params)}. "
                )

        # Len check on pulse parameters
        if pulse_params is not None and not isinstance(pulse_mgr, PulseParamManager):
            n_params = PulseInformation.gate_by_name(gate_name).size
            if len(flat_params) != n_params:
                raise ValueError(
                    f"Gate '{gate_name}' expects {n_params} pulse parameters, "
                    f"got {len(flat_params)}"
                )

        # Pulse slicing + scaling
        if gate_mode == "pulse" and isinstance(pulse_mgr, PulseParamManager):
            n_params = PulseInformation.gate_by_name(gate_name).size
            scalers = pulse_mgr.get(n_params)
            base = PulseInformation.gate_by_name(gate_name).params
            kwargs["pulse_params"] = base * scalers

        # Call the selected gate backend
        gate = getattr(gate_backend, gate_name, None)
        if gate is None:
            raise AttributeError(
                f"'{gate_backend.__class__.__name__}' object "
                f"has no attribute '{gate_name}'"
            )

        return gate(*args, **kwargs)

    @staticmethod
    @contextmanager
    def pulse_manager_context(pulse_params: np.ndarray):
        """Temporarily set the global pulse manager for circuit building."""
        Gates._pulse_mgr = PulseParamManager(pulse_params)
        try:
            yield
        finally:
            Gates._pulse_mgr = None

    @staticmethod
    def parse_gates(
        gates: Union[str, Callable, List[Union[str, Callable]]],
        set_of_gates=None,
    ):
        set_of_gates = set_of_gates or Gates

        if isinstance(gates, str):
            # if str, use the pennylane fct
            parsed_gates = [getattr(set_of_gates, f"{gates}")]
        elif isinstance(gates, list):
            parsed_gates = []
            for enc in gates:
                # if list, check if str or callable
                if isinstance(enc, str):
                    parsed_gates.append(getattr(set_of_gates, f"{enc}"))
                # check if callable
                elif callable(enc):
                    parsed_gates.append(enc)
                else:
                    raise ValueError(
                        f"Operation {enc} is not a valid gate or callable.\
                        Got {type(enc)}"
                    )
        elif callable(gates):
            # default to callable
            parsed_gates = [gates]
        elif gates is None:
            parsed_gates = [lambda *args, **kwargs: None]
        else:
            raise ValueError(
                f"Operation {gates} is not a valid gate or callable or list of both."
            )
        return parsed_gates

    @staticmethod
    def is_rotational(gate):
        return gate.__name__ in [
            "RX",
            "RY",
            "RZ",
            "Rot",
            "CRX",
            "CRY",
            "CRZ",
        ]

    @staticmethod
    def is_entangling(gate):
        return gate.__name__ in ["CX", "CY", "CZ", "CRX", "CRY", "CRZ"]

pulse_manager_context(pulse_params) staticmethod #

Temporarily set the global pulse manager for circuit building.

Source code in qml_essentials/gates.py
@staticmethod
@contextmanager
def pulse_manager_context(pulse_params: np.ndarray):
    """Temporarily set the global pulse manager for circuit building."""
    Gates._pulse_mgr = PulseParamManager(pulse_params)
    try:
        yield
    finally:
        Gates._pulse_mgr = None
from qml_essentials.gates import UnitaryGates

Unitary Gates#

Collection of unitary quantum gates with optional noise simulation.

Source code in qml_essentials/gates.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
class UnitaryGates:
    """Collection of unitary quantum gates with optional noise simulation."""

    batch_gate_error = True

    @staticmethod
    def NQubitDepolarizingChannel(p: float, wires: List[int]) -> qml.QubitChannel:
        """
        Generate Kraus operators for n-qubit depolarizing channel.

        The n-qubit depolarizing channel models uniform depolarizing noise
        acting on n qubits simultaneously, useful for simulating realistic
        multi-qubit noise affecting entangling gates.

        Args:
            p (float): Total probability of depolarizing error (0 ≤ p ≤ 1).
            wires (List[int]): Qubit indices on which the channel acts.
                Must contain at least 2 qubits.

        Returns:
            qml.QubitChannel: PennyLane QubitChannel with Kraus operators
                representing the depolarizing noise channel.

        Raises:
            ValueError: If p is not in [0, 1] or if fewer than 2 qubits provided.
        """

        def n_qubit_depolarizing_kraus(p: float, n: int) -> List[np.ndarray]:
            if not (0.0 <= p <= 1.0):
                raise ValueError(f"Probability p must be between 0 and 1, got {p}")
            if n < 2:
                raise ValueError(f"Number of qubits must be >= 2, got {n}")

            Id = np.eye(2)
            X = qml.matrix(qml.PauliX(0))
            Y = qml.matrix(qml.PauliY(0))
            Z = qml.matrix(qml.PauliZ(0))
            paulis = [Id, X, Y, Z]

            dim = 2**n
            all_ops = []

            # Generate all n-qubit Pauli tensor products:
            for indices in itertools.product(range(4), repeat=n):
                P = np.eye(1)
                for idx in indices:
                    P = np.kron(P, paulis[idx])
                all_ops.append(P)

            # Identity operator corresponds to all zeros indices (Id^n)
            K0 = np.sqrt(1 - p * (4**n - 1) / (4**n)) * np.eye(dim)

            kraus_ops = []
            for i, P in enumerate(all_ops):
                if i == 0:
                    # Skip the identity, already handled as K0
                    continue
                kraus_ops.append(np.sqrt(p / (4**n)) * P)

            return [K0] + kraus_ops

        return qml.QubitChannel(n_qubit_depolarizing_kraus(p, len(wires)), wires=wires)

    @staticmethod
    def Noise(
        wires: Union[int, List[int]], noise_params: Optional[Dict[str, float]] = None
    ) -> None:
        """
        Apply noise channels to specified qubits.

        Applies various single-qubit and multi-qubit noise channels based on
        the provided noise parameters dictionary.

        Args:
            wires (Union[int, List[int]]): Qubit index or list of qubit indices
                to apply noise to.
            noise_params (Optional[Dict[str, float]]): Dictionary of noise
                parameters. Supported keys:
                - "BitFlip" (float): Bit flip error probability
                - "PhaseFlip" (float): Phase flip error probability
                - "Depolarizing" (float): Single-qubit depolarizing probability
                - "MultiQubitDepolarizing" (float): Multi-qubit depolarizing
                  probability (applies if len(wires) > 1)
                All parameters default to 0.0 if not provided.

        Returns:
            None: Noise channels are applied in-place to the circuit.
        """
        if noise_params is not None:
            if isinstance(wires, int):
                wires = [wires]  # single qubit gate

            # noise on single qubits
            for wire in wires:
                bf = noise_params.get("BitFlip", 0.0)
                if bf > 0:
                    qml.BitFlip(bf, wires=wire)

                pf = noise_params.get("PhaseFlip", 0.0)
                if pf > 0:
                    qml.PhaseFlip(pf, wires=wire)

                dp = noise_params.get("Depolarizing", 0.0)
                if dp > 0:
                    qml.DepolarizingChannel(dp, wires=wire)

            # noise on two-qubits
            if len(wires) > 1:
                p = noise_params.get("MultiQubitDepolarizing", 0.0)
                if p > 0:
                    UnitaryGates.NQubitDepolarizingChannel(p, wires)

    @staticmethod
    def GateError(
        w: Union[float, np.ndarray, List[float]],
        noise_params: Optional[Dict[str, float]] = None,
        random_key: Optional[jax.random.PRNGKey] = None,
    ) -> Tuple[np.ndarray, jax.random.PRNGKey]:
        """
        Apply gate error noise to rotation angle(s).

        Adds Gaussian noise to gate rotation angles to simulate imperfect
        gate implementations.

        Args:
            w (Union[float, np.ndarray, List[float]]): Rotation angle(s) in radians.
            noise_params (Optional[Dict[str, float]]): Dictionary with optional
                "GateError" key specifying standard deviation of Gaussian noise.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for
                stochastic noise generation.

        Returns:
            Tuple[np.ndarray, jax.random.PRNGKey]: Tuple containing:
                - Modified rotation angle(s) with applied noise
                - Updated JAX random key

        Raises:
            AssertionError: If noise_params contains "GateError" but random_key is None.
        """
        if noise_params is not None and noise_params.get("GateError", None) is not None:
            assert (
                random_key is not None
            ), "A random_key must be provided when using GateError"

            random_key, sub_key = safe_random_split(random_key)
            w += noise_params["GateError"] * jax.random.normal(
                sub_key,
                (
                    w.shape
                    if isinstance(w, np.ndarray) and UnitaryGates.batch_gate_error
                    else (1,)
                ),
            )
        return w, random_key

    @staticmethod
    def Rot(
        phi: Union[float, np.ndarray, List[float]],
        theta: Union[float, np.ndarray, List[float]],
        omega: Union[float, np.ndarray, List[float]],
        wires: Union[int, List[int]],
        noise_params: Optional[Dict[str, float]] = None,
        random_key: Optional[jax.random.PRNGKey] = None,
    ) -> None:
        """
        Apply general rotation gate with optional noise.

        Applies a three-angle rotation Rot(phi, theta, omega) with optional
        gate errors and noise channels.

        Args:
            phi (Union[float, np.ndarray, List[float]]): First rotation angle.
            theta (Union[float, np.ndarray, List[float]]): Second rotation angle.
            omega (Union[float, np.ndarray, List[float]]): Third rotation angle.
            wires (Union[int, List[int]]): Qubit index or indices to apply rotation to.
            noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
                Supports BitFlip, PhaseFlip, Depolarizing, and GateError.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.

        Returns:
            None: Gate and noise are applied in-place to the circuit.
        """
        if noise_params is not None and "GateError" in noise_params:
            phi, random_key = UnitaryGates.GateError(phi, noise_params, random_key)
            theta, random_key = UnitaryGates.GateError(theta, noise_params, random_key)
            omega, random_key = UnitaryGates.GateError(omega, noise_params, random_key)
        qml.Rot(phi, theta, omega, wires=wires)
        UnitaryGates.Noise(wires, noise_params)

    @staticmethod
    def RX(
        w: Union[float, np.ndarray, List[float]],
        wires: Union[int, List[int]],
        noise_params: Optional[Dict[str, float]] = None,
        random_key: Optional[jax.random.PRNGKey] = None,
    ) -> None:
        """
        Apply X-axis rotation with optional noise.

        Args:
            w (Union[float, np.ndarray, List[float]]): Rotation angle.
            wires (Union[int, List[int]]): Qubit index or indices.
            noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.

        Returns:
            None: Gate and noise are applied in-place to the circuit.
        """
        w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
        qml.RX(w, wires=wires)
        UnitaryGates.Noise(wires, noise_params)

    @staticmethod
    def RY(
        w: Union[float, np.ndarray, List[float]],
        wires: Union[int, List[int]],
        noise_params: Optional[Dict[str, float]] = None,
        random_key: Optional[jax.random.PRNGKey] = None,
    ) -> None:
        """
        Apply Y-axis rotation with optional noise.

        Args:
            w (Union[float, np.ndarray, List[float]]): Rotation angle.
            wires (Union[int, List[int]]): Qubit index or indices.
            noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.

        Returns:
            None: Gate and noise are applied in-place to the circuit.
        """
        w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
        qml.RY(w, wires=wires)
        UnitaryGates.Noise(wires, noise_params)

    @staticmethod
    def RZ(
        w: Union[float, np.ndarray, List[float]],
        wires: Union[int, List[int]],
        noise_params: Optional[Dict[str, float]] = None,
        random_key: Optional[jax.random.PRNGKey] = None,
    ) -> None:
        """
        Apply Z-axis rotation with optional noise.

        Args:
            w (Union[float, np.ndarray, List[float]]): Rotation angle.
            wires (Union[int, List[int]]): Qubit index or indices.
            noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.

        Returns:
            None: Gate and noise are applied in-place to the circuit.
        """
        w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
        qml.RZ(w, wires=wires)
        UnitaryGates.Noise(wires, noise_params)

    @staticmethod
    def CRX(
        w: Union[float, np.ndarray, List[float]],
        wires: Union[int, List[int]],
        noise_params: Optional[Dict[str, float]] = None,
        random_key: Optional[jax.random.PRNGKey] = None,
    ) -> None:
        """
        Apply controlled X-rotation with optional noise.

        Args:
            w (Union[float, np.ndarray, List[float]]): Rotation angle.
            wires (Union[int, List[int]]): Control and target qubit indices.
            noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.

        Returns:
            None: Gate and noise are applied in-place to the circuit.
        """
        w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
        qml.CRX(w, wires=wires)
        UnitaryGates.Noise(wires, noise_params)

    @staticmethod
    def CRY(
        w: Union[float, np.ndarray, List[float]],
        wires: Union[int, List[int]],
        noise_params: Optional[Dict[str, float]] = None,
        random_key: Optional[jax.random.PRNGKey] = None,
    ) -> None:
        """
        Apply controlled Y-rotation with optional noise.

        Args:
            w (Union[float, np.ndarray, List[float]]): Rotation angle.
            wires (Union[int, List[int]]): Control and target qubit indices.
            noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.

        Returns:
            None: Gate and noise are applied in-place to the circuit.
        """
        w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
        qml.CRY(w, wires=wires)
        UnitaryGates.Noise(wires, noise_params)

    @staticmethod
    def CRZ(
        w: Union[float, np.ndarray, List[float]],
        wires: Union[int, List[int]],
        noise_params: Optional[Dict[str, float]] = None,
        random_key: Optional[jax.random.PRNGKey] = None,
    ) -> None:
        """
        Apply controlled Z-rotation with optional noise.

        Args:
            w (Union[float, np.ndarray, List[float]]): Rotation angle.
            wires (Union[int, List[int]]): Control and target qubit indices.
            noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.

        Returns:
            None: Gate and noise are applied in-place to the circuit.
        """
        w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
        qml.CRZ(w, wires=wires)
        UnitaryGates.Noise(wires, noise_params)

    @staticmethod
    def CX(
        wires: Union[int, List[int]],
        noise_params: Optional[Dict[str, float]] = None,
        random_key: Optional[jax.random.PRNGKey] = None,
    ) -> None:
        """
        Apply controlled-NOT (CNOT) gate with optional noise.

        Args:
            wires (Union[int, List[int]]): Control and target qubit indices.
            noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
                (not used in this gate).

        Returns:
            None: Gate and noise are applied in-place to the circuit.
        """
        qml.CNOT(wires=wires)
        UnitaryGates.Noise(wires, noise_params)

    @staticmethod
    def CY(
        wires: Union[int, List[int]],
        noise_params: Optional[Dict[str, float]] = None,
        random_key: Optional[jax.random.PRNGKey] = None,
    ) -> None:
        """
        Apply controlled-Y gate with optional noise.

        Args:
            wires (Union[int, List[int]]): Control and target qubit indices.
            noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
                (not used in this gate).

        Returns:
            None: Gate and noise are applied in-place to the circuit.
        """
        qml.CY(wires=wires)
        UnitaryGates.Noise(wires, noise_params)

    @staticmethod
    def CZ(
        wires: Union[int, List[int]],
        noise_params: Optional[Dict[str, float]] = None,
        random_key: Optional[jax.random.PRNGKey] = None,
    ) -> None:
        """
        Apply controlled-Z gate with optional noise.

        Args:
            wires (Union[int, List[int]]): Control and target qubit indices.
            noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
                (not used in this gate).

        Returns:
            None: Gate and noise are applied in-place to the circuit.
        """
        qml.CZ(wires=wires)
        UnitaryGates.Noise(wires, noise_params)

    @staticmethod
    def H(
        wires: Union[int, List[int]],
        noise_params: Optional[Dict[str, float]] = None,
        random_key: Optional[jax.random.PRNGKey] = None,
    ) -> None:
        """
        Apply Hadamard gate with optional noise.

        Args:
            wires (Union[int, List[int]]): Qubit index or indices.
            noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
                (not used in this gate).

        Returns:
            None: Gate and noise are applied in-place to the circuit.
        """
        qml.Hadamard(wires=wires)
        UnitaryGates.Noise(wires, noise_params)

CRX(w, wires, noise_params=None, random_key=None) staticmethod #

Apply controlled X-rotation with optional noise.

Parameters:

Name Type Description Default
w Union[float, ndarray, List[float]]

Rotation angle.

required
wires Union[int, List[int]]

Control and target qubit indices.

required
noise_params Optional[Dict[str, float]]

Noise parameters dictionary.

None
random_key Optional[PRNGKey]

JAX random key for noise.

None

Returns:

Name Type Description
None None

Gate and noise are applied in-place to the circuit.

Source code in qml_essentials/gates.py
@staticmethod
def CRX(
    w: Union[float, np.ndarray, List[float]],
    wires: Union[int, List[int]],
    noise_params: Optional[Dict[str, float]] = None,
    random_key: Optional[jax.random.PRNGKey] = None,
) -> None:
    """
    Apply controlled X-rotation with optional noise.

    Args:
        w (Union[float, np.ndarray, List[float]]): Rotation angle.
        wires (Union[int, List[int]]): Control and target qubit indices.
        noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.

    Returns:
        None: Gate and noise are applied in-place to the circuit.
    """
    w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
    qml.CRX(w, wires=wires)
    UnitaryGates.Noise(wires, noise_params)

CRY(w, wires, noise_params=None, random_key=None) staticmethod #

Apply controlled Y-rotation with optional noise.

Parameters:

Name Type Description Default
w Union[float, ndarray, List[float]]

Rotation angle.

required
wires Union[int, List[int]]

Control and target qubit indices.

required
noise_params Optional[Dict[str, float]]

Noise parameters dictionary.

None
random_key Optional[PRNGKey]

JAX random key for noise.

None

Returns:

Name Type Description
None None

Gate and noise are applied in-place to the circuit.

Source code in qml_essentials/gates.py
@staticmethod
def CRY(
    w: Union[float, np.ndarray, List[float]],
    wires: Union[int, List[int]],
    noise_params: Optional[Dict[str, float]] = None,
    random_key: Optional[jax.random.PRNGKey] = None,
) -> None:
    """
    Apply controlled Y-rotation with optional noise.

    Args:
        w (Union[float, np.ndarray, List[float]]): Rotation angle.
        wires (Union[int, List[int]]): Control and target qubit indices.
        noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.

    Returns:
        None: Gate and noise are applied in-place to the circuit.
    """
    w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
    qml.CRY(w, wires=wires)
    UnitaryGates.Noise(wires, noise_params)

CRZ(w, wires, noise_params=None, random_key=None) staticmethod #

Apply controlled Z-rotation with optional noise.

Parameters:

Name Type Description Default
w Union[float, ndarray, List[float]]

Rotation angle.

required
wires Union[int, List[int]]

Control and target qubit indices.

required
noise_params Optional[Dict[str, float]]

Noise parameters dictionary.

None
random_key Optional[PRNGKey]

JAX random key for noise.

None

Returns:

Name Type Description
None None

Gate and noise are applied in-place to the circuit.

Source code in qml_essentials/gates.py
@staticmethod
def CRZ(
    w: Union[float, np.ndarray, List[float]],
    wires: Union[int, List[int]],
    noise_params: Optional[Dict[str, float]] = None,
    random_key: Optional[jax.random.PRNGKey] = None,
) -> None:
    """
    Apply controlled Z-rotation with optional noise.

    Args:
        w (Union[float, np.ndarray, List[float]]): Rotation angle.
        wires (Union[int, List[int]]): Control and target qubit indices.
        noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.

    Returns:
        None: Gate and noise are applied in-place to the circuit.
    """
    w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
    qml.CRZ(w, wires=wires)
    UnitaryGates.Noise(wires, noise_params)

CX(wires, noise_params=None, random_key=None) staticmethod #

Apply controlled-NOT (CNOT) gate with optional noise.

Parameters:

Name Type Description Default
wires Union[int, List[int]]

Control and target qubit indices.

required
noise_params Optional[Dict[str, float]]

Noise parameters dictionary.

None
random_key Optional[PRNGKey]

JAX random key for compatibility (not used in this gate).

None

Returns:

Name Type Description
None None

Gate and noise are applied in-place to the circuit.

Source code in qml_essentials/gates.py
@staticmethod
def CX(
    wires: Union[int, List[int]],
    noise_params: Optional[Dict[str, float]] = None,
    random_key: Optional[jax.random.PRNGKey] = None,
) -> None:
    """
    Apply controlled-NOT (CNOT) gate with optional noise.

    Args:
        wires (Union[int, List[int]]): Control and target qubit indices.
        noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
            (not used in this gate).

    Returns:
        None: Gate and noise are applied in-place to the circuit.
    """
    qml.CNOT(wires=wires)
    UnitaryGates.Noise(wires, noise_params)

CY(wires, noise_params=None, random_key=None) staticmethod #

Apply controlled-Y gate with optional noise.

Parameters:

Name Type Description Default
wires Union[int, List[int]]

Control and target qubit indices.

required
noise_params Optional[Dict[str, float]]

Noise parameters dictionary.

None
random_key Optional[PRNGKey]

JAX random key for compatibility (not used in this gate).

None

Returns:

Name Type Description
None None

Gate and noise are applied in-place to the circuit.

Source code in qml_essentials/gates.py
@staticmethod
def CY(
    wires: Union[int, List[int]],
    noise_params: Optional[Dict[str, float]] = None,
    random_key: Optional[jax.random.PRNGKey] = None,
) -> None:
    """
    Apply controlled-Y gate with optional noise.

    Args:
        wires (Union[int, List[int]]): Control and target qubit indices.
        noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
            (not used in this gate).

    Returns:
        None: Gate and noise are applied in-place to the circuit.
    """
    qml.CY(wires=wires)
    UnitaryGates.Noise(wires, noise_params)

CZ(wires, noise_params=None, random_key=None) staticmethod #

Apply controlled-Z gate with optional noise.

Parameters:

Name Type Description Default
wires Union[int, List[int]]

Control and target qubit indices.

required
noise_params Optional[Dict[str, float]]

Noise parameters dictionary.

None
random_key Optional[PRNGKey]

JAX random key for compatibility (not used in this gate).

None

Returns:

Name Type Description
None None

Gate and noise are applied in-place to the circuit.

Source code in qml_essentials/gates.py
@staticmethod
def CZ(
    wires: Union[int, List[int]],
    noise_params: Optional[Dict[str, float]] = None,
    random_key: Optional[jax.random.PRNGKey] = None,
) -> None:
    """
    Apply controlled-Z gate with optional noise.

    Args:
        wires (Union[int, List[int]]): Control and target qubit indices.
        noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
            (not used in this gate).

    Returns:
        None: Gate and noise are applied in-place to the circuit.
    """
    qml.CZ(wires=wires)
    UnitaryGates.Noise(wires, noise_params)

GateError(w, noise_params=None, random_key=None) staticmethod #

Apply gate error noise to rotation angle(s).

Adds Gaussian noise to gate rotation angles to simulate imperfect gate implementations.

Parameters:

Name Type Description Default
w Union[float, ndarray, List[float]]

Rotation angle(s) in radians.

required
noise_params Optional[Dict[str, float]]

Dictionary with optional "GateError" key specifying standard deviation of Gaussian noise.

None
random_key Optional[PRNGKey]

JAX random key for stochastic noise generation.

None

Returns:

Type Description
Tuple[ndarray, PRNGKey]

Tuple[np.ndarray, jax.random.PRNGKey]: Tuple containing: - Modified rotation angle(s) with applied noise - Updated JAX random key

Raises:

Type Description
AssertionError

If noise_params contains "GateError" but random_key is None.

Source code in qml_essentials/gates.py
@staticmethod
def GateError(
    w: Union[float, np.ndarray, List[float]],
    noise_params: Optional[Dict[str, float]] = None,
    random_key: Optional[jax.random.PRNGKey] = None,
) -> Tuple[np.ndarray, jax.random.PRNGKey]:
    """
    Apply gate error noise to rotation angle(s).

    Adds Gaussian noise to gate rotation angles to simulate imperfect
    gate implementations.

    Args:
        w (Union[float, np.ndarray, List[float]]): Rotation angle(s) in radians.
        noise_params (Optional[Dict[str, float]]): Dictionary with optional
            "GateError" key specifying standard deviation of Gaussian noise.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for
            stochastic noise generation.

    Returns:
        Tuple[np.ndarray, jax.random.PRNGKey]: Tuple containing:
            - Modified rotation angle(s) with applied noise
            - Updated JAX random key

    Raises:
        AssertionError: If noise_params contains "GateError" but random_key is None.
    """
    if noise_params is not None and noise_params.get("GateError", None) is not None:
        assert (
            random_key is not None
        ), "A random_key must be provided when using GateError"

        random_key, sub_key = safe_random_split(random_key)
        w += noise_params["GateError"] * jax.random.normal(
            sub_key,
            (
                w.shape
                if isinstance(w, np.ndarray) and UnitaryGates.batch_gate_error
                else (1,)
            ),
        )
    return w, random_key

H(wires, noise_params=None, random_key=None) staticmethod #

Apply Hadamard gate with optional noise.

Parameters:

Name Type Description Default
wires Union[int, List[int]]

Qubit index or indices.

required
noise_params Optional[Dict[str, float]]

Noise parameters dictionary.

None
random_key Optional[PRNGKey]

JAX random key for compatibility (not used in this gate).

None

Returns:

Name Type Description
None None

Gate and noise are applied in-place to the circuit.

Source code in qml_essentials/gates.py
@staticmethod
def H(
    wires: Union[int, List[int]],
    noise_params: Optional[Dict[str, float]] = None,
    random_key: Optional[jax.random.PRNGKey] = None,
) -> None:
    """
    Apply Hadamard gate with optional noise.

    Args:
        wires (Union[int, List[int]]): Qubit index or indices.
        noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
            (not used in this gate).

    Returns:
        None: Gate and noise are applied in-place to the circuit.
    """
    qml.Hadamard(wires=wires)
    UnitaryGates.Noise(wires, noise_params)

NQubitDepolarizingChannel(p, wires) staticmethod #

Generate Kraus operators for n-qubit depolarizing channel.

The n-qubit depolarizing channel models uniform depolarizing noise acting on n qubits simultaneously, useful for simulating realistic multi-qubit noise affecting entangling gates.

Parameters:

Name Type Description Default
p float

Total probability of depolarizing error (0 ≤ p ≤ 1).

required
wires List[int]

Qubit indices on which the channel acts. Must contain at least 2 qubits.

required

Returns:

Type Description
QubitChannel

qml.QubitChannel: PennyLane QubitChannel with Kraus operators representing the depolarizing noise channel.

Raises:

Type Description
ValueError

If p is not in [0, 1] or if fewer than 2 qubits provided.

Source code in qml_essentials/gates.py
@staticmethod
def NQubitDepolarizingChannel(p: float, wires: List[int]) -> qml.QubitChannel:
    """
    Generate Kraus operators for n-qubit depolarizing channel.

    The n-qubit depolarizing channel models uniform depolarizing noise
    acting on n qubits simultaneously, useful for simulating realistic
    multi-qubit noise affecting entangling gates.

    Args:
        p (float): Total probability of depolarizing error (0 ≤ p ≤ 1).
        wires (List[int]): Qubit indices on which the channel acts.
            Must contain at least 2 qubits.

    Returns:
        qml.QubitChannel: PennyLane QubitChannel with Kraus operators
            representing the depolarizing noise channel.

    Raises:
        ValueError: If p is not in [0, 1] or if fewer than 2 qubits provided.
    """

    def n_qubit_depolarizing_kraus(p: float, n: int) -> List[np.ndarray]:
        if not (0.0 <= p <= 1.0):
            raise ValueError(f"Probability p must be between 0 and 1, got {p}")
        if n < 2:
            raise ValueError(f"Number of qubits must be >= 2, got {n}")

        Id = np.eye(2)
        X = qml.matrix(qml.PauliX(0))
        Y = qml.matrix(qml.PauliY(0))
        Z = qml.matrix(qml.PauliZ(0))
        paulis = [Id, X, Y, Z]

        dim = 2**n
        all_ops = []

        # Generate all n-qubit Pauli tensor products:
        for indices in itertools.product(range(4), repeat=n):
            P = np.eye(1)
            for idx in indices:
                P = np.kron(P, paulis[idx])
            all_ops.append(P)

        # Identity operator corresponds to all zeros indices (Id^n)
        K0 = np.sqrt(1 - p * (4**n - 1) / (4**n)) * np.eye(dim)

        kraus_ops = []
        for i, P in enumerate(all_ops):
            if i == 0:
                # Skip the identity, already handled as K0
                continue
            kraus_ops.append(np.sqrt(p / (4**n)) * P)

        return [K0] + kraus_ops

    return qml.QubitChannel(n_qubit_depolarizing_kraus(p, len(wires)), wires=wires)

Noise(wires, noise_params=None) staticmethod #

Apply noise channels to specified qubits.

Applies various single-qubit and multi-qubit noise channels based on the provided noise parameters dictionary.

Parameters:

Name Type Description Default
wires Union[int, List[int]]

Qubit index or list of qubit indices to apply noise to.

required
noise_params Optional[Dict[str, float]]

Dictionary of noise parameters. Supported keys: - "BitFlip" (float): Bit flip error probability - "PhaseFlip" (float): Phase flip error probability - "Depolarizing" (float): Single-qubit depolarizing probability - "MultiQubitDepolarizing" (float): Multi-qubit depolarizing probability (applies if len(wires) > 1) All parameters default to 0.0 if not provided.

None

Returns:

Name Type Description
None None

Noise channels are applied in-place to the circuit.

Source code in qml_essentials/gates.py
@staticmethod
def Noise(
    wires: Union[int, List[int]], noise_params: Optional[Dict[str, float]] = None
) -> None:
    """
    Apply noise channels to specified qubits.

    Applies various single-qubit and multi-qubit noise channels based on
    the provided noise parameters dictionary.

    Args:
        wires (Union[int, List[int]]): Qubit index or list of qubit indices
            to apply noise to.
        noise_params (Optional[Dict[str, float]]): Dictionary of noise
            parameters. Supported keys:
            - "BitFlip" (float): Bit flip error probability
            - "PhaseFlip" (float): Phase flip error probability
            - "Depolarizing" (float): Single-qubit depolarizing probability
            - "MultiQubitDepolarizing" (float): Multi-qubit depolarizing
              probability (applies if len(wires) > 1)
            All parameters default to 0.0 if not provided.

    Returns:
        None: Noise channels are applied in-place to the circuit.
    """
    if noise_params is not None:
        if isinstance(wires, int):
            wires = [wires]  # single qubit gate

        # noise on single qubits
        for wire in wires:
            bf = noise_params.get("BitFlip", 0.0)
            if bf > 0:
                qml.BitFlip(bf, wires=wire)

            pf = noise_params.get("PhaseFlip", 0.0)
            if pf > 0:
                qml.PhaseFlip(pf, wires=wire)

            dp = noise_params.get("Depolarizing", 0.0)
            if dp > 0:
                qml.DepolarizingChannel(dp, wires=wire)

        # noise on two-qubits
        if len(wires) > 1:
            p = noise_params.get("MultiQubitDepolarizing", 0.0)
            if p > 0:
                UnitaryGates.NQubitDepolarizingChannel(p, wires)

RX(w, wires, noise_params=None, random_key=None) staticmethod #

Apply X-axis rotation with optional noise.

Parameters:

Name Type Description Default
w Union[float, ndarray, List[float]]

Rotation angle.

required
wires Union[int, List[int]]

Qubit index or indices.

required
noise_params Optional[Dict[str, float]]

Noise parameters dictionary.

None
random_key Optional[PRNGKey]

JAX random key for noise.

None

Returns:

Name Type Description
None None

Gate and noise are applied in-place to the circuit.

Source code in qml_essentials/gates.py
@staticmethod
def RX(
    w: Union[float, np.ndarray, List[float]],
    wires: Union[int, List[int]],
    noise_params: Optional[Dict[str, float]] = None,
    random_key: Optional[jax.random.PRNGKey] = None,
) -> None:
    """
    Apply X-axis rotation with optional noise.

    Args:
        w (Union[float, np.ndarray, List[float]]): Rotation angle.
        wires (Union[int, List[int]]): Qubit index or indices.
        noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.

    Returns:
        None: Gate and noise are applied in-place to the circuit.
    """
    w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
    qml.RX(w, wires=wires)
    UnitaryGates.Noise(wires, noise_params)

RY(w, wires, noise_params=None, random_key=None) staticmethod #

Apply Y-axis rotation with optional noise.

Parameters:

Name Type Description Default
w Union[float, ndarray, List[float]]

Rotation angle.

required
wires Union[int, List[int]]

Qubit index or indices.

required
noise_params Optional[Dict[str, float]]

Noise parameters dictionary.

None
random_key Optional[PRNGKey]

JAX random key for noise.

None

Returns:

Name Type Description
None None

Gate and noise are applied in-place to the circuit.

Source code in qml_essentials/gates.py
@staticmethod
def RY(
    w: Union[float, np.ndarray, List[float]],
    wires: Union[int, List[int]],
    noise_params: Optional[Dict[str, float]] = None,
    random_key: Optional[jax.random.PRNGKey] = None,
) -> None:
    """
    Apply Y-axis rotation with optional noise.

    Args:
        w (Union[float, np.ndarray, List[float]]): Rotation angle.
        wires (Union[int, List[int]]): Qubit index or indices.
        noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.

    Returns:
        None: Gate and noise are applied in-place to the circuit.
    """
    w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
    qml.RY(w, wires=wires)
    UnitaryGates.Noise(wires, noise_params)

RZ(w, wires, noise_params=None, random_key=None) staticmethod #

Apply Z-axis rotation with optional noise.

Parameters:

Name Type Description Default
w Union[float, ndarray, List[float]]

Rotation angle.

required
wires Union[int, List[int]]

Qubit index or indices.

required
noise_params Optional[Dict[str, float]]

Noise parameters dictionary.

None
random_key Optional[PRNGKey]

JAX random key for noise.

None

Returns:

Name Type Description
None None

Gate and noise are applied in-place to the circuit.

Source code in qml_essentials/gates.py
@staticmethod
def RZ(
    w: Union[float, np.ndarray, List[float]],
    wires: Union[int, List[int]],
    noise_params: Optional[Dict[str, float]] = None,
    random_key: Optional[jax.random.PRNGKey] = None,
) -> None:
    """
    Apply Z-axis rotation with optional noise.

    Args:
        w (Union[float, np.ndarray, List[float]]): Rotation angle.
        wires (Union[int, List[int]]): Qubit index or indices.
        noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.

    Returns:
        None: Gate and noise are applied in-place to the circuit.
    """
    w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
    qml.RZ(w, wires=wires)
    UnitaryGates.Noise(wires, noise_params)

Rot(phi, theta, omega, wires, noise_params=None, random_key=None) staticmethod #

Apply general rotation gate with optional noise.

Applies a three-angle rotation Rot(phi, theta, omega) with optional gate errors and noise channels.

Parameters:

Name Type Description Default
phi Union[float, ndarray, List[float]]

First rotation angle.

required
theta Union[float, ndarray, List[float]]

Second rotation angle.

required
omega Union[float, ndarray, List[float]]

Third rotation angle.

required
wires Union[int, List[int]]

Qubit index or indices to apply rotation to.

required
noise_params Optional[Dict[str, float]]

Noise parameters dictionary. Supports BitFlip, PhaseFlip, Depolarizing, and GateError.

None
random_key Optional[PRNGKey]

JAX random key for noise.

None

Returns:

Name Type Description
None None

Gate and noise are applied in-place to the circuit.

Source code in qml_essentials/gates.py
@staticmethod
def Rot(
    phi: Union[float, np.ndarray, List[float]],
    theta: Union[float, np.ndarray, List[float]],
    omega: Union[float, np.ndarray, List[float]],
    wires: Union[int, List[int]],
    noise_params: Optional[Dict[str, float]] = None,
    random_key: Optional[jax.random.PRNGKey] = None,
) -> None:
    """
    Apply general rotation gate with optional noise.

    Applies a three-angle rotation Rot(phi, theta, omega) with optional
    gate errors and noise channels.

    Args:
        phi (Union[float, np.ndarray, List[float]]): First rotation angle.
        theta (Union[float, np.ndarray, List[float]]): Second rotation angle.
        omega (Union[float, np.ndarray, List[float]]): Third rotation angle.
        wires (Union[int, List[int]]): Qubit index or indices to apply rotation to.
        noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
            Supports BitFlip, PhaseFlip, Depolarizing, and GateError.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.

    Returns:
        None: Gate and noise are applied in-place to the circuit.
    """
    if noise_params is not None and "GateError" in noise_params:
        phi, random_key = UnitaryGates.GateError(phi, noise_params, random_key)
        theta, random_key = UnitaryGates.GateError(theta, noise_params, random_key)
        omega, random_key = UnitaryGates.GateError(omega, noise_params, random_key)
    qml.Rot(phi, theta, omega, wires=wires)
    UnitaryGates.Noise(wires, noise_params)

Pulse Gates#

from qml_essentials.gates import PulseGates

Pulse-level implementations of quantum gates.

Implements quantum gates using time-dependent Hamiltonians and pulse sequences, following the approach from https://doi.org/10.5445/IR/1000184129. Gates are decomposed using shaped Gaussian pulses with carrier modulation.

Attributes:

Name Type Description
omega_q float

Qubit frequency (10Ï€).

omega_c float

Carrier frequency (10Ï€).

H_static ndarray

Static Hamiltonian in qubit rotating frame.

Id, X, Y, Z (np.ndarray

Pauli matrices for gate construction.

Source code in qml_essentials/gates.py
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
class PulseGates:
    """
    Pulse-level implementations of quantum gates.

    Implements quantum gates using time-dependent Hamiltonians and pulse
    sequences, following the approach from https://doi.org/10.5445/IR/1000184129.
    Gates are decomposed using shaped Gaussian pulses with carrier modulation.

    Attributes:
        omega_q (float): Qubit frequency (10Ï€).
        omega_c (float): Carrier frequency (10Ï€).
        H_static (np.ndarray): Static Hamiltonian in qubit rotating frame.
        Id, X, Y, Z (np.ndarray): Pauli matrices for gate construction.
    """

    # NOTE: Implementation of S, RX, RY, RZ, CZ, CNOT/CX and H pulse level
    #   gates closely follow https://doi.org/10.5445/IR/1000184129
    # TODO: Mention deviations from the above?
    omega_q = 10 * np.pi
    omega_c = 10 * np.pi

    H_static = np.array([[np.exp(1j * omega_q / 2), 0], [0, np.exp(-1j * omega_q / 2)]])

    Id = np.eye(2, dtype=np.complex64)
    X = np.array([[0, 1], [1, 0]])
    Y = np.array([[0, -1j], [1j, 0]])
    Z = np.array([[1, 0], [0, -1]])

    @staticmethod
    def _S(
        p: Union[List[float], np.ndarray],
        t: Union[float, List[float], np.ndarray],
        phi_c: float,
    ) -> np.ndarray:
        """
        Generate shaped Gaussian pulse envelope with carrier modulation.

        Internal helper function for creating time-dependent pulse shapes
        used in rotation gates. Not intended for direct circuit use.

        Args:
            p (Union[List[float], np.ndarray]): Pulse parameters [A, sigma]:
                - A (float): Amplitude of the Gaussian envelope
                - sigma (float): Width (standard deviation) of the Gaussian
            t (Union[float, List[float], np.ndarray]): Time or time interval
                for pulse application. If sequence, center is computed as midpoint.
            phi_c (float): Phase offset for the cosine carrier.

        Returns:
            np.ndarray: Shaped pulse amplitude at time(s) t.
        """
        A, sigma = p
        t_c = (t[0] + t[1]) / 2 if isinstance(t, (list, tuple)) else t / 2

        f = A * np.exp(-0.5 * ((t - t_c) / sigma) ** 2)
        x = np.cos(PulseGates.omega_c * t + phi_c)

        return f * x

    @staticmethod
    def Rot(
        phi: float,
        theta: float,
        omega: float,
        wires: Union[int, List[int]],
        pulse_params: Optional[np.ndarray] = None,
    ) -> None:
        """
        Apply general single-qubit rotation using pulse decomposition.

        Decomposes a general rotation into RZ(phi) · RY(theta) · RZ(omega)
        and applies each component using pulse-level implementations.

        Args:
            phi (float): First rotation angle.
            theta (float): Second rotation angle.
            omega (float): Third rotation angle.
            wires (Union[int, List[int]]): Qubit index or indices to apply rotation to.
            pulse_params (Optional[np.ndarray]): Pulse parameters for the
                composing gates. If None, uses optimized parameters.

        Returns:
            None: Gates are applied in-place to the circuit.
        """
        params_RZ_1, params_RY, params_RZ_2 = PulseInformation.Rot.split_params(
            pulse_params
        )

        PulseGates.RZ(phi, wires=wires, pulse_params=params_RZ_1)
        PulseGates.RY(theta, wires=wires, pulse_params=params_RY)
        PulseGates.RZ(omega, wires=wires, pulse_params=params_RZ_2)

    @staticmethod
    def RX(
        w: float,
        wires: Union[int, List[int]],
        pulse_params: Optional[np.ndarray] = None,
    ) -> None:
        """
        Apply X-axis rotation using pulse-level implementation.

        Implements RX rotation using a shaped Gaussian pulse with optimized
        envelope parameters.

        Args:
            w (float): Rotation angle in radians.
            wires (Union[int, List[int]]): Qubit index or indices to apply rotation to.
            pulse_params (Optional[np.ndarray]): Array containing pulse parameters
                [A, sigma, t] for the Gaussian envelope. If None, uses optimized
                parameters.

        Returns:
            None: Gate is applied in-place to the circuit.
        """
        pulse_params = PulseInformation.RX.split_params(pulse_params)

        def Sx(p, t):
            return PulseGates._S(p, t, phi_c=np.pi) * w

        _H = PulseGates.H_static.conj().T @ PulseGates.X @ PulseGates.H_static
        _H = qml.Hermitian(_H, wires=wires)
        H_eff = Sx * _H

        qml.evolve(H_eff)([pulse_params[0:2]], pulse_params[2])

    @staticmethod
    def RY(
        w: float,
        wires: Union[int, List[int]],
        pulse_params: Optional[np.ndarray] = None,
    ) -> None:
        """
        Apply Y-axis rotation using pulse-level implementation.

        Implements RY rotation using a shaped Gaussian pulse with optimized
        envelope parameters.

        Args:
            w (float): Rotation angle in radians.
            wires (Union[int, List[int]]): Qubit index or indices to apply rotation to.
            pulse_params (Optional[np.ndarray]): Array containing pulse parameters
                [A, sigma, t] for the Gaussian envelope. If None, uses optimized
                parameters.

        Returns:
            None: Gate is applied in-place to the circuit.
        """
        pulse_params = PulseInformation.RY.split_params(pulse_params)

        def Sy(p, t):
            return PulseGates._S(p, t, phi_c=-np.pi / 2) * w

        _H = PulseGates.H_static.conj().T @ PulseGates.Y @ PulseGates.H_static
        _H = qml.Hermitian(_H, wires=wires)
        H_eff = Sy * _H

        qml.evolve(H_eff)([pulse_params[0:2]], pulse_params[2])

    @staticmethod
    def RZ(
        w: float, wires: Union[int, List[int]], pulse_params: Optional[float] = None
    ) -> None:
        """
        Apply Z-axis rotation using pulse-level implementation.

        Implements RZ rotation using virtual Z rotations (phase tracking)
        without physical pulse application.

        Args:
            w (float): Rotation angle in radians.
            wires (Union[int, List[int]]): Qubit index or indices to apply rotation to.
            pulse_params (Optional[float]): Duration parameter for the pulse.
                Rotation angle = w * 2 * pulse_params. Defaults to 0.5 if None.

        Returns:
            None: Gate is applied in-place to the circuit.
        """
        pulse_params = PulseInformation.RZ.split_params(pulse_params)

        _H = qml.Hermitian(PulseGates.Z, wires=wires)

        def Sz(p, t):
            return p * w

        H_eff = Sz * _H

        qml.evolve(H_eff)([pulse_params], 1)

    @staticmethod
    def H(
        wires: Union[int, List[int]], pulse_params: Optional[np.ndarray] = None
    ) -> None:
        """
        Apply Hadamard gate using pulse decomposition.

        Implements Hadamard as RZ(π) · RY(π/2) with a correction phase,
        using pulse-level implementations for each component.

        Args:
            wires (Union[int, List[int]]): Qubit index or indices to apply gate to.
            pulse_params (Optional[np.ndarray]): Pulse parameters for the
                composing gates. If None, uses optimized parameters.

        Returns:
            None: Gate is applied in-place to the circuit.
        """
        pulse_params_RZ, pulse_params_RY = PulseInformation.H.split_params(pulse_params)

        # qml.GlobalPhase(-np.pi / 2)  # this could act as substitute to Sc
        PulseGates.RZ(np.pi, wires=wires, pulse_params=pulse_params_RZ)
        PulseGates.RY(np.pi / 2, wires=wires, pulse_params=pulse_params_RY)

        def Sc(p, t):
            return -1.0

        _H = np.pi / 2 * np.eye(2, dtype=np.complex64)
        _H = qml.Hermitian(_H, wires=wires)
        H_corr = Sc * _H

        qml.evolve(H_corr)([0], 1)

    @staticmethod
    def CX(wires: List[int], pulse_params: Optional[np.ndarray] = None) -> None:
        """
        Apply CNOT gate using pulse decomposition.

        Implements CNOT as H_target · CZ · H_target, where H and CZ are
        applied using their respective pulse-level implementations.

        Args:
            wires (List[int]): Control and target qubit indices [control, target].
            pulse_params (Optional[np.ndarray]): Pulse parameters for the
                composing gates. If None, uses optimized parameters.

        Returns:
            None: Gate is applied in-place to the circuit.
        """
        params_H_1, params_CZ, params_H_2 = PulseInformation.CX.split_params(
            pulse_params
        )

        target = wires[1]

        PulseGates.H(wires=target, pulse_params=params_H_1)
        PulseGates.CZ(wires=wires, pulse_params=params_CZ)
        PulseGates.H(wires=target, pulse_params=params_H_2)

    @staticmethod
    def CY(wires: List[int], pulse_params: Optional[np.ndarray] = None) -> None:
        """
        Apply controlled-Y gate using pulse decomposition.

        Implements CY as RZ(-π/2)_target · CX · RZ(π/2)_target using
        pulse-level implementations.

        Args:
            wires (List[int]): Control and target qubit indices [control, target].
            pulse_params (Optional[np.ndarray]): Pulse parameters for the
                composing gates. If None, uses optimized parameters.

        Returns:
            None: Gate is applied in-place to the circuit.
        """
        params_RZ_1, params_CX, params_RZ_2 = PulseInformation.CY.split_params(
            pulse_params
        )

        target = wires[1]

        PulseGates.RZ(-np.pi / 2, wires=target, pulse_params=params_RZ_1)
        PulseGates.CX(wires=wires, pulse_params=params_CX)
        PulseGates.RZ(np.pi / 2, wires=target, pulse_params=params_RZ_2)

    @staticmethod
    def CZ(wires: List[int], pulse_params: Optional[float] = None) -> None:
        """
        Apply controlled-Z gate using pulse-level implementation.

        Implements CZ using a two-qubit interaction Hamiltonian based on
        ZZ coupling.

        Args:
            wires (List[int]): Control and target qubit indices.
            pulse_params (Optional[float]): Time or duration parameter for
                the pulse evolution. If None, uses optimized value.

        Returns:
            None: Gate is applied in-place to the circuit.
        """
        if pulse_params is None:
            pulse_params = PulseInformation.CZ.params
        else:
            pulse_params = pulse_params

        I_I = np.kron(PulseGates.Id, PulseGates.Id)
        Z_I = np.kron(PulseGates.Z, PulseGates.Id)
        I_Z = np.kron(PulseGates.Id, PulseGates.Z)
        Z_Z = np.kron(PulseGates.Z, PulseGates.Z)

        def Scz(p, t):
            return p * np.pi

        _H = (np.pi / 4) * (I_I - Z_I - I_Z + Z_Z)
        _H = qml.Hermitian(_H, wires=wires)
        H_eff = Scz * _H

        qml.evolve(H_eff)([pulse_params], 1)

    @staticmethod
    def CRX(
        w: float, wires: List[int], pulse_params: Optional[np.ndarray] = None
    ) -> None:
        """
        Apply controlled-RX gate using pulse decomposition.

        Implements CRX(w) as RZ(π/2) · RY(w/2) · CX · RY(-w/2) · CX · RZ(-π/2)
        applied to the target qubit, following arXiv:2408.01036.

        Args:
            w (float): Rotation angle in radians.
            wires (List[int]): Control and target qubit indices [control, target].
            pulse_params (Optional[np.ndarray]): Pulse parameters for the
                composing gates. If None, uses optimized parameters.

        Returns:
            None: Gate is applied in-place to the circuit.
        """
        params_RZ_1, params_RY, params_CX_1, params_RY_2, params_CX_2, params_RZ_2 = (
            PulseInformation.CRX.split_params(pulse_params)
        )

        target = wires[1]

        PulseGates.RZ(np.pi / 2, wires=target, pulse_params=params_RZ_1)
        PulseGates.RY(w / 2, wires=target, pulse_params=params_RY)
        PulseGates.CX(wires=wires, pulse_params=params_CX_1)
        PulseGates.RY(-w / 2, wires=target, pulse_params=params_RY_2)
        PulseGates.CX(wires=wires, pulse_params=params_CX_2)
        PulseGates.RZ(-np.pi / 2, wires=target, pulse_params=params_RZ_2)

    @staticmethod
    def CRY(
        w: float, wires: List[int], pulse_params: Optional[np.ndarray] = None
    ) -> None:
        """
        Apply controlled-RY gate using pulse decomposition.

        Implements CRY(w) as RY(w/2) · CX · RY(-w/2) · CX applied to the
        target qubit, following arXiv:2408.01036.

        Args:
            w (float): Rotation angle in radians.
            wires (List[int]): Control and target qubit indices [control, target].
            pulse_params (Optional[np.ndarray]): Pulse parameters for the
                composing gates. If None, uses optimized parameters.

        Returns:
            None: Gate is applied in-place to the circuit.
        """
        params_RY_1, params_CX_1, params_RY_2, params_CX_2 = (
            PulseInformation.CRY.split_params(pulse_params)
        )

        target = wires[1]

        PulseGates.RY(w / 2, wires=target, pulse_params=params_RY_1)
        PulseGates.CX(wires=wires, pulse_params=params_CX_1)
        PulseGates.RY(-w / 2, wires=target, pulse_params=params_RY_2)
        PulseGates.CX(wires=wires, pulse_params=params_CX_2)

    @staticmethod
    def CRZ(
        w: float, wires: List[int], pulse_params: Optional[np.ndarray] = None
    ) -> None:
        """
        Apply controlled-RZ gate using pulse decomposition.

        Implements CRZ(w) as RZ(w/2) · CX · RZ(-w/2) · CX applied to the
        target qubit, following arXiv:2408.01036.

        Args:
            w (float): Rotation angle in radians.
            wires (List[int]): Control and target qubit indices [control, target].
            pulse_params (Optional[np.ndarray]): Pulse parameters for the
                composing gates. If None, uses optimized parameters.

        Returns:
            None: Gate is applied in-place to the circuit.
        """
        params_RZ_1, params_CX_1, params_RZ_2, params_CX_2 = (
            PulseInformation.CRZ.split_params(pulse_params)
        )

        target = wires[1]

        PulseGates.RZ(w / 2, wires=target, pulse_params=params_RZ_1)
        PulseGates.CX(wires=wires, pulse_params=params_CX_1)
        PulseGates.RZ(-w / 2, wires=target, pulse_params=params_RZ_2)
        PulseGates.CX(wires=wires, pulse_params=params_CX_2)

CRX(w, wires, pulse_params=None) staticmethod #

Apply controlled-RX gate using pulse decomposition.

Implements CRX(w) as RZ(π/2) · RY(w/2) · CX · RY(-w/2) · CX · RZ(-π/2) applied to the target qubit, following arXiv:2408.01036.

Parameters:

Name Type Description Default
w float

Rotation angle in radians.

required
wires List[int]

Control and target qubit indices [control, target].

required
pulse_params Optional[ndarray]

Pulse parameters for the composing gates. If None, uses optimized parameters.

None

Returns:

Name Type Description
None None

Gate is applied in-place to the circuit.

Source code in qml_essentials/gates.py
@staticmethod
def CRX(
    w: float, wires: List[int], pulse_params: Optional[np.ndarray] = None
) -> None:
    """
    Apply controlled-RX gate using pulse decomposition.

    Implements CRX(w) as RZ(π/2) · RY(w/2) · CX · RY(-w/2) · CX · RZ(-π/2)
    applied to the target qubit, following arXiv:2408.01036.

    Args:
        w (float): Rotation angle in radians.
        wires (List[int]): Control and target qubit indices [control, target].
        pulse_params (Optional[np.ndarray]): Pulse parameters for the
            composing gates. If None, uses optimized parameters.

    Returns:
        None: Gate is applied in-place to the circuit.
    """
    params_RZ_1, params_RY, params_CX_1, params_RY_2, params_CX_2, params_RZ_2 = (
        PulseInformation.CRX.split_params(pulse_params)
    )

    target = wires[1]

    PulseGates.RZ(np.pi / 2, wires=target, pulse_params=params_RZ_1)
    PulseGates.RY(w / 2, wires=target, pulse_params=params_RY)
    PulseGates.CX(wires=wires, pulse_params=params_CX_1)
    PulseGates.RY(-w / 2, wires=target, pulse_params=params_RY_2)
    PulseGates.CX(wires=wires, pulse_params=params_CX_2)
    PulseGates.RZ(-np.pi / 2, wires=target, pulse_params=params_RZ_2)

CRY(w, wires, pulse_params=None) staticmethod #

Apply controlled-RY gate using pulse decomposition.

Implements CRY(w) as RY(w/2) · CX · RY(-w/2) · CX applied to the target qubit, following arXiv:2408.01036.

Parameters:

Name Type Description Default
w float

Rotation angle in radians.

required
wires List[int]

Control and target qubit indices [control, target].

required
pulse_params Optional[ndarray]

Pulse parameters for the composing gates. If None, uses optimized parameters.

None

Returns:

Name Type Description
None None

Gate is applied in-place to the circuit.

Source code in qml_essentials/gates.py
@staticmethod
def CRY(
    w: float, wires: List[int], pulse_params: Optional[np.ndarray] = None
) -> None:
    """
    Apply controlled-RY gate using pulse decomposition.

    Implements CRY(w) as RY(w/2) · CX · RY(-w/2) · CX applied to the
    target qubit, following arXiv:2408.01036.

    Args:
        w (float): Rotation angle in radians.
        wires (List[int]): Control and target qubit indices [control, target].
        pulse_params (Optional[np.ndarray]): Pulse parameters for the
            composing gates. If None, uses optimized parameters.

    Returns:
        None: Gate is applied in-place to the circuit.
    """
    params_RY_1, params_CX_1, params_RY_2, params_CX_2 = (
        PulseInformation.CRY.split_params(pulse_params)
    )

    target = wires[1]

    PulseGates.RY(w / 2, wires=target, pulse_params=params_RY_1)
    PulseGates.CX(wires=wires, pulse_params=params_CX_1)
    PulseGates.RY(-w / 2, wires=target, pulse_params=params_RY_2)
    PulseGates.CX(wires=wires, pulse_params=params_CX_2)

CRZ(w, wires, pulse_params=None) staticmethod #

Apply controlled-RZ gate using pulse decomposition.

Implements CRZ(w) as RZ(w/2) · CX · RZ(-w/2) · CX applied to the target qubit, following arXiv:2408.01036.

Parameters:

Name Type Description Default
w float

Rotation angle in radians.

required
wires List[int]

Control and target qubit indices [control, target].

required
pulse_params Optional[ndarray]

Pulse parameters for the composing gates. If None, uses optimized parameters.

None

Returns:

Name Type Description
None None

Gate is applied in-place to the circuit.

Source code in qml_essentials/gates.py
@staticmethod
def CRZ(
    w: float, wires: List[int], pulse_params: Optional[np.ndarray] = None
) -> None:
    """
    Apply controlled-RZ gate using pulse decomposition.

    Implements CRZ(w) as RZ(w/2) · CX · RZ(-w/2) · CX applied to the
    target qubit, following arXiv:2408.01036.

    Args:
        w (float): Rotation angle in radians.
        wires (List[int]): Control and target qubit indices [control, target].
        pulse_params (Optional[np.ndarray]): Pulse parameters for the
            composing gates. If None, uses optimized parameters.

    Returns:
        None: Gate is applied in-place to the circuit.
    """
    params_RZ_1, params_CX_1, params_RZ_2, params_CX_2 = (
        PulseInformation.CRZ.split_params(pulse_params)
    )

    target = wires[1]

    PulseGates.RZ(w / 2, wires=target, pulse_params=params_RZ_1)
    PulseGates.CX(wires=wires, pulse_params=params_CX_1)
    PulseGates.RZ(-w / 2, wires=target, pulse_params=params_RZ_2)
    PulseGates.CX(wires=wires, pulse_params=params_CX_2)

CX(wires, pulse_params=None) staticmethod #

Apply CNOT gate using pulse decomposition.

Implements CNOT as H_target · CZ · H_target, where H and CZ are applied using their respective pulse-level implementations.

Parameters:

Name Type Description Default
wires List[int]

Control and target qubit indices [control, target].

required
pulse_params Optional[ndarray]

Pulse parameters for the composing gates. If None, uses optimized parameters.

None

Returns:

Name Type Description
None None

Gate is applied in-place to the circuit.

Source code in qml_essentials/gates.py
@staticmethod
def CX(wires: List[int], pulse_params: Optional[np.ndarray] = None) -> None:
    """
    Apply CNOT gate using pulse decomposition.

    Implements CNOT as H_target · CZ · H_target, where H and CZ are
    applied using their respective pulse-level implementations.

    Args:
        wires (List[int]): Control and target qubit indices [control, target].
        pulse_params (Optional[np.ndarray]): Pulse parameters for the
            composing gates. If None, uses optimized parameters.

    Returns:
        None: Gate is applied in-place to the circuit.
    """
    params_H_1, params_CZ, params_H_2 = PulseInformation.CX.split_params(
        pulse_params
    )

    target = wires[1]

    PulseGates.H(wires=target, pulse_params=params_H_1)
    PulseGates.CZ(wires=wires, pulse_params=params_CZ)
    PulseGates.H(wires=target, pulse_params=params_H_2)

CY(wires, pulse_params=None) staticmethod #

Apply controlled-Y gate using pulse decomposition.

Implements CY as RZ(-π/2)_target · CX · RZ(π/2)_target using pulse-level implementations.

Parameters:

Name Type Description Default
wires List[int]

Control and target qubit indices [control, target].

required
pulse_params Optional[ndarray]

Pulse parameters for the composing gates. If None, uses optimized parameters.

None

Returns:

Name Type Description
None None

Gate is applied in-place to the circuit.

Source code in qml_essentials/gates.py
@staticmethod
def CY(wires: List[int], pulse_params: Optional[np.ndarray] = None) -> None:
    """
    Apply controlled-Y gate using pulse decomposition.

    Implements CY as RZ(-π/2)_target · CX · RZ(π/2)_target using
    pulse-level implementations.

    Args:
        wires (List[int]): Control and target qubit indices [control, target].
        pulse_params (Optional[np.ndarray]): Pulse parameters for the
            composing gates. If None, uses optimized parameters.

    Returns:
        None: Gate is applied in-place to the circuit.
    """
    params_RZ_1, params_CX, params_RZ_2 = PulseInformation.CY.split_params(
        pulse_params
    )

    target = wires[1]

    PulseGates.RZ(-np.pi / 2, wires=target, pulse_params=params_RZ_1)
    PulseGates.CX(wires=wires, pulse_params=params_CX)
    PulseGates.RZ(np.pi / 2, wires=target, pulse_params=params_RZ_2)

CZ(wires, pulse_params=None) staticmethod #

Apply controlled-Z gate using pulse-level implementation.

Implements CZ using a two-qubit interaction Hamiltonian based on ZZ coupling.

Parameters:

Name Type Description Default
wires List[int]

Control and target qubit indices.

required
pulse_params Optional[float]

Time or duration parameter for the pulse evolution. If None, uses optimized value.

None

Returns:

Name Type Description
None None

Gate is applied in-place to the circuit.

Source code in qml_essentials/gates.py
@staticmethod
def CZ(wires: List[int], pulse_params: Optional[float] = None) -> None:
    """
    Apply controlled-Z gate using pulse-level implementation.

    Implements CZ using a two-qubit interaction Hamiltonian based on
    ZZ coupling.

    Args:
        wires (List[int]): Control and target qubit indices.
        pulse_params (Optional[float]): Time or duration parameter for
            the pulse evolution. If None, uses optimized value.

    Returns:
        None: Gate is applied in-place to the circuit.
    """
    if pulse_params is None:
        pulse_params = PulseInformation.CZ.params
    else:
        pulse_params = pulse_params

    I_I = np.kron(PulseGates.Id, PulseGates.Id)
    Z_I = np.kron(PulseGates.Z, PulseGates.Id)
    I_Z = np.kron(PulseGates.Id, PulseGates.Z)
    Z_Z = np.kron(PulseGates.Z, PulseGates.Z)

    def Scz(p, t):
        return p * np.pi

    _H = (np.pi / 4) * (I_I - Z_I - I_Z + Z_Z)
    _H = qml.Hermitian(_H, wires=wires)
    H_eff = Scz * _H

    qml.evolve(H_eff)([pulse_params], 1)

H(wires, pulse_params=None) staticmethod #

Apply Hadamard gate using pulse decomposition.

Implements Hadamard as RZ(π) · RY(π/2) with a correction phase, using pulse-level implementations for each component.

Parameters:

Name Type Description Default
wires Union[int, List[int]]

Qubit index or indices to apply gate to.

required
pulse_params Optional[ndarray]

Pulse parameters for the composing gates. If None, uses optimized parameters.

None

Returns:

Name Type Description
None None

Gate is applied in-place to the circuit.

Source code in qml_essentials/gates.py
@staticmethod
def H(
    wires: Union[int, List[int]], pulse_params: Optional[np.ndarray] = None
) -> None:
    """
    Apply Hadamard gate using pulse decomposition.

    Implements Hadamard as RZ(π) · RY(π/2) with a correction phase,
    using pulse-level implementations for each component.

    Args:
        wires (Union[int, List[int]]): Qubit index or indices to apply gate to.
        pulse_params (Optional[np.ndarray]): Pulse parameters for the
            composing gates. If None, uses optimized parameters.

    Returns:
        None: Gate is applied in-place to the circuit.
    """
    pulse_params_RZ, pulse_params_RY = PulseInformation.H.split_params(pulse_params)

    # qml.GlobalPhase(-np.pi / 2)  # this could act as substitute to Sc
    PulseGates.RZ(np.pi, wires=wires, pulse_params=pulse_params_RZ)
    PulseGates.RY(np.pi / 2, wires=wires, pulse_params=pulse_params_RY)

    def Sc(p, t):
        return -1.0

    _H = np.pi / 2 * np.eye(2, dtype=np.complex64)
    _H = qml.Hermitian(_H, wires=wires)
    H_corr = Sc * _H

    qml.evolve(H_corr)([0], 1)

RX(w, wires, pulse_params=None) staticmethod #

Apply X-axis rotation using pulse-level implementation.

Implements RX rotation using a shaped Gaussian pulse with optimized envelope parameters.

Parameters:

Name Type Description Default
w float

Rotation angle in radians.

required
wires Union[int, List[int]]

Qubit index or indices to apply rotation to.

required
pulse_params Optional[ndarray]

Array containing pulse parameters [A, sigma, t] for the Gaussian envelope. If None, uses optimized parameters.

None

Returns:

Name Type Description
None None

Gate is applied in-place to the circuit.

Source code in qml_essentials/gates.py
@staticmethod
def RX(
    w: float,
    wires: Union[int, List[int]],
    pulse_params: Optional[np.ndarray] = None,
) -> None:
    """
    Apply X-axis rotation using pulse-level implementation.

    Implements RX rotation using a shaped Gaussian pulse with optimized
    envelope parameters.

    Args:
        w (float): Rotation angle in radians.
        wires (Union[int, List[int]]): Qubit index or indices to apply rotation to.
        pulse_params (Optional[np.ndarray]): Array containing pulse parameters
            [A, sigma, t] for the Gaussian envelope. If None, uses optimized
            parameters.

    Returns:
        None: Gate is applied in-place to the circuit.
    """
    pulse_params = PulseInformation.RX.split_params(pulse_params)

    def Sx(p, t):
        return PulseGates._S(p, t, phi_c=np.pi) * w

    _H = PulseGates.H_static.conj().T @ PulseGates.X @ PulseGates.H_static
    _H = qml.Hermitian(_H, wires=wires)
    H_eff = Sx * _H

    qml.evolve(H_eff)([pulse_params[0:2]], pulse_params[2])

RY(w, wires, pulse_params=None) staticmethod #

Apply Y-axis rotation using pulse-level implementation.

Implements RY rotation using a shaped Gaussian pulse with optimized envelope parameters.

Parameters:

Name Type Description Default
w float

Rotation angle in radians.

required
wires Union[int, List[int]]

Qubit index or indices to apply rotation to.

required
pulse_params Optional[ndarray]

Array containing pulse parameters [A, sigma, t] for the Gaussian envelope. If None, uses optimized parameters.

None

Returns:

Name Type Description
None None

Gate is applied in-place to the circuit.

Source code in qml_essentials/gates.py
@staticmethod
def RY(
    w: float,
    wires: Union[int, List[int]],
    pulse_params: Optional[np.ndarray] = None,
) -> None:
    """
    Apply Y-axis rotation using pulse-level implementation.

    Implements RY rotation using a shaped Gaussian pulse with optimized
    envelope parameters.

    Args:
        w (float): Rotation angle in radians.
        wires (Union[int, List[int]]): Qubit index or indices to apply rotation to.
        pulse_params (Optional[np.ndarray]): Array containing pulse parameters
            [A, sigma, t] for the Gaussian envelope. If None, uses optimized
            parameters.

    Returns:
        None: Gate is applied in-place to the circuit.
    """
    pulse_params = PulseInformation.RY.split_params(pulse_params)

    def Sy(p, t):
        return PulseGates._S(p, t, phi_c=-np.pi / 2) * w

    _H = PulseGates.H_static.conj().T @ PulseGates.Y @ PulseGates.H_static
    _H = qml.Hermitian(_H, wires=wires)
    H_eff = Sy * _H

    qml.evolve(H_eff)([pulse_params[0:2]], pulse_params[2])

RZ(w, wires, pulse_params=None) staticmethod #

Apply Z-axis rotation using pulse-level implementation.

Implements RZ rotation using virtual Z rotations (phase tracking) without physical pulse application.

Parameters:

Name Type Description Default
w float

Rotation angle in radians.

required
wires Union[int, List[int]]

Qubit index or indices to apply rotation to.

required
pulse_params Optional[float]

Duration parameter for the pulse. Rotation angle = w * 2 * pulse_params. Defaults to 0.5 if None.

None

Returns:

Name Type Description
None None

Gate is applied in-place to the circuit.

Source code in qml_essentials/gates.py
@staticmethod
def RZ(
    w: float, wires: Union[int, List[int]], pulse_params: Optional[float] = None
) -> None:
    """
    Apply Z-axis rotation using pulse-level implementation.

    Implements RZ rotation using virtual Z rotations (phase tracking)
    without physical pulse application.

    Args:
        w (float): Rotation angle in radians.
        wires (Union[int, List[int]]): Qubit index or indices to apply rotation to.
        pulse_params (Optional[float]): Duration parameter for the pulse.
            Rotation angle = w * 2 * pulse_params. Defaults to 0.5 if None.

    Returns:
        None: Gate is applied in-place to the circuit.
    """
    pulse_params = PulseInformation.RZ.split_params(pulse_params)

    _H = qml.Hermitian(PulseGates.Z, wires=wires)

    def Sz(p, t):
        return p * w

    H_eff = Sz * _H

    qml.evolve(H_eff)([pulse_params], 1)

Rot(phi, theta, omega, wires, pulse_params=None) staticmethod #

Apply general single-qubit rotation using pulse decomposition.

Decomposes a general rotation into RZ(phi) · RY(theta) · RZ(omega) and applies each component using pulse-level implementations.

Parameters:

Name Type Description Default
phi float

First rotation angle.

required
theta float

Second rotation angle.

required
omega float

Third rotation angle.

required
wires Union[int, List[int]]

Qubit index or indices to apply rotation to.

required
pulse_params Optional[ndarray]

Pulse parameters for the composing gates. If None, uses optimized parameters.

None

Returns:

Name Type Description
None None

Gates are applied in-place to the circuit.

Source code in qml_essentials/gates.py
@staticmethod
def Rot(
    phi: float,
    theta: float,
    omega: float,
    wires: Union[int, List[int]],
    pulse_params: Optional[np.ndarray] = None,
) -> None:
    """
    Apply general single-qubit rotation using pulse decomposition.

    Decomposes a general rotation into RZ(phi) · RY(theta) · RZ(omega)
    and applies each component using pulse-level implementations.

    Args:
        phi (float): First rotation angle.
        theta (float): Second rotation angle.
        omega (float): Third rotation angle.
        wires (Union[int, List[int]]): Qubit index or indices to apply rotation to.
        pulse_params (Optional[np.ndarray]): Pulse parameters for the
            composing gates. If None, uses optimized parameters.

    Returns:
        None: Gates are applied in-place to the circuit.
    """
    params_RZ_1, params_RY, params_RZ_2 = PulseInformation.Rot.split_params(
        pulse_params
    )

    PulseGates.RZ(phi, wires=wires, pulse_params=params_RZ_1)
    PulseGates.RY(theta, wires=wires, pulse_params=params_RY)
    PulseGates.RZ(omega, wires=wires, pulse_params=params_RZ_2)

Pulse Structure#

from qml_essentials.gates import PulseParams

Container for hierarchical pulse parameters.

Manages pulse parameters for quantum gates, supporting both leaf nodes (gates with direct parameters) and composite nodes (gates decomposed into simpler gates). Enables hierarchical parameter access and manipulation.

Attributes:

Name Type Description
name str

Name identifier for the gate.

_params ndarray

Direct pulse parameters (leaf nodes only).

_pulse_obj List

Child PulseParams objects (composite nodes only).

Source code in qml_essentials/gates.py
class PulseParams:
    """
    Container for hierarchical pulse parameters.

    Manages pulse parameters for quantum gates, supporting both leaf nodes
    (gates with direct parameters) and composite nodes (gates decomposed
    into simpler gates). Enables hierarchical parameter access and
    manipulation.

    Attributes:
        name (str): Name identifier for the gate.
        _params (np.ndarray): Direct pulse parameters (leaf nodes only).
        _pulse_obj (List): Child PulseParams objects (composite nodes only).
    """

    def __init__(
        self,
        name: str = "",
        params: Optional[np.ndarray] = None,
        pulse_obj: Optional[List] = None,
    ) -> None:
        """
        Initialize pulse parameters container.

        Args:
            name (str): Name identifier for the gate. Defaults to empty string.
            params (Optional[np.ndarray]): Direct pulse parameters for leaf gates.
                Mutually exclusive with pulse_obj.
            pulse_obj (Optional[List]): List of child PulseParams for composite
                gates. Mutually exclusive with params.

        Raises:
            AssertionError: If both or neither of params and pulse_obj are provided.
        """
        assert (params is None and pulse_obj is not None) or (
            params is not None and pulse_obj is None
        ), "Exactly one of `params` or `pulse_params` must be provided."

        self._pulse_obj = pulse_obj

        if params is not None:
            self._params = params

        self.name = name

    def __len__(self) -> int:
        """
        Get the total number of pulse parameters.

        For composite gates, returns the accumulated count from all children.

        Returns:
            int: Total number of pulse parameters.
        """
        return len(self.params)

    def __getitem__(self, idx: int) -> Union[float, np.ndarray]:
        """
        Access pulse parameter(s) by index.

        For leaf gates, returns the parameter at the given index.
        For composite gates, returns parameters of the child at the given index.

        Args:
            idx (int): Index to access.

        Returns:
            Union[float, np.ndarray]: Parameter value or child parameters.
        """
        if self.is_leaf:
            return self.params[idx]
        else:
            return self.childs[idx].params

    def __str__(self) -> str:
        """Return string representation (gate name)."""
        return self.name

    def __repr__(self) -> str:
        """Return repr string (gate name)."""
        return self.name

    @property
    def is_leaf(self) -> bool:
        """Check if this is a leaf node (direct parameters, no children)."""
        return self._pulse_obj is None

    @property
    def size(self) -> int:
        """Get the total parameter count (alias for __len__)."""
        return len(self)

    @property
    def leafs(self) -> List["PulseParams"]:
        """
        Get all leaf nodes in the hierarchy.

        Recursively collects all leaf PulseParams objects in the tree.

        Returns:
            List[PulseParams]: List of unique leaf nodes.
        """
        if self.is_leaf:
            return [self]

        leafs = []
        for obj in self._pulse_obj:
            leafs.extend(obj.leafs)

        return list(set(leafs))

    @property
    def childs(self) -> List["PulseParams"]:
        """
        Get direct children of this node.

        Returns:
            List[PulseParams]: List of child PulseParams objects, or empty list
                if this is a leaf node.
        """
        if self.is_leaf:
            return []

        return self._pulse_obj

    @property
    def shape(self) -> List[int]:
        """
        Get the shape of pulse parameters.

        For leaf nodes, returns list with parameter count.
        For composite nodes, returns nested list of child shapes.

        Returns:
            List[int]: Parameter shape specification.
        """
        if self.is_leaf:
            return [len(self.params)]

        shape = []
        for obj in self.childs:
            shape.append(*obj.shape())

        return shape

    @property
    def params(self) -> np.ndarray:
        """
        Get or compute pulse parameters.

        For leaf nodes, returns internal pulse parameters.
        For composite nodes, returns concatenated parameters from all children.

        Returns:
            np.ndarray: Pulse parameters array.
        """
        if self.is_leaf:
            return self._params

        params = self.split_params(params=None, leafs=False)

        return np.concatenate(params)

    @params.setter
    def params(self, value: np.ndarray) -> None:
        """
        Set pulse parameters.

        For leaf nodes, sets internal parameters directly.
        For composite nodes, distributes values across children.

        Args:
            value (np.ndarray): Pulse parameters to set.

        Raises:
            AssertionError: If value is not np.ndarray for leaf nodes.
        """
        if self.is_leaf:
            assert isinstance(value, np.ndarray), "params must be a np.ndarray"
            self._params = value
            return

        idx = 0
        for obj in self.childs:
            nidx = idx + obj.size
            obj.params = value[idx:nidx]
            idx = nidx

    @property
    def leaf_params(self) -> np.ndarray:
        """
        Get parameters from all leaf nodes.

        Returns:
            np.ndarray: Concatenated parameters from all leaf nodes.
        """
        if self.is_leaf:
            return self._params

        params = self.split_params(None, leafs=True)

        return np.concatenate(params)

    @leaf_params.setter
    def leaf_params(self, value: np.ndarray) -> None:
        """
        Set parameters for all leaf nodes.

        Args:
            value (np.ndarray): Parameters to distribute across leaf nodes.
        """
        if self.is_leaf:
            self._params = value
            return

        idx = 0
        for obj in self.leafs:
            nidx = idx + obj.size
            obj.params = value[idx:nidx]
            idx = nidx

    def split_params(
        self,
        params: Optional[np.ndarray] = None,
        leafs: bool = False,
    ) -> List[np.ndarray]:
        """
        Split parameters into sub-arrays for children or leaves.

        Args:
            params (Optional[np.ndarray]): Parameters to split. If None,
                uses internal parameters.
            leafs (bool): If True, splits across leaf nodes; if False,
                splits across direct children. Defaults to False.

        Returns:
            List[np.ndarray]: List of parameter arrays for children or leaves.
        """
        if params is None:
            if self.is_leaf:
                return self._params

            objs = self.leafs if leafs else self.childs
            s_params = []
            for obj in objs:
                s_params.append(obj.params)

            return s_params
        else:
            if self.is_leaf:
                return params

            objs = self.leafs if leafs else self.childs
            s_params = []
            idx = 0
            for obj in objs:
                nidx = idx + obj.size
                s_params.append(params[idx:nidx])
                idx = nidx

            return s_params

childs property #

Get direct children of this node.

Returns:

Type Description
List[PulseParams]

List[PulseParams]: List of child PulseParams objects, or empty list if this is a leaf node.

is_leaf property #

Check if this is a leaf node (direct parameters, no children).

leaf_params property writable #

Get parameters from all leaf nodes.

Returns:

Type Description
ndarray

np.ndarray: Concatenated parameters from all leaf nodes.

leafs property #

Get all leaf nodes in the hierarchy.

Recursively collects all leaf PulseParams objects in the tree.

Returns:

Type Description
List[PulseParams]

List[PulseParams]: List of unique leaf nodes.

params property writable #

Get or compute pulse parameters.

For leaf nodes, returns internal pulse parameters. For composite nodes, returns concatenated parameters from all children.

Returns:

Type Description
ndarray

np.ndarray: Pulse parameters array.

shape property #

Get the shape of pulse parameters.

For leaf nodes, returns list with parameter count. For composite nodes, returns nested list of child shapes.

Returns:

Type Description
List[int]

List[int]: Parameter shape specification.

size property #

Get the total parameter count (alias for len).

__getitem__(idx) #

Access pulse parameter(s) by index.

For leaf gates, returns the parameter at the given index. For composite gates, returns parameters of the child at the given index.

Parameters:

Name Type Description Default
idx int

Index to access.

required

Returns:

Type Description
Union[float, ndarray]

Union[float, np.ndarray]: Parameter value or child parameters.

Source code in qml_essentials/gates.py
def __getitem__(self, idx: int) -> Union[float, np.ndarray]:
    """
    Access pulse parameter(s) by index.

    For leaf gates, returns the parameter at the given index.
    For composite gates, returns parameters of the child at the given index.

    Args:
        idx (int): Index to access.

    Returns:
        Union[float, np.ndarray]: Parameter value or child parameters.
    """
    if self.is_leaf:
        return self.params[idx]
    else:
        return self.childs[idx].params

__init__(name='', params=None, pulse_obj=None) #

Initialize pulse parameters container.

Parameters:

Name Type Description Default
name str

Name identifier for the gate. Defaults to empty string.

''
params Optional[ndarray]

Direct pulse parameters for leaf gates. Mutually exclusive with pulse_obj.

None
pulse_obj Optional[List]

List of child PulseParams for composite gates. Mutually exclusive with params.

None

Raises:

Type Description
AssertionError

If both or neither of params and pulse_obj are provided.

Source code in qml_essentials/gates.py
def __init__(
    self,
    name: str = "",
    params: Optional[np.ndarray] = None,
    pulse_obj: Optional[List] = None,
) -> None:
    """
    Initialize pulse parameters container.

    Args:
        name (str): Name identifier for the gate. Defaults to empty string.
        params (Optional[np.ndarray]): Direct pulse parameters for leaf gates.
            Mutually exclusive with pulse_obj.
        pulse_obj (Optional[List]): List of child PulseParams for composite
            gates. Mutually exclusive with params.

    Raises:
        AssertionError: If both or neither of params and pulse_obj are provided.
    """
    assert (params is None and pulse_obj is not None) or (
        params is not None and pulse_obj is None
    ), "Exactly one of `params` or `pulse_params` must be provided."

    self._pulse_obj = pulse_obj

    if params is not None:
        self._params = params

    self.name = name

__len__() #

Get the total number of pulse parameters.

For composite gates, returns the accumulated count from all children.

Returns:

Name Type Description
int int

Total number of pulse parameters.

Source code in qml_essentials/gates.py
def __len__(self) -> int:
    """
    Get the total number of pulse parameters.

    For composite gates, returns the accumulated count from all children.

    Returns:
        int: Total number of pulse parameters.
    """
    return len(self.params)

__repr__() #

Return repr string (gate name).

Source code in qml_essentials/gates.py
def __repr__(self) -> str:
    """Return repr string (gate name)."""
    return self.name

__str__() #

Return string representation (gate name).

Source code in qml_essentials/gates.py
def __str__(self) -> str:
    """Return string representation (gate name)."""
    return self.name

split_params(params=None, leafs=False) #

Split parameters into sub-arrays for children or leaves.

Parameters:

Name Type Description Default
params Optional[ndarray]

Parameters to split. If None, uses internal parameters.

None
leafs bool

If True, splits across leaf nodes; if False, splits across direct children. Defaults to False.

False

Returns:

Type Description
List[ndarray]

List[np.ndarray]: List of parameter arrays for children or leaves.

Source code in qml_essentials/gates.py
def split_params(
    self,
    params: Optional[np.ndarray] = None,
    leafs: bool = False,
) -> List[np.ndarray]:
    """
    Split parameters into sub-arrays for children or leaves.

    Args:
        params (Optional[np.ndarray]): Parameters to split. If None,
            uses internal parameters.
        leafs (bool): If True, splits across leaf nodes; if False,
            splits across direct children. Defaults to False.

    Returns:
        List[np.ndarray]: List of parameter arrays for children or leaves.
    """
    if params is None:
        if self.is_leaf:
            return self._params

        objs = self.leafs if leafs else self.childs
        s_params = []
        for obj in objs:
            s_params.append(obj.params)

        return s_params
    else:
        if self.is_leaf:
            return params

        objs = self.leafs if leafs else self.childs
        s_params = []
        idx = 0
        for obj in objs:
            nidx = idx + obj.size
            s_params.append(params[idx:nidx])
            idx = nidx

        return s_params

Model#

from qml_essentials.model import Model

A quantum circuit model.

Source code in qml_essentials/model.py
  20
  21
  22
  23
  24
  25
  26
  27
  28
  29
  30
  31
  32
  33
  34
  35
  36
  37
  38
  39
  40
  41
  42
  43
  44
  45
  46
  47
  48
  49
  50
  51
  52
  53
  54
  55
  56
  57
  58
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
class Model:
    """
    A quantum circuit model.
    """

    lightning_threshold = 12
    cpu_scaler = 0.9  # default cpu scaler, =1 means full CPU for MP

    def __init__(
        self,
        n_qubits: int,
        n_layers: int,
        circuit_type: Union[str, Circuit] = "No_Ansatz",
        data_reupload: Union[bool, List[List[bool]], List[List[List[bool]]]] = True,
        state_preparation: Union[
            str, Callable, List[Union[str, Callable]], None
        ] = None,
        encoding: Union[Encoding, str, Callable, List[Union[str, Callable]]] = Gates.RX,
        trainable_frequencies: bool = False,
        initialization: str = "random",
        initialization_domain: List[float] = [0, 2 * jnp.pi],
        output_qubit: Union[List[int], int] = -1,
        shots: Optional[int] = None,
        random_seed: int = 1000,
        remove_zero_encoding: bool = True,
        use_multithreading: bool = False,
        repeat_batch_axis: List[bool] = [True, True, True],
    ) -> None:
        """
        Initialize the quantum circuit model.
        Parameters will have the shape [impl_n_layers, parameters_per_layer]
        where impl_n_layers is the number of layers provided and added by one
        depending if data_reupload is True and parameters_per_layer is given by
        the chosen ansatz.

        The model is initialized with the following parameters as defaults:
        - noise_params: None
        - execution_type: "expval"
        - shots: None

        Args:
            n_qubits (int): The number of qubits in the circuit.
            n_layers (int): The number of layers in the circuit.
            circuit_type (str, Circuit): The type of quantum circuit to use.
                If None, defaults to "no_ansatz".
            data_reupload (Union[bool, List[bool], List[List[bool]]], optional):
                Whether to reupload data to the quantum device on each
                layer and qubit. Detailed re-uploading instructions can be given
                as a list/array of 0/False and 1/True with shape (n_qubits,
                n_layers) to specify where to upload the data. Defaults to True
                for applying data re-uploading to the full circuit.
            encoding (Union[str, Callable, List[str], List[Callable]], optional):
                The unitary to use for encoding the input data. Can be a string
                (e.g. "RX") or a callable (e.g. qml.RX). Defaults to qml.RX.
                If input is multidimensional it is assumed to be a list of
                unitaries or a list of strings.
            trainable_frequencies (bool, optional):
                Sets trainable encoding parameters for trainable frequencies.
                Defaults to False.
            initialization (str, optional): The strategy to initialize the parameters.
                Can be "random", "zeros", "zero-controlled", "pi", or "pi-controlled".
                Defaults to "random".
            output_qubit (List[int], int, optional): The index of the output
                qubit (or qubits). When set to -1 all qubits are measured, or a
                global measurement is conducted, depending on the execution
                type.
            shots (Optional[int], optional): The number of shots to use for
                the quantum device. Defaults to None.
            random_seed (int, optional): seed for the random number generator
                in initialization is "random" and for random noise parameters.
                Defaults to 1000.
            remove_zero_encoding (bool, optional): whether to
                remove the zero encoding from the circuit. Defaults to True.
            use_multithreading (bool, optional): whether to use JAX
                multithreading to parallelise over batch dimension.

        Returns:
            None
        """
        # Initialize default parameters needed for circuit evaluation
        self.n_qubits: int = n_qubits
        self.output_qubit: Union[List[int], int] = output_qubit
        self.n_layers: int = n_layers
        self.noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None
        self.shots = shots
        self.remove_zero_encoding = remove_zero_encoding
        self.use_multithreading = use_multithreading
        self.trainable_frequencies: bool = trainable_frequencies
        self.execution_type: str = "expval"
        self.repeat_batch_axis: List[bool] = repeat_batch_axis

        # --- State Preparation ---
        try:
            self._sp = Gates.parse_gates(state_preparation, Gates)
        except ValueError as e:
            raise ValueError(f"Error parsing encodings: {e}")

        # prepare corresponding pulse parameters (always optimized pulses)
        self.sp_pulse_params = []
        for sp in self._sp:
            sp_name = sp.__name__ if hasattr(sp, "__name__") else str(sp)

            if pinfo.gate_by_name(sp_name) is not None:
                self.sp_pulse_params.append(pinfo.gate_by_name(sp_name).params)
            else:
                # gate has no pulse parametrization
                self.sp_pulse_params.append(None)

        # --- Encoding ---
        if isinstance(encoding, Encoding):
            # user wants custom strategy? do it!
            self._enc = encoding
        else:
            # use hammming encoding by default
            self._enc = Encoding("hamming", encoding)

        # Number of possible inputs
        self.n_input_feat = len(self._enc)
        log.debug(f"Number of input features: {self.n_input_feat}")

        # Trainable frequencies, default initialization as in arXiv:2309.03279v2
        self.enc_params = jnp.ones((self.n_qubits, self.n_input_feat))

        self._zero_inputs = False

        # --- Data-Reuploading ---
        # Process data reuploading strategy and set degree
        if not isinstance(data_reupload, bool):
            if not isinstance(data_reupload, np.ndarray):
                data_reupload = np.array(data_reupload)

            if len(data_reupload.shape) == 2:
                assert data_reupload.shape == (
                    n_layers,
                    n_qubits,
                ), f"Data reuploading array has wrong shape. \
                    Expected {(n_layers, n_qubits)} or\
                    {(n_layers, n_qubits, self.n_input_feat)},\
                    got {data_reupload.shape}."
                data_reupload = data_reupload.reshape(*data_reupload.shape, 1)
                data_reupload = np.repeat(data_reupload, self.n_input_feat, axis=2)

            assert data_reupload.shape == (
                n_layers,
                n_qubits,
                self.n_input_feat,
            ), f"Data reuploading array has wrong shape. \
                Expected {(n_layers, n_qubits, self.n_input_feat)},\
                got {data_reupload.shape}."

            log.debug(f"Data reuploading array:\n{data_reupload}")
        else:
            if data_reupload:
                impl_n_layers: int = (
                    n_layers + 1
                )  # we need L+1 according to Schuld et al.
                data_reupload = np.ones((n_layers, n_qubits, self.n_input_feat))
                log.debug("Full data reuploading.")
            else:
                impl_n_layers: int = n_layers
                data_reupload = np.zeros((n_layers, n_qubits, self.n_input_feat))
                data_reupload[0][0] = 1
                log.debug("No data reuploading.")

        # convert to boolean values
        data_reupload = data_reupload.astype(bool)
        self.data_reupload = jnp.array(data_reupload)

        self.degree: Tuple = tuple(
            self._enc.get_n_freqs(jnp.count_nonzero(self.data_reupload[..., i]))
            for i in range(self.n_input_feat)
        )

        self.frequencies: Tuple = tuple(
            self._enc.get_spectrum(jnp.count_nonzero(self.data_reupload[..., i]))
            for i in range(self.n_input_feat)
        )

        self.has_dru = jnp.max(jnp.array([jnp.max(f) for f in self.frequencies])) > 1

        # check for the highest degree among all input dimensions
        if self.has_dru:
            impl_n_layers: int = n_layers + 1  # we need L+1 according to Schuld et al.
        else:
            impl_n_layers = n_layers
        log.info(f"Number of implicit layers: {impl_n_layers}.")

        # --- Ansatz ---
        # only weak check for str. We trust the user to provide sth useful
        if isinstance(circuit_type, str):
            self.pqc: Callable[[Optional[jnp.ndarray], int], int] = getattr(
                Ansaetze, circuit_type or "No_Ansatz"
            )()
        else:
            self.pqc = circuit_type()
        log.info(f"Using Ansatz {circuit_type}.")

        # calculate the shape of the parameter vector here, we will re-use this in init.
        params_per_layer = self.pqc.n_params_per_layer(self.n_qubits)
        self._params_shape: Tuple[int, int] = (impl_n_layers, params_per_layer)
        log.info(f"Parameters per layer: {params_per_layer}")

        pulse_params_per_layer = self.pqc.n_pulse_params_per_layer(self.n_qubits)
        self._pulse_params_shape: Tuple[int, int] = (
            impl_n_layers,
            pulse_params_per_layer,
        )

        # intialize to None as we can't know this yet
        self._batch_shape = None

        # this will also be re-used in the init method,
        # however, only if nothing is provided
        self._inialization_strategy = initialization
        self._initialization_domain = initialization_domain

        # ..here! where we only require a JAX random key
        self.random_key = self.initialize_params(random.key(random_seed))

        # Initializing pulse params
        self.pulse_params: jnp.ndarray = jnp.ones((*self._pulse_params_shape, 1))

        log.info(f"Initialized pulse parameters with shape {self.pulse_params.shape}.")

        # Initialize two circuits, one with the default device and
        # one with the mixed device
        # which allows us to later route depending on the state_vector flag
        if self.n_qubits < self.lightning_threshold:
            device = "default.qubit"
        else:
            device = "lightning.qubit"
            self.use_multithreading = False
        self.circuit: qml.QNode = qml.QNode(
            self._circuit,
            qml.device(
                device,
                shots=self.shots,
                wires=self.n_qubits,
            ),
            interface="jax-jit",
            diff_method="parameter-shift" if self.shots is not None else "best",
        )

        self.circuit_mixed: qml.QNode = qml.QNode(
            self._circuit,
            qml.device("default.mixed", shots=self.shots, wires=self.n_qubits),
            interface="jax-jit",
            diff_method="parameter-shift" if self.shots is not None else "best",
        )

    @property
    def noise_params(self) -> Optional[Dict[str, Union[float, Dict[str, float]]]]:
        """
        Gets the noise parameters of the model.

        Returns:
            Optional[Dict[str, float]]: A dictionary of
            noise parameters or None if not set.
        """
        return self._noise_params

    @noise_params.setter
    def noise_params(
        self, kvs: Optional[Dict[str, Union[float, Dict[str, float]]]]
    ) -> None:
        """
        Sets the noise parameters of the model.

        Typically a "noise parameter" refers to the error probability.
        ThermalRelaxation is a special case, and supports a dict as value with
        structure:
            "ThermalRelaxation":
            {
                "t1": 2000, # relative t1 time.
                "t2": 1000, # relative t2 time
                "t_factor" 1: # relative gate time factor
            },

        Args:
            kvs (Optional[Dict[str, Union[float, Dict[str, float]]]]): A
            dictionary of noise parameters. If all values are 0.0, the noise
            parameters are set to None.

        Returns:
            None
        """
        # set to None if only zero values provided
        if kvs is not None and all(v == 0.0 for v in kvs.values()):
            kvs = None

        # set default values
        if kvs is not None:
            defaults = {
                "BitFlip": 0.0,
                "PhaseFlip": 0.0,
                "Depolarizing": 0.0,
                "MultiQubitDepolarizing": 0.0,
                "AmplitudeDamping": 0.0,
                "PhaseDamping": 0.0,
                "GateError": 0.0,
                "ThermalRelaxation": None,
                "StatePreparation": 0.0,
                "Measurement": 0.0,
            }
            for key, default_val in defaults.items():
                kvs.setdefault(key, default_val)

            # check if there are any keys not supported
            for key in kvs.keys():
                if key not in defaults:
                    warnings.warn(
                        f"Noise type {key} is not supported by this package",
                        UserWarning,
                    )

            # check valid params for thermal relaxation noise channel
            tr_params = kvs["ThermalRelaxation"]
            if isinstance(tr_params, dict):
                tr_params.setdefault("t1", 0.0)
                tr_params.setdefault("t2", 0.0)
                tr_params.setdefault("t_factor", 0.0)
                valid_tr_keys = {"t1", "t2", "t_factor"}
                for k in tr_params.keys():
                    if k not in valid_tr_keys:
                        warnings.warn(
                            f"Thermal Relaxation parameter {k} is not supported "
                            f"by this package",
                            UserWarning,
                        )
                if not all(tr_params.values()) or tr_params["t2"] > 2 * tr_params["t1"]:
                    warnings.warn(
                        "Received invalid values for Thermal Relaxation noise "
                        "parameter. Thermal relaxation is not applied!",
                        UserWarning,
                    )
                    kvs["ThermalRelaxation"] = 0.0

        self._noise_params = kvs

    @property
    def output_qubit(self) -> List[int]:
        """Get the output qubit indices for measurement."""
        return self._output_qubit

    @output_qubit.setter
    def output_qubit(self, value: Union[int, List[int]]) -> None:
        """
        Set the output qubit(s) for measurement.

        Args:
            value: Qubit index or list of indices. Use -1 for all qubits.
        """
        if isinstance(value, list):
            assert (
                len(value) <= self.n_qubits
            ), f"Size of output_qubit {len(value)} cannot be\
            larger than number of qubits {self.n_qubits}."
        elif isinstance(value, int):
            if value == -1:
                value = list(range(self.n_qubits))
            else:
                assert (
                    value < self.n_qubits
                ), f"Output qubit {value} cannot be larger than {self.n_qubits}."
                value = [value]

        self._output_qubit = value

    @property
    def execution_type(self) -> str:
        """
        Gets the execution type of the model.

        Returns:
            str: The execution type, one of 'density', 'expval', or 'probs'.
        """
        return self._execution_type

    @execution_type.setter
    def execution_type(self, value: str) -> None:
        if value == "density":
            self._result_shape = (
                2 ** len(self.output_qubit),
                2 ** len(self.output_qubit),
            )
        elif value == "expval":
            # check if all qubits are used
            if len(self.output_qubit) == self.n_qubits:
                self._result_shape = (len(self.output_qubit),)
            # if not -> parity measurement with only 1D output per pair
            # or n_local measurement
            else:
                self._result_shape = (len(self.output_qubit),)
        elif value == "probs":
            # in case this is a list of parities,
            # each pair has 2^len(qubits) probabilities
            n_parity = (
                2 ** len(self.output_qubit[0])
                if isinstance(self.output_qubit[0], Tuple)
                else 2
            )
            self._result_shape = (len(self.output_qubit), n_parity)
        elif value == "state":
            self._result_shape = (2 ** len(self.output_qubit),)
        else:
            raise ValueError(f"Invalid execution type: {value}.")

        if value == "state" and not self.all_qubit_measurement:
            warnings.warn(
                f"{value} measurement does ignore output_qubit, which is "
                f"{self.output_qubit}.",
                UserWarning,
            )

        if value == "probs" and self.shots is None:
            warnings.warn(
                "Setting execution_type to probs without specifying shots.",
                UserWarning,
            )

        if value == "density" and self.shots is not None:
            warnings.warn(
                "Setting execution_type to density with specified shots.",
                UserWarning,
            )

        self._execution_type = value

    @property
    def shots(self) -> Optional[int]:
        """
        Gets the number of shots to use for the quantum device.

        Returns:
            Optional[int]: The number of shots.
        """
        return self._shots

    @shots.setter
    def shots(self, value: Optional[int]) -> None:
        """
        Sets the number of shots to use for the quantum device.

        Args:
            value (Optional[int]): The number of shots.
            If an integer less than or equal to 0 is provided, it is set to None.

        Returns:
            None
        """
        if type(value) is int and value <= 0:
            value = None
        self._shots = value

    @property
    def params(self) -> jnp.ndarray:
        """Get the variational parameters of the model."""
        return self._params

    @params.setter
    def params(self, value: jnp.ndarray) -> None:
        """Set the variational parameters, ensuring batch dimension exists."""
        if len(value.shape) == 2:
            value = value.reshape(*value.shape, 1)

        self._params = value

    @property
    def enc_params(self) -> jnp.ndarray:
        """Get the encoding parameters used for input transformation."""
        return self._enc_params

    @enc_params.setter
    def enc_params(self, value: jnp.ndarray) -> None:
        """Set the encoding parameters."""
        self._enc_params = value

    @property
    def pulse_params(self) -> jnp.ndarray:
        """Get the pulse parameters for pulse-mode gate execution."""
        return self._pulse_params

    @pulse_params.setter
    def pulse_params(self, value: jnp.ndarray) -> None:
        """Set the pulse parameters."""
        self._pulse_params = value

    @property
    def all_qubit_measurement(self) -> bool:
        """Check if measurement is performed on all qubits."""
        return self.output_qubit == list(range(self.n_qubits))

    @property
    def batch_shape(self) -> Tuple[int, ...]:
        """
        Get the batch shape (B_I, B_P, B_R).
        If the model was not called before,
        it returns (1, 1, 1).

        Returns:
            Tuple[int, ...]: Tuple of (input_batch, param_batch, pulse_batch).
                Returns (1, 1, 1) if model has not been called yet.
        """
        if self._batch_shape is None:
            log.debug("Model was not called yet. Returning (1,1,1) as batch shape.")
            return (1, 1, 1)
        return self._batch_shape

    @property
    def eff_batch_shape(self) -> Tuple[int, ...]:
        """
        Get the effective batch shape after applying repeat_batch_axis mask.

        Returns:
            Tuple[int, ...]: Effective batch dimensions, excluding zeros.
        """
        batch_shape = np.array(self.batch_shape) * self.repeat_batch_axis
        batch_shape = batch_shape[batch_shape != 0]
        return batch_shape

    def initialize_params(
        self,
        random_key: Optional[random.PRNGKey] = None,
        repeat: int = 1,
        initialization: Optional[str] = None,
        initialization_domain: Optional[List[float]] = None,
    ) -> random.PRNGKey:
        """
        Initialize the variational parameters of the model.

        Args:
            random_key (Optional[random.PRNGKey]): JAX random key for initialization.
                If None, uses the model's internal random key.
            repeat (int): Number of parameter sets to create (batch dimension).
                Defaults to 1.
            initialization (Optional[str]): Strategy for parameter initialization.
                Options: "random", "zeros", "pi", "zero-controlled", "pi-controlled".
                If None, uses the strategy specified in the constructor.
            initialization_domain (Optional[List[float]]): Domain [min, max] for
                random initialization. If None, uses the domain from constructor.

        Returns:
            random.PRNGKey: Updated random key after initialization.

        Raises:
            Exception: If an invalid initialization method is specified.
        """
        # Initializing params
        params_shape = (*self._params_shape, repeat)

        # use existing strategy if not specified
        initialization = initialization or self._inialization_strategy
        initialization_domain = initialization_domain or self._initialization_domain

        random_key, sub_key = safe_random_split(
            random_key if random_key is not None else self.random_key
        )

        def set_control_params(params: jnp.ndarray, value: float) -> jnp.ndarray:
            indices = self.pqc.get_control_indices(self.n_qubits)
            if indices is None:
                warnings.warn(
                    f"Specified {initialization} but circuit\
                    does not contain controlled rotation gates.\
                    Parameters are intialized randomly.",
                    UserWarning,
                )
            else:
                np_params = np.array(params)
                np_params[:, indices[0] : indices[1] : indices[2]] = (
                    np.ones_like(params[:, indices[0] : indices[1] : indices[2]])
                    * value
                )
                params = jnp.array(np_params)
            return params

        if initialization == "random":
            self.params: jnp.ndarray = random.uniform(
                sub_key,
                params_shape,
                minval=initialization_domain[0],
                maxval=initialization_domain[1],
            )
        elif initialization == "zeros":
            self.params: jnp.ndarray = jnp.zeros(params_shape)
        elif initialization == "pi":
            self.params: jnp.ndarray = jnp.ones(params_shape) * jnp.pi
        elif initialization == "zero-controlled":
            self.params: jnp.ndarray = random.uniform(
                sub_key,
                params_shape,
                minval=initialization_domain[0],
                maxval=initialization_domain[1],
            )
            self.params = set_control_params(self.params, 0)
        elif initialization == "pi-controlled":
            self.params: jnp.ndarray = random.uniform(
                sub_key,
                params_shape,
                minval=initialization_domain[0],
                maxval=initialization_domain[1],
            )
            self.params = set_control_params(self.params, jnp.pi)
        else:
            raise Exception("Invalid initialization method")

        log.info(
            f"Initialized parameters with shape {self.params.shape}\
            using strategy {initialization}."
        )

        return random_key

    def transform_input(
        self, inputs: jnp.ndarray, enc_params: jnp.ndarray
    ) -> jnp.ndarray:
        """
        Transform input data by scaling with encoding parameters.

        Implements the input transformation as described in arXiv:2309.03279v2,
        where inputs are linearly scaled by encoding parameters before being
        used in the quantum circuit.

        Args:
            inputs (jnp.ndarray): Input data point of shape (n_input_feat,) or
                (batch_size, n_input_feat).
            enc_params (jnp.ndarray): Encoding weight scalar or vector used to
                scale the input.

        Returns:
            jnp.ndarray: Transformed input, element-wise product of inputs
                and enc_params.
        """
        return inputs * enc_params

    def _iec(
        self,
        inputs: jnp.ndarray,
        data_reupload: jnp.ndarray,
        enc: Encoding,
        enc_params: jnp.ndarray,
        noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None,
        random_key: Optional[random.PRNGKey] = None,
    ) -> None:
        """
        Apply Input Encoding Circuit (IEC) with angle encoding.

        Encodes classical input data into the quantum circuit using rotation
        gates (e.g., RX, RY, RZ). Supports data re-uploading at specified
        positions in the circuit.

        Args:
            inputs (jnp.ndarray): Input data of shape (n_input_feat,) or
                (batch_size, n_input_feat).
            data_reupload (jnp.ndarray): Boolean array of shape (n_qubits, n_input_feat)
                indicating where to apply encoding gates.
            enc (Encoding): Encoding strategy containing the encoding gate functions.
            enc_params (jnp.ndarray): Encoding parameters of shape
                (n_qubits, n_input_feat) used to scale inputs.
            noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]):
                Noise parameters for gate-level noise simulation. Defaults to None.
            random_key (Optional[random.PRNGKey]): JAX random key for stochastic
                noise. Defaults to None.

        Returns:
            None: Gates are applied in-place to the quantum circuit.
        """
        # check for zero, because due to input validation, input cannot be none
        if self.remove_zero_encoding and self._zero_inputs and self.batch_shape[0] == 1:
            return

        for q in range(self.n_qubits):
            # use the last dimension of the inputs (feature dimension)
            for idx in range(inputs.shape[-1]):
                if data_reupload[q, idx]:
                    # use elipsis to indiex only the last dimension
                    # as inputs are generally *not* qubit dependent
                    random_key, sub_key = safe_random_split(random_key)
                    enc[idx](
                        self.transform_input(inputs[..., idx], enc_params[q, idx]),
                        wires=q,
                        noise_params=noise_params,
                        random_key=sub_key,
                    )

    def _circuit(
        self,
        params: jnp.ndarray,
        inputs: jnp.ndarray,
        pulse_params: Optional[jnp.ndarray] = None,
        enc_params: Optional[jnp.ndarray] = None,
        gate_mode: str = "unitary",
        noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None,
        random_key: Optional[random.PRNGKey] = None,
    ) -> Union[float, jnp.ndarray]:
        """
        Build and execute the quantum circuit.

        Constructs the full quantum circuit including variational layers and
        encoding, then returns the measurement result based on the configured
        execution type.

        Args:
            params (jnp.ndarray): Variational parameters of shape
                (n_layers, n_params_per_layer).
            inputs (jnp.ndarray): Input data of shape (n_input_feat,).
            pulse_params (Optional[jnp.ndarray]): Pulse parameter scalers of shape
                (n_layers, n_pulse_params_per_layer) for pulse-mode execution.
                Defaults to None.
            enc_params (Optional[jnp.ndarray]): Encoding parameters of shape
                (n_qubits, n_input_feat). Defaults to None (uses model's enc_params).
            gate_mode (str): Gate execution mode, either "unitary" or "pulse".
                Defaults to "unitary".
            noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]):
                Noise parameters for simulation. Defaults to None.
            random_key (Optional[random.PRNGKey]): JAX random key for stochastic
                operations. Defaults to None.

        Returns:
            Union[float, jnp.ndarray]: Circuit output depending on execution_type:
                - "expval": Expectation value(s) of the observable(s)
                - "density": Density matrix of output qubits
                - "probs": Measurement probabilities
                - "state": Full quantum state vector
        """
        self._variational(
            params=params,
            inputs=inputs,
            pulse_params=pulse_params,
            enc_params=enc_params,
            gate_mode=gate_mode,
            noise_params=noise_params,
            random_key=random_key,
        )
        return self._observable()

    def _variational(
        self,
        params: jnp.ndarray,
        inputs: jnp.ndarray,
        pulse_params: Optional[jnp.ndarray] = None,
        enc_params: Optional[jnp.ndarray] = None,
        gate_mode: str = "unitary",
        noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None,
        random_key: Optional[random.PRNGKey] = None,
    ) -> None:
        """
        Build the variational quantum circuit structure.

        Constructs the circuit by applying state preparation, alternating
        variational ansatz layers with input encoding layers, and optional
        noise channels.

        Args:
            params (jnp.ndarray): Variational parameters of shape
                (n_layers, n_params_per_layer).
            inputs (jnp.ndarray): Input data of shape (n_input_feat,).
            pulse_params (Optional[jnp.ndarray]): Pulse parameter scalers of shape
                (n_layers, n_pulse_params_per_layer) for pulse-mode execution.
                Defaults to None (uses model's pulse_params).
            enc_params (Optional[jnp.ndarray]): Encoding parameters of shape
                (n_qubits, n_input_feat). Defaults to None (uses model's enc_params).
            gate_mode (str): Gate execution mode, either "unitary" or "pulse".
                Defaults to "unitary".
            noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]):
                Noise parameters for simulation. Defaults to None.
            random_key (Optional[random.PRNGKey]): JAX random key for stochastic
                operations. Defaults to None.

        Returns:
            None: Gates are applied in-place to the quantum circuit.

        Note:
            Issues RuntimeWarning if called directly without providing parameters
            that would normally be passed through the forward method.
        """
        # TODO: rework and double check params shape
        if len(params.shape) > 2 and params.shape[2] == 1:
            params = params[:, :, 0]

        if len(inputs.shape) > 1 and inputs.shape[0] == 1:
            inputs = inputs[0]

        if enc_params is None:
            # TODO: Raise warning if trainable frequencies is True, or similar. I.e., no
            #   warning if user does not care for frequencies or enc_params
            if self.trainable_frequencies:
                warnings.warn(
                    "Explicit call to `_circuit` or `_variational` detected: "
                    "`enc_params` is None, using `self.enc_params` instead.",
                    RuntimeWarning,
                )
            enc_params = self.enc_params

        if pulse_params is None:
            if gate_mode == "pulse":
                warnings.warn(
                    "Explicit call to `_circuit` or `_variational` detected: "
                    "`pulse_params` is None, using `self.pulse_params` instead.",
                    RuntimeWarning,
                )
            pulse_params = self.pulse_params

        if noise_params is None:
            if self.noise_params is not None:
                warnings.warn(
                    "Explicit call to `_circuit` or `_variational` detected: "
                    "`noise_params` is None, using `self.noise_params` instead.",
                    RuntimeWarning,
                )
                noise_params = self.noise_params

        if noise_params is not None:
            if random_key is None:
                warnings.warn(
                    "Explicit call to `_circuit` or `_variational` detected: "
                    "`random_key` is None, using `random.PRNGKey(0)` instead.",
                    RuntimeWarning,
                )
                random_key = self.random_key
            self._apply_state_prep_noise(noise_params=noise_params)

        # state preparation
        for q in range(self.n_qubits):
            for _sp, sp_pulse_params in zip(self._sp, self.sp_pulse_params):
                random_key, sub_key = safe_random_split(random_key)
                _sp(
                    wires=q,
                    pulse_params=sp_pulse_params,
                    noise_params=noise_params,
                    random_key=sub_key,
                    gate_mode=gate_mode,
                )

        # circuit building
        for layer in range(0, self.n_layers):
            self.random_key, sub_key = safe_random_split(self.random_key)
            # ansatz layers
            self.pqc(
                params[layer],
                self.n_qubits,
                pulse_params=pulse_params[layer],
                noise_params=noise_params,
                random_key=sub_key,
                gate_mode=gate_mode,
            )

            self.random_key, sub_key = safe_random_split(self.random_key)
            # encoding layers
            self._iec(
                inputs,
                data_reupload=self.data_reupload[layer],
                enc=self._enc,
                enc_params=enc_params,
                noise_params=noise_params,
                random_key=sub_key,
            )

            # visual barrier
            if self.has_dru:
                qml.Barrier(wires=list(range(self.n_qubits)), only_visual=True)

        # final ansatz layer
        if self.has_dru:  # same check as in init
            self.random_key, sub_key = safe_random_split(self.random_key)
            self.pqc(
                params[self.n_layers],
                self.n_qubits,
                pulse_params=pulse_params[-1],
                noise_params=noise_params,
                random_key=sub_key,
                gate_mode=gate_mode,
            )

        # channel noise
        if noise_params is not None:
            self._apply_general_noise(noise_params=noise_params)

    def _observable(self) -> Union[jnp.ndarray, List[jnp.ndarray]]:
        """
        Define and return the measurement observable(s) for the circuit.

        Constructs the appropriate PennyLane measurement based on the
        configured execution_type and output_qubit settings.

        Returns:
            Union[jnp.ndarray, List[jnp.ndarray]]: Measurement result(s):
                - "density": qml.density_matrix for output qubits
                - "state": Full quantum state via qml.state()
                - "expval": Expectation value(s) of PauliZ observable(s)
                - "probs": Measurement probabilities

        Raises:
            ValueError: If execution_type or output_qubit configuration is invalid.
        """
        # run mixed simulation and get density matrix
        if self.execution_type == "density":
            return qml.density_matrix(wires=self.output_qubit)
        elif self.execution_type == "state":
            return qml.state()
        # run default simulation and get expectation value
        elif self.execution_type == "expval":
            # n-local measurement
            if self.all_qubit_measurement:
                return [qml.expval(qml.PauliZ(q)) for q in self.output_qubit]
            # parity or local measurement(s)
            elif isinstance(self.output_qubit, list):
                ret = []
                # list of parity pairs
                for pair in self.output_qubit:
                    if isinstance(pair, int):
                        ret.append(qml.expval(qml.PauliZ(pair)))
                    else:
                        obs = qml.PauliZ(pair[0])
                        for q in pair[1:]:
                            obs = obs @ qml.PauliZ(q)
                        ret.append(qml.expval(obs))
                return ret
            else:
                raise ValueError(
                    f"Invalid parameter `output_qubit`: {self.output_qubit}.\
                        Must be int, list or -1."
                )
        # run default simulation and get probs
        elif self.execution_type == "probs":
            # n-local measurement
            if self.all_qubit_measurement:
                return qml.probs(wires=self.output_qubit)
            # parity or local measurement(s)
            elif isinstance(self.output_qubit, list):
                ret = []
                # list of parity pairs
                for pair in self.output_qubit:
                    if isinstance(pair, int):
                        ret.append(qml.probs(wires=[pair]))
                    else:
                        ret.append(qml.probs(wires=pair))
                return ret
            else:
                raise ValueError(
                    f"Invalid parameter `output_qubit`: {self.output_qubit}.\
                        Must be int, list or -1."
                )
        else:
            raise ValueError(f"Invalid execution_type: {self.execution_type}.")

    def _apply_state_prep_noise(
        self, noise_params: Dict[str, Union[float, Dict[str, float]]]
    ) -> None:
        """
        Apply state preparation noise to all qubits.

        Simulates imperfect state preparation by applying BitFlip errors
        to each qubit with the specified probability.

        Args:
            noise_params (Dict[str, Union[float, Dict[str, float]]]): Dictionary
                containing noise parameters. Uses the "StatePreparation" key
                for the BitFlip probability.

        Returns:
            None: Noise channels are applied in-place to the circuit.
        """
        p = noise_params.get("StatePreparation", 0.0)
        if p > 0:
            for q in range(self.n_qubits):
                qml.BitFlip(p, wires=q)

    def _apply_general_noise(
        self, noise_params: Dict[str, Union[float, Dict[str, float]]]
    ) -> None:
        """
        Apply general noise channels to all qubits.

        Applies various decoherence and error channels after the circuit
        execution, simulating environmental noise effects.

        Args:
            noise_params (Dict[str, Union[float, Dict[str, float]]]): Dictionary
                containing noise parameters with the following supported keys:
                - "AmplitudeDamping" (float): Probability for amplitude damping.
                - "PhaseDamping" (float): Probability for phase damping.
                - "Measurement" (float): Probability for measurement error (BitFlip).
                - "ThermalRelaxation" (Dict): Dictionary with keys "t1", "t2",
                  "t_factor" for thermal relaxation simulation.

        Returns:
            None: Noise channels are applied in-place to the circuit.

        Note:
            Gate-level noise (e.g., GateError) is handled separately in the
            Gates.Noise module and applied at the individual gate level.
        """
        amp_damp = noise_params.get("AmplitudeDamping", 0.0)
        phase_damp = noise_params.get("PhaseDamping", 0.0)
        thermal_relax = noise_params.get("ThermalRelaxation", 0.0)
        meas = noise_params.get("Measurement", 0.0)
        for q in range(self.n_qubits):
            if amp_damp > 0:
                qml.AmplitudeDamping(amp_damp, wires=q)
            if phase_damp > 0:
                qml.PhaseDamping(phase_damp, wires=q)
            if meas > 0:
                qml.BitFlip(meas, wires=q)
            if isinstance(thermal_relax, dict):
                t1 = thermal_relax["t1"]
                t2 = thermal_relax["t2"]
                t_factor = thermal_relax["t_factor"]
                circuit_depth = self._get_circuit_depth()
                tg = circuit_depth * t_factor
                qml.ThermalRelaxationError(1.0, t1, t2, tg, q)

    def _get_circuit_depth(self, inputs: Optional[jnp.ndarray] = None) -> int:
        """
        Calculate the depth of the quantum circuit.

        Creates a copy of the model without noise to accurately measure
        the circuit depth using PennyLane's specs function.

        Args:
            inputs (Optional[jnp.ndarray]): Input data for circuit evaluation.
                If None, default zero inputs are used.

        Returns:
            int: The circuit depth (longest path of gates in the circuit).
        """
        inputs = self._inputs_validation(inputs)
        spec_model = deepcopy(self)
        spec_model.noise_params = None  # remove noise
        specs = qml.specs(spec_model.circuit)(self.params, inputs)

        return specs["resources"].depth

    def draw(
        self,
        inputs: Optional[jnp.ndarray] = None,
        figure: str = "text",
        *args: Any,
        **kwargs: Any,
    ) -> Union[str, Any]:
        """
        Visualize the quantum circuit.

        Generates a visual representation of the circuit using the specified
        rendering method.

        Args:
            inputs (Optional[jnp.ndarray]): Input data for the circuit. If None,
                default zero inputs are used. Defaults to None.
            figure (str): Visualization format. Options:
                - "text": ASCII text representation
                - "mpl": Matplotlib figure
                - "tikz": TikZ/LaTeX code for publication-quality figures
                Defaults to "text".
            *args (Any): Additional positional arguments passed to the
                visualization backend.
            **kwargs (Any): Additional keyword arguments passed to the
                visualization backend. May include pulse_params, gate_mode,
                enc_params, or noise_params.

        Returns:
            Union[str, Any]: Visualization output:
                - "text": String with ASCII circuit diagram
                - "mpl": Matplotlib Figure and Axes objects
                - "tikz": TikZ code string

        Raises:
            AssertionError: If figure is not one of "text", "mpl", or "tikz".
        """

        if not isinstance(self.circuit, qml.QNode):
            # TODO: throws strange argument error if not catched
            return ""

        assert figure in [
            "text",
            "mpl",
            "tikz",
        ], f"Invalid figure: {figure}. Must be 'text', 'mpl' or 'tikz'."

        inputs = self._inputs_validation(inputs)

        if figure == "mpl":
            return qml.draw_mpl(self.circuit)(
                params=self.params,
                inputs=inputs,
                *args,
                **kwargs,
            )
        elif figure == "tikz":
            return QuanTikz.build(
                self.circuit,
                params=self.params,
                inputs=inputs,
                *args,
                **kwargs,
            )
        else:
            return qml.draw(self.circuit)(params=self.params, inputs=inputs)

    def __repr__(self) -> str:
        """Return text representation of the quantum circuit."""
        return self.draw(figure="text")

    def __str__(self) -> str:
        """Return string representation of the quantum circuit."""
        return self.draw(figure="text")

    def _params_validation(self, params: Optional[jnp.ndarray]) -> jnp.ndarray:
        """
        Validate and normalize variational parameters.

        Ensures parameters have the correct shape with a batch dimension,
        and updates the model's internal parameters if new ones are provided.

        Args:
            params (Optional[jnp.ndarray]): Variational parameters to validate.
                If None, returns the model's current parameters.

        Returns:
            jnp.ndarray: Validated parameters with shape
                (n_layers, n_params_per_layer, batch_size).
        """
        # append batch axis if not provided
        if params is not None:
            if len(params.shape) == 2:
                params = np.expand_dims(params, axis=-1)

            self.params = params
        else:
            params = self.params

        return params

    def _pulse_params_validation(
        self, pulse_params: Optional[jnp.ndarray]
    ) -> jnp.ndarray:
        """
        Validate and normalize pulse parameters.

        Ensures pulse parameters are set, using model defaults if not provided.

        Args:
            pulse_params (Optional[jnp.ndarray]): Pulse parameter scalers.
                If None, returns the model's current pulse parameters.

        Returns:
            jnp.ndarray: Validated pulse parameters with shape
                (n_layers, n_pulse_params_per_layer, batch_size).
        """
        if pulse_params is None:
            pulse_params = self.pulse_params
        else:
            self.pulse_params = pulse_params

        return pulse_params

    def _enc_params_validation(self, enc_params: Optional[jnp.ndarray]) -> jnp.ndarray:
        """
        Validate and normalize encoding parameters.

        Ensures encoding parameters have the correct shape for the model's
        input feature dimensions.

        Args:
            enc_params (Optional[jnp.ndarray]): Encoding parameters to validate.
                If None, returns the model's current encoding parameters.

        Returns:
            jnp.ndarray: Validated encoding parameters with shape
                (n_qubits, n_input_feat).

        Raises:
            ValueError: If enc_params shape is incompatible with n_input_feat > 1.
        """
        if enc_params is None:
            enc_params = self.enc_params
        else:
            if self.trainable_frequencies:
                self.enc_params = enc_params
            else:
                self.enc_params = jnp.array(enc_params)

        if len(enc_params.shape) == 1 and self.n_input_feat == 1:
            enc_params = enc_params.reshape(-1, 1)
        elif len(enc_params.shape) == 1 and self.n_input_feat > 1:
            raise ValueError(
                f"Input dimension {self.n_input_feat} >1 but \
                `enc_params` has shape {enc_params.shape}"
            )

        return enc_params

    def _inputs_validation(
        self, inputs: Union[None, List, float, int, jnp.ndarray]
    ) -> jnp.ndarray:
        """
        Validate and normalize input data.

        Converts various input formats to a standardized 2D array shape
        suitable for batch processing in the quantum circuit.

        Args:
            inputs (Union[None, List, float, int, jnp.ndarray]): Input data in
                various formats:
                - None: Returns zeros with shape (1, n_input_feat)
                - float/int: Single scalar value
                - List: List of values or batched inputs
                - jnp.ndarray: NumPy/JAX array

        Returns:
            jnp.ndarray: Validated inputs with shape (batch_size, n_input_feat).

        Raises:
            ValueError: If input shape is incompatible with expected n_input_feat.

        Warns:
            UserWarning: If input is replicated to match n_input_feat.
        """
        self._zero_inputs = False
        if isinstance(inputs, List):
            inputs = jnp.array(np.stack(inputs))
        elif isinstance(inputs, float) or isinstance(inputs, int):
            inputs = jnp.array([inputs])
        elif inputs is None:
            inputs = jnp.array([[0] * self.n_input_feat])

        if not inputs.any():
            self._zero_inputs = True

        if len(inputs.shape) <= 1:
            if self.n_input_feat == 1:
                # add a batch dimension
                inputs = inputs.reshape(-1, 1)
            else:
                if inputs.shape[0] == self.n_input_feat:
                    inputs = inputs.reshape(1, -1)
                else:
                    inputs = inputs.reshape(-1, 1)
                    inputs = inputs.repeat(self.n_input_feat, axis=1)
                    warnings.warn(
                        f"Expected {self.n_input_feat} inputs, but {inputs.shape[0]} "
                        "was provided, replicating input for all input features.",
                        UserWarning,
                    )
        else:
            if inputs.shape[1] != self.n_input_feat:
                raise ValueError(
                    f"Wrong number of inputs provided. Expected {self.n_input_feat} "
                    f"inputs, but input has shape {inputs.shape}."
                )

        return inputs

    def _mp_executor(
        self,
        f: Callable,
        params: jnp.ndarray,
        pulse_params: jnp.ndarray,
        inputs: jnp.ndarray,
        enc_params: jnp.ndarray,
        noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]],
        random_key: random.PRNGKey,
        gate_mode: str,
    ) -> jnp.ndarray:
        """
        Execute circuit function with optional parallelization over batches.

        Uses JAX's vmap for vectorized execution when batching over inputs,
        parameters, or pulse parameters. Falls back to sequential execution
        for single samples or when multithreading is disabled.

        Args:
            f (Callable): Circuit function to execute (circuit or circuit_mixed).
            params (jnp.ndarray): Variational parameters of shape
                (n_layers, n_params_per_layer, batch_size).
            pulse_params (jnp.ndarray): Pulse parameters of shape
                (n_layers, n_pulse_params_per_layer, batch_size).
            inputs (jnp.ndarray): Input data of shape (batch_size, n_input_feat).
            enc_params (jnp.ndarray): Encoding parameters of shape
                (n_qubits, n_input_feat).
            noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]):
                Noise configuration dictionary.
            random_key (random.PRNGKey): JAX random key for stochastic operations.
            gate_mode (str): Gate execution mode ("unitary" or "pulse").

        Returns:
            jnp.ndarray: Circuit execution results, post-processed for uniformity.
        """

        def _f(
            _params: jnp.ndarray,
            _inputs: jnp.ndarray,
            _pulse_params: jnp.ndarray,
            _random_key: random.PRNGKey,
        ) -> jnp.ndarray:
            return f(
                params=_params,
                inputs=_inputs,
                pulse_params=_pulse_params,
                random_key=_random_key,
                noise_params=noise_params,
                enc_params=enc_params,
                gate_mode=gate_mode,
            )

        B = np.prod(self.eff_batch_shape)
        if (gate_mode == "pulse" or self.use_multithreading) and B > 1:
            random_keys = safe_random_split(random_key, num=B)

            # wrapper to allow kwargs (not supported by jax)
            result = jax.vmap(
                _f,
                in_axes=(
                    2 if self.batch_shape[1] > 1 else None,  # params
                    0 if self.batch_shape[0] > 1 else None,  # inputs
                    2 if self.batch_shape[2] > 1 else None,  # pulse_params
                    0,  # random_keys
                ),
            )(
                params,
                inputs,
                pulse_params,
                random_keys,
            )
        else:
            result = _f(
                _params=params,
                _pulse_params=pulse_params,
                _inputs=inputs,
                _random_key=random_key,
            )

        return self._postprocess_res(result)

    def _postprocess_res(self, result: Union[List, jnp.ndarray]) -> jnp.ndarray:
        """
        Post-process circuit execution results for uniform shape.

        Converts list outputs (from multiple measurements) to stacked arrays
        and reorders axes for consistent batch dimension placement.

        Args:
            result (Union[List, jnp.ndarray]): Raw circuit output, either a
                list of measurement results or a single array.

        Returns:
            jnp.ndarray: Uniformly shaped result array with batch dimension first.
        """
        if isinstance(result, list):
            # we use moveaxis here because in case of parity measure,
            # there is another dimension appended to the end and
            # simply transposing would result in a wrong shape
            result = jnp.stack(result)
            if len(result.shape) > 1:
                result = jnp.moveaxis(result, 0, 1)
        return result

    def _assimilate_batch(
        self,
        inputs: jnp.ndarray,
        params: jnp.ndarray,
        pulse_params: jnp.ndarray,
    ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
        """
        Align batch dimensions across inputs, parameters, and pulse parameters.

        Broadcasts and reshapes arrays to have compatible batch dimensions
        for vectorized circuit execution. Sets the internal batch_shape.

        Args:
            inputs (jnp.ndarray): Input data of shape (B_I, n_input_feat).
            params (jnp.ndarray): Parameters of shape (n_layers, n_params, B_P).
            pulse_params (jnp.ndarray): Pulse params of shape (n_layers, n_pulse, B_R).

        Returns:
            Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: Tuple containing:
                - inputs: Reshaped to (B, n_input_feat) where B = B_I * B_P * B_R
                - params: Reshaped to (n_layers, n_params, B)
                - pulse_params: Reshaped to (n_layers, n_pulse, B)

        Note:
            The effective batch shape depends on repeat_batch_axis configuration.
            This is the only method that sets self._batch_shape.
        """
        B_I = inputs.shape[0]
        # we check for the product because there is a chance that
        # there are no params. In this case we want B_P to be 1
        B_P = 1 if 0 in params.shape else params.shape[-1]
        B_R = pulse_params.shape[-1]

        # THIS is the only place where we set the batch shape
        self._batch_shape = (B_I, B_P, B_R)
        B = np.prod(self.eff_batch_shape)

        # [B_I, ...] -> [B_I, B_P, B_R, ...] -> [B, ...]
        if B_I > 1 and self.repeat_batch_axis[0]:
            if self.repeat_batch_axis[1]:
                inputs = jnp.repeat(inputs[:, None, None, ...], B_P, axis=1)
            if self.repeat_batch_axis[2]:
                inputs = jnp.repeat(inputs, B_R, axis=2)
            inputs = inputs.reshape(B, *inputs.shape[3:])

        # [..., ..., B_P] -> [..., ..., B_I, B_P, B_R] -> [..., ..., B]
        if B_P > 1 and self.repeat_batch_axis[1]:
            # add B_I axis before last, and B_R axis after last
            params = params[..., None, :, None]  # [..., B_I(=1), B_P, B_R(=1)]
            if self.repeat_batch_axis[0]:
                params = jnp.repeat(params, B_I, axis=-3)  # [..., B_I, B_P, 1]
            if self.repeat_batch_axis[2]:
                params = jnp.repeat(params, B_R, axis=-1)  # [..., B_I, B_P, B_R]
            params = params.reshape(*params.shape[:-3], B)

        # [..., ..., B_R] -> [..., ..., B_I, B_P, B_R] -> [..., ..., B]
        if B_R > 1 and self.repeat_batch_axis[2]:
            # add B_I axis before last, and B_P axis before last (after adding B_I)
            pulse_params = pulse_params[
                ..., None, None, :
            ]  # [..., B_I(=1), B_P(=1), B_R]
            if self.repeat_batch_axis[0]:
                pulse_params = jnp.repeat(
                    pulse_params, B_I, axis=-3
                )  # [..., B_I, 1, B_R]
            if self.repeat_batch_axis[1]:
                pulse_params = jnp.repeat(
                    pulse_params, B_P, axis=-2
                )  # [..., B_I, B_P, B_R]
            pulse_params = pulse_params.reshape(*pulse_params.shape[:-3], B)

        return inputs, params, pulse_params

    def _requires_density(self) -> bool:
        """
        Check if density matrix simulation is required.

        Determines whether the circuit must be executed with the mixed-state
        simulator based on execution type and noise configuration.

        Returns:
            bool: True if density matrix simulation is required, False otherwise.
                Returns True if:
                - execution_type is "density", or
                - Any non-coherent noise channel has non-zero probability
        """
        if self.execution_type == "density":
            return True

        if self.noise_params is None:
            return False

        coherent_noise = {"GateError"}
        for k, v in self.noise_params.items():
            if k in coherent_noise:
                continue
            if v is not None and v > 0:
                return True
        return False

    def __call__(
        self,
        params: Optional[jnp.ndarray] = None,
        inputs: Optional[jnp.ndarray] = None,
        pulse_params: Optional[jnp.ndarray] = None,
        enc_params: Optional[jnp.ndarray] = None,
        noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None,
        execution_type: Optional[str] = None,
        force_mean: bool = False,
        gate_mode: str = "unitary",
    ) -> jnp.ndarray:
        """
        Execute the quantum circuit (callable interface).

        Provides a convenient callable interface for circuit execution,
        delegating to the _forward method.

        Args:
            params (Optional[jnp.ndarray]): Variational parameters of shape
                (n_layers, n_params_per_layer) or (n_layers, n_params_per_layer, batch).
                If None, uses model's internal parameters.
            inputs (Optional[jnp.ndarray]): Input data of shape
                (batch_size, n_input_feat). If None, uses zero inputs.
            pulse_params (Optional[jnp.ndarray]): Pulse parameter scalers for
                pulse-mode gate execution.
            enc_params (Optional[jnp.ndarray]): Encoding parameters of shape
                (n_qubits, n_input_feat). If None, uses model's encoding parameters.
            noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]):
                Noise configuration. If None, uses previously set noise parameters.
            execution_type (Optional[str]): Measurement type: "expval", "density",
                "probs", or "state". If None, uses current execution_type setting.
            force_mean (bool): If True, averages results over measurement qubits.
                Defaults to False.
            gate_mode (str): Gate execution backend, "unitary" or "pulse".
                Defaults to "unitary".

        Returns:
            jnp.ndarray: Circuit output with shape depending on execution_type:
                - "expval": (n_output_qubits,) or scalar
                - "density": (2^n_output, 2^n_output)
                - "probs": (2^n_output,) or (n_pairs, 2^pair_size)
                - "state": (2^n_qubits,)
        """
        # Call forward method which handles the actual caching etc.
        return self._forward(
            params=params,
            inputs=inputs,
            pulse_params=pulse_params,
            enc_params=enc_params,
            noise_params=noise_params,
            execution_type=execution_type,
            force_mean=force_mean,
            gate_mode=gate_mode,
        )

    def _forward(
        self,
        params: Optional[jnp.ndarray] = None,
        inputs: Optional[jnp.ndarray] = None,
        pulse_params: Optional[jnp.ndarray] = None,
        enc_params: Optional[jnp.ndarray] = None,
        noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None,
        execution_type: Optional[str] = None,
        force_mean: bool = False,
        gate_mode: str = "unitary",
    ) -> jnp.ndarray:
        """
        Execute the quantum circuit forward pass.

        Internal implementation of the forward pass that handles parameter
        validation, batch alignment, and circuit execution routing.

        Args:
            params (Optional[jnp.ndarray]): Variational parameters of shape
                (n_layers, n_params_per_layer) or
                (n_layers, n_params_per_layer, batch).
                If None, uses model's internal parameters.
            inputs (Optional[jnp.ndarray]): Input data of shape
                (batch_size, n_input_feat).
                If None, uses zero inputs.
            pulse_params (Optional[jnp.ndarray]): Pulse parameter scalers for
                pulse-mode gate execution.
            enc_params (Optional[jnp.ndarray]): Encoding parameters of shape
                (n_qubits, n_input_feat). If None, uses model's encoding parameters.
            noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]):
                Noise configuration. If None, uses previously set noise parameters.
            execution_type (Optional[str]): Measurement type: "expval", "density",
                "probs", or "state". If None, uses current execution_type setting.
            force_mean (bool): If True, averages results over measurement qubits.
                Defaults to False.
            gate_mode (str): Gate execution backend, "unitary" or "pulse".
                Defaults to "unitary".

        Returns:
            jnp.ndarray: Circuit output with shape depending on execution_type:
                - "expval": (n_output_qubits,) or scalar
                - "density": (2^n_output, 2^n_output)
                - "probs": (2^n_output,) or (n_pairs, 2^pair_size)
                - "state": (2^n_qubits,)

        Raises:
            ValueError: If pulse_params provided without pulse gate_mode, or
                if noise_params provided with pulse gate_mode.
        """
        # set the parameters as object attributes
        if noise_params is not None:
            self.noise_params = noise_params
        if execution_type is not None:
            self.execution_type = execution_type
        self.gate_mode = gate_mode

        # consistency checks
        if pulse_params is not None and gate_mode != "pulse":
            raise ValueError(
                "pulse_params were provided but gate_mode is not 'pulse'. "
                "Either switch gate_mode='pulse' or do not pass pulse_params."
            )

        if noise_params is not None and gate_mode == "pulse":
            raise ValueError(
                "Noise is not supported in 'pulse' gate_mode. "
                "Either remove noise_params or use gate_mode='unitary'."
            )

        params = self._params_validation(params)
        pulse_params = self._pulse_params_validation(pulse_params)
        inputs = self._inputs_validation(inputs)
        enc_params = self._enc_params_validation(enc_params)

        inputs, params, pulse_params = self._assimilate_batch(
            inputs,
            params,
            pulse_params,
        )

        result: Optional[jnp.ndarray] = None
        self.random_key, subkey = safe_random_split(self.random_key)

        # if density matrix requested or noise params used
        if self._requires_density():
            result = self._mp_executor(
                f=self.circuit_mixed,
                params=params,
                pulse_params=pulse_params,
                inputs=inputs,
                enc_params=enc_params,
                noise_params=self.noise_params,
                random_key=subkey,
                gate_mode=gate_mode,
            )
        else:
            if not isinstance(self.circuit, qml.QNode):
                result = self.circuit(
                    inputs=inputs,
                )
            else:
                result = self._mp_executor(
                    f=self.circuit,
                    params=params,
                    pulse_params=pulse_params,
                    inputs=inputs,
                    enc_params=enc_params,
                    noise_params=self.noise_params,
                    random_key=subkey,
                    gate_mode=gate_mode,
                )

        result = result.reshape((*self.eff_batch_shape, *self._result_shape)).squeeze()

        if (
            self.execution_type in ("expval", "probs")
            and force_mean
            and len(result.shape) > 0
            and self._result_shape[0] > 1
        ):
            result = result.mean(axis=-1)

        return result

all_qubit_measurement property #

Check if measurement is performed on all qubits.

batch_shape property #

Get the batch shape (B_I, B_P, B_R). If the model was not called before, it returns (1, 1, 1).

Returns:

Type Description
Tuple[int, ...]

Tuple[int, ...]: Tuple of (input_batch, param_batch, pulse_batch). Returns (1, 1, 1) if model has not been called yet.

eff_batch_shape property #

Get the effective batch shape after applying repeat_batch_axis mask.

Returns:

Type Description
Tuple[int, ...]

Tuple[int, ...]: Effective batch dimensions, excluding zeros.

enc_params property writable #

Get the encoding parameters used for input transformation.

execution_type property writable #

Gets the execution type of the model.

Returns:

Name Type Description
str str

The execution type, one of 'density', 'expval', or 'probs'.

noise_params property writable #

Gets the noise parameters of the model.

Returns:

Type Description
Optional[Dict[str, Union[float, Dict[str, float]]]]

Optional[Dict[str, float]]: A dictionary of

Optional[Dict[str, Union[float, Dict[str, float]]]]

noise parameters or None if not set.

output_qubit property writable #

Get the output qubit indices for measurement.

params property writable #

Get the variational parameters of the model.

pulse_params property writable #

Get the pulse parameters for pulse-mode gate execution.

shots property writable #

Gets the number of shots to use for the quantum device.

Returns:

Type Description
Optional[int]

Optional[int]: The number of shots.

__call__(params=None, inputs=None, pulse_params=None, enc_params=None, noise_params=None, execution_type=None, force_mean=False, gate_mode='unitary') #

Execute the quantum circuit (callable interface).

Provides a convenient callable interface for circuit execution, delegating to the _forward method.

Parameters:

Name Type Description Default
params Optional[ndarray]

Variational parameters of shape (n_layers, n_params_per_layer) or (n_layers, n_params_per_layer, batch). If None, uses model's internal parameters.

None
inputs Optional[ndarray]

Input data of shape (batch_size, n_input_feat). If None, uses zero inputs.

None
pulse_params Optional[ndarray]

Pulse parameter scalers for pulse-mode gate execution.

None
enc_params Optional[ndarray]

Encoding parameters of shape (n_qubits, n_input_feat). If None, uses model's encoding parameters.

None
noise_params Optional[Dict[str, Union[float, Dict[str, float]]]]

Noise configuration. If None, uses previously set noise parameters.

None
execution_type Optional[str]

Measurement type: "expval", "density", "probs", or "state". If None, uses current execution_type setting.

None
force_mean bool

If True, averages results over measurement qubits. Defaults to False.

False
gate_mode str

Gate execution backend, "unitary" or "pulse". Defaults to "unitary".

'unitary'

Returns:

Type Description
ndarray

jnp.ndarray: Circuit output with shape depending on execution_type: - "expval": (n_output_qubits,) or scalar - "density": (2^n_output, 2^n_output) - "probs": (2^n_output,) or (n_pairs, 2^pair_size) - "state": (2^n_qubits,)

Source code in qml_essentials/model.py
def __call__(
    self,
    params: Optional[jnp.ndarray] = None,
    inputs: Optional[jnp.ndarray] = None,
    pulse_params: Optional[jnp.ndarray] = None,
    enc_params: Optional[jnp.ndarray] = None,
    noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None,
    execution_type: Optional[str] = None,
    force_mean: bool = False,
    gate_mode: str = "unitary",
) -> jnp.ndarray:
    """
    Execute the quantum circuit (callable interface).

    Provides a convenient callable interface for circuit execution,
    delegating to the _forward method.

    Args:
        params (Optional[jnp.ndarray]): Variational parameters of shape
            (n_layers, n_params_per_layer) or (n_layers, n_params_per_layer, batch).
            If None, uses model's internal parameters.
        inputs (Optional[jnp.ndarray]): Input data of shape
            (batch_size, n_input_feat). If None, uses zero inputs.
        pulse_params (Optional[jnp.ndarray]): Pulse parameter scalers for
            pulse-mode gate execution.
        enc_params (Optional[jnp.ndarray]): Encoding parameters of shape
            (n_qubits, n_input_feat). If None, uses model's encoding parameters.
        noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]):
            Noise configuration. If None, uses previously set noise parameters.
        execution_type (Optional[str]): Measurement type: "expval", "density",
            "probs", or "state". If None, uses current execution_type setting.
        force_mean (bool): If True, averages results over measurement qubits.
            Defaults to False.
        gate_mode (str): Gate execution backend, "unitary" or "pulse".
            Defaults to "unitary".

    Returns:
        jnp.ndarray: Circuit output with shape depending on execution_type:
            - "expval": (n_output_qubits,) or scalar
            - "density": (2^n_output, 2^n_output)
            - "probs": (2^n_output,) or (n_pairs, 2^pair_size)
            - "state": (2^n_qubits,)
    """
    # Call forward method which handles the actual caching etc.
    return self._forward(
        params=params,
        inputs=inputs,
        pulse_params=pulse_params,
        enc_params=enc_params,
        noise_params=noise_params,
        execution_type=execution_type,
        force_mean=force_mean,
        gate_mode=gate_mode,
    )

__init__(n_qubits, n_layers, circuit_type='No_Ansatz', data_reupload=True, state_preparation=None, encoding=Gates.RX, trainable_frequencies=False, initialization='random', initialization_domain=[0, 2 * jnp.pi], output_qubit=-1, shots=None, random_seed=1000, remove_zero_encoding=True, use_multithreading=False, repeat_batch_axis=[True, True, True]) #

Initialize the quantum circuit model. Parameters will have the shape [impl_n_layers, parameters_per_layer] where impl_n_layers is the number of layers provided and added by one depending if data_reupload is True and parameters_per_layer is given by the chosen ansatz.

The model is initialized with the following parameters as defaults: - noise_params: None - execution_type: "expval" - shots: None

Parameters:

Name Type Description Default
n_qubits int

The number of qubits in the circuit.

required
n_layers int

The number of layers in the circuit.

required
circuit_type (str, Circuit)

The type of quantum circuit to use. If None, defaults to "no_ansatz".

'No_Ansatz'
data_reupload Union[bool, List[bool], List[List[bool]]]

Whether to reupload data to the quantum device on each layer and qubit. Detailed re-uploading instructions can be given as a list/array of 0/False and 1/True with shape (n_qubits, n_layers) to specify where to upload the data. Defaults to True for applying data re-uploading to the full circuit.

True
encoding Union[str, Callable, List[str], List[Callable]]

The unitary to use for encoding the input data. Can be a string (e.g. "RX") or a callable (e.g. qml.RX). Defaults to qml.RX. If input is multidimensional it is assumed to be a list of unitaries or a list of strings.

RX
trainable_frequencies bool

Sets trainable encoding parameters for trainable frequencies. Defaults to False.

False
initialization str

The strategy to initialize the parameters. Can be "random", "zeros", "zero-controlled", "pi", or "pi-controlled". Defaults to "random".

'random'
output_qubit (List[int], int)

The index of the output qubit (or qubits). When set to -1 all qubits are measured, or a global measurement is conducted, depending on the execution type.

-1
shots Optional[int]

The number of shots to use for the quantum device. Defaults to None.

None
random_seed int

seed for the random number generator in initialization is "random" and for random noise parameters. Defaults to 1000.

1000
remove_zero_encoding bool

whether to remove the zero encoding from the circuit. Defaults to True.

True
use_multithreading bool

whether to use JAX multithreading to parallelise over batch dimension.

False

Returns:

Type Description
None

None

Source code in qml_essentials/model.py
def __init__(
    self,
    n_qubits: int,
    n_layers: int,
    circuit_type: Union[str, Circuit] = "No_Ansatz",
    data_reupload: Union[bool, List[List[bool]], List[List[List[bool]]]] = True,
    state_preparation: Union[
        str, Callable, List[Union[str, Callable]], None
    ] = None,
    encoding: Union[Encoding, str, Callable, List[Union[str, Callable]]] = Gates.RX,
    trainable_frequencies: bool = False,
    initialization: str = "random",
    initialization_domain: List[float] = [0, 2 * jnp.pi],
    output_qubit: Union[List[int], int] = -1,
    shots: Optional[int] = None,
    random_seed: int = 1000,
    remove_zero_encoding: bool = True,
    use_multithreading: bool = False,
    repeat_batch_axis: List[bool] = [True, True, True],
) -> None:
    """
    Initialize the quantum circuit model.
    Parameters will have the shape [impl_n_layers, parameters_per_layer]
    where impl_n_layers is the number of layers provided and added by one
    depending if data_reupload is True and parameters_per_layer is given by
    the chosen ansatz.

    The model is initialized with the following parameters as defaults:
    - noise_params: None
    - execution_type: "expval"
    - shots: None

    Args:
        n_qubits (int): The number of qubits in the circuit.
        n_layers (int): The number of layers in the circuit.
        circuit_type (str, Circuit): The type of quantum circuit to use.
            If None, defaults to "no_ansatz".
        data_reupload (Union[bool, List[bool], List[List[bool]]], optional):
            Whether to reupload data to the quantum device on each
            layer and qubit. Detailed re-uploading instructions can be given
            as a list/array of 0/False and 1/True with shape (n_qubits,
            n_layers) to specify where to upload the data. Defaults to True
            for applying data re-uploading to the full circuit.
        encoding (Union[str, Callable, List[str], List[Callable]], optional):
            The unitary to use for encoding the input data. Can be a string
            (e.g. "RX") or a callable (e.g. qml.RX). Defaults to qml.RX.
            If input is multidimensional it is assumed to be a list of
            unitaries or a list of strings.
        trainable_frequencies (bool, optional):
            Sets trainable encoding parameters for trainable frequencies.
            Defaults to False.
        initialization (str, optional): The strategy to initialize the parameters.
            Can be "random", "zeros", "zero-controlled", "pi", or "pi-controlled".
            Defaults to "random".
        output_qubit (List[int], int, optional): The index of the output
            qubit (or qubits). When set to -1 all qubits are measured, or a
            global measurement is conducted, depending on the execution
            type.
        shots (Optional[int], optional): The number of shots to use for
            the quantum device. Defaults to None.
        random_seed (int, optional): seed for the random number generator
            in initialization is "random" and for random noise parameters.
            Defaults to 1000.
        remove_zero_encoding (bool, optional): whether to
            remove the zero encoding from the circuit. Defaults to True.
        use_multithreading (bool, optional): whether to use JAX
            multithreading to parallelise over batch dimension.

    Returns:
        None
    """
    # Initialize default parameters needed for circuit evaluation
    self.n_qubits: int = n_qubits
    self.output_qubit: Union[List[int], int] = output_qubit
    self.n_layers: int = n_layers
    self.noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None
    self.shots = shots
    self.remove_zero_encoding = remove_zero_encoding
    self.use_multithreading = use_multithreading
    self.trainable_frequencies: bool = trainable_frequencies
    self.execution_type: str = "expval"
    self.repeat_batch_axis: List[bool] = repeat_batch_axis

    # --- State Preparation ---
    try:
        self._sp = Gates.parse_gates(state_preparation, Gates)
    except ValueError as e:
        raise ValueError(f"Error parsing encodings: {e}")

    # prepare corresponding pulse parameters (always optimized pulses)
    self.sp_pulse_params = []
    for sp in self._sp:
        sp_name = sp.__name__ if hasattr(sp, "__name__") else str(sp)

        if pinfo.gate_by_name(sp_name) is not None:
            self.sp_pulse_params.append(pinfo.gate_by_name(sp_name).params)
        else:
            # gate has no pulse parametrization
            self.sp_pulse_params.append(None)

    # --- Encoding ---
    if isinstance(encoding, Encoding):
        # user wants custom strategy? do it!
        self._enc = encoding
    else:
        # use hammming encoding by default
        self._enc = Encoding("hamming", encoding)

    # Number of possible inputs
    self.n_input_feat = len(self._enc)
    log.debug(f"Number of input features: {self.n_input_feat}")

    # Trainable frequencies, default initialization as in arXiv:2309.03279v2
    self.enc_params = jnp.ones((self.n_qubits, self.n_input_feat))

    self._zero_inputs = False

    # --- Data-Reuploading ---
    # Process data reuploading strategy and set degree
    if not isinstance(data_reupload, bool):
        if not isinstance(data_reupload, np.ndarray):
            data_reupload = np.array(data_reupload)

        if len(data_reupload.shape) == 2:
            assert data_reupload.shape == (
                n_layers,
                n_qubits,
            ), f"Data reuploading array has wrong shape. \
                Expected {(n_layers, n_qubits)} or\
                {(n_layers, n_qubits, self.n_input_feat)},\
                got {data_reupload.shape}."
            data_reupload = data_reupload.reshape(*data_reupload.shape, 1)
            data_reupload = np.repeat(data_reupload, self.n_input_feat, axis=2)

        assert data_reupload.shape == (
            n_layers,
            n_qubits,
            self.n_input_feat,
        ), f"Data reuploading array has wrong shape. \
            Expected {(n_layers, n_qubits, self.n_input_feat)},\
            got {data_reupload.shape}."

        log.debug(f"Data reuploading array:\n{data_reupload}")
    else:
        if data_reupload:
            impl_n_layers: int = (
                n_layers + 1
            )  # we need L+1 according to Schuld et al.
            data_reupload = np.ones((n_layers, n_qubits, self.n_input_feat))
            log.debug("Full data reuploading.")
        else:
            impl_n_layers: int = n_layers
            data_reupload = np.zeros((n_layers, n_qubits, self.n_input_feat))
            data_reupload[0][0] = 1
            log.debug("No data reuploading.")

    # convert to boolean values
    data_reupload = data_reupload.astype(bool)
    self.data_reupload = jnp.array(data_reupload)

    self.degree: Tuple = tuple(
        self._enc.get_n_freqs(jnp.count_nonzero(self.data_reupload[..., i]))
        for i in range(self.n_input_feat)
    )

    self.frequencies: Tuple = tuple(
        self._enc.get_spectrum(jnp.count_nonzero(self.data_reupload[..., i]))
        for i in range(self.n_input_feat)
    )

    self.has_dru = jnp.max(jnp.array([jnp.max(f) for f in self.frequencies])) > 1

    # check for the highest degree among all input dimensions
    if self.has_dru:
        impl_n_layers: int = n_layers + 1  # we need L+1 according to Schuld et al.
    else:
        impl_n_layers = n_layers
    log.info(f"Number of implicit layers: {impl_n_layers}.")

    # --- Ansatz ---
    # only weak check for str. We trust the user to provide sth useful
    if isinstance(circuit_type, str):
        self.pqc: Callable[[Optional[jnp.ndarray], int], int] = getattr(
            Ansaetze, circuit_type or "No_Ansatz"
        )()
    else:
        self.pqc = circuit_type()
    log.info(f"Using Ansatz {circuit_type}.")

    # calculate the shape of the parameter vector here, we will re-use this in init.
    params_per_layer = self.pqc.n_params_per_layer(self.n_qubits)
    self._params_shape: Tuple[int, int] = (impl_n_layers, params_per_layer)
    log.info(f"Parameters per layer: {params_per_layer}")

    pulse_params_per_layer = self.pqc.n_pulse_params_per_layer(self.n_qubits)
    self._pulse_params_shape: Tuple[int, int] = (
        impl_n_layers,
        pulse_params_per_layer,
    )

    # intialize to None as we can't know this yet
    self._batch_shape = None

    # this will also be re-used in the init method,
    # however, only if nothing is provided
    self._inialization_strategy = initialization
    self._initialization_domain = initialization_domain

    # ..here! where we only require a JAX random key
    self.random_key = self.initialize_params(random.key(random_seed))

    # Initializing pulse params
    self.pulse_params: jnp.ndarray = jnp.ones((*self._pulse_params_shape, 1))

    log.info(f"Initialized pulse parameters with shape {self.pulse_params.shape}.")

    # Initialize two circuits, one with the default device and
    # one with the mixed device
    # which allows us to later route depending on the state_vector flag
    if self.n_qubits < self.lightning_threshold:
        device = "default.qubit"
    else:
        device = "lightning.qubit"
        self.use_multithreading = False
    self.circuit: qml.QNode = qml.QNode(
        self._circuit,
        qml.device(
            device,
            shots=self.shots,
            wires=self.n_qubits,
        ),
        interface="jax-jit",
        diff_method="parameter-shift" if self.shots is not None else "best",
    )

    self.circuit_mixed: qml.QNode = qml.QNode(
        self._circuit,
        qml.device("default.mixed", shots=self.shots, wires=self.n_qubits),
        interface="jax-jit",
        diff_method="parameter-shift" if self.shots is not None else "best",
    )

__repr__() #

Return text representation of the quantum circuit.

Source code in qml_essentials/model.py
def __repr__(self) -> str:
    """Return text representation of the quantum circuit."""
    return self.draw(figure="text")

__str__() #

Return string representation of the quantum circuit.

Source code in qml_essentials/model.py
def __str__(self) -> str:
    """Return string representation of the quantum circuit."""
    return self.draw(figure="text")

draw(inputs=None, figure='text', *args, **kwargs) #

Visualize the quantum circuit.

Generates a visual representation of the circuit using the specified rendering method.

Parameters:

Name Type Description Default
inputs Optional[ndarray]

Input data for the circuit. If None, default zero inputs are used. Defaults to None.

None
figure str

Visualization format. Options: - "text": ASCII text representation - "mpl": Matplotlib figure - "tikz": TikZ/LaTeX code for publication-quality figures Defaults to "text".

'text'
*args Any

Additional positional arguments passed to the visualization backend.

()
**kwargs Any

Additional keyword arguments passed to the visualization backend. May include pulse_params, gate_mode, enc_params, or noise_params.

{}

Returns:

Type Description
Union[str, Any]

Union[str, Any]: Visualization output: - "text": String with ASCII circuit diagram - "mpl": Matplotlib Figure and Axes objects - "tikz": TikZ code string

Raises:

Type Description
AssertionError

If figure is not one of "text", "mpl", or "tikz".

Source code in qml_essentials/model.py
def draw(
    self,
    inputs: Optional[jnp.ndarray] = None,
    figure: str = "text",
    *args: Any,
    **kwargs: Any,
) -> Union[str, Any]:
    """
    Visualize the quantum circuit.

    Generates a visual representation of the circuit using the specified
    rendering method.

    Args:
        inputs (Optional[jnp.ndarray]): Input data for the circuit. If None,
            default zero inputs are used. Defaults to None.
        figure (str): Visualization format. Options:
            - "text": ASCII text representation
            - "mpl": Matplotlib figure
            - "tikz": TikZ/LaTeX code for publication-quality figures
            Defaults to "text".
        *args (Any): Additional positional arguments passed to the
            visualization backend.
        **kwargs (Any): Additional keyword arguments passed to the
            visualization backend. May include pulse_params, gate_mode,
            enc_params, or noise_params.

    Returns:
        Union[str, Any]: Visualization output:
            - "text": String with ASCII circuit diagram
            - "mpl": Matplotlib Figure and Axes objects
            - "tikz": TikZ code string

    Raises:
        AssertionError: If figure is not one of "text", "mpl", or "tikz".
    """

    if not isinstance(self.circuit, qml.QNode):
        # TODO: throws strange argument error if not catched
        return ""

    assert figure in [
        "text",
        "mpl",
        "tikz",
    ], f"Invalid figure: {figure}. Must be 'text', 'mpl' or 'tikz'."

    inputs = self._inputs_validation(inputs)

    if figure == "mpl":
        return qml.draw_mpl(self.circuit)(
            params=self.params,
            inputs=inputs,
            *args,
            **kwargs,
        )
    elif figure == "tikz":
        return QuanTikz.build(
            self.circuit,
            params=self.params,
            inputs=inputs,
            *args,
            **kwargs,
        )
    else:
        return qml.draw(self.circuit)(params=self.params, inputs=inputs)

initialize_params(random_key=None, repeat=1, initialization=None, initialization_domain=None) #

Initialize the variational parameters of the model.

Parameters:

Name Type Description Default
random_key Optional[PRNGKey]

JAX random key for initialization. If None, uses the model's internal random key.

None
repeat int

Number of parameter sets to create (batch dimension). Defaults to 1.

1
initialization Optional[str]

Strategy for parameter initialization. Options: "random", "zeros", "pi", "zero-controlled", "pi-controlled". If None, uses the strategy specified in the constructor.

None
initialization_domain Optional[List[float]]

Domain [min, max] for random initialization. If None, uses the domain from constructor.

None

Returns:

Type Description
PRNGKey

random.PRNGKey: Updated random key after initialization.

Raises:

Type Description
Exception

If an invalid initialization method is specified.

Source code in qml_essentials/model.py
def initialize_params(
    self,
    random_key: Optional[random.PRNGKey] = None,
    repeat: int = 1,
    initialization: Optional[str] = None,
    initialization_domain: Optional[List[float]] = None,
) -> random.PRNGKey:
    """
    Initialize the variational parameters of the model.

    Args:
        random_key (Optional[random.PRNGKey]): JAX random key for initialization.
            If None, uses the model's internal random key.
        repeat (int): Number of parameter sets to create (batch dimension).
            Defaults to 1.
        initialization (Optional[str]): Strategy for parameter initialization.
            Options: "random", "zeros", "pi", "zero-controlled", "pi-controlled".
            If None, uses the strategy specified in the constructor.
        initialization_domain (Optional[List[float]]): Domain [min, max] for
            random initialization. If None, uses the domain from constructor.

    Returns:
        random.PRNGKey: Updated random key after initialization.

    Raises:
        Exception: If an invalid initialization method is specified.
    """
    # Initializing params
    params_shape = (*self._params_shape, repeat)

    # use existing strategy if not specified
    initialization = initialization or self._inialization_strategy
    initialization_domain = initialization_domain or self._initialization_domain

    random_key, sub_key = safe_random_split(
        random_key if random_key is not None else self.random_key
    )

    def set_control_params(params: jnp.ndarray, value: float) -> jnp.ndarray:
        indices = self.pqc.get_control_indices(self.n_qubits)
        if indices is None:
            warnings.warn(
                f"Specified {initialization} but circuit\
                does not contain controlled rotation gates.\
                Parameters are intialized randomly.",
                UserWarning,
            )
        else:
            np_params = np.array(params)
            np_params[:, indices[0] : indices[1] : indices[2]] = (
                np.ones_like(params[:, indices[0] : indices[1] : indices[2]])
                * value
            )
            params = jnp.array(np_params)
        return params

    if initialization == "random":
        self.params: jnp.ndarray = random.uniform(
            sub_key,
            params_shape,
            minval=initialization_domain[0],
            maxval=initialization_domain[1],
        )
    elif initialization == "zeros":
        self.params: jnp.ndarray = jnp.zeros(params_shape)
    elif initialization == "pi":
        self.params: jnp.ndarray = jnp.ones(params_shape) * jnp.pi
    elif initialization == "zero-controlled":
        self.params: jnp.ndarray = random.uniform(
            sub_key,
            params_shape,
            minval=initialization_domain[0],
            maxval=initialization_domain[1],
        )
        self.params = set_control_params(self.params, 0)
    elif initialization == "pi-controlled":
        self.params: jnp.ndarray = random.uniform(
            sub_key,
            params_shape,
            minval=initialization_domain[0],
            maxval=initialization_domain[1],
        )
        self.params = set_control_params(self.params, jnp.pi)
    else:
        raise Exception("Invalid initialization method")

    log.info(
        f"Initialized parameters with shape {self.params.shape}\
        using strategy {initialization}."
    )

    return random_key

transform_input(inputs, enc_params) #

Transform input data by scaling with encoding parameters.

Implements the input transformation as described in arXiv:2309.03279v2, where inputs are linearly scaled by encoding parameters before being used in the quantum circuit.

Parameters:

Name Type Description Default
inputs ndarray

Input data point of shape (n_input_feat,) or (batch_size, n_input_feat).

required
enc_params ndarray

Encoding weight scalar or vector used to scale the input.

required

Returns:

Type Description
ndarray

jnp.ndarray: Transformed input, element-wise product of inputs and enc_params.

Source code in qml_essentials/model.py
def transform_input(
    self, inputs: jnp.ndarray, enc_params: jnp.ndarray
) -> jnp.ndarray:
    """
    Transform input data by scaling with encoding parameters.

    Implements the input transformation as described in arXiv:2309.03279v2,
    where inputs are linearly scaled by encoding parameters before being
    used in the quantum circuit.

    Args:
        inputs (jnp.ndarray): Input data point of shape (n_input_feat,) or
            (batch_size, n_input_feat).
        enc_params (jnp.ndarray): Encoding weight scalar or vector used to
            scale the input.

    Returns:
        jnp.ndarray: Transformed input, element-wise product of inputs
            and enc_params.
    """
    return inputs * enc_params

Entanglement#

from qml_essentials.entanglement import Entanglement
Source code in qml_essentials/entanglement.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
class Entanglement:
    @staticmethod
    def meyer_wallach(
        model: Model,
        n_samples: Optional[int | None],
        seed: Optional[int],
        scale: bool = False,
        **kwargs: Any,
    ) -> float:
        """
        Calculates the entangling capacity of a given quantum circuit
        using Meyer-Wallach measure.

        Args:
            model (Model): The quantum circuit model.
            n_samples (Optional[int]): Number of samples per qubit.
                If None or < 0, the current parameters of the model are used.
            seed (Optional[int]): Seed for the random number generator.
            scale (bool): Whether to scale the number of samples.
            kwargs (Any): Additional keyword arguments for the model function.

        Returns:
            float: Entangling capacity of the given circuit, guaranteed
                to be between 0.0 and 1.0.
        """
        if "noise_params" in kwargs:
            log.warning(
                "Meyer-Wallach measure not suitable for noisy circuits.\
                    Consider 'relative_entropy' instead."
            )

        if scale:
            n_samples = jnp.power(2, model.n_qubits) * n_samples

        random_key = jax.random.key(seed)
        if n_samples is not None and n_samples > 0:
            assert seed is not None, "Seed must be provided when samples > 0"
            random_key = model.initialize_params(random_key, repeat=n_samples)
        else:
            if seed is not None:
                log.warning("Seed is ignored when samples is 0")

        # implicitly set input to none in case it's not needed
        kwargs.setdefault("inputs", None)
        # explicitly set execution type because everything else won't work
        rhos = model(execution_type="density", **kwargs).reshape(
            -1, 2**model.n_qubits, 2**model.n_qubits
        )

        ent = Entanglement._compute_meyer_wallach_meas(
            rhos, model.n_qubits, model.use_multithreading
        )

        log.debug(f"Variance of measure: {ent.var()}")

        return ent.mean()

    @staticmethod
    def _compute_meyer_wallach_meas(
        rhos: jnp.ndarray, n_qubits: int, use_multithreading: bool = False
    ) -> jnp.ndarray:
        """
        Computes the Meyer-Wallach entangling capability measure for a given
        set of density matrices.

        Args:
            rhos (jnp.ndarray): Density matrices of the sample quantum states.
                The shape is (B_s, 2^n, 2^n), where B_s is the number of samples
                (batch) and n the number of qubits
            n_qubits (int): The number of qubits
            use_multithreading (bool): Whether to use JAX vectorisation.

        Returns:
            jnp.ndarray: Entangling capability for each sample, array with
                shape (B_s,)
        """
        qb = list(range(n_qubits))

        def _f(rhos):
            entropy = 0
            for j in range(n_qubits):
                # Formula 6 in https://doi.org/10.48550/arXiv.quant-ph/0305094
                density = qml.math.partial_trace(rhos, qb[:j] + qb[j + 1 :])
                # only real values, because imaginary part will be separate
                # in all following calculations anyway
                # entropy should be 1/2 <= entropy <= 1
                entropy += jnp.trace((density @ density).real, axis1=-2, axis2=-1)

            # inverse averaged entropy and scale to [0, 1]
            return 2 * (1 - entropy / n_qubits)

        if use_multithreading:
            return jax.vmap(_f)(rhos)
        else:
            return _f(rhos)

    @staticmethod
    def bell_measurements(
        model: Model, n_samples: int, seed: int, scale: bool = False, **kwargs: Any
    ) -> float:
        """
        Compute the Bell measurement for a given model.

        Args:
            model (Model): The quantum circuit model.
            n_samples (int): The number of samples to compute the measure for.
            seed (int): The seed for the random number generator.
            scale (bool): Whether to scale the number of samples
                according to the number of qubits.
            **kwargs (Any): Additional keyword arguments for the model function.

        Returns:
            float: The Bell measurement value.
        """
        if "noise_params" in kwargs:
            log.warning(
                "Bell Measurements not suitable for noisy circuits.\
                    Consider 'relative_entropy' instead."
            )

        if scale:
            n_samples = jnp.power(2, model.n_qubits) * n_samples

        def _circuit(
            params: jnp.ndarray, inputs: jnp.ndarray, **kwargs
        ) -> List[jnp.ndarray]:
            """
            Compute the Bell measurement circuit.

            Args:
                params (jnp.ndarray): The model parameters.
                inputs (jnp.ndarray): The input to the model.
                pulse_params (jnp.ndarray): The model pulse parameters.
                enc_params (Optional[jnp.ndarray]): The frequency encoding parameters.

            Returns:
                List[jnp.ndarray]: The probabilities of the Bell measurement.
            """
            model._variational(params, inputs, **kwargs)

            qml.map_wires(
                model._variational,
                {i: i + model.n_qubits for i in range(model.n_qubits)},
            )(params, inputs)

            for q in range(model.n_qubits):
                qml.CNOT(wires=[q, q + model.n_qubits])
                qml.H(q)

            # look at the auxiliary qubits
            return model._observable()

        prev_output_qubit = model.output_qubit
        model.output_qubit = [(q, q + model.n_qubits) for q in range(model.n_qubits)]
        model.circuit = qml.QNode(
            _circuit,
            qml.device(
                "default.qubit",
                shots=model.shots,
                wires=model.n_qubits * 2,
            ),
        )

        random_key = jax.random.key(seed)
        if n_samples is not None and n_samples > 0:
            assert seed is not None, "Seed must be provided when samples > 0"
            random_key = model.initialize_params(random_key, repeat=n_samples)
            params = model.params
        else:
            if seed is not None:
                log.warning("Seed is ignored when samples is 0")

            if len(model.params.shape) <= 2:
                params = model.params.reshape(*model.params.shape, 1)
            else:
                log.info(f"Using sample size of model params: {model.params.shape[-1]}")
                params = model.params

        n_samples = params.shape[-1]
        measure = jnp.zeros(n_samples)

        # implicitly set input to none in case it's not needed
        kwargs.setdefault("inputs", None)
        exp = model(params=params, execution_type="probs", **kwargs)
        exp = 1 - 2 * exp[..., -1]

        if not jnp.isclose(jnp.sum(exp.imag), 0, atol=1e-6):
            log.warning("Imaginary part of probabilities detected")
            exp = jnp.abs(exp)

        measure = 2 * (1 - exp.mean(axis=0))
        entangling_capability = min(max(measure.mean(), 0.0), 1.0)
        log.debug(f"Variance of measure: {measure.var()}")

        # restore state
        model.output_qubit = prev_output_qubit
        return float(entangling_capability)

    @staticmethod
    def relative_entropy(
        model: Model,
        n_samples: int,
        n_sigmas: int,
        seed: Optional[int],
        scale: bool = False,
        **kwargs: Any,
    ) -> float:
        """
        Calculates the relative entropy of entanglement of a given quantum
        circuit. This measure is also applicable to mixed state, albeit it
        might me not fully accurate in this simplified case.

        As the relative entropy is generally defined as the smallest relative
        entropy from the state in question to the set of separable states.
        However, as computing the nearest separable state is NP-hard, we select
        n_sigmas of random separable states to compute the distance to, which
        is not necessarily the nearest. Thus, this measure of entanglement
        presents an upper limit of entanglement.

        As the relative entropy is not necessarily between zero and one, this
        function also normalises by the relative entroy to the GHZ state.

        Args:
            model (Model): The quantum circuit model.
            n_samples (int): Number of samples per qubit.
                If <= 0, the current parameters of the model are used.
            n_sigmas (int): Number of random separable pure states to compare against.
            seed (Optional[int]): Seed for the random number generator.
            scale (bool): Whether to scale the number of samples.
            kwargs (Any): Additional keyword arguments for the model function.

        Returns:
            float: Entangling capacity of the given circuit, guaranteed
                to be between 0.0 and 1.0.
        """
        dim = jnp.power(2, model.n_qubits)
        if scale:
            n_samples = dim * n_samples
            n_sigmas = dim * n_sigmas

        random_key = jax.random.key(seed)

        # Random separable states
        log_sigmas = sample_random_separable_states(
            model.n_qubits, n_samples=n_sigmas, random_key=random_key, take_log=True
        )

        random_key, _ = jax.random.split(random_key)

        if n_samples is not None and n_samples > 0:
            assert seed is not None, "Seed must be provided when samples > 0"
            model.initialize_params(random_key, repeat=n_samples)
        else:
            if seed is not None:
                log.warning("Seed is ignored when samples is 0")

            if len(model.params.shape) <= 2:
                model.params = model.params.reshape(*model.params.shape, 1)
            else:
                log.info(f"Using sample size of model params: {model.params.shape[-1]}")

        rhos, log_rhos = Entanglement._compute_log_density(model, **kwargs)

        rel_entropies = jnp.zeros((n_sigmas, model.params.shape[-1]))

        for i, log_sigma in enumerate(log_sigmas):
            rel_entropies = rel_entropies.at[i].set(
                Entanglement._compute_rel_entropies(
                    rhos, log_rhos, log_sigma, model.use_multithreading
                )
            )

        # Entropy of GHZ states should be maximal
        ghz_model = Model(model.n_qubits, 1, "GHZ", data_reupload=False)
        rho_ghz, log_rho_ghz = Entanglement._compute_log_density(ghz_model, **kwargs)
        ghz_entropies = Entanglement._compute_rel_entropies(
            rho_ghz, log_rho_ghz, log_sigmas, use_multithreading=False
        )

        normalised_entropies = rel_entropies / ghz_entropies

        # Average all iterated states
        entangling_capability = normalised_entropies.T.min(axis=1)
        log.debug(f"Variance of measure: {entangling_capability.var()}")

        return entangling_capability.mean()

    @staticmethod
    def _compute_log_density(model: Model, **kwargs) -> Tuple[jnp.ndarray, jnp.ndarray]:
        """
        Obtains the density matrix of a model and computes its logarithm.

        Args:
            model (Model): The model for which to compute the density matrix.

        Returns:
            Tuple[jnp.ndarray, jnp.ndarray]:
                - jnp.ndarray: density matrix.
                - jnp.ndarray: logarithm of the density matrix.
        """
        # implicitly set input to none in case it's not needed
        kwargs.setdefault("inputs", None)
        # explicitly set execution type because everything else won't work
        rho = model(execution_type="density", **kwargs)
        rho = rho.reshape(-1, 2**model.n_qubits, 2**model.n_qubits)
        log_rho = logm_v(rho) / jnp.log(2)
        return rho, log_rho

    @staticmethod
    def _compute_rel_entropies(
        rhos: jnp.ndarray,
        log_rhos: jnp.ndarray,
        log_sigmas: jnp.ndarray,
        use_multithreading: bool,
    ) -> jnp.ndarray:
        """
        Compute the relative entropy for a given model.

        Args:
            rhos (jnp.ndarray): Density matrix result of the circuit, has shape
                (R, 2^n, 2^n), with the batch size R and number of qubits n
            log_rhos (jnp.ndarray): Corresponding logarithm of the density
                matrix, has shape (R, 2^n, 2^n).
            log_sigmas (jnp.ndarray): Density matrix of next separable state,
                has shape (2^n, 2^n) if it's a single sigma or (S, 2^n, 2^n),
                with the batch size S (number of sigmas).

        Returns:
            jnp.ndarray: Relative Entropy for each sample
        """
        n_rhos = rhos.shape[0]
        if len(log_sigmas.shape) == 3:
            n_sigmas = log_sigmas.shape[0]
            rhos = jnp.tile(rhos, (n_sigmas, 1, 1))
            log_rhos = jnp.tile(log_rhos, (n_sigmas, 1, 1))
            einsum_subscript = "ij,jk->ik" if use_multithreading else "sij,sjk->sik"
        else:
            n_sigmas = 1
            log_sigmas = log_sigmas[jnp.newaxis, ...].repeat(n_rhos, axis=0)

        einsum_subscript = "ij,jk->ik" if use_multithreading else "sij,sjk->sik"

        def _f(rhos, log_rhos, log_sigmas):
            prod = jnp.einsum(einsum_subscript, rhos, log_rhos - log_sigmas)
            rel_entropies = jnp.abs(jnp.trace(prod, axis1=-2, axis2=-1))
            return rel_entropies

        if use_multithreading:
            rel_entropies = jax.vmap(_f, in_axes=(0, 0, 0))(rhos, log_rhos, log_sigmas)
        else:
            rel_entropies = _f(rhos, log_rhos, log_sigmas)

        if n_sigmas > 1:
            rel_entropies = rel_entropies.reshape(n_sigmas, n_rhos)
        return rel_entropies

    @staticmethod
    def entanglement_of_formation(
        model: Model,
        n_samples: int,
        seed: Optional[int],
        scale: bool = False,
        always_decompose: bool = False,
        **kwargs: Any,
    ) -> float:
        """
        This function implements the entanglement of formation for mixed
        quantum systems.
        In that a mixed state gets decomposed into pure states with respective
        probabilities using the eigendecomposition of the density matrix.
        Then, the Meyer-Wallach measure is computed for each pure state,
        weighted by the eigenvalue.
        See e.g. https://doi.org/10.48550/arXiv.quant-ph/0504163

        Note that the decomposition is *not unique*! Therefore, this measure
        presents the entanglement for *some* decomposition into pure states,
        not necessarily the one that is anticipated when applying the Kraus
        channels.
        If a pure state is provided, this results in the same value as the
        Entanglement.meyer_wallach function if `always_decompose` flag is not set.

        Args:
            model (Model): The quantum circuit model.
            n_samples (int): Number of samples per qubit.
            seed (Optional[int]): Seed for the random number generator.
            scale (bool): Whether to scale the number of samples.
            always_decompose (bool): Whether to explicitly compute the
                entantlement of formation for the eigendecomposition of a pure
                state.
            kwargs (Any): Additional keyword arguments for the model function.

        Returns:
            float: Entangling capacity of the given circuit, guaranteed
                to be between 0.0 and 1.0.
        """

        if scale:
            n_samples = jnp.power(2, model.n_qubits) * n_samples

        random_key = jax.random.key(seed)
        if n_samples is not None and n_samples > 0:
            assert seed is not None, "Seed must be provided when samples > 0"
            model.initialize_params(random_key, repeat=n_samples)
        else:
            if seed is not None:
                log.warning("Seed is ignored when samples is 0")

            if len(model.params.shape) <= 2:
                model.params = model.params.reshape(*model.params.shape, 1)
            else:
                log.info(f"Using sample size of model params: {model.params.shape[-1]}")

        # implicitly set input to none in case it's not needed
        kwargs.setdefault("inputs", None)
        rhos = model(execution_type="density", **kwargs)
        rhos = rhos.reshape(-1, 2**model.n_qubits, 2**model.n_qubits)
        ent = Entanglement._compute_entanglement_of_formation(
            rhos, model.n_qubits, always_decompose, model.use_multithreading
        )
        return ent.mean()

    @staticmethod
    def _compute_entanglement_of_formation(
        rhos: jnp.ndarray,
        n_qubits: int,
        always_decompose: bool,
        use_multithreading: bool,
    ) -> jnp.ndarray:
        """
        Computes the entanglement of formation for a given batch of density
        matrices.

        Args:
            rho (jnp.ndarray): The density matrices, has shape (B_s, 2^n, 2^n),
                where B_s is the batch size and n the number of qubits.
            n_qubits (int): Number of qubits
            always_decompose (bool): Whether to explicitly compute the
                entantlement of formation for the eigendecomposition of a pure
                state.
            use_multithreading (bool): Whether to use JAX vectorisation.

        Returns:
            jnp.ndarray: Entanglement for the provided density matrices.
        """
        eigenvalues, eigenvectors = jnp.linalg.eigh(rhos)
        if not always_decompose and jnp.isclose(eigenvalues, 1.0).any(axis=-1).all():
            return Entanglement._compute_meyer_wallach_meas(
                rhos, n_qubits, use_multithreading
            )

        rhos = np.einsum("sij,sik->sijk", eigenvectors, eigenvectors.conjugate())
        measures = Entanglement._compute_meyer_wallach_meas(
            rhos.reshape(-1, 2**n_qubits, 2**n_qubits), n_qubits, use_multithreading
        )
        ent = np.einsum("si,si->s", measures.reshape(-1, 2**n_qubits), eigenvalues)
        return ent

    @staticmethod
    def concentratable_entanglement(
        model: Model, n_samples: int, seed: int, scale: bool = False, **kwargs: Any
    ) -> float:
        """
        Computes the concentratable entanglement of a given model.

        This method utilizes the Concentratable Entanglement measure from
        https://arxiv.org/abs/2104.06923.

        Args:
            model (Model): The quantum circuit model.
            n_samples (int): The number of samples to compute the measure for.
            seed (int): The seed for the random number generator.
            scale (bool): Whether to scale the number of samples according to
                the number of qubits.
            **kwargs (Any): Additional keyword arguments for the model function.

        Returns:
            float: Entangling capability of the given circuit, guaranteed
                to be between 0.0 and 1.0.
        """
        n = model.n_qubits
        N = 2**n

        if scale:
            n_samples = N * n_samples

        dev = qml.device(
            "default.mixed",
            shots=model.shots,
            wires=n * 3,
        )

        @qml.qnode(device=dev)
        def _swap_test(
            params: jnp.ndarray, inputs: jnp.ndarray, **kwargs
        ) -> jnp.ndarray:
            """
            Constructs a circuit to compute the concentratable entanglement using the
            swap test by creating two copies of a state given by a density matrix rho
            and mapping the output wires accordingly.

            Args:
                rho (jnp.ndarray): the density matrix of the state on which the swap
                    test is performed.

            Returns:
                List[jnp.ndarray]: Probabilities obtained from the swap test circuit.
            """

            qml.map_wires(model._variational, wire_map={o: o + n for o in range(n)})(
                params, inputs, **kwargs
            )
            qml.map_wires(
                model._variational, wire_map={o: o + 2 * n for o in range(n)}
            )(params, inputs, **kwargs)

            # Perform swap test
            for i in range(n):
                qml.H(i)

            for i in range(n):
                qml.CSWAP([i, i + n, i + 2 * n])

            for i in range(n):
                qml.H(i)

            return qml.probs(wires=[i for i in range(n)])

        random_key = jax.random.key(seed)
        if n_samples is not None and n_samples > 0:
            assert seed is not None, "Seed must be provided when samples > 0"
            model.initialize_params(random_key, repeat=n_samples)
        else:
            if seed is not None:
                log.warning("Seed is ignored when samples is 0")

            if len(model.params.shape) <= 2:
                model.params = model.params.reshape(*model.params.shape, 1)
            else:
                log.info(f"Using sample size of model params: {model.params.shape[-1]}")

        def _f(params):
            probs = _swap_test(params, model._inputs_validation(None), **kwargs)
            ent = 1 - probs[..., 0]
            return ent

        if model.use_multithreading:
            ent = jax.vmap(_f, in_axes=2)(model.params)
        else:
            ent = _f(model.params)

        # Catch floating point errors
        log.debug(f"Variance of measure: {ent.var()}")

        return ent.mean()

bell_measurements(model, n_samples, seed, scale=False, **kwargs) staticmethod #

Compute the Bell measurement for a given model.

Parameters:

Name Type Description Default
model Model

The quantum circuit model.

required
n_samples int

The number of samples to compute the measure for.

required
seed int

The seed for the random number generator.

required
scale bool

Whether to scale the number of samples according to the number of qubits.

False
**kwargs Any

Additional keyword arguments for the model function.

{}

Returns:

Name Type Description
float float

The Bell measurement value.

Source code in qml_essentials/entanglement.py
@staticmethod
def bell_measurements(
    model: Model, n_samples: int, seed: int, scale: bool = False, **kwargs: Any
) -> float:
    """
    Compute the Bell measurement for a given model.

    Args:
        model (Model): The quantum circuit model.
        n_samples (int): The number of samples to compute the measure for.
        seed (int): The seed for the random number generator.
        scale (bool): Whether to scale the number of samples
            according to the number of qubits.
        **kwargs (Any): Additional keyword arguments for the model function.

    Returns:
        float: The Bell measurement value.
    """
    if "noise_params" in kwargs:
        log.warning(
            "Bell Measurements not suitable for noisy circuits.\
                Consider 'relative_entropy' instead."
        )

    if scale:
        n_samples = jnp.power(2, model.n_qubits) * n_samples

    def _circuit(
        params: jnp.ndarray, inputs: jnp.ndarray, **kwargs
    ) -> List[jnp.ndarray]:
        """
        Compute the Bell measurement circuit.

        Args:
            params (jnp.ndarray): The model parameters.
            inputs (jnp.ndarray): The input to the model.
            pulse_params (jnp.ndarray): The model pulse parameters.
            enc_params (Optional[jnp.ndarray]): The frequency encoding parameters.

        Returns:
            List[jnp.ndarray]: The probabilities of the Bell measurement.
        """
        model._variational(params, inputs, **kwargs)

        qml.map_wires(
            model._variational,
            {i: i + model.n_qubits for i in range(model.n_qubits)},
        )(params, inputs)

        for q in range(model.n_qubits):
            qml.CNOT(wires=[q, q + model.n_qubits])
            qml.H(q)

        # look at the auxiliary qubits
        return model._observable()

    prev_output_qubit = model.output_qubit
    model.output_qubit = [(q, q + model.n_qubits) for q in range(model.n_qubits)]
    model.circuit = qml.QNode(
        _circuit,
        qml.device(
            "default.qubit",
            shots=model.shots,
            wires=model.n_qubits * 2,
        ),
    )

    random_key = jax.random.key(seed)
    if n_samples is not None and n_samples > 0:
        assert seed is not None, "Seed must be provided when samples > 0"
        random_key = model.initialize_params(random_key, repeat=n_samples)
        params = model.params
    else:
        if seed is not None:
            log.warning("Seed is ignored when samples is 0")

        if len(model.params.shape) <= 2:
            params = model.params.reshape(*model.params.shape, 1)
        else:
            log.info(f"Using sample size of model params: {model.params.shape[-1]}")
            params = model.params

    n_samples = params.shape[-1]
    measure = jnp.zeros(n_samples)

    # implicitly set input to none in case it's not needed
    kwargs.setdefault("inputs", None)
    exp = model(params=params, execution_type="probs", **kwargs)
    exp = 1 - 2 * exp[..., -1]

    if not jnp.isclose(jnp.sum(exp.imag), 0, atol=1e-6):
        log.warning("Imaginary part of probabilities detected")
        exp = jnp.abs(exp)

    measure = 2 * (1 - exp.mean(axis=0))
    entangling_capability = min(max(measure.mean(), 0.0), 1.0)
    log.debug(f"Variance of measure: {measure.var()}")

    # restore state
    model.output_qubit = prev_output_qubit
    return float(entangling_capability)

concentratable_entanglement(model, n_samples, seed, scale=False, **kwargs) staticmethod #

Computes the concentratable entanglement of a given model.

This method utilizes the Concentratable Entanglement measure from https://arxiv.org/abs/2104.06923.

Parameters:

Name Type Description Default
model Model

The quantum circuit model.

required
n_samples int

The number of samples to compute the measure for.

required
seed int

The seed for the random number generator.

required
scale bool

Whether to scale the number of samples according to the number of qubits.

False
**kwargs Any

Additional keyword arguments for the model function.

{}

Returns:

Name Type Description
float float

Entangling capability of the given circuit, guaranteed to be between 0.0 and 1.0.

Source code in qml_essentials/entanglement.py
@staticmethod
def concentratable_entanglement(
    model: Model, n_samples: int, seed: int, scale: bool = False, **kwargs: Any
) -> float:
    """
    Computes the concentratable entanglement of a given model.

    This method utilizes the Concentratable Entanglement measure from
    https://arxiv.org/abs/2104.06923.

    Args:
        model (Model): The quantum circuit model.
        n_samples (int): The number of samples to compute the measure for.
        seed (int): The seed for the random number generator.
        scale (bool): Whether to scale the number of samples according to
            the number of qubits.
        **kwargs (Any): Additional keyword arguments for the model function.

    Returns:
        float: Entangling capability of the given circuit, guaranteed
            to be between 0.0 and 1.0.
    """
    n = model.n_qubits
    N = 2**n

    if scale:
        n_samples = N * n_samples

    dev = qml.device(
        "default.mixed",
        shots=model.shots,
        wires=n * 3,
    )

    @qml.qnode(device=dev)
    def _swap_test(
        params: jnp.ndarray, inputs: jnp.ndarray, **kwargs
    ) -> jnp.ndarray:
        """
        Constructs a circuit to compute the concentratable entanglement using the
        swap test by creating two copies of a state given by a density matrix rho
        and mapping the output wires accordingly.

        Args:
            rho (jnp.ndarray): the density matrix of the state on which the swap
                test is performed.

        Returns:
            List[jnp.ndarray]: Probabilities obtained from the swap test circuit.
        """

        qml.map_wires(model._variational, wire_map={o: o + n for o in range(n)})(
            params, inputs, **kwargs
        )
        qml.map_wires(
            model._variational, wire_map={o: o + 2 * n for o in range(n)}
        )(params, inputs, **kwargs)

        # Perform swap test
        for i in range(n):
            qml.H(i)

        for i in range(n):
            qml.CSWAP([i, i + n, i + 2 * n])

        for i in range(n):
            qml.H(i)

        return qml.probs(wires=[i for i in range(n)])

    random_key = jax.random.key(seed)
    if n_samples is not None and n_samples > 0:
        assert seed is not None, "Seed must be provided when samples > 0"
        model.initialize_params(random_key, repeat=n_samples)
    else:
        if seed is not None:
            log.warning("Seed is ignored when samples is 0")

        if len(model.params.shape) <= 2:
            model.params = model.params.reshape(*model.params.shape, 1)
        else:
            log.info(f"Using sample size of model params: {model.params.shape[-1]}")

    def _f(params):
        probs = _swap_test(params, model._inputs_validation(None), **kwargs)
        ent = 1 - probs[..., 0]
        return ent

    if model.use_multithreading:
        ent = jax.vmap(_f, in_axes=2)(model.params)
    else:
        ent = _f(model.params)

    # Catch floating point errors
    log.debug(f"Variance of measure: {ent.var()}")

    return ent.mean()

entanglement_of_formation(model, n_samples, seed, scale=False, always_decompose=False, **kwargs) staticmethod #

This function implements the entanglement of formation for mixed quantum systems. In that a mixed state gets decomposed into pure states with respective probabilities using the eigendecomposition of the density matrix. Then, the Meyer-Wallach measure is computed for each pure state, weighted by the eigenvalue. See e.g. https://doi.org/10.48550/arXiv.quant-ph/0504163

Note that the decomposition is not unique! Therefore, this measure presents the entanglement for some decomposition into pure states, not necessarily the one that is anticipated when applying the Kraus channels. If a pure state is provided, this results in the same value as the Entanglement.meyer_wallach function if always_decompose flag is not set.

Parameters:

Name Type Description Default
model Model

The quantum circuit model.

required
n_samples int

Number of samples per qubit.

required
seed Optional[int]

Seed for the random number generator.

required
scale bool

Whether to scale the number of samples.

False
always_decompose bool

Whether to explicitly compute the entantlement of formation for the eigendecomposition of a pure state.

False
kwargs Any

Additional keyword arguments for the model function.

{}

Returns:

Name Type Description
float float

Entangling capacity of the given circuit, guaranteed to be between 0.0 and 1.0.

Source code in qml_essentials/entanglement.py
@staticmethod
def entanglement_of_formation(
    model: Model,
    n_samples: int,
    seed: Optional[int],
    scale: bool = False,
    always_decompose: bool = False,
    **kwargs: Any,
) -> float:
    """
    This function implements the entanglement of formation for mixed
    quantum systems.
    In that a mixed state gets decomposed into pure states with respective
    probabilities using the eigendecomposition of the density matrix.
    Then, the Meyer-Wallach measure is computed for each pure state,
    weighted by the eigenvalue.
    See e.g. https://doi.org/10.48550/arXiv.quant-ph/0504163

    Note that the decomposition is *not unique*! Therefore, this measure
    presents the entanglement for *some* decomposition into pure states,
    not necessarily the one that is anticipated when applying the Kraus
    channels.
    If a pure state is provided, this results in the same value as the
    Entanglement.meyer_wallach function if `always_decompose` flag is not set.

    Args:
        model (Model): The quantum circuit model.
        n_samples (int): Number of samples per qubit.
        seed (Optional[int]): Seed for the random number generator.
        scale (bool): Whether to scale the number of samples.
        always_decompose (bool): Whether to explicitly compute the
            entantlement of formation for the eigendecomposition of a pure
            state.
        kwargs (Any): Additional keyword arguments for the model function.

    Returns:
        float: Entangling capacity of the given circuit, guaranteed
            to be between 0.0 and 1.0.
    """

    if scale:
        n_samples = jnp.power(2, model.n_qubits) * n_samples

    random_key = jax.random.key(seed)
    if n_samples is not None and n_samples > 0:
        assert seed is not None, "Seed must be provided when samples > 0"
        model.initialize_params(random_key, repeat=n_samples)
    else:
        if seed is not None:
            log.warning("Seed is ignored when samples is 0")

        if len(model.params.shape) <= 2:
            model.params = model.params.reshape(*model.params.shape, 1)
        else:
            log.info(f"Using sample size of model params: {model.params.shape[-1]}")

    # implicitly set input to none in case it's not needed
    kwargs.setdefault("inputs", None)
    rhos = model(execution_type="density", **kwargs)
    rhos = rhos.reshape(-1, 2**model.n_qubits, 2**model.n_qubits)
    ent = Entanglement._compute_entanglement_of_formation(
        rhos, model.n_qubits, always_decompose, model.use_multithreading
    )
    return ent.mean()

meyer_wallach(model, n_samples, seed, scale=False, **kwargs) staticmethod #

Calculates the entangling capacity of a given quantum circuit using Meyer-Wallach measure.

Parameters:

Name Type Description Default
model Model

The quantum circuit model.

required
n_samples Optional[int]

Number of samples per qubit. If None or < 0, the current parameters of the model are used.

required
seed Optional[int]

Seed for the random number generator.

required
scale bool

Whether to scale the number of samples.

False
kwargs Any

Additional keyword arguments for the model function.

{}

Returns:

Name Type Description
float float

Entangling capacity of the given circuit, guaranteed to be between 0.0 and 1.0.

Source code in qml_essentials/entanglement.py
@staticmethod
def meyer_wallach(
    model: Model,
    n_samples: Optional[int | None],
    seed: Optional[int],
    scale: bool = False,
    **kwargs: Any,
) -> float:
    """
    Calculates the entangling capacity of a given quantum circuit
    using Meyer-Wallach measure.

    Args:
        model (Model): The quantum circuit model.
        n_samples (Optional[int]): Number of samples per qubit.
            If None or < 0, the current parameters of the model are used.
        seed (Optional[int]): Seed for the random number generator.
        scale (bool): Whether to scale the number of samples.
        kwargs (Any): Additional keyword arguments for the model function.

    Returns:
        float: Entangling capacity of the given circuit, guaranteed
            to be between 0.0 and 1.0.
    """
    if "noise_params" in kwargs:
        log.warning(
            "Meyer-Wallach measure not suitable for noisy circuits.\
                Consider 'relative_entropy' instead."
        )

    if scale:
        n_samples = jnp.power(2, model.n_qubits) * n_samples

    random_key = jax.random.key(seed)
    if n_samples is not None and n_samples > 0:
        assert seed is not None, "Seed must be provided when samples > 0"
        random_key = model.initialize_params(random_key, repeat=n_samples)
    else:
        if seed is not None:
            log.warning("Seed is ignored when samples is 0")

    # implicitly set input to none in case it's not needed
    kwargs.setdefault("inputs", None)
    # explicitly set execution type because everything else won't work
    rhos = model(execution_type="density", **kwargs).reshape(
        -1, 2**model.n_qubits, 2**model.n_qubits
    )

    ent = Entanglement._compute_meyer_wallach_meas(
        rhos, model.n_qubits, model.use_multithreading
    )

    log.debug(f"Variance of measure: {ent.var()}")

    return ent.mean()

relative_entropy(model, n_samples, n_sigmas, seed, scale=False, **kwargs) staticmethod #

Calculates the relative entropy of entanglement of a given quantum circuit. This measure is also applicable to mixed state, albeit it might me not fully accurate in this simplified case.

As the relative entropy is generally defined as the smallest relative entropy from the state in question to the set of separable states. However, as computing the nearest separable state is NP-hard, we select n_sigmas of random separable states to compute the distance to, which is not necessarily the nearest. Thus, this measure of entanglement presents an upper limit of entanglement.

As the relative entropy is not necessarily between zero and one, this function also normalises by the relative entroy to the GHZ state.

Parameters:

Name Type Description Default
model Model

The quantum circuit model.

required
n_samples int

Number of samples per qubit. If <= 0, the current parameters of the model are used.

required
n_sigmas int

Number of random separable pure states to compare against.

required
seed Optional[int]

Seed for the random number generator.

required
scale bool

Whether to scale the number of samples.

False
kwargs Any

Additional keyword arguments for the model function.

{}

Returns:

Name Type Description
float float

Entangling capacity of the given circuit, guaranteed to be between 0.0 and 1.0.

Source code in qml_essentials/entanglement.py
@staticmethod
def relative_entropy(
    model: Model,
    n_samples: int,
    n_sigmas: int,
    seed: Optional[int],
    scale: bool = False,
    **kwargs: Any,
) -> float:
    """
    Calculates the relative entropy of entanglement of a given quantum
    circuit. This measure is also applicable to mixed state, albeit it
    might me not fully accurate in this simplified case.

    As the relative entropy is generally defined as the smallest relative
    entropy from the state in question to the set of separable states.
    However, as computing the nearest separable state is NP-hard, we select
    n_sigmas of random separable states to compute the distance to, which
    is not necessarily the nearest. Thus, this measure of entanglement
    presents an upper limit of entanglement.

    As the relative entropy is not necessarily between zero and one, this
    function also normalises by the relative entroy to the GHZ state.

    Args:
        model (Model): The quantum circuit model.
        n_samples (int): Number of samples per qubit.
            If <= 0, the current parameters of the model are used.
        n_sigmas (int): Number of random separable pure states to compare against.
        seed (Optional[int]): Seed for the random number generator.
        scale (bool): Whether to scale the number of samples.
        kwargs (Any): Additional keyword arguments for the model function.

    Returns:
        float: Entangling capacity of the given circuit, guaranteed
            to be between 0.0 and 1.0.
    """
    dim = jnp.power(2, model.n_qubits)
    if scale:
        n_samples = dim * n_samples
        n_sigmas = dim * n_sigmas

    random_key = jax.random.key(seed)

    # Random separable states
    log_sigmas = sample_random_separable_states(
        model.n_qubits, n_samples=n_sigmas, random_key=random_key, take_log=True
    )

    random_key, _ = jax.random.split(random_key)

    if n_samples is not None and n_samples > 0:
        assert seed is not None, "Seed must be provided when samples > 0"
        model.initialize_params(random_key, repeat=n_samples)
    else:
        if seed is not None:
            log.warning("Seed is ignored when samples is 0")

        if len(model.params.shape) <= 2:
            model.params = model.params.reshape(*model.params.shape, 1)
        else:
            log.info(f"Using sample size of model params: {model.params.shape[-1]}")

    rhos, log_rhos = Entanglement._compute_log_density(model, **kwargs)

    rel_entropies = jnp.zeros((n_sigmas, model.params.shape[-1]))

    for i, log_sigma in enumerate(log_sigmas):
        rel_entropies = rel_entropies.at[i].set(
            Entanglement._compute_rel_entropies(
                rhos, log_rhos, log_sigma, model.use_multithreading
            )
        )

    # Entropy of GHZ states should be maximal
    ghz_model = Model(model.n_qubits, 1, "GHZ", data_reupload=False)
    rho_ghz, log_rho_ghz = Entanglement._compute_log_density(ghz_model, **kwargs)
    ghz_entropies = Entanglement._compute_rel_entropies(
        rho_ghz, log_rho_ghz, log_sigmas, use_multithreading=False
    )

    normalised_entropies = rel_entropies / ghz_entropies

    # Average all iterated states
    entangling_capability = normalised_entropies.T.min(axis=1)
    log.debug(f"Variance of measure: {entangling_capability.var()}")

    return entangling_capability.mean()

Expressibility#

from qml_essentials.expressibility import Expressibility
Source code in qml_essentials/expressibility.py
class Expressibility:
    @staticmethod
    def _sample_state_fidelities(
        model: Model,
        x_samples: jnp.ndarray,
        n_samples: int,
        seed: int,
        kwargs: Any,
    ) -> jnp.ndarray:
        """
        Compute the fidelities for each pair of input samples and parameter sets.

        Args:
            model (Callable): Function that models the quantum circuit.
            x_samples (jnp.ndarray): Array of shape (n_input_samples, n_features)
                containing the input samples.
            n_samples (int): Number of parameter sets to generate.
            seed (int): Random number generator seed.
            kwargs (Any): Additional keyword arguments for the model function.

        Returns:
            jnp.ndarray: Array of shape (n_input_samples, n_samples)
            containing the fidelities.
        """
        random_key = random.key(seed)

        # Generate random parameter sets
        # We need two sets of parameters, as we are computing fidelities for a
        # pair of random state vectors
        model.initialize_params(random_key, repeat=n_samples * 2)

        # Initialize array to store fidelities
        fidelities: jnp.ndarray = jnp.zeros((len(x_samples), n_samples))

        # Compute the fidelity for each pair of input samples and parameters
        for idx, x_sample in enumerate(x_samples):
            # Evaluate the model for the current pair of input samples and parameters
            # Execution type is explicitly set to density
            sv: jnp.ndarray = model(
                inputs=x_sample,
                params=model.params,
                execution_type="density",
                **kwargs,
            )

            # $\sqrt{\rho}$
            sqrt_sv1: jnp.ndarray = jnp.array([sqrtm(m) for m in sv[:n_samples]])

            # $\sqrt{\rho} \sigma \sqrt{\rho}$
            inner_fidelity = sqrt_sv1 @ sv[n_samples:] @ sqrt_sv1

            # Compute the fidelity using the partial trace of the statevector
            fidelity: jnp.ndarray = (
                jnp.trace(
                    jnp.array([sqrtm(m) for m in inner_fidelity]),
                    axis1=1,
                    axis2=2,
                )
                ** 2
            )

            fidelities = fidelities.at[idx].set(jnp.abs(fidelity))

        return fidelities

    @staticmethod
    def state_fidelities(
        seed: int,
        n_samples: int,
        n_bins: int,
        model: Model,
        n_input_samples: int = 0,
        input_domain: List[float] = None,
        scale: bool = False,
        **kwargs: Any,
    ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
        """
        Sample the state fidelities and histogram them into a 2D array.

        Args:
            seed (int): Random number generator seed.
            n_samples (int): Number of parameter sets to generate.
            n_bins (int): Number of histogram bins.
            n_input_samples (int): Number of input samples.
            input_domain (List[float]): Input domain.
            model (Callable): Function that models the quantum circuit.
            scale (bool): Whether to scale the number of samples and bins.
            kwargs (Any): Additional keyword arguments for the model function.

        Returns:
            Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: Tuple containing the
                input samples, bin edges, and histogram values.
        """
        if scale:
            n_samples = jnp.power(2, model.n_qubits) * n_samples
            n_bins = model.n_qubits * n_bins

        if input_domain is None or n_input_samples is None or n_input_samples == 0:
            x = jnp.zeros((1))
            n_input_samples = 1
        else:
            x = jnp.linspace(*input_domain, n_input_samples)

        fidelities = Expressibility._sample_state_fidelities(
            x_samples=x,
            n_samples=n_samples,
            seed=seed,
            model=model,
            kwargs=kwargs,
        )
        z: np.ndarray = np.zeros((n_input_samples, n_bins))

        y: jnp.ndarray = jnp.linspace(0, 1, n_bins + 1)

        for i, f in enumerate(fidelities):
            z[i], _ = jnp.histogram(f, bins=y)

        z = z / n_samples

        if z.shape[0] == 1:
            z = z.flatten()

        return x, y, z

    @staticmethod
    def _haar_probability(fidelity: float, n_qubits: int) -> float:
        """
        Calculates theoretical probability density function for random Haar states
        as proposed by Sim et al. (https://arxiv.org/abs/1905.10876).

        Args:
            fidelity (float): fidelity of two parameter assignments in [0, 1]
            n_qubits (int): number of qubits in the quantum system

        Returns:
            float: probability for a given fidelity
        """
        N = 2**n_qubits

        prob = (N - 1) * (1 - fidelity) ** (N - 2)
        return prob

    @staticmethod
    def _sample_haar_integral(n_qubits: int, n_bins: int) -> jnp.ndarray:
        """
        Calculates theoretical probability density function for random Haar states
        as proposed by Sim et al. (https://arxiv.org/abs/1905.10876) and bins it
        into a 2D-histogram.

        Args:
            n_qubits (int): number of qubits in the quantum system
            n_bins (int): number of histogram bins

        Returns:
            jnp.ndarray: probability distribution for all fidelities
        """
        dist = np.zeros(n_bins)
        for idx in range(n_bins):
            v = idx / n_bins
            u = (idx + 1) / n_bins
            dist[idx], _ = integrate.quad(
                Expressibility._haar_probability, v, u, args=(n_qubits,)
            )

        return dist

    @staticmethod
    def haar_integral(
        n_qubits: int,
        n_bins: int,
        cache: bool = True,
        scale: bool = False,
    ) -> Tuple[jnp.ndarray, jnp.ndarray]:
        """
        Calculates theoretical probability density function for random Haar states
        as proposed by Sim et al. (https://arxiv.org/abs/1905.10876) and bins it
        into a 3D-histogram.

        Args:
            n_qubits (int): number of qubits in the quantum system
            n_bins (int): number of histogram bins
            cache (bool): whether to cache the haar integral
            scale (bool): whether to scale the number of bins

        Returns:
            Tuple[jnp.ndarray, jnp.ndarray]:
                - x component (bins): the input domain
                - y component (probabilities): the haar probability density
                  funtion for random Haar states
        """
        if scale:
            n_bins = n_qubits * n_bins

        x = jnp.linspace(0, 1, n_bins)

        if cache:
            name = f"haar_{n_qubits}q_{n_bins}s_{'scaled' if scale else ''}.npy"

            cache_folder = ".cache"
            if not os.path.exists(cache_folder):
                os.mkdir(cache_folder)

            file_path = os.path.join(cache_folder, name)

            if os.path.isfile(file_path):
                y = jnp.load(file_path)
                return x, y

        y = Expressibility._sample_haar_integral(n_qubits, n_bins)

        if cache:
            jnp.save(file_path, y)

        return x, y

    @staticmethod
    def kullback_leibler_divergence(
        vqc_prob_dist: jnp.ndarray,
        haar_dist: jnp.ndarray,
    ) -> jnp.ndarray:
        """
        Calculates the KL divergence between two probability distributions (Haar
        probability distribution and the fidelity distribution sampled from a VQC).

        Args:
            vqc_prob_dist (jnp.ndarray): VQC fidelity probability distribution.
                Should have shape (n_inputs_samples, n_bins)
            haar_dist (jnp.ndarray): Haar probability distribution with shape.
                Should have shape (n_bins, )

        Returns:
            jnp.ndarray: Array of KL-Divergence values for all values in axis 1
        """
        if len(vqc_prob_dist.shape) > 1:
            assert all([haar_dist.shape == p.shape for p in vqc_prob_dist]), (
                "All probabilities for inputs should have the same shape as Haar. "
                f"Got {haar_dist.shape} for Haar and {vqc_prob_dist.shape} for VQC"
            )
        else:
            vqc_prob_dist = vqc_prob_dist.reshape((1, -1))

        kl_divergence = np.zeros(vqc_prob_dist.shape[0])
        for idx, p in enumerate(vqc_prob_dist):
            kl_divergence[idx] = jnp.sum(rel_entr(p, haar_dist))

        return kl_divergence

    def kl_divergence_to_haar(
        model: Model,
        seed: int,
        n_samples: int,
        n_bins: int,
        n_input_samples: int = 0,
        input_domain: List[float] = None,
        scale: bool = False,
        **kwargs: Any,
    ) -> float:
        """
        Shortcut method to compute the KL-Divergence bewteen a model and the
        Haar distribution. The basic steps are:
            - Sample the state fidelities for randomly initialised parameters.
            - Calculates the KL divergence between the sampled probability and
              the Haar probability distribution.

        Args:
            model (Model): Function that models the quantum circuit.
            seed (int): Random number generator seed.
            n_samples (int): Number of parameter sets to generate.
            n_bins (int): Number of histogram bins.
            n_input_samples (int): Number of input samples.
            input_domain (List[float]): Input domain.
            scale (bool): Whether to scale the number of samples and bins.
            kwargs (Any): Additional keyword arguments for the model function.

        Returns:
            Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: Tuple containing the
                input samples, bin edges, and histogram values.
        """
        _, _, fidelities = Expressibility.state_fidelities(
            model=model,
            seed=seed,
            n_samples=n_samples,
            n_bins=n_bins,
            n_input_samples=n_input_samples,
            input_domain=input_domain,
            scale=scale,
            **kwargs,
        )
        _, haar_probs = Expressibility.haar_integral(
            model.n_qubits, n_bins=n_bins, scale=scale
        )
        return Expressibility.kullback_leibler_divergence(fidelities, haar_probs)

haar_integral(n_qubits, n_bins, cache=True, scale=False) staticmethod #

Calculates theoretical probability density function for random Haar states as proposed by Sim et al. (https://arxiv.org/abs/1905.10876) and bins it into a 3D-histogram.

Parameters:

Name Type Description Default
n_qubits int

number of qubits in the quantum system

required
n_bins int

number of histogram bins

required
cache bool

whether to cache the haar integral

True
scale bool

whether to scale the number of bins

False

Returns:

Type Description
Tuple[ndarray, ndarray]

Tuple[jnp.ndarray, jnp.ndarray]: - x component (bins): the input domain - y component (probabilities): the haar probability density funtion for random Haar states

Source code in qml_essentials/expressibility.py
@staticmethod
def haar_integral(
    n_qubits: int,
    n_bins: int,
    cache: bool = True,
    scale: bool = False,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """
    Calculates theoretical probability density function for random Haar states
    as proposed by Sim et al. (https://arxiv.org/abs/1905.10876) and bins it
    into a 3D-histogram.

    Args:
        n_qubits (int): number of qubits in the quantum system
        n_bins (int): number of histogram bins
        cache (bool): whether to cache the haar integral
        scale (bool): whether to scale the number of bins

    Returns:
        Tuple[jnp.ndarray, jnp.ndarray]:
            - x component (bins): the input domain
            - y component (probabilities): the haar probability density
              funtion for random Haar states
    """
    if scale:
        n_bins = n_qubits * n_bins

    x = jnp.linspace(0, 1, n_bins)

    if cache:
        name = f"haar_{n_qubits}q_{n_bins}s_{'scaled' if scale else ''}.npy"

        cache_folder = ".cache"
        if not os.path.exists(cache_folder):
            os.mkdir(cache_folder)

        file_path = os.path.join(cache_folder, name)

        if os.path.isfile(file_path):
            y = jnp.load(file_path)
            return x, y

    y = Expressibility._sample_haar_integral(n_qubits, n_bins)

    if cache:
        jnp.save(file_path, y)

    return x, y

kl_divergence_to_haar(model, seed, n_samples, n_bins, n_input_samples=0, input_domain=None, scale=False, **kwargs) #

Shortcut method to compute the KL-Divergence bewteen a model and the Haar distribution. The basic steps are: - Sample the state fidelities for randomly initialised parameters. - Calculates the KL divergence between the sampled probability and the Haar probability distribution.

Parameters:

Name Type Description Default
model Model

Function that models the quantum circuit.

required
seed int

Random number generator seed.

required
n_samples int

Number of parameter sets to generate.

required
n_bins int

Number of histogram bins.

required
n_input_samples int

Number of input samples.

0
input_domain List[float]

Input domain.

None
scale bool

Whether to scale the number of samples and bins.

False
kwargs Any

Additional keyword arguments for the model function.

{}

Returns:

Type Description
float

Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: Tuple containing the input samples, bin edges, and histogram values.

Source code in qml_essentials/expressibility.py
def kl_divergence_to_haar(
    model: Model,
    seed: int,
    n_samples: int,
    n_bins: int,
    n_input_samples: int = 0,
    input_domain: List[float] = None,
    scale: bool = False,
    **kwargs: Any,
) -> float:
    """
    Shortcut method to compute the KL-Divergence bewteen a model and the
    Haar distribution. The basic steps are:
        - Sample the state fidelities for randomly initialised parameters.
        - Calculates the KL divergence between the sampled probability and
          the Haar probability distribution.

    Args:
        model (Model): Function that models the quantum circuit.
        seed (int): Random number generator seed.
        n_samples (int): Number of parameter sets to generate.
        n_bins (int): Number of histogram bins.
        n_input_samples (int): Number of input samples.
        input_domain (List[float]): Input domain.
        scale (bool): Whether to scale the number of samples and bins.
        kwargs (Any): Additional keyword arguments for the model function.

    Returns:
        Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: Tuple containing the
            input samples, bin edges, and histogram values.
    """
    _, _, fidelities = Expressibility.state_fidelities(
        model=model,
        seed=seed,
        n_samples=n_samples,
        n_bins=n_bins,
        n_input_samples=n_input_samples,
        input_domain=input_domain,
        scale=scale,
        **kwargs,
    )
    _, haar_probs = Expressibility.haar_integral(
        model.n_qubits, n_bins=n_bins, scale=scale
    )
    return Expressibility.kullback_leibler_divergence(fidelities, haar_probs)

kullback_leibler_divergence(vqc_prob_dist, haar_dist) staticmethod #

Calculates the KL divergence between two probability distributions (Haar probability distribution and the fidelity distribution sampled from a VQC).

Parameters:

Name Type Description Default
vqc_prob_dist ndarray

VQC fidelity probability distribution. Should have shape (n_inputs_samples, n_bins)

required
haar_dist ndarray

Haar probability distribution with shape. Should have shape (n_bins, )

required

Returns:

Type Description
ndarray

jnp.ndarray: Array of KL-Divergence values for all values in axis 1

Source code in qml_essentials/expressibility.py
@staticmethod
def kullback_leibler_divergence(
    vqc_prob_dist: jnp.ndarray,
    haar_dist: jnp.ndarray,
) -> jnp.ndarray:
    """
    Calculates the KL divergence between two probability distributions (Haar
    probability distribution and the fidelity distribution sampled from a VQC).

    Args:
        vqc_prob_dist (jnp.ndarray): VQC fidelity probability distribution.
            Should have shape (n_inputs_samples, n_bins)
        haar_dist (jnp.ndarray): Haar probability distribution with shape.
            Should have shape (n_bins, )

    Returns:
        jnp.ndarray: Array of KL-Divergence values for all values in axis 1
    """
    if len(vqc_prob_dist.shape) > 1:
        assert all([haar_dist.shape == p.shape for p in vqc_prob_dist]), (
            "All probabilities for inputs should have the same shape as Haar. "
            f"Got {haar_dist.shape} for Haar and {vqc_prob_dist.shape} for VQC"
        )
    else:
        vqc_prob_dist = vqc_prob_dist.reshape((1, -1))

    kl_divergence = np.zeros(vqc_prob_dist.shape[0])
    for idx, p in enumerate(vqc_prob_dist):
        kl_divergence[idx] = jnp.sum(rel_entr(p, haar_dist))

    return kl_divergence

state_fidelities(seed, n_samples, n_bins, model, n_input_samples=0, input_domain=None, scale=False, **kwargs) staticmethod #

Sample the state fidelities and histogram them into a 2D array.

Parameters:

Name Type Description Default
seed int

Random number generator seed.

required
n_samples int

Number of parameter sets to generate.

required
n_bins int

Number of histogram bins.

required
n_input_samples int

Number of input samples.

0
input_domain List[float]

Input domain.

None
model Callable

Function that models the quantum circuit.

required
scale bool

Whether to scale the number of samples and bins.

False
kwargs Any

Additional keyword arguments for the model function.

{}

Returns:

Type Description
Tuple[ndarray, ndarray, ndarray]

Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: Tuple containing the input samples, bin edges, and histogram values.

Source code in qml_essentials/expressibility.py
@staticmethod
def state_fidelities(
    seed: int,
    n_samples: int,
    n_bins: int,
    model: Model,
    n_input_samples: int = 0,
    input_domain: List[float] = None,
    scale: bool = False,
    **kwargs: Any,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """
    Sample the state fidelities and histogram them into a 2D array.

    Args:
        seed (int): Random number generator seed.
        n_samples (int): Number of parameter sets to generate.
        n_bins (int): Number of histogram bins.
        n_input_samples (int): Number of input samples.
        input_domain (List[float]): Input domain.
        model (Callable): Function that models the quantum circuit.
        scale (bool): Whether to scale the number of samples and bins.
        kwargs (Any): Additional keyword arguments for the model function.

    Returns:
        Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: Tuple containing the
            input samples, bin edges, and histogram values.
    """
    if scale:
        n_samples = jnp.power(2, model.n_qubits) * n_samples
        n_bins = model.n_qubits * n_bins

    if input_domain is None or n_input_samples is None or n_input_samples == 0:
        x = jnp.zeros((1))
        n_input_samples = 1
    else:
        x = jnp.linspace(*input_domain, n_input_samples)

    fidelities = Expressibility._sample_state_fidelities(
        x_samples=x,
        n_samples=n_samples,
        seed=seed,
        model=model,
        kwargs=kwargs,
    )
    z: np.ndarray = np.zeros((n_input_samples, n_bins))

    y: jnp.ndarray = jnp.linspace(0, 1, n_bins + 1)

    for i, f in enumerate(fidelities):
        z[i], _ = jnp.histogram(f, bins=y)

    z = z / n_samples

    if z.shape[0] == 1:
        z = z.flatten()

    return x, y, z

Coefficients#

from qml_essentials.coefficients import Coefficients
Source code in qml_essentials/coefficients.py
class Coefficients:
    @staticmethod
    def get_spectrum(
        model: Model,
        mfs: int = 1,
        mts: int = 1,
        shift=False,
        trim=False,
        **kwargs,
    ) -> Tuple[jnp.ndarray, jnp.ndarray]:
        """
        Extracts the coefficients of a given model using a FFT (jnp-fft).

        Note that the coefficients are complex numbers, but the imaginary part
        of the coefficients should be very close to zero, since the expectation
        values of the Pauli operators are real numbers.

        It can perform oversampling in both the frequency and time domain
        using the `mfs` and `mts` arguments.

        Args:
            model (Model): The model to sample.
            mfs (int): Multiplicator for the highest frequency. Default is 2.
            mts (int): Multiplicator for the number of time samples. Default is 1.
            shift (bool): Whether to apply jnp-fftshift. Default is False.
            trim (bool): Whether to remove the Nyquist frequency if spectrum is even.
                Default is False.
            kwargs (Any): Additional keyword arguments for the model function.

        Returns:
            Tuple[jnp.ndarray, jnp.ndarray]: Tuple containing the coefficients
            and frequencies.
        """
        kwargs.setdefault("force_mean", True)
        kwargs.setdefault("execution_type", "expval")

        coeffs, freqs = Coefficients._fourier_transform(
            model, mfs=mfs, mts=mts, **kwargs
        )

        if not jnp.isclose(jnp.sum(coeffs).imag, 0.0, rtol=1.0e-5):
            raise ValueError(
                f"Spectrum is not real. Imaginary part of coefficients is:\
                {jnp.sum(coeffs).imag}"
            )

        if trim:
            for ax in range(model.n_input_feat):
                if coeffs.shape[ax] % 2 == 0:
                    coeffs = np.delete(coeffs, len(coeffs) // 2, axis=ax)
                    freqs = [np.delete(freq, len(freq) // 2, axis=ax) for freq in freqs]

        if shift:
            coeffs = jnp.fft.fftshift(coeffs, axes=list(range(model.n_input_feat)))
            freqs = np.fft.fftshift(freqs)

        if len(freqs) == 1:
            freqs = freqs[0]

        return coeffs, freqs

    @staticmethod
    def _fourier_transform(
        model: Model, mfs: int, mts: int, **kwargs: Any
    ) -> jnp.ndarray:
        # Create a frequency vector with as many frequencies as model degrees,
        # oversampled by mfs
        n_freqs: jnp.ndarray = jnp.array(
            [mfs * model.degree[i] for i in range(model.n_input_feat)]
        )

        start, stop, step = 0, 2 * mts * jnp.pi, 2 * jnp.pi / n_freqs
        # Stretch according to the number of frequencies
        inputs: List = [
            jnp.arange(start, stop, step[i]) for i in range(model.n_input_feat)
        ]

        # permute with input dimensionality
        nd_inputs = jnp.array(
            jnp.meshgrid(*[inputs[i] for i in range(model.n_input_feat)])
        ).T.reshape(-1, model.n_input_feat)

        # Output vector is not necessarily the same length as input
        outputs = model(inputs=nd_inputs, **kwargs)
        outputs = outputs.reshape(
            *[inputs[i].shape[0] for i in range(model.n_input_feat)], -1
        ).squeeze()

        coeffs = jnp.fft.fftn(outputs, axes=list(range(model.n_input_feat)))

        freqs = [
            jnp.fft.fftfreq(int(mts * n_freqs[i]), 1 / n_freqs[i])
            for i in range(model.n_input_feat)
        ]
        # freqs = jnp.fft.fftfreq(mts * n_freqs, 1 / n_freqs)

        # TODO: this could cause issues with multidim input
        # FIXME: account for different frequencies in multidim input scenarios
        # Run the fft and rearrange +
        # normalize the output (using product if multidim)
        return (
            coeffs / math.prod(outputs.shape[0 : model.n_input_feat]),
            freqs,
        )

    @staticmethod
    def get_psd(coeffs: jnp.ndarray) -> jnp.ndarray:
        """
        Calculates the power spectral density (PSD) from given Fourier coefficients.

        Args:
            coeffs (jnp.ndarray): The Fourier coefficients.

        Returns:
            jnp.ndarray: The power spectral density.
        """
        # TODO: if we apply trim=True in advance, this will be slightly wrong..

        def abs2(x):
            return x.real**2 + x.imag**2

        scale = 2.0 / (len(coeffs) ** 2)
        return scale * abs2(coeffs)

    @staticmethod
    def evaluate_Fourier_series(
        coefficients: jnp.ndarray,
        frequencies: jnp.ndarray,
        inputs: Union[jnp.ndarray, list, float],
    ) -> float:
        """
        Evaluate the function value of a Fourier series at one point.

        Args:
            coefficients (jnp.ndarray): Coefficients of the Fourier series.
            frequencies (jnp.ndarray): Corresponding frequencies.
            inputs (jnp.ndarray): Point at which to evaluate the function.
        Returns:
            float: The function value at the input point.
        """
        if isinstance(frequencies, list):
            if len(coefficients.shape) <= len(frequencies):
                coefficients = coefficients[..., jnp.newaxis]
        else:
            if len(coefficients.shape) == 1:
                coefficients = coefficients[..., jnp.newaxis]

        if isinstance(inputs, list):
            inputs = jnp.array(inputs)
        if len(inputs.shape) < 1:
            inputs = inputs[jnp.newaxis, ...]

        if isinstance(frequencies, list):
            input_dim = len(frequencies)
            frequencies = jnp.stack(jnp.meshgrid(*frequencies))
            if input_dim != len(inputs):
                frequencies = jnp.repeat(
                    frequencies[jnp.newaxis, ...], inputs.shape[0], axis=0
                )
                freq_inputs = jnp.einsum("bi...,b->b...", frequencies, inputs)
                exponents = jnp.exp(1j * freq_inputs).T
                exp = jnp.einsum("jl...k,jl...b->b...k", coefficients, exponents)
            else:
                freq_inputs = jnp.einsum("i...,i->...", frequencies, inputs)
                exponents = jnp.exp(1j * freq_inputs).T
                exp = jnp.einsum("jl...k,jl...->k...", coefficients, exponents)
        else:
            frequencies = jnp.repeat(
                frequencies[jnp.newaxis, ...], inputs.shape[0], axis=0
            )
            freq_inputs = jnp.einsum("i...,i->i...", frequencies, inputs)
            exponents = jnp.exp(1j * freq_inputs)
            exp = jnp.einsum("j...k,ij...->ik...", coefficients, exponents)

        return jnp.squeeze(jnp.real(exp))

evaluate_Fourier_series(coefficients, frequencies, inputs) staticmethod #

Evaluate the function value of a Fourier series at one point.

Parameters:

Name Type Description Default
coefficients ndarray

Coefficients of the Fourier series.

required
frequencies ndarray

Corresponding frequencies.

required
inputs ndarray

Point at which to evaluate the function.

required

Returns: float: The function value at the input point.

Source code in qml_essentials/coefficients.py
@staticmethod
def evaluate_Fourier_series(
    coefficients: jnp.ndarray,
    frequencies: jnp.ndarray,
    inputs: Union[jnp.ndarray, list, float],
) -> float:
    """
    Evaluate the function value of a Fourier series at one point.

    Args:
        coefficients (jnp.ndarray): Coefficients of the Fourier series.
        frequencies (jnp.ndarray): Corresponding frequencies.
        inputs (jnp.ndarray): Point at which to evaluate the function.
    Returns:
        float: The function value at the input point.
    """
    if isinstance(frequencies, list):
        if len(coefficients.shape) <= len(frequencies):
            coefficients = coefficients[..., jnp.newaxis]
    else:
        if len(coefficients.shape) == 1:
            coefficients = coefficients[..., jnp.newaxis]

    if isinstance(inputs, list):
        inputs = jnp.array(inputs)
    if len(inputs.shape) < 1:
        inputs = inputs[jnp.newaxis, ...]

    if isinstance(frequencies, list):
        input_dim = len(frequencies)
        frequencies = jnp.stack(jnp.meshgrid(*frequencies))
        if input_dim != len(inputs):
            frequencies = jnp.repeat(
                frequencies[jnp.newaxis, ...], inputs.shape[0], axis=0
            )
            freq_inputs = jnp.einsum("bi...,b->b...", frequencies, inputs)
            exponents = jnp.exp(1j * freq_inputs).T
            exp = jnp.einsum("jl...k,jl...b->b...k", coefficients, exponents)
        else:
            freq_inputs = jnp.einsum("i...,i->...", frequencies, inputs)
            exponents = jnp.exp(1j * freq_inputs).T
            exp = jnp.einsum("jl...k,jl...->k...", coefficients, exponents)
    else:
        frequencies = jnp.repeat(
            frequencies[jnp.newaxis, ...], inputs.shape[0], axis=0
        )
        freq_inputs = jnp.einsum("i...,i->i...", frequencies, inputs)
        exponents = jnp.exp(1j * freq_inputs)
        exp = jnp.einsum("j...k,ij...->ik...", coefficients, exponents)

    return jnp.squeeze(jnp.real(exp))

get_psd(coeffs) staticmethod #

Calculates the power spectral density (PSD) from given Fourier coefficients.

Parameters:

Name Type Description Default
coeffs ndarray

The Fourier coefficients.

required

Returns:

Type Description
ndarray

jnp.ndarray: The power spectral density.

Source code in qml_essentials/coefficients.py
@staticmethod
def get_psd(coeffs: jnp.ndarray) -> jnp.ndarray:
    """
    Calculates the power spectral density (PSD) from given Fourier coefficients.

    Args:
        coeffs (jnp.ndarray): The Fourier coefficients.

    Returns:
        jnp.ndarray: The power spectral density.
    """
    # TODO: if we apply trim=True in advance, this will be slightly wrong..

    def abs2(x):
        return x.real**2 + x.imag**2

    scale = 2.0 / (len(coeffs) ** 2)
    return scale * abs2(coeffs)

get_spectrum(model, mfs=1, mts=1, shift=False, trim=False, **kwargs) staticmethod #

Extracts the coefficients of a given model using a FFT (jnp-fft).

Note that the coefficients are complex numbers, but the imaginary part of the coefficients should be very close to zero, since the expectation values of the Pauli operators are real numbers.

It can perform oversampling in both the frequency and time domain using the mfs and mts arguments.

Parameters:

Name Type Description Default
model Model

The model to sample.

required
mfs int

Multiplicator for the highest frequency. Default is 2.

1
mts int

Multiplicator for the number of time samples. Default is 1.

1
shift bool

Whether to apply jnp-fftshift. Default is False.

False
trim bool

Whether to remove the Nyquist frequency if spectrum is even. Default is False.

False
kwargs Any

Additional keyword arguments for the model function.

{}

Returns:

Type Description
ndarray

Tuple[jnp.ndarray, jnp.ndarray]: Tuple containing the coefficients

ndarray

and frequencies.

Source code in qml_essentials/coefficients.py
@staticmethod
def get_spectrum(
    model: Model,
    mfs: int = 1,
    mts: int = 1,
    shift=False,
    trim=False,
    **kwargs,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """
    Extracts the coefficients of a given model using a FFT (jnp-fft).

    Note that the coefficients are complex numbers, but the imaginary part
    of the coefficients should be very close to zero, since the expectation
    values of the Pauli operators are real numbers.

    It can perform oversampling in both the frequency and time domain
    using the `mfs` and `mts` arguments.

    Args:
        model (Model): The model to sample.
        mfs (int): Multiplicator for the highest frequency. Default is 2.
        mts (int): Multiplicator for the number of time samples. Default is 1.
        shift (bool): Whether to apply jnp-fftshift. Default is False.
        trim (bool): Whether to remove the Nyquist frequency if spectrum is even.
            Default is False.
        kwargs (Any): Additional keyword arguments for the model function.

    Returns:
        Tuple[jnp.ndarray, jnp.ndarray]: Tuple containing the coefficients
        and frequencies.
    """
    kwargs.setdefault("force_mean", True)
    kwargs.setdefault("execution_type", "expval")

    coeffs, freqs = Coefficients._fourier_transform(
        model, mfs=mfs, mts=mts, **kwargs
    )

    if not jnp.isclose(jnp.sum(coeffs).imag, 0.0, rtol=1.0e-5):
        raise ValueError(
            f"Spectrum is not real. Imaginary part of coefficients is:\
            {jnp.sum(coeffs).imag}"
        )

    if trim:
        for ax in range(model.n_input_feat):
            if coeffs.shape[ax] % 2 == 0:
                coeffs = np.delete(coeffs, len(coeffs) // 2, axis=ax)
                freqs = [np.delete(freq, len(freq) // 2, axis=ax) for freq in freqs]

    if shift:
        coeffs = jnp.fft.fftshift(coeffs, axes=list(range(model.n_input_feat)))
        freqs = np.fft.fftshift(freqs)

    if len(freqs) == 1:
        freqs = freqs[0]

    return coeffs, freqs