Skip to content

References

Ansaetze#

from qml_essentials.ansaetze import Ansaetze
Source code in qml_essentials/ansaetze.py
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
class Ansaetze:
    def get_available(parameterized_only=False):
        # list of parameterized ansaetze
        ansaetze = [
            Ansaetze.Circuit_1,
            Ansaetze.Circuit_2,
            Ansaetze.Circuit_3,
            Ansaetze.Circuit_4,
            Ansaetze.Circuit_5,
            Ansaetze.Circuit_6,
            Ansaetze.Circuit_7,
            Ansaetze.Circuit_8,
            Ansaetze.Circuit_9,
            Ansaetze.Circuit_10,
            Ansaetze.Circuit_13,
            Ansaetze.Circuit_14,
            Ansaetze.Circuit_15,
            Ansaetze.Circuit_16,
            Ansaetze.Circuit_17,
            Ansaetze.Circuit_18,
            Ansaetze.Circuit_19,
            Ansaetze.Circuit_20,
            Ansaetze.No_Entangling,
            Ansaetze.Strongly_Entangling,
            Ansaetze.Hardware_Efficient,
        ]

        # extend by the non-parameterized ones
        if not parameterized_only:
            ansaetze += [
                Ansaetze.No_Ansatz,
                Ansaetze.GHZ,
            ]

        return ansaetze

    class No_Ansatz(DeclarativeCircuit):
        @staticmethod
        def structure():
            return ()

    class GHZ(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.H),
                Block(gate=Gates.CX, topology=Topology.stairs, reverse=True),
            )

        @staticmethod
        def build(w: np.ndarray, n_qubits: int, **kwargs):
            Gates.H(wires=0, **kwargs)
            for q in range(n_qubits - 1):
                Gates.CX(wires=[q, q + 1], **kwargs)

        @staticmethod
        def n_pulse_params_per_layer(n_qubits: int) -> int:
            n_params = PulseInformation.num_params("H")  # only 1 H
            n_params += (n_qubits - 1) * PulseInformation.num_params(Gates.CX)
            return n_params

    class Circuit_1(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.RX),
                Block(gate=Gates.RZ),
            )

    class Circuit_2(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.RX),
                Block(gate=Gates.RZ),
                Block(
                    gate=Gates.CX,
                    topology=Topology.stairs,
                ),
            )

    class Circuit_3(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.RX),
                Block(gate=Gates.RZ),
                Block(gate=Gates.CRZ, topology=Topology.stairs),
            )

    class Circuit_4(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.RX),
                Block(gate=Gates.RZ),
                Block(gate=Gates.CRX, topology=Topology.stairs),
            )

    class Circuit_5(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.RX),
                Block(gate=Gates.RZ),
                Block(gate=Gates.CRZ, topology=Topology.all_to_all),
                Block(gate=Gates.RX),
                Block(gate=Gates.RZ),
            )

    class Circuit_6(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.RX),
                Block(gate=Gates.RZ),
                Block(gate=Gates.CRX, topology=Topology.all_to_all),
                Block(gate=Gates.RX),
                Block(gate=Gates.RZ),
            )

    class Circuit_7(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.RX),
                Block(gate=Gates.RZ),
                Block(
                    gate=Gates.CRZ,
                    topology=Topology.bricks,
                ),
                Block(gate=Gates.RX),
                Block(gate=Gates.RZ),
                Block(
                    gate=Gates.CRZ,
                    topology=Topology.bricks,
                    offset=1,
                ),
            )

    class Circuit_8(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.RX),
                Block(gate=Gates.RZ),
                Block(
                    gate=Gates.CRX,
                    topology=Topology.bricks,
                ),
                Block(gate=Gates.RX),
                Block(gate=Gates.RZ),
                Block(
                    gate=Gates.CRX,
                    topology=Topology.bricks,
                    offset=1,
                ),
            )

    class Circuit_9(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.H),
                Block(gate="CZ", topology=Topology.stairs),
                Block(gate=Gates.RX),
            )

    class Circuit_10(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.RY),
                Block(gate="CZ", topology=Topology.stairs, offset=-1, wrap=True),
                Block(gate=Gates.RY),
            )

    class Circuit_13(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.RY),
                Block(
                    gate=Gates.CRZ,
                    topology=Topology.stairs,
                    wrap=True,
                    reverse=True,
                    mirror=False,
                ),
                Block(gate=Gates.RY),
                Block(
                    gate=Gates.CRZ,
                    topology=Topology.stairs,
                    reverse=False,
                    mirror=False,
                    offset=lambda n: n - 1,
                    span=3,
                    wrap=True,
                ),
            )

    class Circuit_14(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.RY),
                Block(
                    gate=Gates.CRX,
                    topology=Topology.stairs,
                    wrap=True,
                    reverse=True,
                    mirror=False,
                ),
                Block(gate=Gates.RY),
                Block(
                    gate=Gates.CRX,
                    topology=Topology.stairs,
                    reverse=False,
                    mirror=False,
                    offset=lambda n: n - 1,
                    span=3,
                    wrap=True,
                ),
            )

    class Circuit_15(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.RY),
                Block(
                    gate=Gates.CX,
                    topology=Topology.stairs,
                    wrap=True,
                    reverse=True,
                    mirror=False,
                ),
                Block(gate=Gates.RY),
                Block(
                    gate=Gates.CX,
                    topology=Topology.stairs,
                    reverse=False,
                    mirror=False,
                    offset=lambda n: n - 1,
                    span=3,
                    wrap=True,
                ),
            )

    class Circuit_16(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.RX),
                Block(gate=Gates.RZ),
                Block(
                    gate=Gates.CRZ,
                    topology=Topology.bricks,
                ),
                Block(
                    gate=Gates.CRZ,
                    topology=Topology.bricks,
                    offset=1,
                ),
            )

    class Circuit_17(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.RX),
                Block(gate=Gates.RZ),
                Block(
                    gate=Gates.CRX,
                    topology=Topology.bricks,
                ),
                Block(
                    gate=Gates.CRX,
                    topology=Topology.bricks,
                    offset=1,
                ),
            )

    class Circuit_18(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.RX),
                Block(gate=Gates.RZ),
                Block(
                    gate=Gates.CRZ,
                    topology=Topology.stairs,
                    wrap=True,
                    mirror=False,
                ),
            )

    class Circuit_19(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.RX),
                Block(gate=Gates.RZ),
                Block(
                    gate=Gates.CRX,
                    topology=Topology.stairs,
                    wrap=True,
                    mirror=False,
                ),
            )

    class Circuit_20(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.RY),
                Block(
                    gate=Gates.CX,
                    topology=Topology.stairs,
                    wrap=True,
                    reverse=True,
                    mirror=False,
                ),
                Block(gate=Gates.RY),
                Block(
                    gate=Gates.CX,
                    topology=Topology.stairs,
                    reverse=False,
                    offset=lambda n: n - 2,
                    span=1,
                    wrap=True,
                ),
            )

    class No_Entangling(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (Block(gate=Gates.Rot),)

    class Hardware_Efficient(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.RY),
                Block(gate=Gates.RZ),
                Block(gate=Gates.RY),
                Block(
                    gate=Gates.CX,
                    topology=Topology.bricks,
                    mirror=False,
                ),
                Block(
                    gate=Gates.CX,
                    topology=Topology.bricks,
                    offset=-1,
                    modulo=True,
                    wrap=True,
                    mirror=False,
                ),
            )

    class Strongly_Entangling(DeclarativeCircuit):
        @staticmethod
        def structure():
            return (
                Block(gate=Gates.Rot),
                Block(
                    gate=Gates.CX,
                    topology=Topology.stairs,
                    wrap=True,
                    reverse=False,
                    mirror=False,
                ),
                Block(gate=Gates.Rot),
                Block(
                    gate=Gates.CX,
                    topology=Topology.stairs,
                    reverse=False,
                    span=lambda n: n // 2,
                    wrap=True,
                    mirror=False,
                ),
            )

Gates#

As the structure of the different classes used to realize pulse and unitary gates can be a bit confusing, the following diagram might help:

Gate Structure Gate Structure

from qml_essentials.gates import Gates

Dynamic accessor for quantum Gates.

Routes calls like Gates.RX(...) to either UnitaryGates or PulseGates depending on the gate_mode keyword (defaults to 'unitary').

During circuit building, the pulse manager can be activated via pulse_manager_context, which slices the global model pulse parameters and passes them to each gate. Model pulse parameters act as element-wise scalers on the gate's optimized pulse parameters.

Parameters#

gate_mode : str, optional Determines the backend. 'unitary' for UnitaryGates, 'pulse' for PulseGates. Defaults to 'unitary'.

Examples#

Gates.RX(w, wires) Gates.RX(w, wires, gate_mode="unitary") Gates.RX(w, wires, gate_mode="pulse") Gates.RX(w, wires, pulse_params, gate_mode="pulse")

Source code in qml_essentials/gates.py
class Gates(metaclass=GatesMeta):
    """
    Dynamic accessor for quantum Gates.

    Routes calls like `Gates.RX(...)` to either `UnitaryGates` or `PulseGates`
    depending on the `gate_mode` keyword (defaults to 'unitary').

    During circuit building, the pulse manager can be activated via
    `pulse_manager_context`, which slices the global model pulse parameters
    and passes them to each gate. Model pulse parameters act as element-wise
    scalers on the gate's optimized pulse parameters.

    Parameters
    ----------
    gate_mode : str, optional
        Determines the backend. 'unitary' for UnitaryGates, 'pulse' for PulseGates.
        Defaults to 'unitary'.

    Examples
    --------
    >>> Gates.RX(w, wires)
    >>> Gates.RX(w, wires, gate_mode="unitary")
    >>> Gates.RX(w, wires, gate_mode="pulse")
    >>> Gates.RX(w, wires, pulse_params, gate_mode="pulse")
    """

    def __getattr__(self, gate_name):
        def handler(**kwargs):
            return self._inner_getattr(gate_name, **kwargs)

        return handler

    @staticmethod
    def _inner_getattr(gate_name, *args, **kwargs):
        if gate_name == "Barrier":
            return Barrier(*args, **kwargs)

        gate_mode = kwargs.pop("gate_mode", "unitary")

        # Backend selection and kwargs filtering
        allowed_args = [
            "w",
            "wires",
            "phi",
            "theta",
            "omega",
            "input_idx",
            "noise_params",
            "random_key",
        ]
        if gate_mode == "unitary":
            gate_backend = UnitaryGates
        elif gate_mode == "pulse":
            gate_backend = PulseGates
            allowed_args += ["pulse_params"]
        else:
            raise ValueError(
                f"Unknown gate mode: {gate_mode}. Use 'unitary' or 'pulse'."
            )

        if len(kwargs.keys() - allowed_args) > 0:
            # TODO: pulse params are always provided?
            log.debug(
                f"Unsupported keyword arguments: {list(kwargs.keys() - allowed_args)}"
            )

        kwargs = {k: v for k, v in kwargs.items() if k in allowed_args}
        pulse_params = kwargs.get("pulse_params")
        pulse_mgr = getattr(Gates, "_pulse_mgr", None)

        # TODO: rework this part to convert to valid PulseParams earlier
        # Type check on pulse parameters
        if pulse_params is not None:
            # flatten pulse parameters
            if isinstance(pulse_params, (list, tuple)):
                flat_params = pulse_params

            elif isinstance(pulse_params, jax.core.Tracer):
                flat_params = jnp.ravel(pulse_params)

            elif isinstance(pulse_params, (jnp.ndarray, jnp.ndarray)):
                flat_params = pulse_params.flatten().tolist()
            elif isinstance(pulse_params, PulseParams):
                # extract the params in case a full object is given
                kwargs["pulse_params"] = pulse_params.params
                flat_params = pulse_params.params.flatten().tolist()

            else:
                raise TypeError(f"Unsupported pulse_params type: {type(pulse_params)}")

            # checks elements in flat parameters are real numbers or jax Tracer
            if not all(
                isinstance(x, (numbers.Real, jax.core.Tracer)) for x in flat_params
            ):
                raise TypeError(
                    "All elements in pulse_params must be int or float, "
                    f"got {pulse_params}, type {type(pulse_params)}. "
                )

        # Len check on pulse parameters
        if pulse_params is not None and not isinstance(pulse_mgr, PulseParamManager):
            n_params = PulseInformation.gate_by_name(gate_name).size
            if len(flat_params) != n_params:
                raise ValueError(
                    f"Gate '{gate_name}' expects {n_params} pulse parameters, "
                    f"got {len(flat_params)}"
                )

        # Pulse slicing + scaling
        if gate_mode == "pulse" and isinstance(pulse_mgr, PulseParamManager):
            n_params = PulseInformation.gate_by_name(gate_name).size
            scalers = pulse_mgr.get(n_params)
            base = PulseInformation.gate_by_name(gate_name).params
            kwargs["pulse_params"] = base * scalers

        # Call the selected gate backend
        gate = getattr(gate_backend, gate_name, None)
        if gate is None:
            raise AttributeError(
                f"'{gate_backend.__class__.__name__}' object "
                f"has no attribute '{gate_name}'"
            )

        return gate(*args, **kwargs)

    @staticmethod
    @contextmanager
    def pulse_manager_context(pulse_params: jnp.ndarray):
        """Temporarily set the global pulse manager for circuit building."""
        Gates._pulse_mgr = PulseParamManager(pulse_params)
        try:
            yield
        finally:
            Gates._pulse_mgr = None

    @staticmethod
    def parse_gates(
        gates: Union[str, Callable, List[Union[str, Callable]]],
        set_of_gates=None,
    ):
        set_of_gates = set_of_gates or Gates

        if isinstance(gates, str):
            # if str, use the pennylane fct
            parsed_gates = [getattr(set_of_gates, f"{gates}")]
        elif isinstance(gates, list):
            parsed_gates = []
            for enc in gates:
                # if list, check if str or callable
                if isinstance(enc, str):
                    parsed_gates.append(getattr(set_of_gates, f"{enc}"))
                # check if callable
                elif callable(enc):
                    parsed_gates.append(enc)
                else:
                    raise ValueError(
                        f"Operation {enc} is not a valid gate or callable.\
                        Got {type(enc)}"
                    )
        elif callable(gates):
            # default to callable
            parsed_gates = [gates]
        elif gates is None:
            parsed_gates = [lambda *args, **kwargs: None]
        else:
            raise ValueError(
                f"Operation {gates} is not a valid gate or callable or list of both."
            )
        return parsed_gates

    @staticmethod
    def is_rotational(gate):
        return gate.__name__ in [
            "RX",
            "RY",
            "RZ",
            "Rot",
            "CRX",
            "CRY",
            "CRZ",
        ]

    @staticmethod
    def is_entangling(gate):
        return gate.__name__ in ["CX", "CY", "CZ", "CRX", "CRY", "CRZ"]

pulse_manager_context(pulse_params) staticmethod #

Temporarily set the global pulse manager for circuit building.

Source code in qml_essentials/gates.py
@staticmethod
@contextmanager
def pulse_manager_context(pulse_params: jnp.ndarray):
    """Temporarily set the global pulse manager for circuit building."""
    Gates._pulse_mgr = PulseParamManager(pulse_params)
    try:
        yield
    finally:
        Gates._pulse_mgr = None
from qml_essentials.gates import UnitaryGates

Unitary Gates#

Collection of unitary quantum gates with optional noise simulation.

Source code in qml_essentials/unitary.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
class UnitaryGates:
    """Collection of unitary quantum gates with optional noise simulation."""

    batch_gate_error = True

    @staticmethod
    def NQubitDepolarizingChannel(p: float, wires: List[int]) -> op.QubitChannel:
        """
        Generate Kraus operators for n-qubit depolarizing channel.

        The n-qubit depolarizing channel models uniform depolarizing noise
        acting on n qubits simultaneously, useful for simulating realistic
        multi-qubit noise affecting entangling gates.

        Args:
            p (float): Total probability of depolarizing error (0 ≤ p ≤ 1).
            wires (List[int]): Qubit indices on which the channel acts.
                Must contain at least 2 qubits.

        Returns:
            op.QubitChannel: QubitChannel with Kraus operators
                representing the depolarizing noise channel.

        Raises:
            ValueError: If p is not in [0, 1] or if fewer than 2 qubits provided.
        """

        def n_qubit_depolarizing_kraus(p: float, n: int) -> List[jnp.ndarray]:
            if not (0.0 <= p <= 1.0):
                raise ValueError(f"Probability p must be between 0 and 1, got {p}")
            if n < 2:
                raise ValueError(f"Number of qubits must be >= 2, got {n}")

            Id = jnp.eye(2)
            X = op.PauliX._matrix
            Y = op.PauliY._matrix
            Z = op.PauliZ._matrix
            paulis = [Id, X, Y, Z]

            dim = 2**n
            all_ops = []

            # Generate all n-qubit Pauli tensor products:
            for indices in itertools.product(range(4), repeat=n):
                P = jnp.eye(1)
                for idx in indices:
                    P = jnp.kron(P, paulis[idx])
                all_ops.append(P)

            # Identity operator corresponds to all zeros indices (Id^n)
            K0 = jnp.sqrt(1 - p * (4**n - 1) / (4**n)) * jnp.eye(dim)

            kraus_ops = []
            for i, P in enumerate(all_ops):
                if i == 0:
                    # Skip the identity, already handled as K0
                    continue
                kraus_ops.append(jnp.sqrt(p / (4**n)) * P)

            return [K0] + kraus_ops

        return op.QubitChannel(n_qubit_depolarizing_kraus(p, len(wires)), wires=wires)

    @staticmethod
    def Noise(
        wires: Union[int, List[int]], noise_params: Optional[Dict[str, float]] = None
    ) -> None:
        """
        Apply noise channels to specified qubits.

        Applies various single-qubit and multi-qubit noise channels based on
        the provided noise parameters dictionary.

        Args:
            wires (Union[int, List[int]]): Qubit index or list of qubit indices
                to apply noise to.
            noise_params (Optional[Dict[str, float]]): Dictionary of noise
                parameters. Supported keys:
                - "BitFlip" (float): Bit flip error probability
                - "PhaseFlip" (float): Phase flip error probability
                - "Depolarizing" (float): Single-qubit depolarizing probability
                - "MultiQubitDepolarizing" (float): Multi-qubit depolarizing
                  probability (applies if len(wires) > 1)
                All parameters default to 0.0 if not provided.

        Returns:
            None: Noise channels are applied in-place to the circuit.
        """
        if noise_params is not None:
            if isinstance(wires, int):
                wires = [wires]  # single qubit gate

            # noise on single qubits
            for wire in wires:
                bf = noise_params.get("BitFlip", 0.0)
                if bf > 0:
                    op.BitFlip(bf, wires=wire)

                pf = noise_params.get("PhaseFlip", 0.0)
                if pf > 0:
                    op.PhaseFlip(pf, wires=wire)

                dp = noise_params.get("Depolarizing", 0.0)
                if dp > 0:
                    op.DepolarizingChannel(dp, wires=wire)

            # noise on two-qubits
            if len(wires) > 1:
                p = noise_params.get("MultiQubitDepolarizing", 0.0)
                if p > 0:
                    UnitaryGates.NQubitDepolarizingChannel(p, wires)

    @staticmethod
    def GateError(
        w: Union[float, jnp.ndarray, List[float]],
        noise_params: Optional[Dict[str, float]] = None,
        random_key: Optional[jax.random.PRNGKey] = None,
    ) -> Tuple[jnp.ndarray, jax.random.PRNGKey]:
        """
        Apply gate error noise to rotation angle(s).

        Adds Gaussian noise to gate rotation angles to simulate imperfect
        gate implementations.

        Args:
            w (Union[float, jnp.ndarray, List[float]]): Rotation angle(s) in radians.
            noise_params (Optional[Dict[str, float]]): Dictionary with optional
                "GateError" key specifying standard deviation of Gaussian noise.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for
                stochastic noise generation.

        Returns:
            Tuple[jnp.ndarray, jax.random.PRNGKey]: Tuple containing:
                - Modified rotation angle(s) with applied noise
                - Updated JAX random key

        Raises:
            AssertionError: If noise_params contains "GateError" but random_key is None.
        """
        if noise_params is not None and noise_params.get("GateError", None) is not None:
            assert (
                random_key is not None
            ), "A random_key must be provided when using GateError"

            if UnitaryGates.batch_gate_error:
                random_key, sub_key = safe_random_split(random_key)
            else:
                # Use a fixed key so that every batch element (under vmap)
                # draws the same noise value, effectively broadcasting.
                sub_key = jax.random.key(0)

            w += noise_params["GateError"] * jax.random.normal(
                sub_key,
                (
                    w.shape
                    if isinstance(w, jnp.ndarray) and UnitaryGates.batch_gate_error
                    else ()
                ),
            )
        return w, random_key

    @staticmethod
    def Rot(
        phi: Union[float, jnp.ndarray, List[float]],
        theta: Union[float, jnp.ndarray, List[float]],
        omega: Union[float, jnp.ndarray, List[float]],
        wires: Union[int, List[int]],
        noise_params: Optional[Dict[str, float]] = None,
        random_key: Optional[jax.random.PRNGKey] = None,
        input_idx: int = -1,
    ) -> None:
        """
        Apply general rotation gate with optional noise.

        Applies a three-angle rotation Rot(phi, theta, omega) with optional
        gate errors and noise channels.

        Args:
            phi (Union[float, jnp.ndarray, List[float]]): First rotation angle.
            theta (Union[float, jnp.ndarray, List[float]]): Second rotation angle.
            omega (Union[float, jnp.ndarray, List[float]]): Third rotation angle.
            wires (Union[int, List[int]]): Qubit index or indices to apply rotation to.
            noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
                Supports BitFlip, PhaseFlip, Depolarizing, and GateError.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
            input_idx (int): Flag for the tape to track inputs

        Returns:
            None: Gate and noise are applied in-place to the circuit.
        """
        if noise_params is not None and "GateError" in noise_params:
            phi, random_key = UnitaryGates.GateError(phi, noise_params, random_key)
            theta, random_key = UnitaryGates.GateError(theta, noise_params, random_key)
            omega, random_key = UnitaryGates.GateError(omega, noise_params, random_key)
        op.Rot(phi, theta, omega, wires=wires, input_idx=False)
        UnitaryGates.Noise(wires, noise_params)

    @staticmethod
    def PauliRot(
        theta: float,
        pauli: str,
        wires: Union[int, List[int]],
        noise_params: Optional[Dict[str, float]] = None,
        random_key: Optional[jax.random.PRNGKey] = None,
        input_idx: int = -1,
    ) -> None:
        """
        Apply general rotation gate with optional noise.

        Applies a three-angle rotation Rot(phi, theta, omega) with optional
        gate errors and noise channels.

        Args:
            theta (Union[float, jnp.ndarray, List[float]]): Second rotation angle.
            pauli (str): Pauli operator to apply. Must be "X", "Y", or "Z".
            wires (Union[int, List[int]]): Qubit index or indices to apply rotation to.
            noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
                Supports BitFlip, PhaseFlip, Depolarizing, and GateError.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
            input_idx (int): Flag for the tape to track inputs

        Returns:
            None: Gate and noise are applied in-place to the circuit.
        """
        if noise_params is not None and "GateError" in noise_params:
            theta, random_key = UnitaryGates.GateError(theta, noise_params, random_key)
        op.PauliRot(theta, pauli, wires=wires, input_idx=input_idx)
        UnitaryGates.Noise(wires, noise_params)

    @staticmethod
    def RX(
        w: Union[float, jnp.ndarray, List[float]],
        wires: Union[int, List[int]],
        noise_params: Optional[Dict[str, float]] = None,
        random_key: Optional[jax.random.PRNGKey] = None,
        input_idx: int = -1,
    ) -> None:
        """
        Apply X-axis rotation with optional noise.

        Args:
            w (Union[float, jnp.ndarray, List[float]]): Rotation angle.
            wires (Union[int, List[int]]): Qubit index or indices.
            noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
            input_idx (int): Flag for the tape to track inputs

        Returns:
            None: Gate and noise are applied in-place to the circuit.
        """
        w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
        op.RX(w, wires=wires, input_idx=input_idx)
        UnitaryGates.Noise(wires, noise_params)

    @staticmethod
    def RY(
        w: Union[float, jnp.ndarray, List[float]],
        wires: Union[int, List[int]],
        noise_params: Optional[Dict[str, float]] = None,
        random_key: Optional[jax.random.PRNGKey] = None,
        input_idx: int = -1,
    ) -> None:
        """
        Apply Y-axis rotation with optional noise.

        Args:
            w (Union[float, jnp.ndarray, List[float]]): Rotation angle.
            wires (Union[int, List[int]]): Qubit index or indices.
            noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
            input_idx (int): Flag for the tape to track inputs

        Returns:
            None: Gate and noise are applied in-place to the circuit.
        """
        w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
        op.RY(w, wires=wires, input_idx=input_idx)
        UnitaryGates.Noise(wires, noise_params)

    @staticmethod
    def RZ(
        w: Union[float, jnp.ndarray, List[float]],
        wires: Union[int, List[int]],
        noise_params: Optional[Dict[str, float]] = None,
        random_key: Optional[jax.random.PRNGKey] = None,
        input_idx: int = -1,
    ) -> None:
        """
        Apply Z-axis rotation with optional noise.

        Args:
            w (Union[float, jnp.ndarray, List[float]]): Rotation angle.
            wires (Union[int, List[int]]): Qubit index or indices.
            noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
            input_idx (int): Flag for the tape to track inputs

        Returns:
            None: Gate and noise are applied in-place to the circuit.
        """
        w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
        op.RZ(w, wires=wires, input_idx=input_idx)
        UnitaryGates.Noise(wires, noise_params)

    @staticmethod
    def CRX(
        w: Union[float, jnp.ndarray, List[float]],
        wires: Union[int, List[int]],
        noise_params: Optional[Dict[str, float]] = None,
        random_key: Optional[jax.random.PRNGKey] = None,
        input_idx: int = -1,
    ) -> None:
        """
        Apply controlled X-rotation with optional noise.

        Args:
            w (Union[float, jnp.ndarray, List[float]]): Rotation angle.
            wires (Union[int, List[int]]): Control and target qubit indices.
            noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
            input_idx (int): Flag for the tape to track inputs

        Returns:
            None: Gate and noise are applied in-place to the circuit.
        """
        w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
        op.CRX(w, wires=wires, input_idx=input_idx)
        UnitaryGates.Noise(wires, noise_params)

    @staticmethod
    def CRY(
        w: Union[float, jnp.ndarray, List[float]],
        wires: Union[int, List[int]],
        noise_params: Optional[Dict[str, float]] = None,
        random_key: Optional[jax.random.PRNGKey] = None,
        input_idx: int = -1,
    ) -> None:
        """
        Apply controlled Y-rotation with optional noise.

        Args:
            w (Union[float, jnp.ndarray, List[float]]): Rotation angle.
            wires (Union[int, List[int]]): Control and target qubit indices.
            noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
            input_idx (int): Flag for the tape to track inputs

        Returns:
            None: Gate and noise are applied in-place to the circuit.
        """
        w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
        op.CRY(w, wires=wires, input_idx=input_idx)
        UnitaryGates.Noise(wires, noise_params)

    @staticmethod
    def CRZ(
        w: Union[float, jnp.ndarray, List[float]],
        wires: Union[int, List[int]],
        noise_params: Optional[Dict[str, float]] = None,
        random_key: Optional[jax.random.PRNGKey] = None,
        input_idx: int = -1,
    ) -> None:
        """
        Apply controlled Z-rotation with optional noise.

        Args:
            w (Union[float, jnp.ndarray, List[float]]): Rotation angle.
            wires (Union[int, List[int]]): Control and target qubit indices.
            noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
            input_idx (int): Flag for the tape to track inputs

        Returns:
            None: Gate and noise are applied in-place to the circuit.
        """
        w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
        op.CRZ(w, wires=wires, input_idx=input_idx)
        UnitaryGates.Noise(wires, noise_params)

    @staticmethod
    def CX(
        wires: Union[int, List[int]],
        noise_params: Optional[Dict[str, float]] = None,
        random_key: Optional[jax.random.PRNGKey] = None,
    ) -> None:
        """
        Apply controlled-NOT (CNOT) gate with optional noise.

        Args:
            wires (Union[int, List[int]]): Control and target qubit indices.
            noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
                (not used in this gate).

        Returns:
            None: Gate and noise are applied in-place to the circuit.
        """
        op.CX(wires=wires)
        UnitaryGates.Noise(wires, noise_params)

    @staticmethod
    def CY(
        wires: Union[int, List[int]],
        noise_params: Optional[Dict[str, float]] = None,
        random_key: Optional[jax.random.PRNGKey] = None,
    ) -> None:
        """
        Apply controlled-Y gate with optional noise.

        Args:
            wires (Union[int, List[int]]): Control and target qubit indices.
            noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
                (not used in this gate).

        Returns:
            None: Gate and noise are applied in-place to the circuit.
        """
        op.CY(wires=wires)
        UnitaryGates.Noise(wires, noise_params)

    @staticmethod
    def CZ(
        wires: Union[int, List[int]],
        noise_params: Optional[Dict[str, float]] = None,
        random_key: Optional[jax.random.PRNGKey] = None,
    ) -> None:
        """
        Apply controlled-Z gate with optional noise.

        Args:
            wires (Union[int, List[int]]): Control and target qubit indices.
            noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
                (not used in this gate).

        Returns:
            None: Gate and noise are applied in-place to the circuit.
        """
        op.CZ(wires=wires)
        UnitaryGates.Noise(wires, noise_params)

    @staticmethod
    def H(
        wires: Union[int, List[int]],
        noise_params: Optional[Dict[str, float]] = None,
        random_key: Optional[jax.random.PRNGKey] = None,
    ) -> None:
        """
        Apply Hadamard gate with optional noise.

        Args:
            wires (Union[int, List[int]]): Qubit index or indices.
            noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
                (not used in this gate).

        Returns:
            None: Gate and noise are applied in-place to the circuit.
        """
        op.H(wires=wires)
        UnitaryGates.Noise(wires, noise_params)

CRX(w, wires, noise_params=None, random_key=None, input_idx=-1) staticmethod #

Apply controlled X-rotation with optional noise.

Parameters:

Name Type Description Default
w Union[float, ndarray, List[float]]

Rotation angle.

required
wires Union[int, List[int]]

Control and target qubit indices.

required
noise_params Optional[Dict[str, float]]

Noise parameters dictionary.

None
random_key Optional[PRNGKey]

JAX random key for noise.

None
input_idx int

Flag for the tape to track inputs

-1

Returns:

Name Type Description
None None

Gate and noise are applied in-place to the circuit.

Source code in qml_essentials/unitary.py
@staticmethod
def CRX(
    w: Union[float, jnp.ndarray, List[float]],
    wires: Union[int, List[int]],
    noise_params: Optional[Dict[str, float]] = None,
    random_key: Optional[jax.random.PRNGKey] = None,
    input_idx: int = -1,
) -> None:
    """
    Apply controlled X-rotation with optional noise.

    Args:
        w (Union[float, jnp.ndarray, List[float]]): Rotation angle.
        wires (Union[int, List[int]]): Control and target qubit indices.
        noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
        input_idx (int): Flag for the tape to track inputs

    Returns:
        None: Gate and noise are applied in-place to the circuit.
    """
    w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
    op.CRX(w, wires=wires, input_idx=input_idx)
    UnitaryGates.Noise(wires, noise_params)

