Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions shell_wrapper/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ cc_library(
":status_macros",
"@abseil-cpp//absl/status",
"@abseil-cpp//absl/status:statusor",
"@abseil-cpp//absl/strings:str_format",
"@abseil-cpp//absl/strings:string_view",
"@abseil-cpp//absl/types:span",
"@shell-encryption//shell_encryption/rns:coefficient_encoder",
Expand Down
16 changes: 14 additions & 2 deletions shell_wrapper/kahe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "include/cxx.h"
Expand Down Expand Up @@ -347,6 +348,7 @@ FfiStatus PackMessagesRaw(rust::Slice<const uint64_t> messages,

FfiStatus UnpackMessagesRaw(uint64_t packing_base, uint64_t packing_dimension,
uint64_t num_packed_values,
uint64_t num_unpacked_values,
BigIntVectorWrapper& packed_values,
rust::Vec<uint64_t>& out) {
// Validate the wrappers.
Expand All @@ -358,15 +360,25 @@ FfiStatus UnpackMessagesRaw(uint64_t packing_base, uint64_t packing_dimension,
return MakeFfiStatus(
absl::InvalidArgumentError("insufficient number of packed values."));
}

// `unpacked_messages` is padded with zeros if needed.
std::vector<uint64_t> unpacked_messages =
rlwe::UnpackMessagesFlat<secure_aggregation::Integer,
secure_aggregation::BigInteger>(
absl::MakeSpan(*packed_values.ptr).subspan(0, num_packed_values),
packing_base, packing_dimension);
packed_values.ptr->erase(packed_values.ptr->begin(),
packed_values.ptr->begin() + num_packed_values);
for (auto& val : unpacked_messages) {
out.push_back(val);

// Remove padding and copy values to Rust output vector.
if (unpacked_messages.size() < num_unpacked_values) {
return MakeFfiStatus(absl::InvalidArgumentError(
absl::StrFormat("unpacked messages is too short (%d) for the requested "
"number of unpacked values (%d)",
unpacked_messages.size(), num_unpacked_values)));
}
for (size_t i = 0; i < num_unpacked_values; ++i) {
out.push_back(unpacked_messages[i]);
}
return MakeFfiStatus();
}
Expand Down
6 changes: 4 additions & 2 deletions shell_wrapper/kahe.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,16 @@ FfiStatus PackMessagesRaw(rust::Slice<const uint64_t> messages,
uint64_t num_packed_values,
BigIntVectorWrapper* packed_values);

// Unpacks messages stored at `packed_values[0..num_packed_values]` and appends
// them to `out`, and removes these packed values from `packed_values`.
// Unpacks messages stored at `packed_values[0..num_packed_values]`, removes
// these packed values from `packed_values` and appends the first
// `num_unpacked_values` messages to `out`.
// Expects `packed_values.ptr` to be a valid pointer to the vector of packed
// values, and expects packing_base > 1, packing_dimension > 0,
// num_packed_values > 0, packing_base^packing_dimension <
// std::numeric_limits<BigInteger>::max().
FfiStatus UnpackMessagesRaw(uint64_t packing_base, uint64_t packing_dimension,
uint64_t num_packed_values,
uint64_t num_unpacked_values,
BigIntVectorWrapper& packed_values,
rust::Vec<uint64_t>& out);

Expand Down
13 changes: 13 additions & 0 deletions shell_wrapper/kahe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,22 @@ use std::collections::HashMap;
use std::marker::PhantomData;
use std::mem::MaybeUninit;

/// Configuration for packing and unpacking. Used to convert a long vector of `length` small
/// integers in [0, `base`) into a short vector of `num_packed_coeffs` large integers in
/// [0, `base`^`dimension`), and vice versa.
#[derive(Debug, PartialEq, Clone)]
pub struct PackedVectorConfig {
/// Base for packing.
pub base: u64,

/// Number of elements packed into each coefficient.
pub dimension: u64,

/// Number of coefficients in the packed vector.
pub num_packed_coeffs: u64,

/// Number of elements in the plaintext vector before packing.
pub length: u64,
}

#[cxx::bridge]
Expand Down Expand Up @@ -93,6 +104,7 @@ mod ffi {
packing_base: u64,
packing_dimension: u64,
num_packed_values: u64,
num_unpacked_values: u64,
packed_values: &mut BigIntVectorWrapper,
out: &mut Vec<u64>,
) -> FfiStatus;
Expand Down Expand Up @@ -260,6 +272,7 @@ pub fn decrypt(
packed_vector_config.base,
packed_vector_config.dimension,
packed_vector_config.num_packed_coeffs,
packed_vector_config.length,
&mut packed_values,
&mut unpacked_values,
)
Expand Down
120 changes: 104 additions & 16 deletions shell_wrapper/kahe_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,8 @@ TEST(KaheTest, UnpackMessagesRawRemovesConsumedPackedValues) {
rust::Vec<Integer> unpacked_messages;
SECAGG_EXPECT_OK(UnwrapFfiStatus(
UnpackMessagesRaw(packing_base, packing_dimension, num_packed_values,
packed_values, unpacked_messages)));
num_packed_values * packing_dimension, packed_values,
unpacked_messages)));
EXPECT_EQ(packed_values.ptr->size(), num_packed_values);
EXPECT_EQ(unpacked_messages.size(), num_packed_values);
// Unpacked values should match the first half of the original packed values.
Expand All @@ -349,6 +350,94 @@ TEST(KaheTest, UnpackMessagesRawRemovesConsumedPackedValues) {
absl::MakeSpan(packed).subspan(num_packed_values));
}

