Skip to content

Commit e9553c2

Browse files
committed
cut stack usage
1 parent abebea1 commit e9553c2

File tree

5 files changed

+53
-34
lines changed

5 files changed

+53
-34
lines changed

CHANGELOG.md

+5
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## 0.4.4 (2024-10-XX)
9+
10+
- Significant shrink of required stack size
11+
- Internal-only refactoring and polishing
12+
813
## 0.4.3 (2024-10-16)
914

1015
- Adapted ExpandedPrivateKey into PrivateKey and ExpandedPublicKey into PublicKey, removed the former(s)

Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ workspace = { exclude = ["ct_cm4", "dudect", "fuzz", "wasm"] }
22

33
[package]
44
name = "fips204"
5-
version = "0.4.3"
5+
version = "0.4.4"
66
authors = ["Eric Schorn <[email protected]>"]
77
description = "FIPS 204: Module-Lattice-Based Digital Signature"
88
categories = ["cryptography", "no-std"]

src/lib.rs

+13-8
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515

1616

1717
// TODO Roadmap
18+
// 0. Code clean-up, more carefully shrink stack
1819
// 1. Improve docs on first/last few algorithms
19-
// 2. Several outstanding refactors (mostly down below in this file)
20-
// 3. Always more testing...
20+
// 2. Always more testing...
2121

2222

