Skip to content

Commit 06289e0

Browse files
schoppmpcopybara-github
authored andcommitted
Leave this unchanged when {Add,Sub,Mul,FusedMulAdd}InPlace fail
PiperOrigin-RevId: 850507664
1 parent 5a73855 commit 06289e0

File tree

2 files changed

+152
-16
lines changed

2 files changed

+152
-16
lines changed

shell_encryption/rns/rns_polynomial.h

Lines changed: 70 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,10 @@ class RnsPolynomial {
190190
return output;
191191
}
192192

193-
// Adds `that` to `this` in-place.
193+
// Adds `that` to `this` in-place. May fail if the coefficient vectors of
194+
// `this` and `that` are not the same size, or the number of coefficient
195+
// vectors does not match the number of moduli. In case of failure, `this`
196+
// is not modified.
194197
absl::Status AddInPlace(
195198
const RnsPolynomial& that,
196199
absl::Span<const PrimeModulus<ModularInt>* const> moduli) {
@@ -203,14 +206,28 @@ class RnsPolynomial {
203206
return absl::InvalidArgumentError(
204207
absl::StrCat("`moduli` must contain ", num_moduli, " RNS moduli."));
205208
}
209+
// Check all coeff vectors have the right size to catch errors before
210+
// modifying `this`.
206211
for (int i = 0; i < num_moduli; ++i) {
207-
RLWE_RETURN_IF_ERROR(ModularInt::BatchAddInPlace(
208-
&coeff_vectors_[i], that.coeff_vectors_[i], moduli[i]->ModParams()));
212+
if (coeff_vectors_[i].size() != that.coeff_vectors_[i].size()) {
213+
return absl::InvalidArgumentError(
214+
absl::StrCat("Size of coefficient vector ", i, " does not match"));
215+
}
216+
}
217+
for (int i = 0; i < num_moduli; ++i) {
218+
ModularInt::BatchAddInPlace(&coeff_vectors_[i], that.coeff_vectors_[i],
219+
moduli[i]->ModParams())
220+
.IgnoreError(); // We already checked the sizes above, and
221+
// BatchAddInPlace will only fail if its inputs have
222+
// different sizes.
209223
}
210224
return absl::OkStatus();
211225
}
212226

213-
// Substracts `that` from `this` in-place.
227+
// Substracts `that` from `this` in-place. May fail if the coefficient
228+
// vectors of `this` and `that` are not the same size, or the number of
229+
// coefficient vectors does not match the number of moduli. In case of
230+
// failure, `this` is not modified.
214231
absl::Status SubInPlace(
215232
const RnsPolynomial& that,
216233
absl::Span<const PrimeModulus<ModularInt>* const> moduli) {
@@ -223,9 +240,20 @@ class RnsPolynomial {
223240
return absl::InvalidArgumentError(
224241
absl::StrCat("`moduli` must contain ", num_moduli, " RNS moduli."));
225242
}
243+
// Check all coeff vectors have the right size to catch errors before
244+
// modifying `this`.
245+
for (int i = 0; i < num_moduli; ++i) {
246+
if (coeff_vectors_[i].size() != that.coeff_vectors_[i].size()) {
247+
return absl::InvalidArgumentError(
248+
absl::StrCat("Size of coefficient vector ", i, " does not match"));
249+
}
250+
}
226251
for (int i = 0; i < num_moduli; ++i) {
227-
RLWE_RETURN_IF_ERROR(ModularInt::BatchSubInPlace(
228-
&coeff_vectors_[i], that.coeff_vectors_[i], moduli[i]->ModParams()));
252+
ModularInt::BatchSubInPlace(&coeff_vectors_[i], that.coeff_vectors_[i],
253+
moduli[i]->ModParams())
254+
.IgnoreError(); // We already checked the sizes above, and
255+
// BatchSubInPlace will only fail if its inputs have
256+
// different sizes.
229257
}
230258
return absl::OkStatus();
231259
}
@@ -295,8 +323,10 @@ class RnsPolynomial {
295323
return output;
296324
}
297325

