diff --git a/shell_wrapper/kahe.cc b/shell_wrapper/kahe.cc index 2fc9583..8c1d3b7 100644 --- a/shell_wrapper/kahe.cc +++ b/shell_wrapper/kahe.cc @@ -322,12 +322,6 @@ FfiStatus PackMessagesRaw(rust::Slice messages, return MakeFfiStatus(absl::InvalidArgumentError( secure_aggregation::kNullPointerErrorMessage)); } - - // Allocate the vector for output packed values if needed. - if (packed_values->ptr == nullptr) { - packed_values->ptr = - std::make_unique>(); - } auto curr_packed_values = rlwe::PackMessagesFlat( @@ -339,6 +333,11 @@ FfiStatus PackMessagesRaw(rust::Slice messages, } // Pad with zeros if needed. curr_packed_values.resize(num_packed_values, 0); + // Allocate the vector for output packed values if needed. + if (packed_values->ptr == nullptr) { + packed_values->ptr = + std::make_unique>(); + } // Append the packed values to the end of the output vector. packed_values->ptr->insert(packed_values->ptr->end(), curr_packed_values.begin(), diff --git a/shell_wrapper/kahe.rs b/shell_wrapper/kahe.rs index 40c6a73..699a455 100644 --- a/shell_wrapper/kahe.rs +++ b/shell_wrapper/kahe.rs @@ -37,7 +37,7 @@ mod ffi { } pub struct BigIntVectorWrapper { - pub ptr: UniquePtr>, + pub(crate) ptr: UniquePtr>, } unsafe extern "C++" { @@ -190,7 +190,11 @@ pub fn encrypt( params: &KahePublicParametersWrapper, prng: &mut SingleThreadHkdfWrapper, ) -> Result { - let mut packed_values = MaybeUninit::::zeroed(); + // SAFETY: this initializes `packed_values` with packed_values.ptr == nullptr. The following + // loop ensures that we either return an error (and drop `packed_values`, including potential + // partial allocations), or make packed_values.ptr point to a valid C++ vector. + let mut packed_values = BigIntVectorWrapper { ptr: cxx::UniquePtr::null() }; + // SAFETY: No lifetime constraints (`PackMessagesRaw` may create a new vector of BigIntegers // wrapped by `packed_values` which does not keep any reference to the inputs). // `PackMessagesRaw` only appends to the C++ vector wrapped by `packed_values`, @@ -206,7 +210,7 @@ pub fn encrypt( packed_vector_config.base, packed_vector_config.dimension, packed_vector_config.num_packed_coeffs, - packed_values.as_mut_ptr(), + &mut packed_values, ) })?; } @@ -217,7 +221,7 @@ pub fn encrypt( // wrapped by `packed_values`, updates the states wrapped by `prng`, and writes into the C++ // vector wrapped by `out`. rust_status_from_cpp(unsafe { - ffi::Encrypt(&packed_values.assume_init(), secret_key, params, prng, out.as_mut_ptr()) + ffi::Encrypt(&packed_values, secret_key, params, prng, out.as_mut_ptr()) })?; // SAFETY: `out` is safely initialized if we get to this point. Ok(unsafe { out.assume_init() }) @@ -239,6 +243,9 @@ pub fn decrypt( ffi::Decrypt(ciphertext, secret_key, params, packed_values.as_mut_ptr()) })?; + // SAFETY: `packed_values` is safely initialized if we get to this point. + let mut packed_values = unsafe { packed_values.assume_init() }; + let mut output_vectors = HashMap::>::new(); // Assume the packed values are stored in the same order as the configs. for (id, packed_vector_config) in packed_vector_configs.iter() { @@ -253,7 +260,7 @@ pub fn decrypt( packed_vector_config.base, packed_vector_config.dimension, packed_vector_config.num_packed_coeffs, - packed_values.assume_init_mut(), + &mut packed_values, &mut unpacked_values, ) })?;