diff --git a/ml-dsa/Cargo.toml b/ml-dsa/Cargo.toml index 96f9c670..65395b17 100644 --- a/ml-dsa/Cargo.toml +++ b/ml-dsa/Cargo.toml @@ -16,15 +16,17 @@ keywords = ["crypto", "signature"] [features] default = ["std"] -std = [] #["sha3/std"] +std = ["sha3/std"] zeroize = ["dep:zeroize"] +rand_core = ["signature/rand_core"] [dependencies] # TODO reset to something like: { version = "0.2.0-rc.9", features = ["extra-sizes"] } hybrid-array = { path = "../../hybrid-array/", features = ["extra-sizes"]} num-traits = "0.2.19" +rand = "0.8.5" sha3 = "0.10.8" -signature = { version = "2.3.0-pre.4", features = ["rand_core"] } +signature = "2.3.0-pre.4" zeroize = { version = "1.8.1", optional = true, default-features = false } [dev-dependencies] diff --git a/ml-dsa/src/lib.rs b/ml-dsa/src/lib.rs index 91f561aa..92ecba1a 100644 --- a/ml-dsa/src/lib.rs +++ b/ml-dsa/src/lib.rs @@ -26,7 +26,9 @@ mod util; // TODO(RLB) Move module to an independent crate shared with ml_kem mod module_lattice; +use core::convert::{AsRef, TryFrom, TryInto}; use hybrid_array::{typenum::*, Array}; +use rand::{CryptoRng, RngCore}; use crate::algebra::*; use crate::crypto::*; @@ -37,22 +39,19 @@ use crate::sampling::*; use crate::util::*; // TODO(RLB) Clean up this API -pub use crate::param::{ - EncodedSignature, EncodedSigningKey, EncodedVerificationKey, SignatureParams, SigningKeyParams, - VerificationKeyParams, -}; - +pub use crate::param::{EncodedSignature, EncodedSigningKey, EncodedVerifyingKey, MlDsaParams}; pub use crate::util::B32; +pub use signature::Error; /// An ML-DSA signature #[derive(Clone, PartialEq)] -pub struct Signature { +pub struct Signature { c_tilde: Array, z: Vector, h: Hint

, } -impl Signature

