Skip to content

Commit

Permalink
slh-dsa: remove allocations (#860)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonmon6691 authored Oct 20, 2024
1 parent f46f4a3 commit 8f93676
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 26 deletions.
8 changes: 4 additions & 4 deletions slh-dsa/src/hashes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ pub(crate) trait HashSuite: Sized + Clone + Debug + PartialEq + Eq {
fn prf_msg(
sk_prf: &SkPrf<Self::N>,
opt_rand: &Array<u8, Self::N>,
msg: impl AsRef<[u8]>,
msg: &[impl AsRef<[u8]>],
) -> Array<u8, Self::N>;

/// Hashes a message using a given randomizer
fn h_msg(
rand: &Array<u8, Self::N>,
pk_seed: &PkSeed<Self::N>,
pk_root: &Array<u8, Self::N>,
msg: impl AsRef<[u8]>,
msg: &[impl AsRef<[u8]>],
) -> Array<u8, Self::M>;

/// PRF that is used to generate the secret values in WOTS+ and FORS private keys.
Expand Down Expand Up @@ -76,7 +76,7 @@ mod tests {
let opt_rand = Array::<u8, H::N>::from_fn(|_| 1);
let msg = [2u8; 32];

let result = H::prf_msg(&sk_prf, &opt_rand, msg);
let result = H::prf_msg(&sk_prf, &opt_rand, &[msg]);

assert_eq!(result.as_slice(), expected);
}
Expand All @@ -87,7 +87,7 @@ mod tests {
let pk_root = Array::<u8, H::N>::from_fn(|_| 2);
let msg = [3u8; 32];

let result = H::h_msg(&rand, &pk_seed, &pk_root, msg);
let result = H::h_msg(&rand, &pk_seed, &pk_root, &[msg]);

assert_eq!(result.as_slice(), expected);
}
Expand Down
18 changes: 10 additions & 8 deletions slh-dsa/src/hashes/sha2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,12 @@ where
fn prf_msg(
sk_prf: &SkPrf<Self::N>,
opt_rand: &Array<u8, Self::N>,
msg: impl AsRef<[u8]>,
msg: &[impl AsRef<[u8]>],
) -> Array<u8, Self::N> {
let mut mac = Hmac::<Sha256>::new_from_slice(sk_prf.as_ref()).unwrap();
mac.update(opt_rand.as_slice());
mac.update(msg.as_ref());
msg.iter()
.for_each(|msg_part| mac.update(msg_part.as_ref()));
let result = mac.finalize().into_bytes();
Array::clone_from_slice(&result[..Self::N::USIZE])
}
Expand All @@ -73,13 +74,13 @@ where
rand: &Array<u8, Self::N>,
pk_seed: &PkSeed<Self::N>,
pk_root: &Array<u8, Self::N>,
msg: impl AsRef<[u8]>,
msg: &[impl AsRef<[u8]>],
) -> Array<u8, Self::M> {
let mut h = Sha256::new();
h.update(rand);
h.update(pk_seed);
h.update(pk_root);
h.update(msg.as_ref());
msg.iter().for_each(|msg_part| h.update(msg_part.as_ref()));
let result = Array(h.finalize().into());
let seed = rand.clone().concat(pk_seed.0.clone()).concat(result);
mgf1::<Sha256, Self::M>(&seed)
Expand Down Expand Up @@ -220,11 +221,12 @@ where
fn prf_msg(
sk_prf: &SkPrf<Self::N>,
opt_rand: &Array<u8, Self::N>,
msg: impl AsRef<[u8]>,
msg: &[impl AsRef<[u8]>],
) -> Array<u8, Self::N> {
let mut mac = Hmac::<Sha512>::new_from_slice(sk_prf.as_ref()).unwrap();
mac.update(opt_rand.as_slice());
mac.update(msg.as_ref());
msg.iter()
.for_each(|msg_part| mac.update(msg_part.as_ref()));
let result = mac.finalize().into_bytes();
Array::clone_from_slice(&result[..Self::N::USIZE])
}
Expand All @@ -233,13 +235,13 @@ where
rand: &Array<u8, Self::N>,
pk_seed: &PkSeed<Self::N>,
pk_root: &Array<u8, Self::N>,
msg: impl AsRef<[u8]>,
msg: &[impl AsRef<[u8]>],
) -> Array<u8, Self::M> {
let mut h = Sha512::new();
h.update(rand);
h.update(pk_seed);
h.update(pk_root);
h.update(msg.as_ref());
msg.iter().for_each(|msg_part| h.update(msg_part.as_ref()));
let result = Array(h.finalize().into());
let seed = rand.clone().concat(pk_seed.0.clone()).concat(result);
mgf1::<Sha512, Self::M>(&seed)
Expand Down
12 changes: 7 additions & 5 deletions slh-dsa/src/hashes/shake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,13 @@ where
fn prf_msg(
sk_prf: &SkPrf<Self::N>,
opt_rand: &Array<u8, Self::N>,
msg: impl AsRef<[u8]>,
msg: &[impl AsRef<[u8]>],
) -> Array<u8, Self::N> {
let mut hasher = Shake256::default();
hasher.update(sk_prf.as_ref());
hasher.update(opt_rand.as_slice());
hasher.update(msg.as_ref());
msg.iter()
.for_each(|msg_part| hasher.update(msg_part.as_ref()));
let mut output = Array::<u8, Self::N>::default();
hasher.finalize_xof_into(&mut output);
output
Expand All @@ -49,13 +50,14 @@ where
rand: &Array<u8, Self::N>,
pk_seed: &PkSeed<Self::N>,
pk_root: &Array<u8, Self::N>,
msg: impl AsRef<[u8]>,
msg: &[impl AsRef<[u8]>],
) -> Array<u8, Self::M> {
let mut hasher = Shake256::default();
hasher.update(rand.as_slice());
hasher.update(pk_seed.as_ref());
hasher.update(pk_root.as_ref());
hasher.update(msg.as_ref());
msg.iter()
.for_each(|msg_part| hasher.update(msg_part.as_ref()));
let mut output = Array::<u8, Self::M>::default();
hasher.finalize_xof_into(&mut output);
output
Expand Down Expand Up @@ -267,7 +269,7 @@ mod tests {

let expected = hex!("bc5c062307df0a41aeeae19ad655f7b2");

let result = H::prf_msg(&sk_prf, &opt_rand, msg);
let result = H::prf_msg(&sk_prf, &opt_rand, &[msg]);

assert_eq!(result.as_slice(), expected);
}
Expand Down
5 changes: 2 additions & 3 deletions slh-dsa/src/signing_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ impl<P: ParameterSet> SigningKey<P> {
/// Implements [slh_sign_internal] as defined in FIPS-205.
/// Published for KAT validation purposes but not intended for general use.
/// opt_rand must be a P::N length slice, panics otherwise.
pub fn slh_sign_internal(&self, msg: &[u8], opt_rand: Option<&[u8]>) -> Signature<P> {
pub fn slh_sign_internal(&self, msg: &[&[u8]], opt_rand: Option<&[u8]>) -> Signature<P> {
let rand = opt_rand
.unwrap_or(&self.verifying_key.pk_seed.0)
.try_into()
Expand Down Expand Up @@ -142,8 +142,7 @@ impl<P: ParameterSet> SigningKey<P> {
let ctx_len = u8::try_from(ctx.len()).map_err(|_| Error::new())?;
let ctx_len_bytes = ctx_len.to_be_bytes();

// TODO - figure out what to do about this allocation. Maybe pass a chained iterator to slh_sign_internal?
let ctx_msg = [&[0], &ctx_len_bytes, ctx, msg].concat();
let ctx_msg = [&[0], &ctx_len_bytes, ctx, msg];
Ok(self.slh_sign_internal(&ctx_msg, opt_rand))
}

Expand Down
9 changes: 6 additions & 3 deletions slh-dsa/src/verifying_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ impl<P: ParameterSet + VerifyingKeyLen> VerifyingKey<P> {
/// Verify a raw message (without context).
/// Implements [slh_verify_internal] as defined in FIPS-205.
/// Published for KAT validation purposes but not intended for general use.
pub fn slh_verify_internal(&self, msg: &[u8], signature: &Signature<P>) -> Result<(), Error> {
pub fn slh_verify_internal(
&self,
msg: &[&[u8]],
signature: &Signature<P>,
) -> Result<(), Error> {
let pk_seed = &self.pk_seed;
let randomizer = &signature.randomizer;
let fors_sig = &signature.fors_sig;
Expand Down Expand Up @@ -79,8 +83,7 @@ impl<P: ParameterSet + VerifyingKeyLen> VerifyingKey<P> {
let ctx_len = u8::try_from(ctx.len()).map_err(|_| Error::new())?;
let ctx_len_bytes = ctx_len.to_be_bytes();

// TODO - figure out what to do about this allocation. Maybe pass a chained iterator to slh_sign_internal?
let ctx_msg = [&[0], &ctx_len_bytes, ctx, msg].concat();
let ctx_msg = [&[0], &ctx_len_bytes, ctx, msg];
self.slh_verify_internal(&ctx_msg, signature) // TODO - context processing
}

Expand Down
2 changes: 1 addition & 1 deletion slh-dsa/tests/acvp_sig.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ macro_rules! parameter_case {
.additionalRandomness
.as_ref()
.map(|x| x.data.as_slice());
let sig = sk.slh_sign_internal($test_case.message.data.as_slice(), opt_rand);
let sig = sk.slh_sign_internal(&[$test_case.message.data.as_slice()], opt_rand);
assert_eq!(sig.to_vec(), $test_case.signature.data);
}};
}
Expand Down
2 changes: 1 addition & 1 deletion slh-dsa/tests/acvp_ver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ macro_rules! parameter_case {
($param:ident, $test_case:expr) => {{
let sk = VerifyingKey::<$param>::try_from($test_case.pk.data.as_slice()).unwrap();
if let Ok(sig) = $test_case.signature.data.as_slice().try_into() {
let success = sk.slh_verify_internal($test_case.message.data.as_slice(), &sig);
let success = sk.slh_verify_internal(&[$test_case.message.data.as_slice()], &sig);
assert_eq!($test_case.testPassed, success.is_ok());
} else {
assert!(!$test_case.testPassed);
Expand Down
2 changes: 1 addition & 1 deletion slh-dsa/tests/known_answer_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ where
let mut opt_rand = vec![0; P::VkLen::USIZE / 2];
rng.fill_bytes(opt_rand.as_mut());

let sig = sk.slh_sign_internal(msg, Some(&opt_rand)).to_bytes();
let sig = sk.slh_sign_internal(&[msg], Some(&opt_rand)).to_bytes();
writeln!(resp, "smlen = {}", sig.as_slice().len() + msg.len()).unwrap();
writeln!(
resp,
Expand Down

0 comments on commit 8f93676

Please sign in to comment.