Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit e8a36cb

Browse files
committedOct 31, 2024··
feat: state data is list of tuples now
1 parent 54b5f3a commit e8a36cb

File tree

6 files changed

+87
-68
lines changed

6 files changed

+87
-68
lines changed
 

‎src/modules/csm/checkpoint.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class FrameCheckpoint:
3333
@dataclass
3434
class ValidatorDuty:
3535
index: ValidatorIndex
36-
included: bool
36+
is_included: bool
3737

3838

3939
class FrameCheckpointsIterator:
@@ -200,7 +200,7 @@ def _check_duty(
200200
for validator_duty in committee:
201201
self.state.inc(
202202
validator_duty.index,
203-
included=validator_duty.included,
203+
validator_duty.is_included,
204204
)
205205
if duty_epoch not in self.state.unprocessed_epochs:
206206
raise ValueError(f"Epoch {duty_epoch} is not in epochs that should be processed")
@@ -222,7 +222,7 @@ def _prepare_committees(self, epoch: EpochNumber) -> Committees:
222222
validators = []
223223
# Order of insertion is used to track the positions in the committees.
224224
for validator in committee.validators:
225-
validators.append(ValidatorDuty(index=ValidatorIndex(int(validator)), included=False))
225+
validators.append(ValidatorDuty(index=ValidatorIndex(int(validator)), is_included=False))
226226
committees[(committee.slot, committee.index)] = validators
227227
return committees
228228

@@ -233,7 +233,7 @@ def process_attestations(attestations: Iterable[BlockAttestation], committees: C
233233
committee = committees.get(committee_id, [])
234234
att_bits = _to_bits(attestation.aggregation_bits)
235235
for index_in_committee, validator_duty in enumerate(committee):
236-
validator_duty.included = validator_duty.included or _is_attested(att_bits, index_in_committee)
236+
validator_duty.is_included = validator_duty.is_included or _is_attested(att_bits, index_in_committee)
237237

238238

239239
def _is_attested(bits: Sequence[bool], index: int) -> bool:

‎src/modules/csm/csm.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
)
1313
from src.metrics.prometheus.duration_meter import duration_meter
1414
from src.modules.csm.checkpoint import FrameCheckpointProcessor, FrameCheckpointsIterator, MinStepIsNotReached
15-
from src.modules.csm.log import FramePerfLog
16-
from src.modules.csm.state import State
15+
from src.modules.csm.log import FramePerfLog, AttestationsAccumulatorLog
16+
from src.modules.csm.state import State, perf, Assigned, Included, AttestationsAccumulator
1717
from src.modules.csm.tree import Tree
1818
from src.modules.csm.types import ReportData, Shares
1919
from src.modules.submodules.consensus import ConsensusModule
@@ -228,7 +228,7 @@ def calculate_distribution(
228228
) -> tuple[int, defaultdict[NodeOperatorId, int], FramePerfLog]:
229229
"""Computes distribution of fee shares at the given timestamp"""
230230

231-
network_avg_perf = self.state.get_network_aggr().perf
231+
network_avg_perf = perf(self.state.get_network_aggr())
232232
threshold = network_avg_perf - self.w3.csm.oracle.perf_leeway_bp(blockstamp.block_hash) / TOTAL_BASIS_POINTS
233233
operators_to_validators = self.module_validators_by_node_operators(blockstamp)
234234

@@ -243,9 +243,14 @@ def calculate_distribution(
243243
continue
244244

245245
for v in validators:
246-
aggr = self.state.data.get(ValidatorIndex(int(v.index)))
246+
aggr = (
247+
self.state.data[ValidatorIndex(int(v.index))] or
248+
AttestationsAccumulator((Assigned(0), Included(0)))
249+
)
247250

248-
if aggr is None:
251+
assigned, included = aggr
252+
253+
if not assigned:
249254
# It's possible that the validator is not assigned to any duty, hence it's performance
250255
# is not presented in the aggregates (e.g. exited, pending for activation etc).
251256
continue
@@ -256,12 +261,12 @@ def calculate_distribution(
256261
log.operators[no_id].validators[v.index].slashed = True
257262
continue
258263

259-
if aggr.perf > threshold:
264+
if perf(aggr) > threshold:
260265
# Count of assigned attestations used as a metrics of time
261266
# the validator was active in the current frame.
262-
distribution[no_id] += aggr.assigned
267+
distribution[no_id] += assigned
263268

264-
log.operators[no_id].validators[v.index].perf = aggr
269+
log.operators[no_id].validators[v.index].perf = AttestationsAccumulatorLog(assigned, included)
265270

266271
# Calculate share of each CSM node operator.
267272
shares = defaultdict[NodeOperatorId, int](int)

‎src/modules/csm/log.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,23 @@
22
from collections import defaultdict
33
from dataclasses import asdict, dataclass, field
44

5-
from src.modules.csm.state import AttestationsAccumulator
65
from src.modules.csm.types import Shares
76
from src.types import EpochNumber, NodeOperatorId, ReferenceBlockStamp
87

98

109
class LogJSONEncoder(json.JSONEncoder): ...
1110

1211

12+
@dataclass
13+
class AttestationsAccumulatorLog:
14+
assigned: int = 0
15+
included: int = 0
16+
17+
1318
@dataclass
1419
class ValidatorFrameSummary:
15-
perf: AttestationsAccumulator = field(default_factory=AttestationsAccumulator)
20+
# TODO: Should be renamed. Perf means different things in different contexts
21+
perf: AttestationsAccumulatorLog = field(default_factory=AttestationsAccumulatorLog)
1622
slashed: bool = False
1723

1824

‎src/modules/csm/state.py

+28-33
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
import logging
22
import os
33
import pickle
4-
from collections import defaultdict
5-
from dataclasses import dataclass
64
from pathlib import Path
7-
from typing import Self
5+
from typing import Self, NewType
86

97
from src.types import EpochNumber, ValidatorIndex
108
from src.utils.range import sequence
@@ -17,20 +15,14 @@ class InvalidState(ValueError):
1715
"""State has data considered as invalid for a report"""
1816

1917

20-
@dataclass
21-
class AttestationsAccumulator:
22-
"""Accumulator of attestations duties observed for a validator"""
18+
Assigned = NewType("Assigned", int)
19+
Included = NewType("Included", int)
20+
AttestationsAccumulator = NewType('AttestationsAccumulator', tuple[Assigned, Included])
2321

24-
assigned: int = 0
25-
included: int = 0
2622

27-
@property
28-
def perf(self) -> float:
29-
return self.included / self.assigned if self.assigned else 0
30-
31-
def add_duty(self, included: bool) -> None:
32-
self.assigned += 1
33-
self.included += 1 if included else 0
23+
def perf(acc: AttestationsAccumulator) -> float:
24+
assigned, included = acc
25+
return included / assigned if assigned else 0
3426

3527

3628
class State:
@@ -43,14 +35,14 @@ class State:
4335
4436
The state can be migrated to be used for another frame's report by calling the `migrate` method.
4537
"""
46-
47-
data: defaultdict[ValidatorIndex, AttestationsAccumulator]
38+
# validator_index -> (assigned, included)
39+
data: list[AttestationsAccumulator | None]
4840

4941
_epochs_to_process: set[EpochNumber]
5042
_processed_epochs: set[EpochNumber]
5143

52-
def __init__(self, data: dict[ValidatorIndex, AttestationsAccumulator] | None = None) -> None:
53-
self.data = defaultdict(AttestationsAccumulator, data or {})
44+
def __init__(self, data: list[AttestationsAccumulator | None] | None = None) -> None:
45+
self.data = data or []
5446
self._epochs_to_process = set()
5547
self._processed_epochs = set()
5648

@@ -88,13 +80,16 @@ def buffer(self) -> Path:
8880
return self.file().with_suffix(".buf")
8981

9082
def clear(self) -> None:
91-
self.data = defaultdict(AttestationsAccumulator)
83+
self.data = []
9284
self._epochs_to_process.clear()
9385
self._processed_epochs.clear()
9486
assert self.is_empty
9587

96-
def inc(self, key: ValidatorIndex, included: bool) -> None:
97-
self.data[key].add_duty(included)
88+
def inc(self, key: ValidatorIndex, is_included: bool) -> None:
89+
if key >= len(self.data):
90+
self.data += [None] * (key - len(self.data) + 1)
91+
assigned, included = self.data[key] or (Assigned(0), Included(0))
92+
self.data[key] = AttestationsAccumulator((Assigned(assigned + 1), Included(included + 1 if is_included else included)))
9893

9994
def add_processed_epoch(self, epoch: EpochNumber) -> None:
10095
self._processed_epochs.add(epoch)
@@ -149,15 +144,15 @@ def frame(self) -> tuple[EpochNumber, EpochNumber]:
149144
def get_network_aggr(self) -> AttestationsAccumulator:
150145
"""Return `AttestationsAccumulator` over duties of all the network validators"""
151146

152-
included = assigned = 0
153-
for validator, acc in self.data.items():
154-
if acc.included > acc.assigned:
155-
raise ValueError(f"Invalid accumulator: {validator=}, {acc=}")
156-
included += acc.included
157-
assigned += acc.assigned
158-
aggr = AttestationsAccumulator(
159-
included=included,
160-
assigned=assigned,
161-
)
162-
logger.info({"msg": "Network attestations aggregate computed", "value": repr(aggr), "avg_perf": aggr.perf})
147+
net_included = net_assigned = 0
148+
for validator_index, acc in enumerate(self.data):
149+
if acc is None:
150+
continue
151+
assigned, included = acc
152+
if included > assigned:
153+
raise ValueError(f"Invalid accumulator: {validator_index=}, {acc=}")
154+
net_included += included
155+
net_assigned += assigned
156+
aggr = AttestationsAccumulator((Assigned(net_assigned), Included(net_included)))
157+
logger.info({"msg": "Network attestations aggregate computed", "value": repr(aggr), "avg_perf": perf(aggr)})
163158
return aggr

‎tests/modules/csm/test_checkpoint.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def test_checkpoints_processor_prepare_committees(mock_get_attestation_committee
237237
assert int(committee_index) == committee_from_raw.index
238238
assert len(validators) == 32
239239
for validator in validators:
240-
assert validator.included is False
240+
assert validator.is_included is False
241241

242242

243243
def test_checkpoints_processor_process_attestations(mock_get_attestation_committees, consensus_client, converter):
@@ -265,9 +265,9 @@ def test_checkpoints_processor_process_attestations(mock_get_attestation_committ
265265
for validator in validators:
266266
# only the first attestation is accounted
267267
if index == 0:
268-
assert validator.included is True
268+
assert validator.is_included is True
269269
else:
270-
assert validator.included is False
270+
assert validator.is_included is False
271271

272272

273273
def test_checkpoints_processor_process_attestations_undefined_committee(
@@ -290,7 +290,7 @@ def test_checkpoints_processor_process_attestations_undefined_committee(
290290
process_attestations([attestation], committees)
291291
for validators in committees.values():
292292
for v in validators:
293-
assert v.included is False
293+
assert v.is_included is False
294294

295295

296296
@pytest.fixture()

‎tests/modules/csm/test_csm_module.py

+30-17
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from src.constants import UINT64_MAX
1111
from src.modules.csm.csm import CSOracle
12-
from src.modules.csm.state import AttestationsAccumulator, State
12+
from src.modules.csm.state import AttestationsAccumulator, State, perf
1313
from src.modules.csm.tree import Tree
1414
from src.modules.submodules.oracle_module import ModuleExecuteDelay
1515
from src.modules.submodules.types import CurrentFrame, ZERO_HASH
@@ -118,21 +118,34 @@ def test_calculate_distribution(module: CSOracle, csm: CSM):
118118
)
119119

120120
module.state = State(
121-
{
122-
ValidatorIndex(0): AttestationsAccumulator(included=200, assigned=200), # short on frame
123-
ValidatorIndex(1): AttestationsAccumulator(included=1000, assigned=1000),
124-
ValidatorIndex(2): AttestationsAccumulator(included=1000, assigned=1000),
125-
ValidatorIndex(3): AttestationsAccumulator(included=999, assigned=1000),
126-
ValidatorIndex(4): AttestationsAccumulator(included=900, assigned=1000),
127-
ValidatorIndex(5): AttestationsAccumulator(included=500, assigned=1000), # underperforming
128-
ValidatorIndex(6): AttestationsAccumulator(included=0, assigned=0), # underperforming
129-
ValidatorIndex(7): AttestationsAccumulator(included=900, assigned=1000),
130-
ValidatorIndex(8): AttestationsAccumulator(included=500, assigned=1000), # underperforming
131-
# ValidatorIndex(9): AttestationsAggregate(included=0, assigned=0), # missing in state
132-
ValidatorIndex(10): AttestationsAccumulator(included=1000, assigned=1000),
133-
ValidatorIndex(11): AttestationsAccumulator(included=1000, assigned=1000),
134-
ValidatorIndex(12): AttestationsAccumulator(included=1000, assigned=1000),
135-
}
121+
[
122+
# ValidatorIndex(0):
123+
AttestationsAccumulator((200, 200)), # short on frame
124+
# ValidatorIndex(1):
125+
AttestationsAccumulator((1000, 1000)),
126+
# ValidatorIndex(2):
127+
AttestationsAccumulator((1000, 1000)),
128+
# ValidatorIndex(3):
129+
AttestationsAccumulator((1000, 999)),
130+
# ValidatorIndex(4):
131+
AttestationsAccumulator((1000, 900)),
132+
# ValidatorIndex(5):
133+
AttestationsAccumulator((1000, 500)), # underperforming
134+
# ValidatorIndex(6):
135+
AttestationsAccumulator((0, 0)), # underperforming
136+
# ValidatorIndex(7):
137+
AttestationsAccumulator((1000, 900)),
138+
# ValidatorIndex(8):
139+
AttestationsAccumulator((1000, 500)), # underperforming
140+
# ValidatorIndex(9):
141+
None, # missing in state
142+
# ValidatorIndex(10):
143+
AttestationsAccumulator((1000, 1000)),
144+
# ValidatorIndex(11):
145+
AttestationsAccumulator((1000, 1000)),
146+
# ValidatorIndex(12):
147+
AttestationsAccumulator((1000, 1000)),
148+
]
136149
)
137150
module.state.migrate(EpochNumber(100), EpochNumber(500))
138151

@@ -177,7 +190,7 @@ def test_calculate_distribution(module: CSOracle, csm: CSM):
177190
assert log.operators[NodeOperatorId(6)].distributed == 2380
178191

179192
assert log.frame == (100, 500)
180-
assert log.threshold == module.state.get_network_aggr().perf - 0.05
193+
assert log.threshold == perf(module.state.get_network_aggr()) - 0.05
181194

182195

183196
# Static functions you were dreaming of for so long.

0 commit comments

Comments
 (0)