Coverage for qml_essentials/coefficients.py: 96%

424 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2026-02-20 14:03 +0000

1from __future__ import annotations 

2import math 

3from collections import defaultdict 

4from dataclasses import dataclass 

5import pennylane as qml 

6import jax.numpy as jnp 

7from jax import random 

8import numpy as np 

9from pennylane.operation import Operator 

10import pennylane.ops.op_math as qml_op 

11from scipy.stats import rankdata 

12from functools import reduce 

13from typing import List, Tuple, Optional, Any, Dict, Union 

14 

15from qml_essentials.model import Model 

16from qml_essentials.utils import PauliCircuit 

17 

18import logging 

19 

20log = logging.getLogger(__name__) 

21 

22 

23class Coefficients: 

24 @staticmethod 

25 def get_spectrum( 

26 model: Model, 

27 mfs: int = 1, 

28 mts: int = 1, 

29 shift=False, 

30 trim=False, 

31 **kwargs, 

32 ) -> Tuple[jnp.ndarray, jnp.ndarray]: 

33 """ 

34 Extracts the coefficients of a given model using a FFT (jnp-fft). 

35 

36 Note that the coefficients are complex numbers, but the imaginary part 

37 of the coefficients should be very close to zero, since the expectation 

38 values of the Pauli operators are real numbers. 

39 

40 It can perform oversampling in both the frequency and time domain 

41 using the `mfs` and `mts` arguments. 

42 

43 Args: 

44 model (Model): The model to sample. 

45 mfs (int): Multiplicator for the highest frequency. Default is 2. 

46 mts (int): Multiplicator for the number of time samples. Default is 1. 

47 shift (bool): Whether to apply jnp-fftshift. Default is False. 

48 trim (bool): Whether to remove the Nyquist frequency if spectrum is even. 

49 Default is False. 

50 kwargs (Any): Additional keyword arguments for the model function. 

51 

52 Returns: 

53 Tuple[jnp.ndarray, jnp.ndarray]: Tuple containing the coefficients 

54 and frequencies. 

55 """ 

56 kwargs.setdefault("force_mean", True) 

57 kwargs.setdefault("execution_type", "expval") 

58 

59 coeffs, freqs = Coefficients._fourier_transform( 

60 model, mfs=mfs, mts=mts, **kwargs 

61 ) 

62 

63 if not jnp.isclose(jnp.sum(coeffs).imag, 0.0, rtol=1.0e-5): 

64 raise ValueError( 

65 f"Spectrum is not real. Imaginary part of coefficients is:\ 

66 {jnp.sum(coeffs).imag}" 

67 ) 

68 

69 if trim: 

70 for ax in range(model.n_input_feat): 

71 if coeffs.shape[ax] % 2 == 0: 

