Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ uv run fill --fork=devnet --clean -n auto # Generate test vectors
### Code Quality

```bash
uv run ruff format src tests packages # Format code
uv run ruff check --fix src tests packages # Lint and fix
uvx tox -e typecheck # Type check
uvx tox -e all-checks # All quality checks
uvx tox # Everything (checks + tests + docs)
uv run ruff format # Format code
uv run ruff check --fix # Lint and fix
uvx tox -e typecheck # Type check
uvx tox -e all-checks # All quality checks
uvx tox # Everything (checks + tests + docs)
```

### Common Tasks
Expand Down
34 changes: 17 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,16 +97,16 @@ uv run fill --clean --fork=devnet

```bash
# Check code style and errors
uv run ruff check src tests packages
uv run ruff check

# Auto-fix issues
uv run ruff check --fix src tests packages
uv run ruff check --fix

# Format code
uv run ruff format src tests packages
uv run ruff format

# Type checking
uv run mypy src tests packages
uv run ty check
```

### Using Tox for Comprehensive Checks
Expand Down Expand Up @@ -179,24 +179,24 @@ def test_withdrawal_amount_above_uint64_max():
- **pytest**: Testing framework - just name test files `test_*.py` and functions `test_*`
- **uv**: Fast Python package manager - like npm/yarn but for Python
- **ruff**: Linter and formatter
- **mypy**: Type checker that works with Pydantic models
- **ty**: Type checker
- **tox**: Automation tool for running tests across multiple environments (used via `uvx`)
- **mkdocs**: Documentation generator - write docs in Markdown, serve them locally

## Common Commands Reference

| Task | Command |
|-----------------------------------------------|----------------------------------------------|
| Install and sync project and dev dependencies | `uv sync` |
| Run tests | `uv run pytest ...` |
| Format code | `uv run ruff format src tests packages` |
| Lint code | `uv run ruff check src tests packages` |
| Fix lint errors | `uv run ruff check --fix src tests packages` |
| Type check | `uv run mypy src tests packages` |
| Build docs | `uv run mkdocs build` |
| Serve docs | `uv run mkdocs serve` |
| Run everything (checks + tests + docs) | `uvx tox` |
| Run all quality checks (no tests/docs) | `uvx tox -e all-checks` |
| Task | Command |
|-----------------------------------------------|---------------------------|
| Install and sync project and dev dependencies | `uv sync` |
| Run tests | `uv run pytest ...` |
| Format code | `uv run ruff format` |
| Lint code | `uv run ruff check` |
| Fix lint errors | `uv run ruff check --fix` |
| Type check | `uv run ty check` |
| Build docs | `uv run mkdocs build` |
| Serve docs | `uv run mkdocs serve` |
| Run everything (checks + tests + docs) | `uvx tox` |
| Run all quality checks (no tests/docs) | `uvx tox -e all-checks` |

## Contributing

