diff --git a/shell_wrapper/shell_types.rs b/shell_wrapper/shell_types.rs index b62792d..ce82ca7 100644 --- a/shell_wrapper/shell_types.rs +++ b/shell_wrapper/shell_types.rs @@ -125,6 +125,8 @@ impl Clone for RnsPolynomialVec { } } +pub use ffi::RustVecToRnsPolynomialVecWrapper as rust_vec_to_rns_polynomial_vec; + pub use ffi::RnsPolynomialWrapper as RnsPolynomial; impl Deref for RnsPolynomial { diff --git a/willow/proto/shell/ciphertexts.proto b/willow/proto/shell/ciphertexts.proto index c45b24c..ef24c12 100644 --- a/willow/proto/shell/ciphertexts.proto +++ b/willow/proto/shell/ciphertexts.proto @@ -53,3 +53,7 @@ message ShellAheCiphertext { message ShellKaheCiphertext { repeated rlwe.SerializedRnsPolynomial poly = 1; } + +message ShellKaheSecretKey { + rlwe.SerializedRnsPolynomial poly = 1; +} diff --git a/willow/src/shell/BUILD b/willow/src/shell/BUILD index 8b2dbe1..6aaf218 100644 --- a/willow/src/shell/BUILD +++ b/willow/src/shell/BUILD @@ -65,11 +65,16 @@ rust_library( srcs = ["kahe.rs"], deps = [ ":single_thread_hkdf", + "@protobuf//rust:protobuf", "//shell_wrapper:kahe", + "//shell_wrapper:shell_serialization", + "//shell_wrapper:shell_serialization_cxx", # fixdeps: keep "//shell_wrapper:shell_types", "//shell_wrapper:status", + "//willow/proto/shell:shell_ciphertexts_rust_proto", "//willow/src/traits:kahe_traits", "//willow/src/traits:prng_traits", + "//willow/src/traits:proto_serialization_traits", ], ) diff --git a/willow/src/shell/kahe.rs b/willow/src/shell/kahe.rs index 021eb9f..ad3022f 100644 --- a/willow/src/shell/kahe.rs +++ b/willow/src/shell/kahe.rs @@ -16,9 +16,14 @@ use kahe::{KahePublicParametersWrapper, PackedVectorConfig}; use kahe_traits::{ KaheBase, KaheDecrypt, KaheEncrypt, KaheKeygen, TrySecretKeyFrom, TrySecretKeyInto, }; +use proto_serialization_traits::{FromProto, ToProto}; +use protobuf::proto; +use shell_ciphertexts_rust_proto::{ShellKaheCiphertext, ShellKaheSecretKey}; +use shell_serialization::{rns_polynomial_from_proto, rns_polynomial_to_proto}; use shell_types::{ add_in_place, add_in_place_vec, read_small_rns_polynomial_from_buffer, - write_small_rns_polynomial_to_buffer, RnsPolynomial, RnsPolynomialVec, + rust_vec_to_rns_polynomial_vec, write_small_rns_polynomial_to_buffer, RnsPolynomial, + RnsPolynomialVec, }; use single_thread_hkdf::SingleThreadHkdfPrng; use std::collections::HashMap; @@ -105,6 +110,60 @@ pub struct SecretKey(pub RnsPolynomial); #[derive(Clone)] pub struct Ciphertext(pub RnsPolynomialVec); +impl ToProto for SecretKey { + type Proto = ShellKaheSecretKey; + type Context = ShellKahe; + + fn to_proto(&self, ctx: &Self::Context) -> Result { + let moduli = kahe::get_moduli(&ctx.public_kahe_parameters); + let poly_proto = rns_polynomial_to_proto(&self.0, &moduli)?; + Ok(proto!(ShellKaheSecretKey { poly: poly_proto })) + } +} + +impl FromProto for SecretKey { + type Proto = ShellKaheSecretKey; + type Context = ShellKahe; + + fn from_proto( + proto: impl protobuf::AsView, + ctx: &Self::Context, + ) -> Result { + let moduli = kahe::get_moduli(&ctx.public_kahe_parameters); + let poly = rns_polynomial_from_proto(proto.as_view().poly(), &moduli)?; + Ok(Self(poly)) + } +} + +impl ToProto for Ciphertext { + type Proto = ShellKaheCiphertext; + type Context = ShellKahe; + + fn to_proto(&self, ctx: &Self::Context) -> Result { + let moduli = kahe::get_moduli(&ctx.public_kahe_parameters); + let mut result = proto!(ShellKaheCiphertext {}); + for poly in self.0.iter() { + result.poly_mut().push(rns_polynomial_to_proto(&poly, &moduli)?); + } + Ok(result) + } +} + +impl FromProto for Ciphertext { + type Proto = ShellKaheCiphertext; + type Context = ShellKahe; + + fn from_proto( + proto: impl protobuf::AsView, + ctx: &Self::Context, + ) -> Result { + let moduli = kahe::get_moduli(&ctx.public_kahe_parameters); + let polys: Result, _> = + proto.as_view().poly().iter().map(|p| rns_polynomial_from_proto(p, &moduli)).collect(); + Ok(Ciphertext(rust_vec_to_rns_polynomial_vec(polys?))) + } +} + impl KaheBase for ShellKahe { type SecretKey = SecretKey; @@ -312,6 +371,7 @@ mod test { KaheBase, KaheDecrypt, KaheEncrypt, KaheKeygen, TrySecretKeyFrom, TrySecretKeyInto, }; use prng_traits::SecurePrng; + use proto_serialization_traits::{FromProto, ToProto}; use shell_testing_parameters::{make_kahe_config_for, set_kahe_num_public_polynomials}; use single_thread_hkdf::SingleThreadHkdfPrng; use std::collections::HashMap; @@ -507,6 +567,34 @@ mod test { verify_le!(mean.abs(), TAIL_BOUND_MULTIPLIER * mean_std) } + #[gtest] + fn test_encrypt_decrypt_serialized_proto() -> googletest::Result<()> { + let plaintext_modulus_bits = 39; + let packed_vector_configs = HashMap::from([( + String::from(DEFAULT_ID), + PackedVectorConfig { base: 10, dimension: 2, num_packed_coeffs: 5 }, + )]); + let kahe_config = make_kahe_config_for(plaintext_modulus_bits, packed_vector_configs)?; + let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?; + + let pt = HashMap::from([(String::from(DEFAULT_ID), vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9])]); + let seed = SingleThreadHkdfPrng::generate_seed()?; + let mut prng = SingleThreadHkdfPrng::create(&seed)?; + let sk = kahe.key_gen(&mut prng)?; + let ct = kahe.encrypt(&ShellKahe::plaintext_as_slice(&pt), &sk, &mut prng)?; + + // Serialize and deserialize key + let sk_proto = sk.to_proto(&kahe)?; + let sk_deserialized = SecretKey::from_proto(sk_proto, &kahe)?; + + // Serialize and deserialize ciphertext + let ct_proto = ct.to_proto(&kahe)?; + let ct_deserialized = Ciphertext::from_proto(ct_proto, &kahe)?; + + let decrypted = kahe.decrypt(&ct_deserialized, &sk_deserialized)?; + verify_eq!(&pt, &decrypted) + } + #[gtest] fn test_key_serialization_is_homomorphic() -> googletest::Result<()> { // Set up a ShellKahe instance.