Skip to content

Commit 7a7b32f

Browse files
committed
feat: state data is list of sequences now
1 parent 54b5f3a commit 7a7b32f

10 files changed

+455
-192
lines changed

src/modules/csm/checkpoint.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020
lock = Lock()
2121

2222

23-
class MinStepIsNotReached(Exception):
24-
...
23+
class MinStepIsNotReached(Exception): ...
2524

2625

2726
@dataclass
@@ -33,7 +32,7 @@ class FrameCheckpoint:
3332
@dataclass
3433
class ValidatorDuty:
3534
index: ValidatorIndex
36-
included: bool
35+
is_included: bool
3736

3837

3938
class FrameCheckpointsIterator:
@@ -198,9 +197,10 @@ def _check_duty(
198197
with lock:
199198
for committee in committees.values():
200199
for validator_duty in committee:
201-
self.state.inc(
200+
self.state.set_duty_status(
201+
duty_epoch,
202202
validator_duty.index,
203-
included=validator_duty.included,
203+
validator_duty.is_included,
204204
)
205205
if duty_epoch not in self.state.unprocessed_epochs:
206206
raise ValueError(f"Epoch {duty_epoch} is not in epochs that should be processed")
@@ -222,7 +222,7 @@ def _prepare_committees(self, epoch: EpochNumber) -> Committees:
222222
validators = []
223223
# Order of insertion is used to track the positions in the committees.
224224
for validator in committee.validators:
225-
validators.append(ValidatorDuty(index=ValidatorIndex(int(validator)), included=False))
225+
validators.append(ValidatorDuty(index=ValidatorIndex(int(validator)), is_included=False))
226226
committees[(committee.slot, committee.index)] = validators
227227
return committees
228228

@@ -233,7 +233,7 @@ def process_attestations(attestations: Iterable[BlockAttestation], committees: C
233233
committee = committees.get(committee_id, [])
234234
att_bits = _to_bits(attestation.aggregation_bits)
235235
for index_in_committee, validator_duty in enumerate(committee):
236-
validator_duty.included = validator_duty.included or _is_attested(att_bits, index_in_committee)
236+
validator_duty.is_included = validator_duty.is_included or _is_attested(att_bits, index_in_committee)
237237

238238

239239
def _is_attested(bits: Sequence[bool], index: int) -> bool:

src/modules/csm/csm.py

+58-20
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
)
1313
from src.metrics.prometheus.duration_meter import duration_meter
1414
from src.modules.csm.checkpoint import FrameCheckpointProcessor, FrameCheckpointsIterator, MinStepIsNotReached
15-
from src.modules.csm.log import FramePerfLog
15+
from src.modules.csm.duties.attestation import calc_performance
16+
from src.modules.csm.log import FramePerfLog, AttestationsAccumulatorLog
1617
from src.modules.csm.state import State
1718
from src.modules.csm.tree import Tree
1819
from src.modules.csm.types import ReportData, Shares
@@ -33,7 +34,7 @@
3334
)
3435
from src.utils.blockstamp import build_blockstamp
3536
from src.utils.cache import global_lru_cache as lru_cache
36-
from src.utils.slot import get_next_non_missed_slot
37+
from src.utils.slot import get_next_non_missed_slot, get_reference_blockstamp
3738
from src.utils.web3converter import Web3Converter
3839
from src.web3py.extensions.lido_validators import NodeOperatorId, StakingModule, ValidatorsByNodeOperator
3940
from src.web3py.types import Web3
@@ -102,12 +103,12 @@ def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple:
102103
if (prev_cid is None) != (prev_root == ZERO_HASH):
103104
raise InconsistentData(f"Got inconsistent previous tree data: {prev_root=} {prev_cid=}")
104105

105-
distributed, shares, log = self.calculate_distribution(blockstamp)
106+
distributed, shares, logs = self.calculate_distribution(blockstamp)
106107

107108
if distributed != sum(shares.values()):
108109
raise InconsistentData(f"Invalid distribution: {sum(shares.values())=} != {distributed=}")
109110

110-
log_cid = self.publish_log(log)
111+
log_cid = self.publish_log(logs)
111112

112113
if not distributed and not shares:
113114
logger.info({"msg": "No shares distributed in the current frame"})
@@ -225,27 +226,64 @@ def collect_data(self, blockstamp: BlockStamp) -> bool:
225226

226227
def calculate_distribution(
227228
self, blockstamp: ReferenceBlockStamp
228-
) -> tuple[int, defaultdict[NodeOperatorId, int], FramePerfLog]:
229+
) -> tuple[int, defaultdict[NodeOperatorId, int], list[FramePerfLog]]:
229230
"""Computes distribution of fee shares at the given timestamp"""
230-
231-
network_avg_perf = self.state.get_network_aggr().perf
232-
threshold = network_avg_perf - self.w3.csm.oracle.perf_leeway_bp(blockstamp.block_hash) / TOTAL_BASIS_POINTS
233231
operators_to_validators = self.module_validators_by_node_operators(blockstamp)
234232