TEST(KaheTest, RawEncryptDecryptPadding) {
constexpr int num_packing = 2;
constexpr int num_public_polynomials = 2;
constexpr int num_messages = 9;
constexpr Integer packing_base = 10;
// 2 messages per coefficient, last coefficient has only 1 message.
constexpr int num_packed_messages = 5;

std::unique_ptr<std::string> public_seed;
SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(public_seed)));
KahePublicParametersWrapper params;
SECAGG_ASSERT_OK(UnwrapFfiStatus(CreateKahePublicParametersWrapper(
kLogN, kLogT, ToRustSlice(kQs), num_public_polynomials,
ToRustSlice(*public_seed), &params)));
std::unique_ptr<std::string> private_seed;
SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(private_seed)));
SingleThreadHkdfWrapper prng;
SECAGG_ASSERT_OK(UnwrapFfiStatus(
CreateSingleThreadHkdf(ToRustSlice(*private_seed), prng)));
RnsPolynomialWrapper key;
SECAGG_ASSERT_OK(
UnwrapFfiStatus(GenerateSecretKeyWrapper(params, &prng, &key)));

// Pack messages that don't fully occupy the packed coefficients.
std::vector<Integer> input_messages =
rlwe::testing::SampleMessages(num_messages, packing_base);
BigIntVectorWrapper packed_messages{
.ptr = std::make_unique<std::vector<BigInteger>>()};
SECAGG_ASSERT_OK(UnwrapFfiStatus(
PackMessagesRaw(ToRustSlice(input_messages), packing_base, num_packing,
num_packed_messages, &packed_messages)));

// Encrypt the packed messages.
RnsPolynomialVecWrapper ciphertexts;
SECAGG_ASSERT_OK(UnwrapFfiStatus(
Encrypt(packed_messages, key, params, &prng, &ciphertexts)));

// Decrypt to get a packed plaintext.
BigIntVectorWrapper packed_messages_1{
.ptr = std::make_unique<std::vector<BigInteger>>()};
SECAGG_ASSERT_OK(
UnwrapFfiStatus(Decrypt(ciphertexts, key, params, &packed_messages_1)));

