Coverage for qml_essentials / evolution.py: 94%

170 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-06-11 15:51 +0000

1"""Hamiltonian time-evolution machinery for pulse/gate construction. 

2 

3This module houses the :class:`Evolution` class, which turns a (static or 

4time-dependent) Hamiltonian into a gate factory by solving the Schrödinger 

5equation ``dU/dt = -i H(t) U``. It is the pulse/gate-dependent counterpart to 

6the otherwise pulse-agnostic :mod:`qml_essentials.jaqsi` entry point. 

7 

8The engine is normally reached through the :meth:`evolve` method on the 

9Hamiltonian object (``Hermitian`` / ``ParametrizedHamiltonian``), which delegates 

10to :meth:`Evolution.evolve`. :class:`Evolution` is also where solver defaults 

11live (:meth:`Evolution.set_solver_defaults`). 

12""" 

13 

14from typing import Any, Callable, List, Optional, Tuple, Union 

15import math 

16import threading 

17 

18import diffrax 

19import jax 

20import jax.numpy as jnp 

21import jax.scipy.linalg 

22import equinox as eqx 

23 

24from qml_essentials.operations import ( 

25 Hermitian, 

26 ParametrizedHamiltonian, 

27 Operation, 

28) 

29 

30 

31class Evolution: 

32 # Module-level cache for JIT-compiled ODE solvers. Keyed on 

33 # (coeff_fn_id, dim, atol, rtol, max_steps, throw) so that all 

34 # evolve() calls with the same pulse shape function and matrix size 

35 # share one compiled XLA program. This turns O(n_gates) JIT 

36 # compilations into O(n_distinct_pulse_shapes) during pulse-mode 

37 # circuit building. 

38 _evolve_solver_cache: dict = {} 

39 _evolve_solver_cache_lock = threading.Lock() 

40 

41 # Default solver knobs for parametrized (time-dependent) evolution. 

42 # These can be overridden per-call via the **odeint_kwargs of 

43 # ``evolve()`` or globally via :meth:`set_solver_defaults`. 

44 # 

45 # ``max_steps`` is the hard cap on accepted ODE steps. Pulse-level 

46 # workloads at on-resonance carriers (ω_c ≈ ω_q) require many more 

47 # steps than the diffrax default during JIT — 2**13 = 8192 is 

48 # large enough for realistic single- and two-qubit pulses while 

49 # remaining cheap to compile. 

50 # 

51 # ``throw`` controls whether diffrax raises on solver failure 

52 # (e.g. ``MaxStepsReached``). When set to ``False`` the gate 

53 # factory instead returns a NaN-filled unitary so the calling 

54 # optimiser sees a well-defined (but useless) result and can 

55 # gracefully reject the candidate. 

56 

57 # ``solver`` selects the time-integration backend for the 

58 # interaction-picture ODE ``dU/dt = -i H_I(t) U``: 

59 # 

60 # * ``"dopri8"`` (default) — adaptive Dormand-Prince 8(7) via 

61 # diffrax. Robust but expensive on highly oscillatory drives 

62 # because the step controller resolves every fast cycle. 

63 # * ``"dopri5"`` — TODO description 

64 # * ``"magnus2"`` — commutator-free Magnus, 2nd order (midpoint 

65 # rule) on a fixed ``magnus_steps`` grid via ``jax.lax.scan``. 

66 # One ``expm`` per step. Preserves unitarity to machine 

67 # precision and fuses into a single XLA program. 

68 # * ``"magnus4"`` — commutator-free Magnus, 4th order (CFM4:2 of 

69 # Blanes & Moan) on a fixed ``magnus_steps`` grid. Two ``H`` 

70 # evaluations and two ``expm`` per step; typically the best 

71 # accuracy/cost trade-off for smooth oscillatory pulse drives. 

72 # 

73 # ``magnus_steps`` is the number of fixed substeps for the Magnus 

74 # integrators (ignored for ``dopri8``). Choose it so that ``h = 

75 # T/N`` resolves the fastest oscillation in ``H(t)`` (~few steps 

76 # per period of the highest frequency). 

77 _solver_defaults: dict = { 

78 "max_steps": 2**13, 

79 "throw": True, 

80 "solver": "dopri8", 

81 "magnus_steps": 256, 

82 } 

83 _valid_solvers = ("dopri8", "dopri5", "magnus2", "magnus4") 

