Coverage for qml_essentials / coefficients.py: 96%

643 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-06-11 15:51 +0000

1from __future__ import annotations 

2import sys 

3import math 

4import warnings 

5import itertools 

6from collections import defaultdict 

7import jax.numpy as jnp 

8from jax import random 

9import numpy as np 

10from scipy.stats import rankdata 

11from functools import reduce, lru_cache 

12from typing import List, Tuple, Optional, Any, Dict, Union 

13 

14from qml_essentials.model import Model 

15from qml_essentials.pauli import PauliCircuit 

16from qml_essentials.operations import PauliWord 

17 

18import logging 

19 

20log = logging.getLogger(__name__) 

21 

22 

23class Coefficients: 

24 @classmethod 

25 def get_spectrum( 

26 cls, 

27 model: Model, 

28 mfs: int = 1, 

29 mts: int = 1, 

30 shift=False, 

31 trim=False, 

32 numerical_cap: Optional[float] = -1, 

33 **kwargs, 

34 ) -> Tuple[jnp.ndarray, jnp.ndarray]: 

35 """ 

36 Extracts the coefficients of a given model using a FFT (jnp-fft). 

37 

38 Note that the coefficients are complex numbers, but the imaginary part 

39 of the coefficients should be very close to zero, since the expectation 

40 values of the Pauli operators are real numbers. 

41 

42 It can perform oversampling in both the frequency and time domain 

43 using the `mfs` and `mts` arguments. 

44 

45 Args: 

46 model (Model): The model to sample. 

47 mfs (int): Multiplicator for the highest frequency. Default is 2. 

48 mts (int): Multiplicator for the number of time samples. Default is 1. 

49 shift (bool): Whether to apply jnp-fftshift. Default is False. 

50 trim (bool): Whether to remove the Nyquist frequency if spectrum is even. 

51 Default is False. 

52 numerical_cap (Optional[float]): Numerical cap for the coefficients. 

53 If positive, coefficients with magnitude below the cap are 

54 zeroed and, for a single input feature, frequencies that 

55 vanish entirely are removed from both `coeffs` and `freqs`. 

56 kwargs (Any): Additional keyword arguments for the model function. 

57 

58 Returns: 

59 Tuple[jnp.ndarray, jnp.ndarray]: Tuple containing the coefficients 

60 and frequencies. 

61 """ 

62 kwargs.setdefault("force_mean", True) 

63 kwargs.setdefault("execution_type", "expval") 

64 

65 coeffs, freqs = cls._fourier_transform(model, mfs=mfs, mts=mts, **kwargs) 

66 

67 if not jnp.isclose(jnp.sum(coeffs).imag, 0.0, atol=1.0e-6): 

68 raise ValueError( 

69 f"Spectrum is not real. Imaginary part of coefficients is:\ 

70 {jnp.sum(coeffs).imag}" 

71 ) 

72 

73 if trim: 

74 for ax in range(model.n_input_feat): 

75 if coeffs.shape[ax] % 2 == 0: 

