Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions src/lean_spec/subspecs/containers/block/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,11 @@ def verify_signatures(self, parent_state: "State") -> bool:

# Verify each attestation signature
for attestation, signature in zip(all_attestations, signatures, strict=True):
# Identify the validator who created this attestation
validator_id = attestation.validator_id.as_int()

# Ensure validator exists in the active set
assert validator_id < len(validators), "Validator index out of range"
validator = cast(Validator, validators[validator_id])
assert attestation.validator_id < Uint64(len(validators)), (
"Validator index out of range"
)
validator = cast(Validator, validators[attestation.validator_id])

# Verify the XMSS signature
#
Expand Down
3 changes: 2 additions & 1 deletion src/lean_spec/subspecs/containers/slot.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def is_justifiable_after(self, finalized_slot: Slot) -> bool:
assert self >= finalized_slot, "Candidate slot must not be before finalized slot"

# Calculate the distance in slots from the last finalized slot.
delta = (self - finalized_slot).as_int()
# Convert to int for pure arithmetic operations below.
delta = int(self - finalized_slot)

return (
# Rule 1: The first 5 slots after finalization are always justifiable.
Expand Down
21 changes: 9 additions & 12 deletions src/lean_spec/subspecs/containers/state/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def process_block_header(self, block: Block) -> "State":
# If slots were skipped (missed proposals), we must record them.
#
# Formula: (Current - Parent - 1). Adjacent blocks have a gap of 0.
num_empty_slots = (block.slot - parent_header.slot - Slot(1)).as_int()
num_empty_slots = int(block.slot - parent_header.slot - Slot(1))

# Update the list of historical block roots.
#
Expand Down Expand Up @@ -425,27 +425,25 @@ def process_attestations(
# Ignore attestations whose source is not already justified,
# or whose target is not in the history, or whose target is not a
# valid justifiable slot
source_slot = source.slot.as_int()
target_slot = target.slot.as_int()

# Source slot must be justified
if not justified_slots[source_slot]:
if not justified_slots[source.slot]:
continue

# Target slot must not be already justified
# This condition is missing in 3sf mini but has been added here because
# we don't want to re-introduce the target again for remaining votes if
# the slot is already justified and its tracking already cleared out
# from justifications map
if justified_slots[target_slot]:
if justified_slots[target.slot]:
continue

# Source root must match the state's historical block hashes
if source.root != self.historical_block_hashes[source_slot]:
if source.root != self.historical_block_hashes[source.slot]:
continue

# Target root must match the state's historical block hashes
if target.root != self.historical_block_hashes[target_slot]:
if target.root != self.historical_block_hashes[target.slot]:
continue

# Target slot must be after source slot
Expand All @@ -460,9 +458,8 @@ def process_attestations(
if target.root not in justifications:
justifications[target.root] = [Boolean(False)] * self.validators.count

validator_id = attestation.validator_id.as_int()
if not justifications[target.root][validator_id]:
justifications[target.root][validator_id] = Boolean(True)
if not justifications[target.root][attestation.validator_id]:
justifications[target.root][attestation.validator_id] = Boolean(True)

count = sum(bool(justified) for justified in justifications[target.root])

Expand All @@ -473,14 +470,14 @@ def process_attestations(
# justifying specially if the num_validators is low in testing scenarios
if 3 * count >= (2 * self.validators.count):
latest_justified = target
justified_slots[target_slot] = True
justified_slots[target.slot] = True
del justifications[target.root]

# Finalization: if the target is the next valid justifiable
# hash after the source
if not any(
Slot(slot).is_justifiable_after(self.latest_finalized.slot)
for slot in range(source_slot + 1, target_slot)
for slot in range(source.slot + Slot(1), target.slot)
):
latest_finalized = source

Expand Down
13 changes: 5 additions & 8 deletions src/lean_spec/subspecs/containers/state/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,36 +11,33 @@ class HistoricalBlockHashes(SSZList):
"""List of historical block root hashes up to historical_roots_limit."""

ELEMENT_TYPE = Bytes32
LIMIT = DEVNET_CONFIG.historical_roots_limit.as_int()
LIMIT = int(DEVNET_CONFIG.historical_roots_limit)


class JustificationRoots(SSZList):
"""List of justified block roots up to historical_roots_limit."""

ELEMENT_TYPE = Bytes32
LIMIT = DEVNET_CONFIG.historical_roots_limit.as_int()
LIMIT = int(DEVNET_CONFIG.historical_roots_limit)


class JustifiedSlots(BaseBitlist):
"""Bitlist tracking justified slots up to historical roots limit."""

LIMIT = DEVNET_CONFIG.historical_roots_limit.as_int()
LIMIT = int(DEVNET_CONFIG.historical_roots_limit)


class JustificationValidators(BaseBitlist):
"""Bitlist for tracking validator justifications per historical root."""

LIMIT = (
DEVNET_CONFIG.historical_roots_limit.as_int()
* DEVNET_CONFIG.validator_registry_limit.as_int()
)
LIMIT = int(DEVNET_CONFIG.historical_roots_limit) * int(DEVNET_CONFIG.validator_registry_limit)


class Validators(SSZList):
"""Validator registry tracked in the state."""

ELEMENT_TYPE = Validator
LIMIT = DEVNET_CONFIG.validator_registry_limit.as_int()
LIMIT = int(DEVNET_CONFIG.validator_registry_limit)

@property
def count(self) -> int:
Expand Down
4 changes: 1 addition & 3 deletions src/lean_spec/subspecs/networking/gossipsub/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@ class GossipsubParameters(StrictBaseModel):
"""The number of history windows to gossip about."""

seen_ttl_secs: int = (
DEVNET_CONFIG.seconds_per_slot.as_int()
* DEVNET_CONFIG.justification_lookback_slots.as_int()
* 2
int(DEVNET_CONFIG.seconds_per_slot) * int(DEVNET_CONFIG.justification_lookback_slots) * 2
)
"""
The expiry time in seconds for the cache of seen message IDs.
Expand Down
8 changes: 4 additions & 4 deletions src/lean_spec/types/uint.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,6 @@ def max_value(cls) -> Self:
"""The maximum value for this unsigned integer."""
return cls(2**cls.BITS - 1)

def as_int(self) -> int:
"""Convert the unsigned integer to a plain integer."""
return int(self)

def to_bytes(
self,
length: SupportsIndex | None = None,
Expand Down Expand Up @@ -392,6 +388,10 @@ def __hash__(self) -> int:
"""Return a distinct hash for the object."""
return hash((type(self), int(self)))

def __index__(self) -> int:
"""Return self as an integer for use in slicing and indexing."""
return int(self)


class Uint8(BaseUint):
"""A type representing an 8-bit unsigned integer (uint8)."""
Expand Down
14 changes: 7 additions & 7 deletions tests/lean_spec/subspecs/forkchoice/test_time_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ def test_on_tick_already_current(self, sample_store: Store) -> None:
# Try to advance to current time (should be no-op)
sample_store = sample_store.on_tick(current_target, has_proposal=True)

# Should not change significantly
assert abs(sample_store.time.as_int() - initial_time.as_int()) <= 10 # small tolerance
# Should not change significantly (time can only increase)
assert sample_store.time - initial_time <= Uint64(10) # small tolerance

def test_on_tick_small_increment(self, sample_store: Store) -> None:
"""Test on_tick with small time increment."""
Expand Down Expand Up @@ -156,12 +156,12 @@ def test_tick_interval_actions_by_phase(self, sample_store: Store) -> None:
)

# Tick through a complete slot cycle
for interval in range(INTERVALS_PER_SLOT.as_int()):
for interval in range(INTERVALS_PER_SLOT):
has_proposal = interval == 0 # Proposal only in first interval
sample_store = sample_store.tick_interval(has_proposal=has_proposal)

current_interval = sample_store.time % INTERVALS_PER_SLOT
expected_interval = Uint64((interval + 1) % INTERVALS_PER_SLOT.as_int())
expected_interval = Uint64((interval + 1)) % INTERVALS_PER_SLOT
assert current_interval == expected_interval


Expand All @@ -175,15 +175,15 @@ def test_slot_to_time_conversion(self, sample_config: Config) -> None:
genesis_time = sample_config.genesis_time

# Slot 0 should be at genesis time
slot_0_time = genesis_time + Uint64(0 * SECONDS_PER_SLOT.as_int())
slot_0_time = genesis_time + Uint64(0) * SECONDS_PER_SLOT
assert slot_0_time == genesis_time

# Slot 1 should be at genesis + SECONDS_PER_SLOT
slot_1_time = genesis_time + Uint64(1 * SECONDS_PER_SLOT.as_int())
slot_1_time = genesis_time + Uint64(1) * SECONDS_PER_SLOT
assert slot_1_time == genesis_time + SECONDS_PER_SLOT

# Slot 10 should be at genesis + 10 * SECONDS_PER_SLOT
slot_10_time = genesis_time + Uint64(10 * SECONDS_PER_SLOT.as_int())
slot_10_time = genesis_time + Uint64(10) * SECONDS_PER_SLOT
assert slot_10_time == genesis_time + Uint64(10) * SECONDS_PER_SLOT

def test_time_to_slot_conversion(self, sample_config: Config) -> None:
Expand Down
59 changes: 58 additions & 1 deletion tests/lean_spec/types/test_uint.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def test_arithmetic_operators(uint_class: Type[BaseUint]) -> None:
assert uint_class(b_val) ** 4 == uint_class(b_val**4)
if uint_class.BITS <= 16: # Pow gets too big quickly
with pytest.raises(OverflowError):
_ = a ** b.as_int()
_ = a ** int(b)


@pytest.mark.parametrize("uint_class", ALL_UINT_TYPES)
Expand Down Expand Up @@ -241,6 +241,63 @@ def test_hash(uint_class: Type[BaseUint]) -> None:
assert hash(uint_class(1)) != hash(uint_class(2))


@pytest.mark.parametrize("uint_class", ALL_UINT_TYPES)
def test_index_list_access(uint_class: Type[BaseUint]) -> None:
"""Tests that Uint types can be used directly for list indexing."""
data = ["a", "b", "c", "d", "e"]
idx = uint_class(2)
assert data[idx] == "c"
assert data[uint_class(0)] == "a"
assert data[uint_class(4)] == "e"


@pytest.mark.parametrize("uint_class", ALL_UINT_TYPES)
def test_index_slicing(uint_class: Type[BaseUint]) -> None:
"""Tests that Uint types can be used in slice operations."""
data = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
start = uint_class(2)
stop = uint_class(7)
step = uint_class(2)

assert data[start:stop] == [2, 3, 4, 5, 6]
assert data[:stop] == [0, 1, 2, 3, 4, 5, 6]
assert data[start:] == [2, 3, 4, 5, 6, 7, 8, 9]
assert data[start:stop:step] == [2, 4, 6]


@pytest.mark.parametrize("uint_class", ALL_UINT_TYPES)
def test_index_range(uint_class: Type[BaseUint]) -> None:
"""Tests that Uint types can be used in range()."""
n = uint_class(5)
result = list(range(n))
assert result == [0, 1, 2, 3, 4]

start = uint_class(2)
stop = uint_class(8)
step = uint_class(2)
result = list(range(start, stop, step))
assert result == [2, 4, 6]


@pytest.mark.parametrize("uint_class", ALL_UINT_TYPES)
def test_index_hex_bin_oct(uint_class: Type[BaseUint]) -> None:
"""Tests that Uint types work with hex(), bin(), oct()."""
val = uint_class(42)
assert hex(val) == "0x2a"
assert bin(val) == "0b101010"
assert oct(val) == "0o52"


@pytest.mark.parametrize("uint_class", ALL_UINT_TYPES)
def test_index_operator_index(uint_class: Type[BaseUint]) -> None:
"""Tests that operator.index() works with Uint types."""
import operator

val = uint_class(42)
assert operator.index(val) == 42
assert isinstance(operator.index(val), int)


@pytest.mark.parametrize("uint_class", ALL_UINT_TYPES)
def test_to_bytes_default(uint_class: Type[BaseUint]) -> None:
"""Tests the default behavior of the to_bytes method."""
Expand Down
Loading