Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 51 additions & 64 deletions crates/fhe-math/src/rns/scaler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use itertools::{izip, Itertools};
use ndarray::{ArrayView1, ArrayViewMut1};
use num_bigint::BigUint;
use num_traits::{One, ToPrimitive, Zero};
use std::{cmp::min, sync::Arc};
use std::{borrow::Cow, cmp::min, sync::Arc};

/// Scaling factor when performing a RNS scaling.
#[derive(Default, Debug, Clone, PartialEq, Eq)]
Expand Down Expand Up @@ -49,18 +49,15 @@ pub struct RnsScaler {

gamma: Box<[u64]>,
gamma_shoup: Box<[u64]>,
theta_gamma_lo: u64,
theta_gamma_hi: u64,
theta_gamma: u128,
theta_gamma_sign: bool,

omega: Box<[Box<[u64]>]>,
omega_shoup: Box<[Box<[u64]>]>,
theta_omega_lo: Box<[u64]>,
theta_omega_hi: Box<[u64]>,
theta_omega: Box<[u128]>,
theta_omega_sign: Box<[bool]>,

theta_garner_lo: Box<[u64]>,
theta_garner_hi: Box<[u64]>,
theta_garner: Box<[u128]>,
theta_garner_shift: usize,
}

Expand All @@ -74,14 +71,13 @@ impl RnsScaler {
scaling_factor: ScalingFactor,
) -> Self {
// Let's define gamma = round(numerator * from.product / denominator)
let (gamma, theta_gamma_lo, theta_gamma_hi, theta_gamma_sign) =
Self::extract_projection_and_theta(
to,
&from.product,
&scaling_factor.numerator,
&scaling_factor.denominator,
false,
);
let (gamma, theta_gamma, theta_gamma_sign) = Self::extract_projection_and_theta(
to,
&from.product,
&scaling_factor.numerator,
&scaling_factor.denominator,
false,
);
let gamma_shoup = izip!(&gamma, &to.moduli)
.map(|(wi, q)| q.shoup(*wi))
.collect_vec();
Expand All @@ -93,25 +89,22 @@ impl RnsScaler {
omega.push(vec![0u64; from.moduli.len()].into_boxed_slice());
omega_shoup.push(vec![0u64; from.moduli.len()].into_boxed_slice());
}
let mut theta_omega_lo = Vec::with_capacity(from.garner.len());
let mut theta_omega_hi = Vec::with_capacity(from.garner.len());
let mut theta_omega = Vec::with_capacity(from.garner.len());
let mut theta_omega_sign = Vec::with_capacity(from.garner.len());
for i in 0..from.garner.len() {
let (omega_i, theta_omega_i_lo, theta_omega_i_hi, theta_omega_i_sign) =
Self::extract_projection_and_theta(
to,
&from.garner[i],
&scaling_factor.numerator,
&scaling_factor.denominator,
true,
);
let (omega_i, theta_omega_i, theta_omega_i_sign) = Self::extract_projection_and_theta(
to,
&from.garner[i],
&scaling_factor.numerator,
&scaling_factor.denominator,
true,
);
for j in 0..to.moduli.len() {
let qj = &to.moduli[j];
omega[j][i] = qj.reduce(omega_i[j]);
omega_shoup[j][i] = qj.shoup(omega[j][i]);
}
theta_omega_lo.push(theta_omega_i_lo);
theta_omega_hi.push(theta_omega_i_hi);
theta_omega.push(theta_omega_i);
theta_omega_sign.push(theta_omega_i_sign);
}

Expand All @@ -132,15 +125,15 @@ impl RnsScaler {
);
// Finally, define theta_garner_i = from.garner_i / product, also scaled by
// 2^127.
let mut theta_garner_lo = Vec::with_capacity(from.garner.len());
let mut theta_garner_hi = Vec::with_capacity(from.garner.len());
let mut theta_garner = Vec::with_capacity(from.garner.len());
for garner_i in &from.garner {
let mut theta: BigUint =
((garner_i << theta_garner_shift) + (&from.product >> 1)) / &from.product;
let theta_hi: BigUint = &theta >> 64;
theta -= &theta_hi << 64;
theta_garner_lo.push(theta.to_u64().unwrap());
theta_garner_hi.push(theta_hi.to_u64().unwrap());
let theta_combined =
(theta.to_u64().unwrap() as u128) | ((theta_hi.to_u64().unwrap() as u128) << 64);
theta_garner.push(theta_combined);
}

Self {
Expand All @@ -149,16 +142,13 @@ impl RnsScaler {
scaling_factor,
gamma: gamma.into_boxed_slice(),
gamma_shoup: gamma_shoup.into_boxed_slice(),
theta_gamma_lo,
theta_gamma_hi,
theta_gamma,
theta_gamma_sign,
omega: omega.into_boxed_slice(),
omega_shoup: omega_shoup.into_boxed_slice(),
theta_omega_lo: theta_omega_lo.into_boxed_slice(),
theta_omega_hi: theta_omega_hi.into_boxed_slice(),
theta_omega: theta_omega.into_boxed_slice(),
theta_omega_sign: theta_omega_sign.into_boxed_slice(),
theta_garner_lo: theta_garner_lo.into_boxed_slice(),
theta_garner_hi: theta_garner_hi.into_boxed_slice(),
theta_garner: theta_garner.into_boxed_slice(),
theta_garner_shift: theta_garner_shift as usize,
}
}
Expand All @@ -175,7 +165,7 @@ impl RnsScaler {
numerator: &BigUint,
denominator: &BigUint,
round_up: bool,
) -> (Vec<u64>, u64, u64, bool) {
) -> (Vec<u64>, u128, bool) {
let gamma = (numerator * input + (denominator >> 1)) / denominator;
let projected = ctx.project(&gamma);

Expand Down Expand Up @@ -211,10 +201,10 @@ impl RnsScaler {
}
let theta_hi_biguint: BigUint = &theta >> 64;
theta -= &theta_hi_biguint << 64;
let theta_lo = theta.to_u64().unwrap();
let theta_hi = theta_hi_biguint.to_u64().unwrap();
let theta_combined = (theta.to_u64().unwrap() as u128)
| ((theta_hi_biguint.to_u64().unwrap() as u128) << 64);

(projected, theta_lo, theta_hi, theta_sign)
(projected, theta_combined, theta_sign)
}

/// Output the RNS representation of the rests scaled by numerator *
Expand Down Expand Up @@ -244,16 +234,15 @@ impl RnsScaler {
debug_assert!(!out.is_empty());
debug_assert!(starting_index + out.len() <= self.to.moduli_u64.len());

// First, let's compute the inner product of the rests with theta_omega.
// First, let's compute the inner product of the rests with theta_garner.
let rest_cow = rests
.as_slice()
.map(Cow::Borrowed)
.unwrap_or_else(|| Cow::Owned(rests.to_vec()));
let rest_slice = rest_cow.as_ref();
let mut sum_theta_garner = u256::ZERO;
for (thetag_lo, thetag_hi, ri) in izip!(
self.theta_garner_lo.iter(),
self.theta_garner_hi.iter(),
rests
) {
sum_theta_garner = sum_theta_garner.wrapping_add(
U256::from(*ri) * U256::from((*thetag_lo as u128) | ((*thetag_hi as u128) << 64)),
);
for (thetag, ri) in self.theta_garner.iter().zip(rest_slice.iter()) {
sum_theta_garner = sum_theta_garner.wrapping_add(U256::from(*ri) * U256::from(*thetag));
}
// Let's compute v = round(sum_theta_garner / 2^theta_garner_shift)
sum_theta_garner >>= self.theta_garner_shift - 1;
Expand All @@ -265,14 +254,13 @@ impl RnsScaler {
let mut w = 0u128;
if !self.scaling_factor.is_one {
let mut sum_theta_omega = u256::ZERO;
for (thetao_lo, thetao_hi, thetao_sign, ri) in izip!(
self.theta_omega_lo.iter(),
self.theta_omega_hi.iter(),
self.theta_omega_sign.iter(),
rests
) {
let product = U256::from(*ri)
* U256::from((*thetao_lo as u128) | ((*thetao_hi as u128) << 64));
for ((thetao, thetao_sign), ri) in self
.theta_omega
.iter()
.zip(self.theta_omega_sign.iter())
.zip(rest_slice.iter())
{
let product = U256::from(*ri) * U256::from(*thetao);
if *thetao_sign {
sum_theta_omega = sum_theta_omega.wrapping_sub(product);
} else {
Expand All @@ -281,8 +269,7 @@ impl RnsScaler {
}

// Let's subtract v * theta_gamma to sum_theta_omega.
let v_theta_gamma = U256::from(v)
* U256::from((self.theta_gamma_lo as u128) | ((self.theta_gamma_hi as u128) << 64));
let v_theta_gamma = U256::from(v) * U256::from(self.theta_gamma);
if self.theta_gamma_sign {
sum_theta_omega = sum_theta_omega.wrapping_add(v_theta_gamma);
} else {
Expand Down Expand Up @@ -324,11 +311,11 @@ impl RnsScaler {
yi += if w_sign { **qi * 2 - wi } else { wi } as u128;
}

debug_assert!(rests.len() <= omega_i.len());
debug_assert!(rests.len() <= omega_shoup_i.len());
for j in 0..rests.len() {
debug_assert!(rest_slice.len() <= omega_i.len());
debug_assert!(rest_slice.len() <= omega_shoup_i.len());
for j in 0..rest_slice.len() {
yi += qi.lazy_mul_shoup(
*rests.get(j).unwrap(),
*rest_slice.get_unchecked(j),
*omega_i.get_unchecked(j),
*omega_shoup_i.get_unchecked(j),
) as u128;
Expand Down
Loading