88
99from __future__ import annotations
1010
11- from typing import List , Tuple , cast
11+ from typing import List , Tuple
1212
1313from pydantic import model_validator
1414
2626 XmssConfig ,
2727)
2828from .containers import PublicKey , SecretKey , Signature
29- from .merkle_tree import (
30- PROD_MERKLE_TREE ,
31- TEST_MERKLE_TREE ,
32- MerkleTree ,
33- )
3429from .prf import PROD_PRF , TEST_PRF , Prf
3530from .rand import PROD_RAND , TEST_RAND , Rand
3631from .tweak_hash import (
@@ -60,9 +55,6 @@ class GeneralizedXmssScheme(StrictBaseModel):
6055 hasher : TweakHasher
6156 """Hash function with tweakable domain separation."""
6257
63- merkle_tree : MerkleTree
64- """Merkle tree implementation for authentication paths."""
65-
6658 encoder : TargetSumEncoder
6759 """Message encoder that produces valid codewords."""
6860
@@ -76,7 +68,6 @@ def enforce_strict_types(self) -> "GeneralizedXmssScheme":
7668 "config" : XmssConfig ,
7769 "prf" : Prf ,
7870 "hasher" : TweakHasher ,
79- "merkle_tree" : MerkleTree ,
8071 "encoder" : TargetSumEncoder ,
8172 "rand" : Rand ,
8273 }
@@ -175,7 +166,7 @@ def key_gen(
175166 left_bottom_tree = bottom_tree_from_prf_key (
176167 self .prf ,
177168 self .hasher ,
178- self .merkle_tree ,
169+ self .rand ,
179170 config ,
180171 prf_key ,
181172 Uint64 (start_bottom_tree_index ),
@@ -184,7 +175,7 @@ def key_gen(
184175 right_bottom_tree = bottom_tree_from_prf_key (
185176 self .prf ,
186177 self .hasher ,
187- self .merkle_tree ,
178+ self .rand ,
188179 config ,
189180 prf_key ,
190181 Uint64 (start_bottom_tree_index + 1 ),
@@ -202,7 +193,7 @@ def key_gen(
202193 tree = bottom_tree_from_prf_key (
203194 self .prf ,
204195 self .hasher ,
205- self .merkle_tree ,
196+ self .rand ,
206197 config ,
207198 prf_key ,
208199 Uint64 (i ),
@@ -453,9 +444,7 @@ def verify(self, pk: PublicKey, epoch: Uint64, message: bytes, sig: Signature) -
453444 chain_ends : List [List [Fp ]] = []
454445 for chain_index , xi in enumerate (codeword ):
455446 # The signature provides `start_digest`, which is the hash value after `xi` steps.
456- # Extract from SSZ type: HashDigestList -> HashDigestVector -> List[Fp]
457- hash_vector = cast (HashDigestVector , sig .hashes [chain_index ])
458- start_digest = cast (List [Fp ], list (hash_vector .data ))
447+ start_digest : List [Fp ] = list (sig .hashes [chain_index ])
459448 # We must perform the remaining `BASE - 1 - xi` hashing steps
460449 # to compute the public endpoint of the chain.
461450 num_steps_remaining = config .BASE - 1 - xi
@@ -475,7 +464,10 @@ def verify(self, pk: PublicKey, epoch: Uint64, message: bytes, sig: Signature) -
475464 # - Hashes the `chain_ends` to get the leaf node for the epoch,
476465 # - Uses the `opening` path from the signature to compute a candidate root.
477466 # - It returns true if and only if this candidate root matches the public key's root.
478- return self .merkle_tree .verify_path (
467+ from .subtree import verify_path
468+
469+ return verify_path (
470+ hasher = self .hasher ,
479471 parameter = pk .parameter ,
480472 root = pk .root ,
481473 position = epoch ,
@@ -568,7 +560,7 @@ def advance_preparation(self, sk: SecretKey) -> SecretKey:
568560 new_right_bottom_tree = bottom_tree_from_prf_key (
569561 prf = self .prf ,
570562 hasher = self .hasher ,
571- merkle_tree = self .merkle_tree ,
563+ rand = self .rand ,
572564 config = self .config ,
573565 prf_key = sk .prf_key ,
574566 bottom_tree_index = new_right_tree_index ,
@@ -589,7 +581,6 @@ def advance_preparation(self, sk: SecretKey) -> SecretKey:
589581 config = PROD_CONFIG ,
590582 prf = PROD_PRF ,
591583 hasher = PROD_TWEAK_HASHER ,
592- merkle_tree = PROD_MERKLE_TREE ,
593584 encoder = PROD_TARGET_SUM_ENCODER ,
594585 rand = PROD_RAND ,
595586)
@@ -599,7 +590,6 @@ def advance_preparation(self, sk: SecretKey) -> SecretKey:
599590 config = TEST_CONFIG ,
600591 prf = TEST_PRF ,
601592 hasher = TEST_TWEAK_HASHER ,
602- merkle_tree = TEST_MERKLE_TREE ,
603593 encoder = TEST_TARGET_SUM_ENCODER ,
604594 rand = TEST_RAND ,
605595)
0 commit comments