Coverage for qml_essentials / coefficients.py: 96%
643 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-11 15:51 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-06-11 15:51 +0000
1from __future__ import annotations
2import sys
3import math
4import warnings
5import itertools
6from collections import defaultdict
7import jax.numpy as jnp
8from jax import random
9import numpy as np
10from scipy.stats import rankdata
11from functools import reduce, lru_cache
12from typing import List, Tuple, Optional, Any, Dict, Union
14from qml_essentials.model import Model
15from qml_essentials.pauli import PauliCircuit
16from qml_essentials.operations import PauliWord
18import logging
20log = logging.getLogger(__name__)
23class Coefficients:
24 @classmethod
25 def get_spectrum(
26 cls,
27 model: Model,
28 mfs: int = 1,
29 mts: int = 1,
30 shift=False,
31 trim=False,
32 numerical_cap: Optional[float] = -1,
33 **kwargs,
34 ) -> Tuple[jnp.ndarray, jnp.ndarray]:
35 """
36 Extracts the coefficients of a given model using a FFT (jnp-fft).
38 Note that the coefficients are complex numbers, but the imaginary part
39 of the coefficients should be very close to zero, since the expectation
40 values of the Pauli operators are real numbers.
42 It can perform oversampling in both the frequency and time domain
43 using the `mfs` and `mts` arguments.
45 Args:
46 model (Model): The model to sample.
47 mfs (int): Multiplicator for the highest frequency. Default is 2.
48 mts (int): Multiplicator for the number of time samples. Default is 1.
49 shift (bool): Whether to apply jnp-fftshift. Default is False.
50 trim (bool): Whether to remove the Nyquist frequency if spectrum is even.
51 Default is False.
52 numerical_cap (Optional[float]): Numerical cap for the coefficients.
53 If positive, coefficients with magnitude below the cap are
54 zeroed and, for a single input feature, frequencies that
55 vanish entirely are removed from both `coeffs` and `freqs`.
56 kwargs (Any): Additional keyword arguments for the model function.
58 Returns:
59 Tuple[jnp.ndarray, jnp.ndarray]: Tuple containing the coefficients
60 and frequencies.
61 """
62 kwargs.setdefault("force_mean", True)
63 kwargs.setdefault("execution_type", "expval")
65 coeffs, freqs = cls._fourier_transform(model, mfs=mfs, mts=mts, **kwargs)
67 if not jnp.isclose(jnp.sum(coeffs).imag, 0.0, atol=1.0e-6):
68 raise ValueError(
69 f"Spectrum is not real. Imaginary part of coefficients is:\
70 {jnp.sum(coeffs).imag}"
71 )
73 if trim:
74 for ax in range(model.n_input_feat):
75 if coeffs.shape[ax] % 2 == 0:
76 coeffs = np.delete(coeffs, len(coeffs) // 2, axis=ax)
77 freqs = [np.delete(freq, len(freq) // 2, axis=ax) for freq in freqs]
79 if shift:
80 coeffs = jnp.fft.fftshift(coeffs, axes=list(range(model.n_input_feat)))
81 freqs = np.fft.fftshift(freqs)
83 if numerical_cap > 0:
84 # set coeffs below threshold to zero
85 coeffs = jnp.where(
86 jnp.abs(coeffs) < numerical_cap,
87 jnp.zeros_like(coeffs),
88 coeffs,
89 )
91 # Drop frequencies whose coefficients vanish entirely after
92 # capping, so the returned spectrum reflects only the surviving
93 # frequencies. Well-defined only for a single (1-D) frequency
94 # axis; for multi-dim input the rectangular grid is left intact.
95 if model.n_input_feat == 1:
96 if coeffs.ndim == 1:
97 surviving = coeffs != 0
98 else:
99 surviving = jnp.any(coeffs != 0, axis=tuple(range(1, coeffs.ndim)))
100 coeffs = coeffs[surviving]
101 freqs = [freqs[0][surviving]]
103 if len(freqs) == 1:
104 freqs = freqs[0]
106 return coeffs, freqs
108 @classmethod
109 def _fourier_transform(
110 cls, model: Model, mfs: int, mts: int, **kwargs: Any
111 ) -> jnp.ndarray:
112 # Create a frequency vector with as many frequencies as model degrees,
113 # oversampled by mfs
114 n_freqs: jnp.ndarray = jnp.array(
115 [mfs * model.degree[i] for i in range(model.n_input_feat)]
116 )
118 start, stop, step = 0, 2 * mts * jnp.pi, 2 * jnp.pi / n_freqs
119 # Stretch according to the number of frequencies
120 inputs: List = [
121 jnp.arange(start, stop, step[i]) for i in range(model.n_input_feat)
122 ]
124 # permute with input dimensionality
125 nd_inputs = jnp.array(
126 jnp.meshgrid(*[inputs[i] for i in range(model.n_input_feat)])
127 ).T.reshape(-1, model.n_input_feat)
129 # Output vector is not necessarily the same length as input
130 outputs = model(inputs=nd_inputs, **kwargs)
131 outputs = outputs.reshape(
132 *[inputs[i].shape[0] for i in range(model.n_input_feat)], -1
133 ).squeeze()
135 coeffs = jnp.fft.fftn(outputs, axes=list(range(model.n_input_feat)))
137 freqs = [
138 jnp.fft.fftfreq(int(mts * n_freqs[i]), 1 / n_freqs[i])
139 for i in range(model.n_input_feat)
140 ]
141 # freqs = jnp.fft.fftfreq(mts * n_freqs, 1 / n_freqs)
143 # TODO: this could cause issues with multidim input
144 # FIXME: account for different frequencies in multidim input scenarios
145 # Run the fft and rearrange +
146 # normalize the output (using product if multidim)
147 return (
148 coeffs / math.prod(outputs.shape[0 : model.n_input_feat]),
149 freqs,
150 )
152 @classmethod
153 def get_psd(cls, coeffs: jnp.ndarray) -> jnp.ndarray:
154 """
155 Calculates the power spectral density (PSD) from given Fourier coefficients.
157 Args:
158 coeffs (jnp.ndarray): The Fourier coefficients.
160 Returns:
161 jnp.ndarray: The power spectral density.
162 """
163 # TODO: if we apply trim=True in advance, this will be slightly wrong..
165 def abs2(x):
166 return x.real**2 + x.imag**2
168 scale = 2.0 / (len(coeffs) ** 2)
169 return scale * abs2(coeffs)
171 @classmethod
172 def evaluate_Fourier_series(
173 cls,
174 coefficients: jnp.ndarray,
175 frequencies: jnp.ndarray,
176 inputs: Union[jnp.ndarray, list, float],
177 ) -> float:
178 """
179 Evaluate the function value of a Fourier series at one point.
181 Args:
182 coefficients (jnp.ndarray): Coefficients of the Fourier series.
183 frequencies (jnp.ndarray): Corresponding frequencies.
184 inputs (jnp.ndarray): Point at which to evaluate the function.
185 Returns:
186 float: The function value at the input point.
187 """
188 coefficients = jnp.asarray(coefficients)
190 def flatten_grid(freq_axes):
191 freq_axes = [jnp.asarray(freq) for freq in freq_axes]
192 freq_grid = jnp.stack(jnp.meshgrid(*freq_axes, indexing="ij"), axis=-1)
193 flat_frequencies = freq_grid.reshape(-1, len(freq_axes))
194 flat_coefficients = coefficients.reshape(
195 flat_frequencies.shape[0], *coefficients.shape[len(freq_axes) :]
196 )
197 return flat_coefficients, flat_frequencies
199 if isinstance(frequencies, list):
200 flat_coefficients, flat_frequencies = flatten_grid(frequencies)
201 else:
202 frequencies = jnp.asarray(frequencies)
203 if frequencies.ndim == 1:
204 flat_frequencies = frequencies[:, jnp.newaxis]
205 flat_coefficients = coefficients.reshape(
206 flat_frequencies.shape[0], *coefficients.shape[1:]
207 )
208 else:
209 n_features, n_axis_freqs = frequencies.shape
210 is_axis_frequencies = (
211 coefficients.shape[:n_features] == (n_axis_freqs,) * n_features
212 )
214 if is_axis_frequencies:
215 flat_coefficients, flat_frequencies = flatten_grid(frequencies)
216 else:
217 flat_frequencies = frequencies
218 flat_coefficients = coefficients.reshape(
219 flat_frequencies.shape[0], *coefficients.shape[1:]
220 )
222 inputs = jnp.asarray(inputs)
223 if inputs.ndim == 0:
224 inputs = inputs.reshape(1, 1)
225 elif inputs.ndim == 1:
226 if flat_frequencies.shape[1] == 1:
227 inputs = inputs[:, jnp.newaxis]
228 elif inputs.shape[0] == flat_frequencies.shape[1]:
229 inputs = inputs[jnp.newaxis, :]
230 else:
231 inputs = jnp.repeat(
232 inputs[:, jnp.newaxis], flat_frequencies.shape[1], axis=1
233 )
234 exponents = jnp.exp(1j * (inputs @ flat_frequencies.T))
235 exp = jnp.tensordot(exponents, flat_coefficients, axes=([1], [0]))
237 return jnp.squeeze(jnp.real(exp))
240class FourierTree:
241 """
242 Sine-cosine tree representation for the algorithm by Nemkov et al.
244 Computes the analytical Fourier coefficients/frequencies of a Pauli-Clifford
245 circuit. The symbolic structure of the tree (which Pauli rotations
246 contribute sine/cosine factors to which leaf, and the leaf observables) is
247 built once in NumPy; the parameter-dependent coefficients are then obtained
248 with a small number of vectorised JAX operations, so the result remains
249 jittable / differentiable with respect to the model parameters.
251 The resulting spectrum is the d-dimensional set of frequency vectors,
252 where $d$ is the input dimensionality.
254 **Usage**:
255 ```
256 model = Model(...)
257 tree = FourierTree(model)
258 exp = tree() # expectation value
259 coeff_list, freq_list = tree.get_spectrum()
260 ```
261 """
263 def __init__(self, model: Model):
264 """
265 Tree initialisation, based on the Pauli-Clifford representation of a
266 model.
268 Args:
269 model (Model): The Model, for which to build the tree.
270 """
271 self.model = model
272 self.n_qubits = model.n_qubits
274 # A single (de-batched) parameter set drives the whole tree.
275 self._params = self._single_param_set(model.params)
277 # Canonical Pauli-Clifford structure, recorded once at a fixed base
278 # input. The base value is irrelevant to the structure (it only sets
279 # the rotation angles, not which Pauli words appear).
280 base_inputs = np.ones(model.n_input_feat)
281 operations, observables = self._build_canonical_tape(self._params, base_inputs)
283 self.parameters = [
284 jnp.squeeze(p) for p in PauliCircuit.get_parameters(operations)
285 ]
286 self.n_params = len(self.parameters)
288 # Pauli generators of the (canonical) rotations, as symbolic words.
289 self.pauli_words: List[PauliWord] = [
290 PauliWord.from_operation(op, self.n_qubits) for op in operations
291 ]
293 # Cumulative X/Y support of the rotations[0..k] (for light-cone early
294 # stopping). cumulative_xy[k] is True on every qubit touched by an X/Y
295 # generator in any rotation up to index k.
296 self.cumulative_xy: List[np.ndarray] = []
297 running = np.zeros(self.n_qubits, dtype=bool)
298 for pw in self.pauli_words:
299 running = np.logical_or(running, pw.xy_mask)
300 self.cumulative_xy.append(running.copy())
302 # Observable Pauli words (one tree root each).
303 self.observable_words: List[PauliWord] = [
304 PauliWord.from_operation(obs, self.n_qubits) for obs in observables
305 ]
307 # Identify the input-encoding columns, their feature, and integer
308 # frequency scaling directly from the tape (no per-gate tagging). Sets
309 # ``input_indices``, ``all_input_indices``, ``input_scaling``,
310 # ``var_positions`` and ``features``.
311 self._detect_inputs(base_inputs)
313 # The explicit leaf structure is built lazily: for deep circuits the
314 # number of tree paths explodes combinatorially, while the canonical
315 # form above (and the merged-state support DP) remain cheap.
316 self._structure_built = False
318 def _ensure_structure(self) -> None:
319 """Build the explicit leaf/spectrum structure on first use."""
320 if not self._structure_built:
321 # Symbolic structure: per root (S, C, terms) leaf arrays ...
322 self._build_leaf_arrays()
323 # ... and the parameter-independent frequency/weight structure.
324 self._build_spectrum_structure()
325 self._structure_built = True
327 def _single_param_set(self, params) -> jnp.ndarray:
328 """De-batch the model parameters to the single set the tree describes.
330 Models can carry batched parameters (e.g. after FCC sampling); the tree
331 is defined for one set, so fall back to the first and warn.
332 """
333 params = jnp.asarray(params)
334 if params.ndim > 2 and params.shape[0] > 1:
335 warnings.warn(
336 f"FourierTree supports a single parameter set; using the first "
337 f"of {params.shape[0]} batched parameter sets.",
338 UserWarning,
339 )
340 params = params[0]
341 return params
343 def _build_canonical_tape(self, params, inputs):
344 """Record the circuit and transform it to Pauli-Clifford normal form.
346 Returns the ``(operations, observables)`` of the canonical circuit
347 (see :meth:`PauliCircuit.from_parameterised_circuit`).
348 """
349 params = self._single_param_set(params)
350 inputs = self.model._inputs_validation(inputs)
351 raw_tape = self.model.script._record(params=params, inputs=inputs)
352 _, obs_list = self.model._build_obs()
353 return PauliCircuit.from_parameterised_circuit(
354 raw_tape, observables=obs_list, n_qubits=self.n_qubits
355 )
357 def _canonical_parameters(self, inputs) -> np.ndarray:
358 """Recorded canonical rotation angles (1-D float array) for ``inputs``."""
359 operations, _ = self._build_canonical_tape(self._params, inputs)
360 return np.array(
361 [float(jnp.squeeze(p)) for p in PauliCircuit.get_parameters(operations)]
362 )
364 def _detect_inputs(self, base_inputs: np.ndarray) -> None:
365 r"""Infer the input-encoding columns directly from the tape (tag-free).
367 Each encoding rotation applies an angle :math:`\omega_k\,x_f` that is
368 linear in a single input feature :math:`x_f`, and Clifford commutation
369 only multiplies a rotation generator by :math:`\pm 1`. Every canonical
370 rotation angle is therefore an affine function of the inputs, so
371 perturbing one feature at a time and differencing the recorded angles
372 isolates exactly the columns that depend on it, together with the
373 signed integer scaling :math:`\omega_k`.
375 Sets :attr:`input_indices` (``{feature: [columns]}``),
376 :attr:`all_input_indices`, :attr:`input_scaling` (per column, ``1`` for
377 variational columns), :attr:`var_positions`, and :attr:`features`.
379 Raises:
380 NotImplementedError: If a rotation depends on more than one feature
381 (the tree requires single-feature encodings).
382 """
383 tol = 1e-6
384 d = self.model.n_input_feat
385 base = np.asarray(base_inputs, dtype=float)
386 p_base = np.array([float(p) for p in self.parameters])
388 # response[f, k] = d(angle_k) / d(x_f), the linear response of column k.
389 response = np.zeros((d, self.n_params))
390 for f in range(d):
391 step = base.copy()
392 step[f] += 1.0
393 response[f] = self._canonical_parameters(step) - p_base
395 input_indices: Dict[int, list] = defaultdict(list)
396 all_input_indices: List[int] = []
397 scaling = np.ones(self.n_params, dtype=np.int64)
398 for k in range(self.n_params):
399 feats = np.flatnonzero(np.abs(response[:, k]) > tol)
400 if feats.size == 0:
401 continue # variational column
402 if feats.size > 1:
403 raise NotImplementedError(
404 f"Rotation {k} depends on multiple input features "
405 f"{feats.tolist()}; the Fourier tree requires each encoding "
406 "rotation to be linear in a single feature."
407 )
408 f = int(feats[0])
409 omega = float(response[f, k])
410 w = int(round(omega))
411 if abs(omega - w) > tol:
412 warnings.warn(
413 f"Non-integer input scaling {omega:.4f} on rotation {k} "
414 f"(feature {f}); rounding to {w}. The Fourier tree supports "
415 "integer frequency scalings only.",
416 UserWarning,
417 )
418 input_indices[f].append(k)
419 all_input_indices.append(k)
420 scaling[k] = w
422 self.input_indices = input_indices
423 self.all_input_indices = all_input_indices
424 self.input_scaling = scaling
425 input_set = set(all_input_indices)
426 self.var_positions = np.array(
427 [i for i in range(self.n_params) if i not in input_set], dtype=np.int64
428 )
429 # Ordered list of input feature keys (d-dimensional spectrum).
430 self.features = sorted(input_indices.keys())
432 # Symbolic tree construction (NumPy)
433 def _build_leaf_arrays(self) -> None:
434 """Collect the tree leaves for every root into integer count matrices.
436 For each root (observable) this produces:
437 - ``S``: (n_leaves, n_params) sine-factor counts per parameter,
438 - ``C``: (n_leaves, n_params) cosine-factor counts per parameter,
439 - ``terms``: (n_leaves,) complex leaf constants ``<0|O_leaf|0>``.
440 """
441 self.leaf_arrays: List[Tuple[np.ndarray, np.ndarray, np.ndarray]] = []
442 for obs_word in self.observable_words:
443 leaves: List[Tuple[np.ndarray, np.ndarray, complex]] = []
444 zeros = np.zeros(self.n_params, dtype=np.int64)
445 self._collect_leaves(
446 obs_word, self.n_params - 1, zeros.copy(), zeros.copy(), leaves
447 )
448 if leaves:
449 S = np.stack([leaf[0] for leaf in leaves])
450 C = np.stack([leaf[1] for leaf in leaves])
451 terms = np.array([leaf[2] for leaf in leaves], dtype=np.complex128)
452 else:
453 S = np.zeros((0, self.n_params), dtype=np.int64)
454 C = np.zeros((0, self.n_params), dtype=np.int64)
455 terms = np.zeros(0, dtype=np.complex128)
456 self.leaf_arrays.append((S, C, terms))
458 def _collect_leaves(
459 self,
460 observable: PauliWord,
461 pauli_idx: int,
462 sin_counts: np.ndarray,
463 cos_counts: np.ndarray,
464 leaves: List[Tuple[np.ndarray, np.ndarray, complex]],
465 ) -> None:
466 """Recursively enumerate the leaves of the coefficient tree.
468 The incoming sine/cosine factor (from the parent edge) is already
469 accumulated into ``sin_counts``/``cos_counts``. This fuses the tree
470 construction and leaf traversal of the original implementation into a
471 single NumPy pass (no per-node JAX scatter updates).
472 """
473 if self._early_stopping_possible(pauli_idx, observable):
474 return
476 # Skip trailing Pauli rotations that commute with the observable.
477 while pauli_idx >= 0:
478 last = self.pauli_words[pauli_idx]
479 if not observable.commutes_with(last):
480 break
481 pauli_idx -= 1
482 else: # leaf reached
483 term = observable.zero_expectation()
484 if term != 0:
485 leaves.append((sin_counts, cos_counts, term))
486 return
488 last = self.pauli_words[pauli_idx]
490 # Left child: cosine factor for this parameter, same observable.
491 cos_left = cos_counts.copy()
492 cos_left[pauli_idx] += 1
493 self._collect_leaves(
494 observable, pauli_idx - 1, sin_counts.copy(), cos_left, leaves
495 )
497 # Right child: sine factor, observable becomes P . O.
498 sin_right = sin_counts.copy()
499 sin_right[pauli_idx] += 1
500 self._collect_leaves(
501 last.compose(observable),
502 pauli_idx - 1,
503 sin_right,
504 cos_counts.copy(),
505 leaves,
506 )
508 def _early_stopping_possible(self, pauli_idx: int, observable: PauliWord) -> bool:
509 """Whether a node can be discarded (all reachable expectations vanish).
511 Mirrors the criterion of Nemkov et al. (light cone): a qubit on which
512 the observable carries an X/Y must be covered by an X/Y generator of
513 some remaining rotation (rotations[0..pauli_idx]); otherwise that X/Y can
514 never be rotated into a diagonal term and the whole node contributes
515 zero. Equivalently, the node survives iff every qubit is either I/Z in
516 the observable or covered by the cumulative rotation X/Y support.
517 """
518 obs_iz = np.logical_not(observable.xy_mask)
519 combined = np.logical_or(obs_iz, self.cumulative_xy[pauli_idx]).all()
520 return not bool(combined)
522 # Frequency / weight structure (NumPy, parameter independent)
523 def _build_spectrum_structure(self) -> None:
524 """Build, per root, the frequency vectors and the (n_freq, n_leaves)
525 weight matrix ``W`` such that ``coeffs = W @ (terms * variational)``.
526 """
527 self.freqs_per_root: List[np.ndarray] = []
528 self.weights_per_root: List[np.ndarray] = []
529 d = len(self.features)
531 for S, C, _ in self.leaf_arrays:
532 n_leaves = S.shape[0]
533 freq_to_col: Dict[tuple, np.ndarray] = defaultdict(
534 lambda: np.zeros(n_leaves, dtype=np.complex128)
535 )
536 for leaf in range(n_leaves):
537 # One expansion factor per *active* input column, each carrying
538 # its feature axis and integer frequency scaling. Per leaf a
539 # column contributes at most one sin/cos factor (square-free),
540 # but different columns of the same feature may carry different
541 # scalings, so they are expanded individually and convolved
542 # rather than aggregating counts (which would assume a common
543 # unit scaling).
544 col_factors: List[List[Tuple[int, int, float]]] = []
545 half_exp = 0
546 for axis, feat in enumerate(self.features):
547 for k in self.input_indices[feat]:
548 s = int(S[leaf, k])
549 c = int(C[leaf, k])
550 if s == 0 and c == 0:
551 continue
552 half_exp += s + c
553 w_k = int(self.input_scaling[k])
554 col_factors.append(
555 [
556 (axis, int(o) * w_k, wt)
557 for o, wt in self._binomial_terms(s, c)
558 ]
559 )
560 half = 0.5**half_exp
562 if d == 0:
563 freq_to_col[(0,)][leaf] += half
564 continue
566 if not col_factors:
567 freq_to_col[(0,) * d][leaf] += half
568 continue
570 for combo in itertools.product(*col_factors):
571 omega = [0] * d
572 weight = half
573 for axis, o, wt in combo:
574 omega[axis] += o
575 weight *= wt
576 freq_to_col[tuple(omega)][leaf] += weight
578 if freq_to_col:
579 omegas = sorted(freq_to_col.keys())
580 W = np.stack([freq_to_col[o] for o in omegas]) # (n_freq, n_leaves)
581 freqs = np.array(omegas, dtype=np.int64) # (n_freq, d)
582 else:
583 freqs = np.zeros((1, max(d, 1)), dtype=np.int64)
584 W = np.zeros((1, n_leaves), dtype=np.complex128)
586 # Collapse to 1-D frequency array for the single-feature case.
587 if freqs.shape[1] == 1:
588 freqs = freqs[:, 0]
589 self.freqs_per_root.append(freqs)
590 # Keep W in NumPy complex128: its entries are dyadic rationals
591 # (binomial weights x 0.5^k x i^m), which are exact in float64 --
592 # this allows exact symbolic zero-tests in get_exact_support.
593 self.weights_per_root.append(W)
595 @staticmethod
596 def _binomial_terms(s: int, c: int) -> List[Tuple[int, float]]:
597 """Expand ``cos^c (i sin)^s`` in ``e^{i omega x}`` (without the 0.5 factor).
599 Returns a list of ``(omega, weight)`` with
600 ``omega = 2a + 2b - s - c`` and ``weight = C(s,a) C(c,b) (-1)^{s-a}``.
601 """
602 terms = []
603 for a in range(s + 1):
604 for b in range(c + 1):
605 weight = math.comb(s, a) * math.comb(c, b) * (-1) ** (s - a)
606 terms.append((2 * a + 2 * b - s - c, float(weight)))
607 return terms
609 # Vectorised numeric evaluation (JAX)
610 @staticmethod
611 def _safe_pow(base: jnp.ndarray, exp: jnp.ndarray) -> jnp.ndarray:
612 """Elementwise ``base ** exp`` for real base and non-negative integer
613 exponents, correct for negative bases (avoids ``log`` of negatives).
615 Args:
616 base: real array of shape ``(n,)``.
617 exp: integer array of shape ``(n_leaves, n)``.
618 """
619 mag = jnp.abs(base)[None, :] ** exp
620 sign = jnp.where(exp % 2 == 0, 1.0, jnp.sign(base)[None, :])
621 return sign * mag
623 _I_POW = None # set lazily to jnp.array([1, 1j, -1, -1j])
625 def _leaf_factors(
626 self, S: np.ndarray, C: np.ndarray, columns: np.ndarray
627 ) -> jnp.ndarray:
628 """Per-leaf product ``prod_i cos(theta_i)^{C} (i sin(theta_i))^{S}`` over
629 the given parameter ``columns`` (vectorised over leaves).
630 """
631 if FourierTree._I_POW is None:
632 FourierTree._I_POW = jnp.array([1, 1j, -1, -1j])
634 if S.shape[0] == 0:
635 return jnp.zeros(0, dtype=jnp.complex64)
637 theta = jnp.stack([self.parameters[i] for i in columns])
638 S_sub = jnp.asarray(S[:, columns])
639 C_sub = jnp.asarray(C[:, columns])
641 cos_part = self._safe_pow(jnp.cos(theta), C_sub)
642 sin_mag = self._safe_pow(jnp.sin(theta), S_sub)
643 i_part = FourierTree._I_POW[S_sub % 4]
644 return jnp.prod(cos_part * sin_mag * i_part, axis=1)
646 def __call__(
647 self,
648 params: Optional[jnp.ndarray] = None,
649 inputs: Optional[jnp.ndarray] = None,
650 **kwargs,
651 ) -> jnp.ndarray:
652 """
653 Evaluate the expectation value(s) of the model's observables via the
654 sine-cosine tree (equivalent to the circuit expectation).
656 Args:
657 params (Optional[jnp.ndarray]): Model parameters. Defaults to the
658 model's parameters.
659 inputs (Optional[jnp.ndarray]): Inputs to the circuit. Defaults to 1.
661 Returns:
662 jnp.ndarray: Expectation value per observable (or their mean if
663 ``force_mean`` is set).
665 Raises:
666 NotImplementedError: For execution types other than "expval" or when
667 noise is requested.
668 """
669 params = (
670 self.model._params_validation(params)
671 if params is not None
672 else self.model.params
673 )
674 inputs = (
675 self.model._inputs_validation(inputs)
676 if inputs is not None
677 else self.model._inputs_validation(1.0)
678 )
680 if kwargs.get("execution_type", "expval") != "expval":
681 raise NotImplementedError(
682 f'Currently, only "expval" execution type is supported when '
683 f"building FourierTree. Got {kwargs.get('execution_type', 'expval')}."
684 )
685 if kwargs.get("noise_params", None) is not None:
686 raise NotImplementedError(
687 "Currently, noise is not supported when building FourierTree."
688 )
690 # Re-derive the (canonical) parameter values for the requested inputs;
691 # the tree structure (leaf arrays) is unchanged.
692 operations, _ = self._build_canonical_tape(params, inputs)
693 self.parameters = [
694 jnp.squeeze(p) for p in PauliCircuit.get_parameters(operations)
695 ]
697 self._ensure_structure()
698 all_columns = np.arange(self.n_params, dtype=np.int64)
699 results = []
700 for S, C, terms in self.leaf_arrays:
701 factors = self._leaf_factors(S, C, all_columns)
702 results.append(jnp.real(jnp.sum(jnp.asarray(terms) * factors)))
703 results = jnp.array(results)
705 if kwargs.get("force_mean", False):
706 return jnp.mean(results)
707 return results
709 def get_spectrum(
710 self, force_mean: bool = False
711 ) -> Tuple[List[jnp.ndarray], List[jnp.ndarray]]:
712 """
713 Compute the Fourier spectrum (coefficients and frequencies) of the tree.
715 Args:
716 force_mean (bool, optional): Average the coefficients over all
717 observables (roots). Defaults to False.
719 Returns:
720 Tuple[List[jnp.ndarray], List[jnp.ndarray]]:
721 - List of coefficients, one entry per observable (root).
722 - List of corresponding frequencies, one entry per root.
723 When ``force_mean`` is set, both lists have a single entry.
724 """
725 self._ensure_structure()
726 per_root_coeffs: List[jnp.ndarray] = []
727 for (S, C, terms), W in zip(self.leaf_arrays, self.weights_per_root):
728 leaf_const = jnp.asarray(terms) * self._leaf_factors(
729 S, C, self.var_positions
730 )
731 per_root_coeffs.append(jnp.asarray(W) @ leaf_const)
733 return self._combine_roots(per_root_coeffs, self.freqs_per_root, force_mean)
735 def get_exact_support(self, method: str = "tree") -> List[np.ndarray]:
736 r"""Symbolically derive the exact frequency support (no sampling).
738 A frequency :math:`\omega` belongs to the exact spectrum iff its
739 coefficient :math:`c_\omega(\theta) = \sum_l W_{\omega l}\,
740 \text{term}_l\, v_l(\theta)` is not identically zero in the
741 variational parameters :math:`\theta`.
743 Two methods are available:
745 - ``"tree"`` (default, fully exact): enumerates the explicit tree
746 leaves. Because the branch index strictly decreases along every tree
747 path, each parameter contributes **at most one** sine *or* cosine
748 factor per leaf (:math:`S_{li}, C_{li} \in \{0, 1\}`). Every
749 variational leaf factor :math:`v_l` is therefore a *square-free*
750 monomial over :math:`\{1, \cos\theta_i, i\sin\theta_i\}`, and
751 monomials with distinct signatures are linearly independent functions
752 (no :math:`\cos^2 + \sin^2` identities can arise without squares).
753 Hence
755 .. math::
756 c_\omega \equiv 0 \iff \sum_{l \in g} W_{\omega l}\,\text{term}_l
757 = 0 \quad \text{for every signature group } g.
759 Since all involved quantities are dyadic rationals times
760 :math:`\{\pm 1, \pm i\}`, the group sums are exact in float64 and the
761 zero-test is exact. The number of leaves can however grow
762 exponentially with circuit depth.
764 - ``"dp"`` (scalable): merges tree nodes with identical
765 ``(rotation index, observable)`` — at most ``n_params * 4^n_qubits``
766 states — and tracks the achievable input sine/cosine count pairs
767 ``(s, c)`` per state. The support is the union of the (exact)
768 expansion supports of :math:`\cos^c x\, (i \sin x)^s` over all
769 achievable pairs. This is exact per tree path (including interior
770 zero coefficients of the expansions), but unlike ``"tree"`` it cannot
771 detect coefficients that cancel identically *across* paths with
772 identical variational signatures (e.g. directly repeated encodings).
773 It therefore yields a tight superset in such corner cases.
774 Currently restricted to a single input feature.
776 Args:
777 method (str): ``"tree"`` (fully exact) or ``"dp"`` (scalable).
779 Returns:
780 List[np.ndarray]: For each observable (root), the frequency vectors
781 with not-identically-zero coefficient — shape ``(n_freq,)`` for a
782 single input feature, ``(n_freq, n_features)`` otherwise.
783 """
784 if method == "dp":
785 return self._support_dp()
786 if method != "tree":
787 raise ValueError(f"Unknown method '{method}'. Use 'tree' or 'dp'.")
789 self._ensure_structure()
790 supports = []
791 for (S, C, terms), W, freqs in zip(
792 self.leaf_arrays, self.weights_per_root, self.freqs_per_root
793 ):
794 freqs = np.asarray(freqs)
795 n_leaves = S.shape[0]
796 if n_leaves == 0:
797 supports.append(freqs[:0])
798 continue
800 # Group leaves by their variational sine/cosine signature.
801 signature = np.hstack([S[:, self.var_positions], C[:, self.var_positions]])
802 _, groups = np.unique(signature, axis=0, return_inverse=True)
803 n_groups = int(groups.max()) + 1
805 # Per-group sums of W[omega, l] * term_l, accumulated exactly.
806 contrib = (W * terms[None, :]).T # (n_leaves, n_freq)
807 group_sums = np.zeros((n_groups, W.shape[0]), dtype=np.complex128)
808 np.add.at(group_sums, groups, contrib)
810 mask = (np.abs(group_sums) > 1e-12).any(axis=0) # (n_freq,)
811 supports.append(freqs[mask])
812 return supports
814 def _support_dp(self) -> List[np.ndarray]:
815 """Merged-state dynamic program for the frequency support.
817 Instead of enumerating all (worst-case exponentially many) tree paths,
818 nodes are merged on ``(rotation index, bare observable)``. Each state
819 stores the set of achievable input ``(s, c)`` count pairs as a bitmask,
820 so transitions are O(1) big-int operations. See
821 :meth:`get_exact_support` for semantics and limitations.
822 """
823 if len(self.features) != 1:
824 raise NotImplementedError(
825 "The 'dp' support method currently supports exactly one input "
826 "feature; use method='tree' for multi-feature models."
827 )
829 if self.all_input_indices and np.any(
830 self.input_scaling[self.all_input_indices] != 1
831 ):
832 raise NotImplementedError(
833 "The 'dp' support method does not support non-unit input "
834 "frequency scaling (it aggregates sin/cos counts and cannot "
835 "represent per-gate scalings); use method='tree'."
836 )
838 n = self.n_qubits
839 is_input = np.zeros(self.n_params, dtype=bool)
840 is_input[self.all_input_indices] = True
841 n_inp = int(is_input.sum())
842 stride = n_inp + 1 # bit index for (s, c) is s * stride + c
844 def encode(word: PauliWord) -> Tuple[int, int]:
845 x = z = 0
846 for q in range(n):
847 x |= int(word.x[q]) << q
848 z |= int(word.z[q]) << q
849 return x, z
851 paulis = [encode(w) for w in self.pauli_words]
852 cum_xy = []
853 running = 0
854 for xp, _ in paulis:
855 running |= xp
856 cum_xy.append(running)
858 def parity(v: int) -> int:
859 return bin(v).count("1") & 1
861 def dp(idx: int, xo: int, zo: int, memo: dict) -> int:
862 # Light-cone early stopping (cf. _early_stopping_possible).
863 if idx >= 0 and (xo & ~cum_xy[idx]):
864 return 0
865 # Skip trailing rotations that commute with the observable.
866 while idx >= 0:
867 xp, zp = paulis[idx]
868 if parity(xo & zp) ^ parity(zo & xp):
869 break
870 idx -= 1
871 else: # leaf: counts (s=0, c=0) iff the observable is diagonal
872 return 1 if xo == 0 else 0
873 key = (idx, xo, zo)
874 hit = memo.get(key)
875 if hit is not None:
876 return hit
877 xp, zp = paulis[idx]
878 cos_child = dp(idx - 1, xo, zo, memo)
879 sin_child = dp(idx - 1, xo ^ xp, zo ^ zp, memo)
880 if is_input[idx]:
881 # Active input gate: cosine increments c, sine increments s.
882 val = (cos_child << 1) | (sin_child << stride)
883 else:
884 val = cos_child | sin_child
885 memo[key] = val
886 return val
888 # Recursion depth is bounded by the number of rotations.
889 old_limit = sys.getrecursionlimit()
890 sys.setrecursionlimit(max(old_limit, self.n_params + 1000))
891 try:
892 supports = []
893 for obs in self.observable_words:
894 memo: dict = {}
895 xo, zo = encode(obs)
896 mask = dp(self.n_params - 1, xo, zo, memo)
897 freqs: set = set()
898 while mask:
899 bit = mask & -mask
900 i = bit.bit_length() - 1
901 freqs |= self._expansion_support(i // stride, i % stride)
902 mask ^= bit
903 supports.append(np.array(sorted(freqs), dtype=np.int64))
904 finally:
905 sys.setrecursionlimit(old_limit)
906 return supports
908 @staticmethod
909 @lru_cache(maxsize=None)
910 def _expansion_support(s: int, c: int) -> frozenset:
911 r"""Frequencies with non-zero coefficient in :math:`\cos^c x (i\sin x)^s`.
913 Computed exactly with integer arithmetic via the polynomial
914 :math:`(t - 1)^s (t + 1)^c` (with :math:`t = e^{2ix}` up to a shift);
915 interior coefficients can vanish, e.g. :math:`\cos x \sin x` only
916 contains :math:`\pm 2`.
917 """
918 coeffs = [1]
919 for _ in range(s): # multiply by (t - 1)
920 new = [0] * (len(coeffs) + 1)
921 for i, a in enumerate(coeffs):
922 new[i + 1] += a
923 new[i] -= a
924 coeffs = new
925 for _ in range(c): # multiply by (t + 1)
926 new = [0] * (len(coeffs) + 1)
927 for i, a in enumerate(coeffs):
928 new[i + 1] += a
929 new[i] += a
930 coeffs = new
931 m = s + c
932 return frozenset(2 * k - m for k, a in enumerate(coeffs) if a != 0)
934 def _combine_roots(
935 self,
936 per_root_coeffs: List[jnp.ndarray],
937 per_root_freqs: List[np.ndarray],
938 force_mean: bool,
939 ) -> Tuple[List[jnp.ndarray], List[jnp.ndarray]]:
940 """Assemble the per-root spectra, optionally averaging over roots."""
941 if not force_mean:
942 coefficients = [jnp.asarray(c) for c in per_root_coeffs]
943 frequencies = [jnp.asarray(f) for f in per_root_freqs]
944 return coefficients, frequencies
946 # Average over roots on the union of all frequency vectors.
947 accum: Dict[tuple, complex] = defaultdict(complex)
948 for coeffs, freqs in zip(per_root_coeffs, per_root_freqs):
949 freqs_np = np.asarray(freqs)
950 for k in range(freqs_np.shape[0]):
951 key = (
952 (int(freqs_np[k]),)
953 if freqs_np.ndim == 1
954 else tuple(int(v) for v in freqs_np[k])
955 )
956 accum[key] += complex(coeffs[k])
957 n_roots = max(len(per_root_coeffs), 1)
958 keys = sorted(accum.keys())
959 mean_coeffs = jnp.array([accum[k] / n_roots for k in keys])
960 freq_arr = np.array(keys, dtype=np.int64)
961 if freq_arr.shape[1] == 1:
962 freq_arr = freq_arr[:, 0]
963 return [mean_coeffs], [jnp.asarray(freq_arr)]
966class FCC:
967 @classmethod
968 def get_fcc(
969 cls,
970 model: Model,
971 n_samples: int,
972 random_key: Optional[random.PRNGKey] = None,
973 method: Optional[str] = "pearson",
974 scale: Optional[bool] = False,
975 weight: Optional[bool] = False,
976 trim_redundant: Optional[bool] = True,
977 **kwargs,
978 ) -> float:
979 """
980 Shortcut method to get just the FCC.
981 This includes
982 1. What is done in `get_fourier_fingerprint`:
983 1. Calculating the coefficients (using `n_samples`)
984 2. Correlating the result from 1) using `method`
985 3. Weighting the correlation matrix (if `weight` is True)
986 4. Remove redundancies
987 2. What is done in `calculate_fcc`:
988 1. Absolute of the fingerprint
989 2. Average
991 Args:
992 model (Model): The QFM model
993 n_samples (int): Number of samples to calculate average of coefficients
994 random_key (Optional[random.PRNGKey]): JAX random key for parameter
995 initialization. If None, uses the model's internal random key.
996 method (Optional[str], optional): Correlation method. Supported values are
997 "pearson", "complex_pearson", "spearman", and "covariance".
998 Defaults to "pearson".
999 scale (Optional[bool], optional): Whether to scale the number of samples.
1000 Defaults to False.
1001 weight (Optional[bool], optional): Whether to weight the correlation matrix.
1002 Defaults to False.
1003 trim_redundant (Optional[bool], optional): Whether to remove redundant
1004 correlations. Defaults to False.
1005 **kwargs (Any): Additional keyword arguments for the model function.
1007 Returns:
1008 float: The FCC
1009 """
1011 # Memory-efficient fast path
1012 if trim_redundant and not weight:
1013 _, coeffs, freqs = cls._calculate_coefficients(
1014 model, n_samples, random_key, scale, **kwargs
1015 )
1016 pos_idx = cls._calculate_mask(freqs)
1017 coeffs_flat = coeffs.reshape(-1, coeffs.shape[-1])
1018 coeffs_sub = coeffs_flat[pos_idx]
1020 fp = cls._correlate(coeffs_sub.transpose(), method=method)
1021 abs_fp = jnp.abs(fp)
1022 diag = jnp.abs(jnp.diagonal(fp))
1024 total_sum = jnp.nansum(abs_fp)
1025 total_count = jnp.sum(jnp.isfinite(abs_fp))
1026 diag_sum = jnp.nansum(diag)
1027 diag_count = jnp.sum(jnp.isfinite(diag))
1029 lower_sum = (total_sum - diag_sum) / 2.0
1030 lower_count = (total_count - diag_count) / 2.0
1031 return lower_sum / lower_count
1033 fourier_fingerprint, _ = cls.get_fourier_fingerprint(
1034 model,
1035 n_samples,
1036 random_key,
1037 method,
1038 scale,
1039 weight,
1040 trim_redundant=trim_redundant,
1041 **kwargs,
1042 )
1044 return cls.calculate_fcc(fourier_fingerprint)
1046 @classmethod
1047 def get_fourier_fingerprint(
1048 cls,
1049 model: Model,
1050 n_samples: int,
1051 random_key: Optional[random.PRNGKey] = None,
1052 method: Optional[str] = "pearson",
1053 scale: Optional[bool] = False,
1054 weight: Optional[bool] = False,
1055 trim_redundant: Optional[bool] = True,
1056 nan_to_one: Optional[bool] = False,
1057 **kwargs: Any,
1058 ) -> Tuple[jnp.ndarray, jnp.ndarray]:
1059 """
1060 Shortcut method to get just the fourier fingerprint.
1061 This includes
1062 1. Calculating the coefficients (using `n_samples`)
1063 2. Correlating the result from 1) using `method`
1064 3. Weighting the correlation matrix (if `weight` is True)
1065 4. Remove redundancies (if `trim_redundant` is True)
1067 Args:
1068 model (Model): The QFM model
1069 n_samples (int): Number of samples to calculate average of coefficients
1070 random_key (Optional[random.PRNGKey]): JAX random key for parameter
1071 initialization. If None, uses the model's internal random key.
1072 method (Optional[str], optional): Correlation method. Supported values are
1073 "pearson", "complex_pearson", "spearman", and "covariance".
1074 Defaults to "pearson".
1075 scale (Optional[bool], optional): Whether to scale the number of samples.
1076 Defaults to False.
1077 weight (Optional[bool], optional): Whether to weight the correlation matrix.
1078 Defaults to False.
1079 trim_redundant (Optional[bool], optional): Whether to remove redundant
1080 correlations. Defaults to True.
1081 nan_to_one (Optional[bool], optional): Whether to set nan to 1.
1082 Defaults to False.
1083 **kwargs: Additional keyword arguments for the model function.
1085 Returns:
1086 Tuple[jnp.ndarray, jnp.ndarray]: The fourier fingerprint and the
1087 corresponding frequency indices. If `trim_redundant` is True the
1088 frequencies are returned as a `(row_freqs, col_freqs)` tuple that
1089 labels the two (redundancy-trimmed) matrix axes; otherwise the
1090 full frequency vector is returned.
1091 """
1092 _, coeffs, freqs = cls._calculate_coefficients(
1093 model, n_samples, random_key, scale, **kwargs
1094 )
1096 # Memory-efficient fast path
1097 if trim_redundant and not weight:
1098 pos_idx = cls._calculate_mask(freqs)
1099 pos_freqs = cls._flat_frequencies(freqs)[pos_idx]
1101 # Flatten all frequency axes; the last axis is the sample
1102 # axis. `_calculate_mask` returns flat indices in C order,
1103 # matching this reshape.
1104 coeffs_flat = coeffs.reshape(-1, coeffs.shape[-1])
1105 coeffs_sub = coeffs_flat[pos_idx]
1107 fourier_fingerprint = cls._correlate(coeffs_sub.transpose(), method=method)
1109 if nan_to_one:
1110 fourier_fingerprint = jnp.where(
1111 jnp.isnan(fourier_fingerprint), 1.0, fourier_fingerprint
1112 )
1114 M = fourier_fingerprint.shape[0]
1115 lower_tri_mask = jnp.tri(M, k=-1, dtype=bool)
1116 fourier_fingerprint = jnp.where(
1117 lower_tri_mask, fourier_fingerprint, jnp.nan
1118 )
1120 row_mask = jnp.any(jnp.isfinite(fourier_fingerprint), axis=1)
1121 col_mask = jnp.any(jnp.isfinite(fourier_fingerprint), axis=0)
1122 fourier_fingerprint = fourier_fingerprint[row_mask][:, col_mask]
1124 return fourier_fingerprint, (pos_freqs[row_mask], pos_freqs[col_mask])
1126 fourier_fingerprint = cls._correlate(coeffs.transpose(), method=method)
1128 if nan_to_one:
1129 # set nan to 1
1130 fourier_fingerprint[jnp.isnan(fourier_fingerprint)] = 1.0
1132 # perform weighting if requested
1133 fourier_fingerprint = (
1134 cls._weighting_mean(fourier_fingerprint, coeffs)
1135 if weight
1136 else fourier_fingerprint
1137 )
1139 if trim_redundant:
1140 pos_idx = cls._calculate_mask(freqs)
1141 pos_freqs = cls._flat_frequencies(freqs)[pos_idx]
1143 # restrict to the positive-frequency sub-block (M x M with
1144 # M = number of non-negative flat-frequencies) instead of
1145 # building a full N x N mask. This avoids the O(N^2) float
1146 fourier_fingerprint = fourier_fingerprint[pos_idx][:, pos_idx]
1148 # keep only the strict lower triangle; the rest -> nan
1149 M = fourier_fingerprint.shape[0]
1150 lower_tri_mask = jnp.tri(M, k=-1, dtype=bool)
1151 fourier_fingerprint = jnp.where(
1152 lower_tri_mask, fourier_fingerprint, jnp.nan
1153 )
1155 row_mask = jnp.any(jnp.isfinite(fourier_fingerprint), axis=1)
1156 col_mask = jnp.any(jnp.isfinite(fourier_fingerprint), axis=0)
1158 fourier_fingerprint = fourier_fingerprint[row_mask][:, col_mask]
1160 return fourier_fingerprint, (pos_freqs[row_mask], pos_freqs[col_mask])
1162 return fourier_fingerprint, freqs
1164 @classmethod
1165 def calculate_fcc(
1166 cls,
1167 fourier_fingerprint: jnp.ndarray,
1168 ) -> float:
1169 """
1170 Method to calculate the FCC based on an existing correlation matrix.
1171 Calculate absolute and then the average over this matrix.
1172 The Fingerprint can be obtained via `get_fourier_fingerprint`
1174 Args:
1175 fourier_fingerprint (jnp.ndarray): Correlation matrix of coefficients
1176 Returns:
1177 float: The FCC
1178 """
1179 # apply the mask on the fingerprint
1180 return jnp.nanmean(jnp.abs(fourier_fingerprint))
1182 @classmethod
1183 def _calculate_mask(cls, freqs: jnp.ndarray) -> jnp.ndarray:
1184 """
1185 Determine the flat indices of the Fourier correlation matrix
1186 that lie on a non-negative-frequency row/column. Together with
1187 the strict-lower-triangle condition (handled by the caller),
1188 these indices select the entries of the correlation matrix
1189 that survive the redundancy filter applied in
1190 `get_fourier_fingerprint`:
1192 - rows/columns whose flat frequency component is negative are
1193 discarded (they are the complex-conjugate redundancies of
1194 their positive counterparts);
1195 - of the remaining positive-frequency sub-block, only the
1196 strict lower triangle is kept (the upper triangle, including
1197 the diagonal, contains either duplicates from symmetry or
1198 self-correlations).
1200 Args:
1201 freqs (jnp.ndarray): Array of frequencies. Either a 1-D
1202 vector (single input feature) or a 2-D array of shape
1203 ``(n_input_feat, K)`` whose rows are the per-axis
1204 frequency vectors.
1206 Returns:
1207 jnp.ndarray: 1-D int array of flat indices selecting the
1208 non-negative-frequency rows/cols of the fingerprint.
1209 """
1210 freqs_arr = jnp.asarray(freqs)
1212 if freqs_arr.ndim == 1:
1213 pos_flat = freqs_arr >= 0
1214 else:
1215 # N-D case: build the per-axis non-negativity masks and
1216 # combine them via broadcasting (no float `jnp.outer`!),
1217 # then flatten to match the row-major flattening used by
1218 # the upstream coefficient/correlation pipeline.
1219 axes_pos = [freqs_arr[i] >= 0 for i in range(freqs_arr.shape[0])]
1220 expanded = []
1221 n_axes = len(axes_pos)
1222 for i, p in enumerate(axes_pos):
1223 shape = [1] * n_axes
1224 shape[i] = p.shape[0]
1225 expanded.append(p.reshape(shape))
1226 nd_pos = reduce(jnp.logical_and, expanded)
1227 pos_flat = nd_pos.flatten()
1229 return jnp.where(pos_flat)[0]
1231 @classmethod
1232 def _flat_frequencies(cls, freqs: jnp.ndarray) -> jnp.ndarray:
1233 """
1234 Build the per-coefficient flat frequency labels in the same
1235 C-order used to flatten the coefficient/correlation pipeline, so
1236 they can be indexed by the flat indices from `_calculate_mask`.
1238 Args:
1239 freqs (jnp.ndarray): Either a 1-D vector (single input feature)
1240 or a ``(n_input_feat, K)`` stack / list of per-axis frequency
1241 vectors (multi-dim input).
1243 Returns:
1244 jnp.ndarray: 1-D frequency vector (single input feature) or a
1245 ``(N, n_input_feat)`` array of per-coefficient frequency
1246 tuples (multi-dim input).
1247 """
1248 fa = jnp.asarray(freqs)
1249 if fa.ndim == 1:
1250 return fa
1251 # Multi-dim: per-axis vectors -> flat grid of frequency tuples in the
1252 # same C-order used by `_calculate_mask` and the coefficient reshape.
1253 grids = jnp.meshgrid(*[fa[i] for i in range(fa.shape[0])], indexing="ij")
1254 return jnp.stack(grids, axis=-1).reshape(-1, fa.shape[0])
1256 @classmethod
1257 def _calculate_coefficients(
1258 cls,
1259 model: Model,
1260 n_samples: int,
1261 random_key: Optional[random.PRNGKey] = None,
1262 scale: bool = False,
1263 **kwargs: Any,
1264 ) -> Tuple[jnp.ndarray, jnp.ndarray]:
1265 """
1266 Calculates the Fourier coefficients of a given model
1267 using `n_samples`.
1268 Optionally, `noise_params` can be passed to perform noisy simulation.
1270 Args:
1271 model (Model): The QFM model
1272 n_samples (int): Number of samples to calculate average of coefficients
1273 random_key (Optional[random.PRNGKey]): JAX random key for parameter
1274 initialization. If None, uses the model's internal random key.
1275 scale (bool, optional): Whether to scale the number of samples.
1276 Defaults to False.
1277 **kwargs: Additional keyword arguments for the model function.
1279 Returns:
1280 Tuple[jnp.ndarray, jnp.ndarray]: Parameters and Coefficients of size NxK
1281 """
1282 if n_samples > 0:
1283 if scale:
1284 total_samples = int(
1285 jnp.power(2, model.n_qubits) * n_samples * model.n_input_feat
1286 )
1287 log.info(f"Using {total_samples} samples.")
1288 else:
1289 total_samples = n_samples
1290 model.initialize_params(random_key, repeat=total_samples)
1291 else:
1292 total_samples = 1
1294 coeffs, freqs = Coefficients.get_spectrum(
1295 model, shift=True, trim=True, **kwargs
1296 )
1298 return model.params, coeffs, freqs
1300 @classmethod
1301 def _correlate(cls, mat: jnp.ndarray, method: str = "pearson") -> jnp.ndarray:
1302 """
1303 Correlates two arrays using `method`.
1304 Currently, `pearson`, `complex_pearson`, `spearman`, and `covariance`
1305 are supported.
1307 Args:
1308 mat (jnp.ndarray): Array of shape (N, K)
1309 method (str, optional): Correlation method. Defaults to "pearson".
1311 Raises:
1312 ValueError: If the method is not supported.
1314 Returns:
1315 jnp.ndarray: Correlation matrix of `a` and `b`.
1316 """
1317 assert len(mat.shape) >= 2, "Input matrix must have at least 2 dimensions"
1319 # Note that for the general n-D case, we have to flatten along
1320 # the first axis (last one is batch).
1321 # Note that the order here is important so we can easily filter out
1322 # negative coefficients later.
1323 # Consider the following example: [[1,2,3],[4,5,6],[7,8,9]]
1324 # we want to get [1, 4, 7, 2, 5, 8, 3, 6, 9]
1325 # such that after correlation, all positive indexed coefficients
1326 # will be in the bottom right quadrant
1327 if method == "pearson":
1328 result = cls._pearson(mat.reshape(mat.shape[0], -1))
1329 # result = cls._pearson(mat.reshape(mat.shape[-1], -1, order="F"))
1330 elif method == "complex_pearson":
1331 result = cls._complex_pearson(mat.reshape(mat.shape[0], -1))
1332 elif method == "spearman":
1333 result = cls._spearman(mat.reshape(mat.shape[0], -1))
1334 # result = cls._spearman(mat.reshape(mat.shape[-1], -1, order="F"))
1335 elif method == "covariance":
1336 result = cls._covariance(mat.reshape(mat.shape[0], -1))
1337 else:
1338 raise ValueError(
1339 f"Unknown correlation method: {method}. Must be 'pearson', \
1340 'complex_pearson', 'spearman' or 'covariance'."
1341 )
1343 return result
1345 @classmethod
1346 def _covariance(cls, mat: jnp.ndarray, minp: Optional[int] = 1) -> jnp.ndarray:
1347 """
1348 Compute the Hermitian sample covariance between columns of `mat`,
1349 permitting missing values (NaN or ±Inf).
1351 For each pair (i, j) the covariance is computed over the rows that are
1352 finite in both columns, as
1353 sum(conj(x_i - mean_i) * (x_j - mean_j)) / (nobs - 1),
1354 so it computes `X.conj().T @ X`.
1355 Real input collapses to the ordinary real sample covariance; complex
1356 input yields a complex matrix whose magnitude and angle carry the
1357 covariance strength and relative phase.
1360 Args:
1361 mat : array_like, shape (N, K)
1362 Input data.
1363 minp : int, optional
1364 Minimum number of paired observations required to form a
1365 covariance. If the number of valid pairs for (i, j) is < minp,
1366 the result is NaN.
1368 Returns:
1369 cov : ndarray, shape (K, K)
1370 Sample covariance matrix.
1371 """
1372 mat = jnp.asarray(mat)
1373 real_dtype = jnp.asarray(mat.real).dtype
1375 mask = jnp.isfinite(mat)
1376 fmask = mask.astype(real_dtype)
1377 safe = jnp.where(mask, mat, 0.0)
1379 nobs = fmask.T @ fmask
1380 nobs_safe = jnp.where(nobs > 0, nobs, 1.0)
1382 sum_x = safe.T @ fmask
1383 sum_y = fmask.T @ safe
1385 masked = safe * fmask
1386 sum_conj_xy = jnp.conj(masked).T @ masked
1388 sxy = sum_conj_xy - (jnp.conj(sum_x) * sum_y) / nobs_safe
1390 denom = jnp.where(nobs > 1, nobs - 1, jnp.nan)
1391 result = sxy / denom
1393 result = jnp.where(nobs < minp, jnp.nan, result)
1395 return result
1397 @classmethod
1398 def _complex_pearson(
1399 cls, mat: jnp.ndarray, minp: Optional[int] = 1
1400 ) -> jnp.ndarray:
1401 """
1402 Compute the complex Pearson correlation between columns of `mat`,
1403 permitting missing values (NaN or ±Inf).
1405 This uses the Hermitian normalized covariance
1406 sum(conj(x_i - mean_i) * (x_j - mean_j)) /
1407 sqrt(sum(abs(x_i - mean_i)**2) * sum(abs(x_j - mean_j)**2)).
1408 Consequently, if column j is exp(1j * phi) times column i, then
1409 abs(corr[i, j]) is 1 and angle(corr[i, j]) is phi.
1411 Args:
1412 mat : array_like, shape (N, K)
1413 Input data.
1414 minp : int, optional
1415 Minimum number of paired observations required to form a correlation.
1416 If the number of valid pairs for (i, j) is < minp, the result is NaN.
1418 Returns:
1419 corr : ndarray, shape (K, K)
1420 Complex Pearson correlation matrix.
1421 """
1422 mat = jnp.asarray(mat)
1423 real_dtype = jnp.asarray(mat.real).dtype
1425 mask = jnp.isfinite(mat)
1426 fmask = mask.astype(real_dtype)
1427 safe = jnp.where(mask, mat, 0.0)
1429 nobs = fmask.T @ fmask
1430 nobs_safe = jnp.where(nobs > 0, nobs, 1.0)
1432 sum_x = safe.T @ fmask
1433 sum_y = fmask.T @ safe
1435 masked = safe * fmask
1436 sum_conj_xy = jnp.conj(masked).T @ masked
1438 safe_abs_sq = jnp.abs(safe) ** 2
1439 sum_abs_x2 = safe_abs_sq.T @ fmask
1440 sum_abs_y2 = fmask.T @ safe_abs_sq
1442 ssx = sum_abs_x2 - jnp.abs(sum_x) ** 2 / nobs_safe
1443 ssy = sum_abs_y2 - jnp.abs(sum_y) ** 2 / nobs_safe
1444 sxy = sum_conj_xy - (jnp.conj(sum_x) * sum_y) / nobs_safe
1446 denom = jnp.sqrt(ssx * ssy)
1447 result = jnp.where(denom > 0, sxy / denom, jnp.nan)
1448 magnitude = jnp.abs(result)
1449 result = jnp.where(magnitude > 1.0, result / magnitude, result)
1451 result = jnp.where(nobs < minp, jnp.nan, result)
1453 return result
1455 @classmethod
1456 def _pearson(cls, mat: jnp.ndarray, minp: Optional[int] = 1) -> jnp.ndarray:
1457 """
1458 Compute Pearson correlation between columns of `mat`,
1459 permitting missing values (NaN or ±Inf).
1461 The Pearson correlation is the normalized covariance,
1462 corr[i, j] = cov[i, j] / sqrt(cov[i, i] * cov[j, j]),
1463 so it is obtained by normalizing `_covariance` by the per-column
1464 standard deviations.
1466 If the input is complex, real and imaginary parts are stacked along
1467 the sample axis so that both components contribute to the correlation
1468 without discarding information.
1470 Args:
1471 mat : array_like, shape (N, K)
1472 Input data.
1473 minp : int, optional
1474 Minimum number of paired observations required to form a correlation.
1475 If the number of valid pairs for (i, j) is < minp, the result is NaN.
1477 Returns:
1478 corr : ndarray, shape (K, K)
1479 Pearson correlation matrix.
1480 """
1481 # Preserve complex information by splitting into real / imag samples.
1482 # After stacking the data is real, so the Hermitian `_covariance`
1483 # reduces to the ordinary real sample covariance.
1484 if jnp.iscomplexobj(mat):
1485 mat = jnp.concatenate([mat.real, mat.imag], axis=0)
1487 cov = cls._covariance(mat, minp=minp)
1489 # corr[i, j] = cov[i, j] / (std_i * std_j) with std_i = sqrt(cov[i, i])
1490 std = jnp.sqrt(jnp.diagonal(cov))
1491 denom = std[:, None] * std[None, :]
1492 result = jnp.where(denom > 0, cov / denom, jnp.nan)
1494 # clip numerical drift to [-1, 1]
1495 result = jnp.clip(jnp.real(result), -1.0, 1.0)
1497 return result
1499 @classmethod
1500 def _spearman(cls, mat: jnp.ndarray, minp: Optional[int] = 1) -> jnp.ndarray:
1501 """
1502 Based on Pandas correlation method as implemented here:
1503 https://github.com/pandas-dev/pandas/blob/main/pandas/_libs/algos.pyx
1505 Compute Spearman correlation between columns of `mat`,
1506 permitting missing values (NaN or ±Inf).
1508 If the input is complex, real and imaginary parts are stacked along
1509 the sample axis so that both components contribute to the correlation
1510 without discarding information.
1512 Args:
1513 mat : array_like, shape (N, K)
1514 Input data.
1515 minp : int, optional
1516 Minimum number of paired observations required to form a correlation.
1517 If the number of valid pairs for (i, j) is < minp, the result is NaN.
1519 Returns:
1520 corr : ndarray, shape (K, K)
1521 Spearman correlation matrix.
1522 """
1523 # Preserve complex information by splitting into real / imag samples
1524 if jnp.iscomplexobj(mat):
1525 mat = jnp.concatenate([mat.real, mat.imag], axis=0)
1527 mat = jnp.asarray(mat)
1528 N, K = mat.shape
1530 # trivial all-NaN answer if too few rows
1531 if N < minp:
1532 return jnp.full((K, K), jnp.nan)
1534 # mask of finite entries
1535 mask = jnp.isfinite(mat) # shape (N, K), dtype=bool
1537 # precompute ranks column-wise ignoring NaNs
1538 ranks = np.full((N, K), np.nan)
1539 for j in range(K):
1540 valid = mask[:, j]
1541 if valid.any():
1542 ranks[valid, j] = rankdata(mat[valid, j], method="average")
1544 ranks = jnp.asarray(ranks)
1546 # Vectorised Pearson on the ranks
1547 # Replace NaN ranks with 0; use mask to track validity.
1548 rank_mask = jnp.isfinite(ranks)
1549 safe_ranks = jnp.where(rank_mask, ranks, 0.0)
1551 # Pairwise valid-observation counts (K, K)
1552 fmask = rank_mask.astype(ranks.dtype)
1553 nobs = fmask.T @ fmask
1555 # Pairwise sums over mutually-valid rows
1556 sum_x = safe_ranks.T @ fmask # (K, K)
1557 sum_y = fmask.T @ safe_ranks # (K, K)
1559 # Pairwise products
1560 masked_ranks = safe_ranks * fmask # same as safe_ranks
1561 sum_xy = masked_ranks.T @ masked_ranks # (K, K)
1563 safe_sq = safe_ranks**2
1564 sum_x2 = safe_sq.T @ fmask # (K, K)
1565 sum_y2 = fmask.T @ safe_sq # (K, K)
1567 nobs_safe = jnp.where(nobs > 0, nobs, 1.0)
1568 ssx = sum_x2 - sum_x**2 / nobs_safe
1569 ssy = sum_y2 - sum_y**2 / nobs_safe
1570 sxy = sum_xy - (sum_x * sum_y) / nobs_safe
1572 denom = jnp.sqrt(ssx * ssy)
1573 result = jnp.where(denom > 0, sxy / denom, jnp.nan)
1574 result = jnp.clip(result, -1.0, 1.0)
1576 # Enforce minp
1577 result = jnp.where(nobs < minp, jnp.nan, result)
1579 return result
1581 @classmethod
1582 def _weighting_linear(cls, fourier_fingerprint: jnp.ndarray) -> jnp.ndarray:
1583 """
1584 Performs weighting on the given correlation matrix.
1585 Here, low-frequent coefficients are weighted more heavily.
1587 Args:
1588 fourier_fingerprint (jnp.ndarray): Correlation matrix
1589 """
1590 assert (
1591 fourier_fingerprint.shape[0] % 2 != 0
1592 and fourier_fingerprint.shape[1] % 2 != 0
1593 ), (
1594 "Correlation matrix must have odd dimensions. \
1595 Hint: use `trim` argument when calling `get_spectrum`."
1596 )
1597 assert fourier_fingerprint.shape[0] == fourier_fingerprint.shape[1], (
1598 "Correlation matrix must be square."
1599 )
1601 # The weight matrix produced by the previous quadrant-mirror
1602 # construction has a closed form: it is a "tent" sum along the
1603 # two axes. Concretely, with N = fourier_fingerprint.shape[0]
1604 # (odd) and center = N // 2,
1605 # W[i, j] = u[i] + u[j]
1606 # where u[k] = (center - |k - center|) / (2 * center)
1607 # is a triangular weighting peaking at the centre (the zero
1608 # frequency) and decaying linearly to 0 at the spectrum edges.
1609 N = fourier_fingerprint.shape[0]
1610 center = N // 2
1611 k = jnp.arange(N)
1612 u = (center - jnp.abs(k - center)) / (2 * center)
1614 return fourier_fingerprint * (u[:, None] + u[None, :])
1616 @classmethod
1617 def _weighting_mean(
1618 cls, fourier_fingerprint: jnp.ndarray, coeffs: jnp.ndarray
1619 ) -> jnp.ndarray:
1620 """
1621 Performs weighting on the given correlation matrix.
1622 Here, we use the product of the mean of the coefficients as weights.
1623 This suppresses correlations where the mean of the coefficients is near zero.
1625 Args:
1626 fourier_fingerprint (jnp.ndarray): Correlation matrix
1627 coeffs (jnp.ndarray): Fourier coefficients
1628 """
1629 assert fourier_fingerprint.shape[0] == fourier_fingerprint.shape[1], (
1630 "Correlation matrix must be square."
1631 )
1632 assert len(coeffs.shape) >= 2, (
1633 "Coefficient matrix must contain coefficient axes and a sample axis."
1634 )
1636 coefficient_means = jnp.abs(jnp.mean(coeffs, axis=-1))
1637 coefficient_means = coefficient_means.T.reshape(-1)
1639 assert fourier_fingerprint.shape[0] == coefficient_means.shape[0], (
1640 "Correlation matrix size must match the number of Fourier coefficients."
1641 )
1643 # Apply the rank-1 weight w[i] * w[j] via broadcasting instead
1644 # of materialising an explicit `jnp.outer` N x N intermediate.
1645 return (
1646 fourier_fingerprint
1647 * coefficient_means[:, None]
1648 * coefficient_means[None, :]
1649 )
1652class Datasets:
1653 @classmethod
1654 def generate_fourier_series(
1655 cls,
1656 random_key: random.PRNGKey,
1657 model: Model,
1658 coefficients_min: float = 0.0,
1659 coefficients_max: float = 1.0,
1660 zero_centered: bool = False,
1661 ) -> jnp.ndarray:
1662 """
1663 Generates the Fourier series representation of a function.
1664 It uses the `model.frequencies` property to retrieve the frequency
1665 information. This ensures that the resulting Fourier series is
1666 compatible with the model.
1668 This function is capable of generating $D$-dimensional Fourier series
1669 (again defined by `model.n_input_feat`).
1670 The highest frequency $N$ is retrieved per dimension.
1672 Samples of the Fourier coefficients are drawn from a uniform circle.
1674 Args:
1675 random_key (random.PRNGKey): Random number key for JAX.
1676 model (Model): The quantum circuit model.
1677 coefficients_min (float, optional): Minimum value for the coefficients.
1678 Defaults to 0.0.
1679 coefficients_max (float, optional): Maximum value for the coefficients.
1680 Defaults to 1.0.
1681 zero_centered (bool, optional): Whether to zero-center the coefficients.
1682 Defaults to False.
1684 Returns:
1685 jnp.ndarray: Input domain samples with shape ((N,)*D, D)
1686 jnp.ndarray: Fourier series values with shape ((N,)*D)
1687 jnp.ndarray: Fourier coefficients with shape ((N,)*D)
1689 """
1690 # TODO: the following code can be considered to
1691 # capturing a truly random spectrum.
1692 # add some constraints on the spectrum, i.e. not fully
1694 # Note: one key observation for understanding the following code is,
1695 # that instead of wrapping your head around symmetries in multi-
1696 # dimensional coefficient matrices, one can simply look at the flattened
1697 # version of such a matrix and reshape later. It just works out.
1699 # going from [0, 2pi] with the resolution required for highest frequency
1700 # permute with input dimensionality to get an n-d grid of domain samples
1701 # the output shape comes from the fact that want to create a "coordinate system"
1702 domain_samples_per_input_dim = jnp.stack(
1703 jnp.meshgrid(
1704 *[jnp.arange(0, 2 * jnp.pi, 2 * jnp.pi / d) for d in model.degree]
1705 )
1706 ).T.reshape(-1, model.n_input_feat)
1708 # generate the frequency indices for each dimension.
1709 # this will have the same shape as the domain samples
1710 frequencies = jnp.stack(jnp.meshgrid(*model.frequencies)).T.reshape(
1711 -1, model.n_input_feat
1712 )
1714 # using the frequency information, sample coefficients for each dimension
1715 # shape: (input_dims, n_freqs_per_input_dim // 2 + 1)
1717 coefficients = cls.uniform_circle(
1718 random_key,
1719 low=coefficients_min,
1720 high=coefficients_max,
1721 size=math.prod(model.degree) // 2 + 1,
1722 )
1724 # zero center (first coeff = 0)
1725 # we can assume the first coeff is the offset, because we're dealing
1726 # with a non-symmetric spectrum here
1727 if zero_centered:
1728 coefficients = coefficients.at[0].set(0.0)
1729 else:
1730 coefficients = coefficients.at[0].set(coefficients[0].real)
1732 # ensure symmetry (here, non_negative_ is removed!),
1733 # giving us the full coefficients vector
1734 coefficients = jnp.concat(
1735 [
1736 jnp.flip(coefficients[..., 1:]).conjugate(),
1737 coefficients,
1738 ],
1739 axis=-1,
1740 )
1742 # Vectorized version of $f(x) = \sum_{n=0}^{N-1} c_n * e^{i * \omega_n * x}$
1743 # it takes into account the input dimension, i.e. the output is a matrix
1744 # normalization uses the n_freqs component of the coefficients
1745 values = jnp.real(
1746 (
1747 jnp.exp(1j * (domain_samples_per_input_dim @ frequencies.T))
1748 * coefficients
1749 ).sum(axis=1)
1750 / coefficients.size
1751 )
1753 # return all the information we have
1754 return [
1755 domain_samples_per_input_dim.reshape(*model.degree, -1),
1756 values.reshape(model.degree),
1757 coefficients.reshape(model.degree),
1758 ]
1760 @classmethod
1761 def uniform_circle(
1762 cls,
1763 random_key: random.PRNGKey,
1764 size: Union[jnp.ndarray, List, int],
1765 low=0.0,
1766 high=1.0,
1767 ):
1768 """
1769 Random number generator for complex numbers sampled inside the unit circle
1771 Args:
1772 random_key (random.PRNGKey): Random number key for JAX.
1773 size (Union[jnp.ndarray, int]): Number of samples. If a 2D array is passed,
1774 the first dimension will be the number of dimensions.
1775 low (float, optional): Minimum Radius. Defaults to 0.0.
1776 high (float, optional): Maximum Radius. Defaults to 1.0.
1778 Returns
1779 jnp.ndarray: Array of complex numbers with shape of `size`
1780 """
1782 if isinstance(size, int):
1783 size = jnp.array([size])
1785 random_key, random_key1 = random.split(random_key)
1786 return jnp.sqrt(
1787 random.uniform(random_key, size, minval=low, maxval=high)
1788 ) * jnp.exp(2j * jnp.pi * random.uniform(random_key1, size))