84 

85 @classmethod 

86 def set_solver_defaults( 

87 cls, 

88 max_steps: Optional[int] = None, 

89 throw: Optional[bool] = None, 

90 solver: Optional[str] = None, 

91 magnus_steps: Optional[int] = None, 

92 ) -> dict: 

93 """Update class-level solver defaults; return the previous values. 

94 

95 The returned dictionary is suitable for restoring the previous 

96 defaults via ``set_solver_defaults(**prev)``. 

97 

98 Args: 

99 max_steps: New default for ``max_steps`` (ignored if ``None``). 

100 throw: New default for ``throw`` (ignored if ``None``). 

101 

102 Returns: 

103 Dictionary with the previous values of the updated keys. 

104 """ 

105 prev: dict = {} 

106 if max_steps is not None: 

107 prev["max_steps"] = cls._solver_defaults["max_steps"] 

108 cls._solver_defaults["max_steps"] = int(max_steps) 

109 if throw is not None: 

110 prev["throw"] = cls._solver_defaults["throw"] 

111 cls._solver_defaults["throw"] = bool(throw) 

112 if solver is not None: 

113 if solver not in cls._valid_solvers: 

114 raise ValueError( 

115 f"Unknown solver {solver!r}; expected one of {cls._valid_solvers}" 

116 ) 

117 prev["solver"] = cls._solver_defaults["solver"] 

118 cls._solver_defaults["solver"] = solver 

119 if magnus_steps is not None: 

120 prev["magnus_steps"] = cls._solver_defaults["magnus_steps"] 

121 cls._solver_defaults["magnus_steps"] = int(magnus_steps) 

122 return prev 

123 

124 @classmethod 

125 def _store_evolve_solver(cls, cache_key: tuple, solve: Callable) -> Callable: 

126 """Cache a compiled evolve solver unless another thread won the race.""" 

127 with cls._evolve_solver_cache_lock: 

128 existing = cls._evolve_solver_cache.get(cache_key) 

129 if existing is not None: 

130 return existing 

131 cls._evolve_solver_cache[cache_key] = solve 

132 return solve 

133 

134 @classmethod 

135 def clear_evolve_solver_cache(cls) -> None: 

136 """Drop every cached compiled evolve solver. 

137 

138 Call this whenever the coefficient functions referenced by the 

139 cache keys are rebuilt (e.g. when :class:`PulseGates` swaps in 

140 a new pulse envelope, RWA flag or frame). Without an explicit 

141 eviction the cache keeps the old code objects alive and would 

142 also retain XLA programs that no longer match any active 

143 coefficient function. 

144 """ 

145 with cls._evolve_solver_cache_lock: 

146 cls._evolve_solver_cache.clear() 

147 

148 @classmethod 

149 def _parse_evolve_solver_options(cls, odeint_kwargs: dict) -> tuple: 

150 """Pop and validate solver options from ``evolve(..., **odeint_kwargs)``.""" 

151 default_tol = 1.0e-10 if jax.config.x64_enabled else 1.4e-8 

152 atol = odeint_kwargs.pop("atol", default_tol) 

153 rtol = odeint_kwargs.pop("rtol", default_tol) 

154 max_steps = int( 

155 odeint_kwargs.pop("max_steps", cls._solver_defaults["max_steps"]) 

156 ) 

157 throw = bool(odeint_kwargs.pop("throw", cls._solver_defaults["throw"])) 

158 solver_name = str(odeint_kwargs.pop("solver", cls._solver_defaults["solver"])) 

159 if solver_name not in cls._valid_solvers: 

160 raise ValueError( 

161 f"Unknown solver {solver_name!r}; expected one of {cls._valid_solvers}" 

162 ) 

163 magnus_steps = int( 

164 odeint_kwargs.pop("magnus_steps", cls._solver_defaults["magnus_steps"]) 

165 ) 

166 return atol, rtol, max_steps, throw, solver_name, magnus_steps 

167 

168 @classmethod 

169 def _build_magnus_evolve_solver( 

170 cls, 

171 cache_key: tuple, 

172 coeff_fns: Tuple[Callable, ...], 

173 n_terms: int, 

174 dim: int, 

175 solver_name: str, 

176 magnus_steps: int, 

177 ) -> Callable: 

178 """Build and cache a fixed-step commutator-free Magnus solver.""" 

179 _coeff_fns = coeff_fns 

