Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 70 additions & 16 deletions shell_encryption/rns/rns_polynomial.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,10 @@ class RnsPolynomial {
return output;
}

// Adds `that` to `this` in-place.
// Adds `that` to `this` in-place. May fail if the coefficient vectors of
// `this` and `that` are not the same size, or the number of coefficient
// vectors does not match the number of moduli. In case of failure, `this`
// is not modified.
absl::Status AddInPlace(
const RnsPolynomial& that,
absl::Span<const PrimeModulus<ModularInt>* const> moduli) {
Expand All @@ -203,14 +206,28 @@ class RnsPolynomial {
return absl::InvalidArgumentError(
absl::StrCat("`moduli` must contain ", num_moduli, " RNS moduli."));
}
// Check all coeff vectors have the right size to catch errors before
// modifying `this`.
for (int i = 0; i < num_moduli; ++i) {
RLWE_RETURN_IF_ERROR(ModularInt::BatchAddInPlace(
&coeff_vectors_[i], that.coeff_vectors_[i], moduli[i]->ModParams()));
if (coeff_vectors_[i].size() != that.coeff_vectors_[i].size()) {
return absl::InvalidArgumentError(
absl::StrCat("Size of coefficient vector ", i, " does not match"));
}
}
for (int i = 0; i < num_moduli; ++i) {
ModularInt::BatchAddInPlace(&coeff_vectors_[i], that.coeff_vectors_[i],
moduli[i]->ModParams())
.IgnoreError(); // We already checked the sizes above, and
// BatchAddInPlace will only fail if its inputs have
// different sizes.
}
return absl::OkStatus();
}

// Substracts `that` from `this` in-place.
// Substracts `that` from `this` in-place. May fail if the coefficient
// vectors of `this` and `that` are not the same size, or the number of
// coefficient vectors does not match the number of moduli. In case of
// failure, `this` is not modified.
absl::Status SubInPlace(
const RnsPolynomial& that,
absl::Span<const PrimeModulus<ModularInt>* const> moduli) {
Expand All @@ -223,9 +240,20 @@ class RnsPolynomial {
return absl::InvalidArgumentError(
absl::StrCat("`moduli` must contain ", num_moduli, " RNS moduli."));
}
// Check all coeff vectors have the right size to catch errors before
// modifying `this`.
for (int i = 0; i < num_moduli; ++i) {
if (coeff_vectors_[i].size() != that.coeff_vectors_[i].size()) {
return absl::InvalidArgumentError(
absl::StrCat("Size of coefficient vector ", i, " does not match"));
}
}
for (int i = 0; i < num_moduli; ++i) {
RLWE_RETURN_IF_ERROR(ModularInt::BatchSubInPlace(
&coeff_vectors_[i], that.coeff_vectors_[i], moduli[i]->ModParams()));
ModularInt::BatchSubInPlace(&coeff_vectors_[i], that.coeff_vectors_[i],
moduli[i]->ModParams())
.IgnoreError(); // We already checked the sizes above, and
// BatchSubInPlace will only fail if its inputs have
// different sizes.
}
return absl::OkStatus();
}
Expand Down Expand Up @@ -295,8 +323,10 @@ class RnsPolynomial {
return output;
}

// Multiplies this polynomial by `that`.
// Both `this` and `that` must be in NTT form.
// Multiplies this polynomial by `that`. Both `this` and `that` must be in NTT
// form. May fail if the coefficient vectors of `this` and `that` are not the
// same size, or the number of coefficient vectors does not match the number
// of moduli. In case of failure, `this` is not modified.
absl::Status MulInPlace(
const RnsPolynomial& that,
absl::Span<const PrimeModulus<ModularInt>* const> moduli) {
Expand All @@ -317,16 +347,29 @@ class RnsPolynomial {
return absl::InvalidArgumentError(
"RNS polynomial `that` must be in NTT form.");
}

// Check all coeff vectors have the right size to catch errors before
// modifying `this`.
for (int i = 0; i < num_moduli; ++i) {
RLWE_RETURN_IF_ERROR(ModularInt::BatchMulInPlace(
&coeff_vectors_[i], that.coeff_vectors_[i], moduli[i]->ModParams()));
if (coeff_vectors_[i].size() != that.coeff_vectors_[i].size()) {
return absl::InvalidArgumentError(
absl::StrCat("Size of coefficient vector ", i, " does not match"));
}
}
for (int i = 0; i < num_moduli; ++i) {
ModularInt::BatchMulInPlace(&coeff_vectors_[i], that.coeff_vectors_[i],
moduli[i]->ModParams())
.IgnoreError(); // We already checked the sizes above, and
// BatchMulInPlace will only fail if its inputs have
// different sizes.
}
return absl::OkStatus();
}

// Adds the polynomial product a * b to this polynomial.
// Polynomials `this`, `a`, and `b` must be all in NTT form.
// Adds the polynomial product a * b to this polynomial. Polynomials `this`,
// `a`, and `b` must be all in NTT form. May fail if the coefficient vectors
// of `this`, `a`, and `b` are not the same size, or the number of coefficient
// vectors does not match the number of moduli. In case of failure, `this` is
// not modified.
absl::Status FusedMulAddInPlace(
const RnsPolynomial& a, const RnsPolynomial& b,
absl::Span<const PrimeModulus<ModularInt>* const> moduli) {
Expand Down Expand Up @@ -355,11 +398,22 @@ class RnsPolynomial {
return absl::InvalidArgumentError(
"RNS polynomial `b` must be in NTT form.");
}

// Check all coeff vectors have the right size to catch errors before
// modifying `this`.
for (int i = 0; i < num_moduli; ++i) {
if (coeff_vectors_[i].size() != a.coeff_vectors_[i].size() ||
coeff_vectors_[i].size() != b.coeff_vectors_[i].size()) {
return absl::InvalidArgumentError(
absl::StrCat("Size of coefficient vector ", i, " does not match"));
}
}
for (int i = 0; i < num_moduli; ++i) {
RLWE_RETURN_IF_ERROR(ModularInt::BatchFusedMulAddInPlace(
ModularInt::BatchFusedMulAddInPlace(
&coeff_vectors_[i], a.coeff_vectors_[i], b.coeff_vectors_[i],
moduli[i]->ModParams()));
moduli[i]->ModParams())
.IgnoreError(); // We already checked the sizes above, and
// BatchFusedMulAddInPlace will only fail if its
// inputs have different sizes.
}
return absl::OkStatus();
}
Expand Down
82 changes: 82 additions & 0 deletions shell_encryption/rns/rns_polynomial_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2610,5 +2610,87 @@ TYPED_TEST(RnsPolynomialTest, SerializedDeserializes) {
EXPECT_EQ(deserialized2, a);
}