// Unpack and retrieve the original messages plus padding.
rust::Vec<Integer> unpacked_messages_1;
SECAGG_ASSERT_OK(UnwrapFfiStatus(UnpackMessagesRaw(
packing_base, num_packing, packed_messages_1.ptr->size(),
num_packed_messages * num_packing, packed_messages_1,
unpacked_messages_1)));

// Decrypted messages are padded to zero up to the end of the polynomial.
EXPECT_THAT(
absl::MakeSpan(unpacked_messages_1.data(), unpacked_messages_1.size())
.subspan(num_messages,
num_packed_messages * num_packing - num_messages),
::testing::Each(::testing::Eq(0)));

// Decrypt to obtain a fresh packed plaintext.
BigIntVectorWrapper packed_messages_2{
.ptr = std::make_unique<std::vector<BigInteger>>()};
SECAGG_ASSERT_OK(
UnwrapFfiStatus(Decrypt(ciphertexts, key, params, &packed_messages_2)));

// Now unpack and directly pass the right length to remove padding.
rust::Vec<Integer> unpacked_messages_2;
SECAGG_ASSERT_OK(UnwrapFfiStatus(UnpackMessagesRaw(
packing_base, num_packing, packed_messages_2.ptr->size(), num_messages,
packed_messages_2, unpacked_messages_2)));
EXPECT_EQ(
absl::MakeSpan(unpacked_messages_2.data(), unpacked_messages_2.size()),
absl::MakeSpan(input_messages));

// Finally, check that we fail if we request too many unpacked messages
BigIntVectorWrapper packed_messages_3{
.ptr = std::make_unique<std::vector<BigInteger>>()};
SECAGG_ASSERT_OK(
UnwrapFfiStatus(Decrypt(ciphertexts, key, params, &packed_messages_2)));
SECAGG_ASSERT_OK(UnwrapFfiStatus(
Encrypt(packed_messages_3, key, params, &prng, &ciphertexts)));
rust::Vec<Integer> unpacked_messages_3;
int num_unpacked_messages_3 = packed_messages_3.ptr->size() * num_packing + 1;
EXPECT_THAT(
UnwrapFfiStatus(UnpackMessagesRaw(
packing_base, num_packing, packed_messages_3.ptr->size(),
num_unpacked_messages_3, packed_messages_3, unpacked_messages_3)),
StatusIs(absl::StatusCode::kInvalidArgument));
}

TEST(KaheTest, PackAndEncrypt) {
constexpr int num_packing = 8;
constexpr int num_public_polynomials = 2;
Expand Down Expand Up @@ -399,19 +488,14 @@ TEST(KaheTest, PackAndEncrypt) {
.ptr = std::make_unique<std::vector<BigInteger>>(std::move(decrypted))};
rust::Vec<Integer> unpacked_messages;
SECAGG_ASSERT_OK(UnwrapFfiStatus(
UnpackMessagesRaw(packing_base, num_packing, packed_messages.size(),
decrypted_wrapper, unpacked_messages)));
UnpackMessagesRaw(packing_base, num_packing, num_packed_messages,
num_messages, decrypted_wrapper, unpacked_messages)));
EXPECT_EQ(absl::MakeSpan(unpacked_messages.data(), num_messages),
absl::MakeSpan(expected_unpacked_messages.data(), num_messages));

// Check against the original input messages.
EXPECT_EQ(absl::MakeSpan(unpacked_messages.data(), num_messages),
absl::MakeSpan(input_messages).subspan(0, num_messages));
// Check unpacked messages are padded with zeros.
ASSERT_GE(expected_unpacked_messages.size(), num_messages);
EXPECT_THAT(
absl::MakeSpan(unpacked_messages.data(), unpacked_messages.size())
.subspan(num_messages, unpacked_messages.size() - num_messages),
::testing::Each(::testing::Eq(0)));
absl::MakeSpan(input_messages));
}

