Coverage for qml_essentials / coefficients.py: 95%

520 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-06-10 08:17 +0000

1from __future__ import annotations 

2import math 

3from collections import defaultdict 

4from dataclasses import dataclass 

5import jax.numpy as jnp 

6from jax import random 

7import numpy as np 

8from scipy.stats import rankdata 

9from functools import reduce 

10from typing import List, Tuple, Optional, Any, Dict, Union 

11 

12from qml_essentials.model import Model 

13from qml_essentials.utils import PauliCircuit 

14from qml_essentials.operations import ( 

15 Operation, 

16 PauliX, 

17 PauliY, 

18 PauliZ, 

19) 

20 

21import logging 

22 

23log = logging.getLogger(__name__) 

24 

25 

26class Coefficients: 

27 @classmethod 

28 def get_spectrum( 

29 cls, 

30 model: Model, 

31 mfs: int = 1, 

32 mts: int = 1, 

33 shift=False, 

34 trim=False, 

35 numerical_cap: Optional[float] = -1, 

36 **kwargs, 

37 ) -> Tuple[jnp.ndarray, jnp.ndarray]: 

38 """ 

39 Extracts the coefficients of a given model using a FFT (jnp-fft). 

40 

41 Note that the coefficients are complex numbers, but the imaginary part 

42 of the coefficients should be very close to zero, since the expectation 

43 values of the Pauli operators are real numbers. 

44 

45 It can perform oversampling in both the frequency and time domain 

46 using the `mfs` and `mts` arguments. 

47 

48 Args: 

49 model (Model): The model to sample. 

50 mfs (int): Multiplicator for the highest frequency. Default is 2. 

51 mts (int): Multiplicator for the number of time samples. Default is 1. 

52 shift (bool): Whether to apply jnp-fftshift. Default is False. 

53 trim (bool): Whether to remove the Nyquist frequency if spectrum is even. 

54 Default is False. 

55 numerical_cap (Optional[float]): Numerical cap for the coefficients. 

56 If positive, coefficients with magnitude below the cap are 

57 zeroed and, for a single input feature, frequencies that 

58 vanish entirely are removed from both `coeffs` and `freqs`. 

59 kwargs (Any): Additional keyword arguments for the model function. 

60 

61 Returns: 

62 Tuple[jnp.ndarray, jnp.ndarray]: Tuple containing the coefficients 

63 and frequencies. 

64 """ 

65 kwargs.setdefault("force_mean", True) 

66 kwargs.setdefault("execution_type", "expval") 

67 

68 coeffs, freqs = cls._fourier_transform(model, mfs=mfs, mts=mts, **kwargs) 

69 

70 if not jnp.isclose(jnp.sum(coeffs).imag, 0.0, rtol=1.0e-5): 

71 raise ValueError( 

72 f"Spectrum is not real. Imaginary part of coefficients is:\ 

73 {jnp.sum(coeffs).imag}" 

74 ) 

75 

76 if trim: 

77 for ax in range(model.n_input_feat): 

78 if coeffs.shape[ax] % 2 == 0: 

