diff --git a/src/modules/csm/checkpoint.py b/src/modules/csm/checkpoint.py index 0efc326c6..222d1e626 100644 --- a/src/modules/csm/checkpoint.py +++ b/src/modules/csm/checkpoint.py @@ -143,6 +143,7 @@ def exec(self, checkpoint: FrameCheckpoint) -> int: for duty_epoch in unprocessed_epochs } self._process(unprocessed_epochs, duty_epochs_roots) + self.state.commit() return len(unprocessed_epochs) def _get_block_roots(self, checkpoint_slot: SlotNumber): @@ -204,18 +205,17 @@ def _check_duty( for root in block_roots: attestations = self.cc.get_block_attestations(root) process_attestations(attestations, committees, self.eip7549_supported) - with lock: for committee in committees.values(): for validator_duty in committee: - self.state.inc( + self.state.increment_duty( + duty_epoch, validator_duty.index, included=validator_duty.included, ) if duty_epoch not in self.state.unprocessed_epochs: raise ValueError(f"Epoch {duty_epoch} is not in epochs that should be processed") self.state.add_processed_epoch(duty_epoch) - self.state.commit() self.state.log_progress() unprocessed_epochs = self.state.unprocessed_epochs CSM_UNPROCESSED_EPOCHS_COUNT.set(len(unprocessed_epochs)) diff --git a/src/modules/csm/csm.py b/src/modules/csm/csm.py index 543a276e2..e9e042172 100644 --- a/src/modules/csm/csm.py +++ b/src/modules/csm/csm.py @@ -12,8 +12,8 @@ ) from src.metrics.prometheus.duration_meter import duration_meter from src.modules.csm.checkpoint import FrameCheckpointProcessor, FrameCheckpointsIterator, MinStepIsNotReached -from src.modules.csm.log import FramePerfLog -from src.modules.csm.state import State +from src.modules.csm.log import FramePerfLog, OperatorFrameSummary +from src.modules.csm.state import State, Frame, AttestationsAccumulator from src.modules.csm.tree import Tree from src.modules.csm.types import ReportData, Shares from src.modules.submodules.consensus import ConsensusModule @@ -28,13 +28,12 @@ ReferenceBlockStamp, SlotNumber, StakingModuleAddress, - StakingModuleId, ) from src.utils.blockstamp import build_blockstamp from src.utils.cache import global_lru_cache as lru_cache -from src.utils.slot import get_next_non_missed_slot +from src.utils.slot import get_next_non_missed_slot, get_reference_blockstamp from src.utils.web3converter import Web3Converter -from src.web3py.extensions.lido_validators import NodeOperatorId, StakingModule, ValidatorsByNodeOperator +from src.web3py.extensions.lido_validators import NodeOperatorId, StakingModule, ValidatorsByNodeOperator, LidoValidator from src.web3py.types import Web3 logger = logging.getLogger(__name__) @@ -62,13 +61,13 @@ class CSOracle(BaseModule, ConsensusModule): COMPATIBLE_ONCHAIN_VERSIONS = [(1, 1), (1, 2)] report_contract: CSFeeOracleContract - module_id: StakingModuleId + staking_module: StakingModule def __init__(self, w3: Web3): self.report_contract = w3.csm.oracle self.state = State.load() super().__init__(w3) - self.module_id = self._get_module_id() + self.staking_module = self._get_staking_module() def refresh_contracts(self): self.report_contract = self.w3.csm.oracle # type: ignore @@ -101,15 +100,15 @@ def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple: if (prev_cid is None) != (prev_root == ZERO_HASH): raise InconsistentData(f"Got inconsistent previous tree data: {prev_root=} {prev_cid=}") - distributed, shares, log = self.calculate_distribution(blockstamp) + total_distributed, total_rewards, logs = self.calculate_distribution(blockstamp) - if distributed != sum(shares.values()): - raise InconsistentData(f"Invalid distribution: {sum(shares.values())=} != {distributed=}") + if total_distributed != sum(total_rewards.values()): + raise InconsistentData(f"Invalid distribution: {sum(total_rewards.values())=} != {total_distributed=}") - log_cid = self.publish_log(log) + log_cid = self.publish_log(logs) - if not distributed and not shares: - logger.info({"msg": "No shares distributed in the current frame"}) + if not total_distributed and not total_rewards: + logger.info({"msg": "No rewards distributed in the current frame"}) return ReportData( self.get_consensus_version(blockstamp), blockstamp.ref_slot, @@ -120,13 +119,13 @@ def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple: ).as_tuple() if prev_cid and prev_root != ZERO_HASH: - # Update cumulative amount of shares for all operators. - for no_id, acc_shares in self.get_accumulated_shares(prev_cid, prev_root): - shares[no_id] += acc_shares + # Update cumulative amount of stETH shares for all operators. + for no_id, accumulated_rewards in self.get_accumulated_rewards(prev_cid, prev_root): + total_rewards[no_id] += accumulated_rewards else: logger.info({"msg": "No previous distribution. Nothing to accumulate"}) - tree = self.make_tree(shares) + tree = self.make_tree(total_rewards) tree_cid = self.publish_tree(tree) return ReportData( @@ -135,7 +134,7 @@ def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple: tree_root=tree.root, tree_cid=tree_cid, log_cid=log_cid, - distributed=distributed, + distributed=total_distributed, ).as_tuple() def is_main_data_submitted(self, blockstamp: BlockStamp) -> bool: @@ -160,7 +159,7 @@ def module_validators_by_node_operators(self, blockstamp: BlockStamp) -> Validat def validate_state(self, blockstamp: ReferenceBlockStamp) -> None: # NOTE: We cannot use `r_epoch` from the `current_frame_range` call because the `blockstamp` is a # `ReferenceBlockStamp`, hence it's a block the frame ends at. We use `ref_epoch` instead. - l_epoch, _ = self.current_frame_range(blockstamp) + l_epoch, _ = self.get_epochs_range_to_process(blockstamp) r_epoch = blockstamp.ref_epoch self.state.validate(l_epoch, r_epoch) @@ -175,8 +174,8 @@ def collect_data(self, blockstamp: BlockStamp) -> bool: converter = self.converter(blockstamp) - l_epoch, r_epoch = self.current_frame_range(blockstamp) - logger.info({"msg": f"Frame for performance data collect: epochs [{l_epoch};{r_epoch}]"}) + l_epoch, r_epoch = self.get_epochs_range_to_process(blockstamp) + logger.info({"msg": f"Epochs range for performance data collect: [{l_epoch};{r_epoch}]"}) # NOTE: Finalized slot is the first slot of justifying epoch, so we need to take the previous. But if the first # slot of the justifying epoch is empty, blockstamp.slot_number will point to the slot where the last finalized @@ -192,16 +191,16 @@ def collect_data(self, blockstamp: BlockStamp) -> bool: if report_blockstamp and report_blockstamp.ref_epoch != r_epoch: logger.warning( { - "msg": f"Frame has been changed, but the change is not yet observed on finalized epoch {finalized_epoch}" + "msg": f"Epochs range has been changed, but the change is not yet observed on finalized epoch {finalized_epoch}" } ) return False if l_epoch > finalized_epoch: - logger.info({"msg": "The starting epoch of the frame is not finalized yet"}) + logger.info({"msg": "The starting epoch of the epochs range is not finalized yet"}) return False - self.state.migrate(l_epoch, r_epoch, consensus_version) + self.state.migrate(l_epoch, r_epoch, converter.frame_config.epochs_per_frame, consensus_version) self.state.log_progress() if self.state.is_fulfilled: @@ -218,8 +217,8 @@ def collect_data(self, blockstamp: BlockStamp) -> bool: processor = FrameCheckpointProcessor(self.w3.cc, self.state, converter, blockstamp, eip7549_supported) for checkpoint in checkpoints: - if self.current_frame_range(self._receive_last_finalized_slot()) != (l_epoch, r_epoch): - logger.info({"msg": "Checkpoints were prepared for an outdated frame, stop processing"}) + if self.get_epochs_range_to_process(self._receive_last_finalized_slot()) != (l_epoch, r_epoch): + logger.info({"msg": "Checkpoints were prepared for an outdated epochs range, stop processing"}) raise ValueError("Outdated checkpoint") processor.exec(checkpoint) @@ -227,65 +226,137 @@ def collect_data(self, blockstamp: BlockStamp) -> bool: def calculate_distribution( self, blockstamp: ReferenceBlockStamp - ) -> tuple[int, defaultdict[NodeOperatorId, int], FramePerfLog]: + ) -> tuple[Shares, defaultdict[NodeOperatorId, Shares], list[FramePerfLog]]: """Computes distribution of fee shares at the given timestamp""" - - network_avg_perf = self.state.get_network_aggr().perf - threshold = network_avg_perf - self.w3.csm.oracle.perf_leeway_bp(blockstamp.block_hash) / TOTAL_BASIS_POINTS operators_to_validators = self.module_validators_by_node_operators(blockstamp) - # Build the map of the current distribution operators. - distribution: dict[NodeOperatorId, int] = defaultdict(int) - stuck_operators = self.stuck_operators(blockstamp) - log = FramePerfLog(blockstamp, self.state.frame, threshold) + total_distributed = Shares(0) + total_rewards = defaultdict[NodeOperatorId, Shares](Shares) + logs: list[FramePerfLog] = [] - for (_, no_id), validators in operators_to_validators.items(): - if no_id in stuck_operators: - log.operators[no_id].stuck = True - continue + for frame in self.state.frames: + from_epoch, to_epoch = frame + logger.info({"msg": f"Calculating distribution for frame [{from_epoch};{to_epoch}]"}) - for v in validators: - aggr = self.state.data.get(v.index) + frame_blockstamp = blockstamp + if to_epoch != blockstamp.ref_epoch: + frame_blockstamp = self._get_ref_blockstamp_for_frame(blockstamp, to_epoch) - if aggr is None: - # It's possible that the validator is not assigned to any duty, hence it's performance - # is not presented in the aggregates (e.g. exited, pending for activation etc). - continue + total_rewards_to_distribute = self.w3.csm.fee_distributor.shares_to_distribute(frame_blockstamp.block_hash) + rewards_to_distribute_in_frame = total_rewards_to_distribute - total_distributed - if v.validator.slashed is True: - # It means that validator was active during the frame and got slashed and didn't meet the exit - # epoch, so we should not count such validator for operator's share. - log.operators[no_id].validators[v.index].slashed = True - continue + rewards_in_frame, log = self._calculate_distribution_in_frame( + frame, frame_blockstamp, rewards_to_distribute_in_frame, operators_to_validators + ) + distributed_in_frame = sum(rewards_in_frame.values()) - if aggr.perf > threshold: - # Count of assigned attestations used as a metrics of time - # the validator was active in the current frame. - distribution[no_id] += aggr.assigned + total_distributed += distributed_in_frame + if total_distributed > total_rewards_to_distribute: + raise CSMError(f"Invalid distribution: {total_distributed=} > {total_rewards_to_distribute=}") - log.operators[no_id].validators[v.index].perf = aggr + for no_id, rewards in rewards_in_frame.items(): + total_rewards[no_id] += rewards - # Calculate share of each CSM node operator. - shares = defaultdict[NodeOperatorId, int](int) - total = sum(p for p in distribution.values()) + logs.append(log) - if not total: - return 0, shares, log + return total_distributed, total_rewards, logs - to_distribute = self.w3.csm.fee_distributor.shares_to_distribute(blockstamp.block_hash) - log.distributable = to_distribute + def _get_ref_blockstamp_for_frame( + self, blockstamp: ReferenceBlockStamp, frame_ref_epoch: EpochNumber + ) -> ReferenceBlockStamp: + converter = self.converter(blockstamp) + return get_reference_blockstamp( + cc=self.w3.cc, + ref_slot=converter.get_epoch_last_slot(frame_ref_epoch), + ref_epoch=frame_ref_epoch, + last_finalized_slot_number=blockstamp.slot_number, + ) - for no_id, no_share in distribution.items(): - if no_share: - shares[no_id] = to_distribute * no_share // total - log.operators[no_id].distributed = shares[no_id] + def _calculate_distribution_in_frame( + self, + frame: Frame, + blockstamp: ReferenceBlockStamp, + rewards_to_distribute: int, + operators_to_validators: ValidatorsByNodeOperator + ): + threshold = self._get_performance_threshold(frame, blockstamp) + log = FramePerfLog(blockstamp, frame, threshold) - distributed = sum(s for s in shares.values()) - if distributed > to_distribute: - raise CSMError(f"Invalid distribution: {distributed=} > {to_distribute=}") - return distributed, shares, log + participation_shares: defaultdict[NodeOperatorId, int] = defaultdict(int) - def get_accumulated_shares(self, cid: CID, root: HexBytes) -> Iterator[tuple[NodeOperatorId, Shares]]: + stuck_operators = self.get_stuck_operators(frame, blockstamp) + for (_, no_id), validators in operators_to_validators.items(): + log_operator = log.operators[no_id] + if no_id in stuck_operators: + log_operator.stuck = True + continue + for validator in validators: + duty = self.state.data[frame].get(validator.index) + self.process_validator_duty(validator, duty, threshold, participation_shares, log_operator) + + rewards_distribution = self.calc_rewards_distribution_in_frame(participation_shares, rewards_to_distribute) + + for no_id, no_rewards in rewards_distribution.items(): + log.operators[no_id].distributed = no_rewards + + log.distributable = rewards_to_distribute + + return rewards_distribution, log + + def _get_performance_threshold(self, frame: Frame, blockstamp: ReferenceBlockStamp) -> float: + network_perf = self.state.get_network_aggr(frame).perf + perf_leeway = self.w3.csm.oracle.perf_leeway_bp(blockstamp.block_hash) / TOTAL_BASIS_POINTS + threshold = network_perf - perf_leeway + return threshold + + @staticmethod + def process_validator_duty( + validator: LidoValidator, + attestation_duty: AttestationsAccumulator | None, + threshold: float, + participation_shares: defaultdict[NodeOperatorId, int], + log_operator: OperatorFrameSummary + ): + if attestation_duty is None: + # It's possible that the validator is not assigned to any duty, hence it's performance + # is not presented in the aggregates (e.g. exited, pending for activation etc). + # TODO: check `sync_aggr` to strike (in case of bad sync performance) after validator exit + return + + log_validator = log_operator.validators[validator.index] + + if validator.validator.slashed is True: + # It means that validator was active during the frame and got slashed and didn't meet the exit + # epoch, so we should not count such validator for operator's share. + log_validator.slashed = True + return + + if attestation_duty.perf > threshold: + # Count of assigned attestations used as a metrics of time + # the validator was active in the current frame. + participation_shares[validator.lido_id.operatorIndex] += attestation_duty.assigned + + log_validator.attestation_duty = attestation_duty + + @staticmethod + def calc_rewards_distribution_in_frame( + participation_shares: dict[NodeOperatorId, int], + rewards_to_distribute: int, + ) -> dict[NodeOperatorId, int]: + if rewards_to_distribute < 0: + raise ValueError(f"Invalid rewards to distribute: {rewards_to_distribute}") + rewards_distribution: dict[NodeOperatorId, int] = defaultdict(int) + total_participation = sum(participation_shares.values()) + + for no_id, no_participation_share in participation_shares.items(): + if no_participation_share == 0: + # Skip operators with zero participation + continue + rewards_distribution[no_id] = rewards_to_distribute * no_participation_share // total_participation + + return rewards_distribution + + def get_accumulated_rewards(self, cid: CID, root: HexBytes) -> Iterator[tuple[NodeOperatorId, Shares]]: logger.info({"msg": "Fetching tree by CID from IPFS", "cid": repr(cid)}) tree = Tree.decode(self.w3.ipfs.fetch(cid)) @@ -297,33 +368,30 @@ def get_accumulated_shares(self, cid: CID, root: HexBytes) -> Iterator[tuple[Nod for v in tree.tree.values: yield v["value"] - def stuck_operators(self, blockstamp: ReferenceBlockStamp) -> set[NodeOperatorId]: - stuck: set[NodeOperatorId] = set() - l_epoch, _ = self.current_frame_range(blockstamp) - l_ref_slot = self.converter(blockstamp).get_epoch_first_slot(l_epoch) + def get_stuck_operators(self, frame: Frame, frame_blockstamp: ReferenceBlockStamp) -> set[NodeOperatorId]: + l_epoch, _ = frame + l_ref_slot = self.converter(frame_blockstamp).get_epoch_first_slot(l_epoch) # NOTE: r_block is guaranteed to be <= ref_slot, and the check # in the inner frames assures the l_block <= r_block. l_blockstamp = build_blockstamp( get_next_non_missed_slot( self.w3.cc, l_ref_slot, - blockstamp.slot_number, + frame_blockstamp.slot_number, ) ) - nos_by_module = self.w3.lido_validators.get_lido_node_operators_by_modules(l_blockstamp) - if self.module_id in nos_by_module: - stuck.update(no.id for no in nos_by_module[self.module_id] if no.stuck_validators_count > 0) - else: + digests = self.w3.lido_contracts.staking_router.get_all_node_operator_digests( + self.staking_module, l_blockstamp.block_hash + ) + if not digests: logger.warning("No CSM digest at blockstamp=%s, module was not added yet?", l_blockstamp) - - stuck.update( - self.w3.csm.get_operators_with_stucks_in_range( - l_blockstamp.block_hash, - blockstamp.block_hash, - ) + stuck_from_digests = (no.id for no in digests if no.stuck_validators_count > 0) + stuck_from_events = self.w3.csm.get_operators_with_stucks_in_range( + l_blockstamp.block_hash, + frame_blockstamp.block_hash, ) - return stuck + return set(stuck_from_digests) | set(stuck_from_events) def make_tree(self, shares: dict[NodeOperatorId, Shares]) -> Tree: if not shares: @@ -348,13 +416,13 @@ def publish_tree(self, tree: Tree) -> CID: logger.info({"msg": "Tree dump uploaded to IPFS", "cid": repr(tree_cid)}) return tree_cid - def publish_log(self, log: FramePerfLog) -> CID: - log_cid = self.w3.ipfs.publish(log.encode()) - logger.info({"msg": "Frame log uploaded to IPFS", "cid": repr(log_cid)}) + def publish_log(self, logs: list[FramePerfLog]) -> CID: + log_cid = self.w3.ipfs.publish(FramePerfLog.encode(logs)) + logger.info({"msg": "Frame(s) log uploaded to IPFS", "cid": repr(log_cid)}) return log_cid @lru_cache(maxsize=1) - def current_frame_range(self, blockstamp: BlockStamp) -> tuple[EpochNumber, EpochNumber]: + def get_epochs_range_to_process(self, blockstamp: BlockStamp) -> tuple[EpochNumber, EpochNumber]: converter = self.converter(blockstamp) far_future_initial_epoch = converter.get_epoch_by_timestamp(UINT64_MAX) @@ -385,9 +453,9 @@ def current_frame_range(self, blockstamp: BlockStamp) -> tuple[EpochNumber, Epoc ) if l_ref_slot < last_processing_ref_slot: - raise CSMError(f"Got invalid frame range: {l_ref_slot=} < {last_processing_ref_slot=}") + raise CSMError(f"Got invalid epochs range: {l_ref_slot=} < {last_processing_ref_slot=}") if l_ref_slot >= r_ref_slot: - raise CSMError(f"Got invalid frame range {r_ref_slot=}, {l_ref_slot=}") + raise CSMError(f"Got invalid epochs range {r_ref_slot=}, {l_ref_slot=}") l_epoch = converter.get_epoch_by_slot(SlotNumber(l_ref_slot + 1)) r_epoch = converter.get_epoch_by_slot(r_ref_slot) @@ -401,11 +469,11 @@ def current_frame_range(self, blockstamp: BlockStamp) -> tuple[EpochNumber, Epoc def converter(self, blockstamp: BlockStamp) -> Web3Converter: return Web3Converter(self.get_chain_config(blockstamp), self.get_frame_config(blockstamp)) - def _get_module_id(self) -> StakingModuleId: + def _get_staking_module(self) -> StakingModule: modules: list[StakingModule] = self.w3.lido_contracts.staking_router.get_staking_modules() for mod in modules: if mod.staking_module_address == self.w3.csm.module.address: - return mod.id + return mod raise NoModuleFound diff --git a/src/modules/csm/log.py b/src/modules/csm/log.py index f89f4ef58..29ab24902 100644 --- a/src/modules/csm/log.py +++ b/src/modules/csm/log.py @@ -12,7 +12,7 @@ class LogJSONEncoder(json.JSONEncoder): ... @dataclass class ValidatorFrameSummary: - perf: AttestationsAccumulator = field(default_factory=AttestationsAccumulator) + attestation_duty: AttestationsAccumulator = field(default_factory=AttestationsAccumulator) slashed: bool = False @@ -35,13 +35,14 @@ class FramePerfLog: default_factory=lambda: defaultdict(OperatorFrameSummary) ) - def encode(self) -> bytes: + @staticmethod + def encode(logs: list['FramePerfLog']) -> bytes: return ( LogJSONEncoder( indent=None, separators=(',', ':'), sort_keys=True, ) - .encode(asdict(self)) + .encode([asdict(log) for log in logs]) .encode() ) diff --git a/src/modules/csm/state.py b/src/modules/csm/state.py index 4373f5259..48debfe83 100644 --- a/src/modules/csm/state.py +++ b/src/modules/csm/state.py @@ -3,6 +3,8 @@ import pickle from collections import defaultdict from dataclasses import dataclass +from functools import lru_cache +from itertools import batched from pathlib import Path from typing import Self @@ -33,6 +35,10 @@ def add_duty(self, included: bool) -> None: self.included += 1 if included else 0 +type Frame = tuple[EpochNumber, EpochNumber] +type StateData = dict[Frame, defaultdict[ValidatorIndex, AttestationsAccumulator]] + + class State: """ Processing state of a CSM performance oracle frame. @@ -43,16 +49,17 @@ class State: The state can be migrated to be used for another frame's report by calling the `migrate` method. """ - - data: defaultdict[ValidatorIndex, AttestationsAccumulator] + frames: list[Frame] + data: StateData _epochs_to_process: tuple[EpochNumber, ...] _processed_epochs: set[EpochNumber] _consensus_version: int = 1 - def __init__(self, data: dict[ValidatorIndex, AttestationsAccumulator] | None = None) -> None: - self.data = defaultdict(AttestationsAccumulator, data or {}) + def __init__(self) -> None: + self.frames = [] + self.data = {} self._epochs_to_process = tuple() self._processed_epochs = set() @@ -89,14 +96,45 @@ def file(cls) -> Path: def buffer(self) -> Path: return self.file().with_suffix(".buf") + @property + def is_empty(self) -> bool: + return not self.data and not self._epochs_to_process and not self._processed_epochs + + @property + def unprocessed_epochs(self) -> set[EpochNumber]: + if not self._epochs_to_process: + raise ValueError("Epochs to process are not set") + diff = set(self._epochs_to_process) - self._processed_epochs + return diff + + @property + def is_fulfilled(self) -> bool: + return not self.unprocessed_epochs + + @staticmethod + def _calculate_frames(epochs_to_process: tuple[EpochNumber, ...], epochs_per_frame: int) -> list[Frame]: + """Split epochs to process into frames of `epochs_per_frame` length""" + if len(epochs_to_process) % epochs_per_frame != 0: + raise ValueError("Insufficient epochs to form a frame") + return [(frame[0], frame[-1]) for frame in batched(sorted(epochs_to_process), epochs_per_frame)] + def clear(self) -> None: - self.data = defaultdict(AttestationsAccumulator) + self.data = {} self._epochs_to_process = tuple() self._processed_epochs.clear() assert self.is_empty - def inc(self, key: ValidatorIndex, included: bool) -> None: - self.data[key].add_duty(included) + @lru_cache(variables.CSM_ORACLE_MAX_CONCURRENCY) + def find_frame(self, epoch: EpochNumber) -> Frame: + for epoch_range in self.frames: + from_epoch, to_epoch = epoch_range + if from_epoch <= epoch <= to_epoch: + return epoch_range + raise ValueError(f"Epoch {epoch} is out of frames range: {self.frames}") + + def increment_duty(self, epoch: EpochNumber, val_index: ValidatorIndex, included: bool) -> None: + frame = self.find_frame(epoch) + self.data[frame][val_index].add_duty(included) def add_processed_epoch(self, epoch: EpochNumber) -> None: self._processed_epochs.add(epoch) @@ -104,7 +142,9 @@ def add_processed_epoch(self, epoch: EpochNumber) -> None: def log_progress(self) -> None: logger.info({"msg": f"Processed {len(self._processed_epochs)} of {len(self._epochs_to_process)} epochs"}) - def migrate(self, l_epoch: EpochNumber, r_epoch: EpochNumber, consensus_version: int): + def migrate( + self, l_epoch: EpochNumber, r_epoch: EpochNumber, epochs_per_frame: int, consensus_version: int + ) -> None: if consensus_version != self._consensus_version: logger.warning( { @@ -114,17 +154,41 @@ def migrate(self, l_epoch: EpochNumber, r_epoch: EpochNumber, consensus_version: ) self.clear() - for state_epochs in (self._epochs_to_process, self._processed_epochs): - for epoch in state_epochs: - if epoch < l_epoch or epoch > r_epoch: - logger.warning({"msg": "Discarding invalidated state cache"}) - self.clear() - break + new_frames = self._calculate_frames(tuple(sequence(l_epoch, r_epoch)), epochs_per_frame) + if self.frames == new_frames: + logger.info({"msg": "No need to migrate duties data cache"}) + return + self._migrate_frames_data(new_frames) + self.frames = new_frames + self.find_frame.cache_clear() self._epochs_to_process = tuple(sequence(l_epoch, r_epoch)) self._consensus_version = consensus_version self.commit() + def _migrate_frames_data(self, new_frames: list[Frame]): + logger.info({"msg": f"Migrating duties data cache: {self.frames=} -> {new_frames=}"}) + new_data: StateData = {frame: defaultdict(AttestationsAccumulator) for frame in new_frames} + + def overlaps(a: Frame, b: Frame): + return a[0] <= b[0] and a[1] >= b[1] + + consumed = [] + for new_frame in new_frames: + for frame_to_consume in self.frames: + if overlaps(new_frame, frame_to_consume): + assert frame_to_consume not in consumed + consumed.append(frame_to_consume) + for val, duty in self.data[frame_to_consume].items(): + new_data[new_frame][val].assigned += duty.assigned + new_data[new_frame][val].included += duty.included + for frame in self.frames: + if frame in consumed: + continue + logger.warning({"msg": f"Invalidating frame duties data cache: {frame}"}) + self._processed_epochs -= set(sequence(*frame)) + self.data = new_data + def validate(self, l_epoch: EpochNumber, r_epoch: EpochNumber) -> None: if not self.is_fulfilled: raise InvalidState(f"State is not fulfilled. {self.unprocessed_epochs=}") @@ -135,34 +199,15 @@ def validate(self, l_epoch: EpochNumber, r_epoch: EpochNumber) -> None: for epoch in sequence(l_epoch, r_epoch): if epoch not in self._processed_epochs: - raise InvalidState(f"Epoch {epoch} should be processed") - - @property - def is_empty(self) -> bool: - return not self.data and not self._epochs_to_process and not self._processed_epochs - - @property - def unprocessed_epochs(self) -> set[EpochNumber]: - if not self._epochs_to_process: - raise ValueError("Epochs to process are not set") - diff = set(self._epochs_to_process) - self._processed_epochs - return diff - - @property - def is_fulfilled(self) -> bool: - return not self.unprocessed_epochs - - @property - def frame(self) -> tuple[EpochNumber, EpochNumber]: - if not self._epochs_to_process: - raise ValueError("Epochs to process are not set") - return min(self._epochs_to_process), max(self._epochs_to_process) - - def get_network_aggr(self) -> AttestationsAccumulator: - """Return `AttestationsAccumulator` over duties of all the network validators""" + raise InvalidState(f"Epoch {epoch} missing in processed epochs") + def get_network_aggr(self, frame: Frame) -> AttestationsAccumulator: + # TODO: exclude `active_slashed` validators from the calculation included = assigned = 0 - for validator, acc in self.data.items(): + frame_data = self.data.get(frame) + if frame_data is None: + raise ValueError(f"No data for frame {frame} to calculate network aggregate") + for validator, acc in frame_data.items(): if acc.included > acc.assigned: raise ValueError(f"Invalid accumulator: {validator=}, {acc=}") included += acc.included diff --git a/src/providers/execution/contracts/cs_fee_distributor.py b/src/providers/execution/contracts/cs_fee_distributor.py index 937c7dc49..8a556b250 100644 --- a/src/providers/execution/contracts/cs_fee_distributor.py +++ b/src/providers/execution/contracts/cs_fee_distributor.py @@ -3,7 +3,7 @@ from eth_typing import ChecksumAddress from hexbytes import HexBytes from web3 import Web3 -from web3.types import BlockIdentifier +from web3.types import BlockIdentifier, Wei from ..base_interface import ContractInterface @@ -26,7 +26,7 @@ def oracle(self, block_identifier: BlockIdentifier = "latest") -> ChecksumAddres ) return Web3.to_checksum_address(resp) - def shares_to_distribute(self, block_identifier: BlockIdentifier = "latest") -> int: + def shares_to_distribute(self, block_identifier: BlockIdentifier = "latest") -> Wei: """Returns the amount of shares that are pending to be distributed""" resp = self.functions.pendingSharesToDistribute().call(block_identifier=block_identifier) diff --git a/tests/modules/csm/test_checkpoint.py b/tests/modules/csm/test_checkpoint.py index 44f23735e..070d82bf9 100644 --- a/tests/modules/csm/test_checkpoint.py +++ b/tests/modules/csm/test_checkpoint.py @@ -326,7 +326,7 @@ def test_checkpoints_processor_no_eip7549_support( monkeypatch: pytest.MonkeyPatch, ): state = State() - state.migrate(EpochNumber(0), EpochNumber(255), 1) + state.migrate(EpochNumber(0), EpochNumber(255), 256, 1) processor = FrameCheckpointProcessor( consensus_client, state, @@ -354,7 +354,7 @@ def test_checkpoints_processor_check_duty( converter, ): state = State() - state.migrate(0, 255, 1) + state.migrate(0, 255, 256, 1) finalized_blockstamp = ... processor = FrameCheckpointProcessor( consensus_client, @@ -367,7 +367,7 @@ def test_checkpoints_processor_check_duty( assert len(state._processed_epochs) == 1 assert len(state._epochs_to_process) == 256 assert len(state.unprocessed_epochs) == 255 - assert len(state.data) == 2048 * 32 + assert len(state.data[(0, 255)]) == 2048 * 32 def test_checkpoints_processor_process( @@ -379,7 +379,7 @@ def test_checkpoints_processor_process( converter, ): state = State() - state.migrate(0, 255, 1) + state.migrate(0, 255, 256, 1) finalized_blockstamp = ... processor = FrameCheckpointProcessor( consensus_client, @@ -392,7 +392,7 @@ def test_checkpoints_processor_process( assert len(state._processed_epochs) == 2 assert len(state._epochs_to_process) == 256 assert len(state.unprocessed_epochs) == 254 - assert len(state.data) == 2048 * 32 + assert len(state.data[(0, 255)]) == 2048 * 32 def test_checkpoints_processor_exec( @@ -404,7 +404,7 @@ def test_checkpoints_processor_exec( converter, ): state = State() - state.migrate(0, 255, 1) + state.migrate(0, 255, 256, 1) finalized_blockstamp = ... processor = FrameCheckpointProcessor( consensus_client, @@ -418,4 +418,4 @@ def test_checkpoints_processor_exec( assert len(state._processed_epochs) == 2 assert len(state._epochs_to_process) == 256 assert len(state.unprocessed_epochs) == 254 - assert len(state.data) == 2048 * 32 + assert len(state.data[(0, 255)]) == 2048 * 32 diff --git a/tests/modules/csm/test_csm_distribution.py b/tests/modules/csm/test_csm_distribution.py new file mode 100644 index 000000000..f277f6eb8 --- /dev/null +++ b/tests/modules/csm/test_csm_distribution.py @@ -0,0 +1,337 @@ +from collections import defaultdict +from unittest.mock import Mock, call + +import pytest +from web3.types import Wei + +from src.constants import UINT64_MAX +from src.modules.csm.csm import CSOracle, CSMError +from src.modules.csm.log import ValidatorFrameSummary +from src.modules.csm.state import AttestationsAccumulator, State +from src.types import NodeOperatorId, ValidatorIndex +from src.web3py.extensions import CSM +from tests.factory.blockstamp import ReferenceBlockStampFactory +from tests.factory.no_registry import LidoValidatorFactory + + +@pytest.fixture(autouse=True) +def mock_get_staking_module(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(CSOracle, "_get_staking_module", Mock()) + + +@pytest.fixture() +def module(web3, csm: CSM): + yield CSOracle(web3) + + +def test_calculate_distribution_handles_single_frame(module): + module.state = Mock() + module.state.frames = [(1, 2)] + blockstamp = Mock() + module.module_validators_by_node_operators = Mock() + module._get_ref_blockstamp_for_frame = Mock(return_value=blockstamp) + module.w3.csm.fee_distributor.shares_to_distribute = Mock(return_value=500) + module._calculate_distribution_in_frame = Mock(return_value=({NodeOperatorId(1): 500}, Mock())) + + total_distributed, total_rewards, logs = module.calculate_distribution(blockstamp) + + assert total_distributed == 500 + assert total_rewards[NodeOperatorId(1)] == 500 + assert len(logs) == 1 + + +def test_calculate_distribution_handles_multiple_frames(module): + module.state = Mock() + module.state.frames = [(1, 2), (3, 4), (5, 6)] + blockstamp = ReferenceBlockStampFactory.build(ref_epoch=2) + module.module_validators_by_node_operators = Mock() + module._get_ref_blockstamp_for_frame = Mock(return_value=blockstamp) + module.w3.csm.fee_distributor.shares_to_distribute = Mock(return_value=800) + module._calculate_distribution_in_frame = Mock( + side_effect=[ + ({NodeOperatorId(1): 500}, Mock()), + ({NodeOperatorId(1): 136}, Mock()), + ({NodeOperatorId(1): 164}, Mock()), + ] + ) + + total_distributed, total_rewards, logs = module.calculate_distribution(blockstamp) + + assert total_distributed == 800 + assert total_rewards[NodeOperatorId(1)] == 800 + assert len(logs) == 3 + module._get_ref_blockstamp_for_frame.assert_has_calls( + [call(blockstamp, frame[1]) for frame in module.state.frames[1:]] + ) + + +def test_calculate_distribution_handles_invalid_distribution(module): + module.state = Mock() + module.state.frames = [(1, 2)] + blockstamp = Mock() + module.module_validators_by_node_operators = Mock() + module._get_ref_blockstamp_for_frame = Mock(return_value=blockstamp) + module.w3.csm.fee_distributor.shares_to_distribute = Mock(return_value=500) + module._calculate_distribution_in_frame = Mock(return_value=({NodeOperatorId(1): 600}, Mock())) + + with pytest.raises(CSMError, match="Invalid distribution"): + module.calculate_distribution(blockstamp) + + +def test_calculate_distribution_in_frame_handles_stuck_operator(module): + frame = Mock() + blockstamp = Mock() + rewards_to_distribute = UINT64_MAX + operators_to_validators = {(Mock(), NodeOperatorId(1)): [LidoValidatorFactory.build()]} + module.state = State() + module.state.data = {frame: defaultdict(AttestationsAccumulator)} + module.get_stuck_operators = Mock(return_value={NodeOperatorId(1)}) + module._get_performance_threshold = Mock() + + rewards_distribution, log = module._calculate_distribution_in_frame( + frame, blockstamp, rewards_to_distribute, operators_to_validators + ) + + assert rewards_distribution[NodeOperatorId(1)] == 0 + assert log.operators[NodeOperatorId(1)].stuck is True + assert log.operators[NodeOperatorId(1)].distributed == 0 + assert log.operators[NodeOperatorId(1)].validators == defaultdict(ValidatorFrameSummary) + + +def test_calculate_distribution_in_frame_handles_no_attestation_duty(module): + frame = Mock() + blockstamp = Mock() + rewards_to_distribute = UINT64_MAX + validator = LidoValidatorFactory.build() + node_operator_id = validator.lido_id.operatorIndex + operators_to_validators = {(Mock(), node_operator_id): [validator]} + module.state = State() + module.state.data = {frame: defaultdict(AttestationsAccumulator)} + module.get_stuck_operators = Mock(return_value=set()) + module._get_performance_threshold = Mock() + + rewards_distribution, log = module._calculate_distribution_in_frame( + frame, blockstamp, rewards_to_distribute, operators_to_validators + ) + + assert rewards_distribution[node_operator_id] == 0 + assert log.operators[node_operator_id].stuck is False + assert log.operators[node_operator_id].distributed == 0 + assert log.operators[node_operator_id].validators == defaultdict(ValidatorFrameSummary) + + +def test_calculate_distribution_in_frame_handles_above_threshold_performance(module): + frame = Mock() + blockstamp = Mock() + rewards_to_distribute = UINT64_MAX + validator = LidoValidatorFactory.build() + validator.validator.slashed = False + node_operator_id = validator.lido_id.operatorIndex + operators_to_validators = {(Mock(), node_operator_id): [validator]} + module.state = State() + attestation_duty = AttestationsAccumulator(assigned=10, included=6) + module.state.data = {frame: {validator.index: attestation_duty}} + module.get_stuck_operators = Mock(return_value=set()) + module._get_performance_threshold = Mock(return_value=0.5) + + rewards_distribution, log = module._calculate_distribution_in_frame( + frame, blockstamp, rewards_to_distribute, operators_to_validators + ) + + assert rewards_distribution[node_operator_id] > 0 # no need to check exact value + assert log.operators[node_operator_id].stuck is False + assert log.operators[node_operator_id].distributed > 0 + assert log.operators[node_operator_id].validators[validator.index].attestation_duty == attestation_duty + + +def test_calculate_distribution_in_frame_handles_below_threshold_performance(module): + frame = Mock() + blockstamp = Mock() + rewards_to_distribute = UINT64_MAX + validator = LidoValidatorFactory.build() + validator.validator.slashed = False + node_operator_id = validator.lido_id.operatorIndex + operators_to_validators = {(Mock(), node_operator_id): [validator]} + module.state = State() + attestation_duty = AttestationsAccumulator(assigned=10, included=5) + module.state.data = {frame: {validator.index: attestation_duty}} + module.get_stuck_operators = Mock(return_value=set()) + module._get_performance_threshold = Mock(return_value=0.5) + + rewards_distribution, log = module._calculate_distribution_in_frame( + frame, blockstamp, rewards_to_distribute, operators_to_validators + ) + + assert rewards_distribution[node_operator_id] == 0 + assert log.operators[node_operator_id].stuck is False + assert log.operators[node_operator_id].distributed == 0 + assert log.operators[node_operator_id].validators[validator.index].attestation_duty == attestation_duty + + +def test_performance_threshold_calculates_correctly(module): + state = State() + state.data = { + (0, 31): { + ValidatorIndex(1): AttestationsAccumulator(10, 10), + ValidatorIndex(2): AttestationsAccumulator(10, 10), + }, + } + module.w3.csm.oracle.perf_leeway_bp.return_value = 500 + module.state = state + + threshold = module._get_performance_threshold((0, 31), Mock()) + + assert threshold == 0.95 + + +def test_performance_threshold_handles_zero_leeway(module): + state = State() + state.data = { + (0, 31): { + ValidatorIndex(1): AttestationsAccumulator(10, 10), + ValidatorIndex(2): AttestationsAccumulator(10, 10), + }, + } + module.w3.csm.oracle.perf_leeway_bp.return_value = 0 + module.state = state + + threshold = module._get_performance_threshold((0, 31), Mock()) + + assert threshold == 1.0 + + +def test_performance_threshold_handles_high_leeway(module): + state = State() + state.data = { + (0, 31): {ValidatorIndex(1): AttestationsAccumulator(10, 1), ValidatorIndex(2): AttestationsAccumulator(10, 1)}, + } + module.w3.csm.oracle.perf_leeway_bp.return_value = 5000 + module.state = state + + threshold = module._get_performance_threshold((0, 31), Mock()) + + assert threshold == -0.4 + + +def test_process_validator_duty_handles_above_threshold_performance(): + validator = LidoValidatorFactory.build() + validator.validator.slashed = False + log_operator = Mock() + log_operator.validators = defaultdict(ValidatorFrameSummary) + participation_shares = defaultdict(int) + threshold = 0.5 + + attestation_duty = AttestationsAccumulator(assigned=10, included=6) + + CSOracle.process_validator_duty(validator, attestation_duty, threshold, participation_shares, log_operator) + + assert participation_shares[validator.lido_id.operatorIndex] == 10 + assert log_operator.validators[validator.index].attestation_duty == attestation_duty + + +def test_process_validator_duty_handles_below_threshold_performance(): + validator = LidoValidatorFactory.build() + validator.validator.slashed = False + log_operator = Mock() + log_operator.validators = defaultdict(ValidatorFrameSummary) + participation_shares = defaultdict(int) + threshold = 0.5 + + attestation_duty = AttestationsAccumulator(assigned=10, included=4) + + CSOracle.process_validator_duty(validator, attestation_duty, threshold, participation_shares, log_operator) + + assert participation_shares[validator.lido_id.operatorIndex] == 0 + assert log_operator.validators[validator.index].attestation_duty == attestation_duty + + +def test_process_validator_duty_handles_non_empy_participation_shares(): + validator = LidoValidatorFactory.build() + validator.validator.slashed = False + log_operator = Mock() + log_operator.validators = defaultdict(ValidatorFrameSummary) + participation_shares = {validator.lido_id.operatorIndex: 25} + threshold = 0.5 + + attestation_duty = AttestationsAccumulator(assigned=10, included=6) + + CSOracle.process_validator_duty(validator, attestation_duty, threshold, participation_shares, log_operator) + + assert participation_shares[validator.lido_id.operatorIndex] == 35 + assert log_operator.validators[validator.index].attestation_duty == attestation_duty + + +def test_process_validator_duty_handles_no_duty_assigned(): + validator = LidoValidatorFactory.build() + log_operator = Mock() + log_operator.validators = defaultdict(ValidatorFrameSummary) + participation_shares = defaultdict(int) + threshold = 0.5 + + CSOracle.process_validator_duty(validator, None, threshold, participation_shares, log_operator) + + assert participation_shares[validator.lido_id.operatorIndex] == 0 + assert validator.index not in log_operator.validators + + +def test_process_validator_duty_handles_slashed_validator(): + validator = LidoValidatorFactory.build() + validator.validator.slashed = True + log_operator = Mock() + log_operator.validators = defaultdict(ValidatorFrameSummary) + participation_shares = defaultdict(int) + threshold = 0.5 + + attestation_duty = AttestationsAccumulator(assigned=1, included=1) + + CSOracle.process_validator_duty(validator, attestation_duty, threshold, participation_shares, log_operator) + + assert participation_shares[validator.lido_id.operatorIndex] == 0 + assert log_operator.validators[validator.index].slashed is True + + +def test_calc_rewards_distribution_in_frame_correctly_distributes_rewards(): + participation_shares = {NodeOperatorId(1): 100, NodeOperatorId(2): 200} + rewards_to_distribute = Wei(1 * 10**18) + + rewards_distribution = CSOracle.calc_rewards_distribution_in_frame(participation_shares, rewards_to_distribute) + + assert rewards_distribution[NodeOperatorId(1)] == Wei(333333333333333333) + assert rewards_distribution[NodeOperatorId(2)] == Wei(666666666666666666) + + +def test_calc_rewards_distribution_in_frame_handles_zero_participation(): + participation_shares = {NodeOperatorId(1): 0, NodeOperatorId(2): 0} + rewards_to_distribute = Wei(1 * 10**18) + + rewards_distribution = CSOracle.calc_rewards_distribution_in_frame(participation_shares, rewards_to_distribute) + + assert rewards_distribution[NodeOperatorId(1)] == 0 + assert rewards_distribution[NodeOperatorId(2)] == 0 + + +def test_calc_rewards_distribution_in_frame_handles_no_participation(): + participation_shares = {} + rewards_to_distribute = Wei(1 * 10**18) + + rewards_distribution = CSOracle.calc_rewards_distribution_in_frame(participation_shares, rewards_to_distribute) + + assert len(rewards_distribution) == 0 + + +def test_calc_rewards_distribution_in_frame_handles_partial_participation(): + participation_shares = {NodeOperatorId(1): 100, NodeOperatorId(2): 0} + rewards_to_distribute = Wei(1 * 10**18) + + rewards_distribution = CSOracle.calc_rewards_distribution_in_frame(participation_shares, rewards_to_distribute) + + assert rewards_distribution[NodeOperatorId(1)] == Wei(1 * 10**18) + assert rewards_distribution[NodeOperatorId(2)] == 0 + + +def test_calc_rewards_distribution_in_frame_handles_negative_to_distribute(): + participation_shares = {NodeOperatorId(1): 100, NodeOperatorId(2): 200} + rewards_to_distribute = Wei(-1) + + with pytest.raises(ValueError, match="Invalid rewards to distribute"): + CSOracle.calc_rewards_distribution_in_frame(participation_shares, rewards_to_distribute) diff --git a/tests/modules/csm/test_csm_module.py b/tests/modules/csm/test_csm_module.py index f74af8d69..ee096b64d 100644 --- a/tests/modules/csm/test_csm_module.py +++ b/tests/modules/csm/test_csm_module.py @@ -9,20 +9,20 @@ from src.constants import UINT64_MAX from src.modules.csm.csm import CSOracle -from src.modules.csm.state import AttestationsAccumulator, State +from src.modules.csm.state import State from src.modules.csm.tree import Tree from src.modules.submodules.oracle_module import ModuleExecuteDelay from src.modules.submodules.types import CurrentFrame, ZERO_HASH from src.providers.ipfs import CIDv0, CID -from src.types import EpochNumber, NodeOperatorId, SlotNumber, StakingModuleId, ValidatorIndex +from src.types import NodeOperatorId, SlotNumber, StakingModuleId from src.web3py.extensions.csm import CSM from tests.factory.blockstamp import BlockStampFactory, ReferenceBlockStampFactory from tests.factory.configs import ChainConfigFactory, FrameConfigFactory @pytest.fixture(autouse=True) -def mock_get_module_id(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setattr(CSOracle, "_get_module_id", Mock()) +def mock_get_staking_module(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(CSOracle, "_get_staking_module", Mock()) @pytest.fixture(autouse=True) @@ -39,34 +39,28 @@ def test_init(module: CSOracle): assert module -def test_stuck_operators(module: CSOracle, csm: CSM): +def test_get_stuck_operators(module: CSOracle, csm: CSM): module.module = Mock() # type: ignore - module.module_id = StakingModuleId(1) module.w3.cc = Mock() module.w3.lido_validators = Mock() module.w3.lido_contracts = Mock() - module.w3.lido_validators.get_lido_node_operators_by_modules = Mock( - return_value={ - 1: { - type('NodeOperator', (object,), {'id': 0, 'stuck_validators_count': 0})(), - type('NodeOperator', (object,), {'id': 1, 'stuck_validators_count': 0})(), - type('NodeOperator', (object,), {'id': 2, 'stuck_validators_count': 1})(), - type('NodeOperator', (object,), {'id': 3, 'stuck_validators_count': 0})(), - type('NodeOperator', (object,), {'id': 4, 'stuck_validators_count': 100500})(), - type('NodeOperator', (object,), {'id': 5, 'stuck_validators_count': 100})(), - type('NodeOperator', (object,), {'id': 6, 'stuck_validators_count': 0})(), - }, - 2: {}, - 3: {}, - 4: {}, - } + module.w3.lido_contracts.staking_router.get_all_node_operator_digests = Mock( + return_value=[ + Mock(id=0, stuck_validators_count=0), + Mock(id=1, stuck_validators_count=0), + Mock(id=2, stuck_validators_count=1), + Mock(id=3, stuck_validators_count=0), + Mock(id=4, stuck_validators_count=100500), + Mock(id=5, stuck_validators_count=100), + Mock(id=6, stuck_validators_count=0), + ] ) module.w3.csm.get_operators_with_stucks_in_range = Mock( return_value=[NodeOperatorId(2), NodeOperatorId(4), NodeOperatorId(6), NodeOperatorId(1337)] ) - module.current_frame_range = Mock(return_value=(69, 100)) + module.get_epochs_range_to_process = Mock(return_value=(69, 100)) module.converter = Mock() module.converter.get_epoch_first_slot = Mock(return_value=lambda epoch: epoch * 32) @@ -78,31 +72,17 @@ def test_stuck_operators(module: CSOracle, csm: CSM): with patch('src.modules.csm.csm.build_blockstamp', return_value=l_blockstamp): with patch('src.modules.csm.csm.get_next_non_missed_slot', return_value=Mock()): - stuck = module.stuck_operators(blockstamp=blockstamp) + stuck = module.get_stuck_operators(frame=(69, 100), frame_blockstamp=blockstamp) assert stuck == {NodeOperatorId(2), NodeOperatorId(4), NodeOperatorId(5), NodeOperatorId(6), NodeOperatorId(1337)} -def test_stuck_operators_left_border_before_enact(module: CSOracle, csm: CSM, caplog: pytest.LogCaptureFixture): +def test_get_stuck_operators_left_border_before_enact(module: CSOracle, csm: CSM, caplog: pytest.LogCaptureFixture): module.module = Mock() # type: ignore - module.module_id = StakingModuleId(3) module.w3.cc = Mock() module.w3.lido_validators = Mock() module.w3.lido_contracts = Mock() - module.w3.lido_validators.get_lido_node_operators_by_modules = Mock( - return_value={ - 1: { - type('NodeOperator', (object,), {'id': 0, 'stuck_validators_count': 0})(), - type('NodeOperator', (object,), {'id': 1, 'stuck_validators_count': 0})(), - type('NodeOperator', (object,), {'id': 2, 'stuck_validators_count': 1})(), - type('NodeOperator', (object,), {'id': 3, 'stuck_validators_count': 0})(), - type('NodeOperator', (object,), {'id': 4, 'stuck_validators_count': 100500})(), - type('NodeOperator', (object,), {'id': 5, 'stuck_validators_count': 100})(), - type('NodeOperator', (object,), {'id': 6, 'stuck_validators_count': 0})(), - }, - 2: {}, - } - ) + module.w3.lido_contracts.staking_router.get_all_node_operator_digests = Mock(return_value=[]) module.w3.csm.get_operators_with_stucks_in_range = Mock( return_value=[ @@ -112,7 +92,7 @@ def test_stuck_operators_left_border_before_enact(module: CSOracle, csm: CSM, ca ] ) - module.current_frame_range = Mock(return_value=(69, 100)) + module.get_epochs_range_to_process = Mock(return_value=(69, 100)) module.converter = Mock() module.converter.get_epoch_first_slot = Mock(return_value=lambda epoch: epoch * 32) @@ -121,7 +101,7 @@ def test_stuck_operators_left_border_before_enact(module: CSOracle, csm: CSM, ca with patch('src.modules.csm.csm.build_blockstamp', return_value=l_blockstamp): with patch('src.modules.csm.csm.get_next_non_missed_slot', return_value=Mock()): - stuck = module.stuck_operators(blockstamp=blockstamp) + stuck = module.get_stuck_operators(frame=(69, 100), frame_blockstamp=blockstamp) assert stuck == { NodeOperatorId(2), @@ -132,103 +112,6 @@ def test_stuck_operators_left_border_before_enact(module: CSOracle, csm: CSM, ca assert caplog.messages[0].startswith("No CSM digest at blockstamp") -def test_calculate_distribution(module: CSOracle, csm: CSM): - csm.fee_distributor.shares_to_distribute = Mock(return_value=10_000) - csm.oracle.perf_leeway_bp = Mock(return_value=500) - - module.module_validators_by_node_operators = Mock( - return_value={ - (None, NodeOperatorId(0)): [Mock(index=0, validator=Mock(slashed=False))], - (None, NodeOperatorId(1)): [Mock(index=1, validator=Mock(slashed=False))], - (None, NodeOperatorId(2)): [Mock(index=2, validator=Mock(slashed=False))], # stuck - (None, NodeOperatorId(3)): [Mock(index=3, validator=Mock(slashed=False))], - (None, NodeOperatorId(4)): [Mock(index=4, validator=Mock(slashed=False))], # stuck - (None, NodeOperatorId(5)): [ - Mock(index=5, validator=Mock(slashed=False)), - Mock(index=6, validator=Mock(slashed=False)), - ], - (None, NodeOperatorId(6)): [ - Mock(index=7, validator=Mock(slashed=False)), - Mock(index=8, validator=Mock(slashed=False)), - ], - (None, NodeOperatorId(7)): [Mock(index=9, validator=Mock(slashed=False))], - (None, NodeOperatorId(8)): [ - Mock(index=10, validator=Mock(slashed=False)), - Mock(index=11, validator=Mock(slashed=True)), - ], - (None, NodeOperatorId(9)): [Mock(index=12, validator=Mock(slashed=True))], - } - ) - module.stuck_operators = Mock( - return_value=[ - NodeOperatorId(2), - NodeOperatorId(4), - ] - ) - - module.state = State( - { - ValidatorIndex(0): AttestationsAccumulator(included=200, assigned=200), # short on frame - ValidatorIndex(1): AttestationsAccumulator(included=1000, assigned=1000), - ValidatorIndex(2): AttestationsAccumulator(included=1000, assigned=1000), - ValidatorIndex(3): AttestationsAccumulator(included=999, assigned=1000), - ValidatorIndex(4): AttestationsAccumulator(included=900, assigned=1000), - ValidatorIndex(5): AttestationsAccumulator(included=500, assigned=1000), # underperforming - ValidatorIndex(6): AttestationsAccumulator(included=0, assigned=0), # underperforming - ValidatorIndex(7): AttestationsAccumulator(included=900, assigned=1000), - ValidatorIndex(8): AttestationsAccumulator(included=500, assigned=1000), # underperforming - # ValidatorIndex(9): AttestationsAggregate(included=0, assigned=0), # missing in state - ValidatorIndex(10): AttestationsAccumulator(included=1000, assigned=1000), - ValidatorIndex(11): AttestationsAccumulator(included=1000, assigned=1000), - ValidatorIndex(12): AttestationsAccumulator(included=1000, assigned=1000), - } - ) - module.state.migrate(EpochNumber(100), EpochNumber(500), 1) - - _, shares, log = module.calculate_distribution(blockstamp=Mock()) - - assert tuple(shares.items()) == ( - (NodeOperatorId(0), 476), - (NodeOperatorId(1), 2380), - (NodeOperatorId(3), 2380), - (NodeOperatorId(6), 2380), - (NodeOperatorId(8), 2380), - ) - - assert tuple(log.operators.keys()) == ( - NodeOperatorId(0), - NodeOperatorId(1), - NodeOperatorId(2), - NodeOperatorId(3), - NodeOperatorId(4), - NodeOperatorId(5), - NodeOperatorId(6), - # NodeOperatorId(7), # Missing in state - NodeOperatorId(8), - NodeOperatorId(9), - ) - - assert not log.operators[NodeOperatorId(1)].stuck - - assert log.operators[NodeOperatorId(2)].validators == {} - assert log.operators[NodeOperatorId(2)].stuck - assert log.operators[NodeOperatorId(4)].validators == {} - assert log.operators[NodeOperatorId(4)].stuck - - assert 5 in log.operators[NodeOperatorId(5)].validators - assert 6 in log.operators[NodeOperatorId(5)].validators - assert 7 in log.operators[NodeOperatorId(6)].validators - - assert log.operators[NodeOperatorId(0)].distributed == 476 - assert log.operators[NodeOperatorId(1)].distributed == 2380 - assert log.operators[NodeOperatorId(2)].distributed == 0 - assert log.operators[NodeOperatorId(3)].distributed == 2380 - assert log.operators[NodeOperatorId(6)].distributed == 2380 - - assert log.frame == (100, 500) - assert log.threshold == module.state.get_network_aggr().perf - 0.05 - - # Static functions you were dreaming of for so long. @@ -380,11 +263,11 @@ def test_current_frame_range(module: CSOracle, csm: CSM, mock_chain_config: NoRe if param.expected_frame is ValueError: with pytest.raises(ValueError): - module.current_frame_range(ReferenceBlockStampFactory.build(slot_number=param.finalized_slot)) + module.get_epochs_range_to_process(ReferenceBlockStampFactory.build(slot_number=param.finalized_slot)) else: bs = ReferenceBlockStampFactory.build(slot_number=param.finalized_slot) - l_epoch, r_epoch = module.current_frame_range(bs) + l_epoch, r_epoch = module.get_epochs_range_to_process(bs) assert (l_epoch, r_epoch) == param.expected_frame @@ -418,7 +301,7 @@ class CollectDataTestParam: collect_frame_range=Mock(return_value=(0, 1)), report_blockstamp=Mock(ref_epoch=3), state=Mock(), - expected_msg="Frame has been changed, but the change is not yet observed on finalized epoch 1", + expected_msg="Epochs range has been changed, but the change is not yet observed on finalized epoch 1", expected_result=False, ), id="frame_changed_forward", @@ -429,7 +312,7 @@ class CollectDataTestParam: collect_frame_range=Mock(return_value=(0, 2)), report_blockstamp=Mock(ref_epoch=1), state=Mock(), - expected_msg="Frame has been changed, but the change is not yet observed on finalized epoch 1", + expected_msg="Epochs range has been changed, but the change is not yet observed on finalized epoch 1", expected_result=False, ), id="frame_changed_backward", @@ -440,7 +323,7 @@ class CollectDataTestParam: collect_frame_range=Mock(return_value=(1, 2)), report_blockstamp=Mock(ref_epoch=2), state=Mock(), - expected_msg="The starting epoch of the frame is not finalized yet", + expected_msg="The starting epoch of the epochs range is not finalized yet", expected_result=False, ), id="starting_epoch_not_finalized", @@ -490,7 +373,7 @@ def test_collect_data( module.w3 = Mock() module._receive_last_finalized_slot = Mock() module.state = param.state - module.current_frame_range = param.collect_frame_range + module.get_epochs_range_to_process = param.collect_frame_range module.get_blockstamp_for_report = Mock(return_value=param.report_blockstamp) with caplog.at_level(logging.DEBUG): @@ -516,7 +399,7 @@ def test_collect_data_outdated_checkpoint( unprocessed_epochs=list(range(0, 101)), is_fulfilled=False, ) - module.current_frame_range = Mock(side_effect=[(0, 100), (50, 150)]) + module.get_epochs_range_to_process = Mock(side_effect=[(0, 100), (50, 150)]) module.get_blockstamp_for_report = Mock(return_value=Mock(ref_epoch=100)) with caplog.at_level(logging.DEBUG): @@ -524,7 +407,10 @@ def test_collect_data_outdated_checkpoint( module.collect_data(blockstamp=Mock(slot_number=640)) msg = list( - filter(lambda log: "Checkpoints were prepared for an outdated frame, stop processing" in log, caplog.messages) + filter( + lambda log: "Checkpoints were prepared for an outdated epochs range, stop processing" in log, + caplog.messages, + ) ) assert len(msg), "Expected message not found in logs" @@ -540,7 +426,7 @@ def test_collect_data_fulfilled_state( unprocessed_epochs=list(range(0, 101)), ) type(module.state).is_fulfilled = PropertyMock(side_effect=[False, True]) - module.current_frame_range = Mock(return_value=(0, 100)) + module.get_epochs_range_to_process = Mock(return_value=(0, 100)) module.get_blockstamp_for_report = Mock(return_value=Mock(ref_epoch=100)) with caplog.at_level(logging.DEBUG): @@ -692,7 +578,7 @@ def test_build_report(csm: CSM, module: CSOracle, param: BuildReportTestParam): # mock previous report module.w3.csm.get_csm_tree_root = Mock(return_value=param.prev_tree_root) module.w3.csm.get_csm_tree_cid = Mock(return_value=param.prev_tree_cid) - module.get_accumulated_shares = Mock(return_value=param.prev_acc_shares) + module.get_accumulated_rewards = Mock(return_value=param.prev_acc_shares) # mock current frame module.calculate_distribution = param.curr_distribution module.make_tree = Mock(return_value=Mock(root=param.curr_tree_root)) @@ -751,7 +637,7 @@ def test_get_accumulated_shares(module: CSOracle, tree: Tree): encoded_tree = tree.encode() module.w3.ipfs = Mock(fetch=Mock(return_value=encoded_tree)) - for i, leaf in enumerate(module.get_accumulated_shares(cid=CIDv0("0x100500"), root=tree.root)): + for i, leaf in enumerate(module.get_accumulated_rewards(cid=CIDv0("0x100500"), root=tree.root)): assert tuple(leaf) == tree.tree.values[i]["value"] @@ -760,7 +646,7 @@ def test_get_accumulated_shares_unexpected_root(module: CSOracle, tree: Tree): module.w3.ipfs = Mock(fetch=Mock(return_value=encoded_tree)) with pytest.raises(ValueError): - next(module.get_accumulated_shares(cid=CIDv0("0x100500"), root=HexBytes("0x100500"))) + next(module.get_accumulated_rewards(cid=CIDv0("0x100500"), root=HexBytes("0x100500"))) @dataclass(frozen=True) diff --git a/tests/modules/csm/test_log.py b/tests/modules/csm/test_log.py index de52ca9ef..c95ef93ca 100644 --- a/tests/modules/csm/test_log.py +++ b/tests/modules/csm/test_log.py @@ -1,8 +1,7 @@ import json import pytest -from src.modules.csm.log import FramePerfLog -from src.modules.csm.state import AttestationsAccumulator +from src.modules.csm.log import FramePerfLog, AttestationsAccumulator from src.types import EpochNumber, NodeOperatorId, ReferenceBlockStamp from tests.factory.blockstamp import ReferenceBlockStampFactory @@ -29,20 +28,60 @@ def test_fields_access(log: FramePerfLog): def test_log_encode(log: FramePerfLog): # Fill in dynamic fields to make sure we have data in it to be encoded. - log.operators[NodeOperatorId(42)].validators["41337"].perf = AttestationsAccumulator(220, 119) + log.operators[NodeOperatorId(42)].validators["41337"].attestation_duty = AttestationsAccumulator(220, 119) log.operators[NodeOperatorId(42)].distributed = 17 log.operators[NodeOperatorId(0)].distributed = 0 - encoded = log.encode() + logs = [log] + + encoded = FramePerfLog.encode(logs) + + for decoded in json.loads(encoded): + assert decoded["operators"]["42"]["validators"]["41337"]["attestation_duty"]["assigned"] == 220 + assert decoded["operators"]["42"]["validators"]["41337"]["attestation_duty"]["included"] == 119 + assert decoded["operators"]["42"]["distributed"] == 17 + assert decoded["operators"]["0"]["distributed"] == 0 + + assert decoded["blockstamp"]["block_hash"] == log.blockstamp.block_hash + assert decoded["blockstamp"]["ref_slot"] == log.blockstamp.ref_slot + + assert decoded["threshold"] == log.threshold + assert decoded["frame"] == list(log.frame) + + +def test_logs_encode(): + log_0 = FramePerfLog(ReferenceBlockStampFactory.build(), (EpochNumber(100), EpochNumber(500))) + log_0.operators[NodeOperatorId(42)].validators["41337"].attestation_duty = AttestationsAccumulator(220, 119) + log_0.operators[NodeOperatorId(42)].distributed = 17 + log_0.operators[NodeOperatorId(0)].distributed = 0 + + log_1 = FramePerfLog(ReferenceBlockStampFactory.build(), (EpochNumber(500), EpochNumber(900))) + log_1.operators[NodeOperatorId(5)].validators["1234"].attestation_duty = AttestationsAccumulator(400, 399) + log_1.operators[NodeOperatorId(5)].distributed = 40 + log_1.operators[NodeOperatorId(18)].distributed = 0 + + logs = [log_0, log_1] + + encoded = FramePerfLog.encode(logs) + decoded = json.loads(encoded) - assert decoded["operators"]["42"]["validators"]["41337"]["perf"]["assigned"] == 220 - assert decoded["operators"]["42"]["validators"]["41337"]["perf"]["included"] == 119 - assert decoded["operators"]["42"]["distributed"] == 17 - assert decoded["operators"]["0"]["distributed"] == 0 + assert len(decoded) == 2 + + assert decoded[0]["operators"]["42"]["validators"]["41337"]["attestation_duty"]["assigned"] == 220 + assert decoded[0]["operators"]["42"]["validators"]["41337"]["attestation_duty"]["included"] == 119 + assert decoded[0]["operators"]["42"]["distributed"] == 17 + assert decoded[0]["operators"]["0"]["distributed"] == 0 + + assert decoded[1]["operators"]["5"]["validators"]["1234"]["attestation_duty"]["assigned"] == 400 + assert decoded[1]["operators"]["5"]["validators"]["1234"]["attestation_duty"]["included"] == 399 + assert decoded[1]["operators"]["5"]["distributed"] == 40 + assert decoded[1]["operators"]["18"]["distributed"] == 0 - assert decoded["blockstamp"]["block_hash"] == log.blockstamp.block_hash - assert decoded["blockstamp"]["ref_slot"] == log.blockstamp.ref_slot + for i, log in enumerate(logs): + assert decoded[i]["blockstamp"]["block_hash"] == log.blockstamp.block_hash + assert decoded[i]["blockstamp"]["ref_slot"] == log.blockstamp.ref_slot - assert decoded["threshold"] == log.threshold - assert decoded["frame"] == list(log.frame) + assert decoded[i]["threshold"] == log.threshold + assert decoded[i]["frame"] == list(log.frame) + assert decoded[i]["distributable"] == log.distributable diff --git a/tests/modules/csm/test_state.py b/tests/modules/csm/test_state.py index 7539f7d26..2fb3cd078 100644 --- a/tests/modules/csm/test_state.py +++ b/tests/modules/csm/test_state.py @@ -1,219 +1,420 @@ +import os +import pickle +from collections import defaultdict from pathlib import Path from unittest.mock import Mock import pytest -from src.modules.csm.state import AttestationsAccumulator, State -from src.types import EpochNumber, ValidatorIndex +from src import variables +from src.modules.csm.state import AttestationsAccumulator, State, InvalidState +from src.types import ValidatorIndex from src.utils.range import sequence -@pytest.fixture() -def state_file_path(tmp_path: Path) -> Path: - return (tmp_path / "mock").with_suffix(State.EXTENSION) +@pytest.fixture(autouse=True) +def remove_state_files(): + state_file = Path("/tmp/state.pkl") + state_buf = Path("/tmp/state.buf") + state_file.unlink(missing_ok=True) + state_buf.unlink(missing_ok=True) + yield + state_file.unlink(missing_ok=True) + state_buf.unlink(missing_ok=True) + + +def test_load_restores_state_from_file(monkeypatch): + monkeypatch.setattr("src.modules.csm.state.State.file", lambda _=None: Path("/tmp/state.pkl")) + state = State() + state.data = { + (0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), + } + state.commit() + loaded_state = State.load() + assert loaded_state.data == state.data -@pytest.fixture(autouse=True) -def mock_state_file(state_file_path: Path): - State.file = Mock(return_value=state_file_path) +def test_load_returns_new_instance_if_file_not_found(monkeypatch): + monkeypatch.setattr("src.modules.csm.state.State.file", lambda: Path("/non/existent/path")) + state = State.load() + assert state.is_empty -def test_attestation_aggregate_perf(): - aggr = AttestationsAccumulator(included=333, assigned=777) - assert aggr.perf == pytest.approx(0.4285, abs=1e-4) +def test_load_returns_new_instance_if_empty_object(monkeypatch, tmp_path): + with open('/tmp/state.pkl', "wb") as f: + pickle.dump(None, f) + monkeypatch.setattr("src.modules.csm.state.State.file", lambda: Path("/tmp/state.pkl")) + state = State.load() + assert state.is_empty -def test_state_avg_perf(): +def test_commit_saves_state_to_file(monkeypatch): state = State() + state.data = { + (0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), + } + monkeypatch.setattr("src.modules.csm.state.State.file", lambda _: Path("/tmp/state.pkl")) + monkeypatch.setattr("os.replace", Mock(side_effect=os.replace)) + state.commit() + with open("/tmp/state.pkl", "rb") as f: + loaded_state = pickle.load(f) + assert loaded_state.data == state.data + os.replace.assert_called_once_with(Path("/tmp/state.buf"), Path("/tmp/state.pkl")) - assert state.get_network_aggr().perf == 0 - - state = State( - { - ValidatorIndex(0): AttestationsAccumulator(included=0, assigned=0), - ValidatorIndex(1): AttestationsAccumulator(included=0, assigned=0), - } - ) - assert state.get_network_aggr().perf == 0 +def test_file_returns_correct_path(monkeypatch): + monkeypatch.setattr(variables, "CACHE_PATH", Path("/tmp")) + assert State.file() == Path("/tmp/cache.pkl") - state = State( - { - ValidatorIndex(0): AttestationsAccumulator(included=333, assigned=777), - ValidatorIndex(1): AttestationsAccumulator(included=167, assigned=223), - } - ) - assert state.get_network_aggr().perf == 0.5 +def test_buffer_returns_correct_path(monkeypatch): + monkeypatch.setattr(variables, "CACHE_PATH", Path("/tmp")) + state = State() + assert state.buffer == Path("/tmp/cache.buf") -def test_state_frame(): +def test_is_empty_returns_true_for_empty_state(): state = State() + assert state.is_empty - state.migrate(EpochNumber(100), EpochNumber(500), 1) - assert state.frame == (100, 500) - state.migrate(EpochNumber(300), EpochNumber(301), 1) - assert state.frame == (300, 301) +def test_is_empty_returns_false_for_non_empty_state(): + state = State() + state.data = {(0, 31): defaultdict(AttestationsAccumulator)} + assert not state.is_empty - state.clear() +def test_unprocessed_epochs_raises_error_if_epochs_not_set(): + state = State() with pytest.raises(ValueError, match="Epochs to process are not set"): - state.frame + state.unprocessed_epochs -def test_state_attestations(): - state = State( - { - ValidatorIndex(0): AttestationsAccumulator(included=333, assigned=777), - ValidatorIndex(1): AttestationsAccumulator(included=167, assigned=223), - } - ) +def test_unprocessed_epochs_returns_correct_set(): + state = State() + state._epochs_to_process = tuple(sequence(0, 95)) + state._processed_epochs = set(sequence(0, 63)) + assert state.unprocessed_epochs == set(sequence(64, 95)) - network_aggr = state.get_network_aggr() - assert network_aggr.assigned == 1000 - assert network_aggr.included == 500 +def test_is_fulfilled_returns_true_if_no_unprocessed_epochs(): + state = State() + state._epochs_to_process = tuple(sequence(0, 95)) + state._processed_epochs = set(sequence(0, 95)) + assert state.is_fulfilled -def test_state_load(): - orig = State( - { - ValidatorIndex(0): AttestationsAccumulator(included=333, assigned=777), - ValidatorIndex(1): AttestationsAccumulator(included=167, assigned=223), - } - ) +def test_is_fulfilled_returns_false_if_unprocessed_epochs_exist(): + state = State() + state._epochs_to_process = tuple(sequence(0, 95)) + state._processed_epochs = set(sequence(0, 63)) + assert not state.is_fulfilled - orig.commit() - copy = State.load() - assert copy.data == orig.data +def test_calculate_frames_handles_exact_frame_size(): + epochs = tuple(range(10)) + frames = State._calculate_frames(epochs, 5) + assert frames == [(0, 4), (5, 9)] -def test_state_clear(): - state = State( - { - ValidatorIndex(0): AttestationsAccumulator(included=333, assigned=777), - ValidatorIndex(1): AttestationsAccumulator(included=167, assigned=223), - } - ) - state._epochs_to_process = (EpochNumber(1), EpochNumber(33)) - state._processed_epochs = {EpochNumber(42), EpochNumber(17)} +def test_calculate_frames_raises_error_for_insufficient_epochs(): + epochs = tuple(range(8)) + with pytest.raises(ValueError, match="Insufficient epochs to form a frame"): + State._calculate_frames(epochs, 5) + +def test_clear_resets_state_to_empty(): + state = State() + state.data = {(0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)})} state.clear() assert state.is_empty - assert not state.data -def test_state_add_processed_epoch(): +def test_find_frame_returns_correct_frame(): + state = State() + state.frames = [(0, 31)] + state.data = {(0, 31): defaultdict(AttestationsAccumulator)} + assert state.find_frame(15) == (0, 31) + + +def test_find_frame_raises_error_for_out_of_range_epoch(): + state = State() + state.frames = [(0, 31)] + state.data = {(0, 31): defaultdict(AttestationsAccumulator)} + with pytest.raises(ValueError, match="Epoch 32 is out of frames range"): + state.find_frame(32) + + +def test_increment_duty_adds_duty_correctly(): + state = State() + frame = (0, 31) + state.frames = [frame] + duty_epoch, _ = frame + state.data = { + frame: defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), + } + state.increment_duty(duty_epoch, ValidatorIndex(1), True) + assert state.data[frame][ValidatorIndex(1)].assigned == 11 + assert state.data[frame][ValidatorIndex(1)].included == 6 + + +def test_increment_duty_creates_new_validator_entry(): state = State() - state.add_processed_epoch(EpochNumber(42)) - state.add_processed_epoch(EpochNumber(17)) - assert state._processed_epochs == {EpochNumber(42), EpochNumber(17)} + frame = (0, 31) + state.frames = [frame] + duty_epoch, _ = frame + state.data = { + frame: defaultdict(AttestationsAccumulator), + } + state.increment_duty(duty_epoch, ValidatorIndex(2), True) + assert state.data[frame][ValidatorIndex(2)].assigned == 1 + assert state.data[frame][ValidatorIndex(2)].included == 1 + + +def test_increment_duty_handles_non_included_duty(): + state = State() + frame = (0, 31) + state.frames = [frame] + frame = (0, 31) + duty_epoch, _ = frame + state.data = { + frame: defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), + } + state.increment_duty(duty_epoch, ValidatorIndex(1), False) + assert state.data[frame][ValidatorIndex(1)].assigned == 11 + assert state.data[frame][ValidatorIndex(1)].included == 5 + + +def test_increment_duty_raises_error_for_out_of_range_epoch(): + state = State() + frame = (0, 31) + state.frames = [frame] + state.data = { + frame: defaultdict(AttestationsAccumulator), + } + with pytest.raises(ValueError, match="is out of frames range"): + state.increment_duty(32, ValidatorIndex(1), True) -def test_state_inc(): - state = State( - { - ValidatorIndex(0): AttestationsAccumulator(included=0, assigned=0), - ValidatorIndex(1): AttestationsAccumulator(included=1, assigned=2), - } - ) +def test_add_processed_epoch_adds_epoch_to_processed_set(): + state = State() + state.add_processed_epoch(5) + assert 5 in state._processed_epochs - state.inc(ValidatorIndex(0), True) - state.inc(ValidatorIndex(0), False) - state.inc(ValidatorIndex(1), True) - state.inc(ValidatorIndex(1), True) - state.inc(ValidatorIndex(1), False) +def test_add_processed_epoch_does_not_duplicate_epochs(): + state = State() + state.add_processed_epoch(5) + state.add_processed_epoch(5) + assert len(state._processed_epochs) == 1 - state.inc(ValidatorIndex(2), True) - state.inc(ValidatorIndex(2), False) - assert tuple(state.data.values()) == ( - AttestationsAccumulator(included=1, assigned=2), - AttestationsAccumulator(included=3, assigned=5), - AttestationsAccumulator(included=1, assigned=2), - ) +def test_init_or_migrate_discards_data_on_version_change(): + state = State() + state._consensus_version = 1 + state.clear = Mock() + state.commit = Mock() + state.migrate(0, 63, 32, 2) + state.clear.assert_called_once() + state.commit.assert_called_once() + +def test_init_or_migrate_no_migration_needed(): + state = State() + state._consensus_version = 1 + state.frames = [(0, 31), (32, 63)] + state.data = { + (0, 31): defaultdict(AttestationsAccumulator), + (32, 63): defaultdict(AttestationsAccumulator), + } + state.commit = Mock() + state.migrate(0, 63, 32, 1) + state.commit.assert_not_called() + + +def test_init_or_migrate_migrates_data(): + state = State() + state._consensus_version = 1 + state.frames = [(0, 31), (32, 63)] + state.data = { + (0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), + (32, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(20, 15)}), + } + state.commit = Mock() + state.migrate(0, 63, 64, 1) + assert state.data == { + (0, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(30, 20)}), + } + state.commit.assert_called_once() + + +def test_init_or_migrate_invalidates_unmigrated_frames(): + state = State() + state._consensus_version = 1 + state.frames = [(0, 63)] + state.data = { + (0, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(30, 20)}), + } + state.commit = Mock() + state.migrate(0, 31, 32, 1) + assert state.data == { + (0, 31): defaultdict(AttestationsAccumulator), + } + assert state._processed_epochs == set() + state.commit.assert_called_once() + + +def test_init_or_migrate_discards_unmigrated_frame(): + state = State() + state._consensus_version = 1 + state.frames = [(0, 31), (32, 63), (64, 95)] + state.data = { + (0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), + (32, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(20, 15)}), + (64, 95): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(30, 25)}), + } + state._processed_epochs = set(sequence(0, 95)) + state.commit = Mock() + state.migrate(0, 63, 32, 1) + assert state.data == { + (0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), + (32, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(20, 15)}), + } + assert state._processed_epochs == set(sequence(0, 63)) + state.commit.assert_called_once() + + +def test_migrate_frames_data_creates_new_data_correctly(): + state = State() + state.frames = [(0, 31), (32, 63)] + new_frames = [(0, 63)] + state.data = { + (0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), + (32, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(20, 15)}), + } + state._migrate_frames_data(new_frames) + assert state.data == { + (0, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(30, 20)}) + } + + +def test_migrate_frames_data_handles_no_migration(): + state = State() + state.frames = [(0, 31)] + new_frames = [(0, 31)] + state.data = { + (0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), + } + state._migrate_frames_data(new_frames) + assert state.data == { + (0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}) + } + + +def test_migrate_frames_data_handles_partial_migration(): + state = State() + state.frames = [(0, 31), (32, 63)] + new_frames = [(0, 31), (32, 95)] + state.data = { + (0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), + (32, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(20, 15)}), + } + state._migrate_frames_data(new_frames) + assert state.data == { + (0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}), + (32, 95): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(20, 15)}), + } + + +def test_migrate_frames_data_handles_no_data(): + state = State() + state.frames = [(0, 31)] + new_frames = [(0, 31)] + state.data = {frame: defaultdict(AttestationsAccumulator) for frame in state.frames} + state._migrate_frames_data(new_frames) + assert state.data == {(0, 31): defaultdict(AttestationsAccumulator)} -def test_state_file_is_path(): - assert isinstance(State.file(), Path) +def test_migrate_frames_data_handles_wider_old_frame(): + state = State() + state.frames = [(0, 63)] + new_frames = [(0, 31), (32, 63)] + state.data = { + (0, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(30, 20)}), + } + state._migrate_frames_data(new_frames) + assert state.data == { + (0, 31): defaultdict(AttestationsAccumulator), + (32, 63): defaultdict(AttestationsAccumulator), + } + + +def test_validate_raises_error_if_state_not_fulfilled(): + state = State() + state._epochs_to_process = tuple(sequence(0, 95)) + state._processed_epochs = set(sequence(0, 94)) + with pytest.raises(InvalidState, match="State is not fulfilled"): + state.validate(0, 95) -class TestStateTransition: - """Tests for State's transition for different l_epoch, r_epoch values""" - @pytest.fixture(autouse=True) - def no_commit(self, monkeypatch: pytest.MonkeyPatch): - monkeypatch.setattr(State, "commit", Mock()) +def test_validate_raises_error_if_processed_epoch_out_of_range(): + state = State() + state._epochs_to_process = tuple(sequence(0, 95)) + state._processed_epochs = set(sequence(0, 95)) + state._processed_epochs.add(96) + with pytest.raises(InvalidState, match="Processed epoch 96 is out of range"): + state.validate(0, 95) - def test_empty_to_new_frame(self): - state = State() - assert state.is_empty - l_epoch = EpochNumber(1) - r_epoch = EpochNumber(255) +def test_validate_raises_error_if_epoch_missing_in_processed_epochs(): + state = State() + state._epochs_to_process = tuple(sequence(0, 94)) + state._processed_epochs = set(sequence(0, 94)) + with pytest.raises(InvalidState, match="Epoch 95 missing in processed epochs"): + state.validate(0, 95) - state.migrate(l_epoch, r_epoch, 1) - assert not state.is_empty - assert state.unprocessed_epochs == set(sequence(l_epoch, r_epoch)) +def test_validate_passes_for_fulfilled_state(): + state = State() + state._epochs_to_process = tuple(sequence(0, 95)) + state._processed_epochs = set(sequence(0, 95)) + state.validate(0, 95) - @pytest.mark.parametrize( - ("l_epoch_old", "r_epoch_old", "l_epoch_new", "r_epoch_new"), - [ - pytest.param(1, 255, 256, 510, id="Migrate a..bA..B"), - pytest.param(1, 255, 32, 510, id="Migrate a..A..b..B"), - pytest.param(32, 510, 1, 255, id="Migrate: A..a..B..b"), - ], - ) - def test_new_frame_requires_discarding_state(self, l_epoch_old, r_epoch_old, l_epoch_new, r_epoch_new): - state = State() - state.clear = Mock(side_effect=state.clear) - state.migrate(l_epoch_old, r_epoch_old, 1) - state.clear.assert_not_called() - state.migrate(l_epoch_new, r_epoch_new, 1) - state.clear.assert_called_once() +def test_attestation_aggregate_perf(): + aggr = AttestationsAccumulator(included=333, assigned=777) + assert aggr.perf == pytest.approx(0.4285, abs=1e-4) - assert state.unprocessed_epochs == set(sequence(l_epoch_new, r_epoch_new)) - @pytest.mark.parametrize( - ("l_epoch_old", "r_epoch_old", "l_epoch_new", "r_epoch_new"), - [ - pytest.param(1, 255, 1, 510, id="Migrate Aa..b..B"), - pytest.param(32, 510, 1, 510, id="Migrate: A..a..b..B"), - ], - ) - def test_new_frame_extends_old_state(self, l_epoch_old, r_epoch_old, l_epoch_new, r_epoch_new): - state = State() - state.clear = Mock(side_effect=state.clear) +def test_get_network_aggr_computes_correctly(): + state = State() + state.data = { + (0, 31): defaultdict( + AttestationsAccumulator, + {ValidatorIndex(1): AttestationsAccumulator(10, 5), ValidatorIndex(2): AttestationsAccumulator(20, 15)}, + ) + } + aggr = state.get_network_aggr((0, 31)) + assert aggr.assigned == 30 + assert aggr.included == 20 + + +def test_get_network_aggr_raises_error_for_invalid_accumulator(): + state = State() + state.data = {(0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 15)})} + with pytest.raises(ValueError, match="Invalid accumulator"): + state.get_network_aggr((0, 31)) - state.migrate(l_epoch_old, r_epoch_old, 1) - state.clear.assert_not_called() - state.migrate(l_epoch_new, r_epoch_new, 1) - state.clear.assert_not_called() +def test_get_network_aggr_raises_error_for_missing_frame_data(): + state = State() + with pytest.raises(ValueError, match="No data for frame"): + state.get_network_aggr((0, 31)) - assert state.unprocessed_epochs == set(sequence(l_epoch_new, r_epoch_new)) - @pytest.mark.parametrize( - ("old_version", "new_version"), - [ - pytest.param(2, 3, id="Increase consensus version"), - pytest.param(3, 2, id="Decrease consensus version"), - ], - ) - def test_consensus_version_change(self, old_version, new_version): - state = State() - state.clear = Mock(side_effect=state.clear) - state._consensus_version = old_version - - l_epoch = r_epoch = EpochNumber(255) - - state.migrate(l_epoch, r_epoch, old_version) - state.clear.assert_not_called() - - state.migrate(l_epoch, r_epoch, new_version) - state.clear.assert_called_once() +def test_get_network_aggr_handles_empty_frame_data(): + state = State() + state.data = {(0, 31): defaultdict(AttestationsAccumulator)} + aggr = state.get_network_aggr((0, 31)) + assert aggr.assigned == 0 + assert aggr.included == 0