Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 81 additions & 117 deletions crates/fhe-math/src/rq/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub mod scaler;
pub mod switcher;
pub mod traits;
use self::{scaler::Scaler, switcher::Switcher, traits::TryConvertFrom};
use crate::{Error, Result};
use crate::{zq::Modulus, Error, Result};
pub use context::Context;
use fhe_util::sample_vec_cbd;
use itertools::{izip, Itertools};
Expand Down Expand Up @@ -158,37 +158,28 @@ impl Poly {

/// Change the representation of the underlying polynomial.
pub fn change_representation(&mut self, to: Representation) {
match self.representation {
Representation::PowerBasis => {
match to {
Representation::Ntt => self.ntt_forward(),
Representation::NttShoup => {
self.ntt_forward();
self.compute_coefficients_shoup();
}
Representation::PowerBasis => {} // no-op
}
if self.representation == to {
return;
}

match (&self.representation, &to) {
(Representation::PowerBasis, Representation::Ntt) => self.ntt_forward(),
(Representation::PowerBasis, Representation::NttShoup) => {
self.ntt_forward();
self.compute_coefficients_shoup()
}
Representation::Ntt => {
match to {
Representation::PowerBasis => self.ntt_backward(),
Representation::NttShoup => self.compute_coefficients_shoup(),
Representation::Ntt => {} // no-op
}
(Representation::Ntt, Representation::PowerBasis) => self.ntt_backward(),
(Representation::Ntt, Representation::NttShoup) => self.compute_coefficients_shoup(),
(Representation::NttShoup, Representation::PowerBasis) => {
self.zeroize_shoup();
self.coefficients_shoup = None;
self.ntt_backward()
}
Representation::NttShoup => {
if to != Representation::NttShoup {
// We are not sure whether this polynomial was sensitive or not,
// so for security, we zeroize the Shoup coefficients.
self.zeroize_shoup();
self.coefficients_shoup = None
}
match to {
Representation::PowerBasis => self.ntt_backward(),
Representation::Ntt => {} // no-op
Representation::NttShoup => {} // no-op
}
(Representation::NttShoup, Representation::Ntt) => {
self.zeroize_shoup();
self.coefficients_shoup = None;
}
_ => unreachable!(),
}

self.representation = to;
Expand Down Expand Up @@ -283,25 +274,24 @@ impl Poly {
rng: &mut T,
) -> Result<Self> {
if !(1..=16).contains(&variance) {
Err(Error::Default(
return Err(Error::Default(
"The variance should be an integer between 1 and 16".to_string(),
))
} else {
let coeffs = Zeroizing::new(
sample_vec_cbd(ctx.degree, variance, rng)
.map_err(|e| Error::Default(e.to_string()))?,
);
let mut p = Poly::try_convert_from(
coeffs.as_ref() as &[i64],
ctx,
false,
Representation::PowerBasis,
)?;
if representation != Representation::PowerBasis {
p.change_representation(representation);
}
Ok(p)
));
}

let coeffs = Zeroizing::new(
sample_vec_cbd(ctx.degree, variance, rng).map_err(|e| Error::Default(e.to_string()))?,
);
let mut p = Poly::try_convert_from(
coeffs.as_ref() as &[i64],
ctx,
false,
Representation::PowerBasis,
)?;
if representation != Representation::PowerBasis {
p.change_representation(representation);
}
Ok(p)
}

/// Access the polynomial coefficients in RNS representation.
Expand Down Expand Up @@ -341,18 +331,7 @@ impl Poly {
unsafe { q.allow_variable_time_computations() }
}
match self.representation {
Representation::Ntt => {
izip!(
q.coefficients.outer_iter_mut(),
self.coefficients.outer_iter()
)
.for_each(|(mut q_row, p_row)| {
for (j, k) in izip!(self.ctx.bitrev.iter(), i.power_bitrev.iter()) {
q_row[*j] = p_row[*k]
}
});
}
Representation::NttShoup => {
Representation::Ntt | Representation::NttShoup => {
izip!(
q.coefficients.outer_iter_mut(),
self.coefficients.outer_iter()
Expand All @@ -362,15 +341,17 @@ impl Poly {
q_row[*j] = p_row[*k]
}
});
izip!(
q.coefficients_shoup.as_mut().unwrap().outer_iter_mut(),
self.coefficients_shoup.as_ref().unwrap().outer_iter()
)
.for_each(|(mut q_row, p_row)| {
for (j, k) in izip!(self.ctx.bitrev.iter(), i.power_bitrev.iter()) {
q_row[*j] = p_row[*k]
}
});
if self.representation == Representation::NttShoup {
izip!(
q.coefficients_shoup.as_mut().unwrap().outer_iter_mut(),
self.coefficients_shoup.as_ref().unwrap().outer_iter()
)
.for_each(|(mut q_row, p_row)| {
for (j, k) in izip!(self.ctx.bitrev.iter(), i.power_bitrev.iter()) {
q_row[*j] = p_row[*k]
}
});
}
}
Representation::PowerBasis => {
let mut power = 0usize;
Expand Down Expand Up @@ -455,57 +436,40 @@ impl Poly {
let (mut q_new_polys, mut q_last_poly) =
self.coefficients.view_mut().split_at(Axis(0), q_len - 1);

if self.allow_variable_time_computations {
unsafe {
q_last_poly
.iter_mut()
.for_each(|coeff| *coeff = q_last.add_vt(*coeff, q_last_div_2));
izip!(
q_new_polys.outer_iter_mut(),
self.ctx.q.iter(),
self.ctx.inv_last_qi_mod_qj.iter(),
self.ctx.inv_last_qi_mod_qj_shoup.iter(),
)
.for_each(|(coeffs, qi, inv, inv_shoup)| {
let q_last_div_2_mod_qi = **qi - qi.reduce_vt(q_last_div_2); // Up to qi.modulus()
for (coeff, q_last_coeff) in izip!(coeffs, q_last_poly.iter()) {
// (x mod q_last - q_L/2) mod q_i
let tmp = qi.lazy_reduce(*q_last_coeff) + q_last_div_2_mod_qi; // Up to 3 * qi.modulus()

// ((x mod q_i) - (x mod q_last) + (q_L/2 mod q_i)) mod q_i
// = (x - x mod q_last + q_L/2) mod q_i
*coeff += 3 * (**qi) - tmp; // Up to 4 * qi.modulus()

// q_last^{-1} * (x - x mod q_last) mod q_i
*coeff = qi.mul_shoup(*coeff, *inv, *inv_shoup);
}
});
}
let add: fn(&Modulus, u64, u64) -> u64 = if self.allow_variable_time_computations {
|qi, a, b| unsafe { qi.add_vt(a, b) }
} else {
q_last_poly
.iter_mut()
.for_each(|coeff| *coeff = q_last.add(*coeff, q_last_div_2));
izip!(
q_new_polys.outer_iter_mut(),
self.ctx.q.iter(),
self.ctx.inv_last_qi_mod_qj.iter(),
self.ctx.inv_last_qi_mod_qj_shoup.iter(),
)
.for_each(|(coeffs, qi, inv, inv_shoup)| {
let q_last_div_2_mod_qi = **qi - qi.reduce(q_last_div_2); // Up to qi.modulus()
for (coeff, q_last_coeff) in izip!(coeffs, q_last_poly.iter()) {
// (x mod q_last - q_L/2) mod q_i
let tmp = qi.lazy_reduce(*q_last_coeff) + q_last_div_2_mod_qi; // Up to 3 * qi.modulus()

// ((x mod q_i) - (x mod q_last) + (q_L/2 mod q_i)) mod q_i
// = (x - x mod q_last + q_L/2) mod q_i
*coeff += 3 * (**qi) - tmp; // Up to 4 * qi.modulus()

// q_last^{-1} * (x - x mod q_last) mod q_i
*coeff = qi.mul_shoup(*coeff, *inv, *inv_shoup);
}
});
}
|qi, a, b| qi.add(a, b)
};
let reduce: unsafe fn(&Modulus, u64) -> u64 = if self.allow_variable_time_computations {
|qi, a| unsafe { qi.reduce_vt(a) }
} else {
|qi, a| qi.reduce(a)
};

q_last_poly
.iter_mut()
.for_each(|coeff| *coeff = add(q_last, *coeff, q_last_div_2));
izip!(
q_new_polys.outer_iter_mut(),
self.ctx.q.iter(),
self.ctx.inv_last_qi_mod_qj.iter(),
self.ctx.inv_last_qi_mod_qj_shoup.iter(),
)
.for_each(|(coeffs, qi, inv, inv_shoup)| {
let q_last_div_2_mod_qi = **qi - unsafe { reduce(qi, q_last_div_2) }; // Up to qi.modulus()
for (coeff, q_last_coeff) in izip!(coeffs, q_last_poly.iter()) {
// (x mod q_last - q_L/2) mod q_i
let tmp = qi.lazy_reduce(*q_last_coeff) + q_last_div_2_mod_qi; // Up to 3 * qi.modulus()

// ((x mod q_i) - (x mod q_last) + (q_L/2 mod q_i)) mod q_i
// = (x - x mod q_last + q_L/2) mod q_i
*coeff += 3 * (**qi) - tmp; // Up to 4 * qi.modulus()

// q_last^{-1} * (x - x mod q_last) mod q_i
*coeff = qi.mul_shoup(*coeff, *inv, *inv_shoup);
}
});

// Remove the last row, and update the context.
if !self.allow_variable_time_computations {
Expand Down
Loading