Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions packages/testing/src/consensus_testing/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, Iterator, Self

from lean_spec.subspecs.containers import Attestation
from lean_spec.subspecs.containers import AttestationData
from lean_spec.subspecs.containers.slot import Slot
from lean_spec.subspecs.ssz.hash import hash_tree_root
from lean_spec.subspecs.xmss.containers import PublicKey, SecretKey, Signature
Expand Down Expand Up @@ -121,7 +121,7 @@ class XmssKeyManager:
>>> mgr = XmssKeyManager()
>>> mgr[Uint64(0)] # Get key pair
>>> mgr.get_public_key(Uint64(1)) # Get public key only
>>> mgr.sign_attestation(attestation) # Sign with auto-advancement
>>> mgr.sign_attestation_data(validator_id, attestation_data) # Sign with auto-advancement
"""

def __init__(
Expand Down Expand Up @@ -167,25 +167,29 @@ def get_all_public_keys(self) -> dict[Uint64, PublicKey]:
"""Get all public keys (from base keys, not advanced state)."""
return {idx: kp.public for idx, kp in self.keys.items()}

def sign_attestation(self, attestation: Attestation) -> Signature:
def sign_attestation_data(
self,
validator_id: Uint64,
attestation_data: AttestationData,
) -> Signature:
"""
Sign an attestation with automatic key state advancement.
Sign an attestation data with automatic key state advancement.

XMSS is stateful: signing advances the internal key state.
This method handles advancement transparently.

Args:
attestation: The attestation to sign.
validator_id: The validator index to sign the attestation data for.
attestation_data: The attestation data to sign.

Returns:
XMSS signature.

