Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CSM] feat: proper missing frames handling #557

Merged
merged 20 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/modules/csm/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def exec(self, checkpoint: FrameCheckpoint) -> int:
for duty_epoch in unprocessed_epochs
}
self._process(unprocessed_epochs, duty_epochs_roots)
self.state.commit()
return len(unprocessed_epochs)

def _get_block_roots(self, checkpoint_slot: SlotNumber):
Expand Down Expand Up @@ -204,18 +205,17 @@ def _check_duty(
for root in block_roots:
attestations = self.cc.get_block_attestations(root)
process_attestations(attestations, committees, self.eip7549_supported)

with lock:
for committee in committees.values():
for validator_duty in committee:
self.state.inc(
self.state.increment_duty(
duty_epoch,
validator_duty.index,
included=validator_duty.included,
)
if duty_epoch not in self.state.unprocessed_epochs:
raise ValueError(f"Epoch {duty_epoch} is not in epochs that should be processed")
self.state.add_processed_epoch(duty_epoch)
self.state.commit()
self.state.log_progress()
unprocessed_epochs = self.state.unprocessed_epochs
CSM_UNPROCESSED_EPOCHS_COUNT.set(len(unprocessed_epochs))
Expand Down
256 changes: 162 additions & 94 deletions src/modules/csm/csm.py

Large diffs are not rendered by default.

7 changes: 4 additions & 3 deletions src/modules/csm/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class LogJSONEncoder(json.JSONEncoder): ...

@dataclass
class ValidatorFrameSummary:
perf: AttestationsAccumulator = field(default_factory=AttestationsAccumulator)
attestation_duty: AttestationsAccumulator = field(default_factory=AttestationsAccumulator)
slashed: bool = False


Expand All @@ -35,13 +35,14 @@ class FramePerfLog:
default_factory=lambda: defaultdict(OperatorFrameSummary)
)

def encode(self) -> bytes:
@staticmethod
def encode(logs: list['FramePerfLog']) -> bytes:
return (
LogJSONEncoder(
indent=None,
separators=(',', ':'),
sort_keys=True,
)
.encode(asdict(self))
.encode([asdict(log) for log in logs])
.encode()
)
129 changes: 89 additions & 40 deletions src/modules/csm/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import pickle
from collections import defaultdict
from dataclasses import dataclass
from functools import lru_cache
from itertools import batched
from pathlib import Path
from typing import Self

Expand Down Expand Up @@ -33,6 +35,10 @@ def add_duty(self, included: bool) -> None:
self.included += 1 if included else 0


type Frame = tuple[EpochNumber, EpochNumber]
type StateData = dict[Frame, defaultdict[ValidatorIndex, AttestationsAccumulator]]


class State:
"""
Processing state of a CSM performance oracle frame.
Expand All @@ -43,16 +49,16 @@ class State:

The state can be migrated to be used for another frame's report by calling the `migrate` method.
"""

data: defaultdict[ValidatorIndex, AttestationsAccumulator]
frames: list[Frame]
data: StateData

_epochs_to_process: tuple[EpochNumber, ...]
_processed_epochs: set[EpochNumber]

_consensus_version: int = 1

def __init__(self, data: dict[ValidatorIndex, AttestationsAccumulator] | None = None) -> None:
self.data = defaultdict(AttestationsAccumulator, data or {})
def __init__(self) -> None:
self.data = {}
self._epochs_to_process = tuple()
self._processed_epochs = set()

Expand Down Expand Up @@ -89,22 +95,55 @@ def file(cls) -> Path:
def buffer(self) -> Path:
return self.file().with_suffix(".buf")

@property
def is_empty(self) -> bool:
return not self.data and not self._epochs_to_process and not self._processed_epochs

@property
def unprocessed_epochs(self) -> set[EpochNumber]:
if not self._epochs_to_process:
raise ValueError("Epochs to process are not set")
diff = set(self._epochs_to_process) - self._processed_epochs
return diff

@property
def is_fulfilled(self) -> bool:
return not self.unprocessed_epochs

@staticmethod
def _calculate_frames(epochs_to_process: tuple[EpochNumber, ...], epochs_per_frame: int) -> list[Frame]:
"""Split epochs to process into frames of `epochs_per_frame` length"""
if len(epochs_to_process) % epochs_per_frame != 0:
raise ValueError("Insufficient epochs to form a frame")
return [(frame[0], frame[-1]) for frame in batched(sorted(epochs_to_process), epochs_per_frame)]

def clear(self) -> None:
self.data = defaultdict(AttestationsAccumulator)
self.data = {}
self._epochs_to_process = tuple()
self._processed_epochs.clear()
assert self.is_empty

def inc(self, key: ValidatorIndex, included: bool) -> None:
self.data[key].add_duty(included)
@lru_cache(variables.CSM_ORACLE_MAX_CONCURRENCY)
def find_frame(self, epoch: EpochNumber) -> Frame:
for epoch_range in self.frames:
from_epoch, to_epoch = epoch_range
if from_epoch <= epoch <= to_epoch:
return epoch_range
raise ValueError(f"Epoch {epoch} is out of frames range: {self.frames}")

def increment_duty(self, epoch: EpochNumber, val_index: ValidatorIndex, included: bool) -> None:
frame = self.find_frame(epoch)
self.data[frame][val_index].add_duty(included)

def add_processed_epoch(self, epoch: EpochNumber) -> None:
self._processed_epochs.add(epoch)

def log_progress(self) -> None:
logger.info({"msg": f"Processed {len(self._processed_epochs)} of {len(self._epochs_to_process)} epochs"})

def migrate(self, l_epoch: EpochNumber, r_epoch: EpochNumber, consensus_version: int):
def migrate(
self, l_epoch: EpochNumber, r_epoch: EpochNumber, epochs_per_frame: int, consensus_version: int
) -> None:
if consensus_version != self._consensus_version:
logger.warning(
{
Expand All @@ -114,17 +153,46 @@ def migrate(self, l_epoch: EpochNumber, r_epoch: EpochNumber, consensus_version:
)
self.clear()

for state_epochs in (self._epochs_to_process, self._processed_epochs):
for epoch in state_epochs:
if epoch < l_epoch or epoch > r_epoch:
logger.warning({"msg": "Discarding invalidated state cache"})
self.clear()
break
frames = self._calculate_frames(tuple(sequence(l_epoch, r_epoch)), epochs_per_frame)

if not self.is_empty:
cached_frames = self.frames
if cached_frames == frames:
logger.info({"msg": "No need to migrate duties data cache"})
return
self._migrate_frames_data(frames)
else:
self.data = {frame: defaultdict(AttestationsAccumulator) for frame in frames}

self.frames = frames
self._epochs_to_process = tuple(sequence(l_epoch, r_epoch))
self._consensus_version = consensus_version
self.find_frame.cache_clear()
self.commit()

def _migrate_frames_data(self, new_frames: list[Frame]):
logger.info({"msg": f"Migrating duties data cache: {self.frames=} -> {new_frames=}"})
new_data: StateData = {frame: defaultdict(AttestationsAccumulator) for frame in new_frames}

def overlaps(a: Frame, b: Frame):
return a[0] <= b[0] and a[1] >= b[1]

consumed = []
for new_frame in new_frames:
for frame_to_consume in self.frames:
if overlaps(new_frame, frame_to_consume):
assert frame_to_consume not in consumed
consumed.append(frame_to_consume)
for val, duty in self.data[frame_to_consume].items():
new_data[new_frame][val].assigned += duty.assigned
new_data[new_frame][val].included += duty.included
for frame in self.frames:
if frame in consumed:
continue
logger.warning({"msg": f"Invalidating frame duties data cache: {frame}"})
self._processed_epochs -= set(sequence(*frame))
self.data = new_data

def validate(self, l_epoch: EpochNumber, r_epoch: EpochNumber) -> None:
if not self.is_fulfilled:
raise InvalidState(f"State is not fulfilled. {self.unprocessed_epochs=}")
Expand All @@ -135,34 +203,15 @@ def validate(self, l_epoch: EpochNumber, r_epoch: EpochNumber) -> None:

for epoch in sequence(l_epoch, r_epoch):
if epoch not in self._processed_epochs:
raise InvalidState(f"Epoch {epoch} should be processed")

@property
def is_empty(self) -> bool:
return not self.data and not self._epochs_to_process and not self._processed_epochs

@property
def unprocessed_epochs(self) -> set[EpochNumber]:
if not self._epochs_to_process:
raise ValueError("Epochs to process are not set")
diff = set(self._epochs_to_process) - self._processed_epochs
return diff

@property
def is_fulfilled(self) -> bool:
return not self.unprocessed_epochs

@property
def frame(self) -> tuple[EpochNumber, EpochNumber]:
if not self._epochs_to_process:
raise ValueError("Epochs to process are not set")
return min(self._epochs_to_process), max(self._epochs_to_process)

def get_network_aggr(self) -> AttestationsAccumulator:
"""Return `AttestationsAccumulator` over duties of all the network validators"""
raise InvalidState(f"Epoch {epoch} missing in processed epochs")

def get_network_aggr(self, frame: Frame) -> AttestationsAccumulator:
# TODO: exclude `active_slashed` validators from the calculation
included = assigned = 0
for validator, acc in self.data.items():
frame_data = self.data.get(frame)
if frame_data is None:
raise ValueError(f"No data for frame {frame} to calculate network aggregate")
for validator, acc in frame_data.items():
if acc.included > acc.assigned:
raise ValueError(f"Invalid accumulator: {validator=}, {acc=}")
included += acc.included
Expand Down
4 changes: 2 additions & 2 deletions src/providers/execution/contracts/cs_fee_distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from eth_typing import ChecksumAddress
from hexbytes import HexBytes
from web3 import Web3
from web3.types import BlockIdentifier
from web3.types import BlockIdentifier, Wei

from ..base_interface import ContractInterface

Expand All @@ -26,7 +26,7 @@ def oracle(self, block_identifier: BlockIdentifier = "latest") -> ChecksumAddres
)
return Web3.to_checksum_address(resp)

def shares_to_distribute(self, block_identifier: BlockIdentifier = "latest") -> int:
def shares_to_distribute(self, block_identifier: BlockIdentifier = "latest") -> Wei:
"""Returns the amount of shares that are pending to be distributed"""

resp = self.functions.pendingSharesToDistribute().call(block_identifier=block_identifier)
Expand Down
14 changes: 7 additions & 7 deletions tests/modules/csm/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def test_checkpoints_processor_no_eip7549_support(
monkeypatch: pytest.MonkeyPatch,
):
state = State()
state.migrate(EpochNumber(0), EpochNumber(255), 1)
state.migrate(EpochNumber(0), EpochNumber(255), 256, 1)
processor = FrameCheckpointProcessor(
consensus_client,
state,
Expand Down Expand Up @@ -354,7 +354,7 @@ def test_checkpoints_processor_check_duty(
converter,
):
state = State()
state.migrate(0, 255, 1)
state.migrate(0, 255, 256, 1)
finalized_blockstamp = ...
processor = FrameCheckpointProcessor(
consensus_client,
Expand All @@ -367,7 +367,7 @@ def test_checkpoints_processor_check_duty(
assert len(state._processed_epochs) == 1
assert len(state._epochs_to_process) == 256
assert len(state.unprocessed_epochs) == 255
assert len(state.data) == 2048 * 32
assert len(state.data[(0, 255)]) == 2048 * 32


def test_checkpoints_processor_process(
Expand All @@ -379,7 +379,7 @@ def test_checkpoints_processor_process(
converter,
):
state = State()
state.migrate(0, 255, 1)
state.migrate(0, 255, 256, 1)
finalized_blockstamp = ...
processor = FrameCheckpointProcessor(
consensus_client,
Expand All @@ -392,7 +392,7 @@ def test_checkpoints_processor_process(
assert len(state._processed_epochs) == 2
assert len(state._epochs_to_process) == 256
assert len(state.unprocessed_epochs) == 254
assert len(state.data) == 2048 * 32
assert len(state.data[(0, 255)]) == 2048 * 32


def test_checkpoints_processor_exec(
Expand All @@ -404,7 +404,7 @@ def test_checkpoints_processor_exec(
converter,
):
state = State()
state.migrate(0, 255, 1)
state.migrate(0, 255, 256, 1)
finalized_blockstamp = ...
processor = FrameCheckpointProcessor(
consensus_client,
Expand All @@ -418,4 +418,4 @@ def test_checkpoints_processor_exec(
assert len(state._processed_epochs) == 2
assert len(state._epochs_to_process) == 256
assert len(state.unprocessed_epochs) == 254
assert len(state.data) == 2048 * 32
assert len(state.data[(0, 255)]) == 2048 * 32
Loading