Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CSM] feat: proper missing frames handling #557

Merged
merged 20 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/modules/csm/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def exec(self, checkpoint: FrameCheckpoint) -> int:
for duty_epoch in unprocessed_epochs
}
self._process(unprocessed_epochs, duty_epochs_roots)
self.state.commit()
return len(unprocessed_epochs)

def _get_block_roots(self, checkpoint_slot: SlotNumber):
Expand Down Expand Up @@ -204,18 +205,18 @@ def _check_duty(
for root in block_roots:
attestations = self.cc.get_block_attestations(root)
process_attestations(attestations, committees, self.eip7549_supported)

frame = self.state.find_frame(duty_epoch)
with lock:
for committee in committees.values():
for validator_duty in committee:
self.state.inc(
self.state.increment_duty(
frame,
validator_duty.index,
included=validator_duty.included,
)
if duty_epoch not in self.state.unprocessed_epochs:
raise ValueError(f"Epoch {duty_epoch} is not in epochs that should be processed")
self.state.add_processed_epoch(duty_epoch)
self.state.commit()
self.state.log_progress()
unprocessed_epochs = self.state.unprocessed_epochs
CSM_UNPROCESSED_EPOCHS_COUNT.set(len(unprocessed_epochs))
Expand Down
190 changes: 130 additions & 60 deletions src/modules/csm/csm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
)
from src.metrics.prometheus.duration_meter import duration_meter
from src.modules.csm.checkpoint import FrameCheckpointProcessor, FrameCheckpointsIterator, MinStepIsNotReached
from src.modules.csm.log import FramePerfLog
from src.modules.csm.state import State
from src.modules.csm.log import FramePerfLog, OperatorFrameSummary
from src.modules.csm.state import State, Frame, AttestationsAccumulator
from src.modules.csm.tree import Tree
from src.modules.csm.types import ReportData, Shares
from src.modules.submodules.consensus import ConsensusModule
Expand All @@ -32,9 +32,9 @@
)
from src.utils.blockstamp import build_blockstamp
from src.utils.cache import global_lru_cache as lru_cache
from src.utils.slot import get_next_non_missed_slot
from src.utils.slot import get_next_non_missed_slot, get_reference_blockstamp
from src.utils.web3converter import Web3Converter
from src.web3py.extensions.lido_validators import NodeOperatorId, StakingModule, ValidatorsByNodeOperator
from src.web3py.extensions.lido_validators import NodeOperatorId, StakingModule, ValidatorsByNodeOperator, LidoValidator
from src.web3py.types import Web3

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -101,15 +101,15 @@ def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple:
if (prev_cid is None) != (prev_root == ZERO_HASH):
raise InconsistentData(f"Got inconsistent previous tree data: {prev_root=} {prev_cid=}")

distributed, shares, log = self.calculate_distribution(blockstamp)
total_distributed, total_rewards, logs = self.calculate_distribution(blockstamp)

if distributed != sum(shares.values()):
raise InconsistentData(f"Invalid distribution: {sum(shares.values())=} != {distributed=}")
if total_distributed != sum(total_rewards.values()):
raise InconsistentData(f"Invalid distribution: {sum(total_rewards.values())=} != {total_distributed=}")

log_cid = self.publish_log(log)
log_cid = self.publish_log(logs)

if not distributed and not shares:
logger.info({"msg": "No shares distributed in the current frame"})
if not total_distributed and not total_rewards:
logger.info({"msg": "No rewards distributed in the current frame"})
return ReportData(
self.get_consensus_version(blockstamp),
blockstamp.ref_slot,
Expand All @@ -122,11 +122,11 @@ def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple:
if prev_cid and prev_root != ZERO_HASH:
# Update cumulative amount of shares for all operators.
for no_id, acc_shares in self.get_accumulated_shares(prev_cid, prev_root):
shares[no_id] += acc_shares
total_rewards[no_id] += acc_shares
else:
logger.info({"msg": "No previous distribution. Nothing to accumulate"})

tree = self.make_tree(shares)
tree = self.make_tree(total_rewards)
tree_cid = self.publish_tree(tree)

return ReportData(
Expand All @@ -135,7 +135,7 @@ def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple:
tree_root=tree.root,
tree_cid=tree_cid,
log_cid=log_cid,
distributed=distributed,
distributed=total_distributed,
).as_tuple()

def is_main_data_submitted(self, blockstamp: BlockStamp) -> bool:
Expand Down Expand Up @@ -201,7 +201,7 @@ def collect_data(self, blockstamp: BlockStamp) -> bool:
logger.info({"msg": "The starting epoch of the frame is not finalized yet"})
return False

self.state.migrate(l_epoch, r_epoch, consensus_version)
self.state.init_or_migrate(l_epoch, r_epoch, converter.frame_config.epochs_per_frame, consensus_version)
self.state.log_progress()

if self.state.is_fulfilled:
Expand All @@ -227,63 +227,133 @@ def collect_data(self, blockstamp: BlockStamp) -> bool:

def calculate_distribution(
self, blockstamp: ReferenceBlockStamp
) -> tuple[int, defaultdict[NodeOperatorId, int], FramePerfLog]:
) -> tuple[int, defaultdict[NodeOperatorId, int], list[FramePerfLog]]:
"""Computes distribution of fee shares at the given timestamp"""

