Skip to content

Commit 1aa7228

Browse files
committedFeb 13, 2025··
refactor: State and tests

File tree

3 files changed

+445
-286
lines changed

3 files changed

+445
-286
lines changed
 

‎src/modules/csm/checkpoint.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -205,12 +205,12 @@ def _check_duty(
205205
for root in block_roots:
206206
attestations = self.cc.get_block_attestations(root)
207207
process_attestations(attestations, committees, self.eip7549_supported)
208-
208+
frame = self.state.find_frame(duty_epoch)
209209
with lock:
210210
for committee in committees.values():
211211
for validator_duty in committee:
212212
self.state.increment_duty(
213-
duty_epoch,
213+
frame,
214214
validator_duty.index,
215215
included=validator_duty.included,
216216
)

‎src/modules/csm/state.py

+67-68
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313

1414
logger = logging.getLogger(__name__)
1515

16-
type Frame = tuple[EpochNumber, EpochNumber]
17-
1816

1917
class InvalidState(ValueError):
2018
"""State has data considered as invalid for a report"""
@@ -36,6 +34,10 @@ def add_duty(self, included: bool) -> None:
3634
self.included += 1 if included else 0
3735

3836

37+
type Frame = tuple[EpochNumber, EpochNumber]
38+
type StateData = dict[Frame, defaultdict[ValidatorIndex, AttestationsAccumulator]]
39+
40+
3941
class State:
4042
"""
4143
Processing state of a CSM performance oracle frame.
@@ -46,18 +48,16 @@ class State:
4648
4749
The state can be migrated to be used for another frame's report by calling the `migrate` method.
4850
"""
49-
data: dict[Frame, defaultdict[ValidatorIndex, AttestationsAccumulator]]
51+
data: StateData
5052

5153
_epochs_to_process: tuple[EpochNumber, ...]
5254
_processed_epochs: set[EpochNumber]
5355
_epochs_per_frame: int
5456

5557
_consensus_version: int = 1
5658

57-
def __init__(self, data: dict[Frame, dict[ValidatorIndex, AttestationsAccumulator]] | None = None) -> None:
58-
self.data = {
59-
frame: defaultdict(AttestationsAccumulator, validators) for frame, validators in (data or {}).items()
60-
}
59+
def __init__(self) -> None:
60+
self.data = {}
6161
self._epochs_to_process = tuple()
6262
self._processed_epochs = set()
6363
self._epochs_per_frame = 0
@@ -110,6 +110,16 @@ def unprocessed_epochs(self) -> set[EpochNumber]:
110110
def is_fulfilled(self) -> bool:
111111
return not self.unprocessed_epochs
112112

113+
@staticmethod
114+
def calculate_frames(epochs_to_process: tuple[EpochNumber, ...], epochs_per_frame: int) -> list[Frame]:
115+
"""Split epochs to process into frames of `epochs_per_frame` length"""
116+
frames = []
117+
for frame_epochs in batched(epochs_to_process, epochs_per_frame):
118+
if len(frame_epochs) < epochs_per_frame:
119+
raise ValueError("Insufficient epochs to form a frame")
120+
frames.append((frame_epochs[0], frame_epochs[-1]))
121+
return frames
122+
113123
def clear(self) -> None:
114124
self.data = {}
115125
self._epochs_to_process = tuple()
@@ -123,17 +133,20 @@ def find_frame(self, epoch: EpochNumber) -> Frame:
123133
return epoch_range
124134
raise ValueError(f"Epoch {epoch} is out of frames range: {frames}")
125135

126-
def increment_duty(self, epoch: EpochNumber, val_index: ValidatorIndex, included: bool) -> None:
127-
epoch_range = self.find_frame(epoch)
128-
self.data[epoch_range][val_index].add_duty(included)
136+
def increment_duty(self, frame: Frame, val_index: ValidatorIndex, included: bool) -> None:
137+
if frame not in self.data:
138+
raise ValueError(f"Frame {frame} is not found in the state")
139+
self.data[frame][val_index].add_duty(included)
129140

130141
def add_processed_epoch(self, epoch: EpochNumber) -> None:
131142
self._processed_epochs.add(epoch)
132143

133144
def log_progress(self) -> None:
134145
logger.info({"msg": f"Processed {len(self._processed_epochs)} of {len(self._epochs_to_process)} epochs"})
135146

136-
def init_or_migrate(self, l_epoch: EpochNumber, r_epoch: EpochNumber, epochs_per_frame: int, consensus_version: int) -> None:
147+
def init_or_migrate(
148+
self, l_epoch: EpochNumber, r_epoch: EpochNumber, epochs_per_frame: int, consensus_version: int
149+
) -> None:
137150
if consensus_version != self._consensus_version:
138151
logger.warning(
139152
{
@@ -143,59 +156,55 @@ def init_or_migrate(self, l_epoch: EpochNumber, r_epoch: EpochNumber, epochs_per
143156
)
144157
self.clear()
145158

159+
frames = self.calculate_frames(tuple(sequence(l_epoch, r_epoch)), epochs_per_frame)
160+
frames_data: StateData = {frame: defaultdict(AttestationsAccumulator) for frame in frames}
161+
146162
if not self.is_empty:
147-
invalidated = self._migrate_or_invalidate(l_epoch, r_epoch, epochs_per_frame)
148-
if invalidated:
149-
self.clear()
163+
cached_frames = self.calculate_frames(self._epochs_to_process, self._epochs_per_frame)
164+
if cached_frames == frames:
165+
logger.info({"msg": "No need to migrate duties data cache"})
166+
return
167+
168+
frames_data, migration_status = self._migrate_frames_data(cached_frames, frames)
169+
170+
for current_frame, migrated in migration_status.items():
171+
if not migrated:
172+
logger.warning({"msg": f"Invalidating frame duties data cache: {current_frame}"})
173+
for epoch in sequence(*current_frame):
174+
self._processed_epochs.discard(epoch)
150175

151-
self._fill_frames(l_epoch, r_epoch, epochs_per_frame)
176+
self.data = frames_data
152177
self._epochs_per_frame = epochs_per_frame
153178
self._epochs_to_process = tuple(sequence(l_epoch, r_epoch))
154179
self._consensus_version = consensus_version
155180
self.commit()
156181

157-
def _fill_frames(self, l_epoch: EpochNumber, r_epoch: EpochNumber, epochs_per_frame: int) -> None:
158-
frames = self.calculate_frames(tuple(sequence(l_epoch, r_epoch)), epochs_per_frame)
159-
for frame in frames:
160-
self.data.setdefault(frame, defaultdict(AttestationsAccumulator))
161-
162-
def _migrate_or_invalidate(self, l_epoch: EpochNumber, r_epoch: EpochNumber, epochs_per_frame: int) -> bool:
163-
current_frames = self.calculate_frames(self._epochs_to_process, self._epochs_per_frame)
164-
new_frames = self.calculate_frames(tuple(sequence(l_epoch, r_epoch)), epochs_per_frame)
165-
inv_msg = f"Discarding invalid state cache because of frames change. {current_frames=}, {new_frames=}"
166-
167-
if self._invalidate_on_epoch_range_change(l_epoch, r_epoch):
168-
logger.warning({"msg": inv_msg})
169-
return True
170-
171-
frame_expanded = epochs_per_frame > self._epochs_per_frame
172-
frame_shrunk = epochs_per_frame < self._epochs_per_frame
173-
174-
has_single_frame = len(current_frames) == len(new_frames) == 1
175-
176-
if has_single_frame and frame_expanded:
177-
current_frame, *_ = current_frames
178-
new_frame, *_ = new_frames
179-
self.data[new_frame] = self.data.pop(current_frame)
180-
logger.info({"msg": f"Migrated state cache to a new frame. {current_frame=}, {new_frame=}"})
181-
return False
182-
183-
if has_single_frame and frame_shrunk:
184-
logger.warning({"msg": inv_msg})
185-
return True
186-
187-
if not has_single_frame and frame_expanded or frame_shrunk:
188-
logger.warning({"msg": inv_msg})
189-
return True
190-
191-
return False
192-
193-
def _invalidate_on_epoch_range_change(self, l_epoch: EpochNumber, r_epoch: EpochNumber) -> bool:
194-
"""Check if the epoch range has been invalidated."""
195-
for epoch_set in (self._epochs_to_process, self._processed_epochs):
196-
if any(epoch < l_epoch or epoch > r_epoch for epoch in epoch_set):
197-
return True
198-
return False
182+
def _migrate_frames_data(
183+
self, current_frames: list[Frame], new_frames: list[Frame]
184+
) -> tuple[StateData, dict[Frame, bool]]:
185+
migration_status = {frame: False for frame in current_frames}
186+
new_data: StateData = {frame: defaultdict(AttestationsAccumulator) for frame in new_frames}
187+
188+
logger.info({"msg": f"Migrating duties data cache: {current_frames=} -> {new_frames=}"})
189+
190+
for current_frame in current_frames:
191+
curr_frame_l_epoch, curr_frame_r_epoch = current_frame
192+
for new_frame in new_frames:
193+
if current_frame == new_frame:
194+
new_data[new_frame] = self.data[current_frame]
195+
migration_status[current_frame] = True
196+
break
197+
198+
new_frame_l_epoch, new_frame_r_epoch = new_frame
199+
if curr_frame_l_epoch >= new_frame_l_epoch and curr_frame_r_epoch <= new_frame_r_epoch:
200+
logger.info({"msg": f"Migrating frame duties data cache: {current_frame=} -> {new_frame=}"})
201+
for val in self.data[current_frame]:
202+
new_data[new_frame][val].assigned += self.data[current_frame][val].assigned
203+
new_data[new_frame][val].included += self.data[current_frame][val].included
204+
migration_status[current_frame] = True
205+
break
206+
207+
return new_data, migration_status
199208

200209
def validate(self, l_epoch: EpochNumber, r_epoch: EpochNumber) -> None:
201210
if not self.is_fulfilled:
@@ -209,21 +218,11 @@ def validate(self, l_epoch: EpochNumber, r_epoch: EpochNumber) -> None:
209218
if epoch not in self._processed_epochs:
210219
raise InvalidState(f"Epoch {epoch} missing in processed epochs")
211220

212-
@staticmethod
213-
def calculate_frames(epochs_to_process: tuple[EpochNumber, ...], epochs_per_frame: int) -> list[Frame]:
214-
"""Split epochs to process into frames of `epochs_per_frame` length"""
215-
frames = []
216-
for frame_epochs in batched(epochs_to_process, epochs_per_frame):
217-
if len(frame_epochs) < epochs_per_frame:
218-
raise ValueError("Insufficient epochs to form a frame")
219-
frames.append((frame_epochs[0], frame_epochs[-1]))
220-
return frames
221-
222221
def get_network_aggr(self, frame: Frame) -> AttestationsAccumulator:
223222
# TODO: exclude `active_slashed` validators from the calculation
224223
included = assigned = 0
225224
frame_data = self.data.get(frame)
226-
if not frame_data:
225+
if frame_data is None:
227226
raise ValueError(f"No data for frame {frame} to calculate network aggregate")
228227
for validator, acc in frame_data.items():
229228
if acc.included > acc.assigned:

‎tests/modules/csm/test_state.py

+376-216
Original file line numberDiff line numberDiff line change
@@ -1,258 +1,418 @@
1+
import os
2+
import pickle
3+
from collections import defaultdict
14
from pathlib import Path
25
from unittest.mock import Mock
36

47
import pytest
58

6-
from src.modules.csm.state import AttestationsAccumulator, State
7-
from src.types import EpochNumber, ValidatorIndex
9+
from src import variables
10+
from src.modules.csm.state import AttestationsAccumulator, State, InvalidState
11+
from src.types import ValidatorIndex
812
from src.utils.range import sequence
913

1014

11-
@pytest.fixture()
12-
def state_file_path(tmp_path: Path) -> Path:
13-
return (tmp_path / "mock").with_suffix(State.EXTENSION)
15+
@pytest.fixture(autouse=True)
16+
def remove_state_files():
17+
state_file = Path("/tmp/state.pkl")
18+
state_buf = Path("/tmp/state.buf")
19+
state_file.unlink(missing_ok=True)
20+
state_buf.unlink(missing_ok=True)
21+
yield
22+
state_file.unlink(missing_ok=True)
23+
state_buf.unlink(missing_ok=True)
24+
25+
26+
def test_load_restores_state_from_file(monkeypatch):
27+
monkeypatch.setattr("src.modules.csm.state.State.file", lambda _=None: Path("/tmp/state.pkl"))
28+
state = State()
29+
state.data = {
30+
(0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}),
31+
}
32+
state.commit()
33+
loaded_state = State.load()
34+
assert loaded_state.data == state.data
1435

1536

16-
@pytest.fixture(autouse=True)
17-
def mock_state_file(state_file_path: Path):
18-
State.file = Mock(return_value=state_file_path)
37+
def test_load_returns_new_instance_if_file_not_found(monkeypatch):
38+
monkeypatch.setattr("src.modules.csm.state.State.file", lambda: Path("/non/existent/path"))
39+
state = State.load()
40+
assert state.is_empty
1941

2042

21-
def test_attestation_aggregate_perf():
22-
aggr = AttestationsAccumulator(included=333, assigned=777)
23-
assert aggr.perf == pytest.approx(0.4285, abs=1e-4)
43+
def test_load_returns_new_instance_if_empty_object(monkeypatch, tmp_path):
44+
with open('/tmp/state.pkl', "wb") as f:
45+
pickle.dump(None, f)
46+
monkeypatch.setattr("src.modules.csm.state.State.file", lambda: Path("/tmp/state.pkl"))
47+
state = State.load()
48+
assert state.is_empty
49+
50+
51+
def test_commit_saves_state_to_file(monkeypatch):
52+
state = State()
53+
state.data = {
54+
(0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}),
55+
}
56+
monkeypatch.setattr("src.modules.csm.state.State.file", lambda _: Path("/tmp/state.pkl"))
57+
monkeypatch.setattr("os.replace", Mock(side_effect=os.replace))
58+
state.commit()
59+
with open("/tmp/state.pkl", "rb") as f:
60+
loaded_state = pickle.load(f)
61+
assert loaded_state.data == state.data
62+
os.replace.assert_called_once_with(Path("/tmp/state.buf"), Path("/tmp/state.pkl"))
63+
64+
65+
def test_file_returns_correct_path(monkeypatch):
66+
monkeypatch.setattr(variables, "CACHE_PATH", Path("/tmp"))
67+
assert State.file() == Path("/tmp/cache.pkl")
68+
69+
70+
def test_buffer_returns_correct_path(monkeypatch):
71+
monkeypatch.setattr(variables, "CACHE_PATH", Path("/tmp"))
72+
state = State()
73+
assert state.buffer == Path("/tmp/cache.buf")
74+
75+
76+
def test_is_empty_returns_true_for_empty_state():
77+
state = State()
78+
assert state.is_empty
79+
80+
81+
def test_is_empty_returns_false_for_non_empty_state():
82+
state = State()
83+
state.data = {(0, 31): defaultdict(AttestationsAccumulator)}
84+
assert not state.is_empty
85+
86+
87+
def test_unprocessed_epochs_raises_error_if_epochs_not_set():
88+
state = State()
89+
with pytest.raises(ValueError, match="Epochs to process are not set"):
90+
state.unprocessed_epochs
91+
92+
93+
def test_unprocessed_epochs_returns_correct_set():
94+
state = State()
95+
state._epochs_to_process = tuple(sequence(0, 95))
96+
state._processed_epochs = set(sequence(0, 63))
97+
assert state.unprocessed_epochs == set(sequence(64, 95))
98+
99+
100+
def test_is_fulfilled_returns_true_if_no_unprocessed_epochs():
101+
state = State()
102+
state._epochs_to_process = tuple(sequence(0, 95))
103+
state._processed_epochs = set(sequence(0, 95))
104+
assert state.is_fulfilled
105+
106+
107+
def test_is_fulfilled_returns_false_if_unprocessed_epochs_exist():
108+
state = State()
109+
state._epochs_to_process = tuple(sequence(0, 95))
110+
state._processed_epochs = set(sequence(0, 63))
111+
assert not state.is_fulfilled
112+
113+
114+
def test_calculate_frames_handles_exact_frame_size():
115+
epochs = tuple(range(10))
116+
frames = State.calculate_frames(epochs, 5)
117+
assert frames == [(0, 4), (5, 9)]
118+
119+
120+
def test_calculate_frames_raises_error_for_insufficient_epochs():
121+
epochs = tuple(range(8))
122+
with pytest.raises(ValueError, match="Insufficient epochs to form a frame"):
123+
State.calculate_frames(epochs, 5)
124+
125+
126+
def test_clear_resets_state_to_empty():
127+
state = State()
128+
state.data = {(0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)})}
129+
state.clear()
130+
assert state.is_empty
131+
132+
133+
def test_find_frame_returns_correct_frame():
134+
state = State()
135+
state.data = {(0, 31): defaultdict(AttestationsAccumulator)}
136+
assert state.find_frame(15) == (0, 31)
24137