CRY(w, wires, noise_params=None, random_key=None, input_idx=-1) staticmethod #

Apply controlled Y-rotation with optional noise.

Parameters:

Name Type Description Default
w Union[float, ndarray, List[float]]

Rotation angle.

required
wires Union[int, List[int]]

Control and target qubit indices.

required
noise_params Optional[Dict[str, float]]

Noise parameters dictionary.

None
random_key Optional[PRNGKey]

JAX random key for noise.

None
input_idx int

Flag for the tape to track inputs

-1

Returns:

Name Type Description
None None

Gate and noise are applied in-place to the circuit.

Source code in qml_essentials/unitary.py
@staticmethod
def CRY(
    w: Union[float, jnp.ndarray, List[float]],
    wires: Union[int, List[int]],
    noise_params: Optional[Dict[str, float]] = None,
    random_key: Optional[jax.random.PRNGKey] = None,
    input_idx: int = -1,
) -> None:
    """
    Apply controlled Y-rotation with optional noise.

    Args:
        w (Union[float, jnp.ndarray, List[float]]): Rotation angle.
        wires (Union[int, List[int]]): Control and target qubit indices.
        noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
        input_idx (int): Flag for the tape to track inputs

    Returns:
        None: Gate and noise are applied in-place to the circuit.
    """
    w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
    op.CRY(w, wires=wires, input_idx=input_idx)
    UnitaryGates.Noise(wires, noise_params)

CRZ(w, wires, noise_params=None, random_key=None, input_idx=-1) staticmethod #

Apply controlled Z-rotation with optional noise.

Parameters:

Name Type Description Default
w Union[float, ndarray, List[float]]

Rotation angle.

required
wires Union[int, List[int]]

Control and target qubit indices.

required
noise_params Optional[Dict[str, float]]

Noise parameters dictionary.

None
random_key Optional[PRNGKey]

JAX random key for noise.

None
input_idx int

Flag for the tape to track inputs

-1

Returns:

Name Type Description
None None

Gate and noise are applied in-place to the circuit.

Source code in qml_essentials/unitary.py
@staticmethod
def CRZ(
    w: Union[float, jnp.ndarray, List[float]],
    wires: Union[int, List[int]],
    noise_params: Optional[Dict[str, float]] = None,
    random_key: Optional[jax.random.PRNGKey] = None,
    input_idx: int = -1,
) -> None:
    """
    Apply controlled Z-rotation with optional noise.

    Args:
        w (Union[float, jnp.ndarray, List[float]]): Rotation angle.
        wires (Union[int, List[int]]): Control and target qubit indices.
        noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
        input_idx (int): Flag for the tape to track inputs

    Returns:
        None: Gate and noise are applied in-place to the circuit.
    """
    w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
    op.CRZ(w, wires=wires, input_idx=input_idx)
    UnitaryGates.Noise(wires, noise_params)

CX(wires, noise_params=None, random_key=None) staticmethod #

Apply controlled-NOT (CNOT) gate with optional noise.

Parameters:

Name Type Description Default
wires Union[int, List[int]]

Control and target qubit indices.

required
noise_params Optional[Dict[str, float]]

Noise parameters dictionary.

None
random_key Optional[PRNGKey]

JAX random key for compatibility (not used in this gate).

None

Returns:

Name Type Description
None None

Gate and noise are applied in-place to the circuit.

Source code in qml_essentials/unitary.py
@staticmethod
def CX(
    wires: Union[int, List[int]],
    noise_params: Optional[Dict[str, float]] = None,
    random_key: Optional[jax.random.PRNGKey] = None,
) -> None:
    """
    Apply controlled-NOT (CNOT) gate with optional noise.

    Args:
        wires (Union[int, List[int]]): Control and target qubit indices.
        noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
            (not used in this gate).

    Returns:
        None: Gate and noise are applied in-place to the circuit.
    """
    op.CX(wires=wires)
    UnitaryGates.Noise(wires, noise_params)

CY(wires, noise_params=None, random_key=None) staticmethod #

Apply controlled-Y gate with optional noise.

Parameters:

Name Type Description Default
wires Union[int, List[int]]

Control and target qubit indices.

required
noise_params Optional[Dict[str, float]]

Noise parameters dictionary.

None
random_key Optional[PRNGKey]

JAX random key for compatibility (not used in this gate).

None

Returns:

Name Type Description
None None

Gate and noise are applied in-place to the circuit.

Source code in qml_essentials/unitary.py
@staticmethod
def CY(
    wires: Union[int, List[int]],
    noise_params: Optional[Dict[str, float]] = None,
    random_key: Optional[jax.random.PRNGKey] = None,
) -> None:
    """
    Apply controlled-Y gate with optional noise.

    Args:
        wires (Union[int, List[int]]): Control and target qubit indices.
        noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
            (not used in this gate).

    Returns:
        None: Gate and noise are applied in-place to the circuit.
    """
    op.CY(wires=wires)
    UnitaryGates.Noise(wires, noise_params)

CZ(wires, noise_params=None, random_key=None) staticmethod #

Apply controlled-Z gate with optional noise.

Parameters:

Name Type Description Default
wires Union[int, List[int]]

Control and target qubit indices.

required
noise_params Optional[Dict[str, float]]

Noise parameters dictionary.

None
random_key Optional[PRNGKey]

JAX random key for compatibility (not used in this gate).

None

Returns:

Name Type Description
None None

Gate and noise are applied in-place to the circuit.

Source code in qml_essentials/unitary.py
@staticmethod
def CZ(
    wires: Union[int, List[int]],
    noise_params: Optional[Dict[str, float]] = None,
    random_key: Optional[jax.random.PRNGKey] = None,
) -> None:
    """
    Apply controlled-Z gate with optional noise.

    Args:
        wires (Union[int, List[int]]): Control and target qubit indices.
        noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
            (not used in this gate).

    Returns:
        None: Gate and noise are applied in-place to the circuit.
    """
    op.CZ(wires=wires)
    UnitaryGates.Noise(wires, noise_params)

GateError(w, noise_params=None, random_key=None) staticmethod #

Apply gate error noise to rotation angle(s).

Adds Gaussian noise to gate rotation angles to simulate imperfect gate implementations.

Parameters:

Name Type Description Default
w Union[float, ndarray, List[float]]

Rotation angle(s) in radians.

required
noise_params Optional[Dict[str, float]]

Dictionary with optional "GateError" key specifying standard deviation of Gaussian noise.

None
random_key Optional[PRNGKey]

JAX random key for stochastic noise generation.

None

Returns:

Type Description
Tuple[ndarray, PRNGKey]

Tuple[jnp.ndarray, jax.random.PRNGKey]: Tuple containing: - Modified rotation angle(s) with applied noise - Updated JAX random key

Raises:

Type Description
AssertionError

If noise_params contains "GateError" but random_key is None.

Source code in qml_essentials/unitary.py
@staticmethod
def GateError(
    w: Union[float, jnp.ndarray, List[float]],
    noise_params: Optional[Dict[str, float]] = None,
    random_key: Optional[jax.random.PRNGKey] = None,
) -> Tuple[jnp.ndarray, jax.random.PRNGKey]:
    """
    Apply gate error noise to rotation angle(s).

    Adds Gaussian noise to gate rotation angles to simulate imperfect
    gate implementations.

    Args:
        w (Union[float, jnp.ndarray, List[float]]): Rotation angle(s) in radians.
        noise_params (Optional[Dict[str, float]]): Dictionary with optional
            "GateError" key specifying standard deviation of Gaussian noise.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for
            stochastic noise generation.

    Returns:
        Tuple[jnp.ndarray, jax.random.PRNGKey]: Tuple containing:
            - Modified rotation angle(s) with applied noise
            - Updated JAX random key

    Raises:
        AssertionError: If noise_params contains "GateError" but random_key is None.
    """
    if noise_params is not None and noise_params.get("GateError", None) is not None:
        assert (
            random_key is not None
        ), "A random_key must be provided when using GateError"

        if UnitaryGates.batch_gate_error:
            random_key, sub_key = safe_random_split(random_key)
        else:
            # Use a fixed key so that every batch element (under vmap)
            # draws the same noise value, effectively broadcasting.
            sub_key = jax.random.key(0)

        w += noise_params["GateError"] * jax.random.normal(
            sub_key,
            (
                w.shape
                if isinstance(w, jnp.ndarray) and UnitaryGates.batch_gate_error
                else ()
            ),
        )
    return w, random_key

H(wires, noise_params=None, random_key=None) staticmethod #

Apply Hadamard gate with optional noise.

Parameters:

Name Type Description Default
wires Union[int, List[int]]

Qubit index or indices.

required
noise_params Optional[Dict[str, float]]

Noise parameters dictionary.

None
random_key Optional[PRNGKey]

JAX random key for compatibility (not used in this gate).

None

Returns:

Name Type Description
None None

Gate and noise are applied in-place to the circuit.

Source code in qml_essentials/unitary.py
@staticmethod
def H(
    wires: Union[int, List[int]],
    noise_params: Optional[Dict[str, float]] = None,
    random_key: Optional[jax.random.PRNGKey] = None,
) -> None:
    """
    Apply Hadamard gate with optional noise.

    Args:
        wires (Union[int, List[int]]): Qubit index or indices.
        noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
            (not used in this gate).

    Returns:
        None: Gate and noise are applied in-place to the circuit.
    """
    op.H(wires=wires)
    UnitaryGates.Noise(wires, noise_params)

NQubitDepolarizingChannel(p, wires) staticmethod #

Generate Kraus operators for n-qubit depolarizing channel.

The n-qubit depolarizing channel models uniform depolarizing noise acting on n qubits simultaneously, useful for simulating realistic multi-qubit noise affecting entangling gates.

Parameters:

Name Type Description Default
p float

Total probability of depolarizing error (0 ≤ p ≤ 1).

required
wires List[int]

Qubit indices on which the channel acts. Must contain at least 2 qubits.

required

Returns:

Type Description
QubitChannel

op.QubitChannel: QubitChannel with Kraus operators representing the depolarizing noise channel.

Raises:

Type Description
ValueError

If p is not in [0, 1] or if fewer than 2 qubits provided.

Source code in qml_essentials/unitary.py
@staticmethod
def NQubitDepolarizingChannel(p: float, wires: List[int]) -> op.QubitChannel:
    """
    Generate Kraus operators for n-qubit depolarizing channel.

    The n-qubit depolarizing channel models uniform depolarizing noise
    acting on n qubits simultaneously, useful for simulating realistic
    multi-qubit noise affecting entangling gates.

    Args:
        p (float): Total probability of depolarizing error (0 ≤ p ≤ 1).
        wires (List[int]): Qubit indices on which the channel acts.
            Must contain at least 2 qubits.

    Returns:
        op.QubitChannel: QubitChannel with Kraus operators
            representing the depolarizing noise channel.

    Raises:
        ValueError: If p is not in [0, 1] or if fewer than 2 qubits provided.
    """

    def n_qubit_depolarizing_kraus(p: float, n: int) -> List[jnp.ndarray]:
        if not (0.0 <= p <= 1.0):
            raise ValueError(f"Probability p must be between 0 and 1, got {p}")
        if n < 2:
            raise ValueError(f"Number of qubits must be >= 2, got {n}")

        Id = jnp.eye(2)
        X = op.PauliX._matrix
        Y = op.PauliY._matrix
        Z = op.PauliZ._matrix
        paulis = [Id, X, Y, Z]

        dim = 2**n
        all_ops = []

        # Generate all n-qubit Pauli tensor products:
        for indices in itertools.product(range(4), repeat=n):
            P = jnp.eye(1)
            for idx in indices:
                P = jnp.kron(P, paulis[idx])
            all_ops.append(P)

        # Identity operator corresponds to all zeros indices (Id^n)
        K0 = jnp.sqrt(1 - p * (4**n - 1) / (4**n)) * jnp.eye(dim)

        kraus_ops = []
        for i, P in enumerate(all_ops):
            if i == 0:
                # Skip the identity, already handled as K0
                continue
            kraus_ops.append(jnp.sqrt(p / (4**n)) * P)

        return [K0] + kraus_ops

    return op.QubitChannel(n_qubit_depolarizing_kraus(p, len(wires)), wires=wires)

Noise(wires, noise_params=None) staticmethod #

Apply noise channels to specified qubits.

Applies various single-qubit and multi-qubit noise channels based on the provided noise parameters dictionary.

Parameters:

Name Type Description Default
wires Union[int, List[int]]

Qubit index or list of qubit indices to apply noise to.

required
noise_params Optional[Dict[str, float]]

Dictionary of noise parameters. Supported keys: - "BitFlip" (float): Bit flip error probability - "PhaseFlip" (float): Phase flip error probability - "Depolarizing" (float): Single-qubit depolarizing probability - "MultiQubitDepolarizing" (float): Multi-qubit depolarizing probability (applies if len(wires) > 1) All parameters default to 0.0 if not provided.

None

Returns:

Name Type Description
None None

Noise channels are applied in-place to the circuit.

Source code in qml_essentials/unitary.py
@staticmethod
def Noise(
    wires: Union[int, List[int]], noise_params: Optional[Dict[str, float]] = None
) -> None:
    """
    Apply noise channels to specified qubits.

    Applies various single-qubit and multi-qubit noise channels based on
    the provided noise parameters dictionary.

    Args:
        wires (Union[int, List[int]]): Qubit index or list of qubit indices
            to apply noise to.
        noise_params (Optional[Dict[str, float]]): Dictionary of noise
            parameters. Supported keys:
            - "BitFlip" (float): Bit flip error probability
            - "PhaseFlip" (float): Phase flip error probability
            - "Depolarizing" (float): Single-qubit depolarizing probability
            - "MultiQubitDepolarizing" (float): Multi-qubit depolarizing
              probability (applies if len(wires) > 1)
            All parameters default to 0.0 if not provided.

    Returns:
        None: Noise channels are applied in-place to the circuit.
    """
    if noise_params is not None:
        if isinstance(wires, int):
            wires = [wires]  # single qubit gate

        # noise on single qubits
        for wire in wires:
            bf = noise_params.get("BitFlip", 0.0)
            if bf > 0:
                op.BitFlip(bf, wires=wire)

            pf = noise_params.get("PhaseFlip", 0.0)
            if pf > 0:
                op.PhaseFlip(pf, wires=wire)

            dp = noise_params.get("Depolarizing", 0.0)
            if dp > 0:
                op.DepolarizingChannel(dp, wires=wire)

        # noise on two-qubits
        if len(wires) > 1:
            p = noise_params.get("MultiQubitDepolarizing", 0.0)
            if p > 0:
                UnitaryGates.NQubitDepolarizingChannel(p, wires)

PauliRot(theta, pauli, wires, noise_params=None, random_key=None, input_idx=-1) staticmethod #

Apply general rotation gate with optional noise.

Applies a three-angle rotation Rot(phi, theta, omega) with optional gate errors and noise channels.

Parameters:

Name Type Description Default
theta Union[float, ndarray, List[float]]

Second rotation angle.

required
pauli str

Pauli operator to apply. Must be "X", "Y", or "Z".

required
wires Union[int, List[int]]

Qubit index or indices to apply rotation to.

required
noise_params Optional[Dict[str, float]]

Noise parameters dictionary. Supports BitFlip, PhaseFlip, Depolarizing, and GateError.

None
random_key Optional[PRNGKey]

JAX random key for noise.

None
input_idx int

Flag for the tape to track inputs

-1

Returns:

Name Type Description
None None

Gate and noise are applied in-place to the circuit.

Source code in qml_essentials/unitary.py
@staticmethod
def PauliRot(
    theta: float,
    pauli: str,
    wires: Union[int, List[int]],
    noise_params: Optional[Dict[str, float]] = None,
    random_key: Optional[jax.random.PRNGKey] = None,
    input_idx: int = -1,
) -> None:
    """
    Apply general rotation gate with optional noise.

    Applies a three-angle rotation Rot(phi, theta, omega) with optional
    gate errors and noise channels.

    Args:
        theta (Union[float, jnp.ndarray, List[float]]): Second rotation angle.
        pauli (str): Pauli operator to apply. Must be "X", "Y", or "Z".
        wires (Union[int, List[int]]): Qubit index or indices to apply rotation to.
        noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
            Supports BitFlip, PhaseFlip, Depolarizing, and GateError.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
        input_idx (int): Flag for the tape to track inputs

    Returns:
        None: Gate and noise are applied in-place to the circuit.
    """
    if noise_params is not None and "GateError" in noise_params:
        theta, random_key = UnitaryGates.GateError(theta, noise_params, random_key)
    op.PauliRot(theta, pauli, wires=wires, input_idx=input_idx)
    UnitaryGates.Noise(wires, noise_params)

RX(w, wires, noise_params=None, random_key=None, input_idx=-1) staticmethod #

Apply X-axis rotation with optional noise.

Parameters:

Name Type Description Default
w Union[float, ndarray, List[float]]

Rotation angle.

required
wires Union[int, List[int]]

Qubit index or indices.

required
noise_params Optional[Dict[str, float]]

Noise parameters dictionary.

None
random_key Optional[PRNGKey]

JAX random key for noise.

None
input_idx int

Flag for the tape to track inputs

-1

Returns:

Name Type Description
None None

Gate and noise are applied in-place to the circuit.

Source code in qml_essentials/unitary.py
@staticmethod
def RX(
    w: Union[float, jnp.ndarray, List[float]],
    wires: Union[int, List[int]],
    noise_params: Optional[Dict[str, float]] = None,
    random_key: Optional[jax.random.PRNGKey] = None,
    input_idx: int = -1,
) -> None:
    """
    Apply X-axis rotation with optional noise.

    Args:
        w (Union[float, jnp.ndarray, List[float]]): Rotation angle.
        wires (Union[int, List[int]]): Qubit index or indices.
        noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
        input_idx (int): Flag for the tape to track inputs

    Returns:
        None: Gate and noise are applied in-place to the circuit.
    """
    w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
    op.RX(w, wires=wires, input_idx=input_idx)
    UnitaryGates.Noise(wires, noise_params)

RY(w, wires, noise_params=None, random_key=None, input_idx=-1) staticmethod #

Apply Y-axis rotation with optional noise.

Parameters:

Name Type Description Default
w Union[float, ndarray, List[float]]

Rotation angle.

required
wires Union[int, List[int]]

Qubit index or indices.

required
noise_params Optional[Dict[str, float]]

Noise parameters dictionary.

None
random_key Optional[PRNGKey]

JAX random key for noise.

None
input_idx int

Flag for the tape to track inputs

-1

Returns:

Name Type Description
None None

Gate and noise are applied in-place to the circuit.

Source code in qml_essentials/unitary.py
@staticmethod
def RY(
    w: Union[float, jnp.ndarray, List[float]],
    wires: Union[int, List[int]],
    noise_params: Optional[Dict[str, float]] = None,
    random_key: Optional[jax.random.PRNGKey] = None,
    input_idx: int = -1,
) -> None:
    """
    Apply Y-axis rotation with optional noise.

    Args:
        w (Union[float, jnp.ndarray, List[float]]): Rotation angle.
        wires (Union[int, List[int]]): Qubit index or indices.
        noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
        input_idx (int): Flag for the tape to track inputs

    Returns:
        None: Gate and noise are applied in-place to the circuit.
    """
    w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
    op.RY(w, wires=wires, input_idx=input_idx)
    UnitaryGates.Noise(wires, noise_params)

RZ(w, wires, noise_params=None, random_key=None, input_idx=-1) staticmethod #

Apply Z-axis rotation with optional noise.

Parameters:

Name Type Description Default
w Union[float, ndarray, List[float]]

Rotation angle.

required
wires Union[int, List[int]]

Qubit index or indices.

required
noise_params Optional[Dict[str, float]]

Noise parameters dictionary.

None
random_key Optional[PRNGKey]

JAX random key for noise.

None
input_idx int

Flag for the tape to track inputs

-1

Returns:

Name Type Description
None None

Gate and noise are applied in-place to the circuit.

Source code in qml_essentials/unitary.py
@staticmethod
def RZ(
    w: Union[float, jnp.ndarray, List[float]],
    wires: Union[int, List[int]],
    noise_params: Optional[Dict[str, float]] = None,
    random_key: Optional[jax.random.PRNGKey] = None,
    input_idx: int = -1,
) -> None:
    """
    Apply Z-axis rotation with optional noise.

    Args:
        w (Union[float, jnp.ndarray, List[float]]): Rotation angle.
        wires (Union[int, List[int]]): Qubit index or indices.
        noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
        input_idx (int): Flag for the tape to track inputs

    Returns:
        None: Gate and noise are applied in-place to the circuit.
    """
    w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
    op.RZ(w, wires=wires, input_idx=input_idx)
    UnitaryGates.Noise(wires, noise_params)

Rot(phi, theta, omega, wires, noise_params=None, random_key=None, input_idx=-1) staticmethod #

Apply general rotation gate with optional noise.

Applies a three-angle rotation Rot(phi, theta, omega) with optional gate errors and noise channels.

Parameters:

Name Type Description Default
phi Union[float, ndarray, List[float]]

First rotation angle.

required
theta Union[float, ndarray, List[float]]

Second rotation angle.

required
omega Union[float, ndarray, List[float]]

Third rotation angle.

required
wires Union[int, List[int]]

Qubit index or indices to apply rotation to.

required
noise_params Optional[Dict[str, float]]

Noise parameters dictionary. Supports BitFlip, PhaseFlip, Depolarizing, and GateError.

None
random_key Optional[PRNGKey]

JAX random key for noise.

None
input_idx int

Flag for the tape to track inputs

-1

Returns:

Name Type Description
None None

Gate and noise are applied in-place to the circuit.

Source code in qml_essentials/unitary.py
@staticmethod
def Rot(
    phi: Union[float, jnp.ndarray, List[float]],
    theta: Union[float, jnp.ndarray, List[float]],
    omega: Union[float, jnp.ndarray, List[float]],
    wires: Union[int, List[int]],
    noise_params: Optional[Dict[str, float]] = None,
    random_key: Optional[jax.random.PRNGKey] = None,
    input_idx: int = -1,
) -> None:
    """
    Apply general rotation gate with optional noise.

    Applies a three-angle rotation Rot(phi, theta, omega) with optional
    gate errors and noise channels.

    Args:
        phi (Union[float, jnp.ndarray, List[float]]): First rotation angle.
        theta (Union[float, jnp.ndarray, List[float]]): Second rotation angle.
        omega (Union[float, jnp.ndarray, List[float]]): Third rotation angle.
        wires (Union[int, List[int]]): Qubit index or indices to apply rotation to.
        noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
            Supports BitFlip, PhaseFlip, Depolarizing, and GateError.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for noise.
        input_idx (int): Flag for the tape to track inputs

    Returns:
        None: Gate and noise are applied in-place to the circuit.
    """
    if noise_params is not None and "GateError" in noise_params:
        phi, random_key = UnitaryGates.GateError(phi, noise_params, random_key)
        theta, random_key = UnitaryGates.GateError(theta, noise_params, random_key)
        omega, random_key = UnitaryGates.GateError(omega, noise_params, random_key)
    op.Rot(phi, theta, omega, wires=wires, input_idx=False)
    UnitaryGates.Noise(wires, noise_params)

Pulse Gates#

from qml_essentials.gates import PulseGates

Pulse-level implementations of quantum gates.

Implements quantum gates using time-dependent Hamiltonians and pulse sequences, following the approach from https://doi.org/10.5445/IR/1000184129. The active pulse envelope is selected via :meth:PulseInformation.set_envelope.

Attributes:

Name Type Description
omega_q

Qubit frequency (10Ï€).

omega_c

Carrier frequency (10Ï€).

_active_envelope str

Name of the currently active envelope shape.