network_avg_perf = self.state.get_network_aggr().perf
threshold = network_avg_perf - self.w3.csm.oracle.perf_leeway_bp(blockstamp.block_hash) / TOTAL_BASIS_POINTS
operators_to_validators = self.module_validators_by_node_operators(blockstamp)

# Build the map of the current distribution operators.
distribution: dict[NodeOperatorId, int] = defaultdict(int)
stuck_operators = self.stuck_operators(blockstamp)
log = FramePerfLog(blockstamp, self.state.frame, threshold)
total_distributed = 0
total_rewards = defaultdict[NodeOperatorId, int](int)
logs: list[FramePerfLog] = []

for (_, no_id), validators in operators_to_validators.items():
if no_id in stuck_operators:
log.operators[no_id].stuck = True
continue
for frame in self.state.frames:
from_epoch, to_epoch = frame
logger.info({"msg": f"Calculating distribution for frame [{from_epoch};{to_epoch}]"})

for v in validators:
aggr = self.state.data.get(v.index)
frame_blockstamp = blockstamp
if to_epoch != blockstamp.ref_epoch:
frame_blockstamp = self._get_ref_blockstamp_for_frame(blockstamp, to_epoch)

if aggr is None:
# It's possible that the validator is not assigned to any duty, hence it's performance
# is not presented in the aggregates (e.g. exited, pending for activation etc).
continue
total_rewards_to_distribute = self.w3.csm.fee_distributor.shares_to_distribute(frame_blockstamp.block_hash)
rewards_to_distribute_in_frame = total_rewards_to_distribute - total_distributed

if v.validator.slashed is True:
# It means that validator was active during the frame and got slashed and didn't meet the exit
# epoch, so we should not count such validator for operator's share.
log.operators[no_id].validators[v.index].slashed = True
continue
rewards_in_frame, log = self._calculate_distribution_in_frame(
frame, frame_blockstamp, rewards_to_distribute_in_frame, operators_to_validators
)
distributed_in_frame = sum(rewards_in_frame.values())

if aggr.perf > threshold:
# Count of assigned attestations used as a metrics of time
# the validator was active in the current frame.
distribution[no_id] += aggr.assigned
total_distributed += distributed_in_frame
if total_distributed > total_rewards_to_distribute:
raise CSMError(f"Invalid distribution: {total_distributed=} > {total_rewards_to_distribute=}")

log.operators[no_id].validators[v.index].perf = aggr
for no_id, rewards in rewards_in_frame.items():
total_rewards[no_id] += rewards

# Calculate share of each CSM node operator.
shares = defaultdict[NodeOperatorId, int](int)
total = sum(p for p in distribution.values())
logs.append(log)

if not total:
return 0, shares, log
return total_distributed, total_rewards, logs

to_distribute = self.w3.csm.fee_distributor.shares_to_distribute(blockstamp.block_hash)
log.distributable = to_distribute
def _get_ref_blockstamp_for_frame(
self, blockstamp: ReferenceBlockStamp, frame_ref_epoch: EpochNumber
) -> ReferenceBlockStamp:
converter = self.converter(blockstamp)
return get_reference_blockstamp(
cc=self.w3.cc,
ref_slot=converter.get_epoch_last_slot(frame_ref_epoch),
ref_epoch=frame_ref_epoch,
last_finalized_slot_number=blockstamp.slot_number,
)

for no_id, no_share in distribution.items():
if no_share:
shares[no_id] = to_distribute * no_share // total
log.operators[no_id].distributed = shares[no_id]
def _calculate_distribution_in_frame(
self,
frame: Frame,
blockstamp: ReferenceBlockStamp,
rewards_to_distribute: int,
operators_to_validators: ValidatorsByNodeOperator
):
threshold = self._get_performance_threshold(frame, blockstamp)
log = FramePerfLog(blockstamp, frame, threshold)

participation_shares: defaultdict[NodeOperatorId, int] = defaultdict(int)

