Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions willow/src/api/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ rust_library(
"@protobuf//rust:protobuf",
"//shell_wrapper:status",
"//willow/proto/willow:aggregation_config_rust_proto",
"//willow/src/shell:single_thread_hkdf",
"//willow/src/traits:proto_serialization_traits",
],
)
Expand Down
13 changes: 13 additions & 0 deletions willow/src/api/aggregation_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,19 @@ impl ToProto for AggregationConfig {
}
}

impl AggregationConfig {
/// Computes context bytes by hashing the session ID in the config.
pub fn compute_context_bytes(&self) -> Result<Vec<u8>, StatusError> {
let context_seed = single_thread_hkdf::compute_hkdf(
self.session_id.as_bytes(),
b"",
b"AggregationConfig.context_string",
single_thread_hkdf::seed_length(),
)?;
Ok(context_seed.as_bytes().to_vec())
}
}

#[cfg(test)]
mod tests {
use crate::AggregationConfig;
Expand Down
1 change: 1 addition & 0 deletions willow/src/input_encoding/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ load("@rules_cc//cc:cc_test.bzl", "cc_test")
package(
default_applicable_licenses = [
],
default_visibility = ["//visibility:public"],
)

cc_library(
Expand Down
51 changes: 47 additions & 4 deletions willow/src/testing_utils/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

load("@cxx.rs//tools/bazel:rust_cxx_bridge.bzl", "rust_cxx_bridge")
load("@rules_cc//cc:cc_library.bzl", "cc_library")
load("@rules_cc//cc:cc_test.bzl", "cc_test")
load("@rules_rust//rust:defs.bzl", "rust_library", "rust_test")

package(
default_applicable_licenses = [
],
default_visibility = ["//visibility:public"],
default_visibility = ["//:internal"],
)

# PRNG
Expand Down Expand Up @@ -71,7 +74,6 @@ rust_library(
"//shell_wrapper:status",
"//willow/src/api:aggregation_config",
"//willow/src/shell:kahe_shell",
"//willow/src/shell:single_thread_hkdf",
"//willow/src/shell:vahe_shell",
"//willow/src/traits:ahe_traits",
"//willow/src/traits:kahe_traits",
Expand All @@ -90,16 +92,28 @@ rust_test(
],
)

rust_cxx_bridge(
name = "shell_testing_decryptor_cxx",
src = "shell_testing_decryptor.rs",
deps = [
":shell_testing_decryptor",
],
)

rust_library(
name = "shell_testing_decryptor",
testonly = 1,
srcs = [
"shell_testing_decryptor.rs",
],
deps = [
":shell_testing_parameters",
"@protobuf//rust:protobuf",
"@cxx.rs//:cxx",
"//shell_wrapper:shell_types_cc",
"//shell_wrapper:status",
"//willow/proto/willow:aggregation_config_rust_proto",
"//willow/proto/willow:messages_rust_proto",
"//willow/src/api:aggregation_config",
"//willow/src/shell:ahe_shell",
"//willow/src/shell:kahe_shell",
"//willow/src/shell:parameters_shell",
"//willow/src/shell:single_thread_hkdf",
Expand All @@ -108,6 +122,35 @@ rust_library(
"//willow/src/traits:kahe_traits",
"//willow/src/traits:messages",
"//willow/src/traits:prng_traits",
"//willow/src/traits:proto_serialization_traits",
"//willow/src/traits:vahe_traits",
],
)

cc_library(
name = "shell_testing_decryptor_cc",
srcs = ["shell_testing_decryptor.cc"],
hdrs = ["shell_testing_decryptor.h"],
deps = [
":shell_testing_decryptor_cxx",
"@abseil-cpp//absl/memory",
"@abseil-cpp//absl/status",
"@abseil-cpp//absl/status:statusor",
"//shell_wrapper:shell_types_cc",
"//willow/proto/shell:shell_ciphertexts_cc_proto",
"//willow/proto/willow:aggregation_config_cc_proto",
"//willow/proto/willow:messages_cc_proto",
"//willow/src/input_encoding:codec",
],
)

cc_test(
name = "shell_testing_decryptor_test",
srcs = ["shell_testing_decryptor_test.cc"],
deps = [
":shell_testing_decryptor_cc",
"@googletest//:gtest_main",
"//shell_wrapper:status_matchers",
"//willow/proto/willow:aggregation_config_cc_proto",
],
)
103 changes: 103 additions & 0 deletions willow/src/testing_utils/shell_testing_decryptor.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "willow/src/testing_utils/shell_testing_decryptor.h"

#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "shell_wrapper/shell_types.h"
#include "willow/src/input_encoding/codec.h"
#include "willow/src/testing_utils/shell_testing_decryptor.rs.h"

namespace secure_aggregation {

ShellTestingDecryptor::ShellTestingDecryptor(
rust::Box<ShellTestingDecryptorRust> decryptor)
: decryptor_(std::move(decryptor)) {}

absl::StatusOr<std::unique_ptr<ShellTestingDecryptor>>
ShellTestingDecryptor::Create(
const willow::AggregationConfigProto& aggregation_config) {
std::string aggregation_config_proto = aggregation_config.SerializeAsString();
rust::Slice<const uint8_t> slice = ToRustSlice(aggregation_config_proto);

secure_aggregation::ShellTestingDecryptorRust* out;
std::unique_ptr<std::string> status_message;
int status_code =
create_shell_testing_decryptor(slice, &out, &status_message);

if (status_code != 0) {
return absl::Status(absl::StatusCode(status_code), *status_message);
}
// Use `into_box` to avoid linker issues arising from rust::Box::from_raw.
return absl::WrapUnique(new ShellTestingDecryptor(decryptor_into_box(out)));
}

absl::StatusOr<willow::ShellAhePublicKey>
ShellTestingDecryptor::GeneratePublicKey() {
rust::Vec<uint8_t> out;
std::unique_ptr<std::string> status_message;
int status_code = decryptor_->generate_public_key(&out, &status_message);

if (status_code != 0) {
return absl::Status(absl::StatusCode(status_code), *status_message);
}

willow::ShellAhePublicKey public_key;
if (!public_key.ParseFromArray(out.data(), out.size())) {
return absl::InternalError("Failed to parse ShellAhePublicKey");
}
return public_key;
}

absl::StatusOr<willow::EncodedData> ShellTestingDecryptor::Decrypt(
const willow::ClientMessage& message) {
std::string contribution_proto = message.SerializeAsString();
rust::Slice<const uint8_t> slice(
reinterpret_cast<const uint8_t*>(contribution_proto.data()),
contribution_proto.size());

rust::Vec<secure_aggregation::EncodedDataEntry> rust_flat_data;
std::unique_ptr<std::string> status_message;
int status_code =
decryptor_->decrypt(slice, &rust_flat_data, &status_message);

if (status_code != 0) {
return absl::Status(absl::StatusCode(status_code), *status_message);
}

willow::EncodedData encoded_data;
for (const auto& rust_entry : rust_flat_data) {
std::string key(rust_entry.key);
std::vector<int64_t> val;
val.reserve(rust_entry.values.size());
for (auto v : rust_entry.values) {
val.push_back(static_cast<int64_t>(v));
}
encoded_data[std::move(key)] = std::move(val);
}

return encoded_data;
}

} // namespace secure_aggregation
59 changes: 59 additions & 0 deletions willow/src/testing_utils/shell_testing_decryptor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef SECURE_AGGREGATION_WILLOW_SRC_TESTING_UTILS_SHELL_TESTING_DECRYPTOR_H_
#define SECURE_AGGREGATION_WILLOW_SRC_TESTING_UTILS_SHELL_TESTING_DECRYPTOR_H_

#include <memory>

#include "absl/status/statusor.h"
#include "willow/proto/shell/ciphertexts.pb.h"
#include "willow/proto/willow/aggregation_config.pb.h"
#include "willow/proto/willow/messages.pb.h"
#include "willow/src/input_encoding/codec.h"
#include "willow/src/testing_utils/shell_testing_decryptor.rs.h"

namespace secure_aggregation {

// Basic implementation of a single decryptor that uses Shell operations
// directly. Useful for testing Shell clients, by checking that encrypted
// messages can be decrypted properly.
class ShellTestingDecryptor {
public:
// Creates a new ShellTestingDecryptor from the given config, hashing the
// session ID from the config to seed KAHE and AHE public parameters.
static absl::StatusOr<std::unique_ptr<ShellTestingDecryptor>> Create(
const willow::AggregationConfigProto& aggregation_config);

// Generates a new AHE public key, and stores the corresponding secret key.
absl::StatusOr<willow::ShellAhePublicKey> GeneratePublicKey();

// Decrypts a client message using the stored AHE secret key, by recovering
// the KAHE key from the AHE ciphertext and then decrypting the KAHE
// ciphertext. Does not verify the client proof contained in the message.
absl::StatusOr<willow::EncodedData> Decrypt(
const willow::ClientMessage& message);

private:
explicit ShellTestingDecryptor(
rust::Box<ShellTestingDecryptorRust> decryptor);

rust::Box<ShellTestingDecryptorRust> decryptor_;
};

} // namespace secure_aggregation

#endif // SECURE_AGGREGATION_WILLOW_SRC_TESTING_UTILS_SHELL_TESTING_DECRYPTOR_H_
Loading
Loading