From 7965712fd73adcfd4faee8426b93743507c09315 Mon Sep 17 00:00:00 2001 From: Phillipp Schoppmann Date: Mon, 15 Dec 2025 18:17:17 -0800 Subject: [PATCH] Implement C++ accumulator that wraps Server and Verifier PiperOrigin-RevId: 845002382 --- willow/proto/willow/BUILD | 19 ++ willow/proto/willow/server_accumulator.proto | 33 ++ willow/src/api/BUILD | 69 ++++ willow/src/api/server_accumulator.cc | 97 ++++++ willow/src/api/server_accumulator.h | 70 ++++ willow/src/api/server_accumulator.rs | 336 +++++++++++++++++++ willow/src/api/server_accumulator_test.cc | 101 ++++++ willow/src/traits/server.rs | 6 +- willow/src/traits/verifier.rs | 4 +- willow/src/willow_v1/server.rs | 54 ++- willow/src/willow_v1/verifier.rs | 14 +- willow/tests/willow_v1_shell.rs | 3 +- 12 files changed, 765 insertions(+), 41 deletions(-) create mode 100644 willow/proto/willow/server_accumulator.proto create mode 100644 willow/src/api/server_accumulator.cc create mode 100644 willow/src/api/server_accumulator.h create mode 100644 willow/src/api/server_accumulator.rs create mode 100644 willow/src/api/server_accumulator_test.cc diff --git a/willow/proto/willow/BUILD b/willow/proto/willow/BUILD index c5c72b6..fae653c 100644 --- a/willow/proto/willow/BUILD +++ b/willow/proto/willow/BUILD @@ -87,3 +87,22 @@ rust_proto_library( name = "messages_rust_proto", deps = [":messages_proto"], ) + +proto_library( + name = "server_accumulator_proto", + srcs = ["server_accumulator.proto"], + deps = [ + ":aggregation_config_proto", + ":messages_proto", + ], +) + +cc_proto_library( + name = "server_accumulator_cc_proto", + deps = [":server_accumulator_proto"], +) + +rust_proto_library( + name = "server_accumulator_rust_proto", + deps = [":server_accumulator_proto"], +) diff --git a/willow/proto/willow/server_accumulator.proto b/willow/proto/willow/server_accumulator.proto new file mode 100644 index 0000000..f6b697d --- /dev/null +++ b/willow/proto/willow/server_accumulator.proto @@ -0,0 +1,33 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package secure_aggregation.willow; + +import "willow/proto/willow/aggregation_config.proto"; +import "willow/proto/willow/messages.proto"; + +option java_multiple_files = true; +option java_outer_classname = "ServerAccumulatorProto"; + +message ServerAccumulatorState { + ServerStateProto server_state = 1; + VerifierStateProto verifier_state = 2; + AggregationConfigProto aggregation_config = 3; +} + +message ClientMessageList { + repeated ClientMessage client_messages = 1; +} diff --git a/willow/src/api/BUILD b/willow/src/api/BUILD index bc42a7c..076cdf5 100644 --- a/willow/src/api/BUILD +++ b/willow/src/api/BUILD @@ -1,3 +1,5 @@ +load("@cxx.rs//tools/bazel:rust_cxx_bridge.bzl", "rust_cxx_bridge") + # Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") load("@rules_rust//rust:defs.bzl", "rust_library", "rust_test") package( @@ -42,3 +46,68 @@ rust_test( "//willow/src/testing_utils", ], ) + +cc_library( + name = "server_accumulator", + srcs = ["server_accumulator.cc"], + hdrs = ["server_accumulator.h"], + deps = [ + ":server_accumulator_cxx/include", # fixdeps: keep + ":server_accumulator_rust", # fixdeps: keep + "@abseil-cpp//absl/memory", + "@abseil-cpp//absl/status", + "@abseil-cpp//absl/status:statusor", + "@abseil-cpp//absl/strings", + "@cxx.rs//:cxx", + "@cxx.rs//:core", + "//willow/proto/willow:aggregation_config_cc_proto", + "//willow/proto/willow:server_accumulator_cc_proto", + ], +) + +cc_test( + name = "server_accumulator_test", + srcs = ["server_accumulator_test.cc"], + deps = [ + ":server_accumulator", + ":server_accumulator_cxx", # fixdeps: keep + "@googletest//:gtest_main", + "@abseil-cpp//absl/status", + "@abseil-cpp//absl/status:statusor", + "//willow/proto/willow:aggregation_config_cc_proto", + "//willow/proto/willow:server_accumulator_cc_proto", + ], +) + +rust_cxx_bridge( + name = "server_accumulator_cxx", + src = "server_accumulator.rs", + deps = [ + ":server_accumulator", + ], +) + +rust_library( + name = "server_accumulator_rust", + srcs = ["server_accumulator.rs"], + deps = [ + ":aggregation_config", + "@protobuf//rust:protobuf", + "@cxx.rs//:cxx", + "//shell_wrapper:status", + "//willow/proto/willow:aggregation_config_rust_proto", + "//willow/proto/willow:server_accumulator_rust_proto", + "//willow/src/shell:kahe_shell", + "//willow/src/shell:parameters_shell", + "//willow/src/shell:vahe_shell", + "//willow/src/traits:ahe_traits", + "//willow/src/traits:kahe_traits", + "//willow/src/traits:messages", + "//willow/src/traits:proto_serialization_traits", + "//willow/src/traits:server_traits", + "//willow/src/traits:vahe_traits", + "//willow/src/traits:verifier_traits", + "//willow/src/willow_v1:willow_v1_server", + "//willow/src/willow_v1:willow_v1_verifier", + ], +) diff --git a/willow/src/api/server_accumulator.cc b/willow/src/api/server_accumulator.cc new file mode 100644 index 0000000..f395f39 --- /dev/null +++ b/willow/src/api/server_accumulator.cc @@ -0,0 +1,97 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "willow/src/api/server_accumulator.h" + +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "include/cxx.h" +#include "willow/proto/willow/aggregation_config.pb.h" +#include "willow/proto/willow/server_accumulator.pb.h" +#include "willow/src/api/server_accumulator.rs.h" + +namespace secure_aggregation { + +absl::StatusOr> +WillowShellServerAccumulator::Create( + const willow::AggregationConfigProto& aggregation_config) { + secure_aggregation::ServerAccumulator* out; + std::unique_ptr status_message; + int status_code = + secure_aggregation::NewServerAccumulatorFromSerializedConfig( + std::make_unique(aggregation_config.SerializeAsString()), + &out, &status_message); + if (status_code != 0) { + return absl::Status(absl::StatusCode(status_code), *status_message); + } + return absl::WrapUnique(new WillowShellServerAccumulator(IntoBox(out))); +} + +absl::StatusOr> +WillowShellServerAccumulator::CreateFromSerializedState( + std::string serialized_state) { + secure_aggregation::ServerAccumulator* out; + std::unique_ptr status_message; + int status_code = secure_aggregation::NewServerAccumulatorFromSerializedState( + std::make_unique(std::move(serialized_state)), &out, + &status_message); + if (status_code != 0) { + return absl::Status(absl::StatusCode(status_code), *status_message); + } + return absl::WrapUnique(new WillowShellServerAccumulator(IntoBox(out))); +} + +absl::Status WillowShellServerAccumulator::ProcessClientMessages( + willow::ClientMessageList client_messages) { + client_messages.Clear(); + std::unique_ptr status_message; + int status_code = accumulator_->ProcessClientMessages( + std::make_unique(client_messages.SerializeAsString()), + &status_message); + if (status_code != 0) { + return absl::Status(absl::StatusCode(status_code), *status_message); + } + return absl::OkStatus(); +} + +absl::Status WillowShellServerAccumulator::Merge( + std::unique_ptr other) { + std::unique_ptr status_message; + int status_code = + accumulator_->Merge(std::move(other->accumulator_), &status_message); + if (status_code != 0) { + return absl::Status(absl::StatusCode(status_code), *status_message); + } + return absl::OkStatus(); +} + +absl::StatusOr WillowShellServerAccumulator::ToSerializedState() { + rust::Vec serialized_state; + std::unique_ptr status_message; + int status_code = + accumulator_->ToSerializedState(&serialized_state, &status_message); + if (status_code != 0) { + return absl::Status(absl::StatusCode(status_code), *status_message); + } + return std::string(reinterpret_cast(serialized_state.data()), + serialized_state.size()); +} + +} // namespace secure_aggregation \ No newline at end of file diff --git a/willow/src/api/server_accumulator.h b/willow/src/api/server_accumulator.h new file mode 100644 index 0000000..ef17477 --- /dev/null +++ b/willow/src/api/server_accumulator.h @@ -0,0 +1,70 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SECURE_AGGREGATION_WILLOW_SRC_API_SERVER_ACCUMULATOR_H_ +#define SECURE_AGGREGATION_WILLOW_SRC_API_SERVER_ACCUMULATOR_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "include/cxx.h" +#include "willow/proto/willow/aggregation_config.pb.h" +#include "willow/proto/willow/server_accumulator.pb.h" +#include "willow/src/api/server_accumulator.rs.h" + +namespace secure_aggregation { + +// Implements an accumulator class intended to be used by a batch processing +// system. Combines both the server and the verifier functionality of willow_v1, +// using SHELL for the underlying cryptography. +class WillowShellServerAccumulator { + public: + // Creates a new accumulator with the given aggregation_config and empty + // state. + static absl::StatusOr> Create( + const willow::AggregationConfigProto& aggregation_config); + + // Creates a new accumulator from the given serialized state, which must + // correspond to a serialized ServerAccumulatorState proto. + static absl::StatusOr> + CreateFromSerializedState(std::string serialized_state); + + // Processes a list of client messages. If an invalid message is encountered, + // an error is logged and processing continues. + absl::Status ProcessClientMessages(willow::ClientMessageList client_messages); + + // Merges the state of `other` into the current accumulator. + absl::Status Merge(std::unique_ptr other); + + // Converts the current state of the accumulator to a serialized + // ServerAccumulatorState proto. + absl::StatusOr ToSerializedState(); + + private: + explicit WillowShellServerAccumulator( + rust::Box accumulator) + : accumulator_(std::move(accumulator)) {} + + rust::Box accumulator_; +}; + +} // namespace secure_aggregation + +#endif // SECURE_AGGREGATION_WILLOW_SRC_API_SERVER_ACCUMULATOR_H_ diff --git a/willow/src/api/server_accumulator.rs b/willow/src/api/server_accumulator.rs new file mode 100644 index 0000000..aa3e3ae --- /dev/null +++ b/willow/src/api/server_accumulator.rs @@ -0,0 +1,336 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use aggregation_config::AggregationConfig; +use aggregation_config_rust_proto::AggregationConfigProto; +use ahe_traits::AheBase; +use kahe_shell::ShellKahe; +use kahe_traits::KaheBase; +use messages::ClientMessage; +use parameters_shell::{create_shell_ahe_config, create_shell_kahe_config}; +use proto_serialization_traits::{FromProto, ToProto}; +use protobuf::prelude::*; +use protobuf::AsView; +use server_accumulator_rust_proto::{ClientMessageList, ServerAccumulatorState}; +use server_traits::SecureAggregationServer; +use status::StatusError; +use vahe_shell::ShellVahe; +use verifier_traits::SecureAggregationVerifier; +use willow_v1_server::{ServerState, WillowV1Server}; +use willow_v1_verifier::{VerifierState, WillowV1Verifier}; + +#[cxx::bridge] +pub mod ffi { + extern "Rust" { + #[namespace = "secure_aggregation"] + type ServerAccumulator; + + // We cannot use status::FfiStatus because CXX requires shared structs to be defined in the + // same module. So using separate message and pointer as a workaround. + // SAFETY: All functions in this module are only called from the wrapping C++ library, + // ensuring that output pointers are correctly wrapped by a rust::Box, and that pointer + // arguments are not null. + + #[namespace = "secure_aggregation"] + #[cxx_name = "NewServerAccumulatorFromSerializedConfig"] + unsafe fn new_server_accumulator_from_serialized_config( + serialized_aggregation_config: UniquePtr, + out: *mut *mut ServerAccumulator, + out_status_message: *mut UniquePtr, + ) -> i32; + + #[namespace = "secure_aggregation"] + #[cxx_name = "NewServerAccumulatorFromSerializedState"] + unsafe fn new_server_accumulator_from_serialized_state( + serialized_server_accumulator: UniquePtr, + out: *mut *mut ServerAccumulator, + out_status_message: *mut UniquePtr, + ) -> i32; + + #[namespace = "secure_aggregation"] + #[cxx_name = "ProcessClientMessages"] + unsafe fn process_client_messages_ffi( + self: &mut ServerAccumulator, + client_messages: UniquePtr, + out_status_message: *mut UniquePtr, + ) -> i32; + + #[namespace = "secure_aggregation"] + #[cxx_name = "ToSerializedState"] + unsafe fn to_serialized_state_ffi( + self: &ServerAccumulator, + out: *mut Vec, + out_status_message: *mut UniquePtr, + ) -> i32; + + #[namespace = "secure_aggregation"] + #[cxx_name = "Merge"] + unsafe fn merge_ffi( + self: &mut ServerAccumulator, + other: Box, + out_status_message: *mut UniquePtr, + ) -> i32; + + #[namespace = "secure_aggregation"] + #[cxx_name = "IntoBox"] + unsafe fn into_box(ptr: *mut ServerAccumulator) -> Box; + } +} + +use status::ffi::FfiStatus; + +pub struct ServerAccumulator { + server: WillowV1Server, + server_state: ServerState, + verifier: WillowV1Verifier, + verifier_state: VerifierState, + aggregation_config: AggregationConfig, +} + +impl ServerAccumulator { + fn new(aggregation_config: AggregationConfig) -> Result { + let context_string = aggregation_config.session_id.as_bytes(); + let vahe_config = create_shell_ahe_config(aggregation_config.max_number_of_decryptors)?; + let kahe_config = create_shell_kahe_config(&aggregation_config)?; + let server_kahe = ShellKahe::new(kahe_config, context_string)?; + let server_vahe = ShellVahe::new(vahe_config.clone(), context_string)?; + let verifier_vahe = ShellVahe::new(vahe_config, context_string)?; + let server = WillowV1Server { kahe: server_kahe, vahe: server_vahe }; + let verifier = WillowV1Verifier { vahe: verifier_vahe }; + Ok(Self { + server: server, + server_state: Default::default(), + verifier: verifier, + verifier_state: Default::default(), + aggregation_config: aggregation_config, + }) + } + + fn new_from_serialized_config( + mut serialized_aggregation_config: cxx::UniquePtr, + ) -> Result { + let serialized_aggregation_config_proto = AggregationConfigProto::parse( + serialized_aggregation_config.as_bytes(), + ) + .map_err(|e| status::internal(format!("Failed to parse AggregationConfigProto: {}", e)))?; + serialized_aggregation_config = cxx::UniquePtr::null(); // Release memory. + let aggregation_config = + AggregationConfig::from_proto(serialized_aggregation_config_proto, ())?; + Self::new(aggregation_config) + } + + fn new_from_serialized_state( + mut serialized_server_accumulator: cxx::UniquePtr, + ) -> Result { + let serialized_server_accumulator_proto = ServerAccumulatorState::parse( + serialized_server_accumulator.as_bytes(), + ) + .map_err(|e| status::internal(format!("Failed to parse ServerAccumulatorState: {}", e)))?; + serialized_server_accumulator = cxx::UniquePtr::null(); // Release memory. + Self::from_proto(serialized_server_accumulator_proto, ()) + } + + fn process_client_message( + &mut self, + client_message: ClientMessage, + ) -> Result<(), StatusError> { + let (ciphertext_contribution, decryption_request_contribution) = + self.server.split_client_message(client_message)?; + // Create a copy of the server and verifier state. Only update the accumulator state if + // processing succeededs all the way. + let mut server_state = self.server_state.clone(); + let mut verifier_state = self.verifier_state.clone(); + self.verifier.verify_and_include(decryption_request_contribution, &mut verifier_state)?; + self.server.handle_ciphertext_contribution(ciphertext_contribution, &mut server_state)?; + self.server_state = server_state; + self.verifier_state = verifier_state; + Ok(()) + } + + // Processes a list of client messages. If an invalid message is encountered, an error is logged + // and processing continues. + pub fn process_client_messages( + &mut self, + mut client_messages: Vec>, + ) -> () { + client_messages.sort_by(|a, b| a.nonce.cmp(&b.nonce)); + for message in client_messages { + if let Err(status) = self.process_client_message(message) { + eprintln!("Failed to process client message: {}", status); + } + } + } + + fn process_client_messages_serialized( + &mut self, + mut client_messages: cxx::UniquePtr, + ) -> Result<(), StatusError> { + let client_messages_proto = ClientMessageList::parse(client_messages.as_bytes()) + .map_err(|e| status::internal(format!("Failed to parse ClientMessageList: {}", e)))?; + client_messages = cxx::UniquePtr::null(); // Release memory. + let client_messages: Result, _> = client_messages_proto + .client_messages() + .iter() + .map(|m| ClientMessage::from_proto(m, &self.server)) + .collect(); + self.process_client_messages(client_messages?); + Ok(()) + } + + // SAFETY: + // - `out_status_message` must not be null. + pub unsafe fn process_client_messages_ffi( + &mut self, + client_messages: cxx::UniquePtr, + out_status_message: *mut cxx::UniquePtr, + ) -> i32 { + match self.process_client_messages_serialized(client_messages) { + Ok(()) => 0, + Err(status_error) => { + let ffi_status: FfiStatus = status_error.into(); + *out_status_message = ffi_status.message; + ffi_status.code + } + } + } + + pub fn merge(&mut self, other: Box) -> Result<(), StatusError> { + if self.aggregation_config != other.aggregation_config { + return Err(status::invalid_argument("Aggregation config mismatch")); + } + let server_state = std::mem::take(&mut self.server_state); + let verifier_state = std::mem::take(&mut self.verifier_state); + self.server_state = self.server.merge_states(server_state, other.server_state)?; + self.verifier_state = self.verifier.merge_states(verifier_state, other.verifier_state)?; + Ok(()) + } + + fn to_serialized_state(&self) -> Result, StatusError> { + self.to_proto(())?.serialize().map_err(|e| { + status::internal(format!("Failed to serialize ServerAccumulatorState: {}", e)) + }) + } + + // SAFETY: + // - `out_status_message` must not be null. + pub unsafe fn merge_ffi( + self: &mut ServerAccumulator, + other: Box, + out_status_message: *mut cxx::UniquePtr, + ) -> i32 { + match self.merge(other) { + Ok(()) => 0, + Err(status_error) => { + let ffi_status: FfiStatus = status_error.into(); + *out_status_message = ffi_status.message; + ffi_status.code + } + } + } + + pub unsafe fn to_serialized_state_ffi( + &self, + out: *mut Vec, + out_status_message: *mut cxx::UniquePtr, + ) -> i32 { + match self.to_serialized_state() { + Ok(serialized_state) => { + *out = serialized_state; + 0 + } + Err(status_error) => { + let ffi_status: FfiStatus = status_error.into(); + *out_status_message = ffi_status.message; + ffi_status.code + } + } + } +} + +impl ToProto for ServerAccumulator { + type Proto = ServerAccumulatorState; + + fn to_proto(&self, _context: ()) -> Result { + Ok(proto!(ServerAccumulatorState { + server_state: self.server_state.to_proto(&self.server)?, + verifier_state: self.verifier_state.to_proto(&self.verifier)?, + aggregation_config: self.aggregation_config.to_proto(())?, + })) + } +} + +impl FromProto for ServerAccumulator { + type Proto = ServerAccumulatorState; + + fn from_proto( + proto: impl AsView, + _context: (), + ) -> Result { + let proto = proto.as_view(); + let aggregation_config = AggregationConfig::from_proto(proto.aggregation_config(), ())?; + let mut result = Self::new(aggregation_config)?; + result.server_state = ServerState::from_proto(proto.server_state(), &result.server)?; + result.verifier_state = + VerifierState::from_proto(proto.verifier_state(), &result.verifier)?; + Ok(result) + } +} + +// SAFETY: +// - `out` must not be null. It must be turned into a rust::Box on the C++ side. +// - `out_status_message` must not be null. +unsafe fn new_server_accumulator_from_serialized_config( + serialized_aggregation_config: cxx::UniquePtr, + out: *mut *mut ServerAccumulator, + out_status_message: *mut cxx::UniquePtr, +) -> i32 { + match ServerAccumulator::new_from_serialized_config(serialized_aggregation_config) { + Ok(server_accumulator) => { + *out = Box::into_raw(Box::new(server_accumulator)); + 0 + } + Err(status_error) => { + let ffi_status: FfiStatus = status_error.into(); + *out_status_message = ffi_status.message; + ffi_status.code + } + } +} + +// SAFETY: +// - `out` must not be null. It must be turned into a rust::Box on the C++ side. +// - `out_status_message` must not be null. +unsafe fn new_server_accumulator_from_serialized_state( + serialized_server_accumulator: cxx::UniquePtr, + out: *mut *mut ServerAccumulator, + out_status_message: *mut cxx::UniquePtr, +) -> i32 { + match ServerAccumulator::new_from_serialized_state(serialized_server_accumulator) { + Ok(server_accumulator) => { + *out = Box::into_raw(Box::new(server_accumulator)); + 0 + } + Err(status_error) => { + let ffi_status: FfiStatus = status_error.into(); + *out_status_message = ffi_status.message; + ffi_status.code + } + } +} + +// SAFETY: +// - `ptr` must have been created by Box::into_raw or one of the functions in this module. +unsafe fn into_box(ptr: *mut ServerAccumulator) -> Box { + Box::from_raw(ptr) +} diff --git a/willow/src/api/server_accumulator_test.cc b/willow/src/api/server_accumulator_test.cc new file mode 100644 index 0000000..0550be1 --- /dev/null +++ b/willow/src/api/server_accumulator_test.cc @@ -0,0 +1,101 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "willow/src/api/server_accumulator.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "gtest/gtest.h" +#include "willow/proto/willow/aggregation_config.pb.h" +#include "willow/proto/willow/server_accumulator.pb.h" + +namespace secure_aggregation { +namespace { + +using ::secure_aggregation::willow::AggregationConfigProto; +using ::secure_aggregation::willow::ClientMessageList; +using ::secure_aggregation::willow::ServerAccumulatorState; +using ::secure_aggregation::willow::VectorConfig; + +AggregationConfigProto CreateValidConfig() { + AggregationConfigProto config; + VectorConfig vector_config; + vector_config.set_length(10); + vector_config.set_bound(100); + (*config.mutable_vector_configs())["test_vector"] = vector_config; + config.set_max_number_of_decryptors(1); + config.set_max_number_of_clients(10); + config.set_session_id("test_session"); + return config; +} + +TEST(WillowShellServerAccumulatorTest, CreateSucceedsWithValidConfig) { + AggregationConfigProto config = CreateValidConfig(); + auto accumulator_or = WillowShellServerAccumulator::Create(config); + ASSERT_TRUE(accumulator_or.ok()) << accumulator_or.status(); + EXPECT_NE(*accumulator_or, nullptr); +} + +TEST(WillowShellServerAccumulatorTest, ToSerializedStateHasCorrectConfig) { + AggregationConfigProto config = CreateValidConfig(); + auto accumulator = *WillowShellServerAccumulator::Create(config); + auto serialized_state_or = accumulator->ToSerializedState(); + ASSERT_TRUE(serialized_state_or.ok()) << serialized_state_or.status(); + + ServerAccumulatorState state; + ASSERT_TRUE(state.ParseFromString(*serialized_state_or)); + // Check if the config matches. We serialize and deserialize to compare protos + // easily or check fields. + EXPECT_EQ(state.aggregation_config().session_id(), config.session_id()); + EXPECT_EQ(state.aggregation_config().max_number_of_clients(), + config.max_number_of_clients()); +} + +TEST(WillowShellServerAccumulatorTest, CreateFromSerializedStateRoundTrip) { + AggregationConfigProto config = CreateValidConfig(); + auto accumulator = *WillowShellServerAccumulator::Create(config); + auto serialized_state_or = accumulator->ToSerializedState(); + ASSERT_TRUE(serialized_state_or.ok()) << serialized_state_or.status(); + + auto accumulator2_or = + WillowShellServerAccumulator::CreateFromSerializedState( + *serialized_state_or); + ASSERT_TRUE(accumulator2_or.ok()) << accumulator2_or.status(); + EXPECT_NE(*accumulator2_or, nullptr); + + auto serialized_state2_or = (*accumulator2_or)->ToSerializedState(); + ASSERT_TRUE(serialized_state2_or.ok()) << serialized_state2_or.status(); + EXPECT_EQ(*serialized_state_or, *serialized_state2_or); +} + +TEST(WillowShellServerAccumulatorTest, MergeSucceedsWithEmptyAccumulators) { + AggregationConfigProto config = CreateValidConfig(); + auto accumulator1 = *WillowShellServerAccumulator::Create(config); + auto accumulator2 = *WillowShellServerAccumulator::Create(config); + + EXPECT_TRUE(accumulator1->Merge(std::move(accumulator2)).ok()); +} + +TEST(WillowShellServerAccumulatorTest, ProcessClientMessagesWithEmptyList) { + AggregationConfigProto config = CreateValidConfig(); + auto accumulator = *WillowShellServerAccumulator::Create(config); + ClientMessageList empty_list; + EXPECT_TRUE(accumulator->ProcessClientMessages(empty_list).ok()); +} + +} // namespace +} // namespace secure_aggregation diff --git a/willow/src/traits/server.rs b/willow/src/traits/server.rs index ee5049f..967c15e 100644 --- a/willow/src/traits/server.rs +++ b/willow/src/traits/server.rs @@ -82,9 +82,9 @@ pub trait SecureAggregationServer: HasKahe + HasVahe { ) -> Result; /// Merges two server states into one. - fn merge_server_states( + fn merge_states( &self, - server_state_1: &Self::ServerState, - server_state_2: &Self::ServerState, + server_state_1: Self::ServerState, + server_state_2: Self::ServerState, ) -> Result; } diff --git a/willow/src/traits/verifier.rs b/willow/src/traits/verifier.rs index 0cc45ba..867b209 100644 --- a/willow/src/traits/verifier.rs +++ b/willow/src/traits/verifier.rs @@ -33,8 +33,8 @@ pub trait SecureAggregationVerifier: HasVahe { /// `verify_and_include` on all the contributions included in both states. fn merge_states( &self, - state1: &Self::VerifierState, - state2: &Self::VerifierState, + state1: Self::VerifierState, + state2: Self::VerifierState, ) -> Result; /// Returns a decryption request for the sum of the contributions, consumes the state. diff --git a/willow/src/willow_v1/server.rs b/willow/src/willow_v1/server.rs index 3ba55a6..5ca4f61 100644 --- a/willow/src/willow_v1/server.rs +++ b/willow/src/willow_v1/server.rs @@ -300,55 +300,53 @@ where /// client sums and partial decryption sums. The public key shares will be merged by joining all /// public key shares with unique IDs. In case IDs are present in both server states, the public /// key share from `server_state_1` will be used. - fn merge_server_states( + fn merge_states( &self, - server_state_1: &Self::ServerState, - server_state_2: &Self::ServerState, + server_state_1: Self::ServerState, + server_state_2: Self::ServerState, ) -> Result { let mut merged_server_state = ServerState::default(); // Merge public key shares. merged_server_state.decryptor_public_key_shares = - server_state_1.decryptor_public_key_shares.clone(); - for (id, key_share) in server_state_2.decryptor_public_key_shares.iter() { - if !merged_server_state.decryptor_public_key_shares.contains_key(id) { - merged_server_state - .decryptor_public_key_shares - .insert(id.to_string(), key_share.clone()); + server_state_1.decryptor_public_key_shares; + for (id, key_share) in server_state_2.decryptor_public_key_shares.into_iter() { + if !merged_server_state.decryptor_public_key_shares.contains_key(&id) { + merged_server_state.decryptor_public_key_shares.insert(id, key_share); } } merged_server_state.client_sum = - match (&server_state_1.client_sum, &server_state_2.client_sum) { + match (server_state_1.client_sum, server_state_2.client_sum) { ( Some((kahe_ciphertext_1, ahe_recover_ciphertext_1)), Some((kahe_ciphertext_2, ahe_recover_ciphertext_2)), ) => { - let mut merged_kahe_ciphertext = kahe_ciphertext_1.clone(); - let mut merged_ahe_recover_ciphertext = ahe_recover_ciphertext_1.clone(); - self.kahe - .add_ciphertexts_in_place(kahe_ciphertext_2, &mut merged_kahe_ciphertext)?; + let mut merged_kahe_ciphertext = kahe_ciphertext_1; + let mut merged_ahe_recover_ciphertext = ahe_recover_ciphertext_1; + self.kahe.add_ciphertexts_in_place( + &kahe_ciphertext_2, + &mut merged_kahe_ciphertext, + )?; self.vahe.add_recover_ciphertexts_in_place( - ahe_recover_ciphertext_2, + &ahe_recover_ciphertext_2, &mut merged_ahe_recover_ciphertext, )?; Some((merged_kahe_ciphertext, merged_ahe_recover_ciphertext)) } - (Some(s), None) | (None, Some(s)) => Some(s.clone()), + (Some(s), None) | (None, Some(s)) => Some(s), (None, None) => None, }; - merged_server_state.partial_decryption_sum = match ( - &server_state_1.partial_decryption_sum, - &server_state_2.partial_decryption_sum, - ) { - (Some(sum1), Some(sum2)) => { - let mut merged_sum = sum1.clone(); - self.vahe.add_partial_decryptions_in_place(sum2, &mut merged_sum)?; - Some(merged_sum) - } - (Some(s), None) | (None, Some(s)) => Some(s.clone()), - (None, None) => None, - }; + merged_server_state.partial_decryption_sum = + match (server_state_1.partial_decryption_sum, server_state_2.partial_decryption_sum) { + (Some(sum1), Some(sum2)) => { + let mut merged_sum = sum1; + self.vahe.add_partial_decryptions_in_place(&sum2, &mut merged_sum)?; + Some(merged_sum) + } + (Some(s), None) | (None, Some(s)) => Some(s), + (None, None) => None, + }; Ok(merged_server_state) } diff --git a/willow/src/willow_v1/verifier.rs b/willow/src/willow_v1/verifier.rs index 19c4e6c..f9006fa 100644 --- a/willow/src/willow_v1/verifier.rs +++ b/willow/src/willow_v1/verifier.rs @@ -203,8 +203,8 @@ where /// Merges two states into one. Fails if the intervals in the two states overlap. fn merge_states( &self, - state1: &Self::VerifierState, - state2: &Self::VerifierState, + state1: Self::VerifierState, + state2: Self::VerifierState, ) -> Result { match (&state1.0, &state2.0) { (Some(state1), Some(state2)) => { @@ -426,7 +426,7 @@ mod tests { // Try to merge the states, should fail. verify_that!( - setup.verifier.merge_states(&verifier_state_1, &verifier_state_2), + setup.verifier.merge_states(verifier_state_1, verifier_state_2), err(status_is(status::StatusErrorCode::InvalidArgument).with_message(eq( "`nonce_bounds.0` must be less than or equal to `nonce_bounds.1`" ))) @@ -444,8 +444,10 @@ mod tests { )?; // Merge with empty state, should preserve nonce bounds. - let verifier_state_3 = setup.verifier.merge_states(&verifier_state_1, &verifier_state_2)?; - let verifier_state_4 = setup.verifier.merge_states(&verifier_state_2, &verifier_state_1)?; + let verifier_state_3 = + setup.verifier.merge_states(verifier_state_1.clone(), verifier_state_2.clone())?; + let verifier_state_4 = + setup.verifier.merge_states(verifier_state_2.clone(), verifier_state_1.clone())?; // Nonce bounds should be the same as in verifier_state_1. verify_true!(verifier_state_3.0.is_some())?; @@ -466,7 +468,7 @@ mod tests { let verifier_state_1 = VerifierState::default(); let verifier_state_2 = VerifierState::default(); - let verifier_state_3 = setup.verifier.merge_states(&verifier_state_1, &verifier_state_2)?; + let verifier_state_3 = setup.verifier.merge_states(verifier_state_1, verifier_state_2)?; verify_true!(verifier_state_3.0.is_none()) } diff --git a/willow/tests/willow_v1_shell.rs b/willow/tests/willow_v1_shell.rs index 4c3e5ff..9e497fd 100644 --- a/willow/tests/willow_v1_shell.rs +++ b/willow/tests/willow_v1_shell.rs @@ -370,8 +370,7 @@ fn encrypt_decrypt_multiple_clients() -> googletest::Result<()> { if i < half { &mut verifier_state_1 } else { &mut verifier_state_2 }; verifier.verify_and_include(decryption_request_contribution, &mut verifier_state).unwrap(); } - let verifier_state_merged = - verifier.merge_states(&verifier_state_1, &verifier_state_2).unwrap(); + let verifier_state_merged = verifier.merge_states(verifier_state_1, verifier_state_2).unwrap(); // Run the rest of the protocol twice, once with each of the the two copies of the verifier state. for (mut server_state, verifier_state) in