Skip to content

Commit e091d12

Browse files
committed
refactor: distribution and tests
1 parent 1aa7228 commit e091d12

File tree

7 files changed

+426
-336
lines changed

7 files changed

+426
-336
lines changed

src/modules/csm/csm.py

+102-71
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
)
1313
from src.metrics.prometheus.duration_meter import duration_meter
1414
from src.modules.csm.checkpoint import FrameCheckpointProcessor, FrameCheckpointsIterator, MinStepIsNotReached
15-
from src.modules.csm.log import FramePerfLog
16-
from src.modules.csm.state import State, Frame
15+
from src.modules.csm.log import FramePerfLog, OperatorFrameSummary
16+
from src.modules.csm.state import State, Frame, AttestationsAccumulator
1717
from src.modules.csm.tree import Tree
1818
from src.modules.csm.types import ReportData, Shares
1919
from src.modules.submodules.consensus import ConsensusModule
@@ -29,13 +29,12 @@
2929
SlotNumber,
3030
StakingModuleAddress,
3131
StakingModuleId,
32-
ValidatorIndex,
3332
)
3433
from src.utils.blockstamp import build_blockstamp
3534
from src.utils.cache import global_lru_cache as lru_cache
3635
from src.utils.slot import get_next_non_missed_slot, get_reference_blockstamp
3736
from src.utils.web3converter import Web3Converter
38-
from src.web3py.extensions.lido_validators import NodeOperatorId, StakingModule, ValidatorsByNodeOperator
37+
from src.web3py.extensions.lido_validators import NodeOperatorId, StakingModule, ValidatorsByNodeOperator, LidoValidator
3938
from src.web3py.types import Web3
4039

4140
logger = logging.getLogger(__name__)
@@ -102,15 +101,15 @@ def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple:
102101
if (prev_cid is None) != (prev_root == ZERO_HASH):
103102
raise InconsistentData(f"Got inconsistent previous tree data: {prev_root=} {prev_cid=}")
104103

105-
distributed, shares, logs = self.calculate_distribution(blockstamp)
104+
total_distributed, total_rewards, logs = self.calculate_distribution(blockstamp)
106105

107-
if distributed != sum(shares.values()):
108-
raise InconsistentData(f"Invalid distribution: {sum(shares.values())=} != {distributed=}")
106+
if total_distributed != sum(total_rewards.values()):
107+
raise InconsistentData(f"Invalid distribution: {sum(total_rewards.values())=} != {total_distributed=}")
109108

110109
log_cid = self.publish_log(logs)
111110

