Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions src/lean_spec/subspecs/containers/block/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,11 @@ def verify_signatures(self, parent_state: "State") -> bool:

# Verify each attestation signature
for attestation, signature in zip(all_attestations, signatures, strict=True):
# Identify the validator who created this attestation
validator_id = attestation.validator_id.as_int()

# Ensure validator exists in the active set
assert validator_id < len(validators), "Validator index out of range"
validator = cast(Validator, validators[validator_id])
assert attestation.validator_id < Uint64(len(validators)), (
"Validator index out of range"
)
validator = cast(Validator, validators[attestation.validator_id])

# Verify the XMSS signature
#
Expand Down
3 changes: 2 additions & 1 deletion src/lean_spec/subspecs/containers/slot.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def is_justifiable_after(self, finalized_slot: Slot) -> bool:
assert self >= finalized_slot, "Candidate slot must not be before finalized slot"

# Calculate the distance in slots from the last finalized slot.
delta = (self - finalized_slot).as_int()
# Convert to int for pure arithmetic operations below.
delta = int(self - finalized_slot)

return (
# Rule 1: The first 5 slots after finalization are always justifiable.
Expand Down
27 changes: 11 additions & 16 deletions src/lean_spec/subspecs/containers/state/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,18 +258,16 @@ def process_block_header(self, block: Block) -> "State":
)

# If there were empty slots between parent and this block, fill them.
num_empty_slots = (block.slot - parent_header.slot - Slot(1)).as_int()
num_empty_slots = int(block.slot - parent_header.slot - Slot(1))

# Build new historical hashes list
new_historical_hashes_data = (
self.historical_block_hashes + [parent_root] + ([ZERO_HASH] * num_empty_slots)
self.historical_block_hashes + [parent_root] + [ZERO_HASH] * num_empty_slots
)

# Build new justified slots list
new_justified_slots_data = (
self.justified_slots
+ [Boolean(is_genesis_parent)]
+ ([Boolean(False)] * num_empty_slots)
self.justified_slots + [Boolean(is_genesis_parent)] + [Boolean(False)] * num_empty_slots
)

