diff --git a/willow/proto/willow/messages.proto b/willow/proto/willow/messages.proto index 98e8dcc..53c1318 100644 --- a/willow/proto/willow/messages.proto +++ b/willow/proto/willow/messages.proto @@ -22,6 +22,7 @@ import "willow/proto/zk/proofs.proto"; option java_multiple_files = true; option java_outer_classname = "MessagesProto"; + message ClientMessage { ShellKaheCiphertext kahe_ciphertext = 1; ShellAheCiphertext ahe_ciphertext = 2; @@ -47,3 +48,20 @@ message DecryptionRequestContribution { RlweRelationProofListProto proof = 2; bytes nonce = 3; } + +message DecryptorStateProto { + ShellAheSecretKeyShare sk_share = 1; +} + +message ServerStateProto { + map decryptor_public_key_shares = 1; + ShellKaheCiphertext client_sum_kahe = 2; + ShellAheRecoverCiphertext client_sum_ahe_recover = 3; + ShellAhePartialDecryption partial_decryption_sum = 4; +} + +message VerifierStateProto { + ShellAhePartialDecCiphertext partial_dec_ciphertext_sum = 1; + bytes nonce_lower_bound = 2; + bytes nonce_upper_bound = 3; +} diff --git a/willow/src/willow_v1/BUILD b/willow/src/willow_v1/BUILD index 48c0035..cf1aa85 100644 --- a/willow/src/willow_v1/BUILD +++ b/willow/src/willow_v1/BUILD @@ -51,30 +51,76 @@ rust_test( ], ) +rust_test( + name = "willow_v1_decryptor_test", + crate = ":willow_v1_decryptor", + deps = [ + "@crate_index//:googletest", + "//willow/src/shell:parameters_shell", + "//willow/src/shell:single_thread_hkdf", + "//willow/src/shell:vahe_shell", + "//willow/src/traits:ahe_traits", + "//willow/src/traits:decryptor_traits", + "//willow/src/traits:prng_traits", + "//willow/src/traits:proto_serialization_traits", + ], +) + rust_library( name = "willow_v1_decryptor", srcs = [ "decryptor.rs", ], deps = [ + "@protobuf//rust:protobuf", "//shell_wrapper:status", + "//willow/proto/shell:shell_ciphertexts_rust_proto", + "//willow/proto/willow:messages_rust_proto", "//willow/src/traits:ahe_traits", "//willow/src/traits:decryptor_traits", "//willow/src/traits:messages", + "//willow/src/traits:proto_serialization_traits", "//willow/src/traits:vahe_traits", ], ) +rust_test( + name = "willow_v1_server_test", + crate = ":willow_v1_server", + deps = [ + ":willow_v1_client", + ":willow_v1_decryptor", + ":willow_v1_verifier", + "@crate_index//:googletest", + "//willow/src/shell:kahe_shell", + "//willow/src/shell:parameters_shell", + "//willow/src/shell:single_thread_hkdf", + "//willow/src/shell:vahe_shell", + "//willow/src/testing_utils", + "//willow/src/traits:ahe_traits", + "//willow/src/traits:client_traits", + "//willow/src/traits:decryptor_traits", + "//willow/src/traits:prng_traits", + "//willow/src/traits:proto_serialization_traits", + "//willow/src/traits:server_traits", + "//willow/src/traits:verifier_traits", + ], +) + rust_library( name = "willow_v1_server", srcs = [ "server.rs", ], deps = [ + "@protobuf//rust:protobuf", "//shell_wrapper:status", + "//willow/proto/shell:shell_ciphertexts_rust_proto", + "//willow/proto/willow:messages_rust_proto", "//willow/src/traits:ahe_traits", "//willow/src/traits:kahe_traits", "//willow/src/traits:messages", + "//willow/src/traits:proto_serialization_traits", "//willow/src/traits:server_traits", "//willow/src/traits:vahe_traits", ], @@ -86,10 +132,14 @@ rust_library( "verifier.rs", ], deps = [ + "@protobuf//rust:protobuf", "//shell_wrapper:status", + "//willow/proto/shell:shell_ciphertexts_rust_proto", + "//willow/proto/willow:messages_rust_proto", "//willow/src/traits:ahe_traits", "//willow/src/traits:kahe_traits", "//willow/src/traits:messages", + "//willow/src/traits:proto_serialization_traits", "//willow/src/traits:vahe_traits", "//willow/src/traits:verifier_traits", ], @@ -105,6 +155,7 @@ rust_test( "@crate_index//:googletest", "//shell_wrapper:status_matchers_rs", "//willow/src/shell:kahe_shell", + "//willow/src/shell:parameters_shell", "//willow/src/shell:single_thread_hkdf", "//willow/src/shell:vahe_shell", "//willow/src/testing_utils", @@ -114,6 +165,7 @@ rust_test( "//willow/src/traits:decryptor_traits", "//willow/src/traits:kahe_traits", "//willow/src/traits:prng_traits", + "//willow/src/traits:proto_serialization_traits", "//willow/src/traits:server_traits", "//willow/src/traits:vahe_traits", ], diff --git a/willow/src/willow_v1/decryptor.rs b/willow/src/willow_v1/decryptor.rs index 7e76977..9301df0 100644 --- a/willow/src/willow_v1/decryptor.rs +++ b/willow/src/willow_v1/decryptor.rs @@ -15,6 +15,11 @@ use ahe_traits::{AheKeygen, PartialDec}; use decryptor_traits::SecureAggregationDecryptor; use messages::{DecryptorPublicKeyShare, PartialDecryptionRequest, PartialDecryptionResponse}; +use messages_rust_proto::DecryptorStateProto; +use proto_serialization_traits::{FromProto, ToProto}; +use protobuf::{proto, AsView}; +use shell_ciphertexts_rust_proto::ShellAheSecretKeyShare; +use status::StatusError; use vahe_traits::{EncryptVerify, HasVahe, VaheBase}; /// Lightweight decryptor directly exposing KAHE/VAHE types. It verifies only the client proofs, @@ -41,6 +46,45 @@ impl Default for DecryptorState { } } +impl<'a, C, Vahe> ToProto<&'a C> for DecryptorState +where + C: HasVahe, + Vahe: VaheBase + 'a, + Vahe::SecretKeyShare: ToProto<&'a Vahe, Proto = ShellAheSecretKeyShare>, +{ + type Proto = DecryptorStateProto; + + fn to_proto(&self, context: &'a C) -> Result { + let mut proto = DecryptorStateProto::new(); + if let Some(sk) = &self.sk_share { + proto.set_sk_share(sk.to_proto(context.vahe())?); + } + Ok(proto) + } +} + +impl<'a, C, Vahe> FromProto<&'a C> for DecryptorState +where + C: HasVahe, + Vahe: VaheBase + 'a, + Vahe::SecretKeyShare: FromProto<&'a Vahe, Proto = ShellAheSecretKeyShare>, +{ + type Proto = DecryptorStateProto; + + fn from_proto( + proto: impl AsView, + context: &'a C, + ) -> Result { + let proto = proto.as_view(); + let sk_share = if proto.has_sk_share() { + Some(Vahe::SecretKeyShare::from_proto(proto.sk_share(), context.vahe())?) + } else { + None + }; + Ok(DecryptorState { sk_share }) + } +} + /// Implementation of the `SecureAggregationDecryptor` trait for the generic /// KAHE/AHE decryptor, using WillowCommon as the common types (e.g. protocol /// messages are directly the AHE public key and ciphertexts). @@ -82,3 +126,44 @@ where Ok(PartialDecryptionResponse { partial_decryption: pd }) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::{DecryptorState, WillowV1Decryptor}; + use ahe_traits::AheBase; + use decryptor_traits::SecureAggregationDecryptor; + use googletest::{gtest, verify_true}; + use parameters_shell::create_shell_ahe_config; + use prng_traits::SecurePrng; + use proto_serialization_traits::{FromProto, ToProto}; + use single_thread_hkdf::SingleThreadHkdfPrng; + use vahe_shell::ShellVahe; + + const CONTEXT_STRING: &[u8] = b"testing_context_string"; + + #[gtest] + fn decryptor_state_serialization_roundtrip() -> googletest::Result<()> { + let vahe = ShellVahe::new(create_shell_ahe_config(1).unwrap(), CONTEXT_STRING).unwrap(); + let seed = SingleThreadHkdfPrng::generate_seed()?; + let prng = SingleThreadHkdfPrng::create(&seed)?; + let mut decryptor = WillowV1Decryptor { vahe, prng }; + let mut decryptor_state = DecryptorState::default(); + + // Check empty state serialization. + let decryptor_state_proto = decryptor_state.to_proto(&decryptor)?; + let decryptor_state_roundtrip = + DecryptorState::from_proto(decryptor_state_proto, &decryptor)?; + verify_true!(decryptor_state_roundtrip.sk_share.is_none())?; + + // Check populated state serialization. + decryptor.create_public_key_share(&mut decryptor_state)?; + verify_true!(decryptor_state.sk_share.is_some())?; + let decryptor_state_proto = decryptor_state.to_proto(&decryptor)?; + let decryptor_state_roundtrip = + DecryptorState::from_proto(decryptor_state_proto, &decryptor)?; + verify_true!(decryptor_state_roundtrip.sk_share.is_some())?; + + Ok(()) + } +} diff --git a/willow/src/willow_v1/server.rs b/willow/src/willow_v1/server.rs index 6a0b39b..3ba55a6 100644 --- a/willow/src/willow_v1/server.rs +++ b/willow/src/willow_v1/server.rs @@ -18,12 +18,21 @@ use messages::{ CiphertextContribution, ClientMessage, DecryptionRequestContribution, DecryptorPublicKey, DecryptorPublicKeyShare, PartialDecryptionResponse, }; +use messages_rust_proto::ServerStateProto; +use proto_serialization_traits::{FromProto, ToProto}; +use protobuf::AsView; use server_traits::SecureAggregationServer; +use shell_ciphertexts_rust_proto::{ + ShellAhePartialDecryption, ShellAhePublicKeyShare, ShellAheRecoverCiphertext, + ShellKaheCiphertext, +}; +use status::StatusError; use std::collections::HashMap; use vahe_traits::{EncryptVerify, HasVahe, Recover, VaheBase}; -/// The server struct, containing a WillowCommon instance. Only the clients messages are verified, -/// not the key generation or partial decryptions. +/// Implements the `server` role in the Willow protocol. This includes aggregating public key shares +/// from the decryptors, aggregating client ciphertexts, and recovering the aggregation result after +/// receiving partial decryption responses from the decryptors. pub struct WillowV1Server { pub kahe: Kahe, pub vahe: Vahe, @@ -73,6 +82,92 @@ impl Clone for ServerState ToProto<&'a C> for ServerState +where + C: HasKahe + HasVahe, + Kahe: KaheBase + 'a, + Vahe: VaheBase + PartialDec + 'a, + Kahe::Ciphertext: ToProto<&'a Kahe, Proto = ShellKaheCiphertext>, // TODO: Rename protos to be generic once cl/836370582 has landed. + Vahe::RecoverCiphertext: ToProto<&'a Vahe, Proto = ShellAheRecoverCiphertext>, + Vahe::PartialDecryption: ToProto<&'a Vahe, Proto = ShellAhePartialDecryption>, + Vahe::PublicKeyShare: ToProto<&'a Vahe, Proto = ShellAhePublicKeyShare>, +{ + type Proto = ServerStateProto; + + fn to_proto(&self, context: &'a C) -> Result { + let mut proto = ServerStateProto::new(); + + if let Some((kahe, ahe)) = &self.client_sum { + proto.set_client_sum_kahe(kahe.to_proto(context.kahe())?); + proto.set_client_sum_ahe_recover(ahe.to_proto(context.vahe())?); + } + + for (k, v) in &self.decryptor_public_key_shares { + proto.decryptor_public_key_shares_mut().insert(k.as_str(), v.to_proto(context.vahe())?); + } + + if let Some(pd) = &self.partial_decryption_sum { + proto.set_partial_decryption_sum(pd.to_proto(context.vahe())?); + } + + Ok(proto) + } +} + +impl<'a, C, Kahe, Vahe> FromProto<&'a C> for ServerState +where + C: HasKahe + HasVahe, + Kahe: KaheBase + 'a, + Vahe: VaheBase + PartialDec + 'a, + Kahe::Ciphertext: FromProto<&'a Kahe, Proto = ShellKaheCiphertext>, + Vahe::RecoverCiphertext: FromProto<&'a Vahe, Proto = ShellAheRecoverCiphertext>, + Vahe::PartialDecryption: FromProto<&'a Vahe, Proto = ShellAhePartialDecryption>, + Vahe::PublicKeyShare: FromProto<&'a Vahe, Proto = ShellAhePublicKeyShare>, +{ + type Proto = ServerStateProto; + + fn from_proto( + proto: impl AsView, + context: &'a C, + ) -> Result { + let proto = proto.as_view(); + + let client_sum = if proto.has_client_sum_kahe() && proto.has_client_sum_ahe_recover() { + Some(( + Kahe::Ciphertext::from_proto(proto.client_sum_kahe(), context.kahe())?, + Vahe::RecoverCiphertext::from_proto( + proto.client_sum_ahe_recover(), + context.vahe(), + )?, + )) + } else if !proto.has_client_sum_kahe() && !proto.has_client_sum_ahe_recover() { + None + } else { + return Err(status::invalid_argument( + "ServerStateProto must have both or neither of client_sum_kahe and \ + client_sum_ahe_recover", + )); + }; + + let mut decryptor_public_key_shares = HashMap::new(); + for (k, v) in proto.decryptor_public_key_shares() { + decryptor_public_key_shares + .insert(k.to_string(), Vahe::PublicKeyShare::from_proto(v, context.vahe())?); + } + + let partial_decryption_sum = if proto.has_partial_decryption_sum() { + Some(Vahe::PartialDecryption::from_proto( + proto.partial_decryption_sum(), + context.vahe(), + )?) + } else { + None + }; + + Ok(ServerState { decryptor_public_key_shares, client_sum, partial_decryption_sum }) + } +} + impl SecureAggregationServer for WillowV1Server where Vahe: EncryptVerify + PartialDec + Recover, @@ -258,3 +353,126 @@ where Ok(merged_server_state) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ServerState, WillowV1Server}; + use ahe_traits::AheBase; + use client_traits::SecureAggregationClient; + use decryptor_traits::SecureAggregationDecryptor; + use googletest::{gtest, verify_true}; + use kahe_shell::ShellKahe; + use parameters_shell::{create_shell_ahe_config, create_shell_kahe_config}; + use prng_traits::SecurePrng; + use proto_serialization_traits::{FromProto, ToProto}; + use server_traits::SecureAggregationServer; + use single_thread_hkdf::SingleThreadHkdfPrng; + use std::collections::HashMap; + use testing_utils::{generate_aggregation_config, generate_random_nonce}; + use vahe_shell::ShellVahe; + use verifier_traits::SecureAggregationVerifier; + use willow_v1_client::WillowV1Client; + use willow_v1_decryptor::{DecryptorState, WillowV1Decryptor}; + use willow_v1_verifier::{VerifierState, WillowV1Verifier}; + + const CONTEXT_STRING: &[u8] = b"testing_context_string"; + const DEFAULT_VECTOR_ID: &str = "default"; + + #[gtest] + fn server_state_serialization_roundtrip() -> googletest::Result<()> { + let aggregation_config = + generate_aggregation_config(DEFAULT_VECTOR_ID.to_string(), 16, 10, 1, 1); + let max_number_of_decryptors = aggregation_config.max_number_of_decryptors; + + // Create client. + let kahe = + ShellKahe::new(create_shell_kahe_config(&aggregation_config).unwrap(), CONTEXT_STRING) + .unwrap(); + let vahe = ShellVahe::new( + create_shell_ahe_config(max_number_of_decryptors).unwrap(), + CONTEXT_STRING, + ) + .unwrap(); + let seed = SingleThreadHkdfPrng::generate_seed()?; + let prng = SingleThreadHkdfPrng::create(&seed)?; + let mut client = WillowV1Client { kahe, vahe, prng }; + + // Create decryptor. + let vahe = ShellVahe::new( + create_shell_ahe_config(max_number_of_decryptors).unwrap(), + CONTEXT_STRING, + ) + .unwrap(); + let seed = SingleThreadHkdfPrng::generate_seed()?; + let prng = SingleThreadHkdfPrng::create(&seed)?; + let mut decryptor_state = DecryptorState::default(); + let mut decryptor = WillowV1Decryptor { vahe, prng }; + + // Create server. + let kahe = + ShellKahe::new(create_shell_kahe_config(&aggregation_config).unwrap(), CONTEXT_STRING) + .unwrap(); + let vahe = ShellVahe::new( + create_shell_ahe_config(max_number_of_decryptors).unwrap(), + CONTEXT_STRING, + ) + .unwrap(); + let server = WillowV1Server { kahe, vahe }; + let mut server_state = ServerState::default(); + + // Create verifier. + let vahe = ShellVahe::new( + create_shell_ahe_config(max_number_of_decryptors).unwrap(), + CONTEXT_STRING, + ) + .unwrap(); + let verifier = WillowV1Verifier { vahe }; + let mut verifier_state = VerifierState::default(); + + // Check empty state serialization + let server_state_proto = server_state.to_proto(&server)?; + let server_state_roundtrip = ServerState::from_proto(server_state_proto, &server)?; + verify_true!(server_state_roundtrip.decryptor_public_key_shares.is_empty())?; + verify_true!(server_state_roundtrip.client_sum.is_none())?; + verify_true!(server_state_roundtrip.partial_decryption_sum.is_none())?; + + // Populate server state. + let public_key_share = decryptor.create_public_key_share(&mut decryptor_state)?; + server.handle_decryptor_public_key_share( + public_key_share, + "Decryptor 0", + &mut server_state, + )?; + let public_key = server.create_decryptor_public_key(&server_state)?; + let client_plaintext = HashMap::from([( + DEFAULT_VECTOR_ID.to_string(), + vec![1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1], + )]); + let nonce = generate_random_nonce(); + let client_message = client.create_client_message( + &ShellKahe::plaintext_as_slice(&client_plaintext), + &public_key, + &nonce, + )?; + let (ciphertext_contribution, decryption_request_contribution) = + server.split_client_message(client_message)?; + verifier.verify_and_include(decryption_request_contribution, &mut verifier_state)?; + server.handle_ciphertext_contribution(ciphertext_contribution, &mut server_state)?; + let pd_ct = verifier.create_partial_decryption_request(verifier_state)?; + let pd = decryptor.handle_partial_decryption_request(pd_ct, &decryptor_state)?; + server.handle_partial_decryption(pd, &mut server_state)?; + + // Check populated state serialization + verify_true!(!server_state.decryptor_public_key_shares.is_empty())?; + verify_true!(server_state.client_sum.is_some())?; + verify_true!(server_state.partial_decryption_sum.is_some())?; + let server_state_proto = server_state.to_proto(&server)?; + let server_state_roundtrip = ServerState::from_proto(server_state_proto, &server)?; + verify_true!(!server_state_roundtrip.decryptor_public_key_shares.is_empty())?; + verify_true!(server_state_roundtrip.client_sum.is_some())?; + verify_true!(server_state_roundtrip.partial_decryption_sum.is_some())?; + + Ok(()) + } +} diff --git a/willow/src/willow_v1/verifier.rs b/willow/src/willow_v1/verifier.rs index 1107c54..19c4e6c 100644 --- a/willow/src/willow_v1/verifier.rs +++ b/willow/src/willow_v1/verifier.rs @@ -13,6 +13,11 @@ // limitations under the License. use messages::{DecryptionRequestContribution, PartialDecryptionRequest}; +use messages_rust_proto::VerifierStateProto; +use proto_serialization_traits::{FromProto, ToProto}; +use protobuf::{proto, AsView}; +use shell_ciphertexts_rust_proto::ShellAhePartialDecCiphertext; +use status::StatusError; use std::fmt::Debug; use vahe_traits::{EncryptVerify, HasVahe, VaheBase}; use verifier_traits::SecureAggregationVerifier; @@ -97,6 +102,61 @@ impl Clone for VerifierState { } } +impl<'a, C, Vahe> ToProto<&'a C> for VerifierState +where + C: HasVahe, + Vahe: VaheBase + 'a, + Vahe::PartialDecCiphertext: ToProto<&'a Vahe, Proto = ShellAhePartialDecCiphertext>, +{ + type Proto = VerifierStateProto; + + fn to_proto(&self, context: &'a C) -> Result { + if let Some(state) = &self.0 { + Ok(proto!(VerifierStateProto { + partial_dec_ciphertext_sum: state + .partial_dec_ciphertext_sum + .to_proto(context.vahe())?, + nonce_lower_bound: state.nonce_bounds.0.clone(), + nonce_upper_bound: state.nonce_bounds.1.clone(), + })) + } else { + Ok(proto!(VerifierStateProto {})) + } + } +} + +impl<'a, C, Vahe> FromProto<&'a C> for VerifierState +where + C: HasVahe, + Vahe: VaheBase + 'a, + Vahe::PartialDecCiphertext: FromProto<&'a Vahe, Proto = ShellAhePartialDecCiphertext>, +{ + type Proto = VerifierStateProto; + + fn from_proto( + proto: impl AsView, + context: &'a C, + ) -> Result { + let proto = proto.as_view(); + if proto.has_partial_dec_ciphertext_sum() { + let state = VerifierState(Some(NonemptyVerifierState { + partial_dec_ciphertext_sum: Vahe::PartialDecCiphertext::from_proto( + proto.partial_dec_ciphertext_sum(), + context.vahe(), + )?, + nonce_bounds: ( + proto.nonce_lower_bound().to_vec(), + proto.nonce_upper_bound().to_vec(), + ), + })); + state.validate()?; + Ok(state) + } else { + Ok(VerifierState(None)) + } + } +} + impl SecureAggregationVerifier for WillowV1Verifier where Vahe: EncryptVerify, @@ -210,9 +270,10 @@ mod tests { }; use kahe_shell::ShellKahe; use kahe_traits::KaheBase; + use parameters_shell::{create_shell_ahe_config, create_shell_kahe_config}; use prng_traits::SecurePrng; + use proto_serialization_traits::{FromProto, ToProto}; use server_traits::SecureAggregationServer; - use shell_testing_parameters::{make_ahe_config, make_kahe_config}; use single_thread_hkdf::SingleThreadHkdfPrng; use status_matchers_rs::status_is; use std::collections::HashMap; @@ -233,30 +294,51 @@ mod tests { fn setup() -> Result { let aggregation_config = generate_aggregation_config(DEFAULT_VECTOR_ID.to_string(), 16, 10, 1, 1); + let max_number_of_decryptors = aggregation_config.max_number_of_decryptors; // Create client. - let kahe = ShellKahe::new(make_kahe_config(&aggregation_config), CONTEXT_STRING).unwrap(); - let vahe = ShellVahe::new(make_ahe_config(), CONTEXT_STRING).unwrap(); + let kahe = + ShellKahe::new(create_shell_kahe_config(&aggregation_config).unwrap(), CONTEXT_STRING) + .unwrap(); + let vahe = ShellVahe::new( + create_shell_ahe_config(max_number_of_decryptors).unwrap(), + CONTEXT_STRING, + ) + .unwrap(); let seed = SingleThreadHkdfPrng::generate_seed()?; let prng = SingleThreadHkdfPrng::create(&seed)?; let mut client = WillowV1Client { kahe, vahe, prng }; // Create decryptor, which needs its own `vahe` (with same public polynomials // generated from the seeds) and `prng`. - let vahe = ShellVahe::new(make_ahe_config(), CONTEXT_STRING).unwrap(); + let vahe = ShellVahe::new( + create_shell_ahe_config(max_number_of_decryptors).unwrap(), + CONTEXT_STRING, + ) + .unwrap(); let seed = SingleThreadHkdfPrng::generate_seed()?; let prng = SingleThreadHkdfPrng::create(&seed)?; let mut decryptor_state = DecryptorState::default(); let mut decryptor = WillowV1Decryptor { vahe, prng }; // Create server. - let kahe = ShellKahe::new(make_kahe_config(&aggregation_config), CONTEXT_STRING).unwrap(); - let vahe = ShellVahe::new(make_ahe_config(), CONTEXT_STRING).unwrap(); + let kahe = + ShellKahe::new(create_shell_kahe_config(&aggregation_config).unwrap(), CONTEXT_STRING) + .unwrap(); + let vahe = ShellVahe::new( + create_shell_ahe_config(max_number_of_decryptors).unwrap(), + CONTEXT_STRING, + ) + .unwrap(); let server = WillowV1Server { kahe, vahe }; let mut server_state = ServerState::default(); // Create verifier. - let vahe = ShellVahe::new(make_ahe_config(), CONTEXT_STRING).unwrap(); + let vahe = ShellVahe::new( + create_shell_ahe_config(max_number_of_decryptors).unwrap(), + CONTEXT_STRING, + ) + .unwrap(); let verifier = WillowV1Verifier { vahe }; // Decryptor generates public key share. @@ -399,4 +481,30 @@ mod tests { .with_message(contains_substring("at least one client message "))) ) } + + #[gtest] + fn verifier_state_serialization_roundtrip() -> googletest::Result<()> { + let setup = setup()?; + let mut verifier_state = VerifierState::default(); + + // Check empty state serialization + let verifier_state_proto = verifier_state.to_proto(&setup.verifier)?; + let verifier_state_roundtrip = + VerifierState::from_proto(verifier_state_proto, &setup.verifier)?; + verify_true!(verifier_state_roundtrip.0.is_none())?; + + // Check populated state serialization + setup.verifier.verify_and_include( + setup.decryption_request_contribution.clone(), + &mut verifier_state, + )?; + let nonce_bounds_before = verifier_state.0.as_ref().unwrap().nonce_bounds.clone(); + let verifier_state_proto = verifier_state.to_proto(&setup.verifier)?; + let verifier_state_roundtrip = + VerifierState::from_proto(verifier_state_proto, &setup.verifier)?; + verify_true!(verifier_state_roundtrip.0.is_some())?; + verify_eq!(verifier_state_roundtrip.0.as_ref().unwrap().nonce_bounds, nonce_bounds_before)?; + + Ok(()) + } }