25138

26-
def test_state_avg_perf():
139+
def test_find_frame_raises_error_for_out_of_range_epoch():
27140
state = State()
141+
state.data = {(0, 31): defaultdict(AttestationsAccumulator)}
142+
with pytest.raises(ValueError, match="Epoch 32 is out of frames range"):
143+
state.find_frame(32)
28144

29-
frame = (0, 999)
30145

31-
with pytest.raises(ValueError):
32-
state.get_network_aggr(frame)
146+
def test_increment_duty_adds_duty_correctly():
147+
state = State()
148+
frame = (0, 31)
149+
state.data = {
150+
frame: defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}),
151+
}
152+
state.increment_duty(frame, ValidatorIndex(1), True)
153+
assert state.data[frame][ValidatorIndex(1)].assigned == 11
154+
assert state.data[frame][ValidatorIndex(1)].included == 6
33155

156+
157+
def test_increment_duty_creates_new_validator_entry():
34158
state = State()
35-
state.init_or_migrate(*frame, 1000, 1)
159+
frame = (0, 31)
36160
state.data = {
37-
frame: {
38-
ValidatorIndex(0): AttestationsAccumulator(included=0, assigned=0),
39-
ValidatorIndex(1): AttestationsAccumulator(included=0, assigned=0),
40-
}
161+
frame: defaultdict(AttestationsAccumulator),
41162
}
163+
state.increment_duty(frame, ValidatorIndex(2), True)
164+
assert state.data[frame][ValidatorIndex(2)].assigned == 1
165+
assert state.data[frame][ValidatorIndex(2)].included == 1
42166

