Coverage for qml_essentials / coefficients.py: 95%

504 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-06-01 13:52 +0000

1from __future__ import annotations 

2import math 

3from collections import defaultdict 

4from dataclasses import dataclass 

5import jax.numpy as jnp 

6from jax import random 

7import numpy as np 

8from scipy.stats import rankdata 

9from functools import reduce 

10from typing import List, Tuple, Optional, Any, Dict, Union 

11 

12from qml_essentials.model import Model 

13from qml_essentials.utils import PauliCircuit 

14from qml_essentials.operations import ( 

15 Operation, 

16 PauliX, 

17 PauliY, 

18 PauliZ, 

19) 

20 

21import logging 

22 

23log = logging.getLogger(__name__) 

24 

25 

26class Coefficients: 

27 @classmethod 

28 def get_spectrum( 

29 cls, 

30 model: Model, 

31 mfs: int = 1, 

32 mts: int = 1, 

33 shift=False, 

34 trim=False, 

35 numerical_cap: Optional[float] = -1, 

36 **kwargs, 

37 ) -> Tuple[jnp.ndarray, jnp.ndarray]: 

38 """ 

39 Extracts the coefficients of a given model using a FFT (jnp-fft). 

40 

41 Note that the coefficients are complex numbers, but the imaginary part 

42 of the coefficients should be very close to zero, since the expectation 

43 values of the Pauli operators are real numbers. 

44 

45 It can perform oversampling in both the frequency and time domain 

46 using the `mfs` and `mts` arguments. 

47 

48 Args: 

49 model (Model): The model to sample. 

50 mfs (int): Multiplicator for the highest frequency. Default is 2. 

51 mts (int): Multiplicator for the number of time samples. Default is 1. 

52 shift (bool): Whether to apply jnp-fftshift. Default is False. 

53 trim (bool): Whether to remove the Nyquist frequency if spectrum is even. 

54 Default is False. 

55 numerical_cap (Optional[float]): Numerical cap for the coefficients. 

56 kwargs (Any): Additional keyword arguments for the model function. 

57 

58 Returns: 

59 Tuple[jnp.ndarray, jnp.ndarray]: Tuple containing the coefficients 

60 and frequencies. 

61 """ 

62 kwargs.setdefault("force_mean", True) 

63 kwargs.setdefault("execution_type", "expval") 

64 

65 coeffs, freqs = cls._fourier_transform(model, mfs=mfs, mts=mts, **kwargs) 

66 

67 if not jnp.isclose(jnp.sum(coeffs).imag, 0.0, rtol=1.0e-5): 

68 raise ValueError( 

69 f"Spectrum is not real. Imaginary part of coefficients is:\ 

70 {jnp.sum(coeffs).imag}" 

71 ) 

72 

73 if trim: 

74 for ax in range(model.n_input_feat): 

75 if coeffs.shape[ax] % 2 == 0: 

