diff --git a/packages/testing/src/consensus_testing/keys.py b/packages/testing/src/consensus_testing/keys.py index ce2f6bdb..0dad9b0f 100644 --- a/packages/testing/src/consensus_testing/keys.py +++ b/packages/testing/src/consensus_testing/keys.py @@ -29,9 +29,13 @@ from pathlib import Path from typing import TYPE_CHECKING, Iterator, Self -from lean_spec.subspecs.containers import Attestation +from lean_spec.subspecs.containers import AttestationData +from lean_spec.subspecs.containers.attestation.types import NaiveAggregatedSignature +from lean_spec.subspecs.containers.block.types import ( + AggregatedAttestations, + AttestationSignatures, +) from lean_spec.subspecs.containers.slot import Slot -from lean_spec.subspecs.ssz.hash import hash_tree_root from lean_spec.subspecs.xmss.containers import PublicKey, SecretKey, Signature from lean_spec.subspecs.xmss.interface import TEST_SIGNATURE_SCHEME, GeneralizedXmssScheme from lean_spec.types import Uint64 @@ -121,7 +125,7 @@ class XmssKeyManager: >>> mgr = XmssKeyManager() >>> mgr[Uint64(0)] # Get key pair >>> mgr.get_public_key(Uint64(1)) # Get public key only - >>> mgr.sign_attestation(attestation) # Sign with auto-advancement + >>> mgr.sign_attestation_data(validator_id, attestation_data) # Sign with auto-advancement """ def __init__( @@ -167,15 +171,20 @@ def get_all_public_keys(self) -> dict[Uint64, PublicKey]: """Get all public keys (from base keys, not advanced state).""" return {idx: kp.public for idx, kp in self.keys.items()} - def sign_attestation(self, attestation: Attestation) -> Signature: + def sign_attestation_data( + self, + validator_id: Uint64, + attestation_data: AttestationData, + ) -> Signature: """ - Sign an attestation with automatic key state advancement. + Sign an attestation data with automatic key state advancement. XMSS is stateful: signing advances the internal key state. This method handles advancement transparently. Args: - attestation: The attestation to sign. + validator_id: The validator index to sign the attestation data for. + attestation_data: The attestation data to sign. Returns: XMSS signature. @@ -183,9 +192,8 @@ def sign_attestation(self, attestation: Attestation) -> Signature: Raises: ValueError: If slot exceeds key lifetime. """ - idx = attestation.validator_id - epoch = attestation.data.slot - kp = self[idx] + epoch = attestation_data.slot + kp = self[validator_id] sk = kp.secret # Advance key state until epoch is in prepared interval @@ -198,12 +206,34 @@ def sign_attestation(self, attestation: Attestation) -> Signature: prepared = self.scheme.get_prepared_interval(sk) # Cache advanced state - self._state[idx] = kp.with_secret(sk) + self._state[validator_id] = kp.with_secret(sk) - # Sign hash tree root - message = bytes(hash_tree_root(attestation)) + # Sign hash tree root of the attestation data + message = attestation_data.data_root_bytes() return self.scheme.sign(sk, epoch, message) + def build_attestation_signatures( + self, + aggregated_attestations: AggregatedAttestations, + signature_lookup: Mapping[tuple[Uint64, bytes], Signature] | None = None, + ) -> AttestationSignatures: + """Build `AttestationSignatures` for already-aggregated attestations.""" + lookup = signature_lookup or {} + return AttestationSignatures( + data=[ + NaiveAggregatedSignature( + data=[ + ( + lookup.get((vid, agg.data.data_root_bytes())) + or self.sign_attestation_data(vid, agg.data) + ) + for vid in agg.aggregation_bits.to_validator_indices() + ] + ) + for agg in aggregated_attestations + ] + ) + def _generate_single_keypair(num_epochs: int) -> dict[str, str]: """Generate one key pair (module-level for pickling in ProcessPoolExecutor).""" diff --git a/packages/testing/src/consensus_testing/test_fixtures/fork_choice.py b/packages/testing/src/consensus_testing/test_fixtures/fork_choice.py index b4d3d133..bef3295e 100644 --- a/packages/testing/src/consensus_testing/test_fixtures/fork_choice.py +++ b/packages/testing/src/consensus_testing/test_fixtures/fork_choice.py @@ -13,13 +13,16 @@ AttestationData, SignedAttestation, ) -from lean_spec.subspecs.containers.block.block import ( +from lean_spec.subspecs.containers.block import ( Block, BlockBody, + BlockSignatures, BlockWithAttestation, SignedBlockWithAttestation, ) -from lean_spec.subspecs.containers.block.types import Attestations, BlockSignatures +from lean_spec.subspecs.containers.block.types import ( + AggregatedAttestations, +) from lean_spec.subspecs.containers.checkpoint import Checkpoint from lean_spec.subspecs.containers.slot import Slot from lean_spec.subspecs.containers.state import Validators @@ -133,7 +136,7 @@ def set_anchor_block_default(self) -> ForkChoiceTest: proposer_index=self.anchor_state.latest_block_header.proposer_index, parent_root=self.anchor_state.latest_block_header.parent_root, state_root=hash_tree_root(self.anchor_state), - body=BlockBody(attestations=Attestations(data=[])), + body=BlockBody(attestations=AggregatedAttestations(data=[])), ) return self @@ -153,7 +156,7 @@ def set_max_slot_default(self) -> ForkChoiceTest: if isinstance(step, BlockStep): max_slot_value = max(max_slot_value, int(step.block.slot)) elif isinstance(step, AttestationStep): - max_slot_value = max(max_slot_value, int(step.attestation.message.data.slot)) + max_slot_value = max(max_slot_value, int(step.attestation.message.slot)) self.max_slot = Slot(max_slot_value) @@ -344,15 +347,24 @@ def _build_block_from_spec( ) # Sign all attestations and the proposer attestation - signature_list = [key_manager.sign_attestation(att) for att in attestations] - signature_list.append(key_manager.sign_attestation(proposer_attestation)) + attestation_signatures = key_manager.build_attestation_signatures( + final_block.body.attestations + ) + + proposer_signature = key_manager.sign_attestation_data( + proposer_attestation.validator_id, + proposer_attestation.data, + ) return SignedBlockWithAttestation( message=BlockWithAttestation( block=final_block, proposer_attestation=proposer_attestation, ), - signature=BlockSignatures(data=signature_list), + signature=BlockSignatures( + attestation_signatures=attestation_signatures, + proposer_signature=proposer_signature, + ), ) def _resolve_parent_root( @@ -415,9 +427,13 @@ def _build_attestations_from_spec( signed_att = self._build_signed_attestation_from_spec( att_spec, block_registry, parent_state ) - attestations.append(signed_att.message) + attestations.append( + Attestation(validator_id=signed_att.validator_id, data=signed_att.message) + ) else: - attestations.append(att_spec.message) + attestations.append( + Attestation(validator_id=att_spec.validator_id, data=att_spec.message) + ) return attestations @@ -473,7 +489,8 @@ def _build_signed_attestation_from_spec( # Create signed attestation return SignedAttestation( - message=attestation, + validator_id=attestation.validator_id, + message=attestation.data, signature=( spec.signature or Signature( diff --git a/packages/testing/src/consensus_testing/test_fixtures/state_transition.py b/packages/testing/src/consensus_testing/test_fixtures/state_transition.py index 5e6ca0ed..ca4985d0 100644 --- a/packages/testing/src/consensus_testing/test_fixtures/state_transition.py +++ b/packages/testing/src/consensus_testing/test_fixtures/state_transition.py @@ -5,7 +5,7 @@ from pydantic import ConfigDict, PrivateAttr, field_serializer from lean_spec.subspecs.containers.block.block import Block, BlockBody -from lean_spec.subspecs.containers.block.types import Attestations +from lean_spec.subspecs.containers.block.types import AggregatedAttestations from lean_spec.subspecs.containers.state.state import State from lean_spec.subspecs.ssz.hash import hash_tree_root from lean_spec.types import Bytes32, Uint64 @@ -205,7 +205,9 @@ def _build_block_from_spec(self, spec: BlockSpec, state: State) -> tuple[Block, parent_root = hash_tree_root(temp_state.latest_block_header) # Extract attestations from body if provided - attestations = list(spec.body.attestations) if spec.body else [] + aggregated_attestations = ( + spec.body.attestations if spec.body else AggregatedAttestations(data=[]) + ) # Handle explicit state root override if spec.state_root is not None: @@ -214,7 +216,7 @@ def _build_block_from_spec(self, spec: BlockSpec, state: State) -> tuple[Block, proposer_index=proposer_index, parent_root=parent_root, state_root=spec.state_root, - body=spec.body or BlockBody(attestations=Attestations(data=[])), + body=spec.body or BlockBody(attestations=aggregated_attestations), ) return block, None @@ -225,7 +227,7 @@ def _build_block_from_spec(self, spec: BlockSpec, state: State) -> tuple[Block, proposer_index=proposer_index, parent_root=parent_root, state_root=Bytes32.zero(), - body=spec.body or BlockBody(attestations=Attestations(data=attestations)), + body=spec.body or BlockBody(attestations=aggregated_attestations), ) return block, None @@ -234,6 +236,10 @@ def _build_block_from_spec(self, spec: BlockSpec, state: State) -> tuple[Block, slot=spec.slot, proposer_index=proposer_index, parent_root=parent_root, - attestations=attestations, + attestations=[ + attestation + for aggregated_attestation in aggregated_attestations + for attestation in aggregated_attestation.to_plain() + ], ) return block, post_state diff --git a/packages/testing/src/consensus_testing/test_fixtures/verify_signatures.py b/packages/testing/src/consensus_testing/test_fixtures/verify_signatures.py index f9c53c3f..dca4bbd3 100644 --- a/packages/testing/src/consensus_testing/test_fixtures/verify_signatures.py +++ b/packages/testing/src/consensus_testing/test_fixtures/verify_signatures.py @@ -12,11 +12,11 @@ AttestationData, SignedAttestation, ) -from lean_spec.subspecs.containers.block.block import ( +from lean_spec.subspecs.containers.block import ( + BlockSignatures, BlockWithAttestation, SignedBlockWithAttestation, ) -from lean_spec.subspecs.containers.block.types import BlockSignatures from lean_spec.subspecs.containers.checkpoint import Checkpoint from lean_spec.subspecs.containers.slot import Slot from lean_spec.subspecs.containers.state.state import State @@ -192,7 +192,9 @@ def _build_block_from_spec( parent_root = hash_tree_root(parent_state.latest_block_header) # Build attestations from spec - attestations, signatures = self._build_attestations_from_spec(spec, state, key_manager) + attestations, attestation_signature_inputs = self._build_attestations_from_spec( + spec, state, key_manager + ) # Use State.build_block for core block building (pure spec logic) final_block, _, _, _ = state.build_block( @@ -202,6 +204,20 @@ def _build_block_from_spec( attestations=attestations, ) + # Preserve per-attestation validity from the spec. + # + # For signature tests we must ensure that the signatures in the input spec are used + # for any intentionally-invalid signature from the input spec remains invalid + # in the produced `SignedBlockWithAttestation`. + signature_lookup: dict[tuple[Uint64, bytes], Signature] = { + (att.validator_id, att.data.data_root_bytes()): sig + for att, sig in zip(attestations, attestation_signature_inputs, strict=True) + } + attestation_signatures = key_manager.build_attestation_signatures( + final_block.body.attestations, + signature_lookup=signature_lookup, + ) + # Create proposer attestation for this block block_root = hash_tree_root(final_block) proposer_attestation = Attestation( @@ -216,7 +232,10 @@ def _build_block_from_spec( # Sign proposer attestation - use valid or dummy signature based on spec if spec.valid_signature: - proposer_attestation_signature = key_manager.sign_attestation(proposer_attestation) + proposer_attestation_signature = key_manager.sign_attestation_data( + proposer_attestation.validator_id, + proposer_attestation.data, + ) else: # Generate an invalid dummy signature (all zeros) from lean_spec.subspecs.xmss.constants import TEST_CONFIG @@ -229,14 +248,15 @@ def _build_block_from_spec( hashes=HashDigestList(data=[]), ) - signatures.append(proposer_attestation_signature) - return SignedBlockWithAttestation( message=BlockWithAttestation( block=final_block, proposer_attestation=proposer_attestation, ), - signature=BlockSignatures(data=signatures), + signature=BlockSignatures( + attestation_signatures=attestation_signatures, + proposer_signature=proposer_attestation_signature, + ), ) def _build_attestations_from_spec( @@ -257,10 +277,22 @@ def _build_attestations_from_spec( signed_attestation = self._build_signed_attestation_from_spec( attestation_item, state, key_manager ) - attestations.append(signed_attestation.message) + # Reconstruct Attestation from SignedAttestation components + attestations.append( + Attestation( + validator_id=signed_attestation.validator_id, + data=signed_attestation.message, + ) + ) attestation_signatures.append(signed_attestation.signature) else: - attestations.append(attestation_item.message) + # Reconstruct Attestation from existing SignedAttestation + attestations.append( + Attestation( + validator_id=attestation_item.validator_id, + data=attestation_item.message, + ) + ) attestation_signatures.append(attestation_item.signature) return attestations, attestation_signatures @@ -313,7 +345,10 @@ def _build_signed_attestation_from_spec( # Sign the attestation - use dummy signature if expecting invalid signature if spec.valid_signature: # Generate valid signature using key manager - signature = key_manager.sign_attestation(attestation) + signature = key_manager.sign_attestation_data( + attestation.validator_id, + attestation.data, + ) else: # Generate an invalid dummy signature (all zeros) from lean_spec.subspecs.xmss.constants import TEST_CONFIG @@ -328,6 +363,7 @@ def _build_signed_attestation_from_spec( # Create signed attestation return SignedAttestation( - message=attestation, + validator_id=attestation.validator_id, + message=attestation.data, signature=signature, ) diff --git a/packages/testing/src/consensus_testing/test_types/store_checks.py b/packages/testing/src/consensus_testing/test_types/store_checks.py index d248dace..0fbef3d5 100644 --- a/packages/testing/src/consensus_testing/test_types/store_checks.py +++ b/packages/testing/src/consensus_testing/test_types/store_checks.py @@ -51,7 +51,7 @@ def validate_attestation( expected = getattr(self, field_name) if field_name == "attestation_slot": - actual = attestation.message.data.slot + actual = attestation.message.slot if actual != expected: raise AssertionError( f"Step {step_index}: validator {self.validator} {location} " @@ -59,7 +59,7 @@ def validate_attestation( ) elif field_name == "head_slot": - actual = attestation.message.data.head.slot + actual = attestation.message.head.slot if actual != expected: raise AssertionError( f"Step {step_index}: validator {self.validator} {location} " @@ -67,7 +67,7 @@ def validate_attestation( ) elif field_name == "source_slot": - actual = attestation.message.data.source.slot + actual = attestation.message.source.slot if actual != expected: raise AssertionError( f"Step {step_index}: validator {self.validator} {location} " @@ -75,7 +75,7 @@ def validate_attestation( ) elif field_name == "target_slot": - actual = attestation.message.data.target.slot + actual = attestation.message.target.slot if actual != expected: raise AssertionError( f"Step {step_index}: validator {self.validator} {location} " @@ -442,7 +442,7 @@ def validate_against_store( # An attestation votes for this fork if its head is this block or a descendant weight = 0 for attestation in store.latest_known_attestations.values(): - att_head_root = attestation.message.data.head.root + att_head_root = attestation.message.head.root # Check if attestation head is this block or a descendant if att_head_root == root: weight += 1 diff --git a/src/lean_spec/subspecs/containers/__init__.py b/src/lean_spec/subspecs/containers/__init__.py index 4548fef5..0cf54111 100644 --- a/src/lean_spec/subspecs/containers/__init__.py +++ b/src/lean_spec/subspecs/containers/__init__.py @@ -9,12 +9,12 @@ """ from .attestation import ( - AggregatedAttestations, - AggregatedSignatures, + AggregatedAttestation, AggregationBits, Attestation, AttestationData, - SignedAggregatedAttestations, + NaiveAggregatedSignature, + SignedAggregatedAttestation, SignedAttestation, ) from .block import ( @@ -30,13 +30,13 @@ from .validator import Validator __all__ = [ - "AggregatedAttestations", - "AggregatedSignatures", + "AggregatedAttestation", + "NaiveAggregatedSignature", "AggregationBits", "AttestationData", "Attestation", "SignedAttestation", - "SignedAggregatedAttestations", + "SignedAggregatedAttestation", "Block", "BlockWithAttestation", "BlockBody", diff --git a/src/lean_spec/subspecs/containers/attestation/__init__.py b/src/lean_spec/subspecs/containers/attestation/__init__.py index 08526865..7e05a530 100644 --- a/src/lean_spec/subspecs/containers/attestation/__init__.py +++ b/src/lean_spec/subspecs/containers/attestation/__init__.py @@ -1,20 +1,20 @@ """Attestation containers and related types for the Lean spec.""" from .attestation import ( - AggregatedAttestations, + AggregatedAttestation, Attestation, AttestationData, - SignedAggregatedAttestations, + SignedAggregatedAttestation, SignedAttestation, ) -from .types import AggregatedSignatures, AggregationBits +from .types import AggregationBits, NaiveAggregatedSignature __all__ = [ "AttestationData", "Attestation", "SignedAttestation", - "SignedAggregatedAttestations", - "AggregatedAttestations", - "AggregatedSignatures", + "SignedAggregatedAttestation", + "AggregatedAttestation", + "NaiveAggregatedSignature", "AggregationBits", ] diff --git a/src/lean_spec/subspecs/containers/attestation/attestation.py b/src/lean_spec/subspecs/containers/attestation/attestation.py index e4983284..c6614251 100644 --- a/src/lean_spec/subspecs/containers/attestation/attestation.py +++ b/src/lean_spec/subspecs/containers/attestation/attestation.py @@ -12,12 +12,17 @@ doesn't do this yet. """ +from __future__ import annotations + +from collections import defaultdict + from lean_spec.subspecs.containers.slot import Slot +from lean_spec.subspecs.ssz import hash_tree_root from lean_spec.types import Container, Uint64 from ...xmss.containers import Signature from ..checkpoint import Checkpoint -from .types import AggregatedSignatures, AggregationBits +from .types import AggregationBits, NaiveAggregatedSignature class AttestationData(Container): @@ -35,6 +40,10 @@ class AttestationData(Container): source: Checkpoint """The checkpoint representing the source block as observed by the validator.""" + def data_root_bytes(self) -> bytes: + """The root of the attestation data.""" + return bytes(hash_tree_root(self)) + class Attestation(Container): """Validator specific attestation wrapping shared attestation data.""" @@ -49,14 +58,17 @@ class Attestation(Container): class SignedAttestation(Container): """Validator attestation bundled with its signature.""" - message: Attestation + validator_id: Uint64 + """The index of the validator making the attestation.""" + + message: AttestationData """The attestation message signed by the validator.""" signature: Signature """Signature aggregation produced by the leanVM (SNARKs in the future).""" -class AggregatedAttestations(Container): +class AggregatedAttestation(Container): """Aggregated attestation consisting of participation bits and message.""" aggregation_bits: AggregationBits @@ -69,18 +81,62 @@ class AggregatedAttestations(Container): committee assignments. """ - -class SignedAggregatedAttestations(Container): + def to_plain(self) -> list[Attestation]: + """ + Expand this aggregated attestation into plain per-validator attestations. + + Returns: + One `Attestation` per participating validator index, all sharing the same + `AttestationData`. + """ + validator_indices = self.aggregation_bits.to_validator_indices() + return [ + Attestation(validator_id=validator_id, data=self.data) + for validator_id in validator_indices + ] + + @classmethod + def aggregate_by_data( + cls, + attestations: list[Attestation], + ) -> list[AggregatedAttestation]: + """ + Aggregate plain per-validator attestations by their shared AttestationData. + + Args: + attestations: Attestations to aggregate. + + Returns: + One AggregatedAttestation per unique AttestationData, with aggregation + bits set for all participating validators. + """ + data_to_validator_ids: dict[AttestationData, list[Uint64]] = defaultdict(list) + for attestation in attestations: + data_to_validator_ids[attestation.data].append(attestation.validator_id) + + return [ + cls( + aggregation_bits=AggregationBits.from_validator_indices(validator_ids), + data=data, + ) + for data, validator_ids in data_to_validator_ids.items() + ] + + +class SignedAggregatedAttestation(Container): """Aggregated attestation bundled with aggregated signatures.""" - message: AggregatedAttestations + message: AggregatedAttestation """Aggregated attestation data.""" - signature: AggregatedSignatures + signature: NaiveAggregatedSignature """Aggregated attestation plus its combined signature. Stores a naive list of validator signatures that mirrors the attestation order. - TODO: this will be replaced by a SNARK in future devnets. + TODO: + - signatures will be replaced by MegaBytes in next PR to include leanVM proof. + - this will be replaced by a SNARK in future devnets. + - this will be aggregated by aggregators in future devnets. """ diff --git a/src/lean_spec/subspecs/containers/attestation/types.py b/src/lean_spec/subspecs/containers/attestation/types.py index c4a7b99a..9dec9bae 100644 --- a/src/lean_spec/subspecs/containers/attestation/types.py +++ b/src/lean_spec/subspecs/containers/attestation/types.py @@ -1,6 +1,8 @@ """Attestation-related SSZ types for the Lean consensus specification.""" -from lean_spec.types import SSZList +from __future__ import annotations + +from lean_spec.types import SSZList, Uint64 from lean_spec.types.bitfields import BaseBitlist from ...chain.config import VALIDATOR_REGISTRY_LIMIT @@ -12,8 +14,52 @@ class AggregationBits(BaseBitlist): LIMIT = int(VALIDATOR_REGISTRY_LIMIT) + @classmethod + def from_validator_indices(cls, indices: list[Uint64]) -> AggregationBits: + """ + Construct aggregation bits from a set of validator indices. + + Args: + indices: Validator indices to set in the bitlist. + + Returns: + AggregationBits with the corresponding indices set to True. + + Raises: + AssertionError: If no indices are provided. + AssertionError: If any index is outside the supported LIMIT. + """ + ids = [int(i) for i in indices] + if not ids: + raise AssertionError("Aggregated attestation must reference at least one validator") + + max_id = max(ids) + if max_id >= cls.LIMIT: + raise AssertionError("Validator index out of range for aggregation bits") + + bits = [False] * (max_id + 1) + for i in ids: + bits[i] = True + + return cls(data=bits) + + def to_validator_indices(self) -> list[Uint64]: + """ + Extract all validator indices encoded in these aggregation bits. + + Returns: + List of validator indices, sorted in ascending order. + + Raises: + AssertionError: If no bits are set. + """ + if not (indices := [Uint64(i) for i, bit in enumerate(self.data) if bool(bit)]): + raise AssertionError("Aggregated attestation must reference at least one validator") + + return indices + -class AggregatedSignatures(SSZList): +class NaiveAggregatedSignature(SSZList): """Naive list of validator signatures used for aggregation placeholders.""" ELEMENT_TYPE = Signature diff --git a/src/lean_spec/subspecs/containers/block/__init__.py b/src/lean_spec/subspecs/containers/block/__init__.py index a3a844cd..4ed7dfa7 100644 --- a/src/lean_spec/subspecs/containers/block/__init__.py +++ b/src/lean_spec/subspecs/containers/block/__init__.py @@ -4,17 +4,22 @@ Block, BlockBody, BlockHeader, + BlockSignatures, BlockWithAttestation, SignedBlockWithAttestation, ) -from .types import Attestations, BlockSignatures +from .types import ( + AggregatedAttestations, + AttestationSignatures, +) __all__ = [ "Block", "BlockBody", "BlockHeader", + "BlockSignatures", "BlockWithAttestation", "SignedBlockWithAttestation", - "Attestations", - "BlockSignatures", + "AggregatedAttestations", + "AttestationSignatures", ] diff --git a/src/lean_spec/subspecs/containers/block/block.py b/src/lean_spec/subspecs/containers/block/block.py index 5c985932..c1eef772 100644 --- a/src/lean_spec/subspecs/containers/block/block.py +++ b/src/lean_spec/subspecs/containers/block/block.py @@ -12,13 +12,16 @@ from typing import TYPE_CHECKING, cast from lean_spec.subspecs.containers.slot import Slot -from lean_spec.subspecs.ssz.hash import hash_tree_root from lean_spec.types import Bytes32, Uint64 from lean_spec.types.container import Container +from ...xmss.containers import Signature as XmssSignature from ..attestation import Attestation from ..validator import Validator -from .types import Attestations, BlockSignatures +from .types import ( + AggregatedAttestations, + AttestationSignatures, +) if TYPE_CHECKING: from ..state import State @@ -32,7 +35,7 @@ class BlockBody(Container): packaged into blocks. """ - attestations: Attestations + attestations: AggregatedAttestations """Plain validator attestations carried in the block body. Individual signatures live in the aggregated block signature list, so @@ -97,6 +100,24 @@ class BlockWithAttestation(Container): """The proposer's attestation corresponding to this block.""" +class BlockSignatures(Container): + """Signature payload for the block.""" + + attestation_signatures: AttestationSignatures + """Attestation signatures for the aggregated attestations in the block body. + + Each entry corresponds to an aggregated attestation from the block body and + contains all XMSS signatures from the participating validators. + + TODO: + - Currently, this is list of lists of signatures. + - The list of signatures will be replaced by a BytesArray to include leanVM aggregated proof. + """ + + proposer_signature: XmssSignature + """Signature for the proposer's attestation.""" + + class SignedBlockWithAttestation(Container): """Envelope carrying a block, an attestation from proposer, and aggregated signatures.""" @@ -137,48 +158,52 @@ def verify_signatures(self, parent_state: "State") -> bool: - Validator index out of range - XMSS signature verification failure """ - # Unpack the signed block components block = self.message.block signatures = self.signature + aggregated_attestations = block.body.attestations + attestation_signatures = signatures.attestation_signatures - # Combine all attestations that need verification - # - # This creates a single list containing both: - # 1. Block body attestations (from other validators) - # 2. Proposer attestation (from the block producer) - all_attestations = block.body.attestations + [self.message.proposer_attestation] - - # Verify signature count matches attestation count - # - # Each attestation must have exactly one corresponding signature. - # - # The ordering must be preserved: - # 1. Block body attestations, - # 2. The proposer attestation. - assert len(signatures) == len(all_attestations), ( - "Number of signatures does not match number of attestations" + assert len(aggregated_attestations) == len(attestation_signatures), ( + "Attestation signature groups must align with block body attestations" ) validators = parent_state.validators - # Verify each attestation signature - for attestation, signature in zip(all_attestations, signatures, strict=True): - # Ensure validator exists in the active set - assert attestation.validator_id < Uint64(len(validators)), ( - "Validator index out of range" + for aggregated_attestation, aggregated_signature in zip( + aggregated_attestations, attestation_signatures, strict=True + ): + validator_ids = aggregated_attestation.aggregation_bits.to_validator_indices() + + assert len(aggregated_signature) == len(validator_ids), ( + "Aggregated attestation signature count mismatch" ) - validator = cast(Validator, validators[attestation.validator_id]) - - # Verify the XMSS signature - # - # This cryptographically proves that: - # - The validator possesses the secret key for their public key - # - The attestation has not been tampered with - # - The signature was created at the correct epoch (slot) - assert signature.verify( - validator.get_pubkey(), - attestation.data.slot, - bytes(hash_tree_root(attestation)), - ), "Attestation signature verification failed" + + attestation_root = aggregated_attestation.data.data_root_bytes() + + # Verify each validator's attestation signature + for validator_id, signature in zip(validator_ids, aggregated_signature, strict=True): + # Ensure validator exists in the active set + assert validator_id < Uint64(len(validators)), "Validator index out of range" + validator = cast(Validator, validators[validator_id]) + + assert signature.verify( + validator.get_pubkey(), + aggregated_attestation.data.slot, + attestation_root, + ), "Attestation signature verification failed" + + # Verify proposer attestation signature + proposer_attestation = self.message.proposer_attestation + proposer_signature = signatures.proposer_signature + assert proposer_attestation.validator_id < Uint64(len(validators)), ( + "Proposer index out of range" + ) + proposer = cast(Validator, validators[proposer_attestation.validator_id]) + + assert proposer_signature.verify( + proposer.get_pubkey(), + proposer_attestation.data.slot, + proposer_attestation.data.data_root_bytes(), + ), "Proposer signature verification failed" return True diff --git a/src/lean_spec/subspecs/containers/block/types.py b/src/lean_spec/subspecs/containers/block/types.py index 0e5f68ff..e602ef20 100644 --- a/src/lean_spec/subspecs/containers/block/types.py +++ b/src/lean_spec/subspecs/containers/block/types.py @@ -3,19 +3,18 @@ from lean_spec.types import SSZList from ...chain.config import VALIDATOR_REGISTRY_LIMIT -from ...xmss.containers import Signature -from ..attestation import Attestation +from ..attestation import AggregatedAttestation, NaiveAggregatedSignature -class Attestations(SSZList): - """List of validator attestations included in a block.""" +class AggregatedAttestations(SSZList): + """List of aggregated attestations included in a block.""" - ELEMENT_TYPE = Attestation + ELEMENT_TYPE = AggregatedAttestation LIMIT = int(VALIDATOR_REGISTRY_LIMIT) -class BlockSignatures(SSZList): - """Aggregated signature list included alongside the block.""" +class AttestationSignatures(SSZList): + """List of per-attestation naive signature lists aligned with block body attestations.""" - ELEMENT_TYPE = Signature + ELEMENT_TYPE = NaiveAggregatedSignature LIMIT = int(VALIDATOR_REGISTRY_LIMIT) diff --git a/src/lean_spec/subspecs/containers/state/state.py b/src/lean_spec/subspecs/containers/state/state.py index 9617f9f6..f6a9476f 100644 --- a/src/lean_spec/subspecs/containers/state/state.py +++ b/src/lean_spec/subspecs/containers/state/state.py @@ -12,12 +12,16 @@ is_proposer, ) -from ..attestation import Attestation, SignedAttestation +from ..attestation import ( + AggregatedAttestation, + Attestation, + SignedAttestation, +) if TYPE_CHECKING: from lean_spec.subspecs.xmss.containers import Signature from ..block import Block, BlockBody, BlockHeader -from ..block.types import Attestations +from ..block.types import AggregatedAttestations from ..checkpoint import Checkpoint from ..config import Config from ..slot import Slot @@ -96,7 +100,7 @@ def generate_genesis(cls, genesis_time: Uint64, validators: Validators) -> "Stat proposer_index=Uint64(0), parent_root=Bytes32.zero(), state_root=Bytes32.zero(), - body_root=hash_tree_root(BlockBody(attestations=Attestations(data=[]))), + body_root=hash_tree_root(BlockBody(attestations=AggregatedAttestations(data=[]))), ) # Assemble and return the full genesis state. @@ -352,16 +356,31 @@ def process_block(self, block: Block) -> "State": ------- State A new state with the processed block. + + Raises: + ------ + AssertionError + If block contains duplicate AttestationData. """ # First process the block header. state = self.process_block_header(block) - # Process justification attestations. - return state.process_attestations(block.body.attestations) + # Process justification attestations by converting aggregated payloads + attestations: list[Attestation] = [] + attestations_data = set() + for aggregated_att in block.body.attestations: + # No partial aggregation is allowed. + if aggregated_att.data in attestations_data: + raise AssertionError("Block contains duplicate AttestationData") + + attestations_data.add(aggregated_att.data) + attestations.extend(aggregated_att.to_plain()) + + return state.process_attestations(attestations) def process_attestations( self, - attestations: Attestations, + attestations: list[Attestation], ) -> "State": """ Apply attestations and update justification/finalization @@ -649,7 +668,11 @@ def build_block( proposer_index=proposer_index, parent_root=parent_root, state_root=Bytes32.zero(), - body=BlockBody(attestations=Attestations(data=attestations)), + body=BlockBody( + attestations=AggregatedAttestations( + data=AggregatedAttestation.aggregate_by_data(attestations) + ) + ), ) # Apply state transition to get the post-block state @@ -664,7 +687,11 @@ def build_block( new_signatures: list[Signature] = [] for signed_attestation in available_signed_attestations: - data = signed_attestation.message.data + data = signed_attestation.message + attestation = Attestation( + validator_id=signed_attestation.validator_id, + data=data, + ) # Skip if target block is unknown if data.head.root not in known_block_roots: @@ -675,8 +702,8 @@ def build_block( continue # Add attestation if not already included - if signed_attestation.message not in attestations: - new_attestations.append(signed_attestation.message) + if attestation not in attestations: + new_attestations.append(attestation) new_signatures.append(signed_attestation.signature) # Fixed point reached: no new attestations found diff --git a/src/lean_spec/subspecs/forkchoice/store.py b/src/lean_spec/subspecs/forkchoice/store.py index 2f48ff13..f6bba9db 100644 --- a/src/lean_spec/subspecs/forkchoice/store.py +++ b/src/lean_spec/subspecs/forkchoice/store.py @@ -211,7 +211,7 @@ def validate_attestation(self, signed_attestation: SignedAttestation) -> None: Raises: AssertionError: If attestation fails validation. """ - data = signed_attestation.message.data + data = signed_attestation.message # Availability Check # @@ -293,11 +293,11 @@ def on_attestation( self.validate_attestation(signed_attestation) # Extract the validator index that produced this attestation. - validator_id = Uint64(signed_attestation.message.validator_id) + validator_id = Uint64(signed_attestation.validator_id) # Extract the attestation's slot: # - used to decide if this attestation is "newer" than a previous one. - attestation_slot = signed_attestation.message.data.slot + attestation_slot = signed_attestation.message.slot # Copy the known attestation map: # - we build a new Store immutably, @@ -321,7 +321,7 @@ def on_attestation( # Update the known attestation for this validator if: # - there is no known attestation yet, or # - this attestation is from a later slot than the known one. - if latest_known is None or latest_known.message.data.slot < attestation_slot: + if latest_known is None or latest_known.message.slot < attestation_slot: new_known[validator_id] = signed_attestation # Fetch any pending ("new") attestation for this validator. @@ -332,7 +332,7 @@ def on_attestation( # - it is from an equal or earlier slot than this on-chain attestation. # # In that case, the on-chain attestation supersedes it. - if existing_new is not None and existing_new.message.data.slot <= attestation_slot: + if existing_new is not None and existing_new.message.slot <= attestation_slot: del new_new[validator_id] else: # Network gossip attestation processing @@ -355,7 +355,7 @@ def on_attestation( # Update the pending attestation for this validator if: # - there is no pending attestation yet, or # - this one is from a later slot than the pending one. - if latest_new is None or latest_new.message.data.slot < attestation_slot: + if latest_new is None or latest_new.message.slot < attestation_slot: new_new[validator_id] = signed_attestation # Return a new Store with updated "known" and "new" attestation maps. @@ -410,7 +410,6 @@ def on_block(self, signed_block_with_attestation: SignedBlockWithAttestation) -> # Unpack block components block = signed_block_with_attestation.message.block proposer_attestation = signed_block_with_attestation.message.proposer_attestation - signatures = signed_block_with_attestation.signature block_root = hash_tree_root(block) # Skip duplicate blocks (idempotent operation) @@ -457,23 +456,35 @@ def on_block(self, signed_block_with_attestation: SignedBlockWithAttestation) -> } ) - # Process block body attestations - # - # Iterate over attestations and their corresponding signatures. - for attestation, signature in zip( - signed_block_with_attestation.message.block.body.attestations, - signed_block_with_attestation.signature, - strict=False, + # Process block body attestations. + aggregated_attestations = signed_block_with_attestation.message.block.body.attestations + attestation_signatures = signed_block_with_attestation.signature.attestation_signatures + + assert len(aggregated_attestations) == len(attestation_signatures), ( + "Attestation signature groups must match aggregated attestations" + ) + + for aggregated_attestation, aggregated_signature in zip( + aggregated_attestations, attestation_signatures, strict=True ): - # Process as on-chain attestation (immediately becomes "known") - store = store.on_attestation( - signed_attestation=SignedAttestation( - message=attestation, - signature=signature, - ), - is_from_block=True, + plain_attestations = aggregated_attestation.to_plain() + + assert len(plain_attestations) == len(aggregated_signature), ( + "Aggregated attestation signature count mismatch" ) + for attestation, signature in zip( + plain_attestations, aggregated_signature, strict=True + ): + store = store.on_attestation( + signed_attestation=SignedAttestation( + validator_id=attestation.validator_id, + message=attestation.data, + signature=signature, + ), + is_from_block=True, + ) + # Update forkchoice head based on new block and attestations # # IMPORTANT: This must happen BEFORE processing proposer attestation @@ -489,8 +500,9 @@ def on_block(self, signed_block_with_attestation: SignedBlockWithAttestation) -> # 3. Influence fork choice only after interval 3 (end of slot) store = store.on_attestation( signed_attestation=SignedAttestation( - message=proposer_attestation, - signature=signatures[len(block.body.attestations)], + validator_id=proposer_attestation.validator_id, + message=proposer_attestation.data, + signature=signed_block_with_attestation.signature.proposer_signature, ), is_from_block=False, ) @@ -552,7 +564,7 @@ def _compute_lmd_ghost_head( # # Each visited block accumulates one unit of weight from that validator. for attestation in attestations.values(): - current_root = attestation.message.data.head.root + current_root = attestation.message.head.root # Climb towards the anchor while staying inside the known tree. # diff --git a/tests/consensus/devnet/state_transition/test_genesis.py b/tests/consensus/devnet/state_transition/test_genesis.py index b991cc30..5e761711 100644 --- a/tests/consensus/devnet/state_transition/test_genesis.py +++ b/tests/consensus/devnet/state_transition/test_genesis.py @@ -11,7 +11,7 @@ from consensus_testing import StateExpectation, StateTransitionTestFiller, generate_pre_state from lean_spec.subspecs.containers.block import Block, BlockBody -from lean_spec.subspecs.containers.block.types import Attestations +from lean_spec.subspecs.containers.block.types import AggregatedAttestations from lean_spec.subspecs.containers.slot import Slot from lean_spec.subspecs.containers.state import State, Validators from lean_spec.subspecs.containers.state.types import ( @@ -55,7 +55,7 @@ def test_genesis_default_configuration( latest_block_header_parent_root=Bytes32.zero(), latest_block_header_state_root=Bytes32.zero(), latest_block_header_body_root=hash_tree_root( - BlockBody(attestations=Attestations(data=[])) + BlockBody(attestations=AggregatedAttestations(data=[])) ), historical_block_hashes=HistoricalBlockHashes(data=[]), justified_slots=JustifiedSlots(data=[]), @@ -100,7 +100,7 @@ def test_genesis_custom_time( latest_block_header_parent_root=Bytes32.zero(), latest_block_header_state_root=Bytes32.zero(), latest_block_header_body_root=hash_tree_root( - BlockBody(attestations=Attestations(data=[])) + BlockBody(attestations=AggregatedAttestations(data=[])) ), historical_block_hashes=HistoricalBlockHashes(data=[]), justified_slots=JustifiedSlots(data=[]), @@ -143,7 +143,7 @@ def test_genesis_custom_validator_set( latest_block_header_parent_root=Bytes32.zero(), latest_block_header_state_root=Bytes32.zero(), latest_block_header_body_root=hash_tree_root( - BlockBody(attestations=Attestations(data=[])) + BlockBody(attestations=AggregatedAttestations(data=[])) ), historical_block_hashes=HistoricalBlockHashes(data=[]), justified_slots=JustifiedSlots(data=[]), @@ -173,7 +173,7 @@ def test_genesis_block_hash_comparison() -> None: proposer_index=Uint64(0), parent_root=Bytes32.zero(), state_root=hash_tree_root(genesis_state1), - body=BlockBody(attestations=Attestations(data=[])), + body=BlockBody(attestations=AggregatedAttestations(data=[])), ) # Compute hash of first genesis block @@ -190,7 +190,7 @@ def test_genesis_block_hash_comparison() -> None: proposer_index=Uint64(0), parent_root=Bytes32.zero(), state_root=hash_tree_root(genesis_state1_copy), - body=BlockBody(attestations=Attestations(data=[])), + body=BlockBody(attestations=AggregatedAttestations(data=[])), ) genesis_block_hash1_copy = hash_tree_root(genesis_block1_copy) @@ -215,7 +215,7 @@ def test_genesis_block_hash_comparison() -> None: proposer_index=Uint64(0), parent_root=Bytes32.zero(), state_root=hash_tree_root(genesis_state2), - body=BlockBody(attestations=Attestations(data=[])), + body=BlockBody(attestations=AggregatedAttestations(data=[])), ) genesis_block_hash2 = hash_tree_root(genesis_block2) @@ -240,7 +240,7 @@ def test_genesis_block_hash_comparison() -> None: proposer_index=Uint64(0), parent_root=Bytes32.zero(), state_root=hash_tree_root(genesis_state3), - body=BlockBody(attestations=Attestations(data=[])), + body=BlockBody(attestations=AggregatedAttestations(data=[])), ) genesis_block_hash3 = hash_tree_root(genesis_block3) diff --git a/tests/lean_spec/subspecs/containers/test_attestation_aggregation.py b/tests/lean_spec/subspecs/containers/test_attestation_aggregation.py new file mode 100644 index 00000000..1d5e2e13 --- /dev/null +++ b/tests/lean_spec/subspecs/containers/test_attestation_aggregation.py @@ -0,0 +1,142 @@ +"""Tests for attestation aggregation and signature ordering.""" + +import pytest + +from lean_spec.subspecs.containers.attestation import ( + AggregatedAttestation, + AggregationBits, + Attestation, + AttestationData, +) +from lean_spec.subspecs.containers.checkpoint import Checkpoint +from lean_spec.subspecs.containers.slot import Slot +from lean_spec.types import Bytes32, Uint64 + + +class TestAttestationAggregation: + """Test proper attestation aggregation by common data.""" + + def test_reject_empty_aggregation_bits(self) -> None: + """Validate aggregated attestation must include at least one validator.""" + bits = AggregationBits(data=[False, False, False]) + with pytest.raises(AssertionError, match="at least one validator"): + bits.to_validator_indices() + + def test_aggregate_attestations_by_common_data(self) -> None: + """Test that attestations with same data are properly aggregated.""" + # Create three attestations with two having common data + att_data1 = AttestationData( + slot=Slot(5), + head=Checkpoint(root=Bytes32.zero(), slot=Slot(4)), + target=Checkpoint(root=Bytes32.zero(), slot=Slot(3)), + source=Checkpoint(root=Bytes32.zero(), slot=Slot(2)), + ) + att_data2 = AttestationData( + slot=Slot(6), + head=Checkpoint(root=Bytes32.zero(), slot=Slot(5)), + target=Checkpoint(root=Bytes32.zero(), slot=Slot(4)), + source=Checkpoint(root=Bytes32.zero(), slot=Slot(3)), + ) + + attestations = [ + Attestation(validator_id=Uint64(1), data=att_data1), + Attestation(validator_id=Uint64(3), data=att_data1), + Attestation(validator_id=Uint64(5), data=att_data2), + ] + + aggregated = AggregatedAttestation.aggregate_by_data(attestations) + + # Should have 2 aggregated attestations (one per unique data) + assert len(aggregated) == 2 + + # Find the aggregated attestation with att_data1 + agg1 = next(agg for agg in aggregated if agg.data == att_data1) + validator_ids1 = agg1.aggregation_bits.to_validator_indices() + + # Should contain validators 1 and 3 + assert set(validator_ids1) == {Uint64(1), Uint64(3)} + + # Find the aggregated attestation with att_data2 + agg2 = next(agg for agg in aggregated if agg.data == att_data2) + validator_ids2 = agg2.aggregation_bits.to_validator_indices() + + # Should contain only validator 5 + assert set(validator_ids2) == {Uint64(5)} + + def test_aggregate_attestations_sets_all_bits(self) -> None: + """Test that aggregation sets all validator bits correctly.""" + att_data = AttestationData( + slot=Slot(5), + head=Checkpoint(root=Bytes32.zero(), slot=Slot(4)), + target=Checkpoint(root=Bytes32.zero(), slot=Slot(3)), + source=Checkpoint(root=Bytes32.zero(), slot=Slot(2)), + ) + + attestations = [ + Attestation(validator_id=Uint64(2), data=att_data), + Attestation(validator_id=Uint64(7), data=att_data), + Attestation(validator_id=Uint64(10), data=att_data), + ] + + aggregated = AggregatedAttestation.aggregate_by_data(attestations) + + assert len(aggregated) == 1 + validator_ids = aggregated[0].aggregation_bits.to_validator_indices() + + # Should have all three validators + assert len(validator_ids) == 3 + assert set(validator_ids) == {Uint64(2), Uint64(7), Uint64(10)} + + def test_aggregate_empty_attestations(self) -> None: + """Test aggregation with no attestations.""" + aggregated = AggregatedAttestation.aggregate_by_data([]) + assert len(aggregated) == 0 + + def test_aggregate_single_attestation(self) -> None: + """Test aggregation with a single attestation.""" + att_data = AttestationData( + slot=Slot(5), + head=Checkpoint(root=Bytes32.zero(), slot=Slot(4)), + target=Checkpoint(root=Bytes32.zero(), slot=Slot(3)), + source=Checkpoint(root=Bytes32.zero(), slot=Slot(2)), + ) + + attestations = [Attestation(validator_id=Uint64(5), data=att_data)] + + aggregated = AggregatedAttestation.aggregate_by_data(attestations) + + assert len(aggregated) == 1 + validator_ids = aggregated[0].aggregation_bits.to_validator_indices() + assert validator_ids == [Uint64(5)] + + +class TestDuplicateAttestationDataValidation: + """Test validation that blocks don't contain duplicate AttestationData.""" + + def test_duplicate_attestation_data_detection(self) -> None: + """Ensure conversion to plain attestations preserves duplicates.""" + att_data = AttestationData( + slot=Slot(1), + head=Checkpoint(root=Bytes32.zero(), slot=Slot(0)), + target=Checkpoint(root=Bytes32.zero(), slot=Slot(0)), + source=Checkpoint(root=Bytes32.zero(), slot=Slot(0)), + ) + + from lean_spec.subspecs.containers.attestation import AggregatedAttestation + from lean_spec.subspecs.containers.attestation.types import AggregationBits + + agg1 = AggregatedAttestation( + aggregation_bits=AggregationBits(data=[False, True]), + data=att_data, + ) + agg2 = AggregatedAttestation( + aggregation_bits=AggregationBits(data=[False, True, True]), + data=att_data, + ) + + plain = [plain_att for aggregated in (agg1, agg2) for plain_att in aggregated.to_plain()] + + # Expect 2 plain attestations (because validator 1 is common in agg1 and agg2) + # validator 1 and validator 2 are the only unique validators in the attestations + assert len(set(plain)) == 2 + assert all(att.data == att_data for att in plain) diff --git a/tests/lean_spec/subspecs/forkchoice/conftest.py b/tests/lean_spec/subspecs/forkchoice/conftest.py index d8b97a0e..9750cf0a 100644 --- a/tests/lean_spec/subspecs/forkchoice/conftest.py +++ b/tests/lean_spec/subspecs/forkchoice/conftest.py @@ -5,14 +5,13 @@ import pytest from lean_spec.subspecs.containers import ( - Attestation, AttestationData, BlockBody, Checkpoint, SignedAttestation, State, ) -from lean_spec.subspecs.containers.block import Attestations, BlockHeader +from lean_spec.subspecs.containers.block import AggregatedAttestations, BlockHeader from lean_spec.subspecs.containers.config import Config from lean_spec.subspecs.containers.slot import Slot from lean_spec.subspecs.containers.state import Validators @@ -45,7 +44,7 @@ def __init__(self, latest_justified: Checkpoint) -> None: proposer_index=Uint64(0), parent_root=Bytes32.zero(), state_root=Bytes32.zero(), - body_root=hash_tree_root(BlockBody(attestations=Attestations(data=[]))), + body_root=hash_tree_root(BlockBody(attestations=AggregatedAttestations(data=[]))), ) super().__init__( @@ -76,12 +75,9 @@ def build_signed_attestation( target=target, source=source_checkpoint, ) - message = Attestation( - validator_id=validator, - data=attestation_data, - ) return SignedAttestation( - message=message, + validator_id=validator, + message=attestation_data, signature=Signature( path=HashTreeOpening(siblings=HashDigestList(data=[])), rho=Randomness(data=[Fp(0) for _ in range(PROD_CONFIG.RAND_LEN_FE)]), diff --git a/tests/lean_spec/subspecs/forkchoice/test_store_attestations.py b/tests/lean_spec/subspecs/forkchoice/test_store_attestations.py new file mode 100644 index 00000000..c651db07 --- /dev/null +++ b/tests/lean_spec/subspecs/forkchoice/test_store_attestations.py @@ -0,0 +1,111 @@ +"""Tests for Store attestation handling.""" + +from consensus_testing.keys import XmssKeyManager + +from lean_spec.subspecs.chain.config import SECONDS_PER_SLOT +from lean_spec.subspecs.containers.attestation import ( + Attestation, + AttestationData, + SignedAttestation, +) +from lean_spec.subspecs.containers.block import ( + Block, + BlockBody, + BlockSignatures, + BlockWithAttestation, + SignedBlockWithAttestation, +) +from lean_spec.subspecs.containers.block.types import ( + AggregatedAttestations, +) +from lean_spec.subspecs.containers.checkpoint import Checkpoint +from lean_spec.subspecs.containers.slot import Slot +from lean_spec.subspecs.containers.state import State, Validators +from lean_spec.subspecs.containers.validator import Validator +from lean_spec.subspecs.forkchoice import Store +from lean_spec.subspecs.ssz.hash import hash_tree_root +from lean_spec.types import Bytes32, Uint64 + + +def test_on_block_processes_multi_validator_aggregations() -> None: + """Ensure Store.on_block handles aggregated attestations with many validators.""" + key_manager = XmssKeyManager(max_slot=Slot(10)) + validators = Validators( + data=[ + Validator(pubkey=key_manager[Uint64(i)].public.encode_bytes(), index=Uint64(i)) + for i in range(3) + ] + ) + genesis_state = State.generate_genesis(genesis_time=Uint64(0), validators=validators) + genesis_block = Block( + slot=Slot(0), + proposer_index=Uint64(0), + parent_root=Bytes32.zero(), + state_root=hash_tree_root(genesis_state), + body=BlockBody(attestations=AggregatedAttestations(data=[])), + ) + + base_store = Store.get_forkchoice_store(genesis_state, genesis_block) + consumer_store = base_store + + # Producer view knows about attestations from validators 1 and 2 + attestation_slot = Slot(1) + attestation_data = base_store.produce_attestation_data(attestation_slot) + signed_attestations = { + validator_id: SignedAttestation( + validator_id=validator_id, + message=attestation_data, + signature=key_manager.sign_attestation_data(validator_id, attestation_data), + ) + for validator_id in (Uint64(1), Uint64(2)) + } + producer_store = base_store.model_copy( + update={"latest_known_attestations": signed_attestations} + ) + + # For slot 1 with 3 validators: 1 % 3 == 1, so validator 1 is the proposer + proposer_index = Uint64(1) + _, block, _ = producer_store.produce_block_with_signatures( + attestation_slot, + proposer_index, + ) + + block_root = hash_tree_root(block) + parent_state = producer_store.states[block.parent_root] + proposer_attestation = Attestation( + validator_id=proposer_index, + data=AttestationData( + slot=attestation_slot, + head=Checkpoint(root=block_root, slot=attestation_slot), + target=Checkpoint(root=block_root, slot=attestation_slot), + source=Checkpoint(root=block.parent_root, slot=parent_state.latest_block_header.slot), + ), + ) + proposer_signature = key_manager.sign_attestation_data( + proposer_attestation.validator_id, + proposer_attestation.data, + ) + + attestation_signatures = key_manager.build_attestation_signatures(block.body.attestations) + + signed_block = SignedBlockWithAttestation( + message=BlockWithAttestation( + block=block, + proposer_attestation=proposer_attestation, + ), + signature=BlockSignatures( + attestation_signatures=attestation_signatures, + proposer_signature=proposer_signature, + ), + ) + + # Advance consumer store time to block's slot before processing + block_time = consumer_store.config.genesis_time + block.slot * Uint64(SECONDS_PER_SLOT) + consumer_store = consumer_store.on_tick(block_time, has_proposal=True) + + updated_store = consumer_store.on_block(signed_block) + + assert Uint64(1) in updated_store.latest_known_attestations + assert Uint64(2) in updated_store.latest_known_attestations + assert updated_store.latest_known_attestations[Uint64(1)].message == attestation_data + assert updated_store.latest_known_attestations[Uint64(2)].message == attestation_data diff --git a/tests/lean_spec/subspecs/forkchoice/test_time_management.py b/tests/lean_spec/subspecs/forkchoice/test_time_management.py index c4e9260b..509f9399 100644 --- a/tests/lean_spec/subspecs/forkchoice/test_time_management.py +++ b/tests/lean_spec/subspecs/forkchoice/test_time_management.py @@ -10,7 +10,7 @@ State, Validator, ) -from lean_spec.subspecs.containers.block import Attestations +from lean_spec.subspecs.containers.block import AggregatedAttestations from lean_spec.subspecs.containers.slot import Slot from lean_spec.subspecs.containers.state import Validators from lean_spec.subspecs.forkchoice import Store @@ -35,7 +35,7 @@ def sample_store(sample_config: Config) -> Store: proposer_index=Uint64(0), parent_root=Bytes32.zero(), state_root=Bytes32(b"state" + b"\x00" * 27), - body=BlockBody(attestations=Attestations(data=[])), + body=BlockBody(attestations=AggregatedAttestations(data=[])), ) genesis_hash = hash_tree_root(genesis_block) @@ -281,7 +281,7 @@ def test_accept_new_attestations_multiple(self, sample_store: Store) -> None: # Verify correct mapping for i, checkpoint in enumerate(checkpoints): stored = sample_store.latest_known_attestations[Uint64(i)] - assert stored.message.data.target == checkpoint + assert stored.message.target == checkpoint def test_accept_new_attestations_empty(self, sample_store: Store) -> None: """Test accepting new attestations when there are none.""" @@ -306,7 +306,7 @@ def test_get_proposal_head_basic(self, sample_store: Store) -> None: proposer_index=Uint64(0), parent_root=Bytes32.zero(), state_root=Bytes32(b"genesis" + b"\x00" * 25), - body=BlockBody(attestations=Attestations(data=[])), + body=BlockBody(attestations=AggregatedAttestations(data=[])), ) genesis_hash = hash_tree_root(genesis_block) @@ -353,7 +353,7 @@ def test_get_proposal_head_processes_attestations(self, sample_store: Store) -> assert Uint64(10) not in store.latest_new_attestations assert Uint64(10) in store.latest_known_attestations stored = store.latest_known_attestations[Uint64(10)] - assert stored.message.data.target == checkpoint + assert stored.message.target == checkpoint class TestTimeConstants: diff --git a/tests/lean_spec/subspecs/forkchoice/test_validator.py b/tests/lean_spec/subspecs/forkchoice/test_validator.py index 54c0b353..400ca09f 100644 --- a/tests/lean_spec/subspecs/forkchoice/test_validator.py +++ b/tests/lean_spec/subspecs/forkchoice/test_validator.py @@ -14,7 +14,7 @@ State, Validator, ) -from lean_spec.subspecs.containers.block import Attestations +from lean_spec.subspecs.containers.block import AggregatedAttestations from lean_spec.subspecs.containers.slot import Slot from lean_spec.subspecs.containers.state import ( HistoricalBlockHashes, @@ -82,7 +82,7 @@ def sample_store(config: Config, sample_state: State) -> Store: proposer_index=Uint64(0), parent_root=Bytes32.zero(), state_root=hash_tree_root(sample_state), - body=BlockBody(attestations=Attestations(data=[])), + body=BlockBody(attestations=AggregatedAttestations(data=[])), ) genesis_hash = hash_tree_root(genesis_block) @@ -134,12 +134,9 @@ def build_signed_attestation( target=target, source=source, ) - message = Attestation( - validator_id=validator, - data=data, - ) return SignedAttestation( - message=message, + validator_id=validator, + message=data, signature=Signature( path=HashTreeOpening(siblings=HashDigestList(data=[])), rho=Randomness(data=[Fp(0) for _ in range(PROD_CONFIG.RAND_LEN_FE)]), @@ -518,7 +515,7 @@ def test_validator_operations_empty_store(self) -> None: config = Config(genesis_time=Uint64(1000)) # Create minimal genesis block first - genesis_body = BlockBody(attestations=Attestations(data=[])) + genesis_body = BlockBody(attestations=AggregatedAttestations(data=[])) # Create validators list with 3 validators validators = Validators( diff --git a/tests/lean_spec/subspecs/ssz/test_block.py b/tests/lean_spec/subspecs/ssz/test_block.py index 685d78bd..a7f880db 100644 --- a/tests/lean_spec/subspecs/ssz/test_block.py +++ b/tests/lean_spec/subspecs/ssz/test_block.py @@ -1,49 +1,63 @@ -from lean_spec.subspecs.containers.attestation import ( - Attestation, - AttestationData, -) -from lean_spec.subspecs.containers.block.block import ( - Block, - BlockBody, - BlockWithAttestation, - SignedBlockWithAttestation, -) -from lean_spec.subspecs.containers.block.types import Attestations, BlockSignatures -from lean_spec.subspecs.containers.checkpoint import Checkpoint -from lean_spec.types import Bytes32, Uint64 - - -def test_encode_decode_signed_block_with_attestation_roundtrip() -> None: - signed_block_with_attestation = SignedBlockWithAttestation( - message=BlockWithAttestation( - block=Block( - slot=0, - proposer_index=Uint64(0), - parent_root=Bytes32.zero(), - state_root=Bytes32.zero(), - body=BlockBody(attestations=Attestations(data=[])), - ), - proposer_attestation=Attestation( - validator_id=Uint64(0), - data=AttestationData( - slot=0, - head=Checkpoint(root=Bytes32.zero(), slot=0), - target=Checkpoint(root=Bytes32.zero(), slot=0), - source=Checkpoint(root=Bytes32.zero(), slot=0), - ), - ), - ), - signature=BlockSignatures(data=[]), - ) - - encode = signed_block_with_attestation.encode_bytes() - expected_value = ( - "08000000ec0000008c000000000000000000000000000000000000000000000000000000000000000000000000" - "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" - "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" - "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" - "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" - "0000005400000004000000" - ) - assert encode.hex() == expected_value - assert SignedBlockWithAttestation.decode_bytes(encode) == signed_block_with_attestation +from lean_spec.subspecs.containers.attestation import Attestation, AttestationData +from lean_spec.subspecs.containers.block import ( + Block, + BlockBody, + BlockSignatures, + BlockWithAttestation, + SignedBlockWithAttestation, +) +from lean_spec.subspecs.containers.block.types import ( + AggregatedAttestations, + AttestationSignatures, +) +from lean_spec.subspecs.containers.checkpoint import Checkpoint +from lean_spec.subspecs.koalabear import Fp +from lean_spec.subspecs.xmss.constants import PROD_CONFIG +from lean_spec.subspecs.xmss.containers import Signature +from lean_spec.subspecs.xmss.types import HashDigestList, HashTreeOpening, Randomness +from lean_spec.types import Bytes32, Uint64 + + +def test_encode_decode_signed_block_with_attestation_roundtrip() -> None: + signed_block_with_attestation = SignedBlockWithAttestation( + message=BlockWithAttestation( + block=Block( + slot=0, + proposer_index=Uint64(0), + parent_root=Bytes32.zero(), + state_root=Bytes32.zero(), + body=BlockBody(attestations=AggregatedAttestations(data=[])), + ), + proposer_attestation=Attestation( + validator_id=Uint64(0), + data=AttestationData( + slot=0, + head=Checkpoint(root=Bytes32.zero(), slot=0), + target=Checkpoint(root=Bytes32.zero(), slot=0), + source=Checkpoint(root=Bytes32.zero(), slot=0), + ), + ), + ), + signature=BlockSignatures( + attestation_signatures=AttestationSignatures(data=[]), + proposer_signature=Signature( + path=HashTreeOpening(siblings=HashDigestList(data=[])), + rho=Randomness(data=[Fp(0) for _ in range(PROD_CONFIG.RAND_LEN_FE)]), + hashes=HashDigestList(data=[]), + ), + ), + ) + + encode = signed_block_with_attestation.encode_bytes() + expected_value = ( + "08000000ec0000008c00000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000000000000000000000000000000000000000000000" + "00000000000000000000000000000054000000040000000800000008000000240000000" + "00000000000000000000000000000000000000000000000000000002800000004000000" + ) + assert encode.hex() == expected_value, "Encoded value must match hardcoded expected value" + assert SignedBlockWithAttestation.decode_bytes(encode) == signed_block_with_attestation diff --git a/tests/lean_spec/subspecs/ssz/test_signed_attestation.py b/tests/lean_spec/subspecs/ssz/test_signed_attestation.py index a7a416fa..f30e6fc0 100644 --- a/tests/lean_spec/subspecs/ssz/test_signed_attestation.py +++ b/tests/lean_spec/subspecs/ssz/test_signed_attestation.py @@ -1,46 +1,39 @@ -from lean_spec.subspecs.containers import ( - Attestation, - AttestationData, - Checkpoint, - SignedAttestation, -) -from lean_spec.subspecs.koalabear import Fp -from lean_spec.subspecs.xmss.constants import PROD_CONFIG -from lean_spec.subspecs.xmss.containers import Signature -from lean_spec.subspecs.xmss.types import HashDigestList, HashTreeOpening, Randomness -from lean_spec.types import Bytes32, Uint64 - - -def test_encode_decode_signed_attestation_roundtrip() -> None: - signed_attestation = SignedAttestation( - message=Attestation( - validator_id=Uint64(0), - data=AttestationData( - slot=0, - head=Checkpoint(root=Bytes32.zero(), slot=0), - target=Checkpoint(root=Bytes32.zero(), slot=0), - source=Checkpoint(root=Bytes32.zero(), slot=0), - ), - ), - signature=Signature( - path=HashTreeOpening(siblings=HashDigestList(data=[])), - rho=Randomness(data=[Fp(0) for _ in range(PROD_CONFIG.RAND_LEN_FE)]), - hashes=HashDigestList(data=[]), - ), - ) - - # Test that encoding produces the expected hardcoded value - encode = signed_attestation.encode_bytes() - expected_value = ( - "000000000000000000000000000000000000000000000000000000000000" - "000000000000000000000000000000000000000000000000000000000000" - "000000000000000000000000000000000000000000000000000000000000" - "000000000000000000000000000000000000000000000000000000000000" - "000000000000000000000000000000008c00000024000000000000000000" - "000000000000000000000000000000000000000000002800000004000000" - ) - assert encode.hex() == expected_value, "Encoded value must match hardcoded expected value" - - # Test that decoding round-trips correctly - decoded = SignedAttestation.decode_bytes(encode) - assert decoded == signed_attestation +from lean_spec.subspecs.containers import AttestationData, Checkpoint, SignedAttestation +from lean_spec.subspecs.koalabear import Fp +from lean_spec.subspecs.xmss.constants import PROD_CONFIG +from lean_spec.subspecs.xmss.containers import Signature +from lean_spec.subspecs.xmss.types import HashDigestList, HashTreeOpening, Randomness +from lean_spec.types import Bytes32, Uint64 + + +def test_encode_decode_signed_attestation_roundtrip() -> None: + attestation_data = AttestationData( + slot=0, + head=Checkpoint(root=Bytes32.zero(), slot=0), + target=Checkpoint(root=Bytes32.zero(), slot=0), + source=Checkpoint(root=Bytes32.zero(), slot=0), + ) + signed_attestation = SignedAttestation( + validator_id=Uint64(0), + message=attestation_data, + signature=Signature( + path=HashTreeOpening(siblings=HashDigestList(data=[])), + rho=Randomness(data=[Fp(0) for _ in range(PROD_CONFIG.RAND_LEN_FE)]), + hashes=HashDigestList(data=[]), + ), + ) + + # Test that encoding produces the expected hardcoded value + encoded = signed_attestation.encode_bytes() + expected_value = ( + "000000000000000000000000000000000000000000000000000000000000000000000000" + "000000000000000000000000000000000000000000000000000000000000000000000000" + "000000000000000000000000000000000000000000000000000000000000000000000000" + "000000000000000000000000000000000000000000000000000000008c00000024000000" + "000000000000000000000000000000000000000000000000000000002800000004000000" + ) + + assert encoded.hex() == expected_value, "Encoded value must match hardcoded expected value" + # Test that decoding round-trips correctly + decoded = SignedAttestation.decode_bytes(encoded) + assert decoded == signed_attestation