diff --git a/src/lean_spec/subspecs/xmss/interface.py b/src/lean_spec/subspecs/xmss/interface.py index db6de779..d13a058d 100644 --- a/src/lean_spec/subspecs/xmss/interface.py +++ b/src/lean_spec/subspecs/xmss/interface.py @@ -19,7 +19,6 @@ ) from lean_spec.types import StrictBaseModel, Uint64 -from ..koalabear import Fp from .constants import ( PROD_CONFIG, TEST_CONFIG, @@ -330,7 +329,7 @@ def sign(self, sk: SecretKey, epoch: Uint64, message: bytes) -> Signature: raise RuntimeError("Encoding is broken: returned too many or too few chunks.") # Compute the one-time signature hashes based on the codeword. - ots_hashes: List[List[Fp]] = [] + ots_hashes: List[HashDigestVector] = [] for chain_index, steps in enumerate(codeword): # Derive the secret start of the current chain using the master PRF key. start_digest = self.prf.apply(sk.prf_key, epoch, Uint64(chain_index)) @@ -380,11 +379,9 @@ def sign(self, sk: SecretKey, epoch: Uint64, message: bytes) -> Signature: # - The OTS, # - The Merkle path, # - The randomness `rho` needed for verification. - # Wrap ots_hashes in SSZ types - from .types import HashDigestList, HashDigestVector + from .types import HashDigestList - ssz_hashes = [HashDigestVector(data=hash_digest) for hash_digest in ots_hashes] - return Signature(path=path, rho=rho, hashes=HashDigestList(data=ssz_hashes)) + return Signature(path=path, rho=rho, hashes=HashDigestList(data=ots_hashes)) def verify(self, pk: PublicKey, epoch: Uint64, message: bytes, sig: Signature) -> bool: r""" @@ -441,10 +438,10 @@ def verify(self, pk: PublicKey, epoch: Uint64, message: bytes, sig: Signature) - return False # Reconstruct the one-time public key (the list of chain endpoints). - chain_ends: List[List[Fp]] = [] + chain_ends: List[HashDigestVector] = [] for chain_index, xi in enumerate(codeword): # The signature provides `start_digest`, which is the hash value after `xi` steps. - start_digest: List[Fp] = list(sig.hashes[chain_index]) + start_digest = sig.hashes[chain_index] # We must perform the remaining `BASE - 1 - xi` hashing steps # to compute the public endpoint of the chain. num_steps_remaining = config.BASE - 1 - xi diff --git a/src/lean_spec/subspecs/xmss/prf.py b/src/lean_spec/subspecs/xmss/prf.py index c7542c98..7a8c3fd5 100644 --- a/src/lean_spec/subspecs/xmss/prf.py +++ b/src/lean_spec/subspecs/xmss/prf.py @@ -11,7 +11,6 @@ import hashlib import os -from typing import List from pydantic import model_validator @@ -24,7 +23,7 @@ TEST_CONFIG, XmssConfig, ) -from .types import PRFKey, Randomness +from .types import HashDigestVector, PRFKey, Randomness PRF_DOMAIN_SEP: bytes = bytes( [ @@ -106,7 +105,7 @@ def key_gen(self) -> PRFKey: """ return PRFKey(os.urandom(PRF_KEY_LENGTH)) - def apply(self, key: PRFKey, epoch: Uint64, chain_index: Uint64) -> List[Fp]: + def apply(self, key: PRFKey, epoch: Uint64, chain_index: Uint64) -> HashDigestVector: """ Applies the PRF to derive the secret starting value for a single hash chain. @@ -127,8 +126,7 @@ def apply(self, key: PRFKey, epoch: Uint64, chain_index: Uint64) -> List[Fp]: chain_index: The index of the hash chain within that epoch's OTS. Returns: - A list of field elements representing the secret start of a single - hash chain (i.e., a `HashDigest`). + A hash digest representing the secret start of a single hash chain. """ # Retrieve the scheme's configuration parameters. config = self.config @@ -160,15 +158,17 @@ def apply(self, key: PRFKey, epoch: Uint64, chain_index: Uint64) -> List[Fp]: # - Slice an 8-byte (64-bit) chunk from the `prf_output_bytes`. # - Convert that chunk from a big-endian byte representation to an integer. # - Create a field element from the integer (the Fp constructor handles the modulo). - return [ - Fp( - value=int.from_bytes( - prf_output_bytes[i * PRF_BYTES_PER_FE : (i + 1) * PRF_BYTES_PER_FE], - "big", + return HashDigestVector( + data=[ + Fp( + value=int.from_bytes( + prf_output_bytes[i * PRF_BYTES_PER_FE : (i + 1) * PRF_BYTES_PER_FE], + "big", + ) ) - ) - for i in range(config.HASH_LEN_FE) - ] + for i in range(config.HASH_LEN_FE) + ] + ) def get_randomness( self, key: PRFKey, epoch: Uint64, message: bytes, counter: Uint64 diff --git a/src/lean_spec/subspecs/xmss/rand.py b/src/lean_spec/subspecs/xmss/rand.py index 4854eb83..bf0e0a4d 100644 --- a/src/lean_spec/subspecs/xmss/rand.py +++ b/src/lean_spec/subspecs/xmss/rand.py @@ -9,7 +9,7 @@ from ..koalabear import Fp, P from .constants import PROD_CONFIG, TEST_CONFIG, XmssConfig -from .types import Parameter, Randomness +from .types import HashDigestVector, Parameter, Randomness class Rand(StrictBaseModel): @@ -34,9 +34,9 @@ def parameter(self) -> Parameter: """Generates a random public parameter.""" return Parameter(data=self.field_elements(self.config.PARAMETER_LEN)) - def domain(self) -> List[Fp]: + def domain(self) -> HashDigestVector: """Generates a random hash digest.""" - return self.field_elements(self.config.HASH_LEN_FE) + return HashDigestVector(data=self.field_elements(self.config.HASH_LEN_FE)) def rho(self) -> Randomness: """Generates randomness `rho` for message encoding.""" diff --git a/src/lean_spec/subspecs/xmss/subtree.py b/src/lean_spec/subspecs/xmss/subtree.py index 1811e234..451c5cc0 100644 --- a/src/lean_spec/subspecs/xmss/subtree.py +++ b/src/lean_spec/subspecs/xmss/subtree.py @@ -12,7 +12,6 @@ from lean_spec.types import Uint64 from lean_spec.types.container import Container -from ..koalabear import Fp from .tweak_hash import TreeTweak from .types import ( HashDigestList, @@ -105,7 +104,7 @@ def new( depth: int, start_index: Uint64, parameter: Parameter, - lowest_layer_nodes: List[List[Fp]], + lowest_layer_nodes: List[HashDigestVector], ) -> HashSubTree: """ Builds a new sparse Merkle subtree starting from a specified layer. @@ -143,48 +142,37 @@ def new( Returns: A `HashSubTree` containing all computed layers from `lowest_layer` to root. """ - # Validate that we have enough space in the tree for these nodes. - # At layer `lowest_layer`, there are 2^(depth - lowest_layer) possible positions. - max_index_at_layer = 1 << (depth - lowest_layer) - if start_index + Uint64(len(lowest_layer_nodes)) > Uint64(max_index_at_layer): + # Validate: nodes must fit in available positions at this layer. + max_positions = 1 << (depth - lowest_layer) + if int(start_index) + len(lowest_layer_nodes) > max_positions: raise ValueError( - f"Not enough space at layer {lowest_layer}: " - f"start_index={start_index}, nodes={len(lowest_layer_nodes)}, " - f"max={max_index_at_layer}" + f"Overflow at layer {lowest_layer}: " + f"start={start_index}, count={len(lowest_layer_nodes)}, max={max_positions}" ) - # Start with the lowest layer nodes and apply initial padding. + # Initialize with padded input layer. layers: List[HashTreeLayer] = [] - current_layer = get_padded_layer(rand, lowest_layer_nodes, start_index) - layers.append(current_layer) + current = get_padded_layer(rand, lowest_layer_nodes, start_index) + layers.append(current) - # Build the tree layer by layer from lowest_layer up to the root. + # Build upward: hash pairs of children to create parents. for level in range(lowest_layer, depth): - parents: List[List[Fp]] = [] - - # Group current layer's nodes into pairs of (left, right) siblings. - # The padding guarantees this works perfectly without orphan nodes. - num_nodes = len(current_layer.nodes) - for i in range(0, num_nodes, 2): - left_node = current_layer.nodes[i] - right_node = current_layer.nodes[i + 1] - - # Calculate the position of the parent node in the next level up. - parent_index = (current_layer.start_index // Uint64(2)) + Uint64(i // 2) - # Create the tweak for hashing these two children. - tweak = TreeTweak(level=level + 1, index=parent_index) - # Hash the left and right children to get their parent. - left_data = list(left_node) - right_data = list(right_node) - parent_node = hasher.apply(parameter, tweak, [left_data, right_data]) - parents.append(parent_node) - - # Pad the new list of parents to prepare for the next iteration. - new_start_index = current_layer.start_index // Uint64(2) - current_layer = get_padded_layer(rand, parents, new_start_index) - layers.append(current_layer) - - # Return the completed subtree. + parent_start = current.start_index // Uint64(2) + + # Hash each pair of siblings into their parent. + parents = [ + hasher.apply( + parameter, + TreeTweak(level=level + 1, index=int(parent_start) + i), + [current.nodes[2 * i], current.nodes[2 * i + 1]], + ) + for i in range(len(current.nodes) // 2) + ] + + # Pad and store the new layer. + current = get_padded_layer(rand, parents, parent_start) + layers.append(current) + return cls( depth=Uint64(depth), lowest_layer=Uint64(lowest_layer), @@ -234,26 +222,17 @@ def new_top_tree( ValueError: If depth is odd (top-bottom split requires even depth). """ if depth % 2 != 0: - raise ValueError( - f"Top-bottom tree split requires even depth, got {depth}. " - f"The top tree must start at depth/2, which must be an integer." - ) - - # The top tree starts at the middle layer. - lowest_layer = depth // 2 + raise ValueError(f"Depth must be even for top-bottom split, got {depth}.") - # Convert HashDigestVector roots to List[Fp] for building - roots_as_lists: List[List[Fp]] = [list(root) for root in bottom_tree_roots] - - # Build the top tree using the bottom tree roots as the lowest layer. + # Build from middle layer using bottom tree roots as leaves. return cls.new( hasher=hasher, rand=rand, - lowest_layer=lowest_layer, + lowest_layer=depth // 2, depth=depth, start_index=start_bottom_tree_index, parameter=parameter, - lowest_layer_nodes=roots_as_lists, + lowest_layer_nodes=bottom_tree_roots, ) @classmethod @@ -264,7 +243,7 @@ def new_bottom_tree( depth: int, bottom_tree_index: Uint64, parameter: Parameter, - leaves: List[List[Fp]], + leaves: List[HashDigestVector], ) -> HashSubTree: """ Constructs a single bottom tree from leaf hashes. @@ -303,60 +282,40 @@ def new_bottom_tree( ValueError: If depth is odd or leaves count doesn't match `sqrt(LIFETIME)`. """ if depth % 2 != 0: - raise ValueError( - f"Top-bottom tree split requires even depth, got {depth}. " - f"Bottom trees must span exactly depth/2 layers." - ) + raise ValueError(f"Depth must be even for top-bottom split, got {depth}.") - leafs_per_bottom_tree = 1 << (depth // 2) - if len(leaves) != leafs_per_bottom_tree: + # Each bottom tree has exactly sqrt(LIFETIME) leaves. + leafs_per_tree = 1 << (depth // 2) + if len(leaves) != leafs_per_tree: raise ValueError( - f"Bottom tree must have exactly {leafs_per_bottom_tree} leaves " - f"(sqrt(LIFETIME) for depth={depth}), got {len(leaves)}" + f"Expected {leafs_per_tree} leaves for depth={depth}, got {len(leaves)}." ) - # Calculate the starting index for this bottom tree's leaves. - start_index = bottom_tree_index * Uint64(leafs_per_bottom_tree) - - # Build a full subtree from layer 0 using the leaves. + # Build full tree from leaves. full_tree = cls.new( hasher=hasher, rand=rand, lowest_layer=0, depth=depth, - start_index=start_index, + start_index=bottom_tree_index * Uint64(leafs_per_tree), parameter=parameter, lowest_layer_nodes=leaves, ) - # Truncate to remove upper layers that would be incompatible with the top tree. - # We keep layers 0 through depth/2 (inclusive). - # - # Extract the root at layer depth/2. The root's index in that layer is either - # the bottom_tree_index (if it's the left child of its parent in the top tree) - # or bottom_tree_index (if it's the right child). Since we're at layer depth/2, - # the position is simply bottom_tree_index. - middle_layer = full_tree.layers[depth // 2] - - # The root is at position (start_index >> (depth // 2)) = bottom_tree_index - # within the middle layer. We need to find it in the stored nodes. - root_position_in_layer = bottom_tree_index - middle_layer.start_index - root_node = middle_layer.nodes[int(root_position_in_layer)] - root: List[Fp] = list(root_node) - - # Truncate layers to keep only 0 through depth/2 - 1. - truncated_layers: List[HashTreeLayer] = [full_tree.layers[i] for i in range(depth // 2)] - - # Add a final layer containing just the root. + # Extract root from middle layer. + middle = full_tree.layers[depth // 2] + root_idx = int(bottom_tree_index - middle.start_index) root_layer = HashTreeLayer( - start_index=bottom_tree_index, nodes=HashDigestList(data=[HashDigestVector(data=root)]) + start_index=bottom_tree_index, + nodes=HashDigestList(data=[middle.nodes[root_idx]]), ) - truncated_layers.append(root_layer) + # Keep bottom half + single root node. + truncated = [full_tree.layers[i] for i in range(depth // 2)] return cls( depth=Uint64(depth), lowest_layer=Uint64(0), - layers=HashTreeLayers(data=truncated_layers), + layers=HashTreeLayers(data=truncated + [root_layer]), ) def root(self) -> HashDigestVector: @@ -372,16 +331,11 @@ def root(self) -> HashDigestVector: Raises: ValueError: If the subtree has no layers or the highest layer is empty. """ - if len(self.layers) == 0: - raise ValueError("Cannot get root of empty subtree.") - - highest_layer = self.layers[-1] - if len(highest_layer.nodes) == 0: - raise ValueError("Highest layer of subtree is empty.") - - # The root is the only node in the highest layer for proper subtrees. - # For top trees and bottom trees, the highest layer should have exactly one node. - return highest_layer.nodes[0] + if not self.layers: + raise ValueError("Empty subtree has no root.") + if not self.layers[-1].nodes: + raise ValueError("Top layer is empty.") + return self.layers[-1].nodes[0] def path(self, position: Uint64) -> HashTreeOpening: """ @@ -403,46 +357,37 @@ def path(self, position: Uint64) -> HashTreeOpening: Raises: ValueError: If the subtree is empty or the position is out of bounds. """ - if len(self.layers) == 0: - raise ValueError("Cannot generate path for empty subtree.") + if not self.layers: + raise ValueError("Empty subtree.") - first_layer = self.layers[0] - if position < first_layer.start_index: - raise ValueError("Position is before the subtree's start index.") + # Check bounds. + first = self.layers[0] + if not (first.start_index <= position < first.start_index + Uint64(len(first.nodes))): + raise ValueError(f"Position {position} out of bounds.") - if position >= first_layer.start_index + Uint64(len(first_layer.nodes)): - raise ValueError("Position is beyond the subtree's range.") - - # Build the co-path directly with SSZ types + # Collect sibling at each layer (except root). siblings: List[HashDigestVector] = [] - current_position = position + pos = position - # Iterate through layers from lowest to highest, EXCLUDING the final root layer. - # The root layer doesn't contribute a sibling to the authentication path. + # Iterate over all layers except the last (root). num_layers = len(self.layers) - for layer_idx in range(num_layers - 1): - layer = self.layers[layer_idx] - # Determine the sibling's position by flipping the last bit. - sibling_position = current_position ^ Uint64(1) - sibling_index = int(sibling_position - layer.start_index) - - # Ensure the sibling exists in this layer - if sibling_index < 0 or sibling_index >= len(layer.nodes): - raise ValueError( - f"Sibling index {sibling_index} out of bounds for layer " - f"with {len(layer.nodes)} nodes" - ) - - siblings.append(layer.nodes[sibling_index]) + for i in range(num_layers - 1): + layer = self.layers[i] + # Sibling index: flip last bit of position, adjust for layer offset. + sibling_idx = int((pos ^ Uint64(1)) - layer.start_index) + if not (0 <= sibling_idx < len(layer.nodes)): + raise ValueError(f"Sibling index {sibling_idx} out of bounds.") - # Move to the parent's position for the next iteration. - current_position = current_position // Uint64(2) + siblings.append(layer.nodes[sibling_idx]) + pos = pos // Uint64(2) # Move to parent position. return HashTreeOpening(siblings=HashDigestList(data=siblings)) def combined_path( - top_tree: HashSubTree, bottom_tree: HashSubTree, position: Uint64 + top_tree: HashSubTree, + bottom_tree: HashSubTree, + position: Uint64, ) -> HashTreeOpening: """ Generates a combined authentication path spanning top and bottom trees. @@ -480,58 +425,31 @@ def combined_path( Raises: ValueError: If trees have mismatched depths, odd depth, or position is - out of bounds for the bottom tree. + out of bounds for the bottom tree. """ - # Validate that both trees have the same depth. + # Validate matching depths. if top_tree.depth != bottom_tree.depth: - raise ValueError( - f"Top and bottom trees must have same depth: " - f"top={top_tree.depth}, bottom={bottom_tree.depth}" - ) - - depth = top_tree.depth - - # Validate even depth (required for top-bottom split). - if depth % Uint64(2) != Uint64(0): - raise ValueError( - f"Top-bottom tree traversal requires even depth, got {depth}. " - f"Cannot split tree into equal top and bottom halves." - ) + raise ValueError(f"Depth mismatch: top={top_tree.depth}, bottom={bottom_tree.depth}.") - # Calculate parameters for bottom trees. - leafs_per_bottom_tree = 1 << int(depth // Uint64(2)) + depth = int(top_tree.depth) + if depth % 2 != 0: + raise ValueError(f"Depth must be even, got {depth}.") - # Determine which bottom tree this position belongs to. - # - # Bottom tree index = floor(position / sqrt(LIFETIME)) - bottom_tree_index = position // Uint64(leafs_per_bottom_tree) - - # Verify that the provided bottom_tree actually corresponds to this position. - # The bottom tree's lowest layer starts at bottom_tree_index * leafs_per_bottom_tree. - expected_start = bottom_tree_index * Uint64(leafs_per_bottom_tree) - actual_start = bottom_tree.layers[0].start_index - - if actual_start != expected_start: + # Validate bottom tree matches position. + leafs_per_tree = Uint64(1 << (depth // 2)) + expected_start = (position // leafs_per_tree) * leafs_per_tree + if bottom_tree.layers[0].start_index != expected_start: raise ValueError( - f"Bottom tree mismatch: position {position} belongs to " - f"bottom tree {bottom_tree_index} (should start at {expected_start}), " - f"but provided bottom tree starts at {actual_start}" + f"Wrong bottom tree: position {position} needs start {expected_start}, " + f"got {bottom_tree.layers[0].start_index}." ) - # Get the authentication path within the bottom tree (from leaf to bottom tree root). + # Concatenate: bottom path + top path. bottom_path = bottom_tree.path(position) + top_path = top_tree.path(position // leafs_per_tree) + combined = bottom_path.siblings.data + top_path.siblings.data - # Get the authentication path within the top tree (from bottom tree root to global root). - # The bottom tree's root is at position `bottom_tree_index` in the top tree's lowest layer. - top_path = top_tree.path(Uint64(bottom_tree_index)) - - # Concatenate the two paths: bottom siblings first, then top siblings. - # This creates a complete path from leaf to global root. - combined_siblings: List[HashDigestVector] = [ - bottom_path.siblings[i] for i in range(len(bottom_path.siblings)) - ] + [top_path.siblings[i] for i in range(len(top_path.siblings))] - - return HashTreeOpening(siblings=HashDigestList(data=combined_siblings)) + return HashTreeOpening(siblings=HashDigestList(data=combined)) def verify_path( @@ -539,7 +457,7 @@ def verify_path( parameter: Parameter, root: HashDigestVector, position: Uint64, - leaf_parts: List[List[Fp]], + leaf_parts: List[HashDigestVector], opening: HashTreeOpening, ) -> bool: """ @@ -584,41 +502,33 @@ def verify_path( """ # Compute the depth depth = len(opening.siblings) - # Compute the number of leafs in the tree - num_leafs = 2**depth - # Check that the tree depth is at most 32. - if len(opening.siblings) > 32: - raise ValueError("Tree depth must be at most 32.") - # Check that the position and path length match. - if int(position) >= num_leafs: - raise ValueError("Position and path length do not match.") - - # The first step is to hash the constituent parts of the leaf to get - # the actual node at layer 0 of the tree. - leaf_tweak = TreeTweak(level=0, index=int(position)) - current_node = hasher.apply(parameter, leaf_tweak, leaf_parts) - - # Iterate up the tree, hashing the current node with its sibling from - # the path at each level. - current_position = int(position) - for level in range(len(opening.siblings)): - sibling_node: List[Fp] = list(opening.siblings[level]) - - # Determine if the current node is a left or right child. - if current_position % 2 == 0: - # Current node is a left child; sibling is on the right. - children = [current_node, sibling_node] + if depth > 32: + raise ValueError("Depth exceeds maximum of 32.") + if int(position) >= (1 << depth): + raise ValueError("Position exceeds tree capacity.") + + # Start: hash leaf parts to get leaf node. + current = hasher.apply( + parameter, + TreeTweak(level=0, index=int(position)), + leaf_parts, + ) + pos = int(position) + + # Walk up: hash current with each sibling. + for level, sibling in enumerate(opening.siblings): + # Left child has even position, right child has odd. + if pos % 2 == 0: + left, right = current, sibling else: - # Current node is a right child; sibling is on the left. - children = [sibling_node, current_node] - - # Move up to the parent's position for the next iteration. - current_position //= 2 - # Create the tweak for the parent's level and position. - parent_tweak = TreeTweak(level=level + 1, index=current_position) - # Hash the children to compute the parent node. - current_node = hasher.apply(parameter, parent_tweak, children) - - # After iterating through the entire path, the final computed node - # should be the root of the tree. - return current_node == list(root) + left, right = sibling, current + + pos //= 2 # Parent position. + current = hasher.apply( + parameter, + TreeTweak(level=level + 1, index=pos), + [left, right], + ) + + # Valid if we reconstructed the expected root. + return current == root diff --git a/src/lean_spec/subspecs/xmss/tweak_hash.py b/src/lean_spec/subspecs/xmss/tweak_hash.py index 6fa9597f..b3b7ca8a 100644 --- a/src/lean_spec/subspecs/xmss/tweak_hash.py +++ b/src/lean_spec/subspecs/xmss/tweak_hash.py @@ -26,8 +26,7 @@ from __future__ import annotations -from itertools import chain -from typing import List, Union, cast +from typing import List, Union from pydantic import Field, model_validator @@ -46,7 +45,7 @@ TEST_POSEIDON, PoseidonXmss, ) -from .types import Parameter +from .types import HashDigestVector, Parameter from .utils import int_to_base_p @@ -150,8 +149,8 @@ def apply( self, parameter: Parameter, tweak: Union[TreeTweak, ChainTweak], - message_parts: List[List[Fp]], - ) -> List[Fp]: + message_parts: List[HashDigestVector], + ) -> HashDigestVector: """ Applies the tweakable Poseidon2 hash function to a message. @@ -189,28 +188,30 @@ def apply( # Case 1: Hashing a single digest (used in hash chains). # # We use the efficient width-16 compression mode. - input_vec = cast(List[Fp], list(parameter.data)) + encoded_tweak + message_parts[0] - return self.poseidon.compress(input_vec, 16, config.HASH_LEN_FE) + input_vec = parameter.elements + encoded_tweak + message_parts[0].elements + result = self.poseidon.compress(input_vec, 16, config.HASH_LEN_FE) elif len(message_parts) == 2: # Case 2: Hashing two digests (used for Merkle tree nodes). # # We use the slightly larger width-24 compression mode. input_vec = ( - cast(List[Fp], list(parameter.data)) + parameter.elements + encoded_tweak - + message_parts[0] - + message_parts[1] + + message_parts[0].elements + + message_parts[1].elements ) - return self.poseidon.compress(input_vec, 24, config.HASH_LEN_FE) + result = self.poseidon.compress(input_vec, 24, config.HASH_LEN_FE) else: # Case 3: Hashing many digests (used for the Merkle tree leaf). # # We use the robust sponge mode. # First, flatten the list of message parts into a single vector. - flattened_message = list(chain.from_iterable(message_parts)) - input_vec = cast(List[Fp], list(parameter.data)) + encoded_tweak + flattened_message + flattened_message: List[Fp] = [] + for part in message_parts: + flattened_message.extend(part.elements) + input_vec = parameter.elements + encoded_tweak + flattened_message # Create a domain separator for the sponge mode based on the input dimensions. # @@ -223,7 +224,9 @@ def apply( ] capacity_value = self.poseidon.safe_domain_separator(lengths, config.CAPACITY) - return self.poseidon.sponge(input_vec, capacity_value, config.HASH_LEN_FE, 24) + result = self.poseidon.sponge(input_vec, capacity_value, config.HASH_LEN_FE, 24) + + return HashDigestVector(data=result) def hash_chain( self, @@ -232,8 +235,8 @@ def hash_chain( chain_index: int, start_step: int, num_steps: int, - start_digest: List[Fp], - ) -> List[Fp]: + start_digest: HashDigestVector, + ) -> HashDigestVector: """ Performs repeated hashing to traverse a WOTS+ hash chain. diff --git a/src/lean_spec/subspecs/xmss/types.py b/src/lean_spec/subspecs/xmss/types.py index 31f82541..1f4a9f2c 100644 --- a/src/lean_spec/subspecs/xmss/types.py +++ b/src/lean_spec/subspecs/xmss/types.py @@ -1,5 +1,7 @@ """Base types for the XMSS signature scheme.""" +from typing import List + from lean_spec.subspecs.koalabear import Fp from ...types import Uint64 @@ -61,6 +63,11 @@ class HashDigestVector(SSZVector): ELEMENT_TYPE = Fp LENGTH = HASH_DIGEST_LENGTH + @property + def elements(self) -> List[Fp]: + """Return the field elements as a typed list.""" + return list(self.data) # type: ignore[arg-type] + class HashDigestList(SSZList): """ @@ -93,6 +100,11 @@ class Parameter(SSZVector): ELEMENT_TYPE = Fp LENGTH = PROD_CONFIG.PARAMETER_LEN + @property + def elements(self) -> List[Fp]: + """Return the field elements as a typed list.""" + return list(self.data) # type: ignore[arg-type] + class Randomness(SSZVector): """ diff --git a/src/lean_spec/subspecs/xmss/utils.py b/src/lean_spec/subspecs/xmss/utils.py index 744048e9..b0f2dfde 100644 --- a/src/lean_spec/subspecs/xmss/utils.py +++ b/src/lean_spec/subspecs/xmss/utils.py @@ -14,7 +14,9 @@ from .tweak_hash import TweakHasher -def get_padded_layer(rand: Rand, nodes: List[List[Fp]], start_index: Uint64) -> HashTreeLayer: +def get_padded_layer( + rand: Rand, nodes: List[HashDigestVector], start_index: Uint64 +) -> HashTreeLayer: """ Pads a layer of nodes with random hashes to simplify tree construction. @@ -32,7 +34,7 @@ def get_padded_layer(rand: Rand, nodes: List[List[Fp]], start_index: Uint64) -> Returns: A new `HashTreeLayer` with the necessary padding applied. """ - nodes_with_padding: List[List[Fp]] = [] + nodes_with_padding: List[HashDigestVector] = [] end_index = start_index + Uint64(len(nodes)) - Uint64(1) # Prepend random padding if the layer starts at an odd index. @@ -50,10 +52,9 @@ def get_padded_layer(rand: Rand, nodes: List[List[Fp]], start_index: Uint64) -> if end_index % Uint64(2) == Uint64(0): nodes_with_padding.append(rand.domain()) - # Convert to SSZ-friendly types: each digest becomes a HashDigestVector, - # and the list becomes a HashDigestList. - ssz_nodes = [HashDigestVector(data=node) for node in nodes_with_padding] - return HashTreeLayer(start_index=actual_start_index, nodes=HashDigestList(data=ssz_nodes)) + return HashTreeLayer( + start_index=actual_start_index, nodes=HashDigestList(data=nodes_with_padding) + ) def int_to_base_p(value: int, num_limbs: int) -> List[Fp]: @@ -211,11 +212,11 @@ def bottom_tree_from_prf_key( end_epoch = start_epoch + Uint64(leafs_per_bottom_tree) # Generate leaf hashes for all epochs in this bottom tree. - leaf_hashes: List[List[Fp]] = [] + leaf_hashes: List[HashDigestVector] = [] for epoch in range(int(start_epoch), int(end_epoch)): # For each epoch, compute the one-time public key (chain endpoints). - chain_ends: List[List[Fp]] = [] + chain_ends: List[HashDigestVector] = [] for chain_index in range(config.DIMENSION): # Derive the secret start of the chain from the PRF key. diff --git a/tests/lean_spec/subspecs/xmss/test_merkle_tree.py b/tests/lean_spec/subspecs/xmss/test_merkle_tree.py index 282e7496..b1fa6934 100644 --- a/tests/lean_spec/subspecs/xmss/test_merkle_tree.py +++ b/tests/lean_spec/subspecs/xmss/test_merkle_tree.py @@ -2,7 +2,6 @@ import pytest -from lean_spec.subspecs.koalabear import Fp from lean_spec.subspecs.xmss.rand import PROD_RAND, Rand from lean_spec.subspecs.xmss.subtree import HashSubTree, verify_path from lean_spec.subspecs.xmss.tweak_hash import ( @@ -10,6 +9,7 @@ TreeTweak, TweakHasher, ) +from lean_spec.subspecs.xmss.types import HashDigestVector from lean_spec.types import Uint64 @@ -41,12 +41,12 @@ def _run_commit_open_verify_roundtrip( """ # SETUP: Generate a random parameter and the raw leaf data. parameter = rand.parameter() - leaves: list[list[list[Fp]]] = [ + leaves: list[list[HashDigestVector]] = [ [rand.domain() for _ in range(leaf_parts_len)] for _ in range(num_leaves) ] # HASH LEAVES: Compute the layer 0 nodes by hashing the leaf parts. - leaf_hashes: list[list[Fp]] = [ + leaf_hashes: list[HashDigestVector] = [ hasher.apply( parameter, TreeTweak(level=0, index=start_index + i),