Expand Down
6 changes: 3 additions & 3 deletions packages/testing/src/framework/forks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ def __init_subclass__(
cls._children = set()

# Track parent-child relationships
base_class = cls.__bases__[0]
if base_class != BaseFork and hasattr(base_class, "_children"):
base_class._children.add(cls)
for base in cls.__bases__:
if base is not BaseFork and issubclass(base, BaseFork):
base._children.add(cls)

@classmethod
@abstractmethod
Expand Down
17 changes: 6 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,11 @@ known-first-party = ["lean_spec"]
[tool.ruff.lint.per-file-ignores]
"tests/**" = ["D", "F401", "F403"]

[tool.mypy]
python_version = "3.12"
plugins = ["pydantic.mypy"]
strict = true
warn_return_any = true
warn_unused_configs = true
warn_unused_ignores = true
no_implicit_reexport = true
namespace_packages = false
explicit_package_bases = false
[tool.ty.environment]
python-version = "3.12"

[tool.ty.terminal]
error-on-warning = true

[tool.pytest.ini_options]
minversion = "8.3.3"
Expand Down Expand Up @@ -125,7 +120,7 @@ test = [
"lean-ethereum-testing",
]
lint = [
"mypy>=1.17.0,<2",
"ty>=0.0.1a34",
"ruff>=0.13.2,<1",
"codespell>=2.4.1,<3",
]
Expand Down
2 changes: 1 addition & 1 deletion src/lean_spec/subspecs/containers/attestation/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def to_validator_indices(self) -> list[Uint64]:
return indices


class NaiveAggregatedSignature(SSZList):
class NaiveAggregatedSignature(SSZList[Signature]):
"""Naive list of validator signatures used for aggregation placeholders."""

ELEMENT_TYPE = Signature
Expand Down
10 changes: 2 additions & 8 deletions src/lean_spec/subspecs/containers/block/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,12 @@
from ..attestation import AggregatedAttestation, AttestationData, NaiveAggregatedSignature


class AggregatedAttestations(SSZList):
class AggregatedAttestations(SSZList[AggregatedAttestation]):
"""List of aggregated attestations included in a block."""

ELEMENT_TYPE = AggregatedAttestation
LIMIT = int(VALIDATOR_REGISTRY_LIMIT)

def __getitem__(self, index: int) -> AggregatedAttestation:
"""Access an aggregated attestation by index with proper typing."""
item = self.data[index]
assert isinstance(item, AggregatedAttestation)
return item

def has_duplicate_data(self) -> bool:
"""Check if any two attestations share the same AttestationData."""
seen: set[AttestationData] = set()
Expand All @@ -28,7 +22,7 @@ def has_duplicate_data(self) -> bool:
return False


class AttestationSignatures(SSZList):
class AttestationSignatures(SSZList[NaiveAggregatedSignature]):
"""List of per-attestation naive signature lists aligned with block body attestations."""

ELEMENT_TYPE = NaiveAggregatedSignature
Expand Down
6 changes: 3 additions & 3 deletions src/lean_spec/subspecs/containers/state/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,8 @@ def process_block_header(self, block: Block) -> "State":
update={
"latest_justified": new_latest_justified,
"latest_finalized": new_latest_finalized,
"historical_block_hashes": HistoricalBlockHashes(data=new_historical_hashes_data),
"justified_slots": JustifiedSlots(data=new_justified_slots_data),
"historical_block_hashes": new_historical_hashes_data,
"justified_slots": new_justified_slots_data,
"latest_block_header": new_header,
}
)
Expand Down Expand Up @@ -570,7 +570,7 @@ def process_attestations(
"justifications_validators": JustificationValidators(
data=[vote for root in sorted_roots for vote in justifications[root]]
),
"justified_slots": JustifiedSlots(data=justified_slots),
"justified_slots": justified_slots,
"latest_justified": latest_justified,
"latest_finalized": latest_finalized,
}
Expand Down
12 changes: 3 additions & 9 deletions src/lean_spec/subspecs/containers/state/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
from ..validator import Validator


class HistoricalBlockHashes(SSZList):
class HistoricalBlockHashes(SSZList[Bytes32]):
"""List of historical block root hashes up to historical_roots_limit."""

ELEMENT_TYPE = Bytes32
LIMIT = int(DEVNET_CONFIG.historical_roots_limit)


class JustificationRoots(SSZList):
class JustificationRoots(SSZList[Bytes32]):
"""List of justified block roots up to historical_roots_limit."""

ELEMENT_TYPE = Bytes32
Expand All @@ -33,14 +33,8 @@ class JustificationValidators(BaseBitlist):
LIMIT = int(DEVNET_CONFIG.historical_roots_limit) * int(DEVNET_CONFIG.validator_registry_limit)


class Validators(SSZList):
class Validators(SSZList[Validator]):
"""Validator registry tracked in the state."""

ELEMENT_TYPE = Validator
LIMIT = int(DEVNET_CONFIG.validator_registry_limit)

def __getitem__(self, index: int) -> Validator:
"""Access a validator by index with proper typing."""
item = self.data[index]
assert isinstance(item, Validator)
return item
4 changes: 2 additions & 2 deletions src/lean_spec/subspecs/networking/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
MAX_REQUEST_BLOCKS: Final = 2**10
"""Maximum number of blocks in a single request."""

MESSAGE_DOMAIN_INVALID_SNAPPY: Final = DomainType(b"\x00\x00\x00\x00")
MESSAGE_DOMAIN_INVALID_SNAPPY: Final[DomainType] = b"\x00\x00\x00\x00"
"""4-byte domain for gossip message-id isolation of invalid snappy messages."""

MESSAGE_DOMAIN_VALID_SNAPPY: Final = DomainType(b"\x01\x00\x00\x00")
MESSAGE_DOMAIN_VALID_SNAPPY: Final[DomainType] = b"\x01\x00\x00\x00"
"""4-byte domain for gossip message-id isolation of valid snappy messages."""
4 changes: 2 additions & 2 deletions src/lean_spec/subspecs/networking/gossipsub/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def id(self) -> MessageId:
self.raw_data,
)

# Compute the raw ID bytes and cast to our strict type before caching
self._id = MessageId(self._compute_raw_id(domain, data_for_hash))
# Compute the raw ID bytes and assign with proper type annotation
self._id: MessageId = self._compute_raw_id(domain, data_for_hash)
return self._id

def _compute_raw_id(self, domain: bytes, message_data: bytes) -> bytes:
Expand Down
8 changes: 2 additions & 6 deletions src/lean_spec/subspecs/xmss/message_hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from __future__ import annotations

from typing import List, cast
from typing import List

from pydantic import model_validator

