diff --git a/willow/proto/willow/BUILD b/willow/proto/willow/BUILD index 935130e..5aa4e57 100644 --- a/willow/proto/willow/BUILD +++ b/willow/proto/willow/BUILD @@ -14,6 +14,7 @@ load("@protobuf//bazel:cc_proto_library.bzl", "cc_proto_library") load("@protobuf//bazel:proto_library.bzl", "proto_library") +load("@protobuf//rust:defs.bzl", "rust_proto_library") package( default_visibility = ["//visibility:public"], @@ -48,3 +49,22 @@ cc_proto_library( name = "input_spec_cc_proto", deps = [":input_spec_proto"], ) + +proto_library( + name = "messages_proto", + srcs = ["messages.proto"], + deps = [ + "//willow/proto/shell:shell_ciphertexts_proto", + "//willow/proto/zk:proofs_proto", + ], +) + +cc_proto_library( + name = "messages_cc_proto", + deps = [":messages_proto"], +) + +rust_proto_library( + name = "messages_rust_proto", + deps = [":messages_proto"], +) diff --git a/willow/proto/willow/messages.proto b/willow/proto/willow/messages.proto new file mode 100644 index 0000000..98e8dcc --- /dev/null +++ b/willow/proto/willow/messages.proto @@ -0,0 +1,49 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package secure_aggregation.willow; + +import "willow/proto/shell/ciphertexts.proto"; +import "willow/proto/zk/proofs.proto"; + +option java_multiple_files = true; +option java_outer_classname = "MessagesProto"; + +message ClientMessage { + ShellKaheCiphertext kahe_ciphertext = 1; + ShellAheCiphertext ahe_ciphertext = 2; + RlweRelationProofListProto proof = 3; + bytes nonce = 4; +} + +message PartialDecryptionRequest { + ShellAhePartialDecCiphertext partial_dec_ciphertext = 1; +} + +message PartialDecryptionResponse { + ShellAhePartialDecryption partial_decryption = 1; +} + +message CiphertextContribution { + ShellKaheCiphertext kahe_ciphertext = 1; + ShellAheRecoverCiphertext ahe_recover_ciphertext = 2; +} + +message DecryptionRequestContribution { + ShellAhePartialDecCiphertext partial_dec_ciphertext = 1; + RlweRelationProofListProto proof = 2; + bytes nonce = 3; +} diff --git a/willow/src/traits/BUILD b/willow/src/traits/BUILD index d92e1a9..6b1503a 100644 --- a/willow/src/traits/BUILD +++ b/willow/src/traits/BUILD @@ -85,7 +85,13 @@ rust_library( deps = [ ":ahe_traits", ":kahe_traits", + ":proto_serialization_traits", ":vahe_traits", + "@protobuf//rust:protobuf", + "//shell_wrapper:status", + "//willow/proto/shell:shell_ciphertexts_rust_proto", + "//willow/proto/willow:messages_rust_proto", + "//willow/proto/zk:proofs_rust_proto", ], ) diff --git a/willow/src/traits/messages.rs b/willow/src/traits/messages.rs index 674e2cd..07ba25f 100644 --- a/willow/src/traits/messages.rs +++ b/willow/src/traits/messages.rs @@ -13,9 +13,23 @@ // limitations under the License. use ahe_traits::AheBase; -use kahe_traits::KaheBase; +use kahe_traits::{HasKahe, KaheBase}; +use messages_rust_proto::{ + CiphertextContribution as CiphertextContributionProto, ClientMessage as ClientMessageProto, + DecryptionRequestContribution as DecryptionRequestContributionProto, + PartialDecryptionRequest as PartialDecryptionRequestProto, + PartialDecryptionResponse as PartialDecryptionResponseProto, +}; +use proofs_rust_proto::RlweRelationProofListProto; +use proto_serialization_traits::{FromProto, ToProto}; +use protobuf::{proto, AsView}; +use shell_ciphertexts_rust_proto::{ + ShellAheCiphertext, ShellAhePartialDecCiphertext, ShellAhePartialDecryption, + ShellAheRecoverCiphertext, ShellKaheCiphertext, +}; +use status::StatusError; use std::fmt::Debug; -use vahe_traits::VaheBase; +use vahe_traits::{HasVahe, VaheBase}; pub type DecryptorPublicKeyShare = ::PublicKeyShare; @@ -30,6 +44,52 @@ pub struct ClientMessage { pub nonce: Vec, } +impl<'a, C, Kahe, Vahe> ToProto<&'a C> for ClientMessage +where + C: HasKahe + HasVahe, + Kahe: KaheBase + 'a, + Vahe: VaheBase + 'a, + Kahe::Ciphertext: ToProto<&'a Kahe, Proto = ShellKaheCiphertext>, + Vahe::Ciphertext: ToProto<&'a Vahe, Proto = ShellAheCiphertext>, + Vahe::EncryptionProof: ToProto, +{ + type Proto = ClientMessageProto; + + fn to_proto(&self, context: &'a C) -> Result { + Ok(proto!(ClientMessageProto { + kahe_ciphertext: self.kahe_ciphertext.to_proto(context.kahe())?, + ahe_ciphertext: self.ahe_ciphertext.to_proto(context.vahe())?, + proof: self.proof.to_proto(())?, + nonce: self.nonce.clone(), + })) + } +} + +impl<'a, C, Kahe, Vahe> FromProto<&'a C> for ClientMessage +where + C: HasKahe + HasVahe, + Kahe: KaheBase + 'a, + Vahe: VaheBase + 'a, + Kahe::Ciphertext: FromProto<&'a Kahe, Proto = ShellKaheCiphertext>, + Vahe::Ciphertext: FromProto<&'a Vahe, Proto = ShellAheCiphertext>, + Vahe::EncryptionProof: FromProto, +{ + type Proto = ClientMessageProto; + + fn from_proto( + proto: impl AsView, + context: &'a C, + ) -> Result { + let proto = proto.as_view(); + Ok(ClientMessage { + kahe_ciphertext: Kahe::Ciphertext::from_proto(proto.kahe_ciphertext(), context.kahe())?, + ahe_ciphertext: Vahe::Ciphertext::from_proto(proto.ahe_ciphertext(), context.vahe())?, + proof: Vahe::EncryptionProof::from_proto(proto.proof(), ())?, + nonce: proto.nonce().to_vec(), + }) + } +} + impl Clone for ClientMessage { fn clone(self: &ClientMessage) -> ClientMessage { ClientMessage { @@ -46,6 +106,43 @@ pub struct PartialDecryptionRequest { pub partial_dec_ciphertext: Vahe::PartialDecCiphertext, } +impl<'a, C, Vahe> ToProto<&'a C> for PartialDecryptionRequest +where + C: HasVahe, + Vahe: VaheBase + 'a, + Vahe::PartialDecCiphertext: ToProto<&'a Vahe, Proto = ShellAhePartialDecCiphertext>, +{ + type Proto = PartialDecryptionRequestProto; + + fn to_proto(&self, context: &'a C) -> Result { + Ok(proto!(PartialDecryptionRequestProto { + partial_dec_ciphertext: self.partial_dec_ciphertext.to_proto(context.vahe())?, + })) + } +} + +impl<'a, C, Vahe> FromProto<&'a C> for PartialDecryptionRequest +where + C: HasVahe, + Vahe: VaheBase + 'a, + Vahe::PartialDecCiphertext: FromProto<&'a Vahe, Proto = ShellAhePartialDecCiphertext>, +{ + type Proto = PartialDecryptionRequestProto; + + fn from_proto( + proto: impl AsView, + context: &'a C, + ) -> Result { + let proto = proto.as_view(); + Ok(PartialDecryptionRequest { + partial_dec_ciphertext: Vahe::PartialDecCiphertext::from_proto( + proto.partial_dec_ciphertext(), + context.vahe(), + )?, + }) + } +} + /// We manually implement clone for PartialDecryptionRequest because Vahe is not cloneable. impl Clone for PartialDecryptionRequest { fn clone(self: &PartialDecryptionRequest) -> PartialDecryptionRequest { @@ -65,12 +162,92 @@ pub struct PartialDecryptionResponse { pub partial_decryption: Vahe::PartialDecryption, } +impl<'a, C, Vahe> ToProto<&'a C> for PartialDecryptionResponse +where + C: HasVahe, + Vahe: VaheBase + 'a, + Vahe::PartialDecryption: ToProto<&'a Vahe, Proto = ShellAhePartialDecryption>, +{ + type Proto = PartialDecryptionResponseProto; + + fn to_proto(&self, context: &'a C) -> Result { + Ok(proto!(PartialDecryptionResponseProto { + partial_decryption: self.partial_decryption.to_proto(context.vahe())?, + })) + } +} + +impl<'a, C, Vahe> FromProto<&'a C> for PartialDecryptionResponse +where + C: HasVahe, + Vahe: VaheBase + 'a, + Vahe::PartialDecryption: FromProto<&'a Vahe, Proto = ShellAhePartialDecryption>, +{ + type Proto = PartialDecryptionResponseProto; + + fn from_proto( + proto: impl AsView, + context: &'a C, + ) -> Result { + let proto = proto.as_view(); + Ok(PartialDecryptionResponse { + partial_decryption: Vahe::PartialDecryption::from_proto( + proto.partial_decryption(), + context.vahe(), + )?, + }) + } +} + /// The part of the client message that the verifier needn't check pub struct CiphertextContribution { pub kahe_ciphertext: Kahe::Ciphertext, pub ahe_recover_ciphertext: Vahe::RecoverCiphertext, } +impl<'a, C, Kahe, Vahe> ToProto<&'a C> for CiphertextContribution +where + C: HasKahe + HasVahe, + Kahe: KaheBase + 'a, + Vahe: VaheBase + 'a, + Kahe::Ciphertext: ToProto<&'a Kahe, Proto = ShellKaheCiphertext>, + Vahe::RecoverCiphertext: ToProto<&'a Vahe, Proto = ShellAheRecoverCiphertext>, +{ + type Proto = CiphertextContributionProto; + + fn to_proto(&self, context: &'a C) -> Result { + Ok(proto!(CiphertextContributionProto { + kahe_ciphertext: self.kahe_ciphertext.to_proto(context.kahe())?, + ahe_recover_ciphertext: self.ahe_recover_ciphertext.to_proto(context.vahe())?, + })) + } +} + +impl<'a, C, Kahe, Vahe> FromProto<&'a C> for CiphertextContribution +where + C: HasKahe + HasVahe, + Kahe: KaheBase + 'a, + Vahe: VaheBase + 'a, + Kahe::Ciphertext: FromProto<&'a Kahe, Proto = ShellKaheCiphertext>, + Vahe::RecoverCiphertext: FromProto<&'a Vahe, Proto = ShellAheRecoverCiphertext>, +{ + type Proto = CiphertextContributionProto; + + fn from_proto( + proto: impl AsView, + context: &'a C, + ) -> Result { + let proto = proto.as_view(); + Ok(CiphertextContribution { + kahe_ciphertext: Kahe::Ciphertext::from_proto(proto.kahe_ciphertext(), context.kahe())?, + ahe_recover_ciphertext: Vahe::RecoverCiphertext::from_proto( + proto.ahe_recover_ciphertext(), + context.vahe(), + )?, + }) + } +} + impl Clone for CiphertextContribution { fn clone(&self) -> CiphertextContribution { CiphertextContribution { @@ -88,6 +265,49 @@ pub struct DecryptionRequestContribution { pub nonce: Vec, } +impl<'a, C, Vahe> ToProto<&'a C> for DecryptionRequestContribution +where + C: HasVahe, + Vahe: VaheBase + 'a, + Vahe::PartialDecCiphertext: ToProto<&'a Vahe, Proto = ShellAhePartialDecCiphertext>, + Vahe::EncryptionProof: ToProto, +{ + type Proto = DecryptionRequestContributionProto; + + fn to_proto(&self, context: &'a C) -> Result { + Ok(proto!(DecryptionRequestContributionProto { + partial_dec_ciphertext: self.partial_dec_ciphertext.to_proto(context.vahe())?, + proof: self.proof.to_proto(())?, + nonce: self.nonce.clone(), + })) + } +} + +impl<'a, C, Vahe> FromProto<&'a C> for DecryptionRequestContribution +where + C: HasVahe, + Vahe: VaheBase + 'a, + Vahe::PartialDecCiphertext: FromProto<&'a Vahe, Proto = ShellAhePartialDecCiphertext>, + Vahe::EncryptionProof: FromProto, +{ + type Proto = DecryptionRequestContributionProto; + + fn from_proto( + proto: impl AsView, + context: &'a C, + ) -> Result { + let proto = proto.as_view(); + Ok(DecryptionRequestContribution { + partial_dec_ciphertext: Vahe::PartialDecCiphertext::from_proto( + proto.partial_dec_ciphertext(), + context.vahe(), + )?, + proof: Vahe::EncryptionProof::from_proto(proto.proof(), ())?, + nonce: proto.nonce().to_vec(), + }) + } +} + impl Clone for DecryptionRequestContribution { fn clone(&self) -> DecryptionRequestContribution { DecryptionRequestContribution { diff --git a/willow/tests/BUILD b/willow/tests/BUILD index c12fd4c..7e0ad48 100644 --- a/willow/tests/BUILD +++ b/willow/tests/BUILD @@ -39,6 +39,7 @@ rust_test( "//willow/src/traits:kahe_traits", "//willow/src/traits:messages", "//willow/src/traits:prng_traits", + "//willow/src/traits:proto_serialization_traits", "//willow/src/traits:server_traits", "//willow/src/traits:vahe_traits", "//willow/src/traits:verifier_traits", diff --git a/willow/tests/willow_v1_shell.rs b/willow/tests/willow_v1_shell.rs index 0e1bdc4..4c3e5ff 100644 --- a/willow/tests/willow_v1_shell.rs +++ b/willow/tests/willow_v1_shell.rs @@ -19,8 +19,13 @@ use googletest::prelude::container_eq; use googletest::{gtest, verify_eq, verify_that}; use kahe_shell::ShellKahe; use kahe_traits::KaheBase; +use messages::{ + CiphertextContribution, ClientMessage, DecryptionRequestContribution, DecryptorPublicKeyShare, + PartialDecryptionRequest, PartialDecryptionResponse, +}; use parameters_shell::{create_shell_ahe_config, create_shell_kahe_config}; use prng_traits::SecurePrng; +use proto_serialization_traits::{FromProto, ToProto}; use server_traits::SecureAggregationServer; use single_thread_hkdf::SingleThreadHkdfPrng; use status::StatusErrorCode; @@ -128,6 +133,140 @@ fn encrypt_decrypt_one() -> googletest::Result<()> { ) } +/// Encrypt and decrypt with a single decryptor and single client, using serialization. +#[gtest] +fn encrypt_decrypt_one_serialized() -> googletest::Result<()> { + let default_id = String::from("default"); + let aggregation_config = generate_aggregation_config(default_id.clone(), 16, 10, 1, 1); + let max_number_of_decryptors = aggregation_config.max_number_of_decryptors; + + // Create client. + let kahe = + ShellKahe::new(create_shell_kahe_config(&aggregation_config).unwrap(), CONTEXT_STRING) + .unwrap(); + let vahe = + ShellVahe::new(create_shell_ahe_config(max_number_of_decryptors).unwrap(), CONTEXT_STRING) + .unwrap(); + let seed = SingleThreadHkdfPrng::generate_seed().unwrap(); + let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); + let mut client = WillowV1Client { kahe, vahe, prng }; + + // Create decryptor, which needs its own `vahe` (with same public polynomials + // generated from the seeds) and `prng`. + let vahe = + ShellVahe::new(create_shell_ahe_config(max_number_of_decryptors).unwrap(), CONTEXT_STRING) + .unwrap(); + let seed = SingleThreadHkdfPrng::generate_seed().unwrap(); + let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); + let mut decryptor_state = DecryptorState::default(); + let mut decryptor = WillowV1Decryptor { vahe, prng }; + + // Create server. + let kahe = + ShellKahe::new(create_shell_kahe_config(&aggregation_config).unwrap(), CONTEXT_STRING) + .unwrap(); + let vahe = + ShellVahe::new(create_shell_ahe_config(max_number_of_decryptors).unwrap(), CONTEXT_STRING) + .unwrap(); + let server = WillowV1Server { kahe, vahe }; + let mut server_state = ServerState::default(); + + // Create verifier. + let vahe = + ShellVahe::new(create_shell_ahe_config(max_number_of_decryptors).unwrap(), CONTEXT_STRING) + .unwrap(); + let verifier = WillowV1Verifier { vahe }; + let mut verifier_state = VerifierState::default(); + + // Decryptor generates public key share. + let public_key_share = decryptor.create_public_key_share(&mut decryptor_state).unwrap(); + + // Serialize and deserialize the public key share. + let public_key_share_proto = public_key_share.to_proto(&decryptor.vahe)?; + let public_key_share: DecryptorPublicKeyShare = + DecryptorPublicKeyShare::::from_proto(public_key_share_proto, &server.vahe)?; + + // Server handles the public key share. + server + .handle_decryptor_public_key_share(public_key_share, "Decryptor 0", &mut server_state) + .unwrap(); + + // Server creates the public key. + let public_key = server.create_decryptor_public_key(&server_state).unwrap(); + + // Serialize and deserialize the public key. + let public_key_proto = public_key.to_proto(&server.vahe)?; + let public_key = + messages::DecryptorPublicKey::::from_proto(public_key_proto, &client.vahe)?; + + // Client encrypts. + let client_plaintext = + HashMap::from([(default_id.clone(), vec![1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1])]); + let nonce = generate_random_nonce(); + let client_message = client + .create_client_message( + &ShellKahe::plaintext_as_slice(&client_plaintext), + &public_key, + &nonce, + ) + .unwrap(); + + // Serialize and deserialize the client message. + let client_message_proto = client_message.to_proto(&client)?; + let client_message: ClientMessage = + ClientMessage::from_proto(client_message_proto, &server)?; + + // The client message is split and handled by the server and verifier. + let (ciphertext_contribution, decryption_request_contribution) = + server.split_client_message(client_message).unwrap(); + + // Serialize and deserialize the contributions. + let ciphertext_contribution_proto = ciphertext_contribution.to_proto(&server)?; + let ciphertext_contribution: CiphertextContribution = + CiphertextContribution::from_proto(ciphertext_contribution_proto, &server)?; + + let decryption_request_contribution_proto = + decryption_request_contribution.to_proto(&server)?; + let decryption_request_contribution: DecryptionRequestContribution = + DecryptionRequestContribution::from_proto( + decryption_request_contribution_proto, + &verifier, + )?; + + verifier.verify_and_include(decryption_request_contribution, &mut verifier_state).unwrap(); + server.handle_ciphertext_contribution(ciphertext_contribution, &mut server_state).unwrap(); + + // Verifier creates the partial decryption request. + let pd_ct = verifier.create_partial_decryption_request(verifier_state).unwrap(); + + // Serialize and deserialize the partial decryption request. + let pd_ct_proto = pd_ct.to_proto(&verifier)?; + let pd_ct: PartialDecryptionRequest = + PartialDecryptionRequest::from_proto(pd_ct_proto, &decryptor)?; + + // Decryptor creates partial decryption. + let pd = decryptor.handle_partial_decryption_request(pd_ct, &decryptor_state).unwrap(); + + // Serialize and deserialize the partial decryption. + let pd_proto = pd.to_proto(&decryptor)?; + let pd: PartialDecryptionResponse = + PartialDecryptionResponse::from_proto(pd_proto, &server)?; + + // Server handles the partial decryption. + server.handle_partial_decryption(pd, &mut server_state).unwrap(); + + // Server recovers the aggregation result. + let aggregation_result = server.recover_aggregation_result(&server_state).unwrap(); + + // Check that the (padded) result matches the client plaintext. + verify_that!(aggregation_result.keys().collect::>(), container_eq([&default_id]))?; + let client_plaintext_length = client_plaintext.get(&default_id).unwrap().len(); + verify_eq!( + aggregation_result.get(&default_id).unwrap()[..client_plaintext_length], + client_plaintext.get(&default_id).unwrap()[..] + ) +} + // Encrypt and decrypt with multiple clients and a single decryptor. #[gtest] fn encrypt_decrypt_multiple_clients() -> googletest::Result<()> {