Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,16 @@ jobs:
- name: Install protobuf compiler
run: sudo apt-get update && sudo apt-get install -y protobuf-compiler

- name: Cache cargo build
uses: Swatinem/rust-cache@v2

- name: Check code format
run: cargo fmt --all -- --check

- name: Run clippy
run: cargo clippy --all-targets --all-features -- -D warnings

- name: Run tests
run: cargo test --all-features
# Release profile: one compile pass, and the trBFV secure-preset e2e
# tests (tests/trbfv_secure_e2e.rs) run in seconds instead of minutes.
run: cargo test --release --all-features
1 change: 1 addition & 0 deletions crates/fhe/src/trbfv/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub fn validate_threshold_config(n: usize, threshold: usize) -> Result<(), Error
if n == 0 {
return Err(Error::invalid_party_count(n, 1));
}
// TODO: make stronger assumptions on minimum requirement (and / or) exact requirements.
if threshold > (n - 1) / 2 {
return Err(Error::threshold_too_large(threshold, n));
}
Expand Down
18 changes: 16 additions & 2 deletions crates/fhe/src/trbfv/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,15 @@ impl Error {
#[must_use]
pub fn invalid_party_id(party_id: usize, max_party_id: usize) -> Self {
Error::UnspecifiedInput(format!(
"Invalid party ID: {party_id}, must be between 0 and {max_party_id}"
"Invalid party ID: {party_id}, must be between 1 and {max_party_id}"
))
}

/// Create a duplicate party ID error.
#[must_use]
pub fn duplicate_party_id(party_id: usize) -> Self {
Error::UnspecifiedInput(format!(
"Duplicate party ID {party_id} in reconstructing parties"
))
}

Expand Down Expand Up @@ -133,7 +141,13 @@ mod tests {
let error = Error::invalid_party_id(5, 3);
assert_eq!(
error.to_string(),
"Invalid party ID: 5, must be between 0 and 3"
"Invalid party ID: 5, must be between 1 and 3"
);

let error = Error::duplicate_party_id(2);
assert_eq!(
error.to_string(),
"Duplicate party ID 2 in reconstructing parties"
);

let error = Error::secret_sharing("Test secret sharing error");
Expand Down
83 changes: 64 additions & 19 deletions crates/fhe/src/trbfv/shamir.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::Error;
use fhe_util::rng08;
/// Shamir Secret Sharing implementation for threshold BFV.
///
Expand Down Expand Up @@ -40,7 +41,7 @@ use rayon::prelude::*;
/// let shares = sss.split(secret.clone());
///
/// println!("shares: {:?}", shares);
/// assert_eq!(secret, sss.recover(&shares[0..sss.threshold +1]));
/// assert_eq!(secret, sss.recover(&shares[0..sss.threshold +1]).unwrap());
/// }
///
/// Fork a full-entropy ChaCha20 seed from the caller's RNG.
Expand Down Expand Up @@ -175,28 +176,40 @@ impl ShamirSecretSharing {
///
/// The reconstructed secret value.
///
/// # Panics
/// # Errors
///
/// Panics if the number of shares provided is not equal to the threshold + one.
#[must_use]
pub fn recover(&self, shares: &[(usize, BigInt)]) -> BigInt {
assert!(shares.len() == (self.threshold + 1), "wrong shares number");
/// Returns an error if the number of shares provided is not equal to
/// threshold + 1, or if a Lagrange denominator is not invertible
/// (e.g., duplicate share indices).
pub fn recover(&self, shares: &[(usize, BigInt)]) -> Result<BigInt, Error> {
if shares.len() != self.threshold + 1 {
return Err(Error::secret_sharing(format!(
"wrong shares number: expected {}, got {}",
self.threshold + 1,
shares.len()
)));
}
let (xs, ys): (Vec<usize>, Vec<BigInt>) = shares.iter().cloned().unzip();
let result = self.lagrange_interpolation(Zero::zero(), xs, ys);
let result = self.lagrange_interpolation(Zero::zero(), xs, ys)?;
if result < Zero::zero() {
result + &self.prime
Ok(result + &self.prime)
} else {
result
Ok(result)
}
}

// indices i and item iterate 0..len, same as xs_bigint.len() and ys.len()
#[allow(clippy::indexing_slicing)]
fn lagrange_interpolation(&self, x: BigInt, xs: Vec<usize>, ys: Vec<BigInt>) -> BigInt {
fn lagrange_interpolation(
&self,
x: BigInt,
xs: Vec<usize>,
ys: Vec<BigInt>,
) -> Result<BigInt, Error> {
let len = xs.len();
let xs_bigint: Vec<BigInt> = xs.iter().map(|x| BigInt::from(*x as i64)).collect();

(0..len)
let terms: Result<Vec<BigInt>, Error> = (0..len)
.into_par_iter()
.map(|item| {
let numerator = (0..len).fold(One::one(), |product: BigInt, i| {
Expand All @@ -214,20 +227,28 @@ impl ShamirSecretSharing {
}
});
// Calculate this Lagrange term
(numerator * self.mod_reverse(denominator) * &ys[item]) % &self.prime
Ok((numerator * self.mod_reverse(denominator)? * &ys[item]) % &self.prime)
})
.reduce(Zero::zero, |sum, term| (sum + term) % &self.prime)
.collect();

Ok(terms?
.into_iter()
.fold(Zero::zero(), |sum: BigInt, term| (sum + term) % &self.prime))
}

fn mod_reverse(&self, num: BigInt) -> BigInt {
fn mod_reverse(&self, num: BigInt) -> Result<BigInt, Error> {
let num1 = if num < Zero::zero() {
num + &self.prime
} else {
num
};
let (_gcd, _, inv) = self.extend_euclid_algo(num1);
// println!("inv:{}", inv);
inv
let (gcd, _, inv) = self.extend_euclid_algo(num1);
if !gcd.is_one() {
return Err(Error::secret_sharing(
"non-invertible Lagrange denominator (duplicate or invalid share indices)",
));
}
Ok(inv)
}

/**
Expand Down Expand Up @@ -304,10 +325,34 @@ mod tests {
(1, BigInt::from(1494)),
(2, BigInt::from(329)),
(3, BigInt::from(965))
]),
])
.unwrap(),
BigInt::from(1234)
);
}
#[test]
fn test_recover_rejects_bad_shares() {
let sss = ShamirSecretSharing {
threshold: 2,
share_amount: 6,
prime: BigInt::from(1613),
};
// Wrong share count
assert!(
sss.recover(&[(1, BigInt::from(1494)), (2, BigInt::from(329))])
.is_err()
);
// Duplicate share indices -> non-invertible Lagrange denominator
assert!(
sss.recover(&[
(1, BigInt::from(1494)),
(1, BigInt::from(1494)),
(3, BigInt::from(965))
])
.is_err()
);
}

#[test]
fn test_large_prime() {
let sss = ShamirSecretSharing {
Expand All @@ -322,6 +367,6 @@ mod tests {
};
let secret = BigInt::parse_bytes(b"ffffffffffffffffffffffffffffffffffffff", 16).unwrap();
let shares = sss.split(secret.clone(), &mut rand::rng());
assert_eq!(secret, sss.recover(&shares[0..sss.threshold + 1]));
assert_eq!(secret, sss.recover(&shares[0..sss.threshold + 1]).unwrap());
}
}
84 changes: 70 additions & 14 deletions crates/fhe/src/trbfv/shares.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,10 @@ impl ShareManager {
reconstructing_parties: Vec<usize>,
ciphertext: Arc<Ciphertext>,
) -> Result<Plaintext, Error> {
// Validate we have enough shares
if d_share_polys.len() < (self.threshold + 1) {
// Reconstruction consumes exactly threshold + 1 shares; requiring
// exactness (rather than truncating extras) avoids silently depending
// on the order of the provided shares.
if d_share_polys.len() != self.threshold + 1 {
return Err(Error::insufficient_shares(
d_share_polys.len(),
self.threshold + 1,
Expand All @@ -274,9 +276,22 @@ impl ShareManager {
"reconstructing_parties length must match d_share_polys length".to_string(),
));
}
let m_data: Vec<u64> = (0..self.params.moduli().len())
// Shamir x-coordinates are 1-based, bounded by n, and must be distinct:
// index 0 would evaluate the sharing polynomial at the secret itself,
// and duplicates make the Lagrange denominators non-invertible.
let mut seen = vec![false; self.n + 1];
for &idx in &reconstructing_parties {
if idx == 0 || idx > self.n {
return Err(Error::invalid_party_id(idx, self.n));
}
if seen[idx] {
return Err(Error::duplicate_party_id(idx));
}
seen[idx] = true;
}
let recovered: Result<Vec<Vec<u64>>, Error> = (0..self.params.moduli().len())
.into_par_iter()
.flat_map(|m| {
.map(|m| {
let shamir_ss = ShamirSecretSharing::new(
self.threshold,
self.n,
Expand All @@ -286,13 +301,11 @@ impl ShareManager {
// Parallelize coefficient recovery within each modulus
(0..self.params.degree())
.into_par_iter()
.map(|i| {
.map(|i| -> Result<u64, Error> {
let mut shamir_open_vec_mod: Vec<(usize, BigInt)> =
Vec::with_capacity(self.params.degree());
for (party_idx, d_share_poly) in reconstructing_parties
.iter()
.zip(d_share_polys.iter())
.take(self.threshold + 1)
Vec::with_capacity(self.threshold + 1);
for (party_idx, d_share_poly) in
reconstructing_parties.iter().zip(d_share_polys.iter())
{
let coeffs = d_share_poly.coefficients();
let coeff_arr = coeffs.row(m);
Expand All @@ -301,13 +314,17 @@ impl ShareManager {
let coeff_formatted = (*party_idx, coeff.to_bigint().unwrap());
shamir_open_vec_mod.push(coeff_formatted);
}
let shamir_result =
shamir_ss.recover(&shamir_open_vec_mod[0..self.threshold + 1]);
shamir_result.to_u64().unwrap()
let shamir_result = shamir_ss.recover(&shamir_open_vec_mod)?;
shamir_result.to_u64().ok_or_else(|| {
Error::DefaultError(
"recovered Shamir coefficient does not fit in u64".to_string(),
)
})
})
.collect::<Vec<u64>>()
.collect::<Result<Vec<u64>, Error>>()
})
.collect();
let m_data: Vec<u64> = recovered?.into_iter().flatten().collect();

// scale result poly
let arr_matrix =
Expand Down Expand Up @@ -850,6 +867,45 @@ mod tests {
);
}

#[test]
fn test_decrypt_from_shares_rejects_invalid_party_indices() {
let mut rng = rng();
let params = test_params();
let n = 5;
let threshold = 2; // needs exactly 3 shares
let manager = ShareManager::new(n, threshold, params.clone());

let sk = SecretKey::random(&params, &mut rng);
let pk = PublicKey::new(&sk, &mut rng);
let pt = Plaintext::try_encode(&[1u64], Encoding::poly(), &params).unwrap();
let ct = Arc::new(pk.try_encrypt(&pt, &mut rng).unwrap());

let ctx = params.context_at_level(0).unwrap();
let shares: Vec<Poly<PowerBasis>> = (0..3).map(|_| Poly::<PowerBasis>::zero(ctx)).collect();

// Duplicate index
let result = manager.decrypt_from_shares(shares.clone(), vec![1, 2, 2], ct.clone());
assert!(result.is_err());

// Index 0 (would evaluate the sharing polynomial at the secret)
let result = manager.decrypt_from_shares(shares.clone(), vec![0, 1, 2], ct.clone());
assert!(result.is_err());

// Index > n
let result = manager.decrypt_from_shares(shares.clone(), vec![1, 2, 6], ct.clone());
assert!(result.is_err());

// Wrong share count: more than threshold + 1 is rejected
let four: Vec<Poly<PowerBasis>> = (0..4).map(|_| Poly::<PowerBasis>::zero(ctx)).collect();
let result = manager.decrypt_from_shares(four, vec![1, 2, 3, 4], ct.clone());
assert!(result.is_err());

// Fewer than threshold + 1 is rejected
let two: Vec<Poly<PowerBasis>> = (0..2).map(|_| Poly::<PowerBasis>::zero(ctx)).collect();
let result = manager.decrypt_from_shares(two, vec![1, 2], ct);
assert!(result.is_err());
}

#[test]
fn test_threshold_decryption_random_party_order() {
let mut rng = rng();
Expand Down
Loading
Loading