Coverage for qml_essentials / coefficients.py: 95%
520 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-10 08:17 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-10 08:17 +0000
1from __future__ import annotations
2import math
3from collections import defaultdict
4from dataclasses import dataclass
5import jax.numpy as jnp
6from jax import random
7import numpy as np
8from scipy.stats import rankdata
9from functools import reduce
10from typing import List, Tuple, Optional, Any, Dict, Union
12from qml_essentials.model import Model
13from qml_essentials.utils import PauliCircuit
14from qml_essentials.operations import (
15 Operation,
16 PauliX,
17 PauliY,
18 PauliZ,
19)
21import logging
23log = logging.getLogger(__name__)
26class Coefficients:
27 @classmethod
28 def get_spectrum(
29 cls,
30 model: Model,
31 mfs: int = 1,
32 mts: int = 1,
33 shift=False,
34 trim=False,
35 numerical_cap: Optional[float] = -1,
36 **kwargs,
37 ) -> Tuple[jnp.ndarray, jnp.ndarray]:
38 """
39 Extracts the coefficients of a given model using a FFT (jnp-fft).
41 Note that the coefficients are complex numbers, but the imaginary part
42 of the coefficients should be very close to zero, since the expectation
43 values of the Pauli operators are real numbers.
45 It can perform oversampling in both the frequency and time domain
46 using the `mfs` and `mts` arguments.
48 Args:
49 model (Model): The model to sample.
50 mfs (int): Multiplicator for the highest frequency. Default is 2.
51 mts (int): Multiplicator for the number of time samples. Default is 1.
52 shift (bool): Whether to apply jnp-fftshift. Default is False.
53 trim (bool): Whether to remove the Nyquist frequency if spectrum is even.
54 Default is False.
55 numerical_cap (Optional[float]): Numerical cap for the coefficients.
56 If positive, coefficients with magnitude below the cap are
57 zeroed and, for a single input feature, frequencies that
58 vanish entirely are removed from both `coeffs` and `freqs`.
59 kwargs (Any): Additional keyword arguments for the model function.
61 Returns:
62 Tuple[jnp.ndarray, jnp.ndarray]: Tuple containing the coefficients
63 and frequencies.
64 """
65 kwargs.setdefault("force_mean", True)
66 kwargs.setdefault("execution_type", "expval")
68 coeffs, freqs = cls._fourier_transform(model, mfs=mfs, mts=mts, **kwargs)
70 if not jnp.isclose(jnp.sum(coeffs).imag, 0.0, rtol=1.0e-5):
71 raise ValueError(
72 f"Spectrum is not real. Imaginary part of coefficients is:\
73 {jnp.sum(coeffs).imag}"
74 )
76 if trim:
77 for ax in range(model.n_input_feat):
78 if coeffs.shape[ax] % 2 == 0:
79 coeffs = np.delete(coeffs, len(coeffs) // 2, axis=ax)
80 freqs = [np.delete(freq, len(freq) // 2, axis=ax) for freq in freqs]
82 if shift:
83 coeffs = jnp.fft.fftshift(coeffs, axes=list(range(model.n_input_feat)))
84 freqs = np.fft.fftshift(freqs)
86 if numerical_cap > 0:
87 # set coeffs below threshold to zero
88 coeffs = jnp.where(
89 jnp.abs(coeffs) < numerical_cap,
90 jnp.zeros_like(coeffs),
91 coeffs,
92 )
94 # Drop frequencies whose coefficients vanish entirely after
95 # capping, so the returned spectrum reflects only the surviving
96 # frequencies. Well-defined only for a single (1-D) frequency
97 # axis; for multi-dim input the rectangular grid is left intact.
98 if model.n_input_feat == 1:
99 if coeffs.ndim == 1:
100 surviving = coeffs != 0
101 else:
102 surviving = jnp.any(
103 coeffs != 0, axis=tuple(range(1, coeffs.ndim))
104 )
105 coeffs = coeffs[surviving]
106 freqs = [freqs[0][surviving]]
108 if len(freqs) == 1:
109 freqs = freqs[0]
111 return coeffs, freqs
113 @classmethod
114 def _fourier_transform(
115 cls, model: Model, mfs: int, mts: int, **kwargs: Any
116 ) -> jnp.ndarray:
117 # Create a frequency vector with as many frequencies as model degrees,
118 # oversampled by mfs
119 n_freqs: jnp.ndarray = jnp.array(
120 [mfs * model.degree[i] for i in range(model.n_input_feat)]
121 )
123 start, stop, step = 0, 2 * mts * jnp.pi, 2 * jnp.pi / n_freqs
124 # Stretch according to the number of frequencies
125 inputs: List = [
126 jnp.arange(start, stop, step[i]) for i in range(model.n_input_feat)
127 ]
129 # permute with input dimensionality
130 nd_inputs = jnp.array(
131 jnp.meshgrid(*[inputs[i] for i in range(model.n_input_feat)])
132 ).T.reshape(-1, model.n_input_feat)
134 # Output vector is not necessarily the same length as input
135 outputs = model(inputs=nd_inputs, **kwargs)
136 outputs = outputs.reshape(
137 *[inputs[i].shape[0] for i in range(model.n_input_feat)], -1
138 ).squeeze()
140 coeffs = jnp.fft.fftn(outputs, axes=list(range(model.n_input_feat)))
142 freqs = [
143 jnp.fft.fftfreq(int(mts * n_freqs[i]), 1 / n_freqs[i])
144 for i in range(model.n_input_feat)
145 ]
146 # freqs = jnp.fft.fftfreq(mts * n_freqs, 1 / n_freqs)
148 # TODO: this could cause issues with multidim input
149 # FIXME: account for different frequencies in multidim input scenarios
150 # Run the fft and rearrange +
151 # normalize the output (using product if multidim)
152 return (
153 coeffs / math.prod(outputs.shape[0 : model.n_input_feat]),
154 freqs,
155 )
157 @classmethod
158 def get_psd(cls, coeffs: jnp.ndarray) -> jnp.ndarray:
159 """
160 Calculates the power spectral density (PSD) from given Fourier coefficients.
162 Args:
163 coeffs (jnp.ndarray): The Fourier coefficients.
165 Returns:
166 jnp.ndarray: The power spectral density.
167 """
168 # TODO: if we apply trim=True in advance, this will be slightly wrong..
170 def abs2(x):
171 return x.real**2 + x.imag**2
173 scale = 2.0 / (len(coeffs) ** 2)
174 return scale * abs2(coeffs)
176 @classmethod
177 def evaluate_Fourier_series(
178 cls,
179 coefficients: jnp.ndarray,
180 frequencies: jnp.ndarray,
181 inputs: Union[jnp.ndarray, list, float],
182 ) -> float:
183 """
184 Evaluate the function value of a Fourier series at one point.
186 Args:
187 coefficients (jnp.ndarray): Coefficients of the Fourier series.
188 frequencies (jnp.ndarray): Corresponding frequencies.
189 inputs (jnp.ndarray): Point at which to evaluate the function.
190 Returns:
191 float: The function value at the input point.
192 """
193 if isinstance(frequencies, list):
194 if len(coefficients.shape) <= len(frequencies):
195 coefficients = coefficients[..., jnp.newaxis]
196 else:
197 if len(coefficients.shape) == 1:
198 coefficients = coefficients[..., jnp.newaxis]
200 if isinstance(inputs, list):
201 inputs = jnp.array(inputs)
202 if len(inputs.shape) < 1:
203 inputs = inputs[jnp.newaxis, ...]
205 if isinstance(frequencies, list):
206 input_dim = len(frequencies)
207 frequencies = jnp.stack(jnp.meshgrid(*frequencies))
208 if input_dim != len(inputs):
209 frequencies = jnp.repeat(
210 frequencies[jnp.newaxis, ...], inputs.shape[0], axis=0
211 )
212 freq_inputs = jnp.einsum("bi...,b->b...", frequencies, inputs)
213 exponents = jnp.exp(1j * freq_inputs).T
214 exp = jnp.einsum("jl...k,jl...b->b...k", coefficients, exponents)
215 else:
216 freq_inputs = jnp.einsum("i...,i->...", frequencies, inputs)
217 exponents = jnp.exp(1j * freq_inputs).T
218 exp = jnp.einsum("jl...k,jl...->k...", coefficients, exponents)
219 else:
220 frequencies = jnp.repeat(
221 frequencies[jnp.newaxis, ...], inputs.shape[0], axis=0
222 )
223 freq_inputs = jnp.einsum("i...,i->i...", frequencies, inputs)
224 exponents = jnp.exp(1j * freq_inputs)
225 exp = jnp.einsum("j...k,ij...->ik...", coefficients, exponents)
227 return jnp.squeeze(jnp.real(exp))
230class FourierTree:
231 """
232 Sine-cosine tree representation for the algorithm by Nemkov et al.
233 This tree can be used to obtain analytical Fourier coefficients for a given
234 Pauli-Clifford circuit.
235 """
237 class CoefficientsTreeNode:
238 """
239 Representation of a node in the coefficients tree for the algorithm by
240 Nemkov et al.
241 """
243 def __init__(
244 self,
245 parameter_idx: Optional[int],
246 observable: FourierTree.PauliOperator,
247 is_sine_factor: bool,
248 is_cosine_factor: bool,
249 left: Optional[FourierTree.CoefficientsTreeNode] = None,
250 right: Optional[FourierTree.CoefficientsTreeNode] = None,
251 ):
252 """
253 Coefficient tree node initialisation. Each node has information about
254 its creation context and it's children, i.e.:
256 Args:
257 parameter_idx (Optional[int]): Index of the corresp. param. index i.
258 observable (FourierTree.PauliOperator): The nodes observable to
259 obtain the expectation value that contributes to the constant
260 term.
261 is_sine_factor (bool): If this node belongs to a sine coefficient.
262 is_cosine_factor (bool): If this node belongs to a cosine coefficient.
263 left (Optional[CoefficientsTreeNode]): left child (if any).
264 right (Optional[CoefficientsTreeNode]): right child (if any).
265 """
266 self.parameter_idx = parameter_idx
268 assert not (is_sine_factor and is_cosine_factor), (
269 "Cannot be sine and cosine at the same time"
270 )
271 self.is_sine_factor = is_sine_factor
272 self.is_cosine_factor = is_cosine_factor
274 # If the observable does not constist of only Z and I, the
275 # expectation (and therefore the constant node term) is zero
276 if jnp.logical_or(
277 observable.list_repr == 0, observable.list_repr == 1
278 ).any():
279 self.term = 0.0
280 else:
281 self.term = observable.phase
283 self.left = left
284 self.right = right
286 def evaluate(self, parameters: list[float]) -> float:
287 """
288 Recursive function to evaluate the expectation of the coefficient tree,
289 starting from the current node.
291 Args:
292 parameters (list[float]): The parameters, by which the circuit (and
293 therefore the tree) is parametrised.
295 Returns:
296 float: The expectation for the current node and it's children.
297 """
298 factor = (
299 parameters[self.parameter_idx]
300 if self.parameter_idx is not None
301 else 1.0
302 )
303 if self.is_sine_factor:
304 factor = 1j * jnp.sin(factor)
305 elif self.is_cosine_factor:
306 factor = jnp.cos(factor)
307 if not (self.left or self.right): # leaf
308 return factor * self.term
310 sum_children = 0.0
311 if self.left:
312 left = self.left.evaluate(parameters)
313 sum_children = sum_children + left
314 if self.right:
315 right = self.right.evaluate(parameters)
316 sum_children = sum_children + right
318 return factor * sum_children
320 def get_leafs(
321 self,
322 sin_list: List[int],
323 cos_list: List[int],
324 existing_leafs: List[FourierTree.TreeLeaf] = [],
325 ) -> List[FourierTree.TreeLeaf]:
326 """
327 Traverse the tree starting from the current node, to obtain the tree
328 leafs only.
329 The leafs correspond to the terms in the sine-cosine tree
330 representation that eventually are used to obtain coefficients and
331 frequencies.
332 Sine and cosine lists are recursively passed to the children until a
333 leaf is reached (top to bottom).
334 Leafs are then passed bottom to top to the caller.
336 Args:
337 sin_list (List[int]): Current number of sine contributions for each
338 parameter. Has the same length as the parameters, as each
339 position corresponds to one parameter.
340 cos_list (List[int]): Current number of cosine contributions for
341 each parameter. Has the same length as the parameters, as each
342 position corresponds to one parameter.
343 existing_leafs (List[TreeLeaf]): Current list of leaf nodes from
344 parents.
346 Returns:
347 List[TreeLeaf]: Updated list of leaf nodes.
348 """
350 if self.is_sine_factor:
351 sin_list = sin_list.at[self.parameter_idx].set(
352 sin_list[self.parameter_idx] + 1
353 )
354 if self.is_cosine_factor:
355 cos_list = cos_list.at[self.parameter_idx].set(
356 cos_list[self.parameter_idx] + 1
357 )
359 if not (self.left or self.right): # leaf
360 if self.term != 0.0:
361 return [FourierTree.TreeLeaf(sin_list, cos_list, self.term)]
362 else:
363 return []
365 if self.left:
366 leafs_left = self.left.get_leafs(
367 sin_list.copy(), cos_list.copy(), existing_leafs.copy()
368 )
369 else:
370 leafs_left = []
372 if self.right:
373 leafs_right = self.right.get_leafs(
374 sin_list.copy(), cos_list.copy(), existing_leafs.copy()
375 )
376 else:
377 leafs_right = []
379 existing_leafs.extend(leafs_left)
380 existing_leafs.extend(leafs_right)
381 return existing_leafs
383 @dataclass
384 class TreeLeaf:
385 """
386 Coefficient tree leafs according to the algorithm by Nemkov et al., which
387 correspond to the terms in the sine-cosine tree representation that
388 eventually are used to obtain coefficients and frequencies.
390 Args:
391 sin_indices (List[int]): Current number of sine contributions for each
392 parameter. Has the same length as the parameters, as each
393 position corresponds to one parameter.
394 cos_indices (List[int]): Current number of cosine contributions for
395 each parameter. Has the same length as the parameters, as each
396 position corresponds to one parameter.
397 term (jnp.complex): Constant factor of the leaf, depending on the
398 expectation value of the observable, and a phase.
399 """
401 sin_indices: List[int]
402 cos_indices: List[int]
403 term: complex
405 class PauliOperator:
406 """
407 Utility class for storing Pauli Rotations, the corresponding indices
408 in the XY-Space (whether there is a gate with X or Y generator at a
409 certain qubit) and the phase.
411 Args:
412 pauli (Union[Operator, jnp.ndarray[int]]): Pauli Rotation Operation
413 or list representation
414 n_qubits (int): Number of qubits in the circuit
415 prev_xy_indices (Optional[jnp.ndarray[bool]]): X/Y indices of the
416 previous Pauli sequence. Defaults to None.
417 is_observable (bool): If the operator is an observable. Defaults to
418 False.
419 is_init (bool): If this Pauli operator is initialised the first
420 time. Defaults to True.
421 phase (float): Phase of the operator. Defaults to 1.0
422 """
424 def __init__(
425 self,
426 pauli: Union[Operation, jnp.ndarray[int]],
427 n_qubits: int,
428 prev_xy_indices: Optional[jnp.ndarray[bool]] = None,
429 is_observable: bool = False,
430 is_init: bool = True,
431 phase: float = 1.0,
432 ):
433 self.is_observable = is_observable
434 self.phase = phase
436 if is_init:
437 if not is_observable:
438 pauli = pauli.generator()
439 self.list_repr = self._create_list_representation(pauli, n_qubits)
440 else:
441 assert isinstance(pauli, jnp.ndarray)
442 self.list_repr = pauli
444 if prev_xy_indices is None:
445 prev_xy_indices = jnp.zeros(n_qubits, dtype=bool)
446 self.xy_indices = jnp.logical_or(
447 prev_xy_indices,
448 self._compute_xy_indices(self.list_repr, rev=is_observable),
449 )
451 @staticmethod
452 def _compute_xy_indices(
453 op: jnp.ndarray[int], rev: bool = False
454 ) -> jnp.ndarray[bool]:
455 """
456 Computes the positions of X or Y gates to an one-hot encoded boolen
457 array.
459 Args:
460 op (jnp.ndarray[int]): Pauli-Operation list representation.
461 rev (bool): Whether to negate the array.
463 Returns:
464 jnp.ndarray[bool]: One hot encoded boolean array.
465 """
466 xy_indices = (op == 0) + (op == 1)
467 if rev:
468 xy_indices = ~xy_indices
469 return xy_indices
471 @staticmethod
472 def _create_list_representation(
473 op: Operation, n_qubits: int
474 ) -> jnp.ndarray[int]:
475 """
476 Create list representation of an Operation.
477 Generally, the list representation is a list of length n_qubits,
478 where at each position a Pauli Operator is encoded as such:
479 I: -1
480 X: 0
481 Y: 1
482 Z: 2
484 Args:
485 op (Operation): Gate operation (PauliX, PauliY, PauliZ, or
486 Hermitian wrapping a multi-qubit Pauli tensor product).
487 n_qubits (int): number of qubits in the circuit
489 Returns:
490 jnp.ndarray[int]: List representation
491 """
492 pauli_repr = -jnp.ones(n_qubits, dtype=int)
494 _NAME_TO_IDX = {"PauliX": 0, "PauliY": 1, "PauliZ": 2}
496 if op.name in _NAME_TO_IDX:
497 pauli_repr = pauli_repr.at[op.wires[0]].set(_NAME_TO_IDX[op.name])
498 elif isinstance(op, PauliX):
499 pauli_repr = pauli_repr.at[op.wires[0]].set(0)
500 elif isinstance(op, PauliY):
501 pauli_repr = pauli_repr.at[op.wires[0]].set(1)
502 elif isinstance(op, PauliZ):
503 pauli_repr = pauli_repr.at[op.wires[0]].set(2)
504 else:
505 # Multi-qubit case: decompose via pauli_string_from_operation
506 from qml_essentials.operations import pauli_string_from_operation
508 pauli_str = pauli_string_from_operation(op)
509 char_to_idx = {"X": 0, "Y": 1, "Z": 2, "I": -1}
510 for i, (wire, ch) in enumerate(zip(op.wires, pauli_str)):
511 idx = char_to_idx.get(ch, -1)
512 if idx >= 0:
513 pauli_repr = pauli_repr.at[wire].set(idx)
515 return pauli_repr
517 def is_commuting(self, pauli: jnp.ndarray[int]) -> bool:
518 """
519 Computes if this Pauli commutes with another Pauli operator.
520 This computation is based on the fact that The commutator is zero
521 if and only if the number of anticommuting single-qubit Paulis is
522 even.
524 Args:
525 pauli (jnp.ndarray[int]): List representation of another Pauli
527 Returns:
528 bool: If the current and other Pauli are commuting.
529 """
530 anticommutator = jnp.where(
531 pauli < 0,
532 False,
533 jnp.where(
534 self.list_repr < 0,
535 False,
536 jnp.where(self.list_repr == pauli, False, True),
537 ),
538 )
539 return not (jnp.sum(anticommutator) % 2)
541 def tensor(self, pauli: jnp.ndarray[int]) -> FourierTree.PauliOperator:
542 """
543 Compute tensor product between the current Pauli and a given list
544 representation of another Pauli operator.
546 Args:
547 pauli (jnp.ndarray[int]): List representation of Pauli
549 Returns:
550 FourierTree.PauliOperator: New Pauli operator object, which
551 contains the tensor product
552 """
553 diff = (pauli - self.list_repr + 3) % 3
554 phase = jnp.where(
555 self.list_repr < 0,
556 1.0,
557 jnp.where(
558 pauli < 0,
559 1.0,
560 jnp.where(
561 diff == 2,
562 1.0j,
563 jnp.where(diff == 1, -1.0j, 1.0),
564 ),
565 ),
566 )
568 obs = jnp.where(
569 self.list_repr < 0,
570 pauli,
571 jnp.where(
572 pauli < 0,
573 self.list_repr,
574 jnp.where(
575 diff == 2,
576 (self.list_repr + 1) % 3,
577 jnp.where(diff == 1, (self.list_repr + 2) % 3, -1),
578 ),
579 ),
580 )
581 phase = self.phase * jnp.prod(phase)
582 return FourierTree.PauliOperator(
583 obs, phase=phase, n_qubits=obs.size, is_init=False, is_observable=True
584 )
586 def __init__(self, model: Model):
587 """
588 Tree initialisation, based on the Pauli-Clifford representation of a model.
589 Currently, only one input feature is supported.
591 **Usage**:
592 ```
593 # initialise a model
594 model = Model(...)
596 # initialise and build FourierTree
597 tree = FourierTree(model)
599 # get expectaion value
600 exp = tree()
602 # Get spectrum (for each observable, we have one list element)
603 coeff_list, freq_list = tree.spectrum()
604 ```
606 Args:
607 model (Model): The Model, for which to build the tree
608 """
609 self.model = model
610 self.tree_roots = None
612 inputs = self.model._inputs_validation([1.0])
614 # Record the circuit tape using jaqsi's tape recording
615 raw_tape = self.model.script._record(params=model.params, inputs=inputs)
617 # Build observables from the model's output_qubit configuration
618 _, obs_list = self.model._build_obs()
620 quantum_tape = PauliCircuit.from_parameterised_circuit(
621 raw_tape, observables=obs_list
622 )
624 self.parameters = [jnp.squeeze(p) for p in quantum_tape.get_parameters()]
626 self.input_indices, self.all_input_indices = quantum_tape.get_input_indices()
628 self.observables = self._encode_observables(quantum_tape.observables)
630 pauli_rot = FourierTree.PauliOperator(
631 quantum_tape.operations[0],
632 self.model.n_qubits,
633 )
634 self.pauli_rotations = [pauli_rot]
635 for op in quantum_tape.operations[1:]:
636 pauli_rot = FourierTree.PauliOperator(
637 op, self.model.n_qubits, pauli_rot.xy_indices
638 )
639 self.pauli_rotations.append(pauli_rot)
641 self.tree_roots = self.build()
642 self.leafs: List[List[FourierTree.TreeLeaf]] = self._get_tree_leafs()
644 def __call__(
645 self,
646 params: Optional[jnp.ndarray] = None,
647 inputs: Optional[jnp.ndarray] = None,
648 **kwargs,
649 ) -> jnp.ndarray:
650 """
651 Evaluates the Fourier tree via sine-cosine terms sum. This is
652 equivalent to computing the expectation value of the observables with
653 respect to the corresponding circuit.
655 Args:
656 params (Optional[jnp.ndarray], optional): Parameters of the model.
657 Defaults to None.
658 inputs (Optional[jnp.ndarray], optional): Inputs to the circuit.
659 Defaults to None.
661 Returns:
662 jnp.ndarray: Expectation value of the tree.
664 Raises:
665 NotImplementedError: When using other "execution_type" as expval.
666 NotImplementedError: When using "noise_params"
669 """
670 params = (
671 self.model._params_validation(params)
672 if params is not None
673 else self.model.params
674 )
675 inputs = (
676 self.model._inputs_validation(inputs)
677 if inputs is not None
678 else self.model._inputs_validation(1.0)
679 )
681 if kwargs.get("execution_type", "expval") != "expval":
682 raise NotImplementedError(
683 f'Currently, only "expval" execution type is supported when '
684 f"building FourierTree. Got {kwargs.get('execution_type', 'expval')}."
685 )
686 if kwargs.get("noise_params", None) is not None:
687 raise NotImplementedError(
688 "Currently, noise is not supported when building FourierTree."
689 )
691 # Record the circuit tape using jaqsi's tape recording
692 raw_tape = self.model.script._record(params=self.model.params, inputs=inputs)
694 # Build observables from the model's output_qubit configuration
695 _, obs_list = self.model._build_obs()
697 quantum_tape = PauliCircuit.from_parameterised_circuit(
698 raw_tape, observables=obs_list
699 )
701 self.parameters = [jnp.squeeze(p) for p in quantum_tape.get_parameters()]
703 results = jnp.zeros(len(self.tree_roots))
704 for i, root in enumerate(self.tree_roots):
705 results = results.at[i].set(jnp.real(root.evaluate(self.parameters)))
707 if kwargs.get("force_mean", False):
708 return jnp.mean(results)
709 else:
710 return results
712 def build(self) -> List[CoefficientsTreeNode]:
713 """
714 Creates the coefficient tree, i.e. it creates and initialises the tree
715 nodes.
716 Leafs can be obtained separately in _get_tree_leafs, once the tree is
717 set up.
719 Returns:
720 List[CoefficientsTreeNode]: The list of root nodes (one root for
721 each observable).
722 """
723 tree_roots = []
724 pauli_rotation_idx = len(self.pauli_rotations) - 1
725 for obs in self.observables:
726 root = self._create_tree_node(obs, pauli_rotation_idx)
727 tree_roots.append(root)
728 return tree_roots
730 def _encode_observables(
731 self, tape_obs: List[Operation]
732 ) -> List[FourierTree.PauliOperator]:
733 """
734 Encodes observables from tape as FourierTree.PauliOperator
735 utility objects.
737 Args:
738 tape_obs (List[Operation]): Observable operations
740 Returns:
741 List[FourierTree.PauliOperator]: List of Pauli operators
742 """
743 observables = []
744 for obs in tape_obs:
745 observables.append(
746 FourierTree.PauliOperator(obs, self.model.n_qubits, is_observable=True)
747 )
748 return observables
750 def _get_tree_leafs(self) -> List[List[TreeLeaf]]:
751 """
752 Obtain all Leaf Nodes with its sine- and cosine terms.
754 Returns:
755 List[List[TreeLeaf]]: For each observable (root), the list of leaf
756 nodes.
757 """
758 leafs = []
759 for root in self.tree_roots:
760 sin_list = jnp.zeros(len(self.parameters), dtype=jnp.int32)
761 cos_list = jnp.zeros(len(self.parameters), dtype=jnp.int32)
762 leafs.append(root.get_leafs(sin_list, cos_list, []))
763 return leafs
765 def get_spectrum(
766 self, force_mean: bool = False
767 ) -> Tuple[List[jnp.ndarray], List[jnp.ndarray]]:
768 """
769 Computes the Fourier spectrum for the tree, consisting of the
770 frequencies and its corresponding coefficinets.
771 If the frag force_mean was set in the constructor, the mean coefficient
772 over all observables (roots) are computed.
774 Args:
775 force_mean (bool, optional): Whether to average over multiple
776 observables. Defaults to False.
778 Returns:
779 Tuple[List[jnp.ndarray], List[jnp.ndarray]]:
780 - List of frequencies, one list for each observable (root).
781 - List of corresponding coefficents, one list for each
782 observable (root).
783 """
784 parameter_indices = [
785 i for i in range(len(self.parameters)) if i not in self.all_input_indices
786 ]
788 coeffs = []
789 for leafs in self.leafs:
790 freq_terms = defaultdict(np.complex128)
791 for input_idx in self.input_indices:
792 for leaf in leafs:
793 leaf_factor, s, c = self._compute_leaf_factors(
794 leaf, parameter_indices, input_idx
795 )
797 for a in range(s + 1):
798 for b in range(c + 1):
799 comb = math.comb(s, a) * math.comb(c, b) * (-1) ** (s - a)
800 freq_terms[2 * a + 2 * b - s - c] += comb * leaf_factor
802 coeffs.append(freq_terms)
804 frequencies, coefficients = self._freq_terms_to_coeffs(coeffs, force_mean)
805 return coefficients, frequencies
807 def _freq_terms_to_coeffs(
808 self, coeffs: List[Dict[int, jnp.ndarray]], force_mean: bool
809 ) -> Tuple[List[jnp.ndarray], List[jnp.ndarray]]:
810 """
811 Given a list of dictionaries of the form:
812 [
813 {
814 freq_obs1_1: coeff1,
815 freq_obs1_2: coeff2,
816 ...
817 },
818 {
819 freq_obs2_1: coeff3,
820 freq_obs2_2: coeff4,
821 ...
822 }
823 ...
824 ],
825 Compute two separate lists of frequencies and coefficients.
826 such that:
827 freqs: [
828 [freq_obs1_1, freq_obs1_1, ...],
829 [freq_obs2_1, freq_obs2_1, ...],
830 ...
831 ]
832 coeffs: [
833 [coeff1, coeff2, ...],
834 [coeff3, coeff4, ...],
835 ...
836 ]
838 If force_mean is set length of the resulting frequency and coefficent
839 list is 1.
841 Args:
842 coeffs (List[Dict[int, jnp.ndarray]]): Frequency->Coefficients
843 dictionary list, one dict for each observable (root).
844 force_mean (bool): Whether to average coefficients over multiple
845 observables.
847 Returns:
848 Tuple[List[jnp.ndarray], List[jnp.ndarray]]:
849 - List of frequencies, one list for each observable (root).
850 - List of corresponding coefficents, one list for each
851 observable (root).
852 """
853 frequencies = []
854 coefficients = []
855 if force_mean:
856 all_freqs = sorted(set([f for c in coeffs for f in c.keys()]))
857 coefficients.append(
858 jnp.array(
859 [
860 jnp.mean(jnp.array([c.get(f, 0.0) for c in coeffs]))
861 for f in all_freqs
862 ]
863 )
864 )
865 frequencies.append(jnp.array(all_freqs))
866 else:
867 for freq_terms in coeffs:
868 freq_terms = dict(sorted(freq_terms.items()))
869 frequencies.append(jnp.array(list(freq_terms.keys())))
870 coefficients.append(jnp.array(list(freq_terms.values())))
871 return frequencies, coefficients
873 def _compute_leaf_factors(
874 self,
875 leaf: TreeLeaf,
876 parameter_indices: List[int],
877 input_idx: int,
878 ) -> Tuple[float, int, int]:
879 """
880 Computes the constant coefficient factor for each leaf.
881 Additionally sine and cosine contributions of the input parameters for
882 this leaf are returned, which are required to obtain the corresponding
883 frequencies.
885 Args:
886 leaf (TreeLeaf): The leaf for which to compute the factor.
887 parameter_indices (List[int]): Variational parameter indices.
889 Returns:
890 Tuple[float, int, int]:
891 - float: the constant factor for the leaf
892 - int: number of sine contributions of the input
893 - int: number of cosine contributions of the input
894 """
895 leaf_factor = 1.0
896 for i in parameter_indices:
897 interm_factor = (
898 jnp.cos(self.parameters[i]) ** leaf.cos_indices[i]
899 * (1j * jnp.sin(self.parameters[i])) ** leaf.sin_indices[i]
900 )
901 leaf_factor = leaf_factor * interm_factor
903 # Get number of sine and cosine factors to which the input contributes
904 c = jnp.sum(
905 jnp.array([leaf.cos_indices[k] for k in self.input_indices[input_idx]])
906 )
907 s = jnp.sum(
908 jnp.array([leaf.sin_indices[k] for k in self.input_indices[input_idx]])
909 )
911 leaf_factor = leaf.term * leaf_factor * 0.5 ** (s + c)
913 return leaf_factor, int(s), int(c)
915 def _early_stopping_possible(
916 self, pauli_rotation_idx: int, observable: FourierTree.PauliOperator
917 ):
918 """
919 Checks if a node for an observable can be discarded as all expecation
920 values that can result through further branching are zero.
921 The method is mentioned in the paper by Nemkov et al.: If the one-hot
922 encoded indices for X/Y operations in the Pauli-rotation generators are
923 a basis for that of the observable, the node must be processed further.
924 If not, it can be discarded.
926 Args:
927 pauli_rotation_idx (int): Index of remaining Pauli rotation gates.
928 Gates itself are attributes of the class.
929 observable (FourierTree.PauliOperator): Current observable
930 """
931 xy_indices_obs = jnp.logical_or(
932 observable.xy_indices, self.pauli_rotations[pauli_rotation_idx].xy_indices
933 ).all()
935 return not xy_indices_obs
937 def _create_tree_node(
938 self,
939 observable: FourierTree.PauliOperator,
940 pauli_rotation_idx: int,
941 parameter_idx: Optional[int] = None,
942 is_sine: bool = False,
943 is_cosine: bool = False,
944 ) -> Optional[CoefficientsTreeNode]:
945 """
946 Builds the Fourier-Tree according to the algorithm by Nemkov et al.
948 Args:
949 observable (FourierTree.PauliOperator): Current observable
950 pauli_rotation_idx (int): Index of remaining Pauli rotation gates.
951 Gates itself are attributes of the class.
952 parameter_idx (Optional[int]): Index of the current parameter.
953 Parameters itself are attributes of the class.
954 is_sine (bool): If the current node is a sine (left) node.
955 is_cosine (bool): If the current node is a cosine (right) node.
957 Returns:
958 Optional[CoefficientsTreeNode]: The resulting node. Children are set
959 recursively. The top level receives the tree root.
960 """
961 if self._early_stopping_possible(pauli_rotation_idx, observable):
962 return None
964 # remove commuting paulis
965 while pauli_rotation_idx >= 0:
966 last_pauli = self.pauli_rotations[pauli_rotation_idx]
967 if not observable.is_commuting(last_pauli.list_repr):
968 break
969 pauli_rotation_idx -= 1
970 else: # leaf
971 return FourierTree.CoefficientsTreeNode(
972 parameter_idx, observable, is_sine, is_cosine
973 )
975 last_pauli = self.pauli_rotations[pauli_rotation_idx]
977 left = self._create_tree_node(
978 observable,
979 pauli_rotation_idx - 1,
980 pauli_rotation_idx,
981 is_cosine=True,
982 )
984 next_observable = self._create_new_observable(last_pauli.list_repr, observable)
985 right = self._create_tree_node(
986 next_observable,
987 pauli_rotation_idx - 1,
988 pauli_rotation_idx,
989 is_sine=True,
990 )
992 return FourierTree.CoefficientsTreeNode(
993 parameter_idx,
994 observable,
995 is_sine,
996 is_cosine,
997 left,
998 right,
999 )
1001 def _create_new_observable(
1002 self, pauli: jnp.ndarray[int], observable: FourierTree.PauliOperator
1003 ) -> FourierTree.PauliOperator:
1004 """
1005 Utility function to obtain the new observable for a tree node, if the
1006 last Pauli and the observable do not commute.
1008 Args:
1009 pauli (jnp.ndarray[int]): The int array representation of the last
1010 Pauli rotation in the operation sequence.
1011 observable (FourierTree.PauliOperator): The current observable.
1013 Returns:
1014 FourierTree.PauliOperator: The new observable.
1015 """
1016 observable = observable.tensor(pauli)
1017 return observable
1020class FCC:
1021 @classmethod
1022 def get_fcc(
1023 cls,
1024 model: Model,
1025 n_samples: int,
1026 random_key: Optional[random.PRNGKey] = None,
1027 method: Optional[str] = "pearson",
1028 scale: Optional[bool] = False,
1029 weight: Optional[bool] = False,
1030 trim_redundant: Optional[bool] = True,
1031 **kwargs,
1032 ) -> float:
1033 """
1034 Shortcut method to get just the FCC.
1035 This includes
1036 1. What is done in `get_fourier_fingerprint`:
1037 1. Calculating the coefficients (using `n_samples`)
1038 2. Correlating the result from 1) using `method`
1039 3. Weighting the correlation matrix (if `weight` is True)
1040 4. Remove redundancies
1041 2. What is done in `calculate_fcc`:
1042 1. Absolute of the fingerprint
1043 2. Average
1045 Args:
1046 model (Model): The QFM model
1047 n_samples (int): Number of samples to calculate average of coefficients
1048 random_key (Optional[random.PRNGKey]): JAX random key for parameter
1049 initialization. If None, uses the model's internal random key.
1050 method (Optional[str], optional): Correlation method. Supported values are
1051 "pearson", "complex_pearson", and "spearman". Defaults to "pearson".
1052 scale (Optional[bool], optional): Whether to scale the number of samples.
1053 Defaults to False.
1054 weight (Optional[bool], optional): Whether to weight the correlation matrix.
1055 Defaults to False.
1056 trim_redundant (Optional[bool], optional): Whether to remove redundant
1057 correlations. Defaults to False.
1058 **kwargs (Any): Additional keyword arguments for the model function.
1060 Returns:
1061 float: The FCC
1062 """
1064 # Memory-efficient fast path
1065 if trim_redundant and not weight:
1066 _, coeffs, freqs = cls._calculate_coefficients(
1067 model, n_samples, random_key, scale, **kwargs
1068 )
1069 pos_idx = cls._calculate_mask(freqs)
1070 coeffs_flat = coeffs.reshape(-1, coeffs.shape[-1])
1071 coeffs_sub = coeffs_flat[pos_idx]
1073 fp = cls._correlate(coeffs_sub.transpose(), method=method)
1074 abs_fp = jnp.abs(fp)
1075 diag = jnp.abs(jnp.diagonal(fp))
1077 total_sum = jnp.nansum(abs_fp)
1078 total_count = jnp.sum(jnp.isfinite(abs_fp))
1079 diag_sum = jnp.nansum(diag)
1080 diag_count = jnp.sum(jnp.isfinite(diag))
1082 lower_sum = (total_sum - diag_sum) / 2.0
1083 lower_count = (total_count - diag_count) / 2.0
1084 return lower_sum / lower_count
1086 fourier_fingerprint, _ = cls.get_fourier_fingerprint(
1087 model,
1088 n_samples,
1089 random_key,
1090 method,
1091 scale,
1092 weight,
1093 trim_redundant=trim_redundant,
1094 **kwargs,
1095 )
1097 return cls.calculate_fcc(fourier_fingerprint)
1099 @classmethod
1100 def get_fourier_fingerprint(
1101 cls,
1102 model: Model,
1103 n_samples: int,
1104 random_key: Optional[random.PRNGKey] = None,
1105 method: Optional[str] = "pearson",
1106 scale: Optional[bool] = False,
1107 weight: Optional[bool] = False,
1108 trim_redundant: Optional[bool] = True,
1109 nan_to_one: Optional[bool] = False,
1110 **kwargs: Any,
1111 ) -> Tuple[jnp.ndarray, jnp.ndarray]:
1112 """
1113 Shortcut method to get just the fourier fingerprint.
1114 This includes
1115 1. Calculating the coefficients (using `n_samples`)
1116 2. Correlating the result from 1) using `method`
1117 3. Weighting the correlation matrix (if `weight` is True)
1118 4. Remove redundancies (if `trim_redundant` is True)
1120 Args:
1121 model (Model): The QFM model
1122 n_samples (int): Number of samples to calculate average of coefficients
1123 random_key (Optional[random.PRNGKey]): JAX random key for parameter
1124 initialization. If None, uses the model's internal random key.
1125 method (Optional[str], optional): Correlation method. Supported values are
1126 "pearson", "complex_pearson", and "spearman". Defaults to "pearson".
1127 scale (Optional[bool], optional): Whether to scale the number of samples.
1128 Defaults to False.
1129 weight (Optional[bool], optional): Whether to weight the correlation matrix.
1130 Defaults to False.
1131 trim_redundant (Optional[bool], optional): Whether to remove redundant
1132 correlations. Defaults to True.
1133 nan_to_one (Optional[bool], optional): Whether to set nan to 1.
1134 Defaults to False.
1135 **kwargs: Additional keyword arguments for the model function.
1137 Returns:
1138 Tuple[jnp.ndarray, jnp.ndarray]: The fourier fingerprint and the
1139 corresponding frequency indices. If `trim_redundant` is True the
1140 frequencies are returned as a `(row_freqs, col_freqs)` tuple that
1141 labels the two (redundancy-trimmed) matrix axes; otherwise the
1142 full frequency vector is returned.
1143 """
1144 _, coeffs, freqs = cls._calculate_coefficients(
1145 model, n_samples, random_key, scale, **kwargs
1146 )
1148 # Memory-efficient fast path
1149 if trim_redundant and not weight:
1150 pos_idx = cls._calculate_mask(freqs)
1151 pos_freqs = cls._flat_frequencies(freqs)[pos_idx]
1153 # Flatten all frequency axes; the last axis is the sample
1154 # axis. `_calculate_mask` returns flat indices in C order,
1155 # matching this reshape.
1156 coeffs_flat = coeffs.reshape(-1, coeffs.shape[-1])
1157 coeffs_sub = coeffs_flat[pos_idx]
1159 fourier_fingerprint = cls._correlate(coeffs_sub.transpose(), method=method)
1161 if nan_to_one:
1162 fourier_fingerprint = jnp.where(
1163 jnp.isnan(fourier_fingerprint), 1.0, fourier_fingerprint
1164 )
1166 M = fourier_fingerprint.shape[0]
1167 lower_tri_mask = jnp.tri(M, k=-1, dtype=bool)
1168 fourier_fingerprint = jnp.where(
1169 lower_tri_mask, fourier_fingerprint, jnp.nan
1170 )
1172 row_mask = jnp.any(jnp.isfinite(fourier_fingerprint), axis=1)
1173 col_mask = jnp.any(jnp.isfinite(fourier_fingerprint), axis=0)
1174 fourier_fingerprint = fourier_fingerprint[row_mask][:, col_mask]
1176 return fourier_fingerprint, (pos_freqs[row_mask], pos_freqs[col_mask])
1178 fourier_fingerprint = cls._correlate(coeffs.transpose(), method=method)
1180 if nan_to_one:
1181 # set nan to 1
1182 fourier_fingerprint[jnp.isnan(fourier_fingerprint)] = 1.0
1184 # perform weighting if requested
1185 fourier_fingerprint = (
1186 cls._weighting_mean(fourier_fingerprint, coeffs)
1187 if weight
1188 else fourier_fingerprint
1189 )
1191 if trim_redundant:
1192 pos_idx = cls._calculate_mask(freqs)
1193 pos_freqs = cls._flat_frequencies(freqs)[pos_idx]
1195 # restrict to the positive-frequency sub-block (M x M with
1196 # M = number of non-negative flat-frequencies) instead of
1197 # building a full N x N mask. This avoids the O(N^2) float
1198 fourier_fingerprint = fourier_fingerprint[pos_idx][:, pos_idx]
1200 # keep only the strict lower triangle; the rest -> nan
1201 M = fourier_fingerprint.shape[0]
1202 lower_tri_mask = jnp.tri(M, k=-1, dtype=bool)
1203 fourier_fingerprint = jnp.where(
1204 lower_tri_mask, fourier_fingerprint, jnp.nan
1205 )
1207 row_mask = jnp.any(jnp.isfinite(fourier_fingerprint), axis=1)
1208 col_mask = jnp.any(jnp.isfinite(fourier_fingerprint), axis=0)
1210 fourier_fingerprint = fourier_fingerprint[row_mask][:, col_mask]
1212 return fourier_fingerprint, (pos_freqs[row_mask], pos_freqs[col_mask])
1214 return fourier_fingerprint, freqs
1216 @classmethod
1217 def calculate_fcc(
1218 cls,
1219 fourier_fingerprint: jnp.ndarray,
1220 ) -> float:
1221 """
1222 Method to calculate the FCC based on an existing correlation matrix.
1223 Calculate absolute and then the average over this matrix.
1224 The Fingerprint can be obtained via `get_fourier_fingerprint`
1226 Args:
1227 fourier_fingerprint (jnp.ndarray): Correlation matrix of coefficients
1228 Returns:
1229 float: The FCC
1230 """
1231 # apply the mask on the fingerprint
1232 return jnp.nanmean(jnp.abs(fourier_fingerprint))
1234 @classmethod
1235 def _calculate_mask(cls, freqs: jnp.ndarray) -> jnp.ndarray:
1236 """
1237 Determine the flat indices of the Fourier correlation matrix
1238 that lie on a non-negative-frequency row/column. Together with
1239 the strict-lower-triangle condition (handled by the caller),
1240 these indices select the entries of the correlation matrix
1241 that survive the redundancy filter applied in
1242 `get_fourier_fingerprint`:
1244 - rows/columns whose flat frequency component is negative are
1245 discarded (they are the complex-conjugate redundancies of
1246 their positive counterparts);
1247 - of the remaining positive-frequency sub-block, only the
1248 strict lower triangle is kept (the upper triangle, including
1249 the diagonal, contains either duplicates from symmetry or
1250 self-correlations).
1252 Args:
1253 freqs (jnp.ndarray): Array of frequencies. Either a 1-D
1254 vector (single input feature) or a 2-D array of shape
1255 ``(n_input_feat, K)`` whose rows are the per-axis
1256 frequency vectors.
1258 Returns:
1259 jnp.ndarray: 1-D int array of flat indices selecting the
1260 non-negative-frequency rows/cols of the fingerprint.
1261 """
1262 freqs_arr = jnp.asarray(freqs)
1264 if freqs_arr.ndim == 1:
1265 pos_flat = freqs_arr >= 0
1266 else:
1267 # N-D case: build the per-axis non-negativity masks and
1268 # combine them via broadcasting (no float `jnp.outer`!),
1269 # then flatten to match the row-major flattening used by
1270 # the upstream coefficient/correlation pipeline.
1271 axes_pos = [freqs_arr[i] >= 0 for i in range(freqs_arr.shape[0])]
1272 expanded = []
1273 n_axes = len(axes_pos)
1274 for i, p in enumerate(axes_pos):
1275 shape = [1] * n_axes
1276 shape[i] = p.shape[0]
1277 expanded.append(p.reshape(shape))
1278 nd_pos = reduce(jnp.logical_and, expanded)
1279 pos_flat = nd_pos.flatten()
1281 return jnp.where(pos_flat)[0]
1283 @classmethod
1284 def _flat_frequencies(cls, freqs: jnp.ndarray) -> jnp.ndarray:
1285 """
1286 Build the per-coefficient flat frequency labels in the same
1287 C-order used to flatten the coefficient/correlation pipeline, so
1288 they can be indexed by the flat indices from `_calculate_mask`.
1290 Args:
1291 freqs (jnp.ndarray): Either a 1-D vector (single input feature)
1292 or a ``(n_input_feat, K)`` stack / list of per-axis frequency
1293 vectors (multi-dim input).
1295 Returns:
1296 jnp.ndarray: 1-D frequency vector (single input feature) or a
1297 ``(N, n_input_feat)`` array of per-coefficient frequency
1298 tuples (multi-dim input).
1299 """
1300 fa = jnp.asarray(freqs)
1301 if fa.ndim == 1:
1302 return fa
1303 # Multi-dim: per-axis vectors -> flat grid of frequency tuples in the
1304 # same C-order used by `_calculate_mask` and the coefficient reshape.
1305 grids = jnp.meshgrid(*[fa[i] for i in range(fa.shape[0])], indexing="ij")
1306 return jnp.stack(grids, axis=-1).reshape(-1, fa.shape[0])
1308 @classmethod
1309 def _calculate_coefficients(
1310 cls,
1311 model: Model,
1312 n_samples: int,
1313 random_key: Optional[random.PRNGKey] = None,
1314 scale: bool = False,
1315 **kwargs: Any,
1316 ) -> Tuple[jnp.ndarray, jnp.ndarray]:
1317 """
1318 Calculates the Fourier coefficients of a given model
1319 using `n_samples`.
1320 Optionally, `noise_params` can be passed to perform noisy simulation.
1322 Args:
1323 model (Model): The QFM model
1324 n_samples (int): Number of samples to calculate average of coefficients
1325 random_key (Optional[random.PRNGKey]): JAX random key for parameter
1326 initialization. If None, uses the model's internal random key.
1327 scale (bool, optional): Whether to scale the number of samples.
1328 Defaults to False.
1329 **kwargs: Additional keyword arguments for the model function.
1331 Returns:
1332 Tuple[jnp.ndarray, jnp.ndarray]: Parameters and Coefficients of size NxK
1333 """
1334 if n_samples > 0:
1335 if scale:
1336 total_samples = int(
1337 jnp.power(2, model.n_qubits) * n_samples * model.n_input_feat
1338 )
1339 log.info(f"Using {total_samples} samples.")
1340 else:
1341 total_samples = n_samples
1342 model.initialize_params(random_key, repeat=total_samples)
1343 else:
1344 total_samples = 1
1346 coeffs, freqs = Coefficients.get_spectrum(
1347 model, shift=True, trim=True, **kwargs
1348 )
1350 return model.params, coeffs, freqs
1352 @classmethod
1353 def _correlate(cls, mat: jnp.ndarray, method: str = "pearson") -> jnp.ndarray:
1354 """
1355 Correlates two arrays using `method`.
1356 Currently, `pearson`, `complex_pearson`, and `spearman` are supported.
1358 Args:
1359 mat (jnp.ndarray): Array of shape (N, K)
1360 method (str, optional): Correlation method. Defaults to "pearson".
1362 Raises:
1363 ValueError: If the method is not supported.
1365 Returns:
1366 jnp.ndarray: Correlation matrix of `a` and `b`.
1367 """
1368 assert len(mat.shape) >= 2, "Input matrix must have at least 2 dimensions"
1370 # Note that for the general n-D case, we have to flatten along
1371 # the first axis (last one is batch).
1372 # Note that the order here is important so we can easily filter out
1373 # negative coefficients later.
1374 # Consider the following example: [[1,2,3],[4,5,6],[7,8,9]]
1375 # we want to get [1, 4, 7, 2, 5, 8, 3, 6, 9]
1376 # such that after correlation, all positive indexed coefficients
1377 # will be in the bottom right quadrant
1378 if method == "pearson":
1379 result = cls._pearson(mat.reshape(mat.shape[0], -1))
1380 # result = cls._pearson(mat.reshape(mat.shape[-1], -1, order="F"))
1381 elif method == "complex_pearson":
1382 result = cls._complex_pearson(mat.reshape(mat.shape[0], -1))
1383 elif method == "spearman":
1384 result = cls._spearman(mat.reshape(mat.shape[0], -1))
1385 # result = cls._spearman(mat.reshape(mat.shape[-1], -1, order="F"))
1386 else:
1387 raise ValueError(
1388 f"Unknown correlation method: {method}. \
1389 Must be 'pearson', 'complex_pearson' or 'spearman'."
1390 )
1392 return result
1394 @classmethod
1395 def _complex_pearson(
1396 cls, mat: jnp.ndarray, cov: Optional[bool] = False, minp: Optional[int] = 1
1397 ) -> jnp.ndarray:
1398 """
1399 Compute the complex Pearson correlation between columns of `mat`,
1400 permitting missing values (NaN or ±Inf).
1402 This uses the Hermitian normalized covariance
1403 sum(conj(x_i - mean_i) * (x_j - mean_j)) /
1404 sqrt(sum(abs(x_i - mean_i)**2) * sum(abs(x_j - mean_j)**2)).
1405 Consequently, if column j is exp(1j * phi) times column i, then
1406 abs(corr[i, j]) is 1 and angle(corr[i, j]) is phi.
1408 Args:
1409 mat : array_like, shape (N, K)
1410 Input data.
1411 cov : bool, optional
1412 If True, return the sample covariance matrix instead of
1413 correlation. Defaults to False.
1414 minp : int, optional
1415 Minimum number of paired observations required to form a correlation.
1416 If the number of valid pairs for (i, j) is < minp, the result is NaN.
1418 Returns:
1419 corr : ndarray, shape (K, K)
1420 Complex Pearson correlation matrix.
1421 """
1422 mat = jnp.asarray(mat)
1423 real_dtype = jnp.asarray(mat.real).dtype
1425 mask = jnp.isfinite(mat)
1426 fmask = mask.astype(real_dtype)
1427 safe = jnp.where(mask, mat, 0.0)
1429 nobs = fmask.T @ fmask
1430 nobs_safe = jnp.where(nobs > 0, nobs, 1.0)
1432 sum_x = safe.T @ fmask
1433 sum_y = fmask.T @ safe
1435 masked = safe * fmask
1436 sum_conj_xy = jnp.conj(masked).T @ masked
1438 safe_abs_sq = jnp.abs(safe) ** 2
1439 sum_abs_x2 = safe_abs_sq.T @ fmask
1440 sum_abs_y2 = fmask.T @ safe_abs_sq
1442 ssx = sum_abs_x2 - jnp.abs(sum_x) ** 2 / nobs_safe
1443 ssy = sum_abs_y2 - jnp.abs(sum_y) ** 2 / nobs_safe
1444 sxy = sum_conj_xy - (jnp.conj(sum_x) * sum_y) / nobs_safe
1446 if cov:
1447 denom = jnp.where(nobs > 1, nobs - 1, jnp.nan)
1448 result = sxy / denom
1449 else:
1450 denom = jnp.sqrt(ssx * ssy)
1451 result = jnp.where(denom > 0, sxy / denom, jnp.nan)
1452 magnitude = jnp.abs(result)
1453 result = jnp.where(magnitude > 1.0, result / magnitude, result)
1455 result = jnp.where(nobs < minp, jnp.nan, result)
1457 return result
1459 @classmethod
1460 def _pearson(
1461 cls, mat: jnp.ndarray, cov: Optional[bool] = False, minp: Optional[int] = 1
1462 ) -> jnp.ndarray:
1463 """
1464 Based on Pandas correlation method as implemented here:
1465 https://github.com/pandas-dev/pandas/blob/main/pandas/_libs/algos.pyx
1467 Compute Pearson correlation between columns of `mat`,
1468 permitting missing values (NaN or ±Inf).
1470 If the input is complex, real and imaginary parts are stacked along
1471 the sample axis so that both components contribute to the correlation
1472 without discarding information.
1474 Args:
1475 mat : array_like, shape (N, K)
1476 Input data.
1477 cov : bool, optional
1478 If True, return the sample covariance matrix instead of
1479 correlation. Defaults to False.
1480 minp : int, optional
1481 Minimum number of paired observations required to form a correlation.
1482 If the number of valid pairs for (i, j) is < minp, the result is NaN.
1484 Returns:
1485 corr : ndarray, shape (K, K)
1486 Pearson correlation matrix.
1487 """
1488 # Preserve complex information by splitting into real / imag samples
1489 if jnp.iscomplexobj(mat):
1490 mat = jnp.concatenate([mat.real, mat.imag], axis=0)
1492 mat = jnp.asarray(mat)
1494 # pre-compute finite mask (N, K)
1495 mask = jnp.isfinite(mat)
1496 fmask = mask.astype(mat.dtype)
1498 # Replace non-finite entries with 0 so arithmetic is safe;
1499 # the mask keeps track of validity.
1500 safe = jnp.where(mask, mat, 0.0)
1502 # Pairwise valid-observation counts (K, K)
1503 nobs = fmask.T @ fmask
1505 # Pairwise sums (only over mutually valid rows)
1506 # For columns i, j the "valid" rows are mask[:,i] & mask[:,j].
1507 # sum_x[i,j] = sum of mat[:,i] where both i and j are valid.
1508 # Using: safe[:,i] * mask[:,j] zeroes out rows invalid for j.
1509 # Then summing over N gives sum_x[i,j].
1510 # safe.T @ fmask gives (K, K) where entry (i,j) = sum of safe[:,i]*mask[:,j]
1511 sum_x = safe.T @ fmask # (K, K) – row-var sums
1512 sum_y = fmask.T @ safe # (K, K) – col-var sums
1514 # Note: explicit means (sum_x/nobs, sum_y/nobs) are not needed as
1515 # separate variables — the computational formula used below
1516 # (e.g. ssx = Σx² − (Σx)²/n) implicitly handles mean-centering.
1518 # Cross products, sum-of-squares via computational formula:
1519 # ssx = Σx² − (Σx)²/n, ssy = Σy² − (Σy)²/n,
1520 # sxy = Σxy − (Σx)(Σy)/n
1521 # All sums are taken over the pairwise-valid subset for each (i,j).
1522 masked = safe * fmask # same as safe but explicit
1523 sum_xy = masked.T @ masked # (K, K)
1525 # ssx[i,j] = sum_xx_ij - nobs * mean_x^2 (but sum_xx_ij uses pair mask)
1526 # We need sum of x^2 over the *pair* mask, not just column mask.
1527 # sum_x2[i,j] = sum_n safe[n,i]^2 * mask[n,i] * mask[n,j]
1528 safe_sq = safe**2
1529 sum_x2 = safe_sq.T @ fmask # (K, K)
1530 sum_y2 = fmask.T @ safe_sq # (K, K)
1532 ssx = sum_x2 - sum_x**2 / jnp.where(nobs > 0, nobs, 1.0)
1533 ssy = sum_y2 - sum_y**2 / jnp.where(nobs > 0, nobs, 1.0)
1534 sxy = sum_xy - (sum_x * sum_y) / jnp.where(nobs > 0, nobs, 1.0)
1536 if cov:
1537 denom = jnp.where(nobs > 1, nobs - 1, jnp.nan)
1538 result = sxy / denom
1539 else:
1540 denom = jnp.sqrt(ssx * ssy)
1541 result = jnp.where(denom > 0, sxy / denom, jnp.nan)
1542 # clip numerical drift to [-1, 1]
1543 result = jnp.clip(result, -1.0, 1.0)
1545 # Enforce minp: set entries with too few observations to NaN
1546 result = jnp.where(nobs < minp, jnp.nan, result)
1548 return result
1550 @classmethod
1551 def _spearman(cls, mat: jnp.ndarray, minp: Optional[int] = 1) -> jnp.ndarray:
1552 """
1553 Based on Pandas correlation method as implemented here:
1554 https://github.com/pandas-dev/pandas/blob/main/pandas/_libs/algos.pyx
1556 Compute Spearman correlation between columns of `mat`,
1557 permitting missing values (NaN or ±Inf).
1559 If the input is complex, real and imaginary parts are stacked along
1560 the sample axis so that both components contribute to the correlation
1561 without discarding information.
1563 Args:
1564 mat : array_like, shape (N, K)
1565 Input data.
1566 minp : int, optional
1567 Minimum number of paired observations required to form a correlation.
1568 If the number of valid pairs for (i, j) is < minp, the result is NaN.
1570 Returns:
1571 corr : ndarray, shape (K, K)
1572 Spearman correlation matrix.
1573 """
1574 # Preserve complex information by splitting into real / imag samples
1575 if jnp.iscomplexobj(mat):
1576 mat = jnp.concatenate([mat.real, mat.imag], axis=0)
1578 mat = jnp.asarray(mat)
1579 N, K = mat.shape
1581 # trivial all-NaN answer if too few rows
1582 if N < minp:
1583 return jnp.full((K, K), jnp.nan)
1585 # mask of finite entries
1586 mask = jnp.isfinite(mat) # shape (N, K), dtype=bool
1588 # precompute ranks column-wise ignoring NaNs
1589 ranks = np.full((N, K), np.nan)
1590 for j in range(K):
1591 valid = mask[:, j]
1592 if valid.any():
1593 ranks[valid, j] = rankdata(mat[valid, j], method="average")
1595 ranks = jnp.asarray(ranks)
1597 # Vectorised Pearson on the ranks
1598 # Replace NaN ranks with 0; use mask to track validity.
1599 rank_mask = jnp.isfinite(ranks)
1600 safe_ranks = jnp.where(rank_mask, ranks, 0.0)
1602 # Pairwise valid-observation counts (K, K)
1603 fmask = rank_mask.astype(ranks.dtype)
1604 nobs = fmask.T @ fmask
1606 # Pairwise sums over mutually-valid rows
1607 sum_x = safe_ranks.T @ fmask # (K, K)
1608 sum_y = fmask.T @ safe_ranks # (K, K)
1610 # Pairwise products
1611 masked_ranks = safe_ranks * fmask # same as safe_ranks
1612 sum_xy = masked_ranks.T @ masked_ranks # (K, K)
1614 safe_sq = safe_ranks**2
1615 sum_x2 = safe_sq.T @ fmask # (K, K)
1616 sum_y2 = fmask.T @ safe_sq # (K, K)
1618 nobs_safe = jnp.where(nobs > 0, nobs, 1.0)
1619 ssx = sum_x2 - sum_x**2 / nobs_safe
1620 ssy = sum_y2 - sum_y**2 / nobs_safe
1621 sxy = sum_xy - (sum_x * sum_y) / nobs_safe
1623 denom = jnp.sqrt(ssx * ssy)
1624 result = jnp.where(denom > 0, sxy / denom, jnp.nan)
1625 result = jnp.clip(result, -1.0, 1.0)
1627 # Enforce minp
1628 result = jnp.where(nobs < minp, jnp.nan, result)
1630 return result
1632 @classmethod
1633 def _weighting_linear(cls, fourier_fingerprint: jnp.ndarray) -> jnp.ndarray:
1634 """
1635 Performs weighting on the given correlation matrix.
1636 Here, low-frequent coefficients are weighted more heavily.
1638 Args:
1639 fourier_fingerprint (jnp.ndarray): Correlation matrix
1640 """
1641 assert (
1642 fourier_fingerprint.shape[0] % 2 != 0
1643 and fourier_fingerprint.shape[1] % 2 != 0
1644 ), (
1645 "Correlation matrix must have odd dimensions. \
1646 Hint: use `trim` argument when calling `get_spectrum`."
1647 )
1648 assert fourier_fingerprint.shape[0] == fourier_fingerprint.shape[1], (
1649 "Correlation matrix must be square."
1650 )
1652 # The weight matrix produced by the previous quadrant-mirror
1653 # construction has a closed form: it is a "tent" sum along the
1654 # two axes. Concretely, with N = fourier_fingerprint.shape[0]
1655 # (odd) and center = N // 2,
1656 # W[i, j] = u[i] + u[j]
1657 # where u[k] = (center - |k - center|) / (2 * center)
1658 # is a triangular weighting peaking at the centre (the zero
1659 # frequency) and decaying linearly to 0 at the spectrum edges.
1660 N = fourier_fingerprint.shape[0]
1661 center = N // 2
1662 k = jnp.arange(N)
1663 u = (center - jnp.abs(k - center)) / (2 * center)
1665 return fourier_fingerprint * (u[:, None] + u[None, :])
1667 @classmethod
1668 def _weighting_mean(
1669 cls, fourier_fingerprint: jnp.ndarray, coeffs: jnp.ndarray
1670 ) -> jnp.ndarray:
1671 """
1672 Performs weighting on the given correlation matrix.
1673 Here, we use the product of the mean of the coefficients as weights.
1674 This suppresses correlations where the mean of the coefficients is near zero.
1676 Args:
1677 fourier_fingerprint (jnp.ndarray): Correlation matrix
1678 coeffs (jnp.ndarray): Fourier coefficients
1679 """
1680 assert fourier_fingerprint.shape[0] == fourier_fingerprint.shape[1], (
1681 "Correlation matrix must be square."
1682 )
1683 assert len(coeffs.shape) >= 2, (
1684 "Coefficient matrix must contain coefficient axes and a sample axis."
1685 )
1687 coefficient_means = jnp.abs(jnp.mean(coeffs, axis=-1))
1688 coefficient_means = coefficient_means.T.reshape(-1)
1690 assert fourier_fingerprint.shape[0] == coefficient_means.shape[0], (
1691 "Correlation matrix size must match the number of Fourier coefficients."
1692 )
1694 # Apply the rank-1 weight w[i] * w[j] via broadcasting instead
1695 # of materialising an explicit `jnp.outer` N x N intermediate.
1696 return (
1697 fourier_fingerprint
1698 * coefficient_means[:, None]
1699 * coefficient_means[None, :]
1700 )
1703class Datasets:
1704 @classmethod
1705 def generate_fourier_series(
1706 cls,
1707 random_key: random.PRNGKey,
1708 model: Model,
1709 coefficients_min: float = 0.0,
1710 coefficients_max: float = 1.0,
1711 zero_centered: bool = False,
1712 ) -> jnp.ndarray:
1713 """
1714 Generates the Fourier series representation of a function.
1715 It uses the `model.frequencies` property to retrieve the frequency
1716 information. This ensures that the resulting Fourier series is
1717 compatible with the model.
1719 This function is capable of generating $D$-dimensional Fourier series
1720 (again defined by `model.n_input_feat`).
1721 The highest frequency $N$ is retrieved per dimension.
1723 Samples of the Fourier coefficients are drawn from a uniform circle.
1725 Args:
1726 random_key (random.PRNGKey): Random number key for JAX.
1727 model (Model): The quantum circuit model.
1728 coefficients_min (float, optional): Minimum value for the coefficients.
1729 Defaults to 0.0.
1730 coefficients_max (float, optional): Maximum value for the coefficients.
1731 Defaults to 1.0.
1732 zero_centered (bool, optional): Whether to zero-center the coefficients.
1733 Defaults to False.
1735 Returns:
1736 jnp.ndarray: Input domain samples with shape ((N,)*D, D)
1737 jnp.ndarray: Fourier series values with shape ((N,)*D)
1738 jnp.ndarray: Fourier coefficients with shape ((N,)*D)
1740 """
1741 # TODO: the following code can be considered to
1742 # capturing a truly random spectrum.
1743 # add some constraints on the spectrum, i.e. not fully
1745 # Note: one key observation for understanding the following code is,
1746 # that instead of wrapping your head around symmetries in multi-
1747 # dimensional coefficient matrices, one can simply look at the flattened
1748 # version of such a matrix and reshape later. It just works out.
1750 # going from [0, 2pi] with the resolution required for highest frequency
1751 # permute with input dimensionality to get an n-d grid of domain samples
1752 # the output shape comes from the fact that want to create a "coordinate system"
1753 domain_samples_per_input_dim = jnp.stack(
1754 jnp.meshgrid(
1755 *[jnp.arange(0, 2 * jnp.pi, 2 * jnp.pi / d) for d in model.degree]
1756 )
1757 ).T.reshape(-1, model.n_input_feat)
1759 # generate the frequency indices for each dimension.
1760 # this will have the same shape as the domain samples
1761 frequencies = jnp.stack(jnp.meshgrid(*model.frequencies)).T.reshape(
1762 -1, model.n_input_feat
1763 )
1765 # using the frequency information, sample coefficients for each dimension
1766 # shape: (input_dims, n_freqs_per_input_dim // 2 + 1)
1768 coefficients = cls.uniform_circle(
1769 random_key,
1770 low=coefficients_min,
1771 high=coefficients_max,
1772 size=math.prod(model.degree) // 2 + 1,
1773 )
1775 # zero center (first coeff = 0)
1776 # we can assume the first coeff is the offset, because we're dealing
1777 # with a non-symmetric spectrum here
1778 if zero_centered:
1779 coefficients = coefficients.at[0].set(0.0)
1780 else:
1781 coefficients = coefficients.at[0].set(coefficients[0].real)
1783 # ensure symmetry (here, non_negative_ is removed!),
1784 # giving us the full coefficients vector
1785 coefficients = jnp.concat(
1786 [
1787 jnp.flip(coefficients[..., 1:]).conjugate(),
1788 coefficients,
1789 ],
1790 axis=-1,
1791 )
1793 # Vectorized version of $f(x) = \sum_{n=0}^{N-1} c_n * e^{i * \omega_n * x}$
1794 # it takes into account the input dimension, i.e. the output is a matrix
1795 # normalization uses the n_freqs component of the coefficients
1796 values = jnp.real(
1797 (
1798 jnp.exp(1j * (domain_samples_per_input_dim @ frequencies.T))
1799 * coefficients
1800 ).sum(axis=1)
1801 / coefficients.size
1802 )
1804 # return all the information we have
1805 return [
1806 domain_samples_per_input_dim.reshape(*model.degree, -1),
1807 values.reshape(model.degree),
1808 coefficients.reshape(model.degree),
1809 ]
1811 @classmethod
1812 def uniform_circle(
1813 cls,
1814 random_key: random.PRNGKey,
1815 size: Union[jnp.ndarray, List, int],
1816 low=0.0,
1817 high=1.0,
1818 ):
1819 """
1820 Random number generator for complex numbers sampled inside the unit circle
1822 Args:
1823 random_key (random.PRNGKey): Random number key for JAX.
1824 size (Union[jnp.ndarray, int]): Number of samples. If a 2D array is passed,
1825 the first dimension will be the number of dimensions.
1826 low (float, optional): Minimum Radius. Defaults to 0.0.
1827 high (float, optional): Maximum Radius. Defaults to 1.0.
1829 Returns
1830 jnp.ndarray: Array of complex numbers with shape of `size`
1831 """
1833 if isinstance(size, int):
1834 size = jnp.array([size])
1836 random_key, random_key1 = random.split(random_key)
1837 return jnp.sqrt(
1838 random.uniform(random_key, size, minval=low, maxval=high)
1839 ) * jnp.exp(2j * jnp.pi * random.uniform(random_key1, size))