Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions src/lean_spec/subspecs/xmss/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
)
from lean_spec.types import StrictBaseModel, Uint64

from ..koalabear import Fp
from .constants import (
PROD_CONFIG,
TEST_CONFIG,
Expand Down Expand Up @@ -330,7 +329,7 @@ def sign(self, sk: SecretKey, epoch: Uint64, message: bytes) -> Signature:
raise RuntimeError("Encoding is broken: returned too many or too few chunks.")

# Compute the one-time signature hashes based on the codeword.
ots_hashes: List[List[Fp]] = []
ots_hashes: List[HashDigestVector] = []
for chain_index, steps in enumerate(codeword):
# Derive the secret start of the current chain using the master PRF key.
start_digest = self.prf.apply(sk.prf_key, epoch, Uint64(chain_index))
Expand Down Expand Up @@ -380,11 +379,9 @@ def sign(self, sk: SecretKey, epoch: Uint64, message: bytes) -> Signature:
# - The OTS,
# - The Merkle path,
# - The randomness `rho` needed for verification.
# Wrap ots_hashes in SSZ types
from .types import HashDigestList, HashDigestVector
from .types import HashDigestList

ssz_hashes = [HashDigestVector(data=hash_digest) for hash_digest in ots_hashes]
return Signature(path=path, rho=rho, hashes=HashDigestList(data=ssz_hashes))
return Signature(path=path, rho=rho, hashes=HashDigestList(data=ots_hashes))

def verify(self, pk: PublicKey, epoch: Uint64, message: bytes, sig: Signature) -> bool:
r"""
Expand Down Expand Up @@ -441,10 +438,10 @@ def verify(self, pk: PublicKey, epoch: Uint64, message: bytes, sig: Signature) -
return False

# Reconstruct the one-time public key (the list of chain endpoints).
chain_ends: List[List[Fp]] = []
chain_ends: List[HashDigestVector] = []
for chain_index, xi in enumerate(codeword):
# The signature provides `start_digest`, which is the hash value after `xi` steps.
start_digest: List[Fp] = list(sig.hashes[chain_index])
start_digest = sig.hashes[chain_index]
# We must perform the remaining `BASE - 1 - xi` hashing steps
# to compute the public endpoint of the chain.
num_steps_remaining = config.BASE - 1 - xi
Expand Down
26 changes: 13 additions & 13 deletions src/lean_spec/subspecs/xmss/prf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

import hashlib
import os
from typing import List

from pydantic import model_validator

Expand All @@ -24,7 +23,7 @@
TEST_CONFIG,
XmssConfig,
)
from .types import PRFKey, Randomness
from .types import HashDigestVector, PRFKey, Randomness

PRF_DOMAIN_SEP: bytes = bytes(
[
Expand Down Expand Up @@ -106,7 +105,7 @@ def key_gen(self) -> PRFKey:
"""
return PRFKey(os.urandom(PRF_KEY_LENGTH))

def apply(self, key: PRFKey, epoch: Uint64, chain_index: Uint64) -> List[Fp]:
def apply(self, key: PRFKey, epoch: Uint64, chain_index: Uint64) -> HashDigestVector:
"""
Applies the PRF to derive the secret starting value for a single hash chain.

Expand All @@ -127,8 +126,7 @@ def apply(self, key: PRFKey, epoch: Uint64, chain_index: Uint64) -> List[Fp]:
chain_index: The index of the hash chain within that epoch's OTS.

Returns:
A list of field elements representing the secret start of a single
hash chain (i.e., a `HashDigest`).
A hash digest representing the secret start of a single hash chain.
"""
# Retrieve the scheme's configuration parameters.
config = self.config
Expand Down Expand Up @@ -160,15 +158,17 @@ def apply(self, key: PRFKey, epoch: Uint64, chain_index: Uint64) -> List[Fp]:
# - Slice an 8-byte (64-bit) chunk from the `prf_output_bytes`.
# - Convert that chunk from a big-endian byte representation to an integer.
# - Create a field element from the integer (the Fp constructor handles the modulo).
return [
Fp(
value=int.from_bytes(
prf_output_bytes[i * PRF_BYTES_PER_FE : (i + 1) * PRF_BYTES_PER_FE],
"big",
return HashDigestVector(
data=[
Fp(
value=int.from_bytes(
prf_output_bytes[i * PRF_BYTES_PER_FE : (i + 1) * PRF_BYTES_PER_FE],
"big",
)
)
)
for i in range(config.HASH_LEN_FE)
]
for i in range(config.HASH_LEN_FE)
]
)

def get_randomness(
self, key: PRFKey, epoch: Uint64, message: bytes, counter: Uint64
Expand Down
6 changes: 3 additions & 3 deletions src/lean_spec/subspecs/xmss/rand.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from ..koalabear import Fp, P
from .constants import PROD_CONFIG, TEST_CONFIG, XmssConfig
from .types import Parameter, Randomness
from .types import HashDigestVector, Parameter, Randomness


class Rand(StrictBaseModel):
Expand All @@ -34,9 +34,9 @@ def parameter(self) -> Parameter:
"""Generates a random public parameter."""
return Parameter(data=self.field_elements(self.config.PARAMETER_LEN))

def domain(self) -> List[Fp]:
def domain(self) -> HashDigestVector:
"""Generates a random hash digest."""
return self.field_elements(self.config.HASH_LEN_FE)
return HashDigestVector(data=self.field_elements(self.config.HASH_LEN_FE))

def rho(self) -> Randomness:
"""Generates randomness `rho` for message encoding."""
Expand Down
Loading
Loading