Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/lean_spec/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,18 @@
from .byte_arrays import ZERO_HASH, Bytes32, Bytes52, Bytes3116
from .collections import SSZList, SSZVector
from .container import Container
from .exceptions import (
SSZError,
SSZSerializationError,
SSZTypeError,
SSZValueError,
)
from .ssz_base import SSZType
from .uint import Uint64
from .validator import is_proposer

__all__ = [
# Core types
"Uint64",
"BasisPoint",
"Bytes32",
Expand All @@ -25,4 +32,9 @@
"SSZType",
"Boolean",
"Container",
# Exceptions
"SSZError",
"SSZTypeError",
"SSZValueError",
"SSZSerializationError",
]
33 changes: 19 additions & 14 deletions src/lean_spec/types/bitfields.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from typing_extensions import Self

from .boolean import Boolean
from .exceptions import SSZSerializationError, SSZTypeError, SSZValueError
from .ssz_base import SSZModel


Expand All @@ -55,13 +56,15 @@ class BaseBitvector(SSZModel):
def _coerce_and_validate(cls, v: Any) -> tuple[Boolean, ...]:
"""Validate and convert input data to typed tuple of Booleans."""
if not hasattr(cls, "LENGTH"):
raise TypeError(f"{cls.__name__} must define LENGTH")
raise SSZTypeError(f"{cls.__name__} must define LENGTH")

if not isinstance(v, (list, tuple)):
v = tuple(v)

if len(v) != cls.LENGTH:
raise ValueError(f"{cls.__name__} requires exactly {cls.LENGTH} bits, got {len(v)}")
raise SSZValueError(
f"{cls.__name__} requires exactly {cls.LENGTH} elements, got {len(v)}"
)

return tuple(Boolean(bit) for bit in v)

Expand All @@ -86,10 +89,12 @@ def deserialize(cls, stream: IO[bytes], scope: int) -> Self:
"""Read SSZ bytes from a stream and return an instance."""
expected_len = cls.get_byte_length()
if scope != expected_len:
raise ValueError(f"{cls.__name__}: expected {expected_len} bytes, got {scope}")
raise SSZSerializationError(
f"{cls.__name__}: expected {expected_len} bytes, got {scope}"
)
data = stream.read(scope)
if len(data) != scope:
raise IOError(f"Expected {scope} bytes, got {len(data)}")
raise SSZSerializationError(f"{cls.__name__}: expected {scope} bytes, got {len(data)}")
return cls.decode_bytes(data)

def encode_bytes(self) -> bytes:
Expand All @@ -115,7 +120,7 @@ def decode_bytes(cls, data: bytes) -> Self:
"""
expected = cls.get_byte_length()
if len(data) != expected:
raise ValueError(f"{cls.__name__}: expected {expected} bytes, got {len(data)}")
raise SSZValueError(f"{cls.__name__}: expected {expected} bytes, got {len(data)}")

bits = tuple(Boolean((data[i // 8] >> (i % 8)) & 1) for i in range(cls.LENGTH))
return cls(data=bits)
Expand Down Expand Up @@ -144,19 +149,19 @@ class BaseBitlist(SSZModel):
def _coerce_and_validate(cls, v: Any) -> tuple[Boolean, ...]:
"""Validate and convert input to a tuple of Boolean elements."""
if not hasattr(cls, "LIMIT"):
raise TypeError(f"{cls.__name__} must define LIMIT")
raise SSZTypeError(f"{cls.__name__} must define LIMIT")

# Handle various input types
if isinstance(v, (list, tuple)):
elements = v
elif hasattr(v, "__iter__") and not isinstance(v, (str, bytes)):
elements = list(v)
else:
raise TypeError(f"Bitlist data must be iterable, got {type(v)}")
raise SSZTypeError(f"Expected iterable, got {type(v).__name__}")

# Check limit
if len(elements) > cls.LIMIT:
raise ValueError(f"{cls.__name__} cannot exceed {cls.LIMIT} bits, got {len(elements)}")
raise SSZValueError(f"{cls.__name__} exceeds limit of {cls.LIMIT}, got {len(elements)}")

return tuple(Boolean(bit) for bit in elements)

Expand Down Expand Up @@ -197,8 +202,8 @@ def is_fixed_size(cls) -> bool:

@classmethod
def get_byte_length(cls) -> int:
"""Lists are variable-size, so this raises a TypeError."""
raise TypeError(f"{cls.__name__} is variable-size")
"""Lists are variable-size, so this raises an SSZTypeError."""
raise SSZTypeError(f"{cls.__name__}: variable-size bitlist has no fixed byte length")

def serialize(self, stream: IO[bytes]) -> int:
"""Write SSZ bytes to a binary stream."""
Expand All @@ -211,7 +216,7 @@ def deserialize(cls, stream: IO[bytes], scope: int) -> Self:
"""Read SSZ bytes from a stream and return an instance."""
data = stream.read(scope)
if len(data) != scope:
raise IOError(f"Expected {scope} bytes, got {len(data)}")
raise SSZSerializationError(f"{cls.__name__}: expected {scope} bytes, got {len(data)}")
return cls.decode_bytes(data)

def encode_bytes(self) -> bytes:
Expand Down Expand Up @@ -254,7 +259,7 @@ def decode_bytes(cls, data: bytes) -> Self:
the last data bit. All bits after the delimiter are assumed to be 0.
"""
if len(data) == 0:
raise ValueError("Cannot decode empty bytes to Bitlist")
raise SSZSerializationError(f"{cls.__name__}: cannot decode empty bytes")

