Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 3 additions & 37 deletions packages/testing/src/consensus_testing/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,9 @@
import tempfile
import urllib.request
from concurrent.futures import ProcessPoolExecutor
from dataclasses import dataclass
from functools import cache, partial
from pathlib import Path
from typing import TYPE_CHECKING, Iterator, Self
from typing import TYPE_CHECKING, Iterator

from lean_spec.config import LEAN_ENV
from lean_spec.subspecs.containers import AttestationData
Expand All @@ -46,7 +45,7 @@
AttestationSignatures,
)
from lean_spec.subspecs.containers.slot import Slot
from lean_spec.subspecs.xmss.containers import PublicKey, SecretKey, Signature
from lean_spec.subspecs.xmss.containers import KeyPair, PublicKey, Signature
from lean_spec.subspecs.xmss.interface import (
PROD_SIGNATURE_SCHEME,
TEST_SIGNATURE_SCHEME,
Expand Down Expand Up @@ -120,39 +119,6 @@ def get_shared_key_manager(max_slot: Slot = _DEFAULT_MAX_SLOT) -> XmssKeyManager
"""Key lifetime in epochs (derived from DEFAULT_MAX_SLOT)."""


@dataclass(frozen=True, slots=True)
class KeyPair:
"""
Immutable XMSS key pair for a validator.

Attributes:
public: Public key for signature verification.
secret: Secret key containing Merkle tree structures.
"""

public: PublicKey
secret: SecretKey

@classmethod
def from_dict(cls, data: Mapping[str, str]) -> Self:
"""Deserialize from JSON-compatible dict with hex-encoded SSZ."""
return cls(
public=PublicKey.decode_bytes(bytes.fromhex(data["public"])),
secret=SecretKey.decode_bytes(bytes.fromhex(data["secret"])),
)

def to_dict(self) -> dict[str, str]:
"""Serialize to JSON-compatible dict with hex-encoded SSZ."""
return {
"public": self.public.encode_bytes().hex(),
"secret": self.secret.encode_bytes().hex(),
}

def with_secret(self, secret: SecretKey) -> KeyPair:
"""Return a new KeyPair with updated secret key (for state advancement)."""
return KeyPair(public=self.public, secret=secret)


def _get_keys_dir(scheme_name: str) -> Path:
"""Get the keys directory path for the given scheme."""
return Path(__file__).parent / "test_keys" / f"{scheme_name}_scheme"
Expand Down Expand Up @@ -298,7 +264,7 @@ def sign_attestation_data(
prepared = self.scheme.get_prepared_interval(sk)

# Cache advanced state
self._state[validator_id] = kp.with_secret(sk)
self._state[validator_id] = kp._replace(secret=sk)

# Sign hash tree root of the attestation data
message = attestation_data.data_root_bytes()
Expand Down
30 changes: 29 additions & 1 deletion src/lean_spec/subspecs/xmss/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Mapping, NamedTuple

from ...types import Uint64
from ...types.container import Container
Expand Down Expand Up @@ -181,3 +181,31 @@ class SecretKey(Container):
Together with `left_bottom_tree`, this provides a prepared interval of
exactly `2 * sqrt(LIFETIME)` consecutive epochs.
"""


class KeyPair(NamedTuple):
"""
Immutable XMSS key pair for a validator.

Attributes:
public: Public key for signature verification.
secret: Secret key containing Merkle tree structures.
"""

public: PublicKey
secret: SecretKey

@classmethod
def from_dict(cls, data: Mapping[str, str]) -> "KeyPair":
"""Deserialize from JSON-compatible dict with hex-encoded SSZ."""
return cls(
public=PublicKey.decode_bytes(bytes.fromhex(data["public"])),
secret=SecretKey.decode_bytes(bytes.fromhex(data["secret"])),
)

def to_dict(self) -> dict[str, str]:
"""Serialize to JSON-compatible dict with hex-encoded SSZ."""
return {
"public": self.public.encode_bytes().hex(),
"secret": self.secret.encode_bytes().hex(),
}
10 changes: 4 additions & 6 deletions src/lean_spec/subspecs/xmss/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
TEST_CONFIG,
XmssConfig,
)
from .containers import PublicKey, SecretKey, Signature
from .containers import KeyPair, PublicKey, SecretKey, Signature
from .prf import PROD_PRF, TEST_PRF, Prf
from .rand import PROD_RAND, TEST_RAND, Rand
from .subtree import HashSubTree, combined_path, verify_path
Expand Down Expand Up @@ -73,9 +73,7 @@ def _validate_strict_types(self) -> "GeneralizedXmssScheme":
)
return self

def key_gen(
self, activation_epoch: Uint64, num_active_epochs: Uint64
) -> tuple[PublicKey, SecretKey]:
def key_gen(self, activation_epoch: Uint64, num_active_epochs: Uint64) -> KeyPair:
"""
Generates a new cryptographic key pair for a specified range of epochs.

Expand Down Expand Up @@ -120,7 +118,7 @@ def key_gen(
- Will be rounded up to at least `2 * sqrt(LIFETIME)`.

Returns:
A tuple containing the `PublicKey` and `SecretKey`.
A `KeyPair` containing the public and secret keys.

Note:
The actual activation epoch and num_active_epochs in the returned SecretKey
Expand Down Expand Up @@ -220,7 +218,7 @@ def key_gen(
left_bottom_tree=left_bottom_tree,
right_bottom_tree=right_bottom_tree,
)
return pk, sk
return KeyPair(public=pk, secret=sk)

def sign(self, sk: SecretKey, epoch: Uint64, message: bytes) -> Signature:
"""
Expand Down
Loading