diff --git a/src/lean_spec/subspecs/xmss/interface.py b/src/lean_spec/subspecs/xmss/interface.py index 5ece55c2..db6de779 100644 --- a/src/lean_spec/subspecs/xmss/interface.py +++ b/src/lean_spec/subspecs/xmss/interface.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import List, Tuple, cast +from typing import List, Tuple from pydantic import model_validator @@ -26,11 +26,6 @@ XmssConfig, ) from .containers import PublicKey, SecretKey, Signature -from .merkle_tree import ( - PROD_MERKLE_TREE, - TEST_MERKLE_TREE, - MerkleTree, -) from .prf import PROD_PRF, TEST_PRF, Prf from .rand import PROD_RAND, TEST_RAND, Rand from .tweak_hash import ( @@ -60,9 +55,6 @@ class GeneralizedXmssScheme(StrictBaseModel): hasher: TweakHasher """Hash function with tweakable domain separation.""" - merkle_tree: MerkleTree - """Merkle tree implementation for authentication paths.""" - encoder: TargetSumEncoder """Message encoder that produces valid codewords.""" @@ -76,7 +68,6 @@ def enforce_strict_types(self) -> "GeneralizedXmssScheme": "config": XmssConfig, "prf": Prf, "hasher": TweakHasher, - "merkle_tree": MerkleTree, "encoder": TargetSumEncoder, "rand": Rand, } @@ -175,7 +166,7 @@ def key_gen( left_bottom_tree = bottom_tree_from_prf_key( self.prf, self.hasher, - self.merkle_tree, + self.rand, config, prf_key, Uint64(start_bottom_tree_index), @@ -184,7 +175,7 @@ def key_gen( right_bottom_tree = bottom_tree_from_prf_key( self.prf, self.hasher, - self.merkle_tree, + self.rand, config, prf_key, Uint64(start_bottom_tree_index + 1), @@ -202,7 +193,7 @@ def key_gen( tree = bottom_tree_from_prf_key( self.prf, self.hasher, - self.merkle_tree, + self.rand, config, prf_key, Uint64(i), @@ -453,9 +444,7 @@ def verify(self, pk: PublicKey, epoch: Uint64, message: bytes, sig: Signature) - chain_ends: List[List[Fp]] = [] for chain_index, xi in enumerate(codeword): # The signature provides `start_digest`, which is the hash value after `xi` steps. - # Extract from SSZ type: HashDigestList -> HashDigestVector -> List[Fp] - hash_vector = cast(HashDigestVector, sig.hashes[chain_index]) - start_digest = cast(List[Fp], list(hash_vector.data)) + start_digest: List[Fp] = list(sig.hashes[chain_index]) # We must perform the remaining `BASE - 1 - xi` hashing steps # to compute the public endpoint of the chain. num_steps_remaining = config.BASE - 1 - xi @@ -475,7 +464,10 @@ def verify(self, pk: PublicKey, epoch: Uint64, message: bytes, sig: Signature) - # - Hashes the `chain_ends` to get the leaf node for the epoch, # - Uses the `opening` path from the signature to compute a candidate root. # - It returns true if and only if this candidate root matches the public key's root. - return self.merkle_tree.verify_path( + from .subtree import verify_path + + return verify_path( + hasher=self.hasher, parameter=pk.parameter, root=pk.root, position=epoch, @@ -568,7 +560,7 @@ def advance_preparation(self, sk: SecretKey) -> SecretKey: new_right_bottom_tree = bottom_tree_from_prf_key( prf=self.prf, hasher=self.hasher, - merkle_tree=self.merkle_tree, + rand=self.rand, config=self.config, prf_key=sk.prf_key, bottom_tree_index=new_right_tree_index, @@ -589,7 +581,6 @@ def advance_preparation(self, sk: SecretKey) -> SecretKey: config=PROD_CONFIG, prf=PROD_PRF, hasher=PROD_TWEAK_HASHER, - merkle_tree=PROD_MERKLE_TREE, encoder=PROD_TARGET_SUM_ENCODER, rand=PROD_RAND, ) @@ -599,7 +590,6 @@ def advance_preparation(self, sk: SecretKey) -> SecretKey: config=TEST_CONFIG, prf=TEST_PRF, hasher=TEST_TWEAK_HASHER, - merkle_tree=TEST_MERKLE_TREE, encoder=TEST_TARGET_SUM_ENCODER, rand=TEST_RAND, ) diff --git a/src/lean_spec/subspecs/xmss/merkle_tree.py b/src/lean_spec/subspecs/xmss/merkle_tree.py deleted file mode 100644 index 480167d3..00000000 --- a/src/lean_spec/subspecs/xmss/merkle_tree.py +++ /dev/null @@ -1,328 +0,0 @@ -""" -Implements the sparse Merkle tree used in the Generalized XMSS scheme. - -### Usage of Merkle Trees in XMSS scheme: Aggregating Keys - -A Merkle tree is a cryptographic data structure that allows for the efficient -aggregation and verification of large sets of data. In XMSS, its role is to -aggregate **one-time public keys** (the leaves of the tree) into a single, -compact **master public key** (the root of the tree). - -A verifier who knows only the root can be given a small proof (an "authentication -path") to efficiently verify that a specific one-time public key is a legitimate -part of the overall scheme. - -### Key Optimizations for XMSS - -This implementation includes two important features tailored for this use case: - -1. **Sparsity**: A key pair might have a massive theoretical lifetime (e.g., 2^32 epochs), - but in practice, a signer only needs to generate and store the keys for a - much smaller, active range (e.g., 2^20 epochs). This implementation builds a - "sparse" tree, only computing and storing the nodes and branches relevant to - this active range of leaves, leading to enormous savings in computation and memory. - -2. **Random Padding**: To simplify the algorithm that builds the tree, each active - layer is padded with random hash values. This ensures that every node can - always be paired with a sibling, eliminating complex edge-case logic for - "orphan" nodes at the boundaries of the sparse region. -""" - -from __future__ import annotations - -from typing import Iterator, List, Tuple, cast - -from pydantic import model_validator - -from lean_spec.types import StrictBaseModel, Uint64 - -from ..koalabear import Fp -from .constants import ( - PROD_CONFIG, - TEST_CONFIG, - XmssConfig, -) -from .rand import PROD_RAND, TEST_RAND, Rand -from .subtree import HashSubTree -from .tweak_hash import ( - PROD_TWEAK_HASHER, - TEST_TWEAK_HASHER, - TreeTweak, - TweakHasher, -) -from .types import ( - HashDigestList, - HashDigestVector, - HashTreeLayer, - HashTreeLayers, - HashTreeOpening, - Parameter, -) -from .utils import get_padded_layer - - -class MerkleTree(StrictBaseModel): - """An instance of the Merkle Tree handler for a given config.""" - - config: XmssConfig - """Configuration parameters for the Merkle tree.""" - - hasher: TweakHasher - """Hash function for hashing tree nodes.""" - - rand: Rand - """Random generator for padding.""" - - @model_validator(mode="after") - def enforce_strict_types(self) -> "MerkleTree": - """Validates that only exact approved types are used (rejects subclasses).""" - checks = {"config": XmssConfig, "hasher": TweakHasher, "rand": Rand} - for field, expected in checks.items(): - if type(getattr(self, field)) is not expected: - raise TypeError( - f"{field} must be exactly {expected.__name__}, " - f"got {type(getattr(self, field)).__name__}" - ) - return self - - def build( - self, - depth: int, - start_index: Uint64, - parameter: Parameter, - leaf_hashes: List[List[Fp]], - ) -> HashSubTree: - """ - Builds a new sparse Merkle tree from a contiguous range of leaf hashes. - - ### Construction Algorithm - - 1. **Initialization**: The process starts with the provided `leaf_hashes` - at the bottom of the tree (level 0). This layer is padded. - - 2. **Bottom-Up Iteration**: The tree is built level by level, from the - leaves towards the root. - - 3. **Parent Generation**: In each level, the current layer's nodes are - grouped into pairs (left child, right child). Each pair is then - hashed together (using a level- and index-specific tweak) to create - a parent node in the level above. - - 4. **Padding**: The new list of parent nodes is padded to prepare for the - next iteration, ensuring the sibling-pairing logic remains simple. - - 5. **Termination**: This process repeats until a layer with a single - node is produced. This final node is the tree's root. - - Args: - depth: The total depth of the tree (e.g., 32 for a 2^32 leaf space). - start_index: The absolute index of the first leaf in `leaf_hashes`. - parameter: The public parameter `P` for the hash function. - leaf_hashes: The list of pre-hashed leaf nodes. - - Returns: - The fully constructed `HashSubTree` object containing all computed layers. - """ - # Check there is enough space for the leafs in the tree. - if int(start_index) + len(leaf_hashes) > 2**depth: - raise ValueError("Not enough space for leafs in the tree.") - - # Start with the leaf hashes and apply the initial padding. - layers: List[HashTreeLayer] = [] - current_layer = get_padded_layer(self.rand, leaf_hashes, start_index) - layers.append(current_layer) - - # Iterate from the leaf layer (level 0) up to the root. - for level in range(depth): - parents: List[List[Fp]] = [] - # Group the current layer's nodes into pairs of (left, right) siblings. - # - # The padding guarantees this works perfectly without leaving orphan nodes. - children_iter = cast( - Iterator[Tuple[HashDigestVector, HashDigestVector]], - zip( - current_layer.nodes.data[0::2], - current_layer.nodes.data[1::2], - strict=False, - ), - ) - for i, children in enumerate(children_iter): - # Calculate the position of the parent node in the next level up. - parent_index = (current_layer.start_index // Uint64(2)) + Uint64(i) - # Create the tweak for hashing these two children. - tweak = TreeTweak(level=level + 1, index=parent_index) - # Hash the left and right children to get their parent. - # Convert HashDigestVector to List[Fp] for hashing - left_data = cast("Tuple[Fp, ...]", children[0].data) - right_data = cast("Tuple[Fp, ...]", children[1].data) - parent_node = self.hasher.apply( - parameter, tweak, [list(left_data), list(right_data)] - ) - parents.append(parent_node) - - # Pad the new list of parents to prepare for the next iteration. - new_start_index = current_layer.start_index // Uint64(2) - current_layer = get_padded_layer(self.rand, parents, new_start_index) - layers.append(current_layer) - - # Return the completed tree containing all computed layers. - # A full tree is represented as a HashSubTree with lowest_layer=0 - return HashSubTree( - depth=Uint64(depth), lowest_layer=Uint64(0), layers=HashTreeLayers(data=layers) - ) - - def root(self, tree: HashSubTree) -> HashDigestVector: - """ - Extracts the root digest from a constructed Merkle tree. - - The root is the single node in the final, highest layer of the `HashSubTree` - and serves as the primary component of the master public key. - """ - # The root is the single node in the final layer. - return cast(HashDigestVector, cast(HashTreeLayer, tree.layers.data[-1]).nodes[0]) - - def path(self, tree: HashSubTree, position: Uint64) -> HashTreeOpening: - """ - Computes the authentication path for a leaf. - - The path is the minimal set of sibling nodes a verifier needs to reconstruct - the root from a given leaf. This `O(log N)` proof is what makes Merkle - tree verification highly efficient. - - ### Path Generation Algorithm - The algorithm "climbs" the tree from the leaf level to the root. At each - level, it identifies the sibling of the current node on the path and adds - it to the `co_path`. It then moves up to the parent's position for the - next level. - - Args: - tree: The `HashSubTree` from which to extract the path. - position: The absolute index of the leaf whose path is needed. - - Returns: - A `HashTreeOpening` object containing the list of sibling hashes. - """ - # Check that there is at least one layer in the tree. - if len(tree.layers) == 0: - raise ValueError("Cannot generate path for empty tree.") - - # Check that the position is within the tree's range. - first_layer = cast(HashTreeLayer, tree.layers.data[0]) - if position < first_layer.start_index: - raise ValueError("Position (before start) is invalid.") - - if position >= first_layer.start_index + Uint64(len(first_layer.nodes)): - raise ValueError("Position (after end) is invalid.") - - co_path: List[List[Fp]] = [] - current_position = position - - # Iterate from the leaf layer (level 0) up to the layer below the root. - for level in range(int(tree.depth)): - # Determine the sibling's position by flipping the last bit (XOR with 1). - sibling_position = current_position ^ Uint64(1) - # Find the sibling's index within our sparsely stored `nodes` vector. - layer = cast(HashTreeLayer, tree.layers.data[level]) - sibling_index_in_vec = sibling_position - layer.start_index - # Add the sibling's hash to the co-path. - sibling_node = cast(HashDigestVector, layer.nodes.data[int(sibling_index_in_vec)]) - sibling_data = cast("Tuple[Fp, ...]", sibling_node.data) - co_path.append(list(sibling_data)) - # Move up to the parent's position for the next iteration. - current_position = current_position // Uint64(2) - - # Wrap in SSZ types - ssz_siblings = [HashDigestVector(data=sibling) for sibling in co_path] - return HashTreeOpening(siblings=HashDigestList(data=ssz_siblings)) - - def verify_path( - self, - parameter: Parameter, - root: HashDigestVector, - position: Uint64, - leaf_parts: List[List[Fp]], - opening: HashTreeOpening, - ) -> bool: - """ - Verifies a Merkle authentication path against a known, trusted root. - - This function is the final check in signature verification. It proves that the - one-time public key used for the signature (represented by `leaf_parts`) is a - legitimate member of the set committed to by the Merkle `root`. - - ### Verification Algorithm - - 1. **Leaf Computation**: The process begins at the bottom. The verifier first - hashes the `leaf_parts` to compute the actual leaf digest. This becomes the - starting `current_node` for the climb up the tree. - - 2. **Bottom-Up Reconstruction**: The verifier iterates through the `opening.siblings` - path. At each `level`, it takes the `current_node` and the `sibling_node` - from the path. - - 3. **Parent Calculation**: It determines if the `current_node` is a left or - right child based on its `position`. The two nodes are placed in the - correct `(left, right)` order and hashed (with the correct `TreeTweak`) - to compute the parent. This parent becomes the `current_node` for the - next level. - - 4. **Final Comparison**: After all siblings are used, the final `current_node` - is the candidate root. The path is valid if and only if it matches the trusted `root`. - - Args: - parameter: The public parameter `P` for the hash function. - root: The known, trusted Merkle root from the public key. - position: The absolute index of the leaf being verified. - leaf_parts: The list of digests that constitute the original leaf. - opening: The `HashTreeOpening` object containing the sibling path. - - Returns: - `True` if the path is valid and reconstructs the root, `False` otherwise. - """ - # Compute the depth - depth = len(opening.siblings) - # Compute the number of leafs in the tree - num_leafs = 2**depth - # Check that the tree depth is at most 32. - if len(opening.siblings) > 32: - raise ValueError("Tree depth must be at most 32.") - # Check that the position and path length match. - if int(position) >= num_leafs: - raise ValueError("Position and path length do not match.") - - # The first step is to hash the constituent parts of the leaf to get - # the actual node at layer 0 of the tree. - leaf_tweak = TreeTweak(level=0, index=int(position)) - current_node = self.hasher.apply(parameter, leaf_tweak, leaf_parts) - - # Iterate up the tree, hashing the current node with its sibling from - # the path at each level. - current_position = int(position) - for level, sibling_vector in enumerate(opening.siblings): - # Convert HashDigestVector to List[Fp] - sibling_node = list(sibling_vector.data) - # Determine if the current node is a left or right child. - if current_position % 2 == 0: - # Current node is a left child; sibling is on the right. - children = [current_node, sibling_node] - else: - # Current node is a right child; sibling is on the left. - children = [sibling_node, current_node] - - # Move up to the parent's position for the next iteration. - current_position //= 2 - # Create the tweak for the parent's level and position. - parent_tweak = TreeTweak(level=level + 1, index=current_position) - # Hash the children to compute the parent node. - current_node = self.hasher.apply(parameter, parent_tweak, children) - - # After iterating through the entire path, the final computed node - # should be the root of the tree. - return current_node == list(root.data) - - -PROD_MERKLE_TREE = MerkleTree(config=PROD_CONFIG, hasher=PROD_TWEAK_HASHER, rand=PROD_RAND) -"""An instance configured for production-level parameters.""" - -TEST_MERKLE_TREE = MerkleTree(config=TEST_CONFIG, hasher=TEST_TWEAK_HASHER, rand=TEST_RAND) -"""A lightweight instance for test environments.""" diff --git a/src/lean_spec/subspecs/xmss/subtree.py b/src/lean_spec/subspecs/xmss/subtree.py index 00e59f44..1811e234 100644 --- a/src/lean_spec/subspecs/xmss/subtree.py +++ b/src/lean_spec/subspecs/xmss/subtree.py @@ -7,7 +7,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Iterator, List, Tuple, cast +from typing import TYPE_CHECKING, List from lean_spec.types import Uint64 from lean_spec.types.container import Container @@ -164,24 +164,19 @@ def new( # Group current layer's nodes into pairs of (left, right) siblings. # The padding guarantees this works perfectly without orphan nodes. - children_iter = cast( - Iterator[Tuple[HashDigestVector, HashDigestVector]], - zip( - current_layer.nodes.data[0::2], - current_layer.nodes.data[1::2], - strict=False, - ), - ) - for i, children in enumerate(children_iter): + num_nodes = len(current_layer.nodes) + for i in range(0, num_nodes, 2): + left_node = current_layer.nodes[i] + right_node = current_layer.nodes[i + 1] + # Calculate the position of the parent node in the next level up. - parent_index = (current_layer.start_index // Uint64(2)) + Uint64(i) + parent_index = (current_layer.start_index // Uint64(2)) + Uint64(i // 2) # Create the tweak for hashing these two children. tweak = TreeTweak(level=level + 1, index=parent_index) # Hash the left and right children to get their parent. - # Convert HashDigestVector to List[Fp] for hashing - left_data = cast("Tuple[Fp, ...]", children[0].data) - right_data = cast("Tuple[Fp, ...]", children[1].data) - parent_node = hasher.apply(parameter, tweak, [list(left_data), list(right_data)]) + left_data = list(left_node) + right_data = list(right_node) + parent_node = hasher.apply(parameter, tweak, [left_data, right_data]) parents.append(parent_node) # Pad the new list of parents to prepare for the next iteration. @@ -248,7 +243,7 @@ def new_top_tree( lowest_layer = depth // 2 # Convert HashDigestVector roots to List[Fp] for building - roots_as_lists = [cast(List[Fp], list(root.data)) for root in bottom_tree_roots] + roots_as_lists: List[List[Fp]] = [list(root) for root in bottom_tree_roots] # Build the top tree using the bottom tree roots as the lowest layer. return cls.new( @@ -341,22 +336,20 @@ def new_bottom_tree( # the bottom_tree_index (if it's the left child of its parent in the top tree) # or bottom_tree_index (if it's the right child). Since we're at layer depth/2, # the position is simply bottom_tree_index. - middle_layer = cast(HashTreeLayer, full_tree.layers.data[depth // 2]) + middle_layer = full_tree.layers[depth // 2] # The root is at position (start_index >> (depth // 2)) = bottom_tree_index # within the middle layer. We need to find it in the stored nodes. root_position_in_layer = bottom_tree_index - middle_layer.start_index - root_node = cast(HashDigestVector, middle_layer.nodes.data[int(root_position_in_layer)]) - root_data = cast("Tuple[Fp, ...]", root_node.data) - root = list(root_data) + root_node = middle_layer.nodes[int(root_position_in_layer)] + root: List[Fp] = list(root_node) # Truncate layers to keep only 0 through depth/2 - 1. - truncated_layers = list(full_tree.layers.data[: (depth // 2)]) + truncated_layers: List[HashTreeLayer] = [full_tree.layers[i] for i in range(depth // 2)] # Add a final layer containing just the root. - root_vector = HashDigestVector(data=root) root_layer = HashTreeLayer( - start_index=bottom_tree_index, nodes=HashDigestList(data=[root_vector]) + start_index=bottom_tree_index, nodes=HashDigestList(data=[HashDigestVector(data=root)]) ) truncated_layers.append(root_layer) @@ -382,13 +375,13 @@ def root(self) -> HashDigestVector: if len(self.layers) == 0: raise ValueError("Cannot get root of empty subtree.") - highest_layer = cast(HashTreeLayer, self.layers.data[-1]) - if len(highest_layer.nodes.data) == 0: + highest_layer = self.layers[-1] + if len(highest_layer.nodes) == 0: raise ValueError("Highest layer of subtree is empty.") # The root is the only node in the highest layer for proper subtrees. # For top trees and bottom trees, the highest layer should have exactly one node. - return cast(HashDigestVector, highest_layer.nodes[0]) + return highest_layer.nodes[0] def path(self, position: Uint64) -> HashTreeOpening: """ @@ -413,41 +406,39 @@ def path(self, position: Uint64) -> HashTreeOpening: if len(self.layers) == 0: raise ValueError("Cannot generate path for empty subtree.") - lowest_layer = cast(HashTreeLayer, self.layers.data[0]) - if position < lowest_layer.start_index: + first_layer = self.layers[0] + if position < first_layer.start_index: raise ValueError("Position is before the subtree's start index.") - if position >= lowest_layer.start_index + Uint64(len(lowest_layer.nodes)): + if position >= first_layer.start_index + Uint64(len(first_layer.nodes)): raise ValueError("Position is beyond the subtree's range.") # Build the co-path directly with SSZ types - siblings = HashDigestList(data=[]) + siblings: List[HashDigestVector] = [] current_position = position # Iterate through layers from lowest to highest, EXCLUDING the final root layer. # The root layer doesn't contribute a sibling to the authentication path. - # self.layers.data[:-1] gives all layers except the last (root) layer. - for layer_raw in self.layers.data[:-1]: - layer = cast(HashTreeLayer, layer_raw) + num_layers = len(self.layers) + for layer_idx in range(num_layers - 1): + layer = self.layers[layer_idx] # Determine the sibling's position by flipping the last bit. sibling_position = current_position ^ Uint64(1) - sibling_index = sibling_position - layer.start_index + sibling_index = int(sibling_position - layer.start_index) # Ensure the sibling exists in this layer - if sibling_index < Uint64(0) or sibling_index >= Uint64(len(layer.nodes)): + if sibling_index < 0 or sibling_index >= len(layer.nodes): raise ValueError( f"Sibling index {sibling_index} out of bounds for layer " f"with {len(layer.nodes)} nodes" ) - # Access the sibling directly from the SSZ list and add to path - siblings = siblings + [layer.nodes[int(sibling_index)]] + siblings.append(layer.nodes[sibling_index]) # Move to the parent's position for the next iteration. current_position = current_position // Uint64(2) - # Return the opening with SSZ-typed siblings - return HashTreeOpening(siblings=siblings) + return HashTreeOpening(siblings=HashDigestList(data=siblings)) def combined_path( @@ -518,7 +509,7 @@ def combined_path( # Verify that the provided bottom_tree actually corresponds to this position. # The bottom tree's lowest layer starts at bottom_tree_index * leafs_per_bottom_tree. expected_start = bottom_tree_index * Uint64(leafs_per_bottom_tree) - actual_start = cast(HashTreeLayer, bottom_tree.layers.data[0]).start_index + actual_start = bottom_tree.layers[0].start_index if actual_start != expected_start: raise ValueError( @@ -536,7 +527,98 @@ def combined_path( # Concatenate the two paths: bottom siblings first, then top siblings. # This creates a complete path from leaf to global root. - # Since siblings are now HashDigestList, we need to concatenate their data. - combined_siblings_data = list(bottom_path.siblings.data) + list(top_path.siblings.data) + combined_siblings: List[HashDigestVector] = [ + bottom_path.siblings[i] for i in range(len(bottom_path.siblings)) + ] + [top_path.siblings[i] for i in range(len(top_path.siblings))] + + return HashTreeOpening(siblings=HashDigestList(data=combined_siblings)) + + +def verify_path( + hasher: "TweakHasher", + parameter: Parameter, + root: HashDigestVector, + position: Uint64, + leaf_parts: List[List[Fp]], + opening: HashTreeOpening, +) -> bool: + """ + Verifies a Merkle authentication path against a known, trusted root. + + This function is the final check in signature verification. It proves that the + one-time public key used for the signature (represented by `leaf_parts`) is a + legitimate member of the set committed to by the Merkle `root`. + + ### Verification Algorithm + + 1. **Leaf Computation**: The process begins at the bottom. The verifier first + hashes the `leaf_parts` to compute the actual leaf digest. This becomes the + starting `current_node` for the climb up the tree. - return HashTreeOpening(siblings=HashDigestList(data=combined_siblings_data)) + 2. **Bottom-Up Reconstruction**: The verifier iterates through the `opening.siblings` + path. At each `level`, it takes the `current_node` and the `sibling_node` + from the path. + + 3. **Parent Calculation**: It determines if the `current_node` is a left or + right child based on its `position`. The two nodes are placed in the + correct `(left, right)` order and hashed (with the correct `TreeTweak`) + to compute the parent. This parent becomes the `current_node` for the + next level. + + 4. **Final Comparison**: After all siblings are used, the final `current_node` + is the candidate root. The path is valid if and only if it matches the trusted `root`. + + Args: + hasher: The tweakable hash instance for computing parent nodes. + parameter: The public parameter `P` for the hash function. + root: The known, trusted Merkle root from the public key. + position: The absolute index of the leaf being verified. + leaf_parts: The list of digests that constitute the original leaf. + opening: The `HashTreeOpening` object containing the sibling path. + + Returns: + `True` if the path is valid and reconstructs the root, `False` otherwise. + + Raises: + ValueError: If the tree depth exceeds 32 or position doesn't match path length. + """ + # Compute the depth + depth = len(opening.siblings) + # Compute the number of leafs in the tree + num_leafs = 2**depth + # Check that the tree depth is at most 32. + if len(opening.siblings) > 32: + raise ValueError("Tree depth must be at most 32.") + # Check that the position and path length match. + if int(position) >= num_leafs: + raise ValueError("Position and path length do not match.") + + # The first step is to hash the constituent parts of the leaf to get + # the actual node at layer 0 of the tree. + leaf_tweak = TreeTweak(level=0, index=int(position)) + current_node = hasher.apply(parameter, leaf_tweak, leaf_parts) + + # Iterate up the tree, hashing the current node with its sibling from + # the path at each level. + current_position = int(position) + for level in range(len(opening.siblings)): + sibling_node: List[Fp] = list(opening.siblings[level]) + + # Determine if the current node is a left or right child. + if current_position % 2 == 0: + # Current node is a left child; sibling is on the right. + children = [current_node, sibling_node] + else: + # Current node is a right child; sibling is on the left. + children = [sibling_node, current_node] + + # Move up to the parent's position for the next iteration. + current_position //= 2 + # Create the tweak for the parent's level and position. + parent_tweak = TreeTweak(level=level + 1, index=current_position) + # Hash the children to compute the parent node. + current_node = hasher.apply(parameter, parent_tweak, children) + + # After iterating through the entire path, the final computed node + # should be the root of the tree. + return current_node == list(root) diff --git a/src/lean_spec/subspecs/xmss/types.py b/src/lean_spec/subspecs/xmss/types.py index 268e9cdf..31f82541 100644 --- a/src/lean_spec/subspecs/xmss/types.py +++ b/src/lean_spec/subspecs/xmss/types.py @@ -74,6 +74,10 @@ class HashDigestList(SSZList): ELEMENT_TYPE = HashDigestVector LIMIT = NODE_LIST_LIMIT + def __getitem__(self, index: int) -> HashDigestVector: + """Access a hash digest by index with proper typing.""" + return self.data[index] # type: ignore[return-value] + class Parameter(SSZVector): """ @@ -164,3 +168,7 @@ class HashTreeLayers(SSZList): ELEMENT_TYPE = HashTreeLayer LIMIT = LAYERS_LIMIT + + def __getitem__(self, index: int) -> HashTreeLayer: + """Access a layer by index with proper typing.""" + return self.data[index] # type: ignore[return-value] diff --git a/src/lean_spec/subspecs/xmss/utils.py b/src/lean_spec/subspecs/xmss/utils.py index d1add5ee..744048e9 100644 --- a/src/lean_spec/subspecs/xmss/utils.py +++ b/src/lean_spec/subspecs/xmss/utils.py @@ -9,7 +9,6 @@ from .types import HashDigestList, HashDigestVector, HashTreeLayer, Parameter, PRFKey if TYPE_CHECKING: - from .merkle_tree import MerkleTree from .prf import Prf from .subtree import HashSubTree from .tweak_hash import TweakHasher @@ -164,7 +163,7 @@ def expand_activation_time( def bottom_tree_from_prf_key( prf: "Prf", hasher: "TweakHasher", - merkle_tree: "MerkleTree", + rand: Rand, config: XmssConfig, prf_key: PRFKey, bottom_tree_index: Uint64, @@ -193,7 +192,7 @@ def bottom_tree_from_prf_key( Args: prf: The PRF instance for key derivation. hasher: The tweakable hash instance. - merkle_tree: The Merkle tree instance for tree construction. + rand: Random generator for padding values. config: The XMSS configuration. prf_key: The master PRF secret key. bottom_tree_index: The index of the bottom tree to generate (0, 1, 2, ...). @@ -243,7 +242,7 @@ def bottom_tree_from_prf_key( return HashSubTree.new_bottom_tree( hasher=hasher, - rand=merkle_tree.rand, + rand=rand, depth=config.LOG_LIFETIME, bottom_tree_index=bottom_tree_index, parameter=parameter, diff --git a/tests/lean_spec/subspecs/xmss/test_merkle_tree.py b/tests/lean_spec/subspecs/xmss/test_merkle_tree.py index 988ba39d..282e7496 100644 --- a/tests/lean_spec/subspecs/xmss/test_merkle_tree.py +++ b/tests/lean_spec/subspecs/xmss/test_merkle_tree.py @@ -3,18 +3,19 @@ import pytest from lean_spec.subspecs.koalabear import Fp -from lean_spec.subspecs.xmss.merkle_tree import ( - PROD_MERKLE_TREE, - MerkleTree, -) +from lean_spec.subspecs.xmss.rand import PROD_RAND, Rand +from lean_spec.subspecs.xmss.subtree import HashSubTree, verify_path from lean_spec.subspecs.xmss.tweak_hash import ( + PROD_TWEAK_HASHER, TreeTweak, + TweakHasher, ) from lean_spec.types import Uint64 def _run_commit_open_verify_roundtrip( - merkle_tree: MerkleTree, + hasher: TweakHasher, + rand: Rand, num_leaves: int, depth: int, start_index: int, @@ -31,19 +32,22 @@ def _run_commit_open_verify_roundtrip( 5. Verify that each path is valid for its corresponding leaf and root. Args: + hasher: The tweakable hash instance for computing parent nodes. + rand: Random generator for padding values. num_leaves: The number of active leaves in the tree. + depth: The total depth of the Merkle tree. start_index: The starting index of the first active leaf. leaf_parts_len: The number of digests that constitute a single leaf. """ # SETUP: Generate a random parameter and the raw leaf data. - parameter = merkle_tree.rand.parameter() + parameter = rand.parameter() leaves: list[list[list[Fp]]] = [ - [merkle_tree.rand.domain() for _ in range(leaf_parts_len)] for _ in range(num_leaves) + [rand.domain() for _ in range(leaf_parts_len)] for _ in range(num_leaves) ] # HASH LEAVES: Compute the layer 0 nodes by hashing the leaf parts. leaf_hashes: list[list[Fp]] = [ - merkle_tree.hasher.apply( + hasher.apply( parameter, TreeTweak(level=0, index=start_index + i), leaf_parts, @@ -52,14 +56,29 @@ def _run_commit_open_verify_roundtrip( ] # COMMIT: Build the Merkle tree from the leaf hashes. - tree = merkle_tree.build(depth, Uint64(start_index), parameter, leaf_hashes) - root = merkle_tree.root(tree) + tree = HashSubTree.new( + hasher=hasher, + rand=rand, + lowest_layer=0, + depth=depth, + start_index=Uint64(start_index), + parameter=parameter, + lowest_layer_nodes=leaf_hashes, + ) + root = tree.root() # OPEN & VERIFY: For each leaf, generate and verify its path. for i, leaf_parts in enumerate(leaves): position = Uint64(start_index + i) - opening = merkle_tree.path(tree, position) - is_valid = merkle_tree.verify_path(parameter, root, position, leaf_parts, opening) + opening = tree.path(position) + is_valid = verify_path( + hasher=hasher, + parameter=parameter, + root=root, + position=position, + leaf_parts=leaf_parts, + opening=opening, + ) assert is_valid, f"Verification failed for leaf at position {position}" @@ -87,5 +106,5 @@ def test_commit_open_verify_roundtrip( assert start_index + num_leaves <= (1 << depth) _run_commit_open_verify_roundtrip( - PROD_MERKLE_TREE, num_leaves, depth, start_index, leaf_parts_len + PROD_TWEAK_HASHER, PROD_RAND, num_leaves, depth, start_index, leaf_parts_len ) diff --git a/tests/lean_spec/subspecs/xmss/test_strict_types.py b/tests/lean_spec/subspecs/xmss/test_strict_types.py index d3013ad8..cbd5f7cc 100644 --- a/tests/lean_spec/subspecs/xmss/test_strict_types.py +++ b/tests/lean_spec/subspecs/xmss/test_strict_types.py @@ -10,7 +10,6 @@ from lean_spec.subspecs.xmss.constants import PROD_CONFIG, TEST_CONFIG, XmssConfig from lean_spec.subspecs.xmss.interface import GeneralizedXmssScheme -from lean_spec.subspecs.xmss.merkle_tree import PROD_MERKLE_TREE, MerkleTree from lean_spec.subspecs.xmss.message_hash import PROD_MESSAGE_HASHER, MessageHasher from lean_spec.subspecs.xmss.poseidon import PROD_POSEIDON, PoseidonXmss from lean_spec.subspecs.xmss.prf import PROD_PRF, Prf @@ -145,83 +144,6 @@ def test_tweak_hasher_frozen(self) -> None: PROD_TWEAK_HASHER.config = TEST_CONFIG -class TestMerkleTreeStrictTypes: - """Tests for MerkleTree strict type checking.""" - - def test_merkle_tree_accepts_exact_types(self) -> None: - """MerkleTree initialization succeeds with exact types.""" - tree = MerkleTree(config=PROD_CONFIG, hasher=PROD_TWEAK_HASHER, rand=PROD_RAND) - assert tree.config == PROD_CONFIG - - def test_merkle_tree_rejects_subclass_config(self) -> None: - """MerkleTree rejects XmssConfig subclass.""" - - class CustomConfig(XmssConfig): - pass - - custom_config = XmssConfig.__new__(CustomConfig) - custom_config.__dict__.update(PROD_CONFIG.__dict__) - - with pytest.raises(TypeError, match="config must be exactly XmssConfig"): - MerkleTree(config=custom_config, hasher=PROD_TWEAK_HASHER, rand=PROD_RAND) - - def test_merkle_tree_rejects_subclass_hasher(self) -> None: - """MerkleTree rejects TweakHasher subclass.""" - - class CustomHasher(TweakHasher): - pass - - custom_hasher = TweakHasher.__new__(CustomHasher) - custom_hasher.__dict__.update(PROD_TWEAK_HASHER.__dict__) - - with pytest.raises(TypeError, match="hasher must be exactly TweakHasher"): - MerkleTree(config=PROD_CONFIG, hasher=custom_hasher, rand=PROD_RAND) - - def test_merkle_tree_rejects_subclass_rand(self) -> None: - """MerkleTree rejects Rand subclass.""" - - class CustomRand(Rand): - pass - - custom_rand = Rand.__new__(CustomRand) - custom_rand.__dict__.update(PROD_RAND.__dict__) - - with pytest.raises(TypeError, match="rand must be exactly Rand"): - MerkleTree(config=PROD_CONFIG, hasher=PROD_TWEAK_HASHER, rand=custom_rand) - - def test_merkle_tree_rejects_wrong_type_config(self) -> None: - """MerkleTree rejects completely wrong type for config.""" - - class RandomClass: - pass - - with pytest.raises((TypeError, ValidationError)): - MerkleTree(config=RandomClass(), hasher=PROD_TWEAK_HASHER, rand=PROD_RAND) - - def test_merkle_tree_rejects_wrong_type_hasher(self) -> None: - """MerkleTree rejects completely wrong type for hasher.""" - - class RandomClass: - pass - - with pytest.raises((TypeError, ValidationError)): - MerkleTree(config=PROD_CONFIG, hasher=RandomClass(), rand=PROD_RAND) - - def test_merkle_tree_rejects_wrong_type_rand(self) -> None: - """MerkleTree rejects completely wrong type for rand.""" - - class RandomClass: - pass - - with pytest.raises((TypeError, ValidationError)): - MerkleTree(config=PROD_CONFIG, hasher=PROD_TWEAK_HASHER, rand=RandomClass()) - - def test_merkle_tree_frozen(self) -> None: - """MerkleTree is immutable (frozen).""" - with pytest.raises(ValidationError): - PROD_MERKLE_TREE.config = TEST_CONFIG - - class TestTargetSumEncoderStrictTypes: """Tests for TargetSumEncoder strict type checking.""" @@ -287,7 +209,6 @@ def test_scheme_accepts_exact_types(self) -> None: config=PROD_CONFIG, prf=PROD_PRF, hasher=PROD_TWEAK_HASHER, - merkle_tree=PROD_MERKLE_TREE, encoder=PROD_TARGET_SUM_ENCODER, rand=PROD_RAND, ) @@ -307,7 +228,6 @@ class CustomConfig(XmssConfig): config=custom_config, prf=PROD_PRF, hasher=PROD_TWEAK_HASHER, - merkle_tree=PROD_MERKLE_TREE, encoder=PROD_TARGET_SUM_ENCODER, rand=PROD_RAND, ) @@ -326,7 +246,6 @@ class CustomPrf(Prf): config=PROD_CONFIG, prf=custom_prf, hasher=PROD_TWEAK_HASHER, - merkle_tree=PROD_MERKLE_TREE, encoder=PROD_TARGET_SUM_ENCODER, rand=PROD_RAND, ) @@ -345,26 +264,6 @@ class CustomHasher(TweakHasher): config=PROD_CONFIG, prf=PROD_PRF, hasher=custom_hasher, - merkle_tree=PROD_MERKLE_TREE, - encoder=PROD_TARGET_SUM_ENCODER, - rand=PROD_RAND, - ) - - def test_scheme_rejects_subclass_merkle_tree(self) -> None: - """GeneralizedXmssScheme rejects MerkleTree subclass.""" - - class CustomMerkleTree(MerkleTree): - pass - - custom_tree = MerkleTree.__new__(CustomMerkleTree) - custom_tree.__dict__.update(PROD_MERKLE_TREE.__dict__) - - with pytest.raises(TypeError, match="merkle_tree must be exactly MerkleTree"): - GeneralizedXmssScheme( - config=PROD_CONFIG, - prf=PROD_PRF, - hasher=PROD_TWEAK_HASHER, - merkle_tree=custom_tree, encoder=PROD_TARGET_SUM_ENCODER, rand=PROD_RAND, ) @@ -383,7 +282,6 @@ class CustomEncoder(TargetSumEncoder): config=PROD_CONFIG, prf=PROD_PRF, hasher=PROD_TWEAK_HASHER, - merkle_tree=PROD_MERKLE_TREE, encoder=custom_encoder, rand=PROD_RAND, ) @@ -402,7 +300,6 @@ class CustomRand(Rand): config=PROD_CONFIG, prf=PROD_PRF, hasher=PROD_TWEAK_HASHER, - merkle_tree=PROD_MERKLE_TREE, encoder=PROD_TARGET_SUM_ENCODER, rand=custom_rand, ) @@ -414,7 +311,6 @@ def test_scheme_rejects_extra_fields(self) -> None: config=PROD_CONFIG, prf=PROD_PRF, hasher=PROD_TWEAK_HASHER, - merkle_tree=PROD_MERKLE_TREE, encoder=PROD_TARGET_SUM_ENCODER, rand=PROD_RAND, extra_field="should_fail", diff --git a/tests/lean_spec/subspecs/xmss/test_utils.py b/tests/lean_spec/subspecs/xmss/test_utils.py index 679eff09..42cbc879 100644 --- a/tests/lean_spec/subspecs/xmss/test_utils.py +++ b/tests/lean_spec/subspecs/xmss/test_utils.py @@ -7,8 +7,8 @@ from lean_spec.subspecs.koalabear.field import Fp, P from lean_spec.subspecs.xmss.constants import TEST_CONFIG -from lean_spec.subspecs.xmss.merkle_tree import TEST_MERKLE_TREE from lean_spec.subspecs.xmss.prf import TEST_PRF +from lean_spec.subspecs.xmss.rand import TEST_RAND from lean_spec.subspecs.xmss.tweak_hash import TEST_TWEAK_HASHER from lean_spec.subspecs.xmss.types import HashTreeLayer, Parameter from lean_spec.subspecs.xmss.utils import ( @@ -125,7 +125,7 @@ def test_bottom_tree_from_prf_key() -> None: bottom_tree = bottom_tree_from_prf_key( prf=TEST_PRF, hasher=TEST_TWEAK_HASHER, - merkle_tree=TEST_MERKLE_TREE, + rand=TEST_RAND, config=config, prf_key=prf_key, bottom_tree_index=Uint64(0), @@ -159,7 +159,7 @@ def test_bottom_tree_from_prf_key_deterministic() -> None: tree1 = bottom_tree_from_prf_key( prf=TEST_PRF, hasher=TEST_TWEAK_HASHER, - merkle_tree=TEST_MERKLE_TREE, + rand=TEST_RAND, config=config, prf_key=prf_key, bottom_tree_index=Uint64(0), @@ -169,7 +169,7 @@ def test_bottom_tree_from_prf_key_deterministic() -> None: tree2 = bottom_tree_from_prf_key( prf=TEST_PRF, hasher=TEST_TWEAK_HASHER, - merkle_tree=TEST_MERKLE_TREE, + rand=TEST_RAND, config=config, prf_key=prf_key, bottom_tree_index=Uint64(0), @@ -195,7 +195,7 @@ def test_bottom_tree_from_prf_key_different_indices() -> None: tree0 = bottom_tree_from_prf_key( prf=TEST_PRF, hasher=TEST_TWEAK_HASHER, - merkle_tree=TEST_MERKLE_TREE, + rand=TEST_RAND, config=config, prf_key=prf_key, bottom_tree_index=Uint64(0), @@ -205,7 +205,7 @@ def test_bottom_tree_from_prf_key_different_indices() -> None: tree1 = bottom_tree_from_prf_key( prf=TEST_PRF, hasher=TEST_TWEAK_HASHER, - merkle_tree=TEST_MERKLE_TREE, + rand=TEST_RAND, config=config, prf_key=prf_key, bottom_tree_index=Uint64(1),