Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sanitize KAPI keys to lowercase #522

Merged
merged 10 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions src/providers/keys/client.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from time import sleep
from typing import cast

from eth_typing import HexStr

from src.metrics.prometheus.basic import KEYS_API_REQUESTS_DURATION, KEYS_API_LATEST_BLOCKNUMBER
from src.providers.http_provider import HTTPProvider, NotOkResponse
from src.providers.keys.types import LidoKey, KeysApiStatus
from src.types import BlockStamp, StakingModuleAddress
from src.utils.dataclass import list_of_dataclasses
from src.utils.cache import global_lru_cache as lru_cache


Expand All @@ -17,6 +18,28 @@ class KAPIClientError(NotOkResponse):
pass


def _transform_keys_to_lowercase(lido_keys: list[LidoKey]) -> list[LidoKey]:
"""
Transforms the `key` field of each LidoKey in the input list to lowercase.

Args:
lido_keys (List[LidoKey]): List of LidoKey objects.

Returns:
List[LidoKey]: List of transformed LidoKey objects.
"""
return [
LidoKey(
key=HexStr(lido_key.key.lower()),
depositSignature=lido_key.depositSignature,
operatorIndex=lido_key.operatorIndex,
used=lido_key.used,
moduleAddress=lido_key.moduleAddress
)
for lido_key in lido_keys
]


class KeysAPIClient(HTTPProvider):
"""
Lido Keys are stored in different modules in on-chain and off-chain format.
Expand Down Expand Up @@ -51,17 +74,19 @@ def _get_with_blockstamp(self, url: str, blockstamp: BlockStamp, params: dict |
raise KeysOutdatedException(f'Keys API Service stuck, no updates for {self.backoff_factor * self.retry_count} seconds.')

@lru_cache(maxsize=1)
@list_of_dataclasses(LidoKey.from_response)
def get_used_lido_keys(self, blockstamp: BlockStamp) -> list[dict]:
def get_used_lido_keys(self, blockstamp: BlockStamp) -> list[LidoKey]:
"""Docs: https://keys-api.lido.fi/api/static/index.html#/keys/KeysController_get"""
return cast(list[dict], self._get_with_blockstamp(self.USED_KEYS, blockstamp))
lido_keys = list(map(lambda x: LidoKey.from_response(**x), self._get_with_blockstamp(self.USED_KEYS, blockstamp)))
return _transform_keys_to_lowercase(lido_keys)

@lru_cache(maxsize=1)
def get_module_operators_keys(self, module_address: StakingModuleAddress, blockstamp: BlockStamp) -> dict:
"""
Docs: https://keys-api.lido.fi/api/static/index.html#/operators-keys/SRModulesOperatorsKeysController_getOperatorsKeys
"""
return cast(dict, self._get_with_blockstamp(self.MODULE_OPERATORS_KEYS.format(module_address), blockstamp))
data = cast(dict, self._get_with_blockstamp(self.MODULE_OPERATORS_KEYS.format(module_address), blockstamp))
data['keys'] = _transform_keys_to_lowercase(data['keys'])
return data

def get_status(self) -> KeysApiStatus:
"""Docs: https://keys-api.lido.fi/api/static/index.html#/status/StatusController_get"""
Expand Down
3 changes: 2 additions & 1 deletion src/providers/keys/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

from eth_typing import ChecksumAddress, HexStr

from src.types import NodeOperatorId
from src.utils.dataclass import FromResponse


@dataclass
class LidoKey(FromResponse):
key: HexStr
depositSignature: HexStr
operatorIndex: int
operatorIndex: NodeOperatorId
used: bool
moduleAddress: ChecksumAddress

Expand Down
5 changes: 2 additions & 3 deletions src/services/exit_order/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
from src.metrics.prometheus.duration_meter import duration_meter
from src.modules.submodules.types import ChainConfig
from src.services.exit_order.iterator_state import ExitOrderIteratorStateService, NodeOperatorPredictableState
from src.types import ReferenceBlockStamp, NodeOperatorGlobalIndex, StakingModuleId, NodeOperatorId

from src.types import ReferenceBlockStamp, NodeOperatorGlobalIndex, StakingModuleId
from src.utils.validator_state import get_validator_age
from src.web3py.extensions.lido_validators import LidoValidator
from src.web3py.types import Web3
Expand Down Expand Up @@ -173,5 +172,5 @@ def operator_index_by_validator(
) -> NodeOperatorGlobalIndex:
return (
StakingModuleId(staking_module_id[validator.lido_id.moduleAddress]),
NodeOperatorId(validator.lido_id.operatorIndex),
validator.lido_id.operatorIndex,
)
30 changes: 21 additions & 9 deletions src/web3py/extensions/lido_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@
from src.utils.dataclass import Nested
from src.utils.cache import global_lru_cache as lru_cache


