diff --git a/packages/testing/src/consensus_testing/keys.py b/packages/testing/src/consensus_testing/keys.py index a223ab20..65b95cdf 100644 --- a/packages/testing/src/consensus_testing/keys.py +++ b/packages/testing/src/consensus_testing/keys.py @@ -1,6 +1,8 @@ """XMSS key management utilities for testing.""" -from typing import NamedTuple, Optional +from __future__ import annotations + +from typing import Any, NamedTuple, Optional from lean_spec.subspecs.containers import Attestation, Signature from lean_spec.subspecs.containers.slot import Slot @@ -10,7 +12,7 @@ TEST_SIGNATURE_SCHEME, GeneralizedXmssScheme, ) -from lean_spec.types import ValidatorIndex +from lean_spec.types import Uint64, ValidatorIndex class KeyPair(NamedTuple): @@ -23,11 +25,11 @@ class KeyPair(NamedTuple): """The validator's secret key (used for signing).""" -_KEY_CACHE: dict[tuple[int, int], KeyPair] = {} +_KEY_CACHE: dict[tuple[int, int, int, int], KeyPair] = {} """ Cache keys across tests to avoid regenerating them for the same validator/lifetime combo. -Key: (validator_index, num_active_epochs) -> KeyPair +Key: (validator_index, activation_epoch, num_active_epochs, seed) -> KeyPair """ @@ -36,9 +38,17 @@ class XmssKeyManager: DEFAULT_MAX_SLOT = Slot(100) """Default maximum slot horizon if not specified.""" + DEFAULT_ACTIVATION_EPOCH = Uint64(0) + """Default activation epoch when none is provided.""" + DEFAULT_SEED = 0 + """Default deterministic seed when none is provided.""" def __init__( self, + activation_epoch: Optional[Uint64 | Slot | int] = None, + *, + default_activation_epoch: Optional[Uint64 | Slot | int] = None, + default_seed: Optional[int] = None, max_slot: Optional[Slot] = None, scheme: GeneralizedXmssScheme = TEST_SIGNATURE_SCHEME, ) -> None: @@ -47,6 +57,12 @@ def __init__( Parameters ---------- + activation_epoch : Uint64 | Slot | int, optional + Deprecated alias for `default_activation_epoch`. + default_activation_epoch : Uint64 | Slot | int, optional + Activation epoch used when none is provided for key generation. + default_seed : int, optional + Seed value used when none is provided for key generation. max_slot : Slot, optional Highest slot number for which keys must remain valid. Defaults to `Slot(100)`. @@ -58,14 +74,118 @@ def __init__( ----- Internally, keys are stored in a single dictionary: `{ValidatorIndex → KeyPair}`. + + This class manages stateful XMSS keys for testing, handling the complexity of + epoch updates and key evolution that stateless helpers cannot provide. """ self.max_slot = max_slot if max_slot is not None else self.DEFAULT_MAX_SLOT self.scheme = scheme + if activation_epoch is not None and default_activation_epoch is not None: + raise ValueError("Use either activation_epoch or default_activation_epoch, not both.") + effective_activation = ( + default_activation_epoch if default_activation_epoch is not None else activation_epoch + ) + activation_value = ( + self.DEFAULT_ACTIVATION_EPOCH + if effective_activation is None + else self._coerce_uint64(effective_activation) + ) + self._default_activation_epoch = activation_value + self._default_seed = int(default_seed) if default_seed is not None else self.DEFAULT_SEED self._key_pairs: dict[ValidatorIndex, KeyPair] = {} + self._key_metadata: dict[ValidatorIndex, dict[str, Any]] = {} + + @staticmethod + def _coerce_uint64(value: Uint64 | Slot | int) -> Uint64: + """Convert supported numeric inputs to Uint64.""" + if isinstance(value, Uint64): + return Uint64(int(value)) + if isinstance(value, Slot): + return Uint64(value.as_int()) + return Uint64(int(value)) + + @property + def default_max_epoch(self) -> int: + """Default lifetime derived from the manager's configured max_slot.""" + return self.default_num_active_epochs + + @property + def default_num_active_epochs(self) -> int: + """Number of epochs keys stay active when not overridden.""" + return self.max_slot.as_int() + 1 + + @property + def default_activation_epoch(self) -> int: + """Default activation epoch as an int.""" + return int(self._default_activation_epoch) + + @property + def default_seed(self) -> int: + """Default seed used when none is provided.""" + return self._default_seed + + def create_and_store_key_pair( + self, + validator_index: ValidatorIndex, + *, + activation_epoch: Optional[Uint64 | Slot | int] = None, + num_active_epochs: Optional[Uint64 | Slot | int] = None, + seed: Optional[int] = None, + ) -> KeyPair: + """ + Generate and store a key pair with explicit control over key generation. + + Parameters + ---------- + validator_index : ValidatorIndex + The validator for whom a key pair should be generated. + activation_epoch : Uint64 | Slot | int, optional + First epoch for which the key is valid. Defaults to the manager's + configured `default_activation_epoch`. + num_active_epochs : Uint64 | Slot | int, optional + Number of consecutive epochs the key should remain active. + Defaults to `default_num_active_epochs` (derived from `max_slot` to include genesis). + seed : int, optional + Deterministic seed for caching/reuse. Defaults to manager's `default_seed`. + """ + activation_epoch_val = ( + self._coerce_uint64(activation_epoch) + if activation_epoch is not None + else self._default_activation_epoch + ) + num_active_epochs_val = ( + self._coerce_uint64(num_active_epochs) + if num_active_epochs is not None + else self._coerce_uint64(self.default_num_active_epochs) + ) + seed_val = int(seed) if seed is not None else self.default_seed + + cache_key = ( + int(validator_index), + int(activation_epoch_val), + int(num_active_epochs_val), + seed_val, + ) + + if cache_key in _KEY_CACHE: + key_pair = _KEY_CACHE[cache_key] + else: + pk, sk = self.scheme.key_gen(activation_epoch_val, num_active_epochs_val) + key_pair = KeyPair(public=pk, secret=sk) + _KEY_CACHE[cache_key] = key_pair + + self._key_pairs[validator_index] = key_pair + self._key_metadata[validator_index] = { + "activation_epoch": int(activation_epoch_val), + "num_active_epochs": int(num_active_epochs_val), + "seed": seed_val, + } + # TODO: support multiple keys per validator keyed by activation_epoch. + return key_pair def __getitem__(self, validator_index: ValidatorIndex) -> KeyPair: """ - Retrieve or lazily generate a validator’s key pair. + Retrieve or lazily generate a validator's key pair. Parameters ---------- @@ -75,45 +195,18 @@ def __getitem__(self, validator_index: ValidatorIndex) -> KeyPair: Returns: ------- KeyPair - The validator’s XMSS key pair. + XMSS key pair associated with the validator. Notes: ----- - Generates a new key if none exists. - Keys are deterministic for testing (`seed=0`). - - Lifetime = `max_slot + 1` to include the genesis slot. + - Lifetime defaults to `default_num_active_epochs` to include the genesis slot. """ - # Return cached keys if they exist. if validator_index in self._key_pairs: return self._key_pairs[validator_index] - # Generate New Key Pair - # - # XMSS requires knowing the total number of signatures in advance. - # We use max_slot + 1 as the lifetime since: - # - Validators may sign once per slot (attestations) - # - We include slot 0 (genesis) in the count - num_active_epochs = self.max_slot.as_int() + 1 - - # Check global cache first (keys are reused across tests) - cache_key = (int(validator_index), num_active_epochs) - if cache_key in _KEY_CACHE: - key_pair = _KEY_CACHE[cache_key] - self._key_pairs[validator_index] = key_pair - return key_pair - - # Generate the key pair using the default XMSS scheme. - # - # The seed is set to 0 for deterministic test keys. - from lean_spec.types import Uint64 - - pk, sk = self.scheme.key_gen(Uint64(0), Uint64(num_active_epochs)) - - # Store as a cohesive unit and return. - key_pair = KeyPair(public=pk, secret=sk) - _KEY_CACHE[cache_key] = key_pair # Cache globally for reuse across tests - self._key_pairs[validator_index] = key_pair - return key_pair + return self.create_and_store_key_pair(validator_index) def sign_attestation(self, attestation: Attestation) -> Signature: """ @@ -177,16 +270,40 @@ def sign_attestation(self, attestation: Attestation) -> Signature: # Generate the XMSS signature using the validator's (now prepared) secret key. xmss_sig = self.scheme.sign(sk, epoch, message) - # Convert the signature to the wire format (byte array). - signature_bytes = xmss_sig.to_bytes(self.scheme.config) + # Convert to the consensus Signature container (handles padding internally). + return Signature.from_xmss(xmss_sig, self.scheme) - # Ensure the signature meets the consensus spec length (3116 bytes). - # - # This is necessary when using TEST_CONFIG (796 bytes) vs PROD_CONFIG. - # Padding with zeros on the right maintains compatibility. - padded_bytes = signature_bytes.ljust(Signature.LENGTH, b"\x00") + def export_test_vectors(self, include_private_keys: bool = False) -> list[dict[str, Any]]: + """ + Export generated keys as dictionaries suitable for JSON test vectors. + + Parameters + ---------- + include_private_keys : bool, optional + When True, include SecretKey contents for debugging fixtures. - return Signature(padded_bytes) + Returns: + ------- + list[dict[str, Any]] + A list of entries keyed by validator_index with metadata and hex keys. + """ + vectors: list[dict[str, Any]] = [] + for validator_index in sorted(self._key_pairs.keys(), key=int): + key_pair = self._key_pairs[validator_index] + metadata = self._key_metadata.get(validator_index, {}) + entry: dict[str, Any] = { + "validator_index": int(validator_index), + "activation_epoch": metadata.get("activation_epoch", self.default_activation_epoch), + "num_active_epochs": metadata.get( + "num_active_epochs", self.default_num_active_epochs + ), + "seed": metadata.get("seed", self.default_seed), + "public_key": key_pair.public.to_bytes(self.scheme.config).hex(), + } + if include_private_keys: + entry["secret_key"] = key_pair.secret.model_dump() + vectors.append(entry) + return vectors def get_public_key(self, validator_index: ValidatorIndex) -> PublicKey: """ @@ -199,7 +316,7 @@ def get_public_key(self, validator_index: ValidatorIndex) -> PublicKey: Returns: ------- PublicKey - The validator’s public key. + Public key for the validator. """ return self[validator_index].public diff --git a/packages/testing/src/consensus_testing/test_fixtures/fork_choice.py b/packages/testing/src/consensus_testing/test_fixtures/fork_choice.py index ea155e02..fc012570 100644 --- a/packages/testing/src/consensus_testing/test_fixtures/fork_choice.py +++ b/packages/testing/src/consensus_testing/test_fixtures/fork_choice.py @@ -112,6 +112,20 @@ class ForkChoiceTest(BaseConsensusFixture): valid up to the highest slot used in any block or attestation. """ + key_manager_seed: int | None = None + """ + Optional deterministic seed to pass to the XMSS key manager. + + When set, validators' keys and signatures become reproducible across runs. + """ + + key_manager_activation_epoch: Slot | None = None + """ + Optional activation epoch to use when generating keys. + + Defaults to the key manager's own default (0) when unset. + """ + @model_validator(mode="after") def set_anchor_block_default(self) -> ForkChoiceTest: """ @@ -183,10 +197,24 @@ def make_fixture(self) -> ForkChoiceTest: # Use shared key manager if it has sufficient capacity, otherwise create a new one # This optimizes performance by reusing keys across tests when possible shared_key_manager = _get_shared_key_manager() + use_shared = ( + self.key_manager_seed is None + and self.key_manager_activation_epoch is None + and self.max_slot <= shared_key_manager.max_slot + ) key_manager = ( shared_key_manager - if self.max_slot <= shared_key_manager.max_slot - else XmssKeyManager(max_slot=self.max_slot, scheme=TEST_SIGNATURE_SCHEME) + if use_shared + else XmssKeyManager( + max_slot=self.max_slot, + scheme=TEST_SIGNATURE_SCHEME, + default_seed=self.key_manager_seed, + default_activation_epoch=( + self.key_manager_activation_epoch + if self.key_manager_activation_epoch is not None + else XmssKeyManager.DEFAULT_ACTIVATION_EPOCH + ), + ) ) # Update validator pubkeys to match key_manager's generated keys diff --git a/src/lean_spec/subspecs/containers/signature.py b/src/lean_spec/subspecs/containers/signature.py index 697b0231..cf6a6d3c 100644 --- a/src/lean_spec/subspecs/containers/signature.py +++ b/src/lean_spec/subspecs/containers/signature.py @@ -28,3 +28,22 @@ def verify( return scheme.verify(public_key, epoch, message, signature) except Exception: return False + + @classmethod + def from_xmss( + cls, xmss_signature: XmssSignature, scheme: GeneralizedXmssScheme = TEST_SIGNATURE_SCHEME + ) -> Signature: + """ + Create a consensus `Signature` container from an XMSS signature object. + + Applies the consensus-layer fixed-length padding, delegating all encoding + details to the XMSS container itself. + """ + raw = xmss_signature.to_bytes(scheme.config) + if len(raw) > cls.LENGTH: + raise ValueError( + f"XMSS signature length {len(raw)} exceeds container size {cls.LENGTH}" + ) + + # Pad on the right to the fixed-length container expected by consensus. + return cls(raw.ljust(cls.LENGTH, b"\x00")) diff --git a/tests/lean_spec/subspecs/containers/test_signature.py b/tests/lean_spec/subspecs/containers/test_signature.py new file mode 100644 index 00000000..8bb3bff8 --- /dev/null +++ b/tests/lean_spec/subspecs/containers/test_signature.py @@ -0,0 +1,29 @@ +"""Tests for consensus Signature container.""" + +from lean_spec.subspecs.containers import Signature +from lean_spec.subspecs.xmss.interface import TEST_SIGNATURE_SCHEME +from lean_spec.types import Uint64 + + +class TestSignatureFromXmss: + """Tests for Signature.from_xmss conversion method.""" + + def test_from_xmss_roundtrip_with_verify(self) -> None: + """Test that a signature created via from_xmss can be verified.""" + + # Generate a test key pair + pk, sk = TEST_SIGNATURE_SCHEME.key_gen(Uint64(0), Uint64(10)) + + # Create a test message (must be exactly 32 bytes) + message = b"test message for signing123456\x00\x00" # 32 bytes + assert len(message) == 32 + epoch = Uint64(0) + + # Sign the message + xmss_sig = TEST_SIGNATURE_SCHEME.sign(sk, epoch, message) + + # Convert to consensus signature + consensus_sig = Signature.from_xmss(xmss_sig, TEST_SIGNATURE_SCHEME) + + # Verify using the consensus signature's verify method + assert consensus_sig.verify(pk, epoch, message, TEST_SIGNATURE_SCHEME) diff --git a/tests/test_consensus_testing_keys.py b/tests/test_consensus_testing_keys.py new file mode 100644 index 00000000..ab1707c4 --- /dev/null +++ b/tests/test_consensus_testing_keys.py @@ -0,0 +1,58 @@ +import pytest +from consensus_testing.keys import XmssKeyManager + +from lean_spec.types import ValidatorIndex + + +def test_seeded_key_generation_is_deterministic() -> None: + manager_a = XmssKeyManager(default_seed=42) + manager_b = XmssKeyManager(default_seed=42) + manager_c = XmssKeyManager(default_seed=43) + + pair_a = manager_a.create_and_store_key_pair(ValidatorIndex(0)) + pair_b = manager_b.create_and_store_key_pair(ValidatorIndex(0)) + pair_c = manager_c.create_and_store_key_pair(ValidatorIndex(0)) + + assert pair_a.public == pair_b.public + assert pair_a.secret == pair_b.secret + assert pair_a.public != pair_c.public + assert pair_a.secret != pair_c.secret + + +def test_export_test_vectors_shape_and_metadata() -> None: + manager = XmssKeyManager(default_seed=7) + # Explicitly control first key parameters + manager.create_and_store_key_pair( + ValidatorIndex(1), + activation_epoch=5, + num_active_epochs=10, + seed=99, + ) + # Use defaults for second key + manager.create_and_store_key_pair(ValidatorIndex(2)) + + vectors = manager.export_test_vectors(include_private_keys=True) + assert {entry["validator_index"] for entry in vectors} == {1, 2} + + by_validator = {entry["validator_index"]: entry for entry in vectors} + first = by_validator[1] + second = by_validator[2] + + # Public key should be hex-encoded and match the configured length. + pk_len = manager.scheme.config.PUBLIC_KEY_LEN_BYTES * 2 + assert len(first["public_key"]) == pk_len + assert len(second["public_key"]) == pk_len + + # Metadata should reflect the parameters used to create the keys. + assert first["activation_epoch"] == 5 + assert first["num_active_epochs"] == 10 + assert first["seed"] == 99 + + assert second["activation_epoch"] == manager.default_activation_epoch + assert second["num_active_epochs"] == manager.default_num_active_epochs + assert second["seed"] == manager.default_seed + + # Secret key is only present when requested. + assert "secret_key" in first + assert isinstance(first["secret_key"], dict) + assert "secret_key" in second