Skip to content

Commit cbd2ce7

Browse files
tholopcopybara-github
authored andcommitted
Prevent memory leaks from MaybeUninit BigIntVectorWrappers.
PiperOrigin-RevId: 844883517
1 parent fe065d7 commit cbd2ce7

File tree

2 files changed

+17
-11
lines changed

2 files changed

+17
-11
lines changed

shell_wrapper/kahe.cc

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -322,12 +322,6 @@ FfiStatus PackMessagesRaw(rust::Slice<const uint64_t> messages,
322322
return MakeFfiStatus(absl::InvalidArgumentError(
323323
secure_aggregation::kNullPointerErrorMessage));
324324
}
325-
326-
// Allocate the vector for output packed values if needed.
327-
if (packed_values->ptr == nullptr) {
328-
packed_values->ptr =
329-
std::make_unique<std::vector<secure_aggregation::BigInteger>>();
330-
}
331325
auto curr_packed_values =
332326
rlwe::PackMessagesFlat<secure_aggregation::Integer,
333327
secure_aggregation::BigInteger>(
@@ -339,6 +333,11 @@ FfiStatus PackMessagesRaw(rust::Slice<const uint64_t> messages,
339333
}
340334
// Pad with zeros if needed.
341335
curr_packed_values.resize(num_packed_values, 0);
336+
// Allocate the vector for output packed values if needed.
337+
if (packed_values->ptr == nullptr) {
338+
packed_values->ptr =
339+
std::make_unique<std::vector<secure_aggregation::BigInteger>>();
340+
}
342341
// Append the packed values to the end of the output vector.
343342
packed_values->ptr->insert(packed_values->ptr->end(),
344343
curr_packed_values.begin(),

shell_wrapper/kahe.rs

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ mod ffi {
3737
}
3838

3939
pub struct BigIntVectorWrapper {
40-
pub ptr: UniquePtr<CxxVector<BigInteger>>,
40+
pub(crate) ptr: UniquePtr<CxxVector<BigInteger>>,
4141
}
4242

4343
unsafe extern "C++" {
@@ -190,7 +190,11 @@ pub fn encrypt(
190190
params: &KahePublicParametersWrapper,
191191
prng: &mut SingleThreadHkdfWrapper,
192192
) -> Result<RnsPolynomialVec, status::StatusError> {
193-
let mut packed_values = MaybeUninit::<BigIntVectorWrapper>::zeroed();
193+
// SAFETY: this initializes `packed_values` with packed_values.ptr == nullptr. The following
194+
// loop ensures that we either return an error (and drop `packed_values`, including potential
195+
// partial allocations), or make packed_values.ptr point to a valid C++ vector.
196+
let mut packed_values = BigIntVectorWrapper { ptr: cxx::UniquePtr::null() };
197+
194198
// SAFETY: No lifetime constraints (`PackMessagesRaw` may create a new vector of BigIntegers
195199
// wrapped by `packed_values` which does not keep any reference to the inputs).
196200
// `PackMessagesRaw` only appends to the C++ vector wrapped by `packed_values`,
@@ -206,7 +210,7 @@ pub fn encrypt(
206210
packed_vector_config.base,
207211
packed_vector_config.dimension,
208212
packed_vector_config.num_packed_coeffs,
209-
packed_values.as_mut_ptr(),
213+
&mut packed_values,
210214
)
211215
})?;
212216
}
@@ -217,7 +221,7 @@ pub fn encrypt(
217221
// wrapped by `packed_values`, updates the states wrapped by `prng`, and writes into the C++
218222
// vector wrapped by `out`.
219223
rust_status_from_cpp(unsafe {
220-
ffi::Encrypt(&packed_values.assume_init(), secret_key, params, prng, out.as_mut_ptr())
224+
ffi::Encrypt(&packed_values, secret_key, params, prng, out.as_mut_ptr())
221225
})?;
222226
// SAFETY: `out` is safely initialized if we get to this point.
223227
Ok(unsafe { out.assume_init() })
@@ -239,6 +243,9 @@ pub fn decrypt(
239243
ffi::Decrypt(ciphertext, secret_key, params, packed_values.as_mut_ptr())
240244
})?;
241245

246+
// SAFETY: `packed_values` is safely initialized if we get to this point.
247+
let mut packed_values = unsafe { packed_values.assume_init() };
248+
242249
let mut output_vectors = HashMap::<String, Vec<u64>>::new();
243250
// Assume the packed values are stored in the same order as the configs.
244251
for (id, packed_vector_config) in packed_vector_configs.iter() {
@@ -253,7 +260,7 @@ pub fn decrypt(
253260
packed_vector_config.base,
254261
packed_vector_config.dimension,
255262
packed_vector_config.num_packed_coeffs,
256-
packed_values.assume_init_mut(),
263+
&mut packed_values,
257264
&mut unpacked_values,
258265
)
259266
})?;

0 commit comments

Comments
 (0)