Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions src/lean_spec/subspecs/xmss/_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Internal validation utilities for the XMSS scheme."""

from __future__ import annotations

from typing import Any


def enforce_strict_types(instance: Any, **field_types: type) -> None:
"""
Validate that specified fields are exact types, not subclasses.

This is a helper function to be called from Pydantic model validators.

It enforces that field values are exactly the declared type, preventing
type confusion attacks where a malicious subclass could override behavior.

Args:
instance: The model instance being validated.
**field_types: Mapping of field names to their exact expected types.

Raises:
TypeError: If any field is a subclass rather than the exact type.
"""
for field_name, expected_type in field_types.items():
value = getattr(instance, field_name)
if type(value) is not expected_type:
raise TypeError(
f"{field_name} must be exactly {expected_type.__name__}, not a subclass"
)
21 changes: 10 additions & 11 deletions src/lean_spec/subspecs/xmss/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from lean_spec.types import StrictBaseModel, Uint64

from ._validation import enforce_strict_types
from .constants import (
PROD_CONFIG,
TEST_CONFIG,
Expand Down Expand Up @@ -60,18 +61,16 @@ class GeneralizedXmssScheme(StrictBaseModel):
"""Random data generator for key generation."""

@model_validator(mode="after")
def enforce_strict_types(self) -> "GeneralizedXmssScheme":
def _validate_strict_types(self) -> "GeneralizedXmssScheme":
"""Reject subclasses to prevent type confusion attacks."""
if type(self.config) is not XmssConfig:
raise TypeError("config must be exactly XmssConfig, not a subclass")
if type(self.prf) is not Prf:
raise TypeError("prf must be exactly Prf, not a subclass")
if type(self.hasher) is not TweakHasher:
raise TypeError("hasher must be exactly TweakHasher, not a subclass")
if type(self.encoder) is not TargetSumEncoder:
raise TypeError("encoder must be exactly TargetSumEncoder, not a subclass")
if type(self.rand) is not Rand:
raise TypeError("rand must be exactly Rand, not a subclass")
enforce_strict_types(
self,
config=XmssConfig,
prf=Prf,
hasher=TweakHasher,
encoder=TargetSumEncoder,
rand=Rand,
)
return self

def key_gen(
Expand Down
8 changes: 3 additions & 5 deletions src/lean_spec/subspecs/xmss/message_hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from lean_spec.types import StrictBaseModel, Uint64

from ..koalabear import Fp, P
from ._validation import enforce_strict_types
from .constants import (
PROD_CONFIG,
TEST_CONFIG,
Expand All @@ -64,12 +65,9 @@ class MessageHasher(StrictBaseModel):
"""Poseidon hash engine."""

@model_validator(mode="after")
def enforce_strict_types(self) -> "MessageHasher":
def _validate_strict_types(self) -> "MessageHasher":
"""Reject subclasses to prevent type confusion attacks."""
if type(self.config) is not XmssConfig:
raise TypeError("config must be exactly XmssConfig, not a subclass")
if type(self.poseidon) is not PoseidonXmss:
raise TypeError("poseidon must be exactly PoseidonXmss, not a subclass")
enforce_strict_types(self, config=XmssConfig, poseidon=PoseidonXmss)
return self

def encode_message(self, message: bytes) -> list[Fp]:
Expand Down
8 changes: 3 additions & 5 deletions src/lean_spec/subspecs/xmss/poseidon.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
Poseidon2Params,
permute,
)
from ._validation import enforce_strict_types
from .utils import int_to_base_p


Expand All @@ -46,12 +47,9 @@ class PoseidonXmss(StrictBaseModel):
"""Poseidon2 parameters for 24-width permutation."""

@model_validator(mode="after")
def enforce_strict_types(self) -> "PoseidonXmss":
def _validate_strict_types(self) -> "PoseidonXmss":
"""Reject subclasses to prevent type confusion attacks."""
if type(self.params16) is not Poseidon2Params:
raise TypeError("params16 must be exactly Poseidon2Params, not a subclass")
if type(self.params24) is not Poseidon2Params:
raise TypeError("params24 must be exactly Poseidon2Params, not a subclass")
enforce_strict_types(self, params16=Poseidon2Params, params24=Poseidon2Params)
return self

def compress(self, input_vec: list[Fp], width: int, output_len: int) -> list[Fp]:
Expand Down
6 changes: 3 additions & 3 deletions src/lean_spec/subspecs/xmss/prf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from lean_spec.subspecs.koalabear import Fp
from lean_spec.types import StrictBaseModel, Uint64

from ._validation import enforce_strict_types
from .constants import (
PRF_KEY_LENGTH,
PROD_CONFIG,
Expand Down Expand Up @@ -109,10 +110,9 @@ class Prf(StrictBaseModel):
"""Configuration parameters for the PRF."""

@model_validator(mode="after")
def enforce_strict_types(self) -> "Prf":
def _validate_strict_types(self) -> "Prf":
"""Reject subclasses to prevent type confusion attacks."""
if type(self.config) is not XmssConfig:
raise TypeError("config must be exactly XmssConfig, not a subclass")
enforce_strict_types(self, config=XmssConfig)
return self

def key_gen(self) -> PRFKey:
Expand Down
6 changes: 3 additions & 3 deletions src/lean_spec/subspecs/xmss/rand.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from lean_spec.types import StrictBaseModel

from ..koalabear import Fp, P
from ._validation import enforce_strict_types
from .constants import PROD_CONFIG, TEST_CONFIG, XmssConfig
from .types import HashDigestVector, Parameter, Randomness

Expand All @@ -18,10 +19,9 @@ class Rand(StrictBaseModel):
"""Configuration parameters for the random generator."""

@model_validator(mode="after")
def enforce_strict_types(self) -> "Rand":
def _validate_strict_types(self) -> "Rand":
"""Reject subclasses to prevent type confusion attacks."""
if type(self.config) is not XmssConfig:
raise TypeError("config must be exactly XmssConfig, not a subclass")
enforce_strict_types(self, config=XmssConfig)
return self

def field_elements(self, length: int) -> list[Fp]:
Expand Down
8 changes: 3 additions & 5 deletions src/lean_spec/subspecs/xmss/target_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from lean_spec.types import StrictBaseModel, Uint64

from ._validation import enforce_strict_types
from .constants import PROD_CONFIG, TEST_CONFIG, XmssConfig
from .message_hash import (
PROD_MESSAGE_HASHER,
Expand All @@ -34,12 +35,9 @@ class TargetSumEncoder(StrictBaseModel):
"""Message hasher for encoding."""

@model_validator(mode="after")
def enforce_strict_types(self) -> "TargetSumEncoder":
def _validate_strict_types(self) -> "TargetSumEncoder":
"""Reject subclasses to prevent type confusion attacks."""
if type(self.config) is not XmssConfig:
raise TypeError("config must be exactly XmssConfig, not a subclass")
if type(self.message_hasher) is not MessageHasher:
raise TypeError("message_hasher must be exactly MessageHasher, not a subclass")
enforce_strict_types(self, config=XmssConfig, message_hasher=MessageHasher)
return self

def encode(
Expand Down
31 changes: 15 additions & 16 deletions src/lean_spec/subspecs/xmss/tweak_hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from lean_spec.types import StrictBaseModel, Uint64

from ..koalabear import Fp
from ._validation import enforce_strict_types
from .constants import (
PROD_CONFIG,
TEST_CONFIG,
Expand Down Expand Up @@ -86,12 +87,9 @@ class TweakHasher(StrictBaseModel):
"""Poseidon permutation instance for hashing."""

@model_validator(mode="after")
def enforce_strict_types(self) -> "TweakHasher":
def _validate_strict_types(self) -> "TweakHasher":
"""Reject subclasses to prevent type confusion attacks."""
if type(self.config) is not XmssConfig:
raise TypeError("config must be exactly XmssConfig, not a subclass")
if type(self.poseidon) is not PoseidonXmss:
raise TypeError("poseidon must be exactly PoseidonXmss, not a subclass")
enforce_strict_types(self, config=XmssConfig, poseidon=PoseidonXmss)
return self

def _encode_tweak(self, tweak: TreeTweak | ChainTweak, length: int) -> list[Fp]:
Expand Down Expand Up @@ -123,17 +121,18 @@ def _encode_tweak(self, tweak: TreeTweak | ChainTweak, length: int) -> list[Fp]:
# Pack the tweak's integer fields into a single large integer.
#
# A hardcoded prefix is included for domain separation between tweak types.
if isinstance(tweak, TreeTweak):
# Packing scheme: (level << 40) | (index << 8) | PREFIX
acc = (tweak.level << 40) | (int(tweak.index) << 8) | TWEAK_PREFIX_TREE.value
else:
# Packing scheme: (epoch << 24) | (chain_index << 16) | (step << 8) | PREFIX
acc = (
(int(tweak.epoch) << 24)
| (tweak.chain_index << 16)
| (tweak.step << 8)
| TWEAK_PREFIX_CHAIN.value
)
match tweak:
case TreeTweak(level=level, index=index):
# Packing scheme: (level << 40) | (index << 8) | PREFIX
acc = (level << 40) | (int(index) << 8) | TWEAK_PREFIX_TREE.value
case ChainTweak(epoch=epoch, chain_index=chain_index, step=step):
# Packing scheme: (epoch << 24) | (chain_index << 16) | (step << 8) | PREFIX
acc = (
(int(epoch) << 24)
| (chain_index << 16)
| (step << 8)
| TWEAK_PREFIX_CHAIN.value
)

# Decompose the packed integer `acc` into a list of base-P field elements.
return int_to_base_p(acc, length)
Expand Down
Loading