diff --git a/src/lean_spec/types/__init__.py b/src/lean_spec/types/__init__.py index cd3e46b0..0182d312 100644 --- a/src/lean_spec/types/__init__.py +++ b/src/lean_spec/types/__init__.py @@ -6,11 +6,18 @@ from .byte_arrays import ZERO_HASH, Bytes32, Bytes52, Bytes3116 from .collections import SSZList, SSZVector from .container import Container +from .exceptions import ( + SSZError, + SSZSerializationError, + SSZTypeError, + SSZValueError, +) from .ssz_base import SSZType from .uint import Uint64 from .validator import is_proposer __all__ = [ + # Core types "Uint64", "BasisPoint", "Bytes32", @@ -25,4 +32,9 @@ "SSZType", "Boolean", "Container", + # Exceptions + "SSZError", + "SSZTypeError", + "SSZValueError", + "SSZSerializationError", ] diff --git a/src/lean_spec/types/bitfields.py b/src/lean_spec/types/bitfields.py index f14021aa..cb9b8b1b 100644 --- a/src/lean_spec/types/bitfields.py +++ b/src/lean_spec/types/bitfields.py @@ -29,6 +29,7 @@ from typing_extensions import Self from .boolean import Boolean +from .exceptions import SSZSerializationError, SSZTypeError, SSZValueError from .ssz_base import SSZModel @@ -55,13 +56,15 @@ class BaseBitvector(SSZModel): def _coerce_and_validate(cls, v: Any) -> tuple[Boolean, ...]: """Validate and convert input data to typed tuple of Booleans.""" if not hasattr(cls, "LENGTH"): - raise TypeError(f"{cls.__name__} must define LENGTH") + raise SSZTypeError(f"{cls.__name__} must define LENGTH") if not isinstance(v, (list, tuple)): v = tuple(v) if len(v) != cls.LENGTH: - raise ValueError(f"{cls.__name__} requires exactly {cls.LENGTH} bits, got {len(v)}") + raise SSZValueError( + f"{cls.__name__} requires exactly {cls.LENGTH} elements, got {len(v)}" + ) return tuple(Boolean(bit) for bit in v) @@ -86,10 +89,12 @@ def deserialize(cls, stream: IO[bytes], scope: int) -> Self: """Read SSZ bytes from a stream and return an instance.""" expected_len = cls.get_byte_length() if scope != expected_len: - raise ValueError(f"{cls.__name__}: expected {expected_len} bytes, got {scope}") + raise SSZSerializationError( + f"{cls.__name__}: expected {expected_len} bytes, got {scope}" + ) data = stream.read(scope) if len(data) != scope: - raise IOError(f"Expected {scope} bytes, got {len(data)}") + raise SSZSerializationError(f"{cls.__name__}: expected {scope} bytes, got {len(data)}") return cls.decode_bytes(data) def encode_bytes(self) -> bytes: @@ -115,7 +120,7 @@ def decode_bytes(cls, data: bytes) -> Self: """ expected = cls.get_byte_length() if len(data) != expected: - raise ValueError(f"{cls.__name__}: expected {expected} bytes, got {len(data)}") + raise SSZValueError(f"{cls.__name__}: expected {expected} bytes, got {len(data)}") bits = tuple(Boolean((data[i // 8] >> (i % 8)) & 1) for i in range(cls.LENGTH)) return cls(data=bits) @@ -144,7 +149,7 @@ class BaseBitlist(SSZModel): def _coerce_and_validate(cls, v: Any) -> tuple[Boolean, ...]: """Validate and convert input to a tuple of Boolean elements.""" if not hasattr(cls, "LIMIT"): - raise TypeError(f"{cls.__name__} must define LIMIT") + raise SSZTypeError(f"{cls.__name__} must define LIMIT") # Handle various input types if isinstance(v, (list, tuple)): @@ -152,11 +157,11 @@ def _coerce_and_validate(cls, v: Any) -> tuple[Boolean, ...]: elif hasattr(v, "__iter__") and not isinstance(v, (str, bytes)): elements = list(v) else: - raise TypeError(f"Bitlist data must be iterable, got {type(v)}") + raise SSZTypeError(f"Expected iterable, got {type(v).__name__}") # Check limit if len(elements) > cls.LIMIT: - raise ValueError(f"{cls.__name__} cannot exceed {cls.LIMIT} bits, got {len(elements)}") + raise SSZValueError(f"{cls.__name__} exceeds limit of {cls.LIMIT}, got {len(elements)}") return tuple(Boolean(bit) for bit in elements) @@ -197,8 +202,8 @@ def is_fixed_size(cls) -> bool: @classmethod def get_byte_length(cls) -> int: - """Lists are variable-size, so this raises a TypeError.""" - raise TypeError(f"{cls.__name__} is variable-size") + """Lists are variable-size, so this raises an SSZTypeError.""" + raise SSZTypeError(f"{cls.__name__}: variable-size bitlist has no fixed byte length") def serialize(self, stream: IO[bytes]) -> int: """Write SSZ bytes to a binary stream.""" @@ -211,7 +216,7 @@ def deserialize(cls, stream: IO[bytes], scope: int) -> Self: """Read SSZ bytes from a stream and return an instance.""" data = stream.read(scope) if len(data) != scope: - raise IOError(f"Expected {scope} bytes, got {len(data)}") + raise SSZSerializationError(f"{cls.__name__}: expected {scope} bytes, got {len(data)}") return cls.decode_bytes(data) def encode_bytes(self) -> bytes: @@ -254,7 +259,7 @@ def decode_bytes(cls, data: bytes) -> Self: the last data bit. All bits after the delimiter are assumed to be 0. """ if len(data) == 0: - raise ValueError("Cannot decode empty bytes to Bitlist") + raise SSZSerializationError(f"{cls.__name__}: cannot decode empty bytes") # Find the position of the delimiter bit (rightmost 1). delimiter_pos = None @@ -267,12 +272,12 @@ def decode_bytes(cls, data: bytes) -> Self: break if delimiter_pos is None: - raise ValueError("No delimiter bit found in Bitlist data") + raise SSZSerializationError(f"{cls.__name__}: no delimiter bit found") # Extract data bits (everything before the delimiter). num_bits = delimiter_pos if num_bits > cls.LIMIT: - raise ValueError(f"{cls.__name__} decoded length {num_bits} exceeds limit {cls.LIMIT}") + raise SSZValueError(f"{cls.__name__} exceeds limit of {cls.LIMIT}, got {num_bits}") bits = tuple(Boolean((data[i // 8] >> (i % 8)) & 1) for i in range(num_bits)) return cls(data=bits) diff --git a/src/lean_spec/types/boolean.py b/src/lean_spec/types/boolean.py index e461c73e..82012e59 100644 --- a/src/lean_spec/types/boolean.py +++ b/src/lean_spec/types/boolean.py @@ -8,6 +8,7 @@ from pydantic_core import CoreSchema, core_schema from typing_extensions import Self +from .exceptions import SSZSerializationError, SSZTypeError, SSZValueError from .ssz_base import SSZType @@ -31,14 +32,14 @@ def __new__(cls, value: bool | int) -> Self: Accepts only `True`, `False`, `1`, or `0`. Raises: - TypeError: If `value` is not a bool or int. - ValueError: If `value` is an integer other than 0 or 1. + SSZTypeCoercionError: If `value` is not a bool or int. + SSZDecodeError: If `value` is an integer other than 0 or 1. """ if not isinstance(value, int): - raise TypeError(f"Expected bool or int, got {type(value).__name__}") + raise SSZTypeError(f"Expected bool or int, got {type(value).__name__}") if value not in (0, 1): - raise ValueError(f"Boolean value must be 0 or 1, not {value}") + raise SSZValueError(f"Boolean value must be 0 or 1, not {value}") return super().__new__(cls, value) @@ -93,9 +94,9 @@ def encode_bytes(self) -> bytes: def decode_bytes(cls, data: bytes) -> Self: """Deserialize a single byte into a Boolean instance.""" if len(data) != 1: - raise ValueError(f"Expected 1 byte for Boolean, got {len(data)}") + raise SSZSerializationError(f"Boolean: expected 1 byte, got {len(data)}") if data[0] not in (0, 1): - raise ValueError(f"Boolean byte must be 0x00 or 0x01, got {data[0]:#04x}") + raise SSZSerializationError(f"Boolean: byte must be 0x00 or 0x01, got {data[0]:#04x}") return cls(data[0]) def serialize(self, stream: IO[bytes]) -> int: @@ -108,10 +109,10 @@ def serialize(self, stream: IO[bytes]) -> int: def deserialize(cls, stream: IO[bytes], scope: int) -> Self: """Deserialize a boolean from a binary stream.""" if scope != 1: - raise ValueError(f"Invalid scope for Boolean: expected 1, got {scope}") + raise SSZSerializationError(f"Boolean: expected scope of 1, got {scope}") data = stream.read(1) if len(data) != 1: - raise IOError("Stream ended prematurely while decoding Boolean") + raise SSZSerializationError(f"Boolean: expected 1 byte, got {len(data)}") return cls.decode_bytes(data) def _raise_type_error(self, other: Any, op_symbol: str) -> None: diff --git a/src/lean_spec/types/byte_arrays.py b/src/lean_spec/types/byte_arrays.py index 5a360085..83273624 100644 --- a/src/lean_spec/types/byte_arrays.py +++ b/src/lean_spec/types/byte_arrays.py @@ -16,6 +16,7 @@ from pydantic_core import core_schema from typing_extensions import Self +from .exceptions import SSZSerializationError, SSZTypeError, SSZValueError from .ssz_base import SSZModel, SSZType @@ -64,14 +65,15 @@ def __new__(cls, value: Any = b"") -> Self: value: Any value coercible to bytes (see `_coerce_to_bytes`). Raises: - ValueError: If the resulting byte length differs from `LENGTH`. + SSZTypeDefinitionError: If the class doesn't define LENGTH. + SSZLengthError: If the resulting byte length differs from `LENGTH`. """ if not hasattr(cls, "LENGTH"): - raise TypeError(f"{cls.__name__} must define LENGTH") + raise SSZTypeError(f"{cls.__name__} must define LENGTH") b = _coerce_to_bytes(value) if len(b) != cls.LENGTH: - raise ValueError(f"{cls.__name__} expects exactly {cls.LENGTH} bytes, got {len(b)}") + raise SSZValueError(f"{cls.__name__} requires exactly {cls.LENGTH} bytes, got {len(b)}") return super().__new__(cls, b) @classmethod @@ -112,16 +114,14 @@ def deserialize(cls, stream: IO[bytes], scope: int) -> Self: For a fixed-size type, `scope` must match `LENGTH`. Raises: - ValueError: if `scope` != `LENGTH`. - IOError: if the stream ends prematurely. + SSZDecodeError: if `scope` != `LENGTH`. + SSZStreamError: if the stream ends prematurely. """ if scope != cls.LENGTH: - raise ValueError( - f"Invalid scope for ByteVector[{cls.LENGTH}]: expected {cls.LENGTH}, got {scope}" - ) + raise SSZSerializationError(f"{cls.__name__}: expected {cls.LENGTH} bytes, got {scope}") data = stream.read(scope) if len(data) != scope: - raise IOError("Stream ended prematurely while decoding ByteVector") + raise SSZSerializationError(f"{cls.__name__}: expected {scope} bytes, got {len(data)}") return cls(data) def encode_bytes(self) -> bytes: @@ -136,7 +136,9 @@ def decode_bytes(cls, data: bytes) -> Self: For a fixed-size type, the data must be exactly `LENGTH` bytes. """ if len(data) != cls.LENGTH: - raise ValueError(f"{cls.__name__} expects exactly {cls.LENGTH} bytes, got {len(data)}") + raise SSZValueError( + f"{cls.__name__} requires exactly {cls.LENGTH} bytes, got {len(data)}" + ) return cls(data) @classmethod @@ -262,11 +264,11 @@ class BaseByteList(SSZModel): def _validate_byte_list_data(cls, v: Any) -> bytes: """Validate and convert input to bytes with limit checking.""" if not hasattr(cls, "LIMIT"): - raise TypeError(f"{cls.__name__} must define LIMIT") + raise SSZTypeError(f"{cls.__name__} must define LIMIT") b = _coerce_to_bytes(v) if len(b) > cls.LIMIT: - raise ValueError(f"ByteList[{cls.LIMIT}] length {len(b)} exceeds limit {cls.LIMIT}") + raise SSZValueError(f"{cls.__name__} exceeds limit of {cls.LIMIT}, got {len(b)}") return b @field_serializer("data", when_used="json") @@ -282,7 +284,7 @@ def is_fixed_size(cls) -> bool: @classmethod def get_byte_length(cls) -> int: """ByteList is variable-size, so this should not be called.""" - raise TypeError(f"{cls.__name__} is variable-size and has no fixed byte length") + raise SSZTypeError(f"{cls.__name__}: variable-size byte list has no fixed byte length") def serialize(self, stream: IO[bytes]) -> int: """ @@ -303,16 +305,17 @@ def deserialize(cls, stream: IO[bytes], scope: int) -> Self: knows how many bytes belong to this value in its context). Raises: - ValueError: if the decoded length exceeds `LIMIT`. - IOError: if the stream ends prematurely. + SSZDecodeError: if the scope is negative. + SSZLengthError: if the decoded length exceeds `LIMIT`. + SSZStreamError: if the stream ends prematurely. """ if scope < 0: - raise ValueError("Invalid scope for ByteList: negative") + raise SSZSerializationError(f"{cls.__name__}: negative scope") if scope > cls.LIMIT: - raise ValueError(f"ByteList[{cls.LIMIT}] scope {scope} exceeds limit") + raise SSZValueError(f"{cls.__name__} exceeds limit of {cls.LIMIT}, got {scope}") data = stream.read(scope) if len(data) != scope: - raise IOError("Stream ended prematurely while decoding ByteList") + raise SSZSerializationError(f"{cls.__name__}: expected {scope} bytes, got {len(data)}") return cls(data=data) def encode_bytes(self) -> bytes: @@ -327,7 +330,7 @@ def decode_bytes(cls, data: bytes) -> Self: For variable-size types, the data length must be `<= LIMIT`. """ if len(data) > cls.LIMIT: - raise ValueError(f"ByteList[{cls.LIMIT}] length {len(data)} exceeds limit") + raise SSZValueError(f"{cls.__name__} exceeds limit of {cls.LIMIT}, got {len(data)}") return cls(data=data) def __bytes__(self) -> bytes: diff --git a/src/lean_spec/types/collections.py b/src/lean_spec/types/collections.py index 9c5924d1..d9d43a46 100644 --- a/src/lean_spec/types/collections.py +++ b/src/lean_spec/types/collections.py @@ -22,6 +22,7 @@ from lean_spec.types.constants import OFFSET_BYTE_LENGTH from .byte_arrays import BaseBytes +from .exceptions import SSZSerializationError, SSZTypeError, SSZValueError from .ssz_base import SSZModel, SSZType from .uint import Uint32 @@ -103,7 +104,7 @@ def _serialize_data(self, value: Sequence[T]) -> list[Any]: def _validate_vector_data(cls, v: Any) -> tuple[SSZType, ...]: """Validate and convert input to a typed tuple of exactly LENGTH elements.""" if not hasattr(cls, "ELEMENT_TYPE") or not hasattr(cls, "LENGTH"): - raise TypeError(f"{cls.__name__} must define ELEMENT_TYPE and LENGTH") + raise SSZTypeError(f"{cls.__name__} must define ELEMENT_TYPE and LENGTH") if not isinstance(v, (list, tuple)): v = tuple(v) @@ -115,9 +116,8 @@ def _validate_vector_data(cls, v: Any) -> tuple[SSZType, ...]: ) if len(typed_values) != cls.LENGTH: - raise ValueError( - f"{cls.__name__} requires exactly {cls.LENGTH} items, " - f"but {len(typed_values)} were provided." + raise SSZValueError( + f"{cls.__name__} requires exactly {cls.LENGTH} elements, got {len(typed_values)}" ) return typed_values @@ -131,7 +131,7 @@ def is_fixed_size(cls) -> bool: def get_byte_length(cls) -> int: """Get the byte length if the SSZVector is fixed-size.""" if not cls.is_fixed_size(): - raise TypeError(f"{cls.__name__} is not a fixed-size type.") + raise SSZTypeError(f"{cls.__name__}: variable-size vector has no fixed byte length") return cls.ELEMENT_TYPE.get_byte_length() * cls.LENGTH def serialize(self, stream: IO[bytes]) -> int: @@ -162,9 +162,8 @@ def deserialize(cls, stream: IO[bytes], scope: int) -> Self: if cls.is_fixed_size(): elem_byte_length = cls.get_byte_length() // cls.LENGTH if scope != cls.get_byte_length(): - raise ValueError( - f"Invalid scope for {cls.__name__}: " - f"expected {cls.get_byte_length()}, got {scope}" + raise SSZSerializationError( + f"{cls.__name__}: expected {cls.get_byte_length()} bytes, got {scope}" ) elements = [ cls.ELEMENT_TYPE.deserialize(stream, elem_byte_length) for _ in range(cls.LENGTH) @@ -175,7 +174,10 @@ def deserialize(cls, stream: IO[bytes], scope: int) -> Self: # The first offset tells us where the data starts, which must be after all offsets. first_offset = int(Uint32.deserialize(stream, OFFSET_BYTE_LENGTH)) if first_offset != cls.LENGTH * OFFSET_BYTE_LENGTH: - raise ValueError("Invalid first offset in variable-size vector.") + expected = cls.LENGTH * OFFSET_BYTE_LENGTH + raise SSZSerializationError( + f"{cls.__name__}: invalid offset {first_offset}, expected {expected}" + ) # Read the remaining offsets and add the total scope as the final boundary. offsets = [first_offset] + [ int(Uint32.deserialize(stream, OFFSET_BYTE_LENGTH)) for _ in range(cls.LENGTH - 1) @@ -185,7 +187,9 @@ def deserialize(cls, stream: IO[bytes], scope: int) -> Self: for i in range(cls.LENGTH): start, end = offsets[i], offsets[i + 1] if start > end: - raise ValueError(f"Invalid offsets: start {start} > end {end}") + raise SSZSerializationError( + f"{cls.__name__}: invalid offsets start={start} > end={end}" + ) elements.append(cls.ELEMENT_TYPE.deserialize(stream, end - start)) return cls(data=elements) @@ -294,7 +298,7 @@ def _serialize_data(self, value: Sequence[T]) -> list[Any]: def _validate_list_data(cls, v: Any) -> tuple[SSZType, ...]: """Validate and convert input to a tuple of SSZType elements.""" if not hasattr(cls, "ELEMENT_TYPE") or not hasattr(cls, "LIMIT"): - raise TypeError(f"{cls.__name__} must define ELEMENT_TYPE and LIMIT") + raise SSZTypeError(f"{cls.__name__} must define ELEMENT_TYPE and LIMIT") # Handle various input types if isinstance(v, (list, tuple)): @@ -302,25 +306,23 @@ def _validate_list_data(cls, v: Any) -> tuple[SSZType, ...]: elif hasattr(v, "__iter__") and not isinstance(v, (str, bytes)): elements = list(v) else: - raise TypeError(f"List data must be iterable, got {type(v)}") + raise SSZTypeError(f"Expected iterable, got {type(v).__name__}") # Check limit if len(elements) > cls.LIMIT: - raise ValueError( - f"{cls.__name__} cannot contain more than {cls.LIMIT} elements, got {len(elements)}" - ) + raise SSZValueError(f"{cls.__name__} exceeds limit of {cls.LIMIT}, got {len(elements)}") # Convert and validate each element typed_values = [] - for i, element in enumerate(elements): + for element in elements: if isinstance(element, cls.ELEMENT_TYPE): typed_values.append(element) else: try: typed_values.append(cast(Any, cls.ELEMENT_TYPE)(element)) except Exception as e: - raise ValueError( - f"Element {i} cannot be converted to {cls.ELEMENT_TYPE.__name__}: {e}" + raise SSZTypeError( + f"Expected {cls.ELEMENT_TYPE.__name__}, got {type(element).__name__}" ) from e return tuple(typed_values) @@ -342,8 +344,8 @@ def is_fixed_size(cls) -> bool: @classmethod def get_byte_length(cls) -> int: - """Lists are variable-size, so this raises a TypeError.""" - raise TypeError(f"{cls.__name__} is variable-size") + """Lists are variable-size, so this raises an SSZTypeError.""" + raise SSZTypeError(f"{cls.__name__}: variable-size list has no fixed byte length") def serialize(self, stream: IO[bytes]) -> int: """Serialize the list to a binary stream.""" @@ -372,11 +374,15 @@ def deserialize(cls, stream: IO[bytes], scope: int) -> Self: # Fixed-size elements: read them back-to-back element_size = cls.ELEMENT_TYPE.get_byte_length() if scope % element_size != 0: - raise ValueError(f"Scope {scope} is not divisible by element size {element_size}") + raise SSZSerializationError( + f"{cls.__name__}: scope {scope} not divisible by element size {element_size}" + ) num_elements = scope // element_size if num_elements > cls.LIMIT: - raise ValueError(f"Too many elements: {num_elements} > {cls.LIMIT}") + raise SSZValueError( + f"{cls.__name__} exceeds limit of {cls.LIMIT}, got {num_elements}" + ) elements = [ cls.ELEMENT_TYPE.deserialize(stream, element_size) for _ in range(num_elements) @@ -389,16 +395,18 @@ def deserialize(cls, stream: IO[bytes], scope: int) -> Self: # Empty list case return cls(data=[]) if scope < OFFSET_BYTE_LENGTH: - raise ValueError(f"Invalid scope for variable-size list: {scope}") + raise SSZSerializationError( + f"{cls.__name__}: scope {scope} too small for variable-size list" + ) # Read the first offset to determine the number of elements. first_offset = int(Uint32.deserialize(stream, OFFSET_BYTE_LENGTH)) if first_offset > scope or first_offset % OFFSET_BYTE_LENGTH != 0: - raise ValueError("Invalid first offset in list.") + raise SSZSerializationError(f"{cls.__name__}: invalid offset {first_offset}") count = first_offset // OFFSET_BYTE_LENGTH if count > cls.LIMIT: - raise ValueError(f"Decoded list length {count} exceeds limit of {cls.LIMIT}") + raise SSZValueError(f"{cls.__name__} exceeds limit of {cls.LIMIT}, got {count}") # Read the rest of the offsets. offsets = [first_offset] + [ @@ -411,7 +419,9 @@ def deserialize(cls, stream: IO[bytes], scope: int) -> Self: for i in range(count): start, end = offsets[i], offsets[i + 1] if start > end: - raise ValueError(f"Invalid offsets: start {start} > end {end}") + raise SSZSerializationError( + f"{cls.__name__}: invalid offsets start={start} > end={end}" + ) elements.append(cls.ELEMENT_TYPE.deserialize(stream, end - start)) return cls(data=elements) diff --git a/src/lean_spec/types/container.py b/src/lean_spec/types/container.py index 4caa88be..c09d594a 100644 --- a/src/lean_spec/types/container.py +++ b/src/lean_spec/types/container.py @@ -16,6 +16,7 @@ from typing_extensions import Self from .constants import OFFSET_BYTE_LENGTH +from .exceptions import SSZSerializationError, SSZTypeError from .ssz_base import SSZModel, SSZType from .uint import Uint32 @@ -31,14 +32,11 @@ def _get_ssz_field_type(annotation: Any) -> Type[SSZType]: The SSZType class. Raises: - TypeError: If the annotation is not a valid SSZType class. + SSZTypeCoercionError: If the annotation is not a valid SSZType class. """ # Check if it's a class and is a subclass of SSZType if not (inspect.isclass(annotation) and issubclass(annotation, SSZType)): - raise TypeError( - f"Field annotation {annotation} is not a valid SSZType class. " - f"Container fields must be concrete SSZType subclasses." - ) + raise SSZTypeError(f"Expected SSZType subclass, got {annotation}") return annotation @@ -98,11 +96,11 @@ def get_byte_length(cls) -> int: Total byte length of all fields summed together. Raises: - TypeError: If called on a variable-size container. + SSZTypeDefinitionError: If called on a variable-size container. """ # Only fixed-size containers have a deterministic byte length if not cls.is_fixed_size(): - raise TypeError(f"{cls.__name__} is variable-size") + raise SSZTypeError(f"{cls.__name__}: variable-size container has no fixed byte length") # Sum the byte lengths of all fixed-size fields return sum( @@ -185,8 +183,8 @@ def deserialize(cls, stream: IO[bytes], scope: int) -> Self: New container instance with deserialized values. Raises: - IOError: If stream ends unexpectedly. - ValueError: If offsets are invalid. + SSZStreamError: If stream ends unexpectedly. + SSZOffsetError: If offsets are invalid. """ fields = {} # Collected field values var_fields = [] # (name, type, offset) for variable fields @@ -201,14 +199,19 @@ def deserialize(cls, stream: IO[bytes], scope: int) -> Self: size = field_type.get_byte_length() data = stream.read(size) if len(data) != size: - raise IOError(f"Unexpected EOF reading {field_name}") + raise SSZSerializationError( + f"{cls.__name__}.{field_name}: expected {size} bytes, got {len(data)}" + ) fields[field_name] = field_type.decode_bytes(data) bytes_read += size else: # Read offset pointer for variable field offset_bytes = stream.read(OFFSET_BYTE_LENGTH) if len(offset_bytes) != OFFSET_BYTE_LENGTH: - raise IOError(f"Unexpected EOF reading offset for {field_name}") + raise SSZSerializationError( + f"{cls.__name__}.{field_name}: " + f"expected {OFFSET_BYTE_LENGTH} offset bytes, got {len(offset_bytes)}" + ) offset = int(Uint32.decode_bytes(offset_bytes)) var_fields.append((field_name, field_type, offset)) bytes_read += OFFSET_BYTE_LENGTH @@ -219,7 +222,10 @@ def deserialize(cls, stream: IO[bytes], scope: int) -> Self: var_section_size = scope - bytes_read var_section = stream.read(var_section_size) if len(var_section) != var_section_size: - raise IOError("Unexpected EOF in variable section") + raise SSZSerializationError( + f"{cls.__name__}: " + f"expected {var_section_size} variable bytes, got {len(var_section)}" + ) # Extract each variable field using offsets offsets = [offset for _, _, offset in var_fields] + [scope] @@ -231,7 +237,9 @@ def deserialize(cls, stream: IO[bytes], scope: int) -> Self: # Validate offset bounds if rel_start < 0 or rel_start > rel_end: - raise ValueError(f"Invalid offsets for {name}") + raise SSZSerializationError( + f"{cls.__name__}.{name}: invalid offsets start={start}, end={end}" + ) # Deserialize field from its slice field_data = var_section[rel_start:rel_end] diff --git a/src/lean_spec/types/exceptions.py b/src/lean_spec/types/exceptions.py new file mode 100644 index 00000000..c1fbe607 --- /dev/null +++ b/src/lean_spec/types/exceptions.py @@ -0,0 +1,17 @@ +"""Exception hierarchy for the SSZ type system.""" + + +class SSZError(Exception): + """Base exception for all SSZ-related errors.""" + + +class SSZTypeError(SSZError): + """Raised for type-related errors (coercion, definition, invalid types).""" + + +class SSZValueError(SSZError): + """Raised for value-related errors (overflow, length, bounds).""" + + +class SSZSerializationError(SSZError): + """Raised for serialization errors (encoding, decoding, stream issues).""" diff --git a/src/lean_spec/types/uint.py b/src/lean_spec/types/uint.py index b3274966..1acef0f1 100644 --- a/src/lean_spec/types/uint.py +++ b/src/lean_spec/types/uint.py @@ -8,6 +8,7 @@ from pydantic_core import core_schema from typing_extensions import Self +from .exceptions import SSZSerializationError, SSZTypeError, SSZValueError from .ssz_base import SSZType @@ -22,16 +23,17 @@ def __new__(cls, value: SupportsInt) -> Self: Create and validate a new Uint instance. Raises: - TypeError: If `value` is not an int (rejects bool, string, float). - OverflowError: If `value` is outside the allowed range [0, 2**BITS - 1]. + SSZTypeCoercionError: If `value` is not an int (rejects bool, string, float). + SSZOverflowError: If `value` is outside the allowed range [0, 2**BITS - 1]. """ # We should accept only ints. if not isinstance(value, int) or isinstance(value, bool): - raise TypeError(f"Expected int, got {type(value).__name__}") + raise SSZTypeError(f"Expected int, got {type(value).__name__}") int_value = int(value) - if not (0 <= int_value < (2**cls.BITS)): - raise OverflowError(f"{int_value} is out of range for {cls.__name__}") + max_value = 2**cls.BITS - 1 + if not (0 <= int_value <= max_value): + raise SSZValueError(f"{int_value} out of range for {cls.__name__} [0, {max_value}]") return super().__new__(cls, int_value) @classmethod @@ -104,7 +106,7 @@ def decode_bytes(cls, data: bytes) -> Self: data (bytes): The SSZ byte string to deserialize. Raises: - ValueError: If the byte string has an incorrect length. + SSZDecodeError: If the byte string has an incorrect length. Returns: Self: A new instance of the Uint class. @@ -112,9 +114,8 @@ def decode_bytes(cls, data: bytes) -> Self: # Ensure the input data has the correct number of bytes. expected_length = cls.get_byte_length() if len(data) != expected_length: - raise ValueError( - f"Invalid byte length for {cls.__name__}: " - f"expected {expected_length}, got {len(data)}" + raise SSZSerializationError( + f"{cls.__name__}: expected {expected_length} bytes, got {len(data)}" ) # The `from_bytes` class method from `int` is used to convert the data. return cls(int.from_bytes(data, "little")) @@ -146,7 +147,8 @@ def deserialize(cls, stream: IO[bytes], scope: int) -> Self: scope (int): The number of bytes available to read for this object. Raises: - ValueError: If the scope does not match the type's byte length. + SSZDecodeError: If the scope does not match the type's byte length. + SSZStreamError: If the stream ends prematurely. Returns: Self: A new instance of the Uint class. @@ -154,14 +156,16 @@ def deserialize(cls, stream: IO[bytes], scope: int) -> Self: # For a fixed-size type, the scope must exactly match the byte length. byte_length = cls.get_byte_length() if scope != byte_length: - raise ValueError( - f"Invalid scope for {cls.__name__}: expected {byte_length}, got {scope}" + raise SSZSerializationError( + f"{cls.__name__}: invalid scope, expected {byte_length} bytes, got {scope}" ) # Read the required number of bytes from the stream. data = stream.read(byte_length) # Ensure the correct number of bytes was read. if len(data) != byte_length: - raise IOError(f"Stream ended prematurely while decoding {cls.__name__}") + raise SSZSerializationError( + f"{cls.__name__}: expected {byte_length} bytes, got {len(data)}" + ) # Decode the bytes into a new instance. return cls.decode_bytes(data) diff --git a/src/lean_spec/types/union.py b/src/lean_spec/types/union.py index 13571b49..c8f46824 100644 --- a/src/lean_spec/types/union.py +++ b/src/lean_spec/types/union.py @@ -16,6 +16,7 @@ from pydantic import model_validator from typing_extensions import Self +from .exceptions import SSZSerializationError, SSZTypeError, SSZValueError from .ssz_base import SSZModel, SSZType # Constants for Union implementation @@ -103,26 +104,26 @@ def _validate_union_data(cls, data: Any) -> dict[str, Any]: """Validate selector and value together.""" # Check required class attributes and get options if not hasattr(cls, "OPTIONS") or not isinstance(cls.OPTIONS, tuple): - raise TypeError(f"{cls.__name__} must define OPTIONS as a tuple of SSZ types") + raise SSZTypeError(f"{cls.__name__} must define OPTIONS as a tuple of SSZ types") options, options_count = cls.OPTIONS, len(cls.OPTIONS) # Validate OPTIONS constraints if options_count == 0: - raise TypeError(f"{cls.__name__} OPTIONS cannot be empty") + raise SSZTypeError(f"{cls.__name__}: OPTIONS cannot be empty") if options_count > MAX_UNION_OPTIONS: - raise TypeError( - f"{cls.__name__} has {options_count} options, but maximum is {MAX_UNION_OPTIONS}" + raise SSZTypeError( + f"{cls.__name__}: has {options_count} options, max is {MAX_UNION_OPTIONS}" ) if options[0] is None and options_count == 1: - raise TypeError(f"{cls.__name__} cannot have None as the only option") + raise SSZTypeError(f"{cls.__name__}: cannot have None as the only option") # Validate None placement (only at index 0) and types for i, opt in enumerate(options): if opt is None and i != 0: - raise TypeError(f"{cls.__name__} can only have None at index 0, found at index {i}") + raise SSZTypeError(f"{cls.__name__}: None only allowed at index 0, found at {i}") elif opt is not None and not isinstance(opt, type): - raise TypeError(f"{cls.__name__} option {i} must be a type, got {type(opt)}") + raise SSZTypeError(f"{cls.__name__}: option {i} must be a type, got {type(opt)}") # Extract selector and value from input selector = data.get("selector") @@ -130,12 +131,13 @@ def _validate_union_data(cls, data: Any) -> dict[str, Any]: # Validate selector if not isinstance(selector, int) or not 0 <= selector < options_count: - raise ValueError(f"Invalid selector {selector} for {options_count} options") + sel = selector if isinstance(selector, int) else -1 + raise SSZValueError(f"{cls.__name__}: selector {sel} out of range [0, {options_count})") # Handle None option if (selected_type := options[selector]) is None: if value is not None: - raise TypeError("Selected option is None, therefore value must be None") + raise SSZTypeError(f"Expected None, got {type(value).__name__}") return {"selector": selector, "value": None} # Handle non-None option - coerce value if needed @@ -146,8 +148,8 @@ def _validate_union_data(cls, data: Any) -> dict[str, Any]: coerced_value = cast(Any, selected_type)(value) return {"selector": selector, "value": coerced_value} except Exception as e: - raise TypeError( - f"Cannot coerce {type(value).__name__} to {selected_type.__name__}: {e}" + raise SSZTypeError( + f"Expected {selected_type.__name__}, got {type(value).__name__}" ) from e @property @@ -168,7 +170,7 @@ def is_fixed_size(cls) -> bool: @classmethod def get_byte_length(cls) -> int: """Union types are variable-size and don't have fixed length.""" - raise TypeError(f"{cls.__name__} is variable-size") + raise SSZTypeError(f"{cls.__name__}: variable-size union has no fixed byte length") def serialize(self, stream: IO[bytes]) -> int: """Serialize this Union to a byte stream in SSZ format.""" @@ -183,40 +185,49 @@ def deserialize(cls, stream: IO[bytes], scope: int) -> Self: """Deserialize a Union from a byte stream using SSZ format.""" # Validate scope for selector byte if scope < SELECTOR_BYTE_SIZE: - raise ValueError("Scope too small for Union selector") + raise SSZSerializationError(f"{cls.__name__}: scope too small for selector") # Read selector byte selector_bytes = stream.read(SELECTOR_BYTE_SIZE) if len(selector_bytes) != SELECTOR_BYTE_SIZE: - raise IOError("Stream ended reading Union selector") + raise SSZSerializationError( + f"{cls.__name__}: " + f"expected {SELECTOR_BYTE_SIZE} selector bytes, got {len(selector_bytes)}" + ) selector = int.from_bytes(selector_bytes, byteorder="little") remaining_bytes = scope - SELECTOR_BYTE_SIZE # Validate selector range if not 0 <= selector < len(cls.OPTIONS): - raise ValueError(f"Selector {selector} out of range for {len(cls.OPTIONS)} options") + raise SSZValueError( + f"{cls.__name__}: selector {selector} out of range [0, {len(cls.OPTIONS)})" + ) selected_type = cls.OPTIONS[selector] # Handle None option if selected_type is None: if remaining_bytes != 0: - raise ValueError("Invalid encoding: None arm must have no payload bytes") + raise SSZSerializationError(f"{cls.__name__}: None arm must have no payload bytes") return cls(selector=selector, value=None) # Handle non-None option if selected_type.is_fixed_size() and hasattr(selected_type, "get_byte_length"): required_bytes = selected_type.get_byte_length() if remaining_bytes < required_bytes: - raise IOError(f"Need {required_bytes} bytes, got {remaining_bytes}") + raise SSZSerializationError( + f"{cls.__name__}: expected {required_bytes} bytes, got {remaining_bytes}" + ) # Deserialize value try: value = selected_type.deserialize(stream, remaining_bytes) return cls(selector=selector, value=value) except Exception as e: - raise IOError(f"Failed to deserialize {selected_type.__name__}: {e}") from e + raise SSZSerializationError( + f"{cls.__name__}: failed to deserialize {selected_type.__name__}: {e}" + ) from e def encode_bytes(self) -> bytes: """Encode this Union to bytes.""" diff --git a/tests/lean_spec/types/test_bitfields.py b/tests/lean_spec/types/test_bitfields.py index 554c6255..e30f0f37 100644 --- a/tests/lean_spec/types/test_bitfields.py +++ b/tests/lean_spec/types/test_bitfields.py @@ -9,6 +9,10 @@ from lean_spec.types.bitfields import BaseBitlist, BaseBitvector from lean_spec.types.boolean import Boolean +from lean_spec.types.exceptions import SSZSerializationError, SSZTypeError, SSZValueError + +# Type alias for errors that can be SSZValueError or wrapped in ValidationError +ValueOrValidationError = (SSZValueError, ValidationError) # Define bitfield types at module level for reuse and model classes @@ -54,7 +58,7 @@ class Bitvector16(BaseBitvector): def test_instantiate_raw_type_raises_error(self) -> None: """Tests that the raw, non-specialized BaseBitvector cannot be instantiated.""" - with pytest.raises(TypeError, match="must define LENGTH"): + with pytest.raises(SSZTypeError, match="must define LENGTH"): BaseBitvector(data=[]) def test_instantiation_success(self) -> None: @@ -74,7 +78,7 @@ def test_instantiation_success(self) -> None: ) def test_instantiation_with_wrong_length_raises_error(self, values: list[Boolean]) -> None: """Tests that providing the wrong number of items during instantiation fails.""" - with pytest.raises(ValueError, match="requires exactly 4 bits"): + with pytest.raises(ValueOrValidationError): Bitvector4(data=values) def test_pydantic_validation_accepts_valid_list(self) -> None: @@ -93,7 +97,7 @@ def test_pydantic_validation_accepts_valid_list(self) -> None: ) def test_pydantic_validation_rejects_invalid_values(self, invalid_value: Any) -> None: """Tests that Pydantic validation rejects lists of the wrong length.""" - with pytest.raises(ValidationError): + with pytest.raises(ValueOrValidationError): Bitvector4Model(value=invalid_value) def test_bitvector_is_immutable(self) -> None: @@ -125,7 +129,7 @@ class Bitlist16(BaseBitlist): def test_instantiate_raw_type_raises_error(self) -> None: """Tests that the raw, non-specialized BaseBitlist cannot be instantiated.""" - with pytest.raises(TypeError, match="must define LIMIT"): + with pytest.raises(SSZTypeError, match="must define LIMIT"): BaseBitlist(data=[]) def test_instantiation_success(self) -> None: @@ -141,7 +145,7 @@ def test_instantiation_over_limit_raises_error(self) -> None: class Bitlist4(BaseBitlist): LIMIT = 4 - with pytest.raises(ValueError, match="cannot exceed 4 bits"): + with pytest.raises(ValueOrValidationError): Bitlist4(data=[Boolean(b) for b in [True, False, True, False, True]]) def test_pydantic_validation_accepts_valid_list(self) -> None: @@ -159,7 +163,7 @@ def test_pydantic_validation_accepts_valid_list(self) -> None: ) def test_pydantic_validation_rejects_invalid_values(self, invalid_value: Any) -> None: """Tests that Pydantic validation rejects lists that exceed the limit.""" - with pytest.raises(ValidationError): + with pytest.raises(ValueOrValidationError): Bitlist8Model(value=invalid_value) def test_add_with_list(self) -> None: @@ -197,7 +201,7 @@ class Bitlist4(BaseBitlist): LIMIT = 4 bitlist = Bitlist4(data=[Boolean(True), Boolean(False), Boolean(True)]) - with pytest.raises(ValueError, match="cannot exceed 4 bits"): + with pytest.raises(ValueOrValidationError): bitlist + [Boolean(False), Boolean(True)] @@ -267,7 +271,7 @@ def test_bitvector_decode_invalid_length(self) -> None: class Bitvector8(BaseBitvector): LENGTH = 8 - with pytest.raises(ValueError, match="expected 1 bytes, got 2"): + with pytest.raises(SSZValueError, match="expected 1 bytes, got 2"): Bitvector8.decode_bytes(b"\x01\x02") # Expects 1 byte, gets 2 def test_bitlist_decode_invalid_data(self) -> None: @@ -276,7 +280,7 @@ def test_bitlist_decode_invalid_data(self) -> None: class Bitlist8(BaseBitlist): LIMIT = 8 - with pytest.raises(ValueError, match="Cannot decode empty bytes"): + with pytest.raises(SSZSerializationError, match="cannot decode empty bytes"): Bitlist8.decode_bytes(b"") @@ -295,7 +299,7 @@ class Bitlist10(BaseBitlist): LIMIT = 10 assert Bitlist10.is_fixed_size() is False - with pytest.raises(TypeError): + with pytest.raises(SSZTypeError): Bitlist10.get_byte_length() def test_bitvector_deserialize_invalid_scope(self) -> None: @@ -303,7 +307,7 @@ class Bitvector8(BaseBitvector): LENGTH = 8 stream = io.BytesIO(b"\xff") - with pytest.raises(ValueError, match="expected 1 bytes, got 2"): + with pytest.raises(SSZSerializationError, match="expected 1 bytes, got 2"): Bitvector8.deserialize(stream, scope=2) def test_bitvector_deserialize_premature_end(self) -> None: @@ -311,7 +315,7 @@ class Bitvector16(BaseBitvector): LENGTH = 16 stream = io.BytesIO(b"\xff") # Only 1 byte, expects 2 - with pytest.raises(IOError, match="Expected 2 bytes, got 1"): + with pytest.raises(SSZSerializationError, match="expected 2 bytes, got 1"): Bitvector16.deserialize(stream, scope=2) def test_bitlist_deserialize_premature_end(self) -> None: @@ -319,7 +323,7 @@ class Bitlist16(BaseBitlist): LIMIT = 16 stream = io.BytesIO(b"\xff") # Only 1 byte - with pytest.raises(IOError, match="Expected 2 bytes, got 1"): + with pytest.raises(SSZSerializationError, match="expected 2 bytes, got 1"): Bitlist16.deserialize(stream, scope=2) # Scope says to read 2 @pytest.mark.parametrize( diff --git a/tests/lean_spec/types/test_boolean.py b/tests/lean_spec/types/test_boolean.py index 7e6168ba..e9c17a4c 100644 --- a/tests/lean_spec/types/test_boolean.py +++ b/tests/lean_spec/types/test_boolean.py @@ -7,6 +7,7 @@ from pydantic import BaseModel, ValidationError from lean_spec.types.boolean import Boolean +from lean_spec.types.exceptions import SSZSerializationError, SSZTypeError, SSZValueError class BooleanModel(BaseModel): @@ -39,15 +40,15 @@ def test_instantiation_from_valid_types(valid_value: bool | int) -> None: @pytest.mark.parametrize("invalid_int", [-1, 2, 100]) def test_instantiation_from_invalid_int_raises_error(invalid_int: int) -> None: - """Tests that instantiating with an int other than 0 or 1 raises ValueError.""" - with pytest.raises(ValueError, match="Boolean value must be 0 or 1"): + """Tests that instantiating with an int other than 0 or 1 raises SSZValueError.""" + with pytest.raises(SSZValueError, match="Boolean value must be 0 or 1"): Boolean(invalid_int) @pytest.mark.parametrize("invalid_type", [1.0, "True", b"\x01", None]) def test_instantiation_from_invalid_types_raises_error(invalid_type: Any) -> None: - """Tests that instantiating with non-bool/non-int types raises a TypeError.""" - with pytest.raises(TypeError, match="Expected bool or int"): + """Tests that instantiating with non-bool/non-int types raises SSZTypeError.""" + with pytest.raises(SSZTypeError, match="Expected bool or int"): Boolean(invalid_type) @@ -210,16 +211,16 @@ def test_encode_decode_roundtrip(self, value: bool, expected_bytes: bytes) -> No def test_decode_invalid_length(self) -> None: """Tests that decode_bytes fails with incorrect byte length.""" - with pytest.raises(ValueError, match="Expected 1 byte"): + with pytest.raises(SSZSerializationError, match="expected 1 byte"): Boolean.decode_bytes(b"") - with pytest.raises(ValueError, match="Expected 1 byte"): + with pytest.raises(SSZSerializationError, match="expected 1 byte"): Boolean.decode_bytes(b"\x00\x01") def test_decode_invalid_value(self) -> None: """Tests that decode_bytes fails with an invalid byte value.""" - with pytest.raises(ValueError, match="must be 0x00 or 0x01"): + with pytest.raises(SSZSerializationError, match="must be 0x00 or 0x01"): Boolean.decode_bytes(b"\x02") - with pytest.raises(ValueError, match="must be 0x00 or 0x01"): + with pytest.raises(SSZSerializationError, match="must be 0x00 or 0x01"): Boolean.decode_bytes(b"\xff") @pytest.mark.parametrize("value", [True, False]) @@ -241,15 +242,15 @@ def test_serialize_deserialize_roundtrip(self, value: bool) -> None: def test_deserialize_invalid_scope(self) -> None: """Tests that deserialize fails with an incorrect scope.""" stream = io.BytesIO(b"\x01") - with pytest.raises(ValueError, match="Invalid scope for Boolean"): + with pytest.raises(SSZSerializationError, match="expected scope of 1"): Boolean.deserialize(stream, scope=0) stream.seek(0) - with pytest.raises(ValueError, match="Invalid scope for Boolean"): + with pytest.raises(SSZSerializationError, match="expected scope of 1"): Boolean.deserialize(stream, scope=2) def test_deserialize_premature_stream_end(self) -> None: """Tests that deserialize fails if the stream ends prematurely.""" stream = io.BytesIO(b"") # Empty stream - with pytest.raises(IOError, match="Stream ended prematurely"): + with pytest.raises(SSZSerializationError, match="expected 1 byte, got 0"): Boolean.deserialize(stream, scope=1) diff --git a/tests/lean_spec/types/test_byte_arrays.py b/tests/lean_spec/types/test_byte_arrays.py index 8888f4e4..5f9cb6cd 100644 --- a/tests/lean_spec/types/test_byte_arrays.py +++ b/tests/lean_spec/types/test_byte_arrays.py @@ -19,6 +19,7 @@ BaseBytes, BaseByteList, ) +from lean_spec.types.exceptions import SSZSerializationError, SSZTypeError, SSZValueError def sha256(b: bytes) -> bytes: @@ -61,11 +62,11 @@ def test_bytevector_coercion(value: Any, expected: bytes) -> None: def test_bytevector_wrong_length_raises() -> None: - with pytest.raises(ValueError): + with pytest.raises(SSZValueError): Bytes4(b"\x00\x01\x02") # 3 != 4 - with pytest.raises(ValueError): + with pytest.raises(SSZValueError): Bytes4([0, 1, 2]) # 3 != 4 - with pytest.raises(ValueError): + with pytest.raises(SSZValueError): Bytes4("000102") # 3 != 4 (hex nibbles -> 3 bytes) @@ -89,7 +90,7 @@ class ByteList5(BaseByteList): def test_bytelist_over_limit_raises() -> None: # Test with ByteList64 that has limit 64 - with pytest.raises(ValueError): + with pytest.raises(SSZValueError): ByteList64(data=b"\x00" * 65) # Over the limit @@ -170,7 +171,7 @@ def test_encode_decode_roundtrip_vector(Typ: Type[BaseBytes], payload: bytes) -> def test_vector_deserialize_scope_mismatch_raises() -> None: v = Bytes4(b"\x00\x01\x02\x03") buf = io.BytesIO(v.encode_bytes()) - with pytest.raises(ValueError): + with pytest.raises(SSZSerializationError, match="expected 4 bytes, got 3"): Bytes4.deserialize(buf, 3) # wrong scope @@ -204,7 +205,7 @@ class TestByteList2(BaseByteList): LIMIT = 2 buf = io.BytesIO(b"\x00\x01\x02") - with pytest.raises(ValueError): + with pytest.raises(SSZValueError): TestByteList2.deserialize(buf, 3) @@ -213,7 +214,7 @@ class TestByteList10(BaseByteList): LIMIT = 10 buf = io.BytesIO(b"\x00\x01") - with pytest.raises(IOError): + with pytest.raises(SSZSerializationError): TestByteList10.deserialize(buf, 3) # stream too short @@ -249,11 +250,11 @@ def test_pydantic_accepts_various_inputs_for_vectors() -> None: def test_pydantic_validates_vector_lengths() -> None: - with pytest.raises(ValueError): + with pytest.raises(SSZValueError): ModelVectors(root=Bytes32(b"\x11" * 31), key=Bytes4(b"\x00\x01\x02\x03")) # too short - with pytest.raises(ValueError): + with pytest.raises(SSZValueError): ModelVectors(root=Bytes32(b"\x11" * 33), key=Bytes4(b"\x00\x01\x02\x03")) # too long - with pytest.raises(ValueError): + with pytest.raises(SSZValueError): ModelVectors(root=Bytes32(b"\x11" * 32), key=Bytes4(b"\x00\x01\x02")) # key too short @@ -277,7 +278,7 @@ def test_pydantic_accepts_and_serializes_bytelist() -> None: def test_pydantic_bytelist_limit_enforced() -> None: - with pytest.raises(ValueError): + with pytest.raises(SSZValueError): ModelLists(payload=ByteList16(data=bytes(range(17)))) # over limit diff --git a/tests/lean_spec/types/test_collections.py b/tests/lean_spec/types/test_collections.py index 4d98c03b..6529fb9c 100644 --- a/tests/lean_spec/types/test_collections.py +++ b/tests/lean_spec/types/test_collections.py @@ -10,8 +10,12 @@ from lean_spec.types.boolean import Boolean from lean_spec.types.collections import SSZList, SSZVector from lean_spec.types.container import Container +from lean_spec.types.exceptions import SSZSerializationError, SSZTypeError, SSZValueError from lean_spec.types.uint import Uint8, Uint16, Uint32, Uint256 +# Type alias for errors that can be SSZValueError or wrapped in ValidationError +ValueOrValidationError = (SSZValueError, ValidationError) + # Define some List types that are needed for Container definitions class Uint16List4(SSZList): @@ -234,9 +238,9 @@ def test_instantiation_success(self) -> None: def test_instantiation_with_wrong_length_raises_error(self) -> None: """Tests that providing the wrong number of items during instantiation fails.""" vec_type = Uint8Vector4 - with pytest.raises(ValueError, match="requires exactly 4 items"): + with pytest.raises(ValueOrValidationError): vec_type(data=[Uint8(1), Uint8(2), Uint8(3)]) # Too few - with pytest.raises(ValueError, match="requires exactly 4 items"): + with pytest.raises(ValueOrValidationError): vec_type(data=[Uint8(1), Uint8(2), Uint8(3), Uint8(4), Uint8(5)]) # Too many def test_pydantic_validation(self) -> None: @@ -246,11 +250,11 @@ def test_pydantic_validation(self) -> None: assert isinstance(instance.value, Uint8Vector2) assert list(instance.value) == [Uint8(10), Uint8(20)] # Test invalid data - with pytest.raises(ValidationError): + with pytest.raises(ValueOrValidationError): Uint8Vector2Model(value={"data": [10]}) # type: ignore[arg-type] - with pytest.raises(ValidationError): + with pytest.raises(ValueOrValidationError): Uint8Vector2Model(value={"data": [10, 20, 30]}) # type: ignore[arg-type] - with pytest.raises(TypeError): + with pytest.raises(SSZTypeError): Uint8Vector2Model(value={"data": [10, "bad"]}) # type: ignore[arg-type] def test_vector_is_immutable(self) -> None: @@ -281,13 +285,13 @@ def test_class_getitem_creates_specialized_type(self) -> None: def test_instantiate_raw_type_raises_error(self) -> None: """Tests that the raw, non-specialized SSZList cannot be instantiated.""" - with pytest.raises(TypeError, match="must define ELEMENT_TYPE and LIMIT"): + with pytest.raises(SSZTypeError, match="must define ELEMENT_TYPE and LIMIT"): SSZList(data=[]) def test_instantiation_over_limit_raises_error(self) -> None: """Tests that providing more items than the limit during instantiation fails.""" list_type = Uint8List4 - with pytest.raises(ValueError, match="cannot contain more than 4 elements"): + with pytest.raises(ValueOrValidationError): list_type(data=[Uint8(1), Uint8(2), Uint8(3), Uint8(4), Uint8(5)]) def test_pydantic_validation(self) -> None: @@ -297,19 +301,19 @@ def test_pydantic_validation(self) -> None: assert isinstance(instance.value, Uint8List4) assert list(instance.value) == [Uint8(10), Uint8(20)] # Test invalid data - list too long - with pytest.raises(ValidationError): + with pytest.raises(ValueOrValidationError): Uint8List4Model( value=Uint8List4(data=[Uint8(10), Uint8(20), Uint8(30), Uint8(40), Uint8(50)]) ) def test_append_at_limit_raises_error(self) -> None: """Tests that creating a list at limit +1 fails during construction.""" - with pytest.raises(ValueError, match="cannot contain more than 4 elements"): + with pytest.raises(ValueOrValidationError): BooleanList4(data=[Boolean(True)] * 5) def test_extend_over_limit_raises_error(self) -> None: """Tests that creating a list over the limit fails during construction.""" - with pytest.raises(ValueError, match="cannot contain more than 4 elements"): + with pytest.raises(ValueOrValidationError): BooleanList4( data=[Boolean(True), Boolean(False), Boolean(True), Boolean(False), Boolean(True)] ) @@ -332,7 +336,7 @@ def test_add_with_sszlist(self) -> None: def test_add_exceeding_limit_raises_error(self) -> None: """Tests that concatenating beyond the limit raises an error.""" list1 = Uint8List4(data=[Uint8(1), Uint8(2), Uint8(3)]) - with pytest.raises(ValueError, match="cannot contain more than 4 elements"): + with pytest.raises(ValueOrValidationError): list1 + [4, 5] diff --git a/tests/lean_spec/types/test_uint.py b/tests/lean_spec/types/test_uint.py index e973154b..5f8f37ae 100644 --- a/tests/lean_spec/types/test_uint.py +++ b/tests/lean_spec/types/test_uint.py @@ -6,6 +6,7 @@ import pytest from pydantic import BaseModel, ValidationError +from lean_spec.types.exceptions import SSZSerializationError, SSZTypeError, SSZValueError from lean_spec.types.uint import ( BaseUint, Uint8, @@ -92,9 +93,9 @@ def test_pydantic_strict_mode_rejects_invalid_types( def test_instantiation_from_invalid_types_raises_error( uint_class: Type[BaseUint], invalid_value: Any, expected_type_name: str ) -> None: - """Tests that instantiating with non-integer types raises a TypeError.""" + """Tests that instantiating with non-integer types raises SSZTypeError.""" expected_msg = f"Expected int, got {expected_type_name}" - with pytest.raises(TypeError, match=expected_msg): + with pytest.raises(SSZTypeError, match=expected_msg): uint_class(invalid_value) @@ -109,16 +110,16 @@ def test_instantiation_and_type(uint_class: Type[BaseUint]) -> None: @pytest.mark.parametrize("uint_class", ALL_UINT_TYPES) def test_instantiation_negative(uint_class: Type[BaseUint]) -> None: - """Tests that instantiating with a negative number raises OverflowError.""" - with pytest.raises(OverflowError): + """Tests that instantiating with a negative number raises SSZValueError.""" + with pytest.raises(SSZValueError): uint_class(-5) @pytest.mark.parametrize("uint_class", ALL_UINT_TYPES) def test_instantiation_too_large(uint_class: Type[BaseUint]) -> None: - """Tests that instantiating with a value >= MAX raises OverflowError.""" + """Tests that instantiating with a value >= MAX raises SSZValueError.""" max_value = 2**uint_class.BITS - with pytest.raises(OverflowError): + with pytest.raises(SSZValueError): uint_class(max_value) @@ -140,17 +141,17 @@ def test_arithmetic_operators(uint_class: Type[BaseUint]) -> None: # Addition assert a + b == uint_class(a_val + b_val) - with pytest.raises(OverflowError): + with pytest.raises(SSZValueError): _ = max_val + b # Subtraction assert a - b == uint_class(a_val - b_val) - with pytest.raises(OverflowError): + with pytest.raises(SSZValueError): _ = b - a # Multiplication assert a * b == uint_class(a_val * b_val) - with pytest.raises(OverflowError): + with pytest.raises(SSZValueError): _ = max_val * b # Floor Division @@ -162,7 +163,7 @@ def test_arithmetic_operators(uint_class: Type[BaseUint]) -> None: # Exponentiation assert uint_class(b_val) ** 4 == uint_class(b_val**4) if uint_class.BITS <= 16: # Pow gets too big quickly - with pytest.raises(OverflowError): + with pytest.raises(SSZValueError): _ = a ** int(b) @@ -399,10 +400,10 @@ def test_encode_decode_roundtrip( @pytest.mark.parametrize("uint_class", ALL_UINT_TYPES) def test_decode_bytes_invalid_length(self, uint_class: Type[BaseUint]) -> None: - """Tests that `decode_bytes` raises a ValueError for data of the wrong length.""" + """Tests that `decode_bytes` raises SSZSerializationError for wrong length data.""" # Create byte string that is one byte too short. invalid_data = b"\x00" * (uint_class.get_byte_length() - 1) - with pytest.raises(ValueError, match="Invalid byte length"): + with pytest.raises(SSZSerializationError, match="expected .* bytes, got"): uint_class.decode_bytes(invalid_data) @pytest.mark.parametrize("uint_class", ALL_UINT_TYPES) @@ -426,17 +427,17 @@ def test_serialize_deserialize_stream_roundtrip(self, uint_class: Type[BaseUint] @pytest.mark.parametrize("uint_class", ALL_UINT_TYPES) def test_deserialize_invalid_scope(self, uint_class: Type[BaseUint]) -> None: - """Tests that `deserialize` raises a ValueError if the scope is incorrect.""" + """Tests that `deserialize` raises an SSZSerializationError if the scope is incorrect.""" stream = io.BytesIO(b"\x00" * uint_class.get_byte_length()) invalid_scope = uint_class.get_byte_length() - 1 - with pytest.raises(ValueError, match="Invalid scope"): + with pytest.raises(SSZSerializationError, match="invalid scope"): uint_class.deserialize(stream, scope=invalid_scope) @pytest.mark.parametrize("uint_class", ALL_UINT_TYPES) def test_deserialize_stream_too_short(self, uint_class: Type[BaseUint]) -> None: - """Tests that `deserialize` raises an IOError if the stream ends prematurely.""" + """Tests that `deserialize` raises SSZSerializationError if stream ends prematurely.""" byte_length = uint_class.get_byte_length() # Create a stream that is shorter than what the type requires. stream = io.BytesIO(b"\x00" * (byte_length - 1)) - with pytest.raises(IOError, match="Stream ended prematurely"): + with pytest.raises(SSZSerializationError, match="expected .* bytes, got"): uint_class.deserialize(stream, scope=byte_length) diff --git a/tests/lean_spec/types/test_union.py b/tests/lean_spec/types/test_union.py index b7ab6ac1..59b010dd 100644 --- a/tests/lean_spec/types/test_union.py +++ b/tests/lean_spec/types/test_union.py @@ -9,6 +9,7 @@ from lean_spec.types.collections import SSZList, SSZVector from lean_spec.types.container import Container +from lean_spec.types.exceptions import SSZSerializationError, SSZTypeError, SSZValueError from lean_spec.types.ssz_base import SSZType from lean_spec.types.uint import Uint8, Uint16, Uint32 from lean_spec.types.union import SSZUnion @@ -112,14 +113,14 @@ def test_constructor_success() -> None: def test_constructor_errors() -> None: """Test Union construction error cases.""" # Invalid selector (out of range) - with pytest.raises(ValueError, match="Invalid selector"): + with pytest.raises(SSZValueError, match="out of range"): OptionalNumericUnion(selector=3, value=None) # None value for None option should work OptionalNumericUnion(selector=0, value=None) # Non-None value for None option should fail - with pytest.raises(TypeError, match="value must be None"): + with pytest.raises(SSZTypeError, match="Expected None"): OptionalNumericUnion(selector=0, value=Uint16(1)) @@ -138,22 +139,22 @@ def test_pydantic_validation_ok() -> None: def test_pydantic_validation_errors() -> None: """Test Pydantic validation error cases.""" # Test invalid selector directly - with pytest.raises(ValueError, match="Invalid selector"): + with pytest.raises(SSZValueError, match="out of range"): OptionalNumericUnion(selector=9, value=0) # Test invalid value for None option directly - with pytest.raises(TypeError, match="value must be None"): + with pytest.raises(SSZTypeError, match="Expected None"): OptionalNumericUnion(selector=0, value=1) # Test with Pydantic model wrapper - should catch underlying errors model = create_model("M", v=(OptionalNumericUnion, ...)) # Invalid selector in model context - with pytest.raises((ValidationError, ValueError)): + with pytest.raises((ValidationError, SSZValueError)): model(v={"selector": 9, "value": 0}) # Invalid value for None option in model context - with pytest.raises((ValidationError, TypeError)): + with pytest.raises((ValidationError, SSZTypeError)): model(v={"selector": 0, "value": 1}) @@ -192,15 +193,15 @@ def test_union_with_nested_composites_roundtrip() -> None: def test_deserialize_errors() -> None: """Test deserialization error cases.""" # Too small scope - with pytest.raises(ValueError, match="Scope too small"): + with pytest.raises(SSZSerializationError, match="scope too small"): SimpleUnion.deserialize(io.BytesIO(b""), 0) # Invalid selector - with pytest.raises(ValueError, match="out of range"): + with pytest.raises(SSZValueError, match="out of range"): SimpleUnion.deserialize(io.BytesIO(b"\x09"), 1) # None option with payload - with pytest.raises(ValueError, match="no payload bytes"): + with pytest.raises(SSZSerializationError, match="no payload bytes"): OptionalNumericUnion.deserialize(io.BytesIO(b"\x00\xff"), 2) @@ -255,7 +256,7 @@ def test_is_fixed_size_helper() -> None: def test_get_byte_length_raises() -> None: """Test get_byte_length() raises for variable-size types.""" - with pytest.raises(TypeError, match="variable-size"): + with pytest.raises(SSZTypeError, match="variable-size"): NumericUnion.get_byte_length() @@ -270,7 +271,7 @@ class ValidUnion(SSZUnion): assert instance.selector == 0 # Invalid union with None not at index 0 should fail during validation - with pytest.raises(TypeError, match="None at index 0"): + with pytest.raises(SSZTypeError, match="None only allowed at index 0"): class InvalidUnion1(SSZUnion): OPTIONS = (Uint16, None) @@ -281,7 +282,7 @@ class InvalidUnion1(SSZUnion): class NotSSZ: pass - with pytest.raises(TypeError, match="takes no arguments"): + with pytest.raises(SSZTypeError): class InvalidUnion2(SSZUnion): OPTIONS = (cast(PyType[SSZType], NotSSZ),) @@ -297,7 +298,7 @@ def test_union_boundary_cases() -> None: assert u.value == Uint16(42) # None-only union should fail validation - with pytest.raises(TypeError, match="only option"): + with pytest.raises(SSZTypeError, match="only option"): class NoneOnlyUnion(SSZUnion): OPTIONS = (None,)