Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions willow/proto/willow/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,22 @@ rust_proto_library(
name = "messages_rust_proto",
deps = [":messages_proto"],
)

proto_library(
name = "server_accumulator_proto",
srcs = ["server_accumulator.proto"],
deps = [
":aggregation_config_proto",
":messages_proto",
],
)

cc_proto_library(
name = "server_accumulator_cc_proto",
deps = [":server_accumulator_proto"],
)

rust_proto_library(
name = "server_accumulator_rust_proto",
deps = [":server_accumulator_proto"],
)
33 changes: 33 additions & 0 deletions willow/proto/willow/server_accumulator.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

syntax = "proto3";

package secure_aggregation.willow;

import "willow/proto/willow/aggregation_config.proto";
import "willow/proto/willow/messages.proto";

option java_multiple_files = true;
option java_outer_classname = "ServerAccumulatorProto";

message ServerAccumulatorState {
ServerStateProto server_state = 1;
VerifierStateProto verifier_state = 2;
AggregationConfigProto aggregation_config = 3;
}

message ClientMessageList {
repeated ClientMessage client_messages = 1;
}
69 changes: 69 additions & 0 deletions willow/src/api/BUILD
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
load("@cxx.rs//tools/bazel:rust_cxx_bridge.bzl", "rust_cxx_bridge")

# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -12,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

load("@rules_cc//cc:cc_library.bzl", "cc_library")
load("@rules_cc//cc:cc_test.bzl", "cc_test")
load("@rules_rust//rust:defs.bzl", "rust_library", "rust_test")

package(
Expand Down Expand Up @@ -42,3 +46,68 @@ rust_test(
"//willow/src/testing_utils",
],
)

cc_library(
name = "server_accumulator",
srcs = ["server_accumulator.cc"],
hdrs = ["server_accumulator.h"],
deps = [
":server_accumulator_cxx/include", # fixdeps: keep
":server_accumulator_rust", # fixdeps: keep
"@abseil-cpp//absl/memory",
"@abseil-cpp//absl/status",
"@abseil-cpp//absl/status:statusor",
"@abseil-cpp//absl/strings",
"@cxx.rs//:cxx",
"@cxx.rs//:core",
"//willow/proto/willow:aggregation_config_cc_proto",
"//willow/proto/willow:server_accumulator_cc_proto",
],
)

cc_test(
name = "server_accumulator_test",
srcs = ["server_accumulator_test.cc"],
deps = [
":server_accumulator",
":server_accumulator_cxx", # fixdeps: keep
"@googletest//:gtest_main",
"@abseil-cpp//absl/status",
"@abseil-cpp//absl/status:statusor",
"//willow/proto/willow:aggregation_config_cc_proto",
"//willow/proto/willow:server_accumulator_cc_proto",
],
)

rust_cxx_bridge(
name = "server_accumulator_cxx",
src = "server_accumulator.rs",
deps = [
":server_accumulator",
],
)

rust_library(
name = "server_accumulator_rust",
srcs = ["server_accumulator.rs"],
deps = [
":aggregation_config",
"@protobuf//rust:protobuf",
"@cxx.rs//:cxx",
"//shell_wrapper:status",
"//willow/proto/willow:aggregation_config_rust_proto",
"//willow/proto/willow:server_accumulator_rust_proto",
"//willow/src/shell:kahe_shell",
"//willow/src/shell:parameters_shell",
"//willow/src/shell:vahe_shell",
"//willow/src/traits:ahe_traits",
"//willow/src/traits:kahe_traits",
"//willow/src/traits:messages",
"//willow/src/traits:proto_serialization_traits",
"//willow/src/traits:server_traits",
"//willow/src/traits:vahe_traits",
"//willow/src/traits:verifier_traits",
"//willow/src/willow_v1:willow_v1_server",
"//willow/src/willow_v1:willow_v1_verifier",
],
)
97 changes: 97 additions & 0 deletions willow/src/api/server_accumulator.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "willow/src/api/server_accumulator.h"

#include <cstdint>
#include <memory>
#include <string>
#include <utility>

#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "include/cxx.h"
#include "willow/proto/willow/aggregation_config.pb.h"
#include "willow/proto/willow/server_accumulator.pb.h"
#include "willow/src/api/server_accumulator.rs.h"