112-
if not distributed and not shares:
113-
logger.info({"msg": "No shares distributed in the current frame"})
111+
if not total_distributed and not total_rewards:
112+
logger.info({"msg": "No rewards distributed in the current frame"})
114113
return ReportData(
115114
self.get_consensus_version(blockstamp),
116115
blockstamp.ref_slot,
@@ -123,11 +122,11 @@ def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple:
123122
if prev_cid and prev_root != ZERO_HASH:
124123
# Update cumulative amount of shares for all operators.
125124
for no_id, acc_shares in self.get_accumulated_shares(prev_cid, prev_root):
126-
shares[no_id] += acc_shares
125+
total_rewards[no_id] += acc_shares
127126
else:
128127
logger.info({"msg": "No previous distribution. Nothing to accumulate"})
129128

130-
tree = self.make_tree(shares)
129+
tree = self.make_tree(total_rewards)
131130
tree_cid = self.publish_tree(tree)
132131

133132
return ReportData(
@@ -136,7 +135,7 @@ def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple:
136135
tree_root=tree.root,
137136
tree_cid=tree_cid,
138137
log_cid=log_cid,
139-
distributed=distributed,
138+
distributed=total_distributed,
140139
).as_tuple()
141140

142141
def is_main_data_submitted(self, blockstamp: BlockStamp) -> bool:
@@ -232,26 +231,36 @@ def calculate_distribution(
232231
"""Computes distribution of fee shares at the given timestamp"""
233232
operators_to_validators = self.module_validators_by_node_operators(blockstamp)
234233

235-
distributed = 0
236-
# Calculate share of each CSM node operator.
237-
shares = defaultdict[NodeOperatorId, int](int)
234+
total_distributed = 0
235+
total_rewards = defaultdict[NodeOperatorId, int](int)
238236
logs: list[FramePerfLog] = []
239237

240-
for frame in self.state.data:
238+
for frame in self.state.frames:
241239
from_epoch, to_epoch = frame
242240
logger.info({"msg": f"Calculating distribution for frame [{from_epoch};{to_epoch}]"})
241+
243242
frame_blockstamp = blockstamp
244243
if to_epoch != blockstamp.ref_epoch:
245244
frame_blockstamp = self._get_ref_blockstamp_for_frame(blockstamp, to_epoch)
246-
distributed_in_frame, shares_in_frame, log = self._calculate_distribution_in_frame(
247-
frame_blockstamp, operators_to_validators, frame, distributed
245+
246+
total_rewards_to_distribute = self.w3.csm.fee_distributor.shares_to_distribute(frame_blockstamp.block_hash)
247+
rewards_to_distribute_in_frame = total_rewards_to_distribute - total_distributed
248+
249+
rewards_in_frame, log = self._calculate_distribution_in_frame(
250+
frame, frame_blockstamp, rewards_to_distribute_in_frame, operators_to_validators
248251
)
249-
distributed += distributed_in_frame
250-
for no_id, share in shares_in_frame.items():
251-
shares[no_id] += share
252+
distributed_in_frame = sum(rewards_in_frame.values())
253+
254+
total_distributed += distributed_in_frame
255+
if total_distributed > total_rewards_to_distribute:
256+
raise CSMError(f"Invalid distribution: {total_distributed=} > {total_rewards_to_distribute=}")
257+
258+
for no_id, rewards in rewards_in_frame.items():
259+
total_rewards[no_id] += rewards
260+
252261
logs.append(log)
253262

254-
return distributed, shares, logs
263+
return total_distributed, total_rewards, logs
255264

256265
def _get_ref_blockstamp_for_frame(
257266
self, blockstamp: ReferenceBlockStamp, frame_ref_epoch: EpochNumber
@@ -266,63 +275,85 @@ def _get_ref_blockstamp_for_frame(
266275

267276
def _calculate_distribution_in_frame(
268277
self,
269-
blockstamp: ReferenceBlockStamp,
270-
operators_to_validators: ValidatorsByNodeOperator,
271278
frame: Frame,
272-
distributed: int,
279+
blockstamp: ReferenceBlockStamp,
280+
rewards_to_distribute: int,
281+
operators_to_validators: ValidatorsByNodeOperator
273282
):
274-
network_perf = self.state.get_network_aggr(frame).perf
275-
threshold = network_perf - self.w3.csm.oracle.perf_leeway_bp(blockstamp.block_hash) / TOTAL_BASIS_POINTS
276-
277-
# Build the map of the current distribution operators.
278-
distribution: dict[NodeOperatorId, int] = defaultdict(int)
279-
stuck_operators = self.stuck_operators(blockstamp)
283+
threshold = self._get_performance_threshold(frame, blockstamp)
280284
log = FramePerfLog(blockstamp, frame, threshold)
281285

286+
participation_shares: defaultdict[NodeOperatorId, int] = defaultdict(int)
287+
288+
stuck_operators = self.stuck_operators(blockstamp)
282289
for (_, no_id), validators in operators_to_validators.items():
290+
log_operator = log.operators[no_id]
283291
if no_id in stuck_operators:
284-
log.operators[no_id].stuck = True
292+
log_operator.stuck = True
293+
continue
294+
for validator in validators:
295+
duty = self.state.data[frame].get(validator.index)
296+
self.process_validator_duty(validator, duty, threshold, participation_shares, log_operator)
297+
298+
rewards_distribution = self.calc_rewards_distribution_in_frame(participation_shares, rewards_to_distribute)
299+
300+
for no_id, no_rewards in rewards_distribution.items():
301+
log.operators[no_id].distributed = no_rewards
302+
303+
log.distributable = rewards_to_distribute
304+
305+
return rewards_distribution, log
306+
307+
def _get_performance_threshold(self, frame: Frame, blockstamp: ReferenceBlockStamp) -> float:
308+
network_perf = self.state.get_network_aggr(frame).perf
309+
perf_leeway = self.w3.csm.oracle.perf_leeway_bp(blockstamp.block_hash) / TOTAL_BASIS_POINTS
310+
threshold = network_perf - perf_leeway
311+
return threshold
312+
313+
@staticmethod
314+
def process_validator_duty(
315+
validator: LidoValidator,
316+
attestation_duty: AttestationsAccumulator | None,
317+
threshold: float,
318+
participation_shares: defaultdict[NodeOperatorId, int],
319+
log_operator: OperatorFrameSummary
320+
):
321+
if attestation_duty is None:
322+
# It's possible that the validator is not assigned to any duty, hence it's performance
323+
# is not presented in the aggregates (e.g. exited, pending for activation etc).
324+
# TODO: check `sync_aggr` to strike (in case of bad sync performance) after validator exit
325+
return
326+
327+
log_validator = log_operator.validators[validator.index]
328+
329+
if validator.validator.slashed is True:
330+
# It means that validator was active during the frame and got slashed and didn't meet the exit
331+
# epoch, so we should not count such validator for operator's share.
332+
log_validator.slashed = True
333+
return
334+
335+
if attestation_duty.perf > threshold:
336+
# Count of assigned attestations used as a metrics of time
337+
# the validator was active in the current frame.
338+
participation_shares[validator.lido_id.operatorIndex] += attestation_duty.assigned
339+
340+
log_validator.attestation_duty = attestation_duty
341+
342+
@staticmethod
343+
def calc_rewards_distribution_in_frame(
344+
participation_shares: dict[NodeOperatorId, int],
345+
rewards_to_distribute: int,
346+
) -> dict[NodeOperatorId, int]:
347+
rewards_distribution: dict[NodeOperatorId, int] = defaultdict(int)
348+
total_participation = sum(participation_shares.values())
349+
350+
for no_id, no_participation_share in participation_shares.items():
351+
if no_participation_share == 0:
352+
# Skip operators with zero participation
285353
continue
354+
rewards_distribution[no_id] = rewards_to_distribute * no_participation_share // total_participation
286355

287-
for v in validators:
288-
aggr = self.state.data[frame].get(ValidatorIndex(int(v.index)))
289-
290-
if aggr is None:
291-
# It's possible that the validator is not assigned to any duty, hence it's performance
292-
# is not presented in the aggregates (e.g. exited, pending for activation etc).
293-
continue
294-
295-
if v.validator.slashed is True:
296-
# It means that validator was active during the frame and got slashed and didn't meet the exit
297-
# epoch, so we should not count such validator for operator's share.
298-
log.operators[no_id].validators[v.index].slashed = True
299-
continue
300-
301-
if aggr.perf > threshold:
302-
# Count of assigned attestations used as a metrics of time
303-
# the validator was active in the current frame.
304-
distribution[no_id] += aggr.assigned
305-
306-
log.operators[no_id].validators[v.index].perf = aggr
307-
308-
# Calculate share of each CSM node operator.
309-
shares = defaultdict[NodeOperatorId, int](int)
310-
total = sum(p for p in distribution.values())
311-
to_distribute = self.w3.csm.fee_distributor.shares_to_distribute(blockstamp.block_hash) - distributed
312-
log.distributable = to_distribute
313-
314-
if not total:
315-
return 0, shares, log
316-
317-
for no_id, no_share in distribution.items():
318-
if no_share:
319-
shares[no_id] = to_distribute * no_share // total
320-
log.operators[no_id].distributed = shares[no_id]
321-
322-
distributed = sum(s for s in shares.values())
323-
if distributed > to_distribute:
324-
raise CSMError(f"Invalid distribution: {distributed=} > {to_distribute=}")
325-
return distributed, shares, log
356+
return rewards_distribution
326357

327358
def get_accumulated_shares(self, cid: CID, root: HexBytes) -> Iterator[tuple[NodeOperatorId, Shares]]:
328359
logger.info({"msg": "Fetching tree by CID from IPFS", "cid": repr(cid)})

src/modules/csm/log.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@ class LogJSONEncoder(json.JSONEncoder): ...
1212

1313
@dataclass
1414
class ValidatorFrameSummary:
15-
# TODO: Should be renamed. Perf means different things in different contexts
16-
perf: AttestationsAccumulator = field(default_factory=AttestationsAccumulator)
15+
attestation_duty: AttestationsAccumulator = field(default_factory=AttestationsAccumulator)
1716
slashed: bool = False
1817

1918

src/modules/csm/state.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ def unprocessed_epochs(self) -> set[EpochNumber]:
110110
def is_fulfilled(self) -> bool:
111111
return not self.unprocessed_epochs
112112

113+
@property
114+
def frames(self):
115+
return self.calculate_frames(self._epochs_to_process, self._epochs_per_frame)
116+
113117
@staticmethod
114118
def calculate_frames(epochs_to_process: tuple[EpochNumber, ...], epochs_per_frame: int) -> list[Frame]:
115119
"""Split epochs to process into frames of `epochs_per_frame` length"""
@@ -127,11 +131,10 @@ def clear(self) -> None:
127131
assert self.is_empty
128132

129133
def find_frame(self, epoch: EpochNumber) -> Frame:
130-
frames = self.data.keys()
131-
for epoch_range in frames:
134+
for epoch_range in self.frames:
132135
if epoch_range[0] <= epoch <= epoch_range[1]:
133136
return epoch_range
134-
raise ValueError(f"Epoch {epoch} is out of frames range: {frames}")
137+
raise ValueError(f"Epoch {epoch} is out of frames range: {self.frames}")
135138

136139
def increment_duty(self, frame: Frame, val_index: ValidatorIndex, included: bool) -> None:
137140
if frame not in self.data:
@@ -160,7 +163,7 @@ def init_or_migrate(
160163
frames_data: StateData = {frame: defaultdict(AttestationsAccumulator) for frame in frames}
161164

162165
if not self.is_empty:
163-
cached_frames = self.calculate_frames(self._epochs_to_process, self._epochs_per_frame)
166+
cached_frames = self.frames
164167
if cached_frames == frames:
165168
logger.info({"msg": "No need to migrate duties data cache"})
166169
return

src/providers/execution/contracts/cs_fee_distributor.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from eth_typing import ChecksumAddress
44
from hexbytes import HexBytes
55
from web3 import Web3
6-
from web3.types import BlockIdentifier
6+
from web3.types import BlockIdentifier, Wei
77

88
from ..base_interface import ContractInterface
99

@@ -26,7 +26,7 @@ def oracle(self, block_identifier: BlockIdentifier = "latest") -> ChecksumAddres
2626
)
2727
return Web3.to_checksum_address(resp)
2828

29-
def shares_to_distribute(self, block_identifier: BlockIdentifier = "latest") -> int:
29+
def shares_to_distribute(self, block_identifier: BlockIdentifier = "latest") -> Wei:
3030
"""Returns the amount of shares that are pending to be distributed"""
3131

3232
resp = self.functions.pendingSharesToDistribute().call(block_identifier=block_identifier)

0 commit comments

Comments
 (0)