diff --git a/shell_encryption/rns/rns_polynomial.h b/shell_encryption/rns/rns_polynomial.h index 7173a4b..ddb2b15 100644 --- a/shell_encryption/rns/rns_polynomial.h +++ b/shell_encryption/rns/rns_polynomial.h @@ -190,7 +190,10 @@ class RnsPolynomial { return output; } - // Adds `that` to `this` in-place. + // Adds `that` to `this` in-place. May fail if the coefficient vectors of + // `this` and `that` are not the same size, or the number of coefficient + // vectors does not match the number of moduli. In case of failure, `this` + // is not modified. absl::Status AddInPlace( const RnsPolynomial& that, absl::Span* const> moduli) { @@ -203,14 +206,28 @@ class RnsPolynomial { return absl::InvalidArgumentError( absl::StrCat("`moduli` must contain ", num_moduli, " RNS moduli.")); } + // Check all coeff vectors have the right size to catch errors before + // modifying `this`. for (int i = 0; i < num_moduli; ++i) { - RLWE_RETURN_IF_ERROR(ModularInt::BatchAddInPlace( - &coeff_vectors_[i], that.coeff_vectors_[i], moduli[i]->ModParams())); + if (coeff_vectors_[i].size() != that.coeff_vectors_[i].size()) { + return absl::InvalidArgumentError( + absl::StrCat("Size of coefficient vector ", i, " does not match")); + } + } + for (int i = 0; i < num_moduli; ++i) { + ModularInt::BatchAddInPlace(&coeff_vectors_[i], that.coeff_vectors_[i], + moduli[i]->ModParams()) + .IgnoreError(); // We already checked the sizes above, and + // BatchAddInPlace will only fail if its inputs have + // different sizes. } return absl::OkStatus(); } - // Substracts `that` from `this` in-place. + // Substracts `that` from `this` in-place. May fail if the coefficient + // vectors of `this` and `that` are not the same size, or the number of + // coefficient vectors does not match the number of moduli. In case of + // failure, `this` is not modified. absl::Status SubInPlace( const RnsPolynomial& that, absl::Span* const> moduli) { @@ -223,9 +240,20 @@ class RnsPolynomial { return absl::InvalidArgumentError( absl::StrCat("`moduli` must contain ", num_moduli, " RNS moduli.")); } + // Check all coeff vectors have the right size to catch errors before + // modifying `this`. + for (int i = 0; i < num_moduli; ++i) { + if (coeff_vectors_[i].size() != that.coeff_vectors_[i].size()) { + return absl::InvalidArgumentError( + absl::StrCat("Size of coefficient vector ", i, " does not match")); + } + } for (int i = 0; i < num_moduli; ++i) { - RLWE_RETURN_IF_ERROR(ModularInt::BatchSubInPlace( - &coeff_vectors_[i], that.coeff_vectors_[i], moduli[i]->ModParams())); + ModularInt::BatchSubInPlace(&coeff_vectors_[i], that.coeff_vectors_[i], + moduli[i]->ModParams()) + .IgnoreError(); // We already checked the sizes above, and + // BatchSubInPlace will only fail if its inputs have + // different sizes. } return absl::OkStatus(); } @@ -295,8 +323,10 @@ class RnsPolynomial { return output; } - // Multiplies this polynomial by `that`. - // Both `this` and `that` must be in NTT form. + // Multiplies this polynomial by `that`. Both `this` and `that` must be in NTT + // form. May fail if the coefficient vectors of `this` and `that` are not the + // same size, or the number of coefficient vectors does not match the number + // of moduli. In case of failure, `this` is not modified. absl::Status MulInPlace( const RnsPolynomial& that, absl::Span* const> moduli) { @@ -317,16 +347,29 @@ class RnsPolynomial { return absl::InvalidArgumentError( "RNS polynomial `that` must be in NTT form."); } - + // Check all coeff vectors have the right size to catch errors before + // modifying `this`. for (int i = 0; i < num_moduli; ++i) { - RLWE_RETURN_IF_ERROR(ModularInt::BatchMulInPlace( - &coeff_vectors_[i], that.coeff_vectors_[i], moduli[i]->ModParams())); + if (coeff_vectors_[i].size() != that.coeff_vectors_[i].size()) { + return absl::InvalidArgumentError( + absl::StrCat("Size of coefficient vector ", i, " does not match")); + } + } + for (int i = 0; i < num_moduli; ++i) { + ModularInt::BatchMulInPlace(&coeff_vectors_[i], that.coeff_vectors_[i], + moduli[i]->ModParams()) + .IgnoreError(); // We already checked the sizes above, and + // BatchMulInPlace will only fail if its inputs have + // different sizes. } return absl::OkStatus(); } - // Adds the polynomial product a * b to this polynomial. - // Polynomials `this`, `a`, and `b` must be all in NTT form. + // Adds the polynomial product a * b to this polynomial. Polynomials `this`, + // `a`, and `b` must be all in NTT form. May fail if the coefficient vectors + // of `this`, `a`, and `b` are not the same size, or the number of coefficient + // vectors does not match the number of moduli. In case of failure, `this` is + // not modified. absl::Status FusedMulAddInPlace( const RnsPolynomial& a, const RnsPolynomial& b, absl::Span* const> moduli) { @@ -355,11 +398,22 @@ class RnsPolynomial { return absl::InvalidArgumentError( "RNS polynomial `b` must be in NTT form."); } - + // Check all coeff vectors have the right size to catch errors before + // modifying `this`. + for (int i = 0; i < num_moduli; ++i) { + if (coeff_vectors_[i].size() != a.coeff_vectors_[i].size() || + coeff_vectors_[i].size() != b.coeff_vectors_[i].size()) { + return absl::InvalidArgumentError( + absl::StrCat("Size of coefficient vector ", i, " does not match")); + } + } for (int i = 0; i < num_moduli; ++i) { - RLWE_RETURN_IF_ERROR(ModularInt::BatchFusedMulAddInPlace( + ModularInt::BatchFusedMulAddInPlace( &coeff_vectors_[i], a.coeff_vectors_[i], b.coeff_vectors_[i], - moduli[i]->ModParams())); + moduli[i]->ModParams()) + .IgnoreError(); // We already checked the sizes above, and + // BatchFusedMulAddInPlace will only fail if its + // inputs have different sizes. } return absl::OkStatus(); } diff --git a/shell_encryption/rns/rns_polynomial_test.cc b/shell_encryption/rns/rns_polynomial_test.cc index e01757d..8d95df5 100644 --- a/shell_encryption/rns/rns_polynomial_test.cc +++ b/shell_encryption/rns/rns_polynomial_test.cc @@ -2610,5 +2610,87 @@ TYPED_TEST(RnsPolynomialTest, SerializedDeserializes) { EXPECT_EQ(deserialized2, a); } +TYPED_TEST(RnsPolynomialTest, AddInPlaceLeavesPolynomialUnchangedOnFailure) { + ASSERT_OK_AND_ASSIGN(RnsPolynomial a, this->SampleRnsPolynomial()); + ASSERT_OK_AND_ASSIGN(RnsPolynomial b, this->SampleRnsPolynomial()); + + // Double the length of the coefficient vectors of b. + std::vector> b_coeffs = b.Coeffs(); + for (int i = 0; i < b_coeffs.size(); ++i) { + b_coeffs[i].insert(b_coeffs[i].end(), b.Coeffs()[i].begin(), + b.Coeffs()[i].end()); + } + ASSERT_OK_AND_ASSIGN( + b, RnsPolynomial::Create(std::move(b_coeffs), b.IsNttForm())); + + auto a2 = a.Clone(); + EXPECT_THAT(a.AddInPlace(b, this->moduli_), + StatusIs(absl::StatusCode::kInvalidArgument, + "Size of coefficient vector 0 does not match")); + EXPECT_EQ(a, a2); +} + +TYPED_TEST(RnsPolynomialTest, SubInPlaceLeavesPolynomialUnchangedOnFailure) { + ASSERT_OK_AND_ASSIGN(RnsPolynomial a, this->SampleRnsPolynomial()); + ASSERT_OK_AND_ASSIGN(RnsPolynomial b, this->SampleRnsPolynomial()); + + // Double the length of the coefficient vectors of b. + std::vector> b_coeffs = b.Coeffs(); + for (int i = 0; i < b_coeffs.size(); ++i) { + b_coeffs[i].insert(b_coeffs[i].end(), b.Coeffs()[i].begin(), + b.Coeffs()[i].end()); + } + ASSERT_OK_AND_ASSIGN( + b, RnsPolynomial::Create(std::move(b_coeffs), b.IsNttForm())); + + auto a2 = a.Clone(); + EXPECT_THAT(a.SubInPlace(b, this->moduli_), + StatusIs(absl::StatusCode::kInvalidArgument, + "Size of coefficient vector 0 does not match")); + EXPECT_EQ(a, a2); +} + +TYPED_TEST(RnsPolynomialTest, MulInPlaceLeavesPolynomialUnchangedOnFailure) { + ASSERT_OK_AND_ASSIGN(RnsPolynomial a, this->SampleRnsPolynomial()); + ASSERT_OK_AND_ASSIGN(RnsPolynomial b, this->SampleRnsPolynomial()); + + // Double the length of the coefficient vectors of b. + std::vector> b_coeffs = b.Coeffs(); + for (int i = 0; i < b_coeffs.size(); ++i) { + b_coeffs[i].insert(b_coeffs[i].end(), b.Coeffs()[i].begin(), + b.Coeffs()[i].end()); + } + ASSERT_OK_AND_ASSIGN( + b, RnsPolynomial::Create(std::move(b_coeffs), b.IsNttForm())); + + auto a2 = a.Clone(); + EXPECT_THAT(a.MulInPlace(b, this->moduli_), + StatusIs(absl::StatusCode::kInvalidArgument, + "Size of coefficient vector 0 does not match")); + EXPECT_EQ(a, a2); +} + +TYPED_TEST(RnsPolynomialTest, + FusedMulAddInPlaceLeavesPolynomialUnchangedOnFailure) { + ASSERT_OK_AND_ASSIGN(RnsPolynomial a, this->SampleRnsPolynomial()); + ASSERT_OK_AND_ASSIGN(RnsPolynomial b, this->SampleRnsPolynomial()); + ASSERT_OK_AND_ASSIGN(RnsPolynomial c, this->SampleRnsPolynomial()); + + // Double the length of the coefficient vectors of b. + std::vector> b_coeffs = b.Coeffs(); + for (int i = 0; i < b_coeffs.size(); ++i) { + b_coeffs[i].insert(b_coeffs[i].end(), b.Coeffs()[i].begin(), + b.Coeffs()[i].end()); + } + ASSERT_OK_AND_ASSIGN( + b, RnsPolynomial::Create(std::move(b_coeffs), b.IsNttForm())); + + auto a2 = a.Clone(); + EXPECT_THAT(a.FusedMulAddInPlace(b, c, this->moduli_), + StatusIs(absl::StatusCode::kInvalidArgument, + "Size of coefficient vector 0 does not match")); + EXPECT_EQ(a, a2); +} + } // namespace } // namespace rlwe