TYPED_TEST(RnsPolynomialTest, AddInPlaceLeavesPolynomialUnchangedOnFailure) {
ASSERT_OK_AND_ASSIGN(RnsPolynomial<TypeParam> a, this->SampleRnsPolynomial());
ASSERT_OK_AND_ASSIGN(RnsPolynomial<TypeParam> b, this->SampleRnsPolynomial());

// Double the length of the coefficient vectors of b.
std::vector<std::vector<TypeParam>> b_coeffs = b.Coeffs();
for (int i = 0; i < b_coeffs.size(); ++i) {
b_coeffs[i].insert(b_coeffs[i].end(), b.Coeffs()[i].begin(),
b.Coeffs()[i].end());
}
ASSERT_OK_AND_ASSIGN(
b, RnsPolynomial<TypeParam>::Create(std::move(b_coeffs), b.IsNttForm()));

auto a2 = a.Clone();
EXPECT_THAT(a.AddInPlace(b, this->moduli_),
StatusIs(absl::StatusCode::kInvalidArgument,
"Size of coefficient vector 0 does not match"));
EXPECT_EQ(a, a2);
}

TYPED_TEST(RnsPolynomialTest, SubInPlaceLeavesPolynomialUnchangedOnFailure) {
ASSERT_OK_AND_ASSIGN(RnsPolynomial<TypeParam> a, this->SampleRnsPolynomial());
ASSERT_OK_AND_ASSIGN(RnsPolynomial<TypeParam> b, this->SampleRnsPolynomial());

// Double the length of the coefficient vectors of b.
std::vector<std::vector<TypeParam>> b_coeffs = b.Coeffs();
for (int i = 0; i < b_coeffs.size(); ++i) {
b_coeffs[i].insert(b_coeffs[i].end(), b.Coeffs()[i].begin(),
b.Coeffs()[i].end());
}
ASSERT_OK_AND_ASSIGN(
b, RnsPolynomial<TypeParam>::Create(std::move(b_coeffs), b.IsNttForm()));

auto a2 = a.Clone();
EXPECT_THAT(a.SubInPlace(b, this->moduli_),
StatusIs(absl::StatusCode::kInvalidArgument,
"Size of coefficient vector 0 does not match"));
EXPECT_EQ(a, a2);
}

TYPED_TEST(RnsPolynomialTest, MulInPlaceLeavesPolynomialUnchangedOnFailure) {
ASSERT_OK_AND_ASSIGN(RnsPolynomial<TypeParam> a, this->SampleRnsPolynomial());
ASSERT_OK_AND_ASSIGN(RnsPolynomial<TypeParam> b, this->SampleRnsPolynomial());

// Double the length of the coefficient vectors of b.
std::vector<std::vector<TypeParam>> b_coeffs = b.Coeffs();
for (int i = 0; i < b_coeffs.size(); ++i) {
b_coeffs[i].insert(b_coeffs[i].end(), b.Coeffs()[i].begin(),
b.Coeffs()[i].end());
}
ASSERT_OK_AND_ASSIGN(
b, RnsPolynomial<TypeParam>::Create(std::move(b_coeffs), b.IsNttForm()));

auto a2 = a.Clone();
EXPECT_THAT(a.MulInPlace(b, this->moduli_),
StatusIs(absl::StatusCode::kInvalidArgument,
"Size of coefficient vector 0 does not match"));
EXPECT_EQ(a, a2);
}

TYPED_TEST(RnsPolynomialTest,
FusedMulAddInPlaceLeavesPolynomialUnchangedOnFailure) {
ASSERT_OK_AND_ASSIGN(RnsPolynomial<TypeParam> a, this->SampleRnsPolynomial());
ASSERT_OK_AND_ASSIGN(RnsPolynomial<TypeParam> b, this->SampleRnsPolynomial());
ASSERT_OK_AND_ASSIGN(RnsPolynomial<TypeParam> c, this->SampleRnsPolynomial());

// Double the length of the coefficient vectors of b.
std::vector<std::vector<TypeParam>> b_coeffs = b.Coeffs();
for (int i = 0; i < b_coeffs.size(); ++i) {
b_coeffs[i].insert(b_coeffs[i].end(), b.Coeffs()[i].begin(),
b.Coeffs()[i].end());
}
ASSERT_OK_AND_ASSIGN(
b, RnsPolynomial<TypeParam>::Create(std::move(b_coeffs), b.IsNttForm()));

auto a2 = a.Clone();
EXPECT_THAT(a.FusedMulAddInPlace(b, c, this->moduli_),
StatusIs(absl::StatusCode::kInvalidArgument,
"Size of coefficient vector 0 does not match"));
EXPECT_EQ(a, a2);
}

} // namespace
} // namespace rlwe
Loading