233+
distributed = 0
234+
# Calculate share of each CSM node operator.
235+
shares = defaultdict[NodeOperatorId, int](int)
236+
logs: list[FramePerfLog] = []
237+
238+
converter = self.converter(blockstamp)
239+
frames = self.state.calc_frames(converter.frame_config.epochs_per_frame)
240+
for from_epoch, to_epoch in frames:
241+
frame_blockstamp = blockstamp
242+
frame_ref_slot = converter.get_epoch_first_slot(to_epoch)
243+
if blockstamp.slot_number != frame_ref_slot:
244+
frame_blockstamp = get_reference_blockstamp(
245+
cc=self.w3.cc,
246+
ref_slot=converter.get_epoch_first_slot(to_epoch),
247+
ref_epoch=to_epoch,
248+
last_finalized_slot_number=blockstamp.slot_number,
249+
)
250+
distributed_in_frame, shares_in_frame, log = self._calculate_distribution_in_frame(
251+
frame_blockstamp, operators_to_validators, from_epoch, to_epoch, distributed
252+
)
253+
distributed += distributed_in_frame
254+
for no_id, share in shares_in_frame.items():
255+
shares[no_id] += share
256+
logs.append(log)
257+
258+
return distributed, shares, logs
259+
260+
def _calculate_distribution_in_frame(
261+
self,
262+
blockstamp: ReferenceBlockStamp,
263+
operators_to_validators: ValidatorsByNodeOperator,
264+
from_epoch: EpochNumber,
265+
to_epoch: EpochNumber,
266+
distributed: int,
267+
):
268+
network_perf = self.state.calc_network_perf(from_epoch, to_epoch)
269+
threshold = network_perf - self.w3.csm.oracle.perf_leeway_bp(blockstamp.block_hash) / TOTAL_BASIS_POINTS
270+
235271
# Build the map of the current distribution operators.
236272
distribution: dict[NodeOperatorId, int] = defaultdict(int)
237273
stuck_operators = self.stuck_operators(blockstamp)
238-
log = FramePerfLog(blockstamp, self.state.frame, threshold)
274+
log = FramePerfLog(blockstamp, (from_epoch, to_epoch), threshold)
239275

240276
for (_, no_id), validators in operators_to_validators.items():
241277
if no_id in stuck_operators:
242278
log.operators[no_id].stuck = True
243279
continue
244280

245281
for v in validators:
246-
aggr = self.state.data.get(ValidatorIndex(int(v.index)))
282+
missed = self.state.count_missed(ValidatorIndex(int(v.index)), from_epoch, to_epoch)
283+
included = self.state.count_included(ValidatorIndex(int(v.index)), from_epoch, to_epoch)
284+
assigned = missed + included
247285