stuck_operators = self.stuck_operators(blockstamp)
for (_, no_id), validators in operators_to_validators.items():
log_operator = log.operators[no_id]
if no_id in stuck_operators:
log_operator.stuck = True
continue
for validator in validators:
duty = self.state.data[frame].get(validator.index)
self.process_validator_duty(validator, duty, threshold, participation_shares, log_operator)

rewards_distribution = self.calc_rewards_distribution_in_frame(participation_shares, rewards_to_distribute)

for no_id, no_rewards in rewards_distribution.items():
log.operators[no_id].distributed = no_rewards

log.distributable = rewards_to_distribute

return rewards_distribution, log

def _get_performance_threshold(self, frame: Frame, blockstamp: ReferenceBlockStamp) -> float:
network_perf = self.state.get_network_aggr(frame).perf
perf_leeway = self.w3.csm.oracle.perf_leeway_bp(blockstamp.block_hash) / TOTAL_BASIS_POINTS
threshold = network_perf - perf_leeway
return threshold

@staticmethod
def process_validator_duty(
validator: LidoValidator,
attestation_duty: AttestationsAccumulator | None,
threshold: float,
participation_shares: defaultdict[NodeOperatorId, int],
log_operator: OperatorFrameSummary
):
if attestation_duty is None:
# It's possible that the validator is not assigned to any duty, hence it's performance
# is not presented in the aggregates (e.g. exited, pending for activation etc).
# TODO: check `sync_aggr` to strike (in case of bad sync performance) after validator exit
return

log_validator = log_operator.validators[validator.index]

if validator.validator.slashed is True:
# It means that validator was active during the frame and got slashed and didn't meet the exit
# epoch, so we should not count such validator for operator's share.
log_validator.slashed = True
return

if attestation_duty.perf > threshold:
# Count of assigned attestations used as a metrics of time
# the validator was active in the current frame.
participation_shares[validator.lido_id.operatorIndex] += attestation_duty.assigned

log_validator.attestation_duty = attestation_duty

@staticmethod
def calc_rewards_distribution_in_frame(
participation_shares: dict[NodeOperatorId, int],
rewards_to_distribute: int,
) -> dict[NodeOperatorId, int]:
rewards_distribution: dict[NodeOperatorId, int] = defaultdict(int)
total_participation = sum(participation_shares.values())

for no_id, no_participation_share in participation_shares.items():
if no_participation_share == 0:
# Skip operators with zero participation
continue
rewards_distribution[no_id] = rewards_to_distribute * no_participation_share // total_participation

distributed = sum(s for s in shares.values())
if distributed > to_distribute:
raise CSMError(f"Invalid distribution: {distributed=} > {to_distribute=}")
return distributed, shares, log
return rewards_distribution

def get_accumulated_shares(self, cid: CID, root: HexBytes) -> Iterator[tuple[NodeOperatorId, Shares]]:
logger.info({"msg": "Fetching tree by CID from IPFS", "cid": repr(cid)})
Expand Down Expand Up @@ -348,9 +418,9 @@ def publish_tree(self, tree: Tree) -> CID:
logger.info({"msg": "Tree dump uploaded to IPFS", "cid": repr(tree_cid)})
return tree_cid

def publish_log(self, log: FramePerfLog) -> CID:
log_cid = self.w3.ipfs.publish(log.encode())
logger.info({"msg": "Frame log uploaded to IPFS", "cid": repr(log_cid)})
def publish_log(self, logs: list[FramePerfLog]) -> CID:
log_cid = self.w3.ipfs.publish(FramePerfLog.encode(logs))
logger.info({"msg": "Frame(s) log uploaded to IPFS", "cid": repr(log_cid)})
return log_cid

@lru_cache(maxsize=1)
Expand Down
7 changes: 4 additions & 3 deletions src/modules/csm/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class LogJSONEncoder(json.JSONEncoder): ...

@dataclass
class ValidatorFrameSummary:
perf: AttestationsAccumulator = field(default_factory=AttestationsAccumulator)
attestation_duty: AttestationsAccumulator = field(default_factory=AttestationsAccumulator)
slashed: bool = False


Expand All @@ -35,13 +35,14 @@ class FramePerfLog:
default_factory=lambda: defaultdict(OperatorFrameSummary)
)

def encode(self) -> bytes:
@staticmethod
def encode(logs: list['FramePerfLog']) -> bytes:
return (
LogJSONEncoder(
indent=None,
separators=(',', ':'),
sort_keys=True,
)
.encode(asdict(self))
.encode([asdict(log) for log in logs])
.encode()
)
Loading