43-
assert state.get_network_aggr(frame).perf == 0
44167

168+
def test_increment_duty_handles_non_included_duty():
169+
state = State()
170+
frame = (0, 31)
45171
state.data = {
46-
frame: {
47-
ValidatorIndex(0): AttestationsAccumulator(included=333, assigned=777),
48-
ValidatorIndex(1): AttestationsAccumulator(included=167, assigned=223),
49-
}
172+
frame: defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}),
50173
}
174+
state.increment_duty(frame, ValidatorIndex(1), False)
175+
assert state.data[frame][ValidatorIndex(1)].assigned == 11
176+
assert state.data[frame][ValidatorIndex(1)].included == 5
51177

52-
assert state.get_network_aggr(frame).perf == 0.5
53178

179+
def test_increment_duty_raises_error_for_out_of_range_epoch():
180+
state = State()
181+
state.data = {
182+
(0, 31): defaultdict(AttestationsAccumulator),
183+
}
184+
with pytest.raises(ValueError, match="is not found in the state"):
185+
state.increment_duty((0, 32), ValidatorIndex(1), True)
54186

55-
def test_state_attestations():
56-
state = State(
57-
{
58-
(0, 999): {
59-
ValidatorIndex(0): AttestationsAccumulator(included=333, assigned=777),
60-
ValidatorIndex(1): AttestationsAccumulator(included=167, assigned=223),
61-
}
62-
}
63-
)
64187