180 _cdtype_local = jnp.complex128 if jax.config.x64_enabled else jnp.complex64 

181 n_steps = magnus_steps 

182 solver_name_local = solver_name 

183 

184 @eqx.filter_jit 

185 def _solve(neg_iH_split, params, t0, t1): 

186 # Reconstruct the per-term complex matrices ``-i H_i`` from their 

187 # split (Re, Im) representation so the coefficient sum is a single 

188 # complex tensordot. 

189 A_all = neg_iH_split[:, 0] 

190 B_all = neg_iH_split[:, 1] 

191 neg_iH = (A_all + 1j * B_all).astype(_cdtype_local) 

192 

193 h = (t1 - t0) / n_steps 

194 

195 def H_at(t): 

196 c = jnp.stack( 

197 [ 

198 jnp.asarray(_coeff_fns[i](params[i], t)).reshape(()) 

199 for i in range(n_terms) 

200 ] 

201 ).astype(_cdtype_local) 

202 return jnp.tensordot(c, neg_iH, axes=1) 

203 

204 if solver_name_local == "magnus2": 

205 

206 def step(U, n): 

207 tn = t0 + n * h 

208 Omega = h * H_at(tn + 0.5 * h) 

209 return jax.scipy.linalg.expm(Omega) @ U, None 

210 

211 else: 

212 sqrt3 = math.sqrt(3.0) 

213 c1 = 0.5 - sqrt3 / 6.0 

214 c2 = 0.5 + sqrt3 / 6.0 

215 a1 = 0.25 + sqrt3 / 6.0 

216 a2 = 0.25 - sqrt3 / 6.0 

217 

218 def step(U, n): 

219 tn = t0 + n * h 

220 H1 = H_at(tn + c1 * h) 

221 H2 = H_at(tn + c2 * h) 

222 Omega_a = h * (a1 * H1 + a2 * H2) 

223 Omega_b = h * (a2 * H1 + a1 * H2) 

224 # CFM4:2 ordering (Blanes & Moan 2006, Table II): 

225 # U_{n+1} = exp(Ω_b) · exp(Ω_a) · U_n. 

226 U_next = ( 

227 jax.scipy.linalg.expm(Omega_b) 

228 @ jax.scipy.linalg.expm(Omega_a) 

229 @ U 

230 ) 

231 return U_next, None 

232 

233 U0 = jnp.eye(dim, dtype=_cdtype_local) 

234 U_final, _ = jax.lax.scan(step, U0, jnp.arange(n_steps)) 

235 return U_final 

236 

237 return cls._store_evolve_solver(cache_key, _solve) 

238 

239 @classmethod 

240 def _build_diffrax_evolve_solver( 

241 cls, 

242 cache_key: tuple, 

243 coeff_fns: Tuple[Callable, ...], 

244 n_terms: int, 

245 dim: int, 

246 atol: float, 

247 rtol: float, 

248 max_steps: int, 

249 throw: bool, 

250 solver_name: str, 

251 _rdtype, 

252 ) -> Callable: 

253 """Build and cache an adaptive diffrax-based evolve solver.""" 

254 solver = diffrax.Dopri8() if solver_name == "dopri8" else diffrax.Dopri5() 

255 stepsize_controller = diffrax.PIDController(atol=atol, rtol=rtol) 

256 _coeff_fns = coeff_fns 

257 

258 @eqx.filter_jit 

259 def _solve(neg_iH_split, params, t0, t1): 

260 """Solve dU/dt = sum_i f_i(p_i, t) * (-iH_i) * U from t0 to t1. 

261 

262 ``neg_iH_split`` has shape ``(n_terms, 2, dim, dim)`` with 

263 ``[:, 0]`` = Re(-iH_i) and ``[:, 1]`` = Im(-iH_i). 

264 ``params`` is a list/tuple of length ``n_terms`` carrying 

265 each term's coefficient parameters. The state ``y`` has 

266 shape ``(2, dim, dim)`` with ``y[0] = Re(U)`` and 

267 ``y[1] = Im(U)``. 

268 """ 

269 A_all = neg_iH_split[:, 0] 

270 B_all = neg_iH_split[:, 1] 

271 

272 def rhs(t, y, args): 

273 # Each coefficient function must return a scalar value; some 

274 # call sites pass a shape-(1,) param array, so coerce to a 

