diff --git a/crates/fhe-math/src/rns/scaler.rs b/crates/fhe-math/src/rns/scaler.rs index b3051b1b..40dcaee7 100644 --- a/crates/fhe-math/src/rns/scaler.rs +++ b/crates/fhe-math/src/rns/scaler.rs @@ -8,7 +8,7 @@ use itertools::{izip, Itertools}; use ndarray::{ArrayView1, ArrayViewMut1}; use num_bigint::BigUint; use num_traits::{One, ToPrimitive, Zero}; -use std::{cmp::min, sync::Arc}; +use std::{borrow::Cow, cmp::min, sync::Arc}; /// Scaling factor when performing a RNS scaling. #[derive(Default, Debug, Clone, PartialEq, Eq)] @@ -49,18 +49,15 @@ pub struct RnsScaler { gamma: Box<[u64]>, gamma_shoup: Box<[u64]>, - theta_gamma_lo: u64, - theta_gamma_hi: u64, + theta_gamma: u128, theta_gamma_sign: bool, omega: Box<[Box<[u64]>]>, omega_shoup: Box<[Box<[u64]>]>, - theta_omega_lo: Box<[u64]>, - theta_omega_hi: Box<[u64]>, + theta_omega: Box<[u128]>, theta_omega_sign: Box<[bool]>, - theta_garner_lo: Box<[u64]>, - theta_garner_hi: Box<[u64]>, + theta_garner: Box<[u128]>, theta_garner_shift: usize, } @@ -74,14 +71,13 @@ impl RnsScaler { scaling_factor: ScalingFactor, ) -> Self { // Let's define gamma = round(numerator * from.product / denominator) - let (gamma, theta_gamma_lo, theta_gamma_hi, theta_gamma_sign) = - Self::extract_projection_and_theta( - to, - &from.product, - &scaling_factor.numerator, - &scaling_factor.denominator, - false, - ); + let (gamma, theta_gamma, theta_gamma_sign) = Self::extract_projection_and_theta( + to, + &from.product, + &scaling_factor.numerator, + &scaling_factor.denominator, + false, + ); let gamma_shoup = izip!(&gamma, &to.moduli) .map(|(wi, q)| q.shoup(*wi)) .collect_vec(); @@ -93,25 +89,22 @@ impl RnsScaler { omega.push(vec![0u64; from.moduli.len()].into_boxed_slice()); omega_shoup.push(vec![0u64; from.moduli.len()].into_boxed_slice()); } - let mut theta_omega_lo = Vec::with_capacity(from.garner.len()); - let mut theta_omega_hi = Vec::with_capacity(from.garner.len()); + let mut theta_omega = Vec::with_capacity(from.garner.len()); let mut theta_omega_sign = Vec::with_capacity(from.garner.len()); for i in 0..from.garner.len() { - let (omega_i, theta_omega_i_lo, theta_omega_i_hi, theta_omega_i_sign) = - Self::extract_projection_and_theta( - to, - &from.garner[i], - &scaling_factor.numerator, - &scaling_factor.denominator, - true, - ); + let (omega_i, theta_omega_i, theta_omega_i_sign) = Self::extract_projection_and_theta( + to, + &from.garner[i], + &scaling_factor.numerator, + &scaling_factor.denominator, + true, + ); for j in 0..to.moduli.len() { let qj = &to.moduli[j]; omega[j][i] = qj.reduce(omega_i[j]); omega_shoup[j][i] = qj.shoup(omega[j][i]); } - theta_omega_lo.push(theta_omega_i_lo); - theta_omega_hi.push(theta_omega_i_hi); + theta_omega.push(theta_omega_i); theta_omega_sign.push(theta_omega_i_sign); } @@ -132,15 +125,15 @@ impl RnsScaler { ); // Finally, define theta_garner_i = from.garner_i / product, also scaled by // 2^127. - let mut theta_garner_lo = Vec::with_capacity(from.garner.len()); - let mut theta_garner_hi = Vec::with_capacity(from.garner.len()); + let mut theta_garner = Vec::with_capacity(from.garner.len()); for garner_i in &from.garner { let mut theta: BigUint = ((garner_i << theta_garner_shift) + (&from.product >> 1)) / &from.product; let theta_hi: BigUint = &theta >> 64; theta -= &theta_hi << 64; - theta_garner_lo.push(theta.to_u64().unwrap()); - theta_garner_hi.push(theta_hi.to_u64().unwrap()); + let theta_combined = + (theta.to_u64().unwrap() as u128) | ((theta_hi.to_u64().unwrap() as u128) << 64); + theta_garner.push(theta_combined); } Self { @@ -149,16 +142,13 @@ impl RnsScaler { scaling_factor, gamma: gamma.into_boxed_slice(), gamma_shoup: gamma_shoup.into_boxed_slice(), - theta_gamma_lo, - theta_gamma_hi, + theta_gamma, theta_gamma_sign, omega: omega.into_boxed_slice(), omega_shoup: omega_shoup.into_boxed_slice(), - theta_omega_lo: theta_omega_lo.into_boxed_slice(), - theta_omega_hi: theta_omega_hi.into_boxed_slice(), + theta_omega: theta_omega.into_boxed_slice(), theta_omega_sign: theta_omega_sign.into_boxed_slice(), - theta_garner_lo: theta_garner_lo.into_boxed_slice(), - theta_garner_hi: theta_garner_hi.into_boxed_slice(), + theta_garner: theta_garner.into_boxed_slice(), theta_garner_shift: theta_garner_shift as usize, } } @@ -175,7 +165,7 @@ impl RnsScaler { numerator: &BigUint, denominator: &BigUint, round_up: bool, - ) -> (Vec, u64, u64, bool) { + ) -> (Vec, u128, bool) { let gamma = (numerator * input + (denominator >> 1)) / denominator; let projected = ctx.project(&gamma); @@ -211,10 +201,10 @@ impl RnsScaler { } let theta_hi_biguint: BigUint = &theta >> 64; theta -= &theta_hi_biguint << 64; - let theta_lo = theta.to_u64().unwrap(); - let theta_hi = theta_hi_biguint.to_u64().unwrap(); + let theta_combined = (theta.to_u64().unwrap() as u128) + | ((theta_hi_biguint.to_u64().unwrap() as u128) << 64); - (projected, theta_lo, theta_hi, theta_sign) + (projected, theta_combined, theta_sign) } /// Output the RNS representation of the rests scaled by numerator * @@ -244,16 +234,15 @@ impl RnsScaler { debug_assert!(!out.is_empty()); debug_assert!(starting_index + out.len() <= self.to.moduli_u64.len()); - // First, let's compute the inner product of the rests with theta_omega. + // First, let's compute the inner product of the rests with theta_garner. + let rest_cow = rests + .as_slice() + .map(Cow::Borrowed) + .unwrap_or_else(|| Cow::Owned(rests.to_vec())); + let rest_slice = rest_cow.as_ref(); let mut sum_theta_garner = u256::ZERO; - for (thetag_lo, thetag_hi, ri) in izip!( - self.theta_garner_lo.iter(), - self.theta_garner_hi.iter(), - rests - ) { - sum_theta_garner = sum_theta_garner.wrapping_add( - U256::from(*ri) * U256::from((*thetag_lo as u128) | ((*thetag_hi as u128) << 64)), - ); + for (thetag, ri) in self.theta_garner.iter().zip(rest_slice.iter()) { + sum_theta_garner = sum_theta_garner.wrapping_add(U256::from(*ri) * U256::from(*thetag)); } // Let's compute v = round(sum_theta_garner / 2^theta_garner_shift) sum_theta_garner >>= self.theta_garner_shift - 1; @@ -265,14 +254,13 @@ impl RnsScaler { let mut w = 0u128; if !self.scaling_factor.is_one { let mut sum_theta_omega = u256::ZERO; - for (thetao_lo, thetao_hi, thetao_sign, ri) in izip!( - self.theta_omega_lo.iter(), - self.theta_omega_hi.iter(), - self.theta_omega_sign.iter(), - rests - ) { - let product = U256::from(*ri) - * U256::from((*thetao_lo as u128) | ((*thetao_hi as u128) << 64)); + for ((thetao, thetao_sign), ri) in self + .theta_omega + .iter() + .zip(self.theta_omega_sign.iter()) + .zip(rest_slice.iter()) + { + let product = U256::from(*ri) * U256::from(*thetao); if *thetao_sign { sum_theta_omega = sum_theta_omega.wrapping_sub(product); } else { @@ -281,8 +269,7 @@ impl RnsScaler { } // Let's subtract v * theta_gamma to sum_theta_omega. - let v_theta_gamma = U256::from(v) - * U256::from((self.theta_gamma_lo as u128) | ((self.theta_gamma_hi as u128) << 64)); + let v_theta_gamma = U256::from(v) * U256::from(self.theta_gamma); if self.theta_gamma_sign { sum_theta_omega = sum_theta_omega.wrapping_add(v_theta_gamma); } else { @@ -324,11 +311,11 @@ impl RnsScaler { yi += if w_sign { **qi * 2 - wi } else { wi } as u128; } - debug_assert!(rests.len() <= omega_i.len()); - debug_assert!(rests.len() <= omega_shoup_i.len()); - for j in 0..rests.len() { + debug_assert!(rest_slice.len() <= omega_i.len()); + debug_assert!(rest_slice.len() <= omega_shoup_i.len()); + for j in 0..rest_slice.len() { yi += qi.lazy_mul_shoup( - *rests.get(j).unwrap(), + *rest_slice.get_unchecked(j), *omega_i.get_unchecked(j), *omega_shoup_i.get_unchecked(j), ) as u128;