Coverage for qml_essentials/topologies.py: 98%

47 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2026-02-20 14:03 +0000

1from typing import List, Callable, Union 

2import jax 

3import logging 

4 

5jax.config.update("jax_enable_x64", True) 

6log = logging.getLogger(__name__) 

7 

8 

9class Topology: 

10 """ 

11 Generates [control, target] wire-pair lists for two-qubit gates. 

12 

13 All public methods are static and share a small set of private 

14 helpers so that related topologies (e.g. ``linear`` / ``circular``, 

15 ``brick_layer`` / ``brick_layer_wrap``) re-use the same core logic. 

16 

17 Raises 

18 ------ 

19 ValueError 

20 If ``n_qubits < 2`` is passed to any topology method. 

21 """ 

22 

23 @classmethod 

24 def stairs( 

25 cls, 

26 n_qubits: int, 

27 offset: Union[int, Callable] = 0, 

28 wrap=False, 

29 reverse: bool = True, 

30 mirror: bool = True, 

31 span: Union[int, Callable] = 1, 

32 stride: int = 1, 

33 modulo: bool = True, 

34 ) -> List[List[int]]: 

35 """ 

36 Unified generator for nearest-neighbour and spand pair topologies. 

37 Produces ``[control, target]`` pairs of qubits. 

38 

39 The default values, produce an "upstairs" entangling sequence 

40 without wrapping around the last gate. 

41 

42 Parameters 

43 ---------- 

44 n_qubits : int 

45 Number of qubits. 

46 offset : Union[int, Callable] 

47 Offset for starting the entangling sequence. 

48 Can either be a integer or a callable that takes n_qubits as input. 

49 wrap : bool 

50 Wraps around the entangling gates. 

51 reverse : bool 

52 Reverses both the iteration direction (upstairs/ downstairs) 

53 mirror: bool 

54 Flip target/ control qubit 

55 span : int 

56 Offset between control and target qubit. Defaults to 1 

57 stride : int 

58 Step size for entangling gates. Defaults to 1, meaning a stair 

59 pattern will be generated. 

60 modulo : bool 

61 If a gate should be placed when the iterator decreases below 0 

62 or exceeds n_qubits. Defaults to True 

63 

64 Returns 

65 ------- 

66 List[List[int]] 

67 """ 

68 ctrls = [] 

69 targets = [] 

70 

71 n_gates = n_qubits if wrap else n_qubits - 1 

72 _offset = offset(n_qubits) if callable(offset) else offset 

73 _span = span(n_qubits) if callable(span) else span 

74 

75 for q in range(0, n_gates, stride): 

76 _target = q + _offset + _span 

77 if _target >= n_qubits and not modulo: 

78 continue 

79 _control = q + _offset 

80 if _control < 0 and not modulo: 

81 continue 

82 

83 _target = _target % n_qubits 

84 _control = _control % n_qubits 

85 

86 if _target == _control: 

87 log.warning("Skipping gate where control == target") 

88 continue 

89 

90 targets += [_target] 

91 ctrls += [_control] 

92 

93 if reverse: 

94 ctrls = reversed(ctrls) 

95 targets = reversed(targets) 

96 

97 if mirror: 

98 ctrls, targets = targets, ctrls 

99 

100 pairs = list(zip(ctrls, targets, strict=True)) 

101 

102 return pairs 

103 

104 @classmethod 

105 def bricks(cls, n_qubits: int, **kwargs) -> List[List[int]]: 

106 kwargs.setdefault("stride", 2) 

107 kwargs.setdefault("modulo", False) 

108 return cls.stairs(n_qubits=n_qubits, **kwargs) 

109 

110 @classmethod 

111 def all_to_all(cls, n_qubits: int) -> List[List[int]]: 

112 """Every ordered pair ``(i, j)`` with ``i ≠ j``.""" 

113 pairs: List[List[int]] = [] 

114 for ql in range(n_qubits): 

115 for q in range(n_qubits): 

116 if q != ql: 

117 pairs.append( 

118 [ 

119 n_qubits - ql - 1, 

120 (n_qubits - q - 1) % n_qubits, 

121 ] 

122 ) 

123 return pairs