diff --git a/shell_wrapper/kahe.rs b/shell_wrapper/kahe.rs index 92a5f3b..11b40f5 100644 --- a/shell_wrapper/kahe.rs +++ b/shell_wrapper/kahe.rs @@ -18,7 +18,7 @@ use shell_types::{Moduli, RnsContextRef, RnsPolynomial, RnsPolynomialVec}; use single_thread_hkdf::{SeedWrapper, SingleThreadHkdfWrapper}; use status::rust_status_from_cpp; -use std::collections::HashMap; +use std::collections::{BTreeMap, HashMap}; use std::marker::PhantomData; use std::mem::MaybeUninit; @@ -197,7 +197,7 @@ pub use ffi::BigIntVectorWrapper; /// Returns the resulting ciphertexts. pub fn encrypt( input_vectors: &HashMap<&str, &[u64]>, - packed_vector_configs: &HashMap, + packed_vector_configs: &BTreeMap, secret_key: &RnsPolynomial, params: &KahePublicParametersWrapper, prng: &mut SingleThreadHkdfWrapper, @@ -246,7 +246,7 @@ pub fn decrypt( ciphertext: &RnsPolynomialVec, secret_key: &RnsPolynomial, params: &KahePublicParametersWrapper, - packed_vector_configs: &HashMap, + packed_vector_configs: &BTreeMap, ) -> Result>, status::StatusError> { let mut packed_values = MaybeUninit::::zeroed(); // SAFETY: No lifetime constraints (`packed_values` does not keep any reference to the inputs). diff --git a/shell_wrapper/kahe_test.rs b/shell_wrapper/kahe_test.rs index d9db5ee..a6eb31a 100644 --- a/shell_wrapper/kahe_test.rs +++ b/shell_wrapper/kahe_test.rs @@ -21,7 +21,7 @@ use kahe::{create_public_parameters, decrypt, encrypt, generate_secret_key, Pack use rand::Rng; use status::StatusErrorCode; use status_matchers_rs::status_is; -use std::collections::HashMap; +use std::collections::{BTreeMap, HashMap}; // RNS configuration. LOG_T is the bit length of the KAHE plaintext modulus. const LOG_T: u64 = 11; @@ -45,7 +45,7 @@ fn encrypt_decrypt() -> Result<()> { // Encrypt small vector. `ciphertext` is a wrapper around a C++ pointer. let input_values = vec![1, 2, 3]; let plaintext = HashMap::from([(DEFAULT_ID, input_values.as_slice())]); - let packed_vector_configs = HashMap::from([( + let packed_vector_configs = BTreeMap::from([( DEFAULT_ID.to_string(), PackedVectorConfig { base: 10, dimension: 2, num_packed_coeffs: 2, length: 3 }, )]); @@ -83,7 +83,7 @@ fn encrypt_decrypt_padding() -> Result<()> { // Encrypt the vector. Pass a longer length than what we need. let padded_length = (num_packed_coeffs * packing_dimension) as usize; let plaintext = HashMap::from([(DEFAULT_ID, input_values.as_slice())]); - let packed_vector_configs = HashMap::from([( + let packed_vector_configs = BTreeMap::from([( DEFAULT_ID.to_string(), PackedVectorConfig { base: input_domain as u64, @@ -133,7 +133,7 @@ fn encrypt_decrypt_long() -> Result<()> { let input_values: Vec = (0..num_input_values).map(|_| rand::thread_rng().gen_range(0..input_domain)).collect(); let plaintext = HashMap::from([(DEFAULT_ID, input_values.as_slice())]); - let packed_vector_configs = HashMap::from([( + let packed_vector_configs = BTreeMap::from([( DEFAULT_ID.to_string(), PackedVectorConfig { base: input_domain as u64, @@ -184,7 +184,7 @@ fn encrypt_decrypt_two_vectors() -> Result<()> { // The number of packed coefficients for both vectors. let num_packed_coeffs = [5, 5]; - let packed_vector_configs = HashMap::from([ + let packed_vector_configs = BTreeMap::from([ ( ID0.to_string(), PackedVectorConfig { diff --git a/willow/src/shell/kahe.rs b/willow/src/shell/kahe.rs index 5103f0d..2551307 100644 --- a/willow/src/shell/kahe.rs +++ b/willow/src/shell/kahe.rs @@ -26,7 +26,7 @@ use shell_types::{ RnsPolynomialVec, }; use single_thread_hkdf::SingleThreadHkdfPrng; -use std::collections::HashMap; +use std::collections::{BTreeMap, HashMap}; /// Number of bits supported by the C++ big integer type used for KAHE /// plaintext. @@ -38,7 +38,7 @@ pub struct ShellKaheConfig { pub moduli: Vec, pub log_t: usize, pub num_public_polynomials: usize, - pub packed_vector_configs: HashMap, + pub packed_vector_configs: BTreeMap, } /// Base type holding public KAHE configuration and C++ parameters. @@ -376,7 +376,7 @@ mod test { use proto_serialization_traits::{FromProto, ToProto}; use shell_testing_parameters::{make_kahe_config_for, set_kahe_num_public_polynomials}; use single_thread_hkdf::SingleThreadHkdfPrng; - use std::collections::HashMap; + use std::collections::{BTreeMap, HashMap}; use testing_utils::generate_random_unsigned_vector; /// Standard deviation of the discrete Gaussian distribution used for @@ -400,7 +400,7 @@ mod test { #[gtest] fn test_encrypt_decrypt_short() -> googletest::Result<()> { let plaintext_modulus_bits = 39; - let packed_vector_configs = HashMap::from([( + let packed_vector_configs = BTreeMap::from([( DEFAULT_ID.to_string(), PackedVectorConfig { base: 10, dimension: 2, num_packed_coeffs: 5, length: 10 }, )]); @@ -420,7 +420,7 @@ mod test { #[gtest] fn test_encrypt_decrypt_short_padding() -> googletest::Result<()> { let plaintext_modulus_bits = 39; - let packed_vector_configs = HashMap::from([( + let packed_vector_configs = BTreeMap::from([( DEFAULT_ID.to_string(), PackedVectorConfig { base: 10, dimension: 2, num_packed_coeffs: 5, length: 8 }, )]); @@ -440,7 +440,7 @@ mod test { #[gtest] fn test_encrypt_decrypt_with_serialized_key() -> googletest::Result<()> { let plaintext_modulus_bits = 39; - let packed_vector_configs = HashMap::from([( + let packed_vector_configs = BTreeMap::from([( DEFAULT_ID.to_string(), PackedVectorConfig { base: 10, dimension: 2, num_packed_coeffs: 5, length: 10 }, )]); @@ -467,7 +467,7 @@ mod test { fn test_encrypt_decrypt_long() -> googletest::Result<()> { let plaintext_modulus_bits = 17; let input_domain = 5; - let packed_vector_configs = HashMap::from([( + let packed_vector_configs = BTreeMap::from([( DEFAULT_ID.to_string(), PackedVectorConfig { base: input_domain, @@ -507,7 +507,7 @@ mod test { let plaintext_modulus_bits = 93; let input_domain = 10; let num_messages = 50; - let packed_vector_configs = HashMap::from([( + let packed_vector_configs = BTreeMap::from([( DEFAULT_ID.to_string(), PackedVectorConfig { base: input_domain * 2, @@ -553,7 +553,7 @@ mod test { #[gtest] fn read_write_secret_key() -> googletest::Result<()> { let plaintext_modulus_bits = 17; - let packed_vector_configs = HashMap::from([]); + let packed_vector_configs = BTreeMap::from([]); let kahe_config = make_kahe_config_for(plaintext_modulus_bits, packed_vector_configs)?; let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?; @@ -595,7 +595,7 @@ mod test { #[gtest] fn test_encrypt_decrypt_serialized_proto() -> googletest::Result<()> { let plaintext_modulus_bits = 39; - let packed_vector_configs = HashMap::from([( + let packed_vector_configs = BTreeMap::from([( String::from(DEFAULT_ID), PackedVectorConfig { base: 10, dimension: 2, num_packed_coeffs: 5, length: 10 }, )]); @@ -624,7 +624,7 @@ mod test { fn test_key_serialization_is_homomorphic() -> googletest::Result<()> { // Set up a ShellKahe instance. let plaintext_modulus_bits = 39; - let packed_vector_configs = HashMap::from([]); + let packed_vector_configs = BTreeMap::from([]); let kahe_config = make_kahe_config_for(plaintext_modulus_bits, packed_vector_configs)?; let kahe = ShellKahe::new(kahe_config, CONTEXT_STRING)?; diff --git a/willow/src/shell/parameters_generation.rs b/willow/src/shell/parameters_generation.rs index 8d5f348..99d897e 100644 --- a/willow/src/shell/parameters_generation.rs +++ b/willow/src/shell/parameters_generation.rs @@ -14,7 +14,7 @@ use aggregation_config::AggregationConfig; use kahe::PackedVectorConfig; -use std::collections::HashMap; +use std::collections::BTreeMap; /// Generating KAHE and AHE parameters given the Willow protocol configuration. @@ -34,7 +34,7 @@ pub fn divide_and_roundup(x: usize, y: usize) -> usize { pub fn generate_packing_config( plaintext_bits: usize, agg_config: &AggregationConfig, -) -> Result, status::StatusError> { +) -> Result, status::StatusError> { if plaintext_bits == 0 { return Err(status::invalid_argument("`plaintext_bits` must be positive.")); } @@ -47,7 +47,7 @@ pub fn generate_packing_config( if agg_config.max_number_of_clients <= 0 { return Err(status::invalid_argument("`max_number_of_clients` must be positive.")); } - let mut packing_configs = HashMap::::new(); + let mut packing_configs = BTreeMap::::new(); for (id, (length, bound)) in agg_config.vector_lengths_and_bounds.iter() { if *length <= 0 { return Err(status::invalid_argument(format!( @@ -210,12 +210,7 @@ mod test { ); expect_eq!( packed_vector_configs.get("large").unwrap(), - &PackedVectorConfig { - base: 1 << 24, - dimension: 1, - num_packed_coeffs: 32, - length: 32 - } + &PackedVectorConfig { base: 1 << 24, dimension: 1, num_packed_coeffs: 32, length: 32 } ); expect_eq!( packed_vector_configs.get("long").unwrap(), diff --git a/willow/src/shell/parameters_utils.rs b/willow/src/shell/parameters_utils.rs index d472a1d..cf3bd11 100644 --- a/willow/src/shell/parameters_utils.rs +++ b/willow/src/shell/parameters_utils.rs @@ -19,7 +19,7 @@ use shell_parameters_rust_proto::{ PackedVectorConfigProto, PackedVectorConfigProtoView, ShellKaheConfigProto, ShellKaheConfigProtoView, }; -use std::collections::HashMap; +use std::collections::BTreeMap; /// This file contains some utility functions for working with Willow parameters: /// - Conversions between Rust structs and their corresponding protos. @@ -81,7 +81,7 @@ pub fn kahe_config_from_proto( Err(status::invalid_argument("invalid id in `packed_vectors`.")) } }) - .collect::, _>>()?, + .collect::, _>>()?, }) } @@ -110,7 +110,7 @@ mod test { moduli: vec![65537u64, 12289u64], log_t: 5usize, num_public_polynomials: 2usize, - packed_vector_configs: HashMap::from([ + packed_vector_configs: BTreeMap::from([ ( String::from("vector0"), PackedVectorConfig { diff --git a/willow/src/testing_utils/shell_testing_parameters.rs b/willow/src/testing_utils/shell_testing_parameters.rs index 2feb9f6..6d699ed 100644 --- a/willow/src/testing_utils/shell_testing_parameters.rs +++ b/willow/src/testing_utils/shell_testing_parameters.rs @@ -17,13 +17,13 @@ use ahe_shell::ShellAheConfig; use kahe::PackedVectorConfig; use kahe_shell::ShellKaheConfig; use shell_parameters_generation::{divide_and_roundup, generate_packing_config}; -use std::collections::HashMap; +use std::collections::BTreeMap; /// Creates an KAHE configuration with the given plaintext modulus bits, by /// looking up some pre-generated configurations. pub fn make_kahe_config_for( plaintext_modulus_bits: usize, - packed_vector_configs: HashMap, + packed_vector_configs: BTreeMap, ) -> Result { // Configurations below come from: // google3/experimental/users/baiyuli/async_rlwe_secagg/parameters.cc,