Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions shell_wrapper/kahe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -322,12 +322,6 @@ FfiStatus PackMessagesRaw(rust::Slice<const uint64_t> messages,
return MakeFfiStatus(absl::InvalidArgumentError(
secure_aggregation::kNullPointerErrorMessage));
}

// Allocate the vector for output packed values if needed.
if (packed_values->ptr == nullptr) {
packed_values->ptr =
std::make_unique<std::vector<secure_aggregation::BigInteger>>();
}
auto curr_packed_values =
rlwe::PackMessagesFlat<secure_aggregation::Integer,
secure_aggregation::BigInteger>(
Expand All @@ -339,6 +333,11 @@ FfiStatus PackMessagesRaw(rust::Slice<const uint64_t> messages,
}
// Pad with zeros if needed.
curr_packed_values.resize(num_packed_values, 0);
// Allocate the vector for output packed values if needed.
if (packed_values->ptr == nullptr) {
packed_values->ptr =
std::make_unique<std::vector<secure_aggregation::BigInteger>>();
}
// Append the packed values to the end of the output vector.
packed_values->ptr->insert(packed_values->ptr->end(),
curr_packed_values.begin(),
Expand Down
17 changes: 12 additions & 5 deletions shell_wrapper/kahe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ mod ffi {
}

pub struct BigIntVectorWrapper {
pub ptr: UniquePtr<CxxVector<BigInteger>>,
pub(crate) ptr: UniquePtr<CxxVector<BigInteger>>,
}

unsafe extern "C++" {
Expand Down Expand Up @@ -190,7 +190,11 @@ pub fn encrypt(
params: &KahePublicParametersWrapper,
prng: &mut SingleThreadHkdfWrapper,
) -> Result<RnsPolynomialVec, status::StatusError> {
let mut packed_values = MaybeUninit::<BigIntVectorWrapper>::zeroed();
// SAFETY: this initializes `packed_values` with packed_values.ptr == nullptr. The following
// loop ensures that we either return an error (and drop `packed_values`, including potential
// partial allocations), or make packed_values.ptr point to a valid C++ vector.
let mut packed_values = BigIntVectorWrapper { ptr: cxx::UniquePtr::null() };

// SAFETY: No lifetime constraints (`PackMessagesRaw` may create a new vector of BigIntegers
// wrapped by `packed_values` which does not keep any reference to the inputs).
// `PackMessagesRaw` only appends to the C++ vector wrapped by `packed_values`,
Expand All @@ -206,7 +210,7 @@ pub fn encrypt(
packed_vector_config.base,
packed_vector_config.dimension,
packed_vector_config.num_packed_coeffs,
packed_values.as_mut_ptr(),
&mut packed_values,
)
})?;
}
Expand All @@ -217,7 +221,7 @@ pub fn encrypt(
// wrapped by `packed_values`, updates the states wrapped by `prng`, and writes into the C++
// vector wrapped by `out`.
rust_status_from_cpp(unsafe {
ffi::Encrypt(&packed_values.assume_init(), secret_key, params, prng, out.as_mut_ptr())
ffi::Encrypt(&packed_values, secret_key, params, prng, out.as_mut_ptr())
})?;
// SAFETY: `out` is safely initialized if we get to this point.
Ok(unsafe { out.assume_init() })
Expand All @@ -239,6 +243,9 @@ pub fn decrypt(
ffi::Decrypt(ciphertext, secret_key, params, packed_values.as_mut_ptr())
})?;

// SAFETY: `packed_values` is safely initialized if we get to this point.
let mut packed_values = unsafe { packed_values.assume_init() };

let mut output_vectors = HashMap::<String, Vec<u64>>::new();
// Assume the packed values are stored in the same order as the configs.
for (id, packed_vector_config) in packed_vector_configs.iter() {
Expand All @@ -253,7 +260,7 @@ pub fn decrypt(
packed_vector_config.base,
packed_vector_config.dimension,
packed_vector_config.num_packed_coeffs,
packed_values.assume_init_mut(),
&mut packed_values,
&mut unpacked_values,
)
})?;
Expand Down
Loading