Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 185 additions & 47 deletions packages/testing/src/consensus_testing/keys.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
"""XMSS key management utilities for testing."""

from typing import NamedTuple, Optional
from __future__ import annotations

import random
from typing import Any, NamedTuple, Optional

from lean_spec.subspecs.containers import Attestation, Signature
from lean_spec.subspecs.containers.slot import Slot
from lean_spec.subspecs.koalabear import Fp, P
from lean_spec.subspecs.ssz.hash import hash_tree_root
from lean_spec.subspecs.xmss.constants import PRF_KEY_LENGTH, XmssConfig
from lean_spec.subspecs.xmss.containers import PublicKey, SecretKey
from lean_spec.subspecs.xmss.interface import (
TEST_SIGNATURE_SCHEME,
GeneralizedXmssScheme,
)
from lean_spec.types import ValidatorIndex
from lean_spec.subspecs.xmss.prf import Prf
from lean_spec.subspecs.xmss.utils import Rand
from lean_spec.types import Uint64, ValidatorIndex


class KeyPair(NamedTuple):
Expand All @@ -23,24 +30,63 @@ class KeyPair(NamedTuple):
"""The validator's secret key (used for signing)."""


_KEY_CACHE: dict[tuple[int, int], KeyPair] = {}
_KEY_CACHE: dict[tuple[int, int, int, int | None], KeyPair] = {}
"""
Cache keys across tests to avoid regenerating them for the same validator/lifetime combo.

Key: (validator_index, num_active_epochs) -> KeyPair
Key: (validator_index, activation_epoch, num_active_epochs, seed) -> KeyPair
"""


def _to_int(value: int | Slot | Uint64 | None, default: int = 0) -> int:
"""Normalize Slot/Uint64/int to int with an optional default."""
if value is None:
return default
if isinstance(value, Slot):
return value.as_int()
return int(value)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not have this, this is not clean to have. We should probably rely only on Uint64 or int and not have this. This is a primitive conversion stuff that should not be inside the testing framework folder.



class SeededRand(Rand):
"""Deterministic Rand helper to make key generation repeatable in tests."""

def __init__(self, config: XmssConfig, seed: int) -> None:
"""Initialize with a deterministic seed."""
super().__init__(config)
self._rng = random.Random(seed)

def field_elements(self, length: int) -> list[Fp]:
"""Generate deterministic field elements from the seeded RNG."""
return [Fp(value=self._rng.randrange(P)) for _ in range(length)]


class SeededPrf(Prf):
"""Deterministic PRF helper for repeatable PRF key generation."""

def __init__(self, config: XmssConfig, seed: int) -> None:
"""Initialize with a deterministic seed."""
super().__init__(config)
self._rng = random.Random(seed)

def key_gen(self) -> bytes:
"""Generate a deterministic PRF key for repeatable tests."""
# Use a deterministic stream rather than os.urandom for repeatability in tests.
return self._rng.randbytes(PRF_KEY_LENGTH)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be allowed because we want the PRF and RAND to be based on the XMSS paper, so we can't just put whatever we want there. Here, I've created a pull request to prevent this behavior: #175

Now this should error out



class XmssKeyManager:
"""Lazy key manager for test validators using XMSS signatures."""

DEFAULT_MAX_SLOT = Slot(100)
"""Default maximum slot horizon if not specified."""
DEFAULT_ACTIVATION_EPOCH = 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a doc string to explain


def __init__(
self,
max_slot: Optional[Slot] = None,
scheme: GeneralizedXmssScheme = TEST_SIGNATURE_SCHEME,
default_activation_epoch: int | Slot | Uint64 = DEFAULT_ACTIVATION_EPOCH,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably have only Uint64 here to avoid additional complexities

default_seed: int | None = 0,
) -> None:
"""
Initialize the key manager.
Expand All @@ -53,6 +99,11 @@ def __init__(
scheme : GeneralizedXmssScheme, optional
The XMSS scheme to use.
Defaults to `TEST_SIGNATURE_SCHEME`.
default_activation_epoch : int | Slot | Uint64, optional
Activation epoch used when none is provided for key generation.
default_seed : int | None, optional
Seed for deterministic key generation. Set to None to use non-deterministic
randomness from the underlying XMSS scheme.

Notes:
-----
Expand All @@ -61,7 +112,92 @@ def __init__(
"""
self.max_slot = max_slot if max_slot is not None else self.DEFAULT_MAX_SLOT
self.scheme = scheme
self.default_activation_epoch = _to_int(
default_activation_epoch, self.DEFAULT_ACTIVATION_EPOCH
)
self.default_seed = default_seed
self._key_pairs: dict[ValidatorIndex, KeyPair] = {}
self._key_metadata: dict[ValidatorIndex, dict[str, Any]] = {}
self._schemes_by_seed: dict[int, GeneralizedXmssScheme] = {}