275 # true scalar before stacking. 

276 c = jnp.stack( 

277 [ 

278 jnp.asarray(_coeff_fns[i](args[i], t)).reshape(()) 

279 for i in range(n_terms) 

280 ] 

281 ) 

282 u_re = y[0] 

283 u_im = y[1] 

284 A_eff = jnp.tensordot(c, A_all, axes=1) 

285 B_eff = jnp.tensordot(c, B_all, axes=1) 

286 du_re = A_eff @ u_re - B_eff @ u_im 

287 du_im = A_eff @ u_im + B_eff @ u_re 

288 return jnp.stack([du_re, du_im], axis=0) 

289 

290 y0 = jnp.stack( 

291 [ 

292 jnp.eye(dim, dtype=_rdtype), 

293 jnp.zeros((dim, dim), dtype=_rdtype), 

294 ], 

295 axis=0, 

296 ) 

297 

298 sol = diffrax.diffeqsolve( 

299 diffrax.ODETerm(rhs), 

300 solver, 

301 t0=t0, 

302 t1=t1, 

303 dt0=None, 

304 y0=y0, 

305 args=params, 

306 stepsize_controller=stepsize_controller, 

307 max_steps=max_steps, 

308 throw=throw, 

309 ) 

310 

311 y_final = sol.ys[0] 

312 U = y_final[0] + 1j * y_final[1] 

313 

314 if not throw: 

315 successful = sol.result == diffrax.RESULTS.successful 

316 U = jnp.where(successful, U, jnp.full_like(U, jnp.nan)) 

317 return U 

318 

319 return cls._store_evolve_solver(cache_key, _solve) 

320 

321 @classmethod 

322 def evolve( 

323 cls, 

324 hamiltonian: Union["Hermitian", "ParametrizedHamiltonian"], 

325 name: Optional[str] = None, 

326 **odeint_kwargs: Any, 

327 ) -> Callable: 

328 """Return a gate-factory for Hamiltonian time evolution. 

329 

330 Engine for the :meth:`Hermitian.evolve` / :meth:`ParametrizedHamiltonian.evolve` 

331 methods (the usual entry point); it dispatches on the Hamiltonian type. 

332 

333 Supports two modes: 

334 

335 Static — when *hamiltonian* is a :class:`Hermitian`:: 

336 

337 gate = Hermitian(H_mat, wires=0).evolve() 

338 gate(t=0.5) # U = exp(-i*0.5*H) 

339 

340 Time-dependent — when *hamiltonian* is a 

341 :class:`ParametrizedHamiltonian` (created via ``coeff_fn * Hermitian``):: 

342 

343 H_td = coeff_fn * Hermitian(H_mat, wires=0) 

344 gate = H_td.evolve() 

345 gate([A, sigma], T) # U via ODE: dU/dt = -i f(p,t) H * U 

346 

347 The time-dependent case solves the Schrödinger equation numerically 

348 using ``diffrax.diffeqsolve`` with a Dopri8 adaptive Runge-Kutta 

349 solver 

350 

351 All computations are pure JAX and fully differentiable with 

352 ``jax.grad``. 

353 

354 Args: 

355 hamiltonian: Either a :class:`Hermitian` (static evolution) or a 

356 :class:`ParametrizedHamiltonian` (time-dependent evolution). 

357 **odeint_kwargs: Extra keyword arguments. Recognised keys: 

358 

359 - ``atol``, ``rtol`` — absolute/relative tolerances for the 

360 adaptive step-size controller (default ``1.4e-8``). 

361 

362 Returns: 

363 A callable gate factory. Signature depends on the mode: 

364 

365 - Static: ``(t, wires=0) -> Operation`` 

366 - Time-dependent: ``(coeff_args, T) -> Operation`` 

367 

368 Raises: 

369 TypeError: If *hamiltonian* is neither ``Hermitian`` nor 

370 ``ParametrizedHamiltonian``. 

371 """ 

372 if isinstance(hamiltonian, Hermitian): 

373 return cls._evolve_static(hamiltonian, name=name) 

374 elif isinstance(hamiltonian, ParametrizedHamiltonian): 

375 return cls._evolve_parametrized(hamiltonian, name=name, **odeint_kwargs) 

376 else: 

377 raise TypeError( 

378 f"evolve() expects a Hermitian or ParametrizedHamiltonian, " 

379 f"got {type(hamiltonian)}" 

380 ) 

