Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 65 additions & 11 deletions packages/testing/src/consensus_testing/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@
from pathlib import Path
from typing import TYPE_CHECKING, Iterator, Self

from lean_spec.subspecs.containers import Attestation
from lean_spec.subspecs.containers import AttestationData
from lean_spec.subspecs.containers.block.types import (
AggregatedAttestations,
AttestationSignatures,
NaiveAggregatedSignatures,
)
from lean_spec.subspecs.containers.slot import Slot
from lean_spec.subspecs.ssz.hash import hash_tree_root
from lean_spec.subspecs.xmss.containers import PublicKey, SecretKey, Signature
Expand Down Expand Up @@ -121,7 +126,7 @@ class XmssKeyManager:
>>> mgr = XmssKeyManager()
>>> mgr[Uint64(0)] # Get key pair
>>> mgr.get_public_key(Uint64(1)) # Get public key only
>>> mgr.sign_attestation(attestation) # Sign with auto-advancement
>>> mgr.sign_attestation_data(validator_id, attestation_data) # Sign with auto-advancement
"""

def __init__(
Expand Down Expand Up @@ -167,25 +172,29 @@ def get_all_public_keys(self) -> dict[Uint64, PublicKey]:
"""Get all public keys (from base keys, not advanced state)."""
return {idx: kp.public for idx, kp in self.keys.items()}

def sign_attestation(self, attestation: Attestation) -> Signature:
def sign_attestation_data(
self,
validator_id: Uint64,
attestation_data: AttestationData,
) -> Signature:
"""
Sign an attestation with automatic key state advancement.
Sign an attestation data with automatic key state advancement.

XMSS is stateful: signing advances the internal key state.
This method handles advancement transparently.

Args:
attestation: The attestation to sign.
validator_id: The validator index to sign the attestation data for.
attestation_data: The attestation data to sign.

Returns:
XMSS signature.

Raises:
ValueError: If slot exceeds key lifetime.
"""
idx = attestation.validator_id
epoch = attestation.data.slot
kp = self[idx]
epoch = attestation_data.slot
kp = self[validator_id]
sk = kp.secret

# Advance key state until epoch is in prepared interval
Expand All @@ -198,12 +207,57 @@ def sign_attestation(self, attestation: Attestation) -> Signature:
prepared = self.scheme.get_prepared_interval(sk)

# Cache advanced state
self._state[idx] = kp.with_secret(sk)
self._state[validator_id] = kp.with_secret(sk)

# Sign hash tree root
message = bytes(hash_tree_root(attestation))
# Sign hash tree root of the attestation data
message = bytes(hash_tree_root(attestation_data))
return self.scheme.sign(sk, epoch, message)

def build_attestation_signatures(
self,
aggregated_attestations: AggregatedAttestations,
signature_lookup: Mapping[tuple[Uint64, bytes], Signature] | None = None,
) -> AttestationSignatures:
"""
Build `AttestationSignatures` for already-aggregated attestations.

This is a convenience helper for tests/fixtures that need to produce
`BlockSignatures.attestation_signatures` for a block.

Args:
aggregated_attestations: Iterable of aggregated attestation containers.
Each item is expected to have:
- `.data` (AttestationData)
- `.aggregation_bits.to_validator_indices()` (Iterable[Uint64])
signature_lookup: Optional override map keyed by
`(validator_id, bytes(hash_tree_root(attestation_data))) -> signature`.
When provided and a key exists, that signature is used instead of signing.

