-
Notifications
You must be signed in to change notification settings - Fork 27
framework: add configurable seeds/activation to XmssKeyManager (#129) #173
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
4d5f8d2
ac15db6
55cec11
c404d1f
309b0b1
034d5ab
f513830
fddbd93
0cd8a9b
986285c
9700818
cafdbfb
028f786
3f12f66
85a0b83
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,16 +1,23 @@ | ||
| """XMSS key management utilities for testing.""" | ||
|
|
||
| from typing import NamedTuple, Optional | ||
| from __future__ import annotations | ||
|
|
||
| import random | ||
| from typing import Any, NamedTuple, Optional | ||
|
|
||
| from lean_spec.subspecs.containers import Attestation, Signature | ||
| from lean_spec.subspecs.containers.slot import Slot | ||
| from lean_spec.subspecs.koalabear import Fp, P | ||
| from lean_spec.subspecs.ssz.hash import hash_tree_root | ||
| from lean_spec.subspecs.xmss.constants import PRF_KEY_LENGTH, XmssConfig | ||
| from lean_spec.subspecs.xmss.containers import PublicKey, SecretKey | ||
| from lean_spec.subspecs.xmss.interface import ( | ||
| TEST_SIGNATURE_SCHEME, | ||
| GeneralizedXmssScheme, | ||
| ) | ||
| from lean_spec.types import ValidatorIndex | ||
| from lean_spec.subspecs.xmss.prf import Prf | ||
| from lean_spec.subspecs.xmss.utils import Rand | ||
| from lean_spec.types import Uint64, ValidatorIndex | ||
|
|
||
|
|
||
| class KeyPair(NamedTuple): | ||
|
|
@@ -23,24 +30,63 @@ class KeyPair(NamedTuple): | |
| """The validator's secret key (used for signing).""" | ||
|
|
||
|
|
||
| _KEY_CACHE: dict[tuple[int, int], KeyPair] = {} | ||
| _KEY_CACHE: dict[tuple[int, int, int, int | None], KeyPair] = {} | ||
| """ | ||
| Cache keys across tests to avoid regenerating them for the same validator/lifetime combo. | ||
|
|
||
| Key: (validator_index, num_active_epochs) -> KeyPair | ||
| Key: (validator_index, activation_epoch, num_active_epochs, seed) -> KeyPair | ||
| """ | ||
|
|
||
|
|
||
| def _to_int(value: int | Slot | Uint64 | None, default: int = 0) -> int: | ||
| """Normalize Slot/Uint64/int to int with an optional default.""" | ||
| if value is None: | ||
| return default | ||
| if isinstance(value, Slot): | ||
| return value.as_int() | ||
| return int(value) | ||
|
|
||
|
|
||
| class SeededRand(Rand): | ||
| """Deterministic Rand helper to make key generation repeatable in tests.""" | ||
|
|
||
| def __init__(self, config: XmssConfig, seed: int) -> None: | ||
| """Initialize with a deterministic seed.""" | ||
| super().__init__(config) | ||
| self._rng = random.Random(seed) | ||
|
|
||
| def field_elements(self, length: int) -> list[Fp]: | ||
| """Generate deterministic field elements from the seeded RNG.""" | ||
| return [Fp(value=self._rng.randrange(P)) for _ in range(length)] | ||
|
|
||
|
|
||
| class SeededPrf(Prf): | ||
| """Deterministic PRF helper for repeatable PRF key generation.""" | ||
|
|
||
| def __init__(self, config: XmssConfig, seed: int) -> None: | ||
| """Initialize with a deterministic seed.""" | ||
| super().__init__(config) | ||
| self._rng = random.Random(seed) | ||
|
|
||
| def key_gen(self) -> bytes: | ||
| """Generate a deterministic PRF key for repeatable tests.""" | ||
| # Use a deterministic stream rather than os.urandom for repeatability in tests. | ||
| return self._rng.randbytes(PRF_KEY_LENGTH) | ||
|
||
|
|
||
|
|
||
| class XmssKeyManager: | ||
| """Lazy key manager for test validators using XMSS signatures.""" | ||
|
|
||
| DEFAULT_MAX_SLOT = Slot(100) | ||
| """Default maximum slot horizon if not specified.""" | ||
| DEFAULT_ACTIVATION_EPOCH = 0 | ||
|
||
|
|
||
| def __init__( | ||
| self, | ||
| max_slot: Optional[Slot] = None, | ||
| scheme: GeneralizedXmssScheme = TEST_SIGNATURE_SCHEME, | ||
| default_activation_epoch: int | Slot | Uint64 = DEFAULT_ACTIVATION_EPOCH, | ||
|
||
| default_seed: int | None = 0, | ||
| ) -> None: | ||
| """ | ||
| Initialize the key manager. | ||
|
|
@@ -53,6 +99,11 @@ def __init__( | |
| scheme : GeneralizedXmssScheme, optional | ||
| The XMSS scheme to use. | ||
| Defaults to `TEST_SIGNATURE_SCHEME`. | ||
| default_activation_epoch : int | Slot | Uint64, optional | ||
| Activation epoch used when none is provided for key generation. | ||
| default_seed : int | None, optional | ||
| Seed for deterministic key generation. Set to None to use non-deterministic | ||
| randomness from the underlying XMSS scheme. | ||
|
|
||
| Notes: | ||
| ----- | ||
|
|
@@ -61,7 +112,92 @@ def __init__( | |
| """ | ||
| self.max_slot = max_slot if max_slot is not None else self.DEFAULT_MAX_SLOT | ||
| self.scheme = scheme | ||
| self.default_activation_epoch = _to_int( | ||
| default_activation_epoch, self.DEFAULT_ACTIVATION_EPOCH | ||
| ) | ||
| self.default_seed = default_seed | ||
| self._key_pairs: dict[ValidatorIndex, KeyPair] = {} | ||
| self._key_metadata: dict[ValidatorIndex, dict[str, Any]] = {} | ||
| self._schemes_by_seed: dict[int, GeneralizedXmssScheme] = {} | ||
|
|
||
| @property | ||
| def default_num_active_epochs(self) -> int: | ||
| """Default lifetime derived from the configured max_slot.""" | ||
| return self.max_slot.as_int() + 1 | ||
|
Comment on lines
113
to
115
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the function naming is misleading here. You can maybe name it |
||
|
|
||
| def _scheme_for_seed(self, seed: int | None) -> GeneralizedXmssScheme: | ||
| """ | ||
| Return a scheme instance appropriate for the provided seed. | ||
|
|
||
| A deterministic scheme (SeededRand + SeededPrf) is returned when a specific | ||
| seed is provided; otherwise the base scheme is used. | ||
| """ | ||
| if seed is None: | ||
| return self.scheme | ||
|
|
||
| if seed not in self._schemes_by_seed: | ||
| self._schemes_by_seed[seed] = GeneralizedXmssScheme( | ||
| config=self.scheme.config, | ||
| prf=SeededPrf(self.scheme.config, seed), | ||
| hasher=self.scheme.hasher, | ||
| merkle_tree=self.scheme.merkle_tree, | ||
| encoder=self.scheme.encoder, | ||
| rand=SeededRand(self.scheme.config, seed), | ||
| ) | ||
|
|
||
| return self._schemes_by_seed[seed] | ||
|
||
|
|
||
| def create_and_store_key_pair( | ||
| self, | ||
| validator_index: ValidatorIndex, | ||
| *, | ||
| activation_epoch: int | Slot | Uint64 | None = None, | ||
| num_active_epochs: int | Slot | Uint64 | None = None, | ||
| seed: int | None = None, | ||
| ) -> KeyPair: | ||
| """ | ||
| Generate and store a key pair with explicit control over key generation. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| validator_index : ValidatorIndex | ||
| The validator for whom a key pair should be generated. | ||
| activation_epoch : int | Slot | Uint64, optional | ||
| First epoch for which the key is valid. Defaults to `default_activation_epoch`. | ||
| num_active_epochs : int | Slot | Uint64, optional | ||
| Number of consecutive epochs the key should remain active. | ||
| Defaults to `max_slot + 1` (to include genesis). | ||
| seed : int | None, optional | ||
| Seed used for deterministic key generation. If None, the base scheme's | ||
| randomness is used. | ||
| """ | ||
| activation_epoch_int = _to_int(activation_epoch, self.default_activation_epoch) | ||
| num_active_epochs_int = _to_int(num_active_epochs, self.default_num_active_epochs) | ||
| key_seed = seed if seed is not None else self.default_seed | ||
|
|
||
| scheme = self._scheme_for_seed(key_seed) | ||
|
|
||
| cache_key = ( | ||
| int(validator_index), | ||
| activation_epoch_int, | ||
| num_active_epochs_int, | ||
| key_seed, | ||
| ) | ||
|
|
||
| if cache_key in _KEY_CACHE: | ||
| key_pair = _KEY_CACHE[cache_key] | ||
| else: | ||
| pk, sk = scheme.key_gen(Uint64(activation_epoch_int), Uint64(num_active_epochs_int)) | ||
| key_pair = KeyPair(public=pk, secret=sk) | ||
| _KEY_CACHE[cache_key] = key_pair | ||
|
|
||
| self._key_pairs[validator_index] = key_pair | ||
| self._key_metadata[validator_index] = { | ||
| "activation_epoch": activation_epoch_int, | ||
| "num_active_epochs": num_active_epochs_int, | ||
| "seed": key_seed, | ||
| } | ||
| return key_pair | ||
tcoratger marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def __getitem__(self, validator_index: ValidatorIndex) -> KeyPair: | ||
| """ | ||
|
|
@@ -83,37 +219,10 @@ def __getitem__(self, validator_index: ValidatorIndex) -> KeyPair: | |
| - Keys are deterministic for testing (`seed=0`). | ||
| - Lifetime = `max_slot + 1` to include the genesis slot. | ||
| """ | ||
| # Return cached keys if they exist. | ||
| if validator_index in self._key_pairs: | ||
| return self._key_pairs[validator_index] | ||
|
|
||
| # Generate New Key Pair | ||
| # | ||
| # XMSS requires knowing the total number of signatures in advance. | ||
| # We use max_slot + 1 as the lifetime since: | ||
| # - Validators may sign once per slot (attestations) | ||
| # - We include slot 0 (genesis) in the count | ||
| num_active_epochs = self.max_slot.as_int() + 1 | ||
|
|
||
| # Check global cache first (keys are reused across tests) | ||
| cache_key = (int(validator_index), num_active_epochs) | ||
| if cache_key in _KEY_CACHE: | ||
| key_pair = _KEY_CACHE[cache_key] | ||
| self._key_pairs[validator_index] = key_pair | ||
| return key_pair | ||
|
|
||
| # Generate the key pair using the default XMSS scheme. | ||
| # | ||
| # The seed is set to 0 for deterministic test keys. | ||
| from lean_spec.types import Uint64 | ||
|
|
||
| pk, sk = self.scheme.key_gen(Uint64(0), Uint64(num_active_epochs)) | ||
|
|
||
| # Store as a cohesive unit and return. | ||
| key_pair = KeyPair(public=pk, secret=sk) | ||
| _KEY_CACHE[cache_key] = key_pair # Cache globally for reuse across tests | ||
| self._key_pairs[validator_index] = key_pair | ||
| return key_pair | ||
| return self.create_and_store_key_pair(validator_index) | ||
|
|
||
| def sign_attestation(self, attestation: Attestation) -> Signature: | ||
| """ | ||
|
|
@@ -143,27 +252,37 @@ def sign_attestation(self, attestation: Attestation) -> Signature: | |
| # Get the current secret key | ||
| sk = key_pair.secret | ||
|
|
||
| metadata = self._key_metadata.get( | ||
| validator_id, | ||
| { | ||
| "seed": self.default_seed, | ||
| "activation_epoch": self.default_activation_epoch, | ||
| "num_active_epochs": self.default_num_active_epochs, | ||
| }, | ||
| ) | ||
| scheme = self._scheme_for_seed(metadata.get("seed")) | ||
|
|
||
| # Map the attestation slot to an XMSS epoch. | ||
| # | ||
| # Each slot gets its own epoch to avoid key reuse. | ||
| epoch = attestation.data.slot | ||
|
|
||
| # Loop until the epoch is inside the prepared interval | ||
| prepared_interval = self.scheme.get_prepared_interval(sk) | ||
| prepared_interval = scheme.get_prepared_interval(sk) | ||
| while int(epoch) not in prepared_interval: | ||
| # Check if we're advancing past the key's total lifetime | ||
| activation_interval = self.scheme.get_activation_interval(sk) | ||
| activation_interval = scheme.get_activation_interval(sk) | ||
| if prepared_interval.stop >= activation_interval.stop: | ||
| raise ValueError( | ||
| f"Cannot sign for epoch {epoch}: " | ||
| f"it is beyond the key's max lifetime {activation_interval.stop}" | ||
| ) | ||
|
|
||
| # Advance the key and get the new key object | ||
| sk = self.scheme.advance_preparation(sk) | ||
| sk = scheme.advance_preparation(sk) | ||
|
|
||
| # Update the prepared interval for the next loop check | ||
| prepared_interval = self.scheme.get_prepared_interval(sk) | ||
| prepared_interval = scheme.get_prepared_interval(sk) | ||
|
|
||
| # Update the cached key pair with the new, advanced secret key. | ||
| # This ensures the *next* call to sign() uses the advanced state. | ||
|
|
@@ -175,18 +294,10 @@ def sign_attestation(self, attestation: Attestation) -> Signature: | |
| message = bytes(hash_tree_root(attestation)) | ||
|
|
||
| # Generate the XMSS signature using the validator's (now prepared) secret key. | ||
| xmss_sig = self.scheme.sign(sk, epoch, message) | ||
|
|
||
| # Convert the signature to the wire format (byte array). | ||
| signature_bytes = xmss_sig.to_bytes(self.scheme.config) | ||
|
|
||
| # Ensure the signature meets the consensus spec length (3100 bytes). | ||
| # | ||
| # This is necessary when using TEST_CONFIG (796 bytes) vs PROD_CONFIG. | ||
| # Padding with zeros on the right maintains compatibility. | ||
| padded_bytes = signature_bytes.ljust(Signature.LENGTH, b"\x00") | ||
| xmss_sig = scheme.sign(sk, epoch, message) | ||
|
|
||
| return Signature(padded_bytes) | ||
| # Convert to the consensus Signature container (handles padding internally). | ||
| return Signature.from_xmss(xmss_sig, scheme) | ||
|
|
||
| def get_public_key(self, validator_index: ValidatorIndex) -> PublicKey: | ||
| """ | ||
|
|
@@ -225,3 +336,30 @@ def __contains__(self, validator_index: ValidatorIndex) -> bool: | |
| def __len__(self) -> int: | ||
| """Return the number of validators with generated keys.""" | ||
| return len(self._key_pairs) | ||
|
|
||
| def export_test_vectors(self, include_private_keys: bool = False) -> list[dict[str, Any]]: | ||
| """ | ||
| Export generated keys in a JSON-serializable structure for downstream clients. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| include_private_keys : bool | ||
| When True, include the full secret key dump; otherwise only public data. | ||
| """ | ||
| vectors: list[dict[str, Any]] = [] | ||
| for validator_index, key_pair in self._key_pairs.items(): | ||
| meta = self._key_metadata.get(validator_index, {}) | ||
| entry: dict[str, Any] = { | ||
| "validator_index": int(validator_index), | ||
| "activation_epoch": meta.get("activation_epoch"), | ||
| "num_active_epochs": meta.get("num_active_epochs"), | ||
| "seed": meta.get("seed"), | ||
| "public_key": key_pair.public.to_bytes(self.scheme.config).hex(), | ||
| } | ||
| if include_private_keys: | ||
| # Pydantic models are JSON-serializable; keep the raw dump for full fidelity. | ||
| entry["secret_key"] = key_pair.secret.model_dump(mode="json") | ||
|
|
||
| vectors.append(entry) | ||
|
|
||
| return vectors | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -28,3 +28,22 @@ def verify( | |||||
| return scheme.verify(public_key, epoch, message, signature) | ||||||
| except Exception: | ||||||
| return False | ||||||
|
|
||||||
| @classmethod | ||||||
| def from_xmss( | ||||||
| cls, xmss_signature: XmssSignature, scheme: GeneralizedXmssScheme = TEST_SIGNATURE_SCHEME | ||||||
| ) -> "Signature": | ||||||
|
||||||
| """ | ||||||
| Create a consensus `Signature` container from an XMSS signature object. | ||||||
|
|
||||||
| Handles padding to the fixed 3100-byte length required by the consensus layer, | ||||||
| delegating all encoding details to the XMSS container itself. | ||||||
|
||||||
| Handles padding to the fixed 3100-byte length required by the consensus layer, | |
| delegating all encoding details to the XMSS container itself. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vyakart ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not have this, this is not clean to have. We should probably rely only on Uint64 or int and not have this. This is a primitive conversion stuff that should not be inside the testing framework folder.