diff --git a/willow/benches/shell_benchmarks.rs b/willow/benches/shell_benchmarks.rs index ff11bd9..1a64813 100644 --- a/willow/benches/shell_benchmarks.rs +++ b/willow/benches/shell_benchmarks.rs @@ -324,8 +324,10 @@ fn setup_server_recover_aggregation_result(args: &Args) -> ServerRecoverInputs { let pd_ct = inputs.verifier.create_partial_decryption_request(inputs.verifier_state).unwrap(); // Decryptor creates partial decryption. - let pd = - inputs.decryptor.handle_partial_decryption_request(pd_ct, &inputs.decryptor_state).unwrap(); + let pd = inputs + .decryptor + .handle_partial_decryption_request(pd_ct, &mut inputs.decryptor_state) + .unwrap(); // Server handles the partial decryption. inputs.server.handle_partial_decryption(pd, &mut inputs.server_state).unwrap(); @@ -384,7 +386,7 @@ fn run_decryptor_partial_decryption(inputs: &mut DecryptorInputs) { .decryptor .handle_partial_decryption_request( black_box(inputs.partial_decryption_request.clone()), - black_box(&inputs.decryptor_state), + black_box(&mut inputs.decryptor_state), ) .unwrap(); let _ = black_box(res); // Prevent optimization. diff --git a/willow/proto/willow/BUILD b/willow/proto/willow/BUILD index f5dabb5..f543f23 100644 --- a/willow/proto/willow/BUILD +++ b/willow/proto/willow/BUILD @@ -109,6 +109,7 @@ proto_library( name = "messages_proto", srcs = ["messages.proto"], deps = [ + ":aggregation_config_proto", "//willow/proto/shell:shell_ciphertexts_proto", "//willow/proto/zk:proofs_proto", ], diff --git a/willow/proto/willow/messages.proto b/willow/proto/willow/messages.proto index a9375de..8e3c66d 100644 --- a/willow/proto/willow/messages.proto +++ b/willow/proto/willow/messages.proto @@ -17,6 +17,7 @@ syntax = "proto3"; package secure_aggregation.willow; import "willow/proto/shell/ciphertexts.proto"; +import "willow/proto/willow/aggregation_config.proto"; import "willow/proto/zk/proofs.proto"; option java_multiple_files = true; @@ -32,10 +33,13 @@ message ClientMessage { message PartialDecryptionRequest { ShellAhePartialDecCiphertext partial_dec_ciphertext = 1; + AggregationConfigProto aggregation_config = 2; } message PartialDecryptionResponse { ShellAhePartialDecryption partial_decryption = 1; + // Noise contribution to the final aggregated result. + CiphertextContribution dp_ciphertext_contribution = 2; } message CiphertextContribution { @@ -51,6 +55,7 @@ message DecryptionRequestContribution { message DecryptorStateProto { ShellAheSecretKeyShare sk_share = 1; + AggregationConfigProto aggregation_config = 2; } message ServerStateProto { diff --git a/willow/protocol/BUILD b/willow/protocol/BUILD index 9e74a04..e6df2de 100644 --- a/willow/protocol/BUILD +++ b/willow/protocol/BUILD @@ -29,11 +29,13 @@ rust_library( deps = [ "@protobuf//rust:protobuf", "//ffi_utils:status", + "//willow/api:aggregation_config", "//willow/api:proto_serialization_traits", "//willow/crypto:ahe_traits", "//willow/crypto:kahe_traits", "//willow/crypto:vahe_traits", "//willow/proto/shell:shell_ciphertexts_rust_proto", + "//willow/proto/willow:aggregation_config_rust_proto", "//willow/proto/willow:messages_rust_proto", "//willow/proto/zk:proofs_rust_proto", ], @@ -74,6 +76,8 @@ rust_library( deps = [ ":messages", "//ffi_utils:status", + "//willow/crypto:ahe_traits", + "//willow/crypto:kahe_traits", "//willow/crypto:vahe_traits", ], ) @@ -130,8 +134,10 @@ rust_test( "@crate_index//:googletest", "//willow/api:proto_serialization_traits", "//willow/crypto:ahe_traits", + "//willow/crypto:shell_kahe", "//willow/crypto:shell_parameters", "//willow/crypto:shell_vahe", + "//willow/testing_utils", ], ) @@ -145,9 +151,13 @@ rust_library( ":messages", "@protobuf//rust:protobuf", "//ffi_utils:status", + "//willow/api:aggregation_config", "//willow/api:proto_serialization_traits", "//willow/crypto:ahe_traits", + "//willow/crypto:kahe_traits", "//willow/crypto:prng_traits", + "//willow/crypto:shell_kahe", + "//willow/crypto:shell_parameters", "//willow/crypto:vahe_traits", "//willow/proto/shell:shell_ciphertexts_rust_proto", "//willow/proto/willow:messages_rust_proto", diff --git a/willow/protocol/decryptor_traits.rs b/willow/protocol/decryptor_traits.rs index b0d7dae..334ac65 100644 --- a/willow/protocol/decryptor_traits.rs +++ b/willow/protocol/decryptor_traits.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Google LLC +// Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -use messages::{DecryptorPublicKeyShare, PartialDecryptionRequest, PartialDecryptionResponse}; +use ahe_traits::AheBase; +use kahe_traits::KaheBase; +use messages::{ + DecryptorPublicKey, DecryptorPublicKeyShare, PartialDecryptionRequest, + PartialDecryptionResponse, RecoveryRequest, RecoveryResponse, SetupContribution, + VerifyKeyContributionsRequest, +}; use status::StatusError; use vahe_traits::HasVahe; @@ -28,11 +34,136 @@ pub trait SecureAggregationDecryptor: HasVahe { decryptor_state: &mut Self::DecryptorState, ) -> Result::Vahe>, StatusError>; + type Kahe: KaheBase; + /// Handles a partial decryption request received from the Server. Returns a /// partial decryption to the Server. fn handle_partial_decryption_request( &self, partial_decryption_request: PartialDecryptionRequest<::Vahe>, - decryptor_state: &Self::DecryptorState, - ) -> Result::Vahe>, StatusError>; + decryptor_state: &mut Self::DecryptorState, + ) -> Result::Vahe>, StatusError>; +} + +/// Trait for reputable/non-recoverable decryptors (e.g. TEEs) in a multi-decryptor committee. +pub trait SecureAggregationBaseMultiDecryptor: HasVahe { + /// The state held by the Decryptor between messages. + type DecryptorState: Default; + + /// Creates a public key share, a ZK proof of knowledge of the secret key, + /// and encrypted shares of the randomness used for key generation. + /// + /// The randomness shares are encrypted for other committee members. + fn create_setup_contribution( + &self, + decryptor_state: &mut Self::DecryptorState, + ) -> Result, StatusError>; + + /// Creates a public key share to be sent to the Server, updating the + /// decryptor state. + fn create_public_key_share( + &self, + decryptor_state: &mut Self::DecryptorState, + ) -> Result::Vahe>, StatusError>; + + /// Handles a partial decryption request received from the Server. Returns a + /// partial decryption to the Server. + fn handle_partial_decryption_request( + &self, + partial_decryption_request: PartialDecryptionRequest<::Vahe>, + kahe: Option<&Kahe>, + decryptor_state: &mut Self::DecryptorState, + ) -> Result::Vahe>, StatusError>; +} + +/// Trait for the reputable decryptors in a multi-decryptor committee. +/// +/// Reputable decryptors are assumed to be stable and do not share their +/// randomness for recovery. +pub trait SecureAggregationReputableDecryptor: SecureAggregationBaseMultiDecryptor { + /// Verifies the ZK proofs of knowledge of the secret key for all public key + /// shares, and returns the aggregated public key. Calling code should sign + /// the aggregated public key for the aggregation. + fn verify_and_aggregate_key_contributions( + &self, + request: VerifyKeyContributionsRequest<::Vahe>, + ) -> Result::Vahe>, StatusError>; +} + +/// Trait for the non-reputable decryptors in a multi-decryptor committee. +pub trait SecureAggregationNonReputableMultiDecryptor: SecureAggregationBaseMultiDecryptor { + /// Handles a request to decrypt shares of dropped decryptors. + /// + /// The decryptor should verify they are not being asked to decrypt more than + /// the allowed threshold of shares. + fn handle_recovery_request( + &self, + recovery_request: RecoveryRequest, + decryptor_state: &mut Self::DecryptorState, + ) -> Result; +} + +/// Trait for the protocol coordinator managing the multi-decryptor committee. +/// +/// The coordinator manages protocol flow and aggregates messages from all +/// decryptors. The coordinator itself does not contribute to the public key. +/// +/// The coordinator is not trusted for security at all, it does not hold any secrets and the +/// protocol is secure even if it behaves arbitrarily. +/// As such it need not be run on secure hardware, however it does need access to the +/// cryptographic library for most of these functions. +pub trait SecureAggregationCoordinator: HasVahe { + /// The state held by the Coordinator between protocol rounds. + type CoordinatorState: Default; + + /// Stores setup contributions from all decryptors and creates a request to verify the + /// contributions. + fn handle_setup_submissions( + &self, + non_reputable_contributions: Vec>, + reputable_contributions: Vec>, + coordinator_state: &mut Self::CoordinatorState, + ) -> Result, StatusError>; + + /// Combines the verifier's ciphertext half with the accumulated AHE components. + /// + /// The result should be forwarded to decryptors for partial decryption. + fn prepare_decryption_request( + &self, + verifier_ciphertext: &::PartialDecCiphertext, + coordinator_state: &mut Self::CoordinatorState, + ) -> Result, StatusError>; + + /// Accumulates partial decryptions from responding decryptors. + fn aggregate_partial_decryptions( + &self, + partial_responses: Vec>, + kahe: Option<&Kahe>, + coordinator_state: &mut Self::CoordinatorState, + ) -> Result<(), StatusError>; + + /// Creates recovery requests for surviving decryptors to decrypt shares of dropped client + /// decryptors. + /// + /// If the vector is empty, there are no dropped decryptors to recover and + /// recover_dropped_decryptors can be called immediately with an empty vector. + fn create_recovery_requests( + &self, + coordinator_state: &mut Self::CoordinatorState, + ) -> Result, StatusError>; + + /// Finalizes the decryption by recovering randomness from dropped decryptors. + /// + /// Uses decrypted shares from survivors to simulate missing partial decryptions. + fn recover_dropped_decryptors( + &self, + recovery_responses: Vec, + coordinator_state: &mut Self::CoordinatorState, + ) -> Result<(), StatusError>; + + /// Returns the resulting plaintext from the final decryption, if available. + fn get_plaintext( + &self, + coordinator_state: &mut Self::CoordinatorState, + ) -> Result<::Plaintext, StatusError>; } diff --git a/willow/protocol/messages.rs b/willow/protocol/messages.rs index 80cae44..0e19133 100644 --- a/willow/protocol/messages.rs +++ b/willow/protocol/messages.rs @@ -1,4 +1,4 @@ -// Copyright 2025 Google LLC +// Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use aggregation_config::AggregationConfig; use ahe_traits::AheBase; use kahe_traits::{HasKahe, KaheBase}; use messages_rust_proto::{ @@ -104,6 +105,8 @@ impl Clone for ClientMessage { // Partial decryption request is an aggregated AHE ciphertext. pub struct PartialDecryptionRequest { pub partial_dec_ciphertext: Vahe::PartialDecCiphertext, + // Only set if the decryptor is adding DP noise. + pub aggregation_config: Option, } impl<'a, C, Vahe> ToProto<&'a C> for PartialDecryptionRequest @@ -115,9 +118,13 @@ where type Proto = PartialDecryptionRequestProto; fn to_proto(&self, context: &'a C) -> Result { - Ok(proto!(PartialDecryptionRequestProto { + let mut proto = proto!(PartialDecryptionRequestProto { partial_dec_ciphertext: self.partial_dec_ciphertext.to_proto(context.vahe())?, - })) + }); + if let Some(config) = &self.aggregation_config { + proto.set_aggregation_config(config.to_proto(())?); + } + Ok(proto) } } @@ -134,19 +141,28 @@ where context: &'a C, ) -> Result { let proto = proto.as_view(); + let aggregation_config = if proto.has_aggregation_config() { + Some(AggregationConfig::from_proto(proto.aggregation_config(), ())?) + } else { + None + }; Ok(PartialDecryptionRequest { partial_dec_ciphertext: Vahe::PartialDecCiphertext::from_proto( proto.partial_dec_ciphertext(), context.vahe(), )?, + aggregation_config, }) } } /// We manually implement clone for PartialDecryptionRequest because Vahe is not cloneable. impl Clone for PartialDecryptionRequest { - fn clone(self: &PartialDecryptionRequest) -> PartialDecryptionRequest { - PartialDecryptionRequest { partial_dec_ciphertext: self.partial_dec_ciphertext.clone() } + fn clone(&self) -> PartialDecryptionRequest { + PartialDecryptionRequest { + partial_dec_ciphertext: self.partial_dec_ciphertext.clone(), + aggregation_config: self.aggregation_config.clone(), + } } } @@ -154,34 +170,61 @@ impl Debug for PartialDecryptionRequest { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { f.debug_struct("PartialDecryptionRequest") .field("partial_dec_ciphertext", &"(OMITTED)") + .field("aggregation_config", &self.aggregation_config) .finish() } } -pub struct PartialDecryptionResponse { +pub struct PartialDecryptionResponse { pub partial_decryption: Vahe::PartialDecryption, + // This contribution just contains encrypted DP noise. The server will be forced to include this + // contribution in the result because the randomness of the AHE encryption was included in the + // partial decryption request. + pub dp_ciphertext_contribution: Option>, } -impl<'a, C, Vahe> ToProto<&'a C> for PartialDecryptionResponse +impl<'a, C, Kahe, Vahe> ToProto<(&'a C, Option<&'a Kahe>)> for PartialDecryptionResponse where C: HasVahe, + Kahe: KaheBase + 'a, Vahe: VaheBase + 'a, Vahe::PartialDecryption: ToProto<&'a Vahe, Proto = ShellAhePartialDecryption>, + Kahe::Ciphertext: ToProto<&'a Kahe, Proto = ShellKaheCiphertext>, + Vahe::RecoverCiphertext: ToProto<&'a Vahe, Proto = ShellAheRecoverCiphertext>, { type Proto = PartialDecryptionResponseProto; - fn to_proto(&self, context: &'a C) -> Result { - Ok(proto!(PartialDecryptionResponseProto { - partial_decryption: self.partial_decryption.to_proto(context.vahe())?, - })) + fn to_proto( + &self, + (context, kahe): (&'a C, Option<&'a Kahe>), + ) -> Result { + let vahe = context.vahe(); + let mut proto = proto!(PartialDecryptionResponseProto { + partial_decryption: self.partial_decryption.to_proto(vahe)?, + }); + if let Some(dp_ct) = &self.dp_ciphertext_contribution { + let kahe = kahe.ok_or_else(|| { + status::failed_precondition("Missing Kahe context for DP ciphertext contribution") + })?; + proto.set_dp_ciphertext_contribution(proto!( + messages_rust_proto::CiphertextContribution { + kahe_ciphertext: dp_ct.kahe_ciphertext.to_proto(kahe)?, + ahe_recover_ciphertext: dp_ct.ahe_recover_ciphertext.to_proto(vahe)?, + } + )); + } + Ok(proto) } } -impl<'a, C, Vahe> FromProto<&'a C> for PartialDecryptionResponse +impl<'a, C, Kahe, Vahe> FromProto<&'a C> for PartialDecryptionResponse where - C: HasVahe, + C: HasKahe + HasVahe, + Kahe: KaheBase + 'a, Vahe: VaheBase + 'a, Vahe::PartialDecryption: FromProto<&'a Vahe, Proto = ShellAhePartialDecryption>, + Kahe::Ciphertext: FromProto<&'a Kahe, Proto = ShellKaheCiphertext>, + Vahe::RecoverCiphertext: FromProto<&'a Vahe, Proto = ShellAheRecoverCiphertext>, { type Proto = PartialDecryptionResponseProto; @@ -190,11 +233,17 @@ where context: &'a C, ) -> Result { let proto = proto.as_view(); + let dp_ciphertext_contribution = if proto.has_dp_ciphertext_contribution() { + Some(CiphertextContribution::from_proto(proto.dp_ciphertext_contribution(), context)?) + } else { + None + }; Ok(PartialDecryptionResponse { partial_decryption: Vahe::PartialDecryption::from_proto( proto.partial_decryption(), context.vahe(), )?, + dp_ciphertext_contribution, }) } } @@ -317,3 +366,123 @@ impl Clone for DecryptionRequestContribution { } } } + +/// A public key share and proof of knowledge of the secret key from a decryptor. +pub struct KeyContribution { + pub public_key_share: DecryptorPublicKeyShare, + /// This is required unless this is the only reputable decryptor. + pub proof: Option, +} + +/// The initial contribution to the decryption key from a decryptor. +pub struct SetupContribution { + pub key_contribution: KeyContribution, + /// Only needed if we are adding DP noise. + pub dp_setup: Option>, + /// Shares of the randomness (PRNG state) used for key generation, encrypted for other + /// decryptors to enable recovery. Only used by non-reputable decryptors. + pub encrypted_randomness_shares: Option>, +} + +/// Placeholder for an encrypted share of the PRNG state. +#[derive(Debug, Clone)] +pub struct SecretSharingContribution { + pub encrypted_share: Vec, +} + +/// Public key independent half of ciphertext for noise, only present for adding DP. +pub struct DPSetupContribution { + /// The public-key-independent half of the DP noise ciphertext (component_a). + /// This is also the half that must be sent to the verifier. + pub dp_partial_dec_ciphertext: Vahe::PartialDecCiphertext, + /// Proof of knowledge of the randomness used to create the ciphertext. + pub dp_partial_dec_ciphertext_proof: Option, +} + +/// A request to a surviving decryptor to decrypt shares of dropped decryptors. +#[derive(Debug)] +pub struct RecoveryRequest { + /// The encrypted shares that this decryptor is being asked to decrypt. + pub encrypted_shares: Vec, +} + +/// A response containing the decrypted shares of dropped decryptors. +#[derive(Debug)] +pub struct RecoveryResponse { + /// The decrypted shares. + pub decrypted_shares: Vec>, +} + +/// A request from the coordinator to a reputable decryptor to verify setup contributions +/// and construct the aggregated public key. +pub struct VerifyKeyContributionsRequest { + pub key_contributions: Vec>, +} + +/// Tracks a multi-decryptor's progress through the protocol. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum DecryptorStatus { + #[default] + PreSetup, + AwaitingKeyVerificationRequest, + AwaitingDecryptionRequest, + AwaitingRecoveryRequest, + Finished, +} + +/// State stored by a multi-decryptor between protocol rounds. +pub struct MultiDecryptorState { + /// The current status of the decryptor in the protocol. + pub status: DecryptorStatus, + /// The secret key share generated by the decryptor. None before setup. + pub sk_share: Option, + /// The PRNG state / randomness used during the DP setup phase, if applicable. + pub dp_randomness: Option>, +} + +impl Default for MultiDecryptorState { + fn default() -> Self { + Self { status: DecryptorStatus::default(), sk_share: None, dp_randomness: None } + } +} + +/// Tracks the coordinator's progress through the multi-decryptor protocol. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum CoordinatorStatus { + #[default] + PreSetup, + KeySharesReceived, + AwaitingContributions, + AwaitingPartialDecryptions, + AwaitingRecovery, + OutputReady, + Finished, +} + +/// State stored by the coordinator between protocol rounds. +pub struct CoordinatorState { + /// The current status of the coordinator in the protocol. + pub status: CoordinatorStatus, + /// The encrypted secret shares received from non-reputable decryptors during setup. Each inner + /// `Vec` contains shares from one non-reputable decryptor, encrypted for each *other* + /// decryptor. The outer index corresponds to the sender (the non-reputable decryptor), and + /// the inner index corresponds to the recipient decryptor. + pub encrypted_randomness_shares: Vec>, + /// The accumulated public key independent components (sum of A*r+e) for DP noise, if + /// applicable. + pub dp_noise_component_sum: Option, + /// The public key shares and proofs from all contributions, for sending back to decryptors for + /// signing requests. Only applicable in multiple reputable decryptor case and until used. + pub setup_contributions: Option>>, +} + +impl Default for CoordinatorState { + fn default() -> Self { + Self { + status: CoordinatorStatus::default(), + encrypted_randomness_shares: Vec::new(), + dp_noise_component_sum: None, + setup_contributions: None, + } + } +} diff --git a/willow/protocol/server_traits.rs b/willow/protocol/server_traits.rs index 4f7b70b..8808a89 100644 --- a/willow/protocol/server_traits.rs +++ b/willow/protocol/server_traits.rs @@ -70,7 +70,7 @@ pub trait SecureAggregationServer: HasKahe + HasVahe { /// server state. fn handle_partial_decryption( &self, - partial_decryption_response: PartialDecryptionResponse>, + partial_decryption_response: PartialDecryptionResponse, Vahe>, server_state: &mut Self::ServerState, ) -> Result<(), StatusError>; diff --git a/willow/protocol/willow_v1_decryptor.rs b/willow/protocol/willow_v1_decryptor.rs index 25999d4..4b75022 100644 --- a/willow/protocol/willow_v1_decryptor.rs +++ b/willow/protocol/willow_v1_decryptor.rs @@ -12,20 +12,23 @@ // See the License for the specific language governing permissions and // limitations under the License. +use aggregation_config::AggregationConfig; use ahe_traits::{AheKeygen, PartialDec}; use decryptor_traits::SecureAggregationDecryptor; +use kahe_traits::KaheBase; use messages::{DecryptorPublicKeyShare, PartialDecryptionRequest, PartialDecryptionResponse}; use messages_rust_proto::DecryptorStateProto; use prng_traits::SecurePrng; use proto_serialization_traits::{FromProto, ToProto}; use protobuf::AsView; use shell_ciphertexts_rust_proto::ShellAheSecretKeyShare; +use shell_kahe::ShellKahe; use status::StatusError; use std::cell::RefCell; use std::rc::Rc; use vahe_traits::{EncryptVerify, HasVahe, VaheBase}; -/// Lightweight decryptor directly exposing KAHE/VAHE types. It verifies only the client proofs, +/// Lightweight decryptor directly exposing VAHE types. It verifies only the client proofs, /// does not provide verifiable partial decryptions. pub struct WillowV1Decryptor { pub vahe: Rc, @@ -48,12 +51,14 @@ impl WillowV1Decryptor { } pub struct DecryptorState { - sk_share: Option, + pub sk_share: Option, + pub kahe: Option>, + pub aggregation_config: Option, } impl Default for DecryptorState { fn default() -> Self { - Self { sk_share: None } + Self { sk_share: None, kahe: None, aggregation_config: None } } } @@ -70,6 +75,9 @@ where if let Some(sk) = &self.sk_share { proto.set_sk_share(sk.to_proto(context.vahe())?); } + if let Some(config) = &self.aggregation_config { + proto.set_aggregation_config(config.to_proto(())?); + } Ok(proto) } } @@ -92,7 +100,19 @@ where } else { None }; - Ok(DecryptorState { sk_share }) + let aggregation_config = if proto.has_aggregation_config() { + Some(AggregationConfig::from_proto(proto.aggregation_config(), ())?) + } else { + None + }; + let kahe = if let Some(config) = &aggregation_config { + use shell_parameters::create_shell_configs; + let (kahe_config, _) = create_shell_configs(config)?; + Some(Rc::new(ShellKahe::new(kahe_config, &config.key_id)?)) + } else { + None + }; + Ok(DecryptorState { sk_share, kahe, aggregation_config }) } } @@ -104,6 +124,7 @@ where Vahe: VaheBase + EncryptVerify + PartialDec + AheKeygen, { type DecryptorState = DecryptorState; + type Kahe = ShellKahe; /// Creates a public key share to be sent to the Server, updating the /// decryptor state. @@ -121,8 +142,16 @@ where fn handle_partial_decryption_request( &self, partial_decryption_request: PartialDecryptionRequest, - decryptor_state: &Self::DecryptorState, - ) -> Result, status::StatusError> { + decryptor_state: &mut Self::DecryptorState, + ) -> Result, status::StatusError> { + if let Some(config) = &partial_decryption_request.aggregation_config { + if decryptor_state.kahe.is_none() { + use shell_parameters::create_shell_configs; + let (kahe_config, _) = create_shell_configs(config)?; + decryptor_state.kahe = Some(Rc::new(ShellKahe::new(kahe_config, &config.key_id)?)); + decryptor_state.aggregation_config = Some(config.clone()); + } + } let Some(ref sk_share) = decryptor_state.sk_share else { return Err(status::failed_precondition( "decryptor_state does not contain a secret key share", @@ -134,7 +163,7 @@ where sk_share, &mut self.prng.borrow_mut(), )?; - Ok(PartialDecryptionResponse { partial_decryption: pd }) + Ok(PartialDecryptionResponse { partial_decryption: pd, dp_ciphertext_contribution: None }) } } @@ -174,4 +203,38 @@ mod tests { Ok(()) } + + #[gtest] + fn decryptor_state_with_config_roundtrip() -> googletest::Result<()> { + use kahe_traits::KaheBase; + use shell_kahe::ShellKahe; + use shell_parameters::create_shell_kahe_config; + use testing_utils::generate_aggregation_config; + + let vahe = + Rc::new(ShellVahe::new(create_shell_ahe_config(1).unwrap(), CONTEXT_STRING).unwrap()); + let decryptor = WillowV1Decryptor::new_with_randomly_generated_seed(vahe)?; + + let config = generate_aggregation_config("default".to_string(), 16, 10, 1, 1); + let kahe_config = create_shell_kahe_config(&config).unwrap(); + let kahe = Rc::new(ShellKahe::new(kahe_config, &config.key_id).unwrap()); + + let mut decryptor_state = DecryptorState::default(); + decryptor.create_public_key_share(&mut decryptor_state)?; + decryptor_state.kahe = Some(kahe); + decryptor_state.aggregation_config = Some(config); + + verify_true!(decryptor_state.kahe.is_some())?; + verify_true!(decryptor_state.aggregation_config.is_some())?; + + let decryptor_state_proto = decryptor_state.to_proto(&decryptor)?; + let decryptor_state_roundtrip = + DecryptorState::from_proto(decryptor_state_proto, &decryptor)?; + + verify_true!(decryptor_state_roundtrip.sk_share.is_some())?; + verify_true!(decryptor_state_roundtrip.kahe.is_some())?; + verify_true!(decryptor_state_roundtrip.aggregation_config.is_some())?; + + Ok(()) + } } diff --git a/willow/protocol/willow_v1_server.rs b/willow/protocol/willow_v1_server.rs index 22ad4b0..9837899 100644 --- a/willow/protocol/willow_v1_server.rs +++ b/willow/protocol/willow_v1_server.rs @@ -259,9 +259,14 @@ where /// server state. fn handle_partial_decryption( &self, - partial_decryption_response: PartialDecryptionResponse, + partial_decryption_response: PartialDecryptionResponse, server_state: &mut Self::ServerState, ) -> Result<(), status::StatusError> { + if partial_decryption_response.dp_ciphertext_contribution.is_some() { + return Err(status::failed_precondition( + "DP ciphertext contributions are not yet supported.", + )); + } let partial_decryption = partial_decryption_response.partial_decryption; if let Some(ref mut partial_decryption_sum) = server_state.partial_decryption_sum { self.vahe @@ -441,7 +446,7 @@ mod tests { verifier.verify_and_include(decryption_request_contribution, &mut verifier_state)?; server.handle_ciphertext_contribution(ciphertext_contribution, &mut server_state)?; let pd_ct = verifier.create_partial_decryption_request(verifier_state)?; - let pd = decryptor.handle_partial_decryption_request(pd_ct, &decryptor_state)?; + let pd = decryptor.handle_partial_decryption_request(pd_ct, &mut decryptor_state)?; server.handle_partial_decryption(pd, &mut server_state)?; // Check populated state serialization diff --git a/willow/protocol/willow_v1_verifier.rs b/willow/protocol/willow_v1_verifier.rs index 9217aea..fd94644 100644 --- a/willow/protocol/willow_v1_verifier.rs +++ b/willow/protocol/willow_v1_verifier.rs @@ -250,6 +250,7 @@ where state.validate()?; Ok(PartialDecryptionRequest { partial_dec_ciphertext: state.partial_dec_ciphertext_sum, + aggregation_config: None, }) } else { Err(status::failed_precondition( diff --git a/willow/testing_utils/shell_testing_decryptor.rs b/willow/testing_utils/shell_testing_decryptor.rs index b9074ef..27a5f2f 100644 --- a/willow/testing_utils/shell_testing_decryptor.rs +++ b/willow/testing_utils/shell_testing_decryptor.rs @@ -52,6 +52,13 @@ impl HasVahe for ShellTestingDecryptor { } } +impl kahe_traits::HasKahe for ShellTestingDecryptor { + type Kahe = ShellKahe; + fn kahe(&self) -> &Self::Kahe { + &self.kahe + } +} + impl ShellTestingDecryptor { /// Creates a new ShellTestingDecryptor, using the given context string to seed KAHE and AHE /// public parameters. @@ -157,7 +164,7 @@ impl ShellTestingDecryptor { fn generate_partial_decryption_response( &mut self, request: &PartialDecryptionRequest, - ) -> Result, StatusError> { + ) -> Result, StatusError> { match &self.secret_key { None => Err(status::invalid_argument("No secret key available")), Some(sk_share) => { @@ -166,7 +173,10 @@ impl ShellTestingDecryptor { sk_share, &mut self.prng.borrow_mut(), )?; - Ok(PartialDecryptionResponse { partial_decryption }) + Ok(PartialDecryptionResponse { + partial_decryption, + dp_ciphertext_contribution: None, + }) } } } @@ -181,7 +191,7 @@ impl ShellTestingDecryptor { let request = PartialDecryptionRequest::from_proto(request_proto, self)?; let response = self.generate_partial_decryption_response(&request)?; response - .to_proto(self) + .to_proto((self, Some(&self.kahe))) .map_err(|e| status::internal(&format!("ToProto error: {}", e)))? .serialize() .map_err(|e| status::internal(&format!("Serialize error: {}", e))) diff --git a/willow/tests/willow_v1_shell.rs b/willow/tests/willow_v1_shell.rs index 6c03e1e..4e41f77 100644 --- a/willow/tests/willow_v1_shell.rs +++ b/willow/tests/willow_v1_shell.rs @@ -103,7 +103,7 @@ fn encrypt_decrypt_one() -> googletest::Result<()> { let pd_ct = verifier.create_partial_decryption_request(verifier_state).unwrap(); // Decryptor creates partial decryption. - let pd = decryptor.handle_partial_decryption_request(pd_ct, &decryptor_state).unwrap(); + let pd = decryptor.handle_partial_decryption_request(pd_ct, &mut decryptor_state).unwrap(); // Server handles the partial decryption. server.handle_partial_decryption(pd, &mut server_state).unwrap(); @@ -226,11 +226,11 @@ fn encrypt_decrypt_one_serialized() -> googletest::Result<()> { PartialDecryptionRequest::from_proto(pd_ct_proto, &decryptor)?; // Decryptor creates partial decryption. - let pd = decryptor.handle_partial_decryption_request(pd_ct, &decryptor_state).unwrap(); + let pd = decryptor.handle_partial_decryption_request(pd_ct, &mut decryptor_state).unwrap(); // Serialize and deserialize the partial decryption. - let pd_proto = pd.to_proto(&decryptor)?; - let pd: PartialDecryptionResponse = + let pd_proto = pd.to_proto((&decryptor, None))?; + let pd: PartialDecryptionResponse = PartialDecryptionResponse::from_proto(pd_proto, &server)?; // Server handles the partial decryption. @@ -348,7 +348,7 @@ fn encrypt_decrypt_multiple_clients() -> googletest::Result<()> { let pd_ct = verifier.create_partial_decryption_request(verifier_state).unwrap(); // Decryptor creates partial decryption. - let pd = decryptor.handle_partial_decryption_request(pd_ct, &decryptor_state).unwrap(); + let pd = decryptor.handle_partial_decryption_request(pd_ct, &mut decryptor_state).unwrap(); // Server handles the partial decryption. server.handle_partial_decryption(pd, &mut server_state).unwrap(); @@ -496,7 +496,7 @@ fn encrypt_decrypt_multiple_clients_including_invalid_proofs() -> googletest::Re let pd_ct = verifier.create_partial_decryption_request(verifier_state).unwrap(); // Decryptor creates partial decryption. - let pd = decryptor.handle_partial_decryption_request(pd_ct, &decryptor_state).unwrap(); + let pd = decryptor.handle_partial_decryption_request(pd_ct, &mut decryptor_state).unwrap(); // Server handles the partial decryption. server.handle_partial_decryption(pd, &mut server_state).unwrap(); @@ -621,7 +621,7 @@ fn encrypt_decrypt_many_clients_decryptors() -> googletest::Result<()> { for i in 0..NUM_DECRYPTORS { // Each decryptor creates partial decryption. let pd = decryptors[i] - .handle_partial_decryption_request(pd_ct.clone(), &decryptor_states[i]) + .handle_partial_decryption_request(pd_ct.clone(), &mut decryptor_states[i]) .unwrap(); // Server handles the partial decryption. @@ -741,7 +741,7 @@ fn encrypt_decrypt_no_dropout() -> googletest::Result<()> { // Decryptors perform partial decryption. for i in 0..decryptors.len() { let pd = decryptors[i] - .handle_partial_decryption_request(pd_ct.clone(), &decryptor_states[i]) + .handle_partial_decryption_request(pd_ct.clone(), &mut decryptor_states[i]) .unwrap(); // Server handles the partial decryption. server.handle_partial_decryption(pd, &mut server_state).unwrap();