72 coeffs = np.delete(coeffs, len(coeffs) // 2, axis=ax) 

73 freqs = [np.delete(freq, len(freq) // 2, axis=ax) for freq in freqs] 

74 

75 if shift: 

76 coeffs = jnp.fft.fftshift(coeffs, axes=list(range(model.n_input_feat))) 

77 freqs = np.fft.fftshift(freqs) 

78 

79 if len(freqs) == 1: 

80 freqs = freqs[0] 

81 

82 return coeffs, freqs 

83 

84 @staticmethod 

85 def _fourier_transform( 

86 model: Model, mfs: int, mts: int, **kwargs: Any 

87 ) -> jnp.ndarray: 

88 # Create a frequency vector with as many frequencies as model degrees, 

89 # oversampled by mfs 

90 n_freqs: jnp.ndarray = jnp.array( 

91 [mfs * model.degree[i] for i in range(model.n_input_feat)] 

92 ) 

93 

94 start, stop, step = 0, 2 * mts * jnp.pi, 2 * jnp.pi / n_freqs 

95 # Stretch according to the number of frequencies 

96 inputs: List = [ 

97 jnp.arange(start, stop, step[i]) for i in range(model.n_input_feat) 

98 ] 

99 

100 # permute with input dimensionality 

101 nd_inputs = jnp.array( 

102 jnp.meshgrid(*[inputs[i] for i in range(model.n_input_feat)]) 

103 ).T.reshape(-1, model.n_input_feat) 

104 

105 # Output vector is not necessarily the same length as input 

106 outputs = model(inputs=nd_inputs, **kwargs) 

107 outputs = outputs.reshape( 

108 *[inputs[i].shape[0] for i in range(model.n_input_feat)], -1 

109 ).squeeze() 

110 

111 coeffs = jnp.fft.fftn(outputs, axes=list(range(model.n_input_feat))) 

112 

113 freqs = [ 

114 jnp.fft.fftfreq(int(mts * n_freqs[i]), 1 / n_freqs[i]) 

115 for i in range(model.n_input_feat) 

116 ] 

117 # freqs = jnp.fft.fftfreq(mts * n_freqs, 1 / n_freqs) 

118 

119 # TODO: this could cause issues with multidim input 

120 # FIXME: account for different frequencies in multidim input scenarios 

121 # Run the fft and rearrange + 

122 # normalize the output (using product if multidim) 

123 return ( 

124 coeffs / math.prod(outputs.shape[0 : model.n_input_feat]), 

125 freqs, 

126 ) 

127 

128 @staticmethod 

129 def get_psd(coeffs: jnp.ndarray) -> jnp.ndarray: 

130 """ 

131 Calculates the power spectral density (PSD) from given Fourier coefficients. 

132 

133 Args: 

134 coeffs (jnp.ndarray): The Fourier coefficients. 

135 

136 Returns: 

137 jnp.ndarray: The power spectral density. 

138 """ 

139 # TODO: if we apply trim=True in advance, this will be slightly wrong.. 

140 

141 def abs2(x): 

142 return x.real**2 + x.imag**2 

143 

144 scale = 2.0 / (len(coeffs) ** 2) 

145 return scale * abs2(coeffs) 

146 

147 @staticmethod 

148 def evaluate_Fourier_series( 

149 coefficients: jnp.ndarray, 

150 frequencies: jnp.ndarray, 

151 inputs: Union[jnp.ndarray, list, float], 

152 ) -> float: 

153 """ 

154 Evaluate the function value of a Fourier series at one point. 

155 

156 Args: 

157 coefficients (jnp.ndarray): Coefficients of the Fourier series. 

158 frequencies (jnp.ndarray): Corresponding frequencies. 

159 inputs (jnp.ndarray): Point at which to evaluate the function. 

160 Returns: 

161 float: The function value at the input point. 

162 """ 

163 if isinstance(frequencies, list): 

164 if len(coefficients.shape) <= len(frequencies): 

165 coefficients = coefficients[..., jnp.newaxis] 

166 else: 

167 if len(coefficients.shape) == 1: 

168 coefficients = coefficients[..., jnp.newaxis] 

169 

170 if isinstance(inputs, list): 

171 inputs = jnp.array(inputs) 

172 if len(inputs.shape) < 1: 

173 inputs = inputs[jnp.newaxis, ...] 

174 

175 if isinstance(frequencies, list): 

176 input_dim = len(frequencies) 

177 frequencies = jnp.stack(jnp.meshgrid(*frequencies)) 

178 if input_dim != len(inputs): 

179 frequencies = jnp.repeat( 

180 frequencies[jnp.newaxis, ...], inputs.shape[0], axis=0 

181 ) 

182 freq_inputs = jnp.einsum("bi...,b->b...", frequencies, inputs) 

183 exponents = jnp.exp(1j * freq_inputs).T 

184 exp = jnp.einsum("jl...k,jl...b->b...k", coefficients, exponents) 

185 else: 

186 freq_inputs = jnp.einsum("i...,i->...", frequencies, inputs) 

187 exponents = jnp.exp(1j * freq_inputs).T 

188 exp = jnp.einsum("jl...k,jl...->k...", coefficients, exponents) 

189 else: 

190 frequencies = jnp.repeat( 

191 frequencies[jnp.newaxis, ...], inputs.shape[0], axis=0 

192 ) 

193 freq_inputs = jnp.einsum("i...,i->i...", frequencies, inputs) 

194 exponents = jnp.exp(1j * freq_inputs) 

195 exp = jnp.einsum("j...k,ij...->ik...", coefficients, exponents) 

196 

197 return jnp.squeeze(jnp.real(exp)) 

198 

199 

200class FourierTree: 

201 """ 

202 Sine-cosine tree representation for the algorithm by Nemkov et al. 

203 This tree can be used to obtain analytical Fourier coefficients for a given 

204 Pauli-Clifford circuit. 

205 """ 

206 

207 class CoefficientsTreeNode: 

208 """ 

209 Representation of a node in the coefficients tree for the algorithm by 

210 Nemkov et al. 

211 """ 

212 

213 def __init__( 

214 self, 

215 parameter_idx: Optional[int], 

216 observable: FourierTree.PauliOperator, 

217 is_sine_factor: bool, 

218 is_cosine_factor: bool, 

219 left: Optional[FourierTree.CoefficientsTreeNode] = None, 

220 right: Optional[FourierTree.CoefficientsTreeNode] = None, 

221 ): 

222 """ 

223 Coefficient tree node initialisation. Each node has information about 

224 its creation context and it's children, i.e.: 

225 

226 Args: 

227 parameter_idx (Optional[int]): Index of the corresp. param. index i. 

228 observable (FourierTree.PauliOperator): The nodes observable to 

229 obtain the expectation value that contributes to the constant 

230 term. 

231 is_sine_factor (bool): If this node belongs to a sine coefficient. 

232 is_cosine_factor (bool): If this node belongs to a cosine coefficient. 

233 left (Optional[CoefficientsTreeNode]): left child (if any). 

234 right (Optional[CoefficientsTreeNode]): right child (if any). 

235 """ 

236 self.parameter_idx = parameter_idx 

237 

238 assert not ( 

239 is_sine_factor and is_cosine_factor 

240 ), "Cannot be sine and cosine at the same time" 

241 self.is_sine_factor = is_sine_factor 

242 self.is_cosine_factor = is_cosine_factor 

243 

244 # If the observable does not constist of only Z and I, the 

245 # expectation (and therefore the constant node term) is zero 

246 if jnp.logical_or( 

247 observable.list_repr == 0, observable.list_repr == 1 

248 ).any(): 

249 self.term = 0.0 

250 else: 

251 self.term = observable.phase 

252 

253 self.left = left 

254 self.right = right 

255 

256 def evaluate(self, parameters: list[float]) -> float: 

257 """ 

258 Recursive function to evaluate the expectation of the coefficient tree, 

259 starting from the current node. 

260 

261 Args: 

262 parameters (list[float]): The parameters, by which the circuit (and 

263 therefore the tree) is parametrised. 

264 

265 Returns: 

266 float: The expectation for the current node and it's children. 

267 """ 

268 factor = ( 

269 parameters[self.parameter_idx] 

270 if self.parameter_idx is not None 

271 else 1.0 

272 ) 

273 if self.is_sine_factor: 

274 factor = 1j * jnp.sin(factor) 

275 elif self.is_cosine_factor: 

276 factor = jnp.cos(factor) 

277 if not (self.left or self.right): # leaf 

278 return factor * self.term 

279 

280 sum_children = 0.0 

281 if self.left: 

282 left = self.left.evaluate(parameters) 

283 sum_children = sum_children + left 

284 if self.right: 

285 right = self.right.evaluate(parameters) 

286 sum_children = sum_children + right 

287 

288 return factor * sum_children 

289 

290 def get_leafs( 

291 self, 

292 sin_list: List[int], 

293 cos_list: List[int], 

294 existing_leafs: List[FourierTree.TreeLeaf] = [], 

295 ) -> List[FourierTree.TreeLeaf]: 

296 """ 

297 Traverse the tree starting from the current node, to obtain the tree 

298 leafs only. 

299 The leafs correspond to the terms in the sine-cosine tree 

300 representation that eventually are used to obtain coefficients and 

301 frequencies. 

302 Sine and cosine lists are recursively passed to the children until a 

303 leaf is reached (top to bottom). 

304 Leafs are then passed bottom to top to the caller. 

305 

306 Args: 

307 sin_list (List[int]): Current number of sine contributions for each 

308 parameter. Has the same length as the parameters, as each 

309 position corresponds to one parameter. 

310 cos_list (List[int]): Current number of cosine contributions for 

311 each parameter. Has the same length as the parameters, as each 

312 position corresponds to one parameter. 

313 existing_leafs (List[TreeLeaf]): Current list of leaf nodes from 

314 parents. 

315 

316 Returns: 

317 List[TreeLeaf]: Updated list of leaf nodes. 

318 """ 

319 

320 if self.is_sine_factor: 

321 sin_list = sin_list.at[self.parameter_idx].set( 

322 sin_list[self.parameter_idx] + 1 

323 ) 

324 if self.is_cosine_factor: 

325 cos_list = cos_list.at[self.parameter_idx].set( 

326 cos_list[self.parameter_idx] + 1 

327 ) 

328 

329 if not (self.left or self.right): # leaf 

330 if self.term != 0.0: 

331 return [FourierTree.TreeLeaf(sin_list, cos_list, self.term)] 

332 else: 

333 return [] 

334 

335 if self.left: 

336 leafs_left = self.left.get_leafs( 

337 sin_list.copy(), cos_list.copy(), existing_leafs.copy() 

338 ) 

339 else: 

340 leafs_left = [] 

341 

342 if self.right: 

343 leafs_right = self.right.get_leafs( 

344 sin_list.copy(), cos_list.copy(), existing_leafs.copy() 

345 ) 

346 else: 

347 leafs_right = [] 

348 

349 existing_leafs.extend(leafs_left) 

350 existing_leafs.extend(leafs_right) 

351 return existing_leafs 

352 

353 @dataclass 

354 class TreeLeaf: 

355 """ 

356 Coefficient tree leafs according to the algorithm by Nemkov et al., which 

357 correspond to the terms in the sine-cosine tree representation that 

358 eventually are used to obtain coefficients and frequencies. 

359 

360 Args: 

361 sin_indices (List[int]): Current number of sine contributions for each 

362 parameter. Has the same length as the parameters, as each 

363 position corresponds to one parameter. 

364 cos_list (List[int]): Current number of cosine contributions for 

365 each parameter. Has the same length as the parameters, as each 

366 position corresponds to one parameter. 

367 term (jnp.complex): Constant factor of the leaf, depending on the 

368 expectation value of the observable, and a phase. 

369 """ 

370 

371 sin_indices: List[int] 

372 cos_indices: List[int] 

373 term: jnp.complex128 

374 

375 class PauliOperator: 

376 """ 

377 Utility class for storing Pauli Rotations, the corresponding indices 

378 in the XY-Space (whether there is a gate with X or Y generator at a 

379 certain qubit) and the phase. 

380 

381 Args: 

382 pauli (Union[Operator, jnp.ndarray[int]]): Pauli Rotation Operation 

383 or list representation 

384 n_qubits (int): Number of qubits in the circuit 

385 prev_xy_indices (Optional[jnp.ndarray[bool]]): X/Y indices of the 

386 previous Pauli sequence. Defaults to None. 

387 is_observable (bool): If the operator is an observable. Defaults to 

388 False. 

389 is_init (bool): If this Pauli operator is initialised the first 

390 time. Defaults to True. 

391 phase (float): Phase of the operator. Defaults to 1.0 

392 """ 

393 

394 def __init__( 

395 self, 

396 pauli: Union[Operator, jnp.ndarray[int]], 

397 n_qubits: int, 

398 prev_xy_indices: Optional[jnp.ndarray[bool]] = None, 

399 is_observable: bool = False, 

400 is_init: bool = True, 

401 phase: float = 1.0, 

402 ): 

403 self.is_observable = is_observable 

404 self.phase = phase 

405 

406 if is_init: 

407 if not is_observable: 

408 pauli = pauli.generator()[0].base 

409 self.list_repr = self._create_list_representation(pauli, n_qubits) 

410 else: 

411 assert isinstance(pauli, jnp.ndarray) 

412 self.list_repr = pauli 

413 

414 if prev_xy_indices is None: 

415 prev_xy_indices = jnp.zeros(n_qubits, dtype=bool) 

416 self.xy_indices = jnp.logical_or( 

417 prev_xy_indices, 

418 self._compute_xy_indices(self.list_repr, rev=is_observable), 

419 ) 

420 

421 @staticmethod 

422 def _compute_xy_indices( 

423 op: jnp.ndarray[int], rev: bool = False 

424 ) -> jnp.ndarray[bool]: 

425 """ 

426 Computes the positions of X or Y gates to an one-hot encoded boolen 

427 array. 

428 

429 Args: 

430 op (jnp.ndarray[int]): Pauli-Operation list representation. 

431 rev (bool): Whether to negate the array. 

432 

433 Returns: 

434 jnp.ndarray[bool]: One hot encoded boolean array. 

435 """ 

436 xy_indices = (op == 0) + (op == 1) 

437 if rev: 

438 xy_indices = ~xy_indices 

439 return xy_indices 

440 

441 @staticmethod 

442 def _create_list_representation( 

443 op: Operator, n_qubits: int 

444 ) -> jnp.ndarray[int]: 

445 """ 

446 Create list representation of a Pennylane Operator. 

447 Generally, the list representation is a list of length n_qubits, 

448 where at each position a Pauli Operator is encoded as such: 

449 I: -1 

450 X: 0 

451 Y: 1 

452 Z: 2 

453 

454 Args: 

455 op (Operator): Pennylane Operator 

456 n_qubits (int): number of qubits in the circuit 

457 

458 Returns: 

459 jnp.ndarray[int]: List representation 

460 """ 

461 pauli_repr = -jnp.ones(n_qubits, dtype=int) 

462 op = op.terms()[1][0] if isinstance(op, qml_op.Prod) else op 

463 op = op.base if isinstance(op, qml_op.SProd) else op 

464 

465 if isinstance(op, qml.PauliX): 

466 pauli_repr = pauli_repr.at[op.wires[0]].set(0) 

467 elif isinstance(op, qml.PauliY): 

468 pauli_repr = pauli_repr.at[op.wires[0]].set(1) 

469 elif isinstance(op, qml.PauliZ): 

470 pauli_repr = pauli_repr.at[op.wires[0]].set(2) 

471 else: 

472 for o in op: 

473 if isinstance(o, qml.PauliX): 

474 pauli_repr = pauli_repr.at[o.wires[0]].set(0) 

475 elif isinstance(o, qml.PauliY): 

476 pauli_repr = pauli_repr.at[o.wires[0]].set(1) 

477 elif isinstance(o, qml.PauliZ): 

478 pauli_repr = pauli_repr.at[o.wires[0]].set(2) 

479 return pauli_repr 

480 

481 def is_commuting(self, pauli: jnp.ndarray[int]) -> bool: 

482 """ 

483 Computes if this Pauli commutes with another Pauli operator. 

484 This computation is based on the fact that The commutator is zero 

485 if and only if the number of anticommuting single-qubit Paulis is 

486 even. 

487 

488 Args: 

489 pauli (jnp.ndarray[int]): List representation of another Pauli 

490 

491 Returns: 

492 bool: If the current and other Pauli are commuting. 

493 """ 

494 anticommutator = jnp.where( 

495 pauli < 0, 

496 False, 

497 jnp.where( 

498 self.list_repr < 0, 

499 False, 

500 jnp.where(self.list_repr == pauli, False, True), 

501 ), 

502 ) 

503 return not (jnp.sum(anticommutator) % 2) 

504 

505 def tensor(self, pauli: jnp.ndarray[int]) -> FourierTree.PauliOperator: 

506 """ 

507 Compute tensor product between the current Pauli and a given list 

508 representation of another Pauli operator. 

509 

510 Args: 

511 pauli (jnp.ndarray[int]): List representation of Pauli 

512 

513 Returns: 

514 FourierTree.PauliOperator: New Pauli operator object, which 

515 contains the tensor product 

516 """ 

517 diff = (pauli - self.list_repr + 3) % 3 

518 phase = jnp.where( 

519 self.list_repr < 0, 

520 1.0, 

521 jnp.where( 

522 pauli < 0, 

523 1.0, 

524 jnp.where( 

525 diff == 2, 

526 1.0j, 

527 jnp.where(diff == 1, -1.0j, 1.0), 

528 ), 

529 ), 

530 ) 

531 

532 obs = jnp.where( 

533 self.list_repr < 0, 

534 pauli, 

535 jnp.where( 

536 pauli < 0, 

537 self.list_repr, 

538 jnp.where( 

539 diff == 2, 

540 (self.list_repr + 1) % 3, 

541 jnp.where(diff == 1, (self.list_repr + 2) % 3, -1), 

542 ), 

543 ), 

544 ) 

545 phase = self.phase * jnp.prod(phase) 

546 return FourierTree.PauliOperator( 

547 obs, phase=phase, n_qubits=obs.size, is_init=False, is_observable=True 

548 ) 

549 

550 def __init__(self, model: Model, inputs: Optional[jnp.ndarray] = None): 

551 """ 

552 Tree initialisation, based on the Pauli-Clifford representation of a model. 

553 Currently, only one input feature is supported. 

554 

555 **Usage**: 

556 ``` 

557 # initialise a model 

558 model = Model(...) 

559 

560 # initialise and build FourierTree 

561 tree = FourierTree(model) 

562 

563 # get expectaion value 

564 exp = tree() 

565 

566 # Get spectrum (for each observable, we have one list element) 

567 coeff_list, freq_list = tree.spectrum() 

568 ``` 

569 

570 Args: 

571 model (Model): The Model, for which to build the tree 

572 inputs (bool, optional): Possible default inputs. Defaults to 1.0. 

573 """ 

574 self.model = model 

575 self.tree_roots = None 

576 

577 inputs = ( 

578 self.model._inputs_validation(inputs) 

579 if inputs is not None 

580 else self.model._inputs_validation([1.0]) 

581 ) 

582 

583 # TODO: duplicate the input to find out, where it is in the tape. Not 

584 # really pretty. 

585 if inputs.shape[0] == 1: 

586 inputs = jnp.repeat(inputs, 2, axis=0) 

587 

588 quantum_tape = qml.workflow.construct_tape(self.model.circuit)( 

589 params=model.params, inputs=inputs 

590 ) 

591 

592 quantum_tape = PauliCircuit.from_parameterised_circuit(quantum_tape) 

593 

594 self.parameters = [jnp.squeeze(p) for p in quantum_tape.get_parameters()] 

595 

596 self.input_indices = [ 

597 i for (i, p) in enumerate(self.parameters) if p.shape != () 

598 ] 

599 

600 self.observables = self._encode_observables(quantum_tape.observables) 

601 

602 pauli_rot = FourierTree.PauliOperator( 

603 quantum_tape.operations[0], 

604 self.model.n_qubits, 

605 ) 

606 self.pauli_rotations = [pauli_rot] 

607 for op in quantum_tape.operations[1:]: 

608 pauli_rot = FourierTree.PauliOperator( 

609 op, self.model.n_qubits, pauli_rot.xy_indices 

610 ) 

611 self.pauli_rotations.append(pauli_rot) 

612 

613 self.tree_roots = self.build() 

614 self.leafs: List[List[FourierTree.TreeLeaf]] = self._get_tree_leafs() 

615 

616 def __call__( 

617 self, 

618 params: Optional[jnp.ndarray] = None, 

619 inputs: Optional[jnp.ndarray] = None, 

620 **kwargs, 

621 ) -> jnp.ndarray: 

622 """ 

623 Evaluates the Fourier tree via sine-cosine terms sum. This is 

624 equivalent to computing the expectation value of the observables with 

625 respect to the corresponding circuit. 

626 

627 Args: 

628 params (Optional[jnp.ndarray], optional): Parameters of the model. 

629 Defaults to None. 

630 inputs (Optional[jnp.ndarray], optional): Inputs to the circuit. 

631 Defaults to None. 

632 

633 Returns: 

634 jnp.ndarray: Expectation value of the tree. 

635 

636 Raises: 

637 NotImplementedError: When using other "execution_type" as expval. 

638 NotImplementedError: When using "noise_params" 

639 

640 

641 """ 

642 params = ( 

643 self.model._params_validation(params) 

644 if params is not None 

645 else self.model.params 

646 ) 

647 inputs = ( 

648 self.model._inputs_validation(inputs) 

649 if inputs is not None 

650 else self.model._inputs_validation(1.0) 

651 ) 

652 

653 if kwargs.get("execution_type", "expval") != "expval": 

654 raise NotImplementedError( 

655 f'Currently, only "expval" execution type is supported when ' 

656 f"building FourierTree. Got {kwargs.get('execution_type', 'expval')}." 

657 ) 

658 if kwargs.get("noise_params", None) is not None: 

659 raise NotImplementedError( 

660 "Currently, noise is not supported when building FourierTree." 

661 ) 

662 

663 quantum_tape = qml.workflow.construct_tape(self.model.circuit)( 

664 params=self.model.params, inputs=inputs 

665 ) 

666 quantum_tape = PauliCircuit.from_parameterised_circuit(quantum_tape) 

667 

668 self.parameters = [jnp.squeeze(p) for p in quantum_tape.get_parameters()] 

669 

670 results = jnp.zeros(len(self.tree_roots)) 

671 for i, root in enumerate(self.tree_roots): 

672 results = results.at[i].set(jnp.real(root.evaluate(self.parameters))) 

673 

674 if kwargs.get("force_mean", False): 

675 return jnp.mean(results) 

676 else: 

677 return results 

678 

679 def build(self) -> List[CoefficientsTreeNode]: 

680 """ 

681 Creates the coefficient tree, i.e. it creates and initialises the tree 

682 nodes. 

683 Leafs can be obtained separately in _get_tree_leafs, once the tree is 

684 set up. 

685 

686 Returns: 

687 List[CoefficientsTreeNode]: The list of root nodes (one root for 

688 each observable). 

689 """ 

690 tree_roots = [] 

691 pauli_rotation_idx = len(self.pauli_rotations) - 1 

692 for obs in self.observables: 

693 root = self._create_tree_node(obs, pauli_rotation_idx) 

694 tree_roots.append(root) 

695 return tree_roots 

696 

697 def _encode_observables( 

698 self, tape_obs: List[Operator] 

699 ) -> List[FourierTree.PauliOperator]: 

700 """ 

701 Encodes Pennylane observables from tape as FourierTree.PauliOperator 

702 utility objects. 

703 

704 Args: 

705 tape_obs (List[Operator]): Pennylane tape operations 

706 

707 Returns: 

708 List[FourierTree.PauliOperator]: List of Pauli operators 

709 """ 

710 observables = [] 

711 for obs in tape_obs: 

712 observables.append( 

713 FourierTree.PauliOperator(obs, self.model.n_qubits, is_observable=True) 

714 ) 

715 return observables 

716 

717 def _get_tree_leafs(self) -> List[List[TreeLeaf]]: 

718 """ 

719 Obtain all Leaf Nodes with its sine- and cosine terms. 

720 

721 Returns: 

722 List[List[TreeLeaf]]: For each observable (root), the list of leaf 

723 nodes. 

724 """ 

725 leafs = [] 

726 for root in self.tree_roots: 

727 sin_list = jnp.zeros(len(self.parameters), dtype=jnp.int32) 

728 cos_list = jnp.zeros(len(self.parameters), dtype=jnp.int32) 

729 leafs.append(root.get_leafs(sin_list, cos_list, [])) 

730 return leafs 

731 

732 def get_spectrum( 

733 self, force_mean: bool = False 

734 ) -> Tuple[List[jnp.ndarray], List[jnp.ndarray]]: 

735 """ 

736 Computes the Fourier spectrum for the tree, consisting of the 

737 frequencies and its corresponding coefficinets. 

738 If the frag force_mean was set in the constructor, the mean coefficient 

739 over all observables (roots) are computed. 

740 

741 Args: 

742 force_mean (bool, optional): Whether to average over multiple 

743 observables. Defaults to False. 

744 

745 Returns: 

746 Tuple[List[jnp.ndarray], List[jnp.ndarray]]: 

747 - List of frequencies, one list for each observable (root). 

748 - List of corresponding coefficents, one list for each 

749 observable (root). 

750 """ 

751 parameter_indices = [ 

752 i for i in range(len(self.parameters)) if i not in self.input_indices 

753 ] 

754 

755 coeffs = [] 

756 for leafs in self.leafs: 

757 freq_terms = defaultdict(np.complex128) 

758 for leaf in leafs: 

759 leaf_factor, s, c = self._compute_leaf_factors(leaf, parameter_indices) 

760 

761 for a in range(s + 1): 

762 for b in range(c + 1): 

763 comb = math.comb(s, a) * math.comb(c, b) * (-1) ** (s - a) 

764 freq_terms[2 * a + 2 * b - s._value - c._value] += ( 

765 comb * leaf_factor 

766 ) 

767 

768 coeffs.append(freq_terms) 

769 

770 frequencies, coefficients = self._freq_terms_to_coeffs(coeffs, force_mean) 

771 return coefficients, frequencies 

772 

773 def _freq_terms_to_coeffs( 

774 self, coeffs: List[Dict[int, jnp.ndarray]], force_mean: bool 

775 ) -> Tuple[List[jnp.ndarray], List[jnp.ndarray]]: 

776 """ 

777 Given a list of dictionaries of the form: 

778 [ 

779 { 

780 freq_obs1_1: coeff1, 

781 freq_obs1_2: coeff2, 

782 ... 

783 }, 

784 { 

785 freq_obs2_1: coeff3, 

786 freq_obs2_2: coeff4, 

787 ... 

788 } 

789 ... 

790 ], 

791 Compute two separate lists of frequencies and coefficients. 

792 such that: 

793 freqs: [ 

794 [freq_obs1_1, freq_obs1_1, ...], 

795 [freq_obs2_1, freq_obs2_1, ...], 

796 ... 

797 ] 

798 coeffs: [ 

799 [coeff1, coeff2, ...], 

800 [coeff3, coeff4, ...], 

801 ... 

802 ] 

803 

804 If force_mean is set length of the resulting frequency and coefficent 

805 list is 1. 

806 

807 Args: 

808 coeffs (List[Dict[int, jnp.ndarray]]): Frequency->Coefficients 

809 dictionary list, one dict for each observable (root). 

810 force_mean (bool): Whether to average coefficients over multiple 

811 observables. 

812 

813 Returns: 

814 Tuple[List[jnp.ndarray], List[jnp.ndarray]]: 

815 - List of frequencies, one list for each observable (root). 

816 - List of corresponding coefficents, one list for each 

817 observable (root). 

818 """ 

819 frequencies = [] 

820 coefficients = [] 

821 if force_mean: 

822 all_freqs = sorted(set([f for c in coeffs for f in c.keys()])) 

823 coefficients.append( 

824 jnp.array( 

825 [ 

826 jnp.mean(jnp.array([c.get(f, 0.0) for c in coeffs])) 

827 for f in all_freqs 

828 ] 

829 ) 

830 ) 

831 frequencies.append(jnp.array(all_freqs)) 

832 else: 

833 for freq_terms in coeffs: 

834 freq_terms = dict(sorted(freq_terms.items())) 

835 frequencies.append(jnp.array(list(freq_terms.keys()))) 

836 coefficients.append(jnp.array(list(freq_terms.values()))) 

837 return frequencies, coefficients 

838 

839 def _compute_leaf_factors( 

840 self, leaf: TreeLeaf, parameter_indices: List[int] 

841 ) -> Tuple[float, int, int]: 

842 """ 

843 Computes the constant coefficient factor for each leaf. 

844 Additionally sine and cosine contributions of the input parameters for 

845 this leaf are returned, which are required to obtain the corresponding 

846 frequencies. 

847 

848 Args: 

849 leaf (TreeLeaf): The leaf for which to compute the factor. 

850 parameter_indices (List[int]): Variational parameter indices. 

851 

852 Returns: 

853 Tuple[float, int, int]: 

854 - float: the constant factor for the leaf 

855 - int: number of sine contributions of the input 

856 - int: number of cosine contributions of the input 

857 """ 

858 leaf_factor = 1.0 

859 for i in parameter_indices: 

860 interm_factor = ( 

861 jnp.cos(self.parameters[i]) ** leaf.cos_indices[i] 

862 * (1j * jnp.sin(self.parameters[i])) ** leaf.sin_indices[i] 

863 ) 

864 leaf_factor = leaf_factor * interm_factor 

865 

866 # Get number of sine and cosine factors to which the input contributes 

867 c = jnp.sum(jnp.array([leaf.cos_indices[k] for k in self.input_indices])) 

868 s = jnp.sum(jnp.array([leaf.sin_indices[k] for k in self.input_indices])) 

869 

870 leaf_factor = leaf.term * leaf_factor * 0.5 ** (s + c) 

871 

872 return leaf_factor, s, c 

873 

874 def _early_stopping_possible( 

875 self, pauli_rotation_idx: int, observable: FourierTree.PauliOperator 

876 ): 

877 """ 

878 Checks if a node for an observable can be discarded as all expecation 

879 values that can result through further branching are zero. 

880 The method is mentioned in the paper by Nemkov et al.: If the one-hot 

881 encoded indices for X/Y operations in the Pauli-rotation generators are 

882 a basis for that of the observable, the node must be processed further. 

883 If not, it can be discarded. 

884 

885 Args: 

886 pauli_rotation_idx (int): Index of remaining Pauli rotation gates. 

887 Gates itself are attributes of the class. 

888 observable (FourierTree.PauliOperator): Current observable 

889 """ 

890 xy_indices_obs = jnp.logical_or( 

891 observable.xy_indices, self.pauli_rotations[pauli_rotation_idx].xy_indices 

892 ).all() 

893 

894 return not xy_indices_obs 

895 

896 def _create_tree_node( 

897 self, 

898 observable: FourierTree.PauliOperator, 

899 pauli_rotation_idx: int, 

900 parameter_idx: Optional[int] = None, 

901 is_sine: bool = False, 

902 is_cosine: bool = False, 

903 ) -> Optional[CoefficientsTreeNode]: 

904 """ 

905 Builds the Fourier-Tree according to the algorithm by Nemkov et al. 

906 

907 Args: 

908 observable (FourierTree.PauliOperator): Current observable 

909 pauli_rotation_idx (int): Index of remaining Pauli rotation gates. 

910 Gates itself are attributes of the class. 

911 parameter_idx (Optional[int]): Index of the current parameter. 

912 Parameters itself are attributes of the class. 

913 is_sine (bool): If the current node is a sine (left) node. 

914 is_cosine (bool): If the current node is a cosine (right) node. 

915 

916 Returns: 

917 Optional[CoefficientsTreeNode]: The resulting node. Children are set 

918 recursively. The top level receives the tree root. 

919 """ 

920 if self._early_stopping_possible(pauli_rotation_idx, observable): 

921 return None 

922 

923 # remove commuting paulis 

924 while pauli_rotation_idx >= 0: 

925 last_pauli = self.pauli_rotations[pauli_rotation_idx] 

926 if not observable.is_commuting(last_pauli.list_repr): 

927 break 

928 pauli_rotation_idx -= 1 

929 else: # leaf 

930 return FourierTree.CoefficientsTreeNode( 

931 parameter_idx, observable, is_sine, is_cosine 

932 ) 

933 

934 last_pauli = self.pauli_rotations[pauli_rotation_idx] 

935 

936 left = self._create_tree_node( 

937 observable, 

938 pauli_rotation_idx - 1, 

939 pauli_rotation_idx, 

940 is_cosine=True, 

941 ) 

942 

943 next_observable = self._create_new_observable(last_pauli.list_repr, observable) 

944 right = self._create_tree_node( 

945 next_observable, 

946 pauli_rotation_idx - 1, 

947 pauli_rotation_idx, 

948 is_sine=True, 

949 ) 

950 

951 return FourierTree.CoefficientsTreeNode( 

952 parameter_idx, 

953 observable, 

954 is_sine, 

955 is_cosine, 

956 left, 

957 right, 

958 ) 

959 

960 def _create_new_observable( 

961 self, pauli: jnp.ndarray[int], observable: FourierTree.PauliOperator 

962 ) -> FourierTree.PauliOperator: 

963 """ 

964 Utility function to obtain the new observable for a tree node, if the 

965 last Pauli and the observable do not commute. 

966 

967 Args: 

968 pauli (jnp.ndarray[int]): The int array representation of the last 

969 Pauli rotation in the operation sequence. 

970 observable (FourierTree.PauliOperator): The current observable. 

971 

972 Returns: 

973 FourierTree.PauliOperator: The new observable. 

974 """ 

975 observable = observable.tensor(pauli) 

976 return observable 

977 

978 

979class FCC: 

980 @staticmethod 

981 def get_fcc( 

982 model: Model, 

983 n_samples: int, 

984 seed: int, 

985 method: Optional[str] = "pearson", 

986 scale: Optional[bool] = False, 

987 weight: Optional[bool] = False, 

988 trim_redundant: Optional[bool] = True, 

989 **kwargs, 

990 ) -> float: 

991 """ 

992 Shortcut method to get just the FCC. 

993 This includes 

994 1. What is done in `get_fourier_fingerprint`: 

995 1. Calculating the coefficients (using `n_samples` and `seed`) 

996 2. Correlating the result from 1) using `method` 

997 3. Weighting the correlation matrix (if `weight` is True) 

998 4. Remove redundancies 

999 2. What is done in `calculate_fcc`: 

1000 1. Absolute of the fingerprint 

1001 2. Average 

1002 

1003 Args: 

1004 model (Model): The QFM model 

1005 n_samples (int): Number of samples to calculate average of coefficients 

1006 seed (int): Seed to initialize random parameters 

1007 method (Optional[str], optional): Correlation method. Defaults to "pearson". 

1008 scale (Optional[bool], optional): Whether to scale the number of samples. 

1009 Defaults to False. 

1010 weight (Optional[bool], optional): Whether to weight the correlation matrix. 

1011 Defaults to False. 

1012 trim_redundant (Optional[bool], optional): Whether to remove redundant 

1013 correlations. Defaults to False. 

1014 **kwargs: Additional keyword arguments for the model function. 

1015 

1016 Returns: 

1017 float: The FCC 

1018 """ 

1019 fourier_fingerprint, _ = FCC.get_fourier_fingerprint( 

1020 model, 

1021 n_samples, 

1022 seed, 

1023 method, 

1024 scale, 

1025 weight, 

1026 trim_redundant=trim_redundant, 

1027 **kwargs, 

1028 ) 

1029 

1030 return FCC.calculate_fcc(fourier_fingerprint) 

1031 

1032 def get_fourier_fingerprint( 

1033 model: Model, 

1034 n_samples: int, 

1035 seed: int, 

1036 method: Optional[str] = "pearson", 

1037 scale: Optional[bool] = False, 

1038 weight: Optional[bool] = False, 

1039 trim_redundant: Optional[bool] = True, 

1040 **kwargs, 

1041 ) -> Tuple[jnp.ndarray, jnp.ndarray]: 

1042 """ 

1043 Shortcut method to get just the fourier fingerprint. 

1044 This includes 

1045 1. Calculating the coefficients (using `n_samples` and `seed`) 

1046 2. Correlating the result from 1) using `method` 

1047 3. Weighting the correlation matrix (if `weight` is True) 

1048 4. Remove redundancies (if `trim_redundant` is True) 

1049 

1050 Args: 

1051 model (Model): The QFM model 

1052 n_samples (int): Number of samples to calculate average of coefficients 

1053 seed (int): Seed to initialize random parameters 

1054 method (Optional[str], optional): Correlation method. Defaults to "pearson". 

1055 scale (Optional[bool], optional): Whether to scale the number of samples. 

1056 Defaults to False. 

1057 weight (Optional[bool], optional): Whether to weight the correlation matrix. 

1058 Defaults to False. 

1059 trim_redundant (Optional[bool], optional): Whether to remove redundant 

1060 correlations. Defaults to True. 

1061 **kwargs: Additional keyword arguments for the model function. 

1062 

1063 Returns: 

1064 Tuple[jnp.ndarray, jnp.ndarray]: The fourier fingerprint 

1065 and the frequency indices 

1066 """ 

1067 _, coeffs, freqs = FCC._calculate_coefficients( 

1068 model, n_samples, seed, scale, **kwargs 

1069 ) 

1070 fourier_fingerprint = FCC._correlate(coeffs.transpose(), method=method) 

1071 

1072 # perform weighting if requested 

1073 fourier_fingerprint = ( 

1074 FCC._weighting(fourier_fingerprint) if weight else fourier_fingerprint 

1075 ) 

1076 

1077 if trim_redundant: 

1078 mask = FCC._calculate_mask(freqs) 

1079 

1080 # apply the mask on the fingerprint 

1081 fourier_fingerprint = mask * fourier_fingerprint 

1082 

1083 row_mask = jnp.any(jnp.isfinite(fourier_fingerprint), axis=1) 

1084 col_mask = jnp.any(jnp.isfinite(fourier_fingerprint), axis=0) 

1085 

1086 fourier_fingerprint = fourier_fingerprint[row_mask][:, col_mask] 

1087 

1088 return fourier_fingerprint, freqs 

1089 

1090 @staticmethod 

1091 def calculate_fcc( 

1092 fourier_fingerprint: jnp.ndarray, 

1093 ) -> float: 

1094 """ 

1095 Method to calculate the FCC based on an existing correlation matrix. 

1096 Calculate absolute and then the average over this matrix. 

1097 The Fingerprint can be obtained via `get_fourier_fingerprint` 

1098 

1099 Args: 

1100 coeff_coeff_correlation (jnp.ndarray): Correlation matrix of coefficients 

1101 Returns: 

1102 float: The FCC 

1103 """ 

1104 # apply the mask on the fingerprint 

1105 return jnp.nanmean(jnp.abs(fourier_fingerprint)) 

1106 

1107 def _calculate_mask(freqs: jnp.ndarray) -> jnp.ndarray: 

1108 """ 

1109 Method to calculate a mask filtering out redundant elements 

1110 of the Fourier correlation matrix, based on the provided frequency vector. 

1111 It does so by 'simulating' the operations that would be performed 

1112 by `_correlate`. 

1113 

1114 Args: 

1115 freqs (jnp.ndarray): Array of frequencies 

1116 

1117 Returns: 

1118 jnp.ndarray: The mask 

1119 """ 

1120 # TODO: this part can be heavily optimized, by e.g. using a "positive_only" 

1121 # flag when calculating the coefficients. 

1122 # However this would change the numerical values 

1123 # (while the order should be still the same). 

1124 

1125 # disregard all the negativ frequencies 

1126 freqs[freqs < 0] = jnp.nan 

1127 # compute the outer product of the frequency vectors for arbitrary dimensions 

1128 # or just use the existing frequency vector if it is 1D 

1129 nd_freqs = ( 

1130 reduce(jnp.multiply, jnp.ix_(*freqs)) if len(freqs.shape) > 1 else freqs 

1131 ) 

1132 # TODO: could prevent this if we're not using .squeeze().. 

1133 

1134 # "simulate" what would happen on correlating the coefficients 

1135 corr_freqs = jnp.outer(nd_freqs, nd_freqs) 

1136 # mask all frequencies that are nan now 

1137 # (i.e. all correlations with a negative frequency component) 

1138 corr_mask = jnp.where(jnp.isnan(corr_freqs), corr_freqs, 1) 

1139 # from this, disregard all the other redundant correlations (i.e. c_0_1 = c_1_0) 

1140 corr_mask = corr_mask.at[jnp.triu_indices(corr_mask.shape[0], 0)].set(jnp.nan) 

1141 

1142 return corr_mask 

1143 

1144 @staticmethod 

1145 def _calculate_coefficients( 

1146 model: Model, 

1147 n_samples: int, 

1148 seed: int, 

1149 scale: bool = False, 

1150 **kwargs, 

1151 ) -> Tuple[jnp.ndarray, jnp.ndarray]: 

1152 """ 

1153 Calculates the Fourier coefficients of a given model 

1154 using `n_samples` and `seed`. 

1155 Optionally, `noise_params` can be passed to perform noisy simulation. 

1156 

1157 Args: 

1158 model (Model): The QFM model 

1159 n_samples (int): Number of samples to calculate average of coefficients 

1160 seed (int): Seed to initialize random parameters 

1161 scale (bool, optional): Whether to scale the number of samples. 

1162 Defaults to False. 

1163 **kwargs: Additional keyword arguments for the model function. 

1164 

1165 Returns: 

1166 Tuple[jnp.ndarray, jnp.ndarray]: Parameters and Coefficients of size NxK 

1167 """ 

1168 if n_samples > 0: 

1169 if scale: 

1170 total_samples = int( 

1171 jnp.power(2, model.n_qubits) * n_samples * model.n_input_feat 

1172 ) 

1173 log.info(f"Using {total_samples} samples.") 

1174 else: 

1175 total_samples = n_samples 

1176 random_key = random.key(seed) 

1177 model.initialize_params(random_key, repeat=total_samples) 

1178 else: 

1179 total_samples = 1 

1180 

1181 coeffs, freqs = Coefficients.get_spectrum( 

1182 model, shift=True, trim=True, **kwargs 

1183 ) 

1184 

1185 return model.params, coeffs, freqs 

1186 

1187 @staticmethod 

1188 def _correlate(mat: jnp.ndarray, method: str = "pearson") -> jnp.ndarray: 

1189 """ 

1190 Correlates two arrays using `method`. 

1191 Currently, `pearson` and `spearman` are supported. 

1192 

1193 Args: 

1194 mat (jnp.ndarray): Array of shape (N, K) 

1195 method (str, optional): Correlation method. Defaults to "pearson". 

1196 

1197 Raises: 

1198 ValueError: If the method is not supported. 

1199 

1200 Returns: 

1201 jnp.ndarray: Correlation matrix of `a` and `b`. 

1202 """ 

1203 assert len(mat.shape) >= 2, "Input matrix must have at least 2 dimensions" 

1204 

1205 # Note that for the general n-D case, we have to flatten along 

1206 # the first axis (last one is batch). 

1207 # Note that the order here is important so we can easily filter out 

1208 # negative coefficients later. 

1209 # Consider the following example: [[1,2,3],[4,5,6],[7,8,9]] 

1210 # we want to get [1, 4, 7, 2, 5, 8, 3, 6, 9] 

1211 # such that after correlation, all positive indexed coefficients 

1212 # will be in the bottom right quadrant 

1213 if method == "pearson": 

1214 result = FCC._pearson(mat.reshape(mat.shape[0], -1)) 

1215 # result = FCC._pearson(mat.reshape(mat.shape[-1], -1, order="F")) 

1216 elif method == "spearman": 

1217 result = FCC._spearman(mat.reshape(mat.shape[0], -1)) 

1218 # result = FCC._spearman(mat.reshape(mat.shape[-1], -1, order="F")) 

1219 else: 

1220 raise ValueError( 

1221 f"Unknown correlation method: {method}. \ 

1222 Must be 'pearson' or 'spearman'." 

1223 ) 

1224 

1225 return result 

1226 

1227 @staticmethod 

1228 def _pearson( 

1229 mat: jnp.ndarray, cov: Optional[bool] = False, minp: Optional[int] = 1 

1230 ) -> jnp.ndarray: 

1231 """ 

1232 Based on Pandas correlation method as implemented here: 

1233 https://github.com/pandas-dev/pandas/blob/main/pandas/_libs/algos.pyx 

1234 

1235 Compute Pearson correlation between columns of `mat`, 

1236 permitting missing values (NaN or ±Inf). 

1237 

1238 Args: 

1239 mat : array_like, shape (N, K) 

1240 Input data. 

1241 minp : int, optional 

1242 Minimum number of paired observations required to form a correlation. 

1243 If the number of valid pairs for (i, j) is < minp, the result is NaN. 

1244 

1245 Returns: 

1246 corr : ndarray, shape (K, K) 

1247 Pearson correlation matrix. 

1248 """ 

1249 

1250 mat = jnp.asarray(mat, dtype=jnp.float64) 

1251 N, K = mat.shape 

1252 

1253 # pre‐compute finite‐mask 

1254 mask = jnp.isfinite(mat) 

1255 

1256 # output 

1257 result = np.empty((K, K), dtype=jnp.float64) 

1258 

1259 # TODO: optimize in future iterations 

1260 # loop over column‐pairs 

1261 for i in range(K): 

1262 for j in range(i + 1): 

1263 # find rows where both columns are finite 

1264 m = mask[:, i] & mask[:, j] 

1265 n = jnp.count_nonzero(m) 

1266 if n < minp: 

1267 # too few pairs 

1268 value = jnp.nan 

1269 else: 

1270 x = mat[m, i] 

1271 y = mat[m, j] 

1272 

1273 # compute means 

1274 mean_x = x.mean() 

1275 mean_y = y.mean() 

1276 

1277 # demeaned data 

1278 dx = x - mean_x 

1279 dy = y - mean_y 

1280 

1281 # sum of squares and cross‐prod 

1282 ssx = jnp.dot(dx, dx) 

1283 ssy = jnp.dot(dy, dy) 

1284 cxy = jnp.dot(dx, dy) 

1285 

1286 if cov: 

1287 # sample covariance (denominator n−1) 

1288 value = cxy / (n - 1) if n > 1 else jnp.nan 

1289 else: 

1290 # Pearson r = cov / (σx σy) 

1291 denom = jnp.sqrt(ssx * ssy) 

1292 if denom == 0.0: 

1293 value = jnp.nan 

1294 else: 

1295 value = cxy / denom 

1296 # clip numerical drift 

1297 if value > 1.0: 

1298 value = 1.0 

1299 elif value < -1.0: 

1300 value = -1.0 

1301 

1302 result[i, j] = result[j, i] = value 

1303 

1304 return result 

1305 

1306 def _spearman(mat: jnp.ndarray, minp: Optional[int] = 1) -> jnp.ndarray: 

1307 """ 

1308 Based on Pandas correlation method as implemented here: 

1309 https://github.com/pandas-dev/pandas/blob/main/pandas/_libs/algos.pyx 

1310 

1311 Compute Spearman correlation between columns of `mat`, 

1312 permitting missing values (NaN or ±Inf). 

1313 

1314 Args: 

1315 mat : array_like, shape (N, K) 

1316 Input data. 

1317 minp : int, optional 

1318 Minimum number of paired observations required to form a correlation. 

1319 If the number of valid pairs for (i, j) is < minp, the result is NaN. 

1320 

1321 Returns: 

1322 corr : ndarray, shape (K, K) 

1323 Spearman correlation matrix. 

1324 """ 

1325 N, K = mat.shape 

1326 # trivial all-NaN answer if too few rows 

1327 if N < minp: 

1328 return jnp.full((K, K), jnp.nan, dtype=float) 

1329 

1330 # mask of finite entries 

1331 mask = jnp.isfinite(mat) # shape (N, K), dtype=bool 

1332 

1333 # precompute ranks column-wise ignoring NaNs 

1334 ranks = np.full((N, K), jnp.nan, dtype=float) 

1335 for j in range(K): 

1336 valid = mask[:, j] 

1337 if valid.any(): 

1338 # rankdata by default gives average ranks for ties 

1339 ranks[valid, j] = rankdata(mat[valid, j], method="average") 

1340 

1341 # allocate result 

1342 result = np.empty((K, K), dtype=float) 

1343 

1344 # TODO: optimize in future iterations 

1345 # loop lower triangle (including diagonal) 

1346 for i in range(K): 

1347 for j in range(i + 1): 

1348 # find rows where both columns are finite 

1349 valid = mask[:, i] & mask[:, j] 

1350 nobs = valid.sum() 

1351 

1352 if nobs < minp: 

1353 rho = jnp.nan 

1354 else: 

1355 xi = ranks[valid, i] 

1356 yj = ranks[valid, j] 

1357 # subtract means 

1358 xi = xi - xi.mean() 

1359 yj = yj - yj.mean() 

1360 num = jnp.dot(xi, yj) 

1361 den = jnp.sqrt(jnp.dot(xi, xi) * jnp.dot(yj, yj)) 

1362 rho = num / den if den > 0 else jnp.nan 

1363 

1364 result[i, j] = rho 

1365 result[j, i] = rho 

1366 

1367 return result 

1368 

1369 @staticmethod 

1370 def _weighting(fourier_fingerprint: jnp.ndarray) -> jnp.ndarray: 

1371 """ 

1372 Performs weighting on the given correlation matrix. 

1373 Here, low-frequent coefficients are weighted more heavily. 

1374 

1375 Args: 

1376 correlation (jnp.ndarray): Correlation matrix 

1377 """ 

1378 # TODO: in Future iterations, this can be optimized by computing 

1379 # on the trimmed matrix instead. 

1380 

1381 assert ( 

1382 fourier_fingerprint.shape[0] % 2 != 0 

1383 and fourier_fingerprint.shape[1] % 2 != 0 

1384 ), "Correlation matrix must have odd dimensions. \ 

1385 Hint: use `trim` argument when calling `get_spectrum`." 

1386 assert ( 

1387 fourier_fingerprint.shape[0] == fourier_fingerprint.shape[1] 

1388 ), "Correlation matrix must be square." 

1389 

1390 def quadrant_to_matrix(a: jnp.ndarray) -> jnp.ndarray: 

1391 """ 

1392 Transforms [[1,2],[3,4]] to 

1393 [[1,2,1],[3,4,3],[1,2,1]] 

1394 

1395 Args: 

1396 a (jnp.ndarray): _description_ 

1397 

1398 Returns: 

1399 jnp.ndarray: _description_ 

1400 """ 

1401 # rotates a from [[1,2],[3,4]] to [[3,4],[1,2]] 

1402 a_rot = jnp.rot90(a) 

1403 # merge the two matrices 

1404 left = jnp.concat([a, a_rot]) 

1405 # merges left and right (left flipped) 

1406 b = jnp.concat( 

1407 [left, jnp.flip(left)], 

1408 axis=1, 

1409 ) 

1410 # remove the middle column and row 

1411 return jnp.delete( 

1412 jnp.delete(b, (b.shape[0] // 2), axis=0), (b.shape[1] // 2), axis=1 

1413 ) 

1414 

1415 nc = fourier_fingerprint.shape[0] // 2 + 1 

1416 weights = jnp.mgrid[0:nc:1, 0:nc:1].sum(axis=0) / ((nc - 1) * 2) 

1417 weights_matrix = quadrant_to_matrix(weights) 

1418 

1419 return fourier_fingerprint * weights_matrix 

1420 raise NotImplementedError("Weighting method is not implemented") 

1421 

1422 

1423class Datasets: 

1424 @staticmethod 

1425 def generate_fourier_series( 

1426 random_key: random.PRNGKey, 

1427 model: Model, 

1428 coefficients_min: float = 0.0, 

1429 coefficients_max: float = 1.0, 

1430 zero_centered: bool = False, 

1431 ) -> jnp.ndarray: 

1432 """ 

1433 Generates the Fourier series representation of a function. 

1434 It uses the `model.frequencies` property to retrieve the frequency 

1435 information. This ensures that the resulting Fourier series is 

1436 compatible with the model. 

1437 

1438 This function is capable of generating $D$-dimensional Fourier series 

1439 (again defined by `model.n_input_feat`). 

1440 The highest frequency $N$ is retrieved per dimension. 

1441 

1442 Samples of the Fourier coefficients are drawn from a uniform circle. 

1443 

1444 Args: 

1445 random_key (random.PRNGKey): Random number key for JAX. 

1446 model (Model): The quantum circuit model. 

1447 coefficients_min (float, optional): Minimum value for the coefficients. 

1448 Defaults to 0.0. 

1449 coefficients_max (float, optional): Maximum value for the coefficients. 

1450 Defaults to 1.0. 

1451 zero_centered (bool, optional): Whether to zero-center the coefficients. 

1452 Defaults to False. 

1453 

1454 Returns: 

1455 jnp.ndarray: Input domain samples with shape ((N,)*D, D) 

1456 jnp.ndarray: Fourier series values with shape ((N,)*D) 

1457 jnp.ndarray: Fourier coefficients with shape ((N,)*D) 

1458 

1459 """ 

1460 # TODO: the following code can be considered to 

1461 # capturing a truly random spectrum. 

1462 # add some constraints on the spectrum, i.e. not fully 

1463 

1464 # Note: one key observation for understanding the following code is, 

1465 # that instead of wrapping your head around symmetries in multi- 

1466 # dimensional coefficient matrices, one can simply look at the flattened 

1467 # version of such a matrix and reshape later. It just works out. 

1468 

1469 # going from [0, 2pi] with the resolution required for highest frequency 

1470 # permute with input dimensionality to get an n-d grid of domain samples 

1471 # the output shape comes from the fact that want to create a "coordinate system" 

1472 domain_samples_per_input_dim = jnp.stack( 

1473 jnp.meshgrid( 

1474 *[jnp.arange(0, 2 * jnp.pi, 2 * jnp.pi / d) for d in model.degree] 

1475 ) 

1476 ).T.reshape(-1, model.n_input_feat) 

1477 

1478 # generate the frequency indices for each dimension. 

1479 # this will have the same shape as the domain samples 

1480 frequencies = jnp.stack(jnp.meshgrid(*model.frequencies)).T.reshape( 

1481 -1, model.n_input_feat 

1482 ) 

1483 

1484 # using the frequency information, sample coefficients for each dimension 

1485 # shape: (input_dims, n_freqs_per_input_dim // 2 + 1) 

1486 coefficients = Datasets.uniform_circle( 

1487 random_key, 

1488 low=coefficients_min, 

1489 high=coefficients_max, 

1490 size=math.prod(model.degree) // 2 + 1, 

1491 ) 

1492 

1493 # zero center (first coeff = 0) 

1494 # we can assume the first coeff is the offset, because we're dealing 

1495 # with a non-symmetric spectrum here 

1496 if zero_centered: 

1497 coefficients = coefficients.at[0].set(0.0) 

1498 else: 

1499 coefficients = coefficients.at[0].set(coefficients[0].real) 

1500 

1501 # ensure symmetry (here, non_negative_ is removed!), 

1502 # giving us the full coefficients vector 

1503 coefficients = jnp.concat( 

1504 [ 

1505 jnp.flip(coefficients[..., 1:]).conjugate(), 

1506 coefficients, 

1507 ], 

1508 axis=-1, 

1509 ) 

1510 

1511 # Vectorized version of $f(x) = \sum_{n=0}^{N-1} c_n * e^{i * \omega_n * x}$ 

1512 # it takes into account the input dimension, i.e. the output is a matrix 

1513 # normalization uses the n_freqs component of the coefficients 

1514 values = jnp.real( 

1515 ( 

1516 jnp.exp(1j * (domain_samples_per_input_dim @ frequencies.T)) 

1517 * coefficients 

1518 ).sum(axis=1) 

1519 / coefficients.size 

1520 ) 

1521 

1522 # return all the information we have 

1523 return [ 

1524 domain_samples_per_input_dim.reshape(*model.degree, -1), 

1525 values.reshape(model.degree), 

1526 coefficients.reshape(model.degree), 

1527 ] 

1528 

1529 @staticmethod 

1530 def uniform_circle( 

1531 random_key: random.PRNGKey, 

1532 size: Union[jnp.ndarray, List, int], 

1533 low=0.0, 

1534 high=1.0, 

1535 ): 

1536 """ 

1537 Random number generator for complex numbers sampled inside the unit circle 

1538 

1539 Args: 

1540 random_key (random.PRNGKey): Random number key for JAX. 

1541 size (Union[jnp.ndarray, int]): Number of samples. If a 2D array is passed, 

1542 the first dimension will be the number of dimensions. 

1543 low (float, optional): Minimum Radius. Defaults to 0.0. 

1544 high (float, optional): Maximum Radius. Defaults to 1.0. 

1545 

1546 Returns 

1547 jnp.ndarray: Array of complex numbers with shape of `size` 

1548 """ 

1549 

1550 if isinstance(size, int): 

1551 size = jnp.array([size]) 

1552 

1553 random_key, random_key1 = random.split(random_key) 

1554 return jnp.sqrt( 

1555 random.uniform(random_key, size, minval=low, maxval=high) 

1556 ) * jnp.exp(2j * jnp.pi * random.uniform(random_key1, size))