Skip to content

Commit e73b15c

Browse files
authored
xmss: mv from_prf_key method to HashSubTree (#232)
1 parent d6cec9b commit e73b15c

File tree

4 files changed

+131
-139
lines changed

4 files changed

+131
-139
lines changed

src/lean_spec/subspecs/xmss/interface.py

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,14 @@
2727
from .containers import PublicKey, SecretKey, Signature
2828
from .prf import PROD_PRF, TEST_PRF, Prf
2929
from .rand import PROD_RAND, TEST_RAND, Rand
30+
from .subtree import HashSubTree
3031
from .tweak_hash import (
3132
PROD_TWEAK_HASHER,
3233
TEST_TWEAK_HASHER,
3334
TweakHasher,
3435
)
3536
from .types import HashDigestVector
36-
from .utils import bottom_tree_from_prf_key, expand_activation_time
37+
from .utils import expand_activation_time
3738

3839

3940
class GeneralizedXmssScheme(StrictBaseModel):
@@ -162,23 +163,23 @@ def key_gen(
162163
actual_num_active_epochs = num_bottom_trees * leafs_per_bottom_tree
163164

164165
# Step 2: Generate the first two bottom trees (kept in memory).
165-
left_bottom_tree = bottom_tree_from_prf_key(
166-
self.prf,
167-
self.hasher,
168-
self.rand,
169-
config,
170-
prf_key,
171-
Uint64(start_bottom_tree_index),
172-
parameter,
166+
left_bottom_tree = HashSubTree.from_prf_key(
167+
prf=self.prf,
168+
hasher=self.hasher,
169+
rand=self.rand,
170+
config=config,
171+
prf_key=prf_key,
172+
bottom_tree_index=Uint64(start_bottom_tree_index),
173+
parameter=parameter,
173174
)
174-
right_bottom_tree = bottom_tree_from_prf_key(
175-
self.prf,
176-
self.hasher,
177-
self.rand,
178-
config,
179-
prf_key,
180-
Uint64(start_bottom_tree_index + 1),
181-
parameter,
175+
right_bottom_tree = HashSubTree.from_prf_key(
176+
prf=self.prf,
177+
hasher=self.hasher,
178+
rand=self.rand,
179+
config=config,
180+
prf_key=prf_key,
181+
bottom_tree_index=Uint64(start_bottom_tree_index + 1),
182+
parameter=parameter,
182183
)
183184

184185
# Collect roots for building the top tree.
@@ -189,21 +190,18 @@ def key_gen(
189190

190191
# Step 3: Generate remaining bottom trees (only their roots).
191192
for i in range(start_bottom_tree_index + 2, end_bottom_tree_index):
192-
tree = bottom_tree_from_prf_key(
193-
self.prf,
194-
self.hasher,
195-
self.rand,
196-
config,
197-
prf_key,
198-
Uint64(i),
199-
parameter,
193+
tree = HashSubTree.from_prf_key(
194+
prf=self.prf,
195+
hasher=self.hasher,
196+
rand=self.rand,
197+
config=config,
198+
prf_key=prf_key,
199+
bottom_tree_index=Uint64(i),
200+
parameter=parameter,
200201
)
201-
root = tree.root()
202-
bottom_tree_roots.append(root)
202+
bottom_tree_roots.append(tree.root())
203203

204204
# Step 4: Build the top tree from bottom tree roots.
205-
from .subtree import HashSubTree
206-
207205
top_tree = HashSubTree.new_top_tree(
208206
hasher=self.hasher,
209207
rand=self.rand,
@@ -554,7 +552,7 @@ def advance_preparation(self, sk: SecretKey) -> SecretKey:
554552

555553
# Compute the next bottom tree (the one after the current right tree)
556554
new_right_tree_index = sk.left_bottom_tree_index + Uint64(2)
557-
new_right_bottom_tree = bottom_tree_from_prf_key(
555+
new_right_bottom_tree = HashSubTree.from_prf_key(
558556
prf=self.prf,
559557
hasher=self.hasher,
560558
rand=self.rand,

src/lean_spec/subspecs/xmss/subtree.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,13 @@
2020
HashTreeLayers,
2121
HashTreeOpening,
2222
Parameter,
23+
PRFKey,
2324
)
2425
from .utils import get_padded_layer
2526

2627
if TYPE_CHECKING:
28+
from .constants import XmssConfig
29+
from .prf import Prf
2730
from .rand import Rand
2831
from .tweak_hash import TweakHasher
2932

@@ -318,6 +321,93 @@ def new_bottom_tree(
318321
layers=HashTreeLayers(data=truncated + [root_layer]),
319322
)
320323

324+
@classmethod
325+
def from_prf_key(
326+
cls,
327+
prf: "Prf",
328+
hasher: "TweakHasher",
329+
rand: "Rand",
330+
config: "XmssConfig",
331+
prf_key: PRFKey,
332+
bottom_tree_index: Uint64,
333+
parameter: Parameter,
334+
) -> "HashSubTree":
335+
"""
336+
Generates a single bottom tree on-demand from the PRF key.
337+
338+
This is a key component of the top-bottom tree approach: instead of storing all
339+
one-time secret keys, we regenerate them on-demand using the PRF. This enables
340+
O(sqrt(LIFETIME)) memory usage.
341+
342+
### Algorithm
343+
344+
1. **Determine epoch range**: Bottom tree `i` covers epochs
345+
`[i * sqrt(LIFETIME), (i+1) * sqrt(LIFETIME))`
346+
347+
2. **Generate leaves**: For each epoch in parallel:
348+
- For each chain (0 to DIMENSION-1):
349+
- Derive secret start: `PRF(prf_key, epoch, chain_index)`
350+
- Compute public end: hash chain for `BASE - 1` steps
351+
- Hash all chain ends to get the leaf
352+
353+
3. **Build bottom tree**: Construct the bottom tree from the leaves
354+
355+
Args:
356+
prf: The PRF instance for key derivation.
357+
hasher: The tweakable hash instance.
358+
rand: Random generator for padding values.
359+
config: The XMSS configuration.
360+
prf_key: The master PRF secret key.
361+
bottom_tree_index: The index of the bottom tree to generate (0, 1, 2, ...).
362+
parameter: The public parameter `P` for the hash function.
363+
364+
Returns:
365+
A `HashSubTree` representing the requested bottom tree.
366+
"""
367+
# Calculate the number of leaves per bottom tree: sqrt(LIFETIME).
368+
leafs_per_bottom_tree = 1 << (config.LOG_LIFETIME // 2)
369+
370+
# Determine the epoch range for this bottom tree.
371+
start_epoch = bottom_tree_index * Uint64(leafs_per_bottom_tree)
372+
end_epoch = start_epoch + Uint64(leafs_per_bottom_tree)
373+
374+
# Generate leaf hashes for all epochs in this bottom tree.
375+
leaf_hashes: List[HashDigestVector] = []
376+
377+
for epoch in range(int(start_epoch), int(end_epoch)):
378+
# For each epoch, compute the one-time public key (chain endpoints).
379+
chain_ends: List[HashDigestVector] = []
380+
381+
for chain_index in range(config.DIMENSION):
382+
# Derive the secret start of the chain from the PRF key.
383+
start_digest = prf.apply(prf_key, Uint64(epoch), Uint64(chain_index))
384+
385+
# Compute the public end by hashing BASE - 1 times.
386+
end_digest = hasher.hash_chain(
387+
parameter=parameter,
388+
epoch=Uint64(epoch),
389+
chain_index=chain_index,
390+
start_step=0,
391+
num_steps=config.BASE - 1,
392+
start_digest=start_digest,
393+
)
394+
chain_ends.append(end_digest)
395+
396+
# Hash the chain ends to get the leaf for this epoch.
397+
leaf_tweak = TreeTweak(level=0, index=epoch)
398+
leaf_hash = hasher.apply(parameter, leaf_tweak, chain_ends)
399+
leaf_hashes.append(leaf_hash)
400+
401+
# Build the bottom tree from the leaf hashes.
402+
return cls.new_bottom_tree(
403+
hasher=hasher,
404+
rand=rand,
405+
depth=config.LOG_LIFETIME,
406+
bottom_tree_index=bottom_tree_index,
407+
parameter=parameter,
408+
leaves=leaf_hashes,
409+
)
410+
321411
def root(self) -> HashDigestVector:
322412
"""
323413
Extracts the root digest from this subtree.

src/lean_spec/subspecs/xmss/utils.py

Lines changed: 2 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,11 @@
11
"""Utility functions for the XMSS signature scheme."""
22

3-
from typing import TYPE_CHECKING, List
3+
from typing import List
44

55
from ...types.uint import Uint64
66
from ..koalabear import Fp, P
7-
from .constants import XmssConfig
87
from .rand import Rand
9-
from .types import HashDigestList, HashDigestVector, HashTreeLayer, Parameter, PRFKey
10-
11-
if TYPE_CHECKING:
12-
from .prf import Prf
13-
from .subtree import HashSubTree
14-
from .tweak_hash import TweakHasher
8+
from .types import HashDigestList, HashDigestVector, HashTreeLayer
159

1610

1711
def get_padded_layer(
@@ -159,93 +153,3 @@ def expand_activation_time(
159153
end_bottom_tree_index = end // c
160154

161155
return (start_bottom_tree_index, end_bottom_tree_index)
162-
163-
164-
def bottom_tree_from_prf_key(
165-
prf: "Prf",
166-
hasher: "TweakHasher",
167-
rand: Rand,
168-
config: XmssConfig,
169-
prf_key: PRFKey,
170-
bottom_tree_index: Uint64,
171-
parameter: Parameter,
172-
) -> "HashSubTree":
173-
"""
174-
Generates a single bottom tree on-demand from the PRF key.
175-
176-
This is a key component of the top-bottom tree approach: instead of storing all
177-
one-time secret keys, we regenerate them on-demand using the PRF. This enables
178-
O(sqrt(LIFETIME)) memory usage.
179-
180-
### Algorithm
181-
182-
1. **Determine epoch range**: Bottom tree `i` covers epochs
183-
`[i * sqrt(LIFETIME), (i+1) * sqrt(LIFETIME))`
184-
185-
2. **Generate leaves**: For each epoch in parallel:
186-
- For each chain (0 to DIMENSION-1):
187-
- Derive secret start: `PRF(prf_key, epoch, chain_index)`
188-
- Compute public end: hash chain for `BASE - 1` steps
189-
- Hash all chain ends to get the leaf
190-
191-
3. **Build bottom tree**: Construct the bottom tree from the leaves
192-
193-
Args:
194-
prf: The PRF instance for key derivation.
195-
hasher: The tweakable hash instance.
196-
rand: Random generator for padding values.
197-
config: The XMSS configuration.
198-
prf_key: The master PRF secret key.
199-
bottom_tree_index: The index of the bottom tree to generate (0, 1, 2, ...).
200-
parameter: The public parameter `P` for the hash function.
201-
202-
Returns:
203-
A `HashSubTree` representing the requested bottom tree.
204-
"""
205-
from .tweak_hash import TreeTweak
206-
207-
# Calculate the number of leaves per bottom tree: sqrt(LIFETIME).
208-
leafs_per_bottom_tree = 1 << (config.LOG_LIFETIME // 2)
209-
210-
# Determine the epoch range for this bottom tree.
211-
start_epoch = bottom_tree_index * Uint64(leafs_per_bottom_tree)
212-
end_epoch = start_epoch + Uint64(leafs_per_bottom_tree)
213-
214-
# Generate leaf hashes for all epochs in this bottom tree.
215-
leaf_hashes: List[HashDigestVector] = []
216-
217-
for epoch in range(int(start_epoch), int(end_epoch)):
218-
# For each epoch, compute the one-time public key (chain endpoints).
219-
chain_ends: List[HashDigestVector] = []
220-
221-
for chain_index in range(config.DIMENSION):
222-
# Derive the secret start of the chain from the PRF key.
223-
start_digest = prf.apply(prf_key, Uint64(epoch), Uint64(chain_index))
224-
225-
# Compute the public end by hashing BASE - 1 times.
226-
end_digest = hasher.hash_chain(
227-
parameter=parameter,
228-
epoch=Uint64(epoch),
229-
chain_index=chain_index,
230-
start_step=0,
231-
num_steps=config.BASE - 1,
232-
start_digest=start_digest,
233-
)
234-
chain_ends.append(end_digest)
235-
236-
# Hash the chain ends to get the leaf for this epoch.
237-
leaf_tweak = TreeTweak(level=0, index=epoch)
238-
leaf_hash = hasher.apply(parameter, leaf_tweak, chain_ends)
239-
leaf_hashes.append(leaf_hash)
240-
241-
# Build the bottom tree from the leaf hashes.
242-
from .subtree import HashSubTree
243-
244-
return HashSubTree.new_bottom_tree(
245-
hasher=hasher,
246-
rand=rand,
247-
depth=config.LOG_LIFETIME,
248-
bottom_tree_index=bottom_tree_index,
249-
parameter=parameter,
250-
leaves=leaf_hashes,
251-
)

0 commit comments

Comments
 (0)