Skip to content

Commit

Permalink
Make struct names a bit shorter
Browse files Browse the repository at this point in the history
  • Loading branch information
bifurcation committed Dec 4, 2024
1 parent 3061383 commit 7d347c0
Show file tree
Hide file tree
Showing 8 changed files with 156 additions and 156 deletions.
24 changes: 12 additions & 12 deletions ml-dsa/src/algebra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ define_field!(BaseField, u32, u64, u128, 8380417);

pub type Int = <BaseField as Field>::Int;

pub type FieldElement = algebra::Elem<BaseField>;
pub type Elem = algebra::Elem<BaseField>;
pub type Polynomial = algebra::Polynomial<BaseField>;
pub type PolynomialVector<K> = algebra::PolynomialVector<BaseField, K>;
pub type Vector<K> = algebra::Vector<BaseField, K>;
pub type NttPolynomial = algebra::NttPolynomial<BaseField>;
pub type NttVector<K> = algebra::NttVector<BaseField, K>;
pub type NttMatrix<K, L> = algebra::NttMatrix<BaseField, K, L>;
Expand Down Expand Up @@ -46,17 +46,17 @@ where
}

pub trait Decompose {
fn decompose<TwoGamma2: Unsigned>(self) -> (FieldElement, FieldElement);
fn decompose<TwoGamma2: Unsigned>(self) -> (Elem, Elem);
}

