From 05257566f59097047b6464844dfdb062bbaed6f0 Mon Sep 17 00:00:00 2001 From: Leo Collins Date: Tue, 18 Nov 2025 13:40:24 +0000 Subject: [PATCH] add magic methods --- firedrake/matrix.py | 53 ++++++- tests/firedrake/regression/test_matrix.py | 174 ++++++++++++++++++++++ 2 files changed, 223 insertions(+), 4 deletions(-) diff --git a/firedrake/matrix.py b/firedrake/matrix.py index 2f33841289..9ccbe42ee5 100644 --- a/firedrake/matrix.py +++ b/firedrake/matrix.py @@ -1,10 +1,14 @@ import itertools +from numbers import Complex import ufl from pyop2 import op2 from pyop2.mpi import internal_comm from pyop2.utils import as_tuple from firedrake.petsc import PETSc +from firedrake.function import Function +from firedrake.cofunction import Cofunction +from firedrake.constant import Constant class DummyOP2Mat: @@ -108,18 +112,59 @@ def __str__(self): def __add__(self, other): if isinstance(other, MatrixBase): - mat = self.petscmat + other.petscmat - return AssembledMatrix(self.arguments(), (), mat) + if self.arguments() != other.arguments(): + raise ValueError("Arguments in matrix addition must match.") + return ufl.FormSum((self, 1.), (other, 1.)) else: return NotImplemented def __sub__(self, other): if isinstance(other, MatrixBase): - mat = self.petscmat - other.petscmat - return AssembledMatrix(self.arguments(), (), mat) + if self.arguments() != other.arguments(): + raise ValueError("Arguments in matrix subtraction must match.") + return ufl.FormSum((self, 1.), (other, -1.)) else: return NotImplemented + def __matmul__(self, other): + if isinstance(other, MatrixBase | Function | Cofunction): + return ufl.Action(self, other) + else: + return NotImplemented + + def __rmatmul__(self, other): + if isinstance(other, MatrixBase): + return ufl.Action(other, self) + elif isinstance(other, Cofunction): + return ufl.Action(ufl.Adjoint(self), other) + else: + return NotImplemented + + def __mul__(self, other): + # Scalar multiplication + if isinstance(other, Complex | Constant): + return ufl.FormSum((self, other)) + else: + return NotImplemented + + def __rmul__(self, other): + # Scalar multiplication from the left + if isinstance(other, Complex | Constant): + return ufl.FormSum((self, other)) + else: + return NotImplemented + + def __truediv__(self, other): + # Scalar division + if isinstance(other, Complex | Constant): + other = other.values().item() if isinstance(other, Constant) else other + return ufl.FormSum((self, 1.0/other)) + else: + return NotImplemented + + def __neg__(self): + return ufl.FormSum((self, -1.)) + def assign(self, val): """Set matrix entries.""" if isinstance(val, MatrixBase): diff --git a/tests/firedrake/regression/test_matrix.py b/tests/firedrake/regression/test_matrix.py index 48680747fe..d34d69eeb9 100644 --- a/tests/firedrake/regression/test_matrix.py +++ b/tests/firedrake/regression/test_matrix.py @@ -49,3 +49,177 @@ def test_solve_with_assembled_matrix(a): solve(A == L, solution) assert norm(assemble(f - solution)) < 1e-15 + + +@pytest.fixture +def matrices(): + mesh = UnitSquareMesh(4, 4) + V1 = FunctionSpace(mesh, "CG", 1) + V2 = FunctionSpace(mesh, "CG", 2) + + M1 = assemble(inner(TrialFunction(V1), TestFunction(V1)) * dx) + M1_2 = assemble(inner(2 * TrialFunction(V1), TestFunction(V1)) * dx) + I1 = assemble(interpolate(TrialFunction(V1), V1)) # Identity on V1 + I12 = assemble(interpolate(TrialFunction(V1), V2)) # Interp from V1 to V2 + return M1, M1_2, I1, I12 + + +def test_matrix_matmul(matrices): + M1, M1_2, I1, _ = matrices + + # Test matrix-matrix multiplication + matmul_symb = M1 @ I1 + res1 = assemble(matmul_symb) + assert isinstance(matmul_symb, ufl.Action) + assert isinstance(res1, matrix.AssembledMatrix) + assert np.allclose(res1.petscmat[:, :], M1.petscmat[:, :]) + + # Incompatible matmul + with pytest.raises(TypeError, match="Incompatible function spaces in Action"): + M1 @ M1_2 + + +def test_matrix_addition(matrices): + M1, M1_2, I1, _ = matrices + + V1 = M1.arguments()[0].function_space() + + # Test matrix addition + matadd_symb = I1 + I1 + res2 = assemble(matadd_symb) + assert isinstance(matadd_symb, ufl.FormSum) + assert isinstance(res2, matrix.AssembledMatrix) + assert np.allclose(res2.petscmat[:, :], 2 * np.eye(V1.dim())) + + # Matrix addition incorrect args + with pytest.raises(ValueError, match="Arguments in matrix addition must match."): + I1 + M1 + + # Different matrices, correct arguments + matadd_symb3 = M1 + M1_2 + res_add3 = assemble(matadd_symb3) + assert isinstance(matadd_symb3, ufl.FormSum) + assert isinstance(res_add3, matrix.AssembledMatrix) + assert np.allclose(res_add3.petscmat[:, :], 3 * M1.petscmat[:, :]) + + +def test_matrix_subtraction(matrices): + M1, M1_2, I1, _ = matrices + + matsub = M1 - M1_2 + res_sub = assemble(matsub) + assert isinstance(matsub, ufl.FormSum) + assert isinstance(res_sub, matrix.AssembledMatrix) + assert np.allclose(res_sub.petscmat[:, :], -1 * M1.petscmat[:, :]) + + matsub2 = M1_2 - M1 + res_sub2 = assemble(matsub2) + assert isinstance(matsub2, ufl.FormSum) + assert isinstance(res_sub2, matrix.AssembledMatrix) + assert np.allclose(res_sub2.petscmat[:, :], M1.petscmat[:, :]) + + with pytest.raises(ValueError, match="Arguments in matrix subtraction must match."): + I1 - M1 + + +def test_matrix_scalar_multiplication(matrices): + M1, _, _, _ = matrices + + # Test left scalar multiplication + matscal_left = 2.5 * M1 + res7 = assemble(matscal_left) + assert isinstance(matscal_left, ufl.FormSum) + assert isinstance(res7, matrix.AssembledMatrix) + assert np.allclose(res7.petscmat[:, :], 2.5 * M1.petscmat[:, :]) + + # Test right scalar multiplication + matscal_right = M1 * 3.0 + res8 = assemble(matscal_right) + assert isinstance(matscal_right, ufl.FormSum) + assert isinstance(res8, matrix.AssembledMatrix) + assert np.allclose(res8.petscmat[:, :], 3.0 * M1.petscmat[:, :]) + + # Test with Constant + c = Constant(4.0) + matscal_const = c * M1 + res_const = assemble(matscal_const) + assert isinstance(matscal_const, ufl.FormSum) + assert isinstance(res_const, matrix.AssembledMatrix) + assert np.allclose(res_const.petscmat[:, :], 4.0 * M1.petscmat[:, :]) + + +def test_matrix_scalar_division(matrices): + M1, _, _, _ = matrices + + # Test scalar division + matdiv = M1 / 2.0 + res9 = assemble(matdiv) + assert isinstance(matdiv, ufl.FormSum) + assert isinstance(res9, matrix.AssembledMatrix) + assert np.allclose(res9.petscmat[:, :], 0.5 * M1.petscmat[:, :]) + + # Test division by Constant + c = Constant(4.0) + matdiv_const = M1 / c + res_const = assemble(matdiv_const) + assert isinstance(matdiv_const, ufl.FormSum) + assert isinstance(res_const, matrix.AssembledMatrix) + assert np.allclose(res_const.petscmat[:, :], 0.25 * M1.petscmat[:, :]) + + +def test_matrix_negation(matrices): + M1, _, _, _ = matrices + + matneg_symb = -M1 + res_neg = assemble(matneg_symb) + assert isinstance(matneg_symb, ufl.FormSum) + assert isinstance(res_neg, matrix.AssembledMatrix) + assert np.allclose(res_neg.petscmat[:, :], -1 * M1.petscmat[:, :]) + + matneg2_symb = -matneg_symb + res_neg2 = assemble(matneg2_symb) + isinstance(matneg2_symb, ufl.FormSum) + isinstance(res_neg2, matrix.AssembledMatrix) + assert np.allclose(res_neg2.petscmat[:, :], M1.petscmat[:, :]) + + +def test_matrix_vector_product(matrices): + M1, _, I1, I12 = matrices + + V1 = M1.arguments()[0].function_space() + + f = Function(V1).assign(1.0) + matvec = I1 @ f + assert isinstance(matvec, ufl.Action) + res4 = assemble(matvec) + assert isinstance(res4, Function) + assert np.allclose(res4.dat.data[:], f.dat.data[:]) + + # test vector-matrix product + x, y = SpatialCoordinate(V1.mesh()) + f = Function(V1).interpolate(x + y) + with pytest.raises(TypeError, match=r"unsupported operand type\(s\) for @: 'Function' and 'AssembledMatrix'"): + f @ I12 + + +def test_cofunction_matrix_product(matrices): + M1, _, I1, I12 = matrices + + V1 = M1.arguments()[0].function_space() + V2 = I12.arguments()[0].function_space().dual() + + f = assemble(conj(TestFunction(V2)) * dx) # Cofunction in V2* + vecmat = f @ I12 # adjoint interpolation from V2^* to V1^* + assert isinstance(vecmat, ufl.Action) + res5 = assemble(vecmat) + assert isinstance(res5, Cofunction) + res5_comp = assemble(conj(TestFunction(V1)) * dx) + assert np.allclose(res5.dat.data_ro[:], res5_comp.dat.data_ro[:]) + + I12_adj = assemble(adjoint(I12)) + vecmat = I12_adj @ f + assert isinstance(vecmat, ufl.Action) + res6 = assemble(vecmat) + assert isinstance(res6, Cofunction) + res6_comp = assemble(conj(TestFunction(V1)) * dx) + assert np.allclose(res6.dat.data_ro[:], res6_comp.dat.data_ro[:])