# Construct the new latest block header.
Expand Down Expand Up @@ -384,27 +382,25 @@ def process_attestations(
# Ignore attestations whose source is not already justified,
# or whose target is not in the history, or whose target is not a
# valid justifiable slot
source_slot = source.slot.as_int()
target_slot = target.slot.as_int()

# Source slot must be justified
if not justified_slots[source_slot]:
if not justified_slots[source.slot]:
continue

# Target slot must not be already justified
# This condition is missing in 3sf mini but has been added here because
# we don't want to re-introduce the target again for remaining votes if
# the slot is already justified and its tracking already cleared out
# from justifications map
if justified_slots[target_slot]:
if justified_slots[target.slot]:
continue

# Source root must match the state's historical block hashes
if source.root != self.historical_block_hashes[source_slot]:
if source.root != self.historical_block_hashes[source.slot]:
continue

# Target root must match the state's historical block hashes
if target.root != self.historical_block_hashes[target_slot]:
if target.root != self.historical_block_hashes[target.slot]:
continue

# Target slot must be after source slot
Expand All @@ -419,9 +415,8 @@ def process_attestations(
if target.root not in justifications:
justifications[target.root] = [Boolean(False)] * self.validators.count

validator_id = attestation.validator_id.as_int()
if not justifications[target.root][validator_id]:
justifications[target.root][validator_id] = Boolean(True)
if not justifications[target.root][attestation.validator_id]:
justifications[target.root][attestation.validator_id] = Boolean(True)

count = sum(bool(justified) for justified in justifications[target.root])

Expand All @@ -432,14 +427,14 @@ def process_attestations(
# justifying specially if the num_validators is low in testing scenarios
if 3 * count >= (2 * self.validators.count):
latest_justified = target
justified_slots[target_slot] = True
justified_slots[target.slot] = True
del justifications[target.root]

# Finalization: if the target is the next valid justifiable
# hash after the source
if not any(
Slot(slot).is_justifiable_after(self.latest_finalized.slot)
for slot in range(source_slot + 1, target_slot)
for slot in range(source.slot + Slot(1), target.slot)
):
latest_finalized = source

Expand Down
13 changes: 5 additions & 8 deletions src/lean_spec/subspecs/containers/state/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,36 +11,33 @@ class HistoricalBlockHashes(SSZList):
"""List of historical block root hashes up to historical_roots_limit."""

ELEMENT_TYPE = Bytes32
LIMIT = DEVNET_CONFIG.historical_roots_limit.as_int()
LIMIT = int(DEVNET_CONFIG.historical_roots_limit)


class JustificationRoots(SSZList):
"""List of justified block roots up to historical_roots_limit."""

ELEMENT_TYPE = Bytes32
LIMIT = DEVNET_CONFIG.historical_roots_limit.as_int()
LIMIT = int(DEVNET_CONFIG.historical_roots_limit)


class JustifiedSlots(BaseBitlist):
"""Bitlist tracking justified slots up to historical roots limit."""

LIMIT = DEVNET_CONFIG.historical_roots_limit.as_int()
LIMIT = int(DEVNET_CONFIG.historical_roots_limit)


class JustificationValidators(BaseBitlist):
"""Bitlist for tracking validator justifications per historical root."""

LIMIT = (
DEVNET_CONFIG.historical_roots_limit.as_int()
* DEVNET_CONFIG.validator_registry_limit.as_int()
)
LIMIT = int(DEVNET_CONFIG.historical_roots_limit) * int(DEVNET_CONFIG.validator_registry_limit)


class Validators(SSZList):
"""Validator registry tracked in the state."""

ELEMENT_TYPE = Validator
LIMIT = DEVNET_CONFIG.validator_registry_limit.as_int()
LIMIT = int(DEVNET_CONFIG.validator_registry_limit)

@property
def count(self) -> int:
Expand Down
4 changes: 1 addition & 3 deletions src/lean_spec/subspecs/networking/gossipsub/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@ class GossipsubParameters(StrictBaseModel):
"""The number of history windows to gossip about."""

seen_ttl_secs: int = (
DEVNET_CONFIG.seconds_per_slot.as_int()
* DEVNET_CONFIG.justification_lookback_slots.as_int()
* 2
int(DEVNET_CONFIG.seconds_per_slot) * int(DEVNET_CONFIG.justification_lookback_slots) * 2
)
"""
The expiry time in seconds for the cache of seen message IDs.
Expand Down
8 changes: 4 additions & 4 deletions src/lean_spec/types/uint.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,6 @@ def max_value(cls) -> Self:
"""The maximum value for this unsigned integer."""
return cls(2**cls.BITS - 1)

def as_int(self) -> int:
"""Convert the unsigned integer to a plain integer."""
return int(self)

def to_bytes(
self,
length: SupportsIndex | None = None,
Expand Down Expand Up @@ -392,6 +388,10 @@ def __hash__(self) -> int:
"""Return a distinct hash for the object."""
return hash((type(self), int(self)))

def __index__(self) -> int:
"""Return self as an integer for use in slicing and indexing."""
return int(self)


class Uint8(BaseUint):
"""A type representing an 8-bit unsigned integer (uint8)."""
Expand Down
14 changes: 7 additions & 7 deletions tests/lean_spec/subspecs/forkchoice/test_time_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ def test_on_tick_already_current(self, sample_store: Store) -> None:
# Try to advance to current time (should be no-op)
sample_store = sample_store.on_tick(current_target, has_proposal=True)

# Should not change significantly
assert abs(sample_store.time.as_int() - initial_time.as_int()) <= 10 # small tolerance
# Should not change significantly (time can only increase)
assert sample_store.time - initial_time <= Uint64(10) # small tolerance

def test_on_tick_small_increment(self, sample_store: Store) -> None:
"""Test on_tick with small time increment."""
Expand Down Expand Up @@ -156,12 +156,12 @@ def test_tick_interval_actions_by_phase(self, sample_store: Store) -> None:
)

# Tick through a complete slot cycle
for interval in range(INTERVALS_PER_SLOT.as_int()):
for interval in range(INTERVALS_PER_SLOT):
has_proposal = interval == 0 # Proposal only in first interval
sample_store = sample_store.tick_interval(has_proposal=has_proposal)

current_interval = sample_store.time % INTERVALS_PER_SLOT
expected_interval = Uint64((interval + 1) % INTERVALS_PER_SLOT.as_int())
expected_interval = Uint64((interval + 1)) % INTERVALS_PER_SLOT
assert current_interval == expected_interval


Expand All @@ -175,15 +175,15 @@ def test_slot_to_time_conversion(self, sample_config: Config) -> None:
genesis_time = sample_config.genesis_time

# Slot 0 should be at genesis time
slot_0_time = genesis_time + Uint64(0 * SECONDS_PER_SLOT.as_int())
slot_0_time = genesis_time + Uint64(0) * SECONDS_PER_SLOT
assert slot_0_time == genesis_time

# Slot 1 should be at genesis + SECONDS_PER_SLOT
slot_1_time = genesis_time + Uint64(1 * SECONDS_PER_SLOT.as_int())
slot_1_time = genesis_time + Uint64(1) * SECONDS_PER_SLOT
assert slot_1_time == genesis_time + SECONDS_PER_SLOT

# Slot 10 should be at genesis + 10 * SECONDS_PER_SLOT
slot_10_time = genesis_time + Uint64(10 * SECONDS_PER_SLOT.as_int())
slot_10_time = genesis_time + Uint64(10) * SECONDS_PER_SLOT
assert slot_10_time == genesis_time + Uint64(10) * SECONDS_PER_SLOT

def test_time_to_slot_conversion(self, sample_config: Config) -> None:
Expand Down
59 changes: 58 additions & 1 deletion tests/lean_spec/types/test_uint.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def test_arithmetic_operators(uint_class: Type[BaseUint]) -> None:
assert uint_class(b_val) ** 4 == uint_class(b_val**4)
if uint_class.BITS <= 16: # Pow gets too big quickly
with pytest.raises(OverflowError):
_ = a ** b.as_int()
_ = a ** int(b)


@pytest.mark.parametrize("uint_class", ALL_UINT_TYPES)
Expand Down Expand Up @@ -241,6 +241,63 @@ def test_hash(uint_class: Type[BaseUint]) -> None:
assert hash(uint_class(1)) != hash(uint_class(2))


@pytest.mark.parametrize("uint_class", ALL_UINT_TYPES)
def test_index_list_access(uint_class: Type[BaseUint]) -> None:
"""Tests that Uint types can be used directly for list indexing."""
data = ["a", "b", "c", "d", "e"]
idx = uint_class(2)
assert data[idx] == "c"
assert data[uint_class(0)] == "a"
assert data[uint_class(4)] == "e"


@pytest.mark.parametrize("uint_class", ALL_UINT_TYPES)
def test_index_slicing(uint_class: Type[BaseUint]) -> None:
"""Tests that Uint types can be used in slice operations."""
data = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
start = uint_class(2)
stop = uint_class(7)
step = uint_class(2)

assert data[start:stop] == [2, 3, 4, 5, 6]
assert data[:stop] == [0, 1, 2, 3, 4, 5, 6]
assert data[start:] == [2, 3, 4, 5, 6, 7, 8, 9]
assert data[start:stop:step] == [2, 4, 6]


@pytest.mark.parametrize("uint_class", ALL_UINT_TYPES)
def test_index_range(uint_class: Type[BaseUint]) -> None:
"""Tests that Uint types can be used in range()."""
n = uint_class(5)
result = list(range(n))
assert result == [0, 1, 2, 3, 4]

start = uint_class(2)
stop = uint_class(8)
step = uint_class(2)
result = list(range(start, stop, step))
assert result == [2, 4, 6]


@pytest.mark.parametrize("uint_class", ALL_UINT_TYPES)
def test_index_hex_bin_oct(uint_class: Type[BaseUint]) -> None:
"""Tests that Uint types work with hex(), bin(), oct()."""
val = uint_class(42)
assert hex(val) == "0x2a"
assert bin(val) == "0b101010"
assert oct(val) == "0o52"


@pytest.mark.parametrize("uint_class", ALL_UINT_TYPES)
def test_index_operator_index(uint_class: Type[BaseUint]) -> None:
"""Tests that operator.index() works with Uint types."""
import operator

val = uint_class(42)
assert operator.index(val) == 42
assert isinstance(operator.index(val), int)


@pytest.mark.parametrize("uint_class", ALL_UINT_TYPES)
def test_to_bytes_default(uint_class: Type[BaseUint]) -> None:
"""Tests the default behavior of the to_bytes method."""
Expand Down
Loading