@property
def default_num_active_epochs(self) -> int:
"""Default lifetime derived from the configured max_slot."""
return self.max_slot.as_int() + 1
Comment on lines 113 to 115
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the function naming is misleading here. You can maybe name it default_max_epoch and use DEFAULT_MAX_SLOT here.


def _scheme_for_seed(self, seed: int | None) -> GeneralizedXmssScheme:
"""
Return a scheme instance appropriate for the provided seed.

A deterministic scheme (SeededRand + SeededPrf) is returned when a specific
seed is provided; otherwise the base scheme is used.
"""
if seed is None:
return self.scheme

if seed not in self._schemes_by_seed:
self._schemes_by_seed[seed] = GeneralizedXmssScheme(
config=self.scheme.config,
prf=SeededPrf(self.scheme.config, seed),
hasher=self.scheme.hasher,
merkle_tree=self.scheme.merkle_tree,
encoder=self.scheme.encoder,
rand=SeededRand(self.scheme.config, seed),
)

return self._schemes_by_seed[seed]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that we need to use this


def create_and_store_key_pair(
self,
validator_index: ValidatorIndex,
*,
activation_epoch: int | Slot | Uint64 | None = None,
num_active_epochs: int | Slot | Uint64 | None = None,
seed: int | None = None,
) -> KeyPair:
"""
Generate and store a key pair with explicit control over key generation.

Parameters
----------
validator_index : ValidatorIndex
The validator for whom a key pair should be generated.
activation_epoch : int | Slot | Uint64, optional
First epoch for which the key is valid. Defaults to `default_activation_epoch`.
num_active_epochs : int | Slot | Uint64, optional
Number of consecutive epochs the key should remain active.
Defaults to `max_slot + 1` (to include genesis).
seed : int | None, optional
Seed used for deterministic key generation. If None, the base scheme's
randomness is used.
"""
activation_epoch_int = _to_int(activation_epoch, self.default_activation_epoch)
num_active_epochs_int = _to_int(num_active_epochs, self.default_num_active_epochs)
key_seed = seed if seed is not None else self.default_seed

scheme = self._scheme_for_seed(key_seed)

cache_key = (
int(validator_index),
activation_epoch_int,
num_active_epochs_int,
key_seed,
)

if cache_key in _KEY_CACHE:
key_pair = _KEY_CACHE[cache_key]
else:
pk, sk = scheme.key_gen(Uint64(activation_epoch_int), Uint64(num_active_epochs_int))
key_pair = KeyPair(public=pk, secret=sk)
_KEY_CACHE[cache_key] = key_pair

self._key_pairs[validator_index] = key_pair
self._key_metadata[validator_index] = {
"activation_epoch": activation_epoch_int,
"num_active_epochs": num_active_epochs_int,
"seed": key_seed,
}
return key_pair

def __getitem__(self, validator_index: ValidatorIndex) -> KeyPair:
"""
Expand All @@ -83,37 +219,10 @@ def __getitem__(self, validator_index: ValidatorIndex) -> KeyPair:
- Keys are deterministic for testing (`seed=0`).
- Lifetime = `max_slot + 1` to include the genesis slot.
"""
# Return cached keys if they exist.
if validator_index in self._key_pairs:
return self._key_pairs[validator_index]

# Generate New Key Pair
#
# XMSS requires knowing the total number of signatures in advance.
# We use max_slot + 1 as the lifetime since:
# - Validators may sign once per slot (attestations)
# - We include slot 0 (genesis) in the count
num_active_epochs = self.max_slot.as_int() + 1

# Check global cache first (keys are reused across tests)
cache_key = (int(validator_index), num_active_epochs)
if cache_key in _KEY_CACHE:
key_pair = _KEY_CACHE[cache_key]
self._key_pairs[validator_index] = key_pair
return key_pair

# Generate the key pair using the default XMSS scheme.
#
# The seed is set to 0 for deterministic test keys.
from lean_spec.types import Uint64

pk, sk = self.scheme.key_gen(Uint64(0), Uint64(num_active_epochs))

# Store as a cohesive unit and return.
key_pair = KeyPair(public=pk, secret=sk)
_KEY_CACHE[cache_key] = key_pair # Cache globally for reuse across tests
self._key_pairs[validator_index] = key_pair
return key_pair
return self.create_and_store_key_pair(validator_index)

