Skip to content

Commit e5f601e

Browse files
authored
Merge pull request #557 from lidofinance/feat/csm/state-data-as-tuples
[CSM] feat: proper missing frames handling
2 parents 24a2219 + 15d5ab2 commit e5f601e

File tree

10 files changed

+1041
-464
lines changed

10 files changed

+1041
-464
lines changed

src/modules/csm/checkpoint.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def exec(self, checkpoint: FrameCheckpoint) -> int:
143143
for duty_epoch in unprocessed_epochs
144144
}
145145
self._process(unprocessed_epochs, duty_epochs_roots)
146+
self.state.commit()
146147
return len(unprocessed_epochs)
147148

148149
def _get_block_roots(self, checkpoint_slot: SlotNumber):
@@ -204,18 +205,17 @@ def _check_duty(
204205
for root in block_roots:
205206
attestations = self.cc.get_block_attestations(root)
206207
process_attestations(attestations, committees, self.eip7549_supported)
207-
208208
with lock:
209209
for committee in committees.values():
210210
for validator_duty in committee:
211-
self.state.inc(
211+
self.state.increment_duty(
212+
duty_epoch,
212213
validator_duty.index,
213214
included=validator_duty.included,
214215
)
215216
if duty_epoch not in self.state.unprocessed_epochs:
216217
raise ValueError(f"Epoch {duty_epoch} is not in epochs that should be processed")
217218
self.state.add_processed_epoch(duty_epoch)
218-
self.state.commit()
219219
self.state.log_progress()
220220
unprocessed_epochs = self.state.unprocessed_epochs
221221
CSM_UNPROCESSED_EPOCHS_COUNT.set(len(unprocessed_epochs))

src/modules/csm/csm.py

+162-94
Large diffs are not rendered by default.

src/modules/csm/log.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class LogJSONEncoder(json.JSONEncoder): ...
1212

1313
@dataclass
1414
class ValidatorFrameSummary:
15-
perf: AttestationsAccumulator = field(default_factory=AttestationsAccumulator)
15+
attestation_duty: AttestationsAccumulator = field(default_factory=AttestationsAccumulator)
1616
slashed: bool = False
1717

1818

@@ -35,13 +35,14 @@ class FramePerfLog:
3535
default_factory=lambda: defaultdict(OperatorFrameSummary)
3636
)
3737

38-
def encode(self) -> bytes:
38+
@staticmethod
39+
def encode(logs: list['FramePerfLog']) -> bytes:
3940
return (
4041
LogJSONEncoder(
4142
indent=None,
4243
separators=(',', ':'),
4344
sort_keys=True,
4445
)
45-
.encode(asdict(self))
46+
.encode([asdict(log) for log in logs])
4647
.encode()
4748
)

src/modules/csm/state.py

+85-40
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import pickle
44
from collections import defaultdict
55
from dataclasses import dataclass
6+
from functools import lru_cache
7+
from itertools import batched
68
from pathlib import Path
79
from typing import Self
810

@@ -33,6 +35,10 @@ def add_duty(self, included: bool) -> None:
3335
self.included += 1 if included else 0
3436

3537

38+
type Frame = tuple[EpochNumber, EpochNumber]
39+
type StateData = dict[Frame, defaultdict[ValidatorIndex, AttestationsAccumulator]]
40+
41+
3642
class State:
3743
"""
3844
Processing state of a CSM performance oracle frame.
@@ -43,16 +49,17 @@ class State:
4349
4450
The state can be migrated to be used for another frame's report by calling the `migrate` method.
4551
"""
46-
47-
data: defaultdict[ValidatorIndex, AttestationsAccumulator]
52+
frames: list[Frame]
53+
data: StateData
4854

4955
_epochs_to_process: tuple[EpochNumber, ...]
5056
_processed_epochs: set[EpochNumber]
5157

5258
_consensus_version: int = 1
5359

54-
def __init__(self, data: dict[ValidatorIndex, AttestationsAccumulator] | None = None) -> None:
55-
self.data = defaultdict(AttestationsAccumulator, data or {})
60+
def __init__(self) -> None:
61+
self.frames = []
62+
self.data = {}
5663
self._epochs_to_process = tuple()
5764
self._processed_epochs = set()
5865

@@ -89,22 +96,55 @@ def file(cls) -> Path:
8996
def buffer(self) -> Path:
9097
return self.file().with_suffix(".buf")
9198

99+
@property
100+
def is_empty(self) -> bool:
101+
return not self.data and not self._epochs_to_process and not self._processed_epochs
102+
103+
@property
104+
def unprocessed_epochs(self) -> set[EpochNumber]:
105+
if not self._epochs_to_process:
106+
raise ValueError("Epochs to process are not set")
107+
diff = set(self._epochs_to_process) - self._processed_epochs
108+
return diff
109+
110+
@property
111+
def is_fulfilled(self) -> bool:
112+
return not self.unprocessed_epochs
113+
114+
@staticmethod
115+
def _calculate_frames(epochs_to_process: tuple[EpochNumber, ...], epochs_per_frame: int) -> list[Frame]:
116+
"""Split epochs to process into frames of `epochs_per_frame` length"""
117+
if len(epochs_to_process) % epochs_per_frame != 0:
118+
raise ValueError("Insufficient epochs to form a frame")
119+
return [(frame[0], frame[-1]) for frame in batched(sorted(epochs_to_process), epochs_per_frame)]
120+
92121
def clear(self) -> None:
93-
self.data = defaultdict(AttestationsAccumulator)
122+
self.data = {}
94123
self._epochs_to_process = tuple()
95124
self._processed_epochs.clear()
96125
assert self.is_empty
97126

98-
def inc(self, key: ValidatorIndex, included: bool) -> None:
99-
self.data[key].add_duty(included)
127+
@lru_cache(variables.CSM_ORACLE_MAX_CONCURRENCY)
128+
def find_frame(self, epoch: EpochNumber) -> Frame:
129+
for epoch_range in self.frames:
130+
from_epoch, to_epoch = epoch_range
131+
if from_epoch <= epoch <= to_epoch:
132+
return epoch_range
133+
raise ValueError(f"Epoch {epoch} is out of frames range: {self.frames}")
134+
135+
def increment_duty(self, epoch: EpochNumber, val_index: ValidatorIndex, included: bool) -> None:
136+
frame = self.find_frame(epoch)
137+
self.data[frame][val_index].add_duty(included)
100138

101139
def add_processed_epoch(self, epoch: EpochNumber) -> None:
102140
self._processed_epochs.add(epoch)
103141

104142
def log_progress(self) -> None:
105143
logger.info({"msg": f"Processed {len(self._processed_epochs)} of {len(self._epochs_to_process)} epochs"})
106144

107-
def migrate(self, l_epoch: EpochNumber, r_epoch: EpochNumber, consensus_version: int):
145+
def migrate(
146+
self, l_epoch: EpochNumber, r_epoch: EpochNumber, epochs_per_frame: int, consensus_version: int
147+
) -> None:
108148
if consensus_version != self._consensus_version:
109149
logger.warning(
110150
{
@@ -114,17 +154,41 @@ def migrate(self, l_epoch: EpochNumber, r_epoch: EpochNumber, consensus_version:
114154
)
115155
self.clear()
116156

117-
for state_epochs in (self._epochs_to_process, self._processed_epochs):
118-
for epoch in state_epochs:
119-
if epoch < l_epoch or epoch > r_epoch:
120-
logger.warning({"msg": "Discarding invalidated state cache"})
121-
self.clear()
122-
break
157+
new_frames = self._calculate_frames(tuple(sequence(l_epoch, r_epoch)), epochs_per_frame)
158+
if self.frames == new_frames:
159+
logger.info({"msg": "No need to migrate duties data cache"})
160+
return
161+
self._migrate_frames_data(new_frames)
123162

163+
self.frames = new_frames
164+
self.find_frame.cache_clear()
124165
self._epochs_to_process = tuple(sequence(l_epoch, r_epoch))
125166
self._consensus_version = consensus_version
126167
self.commit()
127168

169+
def _migrate_frames_data(self, new_frames: list[Frame]):
170+
logger.info({"msg": f"Migrating duties data cache: {self.frames=} -> {new_frames=}"})
171+
new_data: StateData = {frame: defaultdict(AttestationsAccumulator) for frame in new_frames}
172+
173+
def overlaps(a: Frame, b: Frame):
174+
return a[0] <= b[0] and a[1] >= b[1]
175+
176+
consumed = []
177+
for new_frame in new_frames:
178+
for frame_to_consume in self.frames:
179+
if overlaps(new_frame, frame_to_consume):
180+
assert frame_to_consume not in consumed
181+
consumed.append(frame_to_consume)
182+
for val, duty in self.data[frame_to_consume].items():
183+
new_data[new_frame][val].assigned += duty.assigned
184+
new_data[new_frame][val].included += duty.included
185+
for frame in self.frames:
186+
if frame in consumed:
187+
continue
188+
logger.warning({"msg": f"Invalidating frame duties data cache: {frame}"})
189+
self._processed_epochs -= set(sequence(*frame))
190+
self.data = new_data
191+
128192
def validate(self, l_epoch: EpochNumber, r_epoch: EpochNumber) -> None:
129193
if not self.is_fulfilled:
130194
raise InvalidState(f"State is not fulfilled. {self.unprocessed_epochs=}")
@@ -135,34 +199,15 @@ def validate(self, l_epoch: EpochNumber, r_epoch: EpochNumber) -> None:
135199

136200
for epoch in sequence(l_epoch, r_epoch):
137201
if epoch not in self._processed_epochs:
138-
raise InvalidState(f"Epoch {epoch} should be processed")
139-
140-
@property
141-
def is_empty(self) -> bool:
142-
return not self.data and not self._epochs_to_process and not self._processed_epochs
143-
144-
@property
145-
def unprocessed_epochs(self) -> set[EpochNumber]:
146-
if not self._epochs_to_process:
147-
raise ValueError("Epochs to process are not set")
148-
diff = set(self._epochs_to_process) - self._processed_epochs
149-
return diff
150-
151-
@property
152-
def is_fulfilled(self) -> bool:
153-
return not self.unprocessed_epochs
154-
155-
@property
156-
def frame(self) -> tuple[EpochNumber, EpochNumber]:
157-
if not self._epochs_to_process:
158-
raise ValueError("Epochs to process are not set")
159-
return min(self._epochs_to_process), max(self._epochs_to_process)
160-
161-
def get_network_aggr(self) -> AttestationsAccumulator:
162-
"""Return `AttestationsAccumulator` over duties of all the network validators"""
202+
raise InvalidState(f"Epoch {epoch} missing in processed epochs")
163203

204+
def get_network_aggr(self, frame: Frame) -> AttestationsAccumulator:
205+
# TODO: exclude `active_slashed` validators from the calculation
164206
included = assigned = 0
165-
for validator, acc in self.data.items():
207+
frame_data = self.data.get(frame)
208+
if frame_data is None:
209+
raise ValueError(f"No data for frame {frame} to calculate network aggregate")
210+
for validator, acc in frame_data.items():
166211
if acc.included > acc.assigned:
167212
raise ValueError(f"Invalid accumulator: {validator=}, {acc=}")
168213
included += acc.included

src/providers/execution/contracts/cs_fee_distributor.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from eth_typing import ChecksumAddress
44
from hexbytes import HexBytes
55
from web3 import Web3
6-
from web3.types import BlockIdentifier
6+
from web3.types import BlockIdentifier, Wei
77

88
from ..base_interface import ContractInterface
99

@@ -26,7 +26,7 @@ def oracle(self, block_identifier: BlockIdentifier = "latest") -> ChecksumAddres
2626
)
2727
return Web3.to_checksum_address(resp)
2828

29-
def shares_to_distribute(self, block_identifier: BlockIdentifier = "latest") -> int:
29+
def shares_to_distribute(self, block_identifier: BlockIdentifier = "latest") -> Wei:
3030
"""Returns the amount of shares that are pending to be distributed"""
3131

3232
resp = self.functions.pendingSharesToDistribute().call(block_identifier=block_identifier)

tests/modules/csm/test_checkpoint.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ def test_checkpoints_processor_no_eip7549_support(
326326
monkeypatch: pytest.MonkeyPatch,
327327
):
328328
state = State()
329-
state.migrate(EpochNumber(0), EpochNumber(255), 1)
329+
state.migrate(EpochNumber(0), EpochNumber(255), 256, 1)
330330
processor = FrameCheckpointProcessor(
331331
consensus_client,
332332
state,
@@ -354,7 +354,7 @@ def test_checkpoints_processor_check_duty(
354354
converter,
355355
):
356356
state = State()
357-
state.migrate(0, 255, 1)
357+
state.migrate(0, 255, 256, 1)
358358
finalized_blockstamp = ...
359359
processor = FrameCheckpointProcessor(
360360
consensus_client,
@@ -367,7 +367,7 @@ def test_checkpoints_processor_check_duty(
367367
assert len(state._processed_epochs) == 1
368368
assert len(state._epochs_to_process) == 256
369369
assert len(state.unprocessed_epochs) == 255
370-
assert len(state.data) == 2048 * 32
370+
assert len(state.data[(0, 255)]) == 2048 * 32
371371

372372

373373
def test_checkpoints_processor_process(
@@ -379,7 +379,7 @@ def test_checkpoints_processor_process(
379379
converter,
380380
):
381381
state = State()
382-
state.migrate(0, 255, 1)
382+
state.migrate(0, 255, 256, 1)
383383
finalized_blockstamp = ...
384384
processor = FrameCheckpointProcessor(
385385
consensus_client,
@@ -392,7 +392,7 @@ def test_checkpoints_processor_process(
392392
assert len(state._processed_epochs) == 2
393393
assert len(state._epochs_to_process) == 256
394394
assert len(state.unprocessed_epochs) == 254
395-
assert len(state.data) == 2048 * 32
395+
assert len(state.data[(0, 255)]) == 2048 * 32
396396

397397

398398
def test_checkpoints_processor_exec(
@@ -404,7 +404,7 @@ def test_checkpoints_processor_exec(
404404
converter,
405405
):
406406
state = State()
407-
state.migrate(0, 255, 1)
407+
state.migrate(0, 255, 256, 1)
408408
finalized_blockstamp = ...
409409
processor = FrameCheckpointProcessor(
410410
consensus_client,
@@ -418,4 +418,4 @@ def test_checkpoints_processor_exec(
418418
assert len(state._processed_epochs) == 2
419419
assert len(state._epochs_to_process) == 256
420420
assert len(state.unprocessed_epochs) == 254
421-
assert len(state.data) == 2048 * 32
421+
assert len(state.data[(0, 255)]) == 2048 * 32

0 commit comments

Comments
 (0)