Skip to content

Commit 8fcfe96

Browse files
committed
feat: per frame data
1 parent 24a2219 commit 8fcfe96

File tree

8 files changed

+496
-153
lines changed

8 files changed

+496
-153
lines changed

src/modules/csm/checkpoint.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def exec(self, checkpoint: FrameCheckpoint) -> int:
143143
for duty_epoch in unprocessed_epochs
144144
}
145145
self._process(unprocessed_epochs, duty_epochs_roots)
146+
self.state.commit()
146147
return len(unprocessed_epochs)
147148

148149
def _get_block_roots(self, checkpoint_slot: SlotNumber):
@@ -208,14 +209,14 @@ def _check_duty(
208209
with lock:
209210
for committee in committees.values():
210211
for validator_duty in committee:
211-
self.state.inc(
212+
self.state.increment_duty(
213+
duty_epoch,
212214
validator_duty.index,
213215
included=validator_duty.included,
214216
)
215217
if duty_epoch not in self.state.unprocessed_epochs:
216218
raise ValueError(f"Epoch {duty_epoch} is not in epochs that should be processed")
217219
self.state.add_processed_epoch(duty_epoch)
218-
self.state.commit()
219220
self.state.log_progress()
220221
unprocessed_epochs = self.state.unprocessed_epochs
221222
CSM_UNPROCESSED_EPOCHS_COUNT.set(len(unprocessed_epochs))

src/modules/csm/csm.py

+56-17
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from src.metrics.prometheus.duration_meter import duration_meter
1414
from src.modules.csm.checkpoint import FrameCheckpointProcessor, FrameCheckpointsIterator, MinStepIsNotReached
1515
from src.modules.csm.log import FramePerfLog
16-
from src.modules.csm.state import State
16+
from src.modules.csm.state import State, Frame
1717
from src.modules.csm.tree import Tree
1818
from src.modules.csm.types import ReportData, Shares
1919
from src.modules.submodules.consensus import ConsensusModule
@@ -29,10 +29,11 @@
2929
SlotNumber,
3030
StakingModuleAddress,
3131
StakingModuleId,
32+
ValidatorIndex,
3233
)
3334
from src.utils.blockstamp import build_blockstamp
3435
from src.utils.cache import global_lru_cache as lru_cache
35-
from src.utils.slot import get_next_non_missed_slot
36+
from src.utils.slot import get_next_non_missed_slot, get_reference_blockstamp
3637
from src.utils.web3converter import Web3Converter
3738
from src.web3py.extensions.lido_validators import NodeOperatorId, StakingModule, ValidatorsByNodeOperator
3839
from src.web3py.types import Web3
@@ -101,12 +102,12 @@ def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple:
101102
if (prev_cid is None) != (prev_root == ZERO_HASH):
102103
raise InconsistentData(f"Got inconsistent previous tree data: {prev_root=} {prev_cid=}")
103104

104-
distributed, shares, log = self.calculate_distribution(blockstamp)
105+
distributed, shares, logs = self.calculate_distribution(blockstamp)
105106

106107
if distributed != sum(shares.values()):
107108
raise InconsistentData(f"Invalid distribution: {sum(shares.values())=} != {distributed=}")
108109

109-
log_cid = self.publish_log(log)
110+
log_cid = self.publish_log(logs)
110111

111112
if not distributed and not shares:
112113
logger.info({"msg": "No shares distributed in the current frame"})
@@ -201,7 +202,7 @@ def collect_data(self, blockstamp: BlockStamp) -> bool:
201202
logger.info({"msg": "The starting epoch of the frame is not finalized yet"})
202203
return False
203204

204-
self.state.migrate(l_epoch, r_epoch, consensus_version)
205+
self.state.init_or_migrate(l_epoch, r_epoch, converter.frame_config.epochs_per_frame, consensus_version)
205206
self.state.log_progress()
206207

207208
if self.state.is_fulfilled:
@@ -227,25 +228,64 @@ def collect_data(self, blockstamp: BlockStamp) -> bool:
227228

228229
def calculate_distribution(
229230
self, blockstamp: ReferenceBlockStamp
230-
) -> tuple[int, defaultdict[NodeOperatorId, int], FramePerfLog]:
231+
) -> tuple[int, defaultdict[NodeOperatorId, int], list[FramePerfLog]]:
231232
"""Computes distribution of fee shares at the given timestamp"""
232-
233-
network_avg_perf = self.state.get_network_aggr().perf
234-
threshold = network_avg_perf - self.w3.csm.oracle.perf_leeway_bp(blockstamp.block_hash) / TOTAL_BASIS_POINTS
235233
operators_to_validators = self.module_validators_by_node_operators(blockstamp)
236234