65-
network_aggr = state.get_network_aggr((0, 999))
188+
def test_add_processed_epoch_adds_epoch_to_processed_set():
189+
state = State()
190+
state.add_processed_epoch(5)
191+
assert 5 in state._processed_epochs
66192

67-
assert network_aggr.assigned == 1000
68-
assert network_aggr.included == 500
69193

194+
def test_add_processed_epoch_does_not_duplicate_epochs():
195+
state = State()
196+
state.add_processed_epoch(5)
197+
state.add_processed_epoch(5)
198+
assert len(state._processed_epochs) == 1
70199

71-
def test_state_load():
72-
orig = State(
73-
{
74-
(0, 999): {
75-
ValidatorIndex(0): AttestationsAccumulator(included=333, assigned=777),
76-
ValidatorIndex(1): AttestationsAccumulator(included=167, assigned=223),
77-
}
78-
}
79-
)
80200

81-
orig.commit()
82-
copy = State.load()
83-
assert copy.data == orig.data
201+
def test_init_or_migrate_discards_data_on_version_change():
202+
state = State()
203+
state._consensus_version = 1
204+
state.clear = Mock()
205+
state.commit = Mock()
206+
state.init_or_migrate(0, 63, 32, 2)
207+
state.clear.assert_called_once()
208+
state.commit.assert_called_once()
84209