Raises:
ValueError: If slot exceeds key lifetime.
"""
idx = attestation.validator_id
epoch = attestation.data.slot
kp = self[idx]
epoch = attestation_data.slot
kp = self[validator_id]
sk = kp.secret

# Advance key state until epoch is in prepared interval
Expand All @@ -198,10 +202,10 @@ def sign_attestation(self, attestation: Attestation) -> Signature:
prepared = self.scheme.get_prepared_interval(sk)

# Cache advanced state
self._state[idx] = kp.with_secret(sk)
self._state[validator_id] = kp.with_secret(sk)

# Sign hash tree root
message = bytes(hash_tree_root(attestation))
# Sign hash tree root of the attestation data
message = bytes(hash_tree_root(attestation_data))
return self.scheme.sign(sk, epoch, message)


Expand Down
37 changes: 27 additions & 10 deletions packages/testing/src/consensus_testing/test_fixtures/fork_choice.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,17 @@
AttestationData,
SignedAttestation,
)
from lean_spec.subspecs.containers.block.block import (
from lean_spec.subspecs.containers.block import (
Block,
BlockBody,
BlockSignatures,
BlockWithAttestation,
SignedBlockWithAttestation,
)
from lean_spec.subspecs.containers.block.types import Attestations, BlockSignatures
from lean_spec.subspecs.containers.block.types import (
AggregatedAttestations,
NaiveAggregatedSignature,
)
from lean_spec.subspecs.containers.checkpoint import Checkpoint
from lean_spec.subspecs.containers.slot import Slot
from lean_spec.subspecs.containers.state import Validators
Expand Down Expand Up @@ -133,7 +137,7 @@ def set_anchor_block_default(self) -> ForkChoiceTest:
proposer_index=self.anchor_state.latest_block_header.proposer_index,
parent_root=self.anchor_state.latest_block_header.parent_root,
state_root=hash_tree_root(self.anchor_state),
body=BlockBody(attestations=Attestations(data=[])),
body=BlockBody(attestations=AggregatedAttestations(data=[])),
)
return self

Expand All @@ -153,7 +157,7 @@ def set_max_slot_default(self) -> ForkChoiceTest:
if isinstance(step, BlockStep):
max_slot_value = max(max_slot_value, int(step.block.slot))
elif isinstance(step, AttestationStep):
max_slot_value = max(max_slot_value, int(step.attestation.message.data.slot))
max_slot_value = max(max_slot_value, int(step.attestation.message.slot))

self.max_slot = Slot(max_slot_value)

Expand Down Expand Up @@ -344,15 +348,23 @@ def _build_block_from_spec(
)

# Sign all attestations and the proposer attestation
signature_list = [key_manager.sign_attestation(att) for att in attestations]
signature_list.append(key_manager.sign_attestation(proposer_attestation))
attestation_signatures = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all attestations here have same data? since we are forming a single list?

key_manager.sign_attestation_data(att.validator_id, att.data) for att in attestations
]
proposer_signature = key_manager.sign_attestation_data(
proposer_attestation.validator_id,
proposer_attestation.data,
)

return SignedBlockWithAttestation(
message=BlockWithAttestation(
block=final_block,
proposer_attestation=proposer_attestation,
),
signature=BlockSignatures(data=signature_list),
signature=BlockSignatures(
attestation_signatures=NaiveAggregatedSignature(data=attestation_signatures),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
attestation_signatures=NaiveAggregatedSignature(data=attestation_signatures),
attestation_signatures=NaiveAggregatedSignature(data=attestation_signatures),

this is a list of NaiveAggregatedSignature

proposer_signature=proposer_signature,
),
)

def _resolve_parent_root(
Expand Down Expand Up @@ -415,9 +427,13 @@ def _build_attestations_from_spec(
signed_att = self._build_signed_attestation_from_spec(
att_spec, block_registry, parent_state
)
attestations.append(signed_att.message)
attestations.append(
Attestation(validator_id=signed_att.validator_id, data=signed_att.message)
)
else:
attestations.append(att_spec.message)
attestations.append(
Attestation(validator_id=att_spec.validator_id, data=att_spec.message)
)

return attestations

Expand Down Expand Up @@ -473,7 +489,8 @@ def _build_signed_attestation_from_spec(

# Create signed attestation
return SignedAttestation(
message=attestation,
validator_id=attestation.validator_id,
message=attestation.data,
signature=(
spec.signature
or Signature(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@

from pydantic import ConfigDict, PrivateAttr, field_serializer

from lean_spec.subspecs.containers.attestation import (
aggregated_attestations_to_plain,
attestation_to_aggregated,
)
from lean_spec.subspecs.containers.block.block import Block, BlockBody
from lean_spec.subspecs.containers.block.types import Attestations
from lean_spec.subspecs.containers.block.types import AggregatedAttestations
from lean_spec.subspecs.containers.state.state import State
from lean_spec.subspecs.ssz.hash import hash_tree_root
from lean_spec.types import Bytes32, Uint64
Expand Down Expand Up @@ -204,8 +208,17 @@ def _build_block_from_spec(self, spec: BlockSpec, state: State) -> tuple[Block,
temp_state = state.process_slots(spec.slot)
parent_root = hash_tree_root(temp_state.latest_block_header)

# Extract attestations from body if provided
attestations = list(spec.body.attestations) if spec.body else []
# Extract attestations from body if provided, converting from aggregated form
# Flatten all plain attestations from all aggregated attestations
attestations = (
[
plain_att
for att in spec.body.attestations
for plain_att in aggregated_attestations_to_plain(att)
]
if spec.body
else []
)

# Handle explicit state root override
if spec.state_root is not None:
Expand All @@ -214,7 +227,7 @@ def _build_block_from_spec(self, spec: BlockSpec, state: State) -> tuple[Block,
proposer_index=proposer_index,
parent_root=parent_root,
state_root=spec.state_root,
body=spec.body or BlockBody(attestations=Attestations(data=[])),
body=spec.body or BlockBody(attestations=AggregatedAttestations(data=[])),
)
return block, None

Expand All @@ -225,7 +238,12 @@ def _build_block_from_spec(self, spec: BlockSpec, state: State) -> tuple[Block,
proposer_index=proposer_index,
parent_root=parent_root,
state_root=Bytes32.zero(),
body=spec.body or BlockBody(attestations=Attestations(data=attestations)),
body=spec.body
or BlockBody(
attestations=AggregatedAttestations(
data=[attestation_to_aggregated(att) for att in attestations]
)
),
)
return block, None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
AttestationData,
SignedAttestation,
)
from lean_spec.subspecs.containers.block.block import (
from lean_spec.subspecs.containers.block import (
BlockSignatures,
BlockWithAttestation,
SignedBlockWithAttestation,
)
from lean_spec.subspecs.containers.block.types import BlockSignatures
from lean_spec.subspecs.containers.block.types import NaiveAggregatedSignature
from lean_spec.subspecs.containers.checkpoint import Checkpoint
from lean_spec.subspecs.containers.slot import Slot
from lean_spec.subspecs.containers.state.state import State
Expand Down Expand Up @@ -216,7 +217,10 @@ def _build_block_from_spec(

# Sign proposer attestation - use valid or dummy signature based on spec
if spec.valid_signature:
proposer_attestation_signature = key_manager.sign_attestation(proposer_attestation)
proposer_attestation_signature = key_manager.sign_attestation_data(
proposer_attestation.validator_id,
proposer_attestation.data,
)
else:
# Generate an invalid dummy signature (all zeros)
from lean_spec.subspecs.xmss.constants import TEST_CONFIG
Expand All @@ -229,14 +233,15 @@ def _build_block_from_spec(
hashes=HashDigestList(data=[]),
)

signatures.append(proposer_attestation_signature)

return SignedBlockWithAttestation(
message=BlockWithAttestation(
block=final_block,
proposer_attestation=proposer_attestation,
),
signature=BlockSignatures(data=signatures),
signature=BlockSignatures(
attestation_signatures=NaiveAggregatedSignature(data=signatures),
proposer_signature=proposer_attestation_signature,
),
)

def _build_attestations_from_spec(
Expand All @@ -257,10 +262,22 @@ def _build_attestations_from_spec(
signed_attestation = self._build_signed_attestation_from_spec(
attestation_item, state, key_manager
)
attestations.append(signed_attestation.message)
# Reconstruct Attestation from SignedAttestation components
attestations.append(
Attestation(
validator_id=signed_attestation.validator_id,
data=signed_attestation.message,
)
)
attestation_signatures.append(signed_attestation.signature)
else:
attestations.append(attestation_item.message)
# Reconstruct Attestation from existing SignedAttestation
attestations.append(
Attestation(
validator_id=attestation_item.validator_id,
data=attestation_item.message,
)
)
attestation_signatures.append(attestation_item.signature)

return attestations, attestation_signatures
Expand Down Expand Up @@ -313,7 +330,10 @@ def _build_signed_attestation_from_spec(
# Sign the attestation - use dummy signature if expecting invalid signature
if spec.valid_signature:
# Generate valid signature using key manager
signature = key_manager.sign_attestation(attestation)
signature = key_manager.sign_attestation_data(
attestation.validator_id,
attestation.data,
)
else:
# Generate an invalid dummy signature (all zeros)
from lean_spec.subspecs.xmss.constants import TEST_CONFIG
Expand All @@ -328,6 +348,7 @@ def _build_signed_attestation_from_spec(

# Create signed attestation
return SignedAttestation(
message=attestation,
validator_id=attestation.validator_id,
message=attestation.data,
signature=signature,
)
Original file line number Diff line number Diff line change
Expand Up @@ -51,31 +51,31 @@ def validate_attestation(
expected = getattr(self, field_name)

if field_name == "attestation_slot":
actual = attestation.message.data.slot
actual = attestation.message.slot
if actual != expected:
raise AssertionError(
f"Step {step_index}: validator {self.validator} {location} "
f"attestation slot = {actual}, expected {expected}"
)

elif field_name == "head_slot":
actual = attestation.message.data.head.slot
actual = attestation.message.head.slot
if actual != expected:
raise AssertionError(
f"Step {step_index}: validator {self.validator} {location} "
f"head slot = {actual}, expected {expected}"
)

elif field_name == "source_slot":
actual = attestation.message.data.source.slot
actual = attestation.message.source.slot
if actual != expected:
raise AssertionError(
f"Step {step_index}: validator {self.validator} {location} "
f"source slot = {actual}, expected {expected}"
)

elif field_name == "target_slot":
actual = attestation.message.data.target.slot
actual = attestation.message.target.slot
if actual != expected:
raise AssertionError(
f"Step {step_index}: validator {self.validator} {location} "
Expand Down Expand Up @@ -442,7 +442,7 @@ def validate_against_store(
# An attestation votes for this fork if its head is this block or a descendant
weight = 0
for attestation in store.latest_known_attestations.values():
att_head_root = attestation.message.data.head.root
att_head_root = attestation.message.head.root
# Check if attestation head is this block or a descendant
if att_head_root == root:
weight += 1
Expand Down
4 changes: 2 additions & 2 deletions src/lean_spec/subspecs/containers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"""

from .attestation import (
AggregatedAttestations,
AggregatedAttestation,
AggregatedSignatures,
AggregationBits,
Attestation,
Expand All @@ -30,7 +30,7 @@
from .validator import Validator

__all__ = [
"AggregatedAttestations",
"AggregatedAttestation",
"AggregatedSignatures",
"AggregationBits",
"AttestationData",
Expand Down
Loading
Loading