381 

382 @staticmethod 

383 def _evolve_static(hermitian: Hermitian, name: Optional[str] = None) -> Callable: 

384 """Gate factory for static Hamiltonian evolution U = exp(-i t H).""" 

385 H_mat = hermitian.matrix 

386 

387 def _apply(t: float, wires: Union[int, List[int]] = 0) -> Operation: 

388 U = jax.scipy.linalg.expm(-1j * t * H_mat) 

389 return Operation(wires=wires, matrix=U, name=name) 

390 

391 return _apply 

392 

393 @classmethod 

394 def _evolve_parametrized( 

395 cls, 

396 ph: ParametrizedHamiltonian, 

397 name: Optional[str] = None, 

398 **odeint_kwargs: Any, 

399 ) -> Callable: 

400 """Gate factory for time-dependent (multi-term) Hamiltonian evolution. 

401 

402 Solves the matrix ODE 

403 

404 dU/dt = -i [\\sum_i f_i(params_i, t) * H_i] * U, U(0) = I 

405 

406 with ``diffrax.diffeqsolve`` (Dopri8 adaptive RK). The Hamiltonian 

407 may contain one or more ``coeff_fn * Hermitian`` terms (see 

408 :class:`ParametrizedHamiltonian`); the single-term case is the 

409 usual ``coeff_fn * Hermitian`` and is fully backward compatible. 

410 

411 Implementation notes: 

412 

413 - To avoid diffrax's experimental complex dtype path, the ODE is 

414 reformulated in real arithmetic. Writing ``-iH_i = A_i + i B_i`` 

415 and ``U = U_re + i U_im``, each term contributes:: 

416 

417 d(U_re)/dt += f_i(p_i,t) * (A_i @ U_re - B_i @ U_im) 

418 d(U_im)/dt += f_i(p_i,t) * (A_i @ U_im + B_i @ U_re) 

419 

420 - ``-iH_i`` is precomputed once per term and stacked into a 

421 ``(n_terms, 2, dim, dim)`` real array, contracted via 

422 ``einsum`` against the per-step coefficient vector 

423 ``c = [f_0(p_0,t), ..., f_{n-1}(p_{n-1},t)]``. 

424 

425 - The JIT-compiled solver is cached per coefficient-function code 

426 tuple (and ``dim``, tolerances) so multiple ``evolve()`` calls 

427 with the same pulse shape — but different Hamiltonian matrices 

428 or parameters — reuse the same compiled XLA program. 

429 

430 TODO: switch back once diffrax is stable with complex arithmetic. 

431 

432 Args: 

433 ph: A :class:`ParametrizedHamiltonian` (one or more terms). 

434 **odeint_kwargs: Keyword arguments forwarded to 

435 ``diffrax.diffeqsolve``. Recognised keys: 

436 

437 - ``atol``, ``rtol`` — absolute/relative tolerances for the 

438 step-size controller (default ``1.4e-8`` in fp32 mode, 

439 ``1.0e-10`` in fp64 mode). 

440 - ``max_steps`` — hard cap on accepted ODE steps 

441 (default :attr:`cls._solver_defaults['max_steps']`, 

442 currently ``2**14``). Increase this if the integrator 

443 raises ``MaxStepsReached`` for a stiff/oscillatory 

444 pulse Hamiltonian. 

445 - ``throw`` — whether to raise on solver failure 

446 (default :attr:`cls._solver_defaults['throw']`, 

447 currently ``True``). When ``False``, a failed 

448 integration returns a NaN-filled unitary instead of 

449 raising; this is the recommended setting for inner 

450 loops of an optimiser (e.g. QOC Stage 0) so a single 

451 pathological candidate cannot abort the whole run. 

452 """ 

453 coeff_fns = ph.coeff_fns # tuple of callables 

454 H_mats = ph.H_mats # tuple of (dim, dim) 

455 wires = ph.wires 

456 n_terms = ph.n_terms 

457 dim = H_mats[0].shape[0] 

458 

459 # Pre-compute -i*H_i for each term and split into real / imaginary 

460 # parts so the ODE RHS uses only real arithmetic. Final shape: 

461 # (n_terms, 2, dim, dim). 

462 neg_iH_split_per_term = [] 

463 for H_mat in H_mats: 

464 neg_iH = -1j * H_mat 