{ +impl Signature

{ // Algorithm 26 sigEncode pub fn encode(&self) -> EncodedSignature

{ let c_tilde = self.c_tilde.clone(); @@ -77,9 +76,62 @@ impl Signature

{ } } +impl<'a, P: MlDsaParams> TryFrom<&'a [u8]> for Signature

{ + type Error = Error; + + fn try_from(value: &'a [u8]) -> Result { + let enc = EncodedSignature::

::try_from(value).map_err(|_| Error::new())?; + Self::decode(&enc).ok_or(Error::new()) + } +} + +impl TryInto> for Signature

{ + type Error = Error; + + fn try_into(self) -> Result, Self::Error> { + Ok(self.encode()) + } +} + +impl signature::SignatureEncoding for Signature

{ + type Repr = EncodedSignature

; +} + +// This method takes a slice of slices so that we can accommodate the varying calculations (direct +// for test vectors, 0... for sign/sign_deterministic, 1... for the pre-hashed version) without +// having to allocate memory for components. +fn message_representative(tr: &[u8], Mp: &[&[u8]]) -> B64 { + let mut h = H::default().absorb(tr); + + for m in Mp { + h = h.absorb(m); + } + + h.squeeze_new() +} + +/// An ML-DSA key pair +pub struct KeyPair { + /// The signing key of the key pair + pub signing_key: SigningKey

, + + /// The verifying key of the key pair + pub verifying_key: VerifyingKey

, +} + +impl AsRef> for KeyPair

{ + fn as_ref(&self) -> &VerifyingKey

{ + &self.verifying_key + } +} + +impl signature::KeypairRef for KeyPair

{ + type VerifyingKey = VerifyingKey

; +} + /// An ML-DSA signing key #[derive(Clone, PartialEq)] -pub struct SigningKey { +pub struct SigningKey { rho: B32, K: B32, tr: B64, @@ -94,7 +146,7 @@ pub struct SigningKey { A_hat: NttMatrix, } -impl SigningKey

{ +impl SigningKey

{ fn new( rho: B32, K: B32, @@ -124,47 +176,16 @@ impl SigningKey

{ } } - /// Deterministically generate a signing key pair from the specified seed - pub fn key_gen_internal(xi: &B32) -> (VerificationKey

, SigningKey

) - where - P: SigningKeyParams + VerificationKeyParams, - { - // Derive seeds - let mut h = H::default() - .absorb(xi) - .absorb(&[P::K::U8]) - .absorb(&[P::L::U8]); - - let rho: B32 = h.squeeze_new(); - let rhop: B64 = h.squeeze_new(); - let K: B32 = h.squeeze_new(); - - // Sample private key components - let A_hat = expand_a::(&rho); - let s1 = expand_s::(&rhop, P::Eta::ETA, 0); - let s2 = expand_s::(&rhop, P::Eta::ETA, P::L::USIZE); - - // Compute derived values - let As1_hat = &A_hat * &s1.ntt(); - let t = &As1_hat.ntt_inverse() + &s2; - - // Compress and encode - let (t1, t0) = t.power2round(); - - let vk = VerificationKey::new(rho, t1, Some(A_hat.clone()), None); - let sk = Self::new(rho, K, vk.tr.clone(), s1, s2, t0, Some(A_hat)); - - (vk, sk) - } - // Algorithm 7 ML-DSA.Sign_internal - pub fn sign_internal(&self, Mp: &[u8], rnd: &B32) -> Signature

+ pub fn sign_internal(&self, Mp: &[&[u8]], rnd: &B32) -> Signature

where - P: SignatureParams, + P: MlDsaParams, { // Compute the message representative + // XXX(RLB): This line incorporates some of the logic from ML-DSA.sign to avoid computing + // the concatenated M'. // XXX(RLB) Should the API represent this as an input? - let mu: B64 = H::default().absorb(&self.tr).absorb(&Mp).squeeze_new(); + let mu = message_representative(&self.tr, Mp); // Compute the private random seed let rhopp: B64 = H::default() @@ -218,10 +239,39 @@ impl SigningKey

{ panic!("Rejection sampling failed to find a valid signature"); } + // Algorithm 2 ML-DSA.Sign + pub fn sign( + &self, + M: &[u8], + ctx: &[u8], + rng: &mut (impl CryptoRng + RngCore), + ) -> Result, Error> { + if ctx.len() > 255 { + return Err(Error::new()); + } + + let mut rnd = B32::default(); + rng.try_fill_bytes(&mut rnd).map_err(|_| Error::new())?; + + let Mp = &[&[0], &[Truncate::truncate(ctx.len())], ctx, M]; + Ok(self.sign_internal(Mp, &rnd)) + } + + // Algorithm 2 ML-DSA.Sign (optional deterministic variant) + pub fn sign_deterministic(&self, M: &[u8], ctx: &[u8]) -> Result, Error> { + if ctx.len() > 255 { + return Err(Error::new()); + } + + let rnd = B32::default(); + let Mp = &[&[0], &[Truncate::truncate(ctx.len())], ctx, M]; + Ok(self.sign_internal(Mp, &rnd)) + } + // Algorithm 24 skEncode pub fn encode(&self) -> EncodedSigningKey

where - P: SigningKeyParams, + P: MlDsaParams, { let s1_enc = P::encode_s1(&self.s1); let s2_enc = P::encode_s2(&self.s2); @@ -239,7 +289,7 @@ impl SigningKey

{ // Algorithm 25 skDecode pub fn decode(enc: &EncodedSigningKey

) -> Self where - P: SigningKeyParams, + P: MlDsaParams, { let (rho, K, tr, s1_enc, s2_enc, t0_enc) = P::split_sk(enc); Self::new( @@ -254,9 +304,31 @@ impl SigningKey

{ } } +/// The Signer implementation for SigningKey uses the optional deterministic variant of ML-DSA, and +/// only supports signing with an empty context string. If you would like to include a context +/// string, use the [`SigningKey::sign_deterministic`] method. +impl signature::Signer> for SigningKey

{ + fn try_sign(&self, msg: &[u8]) -> Result, Error> { + self.sign_deterministic(msg, &[]) + } +} + +/// The RandomizedSigner implementation for SigningKey only supports signing with an empty context +/// string. If you would like to include a context string, use the [`SigningKey::sign`] method. +#[cfg(feature = "rand_core")] +impl signature::RandomizedSigner> for SigningKey

{ + fn try_sign_with_rng( + &self, + rng: &mut impl CryptoRngCore, + msg: &[u8], + ) -> Result, Error> { + self.sign(msg, &[], rng) + } +} + /// An ML-DSA verification key #[derive(Clone, PartialEq)] -pub struct VerificationKey { +pub struct VerifyingKey { rho: B32, t1: Vector, @@ -266,13 +338,35 @@ pub struct VerificationKey { tr: B64, } -impl VerificationKey

{ - pub fn verify_internal(&self, Mp: &[u8], sigma: &Signature

) -> bool +impl VerifyingKey

{ + fn new( + rho: B32, + t1: Vector, + A_hat: Option>, + enc: Option>, + ) -> Self { + let A_hat = A_hat.unwrap_or_else(|| expand_a(&rho)); + let enc = enc.unwrap_or_else(|| Self::encode_internal(&rho, &t1)); + + let t1_2d_hat = (Elem::new(1 << 13) * &t1).ntt(); + let tr: B64 = H::default().absorb(&enc).squeeze_new(); + + Self { + rho, + t1, + A_hat, + t1_2d_hat, + tr, + } + } + + // Algorithm 8 ML-DSA.Verify_internal + pub fn verify_internal(&self, Mp: &[&[u8]], sigma: &Signature

) -> bool where - P: SignatureParams, + P: MlDsaParams, { // Compute the message representative - let mu: B64 = H::default().absorb(&self.tr).absorb(&Mp).squeeze_new(); + let mu = message_representative(&self.tr, Mp); // Reconstruct w let c = sample_in_ball(&sigma.c_tilde, P::TAU); @@ -294,45 +388,41 @@ impl VerificationKey

{ sigma.c_tilde == cp_tilde } - fn encode_internal(rho: &B32, t1: &Vector) -> EncodedVerificationKey

{ - let t1_enc = P::encode_t1(t1); - P::concat_vk(rho.clone(), t1_enc) - } - - fn new( - rho: B32, - t1: Vector, - A_hat: Option>, - enc: Option>, - ) -> Self { - let A_hat = A_hat.unwrap_or_else(|| expand_a(&rho)); - let enc = enc.unwrap_or_else(|| Self::encode_internal(&rho, &t1)); + pub fn verify(&self, M: &[u8], ctx: &[u8], sigma: &Signature

) -> bool { + if ctx.len() > 255 { + return false; + } - let t1_2d_hat = (Elem::new(1 << 13) * &t1).ntt(); - let tr: B64 = H::default().absorb(&enc).squeeze_new(); + let Mp = &[&[0], &[Truncate::truncate(ctx.len())], ctx, M]; + return self.verify_internal(Mp, sigma); + } - Self { - rho, - t1, - A_hat, - t1_2d_hat, - tr, - } + fn encode_internal(rho: &B32, t1: &Vector) -> EncodedVerifyingKey

{ + let t1_enc = P::encode_t1(t1); + P::concat_vk(rho.clone(), t1_enc) } // Algorithm 22 pkEncode - pub fn encode(&self) -> EncodedVerificationKey

{ + pub fn encode(&self) -> EncodedVerifyingKey

{ Self::encode_internal(&self.rho, &self.t1) } // Algorithm 23 pkDecode - pub fn decode(enc: &EncodedVerificationKey

) -> Self { + pub fn decode(enc: &EncodedVerifyingKey

) -> Self { let (rho, t1_enc) = P::split_vk(enc); let t1 = P::decode_t1(t1_enc); Self::new(rho.clone(), t1, None, Some(enc.clone())) } } +impl signature::Verifier> for VerifyingKey

{ + fn verify(&self, msg: &[u8], signature: &Signature

) -> Result<(), Error> { + VerifyingKey::verify(self, msg, &[], signature) + .then_some(()) + .ok_or(Error::new()) + } +} + /// `MlDsa44` is the parameter set for security category 2. #[derive(Default, Clone, Debug, PartialEq)] pub struct MlDsa44; @@ -381,10 +471,73 @@ impl ParameterSet for MlDsa87 { const TAU: usize = 60; } +/// A parameter set that knows how to generate key pairs +pub trait KeyGen: MlDsaParams { + type KeyPair: signature::Keypair; + + /// Generate a signing key pair from the specified RNG + fn key_gen(rng: &mut (impl CryptoRng + RngCore)) -> Self::KeyPair; + + /// Deterministically generate a signing key pair from the specified seed + fn key_gen_internal(xi: &B32) -> Self::KeyPair; +} + +impl

KeyGen for P +where + P: MlDsaParams, +{ + type KeyPair = KeyPair

; + + /// Generate a signing key pair from the specified RNG + // Algorithm 1 ML-DSA.KeyGen() + fn key_gen(rng: &mut (impl CryptoRng + RngCore)) -> KeyPair

{ + let mut xi = B32::default(); + rng.fill_bytes(&mut xi); + Self::key_gen_internal(&xi) + } + + /// Deterministically generate a signing key pair from the specified seed + // Algorithm 6 ML-DSA.KeyGen_internal + fn key_gen_internal(xi: &B32) -> KeyPair

+ where + P: MlDsaParams, + { + // Derive seeds + let mut h = H::default() + .absorb(xi) + .absorb(&[P::K::U8]) + .absorb(&[P::L::U8]); + + let rho: B32 = h.squeeze_new(); + let rhop: B64 = h.squeeze_new(); + let K: B32 = h.squeeze_new(); + + // Sample private key components + let A_hat = expand_a::(&rho); + let s1 = expand_s::(&rhop, P::Eta::ETA, 0); + let s2 = expand_s::(&rhop, P::Eta::ETA, P::L::USIZE); + + // Compute derived values + let As1_hat = &A_hat * &s1.ntt(); + let t = &As1_hat.ntt_inverse() + &s2; + + // Compress and encode + let (t1, t0) = t.power2round(); + + let verifying_key = VerifyingKey::new(rho, t1, Some(A_hat.clone()), None); + let signing_key = + SigningKey::new(rho, K, verifying_key.tr.clone(), s1, s2, t0, Some(A_hat)); + + KeyPair { + signing_key, + verifying_key, + } + } +} + #[cfg(test)] mod test { use super::*; - use rand::Rng; #[test] fn output_sizes() { @@ -393,36 +546,37 @@ mod test { // ML-DSA-65 4032 1952 3309 // ML-DSA-87 4896 2592 4627 assert_eq!(SigningKeySize::::USIZE, 2560); - assert_eq!(VerificationKeySize::::USIZE, 1312); + assert_eq!(VerifyingKeySize::::USIZE, 1312); assert_eq!(SignatureSize::::USIZE, 2420); assert_eq!(SigningKeySize::::USIZE, 4032); - assert_eq!(VerificationKeySize::::USIZE, 1952); + assert_eq!(VerifyingKeySize::::USIZE, 1952); assert_eq!(SignatureSize::::USIZE, 3309); assert_eq!(SigningKeySize::::USIZE, 4896); - assert_eq!(VerificationKeySize::::USIZE, 2592); + assert_eq!(VerifyingKeySize::::USIZE, 2592); assert_eq!(SignatureSize::::USIZE, 4627); } fn encode_decode_round_trip_test

() where - P: SigningKeyParams + VerificationKeyParams + SignatureParams + PartialEq, + P: MlDsaParams + PartialEq, { - let mut rng = rand::thread_rng(); + let kp = P::key_gen(&mut rand::thread_rng()); + let sk = kp.signing_key; + let vk = kp.verifying_key; - let seed: [u8; 32] = rng.gen(); - let (pk, sk) = SigningKey::

::key_gen_internal(&seed.into()); - - let pk_bytes = pk.encode(); - let pk2 = VerificationKey::

::decode(&pk_bytes); - assert!(pk == pk2); + let vk_bytes = vk.encode(); + let vk2 = VerifyingKey::

::decode(&vk_bytes); + assert!(vk == vk2); let sk_bytes = sk.encode(); let sk2 = SigningKey::

::decode(&sk_bytes); assert!(sk == sk2); - let sig = sk.sign_internal(&[0, 1, 2, 3], (&[0u8; 32]).into()); + let M = b"Hello world"; + let rnd = Array([0u8; 32]); + let sig = sk.sign_internal(&[M], &rnd); let sig_bytes = sig.encode(); let sig2 = Signature::

::decode(&sig_bytes).unwrap(); assert!(sig == sig2); @@ -437,18 +591,18 @@ mod test { fn sign_verify_round_trip_test

() where - P: SigningKeyParams + VerificationKeyParams + SignatureParams, + P: MlDsaParams, { let mut rng = rand::thread_rng(); + let kp = P::key_gen(&mut rng); + let sk = kp.signing_key; + let vk = kp.verifying_key; - let seed: [u8; 32] = rng.gen(); - let (pk, sk) = SigningKey::

::key_gen_internal(&seed.into()); - - let rnd: [u8; 32] = rng.gen(); - let Mp = b"Hello world"; - let sig = sk.sign_internal(Mp, &rnd.into()); + let M = b"Hello world"; + let rnd = Array([0u8; 32]); + let sig = sk.sign_internal(&[M], &rnd); - assert!(pk.verify_internal(Mp, &sig)); + assert!(vk.verify_internal(&[M], &sig)); } #[test] diff --git a/ml-dsa/src/param.rs b/ml-dsa/src/param.rs index 8671bfa7..200e3981 100644 --- a/ml-dsa/src/param.rs +++ b/ml-dsa/src/param.rs @@ -10,6 +10,7 @@ //! know any details about object sizes. For example, `VectorEncodingSize::flatten` needs to know //! that the size of an encoded vector is `K` times the size of an encoded polynomial. +use core::fmt::Debug; use core::ops::{Add, Div, Mul, Rem, Sub}; use crate::module_lattice::encode::*; @@ -255,35 +256,35 @@ where } } -pub trait VerificationKeyParams: ParameterSet { +pub trait VerifyingKeyParams: ParameterSet { type T1Size: ArraySize; - type VerificationKeySize: ArraySize; + type VerifyingKeySize: ArraySize; fn encode_t1(t1: &Vector) -> EncodedT1; fn decode_t1(enc: &EncodedT1) -> Vector; - fn concat_vk(rho: B32, t1: EncodedT1) -> EncodedVerificationKey; - fn split_vk(enc: &EncodedVerificationKey) -> (&B32, &EncodedT1); + fn concat_vk(rho: B32, t1: EncodedT1) -> EncodedVerifyingKey; + fn split_vk(enc: &EncodedVerifyingKey) -> (&B32, &EncodedT1); } -pub type VerificationKeySize

=

::VerificationKeySize; +pub type VerifyingKeySize

=

::VerifyingKeySize; -pub type EncodedT1

= Array::T1Size>; -pub type EncodedVerificationKey

= Array>; +pub type EncodedT1

= Array::T1Size>; +pub type EncodedVerifyingKey

= Array>; -impl

VerificationKeyParams for P +impl

VerifyingKeyParams for P where P: ParameterSet, // T1 encoding rules U320: Mul, Prod: ArraySize + Div + Rem, - // Verification key encoding rules + // Verifying key encoding rules U32: Add>, Sum: ArraySize, Sum>: ArraySize + Sub>, { type T1Size = EncodedVectorSize; - type VerificationKeySize = Sum; + type VerifyingKeySize = Sum; fn encode_t1(t1: &Vector) -> EncodedT1 { Encode::::encode(t1) @@ -293,11 +294,11 @@ where Encode::::decode(enc) } - fn concat_vk(rho: B32, t1: EncodedT1) -> EncodedVerificationKey { + fn concat_vk(rho: B32, t1: EncodedT1) -> EncodedVerifyingKey { rho.concat(t1) } - fn split_vk(enc: &EncodedVerificationKey) -> (&B32, &EncodedT1) { + fn split_vk(enc: &EncodedVerifyingKey) -> (&B32, &EncodedT1) { enc.split_ref() } } @@ -417,3 +418,19 @@ where (c_tilde, z, h) } } + +pub trait MlDsaParams: + SigningKeyParams + VerifyingKeyParams + SignatureParams + Debug + Default + PartialEq + Clone +{ +} + +impl MlDsaParams for T where + T: SigningKeyParams + + VerifyingKeyParams + + SignatureParams + + Debug + + Default + + PartialEq + + Clone +{ +} diff --git a/ml-dsa/tests/key-gen.rs b/ml-dsa/tests/key-gen.rs index 811ea5b8..f1832c5e 100644 --- a/ml-dsa/tests/key-gen.rs +++ b/ml-dsa/tests/key-gen.rs @@ -25,21 +25,23 @@ fn acvp_key_gen() { } } -fn verify(tc: &acvp::TestCase) { +fn verify(tc: &acvp::TestCase) { // Import test data into the relevant array structures let seed = Array::try_from(tc.seed.as_slice()).unwrap(); - let pk_bytes = EncodedVerificationKey::

::try_from(tc.pk.as_slice()).unwrap(); + let vk_bytes = EncodedVerifyingKey::

::try_from(tc.pk.as_slice()).unwrap(); let sk_bytes = EncodedSigningKey::

::try_from(tc.sk.as_slice()).unwrap(); - let (pk, sk) = SigningKey::

::key_gen_internal(&seed); + let kp = P::key_gen_internal(&seed); + let sk = kp.signing_key; + let vk = kp.verifying_key; // Verify correctness via serialization - assert_eq!(pk.encode(), pk_bytes); assert_eq!(sk.encode(), sk_bytes); + assert_eq!(vk.encode(), vk_bytes); // Verify correctness via deserialization - assert!(pk == VerificationKey::

::decode(&pk_bytes)); assert!(sk == SigningKey::

::decode(&sk_bytes)); + assert!(vk == VerifyingKey::

::decode(&vk_bytes)); } mod acvp { diff --git a/ml-dsa/tests/sig-gen.rs b/ml-dsa/tests/sig-gen.rs index e0469546..a5f2c476 100644 --- a/ml-dsa/tests/sig-gen.rs +++ b/ml-dsa/tests/sig-gen.rs @@ -30,14 +30,14 @@ fn acvp_sig_gen() { } } -fn verify(tc: &acvp::TestCase) { +fn verify(tc: &acvp::TestCase) { // Import the signing key let sk_bytes = EncodedSigningKey::

::try_from(tc.sk.as_slice()).unwrap(); let sk = SigningKey::

::decode(&sk_bytes); // Verify correctness let rnd = B32::try_from(tc.rnd.as_slice()).unwrap(); - let sig = sk.sign_internal(&tc.message, &rnd); + let sig = sk.sign_internal(&[&tc.message], &rnd); let sig_bytes = sig.encode(); assert_eq!(tc.signature.as_slice(), sig_bytes.as_slice()); diff --git a/ml-dsa/tests/sig-ver.rs b/ml-dsa/tests/sig-ver.rs index 787a743f..aa964290 100644 --- a/ml-dsa/tests/sig-ver.rs +++ b/ml-dsa/tests/sig-ver.rs @@ -24,10 +24,10 @@ fn acvp_sig_ver() { } } -fn verify(tg: &acvp::TestGroup, tc: &acvp::TestCase) { +fn verify(tg: &acvp::TestGroup, tc: &acvp::TestCase) { // Import the verification key - let pk_bytes = EncodedVerificationKey::

::try_from(tg.pk.as_slice()).unwrap(); - let pk = VerificationKey::

::decode(&pk_bytes); + let vk_bytes = EncodedVerifyingKey::

::try_from(tg.pk.as_slice()).unwrap(); + let vk = VerifyingKey::

::decode(&vk_bytes); // Import the signature let sig_bytes = EncodedSignature::

::try_from(tc.signature.as_slice()).unwrap(); @@ -35,7 +35,7 @@ fn verify(tg: &acvp::TestGroup, tc: // Verify the signature if it successfully decoded let test_passed = sig - .map(|sig| pk.verify_internal(tc.message.as_slice(), &sig)) + .and_then(|sig| Some(vk.verify_internal(&[&tc.message], &sig))) .unwrap_or_default(); assert_eq!(test_passed, tc.test_passed); }