Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CSM] feat: proper missing frames handling #557

Merged
merged 20 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions src/modules/csm/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,18 @@ class State:

The state can be migrated to be used for another frame's report by calling the `migrate` method.
"""
frames: list[Frame]
data: StateData

_epochs_to_process: tuple[EpochNumber, ...]
_processed_epochs: set[EpochNumber]
_epochs_per_frame: int

_consensus_version: int = 1

def __init__(self) -> None:
self.data = {}
self._epochs_to_process = tuple()
self._processed_epochs = set()
self._epochs_per_frame = 0

EXTENSION = ".pkl"

Expand Down Expand Up @@ -111,10 +110,6 @@ def unprocessed_epochs(self) -> set[EpochNumber]:
def is_fulfilled(self) -> bool:
return not self.unprocessed_epochs

@property
def frames(self):
return self._calculate_frames(self._epochs_to_process, self._epochs_per_frame)

@staticmethod
def _calculate_frames(epochs_to_process: tuple[EpochNumber, ...], epochs_per_frame: int) -> list[Frame]:
"""Split epochs to process into frames of `epochs_per_frame` length"""
Expand Down Expand Up @@ -169,7 +164,7 @@ def migrate(
else:
self.data = {frame: defaultdict(AttestationsAccumulator) for frame in frames}

self._epochs_per_frame = epochs_per_frame
self.frames = frames
self._epochs_to_process = tuple(sequence(l_epoch, r_epoch))
self._consensus_version = consensus_version
self.find_frame.cache_clear()
Expand Down
52 changes: 19 additions & 33 deletions tests/modules/csm/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,26 +132,23 @@ def test_clear_resets_state_to_empty():

def test_find_frame_returns_correct_frame():
state = State()
state._epochs_to_process = tuple(sequence(0, 31))
state._epochs_per_frame = 32
state.frames = [(0, 31)]
state.data = {(0, 31): defaultdict(AttestationsAccumulator)}
assert state.find_frame(15) == (0, 31)


def test_find_frame_raises_error_for_out_of_range_epoch():
state = State()
state._epochs_to_process = tuple(sequence(0, 31))
state._epochs_per_frame = 32
state.frames = [(0, 31)]
state.data = {(0, 31): defaultdict(AttestationsAccumulator)}
with pytest.raises(ValueError, match="Epoch 32 is out of frames range"):
state.find_frame(32)


def test_increment_duty_adds_duty_correctly():
state = State()
state._epochs_to_process = tuple(sequence(0, 31))
state._epochs_per_frame = 32
frame = (0, 31)
state.frames = [frame]
duty_epoch, _ = frame
state.data = {
frame: defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}),
Expand All @@ -163,9 +160,8 @@ def test_increment_duty_adds_duty_correctly():

def test_increment_duty_creates_new_validator_entry():
state = State()
state._epochs_to_process = tuple(sequence(0, 31))
state._epochs_per_frame = 32
frame = (0, 31)
state.frames = [frame]
duty_epoch, _ = frame
state.data = {
frame: defaultdict(AttestationsAccumulator),
Expand All @@ -177,8 +173,8 @@ def test_increment_duty_creates_new_validator_entry():

def test_increment_duty_handles_non_included_duty():
state = State()
state._epochs_to_process = tuple(sequence(0, 31))
state._epochs_per_frame = 32
frame = (0, 31)
state.frames = [frame]
frame = (0, 31)
duty_epoch, _ = frame
state.data = {
Expand All @@ -191,10 +187,10 @@ def test_increment_duty_handles_non_included_duty():

def test_increment_duty_raises_error_for_out_of_range_epoch():
state = State()
state._epochs_to_process = tuple(sequence(0, 31))
state._epochs_per_frame = 32
frame = (0, 31)
state.frames = [frame]
state.data = {
(0, 31): defaultdict(AttestationsAccumulator),
frame: defaultdict(AttestationsAccumulator),
}
with pytest.raises(ValueError, match="is out of frames range"):
state.increment_duty(32, ValidatorIndex(1), True)
Expand Down Expand Up @@ -226,8 +222,7 @@ def test_init_or_migrate_discards_data_on_version_change():
def test_init_or_migrate_no_migration_needed():
state = State()
state._consensus_version = 1
state._epochs_to_process = tuple(sequence(0, 63))
state._epochs_per_frame = 32
state.frames = [(0, 31), (32, 63)]
state.data = {
(0, 31): defaultdict(AttestationsAccumulator),
(32, 63): defaultdict(AttestationsAccumulator),
Expand All @@ -240,8 +235,7 @@ def test_init_or_migrate_no_migration_needed():
def test_init_or_migrate_migrates_data():
state = State()
state._consensus_version = 1
state._epochs_to_process = tuple(sequence(0, 63))
state._epochs_per_frame = 32
state.frames = [(0, 31), (32, 63)]
state.data = {
(0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}),
(32, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(20, 15)}),
Expand All @@ -257,8 +251,7 @@ def test_init_or_migrate_migrates_data():
def test_init_or_migrate_invalidates_unmigrated_frames():
state = State()
state._consensus_version = 1
state._epochs_to_process = tuple(sequence(0, 63))
state._epochs_per_frame = 64
state.frames = [(0, 63)]
state.data = {
(0, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(30, 20)}),
}
Expand All @@ -274,8 +267,7 @@ def test_init_or_migrate_invalidates_unmigrated_frames():
def test_init_or_migrate_discards_unmigrated_frame():
state = State()
state._consensus_version = 1
state._epochs_to_process = tuple(sequence(0, 95))
state._epochs_per_frame = 32
state.frames = [(0, 31), (32, 63), (64, 95)]
state.data = {
(0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}),
(32, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(20, 15)}),
Expand All @@ -294,8 +286,7 @@ def test_init_or_migrate_discards_unmigrated_frame():

def test_migrate_frames_data_creates_new_data_correctly():
state = State()
state._epochs_to_process = tuple(sequence(0, 63))
state._epochs_per_frame = 32
state.frames = [(0, 31), (32, 63)]
new_frames = [(0, 63)]
state.data = {
(0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}),
Expand All @@ -309,8 +300,7 @@ def test_migrate_frames_data_creates_new_data_correctly():

def test_migrate_frames_data_handles_no_migration():
state = State()
state._epochs_to_process = tuple(sequence(0, 31))
state._epochs_per_frame = 32
state.frames = [(0, 31)]
new_frames = [(0, 31)]
state.data = {
(0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}),
Expand All @@ -323,8 +313,7 @@ def test_migrate_frames_data_handles_no_migration():

def test_migrate_frames_data_handles_partial_migration():
state = State()
state._epochs_to_process = tuple(sequence(0, 63))
state._epochs_per_frame = 32
state.frames = [(0, 31), (32, 63)]
new_frames = [(0, 31), (32, 95)]
state.data = {
(0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}),
Expand All @@ -339,19 +328,16 @@ def test_migrate_frames_data_handles_partial_migration():

def test_migrate_frames_data_handles_no_data():
state = State()
state._epochs_to_process = tuple(sequence(0, 31))
state._epochs_per_frame = 32
current_frames = [(0, 31)]
state.frames = [(0, 31)]
new_frames = [(0, 31)]
state.data = {frame: defaultdict(AttestationsAccumulator) for frame in current_frames}
state.data = {frame: defaultdict(AttestationsAccumulator) for frame in state.frames}
state._migrate_frames_data(new_frames)
assert state.data == {(0, 31): defaultdict(AttestationsAccumulator)}


def test_migrate_frames_data_handles_wider_old_frame():
state = State()
state._epochs_to_process = tuple(sequence(0, 63))
state._epochs_per_frame = 64
state.frames = [(0, 63)]
new_frames = [(0, 31), (32, 63)]
state.data = {
(0, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(30, 20)}),
Expand Down