85210

86-
def test_state_clear():
87-
state = State(
88-
{
89-
(0, 999): {
90-
ValidatorIndex(0): AttestationsAccumulator(included=333, assigned=777),
91-
ValidatorIndex(1): AttestationsAccumulator(included=167, assigned=223),
92-
}
93-
}
94-
)
211+
def test_init_or_migrate_no_migration_needed():
212+
state = State()
213+
state._consensus_version = 1
214+
state._epochs_to_process = tuple(sequence(0, 63))
215+
state._epochs_per_frame = 32
216+
state.data = {
217+
(0, 31): defaultdict(AttestationsAccumulator),
218+
(32, 63): defaultdict(AttestationsAccumulator),
219+
}
220+
state.commit = Mock()
221+
state.init_or_migrate(0, 63, 32, 1)
222+
state.commit.assert_not_called()
95223

96-
state._epochs_to_process = (EpochNumber(1), EpochNumber(33))
97-
state._processed_epochs = {EpochNumber(42), EpochNumber(17)}
98224

99-
state.clear()
100-
assert state.is_empty
101-
assert not state.data
225+
def test_init_or_migrate_migrates_data():
226+
state = State()
227+
state._consensus_version = 1
228+
state._epochs_to_process = tuple(sequence(0, 63))
229+
state._epochs_per_frame = 32
230+
state.data = {
231+
(0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}),
232+
(32, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(20, 15)}),
233+
}
234+
state.commit = Mock()
235+
state.init_or_migrate(0, 63, 64, 1)
236+
assert state.data == {
237+
(0, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(30, 20)}),
238+
}
239+
state.commit.assert_called_once()
240+
241+
242+
def test_init_or_migrate_invalidates_unmigrated_frames():
243+
state = State()
244+
state._consensus_version = 1
245+
state._epochs_to_process = tuple(sequence(0, 63))
246+
state._epochs_per_frame = 64
247+
state.data = {
248+
(0, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(30, 20)}),
249+
}
250+
state.commit = Mock()
251+
state.init_or_migrate(0, 31, 32, 1)
252+
assert state.data == {
253+
(0, 31): defaultdict(AttestationsAccumulator),
254+
}
255+
assert state._processed_epochs == set()
256+
state.commit.assert_called_once()
257+
258+
259+
def test_init_or_migrate_discards_unmigrated_frame():
260+
state = State()
261+
state._consensus_version = 1
262+
state._epochs_to_process = tuple(sequence(0, 95))
263+
state._epochs_per_frame = 32
264+
state.data = {
265+
(0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}),
266+
(32, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(20, 15)}),
267+
(64, 95): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(30, 25)}),
268+
}
269+
state._processed_epochs = set(sequence(0, 95))
270+
state.commit = Mock()
271+
state.init_or_migrate(0, 63, 32, 1)
272+
assert state.data == {
273+
(0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}),
274+
(32, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(20, 15)}),
275+
}
276+
assert state._processed_epochs == set(sequence(0, 63))
277+
state.commit.assert_called_once()
278+
279+
280+
def test_migrate_frames_data_creates_new_data_correctly():
281+
state = State()
282+
current_frames = [(0, 31), (32, 63)]
283+
new_frames = [(0, 63)]
284+
state.data = {
285+
(0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}),
286+
(32, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(20, 15)}),
287+
}
288+
new_data, migration_status = state._migrate_frames_data(current_frames, new_frames)
289+
assert new_data == {
290+
(0, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(30, 20)})
291+
}
292+
assert migration_status == {(0, 31): True, (32, 63): True}
293+
294+
295+
def test_migrate_frames_data_handles_no_migration():
296+
state = State()
297+
current_frames = [(0, 31)]
298+
new_frames = [(0, 31)]
299+
state.data = {
300+
(0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}),
301+
}
302+
new_data, migration_status = state._migrate_frames_data(current_frames, new_frames)
303+
assert new_data == {
304+
(0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)})
305+
}
306+
assert migration_status == {(0, 31): True}
307+
308+
309+
def test_migrate_frames_data_handles_partial_migration():
310+
state = State()
311+
current_frames = [(0, 31), (32, 63)]
312+
new_frames = [(0, 31), (32, 95)]
313+
state.data = {
314+
(0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}),
315+
(32, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(20, 15)}),
316+
}
317+
new_data, migration_status = state._migrate_frames_data(current_frames, new_frames)
318+
assert new_data == {
319+
(0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 5)}),
320+
(32, 95): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(20, 15)}),
321+
}
322+
assert migration_status == {(0, 31): True, (32, 63): True}
323+
324+
325+
def test_migrate_frames_data_handles_no_data():
326+
state = State()
327+
current_frames = [(0, 31)]
328+
new_frames = [(0, 31)]
329+
state.data = {frame: defaultdict(AttestationsAccumulator) for frame in current_frames}
330+
new_data, migration_status = state._migrate_frames_data(current_frames, new_frames)
331+
assert new_data == {(0, 31): defaultdict(AttestationsAccumulator)}
332+
assert migration_status == {(0, 31): True}
333+
334+
335+
def test_migrate_frames_data_handles_wider_old_frame():
336+
state = State()
337+
current_frames = [(0, 63)]
338+
new_frames = [(0, 31), (32, 63)]
339+
state.data = {
340+
(0, 63): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(30, 20)}),
341+
}
342+
new_data, migration_status = state._migrate_frames_data(current_frames, new_frames)
343+
assert new_data == {
344+
(0, 31): defaultdict(AttestationsAccumulator),
345+
(32, 63): defaultdict(AttestationsAccumulator),
346+
}
347+
assert migration_status == {(0, 63): False}
348+
349+
350+
def test_validate_raises_error_if_state_not_fulfilled():
351+
state = State()
352+
state._epochs_to_process = tuple(sequence(0, 95))
353+
state._processed_epochs = set(sequence(0, 94))
354+
with pytest.raises(InvalidState, match="State is not fulfilled"):
355+
state.validate(0, 95)
356+
357+
358+
def test_validate_raises_error_if_processed_epoch_out_of_range():
359+
state = State()
360+
state._epochs_to_process = tuple(sequence(0, 95))
361+
state._processed_epochs = set(sequence(0, 95))
362+
state._processed_epochs.add(96)
363+
with pytest.raises(InvalidState, match="Processed epoch 96 is out of range"):
364+
state.validate(0, 95)
365+
366+
367+
def test_validate_raises_error_if_epoch_missing_in_processed_epochs():
368+
state = State()
369+
state._epochs_to_process = tuple(sequence(0, 94))
370+
state._processed_epochs = set(sequence(0, 94))
371+
with pytest.raises(InvalidState, match="Epoch 95 missing in processed epochs"):
372+
state.validate(0, 95)
102373