logger = logging.getLogger(__name__)


if TYPE_CHECKING:
from src.web3py.types import Web3 # pragma: no cover

Expand Down Expand Up @@ -190,7 +188,7 @@ def get_lido_validators_by_node_operators(self, blockstamp: BlockStamp) -> Valid
for validator in merged_validators:
global_no_id = (
staking_module_address[validator.lido_id.moduleAddress],
NodeOperatorId(validator.lido_id.operatorIndex),
validator.lido_id.operatorIndex,
)

if global_no_id in no_validators:
Expand All @@ -204,30 +202,44 @@ def get_lido_validators_by_node_operators(self, blockstamp: BlockStamp) -> Valid
return no_validators

@lru_cache(maxsize=1)
def get_module_validators_by_node_operators(self, module_address: StakingModuleAddress, blockstamp: BlockStamp) -> ValidatorsByNodeOperator:
"""Get module validators by querying the KeysAPI for the module keys"""
def get_module_validators_by_node_operators(
self,
module_address: StakingModuleAddress,
blockstamp: BlockStamp
) -> ValidatorsByNodeOperator:
"""
Get module validators by querying the KeysAPI for the module keys.

Args:
module_address (StakingModuleAddress): The address of the staking module.
blockstamp (BlockStamp): The block timestamp for querying validators.

Returns:
ValidatorsByNodeOperator: A mapping of node operator IDs to their corresponding validators.
"""
# Fetch module operator keys from the KeysAPI
kapi = self.w3.kac.get_module_operators_keys(module_address, blockstamp)
if (kapi_module_address := kapi['module']['stakingModuleAddress']) != module_address:
raise ValueError(f"Module address mismatch: {kapi_module_address=} != {module_address=}")
operators = kapi['operators']
keys = {k['key']: k for k in kapi['keys']}
keys = {k['key']: LidoKey.from_response(**k) for k in kapi['keys']}
validators = self.w3.cc.get_validators(blockstamp)

module_id = StakingModuleId(int(kapi['module']['id']))

# Make sure even empty NO will be presented in dict
no_validators: ValidatorsByNodeOperator = {
(module_id, NodeOperatorId(int(operator['index']))): [] for operator in operators
}

# Map validators to their corresponding node operators
for validator in validators:
lido_key = keys.get(validator.validator.pubkey)
if not lido_key:
continue
global_id = (module_id, lido_key['operatorIndex'])
global_id = (module_id, lido_key.operatorIndex)
no_validators[global_id].append(
LidoValidator(
lido_id=LidoKey.from_response(**lido_key),
lido_id=lido_key,
**asdict(validator),
)
)
Expand Down
2 changes: 1 addition & 1 deletion tests/modules/accounting/test_validator_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def validator(index: int, exit_epoch: int, pubkey: HexStr, activation_epoch: int
lido_id=LidoKey(
key=pubkey,
depositSignature="",
operatorIndex=-1,
operatorIndex=NodeOperatorId(-1),
used=True,
moduleAddress="",
),
Expand Down
8 changes: 4 additions & 4 deletions tests/modules/ejector/test_exit_order_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@

@pytest.mark.unit
def test_predicates():
def v(module_address, operator, index, activation_epoch) -> LidoValidator:
def v(module_address, operator: int, index, activation_epoch) -> LidoValidator:
validator = object.__new__(LidoValidator)
validator.lido_id = object.__new__(LidoKey)
validator.validator = object.__new__(ValidatorState)
validator.lido_id.moduleAddress = module_address
validator.lido_id.operatorIndex = operator
validator.lido_id.operatorIndex = NodeOperatorId(operator)
validator.index = index
validator.validator.activation_epoch = activation_epoch
return validator
Expand Down Expand Up @@ -75,12 +75,12 @@ def v(module_address, operator, index, activation_epoch) -> LidoValidator:

@pytest.mark.unit
def test_decrease_node_operator_stats():
def v(module_address, operator, index, activation_epoch) -> LidoValidator:
def v(module_address, operator: int, index, activation_epoch) -> LidoValidator:
validator = object.__new__(LidoValidator)
validator.lido_id = object.__new__(LidoKey)
validator.validator = object.__new__(ValidatorState)
validator.lido_id.moduleAddress = module_address
validator.lido_id.operatorIndex = operator
validator.lido_id.operatorIndex = NodeOperatorId(operator)
validator.index = index
validator.validator.activation_epoch = activation_epoch
return validator
Expand Down
Loading