248-
if aggr is None:
286+
if not assigned:
249287
# It's possible that the validator is not assigned to any duty, hence it's performance
250288
# is not presented in the aggregates (e.g. exited, pending for activation etc).
251289
continue
@@ -256,23 +294,23 @@ def calculate_distribution(
256294
log.operators[no_id].validators[v.index].slashed = True
257295
continue
258296

259-
if aggr.perf > threshold:
297+
perf = calc_performance(included, missed)
298+
if perf > threshold:
260299
# Count of assigned attestations used as a metrics of time
261300
# the validator was active in the current frame.
262-
distribution[no_id] += aggr.assigned
301+
distribution[no_id] += assigned
263302

264-
log.operators[no_id].validators[v.index].perf = aggr
303+
log.operators[no_id].validators[v.index].perf = AttestationsAccumulatorLog(assigned, included)
265304

266305
# Calculate share of each CSM node operator.
267306
shares = defaultdict[NodeOperatorId, int](int)
268307
total = sum(p for p in distribution.values())
308+
to_distribute = self.w3.csm.fee_distributor.shares_to_distribute(blockstamp.block_hash) - distributed
309+
log.distributable = to_distribute
269310

270311
if not total:
271312
return 0, shares, log
272313

273-
to_distribute = self.w3.csm.fee_distributor.shares_to_distribute(blockstamp.block_hash)
274-
log.distributable = to_distribute
275-
276314
for no_id, no_share in distribution.items():
277315
if no_share:
278316
shares[no_id] = to_distribute * no_share // total
@@ -343,9 +381,9 @@ def publish_tree(self, tree: Tree) -> CID:
343381
logger.info({"msg": "Tree dump uploaded to IPFS", "cid": repr(tree_cid)})
344382
return tree_cid
345383

346-
def publish_log(self, log: FramePerfLog) -> CID:
347-
log_cid = self.w3.ipfs.publish(log.encode())
348-
logger.info({"msg": "Frame log uploaded to IPFS", "cid": repr(log_cid)})
384+
def publish_log(self, logs: list[FramePerfLog]) -> CID:
385+
log_cid = self.w3.ipfs.publish(FramePerfLog.encode(logs))
386+
logger.info({"msg": "Frame(s) log uploaded to IPFS", "cid": repr(log_cid)})
349387
return log_cid
350388

351389
@lru_cache(maxsize=1)

src/modules/csm/duties/attestation.py

+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from enum import Enum
2+
from typing import NewType
3+
4+
EpochIndexInFrame = NewType('EpochIndexInFrame', int)
5+
6+
7+
class AttestationStatus(Enum):
8+
NO_DUTY = None
9+
MISSED = False
10+
INCLUDED = True
11+
12+
13+
def calc_performance(included: int, missed: int) -> float:
14+
all_ = missed + included
15+
return included / all_ if all_ else 0
16+
17+
18+
class AttestationSequence(list):
19+
"""
20+
Duties sequence for a CSM performance oracle frame for validators.
21+
22+
It is bits sequence where each pair of bits represents the duty status of a validator for a specific epoch:
23+
- None - no duty
24+
- False - missed. assigned but not included attestation
25+
- True - included attestation
26+
27+
Every index in the sequence corresponds to epoch index in report frame.
28+
For example:
29+
Report frame is [100000 epoch, ..., 100510 epoch],
30+
We need to write duty status for epoch 100000 at index 0, for epoch 100001 at index 1, and so on.
31+
"""
32+
33+
def __init__(self, size: int):
34+
super().__init__([AttestationStatus.NO_DUTY] * size)
35+
36+
def __str__(self):
37+
missed, included = self.count_missed(), self.count_included()
38+
return f"{self.__class__.__name__}({missed=}, {included=})"
39+
40+
def _validate_range(self, from_index: EpochIndexInFrame, to_index: EpochIndexInFrame):
41+
if from_index < 0 or to_index > len(self) or from_index >= to_index:
42+
raise ValueError("Invalid range for from_index and to_index")
43+
44+
def count_missed(self, from_index: EpochIndexInFrame = 0, to_index: EpochIndexInFrame = None):
45+
if to_index is None:
46+
to_index = len(self)
47+
self._validate_range(from_index, to_index)
48+
return self[from_index:to_index].count(AttestationStatus.MISSED)
49+
50+
def count_included(self, from_index: EpochIndexInFrame = 0, to_index: EpochIndexInFrame = None):
51+
if to_index is None:
52+
to_index = len(self)
53+
self._validate_range(from_index, to_index)
54+
return self[from_index:to_index].count(AttestationStatus.INCLUDED)
55+
56+
def get_duty_status(self, epoch_index: EpochIndexInFrame) -> AttestationStatus:
57+
return self[epoch_index]
58+
59+
def set_duty_status(self, epoch_index: EpochIndexInFrame, duty_status: AttestationStatus):
60+
self[epoch_index] = duty_status

src/modules/csm/log.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,23 @@
22
from collections import defaultdict
33
from dataclasses import asdict, dataclass, field
44

5-
from src.modules.csm.state import AttestationsAccumulator
65
from src.modules.csm.types import Shares
76
from src.types import EpochNumber, NodeOperatorId, ReferenceBlockStamp
87

98

109
class LogJSONEncoder(json.JSONEncoder): ...
1110

1211

12+
@dataclass
13+
class AttestationsAccumulatorLog:
14+
assigned: int = 0
15+
included: int = 0
16+
17+
1318
@dataclass
1419
class ValidatorFrameSummary:
15-
perf: AttestationsAccumulator = field(default_factory=AttestationsAccumulator)
20+
# TODO: Should be renamed. Perf means different things in different contexts
21+
perf: AttestationsAccumulatorLog = field(default_factory=AttestationsAccumulatorLog)
1622
slashed: bool = False
1723

1824

@@ -35,13 +41,14 @@ class FramePerfLog:
3541
default_factory=lambda: defaultdict(OperatorFrameSummary)
3642
)
3743

38-
def encode(self) -> bytes:
44+
@staticmethod
45+
def encode(logs: list['FramePerfLog']) -> bytes:
3946
return (
4047
LogJSONEncoder(
4148
indent=None,
4249
separators=(',', ':'),
4350
sort_keys=True,
4451
)
45-
.encode(asdict(self))
52+
.encode([asdict(log) for log in logs])
4653
.encode()
4754
)

0 commit comments

Comments
 (0)