103374

104-
def test_state_add_processed_epoch():
375+
def test_validate_passes_for_fulfilled_state():
105376
state = State()
106-
state.add_processed_epoch(EpochNumber(42))
107-
state.add_processed_epoch(EpochNumber(17))
108-
assert state._processed_epochs == {EpochNumber(42), EpochNumber(17)}
377+
state._epochs_to_process = tuple(sequence(0, 95))
378+
state._processed_epochs = set(sequence(0, 95))
379+
state.validate(0, 95)
109380

110381

111-
def test_state_inc():
112-
113-
frame_0 = (0, 999)
114-
frame_1 = (1000, 1999)
115-
116-
state = State(
117-
{
118-
frame_0: {
119-
ValidatorIndex(0): AttestationsAccumulator(included=333, assigned=777),
120-
ValidatorIndex(1): AttestationsAccumulator(included=167, assigned=223),
121-
},
122-
frame_1: {
123-
ValidatorIndex(0): AttestationsAccumulator(included=1, assigned=1),
124-
ValidatorIndex(1): AttestationsAccumulator(included=0, assigned=1),
125-
},
126-
}
127-
)
128-
129-
state.increment_duty(999, ValidatorIndex(0), True)
130-
state.increment_duty(999, ValidatorIndex(0), False)
131-
state.increment_duty(999, ValidatorIndex(1), True)
132-
state.increment_duty(999, ValidatorIndex(1), True)
133-
state.increment_duty(999, ValidatorIndex(1), False)
134-
state.increment_duty(999, ValidatorIndex(2), True)
135-
136-
state.increment_duty(1000, ValidatorIndex(2), False)
137-
138-
assert tuple(state.data[frame_0].values()) == (
139-
AttestationsAccumulator(included=334, assigned=779),
140-
AttestationsAccumulator(included=169, assigned=226),
141-
AttestationsAccumulator(included=1, assigned=1),
142-
)
143-
144-
assert tuple(state.data[frame_1].values()) == (
145-
AttestationsAccumulator(included=1, assigned=1),
146-
AttestationsAccumulator(included=0, assigned=1),
147-
AttestationsAccumulator(included=0, assigned=1),
148-
)
149-
150-
151-
def test_state_file_is_path():
152-
assert isinstance(State.file(), Path)
153-
154-
155-
class TestStateTransition:
156-
"""Tests for State's transition for different l_epoch, r_epoch values"""
157-
158-
@pytest.fixture(autouse=True)
159-
def no_commit(self, monkeypatch: pytest.MonkeyPatch):
160-
monkeypatch.setattr(State, "commit", Mock())
161-
162-
def test_empty_to_new_frame(self):
163-
state = State()
164-
assert state.is_empty
165-
166-
l_epoch = EpochNumber(1)
167-
r_epoch = EpochNumber(255)
168-
169-
state.init_or_migrate(l_epoch, r_epoch, 255, 1)
170-
171-
assert not state.is_empty
172-
assert state.unprocessed_epochs == set(sequence(l_epoch, r_epoch))
173-
174-
@pytest.mark.parametrize(
175-
("l_epoch_old", "r_epoch_old", "l_epoch_new", "r_epoch_new"),
176-
[
177-
pytest.param(1, 255, 256, 510, id="Migrate a..bA..B"),
178-
pytest.param(1, 255, 32, 510, id="Migrate a..A..b..B"),
179-
pytest.param(32, 510, 1, 255, id="Migrate: A..a..B..b"),
180-
],
181-
)
182-
def test_new_frame_requires_discarding_state(self, l_epoch_old, r_epoch_old, l_epoch_new, r_epoch_new):
183-
state = State()
184-
state.clear = Mock(side_effect=state.clear)
185-
state.init_or_migrate(l_epoch_old, r_epoch_old, r_epoch_old - l_epoch_old + 1, 1)
186-
state.clear.assert_not_called()
187-
188-
state.init_or_migrate(l_epoch_new, r_epoch_new, r_epoch_new - l_epoch_new + 1, 1)
189-
state.clear.assert_called_once()
190-
191-
assert state.unprocessed_epochs == set(sequence(l_epoch_new, r_epoch_new))
192-
193-
@pytest.mark.parametrize(
194-
("l_epoch_old", "r_epoch_old", "l_epoch_new", "r_epoch_new", "epochs_per_frame"),
195-
[
196-
pytest.param(1, 255, 1, 510, 255, id="Migrate Aa..b..B"),
197-
],
198-
)
199-
def test_new_frame_extends_old_state(self, l_epoch_old, r_epoch_old, l_epoch_new, r_epoch_new, epochs_per_frame):
200-
state = State()
201-
state.clear = Mock(side_effect=state.clear)
202-
203-
state.init_or_migrate(l_epoch_old, r_epoch_old, epochs_per_frame, 1)
204-
state.clear.assert_not_called()
205-
206-
state.init_or_migrate(l_epoch_new, r_epoch_new, epochs_per_frame, 1)
207-
state.clear.assert_not_called()
208-
209-
assert state.unprocessed_epochs == set(sequence(l_epoch_new, r_epoch_new))
210-
assert len(state.data) == 2
211-
assert list(state.data.keys()) == [(l_epoch_old, r_epoch_old), (r_epoch_old + 1, r_epoch_new)]
212-
assert state.calculate_frames(state._epochs_to_process, epochs_per_frame) == [
213-
(l_epoch_old, r_epoch_old),
214-
(r_epoch_old + 1, r_epoch_new),
215-
]
216-
217-
@pytest.mark.parametrize(
218-
("l_epoch_old", "r_epoch_old", "epochs_per_frame_old", "l_epoch_new", "r_epoch_new", "epochs_per_frame_new"),
219-
[
220-
pytest.param(32, 510, 479, 1, 510, 510, id="Migrate: A..a..b..B"),
221-
],
222-
)
223-
def test_new_frame_extends_old_state_with_single_frame(
224-
self, l_epoch_old, r_epoch_old, epochs_per_frame_old, l_epoch_new, r_epoch_new, epochs_per_frame_new
225-
):
226-
state = State()
227-
state.clear = Mock(side_effect=state.clear)
228-
229-
state.init_or_migrate(l_epoch_old, r_epoch_old, epochs_per_frame_old, 1)
230-
state.clear.assert_not_called()
231-
232-
state.init_or_migrate(l_epoch_new, r_epoch_new, epochs_per_frame_new, 1)
233-
state.clear.assert_not_called()
234-
235-
assert state.unprocessed_epochs == set(sequence(l_epoch_new, r_epoch_new))
236-
assert len(state.data) == 1
237-
assert list(state.data.keys())[0] == (l_epoch_new, r_epoch_new)
238-
assert state.calculate_frames(state._epochs_to_process, epochs_per_frame_new) == [(l_epoch_new, r_epoch_new)]
239-
240-
@pytest.mark.parametrize(
241-
("old_version", "new_version"),
242-
[
243-
pytest.param(2, 3, id="Increase consensus version"),
244-
pytest.param(3, 2, id="Decrease consensus version"),
245-
],
246-
)
247-
def test_consensus_version_change(self, old_version, new_version):
248-
state = State()
249-
state.clear = Mock(side_effect=state.clear)
250-
state._consensus_version = old_version
251-
252-
l_epoch = r_epoch = EpochNumber(255)
253-
254-
state.init_or_migrate(l_epoch, r_epoch, 1, old_version)
255-
state.clear.assert_not_called()
256-
257-
state.init_or_migrate(l_epoch, r_epoch, 1, new_version)
258-
state.clear.assert_called_once()
382+
def test_attestation_aggregate_perf():
383+
aggr = AttestationsAccumulator(included=333, assigned=777)
384+
assert aggr.perf == pytest.approx(0.4285, abs=1e-4)
385+
386+
387+
def test_get_network_aggr_computes_correctly():
388+
state = State()
389+
state.data = {
390+
(0, 31): defaultdict(
391+
AttestationsAccumulator,
392+
{ValidatorIndex(1): AttestationsAccumulator(10, 5), ValidatorIndex(2): AttestationsAccumulator(20, 15)},
393+
)
394+
}
395+
aggr = state.get_network_aggr((0, 31))
396+
assert aggr.assigned == 30
397+
assert aggr.included == 20
398+
399+
400+
def test_get_network_aggr_raises_error_for_invalid_accumulator():
401+
state = State()
402+
state.data = {(0, 31): defaultdict(AttestationsAccumulator, {ValidatorIndex(1): AttestationsAccumulator(10, 15)})}
403+
with pytest.raises(ValueError, match="Invalid accumulator"):
404+
state.get_network_aggr((0, 31))
405+
406+
407+
def test_get_network_aggr_raises_error_for_missing_frame_data():
408+
state = State()
409+
with pytest.raises(ValueError, match="No data for frame"):
410+
state.get_network_aggr((0, 31))
411+
412+
413+
def test_get_network_aggr_handles_empty_frame_data():
414+
state = State()
415+
state.data = {(0, 31): defaultdict(AttestationsAccumulator)}
416+
aggr = state.get_network_aggr((0, 31))
417+
assert aggr.assigned == 0
418+
assert aggr.included == 0

0 commit comments

Comments
 (0)
Please sign in to comment.