Expand Down Expand Up @@ -191,11 +191,7 @@ def apply(

# The input is: rho || P || epoch || message || iteration.
combined_input = (
cast(List[Fp], list(rho.data))
+ cast(List[Fp], list(parameter.data))
+ epoch_fe
+ message_fe
+ iteration_separator
list(rho.data) + list(parameter.data) + epoch_fe + message_fe + iteration_separator
)

# Hash the combined input using Poseidon2 compression mode.
Expand Down
10 changes: 5 additions & 5 deletions src/lean_spec/subspecs/xmss/subtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def new(
parents = [
hasher.apply(
parameter,
TreeTweak(level=level + 1, index=int(parent_start) + i),
TreeTweak(level=level + 1, index=Uint64(int(parent_start) + i)),
[current.nodes[2 * i], current.nodes[2 * i + 1]],
)
for i in range(len(current.nodes) // 2)
Expand Down Expand Up @@ -394,7 +394,7 @@ def from_prf_key(
chain_ends.append(end_digest)

# Hash the chain ends to get the leaf for this epoch.
leaf_tweak = TreeTweak(level=0, index=epoch)
leaf_tweak = TreeTweak(level=0, index=Uint64(epoch))
leaf_hash = hasher.apply(parameter, leaf_tweak, chain_ends)
leaf_hashes.append(leaf_hash)

Expand Down Expand Up @@ -537,7 +537,7 @@ def combined_path(
# Concatenate: bottom path + top path.
bottom_path = bottom_tree.path(position)
top_path = top_tree.path(position // leafs_per_tree)
combined = bottom_path.siblings.data + top_path.siblings.data
combined = tuple(bottom_path.siblings.data) + tuple(top_path.siblings.data)

return HashTreeOpening(siblings=HashDigestList(data=combined))

Expand Down Expand Up @@ -600,7 +600,7 @@ def verify_path(
# Start: hash leaf parts to get leaf node.
current = hasher.apply(
parameter,
TreeTweak(level=0, index=int(position)),
TreeTweak(level=0, index=Uint64(position)),
leaf_parts,
)
pos = int(position)
Expand All @@ -616,7 +616,7 @@ def verify_path(
pos //= 2 # Parent position.
current = hasher.apply(
parameter,
TreeTweak(level=level + 1, index=pos),
TreeTweak(level=level + 1, index=Uint64(pos)),
[left, right],
)

Expand Down
30 changes: 5 additions & 25 deletions src/lean_spec/subspecs/xmss/types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""Base types for the XMSS signature scheme."""

from typing import List

from lean_spec.subspecs.koalabear import Fp

from ...types import Uint64
Expand Down Expand Up @@ -49,7 +47,7 @@ class PRFKey(BaseBytes):
"""


class HashDigestVector(SSZVector):
class HashDigestVector(SSZVector[Fp]):
"""
A single hash digest represented as a fixed-size vector of field elements.

Expand All @@ -63,13 +61,8 @@ class HashDigestVector(SSZVector):
ELEMENT_TYPE = Fp
LENGTH = HASH_DIGEST_LENGTH

@property
def elements(self) -> List[Fp]:
"""Return the field elements as a typed list."""
return list(self.data) # type: ignore[arg-type]


class HashDigestList(SSZList):
class HashDigestList(SSZList[HashDigestVector]):
"""
Variable-length list of hash digests.

Expand All @@ -81,12 +74,8 @@ class HashDigestList(SSZList):
ELEMENT_TYPE = HashDigestVector
LIMIT = NODE_LIST_LIMIT

def __getitem__(self, index: int) -> HashDigestVector:
"""Access a hash digest by index with proper typing."""
return self.data[index] # type: ignore[return-value]


class Parameter(SSZVector):
class Parameter(SSZVector[Fp]):
"""
The public parameter P.

Expand All @@ -100,13 +89,8 @@ class Parameter(SSZVector):
ELEMENT_TYPE = Fp
LENGTH = PROD_CONFIG.PARAMETER_LEN

@property
def elements(self) -> List[Fp]:
"""Return the field elements as a typed list."""
return list(self.data) # type: ignore[arg-type]


class Randomness(SSZVector):
class Randomness(SSZVector[Fp]):
"""
The randomness `rho` (ρ) used during signing.

Expand Down Expand Up @@ -164,7 +148,7 @@ class HashTreeLayer(Container):
"""


class HashTreeLayers(SSZList):
class HashTreeLayers(SSZList[HashTreeLayer]):
"""
Variable-length list of Merkle tree layers.

Expand All @@ -180,7 +164,3 @@ class HashTreeLayers(SSZList):

ELEMENT_TYPE = HashTreeLayer
LIMIT = LAYERS_LIMIT

def __getitem__(self, index: int) -> HashTreeLayer:
"""Access a layer by index with proper typing."""
return self.data[index] # type: ignore[return-value]
Loading
Loading