diff --git a/crates/fhe/Cargo.toml b/crates/fhe/Cargo.toml index 85e0f68f..14a90e12 100644 --- a/crates/fhe/Cargo.toml +++ b/crates/fhe/Cargo.toml @@ -60,11 +60,20 @@ harness = false name = "bfv_rgsw" harness = false +[[example]] +name = "bfv_basic" + +[[example]] +name = "bfv_ops" + [[example]] name = "mulpir" [[example]] name = "sealpir" +[[example]] +name = "rgsw" + [[example]] name = "voting" diff --git a/crates/fhe/benches/bfv.rs b/crates/fhe/benches/bfv.rs index d1a74a2e..19ebc252 100644 --- a/crates/fhe/benches/bfv.rs +++ b/crates/fhe/benches/bfv.rs @@ -18,7 +18,7 @@ pub fn bfv_benchmark(c: &mut Criterion) { group.warm_up_time(Duration::from_millis(600)); group.measurement_time(Duration::from_millis(1000)); - for par in BfvParameters::default_parameters_128(20) { + for par in BfvParameters::default_parameters_128(20).unwrap() { let sk = SecretKey::random(&par, &mut rng); let ek = if par.moduli().len() > 1 { Some( diff --git a/crates/fhe/benches/bfv_optimized_ops.rs b/crates/fhe/benches/bfv_optimized_ops.rs index f61ed62f..296002ef 100644 --- a/crates/fhe/benches/bfv_optimized_ops.rs +++ b/crates/fhe/benches/bfv_optimized_ops.rs @@ -12,7 +12,7 @@ pub fn bfv_benchmark(c: &mut Criterion) { group.warm_up_time(Duration::from_secs(1)); group.measurement_time(Duration::from_secs(1)); - for par in BfvParameters::default_parameters_128(20).skip(2) { + for par in BfvParameters::default_parameters_128(20).unwrap() { for size in [10, 128, 1000] { let sk = SecretKey::random(&par, &mut rng); let pt1 = diff --git a/crates/fhe/benches/bfv_rgsw.rs b/crates/fhe/benches/bfv_rgsw.rs index efdf51f9..eac89887 100644 --- a/crates/fhe/benches/bfv_rgsw.rs +++ b/crates/fhe/benches/bfv_rgsw.rs @@ -11,7 +11,7 @@ pub fn bfv_rgsw_benchmark(c: &mut Criterion) { group.warm_up_time(Duration::from_secs(1)); group.measurement_time(Duration::from_secs(1)); - for par in BfvParameters::default_parameters_128(20).skip(2) { + for par in BfvParameters::default_parameters_128(20).unwrap() { let mut rng = rng(); let sk = SecretKey::random(&par, &mut rng); diff --git a/crates/fhe/examples/bfv_basic.rs b/crates/fhe/examples/bfv_basic.rs new file mode 100644 index 00000000..31427b4e --- /dev/null +++ b/crates/fhe/examples/bfv_basic.rs @@ -0,0 +1,41 @@ +use std::error::Error; + +use fhe::bfv::{BfvParameters, Encoding, Plaintext, PublicKey, SecretKey}; +use fhe_traits::{FheDecoder, FheDecrypter, FheEncoder, FheEncrypter}; +use rand::rng; + +fn main() -> Result<(), Box> { + let mut rng = rng(); + // Use default parameters + let params = BfvParameters::default_parameters_128(16)? + .nth(2) + .ok_or("Could not generate parameters")?; + + // Generate keys + let sk = SecretKey::random(¶ms, &mut rng); + let pk = PublicKey::new(&sk, &mut rng); + + // ----- Without SIMD ----- + let pt_a = Plaintext::try_encode(&[3u64], Encoding::poly(), ¶ms)?; + let pt_b = Plaintext::try_encode(&[5u64], Encoding::poly(), ¶ms)?; + let ct_a = pk.try_encrypt(&pt_a, &mut rng)?; + let ct_b = pk.try_encrypt(&pt_b, &mut rng)?; + let ct_sum = &ct_a + &ct_b; + let pt_sum = sk.try_decrypt(&ct_sum)?; + let res = Vec::::try_decode(&pt_sum, Encoding::poly())?; + println!("3 + 5 = {}", res[0]); + + // ----- With SIMD ----- + let v1 = vec![1u64, 2, 3, 4]; + let v2 = vec![5u64, 6, 7, 8]; + let pt_v1 = Plaintext::try_encode(&v1, Encoding::simd(), ¶ms)?; + let pt_v2 = Plaintext::try_encode(&v2, Encoding::simd(), ¶ms)?; + let ct_v1 = pk.try_encrypt(&pt_v1, &mut rng)?; + let ct_v2 = pk.try_encrypt(&pt_v2, &mut rng)?; + let ct_vsum = &ct_v1 + &ct_v2; + let pt_vsum = sk.try_decrypt(&ct_vsum)?; + let res_v = Vec::::try_decode(&pt_vsum, Encoding::simd())?; + println!("{:?} + {:?} = {:?}", v1, v2, &res_v[..v1.len()]); + + Ok(()) +} diff --git a/crates/fhe/examples/bfv_ops.rs b/crates/fhe/examples/bfv_ops.rs new file mode 100644 index 00000000..ee15b6cb --- /dev/null +++ b/crates/fhe/examples/bfv_ops.rs @@ -0,0 +1,152 @@ +mod util; + +use std::error::Error; +use std::sync::Arc; + +use fhe::bfv::{ + BfvParameters, Ciphertext, Encoding, EvaluationKeyBuilder, Plaintext, PublicKey, + RelinearizationKey, SecretKey, +}; +use fhe_traits::{FheDecoder, FheDecrypter, FheEncoder, FheEncrypter}; +use rand::rng; +use util::timeit::timeit; + +fn weighted_sum_plain( + cts: &[Ciphertext], + weights: &[u64], + params: &Arc, + sk: &SecretKey, +) -> Result> { + let mut acc = Ciphertext::zero(params); + for (ct, w) in cts.iter().zip(weights.iter()) { + let pt_w = Plaintext::try_encode(&[*w], Encoding::poly(), params)?; + acc += &(ct * &pt_w); + } + let pt = sk.try_decrypt(&acc)?; + let v = Vec::::try_decode(&pt, Encoding::poly())?; + Ok(v[0]) +} + +fn weighted_sum_simd( + ct: &Ciphertext, + weights: &Plaintext, + ek: &fhe::bfv::EvaluationKey, + sk: &SecretKey, +) -> Result> { + let tmp = ct * weights; + let summed = ek.computes_inner_sum(&tmp)?; + let pt = sk.try_decrypt(&summed)?; + let v = Vec::::try_decode(&pt, Encoding::simd())?; + Ok(v[0]) +} + +fn main() -> Result<(), Box> { + let mut rng = rng(); + let params = BfvParameters::default_parameters_128(20) + .unwrap() + .nth(2) // first parameters do not support key switching + .unwrap(); + let sk = SecretKey::random(¶ms, &mut rng); + let pk = PublicKey::new(&sk, &mut rng); + let ek = EvaluationKeyBuilder::new_leveled(&sk, 0, 0)? + .enable_inner_sum()? + .build(&mut rng)?; + let rk = RelinearizationKey::new(&sk, &mut rng)?; + + // ----- Weighted sum without SIMD ----- + let values = [1u64, 2, 3]; + let weights = [4u64, 5, 6]; + timeit!("inner product (no SIMD)", { + let cts: Vec = values + .iter() + .map(|v| { + let pt = Plaintext::try_encode(&[*v], Encoding::poly(), ¶ms)?; + Ok(pk.try_encrypt(&pt, &mut rng)?) + }) + .collect::>>()?; + let ws_plain = weighted_sum_plain(&cts, &weights, ¶ms, &sk)?; + println!("Weighted sum (no SIMD) = {ws_plain}"); + }); + + // ----- Weighted sum with SIMD ----- + let pt_vals = Plaintext::try_encode(&values, Encoding::simd(), ¶ms)?; + let ct_vals = pk.try_encrypt(&pt_vals, &mut rng)?; + let pt_ws = Plaintext::try_encode(&weights, Encoding::simd(), ¶ms)?; + timeit!("inner product (SIMD)", { + let ws_simd = weighted_sum_simd(&ct_vals, &pt_ws, &ek, &sk)?; + println!("Weighted sum (SIMD) = {ws_simd}"); + }); + + // ----- Inner product without SIMD ----- + let v1 = [1u64, 2, 3]; + let v2 = [7u64, 8, 9]; + let ct_v1: Vec = v1 + .iter() + .map(|v| { + let pt = Plaintext::try_encode(&[*v], Encoding::poly(), ¶ms)?; + Ok(pk.try_encrypt(&pt, &mut rng)?) + }) + .collect::>>()?; + let ct_v2: Vec = v2 + .iter() + .map(|v| { + let pt = Plaintext::try_encode(&[*v], Encoding::poly(), ¶ms)?; + Ok(pk.try_encrypt(&pt, &mut rng)?) + }) + .collect::>>()?; + let mut acc = Ciphertext::zero(¶ms); + for (a, b) in ct_v1.iter().zip(ct_v2.iter()) { + let mut prod = a * b; + rk.relinearizes(&mut prod)?; + acc += ∏ + } + let pt = sk.try_decrypt(&acc)?; + let ip_plain = Vec::::try_decode(&pt, Encoding::poly())?[0]; + println!("Inner product (no SIMD) = {ip_plain}"); + + // ----- Inner product with SIMD ----- + let pt1 = Plaintext::try_encode(&v1, Encoding::simd(), ¶ms)?; + let pt2 = Plaintext::try_encode(&v2, Encoding::simd(), ¶ms)?; + let ct1 = pk.try_encrypt(&pt1, &mut rng)?; + let ct2 = pk.try_encrypt(&pt2, &mut rng)?; + let mut prod = &ct1 * &ct2; + rk.relinearizes(&mut prod)?; + let summed = ek.computes_inner_sum(&prod)?; + let pt = sk.try_decrypt(&summed)?; + let ip_simd = Vec::::try_decode(&pt, Encoding::simd())?[0]; + println!("Inner product (SIMD) = {ip_simd}"); + + // ----- Polynomial evaluation without SIMD ----- + let x = 3u64; + let pt_x = Plaintext::try_encode(&[x], Encoding::poly(), ¶ms)?; + let ct_x = pk.try_encrypt(&pt_x, &mut rng)?; + let mut ct_x2 = &ct_x * &ct_x; // x^2 + rk.relinearizes(&mut ct_x2)?; + let pt_three = Plaintext::try_encode(&[3u64], Encoding::poly(), ¶ms)?; + let pt_two = Plaintext::try_encode(&[2u64], Encoding::poly(), ¶ms)?; + let pt_one = Plaintext::try_encode(&[1u64], Encoding::poly(), ¶ms)?; + let mut ct_res = &ct_x2 * &pt_three; + ct_res += &(&ct_x * &pt_two); + ct_res += &pt_one; + let pt = sk.try_decrypt(&ct_res)?; + let poly_plain = Vec::::try_decode(&pt, Encoding::poly())?[0]; + println!("Polynomial (no SIMD) = {poly_plain}"); + + // ----- Polynomial evaluation with SIMD ----- + let x_vec = [1u64, 2, 3, 4]; + let pt_xv = Plaintext::try_encode(&x_vec, Encoding::simd(), ¶ms)?; + let ct_xv = pk.try_encrypt(&pt_xv, &mut rng)?; + let mut ct_xv2 = &ct_xv * &ct_xv; + rk.relinearizes(&mut ct_xv2)?; + let pt_three_v = Plaintext::try_encode(&vec![3u64; x_vec.len()], Encoding::simd(), ¶ms)?; + let pt_two_v = Plaintext::try_encode(&vec![2u64; x_vec.len()], Encoding::simd(), ¶ms)?; + let pt_one_v = Plaintext::try_encode(&vec![1u64; x_vec.len()], Encoding::simd(), ¶ms)?; + let mut ct_res_v = &ct_xv2 * &pt_three_v; + ct_res_v += &(&ct_xv * &pt_two_v); + ct_res_v += &pt_one_v; + let pt = sk.try_decrypt(&ct_res_v)?; + let poly_simd = Vec::::try_decode(&pt, Encoding::simd())?; + println!("Polynomial (SIMD) = {:?}", &poly_simd[..x_vec.len()]); + + Ok(()) +} diff --git a/crates/fhe/examples/rgsw.rs b/crates/fhe/examples/rgsw.rs new file mode 100644 index 00000000..fa30427a --- /dev/null +++ b/crates/fhe/examples/rgsw.rs @@ -0,0 +1,50 @@ +use std::error::Error; + +use fhe::bfv::{BfvParameters, Ciphertext, Encoding, Plaintext, RGSWCiphertext, SecretKey}; +use fhe_traits::{FheDecoder, FheDecrypter, FheEncoder, FheEncrypter, Serialize}; +use rand::rng; + +fn main() -> Result<(), Box> { + let mut rng = rng(); + let params = BfvParameters::default_parameters_128(20) + .unwrap() + .nth(2) + .unwrap(); + let sk = SecretKey::random(¶ms, &mut rng); + + let v1 = vec![1u64, 2, 3, 4]; + let v2 = vec![5u64, 6, 7, 8]; + let pt1 = Plaintext::try_encode(&v1, Encoding::simd(), ¶ms)?; + let pt2 = Plaintext::try_encode(&v2, Encoding::simd(), ¶ms)?; + let ct1: Ciphertext = sk.try_encrypt(&pt1, &mut rng)?; + let ct2: Ciphertext = sk.try_encrypt(&pt2, &mut rng)?; + let ct2_rgsw: RGSWCiphertext = sk.try_encrypt(&pt2, &mut rng)?; + + let mut product = &ct1 * &ct2_rgsw; + let expected = &ct1 * &ct2; + + println!("Noise in product: {}", unsafe { + sk.measure_noise(&product)? + }); + println!("Size of product: {} bytes", product.to_bytes().len()); + println!("Noise in expected: {}", unsafe { + sk.measure_noise(&product)? + }); + + product.switch_to_level(product.max_switchable_level())?; + println!("Noise in product: {}", unsafe { + sk.measure_noise(&product)? + }); + println!("Size of product: {} bytes", product.to_bytes().len()); + + let pt_prod = sk.try_decrypt(&product)?; + let pt_exp = sk.try_decrypt(&expected)?; + assert_eq!(pt_prod, pt_exp); + let decoded = Vec::::try_decode(&pt_prod, Encoding::simd())?; + println!( + "RGSW external product successful: {:?}", + &decoded[..v1.len()] + ); + + Ok(()) +} diff --git a/crates/fhe/src/bfv/parameters.rs b/crates/fhe/src/bfv/parameters.rs index 48d40dbc..a5947d0b 100644 --- a/crates/fhe/src/bfv/parameters.rs +++ b/crates/fhe/src/bfv/parameters.rs @@ -12,7 +12,7 @@ use fhe_math::{ use fhe_traits::{Deserialize, FheParameters, Serialize}; use itertools::Itertools; use num_bigint::BigUint; -use num_traits::ToPrimitive; +use num_traits::{PrimInt as _, ToPrimitive}; use prost::Message; use std::collections::HashMap; use std::fmt::Debug; @@ -156,9 +156,13 @@ impl BfvParameters { /// Iterator over default parameters providing about 128 bits of security /// according to the standard. + /// Filters out parameters where the modulus product bitlength is smaller + /// than the plaintext modulus bitlength. + /// + /// Returns an error if no parameters are available after filtering. pub fn default_parameters_128( plaintext_nbits: usize, - ) -> impl Iterator> { + ) -> Result>> { debug_assert!(plaintext_nbits < 64); let mut n_and_qs = HashMap::new(); @@ -190,7 +194,7 @@ impl BfvParameters { ], ); - n_and_qs + let parameters: Vec> = n_and_qs .into_iter() .sorted_by_key(|(n, _)| *n) .filter_map(move |(n, moduli)| { @@ -199,15 +203,38 @@ impl BfvParameters { 2 * n as u64, u64::MAX >> (64 - plaintext_nbits), ) - .map(|plaintext_modulus| { - BfvParametersBuilder::new() - .set_degree(n as usize) - .set_plaintext_modulus(plaintext_modulus) - .set_moduli(&moduli) - .build_arc() - .unwrap() + .and_then(|plaintext_modulus| { + // Calculate the bitlength of the product of moduli + let modulus_product_bitlength = moduli + .iter() + .map(|&m| 64 - m.leading_zeros() as usize) + .sum::(); + + // Filter out parameters where modulus product bitlength < plaintext bitlength + if modulus_product_bitlength >= plaintext_nbits { + BfvParametersBuilder::new() + .set_degree(n as usize) + .set_plaintext_modulus(plaintext_modulus) + .set_moduli(&moduli) + .build_arc() + .ok() + } else { + None + } }) }) + .collect(); + + // Check if we have any valid parameters after filtering + if parameters.is_empty() { + return Err(Error::ParametersError(ParametersError::NoParametersAvailable { + reason: format!( + "No default parameters available for plaintext modulus of {plaintext_nbits} bits. All parameter sets have modulus product bitlength smaller than the plaintext modulus." + ), + })); + } + + Ok(parameters.into_iter()) } #[cfg(test)] @@ -636,7 +663,33 @@ mod tests { #[test] fn default_parameters_iterator() { - let mut it = BfvParameters::default_parameters_128(20); + let mut it = BfvParameters::default_parameters_128(20).unwrap(); assert!(it.next().is_some()); } + + #[test] + fn default_parameters_filtering() { + // Test that parameters are filtered correctly + let params: Vec<_> = BfvParameters::default_parameters_128(20).unwrap().collect(); + + // All returned parameters should have sufficient modulus bitlength + for param in ¶ms { + let modulus_product_bitlength = param.moduli_sizes.iter().sum::(); + assert!(modulus_product_bitlength >= 20); + } + + // Test with a very small plaintext modulus for which we won't be able to + // create any parameters + let result = BfvParameters::default_parameters_128(10); + assert!(result.is_err()); + + match result { + Err(e) => { + let error_string = format!("{e}"); + assert!(error_string.contains("No parameters available")); + assert!(error_string.contains("10 bits")); + } + Ok(_) => panic!("Expected error"), + } + } } diff --git a/crates/fhe/src/errors.rs b/crates/fhe/src/errors.rs index 5bcf5a5c..6027a6c2 100644 --- a/crates/fhe/src/errors.rs +++ b/crates/fhe/src/errors.rs @@ -139,8 +139,8 @@ impl Error { U: std::fmt::Debug, { Self::ContextMismatch { - found: format!("{:?}", found), - expected: format!("{:?}", expected), + found: format!("{found:?}"), + expected: format!("{expected:?}"), } } @@ -313,6 +313,10 @@ pub enum ParametersError { /// Indicates missing required parameter #[error("Missing required parameter: {parameter}")] MissingParameter { parameter: String }, + + /// Indicates no parameters are available after filtering + #[error("No parameters available: {reason}")] + NoParametersAvailable { reason: String }, } impl ParametersError {