Source code in qml_essentials/pulses.py
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
class PulseGates:
    """Pulse-level implementations of quantum gates.

    Implements quantum gates using time-dependent Hamiltonians and pulse
    sequences, following the approach from https://doi.org/10.5445/IR/1000184129.
    The active pulse envelope is selected via
    :meth:`PulseInformation.set_envelope`.

    Attributes:
        omega_q: Qubit frequency (10Ï€).
        omega_c: Carrier frequency (10Ï€).
        _active_envelope: Name of the currently active envelope shape.
    """

    # NOTE: Implementation of S, RX, RY, RZ, CZ, CNOT/CX and H pulse level
    #   gates closely follow https://doi.org/10.5445/IR/1000184129
    omega_q = 10 * jnp.pi
    omega_c = 10 * jnp.pi

    H_static = jnp.array(
        [[jnp.exp(1j * omega_q / 2), 0], [0, jnp.exp(-1j * omega_q / 2)]]
    )

    Id = jnp.eye(2, dtype=jnp.complex64)
    X = jnp.array([[0, 1], [1, 0]])
    Y = jnp.array([[0, -1j], [1j, 0]])
    Z = jnp.array([[1, 0], [0, -1]])

    _H_X = H_static.conj().T @ X @ H_static
    _H_Y = H_static.conj().T @ Y @ H_static

    _H_CZ = (jnp.pi / 4) * (
        jnp.kron(Id, Id) - jnp.kron(Z, Id) - jnp.kron(Id, Z) + jnp.kron(Z, Z)
    )

    _H_corr = jnp.pi / 2 * jnp.eye(2, dtype=jnp.complex64)

    _active_envelope: str = "gaussian"

    @staticmethod
    def _coeff_Sx(p, t):
        """Coefficient function for RX pulse (active envelope)."""
        t_c = t / 2
        env = PulseEnvelope.gaussian(p, t, t_c)
        carrier = jnp.cos(PulseGates.omega_c * t + jnp.pi)
        return env * carrier * p[-1]

    @staticmethod
    def _coeff_Sy(p, t):
        """Coefficient function for RY pulse (active envelope)."""
        t_c = t / 2
        env = PulseEnvelope.gaussian(p, t, t_c)
        carrier = jnp.cos(PulseGates.omega_c * t - jnp.pi / 2)
        return env * carrier * p[-1]

    @staticmethod
    def _coeff_Sz(p, t):
        """Coefficient function for RZ pulse: p * w."""
        return p[0] * p[1]

    @staticmethod
    def _coeff_Sc(p, t):
        """Constant coefficient for H correction phase."""
        return -1.0

    @staticmethod
    def _coeff_Scz(p, t):
        """Coefficient function for CZ pulse."""
        return p * jnp.pi

    @staticmethod
    def _record_pulse_event(gate_name, w, wires, pulse_params, parent=None):
        """Append a PulseEvent to the active pulse tape if recording.

        This is called from leaf gate methods (RX, RY, RZ, CZ) so that
        :func:`~qml_essentials.tape.pulse_recording` can collect events
        without the caller needing to know about the tape.
        """
        ptape = active_pulse_tape()
        if ptape is None:
            return

        from qml_essentials.drawing import PulseEvent, LEAF_META

        meta = LEAF_META.get(gate_name, {})
        wires_list = [wires] if isinstance(wires, int) else list(wires)

        if meta.get("physical", False):
            info = PulseEnvelope.get(PulseInformation.get_envelope())
            pp = PulseInformation.gate_by_name(gate_name).split_params(pulse_params)
            env_p = pp[:-1]
            dur = float(pp[-1])
            ptape.append(
                PulseEvent(
                    gate=gate_name,
                    wires=wires_list,
                    envelope_fn=info["fn"],
                    envelope_params=jnp.array(env_p),
                    w=float(w),
                    duration=dur,
                    carrier_phase=meta["carrier_phase"],
                    parent=parent,
                )
            )
        else:
            pp = PulseInformation.gate_by_name(gate_name).split_params(pulse_params)
            ptape.append(
                PulseEvent(
                    gate=gate_name,
                    wires=wires_list,
                    envelope_fn=None,
                    envelope_params=jnp.ravel(jnp.asarray(pp)),
                    w=float(w) if not isinstance(w, list) else 0.0,
                    duration=1.0,
                    carrier_phase=0.0,
                    parent=parent,
                )
            )

    @staticmethod
    def Rot(
        phi: float,
        theta: float,
        omega: float,
        wires: Union[int, List[int]],
        pulse_params: Optional[jnp.ndarray] = None,
        noise_params: Optional[Dict[str, float]] = None,
        random_key: Optional[jax.random.PRNGKey] = None,
    ) -> None:
        """
        Apply general rotation via decomposition: RZ(phi) · RY(theta) · RZ(omega).

        Args:
            phi (float): First rotation angle.
            theta (float): Second rotation angle.
            omega (float): Third rotation angle.
            wires (Union[int, List[int]]): Qubit index or indices to apply rotation to.
            pulse_params (Optional[jnp.ndarray]): Pulse parameters for the
                composing gates. If None, uses optimized parameters.
            noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility

        Returns:
            None: Gates are applied in-place to the circuit.
        """
        if noise_params is not None and "GateError" in noise_params:
            phi, random_key = UnitaryGates.GateError(phi, noise_params, random_key)
            theta, random_key = UnitaryGates.GateError(theta, noise_params, random_key)
            omega, random_key = UnitaryGates.GateError(omega, noise_params, random_key)
        PulseGates._execute_composite("Rot", [phi, theta, omega], wires, pulse_params)
        UnitaryGates.Noise(wires, noise_params)

    @staticmethod
    def PauliRot(
        pauli: str,
        theta: float,
        wires: Union[int, List[int]],
        pulse_params: Optional[jnp.ndarray] = None,
        noise_params: Optional[Dict[str, float]] = None,
        random_key: Optional[jax.random.PRNGKey] = None,
    ) -> None:
        """Not implemented as a PulseGate."""
        raise NotImplementedError("PauliRot gate is not implemented as PulseGate")

    @staticmethod
    def RX(
        w: float,
        wires: Union[int, List[int]],
        pulse_params: Optional[jnp.ndarray] = None,
        noise_params: Optional[Dict[str, float]] = None,
        random_key: Optional[jax.random.PRNGKey] = None,
    ) -> None:
        """Apply X-axis rotation using the active pulse envelope.

        Args:
            w: Rotation angle in radians.
            wires: Qubit index or indices.
            pulse_params: Envelope parameters ``[env_0, ..., env_n, t]``.
                If ``None``, uses optimized defaults.
            noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
        """
        pulse_params = PulseInformation.RX.split_params(pulse_params)

        PulseGates._record_pulse_event("RX", w, wires, pulse_params)

        _H = op.Hermitian(PulseGates._H_X, wires=wires, record=False)
        H_eff = PulseGates._coeff_Sx * _H

        # Pack: [envelope_params..., w] - evolution time is the last element
        # of pulse_params (pulse_params[-1]).
        w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
        env_params = jnp.array([*pulse_params[:-1], w])
        ys.evolve(H_eff, name="RX")([env_params], pulse_params[-1])
        UnitaryGates.Noise(wires, noise_params)

    @staticmethod
    def RY(
        w: float,
        wires: Union[int, List[int]],
        pulse_params: Optional[jnp.ndarray] = None,
        noise_params: Optional[Dict[str, float]] = None,
        random_key: Optional[jax.random.PRNGKey] = None,
    ) -> None:
        """Apply Y-axis rotation using the active pulse envelope.

        Args:
            w: Rotation angle in radians.
            wires: Qubit index or indices.
            pulse_params: Envelope parameters ``[env_0, ..., env_n, t]``.
                If ``None``, uses optimized defaults.
            noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
        """
        pulse_params = PulseInformation.RY.split_params(pulse_params)

        PulseGates._record_pulse_event("RY", w, wires, pulse_params)

        _H = op.Hermitian(PulseGates._H_Y, wires=wires, record=False)
        H_eff = PulseGates._coeff_Sy * _H

        # Pack w into the params so the coefficient function doesn't need
        # a closure - this enables JIT solver cache sharing across all RY calls.
        w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
        env_params = jnp.array([*pulse_params[:-1], w])
        ys.evolve(H_eff, name="RY")([env_params], pulse_params[-1])
        UnitaryGates.Noise(wires, noise_params)

    @staticmethod
    def RZ(
        w: float,
        wires: Union[int, List[int]],
        pulse_params: Optional[float] = None,
        noise_params: Optional[Dict[str, float]] = None,
        random_key: Optional[jax.random.PRNGKey] = None,
    ) -> None:
        """
        Apply Z-axis rotation using pulse-level implementation.

        Implements RZ rotation using virtual Z rotations (phase tracking)
        without physical pulse application.

        Args:
            w (float): Rotation angle in radians.
            wires (Union[int, List[int]]): Qubit index or indices to apply rotation to.
            pulse_params (Optional[float]): Duration parameter for the pulse.
                Rotation angle = w * 2 * pulse_params. Defaults to 0.5 if None.
            noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility

        Returns:
            None: Gate is applied in-place to the circuit.
        """
        pulse_params = PulseInformation.RZ.split_params(pulse_params)

        PulseGates._record_pulse_event("RZ", w, wires, pulse_params)

        _H = op.Hermitian(PulseGates.Z, wires=wires, record=False)
        H_eff = PulseGates._coeff_Sz * _H

        # Pack w into the params so the coefficient function doesn't need
        # a closure - [pulse_param_scalar, w] enables JIT solver cache sharing.
        # pulse_params may be a 1-element array or scalar; ravel + index to
        # ensure a scalar for concatenation.
        w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
        pp_scalar = jnp.ravel(jnp.asarray(pulse_params))[0]
        ys.evolve(H_eff, name="RZ")([jnp.array([pp_scalar, w])], 1)

        UnitaryGates.Noise(wires, noise_params)

    @staticmethod
    def _resolve_wires(wire_fn, wires):
        """Resolve a wire selector string to actual wire(s).

        Args:
            wire_fn: ``"all"``, ``"target"``, or ``"control"``.
            wires: Parent gate's wire(s) (int or list).

        Returns:
            Wire(s) for the child gate.
        """
        wires_list = [wires] if isinstance(wires, int) else list(wires)
        if wire_fn == "all":
            return wires if len(wires_list) > 1 else wires_list[0]
        if wire_fn == "target":
            return wires_list[-1] if len(wires_list) > 1 else wires_list[0]
        if wire_fn == "control":
            return wires_list[0]
        raise ValueError(f"Unknown wire_fn: {wire_fn!r}")

    @staticmethod
    def _execute_composite(gate_name, w, wires, pulse_params=None):
        """Execute a composite gate by walking its decomposition.

        Reads the :class:`DecompositionStep` list from
        :class:`PulseInformation` and dispatches each step to the
        appropriate ``PulseGates`` method.

        Args:
            gate_name: Gate name (e.g. ``"H"``, ``"CX"``).
            w: Rotation angle(s) passed to the parent gate.
            wires: Wire(s) of the parent gate.
            pulse_params: Optional pulse parameters (split across children).
        """
        pp_obj = PulseInformation.gate_by_name(gate_name)
        parts = pp_obj.split_params(pulse_params)

        for step, child_params in zip(pp_obj.decomposition, parts):
            child_wires = PulseGates._resolve_wires(step.wire_fn, wires)
            child_w = step.angle_fn(w) if step.angle_fn is not None else w
            child_gate = getattr(PulseGates, step.gate.name)

            # Leaf gates that take a rotation angle
            if step.gate.name in ("RX", "RY", "RZ"):
                child_gate(child_w, wires=child_wires, pulse_params=child_params)
            # Leaf gates without a rotation angle
            elif step.gate.name in ("CZ",):
                child_gate(wires=child_wires, pulse_params=child_params)
            # Composite gates with a rotation angle (CRX, CRY, CRZ, Rot, ...)
            elif step.gate.name in ("Rot",):
                # Rot expects (phi, theta, omega, wires, ...)
                child_gate(*child_w, wires=child_wires, pulse_params=child_params)
            elif step.gate.decomposition is not None and step.gate.name in (
                "CRX",
                "CRY",
                "CRZ",
            ):
                child_gate(child_w, wires=child_wires, pulse_params=child_params)
            # Other composite gates (H, CX, CY, ...)
            else:
                child_gate(wires=child_wires, pulse_params=child_params)

    @staticmethod
    def H(
        wires: Union[int, List[int]],
        pulse_params: Optional[jnp.ndarray] = None,
        noise_params: Optional[Dict[str, float]] = None,
        random_key: Optional[jax.random.PRNGKey] = None,
    ) -> None:
        """Apply Hadamard gate using pulse decomposition.

        Decomposes as RZ(π) · RY(π/2) followed by a correction phase.

        Args:
            wires (Union[int, List[int]]): Qubit index or indices.
            noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
                (not used in this gate).
        """
        PulseGates._execute_composite("H", 0.0, wires, pulse_params)

        # Correction phase unique to the H gate
        _H = op.Hermitian(PulseGates._H_corr, wires=wires, record=False)
        H_corr = PulseGates._coeff_Sc * _H
        ys.evolve(H_corr, name="H")([0], 1)
        UnitaryGates.Noise(wires, noise_params)

    @staticmethod
    def CX(
        wires: List[int],
        pulse_params: Optional[jnp.ndarray] = None,
        noise_params: Optional[Dict[str, float]] = None,
        random_key: Optional[jax.random.PRNGKey] = None,
    ) -> None:
        """Apply CNOT gate via decomposition: H(target) · CZ · H(target).

        Args:
            wires (List[int]): Control and target qubit indices [control, target].
            pulse_params (Optional[jnp.ndarray]): Pulse parameters for the
                composing gates. If None, uses optimized parameters.
            noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
                (not used in this gate).

        Returns:
            None: Gate is applied in-place to the circuit.
        """
        PulseGates._execute_composite("CX", 0.0, wires, pulse_params)
        UnitaryGates.Noise(wires, noise_params)

    @staticmethod
    def CY(
        wires: List[int],
        pulse_params: Optional[jnp.ndarray] = None,
        noise_params: Optional[Dict[str, float]] = None,
        random_key: Optional[jax.random.PRNGKey] = None,
    ) -> None:
        """Apply controlled-Y via decomposition.

        Args:
            wires (List[int]): Control and target qubit indices [control, target].
            pulse_params (Optional[jnp.ndarray]): Pulse parameters for the
                composing gates. If None, uses optimized parameters.
            noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
                (not used in this gate).

        """
        PulseGates._execute_composite("CY", 0.0, wires, pulse_params)
        UnitaryGates.Noise(wires, noise_params)

    @staticmethod
    def CZ(
        wires: List[int],
        pulse_params: Optional[float] = None,
        noise_params: Optional[Dict[str, float]] = None,
        random_key: Optional[jax.random.PRNGKey] = None,
    ) -> None:
        """Apply controlled-Z using ZZ coupling Hamiltonian.

        Args:
            wires (List[int]): Control and target qubit indices.
            pulse_params (Optional[float]): Time or duration parameter for
                the pulse evolution. If None, uses optimized value.
            noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
                (not used in this gate).

        """
        if pulse_params is None:
            pulse_params = PulseInformation.CZ.params

        PulseGates._record_pulse_event("CZ", 0.0, wires, pulse_params)

        _H = op.Hermitian(PulseGates._H_CZ, wires=wires, record=False)
        H_eff = PulseGates._coeff_Scz * _H
        ys.evolve(H_eff, name="CZ")([pulse_params], 1)
        UnitaryGates.Noise(wires, noise_params)

    @staticmethod
    def CRX(
        w: float,
        wires: List[int],
        pulse_params: Optional[jnp.ndarray] = None,
        noise_params: Optional[Dict[str, float]] = None,
        random_key: Optional[jax.random.PRNGKey] = None,
    ) -> None:
        """Apply controlled-RX via decomposition.

        Args:
            w (float): Rotation angle in radians.
            wires (List[int]): Control and target qubit indices [control, target].
            pulse_params (Optional[jnp.ndarray]): Pulse parameters for the
                composing gates. If None, uses optimized parameters.
            noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
                (not used in this gate).
        """
        PulseGates._execute_composite("CRX", w, wires, pulse_params)
        UnitaryGates.Noise(wires, noise_params)

    @staticmethod
    def CRY(
        w: float,
        wires: List[int],
        pulse_params: Optional[jnp.ndarray] = None,
        noise_params: Optional[Dict[str, float]] = None,
        random_key: Optional[jax.random.PRNGKey] = None,
    ) -> None:
        """Apply controlled-RY via decomposition.

        Args:
            w (float): Rotation angle in radians.
            wires (List[int]): Control and target qubit indices [control, target].
            pulse_params (Optional[jnp.ndarray]): Pulse parameters for the
                composing gates. If None, uses optimized parameters.
            noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
        """
        w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
        PulseGates._execute_composite("CRY", w, wires, pulse_params)
        UnitaryGates.Noise(wires, noise_params)

    @staticmethod
    def CRZ(
        w: float,
        wires: List[int],
        pulse_params: Optional[jnp.ndarray] = None,
        noise_params: Optional[Dict[str, float]] = None,
        random_key: Optional[jax.random.PRNGKey] = None,
    ) -> None:
        """Apply controlled-RZ via decomposition.

        Args:
            w (float): Rotation angle in radians.
            wires (List[int]): Control and target qubit indices [control, target].
            pulse_params (Optional[jnp.ndarray]): Pulse parameters for the
                composing gates. If None, uses optimized parameters.
            noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
        """
        w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
        PulseGates._execute_composite("CRZ", w, wires, pulse_params)
        UnitaryGates.Noise(wires, noise_params)

CRX(w, wires, pulse_params=None, noise_params=None, random_key=None) staticmethod #

Apply controlled-RX via decomposition.

Parameters:

Name Type Description Default
w float

Rotation angle in radians.

required
wires List[int]

Control and target qubit indices [control, target].

required
pulse_params Optional[ndarray]

Pulse parameters for the composing gates. If None, uses optimized parameters.

None
noise_params Optional[Dict[str, float]]

Noise parameters dictionary.

None
random_key Optional[PRNGKey]

JAX random key for compatibility (not used in this gate).

None
Source code in qml_essentials/pulses.py
@staticmethod
def CRX(
    w: float,
    wires: List[int],
    pulse_params: Optional[jnp.ndarray] = None,
    noise_params: Optional[Dict[str, float]] = None,
    random_key: Optional[jax.random.PRNGKey] = None,
) -> None:
    """Apply controlled-RX via decomposition.

    Args:
        w (float): Rotation angle in radians.
        wires (List[int]): Control and target qubit indices [control, target].
        pulse_params (Optional[jnp.ndarray]): Pulse parameters for the
            composing gates. If None, uses optimized parameters.
        noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
            (not used in this gate).
    """
    PulseGates._execute_composite("CRX", w, wires, pulse_params)
    UnitaryGates.Noise(wires, noise_params)

CRY(w, wires, pulse_params=None, noise_params=None, random_key=None) staticmethod #

Apply controlled-RY via decomposition.

Parameters:

Name Type Description Default
w float

Rotation angle in radians.

required
wires List[int]

Control and target qubit indices [control, target].

required
pulse_params Optional[ndarray]

Pulse parameters for the composing gates. If None, uses optimized parameters.

None
noise_params Optional[Dict[str, float]]

Noise parameters dictionary.

None
random_key Optional[PRNGKey]

JAX random key for compatibility

None
Source code in qml_essentials/pulses.py
@staticmethod
def CRY(
    w: float,
    wires: List[int],
    pulse_params: Optional[jnp.ndarray] = None,
    noise_params: Optional[Dict[str, float]] = None,
    random_key: Optional[jax.random.PRNGKey] = None,
) -> None:
    """Apply controlled-RY via decomposition.

    Args:
        w (float): Rotation angle in radians.
        wires (List[int]): Control and target qubit indices [control, target].
        pulse_params (Optional[jnp.ndarray]): Pulse parameters for the
            composing gates. If None, uses optimized parameters.
        noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
    """
    w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
    PulseGates._execute_composite("CRY", w, wires, pulse_params)
    UnitaryGates.Noise(wires, noise_params)

CRZ(w, wires, pulse_params=None, noise_params=None, random_key=None) staticmethod #

Apply controlled-RZ via decomposition.

Parameters:

Name Type Description Default
w float

Rotation angle in radians.

required
wires List[int]

Control and target qubit indices [control, target].

required
pulse_params Optional[ndarray]

Pulse parameters for the composing gates. If None, uses optimized parameters.

None
noise_params Optional[Dict[str, float]]

Noise parameters dictionary.

None
random_key Optional[PRNGKey]

JAX random key for compatibility

None
Source code in qml_essentials/pulses.py
@staticmethod
def CRZ(
    w: float,
    wires: List[int],
    pulse_params: Optional[jnp.ndarray] = None,
    noise_params: Optional[Dict[str, float]] = None,
    random_key: Optional[jax.random.PRNGKey] = None,
) -> None:
    """Apply controlled-RZ via decomposition.

    Args:
        w (float): Rotation angle in radians.
        wires (List[int]): Control and target qubit indices [control, target].
        pulse_params (Optional[jnp.ndarray]): Pulse parameters for the
            composing gates. If None, uses optimized parameters.
        noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
    """
    w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
    PulseGates._execute_composite("CRZ", w, wires, pulse_params)
    UnitaryGates.Noise(wires, noise_params)

CX(wires, pulse_params=None, noise_params=None, random_key=None) staticmethod #

Apply CNOT gate via decomposition: H(target) · CZ · H(target).

Parameters:

Name Type Description Default
wires List[int]

Control and target qubit indices [control, target].

required
pulse_params Optional[ndarray]

Pulse parameters for the composing gates. If None, uses optimized parameters.

None
noise_params Optional[Dict[str, float]]

Noise parameters dictionary.

None
random_key Optional[PRNGKey]

JAX random key for compatibility (not used in this gate).

None

Returns:

Name Type Description
None None

Gate is applied in-place to the circuit.

Source code in qml_essentials/pulses.py
@staticmethod
def CX(
    wires: List[int],
    pulse_params: Optional[jnp.ndarray] = None,
    noise_params: Optional[Dict[str, float]] = None,
    random_key: Optional[jax.random.PRNGKey] = None,
) -> None:
    """Apply CNOT gate via decomposition: H(target) · CZ · H(target).

    Args:
        wires (List[int]): Control and target qubit indices [control, target].
        pulse_params (Optional[jnp.ndarray]): Pulse parameters for the
            composing gates. If None, uses optimized parameters.
        noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
            (not used in this gate).

    Returns:
        None: Gate is applied in-place to the circuit.
    """
    PulseGates._execute_composite("CX", 0.0, wires, pulse_params)
    UnitaryGates.Noise(wires, noise_params)

CY(wires, pulse_params=None, noise_params=None, random_key=None) staticmethod #

Apply controlled-Y via decomposition.

Parameters:

Name Type Description Default
wires List[int]

Control and target qubit indices [control, target].

required
pulse_params Optional[ndarray]

Pulse parameters for the composing gates. If None, uses optimized parameters.

None
noise_params Optional[Dict[str, float]]

Noise parameters dictionary.

None
random_key Optional[PRNGKey]

JAX random key for compatibility (not used in this gate).

None
Source code in qml_essentials/pulses.py
@staticmethod
def CY(
    wires: List[int],
    pulse_params: Optional[jnp.ndarray] = None,
    noise_params: Optional[Dict[str, float]] = None,
    random_key: Optional[jax.random.PRNGKey] = None,
) -> None:
    """Apply controlled-Y via decomposition.

    Args:
        wires (List[int]): Control and target qubit indices [control, target].
        pulse_params (Optional[jnp.ndarray]): Pulse parameters for the
            composing gates. If None, uses optimized parameters.
        noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
            (not used in this gate).

    """
    PulseGates._execute_composite("CY", 0.0, wires, pulse_params)
    UnitaryGates.Noise(wires, noise_params)

CZ(wires, pulse_params=None, noise_params=None, random_key=None) staticmethod #

Apply controlled-Z using ZZ coupling Hamiltonian.

Parameters:

Name Type Description Default
wires List[int]

Control and target qubit indices.

required
pulse_params Optional[float]

Time or duration parameter for the pulse evolution. If None, uses optimized value.

None
noise_params Optional[Dict[str, float]]

Noise parameters dictionary.

None
random_key Optional[PRNGKey]

JAX random key for compatibility (not used in this gate).

None
Source code in qml_essentials/pulses.py
@staticmethod
def CZ(
    wires: List[int],
    pulse_params: Optional[float] = None,
    noise_params: Optional[Dict[str, float]] = None,
    random_key: Optional[jax.random.PRNGKey] = None,
) -> None:
    """Apply controlled-Z using ZZ coupling Hamiltonian.

    Args:
        wires (List[int]): Control and target qubit indices.
        pulse_params (Optional[float]): Time or duration parameter for
            the pulse evolution. If None, uses optimized value.
        noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
            (not used in this gate).

    """
    if pulse_params is None:
        pulse_params = PulseInformation.CZ.params

    PulseGates._record_pulse_event("CZ", 0.0, wires, pulse_params)

    _H = op.Hermitian(PulseGates._H_CZ, wires=wires, record=False)
    H_eff = PulseGates._coeff_Scz * _H
    ys.evolve(H_eff, name="CZ")([pulse_params], 1)
    UnitaryGates.Noise(wires, noise_params)

H(wires, pulse_params=None, noise_params=None, random_key=None) staticmethod #

Apply Hadamard gate using pulse decomposition.

Decomposes as RZ(π) · RY(π/2) followed by a correction phase.

Parameters:

Name Type Description Default
wires Union[int, List[int]]

Qubit index or indices.

required
noise_params Optional[Dict[str, float]]

Noise parameters dictionary.

None
random_key Optional[PRNGKey]

JAX random key for compatibility (not used in this gate).

None
Source code in qml_essentials/pulses.py
@staticmethod
def H(
    wires: Union[int, List[int]],
    pulse_params: Optional[jnp.ndarray] = None,
    noise_params: Optional[Dict[str, float]] = None,
    random_key: Optional[jax.random.PRNGKey] = None,
) -> None:
    """Apply Hadamard gate using pulse decomposition.

    Decomposes as RZ(π) · RY(π/2) followed by a correction phase.

    Args:
        wires (Union[int, List[int]]): Qubit index or indices.
        noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
            (not used in this gate).
    """
    PulseGates._execute_composite("H", 0.0, wires, pulse_params)

    # Correction phase unique to the H gate
    _H = op.Hermitian(PulseGates._H_corr, wires=wires, record=False)
    H_corr = PulseGates._coeff_Sc * _H
    ys.evolve(H_corr, name="H")([0], 1)
    UnitaryGates.Noise(wires, noise_params)

PauliRot(pauli, theta, wires, pulse_params=None, noise_params=None, random_key=None) staticmethod #

Not implemented as a PulseGate.

Source code in qml_essentials/pulses.py
@staticmethod
def PauliRot(
    pauli: str,
    theta: float,
    wires: Union[int, List[int]],
    pulse_params: Optional[jnp.ndarray] = None,
    noise_params: Optional[Dict[str, float]] = None,
    random_key: Optional[jax.random.PRNGKey] = None,
) -> None:
    """Not implemented as a PulseGate."""
    raise NotImplementedError("PauliRot gate is not implemented as PulseGate")

RX(w, wires, pulse_params=None, noise_params=None, random_key=None) staticmethod #

Apply X-axis rotation using the active pulse envelope.

Parameters:

Name Type Description Default
w float

Rotation angle in radians.

required
wires Union[int, List[int]]

Qubit index or indices.

required
pulse_params Optional[ndarray]

Envelope parameters [env_0, ..., env_n, t]. If None, uses optimized defaults.

None
noise_params Optional[Dict[str, float]]

Noise parameters dictionary.

None
random_key Optional[PRNGKey]

JAX random key for compatibility

None
Source code in qml_essentials/pulses.py
@staticmethod
def RX(
    w: float,
    wires: Union[int, List[int]],
    pulse_params: Optional[jnp.ndarray] = None,
    noise_params: Optional[Dict[str, float]] = None,
    random_key: Optional[jax.random.PRNGKey] = None,
) -> None:
    """Apply X-axis rotation using the active pulse envelope.

    Args:
        w: Rotation angle in radians.
        wires: Qubit index or indices.
        pulse_params: Envelope parameters ``[env_0, ..., env_n, t]``.
            If ``None``, uses optimized defaults.
        noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
    """
    pulse_params = PulseInformation.RX.split_params(pulse_params)

    PulseGates._record_pulse_event("RX", w, wires, pulse_params)

    _H = op.Hermitian(PulseGates._H_X, wires=wires, record=False)
    H_eff = PulseGates._coeff_Sx * _H

    # Pack: [envelope_params..., w] - evolution time is the last element
    # of pulse_params (pulse_params[-1]).
    w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
    env_params = jnp.array([*pulse_params[:-1], w])
    ys.evolve(H_eff, name="RX")([env_params], pulse_params[-1])
    UnitaryGates.Noise(wires, noise_params)

RY(w, wires, pulse_params=None, noise_params=None, random_key=None) staticmethod #

Apply Y-axis rotation using the active pulse envelope.

Parameters:

Name Type Description Default
w float

Rotation angle in radians.

required
wires Union[int, List[int]]

Qubit index or indices.

required
pulse_params Optional[ndarray]

Envelope parameters [env_0, ..., env_n, t]. If None, uses optimized defaults.

None
noise_params Optional[Dict[str, float]]

Noise parameters dictionary.

None
random_key Optional[PRNGKey]

JAX random key for compatibility

None
Source code in qml_essentials/pulses.py
@staticmethod
def RY(
    w: float,
    wires: Union[int, List[int]],
    pulse_params: Optional[jnp.ndarray] = None,
    noise_params: Optional[Dict[str, float]] = None,
    random_key: Optional[jax.random.PRNGKey] = None,
) -> None:
    """Apply Y-axis rotation using the active pulse envelope.

    Args:
        w: Rotation angle in radians.
        wires: Qubit index or indices.
        pulse_params: Envelope parameters ``[env_0, ..., env_n, t]``.
            If ``None``, uses optimized defaults.
        noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility
    """
    pulse_params = PulseInformation.RY.split_params(pulse_params)

    PulseGates._record_pulse_event("RY", w, wires, pulse_params)

    _H = op.Hermitian(PulseGates._H_Y, wires=wires, record=False)
    H_eff = PulseGates._coeff_Sy * _H

    # Pack w into the params so the coefficient function doesn't need
    # a closure - this enables JIT solver cache sharing across all RY calls.
    w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
    env_params = jnp.array([*pulse_params[:-1], w])
    ys.evolve(H_eff, name="RY")([env_params], pulse_params[-1])
    UnitaryGates.Noise(wires, noise_params)

RZ(w, wires, pulse_params=None, noise_params=None, random_key=None) staticmethod #

Apply Z-axis rotation using pulse-level implementation.

Implements RZ rotation using virtual Z rotations (phase tracking) without physical pulse application.

Parameters:

Name Type Description Default
w float

Rotation angle in radians.

required
wires Union[int, List[int]]

Qubit index or indices to apply rotation to.

required
pulse_params Optional[float]

Duration parameter for the pulse. Rotation angle = w * 2 * pulse_params. Defaults to 0.5 if None.

None
noise_params Optional[Dict[str, float]]

Noise parameters dictionary.

None
random_key Optional[PRNGKey]

JAX random key for compatibility

None

Returns:

Name Type Description
None None

Gate is applied in-place to the circuit.

Source code in qml_essentials/pulses.py
@staticmethod
def RZ(
    w: float,
    wires: Union[int, List[int]],
    pulse_params: Optional[float] = None,
    noise_params: Optional[Dict[str, float]] = None,
    random_key: Optional[jax.random.PRNGKey] = None,
) -> None:
    """
    Apply Z-axis rotation using pulse-level implementation.

    Implements RZ rotation using virtual Z rotations (phase tracking)
    without physical pulse application.

    Args:
        w (float): Rotation angle in radians.
        wires (Union[int, List[int]]): Qubit index or indices to apply rotation to.
        pulse_params (Optional[float]): Duration parameter for the pulse.
            Rotation angle = w * 2 * pulse_params. Defaults to 0.5 if None.
        noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility

    Returns:
        None: Gate is applied in-place to the circuit.
    """
    pulse_params = PulseInformation.RZ.split_params(pulse_params)

    PulseGates._record_pulse_event("RZ", w, wires, pulse_params)

    _H = op.Hermitian(PulseGates.Z, wires=wires, record=False)
    H_eff = PulseGates._coeff_Sz * _H

    # Pack w into the params so the coefficient function doesn't need
    # a closure - [pulse_param_scalar, w] enables JIT solver cache sharing.
    # pulse_params may be a 1-element array or scalar; ravel + index to
    # ensure a scalar for concatenation.
    w, random_key = UnitaryGates.GateError(w, noise_params, random_key)
    pp_scalar = jnp.ravel(jnp.asarray(pulse_params))[0]
    ys.evolve(H_eff, name="RZ")([jnp.array([pp_scalar, w])], 1)

    UnitaryGates.Noise(wires, noise_params)

Rot(phi, theta, omega, wires, pulse_params=None, noise_params=None, random_key=None) staticmethod #

Apply general rotation via decomposition: RZ(phi) · RY(theta) · RZ(omega).

Parameters:

Name Type Description Default
phi float

First rotation angle.

required
theta float

Second rotation angle.

required
omega float

Third rotation angle.

required
wires Union[int, List[int]]

Qubit index or indices to apply rotation to.

required
pulse_params Optional[ndarray]

Pulse parameters for the composing gates. If None, uses optimized parameters.

None
noise_params Optional[Dict[str, float]]

Noise parameters dictionary.

None
random_key Optional[PRNGKey]

JAX random key for compatibility

None

Returns:

Name Type Description
None None

Gates are applied in-place to the circuit.

Source code in qml_essentials/pulses.py
@staticmethod
def Rot(
    phi: float,
    theta: float,
    omega: float,
    wires: Union[int, List[int]],
    pulse_params: Optional[jnp.ndarray] = None,
    noise_params: Optional[Dict[str, float]] = None,
    random_key: Optional[jax.random.PRNGKey] = None,
) -> None:
    """
    Apply general rotation via decomposition: RZ(phi) · RY(theta) · RZ(omega).

    Args:
        phi (float): First rotation angle.
        theta (float): Second rotation angle.
        omega (float): Third rotation angle.
        wires (Union[int, List[int]]): Qubit index or indices to apply rotation to.
        pulse_params (Optional[jnp.ndarray]): Pulse parameters for the
            composing gates. If None, uses optimized parameters.
        noise_params (Optional[Dict[str, float]]): Noise parameters dictionary.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for compatibility

    Returns:
        None: Gates are applied in-place to the circuit.
    """
    if noise_params is not None and "GateError" in noise_params:
        phi, random_key = UnitaryGates.GateError(phi, noise_params, random_key)
        theta, random_key = UnitaryGates.GateError(theta, noise_params, random_key)
        omega, random_key = UnitaryGates.GateError(omega, noise_params, random_key)
    PulseGates._execute_composite("Rot", [phi, theta, omega], wires, pulse_params)
    UnitaryGates.Noise(wires, noise_params)

Pulse Structure#

from qml_essentials.gates import PulseParams

Container for hierarchical pulse parameters.

Leaf nodes hold direct parameters; composite nodes hold a list of :class:DecompositionStep objects that describe how the gate is built from simpler gates.

Attributes:

Name Type Description
name

Gate identifier (e.g. "RX", "H").

decomposition

List of :class:DecompositionStep (composite only).

Source code in qml_essentials/pulses.py
class PulseParams:
    """Container for hierarchical pulse parameters.

    Leaf nodes hold direct parameters; composite nodes hold a list of
    :class:`DecompositionStep` objects that describe how the gate is
    built from simpler gates.

    Attributes:
        name: Gate identifier (e.g. ``"RX"``, ``"H"``).
        decomposition: List of :class:`DecompositionStep` (composite only).
    """

    def __init__(
        self,
        name: str = "",
        params: Optional[jnp.ndarray] = None,
        decomposition: Optional[List[DecompositionStep]] = None,
    ) -> None:
        """
        Args:
            name: Gate name.
            params: Direct pulse parameters (leaf gates).
                Mutually exclusive with *decomposition*.
            decomposition: List of :class:`DecompositionStep` (composite gates).
                Mutually exclusive with *params*.
        """
        assert (params is None) != (
            decomposition is None
        ), "Exactly one of `params` or `decomposition` must be provided."

        self.decomposition = decomposition
        # Derive _pulse_obj for backward compat with childs/leafs/split_params
        self._pulse_obj = (
            [step.gate for step in decomposition] if decomposition else None
        )

        if params is not None:
            self._params = params

        self.name = name

    def __len__(self) -> int:
        """
        Get the total number of pulse parameters.

        For composite gates, returns the accumulated count from all children.

        Returns:
            int: Total number of pulse parameters.
        """
        return len(self.params)

    def __getitem__(self, idx: int) -> Union[float, jnp.ndarray]:
        """
        Access pulse parameter(s) by index.

        For leaf gates, returns the parameter at the given index.
        For composite gates, returns parameters of the child at the given index.

        Args:
            idx (int): Index to access.

        Returns:
            Union[float, jnp.ndarray]: Parameter value or child parameters.
        """
        if self.is_leaf:
            return self.params[idx]
        else:
            return self.childs[idx].params

    def __str__(self) -> str:
        """Return string representation (gate name)."""
        return self.name

    def __repr__(self) -> str:
        """Return repr string (gate name)."""
        return self.name

    @property
    def is_leaf(self) -> bool:
        """Check if this is a leaf node (direct parameters, no children)."""
        return self._pulse_obj is None

    @property
    def size(self) -> int:
        """Get the total parameter count (alias for __len__)."""
        return len(self)

    @property
    def leafs(self) -> List["PulseParams"]:
        """
        Get all leaf nodes in the hierarchy.

        Recursively collects all leaf PulseParams objects in the tree.

        Returns:
            List[PulseParams]: List of unique leaf nodes.
        """
        if self.is_leaf:
            return [self]

        leafs = []
        for obj in self._pulse_obj:
            leafs.extend(obj.leafs)

        return list(set(leafs))

    @property
    def childs(self) -> List["PulseParams"]:
        """
        Get direct children of this node.

        Returns:
            List[PulseParams]: List of child PulseParams objects, or empty list
                if this is a leaf node.
        """
        if self.is_leaf:
            return []

        return self._pulse_obj

    @property
    def shape(self) -> List[int]:
        """
        Get the shape of pulse parameters.

        For leaf nodes, returns list with parameter count.
        For composite nodes, returns nested list of child shapes.

        Returns:
            List[int]: Parameter shape specification.
        """
        if self.is_leaf:
            return [len(self.params)]

        shape = []
        for obj in self.childs:
            shape.append(*obj.shape())

        return shape

    @property
    def params(self) -> jnp.ndarray:
        """
        Get or compute pulse parameters.

        For leaf nodes, returns internal pulse parameters.
        For composite nodes, returns concatenated parameters from all children.

        Returns:
            jnp.ndarray: Pulse parameters array.
        """
        if self.is_leaf:
            return self._params

        params = self.split_params(params=None, leafs=False)

        return jnp.concatenate(params)

    @params.setter
    def params(self, value: jnp.ndarray) -> None:
        """
        Set pulse parameters.

        For leaf nodes, sets internal parameters directly.
        For composite nodes, distributes values across children.

        Args:
            value (jnp.ndarray): Pulse parameters to set.

        Raises:
            AssertionError: If value is not jnp.ndarray for leaf nodes.
        """
        if self.is_leaf:
            assert isinstance(value, jnp.ndarray), "params must be a jnp.ndarray"
            self._params = value
            return

        idx = 0
        for obj in self.childs:
            nidx = idx + obj.size
            obj.params = value[idx:nidx]
            idx = nidx

    @property
    def leaf_params(self) -> jnp.ndarray:
        """
        Get parameters from all leaf nodes.

        Returns:
            jnp.ndarray: Concatenated parameters from all leaf nodes.
        """
        if self.is_leaf:
            return self._params

        params = self.split_params(None, leafs=True)

        return jnp.concatenate(params)

    @leaf_params.setter
    def leaf_params(self, value: jnp.ndarray) -> None:
        """
        Set parameters for all leaf nodes.

        Args:
            value (jnp.ndarray): Parameters to distribute across leaf nodes.
        """
        if self.is_leaf:
            self._params = value
            return

        idx = 0
        for obj in self.leafs:
            nidx = idx + obj.size
            obj.params = value[idx:nidx]
            idx = nidx

    def split_params(
        self,
        params: Optional[jnp.ndarray] = None,
        leafs: bool = False,
    ) -> List[jnp.ndarray]:
        """
        Split parameters into sub-arrays for children or leaves.

        Args:
            params (Optional[jnp.ndarray]): Parameters to split. If None,
                uses internal parameters.
            leafs (bool): If True, splits across leaf nodes; if False,
                splits across direct children. Defaults to False.

        Returns:
            List[jnp.ndarray]: List of parameter arrays for children or leaves.
        """
        if params is None:
            if self.is_leaf:
                return self._params

            objs = self.leafs if leafs else self.childs
            s_params = []
            for obj in objs:
                s_params.append(obj.params)

            return s_params
        else:
            if self.is_leaf:
                return params

            objs = self.leafs if leafs else self.childs
            s_params = []
            idx = 0
            for obj in objs:
                nidx = idx + obj.size
                s_params.append(params[idx:nidx])
                idx = nidx

            return s_params

