From 8fcfe96ce5908fa6cfdb93e18563402495e5998e Mon Sep 17 00:00:00 2001 From: vgorkavenko Date: Wed, 30 Oct 2024 14:12:54 +0100 Subject: [PATCH 01/20] feat: per frame data --- src/modules/csm/checkpoint.py | 5 +- src/modules/csm/csm.py | 73 +++++++--- src/modules/csm/log.py | 6 +- src/modules/csm/state.py | 145 ++++++++++++++------ tests/modules/csm/test_checkpoint.py | 14 +- tests/modules/csm/test_csm_module.py | 196 ++++++++++++++++++++++++--- tests/modules/csm/test_log.py | 61 +++++++-- tests/modules/csm/test_state.py | 149 ++++++++++++-------- 8 files changed, 496 insertions(+), 153 deletions(-) diff --git a/src/modules/csm/checkpoint.py b/src/modules/csm/checkpoint.py index 0efc326c6..b111fe197 100644 --- a/src/modules/csm/checkpoint.py +++ b/src/modules/csm/checkpoint.py @@ -143,6 +143,7 @@ def exec(self, checkpoint: FrameCheckpoint) -> int: for duty_epoch in unprocessed_epochs } self._process(unprocessed_epochs, duty_epochs_roots) + self.state.commit() return len(unprocessed_epochs) def _get_block_roots(self, checkpoint_slot: SlotNumber): @@ -208,14 +209,14 @@ def _check_duty( with lock: for committee in committees.values(): for validator_duty in committee: - self.state.inc( + self.state.increment_duty( + duty_epoch, validator_duty.index, included=validator_duty.included, ) if duty_epoch not in self.state.unprocessed_epochs: raise ValueError(f"Epoch {duty_epoch} is not in epochs that should be processed") self.state.add_processed_epoch(duty_epoch) - self.state.commit() self.state.log_progress() unprocessed_epochs = self.state.unprocessed_epochs CSM_UNPROCESSED_EPOCHS_COUNT.set(len(unprocessed_epochs)) diff --git a/src/modules/csm/csm.py b/src/modules/csm/csm.py index 543a276e2..3d3619a94 100644 --- a/src/modules/csm/csm.py +++ b/src/modules/csm/csm.py @@ -13,7 +13,7 @@ from src.metrics.prometheus.duration_meter import duration_meter from src.modules.csm.checkpoint import FrameCheckpointProcessor, FrameCheckpointsIterator, MinStepIsNotReached from src.modules.csm.log import FramePerfLog -from src.modules.csm.state import State +from src.modules.csm.state import State, Frame from src.modules.csm.tree import Tree from src.modules.csm.types import ReportData, Shares from src.modules.submodules.consensus import ConsensusModule @@ -29,10 +29,11 @@ SlotNumber, StakingModuleAddress, StakingModuleId, + ValidatorIndex, ) from src.utils.blockstamp import build_blockstamp from src.utils.cache import global_lru_cache as lru_cache -from src.utils.slot import get_next_non_missed_slot +from src.utils.slot import get_next_non_missed_slot, get_reference_blockstamp from src.utils.web3converter import Web3Converter from src.web3py.extensions.lido_validators import NodeOperatorId, StakingModule, ValidatorsByNodeOperator from src.web3py.types import Web3 @@ -101,12 +102,12 @@ def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple: if (prev_cid is None) != (prev_root == ZERO_HASH): raise InconsistentData(f"Got inconsistent previous tree data: {prev_root=} {prev_cid=}") - distributed, shares, log = self.calculate_distribution(blockstamp) + distributed, shares, logs = self.calculate_distribution(blockstamp) if distributed != sum(shares.values()): raise InconsistentData(f"Invalid distribution: {sum(shares.values())=} != {distributed=}") - log_cid = self.publish_log(log) + log_cid = self.publish_log(logs) if not distributed and not shares: logger.info({"msg": "No shares distributed in the current frame"}) @@ -201,7 +202,7 @@ def collect_data(self, blockstamp: BlockStamp) -> bool: logger.info({"msg": "The starting epoch of the frame is not finalized yet"}) return False - self.state.migrate(l_epoch, r_epoch, consensus_version) + self.state.init_or_migrate(l_epoch, r_epoch, converter.frame_config.epochs_per_frame, consensus_version) self.state.log_progress() if self.state.is_fulfilled: @@ -227,17 +228,56 @@ def collect_data(self, blockstamp: BlockStamp) -> bool: def calculate_distribution( self, blockstamp: ReferenceBlockStamp - ) -> tuple[int, defaultdict[NodeOperatorId, int], FramePerfLog]: + ) -> tuple[int, defaultdict[NodeOperatorId, int], list[FramePerfLog]]: """Computes distribution of fee shares at the given timestamp""" - - network_avg_perf = self.state.get_network_aggr().perf - threshold = network_avg_perf - self.w3.csm.oracle.perf_leeway_bp(blockstamp.block_hash) / TOTAL_BASIS_POINTS operators_to_validators = self.module_validators_by_node_operators(blockstamp) + distributed = 0 + # Calculate share of each CSM node operator. + shares = defaultdict[NodeOperatorId, int](int) + logs: list[FramePerfLog] = [] + + for frame in self.state.data: + from_epoch, to_epoch = frame + logger.info({"msg": f"Calculating distribution for frame [{from_epoch};{to_epoch}]"}) + frame_blockstamp = blockstamp + if to_epoch != blockstamp.ref_epoch: + frame_blockstamp = self._get_ref_blockstamp_for_frame(blockstamp, to_epoch) + distributed_in_frame, shares_in_frame, log = self._calculate_distribution_in_frame( + frame_blockstamp, operators_to_validators, frame, distributed + ) + distributed += distributed_in_frame + for no_id, share in shares_in_frame.items(): + shares[no_id] += share + logs.append(log) + + return distributed, shares, logs + + def _get_ref_blockstamp_for_frame( + self, blockstamp: ReferenceBlockStamp, frame_ref_epoch: EpochNumber + ) -> ReferenceBlockStamp: + converter = self.converter(blockstamp) + return get_reference_blockstamp( + cc=self.w3.cc, + ref_slot=converter.get_epoch_last_slot(frame_ref_epoch), + ref_epoch=frame_ref_epoch, + last_finalized_slot_number=blockstamp.slot_number, + ) + + def _calculate_distribution_in_frame( + self, + blockstamp: ReferenceBlockStamp, + operators_to_validators: ValidatorsByNodeOperator, + frame: Frame, + distributed: int, + ): + network_perf = self.state.get_network_aggr(frame).perf + threshold = network_perf - self.w3.csm.oracle.perf_leeway_bp(blockstamp.block_hash) / TOTAL_BASIS_POINTS + # Build the map of the current distribution operators. distribution: dict[NodeOperatorId, int] = defaultdict(int) stuck_operators = self.stuck_operators(blockstamp) - log = FramePerfLog(blockstamp, self.state.frame, threshold) + log = FramePerfLog(blockstamp, frame, threshold) for (_, no_id), validators in operators_to_validators.items(): if no_id in stuck_operators: @@ -245,7 +285,7 @@ def calculate_distribution( continue for v in validators: - aggr = self.state.data.get(v.index) + aggr = self.state.data[frame].get(ValidatorIndex(int(v.index))) if aggr is None: # It's possible that the validator is not assigned to any duty, hence it's performance @@ -268,13 +308,12 @@ def calculate_distribution( # Calculate share of each CSM node operator. shares = defaultdict[NodeOperatorId, int](int) total = sum(p for p in distribution.values()) + to_distribute = self.w3.csm.fee_distributor.shares_to_distribute(blockstamp.block_hash) - distributed + log.distributable = to_distribute if not total: return 0, shares, log - to_distribute = self.w3.csm.fee_distributor.shares_to_distribute(blockstamp.block_hash) - log.distributable = to_distribute - for no_id, no_share in distribution.items(): if no_share: shares[no_id] = to_distribute * no_share // total @@ -348,9 +387,9 @@ def publish_tree(self, tree: Tree) -> CID: logger.info({"msg": "Tree dump uploaded to IPFS", "cid": repr(tree_cid)}) return tree_cid - def publish_log(self, log: FramePerfLog) -> CID: - log_cid = self.w3.ipfs.publish(log.encode()) - logger.info({"msg": "Frame log uploaded to IPFS", "cid": repr(log_cid)}) + def publish_log(self, logs: list[FramePerfLog]) -> CID: + log_cid = self.w3.ipfs.publish(FramePerfLog.encode(logs)) + logger.info({"msg": "Frame(s) log uploaded to IPFS", "cid": repr(log_cid)}) return log_cid @lru_cache(maxsize=1) diff --git a/src/modules/csm/log.py b/src/modules/csm/log.py index f89f4ef58..39832c8c0 100644 --- a/src/modules/csm/log.py +++ b/src/modules/csm/log.py @@ -12,6 +12,7 @@ class LogJSONEncoder(json.JSONEncoder): ... @dataclass class ValidatorFrameSummary: + # TODO: Should be renamed. Perf means different things in different contexts perf: AttestationsAccumulator = field(default_factory=AttestationsAccumulator) slashed: bool = False @@ -35,13 +36,14 @@ class FramePerfLog: default_factory=lambda: defaultdict(OperatorFrameSummary) ) - def encode(self) -> bytes: + @staticmethod + def encode(logs: list['FramePerfLog']) -> bytes: return ( LogJSONEncoder( indent=None, separators=(',', ':'), sort_keys=True, ) - .encode(asdict(self)) + .encode([asdict(log) for log in logs]) .encode() ) diff --git a/src/modules/csm/state.py b/src/modules/csm/state.py index 4373f5259..fd27a8d62 100644 --- a/src/modules/csm/state.py +++ b/src/modules/csm/state.py @@ -3,6 +3,7 @@ import pickle from collections import defaultdict from dataclasses import dataclass +from itertools import batched from pathlib import Path from typing import Self @@ -12,6 +13,8 @@ logger = logging.getLogger(__name__) +type Frame = tuple[EpochNumber, EpochNumber] + class InvalidState(ValueError): """State has data considered as invalid for a report""" @@ -43,18 +46,21 @@ class State: The state can be migrated to be used for another frame's report by calling the `migrate` method. """ - - data: defaultdict[ValidatorIndex, AttestationsAccumulator] + data: dict[Frame, defaultdict[ValidatorIndex, AttestationsAccumulator]] _epochs_to_process: tuple[EpochNumber, ...] _processed_epochs: set[EpochNumber] + _epochs_per_frame: int _consensus_version: int = 1 - def __init__(self, data: dict[ValidatorIndex, AttestationsAccumulator] | None = None) -> None: - self.data = defaultdict(AttestationsAccumulator, data or {}) + def __init__(self, data: dict[Frame, dict[ValidatorIndex, AttestationsAccumulator]] | None = None) -> None: + self.data = { + frame: defaultdict(AttestationsAccumulator, validators) for frame, validators in (data or {}).items() + } self._epochs_to_process = tuple() self._processed_epochs = set() + self._epochs_per_frame = 0 EXTENSION = ".pkl" @@ -89,14 +95,37 @@ def file(cls) -> Path: def buffer(self) -> Path: return self.file().with_suffix(".buf") + @property + def is_empty(self) -> bool: + return not self.data and not self._epochs_to_process and not self._processed_epochs + + @property + def unprocessed_epochs(self) -> set[EpochNumber]: + if not self._epochs_to_process: + raise ValueError("Epochs to process are not set") + diff = set(self._epochs_to_process) - self._processed_epochs + return diff + + @property + def is_fulfilled(self) -> bool: + return not self.unprocessed_epochs + def clear(self) -> None: - self.data = defaultdict(AttestationsAccumulator) + self.data = {} self._epochs_to_process = tuple() self._processed_epochs.clear() assert self.is_empty - def inc(self, key: ValidatorIndex, included: bool) -> None: - self.data[key].add_duty(included) + def find_frame(self, epoch: EpochNumber) -> Frame: + frames = self.data.keys() + for epoch_range in frames: + if epoch_range[0] <= epoch <= epoch_range[1]: + return epoch_range + raise ValueError(f"Epoch {epoch} is out of frames range: {frames}") + + def increment_duty(self, epoch: EpochNumber, val_index: ValidatorIndex, included: bool) -> None: + epoch_range = self.find_frame(epoch) + self.data[epoch_range][val_index].add_duty(included) def add_processed_epoch(self, epoch: EpochNumber) -> None: self._processed_epochs.add(epoch) @@ -104,7 +133,7 @@ def add_processed_epoch(self, epoch: EpochNumber) -> None: def log_progress(self) -> None: logger.info({"msg": f"Processed {len(self._processed_epochs)} of {len(self._epochs_to_process)} epochs"}) - def migrate(self, l_epoch: EpochNumber, r_epoch: EpochNumber, consensus_version: int): + def init_or_migrate(self, l_epoch: EpochNumber, r_epoch: EpochNumber, epochs_per_frame: int, consensus_version: int) -> None: if consensus_version != self._consensus_version: logger.warning( { @@ -114,17 +143,60 @@ def migrate(self, l_epoch: EpochNumber, r_epoch: EpochNumber, consensus_version: ) self.clear() - for state_epochs in (self._epochs_to_process, self._processed_epochs): - for epoch in state_epochs: - if epoch < l_epoch or epoch > r_epoch: - logger.warning({"msg": "Discarding invalidated state cache"}) - self.clear() - break + if not self.is_empty: + invalidated = self._migrate_or_invalidate(l_epoch, r_epoch, epochs_per_frame) + if invalidated: + self.clear() + self._fill_frames(l_epoch, r_epoch, epochs_per_frame) + self._epochs_per_frame = epochs_per_frame self._epochs_to_process = tuple(sequence(l_epoch, r_epoch)) self._consensus_version = consensus_version self.commit() + def _fill_frames(self, l_epoch: EpochNumber, r_epoch: EpochNumber, epochs_per_frame: int) -> None: + frames = self.calculate_frames(tuple(sequence(l_epoch, r_epoch)), epochs_per_frame) + for frame in frames: + self.data.setdefault(frame, defaultdict(AttestationsAccumulator)) + + def _migrate_or_invalidate(self, l_epoch: EpochNumber, r_epoch: EpochNumber, epochs_per_frame: int) -> bool: + current_frames = self.calculate_frames(self._epochs_to_process, self._epochs_per_frame) + new_frames = self.calculate_frames(tuple(sequence(l_epoch, r_epoch)), epochs_per_frame) + inv_msg = f"Discarding invalid state cache because of frames change. {current_frames=}, {new_frames=}" + + if self._invalidate_on_epoch_range_change(l_epoch, r_epoch): + logger.warning({"msg": inv_msg}) + return True + + frame_expanded = epochs_per_frame > self._epochs_per_frame + frame_shrunk = epochs_per_frame < self._epochs_per_frame + + has_single_frame = len(current_frames) == len(new_frames) == 1 + + if has_single_frame and frame_expanded: + current_frame, *_ = current_frames + new_frame, *_ = new_frames + self.data[new_frame] = self.data.pop(current_frame) + logger.info({"msg": f"Migrated state cache to a new frame. {current_frame=}, {new_frame=}"}) + return False + + if has_single_frame and frame_shrunk: + logger.warning({"msg": inv_msg}) + return True + + if not has_single_frame and frame_expanded or frame_shrunk: + logger.warning({"msg": inv_msg}) + return True + + return False + + def _invalidate_on_epoch_range_change(self, l_epoch: EpochNumber, r_epoch: EpochNumber) -> bool: + """Check if the epoch range has been invalidated.""" + for epoch_set in (self._epochs_to_process, self._processed_epochs): + if any(epoch < l_epoch or epoch > r_epoch for epoch in epoch_set): + return True + return False + def validate(self, l_epoch: EpochNumber, r_epoch: EpochNumber) -> None: if not self.is_fulfilled: raise InvalidState(f"State is not fulfilled. {self.unprocessed_epochs=}") @@ -135,34 +207,25 @@ def validate(self, l_epoch: EpochNumber, r_epoch: EpochNumber) -> None: for epoch in sequence(l_epoch, r_epoch): if epoch not in self._processed_epochs: - raise InvalidState(f"Epoch {epoch} should be processed") - - @property - def is_empty(self) -> bool: - return not self.data and not self._epochs_to_process and not self._processed_epochs - - @property - def unprocessed_epochs(self) -> set[EpochNumber]: - if not self._epochs_to_process: - raise ValueError("Epochs to process are not set") - diff = set(self._epochs_to_process) - self._processed_epochs - return diff - - @property - def is_fulfilled(self) -> bool: - return not self.unprocessed_epochs - - @property - def frame(self) -> tuple[EpochNumber, EpochNumber]: - if not self._epochs_to_process: - raise ValueError("Epochs to process are not set") - return min(self._epochs_to_process), max(self._epochs_to_process) - - def get_network_aggr(self) -> AttestationsAccumulator: - """Return `AttestationsAccumulator` over duties of all the network validators""" - + raise InvalidState(f"Epoch {epoch} missing in processed epochs") + + @staticmethod + def calculate_frames(epochs_to_process: tuple[EpochNumber, ...], epochs_per_frame: int) -> list[Frame]: + """Split epochs to process into frames of `epochs_per_frame` length""" + frames = [] + for frame_epochs in batched(epochs_to_process, epochs_per_frame): + if len(frame_epochs) < epochs_per_frame: + raise ValueError("Insufficient epochs to form a frame") + frames.append((frame_epochs[0], frame_epochs[-1])) + return frames + + def get_network_aggr(self, frame: Frame) -> AttestationsAccumulator: + # TODO: exclude `active_slashed` validators from the calculation included = assigned = 0 - for validator, acc in self.data.items(): + frame_data = self.data.get(frame) + if not frame_data: + raise ValueError(f"No data for frame {frame} to calculate network aggregate") + for validator, acc in frame_data.items(): if acc.included > acc.assigned: raise ValueError(f"Invalid accumulator: {validator=}, {acc=}") included += acc.included diff --git a/tests/modules/csm/test_checkpoint.py b/tests/modules/csm/test_checkpoint.py index 44f23735e..4b456ed03 100644 --- a/tests/modules/csm/test_checkpoint.py +++ b/tests/modules/csm/test_checkpoint.py @@ -326,7 +326,7 @@ def test_checkpoints_processor_no_eip7549_support( monkeypatch: pytest.MonkeyPatch, ): state = State() - state.migrate(EpochNumber(0), EpochNumber(255), 1) + state.init_or_migrate(EpochNumber(0), EpochNumber(255), 256, 1) processor = FrameCheckpointProcessor( consensus_client, state, @@ -354,7 +354,7 @@ def test_checkpoints_processor_check_duty( converter, ): state = State() - state.migrate(0, 255, 1) + state.init_or_migrate(0, 255, 256, 1) finalized_blockstamp = ... processor = FrameCheckpointProcessor( consensus_client, @@ -367,7 +367,7 @@ def test_checkpoints_processor_check_duty( assert len(state._processed_epochs) == 1 assert len(state._epochs_to_process) == 256 assert len(state.unprocessed_epochs) == 255 - assert len(state.data) == 2048 * 32 + assert len(state.data[(0, 255)]) == 2048 * 32 def test_checkpoints_processor_process( @@ -379,7 +379,7 @@ def test_checkpoints_processor_process( converter, ): state = State() - state.migrate(0, 255, 1) + state.init_or_migrate(0, 255, 256, 1) finalized_blockstamp = ... processor = FrameCheckpointProcessor( consensus_client, @@ -392,7 +392,7 @@ def test_checkpoints_processor_process( assert len(state._processed_epochs) == 2 assert len(state._epochs_to_process) == 256 assert len(state.unprocessed_epochs) == 254 - assert len(state.data) == 2048 * 32 + assert len(state.data[(0, 255)]) == 2048 * 32 def test_checkpoints_processor_exec( @@ -404,7 +404,7 @@ def test_checkpoints_processor_exec( converter, ): state = State() - state.migrate(0, 255, 1) + state.init_or_migrate(0, 255, 256, 1) finalized_blockstamp = ... processor = FrameCheckpointProcessor( consensus_client, @@ -418,4 +418,4 @@ def test_checkpoints_processor_exec( assert len(state._processed_epochs) == 2 assert len(state._epochs_to_process) == 256 assert len(state.unprocessed_epochs) == 254 - assert len(state.data) == 2048 * 32 + assert len(state.data[(0, 255)]) == 2048 * 32 diff --git a/tests/modules/csm/test_csm_module.py b/tests/modules/csm/test_csm_module.py index f74af8d69..cdb0c92c5 100644 --- a/tests/modules/csm/test_csm_module.py +++ b/tests/modules/csm/test_csm_module.py @@ -9,7 +9,7 @@ from src.constants import UINT64_MAX from src.modules.csm.csm import CSOracle -from src.modules.csm.state import AttestationsAccumulator, State +from src.modules.csm.state import AttestationsAccumulator, State, Frame from src.modules.csm.tree import Tree from src.modules.submodules.oracle_module import ModuleExecuteDelay from src.modules.submodules.types import CurrentFrame, ZERO_HASH @@ -166,26 +166,37 @@ def test_calculate_distribution(module: CSOracle, csm: CSM): ] ) + frame_0: Frame = (EpochNumber(0), EpochNumber(999)) + + module.state.init_or_migrate(*frame_0, epochs_per_frame=1000, consensus_version=1) module.state = State( { - ValidatorIndex(0): AttestationsAccumulator(included=200, assigned=200), # short on frame - ValidatorIndex(1): AttestationsAccumulator(included=1000, assigned=1000), - ValidatorIndex(2): AttestationsAccumulator(included=1000, assigned=1000), - ValidatorIndex(3): AttestationsAccumulator(included=999, assigned=1000), - ValidatorIndex(4): AttestationsAccumulator(included=900, assigned=1000), - ValidatorIndex(5): AttestationsAccumulator(included=500, assigned=1000), # underperforming - ValidatorIndex(6): AttestationsAccumulator(included=0, assigned=0), # underperforming - ValidatorIndex(7): AttestationsAccumulator(included=900, assigned=1000), - ValidatorIndex(8): AttestationsAccumulator(included=500, assigned=1000), # underperforming - # ValidatorIndex(9): AttestationsAggregate(included=0, assigned=0), # missing in state - ValidatorIndex(10): AttestationsAccumulator(included=1000, assigned=1000), - ValidatorIndex(11): AttestationsAccumulator(included=1000, assigned=1000), - ValidatorIndex(12): AttestationsAccumulator(included=1000, assigned=1000), + frame_0: { + ValidatorIndex(0): AttestationsAccumulator(included=200, assigned=200), # short on frame + ValidatorIndex(1): AttestationsAccumulator(included=1000, assigned=1000), + ValidatorIndex(2): AttestationsAccumulator(included=1000, assigned=1000), + ValidatorIndex(3): AttestationsAccumulator(included=999, assigned=1000), + ValidatorIndex(4): AttestationsAccumulator(included=900, assigned=1000), + ValidatorIndex(5): AttestationsAccumulator(included=500, assigned=1000), # underperforming + ValidatorIndex(6): AttestationsAccumulator(included=0, assigned=0), # underperforming + ValidatorIndex(7): AttestationsAccumulator(included=900, assigned=1000), + ValidatorIndex(8): AttestationsAccumulator(included=500, assigned=1000), # underperforming + # ValidatorIndex(9): AttestationsAggregate(included=0, assigned=0), # missing in state + ValidatorIndex(10): AttestationsAccumulator(included=1000, assigned=1000), + ValidatorIndex(11): AttestationsAccumulator(included=1000, assigned=1000), + ValidatorIndex(12): AttestationsAccumulator(included=1000, assigned=1000), + } } ) - module.state.migrate(EpochNumber(100), EpochNumber(500), 1) - _, shares, log = module.calculate_distribution(blockstamp=Mock()) + l_epoch, r_epoch = frame_0 + + frame_0_network_aggr = module.state.get_network_aggr(frame_0) + + blockstamp = ReferenceBlockStampFactory.build(slot_number=r_epoch * 32, ref_epoch=r_epoch, ref_slot=r_epoch * 32) + _, shares, logs = module.calculate_distribution(blockstamp=blockstamp) + + log, *_ = logs assert tuple(shares.items()) == ( (NodeOperatorId(0), 476), @@ -225,8 +236,157 @@ def test_calculate_distribution(module: CSOracle, csm: CSM): assert log.operators[NodeOperatorId(3)].distributed == 2380 assert log.operators[NodeOperatorId(6)].distributed == 2380 - assert log.frame == (100, 500) - assert log.threshold == module.state.get_network_aggr().perf - 0.05 + assert log.frame == frame_0 + assert log.threshold == frame_0_network_aggr.perf - 0.05 + + +def test_calculate_distribution_with_missed_with_two_frames(module: CSOracle, csm: CSM): + csm.oracle.perf_leeway_bp = Mock(return_value=500) + csm.fee_distributor.shares_to_distribute = Mock(side_effect=[10000, 20000]) + + module.module_validators_by_node_operators = Mock( + return_value={ + (None, NodeOperatorId(0)): [Mock(index=0, validator=Mock(slashed=False))], + (None, NodeOperatorId(1)): [Mock(index=1, validator=Mock(slashed=False))], + (None, NodeOperatorId(2)): [Mock(index=2, validator=Mock(slashed=False))], # stuck + (None, NodeOperatorId(3)): [Mock(index=3, validator=Mock(slashed=False))], + (None, NodeOperatorId(4)): [Mock(index=4, validator=Mock(slashed=False))], # stuck + (None, NodeOperatorId(5)): [ + Mock(index=5, validator=Mock(slashed=False)), + Mock(index=6, validator=Mock(slashed=False)), + ], + (None, NodeOperatorId(6)): [ + Mock(index=7, validator=Mock(slashed=False)), + Mock(index=8, validator=Mock(slashed=False)), + ], + (None, NodeOperatorId(7)): [Mock(index=9, validator=Mock(slashed=False))], + (None, NodeOperatorId(8)): [ + Mock(index=10, validator=Mock(slashed=False)), + Mock(index=11, validator=Mock(slashed=True)), + ], + (None, NodeOperatorId(9)): [Mock(index=12, validator=Mock(slashed=True))], + } + ) + + module.stuck_operators = Mock( + side_effect=[ + [ + NodeOperatorId(2), + NodeOperatorId(4), + ], + [ + NodeOperatorId(2), + NodeOperatorId(4), + ], + ] + ) + + module.state = State() + l_epoch, r_epoch = EpochNumber(0), EpochNumber(1999) + frame_0 = (0, 999) + frame_1 = (1000, 1999) + module.state.init_or_migrate(l_epoch, r_epoch, epochs_per_frame=1000, consensus_version=1) + module.state = State( + { + frame_0: { + ValidatorIndex(0): AttestationsAccumulator(included=200, assigned=200), # short on frame + ValidatorIndex(1): AttestationsAccumulator(included=1000, assigned=1000), + ValidatorIndex(2): AttestationsAccumulator(included=1000, assigned=1000), + ValidatorIndex(3): AttestationsAccumulator(included=999, assigned=1000), + ValidatorIndex(4): AttestationsAccumulator(included=900, assigned=1000), + ValidatorIndex(5): AttestationsAccumulator(included=500, assigned=1000), # underperforming + ValidatorIndex(6): AttestationsAccumulator(included=0, assigned=0), # underperforming + ValidatorIndex(7): AttestationsAccumulator(included=900, assigned=1000), + ValidatorIndex(8): AttestationsAccumulator(included=500, assigned=1000), # underperforming + # ValidatorIndex(9): AttestationsAggregate(included=0, assigned=0), # missing in state + ValidatorIndex(10): AttestationsAccumulator(included=1000, assigned=1000), + ValidatorIndex(11): AttestationsAccumulator(included=1000, assigned=1000), + ValidatorIndex(12): AttestationsAccumulator(included=1000, assigned=1000), + }, + frame_1: { + ValidatorIndex(0): AttestationsAccumulator(included=200, assigned=200), # short on frame + ValidatorIndex(1): AttestationsAccumulator(included=1000, assigned=1000), + ValidatorIndex(2): AttestationsAccumulator(included=1000, assigned=1000), + ValidatorIndex(3): AttestationsAccumulator(included=999, assigned=1000), + ValidatorIndex(4): AttestationsAccumulator(included=900, assigned=1000), + ValidatorIndex(5): AttestationsAccumulator(included=500, assigned=1000), # underperforming + ValidatorIndex(6): AttestationsAccumulator(included=0, assigned=0), # underperforming + ValidatorIndex(7): AttestationsAccumulator(included=900, assigned=1000), + ValidatorIndex(8): AttestationsAccumulator(included=500, assigned=1000), # underperforming + # ValidatorIndex(9): AttestationsAggregate(included=0, assigned=0), # missing in state + ValidatorIndex(10): AttestationsAccumulator(included=1000, assigned=1000), + ValidatorIndex(11): AttestationsAccumulator(included=1000, assigned=1000), + ValidatorIndex(12): AttestationsAccumulator(included=1000, assigned=1000), + }, + } + ) + module.w3.cc = Mock() + + module.converter = Mock( + side_effect=lambda _: Mock( + frame_config=FrameConfigFactory.build(epochs_per_frame=1000), + get_epoch_last_slot=lambda epoch: epoch * 32 + 31, + ) + ) + + module._get_ref_blockstamp_for_frame = Mock( + side_effect=[ + ReferenceBlockStampFactory.build( + slot_number=frame_0[1] * 32, ref_epoch=frame_0[1], ref_slot=frame_0[1] * 32 + ), + ReferenceBlockStampFactory.build(slot_number=r_epoch * 32, ref_epoch=r_epoch, ref_slot=r_epoch * 32), + ] + ) + + blockstamp = ReferenceBlockStampFactory.build(slot_number=r_epoch * 32, ref_epoch=r_epoch, ref_slot=r_epoch * 32) + distributed, shares, logs = module.calculate_distribution(blockstamp=blockstamp) + + assert distributed == 2 * 9_998 # because of the rounding + + assert tuple(shares.items()) == ( + (NodeOperatorId(0), 952), + (NodeOperatorId(1), 4761), + (NodeOperatorId(3), 4761), + (NodeOperatorId(6), 4761), + (NodeOperatorId(8), 4761), + ) + + assert len(logs) == 2 + + for log in logs: + + assert log.frame in module.state.data.keys() + assert log.threshold == module.state.get_network_aggr(log.frame).perf - 0.05 + + assert tuple(log.operators.keys()) == ( + NodeOperatorId(0), + NodeOperatorId(1), + NodeOperatorId(2), + NodeOperatorId(3), + NodeOperatorId(4), + NodeOperatorId(5), + NodeOperatorId(6), + # NodeOperatorId(7), # Missing in state + NodeOperatorId(8), + NodeOperatorId(9), + ) + + assert not log.operators[NodeOperatorId(1)].stuck + + assert log.operators[NodeOperatorId(2)].validators == {} + assert log.operators[NodeOperatorId(2)].stuck + assert log.operators[NodeOperatorId(4)].validators == {} + assert log.operators[NodeOperatorId(4)].stuck + + assert 5 in log.operators[NodeOperatorId(5)].validators + assert 6 in log.operators[NodeOperatorId(5)].validators + assert 7 in log.operators[NodeOperatorId(6)].validators + + assert log.operators[NodeOperatorId(0)].distributed == 476 + assert log.operators[NodeOperatorId(1)].distributed in [2380, 2381] + assert log.operators[NodeOperatorId(2)].distributed == 0 + assert log.operators[NodeOperatorId(3)].distributed in [2380, 2381] + assert log.operators[NodeOperatorId(6)].distributed in [2380, 2381] # Static functions you were dreaming of for so long. diff --git a/tests/modules/csm/test_log.py b/tests/modules/csm/test_log.py index de52ca9ef..61004e9ed 100644 --- a/tests/modules/csm/test_log.py +++ b/tests/modules/csm/test_log.py @@ -1,8 +1,7 @@ import json import pytest -from src.modules.csm.log import FramePerfLog -from src.modules.csm.state import AttestationsAccumulator +from src.modules.csm.log import FramePerfLog, AttestationsAccumulator from src.types import EpochNumber, NodeOperatorId, ReferenceBlockStamp from tests.factory.blockstamp import ReferenceBlockStampFactory @@ -33,16 +32,56 @@ def test_log_encode(log: FramePerfLog): log.operators[NodeOperatorId(42)].distributed = 17 log.operators[NodeOperatorId(0)].distributed = 0 - encoded = log.encode() + logs = [log] + + encoded = FramePerfLog.encode(logs) + + for decoded in json.loads(encoded): + assert decoded["operators"]["42"]["validators"]["41337"]["perf"]["assigned"] == 220 + assert decoded["operators"]["42"]["validators"]["41337"]["perf"]["included"] == 119 + assert decoded["operators"]["42"]["distributed"] == 17 + assert decoded["operators"]["0"]["distributed"] == 0 + + assert decoded["blockstamp"]["block_hash"] == log.blockstamp.block_hash + assert decoded["blockstamp"]["ref_slot"] == log.blockstamp.ref_slot + + assert decoded["threshold"] == log.threshold + assert decoded["frame"] == list(log.frame) + + +def test_logs_encode(): + log_0 = FramePerfLog(ReferenceBlockStampFactory.build(), (EpochNumber(100), EpochNumber(500))) + log_0.operators[NodeOperatorId(42)].validators["41337"].perf = AttestationsAccumulator(220, 119) + log_0.operators[NodeOperatorId(42)].distributed = 17 + log_0.operators[NodeOperatorId(0)].distributed = 0 + + log_1 = FramePerfLog(ReferenceBlockStampFactory.build(), (EpochNumber(500), EpochNumber(900))) + log_1.operators[NodeOperatorId(5)].validators["1234"].perf = AttestationsAccumulator(400, 399) + log_1.operators[NodeOperatorId(5)].distributed = 40 + log_1.operators[NodeOperatorId(18)].distributed = 0 + + logs = [log_0, log_1] + + encoded = FramePerfLog.encode(logs) + decoded = json.loads(encoded) - assert decoded["operators"]["42"]["validators"]["41337"]["perf"]["assigned"] == 220 - assert decoded["operators"]["42"]["validators"]["41337"]["perf"]["included"] == 119 - assert decoded["operators"]["42"]["distributed"] == 17 - assert decoded["operators"]["0"]["distributed"] == 0 + assert len(decoded) == 2 + + assert decoded[0]["operators"]["42"]["validators"]["41337"]["perf"]["assigned"] == 220 + assert decoded[0]["operators"]["42"]["validators"]["41337"]["perf"]["included"] == 119 + assert decoded[0]["operators"]["42"]["distributed"] == 17 + assert decoded[0]["operators"]["0"]["distributed"] == 0 + + assert decoded[1]["operators"]["5"]["validators"]["1234"]["perf"]["assigned"] == 400 + assert decoded[1]["operators"]["5"]["validators"]["1234"]["perf"]["included"] == 399 + assert decoded[1]["operators"]["5"]["distributed"] == 40 + assert decoded[1]["operators"]["18"]["distributed"] == 0 - assert decoded["blockstamp"]["block_hash"] == log.blockstamp.block_hash - assert decoded["blockstamp"]["ref_slot"] == log.blockstamp.ref_slot + for i, log in enumerate(logs): + assert decoded[i]["blockstamp"]["block_hash"] == log.blockstamp.block_hash + assert decoded[i]["blockstamp"]["ref_slot"] == log.blockstamp.ref_slot - assert decoded["threshold"] == log.threshold - assert decoded["frame"] == list(log.frame) + assert decoded[i]["threshold"] == log.threshold + assert decoded[i]["frame"] == list(log.frame) + assert decoded[i]["distributable"] == log.distributable diff --git a/tests/modules/csm/test_state.py b/tests/modules/csm/test_state.py index 7539f7d26..d781522e2 100644 --- a/tests/modules/csm/test_state.py +++ b/tests/modules/csm/test_state.py @@ -26,51 +26,43 @@ def test_attestation_aggregate_perf(): def test_state_avg_perf(): state = State() - assert state.get_network_aggr().perf == 0 + frame = (0, 999) - state = State( - { + with pytest.raises(ValueError): + state.get_network_aggr(frame) + + state = State() + state.init_or_migrate(*frame, 1000, 1) + state.data = { + frame: { ValidatorIndex(0): AttestationsAccumulator(included=0, assigned=0), ValidatorIndex(1): AttestationsAccumulator(included=0, assigned=0), } - ) + } - assert state.get_network_aggr().perf == 0 + assert state.get_network_aggr(frame).perf == 0 - state = State( - { + state.data = { + frame: { ValidatorIndex(0): AttestationsAccumulator(included=333, assigned=777), ValidatorIndex(1): AttestationsAccumulator(included=167, assigned=223), } - ) - - assert state.get_network_aggr().perf == 0.5 - + } -def test_state_frame(): - state = State() - - state.migrate(EpochNumber(100), EpochNumber(500), 1) - assert state.frame == (100, 500) - - state.migrate(EpochNumber(300), EpochNumber(301), 1) - assert state.frame == (300, 301) - - state.clear() - - with pytest.raises(ValueError, match="Epochs to process are not set"): - state.frame + assert state.get_network_aggr(frame).perf == 0.5 def test_state_attestations(): state = State( { - ValidatorIndex(0): AttestationsAccumulator(included=333, assigned=777), - ValidatorIndex(1): AttestationsAccumulator(included=167, assigned=223), + (0, 999): { + ValidatorIndex(0): AttestationsAccumulator(included=333, assigned=777), + ValidatorIndex(1): AttestationsAccumulator(included=167, assigned=223), + } } ) - network_aggr = state.get_network_aggr() + network_aggr = state.get_network_aggr((0, 999)) assert network_aggr.assigned == 1000 assert network_aggr.included == 500 @@ -79,8 +71,10 @@ def test_state_attestations(): def test_state_load(): orig = State( { - ValidatorIndex(0): AttestationsAccumulator(included=333, assigned=777), - ValidatorIndex(1): AttestationsAccumulator(included=167, assigned=223), + (0, 999): { + ValidatorIndex(0): AttestationsAccumulator(included=333, assigned=777), + ValidatorIndex(1): AttestationsAccumulator(included=167, assigned=223), + } } ) @@ -92,8 +86,10 @@ def test_state_load(): def test_state_clear(): state = State( { - ValidatorIndex(0): AttestationsAccumulator(included=333, assigned=777), - ValidatorIndex(1): AttestationsAccumulator(included=167, assigned=223), + (0, 999): { + ValidatorIndex(0): AttestationsAccumulator(included=333, assigned=777), + ValidatorIndex(1): AttestationsAccumulator(included=167, assigned=223), + } } ) @@ -113,27 +109,42 @@ def test_state_add_processed_epoch(): def test_state_inc(): + + frame_0 = (0, 999) + frame_1 = (1000, 1999) + state = State( { - ValidatorIndex(0): AttestationsAccumulator(included=0, assigned=0), - ValidatorIndex(1): AttestationsAccumulator(included=1, assigned=2), + frame_0: { + ValidatorIndex(0): AttestationsAccumulator(included=333, assigned=777), + ValidatorIndex(1): AttestationsAccumulator(included=167, assigned=223), + }, + frame_1: { + ValidatorIndex(0): AttestationsAccumulator(included=1, assigned=1), + ValidatorIndex(1): AttestationsAccumulator(included=0, assigned=1), + }, } ) - state.inc(ValidatorIndex(0), True) - state.inc(ValidatorIndex(0), False) + state.increment_duty(999, ValidatorIndex(0), True) + state.increment_duty(999, ValidatorIndex(0), False) + state.increment_duty(999, ValidatorIndex(1), True) + state.increment_duty(999, ValidatorIndex(1), True) + state.increment_duty(999, ValidatorIndex(1), False) + state.increment_duty(999, ValidatorIndex(2), True) - state.inc(ValidatorIndex(1), True) - state.inc(ValidatorIndex(1), True) - state.inc(ValidatorIndex(1), False) + state.increment_duty(1000, ValidatorIndex(2), False) - state.inc(ValidatorIndex(2), True) - state.inc(ValidatorIndex(2), False) + assert tuple(state.data[frame_0].values()) == ( + AttestationsAccumulator(included=334, assigned=779), + AttestationsAccumulator(included=169, assigned=226), + AttestationsAccumulator(included=1, assigned=1), + ) - assert tuple(state.data.values()) == ( - AttestationsAccumulator(included=1, assigned=2), - AttestationsAccumulator(included=3, assigned=5), - AttestationsAccumulator(included=1, assigned=2), + assert tuple(state.data[frame_1].values()) == ( + AttestationsAccumulator(included=1, assigned=1), + AttestationsAccumulator(included=0, assigned=1), + AttestationsAccumulator(included=0, assigned=1), ) @@ -155,7 +166,7 @@ def test_empty_to_new_frame(self): l_epoch = EpochNumber(1) r_epoch = EpochNumber(255) - state.migrate(l_epoch, r_epoch, 1) + state.init_or_migrate(l_epoch, r_epoch, 255, 1) assert not state.is_empty assert state.unprocessed_epochs == set(sequence(l_epoch, r_epoch)) @@ -171,32 +182,60 @@ def test_empty_to_new_frame(self): def test_new_frame_requires_discarding_state(self, l_epoch_old, r_epoch_old, l_epoch_new, r_epoch_new): state = State() state.clear = Mock(side_effect=state.clear) - state.migrate(l_epoch_old, r_epoch_old, 1) + state.init_or_migrate(l_epoch_old, r_epoch_old, r_epoch_old - l_epoch_old + 1, 1) state.clear.assert_not_called() - state.migrate(l_epoch_new, r_epoch_new, 1) + state.init_or_migrate(l_epoch_new, r_epoch_new, r_epoch_new - l_epoch_new + 1, 1) state.clear.assert_called_once() assert state.unprocessed_epochs == set(sequence(l_epoch_new, r_epoch_new)) @pytest.mark.parametrize( - ("l_epoch_old", "r_epoch_old", "l_epoch_new", "r_epoch_new"), + ("l_epoch_old", "r_epoch_old", "l_epoch_new", "r_epoch_new", "epochs_per_frame"), + [ + pytest.param(1, 255, 1, 510, 255, id="Migrate Aa..b..B"), + ], + ) + def test_new_frame_extends_old_state(self, l_epoch_old, r_epoch_old, l_epoch_new, r_epoch_new, epochs_per_frame): + state = State() + state.clear = Mock(side_effect=state.clear) + + state.init_or_migrate(l_epoch_old, r_epoch_old, epochs_per_frame, 1) + state.clear.assert_not_called() + + state.init_or_migrate(l_epoch_new, r_epoch_new, epochs_per_frame, 1) + state.clear.assert_not_called() + + assert state.unprocessed_epochs == set(sequence(l_epoch_new, r_epoch_new)) + assert len(state.data) == 2 + assert list(state.data.keys()) == [(l_epoch_old, r_epoch_old), (r_epoch_old + 1, r_epoch_new)] + assert state.calculate_frames(state._epochs_to_process, epochs_per_frame) == [ + (l_epoch_old, r_epoch_old), + (r_epoch_old + 1, r_epoch_new), + ] + + @pytest.mark.parametrize( + ("l_epoch_old", "r_epoch_old", "epochs_per_frame_old", "l_epoch_new", "r_epoch_new", "epochs_per_frame_new"), [ - pytest.param(1, 255, 1, 510, id="Migrate Aa..b..B"), - pytest.param(32, 510, 1, 510, id="Migrate: A..a..b..B"), + pytest.param(32, 510, 479, 1, 510, 510, id="Migrate: A..a..b..B"), ], ) - def test_new_frame_extends_old_state(self, l_epoch_old, r_epoch_old, l_epoch_new, r_epoch_new): + def test_new_frame_extends_old_state_with_single_frame( + self, l_epoch_old, r_epoch_old, epochs_per_frame_old, l_epoch_new, r_epoch_new, epochs_per_frame_new + ): state = State() state.clear = Mock(side_effect=state.clear) - state.migrate(l_epoch_old, r_epoch_old, 1) + state.init_or_migrate(l_epoch_old, r_epoch_old, epochs_per_frame_old, 1) state.clear.assert_not_called() - state.migrate(l_epoch_new, r_epoch_new, 1) + state.init_or_migrate(l_epoch_new, r_epoch_new, epochs_per_frame_new, 1) state.clear.assert_not_called() assert state.unprocessed_epochs == set(sequence(l_epoch_new, r_epoch_new)) + assert len(state.data) == 1 + assert list(state.data.keys())[0] == (l_epoch_new, r_epoch_new) + assert state.calculate_frames(state._epochs_to_process, epochs_per_frame_new) == [(l_epoch_new, r_epoch_new)] @pytest.mark.parametrize( ("old_version", "new_version"), @@ -212,8 +251,8 @@ def test_consensus_version_change(self, old_version, new_version): l_epoch = r_epoch = EpochNumber(255) - state.migrate(l_epoch, r_epoch, old_version) + state.init_or_migrate(l_epoch, r_epoch, 1, old_version) state.clear.assert_not_called() - state.migrate(l_epoch, r_epoch, new_version) + state.init_or_migrate(l_epoch, r_epoch, 1, new_version) state.clear.assert_called_once() From 1aa722885ffbda44f09d8417864f62a307c27355 Mon Sep 17 00:00:00 2001 From: vgorkavenko Date: Thu, 13 Feb 2025 12:43:14 +0100 Subject: [PATCH 02/20] refactor: `State` and tests --- src/modules/csm/checkpoint.py | 4 +- src/modules/csm/state.py | 135 ++++---- tests/modules/csm/test_state.py | 592 ++++++++++++++++++++------------ 3 files changed, 445 insertions(+), 286 deletions(-) diff --git a/src/modules/csm/checkpoint.py b/src/modules/csm/checkpoint.py index b111fe197..69d0a79dd 100644 --- a/src/modules/csm/checkpoint.py +++ b/src/modules/csm/checkpoint.py @@ -205,12 +205,12 @@ def _check_duty( for root in block_roots: attestations = self.cc.get_block_attestations(root) process_attestations(attestations, committees, self.eip7549_supported) - + frame = self.state.find_frame(duty_epoch) with lock: for committee in committees.values(): for validator_duty in committee: self.state.increment_duty( - duty_epoch, + frame, validator_duty.index, included=validator_duty.included, ) diff --git a/src/modules/csm/state.py b/src/modules/csm/state.py index fd27a8d62..c269b7fcb 100644 --- a/src/modules/csm/state.py +++ b/src/modules/csm/state.py @@ -13,8 +13,6 @@ logger = logging.getLogger(__name__) -type Frame = tuple[EpochNumber, EpochNumber] - class InvalidState(ValueError): """State has data considered as invalid for a report""" @@ -36,6 +34,10 @@ def add_duty(self, included: bool) -> None: self.included += 1 if included else 0 +type Frame = tuple[EpochNumber, EpochNumber] +type StateData = dict[Frame, defaultdict[ValidatorIndex, AttestationsAccumulator]] + + class State: """ Processing state of a CSM performance oracle frame. @@ -46,7 +48,7 @@ class State: The state can be migrated to be used for another frame's report by calling the `migrate` method. """ - data: dict[Frame, defaultdict[ValidatorIndex, AttestationsAccumulator]] + data: StateData _epochs_to_process: tuple[EpochNumber, ...] _processed_epochs: set[EpochNumber] @@ -54,10 +56,8 @@ class State: _consensus_version: int = 1 - def __init__(self, data: dict[Frame, dict[ValidatorIndex, AttestationsAccumulator]] | None = None) -> None: - self.data = { - frame: defaultdict(AttestationsAccumulator, validators) for frame, validators in (data or {}).items() - } + def __init__(self) -> None: + self.data = {} self._epochs_to_process = tuple() self._processed_epochs = set() self._epochs_per_frame = 0 @@ -110,6 +110,16 @@ def unprocessed_epochs(self) -> set[EpochNumber]: def is_fulfilled(self) -> bool: return not self.unprocessed_epochs + @staticmethod + def calculate_frames(epochs_to_process: tuple[EpochNumber, ...], epochs_per_frame: int) -> list[Frame]: + """Split epochs to process into frames of `epochs_per_frame` length""" + frames = [] + for frame_epochs in batched(epochs_to_process, epochs_per_frame): + if len(frame_epochs) < epochs_per_frame: + raise ValueError("Insufficient epochs to form a frame") + frames.append((frame_epochs[0], frame_epochs[-1])) + return frames + def clear(self) -> None: self.data = {} self._epochs_to_process = tuple() @@ -123,9 +133,10 @@ def find_frame(self, epoch: EpochNumber) -> Frame: return epoch_range raise ValueError(f"Epoch {epoch} is out of frames range: {frames}") - def increment_duty(self, epoch: EpochNumber, val_index: ValidatorIndex, included: bool) -> None: - epoch_range = self.find_frame(epoch) - self.data[epoch_range][val_index].add_duty(included) + def increment_duty(self, frame: Frame, val_index: ValidatorIndex, included: bool) -> None: + if frame not in self.data: + raise ValueError(f"Frame {frame} is not found in the state") + self.data[frame][val_index].add_duty(included) def add_processed_epoch(self, epoch: EpochNumber) -> None: self._processed_epochs.add(epoch) @@ -133,7 +144,9 @@ def add_processed_epoch(self, epoch: EpochNumber) -> None: def log_progress(self) -> None: logger.info({"msg": f"Processed {len(self._processed_epochs)} of {len(self._epochs_to_process)} epochs"}) - def init_or_migrate(self, l_epoch: EpochNumber, r_epoch: EpochNumber, epochs_per_frame: int, consensus_version: int) -> None: + def init_or_migrate( + self, l_epoch: EpochNumber, r_epoch: EpochNumber, epochs_per_frame: int, consensus_version: int + ) -> None: if consensus_version != self._consensus_version: logger.warning( { @@ -143,59 +156,55 @@ def init_or_migrate(self, l_epoch: EpochNumber, r_epoch: EpochNumber, epochs_per ) self.clear() + frames = self.calculate_frames(tuple(sequence(l_epoch, r_epoch)), epochs_per_frame) + frames_data: StateData = {frame: defaultdict(AttestationsAccumulator) for frame in frames} + if not self.is_empty: - invalidated = self._migrate_or_invalidate(l_epoch, r_epoch, epochs_per_frame) - if invalidated: - self.clear() + cached_frames = self.calculate_frames(self._epochs_to_process, self._epochs_per_frame) + if cached_frames == frames: + logger.info({"msg": "No need to migrate duties data cache"}) + return + + frames_data, migration_status = self._migrate_frames_data(cached_frames, frames) + + for current_frame, migrated in migration_status.items(): + if not migrated: + logger.warning({"msg": f"Invalidating frame duties data cache: {current_frame}"}) + for epoch in sequence(*current_frame): + self._processed_epochs.discard(epoch) - self._fill_frames(l_epoch, r_epoch, epochs_per_frame) + self.data = frames_data self._epochs_per_frame = epochs_per_frame self._epochs_to_process = tuple(sequence(l_epoch, r_epoch)) self._consensus_version = consensus_version self.commit() - def _fill_frames(self, l_epoch: EpochNumber, r_epoch: EpochNumber, epochs_per_frame: int) -> None: - frames = self.calculate_frames(tuple(sequence(l_epoch, r_epoch)), epochs_per_frame) - for frame in frames: - self.data.setdefault(frame, defaultdict(AttestationsAccumulator)) - - def _migrate_or_invalidate(self, l_epoch: EpochNumber, r_epoch: EpochNumber, epochs_per_frame: int) -> bool: - current_frames = self.calculate_frames(self._epochs_to_process, self._epochs_per_frame) - new_frames = self.calculate_frames(tuple(sequence(l_epoch, r_epoch)), epochs_per_frame) - inv_msg = f"Discarding invalid state cache because of frames change. {current_frames=}, {new_frames=}" - - if self._invalidate_on_epoch_range_change(l_epoch, r_epoch): - logger.warning({"msg": inv_msg}) - return True - - frame_expanded = epochs_per_frame > self._epochs_per_frame - frame_shrunk = epochs_per_frame < self._epochs_per_frame - - has_single_frame = len(current_frames) == len(new_frames) == 1 - - if has_single_frame and frame_expanded: - current_frame, *_ = current_frames - new_frame, *_ = new_frames - self.data[new_frame] = self.data.pop(current_frame) - logger.info({"msg": f"Migrated state cache to a new frame. {current_frame=}, {new_frame=}"}) - return False - - if has_single_frame and frame_shrunk: - logger.warning({"msg": inv_msg}) - return True - - if not has_single_frame and frame_expanded or frame_shrunk: - logger.warning({"msg": inv_msg}) - return True - - return False - - def _invalidate_on_epoch_range_change(self, l_epoch: EpochNumber, r_epoch: EpochNumber) -> bool: - """Check if the epoch range has been invalidated.""" - for epoch_set in (self._epochs_to_process, self._processed_epochs): - if any(epoch < l_epoch or epoch > r_epoch for epoch in epoch_set): - return True - return False + def _migrate_frames_data( + self, current_frames: list[Frame], new_frames: list[Frame] + ) -> tuple[StateData, dict[Frame, bool]]: + migration_status = {frame: False for frame in current_frames} + new_data: StateData = {frame: defaultdict(AttestationsAccumulator) for frame in new_frames} + + logger.info({"msg": f"Migrating duties data cache: {current_frames=} -> {new_frames=}"}) + + for current_frame in current_frames: + curr_frame_l_epoch, curr_frame_r_epoch = current_frame + for new_frame in new_frames: + if current_frame == new_frame: + new_data[new_frame] = self.data[current_frame] + migration_status[current_frame] = True + break + + new_frame_l_epoch, new_frame_r_epoch = new_frame + if curr_frame_l_epoch >= new_frame_l_epoch and curr_frame_r_epoch <= new_frame_r_epoch: + logger.info({"msg": f"Migrating frame duties data cache: {current_frame=} -> {new_frame=}"}) + for val in self.data[current_frame]: + new_data[new_frame][val].assigned += self.data[current_frame][val].assigned + new_data[new_frame][val].included += self.data[current_frame][val].included + migration_status[current_frame] = True + break + + return new_data, migration_status def validate(self, l_epoch: EpochNumber, r_epoch: EpochNumber) -> None: if not self.is_fulfilled: @@ -209,21 +218,11 @@ def validate(self, l_epoch: EpochNumber, r_epoch: EpochNumber) -> None: if epoch not in self._processed_epochs: raise InvalidState(f"Epoch {epoch} missing in processed epochs") - @staticmethod - def calculate_frames(epochs_to_process: tuple[EpochNumber, ...], epochs_per_frame: int) -> list[Frame]: - """Split epochs to process into frames of `epochs_per_frame` length""" - frames = [] - for frame_epochs in batched(epochs_to_process, epochs_per_frame): - if len(frame_epochs) < epochs_per_frame: - raise ValueError("Insufficient epochs to form a frame") - frames.append((frame_epochs[0], frame_epochs[-1])) - return frames - def get_network_aggr(self, frame: Frame) -> AttestationsAccumulator: # TODO: exclude `active_slashed` validators from the calculation included = assigned = 0 frame_data = self.data.get(frame) - if not frame_data: + if frame_data is None: raise ValueError(f"No data for frame {frame} to calculate network aggregate") for validator, acc in frame_data.items(): if acc.included > acc.assigned: diff --git a/tests/modules/csm/test_state.py b/tests/modules/csm/test_state.py index d781522e2..b5d8f8808 100644 --- a/tests/modules/csm/test_state.py +++ b/tests/modules/csm/test_state.py @@ -1,258 +1,418 @@ +import os +import pickle +from collections import defaultdict from pathlib import Path from unittest.mock import Mock import pytest -from src.modules.csm.state import AttestationsAccumulator, State -from src.types import EpochNumber, ValidatorIndex +from src import variables +from src.modules.csm.state import AttestationsAccumulator, State, InvalidState +from src.types import ValidatorIndex from src.utils.range import sequence -@pytest.fixture() -def state_file_path(tmp_path: Path) -> Path: - return (tmp_path / "mock").with_suffix(State.EXTENSION) +@pytest.fixture(autouse=True) +def remove_state_files(): + state_file = Path("/tmp/state.pkl") + state_buf = Path("/tmp/state.buf") + state_file.unlink(missing_ok=True) + state_buf.unlink(missing_ok=True) + yield + state_file.unlink(missing_ok=True) + state_buf.unlink(missing_ok=True) + + +def test_load_restores_state_from_file(monkeypatch): + monkeypatch.setattr("src.modules.csm.state.State.file", lambda _=None: Path("/tmp/state.pkl")) + state = State() + state.data = { + (0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), + } + state.commit() + loaded_state = State.load() + assert loaded_state.data == state.data -@pytest.fixture(autouse=True) -def mock_state_file(state_file_path: Path): - State.file = Mock(return_value=state_file_path) +def test_load_returns_new_instance_if_file_not_found(monkeypatch): + monkeypatch.setattr("src.modules.csm.state.State.file", lambda: Path("/non/existent/path")) + state = State.load() + assert state.is_empty -def test_attestation_aggregate_perf(): - aggr = AttestationsAccumulator(included=333, assigned=777) - assert aggr.perf == pytest.approx(0.4285, abs=1e-4) +def test_load_returns_new_instance_if_empty_object(monkeypatch, tmp_path): + with open('/tmp/state.pkl', "wb") as f: + pickle.dump(None, f) + monkeypatch.setattr("src.modules.csm.state.State.file", lambda: Path("/tmp/state.pkl")) + state = State.load() + assert state.is_empty + + +def test_commit_saves_state_to_file(monkeypatch): + state = State() + state.data = { + (0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), + } + monkeypatch.setattr("src.modules.csm.state.State.file", lambda _: Path("/tmp/state.pkl")) + monkeypatch.setattr("os.replace", Mock(side_effect=os.replace)) + state.commit() + with open("/tmp/state.pkl", "rb") as f: + loaded_state = pickle.load(f) + assert loaded_state.data == state.data + os.replace.assert_called_once_with(Path("/tmp/state.buf"), Path("/tmp/state.pkl")) + + +def test_file_returns_correct_path(monkeypatch): + monkeypatch.setattr(variables, "CACHE_PATH", Path("/tmp")) + assert State.file() == Path("/tmp/cache.pkl") + + +def test_buffer_returns_correct_path(monkeypatch): + monkeypatch.setattr(variables, "CACHE_PATH", Path("/tmp")) + state = State() + assert state.buffer == Path("/tmp/cache.buf") + + +def test_is_empty_returns_true_for_empty_state(): + state = State() + assert state.is_empty + + +def test_is_empty_returns_false_for_non_empty_state(): + state = State() + state.data = {(0, 31): defaultdict(AttestationsAccumulator)} + assert not state.is_empty + + +def test_unprocessed_epochs_raises_error_if_epochs_not_set(): + state = State() + with pytest.raises(ValueError, match="Epochs to process are not set"): + state.unprocessed_epochs + + +def test_unprocessed_epochs_returns_correct_set(): + state = State() + state._epochs_to_process = tuple(sequence(0, 95)) + state._processed_epochs = set(sequence(0, 63)) + assert state.unprocessed_epochs == set(sequence(64, 95)) + + +def test_is_fulfilled_returns_true_if_no_unprocessed_epochs(): + state = State() + state._epochs_to_process = tuple(sequence(0, 95)) + state._processed_epochs = set(sequence(0, 95)) + assert state.is_fulfilled + + +def test_is_fulfilled_returns_false_if_unprocessed_epochs_exist(): + state = State() + state._epochs_to_process = tuple(sequence(0, 95)) + state._processed_epochs = set(sequence(0, 63)) + assert not state.is_fulfilled + + +def test_calculate_frames_handles_exact_frame_size(): + epochs = tuple(range(10)) + frames = State.calculate_frames(epochs, 5) + assert frames == [(0, 4), (5, 9)] + + +def test_calculate_frames_raises_error_for_insufficient_epochs(): + epochs = tuple(range(8)) + with pytest.raises(ValueError, match="Insufficient epochs to form a frame"): + State.calculate_frames(epochs, 5) + + +def test_clear_resets_state_to_empty(): + state = State() + state.data = {(0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)})} + state.clear() + assert state.is_empty + + +def test_find_frame_returns_correct_frame(): + state = State() + state.data = {(0, 31): defaultdict(AttestationsAccumulator)} + assert state.find_frame(15) == (0, 31) -def test_state_avg_perf(): +def test_find_frame_raises_error_for_out_of_range_epoch(): state = State() + state.data = {(0, 31): defaultdict(AttestationsAccumulator)} + with pytest.raises(ValueError, match="Epoch 32 is out of frames range"): + state.find_frame(32) - frame = (0, 999) - with pytest.raises(ValueError): - state.get_network_aggr(frame) +def test_increment_duty_adds_duty_correctly(): + state = State() + frame = (0, 31) + state.data = { + frame: defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), + } + state.increment_duty(frame, ValidatorIndex(1), True) + assert state.data[frame][ValidatorIndex(1)].assigned == 11 + assert state.data[frame][ValidatorIndex(1)].included == 6 + +def test_increment_duty_creates_new_validator_entry(): state = State() - state.init_or_migrate(*frame, 1000, 1) + frame = (0, 31) state.data = { - frame: { - ValidatorIndex(0): AttestationsAccumulator(included=0, assigned=0), - ValidatorIndex(1): AttestationsAccumulator(included=0, assigned=0), - } + frame: defaultdict(AttestationsAccumulator), } + state.increment_duty(frame, ValidatorIndex(2), True) + assert state.data[frame][ValidatorIndex(2)].assigned == 1 + assert state.data[frame][ValidatorIndex(2)].included == 1 - assert state.get_network_aggr(frame).perf == 0 +def test_increment_duty_handles_non_included_duty(): + state = State() + frame = (0, 31) state.data = { - frame: { - ValidatorIndex(0): AttestationsAccumulator(included=333, assigned=777), - ValidatorIndex(1): AttestationsAccumulator(included=167, assigned=223), - } + frame: defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), } + state.increment_duty(frame, ValidatorIndex(1), False) + assert state.data[frame][ValidatorIndex(1)].assigned == 11 + assert state.data[frame][ValidatorIndex(1)].included == 5 - assert state.get_network_aggr(frame).perf == 0.5 +def test_increment_duty_raises_error_for_out_of_range_epoch(): + state = State() + state.data = { + (0, 31): defaultdict(AttestationsAccumulator), + } + with pytest.raises(ValueError, match="is not found in the state"): + state.increment_duty((0, 32), ValidatorIndex(1), True) -def test_state_attestations(): - state = State( - { - (0, 999): { - ValidatorIndex(0): AttestationsAccumulator(included=333, assigned=777), - ValidatorIndex(1): AttestationsAccumulator(included=167, assigned=223), - } - } - ) - network_aggr = state.get_network_aggr((0, 999)) +def test_add_processed_epoch_adds_epoch_to_processed_set(): + state = State() + state.add_processed_epoch(5) + assert 5 in state._processed_epochs - assert network_aggr.assigned == 1000 - assert network_aggr.included == 500 +def test_add_processed_epoch_does_not_duplicate_epochs(): + state = State() + state.add_processed_epoch(5) + state.add_processed_epoch(5) + assert len(state._processed_epochs) == 1 -def test_state_load(): - orig = State( - { - (0, 999): { - ValidatorIndex(0): AttestationsAccumulator(included=333, assigned=777), - ValidatorIndex(1): AttestationsAccumulator(included=167, assigned=223), - } - } - ) - orig.commit() - copy = State.load() - assert copy.data == orig.data +def test_init_or_migrate_discards_data_on_version_change(): + state = State() + state._consensus_version = 1 + state.clear = Mock() + state.commit = Mock() + state.init_or_migrate(0, 63, 32, 2) + state.clear.assert_called_once() + state.commit.assert_called_once() -def test_state_clear(): - state = State( - { - (0, 999): { - ValidatorIndex(0): AttestationsAccumulator(included=333, assigned=777), - ValidatorIndex(1): AttestationsAccumulator(included=167, assigned=223), - } - } - ) +def test_init_or_migrate_no_migration_needed(): + state = State() + state._consensus_version = 1 + state._epochs_to_process = tuple(sequence(0, 63)) + state._epochs_per_frame = 32 + state.data = { + (0, 31): defaultdict(AttestationsAccumulator), + (32, 63): defaultdict(AttestationsAccumulator), + } + state.commit = Mock() + state.init_or_migrate(0, 63, 32, 1) + state.commit.assert_not_called() - state._epochs_to_process = (EpochNumber(1), EpochNumber(33)) - state._processed_epochs = {EpochNumber(42), EpochNumber(17)} - state.clear() - assert state.is_empty - assert not state.data +def test_init_or_migrate_migrates_data(): + state = State() + state._consensus_version = 1 + state._epochs_to_process = tuple(sequence(0, 63)) + state._epochs_per_frame = 32 + state.data = { + (0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), + (32, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(20, 15)}), + } + state.commit = Mock() + state.init_or_migrate(0, 63, 64, 1) + assert state.data == { + (0, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(30, 20)}), + } + state.commit.assert_called_once() + + +def test_init_or_migrate_invalidates_unmigrated_frames(): + state = State() + state._consensus_version = 1 + state._epochs_to_process = tuple(sequence(0, 63)) + state._epochs_per_frame = 64 + state.data = { + (0, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(30, 20)}), + } + state.commit = Mock() + state.init_or_migrate(0, 31, 32, 1) + assert state.data == { + (0, 31): defaultdict(AttestationsAccumulator), + } + assert state._processed_epochs == set() + state.commit.assert_called_once() + + +def test_init_or_migrate_discards_unmigrated_frame(): + state = State() + state._consensus_version = 1 + state._epochs_to_process = tuple(sequence(0, 95)) + state._epochs_per_frame = 32 + state.data = { + (0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), + (32, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(20, 15)}), + (64, 95): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(30, 25)}), + } + state._processed_epochs = set(sequence(0, 95)) + state.commit = Mock() + state.init_or_migrate(0, 63, 32, 1) + assert state.data == { + (0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), + (32, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(20, 15)}), + } + assert state._processed_epochs == set(sequence(0, 63)) + state.commit.assert_called_once() + + +def test_migrate_frames_data_creates_new_data_correctly(): + state = State() + current_frames = [(0, 31), (32, 63)] + new_frames = [(0, 63)] + state.data = { + (0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), + (32, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(20, 15)}), + } + new_data, migration_status = state._migrate_frames_data(current_frames, new_frames) + assert new_data == { + (0, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(30, 20)}) + } + assert migration_status == {(0, 31): True, (32, 63): True} + + +def test_migrate_frames_data_handles_no_migration(): + state = State() + current_frames = [(0, 31)] + new_frames = [(0, 31)] + state.data = { + (0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), + } + new_data, migration_status = state._migrate_frames_data(current_frames, new_frames) + assert new_data == { + (0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}) + } + assert migration_status == {(0, 31): True} + + +def test_migrate_frames_data_handles_partial_migration(): + state = State() + current_frames = [(0, 31), (32, 63)] + new_frames = [(0, 31), (32, 95)] + state.data = { + (0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), + (32, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(20, 15)}), + } + new_data, migration_status = state._migrate_frames_data(current_frames, new_frames) + assert new_data == { + (0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), + (32, 95): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(20, 15)}), + } + assert migration_status == {(0, 31): True, (32, 63): True} + + +def test_migrate_frames_data_handles_no_data(): + state = State() + current_frames = [(0, 31)] + new_frames = [(0, 31)] + state.data = {frame: defaultdict(AttestationsAccumulator) for frame in current_frames} + new_data, migration_status = state._migrate_frames_data(current_frames, new_frames) + assert new_data == {(0, 31): defaultdict(AttestationsAccumulator)} + assert migration_status == {(0, 31): True} + + +def test_migrate_frames_data_handles_wider_old_frame(): + state = State() + current_frames = [(0, 63)] + new_frames = [(0, 31), (32, 63)] + state.data = { + (0, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(30, 20)}), + } + new_data, migration_status = state._migrate_frames_data(current_frames, new_frames) + assert new_data == { + (0, 31): defaultdict(AttestationsAccumulator), + (32, 63): defaultdict(AttestationsAccumulator), + } + assert migration_status == {(0, 63): False} + + +def test_validate_raises_error_if_state_not_fulfilled(): + state = State() + state._epochs_to_process = tuple(sequence(0, 95)) + state._processed_epochs = set(sequence(0, 94)) + with pytest.raises(InvalidState, match="State is not fulfilled"): + state.validate(0, 95) + + +def test_validate_raises_error_if_processed_epoch_out_of_range(): + state = State() + state._epochs_to_process = tuple(sequence(0, 95)) + state._processed_epochs = set(sequence(0, 95)) + state._processed_epochs.add(96) + with pytest.raises(InvalidState, match="Processed epoch 96 is out of range"): + state.validate(0, 95) + + +def test_validate_raises_error_if_epoch_missing_in_processed_epochs(): + state = State() + state._epochs_to_process = tuple(sequence(0, 94)) + state._processed_epochs = set(sequence(0, 94)) + with pytest.raises(InvalidState, match="Epoch 95 missing in processed epochs"): + state.validate(0, 95) -def test_state_add_processed_epoch(): +def test_validate_passes_for_fulfilled_state(): state = State() - state.add_processed_epoch(EpochNumber(42)) - state.add_processed_epoch(EpochNumber(17)) - assert state._processed_epochs == {EpochNumber(42), EpochNumber(17)} + state._epochs_to_process = tuple(sequence(0, 95)) + state._processed_epochs = set(sequence(0, 95)) + state.validate(0, 95) -def test_state_inc(): - - frame_0 = (0, 999) - frame_1 = (1000, 1999) - - state = State( - { - frame_0: { - ValidatorIndex(0): AttestationsAccumulator(included=333, assigned=777), - ValidatorIndex(1): AttestationsAccumulator(included=167, assigned=223), - }, - frame_1: { - ValidatorIndex(0): AttestationsAccumulator(included=1, assigned=1), - ValidatorIndex(1): AttestationsAccumulator(included=0, assigned=1), - }, - } - ) - - state.increment_duty(999, ValidatorIndex(0), True) - state.increment_duty(999, ValidatorIndex(0), False) - state.increment_duty(999, ValidatorIndex(1), True) - state.increment_duty(999, ValidatorIndex(1), True) - state.increment_duty(999, ValidatorIndex(1), False) - state.increment_duty(999, ValidatorIndex(2), True) - - state.increment_duty(1000, ValidatorIndex(2), False) - - assert tuple(state.data[frame_0].values()) == ( - AttestationsAccumulator(included=334, assigned=779), - AttestationsAccumulator(included=169, assigned=226), - AttestationsAccumulator(included=1, assigned=1), - ) - - assert tuple(state.data[frame_1].values()) == ( - AttestationsAccumulator(included=1, assigned=1), - AttestationsAccumulator(included=0, assigned=1), - AttestationsAccumulator(included=0, assigned=1), - ) - - -def test_state_file_is_path(): - assert isinstance(State.file(), Path) - - -class TestStateTransition: - """Tests for State's transition for different l_epoch, r_epoch values""" - - @pytest.fixture(autouse=True) - def no_commit(self, monkeypatch: pytest.MonkeyPatch): - monkeypatch.setattr(State, "commit", Mock()) - - def test_empty_to_new_frame(self): - state = State() - assert state.is_empty - - l_epoch = EpochNumber(1) - r_epoch = EpochNumber(255) - - state.init_or_migrate(l_epoch, r_epoch, 255, 1) - - assert not state.is_empty - assert state.unprocessed_epochs == set(sequence(l_epoch, r_epoch)) - - @pytest.mark.parametrize( - ("l_epoch_old", "r_epoch_old", "l_epoch_new", "r_epoch_new"), - [ - pytest.param(1, 255, 256, 510, id="Migrate a..bA..B"), - pytest.param(1, 255, 32, 510, id="Migrate a..A..b..B"), - pytest.param(32, 510, 1, 255, id="Migrate: A..a..B..b"), - ], - ) - def test_new_frame_requires_discarding_state(self, l_epoch_old, r_epoch_old, l_epoch_new, r_epoch_new): - state = State() - state.clear = Mock(side_effect=state.clear) - state.init_or_migrate(l_epoch_old, r_epoch_old, r_epoch_old - l_epoch_old + 1, 1) - state.clear.assert_not_called() - - state.init_or_migrate(l_epoch_new, r_epoch_new, r_epoch_new - l_epoch_new + 1, 1) - state.clear.assert_called_once() - - assert state.unprocessed_epochs == set(sequence(l_epoch_new, r_epoch_new)) - - @pytest.mark.parametrize( - ("l_epoch_old", "r_epoch_old", "l_epoch_new", "r_epoch_new", "epochs_per_frame"), - [ - pytest.param(1, 255, 1, 510, 255, id="Migrate Aa..b..B"), - ], - ) - def test_new_frame_extends_old_state(self, l_epoch_old, r_epoch_old, l_epoch_new, r_epoch_new, epochs_per_frame): - state = State() - state.clear = Mock(side_effect=state.clear) - - state.init_or_migrate(l_epoch_old, r_epoch_old, epochs_per_frame, 1) - state.clear.assert_not_called() - - state.init_or_migrate(l_epoch_new, r_epoch_new, epochs_per_frame, 1) - state.clear.assert_not_called() - - assert state.unprocessed_epochs == set(sequence(l_epoch_new, r_epoch_new)) - assert len(state.data) == 2 - assert list(state.data.keys()) == [(l_epoch_old, r_epoch_old), (r_epoch_old + 1, r_epoch_new)] - assert state.calculate_frames(state._epochs_to_process, epochs_per_frame) == [ - (l_epoch_old, r_epoch_old), - (r_epoch_old + 1, r_epoch_new), - ] - - @pytest.mark.parametrize( - ("l_epoch_old", "r_epoch_old", "epochs_per_frame_old", "l_epoch_new", "r_epoch_new", "epochs_per_frame_new"), - [ - pytest.param(32, 510, 479, 1, 510, 510, id="Migrate: A..a..b..B"), - ], - ) - def test_new_frame_extends_old_state_with_single_frame( - self, l_epoch_old, r_epoch_old, epochs_per_frame_old, l_epoch_new, r_epoch_new, epochs_per_frame_new - ): - state = State() - state.clear = Mock(side_effect=state.clear) - - state.init_or_migrate(l_epoch_old, r_epoch_old, epochs_per_frame_old, 1) - state.clear.assert_not_called() - - state.init_or_migrate(l_epoch_new, r_epoch_new, epochs_per_frame_new, 1) - state.clear.assert_not_called() - - assert state.unprocessed_epochs == set(sequence(l_epoch_new, r_epoch_new)) - assert len(state.data) == 1 - assert list(state.data.keys())[0] == (l_epoch_new, r_epoch_new) - assert state.calculate_frames(state._epochs_to_process, epochs_per_frame_new) == [(l_epoch_new, r_epoch_new)] - - @pytest.mark.parametrize( - ("old_version", "new_version"), - [ - pytest.param(2, 3, id="Increase consensus version"), - pytest.param(3, 2, id="Decrease consensus version"), - ], - ) - def test_consensus_version_change(self, old_version, new_version): - state = State() - state.clear = Mock(side_effect=state.clear) - state._consensus_version = old_version - - l_epoch = r_epoch = EpochNumber(255) - - state.init_or_migrate(l_epoch, r_epoch, 1, old_version) - state.clear.assert_not_called() - - state.init_or_migrate(l_epoch, r_epoch, 1, new_version) - state.clear.assert_called_once() +def test_attestation_aggregate_perf(): + aggr = AttestationsAccumulator(included=333, assigned=777) + assert aggr.perf == pytest.approx(0.4285, abs=1e-4) + + +def test_get_network_aggr_computes_correctly(): + state = State() + state.data = { + (0, 31): defaultdict( + AttestationsAccumulator, + {ValidatorIndex(1): AttestationsAccumulator(10, 5), ValidatorIndex(2): AttestationsAccumulator(20, 15)}, + ) + } + aggr = state.get_network_aggr((0, 31)) + assert aggr.assigned == 30 + assert aggr.included == 20 + + +def test_get_network_aggr_raises_error_for_invalid_accumulator(): + state = State() + state.data = {(0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 15)})} + with pytest.raises(ValueError, match="Invalid accumulator"): + state.get_network_aggr((0, 31)) + + +def test_get_network_aggr_raises_error_for_missing_frame_data(): + state = State() + with pytest.raises(ValueError, match="No data for frame"): + state.get_network_aggr((0, 31)) + + +def test_get_network_aggr_handles_empty_frame_data(): + state = State() + state.data = {(0, 31): defaultdict(AttestationsAccumulator)} + aggr = state.get_network_aggr((0, 31)) + assert aggr.assigned == 0 + assert aggr.included == 0 From 8163d4d9c102c52ecfadc90c411795fa5e205936 Mon Sep 17 00:00:00 2001 From: vgorkavenko Date: Thu, 13 Feb 2025 17:37:51 +0100 Subject: [PATCH 03/20] refactor: distribution and tests --- src/modules/csm/csm.py | 173 ++++++---- src/modules/csm/log.py | 3 +- src/modules/csm/state.py | 11 +- .../execution/contracts/cs_fee_distributor.py | 4 +- tests/modules/csm/test_csm_distribution.py | 325 ++++++++++++++++++ tests/modules/csm/test_csm_module.py | 257 -------------- tests/modules/csm/test_state.py | 4 + 7 files changed, 441 insertions(+), 336 deletions(-) create mode 100644 tests/modules/csm/test_csm_distribution.py diff --git a/src/modules/csm/csm.py b/src/modules/csm/csm.py index 3d3619a94..ed146677e 100644 --- a/src/modules/csm/csm.py +++ b/src/modules/csm/csm.py @@ -12,8 +12,8 @@ ) from src.metrics.prometheus.duration_meter import duration_meter from src.modules.csm.checkpoint import FrameCheckpointProcessor, FrameCheckpointsIterator, MinStepIsNotReached -from src.modules.csm.log import FramePerfLog -from src.modules.csm.state import State, Frame +from src.modules.csm.log import FramePerfLog, OperatorFrameSummary +from src.modules.csm.state import State, Frame, AttestationsAccumulator from src.modules.csm.tree import Tree from src.modules.csm.types import ReportData, Shares from src.modules.submodules.consensus import ConsensusModule @@ -29,13 +29,12 @@ SlotNumber, StakingModuleAddress, StakingModuleId, - ValidatorIndex, ) from src.utils.blockstamp import build_blockstamp from src.utils.cache import global_lru_cache as lru_cache from src.utils.slot import get_next_non_missed_slot, get_reference_blockstamp from src.utils.web3converter import Web3Converter -from src.web3py.extensions.lido_validators import NodeOperatorId, StakingModule, ValidatorsByNodeOperator +from src.web3py.extensions.lido_validators import NodeOperatorId, StakingModule, ValidatorsByNodeOperator, LidoValidator from src.web3py.types import Web3 logger = logging.getLogger(__name__) @@ -102,15 +101,15 @@ def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple: if (prev_cid is None) != (prev_root == ZERO_HASH): raise InconsistentData(f"Got inconsistent previous tree data: {prev_root=} {prev_cid=}") - distributed, shares, logs = self.calculate_distribution(blockstamp) + total_distributed, total_rewards, logs = self.calculate_distribution(blockstamp) - if distributed != sum(shares.values()): - raise InconsistentData(f"Invalid distribution: {sum(shares.values())=} != {distributed=}") + if total_distributed != sum(total_rewards.values()): + raise InconsistentData(f"Invalid distribution: {sum(total_rewards.values())=} != {total_distributed=}") log_cid = self.publish_log(logs) - if not distributed and not shares: - logger.info({"msg": "No shares distributed in the current frame"}) + if not total_distributed and not total_rewards: + logger.info({"msg": "No rewards distributed in the current frame"}) return ReportData( self.get_consensus_version(blockstamp), blockstamp.ref_slot, @@ -123,11 +122,11 @@ def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple: if prev_cid and prev_root != ZERO_HASH: # Update cumulative amount of shares for all operators. for no_id, acc_shares in self.get_accumulated_shares(prev_cid, prev_root): - shares[no_id] += acc_shares + total_rewards[no_id] += acc_shares else: logger.info({"msg": "No previous distribution. Nothing to accumulate"}) - tree = self.make_tree(shares) + tree = self.make_tree(total_rewards) tree_cid = self.publish_tree(tree) return ReportData( @@ -136,7 +135,7 @@ def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple: tree_root=tree.root, tree_cid=tree_cid, log_cid=log_cid, - distributed=distributed, + distributed=total_distributed, ).as_tuple() def is_main_data_submitted(self, blockstamp: BlockStamp) -> bool: @@ -232,26 +231,36 @@ def calculate_distribution( """Computes distribution of fee shares at the given timestamp""" operators_to_validators = self.module_validators_by_node_operators(blockstamp) - distributed = 0 - # Calculate share of each CSM node operator. - shares = defaultdict[NodeOperatorId, int](int) + total_distributed = 0 + total_rewards = defaultdict[NodeOperatorId, int](int) logs: list[FramePerfLog] = [] - for frame in self.state.data: + for frame in self.state.frames: from_epoch, to_epoch = frame logger.info({"msg": f"Calculating distribution for frame [{from_epoch};{to_epoch}]"}) + frame_blockstamp = blockstamp if to_epoch != blockstamp.ref_epoch: frame_blockstamp = self._get_ref_blockstamp_for_frame(blockstamp, to_epoch) - distributed_in_frame, shares_in_frame, log = self._calculate_distribution_in_frame( - frame_blockstamp, operators_to_validators, frame, distributed + + total_rewards_to_distribute = self.w3.csm.fee_distributor.shares_to_distribute(frame_blockstamp.block_hash) + rewards_to_distribute_in_frame = total_rewards_to_distribute - total_distributed + + rewards_in_frame, log = self._calculate_distribution_in_frame( + frame, frame_blockstamp, rewards_to_distribute_in_frame, operators_to_validators ) - distributed += distributed_in_frame - for no_id, share in shares_in_frame.items(): - shares[no_id] += share + distributed_in_frame = sum(rewards_in_frame.values()) + + total_distributed += distributed_in_frame + if total_distributed > total_rewards_to_distribute: + raise CSMError(f"Invalid distribution: {total_distributed=} > {total_rewards_to_distribute=}") + + for no_id, rewards in rewards_in_frame.items(): + total_rewards[no_id] += rewards + logs.append(log) - return distributed, shares, logs + return total_distributed, total_rewards, logs def _get_ref_blockstamp_for_frame( self, blockstamp: ReferenceBlockStamp, frame_ref_epoch: EpochNumber @@ -266,63 +275,85 @@ def _get_ref_blockstamp_for_frame( def _calculate_distribution_in_frame( self, - blockstamp: ReferenceBlockStamp, - operators_to_validators: ValidatorsByNodeOperator, frame: Frame, - distributed: int, + blockstamp: ReferenceBlockStamp, + rewards_to_distribute: int, + operators_to_validators: ValidatorsByNodeOperator ): - network_perf = self.state.get_network_aggr(frame).perf - threshold = network_perf - self.w3.csm.oracle.perf_leeway_bp(blockstamp.block_hash) / TOTAL_BASIS_POINTS - - # Build the map of the current distribution operators. - distribution: dict[NodeOperatorId, int] = defaultdict(int) - stuck_operators = self.stuck_operators(blockstamp) + threshold = self._get_performance_threshold(frame, blockstamp) log = FramePerfLog(blockstamp, frame, threshold) + participation_shares: defaultdict[NodeOperatorId, int] = defaultdict(int) + + stuck_operators = self.stuck_operators(blockstamp) for (_, no_id), validators in operators_to_validators.items(): + log_operator = log.operators[no_id] if no_id in stuck_operators: - log.operators[no_id].stuck = True + log_operator.stuck = True + continue + for validator in validators: + duty = self.state.data[frame].get(validator.index) + self.process_validator_duty(validator, duty, threshold, participation_shares, log_operator) + + rewards_distribution = self.calc_rewards_distribution_in_frame(participation_shares, rewards_to_distribute) + + for no_id, no_rewards in rewards_distribution.items(): + log.operators[no_id].distributed = no_rewards + + log.distributable = rewards_to_distribute + + return rewards_distribution, log + + def _get_performance_threshold(self, frame: Frame, blockstamp: ReferenceBlockStamp) -> float: + network_perf = self.state.get_network_aggr(frame).perf + perf_leeway = self.w3.csm.oracle.perf_leeway_bp(blockstamp.block_hash) / TOTAL_BASIS_POINTS + threshold = network_perf - perf_leeway + return threshold + + @staticmethod + def process_validator_duty( + validator: LidoValidator, + attestation_duty: AttestationsAccumulator | None, + threshold: float, + participation_shares: defaultdict[NodeOperatorId, int], + log_operator: OperatorFrameSummary + ): + if attestation_duty is None: + # It's possible that the validator is not assigned to any duty, hence it's performance + # is not presented in the aggregates (e.g. exited, pending for activation etc). + # TODO: check `sync_aggr` to strike (in case of bad sync performance) after validator exit + return + + log_validator = log_operator.validators[validator.index] + + if validator.validator.slashed is True: + # It means that validator was active during the frame and got slashed and didn't meet the exit + # epoch, so we should not count such validator for operator's share. + log_validator.slashed = True + return + + if attestation_duty.perf > threshold: + # Count of assigned attestations used as a metrics of time + # the validator was active in the current frame. + participation_shares[validator.lido_id.operatorIndex] += attestation_duty.assigned + + log_validator.attestation_duty = attestation_duty + + @staticmethod + def calc_rewards_distribution_in_frame( + participation_shares: dict[NodeOperatorId, int], + rewards_to_distribute: int, + ) -> dict[NodeOperatorId, int]: + rewards_distribution: dict[NodeOperatorId, int] = defaultdict(int) + total_participation = sum(participation_shares.values()) + + for no_id, no_participation_share in participation_shares.items(): + if no_participation_share == 0: + # Skip operators with zero participation continue + rewards_distribution[no_id] = rewards_to_distribute * no_participation_share // total_participation - for v in validators: - aggr = self.state.data[frame].get(ValidatorIndex(int(v.index))) - - if aggr is None: - # It's possible that the validator is not assigned to any duty, hence it's performance - # is not presented in the aggregates (e.g. exited, pending for activation etc). - continue - - if v.validator.slashed is True: - # It means that validator was active during the frame and got slashed and didn't meet the exit - # epoch, so we should not count such validator for operator's share. - log.operators[no_id].validators[v.index].slashed = True - continue - - if aggr.perf > threshold: - # Count of assigned attestations used as a metrics of time - # the validator was active in the current frame. - distribution[no_id] += aggr.assigned - - log.operators[no_id].validators[v.index].perf = aggr - - # Calculate share of each CSM node operator. - shares = defaultdict[NodeOperatorId, int](int) - total = sum(p for p in distribution.values()) - to_distribute = self.w3.csm.fee_distributor.shares_to_distribute(blockstamp.block_hash) - distributed - log.distributable = to_distribute - - if not total: - return 0, shares, log - - for no_id, no_share in distribution.items(): - if no_share: - shares[no_id] = to_distribute * no_share // total - log.operators[no_id].distributed = shares[no_id] - - distributed = sum(s for s in shares.values()) - if distributed > to_distribute: - raise CSMError(f"Invalid distribution: {distributed=} > {to_distribute=}") - return distributed, shares, log + return rewards_distribution def get_accumulated_shares(self, cid: CID, root: HexBytes) -> Iterator[tuple[NodeOperatorId, Shares]]: logger.info({"msg": "Fetching tree by CID from IPFS", "cid": repr(cid)}) diff --git a/src/modules/csm/log.py b/src/modules/csm/log.py index 39832c8c0..29ab24902 100644 --- a/src/modules/csm/log.py +++ b/src/modules/csm/log.py @@ -12,8 +12,7 @@ class LogJSONEncoder(json.JSONEncoder): ... @dataclass class ValidatorFrameSummary: - # TODO: Should be renamed. Perf means different things in different contexts - perf: AttestationsAccumulator = field(default_factory=AttestationsAccumulator) + attestation_duty: AttestationsAccumulator = field(default_factory=AttestationsAccumulator) slashed: bool = False diff --git a/src/modules/csm/state.py b/src/modules/csm/state.py index c269b7fcb..e6fb7d866 100644 --- a/src/modules/csm/state.py +++ b/src/modules/csm/state.py @@ -110,6 +110,10 @@ def unprocessed_epochs(self) -> set[EpochNumber]: def is_fulfilled(self) -> bool: return not self.unprocessed_epochs + @property + def frames(self): + return self.calculate_frames(self._epochs_to_process, self._epochs_per_frame) + @staticmethod def calculate_frames(epochs_to_process: tuple[EpochNumber, ...], epochs_per_frame: int) -> list[Frame]: """Split epochs to process into frames of `epochs_per_frame` length""" @@ -127,11 +131,10 @@ def clear(self) -> None: assert self.is_empty def find_frame(self, epoch: EpochNumber) -> Frame: - frames = self.data.keys() - for epoch_range in frames: + for epoch_range in self.frames: if epoch_range[0] <= epoch <= epoch_range[1]: return epoch_range - raise ValueError(f"Epoch {epoch} is out of frames range: {frames}") + raise ValueError(f"Epoch {epoch} is out of frames range: {self.frames}") def increment_duty(self, frame: Frame, val_index: ValidatorIndex, included: bool) -> None: if frame not in self.data: @@ -160,7 +163,7 @@ def init_or_migrate( frames_data: StateData = {frame: defaultdict(AttestationsAccumulator) for frame in frames} if not self.is_empty: - cached_frames = self.calculate_frames(self._epochs_to_process, self._epochs_per_frame) + cached_frames = self.frames if cached_frames == frames: logger.info({"msg": "No need to migrate duties data cache"}) return diff --git a/src/providers/execution/contracts/cs_fee_distributor.py b/src/providers/execution/contracts/cs_fee_distributor.py index 937c7dc49..8a556b250 100644 --- a/src/providers/execution/contracts/cs_fee_distributor.py +++ b/src/providers/execution/contracts/cs_fee_distributor.py @@ -3,7 +3,7 @@ from eth_typing import ChecksumAddress from hexbytes import HexBytes from web3 import Web3 -from web3.types import BlockIdentifier +from web3.types import BlockIdentifier, Wei from ..base_interface import ContractInterface @@ -26,7 +26,7 @@ def oracle(self, block_identifier: BlockIdentifier = "latest") -> ChecksumAddres ) return Web3.to_checksum_address(resp) - def shares_to_distribute(self, block_identifier: BlockIdentifier = "latest") -> int: + def shares_to_distribute(self, block_identifier: BlockIdentifier = "latest") -> Wei: """Returns the amount of shares that are pending to be distributed""" resp = self.functions.pendingSharesToDistribute().call(block_identifier=block_identifier) diff --git a/tests/modules/csm/test_csm_distribution.py b/tests/modules/csm/test_csm_distribution.py new file mode 100644 index 000000000..cb07c2e95 --- /dev/null +++ b/tests/modules/csm/test_csm_distribution.py @@ -0,0 +1,325 @@ +from collections import defaultdict +from unittest.mock import Mock + +import pytest +from web3.types import Wei + +from src.constants import UINT64_MAX +from src.modules.csm.csm import CSOracle, CSMError +from src.modules.csm.log import ValidatorFrameSummary, OperatorFrameSummary +from src.modules.csm.state import AttestationsAccumulator, State +from src.types import NodeOperatorId, ValidatorIndex +from src.web3py.extensions import CSM +from tests.factory.no_registry import LidoValidatorFactory + + +@pytest.fixture(autouse=True) +def mock_get_module_id(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(CSOracle, "_get_module_id", Mock()) + + +@pytest.fixture() +def module(web3, csm: CSM): + yield CSOracle(web3) + + +def test_calculate_distribution_handles_single_frame(module): + module.state = Mock() + module.state.frames = [(1, 2)] + blockstamp = Mock() + module.module_validators_by_node_operators = Mock() + module._get_ref_blockstamp_for_frame = Mock(return_value=blockstamp) + module.w3.csm.fee_distributor.shares_to_distribute = Mock(return_value=500) + module._calculate_distribution_in_frame = Mock(return_value=({NodeOperatorId(1): 500}, Mock())) + + total_distributed, total_rewards, logs = module.calculate_distribution(blockstamp) + + assert total_distributed == 500 + assert total_rewards[NodeOperatorId(1)] == 500 + assert len(logs) == 1 + + +def test_calculate_distribution_handles_multiple_frames(module): + module.state = Mock() + module.state.frames = [(1, 2), (3, 4), (5, 6)] + blockstamp = Mock() + module.module_validators_by_node_operators = Mock() + module._get_ref_blockstamp_for_frame = Mock(return_value=blockstamp) + module.w3.csm.fee_distributor.shares_to_distribute = Mock(return_value=800) + module._calculate_distribution_in_frame = Mock( + side_effect=[ + ({NodeOperatorId(1): 500}, Mock()), + ({NodeOperatorId(1): 136}, Mock()), + ({NodeOperatorId(1): 164}, Mock()), + ] + ) + + total_distributed, total_rewards, logs = module.calculate_distribution(blockstamp) + + assert total_distributed == 800 + assert total_rewards[NodeOperatorId(1)] == 800 + assert len(logs) == 3 + + +def test_calculate_distribution_handles_invalid_distribution(module): + module.state = Mock() + module.state.frames = [(1, 2)] + blockstamp = Mock() + module.module_validators_by_node_operators = Mock() + module._get_ref_blockstamp_for_frame = Mock(return_value=blockstamp) + module.w3.csm.fee_distributor.shares_to_distribute = Mock(return_value=500) + module._calculate_distribution_in_frame = Mock(return_value=({NodeOperatorId(1): 600}, Mock())) + + with pytest.raises(CSMError, match="Invalid distribution"): + module.calculate_distribution(blockstamp) + + +def test_calculate_distribution_in_frame_handles_stuck_operator(module): + frame = Mock() + blockstamp = Mock() + rewards_to_distribute = UINT64_MAX + operators_to_validators = {(Mock(), NodeOperatorId(1)): [LidoValidatorFactory.build()]} + module.state = State() + module.state.data = {frame: defaultdict(AttestationsAccumulator)} + module.stuck_operators = Mock(return_value={NodeOperatorId(1)}) + module._get_performance_threshold = Mock() + + rewards_distribution, log = module._calculate_distribution_in_frame( + frame, blockstamp, rewards_to_distribute, operators_to_validators + ) + + assert rewards_distribution[NodeOperatorId(1)] == 0 + assert log.operators[NodeOperatorId(1)].stuck is True + assert log.operators[NodeOperatorId(1)].distributed == 0 + assert log.operators[NodeOperatorId(1)].validators == defaultdict(ValidatorFrameSummary) + + +def test_calculate_distribution_in_frame_handles_no_attestation_duty(module): + frame = Mock() + blockstamp = Mock() + rewards_to_distribute = UINT64_MAX + validator = LidoValidatorFactory.build() + node_operator_id = validator.lido_id.operatorIndex + operators_to_validators = {(Mock(), node_operator_id): [validator]} + module.state = State() + module.state.data = {frame: defaultdict(AttestationsAccumulator)} + module.stuck_operators = Mock(return_value=set()) + module._get_performance_threshold = Mock() + + rewards_distribution, log = module._calculate_distribution_in_frame( + frame, blockstamp, rewards_to_distribute, operators_to_validators + ) + + assert rewards_distribution[node_operator_id] == 0 + assert log.operators[node_operator_id].stuck is False + assert log.operators[node_operator_id].distributed == 0 + assert log.operators[node_operator_id].validators == defaultdict(ValidatorFrameSummary) + + +def test_calculate_distribution_in_frame_handles_above_threshold_performance(module): + frame = Mock() + blockstamp = Mock() + rewards_to_distribute = UINT64_MAX + validator = LidoValidatorFactory.build() + validator.validator.slashed = False + node_operator_id = validator.lido_id.operatorIndex + operators_to_validators = {(Mock(), node_operator_id): [validator]} + module.state = State() + attestation_duty = AttestationsAccumulator(assigned=10, included=6) + module.state.data = {frame: {validator.index: attestation_duty}} + module.stuck_operators = Mock(return_value=set()) + module._get_performance_threshold = Mock(return_value=0.5) + + rewards_distribution, log = module._calculate_distribution_in_frame( + frame, blockstamp, rewards_to_distribute, operators_to_validators + ) + + assert rewards_distribution[node_operator_id] > 0 # no need to check exact value + assert log.operators[node_operator_id].stuck is False + assert log.operators[node_operator_id].distributed > 0 + assert log.operators[node_operator_id].validators[validator.index].attestation_duty == attestation_duty + + +def test_calculate_distribution_in_frame_handles_below_threshold_performance(module): + frame = Mock() + blockstamp = Mock() + rewards_to_distribute = UINT64_MAX + validator = LidoValidatorFactory.build() + validator.validator.slashed = False + node_operator_id = validator.lido_id.operatorIndex + operators_to_validators = {(Mock(), node_operator_id): [validator]} + module.state = State() + attestation_duty = AttestationsAccumulator(assigned=10, included=5) + module.state.data = {frame: {validator.index: attestation_duty}} + module.stuck_operators = Mock(return_value=set()) + module._get_performance_threshold = Mock(return_value=0.5) + + rewards_distribution, log = module._calculate_distribution_in_frame( + frame, blockstamp, rewards_to_distribute, operators_to_validators + ) + + assert rewards_distribution[node_operator_id] == 0 + assert log.operators[node_operator_id].stuck is False + assert log.operators[node_operator_id].distributed == 0 + assert log.operators[node_operator_id].validators[validator.index].attestation_duty == attestation_duty + + +def test_performance_threshold_calculates_correctly(module): + state = State() + state.data = { + (0, 31): { + ValidatorIndex(1): AttestationsAccumulator(10, 10), + ValidatorIndex(2): AttestationsAccumulator(10, 10), + }, + } + module.w3.csm.oracle.perf_leeway_bp.return_value = 500 + module.state = state + + threshold = module._get_performance_threshold((0, 31), Mock()) + + assert threshold == 0.95 + + +def test_performance_threshold_handles_zero_leeway(module): + state = State() + state.data = { + (0, 31): { + ValidatorIndex(1): AttestationsAccumulator(10, 10), + ValidatorIndex(2): AttestationsAccumulator(10, 10), + }, + } + module.w3.csm.oracle.perf_leeway_bp.return_value = 0 + module.state = state + + threshold = module._get_performance_threshold((0, 31), Mock()) + + assert threshold == 1.0 + + +def test_performance_threshold_handles_high_leeway(module): + state = State() + state.data = { + (0, 31): {ValidatorIndex(1): AttestationsAccumulator(10, 1), ValidatorIndex(2): AttestationsAccumulator(10, 1)}, + } + module.w3.csm.oracle.perf_leeway_bp.return_value = 5000 + module.state = state + + threshold = module._get_performance_threshold((0, 31), Mock()) + + assert threshold == -0.4 + + +def test_process_validator_duty_handles_above_threshold_performance(): + validator = LidoValidatorFactory.build() + validator.validator.slashed = False + log_operator = Mock() + log_operator.validators = defaultdict(ValidatorFrameSummary) + participation_shares = defaultdict(int) + threshold = 0.5 + + attestation_duty = AttestationsAccumulator(assigned=10, included=6) + + CSOracle.process_validator_duty(validator, attestation_duty, threshold, participation_shares, log_operator) + + assert participation_shares[validator.lido_id.operatorIndex] == 10 + assert log_operator.validators[validator.index].attestation_duty == attestation_duty + + +def test_process_validator_duty_handles_below_threshold_performance(): + validator = LidoValidatorFactory.build() + validator.validator.slashed = False + log_operator = Mock() + log_operator.validators = defaultdict(ValidatorFrameSummary) + participation_shares = defaultdict(int) + threshold = 0.5 + + attestation_duty = AttestationsAccumulator(assigned=10, included=4) + + CSOracle.process_validator_duty(validator, attestation_duty, threshold, participation_shares, log_operator) + + assert participation_shares[validator.lido_id.operatorIndex] == 0 + assert log_operator.validators[validator.index].attestation_duty == attestation_duty + + +def test_process_validator_duty_handles_non_empy_participation_shares(): + validator = LidoValidatorFactory.build() + validator.validator.slashed = False + log_operator = Mock() + log_operator.validators = defaultdict(ValidatorFrameSummary) + participation_shares = {validator.lido_id.operatorIndex: 25} + threshold = 0.5 + + attestation_duty = AttestationsAccumulator(assigned=10, included=6) + + CSOracle.process_validator_duty(validator, attestation_duty, threshold, participation_shares, log_operator) + + assert participation_shares[validator.lido_id.operatorIndex] == 35 + assert log_operator.validators[validator.index].attestation_duty == attestation_duty + + +def test_process_validator_duty_handles_no_duty_assigned(): + validator = LidoValidatorFactory.build() + log_operator = Mock() + log_operator.validators = defaultdict(ValidatorFrameSummary) + participation_shares = defaultdict(int) + threshold = 0.5 + + CSOracle.process_validator_duty(validator, None, threshold, participation_shares, log_operator) + + assert participation_shares[validator.lido_id.operatorIndex] == 0 + assert validator.index not in log_operator.validators + + +def test_process_validator_duty_handles_slashed_validator(): + validator = LidoValidatorFactory.build() + validator.validator.slashed = True + log_operator = Mock() + log_operator.validators = defaultdict(ValidatorFrameSummary) + participation_shares = defaultdict(int) + threshold = 0.5 + + attestation_duty = AttestationsAccumulator(assigned=1, included=1) + + CSOracle.process_validator_duty(validator, attestation_duty, threshold, participation_shares, log_operator) + + assert participation_shares[validator.lido_id.operatorIndex] == 0 + assert log_operator.validators[validator.index].slashed is True + + +def test_calc_rewards_distribution_in_frame_correctly_distributes_rewards(): + participation_shares = {NodeOperatorId(1): 100, NodeOperatorId(2): 200} + rewards_to_distribute = Wei(1 * 10**18) + + rewards_distribution = CSOracle.calc_rewards_distribution_in_frame(participation_shares, rewards_to_distribute) + + assert rewards_distribution[NodeOperatorId(1)] == Wei(333333333333333333) + assert rewards_distribution[NodeOperatorId(2)] == Wei(666666666666666666) + + +def test_calc_rewards_distribution_in_frame_handles_zero_participation(): + participation_shares = {NodeOperatorId(1): 0, NodeOperatorId(2): 0} + rewards_to_distribute = Wei(1 * 10**18) + + rewards_distribution = CSOracle.calc_rewards_distribution_in_frame(participation_shares, rewards_to_distribute) + + assert rewards_distribution[NodeOperatorId(1)] == 0 + assert rewards_distribution[NodeOperatorId(2)] == 0 + + +def test_calc_rewards_distribution_in_frame_handles_no_participation(): + participation_shares = {} + rewards_to_distribute = Wei(1 * 10**18) + + rewards_distribution = CSOracle.calc_rewards_distribution_in_frame(participation_shares, rewards_to_distribute) + + assert len(rewards_distribution) == 0 + + +def test_calc_rewards_distribution_in_frame_handles_partial_participation(): + participation_shares = {NodeOperatorId(1): 100, NodeOperatorId(2): 0} + rewards_to_distribute = Wei(1 * 10**18) + + rewards_distribution = CSOracle.calc_rewards_distribution_in_frame(participation_shares, rewards_to_distribute) + + assert rewards_distribution[NodeOperatorId(1)] == Wei(1 * 10**18) + assert rewards_distribution[NodeOperatorId(2)] == 0 diff --git a/tests/modules/csm/test_csm_module.py b/tests/modules/csm/test_csm_module.py index cdb0c92c5..1d396c083 100644 --- a/tests/modules/csm/test_csm_module.py +++ b/tests/modules/csm/test_csm_module.py @@ -132,263 +132,6 @@ def test_stuck_operators_left_border_before_enact(module: CSOracle, csm: CSM, ca assert caplog.messages[0].startswith("No CSM digest at blockstamp") -def test_calculate_distribution(module: CSOracle, csm: CSM): - csm.fee_distributor.shares_to_distribute = Mock(return_value=10_000) - csm.oracle.perf_leeway_bp = Mock(return_value=500) - - module.module_validators_by_node_operators = Mock( - return_value={ - (None, NodeOperatorId(0)): [Mock(index=0, validator=Mock(slashed=False))], - (None, NodeOperatorId(1)): [Mock(index=1, validator=Mock(slashed=False))], - (None, NodeOperatorId(2)): [Mock(index=2, validator=Mock(slashed=False))], # stuck - (None, NodeOperatorId(3)): [Mock(index=3, validator=Mock(slashed=False))], - (None, NodeOperatorId(4)): [Mock(index=4, validator=Mock(slashed=False))], # stuck - (None, NodeOperatorId(5)): [ - Mock(index=5, validator=Mock(slashed=False)), - Mock(index=6, validator=Mock(slashed=False)), - ], - (None, NodeOperatorId(6)): [ - Mock(index=7, validator=Mock(slashed=False)), - Mock(index=8, validator=Mock(slashed=False)), - ], - (None, NodeOperatorId(7)): [Mock(index=9, validator=Mock(slashed=False))], - (None, NodeOperatorId(8)): [ - Mock(index=10, validator=Mock(slashed=False)), - Mock(index=11, validator=Mock(slashed=True)), - ], - (None, NodeOperatorId(9)): [Mock(index=12, validator=Mock(slashed=True))], - } - ) - module.stuck_operators = Mock( - return_value=[ - NodeOperatorId(2), - NodeOperatorId(4), - ] - ) - - frame_0: Frame = (EpochNumber(0), EpochNumber(999)) - - module.state.init_or_migrate(*frame_0, epochs_per_frame=1000, consensus_version=1) - module.state = State( - { - frame_0: { - ValidatorIndex(0): AttestationsAccumulator(included=200, assigned=200), # short on frame - ValidatorIndex(1): AttestationsAccumulator(included=1000, assigned=1000), - ValidatorIndex(2): AttestationsAccumulator(included=1000, assigned=1000), - ValidatorIndex(3): AttestationsAccumulator(included=999, assigned=1000), - ValidatorIndex(4): AttestationsAccumulator(included=900, assigned=1000), - ValidatorIndex(5): AttestationsAccumulator(included=500, assigned=1000), # underperforming - ValidatorIndex(6): AttestationsAccumulator(included=0, assigned=0), # underperforming - ValidatorIndex(7): AttestationsAccumulator(included=900, assigned=1000), - ValidatorIndex(8): AttestationsAccumulator(included=500, assigned=1000), # underperforming - # ValidatorIndex(9): AttestationsAggregate(included=0, assigned=0), # missing in state - ValidatorIndex(10): AttestationsAccumulator(included=1000, assigned=1000), - ValidatorIndex(11): AttestationsAccumulator(included=1000, assigned=1000), - ValidatorIndex(12): AttestationsAccumulator(included=1000, assigned=1000), - } - } - ) - - l_epoch, r_epoch = frame_0 - - frame_0_network_aggr = module.state.get_network_aggr(frame_0) - - blockstamp = ReferenceBlockStampFactory.build(slot_number=r_epoch * 32, ref_epoch=r_epoch, ref_slot=r_epoch * 32) - _, shares, logs = module.calculate_distribution(blockstamp=blockstamp) - - log, *_ = logs - - assert tuple(shares.items()) == ( - (NodeOperatorId(0), 476), - (NodeOperatorId(1), 2380), - (NodeOperatorId(3), 2380), - (NodeOperatorId(6), 2380), - (NodeOperatorId(8), 2380), - ) - - assert tuple(log.operators.keys()) == ( - NodeOperatorId(0), - NodeOperatorId(1), - NodeOperatorId(2), - NodeOperatorId(3), - NodeOperatorId(4), - NodeOperatorId(5), - NodeOperatorId(6), - # NodeOperatorId(7), # Missing in state - NodeOperatorId(8), - NodeOperatorId(9), - ) - - assert not log.operators[NodeOperatorId(1)].stuck - - assert log.operators[NodeOperatorId(2)].validators == {} - assert log.operators[NodeOperatorId(2)].stuck - assert log.operators[NodeOperatorId(4)].validators == {} - assert log.operators[NodeOperatorId(4)].stuck - - assert 5 in log.operators[NodeOperatorId(5)].validators - assert 6 in log.operators[NodeOperatorId(5)].validators - assert 7 in log.operators[NodeOperatorId(6)].validators - - assert log.operators[NodeOperatorId(0)].distributed == 476 - assert log.operators[NodeOperatorId(1)].distributed == 2380 - assert log.operators[NodeOperatorId(2)].distributed == 0 - assert log.operators[NodeOperatorId(3)].distributed == 2380 - assert log.operators[NodeOperatorId(6)].distributed == 2380 - - assert log.frame == frame_0 - assert log.threshold == frame_0_network_aggr.perf - 0.05 - - -def test_calculate_distribution_with_missed_with_two_frames(module: CSOracle, csm: CSM): - csm.oracle.perf_leeway_bp = Mock(return_value=500) - csm.fee_distributor.shares_to_distribute = Mock(side_effect=[10000, 20000]) - - module.module_validators_by_node_operators = Mock( - return_value={ - (None, NodeOperatorId(0)): [Mock(index=0, validator=Mock(slashed=False))], - (None, NodeOperatorId(1)): [Mock(index=1, validator=Mock(slashed=False))], - (None, NodeOperatorId(2)): [Mock(index=2, validator=Mock(slashed=False))], # stuck - (None, NodeOperatorId(3)): [Mock(index=3, validator=Mock(slashed=False))], - (None, NodeOperatorId(4)): [Mock(index=4, validator=Mock(slashed=False))], # stuck - (None, NodeOperatorId(5)): [ - Mock(index=5, validator=Mock(slashed=False)), - Mock(index=6, validator=Mock(slashed=False)), - ], - (None, NodeOperatorId(6)): [ - Mock(index=7, validator=Mock(slashed=False)), - Mock(index=8, validator=Mock(slashed=False)), - ], - (None, NodeOperatorId(7)): [Mock(index=9, validator=Mock(slashed=False))], - (None, NodeOperatorId(8)): [ - Mock(index=10, validator=Mock(slashed=False)), - Mock(index=11, validator=Mock(slashed=True)), - ], - (None, NodeOperatorId(9)): [Mock(index=12, validator=Mock(slashed=True))], - } - ) - - module.stuck_operators = Mock( - side_effect=[ - [ - NodeOperatorId(2), - NodeOperatorId(4), - ], - [ - NodeOperatorId(2), - NodeOperatorId(4), - ], - ] - ) - - module.state = State() - l_epoch, r_epoch = EpochNumber(0), EpochNumber(1999) - frame_0 = (0, 999) - frame_1 = (1000, 1999) - module.state.init_or_migrate(l_epoch, r_epoch, epochs_per_frame=1000, consensus_version=1) - module.state = State( - { - frame_0: { - ValidatorIndex(0): AttestationsAccumulator(included=200, assigned=200), # short on frame - ValidatorIndex(1): AttestationsAccumulator(included=1000, assigned=1000), - ValidatorIndex(2): AttestationsAccumulator(included=1000, assigned=1000), - ValidatorIndex(3): AttestationsAccumulator(included=999, assigned=1000), - ValidatorIndex(4): AttestationsAccumulator(included=900, assigned=1000), - ValidatorIndex(5): AttestationsAccumulator(included=500, assigned=1000), # underperforming - ValidatorIndex(6): AttestationsAccumulator(included=0, assigned=0), # underperforming - ValidatorIndex(7): AttestationsAccumulator(included=900, assigned=1000), - ValidatorIndex(8): AttestationsAccumulator(included=500, assigned=1000), # underperforming - # ValidatorIndex(9): AttestationsAggregate(included=0, assigned=0), # missing in state - ValidatorIndex(10): AttestationsAccumulator(included=1000, assigned=1000), - ValidatorIndex(11): AttestationsAccumulator(included=1000, assigned=1000), - ValidatorIndex(12): AttestationsAccumulator(included=1000, assigned=1000), - }, - frame_1: { - ValidatorIndex(0): AttestationsAccumulator(included=200, assigned=200), # short on frame - ValidatorIndex(1): AttestationsAccumulator(included=1000, assigned=1000), - ValidatorIndex(2): AttestationsAccumulator(included=1000, assigned=1000), - ValidatorIndex(3): AttestationsAccumulator(included=999, assigned=1000), - ValidatorIndex(4): AttestationsAccumulator(included=900, assigned=1000), - ValidatorIndex(5): AttestationsAccumulator(included=500, assigned=1000), # underperforming - ValidatorIndex(6): AttestationsAccumulator(included=0, assigned=0), # underperforming - ValidatorIndex(7): AttestationsAccumulator(included=900, assigned=1000), - ValidatorIndex(8): AttestationsAccumulator(included=500, assigned=1000), # underperforming - # ValidatorIndex(9): AttestationsAggregate(included=0, assigned=0), # missing in state - ValidatorIndex(10): AttestationsAccumulator(included=1000, assigned=1000), - ValidatorIndex(11): AttestationsAccumulator(included=1000, assigned=1000), - ValidatorIndex(12): AttestationsAccumulator(included=1000, assigned=1000), - }, - } - ) - module.w3.cc = Mock() - - module.converter = Mock( - side_effect=lambda _: Mock( - frame_config=FrameConfigFactory.build(epochs_per_frame=1000), - get_epoch_last_slot=lambda epoch: epoch * 32 + 31, - ) - ) - - module._get_ref_blockstamp_for_frame = Mock( - side_effect=[ - ReferenceBlockStampFactory.build( - slot_number=frame_0[1] * 32, ref_epoch=frame_0[1], ref_slot=frame_0[1] * 32 - ), - ReferenceBlockStampFactory.build(slot_number=r_epoch * 32, ref_epoch=r_epoch, ref_slot=r_epoch * 32), - ] - ) - - blockstamp = ReferenceBlockStampFactory.build(slot_number=r_epoch * 32, ref_epoch=r_epoch, ref_slot=r_epoch * 32) - distributed, shares, logs = module.calculate_distribution(blockstamp=blockstamp) - - assert distributed == 2 * 9_998 # because of the rounding - - assert tuple(shares.items()) == ( - (NodeOperatorId(0), 952), - (NodeOperatorId(1), 4761), - (NodeOperatorId(3), 4761), - (NodeOperatorId(6), 4761), - (NodeOperatorId(8), 4761), - ) - - assert len(logs) == 2 - - for log in logs: - - assert log.frame in module.state.data.keys() - assert log.threshold == module.state.get_network_aggr(log.frame).perf - 0.05 - - assert tuple(log.operators.keys()) == ( - NodeOperatorId(0), - NodeOperatorId(1), - NodeOperatorId(2), - NodeOperatorId(3), - NodeOperatorId(4), - NodeOperatorId(5), - NodeOperatorId(6), - # NodeOperatorId(7), # Missing in state - NodeOperatorId(8), - NodeOperatorId(9), - ) - - assert not log.operators[NodeOperatorId(1)].stuck - - assert log.operators[NodeOperatorId(2)].validators == {} - assert log.operators[NodeOperatorId(2)].stuck - assert log.operators[NodeOperatorId(4)].validators == {} - assert log.operators[NodeOperatorId(4)].stuck - - assert 5 in log.operators[NodeOperatorId(5)].validators - assert 6 in log.operators[NodeOperatorId(5)].validators - assert 7 in log.operators[NodeOperatorId(6)].validators - - assert log.operators[NodeOperatorId(0)].distributed == 476 - assert log.operators[NodeOperatorId(1)].distributed in [2380, 2381] - assert log.operators[NodeOperatorId(2)].distributed == 0 - assert log.operators[NodeOperatorId(3)].distributed in [2380, 2381] - assert log.operators[NodeOperatorId(6)].distributed in [2380, 2381] - - # Static functions you were dreaming of for so long. diff --git a/tests/modules/csm/test_state.py b/tests/modules/csm/test_state.py index b5d8f8808..ee09fe1e5 100644 --- a/tests/modules/csm/test_state.py +++ b/tests/modules/csm/test_state.py @@ -132,12 +132,16 @@ def test_clear_resets_state_to_empty(): def test_find_frame_returns_correct_frame(): state = State() + state._epochs_to_process = tuple(sequence(0, 31)) + state._epochs_per_frame = 32 state.data = {(0, 31): defaultdict(AttestationsAccumulator)} assert state.find_frame(15) == (0, 31) def test_find_frame_raises_error_for_out_of_range_epoch(): state = State() + state._epochs_to_process = tuple(sequence(0, 31)) + state._epochs_per_frame = 32 state.data = {(0, 31): defaultdict(AttestationsAccumulator)} with pytest.raises(ValueError, match="Epoch 32 is out of frames range"): state.find_frame(32) From f3896d4d2fb4b17277b6dc50eecb5450d957ac9c Mon Sep 17 00:00:00 2001 From: vgorkavenko Date: Thu, 13 Feb 2025 17:58:43 +0100 Subject: [PATCH 04/20] fix: log tests --- tests/modules/csm/test_log.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/modules/csm/test_log.py b/tests/modules/csm/test_log.py index 61004e9ed..c95ef93ca 100644 --- a/tests/modules/csm/test_log.py +++ b/tests/modules/csm/test_log.py @@ -28,7 +28,7 @@ def test_fields_access(log: FramePerfLog): def test_log_encode(log: FramePerfLog): # Fill in dynamic fields to make sure we have data in it to be encoded. - log.operators[NodeOperatorId(42)].validators["41337"].perf = AttestationsAccumulator(220, 119) + log.operators[NodeOperatorId(42)].validators["41337"].attestation_duty = AttestationsAccumulator(220, 119) log.operators[NodeOperatorId(42)].distributed = 17 log.operators[NodeOperatorId(0)].distributed = 0 @@ -37,8 +37,8 @@ def test_log_encode(log: FramePerfLog): encoded = FramePerfLog.encode(logs) for decoded in json.loads(encoded): - assert decoded["operators"]["42"]["validators"]["41337"]["perf"]["assigned"] == 220 - assert decoded["operators"]["42"]["validators"]["41337"]["perf"]["included"] == 119 + assert decoded["operators"]["42"]["validators"]["41337"]["attestation_duty"]["assigned"] == 220 + assert decoded["operators"]["42"]["validators"]["41337"]["attestation_duty"]["included"] == 119 assert decoded["operators"]["42"]["distributed"] == 17 assert decoded["operators"]["0"]["distributed"] == 0 @@ -51,12 +51,12 @@ def test_log_encode(log: FramePerfLog): def test_logs_encode(): log_0 = FramePerfLog(ReferenceBlockStampFactory.build(), (EpochNumber(100), EpochNumber(500))) - log_0.operators[NodeOperatorId(42)].validators["41337"].perf = AttestationsAccumulator(220, 119) + log_0.operators[NodeOperatorId(42)].validators["41337"].attestation_duty = AttestationsAccumulator(220, 119) log_0.operators[NodeOperatorId(42)].distributed = 17 log_0.operators[NodeOperatorId(0)].distributed = 0 log_1 = FramePerfLog(ReferenceBlockStampFactory.build(), (EpochNumber(500), EpochNumber(900))) - log_1.operators[NodeOperatorId(5)].validators["1234"].perf = AttestationsAccumulator(400, 399) + log_1.operators[NodeOperatorId(5)].validators["1234"].attestation_duty = AttestationsAccumulator(400, 399) log_1.operators[NodeOperatorId(5)].distributed = 40 log_1.operators[NodeOperatorId(18)].distributed = 0 @@ -68,13 +68,13 @@ def test_logs_encode(): assert len(decoded) == 2 - assert decoded[0]["operators"]["42"]["validators"]["41337"]["perf"]["assigned"] == 220 - assert decoded[0]["operators"]["42"]["validators"]["41337"]["perf"]["included"] == 119 + assert decoded[0]["operators"]["42"]["validators"]["41337"]["attestation_duty"]["assigned"] == 220 + assert decoded[0]["operators"]["42"]["validators"]["41337"]["attestation_duty"]["included"] == 119 assert decoded[0]["operators"]["42"]["distributed"] == 17 assert decoded[0]["operators"]["0"]["distributed"] == 0 - assert decoded[1]["operators"]["5"]["validators"]["1234"]["perf"]["assigned"] == 400 - assert decoded[1]["operators"]["5"]["validators"]["1234"]["perf"]["included"] == 399 + assert decoded[1]["operators"]["5"]["validators"]["1234"]["attestation_duty"]["assigned"] == 400 + assert decoded[1]["operators"]["5"]["validators"]["1234"]["attestation_duty"]["included"] == 399 assert decoded[1]["operators"]["5"]["distributed"] == 40 assert decoded[1]["operators"]["18"]["distributed"] == 0 From 27073f41b6d63a2c035e65deb2ef3aacc853a4a9 Mon Sep 17 00:00:00 2001 From: vgorkavenko Date: Fri, 14 Feb 2025 09:22:29 +0100 Subject: [PATCH 05/20] fix: review --- src/modules/csm/state.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/src/modules/csm/state.py b/src/modules/csm/state.py index e6fb7d866..46f664157 100644 --- a/src/modules/csm/state.py +++ b/src/modules/csm/state.py @@ -117,12 +117,9 @@ def frames(self): @staticmethod def calculate_frames(epochs_to_process: tuple[EpochNumber, ...], epochs_per_frame: int) -> list[Frame]: """Split epochs to process into frames of `epochs_per_frame` length""" - frames = [] - for frame_epochs in batched(epochs_to_process, epochs_per_frame): - if len(frame_epochs) < epochs_per_frame: - raise ValueError("Insufficient epochs to form a frame") - frames.append((frame_epochs[0], frame_epochs[-1])) - return frames + if len(epochs_to_process) % epochs_per_frame != 0: + raise ValueError("Insufficient epochs to form a frame") + return [(frame[0], frame[-1]) for frame in batched(epochs_to_process, epochs_per_frame)] def clear(self) -> None: self.data = {} @@ -173,8 +170,7 @@ def init_or_migrate( for current_frame, migrated in migration_status.items(): if not migrated: logger.warning({"msg": f"Invalidating frame duties data cache: {current_frame}"}) - for epoch in sequence(*current_frame): - self._processed_epochs.discard(epoch) + self._processed_epochs.difference_update(sequence(*current_frame)) self.data = frames_data self._epochs_per_frame = epochs_per_frame @@ -201,9 +197,9 @@ def _migrate_frames_data( new_frame_l_epoch, new_frame_r_epoch = new_frame if curr_frame_l_epoch >= new_frame_l_epoch and curr_frame_r_epoch <= new_frame_r_epoch: logger.info({"msg": f"Migrating frame duties data cache: {current_frame=} -> {new_frame=}"}) - for val in self.data[current_frame]: - new_data[new_frame][val].assigned += self.data[current_frame][val].assigned - new_data[new_frame][val].included += self.data[current_frame][val].included + for val, duty in self.data[current_frame].items(): + new_data[new_frame][val].assigned += duty.assigned + new_data[new_frame][val].included += duty.included migration_status[current_frame] = True break From cbedd04871a7687e11a3062880f1330ca676653d Mon Sep 17 00:00:00 2001 From: vgorkavenko Date: Fri, 14 Feb 2025 13:08:24 +0100 Subject: [PATCH 06/20] fix: coverage + renaming --- src/modules/csm/csm.py | 8 ++++---- tests/modules/csm/test_csm_distribution.py | 8 ++++++-- tests/modules/csm/test_csm_module.py | 6 +++--- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/src/modules/csm/csm.py b/src/modules/csm/csm.py index ed146677e..861e902e7 100644 --- a/src/modules/csm/csm.py +++ b/src/modules/csm/csm.py @@ -120,9 +120,9 @@ def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple: ).as_tuple() if prev_cid and prev_root != ZERO_HASH: - # Update cumulative amount of shares for all operators. - for no_id, acc_shares in self.get_accumulated_shares(prev_cid, prev_root): - total_rewards[no_id] += acc_shares + # Update cumulative amount of stETH shares for all operators. + for no_id, accumulated_rewards in self.get_accumulated_rewards(prev_cid, prev_root): + total_rewards[no_id] += accumulated_rewards else: logger.info({"msg": "No previous distribution. Nothing to accumulate"}) @@ -355,7 +355,7 @@ def calc_rewards_distribution_in_frame( return rewards_distribution - def get_accumulated_shares(self, cid: CID, root: HexBytes) -> Iterator[tuple[NodeOperatorId, Shares]]: + def get_accumulated_rewards(self, cid: CID, root: HexBytes) -> Iterator[tuple[NodeOperatorId, Shares]]: logger.info({"msg": "Fetching tree by CID from IPFS", "cid": repr(cid)}) tree = Tree.decode(self.w3.ipfs.fetch(cid)) diff --git a/tests/modules/csm/test_csm_distribution.py b/tests/modules/csm/test_csm_distribution.py index cb07c2e95..f58545af3 100644 --- a/tests/modules/csm/test_csm_distribution.py +++ b/tests/modules/csm/test_csm_distribution.py @@ -1,5 +1,5 @@ from collections import defaultdict -from unittest.mock import Mock +from unittest.mock import Mock, call import pytest from web3.types import Wei @@ -10,6 +10,7 @@ from src.modules.csm.state import AttestationsAccumulator, State from src.types import NodeOperatorId, ValidatorIndex from src.web3py.extensions import CSM +from tests.factory.blockstamp import ReferenceBlockStampFactory from tests.factory.no_registry import LidoValidatorFactory @@ -42,7 +43,7 @@ def test_calculate_distribution_handles_single_frame(module): def test_calculate_distribution_handles_multiple_frames(module): module.state = Mock() module.state.frames = [(1, 2), (3, 4), (5, 6)] - blockstamp = Mock() + blockstamp = ReferenceBlockStampFactory.build(ref_epoch=2) module.module_validators_by_node_operators = Mock() module._get_ref_blockstamp_for_frame = Mock(return_value=blockstamp) module.w3.csm.fee_distributor.shares_to_distribute = Mock(return_value=800) @@ -59,6 +60,9 @@ def test_calculate_distribution_handles_multiple_frames(module): assert total_distributed == 800 assert total_rewards[NodeOperatorId(1)] == 800 assert len(logs) == 3 + module._get_ref_blockstamp_for_frame.assert_has_calls( + [call(blockstamp, frame[1]) for frame in module.state.frames[1:]] + ) def test_calculate_distribution_handles_invalid_distribution(module): diff --git a/tests/modules/csm/test_csm_module.py b/tests/modules/csm/test_csm_module.py index 1d396c083..b3f964a55 100644 --- a/tests/modules/csm/test_csm_module.py +++ b/tests/modules/csm/test_csm_module.py @@ -595,7 +595,7 @@ def test_build_report(csm: CSM, module: CSOracle, param: BuildReportTestParam): # mock previous report module.w3.csm.get_csm_tree_root = Mock(return_value=param.prev_tree_root) module.w3.csm.get_csm_tree_cid = Mock(return_value=param.prev_tree_cid) - module.get_accumulated_shares = Mock(return_value=param.prev_acc_shares) + module.get_accumulated_rewards = Mock(return_value=param.prev_acc_shares) # mock current frame module.calculate_distribution = param.curr_distribution module.make_tree = Mock(return_value=Mock(root=param.curr_tree_root)) @@ -654,7 +654,7 @@ def test_get_accumulated_shares(module: CSOracle, tree: Tree): encoded_tree = tree.encode() module.w3.ipfs = Mock(fetch=Mock(return_value=encoded_tree)) - for i, leaf in enumerate(module.get_accumulated_shares(cid=CIDv0("0x100500"), root=tree.root)): + for i, leaf in enumerate(module.get_accumulated_rewards(cid=CIDv0("0x100500"), root=tree.root)): assert tuple(leaf) == tree.tree.values[i]["value"] @@ -663,7 +663,7 @@ def test_get_accumulated_shares_unexpected_root(module: CSOracle, tree: Tree): module.w3.ipfs = Mock(fetch=Mock(return_value=encoded_tree)) with pytest.raises(ValueError): - next(module.get_accumulated_shares(cid=CIDv0("0x100500"), root=HexBytes("0x100500"))) + next(module.get_accumulated_rewards(cid=CIDv0("0x100500"), root=HexBytes("0x100500"))) @dataclass(frozen=True) From c83aa47f2a502c646d6b3045364d864b86b01543 Mon Sep 17 00:00:00 2001 From: vgorkavenko Date: Fri, 14 Feb 2025 14:09:31 +0100 Subject: [PATCH 07/20] fix: get_stuck_operators --- src/modules/csm/csm.py | 32 ++++++++++----------- tests/modules/csm/test_csm_distribution.py | 8 +++--- tests/modules/csm/test_csm_module.py | 33 ++++++++++++---------- 3 files changed, 38 insertions(+), 35 deletions(-) diff --git a/src/modules/csm/csm.py b/src/modules/csm/csm.py index 861e902e7..4bdd5bd3c 100644 --- a/src/modules/csm/csm.py +++ b/src/modules/csm/csm.py @@ -160,7 +160,7 @@ def module_validators_by_node_operators(self, blockstamp: BlockStamp) -> Validat def validate_state(self, blockstamp: ReferenceBlockStamp) -> None: # NOTE: We cannot use `r_epoch` from the `current_frame_range` call because the `blockstamp` is a # `ReferenceBlockStamp`, hence it's a block the frame ends at. We use `ref_epoch` instead. - l_epoch, _ = self.current_frame_range(blockstamp) + l_epoch, _ = self.get_epochs_range_to_process(blockstamp) r_epoch = blockstamp.ref_epoch self.state.validate(l_epoch, r_epoch) @@ -175,8 +175,8 @@ def collect_data(self, blockstamp: BlockStamp) -> bool: converter = self.converter(blockstamp) - l_epoch, r_epoch = self.current_frame_range(blockstamp) - logger.info({"msg": f"Frame for performance data collect: epochs [{l_epoch};{r_epoch}]"}) + l_epoch, r_epoch = self.get_epochs_range_to_process(blockstamp) + logger.info({"msg": f"Epochs range for performance data collect: [{l_epoch};{r_epoch}]"}) # NOTE: Finalized slot is the first slot of justifying epoch, so we need to take the previous. But if the first # slot of the justifying epoch is empty, blockstamp.slot_number will point to the slot where the last finalized @@ -192,13 +192,13 @@ def collect_data(self, blockstamp: BlockStamp) -> bool: if report_blockstamp and report_blockstamp.ref_epoch != r_epoch: logger.warning( { - "msg": f"Frame has been changed, but the change is not yet observed on finalized epoch {finalized_epoch}" + "msg": f"Epochs range has been changed, but the change is not yet observed on finalized epoch {finalized_epoch}" } ) return False if l_epoch > finalized_epoch: - logger.info({"msg": "The starting epoch of the frame is not finalized yet"}) + logger.info({"msg": "The starting epoch of the epochs range is not finalized yet"}) return False self.state.init_or_migrate(l_epoch, r_epoch, converter.frame_config.epochs_per_frame, consensus_version) @@ -218,8 +218,8 @@ def collect_data(self, blockstamp: BlockStamp) -> bool: processor = FrameCheckpointProcessor(self.w3.cc, self.state, converter, blockstamp, eip7549_supported) for checkpoint in checkpoints: - if self.current_frame_range(self._receive_last_finalized_slot()) != (l_epoch, r_epoch): - logger.info({"msg": "Checkpoints were prepared for an outdated frame, stop processing"}) + if self.get_epochs_range_to_process(self._receive_last_finalized_slot()) != (l_epoch, r_epoch): + logger.info({"msg": "Checkpoints were prepared for an outdated epochs range, stop processing"}) raise ValueError("Outdated checkpoint") processor.exec(checkpoint) @@ -285,7 +285,7 @@ def _calculate_distribution_in_frame( participation_shares: defaultdict[NodeOperatorId, int] = defaultdict(int) - stuck_operators = self.stuck_operators(blockstamp) + stuck_operators = self.get_stuck_operators(frame, blockstamp) for (_, no_id), validators in operators_to_validators.items(): log_operator = log.operators[no_id] if no_id in stuck_operators: @@ -367,17 +367,17 @@ def get_accumulated_rewards(self, cid: CID, root: HexBytes) -> Iterator[tuple[No for v in tree.tree.values: yield v["value"] - def stuck_operators(self, blockstamp: ReferenceBlockStamp) -> set[NodeOperatorId]: + def get_stuck_operators(self, frame: Frame, frame_blockstamp: ReferenceBlockStamp) -> set[NodeOperatorId]: stuck: set[NodeOperatorId] = set() - l_epoch, _ = self.current_frame_range(blockstamp) - l_ref_slot = self.converter(blockstamp).get_epoch_first_slot(l_epoch) + l_epoch, _ = frame + l_ref_slot = self.converter(frame_blockstamp).get_epoch_first_slot(l_epoch) # NOTE: r_block is guaranteed to be <= ref_slot, and the check # in the inner frames assures the l_block <= r_block. l_blockstamp = build_blockstamp( get_next_non_missed_slot( self.w3.cc, l_ref_slot, - blockstamp.slot_number, + frame_blockstamp.slot_number, ) ) @@ -390,7 +390,7 @@ def stuck_operators(self, blockstamp: ReferenceBlockStamp) -> set[NodeOperatorId stuck.update( self.w3.csm.get_operators_with_stucks_in_range( l_blockstamp.block_hash, - blockstamp.block_hash, + frame_blockstamp.block_hash, ) ) return stuck @@ -424,7 +424,7 @@ def publish_log(self, logs: list[FramePerfLog]) -> CID: return log_cid @lru_cache(maxsize=1) - def current_frame_range(self, blockstamp: BlockStamp) -> tuple[EpochNumber, EpochNumber]: + def get_epochs_range_to_process(self, blockstamp: BlockStamp) -> tuple[EpochNumber, EpochNumber]: converter = self.converter(blockstamp) far_future_initial_epoch = converter.get_epoch_by_timestamp(UINT64_MAX) @@ -455,9 +455,9 @@ def current_frame_range(self, blockstamp: BlockStamp) -> tuple[EpochNumber, Epoc ) if l_ref_slot < last_processing_ref_slot: - raise CSMError(f"Got invalid frame range: {l_ref_slot=} < {last_processing_ref_slot=}") + raise CSMError(f"Got invalid epochs range: {l_ref_slot=} < {last_processing_ref_slot=}") if l_ref_slot >= r_ref_slot: - raise CSMError(f"Got invalid frame range {r_ref_slot=}, {l_ref_slot=}") + raise CSMError(f"Got invalid epochs range {r_ref_slot=}, {l_ref_slot=}") l_epoch = converter.get_epoch_by_slot(SlotNumber(l_ref_slot + 1)) r_epoch = converter.get_epoch_by_slot(r_ref_slot) diff --git a/tests/modules/csm/test_csm_distribution.py b/tests/modules/csm/test_csm_distribution.py index f58545af3..aac45d83e 100644 --- a/tests/modules/csm/test_csm_distribution.py +++ b/tests/modules/csm/test_csm_distribution.py @@ -85,7 +85,7 @@ def test_calculate_distribution_in_frame_handles_stuck_operator(module): operators_to_validators = {(Mock(), NodeOperatorId(1)): [LidoValidatorFactory.build()]} module.state = State() module.state.data = {frame: defaultdict(AttestationsAccumulator)} - module.stuck_operators = Mock(return_value={NodeOperatorId(1)}) + module.get_stuck_operators = Mock(return_value={NodeOperatorId(1)}) module._get_performance_threshold = Mock() rewards_distribution, log = module._calculate_distribution_in_frame( @@ -107,7 +107,7 @@ def test_calculate_distribution_in_frame_handles_no_attestation_duty(module): operators_to_validators = {(Mock(), node_operator_id): [validator]} module.state = State() module.state.data = {frame: defaultdict(AttestationsAccumulator)} - module.stuck_operators = Mock(return_value=set()) + module.get_stuck_operators = Mock(return_value=set()) module._get_performance_threshold = Mock() rewards_distribution, log = module._calculate_distribution_in_frame( @@ -131,7 +131,7 @@ def test_calculate_distribution_in_frame_handles_above_threshold_performance(mod module.state = State() attestation_duty = AttestationsAccumulator(assigned=10, included=6) module.state.data = {frame: {validator.index: attestation_duty}} - module.stuck_operators = Mock(return_value=set()) + module.get_stuck_operators = Mock(return_value=set()) module._get_performance_threshold = Mock(return_value=0.5) rewards_distribution, log = module._calculate_distribution_in_frame( @@ -155,7 +155,7 @@ def test_calculate_distribution_in_frame_handles_below_threshold_performance(mod module.state = State() attestation_duty = AttestationsAccumulator(assigned=10, included=5) module.state.data = {frame: {validator.index: attestation_duty}} - module.stuck_operators = Mock(return_value=set()) + module.get_stuck_operators = Mock(return_value=set()) module._get_performance_threshold = Mock(return_value=0.5) rewards_distribution, log = module._calculate_distribution_in_frame( diff --git a/tests/modules/csm/test_csm_module.py b/tests/modules/csm/test_csm_module.py index b3f964a55..1ef4c9234 100644 --- a/tests/modules/csm/test_csm_module.py +++ b/tests/modules/csm/test_csm_module.py @@ -39,7 +39,7 @@ def test_init(module: CSOracle): assert module -def test_stuck_operators(module: CSOracle, csm: CSM): +def test_get_stuck_operators(module: CSOracle, csm: CSM): module.module = Mock() # type: ignore module.module_id = StakingModuleId(1) module.w3.cc = Mock() @@ -66,7 +66,7 @@ def test_stuck_operators(module: CSOracle, csm: CSM): return_value=[NodeOperatorId(2), NodeOperatorId(4), NodeOperatorId(6), NodeOperatorId(1337)] ) - module.current_frame_range = Mock(return_value=(69, 100)) + module.get_epochs_range_to_process = Mock(return_value=(69, 100)) module.converter = Mock() module.converter.get_epoch_first_slot = Mock(return_value=lambda epoch: epoch * 32) @@ -78,12 +78,12 @@ def test_stuck_operators(module: CSOracle, csm: CSM): with patch('src.modules.csm.csm.build_blockstamp', return_value=l_blockstamp): with patch('src.modules.csm.csm.get_next_non_missed_slot', return_value=Mock()): - stuck = module.stuck_operators(blockstamp=blockstamp) + stuck = module.get_stuck_operators(frame=(69, 100), frame_blockstamp=blockstamp) assert stuck == {NodeOperatorId(2), NodeOperatorId(4), NodeOperatorId(5), NodeOperatorId(6), NodeOperatorId(1337)} -def test_stuck_operators_left_border_before_enact(module: CSOracle, csm: CSM, caplog: pytest.LogCaptureFixture): +def test_get_stuck_operators_left_border_before_enact(module: CSOracle, csm: CSM, caplog: pytest.LogCaptureFixture): module.module = Mock() # type: ignore module.module_id = StakingModuleId(3) module.w3.cc = Mock() @@ -112,7 +112,7 @@ def test_stuck_operators_left_border_before_enact(module: CSOracle, csm: CSM, ca ] ) - module.current_frame_range = Mock(return_value=(69, 100)) + module.get_epochs_range_to_process = Mock(return_value=(69, 100)) module.converter = Mock() module.converter.get_epoch_first_slot = Mock(return_value=lambda epoch: epoch * 32) @@ -121,7 +121,7 @@ def test_stuck_operators_left_border_before_enact(module: CSOracle, csm: CSM, ca with patch('src.modules.csm.csm.build_blockstamp', return_value=l_blockstamp): with patch('src.modules.csm.csm.get_next_non_missed_slot', return_value=Mock()): - stuck = module.stuck_operators(blockstamp=blockstamp) + stuck = module.get_stuck_operators(frame=(69, 100), frame_blockstamp=blockstamp) assert stuck == { NodeOperatorId(2), @@ -283,11 +283,11 @@ def test_current_frame_range(module: CSOracle, csm: CSM, mock_chain_config: NoRe if param.expected_frame is ValueError: with pytest.raises(ValueError): - module.current_frame_range(ReferenceBlockStampFactory.build(slot_number=param.finalized_slot)) + module.get_epochs_range_to_process(ReferenceBlockStampFactory.build(slot_number=param.finalized_slot)) else: bs = ReferenceBlockStampFactory.build(slot_number=param.finalized_slot) - l_epoch, r_epoch = module.current_frame_range(bs) + l_epoch, r_epoch = module.get_epochs_range_to_process(bs) assert (l_epoch, r_epoch) == param.expected_frame @@ -321,7 +321,7 @@ class CollectDataTestParam: collect_frame_range=Mock(return_value=(0, 1)), report_blockstamp=Mock(ref_epoch=3), state=Mock(), - expected_msg="Frame has been changed, but the change is not yet observed on finalized epoch 1", + expected_msg="Epochs range has been changed, but the change is not yet observed on finalized epoch 1", expected_result=False, ), id="frame_changed_forward", @@ -332,7 +332,7 @@ class CollectDataTestParam: collect_frame_range=Mock(return_value=(0, 2)), report_blockstamp=Mock(ref_epoch=1), state=Mock(), - expected_msg="Frame has been changed, but the change is not yet observed on finalized epoch 1", + expected_msg="Epochs range has been changed, but the change is not yet observed on finalized epoch 1", expected_result=False, ), id="frame_changed_backward", @@ -343,7 +343,7 @@ class CollectDataTestParam: collect_frame_range=Mock(return_value=(1, 2)), report_blockstamp=Mock(ref_epoch=2), state=Mock(), - expected_msg="The starting epoch of the frame is not finalized yet", + expected_msg="The starting epoch of the epochs range is not finalized yet", expected_result=False, ), id="starting_epoch_not_finalized", @@ -393,7 +393,7 @@ def test_collect_data( module.w3 = Mock() module._receive_last_finalized_slot = Mock() module.state = param.state - module.current_frame_range = param.collect_frame_range + module.get_epochs_range_to_process = param.collect_frame_range module.get_blockstamp_for_report = Mock(return_value=param.report_blockstamp) with caplog.at_level(logging.DEBUG): @@ -419,7 +419,7 @@ def test_collect_data_outdated_checkpoint( unprocessed_epochs=list(range(0, 101)), is_fulfilled=False, ) - module.current_frame_range = Mock(side_effect=[(0, 100), (50, 150)]) + module.get_epochs_range_to_process = Mock(side_effect=[(0, 100), (50, 150)]) module.get_blockstamp_for_report = Mock(return_value=Mock(ref_epoch=100)) with caplog.at_level(logging.DEBUG): @@ -427,7 +427,10 @@ def test_collect_data_outdated_checkpoint( module.collect_data(blockstamp=Mock(slot_number=640)) msg = list( - filter(lambda log: "Checkpoints were prepared for an outdated frame, stop processing" in log, caplog.messages) + filter( + lambda log: "Checkpoints were prepared for an outdated epochs range, stop processing" in log, + caplog.messages, + ) ) assert len(msg), "Expected message not found in logs" @@ -443,7 +446,7 @@ def test_collect_data_fulfilled_state( unprocessed_epochs=list(range(0, 101)), ) type(module.state).is_fulfilled = PropertyMock(side_effect=[False, True]) - module.current_frame_range = Mock(return_value=(0, 100)) + module.get_epochs_range_to_process = Mock(return_value=(0, 100)) module.get_blockstamp_for_report = Mock(return_value=Mock(ref_epoch=100)) with caplog.at_level(logging.DEBUG): From b58e7272a5dc81a822ed8bafe86b6e9a88477747 Mon Sep 17 00:00:00 2001 From: vgorkavenko Date: Fri, 14 Feb 2025 18:03:28 +0100 Subject: [PATCH 08/20] refactor: get digests for `get_stuck_operators` --- src/modules/csm/csm.py | 30 ++++++++---------- tests/modules/csm/test_csm_module.py | 46 +++++++++------------------- 2 files changed, 27 insertions(+), 49 deletions(-) diff --git a/src/modules/csm/csm.py b/src/modules/csm/csm.py index 4bdd5bd3c..4781e6f00 100644 --- a/src/modules/csm/csm.py +++ b/src/modules/csm/csm.py @@ -28,7 +28,6 @@ ReferenceBlockStamp, SlotNumber, StakingModuleAddress, - StakingModuleId, ) from src.utils.blockstamp import build_blockstamp from src.utils.cache import global_lru_cache as lru_cache @@ -62,13 +61,13 @@ class CSOracle(BaseModule, ConsensusModule): COMPATIBLE_ONCHAIN_VERSIONS = [(1, 1), (1, 2)] report_contract: CSFeeOracleContract - module_id: StakingModuleId + staking_module: StakingModule def __init__(self, w3: Web3): self.report_contract = w3.csm.oracle self.state = State.load() super().__init__(w3) - self.module_id = self._get_module_id() + self.staking_module = self._get_staking_module() def refresh_contracts(self): self.report_contract = self.w3.csm.oracle # type: ignore @@ -368,7 +367,6 @@ def get_accumulated_rewards(self, cid: CID, root: HexBytes) -> Iterator[tuple[No yield v["value"] def get_stuck_operators(self, frame: Frame, frame_blockstamp: ReferenceBlockStamp) -> set[NodeOperatorId]: - stuck: set[NodeOperatorId] = set() l_epoch, _ = frame l_ref_slot = self.converter(frame_blockstamp).get_epoch_first_slot(l_epoch) # NOTE: r_block is guaranteed to be <= ref_slot, and the check @@ -381,19 +379,17 @@ def get_stuck_operators(self, frame: Frame, frame_blockstamp: ReferenceBlockStam ) ) - nos_by_module = self.w3.lido_validators.get_lido_node_operators_by_modules(l_blockstamp) - if self.module_id in nos_by_module: - stuck.update(no.id for no in nos_by_module[self.module_id] if no.stuck_validators_count > 0) - else: + digests = self.w3.lido_contracts.staking_router.get_all_node_operator_digests( + self.staking_module, l_blockstamp.block_hash + ) + if not digests: logger.warning("No CSM digest at blockstamp=%s, module was not added yet?", l_blockstamp) - - stuck.update( - self.w3.csm.get_operators_with_stucks_in_range( - l_blockstamp.block_hash, - frame_blockstamp.block_hash, - ) + stuck_from_digests = (no.id for no in digests if no.stuck_validators_count > 0) + stuck_from_events = self.w3.csm.get_operators_with_stucks_in_range( + l_blockstamp.block_hash, + frame_blockstamp.block_hash, ) - return stuck + return set(stuck_from_digests) | set(stuck_from_events) def make_tree(self, shares: dict[NodeOperatorId, Shares]) -> Tree: if not shares: @@ -471,11 +467,11 @@ def get_epochs_range_to_process(self, blockstamp: BlockStamp) -> tuple[EpochNumb def converter(self, blockstamp: BlockStamp) -> Web3Converter: return Web3Converter(self.get_chain_config(blockstamp), self.get_frame_config(blockstamp)) - def _get_module_id(self) -> StakingModuleId: + def _get_staking_module(self) -> StakingModule: modules: list[StakingModule] = self.w3.lido_contracts.staking_router.get_staking_modules() for mod in modules: if mod.staking_module_address == self.w3.csm.module.address: - return mod.id + return mod raise NoModuleFound diff --git a/tests/modules/csm/test_csm_module.py b/tests/modules/csm/test_csm_module.py index 1ef4c9234..8fd863cd5 100644 --- a/tests/modules/csm/test_csm_module.py +++ b/tests/modules/csm/test_csm_module.py @@ -9,12 +9,12 @@ from src.constants import UINT64_MAX from src.modules.csm.csm import CSOracle -from src.modules.csm.state import AttestationsAccumulator, State, Frame +from src.modules.csm.state import State from src.modules.csm.tree import Tree from src.modules.submodules.oracle_module import ModuleExecuteDelay from src.modules.submodules.types import CurrentFrame, ZERO_HASH from src.providers.ipfs import CIDv0, CID -from src.types import EpochNumber, NodeOperatorId, SlotNumber, StakingModuleId, ValidatorIndex +from src.types import NodeOperatorId, SlotNumber, StakingModuleId from src.web3py.extensions.csm import CSM from tests.factory.blockstamp import BlockStampFactory, ReferenceBlockStampFactory from tests.factory.configs import ChainConfigFactory, FrameConfigFactory @@ -22,7 +22,7 @@ @pytest.fixture(autouse=True) def mock_get_module_id(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setattr(CSOracle, "_get_module_id", Mock()) + monkeypatch.setattr(CSOracle, "_get_staking_module", Mock()) @pytest.fixture(autouse=True) @@ -45,21 +45,16 @@ def test_get_stuck_operators(module: CSOracle, csm: CSM): module.w3.cc = Mock() module.w3.lido_validators = Mock() module.w3.lido_contracts = Mock() - module.w3.lido_validators.get_lido_node_operators_by_modules = Mock( - return_value={ - 1: { - type('NodeOperator', (object,), {'id': 0, 'stuck_validators_count': 0})(), - type('NodeOperator', (object,), {'id': 1, 'stuck_validators_count': 0})(), - type('NodeOperator', (object,), {'id': 2, 'stuck_validators_count': 1})(), - type('NodeOperator', (object,), {'id': 3, 'stuck_validators_count': 0})(), - type('NodeOperator', (object,), {'id': 4, 'stuck_validators_count': 100500})(), - type('NodeOperator', (object,), {'id': 5, 'stuck_validators_count': 100})(), - type('NodeOperator', (object,), {'id': 6, 'stuck_validators_count': 0})(), - }, - 2: {}, - 3: {}, - 4: {}, - } + module.w3.lido_contracts.staking_router.get_all_node_operator_digests = Mock( + return_value=[ + type('NodeOperator', (object,), {'id': 0, 'stuck_validators_count': 0})(), + type('NodeOperator', (object,), {'id': 1, 'stuck_validators_count': 0})(), + type('NodeOperator', (object,), {'id': 2, 'stuck_validators_count': 1})(), + type('NodeOperator', (object,), {'id': 3, 'stuck_validators_count': 0})(), + type('NodeOperator', (object,), {'id': 4, 'stuck_validators_count': 100500})(), + type('NodeOperator', (object,), {'id': 5, 'stuck_validators_count': 100})(), + type('NodeOperator', (object,), {'id': 6, 'stuck_validators_count': 0})(), + ] ) module.w3.csm.get_operators_with_stucks_in_range = Mock( @@ -89,20 +84,7 @@ def test_get_stuck_operators_left_border_before_enact(module: CSOracle, csm: CSM module.w3.cc = Mock() module.w3.lido_validators = Mock() module.w3.lido_contracts = Mock() - module.w3.lido_validators.get_lido_node_operators_by_modules = Mock( - return_value={ - 1: { - type('NodeOperator', (object,), {'id': 0, 'stuck_validators_count': 0})(), - type('NodeOperator', (object,), {'id': 1, 'stuck_validators_count': 0})(), - type('NodeOperator', (object,), {'id': 2, 'stuck_validators_count': 1})(), - type('NodeOperator', (object,), {'id': 3, 'stuck_validators_count': 0})(), - type('NodeOperator', (object,), {'id': 4, 'stuck_validators_count': 100500})(), - type('NodeOperator', (object,), {'id': 5, 'stuck_validators_count': 100})(), - type('NodeOperator', (object,), {'id': 6, 'stuck_validators_count': 0})(), - }, - 2: {}, - } - ) + module.w3.lido_contracts.staking_router.get_all_node_operator_digests = Mock(return_value=[]) module.w3.csm.get_operators_with_stucks_in_range = Mock( return_value=[ From bacfa9c76970981a9913596d95a944a0360deded Mon Sep 17 00:00:00 2001 From: vgorkavenko Date: Mon, 17 Feb 2025 11:02:29 +0100 Subject: [PATCH 09/20] fix: `mock_get_staking_module` --- tests/modules/csm/test_csm_distribution.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/modules/csm/test_csm_distribution.py b/tests/modules/csm/test_csm_distribution.py index aac45d83e..c2c04591a 100644 --- a/tests/modules/csm/test_csm_distribution.py +++ b/tests/modules/csm/test_csm_distribution.py @@ -6,7 +6,7 @@ from src.constants import UINT64_MAX from src.modules.csm.csm import CSOracle, CSMError -from src.modules.csm.log import ValidatorFrameSummary, OperatorFrameSummary +from src.modules.csm.log import ValidatorFrameSummary from src.modules.csm.state import AttestationsAccumulator, State from src.types import NodeOperatorId, ValidatorIndex from src.web3py.extensions import CSM @@ -15,8 +15,8 @@ @pytest.fixture(autouse=True) -def mock_get_module_id(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setattr(CSOracle, "_get_module_id", Mock()) +def mock_get_staking_module(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(CSOracle, "_get_staking_module", Mock()) @pytest.fixture() From 6d0ab0148757ca1f0d89cfcf3c5f9310d1846760 Mon Sep 17 00:00:00 2001 From: vgorkavenko Date: Wed, 19 Feb 2025 11:08:33 +0100 Subject: [PATCH 10/20] fix: _calculate_frames --- src/modules/csm/state.py | 8 ++++---- tests/modules/csm/test_state.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/modules/csm/state.py b/src/modules/csm/state.py index 46f664157..0cb3eeac8 100644 --- a/src/modules/csm/state.py +++ b/src/modules/csm/state.py @@ -112,14 +112,14 @@ def is_fulfilled(self) -> bool: @property def frames(self): - return self.calculate_frames(self._epochs_to_process, self._epochs_per_frame) + return self._calculate_frames(self._epochs_to_process, self._epochs_per_frame) @staticmethod - def calculate_frames(epochs_to_process: tuple[EpochNumber, ...], epochs_per_frame: int) -> list[Frame]: + def _calculate_frames(epochs_to_process: tuple[EpochNumber, ...], epochs_per_frame: int) -> list[Frame]: """Split epochs to process into frames of `epochs_per_frame` length""" if len(epochs_to_process) % epochs_per_frame != 0: raise ValueError("Insufficient epochs to form a frame") - return [(frame[0], frame[-1]) for frame in batched(epochs_to_process, epochs_per_frame)] + return [(frame[0], frame[-1]) for frame in batched(sorted(epochs_to_process), epochs_per_frame)] def clear(self) -> None: self.data = {} @@ -156,7 +156,7 @@ def init_or_migrate( ) self.clear() - frames = self.calculate_frames(tuple(sequence(l_epoch, r_epoch)), epochs_per_frame) + frames = self._calculate_frames(tuple(sequence(l_epoch, r_epoch)), epochs_per_frame) frames_data: StateData = {frame: defaultdict(AttestationsAccumulator) for frame in frames} if not self.is_empty: diff --git a/tests/modules/csm/test_state.py b/tests/modules/csm/test_state.py index ee09fe1e5..82763561a 100644 --- a/tests/modules/csm/test_state.py +++ b/tests/modules/csm/test_state.py @@ -113,14 +113,14 @@ def test_is_fulfilled_returns_false_if_unprocessed_epochs_exist(): def test_calculate_frames_handles_exact_frame_size(): epochs = tuple(range(10)) - frames = State.calculate_frames(epochs, 5) + frames = State._calculate_frames(epochs, 5) assert frames == [(0, 4), (5, 9)] def test_calculate_frames_raises_error_for_insufficient_epochs(): epochs = tuple(range(8)) with pytest.raises(ValueError, match="Insufficient epochs to form a frame"): - State.calculate_frames(epochs, 5) + State._calculate_frames(epochs, 5) def test_clear_resets_state_to_empty(): From 9ebdc412b9709e60fc0c4bb3094ff0f28b0afb82 Mon Sep 17 00:00:00 2001 From: vgorkavenko Date: Wed, 19 Feb 2025 11:23:56 +0100 Subject: [PATCH 11/20] feat: cached `find_frame` --- src/modules/csm/checkpoint.py | 3 +-- src/modules/csm/state.py | 8 +++++--- tests/modules/csm/test_state.py | 21 ++++++++++++++++----- 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/src/modules/csm/checkpoint.py b/src/modules/csm/checkpoint.py index 69d0a79dd..222d1e626 100644 --- a/src/modules/csm/checkpoint.py +++ b/src/modules/csm/checkpoint.py @@ -205,12 +205,11 @@ def _check_duty( for root in block_roots: attestations = self.cc.get_block_attestations(root) process_attestations(attestations, committees, self.eip7549_supported) - frame = self.state.find_frame(duty_epoch) with lock: for committee in committees.values(): for validator_duty in committee: self.state.increment_duty( - frame, + duty_epoch, validator_duty.index, included=validator_duty.included, ) diff --git a/src/modules/csm/state.py b/src/modules/csm/state.py index 0cb3eeac8..6ea3db1a2 100644 --- a/src/modules/csm/state.py +++ b/src/modules/csm/state.py @@ -3,6 +3,7 @@ import pickle from collections import defaultdict from dataclasses import dataclass +from functools import lru_cache from itertools import batched from pathlib import Path from typing import Self @@ -127,15 +128,15 @@ def clear(self) -> None: self._processed_epochs.clear() assert self.is_empty + @lru_cache(variables.CSM_ORACLE_MAX_CONCURRENCY) def find_frame(self, epoch: EpochNumber) -> Frame: for epoch_range in self.frames: if epoch_range[0] <= epoch <= epoch_range[1]: return epoch_range raise ValueError(f"Epoch {epoch} is out of frames range: {self.frames}") - def increment_duty(self, frame: Frame, val_index: ValidatorIndex, included: bool) -> None: - if frame not in self.data: - raise ValueError(f"Frame {frame} is not found in the state") + def increment_duty(self, epoch: EpochNumber, val_index: ValidatorIndex, included: bool) -> None: + frame = self.find_frame(epoch) self.data[frame][val_index].add_duty(included) def add_processed_epoch(self, epoch: EpochNumber) -> None: @@ -176,6 +177,7 @@ def init_or_migrate( self._epochs_per_frame = epochs_per_frame self._epochs_to_process = tuple(sequence(l_epoch, r_epoch)) self._consensus_version = consensus_version + self.find_frame.cache_clear() self.commit() def _migrate_frames_data( diff --git a/tests/modules/csm/test_state.py b/tests/modules/csm/test_state.py index 82763561a..97004053c 100644 --- a/tests/modules/csm/test_state.py +++ b/tests/modules/csm/test_state.py @@ -149,44 +149,55 @@ def test_find_frame_raises_error_for_out_of_range_epoch(): def test_increment_duty_adds_duty_correctly(): state = State() + state._epochs_to_process = tuple(sequence(0, 31)) + state._epochs_per_frame = 32 frame = (0, 31) + duty_epoch, _ = frame state.data = { frame: defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), } - state.increment_duty(frame, ValidatorIndex(1), True) + state.increment_duty(duty_epoch, ValidatorIndex(1), True) assert state.data[frame][ValidatorIndex(1)].assigned == 11 assert state.data[frame][ValidatorIndex(1)].included == 6 def test_increment_duty_creates_new_validator_entry(): state = State() + state._epochs_to_process = tuple(sequence(0, 31)) + state._epochs_per_frame = 32 frame = (0, 31) + duty_epoch, _ = frame state.data = { frame: defaultdict(AttestationsAccumulator), } - state.increment_duty(frame, ValidatorIndex(2), True) + state.increment_duty(duty_epoch, ValidatorIndex(2), True) assert state.data[frame][ValidatorIndex(2)].assigned == 1 assert state.data[frame][ValidatorIndex(2)].included == 1 def test_increment_duty_handles_non_included_duty(): state = State() + state._epochs_to_process = tuple(sequence(0, 31)) + state._epochs_per_frame = 32 frame = (0, 31) + duty_epoch, _ = frame state.data = { frame: defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), } - state.increment_duty(frame, ValidatorIndex(1), False) + state.increment_duty(duty_epoch, ValidatorIndex(1), False) assert state.data[frame][ValidatorIndex(1)].assigned == 11 assert state.data[frame][ValidatorIndex(1)].included == 5 def test_increment_duty_raises_error_for_out_of_range_epoch(): state = State() + state._epochs_to_process = tuple(sequence(0, 31)) + state._epochs_per_frame = 32 state.data = { (0, 31): defaultdict(AttestationsAccumulator), } - with pytest.raises(ValueError, match="is not found in the state"): - state.increment_duty((0, 32), ValidatorIndex(1), True) + with pytest.raises(ValueError, match="is out of frames range"): + state.increment_duty(32, ValidatorIndex(1), True) def test_add_processed_epoch_adds_epoch_to_processed_set(): From f9d4f6d711da3db2a5d30084e70ef372621a1bcc Mon Sep 17 00:00:00 2001 From: vgorkavenko Date: Wed, 19 Feb 2025 11:34:07 +0100 Subject: [PATCH 12/20] refactor: `init_or_migrate` and `_migrate_frames_data` --- src/modules/csm/state.py | 50 +++++++++++++-------------------- tests/modules/csm/test_state.py | 39 ++++++++++++------------- 2 files changed, 40 insertions(+), 49 deletions(-) diff --git a/src/modules/csm/state.py b/src/modules/csm/state.py index 6ea3db1a2..653193992 100644 --- a/src/modules/csm/state.py +++ b/src/modules/csm/state.py @@ -158,54 +158,44 @@ def init_or_migrate( self.clear() frames = self._calculate_frames(tuple(sequence(l_epoch, r_epoch)), epochs_per_frame) - frames_data: StateData = {frame: defaultdict(AttestationsAccumulator) for frame in frames} if not self.is_empty: cached_frames = self.frames if cached_frames == frames: logger.info({"msg": "No need to migrate duties data cache"}) return + self._migrate_frames_data(frames) + else: + self.data = {frame: defaultdict(AttestationsAccumulator) for frame in frames} - frames_data, migration_status = self._migrate_frames_data(cached_frames, frames) - - for current_frame, migrated in migration_status.items(): - if not migrated: - logger.warning({"msg": f"Invalidating frame duties data cache: {current_frame}"}) - self._processed_epochs.difference_update(sequence(*current_frame)) - - self.data = frames_data self._epochs_per_frame = epochs_per_frame self._epochs_to_process = tuple(sequence(l_epoch, r_epoch)) self._consensus_version = consensus_version self.find_frame.cache_clear() self.commit() - def _migrate_frames_data( - self, current_frames: list[Frame], new_frames: list[Frame] - ) -> tuple[StateData, dict[Frame, bool]]: - migration_status = {frame: False for frame in current_frames} + def _migrate_frames_data(self, new_frames: list[Frame]): + logger.info({"msg": f"Migrating duties data cache: {self.frames=} -> {new_frames=}"}) new_data: StateData = {frame: defaultdict(AttestationsAccumulator) for frame in new_frames} - logger.info({"msg": f"Migrating duties data cache: {current_frames=} -> {new_frames=}"}) + def overlaps(a: Frame, b: Frame): + return a[0] <= b[0] and a[1] >= b[1] - for current_frame in current_frames: - curr_frame_l_epoch, curr_frame_r_epoch = current_frame - for new_frame in new_frames: - if current_frame == new_frame: - new_data[new_frame] = self.data[current_frame] - migration_status[current_frame] = True - break - - new_frame_l_epoch, new_frame_r_epoch = new_frame - if curr_frame_l_epoch >= new_frame_l_epoch and curr_frame_r_epoch <= new_frame_r_epoch: - logger.info({"msg": f"Migrating frame duties data cache: {current_frame=} -> {new_frame=}"}) - for val, duty in self.data[current_frame].items(): + consumed = [] + for new_frame in new_frames: + for frame_to_consume in self.frames: + if overlaps(new_frame, frame_to_consume): + assert frame_to_consume not in consumed + consumed.append(frame_to_consume) + for val, duty in self.data[frame_to_consume].items(): new_data[new_frame][val].assigned += duty.assigned new_data[new_frame][val].included += duty.included - migration_status[current_frame] = True - break - - return new_data, migration_status + for frame in self.frames: + if frame in consumed: + continue + logger.warning({"msg": f"Invalidating frame duties data cache: {frame}"}) + self._processed_epochs -= set(sequence(*frame)) + self.data = new_data def validate(self, l_epoch: EpochNumber, r_epoch: EpochNumber) -> None: if not self.is_fulfilled: diff --git a/tests/modules/csm/test_state.py b/tests/modules/csm/test_state.py index 97004053c..11808b1ea 100644 --- a/tests/modules/csm/test_state.py +++ b/tests/modules/csm/test_state.py @@ -294,72 +294,73 @@ def test_init_or_migrate_discards_unmigrated_frame(): def test_migrate_frames_data_creates_new_data_correctly(): state = State() - current_frames = [(0, 31), (32, 63)] + state._epochs_to_process = tuple(sequence(0, 63)) + state._epochs_per_frame = 32 new_frames = [(0, 63)] state.data = { (0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), (32, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(20, 15)}), } - new_data, migration_status = state._migrate_frames_data(current_frames, new_frames) - assert new_data == { + state._migrate_frames_data(new_frames) + assert state.data == { (0, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(30, 20)}) } - assert migration_status == {(0, 31): True, (32, 63): True} def test_migrate_frames_data_handles_no_migration(): state = State() - current_frames = [(0, 31)] + state._epochs_to_process = tuple(sequence(0, 31)) + state._epochs_per_frame = 32 new_frames = [(0, 31)] state.data = { (0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), } - new_data, migration_status = state._migrate_frames_data(current_frames, new_frames) - assert new_data == { + state._migrate_frames_data(new_frames) + assert state.data == { (0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}) } - assert migration_status == {(0, 31): True} def test_migrate_frames_data_handles_partial_migration(): state = State() - current_frames = [(0, 31), (32, 63)] + state._epochs_to_process = tuple(sequence(0, 63)) + state._epochs_per_frame = 32 new_frames = [(0, 31), (32, 95)] state.data = { (0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), (32, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(20, 15)}), } - new_data, migration_status = state._migrate_frames_data(current_frames, new_frames) - assert new_data == { + state._migrate_frames_data(new_frames) + assert state.data == { (0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), (32, 95): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(20, 15)}), } - assert migration_status == {(0, 31): True, (32, 63): True} def test_migrate_frames_data_handles_no_data(): state = State() + state._epochs_to_process = tuple(sequence(0, 31)) + state._epochs_per_frame = 32 current_frames = [(0, 31)] new_frames = [(0, 31)] state.data = {frame: defaultdict(AttestationsAccumulator) for frame in current_frames} - new_data, migration_status = state._migrate_frames_data(current_frames, new_frames) - assert new_data == {(0, 31): defaultdict(AttestationsAccumulator)} - assert migration_status == {(0, 31): True} + state._migrate_frames_data(new_frames) + assert state.data == {(0, 31): defaultdict(AttestationsAccumulator)} def test_migrate_frames_data_handles_wider_old_frame(): state = State() - current_frames = [(0, 63)] + state._epochs_to_process = tuple(sequence(0, 63)) + state._epochs_per_frame = 64 new_frames = [(0, 31), (32, 63)] state.data = { (0, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(30, 20)}), } - new_data, migration_status = state._migrate_frames_data(current_frames, new_frames) - assert new_data == { + state._migrate_frames_data(new_frames) + assert state.data == { (0, 31): defaultdict(AttestationsAccumulator), (32, 63): defaultdict(AttestationsAccumulator), } - assert migration_status == {(0, 63): False} def test_validate_raises_error_if_state_not_fulfilled(): From 84eb94f877656a71a414ee7479fe8bc500a84673 Mon Sep 17 00:00:00 2001 From: vgorkavenko Date: Wed, 19 Feb 2025 11:35:07 +0100 Subject: [PATCH 13/20] refactor: `find_frame` --- src/modules/csm/state.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/modules/csm/state.py b/src/modules/csm/state.py index 653193992..cb33c499c 100644 --- a/src/modules/csm/state.py +++ b/src/modules/csm/state.py @@ -131,7 +131,8 @@ def clear(self) -> None: @lru_cache(variables.CSM_ORACLE_MAX_CONCURRENCY) def find_frame(self, epoch: EpochNumber) -> Frame: for epoch_range in self.frames: - if epoch_range[0] <= epoch <= epoch_range[1]: + from_epoch, to_epoch = epoch_range + if from_epoch <= epoch <= to_epoch: return epoch_range raise ValueError(f"Epoch {epoch} is out of frames range: {self.frames}") From ac0d8b0b6ff9337bec25ec4d606e4f62ac9a0c89 Mon Sep 17 00:00:00 2001 From: vgorkavenko Date: Wed, 19 Feb 2025 11:35:38 +0100 Subject: [PATCH 14/20] refactor: `init_or_migrate` -> `migrate` --- src/modules/csm/csm.py | 2 +- src/modules/csm/state.py | 2 +- tests/modules/csm/test_checkpoint.py | 8 ++++---- tests/modules/csm/test_state.py | 10 +++++----- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/modules/csm/csm.py b/src/modules/csm/csm.py index 4781e6f00..5fe6f7a94 100644 --- a/src/modules/csm/csm.py +++ b/src/modules/csm/csm.py @@ -200,7 +200,7 @@ def collect_data(self, blockstamp: BlockStamp) -> bool: logger.info({"msg": "The starting epoch of the epochs range is not finalized yet"}) return False - self.state.init_or_migrate(l_epoch, r_epoch, converter.frame_config.epochs_per_frame, consensus_version) + self.state.migrate(l_epoch, r_epoch, converter.frame_config.epochs_per_frame, consensus_version) self.state.log_progress() if self.state.is_fulfilled: diff --git a/src/modules/csm/state.py b/src/modules/csm/state.py index cb33c499c..2b9abeb86 100644 --- a/src/modules/csm/state.py +++ b/src/modules/csm/state.py @@ -146,7 +146,7 @@ def add_processed_epoch(self, epoch: EpochNumber) -> None: def log_progress(self) -> None: logger.info({"msg": f"Processed {len(self._processed_epochs)} of {len(self._epochs_to_process)} epochs"}) - def init_or_migrate( + def migrate( self, l_epoch: EpochNumber, r_epoch: EpochNumber, epochs_per_frame: int, consensus_version: int ) -> None: if consensus_version != self._consensus_version: diff --git a/tests/modules/csm/test_checkpoint.py b/tests/modules/csm/test_checkpoint.py index 4b456ed03..070d82bf9 100644 --- a/tests/modules/csm/test_checkpoint.py +++ b/tests/modules/csm/test_checkpoint.py @@ -326,7 +326,7 @@ def test_checkpoints_processor_no_eip7549_support( monkeypatch: pytest.MonkeyPatch, ): state = State() - state.init_or_migrate(EpochNumber(0), EpochNumber(255), 256, 1) + state.migrate(EpochNumber(0), EpochNumber(255), 256, 1) processor = FrameCheckpointProcessor( consensus_client, state, @@ -354,7 +354,7 @@ def test_checkpoints_processor_check_duty( converter, ): state = State() - state.init_or_migrate(0, 255, 256, 1) + state.migrate(0, 255, 256, 1) finalized_blockstamp = ... processor = FrameCheckpointProcessor( consensus_client, @@ -379,7 +379,7 @@ def test_checkpoints_processor_process( converter, ): state = State() - state.init_or_migrate(0, 255, 256, 1) + state.migrate(0, 255, 256, 1) finalized_blockstamp = ... processor = FrameCheckpointProcessor( consensus_client, @@ -404,7 +404,7 @@ def test_checkpoints_processor_exec( converter, ): state = State() - state.init_or_migrate(0, 255, 256, 1) + state.migrate(0, 255, 256, 1) finalized_blockstamp = ... processor = FrameCheckpointProcessor( consensus_client, diff --git a/tests/modules/csm/test_state.py b/tests/modules/csm/test_state.py index 11808b1ea..4eba5cfd0 100644 --- a/tests/modules/csm/test_state.py +++ b/tests/modules/csm/test_state.py @@ -218,7 +218,7 @@ def test_init_or_migrate_discards_data_on_version_change(): state._consensus_version = 1 state.clear = Mock() state.commit = Mock() - state.init_or_migrate(0, 63, 32, 2) + state.migrate(0, 63, 32, 2) state.clear.assert_called_once() state.commit.assert_called_once() @@ -233,7 +233,7 @@ def test_init_or_migrate_no_migration_needed(): (32, 63): defaultdict(AttestationsAccumulator), } state.commit = Mock() - state.init_or_migrate(0, 63, 32, 1) + state.migrate(0, 63, 32, 1) state.commit.assert_not_called() @@ -247,7 +247,7 @@ def test_init_or_migrate_migrates_data(): (32, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(20, 15)}), } state.commit = Mock() - state.init_or_migrate(0, 63, 64, 1) + state.migrate(0, 63, 64, 1) assert state.data == { (0, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(30, 20)}), } @@ -263,7 +263,7 @@ def test_init_or_migrate_invalidates_unmigrated_frames(): (0, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(30, 20)}), } state.commit = Mock() - state.init_or_migrate(0, 31, 32, 1) + state.migrate(0, 31, 32, 1) assert state.data == { (0, 31): defaultdict(AttestationsAccumulator), } @@ -283,7 +283,7 @@ def test_init_or_migrate_discards_unmigrated_frame(): } state._processed_epochs = set(sequence(0, 95)) state.commit = Mock() - state.init_or_migrate(0, 63, 32, 1) + state.migrate(0, 63, 32, 1) assert state.data == { (0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), (32, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(20, 15)}), From 9583a6b488a571cd1b985f4e390ce18e6ccd88d1 Mon Sep 17 00:00:00 2001 From: vgorkavenko Date: Wed, 19 Feb 2025 11:43:22 +0100 Subject: [PATCH 15/20] refactor: `calculate_distribution` types --- src/modules/csm/csm.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/modules/csm/csm.py b/src/modules/csm/csm.py index 5fe6f7a94..39f22796d 100644 --- a/src/modules/csm/csm.py +++ b/src/modules/csm/csm.py @@ -3,6 +3,7 @@ from typing import Iterator from hexbytes import HexBytes +from web3.types import Wei from src.constants import TOTAL_BASIS_POINTS, UINT64_MAX from src.metrics.prometheus.business import CONTRACT_ON_PAUSE @@ -226,12 +227,12 @@ def collect_data(self, blockstamp: BlockStamp) -> bool: def calculate_distribution( self, blockstamp: ReferenceBlockStamp - ) -> tuple[int, defaultdict[NodeOperatorId, int], list[FramePerfLog]]: + ) -> tuple[Shares, defaultdict[NodeOperatorId, Shares], list[FramePerfLog]]: """Computes distribution of fee shares at the given timestamp""" operators_to_validators = self.module_validators_by_node_operators(blockstamp) - total_distributed = 0 - total_rewards = defaultdict[NodeOperatorId, int](int) + total_distributed = Shares(0) + total_rewards = defaultdict[NodeOperatorId, Shares](Shares) logs: list[FramePerfLog] = [] for frame in self.state.frames: From 80dc2f6facf8f6eb0999dcfca6c665cbf58ec93d Mon Sep 17 00:00:00 2001 From: vgorkavenko Date: Wed, 19 Feb 2025 11:45:25 +0100 Subject: [PATCH 16/20] refactor: `test_csm_module.py` --- tests/modules/csm/test_csm_module.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/tests/modules/csm/test_csm_module.py b/tests/modules/csm/test_csm_module.py index 8fd863cd5..ee096b64d 100644 --- a/tests/modules/csm/test_csm_module.py +++ b/tests/modules/csm/test_csm_module.py @@ -21,7 +21,7 @@ @pytest.fixture(autouse=True) -def mock_get_module_id(monkeypatch: pytest.MonkeyPatch): +def mock_get_staking_module(monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr(CSOracle, "_get_staking_module", Mock()) @@ -41,19 +41,18 @@ def test_init(module: CSOracle): def test_get_stuck_operators(module: CSOracle, csm: CSM): module.module = Mock() # type: ignore - module.module_id = StakingModuleId(1) module.w3.cc = Mock() module.w3.lido_validators = Mock() module.w3.lido_contracts = Mock() module.w3.lido_contracts.staking_router.get_all_node_operator_digests = Mock( return_value=[ - type('NodeOperator', (object,), {'id': 0, 'stuck_validators_count': 0})(), - type('NodeOperator', (object,), {'id': 1, 'stuck_validators_count': 0})(), - type('NodeOperator', (object,), {'id': 2, 'stuck_validators_count': 1})(), - type('NodeOperator', (object,), {'id': 3, 'stuck_validators_count': 0})(), - type('NodeOperator', (object,), {'id': 4, 'stuck_validators_count': 100500})(), - type('NodeOperator', (object,), {'id': 5, 'stuck_validators_count': 100})(), - type('NodeOperator', (object,), {'id': 6, 'stuck_validators_count': 0})(), + Mock(id=0, stuck_validators_count=0), + Mock(id=1, stuck_validators_count=0), + Mock(id=2, stuck_validators_count=1), + Mock(id=3, stuck_validators_count=0), + Mock(id=4, stuck_validators_count=100500), + Mock(id=5, stuck_validators_count=100), + Mock(id=6, stuck_validators_count=0), ] ) @@ -80,7 +79,6 @@ def test_get_stuck_operators(module: CSOracle, csm: CSM): def test_get_stuck_operators_left_border_before_enact(module: CSOracle, csm: CSM, caplog: pytest.LogCaptureFixture): module.module = Mock() # type: ignore - module.module_id = StakingModuleId(3) module.w3.cc = Mock() module.w3.lido_validators = Mock() module.w3.lido_contracts = Mock() From 0d122e87786d5d0c0523de71938673bacb26a229 Mon Sep 17 00:00:00 2001 From: vgorkavenko Date: Wed, 19 Feb 2025 11:52:39 +0100 Subject: [PATCH 17/20] feat: add negative value checking in `calc_rewards_distribution_in_frame` --- src/modules/csm/csm.py | 2 ++ tests/modules/csm/test_csm_distribution.py | 8 ++++++++ 2 files changed, 10 insertions(+) diff --git a/src/modules/csm/csm.py b/src/modules/csm/csm.py index 39f22796d..be5dd6f2a 100644 --- a/src/modules/csm/csm.py +++ b/src/modules/csm/csm.py @@ -344,6 +344,8 @@ def calc_rewards_distribution_in_frame( participation_shares: dict[NodeOperatorId, int], rewards_to_distribute: int, ) -> dict[NodeOperatorId, int]: + if rewards_to_distribute < 0: + raise ValueError(f"Invalid rewards to distribute: {rewards_to_distribute}") rewards_distribution: dict[NodeOperatorId, int] = defaultdict(int) total_participation = sum(participation_shares.values()) diff --git a/tests/modules/csm/test_csm_distribution.py b/tests/modules/csm/test_csm_distribution.py index c2c04591a..f277f6eb8 100644 --- a/tests/modules/csm/test_csm_distribution.py +++ b/tests/modules/csm/test_csm_distribution.py @@ -327,3 +327,11 @@ def test_calc_rewards_distribution_in_frame_handles_partial_participation(): assert rewards_distribution[NodeOperatorId(1)] == Wei(1 * 10**18) assert rewards_distribution[NodeOperatorId(2)] == 0 + + +def test_calc_rewards_distribution_in_frame_handles_negative_to_distribute(): + participation_shares = {NodeOperatorId(1): 100, NodeOperatorId(2): 200} + rewards_to_distribute = Wei(-1) + + with pytest.raises(ValueError, match="Invalid rewards to distribute"): + CSOracle.calc_rewards_distribution_in_frame(participation_shares, rewards_to_distribute) From f9290da6ecc3711e2867d897366a25f33fe41c89 Mon Sep 17 00:00:00 2001 From: vgorkavenko Date: Wed, 19 Feb 2025 11:53:40 +0100 Subject: [PATCH 18/20] fix: linter --- src/modules/csm/csm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/modules/csm/csm.py b/src/modules/csm/csm.py index be5dd6f2a..e9e042172 100644 --- a/src/modules/csm/csm.py +++ b/src/modules/csm/csm.py @@ -3,7 +3,6 @@ from typing import Iterator from hexbytes import HexBytes -from web3.types import Wei from src.constants import TOTAL_BASIS_POINTS, UINT64_MAX from src.metrics.prometheus.business import CONTRACT_ON_PAUSE From 0731525c1f27cd2578b1822fee65311e6798f026 Mon Sep 17 00:00:00 2001 From: vgorkavenko Date: Wed, 19 Feb 2025 12:39:44 +0100 Subject: [PATCH 19/20] refactor: `frames` now is `State` attribute --- src/modules/csm/state.py | 9 ++---- tests/modules/csm/test_state.py | 52 ++++++++++++--------------------- 2 files changed, 21 insertions(+), 40 deletions(-) diff --git a/src/modules/csm/state.py b/src/modules/csm/state.py index 2b9abeb86..2f284f980 100644 --- a/src/modules/csm/state.py +++ b/src/modules/csm/state.py @@ -49,11 +49,11 @@ class State: The state can be migrated to be used for another frame's report by calling the `migrate` method. """ + frames: list[Frame] data: StateData _epochs_to_process: tuple[EpochNumber, ...] _processed_epochs: set[EpochNumber] - _epochs_per_frame: int _consensus_version: int = 1 @@ -61,7 +61,6 @@ def __init__(self) -> None: self.data = {} self._epochs_to_process = tuple() self._processed_epochs = set() - self._epochs_per_frame = 0 EXTENSION = ".pkl" @@ -111,10 +110,6 @@ def unprocessed_epochs(self) -> set[EpochNumber]: def is_fulfilled(self) -> bool: return not self.unprocessed_epochs - @property - def frames(self): - return self._calculate_frames(self._epochs_to_process, self._epochs_per_frame) - @staticmethod def _calculate_frames(epochs_to_process: tuple[EpochNumber, ...], epochs_per_frame: int) -> list[Frame]: """Split epochs to process into frames of `epochs_per_frame` length""" @@ -169,7 +164,7 @@ def migrate( else: self.data = {frame: defaultdict(AttestationsAccumulator) for frame in frames} - self._epochs_per_frame = epochs_per_frame + self.frames = frames self._epochs_to_process = tuple(sequence(l_epoch, r_epoch)) self._consensus_version = consensus_version self.find_frame.cache_clear() diff --git a/tests/modules/csm/test_state.py b/tests/modules/csm/test_state.py index 4eba5cfd0..2fb3cd078 100644 --- a/tests/modules/csm/test_state.py +++ b/tests/modules/csm/test_state.py @@ -132,16 +132,14 @@ def test_clear_resets_state_to_empty(): def test_find_frame_returns_correct_frame(): state = State() - state._epochs_to_process = tuple(sequence(0, 31)) - state._epochs_per_frame = 32 + state.frames = [(0, 31)] state.data = {(0, 31): defaultdict(AttestationsAccumulator)} assert state.find_frame(15) == (0, 31) def test_find_frame_raises_error_for_out_of_range_epoch(): state = State() - state._epochs_to_process = tuple(sequence(0, 31)) - state._epochs_per_frame = 32 + state.frames = [(0, 31)] state.data = {(0, 31): defaultdict(AttestationsAccumulator)} with pytest.raises(ValueError, match="Epoch 32 is out of frames range"): state.find_frame(32) @@ -149,9 +147,8 @@ def test_find_frame_raises_error_for_out_of_range_epoch(): def test_increment_duty_adds_duty_correctly(): state = State() - state._epochs_to_process = tuple(sequence(0, 31)) - state._epochs_per_frame = 32 frame = (0, 31) + state.frames = [frame] duty_epoch, _ = frame state.data = { frame: defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), @@ -163,9 +160,8 @@ def test_increment_duty_adds_duty_correctly(): def test_increment_duty_creates_new_validator_entry(): state = State() - state._epochs_to_process = tuple(sequence(0, 31)) - state._epochs_per_frame = 32 frame = (0, 31) + state.frames = [frame] duty_epoch, _ = frame state.data = { frame: defaultdict(AttestationsAccumulator), @@ -177,8 +173,8 @@ def test_increment_duty_creates_new_validator_entry(): def test_increment_duty_handles_non_included_duty(): state = State() - state._epochs_to_process = tuple(sequence(0, 31)) - state._epochs_per_frame = 32 + frame = (0, 31) + state.frames = [frame] frame = (0, 31) duty_epoch, _ = frame state.data = { @@ -191,10 +187,10 @@ def test_increment_duty_handles_non_included_duty(): def test_increment_duty_raises_error_for_out_of_range_epoch(): state = State() - state._epochs_to_process = tuple(sequence(0, 31)) - state._epochs_per_frame = 32 + frame = (0, 31) + state.frames = [frame] state.data = { - (0, 31): defaultdict(AttestationsAccumulator), + frame: defaultdict(AttestationsAccumulator), } with pytest.raises(ValueError, match="is out of frames range"): state.increment_duty(32, ValidatorIndex(1), True) @@ -226,8 +222,7 @@ def test_init_or_migrate_discards_data_on_version_change(): def test_init_or_migrate_no_migration_needed(): state = State() state._consensus_version = 1 - state._epochs_to_process = tuple(sequence(0, 63)) - state._epochs_per_frame = 32 + state.frames = [(0, 31), (32, 63)] state.data = { (0, 31): defaultdict(AttestationsAccumulator), (32, 63): defaultdict(AttestationsAccumulator), @@ -240,8 +235,7 @@ def test_init_or_migrate_no_migration_needed(): def test_init_or_migrate_migrates_data(): state = State() state._consensus_version = 1 - state._epochs_to_process = tuple(sequence(0, 63)) - state._epochs_per_frame = 32 + state.frames = [(0, 31), (32, 63)] state.data = { (0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), (32, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(20, 15)}), @@ -257,8 +251,7 @@ def test_init_or_migrate_migrates_data(): def test_init_or_migrate_invalidates_unmigrated_frames(): state = State() state._consensus_version = 1 - state._epochs_to_process = tuple(sequence(0, 63)) - state._epochs_per_frame = 64 + state.frames = [(0, 63)] state.data = { (0, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(30, 20)}), } @@ -274,8 +267,7 @@ def test_init_or_migrate_invalidates_unmigrated_frames(): def test_init_or_migrate_discards_unmigrated_frame(): state = State() state._consensus_version = 1 - state._epochs_to_process = tuple(sequence(0, 95)) - state._epochs_per_frame = 32 + state.frames = [(0, 31), (32, 63), (64, 95)] state.data = { (0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), (32, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(20, 15)}), @@ -294,8 +286,7 @@ def test_init_or_migrate_discards_unmigrated_frame(): def test_migrate_frames_data_creates_new_data_correctly(): state = State() - state._epochs_to_process = tuple(sequence(0, 63)) - state._epochs_per_frame = 32 + state.frames = [(0, 31), (32, 63)] new_frames = [(0, 63)] state.data = { (0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), @@ -309,8 +300,7 @@ def test_migrate_frames_data_creates_new_data_correctly(): def test_migrate_frames_data_handles_no_migration(): state = State() - state._epochs_to_process = tuple(sequence(0, 31)) - state._epochs_per_frame = 32 + state.frames = [(0, 31)] new_frames = [(0, 31)] state.data = { (0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), @@ -323,8 +313,7 @@ def test_migrate_frames_data_handles_no_migration(): def test_migrate_frames_data_handles_partial_migration(): state = State() - state._epochs_to_process = tuple(sequence(0, 63)) - state._epochs_per_frame = 32 + state.frames = [(0, 31), (32, 63)] new_frames = [(0, 31), (32, 95)] state.data = { (0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), @@ -339,19 +328,16 @@ def test_migrate_frames_data_handles_partial_migration(): def test_migrate_frames_data_handles_no_data(): state = State() - state._epochs_to_process = tuple(sequence(0, 31)) - state._epochs_per_frame = 32 - current_frames = [(0, 31)] + state.frames = [(0, 31)] new_frames = [(0, 31)] - state.data = {frame: defaultdict(AttestationsAccumulator) for frame in current_frames} + state.data = {frame: defaultdict(AttestationsAccumulator) for frame in state.frames} state._migrate_frames_data(new_frames) assert state.data == {(0, 31): defaultdict(AttestationsAccumulator)} def test_migrate_frames_data_handles_wider_old_frame(): state = State() - state._epochs_to_process = tuple(sequence(0, 63)) - state._epochs_per_frame = 64 + state.frames = [(0, 63)] new_frames = [(0, 31), (32, 63)] state.data = { (0, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(30, 20)}), From 15d5ab2681e7f44f18b0278bd8c51cd5357a0e12 Mon Sep 17 00:00:00 2001 From: vgorkavenko Date: Wed, 19 Feb 2025 12:53:46 +0100 Subject: [PATCH 20/20] refactor: `migrate` --- src/modules/csm/state.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/src/modules/csm/state.py b/src/modules/csm/state.py index 2f284f980..48debfe83 100644 --- a/src/modules/csm/state.py +++ b/src/modules/csm/state.py @@ -58,6 +58,7 @@ class State: _consensus_version: int = 1 def __init__(self) -> None: + self.frames = [] self.data = {} self._epochs_to_process = tuple() self._processed_epochs = set() @@ -153,21 +154,16 @@ def migrate( ) self.clear() - frames = self._calculate_frames(tuple(sequence(l_epoch, r_epoch)), epochs_per_frame) + new_frames = self._calculate_frames(tuple(sequence(l_epoch, r_epoch)), epochs_per_frame) + if self.frames == new_frames: + logger.info({"msg": "No need to migrate duties data cache"}) + return + self._migrate_frames_data(new_frames) - if not self.is_empty: - cached_frames = self.frames - if cached_frames == frames: - logger.info({"msg": "No need to migrate duties data cache"}) - return - self._migrate_frames_data(frames) - else: - self.data = {frame: defaultdict(AttestationsAccumulator) for frame in frames} - - self.frames = frames + self.frames = new_frames + self.find_frame.cache_clear() self._epochs_to_process = tuple(sequence(l_epoch, r_epoch)) self._consensus_version = consensus_version - self.find_frame.cache_clear() self.commit() def _migrate_frames_data(self, new_frames: list[Frame]):