465 neg_iH_split_per_term.append( 

466 jnp.stack([jnp.real(neg_iH), jnp.imag(neg_iH)], axis=0) 

467 ) 

468 neg_iH_split = jnp.stack(neg_iH_split_per_term, axis=0) 

469 

470 # Real dtype matching the precision mode 

471 # consider decreasing if no convergence 

472 _rdtype = jnp.float64 if jax.config.x64_enabled else jnp.float32 

473 

474 # Pick tolerances according to precision + some headroom 

475 atol, rtol, max_steps, throw, solver_name, magnus_steps = ( 

476 cls._parse_evolve_solver_options(odeint_kwargs) 

477 ) 

478 

479 # Cache key: every coeff fn's code object (same shape of pulse 

480 # fns -> same JIT program) plus dim, tolerances, and solver 

481 # budget / throw flag (different budgets mean different XLA 

482 # programs). We use the code object itself (hashable, identity- 

483 # equal) rather than ``id(fn.__code__)``: ids can be reused for 

484 # later code objects after the original is garbage-collected, 

485 # which would silently return a stale compiled solver for a 

486 # different pulse shape. Holding the code object in the cache 

487 # keeps it alive for as long as the cached program is valid. 

488 cache_key = ( 

489 tuple(fn.__code__ for fn in coeff_fns), 

490 dim, 

491 atol, 

492 rtol, 

493 max_steps, 

494 throw, 

495 solver_name, 

496 magnus_steps, 

497 ) 

498 

499 with cls._evolve_solver_cache_lock: 

500 _solve = cls._evolve_solver_cache.get(cache_key) 

501 if _solve is None: 

502 if solver_name in ("magnus2", "magnus4"): 

503 _solve = cls._build_magnus_evolve_solver( 

504 cache_key=cache_key, 

505 coeff_fns=coeff_fns, 

506 n_terms=n_terms, 

507 dim=dim, 

508 solver_name=solver_name, 

509 magnus_steps=magnus_steps, 

510 ) 

511 else: 

512 _solve = cls._build_diffrax_evolve_solver( 

513 cache_key=cache_key, 

514 coeff_fns=coeff_fns, 

515 n_terms=n_terms, 

516 dim=dim, 

517 atol=atol, 

518 rtol=rtol, 

519 max_steps=max_steps, 

520 throw=throw, 

521 solver_name=solver_name, 

522 _rdtype=_rdtype, 

523 ) 

524 

525 def _apply(coeff_args, T) -> Operation: 

526 """Evolve under the (multi-term) time-dependent Hamiltonian. 

527 

528 Args: 

529 coeff_args: List/tuple of parameter sets, one per term. 

530 For single-term Hamiltonians the legacy form 

531 ``[params]`` works unchanged; ``params`` is forwarded 

532 to the sole coefficient function. 

533 T: Total evolution time. Scalar -> integrate on 

534 ``[0, T]``; 2-element -> integrate on ``[T[0], T[1]]``. 

535 

536 Returns: 

537 An :class:`Operation` wrapping the computed unitary. 

538 """ 

539 # Normalise to a tuple of length n_terms. Accept a bare 

540 # single-term arg for backward compat. 

541 if isinstance(coeff_args, (list, tuple)): 

542 params = tuple(coeff_args) 

543 else: 

544 params = (coeff_args,) 

545 

546 if len(params) != n_terms: 

547 raise ValueError( 

548 f"Expected {n_terms} parameter set(s) for a " 

549 f"{n_terms}-term ParametrizedHamiltonian, " 

550 f"got {len(params)}." 

551 ) 

552 

553 # Build time span — resolve at Python level to avoid traced 

554 # branching. ``T`` is either a Python scalar / 0-d array (=> integrate 

555 # on [0, T]) or a 2-element sequence/array (=> integrate on [T[0], T[1]]). 

556 # Let ``_solve`` cast t0/t1 to its working dtype; we only need the 

557 # array form to know the rank. 

558 T_arr = jnp.asarray(T, dtype=_rdtype) 

559 if T_arr.ndim == 0: 

560 t0 = _rdtype(0.0) 

561 t1 = T_arr 

562 else: 

563 t0 = T_arr[0] 

564 t1 = T_arr[1] 

565 

566 U = _solve(neg_iH_split, params, t0, t1) 

567 

568 return Operation(wires=wires, matrix=U, name=name) 

569 

570 return _apply