3
3
import pickle
4
4
from collections import defaultdict
5
5
from dataclasses import dataclass
6
+ from functools import lru_cache
7
+ from itertools import batched
6
8
from pathlib import Path
7
9
from typing import Self
8
10
@@ -33,6 +35,10 @@ def add_duty(self, included: bool) -> None:
33
35
self .included += 1 if included else 0
34
36
35
37
38
+ type Frame = tuple [EpochNumber , EpochNumber ]
39
+ type StateData = dict [Frame , defaultdict [ValidatorIndex , AttestationsAccumulator ]]
40
+
41
+
36
42
class State :
37
43
"""
38
44
Processing state of a CSM performance oracle frame.
@@ -43,16 +49,17 @@ class State:
43
49
44
50
The state can be migrated to be used for another frame's report by calling the `migrate` method.
45
51
"""
46
-
47
- data : defaultdict [ ValidatorIndex , AttestationsAccumulator ]
52
+ frames : list [ Frame ]
53
+ data : StateData
48
54
49
55
_epochs_to_process : tuple [EpochNumber , ...]
50
56
_processed_epochs : set [EpochNumber ]
51
57
52
58
_consensus_version : int = 1
53
59
54
- def __init__ (self , data : dict [ValidatorIndex , AttestationsAccumulator ] | None = None ) -> None :
55
- self .data = defaultdict (AttestationsAccumulator , data or {})
60
+ def __init__ (self ) -> None :
61
+ self .frames = []
62
+ self .data = {}
56
63
self ._epochs_to_process = tuple ()
57
64
self ._processed_epochs = set ()
58
65
@@ -89,22 +96,55 @@ def file(cls) -> Path:
89
96
def buffer (self ) -> Path :
90
97
return self .file ().with_suffix (".buf" )
91
98
99
+ @property
100
+ def is_empty (self ) -> bool :
101
+ return not self .data and not self ._epochs_to_process and not self ._processed_epochs
102
+
103
+ @property
104
+ def unprocessed_epochs (self ) -> set [EpochNumber ]:
105
+ if not self ._epochs_to_process :
106
+ raise ValueError ("Epochs to process are not set" )
107
+ diff = set (self ._epochs_to_process ) - self ._processed_epochs
108
+ return diff
109
+
110
+ @property
111
+ def is_fulfilled (self ) -> bool :
112
+ return not self .unprocessed_epochs
113
+
114
+ @staticmethod
115
+ def _calculate_frames (epochs_to_process : tuple [EpochNumber , ...], epochs_per_frame : int ) -> list [Frame ]:
116
+ """Split epochs to process into frames of `epochs_per_frame` length"""
117
+ if len (epochs_to_process ) % epochs_per_frame != 0 :
118
+ raise ValueError ("Insufficient epochs to form a frame" )
119
+ return [(frame [0 ], frame [- 1 ]) for frame in batched (sorted (epochs_to_process ), epochs_per_frame )]
120
+
92
121
def clear (self ) -> None :
93
- self .data = defaultdict ( AttestationsAccumulator )
122
+ self .data = {}
94
123
self ._epochs_to_process = tuple ()
95
124
self ._processed_epochs .clear ()
96
125
assert self .is_empty
97
126
98
- def inc (self , key : ValidatorIndex , included : bool ) -> None :
99
- self .data [key ].add_duty (included )
127
+ @lru_cache (variables .CSM_ORACLE_MAX_CONCURRENCY )
128
+ def find_frame (self , epoch : EpochNumber ) -> Frame :
129
+ for epoch_range in self .frames :
130
+ from_epoch , to_epoch = epoch_range
131
+ if from_epoch <= epoch <= to_epoch :
132
+ return epoch_range
133
+ raise ValueError (f"Epoch { epoch } is out of frames range: { self .frames } " )
134
+
135
+ def increment_duty (self , epoch : EpochNumber , val_index : ValidatorIndex , included : bool ) -> None :
136
+ frame = self .find_frame (epoch )
137
+ self .data [frame ][val_index ].add_duty (included )
100
138
101
139
def add_processed_epoch (self , epoch : EpochNumber ) -> None :
102
140
self ._processed_epochs .add (epoch )
103
141
104
142
def log_progress (self ) -> None :
105
143
logger .info ({"msg" : f"Processed { len (self ._processed_epochs )} of { len (self ._epochs_to_process )} epochs" })
106
144
107
- def migrate (self , l_epoch : EpochNumber , r_epoch : EpochNumber , consensus_version : int ):
145
+ def migrate (
146
+ self , l_epoch : EpochNumber , r_epoch : EpochNumber , epochs_per_frame : int , consensus_version : int
147
+ ) -> None :
108
148
if consensus_version != self ._consensus_version :
109
149
logger .warning (
110
150
{
@@ -114,17 +154,41 @@ def migrate(self, l_epoch: EpochNumber, r_epoch: EpochNumber, consensus_version:
114
154
)
115
155
self .clear ()
116
156
117
- for state_epochs in (self ._epochs_to_process , self ._processed_epochs ):
118
- for epoch in state_epochs :
119
- if epoch < l_epoch or epoch > r_epoch :
120
- logger .warning ({"msg" : "Discarding invalidated state cache" })
121
- self .clear ()
122
- break
157
+ new_frames = self ._calculate_frames (tuple (sequence (l_epoch , r_epoch )), epochs_per_frame )
158
+ if self .frames == new_frames :
159
+ logger .info ({"msg" : "No need to migrate duties data cache" })
160
+ return
161
+ self ._migrate_frames_data (new_frames )
123
162
163
+ self .frames = new_frames
164
+ self .find_frame .cache_clear ()
124
165
self ._epochs_to_process = tuple (sequence (l_epoch , r_epoch ))
125
166
self ._consensus_version = consensus_version
126
167
self .commit ()
127
168
169
+ def _migrate_frames_data (self , new_frames : list [Frame ]):
170
+ logger .info ({"msg" : f"Migrating duties data cache: { self .frames = } -> { new_frames = } " })
171
+ new_data : StateData = {frame : defaultdict (AttestationsAccumulator ) for frame in new_frames }
172
+
173
+ def overlaps (a : Frame , b : Frame ):
174
+ return a [0 ] <= b [0 ] and a [1 ] >= b [1 ]
175
+
176
+ consumed = []
177
+ for new_frame in new_frames :
178
+ for frame_to_consume in self .frames :
179
+ if overlaps (new_frame , frame_to_consume ):
180
+ assert frame_to_consume not in consumed
181
+ consumed .append (frame_to_consume )
182
+ for val , duty in self .data [frame_to_consume ].items ():
183
+ new_data [new_frame ][val ].assigned += duty .assigned
184
+ new_data [new_frame ][val ].included += duty .included
185
+ for frame in self .frames :
186
+ if frame in consumed :
187
+ continue
188
+ logger .warning ({"msg" : f"Invalidating frame duties data cache: { frame } " })
189
+ self ._processed_epochs -= set (sequence (* frame ))
190
+ self .data = new_data
191
+
128
192
def validate (self , l_epoch : EpochNumber , r_epoch : EpochNumber ) -> None :
129
193
if not self .is_fulfilled :
130
194
raise InvalidState (f"State is not fulfilled. { self .unprocessed_epochs = } " )
@@ -135,34 +199,15 @@ def validate(self, l_epoch: EpochNumber, r_epoch: EpochNumber) -> None:
135
199
136
200
for epoch in sequence (l_epoch , r_epoch ):
137
201
if epoch not in self ._processed_epochs :
138
- raise InvalidState (f"Epoch { epoch } should be processed" )
139
-
140
- @property
141
- def is_empty (self ) -> bool :
142
- return not self .data and not self ._epochs_to_process and not self ._processed_epochs
143
-
144
- @property
145
- def unprocessed_epochs (self ) -> set [EpochNumber ]:
146
- if not self ._epochs_to_process :
147
- raise ValueError ("Epochs to process are not set" )
148
- diff = set (self ._epochs_to_process ) - self ._processed_epochs
149
- return diff
150
-
151
- @property
152
- def is_fulfilled (self ) -> bool :
153
- return not self .unprocessed_epochs
154
-
155
- @property
156
- def frame (self ) -> tuple [EpochNumber , EpochNumber ]:
157
- if not self ._epochs_to_process :
158
- raise ValueError ("Epochs to process are not set" )
159
- return min (self ._epochs_to_process ), max (self ._epochs_to_process )
160
-
161
- def get_network_aggr (self ) -> AttestationsAccumulator :
162
- """Return `AttestationsAccumulator` over duties of all the network validators"""
202
+ raise InvalidState (f"Epoch { epoch } missing in processed epochs" )
163
203
204
+ def get_network_aggr (self , frame : Frame ) -> AttestationsAccumulator :
205
+ # TODO: exclude `active_slashed` validators from the calculation
164
206
included = assigned = 0
165
- for validator , acc in self .data .items ():
207
+ frame_data = self .data .get (frame )
208
+ if frame_data is None :
209
+ raise ValueError (f"No data for frame { frame } to calculate network aggregate" )
210
+ for validator , acc in frame_data .items ():
166
211
if acc .included > acc .assigned :
167
212
raise ValueError (f"Invalid accumulator: { validator = } , { acc = } " )
168
213
included += acc .included
0 commit comments