298-
// Multiplies this polynomial by `that`.
299-
// Both `this` and `that` must be in NTT form.
326+
// Multiplies this polynomial by `that`. Both `this` and `that` must be in NTT
327+
// form. May fail if the coefficient vectors of `this` and `that` are not the
328+
// same size, or the number of coefficient vectors does not match the number
329+
// of moduli. In case of failure, `this` is not modified.
300330
absl::Status MulInPlace(
301331
const RnsPolynomial& that,
302332
absl::Span<const PrimeModulus<ModularInt>* const> moduli) {
@@ -317,16 +347,29 @@ class RnsPolynomial {
317347
return absl::InvalidArgumentError(
318348
"RNS polynomial `that` must be in NTT form.");
319349
}
320-
350+
// Check all coeff vectors have the right size to catch errors before
351+
// modifying `this`.
321352
for (int i = 0; i < num_moduli; ++i) {
322-
RLWE_RETURN_IF_ERROR(ModularInt::BatchMulInPlace(
323-
&coeff_vectors_[i], that.coeff_vectors_[i], moduli[i]->ModParams()));
353+
if (coeff_vectors_[i].size() != that.coeff_vectors_[i].size()) {
354+
return absl::InvalidArgumentError(
355+
absl::StrCat("Size of coefficient vector ", i, " does not match"));
356+
}
357+
}
358+
for (int i = 0; i < num_moduli; ++i) {
359+
ModularInt::BatchMulInPlace(&coeff_vectors_[i], that.coeff_vectors_[i],
360+
moduli[i]->ModParams())
361+
.IgnoreError(); // We already checked the sizes above, and
362+
// BatchMulInPlace will only fail if its inputs have
363+
// different sizes.
324364
}
325365
return absl::OkStatus();
326366
}
327367

328-
// Adds the polynomial product a * b to this polynomial.
329-
// Polynomials `this`, `a`, and `b` must be all in NTT form.
368+
// Adds the polynomial product a * b to this polynomial. Polynomials `this`,
369+
// `a`, and `b` must be all in NTT form. May fail if the coefficient vectors
370+
// of `this`, `a`, and `b` are not the same size, or the number of coefficient
371+
// vectors does not match the number of moduli. In case of failure, `this` is
372+
// not modified.
330373
absl::Status FusedMulAddInPlace(
331374
const RnsPolynomial& a, const RnsPolynomial& b,
332375
absl::Span<const PrimeModulus<ModularInt>* const> moduli) {
@@ -355,11 +398,22 @@ class RnsPolynomial {
355398
return absl::InvalidArgumentError(
356399
"RNS polynomial `b` must be in NTT form.");
357400
}
358-
401+
// Check all coeff vectors have the right size to catch errors before
402+
// modifying `this`.
403+
for (int i = 0; i < num_moduli; ++i) {
404+
if (coeff_vectors_[i].size() != a.coeff_vectors_[i].size() ||
405+
coeff_vectors_[i].size() != b.coeff_vectors_[i].size()) {
406+
return absl::InvalidArgumentError(
407+
absl::StrCat("Size of coefficient vector ", i, " does not match"));
408+
}
409+
}
359410
for (int i = 0; i < num_moduli; ++i) {
360-
RLWE_RETURN_IF_ERROR(ModularInt::BatchFusedMulAddInPlace(
411+
ModularInt::BatchFusedMulAddInPlace(
361412
&coeff_vectors_[i], a.coeff_vectors_[i], b.coeff_vectors_[i],
362-
moduli[i]->ModParams()));
413+
moduli[i]->ModParams())
414+
.IgnoreError(); // We already checked the sizes above, and
415+
// BatchFusedMulAddInPlace will only fail if its
416+
// inputs have different sizes.
363417
}
364418
return absl::OkStatus();
365419
}

shell_encryption/rns/rns_polynomial_test.cc

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2610,5 +2610,87 @@ TYPED_TEST(RnsPolynomialTest, SerializedDeserializes) {
26102610
EXPECT_EQ(deserialized2, a);
26112611
}
26122612