namespace secure_aggregation {

absl::StatusOr<std::unique_ptr<WillowShellServerAccumulator>>
WillowShellServerAccumulator::Create(
const willow::AggregationConfigProto& aggregation_config) {
secure_aggregation::ServerAccumulator* out;
std::unique_ptr<std::string> status_message;
int status_code =
secure_aggregation::NewServerAccumulatorFromSerializedConfig(
std::make_unique<std::string>(aggregation_config.SerializeAsString()),
&out, &status_message);
if (status_code != 0) {
return absl::Status(absl::StatusCode(status_code), *status_message);
}
return absl::WrapUnique(new WillowShellServerAccumulator(IntoBox(out)));
}

absl::StatusOr<std::unique_ptr<WillowShellServerAccumulator>>
WillowShellServerAccumulator::CreateFromSerializedState(
std::string serialized_state) {
secure_aggregation::ServerAccumulator* out;
std::unique_ptr<std::string> status_message;
int status_code = secure_aggregation::NewServerAccumulatorFromSerializedState(
std::make_unique<std::string>(std::move(serialized_state)), &out,
&status_message);
if (status_code != 0) {
return absl::Status(absl::StatusCode(status_code), *status_message);
}
return absl::WrapUnique(new WillowShellServerAccumulator(IntoBox(out)));
}

absl::Status WillowShellServerAccumulator::ProcessClientMessages(
willow::ClientMessageList client_messages) {
client_messages.Clear();
std::unique_ptr<std::string> status_message;
int status_code = accumulator_->ProcessClientMessages(
std::make_unique<std::string>(client_messages.SerializeAsString()),
&status_message);
if (status_code != 0) {
return absl::Status(absl::StatusCode(status_code), *status_message);
}
return absl::OkStatus();
}

absl::Status WillowShellServerAccumulator::Merge(
std::unique_ptr<WillowShellServerAccumulator> other) {
std::unique_ptr<std::string> status_message;
int status_code =
accumulator_->Merge(std::move(other->accumulator_), &status_message);
if (status_code != 0) {
return absl::Status(absl::StatusCode(status_code), *status_message);
}
return absl::OkStatus();
}

absl::StatusOr<std::string> WillowShellServerAccumulator::ToSerializedState() {
rust::Vec<uint8_t> serialized_state;
std::unique_ptr<std::string> status_message;
int status_code =
accumulator_->ToSerializedState(&serialized_state, &status_message);
if (status_code != 0) {
return absl::Status(absl::StatusCode(status_code), *status_message);
}
return std::string(reinterpret_cast<const char*>(serialized_state.data()),
serialized_state.size());
}

} // namespace secure_aggregation
70 changes: 70 additions & 0 deletions willow/src/api/server_accumulator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef SECURE_AGGREGATION_WILLOW_SRC_API_SERVER_ACCUMULATOR_H_
#define SECURE_AGGREGATION_WILLOW_SRC_API_SERVER_ACCUMULATOR_H_

#include <memory>
#include <string>
#include <utility>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "include/cxx.h"
#include "willow/proto/willow/aggregation_config.pb.h"
#include "willow/proto/willow/server_accumulator.pb.h"
#include "willow/src/api/server_accumulator.rs.h"

namespace secure_aggregation {

// Implements an accumulator class intended to be used by a batch processing
// system. Combines both the server and the verifier functionality of willow_v1,
// using SHELL for the underlying cryptography.
class WillowShellServerAccumulator {
public:
// Creates a new accumulator with the given aggregation_config and empty
// state.
static absl::StatusOr<std::unique_ptr<WillowShellServerAccumulator>> Create(
const willow::AggregationConfigProto& aggregation_config);

// Creates a new accumulator from the given serialized state, which must
// correspond to a serialized ServerAccumulatorState proto.
static absl::StatusOr<std::unique_ptr<WillowShellServerAccumulator>>
CreateFromSerializedState(std::string serialized_state);

// Processes a list of client messages. If an invalid message is encountered,
// an error is logged and processing continues.
absl::Status ProcessClientMessages(willow::ClientMessageList client_messages);

// Merges the state of `other` into the current accumulator.
absl::Status Merge(std::unique_ptr<WillowShellServerAccumulator> other);

// Converts the current state of the accumulator to a serialized
// ServerAccumulatorState proto.
absl::StatusOr<std::string> ToSerializedState();

private:
explicit WillowShellServerAccumulator(
rust::Box<secure_aggregation::ServerAccumulator> accumulator)
: accumulator_(std::move(accumulator)) {}

rust::Box<secure_aggregation::ServerAccumulator> accumulator_;
};

} // namespace secure_aggregation

#endif // SECURE_AGGREGATION_WILLOW_SRC_API_SERVER_ACCUMULATOR_H_
Loading