Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 10 additions & 20 deletions src/lean_spec/subspecs/xmss/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from __future__ import annotations

from typing import List, Tuple, cast
from typing import List, Tuple

from pydantic import model_validator

Expand All @@ -26,11 +26,6 @@
XmssConfig,
)
from .containers import PublicKey, SecretKey, Signature
from .merkle_tree import (
PROD_MERKLE_TREE,
TEST_MERKLE_TREE,
MerkleTree,
)
from .prf import PROD_PRF, TEST_PRF, Prf
from .rand import PROD_RAND, TEST_RAND, Rand
from .tweak_hash import (
Expand Down Expand Up @@ -60,9 +55,6 @@ class GeneralizedXmssScheme(StrictBaseModel):
hasher: TweakHasher
"""Hash function with tweakable domain separation."""

merkle_tree: MerkleTree
"""Merkle tree implementation for authentication paths."""

encoder: TargetSumEncoder
"""Message encoder that produces valid codewords."""

Expand All @@ -76,7 +68,6 @@ def enforce_strict_types(self) -> "GeneralizedXmssScheme":
"config": XmssConfig,
"prf": Prf,
"hasher": TweakHasher,
"merkle_tree": MerkleTree,
"encoder": TargetSumEncoder,
"rand": Rand,
}
Expand Down Expand Up @@ -175,7 +166,7 @@ def key_gen(
left_bottom_tree = bottom_tree_from_prf_key(
self.prf,
self.hasher,
self.merkle_tree,
self.rand,
config,
prf_key,
Uint64(start_bottom_tree_index),
Expand All @@ -184,7 +175,7 @@ def key_gen(
right_bottom_tree = bottom_tree_from_prf_key(
self.prf,
self.hasher,
self.merkle_tree,
self.rand,
config,
prf_key,
Uint64(start_bottom_tree_index + 1),
Expand All @@ -202,7 +193,7 @@ def key_gen(
tree = bottom_tree_from_prf_key(
self.prf,
self.hasher,
self.merkle_tree,
self.rand,
config,
prf_key,
Uint64(i),
Expand Down Expand Up @@ -453,9 +444,7 @@ def verify(self, pk: PublicKey, epoch: Uint64, message: bytes, sig: Signature) -
chain_ends: List[List[Fp]] = []
for chain_index, xi in enumerate(codeword):
# The signature provides `start_digest`, which is the hash value after `xi` steps.
# Extract from SSZ type: HashDigestList -> HashDigestVector -> List[Fp]
hash_vector = cast(HashDigestVector, sig.hashes[chain_index])
start_digest = cast(List[Fp], list(hash_vector.data))
start_digest: List[Fp] = list(sig.hashes[chain_index])
# We must perform the remaining `BASE - 1 - xi` hashing steps
# to compute the public endpoint of the chain.
num_steps_remaining = config.BASE - 1 - xi
Expand All @@ -475,7 +464,10 @@ def verify(self, pk: PublicKey, epoch: Uint64, message: bytes, sig: Signature) -
# - Hashes the `chain_ends` to get the leaf node for the epoch,
# - Uses the `opening` path from the signature to compute a candidate root.
# - It returns true if and only if this candidate root matches the public key's root.
return self.merkle_tree.verify_path(
from .subtree import verify_path

return verify_path(
hasher=self.hasher,
parameter=pk.parameter,
root=pk.root,
position=epoch,
Expand Down Expand Up @@ -568,7 +560,7 @@ def advance_preparation(self, sk: SecretKey) -> SecretKey:
new_right_bottom_tree = bottom_tree_from_prf_key(
prf=self.prf,
hasher=self.hasher,
merkle_tree=self.merkle_tree,
rand=self.rand,
config=self.config,
prf_key=sk.prf_key,
bottom_tree_index=new_right_tree_index,
Expand All @@ -589,7 +581,6 @@ def advance_preparation(self, sk: SecretKey) -> SecretKey:
config=PROD_CONFIG,
prf=PROD_PRF,
hasher=PROD_TWEAK_HASHER,
merkle_tree=PROD_MERKLE_TREE,
encoder=PROD_TARGET_SUM_ENCODER,
rand=PROD_RAND,
)
Expand All @@ -599,7 +590,6 @@ def advance_preparation(self, sk: SecretKey) -> SecretKey:
config=TEST_CONFIG,
prf=TEST_PRF,
hasher=TEST_TWEAK_HASHER,
merkle_tree=TEST_MERKLE_TREE,
encoder=TEST_TARGET_SUM_ENCODER,
rand=TEST_RAND,
)
Expand Down
Loading
Loading