childs property #

Get direct children of this node.

Returns:

Type Description
List[PulseParams]

List[PulseParams]: List of child PulseParams objects, or empty list if this is a leaf node.

is_leaf property #

Check if this is a leaf node (direct parameters, no children).

leaf_params property writable #

Get parameters from all leaf nodes.

Returns:

Type Description
ndarray

jnp.ndarray: Concatenated parameters from all leaf nodes.

leafs property #

Get all leaf nodes in the hierarchy.

Recursively collects all leaf PulseParams objects in the tree.

Returns:

Type Description
List[PulseParams]

List[PulseParams]: List of unique leaf nodes.

params property writable #

Get or compute pulse parameters.

For leaf nodes, returns internal pulse parameters. For composite nodes, returns concatenated parameters from all children.

Returns:

Type Description
ndarray

jnp.ndarray: Pulse parameters array.

shape property #

Get the shape of pulse parameters.

For leaf nodes, returns list with parameter count. For composite nodes, returns nested list of child shapes.

Returns:

Type Description
List[int]

List[int]: Parameter shape specification.

size property #

Get the total parameter count (alias for len).

__getitem__(idx) #

Access pulse parameter(s) by index.

For leaf gates, returns the parameter at the given index. For composite gates, returns parameters of the child at the given index.

Parameters:

Name Type Description Default
idx int

Index to access.

required

Returns:

Type Description
Union[float, ndarray]

Union[float, jnp.ndarray]: Parameter value or child parameters.

Source code in qml_essentials/pulses.py
def __getitem__(self, idx: int) -> Union[float, jnp.ndarray]:
    """
    Access pulse parameter(s) by index.

    For leaf gates, returns the parameter at the given index.
    For composite gates, returns parameters of the child at the given index.

    Args:
        idx (int): Index to access.

    Returns:
        Union[float, jnp.ndarray]: Parameter value or child parameters.
    """
    if self.is_leaf:
        return self.params[idx]
    else:
        return self.childs[idx].params

__init__(name='', params=None, decomposition=None) #

Parameters:

Name Type Description Default
name str

Gate name.

''
params Optional[ndarray]

Direct pulse parameters (leaf gates). Mutually exclusive with decomposition.

None
decomposition Optional[List[DecompositionStep]]

List of :class:DecompositionStep (composite gates). Mutually exclusive with params.

None
Source code in qml_essentials/pulses.py
def __init__(
    self,
    name: str = "",
    params: Optional[jnp.ndarray] = None,
    decomposition: Optional[List[DecompositionStep]] = None,
) -> None:
    """
    Args:
        name: Gate name.
        params: Direct pulse parameters (leaf gates).
            Mutually exclusive with *decomposition*.
        decomposition: List of :class:`DecompositionStep` (composite gates).
            Mutually exclusive with *params*.
    """
    assert (params is None) != (
        decomposition is None
    ), "Exactly one of `params` or `decomposition` must be provided."

    self.decomposition = decomposition
    # Derive _pulse_obj for backward compat with childs/leafs/split_params
    self._pulse_obj = (
        [step.gate for step in decomposition] if decomposition else None
    )

    if params is not None:
        self._params = params

    self.name = name

__len__() #

Get the total number of pulse parameters.

For composite gates, returns the accumulated count from all children.

Returns:

Name Type Description
int int

Total number of pulse parameters.

Source code in qml_essentials/pulses.py
def __len__(self) -> int:
    """
    Get the total number of pulse parameters.

    For composite gates, returns the accumulated count from all children.

    Returns:
        int: Total number of pulse parameters.
    """
    return len(self.params)

__repr__() #

Return repr string (gate name).

Source code in qml_essentials/pulses.py
def __repr__(self) -> str:
    """Return repr string (gate name)."""
    return self.name

__str__() #

Return string representation (gate name).

Source code in qml_essentials/pulses.py
def __str__(self) -> str:
    """Return string representation (gate name)."""
    return self.name

split_params(params=None, leafs=False) #

Split parameters into sub-arrays for children or leaves.

Parameters:

Name Type Description Default
params Optional[ndarray]

Parameters to split. If None, uses internal parameters.

None
leafs bool

If True, splits across leaf nodes; if False, splits across direct children. Defaults to False.

False

Returns:

Type Description
List[ndarray]

List[jnp.ndarray]: List of parameter arrays for children or leaves.

Source code in qml_essentials/pulses.py
def split_params(
    self,
    params: Optional[jnp.ndarray] = None,
    leafs: bool = False,
) -> List[jnp.ndarray]:
    """
    Split parameters into sub-arrays for children or leaves.

    Args:
        params (Optional[jnp.ndarray]): Parameters to split. If None,
            uses internal parameters.
        leafs (bool): If True, splits across leaf nodes; if False,
            splits across direct children. Defaults to False.

    Returns:
        List[jnp.ndarray]: List of parameter arrays for children or leaves.
    """
    if params is None:
        if self.is_leaf:
            return self._params

        objs = self.leafs if leafs else self.childs
        s_params = []
        for obj in objs:
            s_params.append(obj.params)

        return s_params
    else:
        if self.is_leaf:
            return params

        objs = self.leafs if leafs else self.childs
        s_params = []
        idx = 0
        for obj in objs:
            nidx = idx + obj.size
            s_params.append(params[idx:nidx])
            idx = nidx

        return s_params

Model#

from qml_essentials.model import Model

A quantum circuit model.