2613+
TYPED_TEST(RnsPolynomialTest, AddInPlaceLeavesPolynomialUnchangedOnFailure) {
2614+
ASSERT_OK_AND_ASSIGN(RnsPolynomial<TypeParam> a, this->SampleRnsPolynomial());
2615+
ASSERT_OK_AND_ASSIGN(RnsPolynomial<TypeParam> b, this->SampleRnsPolynomial());
2616+
2617+
// Double the length of the coefficient vectors of b.
2618+
std::vector<std::vector<TypeParam>> b_coeffs = b.Coeffs();
2619+
for (int i = 0; i < b_coeffs.size(); ++i) {
2620+
b_coeffs[i].insert(b_coeffs[i].end(), b.Coeffs()[i].begin(),
2621+
b.Coeffs()[i].end());
2622+
}
2623+
ASSERT_OK_AND_ASSIGN(
2624+
b, RnsPolynomial<TypeParam>::Create(std::move(b_coeffs), b.IsNttForm()));
2625+
2626+
auto a2 = a.Clone();
2627+
EXPECT_THAT(a.AddInPlace(b, this->moduli_),
2628+
StatusIs(absl::StatusCode::kInvalidArgument,
2629+
"Size of coefficient vector 0 does not match"));
2630+
EXPECT_EQ(a, a2);
2631+
}
2632+
2633+
TYPED_TEST(RnsPolynomialTest, SubInPlaceLeavesPolynomialUnchangedOnFailure) {
2634+
ASSERT_OK_AND_ASSIGN(RnsPolynomial<TypeParam> a, this->SampleRnsPolynomial());
2635+
ASSERT_OK_AND_ASSIGN(RnsPolynomial<TypeParam> b, this->SampleRnsPolynomial());
2636+
2637+
// Double the length of the coefficient vectors of b.
2638+
std::vector<std::vector<TypeParam>> b_coeffs = b.Coeffs();
2639+
for (int i = 0; i < b_coeffs.size(); ++i) {
2640+
b_coeffs[i].insert(b_coeffs[i].end(), b.Coeffs()[i].begin(),
2641+
b.Coeffs()[i].end());
2642+
}
2643+
ASSERT_OK_AND_ASSIGN(
2644+
b, RnsPolynomial<TypeParam>::Create(std::move(b_coeffs), b.IsNttForm()));
2645+
2646+
auto a2 = a.Clone();
2647+
EXPECT_THAT(a.SubInPlace(b, this->moduli_),
2648+
StatusIs(absl::StatusCode::kInvalidArgument,
2649+
"Size of coefficient vector 0 does not match"));
2650+
EXPECT_EQ(a, a2);
2651+
}
2652+
2653+
TYPED_TEST(RnsPolynomialTest, MulInPlaceLeavesPolynomialUnchangedOnFailure) {
2654+
ASSERT_OK_AND_ASSIGN(RnsPolynomial<TypeParam> a, this->SampleRnsPolynomial());
2655+
ASSERT_OK_AND_ASSIGN(RnsPolynomial<TypeParam> b, this->SampleRnsPolynomial());
2656+
2657+
// Double the length of the coefficient vectors of b.
2658+
std::vector<std::vector<TypeParam>> b_coeffs = b.Coeffs();
2659+
for (int i = 0; i < b_coeffs.size(); ++i) {
2660+
b_coeffs[i].insert(b_coeffs[i].end(), b.Coeffs()[i].begin(),
2661+
b.Coeffs()[i].end());
2662+
}
2663+
ASSERT_OK_AND_ASSIGN(
2664+
b, RnsPolynomial<TypeParam>::Create(std::move(b_coeffs), b.IsNttForm()));
2665+
2666+
auto a2 = a.Clone();
2667+
EXPECT_THAT(a.MulInPlace(b, this->moduli_),
2668+
StatusIs(absl::StatusCode::kInvalidArgument,
2669+
"Size of coefficient vector 0 does not match"));
2670+
EXPECT_EQ(a, a2);
2671+
}
2672+
2673+
TYPED_TEST(RnsPolynomialTest,
2674+
FusedMulAddInPlaceLeavesPolynomialUnchangedOnFailure) {
2675+
ASSERT_OK_AND_ASSIGN(RnsPolynomial<TypeParam> a, this->SampleRnsPolynomial());
2676+
ASSERT_OK_AND_ASSIGN(RnsPolynomial<TypeParam> b, this->SampleRnsPolynomial());
2677+
ASSERT_OK_AND_ASSIGN(RnsPolynomial<TypeParam> c, this->SampleRnsPolynomial());
2678+
2679+
// Double the length of the coefficient vectors of b.
2680+
std::vector<std::vector<TypeParam>> b_coeffs = b.Coeffs();
2681+
for (int i = 0; i < b_coeffs.size(); ++i) {
2682+
b_coeffs[i].insert(b_coeffs[i].end(), b.Coeffs()[i].begin(),
2683+
b.Coeffs()[i].end());
2684+
}
2685+
ASSERT_OK_AND_ASSIGN(
2686+
b, RnsPolynomial<TypeParam>::Create(std::move(b_coeffs), b.IsNttForm()));
2687+
2688+
auto a2 = a.Clone();
2689+
EXPECT_THAT(a.FusedMulAddInPlace(b, c, this->moduli_),
2690+
StatusIs(absl::StatusCode::kInvalidArgument,
2691+
"Size of coefficient vector 0 does not match"));
2692+
EXPECT_EQ(a, a2);
2693+
}
2694+
26132695
} // namespace
26142696
} // namespace rlwe

0 commit comments

Comments
 (0)