76 coeffs = np.delete(coeffs, len(coeffs) // 2, axis=ax) 

77 freqs = [np.delete(freq, len(freq) // 2, axis=ax) for freq in freqs] 

78 

79 if shift: 

80 coeffs = jnp.fft.fftshift(coeffs, axes=list(range(model.n_input_feat))) 

81 freqs = np.fft.fftshift(freqs) 

82 

83 if numerical_cap > 0: 

84 # set coeffs below threshold to zero 

85 coeffs = jnp.where( 

86 jnp.abs(coeffs) < numerical_cap, 

87 jnp.zeros_like(coeffs), 

88 coeffs, 

89 ) 

90 

91 if len(freqs) == 1: 

92 freqs = freqs[0] 

93 

94 return coeffs, freqs 

95 

96 @classmethod 

97 def _fourier_transform( 

98 cls, model: Model, mfs: int, mts: int, **kwargs: Any 

99 ) -> jnp.ndarray: 

100 # Create a frequency vector with as many frequencies as model degrees, 

101 # oversampled by mfs 

102 n_freqs: jnp.ndarray = jnp.array( 

103 [mfs * model.degree[i] for i in range(model.n_input_feat)] 

104 ) 

105 

106 start, stop, step = 0, 2 * mts * jnp.pi, 2 * jnp.pi / n_freqs 

107 # Stretch according to the number of frequencies 

108 inputs: List = [ 

109 jnp.arange(start, stop, step[i]) for i in range(model.n_input_feat) 

110 ] 

111 

112 # permute with input dimensionality 

113 nd_inputs = jnp.array( 

114 jnp.meshgrid(*[inputs[i] for i in range(model.n_input_feat)]) 

115 ).T.reshape(-1, model.n_input_feat) 

116 

117 # Output vector is not necessarily the same length as input 

118 outputs = model(inputs=nd_inputs, **kwargs) 

119 outputs = outputs.reshape( 

120 *[inputs[i].shape[0] for i in range(model.n_input_feat)], -1 

121 ).squeeze() 

122 

123 coeffs = jnp.fft.fftn(outputs, axes=list(range(model.n_input_feat))) 

124 

125 freqs = [ 

126 jnp.fft.fftfreq(int(mts * n_freqs[i]), 1 / n_freqs[i]) 

127 for i in range(model.n_input_feat) 

128 ] 

129 # freqs = jnp.fft.fftfreq(mts * n_freqs, 1 / n_freqs) 

130 

131 # TODO: this could cause issues with multidim input 

132 # FIXME: account for different frequencies in multidim input scenarios 

133 # Run the fft and rearrange + 

134 # normalize the output (using product if multidim) 

135 return ( 

136 coeffs / math.prod(outputs.shape[0 : model.n_input_feat]), 

137 freqs, 

138 ) 

139 

140 @classmethod 

141 def get_psd(cls, coeffs: jnp.ndarray) -> jnp.ndarray: 

142 """ 

143 Calculates the power spectral density (PSD) from given Fourier coefficients. 

144 

145 Args: 

146 coeffs (jnp.ndarray): The Fourier coefficients. 

147 

148 Returns: 

149 jnp.ndarray: The power spectral density. 

150 """ 

151 # TODO: if we apply trim=True in advance, this will be slightly wrong.. 

152 

153 def abs2(x): 

154 return x.real**2 + x.imag**2 

155 

156 scale = 2.0 / (len(coeffs) ** 2) 

157 return scale * abs2(coeffs) 

158 

159 @classmethod 

160 def evaluate_Fourier_series( 

161 cls, 

162 coefficients: jnp.ndarray, 

163 frequencies: jnp.ndarray, 

164 inputs: Union[jnp.ndarray, list, float], 

165 ) -> float: 

166 """ 

167 Evaluate the function value of a Fourier series at one point. 

168 

169 Args: 

170 coefficients (jnp.ndarray): Coefficients of the Fourier series. 

171 frequencies (jnp.ndarray): Corresponding frequencies. 

172 inputs (jnp.ndarray): Point at which to evaluate the function. 

173 Returns: 

174 float: The function value at the input point. 

175 """ 

176 if isinstance(frequencies, list): 

177 if len(coefficients.shape) <= len(frequencies): 

178 coefficients = coefficients[..., jnp.newaxis] 

179 else: 

180 if len(coefficients.shape) == 1: 

181 coefficients = coefficients[..., jnp.newaxis] 

182 

183 if isinstance(inputs, list): 

184 inputs = jnp.array(inputs) 

185 if len(inputs.shape) < 1: 

186 inputs = inputs[jnp.newaxis, ...] 

187 

188 if isinstance(frequencies, list): 

189 input_dim = len(frequencies) 

190 frequencies = jnp.stack(jnp.meshgrid(*frequencies)) 

191 if input_dim != len(inputs): 

192 frequencies = jnp.repeat( 

193 frequencies[jnp.newaxis, ...], inputs.shape[0], axis=0 

194 ) 

195 freq_inputs = jnp.einsum("bi...,b->b...", frequencies, inputs) 

196 exponents = jnp.exp(1j * freq_inputs).T 

197 exp = jnp.einsum("jl...k,jl...b->b...k", coefficients, exponents) 

198 else: 

199 freq_inputs = jnp.einsum("i...,i->...", frequencies, inputs) 

200 exponents = jnp.exp(1j * freq_inputs).T 

201 exp = jnp.einsum("jl...k,jl...->k...", coefficients, exponents) 

202 else: 

203 frequencies = jnp.repeat( 

204 frequencies[jnp.newaxis, ...], inputs.shape[0], axis=0 

205 ) 

206 freq_inputs = jnp.einsum("i...,i->i...", frequencies, inputs) 

207 exponents = jnp.exp(1j * freq_inputs) 

208 exp = jnp.einsum("j...k,ij...->ik...", coefficients, exponents) 

209 

210 return jnp.squeeze(jnp.real(exp)) 

211 

212 

213class FourierTree: 

214 """ 

215 Sine-cosine tree representation for the algorithm by Nemkov et al. 

216 This tree can be used to obtain analytical Fourier coefficients for a given 

217 Pauli-Clifford circuit. 

218 """ 

219 

220 class CoefficientsTreeNode: 

221 """ 

222 Representation of a node in the coefficients tree for the algorithm by 

223 Nemkov et al. 

224 """ 

225 

226 def __init__( 

227 self, 

228 parameter_idx: Optional[int], 

229 observable: FourierTree.PauliOperator, 

230 is_sine_factor: bool, 

231 is_cosine_factor: bool, 

232 left: Optional[FourierTree.CoefficientsTreeNode] = None, 

233 right: Optional[FourierTree.CoefficientsTreeNode] = None, 

234 ): 

235 """ 

236 Coefficient tree node initialisation. Each node has information about 

237 its creation context and it's children, i.e.: 

238 

239 Args: 

240 parameter_idx (Optional[int]): Index of the corresp. param. index i. 

241 observable (FourierTree.PauliOperator): The nodes observable to 

242 obtain the expectation value that contributes to the constant 

243 term. 

244 is_sine_factor (bool): If this node belongs to a sine coefficient. 

245 is_cosine_factor (bool): If this node belongs to a cosine coefficient. 

246 left (Optional[CoefficientsTreeNode]): left child (if any). 

247 right (Optional[CoefficientsTreeNode]): right child (if any). 

248 """ 

249 self.parameter_idx = parameter_idx 

250 

251 assert not (is_sine_factor and is_cosine_factor), ( 

252 "Cannot be sine and cosine at the same time" 

253 ) 

254 self.is_sine_factor = is_sine_factor 

255 self.is_cosine_factor = is_cosine_factor 

256 

257 # If the observable does not constist of only Z and I, the 

258 # expectation (and therefore the constant node term) is zero 

259 if jnp.logical_or( 

260 observable.list_repr == 0, observable.list_repr == 1 

261 ).any(): 

262 self.term = 0.0 

263 else: 

264 self.term = observable.phase 

265 

266 self.left = left 

267 self.right = right 

268 

269 def evaluate(self, parameters: list[float]) -> float: 

270 """ 

271 Recursive function to evaluate the expectation of the coefficient tree, 

272 starting from the current node. 

273 

274 Args: 

275 parameters (list[float]): The parameters, by which the circuit (and 

276 therefore the tree) is parametrised. 

277 

278 Returns: 

279 float: The expectation for the current node and it's children. 

280 """ 

281 factor = ( 

282 parameters[self.parameter_idx] 

283 if self.parameter_idx is not None 

284 else 1.0 

285 ) 

286 if self.is_sine_factor: 

287 factor = 1j * jnp.sin(factor) 

288 elif self.is_cosine_factor: 

289 factor = jnp.cos(factor) 

290 if not (self.left or self.right): # leaf 

291 return factor * self.term 

292 

293 sum_children = 0.0 

294 if self.left: 

295 left = self.left.evaluate(parameters) 

296 sum_children = sum_children + left 

297 if self.right: 

298 right = self.right.evaluate(parameters) 

299 sum_children = sum_children + right 

300 

301 return factor * sum_children 

302 

303 def get_leafs( 

304 self, 

305 sin_list: List[int], 

306 cos_list: List[int], 

307 existing_leafs: List[FourierTree.TreeLeaf] = [], 

308 ) -> List[FourierTree.TreeLeaf]: 

309 """ 

310 Traverse the tree starting from the current node, to obtain the tree 

311 leafs only. 

312 The leafs correspond to the terms in the sine-cosine tree 

313 representation that eventually are used to obtain coefficients and 

314 frequencies. 

315 Sine and cosine lists are recursively passed to the children until a 

316 leaf is reached (top to bottom). 

317 Leafs are then passed bottom to top to the caller. 

318 

319 Args: 

320 sin_list (List[int]): Current number of sine contributions for each 

321 parameter. Has the same length as the parameters, as each 

322 position corresponds to one parameter. 

323 cos_list (List[int]): Current number of cosine contributions for 

324 each parameter. Has the same length as the parameters, as each 

325 position corresponds to one parameter. 

326 existing_leafs (List[TreeLeaf]): Current list of leaf nodes from 

327 parents. 

328 

329 Returns: 

330 List[TreeLeaf]: Updated list of leaf nodes. 

331 """ 

332 

333 if self.is_sine_factor: 

334 sin_list = sin_list.at[self.parameter_idx].set( 

335 sin_list[self.parameter_idx] + 1 

336 ) 

337 if self.is_cosine_factor: 

338 cos_list = cos_list.at[self.parameter_idx].set( 

339 cos_list[self.parameter_idx] + 1 

340 ) 

341 

342 if not (self.left or self.right): # leaf 

343 if self.term != 0.0: 

344 return [FourierTree.TreeLeaf(sin_list, cos_list, self.term)] 

345 else: 

346 return [] 

347 

348 if self.left: 

349 leafs_left = self.left.get_leafs( 

350 sin_list.copy(), cos_list.copy(), existing_leafs.copy() 

351 ) 

352 else: 

353 leafs_left = [] 

354 

355 if self.right: 

356 leafs_right = self.right.get_leafs( 

357 sin_list.copy(), cos_list.copy(), existing_leafs.copy() 

358 ) 

359 else: 

360 leafs_right = [] 

361 

362 existing_leafs.extend(leafs_left) 

363 existing_leafs.extend(leafs_right) 

364 return existing_leafs 

365 

366 @dataclass 

367 class TreeLeaf: 

368 """ 

369 Coefficient tree leafs according to the algorithm by Nemkov et al., which 

370 correspond to the terms in the sine-cosine tree representation that 

371 eventually are used to obtain coefficients and frequencies. 

372 

373 Args: 

374 sin_indices (List[int]): Current number of sine contributions for each 

375 parameter. Has the same length as the parameters, as each 

376 position corresponds to one parameter. 

377 cos_indices (List[int]): Current number of cosine contributions for 

378 each parameter. Has the same length as the parameters, as each 

379 position corresponds to one parameter. 

380 term (jnp.complex): Constant factor of the leaf, depending on the 

381 expectation value of the observable, and a phase. 

382 """ 

383 

384 sin_indices: List[int] 

385 cos_indices: List[int] 

386 term: complex 

387 

388 class PauliOperator: 

389 """ 

390 Utility class for storing Pauli Rotations, the corresponding indices 

391 in the XY-Space (whether there is a gate with X or Y generator at a 

392 certain qubit) and the phase. 

393 

394 Args: 

395 pauli (Union[Operator, jnp.ndarray[int]]): Pauli Rotation Operation 

396 or list representation 

397 n_qubits (int): Number of qubits in the circuit 

398 prev_xy_indices (Optional[jnp.ndarray[bool]]): X/Y indices of the 

399 previous Pauli sequence. Defaults to None. 

400 is_observable (bool): If the operator is an observable. Defaults to 

401 False. 

402 is_init (bool): If this Pauli operator is initialised the first 

403 time. Defaults to True. 

404 phase (float): Phase of the operator. Defaults to 1.0 

405 """ 

406 

407 def __init__( 

408 self, 

409 pauli: Union[Operation, jnp.ndarray[int]], 

410 n_qubits: int, 

411 prev_xy_indices: Optional[jnp.ndarray[bool]] = None, 

412 is_observable: bool = False, 

413 is_init: bool = True, 

414 phase: float = 1.0, 

415 ): 

416 self.is_observable = is_observable 

417 self.phase = phase 

418 

419 if is_init: 

420 if not is_observable: 

421 pauli = pauli.generator() 

422 self.list_repr = self._create_list_representation(pauli, n_qubits) 

423 else: 

424 assert isinstance(pauli, jnp.ndarray) 

425 self.list_repr = pauli 

426 

427 if prev_xy_indices is None: 

428 prev_xy_indices = jnp.zeros(n_qubits, dtype=bool) 

429 self.xy_indices = jnp.logical_or( 

430 prev_xy_indices, 

431 self._compute_xy_indices(self.list_repr, rev=is_observable), 

432 ) 

433 

434 @staticmethod 

435 def _compute_xy_indices( 

436 op: jnp.ndarray[int], rev: bool = False 

437 ) -> jnp.ndarray[bool]: 

438 """ 

439 Computes the positions of X or Y gates to an one-hot encoded boolen 

440 array. 

441 

442 Args: 

443 op (jnp.ndarray[int]): Pauli-Operation list representation. 

444 rev (bool): Whether to negate the array. 

445 

446 Returns: 

447 jnp.ndarray[bool]: One hot encoded boolean array. 

448 """ 

449 xy_indices = (op == 0) + (op == 1) 

450 if rev: 

451 xy_indices = ~xy_indices 

452 return xy_indices 

453 

454 @staticmethod 

455 def _create_list_representation( 

456 op: Operation, n_qubits: int 

457 ) -> jnp.ndarray[int]: 

458 """ 

459 Create list representation of an Operation. 

460 Generally, the list representation is a list of length n_qubits, 

461 where at each position a Pauli Operator is encoded as such: 

462 I: -1 

463 X: 0 

464 Y: 1 

465 Z: 2 

466 

467 Args: 

468 op (Operation): Gate operation (PauliX, PauliY, PauliZ, or 

469 Hermitian wrapping a multi-qubit Pauli tensor product). 

470 n_qubits (int): number of qubits in the circuit 

471 

472 Returns: 

473 jnp.ndarray[int]: List representation 

474 """ 

475 pauli_repr = -jnp.ones(n_qubits, dtype=int) 

476 

477 _NAME_TO_IDX = {"PauliX": 0, "PauliY": 1, "PauliZ": 2} 

478 

479 if op.name in _NAME_TO_IDX: 

480 pauli_repr = pauli_repr.at[op.wires[0]].set(_NAME_TO_IDX[op.name]) 

481 elif isinstance(op, PauliX): 

482 pauli_repr = pauli_repr.at[op.wires[0]].set(0) 

483 elif isinstance(op, PauliY): 

484 pauli_repr = pauli_repr.at[op.wires[0]].set(1) 

485 elif isinstance(op, PauliZ): 

486 pauli_repr = pauli_repr.at[op.wires[0]].set(2) 

487 else: 

488 # Multi-qubit case: decompose via pauli_string_from_operation 

489 from qml_essentials.operations import pauli_string_from_operation 

490 

491 pauli_str = pauli_string_from_operation(op) 

492 char_to_idx = {"X": 0, "Y": 1, "Z": 2, "I": -1} 

493 for i, (wire, ch) in enumerate(zip(op.wires, pauli_str)): 

494 idx = char_to_idx.get(ch, -1) 

495 if idx >= 0: 

496 pauli_repr = pauli_repr.at[wire].set(idx) 

497 

498 return pauli_repr 

499 

500 def is_commuting(self, pauli: jnp.ndarray[int]) -> bool: 

501 """ 

502 Computes if this Pauli commutes with another Pauli operator. 

503 This computation is based on the fact that The commutator is zero 

504 if and only if the number of anticommuting single-qubit Paulis is 

505 even. 

506 

507 Args: 

508 pauli (jnp.ndarray[int]): List representation of another Pauli 

509 

510 Returns: 

511 bool: If the current and other Pauli are commuting. 

512 """ 

513 anticommutator = jnp.where( 

514 pauli < 0, 

515 False, 

516 jnp.where( 

517 self.list_repr < 0, 

518 False, 

519 jnp.where(self.list_repr == pauli, False, True), 

520 ), 

521 ) 

522 return not (jnp.sum(anticommutator) % 2) 

523 

524 def tensor(self, pauli: jnp.ndarray[int]) -> FourierTree.PauliOperator: 

525 """ 

526 Compute tensor product between the current Pauli and a given list 

527 representation of another Pauli operator. 

528 

529 Args: 

530 pauli (jnp.ndarray[int]): List representation of Pauli 

531 

532 Returns: 

533 FourierTree.PauliOperator: New Pauli operator object, which 

534 contains the tensor product 

535 """ 

536 diff = (pauli - self.list_repr + 3) % 3 

537 phase = jnp.where( 

538 self.list_repr < 0, 

539 1.0, 

540 jnp.where( 

541 pauli < 0, 

542 1.0, 

543 jnp.where( 

544 diff == 2, 

545 1.0j, 

546 jnp.where(diff == 1, -1.0j, 1.0), 

547 ), 

548 ), 

549 ) 

550 

551 obs = jnp.where( 

552 self.list_repr < 0, 

553 pauli, 

554 jnp.where( 

555 pauli < 0, 

556 self.list_repr, 

557 jnp.where( 

558 diff == 2, 

559 (self.list_repr + 1) % 3, 

560 jnp.where(diff == 1, (self.list_repr + 2) % 3, -1), 

561 ), 

562 ), 

563 ) 

564 phase = self.phase * jnp.prod(phase) 

565 return FourierTree.PauliOperator( 

566 obs, phase=phase, n_qubits=obs.size, is_init=False, is_observable=True 

567 ) 

568 

569 def __init__(self, model: Model): 

570 """ 

571 Tree initialisation, based on the Pauli-Clifford representation of a model. 

572 Currently, only one input feature is supported. 

573 

574 **Usage**: 

575 ``` 

576 # initialise a model 

577 model = Model(...) 

578 

579 # initialise and build FourierTree 

580 tree = FourierTree(model) 

581 

582 # get expectaion value 

583 exp = tree() 

584 

585 # Get spectrum (for each observable, we have one list element) 

586 coeff_list, freq_list = tree.spectrum() 

587 ``` 

588 

589 Args: 

590 model (Model): The Model, for which to build the tree 

591 """ 

592 self.model = model 

593 self.tree_roots = None 

594 

595 inputs = self.model._inputs_validation([1.0]) 

596 

597 # Record the circuit tape using jaqsi's tape recording 

598 raw_tape = self.model.script._record(params=model.params, inputs=inputs) 

599 

600 # Build observables from the model's output_qubit configuration 

601 _, obs_list = self.model._build_obs() 

602 

603 quantum_tape = PauliCircuit.from_parameterised_circuit( 

604 raw_tape, observables=obs_list 

605 ) 

606 

607 self.parameters = [jnp.squeeze(p) for p in quantum_tape.get_parameters()] 

608 

609 self.input_indices, self.all_input_indices = quantum_tape.get_input_indices() 

610 

611 self.observables = self._encode_observables(quantum_tape.observables) 

612 

613 pauli_rot = FourierTree.PauliOperator( 

614 quantum_tape.operations[0], 

615 self.model.n_qubits, 

616 ) 

617 self.pauli_rotations = [pauli_rot] 

618 for op in quantum_tape.operations[1:]: 

619 pauli_rot = FourierTree.PauliOperator( 

620 op, self.model.n_qubits, pauli_rot.xy_indices 

621 ) 

622 self.pauli_rotations.append(pauli_rot) 

623 

624 self.tree_roots = self.build() 

625 self.leafs: List[List[FourierTree.TreeLeaf]] = self._get_tree_leafs() 

626 

627 def __call__( 

628 self, 

629 params: Optional[jnp.ndarray] = None, 

630 inputs: Optional[jnp.ndarray] = None, 

631 **kwargs, 

632 ) -> jnp.ndarray: 

633 """ 

634 Evaluates the Fourier tree via sine-cosine terms sum. This is 

635 equivalent to computing the expectation value of the observables with 

636 respect to the corresponding circuit. 

637 

638 Args: 

639 params (Optional[jnp.ndarray], optional): Parameters of the model. 

640 Defaults to None. 

641 inputs (Optional[jnp.ndarray], optional): Inputs to the circuit. 

642 Defaults to None. 

643 

644 Returns: 

645 jnp.ndarray: Expectation value of the tree. 

646 

647 Raises: 

648 NotImplementedError: When using other "execution_type" as expval. 

649 NotImplementedError: When using "noise_params" 

650 

651 

652 """ 

653 params = ( 

654 self.model._params_validation(params) 

655 if params is not None 

656 else self.model.params 

657 ) 

658 inputs = ( 

659 self.model._inputs_validation(inputs) 

660 if inputs is not None 

661 else self.model._inputs_validation(1.0) 

662 ) 

663 

664 if kwargs.get("execution_type", "expval") != "expval": 

665 raise NotImplementedError( 

666 f'Currently, only "expval" execution type is supported when ' 

667 f"building FourierTree. Got {kwargs.get('execution_type', 'expval')}." 

668 ) 

669 if kwargs.get("noise_params", None) is not None: 

670 raise NotImplementedError( 

671 "Currently, noise is not supported when building FourierTree." 

672 ) 

673 

674 # Record the circuit tape using jaqsi's tape recording 

675 raw_tape = self.model.script._record(params=self.model.params, inputs=inputs) 

676 

677 # Build observables from the model's output_qubit configuration 

678 _, obs_list = self.model._build_obs() 

679 

680 quantum_tape = PauliCircuit.from_parameterised_circuit( 

681 raw_tape, observables=obs_list 

682 ) 

683 

684 self.parameters = [jnp.squeeze(p) for p in quantum_tape.get_parameters()] 

685 

686 results = jnp.zeros(len(self.tree_roots)) 

687 for i, root in enumerate(self.tree_roots): 

688 results = results.at[i].set(jnp.real(root.evaluate(self.parameters))) 

689 

690 if kwargs.get("force_mean", False): 

691 return jnp.mean(results) 

692 else: 

693 return results 

694 

695 def build(self) -> List[CoefficientsTreeNode]: 

696 """ 

697 Creates the coefficient tree, i.e. it creates and initialises the tree 

698 nodes. 

699 Leafs can be obtained separately in _get_tree_leafs, once the tree is 

700 set up. 

701 

702 Returns: 

703 List[CoefficientsTreeNode]: The list of root nodes (one root for 

704 each observable). 

705 """ 

706 tree_roots = [] 

707 pauli_rotation_idx = len(self.pauli_rotations) - 1 

708 for obs in self.observables: 

709 root = self._create_tree_node(obs, pauli_rotation_idx) 

710 tree_roots.append(root) 

711 return tree_roots 

712 

713 def _encode_observables( 

714 self, tape_obs: List[Operation] 

715 ) -> List[FourierTree.PauliOperator]: 

716 """ 

717 Encodes observables from tape as FourierTree.PauliOperator 

718 utility objects. 

719 

720 Args: 

721 tape_obs (List[Operation]): Observable operations 

722 

723 Returns: 

724 List[FourierTree.PauliOperator]: List of Pauli operators 

725 """ 

726 observables = [] 

727 for obs in tape_obs: 

728 observables.append( 

729 FourierTree.PauliOperator(obs, self.model.n_qubits, is_observable=True) 

730 ) 

731 return observables 

732 

733 def _get_tree_leafs(self) -> List[List[TreeLeaf]]: 

734 """ 

735 Obtain all Leaf Nodes with its sine- and cosine terms. 

736 

737 Returns: 

738 List[List[TreeLeaf]]: For each observable (root), the list of leaf 

739 nodes. 

740 """ 

741 leafs = [] 

742 for root in self.tree_roots: 

743 sin_list = jnp.zeros(len(self.parameters), dtype=jnp.int32) 

744 cos_list = jnp.zeros(len(self.parameters), dtype=jnp.int32) 

745 leafs.append(root.get_leafs(sin_list, cos_list, [])) 

746 return leafs 

747 

748 def get_spectrum( 

749 self, force_mean: bool = False 

750 ) -> Tuple[List[jnp.ndarray], List[jnp.ndarray]]: 

751 """ 

752 Computes the Fourier spectrum for the tree, consisting of the 

753 frequencies and its corresponding coefficinets. 

754 If the frag force_mean was set in the constructor, the mean coefficient 

755 over all observables (roots) are computed. 

756 

757 Args: 

758 force_mean (bool, optional): Whether to average over multiple 

759 observables. Defaults to False. 

760 

761 Returns: 

762 Tuple[List[jnp.ndarray], List[jnp.ndarray]]: 

763 - List of frequencies, one list for each observable (root). 

764 - List of corresponding coefficents, one list for each 

765 observable (root). 

766 """ 

767 parameter_indices = [ 

768 i for i in range(len(self.parameters)) if i not in self.all_input_indices 

769 ] 

770 

771 coeffs = [] 

772 for leafs in self.leafs: 

773 freq_terms = defaultdict(np.complex128) 

774 for input_idx in self.input_indices: 

775 for leaf in leafs: 

776 leaf_factor, s, c = self._compute_leaf_factors( 

777 leaf, parameter_indices, input_idx 

778 ) 

779 

780 for a in range(s + 1): 

781 for b in range(c + 1): 

782 comb = math.comb(s, a) * math.comb(c, b) * (-1) ** (s - a) 

783 freq_terms[2 * a + 2 * b - s - c] += comb * leaf_factor 

784 

785 coeffs.append(freq_terms) 

786 

787 frequencies, coefficients = self._freq_terms_to_coeffs(coeffs, force_mean) 

788 return coefficients, frequencies 

789 

790 def _freq_terms_to_coeffs( 

791 self, coeffs: List[Dict[int, jnp.ndarray]], force_mean: bool 

792 ) -> Tuple[List[jnp.ndarray], List[jnp.ndarray]]: 

793 """ 

794 Given a list of dictionaries of the form: 

795 [ 

796 { 

797 freq_obs1_1: coeff1, 

798 freq_obs1_2: coeff2, 

799 ... 

800 }, 

801 { 

802 freq_obs2_1: coeff3, 

803 freq_obs2_2: coeff4, 

804 ... 

805 } 

806 ... 

807 ], 

808 Compute two separate lists of frequencies and coefficients. 

809 such that: 

810 freqs: [ 

811 [freq_obs1_1, freq_obs1_1, ...], 

812 [freq_obs2_1, freq_obs2_1, ...], 

813 ... 

814 ] 

815 coeffs: [ 

816 [coeff1, coeff2, ...], 

817 [coeff3, coeff4, ...], 

818 ... 

819 ] 

820 

821 If force_mean is set length of the resulting frequency and coefficent 

822 list is 1. 

823 

824 Args: 

825 coeffs (List[Dict[int, jnp.ndarray]]): Frequency->Coefficients 

826 dictionary list, one dict for each observable (root). 

827 force_mean (bool): Whether to average coefficients over multiple 

828 observables. 

829 

830 Returns: 

831 Tuple[List[jnp.ndarray], List[jnp.ndarray]]: 

832 - List of frequencies, one list for each observable (root). 

833 - List of corresponding coefficents, one list for each 

834 observable (root). 

835 """ 

836 frequencies = [] 

837 coefficients = [] 

838 if force_mean: 

839 all_freqs = sorted(set([f for c in coeffs for f in c.keys()])) 

840 coefficients.append( 

841 jnp.array( 

842 [ 

843 jnp.mean(jnp.array([c.get(f, 0.0) for c in coeffs])) 

844 for f in all_freqs 

845 ] 

846 ) 

847 ) 

848 frequencies.append(jnp.array(all_freqs)) 

849 else: 

850 for freq_terms in coeffs: 

851 freq_terms = dict(sorted(freq_terms.items())) 

852 frequencies.append(jnp.array(list(freq_terms.keys()))) 

853 coefficients.append(jnp.array(list(freq_terms.values()))) 

854 return frequencies, coefficients 

855 

856 def _compute_leaf_factors( 

857 self, 

858 leaf: TreeLeaf, 

859 parameter_indices: List[int], 

860 input_idx: int, 

861 ) -> Tuple[float, int, int]: 

862 """ 

863 Computes the constant coefficient factor for each leaf. 

864 Additionally sine and cosine contributions of the input parameters for 

865 this leaf are returned, which are required to obtain the corresponding 

866 frequencies. 

867 

868 Args: 

869 leaf (TreeLeaf): The leaf for which to compute the factor. 

870 parameter_indices (List[int]): Variational parameter indices. 

871 

872 Returns: 

873 Tuple[float, int, int]: 

874 - float: the constant factor for the leaf 

875 - int: number of sine contributions of the input 

876 - int: number of cosine contributions of the input 

877 """ 

878 leaf_factor = 1.0 

879 for i in parameter_indices: 

880 interm_factor = ( 

881 jnp.cos(self.parameters[i]) ** leaf.cos_indices[i] 

882 * (1j * jnp.sin(self.parameters[i])) ** leaf.sin_indices[i] 

883 ) 

884 leaf_factor = leaf_factor * interm_factor 

885 

886 # Get number of sine and cosine factors to which the input contributes 

887 c = jnp.sum( 

888 jnp.array([leaf.cos_indices[k] for k in self.input_indices[input_idx]]) 

889 ) 

890 s = jnp.sum( 

891 jnp.array([leaf.sin_indices[k] for k in self.input_indices[input_idx]]) 

892 ) 

893 

894 leaf_factor = leaf.term * leaf_factor * 0.5 ** (s + c) 

895 

896 return leaf_factor, int(s), int(c) 

897 

898 def _early_stopping_possible( 

899 self, pauli_rotation_idx: int, observable: FourierTree.PauliOperator 

900 ): 

901 """ 

902 Checks if a node for an observable can be discarded as all expecation 

903 values that can result through further branching are zero. 

904 The method is mentioned in the paper by Nemkov et al.: If the one-hot 

905 encoded indices for X/Y operations in the Pauli-rotation generators are 

906 a basis for that of the observable, the node must be processed further. 

907 If not, it can be discarded. 

908 

909 Args: 

910 pauli_rotation_idx (int): Index of remaining Pauli rotation gates. 

911 Gates itself are attributes of the class. 

912 observable (FourierTree.PauliOperator): Current observable 

913 """ 

914 xy_indices_obs = jnp.logical_or( 

915 observable.xy_indices, self.pauli_rotations[pauli_rotation_idx].xy_indices 

916 ).all() 

917 

918 return not xy_indices_obs 

919 

920 def _create_tree_node( 

921 self, 

922 observable: FourierTree.PauliOperator, 

923 pauli_rotation_idx: int, 

924 parameter_idx: Optional[int] = None, 

925 is_sine: bool = False, 

926 is_cosine: bool = False, 

927 ) -> Optional[CoefficientsTreeNode]: 

928 """ 

929 Builds the Fourier-Tree according to the algorithm by Nemkov et al. 

930 

931 Args: 

932 observable (FourierTree.PauliOperator): Current observable 

933 pauli_rotation_idx (int): Index of remaining Pauli rotation gates. 

934 Gates itself are attributes of the class. 

935 parameter_idx (Optional[int]): Index of the current parameter. 

936 Parameters itself are attributes of the class. 

937 is_sine (bool): If the current node is a sine (left) node. 

938 is_cosine (bool): If the current node is a cosine (right) node. 

939 

940 Returns: 

941 Optional[CoefficientsTreeNode]: The resulting node. Children are set 

942 recursively. The top level receives the tree root. 

943 """ 

944 if self._early_stopping_possible(pauli_rotation_idx, observable): 

945 return None 

946 

947 # remove commuting paulis 

948 while pauli_rotation_idx >= 0: 

949 last_pauli = self.pauli_rotations[pauli_rotation_idx] 

950 if not observable.is_commuting(last_pauli.list_repr): 

951 break 

952 pauli_rotation_idx -= 1 

953 else: # leaf 

954 return FourierTree.CoefficientsTreeNode( 

955 parameter_idx, observable, is_sine, is_cosine 

956 ) 

957 

958 last_pauli = self.pauli_rotations[pauli_rotation_idx] 

959 

960 left = self._create_tree_node( 

961 observable, 

962 pauli_rotation_idx - 1, 

963 pauli_rotation_idx, 

964 is_cosine=True, 

965 ) 

966 

967 next_observable = self._create_new_observable(last_pauli.list_repr, observable) 

968 right = self._create_tree_node( 

969 next_observable, 

970 pauli_rotation_idx - 1, 

971 pauli_rotation_idx, 

972 is_sine=True, 

973 ) 

974 

975 return FourierTree.CoefficientsTreeNode( 

976 parameter_idx, 

977 observable, 

978 is_sine, 

979 is_cosine, 

980 left, 

981 right, 

982 ) 

983 

984 def _create_new_observable( 

985 self, pauli: jnp.ndarray[int], observable: FourierTree.PauliOperator 

986 ) -> FourierTree.PauliOperator: 

987 """ 

988 Utility function to obtain the new observable for a tree node, if the 

989 last Pauli and the observable do not commute. 

990 

991 Args: 

992 pauli (jnp.ndarray[int]): The int array representation of the last 

993 Pauli rotation in the operation sequence. 

994 observable (FourierTree.PauliOperator): The current observable. 

995 

996 Returns: 

997 FourierTree.PauliOperator: The new observable. 

998 """ 

999 observable = observable.tensor(pauli) 

1000 return observable 

1001 

1002 

1003class FCC: 

1004 @classmethod 

1005 def get_fcc( 

1006 cls, 

1007 model: Model, 

1008 n_samples: int, 

1009 random_key: Optional[random.PRNGKey] = None, 

1010 method: Optional[str] = "pearson", 

1011 scale: Optional[bool] = False, 

1012 weight: Optional[bool] = False, 

1013 trim_redundant: Optional[bool] = True, 

1014 **kwargs, 

1015 ) -> float: 

1016 """ 

1017 Shortcut method to get just the FCC. 

1018 This includes 

1019 1. What is done in `get_fourier_fingerprint`: 

1020 1. Calculating the coefficients (using `n_samples`) 

1021 2. Correlating the result from 1) using `method` 

1022 3. Weighting the correlation matrix (if `weight` is True) 

1023 4. Remove redundancies 

1024 2. What is done in `calculate_fcc`: 

1025 1. Absolute of the fingerprint 

1026 2. Average 

1027 

1028 Args: 

1029 model (Model): The QFM model 

1030 n_samples (int): Number of samples to calculate average of coefficients 

1031 random_key (Optional[random.PRNGKey]): JAX random key for parameter 

1032 initialization. If None, uses the model's internal random key. 

1033 method (Optional[str], optional): Correlation method. Supported values are 

1034 "pearson", "complex_pearson", and "spearman". Defaults to "pearson". 

1035 scale (Optional[bool], optional): Whether to scale the number of samples. 

1036 Defaults to False. 

1037 weight (Optional[bool], optional): Whether to weight the correlation matrix. 

1038 Defaults to False. 

1039 trim_redundant (Optional[bool], optional): Whether to remove redundant 

1040 correlations. Defaults to False. 

1041 **kwargs (Any): Additional keyword arguments for the model function. 

1042 

1043 Returns: 

1044 float: The FCC 

1045 """ 

1046 

1047 # Memory-efficient fast path 

1048 if trim_redundant and not weight: 

1049 _, coeffs, freqs = cls._calculate_coefficients( 

1050 model, n_samples, random_key, scale, **kwargs 

1051 ) 

1052 pos_idx = cls._calculate_mask(freqs) 

1053 coeffs_flat = coeffs.reshape(-1, coeffs.shape[-1]) 

1054 coeffs_sub = coeffs_flat[pos_idx] 

1055 

1056 fp = cls._correlate(coeffs_sub.transpose(), method=method) 

1057 abs_fp = jnp.abs(fp) 

1058 diag = jnp.abs(jnp.diagonal(fp)) 

1059 

1060 total_sum = jnp.nansum(abs_fp) 

1061 total_count = jnp.sum(jnp.isfinite(abs_fp)) 

1062 diag_sum = jnp.nansum(diag) 

1063 diag_count = jnp.sum(jnp.isfinite(diag)) 

1064 

1065 lower_sum = (total_sum - diag_sum) / 2.0 

1066 lower_count = (total_count - diag_count) / 2.0 

1067 return lower_sum / lower_count 

1068 

1069 fourier_fingerprint, _ = cls.get_fourier_fingerprint( 

1070 model, 

1071 n_samples, 

1072 random_key, 

1073 method, 

1074 scale, 

1075 weight, 

1076 trim_redundant=trim_redundant, 

1077 **kwargs, 

1078 ) 

1079 

1080 return cls.calculate_fcc(fourier_fingerprint) 

1081 

1082 @classmethod 

1083 def get_fourier_fingerprint( 

1084 cls, 

1085 model: Model, 

1086 n_samples: int, 

1087 random_key: Optional[random.PRNGKey] = None, 

1088 method: Optional[str] = "pearson", 

1089 scale: Optional[bool] = False, 

1090 weight: Optional[bool] = False, 

1091 trim_redundant: Optional[bool] = True, 

1092 nan_to_one: Optional[bool] = False, 

1093 **kwargs: Any, 

1094 ) -> Tuple[jnp.ndarray, jnp.ndarray]: 

1095 """ 

1096 Shortcut method to get just the fourier fingerprint. 

1097 This includes 

1098 1. Calculating the coefficients (using `n_samples`) 

1099 2. Correlating the result from 1) using `method` 

1100 3. Weighting the correlation matrix (if `weight` is True) 

1101 4. Remove redundancies (if `trim_redundant` is True) 

1102 

1103 Args: 

1104 model (Model): The QFM model 

1105 n_samples (int): Number of samples to calculate average of coefficients 

1106 random_key (Optional[random.PRNGKey]): JAX random key for parameter 

1107 initialization. If None, uses the model's internal random key. 

1108 method (Optional[str], optional): Correlation method. Supported values are 

1109 "pearson", "complex_pearson", and "spearman". Defaults to "pearson". 

1110 scale (Optional[bool], optional): Whether to scale the number of samples. 

1111 Defaults to False. 

1112 weight (Optional[bool], optional): Whether to weight the correlation matrix. 

1113 Defaults to False. 

1114 trim_redundant (Optional[bool], optional): Whether to remove redundant 

1115 correlations. Defaults to True. 

1116 nan_to_one (Optional[bool], optional): Whether to set nan to 1. 

1117 Defaults to False. 

1118 **kwargs: Additional keyword arguments for the model function. 

1119 

1120 Returns: 

1121 Tuple[jnp.ndarray, jnp.ndarray]: The fourier fingerprint 

1122 and the frequency indices 

1123 """ 

1124 _, coeffs, freqs = cls._calculate_coefficients( 

1125 model, n_samples, random_key, scale, **kwargs 

1126 ) 

1127 

1128 # Memory-efficient fast path 

1129 if trim_redundant and not weight: 

1130 pos_idx = cls._calculate_mask(freqs) 

1131 

1132 # Flatten all frequency axes; the last axis is the sample 

1133 # axis. `_calculate_mask` returns flat indices in C order, 

1134 # matching this reshape. 

1135 coeffs_flat = coeffs.reshape(-1, coeffs.shape[-1]) 

1136 coeffs_sub = coeffs_flat[pos_idx] 

1137 

1138 fourier_fingerprint = cls._correlate(coeffs_sub.transpose(), method=method) 

1139 

1140 if nan_to_one: 

1141 fourier_fingerprint = jnp.where( 

1142 jnp.isnan(fourier_fingerprint), 1.0, fourier_fingerprint 

1143 ) 

1144 

1145 M = fourier_fingerprint.shape[0] 

1146 lower_tri_mask = jnp.tri(M, k=-1, dtype=bool) 

1147 fourier_fingerprint = jnp.where( 

1148 lower_tri_mask, fourier_fingerprint, jnp.nan 

1149 ) 

1150 

1151 row_mask = jnp.any(jnp.isfinite(fourier_fingerprint), axis=1) 

1152 col_mask = jnp.any(jnp.isfinite(fourier_fingerprint), axis=0) 

1153 fourier_fingerprint = fourier_fingerprint[row_mask][:, col_mask] 

1154 

1155 return fourier_fingerprint, freqs 

1156 

1157 fourier_fingerprint = cls._correlate(coeffs.transpose(), method=method) 

1158 

1159 if nan_to_one: 

1160 # set nan to 1 

1161 fourier_fingerprint[jnp.isnan(fourier_fingerprint)] = 1.0 

1162 

1163 # perform weighting if requested 

1164 fourier_fingerprint = ( 

1165 cls._weighting_mean(fourier_fingerprint, coeffs) 

1166 if weight 

1167 else fourier_fingerprint 

1168 ) 

1169 

1170 if trim_redundant: 

1171 pos_idx = cls._calculate_mask(freqs) 

1172 

1173 # restrict to the positive-frequency sub-block (M x M with 

1174 # M = number of non-negative flat-frequencies) instead of 

1175 # building a full N x N mask. This avoids the O(N^2) float 

1176 fourier_fingerprint = fourier_fingerprint[pos_idx][:, pos_idx] 

1177 

1178 # keep only the strict lower triangle; the rest -> nan 

1179 M = fourier_fingerprint.shape[0] 

1180 lower_tri_mask = jnp.tri(M, k=-1, dtype=bool) 

1181 fourier_fingerprint = jnp.where( 

1182 lower_tri_mask, fourier_fingerprint, jnp.nan 

1183 ) 

1184 

1185 row_mask = jnp.any(jnp.isfinite(fourier_fingerprint), axis=1) 

1186 col_mask = jnp.any(jnp.isfinite(fourier_fingerprint), axis=0) 

1187 

1188 fourier_fingerprint = fourier_fingerprint[row_mask][:, col_mask] 

1189 

1190 return fourier_fingerprint, freqs 

1191 

1192 @classmethod 

1193 def calculate_fcc( 

1194 cls, 

1195 fourier_fingerprint: jnp.ndarray, 

1196 ) -> float: 

1197 """ 

1198 Method to calculate the FCC based on an existing correlation matrix. 

1199 Calculate absolute and then the average over this matrix. 

1200 The Fingerprint can be obtained via `get_fourier_fingerprint` 

1201 

1202 Args: 

1203 fourier_fingerprint (jnp.ndarray): Correlation matrix of coefficients 

1204 Returns: 

1205 float: The FCC 

1206 """ 

1207 # apply the mask on the fingerprint 

1208 return jnp.nanmean(jnp.abs(fourier_fingerprint)) 

1209 

1210 @classmethod 

1211 def _calculate_mask(cls, freqs: jnp.ndarray) -> jnp.ndarray: 

1212 """ 

1213 Determine the flat indices of the Fourier correlation matrix 

1214 that lie on a non-negative-frequency row/column. Together with 

1215 the strict-lower-triangle condition (handled by the caller), 

1216 these indices select the entries of the correlation matrix 

1217 that survive the redundancy filter applied in 

1218 `get_fourier_fingerprint`: 

1219 

1220 - rows/columns whose flat frequency component is negative are 

1221 discarded (they are the complex-conjugate redundancies of 

1222 their positive counterparts); 

1223 - of the remaining positive-frequency sub-block, only the 

1224 strict lower triangle is kept (the upper triangle, including 

1225 the diagonal, contains either duplicates from symmetry or 

1226 self-correlations). 

1227 

1228 Args: 

1229 freqs (jnp.ndarray): Array of frequencies. Either a 1-D 

1230 vector (single input feature) or a 2-D array of shape 

1231 ``(n_input_feat, K)`` whose rows are the per-axis 

1232 frequency vectors. 

1233 

1234 Returns: 

1235 jnp.ndarray: 1-D int array of flat indices selecting the 

1236 non-negative-frequency rows/cols of the fingerprint. 

1237 """ 

1238 freqs_arr = jnp.asarray(freqs) 

1239 

1240 if freqs_arr.ndim == 1: 

1241 pos_flat = freqs_arr >= 0 

1242 else: 

1243 # N-D case: build the per-axis non-negativity masks and 

1244 # combine them via broadcasting (no float `jnp.outer`!), 

1245 # then flatten to match the row-major flattening used by 

1246 # the upstream coefficient/correlation pipeline. 

1247 axes_pos = [freqs_arr[i] >= 0 for i in range(freqs_arr.shape[0])] 

1248 expanded = [] 

1249 n_axes = len(axes_pos) 

1250 for i, p in enumerate(axes_pos): 

1251 shape = [1] * n_axes 

1252 shape[i] = p.shape[0] 

1253 expanded.append(p.reshape(shape)) 

1254 nd_pos = reduce(jnp.logical_and, expanded) 

1255 pos_flat = nd_pos.flatten() 

1256 

1257 return jnp.where(pos_flat)[0] 

1258 

1259 @classmethod 

1260 def _calculate_coefficients( 

1261 cls, 

1262 model: Model, 

1263 n_samples: int, 

1264 random_key: Optional[random.PRNGKey] = None, 

1265 scale: bool = False, 

1266 **kwargs: Any, 

1267 ) -> Tuple[jnp.ndarray, jnp.ndarray]: 

1268 """ 

1269 Calculates the Fourier coefficients of a given model 

1270 using `n_samples`. 

1271 Optionally, `noise_params` can be passed to perform noisy simulation. 

1272 

1273 Args: 

1274 model (Model): The QFM model 

1275 n_samples (int): Number of samples to calculate average of coefficients 

1276 random_key (Optional[random.PRNGKey]): JAX random key for parameter 

1277 initialization. If None, uses the model's internal random key. 

1278 scale (bool, optional): Whether to scale the number of samples. 

1279 Defaults to False. 

1280 **kwargs: Additional keyword arguments for the model function. 

1281 

1282 Returns: 

1283 Tuple[jnp.ndarray, jnp.ndarray]: Parameters and Coefficients of size NxK 

1284 """ 

1285 if n_samples > 0: 

1286 if scale: 

1287 total_samples = int( 

1288 jnp.power(2, model.n_qubits) * n_samples * model.n_input_feat 

1289 ) 

1290 log.info(f"Using {total_samples} samples.") 

1291 else: 

1292 total_samples = n_samples 

1293 model.initialize_params(random_key, repeat=total_samples) 

1294 else: 

1295 total_samples = 1 

1296 

1297 coeffs, freqs = Coefficients.get_spectrum( 

1298 model, shift=True, trim=True, **kwargs 

1299 ) 

1300 

1301 return model.params, coeffs, freqs 

1302 

1303 @classmethod 

1304 def _correlate(cls, mat: jnp.ndarray, method: str = "pearson") -> jnp.ndarray: 

1305 """ 

1306 Correlates two arrays using `method`. 

1307 Currently, `pearson`, `complex_pearson`, and `spearman` are supported. 

1308 

1309 Args: 

1310 mat (jnp.ndarray): Array of shape (N, K) 

1311 method (str, optional): Correlation method. Defaults to "pearson". 

1312 

1313 Raises: 

1314 ValueError: If the method is not supported. 

1315 

1316 Returns: 

1317 jnp.ndarray: Correlation matrix of `a` and `b`. 

1318 """ 

1319 assert len(mat.shape) >= 2, "Input matrix must have at least 2 dimensions" 

1320 

1321 # Note that for the general n-D case, we have to flatten along 

1322 # the first axis (last one is batch). 

1323 # Note that the order here is important so we can easily filter out 

1324 # negative coefficients later. 

1325 # Consider the following example: [[1,2,3],[4,5,6],[7,8,9]] 

1326 # we want to get [1, 4, 7, 2, 5, 8, 3, 6, 9] 

1327 # such that after correlation, all positive indexed coefficients 

1328 # will be in the bottom right quadrant 

1329 if method == "pearson": 

1330 result = cls._pearson(mat.reshape(mat.shape[0], -1)) 

1331 # result = cls._pearson(mat.reshape(mat.shape[-1], -1, order="F")) 

1332 elif method == "complex_pearson": 

1333 result = cls._complex_pearson(mat.reshape(mat.shape[0], -1)) 

1334 elif method == "spearman": 

1335 result = cls._spearman(mat.reshape(mat.shape[0], -1)) 

1336 # result = cls._spearman(mat.reshape(mat.shape[-1], -1, order="F")) 

1337 else: 

1338 raise ValueError( 

1339 f"Unknown correlation method: {method}. \ 

1340 Must be 'pearson', 'complex_pearson' or 'spearman'." 

1341 ) 

1342 

1343 return result 

1344 

1345 @classmethod 

1346 def _complex_pearson( 

1347 cls, mat: jnp.ndarray, cov: Optional[bool] = False, minp: Optional[int] = 1 

1348 ) -> jnp.ndarray: 

1349 """ 

1350 Compute the complex Pearson correlation between columns of `mat`, 

1351 permitting missing values (NaN or ±Inf). 

1352 

1353 This uses the Hermitian normalized covariance 

1354 sum(conj(x_i - mean_i) * (x_j - mean_j)) / 

1355 sqrt(sum(abs(x_i - mean_i)**2) * sum(abs(x_j - mean_j)**2)). 

1356 Consequently, if column j is exp(1j * phi) times column i, then 

1357 abs(corr[i, j]) is 1 and angle(corr[i, j]) is phi. 

1358 

1359 Args: 

1360 mat : array_like, shape (N, K) 

1361 Input data. 

1362 cov : bool, optional 

1363 If True, return the sample covariance matrix instead of 

1364 correlation. Defaults to False. 

1365 minp : int, optional 

1366 Minimum number of paired observations required to form a correlation. 

1367 If the number of valid pairs for (i, j) is < minp, the result is NaN. 

1368 

1369 Returns: 

1370 corr : ndarray, shape (K, K) 

1371 Complex Pearson correlation matrix. 

1372 """ 

1373 mat = jnp.asarray(mat) 

1374 real_dtype = jnp.asarray(mat.real).dtype 

1375 

1376 mask = jnp.isfinite(mat) 

1377 fmask = mask.astype(real_dtype) 

1378 safe = jnp.where(mask, mat, 0.0) 

1379 

1380 nobs = fmask.T @ fmask 

1381 nobs_safe = jnp.where(nobs > 0, nobs, 1.0) 

1382 

1383 sum_x = safe.T @ fmask 

1384 sum_y = fmask.T @ safe 

1385 

1386 masked = safe * fmask 

1387 sum_conj_xy = jnp.conj(masked).T @ masked 

1388 

1389 safe_abs_sq = jnp.abs(safe) ** 2 

1390 sum_abs_x2 = safe_abs_sq.T @ fmask 

1391 sum_abs_y2 = fmask.T @ safe_abs_sq 

1392 

1393 ssx = sum_abs_x2 - jnp.abs(sum_x) ** 2 / nobs_safe 

1394 ssy = sum_abs_y2 - jnp.abs(sum_y) ** 2 / nobs_safe 

1395 sxy = sum_conj_xy - (jnp.conj(sum_x) * sum_y) / nobs_safe 

1396 

1397 if cov: 

1398 denom = jnp.where(nobs > 1, nobs - 1, jnp.nan) 

1399 result = sxy / denom 

1400 else: 

1401 denom = jnp.sqrt(ssx * ssy) 

1402 result = jnp.where(denom > 0, sxy / denom, jnp.nan) 

1403 magnitude = jnp.abs(result) 

1404 result = jnp.where(magnitude > 1.0, result / magnitude, result) 

1405 

1406 result = jnp.where(nobs < minp, jnp.nan, result) 

1407 

1408 return result 

1409 

1410 @classmethod 

1411 def _pearson( 

1412 cls, mat: jnp.ndarray, cov: Optional[bool] = False, minp: Optional[int] = 1 

1413 ) -> jnp.ndarray: 

1414 """ 

1415 Based on Pandas correlation method as implemented here: 

1416 https://github.com/pandas-dev/pandas/blob/main/pandas/_libs/algos.pyx 

1417 

1418 Compute Pearson correlation between columns of `mat`, 

1419 permitting missing values (NaN or ±Inf). 

1420 

1421 If the input is complex, real and imaginary parts are stacked along 

1422 the sample axis so that both components contribute to the correlation 

1423 without discarding information. 

1424 

1425 Args: 

1426 mat : array_like, shape (N, K) 

1427 Input data. 

1428 cov : bool, optional 

1429 If True, return the sample covariance matrix instead of 

1430 correlation. Defaults to False. 

1431 minp : int, optional 

1432 Minimum number of paired observations required to form a correlation. 

1433 If the number of valid pairs for (i, j) is < minp, the result is NaN. 

1434 

1435 Returns: 

1436 corr : ndarray, shape (K, K) 

1437 Pearson correlation matrix. 

1438 """ 

1439 # Preserve complex information by splitting into real / imag samples 

1440 if jnp.iscomplexobj(mat): 

1441 mat = jnp.concatenate([mat.real, mat.imag], axis=0) 

1442 

1443 mat = jnp.asarray(mat) 

1444 

1445 # pre-compute finite mask (N, K) 

1446 mask = jnp.isfinite(mat) 

1447 fmask = mask.astype(mat.dtype) 

1448 

1449 # Replace non-finite entries with 0 so arithmetic is safe; 

1450 # the mask keeps track of validity. 

1451 safe = jnp.where(mask, mat, 0.0) 

1452 

1453 # Pairwise valid-observation counts (K, K) 

1454 nobs = fmask.T @ fmask 

1455 

1456 # Pairwise sums (only over mutually valid rows) 

1457 # For columns i, j the "valid" rows are mask[:,i] & mask[:,j]. 

1458 # sum_x[i,j] = sum of mat[:,i] where both i and j are valid. 

1459 # Using: safe[:,i] * mask[:,j] zeroes out rows invalid for j. 

1460 # Then summing over N gives sum_x[i,j]. 

1461 # safe.T @ fmask gives (K, K) where entry (i,j) = sum of safe[:,i]*mask[:,j] 

1462 sum_x = safe.T @ fmask # (K, K) – row-var sums 

1463 sum_y = fmask.T @ safe # (K, K) – col-var sums 

1464 

1465 # Note: explicit means (sum_x/nobs, sum_y/nobs) are not needed as 

1466 # separate variables — the computational formula used below 

1467 # (e.g. ssx = Σx² − (Σx)²/n) implicitly handles mean-centering. 

1468 

1469 # Cross products, sum-of-squares via computational formula: 

1470 # ssx = Σx² − (Σx)²/n, ssy = Σy² − (Σy)²/n, 

1471 # sxy = Σxy − (Σx)(Σy)/n 

1472 # All sums are taken over the pairwise-valid subset for each (i,j). 

1473 masked = safe * fmask # same as safe but explicit 

1474 sum_xy = masked.T @ masked # (K, K) 

1475 

1476 # ssx[i,j] = sum_xx_ij - nobs * mean_x^2 (but sum_xx_ij uses pair mask) 

1477 # We need sum of x^2 over the *pair* mask, not just column mask. 

1478 # sum_x2[i,j] = sum_n safe[n,i]^2 * mask[n,i] * mask[n,j] 

1479 safe_sq = safe**2 

1480 sum_x2 = safe_sq.T @ fmask # (K, K) 

1481 sum_y2 = fmask.T @ safe_sq # (K, K) 

1482 

1483 ssx = sum_x2 - sum_x**2 / jnp.where(nobs > 0, nobs, 1.0) 

1484 ssy = sum_y2 - sum_y**2 / jnp.where(nobs > 0, nobs, 1.0) 

1485 sxy = sum_xy - (sum_x * sum_y) / jnp.where(nobs > 0, nobs, 1.0) 

1486 

1487 if cov: 

1488 denom = jnp.where(nobs > 1, nobs - 1, jnp.nan) 

1489 result = sxy / denom 

1490 else: 

1491 denom = jnp.sqrt(ssx * ssy) 

1492 result = jnp.where(denom > 0, sxy / denom, jnp.nan) 

1493 # clip numerical drift to [-1, 1] 

1494 result = jnp.clip(result, -1.0, 1.0) 

1495 

1496 # Enforce minp: set entries with too few observations to NaN 

1497 result = jnp.where(nobs < minp, jnp.nan, result) 

1498 

1499 return result 

1500 

1501 @classmethod 

1502 def _spearman(cls, mat: jnp.ndarray, minp: Optional[int] = 1) -> jnp.ndarray: 

1503 """ 

1504 Based on Pandas correlation method as implemented here: 

1505 https://github.com/pandas-dev/pandas/blob/main/pandas/_libs/algos.pyx 

1506 

1507 Compute Spearman correlation between columns of `mat`, 

1508 permitting missing values (NaN or ±Inf). 

1509 

1510 If the input is complex, real and imaginary parts are stacked along 

1511 the sample axis so that both components contribute to the correlation 

1512 without discarding information. 

1513 

1514 Args: 

1515 mat : array_like, shape (N, K) 

1516 Input data. 

1517 minp : int, optional 

1518 Minimum number of paired observations required to form a correlation. 

1519 If the number of valid pairs for (i, j) is < minp, the result is NaN. 

1520 

1521 Returns: 

1522 corr : ndarray, shape (K, K) 

1523 Spearman correlation matrix. 

1524 """ 

1525 # Preserve complex information by splitting into real / imag samples 

1526 if jnp.iscomplexobj(mat): 

1527 mat = jnp.concatenate([mat.real, mat.imag], axis=0) 

1528 

1529 mat = jnp.asarray(mat) 

1530 N, K = mat.shape 

1531 

1532 # trivial all-NaN answer if too few rows 

1533 if N < minp: 

1534 return jnp.full((K, K), jnp.nan) 

1535 

1536 # mask of finite entries 

1537 mask = jnp.isfinite(mat) # shape (N, K), dtype=bool 

1538 

1539 # precompute ranks column-wise ignoring NaNs 

1540 ranks = np.full((N, K), np.nan) 

1541 for j in range(K): 

1542 valid = mask[:, j] 

1543 if valid.any(): 

1544 ranks[valid, j] = rankdata(mat[valid, j], method="average") 

1545 

1546 ranks = jnp.asarray(ranks) 

1547 

1548 # Vectorised Pearson on the ranks 

1549 # Replace NaN ranks with 0; use mask to track validity. 

1550 rank_mask = jnp.isfinite(ranks) 

1551 safe_ranks = jnp.where(rank_mask, ranks, 0.0) 

1552 

1553 # Pairwise valid-observation counts (K, K) 

1554 fmask = rank_mask.astype(ranks.dtype) 

1555 nobs = fmask.T @ fmask 

1556 

1557 # Pairwise sums over mutually-valid rows 

1558 sum_x = safe_ranks.T @ fmask # (K, K) 

1559 sum_y = fmask.T @ safe_ranks # (K, K) 

1560 

1561 # Pairwise products 

1562 masked_ranks = safe_ranks * fmask # same as safe_ranks 

1563 sum_xy = masked_ranks.T @ masked_ranks # (K, K) 

1564 

1565 safe_sq = safe_ranks**2 

1566 sum_x2 = safe_sq.T @ fmask # (K, K) 

1567 sum_y2 = fmask.T @ safe_sq # (K, K) 

1568 

1569 nobs_safe = jnp.where(nobs > 0, nobs, 1.0) 

1570 ssx = sum_x2 - sum_x**2 / nobs_safe 

1571 ssy = sum_y2 - sum_y**2 / nobs_safe 

1572 sxy = sum_xy - (sum_x * sum_y) / nobs_safe 

1573 

1574 denom = jnp.sqrt(ssx * ssy) 

1575 result = jnp.where(denom > 0, sxy / denom, jnp.nan) 

1576 result = jnp.clip(result, -1.0, 1.0) 

1577 

1578 # Enforce minp 

1579 result = jnp.where(nobs < minp, jnp.nan, result) 

1580 

1581 return result 

1582 

1583 @classmethod 

1584 def _weighting_linear(cls, fourier_fingerprint: jnp.ndarray) -> jnp.ndarray: 

1585 """ 

1586 Performs weighting on the given correlation matrix. 

1587 Here, low-frequent coefficients are weighted more heavily. 

1588 

1589 Args: 

1590 fourier_fingerprint (jnp.ndarray): Correlation matrix 

1591 """ 

1592 assert ( 

1593 fourier_fingerprint.shape[0] % 2 != 0 

1594 and fourier_fingerprint.shape[1] % 2 != 0 

1595 ), ( 

1596 "Correlation matrix must have odd dimensions. \ 

1597 Hint: use `trim` argument when calling `get_spectrum`." 

1598 ) 

1599 assert fourier_fingerprint.shape[0] == fourier_fingerprint.shape[1], ( 

1600 "Correlation matrix must be square." 

1601 ) 

1602 

1603 # The weight matrix produced by the previous quadrant-mirror 

1604 # construction has a closed form: it is a "tent" sum along the 

1605 # two axes. Concretely, with N = fourier_fingerprint.shape[0] 

1606 # (odd) and center = N // 2, 

1607 # W[i, j] = u[i] + u[j] 

1608 # where u[k] = (center - |k - center|) / (2 * center) 

1609 # is a triangular weighting peaking at the centre (the zero 

1610 # frequency) and decaying linearly to 0 at the spectrum edges. 

1611 N = fourier_fingerprint.shape[0] 

1612 center = N // 2 

1613 k = jnp.arange(N) 

1614 u = (center - jnp.abs(k - center)) / (2 * center) 

1615 

1616 return fourier_fingerprint * (u[:, None] + u[None, :]) 

1617 

1618 @classmethod 

1619 def _weighting_mean( 

1620 cls, fourier_fingerprint: jnp.ndarray, coeffs: jnp.ndarray 

1621 ) -> jnp.ndarray: 

1622 """ 

1623 Performs weighting on the given correlation matrix. 

1624 Here, we use the product of the mean of the coefficients as weights. 

1625 This suppresses correlations where the mean of the coefficients is near zero. 

1626 

1627 Args: 

1628 fourier_fingerprint (jnp.ndarray): Correlation matrix 

1629 coeffs (jnp.ndarray): Fourier coefficients 

1630 """ 

1631 assert fourier_fingerprint.shape[0] == fourier_fingerprint.shape[1], ( 

1632 "Correlation matrix must be square." 

1633 ) 

1634 assert len(coeffs.shape) >= 2, ( 

1635 "Coefficient matrix must contain coefficient axes and a sample axis." 

1636 ) 

1637 

1638 coefficient_means = jnp.abs(jnp.mean(coeffs, axis=-1)) 

1639 coefficient_means = coefficient_means.T.reshape(-1) 

1640 

1641 assert fourier_fingerprint.shape[0] == coefficient_means.shape[0], ( 

1642 "Correlation matrix size must match the number of Fourier coefficients." 

1643 ) 

1644 

1645 # Apply the rank-1 weight w[i] * w[j] via broadcasting instead 

1646 # of materialising an explicit `jnp.outer` N x N intermediate. 

1647 return ( 

1648 fourier_fingerprint 

1649 * coefficient_means[:, None] 

1650 * coefficient_means[None, :] 

1651 ) 

1652 

1653 

1654class Datasets: 

1655 @classmethod 

1656 def generate_fourier_series( 

1657 cls, 

1658 random_key: random.PRNGKey, 

1659 model: Model, 

1660 coefficients_min: float = 0.0, 

1661 coefficients_max: float = 1.0, 

1662 zero_centered: bool = False, 

1663 ) -> jnp.ndarray: 

1664 """ 

1665 Generates the Fourier series representation of a function. 

1666 It uses the `model.frequencies` property to retrieve the frequency 

1667 information. This ensures that the resulting Fourier series is 

1668 compatible with the model. 

1669 

1670 This function is capable of generating $D$-dimensional Fourier series 

1671 (again defined by `model.n_input_feat`). 

1672 The highest frequency $N$ is retrieved per dimension. 

1673 

1674 Samples of the Fourier coefficients are drawn from a uniform circle. 

1675 

1676 Args: 

1677 random_key (random.PRNGKey): Random number key for JAX. 

1678 model (Model): The quantum circuit model. 

1679 coefficients_min (float, optional): Minimum value for the coefficients. 

1680 Defaults to 0.0. 

1681 coefficients_max (float, optional): Maximum value for the coefficients. 

1682 Defaults to 1.0. 

1683 zero_centered (bool, optional): Whether to zero-center the coefficients. 

1684 Defaults to False. 

1685 

1686 Returns: 

1687 jnp.ndarray: Input domain samples with shape ((N,)*D, D) 

1688 jnp.ndarray: Fourier series values with shape ((N,)*D) 

1689 jnp.ndarray: Fourier coefficients with shape ((N,)*D) 

1690 

1691 """ 

1692 # TODO: the following code can be considered to 

1693 # capturing a truly random spectrum. 

1694 # add some constraints on the spectrum, i.e. not fully 

1695 

1696 # Note: one key observation for understanding the following code is, 

1697 # that instead of wrapping your head around symmetries in multi- 

1698 # dimensional coefficient matrices, one can simply look at the flattened 

1699 # version of such a matrix and reshape later. It just works out. 

1700 

1701 # going from [0, 2pi] with the resolution required for highest frequency 

1702 # permute with input dimensionality to get an n-d grid of domain samples 

1703 # the output shape comes from the fact that want to create a "coordinate system" 

1704 domain_samples_per_input_dim = jnp.stack( 

1705 jnp.meshgrid( 

1706 *[jnp.arange(0, 2 * jnp.pi, 2 * jnp.pi / d) for d in model.degree] 

1707 ) 

1708 ).T.reshape(-1, model.n_input_feat) 

1709 

1710 # generate the frequency indices for each dimension. 

1711 # this will have the same shape as the domain samples 

1712 frequencies = jnp.stack(jnp.meshgrid(*model.frequencies)).T.reshape( 

1713 -1, model.n_input_feat 

1714 ) 

1715 

1716 # using the frequency information, sample coefficients for each dimension 

1717 # shape: (input_dims, n_freqs_per_input_dim // 2 + 1) 

1718 

1719 coefficients = cls.uniform_circle( 

1720 random_key, 

1721 low=coefficients_min, 

1722 high=coefficients_max, 

1723 size=math.prod(model.degree) // 2 + 1, 

1724 ) 

1725 

1726 # zero center (first coeff = 0) 

1727 # we can assume the first coeff is the offset, because we're dealing 

1728 # with a non-symmetric spectrum here 

1729 if zero_centered: 

1730 coefficients = coefficients.at[0].set(0.0) 

1731 else: 

1732 coefficients = coefficients.at[0].set(coefficients[0].real) 

1733 

1734 # ensure symmetry (here, non_negative_ is removed!), 

1735 # giving us the full coefficients vector 

1736 coefficients = jnp.concat( 

1737 [ 

1738 jnp.flip(coefficients[..., 1:]).conjugate(), 

1739 coefficients, 

1740 ], 

1741 axis=-1, 

1742 ) 

1743 

1744 # Vectorized version of $f(x) = \sum_{n=0}^{N-1} c_n * e^{i * \omega_n * x}$ 

1745 # it takes into account the input dimension, i.e. the output is a matrix 

1746 # normalization uses the n_freqs component of the coefficients 

1747 values = jnp.real( 

1748 ( 

1749 jnp.exp(1j * (domain_samples_per_input_dim @ frequencies.T)) 

1750 * coefficients 

1751 ).sum(axis=1) 

1752 / coefficients.size 

1753 ) 

1754 

1755 # return all the information we have 

1756 return [ 

1757 domain_samples_per_input_dim.reshape(*model.degree, -1), 

1758 values.reshape(model.degree), 

1759 coefficients.reshape(model.degree), 

1760 ] 

1761 

1762 @classmethod 

1763 def uniform_circle( 

1764 cls, 

1765 random_key: random.PRNGKey, 

1766 size: Union[jnp.ndarray, List, int], 

1767 low=0.0, 

1768 high=1.0, 

1769 ): 

1770 """ 

1771 Random number generator for complex numbers sampled inside the unit circle 

1772 

1773 Args: 

1774 random_key (random.PRNGKey): Random number key for JAX. 

1775 size (Union[jnp.ndarray, int]): Number of samples. If a 2D array is passed, 

1776 the first dimension will be the number of dimensions. 

1777 low (float, optional): Minimum Radius. Defaults to 0.0. 

1778 high (float, optional): Maximum Radius. Defaults to 1.0. 

1779 

1780 Returns 

1781 jnp.ndarray: Array of complex numbers with shape of `size` 

1782 """ 

1783 

1784 if isinstance(size, int): 

1785 size = jnp.array([size]) 

1786 

1787 random_key, random_key1 = random.split(random_key) 

1788 return jnp.sqrt( 

1789 random.uniform(random_key, size, minval=low, maxval=high) 

1790 ) * jnp.exp(2j * jnp.pi * random.uniform(random_key1, size))