# Find the position of the delimiter bit (rightmost 1).
delimiter_pos = None
Expand All @@ -267,12 +272,12 @@ def decode_bytes(cls, data: bytes) -> Self:
break

if delimiter_pos is None:
raise ValueError("No delimiter bit found in Bitlist data")
raise SSZSerializationError(f"{cls.__name__}: no delimiter bit found")

# Extract data bits (everything before the delimiter).
num_bits = delimiter_pos
if num_bits > cls.LIMIT:
raise ValueError(f"{cls.__name__} decoded length {num_bits} exceeds limit {cls.LIMIT}")
raise SSZValueError(f"{cls.__name__} exceeds limit of {cls.LIMIT}, got {num_bits}")

bits = tuple(Boolean((data[i // 8] >> (i % 8)) & 1) for i in range(num_bits))
return cls(data=bits)
17 changes: 9 additions & 8 deletions src/lean_spec/types/boolean.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pydantic_core import CoreSchema, core_schema
from typing_extensions import Self

from .exceptions import SSZSerializationError, SSZTypeError, SSZValueError
from .ssz_base import SSZType


Expand All @@ -31,14 +32,14 @@ def __new__(cls, value: bool | int) -> Self:
Accepts only `True`, `False`, `1`, or `0`.

Raises:
TypeError: If `value` is not a bool or int.
ValueError: If `value` is an integer other than 0 or 1.
SSZTypeCoercionError: If `value` is not a bool or int.
SSZDecodeError: If `value` is an integer other than 0 or 1.
"""
if not isinstance(value, int):
raise TypeError(f"Expected bool or int, got {type(value).__name__}")
raise SSZTypeError(f"Expected bool or int, got {type(value).__name__}")

if value not in (0, 1):
raise ValueError(f"Boolean value must be 0 or 1, not {value}")
raise SSZValueError(f"Boolean value must be 0 or 1, not {value}")

return super().__new__(cls, value)

Expand Down Expand Up @@ -93,9 +94,9 @@ def encode_bytes(self) -> bytes:
def decode_bytes(cls, data: bytes) -> Self:
"""Deserialize a single byte into a Boolean instance."""
if len(data) != 1:
raise ValueError(f"Expected 1 byte for Boolean, got {len(data)}")
raise SSZSerializationError(f"Boolean: expected 1 byte, got {len(data)}")
if data[0] not in (0, 1):
raise ValueError(f"Boolean byte must be 0x00 or 0x01, got {data[0]:#04x}")
raise SSZSerializationError(f"Boolean: byte must be 0x00 or 0x01, got {data[0]:#04x}")
return cls(data[0])

def serialize(self, stream: IO[bytes]) -> int:
Expand All @@ -108,10 +109,10 @@ def serialize(self, stream: IO[bytes]) -> int:
def deserialize(cls, stream: IO[bytes], scope: int) -> Self:
"""Deserialize a boolean from a binary stream."""
if scope != 1:
raise ValueError(f"Invalid scope for Boolean: expected 1, got {scope}")
raise SSZSerializationError(f"Boolean: expected scope of 1, got {scope}")
data = stream.read(1)
if len(data) != 1:
raise IOError("Stream ended prematurely while decoding Boolean")
raise SSZSerializationError(f"Boolean: expected 1 byte, got {len(data)}")
return cls.decode_bytes(data)

def _raise_type_error(self, other: Any, op_symbol: str) -> None:
Expand Down
41 changes: 22 additions & 19 deletions src/lean_spec/types/byte_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pydantic_core import core_schema
from typing_extensions import Self

from .exceptions import SSZSerializationError, SSZTypeError, SSZValueError
from .ssz_base import SSZModel, SSZType


Expand Down Expand Up @@ -64,14 +65,15 @@ def __new__(cls, value: Any = b"") -> Self:
value: Any value coercible to bytes (see `_coerce_to_bytes`).

Raises:
ValueError: If the resulting byte length differs from `LENGTH`.
SSZTypeDefinitionError: If the class doesn't define LENGTH.
SSZLengthError: If the resulting byte length differs from `LENGTH`.
"""
if not hasattr(cls, "LENGTH"):
raise TypeError(f"{cls.__name__} must define LENGTH")
raise SSZTypeError(f"{cls.__name__} must define LENGTH")

b = _coerce_to_bytes(value)
if len(b) != cls.LENGTH:
raise ValueError(f"{cls.__name__} expects exactly {cls.LENGTH} bytes, got {len(b)}")
raise SSZValueError(f"{cls.__name__} requires exactly {cls.LENGTH} bytes, got {len(b)}")
return super().__new__(cls, b)

@classmethod
Expand Down Expand Up @@ -112,16 +114,14 @@ def deserialize(cls, stream: IO[bytes], scope: int) -> Self:
For a fixed-size type, `scope` must match `LENGTH`.

Raises:
ValueError: if `scope` != `LENGTH`.
IOError: if the stream ends prematurely.
SSZDecodeError: if `scope` != `LENGTH`.
SSZStreamError: if the stream ends prematurely.
"""
if scope != cls.LENGTH:
raise ValueError(
f"Invalid scope for ByteVector[{cls.LENGTH}]: expected {cls.LENGTH}, got {scope}"
)
raise SSZSerializationError(f"{cls.__name__}: expected {cls.LENGTH} bytes, got {scope}")
data = stream.read(scope)
if len(data) != scope:
raise IOError("Stream ended prematurely while decoding ByteVector")
raise SSZSerializationError(f"{cls.__name__}: expected {scope} bytes, got {len(data)}")
return cls(data)

def encode_bytes(self) -> bytes:
Expand All @@ -136,7 +136,9 @@ def decode_bytes(cls, data: bytes) -> Self:
For a fixed-size type, the data must be exactly `LENGTH` bytes.
"""
if len(data) != cls.LENGTH:
raise ValueError(f"{cls.__name__} expects exactly {cls.LENGTH} bytes, got {len(data)}")
raise SSZValueError(
f"{cls.__name__} requires exactly {cls.LENGTH} bytes, got {len(data)}"
)
return cls(data)

@classmethod
Expand Down Expand Up @@ -262,11 +264,11 @@ class BaseByteList(SSZModel):
def _validate_byte_list_data(cls, v: Any) -> bytes:
"""Validate and convert input to bytes with limit checking."""
if not hasattr(cls, "LIMIT"):
raise TypeError(f"{cls.__name__} must define LIMIT")
raise SSZTypeError(f"{cls.__name__} must define LIMIT")

b = _coerce_to_bytes(v)
if len(b) > cls.LIMIT:
raise ValueError(f"ByteList[{cls.LIMIT}] length {len(b)} exceeds limit {cls.LIMIT}")
raise SSZValueError(f"{cls.__name__} exceeds limit of {cls.LIMIT}, got {len(b)}")
return b

@field_serializer("data", when_used="json")
Expand All @@ -282,7 +284,7 @@ def is_fixed_size(cls) -> bool:
@classmethod
def get_byte_length(cls) -> int:
"""ByteList is variable-size, so this should not be called."""
raise TypeError(f"{cls.__name__} is variable-size and has no fixed byte length")
raise SSZTypeError(f"{cls.__name__}: variable-size byte list has no fixed byte length")

def serialize(self, stream: IO[bytes]) -> int:
"""
Expand All @@ -303,16 +305,17 @@ def deserialize(cls, stream: IO[bytes], scope: int) -> Self:
knows how many bytes belong to this value in its context).

Raises:
ValueError: if the decoded length exceeds `LIMIT`.
IOError: if the stream ends prematurely.
SSZDecodeError: if the scope is negative.
SSZLengthError: if the decoded length exceeds `LIMIT`.
SSZStreamError: if the stream ends prematurely.
"""
if scope < 0:
raise ValueError("Invalid scope for ByteList: negative")
raise SSZSerializationError(f"{cls.__name__}: negative scope")
if scope > cls.LIMIT:
raise ValueError(f"ByteList[{cls.LIMIT}] scope {scope} exceeds limit")
raise SSZValueError(f"{cls.__name__} exceeds limit of {cls.LIMIT}, got {scope}")
data = stream.read(scope)
if len(data) != scope:
raise IOError("Stream ended prematurely while decoding ByteList")
raise SSZSerializationError(f"{cls.__name__}: expected {scope} bytes, got {len(data)}")
return cls(data=data)

def encode_bytes(self) -> bytes:
Expand All @@ -327,7 +330,7 @@ def decode_bytes(cls, data: bytes) -> Self:
For variable-size types, the data length must be `<= LIMIT`.
"""
if len(data) > cls.LIMIT:
raise ValueError(f"ByteList[{cls.LIMIT}] length {len(data)} exceeds limit")
raise SSZValueError(f"{cls.__name__} exceeds limit of {cls.LIMIT}, got {len(data)}")
return cls(data=data)

def __bytes__(self) -> bytes:
Expand Down
Loading
Loading