diff --git a/shell_encryption/rns/rns_polynomial_test.cc b/shell_encryption/rns/rns_polynomial_test.cc index e01757d..e975ae3 100644 --- a/shell_encryption/rns/rns_polynomial_test.cc +++ b/shell_encryption/rns/rns_polynomial_test.cc @@ -2610,5 +2610,87 @@ TYPED_TEST(RnsPolynomialTest, SerializedDeserializes) { EXPECT_EQ(deserialized2, a); } +TYPED_TEST(RnsPolynomialTest, AddInPlaceLeavesPolynomialUnchangedOnFailure) { + ASSERT_OK_AND_ASSIGN(RnsPolynomial a, this->SampleRnsPolynomial()); + ASSERT_OK_AND_ASSIGN(RnsPolynomial b, this->SampleRnsPolynomial()); + + // Double the length of the coefficient vectors of b. + std::vector> b_coeffs = b.Coeffs(); + for (int i = 0; i < b_coeffs.size(); ++i) { + b_coeffs[i].insert(b_coeffs[i].end(), b.Coeffs()[i].begin(), + b.Coeffs()[i].end()); + } + ASSERT_OK_AND_ASSIGN( + b, RnsPolynomial::Create(std::move(b_coeffs), b.IsNttForm())); + + auto a2 = a.Clone(); + EXPECT_THAT( + a.AddInPlace(b, this->moduli_), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("same size"))); + EXPECT_EQ(a, a2); +} + +TYPED_TEST(RnsPolynomialTest, SubInPlaceLeavesPolynomialUnchangedOnFailure) { + ASSERT_OK_AND_ASSIGN(RnsPolynomial a, this->SampleRnsPolynomial()); + ASSERT_OK_AND_ASSIGN(RnsPolynomial b, this->SampleRnsPolynomial()); + + // Double the length of the coefficient vectors of b. + std::vector> b_coeffs = b.Coeffs(); + for (int i = 0; i < b_coeffs.size(); ++i) { + b_coeffs[i].insert(b_coeffs[i].end(), b.Coeffs()[i].begin(), + b.Coeffs()[i].end()); + } + ASSERT_OK_AND_ASSIGN( + b, RnsPolynomial::Create(std::move(b_coeffs), b.IsNttForm())); + + auto a2 = a.Clone(); + EXPECT_THAT( + a.SubInPlace(b, this->moduli_), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("same size"))); + EXPECT_EQ(a, a2); +} + +TYPED_TEST(RnsPolynomialTest, MulInPlaceLeavesPolynomialUnchangedOnFailure) { + ASSERT_OK_AND_ASSIGN(RnsPolynomial a, this->SampleRnsPolynomial()); + ASSERT_OK_AND_ASSIGN(RnsPolynomial b, this->SampleRnsPolynomial()); + + // Double the length of the coefficient vectors of b. + std::vector> b_coeffs = b.Coeffs(); + for (int i = 0; i < b_coeffs.size(); ++i) { + b_coeffs[i].insert(b_coeffs[i].end(), b.Coeffs()[i].begin(), + b.Coeffs()[i].end()); + } + ASSERT_OK_AND_ASSIGN( + b, RnsPolynomial::Create(std::move(b_coeffs), b.IsNttForm())); + + auto a2 = a.Clone(); + EXPECT_THAT( + a.MulInPlace(b, this->moduli_), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("same size"))); + EXPECT_EQ(a, a2); +} + +TYPED_TEST(RnsPolynomialTest, + FusedMulAddInPlaceLeavesPolynomialUnchangedOnFailure) { + ASSERT_OK_AND_ASSIGN(RnsPolynomial a, this->SampleRnsPolynomial()); + ASSERT_OK_AND_ASSIGN(RnsPolynomial b, this->SampleRnsPolynomial()); + ASSERT_OK_AND_ASSIGN(RnsPolynomial c, this->SampleRnsPolynomial()); + + // Double the length of the coefficient vectors of b. + std::vector> b_coeffs = b.Coeffs(); + for (int i = 0; i < b_coeffs.size(); ++i) { + b_coeffs[i].insert(b_coeffs[i].end(), b.Coeffs()[i].begin(), + b.Coeffs()[i].end()); + } + ASSERT_OK_AND_ASSIGN( + b, RnsPolynomial::Create(std::move(b_coeffs), b.IsNttForm())); + + auto a2 = a.Clone(); + EXPECT_THAT( + a.FusedMulAddInPlace(b, c, this->moduli_), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("same size"))); + EXPECT_EQ(a, a2); +} + } // namespace } // namespace rlwe