def sign_attestation(self, attestation: Attestation) -> Signature:
"""
Expand Down Expand Up @@ -143,27 +252,37 @@ def sign_attestation(self, attestation: Attestation) -> Signature:
# Get the current secret key
sk = key_pair.secret

metadata = self._key_metadata.get(
validator_id,
{
"seed": self.default_seed,
"activation_epoch": self.default_activation_epoch,
"num_active_epochs": self.default_num_active_epochs,
},
)
scheme = self._scheme_for_seed(metadata.get("seed"))

# Map the attestation slot to an XMSS epoch.
#
# Each slot gets its own epoch to avoid key reuse.
epoch = attestation.data.slot

# Loop until the epoch is inside the prepared interval
prepared_interval = self.scheme.get_prepared_interval(sk)
prepared_interval = scheme.get_prepared_interval(sk)
while int(epoch) not in prepared_interval:
# Check if we're advancing past the key's total lifetime
activation_interval = self.scheme.get_activation_interval(sk)
activation_interval = scheme.get_activation_interval(sk)
if prepared_interval.stop >= activation_interval.stop:
raise ValueError(
f"Cannot sign for epoch {epoch}: "
f"it is beyond the key's max lifetime {activation_interval.stop}"
)

# Advance the key and get the new key object
sk = self.scheme.advance_preparation(sk)
sk = scheme.advance_preparation(sk)

# Update the prepared interval for the next loop check
prepared_interval = self.scheme.get_prepared_interval(sk)
prepared_interval = scheme.get_prepared_interval(sk)

# Update the cached key pair with the new, advanced secret key.
# This ensures the *next* call to sign() uses the advanced state.
Expand All @@ -175,18 +294,10 @@ def sign_attestation(self, attestation: Attestation) -> Signature:
message = bytes(hash_tree_root(attestation))

# Generate the XMSS signature using the validator's (now prepared) secret key.
xmss_sig = self.scheme.sign(sk, epoch, message)

# Convert the signature to the wire format (byte array).
signature_bytes = xmss_sig.to_bytes(self.scheme.config)

# Ensure the signature meets the consensus spec length (3100 bytes).
#
# This is necessary when using TEST_CONFIG (796 bytes) vs PROD_CONFIG.
# Padding with zeros on the right maintains compatibility.
padded_bytes = signature_bytes.ljust(Signature.LENGTH, b"\x00")
xmss_sig = scheme.sign(sk, epoch, message)

return Signature(padded_bytes)
# Convert to the consensus Signature container (handles padding internally).
return Signature.from_xmss(xmss_sig, scheme)

def get_public_key(self, validator_index: ValidatorIndex) -> PublicKey:
"""
Expand Down Expand Up @@ -225,3 +336,30 @@ def __contains__(self, validator_index: ValidatorIndex) -> bool:
def __len__(self) -> int:
"""Return the number of validators with generated keys."""
return len(self._key_pairs)

def export_test_vectors(self, include_private_keys: bool = False) -> list[dict[str, Any]]:
"""
Export generated keys in a JSON-serializable structure for downstream clients.

Parameters
----------
include_private_keys : bool
When True, include the full secret key dump; otherwise only public data.
"""
vectors: list[dict[str, Any]] = []
for validator_index, key_pair in self._key_pairs.items():
meta = self._key_metadata.get(validator_index, {})
entry: dict[str, Any] = {
"validator_index": int(validator_index),
"activation_epoch": meta.get("activation_epoch"),
"num_active_epochs": meta.get("num_active_epochs"),
"seed": meta.get("seed"),
"public_key": key_pair.public.to_bytes(self.scheme.config).hex(),
}
if include_private_keys:
# Pydantic models are JSON-serializable; keep the raw dump for full fidelity.
entry["secret_key"] = key_pair.secret.model_dump(mode="json")

vectors.append(entry)

return vectors
19 changes: 19 additions & 0 deletions src/lean_spec/subspecs/containers/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,22 @@ def verify(
return scheme.verify(public_key, epoch, message, signature)
except Exception:
return False

@classmethod
def from_xmss(
cls, xmss_signature: XmssSignature, scheme: GeneralizedXmssScheme = TEST_SIGNATURE_SCHEME
) -> "Signature":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: change it to Signature by adding from __future__ import annotations. I know this isn't consistent across the repo but direct refs look cleaner.

"""
Create a consensus `Signature` container from an XMSS signature object.

Handles padding to the fixed 3100-byte length required by the consensus layer,
delegating all encoding details to the XMSS container itself.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is confusing to put the raw 3100 number here directly, we don't really understand from where it comes from

Suggested change
Handles padding to the fixed 3100-byte length required by the consensus layer,
delegating all encoding details to the XMSS container itself.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""
raw = xmss_signature.to_bytes(scheme.config)
if len(raw) > cls.LENGTH:
raise ValueError(
f"XMSS signature length {len(raw)} exceeds container size {cls.LENGTH}"
)

# Pad on the right to the fixed-length container expected by consensus.
return cls(raw.ljust(cls.LENGTH, b"\x00"))
Loading