Coverage for qml_essentials/entanglement.py: 88%

67 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-07 14:54 +0000

1from typing import Optional, Any 

2import pennylane as qml 

3import pennylane.numpy as np 

4from copy import deepcopy 

5 

6from qml_essentials.model import Model 

7import logging 

8 

9log = logging.getLogger(__name__) 

10 

11 

12class Entanglement: 

13 

14 @staticmethod 

15 def meyer_wallach( 

16 model: Model, 

17 n_samples: Optional[int | None], 

18 seed: Optional[int], 

19 **kwargs: Any, 

20 ) -> float: 

21 """ 

22 Calculates the entangling capacity of a given quantum circuit 

23 using Meyer-Wallach measure. 

24 

25 Args: 

26 model (Callable): Function that models the quantum circuit. 

27 n_samples (int): Number of samples per qubit. 

28 If None or < 0, the current parameters of the model are used 

29 seed (Optional[int]): Seed for the random number generator. 

30 kwargs (Any): Additional keyword arguments for the model function. 

31 

32 Returns: 

33 float: Entangling capacity of the given circuit. It is guaranteed 

34 to be between 0.0 and 1.0. 

35 """ 

36 rng = np.random.default_rng(seed) 

37 if n_samples is not None and n_samples > 0: 

38 assert seed is not None, "Seed must be provided when samples > 0" 

39 # TODO: maybe switch to JAX rng 

40 model.initialize_params(rng=rng, repeat=n_samples) 

41 params = model.params 

42 else: 

43 if seed is not None: 

44 log.warning("Seed is ignored when samples is 0") 

45 

46 if len(model.params.shape) <= 2: 

47 params = model.params.reshape(*model.params.shape, 1) 

48 else: 

49 log.info(f"Using sample size of model params: {model.params.shape[-1]}") 

50 params = model.params 

51 

52 n_samples = params.shape[-1] 

53 mw_measure = np.zeros(n_samples) 

54 qb = list(range(model.n_qubits)) 

55 

56 # TODO: vectorize in future iterations 

57 for i in range(n_samples): 

58 # implicitly set input to none in case it's not needed 

59 kwargs.setdefault("inputs", None) 

60 # explicitly set execution type because everything else won't work 

61 U = model(params=params[:, :, i], execution_type="density", **kwargs) 

62 

63 # Formula 6 in https://doi.org/10.48550/arXiv.quant-ph/0305094 

64 # --- 

65 entropy = 0 

66 for j in range(model.n_qubits): 

67 density = qml.math.partial_trace(U, qb[:j] + qb[j + 1 :]) 

68 # only real values, because imaginary part will be separate 

69 # in all following calculations anyway 

70 # entropy should be 1/2 <= entropy <= 1 

71 entropy += np.trace((density @ density).real) 

72 

73 # inverse averaged entropy and scale to [0, 1] 

74 mw_measure[i] = 2 * (1 - entropy / model.n_qubits) 

75 # --- 

76 

77 # Average all iterated states 

78 # catch floating point errors 

79 entangling_capability = min(max(mw_measure.mean(), 0.0), 1.0) 

80 log.debug(f"Variance of measure: {mw_measure.var()}") 

81 

82 return float(entangling_capability) 

83 

84 @staticmethod 

85 def bell_measurements(model: Model, n_samples, seed, **kwargs: Any) -> float: 

86 

87 def _circuit(params, inputs): 

88 model._variational(params, inputs) 

89 

90 qml.map_wires( 

91 model._variational, 

92 {i: i + model.n_qubits for i in range(model.n_qubits)}, 

93 )(params, inputs) 

94 

95 for q in range(model.n_qubits): 

96 qml.CNOT(wires=[q, q + model.n_qubits]) 

97 qml.H(q) 

98 

99 obs_wires = [(q, q + model.n_qubits) for q in range(model.n_qubits)] 

100 return [qml.probs(wires=w) for w in obs_wires] 

101 

102 model.circuit = qml.QNode( 

103 _circuit, 

104 qml.device( 

105 "default.qubit", 

106 shots=model.shots, 

107 wires=model.n_qubits * 2, 

108 ), 

109 ) 

110 

111 rng = np.random.default_rng(seed) 

112 if n_samples is not None and n_samples > 0: 

113 assert seed is not None, "Seed must be provided when samples > 0" 

114 # TODO: maybe switch to JAX rng 

115 model.initialize_params(rng=rng, repeat=n_samples) 

116 params = model.params 

117 else: 

118 if seed is not None: 

119 log.warning("Seed is ignored when samples is 0") 

120 

121 if len(model.params.shape) <= 2: 

122 params = model.params.reshape(*model.params.shape, 1) 

123 else: 

124 log.info(f"Using sample size of model params: {model.params.shape[-1]}") 

125 params = model.params 

126 

127 n_samples = params.shape[-1] 

128 mw_measure = np.zeros(n_samples) 

129 

130 for i in range(n_samples): 

131 # implicitly set input to none in case it's not needed 

132 kwargs.setdefault("inputs", None) 

133 exp = model(params=params[:, :, i], **kwargs) 

134 

135 exp = 1 - 2 * exp[:, -1] 

136 mw_measure[i] = 2 * (1 - exp.mean()) 

137 entangling_capability = min(max(mw_measure.mean(), 0.0), 1.0) 

138 log.debug(f"Variance of measure: {mw_measure.var()}") 

139 

140 return float(entangling_capability)