diff --git a/src/lean_spec/subspecs/xmss/_validation.py b/src/lean_spec/subspecs/xmss/_validation.py new file mode 100644 index 00000000..52ed909b --- /dev/null +++ b/src/lean_spec/subspecs/xmss/_validation.py @@ -0,0 +1,29 @@ +"""Internal validation utilities for the XMSS scheme.""" + +from __future__ import annotations + +from typing import Any + + +def enforce_strict_types(instance: Any, **field_types: type) -> None: + """ + Validate that specified fields are exact types, not subclasses. + + This is a helper function to be called from Pydantic model validators. + + It enforces that field values are exactly the declared type, preventing + type confusion attacks where a malicious subclass could override behavior. + + Args: + instance: The model instance being validated. + **field_types: Mapping of field names to their exact expected types. + + Raises: + TypeError: If any field is a subclass rather than the exact type. + """ + for field_name, expected_type in field_types.items(): + value = getattr(instance, field_name) + if type(value) is not expected_type: + raise TypeError( + f"{field_name} must be exactly {expected_type.__name__}, not a subclass" + ) diff --git a/src/lean_spec/subspecs/xmss/interface.py b/src/lean_spec/subspecs/xmss/interface.py index 7a7916ad..23880fa2 100644 --- a/src/lean_spec/subspecs/xmss/interface.py +++ b/src/lean_spec/subspecs/xmss/interface.py @@ -18,6 +18,7 @@ ) from lean_spec.types import StrictBaseModel, Uint64 +from ._validation import enforce_strict_types from .constants import ( PROD_CONFIG, TEST_CONFIG, @@ -60,18 +61,16 @@ class GeneralizedXmssScheme(StrictBaseModel): """Random data generator for key generation.""" @model_validator(mode="after") - def enforce_strict_types(self) -> "GeneralizedXmssScheme": + def _validate_strict_types(self) -> "GeneralizedXmssScheme": """Reject subclasses to prevent type confusion attacks.""" - if type(self.config) is not XmssConfig: - raise TypeError("config must be exactly XmssConfig, not a subclass") - if type(self.prf) is not Prf: - raise TypeError("prf must be exactly Prf, not a subclass") - if type(self.hasher) is not TweakHasher: - raise TypeError("hasher must be exactly TweakHasher, not a subclass") - if type(self.encoder) is not TargetSumEncoder: - raise TypeError("encoder must be exactly TargetSumEncoder, not a subclass") - if type(self.rand) is not Rand: - raise TypeError("rand must be exactly Rand, not a subclass") + enforce_strict_types( + self, + config=XmssConfig, + prf=Prf, + hasher=TweakHasher, + encoder=TargetSumEncoder, + rand=Rand, + ) return self def key_gen( diff --git a/src/lean_spec/subspecs/xmss/message_hash.py b/src/lean_spec/subspecs/xmss/message_hash.py index 23d1a72b..4a2537ff 100644 --- a/src/lean_spec/subspecs/xmss/message_hash.py +++ b/src/lean_spec/subspecs/xmss/message_hash.py @@ -39,6 +39,7 @@ from lean_spec.types import StrictBaseModel, Uint64 from ..koalabear import Fp, P +from ._validation import enforce_strict_types from .constants import ( PROD_CONFIG, TEST_CONFIG, @@ -64,12 +65,9 @@ class MessageHasher(StrictBaseModel): """Poseidon hash engine.""" @model_validator(mode="after") - def enforce_strict_types(self) -> "MessageHasher": + def _validate_strict_types(self) -> "MessageHasher": """Reject subclasses to prevent type confusion attacks.""" - if type(self.config) is not XmssConfig: - raise TypeError("config must be exactly XmssConfig, not a subclass") - if type(self.poseidon) is not PoseidonXmss: - raise TypeError("poseidon must be exactly PoseidonXmss, not a subclass") + enforce_strict_types(self, config=XmssConfig, poseidon=PoseidonXmss) return self def encode_message(self, message: bytes) -> list[Fp]: diff --git a/src/lean_spec/subspecs/xmss/poseidon.py b/src/lean_spec/subspecs/xmss/poseidon.py index e57a2431..b8ed4ae9 100644 --- a/src/lean_spec/subspecs/xmss/poseidon.py +++ b/src/lean_spec/subspecs/xmss/poseidon.py @@ -33,6 +33,7 @@ Poseidon2Params, permute, ) +from ._validation import enforce_strict_types from .utils import int_to_base_p @@ -46,12 +47,9 @@ class PoseidonXmss(StrictBaseModel): """Poseidon2 parameters for 24-width permutation.""" @model_validator(mode="after") - def enforce_strict_types(self) -> "PoseidonXmss": + def _validate_strict_types(self) -> "PoseidonXmss": """Reject subclasses to prevent type confusion attacks.""" - if type(self.params16) is not Poseidon2Params: - raise TypeError("params16 must be exactly Poseidon2Params, not a subclass") - if type(self.params24) is not Poseidon2Params: - raise TypeError("params24 must be exactly Poseidon2Params, not a subclass") + enforce_strict_types(self, params16=Poseidon2Params, params24=Poseidon2Params) return self def compress(self, input_vec: list[Fp], width: int, output_len: int) -> list[Fp]: diff --git a/src/lean_spec/subspecs/xmss/prf.py b/src/lean_spec/subspecs/xmss/prf.py index babb7463..c5167ae3 100644 --- a/src/lean_spec/subspecs/xmss/prf.py +++ b/src/lean_spec/subspecs/xmss/prf.py @@ -17,6 +17,7 @@ from lean_spec.subspecs.koalabear import Fp from lean_spec.types import StrictBaseModel, Uint64 +from ._validation import enforce_strict_types from .constants import ( PRF_KEY_LENGTH, PROD_CONFIG, @@ -109,10 +110,9 @@ class Prf(StrictBaseModel): """Configuration parameters for the PRF.""" @model_validator(mode="after") - def enforce_strict_types(self) -> "Prf": + def _validate_strict_types(self) -> "Prf": """Reject subclasses to prevent type confusion attacks.""" - if type(self.config) is not XmssConfig: - raise TypeError("config must be exactly XmssConfig, not a subclass") + enforce_strict_types(self, config=XmssConfig) return self def key_gen(self) -> PRFKey: diff --git a/src/lean_spec/subspecs/xmss/rand.py b/src/lean_spec/subspecs/xmss/rand.py index 488b47a5..c108b165 100644 --- a/src/lean_spec/subspecs/xmss/rand.py +++ b/src/lean_spec/subspecs/xmss/rand.py @@ -7,6 +7,7 @@ from lean_spec.types import StrictBaseModel from ..koalabear import Fp, P +from ._validation import enforce_strict_types from .constants import PROD_CONFIG, TEST_CONFIG, XmssConfig from .types import HashDigestVector, Parameter, Randomness @@ -18,10 +19,9 @@ class Rand(StrictBaseModel): """Configuration parameters for the random generator.""" @model_validator(mode="after") - def enforce_strict_types(self) -> "Rand": + def _validate_strict_types(self) -> "Rand": """Reject subclasses to prevent type confusion attacks.""" - if type(self.config) is not XmssConfig: - raise TypeError("config must be exactly XmssConfig, not a subclass") + enforce_strict_types(self, config=XmssConfig) return self def field_elements(self, length: int) -> list[Fp]: diff --git a/src/lean_spec/subspecs/xmss/target_sum.py b/src/lean_spec/subspecs/xmss/target_sum.py index 4ab9f0ff..fc97614c 100644 --- a/src/lean_spec/subspecs/xmss/target_sum.py +++ b/src/lean_spec/subspecs/xmss/target_sum.py @@ -10,6 +10,7 @@ from lean_spec.types import StrictBaseModel, Uint64 +from ._validation import enforce_strict_types from .constants import PROD_CONFIG, TEST_CONFIG, XmssConfig from .message_hash import ( PROD_MESSAGE_HASHER, @@ -34,12 +35,9 @@ class TargetSumEncoder(StrictBaseModel): """Message hasher for encoding.""" @model_validator(mode="after") - def enforce_strict_types(self) -> "TargetSumEncoder": + def _validate_strict_types(self) -> "TargetSumEncoder": """Reject subclasses to prevent type confusion attacks.""" - if type(self.config) is not XmssConfig: - raise TypeError("config must be exactly XmssConfig, not a subclass") - if type(self.message_hasher) is not MessageHasher: - raise TypeError("message_hasher must be exactly MessageHasher, not a subclass") + enforce_strict_types(self, config=XmssConfig, message_hasher=MessageHasher) return self def encode( diff --git a/src/lean_spec/subspecs/xmss/tweak_hash.py b/src/lean_spec/subspecs/xmss/tweak_hash.py index 54a0181f..ec8b0538 100644 --- a/src/lean_spec/subspecs/xmss/tweak_hash.py +++ b/src/lean_spec/subspecs/xmss/tweak_hash.py @@ -31,6 +31,7 @@ from lean_spec.types import StrictBaseModel, Uint64 from ..koalabear import Fp +from ._validation import enforce_strict_types from .constants import ( PROD_CONFIG, TEST_CONFIG, @@ -86,12 +87,9 @@ class TweakHasher(StrictBaseModel): """Poseidon permutation instance for hashing.""" @model_validator(mode="after") - def enforce_strict_types(self) -> "TweakHasher": + def _validate_strict_types(self) -> "TweakHasher": """Reject subclasses to prevent type confusion attacks.""" - if type(self.config) is not XmssConfig: - raise TypeError("config must be exactly XmssConfig, not a subclass") - if type(self.poseidon) is not PoseidonXmss: - raise TypeError("poseidon must be exactly PoseidonXmss, not a subclass") + enforce_strict_types(self, config=XmssConfig, poseidon=PoseidonXmss) return self def _encode_tweak(self, tweak: TreeTweak | ChainTweak, length: int) -> list[Fp]: @@ -123,17 +121,18 @@ def _encode_tweak(self, tweak: TreeTweak | ChainTweak, length: int) -> list[Fp]: # Pack the tweak's integer fields into a single large integer. # # A hardcoded prefix is included for domain separation between tweak types. - if isinstance(tweak, TreeTweak): - # Packing scheme: (level << 40) | (index << 8) | PREFIX - acc = (tweak.level << 40) | (int(tweak.index) << 8) | TWEAK_PREFIX_TREE.value - else: - # Packing scheme: (epoch << 24) | (chain_index << 16) | (step << 8) | PREFIX - acc = ( - (int(tweak.epoch) << 24) - | (tweak.chain_index << 16) - | (tweak.step << 8) - | TWEAK_PREFIX_CHAIN.value - ) + match tweak: + case TreeTweak(level=level, index=index): + # Packing scheme: (level << 40) | (index << 8) | PREFIX + acc = (level << 40) | (int(index) << 8) | TWEAK_PREFIX_TREE.value + case ChainTweak(epoch=epoch, chain_index=chain_index, step=step): + # Packing scheme: (epoch << 24) | (chain_index << 16) | (step << 8) | PREFIX + acc = ( + (int(epoch) << 24) + | (chain_index << 16) + | (step << 8) + | TWEAK_PREFIX_CHAIN.value + ) # Decompose the packed integer `acc` into a list of base-P field elements. return int_to_base_p(acc, length)