Coverage for qml_essentials / coefficients.py: 96%
431 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-07 09:43 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-07 09:43 +0000
1from __future__ import annotations
2import math
3from collections import defaultdict
4from dataclasses import dataclass
5import jax.numpy as jnp
6from jax import random
7import numpy as np
8from scipy.stats import rankdata
9from functools import reduce
10from typing import List, Tuple, Optional, Any, Dict, Union
12from qml_essentials.model import Model
13from qml_essentials.utils import PauliCircuit
14from qml_essentials.operations import (
15 Operation,
16 PauliX,
17 PauliY,
18 PauliZ,
19)
21import logging
23log = logging.getLogger(__name__)
26class Coefficients:
27 @classmethod
28 def get_spectrum(
29 cls,
30 model: Model,
31 mfs: int = 1,
32 mts: int = 1,
33 shift=False,
34 trim=False,
35 numerical_cap: Optional[float] = -1,
36 **kwargs,
37 ) -> Tuple[jnp.ndarray, jnp.ndarray]:
38 """
39 Extracts the coefficients of a given model using a FFT (jnp-fft).
41 Note that the coefficients are complex numbers, but the imaginary part
42 of the coefficients should be very close to zero, since the expectation
43 values of the Pauli operators are real numbers.
45 It can perform oversampling in both the frequency and time domain
46 using the `mfs` and `mts` arguments.
48 Args:
49 model (Model): The model to sample.
50 mfs (int): Multiplicator for the highest frequency. Default is 2.
51 mts (int): Multiplicator for the number of time samples. Default is 1.
52 shift (bool): Whether to apply jnp-fftshift. Default is False.
53 trim (bool): Whether to remove the Nyquist frequency if spectrum is even.
54 Default is False.
55 numerical_cap (Optional[float]): Numerical cap for the coefficients.
56 kwargs (Any): Additional keyword arguments for the model function.
58 Returns:
59 Tuple[jnp.ndarray, jnp.ndarray]: Tuple containing the coefficients
60 and frequencies.
61 """
62 kwargs.setdefault("force_mean", True)
63 kwargs.setdefault("execution_type", "expval")
65 coeffs, freqs = cls._fourier_transform(model, mfs=mfs, mts=mts, **kwargs)
67 if not jnp.isclose(jnp.sum(coeffs).imag, 0.0, rtol=1.0e-5):
68 raise ValueError(
69 f"Spectrum is not real. Imaginary part of coefficients is:\
70 {jnp.sum(coeffs).imag}"
71 )
73 if trim:
74 for ax in range(model.n_input_feat):
75 if coeffs.shape[ax] % 2 == 0:
76 coeffs = np.delete(coeffs, len(coeffs) // 2, axis=ax)
77 freqs = [np.delete(freq, len(freq) // 2, axis=ax) for freq in freqs]
79 if shift:
80 coeffs = jnp.fft.fftshift(coeffs, axes=list(range(model.n_input_feat)))
81 freqs = np.fft.fftshift(freqs)
83 if numerical_cap > 0:
84 # set coeffs below threshold to zero
85 coeffs = jnp.where(
86 jnp.abs(coeffs) < numerical_cap,
87 jnp.zeros_like(coeffs),
88 coeffs,
89 )
91 if len(freqs) == 1:
92 freqs = freqs[0]
94 return coeffs, freqs
96 @classmethod
97 def _fourier_transform(
98 cls, model: Model, mfs: int, mts: int, **kwargs: Any
99 ) -> jnp.ndarray:
100 # Create a frequency vector with as many frequencies as model degrees,
101 # oversampled by mfs
102 n_freqs: jnp.ndarray = jnp.array(
103 [mfs * model.degree[i] for i in range(model.n_input_feat)]
104 )
106 start, stop, step = 0, 2 * mts * jnp.pi, 2 * jnp.pi / n_freqs
107 # Stretch according to the number of frequencies
108 inputs: List = [
109 jnp.arange(start, stop, step[i]) for i in range(model.n_input_feat)
110 ]
112 # permute with input dimensionality
113 nd_inputs = jnp.array(
114 jnp.meshgrid(*[inputs[i] for i in range(model.n_input_feat)])
115 ).T.reshape(-1, model.n_input_feat)
117 # Output vector is not necessarily the same length as input
118 outputs = model(inputs=nd_inputs, **kwargs)
119 outputs = outputs.reshape(
120 *[inputs[i].shape[0] for i in range(model.n_input_feat)], -1
121 ).squeeze()
123 coeffs = jnp.fft.fftn(outputs, axes=list(range(model.n_input_feat)))
125 freqs = [
126 jnp.fft.fftfreq(int(mts * n_freqs[i]), 1 / n_freqs[i])
127 for i in range(model.n_input_feat)
128 ]
129 # freqs = jnp.fft.fftfreq(mts * n_freqs, 1 / n_freqs)
131 # TODO: this could cause issues with multidim input
132 # FIXME: account for different frequencies in multidim input scenarios
133 # Run the fft and rearrange +
134 # normalize the output (using product if multidim)
135 return (
136 coeffs / math.prod(outputs.shape[0 : model.n_input_feat]),
137 freqs,
138 )
140 @classmethod
141 def get_psd(cls, coeffs: jnp.ndarray) -> jnp.ndarray:
142 """
143 Calculates the power spectral density (PSD) from given Fourier coefficients.
145 Args:
146 coeffs (jnp.ndarray): The Fourier coefficients.
148 Returns:
149 jnp.ndarray: The power spectral density.
150 """
151 # TODO: if we apply trim=True in advance, this will be slightly wrong..
153 def abs2(x):
154 return x.real**2 + x.imag**2
156 scale = 2.0 / (len(coeffs) ** 2)
157 return scale * abs2(coeffs)
159 @classmethod
160 def evaluate_Fourier_series(
161 cls,
162 coefficients: jnp.ndarray,
163 frequencies: jnp.ndarray,
164 inputs: Union[jnp.ndarray, list, float],
165 ) -> float:
166 """
167 Evaluate the function value of a Fourier series at one point.
169 Args:
170 coefficients (jnp.ndarray): Coefficients of the Fourier series.
171 frequencies (jnp.ndarray): Corresponding frequencies.
172 inputs (jnp.ndarray): Point at which to evaluate the function.
173 Returns:
174 float: The function value at the input point.
175 """
176 if isinstance(frequencies, list):
177 if len(coefficients.shape) <= len(frequencies):
178 coefficients = coefficients[..., jnp.newaxis]
179 else:
180 if len(coefficients.shape) == 1:
181 coefficients = coefficients[..., jnp.newaxis]
183 if isinstance(inputs, list):
184 inputs = jnp.array(inputs)
185 if len(inputs.shape) < 1:
186 inputs = inputs[jnp.newaxis, ...]
188 if isinstance(frequencies, list):
189 input_dim = len(frequencies)
190 frequencies = jnp.stack(jnp.meshgrid(*frequencies))
191 if input_dim != len(inputs):
192 frequencies = jnp.repeat(
193 frequencies[jnp.newaxis, ...], inputs.shape[0], axis=0
194 )
195 freq_inputs = jnp.einsum("bi...,b->b...", frequencies, inputs)
196 exponents = jnp.exp(1j * freq_inputs).T
197 exp = jnp.einsum("jl...k,jl...b->b...k", coefficients, exponents)
198 else:
199 freq_inputs = jnp.einsum("i...,i->...", frequencies, inputs)
200 exponents = jnp.exp(1j * freq_inputs).T
201 exp = jnp.einsum("jl...k,jl...->k...", coefficients, exponents)
202 else:
203 frequencies = jnp.repeat(
204 frequencies[jnp.newaxis, ...], inputs.shape[0], axis=0
205 )
206 freq_inputs = jnp.einsum("i...,i->i...", frequencies, inputs)
207 exponents = jnp.exp(1j * freq_inputs)
208 exp = jnp.einsum("j...k,ij...->ik...", coefficients, exponents)
210 return jnp.squeeze(jnp.real(exp))
213class FourierTree:
214 """
215 Sine-cosine tree representation for the algorithm by Nemkov et al.
216 This tree can be used to obtain analytical Fourier coefficients for a given
217 Pauli-Clifford circuit.
218 """
220 class CoefficientsTreeNode:
221 """
222 Representation of a node in the coefficients tree for the algorithm by
223 Nemkov et al.
224 """
226 def __init__(
227 self,
228 parameter_idx: Optional[int],
229 observable: FourierTree.PauliOperator,
230 is_sine_factor: bool,
231 is_cosine_factor: bool,
232 left: Optional[FourierTree.CoefficientsTreeNode] = None,
233 right: Optional[FourierTree.CoefficientsTreeNode] = None,
234 ):
235 """
236 Coefficient tree node initialisation. Each node has information about
237 its creation context and it's children, i.e.:
239 Args:
240 parameter_idx (Optional[int]): Index of the corresp. param. index i.
241 observable (FourierTree.PauliOperator): The nodes observable to
242 obtain the expectation value that contributes to the constant
243 term.
244 is_sine_factor (bool): If this node belongs to a sine coefficient.
245 is_cosine_factor (bool): If this node belongs to a cosine coefficient.
246 left (Optional[CoefficientsTreeNode]): left child (if any).
247 right (Optional[CoefficientsTreeNode]): right child (if any).
248 """
249 self.parameter_idx = parameter_idx
251 assert not (is_sine_factor and is_cosine_factor), (
252 "Cannot be sine and cosine at the same time"
253 )
254 self.is_sine_factor = is_sine_factor
255 self.is_cosine_factor = is_cosine_factor
257 # If the observable does not constist of only Z and I, the
258 # expectation (and therefore the constant node term) is zero
259 if jnp.logical_or(
260 observable.list_repr == 0, observable.list_repr == 1
261 ).any():
262 self.term = 0.0
263 else:
264 self.term = observable.phase
266 self.left = left
267 self.right = right
269 def evaluate(self, parameters: list[float]) -> float:
270 """
271 Recursive function to evaluate the expectation of the coefficient tree,
272 starting from the current node.
274 Args:
275 parameters (list[float]): The parameters, by which the circuit (and
276 therefore the tree) is parametrised.
278 Returns:
279 float: The expectation for the current node and it's children.
280 """
281 factor = (
282 parameters[self.parameter_idx]
283 if self.parameter_idx is not None
284 else 1.0
285 )
286 if self.is_sine_factor:
287 factor = 1j * jnp.sin(factor)
288 elif self.is_cosine_factor:
289 factor = jnp.cos(factor)
290 if not (self.left or self.right): # leaf
291 return factor * self.term
293 sum_children = 0.0
294 if self.left:
295 left = self.left.evaluate(parameters)
296 sum_children = sum_children + left
297 if self.right:
298 right = self.right.evaluate(parameters)
299 sum_children = sum_children + right
301 return factor * sum_children
303 def get_leafs(
304 self,
305 sin_list: List[int],
306 cos_list: List[int],
307 existing_leafs: List[FourierTree.TreeLeaf] = [],
308 ) -> List[FourierTree.TreeLeaf]:
309 """
310 Traverse the tree starting from the current node, to obtain the tree
311 leafs only.
312 The leafs correspond to the terms in the sine-cosine tree
313 representation that eventually are used to obtain coefficients and
314 frequencies.
315 Sine and cosine lists are recursively passed to the children until a
316 leaf is reached (top to bottom).
317 Leafs are then passed bottom to top to the caller.
319 Args:
320 sin_list (List[int]): Current number of sine contributions for each
321 parameter. Has the same length as the parameters, as each
322 position corresponds to one parameter.
323 cos_list (List[int]): Current number of cosine contributions for
324 each parameter. Has the same length as the parameters, as each
325 position corresponds to one parameter.
326 existing_leafs (List[TreeLeaf]): Current list of leaf nodes from
327 parents.
329 Returns:
330 List[TreeLeaf]: Updated list of leaf nodes.
331 """
333 if self.is_sine_factor:
334 sin_list = sin_list.at[self.parameter_idx].set(
335 sin_list[self.parameter_idx] + 1
336 )
337 if self.is_cosine_factor:
338 cos_list = cos_list.at[self.parameter_idx].set(
339 cos_list[self.parameter_idx] + 1
340 )
342 if not (self.left or self.right): # leaf
343 if self.term != 0.0:
344 return [FourierTree.TreeLeaf(sin_list, cos_list, self.term)]
345 else:
346 return []
348 if self.left:
349 leafs_left = self.left.get_leafs(
350 sin_list.copy(), cos_list.copy(), existing_leafs.copy()
351 )
352 else:
353 leafs_left = []
355 if self.right:
356 leafs_right = self.right.get_leafs(
357 sin_list.copy(), cos_list.copy(), existing_leafs.copy()
358 )
359 else:
360 leafs_right = []
362 existing_leafs.extend(leafs_left)
363 existing_leafs.extend(leafs_right)
364 return existing_leafs
366 @dataclass
367 class TreeLeaf:
368 """
369 Coefficient tree leafs according to the algorithm by Nemkov et al., which
370 correspond to the terms in the sine-cosine tree representation that
371 eventually are used to obtain coefficients and frequencies.
373 Args:
374 sin_indices (List[int]): Current number of sine contributions for each
375 parameter. Has the same length as the parameters, as each
376 position corresponds to one parameter.
377 cos_indices (List[int]): Current number of cosine contributions for
378 each parameter. Has the same length as the parameters, as each
379 position corresponds to one parameter.
380 term (jnp.complex): Constant factor of the leaf, depending on the
381 expectation value of the observable, and a phase.
382 """
384 sin_indices: List[int]
385 cos_indices: List[int]
386 term: complex
388 class PauliOperator:
389 """
390 Utility class for storing Pauli Rotations, the corresponding indices
391 in the XY-Space (whether there is a gate with X or Y generator at a
392 certain qubit) and the phase.
394 Args:
395 pauli (Union[Operator, jnp.ndarray[int]]): Pauli Rotation Operation
396 or list representation
397 n_qubits (int): Number of qubits in the circuit
398 prev_xy_indices (Optional[jnp.ndarray[bool]]): X/Y indices of the
399 previous Pauli sequence. Defaults to None.
400 is_observable (bool): If the operator is an observable. Defaults to
401 False.
402 is_init (bool): If this Pauli operator is initialised the first
403 time. Defaults to True.
404 phase (float): Phase of the operator. Defaults to 1.0
405 """
407 def __init__(
408 self,
409 pauli: Union[Operation, jnp.ndarray[int]],
410 n_qubits: int,
411 prev_xy_indices: Optional[jnp.ndarray[bool]] = None,
412 is_observable: bool = False,
413 is_init: bool = True,
414 phase: float = 1.0,
415 ):
416 self.is_observable = is_observable
417 self.phase = phase
419 if is_init:
420 if not is_observable:
421 pauli = pauli.generator()
422 self.list_repr = self._create_list_representation(pauli, n_qubits)
423 else:
424 assert isinstance(pauli, jnp.ndarray)
425 self.list_repr = pauli
427 if prev_xy_indices is None:
428 prev_xy_indices = jnp.zeros(n_qubits, dtype=bool)
429 self.xy_indices = jnp.logical_or(
430 prev_xy_indices,
431 self._compute_xy_indices(self.list_repr, rev=is_observable),
432 )
434 @staticmethod
435 def _compute_xy_indices(
436 op: jnp.ndarray[int], rev: bool = False
437 ) -> jnp.ndarray[bool]:
438 """
439 Computes the positions of X or Y gates to an one-hot encoded boolen
440 array.
442 Args:
443 op (jnp.ndarray[int]): Pauli-Operation list representation.
444 rev (bool): Whether to negate the array.
446 Returns:
447 jnp.ndarray[bool]: One hot encoded boolean array.
448 """
449 xy_indices = (op == 0) + (op == 1)
450 if rev:
451 xy_indices = ~xy_indices
452 return xy_indices
454 @staticmethod
455 def _create_list_representation(
456 op: Operation, n_qubits: int
457 ) -> jnp.ndarray[int]:
458 """
459 Create list representation of an Operation.
460 Generally, the list representation is a list of length n_qubits,
461 where at each position a Pauli Operator is encoded as such:
462 I: -1
463 X: 0
464 Y: 1
465 Z: 2
467 Args:
468 op (Operation): Gate operation (PauliX, PauliY, PauliZ, or
469 Hermitian wrapping a multi-qubit Pauli tensor product).
470 n_qubits (int): number of qubits in the circuit
472 Returns:
473 jnp.ndarray[int]: List representation
474 """
475 pauli_repr = -jnp.ones(n_qubits, dtype=int)
477 _NAME_TO_IDX = {"PauliX": 0, "PauliY": 1, "PauliZ": 2}
479 if op.name in _NAME_TO_IDX:
480 pauli_repr = pauli_repr.at[op.wires[0]].set(_NAME_TO_IDX[op.name])
481 elif isinstance(op, PauliX):
482 pauli_repr = pauli_repr.at[op.wires[0]].set(0)
483 elif isinstance(op, PauliY):
484 pauli_repr = pauli_repr.at[op.wires[0]].set(1)
485 elif isinstance(op, PauliZ):
486 pauli_repr = pauli_repr.at[op.wires[0]].set(2)
487 else:
488 # Multi-qubit case: decompose via pauli_string_from_operation
489 from qml_essentials.operations import pauli_string_from_operation
491 pauli_str = pauli_string_from_operation(op)
492 char_to_idx = {"X": 0, "Y": 1, "Z": 2, "I": -1}
493 for i, (wire, ch) in enumerate(zip(op.wires, pauli_str)):
494 idx = char_to_idx.get(ch, -1)
495 if idx >= 0:
496 pauli_repr = pauli_repr.at[wire].set(idx)
498 return pauli_repr
500 def is_commuting(self, pauli: jnp.ndarray[int]) -> bool:
501 """
502 Computes if this Pauli commutes with another Pauli operator.
503 This computation is based on the fact that The commutator is zero
504 if and only if the number of anticommuting single-qubit Paulis is
505 even.
507 Args:
508 pauli (jnp.ndarray[int]): List representation of another Pauli
510 Returns:
511 bool: If the current and other Pauli are commuting.
512 """
513 anticommutator = jnp.where(
514 pauli < 0,
515 False,
516 jnp.where(
517 self.list_repr < 0,
518 False,
519 jnp.where(self.list_repr == pauli, False, True),
520 ),
521 )
522 return not (jnp.sum(anticommutator) % 2)
524 def tensor(self, pauli: jnp.ndarray[int]) -> FourierTree.PauliOperator:
525 """
526 Compute tensor product between the current Pauli and a given list
527 representation of another Pauli operator.
529 Args:
530 pauli (jnp.ndarray[int]): List representation of Pauli
532 Returns:
533 FourierTree.PauliOperator: New Pauli operator object, which
534 contains the tensor product
535 """
536 diff = (pauli - self.list_repr + 3) % 3
537 phase = jnp.where(
538 self.list_repr < 0,
539 1.0,
540 jnp.where(
541 pauli < 0,
542 1.0,
543 jnp.where(
544 diff == 2,
545 1.0j,
546 jnp.where(diff == 1, -1.0j, 1.0),
547 ),
548 ),
549 )
551 obs = jnp.where(
552 self.list_repr < 0,
553 pauli,
554 jnp.where(
555 pauli < 0,
556 self.list_repr,
557 jnp.where(
558 diff == 2,
559 (self.list_repr + 1) % 3,
560 jnp.where(diff == 1, (self.list_repr + 2) % 3, -1),
561 ),
562 ),
563 )
564 phase = self.phase * jnp.prod(phase)
565 return FourierTree.PauliOperator(
566 obs, phase=phase, n_qubits=obs.size, is_init=False, is_observable=True
567 )
569 def __init__(self, model: Model):
570 """
571 Tree initialisation, based on the Pauli-Clifford representation of a model.
572 Currently, only one input feature is supported.
574 **Usage**:
575 ```
576 # initialise a model
577 model = Model(...)
579 # initialise and build FourierTree
580 tree = FourierTree(model)
582 # get expectaion value
583 exp = tree()
585 # Get spectrum (for each observable, we have one list element)
586 coeff_list, freq_list = tree.spectrum()
587 ```
589 Args:
590 model (Model): The Model, for which to build the tree
591 """
592 self.model = model
593 self.tree_roots = None
595 inputs = self.model._inputs_validation([1.0])
597 # Record the circuit tape using yaqsi's tape recording
598 raw_tape = self.model.script._record(params=model.params, inputs=inputs)
600 # Build observables from the model's output_qubit configuration
601 _, obs_list = self.model._build_obs()
603 quantum_tape = PauliCircuit.from_parameterised_circuit(
604 raw_tape, observables=obs_list
605 )
607 self.parameters = [jnp.squeeze(p) for p in quantum_tape.get_parameters()]
609 self.input_indices, self.all_input_indices = quantum_tape.get_input_indices()
611 self.observables = self._encode_observables(quantum_tape.observables)
613 pauli_rot = FourierTree.PauliOperator(
614 quantum_tape.operations[0],
615 self.model.n_qubits,
616 )
617 self.pauli_rotations = [pauli_rot]
618 for op in quantum_tape.operations[1:]:
619 pauli_rot = FourierTree.PauliOperator(
620 op, self.model.n_qubits, pauli_rot.xy_indices
621 )
622 self.pauli_rotations.append(pauli_rot)
624 self.tree_roots = self.build()
625 self.leafs: List[List[FourierTree.TreeLeaf]] = self._get_tree_leafs()
627 def __call__(
628 self,
629 params: Optional[jnp.ndarray] = None,
630 inputs: Optional[jnp.ndarray] = None,
631 **kwargs,
632 ) -> jnp.ndarray:
633 """
634 Evaluates the Fourier tree via sine-cosine terms sum. This is
635 equivalent to computing the expectation value of the observables with
636 respect to the corresponding circuit.
638 Args:
639 params (Optional[jnp.ndarray], optional): Parameters of the model.
640 Defaults to None.
641 inputs (Optional[jnp.ndarray], optional): Inputs to the circuit.
642 Defaults to None.
644 Returns:
645 jnp.ndarray: Expectation value of the tree.
647 Raises:
648 NotImplementedError: When using other "execution_type" as expval.
649 NotImplementedError: When using "noise_params"
652 """
653 params = (
654 self.model._params_validation(params)
655 if params is not None
656 else self.model.params
657 )
658 inputs = (
659 self.model._inputs_validation(inputs)
660 if inputs is not None
661 else self.model._inputs_validation(1.0)
662 )
664 if kwargs.get("execution_type", "expval") != "expval":
665 raise NotImplementedError(
666 f'Currently, only "expval" execution type is supported when '
667 f"building FourierTree. Got {kwargs.get('execution_type', 'expval')}."
668 )
669 if kwargs.get("noise_params", None) is not None:
670 raise NotImplementedError(
671 "Currently, noise is not supported when building FourierTree."
672 )
674 # Record the circuit tape using yaqsi's tape recording
675 raw_tape = self.model.script._record(params=self.model.params, inputs=inputs)
677 # Build observables from the model's output_qubit configuration
678 _, obs_list = self.model._build_obs()
680 quantum_tape = PauliCircuit.from_parameterised_circuit(
681 raw_tape, observables=obs_list
682 )
684 self.parameters = [jnp.squeeze(p) for p in quantum_tape.get_parameters()]
686 results = jnp.zeros(len(self.tree_roots))
687 for i, root in enumerate(self.tree_roots):
688 results = results.at[i].set(jnp.real(root.evaluate(self.parameters)))
690 if kwargs.get("force_mean", False):
691 return jnp.mean(results)
692 else:
693 return results
695 def build(self) -> List[CoefficientsTreeNode]:
696 """
697 Creates the coefficient tree, i.e. it creates and initialises the tree
698 nodes.
699 Leafs can be obtained separately in _get_tree_leafs, once the tree is
700 set up.
702 Returns:
703 List[CoefficientsTreeNode]: The list of root nodes (one root for
704 each observable).
705 """
706 tree_roots = []
707 pauli_rotation_idx = len(self.pauli_rotations) - 1
708 for obs in self.observables:
709 root = self._create_tree_node(obs, pauli_rotation_idx)
710 tree_roots.append(root)
711 return tree_roots
713 def _encode_observables(
714 self, tape_obs: List[Operation]
715 ) -> List[FourierTree.PauliOperator]:
716 """
717 Encodes observables from tape as FourierTree.PauliOperator
718 utility objects.
720 Args:
721 tape_obs (List[Operation]): Observable operations
723 Returns:
724 List[FourierTree.PauliOperator]: List of Pauli operators
725 """
726 observables = []
727 for obs in tape_obs:
728 observables.append(
729 FourierTree.PauliOperator(obs, self.model.n_qubits, is_observable=True)
730 )
731 return observables
733 def _get_tree_leafs(self) -> List[List[TreeLeaf]]:
734 """
735 Obtain all Leaf Nodes with its sine- and cosine terms.
737 Returns:
738 List[List[TreeLeaf]]: For each observable (root), the list of leaf
739 nodes.
740 """
741 leafs = []
742 for root in self.tree_roots:
743 sin_list = jnp.zeros(len(self.parameters), dtype=jnp.int32)
744 cos_list = jnp.zeros(len(self.parameters), dtype=jnp.int32)
745 leafs.append(root.get_leafs(sin_list, cos_list, []))
746 return leafs
748 def get_spectrum(
749 self, force_mean: bool = False
750 ) -> Tuple[List[jnp.ndarray], List[jnp.ndarray]]:
751 """
752 Computes the Fourier spectrum for the tree, consisting of the
753 frequencies and its corresponding coefficinets.
754 If the frag force_mean was set in the constructor, the mean coefficient
755 over all observables (roots) are computed.
757 Args:
758 force_mean (bool, optional): Whether to average over multiple
759 observables. Defaults to False.
761 Returns:
762 Tuple[List[jnp.ndarray], List[jnp.ndarray]]:
763 - List of frequencies, one list for each observable (root).
764 - List of corresponding coefficents, one list for each
765 observable (root).
766 """
767 parameter_indices = [
768 i for i in range(len(self.parameters)) if i not in self.all_input_indices
769 ]
771 coeffs = []
772 for leafs in self.leafs:
773 freq_terms = defaultdict(np.complex128)
774 for input_idx in self.input_indices:
775 for leaf in leafs:
776 leaf_factor, s, c = self._compute_leaf_factors(
777 leaf, parameter_indices, input_idx
778 )
780 for a in range(s + 1):
781 for b in range(c + 1):
782 comb = math.comb(s, a) * math.comb(c, b) * (-1) ** (s - a)
783 freq_terms[2 * a + 2 * b - s - c] += comb * leaf_factor
785 coeffs.append(freq_terms)
787 frequencies, coefficients = self._freq_terms_to_coeffs(coeffs, force_mean)
788 return coefficients, frequencies
790 def _freq_terms_to_coeffs(
791 self, coeffs: List[Dict[int, jnp.ndarray]], force_mean: bool
792 ) -> Tuple[List[jnp.ndarray], List[jnp.ndarray]]:
793 """
794 Given a list of dictionaries of the form:
795 [
796 {
797 freq_obs1_1: coeff1,
798 freq_obs1_2: coeff2,
799 ...
800 },
801 {
802 freq_obs2_1: coeff3,
803 freq_obs2_2: coeff4,
804 ...
805 }
806 ...
807 ],
808 Compute two separate lists of frequencies and coefficients.
809 such that:
810 freqs: [
811 [freq_obs1_1, freq_obs1_1, ...],
812 [freq_obs2_1, freq_obs2_1, ...],
813 ...
814 ]
815 coeffs: [
816 [coeff1, coeff2, ...],
817 [coeff3, coeff4, ...],
818 ...
819 ]
821 If force_mean is set length of the resulting frequency and coefficent
822 list is 1.
824 Args:
825 coeffs (List[Dict[int, jnp.ndarray]]): Frequency->Coefficients
826 dictionary list, one dict for each observable (root).
827 force_mean (bool): Whether to average coefficients over multiple
828 observables.
830 Returns:
831 Tuple[List[jnp.ndarray], List[jnp.ndarray]]:
832 - List of frequencies, one list for each observable (root).
833 - List of corresponding coefficents, one list for each
834 observable (root).
835 """
836 frequencies = []
837 coefficients = []
838 if force_mean:
839 all_freqs = sorted(set([f for c in coeffs for f in c.keys()]))
840 coefficients.append(
841 jnp.array(
842 [
843 jnp.mean(jnp.array([c.get(f, 0.0) for c in coeffs]))
844 for f in all_freqs
845 ]
846 )
847 )
848 frequencies.append(jnp.array(all_freqs))
849 else:
850 for freq_terms in coeffs:
851 freq_terms = dict(sorted(freq_terms.items()))
852 frequencies.append(jnp.array(list(freq_terms.keys())))
853 coefficients.append(jnp.array(list(freq_terms.values())))
854 return frequencies, coefficients
856 def _compute_leaf_factors(
857 self,
858 leaf: TreeLeaf,
859 parameter_indices: List[int],
860 input_idx: int,
861 ) -> Tuple[float, int, int]:
862 """
863 Computes the constant coefficient factor for each leaf.
864 Additionally sine and cosine contributions of the input parameters for
865 this leaf are returned, which are required to obtain the corresponding
866 frequencies.
868 Args:
869 leaf (TreeLeaf): The leaf for which to compute the factor.
870 parameter_indices (List[int]): Variational parameter indices.
872 Returns:
873 Tuple[float, int, int]:
874 - float: the constant factor for the leaf
875 - int: number of sine contributions of the input
876 - int: number of cosine contributions of the input
877 """
878 leaf_factor = 1.0
879 for i in parameter_indices:
880 interm_factor = (
881 jnp.cos(self.parameters[i]) ** leaf.cos_indices[i]
882 * (1j * jnp.sin(self.parameters[i])) ** leaf.sin_indices[i]
883 )
884 leaf_factor = leaf_factor * interm_factor
886 # Get number of sine and cosine factors to which the input contributes
887 c = jnp.sum(
888 jnp.array([leaf.cos_indices[k] for k in self.input_indices[input_idx]])
889 )
890 s = jnp.sum(
891 jnp.array([leaf.sin_indices[k] for k in self.input_indices[input_idx]])
892 )
894 leaf_factor = leaf.term * leaf_factor * 0.5 ** (s + c)
896 return leaf_factor, int(s), int(c)
898 def _early_stopping_possible(
899 self, pauli_rotation_idx: int, observable: FourierTree.PauliOperator
900 ):
901 """
902 Checks if a node for an observable can be discarded as all expecation
903 values that can result through further branching are zero.
904 The method is mentioned in the paper by Nemkov et al.: If the one-hot
905 encoded indices for X/Y operations in the Pauli-rotation generators are
906 a basis for that of the observable, the node must be processed further.
907 If not, it can be discarded.
909 Args:
910 pauli_rotation_idx (int): Index of remaining Pauli rotation gates.
911 Gates itself are attributes of the class.
912 observable (FourierTree.PauliOperator): Current observable
913 """
914 xy_indices_obs = jnp.logical_or(
915 observable.xy_indices, self.pauli_rotations[pauli_rotation_idx].xy_indices
916 ).all()
918 return not xy_indices_obs
920 def _create_tree_node(
921 self,
922 observable: FourierTree.PauliOperator,
923 pauli_rotation_idx: int,
924 parameter_idx: Optional[int] = None,
925 is_sine: bool = False,
926 is_cosine: bool = False,
927 ) -> Optional[CoefficientsTreeNode]:
928 """
929 Builds the Fourier-Tree according to the algorithm by Nemkov et al.
931 Args:
932 observable (FourierTree.PauliOperator): Current observable
933 pauli_rotation_idx (int): Index of remaining Pauli rotation gates.
934 Gates itself are attributes of the class.
935 parameter_idx (Optional[int]): Index of the current parameter.
936 Parameters itself are attributes of the class.
937 is_sine (bool): If the current node is a sine (left) node.
938 is_cosine (bool): If the current node is a cosine (right) node.
940 Returns:
941 Optional[CoefficientsTreeNode]: The resulting node. Children are set
942 recursively. The top level receives the tree root.
943 """
944 if self._early_stopping_possible(pauli_rotation_idx, observable):
945 return None
947 # remove commuting paulis
948 while pauli_rotation_idx >= 0:
949 last_pauli = self.pauli_rotations[pauli_rotation_idx]
950 if not observable.is_commuting(last_pauli.list_repr):
951 break
952 pauli_rotation_idx -= 1
953 else: # leaf
954 return FourierTree.CoefficientsTreeNode(
955 parameter_idx, observable, is_sine, is_cosine
956 )
958 last_pauli = self.pauli_rotations[pauli_rotation_idx]
960 left = self._create_tree_node(
961 observable,
962 pauli_rotation_idx - 1,
963 pauli_rotation_idx,
964 is_cosine=True,
965 )
967 next_observable = self._create_new_observable(last_pauli.list_repr, observable)
968 right = self._create_tree_node(
969 next_observable,
970 pauli_rotation_idx - 1,
971 pauli_rotation_idx,
972 is_sine=True,
973 )
975 return FourierTree.CoefficientsTreeNode(
976 parameter_idx,
977 observable,
978 is_sine,
979 is_cosine,
980 left,
981 right,
982 )
984 def _create_new_observable(
985 self, pauli: jnp.ndarray[int], observable: FourierTree.PauliOperator
986 ) -> FourierTree.PauliOperator:
987 """
988 Utility function to obtain the new observable for a tree node, if the
989 last Pauli and the observable do not commute.
991 Args:
992 pauli (jnp.ndarray[int]): The int array representation of the last
993 Pauli rotation in the operation sequence.
994 observable (FourierTree.PauliOperator): The current observable.
996 Returns:
997 FourierTree.PauliOperator: The new observable.
998 """
999 observable = observable.tensor(pauli)
1000 return observable
1003class FCC:
1004 @classmethod
1005 def get_fcc(
1006 cls,
1007 model: Model,
1008 n_samples: int,
1009 random_key: Optional[random.PRNGKey] = None,
1010 method: Optional[str] = "pearson",
1011 scale: Optional[bool] = False,
1012 weight: Optional[bool] = False,
1013 trim_redundant: Optional[bool] = True,
1014 **kwargs,
1015 ) -> float:
1016 """
1017 Shortcut method to get just the FCC.
1018 This includes
1019 1. What is done in `get_fourier_fingerprint`:
1020 1. Calculating the coefficients (using `n_samples`)
1021 2. Correlating the result from 1) using `method`
1022 3. Weighting the correlation matrix (if `weight` is True)
1023 4. Remove redundancies
1024 2. What is done in `calculate_fcc`:
1025 1. Absolute of the fingerprint
1026 2. Average
1028 Args:
1029 model (Model): The QFM model
1030 n_samples (int): Number of samples to calculate average of coefficients
1031 random_key (Optional[random.PRNGKey]): JAX random key for parameter
1032 initialization. If None, uses the model's internal random key.
1033 method (Optional[str], optional): Correlation method. Defaults to "pearson".
1034 scale (Optional[bool], optional): Whether to scale the number of samples.
1035 Defaults to False.
1036 weight (Optional[bool], optional): Whether to weight the correlation matrix.
1037 Defaults to False.
1038 trim_redundant (Optional[bool], optional): Whether to remove redundant
1039 correlations. Defaults to False.
1040 **kwargs (Any): Additional keyword arguments for the model function.
1042 Returns:
1043 float: The FCC
1044 """
1046 fourier_fingerprint, _ = cls.get_fourier_fingerprint(
1047 model,
1048 n_samples,
1049 random_key,
1050 method,
1051 scale,
1052 weight,
1053 trim_redundant=trim_redundant,
1054 **kwargs,
1055 )
1057 return cls.calculate_fcc(fourier_fingerprint)
1059 @classmethod
1060 def get_fourier_fingerprint(
1061 cls,
1062 model: Model,
1063 n_samples: int,
1064 random_key: Optional[random.PRNGKey] = None,
1065 method: Optional[str] = "pearson",
1066 scale: Optional[bool] = False,
1067 weight: Optional[bool] = False,
1068 trim_redundant: Optional[bool] = True,
1069 nan_to_one: Optional[bool] = False,
1070 **kwargs: Any,
1071 ) -> Tuple[jnp.ndarray, jnp.ndarray]:
1072 """
1073 Shortcut method to get just the fourier fingerprint.
1074 This includes
1075 1. Calculating the coefficients (using `n_samples`)
1076 2. Correlating the result from 1) using `method`
1077 3. Weighting the correlation matrix (if `weight` is True)
1078 4. Remove redundancies (if `trim_redundant` is True)
1080 Args:
1081 model (Model): The QFM model
1082 n_samples (int): Number of samples to calculate average of coefficients
1083 random_key (Optional[random.PRNGKey]): JAX random key for parameter
1084 initialization. If None, uses the model's internal random key.
1085 method (Optional[str], optional): Correlation method. Defaults to "pearson".
1086 scale (Optional[bool], optional): Whether to scale the number of samples.
1087 Defaults to False.
1088 weight (Optional[bool], optional): Whether to weight the correlation matrix.
1089 Defaults to False.
1090 trim_redundant (Optional[bool], optional): Whether to remove redundant
1091 correlations. Defaults to True.
1092 nan_to_one (Optional[bool], optional): Whether to set nan to 1.
1093 Defaults to False.
1094 **kwargs: Additional keyword arguments for the model function.
1096 Returns:
1097 Tuple[jnp.ndarray, jnp.ndarray]: The fourier fingerprint
1098 and the frequency indices
1099 """
1100 _, coeffs, freqs = cls._calculate_coefficients(
1101 model, n_samples, random_key, scale, **kwargs
1102 )
1104 fourier_fingerprint = cls._correlate(coeffs.transpose(), method=method)
1106 if nan_to_one:
1107 # set nan to 1
1108 fourier_fingerprint[jnp.isnan(fourier_fingerprint)] = 1.0
1110 # perform weighting if requested
1111 fourier_fingerprint = (
1112 cls._weighting(fourier_fingerprint) if weight else fourier_fingerprint
1113 )
1115 if trim_redundant:
1116 mask = cls._calculate_mask(freqs)
1118 # apply the mask on the fingerprint
1119 fourier_fingerprint = mask * fourier_fingerprint
1121 row_mask = jnp.any(jnp.isfinite(fourier_fingerprint), axis=1)
1122 col_mask = jnp.any(jnp.isfinite(fourier_fingerprint), axis=0)
1124 fourier_fingerprint = fourier_fingerprint[row_mask][:, col_mask]
1126 return fourier_fingerprint, freqs
1128 @classmethod
1129 def calculate_fcc(
1130 cls,
1131 fourier_fingerprint: jnp.ndarray,
1132 ) -> float:
1133 """
1134 Method to calculate the FCC based on an existing correlation matrix.
1135 Calculate absolute and then the average over this matrix.
1136 The Fingerprint can be obtained via `get_fourier_fingerprint`
1138 Args:
1139 fourier_fingerprint (jnp.ndarray): Correlation matrix of coefficients
1140 Returns:
1141 float: The FCC
1142 """
1143 # apply the mask on the fingerprint
1144 return jnp.nanmean(jnp.abs(fourier_fingerprint))
1146 @classmethod
1147 def _calculate_mask(cls, freqs: jnp.ndarray) -> jnp.ndarray:
1148 """
1149 Method to calculate a mask filtering out redundant elements
1150 of the Fourier correlation matrix, based on the provided frequency vector.
1151 It does so by 'simulating' the operations that would be performed
1152 by `_correlate`.
1154 Args:
1155 freqs (jnp.ndarray): Array of frequencies
1157 Returns:
1158 jnp.ndarray: The mask
1159 """
1160 # TODO: this part can be heavily optimized, by e.g. using a "positive_only"
1161 # flag when calculating the coefficients.
1162 # However this would change the numerical values
1163 # (while the order should be still the same).
1165 # disregard all the negativ frequencies
1166 freqs[freqs < 0] = jnp.nan
1167 # compute the outer product of the frequency vectors for arbitrary dimensions
1168 # or just use the existing frequency vector if it is 1D
1169 nd_freqs = (
1170 reduce(jnp.multiply, jnp.ix_(*freqs)) if len(freqs.shape) > 1 else freqs
1171 )
1172 # TODO: could prevent this if we're not using .squeeze()..
1174 # "simulate" what would happen on correlating the coefficients
1175 corr_freqs = jnp.outer(nd_freqs, nd_freqs)
1176 # mask all frequencies that are nan now
1177 # (i.e. all correlations with a negative frequency component)
1178 corr_mask = jnp.where(jnp.isnan(corr_freqs), corr_freqs, 1)
1179 # from this, disregard all the other redundant correlations (i.e. c_0_1 = c_1_0)
1180 corr_mask = corr_mask.at[jnp.triu_indices(corr_mask.shape[0], 0)].set(jnp.nan)
1182 return corr_mask
1184 @classmethod
1185 def _calculate_coefficients(
1186 cls,
1187 model: Model,
1188 n_samples: int,
1189 random_key: Optional[random.PRNGKey] = None,
1190 scale: bool = False,
1191 **kwargs: Any,
1192 ) -> Tuple[jnp.ndarray, jnp.ndarray]:
1193 """
1194 Calculates the Fourier coefficients of a given model
1195 using `n_samples`.
1196 Optionally, `noise_params` can be passed to perform noisy simulation.
1198 Args:
1199 model (Model): The QFM model
1200 n_samples (int): Number of samples to calculate average of coefficients
1201 random_key (Optional[random.PRNGKey]): JAX random key for parameter
1202 initialization. If None, uses the model's internal random key.
1203 scale (bool, optional): Whether to scale the number of samples.
1204 Defaults to False.
1205 **kwargs: Additional keyword arguments for the model function.
1207 Returns:
1208 Tuple[jnp.ndarray, jnp.ndarray]: Parameters and Coefficients of size NxK
1209 """
1210 if n_samples > 0:
1211 if scale:
1212 total_samples = int(
1213 jnp.power(2, model.n_qubits) * n_samples * model.n_input_feat
1214 )
1215 log.info(f"Using {total_samples} samples.")
1216 else:
1217 total_samples = n_samples
1218 model.initialize_params(random_key, repeat=total_samples)
1219 else:
1220 total_samples = 1
1222 coeffs, freqs = Coefficients.get_spectrum(
1223 model, shift=True, trim=True, **kwargs
1224 )
1226 return model.params, coeffs, freqs
1228 @classmethod
1229 def _correlate(cls, mat: jnp.ndarray, method: str = "pearson") -> jnp.ndarray:
1230 """
1231 Correlates two arrays using `method`.
1232 Currently, `pearson` and `spearman` are supported.
1234 Args:
1235 mat (jnp.ndarray): Array of shape (N, K)
1236 method (str, optional): Correlation method. Defaults to "pearson".
1238 Raises:
1239 ValueError: If the method is not supported.
1241 Returns:
1242 jnp.ndarray: Correlation matrix of `a` and `b`.
1243 """
1244 assert len(mat.shape) >= 2, "Input matrix must have at least 2 dimensions"
1246 # Note that for the general n-D case, we have to flatten along
1247 # the first axis (last one is batch).
1248 # Note that the order here is important so we can easily filter out
1249 # negative coefficients later.
1250 # Consider the following example: [[1,2,3],[4,5,6],[7,8,9]]
1251 # we want to get [1, 4, 7, 2, 5, 8, 3, 6, 9]
1252 # such that after correlation, all positive indexed coefficients
1253 # will be in the bottom right quadrant
1254 if method == "pearson":
1255 result = cls._pearson(mat.reshape(mat.shape[0], -1))
1256 # result = cls._pearson(mat.reshape(mat.shape[-1], -1, order="F"))
1257 elif method == "spearman":
1258 result = cls._spearman(mat.reshape(mat.shape[0], -1))
1259 # result = cls._spearman(mat.reshape(mat.shape[-1], -1, order="F"))
1260 else:
1261 raise ValueError(
1262 f"Unknown correlation method: {method}. \
1263 Must be 'pearson' or 'spearman'."
1264 )
1266 return result
1268 @classmethod
1269 def _pearson(
1270 cls, mat: jnp.ndarray, cov: Optional[bool] = False, minp: Optional[int] = 1
1271 ) -> jnp.ndarray:
1272 """
1273 Based on Pandas correlation method as implemented here:
1274 https://github.com/pandas-dev/pandas/blob/main/pandas/_libs/algos.pyx
1276 Compute Pearson correlation between columns of `mat`,
1277 permitting missing values (NaN or ±Inf).
1279 If the input is complex, real and imaginary parts are stacked along
1280 the sample axis so that both components contribute to the correlation
1281 without discarding information.
1283 Args:
1284 mat : array_like, shape (N, K)
1285 Input data.
1286 cov : bool, optional
1287 If True, return the sample covariance matrix instead of
1288 correlation. Defaults to False.
1289 minp : int, optional
1290 Minimum number of paired observations required to form a correlation.
1291 If the number of valid pairs for (i, j) is < minp, the result is NaN.
1293 Returns:
1294 corr : ndarray, shape (K, K)
1295 Pearson correlation matrix.
1296 """
1297 # Preserve complex information by splitting into real / imag samples
1298 if jnp.iscomplexobj(mat):
1299 mat = jnp.concatenate([mat.real, mat.imag], axis=0)
1301 mat = jnp.asarray(mat)
1303 # pre-compute finite mask (N, K)
1304 mask = jnp.isfinite(mat)
1305 fmask = mask.astype(mat.dtype)
1307 # Replace non-finite entries with 0 so arithmetic is safe;
1308 # the mask keeps track of validity.
1309 safe = jnp.where(mask, mat, 0.0)
1311 # Pairwise valid-observation counts (K, K)
1312 nobs = fmask.T @ fmask
1314 # Pairwise sums (only over mutually valid rows)
1315 # For columns i, j the "valid" rows are mask[:,i] & mask[:,j].
1316 # sum_x[i,j] = sum of mat[:,i] where both i and j are valid.
1317 # Using: safe[:,i] * mask[:,j] zeroes out rows invalid for j.
1318 # Then summing over N gives sum_x[i,j].
1319 # safe.T @ fmask gives (K, K) where entry (i,j) = sum of safe[:,i]*mask[:,j]
1320 sum_x = safe.T @ fmask # (K, K) – row-var sums
1321 sum_y = fmask.T @ safe # (K, K) – col-var sums
1323 # Note: explicit means (sum_x/nobs, sum_y/nobs) are not needed as
1324 # separate variables — the computational formula used below
1325 # (e.g. ssx = Σx² − (Σx)²/n) implicitly handles mean-centering.
1327 # Cross products, sum-of-squares via computational formula:
1328 # ssx = Σx² − (Σx)²/n, ssy = Σy² − (Σy)²/n,
1329 # sxy = Σxy − (Σx)(Σy)/n
1330 # All sums are taken over the pairwise-valid subset for each (i,j).
1331 masked = safe * fmask # same as safe but explicit
1332 sum_xy = masked.T @ masked # (K, K)
1334 # ssx[i,j] = sum_xx_ij - nobs * mean_x^2 (but sum_xx_ij uses pair mask)
1335 # We need sum of x^2 over the *pair* mask, not just column mask.
1336 # sum_x2[i,j] = sum_n safe[n,i]^2 * mask[n,i] * mask[n,j]
1337 safe_sq = safe**2
1338 sum_x2 = safe_sq.T @ fmask # (K, K)
1339 sum_y2 = fmask.T @ safe_sq # (K, K)
1341 ssx = sum_x2 - sum_x**2 / jnp.where(nobs > 0, nobs, 1.0)
1342 ssy = sum_y2 - sum_y**2 / jnp.where(nobs > 0, nobs, 1.0)
1343 sxy = sum_xy - (sum_x * sum_y) / jnp.where(nobs > 0, nobs, 1.0)
1345 if cov:
1346 denom = jnp.where(nobs > 1, nobs - 1, jnp.nan)
1347 result = sxy / denom
1348 else:
1349 denom = jnp.sqrt(ssx * ssy)
1350 result = jnp.where(denom > 0, sxy / denom, jnp.nan)
1351 # clip numerical drift to [-1, 1]
1352 result = jnp.clip(result, -1.0, 1.0)
1354 # Enforce minp: set entries with too few observations to NaN
1355 result = jnp.where(nobs < minp, jnp.nan, result)
1357 return result
1359 @classmethod
1360 def _spearman(cls, mat: jnp.ndarray, minp: Optional[int] = 1) -> jnp.ndarray:
1361 """
1362 Based on Pandas correlation method as implemented here:
1363 https://github.com/pandas-dev/pandas/blob/main/pandas/_libs/algos.pyx
1365 Compute Spearman correlation between columns of `mat`,
1366 permitting missing values (NaN or ±Inf).
1368 If the input is complex, real and imaginary parts are stacked along
1369 the sample axis so that both components contribute to the correlation
1370 without discarding information.
1372 Args:
1373 mat : array_like, shape (N, K)
1374 Input data.
1375 minp : int, optional
1376 Minimum number of paired observations required to form a correlation.
1377 If the number of valid pairs for (i, j) is < minp, the result is NaN.
1379 Returns:
1380 corr : ndarray, shape (K, K)
1381 Spearman correlation matrix.
1382 """
1383 # Preserve complex information by splitting into real / imag samples
1384 if jnp.iscomplexobj(mat):
1385 mat = jnp.concatenate([mat.real, mat.imag], axis=0)
1387 mat = jnp.asarray(mat)
1388 N, K = mat.shape
1390 # trivial all-NaN answer if too few rows
1391 if N < minp:
1392 return jnp.full((K, K), jnp.nan)
1394 # mask of finite entries
1395 mask = jnp.isfinite(mat) # shape (N, K), dtype=bool
1397 # precompute ranks column-wise ignoring NaNs
1398 ranks = np.full((N, K), np.nan)
1399 for j in range(K):
1400 valid = mask[:, j]
1401 if valid.any():
1402 ranks[valid, j] = rankdata(mat[valid, j], method="average")
1404 ranks = jnp.asarray(ranks)
1406 # Vectorised Pearson on the ranks
1407 # Replace NaN ranks with 0; use mask to track validity.
1408 rank_mask = jnp.isfinite(ranks)
1409 safe_ranks = jnp.where(rank_mask, ranks, 0.0)
1411 # Pairwise valid-observation counts (K, K)
1412 fmask = rank_mask.astype(ranks.dtype)
1413 nobs = fmask.T @ fmask
1415 # Pairwise sums over mutually-valid rows
1416 sum_x = safe_ranks.T @ fmask # (K, K)
1417 sum_y = fmask.T @ safe_ranks # (K, K)
1419 # Pairwise products
1420 masked_ranks = safe_ranks * fmask # same as safe_ranks
1421 sum_xy = masked_ranks.T @ masked_ranks # (K, K)
1423 safe_sq = safe_ranks**2
1424 sum_x2 = safe_sq.T @ fmask # (K, K)
1425 sum_y2 = fmask.T @ safe_sq # (K, K)
1427 nobs_safe = jnp.where(nobs > 0, nobs, 1.0)
1428 ssx = sum_x2 - sum_x**2 / nobs_safe
1429 ssy = sum_y2 - sum_y**2 / nobs_safe
1430 sxy = sum_xy - (sum_x * sum_y) / nobs_safe
1432 denom = jnp.sqrt(ssx * ssy)
1433 result = jnp.where(denom > 0, sxy / denom, jnp.nan)
1434 result = jnp.clip(result, -1.0, 1.0)
1436 # Enforce minp
1437 result = jnp.where(nobs < minp, jnp.nan, result)
1439 return result
1441 @classmethod
1442 def _weighting(cls, fourier_fingerprint: jnp.ndarray) -> jnp.ndarray:
1443 """
1444 Performs weighting on the given correlation matrix.
1445 Here, low-frequent coefficients are weighted more heavily.
1447 Args:
1448 correlation (jnp.ndarray): Correlation matrix
1449 """
1450 # TODO: in Future iterations, this can be optimized by computing
1451 # on the trimmed matrix instead.
1453 assert (
1454 fourier_fingerprint.shape[0] % 2 != 0
1455 and fourier_fingerprint.shape[1] % 2 != 0
1456 ), (
1457 "Correlation matrix must have odd dimensions. \
1458 Hint: use `trim` argument when calling `get_spectrum`."
1459 )
1460 assert fourier_fingerprint.shape[0] == fourier_fingerprint.shape[1], (
1461 "Correlation matrix must be square."
1462 )
1464 def quadrant_to_matrix(a: jnp.ndarray) -> jnp.ndarray:
1465 """
1466 Transforms [[1,2],[3,4]] to
1467 [[1,2,1],[3,4,3],[1,2,1]]
1469 Args:
1470 a (jnp.ndarray): _description_
1472 Returns:
1473 jnp.ndarray: _description_
1474 """
1475 # rotates a from [[1,2],[3,4]] to [[3,4],[1,2]]
1476 a_rot = jnp.rot90(a)
1477 # merge the two matrices
1478 left = jnp.concat([a, a_rot])
1479 # merges left and right (left flipped)
1480 b = jnp.concat(
1481 [left, jnp.flip(left)],
1482 axis=1,
1483 )
1484 # remove the middle column and row
1485 return jnp.delete(
1486 jnp.delete(b, (b.shape[0] // 2), axis=0), (b.shape[1] // 2), axis=1
1487 )
1489 nc = fourier_fingerprint.shape[0] // 2 + 1
1490 weights = jnp.mgrid[0:nc:1, 0:nc:1].sum(axis=0) / ((nc - 1) * 2)
1491 weights_matrix = quadrant_to_matrix(weights)
1493 return fourier_fingerprint * weights_matrix
1494 raise NotImplementedError("Weighting method is not implemented")
1497class Datasets:
1498 @classmethod
1499 def generate_fourier_series(
1500 cls,
1501 random_key: random.PRNGKey,
1502 model: Model,
1503 coefficients_min: float = 0.0,
1504 coefficients_max: float = 1.0,
1505 zero_centered: bool = False,
1506 ) -> jnp.ndarray:
1507 """
1508 Generates the Fourier series representation of a function.
1509 It uses the `model.frequencies` property to retrieve the frequency
1510 information. This ensures that the resulting Fourier series is
1511 compatible with the model.
1513 This function is capable of generating $D$-dimensional Fourier series
1514 (again defined by `model.n_input_feat`).
1515 The highest frequency $N$ is retrieved per dimension.
1517 Samples of the Fourier coefficients are drawn from a uniform circle.
1519 Args:
1520 random_key (random.PRNGKey): Random number key for JAX.
1521 model (Model): The quantum circuit model.
1522 coefficients_min (float, optional): Minimum value for the coefficients.
1523 Defaults to 0.0.
1524 coefficients_max (float, optional): Maximum value for the coefficients.
1525 Defaults to 1.0.
1526 zero_centered (bool, optional): Whether to zero-center the coefficients.
1527 Defaults to False.
1529 Returns:
1530 jnp.ndarray: Input domain samples with shape ((N,)*D, D)
1531 jnp.ndarray: Fourier series values with shape ((N,)*D)
1532 jnp.ndarray: Fourier coefficients with shape ((N,)*D)
1534 """
1535 # TODO: the following code can be considered to
1536 # capturing a truly random spectrum.
1537 # add some constraints on the spectrum, i.e. not fully
1539 # Note: one key observation for understanding the following code is,
1540 # that instead of wrapping your head around symmetries in multi-
1541 # dimensional coefficient matrices, one can simply look at the flattened
1542 # version of such a matrix and reshape later. It just works out.
1544 # going from [0, 2pi] with the resolution required for highest frequency
1545 # permute with input dimensionality to get an n-d grid of domain samples
1546 # the output shape comes from the fact that want to create a "coordinate system"
1547 domain_samples_per_input_dim = jnp.stack(
1548 jnp.meshgrid(
1549 *[jnp.arange(0, 2 * jnp.pi, 2 * jnp.pi / d) for d in model.degree]
1550 )
1551 ).T.reshape(-1, model.n_input_feat)
1553 # generate the frequency indices for each dimension.
1554 # this will have the same shape as the domain samples
1555 frequencies = jnp.stack(jnp.meshgrid(*model.frequencies)).T.reshape(
1556 -1, model.n_input_feat
1557 )
1559 # using the frequency information, sample coefficients for each dimension
1560 # shape: (input_dims, n_freqs_per_input_dim // 2 + 1)
1562 coefficients = cls.uniform_circle(
1563 random_key,
1564 low=coefficients_min,
1565 high=coefficients_max,
1566 size=math.prod(model.degree) // 2 + 1,
1567 )
1569 # zero center (first coeff = 0)
1570 # we can assume the first coeff is the offset, because we're dealing
1571 # with a non-symmetric spectrum here
1572 if zero_centered:
1573 coefficients = coefficients.at[0].set(0.0)
1574 else:
1575 coefficients = coefficients.at[0].set(coefficients[0].real)
1577 # ensure symmetry (here, non_negative_ is removed!),
1578 # giving us the full coefficients vector
1579 coefficients = jnp.concat(
1580 [
1581 jnp.flip(coefficients[..., 1:]).conjugate(),
1582 coefficients,
1583 ],
1584 axis=-1,
1585 )
1587 # Vectorized version of $f(x) = \sum_{n=0}^{N-1} c_n * e^{i * \omega_n * x}$
1588 # it takes into account the input dimension, i.e. the output is a matrix
1589 # normalization uses the n_freqs component of the coefficients
1590 values = jnp.real(
1591 (
1592 jnp.exp(1j * (domain_samples_per_input_dim @ frequencies.T))
1593 * coefficients
1594 ).sum(axis=1)
1595 / coefficients.size
1596 )
1598 # return all the information we have
1599 return [
1600 domain_samples_per_input_dim.reshape(*model.degree, -1),
1601 values.reshape(model.degree),
1602 coefficients.reshape(model.degree),
1603 ]
1605 @classmethod
1606 def uniform_circle(
1607 cls,
1608 random_key: random.PRNGKey,
1609 size: Union[jnp.ndarray, List, int],
1610 low=0.0,
1611 high=1.0,
1612 ):
1613 """
1614 Random number generator for complex numbers sampled inside the unit circle
1616 Args:
1617 random_key (random.PRNGKey): Random number key for JAX.
1618 size (Union[jnp.ndarray, int]): Number of samples. If a 2D array is passed,
1619 the first dimension will be the number of dimensions.
1620 low (float, optional): Minimum Radius. Defaults to 0.0.
1621 high (float, optional): Maximum Radius. Defaults to 1.0.
1623 Returns
1624 jnp.ndarray: Array of complex numbers with shape of `size`
1625 """
1627 if isinstance(size, int):
1628 size = jnp.array([size])
1630 random_key, random_key1 = random.split(random_key)
1631 return jnp.sqrt(
1632 random.uniform(random_key, size, minval=low, maxval=high)
1633 ) * jnp.exp(2j * jnp.pi * random.uniform(random_key1, size))