76 coeffs = np.delete(coeffs, len(coeffs) // 2, axis=ax) 

77 freqs = [np.delete(freq, len(freq) // 2, axis=ax) for freq in freqs] 

78 

79 if shift: 

80 coeffs = jnp.fft.fftshift(coeffs, axes=list(range(model.n_input_feat))) 

81 freqs = np.fft.fftshift(freqs) 

82 

83 if numerical_cap > 0: 

84 # set coeffs below threshold to zero 

85 coeffs = jnp.where( 

86 jnp.abs(coeffs) < numerical_cap, 

87 jnp.zeros_like(coeffs), 

88 coeffs, 

89 ) 

90 

91 # Drop frequencies whose coefficients vanish entirely after 

92 # capping, so the returned spectrum reflects only the surviving 

93 # frequencies. Well-defined only for a single (1-D) frequency 

94 # axis; for multi-dim input the rectangular grid is left intact. 

95 if model.n_input_feat == 1: 

96 if coeffs.ndim == 1: 

97 surviving = coeffs != 0 

98 else: 

99 surviving = jnp.any(coeffs != 0, axis=tuple(range(1, coeffs.ndim))) 

100 coeffs = coeffs[surviving] 

101 freqs = [freqs[0][surviving]] 

102 

103 if len(freqs) == 1: 

104 freqs = freqs[0] 

105 

106 return coeffs, freqs 

107 

108 @classmethod 

109 def _fourier_transform( 

110 cls, model: Model, mfs: int, mts: int, **kwargs: Any 

111 ) -> jnp.ndarray: 

112 # Create a frequency vector with as many frequencies as model degrees, 

113 # oversampled by mfs 

114 n_freqs: jnp.ndarray = jnp.array( 

115 [mfs * model.degree[i] for i in range(model.n_input_feat)] 

116 ) 

117 

118 start, stop, step = 0, 2 * mts * jnp.pi, 2 * jnp.pi / n_freqs 

119 # Stretch according to the number of frequencies 

120 inputs: List = [ 

121 jnp.arange(start, stop, step[i]) for i in range(model.n_input_feat) 

122 ] 

123 

124 # permute with input dimensionality 

125 nd_inputs = jnp.array( 

126 jnp.meshgrid(*[inputs[i] for i in range(model.n_input_feat)]) 

127 ).T.reshape(-1, model.n_input_feat) 

128 

129 # Output vector is not necessarily the same length as input 

130 outputs = model(inputs=nd_inputs, **kwargs) 

131 outputs = outputs.reshape( 

132 *[inputs[i].shape[0] for i in range(model.n_input_feat)], -1 

133 ).squeeze() 

134 

135 coeffs = jnp.fft.fftn(outputs, axes=list(range(model.n_input_feat))) 

136 

137 freqs = [ 

138 jnp.fft.fftfreq(int(mts * n_freqs[i]), 1 / n_freqs[i]) 

139 for i in range(model.n_input_feat) 

140 ] 

141 # freqs = jnp.fft.fftfreq(mts * n_freqs, 1 / n_freqs) 

142 

143 # TODO: this could cause issues with multidim input 

144 # FIXME: account for different frequencies in multidim input scenarios 

145 # Run the fft and rearrange + 

146 # normalize the output (using product if multidim) 

147 return ( 

148 coeffs / math.prod(outputs.shape[0 : model.n_input_feat]), 

149 freqs, 

150 ) 

151 

152 @classmethod 

153 def get_psd(cls, coeffs: jnp.ndarray) -> jnp.ndarray: 

154 """ 

155 Calculates the power spectral density (PSD) from given Fourier coefficients. 

156 

157 Args: 

158 coeffs (jnp.ndarray): The Fourier coefficients. 

159 

160 Returns: 

161 jnp.ndarray: The power spectral density. 

162 """ 

163 # TODO: if we apply trim=True in advance, this will be slightly wrong.. 

164 

165 def abs2(x): 

166 return x.real**2 + x.imag**2 

167 

168 scale = 2.0 / (len(coeffs) ** 2) 

169 return scale * abs2(coeffs) 

170 

171 @classmethod 

172 def evaluate_Fourier_series( 

173 cls, 

174 coefficients: jnp.ndarray, 

175 frequencies: jnp.ndarray, 

176 inputs: Union[jnp.ndarray, list, float], 

177 ) -> float: 

178 """ 

179 Evaluate the function value of a Fourier series at one point. 

180 

181 Args: 

182 coefficients (jnp.ndarray): Coefficients of the Fourier series. 

183 frequencies (jnp.ndarray): Corresponding frequencies. 

184 inputs (jnp.ndarray): Point at which to evaluate the function. 

185 Returns: 

186 float: The function value at the input point. 

187 """ 

188 coefficients = jnp.asarray(coefficients) 

189 

190 def flatten_grid(freq_axes): 

191 freq_axes = [jnp.asarray(freq) for freq in freq_axes] 

192 freq_grid = jnp.stack(jnp.meshgrid(*freq_axes, indexing="ij"), axis=-1) 

193 flat_frequencies = freq_grid.reshape(-1, len(freq_axes)) 

194 flat_coefficients = coefficients.reshape( 

195 flat_frequencies.shape[0], *coefficients.shape[len(freq_axes) :] 

196 ) 

197 return flat_coefficients, flat_frequencies 

198 

199 if isinstance(frequencies, list): 

200 flat_coefficients, flat_frequencies = flatten_grid(frequencies) 

201 else: 

202 frequencies = jnp.asarray(frequencies) 

203 if frequencies.ndim == 1: 

204 flat_frequencies = frequencies[:, jnp.newaxis] 

205 flat_coefficients = coefficients.reshape( 

206 flat_frequencies.shape[0], *coefficients.shape[1:] 

207 ) 

208 else: 

209 n_features, n_axis_freqs = frequencies.shape 

210 is_axis_frequencies = ( 

211 coefficients.shape[:n_features] == (n_axis_freqs,) * n_features 

212 ) 

213 

214 if is_axis_frequencies: 

215 flat_coefficients, flat_frequencies = flatten_grid(frequencies) 

216 else: 

217 flat_frequencies = frequencies 

218 flat_coefficients = coefficients.reshape( 

219 flat_frequencies.shape[0], *coefficients.shape[1:] 

220 ) 

221 

222 inputs = jnp.asarray(inputs) 

223 if inputs.ndim == 0: 

224 inputs = inputs.reshape(1, 1) 

225 elif inputs.ndim == 1: 

226 if flat_frequencies.shape[1] == 1: 

227 inputs = inputs[:, jnp.newaxis] 

228 elif inputs.shape[0] == flat_frequencies.shape[1]: 

229 inputs = inputs[jnp.newaxis, :] 

230 else: 

231 inputs = jnp.repeat( 

232 inputs[:, jnp.newaxis], flat_frequencies.shape[1], axis=1 

233 ) 

234 exponents = jnp.exp(1j * (inputs @ flat_frequencies.T)) 

235 exp = jnp.tensordot(exponents, flat_coefficients, axes=([1], [0])) 

236 

237 return jnp.squeeze(jnp.real(exp)) 

238 

239 

240class FourierTree: 

241 """ 

242 Sine-cosine tree representation for the algorithm by Nemkov et al. 

243 

244 Computes the analytical Fourier coefficients/frequencies of a Pauli-Clifford 

245 circuit. The symbolic structure of the tree (which Pauli rotations 

246 contribute sine/cosine factors to which leaf, and the leaf observables) is 

247 built once in NumPy; the parameter-dependent coefficients are then obtained 

248 with a small number of vectorised JAX operations, so the result remains 

249 jittable / differentiable with respect to the model parameters. 

250 

251 The resulting spectrum is the d-dimensional set of frequency vectors, 

252 where $d$ is the input dimensionality. 

253 

254 **Usage**: 

255 ``` 

256 model = Model(...) 

257 tree = FourierTree(model) 

258 exp = tree() # expectation value 

259 coeff_list, freq_list = tree.get_spectrum() 

260 ``` 

261 """ 

262 

263 def __init__(self, model: Model): 

264 """ 

265 Tree initialisation, based on the Pauli-Clifford representation of a 

266 model. 

267 

268 Args: 

269 model (Model): The Model, for which to build the tree. 

270 """ 

271 self.model = model 

272 self.n_qubits = model.n_qubits 

273 

274 # A single (de-batched) parameter set drives the whole tree. 

275 self._params = self._single_param_set(model.params) 

276 

277 # Canonical Pauli-Clifford structure, recorded once at a fixed base 

278 # input. The base value is irrelevant to the structure (it only sets 

279 # the rotation angles, not which Pauli words appear). 

280 base_inputs = np.ones(model.n_input_feat) 

281 operations, observables = self._build_canonical_tape(self._params, base_inputs) 

282 

283 self.parameters = [ 

284 jnp.squeeze(p) for p in PauliCircuit.get_parameters(operations) 

285 ] 

286 self.n_params = len(self.parameters) 

287 

288 # Pauli generators of the (canonical) rotations, as symbolic words. 

289 self.pauli_words: List[PauliWord] = [ 

290 PauliWord.from_operation(op, self.n_qubits) for op in operations 

291 ] 

292 

293 # Cumulative X/Y support of the rotations[0..k] (for light-cone early 

294 # stopping). cumulative_xy[k] is True on every qubit touched by an X/Y 

295 # generator in any rotation up to index k. 

296 self.cumulative_xy: List[np.ndarray] = [] 

297 running = np.zeros(self.n_qubits, dtype=bool) 

298 for pw in self.pauli_words: 

299 running = np.logical_or(running, pw.xy_mask) 

300 self.cumulative_xy.append(running.copy()) 

301 

302 # Observable Pauli words (one tree root each). 

303 self.observable_words: List[PauliWord] = [ 

304 PauliWord.from_operation(obs, self.n_qubits) for obs in observables 

305 ] 

306 

307 # Identify the input-encoding columns, their feature, and integer 

308 # frequency scaling directly from the tape (no per-gate tagging). Sets 

309 # ``input_indices``, ``all_input_indices``, ``input_scaling``, 

310 # ``var_positions`` and ``features``. 

311 self._detect_inputs(base_inputs) 

312 

313 # The explicit leaf structure is built lazily: for deep circuits the 

314 # number of tree paths explodes combinatorially, while the canonical 

315 # form above (and the merged-state support DP) remain cheap. 

316 self._structure_built = False 

317 

318 def _ensure_structure(self) -> None: 

319 """Build the explicit leaf/spectrum structure on first use.""" 

320 if not self._structure_built: 

321 # Symbolic structure: per root (S, C, terms) leaf arrays ... 

322 self._build_leaf_arrays() 

323 # ... and the parameter-independent frequency/weight structure. 

324 self._build_spectrum_structure() 

325 self._structure_built = True 

326 

327 def _single_param_set(self, params) -> jnp.ndarray: 

328 """De-batch the model parameters to the single set the tree describes. 

329 

330 Models can carry batched parameters (e.g. after FCC sampling); the tree 

331 is defined for one set, so fall back to the first and warn. 

332 """ 

333 params = jnp.asarray(params) 

334 if params.ndim > 2 and params.shape[0] > 1: 

335 warnings.warn( 

336 f"FourierTree supports a single parameter set; using the first " 

337 f"of {params.shape[0]} batched parameter sets.", 

338 UserWarning, 

339 ) 

340 params = params[0] 

341 return params 

342 

343 def _build_canonical_tape(self, params, inputs): 

344 """Record the circuit and transform it to Pauli-Clifford normal form. 

345 

346 Returns the ``(operations, observables)`` of the canonical circuit 

347 (see :meth:`PauliCircuit.from_parameterised_circuit`). 

348 """ 

349 params = self._single_param_set(params) 

350 inputs = self.model._inputs_validation(inputs) 

351 raw_tape = self.model.script._record(params=params, inputs=inputs) 

352 _, obs_list = self.model._build_obs() 

353 return PauliCircuit.from_parameterised_circuit( 

354 raw_tape, observables=obs_list, n_qubits=self.n_qubits 

355 ) 

356 

357 def _canonical_parameters(self, inputs) -> np.ndarray: 

358 """Recorded canonical rotation angles (1-D float array) for ``inputs``.""" 

359 operations, _ = self._build_canonical_tape(self._params, inputs) 

360 return np.array( 

361 [float(jnp.squeeze(p)) for p in PauliCircuit.get_parameters(operations)] 

362 ) 

363 

364 def _detect_inputs(self, base_inputs: np.ndarray) -> None: 

365 r"""Infer the input-encoding columns directly from the tape (tag-free). 

366 

367 Each encoding rotation applies an angle :math:`\omega_k\,x_f` that is 

368 linear in a single input feature :math:`x_f`, and Clifford commutation 

369 only multiplies a rotation generator by :math:`\pm 1`. Every canonical 

370 rotation angle is therefore an affine function of the inputs, so 

371 perturbing one feature at a time and differencing the recorded angles 

372 isolates exactly the columns that depend on it, together with the 

373 signed integer scaling :math:`\omega_k`. 

374 

375 Sets :attr:`input_indices` (``{feature: [columns]}``), 

376 :attr:`all_input_indices`, :attr:`input_scaling` (per column, ``1`` for 

377 variational columns), :attr:`var_positions`, and :attr:`features`. 

378 

379 Raises: 

380 NotImplementedError: If a rotation depends on more than one feature 

381 (the tree requires single-feature encodings). 

382 """ 

383 tol = 1e-6 

384 d = self.model.n_input_feat 

385 base = np.asarray(base_inputs, dtype=float) 

386 p_base = np.array([float(p) for p in self.parameters]) 

387 

388 # response[f, k] = d(angle_k) / d(x_f), the linear response of column k. 

389 response = np.zeros((d, self.n_params)) 

390 for f in range(d): 

391 step = base.copy() 

392 step[f] += 1.0 

393 response[f] = self._canonical_parameters(step) - p_base 

394 

395 input_indices: Dict[int, list] = defaultdict(list) 

396 all_input_indices: List[int] = [] 

397 scaling = np.ones(self.n_params, dtype=np.int64) 

398 for k in range(self.n_params): 

399 feats = np.flatnonzero(np.abs(response[:, k]) > tol) 

400 if feats.size == 0: 

401 continue # variational column 

402 if feats.size > 1: 

403 raise NotImplementedError( 

404 f"Rotation {k} depends on multiple input features " 

405 f"{feats.tolist()}; the Fourier tree requires each encoding " 

406 "rotation to be linear in a single feature." 

407 ) 

408 f = int(feats[0]) 

409 omega = float(response[f, k]) 

410 w = int(round(omega)) 

411 if abs(omega - w) > tol: 

412 warnings.warn( 

413 f"Non-integer input scaling {omega:.4f} on rotation {k} " 

414 f"(feature {f}); rounding to {w}. The Fourier tree supports " 

415 "integer frequency scalings only.", 

416 UserWarning, 

417 ) 

418 input_indices[f].append(k) 

419 all_input_indices.append(k) 

420 scaling[k] = w 

421 

422 self.input_indices = input_indices 

423 self.all_input_indices = all_input_indices 

424 self.input_scaling = scaling 

425 input_set = set(all_input_indices) 

426 self.var_positions = np.array( 

427 [i for i in range(self.n_params) if i not in input_set], dtype=np.int64 

428 ) 

429 # Ordered list of input feature keys (d-dimensional spectrum). 

430 self.features = sorted(input_indices.keys()) 

431 

432 # Symbolic tree construction (NumPy) 

433 def _build_leaf_arrays(self) -> None: 

434 """Collect the tree leaves for every root into integer count matrices. 

435 

436 For each root (observable) this produces: 

437 - ``S``: (n_leaves, n_params) sine-factor counts per parameter, 

438 - ``C``: (n_leaves, n_params) cosine-factor counts per parameter, 

439 - ``terms``: (n_leaves,) complex leaf constants ``<0|O_leaf|0>``. 

440 """ 

441 self.leaf_arrays: List[Tuple[np.ndarray, np.ndarray, np.ndarray]] = [] 

442 for obs_word in self.observable_words: 

443 leaves: List[Tuple[np.ndarray, np.ndarray, complex]] = [] 

444 zeros = np.zeros(self.n_params, dtype=np.int64) 

445 self._collect_leaves( 

446 obs_word, self.n_params - 1, zeros.copy(), zeros.copy(), leaves 

447 ) 

448 if leaves: 

449 S = np.stack([leaf[0] for leaf in leaves]) 

450 C = np.stack([leaf[1] for leaf in leaves]) 

451 terms = np.array([leaf[2] for leaf in leaves], dtype=np.complex128) 

452 else: 

453 S = np.zeros((0, self.n_params), dtype=np.int64) 

454 C = np.zeros((0, self.n_params), dtype=np.int64) 

455 terms = np.zeros(0, dtype=np.complex128) 

456 self.leaf_arrays.append((S, C, terms)) 

457 

458 def _collect_leaves( 

459 self, 

460 observable: PauliWord, 

461 pauli_idx: int, 

462 sin_counts: np.ndarray, 

463 cos_counts: np.ndarray, 

464 leaves: List[Tuple[np.ndarray, np.ndarray, complex]], 

465 ) -> None: 

466 """Recursively enumerate the leaves of the coefficient tree. 

467 

468 The incoming sine/cosine factor (from the parent edge) is already 

469 accumulated into ``sin_counts``/``cos_counts``. This fuses the tree 

470 construction and leaf traversal of the original implementation into a 

471 single NumPy pass (no per-node JAX scatter updates). 

472 """ 

473 if self._early_stopping_possible(pauli_idx, observable): 

474 return 

475 

476 # Skip trailing Pauli rotations that commute with the observable. 

477 while pauli_idx >= 0: 

478 last = self.pauli_words[pauli_idx] 

479 if not observable.commutes_with(last): 

480 break 

481 pauli_idx -= 1 

482 else: # leaf reached 

483 term = observable.zero_expectation() 

484 if term != 0: 

485 leaves.append((sin_counts, cos_counts, term)) 

486 return 

487 

488 last = self.pauli_words[pauli_idx] 

489 

490 # Left child: cosine factor for this parameter, same observable. 

491 cos_left = cos_counts.copy() 

492 cos_left[pauli_idx] += 1 

493 self._collect_leaves( 

494 observable, pauli_idx - 1, sin_counts.copy(), cos_left, leaves 

495 ) 

496 

497 # Right child: sine factor, observable becomes P . O. 

498 sin_right = sin_counts.copy() 

499 sin_right[pauli_idx] += 1 

500 self._collect_leaves( 

501 last.compose(observable), 

502 pauli_idx - 1, 

503 sin_right, 

504 cos_counts.copy(), 

505 leaves, 

506 ) 

507 

508 def _early_stopping_possible(self, pauli_idx: int, observable: PauliWord) -> bool: 

509 """Whether a node can be discarded (all reachable expectations vanish). 

510 

511 Mirrors the criterion of Nemkov et al. (light cone): a qubit on which 

512 the observable carries an X/Y must be covered by an X/Y generator of 

513 some remaining rotation (rotations[0..pauli_idx]); otherwise that X/Y can 

514 never be rotated into a diagonal term and the whole node contributes 

515 zero. Equivalently, the node survives iff every qubit is either I/Z in 

516 the observable or covered by the cumulative rotation X/Y support. 

517 """ 

518 obs_iz = np.logical_not(observable.xy_mask) 

519 combined = np.logical_or(obs_iz, self.cumulative_xy[pauli_idx]).all() 

520 return not bool(combined) 

521 

522 # Frequency / weight structure (NumPy, parameter independent) 

523 def _build_spectrum_structure(self) -> None: 

524 """Build, per root, the frequency vectors and the (n_freq, n_leaves) 

525 weight matrix ``W`` such that ``coeffs = W @ (terms * variational)``. 

526 """ 

527 self.freqs_per_root: List[np.ndarray] = [] 

528 self.weights_per_root: List[np.ndarray] = [] 

529 d = len(self.features) 

530 

531 for S, C, _ in self.leaf_arrays: 

532 n_leaves = S.shape[0] 

533 freq_to_col: Dict[tuple, np.ndarray] = defaultdict( 

534 lambda: np.zeros(n_leaves, dtype=np.complex128) 

535 ) 

536 for leaf in range(n_leaves): 

537 # One expansion factor per *active* input column, each carrying 

538 # its feature axis and integer frequency scaling. Per leaf a 

539 # column contributes at most one sin/cos factor (square-free), 

540 # but different columns of the same feature may carry different 

541 # scalings, so they are expanded individually and convolved 

542 # rather than aggregating counts (which would assume a common 

543 # unit scaling). 

544 col_factors: List[List[Tuple[int, int, float]]] = [] 

545 half_exp = 0 

546 for axis, feat in enumerate(self.features): 

547 for k in self.input_indices[feat]: 

548 s = int(S[leaf, k]) 

549 c = int(C[leaf, k]) 

550 if s == 0 and c == 0: 

551 continue 

552 half_exp += s + c 

553 w_k = int(self.input_scaling[k]) 

554 col_factors.append( 

555 [ 

556 (axis, int(o) * w_k, wt) 

557 for o, wt in self._binomial_terms(s, c) 

558 ] 

559 ) 

560 half = 0.5**half_exp 

561 

562 if d == 0: 

563 freq_to_col[(0,)][leaf] += half 

564 continue 

565 

566 if not col_factors: 

567 freq_to_col[(0,) * d][leaf] += half 

568 continue 

569 

570 for combo in itertools.product(*col_factors): 

571 omega = [0] * d 

572 weight = half 

573 for axis, o, wt in combo: 

574 omega[axis] += o 

575 weight *= wt 

576 freq_to_col[tuple(omega)][leaf] += weight 

577 

578 if freq_to_col: 

579 omegas = sorted(freq_to_col.keys()) 

580 W = np.stack([freq_to_col[o] for o in omegas]) # (n_freq, n_leaves) 

581 freqs = np.array(omegas, dtype=np.int64) # (n_freq, d) 

582 else: 

583 freqs = np.zeros((1, max(d, 1)), dtype=np.int64) 

584 W = np.zeros((1, n_leaves), dtype=np.complex128) 

585 

586 # Collapse to 1-D frequency array for the single-feature case. 

587 if freqs.shape[1] == 1: 

588 freqs = freqs[:, 0] 

589 self.freqs_per_root.append(freqs) 

590 # Keep W in NumPy complex128: its entries are dyadic rationals 

591 # (binomial weights x 0.5^k x i^m), which are exact in float64 -- 

592 # this allows exact symbolic zero-tests in get_exact_support. 

593 self.weights_per_root.append(W) 

594 

595 @staticmethod 

596 def _binomial_terms(s: int, c: int) -> List[Tuple[int, float]]: 

597 """Expand ``cos^c (i sin)^s`` in ``e^{i omega x}`` (without the 0.5 factor). 

598 

599 Returns a list of ``(omega, weight)`` with 

600 ``omega = 2a + 2b - s - c`` and ``weight = C(s,a) C(c,b) (-1)^{s-a}``. 

601 """ 

602 terms = [] 

603 for a in range(s + 1): 

604 for b in range(c + 1): 

605 weight = math.comb(s, a) * math.comb(c, b) * (-1) ** (s - a) 

606 terms.append((2 * a + 2 * b - s - c, float(weight))) 

607 return terms 

608 

609 # Vectorised numeric evaluation (JAX) 

610 @staticmethod 

611 def _safe_pow(base: jnp.ndarray, exp: jnp.ndarray) -> jnp.ndarray: 

612 """Elementwise ``base ** exp`` for real base and non-negative integer 

613 exponents, correct for negative bases (avoids ``log`` of negatives). 

614 

615 Args: 

616 base: real array of shape ``(n,)``. 

617 exp: integer array of shape ``(n_leaves, n)``. 

618 """ 

619 mag = jnp.abs(base)[None, :] ** exp 

620 sign = jnp.where(exp % 2 == 0, 1.0, jnp.sign(base)[None, :]) 

621 return sign * mag 

622 

623 _I_POW = None # set lazily to jnp.array([1, 1j, -1, -1j]) 

624 

625 def _leaf_factors( 

626 self, S: np.ndarray, C: np.ndarray, columns: np.ndarray 

627 ) -> jnp.ndarray: 

628 """Per-leaf product ``prod_i cos(theta_i)^{C} (i sin(theta_i))^{S}`` over 

629 the given parameter ``columns`` (vectorised over leaves). 

630 """ 

631 if FourierTree._I_POW is None: 

632 FourierTree._I_POW = jnp.array([1, 1j, -1, -1j]) 

633 

634 if S.shape[0] == 0: 

635 return jnp.zeros(0, dtype=jnp.complex64) 

636 

637 theta = jnp.stack([self.parameters[i] for i in columns]) 

638 S_sub = jnp.asarray(S[:, columns]) 

639 C_sub = jnp.asarray(C[:, columns]) 

640 

641 cos_part = self._safe_pow(jnp.cos(theta), C_sub) 

642 sin_mag = self._safe_pow(jnp.sin(theta), S_sub) 

643 i_part = FourierTree._I_POW[S_sub % 4] 

644 return jnp.prod(cos_part * sin_mag * i_part, axis=1) 

645 

646 def __call__( 

647 self, 

648 params: Optional[jnp.ndarray] = None, 

649 inputs: Optional[jnp.ndarray] = None, 

650 **kwargs, 

651 ) -> jnp.ndarray: 

652 """ 

653 Evaluate the expectation value(s) of the model's observables via the 

654 sine-cosine tree (equivalent to the circuit expectation). 

655 

656 Args: 

657 params (Optional[jnp.ndarray]): Model parameters. Defaults to the 

658 model's parameters. 

659 inputs (Optional[jnp.ndarray]): Inputs to the circuit. Defaults to 1. 

660 

661 Returns: 

662 jnp.ndarray: Expectation value per observable (or their mean if 

663 ``force_mean`` is set). 

664 

665 Raises: 

666 NotImplementedError: For execution types other than "expval" or when 

667 noise is requested. 

668 """ 

669 params = ( 

670 self.model._params_validation(params) 

671 if params is not None 

672 else self.model.params 

673 ) 

674 inputs = ( 

675 self.model._inputs_validation(inputs) 

676 if inputs is not None 

677 else self.model._inputs_validation(1.0) 

678 ) 

679 

680 if kwargs.get("execution_type", "expval") != "expval": 

681 raise NotImplementedError( 

682 f'Currently, only "expval" execution type is supported when ' 

683 f"building FourierTree. Got {kwargs.get('execution_type', 'expval')}." 

684 ) 

685 if kwargs.get("noise_params", None) is not None: 

686 raise NotImplementedError( 

687 "Currently, noise is not supported when building FourierTree." 

688 ) 

689 

690 # Re-derive the (canonical) parameter values for the requested inputs; 

691 # the tree structure (leaf arrays) is unchanged. 

692 operations, _ = self._build_canonical_tape(params, inputs) 

693 self.parameters = [ 

694 jnp.squeeze(p) for p in PauliCircuit.get_parameters(operations) 

695 ] 

696 

697 self._ensure_structure() 

698 all_columns = np.arange(self.n_params, dtype=np.int64) 

699 results = [] 

700 for S, C, terms in self.leaf_arrays: 

701 factors = self._leaf_factors(S, C, all_columns) 

702 results.append(jnp.real(jnp.sum(jnp.asarray(terms) * factors))) 

703 results = jnp.array(results) 

704 

705 if kwargs.get("force_mean", False): 

706 return jnp.mean(results) 

707 return results 

708 

709 def get_spectrum( 

710 self, force_mean: bool = False 

711 ) -> Tuple[List[jnp.ndarray], List[jnp.ndarray]]: 

712 """ 

713 Compute the Fourier spectrum (coefficients and frequencies) of the tree. 

714 

715 Args: 

716 force_mean (bool, optional): Average the coefficients over all 

717 observables (roots). Defaults to False. 

718 

719 Returns: 

720 Tuple[List[jnp.ndarray], List[jnp.ndarray]]: 

721 - List of coefficients, one entry per observable (root). 

722 - List of corresponding frequencies, one entry per root. 

723 When ``force_mean`` is set, both lists have a single entry. 

724 """ 

725 self._ensure_structure() 

726 per_root_coeffs: List[jnp.ndarray] = [] 

727 for (S, C, terms), W in zip(self.leaf_arrays, self.weights_per_root): 

728 leaf_const = jnp.asarray(terms) * self._leaf_factors( 

729 S, C, self.var_positions 

730 ) 

731 per_root_coeffs.append(jnp.asarray(W) @ leaf_const) 

732 

733 return self._combine_roots(per_root_coeffs, self.freqs_per_root, force_mean) 

734 

735 def get_exact_support(self, method: str = "tree") -> List[np.ndarray]: 

736 r"""Symbolically derive the exact frequency support (no sampling). 

737 

738 A frequency :math:`\omega` belongs to the exact spectrum iff its 

739 coefficient :math:`c_\omega(\theta) = \sum_l W_{\omega l}\, 

740 \text{term}_l\, v_l(\theta)` is not identically zero in the 

741 variational parameters :math:`\theta`. 

742 

743 Two methods are available: 

744 

745 - ``"tree"`` (default, fully exact): enumerates the explicit tree 

746 leaves. Because the branch index strictly decreases along every tree 

747 path, each parameter contributes **at most one** sine *or* cosine 

748 factor per leaf (:math:`S_{li}, C_{li} \in \{0, 1\}`). Every 

749 variational leaf factor :math:`v_l` is therefore a *square-free* 

750 monomial over :math:`\{1, \cos\theta_i, i\sin\theta_i\}`, and 

751 monomials with distinct signatures are linearly independent functions 

752 (no :math:`\cos^2 + \sin^2` identities can arise without squares). 

753 Hence 

754 

755 .. math:: 

756 c_\omega \equiv 0 \iff \sum_{l \in g} W_{\omega l}\,\text{term}_l 

757 = 0 \quad \text{for every signature group } g. 

758 

759 Since all involved quantities are dyadic rationals times 

760 :math:`\{\pm 1, \pm i\}`, the group sums are exact in float64 and the 

761 zero-test is exact. The number of leaves can however grow 

762 exponentially with circuit depth. 

763 

764 - ``"dp"`` (scalable): merges tree nodes with identical 

765 ``(rotation index, observable)`` — at most ``n_params * 4^n_qubits`` 

766 states — and tracks the achievable input sine/cosine count pairs 

767 ``(s, c)`` per state. The support is the union of the (exact) 

768 expansion supports of :math:`\cos^c x\, (i \sin x)^s` over all 

769 achievable pairs. This is exact per tree path (including interior 

770 zero coefficients of the expansions), but unlike ``"tree"`` it cannot 

771 detect coefficients that cancel identically *across* paths with 

772 identical variational signatures (e.g. directly repeated encodings). 

773 It therefore yields a tight superset in such corner cases. 

774 Currently restricted to a single input feature. 

775 

776 Args: 

777 method (str): ``"tree"`` (fully exact) or ``"dp"`` (scalable). 

778 

779 Returns: 

780 List[np.ndarray]: For each observable (root), the frequency vectors 

781 with not-identically-zero coefficient — shape ``(n_freq,)`` for a 

782 single input feature, ``(n_freq, n_features)`` otherwise. 

783 """ 

784 if method == "dp": 

785 return self._support_dp() 

786 if method != "tree": 

787 raise ValueError(f"Unknown method '{method}'. Use 'tree' or 'dp'.") 

788 

789 self._ensure_structure() 

790 supports = [] 

791 for (S, C, terms), W, freqs in zip( 

792 self.leaf_arrays, self.weights_per_root, self.freqs_per_root 

793 ): 

794 freqs = np.asarray(freqs) 

795 n_leaves = S.shape[0] 

796 if n_leaves == 0: 

797 supports.append(freqs[:0]) 

798 continue 

799 

800 # Group leaves by their variational sine/cosine signature. 

801 signature = np.hstack([S[:, self.var_positions], C[:, self.var_positions]]) 

802 _, groups = np.unique(signature, axis=0, return_inverse=True) 

803 n_groups = int(groups.max()) + 1 

804 

805 # Per-group sums of W[omega, l] * term_l, accumulated exactly. 

806 contrib = (W * terms[None, :]).T # (n_leaves, n_freq) 

807 group_sums = np.zeros((n_groups, W.shape[0]), dtype=np.complex128) 

808 np.add.at(group_sums, groups, contrib) 

809 

810 mask = (np.abs(group_sums) > 1e-12).any(axis=0) # (n_freq,) 

811 supports.append(freqs[mask]) 

812 return supports 

813 

814 def _support_dp(self) -> List[np.ndarray]: 

815 """Merged-state dynamic program for the frequency support. 

816 

817 Instead of enumerating all (worst-case exponentially many) tree paths, 

818 nodes are merged on ``(rotation index, bare observable)``. Each state 

819 stores the set of achievable input ``(s, c)`` count pairs as a bitmask, 

820 so transitions are O(1) big-int operations. See 

821 :meth:`get_exact_support` for semantics and limitations. 

822 """ 

823 if len(self.features) != 1: 

824 raise NotImplementedError( 

825 "The 'dp' support method currently supports exactly one input " 

826 "feature; use method='tree' for multi-feature models." 

827 ) 

828 

829 if self.all_input_indices and np.any( 

830 self.input_scaling[self.all_input_indices] != 1 

831 ): 

832 raise NotImplementedError( 

833 "The 'dp' support method does not support non-unit input " 

834 "frequency scaling (it aggregates sin/cos counts and cannot " 

835 "represent per-gate scalings); use method='tree'." 

836 ) 

837 

838 n = self.n_qubits 

839 is_input = np.zeros(self.n_params, dtype=bool) 

840 is_input[self.all_input_indices] = True 

841 n_inp = int(is_input.sum()) 

842 stride = n_inp + 1 # bit index for (s, c) is s * stride + c 

843 

844 def encode(word: PauliWord) -> Tuple[int, int]: 

845 x = z = 0 

846 for q in range(n): 

847 x |= int(word.x[q]) << q 

848 z |= int(word.z[q]) << q 

849 return x, z 

850 

851 paulis = [encode(w) for w in self.pauli_words] 

852 cum_xy = [] 

853 running = 0 

854 for xp, _ in paulis: 

855 running |= xp 

856 cum_xy.append(running) 

857 

858 def parity(v: int) -> int: 

859 return bin(v).count("1") & 1 

860 

861 def dp(idx: int, xo: int, zo: int, memo: dict) -> int: 

862 # Light-cone early stopping (cf. _early_stopping_possible). 

863 if idx >= 0 and (xo & ~cum_xy[idx]): 

864 return 0 

865 # Skip trailing rotations that commute with the observable. 

866 while idx >= 0: 

867 xp, zp = paulis[idx] 

868 if parity(xo & zp) ^ parity(zo & xp): 

869 break 

870 idx -= 1 

871 else: # leaf: counts (s=0, c=0) iff the observable is diagonal 

872 return 1 if xo == 0 else 0 

873 key = (idx, xo, zo) 

874 hit = memo.get(key) 

875 if hit is not None: 

876 return hit 

877 xp, zp = paulis[idx] 

878 cos_child = dp(idx - 1, xo, zo, memo) 

879 sin_child = dp(idx - 1, xo ^ xp, zo ^ zp, memo) 

880 if is_input[idx]: 

881 # Active input gate: cosine increments c, sine increments s. 

882 val = (cos_child << 1) | (sin_child << stride) 

883 else: 

884 val = cos_child | sin_child 

885 memo[key] = val 

886 return val 

887 

888 # Recursion depth is bounded by the number of rotations. 

889 old_limit = sys.getrecursionlimit() 

890 sys.setrecursionlimit(max(old_limit, self.n_params + 1000)) 

891 try: 

892 supports = [] 

893 for obs in self.observable_words: 

894 memo: dict = {} 

895 xo, zo = encode(obs) 

896 mask = dp(self.n_params - 1, xo, zo, memo) 

897 freqs: set = set() 

898 while mask: 

899 bit = mask & -mask 

900 i = bit.bit_length() - 1 

901 freqs |= self._expansion_support(i // stride, i % stride) 

902 mask ^= bit 

903 supports.append(np.array(sorted(freqs), dtype=np.int64)) 

904 finally: 

905 sys.setrecursionlimit(old_limit) 

906 return supports 

907 

908 @staticmethod 

909 @lru_cache(maxsize=None) 

910 def _expansion_support(s: int, c: int) -> frozenset: 

911 r"""Frequencies with non-zero coefficient in :math:`\cos^c x (i\sin x)^s`. 

912 

913 Computed exactly with integer arithmetic via the polynomial 

914 :math:`(t - 1)^s (t + 1)^c` (with :math:`t = e^{2ix}` up to a shift); 

915 interior coefficients can vanish, e.g. :math:`\cos x \sin x` only 

916 contains :math:`\pm 2`. 

917 """ 

918 coeffs = [1] 

919 for _ in range(s): # multiply by (t - 1) 

920 new = [0] * (len(coeffs) + 1) 

921 for i, a in enumerate(coeffs): 

922 new[i + 1] += a 

923 new[i] -= a 

924 coeffs = new 

925 for _ in range(c): # multiply by (t + 1) 

926 new = [0] * (len(coeffs) + 1) 

927 for i, a in enumerate(coeffs): 

928 new[i + 1] += a 

929 new[i] += a 

930 coeffs = new 

931 m = s + c 

932 return frozenset(2 * k - m for k, a in enumerate(coeffs) if a != 0) 

933 

934 def _combine_roots( 

935 self, 

936 per_root_coeffs: List[jnp.ndarray], 

937 per_root_freqs: List[np.ndarray], 

938 force_mean: bool, 

939 ) -> Tuple[List[jnp.ndarray], List[jnp.ndarray]]: 

940 """Assemble the per-root spectra, optionally averaging over roots.""" 

941 if not force_mean: 

942 coefficients = [jnp.asarray(c) for c in per_root_coeffs] 

943 frequencies = [jnp.asarray(f) for f in per_root_freqs] 

944 return coefficients, frequencies 

945 

946 # Average over roots on the union of all frequency vectors. 

947 accum: Dict[tuple, complex] = defaultdict(complex) 

948 for coeffs, freqs in zip(per_root_coeffs, per_root_freqs): 

949 freqs_np = np.asarray(freqs) 

950 for k in range(freqs_np.shape[0]): 

951 key = ( 

952 (int(freqs_np[k]),) 

953 if freqs_np.ndim == 1 

954 else tuple(int(v) for v in freqs_np[k]) 

955 ) 

956 accum[key] += complex(coeffs[k]) 

957 n_roots = max(len(per_root_coeffs), 1) 

958 keys = sorted(accum.keys()) 

959 mean_coeffs = jnp.array([accum[k] / n_roots for k in keys]) 

960 freq_arr = np.array(keys, dtype=np.int64) 

961 if freq_arr.shape[1] == 1: 

962 freq_arr = freq_arr[:, 0] 

963 return [mean_coeffs], [jnp.asarray(freq_arr)] 

964 

965 

966class FCC: 

967 @classmethod 

968 def get_fcc( 

969 cls, 

970 model: Model, 

971 n_samples: int, 

972 random_key: Optional[random.PRNGKey] = None, 

973 method: Optional[str] = "pearson", 

974 scale: Optional[bool] = False, 

975 weight: Optional[bool] = False, 

976 trim_redundant: Optional[bool] = True, 

977 **kwargs, 

978 ) -> float: 

979 """ 

980 Shortcut method to get just the FCC. 

981 This includes 

982 1. What is done in `get_fourier_fingerprint`: 

983 1. Calculating the coefficients (using `n_samples`) 

984 2. Correlating the result from 1) using `method` 

985 3. Weighting the correlation matrix (if `weight` is True) 

986 4. Remove redundancies 

987 2. What is done in `calculate_fcc`: 

988 1. Absolute of the fingerprint 

989 2. Average 

990 

991 Args: 

992 model (Model): The QFM model 

993 n_samples (int): Number of samples to calculate average of coefficients 

994 random_key (Optional[random.PRNGKey]): JAX random key for parameter 

995 initialization. If None, uses the model's internal random key. 

996 method (Optional[str], optional): Correlation method. Supported values are 

997 "pearson", "complex_pearson", "spearman", and "covariance". 

998 Defaults to "pearson". 

999 scale (Optional[bool], optional): Whether to scale the number of samples. 

1000 Defaults to False. 

1001 weight (Optional[bool], optional): Whether to weight the correlation matrix. 

1002 Defaults to False. 

1003 trim_redundant (Optional[bool], optional): Whether to remove redundant 

1004 correlations. Defaults to False. 

1005 **kwargs (Any): Additional keyword arguments for the model function. 

1006 

1007 Returns: 

1008 float: The FCC 

1009 """ 

1010 

1011 # Memory-efficient fast path 

1012 if trim_redundant and not weight: 

1013 _, coeffs, freqs = cls._calculate_coefficients( 

1014 model, n_samples, random_key, scale, **kwargs 

1015 ) 

1016 pos_idx = cls._calculate_mask(freqs) 

1017 coeffs_flat = coeffs.reshape(-1, coeffs.shape[-1]) 

1018 coeffs_sub = coeffs_flat[pos_idx] 

1019 

1020 fp = cls._correlate(coeffs_sub.transpose(), method=method) 

1021 abs_fp = jnp.abs(fp) 

1022 diag = jnp.abs(jnp.diagonal(fp)) 

1023 

1024 total_sum = jnp.nansum(abs_fp) 

1025 total_count = jnp.sum(jnp.isfinite(abs_fp)) 

1026 diag_sum = jnp.nansum(diag) 

1027 diag_count = jnp.sum(jnp.isfinite(diag)) 

1028 

1029 lower_sum = (total_sum - diag_sum) / 2.0 

1030 lower_count = (total_count - diag_count) / 2.0 

1031 return lower_sum / lower_count 

1032 

1033 fourier_fingerprint, _ = cls.get_fourier_fingerprint( 

1034 model, 

1035 n_samples, 

1036 random_key, 

1037 method, 

1038 scale, 

1039 weight, 

1040 trim_redundant=trim_redundant, 

1041 **kwargs, 

1042 ) 

1043 

1044 return cls.calculate_fcc(fourier_fingerprint) 

1045 

1046 @classmethod 

1047 def get_fourier_fingerprint( 

1048 cls, 

1049 model: Model, 

1050 n_samples: int, 

1051 random_key: Optional[random.PRNGKey] = None, 

1052 method: Optional[str] = "pearson", 

1053 scale: Optional[bool] = False, 

1054 weight: Optional[bool] = False, 

1055 trim_redundant: Optional[bool] = True, 

1056 nan_to_one: Optional[bool] = False, 

1057 **kwargs: Any, 

1058 ) -> Tuple[jnp.ndarray, jnp.ndarray]: 

1059 """ 

1060 Shortcut method to get just the fourier fingerprint. 

1061 This includes 

1062 1. Calculating the coefficients (using `n_samples`) 

1063 2. Correlating the result from 1) using `method` 

1064 3. Weighting the correlation matrix (if `weight` is True) 

1065 4. Remove redundancies (if `trim_redundant` is True) 

1066 

1067 Args: 

1068 model (Model): The QFM model 

1069 n_samples (int): Number of samples to calculate average of coefficients 

1070 random_key (Optional[random.PRNGKey]): JAX random key for parameter 

1071 initialization. If None, uses the model's internal random key. 

1072 method (Optional[str], optional): Correlation method. Supported values are 

1073 "pearson", "complex_pearson", "spearman", and "covariance". 

1074 Defaults to "pearson". 

1075 scale (Optional[bool], optional): Whether to scale the number of samples. 

1076 Defaults to False. 

1077 weight (Optional[bool], optional): Whether to weight the correlation matrix. 

1078 Defaults to False. 

1079 trim_redundant (Optional[bool], optional): Whether to remove redundant 

1080 correlations. Defaults to True. 

1081 nan_to_one (Optional[bool], optional): Whether to set nan to 1. 

1082 Defaults to False. 

1083 **kwargs: Additional keyword arguments for the model function. 

1084 

1085 Returns: 

1086 Tuple[jnp.ndarray, jnp.ndarray]: The fourier fingerprint and the 

1087 corresponding frequency indices. If `trim_redundant` is True the 

1088 frequencies are returned as a `(row_freqs, col_freqs)` tuple that 

1089 labels the two (redundancy-trimmed) matrix axes; otherwise the 

1090 full frequency vector is returned. 

1091 """ 

1092 _, coeffs, freqs = cls._calculate_coefficients( 

1093 model, n_samples, random_key, scale, **kwargs 

1094 ) 

1095 

1096 # Memory-efficient fast path 

1097 if trim_redundant and not weight: 

1098 pos_idx = cls._calculate_mask(freqs) 

1099 pos_freqs = cls._flat_frequencies(freqs)[pos_idx] 

1100 

1101 # Flatten all frequency axes; the last axis is the sample 

1102 # axis. `_calculate_mask` returns flat indices in C order, 

1103 # matching this reshape. 

1104 coeffs_flat = coeffs.reshape(-1, coeffs.shape[-1]) 

1105 coeffs_sub = coeffs_flat[pos_idx] 

1106 

1107 fourier_fingerprint = cls._correlate(coeffs_sub.transpose(), method=method) 

1108 

1109 if nan_to_one: 

1110 fourier_fingerprint = jnp.where( 

1111 jnp.isnan(fourier_fingerprint), 1.0, fourier_fingerprint 

1112 ) 

1113 

1114 M = fourier_fingerprint.shape[0] 

1115 lower_tri_mask = jnp.tri(M, k=-1, dtype=bool) 

1116 fourier_fingerprint = jnp.where( 

1117 lower_tri_mask, fourier_fingerprint, jnp.nan 

1118 ) 

1119 

1120 row_mask = jnp.any(jnp.isfinite(fourier_fingerprint), axis=1) 

1121 col_mask = jnp.any(jnp.isfinite(fourier_fingerprint), axis=0) 

1122 fourier_fingerprint = fourier_fingerprint[row_mask][:, col_mask] 

1123 

1124 return fourier_fingerprint, (pos_freqs[row_mask], pos_freqs[col_mask]) 

1125 

1126 fourier_fingerprint = cls._correlate(coeffs.transpose(), method=method) 

1127 

1128 if nan_to_one: 

1129 # set nan to 1 

1130 fourier_fingerprint[jnp.isnan(fourier_fingerprint)] = 1.0 

1131 

1132 # perform weighting if requested 

1133 fourier_fingerprint = ( 

1134 cls._weighting_mean(fourier_fingerprint, coeffs) 

1135 if weight 

1136 else fourier_fingerprint 

1137 ) 

1138 

1139 if trim_redundant: 

1140 pos_idx = cls._calculate_mask(freqs) 

1141 pos_freqs = cls._flat_frequencies(freqs)[pos_idx] 

1142 

1143 # restrict to the positive-frequency sub-block (M x M with 

1144 # M = number of non-negative flat-frequencies) instead of 

1145 # building a full N x N mask. This avoids the O(N^2) float 

1146 fourier_fingerprint = fourier_fingerprint[pos_idx][:, pos_idx] 

1147 

1148 # keep only the strict lower triangle; the rest -> nan 

1149 M = fourier_fingerprint.shape[0] 

1150 lower_tri_mask = jnp.tri(M, k=-1, dtype=bool) 

1151 fourier_fingerprint = jnp.where( 

1152 lower_tri_mask, fourier_fingerprint, jnp.nan 

1153 ) 

1154 

1155 row_mask = jnp.any(jnp.isfinite(fourier_fingerprint), axis=1) 

1156 col_mask = jnp.any(jnp.isfinite(fourier_fingerprint), axis=0) 

1157 

1158 fourier_fingerprint = fourier_fingerprint[row_mask][:, col_mask] 

1159 

1160 return fourier_fingerprint, (pos_freqs[row_mask], pos_freqs[col_mask]) 

1161 

1162 return fourier_fingerprint, freqs 

1163 

1164 @classmethod 

1165 def calculate_fcc( 

1166 cls, 

1167 fourier_fingerprint: jnp.ndarray, 

1168 ) -> float: 

1169 """ 

1170 Method to calculate the FCC based on an existing correlation matrix. 

1171 Calculate absolute and then the average over this matrix. 

1172 The Fingerprint can be obtained via `get_fourier_fingerprint` 

1173 

1174 Args: 

1175 fourier_fingerprint (jnp.ndarray): Correlation matrix of coefficients 

1176 Returns: 

1177 float: The FCC 

1178 """ 

1179 # apply the mask on the fingerprint 

1180 return jnp.nanmean(jnp.abs(fourier_fingerprint)) 

1181 

1182 @classmethod 

1183 def _calculate_mask(cls, freqs: jnp.ndarray) -> jnp.ndarray: 

1184 """ 

1185 Determine the flat indices of the Fourier correlation matrix 

1186 that lie on a non-negative-frequency row/column. Together with 

1187 the strict-lower-triangle condition (handled by the caller), 

1188 these indices select the entries of the correlation matrix 

1189 that survive the redundancy filter applied in 

1190 `get_fourier_fingerprint`: 

1191 

1192 - rows/columns whose flat frequency component is negative are 

1193 discarded (they are the complex-conjugate redundancies of 

1194 their positive counterparts); 

1195 - of the remaining positive-frequency sub-block, only the 

1196 strict lower triangle is kept (the upper triangle, including 

1197 the diagonal, contains either duplicates from symmetry or 

1198 self-correlations). 

1199 

1200 Args: 

1201 freqs (jnp.ndarray): Array of frequencies. Either a 1-D 

1202 vector (single input feature) or a 2-D array of shape 

1203 ``(n_input_feat, K)`` whose rows are the per-axis 

1204 frequency vectors. 

1205 

1206 Returns: 

1207 jnp.ndarray: 1-D int array of flat indices selecting the 

1208 non-negative-frequency rows/cols of the fingerprint. 

1209 """ 

1210 freqs_arr = jnp.asarray(freqs) 

1211 

1212 if freqs_arr.ndim == 1: 

1213 pos_flat = freqs_arr >= 0 

1214 else: 

1215 # N-D case: build the per-axis non-negativity masks and 

1216 # combine them via broadcasting (no float `jnp.outer`!), 

1217 # then flatten to match the row-major flattening used by 

1218 # the upstream coefficient/correlation pipeline. 

1219 axes_pos = [freqs_arr[i] >= 0 for i in range(freqs_arr.shape[0])] 

1220 expanded = [] 

1221 n_axes = len(axes_pos) 

1222 for i, p in enumerate(axes_pos): 

1223 shape = [1] * n_axes 

1224 shape[i] = p.shape[0] 

1225 expanded.append(p.reshape(shape)) 

1226 nd_pos = reduce(jnp.logical_and, expanded) 

1227 pos_flat = nd_pos.flatten() 

1228 

1229 return jnp.where(pos_flat)[0] 

1230 

1231 @classmethod 

1232 def _flat_frequencies(cls, freqs: jnp.ndarray) -> jnp.ndarray: 

1233 """ 

1234 Build the per-coefficient flat frequency labels in the same 

1235 C-order used to flatten the coefficient/correlation pipeline, so 

1236 they can be indexed by the flat indices from `_calculate_mask`. 

1237 

1238 Args: 

1239 freqs (jnp.ndarray): Either a 1-D vector (single input feature) 

1240 or a ``(n_input_feat, K)`` stack / list of per-axis frequency 

1241 vectors (multi-dim input). 

1242 

1243 Returns: 

1244 jnp.ndarray: 1-D frequency vector (single input feature) or a 

1245 ``(N, n_input_feat)`` array of per-coefficient frequency 

1246 tuples (multi-dim input). 

1247 """ 

1248 fa = jnp.asarray(freqs) 

1249 if fa.ndim == 1: 

1250 return fa 

1251 # Multi-dim: per-axis vectors -> flat grid of frequency tuples in the 

1252 # same C-order used by `_calculate_mask` and the coefficient reshape. 

1253 grids = jnp.meshgrid(*[fa[i] for i in range(fa.shape[0])], indexing="ij") 

1254 return jnp.stack(grids, axis=-1).reshape(-1, fa.shape[0]) 

1255 

1256 @classmethod 

1257 def _calculate_coefficients( 

1258 cls, 

1259 model: Model, 

1260 n_samples: int, 

1261 random_key: Optional[random.PRNGKey] = None, 

1262 scale: bool = False, 

1263 **kwargs: Any, 

1264 ) -> Tuple[jnp.ndarray, jnp.ndarray]: 

1265 """ 

1266 Calculates the Fourier coefficients of a given model 

1267 using `n_samples`. 

1268 Optionally, `noise_params` can be passed to perform noisy simulation. 

1269 

1270 Args: 

1271 model (Model): The QFM model 

1272 n_samples (int): Number of samples to calculate average of coefficients 

1273 random_key (Optional[random.PRNGKey]): JAX random key for parameter 

1274 initialization. If None, uses the model's internal random key. 

1275 scale (bool, optional): Whether to scale the number of samples. 

1276 Defaults to False. 

1277 **kwargs: Additional keyword arguments for the model function. 

1278 

1279 Returns: 

1280 Tuple[jnp.ndarray, jnp.ndarray]: Parameters and Coefficients of size NxK 

1281 """ 

1282 if n_samples > 0: 

1283 if scale: 

1284 total_samples = int( 

1285 jnp.power(2, model.n_qubits) * n_samples * model.n_input_feat 

1286 ) 

1287 log.info(f"Using {total_samples} samples.") 

1288 else: 

1289 total_samples = n_samples 

1290 model.initialize_params(random_key, repeat=total_samples) 

1291 else: 

1292 total_samples = 1 

1293 

1294 coeffs, freqs = Coefficients.get_spectrum( 

1295 model, shift=True, trim=True, **kwargs 

1296 ) 

1297 

1298 return model.params, coeffs, freqs 

1299 

1300 @classmethod 

1301 def _correlate(cls, mat: jnp.ndarray, method: str = "pearson") -> jnp.ndarray: 

1302 """ 

1303 Correlates two arrays using `method`. 

1304 Currently, `pearson`, `complex_pearson`, `spearman`, and `covariance` 

1305 are supported. 

1306 

1307 Args: 

1308 mat (jnp.ndarray): Array of shape (N, K) 

1309 method (str, optional): Correlation method. Defaults to "pearson". 

1310 

1311 Raises: 

1312 ValueError: If the method is not supported. 

1313 

1314 Returns: 

1315 jnp.ndarray: Correlation matrix of `a` and `b`. 

1316 """ 

1317 assert len(mat.shape) >= 2, "Input matrix must have at least 2 dimensions" 

1318 

1319 # Note that for the general n-D case, we have to flatten along 

1320 # the first axis (last one is batch). 

1321 # Note that the order here is important so we can easily filter out 

1322 # negative coefficients later. 

1323 # Consider the following example: [[1,2,3],[4,5,6],[7,8,9]] 

1324 # we want to get [1, 4, 7, 2, 5, 8, 3, 6, 9] 

1325 # such that after correlation, all positive indexed coefficients 

1326 # will be in the bottom right quadrant 

1327 if method == "pearson": 

1328 result = cls._pearson(mat.reshape(mat.shape[0], -1)) 

1329 # result = cls._pearson(mat.reshape(mat.shape[-1], -1, order="F")) 

1330 elif method == "complex_pearson": 

1331 result = cls._complex_pearson(mat.reshape(mat.shape[0], -1)) 

1332 elif method == "spearman": 

1333 result = cls._spearman(mat.reshape(mat.shape[0], -1)) 

1334 # result = cls._spearman(mat.reshape(mat.shape[-1], -1, order="F")) 

1335 elif method == "covariance": 

1336 result = cls._covariance(mat.reshape(mat.shape[0], -1)) 

1337 else: 

1338 raise ValueError( 

1339 f"Unknown correlation method: {method}. Must be 'pearson', \ 

1340 'complex_pearson', 'spearman' or 'covariance'." 

1341 ) 

1342 

1343 return result 

1344 

1345 @classmethod 

1346 def _covariance(cls, mat: jnp.ndarray, minp: Optional[int] = 1) -> jnp.ndarray: 

1347 """ 

1348 Compute the Hermitian sample covariance between columns of `mat`, 

1349 permitting missing values (NaN or ±Inf). 

1350 

1351 For each pair (i, j) the covariance is computed over the rows that are 

1352 finite in both columns, as 

1353 sum(conj(x_i - mean_i) * (x_j - mean_j)) / (nobs - 1), 

1354 so it computes `X.conj().T @ X`. 

1355 Real input collapses to the ordinary real sample covariance; complex 

1356 input yields a complex matrix whose magnitude and angle carry the 

1357 covariance strength and relative phase. 

1358 

1359 

1360 Args: 

1361 mat : array_like, shape (N, K) 

1362 Input data. 

1363 minp : int, optional 

1364 Minimum number of paired observations required to form a 

1365 covariance. If the number of valid pairs for (i, j) is < minp, 

1366 the result is NaN. 

1367 

1368 Returns: 

1369 cov : ndarray, shape (K, K) 

1370 Sample covariance matrix. 

1371 """ 

1372 mat = jnp.asarray(mat) 

1373 real_dtype = jnp.asarray(mat.real).dtype 

1374 

1375 mask = jnp.isfinite(mat) 

1376 fmask = mask.astype(real_dtype) 

1377 safe = jnp.where(mask, mat, 0.0) 

1378 

1379 nobs = fmask.T @ fmask 

1380 nobs_safe = jnp.where(nobs > 0, nobs, 1.0) 

1381 

1382 sum_x = safe.T @ fmask 

1383 sum_y = fmask.T @ safe 

1384 

1385 masked = safe * fmask 

1386 sum_conj_xy = jnp.conj(masked).T @ masked 

1387 

1388 sxy = sum_conj_xy - (jnp.conj(sum_x) * sum_y) / nobs_safe 

1389 

1390 denom = jnp.where(nobs > 1, nobs - 1, jnp.nan) 

1391 result = sxy / denom 

1392 

1393 result = jnp.where(nobs < minp, jnp.nan, result) 

1394 

1395 return result 

1396 

1397 @classmethod 

1398 def _complex_pearson( 

1399 cls, mat: jnp.ndarray, minp: Optional[int] = 1 

1400 ) -> jnp.ndarray: 

1401 """ 

1402 Compute the complex Pearson correlation between columns of `mat`, 

1403 permitting missing values (NaN or ±Inf). 

1404 

1405 This uses the Hermitian normalized covariance 

1406 sum(conj(x_i - mean_i) * (x_j - mean_j)) / 

1407 sqrt(sum(abs(x_i - mean_i)**2) * sum(abs(x_j - mean_j)**2)). 

1408 Consequently, if column j is exp(1j * phi) times column i, then 

1409 abs(corr[i, j]) is 1 and angle(corr[i, j]) is phi. 

1410 

1411 Args: 

1412 mat : array_like, shape (N, K) 

1413 Input data. 

1414 minp : int, optional 

1415 Minimum number of paired observations required to form a correlation. 

1416 If the number of valid pairs for (i, j) is < minp, the result is NaN. 

1417 

1418 Returns: 

1419 corr : ndarray, shape (K, K) 

1420 Complex Pearson correlation matrix. 

1421 """ 

1422 mat = jnp.asarray(mat) 

1423 real_dtype = jnp.asarray(mat.real).dtype 

1424 

1425 mask = jnp.isfinite(mat) 

1426 fmask = mask.astype(real_dtype) 

1427 safe = jnp.where(mask, mat, 0.0) 

1428 

1429 nobs = fmask.T @ fmask 

1430 nobs_safe = jnp.where(nobs > 0, nobs, 1.0) 

1431 

1432 sum_x = safe.T @ fmask 

1433 sum_y = fmask.T @ safe 

1434 

1435 masked = safe * fmask 

1436 sum_conj_xy = jnp.conj(masked).T @ masked 

1437 

1438 safe_abs_sq = jnp.abs(safe) ** 2 

1439 sum_abs_x2 = safe_abs_sq.T @ fmask 

1440 sum_abs_y2 = fmask.T @ safe_abs_sq 

1441 

1442 ssx = sum_abs_x2 - jnp.abs(sum_x) ** 2 / nobs_safe 

1443 ssy = sum_abs_y2 - jnp.abs(sum_y) ** 2 / nobs_safe 

1444 sxy = sum_conj_xy - (jnp.conj(sum_x) * sum_y) / nobs_safe 

1445 

1446 denom = jnp.sqrt(ssx * ssy) 

1447 result = jnp.where(denom > 0, sxy / denom, jnp.nan) 

1448 magnitude = jnp.abs(result) 

1449 result = jnp.where(magnitude > 1.0, result / magnitude, result) 

1450 

1451 result = jnp.where(nobs < minp, jnp.nan, result) 

1452 

1453 return result 

1454 

1455 @classmethod 

1456 def _pearson(cls, mat: jnp.ndarray, minp: Optional[int] = 1) -> jnp.ndarray: 

1457 """ 

1458 Compute Pearson correlation between columns of `mat`, 

1459 permitting missing values (NaN or ±Inf). 

1460 

1461 The Pearson correlation is the normalized covariance, 

1462 corr[i, j] = cov[i, j] / sqrt(cov[i, i] * cov[j, j]), 

1463 so it is obtained by normalizing `_covariance` by the per-column 

1464 standard deviations. 

1465 

1466 If the input is complex, real and imaginary parts are stacked along 

1467 the sample axis so that both components contribute to the correlation 

1468 without discarding information. 

1469 

1470 Args: 

1471 mat : array_like, shape (N, K) 

1472 Input data. 

1473 minp : int, optional 

1474 Minimum number of paired observations required to form a correlation. 

1475 If the number of valid pairs for (i, j) is < minp, the result is NaN. 

1476 

1477 Returns: 

1478 corr : ndarray, shape (K, K) 

1479 Pearson correlation matrix. 

1480 """ 

1481 # Preserve complex information by splitting into real / imag samples. 

1482 # After stacking the data is real, so the Hermitian `_covariance` 

1483 # reduces to the ordinary real sample covariance. 

1484 if jnp.iscomplexobj(mat): 

1485 mat = jnp.concatenate([mat.real, mat.imag], axis=0) 

1486 

1487 cov = cls._covariance(mat, minp=minp) 

1488 

1489 # corr[i, j] = cov[i, j] / (std_i * std_j) with std_i = sqrt(cov[i, i]) 

1490 std = jnp.sqrt(jnp.diagonal(cov)) 

1491 denom = std[:, None] * std[None, :] 

1492 result = jnp.where(denom > 0, cov / denom, jnp.nan) 

1493 

1494 # clip numerical drift to [-1, 1] 

1495 result = jnp.clip(jnp.real(result), -1.0, 1.0) 

1496 

1497 return result 

1498 

1499 @classmethod 

1500 def _spearman(cls, mat: jnp.ndarray, minp: Optional[int] = 1) -> jnp.ndarray: 

1501 """ 

1502 Based on Pandas correlation method as implemented here: 

1503 https://github.com/pandas-dev/pandas/blob/main/pandas/_libs/algos.pyx 

1504 

1505 Compute Spearman correlation between columns of `mat`, 

1506 permitting missing values (NaN or ±Inf). 

1507 

1508 If the input is complex, real and imaginary parts are stacked along 

1509 the sample axis so that both components contribute to the correlation 

1510 without discarding information. 

1511 

1512 Args: 

1513 mat : array_like, shape (N, K) 

1514 Input data. 

1515 minp : int, optional 

1516 Minimum number of paired observations required to form a correlation. 

1517 If the number of valid pairs for (i, j) is < minp, the result is NaN. 

1518 

1519 Returns: 

1520 corr : ndarray, shape (K, K) 

1521 Spearman correlation matrix. 

1522 """ 

1523 # Preserve complex information by splitting into real / imag samples 

1524 if jnp.iscomplexobj(mat): 

1525 mat = jnp.concatenate([mat.real, mat.imag], axis=0) 

1526 

1527 mat = jnp.asarray(mat) 

1528 N, K = mat.shape 

1529 

1530 # trivial all-NaN answer if too few rows 

1531 if N < minp: 

1532 return jnp.full((K, K), jnp.nan) 

1533 

1534 # mask of finite entries 

1535 mask = jnp.isfinite(mat) # shape (N, K), dtype=bool 

1536 

1537 # precompute ranks column-wise ignoring NaNs 

1538 ranks = np.full((N, K), np.nan) 

1539 for j in range(K): 

1540 valid = mask[:, j] 

1541 if valid.any(): 

1542 ranks[valid, j] = rankdata(mat[valid, j], method="average") 

1543 

1544 ranks = jnp.asarray(ranks) 

1545 

1546 # Vectorised Pearson on the ranks 

1547 # Replace NaN ranks with 0; use mask to track validity. 

1548 rank_mask = jnp.isfinite(ranks) 

1549 safe_ranks = jnp.where(rank_mask, ranks, 0.0) 

1550 

1551 # Pairwise valid-observation counts (K, K) 

1552 fmask = rank_mask.astype(ranks.dtype) 

1553 nobs = fmask.T @ fmask 

1554 

1555 # Pairwise sums over mutually-valid rows 

1556 sum_x = safe_ranks.T @ fmask # (K, K) 

1557 sum_y = fmask.T @ safe_ranks # (K, K) 

1558 

1559 # Pairwise products 

1560 masked_ranks = safe_ranks * fmask # same as safe_ranks 

1561 sum_xy = masked_ranks.T @ masked_ranks # (K, K) 

1562 

1563 safe_sq = safe_ranks**2 

1564 sum_x2 = safe_sq.T @ fmask # (K, K) 

1565 sum_y2 = fmask.T @ safe_sq # (K, K) 

1566 

1567 nobs_safe = jnp.where(nobs > 0, nobs, 1.0) 

1568 ssx = sum_x2 - sum_x**2 / nobs_safe 

1569 ssy = sum_y2 - sum_y**2 / nobs_safe 

1570 sxy = sum_xy - (sum_x * sum_y) / nobs_safe 

1571 

1572 denom = jnp.sqrt(ssx * ssy) 

1573 result = jnp.where(denom > 0, sxy / denom, jnp.nan) 

1574 result = jnp.clip(result, -1.0, 1.0) 

1575 

1576 # Enforce minp 

1577 result = jnp.where(nobs < minp, jnp.nan, result) 

1578 

1579 return result 

1580 

1581 @classmethod 

1582 def _weighting_linear(cls, fourier_fingerprint: jnp.ndarray) -> jnp.ndarray: 

1583 """ 

1584 Performs weighting on the given correlation matrix. 

1585 Here, low-frequent coefficients are weighted more heavily. 

1586 

1587 Args: 

1588 fourier_fingerprint (jnp.ndarray): Correlation matrix 

1589 """ 

1590 assert ( 

1591 fourier_fingerprint.shape[0] % 2 != 0 

1592 and fourier_fingerprint.shape[1] % 2 != 0 

1593 ), ( 

1594 "Correlation matrix must have odd dimensions. \ 

1595 Hint: use `trim` argument when calling `get_spectrum`." 

1596 ) 

1597 assert fourier_fingerprint.shape[0] == fourier_fingerprint.shape[1], ( 

1598 "Correlation matrix must be square." 

1599 ) 

1600 

1601 # The weight matrix produced by the previous quadrant-mirror 

1602 # construction has a closed form: it is a "tent" sum along the 

1603 # two axes. Concretely, with N = fourier_fingerprint.shape[0] 

1604 # (odd) and center = N // 2, 

1605 # W[i, j] = u[i] + u[j] 

1606 # where u[k] = (center - |k - center|) / (2 * center) 

1607 # is a triangular weighting peaking at the centre (the zero 

1608 # frequency) and decaying linearly to 0 at the spectrum edges. 

1609 N = fourier_fingerprint.shape[0] 

1610 center = N // 2 

1611 k = jnp.arange(N) 

1612 u = (center - jnp.abs(k - center)) / (2 * center) 

1613 

1614 return fourier_fingerprint * (u[:, None] + u[None, :]) 

1615 

1616 @classmethod 

1617 def _weighting_mean( 

1618 cls, fourier_fingerprint: jnp.ndarray, coeffs: jnp.ndarray 

1619 ) -> jnp.ndarray: 

1620 """ 

1621 Performs weighting on the given correlation matrix. 

1622 Here, we use the product of the mean of the coefficients as weights. 

1623 This suppresses correlations where the mean of the coefficients is near zero. 

1624 

1625 Args: 

1626 fourier_fingerprint (jnp.ndarray): Correlation matrix 

1627 coeffs (jnp.ndarray): Fourier coefficients 

1628 """ 

1629 assert fourier_fingerprint.shape[0] == fourier_fingerprint.shape[1], ( 

1630 "Correlation matrix must be square." 

1631 ) 

1632 assert len(coeffs.shape) >= 2, ( 

1633 "Coefficient matrix must contain coefficient axes and a sample axis." 

1634 ) 

1635 

1636 coefficient_means = jnp.abs(jnp.mean(coeffs, axis=-1)) 

1637 coefficient_means = coefficient_means.T.reshape(-1) 

1638 

1639 assert fourier_fingerprint.shape[0] == coefficient_means.shape[0], ( 

1640 "Correlation matrix size must match the number of Fourier coefficients." 

1641 ) 

1642 

1643 # Apply the rank-1 weight w[i] * w[j] via broadcasting instead 

1644 # of materialising an explicit `jnp.outer` N x N intermediate. 

1645 return ( 

1646 fourier_fingerprint 

1647 * coefficient_means[:, None] 

1648 * coefficient_means[None, :] 

1649 ) 

1650 

1651 

1652class Datasets: 

1653 @classmethod 

1654 def generate_fourier_series( 

1655 cls, 

1656 random_key: random.PRNGKey, 

1657 model: Model, 

1658 coefficients_min: float = 0.0, 

1659 coefficients_max: float = 1.0, 

1660 zero_centered: bool = False, 

1661 ) -> jnp.ndarray: 

1662 """ 

1663 Generates the Fourier series representation of a function. 

1664 It uses the `model.frequencies` property to retrieve the frequency 

1665 information. This ensures that the resulting Fourier series is 

1666 compatible with the model. 

1667 

1668 This function is capable of generating $D$-dimensional Fourier series 

1669 (again defined by `model.n_input_feat`). 

1670 The highest frequency $N$ is retrieved per dimension. 

1671 

1672 Samples of the Fourier coefficients are drawn from a uniform circle. 

1673 

1674 Args: 

1675 random_key (random.PRNGKey): Random number key for JAX. 

1676 model (Model): The quantum circuit model. 

1677 coefficients_min (float, optional): Minimum value for the coefficients. 

1678 Defaults to 0.0. 

1679 coefficients_max (float, optional): Maximum value for the coefficients. 

1680 Defaults to 1.0. 

1681 zero_centered (bool, optional): Whether to zero-center the coefficients. 

1682 Defaults to False. 

1683 

1684 Returns: 

1685 jnp.ndarray: Input domain samples with shape ((N,)*D, D) 

1686 jnp.ndarray: Fourier series values with shape ((N,)*D) 

1687 jnp.ndarray: Fourier coefficients with shape ((N,)*D) 

1688 

1689 """ 

1690 # TODO: the following code can be considered to 

1691 # capturing a truly random spectrum. 

1692 # add some constraints on the spectrum, i.e. not fully 

1693 

1694 # Note: one key observation for understanding the following code is, 

1695 # that instead of wrapping your head around symmetries in multi- 

1696 # dimensional coefficient matrices, one can simply look at the flattened 

1697 # version of such a matrix and reshape later. It just works out. 

1698 

1699 # going from [0, 2pi] with the resolution required for highest frequency 

1700 # permute with input dimensionality to get an n-d grid of domain samples 

1701 # the output shape comes from the fact that want to create a "coordinate system" 

1702 domain_samples_per_input_dim = jnp.stack( 

1703 jnp.meshgrid( 

1704 *[jnp.arange(0, 2 * jnp.pi, 2 * jnp.pi / d) for d in model.degree] 

1705 ) 

1706 ).T.reshape(-1, model.n_input_feat) 

1707 

1708 # generate the frequency indices for each dimension. 

1709 # this will have the same shape as the domain samples 

1710 frequencies = jnp.stack(jnp.meshgrid(*model.frequencies)).T.reshape( 

1711 -1, model.n_input_feat 

1712 ) 

1713 

1714 # using the frequency information, sample coefficients for each dimension 

1715 # shape: (input_dims, n_freqs_per_input_dim // 2 + 1) 

1716 

1717 coefficients = cls.uniform_circle( 

1718 random_key, 

1719 low=coefficients_min, 

1720 high=coefficients_max, 

1721 size=math.prod(model.degree) // 2 + 1, 

1722 ) 

1723 

1724 # zero center (first coeff = 0) 

1725 # we can assume the first coeff is the offset, because we're dealing 

1726 # with a non-symmetric spectrum here 

1727 if zero_centered: 

1728 coefficients = coefficients.at[0].set(0.0) 

1729 else: 

1730 coefficients = coefficients.at[0].set(coefficients[0].real) 

1731 

1732 # ensure symmetry (here, non_negative_ is removed!), 

1733 # giving us the full coefficients vector 

1734 coefficients = jnp.concat( 

1735 [ 

1736 jnp.flip(coefficients[..., 1:]).conjugate(), 

1737 coefficients, 

1738 ], 

1739 axis=-1, 

1740 ) 

1741 

1742 # Vectorized version of $f(x) = \sum_{n=0}^{N-1} c_n * e^{i * \omega_n * x}$ 

1743 # it takes into account the input dimension, i.e. the output is a matrix 

1744 # normalization uses the n_freqs component of the coefficients 

1745 values = jnp.real( 

1746 ( 

1747 jnp.exp(1j * (domain_samples_per_input_dim @ frequencies.T)) 

1748 * coefficients 

1749 ).sum(axis=1) 

1750 / coefficients.size 

1751 ) 

1752 

1753 # return all the information we have 

1754 return [ 

1755 domain_samples_per_input_dim.reshape(*model.degree, -1), 

1756 values.reshape(model.degree), 

1757 coefficients.reshape(model.degree), 

1758 ] 

1759 

1760 @classmethod 

1761 def uniform_circle( 

1762 cls, 

1763 random_key: random.PRNGKey, 

1764 size: Union[jnp.ndarray, List, int], 

1765 low=0.0, 

1766 high=1.0, 

1767 ): 

1768 """ 

1769 Random number generator for complex numbers sampled inside the unit circle 

1770 

1771 Args: 

1772 random_key (random.PRNGKey): Random number key for JAX. 

1773 size (Union[jnp.ndarray, int]): Number of samples. If a 2D array is passed, 

1774 the first dimension will be the number of dimensions. 

1775 low (float, optional): Minimum Radius. Defaults to 0.0. 

1776 high (float, optional): Maximum Radius. Defaults to 1.0. 

1777 

1778 Returns 

1779 jnp.ndarray: Array of complex numbers with shape of `size` 

1780 """ 

1781 

1782 if isinstance(size, int): 

1783 size = jnp.array([size]) 

1784 

1785 random_key, random_key1 = random.split(random_key) 

1786 return jnp.sqrt( 

1787 random.uniform(random_key, size, minval=low, maxval=high) 

1788 ) * jnp.exp(2j * jnp.pi * random.uniform(random_key1, size))