235+
distributed = 0
236+
# Calculate share of each CSM node operator.
237+
shares = defaultdict[NodeOperatorId, int](int)
238+
logs: list[FramePerfLog] = []
239+
240+
for frame in self.state.data:
241+
from_epoch, to_epoch = frame
242+
logger.info({"msg": f"Calculating distribution for frame [{from_epoch};{to_epoch}]"})
243+
frame_blockstamp = blockstamp
244+
if to_epoch != blockstamp.ref_epoch:
245+
frame_blockstamp = self._get_ref_blockstamp_for_frame(blockstamp, to_epoch)
246+
distributed_in_frame, shares_in_frame, log = self._calculate_distribution_in_frame(
247+
frame_blockstamp, operators_to_validators, frame, distributed
248+
)
249+
distributed += distributed_in_frame
250+
for no_id, share in shares_in_frame.items():
251+
shares[no_id] += share
252+
logs.append(log)
253+
254+
return distributed, shares, logs
255+
256+
def _get_ref_blockstamp_for_frame(
257+
self, blockstamp: ReferenceBlockStamp, frame_ref_epoch: EpochNumber
258+
) -> ReferenceBlockStamp:
259+
converter = self.converter(blockstamp)
260+
return get_reference_blockstamp(
261+
cc=self.w3.cc,
262+
ref_slot=converter.get_epoch_last_slot(frame_ref_epoch),
263+
ref_epoch=frame_ref_epoch,
264+
last_finalized_slot_number=blockstamp.slot_number,
265+
)
266+
267+
def _calculate_distribution_in_frame(
268+
self,
269+
blockstamp: ReferenceBlockStamp,
270+
operators_to_validators: ValidatorsByNodeOperator,
271+
frame: Frame,
272+
distributed: int,
273+
):
274+
network_perf = self.state.get_network_aggr(frame).perf
275+
threshold = network_perf - self.w3.csm.oracle.perf_leeway_bp(blockstamp.block_hash) / TOTAL_BASIS_POINTS
276+
237277
# Build the map of the current distribution operators.
238278
distribution: dict[NodeOperatorId, int] = defaultdict(int)
239279
stuck_operators = self.stuck_operators(blockstamp)
240-
log = FramePerfLog(blockstamp, self.state.frame, threshold)
280+
log = FramePerfLog(blockstamp, frame, threshold)
241281

242282
for (_, no_id), validators in operators_to_validators.items():
243283
if no_id in stuck_operators:
244284
log.operators[no_id].stuck = True
245285
continue
246286

247287
for v in validators:
248-
aggr = self.state.data.get(v.index)
288+
aggr = self.state.data[frame].get(ValidatorIndex(int(v.index)))
249289

250290
if aggr is None:
251291
# It's possible that the validator is not assigned to any duty, hence it's performance
@@ -268,13 +308,12 @@ def calculate_distribution(
268308
# Calculate share of each CSM node operator.
269309
shares = defaultdict[NodeOperatorId, int](int)
270310
total = sum(p for p in distribution.values())
311+
to_distribute = self.w3.csm.fee_distributor.shares_to_distribute(blockstamp.block_hash) - distributed
312+
log.distributable = to_distribute
271313

272314
if not total:
273315
return 0, shares, log
274316

275-
to_distribute = self.w3.csm.fee_distributor.shares_to_distribute(blockstamp.block_hash)
276-
log.distributable = to_distribute
277-
278317
for no_id, no_share in distribution.items():
279318
if no_share:
280319
shares[no_id] = to_distribute * no_share // total
@@ -348,9 +387,9 @@ def publish_tree(self, tree: Tree) -> CID:
348387
logger.info({"msg": "Tree dump uploaded to IPFS", "cid": repr(tree_cid)})
349388
return tree_cid
350389

351-
def publish_log(self, log: FramePerfLog) -> CID:
352-
log_cid = self.w3.ipfs.publish(log.encode())
353-
logger.info({"msg": "Frame log uploaded to IPFS", "cid": repr(log_cid)})
390+
def publish_log(self, logs: list[FramePerfLog]) -> CID:
391+
log_cid = self.w3.ipfs.publish(FramePerfLog.encode(logs))
392+
logger.info({"msg": "Frame(s) log uploaded to IPFS", "cid": repr(log_cid)})
354393
return log_cid
355394

356395
@lru_cache(maxsize=1)

src/modules/csm/log.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ class LogJSONEncoder(json.JSONEncoder): ...
1212

1313
@dataclass
1414
class ValidatorFrameSummary:
15+
# TODO: Should be renamed. Perf means different things in different contexts
1516
perf: AttestationsAccumulator = field(default_factory=AttestationsAccumulator)
1617
slashed: bool = False
1718

@@ -35,13 +36,14 @@ class FramePerfLog:
3536
default_factory=lambda: defaultdict(OperatorFrameSummary)
3637
)
3738

38-
def encode(self) -> bytes:
39+
@staticmethod
40+
def encode(logs: list['FramePerfLog']) -> bytes:
3941
return (
4042
LogJSONEncoder(
4143
indent=None,
4244
separators=(',', ':'),
4345
sort_keys=True,
4446
)
45-
.encode(asdict(self))
47+
.encode([asdict(log) for log in logs])
4648
.encode()
4749
)

0 commit comments

Comments
 (0)