Coverage for qml_essentials/coefficients.py: 96%
424 statements
« prev ^ index » next coverage.py v7.9.2, created at 2026-02-20 14:03 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2026-02-20 14:03 +0000
1from __future__ import annotations
2import math
3from collections import defaultdict
4from dataclasses import dataclass
5import pennylane as qml
6import jax.numpy as jnp
7from jax import random
8import numpy as np
9from pennylane.operation import Operator
10import pennylane.ops.op_math as qml_op
11from scipy.stats import rankdata
12from functools import reduce
13from typing import List, Tuple, Optional, Any, Dict, Union
15from qml_essentials.model import Model
16from qml_essentials.utils import PauliCircuit
18import logging
20log = logging.getLogger(__name__)
23class Coefficients:
24 @staticmethod
25 def get_spectrum(
26 model: Model,
27 mfs: int = 1,
28 mts: int = 1,
29 shift=False,
30 trim=False,
31 **kwargs,
32 ) -> Tuple[jnp.ndarray, jnp.ndarray]:
33 """
34 Extracts the coefficients of a given model using a FFT (jnp-fft).
36 Note that the coefficients are complex numbers, but the imaginary part
37 of the coefficients should be very close to zero, since the expectation
38 values of the Pauli operators are real numbers.
40 It can perform oversampling in both the frequency and time domain
41 using the `mfs` and `mts` arguments.
43 Args:
44 model (Model): The model to sample.
45 mfs (int): Multiplicator for the highest frequency. Default is 2.
46 mts (int): Multiplicator for the number of time samples. Default is 1.
47 shift (bool): Whether to apply jnp-fftshift. Default is False.
48 trim (bool): Whether to remove the Nyquist frequency if spectrum is even.
49 Default is False.
50 kwargs (Any): Additional keyword arguments for the model function.
52 Returns:
53 Tuple[jnp.ndarray, jnp.ndarray]: Tuple containing the coefficients
54 and frequencies.
55 """
56 kwargs.setdefault("force_mean", True)
57 kwargs.setdefault("execution_type", "expval")
59 coeffs, freqs = Coefficients._fourier_transform(
60 model, mfs=mfs, mts=mts, **kwargs
61 )
63 if not jnp.isclose(jnp.sum(coeffs).imag, 0.0, rtol=1.0e-5):
64 raise ValueError(
65 f"Spectrum is not real. Imaginary part of coefficients is:\
66 {jnp.sum(coeffs).imag}"
67 )
69 if trim:
70 for ax in range(model.n_input_feat):
71 if coeffs.shape[ax] % 2 == 0:
72 coeffs = np.delete(coeffs, len(coeffs) // 2, axis=ax)
73 freqs = [np.delete(freq, len(freq) // 2, axis=ax) for freq in freqs]
75 if shift:
76 coeffs = jnp.fft.fftshift(coeffs, axes=list(range(model.n_input_feat)))
77 freqs = np.fft.fftshift(freqs)
79 if len(freqs) == 1:
80 freqs = freqs[0]
82 return coeffs, freqs
84 @staticmethod
85 def _fourier_transform(
86 model: Model, mfs: int, mts: int, **kwargs: Any
87 ) -> jnp.ndarray:
88 # Create a frequency vector with as many frequencies as model degrees,
89 # oversampled by mfs
90 n_freqs: jnp.ndarray = jnp.array(
91 [mfs * model.degree[i] for i in range(model.n_input_feat)]
92 )
94 start, stop, step = 0, 2 * mts * jnp.pi, 2 * jnp.pi / n_freqs
95 # Stretch according to the number of frequencies
96 inputs: List = [
97 jnp.arange(start, stop, step[i]) for i in range(model.n_input_feat)
98 ]
100 # permute with input dimensionality
101 nd_inputs = jnp.array(
102 jnp.meshgrid(*[inputs[i] for i in range(model.n_input_feat)])
103 ).T.reshape(-1, model.n_input_feat)
105 # Output vector is not necessarily the same length as input
106 outputs = model(inputs=nd_inputs, **kwargs)
107 outputs = outputs.reshape(
108 *[inputs[i].shape[0] for i in range(model.n_input_feat)], -1
109 ).squeeze()
111 coeffs = jnp.fft.fftn(outputs, axes=list(range(model.n_input_feat)))
113 freqs = [
114 jnp.fft.fftfreq(int(mts * n_freqs[i]), 1 / n_freqs[i])
115 for i in range(model.n_input_feat)
116 ]
117 # freqs = jnp.fft.fftfreq(mts * n_freqs, 1 / n_freqs)
119 # TODO: this could cause issues with multidim input
120 # FIXME: account for different frequencies in multidim input scenarios
121 # Run the fft and rearrange +
122 # normalize the output (using product if multidim)
123 return (
124 coeffs / math.prod(outputs.shape[0 : model.n_input_feat]),
125 freqs,
126 )
128 @staticmethod
129 def get_psd(coeffs: jnp.ndarray) -> jnp.ndarray:
130 """
131 Calculates the power spectral density (PSD) from given Fourier coefficients.
133 Args:
134 coeffs (jnp.ndarray): The Fourier coefficients.
136 Returns:
137 jnp.ndarray: The power spectral density.
138 """
139 # TODO: if we apply trim=True in advance, this will be slightly wrong..
141 def abs2(x):
142 return x.real**2 + x.imag**2
144 scale = 2.0 / (len(coeffs) ** 2)
145 return scale * abs2(coeffs)
147 @staticmethod
148 def evaluate_Fourier_series(
149 coefficients: jnp.ndarray,
150 frequencies: jnp.ndarray,
151 inputs: Union[jnp.ndarray, list, float],
152 ) -> float:
153 """
154 Evaluate the function value of a Fourier series at one point.
156 Args:
157 coefficients (jnp.ndarray): Coefficients of the Fourier series.
158 frequencies (jnp.ndarray): Corresponding frequencies.
159 inputs (jnp.ndarray): Point at which to evaluate the function.
160 Returns:
161 float: The function value at the input point.
162 """
163 if isinstance(frequencies, list):
164 if len(coefficients.shape) <= len(frequencies):
165 coefficients = coefficients[..., jnp.newaxis]
166 else:
167 if len(coefficients.shape) == 1:
168 coefficients = coefficients[..., jnp.newaxis]
170 if isinstance(inputs, list):
171 inputs = jnp.array(inputs)
172 if len(inputs.shape) < 1:
173 inputs = inputs[jnp.newaxis, ...]
175 if isinstance(frequencies, list):
176 input_dim = len(frequencies)
177 frequencies = jnp.stack(jnp.meshgrid(*frequencies))
178 if input_dim != len(inputs):
179 frequencies = jnp.repeat(
180 frequencies[jnp.newaxis, ...], inputs.shape[0], axis=0
181 )
182 freq_inputs = jnp.einsum("bi...,b->b...", frequencies, inputs)
183 exponents = jnp.exp(1j * freq_inputs).T
184 exp = jnp.einsum("jl...k,jl...b->b...k", coefficients, exponents)
185 else:
186 freq_inputs = jnp.einsum("i...,i->...", frequencies, inputs)
187 exponents = jnp.exp(1j * freq_inputs).T
188 exp = jnp.einsum("jl...k,jl...->k...", coefficients, exponents)
189 else:
190 frequencies = jnp.repeat(
191 frequencies[jnp.newaxis, ...], inputs.shape[0], axis=0
192 )
193 freq_inputs = jnp.einsum("i...,i->i...", frequencies, inputs)
194 exponents = jnp.exp(1j * freq_inputs)
195 exp = jnp.einsum("j...k,ij...->ik...", coefficients, exponents)
197 return jnp.squeeze(jnp.real(exp))
200class FourierTree:
201 """
202 Sine-cosine tree representation for the algorithm by Nemkov et al.
203 This tree can be used to obtain analytical Fourier coefficients for a given
204 Pauli-Clifford circuit.
205 """
207 class CoefficientsTreeNode:
208 """
209 Representation of a node in the coefficients tree for the algorithm by
210 Nemkov et al.
211 """
213 def __init__(
214 self,
215 parameter_idx: Optional[int],
216 observable: FourierTree.PauliOperator,
217 is_sine_factor: bool,
218 is_cosine_factor: bool,
219 left: Optional[FourierTree.CoefficientsTreeNode] = None,
220 right: Optional[FourierTree.CoefficientsTreeNode] = None,
221 ):
222 """
223 Coefficient tree node initialisation. Each node has information about
224 its creation context and it's children, i.e.:
226 Args:
227 parameter_idx (Optional[int]): Index of the corresp. param. index i.
228 observable (FourierTree.PauliOperator): The nodes observable to
229 obtain the expectation value that contributes to the constant
230 term.
231 is_sine_factor (bool): If this node belongs to a sine coefficient.
232 is_cosine_factor (bool): If this node belongs to a cosine coefficient.
233 left (Optional[CoefficientsTreeNode]): left child (if any).
234 right (Optional[CoefficientsTreeNode]): right child (if any).
235 """
236 self.parameter_idx = parameter_idx
238 assert not (
239 is_sine_factor and is_cosine_factor
240 ), "Cannot be sine and cosine at the same time"
241 self.is_sine_factor = is_sine_factor
242 self.is_cosine_factor = is_cosine_factor
244 # If the observable does not constist of only Z and I, the
245 # expectation (and therefore the constant node term) is zero
246 if jnp.logical_or(
247 observable.list_repr == 0, observable.list_repr == 1
248 ).any():
249 self.term = 0.0
250 else:
251 self.term = observable.phase
253 self.left = left
254 self.right = right
256 def evaluate(self, parameters: list[float]) -> float:
257 """
258 Recursive function to evaluate the expectation of the coefficient tree,
259 starting from the current node.
261 Args:
262 parameters (list[float]): The parameters, by which the circuit (and
263 therefore the tree) is parametrised.
265 Returns:
266 float: The expectation for the current node and it's children.
267 """
268 factor = (
269 parameters[self.parameter_idx]
270 if self.parameter_idx is not None
271 else 1.0
272 )
273 if self.is_sine_factor:
274 factor = 1j * jnp.sin(factor)
275 elif self.is_cosine_factor:
276 factor = jnp.cos(factor)
277 if not (self.left or self.right): # leaf
278 return factor * self.term
280 sum_children = 0.0
281 if self.left:
282 left = self.left.evaluate(parameters)
283 sum_children = sum_children + left
284 if self.right:
285 right = self.right.evaluate(parameters)
286 sum_children = sum_children + right
288 return factor * sum_children
290 def get_leafs(
291 self,
292 sin_list: List[int],
293 cos_list: List[int],
294 existing_leafs: List[FourierTree.TreeLeaf] = [],
295 ) -> List[FourierTree.TreeLeaf]:
296 """
297 Traverse the tree starting from the current node, to obtain the tree
298 leafs only.
299 The leafs correspond to the terms in the sine-cosine tree
300 representation that eventually are used to obtain coefficients and
301 frequencies.
302 Sine and cosine lists are recursively passed to the children until a
303 leaf is reached (top to bottom).
304 Leafs are then passed bottom to top to the caller.
306 Args:
307 sin_list (List[int]): Current number of sine contributions for each
308 parameter. Has the same length as the parameters, as each
309 position corresponds to one parameter.
310 cos_list (List[int]): Current number of cosine contributions for
311 each parameter. Has the same length as the parameters, as each
312 position corresponds to one parameter.
313 existing_leafs (List[TreeLeaf]): Current list of leaf nodes from
314 parents.
316 Returns:
317 List[TreeLeaf]: Updated list of leaf nodes.
318 """
320 if self.is_sine_factor:
321 sin_list = sin_list.at[self.parameter_idx].set(
322 sin_list[self.parameter_idx] + 1
323 )
324 if self.is_cosine_factor:
325 cos_list = cos_list.at[self.parameter_idx].set(
326 cos_list[self.parameter_idx] + 1
327 )
329 if not (self.left or self.right): # leaf
330 if self.term != 0.0:
331 return [FourierTree.TreeLeaf(sin_list, cos_list, self.term)]
332 else:
333 return []
335 if self.left:
336 leafs_left = self.left.get_leafs(
337 sin_list.copy(), cos_list.copy(), existing_leafs.copy()
338 )
339 else:
340 leafs_left = []
342 if self.right:
343 leafs_right = self.right.get_leafs(
344 sin_list.copy(), cos_list.copy(), existing_leafs.copy()
345 )
346 else:
347 leafs_right = []
349 existing_leafs.extend(leafs_left)
350 existing_leafs.extend(leafs_right)
351 return existing_leafs
353 @dataclass
354 class TreeLeaf:
355 """
356 Coefficient tree leafs according to the algorithm by Nemkov et al., which
357 correspond to the terms in the sine-cosine tree representation that
358 eventually are used to obtain coefficients and frequencies.
360 Args:
361 sin_indices (List[int]): Current number of sine contributions for each
362 parameter. Has the same length as the parameters, as each
363 position corresponds to one parameter.
364 cos_list (List[int]): Current number of cosine contributions for
365 each parameter. Has the same length as the parameters, as each
366 position corresponds to one parameter.
367 term (jnp.complex): Constant factor of the leaf, depending on the
368 expectation value of the observable, and a phase.
369 """
371 sin_indices: List[int]
372 cos_indices: List[int]
373 term: jnp.complex128
375 class PauliOperator:
376 """
377 Utility class for storing Pauli Rotations, the corresponding indices
378 in the XY-Space (whether there is a gate with X or Y generator at a
379 certain qubit) and the phase.
381 Args:
382 pauli (Union[Operator, jnp.ndarray[int]]): Pauli Rotation Operation
383 or list representation
384 n_qubits (int): Number of qubits in the circuit
385 prev_xy_indices (Optional[jnp.ndarray[bool]]): X/Y indices of the
386 previous Pauli sequence. Defaults to None.
387 is_observable (bool): If the operator is an observable. Defaults to
388 False.
389 is_init (bool): If this Pauli operator is initialised the first
390 time. Defaults to True.
391 phase (float): Phase of the operator. Defaults to 1.0
392 """
394 def __init__(
395 self,
396 pauli: Union[Operator, jnp.ndarray[int]],
397 n_qubits: int,
398 prev_xy_indices: Optional[jnp.ndarray[bool]] = None,
399 is_observable: bool = False,
400 is_init: bool = True,
401 phase: float = 1.0,
402 ):
403 self.is_observable = is_observable
404 self.phase = phase
406 if is_init:
407 if not is_observable:
408 pauli = pauli.generator()[0].base
409 self.list_repr = self._create_list_representation(pauli, n_qubits)
410 else:
411 assert isinstance(pauli, jnp.ndarray)
412 self.list_repr = pauli
414 if prev_xy_indices is None:
415 prev_xy_indices = jnp.zeros(n_qubits, dtype=bool)
416 self.xy_indices = jnp.logical_or(
417 prev_xy_indices,
418 self._compute_xy_indices(self.list_repr, rev=is_observable),
419 )
421 @staticmethod
422 def _compute_xy_indices(
423 op: jnp.ndarray[int], rev: bool = False
424 ) -> jnp.ndarray[bool]:
425 """
426 Computes the positions of X or Y gates to an one-hot encoded boolen
427 array.
429 Args:
430 op (jnp.ndarray[int]): Pauli-Operation list representation.
431 rev (bool): Whether to negate the array.
433 Returns:
434 jnp.ndarray[bool]: One hot encoded boolean array.
435 """
436 xy_indices = (op == 0) + (op == 1)
437 if rev:
438 xy_indices = ~xy_indices
439 return xy_indices
441 @staticmethod
442 def _create_list_representation(
443 op: Operator, n_qubits: int
444 ) -> jnp.ndarray[int]:
445 """
446 Create list representation of a Pennylane Operator.
447 Generally, the list representation is a list of length n_qubits,
448 where at each position a Pauli Operator is encoded as such:
449 I: -1
450 X: 0
451 Y: 1
452 Z: 2
454 Args:
455 op (Operator): Pennylane Operator
456 n_qubits (int): number of qubits in the circuit
458 Returns:
459 jnp.ndarray[int]: List representation
460 """
461 pauli_repr = -jnp.ones(n_qubits, dtype=int)
462 op = op.terms()[1][0] if isinstance(op, qml_op.Prod) else op
463 op = op.base if isinstance(op, qml_op.SProd) else op
465 if isinstance(op, qml.PauliX):
466 pauli_repr = pauli_repr.at[op.wires[0]].set(0)
467 elif isinstance(op, qml.PauliY):
468 pauli_repr = pauli_repr.at[op.wires[0]].set(1)
469 elif isinstance(op, qml.PauliZ):
470 pauli_repr = pauli_repr.at[op.wires[0]].set(2)
471 else:
472 for o in op:
473 if isinstance(o, qml.PauliX):
474 pauli_repr = pauli_repr.at[o.wires[0]].set(0)
475 elif isinstance(o, qml.PauliY):
476 pauli_repr = pauli_repr.at[o.wires[0]].set(1)
477 elif isinstance(o, qml.PauliZ):
478 pauli_repr = pauli_repr.at[o.wires[0]].set(2)
479 return pauli_repr
481 def is_commuting(self, pauli: jnp.ndarray[int]) -> bool:
482 """
483 Computes if this Pauli commutes with another Pauli operator.
484 This computation is based on the fact that The commutator is zero
485 if and only if the number of anticommuting single-qubit Paulis is
486 even.
488 Args:
489 pauli (jnp.ndarray[int]): List representation of another Pauli
491 Returns:
492 bool: If the current and other Pauli are commuting.
493 """
494 anticommutator = jnp.where(
495 pauli < 0,
496 False,
497 jnp.where(
498 self.list_repr < 0,
499 False,
500 jnp.where(self.list_repr == pauli, False, True),
501 ),
502 )
503 return not (jnp.sum(anticommutator) % 2)
505 def tensor(self, pauli: jnp.ndarray[int]) -> FourierTree.PauliOperator:
506 """
507 Compute tensor product between the current Pauli and a given list
508 representation of another Pauli operator.
510 Args:
511 pauli (jnp.ndarray[int]): List representation of Pauli
513 Returns:
514 FourierTree.PauliOperator: New Pauli operator object, which
515 contains the tensor product
516 """
517 diff = (pauli - self.list_repr + 3) % 3
518 phase = jnp.where(
519 self.list_repr < 0,
520 1.0,
521 jnp.where(
522 pauli < 0,
523 1.0,
524 jnp.where(
525 diff == 2,
526 1.0j,
527 jnp.where(diff == 1, -1.0j, 1.0),
528 ),
529 ),
530 )
532 obs = jnp.where(
533 self.list_repr < 0,
534 pauli,
535 jnp.where(
536 pauli < 0,
537 self.list_repr,
538 jnp.where(
539 diff == 2,
540 (self.list_repr + 1) % 3,
541 jnp.where(diff == 1, (self.list_repr + 2) % 3, -1),
542 ),
543 ),
544 )
545 phase = self.phase * jnp.prod(phase)
546 return FourierTree.PauliOperator(
547 obs, phase=phase, n_qubits=obs.size, is_init=False, is_observable=True
548 )
550 def __init__(self, model: Model, inputs: Optional[jnp.ndarray] = None):
551 """
552 Tree initialisation, based on the Pauli-Clifford representation of a model.
553 Currently, only one input feature is supported.
555 **Usage**:
556 ```
557 # initialise a model
558 model = Model(...)
560 # initialise and build FourierTree
561 tree = FourierTree(model)
563 # get expectaion value
564 exp = tree()
566 # Get spectrum (for each observable, we have one list element)
567 coeff_list, freq_list = tree.spectrum()
568 ```
570 Args:
571 model (Model): The Model, for which to build the tree
572 inputs (bool, optional): Possible default inputs. Defaults to 1.0.
573 """
574 self.model = model
575 self.tree_roots = None
577 inputs = (
578 self.model._inputs_validation(inputs)
579 if inputs is not None
580 else self.model._inputs_validation([1.0])
581 )
583 # TODO: duplicate the input to find out, where it is in the tape. Not
584 # really pretty.
585 if inputs.shape[0] == 1:
586 inputs = jnp.repeat(inputs, 2, axis=0)
588 quantum_tape = qml.workflow.construct_tape(self.model.circuit)(
589 params=model.params, inputs=inputs
590 )
592 quantum_tape = PauliCircuit.from_parameterised_circuit(quantum_tape)
594 self.parameters = [jnp.squeeze(p) for p in quantum_tape.get_parameters()]
596 self.input_indices = [
597 i for (i, p) in enumerate(self.parameters) if p.shape != ()
598 ]
600 self.observables = self._encode_observables(quantum_tape.observables)
602 pauli_rot = FourierTree.PauliOperator(
603 quantum_tape.operations[0],
604 self.model.n_qubits,
605 )
606 self.pauli_rotations = [pauli_rot]
607 for op in quantum_tape.operations[1:]:
608 pauli_rot = FourierTree.PauliOperator(
609 op, self.model.n_qubits, pauli_rot.xy_indices
610 )
611 self.pauli_rotations.append(pauli_rot)
613 self.tree_roots = self.build()
614 self.leafs: List[List[FourierTree.TreeLeaf]] = self._get_tree_leafs()
616 def __call__(
617 self,
618 params: Optional[jnp.ndarray] = None,
619 inputs: Optional[jnp.ndarray] = None,
620 **kwargs,
621 ) -> jnp.ndarray:
622 """
623 Evaluates the Fourier tree via sine-cosine terms sum. This is
624 equivalent to computing the expectation value of the observables with
625 respect to the corresponding circuit.
627 Args:
628 params (Optional[jnp.ndarray], optional): Parameters of the model.
629 Defaults to None.
630 inputs (Optional[jnp.ndarray], optional): Inputs to the circuit.
631 Defaults to None.
633 Returns:
634 jnp.ndarray: Expectation value of the tree.
636 Raises:
637 NotImplementedError: When using other "execution_type" as expval.
638 NotImplementedError: When using "noise_params"
641 """
642 params = (
643 self.model._params_validation(params)
644 if params is not None
645 else self.model.params
646 )
647 inputs = (
648 self.model._inputs_validation(inputs)
649 if inputs is not None
650 else self.model._inputs_validation(1.0)
651 )
653 if kwargs.get("execution_type", "expval") != "expval":
654 raise NotImplementedError(
655 f'Currently, only "expval" execution type is supported when '
656 f"building FourierTree. Got {kwargs.get('execution_type', 'expval')}."
657 )
658 if kwargs.get("noise_params", None) is not None:
659 raise NotImplementedError(
660 "Currently, noise is not supported when building FourierTree."
661 )
663 quantum_tape = qml.workflow.construct_tape(self.model.circuit)(
664 params=self.model.params, inputs=inputs
665 )
666 quantum_tape = PauliCircuit.from_parameterised_circuit(quantum_tape)
668 self.parameters = [jnp.squeeze(p) for p in quantum_tape.get_parameters()]
670 results = jnp.zeros(len(self.tree_roots))
671 for i, root in enumerate(self.tree_roots):
672 results = results.at[i].set(jnp.real(root.evaluate(self.parameters)))
674 if kwargs.get("force_mean", False):
675 return jnp.mean(results)
676 else:
677 return results
679 def build(self) -> List[CoefficientsTreeNode]:
680 """
681 Creates the coefficient tree, i.e. it creates and initialises the tree
682 nodes.
683 Leafs can be obtained separately in _get_tree_leafs, once the tree is
684 set up.
686 Returns:
687 List[CoefficientsTreeNode]: The list of root nodes (one root for
688 each observable).
689 """
690 tree_roots = []
691 pauli_rotation_idx = len(self.pauli_rotations) - 1
692 for obs in self.observables:
693 root = self._create_tree_node(obs, pauli_rotation_idx)
694 tree_roots.append(root)
695 return tree_roots
697 def _encode_observables(
698 self, tape_obs: List[Operator]
699 ) -> List[FourierTree.PauliOperator]:
700 """
701 Encodes Pennylane observables from tape as FourierTree.PauliOperator
702 utility objects.
704 Args:
705 tape_obs (List[Operator]): Pennylane tape operations
707 Returns:
708 List[FourierTree.PauliOperator]: List of Pauli operators
709 """
710 observables = []
711 for obs in tape_obs:
712 observables.append(
713 FourierTree.PauliOperator(obs, self.model.n_qubits, is_observable=True)
714 )
715 return observables
717 def _get_tree_leafs(self) -> List[List[TreeLeaf]]:
718 """
719 Obtain all Leaf Nodes with its sine- and cosine terms.
721 Returns:
722 List[List[TreeLeaf]]: For each observable (root), the list of leaf
723 nodes.
724 """
725 leafs = []
726 for root in self.tree_roots:
727 sin_list = jnp.zeros(len(self.parameters), dtype=jnp.int32)
728 cos_list = jnp.zeros(len(self.parameters), dtype=jnp.int32)
729 leafs.append(root.get_leafs(sin_list, cos_list, []))
730 return leafs
732 def get_spectrum(
733 self, force_mean: bool = False
734 ) -> Tuple[List[jnp.ndarray], List[jnp.ndarray]]:
735 """
736 Computes the Fourier spectrum for the tree, consisting of the
737 frequencies and its corresponding coefficinets.
738 If the frag force_mean was set in the constructor, the mean coefficient
739 over all observables (roots) are computed.
741 Args:
742 force_mean (bool, optional): Whether to average over multiple
743 observables. Defaults to False.
745 Returns:
746 Tuple[List[jnp.ndarray], List[jnp.ndarray]]:
747 - List of frequencies, one list for each observable (root).
748 - List of corresponding coefficents, one list for each
749 observable (root).
750 """
751 parameter_indices = [
752 i for i in range(len(self.parameters)) if i not in self.input_indices
753 ]
755 coeffs = []
756 for leafs in self.leafs:
757 freq_terms = defaultdict(np.complex128)
758 for leaf in leafs:
759 leaf_factor, s, c = self._compute_leaf_factors(leaf, parameter_indices)
761 for a in range(s + 1):
762 for b in range(c + 1):
763 comb = math.comb(s, a) * math.comb(c, b) * (-1) ** (s - a)
764 freq_terms[2 * a + 2 * b - s._value - c._value] += (
765 comb * leaf_factor
766 )
768 coeffs.append(freq_terms)
770 frequencies, coefficients = self._freq_terms_to_coeffs(coeffs, force_mean)
771 return coefficients, frequencies
773 def _freq_terms_to_coeffs(
774 self, coeffs: List[Dict[int, jnp.ndarray]], force_mean: bool
775 ) -> Tuple[List[jnp.ndarray], List[jnp.ndarray]]:
776 """
777 Given a list of dictionaries of the form:
778 [
779 {
780 freq_obs1_1: coeff1,
781 freq_obs1_2: coeff2,
782 ...
783 },
784 {
785 freq_obs2_1: coeff3,
786 freq_obs2_2: coeff4,
787 ...
788 }
789 ...
790 ],
791 Compute two separate lists of frequencies and coefficients.
792 such that:
793 freqs: [
794 [freq_obs1_1, freq_obs1_1, ...],
795 [freq_obs2_1, freq_obs2_1, ...],
796 ...
797 ]
798 coeffs: [
799 [coeff1, coeff2, ...],
800 [coeff3, coeff4, ...],
801 ...
802 ]
804 If force_mean is set length of the resulting frequency and coefficent
805 list is 1.
807 Args:
808 coeffs (List[Dict[int, jnp.ndarray]]): Frequency->Coefficients
809 dictionary list, one dict for each observable (root).
810 force_mean (bool): Whether to average coefficients over multiple
811 observables.
813 Returns:
814 Tuple[List[jnp.ndarray], List[jnp.ndarray]]:
815 - List of frequencies, one list for each observable (root).
816 - List of corresponding coefficents, one list for each
817 observable (root).
818 """
819 frequencies = []
820 coefficients = []
821 if force_mean:
822 all_freqs = sorted(set([f for c in coeffs for f in c.keys()]))
823 coefficients.append(
824 jnp.array(
825 [
826 jnp.mean(jnp.array([c.get(f, 0.0) for c in coeffs]))
827 for f in all_freqs
828 ]
829 )
830 )
831 frequencies.append(jnp.array(all_freqs))
832 else:
833 for freq_terms in coeffs:
834 freq_terms = dict(sorted(freq_terms.items()))
835 frequencies.append(jnp.array(list(freq_terms.keys())))
836 coefficients.append(jnp.array(list(freq_terms.values())))
837 return frequencies, coefficients
839 def _compute_leaf_factors(
840 self, leaf: TreeLeaf, parameter_indices: List[int]
841 ) -> Tuple[float, int, int]:
842 """
843 Computes the constant coefficient factor for each leaf.
844 Additionally sine and cosine contributions of the input parameters for
845 this leaf are returned, which are required to obtain the corresponding
846 frequencies.
848 Args:
849 leaf (TreeLeaf): The leaf for which to compute the factor.
850 parameter_indices (List[int]): Variational parameter indices.
852 Returns:
853 Tuple[float, int, int]:
854 - float: the constant factor for the leaf
855 - int: number of sine contributions of the input
856 - int: number of cosine contributions of the input
857 """
858 leaf_factor = 1.0
859 for i in parameter_indices:
860 interm_factor = (
861 jnp.cos(self.parameters[i]) ** leaf.cos_indices[i]
862 * (1j * jnp.sin(self.parameters[i])) ** leaf.sin_indices[i]
863 )
864 leaf_factor = leaf_factor * interm_factor
866 # Get number of sine and cosine factors to which the input contributes
867 c = jnp.sum(jnp.array([leaf.cos_indices[k] for k in self.input_indices]))
868 s = jnp.sum(jnp.array([leaf.sin_indices[k] for k in self.input_indices]))
870 leaf_factor = leaf.term * leaf_factor * 0.5 ** (s + c)
872 return leaf_factor, s, c
874 def _early_stopping_possible(
875 self, pauli_rotation_idx: int, observable: FourierTree.PauliOperator
876 ):
877 """
878 Checks if a node for an observable can be discarded as all expecation
879 values that can result through further branching are zero.
880 The method is mentioned in the paper by Nemkov et al.: If the one-hot
881 encoded indices for X/Y operations in the Pauli-rotation generators are
882 a basis for that of the observable, the node must be processed further.
883 If not, it can be discarded.
885 Args:
886 pauli_rotation_idx (int): Index of remaining Pauli rotation gates.
887 Gates itself are attributes of the class.
888 observable (FourierTree.PauliOperator): Current observable
889 """
890 xy_indices_obs = jnp.logical_or(
891 observable.xy_indices, self.pauli_rotations[pauli_rotation_idx].xy_indices
892 ).all()
894 return not xy_indices_obs
896 def _create_tree_node(
897 self,
898 observable: FourierTree.PauliOperator,
899 pauli_rotation_idx: int,
900 parameter_idx: Optional[int] = None,
901 is_sine: bool = False,
902 is_cosine: bool = False,
903 ) -> Optional[CoefficientsTreeNode]:
904 """
905 Builds the Fourier-Tree according to the algorithm by Nemkov et al.
907 Args:
908 observable (FourierTree.PauliOperator): Current observable
909 pauli_rotation_idx (int): Index of remaining Pauli rotation gates.
910 Gates itself are attributes of the class.
911 parameter_idx (Optional[int]): Index of the current parameter.
912 Parameters itself are attributes of the class.
913 is_sine (bool): If the current node is a sine (left) node.
914 is_cosine (bool): If the current node is a cosine (right) node.
916 Returns:
917 Optional[CoefficientsTreeNode]: The resulting node. Children are set
918 recursively. The top level receives the tree root.
919 """
920 if self._early_stopping_possible(pauli_rotation_idx, observable):
921 return None
923 # remove commuting paulis
924 while pauli_rotation_idx >= 0:
925 last_pauli = self.pauli_rotations[pauli_rotation_idx]
926 if not observable.is_commuting(last_pauli.list_repr):
927 break
928 pauli_rotation_idx -= 1
929 else: # leaf
930 return FourierTree.CoefficientsTreeNode(
931 parameter_idx, observable, is_sine, is_cosine
932 )
934 last_pauli = self.pauli_rotations[pauli_rotation_idx]
936 left = self._create_tree_node(
937 observable,
938 pauli_rotation_idx - 1,
939 pauli_rotation_idx,
940 is_cosine=True,
941 )
943 next_observable = self._create_new_observable(last_pauli.list_repr, observable)
944 right = self._create_tree_node(
945 next_observable,
946 pauli_rotation_idx - 1,
947 pauli_rotation_idx,
948 is_sine=True,
949 )
951 return FourierTree.CoefficientsTreeNode(
952 parameter_idx,
953 observable,
954 is_sine,
955 is_cosine,
956 left,
957 right,
958 )
960 def _create_new_observable(
961 self, pauli: jnp.ndarray[int], observable: FourierTree.PauliOperator
962 ) -> FourierTree.PauliOperator:
963 """
964 Utility function to obtain the new observable for a tree node, if the
965 last Pauli and the observable do not commute.
967 Args:
968 pauli (jnp.ndarray[int]): The int array representation of the last
969 Pauli rotation in the operation sequence.
970 observable (FourierTree.PauliOperator): The current observable.
972 Returns:
973 FourierTree.PauliOperator: The new observable.
974 """
975 observable = observable.tensor(pauli)
976 return observable
979class FCC:
980 @staticmethod
981 def get_fcc(
982 model: Model,
983 n_samples: int,
984 seed: int,
985 method: Optional[str] = "pearson",
986 scale: Optional[bool] = False,
987 weight: Optional[bool] = False,
988 trim_redundant: Optional[bool] = True,
989 **kwargs,
990 ) -> float:
991 """
992 Shortcut method to get just the FCC.
993 This includes
994 1. What is done in `get_fourier_fingerprint`:
995 1. Calculating the coefficients (using `n_samples` and `seed`)
996 2. Correlating the result from 1) using `method`
997 3. Weighting the correlation matrix (if `weight` is True)
998 4. Remove redundancies
999 2. What is done in `calculate_fcc`:
1000 1. Absolute of the fingerprint
1001 2. Average
1003 Args:
1004 model (Model): The QFM model
1005 n_samples (int): Number of samples to calculate average of coefficients
1006 seed (int): Seed to initialize random parameters
1007 method (Optional[str], optional): Correlation method. Defaults to "pearson".
1008 scale (Optional[bool], optional): Whether to scale the number of samples.
1009 Defaults to False.
1010 weight (Optional[bool], optional): Whether to weight the correlation matrix.
1011 Defaults to False.
1012 trim_redundant (Optional[bool], optional): Whether to remove redundant
1013 correlations. Defaults to False.
1014 **kwargs: Additional keyword arguments for the model function.
1016 Returns:
1017 float: The FCC
1018 """
1019 fourier_fingerprint, _ = FCC.get_fourier_fingerprint(
1020 model,
1021 n_samples,
1022 seed,
1023 method,
1024 scale,
1025 weight,
1026 trim_redundant=trim_redundant,
1027 **kwargs,
1028 )
1030 return FCC.calculate_fcc(fourier_fingerprint)
1032 def get_fourier_fingerprint(
1033 model: Model,
1034 n_samples: int,
1035 seed: int,
1036 method: Optional[str] = "pearson",
1037 scale: Optional[bool] = False,
1038 weight: Optional[bool] = False,
1039 trim_redundant: Optional[bool] = True,
1040 **kwargs,
1041 ) -> Tuple[jnp.ndarray, jnp.ndarray]:
1042 """
1043 Shortcut method to get just the fourier fingerprint.
1044 This includes
1045 1. Calculating the coefficients (using `n_samples` and `seed`)
1046 2. Correlating the result from 1) using `method`
1047 3. Weighting the correlation matrix (if `weight` is True)
1048 4. Remove redundancies (if `trim_redundant` is True)
1050 Args:
1051 model (Model): The QFM model
1052 n_samples (int): Number of samples to calculate average of coefficients
1053 seed (int): Seed to initialize random parameters
1054 method (Optional[str], optional): Correlation method. Defaults to "pearson".
1055 scale (Optional[bool], optional): Whether to scale the number of samples.
1056 Defaults to False.
1057 weight (Optional[bool], optional): Whether to weight the correlation matrix.
1058 Defaults to False.
1059 trim_redundant (Optional[bool], optional): Whether to remove redundant
1060 correlations. Defaults to True.
1061 **kwargs: Additional keyword arguments for the model function.
1063 Returns:
1064 Tuple[jnp.ndarray, jnp.ndarray]: The fourier fingerprint
1065 and the frequency indices
1066 """
1067 _, coeffs, freqs = FCC._calculate_coefficients(
1068 model, n_samples, seed, scale, **kwargs
1069 )
1070 fourier_fingerprint = FCC._correlate(coeffs.transpose(), method=method)
1072 # perform weighting if requested
1073 fourier_fingerprint = (
1074 FCC._weighting(fourier_fingerprint) if weight else fourier_fingerprint
1075 )
1077 if trim_redundant:
1078 mask = FCC._calculate_mask(freqs)
1080 # apply the mask on the fingerprint
1081 fourier_fingerprint = mask * fourier_fingerprint
1083 row_mask = jnp.any(jnp.isfinite(fourier_fingerprint), axis=1)
1084 col_mask = jnp.any(jnp.isfinite(fourier_fingerprint), axis=0)
1086 fourier_fingerprint = fourier_fingerprint[row_mask][:, col_mask]
1088 return fourier_fingerprint, freqs
1090 @staticmethod
1091 def calculate_fcc(
1092 fourier_fingerprint: jnp.ndarray,
1093 ) -> float:
1094 """
1095 Method to calculate the FCC based on an existing correlation matrix.
1096 Calculate absolute and then the average over this matrix.
1097 The Fingerprint can be obtained via `get_fourier_fingerprint`
1099 Args:
1100 coeff_coeff_correlation (jnp.ndarray): Correlation matrix of coefficients
1101 Returns:
1102 float: The FCC
1103 """
1104 # apply the mask on the fingerprint
1105 return jnp.nanmean(jnp.abs(fourier_fingerprint))
1107 def _calculate_mask(freqs: jnp.ndarray) -> jnp.ndarray:
1108 """
1109 Method to calculate a mask filtering out redundant elements
1110 of the Fourier correlation matrix, based on the provided frequency vector.
1111 It does so by 'simulating' the operations that would be performed
1112 by `_correlate`.
1114 Args:
1115 freqs (jnp.ndarray): Array of frequencies
1117 Returns:
1118 jnp.ndarray: The mask
1119 """
1120 # TODO: this part can be heavily optimized, by e.g. using a "positive_only"
1121 # flag when calculating the coefficients.
1122 # However this would change the numerical values
1123 # (while the order should be still the same).
1125 # disregard all the negativ frequencies
1126 freqs[freqs < 0] = jnp.nan
1127 # compute the outer product of the frequency vectors for arbitrary dimensions
1128 # or just use the existing frequency vector if it is 1D
1129 nd_freqs = (
1130 reduce(jnp.multiply, jnp.ix_(*freqs)) if len(freqs.shape) > 1 else freqs
1131 )
1132 # TODO: could prevent this if we're not using .squeeze()..
1134 # "simulate" what would happen on correlating the coefficients
1135 corr_freqs = jnp.outer(nd_freqs, nd_freqs)
1136 # mask all frequencies that are nan now
1137 # (i.e. all correlations with a negative frequency component)
1138 corr_mask = jnp.where(jnp.isnan(corr_freqs), corr_freqs, 1)
1139 # from this, disregard all the other redundant correlations (i.e. c_0_1 = c_1_0)
1140 corr_mask = corr_mask.at[jnp.triu_indices(corr_mask.shape[0], 0)].set(jnp.nan)
1142 return corr_mask
1144 @staticmethod
1145 def _calculate_coefficients(
1146 model: Model,
1147 n_samples: int,
1148 seed: int,
1149 scale: bool = False,
1150 **kwargs,
1151 ) -> Tuple[jnp.ndarray, jnp.ndarray]:
1152 """
1153 Calculates the Fourier coefficients of a given model
1154 using `n_samples` and `seed`.
1155 Optionally, `noise_params` can be passed to perform noisy simulation.
1157 Args:
1158 model (Model): The QFM model
1159 n_samples (int): Number of samples to calculate average of coefficients
1160 seed (int): Seed to initialize random parameters
1161 scale (bool, optional): Whether to scale the number of samples.
1162 Defaults to False.
1163 **kwargs: Additional keyword arguments for the model function.
1165 Returns:
1166 Tuple[jnp.ndarray, jnp.ndarray]: Parameters and Coefficients of size NxK
1167 """
1168 if n_samples > 0:
1169 if scale:
1170 total_samples = int(
1171 jnp.power(2, model.n_qubits) * n_samples * model.n_input_feat
1172 )
1173 log.info(f"Using {total_samples} samples.")
1174 else:
1175 total_samples = n_samples
1176 random_key = random.key(seed)
1177 model.initialize_params(random_key, repeat=total_samples)
1178 else:
1179 total_samples = 1
1181 coeffs, freqs = Coefficients.get_spectrum(
1182 model, shift=True, trim=True, **kwargs
1183 )
1185 return model.params, coeffs, freqs
1187 @staticmethod
1188 def _correlate(mat: jnp.ndarray, method: str = "pearson") -> jnp.ndarray:
1189 """
1190 Correlates two arrays using `method`.
1191 Currently, `pearson` and `spearman` are supported.
1193 Args:
1194 mat (jnp.ndarray): Array of shape (N, K)
1195 method (str, optional): Correlation method. Defaults to "pearson".
1197 Raises:
1198 ValueError: If the method is not supported.
1200 Returns:
1201 jnp.ndarray: Correlation matrix of `a` and `b`.
1202 """
1203 assert len(mat.shape) >= 2, "Input matrix must have at least 2 dimensions"
1205 # Note that for the general n-D case, we have to flatten along
1206 # the first axis (last one is batch).
1207 # Note that the order here is important so we can easily filter out
1208 # negative coefficients later.
1209 # Consider the following example: [[1,2,3],[4,5,6],[7,8,9]]
1210 # we want to get [1, 4, 7, 2, 5, 8, 3, 6, 9]
1211 # such that after correlation, all positive indexed coefficients
1212 # will be in the bottom right quadrant
1213 if method == "pearson":
1214 result = FCC._pearson(mat.reshape(mat.shape[0], -1))
1215 # result = FCC._pearson(mat.reshape(mat.shape[-1], -1, order="F"))
1216 elif method == "spearman":
1217 result = FCC._spearman(mat.reshape(mat.shape[0], -1))
1218 # result = FCC._spearman(mat.reshape(mat.shape[-1], -1, order="F"))
1219 else:
1220 raise ValueError(
1221 f"Unknown correlation method: {method}. \
1222 Must be 'pearson' or 'spearman'."
1223 )
1225 return result
1227 @staticmethod
1228 def _pearson(
1229 mat: jnp.ndarray, cov: Optional[bool] = False, minp: Optional[int] = 1
1230 ) -> jnp.ndarray:
1231 """
1232 Based on Pandas correlation method as implemented here:
1233 https://github.com/pandas-dev/pandas/blob/main/pandas/_libs/algos.pyx
1235 Compute Pearson correlation between columns of `mat`,
1236 permitting missing values (NaN or ±Inf).
1238 Args:
1239 mat : array_like, shape (N, K)
1240 Input data.
1241 minp : int, optional
1242 Minimum number of paired observations required to form a correlation.
1243 If the number of valid pairs for (i, j) is < minp, the result is NaN.
1245 Returns:
1246 corr : ndarray, shape (K, K)
1247 Pearson correlation matrix.
1248 """
1250 mat = jnp.asarray(mat, dtype=jnp.float64)
1251 N, K = mat.shape
1253 # pre‐compute finite‐mask
1254 mask = jnp.isfinite(mat)
1256 # output
1257 result = np.empty((K, K), dtype=jnp.float64)
1259 # TODO: optimize in future iterations
1260 # loop over column‐pairs
1261 for i in range(K):
1262 for j in range(i + 1):
1263 # find rows where both columns are finite
1264 m = mask[:, i] & mask[:, j]
1265 n = jnp.count_nonzero(m)
1266 if n < minp:
1267 # too few pairs
1268 value = jnp.nan
1269 else:
1270 x = mat[m, i]
1271 y = mat[m, j]
1273 # compute means
1274 mean_x = x.mean()
1275 mean_y = y.mean()
1277 # demeaned data
1278 dx = x - mean_x
1279 dy = y - mean_y
1281 # sum of squares and cross‐prod
1282 ssx = jnp.dot(dx, dx)
1283 ssy = jnp.dot(dy, dy)
1284 cxy = jnp.dot(dx, dy)
1286 if cov:
1287 # sample covariance (denominator n−1)
1288 value = cxy / (n - 1) if n > 1 else jnp.nan
1289 else:
1290 # Pearson r = cov / (σx σy)
1291 denom = jnp.sqrt(ssx * ssy)
1292 if denom == 0.0:
1293 value = jnp.nan
1294 else:
1295 value = cxy / denom
1296 # clip numerical drift
1297 if value > 1.0:
1298 value = 1.0
1299 elif value < -1.0:
1300 value = -1.0
1302 result[i, j] = result[j, i] = value
1304 return result
1306 def _spearman(mat: jnp.ndarray, minp: Optional[int] = 1) -> jnp.ndarray:
1307 """
1308 Based on Pandas correlation method as implemented here:
1309 https://github.com/pandas-dev/pandas/blob/main/pandas/_libs/algos.pyx
1311 Compute Spearman correlation between columns of `mat`,
1312 permitting missing values (NaN or ±Inf).
1314 Args:
1315 mat : array_like, shape (N, K)
1316 Input data.
1317 minp : int, optional
1318 Minimum number of paired observations required to form a correlation.
1319 If the number of valid pairs for (i, j) is < minp, the result is NaN.
1321 Returns:
1322 corr : ndarray, shape (K, K)
1323 Spearman correlation matrix.
1324 """
1325 N, K = mat.shape
1326 # trivial all-NaN answer if too few rows
1327 if N < minp:
1328 return jnp.full((K, K), jnp.nan, dtype=float)
1330 # mask of finite entries
1331 mask = jnp.isfinite(mat) # shape (N, K), dtype=bool
1333 # precompute ranks column-wise ignoring NaNs
1334 ranks = np.full((N, K), jnp.nan, dtype=float)
1335 for j in range(K):
1336 valid = mask[:, j]
1337 if valid.any():
1338 # rankdata by default gives average ranks for ties
1339 ranks[valid, j] = rankdata(mat[valid, j], method="average")
1341 # allocate result
1342 result = np.empty((K, K), dtype=float)
1344 # TODO: optimize in future iterations
1345 # loop lower triangle (including diagonal)
1346 for i in range(K):
1347 for j in range(i + 1):
1348 # find rows where both columns are finite
1349 valid = mask[:, i] & mask[:, j]
1350 nobs = valid.sum()
1352 if nobs < minp:
1353 rho = jnp.nan
1354 else:
1355 xi = ranks[valid, i]
1356 yj = ranks[valid, j]
1357 # subtract means
1358 xi = xi - xi.mean()
1359 yj = yj - yj.mean()
1360 num = jnp.dot(xi, yj)
1361 den = jnp.sqrt(jnp.dot(xi, xi) * jnp.dot(yj, yj))
1362 rho = num / den if den > 0 else jnp.nan
1364 result[i, j] = rho
1365 result[j, i] = rho
1367 return result
1369 @staticmethod
1370 def _weighting(fourier_fingerprint: jnp.ndarray) -> jnp.ndarray:
1371 """
1372 Performs weighting on the given correlation matrix.
1373 Here, low-frequent coefficients are weighted more heavily.
1375 Args:
1376 correlation (jnp.ndarray): Correlation matrix
1377 """
1378 # TODO: in Future iterations, this can be optimized by computing
1379 # on the trimmed matrix instead.
1381 assert (
1382 fourier_fingerprint.shape[0] % 2 != 0
1383 and fourier_fingerprint.shape[1] % 2 != 0
1384 ), "Correlation matrix must have odd dimensions. \
1385 Hint: use `trim` argument when calling `get_spectrum`."
1386 assert (
1387 fourier_fingerprint.shape[0] == fourier_fingerprint.shape[1]
1388 ), "Correlation matrix must be square."
1390 def quadrant_to_matrix(a: jnp.ndarray) -> jnp.ndarray:
1391 """
1392 Transforms [[1,2],[3,4]] to
1393 [[1,2,1],[3,4,3],[1,2,1]]
1395 Args:
1396 a (jnp.ndarray): _description_
1398 Returns:
1399 jnp.ndarray: _description_
1400 """
1401 # rotates a from [[1,2],[3,4]] to [[3,4],[1,2]]
1402 a_rot = jnp.rot90(a)
1403 # merge the two matrices
1404 left = jnp.concat([a, a_rot])
1405 # merges left and right (left flipped)
1406 b = jnp.concat(
1407 [left, jnp.flip(left)],
1408 axis=1,
1409 )
1410 # remove the middle column and row
1411 return jnp.delete(
1412 jnp.delete(b, (b.shape[0] // 2), axis=0), (b.shape[1] // 2), axis=1
1413 )
1415 nc = fourier_fingerprint.shape[0] // 2 + 1
1416 weights = jnp.mgrid[0:nc:1, 0:nc:1].sum(axis=0) / ((nc - 1) * 2)
1417 weights_matrix = quadrant_to_matrix(weights)
1419 return fourier_fingerprint * weights_matrix
1420 raise NotImplementedError("Weighting method is not implemented")
1423class Datasets:
1424 @staticmethod
1425 def generate_fourier_series(
1426 random_key: random.PRNGKey,
1427 model: Model,
1428 coefficients_min: float = 0.0,
1429 coefficients_max: float = 1.0,
1430 zero_centered: bool = False,
1431 ) -> jnp.ndarray:
1432 """
1433 Generates the Fourier series representation of a function.
1434 It uses the `model.frequencies` property to retrieve the frequency
1435 information. This ensures that the resulting Fourier series is
1436 compatible with the model.
1438 This function is capable of generating $D$-dimensional Fourier series
1439 (again defined by `model.n_input_feat`).
1440 The highest frequency $N$ is retrieved per dimension.
1442 Samples of the Fourier coefficients are drawn from a uniform circle.
1444 Args:
1445 random_key (random.PRNGKey): Random number key for JAX.
1446 model (Model): The quantum circuit model.
1447 coefficients_min (float, optional): Minimum value for the coefficients.
1448 Defaults to 0.0.
1449 coefficients_max (float, optional): Maximum value for the coefficients.
1450 Defaults to 1.0.
1451 zero_centered (bool, optional): Whether to zero-center the coefficients.
1452 Defaults to False.
1454 Returns:
1455 jnp.ndarray: Input domain samples with shape ((N,)*D, D)
1456 jnp.ndarray: Fourier series values with shape ((N,)*D)
1457 jnp.ndarray: Fourier coefficients with shape ((N,)*D)
1459 """
1460 # TODO: the following code can be considered to
1461 # capturing a truly random spectrum.
1462 # add some constraints on the spectrum, i.e. not fully
1464 # Note: one key observation for understanding the following code is,
1465 # that instead of wrapping your head around symmetries in multi-
1466 # dimensional coefficient matrices, one can simply look at the flattened
1467 # version of such a matrix and reshape later. It just works out.
1469 # going from [0, 2pi] with the resolution required for highest frequency
1470 # permute with input dimensionality to get an n-d grid of domain samples
1471 # the output shape comes from the fact that want to create a "coordinate system"
1472 domain_samples_per_input_dim = jnp.stack(
1473 jnp.meshgrid(
1474 *[jnp.arange(0, 2 * jnp.pi, 2 * jnp.pi / d) for d in model.degree]
1475 )
1476 ).T.reshape(-1, model.n_input_feat)
1478 # generate the frequency indices for each dimension.
1479 # this will have the same shape as the domain samples
1480 frequencies = jnp.stack(jnp.meshgrid(*model.frequencies)).T.reshape(
1481 -1, model.n_input_feat
1482 )
1484 # using the frequency information, sample coefficients for each dimension
1485 # shape: (input_dims, n_freqs_per_input_dim // 2 + 1)
1486 coefficients = Datasets.uniform_circle(
1487 random_key,
1488 low=coefficients_min,
1489 high=coefficients_max,
1490 size=math.prod(model.degree) // 2 + 1,
1491 )
1493 # zero center (first coeff = 0)
1494 # we can assume the first coeff is the offset, because we're dealing
1495 # with a non-symmetric spectrum here
1496 if zero_centered:
1497 coefficients = coefficients.at[0].set(0.0)
1498 else:
1499 coefficients = coefficients.at[0].set(coefficients[0].real)
1501 # ensure symmetry (here, non_negative_ is removed!),
1502 # giving us the full coefficients vector
1503 coefficients = jnp.concat(
1504 [
1505 jnp.flip(coefficients[..., 1:]).conjugate(),
1506 coefficients,
1507 ],
1508 axis=-1,
1509 )
1511 # Vectorized version of $f(x) = \sum_{n=0}^{N-1} c_n * e^{i * \omega_n * x}$
1512 # it takes into account the input dimension, i.e. the output is a matrix
1513 # normalization uses the n_freqs component of the coefficients
1514 values = jnp.real(
1515 (
1516 jnp.exp(1j * (domain_samples_per_input_dim @ frequencies.T))
1517 * coefficients
1518 ).sum(axis=1)
1519 / coefficients.size
1520 )
1522 # return all the information we have
1523 return [
1524 domain_samples_per_input_dim.reshape(*model.degree, -1),
1525 values.reshape(model.degree),
1526 coefficients.reshape(model.degree),
1527 ]
1529 @staticmethod
1530 def uniform_circle(
1531 random_key: random.PRNGKey,
1532 size: Union[jnp.ndarray, List, int],
1533 low=0.0,
1534 high=1.0,
1535 ):
1536 """
1537 Random number generator for complex numbers sampled inside the unit circle
1539 Args:
1540 random_key (random.PRNGKey): Random number key for JAX.
1541 size (Union[jnp.ndarray, int]): Number of samples. If a 2D array is passed,
1542 the first dimension will be the number of dimensions.
1543 low (float, optional): Minimum Radius. Defaults to 0.0.
1544 high (float, optional): Maximum Radius. Defaults to 1.0.
1546 Returns
1547 jnp.ndarray: Array of complex numbers with shape of `size`
1548 """
1550 if isinstance(size, int):
1551 size = jnp.array([size])
1553 random_key, random_key1 = random.split(random_key)
1554 return jnp.sqrt(
1555 random.uniform(random_key, size, minval=low, maxval=high)
1556 ) * jnp.exp(2j * jnp.pi * random.uniform(random_key1, size))