diff --git a/shell_wrapper/BUILD b/shell_wrapper/BUILD index d62094d..6f97ab9 100644 --- a/shell_wrapper/BUILD +++ b/shell_wrapper/BUILD @@ -242,6 +242,7 @@ cc_library( ":status_macros", "@abseil-cpp//absl/status", "@abseil-cpp//absl/status:statusor", + "@abseil-cpp//absl/strings:str_format", "@abseil-cpp//absl/strings:string_view", "@abseil-cpp//absl/types:span", "@shell-encryption//shell_encryption/rns:coefficient_encoder", diff --git a/shell_wrapper/kahe.cc b/shell_wrapper/kahe.cc index 8c1d3b7..81565d0 100644 --- a/shell_wrapper/kahe.cc +++ b/shell_wrapper/kahe.cc @@ -25,6 +25,7 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "include/cxx.h" @@ -347,6 +348,7 @@ FfiStatus PackMessagesRaw(rust::Slice messages, FfiStatus UnpackMessagesRaw(uint64_t packing_base, uint64_t packing_dimension, uint64_t num_packed_values, + uint64_t num_unpacked_values, BigIntVectorWrapper& packed_values, rust::Vec& out) { // Validate the wrappers. @@ -358,6 +360,8 @@ FfiStatus UnpackMessagesRaw(uint64_t packing_base, uint64_t packing_dimension, return MakeFfiStatus( absl::InvalidArgumentError("insufficient number of packed values.")); } + + // `unpacked_messages` is padded with zeros if needed. std::vector unpacked_messages = rlwe::UnpackMessagesFlat( @@ -365,8 +369,16 @@ FfiStatus UnpackMessagesRaw(uint64_t packing_base, uint64_t packing_dimension, packing_base, packing_dimension); packed_values.ptr->erase(packed_values.ptr->begin(), packed_values.ptr->begin() + num_packed_values); - for (auto& val : unpacked_messages) { - out.push_back(val); + + // Remove padding and copy values to Rust output vector. + if (unpacked_messages.size() < num_unpacked_values) { + return MakeFfiStatus(absl::InvalidArgumentError( + absl::StrFormat("unpacked messages is too short (%d) for the requested " + "number of unpacked values (%d)", + unpacked_messages.size(), num_unpacked_values))); + } + for (size_t i = 0; i < num_unpacked_values; ++i) { + out.push_back(unpacked_messages[i]); } return MakeFfiStatus(); } diff --git a/shell_wrapper/kahe.h b/shell_wrapper/kahe.h index dbfac25..cb770c4 100644 --- a/shell_wrapper/kahe.h +++ b/shell_wrapper/kahe.h @@ -154,14 +154,16 @@ FfiStatus PackMessagesRaw(rust::Slice messages, uint64_t num_packed_values, BigIntVectorWrapper* packed_values); -// Unpacks messages stored at `packed_values[0..num_packed_values]` and appends -// them to `out`, and removes these packed values from `packed_values`. +// Unpacks messages stored at `packed_values[0..num_packed_values]`, removes +// these packed values from `packed_values` and appends the first +// `num_unpacked_values` messages to `out`. // Expects `packed_values.ptr` to be a valid pointer to the vector of packed // values, and expects packing_base > 1, packing_dimension > 0, // num_packed_values > 0, packing_base^packing_dimension < // std::numeric_limits::max(). FfiStatus UnpackMessagesRaw(uint64_t packing_base, uint64_t packing_dimension, uint64_t num_packed_values, + uint64_t num_unpacked_values, BigIntVectorWrapper& packed_values, rust::Vec& out); diff --git a/shell_wrapper/kahe.rs b/shell_wrapper/kahe.rs index 699a455..92a5f3b 100644 --- a/shell_wrapper/kahe.rs +++ b/shell_wrapper/kahe.rs @@ -22,11 +22,22 @@ use std::collections::HashMap; use std::marker::PhantomData; use std::mem::MaybeUninit; +/// Configuration for packing and unpacking. Used to convert a long vector of `length` small +/// integers in [0, `base`) into a short vector of `num_packed_coeffs` large integers in +/// [0, `base`^`dimension`), and vice versa. #[derive(Debug, PartialEq, Clone)] pub struct PackedVectorConfig { + /// Base for packing. pub base: u64, + + /// Number of elements packed into each coefficient. pub dimension: u64, + + /// Number of coefficients in the packed vector. pub num_packed_coeffs: u64, + + /// Number of elements in the plaintext vector before packing. + pub length: u64, } #[cxx::bridge] @@ -93,6 +104,7 @@ mod ffi { packing_base: u64, packing_dimension: u64, num_packed_values: u64, + num_unpacked_values: u64, packed_values: &mut BigIntVectorWrapper, out: &mut Vec, ) -> FfiStatus; @@ -260,6 +272,7 @@ pub fn decrypt( packed_vector_config.base, packed_vector_config.dimension, packed_vector_config.num_packed_coeffs, + packed_vector_config.length, &mut packed_values, &mut unpacked_values, ) diff --git a/shell_wrapper/kahe_test.cc b/shell_wrapper/kahe_test.cc index 2d8981c..0b2b8be 100644 --- a/shell_wrapper/kahe_test.cc +++ b/shell_wrapper/kahe_test.cc @@ -337,7 +337,8 @@ TEST(KaheTest, UnpackMessagesRawRemovesConsumedPackedValues) { rust::Vec unpacked_messages; SECAGG_EXPECT_OK(UnwrapFfiStatus( UnpackMessagesRaw(packing_base, packing_dimension, num_packed_values, - packed_values, unpacked_messages))); + num_packed_values * packing_dimension, packed_values, + unpacked_messages))); EXPECT_EQ(packed_values.ptr->size(), num_packed_values); EXPECT_EQ(unpacked_messages.size(), num_packed_values); // Unpacked values should match the first half of the original packed values. @@ -349,6 +350,94 @@ TEST(KaheTest, UnpackMessagesRawRemovesConsumedPackedValues) { absl::MakeSpan(packed).subspan(num_packed_values)); } +TEST(KaheTest, RawEncryptDecryptPadding) { + constexpr int num_packing = 2; + constexpr int num_public_polynomials = 2; + constexpr int num_messages = 9; + constexpr Integer packing_base = 10; + // 2 messages per coefficient, last coefficient has only 1 message. + constexpr int num_packed_messages = 5; + + std::unique_ptr public_seed; + SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(public_seed))); + KahePublicParametersWrapper params; + SECAGG_ASSERT_OK(UnwrapFfiStatus(CreateKahePublicParametersWrapper( + kLogN, kLogT, ToRustSlice(kQs), num_public_polynomials, + ToRustSlice(*public_seed), ¶ms))); + std::unique_ptr private_seed; + SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(private_seed))); + SingleThreadHkdfWrapper prng; + SECAGG_ASSERT_OK(UnwrapFfiStatus( + CreateSingleThreadHkdf(ToRustSlice(*private_seed), prng))); + RnsPolynomialWrapper key; + SECAGG_ASSERT_OK( + UnwrapFfiStatus(GenerateSecretKeyWrapper(params, &prng, &key))); + + // Pack messages that don't fully occupy the packed coefficients. + std::vector input_messages = + rlwe::testing::SampleMessages(num_messages, packing_base); + BigIntVectorWrapper packed_messages{ + .ptr = std::make_unique>()}; + SECAGG_ASSERT_OK(UnwrapFfiStatus( + PackMessagesRaw(ToRustSlice(input_messages), packing_base, num_packing, + num_packed_messages, &packed_messages))); + + // Encrypt the packed messages. + RnsPolynomialVecWrapper ciphertexts; + SECAGG_ASSERT_OK(UnwrapFfiStatus( + Encrypt(packed_messages, key, params, &prng, &ciphertexts))); + + // Decrypt to get a packed plaintext. + BigIntVectorWrapper packed_messages_1{ + .ptr = std::make_unique>()}; + SECAGG_ASSERT_OK( + UnwrapFfiStatus(Decrypt(ciphertexts, key, params, &packed_messages_1))); + + // Unpack and retrieve the original messages plus padding. + rust::Vec unpacked_messages_1; + SECAGG_ASSERT_OK(UnwrapFfiStatus(UnpackMessagesRaw( + packing_base, num_packing, packed_messages_1.ptr->size(), + num_packed_messages * num_packing, packed_messages_1, + unpacked_messages_1))); + + // Decrypted messages are padded to zero up to the end of the polynomial. + EXPECT_THAT( + absl::MakeSpan(unpacked_messages_1.data(), unpacked_messages_1.size()) + .subspan(num_messages, + num_packed_messages * num_packing - num_messages), + ::testing::Each(::testing::Eq(0))); + + // Decrypt to obtain a fresh packed plaintext. + BigIntVectorWrapper packed_messages_2{ + .ptr = std::make_unique>()}; + SECAGG_ASSERT_OK( + UnwrapFfiStatus(Decrypt(ciphertexts, key, params, &packed_messages_2))); + + // Now unpack and directly pass the right length to remove padding. + rust::Vec unpacked_messages_2; + SECAGG_ASSERT_OK(UnwrapFfiStatus(UnpackMessagesRaw( + packing_base, num_packing, packed_messages_2.ptr->size(), num_messages, + packed_messages_2, unpacked_messages_2))); + EXPECT_EQ( + absl::MakeSpan(unpacked_messages_2.data(), unpacked_messages_2.size()), + absl::MakeSpan(input_messages)); + + // Finally, check that we fail if we request too many unpacked messages + BigIntVectorWrapper packed_messages_3{ + .ptr = std::make_unique>()}; + SECAGG_ASSERT_OK( + UnwrapFfiStatus(Decrypt(ciphertexts, key, params, &packed_messages_2))); + SECAGG_ASSERT_OK(UnwrapFfiStatus( + Encrypt(packed_messages_3, key, params, &prng, &ciphertexts))); + rust::Vec unpacked_messages_3; + int num_unpacked_messages_3 = packed_messages_3.ptr->size() * num_packing + 1; + EXPECT_THAT( + UnwrapFfiStatus(UnpackMessagesRaw( + packing_base, num_packing, packed_messages_3.ptr->size(), + num_unpacked_messages_3, packed_messages_3, unpacked_messages_3)), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + TEST(KaheTest, PackAndEncrypt) { constexpr int num_packing = 8; constexpr int num_public_polynomials = 2; @@ -399,19 +488,14 @@ TEST(KaheTest, PackAndEncrypt) { .ptr = std::make_unique>(std::move(decrypted))}; rust::Vec unpacked_messages; SECAGG_ASSERT_OK(UnwrapFfiStatus( - UnpackMessagesRaw(packing_base, num_packing, packed_messages.size(), - decrypted_wrapper, unpacked_messages))); + UnpackMessagesRaw(packing_base, num_packing, num_packed_messages, + num_messages, decrypted_wrapper, unpacked_messages))); EXPECT_EQ(absl::MakeSpan(unpacked_messages.data(), num_messages), absl::MakeSpan(expected_unpacked_messages.data(), num_messages)); + // Check against the original input messages. EXPECT_EQ(absl::MakeSpan(unpacked_messages.data(), num_messages), - absl::MakeSpan(input_messages).subspan(0, num_messages)); - // Check unpacked messages are padded with zeros. - ASSERT_GE(expected_unpacked_messages.size(), num_messages); - EXPECT_THAT( - absl::MakeSpan(unpacked_messages.data(), unpacked_messages.size()) - .subspan(num_messages, unpacked_messages.size() - num_messages), - ::testing::Each(::testing::Eq(0))); + absl::MakeSpan(input_messages)); } TEST(KaheTest, RawVectorEncryptOnePolynomial) { @@ -461,7 +545,8 @@ TEST(KaheTest, RawVectorEncryptOnePolynomial) { rust::Vec unpacked_decrypted_messages; SECAGG_ASSERT_OK(UnwrapFfiStatus(UnpackMessagesRaw( packing_base, num_packing, decrypted_wrapper.ptr->size(), - decrypted_wrapper, unpacked_decrypted_messages))); + decrypted_wrapper.ptr->size() * num_packing, decrypted_wrapper, + unpacked_decrypted_messages))); // Filled the whole buffer with right messages. EXPECT_EQ(absl::MakeSpan(unpacked_decrypted_messages.data(), num_messages), @@ -481,6 +566,7 @@ TEST(KaheTest, RawVectorEncryptOnePolynomial) { unpacked_decrypted_long_messages.reserve(buffer_length); SECAGG_ASSERT_OK(UnwrapFfiStatus(UnpackMessagesRaw( packing_base, num_packing, decrypted_long_messages_wrapper.ptr->size(), + decrypted_long_messages_wrapper.ptr->size() * num_packing, decrypted_long_messages_wrapper, unpacked_decrypted_long_messages))); // The non-zero messages are identical. @@ -538,10 +624,10 @@ TEST(KaheTest, RawVectorEncryptTwoPolynomials) { UnwrapFfiStatus(Decrypt(ciphertexts, key, params, &decrypted_wrapper))); rust::Vec unpacked_decrypted_messages; SECAGG_ASSERT_OK(UnwrapFfiStatus(UnpackMessagesRaw( - packing_base, num_packing, decrypted_wrapper.ptr->size(), + packing_base, num_packing, decrypted_wrapper.ptr->size(), num_messages, decrypted_wrapper, unpacked_decrypted_messages))); - EXPECT_GE(unpacked_decrypted_messages.size(), num_messages); + EXPECT_EQ(unpacked_decrypted_messages.size(), num_messages); EXPECT_EQ(absl::MakeSpan(input_messages), absl::MakeSpan(unpacked_decrypted_messages.data(), num_messages)); } @@ -643,7 +729,8 @@ TEST(KaheTest, UnpackMessagesRawFailsIfUnallocatedPackedValues) { rust::Vec unpacked_messages; EXPECT_THAT(UnwrapFfiStatus(UnpackMessagesRaw( packing_base, packing_dimension, num_packed_messages, - bad_packed_values, unpacked_messages)), + num_packed_messages * packing_dimension, bad_packed_values, + unpacked_messages)), StatusIs(absl::StatusCode::kInvalidArgument)); } @@ -658,7 +745,8 @@ TEST(KaheTest, UnpackMessagesRawFailsIfPackedValuesTooShort) { rust::Vec unpacked_messages; EXPECT_THAT(UnwrapFfiStatus(UnpackMessagesRaw( packing_base, packing_dimension, num_packed_messages, - bad_packed_values, unpacked_messages)), + num_packed_messages * packing_dimension, bad_packed_values, + unpacked_messages)), StatusIs(absl::StatusCode::kInvalidArgument)); } @@ -738,7 +826,7 @@ TEST(KaheTest, AddInPlacePolynomial) { rust::Vec unpacked_decrypted_messages; unpacked_decrypted_messages.reserve(num_messages); SECAGG_ASSERT_OK(UnwrapFfiStatus(UnpackMessagesRaw( - packing_base, num_packing, decrypted_wrapper.ptr->size(), + packing_base, num_packing, decrypted_wrapper.ptr->size(), num_messages, decrypted_wrapper, unpacked_decrypted_messages))); for (int i = 0; i < num_messages; ++i) { EXPECT_EQ(input_values1[i] + input_values2[i], diff --git a/shell_wrapper/kahe_test.rs b/shell_wrapper/kahe_test.rs index 2b4afaf..d9db5ee 100644 --- a/shell_wrapper/kahe_test.rs +++ b/shell_wrapper/kahe_test.rs @@ -47,13 +47,13 @@ fn encrypt_decrypt() -> Result<()> { let plaintext = HashMap::from([(DEFAULT_ID, input_values.as_slice())]); let packed_vector_configs = HashMap::from([( DEFAULT_ID.to_string(), - PackedVectorConfig { base: 10, dimension: 2, num_packed_coeffs: 2 }, + PackedVectorConfig { base: 10, dimension: 2, num_packed_coeffs: 2, length: 3 }, )]); let ciphertext = encrypt(&plaintext, &packed_vector_configs, &secret_key, ¶ms, &mut prng)?; let output_values = decrypt(&ciphertext, &secret_key, ¶ms, &packed_vector_configs)?; expect_that!(output_values.contains_key(DEFAULT_ID), eq(true)); - expect_that!(output_values[DEFAULT_ID][..3], container_eq(input_values)); + expect_that!(output_values[DEFAULT_ID], container_eq(input_values)); Ok(()) } @@ -80,7 +80,8 @@ fn encrypt_decrypt_padding() -> Result<()> { let input_values: Vec = (0..num_input_values).map(|_| rand::thread_rng().gen_range(0..input_domain)).collect(); - // Encrypt the vector. + // Encrypt the vector. Pass a longer length than what we need. + let padded_length = (num_packed_coeffs * packing_dimension) as usize; let plaintext = HashMap::from([(DEFAULT_ID, input_values.as_slice())]); let packed_vector_configs = HashMap::from([( DEFAULT_ID.to_string(), @@ -88,6 +89,7 @@ fn encrypt_decrypt_padding() -> Result<()> { base: input_domain as u64, dimension: packing_dimension as u64, num_packed_coeffs: num_packed_coeffs as u64, + length: padded_length as u64, }, )]); let ciphertext = encrypt(&plaintext, &packed_vector_configs, &secret_key, ¶ms, &mut prng)?; @@ -97,7 +99,6 @@ fn encrypt_decrypt_padding() -> Result<()> { let output_values = &decrypted[DEFAULT_ID]; // Check that message is correctly decrypted with right padding. - let padded_length = (num_packed_coeffs * packing_dimension) as usize; expect_that!(output_values.len(), eq(padded_length)); expect_that!(output_values.len(), gt(num_input_values)); expect_that!(output_values[..num_input_values], container_eq(input_values)); @@ -138,6 +139,7 @@ fn encrypt_decrypt_long() -> Result<()> { base: input_domain as u64, dimension: packing_dimension as u64, num_packed_coeffs: num_packed_coeffs as u64, + length: num_input_values as u64, }, )]); let ciphertext = encrypt(&plaintext, &packed_vector_configs, &secret_key, ¶ms, &mut prng)?; @@ -145,15 +147,9 @@ fn encrypt_decrypt_long() -> Result<()> { let decrypted = decrypt(&ciphertext, &secret_key, ¶ms, &packed_vector_configs)?; let output_values = &decrypted[DEFAULT_ID]; - // Check that message is correctly decrypted with right padding. - let padded_length = num_packed_coeffs * packing_dimension; - expect_that!(output_values.len(), eq(padded_length)); - expect_that!(output_values.len(), gt(num_input_values)); - expect_that!(output_values[..num_input_values], container_eq(input_values)); - expect_that!( - output_values[num_input_values..], - container_eq(vec![0; padded_length - num_input_values]) - ); + // Check that message is correctly decrypted (no padding). + expect_that!(output_values.len(), eq(num_input_values)); + expect_that!(output_values, container_eq(input_values)); // If the input is too long, we should fail. let num_values_too_long = num_public_polynomials * poly_capacity + 1; @@ -195,6 +191,7 @@ fn encrypt_decrypt_two_vectors() -> Result<()> { base: input_domains[0] as u64, dimension: packing_dimensions[0] as u64, num_packed_coeffs: num_packed_coeffs[0] as u64, + length: num_input_values[0] as u64, }, ), ( @@ -203,6 +200,7 @@ fn encrypt_decrypt_two_vectors() -> Result<()> { base: input_domains[1] as u64, dimension: packing_dimensions[1] as u64, num_packed_coeffs: num_packed_coeffs[1] as u64, + length: num_input_values[1] as u64, }, ), ]); @@ -218,26 +216,16 @@ fn encrypt_decrypt_two_vectors() -> Result<()> { HashMap::from([(ID0, input_values0.as_slice()), (ID1, input_values1.as_slice())]); let ciphertext = encrypt(&plaintext, &packed_vector_configs, &secret_key, ¶ms, &mut prng)?; - // Decrypt and check the output contains the two vectors that are padded correctly. + // Decrypt and check the output contains the two vectors. let decrypted = decrypt(&ciphertext, &secret_key, ¶ms, &packed_vector_configs)?; verify_that!(decrypted.contains_key(ID0), eq(true))?; verify_that!(decrypted.contains_key(ID1), eq(true))?; let output_values0 = &decrypted[ID0]; let output_values1 = &decrypted[ID1]; - expect_that!(output_values0.len(), eq(num_packed_coeffs[0] * packing_dimensions[0])); - expect_that!(output_values0.len(), gt(num_input_values[0])); - expect_that!(output_values0[..num_input_values[0]], container_eq(input_values0)); - expect_that!( - output_values0[num_input_values[0]..], - container_eq(vec![0; num_packed_coeffs[0] * packing_dimensions[0] - num_input_values[0]]) - ); - expect_that!(output_values1.len(), eq(num_packed_coeffs[1] * packing_dimensions[1])); - expect_that!(output_values1.len(), gt(num_input_values[1])); - expect_that!(output_values1[..num_input_values[1]], container_eq(input_values1)); - expect_that!( - output_values1[num_input_values[1]..], - container_eq(vec![0; num_packed_coeffs[1] * packing_dimensions[1] - num_input_values[1]]) - ); + expect_that!(output_values0.len(), eq(num_input_values[0])); + expect_that!(output_values0, container_eq(input_values0)); + expect_that!(output_values1.len(), eq(num_input_values[1])); + expect_that!(output_values1, container_eq(input_values1)); Ok(()) } diff --git a/willow/proto/shell/parameters.proto b/willow/proto/shell/parameters.proto index 235ab69..ebb01ed 100644 --- a/willow/proto/shell/parameters.proto +++ b/willow/proto/shell/parameters.proto @@ -20,14 +20,15 @@ option java_multiple_files = true; option java_outer_classname = "ParametersProto"; // This proto defines how to pack an input vector into a KAHE plaintext. -// An input vector is split into `num_packed_coeffs` many sub-vectors of -// length `dimension` each. Each sub-vector is then packed into a single -// plaintext coefficient using base `base` encoding to allow summation over -// all clients' contributions. +// An input vector of length `length` is split into `num_packed_coeffs` many +// sub-vectors of length `dimension` each. Each sub-vector is then packed into a +// single plaintext coefficient using base `base` encoding to allow summation +// over all clients' contributions. message PackedVectorConfigProto { int64 base = 1; int64 dimension = 2; int64 num_packed_coeffs = 3; + int64 length = 4; } // This proto defines the parameters for instantiating the KAHE scheme diff --git a/willow/src/shell/kahe.rs b/willow/src/shell/kahe.rs index 8ff84a3..5103f0d 100644 --- a/willow/src/shell/kahe.rs +++ b/willow/src/shell/kahe.rs @@ -402,7 +402,7 @@ mod test { let plaintext_modulus_bits = 39; let packed_vector_configs = HashMap::from([( DEFAULT_ID.to_string(), - PackedVectorConfig { base: 10, dimension: 2, num_packed_coeffs: 5 }, + PackedVectorConfig { base: 10, dimension: 2, num_packed_coeffs: 5, length: 10 }, )]); let kahe_config = make_kahe_config_for(plaintext_modulus_bits, packed_vector_configs)?; let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?; @@ -417,12 +417,32 @@ mod test { verify_eq!(&pt, &decrypted) } + #[gtest] + fn test_encrypt_decrypt_short_padding() -> googletest::Result<()> { + let plaintext_modulus_bits = 39; + let packed_vector_configs = HashMap::from([( + DEFAULT_ID.to_string(), + PackedVectorConfig { base: 10, dimension: 2, num_packed_coeffs: 5, length: 8 }, + )]); + let kahe_config = make_kahe_config_for(plaintext_modulus_bits, packed_vector_configs)?; + let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?; + + let pt = HashMap::from([(DEFAULT_ID.to_string(), vec![0, 1, 2, 3, 4, 5, 6, 7])]); + let seed = SingleThreadHkdfPrng::generate_seed()?; + let mut prng = SingleThreadHkdfPrng::create(&seed)?; + let sk = kahe.key_gen(&mut prng)?; + let pt_slice = ShellKahe::plaintext_as_slice(&pt); + let ct = kahe.encrypt(&pt_slice, &sk, &mut prng)?; + let decrypted = kahe.decrypt(&ct, &sk)?; + verify_eq!(&pt, &decrypted) + } + #[gtest] fn test_encrypt_decrypt_with_serialized_key() -> googletest::Result<()> { let plaintext_modulus_bits = 39; let packed_vector_configs = HashMap::from([( DEFAULT_ID.to_string(), - PackedVectorConfig { base: 10, dimension: 2, num_packed_coeffs: 5 }, + PackedVectorConfig { base: 10, dimension: 2, num_packed_coeffs: 5, length: 10 }, )]); let kahe_config = make_kahe_config_for(plaintext_modulus_bits, packed_vector_configs)?; let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?; @@ -453,6 +473,7 @@ mod test { base: input_domain, dimension: 1, num_packed_coeffs: 0, // Dummy value until we compute it from kahe_config. + length: 0, // Dummy value. }, )]); let mut kahe_config = make_kahe_config_for(plaintext_modulus_bits, packed_vector_configs)?; @@ -460,6 +481,7 @@ mod test { let num_messages = (1 << kahe_config.log_n) * 2; // Needs two polynomials. let packed_vector_config = kahe_config.packed_vector_configs.get_mut(DEFAULT_ID).unwrap(); packed_vector_config.num_packed_coeffs = num_messages; + packed_vector_config.length = num_messages; set_kahe_num_public_polynomials(&mut kahe_config); let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?; @@ -491,6 +513,7 @@ mod test { base: input_domain * 2, dimension: 1, num_packed_coeffs: num_messages, + length: num_messages, }, )]); let kahe_config = make_kahe_config_for(plaintext_modulus_bits, packed_vector_configs)?; @@ -574,7 +597,7 @@ mod test { let plaintext_modulus_bits = 39; let packed_vector_configs = HashMap::from([( String::from(DEFAULT_ID), - PackedVectorConfig { base: 10, dimension: 2, num_packed_coeffs: 5 }, + PackedVectorConfig { base: 10, dimension: 2, num_packed_coeffs: 5, length: 10 }, )]); let kahe_config = make_kahe_config_for(plaintext_modulus_bits, packed_vector_configs)?; let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?; diff --git a/willow/src/shell/parameters_generation.rs b/willow/src/shell/parameters_generation.rs index daab1d3..8d5f348 100644 --- a/willow/src/shell/parameters_generation.rs +++ b/willow/src/shell/parameters_generation.rs @@ -89,6 +89,7 @@ pub fn generate_packing_config( base: base as u64, dimension: dimension as u64, num_packed_coeffs: num_packed_coeffs as u64, + length: *length as u64, }, ); } @@ -200,15 +201,30 @@ mod test { )?; expect_eq!( packed_vector_configs.get("small").unwrap(), - &PackedVectorConfig { base: 1 << 11, dimension: 2, num_packed_coeffs: 512 } + &PackedVectorConfig { + base: 1 << 11, + dimension: 2, + num_packed_coeffs: 512, + length: 1024 + } ); expect_eq!( packed_vector_configs.get("large").unwrap(), - &PackedVectorConfig { base: 1 << 24, dimension: 1, num_packed_coeffs: 32 } + &PackedVectorConfig { + base: 1 << 24, + dimension: 1, + num_packed_coeffs: 32, + length: 32 + } ); expect_eq!( packed_vector_configs.get("long").unwrap(), - &PackedVectorConfig { base: 1 << 24, dimension: 1, num_packed_coeffs: 65536 } + &PackedVectorConfig { + base: 1 << 24, + dimension: 1, + num_packed_coeffs: 65536, + length: 65536 + } ); Ok(()) } diff --git a/willow/src/shell/parameters_utils.rs b/willow/src/shell/parameters_utils.rs index c309d11..d472a1d 100644 --- a/willow/src/shell/parameters_utils.rs +++ b/willow/src/shell/parameters_utils.rs @@ -30,6 +30,7 @@ pub fn packed_vector_config_to_proto(config: &PackedVectorConfig) -> PackedVecto base: config.base as i64, dimension: config.dimension as i64, num_packed_coeffs: config.num_packed_coeffs as i64, + length: config.length as i64, }) } @@ -39,6 +40,7 @@ pub fn packed_vector_config_from_proto(proto: PackedVectorConfigProtoView) -> Pa base: proto.base() as u64, dimension: proto.dimension() as u64, num_packed_coeffs: proto.num_packed_coeffs() as u64, + length: proto.length() as u64, } } @@ -90,7 +92,12 @@ mod test { #[gtest] fn test_packed_vector_config_proto_roundtrip() -> googletest::Result<()> { - let config = PackedVectorConfig { base: 8u64, dimension: 2u64, num_packed_coeffs: 1024u64 }; + let config = PackedVectorConfig { + base: 8u64, + dimension: 2u64, + num_packed_coeffs: 1024u64, + length: 2048u64, + }; let proto = packed_vector_config_to_proto(&config); let config_from_proto = packed_vector_config_from_proto(proto.as_view()); verify_eq!(config_from_proto, config) @@ -106,7 +113,12 @@ mod test { packed_vector_configs: HashMap::from([ ( String::from("vector0"), - PackedVectorConfig { base: 16u64, dimension: 8u64, num_packed_coeffs: 1024u64 }, + PackedVectorConfig { + base: 16u64, + dimension: 8u64, + num_packed_coeffs: 1024u64, + length: 8192u64, + }, ), ( String::from("vector1"), @@ -114,6 +126,7 @@ mod test { base: 65536u64, dimension: 1u64, num_packed_coeffs: 16u64, + length: 16u64, }, ), ]),