2323
// Implements FIPS 204 Module-Lattice-Based Digital Signature Standard.
@@ -277,9 +277,11 @@ macro_rules! functionality {
277277
use crate::high_low::power2round;
278278
use crate::helpers::to_mont;
279279
use crate::D;
280+
use crate::hashing::expand_a;
280281

281282
// TODO: refactor
282-
let PrivateKey {rho, cap_k: _, tr, s_hat_1_mont, s_hat_2_mont, t_hat_0_mont, cap_a_hat} = &self;
283+
let PrivateKey {rho, cap_k: _, tr, s_hat_1_mont, s_hat_2_mont, t_hat_0_mont} = &self;
284+
let cap_a_hat: [[T; L]; K] = expand_a::<false, K, L>(&rho);
283285
let s_1: [R; L] = inv_ntt(&core::array::from_fn(|l| T(core::array::from_fn(|n| full_reduce32(mont_reduce(s_hat_1_mont[l].0[n] as i64))))));
284286
let s_1: [R; L] = core::array::from_fn(|l| R(core::array::from_fn(|n| if s_1[l].0[n] > (Q >> 2) {s_1[l].0[n] - Q} else {s_1[l].0[n]})));
285287
let s_2: [R; K] = inv_ntt(&core::array::from_fn(|k| T(core::array::from_fn(|n| full_reduce32(mont_reduce(s_hat_2_mont[k].0[n] as i64))))));
@@ -307,7 +309,8 @@ macro_rules! functionality {
307309
let t1_d2_hat_mont: [T; K] = to_mont(&core::array::from_fn(|k| {
308310
T(core::array::from_fn(|n| mont_reduce(i64::from(t1_hat_mont[k].0[n]) << D)))
309311
}));
310-
let pk = PublicKey { rho: *rho, cap_a_hat: cap_a_hat.clone(), tr: *tr, t1_d2_hat_mont};
312+
//let pk = PublicKey { rho: *rho, cap_a_hat: cap_a_hat.clone(), tr: *tr, t1_d2_hat_mont};
313+
let pk = PublicKey { rho: *rho, tr: *tr, t1_d2_hat_mont};
311314

312315
// 10: return pk
313316
pk
@@ -320,7 +323,7 @@ macro_rules! functionality {
320323

321324
// Algorithm 3 in Verifier trait.
322325
fn verify(&self, message: &[u8], sig: &Self::Signature, ctx: &[u8]) -> bool {
323-
let Ok(res) = ml_dsa::verify::<K, L, LAMBDA_DIV4, PK_LEN, SIG_LEN, W1_LEN>(
326+
let Ok(res) = ml_dsa::verify::<false, K, L, LAMBDA_DIV4, PK_LEN, SIG_LEN, W1_LEN>(
324327
BETA, GAMMA1, GAMMA2, OMEGA, TAU, &self, &message, &sig, ctx, &[], &[], false
325328
) else {
326329
return false;
@@ -335,7 +338,7 @@ macro_rules! functionality {
335338
};
336339
let mut phm = [0u8; 64]; // hashers don't all play well with each other
337340
let (oid, phm_len) = hash_message(message, ph, &mut phm);
338-
let Ok(res) = ml_dsa::verify::<K, L, LAMBDA_DIV4, PK_LEN, SIG_LEN, W1_LEN>(
341+
let Ok(res) = ml_dsa::verify::<false, K, L, LAMBDA_DIV4, PK_LEN, SIG_LEN, W1_LEN>(
339342
BETA, GAMMA1, GAMMA2, OMEGA, TAU, &self, &message, &sig, ctx, &oid, &phm[0..phm_len], false
340343
) else {
341344
return false;
@@ -396,9 +399,11 @@ macro_rules! functionality {
396399
use crate::helpers::full_reduce32;
397400
use crate::ntt::inv_ntt;
398401
use crate::D;
402+
use crate::hashing::expand_a;
399403

400404
// TODO: refactor
401-
let PublicKey {rho, cap_a_hat, tr, t1_d2_hat_mont} = &self;
405+
let PublicKey {rho, tr, t1_d2_hat_mont} = &self;
406+
let cap_a_hat: [[T; L]; K] = expand_a::<false, K, L>(&rho);
402407
let (_, _, _, _) = (rho, cap_a_hat, tr, t1_d2_hat_mont);
403408
let t1_d2: [R; K] = inv_ntt(&core::array::from_fn(|k| T(core::array::from_fn(|n| full_reduce32(mont_reduce(t1_d2_hat_mont[k].0[n] as i64))))));
404409
let t1: [R; K] = core::array::from_fn(|k| R(core::array::from_fn(|n| t1_d2[k].0[n] >> D)));
@@ -486,7 +491,7 @@ macro_rules! functionality {
486491
if ctx.len() > 255 {
487492
return false;
488493
};
489-
let Ok(res) = ml_dsa::verify::<K, L, LAMBDA_DIV4, PK_LEN, SIG_LEN, W1_LEN>(
494+
let Ok(res) = ml_dsa::verify::<false, K, L, LAMBDA_DIV4, PK_LEN, SIG_LEN, W1_LEN>(
490495
BETA, GAMMA1, GAMMA2, OMEGA, TAU, pk, &message, &sig, ctx, &[], &[], true
491496
) else {
492497
return false;

src/ml_dsa.rs

+32-23
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,17 @@ pub(crate) fn sign<
8484

8585
// Extract from expand_private()
8686
let PrivateKey {
87-
rho: _,
87+
rho,
8888
cap_k,
8989
tr,
9090
s_hat_1_mont,
9191
s_hat_2_mont,
9292
t_hat_0_mont,
93-
cap_a_hat,
93+
//cap_a_hat,
9494
} = esk;
9595

96+
let cap_a_hat: [[T; L]; K] = expand_a::<CTEST, K, L>(rho);
97+
9698
// 6: 𝜇 ← H(BytesToBits(𝑡𝑟)||𝑀 , 64) ▷ Compute message representative µ
9799
// We may have arrived from 3 different paths
98100
let mut h6 = if nist {
@@ -134,7 +136,7 @@ pub(crate) fn sign<
134136
// 12: w ← NTT−1(cap_a_hat ◦ NTT(y))
135137
let w: [R; K] = {
136138
let y_hat: [T; L] = ntt(&y);
137-
let ay_hat: [T; K] = mat_vec_mul(cap_a_hat, &y_hat);
139+
let ay_hat: [T; K] = mat_vec_mul(&cap_a_hat, &y_hat);
138140
inv_ntt(&ay_hat)
139141
};
140142

@@ -259,6 +261,7 @@ pub(crate) fn sign<
259261
/// Continuation of `verify_start()`. The `lib.rs` wrapper around this will convert `Error()` to false.
260262
#[allow(clippy::too_many_arguments, clippy::similar_names)]
261263
pub(crate) fn verify<
264+
const CTEST: bool,
262265
const K: usize,
263266
const L: usize,
264267
const LAMBDA_DIV4: usize,
@@ -270,7 +273,8 @@ pub(crate) fn verify<
270273
sig: &[u8; SIG_LEN], ctx: &[u8], oid: &[u8], phm: &[u8], nist: bool,
271274
) -> Result<bool, &'static str> {
272275
//
273-
let PublicKey { rho: _, cap_a_hat, tr, t1_d2_hat_mont } = epk;
276+
//let PublicKey { rho: _, cap_a_hat, tr, t1_d2_hat_mont } = epk;
277+
let PublicKey { rho, tr, t1_d2_hat_mont } = epk;
274278

275279
// 1: (ρ, t_1) ← pkDecode(pk)
276280
// --> calculated in expand_public()
@@ -314,8 +318,10 @@ pub(crate) fn verify<
314318

315319
// 9: w′_Approx ← invNTT(cap_A_hat ◦ NTT(z) - NTT(c) ◦ NTT(t_1 · 2^d) ▷ w′_Approx = Az − ct1·2^d
316320
let wp_approx: [R; K] = {
321+
// hardcode CTEST as false since everything is public here
322+
let cap_a_hat: [[T; L]; K] = expand_a::<CTEST, K, L>(rho);
317323
let z_hat: [T; L] = ntt(&z);
318-
let az_hat: [T; K] = mat_vec_mul(cap_a_hat, &z_hat);
324+
let az_hat: [T; K] = mat_vec_mul(&cap_a_hat, &z_hat);
319325
// NTT(t_1 · 2^d) --> calculated in expand_public()
320326
let c_hat: &T = &ntt(&[c])[0];
321327
inv_ntt(&core::array::from_fn(|k| {
@@ -378,22 +384,22 @@ pub(crate) fn key_gen_internal<
378384

379385
// There is effectively no step 2 due to formatting error in spec
380386

381-
// 3: cap_a_hat ← ExpandA(ρ) ▷ A is generated and stored in NTT representation as Â
382-
let cap_a_hat: [[T; L]; K] = expand_a::<CTEST, K, L>(&rho);
383-
384387
// 4: (s_1, s_2) ← ExpandS(ρ′)
385388
let (s_1, s_2): ([R; L], [R; K]) = expand_s::<CTEST, K, L>(eta, &rho_prime);
386389

390+
// 3: cap_a_hat ← ExpandA(ρ) ▷ A is generated and stored in NTT representation as Â
387391
// 5: t ← NTT−1(cap_a_hat ◦ NTT(s_1)) + s_2 ▷ Compute t = As1 + s2
388-
//let t: [R; K]
389-
let s_1_hat: [T; L] = ntt(&s_1);
390-
let as1_hat: [T; K] = mat_vec_mul(&cap_a_hat, &s_1_hat);
391-
let t_not_reduced: [R; K] = add_vector_ntt(&inv_ntt(&as1_hat), &s_2);
392-
let t: [R; K] =
393-
core::array::from_fn(|k| R(core::array::from_fn(|n| full_reduce32(t_not_reduced[k].0[n]))));
394-
395392
// 6: (t_1, t_0) ← Power2Round(t, d) ▷ Compress t
396-
let (t_1, t_0): ([R; K], [R; K]) = power2round(&t);
393+
394+
let (t_1, t_0): ([R; K], [R; K]) = {
395+
let cap_a_hat: [[T; L]; K] = expand_a::<CTEST, K, L>(&rho);
396+
let s_1_hat: [T; L] = ntt(&s_1);
397+
let as1_hat: [T; K] = mat_vec_mul(&cap_a_hat, &s_1_hat);
398+
let t_not_reduced: [R; K] = add_vector_ntt(&inv_ntt(&as1_hat), &s_2);
399+
let t: [R; K] =
400+
core::array::from_fn(|k| R(core::array::from_fn(|n| full_reduce32(t_not_reduced[k].0[n]))));
401+
power2round(&t)
402+
};
397403

398404
// There is effectively no step 7 due to formatting error in spec
399405

@@ -414,10 +420,12 @@ pub(crate) fn key_gen_internal<
414420
let t1_d2_hat_mont: [T; K] = to_mont(&core::array::from_fn(|k| {
415421
T(core::array::from_fn(|n| mont_reduce(i64::from(t1_hat_mont[k].0[n]) << D)))
416422
}));
417-
let pk = PublicKey { rho, cap_a_hat: cap_a_hat.clone(), tr, t1_d2_hat_mont };
423+
//let pk = PublicKey { rho, cap_a_hat: cap_a_hat.clone(), tr, t1_d2_hat_mont };
424+
let pk = PublicKey { rho, tr, t1_d2_hat_mont };
418425

419426
// 2: s_hat_1 ← NTT(s_1)
420-
let s_hat_1_mont: [T; L] = to_mont(&s_1_hat); //ntt(&s_1));
427+
//let s_hat_1_mont: [T; L] = to_mont(&s_1_hat); //ntt(&s_1));
428+
let s_hat_1_mont: [T; L] = to_mont(&ntt(&s_1));
421429
// 3: s_hat_2 ← NTT(s_2)
422430
let s_hat_2_mont: [T; K] = to_mont(&ntt(&s_2));
423431
// 4: t_hat_0 ← NTT(t_0)
@@ -429,7 +437,7 @@ pub(crate) fn key_gen_internal<
429437
s_hat_1_mont,
430438
s_hat_2_mont,
431439
t_hat_0_mont,
432-
cap_a_hat,
440+
// cap_a_hat,
433441
};
434442

435443
// 11: return (pk, sk)
@@ -463,7 +471,7 @@ pub(crate) fn expand_private<
463471
let t_hat_0_mont: [T; K] = to_mont(&ntt(&t_0));
464472

465473
// 5: cap_a_hat ← ExpandA(ρ) ▷ A is generated and stored in NTT representation as Â
466-
let cap_a_hat: [[T; L]; K] = expand_a::<CTEST, K, L>(rho);
474+
//let cap_a_hat: [[T; L]; K] = expand_a::<CTEST, K, L>(rho);
467475

468476
Ok(PrivateKey {
469477
rho: *rho,
@@ -472,7 +480,7 @@ pub(crate) fn expand_private<
472480
s_hat_1_mont,
473481
s_hat_2_mont,
474482
t_hat_0_mont,
475-
cap_a_hat,
483+
//cap_a_hat,
476484
})
477485
}
478486

@@ -489,7 +497,7 @@ pub(crate) fn expand_public<const K: usize, const L: usize, const PK_LEN: usize>
489497
let (rho, t_1): (&[u8; 32], [R; K]) = pk_decode(pk)?;
490498

491499
// 5: cap_a_hat ← ExpandA(ρ) ▷ A is generated and stored in NTT representation as cap_A_hat
492-
let cap_a_hat: [[T; L]; K] = expand_a::<false, K, L>(rho);
500+
//let cap_a_hat: [[T; L]; K] = expand_a::<false, K, L>(rho);
493501

494502
// 6: tr ← H(pk, 64)
495503
let mut h6 = h256_xof(&[pk]);
@@ -503,5 +511,6 @@ pub(crate) fn expand_public<const K: usize, const L: usize, const PK_LEN: usize>
503511
T(core::array::from_fn(|n| mont_reduce(i64::from(t1_hat_mont[k].0[n]) << D)))
504512
}));
505513

506-
Ok(PublicKey { rho: *rho, cap_a_hat, tr, t1_d2_hat_mont })
514+
//Ok(PublicKey { rho: *rho, cap_a_hat, tr, t1_d2_hat_mont })
515+
Ok(PublicKey { rho: *rho, tr, t1_d2_hat_mont })
507516
}

src/types.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ pub struct PrivateKey<const K: usize, const L: usize> {
2525
pub(crate) s_hat_1_mont: [T; L],
2626
pub(crate) s_hat_2_mont: [T; K],
2727
pub(crate) t_hat_0_mont: [T; K],
28-
pub(crate) cap_a_hat: [[T; L]; K],
28+
// pub(crate) cap_a_hat: [[T; L]; K],
2929
}
3030

3131

@@ -37,7 +37,7 @@ pub struct PrivateKey<const K: usize, const L: usize> {
3737
#[repr(align(8))]
3838
pub struct PublicKey<const K: usize, const L: usize> {
3939
pub(crate) rho: [u8; 32],
40-
pub(crate) cap_a_hat: [[T; L]; K],
40+
// pub(crate) cap_a_hat: [[T; L]; K],
4141
pub(crate) tr: [u8; 64],
4242
pub(crate) t1_d2_hat_mont: [T; K],
4343
}

0 commit comments

Comments
 (0)