diff --git a/willow/src/api/BUILD b/willow/src/api/BUILD index 076cdf5..d5b7fe5 100644 --- a/willow/src/api/BUILD +++ b/willow/src/api/BUILD @@ -33,6 +33,7 @@ rust_library( "@protobuf//rust:protobuf", "//shell_wrapper:status", "//willow/proto/willow:aggregation_config_rust_proto", + "//willow/src/shell:single_thread_hkdf", "//willow/src/traits:proto_serialization_traits", ], ) diff --git a/willow/src/api/aggregation_config.rs b/willow/src/api/aggregation_config.rs index 16a85f8..1a6a1af 100644 --- a/willow/src/api/aggregation_config.rs +++ b/willow/src/api/aggregation_config.rs @@ -82,6 +82,19 @@ impl ToProto for AggregationConfig { } } +impl AggregationConfig { + /// Computes context bytes by hashing the session ID in the config. + pub fn compute_context_bytes(&self) -> Result, StatusError> { + let context_seed = single_thread_hkdf::compute_hkdf( + self.session_id.as_bytes(), + b"", + b"AggregationConfig.context_string", + single_thread_hkdf::seed_length(), + )?; + Ok(context_seed.as_bytes().to_vec()) + } +} + #[cfg(test)] mod tests { use crate::AggregationConfig; diff --git a/willow/src/input_encoding/BUILD b/willow/src/input_encoding/BUILD index 348fb4c..d8a7eb2 100644 --- a/willow/src/input_encoding/BUILD +++ b/willow/src/input_encoding/BUILD @@ -18,6 +18,7 @@ load("@rules_cc//cc:cc_test.bzl", "cc_test") package( default_applicable_licenses = [ ], + default_visibility = ["//visibility:public"], ) cc_library( diff --git a/willow/src/testing_utils/BUILD b/willow/src/testing_utils/BUILD index 9b3203b..fb70edb 100644 --- a/willow/src/testing_utils/BUILD +++ b/willow/src/testing_utils/BUILD @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@cxx.rs//tools/bazel:rust_cxx_bridge.bzl", "rust_cxx_bridge") +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") load("@rules_rust//rust:defs.bzl", "rust_library", "rust_test") package( default_applicable_licenses = [ ], - default_visibility = ["//visibility:public"], + default_visibility = ["//:internal"], ) # PRNG @@ -71,7 +74,6 @@ rust_library( "//shell_wrapper:status", "//willow/src/api:aggregation_config", "//willow/src/shell:kahe_shell", - "//willow/src/shell:single_thread_hkdf", "//willow/src/shell:vahe_shell", "//willow/src/traits:ahe_traits", "//willow/src/traits:kahe_traits", @@ -90,16 +92,28 @@ rust_test( ], ) +rust_cxx_bridge( + name = "shell_testing_decryptor_cxx", + src = "shell_testing_decryptor.rs", + deps = [ + ":shell_testing_decryptor", + ], +) + rust_library( name = "shell_testing_decryptor", - testonly = 1, srcs = [ "shell_testing_decryptor.rs", ], deps = [ - ":shell_testing_parameters", + "@protobuf//rust:protobuf", + "@cxx.rs//:cxx", + "//shell_wrapper:shell_types_cc", "//shell_wrapper:status", + "//willow/proto/willow:aggregation_config_rust_proto", + "//willow/proto/willow:messages_rust_proto", "//willow/src/api:aggregation_config", + "//willow/src/shell:ahe_shell", "//willow/src/shell:kahe_shell", "//willow/src/shell:parameters_shell", "//willow/src/shell:single_thread_hkdf", @@ -108,6 +122,35 @@ rust_library( "//willow/src/traits:kahe_traits", "//willow/src/traits:messages", "//willow/src/traits:prng_traits", + "//willow/src/traits:proto_serialization_traits", "//willow/src/traits:vahe_traits", ], ) + +cc_library( + name = "shell_testing_decryptor_cc", + srcs = ["shell_testing_decryptor.cc"], + hdrs = ["shell_testing_decryptor.h"], + deps = [ + ":shell_testing_decryptor_cxx", + "@abseil-cpp//absl/memory", + "@abseil-cpp//absl/status", + "@abseil-cpp//absl/status:statusor", + "//shell_wrapper:shell_types_cc", + "//willow/proto/shell:shell_ciphertexts_cc_proto", + "//willow/proto/willow:aggregation_config_cc_proto", + "//willow/proto/willow:messages_cc_proto", + "//willow/src/input_encoding:codec", + ], +) + +cc_test( + name = "shell_testing_decryptor_test", + srcs = ["shell_testing_decryptor_test.cc"], + deps = [ + ":shell_testing_decryptor_cc", + "@googletest//:gtest_main", + "//shell_wrapper:status_matchers", + "//willow/proto/willow:aggregation_config_cc_proto", + ], +) diff --git a/willow/src/testing_utils/shell_testing_decryptor.cc b/willow/src/testing_utils/shell_testing_decryptor.cc new file mode 100644 index 0000000..08cd02e --- /dev/null +++ b/willow/src/testing_utils/shell_testing_decryptor.cc @@ -0,0 +1,103 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "willow/src/testing_utils/shell_testing_decryptor.h" + +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "shell_wrapper/shell_types.h" +#include "willow/src/input_encoding/codec.h" +#include "willow/src/testing_utils/shell_testing_decryptor.rs.h" + +namespace secure_aggregation { + +ShellTestingDecryptor::ShellTestingDecryptor( + rust::Box decryptor) + : decryptor_(std::move(decryptor)) {} + +absl::StatusOr> +ShellTestingDecryptor::Create( + const willow::AggregationConfigProto& aggregation_config) { + std::string aggregation_config_proto = aggregation_config.SerializeAsString(); + rust::Slice slice = ToRustSlice(aggregation_config_proto); + + secure_aggregation::ShellTestingDecryptorRust* out; + std::unique_ptr status_message; + int status_code = + create_shell_testing_decryptor(slice, &out, &status_message); + + if (status_code != 0) { + return absl::Status(absl::StatusCode(status_code), *status_message); + } + // Use `into_box` to avoid linker issues arising from rust::Box::from_raw. + return absl::WrapUnique(new ShellTestingDecryptor(decryptor_into_box(out))); +} + +absl::StatusOr +ShellTestingDecryptor::GeneratePublicKey() { + rust::Vec out; + std::unique_ptr status_message; + int status_code = decryptor_->generate_public_key(&out, &status_message); + + if (status_code != 0) { + return absl::Status(absl::StatusCode(status_code), *status_message); + } + + willow::ShellAhePublicKey public_key; + if (!public_key.ParseFromArray(out.data(), out.size())) { + return absl::InternalError("Failed to parse ShellAhePublicKey"); + } + return public_key; +} + +absl::StatusOr ShellTestingDecryptor::Decrypt( + const willow::ClientMessage& message) { + std::string contribution_proto = message.SerializeAsString(); + rust::Slice slice( + reinterpret_cast(contribution_proto.data()), + contribution_proto.size()); + + rust::Vec rust_flat_data; + std::unique_ptr status_message; + int status_code = + decryptor_->decrypt(slice, &rust_flat_data, &status_message); + + if (status_code != 0) { + return absl::Status(absl::StatusCode(status_code), *status_message); + } + + willow::EncodedData encoded_data; + for (const auto& rust_entry : rust_flat_data) { + std::string key(rust_entry.key); + std::vector val; + val.reserve(rust_entry.values.size()); + for (auto v : rust_entry.values) { + val.push_back(static_cast(v)); + } + encoded_data[std::move(key)] = std::move(val); + } + + return encoded_data; +} + +} // namespace secure_aggregation diff --git a/willow/src/testing_utils/shell_testing_decryptor.h b/willow/src/testing_utils/shell_testing_decryptor.h new file mode 100644 index 0000000..6aa2988 --- /dev/null +++ b/willow/src/testing_utils/shell_testing_decryptor.h @@ -0,0 +1,59 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SECURE_AGGREGATION_WILLOW_SRC_TESTING_UTILS_SHELL_TESTING_DECRYPTOR_H_ +#define SECURE_AGGREGATION_WILLOW_SRC_TESTING_UTILS_SHELL_TESTING_DECRYPTOR_H_ + +#include + +#include "absl/status/statusor.h" +#include "willow/proto/shell/ciphertexts.pb.h" +#include "willow/proto/willow/aggregation_config.pb.h" +#include "willow/proto/willow/messages.pb.h" +#include "willow/src/input_encoding/codec.h" +#include "willow/src/testing_utils/shell_testing_decryptor.rs.h" + +namespace secure_aggregation { + +// Basic implementation of a single decryptor that uses Shell operations +// directly. Useful for testing Shell clients, by checking that encrypted +// messages can be decrypted properly. +class ShellTestingDecryptor { + public: + // Creates a new ShellTestingDecryptor from the given config, hashing the + // session ID from the config to seed KAHE and AHE public parameters. + static absl::StatusOr> Create( + const willow::AggregationConfigProto& aggregation_config); + + // Generates a new AHE public key, and stores the corresponding secret key. + absl::StatusOr GeneratePublicKey(); + + // Decrypts a client message using the stored AHE secret key, by recovering + // the KAHE key from the AHE ciphertext and then decrypting the KAHE + // ciphertext. Does not verify the client proof contained in the message. + absl::StatusOr Decrypt( + const willow::ClientMessage& message); + + private: + explicit ShellTestingDecryptor( + rust::Box decryptor); + + rust::Box decryptor_; +}; + +} // namespace secure_aggregation + +#endif // SECURE_AGGREGATION_WILLOW_SRC_TESTING_UTILS_SHELL_TESTING_DECRYPTOR_H_ diff --git a/willow/src/testing_utils/shell_testing_decryptor.rs b/willow/src/testing_utils/shell_testing_decryptor.rs index 32f8718..be53c0e 100644 --- a/willow/src/testing_utils/shell_testing_decryptor.rs +++ b/willow/src/testing_utils/shell_testing_decryptor.rs @@ -15,19 +15,28 @@ */ use aggregation_config::AggregationConfig; +use aggregation_config_rust_proto::AggregationConfigProto; +use ahe_shell::Ciphertext as VaheCiphertext; use ahe_traits::{AheBase, AheKeygen, PartialDec}; +use kahe_shell::Ciphertext as KaheCiphertext; use kahe_shell::ShellKahe; use kahe_traits::{KaheBase, KaheDecrypt, TrySecretKeyFrom}; use messages::ClientMessage; +use messages_rust_proto::ClientMessage as ClientMessageProto; use parameters_shell::create_shell_configs; use prng_traits::SecurePrng; +use proto_serialization_traits::{FromProto, ToProto}; +use protobuf::prelude::*; use single_thread_hkdf::SingleThreadHkdfPrng; +use status::ffi::FfiStatus; use status::{StatusError, StatusErrorCode}; use vahe_shell::ShellVahe; use vahe_traits::Recover; +use vahe_traits::VaheBase; /// Basic implementation of a single decryptor that uses Shell operations directly. Useful for -/// testing Shell clients, by checking that encrypted messages can be decrypted properly. +/// testing Shell clients, by checking that encrypted messages can be decrypted properly. Comes with +/// a C++ interface. pub struct ShellTestingDecryptor { kahe: ShellKahe, vahe: ShellVahe, @@ -88,4 +97,157 @@ impl ShellTestingDecryptor { } } } + + fn generate_public_key_serialized(&mut self) -> Result, StatusError> { + let pk = self.generate_public_key()?; + pk.to_proto(&self.vahe) + .map_err(|e| status::internal(format!("ToProto error: {}", e)))? + .serialize() + .map_err(|e| status::internal(format!("Serialize error: {}", e))) + } + + /// SAFETY: `out` and `out_status_message` must not be null. + unsafe fn generate_public_key_ffi( + &mut self, + out: *mut Vec, + out_status_message: *mut cxx::UniquePtr, + ) -> i32 { + match self.generate_public_key_serialized() { + Ok(pk) => { + *out = pk; + 0 + } + Err(status_error) => { + let ffi_status: FfiStatus = status_error.into(); + *out_status_message = ffi_status.message; + ffi_status.code + } + } + } + + fn decrypt_serialized( + &mut self, + contribution: &[u8], + ) -> Result, StatusError> { + let client_message_proto = ClientMessageProto::parse(contribution) + .map_err(|e| status::internal(format!("Failed to parse ClientMessageProto: {}", e)))?; + + let kahe_ciphertext = + KaheCiphertext::from_proto(client_message_proto.kahe_ciphertext(), &self.kahe)?; + let ahe_ciphertext = + VaheCiphertext::from_proto(client_message_proto.ahe_ciphertext(), &self.vahe)?; + + let proof = + ::EncryptionProof::from_proto(client_message_proto.proof(), ())?; + let nonce = client_message_proto.nonce().to_vec(); + + let client_message = ClientMessage { kahe_ciphertext, ahe_ciphertext, proof, nonce }; + + let plaintext = self.decrypt(&client_message)?; + let entries = plaintext + .into_iter() + .map(|(key, values)| ffi::EncodedDataEntry { key, values }) + .collect(); + Ok(entries) + } + + /// SAFETY: `out` and `out_status_message` must not be null. + unsafe fn decrypt_ffi( + &mut self, + contribution: &[u8], + out: *mut Vec, + out_status_message: *mut cxx::UniquePtr, + ) -> i32 { + match self.decrypt_serialized(contribution) { + Ok(result) => { + *out = result; + 0 + } + Err(status_error) => { + let ffi_status: FfiStatus = status_error.into(); + *out_status_message = ffi_status.message; + ffi_status.code + } + } + } +} + +/// CXX bridge to call ShellTestingDecryptor from C++, using serialized protos as input and output. +/// +/// SAFETY: all functions in this module are only called from the wrapping C++ library, +/// ensuring that output pointers are correctly wrapped by a rust::Box, and that pointer +/// arguments are not null. +#[cxx::bridge(namespace = "secure_aggregation")] +pub mod ffi { + struct EncodedDataEntry { + key: String, + values: Vec, + } + + extern "Rust" { + #[cxx_name = "ShellTestingDecryptorRust"] + type ShellTestingDecryptor; + + unsafe fn create_shell_testing_decryptor( + config: &[u8], + out: *mut *mut ShellTestingDecryptor, + out_status_message: *mut UniquePtr, + ) -> i32; + + #[rust_name = "generate_public_key_ffi"] + unsafe fn generate_public_key( + self: &mut ShellTestingDecryptor, + out: *mut Vec, + out_status_message: *mut UniquePtr, + ) -> i32; + + #[rust_name = "decrypt_ffi"] + unsafe fn decrypt( + self: &mut ShellTestingDecryptor, + contribution: &[u8], + out: *mut Vec, + out_status_message: *mut UniquePtr, + ) -> i32; + + unsafe fn decryptor_into_box(ptr: *mut ShellTestingDecryptor) + -> Box; + } +} + +fn create_shell_testing_decryptor_impl( + config: &[u8], +) -> Result, StatusError> { + let aggregation_config_proto = AggregationConfigProto::parse(config) + .map_err(|e| status::internal(format!("Failed to parse AggregationConfigProto: {}", e)))?; + let aggregation_config = AggregationConfig::from_proto(aggregation_config_proto, ())?; + let context_bytes = aggregation_config.compute_context_bytes()?; + let decryptor = ShellTestingDecryptor::new(&aggregation_config, &context_bytes)?; + Ok(Box::new(decryptor)) +} + +/// SAFETY: `out` and `out_status_message` must not be null. +unsafe fn create_shell_testing_decryptor( + config: &[u8], + out: *mut *mut ShellTestingDecryptor, + out_status_message: *mut cxx::UniquePtr, +) -> i32 { + match create_shell_testing_decryptor_impl(config) { + Ok(decryptor) => { + *out = Box::into_raw(decryptor); + 0 + } + Err(status_error) => { + let ffi_status: FfiStatus = status_error.into(); + *out_status_message = ffi_status.message; + ffi_status.code + } + } +} + +/// Converts a raw pointer to a Box. Ideally we would use `rust::Box::from_raw` +/// (https://cxx.rs/binding/box.html) directly from C++, but that causes linker errors. +/// +/// SAFETY: `ptr` must have been created by `Box::into_raw`, as in `create_shell_testing_decryptor`. +unsafe fn decryptor_into_box(ptr: *mut ShellTestingDecryptor) -> Box { + Box::from_raw(ptr) } diff --git a/willow/src/testing_utils/shell_testing_decryptor_test.cc b/willow/src/testing_utils/shell_testing_decryptor_test.cc new file mode 100644 index 0000000..f3d8180 --- /dev/null +++ b/willow/src/testing_utils/shell_testing_decryptor_test.cc @@ -0,0 +1,48 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "willow/src/testing_utils/shell_testing_decryptor.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "shell_wrapper/status_matchers.h" +#include "willow/proto/willow/aggregation_config.pb.h" + +namespace secure_aggregation { +namespace { + +using ::testing::NotNull; + +TEST(ShellTestingDecryptorTest, CreateAndGenerateKey) { + willow::AggregationConfigProto config; + config.set_max_number_of_decryptors(1); + config.set_max_number_of_clients(1); + config.set_max_decryptor_dropouts(0); + config.set_session_id("test_session"); + auto& vector_config = (*config.mutable_vector_configs())["test_vec"]; + vector_config.set_length(10); + vector_config.set_bound(100); + + SECAGG_ASSERT_OK_AND_ASSIGN(auto decryptor, + ShellTestingDecryptor::Create(config)); + ASSERT_THAT(decryptor, NotNull()); + + SECAGG_ASSERT_OK_AND_ASSIGN(const auto& pk, decryptor->GeneratePublicKey()); + EXPECT_TRUE(pk.has_poly()); +} + +} // namespace +} // namespace secure_aggregation