TEST(KaheTest, RawVectorEncryptOnePolynomial) {
Expand Down Expand Up @@ -461,7 +545,8 @@ TEST(KaheTest, RawVectorEncryptOnePolynomial) {
rust::Vec<Integer> unpacked_decrypted_messages;
SECAGG_ASSERT_OK(UnwrapFfiStatus(UnpackMessagesRaw(
packing_base, num_packing, decrypted_wrapper.ptr->size(),
decrypted_wrapper, unpacked_decrypted_messages)));
decrypted_wrapper.ptr->size() * num_packing, decrypted_wrapper,
unpacked_decrypted_messages)));

// Filled the whole buffer with right messages.
EXPECT_EQ(absl::MakeSpan(unpacked_decrypted_messages.data(), num_messages),
Expand All @@ -481,6 +566,7 @@ TEST(KaheTest, RawVectorEncryptOnePolynomial) {
unpacked_decrypted_long_messages.reserve(buffer_length);
SECAGG_ASSERT_OK(UnwrapFfiStatus(UnpackMessagesRaw(
packing_base, num_packing, decrypted_long_messages_wrapper.ptr->size(),
decrypted_long_messages_wrapper.ptr->size() * num_packing,
decrypted_long_messages_wrapper, unpacked_decrypted_long_messages)));

// The non-zero messages are identical.
Expand Down Expand Up @@ -538,10 +624,10 @@ TEST(KaheTest, RawVectorEncryptTwoPolynomials) {
UnwrapFfiStatus(Decrypt(ciphertexts, key, params, &decrypted_wrapper)));
rust::Vec<Integer> unpacked_decrypted_messages;
SECAGG_ASSERT_OK(UnwrapFfiStatus(UnpackMessagesRaw(
packing_base, num_packing, decrypted_wrapper.ptr->size(),
packing_base, num_packing, decrypted_wrapper.ptr->size(), num_messages,
decrypted_wrapper, unpacked_decrypted_messages)));

EXPECT_GE(unpacked_decrypted_messages.size(), num_messages);
EXPECT_EQ(unpacked_decrypted_messages.size(), num_messages);
EXPECT_EQ(absl::MakeSpan(input_messages),
absl::MakeSpan(unpacked_decrypted_messages.data(), num_messages));
}
Expand Down Expand Up @@ -643,7 +729,8 @@ TEST(KaheTest, UnpackMessagesRawFailsIfUnallocatedPackedValues) {
rust::Vec<Integer> unpacked_messages;
EXPECT_THAT(UnwrapFfiStatus(UnpackMessagesRaw(
packing_base, packing_dimension, num_packed_messages,
bad_packed_values, unpacked_messages)),
num_packed_messages * packing_dimension, bad_packed_values,
unpacked_messages)),
StatusIs(absl::StatusCode::kInvalidArgument));
}

Expand All @@ -658,7 +745,8 @@ TEST(KaheTest, UnpackMessagesRawFailsIfPackedValuesTooShort) {
rust::Vec<Integer> unpacked_messages;
EXPECT_THAT(UnwrapFfiStatus(UnpackMessagesRaw(
packing_base, packing_dimension, num_packed_messages,
bad_packed_values, unpacked_messages)),
num_packed_messages * packing_dimension, bad_packed_values,
unpacked_messages)),
StatusIs(absl::StatusCode::kInvalidArgument));
}

Expand Down Expand Up @@ -738,7 +826,7 @@ TEST(KaheTest, AddInPlacePolynomial) {
rust::Vec<Integer> unpacked_decrypted_messages;
unpacked_decrypted_messages.reserve(num_messages);
SECAGG_ASSERT_OK(UnwrapFfiStatus(UnpackMessagesRaw(
packing_base, num_packing, decrypted_wrapper.ptr->size(),
packing_base, num_packing, decrypted_wrapper.ptr->size(), num_messages,
decrypted_wrapper, unpacked_decrypted_messages)));
for (int i = 0; i < num_messages; ++i) {
EXPECT_EQ(input_values1[i] + input_values2[i],
Expand Down
Loading
Loading