Returns:
AttestationSignatures matching the ordering of `aggregated_attestations`
and per-attestation validator index ordering.
"""
return AttestationSignatures(
data=[
NaiveAggregatedSignatures(
data=[
(
signature_lookup.get(
(validator_id, aggregated_attestation.data.data_root_bytes())
)
if signature_lookup is not None
else None
)
or self.sign_attestation_data(validator_id, aggregated_attestation.data)
for validator_id in (
aggregated_attestation.aggregation_bits.to_validator_indices()
)
]
)
for aggregated_attestation in aggregated_attestations
]
)


def _generate_single_keypair(num_epochs: int) -> dict[str, str]:
"""Generate one key pair (module-level for pickling in ProcessPoolExecutor)."""
Expand Down
37 changes: 27 additions & 10 deletions packages/testing/src/consensus_testing/test_fixtures/fork_choice.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@
AttestationData,
SignedAttestation,
)
from lean_spec.subspecs.containers.block.block import (
from lean_spec.subspecs.containers.block import (
Block,
BlockBody,
BlockSignatures,
BlockWithAttestation,
SignedBlockWithAttestation,
)
from lean_spec.subspecs.containers.block.types import Attestations, BlockSignatures
from lean_spec.subspecs.containers.block.types import (
AggregatedAttestations,
)
from lean_spec.subspecs.containers.checkpoint import Checkpoint
from lean_spec.subspecs.containers.slot import Slot
from lean_spec.subspecs.containers.state import Validators
Expand Down Expand Up @@ -133,7 +136,7 @@ def set_anchor_block_default(self) -> ForkChoiceTest:
proposer_index=self.anchor_state.latest_block_header.proposer_index,
parent_root=self.anchor_state.latest_block_header.parent_root,
state_root=hash_tree_root(self.anchor_state),
body=BlockBody(attestations=Attestations(data=[])),
body=BlockBody(attestations=AggregatedAttestations(data=[])),
)
return self

Expand All @@ -153,7 +156,7 @@ def set_max_slot_default(self) -> ForkChoiceTest:
if isinstance(step, BlockStep):
max_slot_value = max(max_slot_value, int(step.block.slot))
elif isinstance(step, AttestationStep):
max_slot_value = max(max_slot_value, int(step.attestation.message.data.slot))
max_slot_value = max(max_slot_value, int(step.attestation.message.slot))

self.max_slot = Slot(max_slot_value)

Expand Down Expand Up @@ -344,15 +347,24 @@ def _build_block_from_spec(
)

# Sign all attestations and the proposer attestation
signature_list = [key_manager.sign_attestation(att) for att in attestations]
signature_list.append(key_manager.sign_attestation(proposer_attestation))
attestation_signatures = key_manager.build_attestation_signatures(
final_block.body.attestations
)

proposer_signature = key_manager.sign_attestation_data(
proposer_attestation.validator_id,
proposer_attestation.data,
)

return SignedBlockWithAttestation(
message=BlockWithAttestation(
block=final_block,
proposer_attestation=proposer_attestation,
),
signature=BlockSignatures(data=signature_list),
signature=BlockSignatures(
attestation_signatures=attestation_signatures,
proposer_signature=proposer_signature,
),
)

def _resolve_parent_root(
Expand Down Expand Up @@ -415,9 +427,13 @@ def _build_attestations_from_spec(
signed_att = self._build_signed_attestation_from_spec(
att_spec, block_registry, parent_state
)
attestations.append(signed_att.message)
attestations.append(
Attestation(validator_id=signed_att.validator_id, data=signed_att.message)
)
else:
attestations.append(att_spec.message)
attestations.append(
Attestation(validator_id=att_spec.validator_id, data=att_spec.message)
)

return attestations

Expand Down Expand Up @@ -473,7 +489,8 @@ def _build_signed_attestation_from_spec(

# Create signed attestation
return SignedAttestation(
message=attestation,
validator_id=attestation.validator_id,
message=attestation.data,
signature=(
spec.signature
or Signature(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pydantic import ConfigDict, PrivateAttr, field_serializer

from lean_spec.subspecs.containers.block.block import Block, BlockBody
from lean_spec.subspecs.containers.block.types import Attestations
from lean_spec.subspecs.containers.block.types import AggregatedAttestations
from lean_spec.subspecs.containers.state.state import State
from lean_spec.subspecs.ssz.hash import hash_tree_root
from lean_spec.types import Bytes32, Uint64
Expand Down Expand Up @@ -205,7 +205,9 @@ def _build_block_from_spec(self, spec: BlockSpec, state: State) -> tuple[Block,
parent_root = hash_tree_root(temp_state.latest_block_header)

# Extract attestations from body if provided
attestations = list(spec.body.attestations) if spec.body else []
aggregated_attestations = (
spec.body.attestations if spec.body else AggregatedAttestations(data=[])
)

# Handle explicit state root override
if spec.state_root is not None:
Expand All @@ -214,7 +216,7 @@ def _build_block_from_spec(self, spec: BlockSpec, state: State) -> tuple[Block,
proposer_index=proposer_index,
parent_root=parent_root,
state_root=spec.state_root,
body=spec.body or BlockBody(attestations=Attestations(data=[])),
body=spec.body or BlockBody(attestations=aggregated_attestations),
)
return block, None

Expand All @@ -225,7 +227,7 @@ def _build_block_from_spec(self, spec: BlockSpec, state: State) -> tuple[Block,
proposer_index=proposer_index,
parent_root=parent_root,
state_root=Bytes32.zero(),
body=spec.body or BlockBody(attestations=Attestations(data=attestations)),
body=spec.body or BlockBody(attestations=aggregated_attestations),
)
return block, None

Expand All @@ -234,6 +236,10 @@ def _build_block_from_spec(self, spec: BlockSpec, state: State) -> tuple[Block,
slot=spec.slot,
proposer_index=proposer_index,
parent_root=parent_root,
attestations=attestations,
attestations=[
attestation
for aggregated_attestation in aggregated_attestations
for attestation in aggregated_attestation.to_plain()
],
)
return block, post_state
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
AttestationData,
SignedAttestation,
)
from lean_spec.subspecs.containers.block.block import (
from lean_spec.subspecs.containers.block import (
BlockSignatures,
BlockWithAttestation,
SignedBlockWithAttestation,
)
from lean_spec.subspecs.containers.block.types import BlockSignatures
from lean_spec.subspecs.containers.checkpoint import Checkpoint
from lean_spec.subspecs.containers.slot import Slot
from lean_spec.subspecs.containers.state.state import State
Expand Down Expand Up @@ -192,7 +192,9 @@ def _build_block_from_spec(
parent_root = hash_tree_root(parent_state.latest_block_header)

# Build attestations from spec
attestations, signatures = self._build_attestations_from_spec(spec, state, key_manager)
attestations, attestation_signature_inputs = self._build_attestations_from_spec(
spec, state, key_manager
)

# Use State.build_block for core block building (pure spec logic)
final_block, _, _, _ = state.build_block(
Expand All @@ -202,6 +204,20 @@ def _build_block_from_spec(
attestations=attestations,
)

# Preserve per-attestation validity from the spec.
#
# For signature tests we must ensure that the signatures in the input spec are used
# for any intentionally-invalid signature from the input spec remains invalid
# in the produced `SignedBlockWithAttestation`.
signature_lookup: dict[tuple[Uint64, bytes], Signature] = {
(att.validator_id, bytes(hash_tree_root(att.data))): sig
for att, sig in zip(attestations, attestation_signature_inputs, strict=True)
}
attestation_signatures = key_manager.build_attestation_signatures(
final_block.body.attestations,
signature_lookup=signature_lookup,
)

# Create proposer attestation for this block
block_root = hash_tree_root(final_block)
proposer_attestation = Attestation(
Expand All @@ -216,7 +232,10 @@ def _build_block_from_spec(

# Sign proposer attestation - use valid or dummy signature based on spec
if spec.valid_signature:
proposer_attestation_signature = key_manager.sign_attestation(proposer_attestation)
proposer_attestation_signature = key_manager.sign_attestation_data(
proposer_attestation.validator_id,
proposer_attestation.data,
)
else:
# Generate an invalid dummy signature (all zeros)
from lean_spec.subspecs.xmss.constants import TEST_CONFIG
Expand All @@ -229,14 +248,15 @@ def _build_block_from_spec(
hashes=HashDigestList(data=[]),
)

signatures.append(proposer_attestation_signature)

return SignedBlockWithAttestation(
message=BlockWithAttestation(
block=final_block,
proposer_attestation=proposer_attestation,
),
signature=BlockSignatures(data=signatures),
signature=BlockSignatures(
attestation_signatures=attestation_signatures,
proposer_signature=proposer_attestation_signature,
),
)

def _build_attestations_from_spec(
Expand All @@ -257,10 +277,22 @@ def _build_attestations_from_spec(
signed_attestation = self._build_signed_attestation_from_spec(
attestation_item, state, key_manager
)
attestations.append(signed_attestation.message)
# Reconstruct Attestation from SignedAttestation components
attestations.append(
Attestation(
validator_id=signed_attestation.validator_id,
data=signed_attestation.message,
)
)
attestation_signatures.append(signed_attestation.signature)
else:
attestations.append(attestation_item.message)
# Reconstruct Attestation from existing SignedAttestation
attestations.append(
Attestation(
validator_id=attestation_item.validator_id,
data=attestation_item.message,
)
)
attestation_signatures.append(attestation_item.signature)

return attestations, attestation_signatures
Expand Down Expand Up @@ -313,7 +345,10 @@ def _build_signed_attestation_from_spec(
# Sign the attestation - use dummy signature if expecting invalid signature
if spec.valid_signature:
# Generate valid signature using key manager
signature = key_manager.sign_attestation(attestation)
signature = key_manager.sign_attestation_data(
attestation.validator_id,
attestation.data,
)
else:
# Generate an invalid dummy signature (all zeros)
from lean_spec.subspecs.xmss.constants import TEST_CONFIG
Expand All @@ -328,6 +363,7 @@ def _build_signed_attestation_from_spec(

# Create signed attestation
return SignedAttestation(
message=attestation,
validator_id=attestation.validator_id,
message=attestation.data,
signature=signature,
)
Loading
Loading