79 coeffs = np.delete(coeffs, len(coeffs) // 2, axis=ax) 

80 freqs = [np.delete(freq, len(freq) // 2, axis=ax) for freq in freqs] 

81 

82 if shift: 

83 coeffs = jnp.fft.fftshift(coeffs, axes=list(range(model.n_input_feat))) 

84 freqs = np.fft.fftshift(freqs) 

85 

86 if numerical_cap > 0: 

87 # set coeffs below threshold to zero 

88 coeffs = jnp.where( 

89 jnp.abs(coeffs) < numerical_cap, 

90 jnp.zeros_like(coeffs), 

91 coeffs, 

92 ) 

93 

94 # Drop frequencies whose coefficients vanish entirely after 

95 # capping, so the returned spectrum reflects only the surviving 

96 # frequencies. Well-defined only for a single (1-D) frequency 

97 # axis; for multi-dim input the rectangular grid is left intact. 

98 if model.n_input_feat == 1: 

99 if coeffs.ndim == 1: 

100 surviving = coeffs != 0 

101 else: 

102 surviving = jnp.any( 

103 coeffs != 0, axis=tuple(range(1, coeffs.ndim)) 

104 ) 

105 coeffs = coeffs[surviving] 

106 freqs = [freqs[0][surviving]] 

107 

108 if len(freqs) == 1: 

109 freqs = freqs[0] 

110 

111 return coeffs, freqs 

112 

113 @classmethod 

114 def _fourier_transform( 

115 cls, model: Model, mfs: int, mts: int, **kwargs: Any 

116 ) -> jnp.ndarray: 

117 # Create a frequency vector with as many frequencies as model degrees, 

118 # oversampled by mfs 

119 n_freqs: jnp.ndarray = jnp.array( 

120 [mfs * model.degree[i] for i in range(model.n_input_feat)] 

121 ) 

122 

123 start, stop, step = 0, 2 * mts * jnp.pi, 2 * jnp.pi / n_freqs 

124 # Stretch according to the number of frequencies 

125 inputs: List = [ 

126 jnp.arange(start, stop, step[i]) for i in range(model.n_input_feat) 

127 ] 

128 

129 # permute with input dimensionality 

130 nd_inputs = jnp.array( 

131 jnp.meshgrid(*[inputs[i] for i in range(model.n_input_feat)]) 

132 ).T.reshape(-1, model.n_input_feat) 

133 

134 # Output vector is not necessarily the same length as input 

135 outputs = model(inputs=nd_inputs, **kwargs) 

136 outputs = outputs.reshape( 

137 *[inputs[i].shape[0] for i in range(model.n_input_feat)], -1 

138 ).squeeze() 

139 

140 coeffs = jnp.fft.fftn(outputs, axes=list(range(model.n_input_feat))) 

141 

142 freqs = [ 

143 jnp.fft.fftfreq(int(mts * n_freqs[i]), 1 / n_freqs[i]) 

144 for i in range(model.n_input_feat) 

145 ] 

146 # freqs = jnp.fft.fftfreq(mts * n_freqs, 1 / n_freqs) 

147 

148 # TODO: this could cause issues with multidim input 

149 # FIXME: account for different frequencies in multidim input scenarios 

150 # Run the fft and rearrange + 

151 # normalize the output (using product if multidim) 

152 return ( 

153 coeffs / math.prod(outputs.shape[0 : model.n_input_feat]), 

154 freqs, 

155 ) 

156 

157 @classmethod 

158 def get_psd(cls, coeffs: jnp.ndarray) -> jnp.ndarray: 

159 """ 

160 Calculates the power spectral density (PSD) from given Fourier coefficients. 

161 

162 Args: 

163 coeffs (jnp.ndarray): The Fourier coefficients. 

164 

165 Returns: 

166 jnp.ndarray: The power spectral density. 

167 """ 

168 # TODO: if we apply trim=True in advance, this will be slightly wrong.. 

169 

170 def abs2(x): 

171 return x.real**2 + x.imag**2 

172 

173 scale = 2.0 / (len(coeffs) ** 2) 

174 return scale * abs2(coeffs) 

175 

176 @classmethod 

177 def evaluate_Fourier_series( 

178 cls, 

179 coefficients: jnp.ndarray, 

180 frequencies: jnp.ndarray, 

181 inputs: Union[jnp.ndarray, list, float], 

182 ) -> float: 

183 """ 

184 Evaluate the function value of a Fourier series at one point. 

185 

186 Args: 

187 coefficients (jnp.ndarray): Coefficients of the Fourier series. 

188 frequencies (jnp.ndarray): Corresponding frequencies. 

189 inputs (jnp.ndarray): Point at which to evaluate the function. 

190 Returns: 

191 float: The function value at the input point. 

192 """ 

193 if isinstance(frequencies, list): 

194 if len(coefficients.shape) <= len(frequencies): 

195 coefficients = coefficients[..., jnp.newaxis] 

196 else: 

197 if len(coefficients.shape) == 1: 

198 coefficients = coefficients[..., jnp.newaxis] 

199 

200 if isinstance(inputs, list): 

201 inputs = jnp.array(inputs) 

202 if len(inputs.shape) < 1: 

203 inputs = inputs[jnp.newaxis, ...] 

204 

205 if isinstance(frequencies, list): 

206 input_dim = len(frequencies) 

207 frequencies = jnp.stack(jnp.meshgrid(*frequencies)) 

208 if input_dim != len(inputs): 

209 frequencies = jnp.repeat( 

210 frequencies[jnp.newaxis, ...], inputs.shape[0], axis=0 

211 ) 

212 freq_inputs = jnp.einsum("bi...,b->b...", frequencies, inputs) 

213 exponents = jnp.exp(1j * freq_inputs).T 

214 exp = jnp.einsum("jl...k,jl...b->b...k", coefficients, exponents) 

215 else: 

216 freq_inputs = jnp.einsum("i...,i->...", frequencies, inputs) 

217 exponents = jnp.exp(1j * freq_inputs).T 

218 exp = jnp.einsum("jl...k,jl...->k...", coefficients, exponents) 

219 else: 

220 frequencies = jnp.repeat( 

221 frequencies[jnp.newaxis, ...], inputs.shape[0], axis=0 

222 ) 

223 freq_inputs = jnp.einsum("i...,i->i...", frequencies, inputs) 

224 exponents = jnp.exp(1j * freq_inputs) 

225 exp = jnp.einsum("j...k,ij...->ik...", coefficients, exponents) 

226 

227 return jnp.squeeze(jnp.real(exp)) 

228 

229 

230class FourierTree: 

231 """ 

232 Sine-cosine tree representation for the algorithm by Nemkov et al. 

233 This tree can be used to obtain analytical Fourier coefficients for a given 

234 Pauli-Clifford circuit. 

235 """ 

236 

237 class CoefficientsTreeNode: 

238 """ 

239 Representation of a node in the coefficients tree for the algorithm by 

240 Nemkov et al. 

241 """ 

242 

243 def __init__( 

244 self, 

245 parameter_idx: Optional[int], 

246 observable: FourierTree.PauliOperator, 

247 is_sine_factor: bool, 

248 is_cosine_factor: bool, 

249 left: Optional[FourierTree.CoefficientsTreeNode] = None, 

250 right: Optional[FourierTree.CoefficientsTreeNode] = None, 

251 ): 

252 """ 

253 Coefficient tree node initialisation. Each node has information about 

254 its creation context and it's children, i.e.: 

255 

256 Args: 

257 parameter_idx (Optional[int]): Index of the corresp. param. index i. 

258 observable (FourierTree.PauliOperator): The nodes observable to 

259 obtain the expectation value that contributes to the constant 

260 term. 

261 is_sine_factor (bool): If this node belongs to a sine coefficient. 

262 is_cosine_factor (bool): If this node belongs to a cosine coefficient. 

263 left (Optional[CoefficientsTreeNode]): left child (if any). 

264 right (Optional[CoefficientsTreeNode]): right child (if any). 

265 """ 

266 self.parameter_idx = parameter_idx 

267 

268 assert not (is_sine_factor and is_cosine_factor), ( 

269 "Cannot be sine and cosine at the same time" 

270 ) 

271 self.is_sine_factor = is_sine_factor 

272 self.is_cosine_factor = is_cosine_factor 

273 

274 # If the observable does not constist of only Z and I, the 

275 # expectation (and therefore the constant node term) is zero 

276 if jnp.logical_or( 

277 observable.list_repr == 0, observable.list_repr == 1 

278 ).any(): 

279 self.term = 0.0 

280 else: 

281 self.term = observable.phase 

282 

283 self.left = left 

284 self.right = right 

285 

286 def evaluate(self, parameters: list[float]) -> float: 

287 """ 

288 Recursive function to evaluate the expectation of the coefficient tree, 

289 starting from the current node. 

290 

291 Args: 

292 parameters (list[float]): The parameters, by which the circuit (and 

293 therefore the tree) is parametrised. 

294 

295 Returns: 

296 float: The expectation for the current node and it's children. 

297 """ 

298 factor = ( 

299 parameters[self.parameter_idx] 

300 if self.parameter_idx is not None 

301 else 1.0 

302 ) 

303 if self.is_sine_factor: 

304 factor = 1j * jnp.sin(factor) 

305 elif self.is_cosine_factor: 

306 factor = jnp.cos(factor) 

307 if not (self.left or self.right): # leaf 

308 return factor * self.term 

309 

310 sum_children = 0.0 

311 if self.left: 

312 left = self.left.evaluate(parameters) 

313 sum_children = sum_children + left 

314 if self.right: 

315 right = self.right.evaluate(parameters) 

316 sum_children = sum_children + right 

317 

318 return factor * sum_children 

319 

320 def get_leafs( 

321 self, 

322 sin_list: List[int], 

323 cos_list: List[int], 

324 existing_leafs: List[FourierTree.TreeLeaf] = [], 

325 ) -> List[FourierTree.TreeLeaf]: 

326 """ 

327 Traverse the tree starting from the current node, to obtain the tree 

328 leafs only. 

329 The leafs correspond to the terms in the sine-cosine tree 

330 representation that eventually are used to obtain coefficients and 

331 frequencies. 

332 Sine and cosine lists are recursively passed to the children until a 

333 leaf is reached (top to bottom). 

334 Leafs are then passed bottom to top to the caller. 

335 

336 Args: 

337 sin_list (List[int]): Current number of sine contributions for each 

338 parameter. Has the same length as the parameters, as each 

339 position corresponds to one parameter. 

340 cos_list (List[int]): Current number of cosine contributions for 

341 each parameter. Has the same length as the parameters, as each 

342 position corresponds to one parameter. 

343 existing_leafs (List[TreeLeaf]): Current list of leaf nodes from 

344 parents. 

345 

346 Returns: 

347 List[TreeLeaf]: Updated list of leaf nodes. 

348 """ 

349 

350 if self.is_sine_factor: 

351 sin_list = sin_list.at[self.parameter_idx].set( 

352 sin_list[self.parameter_idx] + 1 

353 ) 

354 if self.is_cosine_factor: 

355 cos_list = cos_list.at[self.parameter_idx].set( 

356 cos_list[self.parameter_idx] + 1 

357 ) 

358 

359 if not (self.left or self.right): # leaf 

360 if self.term != 0.0: 

361 return [FourierTree.TreeLeaf(sin_list, cos_list, self.term)] 

362 else: 

363 return [] 

364 

365 if self.left: 

366 leafs_left = self.left.get_leafs( 

367 sin_list.copy(), cos_list.copy(), existing_leafs.copy() 

368 ) 

369 else: 

370 leafs_left = [] 

371 

372 if self.right: 

373 leafs_right = self.right.get_leafs( 

374 sin_list.copy(), cos_list.copy(), existing_leafs.copy() 

375 ) 

376 else: 

377 leafs_right = [] 

378 

379 existing_leafs.extend(leafs_left) 

380 existing_leafs.extend(leafs_right) 

381 return existing_leafs 

382 

383 @dataclass 

384 class TreeLeaf: 

385 """ 

386 Coefficient tree leafs according to the algorithm by Nemkov et al., which 

387 correspond to the terms in the sine-cosine tree representation that 

388 eventually are used to obtain coefficients and frequencies. 

389 

390 Args: 

391 sin_indices (List[int]): Current number of sine contributions for each 

392 parameter. Has the same length as the parameters, as each 

393 position corresponds to one parameter. 

394 cos_indices (List[int]): Current number of cosine contributions for 

395 each parameter. Has the same length as the parameters, as each 

396 position corresponds to one parameter. 

397 term (jnp.complex): Constant factor of the leaf, depending on the 

398 expectation value of the observable, and a phase. 

399 """ 

400 

401 sin_indices: List[int] 

402 cos_indices: List[int] 

403 term: complex 

404 

405 class PauliOperator: 

406 """ 

407 Utility class for storing Pauli Rotations, the corresponding indices 

408 in the XY-Space (whether there is a gate with X or Y generator at a 

409 certain qubit) and the phase. 

410 

411 Args: 

412 pauli (Union[Operator, jnp.ndarray[int]]): Pauli Rotation Operation 

413 or list representation 

414 n_qubits (int): Number of qubits in the circuit 

415 prev_xy_indices (Optional[jnp.ndarray[bool]]): X/Y indices of the 

416 previous Pauli sequence. Defaults to None. 

417 is_observable (bool): If the operator is an observable. Defaults to 

418 False. 

419 is_init (bool): If this Pauli operator is initialised the first 

420 time. Defaults to True. 

421 phase (float): Phase of the operator. Defaults to 1.0 

422 """ 

423 

424 def __init__( 

425 self, 

426 pauli: Union[Operation, jnp.ndarray[int]], 

427 n_qubits: int, 

428 prev_xy_indices: Optional[jnp.ndarray[bool]] = None, 

429 is_observable: bool = False, 

430 is_init: bool = True, 

431 phase: float = 1.0, 

432 ): 

433 self.is_observable = is_observable 

434 self.phase = phase 

435 

436 if is_init: 

437 if not is_observable: 

438 pauli = pauli.generator() 

439 self.list_repr = self._create_list_representation(pauli, n_qubits) 

440 else: 

441 assert isinstance(pauli, jnp.ndarray) 

442 self.list_repr = pauli 

443 

444 if prev_xy_indices is None: 

445 prev_xy_indices = jnp.zeros(n_qubits, dtype=bool) 

446 self.xy_indices = jnp.logical_or( 

447 prev_xy_indices, 

448 self._compute_xy_indices(self.list_repr, rev=is_observable), 

449 ) 

450 

451 @staticmethod 

452 def _compute_xy_indices( 

453 op: jnp.ndarray[int], rev: bool = False 

454 ) -> jnp.ndarray[bool]: 

455 """ 

456 Computes the positions of X or Y gates to an one-hot encoded boolen 

457 array. 

458 

459 Args: 

460 op (jnp.ndarray[int]): Pauli-Operation list representation. 

461 rev (bool): Whether to negate the array. 

462 

463 Returns: 

464 jnp.ndarray[bool]: One hot encoded boolean array. 

465 """ 

466 xy_indices = (op == 0) + (op == 1) 

467 if rev: 

468 xy_indices = ~xy_indices 

469 return xy_indices 

470 

471 @staticmethod 

472 def _create_list_representation( 

473 op: Operation, n_qubits: int 

474 ) -> jnp.ndarray[int]: 

475 """ 

476 Create list representation of an Operation. 

477 Generally, the list representation is a list of length n_qubits, 

478 where at each position a Pauli Operator is encoded as such: 

479 I: -1 

480 X: 0 

481 Y: 1 

482 Z: 2 

483 

484 Args: 

485 op (Operation): Gate operation (PauliX, PauliY, PauliZ, or 

486 Hermitian wrapping a multi-qubit Pauli tensor product). 

487 n_qubits (int): number of qubits in the circuit 

488 

489 Returns: 

490 jnp.ndarray[int]: List representation 

491 """ 

492 pauli_repr = -jnp.ones(n_qubits, dtype=int) 

493 

494 _NAME_TO_IDX = {"PauliX": 0, "PauliY": 1, "PauliZ": 2} 

495 

496 if op.name in _NAME_TO_IDX: 

497 pauli_repr = pauli_repr.at[op.wires[0]].set(_NAME_TO_IDX[op.name]) 

498 elif isinstance(op, PauliX): 

499 pauli_repr = pauli_repr.at[op.wires[0]].set(0) 

500 elif isinstance(op, PauliY): 

501 pauli_repr = pauli_repr.at[op.wires[0]].set(1) 

502 elif isinstance(op, PauliZ): 

503 pauli_repr = pauli_repr.at[op.wires[0]].set(2) 

504 else: 

505 # Multi-qubit case: decompose via pauli_string_from_operation 

506 from qml_essentials.operations import pauli_string_from_operation 

507 

508 pauli_str = pauli_string_from_operation(op) 

509 char_to_idx = {"X": 0, "Y": 1, "Z": 2, "I": -1} 

510 for i, (wire, ch) in enumerate(zip(op.wires, pauli_str)): 

511 idx = char_to_idx.get(ch, -1) 

512 if idx >= 0: 

513 pauli_repr = pauli_repr.at[wire].set(idx) 

514 

515 return pauli_repr 

516 

517 def is_commuting(self, pauli: jnp.ndarray[int]) -> bool: 

518 """ 

519 Computes if this Pauli commutes with another Pauli operator. 

520 This computation is based on the fact that The commutator is zero 

521 if and only if the number of anticommuting single-qubit Paulis is 

522 even. 

523 

524 Args: 

525 pauli (jnp.ndarray[int]): List representation of another Pauli 

526 

527 Returns: 

528 bool: If the current and other Pauli are commuting. 

529 """ 

530 anticommutator = jnp.where( 

531 pauli < 0, 

532 False, 

533 jnp.where( 

534 self.list_repr < 0, 

535 False, 

536 jnp.where(self.list_repr == pauli, False, True), 

537 ), 

538 ) 

539 return not (jnp.sum(anticommutator) % 2) 

540 

541 def tensor(self, pauli: jnp.ndarray[int]) -> FourierTree.PauliOperator: 

542 """ 

543 Compute tensor product between the current Pauli and a given list 

544 representation of another Pauli operator. 

545 

546 Args: 

547 pauli (jnp.ndarray[int]): List representation of Pauli 

548 

549 Returns: 

550 FourierTree.PauliOperator: New Pauli operator object, which 

551 contains the tensor product 

552 """ 

553 diff = (pauli - self.list_repr + 3) % 3 

554 phase = jnp.where( 

555 self.list_repr < 0, 

556 1.0, 

557 jnp.where( 

558 pauli < 0, 

559 1.0, 

560 jnp.where( 

561 diff == 2, 

562 1.0j, 

563 jnp.where(diff == 1, -1.0j, 1.0), 

564 ), 

565 ), 

566 ) 

567 

568 obs = jnp.where( 

569 self.list_repr < 0, 

570 pauli, 

571 jnp.where( 

572 pauli < 0, 

573 self.list_repr, 

574 jnp.where( 

575 diff == 2, 

576 (self.list_repr + 1) % 3, 

577 jnp.where(diff == 1, (self.list_repr + 2) % 3, -1), 

578 ), 

579 ), 

580 ) 

581 phase = self.phase * jnp.prod(phase) 

582 return FourierTree.PauliOperator( 

583 obs, phase=phase, n_qubits=obs.size, is_init=False, is_observable=True 

584 ) 

585 

586 def __init__(self, model: Model): 

587 """ 

588 Tree initialisation, based on the Pauli-Clifford representation of a model. 

589 Currently, only one input feature is supported. 

590 

591 **Usage**: 

592 ``` 

593 # initialise a model 

594 model = Model(...) 

595 

596 # initialise and build FourierTree 

597 tree = FourierTree(model) 

598 

599 # get expectaion value 

600 exp = tree() 

601 

602 # Get spectrum (for each observable, we have one list element) 

603 coeff_list, freq_list = tree.spectrum() 

604 ``` 

605 

606 Args: 

607 model (Model): The Model, for which to build the tree 

608 """ 

609 self.model = model 

610 self.tree_roots = None 

611 

612 inputs = self.model._inputs_validation([1.0]) 

613 

614 # Record the circuit tape using jaqsi's tape recording 

615 raw_tape = self.model.script._record(params=model.params, inputs=inputs) 

616 

617 # Build observables from the model's output_qubit configuration 

618 _, obs_list = self.model._build_obs() 

619 

620 quantum_tape = PauliCircuit.from_parameterised_circuit( 

621 raw_tape, observables=obs_list 

622 ) 

623 

624 self.parameters = [jnp.squeeze(p) for p in quantum_tape.get_parameters()] 

625 

626 self.input_indices, self.all_input_indices = quantum_tape.get_input_indices() 

627 

628 self.observables = self._encode_observables(quantum_tape.observables) 

629 

630 pauli_rot = FourierTree.PauliOperator( 

631 quantum_tape.operations[0], 

632 self.model.n_qubits, 

633 ) 

634 self.pauli_rotations = [pauli_rot] 

635 for op in quantum_tape.operations[1:]: 

636 pauli_rot = FourierTree.PauliOperator( 

637 op, self.model.n_qubits, pauli_rot.xy_indices 

638 ) 

639 self.pauli_rotations.append(pauli_rot) 

640 

641 self.tree_roots = self.build() 

642 self.leafs: List[List[FourierTree.TreeLeaf]] = self._get_tree_leafs() 

643 

644 def __call__( 

645 self, 

646 params: Optional[jnp.ndarray] = None, 

647 inputs: Optional[jnp.ndarray] = None, 

648 **kwargs, 

649 ) -> jnp.ndarray: 

650 """ 

651 Evaluates the Fourier tree via sine-cosine terms sum. This is 

652 equivalent to computing the expectation value of the observables with 

653 respect to the corresponding circuit. 

654 

655 Args: 

656 params (Optional[jnp.ndarray], optional): Parameters of the model. 

657 Defaults to None. 

658 inputs (Optional[jnp.ndarray], optional): Inputs to the circuit. 

659 Defaults to None. 

660 

661 Returns: 

662 jnp.ndarray: Expectation value of the tree. 

663 

664 Raises: 

665 NotImplementedError: When using other "execution_type" as expval. 

666 NotImplementedError: When using "noise_params" 

667 

668 

669 """ 

670 params = ( 

671 self.model._params_validation(params) 

672 if params is not None 

673 else self.model.params 

674 ) 

675 inputs = ( 

676 self.model._inputs_validation(inputs) 

677 if inputs is not None 

678 else self.model._inputs_validation(1.0) 

679 ) 

680 

681 if kwargs.get("execution_type", "expval") != "expval": 

682 raise NotImplementedError( 

683 f'Currently, only "expval" execution type is supported when ' 

684 f"building FourierTree. Got {kwargs.get('execution_type', 'expval')}." 

685 ) 

686 if kwargs.get("noise_params", None) is not None: 

687 raise NotImplementedError( 

688 "Currently, noise is not supported when building FourierTree." 

689 ) 

690 

691 # Record the circuit tape using jaqsi's tape recording 

692 raw_tape = self.model.script._record(params=self.model.params, inputs=inputs) 

693 

694 # Build observables from the model's output_qubit configuration 

695 _, obs_list = self.model._build_obs() 

696 

697 quantum_tape = PauliCircuit.from_parameterised_circuit( 

698 raw_tape, observables=obs_list 

699 ) 

700 

701 self.parameters = [jnp.squeeze(p) for p in quantum_tape.get_parameters()] 

702 

703 results = jnp.zeros(len(self.tree_roots)) 

704 for i, root in enumerate(self.tree_roots): 

705 results = results.at[i].set(jnp.real(root.evaluate(self.parameters))) 

706 

707 if kwargs.get("force_mean", False): 

708 return jnp.mean(results) 

709 else: 

710 return results 

711 

712 def build(self) -> List[CoefficientsTreeNode]: 

713 """ 

714 Creates the coefficient tree, i.e. it creates and initialises the tree 

715 nodes. 

716 Leafs can be obtained separately in _get_tree_leafs, once the tree is 

717 set up. 

718 

719 Returns: 

720 List[CoefficientsTreeNode]: The list of root nodes (one root for 

721 each observable). 

722 """ 

723 tree_roots = [] 

724 pauli_rotation_idx = len(self.pauli_rotations) - 1 

725 for obs in self.observables: 

726 root = self._create_tree_node(obs, pauli_rotation_idx) 

727 tree_roots.append(root) 

728 return tree_roots 

729 

730 def _encode_observables( 

731 self, tape_obs: List[Operation] 

732 ) -> List[FourierTree.PauliOperator]: 

733 """ 

734 Encodes observables from tape as FourierTree.PauliOperator 

735 utility objects. 

736 

737 Args: 

738 tape_obs (List[Operation]): Observable operations 

739 

740 Returns: 

741 List[FourierTree.PauliOperator]: List of Pauli operators 

742 """ 

743 observables = [] 

744 for obs in tape_obs: 

745 observables.append( 

746 FourierTree.PauliOperator(obs, self.model.n_qubits, is_observable=True) 

747 ) 

748 return observables 

749 

750 def _get_tree_leafs(self) -> List[List[TreeLeaf]]: 

751 """ 

752 Obtain all Leaf Nodes with its sine- and cosine terms. 

753 

754 Returns: 

755 List[List[TreeLeaf]]: For each observable (root), the list of leaf 

756 nodes. 

757 """ 

758 leafs = [] 

759 for root in self.tree_roots: 

760 sin_list = jnp.zeros(len(self.parameters), dtype=jnp.int32) 

761 cos_list = jnp.zeros(len(self.parameters), dtype=jnp.int32) 

762 leafs.append(root.get_leafs(sin_list, cos_list, [])) 

763 return leafs 

764 

765 def get_spectrum( 

766 self, force_mean: bool = False 

767 ) -> Tuple[List[jnp.ndarray], List[jnp.ndarray]]: 

768 """ 

769 Computes the Fourier spectrum for the tree, consisting of the 

770 frequencies and its corresponding coefficinets. 

771 If the frag force_mean was set in the constructor, the mean coefficient 

772 over all observables (roots) are computed. 

773 

774 Args: 

775 force_mean (bool, optional): Whether to average over multiple 

776 observables. Defaults to False. 

777 

778 Returns: 

779 Tuple[List[jnp.ndarray], List[jnp.ndarray]]: 

780 - List of frequencies, one list for each observable (root). 

781 - List of corresponding coefficents, one list for each 

782 observable (root). 

783 """ 

784 parameter_indices = [ 

785 i for i in range(len(self.parameters)) if i not in self.all_input_indices 

786 ] 

787 

788 coeffs = [] 

789 for leafs in self.leafs: 

790 freq_terms = defaultdict(np.complex128) 

791 for input_idx in self.input_indices: 

792 for leaf in leafs: 

793 leaf_factor, s, c = self._compute_leaf_factors( 

794 leaf, parameter_indices, input_idx 

795 ) 

796 

797 for a in range(s + 1): 

798 for b in range(c + 1): 

799 comb = math.comb(s, a) * math.comb(c, b) * (-1) ** (s - a) 

800 freq_terms[2 * a + 2 * b - s - c] += comb * leaf_factor 

801 

802 coeffs.append(freq_terms) 

803 

804 frequencies, coefficients = self._freq_terms_to_coeffs(coeffs, force_mean) 

805 return coefficients, frequencies 

806 

807 def _freq_terms_to_coeffs( 

808 self, coeffs: List[Dict[int, jnp.ndarray]], force_mean: bool 

809 ) -> Tuple[List[jnp.ndarray], List[jnp.ndarray]]: 

810 """ 

811 Given a list of dictionaries of the form: 

812 [ 

813 { 

814 freq_obs1_1: coeff1, 

815 freq_obs1_2: coeff2, 

816 ... 

817 }, 

818 { 

819 freq_obs2_1: coeff3, 

820 freq_obs2_2: coeff4, 

821 ... 

822 } 

823 ... 

824 ], 

825 Compute two separate lists of frequencies and coefficients. 

826 such that: 

827 freqs: [ 

828 [freq_obs1_1, freq_obs1_1, ...], 

829 [freq_obs2_1, freq_obs2_1, ...], 

830 ... 

831 ] 

832 coeffs: [ 

833 [coeff1, coeff2, ...], 

834 [coeff3, coeff4, ...], 

835 ... 

836 ] 

837 

838 If force_mean is set length of the resulting frequency and coefficent 

839 list is 1. 

840 

841 Args: 

842 coeffs (List[Dict[int, jnp.ndarray]]): Frequency->Coefficients 

843 dictionary list, one dict for each observable (root). 

844 force_mean (bool): Whether to average coefficients over multiple 

845 observables. 

846 

847 Returns: 

848 Tuple[List[jnp.ndarray], List[jnp.ndarray]]: 

849 - List of frequencies, one list for each observable (root). 

850 - List of corresponding coefficents, one list for each 

851 observable (root). 

852 """ 

853 frequencies = [] 

854 coefficients = [] 

855 if force_mean: 

856 all_freqs = sorted(set([f for c in coeffs for f in c.keys()])) 

857 coefficients.append( 

858 jnp.array( 

859 [ 

860 jnp.mean(jnp.array([c.get(f, 0.0) for c in coeffs])) 

861 for f in all_freqs 

862 ] 

863 ) 

864 ) 

865 frequencies.append(jnp.array(all_freqs)) 

866 else: 

867 for freq_terms in coeffs: 

868 freq_terms = dict(sorted(freq_terms.items())) 

869 frequencies.append(jnp.array(list(freq_terms.keys()))) 

870 coefficients.append(jnp.array(list(freq_terms.values()))) 

871 return frequencies, coefficients 

872 

873 def _compute_leaf_factors( 

874 self, 

875 leaf: TreeLeaf, 

876 parameter_indices: List[int], 

877 input_idx: int, 

878 ) -> Tuple[float, int, int]: 

879 """ 

880 Computes the constant coefficient factor for each leaf. 

881 Additionally sine and cosine contributions of the input parameters for 

882 this leaf are returned, which are required to obtain the corresponding 

883 frequencies. 

884 

885 Args: 

886 leaf (TreeLeaf): The leaf for which to compute the factor. 

887 parameter_indices (List[int]): Variational parameter indices. 

888 

889 Returns: 

890 Tuple[float, int, int]: 

891 - float: the constant factor for the leaf 

892 - int: number of sine contributions of the input 

893 - int: number of cosine contributions of the input 

894 """ 

895 leaf_factor = 1.0 

896 for i in parameter_indices: 

897 interm_factor = ( 

898 jnp.cos(self.parameters[i]) ** leaf.cos_indices[i] 

899 * (1j * jnp.sin(self.parameters[i])) ** leaf.sin_indices[i] 

900 ) 

901 leaf_factor = leaf_factor * interm_factor 

902 

903 # Get number of sine and cosine factors to which the input contributes 

904 c = jnp.sum( 

905 jnp.array([leaf.cos_indices[k] for k in self.input_indices[input_idx]]) 

906 ) 

907 s = jnp.sum( 

908 jnp.array([leaf.sin_indices[k] for k in self.input_indices[input_idx]]) 

909 ) 

910 

911 leaf_factor = leaf.term * leaf_factor * 0.5 ** (s + c) 

912 

913 return leaf_factor, int(s), int(c) 

914 

915 def _early_stopping_possible( 

916 self, pauli_rotation_idx: int, observable: FourierTree.PauliOperator 

917 ): 

918 """ 

919 Checks if a node for an observable can be discarded as all expecation 

920 values that can result through further branching are zero. 

921 The method is mentioned in the paper by Nemkov et al.: If the one-hot 

922 encoded indices for X/Y operations in the Pauli-rotation generators are 

923 a basis for that of the observable, the node must be processed further. 

924 If not, it can be discarded. 

925 

926 Args: 

927 pauli_rotation_idx (int): Index of remaining Pauli rotation gates. 

928 Gates itself are attributes of the class. 

929 observable (FourierTree.PauliOperator): Current observable 

930 """ 

931 xy_indices_obs = jnp.logical_or( 

932 observable.xy_indices, self.pauli_rotations[pauli_rotation_idx].xy_indices 

933 ).all() 

934 

935 return not xy_indices_obs 

936 

937 def _create_tree_node( 

938 self, 

939 observable: FourierTree.PauliOperator, 

940 pauli_rotation_idx: int, 

941 parameter_idx: Optional[int] = None, 

942 is_sine: bool = False, 

943 is_cosine: bool = False, 

944 ) -> Optional[CoefficientsTreeNode]: 

945 """ 

946 Builds the Fourier-Tree according to the algorithm by Nemkov et al. 

947 

948 Args: 

949 observable (FourierTree.PauliOperator): Current observable 

950 pauli_rotation_idx (int): Index of remaining Pauli rotation gates. 

951 Gates itself are attributes of the class. 

952 parameter_idx (Optional[int]): Index of the current parameter. 

953 Parameters itself are attributes of the class. 

954 is_sine (bool): If the current node is a sine (left) node. 

955 is_cosine (bool): If the current node is a cosine (right) node. 

956 

957 Returns: 

958 Optional[CoefficientsTreeNode]: The resulting node. Children are set 

959 recursively. The top level receives the tree root. 

960 """ 

961 if self._early_stopping_possible(pauli_rotation_idx, observable): 

962 return None 

963 

964 # remove commuting paulis 

965 while pauli_rotation_idx >= 0: 

966 last_pauli = self.pauli_rotations[pauli_rotation_idx] 

967 if not observable.is_commuting(last_pauli.list_repr): 

968 break 

969 pauli_rotation_idx -= 1 

970 else: # leaf 

971 return FourierTree.CoefficientsTreeNode( 

972 parameter_idx, observable, is_sine, is_cosine 

973 ) 

974 

975 last_pauli = self.pauli_rotations[pauli_rotation_idx] 

976 

977 left = self._create_tree_node( 

978 observable, 

979 pauli_rotation_idx - 1, 

980 pauli_rotation_idx, 

981 is_cosine=True, 

982 ) 

983 

984 next_observable = self._create_new_observable(last_pauli.list_repr, observable) 

985 right = self._create_tree_node( 

986 next_observable, 

987 pauli_rotation_idx - 1, 

988 pauli_rotation_idx, 

989 is_sine=True, 

990 ) 

991 

992 return FourierTree.CoefficientsTreeNode( 

993 parameter_idx, 

994 observable, 

995 is_sine, 

996 is_cosine, 

997 left, 

998 right, 

999 ) 

1000 

1001 def _create_new_observable( 

1002 self, pauli: jnp.ndarray[int], observable: FourierTree.PauliOperator 

1003 ) -> FourierTree.PauliOperator: 

1004 """ 

1005 Utility function to obtain the new observable for a tree node, if the 

1006 last Pauli and the observable do not commute. 

1007 

1008 Args: 

1009 pauli (jnp.ndarray[int]): The int array representation of the last 

1010 Pauli rotation in the operation sequence. 

1011 observable (FourierTree.PauliOperator): The current observable. 

1012 

1013 Returns: 

1014 FourierTree.PauliOperator: The new observable. 

1015 """ 

1016 observable = observable.tensor(pauli) 

1017 return observable 

1018 

1019 

1020class FCC: 

1021 @classmethod 

1022 def get_fcc( 

1023 cls, 

1024 model: Model, 

1025 n_samples: int, 

1026 random_key: Optional[random.PRNGKey] = None, 

1027 method: Optional[str] = "pearson", 

1028 scale: Optional[bool] = False, 

1029 weight: Optional[bool] = False, 

1030 trim_redundant: Optional[bool] = True, 

1031 **kwargs, 

1032 ) -> float: 

1033 """ 

1034 Shortcut method to get just the FCC. 

1035 This includes 

1036 1. What is done in `get_fourier_fingerprint`: 

1037 1. Calculating the coefficients (using `n_samples`) 

1038 2. Correlating the result from 1) using `method` 

1039 3. Weighting the correlation matrix (if `weight` is True) 

1040 4. Remove redundancies 

1041 2. What is done in `calculate_fcc`: 

1042 1. Absolute of the fingerprint 

1043 2. Average 

1044 

1045 Args: 

1046 model (Model): The QFM model 

1047 n_samples (int): Number of samples to calculate average of coefficients 

1048 random_key (Optional[random.PRNGKey]): JAX random key for parameter 

1049 initialization. If None, uses the model's internal random key. 

1050 method (Optional[str], optional): Correlation method. Supported values are 

1051 "pearson", "complex_pearson", and "spearman". Defaults to "pearson". 

1052 scale (Optional[bool], optional): Whether to scale the number of samples. 

1053 Defaults to False. 

1054 weight (Optional[bool], optional): Whether to weight the correlation matrix. 

1055 Defaults to False. 

1056 trim_redundant (Optional[bool], optional): Whether to remove redundant 

1057 correlations. Defaults to False. 

1058 **kwargs (Any): Additional keyword arguments for the model function. 

1059 

1060 Returns: 

1061 float: The FCC 

1062 """ 

1063 

1064 # Memory-efficient fast path 

1065 if trim_redundant and not weight: 

1066 _, coeffs, freqs = cls._calculate_coefficients( 

1067 model, n_samples, random_key, scale, **kwargs 

1068 ) 

1069 pos_idx = cls._calculate_mask(freqs) 

1070 coeffs_flat = coeffs.reshape(-1, coeffs.shape[-1]) 

1071 coeffs_sub = coeffs_flat[pos_idx] 

1072 

1073 fp = cls._correlate(coeffs_sub.transpose(), method=method) 

1074 abs_fp = jnp.abs(fp) 

1075 diag = jnp.abs(jnp.diagonal(fp)) 

1076 

1077 total_sum = jnp.nansum(abs_fp) 

1078 total_count = jnp.sum(jnp.isfinite(abs_fp)) 

1079 diag_sum = jnp.nansum(diag) 

1080 diag_count = jnp.sum(jnp.isfinite(diag)) 

1081 

1082 lower_sum = (total_sum - diag_sum) / 2.0 

1083 lower_count = (total_count - diag_count) / 2.0 

1084 return lower_sum / lower_count 

1085 

1086 fourier_fingerprint, _ = cls.get_fourier_fingerprint( 

1087 model, 

1088 n_samples, 

1089 random_key, 

1090 method, 

1091 scale, 

1092 weight, 

1093 trim_redundant=trim_redundant, 

1094 **kwargs, 

1095 ) 

1096 

1097 return cls.calculate_fcc(fourier_fingerprint) 

1098 

1099 @classmethod 

1100 def get_fourier_fingerprint( 

1101 cls, 

1102 model: Model, 

1103 n_samples: int, 

1104 random_key: Optional[random.PRNGKey] = None, 

1105 method: Optional[str] = "pearson", 

1106 scale: Optional[bool] = False, 

1107 weight: Optional[bool] = False, 

1108 trim_redundant: Optional[bool] = True, 

1109 nan_to_one: Optional[bool] = False, 

1110 **kwargs: Any, 

1111 ) -> Tuple[jnp.ndarray, jnp.ndarray]: 

1112 """ 

1113 Shortcut method to get just the fourier fingerprint. 

1114 This includes 

1115 1. Calculating the coefficients (using `n_samples`) 

1116 2. Correlating the result from 1) using `method` 

1117 3. Weighting the correlation matrix (if `weight` is True) 

1118 4. Remove redundancies (if `trim_redundant` is True) 

1119 

1120 Args: 

1121 model (Model): The QFM model 

1122 n_samples (int): Number of samples to calculate average of coefficients 

1123 random_key (Optional[random.PRNGKey]): JAX random key for parameter 

1124 initialization. If None, uses the model's internal random key. 

1125 method (Optional[str], optional): Correlation method. Supported values are 

1126 "pearson", "complex_pearson", and "spearman". Defaults to "pearson". 

1127 scale (Optional[bool], optional): Whether to scale the number of samples. 

1128 Defaults to False. 

1129 weight (Optional[bool], optional): Whether to weight the correlation matrix. 

1130 Defaults to False. 

1131 trim_redundant (Optional[bool], optional): Whether to remove redundant 

1132 correlations. Defaults to True. 

1133 nan_to_one (Optional[bool], optional): Whether to set nan to 1. 

1134 Defaults to False. 

1135 **kwargs: Additional keyword arguments for the model function. 

1136 

1137 Returns: 

1138 Tuple[jnp.ndarray, jnp.ndarray]: The fourier fingerprint and the 

1139 corresponding frequency indices. If `trim_redundant` is True the 

1140 frequencies are returned as a `(row_freqs, col_freqs)` tuple that 

1141 labels the two (redundancy-trimmed) matrix axes; otherwise the 

1142 full frequency vector is returned. 

1143 """ 

1144 _, coeffs, freqs = cls._calculate_coefficients( 

1145 model, n_samples, random_key, scale, **kwargs 

1146 ) 

1147 

1148 # Memory-efficient fast path 

1149 if trim_redundant and not weight: 

1150 pos_idx = cls._calculate_mask(freqs) 

1151 pos_freqs = cls._flat_frequencies(freqs)[pos_idx] 

1152 

1153 # Flatten all frequency axes; the last axis is the sample 

1154 # axis. `_calculate_mask` returns flat indices in C order, 

1155 # matching this reshape. 

1156 coeffs_flat = coeffs.reshape(-1, coeffs.shape[-1]) 

1157 coeffs_sub = coeffs_flat[pos_idx] 

1158 

1159 fourier_fingerprint = cls._correlate(coeffs_sub.transpose(), method=method) 

1160 

1161 if nan_to_one: 

1162 fourier_fingerprint = jnp.where( 

1163 jnp.isnan(fourier_fingerprint), 1.0, fourier_fingerprint 

1164 ) 

1165 

1166 M = fourier_fingerprint.shape[0] 

1167 lower_tri_mask = jnp.tri(M, k=-1, dtype=bool) 

1168 fourier_fingerprint = jnp.where( 

1169 lower_tri_mask, fourier_fingerprint, jnp.nan 

1170 ) 

1171 

1172 row_mask = jnp.any(jnp.isfinite(fourier_fingerprint), axis=1) 

1173 col_mask = jnp.any(jnp.isfinite(fourier_fingerprint), axis=0) 

1174 fourier_fingerprint = fourier_fingerprint[row_mask][:, col_mask] 

1175 

1176 return fourier_fingerprint, (pos_freqs[row_mask], pos_freqs[col_mask]) 

1177 

1178 fourier_fingerprint = cls._correlate(coeffs.transpose(), method=method) 

1179 

1180 if nan_to_one: 

1181 # set nan to 1 

1182 fourier_fingerprint[jnp.isnan(fourier_fingerprint)] = 1.0 

1183 

1184 # perform weighting if requested 

1185 fourier_fingerprint = ( 

1186 cls._weighting_mean(fourier_fingerprint, coeffs) 

1187 if weight 

1188 else fourier_fingerprint 

1189 ) 

1190 

1191 if trim_redundant: 

1192 pos_idx = cls._calculate_mask(freqs) 

1193 pos_freqs = cls._flat_frequencies(freqs)[pos_idx] 

1194 

1195 # restrict to the positive-frequency sub-block (M x M with 

1196 # M = number of non-negative flat-frequencies) instead of 

1197 # building a full N x N mask. This avoids the O(N^2) float 

1198 fourier_fingerprint = fourier_fingerprint[pos_idx][:, pos_idx] 

1199 

1200 # keep only the strict lower triangle; the rest -> nan 

1201 M = fourier_fingerprint.shape[0] 

1202 lower_tri_mask = jnp.tri(M, k=-1, dtype=bool) 

1203 fourier_fingerprint = jnp.where( 

1204 lower_tri_mask, fourier_fingerprint, jnp.nan 

1205 ) 

1206 

1207 row_mask = jnp.any(jnp.isfinite(fourier_fingerprint), axis=1) 

1208 col_mask = jnp.any(jnp.isfinite(fourier_fingerprint), axis=0) 

1209 

1210 fourier_fingerprint = fourier_fingerprint[row_mask][:, col_mask] 

1211 

1212 return fourier_fingerprint, (pos_freqs[row_mask], pos_freqs[col_mask]) 

1213 

1214 return fourier_fingerprint, freqs 

1215 

1216 @classmethod 

1217 def calculate_fcc( 

1218 cls, 

1219 fourier_fingerprint: jnp.ndarray, 

1220 ) -> float: 

1221 """ 

1222 Method to calculate the FCC based on an existing correlation matrix. 

1223 Calculate absolute and then the average over this matrix. 

1224 The Fingerprint can be obtained via `get_fourier_fingerprint` 

1225 

1226 Args: 

1227 fourier_fingerprint (jnp.ndarray): Correlation matrix of coefficients 

1228 Returns: 

1229 float: The FCC 

1230 """ 

1231 # apply the mask on the fingerprint 

1232 return jnp.nanmean(jnp.abs(fourier_fingerprint)) 

1233 

1234 @classmethod 

1235 def _calculate_mask(cls, freqs: jnp.ndarray) -> jnp.ndarray: 

1236 """ 

1237 Determine the flat indices of the Fourier correlation matrix 

1238 that lie on a non-negative-frequency row/column. Together with 

1239 the strict-lower-triangle condition (handled by the caller), 

1240 these indices select the entries of the correlation matrix 

1241 that survive the redundancy filter applied in 

1242 `get_fourier_fingerprint`: 

1243 

1244 - rows/columns whose flat frequency component is negative are 

1245 discarded (they are the complex-conjugate redundancies of 

1246 their positive counterparts); 

1247 - of the remaining positive-frequency sub-block, only the 

1248 strict lower triangle is kept (the upper triangle, including 

1249 the diagonal, contains either duplicates from symmetry or 

1250 self-correlations). 

1251 

1252 Args: 

1253 freqs (jnp.ndarray): Array of frequencies. Either a 1-D 

1254 vector (single input feature) or a 2-D array of shape 

1255 ``(n_input_feat, K)`` whose rows are the per-axis 

1256 frequency vectors. 

1257 

1258 Returns: 

1259 jnp.ndarray: 1-D int array of flat indices selecting the 

1260 non-negative-frequency rows/cols of the fingerprint. 

1261 """ 

1262 freqs_arr = jnp.asarray(freqs) 

1263 

1264 if freqs_arr.ndim == 1: 

1265 pos_flat = freqs_arr >= 0 

1266 else: 

1267 # N-D case: build the per-axis non-negativity masks and 

1268 # combine them via broadcasting (no float `jnp.outer`!), 

1269 # then flatten to match the row-major flattening used by 

1270 # the upstream coefficient/correlation pipeline. 

1271 axes_pos = [freqs_arr[i] >= 0 for i in range(freqs_arr.shape[0])] 

1272 expanded = [] 

1273 n_axes = len(axes_pos) 

1274 for i, p in enumerate(axes_pos): 

1275 shape = [1] * n_axes 

1276 shape[i] = p.shape[0] 

1277 expanded.append(p.reshape(shape)) 

1278 nd_pos = reduce(jnp.logical_and, expanded) 

1279 pos_flat = nd_pos.flatten() 

1280 

1281 return jnp.where(pos_flat)[0] 

1282 

1283 @classmethod 

1284 def _flat_frequencies(cls, freqs: jnp.ndarray) -> jnp.ndarray: 

1285 """ 

1286 Build the per-coefficient flat frequency labels in the same 

1287 C-order used to flatten the coefficient/correlation pipeline, so 

1288 they can be indexed by the flat indices from `_calculate_mask`. 

1289 

1290 Args: 

1291 freqs (jnp.ndarray): Either a 1-D vector (single input feature) 

1292 or a ``(n_input_feat, K)`` stack / list of per-axis frequency 

1293 vectors (multi-dim input). 

1294 

1295 Returns: 

1296 jnp.ndarray: 1-D frequency vector (single input feature) or a 

1297 ``(N, n_input_feat)`` array of per-coefficient frequency 

1298 tuples (multi-dim input). 

1299 """ 

1300 fa = jnp.asarray(freqs) 

1301 if fa.ndim == 1: 

1302 return fa 

1303 # Multi-dim: per-axis vectors -> flat grid of frequency tuples in the 

1304 # same C-order used by `_calculate_mask` and the coefficient reshape. 

1305 grids = jnp.meshgrid(*[fa[i] for i in range(fa.shape[0])], indexing="ij") 

1306 return jnp.stack(grids, axis=-1).reshape(-1, fa.shape[0]) 

1307 

1308 @classmethod 

1309 def _calculate_coefficients( 

1310 cls, 

1311 model: Model, 

1312 n_samples: int, 

1313 random_key: Optional[random.PRNGKey] = None, 

1314 scale: bool = False, 

1315 **kwargs: Any, 

1316 ) -> Tuple[jnp.ndarray, jnp.ndarray]: 

1317 """ 

1318 Calculates the Fourier coefficients of a given model 

1319 using `n_samples`. 

1320 Optionally, `noise_params` can be passed to perform noisy simulation. 

1321 

1322 Args: 

1323 model (Model): The QFM model 

1324 n_samples (int): Number of samples to calculate average of coefficients 

1325 random_key (Optional[random.PRNGKey]): JAX random key for parameter 

1326 initialization. If None, uses the model's internal random key. 

1327 scale (bool, optional): Whether to scale the number of samples. 

1328 Defaults to False. 

1329 **kwargs: Additional keyword arguments for the model function. 

1330 

1331 Returns: 

1332 Tuple[jnp.ndarray, jnp.ndarray]: Parameters and Coefficients of size NxK 

1333 """ 

1334 if n_samples > 0: 

1335 if scale: 

1336 total_samples = int( 

1337 jnp.power(2, model.n_qubits) * n_samples * model.n_input_feat 

1338 ) 

1339 log.info(f"Using {total_samples} samples.") 

1340 else: 

1341 total_samples = n_samples 

1342 model.initialize_params(random_key, repeat=total_samples) 

1343 else: 

1344 total_samples = 1 

1345 

1346 coeffs, freqs = Coefficients.get_spectrum( 

1347 model, shift=True, trim=True, **kwargs 

1348 ) 

1349 

1350 return model.params, coeffs, freqs 

1351 

1352 @classmethod 

1353 def _correlate(cls, mat: jnp.ndarray, method: str = "pearson") -> jnp.ndarray: 

1354 """ 

1355 Correlates two arrays using `method`. 

1356 Currently, `pearson`, `complex_pearson`, and `spearman` are supported. 

1357 

1358 Args: 

1359 mat (jnp.ndarray): Array of shape (N, K) 

1360 method (str, optional): Correlation method. Defaults to "pearson". 

1361 

1362 Raises: 

1363 ValueError: If the method is not supported. 

1364 

1365 Returns: 

1366 jnp.ndarray: Correlation matrix of `a` and `b`. 

1367 """ 

1368 assert len(mat.shape) >= 2, "Input matrix must have at least 2 dimensions" 

1369 

1370 # Note that for the general n-D case, we have to flatten along 

1371 # the first axis (last one is batch). 

1372 # Note that the order here is important so we can easily filter out 

1373 # negative coefficients later. 

1374 # Consider the following example: [[1,2,3],[4,5,6],[7,8,9]] 

1375 # we want to get [1, 4, 7, 2, 5, 8, 3, 6, 9] 

1376 # such that after correlation, all positive indexed coefficients 

1377 # will be in the bottom right quadrant 

1378 if method == "pearson": 

1379 result = cls._pearson(mat.reshape(mat.shape[0], -1)) 

1380 # result = cls._pearson(mat.reshape(mat.shape[-1], -1, order="F")) 

1381 elif method == "complex_pearson": 

1382 result = cls._complex_pearson(mat.reshape(mat.shape[0], -1)) 

1383 elif method == "spearman": 

1384 result = cls._spearman(mat.reshape(mat.shape[0], -1)) 

1385 # result = cls._spearman(mat.reshape(mat.shape[-1], -1, order="F")) 

1386 else: 

1387 raise ValueError( 

1388 f"Unknown correlation method: {method}. \ 

1389 Must be 'pearson', 'complex_pearson' or 'spearman'." 

1390 ) 

1391 

1392 return result 

1393 

1394 @classmethod 

1395 def _complex_pearson( 

1396 cls, mat: jnp.ndarray, cov: Optional[bool] = False, minp: Optional[int] = 1 

1397 ) -> jnp.ndarray: 

1398 """ 

1399 Compute the complex Pearson correlation between columns of `mat`, 

1400 permitting missing values (NaN or ±Inf). 

1401 

1402 This uses the Hermitian normalized covariance 

1403 sum(conj(x_i - mean_i) * (x_j - mean_j)) / 

1404 sqrt(sum(abs(x_i - mean_i)**2) * sum(abs(x_j - mean_j)**2)). 

1405 Consequently, if column j is exp(1j * phi) times column i, then 

1406 abs(corr[i, j]) is 1 and angle(corr[i, j]) is phi. 

1407 

1408 Args: 

1409 mat : array_like, shape (N, K) 

1410 Input data. 

1411 cov : bool, optional 

1412 If True, return the sample covariance matrix instead of 

1413 correlation. Defaults to False. 

1414 minp : int, optional 

1415 Minimum number of paired observations required to form a correlation. 

1416 If the number of valid pairs for (i, j) is < minp, the result is NaN. 

1417 

1418 Returns: 

1419 corr : ndarray, shape (K, K) 

1420 Complex Pearson correlation matrix. 

1421 """ 

1422 mat = jnp.asarray(mat) 

1423 real_dtype = jnp.asarray(mat.real).dtype 

1424 

1425 mask = jnp.isfinite(mat) 

1426 fmask = mask.astype(real_dtype) 

1427 safe = jnp.where(mask, mat, 0.0) 

1428 

1429 nobs = fmask.T @ fmask 

1430 nobs_safe = jnp.where(nobs > 0, nobs, 1.0) 

1431 

1432 sum_x = safe.T @ fmask 

1433 sum_y = fmask.T @ safe 

1434 

1435 masked = safe * fmask 

1436 sum_conj_xy = jnp.conj(masked).T @ masked 

1437 

1438 safe_abs_sq = jnp.abs(safe) ** 2 

1439 sum_abs_x2 = safe_abs_sq.T @ fmask 

1440 sum_abs_y2 = fmask.T @ safe_abs_sq 

1441 

1442 ssx = sum_abs_x2 - jnp.abs(sum_x) ** 2 / nobs_safe 

1443 ssy = sum_abs_y2 - jnp.abs(sum_y) ** 2 / nobs_safe 

1444 sxy = sum_conj_xy - (jnp.conj(sum_x) * sum_y) / nobs_safe 

1445 

1446 if cov: 

1447 denom = jnp.where(nobs > 1, nobs - 1, jnp.nan) 

1448 result = sxy / denom 

1449 else: 

1450 denom = jnp.sqrt(ssx * ssy) 

1451 result = jnp.where(denom > 0, sxy / denom, jnp.nan) 

1452 magnitude = jnp.abs(result) 

1453 result = jnp.where(magnitude > 1.0, result / magnitude, result) 

1454 

1455 result = jnp.where(nobs < minp, jnp.nan, result) 

1456 

1457 return result 

1458 

1459 @classmethod 

1460 def _pearson( 

1461 cls, mat: jnp.ndarray, cov: Optional[bool] = False, minp: Optional[int] = 1 

1462 ) -> jnp.ndarray: 

1463 """ 

1464 Based on Pandas correlation method as implemented here: 

1465 https://github.com/pandas-dev/pandas/blob/main/pandas/_libs/algos.pyx 

1466 

1467 Compute Pearson correlation between columns of `mat`, 

1468 permitting missing values (NaN or ±Inf). 

1469 

1470 If the input is complex, real and imaginary parts are stacked along 

1471 the sample axis so that both components contribute to the correlation 

1472 without discarding information. 

1473 

1474 Args: 

1475 mat : array_like, shape (N, K) 

1476 Input data. 

1477 cov : bool, optional 

1478 If True, return the sample covariance matrix instead of 

1479 correlation. Defaults to False. 

1480 minp : int, optional 

1481 Minimum number of paired observations required to form a correlation. 

1482 If the number of valid pairs for (i, j) is < minp, the result is NaN. 

1483 

1484 Returns: 

1485 corr : ndarray, shape (K, K) 

1486 Pearson correlation matrix. 

1487 """ 

1488 # Preserve complex information by splitting into real / imag samples 

1489 if jnp.iscomplexobj(mat): 

1490 mat = jnp.concatenate([mat.real, mat.imag], axis=0) 

1491 

1492 mat = jnp.asarray(mat) 

1493 

1494 # pre-compute finite mask (N, K) 

1495 mask = jnp.isfinite(mat) 

1496 fmask = mask.astype(mat.dtype) 

1497 

1498 # Replace non-finite entries with 0 so arithmetic is safe; 

1499 # the mask keeps track of validity. 

1500 safe = jnp.where(mask, mat, 0.0) 

1501 

1502 # Pairwise valid-observation counts (K, K) 

1503 nobs = fmask.T @ fmask 

1504 

1505 # Pairwise sums (only over mutually valid rows) 

1506 # For columns i, j the "valid" rows are mask[:,i] & mask[:,j]. 

1507 # sum_x[i,j] = sum of mat[:,i] where both i and j are valid. 

1508 # Using: safe[:,i] * mask[:,j] zeroes out rows invalid for j. 

1509 # Then summing over N gives sum_x[i,j]. 

1510 # safe.T @ fmask gives (K, K) where entry (i,j) = sum of safe[:,i]*mask[:,j] 

1511 sum_x = safe.T @ fmask # (K, K) – row-var sums 

1512 sum_y = fmask.T @ safe # (K, K) – col-var sums 

1513 

1514 # Note: explicit means (sum_x/nobs, sum_y/nobs) are not needed as 

1515 # separate variables — the computational formula used below 

1516 # (e.g. ssx = Σx² − (Σx)²/n) implicitly handles mean-centering. 

1517 

1518 # Cross products, sum-of-squares via computational formula: 

1519 # ssx = Σx² − (Σx)²/n, ssy = Σy² − (Σy)²/n, 

1520 # sxy = Σxy − (Σx)(Σy)/n 

1521 # All sums are taken over the pairwise-valid subset for each (i,j). 

1522 masked = safe * fmask # same as safe but explicit 

1523 sum_xy = masked.T @ masked # (K, K) 

1524 

1525 # ssx[i,j] = sum_xx_ij - nobs * mean_x^2 (but sum_xx_ij uses pair mask) 

1526 # We need sum of x^2 over the *pair* mask, not just column mask. 

1527 # sum_x2[i,j] = sum_n safe[n,i]^2 * mask[n,i] * mask[n,j] 

1528 safe_sq = safe**2 

1529 sum_x2 = safe_sq.T @ fmask # (K, K) 

1530 sum_y2 = fmask.T @ safe_sq # (K, K) 

1531 

1532 ssx = sum_x2 - sum_x**2 / jnp.where(nobs > 0, nobs, 1.0) 

1533 ssy = sum_y2 - sum_y**2 / jnp.where(nobs > 0, nobs, 1.0) 

1534 sxy = sum_xy - (sum_x * sum_y) / jnp.where(nobs > 0, nobs, 1.0) 

1535 

1536 if cov: 

1537 denom = jnp.where(nobs > 1, nobs - 1, jnp.nan) 

1538 result = sxy / denom 

1539 else: 

1540 denom = jnp.sqrt(ssx * ssy) 

1541 result = jnp.where(denom > 0, sxy / denom, jnp.nan) 

1542 # clip numerical drift to [-1, 1] 

1543 result = jnp.clip(result, -1.0, 1.0) 

1544 

1545 # Enforce minp: set entries with too few observations to NaN 

1546 result = jnp.where(nobs < minp, jnp.nan, result) 

1547 

1548 return result 

1549 

1550 @classmethod 

1551 def _spearman(cls, mat: jnp.ndarray, minp: Optional[int] = 1) -> jnp.ndarray: 

1552 """ 

1553 Based on Pandas correlation method as implemented here: 

1554 https://github.com/pandas-dev/pandas/blob/main/pandas/_libs/algos.pyx 

1555 

1556 Compute Spearman correlation between columns of `mat`, 

1557 permitting missing values (NaN or ±Inf). 

1558 

1559 If the input is complex, real and imaginary parts are stacked along 

1560 the sample axis so that both components contribute to the correlation 

1561 without discarding information. 

1562 

1563 Args: 

1564 mat : array_like, shape (N, K) 

1565 Input data. 

1566 minp : int, optional 

1567 Minimum number of paired observations required to form a correlation. 

1568 If the number of valid pairs for (i, j) is < minp, the result is NaN. 

1569 

1570 Returns: 

1571 corr : ndarray, shape (K, K) 

1572 Spearman correlation matrix. 

1573 """ 

1574 # Preserve complex information by splitting into real / imag samples 

1575 if jnp.iscomplexobj(mat): 

1576 mat = jnp.concatenate([mat.real, mat.imag], axis=0) 

1577 

1578 mat = jnp.asarray(mat) 

1579 N, K = mat.shape 

1580 

1581 # trivial all-NaN answer if too few rows 

1582 if N < minp: 

1583 return jnp.full((K, K), jnp.nan) 

1584 

1585 # mask of finite entries 

1586 mask = jnp.isfinite(mat) # shape (N, K), dtype=bool 

1587 

1588 # precompute ranks column-wise ignoring NaNs 

1589 ranks = np.full((N, K), np.nan) 

1590 for j in range(K): 

1591 valid = mask[:, j] 

1592 if valid.any(): 

1593 ranks[valid, j] = rankdata(mat[valid, j], method="average") 

1594 

1595 ranks = jnp.asarray(ranks) 

1596 

1597 # Vectorised Pearson on the ranks 

1598 # Replace NaN ranks with 0; use mask to track validity. 

1599 rank_mask = jnp.isfinite(ranks) 

1600 safe_ranks = jnp.where(rank_mask, ranks, 0.0) 

1601 

1602 # Pairwise valid-observation counts (K, K) 

1603 fmask = rank_mask.astype(ranks.dtype) 

1604 nobs = fmask.T @ fmask 

1605 

1606 # Pairwise sums over mutually-valid rows 

1607 sum_x = safe_ranks.T @ fmask # (K, K) 

1608 sum_y = fmask.T @ safe_ranks # (K, K) 

1609 

1610 # Pairwise products 

1611 masked_ranks = safe_ranks * fmask # same as safe_ranks 

1612 sum_xy = masked_ranks.T @ masked_ranks # (K, K) 

1613 

1614 safe_sq = safe_ranks**2 

1615 sum_x2 = safe_sq.T @ fmask # (K, K) 

1616 sum_y2 = fmask.T @ safe_sq # (K, K) 

1617 

1618 nobs_safe = jnp.where(nobs > 0, nobs, 1.0) 

1619 ssx = sum_x2 - sum_x**2 / nobs_safe 

1620 ssy = sum_y2 - sum_y**2 / nobs_safe 

1621 sxy = sum_xy - (sum_x * sum_y) / nobs_safe 

1622 

1623 denom = jnp.sqrt(ssx * ssy) 

1624 result = jnp.where(denom > 0, sxy / denom, jnp.nan) 

1625 result = jnp.clip(result, -1.0, 1.0) 

1626 

1627 # Enforce minp 

1628 result = jnp.where(nobs < minp, jnp.nan, result) 

1629 

1630 return result 

1631 

1632 @classmethod 

1633 def _weighting_linear(cls, fourier_fingerprint: jnp.ndarray) -> jnp.ndarray: 

1634 """ 

1635 Performs weighting on the given correlation matrix. 

1636 Here, low-frequent coefficients are weighted more heavily. 

1637 

1638 Args: 

1639 fourier_fingerprint (jnp.ndarray): Correlation matrix 

1640 """ 

1641 assert ( 

1642 fourier_fingerprint.shape[0] % 2 != 0 

1643 and fourier_fingerprint.shape[1] % 2 != 0 

1644 ), ( 

1645 "Correlation matrix must have odd dimensions. \ 

1646 Hint: use `trim` argument when calling `get_spectrum`." 

1647 ) 

1648 assert fourier_fingerprint.shape[0] == fourier_fingerprint.shape[1], ( 

1649 "Correlation matrix must be square." 

1650 ) 

1651 

1652 # The weight matrix produced by the previous quadrant-mirror 

1653 # construction has a closed form: it is a "tent" sum along the 

1654 # two axes. Concretely, with N = fourier_fingerprint.shape[0] 

1655 # (odd) and center = N // 2, 

1656 # W[i, j] = u[i] + u[j] 

1657 # where u[k] = (center - |k - center|) / (2 * center) 

1658 # is a triangular weighting peaking at the centre (the zero 

1659 # frequency) and decaying linearly to 0 at the spectrum edges. 

1660 N = fourier_fingerprint.shape[0] 

1661 center = N // 2 

1662 k = jnp.arange(N) 

1663 u = (center - jnp.abs(k - center)) / (2 * center) 

1664 

1665 return fourier_fingerprint * (u[:, None] + u[None, :]) 

1666 

1667 @classmethod 

1668 def _weighting_mean( 

1669 cls, fourier_fingerprint: jnp.ndarray, coeffs: jnp.ndarray 

1670 ) -> jnp.ndarray: 

1671 """ 

1672 Performs weighting on the given correlation matrix. 

1673 Here, we use the product of the mean of the coefficients as weights. 

1674 This suppresses correlations where the mean of the coefficients is near zero. 

1675 

1676 Args: 

1677 fourier_fingerprint (jnp.ndarray): Correlation matrix 

1678 coeffs (jnp.ndarray): Fourier coefficients 

1679 """ 

1680 assert fourier_fingerprint.shape[0] == fourier_fingerprint.shape[1], ( 

1681 "Correlation matrix must be square." 

1682 ) 

1683 assert len(coeffs.shape) >= 2, ( 

1684 "Coefficient matrix must contain coefficient axes and a sample axis." 

1685 ) 

1686 

1687 coefficient_means = jnp.abs(jnp.mean(coeffs, axis=-1)) 

1688 coefficient_means = coefficient_means.T.reshape(-1) 

1689 

1690 assert fourier_fingerprint.shape[0] == coefficient_means.shape[0], ( 

1691 "Correlation matrix size must match the number of Fourier coefficients." 

1692 ) 

1693 

1694 # Apply the rank-1 weight w[i] * w[j] via broadcasting instead 

1695 # of materialising an explicit `jnp.outer` N x N intermediate. 

1696 return ( 

1697 fourier_fingerprint 

1698 * coefficient_means[:, None] 

1699 * coefficient_means[None, :] 

1700 ) 

1701 

1702 

1703class Datasets: 

1704 @classmethod 

1705 def generate_fourier_series( 

1706 cls, 

1707 random_key: random.PRNGKey, 

1708 model: Model, 

1709 coefficients_min: float = 0.0, 

1710 coefficients_max: float = 1.0, 

1711 zero_centered: bool = False, 

1712 ) -> jnp.ndarray: 

1713 """ 

1714 Generates the Fourier series representation of a function. 

1715 It uses the `model.frequencies` property to retrieve the frequency 

1716 information. This ensures that the resulting Fourier series is 

1717 compatible with the model. 

1718 

1719 This function is capable of generating $D$-dimensional Fourier series 

1720 (again defined by `model.n_input_feat`). 

1721 The highest frequency $N$ is retrieved per dimension. 

1722 

1723 Samples of the Fourier coefficients are drawn from a uniform circle. 

1724 

1725 Args: 

1726 random_key (random.PRNGKey): Random number key for JAX. 

1727 model (Model): The quantum circuit model. 

1728 coefficients_min (float, optional): Minimum value for the coefficients. 

1729 Defaults to 0.0. 

1730 coefficients_max (float, optional): Maximum value for the coefficients. 

1731 Defaults to 1.0. 

1732 zero_centered (bool, optional): Whether to zero-center the coefficients. 

1733 Defaults to False. 

1734 

1735 Returns: 

1736 jnp.ndarray: Input domain samples with shape ((N,)*D, D) 

1737 jnp.ndarray: Fourier series values with shape ((N,)*D) 

1738 jnp.ndarray: Fourier coefficients with shape ((N,)*D) 

1739 

1740 """ 

1741 # TODO: the following code can be considered to 

1742 # capturing a truly random spectrum. 

1743 # add some constraints on the spectrum, i.e. not fully 

1744 

1745 # Note: one key observation for understanding the following code is, 

1746 # that instead of wrapping your head around symmetries in multi- 

1747 # dimensional coefficient matrices, one can simply look at the flattened 

1748 # version of such a matrix and reshape later. It just works out. 

1749 

1750 # going from [0, 2pi] with the resolution required for highest frequency 

1751 # permute with input dimensionality to get an n-d grid of domain samples 

1752 # the output shape comes from the fact that want to create a "coordinate system" 

1753 domain_samples_per_input_dim = jnp.stack( 

1754 jnp.meshgrid( 

1755 *[jnp.arange(0, 2 * jnp.pi, 2 * jnp.pi / d) for d in model.degree] 

1756 ) 

1757 ).T.reshape(-1, model.n_input_feat) 

1758 

1759 # generate the frequency indices for each dimension. 

1760 # this will have the same shape as the domain samples 

1761 frequencies = jnp.stack(jnp.meshgrid(*model.frequencies)).T.reshape( 

1762 -1, model.n_input_feat 

1763 ) 

1764 

1765 # using the frequency information, sample coefficients for each dimension 

1766 # shape: (input_dims, n_freqs_per_input_dim // 2 + 1) 

1767 

1768 coefficients = cls.uniform_circle( 

1769 random_key, 

1770 low=coefficients_min, 

1771 high=coefficients_max, 

1772 size=math.prod(model.degree) // 2 + 1, 

1773 ) 

1774 

1775 # zero center (first coeff = 0) 

1776 # we can assume the first coeff is the offset, because we're dealing 

1777 # with a non-symmetric spectrum here 

1778 if zero_centered: 

1779 coefficients = coefficients.at[0].set(0.0) 

1780 else: 

1781 coefficients = coefficients.at[0].set(coefficients[0].real) 

1782 

1783 # ensure symmetry (here, non_negative_ is removed!), 

1784 # giving us the full coefficients vector 

1785 coefficients = jnp.concat( 

1786 [ 

1787 jnp.flip(coefficients[..., 1:]).conjugate(), 

1788 coefficients, 

1789 ], 

1790 axis=-1, 

1791 ) 

1792 

1793 # Vectorized version of $f(x) = \sum_{n=0}^{N-1} c_n * e^{i * \omega_n * x}$ 

1794 # it takes into account the input dimension, i.e. the output is a matrix 

1795 # normalization uses the n_freqs component of the coefficients 

1796 values = jnp.real( 

1797 ( 

1798 jnp.exp(1j * (domain_samples_per_input_dim @ frequencies.T)) 

1799 * coefficients 

1800 ).sum(axis=1) 

1801 / coefficients.size 

1802 ) 

1803 

1804 # return all the information we have 

1805 return [ 

1806 domain_samples_per_input_dim.reshape(*model.degree, -1), 

1807 values.reshape(model.degree), 

1808 coefficients.reshape(model.degree), 

1809 ] 

1810 

1811 @classmethod 

1812 def uniform_circle( 

1813 cls, 

1814 random_key: random.PRNGKey, 

1815 size: Union[jnp.ndarray, List, int], 

1816 low=0.0, 

1817 high=1.0, 

1818 ): 

1819 """ 

1820 Random number generator for complex numbers sampled inside the unit circle 

1821 

1822 Args: 

1823 random_key (random.PRNGKey): Random number key for JAX. 

1824 size (Union[jnp.ndarray, int]): Number of samples. If a 2D array is passed, 

1825 the first dimension will be the number of dimensions. 

1826 low (float, optional): Minimum Radius. Defaults to 0.0. 

1827 high (float, optional): Maximum Radius. Defaults to 1.0. 

1828 

1829 Returns 

1830 jnp.ndarray: Array of complex numbers with shape of `size` 

1831 """ 

1832 

1833 if isinstance(size, int): 

1834 size = jnp.array([size]) 

1835 

1836 random_key, random_key1 = random.split(random_key) 

1837 return jnp.sqrt( 

1838 random.uniform(random_key, size, minval=low, maxval=high) 

1839 ) * jnp.exp(2j * jnp.pi * random.uniform(random_key1, size))