impl Decompose for FieldElement {
impl Decompose for Elem {
// Algorithm 36 Decompose
fn decompose<TwoGamma2: Unsigned>(self) -> (FieldElement, FieldElement) {
fn decompose<TwoGamma2: Unsigned>(self) -> (Elem, Elem) {
let r_plus = self.clone();
let r0 = r_plus.mod_plus_minus::<TwoGamma2>();

if r_plus - r0 == FieldElement::new(BaseField::Q - 1) {
(FieldElement::new(0), r0 - FieldElement::new(1))
if r_plus - r0 == Elem::new(BaseField::Q - 1) {
(Elem::new(0), r0 - Elem::new(1))
} else {
let mut r1 = r_plus - r0;
r1.0 /= TwoGamma2::U32;
Expand All @@ -73,13 +73,13 @@ pub trait AlgebraExt: Sized {
fn low_bits<TwoGamma2: Unsigned>(&self) -> Self;
}

impl AlgebraExt for FieldElement {
impl AlgebraExt for Elem {
fn mod_plus_minus<M: Unsigned>(&self) -> Self {
let raw_mod = FieldElement::new(M::reduce(self.0));
let raw_mod = Elem::new(M::reduce(self.0));
if raw_mod.0 <= M::U32 >> 1 {
raw_mod
} else {
raw_mod - FieldElement::new(M::U32)
raw_mod - Elem::new(M::U32)
}
}

Expand Down Expand Up @@ -111,7 +111,7 @@ impl AlgebraExt for FieldElement {

let r_plus = self.clone();
let r0 = r_plus.mod_plus_minus::<Pow2D>();
let r1 = FieldElement::new((r_plus - r0).0 >> D::USIZE);
let r1 = Elem::new((r_plus - r0).0 >> D::USIZE);

(r1, r0)
}
Expand Down Expand Up @@ -156,7 +156,7 @@ impl AlgebraExt for Polynomial {
}
}

impl<K: ArraySize> AlgebraExt for PolynomialVector<K> {
impl<K: ArraySize> AlgebraExt for Vector<K> {
fn mod_plus_minus<M: Unsigned>(&self) -> Self {
Self(self.0.iter().map(|x| x.mod_plus_minus::<M>()).collect())
}
Expand Down
56 changes: 28 additions & 28 deletions ml-dsa/src/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ pub type RangeEncodingBits<A, B> = <(A, B) as RangeEncodingSize>::EncodingSize;
pub type RangeEncodedPolynomialSize<A, B> =
<RangeEncodingBits<A, B> as EncodingSize>::EncodedPolynomialSize;
pub type RangeEncodedPolynomial<A, B> = Array<u8, RangeEncodedPolynomialSize<A, B>>;
pub type RangeEncodedPolynomialVectorSize<A, B, K> =
<RangeEncodingBits<A, B> as VectorEncodingSize<K>>::EncodedPolynomialVectorSize;
pub type RangeEncodedPolynomialVector<A, B, K> =
Array<u8, RangeEncodedPolynomialVectorSize<A, B, K>>;
pub type RangeEncodedVectorSize<A, B, K> =
<RangeEncodingBits<A, B> as VectorEncodingSize<K>>::EncodedVectorSize;
pub type RangeEncodedVector<A, B, K> =
Array<u8, RangeEncodedVectorSize<A, B, K>>;

/// BitPack
pub trait BitPack<A, B> {
Expand All @@ -49,8 +49,8 @@ where

// Algorithm 17 BitPack
fn pack(&self) -> RangeEncodedPolynomial<A, B> {
let a = FieldElement::new(RangeMin::<A, B>::U32);
let b = FieldElement::new(RangeMax::<A, B>::U32);
let a = Elem::new(RangeMin::<A, B>::U32);
let b = Elem::new(RangeMax::<A, B>::U32);

let to_encode = Self::new(
self.0
Expand All @@ -66,8 +66,8 @@ where

// Algorithm 17 BitUnPack
fn unpack(enc: &RangeEncodedPolynomial<A, B>) -> Self {
let a = FieldElement::new(RangeMin::<A, B>::U32);
let b = FieldElement::new(RangeMax::<A, B>::U32);
let a = Elem::new(RangeMin::<A, B>::U32);
let b = Elem::new(RangeMax::<A, B>::U32);
let mut decoded: Self = Encode::<RangeEncodingBits<A, B>>::decode(enc);

for z in decoded.0.iter_mut() {
Expand All @@ -79,20 +79,20 @@ where
}
}

impl<K, A, B> BitPack<A, B> for PolynomialVector<K>
impl<K, A, B> BitPack<A, B> for Vector<K>
where
K: ArraySize,
(A, B): RangeEncodingSize,
RangeEncodingBits<A, B>: VectorEncodingSize<K>,
{
type PackedSize = RangeEncodedPolynomialVectorSize<A, B, K>;
type PackedSize = RangeEncodedVectorSize<A, B, K>;

fn pack(&self) -> RangeEncodedPolynomialVector<A, B, K> {
fn pack(&self) -> RangeEncodedVector<A, B, K> {
let polys = self.0.iter().map(|x| BitPack::<A, B>::pack(x)).collect();
RangeEncodingBits::<A, B>::flatten(polys)
}

fn unpack(enc: &RangeEncodedPolynomialVector<A, B, K>) -> Self {
fn unpack(enc: &RangeEncodedVector<A, B, K>) -> Self {
let unfold = RangeEncodingBits::<A, B>::unflatten(enc);
Self(
unfold
Expand Down Expand Up @@ -146,7 +146,7 @@ pub(crate) mod test {
let mut rng = rand::thread_rng();
let decoded = Polynomial::new(Array::from_fn(|_| {
let x: u32 = rng.gen();
FieldElement::new(x % (b + 1))
Elem::new(x % (b + 1))
}));

let actual_encoded = Encode::<D>::encode(&decoded);
Expand All @@ -162,14 +162,14 @@ pub(crate) mod test {
// Use a standard test pattern across all the cases
let decoded = Polynomial::new(
Array::<_, U8>([
FieldElement::new(0),
FieldElement::new(1),
FieldElement::new(2),
FieldElement::new(3),
FieldElement::new(4),
FieldElement::new(5),
FieldElement::new(6),
FieldElement::new(7),
Elem::new(0),
Elem::new(1),
Elem::new(2),
Elem::new(3),
Elem::new(4),
Elem::new(5),
Elem::new(6),
Elem::new(7),
])
.repeat(),
);
Expand Down Expand Up @@ -205,8 +205,8 @@ pub(crate) mod test {
B: Unsigned,
(A, B): RangeEncodingSize,
{
let a = FieldElement::new(A::U32);
let b = FieldElement::new(B::U32);
let a = Elem::new(A::U32);
let b = Elem::new(B::U32);

// Test known answer
let actual_encoded = BitPack::<A, B>::pack(decoded);
Expand All @@ -220,7 +220,7 @@ pub(crate) mod test {
let decoded = Polynomial::new(Array::from_fn(|_| {
let mut x: u32 = rng.gen();
x = x % (a.0 + b.0);
b - FieldElement::new(x)
b - Elem::new(x)
}));

let actual_encoded = BitPack::<A, B>::pack(&decoded);
Expand All @@ -237,10 +237,10 @@ pub(crate) mod test {
// (We can't use -2 because the eta=2 case doesn't actually cover -2)
let decoded = Polynomial::new(
Array::<_, U4>([
FieldElement::new(BaseField::Q - 1),
FieldElement::new(0),
FieldElement::new(1),
FieldElement::new(2),
Elem::new(BaseField::Q - 1),
Elem::new(0),
Elem::new(1),
Elem::new(2),
])
.repeat(),
);
Expand Down
16 changes: 8 additions & 8 deletions ml-dsa/src/hint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,22 @@ use hybrid_array::{
use crate::algebra::*;
use crate::param::*;

fn make_hint<TwoGamma2: Unsigned>(z: FieldElement, r: FieldElement) -> bool {
fn make_hint<TwoGamma2: Unsigned>(z: Elem, r: Elem) -> bool {
let r1 = r.high_bits::<TwoGamma2>();
let v1 = (r + z).high_bits::<TwoGamma2>();
r1 != v1
}

fn use_hint<TwoGamma2: Unsigned>(h: bool, r: FieldElement) -> FieldElement {
fn use_hint<TwoGamma2: Unsigned>(h: bool, r: Elem) -> Elem {
let m: u32 = (BaseField::Q - 1) / TwoGamma2::U32;
let (r1, r0) = r.decompose::<TwoGamma2>();
let gamma2 = TwoGamma2::U32 / 2;
if h && r0.0 <= gamma2 {
FieldElement::new((r1.0 + 1) % m)
Elem::new((r1.0 + 1) % m)
} else if h && r0.0 > BaseField::Q - gamma2 {
FieldElement::new((r1.0 + m - 1) % m)
Elem::new((r1.0 + m - 1) % m)
} else if h {
// We use the FieldElement encoding even for signed integers. Since r0 is computed
// We use the Elem encoding even for signed integers. Since r0 is computed
// mod+- 2*gamma2, it is guaranteed to be in (gamma2, gamma2].
unreachable!();
} else {
Expand All @@ -48,7 +48,7 @@ impl<P> Hint<P>
where
P: SignatureParams,
{
pub fn new(z: PolynomialVector<P::K>, r: PolynomialVector<P::K>) -> Self {
pub fn new(z: Vector<P::K>, r: Vector<P::K>) -> Self {
let zi = z.0.iter();
let ri = r.0.iter();

Expand All @@ -73,11 +73,11 @@ where
.sum()
}

pub fn use_hint(&self, r: &PolynomialVector<P::K>) -> PolynomialVector<P::K> {
pub fn use_hint(&self, r: &Vector<P::K>) -> Vector<P::K> {
let hi = self.0.iter();
let ri = r.0.iter();

PolynomialVector::new(
Vector::new(
hi.zip(ri)
.map(|(hv, rv)| {
let hvi = hv.iter();
Expand Down
22 changes: 11 additions & 11 deletions ml-dsa/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ pub use crate::util::B32;
#[derive(Clone, PartialEq)]
pub struct Signature<P: SignatureParams> {
c_tilde: Array<u8, P::Lambda>,
z: PolynomialVector<P::L>,
z: Vector<P::L>,
h: Hint<P>,
}

Expand Down Expand Up @@ -83,9 +83,9 @@ pub struct SigningKey<P: ParameterSet> {
rho: B32,
K: B32,
tr: B64,
s1: PolynomialVector<P::L>,
s2: PolynomialVector<P::K>,
t0: PolynomialVector<P::K>,
s1: Vector<P::L>,
s2: Vector<P::K>,
t0: Vector<P::K>,

// Derived values
s1_hat: NttVector<P::L>,
Expand All @@ -99,9 +99,9 @@ impl<P: ParameterSet> SigningKey<P> {
rho: B32,
K: B32,
tr: B64,
s1: PolynomialVector<P::L>,
s2: PolynomialVector<P::K>,
t0: PolynomialVector<P::K>,
s1: Vector<P::L>,
s2: Vector<P::K>,
t0: Vector<P::K>,
A_hat: Option<NttMatrix<P::K, P::L>>,
) -> Self {
let A_hat = A_hat.unwrap_or_else(|| expand_a(&rho));
Expand Down Expand Up @@ -258,7 +258,7 @@ impl<P: ParameterSet> SigningKey<P> {
#[derive(Clone, PartialEq)]
pub struct VerificationKey<P: ParameterSet> {
rho: B32,
t1: PolynomialVector<P::K>,
t1: Vector<P::K>,

// Derived values
A_hat: NttMatrix<P::K, P::L>,
Expand Down Expand Up @@ -294,21 +294,21 @@ impl<P: VerificationKeyParams> VerificationKey<P> {
sigma.c_tilde == cp_tilde
}

fn encode_internal(rho: &B32, t1: &PolynomialVector<P::K>) -> EncodedVerificationKey<P> {
fn encode_internal(rho: &B32, t1: &Vector<P::K>) -> EncodedVerificationKey<P> {
let t1_enc = P::encode_t1(t1);
P::concat_vk(rho.clone(), t1_enc)
}

fn new(
rho: B32,
t1: PolynomialVector<P::K>,
t1: Vector<P::K>,
A_hat: Option<NttMatrix<P::K, P::L>>,
enc: Option<EncodedVerificationKey<P>>,
) -> Self {
let A_hat = A_hat.unwrap_or_else(|| expand_a(&rho));
let enc = enc.unwrap_or_else(|| Self::encode_internal(&rho, &t1));

let t1_2d_hat = (FieldElement::new(1 << 13) * &t1).ntt();
let t1_2d_hat = (Elem::new(1 << 13) * &t1).ntt();
let tr: B64 = H::default().absorb(&enc).squeeze_new();

Self {
Expand Down
Loading

0 comments on commit 7d347c0

Please sign in to comment.