Source code in qml_essentials/model.py
  21
  22
  23
  24
  25
  26
  27
  28
  29
  30
  31
  32
  33
  34
  35
  36
  37
  38
  39
  40
  41
  42
  43
  44
  45
  46
  47
  48
  49
  50
  51
  52
  53
  54
  55
  56
  57
  58
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
class Model:
    """
    A quantum circuit model.
    """

    def __init__(
        self,
        n_qubits: int,
        n_layers: int,
        circuit_type: Union[str, Circuit] = "No_Ansatz",
        data_reupload: Union[bool, List[List[bool]], List[List[List[bool]]]] = True,
        state_preparation: Union[
            str, Callable, List[Union[str, Callable]], None
        ] = None,
        encoding: Union[Encoding, str, Callable, List[Union[str, Callable]]] = Gates.RX,
        trainable_frequencies: bool = False,
        initialization: str = "random",
        initialization_domain: List[float] = [0, 2 * jnp.pi],
        output_qubit: Union[List[int], int] = -1,
        shots: Optional[int] = None,
        random_seed: int = 1000,
        remove_zero_encoding: bool = True,
        repeat_batch_axis: List[bool] = [True, True, True],
        pulse_shape: str = "gaussian",
    ) -> None:
        """
        Initialize the quantum circuit model.
        Parameters will have the shape [impl_n_layers, parameters_per_layer]
        where impl_n_layers is the number of layers provided and added by one
        depending if data_reupload is True and parameters_per_layer is given by
        the chosen ansatz.

        The model is initialized with the following parameters as defaults:
        - noise_params: None
        - execution_type: "expval"
        - shots: None

        Args:
            n_qubits (int): The number of qubits in the circuit.
            n_layers (int): The number of layers in the circuit.
            circuit_type (str, Circuit): The type of quantum circuit to use.
                If None, defaults to "no_ansatz".
            data_reupload (Union[bool, List[bool], List[List[bool]]], optional):
                Whether to reupload data to the quantum device on each
                layer and qubit. Detailed re-uploading instructions can be given
                as a list/array of 0/False and 1/True with shape (n_qubits,
                n_layers) to specify where to upload the data. Defaults to True
                for applying data re-uploading to the full circuit.
            encoding (Union[str, Callable, List[str], List[Callable]], optional):
                The unitary to use for encoding the input data. Can be a string
                (e.g. "RX") or a callable (e.g. op.RX). Defaults to op.RX.
                If input is multidimensional it is assumed to be a list of
                unitaries or a list of strings.
            trainable_frequencies (bool, optional):
                Sets trainable encoding parameters for trainable frequencies.
                Defaults to False.
            initialization (str, optional): The strategy to initialize the parameters.
                Can be "random", "zeros", "zero-controlled", "pi", or "pi-controlled".
                Defaults to "random".
            output_qubit (List[int], int, optional): The index of the output
                qubit (or qubits). When set to -1 all qubits are measured, or a
                global measurement is conducted, depending on the execution
                type.
            shots (Optional[int], optional): The number of shots to use for
                the quantum device. Defaults to None.
            random_seed (int, optional): seed for the random number generator
                in initialization is "random" and for random noise parameters.
                Defaults to 1000.
            remove_zero_encoding (bool, optional): whether to
                remove the zero encoding from the circuit. Defaults to True.
            repeat_batch_axis (List[bool], optional): Each boolean in the array
                determines over which axes to parallelise computation. The axes
                correspond to [inputs, params, pulse_params]. Defaults to
                [True, True, True], meaning that batching is enabled over all
                axes.
            pulse_shape (str, optional): Pulse envelope shape for pulse-level
                simulation. One of ``PulseEnvelope.available()``.
                Defaults to ``"gaussian"``.

        Returns:
            None
        """
        # Initialize default parameters needed for circuit evaluation
        self.n_qubits: int = n_qubits
        self.output_qubit: Union[List[int], int] = output_qubit
        self.n_layers: int = n_layers
        self.noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None
        self.shots = shots
        self.remove_zero_encoding = remove_zero_encoding
        self.trainable_frequencies: bool = trainable_frequencies
        self.execution_type: str = "expval"
        self.repeat_batch_axis: List[bool] = repeat_batch_axis

        # --- Pulse envelope ---
        pinfo.set_envelope(pulse_shape)

        # --- State Preparation ---
        try:
            self._sp = Gates.parse_gates(state_preparation, Gates)
        except ValueError as e:
            raise ValueError(f"Error parsing encodings: {e}")

        # prepare corresponding pulse parameters (always optimized pulses)
        self.sp_pulse_params = []
        for sp in self._sp:
            sp_name = sp.__name__ if hasattr(sp, "__name__") else str(sp)

            if pinfo.gate_by_name(sp_name) is not None:
                self.sp_pulse_params.append(pinfo.gate_by_name(sp_name).params)
            else:
                # gate has no pulse parametrization
                self.sp_pulse_params.append(None)

        # --- Encoding ---
        if isinstance(encoding, Encoding):
            # user wants custom strategy? do it!
            self._enc = encoding
        else:
            # use hammming encoding by default
            self._enc = Encoding("hamming", encoding)

        # Number of possible inputs
        self.n_input_feat = len(self._enc)
        log.debug(f"Number of input features: {self.n_input_feat}")

        # Trainable frequencies, default initialization as in arXiv:2309.03279v2
        self.enc_params = jnp.ones((self.n_layers, self.n_qubits, self.n_input_feat))

        self._zero_inputs = False

        # --- Data-Reuploading ---

        # Keep as NumPy array (not JAX) so that ``if data_reupload[q, idx]``
        # in _iec remains a concrete Python bool even under jax.jit tracing.
        # note that setting this will also update self.degree and self.frequencies
        # and in consequence also self.has_dru
        self.data_reupload = data_reupload

        # check for the highest degree among all input dimensions
        if self.has_dru:
            impl_n_layers: int = n_layers + 1  # we need L+1 according to Schuld et al.
        else:
            impl_n_layers = n_layers
        log.info(f"Number of implicit layers: {impl_n_layers}.")

        # --- Ansatz ---
        # only weak check for str. We trust the user to provide sth useful
        if isinstance(circuit_type, str):
            self.pqc: Callable[[Optional[jnp.ndarray], int], int] = getattr(
                Ansaetze, circuit_type or "No_Ansatz"
            )()
        else:
            self.pqc = circuit_type()
        log.info(f"Using Ansatz {circuit_type}.")

        # calculate the shape of the parameter vector here, we will re-use this in init.
        params_per_layer = self.pqc.n_params_per_layer(self.n_qubits)
        self._params_shape: Tuple[int, int] = (impl_n_layers, params_per_layer)
        log.info(f"Parameters per layer: {params_per_layer}")

        pulse_params_per_layer = self.pqc.n_pulse_params_per_layer(self.n_qubits)
        self._pulse_params_shape: Tuple[int, int] = (
            impl_n_layers,
            pulse_params_per_layer,
        )

        # intialize to None as we can't know this yet
        self._batch_shape = None

        # this will also be re-used in the init method,
        # however, only if nothing is provided
        self._inialization_strategy = initialization
        self._initialization_domain = initialization_domain

        # ..here! where we only require a JAX random key
        self.random_key = self.initialize_params(random.key(random_seed))

        # Initializing pulse params
        self.pulse_params: jnp.ndarray = jnp.ones((1, *self._pulse_params_shape))

        log.info(f"Initialized pulse parameters with shape {self.pulse_params.shape}.")

        # Initialise the yaqsi Script that wraps _variational.
        # No device selection needed - yaqsi auto-routes between statevector
        # and density-matrix simulation based on whether noise channels are
        # present on the tape.
        self.script = ys.Script(f=self._variational, n_qubits=self.n_qubits)

    @property
    def noise_params(self) -> Optional[Dict[str, Union[float, Dict[str, float]]]]:
        """
        Gets the noise parameters of the model.

        Returns:
            Optional[Dict[str, float]]: A dictionary of
            noise parameters or None if not set.
        """
        return self._noise_params

    @noise_params.setter
    def noise_params(
        self, kvs: Optional[Dict[str, Union[float, Dict[str, float]]]]
    ) -> None:
        """
        Sets the noise parameters of the model.

        Typically a "noise parameter" refers to the error probability.
        ThermalRelaxation is a special case, and supports a dict as value with
        structure:
            "ThermalRelaxation":
            {
                "t1": 2000, # relative t1 time.
                "t2": 1000, # relative t2 time
                "t_factor" 1: # relative gate time factor
            },

        Args:
            kvs (Optional[Dict[str, Union[float, Dict[str, float]]]]): A
            dictionary of noise parameters. If all values are 0.0, the noise
            parameters are set to None.

        Returns:
            None
        """
        # set to None if only zero values provided
        if kvs is not None and all(v == 0.0 for v in kvs.values()):
            kvs = None

        # set default values
        if kvs is not None:
            defaults = {
                "BitFlip": 0.0,
                "PhaseFlip": 0.0,
                "Depolarizing": 0.0,
                "MultiQubitDepolarizing": 0.0,
                "AmplitudeDamping": 0.0,
                "PhaseDamping": 0.0,
                "GateError": 0.0,
                "ThermalRelaxation": None,
                "StatePreparation": 0.0,
                "Measurement": 0.0,
            }
            for key, default_val in defaults.items():
                kvs.setdefault(key, default_val)

            # check if there are any keys not supported
            for key in kvs.keys():
                if key not in defaults:
                    warnings.warn(
                        f"Noise type {key} is not supported by this package",
                        UserWarning,
                    )

            # check valid params for thermal relaxation noise channel
            tr_params = kvs["ThermalRelaxation"]
            if isinstance(tr_params, dict):
                tr_params.setdefault("t1", 0.0)
                tr_params.setdefault("t2", 0.0)
                tr_params.setdefault("t_factor", 0.0)
                valid_tr_keys = {"t1", "t2", "t_factor"}
                for k in tr_params.keys():
                    if k not in valid_tr_keys:
                        warnings.warn(
                            f"Thermal Relaxation parameter {k} is not supported "
                            f"by this package",
                            UserWarning,
                        )
                if not all(tr_params.values()) or tr_params["t2"] > 2 * tr_params["t1"]:
                    warnings.warn(
                        "Received invalid values for Thermal Relaxation noise "
                        "parameter. Thermal relaxation is not applied!",
                        UserWarning,
                    )
                    kvs["ThermalRelaxation"] = 0.0

        self._noise_params = kvs

    @property
    def output_qubit(self) -> List[int]:
        """Get the output qubit indices for measurement."""
        return self._output_qubit

    @output_qubit.setter
    def output_qubit(self, value: Union[int, List[int]]) -> None:
        """
        Set the output qubit(s) for measurement.

        Args:
            value: Qubit index or list of indices. Use -1 for all qubits.
        """
        if isinstance(value, list):
            assert (
                len(value) <= self.n_qubits
            ), f"Size of output_qubit {len(value)} cannot be\
            larger than number of qubits {self.n_qubits}."
        elif isinstance(value, int):
            if value == -1:
                value = list(range(self.n_qubits))
            else:
                assert (
                    value < self.n_qubits
                ), f"Output qubit {value} cannot be larger than {self.n_qubits}."
                value = [value]

        self._output_qubit = value

    @property
    def execution_type(self) -> str:
        """
        Gets the execution type of the model.

        Returns:
            str: The execution type, one of 'density', 'expval', or 'probs'.
        """
        return self._execution_type

    @execution_type.setter
    def execution_type(self, value: str) -> None:
        if value == "density":
            self._result_shape = (
                2 ** len(self.output_qubit),
                2 ** len(self.output_qubit),
            )
        elif value == "expval":
            # check if all qubits are used
            if len(self.output_qubit) == self.n_qubits:
                self._result_shape = (len(self.output_qubit),)
            # if not -> parity measurement with only 1D output per pair
            # or n_local measurement
            else:
                self._result_shape = (len(self.output_qubit),)
        elif value == "probs":
            # in case this is a list of parities,
            # each pair has 2^len(qubits) probabilities
            n_parity = (
                (2,) * len(self.output_qubit)
                if isinstance(self.output_qubit, (Tuple, List))
                else (2,)
            )
            self._result_shape = n_parity
        elif value == "state":
            self._result_shape = (2 ** len(self.output_qubit),)
        else:
            raise ValueError(f"Invalid execution type: {value}.")

        if value == "state" and not self.all_qubit_measurement:
            warnings.warn(
                f"{value} measurement does ignore output_qubit, which is "
                f"{self.output_qubit}.",
                UserWarning,
            )

        if value == "probs" and self.shots is None:
            warnings.warn(
                "Setting execution_type to probs without specifying shots.",
                UserWarning,
            )

        if value == "density" and self.shots is not None:
            raise ValueError("Setting execution_type to density with shots not None.")

        self._execution_type = value

    @property
    def shots(self) -> Optional[int]:
        """
        Gets the number of shots to use for the quantum device.

        Returns:
            Optional[int]: The number of shots.
        """
        return self._shots

    @shots.setter
    def shots(self, value: Optional[int]) -> None:
        """
        Sets the number of shots to use for the quantum device.

        Args:
            value (Optional[int]): The number of shots.
            If an integer less than or equal to 0 is provided, it is set to None.

        Returns:
            None
        """
        if type(value) is int and value <= 0:
            value = None
        self._shots = value

    @property
    def params(self) -> jnp.ndarray:
        """Get the variational parameters of the model."""
        return self._params

    @params.setter
    def params(self, value: jnp.ndarray) -> None:
        """Set the variational parameters, ensuring batch dimension exists."""
        if len(value.shape) == 2:
            value = value.reshape(1, *value.shape)

        self._params = value

    @property
    def enc_params(self) -> jnp.ndarray:
        """Get the encoding parameters used for input transformation."""
        return self._enc_params

    @enc_params.setter
    def enc_params(self, value: jnp.ndarray) -> None:
        """Set the encoding parameters."""
        self._enc_params = value

    @property
    def pulse_params(self) -> jnp.ndarray:
        """Get the pulse parameters for pulse-mode gate execution."""
        return self._pulse_params

    @pulse_params.setter
    def pulse_params(self, value: jnp.ndarray) -> None:
        """Set the pulse parameters."""
        self._pulse_params = value

    @property
    def data_reupload(self) -> jnp.ndarray:
        """Get the data reupload mask."""
        return self._data_reupload

    @data_reupload.setter
    def data_reupload(self, value: jnp.ndarray) -> None:
        """Set the data reupload mask.

        Always converts to a concrete NumPy boolean array so that
        ``if data_reupload[q, idx]`` in :meth:`_iec` remains a plain
        Python ``bool`` even inside JAX-traced functions (jit / grad / vmap).
        """
        # Process data reuploading strategy and set degree
        if not isinstance(value, bool):
            if not isinstance(value, np.ndarray):
                value = np.array(value)

            if len(value.shape) == 2:
                assert value.shape == (
                    self.n_layers,
                    self.n_qubits,
                ), f"Data reuploading array has wrong shape. \
                    Expected {(self.n_layers, self.n_qubits)} or\
                    {(self.n_layers, self.n_qubits, self.n_input_feat)},\
                    got {value.shape}."
                value = value.reshape(*value.shape, 1)
                value = np.repeat(value, self.n_input_feat, axis=2)

            assert value.shape == (
                self.n_layers,
                self.n_qubits,
                self.n_input_feat,
            ), f"Data reuploading array has wrong shape. \
                Expected {(self.n_layers, self.n_qubits, self.n_input_feat)},\
                got {value.shape}."

            log.debug(f"Data reuploading array:\n{value}")
        else:
            if value:
                value = np.ones((self.n_layers, self.n_qubits, self.n_input_feat))
                log.debug("Full data reuploading.")
            else:
                value = np.zeros((self.n_layers, self.n_qubits, self.n_input_feat))
                value[0][0] = 1
                log.debug("No data reuploading.")

        # convert to boolean values
        self._data_reupload = np.asarray(value).astype(bool)

        self.degree: Tuple = tuple(
            self._enc.get_n_freqs(np.count_nonzero(self.data_reupload[..., i]))
            for i in range(self.n_input_feat)
        )

        self.frequencies: Tuple = tuple(
            self._enc.get_spectrum(np.count_nonzero(self.data_reupload[..., i]))
            for i in range(self.n_input_feat)
        )

        # Cache has_dru as a plain Python bool so that it can be used in
        # Python ``if`` statements even inside JAX-traced functions.
        self._has_dru: bool = bool(max(int(np.max(f)) for f in self._frequencies) > 1)

    @property
    def degree(self) -> Tuple:
        """Get the degree of the model."""
        return self._degree

    @degree.setter
    def degree(self, value: Tuple):
        self._degree = value

    @property
    def frequencies(self) -> Tuple:
        """Get the frequencies of the model."""
        return self._frequencies

    @frequencies.setter
    def frequencies(self, value: Tuple):
        self._frequencies = value

    @property
    def has_dru(self) -> bool:
        """Check if the model has data reupload."""
        return self._has_dru

    @property
    def all_qubit_measurement(self) -> bool:
        """Check if measurement is performed on all qubits."""
        return self.output_qubit == list(range(self.n_qubits))

    @property
    def batch_shape(self) -> Tuple[int, ...]:
        """
        Get the batch shape (B_I, B_P, B_R).
        If the model was not called before,
        it returns (1, 1, 1).

        Returns:
            Tuple[int, ...]: Tuple of (input_batch, param_batch, pulse_batch).
                Returns (1, 1, 1) if model has not been called yet.
        """
        if self._batch_shape is None:
            log.debug("Model was not called yet. Returning (1,1,1) as batch shape.")
            return (1, 1, 1)
        return self._batch_shape

    @property
    def eff_batch_shape(self) -> Tuple[int, ...]:
        """
        Get the effective batch shape after applying repeat_batch_axis mask.

        Returns:
            Tuple[int, ...]: Effective batch dimensions, excluding zeros.
        """
        batch_shape = np.array(self.batch_shape) * self.repeat_batch_axis
        batch_shape = batch_shape[batch_shape != 0]
        return batch_shape

    def initialize_params(
        self,
        random_key: Optional[random.PRNGKey] = None,
        repeat: int = 1,
        initialization: Optional[str] = None,
        initialization_domain: Optional[List[float]] = None,
    ) -> random.PRNGKey:
        """
        Initialize the variational parameters of the model.

        Args:
            random_key (Optional[random.PRNGKey]): JAX random key for initialization.
                If None, uses the model's internal random key.
            repeat (int): Number of parameter sets to create (batch dimension).
                Defaults to 1.
            initialization (Optional[str]): Strategy for parameter initialization.
                Options: "random", "zeros", "pi", "zero-controlled", "pi-controlled".
                If None, uses the strategy specified in the constructor.
            initialization_domain (Optional[List[float]]): Domain [min, max] for
                random initialization. If None, uses the domain from constructor.

        Returns:
            random.PRNGKey: Updated random key after initialization.

        Raises:
            Exception: If an invalid initialization method is specified.
        """
        # Initializing params
        params_shape = (repeat, *self._params_shape)

        # use existing strategy if not specified
        initialization = initialization or self._inialization_strategy
        initialization_domain = initialization_domain or self._initialization_domain

        random_key, sub_key = safe_random_split(
            random_key if random_key is not None else self.random_key
        )

        def set_control_params(params: jnp.ndarray, value: float) -> jnp.ndarray:
            indices = self.pqc.get_control_indices(self.n_qubits)
            if indices is None:
                warnings.warn(
                    f"Specified {initialization} but circuit\
                    does not contain controlled rotation gates.\
                    Parameters are intialized randomly.",
                    UserWarning,
                )
            else:
                np_params = np.array(params)
                np_params[:, :, indices[0] : indices[1] : indices[2]] = (
                    np.ones_like(params[:, :, indices[0] : indices[1] : indices[2]])
                    * value
                )
                params = jnp.array(np_params)
            return params

        if initialization == "random":
            self.params: jnp.ndarray = random.uniform(
                sub_key,
                params_shape,
                minval=initialization_domain[0],
                maxval=initialization_domain[1],
            )
        elif initialization == "zeros":
            self.params: jnp.ndarray = jnp.zeros(params_shape)
        elif initialization == "pi":
            self.params: jnp.ndarray = jnp.ones(params_shape) * jnp.pi
        elif initialization == "zero-controlled":
            self.params: jnp.ndarray = random.uniform(
                sub_key,
                params_shape,
                minval=initialization_domain[0],
                maxval=initialization_domain[1],
            )
            self.params = set_control_params(self.params, 0)
        elif initialization == "pi-controlled":
            self.params: jnp.ndarray = random.uniform(
                sub_key,
                params_shape,
                minval=initialization_domain[0],
                maxval=initialization_domain[1],
            )
            self.params = set_control_params(self.params, jnp.pi)
        else:
            raise Exception("Invalid initialization method")

        log.info(
            f"Initialized parameters with shape {self.params.shape}\
            using strategy {initialization}."
        )

        return random_key

    def transform_input(
        self, inputs: jnp.ndarray, enc_params: jnp.ndarray
    ) -> jnp.ndarray:
        """
        Transform input data by scaling with encoding parameters.

        Implements the input transformation as described in arXiv:2309.03279v2,
        where inputs are linearly scaled by encoding parameters before being
        used in the quantum circuit.

        Args:
            inputs (jnp.ndarray): Input data point of shape (n_input_feat,) or
                (batch_size, n_input_feat).
            enc_params (jnp.ndarray): Encoding weight scalar or vector used to
                scale the input.

        Returns:
            jnp.ndarray: Transformed input, element-wise product of inputs
                and enc_params.
        """
        return inputs * enc_params

    def _iec(
        self,
        inputs: jnp.ndarray,
        data_reupload: jnp.ndarray,
        enc: Encoding,
        enc_params: jnp.ndarray,
        noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None,
        random_key: Optional[random.PRNGKey] = None,
    ) -> None:
        """
        Apply Input Encoding Circuit (IEC) with angle encoding.

        Encodes classical input data into the quantum circuit using rotation
        gates (e.g., RX, RY, RZ). Supports data re-uploading at specified
        positions in the circuit.

        Args:
            inputs (jnp.ndarray): Input data of shape (n_input_feat,) or
                (batch_size, n_input_feat).
            data_reupload (jnp.ndarray): Boolean array of shape (n_qubits, n_input_feat)
                indicating where to apply encoding gates.
            enc (Encoding): Encoding strategy containing the encoding gate functions.
            enc_params (jnp.ndarray): Encoding parameters of shape
                (n_qubits, n_input_feat) used to scale inputs.
            noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]):
                Noise parameters for gate-level noise simulation. Defaults to None.
            random_key (Optional[random.PRNGKey]): JAX random key for stochastic
                noise. Defaults to None.

        Returns:
            None: Gates are applied in-place to the quantum circuit.
        """
        # check for zero, because due to input validation, input cannot be none
        if self.remove_zero_encoding and self._zero_inputs and self.batch_shape[0] == 1:
            return

        for q in range(self.n_qubits):
            # use the last dimension of the inputs (feature dimension)
            for idx in range(inputs.shape[-1]):
                if data_reupload[q, idx]:
                    # use elipsis to indiex only the last dimension
                    # as inputs are generally *not* qubit dependent
                    random_key, sub_key = safe_random_split(random_key)
                    enc[idx](
                        self.transform_input(inputs[..., idx], enc_params[q, idx]),
                        wires=q,
                        noise_params=noise_params,
                        random_key=sub_key,
                        input_idx=idx,
                    )

    def _variational(
        self,
        params: jnp.ndarray,
        inputs: jnp.ndarray,
        pulse_params: Optional[jnp.ndarray] = None,
        random_key: Optional[random.PRNGKey] = None,
        enc_params: Optional[jnp.ndarray] = None,
        gate_mode: str = "unitary",
        noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None,
    ) -> None:
        """
        Build the variational quantum circuit structure.

        Constructs the circuit by applying state preparation, alternating
        variational ansatz layers with input encoding layers, and optional
        noise channels.

        The first five parameters (after ``self``) - ``params``, ``inputs``,
        ``pulse_params``, ``random_key``, ``enc_params`` - are the batchable
        positional arguments.
        The remaining keyword arguments are broadcast across the batch.

        Args:
            params (jnp.ndarray): Variational parameters of shape
                (n_layers, n_params_per_layer).
            inputs (jnp.ndarray): Input data of shape (n_input_feat,).
            pulse_params (Optional[jnp.ndarray]): Pulse parameter scalers of shape
                (n_layers, n_pulse_params_per_layer) for pulse-mode execution.
                Defaults to None (uses model's pulse_params).
            random_key (Optional[random.PRNGKey]): JAX random key for stochastic
                operations. Defaults to None.
            enc_params (Optional[jnp.ndarray]): Encoding parameters of shape
                (n_qubits, n_input_feat). Defaults to None (uses model's enc_params).
            gate_mode (str): Gate execution mode, either "unitary" or "pulse".
                Defaults to "unitary".
            noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]):
                Noise parameters for simulation. Defaults to None.

        Returns:
            None: Gates are applied in-place to the quantum circuit.

        Note:
            Issues RuntimeWarning if called directly without providing parameters
            that would normally be passed through the forward method.
        """
        # TODO: rework and double check params shape
        if len(params.shape) > 2 and params.shape[0] == 1:
            params = params[0]

        if len(inputs.shape) > 1 and inputs.shape[0] == 1:
            inputs = inputs[0]

        if enc_params is None:
            # TODO: Raise warning if trainable frequencies is True, or similar. I.e., no
            #   warning if user does not care for frequencies or enc_params
            if self.trainable_frequencies:
                warnings.warn(
                    "Explicit call to `_circuit` or `_variational` detected: "
                    "`enc_params` is None, using `self.enc_params` instead.",
                    RuntimeWarning,
                )
            enc_params = self.enc_params

        if pulse_params is None:
            if gate_mode == "pulse":
                warnings.warn(
                    "Explicit call to `_circuit` or `_variational` detected: "
                    "`pulse_params` is None, using `self.pulse_params` instead.",
                    RuntimeWarning,
                )
            pulse_params = self.pulse_params

        # Squeeze batch dimension for pulse_params (batch-first convention)
        if len(pulse_params.shape) > 2 and pulse_params.shape[0] == 1:
            pulse_params = pulse_params[0]

        if noise_params is None:
            if self.noise_params is not None:
                warnings.warn(
                    "Explicit call to `_circuit` or `_variational` detected: "
                    "`noise_params` is None, using `self.noise_params` instead.",
                    RuntimeWarning,
                )
                noise_params = self.noise_params

        if noise_params is not None:
            if random_key is None:
                warnings.warn(
                    "Explicit call to `_circuit` or `_variational` detected: "
                    "`random_key` is None, using `random.PRNGKey(0)` instead.",
                    RuntimeWarning,
                )
                random_key = self.random_key
            self._apply_state_prep_noise(noise_params=noise_params)

        # state preparation
        for q in range(self.n_qubits):
            for _sp, sp_pulse_params in zip(self._sp, self.sp_pulse_params):
                random_key, sub_key = safe_random_split(random_key)
                _sp(
                    wires=q,
                    pulse_params=sp_pulse_params,
                    noise_params=noise_params,
                    random_key=sub_key,
                    gate_mode=gate_mode,
                )

        # circuit building
        for layer in range(0, self.n_layers):
            random_key, sub_key = safe_random_split(random_key)
            # ansatz layers
            self.pqc(
                params[layer],
                self.n_qubits,
                pulse_params=pulse_params[layer],
                noise_params=noise_params,
                random_key=sub_key,
                gate_mode=gate_mode,
            )

            random_key, sub_key = safe_random_split(random_key)
            # encoding layers
            self._iec(
                inputs,
                data_reupload=self.data_reupload[layer],
                enc=self._enc,
                enc_params=enc_params[layer],
                noise_params=noise_params,
                random_key=sub_key,
            )

            # visual barrier (no-op in yaqsi, purely cosmetic in PennyLane)

        # final ansatz layer
        if self.has_dru:  # same check as in init
            random_key, sub_key = safe_random_split(random_key)
            self.pqc(
                params[self.n_layers],
                self.n_qubits,
                pulse_params=pulse_params[-1],
                noise_params=noise_params,
                random_key=sub_key,
                gate_mode=gate_mode,
            )

        # channel noise
        if noise_params is not None:
            self._apply_general_noise(noise_params=noise_params)

    def _build_obs(self) -> Tuple[str, List[op.Operation]]:
        """Build the yaqsi measurement type and observable list.

        Translates the model's ``execution_type`` and ``output_qubit``
        settings into parameters suitable for
        :meth:`~qml_essentials.yaqsi.Script.execute`.

        Returns:
            Tuple ``(meas_type, obs)`` where *meas_type* is one of
            ``"expval"``, ``"probs"``, ``"density"``, ``"state"`` and *obs*
            is a (possibly empty) list of :class:`Operation` observables.
        """
        if self.execution_type == "density":
            return "density", []

        if self.execution_type == "state":
            return "state", []

        if self.execution_type == "expval":
            obs: List[op.Operation] = []
            for qubit_spec in self.output_qubit:
                if isinstance(qubit_spec, int):
                    obs.append(op.PauliZ(wires=qubit_spec))
                else:
                    # parity: Z \\otimes Z \\otimes …
                    obs.append(ys.build_parity_observable(list(qubit_spec)))
            return "expval", obs

        if self.execution_type == "probs":
            # probs are computed on the full system; subsystem
            # marginalisation is handled in _postprocess_res
            return "probs", []

        raise ValueError(f"Invalid execution_type: {self.execution_type}.")

    def _apply_state_prep_noise(
        self, noise_params: Dict[str, Union[float, Dict[str, float]]]
    ) -> None:
        """
        Apply state preparation noise to all qubits.

        Simulates imperfect state preparation by applying BitFlip errors
        to each qubit with the specified probability.

        Args:
            noise_params (Dict[str, Union[float, Dict[str, float]]]): Dictionary
                containing noise parameters. Uses the "StatePreparation" key
                for the BitFlip probability.

        Returns:
            None: Noise channels are applied in-place to the circuit.
        """
        p = noise_params.get("StatePreparation", 0.0)
        if p > 0:
            for q in range(self.n_qubits):
                op.BitFlip(p, wires=q)

    def _apply_general_noise(
        self, noise_params: Dict[str, Union[float, Dict[str, float]]]
    ) -> None:
        """
        Apply general noise channels to all qubits.

        Applies various decoherence and error channels after the circuit
        execution, simulating environmental noise effects.

        Args:
            noise_params (Dict[str, Union[float, Dict[str, float]]]): Dictionary
                containing noise parameters with the following supported keys:
                - "AmplitudeDamping" (float): Probability for amplitude damping.
                - "PhaseDamping" (float): Probability for phase damping.
                - "Measurement" (float): Probability for measurement error (BitFlip).
                - "ThermalRelaxation" (Dict): Dictionary with keys "t1", "t2",
                  "t_factor" for thermal relaxation simulation.

        Returns:
            None: Noise channels are applied in-place to the circuit.

        Note:
            Gate-level noise (e.g., GateError) is handled separately in the
            Gates.Noise module and applied at the individual gate level.
        """
        amp_damp = noise_params.get("AmplitudeDamping", 0.0)
        phase_damp = noise_params.get("PhaseDamping", 0.0)
        thermal_relax = noise_params.get("ThermalRelaxation", 0.0)
        meas = noise_params.get("Measurement", 0.0)
        for q in range(self.n_qubits):
            if amp_damp > 0:
                op.AmplitudeDamping(amp_damp, wires=q)
            if phase_damp > 0:
                op.PhaseDamping(phase_damp, wires=q)
            if meas > 0:
                op.BitFlip(meas, wires=q)
            if isinstance(thermal_relax, dict):
                t1 = thermal_relax["t1"]
                t2 = thermal_relax["t2"]
                t_factor = thermal_relax["t_factor"]
                circuit_depth = self._get_circuit_depth()
                tg = circuit_depth * t_factor
                op.ThermalRelaxationError(1.0, t1, t2, tg, q)

    def _get_circuit_depth(self, inputs: Optional[jnp.ndarray] = None) -> int:
        """
        Calculate the depth of the quantum circuit.

        Records the circuit onto a tape (without noise) and computes the
        depth as the length of the critical path: each gate is scheduled
        at the earliest time step after all of its qubits are free.

        Args:
            inputs (Optional[jnp.ndarray]): Input data for circuit evaluation.
                If None, default zero inputs are used.

        Returns:
            int: The circuit depth (longest path of gates in the circuit).
        """
        # Return cached value if available
        if hasattr(self, "_cached_circuit_depth"):
            return self._cached_circuit_depth

        inputs = self._inputs_validation(inputs)

        # Temporarily clear noise_params to prevent _variational from
        # picking them up (which would call _apply_general_noise ->
        # _get_circuit_depth again, causing infinite recursion).
        saved_noise = self._noise_params
        self._noise_params = None

        with recording() as tape:
            self._variational(
                self.params[0] if self.params.ndim == 3 else self.params,
                inputs[0] if inputs.ndim == 2 else inputs,
                noise_params=None,
            )

        self._noise_params = saved_noise

        # Filter out noise channels - only count unitary gates
        ops = [o for o in tape if not isinstance(o, KrausChannel)]

        if not ops:
            self._cached_circuit_depth = 0
            return 0

        # Schedule each gate at the earliest time step where all its wires
        # are free.  ``wire_busy[q]`` tracks the next free time step for
        # qubit ``q``.
        wire_busy: Dict[int, int] = {}
        depth = 0
        for gate in ops:
            start = max((wire_busy.get(w, 0) for w in gate.wires), default=0)
            end = start + 1
            for w in gate.wires:
                wire_busy[w] = end
            depth = max(depth, end)

        self._cached_circuit_depth = depth
        return depth

    def draw(
        self,
        inputs: Optional[jnp.ndarray] = None,
        figure: str = "text",
        **kwargs: Any,
    ) -> Union[str, Any]:
        """Visualize the quantum circuit.

        Records the circuit tape (without noise) and renders the gate
        sequence using the requested backend.

        Args:
            inputs (Optional[jnp.ndarray]): Input data for the circuit.
                If ``None``, default zero inputs are used.
            figure (str): Rendering backend.  One of:

                * ``"text"``  - ASCII art (returned as a ``str``).
                * ``"mpl"``   - Matplotlib figure (returns ``(fig, ax)``).
                * ``"tikz"``  - LaTeX/TikZ ``quantikz`` code (returns a
                  :class:`TikzFigure`).
                * ``"pulse"`` - Pulse schedule (returns ``(fig, axes)``).
                  Only meaningful for pulse-mode models.

            **kwargs: Extra options forwarded to the drawing backend
                (e.g. ``gate_values=True``).

        Returns:
            Depends on figure:

            * ``"text"``  -> ``str``
            * ``"mpl"``   -> ``(matplotlib.figure.Figure, matplotlib.axes.Axes)``
            * ``"tikz"``  -> :class:`TikzFigure`

        Raises:
            ValueError: If figure is not one of the supported modes.
        """
        inputs = self._inputs_validation(inputs)
        params = self.params[0] if self.params.ndim == 3 else self.params
        inp = inputs[0] if inputs.ndim == 2 else inputs

        if figure == "pulse":
            return self.draw_pulse(inputs=inputs, **kwargs)

        # Record without noise to get a clean circuit
        saved_noise = self._noise_params
        self._noise_params = None

        draw_script = ys.Script(f=self._variational, n_qubits=self.n_qubits)
        result = draw_script.draw(
            figure=figure,
            args=(params, inp),
            kwargs={"noise_params": None},
            **kwargs,
        )

        self._noise_params = saved_noise
        return result

    def draw_pulse(
        self,
        inputs: Optional[jnp.ndarray] = None,
        **kwargs: Any,
    ) -> Any:
        """Visualize the pulse schedule for the circuit.

        Records the circuit in pulse mode and collects PulseEvents
        automatically via the pulse-event tape, then renders them.

        Args:
            inputs: Input data.  If ``None``, default zero inputs are used.
            **kwargs: Forwarded to
                :func:`~qml_essentials.drawing.draw_pulse_schedule`
                (e.g. ``show_carrier=True``, ``n_samples=300``).

        Returns:
            ``(fig, axes)`` — Matplotlib Figure and array of Axes.
        """
        inputs = self._inputs_validation(inputs)
        params = self.params[0] if self.params.ndim == 3 else self.params
        inp = inputs[0] if inputs.ndim == 2 else inputs

        draw_script = ys.Script(f=self._variational, n_qubits=self.n_qubits)
        return draw_script.draw(
            figure="pulse",
            args=(params, inp),
            kwargs={
                "gate_mode": "pulse",
                "noise_params": None,
            },
            **kwargs,
        )

    def __repr__(self) -> str:
        """Return text representation of the quantum circuit model."""
        return self.draw(figure="text")

    def __str__(self) -> str:
        """Return string representation of the quantum circuit model."""
        return self.draw(figure="text")

    def _params_validation(self, params: Optional[jnp.ndarray]) -> jnp.ndarray:
        """
        Validate and normalize variational parameters.

        Ensures parameters have the correct shape with a batch dimension,
        and updates the model's internal parameters if new ones are provided.

        Args:
            params (Optional[jnp.ndarray]): Variational parameters to validate.
                If None, returns the model's current parameters.

        Returns:
            jnp.ndarray: Validated parameters with shape
                (batch_size, n_layers, n_params_per_layer).
        """
        # append batch axis if not provided
        if params is not None:
            if len(params.shape) == 2:
                params = np.expand_dims(params, axis=0)

            self.params = params
        else:
            params = self.params

        return params

    def _pulse_params_validation(
        self, pulse_params: Optional[jnp.ndarray]
    ) -> jnp.ndarray:
        """
        Validate and normalize pulse parameters.

        Ensures pulse parameters are set, using model defaults if not provided.

        Args:
            pulse_params (Optional[jnp.ndarray]): Pulse parameter scalers.
                If None, returns the model's current pulse parameters.

        Returns:
            jnp.ndarray: Validated pulse parameters with shape
                (batch_size, n_layers, n_pulse_params_per_layer).
        """
        if pulse_params is None:
            pulse_params = self.pulse_params
        else:
            # ensure batch dimension exists (batch-first convention)
            if len(pulse_params.shape) == 2:
                pulse_params = jnp.expand_dims(pulse_params, axis=0)
            self.pulse_params = pulse_params

        return pulse_params

    def _enc_params_validation(self, enc_params: Optional[jnp.ndarray]) -> jnp.ndarray:
        """
        Validate and normalize encoding parameters.

        Ensures encoding parameters have the correct shape for the model's
        input feature dimensions.

        Args:
            enc_params (Optional[jnp.ndarray]): Encoding parameters to validate.
                If None, returns the model's current encoding parameters.

        Returns:
            jnp.ndarray: Validated encoding parameters with shape
                (n_qubits, n_input_feat).

        Raises:
            ValueError: If enc_params shape is incompatible with n_input_feat > 1.
        """
        if enc_params is None:
            enc_params = self.enc_params
        else:
            if self.trainable_frequencies:
                self.enc_params = enc_params
            else:
                self.enc_params = jnp.array(enc_params)

        if len(enc_params.shape) == 1 and self.n_input_feat == 1:
            enc_params = enc_params.reshape(-1, 1)
        elif len(enc_params.shape) == 1 and self.n_input_feat > 1:
            raise ValueError(
                f"Input dimension {self.n_input_feat} >1 but \
                `enc_params` has shape {enc_params.shape}"
            )

        return enc_params

    def _inputs_validation(
        self, inputs: Union[None, List, float, int, jnp.ndarray]
    ) -> jnp.ndarray:
        """
        Validate and normalize input data.

        Converts various input formats to a standardized 2D array shape
        suitable for batch processing in the quantum circuit.

        Args:
            inputs (Union[None, List, float, int, jnp.ndarray]): Input data in
                various formats:
                - None: Returns zeros with shape (1, n_input_feat)
                - float/int: Single scalar value
                - List: List of values or batched inputs
                - jnp.ndarray: NumPy/JAX array

        Returns:
            jnp.ndarray: Validated inputs with shape (batch_size, n_input_feat).

        Raises:
            ValueError: If input shape is incompatible with expected n_input_feat.

        Warns:
            UserWarning: If input is replicated to match n_input_feat.
        """
        self._zero_inputs = False
        if isinstance(inputs, List):
            inputs = jnp.array(np.stack(inputs))
        elif isinstance(inputs, float) or isinstance(inputs, int):
            inputs = jnp.array([inputs])
        elif inputs is None:
            inputs = jnp.array([[0] * self.n_input_feat])

        if not inputs.any():
            self._zero_inputs = True

        if len(inputs.shape) <= 1:
            if self.n_input_feat == 1:
                # add a batch dimension
                inputs = inputs.reshape(-1, 1)
            else:
                if inputs.shape[0] == self.n_input_feat:
                    inputs = inputs.reshape(1, -1)
                else:
                    inputs = inputs.reshape(-1, 1)
                    inputs = inputs.repeat(self.n_input_feat, axis=1)
                    warnings.warn(
                        f"Expected {self.n_input_feat} inputs, but {inputs.shape[0]} "
                        "was provided, replicating input for all input features.",
                        UserWarning,
                    )
        else:
            if inputs.shape[1] != self.n_input_feat:
                raise ValueError(
                    f"Wrong number of inputs provided. Expected {self.n_input_feat} "
                    f"inputs, but input has shape {inputs.shape}."
                )

        return inputs

    def _postprocess_res(self, result: Union[List, jnp.ndarray]) -> jnp.ndarray:
        """
        Post-process circuit execution results for uniform shape.

        Converts list outputs (from multiple measurements) to stacked arrays
        and reorders axes for consistent batch dimension placement.

        Args:
            result (Union[List, jnp.ndarray]): Raw circuit output, either a
                list of measurement results or a single array.

        Returns:
            jnp.ndarray: Uniformly shaped result array with batch dimension first.
        """
        if isinstance(result, list):
            # we use moveaxis here because in case of parity measure,
            # there is another dimension appended to the end and
            # simply transposing would result in a wrong shape
            result = jnp.stack(result)
            if len(result.shape) > 1:
                result = jnp.moveaxis(result, 0, 1)
        return result

    def _assimilate_batch(
        self,
        inputs: jnp.ndarray,
        params: jnp.ndarray,
        pulse_params: jnp.ndarray,
    ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
        """
        Align batch dimensions across inputs, parameters, and pulse parameters.

        Broadcasts and reshapes arrays to have compatible batch dimensions
        for vectorized circuit execution. Sets the internal batch_shape.

        Args:
            inputs (jnp.ndarray): Input data of shape (B_I, n_input_feat).
            params (jnp.ndarray): Parameters of shape (B_P, n_layers, n_params).
            pulse_params (jnp.ndarray): Pulse params of shape (B_R, n_layers, n_pulse).

        Returns:
            Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: Tuple containing:
                - inputs: Reshaped to (B, n_input_feat) where B = B_I * B_P * B_R
                - params: Reshaped to (B, n_layers, n_params)
                - pulse_params: Reshaped to (B, n_layers, n_pulse)

        Note:
            The effective batch shape depends on repeat_batch_axis configuration.
            This is the only method that sets self._batch_shape.
        """
        B_I = inputs.shape[0]
        # we check for the product because there is a chance that
        # there are no params. In this case we want B_P to be 1
        B_P = 1 if 0 in params.shape else params.shape[0]
        B_R = pulse_params.shape[0]

        # THIS is the only place where we set the batch shape
        self._batch_shape = (B_I, B_P, B_R)
        B = np.prod(self.eff_batch_shape)

        # [B_I, ...] -> [B_I, B_P, B_R, ...] -> [B, ...]
        if B_I > 1 and self.repeat_batch_axis[0]:
            if self.repeat_batch_axis[1]:
                inputs = jnp.repeat(inputs[:, None, None, ...], B_P, axis=1)
            if self.repeat_batch_axis[2]:
                inputs = jnp.repeat(inputs, B_R, axis=2)
            inputs = inputs.reshape(B, *inputs.shape[3:])

        # [B_P, ..., ...] -> [B_I, B_P, B_R, ..., ...] -> [B, ..., ...]
        if B_P > 1 and self.repeat_batch_axis[1]:
            # add B_I axis before first, and B_R axis after first batch dim
            params = params[None, :, None, ...]  # [B_I(=1), B_P, B_R(=1), ...]
            if self.repeat_batch_axis[0]:
                params = jnp.repeat(params, B_I, axis=0)  # [B_I, B_P, 1, ...]
            if self.repeat_batch_axis[2]:
                params = jnp.repeat(params, B_R, axis=2)  # [B_I, B_P, B_R, ...]
            params = params.reshape(B, *params.shape[3:])

        # [B_R, ..., ...] -> [B_I, B_P, B_R, ..., ...] -> [B, ..., ...]
        if B_R > 1 and self.repeat_batch_axis[2]:
            # add B_I axis and B_P axis before B_R
            pulse_params = pulse_params[None, None, ...]  # [B_I(=1), B_P(=1), B_R, ...]
            if self.repeat_batch_axis[0]:
                pulse_params = jnp.repeat(
                    pulse_params, B_I, axis=0
                )  # [B_I, 1, B_R, ...]
            if self.repeat_batch_axis[1]:
                pulse_params = jnp.repeat(
                    pulse_params, B_P, axis=1
                )  # [B_I, B_P, B_R, ...]
            pulse_params = pulse_params.reshape(B, *pulse_params.shape[3:])

        return inputs, params, pulse_params

    def _requires_density(self) -> bool:
        """
        Check if density matrix simulation is required.

        Determines whether the circuit must be executed with the mixed-state
        simulator based on execution type and noise configuration.

        Returns:
            bool: True if density matrix simulation is required, False otherwise.
                Returns True if:
                - execution_type is "density", or
                - Any non-coherent noise channel has non-zero probability
        """
        if self.execution_type == "density":
            return True

        if self.noise_params is None:
            return False

        coherent_noise = {"GateError"}
        for k, v in self.noise_params.items():
            if k in coherent_noise:
                continue
            if v is not None and v > 0:
                return True
        return False

    def __call__(
        self,
        params: Optional[jnp.ndarray] = None,
        inputs: Optional[jnp.ndarray] = None,
        pulse_params: Optional[jnp.ndarray] = None,
        enc_params: Optional[jnp.ndarray] = None,
        data_reupload: Union[bool, List[List[bool]], List[List[List[bool]]]] = None,
        noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None,
        execution_type: Optional[str] = None,
        force_mean: bool = False,
        gate_mode: str = "unitary",
    ) -> jnp.ndarray:
        """
        Execute the quantum circuit (callable interface).

        Provides a convenient callable interface for circuit execution,
        delegating to the _forward method.

        Args:
            params (Optional[jnp.ndarray]): Variational parameters of shape
                (n_layers, n_params_per_layer) or (batch, n_layers, n_params_per_layer).
                If None, uses model's internal parameters.
            inputs (Optional[jnp.ndarray]): Input data of shape
                (batch_size, n_input_feat). If None, uses zero inputs.
            pulse_params (Optional[jnp.ndarray]): Pulse parameter scalers for
                pulse-mode gate execution.
            enc_params (Optional[jnp.ndarray]): Encoding parameters of shape
                (n_qubits, n_input_feat). If None, uses model's encoding parameters.
            data_reupload (Union[bool, List[List[bool]], List[List[List[bool]]]]):
                Data reupload configuration. If None, uses previously set reupload
                configuration.
            noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]):
                Noise configuration. If None, uses previously set noise parameters.
            execution_type (Optional[str]): Measurement type: "expval", "density",
                "probs", or "state". If None, uses current execution_type setting.
            force_mean (bool): If True, averages results over measurement qubits.
                Defaults to False.
            gate_mode (str): Gate execution backend, "unitary" or "pulse".
                Defaults to "unitary".

        Returns:
            jnp.ndarray: Circuit output with shape depending on execution_type:
                - "expval": (n_output_qubits,) or scalar
                - "density": (2^n_output, 2^n_output)
                - "probs": (2^n_output,) or (n_pairs, 2^pair_size)
                - "state": (2^n_qubits,)
        """
        # Call forward method which handles the actual caching etc.
        return self._forward(
            params=params,
            inputs=inputs,
            pulse_params=pulse_params,
            enc_params=enc_params,
            data_reupload=data_reupload,
            noise_params=noise_params,
            execution_type=execution_type,
            force_mean=force_mean,
            gate_mode=gate_mode,
        )

    def _forward(
        self,
        params: Optional[jnp.ndarray] = None,
        inputs: Optional[jnp.ndarray] = None,
        pulse_params: Optional[jnp.ndarray] = None,
        enc_params: Optional[jnp.ndarray] = None,
        data_reupload: Union[bool, List[List[bool]], List[List[List[bool]]]] = None,
        noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None,
        execution_type: Optional[str] = None,
        force_mean: bool = False,
        gate_mode: str = "unitary",
    ) -> jnp.ndarray:
        """
        Execute the quantum circuit forward pass.

        Internal implementation of the forward pass that handles parameter
        validation, batch alignment, and circuit execution routing.

        Args:
            params (Optional[jnp.ndarray]): Variational parameters of shape
                (n_layers, n_params_per_layer) or
                (batch, n_layers, n_params_per_layer).
                If None, uses model's internal parameters.
            inputs (Optional[jnp.ndarray]): Input data of shape
                (batch_size, n_input_feat).
                If None, uses zero inputs.
            pulse_params (Optional[jnp.ndarray]): Pulse parameter scalers for
                pulse-mode gate execution.
            enc_params (Optional[jnp.ndarray]): Encoding parameters of shape
                (n_qubits, n_input_feat). If None, uses model's encoding parameters.
            data_reupload (Union[bool, List[List[bool]], List[List[List[bool]]]]):
                Data reupload configuration. If None, uses previously set reupload
                configuration.
            noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]):
                Noise configuration. If None, uses previously set noise parameters.
            execution_type (Optional[str]): Measurement type: "expval", "density",
                "probs", or "state". If None, uses current execution_type setting.
            force_mean (bool): If True, averages results over measurement qubits.
                Defaults to False.
            gate_mode (str): Gate execution backend, "unitary" or "pulse".
                Defaults to "unitary".

        Returns:
            jnp.ndarray: Circuit output with shape depending on execution_type:
                - "expval": (n_output_qubits,) or scalar
                - "density": (2^n_output, 2^n_output)
                - "probs": (2^n_output,) or (n_pairs, 2^pair_size)
                - "state": (2^n_qubits,)

        Raises:
            ValueError: If pulse_params provided without pulse gate_mode, or
                if noise_params provided with pulse gate_mode.
        """
        # set the parameters as object attributes
        if noise_params is not None:
            self.noise_params = noise_params
        if execution_type is not None:
            self.execution_type = execution_type
        self.gate_mode = gate_mode

        # consistency checks
        if pulse_params is not None and gate_mode != "pulse":
            raise ValueError(
                "pulse_params were provided but gate_mode is not 'pulse'. "
                "Either switch gate_mode='pulse' or do not pass pulse_params."
            )

        # TODO: add testing
        if data_reupload is not None:
            self.data_reupload = data_reupload

        params = self._params_validation(params)
        pulse_params = self._pulse_params_validation(pulse_params)
        inputs = self._inputs_validation(inputs)
        enc_params = self._enc_params_validation(enc_params)

        inputs, params, pulse_params = self._assimilate_batch(
            inputs,
            params,
            pulse_params,
        )

        # split to generate a sub_key, required for actual execution
        self.random_key, sub_key = safe_random_split(self.random_key)

        # Build measurement type & observables from execution_type / output_qubit
        meas_type, obs = self._build_obs()

        # Yaqsi auto-routes between statevector and density-matrix simulation
        # based on whether noise channels appear on the tape, so a single
        B = np.prod(self.eff_batch_shape)

        # kwargs are broadcast (not vmapped over)
        exec_kwargs = dict(
            noise_params=self.noise_params,
            gate_mode=self.gate_mode,
        )

        # Build a shot key from the random_key if shots are requested
        shot_key = None
        if self.shots is not None:
            # overwrite subkey and split shot_key
            sub_key, shot_key = safe_random_split(sub_key)

        if B > 1:
            # use random keys, derived from the subkey
            random_keys = safe_random_split(sub_key, num=B)

            in_axes = (
                0 if self.batch_shape[1] > 1 else None,  # params
                0 if self.batch_shape[0] > 1 else None,  # inputs
                0 if self.batch_shape[2] > 1 else None,  # pulse_params
                0,  # random_keys
                None,  # enc_params (broadcast, not batched)
            )

            result = self.script.execute(
                type=meas_type,
                obs=obs,
                args=(params, inputs, pulse_params, random_keys, enc_params),
                kwargs=exec_kwargs,
                in_axes=in_axes,
                shots=self.shots,
                key=shot_key,
            )
        else:
            # use the subkey directly
            result = self.script.execute(
                type=meas_type,
                obs=obs,
                args=(params, inputs, pulse_params, sub_key, enc_params),
                kwargs=exec_kwargs,
                shots=self.shots,
                key=shot_key,
            )

        result = self._postprocess_res(result)

        # --- Post-processing for partial-qubit measurements ---------------
        if self.execution_type == "density" and not self.all_qubit_measurement:
            result = ys.partial_trace(result, self.n_qubits, self.output_qubit)

        if self.execution_type == "probs" and not self.all_qubit_measurement:
            if isinstance(self.output_qubit[0], (list, tuple)):
                # list of qubit groups - marginalize each independently
                result = jnp.stack(
                    [
                        ys.marginalize_probs(result, self.n_qubits, list(group))
                        for group in self.output_qubit
                    ]
                )
            else:
                result = ys.marginalize_probs(result, self.n_qubits, self.output_qubit)

        result = jnp.asarray(result)
        result = result.reshape((*self.eff_batch_shape, *self._result_shape)).squeeze()

        if (
            self.execution_type in ("expval", "probs")
            and force_mean
            and len(result.shape) > 0
            and self._result_shape[0] > 1
        ):
            result = result.mean(axis=-1)

        return result

all_qubit_measurement property #

Check if measurement is performed on all qubits.

batch_shape property #

Get the batch shape (B_I, B_P, B_R). If the model was not called before, it returns (1, 1, 1).

Returns:

Type Description
Tuple[int, ...]

Tuple[int, ...]: Tuple of (input_batch, param_batch, pulse_batch). Returns (1, 1, 1) if model has not been called yet.

data_reupload property writable #

Get the data reupload mask.

degree property writable #

Get the degree of the model.

eff_batch_shape property #

Get the effective batch shape after applying repeat_batch_axis mask.

Returns:

Type Description
Tuple[int, ...]

Tuple[int, ...]: Effective batch dimensions, excluding zeros.

enc_params property writable #

Get the encoding parameters used for input transformation.

execution_type property writable #

Gets the execution type of the model.

Returns:

Name Type Description
str str

The execution type, one of 'density', 'expval', or 'probs'.

frequencies property writable #

Get the frequencies of the model.

has_dru property #

Check if the model has data reupload.

noise_params property writable #

Gets the noise parameters of the model.

Returns:

Type Description
Optional[Dict[str, Union[float, Dict[str, float]]]]

Optional[Dict[str, float]]: A dictionary of

Optional[Dict[str, Union[float, Dict[str, float]]]]

noise parameters or None if not set.

output_qubit property writable #

Get the output qubit indices for measurement.

params property writable #

Get the variational parameters of the model.

pulse_params property writable #

Get the pulse parameters for pulse-mode gate execution.

shots property writable #

Gets the number of shots to use for the quantum device.

Returns:

Type Description
Optional[int]

Optional[int]: The number of shots.

__call__(params=None, inputs=None, pulse_params=None, enc_params=None, data_reupload=None, noise_params=None, execution_type=None, force_mean=False, gate_mode='unitary') #

Execute the quantum circuit (callable interface).

Provides a convenient callable interface for circuit execution, delegating to the _forward method.

Parameters:

Name Type Description Default
params Optional[ndarray]

Variational parameters of shape (n_layers, n_params_per_layer) or (batch, n_layers, n_params_per_layer). If None, uses model's internal parameters.

None
inputs Optional[ndarray]

Input data of shape (batch_size, n_input_feat). If None, uses zero inputs.

None
pulse_params Optional[ndarray]

Pulse parameter scalers for pulse-mode gate execution.

None
enc_params Optional[ndarray]

Encoding parameters of shape (n_qubits, n_input_feat). If None, uses model's encoding parameters.

None
data_reupload Union[bool, List[List[bool]], List[List[List[bool]]]]

Data reupload configuration. If None, uses previously set reupload configuration.

None
noise_params Optional[Dict[str, Union[float, Dict[str, float]]]]

Noise configuration. If None, uses previously set noise parameters.

None
execution_type Optional[str]

Measurement type: "expval", "density", "probs", or "state". If None, uses current execution_type setting.

None
force_mean bool

If True, averages results over measurement qubits. Defaults to False.

False
gate_mode str

Gate execution backend, "unitary" or "pulse". Defaults to "unitary".

'unitary'

Returns:

Type Description
ndarray

jnp.ndarray: Circuit output with shape depending on execution_type: - "expval": (n_output_qubits,) or scalar - "density": (2^n_output, 2^n_output) - "probs": (2^n_output,) or (n_pairs, 2^pair_size) - "state": (2^n_qubits,)

Source code in qml_essentials/model.py
def __call__(
    self,
    params: Optional[jnp.ndarray] = None,
    inputs: Optional[jnp.ndarray] = None,
    pulse_params: Optional[jnp.ndarray] = None,
    enc_params: Optional[jnp.ndarray] = None,
    data_reupload: Union[bool, List[List[bool]], List[List[List[bool]]]] = None,
    noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None,
    execution_type: Optional[str] = None,
    force_mean: bool = False,
    gate_mode: str = "unitary",
) -> jnp.ndarray:
    """
    Execute the quantum circuit (callable interface).

    Provides a convenient callable interface for circuit execution,
    delegating to the _forward method.

    Args:
        params (Optional[jnp.ndarray]): Variational parameters of shape
            (n_layers, n_params_per_layer) or (batch, n_layers, n_params_per_layer).
            If None, uses model's internal parameters.
        inputs (Optional[jnp.ndarray]): Input data of shape
            (batch_size, n_input_feat). If None, uses zero inputs.
        pulse_params (Optional[jnp.ndarray]): Pulse parameter scalers for
            pulse-mode gate execution.
        enc_params (Optional[jnp.ndarray]): Encoding parameters of shape
            (n_qubits, n_input_feat). If None, uses model's encoding parameters.
        data_reupload (Union[bool, List[List[bool]], List[List[List[bool]]]]):
            Data reupload configuration. If None, uses previously set reupload
            configuration.
        noise_params (Optional[Dict[str, Union[float, Dict[str, float]]]]):
            Noise configuration. If None, uses previously set noise parameters.
        execution_type (Optional[str]): Measurement type: "expval", "density",
            "probs", or "state". If None, uses current execution_type setting.
        force_mean (bool): If True, averages results over measurement qubits.
            Defaults to False.
        gate_mode (str): Gate execution backend, "unitary" or "pulse".
            Defaults to "unitary".

    Returns:
        jnp.ndarray: Circuit output with shape depending on execution_type:
            - "expval": (n_output_qubits,) or scalar
            - "density": (2^n_output, 2^n_output)
            - "probs": (2^n_output,) or (n_pairs, 2^pair_size)
            - "state": (2^n_qubits,)
    """
    # Call forward method which handles the actual caching etc.
    return self._forward(
        params=params,
        inputs=inputs,
        pulse_params=pulse_params,
        enc_params=enc_params,
        data_reupload=data_reupload,
        noise_params=noise_params,
        execution_type=execution_type,
        force_mean=force_mean,
        gate_mode=gate_mode,
    )

__init__(n_qubits, n_layers, circuit_type='No_Ansatz', data_reupload=True, state_preparation=None, encoding=Gates.RX, trainable_frequencies=False, initialization='random', initialization_domain=[0, 2 * jnp.pi], output_qubit=-1, shots=None, random_seed=1000, remove_zero_encoding=True, repeat_batch_axis=[True, True, True], pulse_shape='gaussian') #

Initialize the quantum circuit model. Parameters will have the shape [impl_n_layers, parameters_per_layer] where impl_n_layers is the number of layers provided and added by one depending if data_reupload is True and parameters_per_layer is given by the chosen ansatz.

The model is initialized with the following parameters as defaults: - noise_params: None - execution_type: "expval" - shots: None

Parameters:

Name Type Description Default
n_qubits int

The number of qubits in the circuit.

required
n_layers int

The number of layers in the circuit.

required
circuit_type (str, Circuit)

The type of quantum circuit to use. If None, defaults to "no_ansatz".

'No_Ansatz'
data_reupload Union[bool, List[bool], List[List[bool]]]

Whether to reupload data to the quantum device on each layer and qubit. Detailed re-uploading instructions can be given as a list/array of 0/False and 1/True with shape (n_qubits, n_layers) to specify where to upload the data. Defaults to True for applying data re-uploading to the full circuit.

True
encoding Union[str, Callable, List[str], List[Callable]]

The unitary to use for encoding the input data. Can be a string (e.g. "RX") or a callable (e.g. op.RX). Defaults to op.RX. If input is multidimensional it is assumed to be a list of unitaries or a list of strings.

RX
trainable_frequencies bool

Sets trainable encoding parameters for trainable frequencies. Defaults to False.

False
initialization str

The strategy to initialize the parameters. Can be "random", "zeros", "zero-controlled", "pi", or "pi-controlled". Defaults to "random".

'random'
output_qubit (List[int], int)

The index of the output qubit (or qubits). When set to -1 all qubits are measured, or a global measurement is conducted, depending on the execution type.

-1
shots Optional[int]

The number of shots to use for the quantum device. Defaults to None.

None
random_seed int

seed for the random number generator in initialization is "random" and for random noise parameters. Defaults to 1000.

1000
remove_zero_encoding bool

whether to remove the zero encoding from the circuit. Defaults to True.

True
repeat_batch_axis List[bool]

Each boolean in the array determines over which axes to parallelise computation. The axes correspond to [inputs, params, pulse_params]. Defaults to [True, True, True], meaning that batching is enabled over all axes.

[True, True, True]
pulse_shape str

Pulse envelope shape for pulse-level simulation. One of PulseEnvelope.available(). Defaults to "gaussian".

'gaussian'

Returns:

Type Description
None

None

Source code in qml_essentials/model.py
def __init__(
    self,
    n_qubits: int,
    n_layers: int,
    circuit_type: Union[str, Circuit] = "No_Ansatz",
    data_reupload: Union[bool, List[List[bool]], List[List[List[bool]]]] = True,
    state_preparation: Union[
        str, Callable, List[Union[str, Callable]], None
    ] = None,
    encoding: Union[Encoding, str, Callable, List[Union[str, Callable]]] = Gates.RX,
    trainable_frequencies: bool = False,
    initialization: str = "random",
    initialization_domain: List[float] = [0, 2 * jnp.pi],
    output_qubit: Union[List[int], int] = -1,
    shots: Optional[int] = None,
    random_seed: int = 1000,
    remove_zero_encoding: bool = True,
    repeat_batch_axis: List[bool] = [True, True, True],
    pulse_shape: str = "gaussian",
) -> None:
    """
    Initialize the quantum circuit model.
    Parameters will have the shape [impl_n_layers, parameters_per_layer]
    where impl_n_layers is the number of layers provided and added by one
    depending if data_reupload is True and parameters_per_layer is given by
    the chosen ansatz.

    The model is initialized with the following parameters as defaults:
    - noise_params: None
    - execution_type: "expval"
    - shots: None

    Args:
        n_qubits (int): The number of qubits in the circuit.
        n_layers (int): The number of layers in the circuit.
        circuit_type (str, Circuit): The type of quantum circuit to use.
            If None, defaults to "no_ansatz".
        data_reupload (Union[bool, List[bool], List[List[bool]]], optional):
            Whether to reupload data to the quantum device on each
            layer and qubit. Detailed re-uploading instructions can be given
            as a list/array of 0/False and 1/True with shape (n_qubits,
            n_layers) to specify where to upload the data. Defaults to True
            for applying data re-uploading to the full circuit.
        encoding (Union[str, Callable, List[str], List[Callable]], optional):
            The unitary to use for encoding the input data. Can be a string
            (e.g. "RX") or a callable (e.g. op.RX). Defaults to op.RX.
            If input is multidimensional it is assumed to be a list of
            unitaries or a list of strings.
        trainable_frequencies (bool, optional):
            Sets trainable encoding parameters for trainable frequencies.
            Defaults to False.
        initialization (str, optional): The strategy to initialize the parameters.
            Can be "random", "zeros", "zero-controlled", "pi", or "pi-controlled".
            Defaults to "random".
        output_qubit (List[int], int, optional): The index of the output
            qubit (or qubits). When set to -1 all qubits are measured, or a
            global measurement is conducted, depending on the execution
            type.
        shots (Optional[int], optional): The number of shots to use for
            the quantum device. Defaults to None.
        random_seed (int, optional): seed for the random number generator
            in initialization is "random" and for random noise parameters.
            Defaults to 1000.
        remove_zero_encoding (bool, optional): whether to
            remove the zero encoding from the circuit. Defaults to True.
        repeat_batch_axis (List[bool], optional): Each boolean in the array
            determines over which axes to parallelise computation. The axes
            correspond to [inputs, params, pulse_params]. Defaults to
            [True, True, True], meaning that batching is enabled over all
            axes.
        pulse_shape (str, optional): Pulse envelope shape for pulse-level
            simulation. One of ``PulseEnvelope.available()``.
            Defaults to ``"gaussian"``.

    Returns:
        None
    """
    # Initialize default parameters needed for circuit evaluation
    self.n_qubits: int = n_qubits
    self.output_qubit: Union[List[int], int] = output_qubit
    self.n_layers: int = n_layers
    self.noise_params: Optional[Dict[str, Union[float, Dict[str, float]]]] = None
    self.shots = shots
    self.remove_zero_encoding = remove_zero_encoding
    self.trainable_frequencies: bool = trainable_frequencies
    self.execution_type: str = "expval"
    self.repeat_batch_axis: List[bool] = repeat_batch_axis

    # --- Pulse envelope ---
    pinfo.set_envelope(pulse_shape)

    # --- State Preparation ---
    try:
        self._sp = Gates.parse_gates(state_preparation, Gates)
    except ValueError as e:
        raise ValueError(f"Error parsing encodings: {e}")

    # prepare corresponding pulse parameters (always optimized pulses)
    self.sp_pulse_params = []
    for sp in self._sp:
        sp_name = sp.__name__ if hasattr(sp, "__name__") else str(sp)

        if pinfo.gate_by_name(sp_name) is not None:
            self.sp_pulse_params.append(pinfo.gate_by_name(sp_name).params)
        else:
            # gate has no pulse parametrization
            self.sp_pulse_params.append(None)

    # --- Encoding ---
    if isinstance(encoding, Encoding):
        # user wants custom strategy? do it!
        self._enc = encoding
    else:
        # use hammming encoding by default
        self._enc = Encoding("hamming", encoding)

    # Number of possible inputs
    self.n_input_feat = len(self._enc)
    log.debug(f"Number of input features: {self.n_input_feat}")

    # Trainable frequencies, default initialization as in arXiv:2309.03279v2
    self.enc_params = jnp.ones((self.n_layers, self.n_qubits, self.n_input_feat))

    self._zero_inputs = False

    # --- Data-Reuploading ---

    # Keep as NumPy array (not JAX) so that ``if data_reupload[q, idx]``
    # in _iec remains a concrete Python bool even under jax.jit tracing.
    # note that setting this will also update self.degree and self.frequencies
    # and in consequence also self.has_dru
    self.data_reupload = data_reupload

    # check for the highest degree among all input dimensions
    if self.has_dru:
        impl_n_layers: int = n_layers + 1  # we need L+1 according to Schuld et al.
    else:
        impl_n_layers = n_layers
    log.info(f"Number of implicit layers: {impl_n_layers}.")

    # --- Ansatz ---
    # only weak check for str. We trust the user to provide sth useful
    if isinstance(circuit_type, str):
        self.pqc: Callable[[Optional[jnp.ndarray], int], int] = getattr(
            Ansaetze, circuit_type or "No_Ansatz"
        )()
    else:
        self.pqc = circuit_type()
    log.info(f"Using Ansatz {circuit_type}.")

    # calculate the shape of the parameter vector here, we will re-use this in init.
    params_per_layer = self.pqc.n_params_per_layer(self.n_qubits)
    self._params_shape: Tuple[int, int] = (impl_n_layers, params_per_layer)
    log.info(f"Parameters per layer: {params_per_layer}")

    pulse_params_per_layer = self.pqc.n_pulse_params_per_layer(self.n_qubits)
    self._pulse_params_shape: Tuple[int, int] = (
        impl_n_layers,
        pulse_params_per_layer,
    )

    # intialize to None as we can't know this yet
    self._batch_shape = None

    # this will also be re-used in the init method,
    # however, only if nothing is provided
    self._inialization_strategy = initialization
    self._initialization_domain = initialization_domain

    # ..here! where we only require a JAX random key
    self.random_key = self.initialize_params(random.key(random_seed))

    # Initializing pulse params
    self.pulse_params: jnp.ndarray = jnp.ones((1, *self._pulse_params_shape))

    log.info(f"Initialized pulse parameters with shape {self.pulse_params.shape}.")

    # Initialise the yaqsi Script that wraps _variational.
    # No device selection needed - yaqsi auto-routes between statevector
    # and density-matrix simulation based on whether noise channels are
    # present on the tape.
    self.script = ys.Script(f=self._variational, n_qubits=self.n_qubits)

__repr__() #

Return text representation of the quantum circuit model.

Source code in qml_essentials/model.py
def __repr__(self) -> str:
    """Return text representation of the quantum circuit model."""
    return self.draw(figure="text")

__str__() #

Return string representation of the quantum circuit model.

Source code in qml_essentials/model.py
def __str__(self) -> str:
    """Return string representation of the quantum circuit model."""
    return self.draw(figure="text")

draw(inputs=None, figure='text', **kwargs) #

Visualize the quantum circuit.

Records the circuit tape (without noise) and renders the gate sequence using the requested backend.

Parameters:

Name Type Description Default
inputs Optional[ndarray]

Input data for the circuit. If None, default zero inputs are used.

None
figure str

Rendering backend. One of:

  • "text" - ASCII art (returned as a str).
  • "mpl" - Matplotlib figure (returns (fig, ax)).
  • "tikz" - LaTeX/TikZ quantikz code (returns a :class:TikzFigure).
  • "pulse" - Pulse schedule (returns (fig, axes)). Only meaningful for pulse-mode models.
'text'
**kwargs Any

Extra options forwarded to the drawing backend (e.g. gate_values=True).

{}

Returns:

Type Description
Union[str, Any]

Depends on figure:

Union[str, Any]
  • "text" -> str
Union[str, Any]
  • "mpl" -> (matplotlib.figure.Figure, matplotlib.axes.Axes)
Union[str, Any]
  • "tikz" -> :class:TikzFigure

Raises:

Type Description
ValueError

If figure is not one of the supported modes.

Source code in qml_essentials/model.py
def draw(
    self,
    inputs: Optional[jnp.ndarray] = None,
    figure: str = "text",
    **kwargs: Any,
) -> Union[str, Any]:
    """Visualize the quantum circuit.

    Records the circuit tape (without noise) and renders the gate
    sequence using the requested backend.

    Args:
        inputs (Optional[jnp.ndarray]): Input data for the circuit.
            If ``None``, default zero inputs are used.
        figure (str): Rendering backend.  One of:

            * ``"text"``  - ASCII art (returned as a ``str``).
            * ``"mpl"``   - Matplotlib figure (returns ``(fig, ax)``).
            * ``"tikz"``  - LaTeX/TikZ ``quantikz`` code (returns a
              :class:`TikzFigure`).
            * ``"pulse"`` - Pulse schedule (returns ``(fig, axes)``).
              Only meaningful for pulse-mode models.

        **kwargs: Extra options forwarded to the drawing backend
            (e.g. ``gate_values=True``).

    Returns:
        Depends on figure:

        * ``"text"``  -> ``str``
        * ``"mpl"``   -> ``(matplotlib.figure.Figure, matplotlib.axes.Axes)``
        * ``"tikz"``  -> :class:`TikzFigure`

    Raises:
        ValueError: If figure is not one of the supported modes.
    """
    inputs = self._inputs_validation(inputs)
    params = self.params[0] if self.params.ndim == 3 else self.params
    inp = inputs[0] if inputs.ndim == 2 else inputs

    if figure == "pulse":
        return self.draw_pulse(inputs=inputs, **kwargs)

    # Record without noise to get a clean circuit
    saved_noise = self._noise_params
    self._noise_params = None

    draw_script = ys.Script(f=self._variational, n_qubits=self.n_qubits)
    result = draw_script.draw(
        figure=figure,
        args=(params, inp),
        kwargs={"noise_params": None},
        **kwargs,
    )

    self._noise_params = saved_noise
    return result

draw_pulse(inputs=None, **kwargs) #

Visualize the pulse schedule for the circuit.

Records the circuit in pulse mode and collects PulseEvents automatically via the pulse-event tape, then renders them.

Parameters:

Name Type Description Default
inputs Optional[ndarray]

Input data. If None, default zero inputs are used.

None
**kwargs Any

Forwarded to :func:~qml_essentials.drawing.draw_pulse_schedule (e.g. show_carrier=True, n_samples=300).

{}

Returns:

Type Description
Any

(fig, axes) — Matplotlib Figure and array of Axes.

Source code in qml_essentials/model.py
def draw_pulse(
    self,
    inputs: Optional[jnp.ndarray] = None,
    **kwargs: Any,
) -> Any:
    """Visualize the pulse schedule for the circuit.

    Records the circuit in pulse mode and collects PulseEvents
    automatically via the pulse-event tape, then renders them.

    Args:
        inputs: Input data.  If ``None``, default zero inputs are used.
        **kwargs: Forwarded to
            :func:`~qml_essentials.drawing.draw_pulse_schedule`
            (e.g. ``show_carrier=True``, ``n_samples=300``).

    Returns:
        ``(fig, axes)`` — Matplotlib Figure and array of Axes.
    """
    inputs = self._inputs_validation(inputs)
    params = self.params[0] if self.params.ndim == 3 else self.params
    inp = inputs[0] if inputs.ndim == 2 else inputs

    draw_script = ys.Script(f=self._variational, n_qubits=self.n_qubits)
    return draw_script.draw(
        figure="pulse",
        args=(params, inp),
        kwargs={
            "gate_mode": "pulse",
            "noise_params": None,
        },
        **kwargs,
    )

initialize_params(random_key=None, repeat=1, initialization=None, initialization_domain=None) #

Initialize the variational parameters of the model.

Parameters:

Name Type Description Default
random_key Optional[PRNGKey]

JAX random key for initialization. If None, uses the model's internal random key.

None
repeat int

Number of parameter sets to create (batch dimension). Defaults to 1.

1
initialization Optional[str]

Strategy for parameter initialization. Options: "random", "zeros", "pi", "zero-controlled", "pi-controlled". If None, uses the strategy specified in the constructor.

None
initialization_domain Optional[List[float]]

Domain [min, max] for random initialization. If None, uses the domain from constructor.

None

Returns:

Type Description
PRNGKey

random.PRNGKey: Updated random key after initialization.

Raises:

Type Description
Exception

If an invalid initialization method is specified.

Source code in qml_essentials/model.py
def initialize_params(
    self,
    random_key: Optional[random.PRNGKey] = None,
    repeat: int = 1,
    initialization: Optional[str] = None,
    initialization_domain: Optional[List[float]] = None,
) -> random.PRNGKey:
    """
    Initialize the variational parameters of the model.

    Args:
        random_key (Optional[random.PRNGKey]): JAX random key for initialization.
            If None, uses the model's internal random key.
        repeat (int): Number of parameter sets to create (batch dimension).
            Defaults to 1.
        initialization (Optional[str]): Strategy for parameter initialization.
            Options: "random", "zeros", "pi", "zero-controlled", "pi-controlled".
            If None, uses the strategy specified in the constructor.
        initialization_domain (Optional[List[float]]): Domain [min, max] for
            random initialization. If None, uses the domain from constructor.

    Returns:
        random.PRNGKey: Updated random key after initialization.

    Raises:
        Exception: If an invalid initialization method is specified.
    """
    # Initializing params
    params_shape = (repeat, *self._params_shape)

    # use existing strategy if not specified
    initialization = initialization or self._inialization_strategy
    initialization_domain = initialization_domain or self._initialization_domain

    random_key, sub_key = safe_random_split(
        random_key if random_key is not None else self.random_key
    )

    def set_control_params(params: jnp.ndarray, value: float) -> jnp.ndarray:
        indices = self.pqc.get_control_indices(self.n_qubits)
        if indices is None:
            warnings.warn(
                f"Specified {initialization} but circuit\
                does not contain controlled rotation gates.\
                Parameters are intialized randomly.",
                UserWarning,
            )
        else:
            np_params = np.array(params)
            np_params[:, :, indices[0] : indices[1] : indices[2]] = (
                np.ones_like(params[:, :, indices[0] : indices[1] : indices[2]])
                * value
            )
            params = jnp.array(np_params)
        return params

    if initialization == "random":
        self.params: jnp.ndarray = random.uniform(
            sub_key,
            params_shape,
            minval=initialization_domain[0],
            maxval=initialization_domain[1],
        )
    elif initialization == "zeros":
        self.params: jnp.ndarray = jnp.zeros(params_shape)
    elif initialization == "pi":
        self.params: jnp.ndarray = jnp.ones(params_shape) * jnp.pi
    elif initialization == "zero-controlled":
        self.params: jnp.ndarray = random.uniform(
            sub_key,
            params_shape,
            minval=initialization_domain[0],
            maxval=initialization_domain[1],
        )
        self.params = set_control_params(self.params, 0)
    elif initialization == "pi-controlled":
        self.params: jnp.ndarray = random.uniform(
            sub_key,
            params_shape,
            minval=initialization_domain[0],
            maxval=initialization_domain[1],
        )
        self.params = set_control_params(self.params, jnp.pi)
    else:
        raise Exception("Invalid initialization method")

    log.info(
        f"Initialized parameters with shape {self.params.shape}\
        using strategy {initialization}."
    )

    return random_key

transform_input(inputs, enc_params) #

Transform input data by scaling with encoding parameters.

Implements the input transformation as described in arXiv:2309.03279v2, where inputs are linearly scaled by encoding parameters before being used in the quantum circuit.

Parameters:

Name Type Description Default
inputs ndarray

Input data point of shape (n_input_feat,) or (batch_size, n_input_feat).

required
enc_params ndarray

Encoding weight scalar or vector used to scale the input.

required

Returns:

Type Description
ndarray

jnp.ndarray: Transformed input, element-wise product of inputs and enc_params.

Source code in qml_essentials/model.py
def transform_input(
    self, inputs: jnp.ndarray, enc_params: jnp.ndarray
) -> jnp.ndarray:
    """
    Transform input data by scaling with encoding parameters.

    Implements the input transformation as described in arXiv:2309.03279v2,
    where inputs are linearly scaled by encoding parameters before being
    used in the quantum circuit.

    Args:
        inputs (jnp.ndarray): Input data point of shape (n_input_feat,) or
            (batch_size, n_input_feat).
        enc_params (jnp.ndarray): Encoding weight scalar or vector used to
            scale the input.

    Returns:
        jnp.ndarray: Transformed input, element-wise product of inputs
            and enc_params.
    """
    return inputs * enc_params

Entanglement#

from qml_essentials.entanglement import Entanglement
Source code in qml_essentials/entanglement.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
class Entanglement:
    @staticmethod
    def meyer_wallach(
        model: Model,
        n_samples: Optional[int | None],
        random_key: Optional[jax.random.PRNGKey] = None,
        scale: bool = False,
        **kwargs: Any,
    ) -> float:
        """
        Calculates the entangling capacity of a given quantum circuit
        using Meyer-Wallach measure.

        Args:
            model (Model): The quantum circuit model.
            n_samples (Optional[int]): Number of samples per qubit.
                If None or < 0, the current parameters of the model are used.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for
                parameter initialization. If None, uses the model's internal
                random key.
            scale (bool): Whether to scale the number of samples.
            kwargs (Any): Additional keyword arguments for the model function.

        Returns:
            float: Entangling capacity of the given circuit, guaranteed
                to be between 0.0 and 1.0.
        """
        if "noise_params" in kwargs:
            log.warning(
                "Meyer-Wallach measure not suitable for noisy circuits. "
                "Consider 'concentratable entanglement' instead."
            )

        if scale:
            n_samples = jnp.power(2, model.n_qubits) * n_samples

        if n_samples is not None and n_samples > 0:
            random_key = model.initialize_params(random_key, repeat=n_samples)

        # implicitly set input to none in case it's not needed
        kwargs.setdefault("inputs", None)
        # explicitly set execution type because everything else won't work
        rhos = model(execution_type="density", **kwargs).reshape(
            -1, 2**model.n_qubits, 2**model.n_qubits
        )

        ent = Entanglement._compute_meyer_wallach_meas(rhos, model.n_qubits)

        log.debug(f"Variance of measure: {ent.var()}")

        return ent.mean()

    @staticmethod
    def _compute_meyer_wallach_meas(rhos: jnp.ndarray, n_qubits: int) -> jnp.ndarray:
        """
        Computes the Meyer-Wallach entangling capability measure for a given
        set of density matrices.

        Args:
            rhos (jnp.ndarray): Density matrices of the sample quantum states.
                The shape is (B_s, 2^n, 2^n), where B_s is the number of samples
                (batch) and n the number of qubits
            n_qubits (int): The number of qubits

                    Returns:
            jnp.ndarray: Entangling capability for each sample, array with
                shape (B_s,)
        """
        qb = list(range(n_qubits))

        def _f(rhos):
            entropy = 0
            for j in range(n_qubits):
                # Formula 6 in https://doi.org/10.48550/arXiv.quant-ph/0305094
                # Trace out qubit j, keep all others
                keep = qb[:j] + qb[j + 1 :]
                density = ys.partial_trace(rhos, n_qubits, keep)
                # only real values, because imaginary part will be separate
                # in all following calculations anyway
                # entropy should be 1/2 <= entropy <= 1
                entropy += jnp.trace((density @ density).real, axis1=-2, axis2=-1)

            # inverse averaged entropy and scale to [0, 1]
            return 2 * (1 - entropy / n_qubits)

        return jax.vmap(_f)(rhos)

    @staticmethod
    def bell_measurements(
        model: Model,
        n_samples: int,
        random_key: Optional[jax.random.PRNGKey] = None,
        scale: bool = False,
        **kwargs: Any,
    ) -> float:
        """
        Compute the Bell measurement for a given model.

        Constructs a ``2 * n_qubits`` circuit that prepares two copies of
        the model state (on disjoint qubit registers), applies CNOTs and
        Hadamards, and measures probabilities on the first register.

        Args:
            model (Model): The quantum circuit model.
            n_samples (int): The number of samples to compute the measure for.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for
                parameter initialization. If None, uses the model's internal
                random key.
            scale (bool): Whether to scale the number of samples
                according to the number of qubits.
            **kwargs (Any): Additional keyword arguments for the model function.

        Returns:
            float: The Bell measurement value.
        """
        if "noise_params" in kwargs:
            log.warning(
                "Bell Measurements not suitable for noisy circuits. "
                "Consider 'concentratable entanglement' instead."
            )

        if scale:
            n_samples = jnp.power(2, model.n_qubits) * n_samples

        n = model.n_qubits

        def _bell_circuit(params, inputs, pulse_params=None, random_key=None, **kw):
            """Bell measurement circuit on 2*n qubits."""
            # First copy on wires 0..n-1
            model._variational(
                params, inputs, pulse_params=pulse_params, random_key=random_key, **kw
            )

            # TODO: this is very user-unfriendly and we should find a better way

            # Second copy on wires n..2n-1: record the tape then shift wires
            from qml_essentials.tape import recording as _recording

            with _recording() as shifted_tape:
                model._variational(
                    params,
                    inputs,
                    pulse_params=pulse_params,
                    random_key=random_key,
                    **kw,
                )
            for o in shifted_tape:
                shifted_op = o.__class__.__new__(o.__class__)
                shifted_op.__dict__.update(o.__dict__)
                shifted_op._wires = [w + n for w in o.wires]
                # Re-register on the active tape
                from qml_essentials.tape import active_tape as _active_tape

                tape = _active_tape()
                if tape is not None:
                    tape.append(shifted_op)

            # Bell measurement: CNOT + H
            for q in range(n):
                op.CX(wires=[q, q + n])
                op.H(wires=q)

        bell_script = ys.Script(f=_bell_circuit, n_qubits=2 * n)

        if n_samples is not None and n_samples > 0:
            random_key = model.initialize_params(random_key, repeat=n_samples)
            params = model.params
        else:
            if len(model.params.shape) <= 2:
                params = model.params.reshape(1, *model.params.shape)
            else:
                log.info(f"Using sample size of model params: {model.params.shape[0]}")
                params = model.params

        n_samples = params.shape[0]
        inputs = model._inputs_validation(kwargs.get("inputs", None))

        # Execute: vmap over batch dimension of params (axis 0)
        if n_samples > 1:
            from qml_essentials.utils import safe_random_split

            random_keys = safe_random_split(random_key, num=n_samples)
            result = bell_script.execute(
                type="probs",
                args=(params, inputs, model.pulse_params, random_keys),
                kwargs=kwargs,
                in_axes=(0, None, None, 0),
            )
        else:
            result = bell_script.execute(
                type="probs",
                args=(params, inputs, model.pulse_params, random_key),
                kwargs=kwargs,
            )

        # Marginalize: for each qubit q, keep wires [q, q+n] from the 2n-qubit probs
        # The last probability in each pair gives P(|11⟩) for that qubit pair
        per_qubit = []
        for q in range(n):
            marg = ys.marginalize_probs(result, 2 * n, [q, q + n])
            per_qubit.append(marg)
        # per_qubit[q] has shape (n_samples, 4) or (4,)
        exp = jnp.stack(per_qubit, axis=-2)  # (..., n, 4)
        exp = 1 - 2 * exp[..., -1]  # (..., n)

        if not jnp.isclose(jnp.sum(exp.imag), 0, atol=1e-6):
            log.warning("Imaginary part of probabilities detected")
            exp = jnp.abs(exp)

        measure = 2 * (1 - exp.mean(axis=0))
        entangling_capability = min(max(float(measure.mean()), 0.0), 1.0)
        log.debug(f"Variance of measure: {measure.var()}")

        return entangling_capability

    @staticmethod
    def relative_entropy(
        model: Model,
        n_samples: int,
        n_sigmas: int,
        random_key: Optional[jax.random.PRNGKey] = None,
        scale: bool = False,
        **kwargs: Any,
    ) -> float:
        """
        Calculates the relative entropy of entanglement of a given quantum
        circuit. This measure is also applicable to mixed state, albeit it
        might me not fully accurate in this simplified case.

        As the relative entropy is generally defined as the smallest relative
        entropy from the state in question to the set of separable states.
        However, as computing the nearest separable state is NP-hard, we select
        n_sigmas of random separable states to compute the distance to, which
        is not necessarily the nearest. Thus, this measure of entanglement
        presents an upper limit of entanglement.

        As the relative entropy is not necessarily between zero and one, this
        function also normalises by the relative entroy to the GHZ state.

        Args:
            model (Model): The quantum circuit model.
            n_samples (int): Number of samples per qubit.
                If <= 0, the current parameters of the model are used.
            n_sigmas (int): Number of random separable pure states to compare against.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for
                parameter initialization. If None, uses the model's internal
                random key.
            scale (bool): Whether to scale the number of samples.
            kwargs (Any): Additional keyword arguments for the model function.

        Returns:
            float: Entangling capacity of the given circuit, guaranteed
                to be between 0.0 and 1.0.
        """
        dim = jnp.power(2, model.n_qubits)
        if scale:
            n_samples = dim * n_samples
            n_sigmas = dim * n_sigmas

        if random_key is None:
            random_key = model.random_key

        # Random separable states
        log_sigmas = sample_random_separable_states(
            model.n_qubits, n_samples=n_sigmas, random_key=random_key, take_log=True
        )

        random_key, _ = jax.random.split(random_key)

        if n_samples is not None and n_samples > 0:
            model.initialize_params(random_key, repeat=n_samples)
        else:
            if len(model.params.shape) <= 2:
                model.params = model.params.reshape(1, *model.params.shape)
            else:
                log.info(f"Using sample size of model params: {model.params.shape[0]}")

        rhos, log_rhos = Entanglement._compute_log_density(model, **kwargs)

        rel_entropies = jnp.zeros((n_sigmas, model.params.shape[0]))

        for i, log_sigma in enumerate(log_sigmas):
            rel_entropies = rel_entropies.at[i].set(
                Entanglement._compute_rel_entropies(rhos, log_rhos, log_sigma)
            )

        # Entropy of GHZ states should be maximal
        ghz_model = Model(model.n_qubits, 1, "GHZ", data_reupload=False)
        rho_ghz, log_rho_ghz = Entanglement._compute_log_density(ghz_model, **kwargs)
        ghz_entropies = Entanglement._compute_rel_entropies(
            rho_ghz, log_rho_ghz, log_sigmas
        )

        normalised_entropies = rel_entropies / ghz_entropies

        # Average all iterated states
        entangling_capability = normalised_entropies.T.min(axis=1)
        log.debug(f"Variance of measure: {entangling_capability.var()}")

        return entangling_capability.mean()

    @staticmethod
    def _compute_log_density(model: Model, **kwargs) -> Tuple[jnp.ndarray, jnp.ndarray]:
        """
        Obtains the density matrix of a model and computes its logarithm.

        Args:
            model (Model): The model for which to compute the density matrix.

        Returns:
            Tuple[jnp.ndarray, jnp.ndarray]:
                - jnp.ndarray: density matrix.
                - jnp.ndarray: logarithm of the density matrix.
        """
        # implicitly set input to none in case it's not needed
        kwargs.setdefault("inputs", None)
        # explicitly set execution type because everything else won't work
        rho = model(execution_type="density", **kwargs)
        rho = rho.reshape(-1, 2**model.n_qubits, 2**model.n_qubits)
        log_rho = logm_v(rho) / jnp.log(2)
        return rho, log_rho

    @staticmethod
    def _compute_rel_entropies(
        rhos: jnp.ndarray,
        log_rhos: jnp.ndarray,
        log_sigmas: jnp.ndarray,
    ) -> jnp.ndarray:
        """
        Compute the relative entropy for a given model.

        Args:
            rhos (jnp.ndarray): Density matrix result of the circuit, has shape
                (R, 2^n, 2^n), with the batch size R and number of qubits n
            log_rhos (jnp.ndarray): Corresponding logarithm of the density
                matrix, has shape (R, 2^n, 2^n).
            log_sigmas (jnp.ndarray): Density matrix of next separable state,
                has shape (2^n, 2^n) if it's a single sigma or (S, 2^n, 2^n),
                with the batch size S (number of sigmas).

        Returns:
            jnp.ndarray: Relative Entropy for each sample
        """
        n_rhos = rhos.shape[0]
        if len(log_sigmas.shape) == 3:
            n_sigmas = log_sigmas.shape[0]
            rhos = jnp.tile(rhos, (n_sigmas, 1, 1))
            log_rhos = jnp.tile(log_rhos, (n_sigmas, 1, 1))
            einsum_subscript = "ij,jk->ik"
        else:
            n_sigmas = 1
            log_sigmas = log_sigmas[jnp.newaxis, ...].repeat(n_rhos, axis=0)

        einsum_subscript = "ij,jk->ik"

        def _f(rhos, log_rhos, log_sigmas):
            prod = jnp.einsum(einsum_subscript, rhos, log_rhos - log_sigmas)
            rel_entropies = jnp.abs(jnp.trace(prod, axis1=-2, axis2=-1))
            return rel_entropies

        rel_entropies = jax.vmap(_f, in_axes=(0, 0, 0))(rhos, log_rhos, log_sigmas)

        if n_sigmas > 1:
            rel_entropies = rel_entropies.reshape(n_sigmas, n_rhos)
        return rel_entropies

    @staticmethod
    def entanglement_of_formation(
        model: Model,
        n_samples: int,
        random_key: Optional[jax.random.PRNGKey] = None,
        scale: bool = False,
        always_decompose: bool = False,
        **kwargs: Any,
    ) -> float:
        """
        This function implements the entanglement of formation for mixed
        quantum systems.
        In that a mixed state gets decomposed into pure states with respective
        probabilities using the eigendecomposition of the density matrix.
        Then, the Meyer-Wallach measure is computed for each pure state,
        weighted by the eigenvalue.
        See e.g. https://doi.org/10.48550/arXiv.quant-ph/0504163

        Note that the decomposition is *not unique*! Therefore, this measure
        presents the entanglement for *some* decomposition into pure states,
        not necessarily the one that is anticipated when applying the Kraus
        channels.
        If a pure state is provided, this results in the same value as the
        Entanglement.meyer_wallach function if `always_decompose` flag is not set.

        Args:
            model (Model): The quantum circuit model.
            n_samples (int): Number of samples per qubit.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for
                parameter initialization. If None, uses the model's internal
                random key.
            scale (bool): Whether to scale the number of samples.
            always_decompose (bool): Whether to explicitly compute the
                entantlement of formation for the eigendecomposition of a pure
                state.
            kwargs (Any): Additional keyword arguments for the model function.

        Returns:
            float: Entangling capacity of the given circuit, guaranteed
                to be between 0.0 and 1.0.
        """

        if scale:
            n_samples = jnp.power(2, model.n_qubits) * n_samples

        if n_samples is not None and n_samples > 0:
            model.initialize_params(random_key, repeat=n_samples)
        else:
            if len(model.params.shape) <= 2:
                model.params = model.params.reshape(1, *model.params.shape)
            else:
                log.info(f"Using sample size of model params: {model.params.shape[0]}")

        # implicitly set input to none in case it's not needed
        kwargs.setdefault("inputs", None)
        rhos = model(execution_type="density", **kwargs)
        rhos = rhos.reshape(-1, 2**model.n_qubits, 2**model.n_qubits)
        ent = Entanglement._compute_entanglement_of_formation(
            rhos, model.n_qubits, always_decompose
        )
        return ent.mean()

    @staticmethod
    def _compute_entanglement_of_formation(
        rhos: jnp.ndarray,
        n_qubits: int,
        always_decompose: bool,
    ) -> jnp.ndarray:
        """
        Computes the entanglement of formation for a given batch of density
        matrices.

        Args:
            rho (jnp.ndarray): The density matrices, has shape (B_s, 2^n, 2^n),
                where B_s is the batch size and n the number of qubits.
            n_qubits (int): Number of qubits
            always_decompose (bool): Whether to explicitly compute the
                entantlement of formation for the eigendecomposition of a pure
                state.

        Returns:
            jnp.ndarray: Entanglement for the provided density matrices.
        """
        eigenvalues, eigenvectors = jnp.linalg.eigh(rhos)
        if not always_decompose and jnp.isclose(eigenvalues, 1.0).any(axis=-1).all():
            return Entanglement._compute_meyer_wallach_meas(rhos, n_qubits)

        rhos = np.einsum("sij,sik->sijk", eigenvectors, eigenvectors.conjugate())
        measures = Entanglement._compute_meyer_wallach_meas(
            rhos.reshape(-1, 2**n_qubits, 2**n_qubits), n_qubits
        )
        ent = np.einsum("si,si->s", measures.reshape(-1, 2**n_qubits), eigenvalues)
        return ent

    @staticmethod
    def concentratable_entanglement(
        model: Model,
        n_samples: int,
        random_key: Optional[jax.random.PRNGKey] = None,
        scale: bool = False,
        **kwargs: Any,
    ) -> float:
        """
        Computes the concentratable entanglement of a given model.

        This method utilizes the Concentratable Entanglement measure from
        https://arxiv.org/abs/2104.06923.  The swap test is implemented
        directly in yaqsi using a ``3 * n_qubits`` circuit.

        Args:
            model (Model): The quantum circuit model.
            n_samples (int): The number of samples to compute the measure for.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for
                parameter initialization. If None, uses the model's internal
                random key.
            scale (bool): Whether to scale the number of samples according to
                the number of qubits.
            **kwargs (Any): Additional keyword arguments for the model function.

        Returns:
            float: Entangling capability of the given circuit, guaranteed
                to be between 0.0 and 1.0.
        """
        n = model.n_qubits
        N = 2**n

        if scale:
            n_samples = N * n_samples

        def _shift_and_append(tape_ops, offset):
            """Re-register *tape_ops* on the active tape with wires shifted."""
            from qml_essentials.tape import active_tape as _active_tape

            current = _active_tape()
            if current is None:
                return
            for o in tape_ops:
                shifted = o.__class__.__new__(o.__class__)
                shifted.__dict__.update(o.__dict__)
                shifted._wires = [w + offset for w in o.wires]
                current.append(shifted)

        def _swap_test_circuit(
            params, inputs, pulse_params=None, random_key=None, **kw
        ):
            """Swap-test circuit on 3*n qubits."""
            from qml_essentials.tape import recording as _recording

            # First copy on wires n..2n-1
            with _recording() as copy1_tape:
                model._variational(
                    params,
                    inputs,
                    pulse_params=pulse_params,
                    random_key=random_key,
                    **kw,
                )
            _shift_and_append(copy1_tape, n)

            # Second copy on wires 2n..3n-1
            with _recording() as copy2_tape:
                model._variational(
                    params,
                    inputs,
                    pulse_params=pulse_params,
                    random_key=random_key,
                    **kw,
                )
            _shift_and_append(copy2_tape, 2 * n)

            # Swap test: H on ancilla register (wires 0..n-1)
            for i in range(n):
                op.H(wires=i)

            for i in range(n):
                op.CSWAP(wires=[i, i + n, i + 2 * n])

            for i in range(n):
                op.H(wires=i)

        swap_script = ys.Script(f=_swap_test_circuit, n_qubits=3 * n)

        if n_samples is not None and n_samples > 0:
            random_key = model.initialize_params(random_key, repeat=n_samples)
        else:
            if len(model.params.shape) <= 2:
                model.params = model.params.reshape(1, *model.params.shape)
            else:
                log.info(f"Using sample size of model params: {model.params.shape[0]}")

        params = model.params
        inputs = model._inputs_validation(kwargs.get("inputs", None))
        n_batch = params.shape[0]

        marg_probs = jax.jit(ys.marginalize_probs, static_argnums=(1, 2))

        if n_batch > 1:
            from qml_essentials.utils import safe_random_split

            random_keys = safe_random_split(random_key, num=n_batch)
            probs = swap_script.execute(
                type="probs",
                args=(params, inputs, model.pulse_params, random_keys),
                in_axes=(0, None, None, 0),
                kwargs=kwargs,
            )
        else:
            probs = swap_script.execute(
                type="probs",
                args=(params, inputs, model.pulse_params, random_key),
                kwargs=kwargs,
            )

        # Marginalize to the ancilla register (wires 0..n-1)
        probs = marg_probs(probs, 3 * n, tuple(range(n)))

        ent = 1 - probs[..., 0]

        log.debug(f"Variance of measure: {ent.var()}")

        return float(ent.mean())

bell_measurements(model, n_samples, random_key=None, scale=False, **kwargs) staticmethod #

Compute the Bell measurement for a given model.

Constructs a 2 * n_qubits circuit that prepares two copies of the model state (on disjoint qubit registers), applies CNOTs and Hadamards, and measures probabilities on the first register.

Parameters:

Name Type Description Default
model Model

The quantum circuit model.

required
n_samples int

The number of samples to compute the measure for.

required
random_key Optional[PRNGKey]

JAX random key for parameter initialization. If None, uses the model's internal random key.

None
scale bool

Whether to scale the number of samples according to the number of qubits.

False
**kwargs Any

Additional keyword arguments for the model function.

{}

Returns:

Name Type Description
float float

The Bell measurement value.

Source code in qml_essentials/entanglement.py
@staticmethod
def bell_measurements(
    model: Model,
    n_samples: int,
    random_key: Optional[jax.random.PRNGKey] = None,
    scale: bool = False,
    **kwargs: Any,
) -> float:
    """
    Compute the Bell measurement for a given model.

    Constructs a ``2 * n_qubits`` circuit that prepares two copies of
    the model state (on disjoint qubit registers), applies CNOTs and
    Hadamards, and measures probabilities on the first register.

    Args:
        model (Model): The quantum circuit model.
        n_samples (int): The number of samples to compute the measure for.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for
            parameter initialization. If None, uses the model's internal
            random key.
        scale (bool): Whether to scale the number of samples
            according to the number of qubits.
        **kwargs (Any): Additional keyword arguments for the model function.

    Returns:
        float: The Bell measurement value.
    """
    if "noise_params" in kwargs:
        log.warning(
            "Bell Measurements not suitable for noisy circuits. "
            "Consider 'concentratable entanglement' instead."
        )

    if scale:
        n_samples = jnp.power(2, model.n_qubits) * n_samples

    n = model.n_qubits

    def _bell_circuit(params, inputs, pulse_params=None, random_key=None, **kw):
        """Bell measurement circuit on 2*n qubits."""
        # First copy on wires 0..n-1
        model._variational(
            params, inputs, pulse_params=pulse_params, random_key=random_key, **kw
        )

        # TODO: this is very user-unfriendly and we should find a better way

        # Second copy on wires n..2n-1: record the tape then shift wires
        from qml_essentials.tape import recording as _recording

        with _recording() as shifted_tape:
            model._variational(
                params,
                inputs,
                pulse_params=pulse_params,
                random_key=random_key,
                **kw,
            )
        for o in shifted_tape:
            shifted_op = o.__class__.__new__(o.__class__)
            shifted_op.__dict__.update(o.__dict__)
            shifted_op._wires = [w + n for w in o.wires]
            # Re-register on the active tape
            from qml_essentials.tape import active_tape as _active_tape

            tape = _active_tape()
            if tape is not None:
                tape.append(shifted_op)

        # Bell measurement: CNOT + H
        for q in range(n):
            op.CX(wires=[q, q + n])
            op.H(wires=q)

    bell_script = ys.Script(f=_bell_circuit, n_qubits=2 * n)

    if n_samples is not None and n_samples > 0:
        random_key = model.initialize_params(random_key, repeat=n_samples)
        params = model.params
    else:
        if len(model.params.shape) <= 2:
            params = model.params.reshape(1, *model.params.shape)
        else:
            log.info(f"Using sample size of model params: {model.params.shape[0]}")
            params = model.params

    n_samples = params.shape[0]
    inputs = model._inputs_validation(kwargs.get("inputs", None))

    # Execute: vmap over batch dimension of params (axis 0)
    if n_samples > 1:
        from qml_essentials.utils import safe_random_split

        random_keys = safe_random_split(random_key, num=n_samples)
        result = bell_script.execute(
            type="probs",
            args=(params, inputs, model.pulse_params, random_keys),
            kwargs=kwargs,
            in_axes=(0, None, None, 0),
        )
    else:
        result = bell_script.execute(
            type="probs",
            args=(params, inputs, model.pulse_params, random_key),
            kwargs=kwargs,
        )

    # Marginalize: for each qubit q, keep wires [q, q+n] from the 2n-qubit probs
    # The last probability in each pair gives P(|11⟩) for that qubit pair
    per_qubit = []
    for q in range(n):
        marg = ys.marginalize_probs(result, 2 * n, [q, q + n])
        per_qubit.append(marg)
    # per_qubit[q] has shape (n_samples, 4) or (4,)
    exp = jnp.stack(per_qubit, axis=-2)  # (..., n, 4)
    exp = 1 - 2 * exp[..., -1]  # (..., n)

    if not jnp.isclose(jnp.sum(exp.imag), 0, atol=1e-6):
        log.warning("Imaginary part of probabilities detected")
        exp = jnp.abs(exp)

    measure = 2 * (1 - exp.mean(axis=0))
    entangling_capability = min(max(float(measure.mean()), 0.0), 1.0)
    log.debug(f"Variance of measure: {measure.var()}")

    return entangling_capability

concentratable_entanglement(model, n_samples, random_key=None, scale=False, **kwargs) staticmethod #

Computes the concentratable entanglement of a given model.

This method utilizes the Concentratable Entanglement measure from https://arxiv.org/abs/2104.06923. The swap test is implemented directly in yaqsi using a 3 * n_qubits circuit.

Parameters:

Name Type Description Default
model Model

The quantum circuit model.

required
n_samples int

The number of samples to compute the measure for.

required
random_key Optional[PRNGKey]

JAX random key for parameter initialization. If None, uses the model's internal random key.

None
scale bool

Whether to scale the number of samples according to the number of qubits.

False
**kwargs Any

Additional keyword arguments for the model function.

{}

Returns:

Name Type Description
float float

Entangling capability of the given circuit, guaranteed to be between 0.0 and 1.0.

Source code in qml_essentials/entanglement.py
@staticmethod
def concentratable_entanglement(
    model: Model,
    n_samples: int,
    random_key: Optional[jax.random.PRNGKey] = None,
    scale: bool = False,
    **kwargs: Any,
) -> float:
    """
    Computes the concentratable entanglement of a given model.

    This method utilizes the Concentratable Entanglement measure from
    https://arxiv.org/abs/2104.06923.  The swap test is implemented
    directly in yaqsi using a ``3 * n_qubits`` circuit.

    Args:
        model (Model): The quantum circuit model.
        n_samples (int): The number of samples to compute the measure for.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for
            parameter initialization. If None, uses the model's internal
            random key.
        scale (bool): Whether to scale the number of samples according to
            the number of qubits.
        **kwargs (Any): Additional keyword arguments for the model function.

    Returns:
        float: Entangling capability of the given circuit, guaranteed
            to be between 0.0 and 1.0.
    """
    n = model.n_qubits
    N = 2**n

    if scale:
        n_samples = N * n_samples

    def _shift_and_append(tape_ops, offset):
        """Re-register *tape_ops* on the active tape with wires shifted."""
        from qml_essentials.tape import active_tape as _active_tape

        current = _active_tape()
        if current is None:
            return
        for o in tape_ops:
            shifted = o.__class__.__new__(o.__class__)
            shifted.__dict__.update(o.__dict__)
            shifted._wires = [w + offset for w in o.wires]
            current.append(shifted)

    def _swap_test_circuit(
        params, inputs, pulse_params=None, random_key=None, **kw
    ):
        """Swap-test circuit on 3*n qubits."""
        from qml_essentials.tape import recording as _recording

        # First copy on wires n..2n-1
        with _recording() as copy1_tape:
            model._variational(
                params,
                inputs,
                pulse_params=pulse_params,
                random_key=random_key,
                **kw,
            )
        _shift_and_append(copy1_tape, n)

        # Second copy on wires 2n..3n-1
        with _recording() as copy2_tape:
            model._variational(
                params,
                inputs,
                pulse_params=pulse_params,
                random_key=random_key,
                **kw,
            )
        _shift_and_append(copy2_tape, 2 * n)

        # Swap test: H on ancilla register (wires 0..n-1)
        for i in range(n):
            op.H(wires=i)

        for i in range(n):
            op.CSWAP(wires=[i, i + n, i + 2 * n])

        for i in range(n):
            op.H(wires=i)

    swap_script = ys.Script(f=_swap_test_circuit, n_qubits=3 * n)

    if n_samples is not None and n_samples > 0:
        random_key = model.initialize_params(random_key, repeat=n_samples)
    else:
        if len(model.params.shape) <= 2:
            model.params = model.params.reshape(1, *model.params.shape)
        else:
            log.info(f"Using sample size of model params: {model.params.shape[0]}")

    params = model.params
    inputs = model._inputs_validation(kwargs.get("inputs", None))
    n_batch = params.shape[0]

    marg_probs = jax.jit(ys.marginalize_probs, static_argnums=(1, 2))

    if n_batch > 1:
        from qml_essentials.utils import safe_random_split

        random_keys = safe_random_split(random_key, num=n_batch)
        probs = swap_script.execute(
            type="probs",
            args=(params, inputs, model.pulse_params, random_keys),
            in_axes=(0, None, None, 0),
            kwargs=kwargs,
        )
    else:
        probs = swap_script.execute(
            type="probs",
            args=(params, inputs, model.pulse_params, random_key),
            kwargs=kwargs,
        )

    # Marginalize to the ancilla register (wires 0..n-1)
    probs = marg_probs(probs, 3 * n, tuple(range(n)))

    ent = 1 - probs[..., 0]

    log.debug(f"Variance of measure: {ent.var()}")

    return float(ent.mean())

entanglement_of_formation(model, n_samples, random_key=None, scale=False, always_decompose=False, **kwargs) staticmethod #

This function implements the entanglement of formation for mixed quantum systems. In that a mixed state gets decomposed into pure states with respective probabilities using the eigendecomposition of the density matrix. Then, the Meyer-Wallach measure is computed for each pure state, weighted by the eigenvalue. See e.g. https://doi.org/10.48550/arXiv.quant-ph/0504163

Note that the decomposition is not unique! Therefore, this measure presents the entanglement for some decomposition into pure states, not necessarily the one that is anticipated when applying the Kraus channels. If a pure state is provided, this results in the same value as the Entanglement.meyer_wallach function if always_decompose flag is not set.

Parameters:

Name Type Description Default
model Model

The quantum circuit model.

required
n_samples int

Number of samples per qubit.

required
random_key Optional[PRNGKey]

JAX random key for parameter initialization. If None, uses the model's internal random key.

None
scale bool

Whether to scale the number of samples.

False
always_decompose bool

Whether to explicitly compute the entantlement of formation for the eigendecomposition of a pure state.

False
kwargs Any

Additional keyword arguments for the model function.

{}

Returns:

Name Type Description
float float

Entangling capacity of the given circuit, guaranteed to be between 0.0 and 1.0.

Source code in qml_essentials/entanglement.py
@staticmethod
def entanglement_of_formation(
    model: Model,
    n_samples: int,
    random_key: Optional[jax.random.PRNGKey] = None,
    scale: bool = False,
    always_decompose: bool = False,
    **kwargs: Any,
) -> float:
    """
    This function implements the entanglement of formation for mixed
    quantum systems.
    In that a mixed state gets decomposed into pure states with respective
    probabilities using the eigendecomposition of the density matrix.
    Then, the Meyer-Wallach measure is computed for each pure state,
    weighted by the eigenvalue.
    See e.g. https://doi.org/10.48550/arXiv.quant-ph/0504163

    Note that the decomposition is *not unique*! Therefore, this measure
    presents the entanglement for *some* decomposition into pure states,
    not necessarily the one that is anticipated when applying the Kraus
    channels.
    If a pure state is provided, this results in the same value as the
    Entanglement.meyer_wallach function if `always_decompose` flag is not set.

    Args:
        model (Model): The quantum circuit model.
        n_samples (int): Number of samples per qubit.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for
            parameter initialization. If None, uses the model's internal
            random key.
        scale (bool): Whether to scale the number of samples.
        always_decompose (bool): Whether to explicitly compute the
            entantlement of formation for the eigendecomposition of a pure
            state.
        kwargs (Any): Additional keyword arguments for the model function.

    Returns:
        float: Entangling capacity of the given circuit, guaranteed
            to be between 0.0 and 1.0.
    """

    if scale:
        n_samples = jnp.power(2, model.n_qubits) * n_samples

    if n_samples is not None and n_samples > 0:
        model.initialize_params(random_key, repeat=n_samples)
    else:
        if len(model.params.shape) <= 2:
            model.params = model.params.reshape(1, *model.params.shape)
        else:
            log.info(f"Using sample size of model params: {model.params.shape[0]}")

    # implicitly set input to none in case it's not needed
    kwargs.setdefault("inputs", None)
    rhos = model(execution_type="density", **kwargs)
    rhos = rhos.reshape(-1, 2**model.n_qubits, 2**model.n_qubits)
    ent = Entanglement._compute_entanglement_of_formation(
        rhos, model.n_qubits, always_decompose
    )
    return ent.mean()

meyer_wallach(model, n_samples, random_key=None, scale=False, **kwargs) staticmethod #

Calculates the entangling capacity of a given quantum circuit using Meyer-Wallach measure.

Parameters:

Name Type Description Default
model Model

The quantum circuit model.

required
n_samples Optional[int]

Number of samples per qubit. If None or < 0, the current parameters of the model are used.

required
random_key Optional[PRNGKey]

JAX random key for parameter initialization. If None, uses the model's internal random key.

None
scale bool

Whether to scale the number of samples.

False
kwargs Any

Additional keyword arguments for the model function.

{}

Returns:

Name Type Description
float float

Entangling capacity of the given circuit, guaranteed to be between 0.0 and 1.0.

Source code in qml_essentials/entanglement.py
@staticmethod
def meyer_wallach(
    model: Model,
    n_samples: Optional[int | None],
    random_key: Optional[jax.random.PRNGKey] = None,
    scale: bool = False,
    **kwargs: Any,
) -> float:
    """
    Calculates the entangling capacity of a given quantum circuit
    using Meyer-Wallach measure.

    Args:
        model (Model): The quantum circuit model.
        n_samples (Optional[int]): Number of samples per qubit.
            If None or < 0, the current parameters of the model are used.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for
            parameter initialization. If None, uses the model's internal
            random key.
        scale (bool): Whether to scale the number of samples.
        kwargs (Any): Additional keyword arguments for the model function.

    Returns:
        float: Entangling capacity of the given circuit, guaranteed
            to be between 0.0 and 1.0.
    """
    if "noise_params" in kwargs:
        log.warning(
            "Meyer-Wallach measure not suitable for noisy circuits. "
            "Consider 'concentratable entanglement' instead."
        )

    if scale:
        n_samples = jnp.power(2, model.n_qubits) * n_samples

    if n_samples is not None and n_samples > 0:
        random_key = model.initialize_params(random_key, repeat=n_samples)

    # implicitly set input to none in case it's not needed
    kwargs.setdefault("inputs", None)
    # explicitly set execution type because everything else won't work
    rhos = model(execution_type="density", **kwargs).reshape(
        -1, 2**model.n_qubits, 2**model.n_qubits
    )

    ent = Entanglement._compute_meyer_wallach_meas(rhos, model.n_qubits)

    log.debug(f"Variance of measure: {ent.var()}")

    return ent.mean()

relative_entropy(model, n_samples, n_sigmas, random_key=None, scale=False, **kwargs) staticmethod #

Calculates the relative entropy of entanglement of a given quantum circuit. This measure is also applicable to mixed state, albeit it might me not fully accurate in this simplified case.

As the relative entropy is generally defined as the smallest relative entropy from the state in question to the set of separable states. However, as computing the nearest separable state is NP-hard, we select n_sigmas of random separable states to compute the distance to, which is not necessarily the nearest. Thus, this measure of entanglement presents an upper limit of entanglement.

As the relative entropy is not necessarily between zero and one, this function also normalises by the relative entroy to the GHZ state.

Parameters:

Name Type Description Default
model Model

The quantum circuit model.

required
n_samples int

Number of samples per qubit. If <= 0, the current parameters of the model are used.

required
n_sigmas int

Number of random separable pure states to compare against.

required
random_key Optional[PRNGKey]

JAX random key for parameter initialization. If None, uses the model's internal random key.

None
scale bool

Whether to scale the number of samples.

False
kwargs Any

Additional keyword arguments for the model function.

{}

Returns:

Name Type Description
float float

Entangling capacity of the given circuit, guaranteed to be between 0.0 and 1.0.

Source code in qml_essentials/entanglement.py
@staticmethod
def relative_entropy(
    model: Model,
    n_samples: int,
    n_sigmas: int,
    random_key: Optional[jax.random.PRNGKey] = None,
    scale: bool = False,
    **kwargs: Any,
) -> float:
    """
    Calculates the relative entropy of entanglement of a given quantum
    circuit. This measure is also applicable to mixed state, albeit it
    might me not fully accurate in this simplified case.

    As the relative entropy is generally defined as the smallest relative
    entropy from the state in question to the set of separable states.
    However, as computing the nearest separable state is NP-hard, we select
    n_sigmas of random separable states to compute the distance to, which
    is not necessarily the nearest. Thus, this measure of entanglement
    presents an upper limit of entanglement.

    As the relative entropy is not necessarily between zero and one, this
    function also normalises by the relative entroy to the GHZ state.

    Args:
        model (Model): The quantum circuit model.
        n_samples (int): Number of samples per qubit.
            If <= 0, the current parameters of the model are used.
        n_sigmas (int): Number of random separable pure states to compare against.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for
            parameter initialization. If None, uses the model's internal
            random key.
        scale (bool): Whether to scale the number of samples.
        kwargs (Any): Additional keyword arguments for the model function.

    Returns:
        float: Entangling capacity of the given circuit, guaranteed
            to be between 0.0 and 1.0.
    """
    dim = jnp.power(2, model.n_qubits)
    if scale:
        n_samples = dim * n_samples
        n_sigmas = dim * n_sigmas

    if random_key is None:
        random_key = model.random_key

    # Random separable states
    log_sigmas = sample_random_separable_states(
        model.n_qubits, n_samples=n_sigmas, random_key=random_key, take_log=True
    )

    random_key, _ = jax.random.split(random_key)

    if n_samples is not None and n_samples > 0:
        model.initialize_params(random_key, repeat=n_samples)
    else:
        if len(model.params.shape) <= 2:
            model.params = model.params.reshape(1, *model.params.shape)
        else:
            log.info(f"Using sample size of model params: {model.params.shape[0]}")

    rhos, log_rhos = Entanglement._compute_log_density(model, **kwargs)

    rel_entropies = jnp.zeros((n_sigmas, model.params.shape[0]))

    for i, log_sigma in enumerate(log_sigmas):
        rel_entropies = rel_entropies.at[i].set(
            Entanglement._compute_rel_entropies(rhos, log_rhos, log_sigma)
        )

    # Entropy of GHZ states should be maximal
    ghz_model = Model(model.n_qubits, 1, "GHZ", data_reupload=False)
    rho_ghz, log_rho_ghz = Entanglement._compute_log_density(ghz_model, **kwargs)
    ghz_entropies = Entanglement._compute_rel_entropies(
        rho_ghz, log_rho_ghz, log_sigmas
    )

    normalised_entropies = rel_entropies / ghz_entropies

    # Average all iterated states
    entangling_capability = normalised_entropies.T.min(axis=1)
    log.debug(f"Variance of measure: {entangling_capability.var()}")

    return entangling_capability.mean()

Expressibility#

from qml_essentials.expressibility import Expressibility
Source code in qml_essentials/expressibility.py
class Expressibility:
    @staticmethod
    def _sample_state_fidelities(
        model: Model,
        n_samples: int,
        random_key: Optional[jax.random.PRNGKey] = None,
        kwargs: Any = None,
    ) -> jnp.ndarray:
        """
        Compute the fidelities for each parameter set.

        Args:
            model (Callable): Function that models the quantum circuit.
            n_samples (int): Number of parameter sets to generate.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for
                parameter initialization. If None, uses the model's internal
                random key.
            kwargs (Any): Additional keyword arguments for the model function.

        Returns:
            jnp.ndarray: Array of shape (n_samples,) containing the fidelities.
        """
        # Generate random parameter sets
        # We need two sets of parameters, as we are computing fidelities for a
        # pair of random state vectors
        model.initialize_params(random_key, repeat=n_samples * 2)

        # Evaluate the model for all parameters
        # Execution type is explicitly set to density
        sv: jnp.ndarray = model(
            params=model.params,
            execution_type="density",
            **kwargs,
        )

        # $\sqrt{\rho}$
        sqrt_sv1: jnp.ndarray = jnp.array([sqrtm(m) for m in sv[:n_samples]])

        # $\sqrt{\rho} \sigma \sqrt{\rho}$
        inner_fidelity = sqrt_sv1 @ sv[n_samples:] @ sqrt_sv1

        # Compute the fidelity using the partial trace of the statevector
        fidelity: jnp.ndarray = (
            jnp.trace(
                jnp.array([sqrtm(m) for m in inner_fidelity]),
                axis1=1,
                axis2=2,
            )
            ** 2
        )

        fidelity = jnp.abs(fidelity)

        return fidelity

    @staticmethod
    def state_fidelities(
        n_samples: int,
        n_bins: int,
        model: Model,
        random_key: Optional[jax.random.PRNGKey] = None,
        scale: bool = False,
        **kwargs: Any,
    ) -> Tuple[jnp.ndarray, jnp.ndarray]:
        """
        Sample the state fidelities and histogram them into a 2D array.

        Args:
            n_samples (int): Number of parameter sets to generate.
            n_bins (int): Number of histogram bins.
            model (Callable): Function that models the quantum circuit.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for
                parameter initialization. If None, uses the model's internal
                random key.
            scale (bool): Whether to scale the number of samples and bins.
            kwargs (Any): Additional keyword arguments for the model function.

        Returns:
            Tuple[jnp.ndarray, jnp.ndarray]: Tuple containing the bin edges,
            and histogram values.
        """
        if scale:
            n_samples = jnp.power(2, model.n_qubits) * n_samples
            n_bins = model.n_qubits * n_bins

        fidelities = Expressibility._sample_state_fidelities(
            n_samples=n_samples,
            random_key=random_key,
            model=model,
            kwargs=kwargs,
        )

        y: jnp.ndarray = jnp.linspace(0, 1, n_bins + 1)

        z, _ = jnp.histogram(fidelities, bins=y)

        z = z / n_samples

        return y, z

    @staticmethod
    def _haar_probability(fidelity: float, n_qubits: int) -> float:
        """
        Calculates theoretical probability density function for random Haar states
        as proposed by Sim et al. (https://arxiv.org/abs/1905.10876).

        Args:
            fidelity (float): fidelity of two parameter assignments in [0, 1]
            n_qubits (int): number of qubits in the quantum system

        Returns:
            float: probability for a given fidelity
        """
        N = 2**n_qubits

        prob = (N - 1) * (1 - fidelity) ** (N - 2)
        return prob

    @staticmethod
    def _sample_haar_integral(n_qubits: int, n_bins: int) -> jnp.ndarray:
        """
        Calculates theoretical probability density function for random Haar states
        as proposed by Sim et al. (https://arxiv.org/abs/1905.10876) and bins it
        into a 2D-histogram.

        Args:
            n_qubits (int): number of qubits in the quantum system
            n_bins (int): number of histogram bins

        Returns:
            jnp.ndarray: probability distribution for all fidelities
        """
        dist = np.zeros(n_bins)
        for idx in range(n_bins):
            v = idx / n_bins
            u = (idx + 1) / n_bins
            dist[idx], _ = integrate.quad(
                Expressibility._haar_probability, v, u, args=(n_qubits,)
            )

        return dist

    @staticmethod
    def haar_integral(
        n_qubits: int,
        n_bins: int,
        cache: bool = True,
        scale: bool = False,
    ) -> Tuple[jnp.ndarray, jnp.ndarray]:
        """
        Calculates theoretical probability density function for random Haar states
        as proposed by Sim et al. (https://arxiv.org/abs/1905.10876) and bins it
        into a 3D-histogram.

        Args:
            n_qubits (int): number of qubits in the quantum system
            n_bins (int): number of histogram bins
            cache (bool): whether to cache the haar integral
            scale (bool): whether to scale the number of bins

        Returns:
            Tuple[jnp.ndarray, jnp.ndarray]:
                - x component (bins): the input domain
                - y component (probabilities): the haar probability density
                  funtion for random Haar states
        """
        if scale:
            n_bins = n_qubits * n_bins

        x = jnp.linspace(0, 1, n_bins)

        if cache:
            name = f"haar_{n_qubits}q_{n_bins}s_{'scaled' if scale else ''}.npy"

            cache_folder = ".cache"
            if not os.path.exists(cache_folder):
                os.mkdir(cache_folder)

            file_path = os.path.join(cache_folder, name)

            if os.path.isfile(file_path):
                y = jnp.load(file_path)
                return x, y

        y = Expressibility._sample_haar_integral(n_qubits, n_bins)

        if cache:
            jnp.save(file_path, y)

        return x, y

    @staticmethod
    def kullback_leibler_divergence(
        vqc_prob_dist: jnp.ndarray,
        haar_dist: jnp.ndarray,
    ) -> jnp.ndarray:
        """
        Calculates the KL divergence between two probability distributions (Haar
        probability distribution and the fidelity distribution sampled from a VQC).

        Args:
            vqc_prob_dist (jnp.ndarray): VQC fidelity probability distribution.
                Should have shape (n_inputs_samples, n_bins)
            haar_dist (jnp.ndarray): Haar probability distribution with shape.
                Should have shape (n_bins, )

        Returns:
            jnp.ndarray: Array of KL-Divergence values for all values in axis 1
        """
        if len(vqc_prob_dist.shape) > 1:
            assert all([haar_dist.shape == p.shape for p in vqc_prob_dist]), (
                "All probabilities for inputs should have the same shape as Haar. "
                f"Got {haar_dist.shape} for Haar and {vqc_prob_dist.shape} for VQC"
            )
        else:
            vqc_prob_dist = vqc_prob_dist.reshape((1, -1))

        kl_divergence = np.zeros(vqc_prob_dist.shape[0])
        for idx, p in enumerate(vqc_prob_dist):
            kl_divergence[idx] = jnp.sum(rel_entr(p, haar_dist))

        return kl_divergence

    def kl_divergence_to_haar(
        model: Model,
        n_samples: int,
        n_bins: int,
        random_key: Optional[jax.random.PRNGKey] = None,
        scale: bool = False,
        **kwargs: Any,
    ) -> float:
        """
        Shortcut method to compute the KL-Divergence bewteen a model and the
        Haar distribution. The basic steps are:
            - Sample the state fidelities for randomly initialised parameters.
            - Calculates the KL divergence between the sampled probability and
              the Haar probability distribution.

        Args:
            model (Model): Function that models the quantum circuit.
            n_samples (int): Number of parameter sets to generate.
            n_bins (int): Number of histogram bins.
            random_key (Optional[jax.random.PRNGKey]): JAX random key for
                parameter initialization. If None, uses the model's internal
                random key.
            scale (bool): Whether to scale the number of samples and bins.
            kwargs (Any): Additional keyword arguments for the model function.

        Returns:
            Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: Tuple containing the
                input samples, bin edges, and histogram values.
        """
        _, fidelities = Expressibility.state_fidelities(
            model=model,
            random_key=random_key,
            n_samples=n_samples,
            n_bins=n_bins,
            scale=scale,
            **kwargs,
        )
        _, haar_probs = Expressibility.haar_integral(
            model.n_qubits, n_bins=n_bins, scale=scale
        )
        return Expressibility.kullback_leibler_divergence(fidelities, haar_probs)

haar_integral(n_qubits, n_bins, cache=True, scale=False) staticmethod #

Calculates theoretical probability density function for random Haar states as proposed by Sim et al. (https://arxiv.org/abs/1905.10876) and bins it into a 3D-histogram.

Parameters:

Name Type Description Default
n_qubits int

number of qubits in the quantum system

required
n_bins int

number of histogram bins

required
cache bool

whether to cache the haar integral

True
scale bool

whether to scale the number of bins

False

Returns:

Type Description
Tuple[ndarray, ndarray]

Tuple[jnp.ndarray, jnp.ndarray]: - x component (bins): the input domain - y component (probabilities): the haar probability density funtion for random Haar states

Source code in qml_essentials/expressibility.py
@staticmethod
def haar_integral(
    n_qubits: int,
    n_bins: int,
    cache: bool = True,
    scale: bool = False,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """
    Calculates theoretical probability density function for random Haar states
    as proposed by Sim et al. (https://arxiv.org/abs/1905.10876) and bins it
    into a 3D-histogram.

    Args:
        n_qubits (int): number of qubits in the quantum system
        n_bins (int): number of histogram bins
        cache (bool): whether to cache the haar integral
        scale (bool): whether to scale the number of bins

    Returns:
        Tuple[jnp.ndarray, jnp.ndarray]:
            - x component (bins): the input domain
            - y component (probabilities): the haar probability density
              funtion for random Haar states
    """
    if scale:
        n_bins = n_qubits * n_bins

    x = jnp.linspace(0, 1, n_bins)

    if cache:
        name = f"haar_{n_qubits}q_{n_bins}s_{'scaled' if scale else ''}.npy"

        cache_folder = ".cache"
        if not os.path.exists(cache_folder):
            os.mkdir(cache_folder)

        file_path = os.path.join(cache_folder, name)

        if os.path.isfile(file_path):
            y = jnp.load(file_path)
            return x, y

    y = Expressibility._sample_haar_integral(n_qubits, n_bins)

    if cache:
        jnp.save(file_path, y)

    return x, y

kl_divergence_to_haar(model, n_samples, n_bins, random_key=None, scale=False, **kwargs) #

Shortcut method to compute the KL-Divergence bewteen a model and the Haar distribution. The basic steps are: - Sample the state fidelities for randomly initialised parameters. - Calculates the KL divergence between the sampled probability and the Haar probability distribution.

Parameters:

Name Type Description Default
model Model

Function that models the quantum circuit.

required
n_samples int

Number of parameter sets to generate.

required
n_bins int

Number of histogram bins.

required
random_key Optional[PRNGKey]

JAX random key for parameter initialization. If None, uses the model's internal random key.

None
scale bool

Whether to scale the number of samples and bins.

False
kwargs Any

Additional keyword arguments for the model function.

{}

Returns:

Type Description
float

Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: Tuple containing the input samples, bin edges, and histogram values.

Source code in qml_essentials/expressibility.py
def kl_divergence_to_haar(
    model: Model,
    n_samples: int,
    n_bins: int,
    random_key: Optional[jax.random.PRNGKey] = None,
    scale: bool = False,
    **kwargs: Any,
) -> float:
    """
    Shortcut method to compute the KL-Divergence bewteen a model and the
    Haar distribution. The basic steps are:
        - Sample the state fidelities for randomly initialised parameters.
        - Calculates the KL divergence between the sampled probability and
          the Haar probability distribution.

    Args:
        model (Model): Function that models the quantum circuit.
        n_samples (int): Number of parameter sets to generate.
        n_bins (int): Number of histogram bins.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for
            parameter initialization. If None, uses the model's internal
            random key.
        scale (bool): Whether to scale the number of samples and bins.
        kwargs (Any): Additional keyword arguments for the model function.

    Returns:
        Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: Tuple containing the
            input samples, bin edges, and histogram values.
    """
    _, fidelities = Expressibility.state_fidelities(
        model=model,
        random_key=random_key,
        n_samples=n_samples,
        n_bins=n_bins,
        scale=scale,
        **kwargs,
    )
    _, haar_probs = Expressibility.haar_integral(
        model.n_qubits, n_bins=n_bins, scale=scale
    )
    return Expressibility.kullback_leibler_divergence(fidelities, haar_probs)

kullback_leibler_divergence(vqc_prob_dist, haar_dist) staticmethod #

Calculates the KL divergence between two probability distributions (Haar probability distribution and the fidelity distribution sampled from a VQC).

Parameters:

Name Type Description Default
vqc_prob_dist ndarray

VQC fidelity probability distribution. Should have shape (n_inputs_samples, n_bins)

required
haar_dist ndarray

Haar probability distribution with shape. Should have shape (n_bins, )

required

Returns:

Type Description
ndarray

jnp.ndarray: Array of KL-Divergence values for all values in axis 1

Source code in qml_essentials/expressibility.py
@staticmethod
def kullback_leibler_divergence(
    vqc_prob_dist: jnp.ndarray,
    haar_dist: jnp.ndarray,
) -> jnp.ndarray:
    """
    Calculates the KL divergence between two probability distributions (Haar
    probability distribution and the fidelity distribution sampled from a VQC).

    Args:
        vqc_prob_dist (jnp.ndarray): VQC fidelity probability distribution.
            Should have shape (n_inputs_samples, n_bins)
        haar_dist (jnp.ndarray): Haar probability distribution with shape.
            Should have shape (n_bins, )

    Returns:
        jnp.ndarray: Array of KL-Divergence values for all values in axis 1
    """
    if len(vqc_prob_dist.shape) > 1:
        assert all([haar_dist.shape == p.shape for p in vqc_prob_dist]), (
            "All probabilities for inputs should have the same shape as Haar. "
            f"Got {haar_dist.shape} for Haar and {vqc_prob_dist.shape} for VQC"
        )
    else:
        vqc_prob_dist = vqc_prob_dist.reshape((1, -1))

    kl_divergence = np.zeros(vqc_prob_dist.shape[0])
    for idx, p in enumerate(vqc_prob_dist):
        kl_divergence[idx] = jnp.sum(rel_entr(p, haar_dist))

    return kl_divergence

state_fidelities(n_samples, n_bins, model, random_key=None, scale=False, **kwargs) staticmethod #

Sample the state fidelities and histogram them into a 2D array.

Parameters:

Name Type Description Default
n_samples int

Number of parameter sets to generate.

required
n_bins int

Number of histogram bins.

required
model Callable

Function that models the quantum circuit.

required
random_key Optional[PRNGKey]

JAX random key for parameter initialization. If None, uses the model's internal random key.

None
scale bool

Whether to scale the number of samples and bins.

False
kwargs Any

Additional keyword arguments for the model function.

{}

Returns:

Type Description
ndarray

Tuple[jnp.ndarray, jnp.ndarray]: Tuple containing the bin edges,

ndarray

and histogram values.

Source code in qml_essentials/expressibility.py
@staticmethod
def state_fidelities(
    n_samples: int,
    n_bins: int,
    model: Model,
    random_key: Optional[jax.random.PRNGKey] = None,
    scale: bool = False,
    **kwargs: Any,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """
    Sample the state fidelities and histogram them into a 2D array.

    Args:
        n_samples (int): Number of parameter sets to generate.
        n_bins (int): Number of histogram bins.
        model (Callable): Function that models the quantum circuit.
        random_key (Optional[jax.random.PRNGKey]): JAX random key for
            parameter initialization. If None, uses the model's internal
            random key.
        scale (bool): Whether to scale the number of samples and bins.
        kwargs (Any): Additional keyword arguments for the model function.

    Returns:
        Tuple[jnp.ndarray, jnp.ndarray]: Tuple containing the bin edges,
        and histogram values.
    """
    if scale:
        n_samples = jnp.power(2, model.n_qubits) * n_samples
        n_bins = model.n_qubits * n_bins

    fidelities = Expressibility._sample_state_fidelities(
        n_samples=n_samples,
        random_key=random_key,
        model=model,
        kwargs=kwargs,
    )

    y: jnp.ndarray = jnp.linspace(0, 1, n_bins + 1)

    z, _ = jnp.histogram(fidelities, bins=y)

    z = z / n_samples

    return y, z

Coefficients#

from qml_essentials.coefficients import Coefficients
Source code in qml_essentials/coefficients.py
class Coefficients:
    @staticmethod
    def get_spectrum(
        model: Model,
        mfs: int = 1,
        mts: int = 1,
        shift=False,
        trim=False,
        numerical_cap: Optional[float] = -1,
        **kwargs,
    ) -> Tuple[jnp.ndarray, jnp.ndarray]:
        """
        Extracts the coefficients of a given model using a FFT (jnp-fft).

        Note that the coefficients are complex numbers, but the imaginary part
        of the coefficients should be very close to zero, since the expectation
        values of the Pauli operators are real numbers.

        It can perform oversampling in both the frequency and time domain
        using the `mfs` and `mts` arguments.

        Args:
            model (Model): The model to sample.
            mfs (int): Multiplicator for the highest frequency. Default is 2.
            mts (int): Multiplicator for the number of time samples. Default is 1.
            shift (bool): Whether to apply jnp-fftshift. Default is False.
            trim (bool): Whether to remove the Nyquist frequency if spectrum is even.
                Default is False.
            numerical_cap (Optional[float]): Numerical cap for the coefficients.
            kwargs (Any): Additional keyword arguments for the model function.

        Returns:
            Tuple[jnp.ndarray, jnp.ndarray]: Tuple containing the coefficients
            and frequencies.
        """
        kwargs.setdefault("force_mean", True)
        kwargs.setdefault("execution_type", "expval")

        coeffs, freqs = Coefficients._fourier_transform(
            model, mfs=mfs, mts=mts, **kwargs
        )

        if not jnp.isclose(jnp.sum(coeffs).imag, 0.0, rtol=1.0e-5):
            raise ValueError(
                f"Spectrum is not real. Imaginary part of coefficients is:\
                {jnp.sum(coeffs).imag}"
            )

        if trim:
            for ax in range(model.n_input_feat):
                if coeffs.shape[ax] % 2 == 0:
                    coeffs = np.delete(coeffs, len(coeffs) // 2, axis=ax)
                    freqs = [np.delete(freq, len(freq) // 2, axis=ax) for freq in freqs]

        if shift:
            coeffs = jnp.fft.fftshift(coeffs, axes=list(range(model.n_input_feat)))
            freqs = np.fft.fftshift(freqs)

        if numerical_cap > 0:
            # set coeffs below threshold to zero
            coeffs = jnp.where(
                jnp.abs(coeffs) < numerical_cap,
                jnp.zeros_like(coeffs),
                coeffs,
            )

        if len(freqs) == 1:
            freqs = freqs[0]

        return coeffs, freqs

    @staticmethod
    def _fourier_transform(
        model: Model, mfs: int, mts: int, **kwargs: Any
    ) -> jnp.ndarray:
        # Create a frequency vector with as many frequencies as model degrees,
        # oversampled by mfs
        n_freqs: jnp.ndarray = jnp.array(
            [mfs * model.degree[i] for i in range(model.n_input_feat)]
        )

        start, stop, step = 0, 2 * mts * jnp.pi, 2 * jnp.pi / n_freqs
        # Stretch according to the number of frequencies
        inputs: List = [
            jnp.arange(start, stop, step[i]) for i in range(model.n_input_feat)
        ]

        # permute with input dimensionality
        nd_inputs = jnp.array(
            jnp.meshgrid(*[inputs[i] for i in range(model.n_input_feat)])
        ).T.reshape(-1, model.n_input_feat)

        # Output vector is not necessarily the same length as input
        outputs = model(inputs=nd_inputs, **kwargs)
        outputs = outputs.reshape(
            *[inputs[i].shape[0] for i in range(model.n_input_feat)], -1
        ).squeeze()

        coeffs = jnp.fft.fftn(outputs, axes=list(range(model.n_input_feat)))

        freqs = [
            jnp.fft.fftfreq(int(mts * n_freqs[i]), 1 / n_freqs[i])
            for i in range(model.n_input_feat)
        ]
        # freqs = jnp.fft.fftfreq(mts * n_freqs, 1 / n_freqs)

        # TODO: this could cause issues with multidim input
        # FIXME: account for different frequencies in multidim input scenarios
        # Run the fft and rearrange +
        # normalize the output (using product if multidim)
        return (
            coeffs / math.prod(outputs.shape[0 : model.n_input_feat]),
            freqs,
        )

    @staticmethod
    def get_psd(coeffs: jnp.ndarray) -> jnp.ndarray:
        """
        Calculates the power spectral density (PSD) from given Fourier coefficients.

        Args:
            coeffs (jnp.ndarray): The Fourier coefficients.

        Returns:
            jnp.ndarray: The power spectral density.
        """
        # TODO: if we apply trim=True in advance, this will be slightly wrong..

        def abs2(x):
            return x.real**2 + x.imag**2

        scale = 2.0 / (len(coeffs) ** 2)
        return scale * abs2(coeffs)

    @staticmethod
    def evaluate_Fourier_series(
        coefficients: jnp.ndarray,
        frequencies: jnp.ndarray,
        inputs: Union[jnp.ndarray, list, float],
    ) -> float:
        """
        Evaluate the function value of a Fourier series at one point.

        Args:
            coefficients (jnp.ndarray): Coefficients of the Fourier series.
            frequencies (jnp.ndarray): Corresponding frequencies.
            inputs (jnp.ndarray): Point at which to evaluate the function.
        Returns:
            float: The function value at the input point.
        """
        if isinstance(frequencies, list):
            if len(coefficients.shape) <= len(frequencies):
                coefficients = coefficients[..., jnp.newaxis]
        else:
            if len(coefficients.shape) == 1:
                coefficients = coefficients[..., jnp.newaxis]

        if isinstance(inputs, list):
            inputs = jnp.array(inputs)
        if len(inputs.shape) < 1:
            inputs = inputs[jnp.newaxis, ...]

        if isinstance(frequencies, list):
            input_dim = len(frequencies)
            frequencies = jnp.stack(jnp.meshgrid(*frequencies))
            if input_dim != len(inputs):
                frequencies = jnp.repeat(
                    frequencies[jnp.newaxis, ...], inputs.shape[0], axis=0
                )
                freq_inputs = jnp.einsum("bi...,b->b...", frequencies, inputs)
                exponents = jnp.exp(1j * freq_inputs).T
                exp = jnp.einsum("jl...k,jl...b->b...k", coefficients, exponents)
            else:
                freq_inputs = jnp.einsum("i...,i->...", frequencies, inputs)
                exponents = jnp.exp(1j * freq_inputs).T
                exp = jnp.einsum("jl...k,jl...->k...", coefficients, exponents)
        else:
            frequencies = jnp.repeat(
                frequencies[jnp.newaxis, ...], inputs.shape[0], axis=0
            )
            freq_inputs = jnp.einsum("i...,i->i...", frequencies, inputs)
            exponents = jnp.exp(1j * freq_inputs)
            exp = jnp.einsum("j...k,ij...->ik...", coefficients, exponents)

        return jnp.squeeze(jnp.real(exp))

evaluate_Fourier_series(coefficients, frequencies, inputs) staticmethod #

Evaluate the function value of a Fourier series at one point.

Parameters:

Name Type Description Default
coefficients ndarray

Coefficients of the Fourier series.

required
frequencies ndarray

Corresponding frequencies.

required
inputs ndarray

Point at which to evaluate the function.

required

Returns: float: The function value at the input point.

Source code in qml_essentials/coefficients.py
@staticmethod
def evaluate_Fourier_series(
    coefficients: jnp.ndarray,
    frequencies: jnp.ndarray,
    inputs: Union[jnp.ndarray, list, float],
) -> float:
    """
    Evaluate the function value of a Fourier series at one point.

    Args:
        coefficients (jnp.ndarray): Coefficients of the Fourier series.
        frequencies (jnp.ndarray): Corresponding frequencies.
        inputs (jnp.ndarray): Point at which to evaluate the function.
    Returns:
        float: The function value at the input point.
    """
    if isinstance(frequencies, list):
        if len(coefficients.shape) <= len(frequencies):
            coefficients = coefficients[..., jnp.newaxis]
    else:
        if len(coefficients.shape) == 1:
            coefficients = coefficients[..., jnp.newaxis]

    if isinstance(inputs, list):
        inputs = jnp.array(inputs)
    if len(inputs.shape) < 1:
        inputs = inputs[jnp.newaxis, ...]

    if isinstance(frequencies, list):
        input_dim = len(frequencies)
        frequencies = jnp.stack(jnp.meshgrid(*frequencies))
        if input_dim != len(inputs):
            frequencies = jnp.repeat(
                frequencies[jnp.newaxis, ...], inputs.shape[0], axis=0
            )
            freq_inputs = jnp.einsum("bi...,b->b...", frequencies, inputs)
            exponents = jnp.exp(1j * freq_inputs).T
            exp = jnp.einsum("jl...k,jl...b->b...k", coefficients, exponents)
        else:
            freq_inputs = jnp.einsum("i...,i->...", frequencies, inputs)
            exponents = jnp.exp(1j * freq_inputs).T
            exp = jnp.einsum("jl...k,jl...->k...", coefficients, exponents)
    else:
        frequencies = jnp.repeat(
            frequencies[jnp.newaxis, ...], inputs.shape[0], axis=0
        )
        freq_inputs = jnp.einsum("i...,i->i...", frequencies, inputs)
        exponents = jnp.exp(1j * freq_inputs)
        exp = jnp.einsum("j...k,ij...->ik...", coefficients, exponents)

    return jnp.squeeze(jnp.real(exp))

get_psd(coeffs) staticmethod #

Calculates the power spectral density (PSD) from given Fourier coefficients.

Parameters:

Name Type Description Default
coeffs ndarray

The Fourier coefficients.

required

Returns:

Type Description
ndarray

jnp.ndarray: The power spectral density.

Source code in qml_essentials/coefficients.py
@staticmethod
def get_psd(coeffs: jnp.ndarray) -> jnp.ndarray:
    """
    Calculates the power spectral density (PSD) from given Fourier coefficients.

    Args:
        coeffs (jnp.ndarray): The Fourier coefficients.

    Returns:
        jnp.ndarray: The power spectral density.
    """
    # TODO: if we apply trim=True in advance, this will be slightly wrong..

    def abs2(x):
        return x.real**2 + x.imag**2

    scale = 2.0 / (len(coeffs) ** 2)
    return scale * abs2(coeffs)

get_spectrum(model, mfs=1, mts=1, shift=False, trim=False, numerical_cap=-1, **kwargs) staticmethod #

Extracts the coefficients of a given model using a FFT (jnp-fft).

Note that the coefficients are complex numbers, but the imaginary part of the coefficients should be very close to zero, since the expectation values of the Pauli operators are real numbers.

It can perform oversampling in both the frequency and time domain using the mfs and mts arguments.

Parameters:

Name Type Description Default
model Model

The model to sample.

required
mfs int

Multiplicator for the highest frequency. Default is 2.

1
mts int

Multiplicator for the number of time samples. Default is 1.

1
shift bool

Whether to apply jnp-fftshift. Default is False.

False
trim bool

Whether to remove the Nyquist frequency if spectrum is even. Default is False.

False
numerical_cap Optional[float]

Numerical cap for the coefficients.

-1
kwargs Any

Additional keyword arguments for the model function.

{}

Returns:

Type Description
ndarray

Tuple[jnp.ndarray, jnp.ndarray]: Tuple containing the coefficients

ndarray

and frequencies.

Source code in qml_essentials/coefficients.py
@staticmethod
def get_spectrum(
    model: Model,
    mfs: int = 1,
    mts: int = 1,
    shift=False,
    trim=False,
    numerical_cap: Optional[float] = -1,
    **kwargs,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """
    Extracts the coefficients of a given model using a FFT (jnp-fft).

    Note that the coefficients are complex numbers, but the imaginary part
    of the coefficients should be very close to zero, since the expectation
    values of the Pauli operators are real numbers.

    It can perform oversampling in both the frequency and time domain
    using the `mfs` and `mts` arguments.

    Args:
        model (Model): The model to sample.
        mfs (int): Multiplicator for the highest frequency. Default is 2.
        mts (int): Multiplicator for the number of time samples. Default is 1.
        shift (bool): Whether to apply jnp-fftshift. Default is False.
        trim (bool): Whether to remove the Nyquist frequency if spectrum is even.
            Default is False.
        numerical_cap (Optional[float]): Numerical cap for the coefficients.
        kwargs (Any): Additional keyword arguments for the model function.

    Returns:
        Tuple[jnp.ndarray, jnp.ndarray]: Tuple containing the coefficients
        and frequencies.
    """
    kwargs.setdefault("force_mean", True)
    kwargs.setdefault("execution_type", "expval")

    coeffs, freqs = Coefficients._fourier_transform(
        model, mfs=mfs, mts=mts, **kwargs
    )

    if not jnp.isclose(jnp.sum(coeffs).imag, 0.0, rtol=1.0e-5):
        raise ValueError(
            f"Spectrum is not real. Imaginary part of coefficients is:\
            {jnp.sum(coeffs).imag}"
        )

    if trim:
        for ax in range(model.n_input_feat):
            if coeffs.shape[ax] % 2 == 0:
                coeffs = np.delete(coeffs, len(coeffs) // 2, axis=ax)
                freqs = [np.delete(freq, len(freq) // 2, axis=ax) for freq in freqs]

    if shift:
        coeffs = jnp.fft.fftshift(coeffs, axes=list(range(model.n_input_feat)))
        freqs = np.fft.fftshift(freqs)

    if numerical_cap > 0:
        # set coeffs below threshold to zero
        coeffs = jnp.where(
            jnp.abs(coeffs) < numerical_cap,
            jnp.zeros_like(coeffs),
            coeffs,
        )

    if len(freqs) == 1:
        freqs = freqs[0]

    return coeffs, freqs