Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 49 additions & 4 deletions firedrake/matrix.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import itertools
from numbers import Complex
import ufl

from pyop2 import op2
from pyop2.mpi import internal_comm
from pyop2.utils import as_tuple
from firedrake.petsc import PETSc
from firedrake.function import Function
from firedrake.cofunction import Cofunction
from firedrake.constant import Constant


class DummyOP2Mat:
Expand Down Expand Up @@ -108,18 +112,59 @@ def __str__(self):

def __add__(self, other):
if isinstance(other, MatrixBase):
mat = self.petscmat + other.petscmat
return AssembledMatrix(self.arguments(), (), mat)
if self.arguments() != other.arguments():
raise ValueError("Arguments in matrix addition must match.")
return ufl.FormSum((self, 1.), (other, 1.))
else:
return NotImplemented

def __sub__(self, other):
if isinstance(other, MatrixBase):
mat = self.petscmat - other.petscmat
return AssembledMatrix(self.arguments(), (), mat)
if self.arguments() != other.arguments():
raise ValueError("Arguments in matrix subtraction must match.")
return ufl.FormSum((self, 1.), (other, -1.))
else:
return NotImplemented

def __matmul__(self, other):
Copy link
Contributor

@pbrubeck pbrubeck Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this class inherits from ufl.Matrix, and ufl.Matrix is a BaseForm, shouldn't it already inherit addition and scalar multiplication?

I'm not aware of matrix multiplication being implemented in BaseForm, but it seems that it would not harm to have it there.

Copy link
Contributor

@pbrubeck pbrubeck Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Matrix multiplication is there, but it seems that it only supports multiplication times a Function (and breaks for Cofunction and Matrix).

https://github.com/FEniCS/ufl/blob/main/ufl/form.py#L247

https://github.com/FEniCS/ufl/blob/main/ufl/form.py#L219-L225

if isinstance(other, MatrixBase | Function | Cofunction):
return ufl.Action(self, other)
else:
return NotImplemented

def __rmatmul__(self, other):
if isinstance(other, MatrixBase):
return ufl.Action(other, self)
elif isinstance(other, Cofunction):
return ufl.Action(ufl.Adjoint(self), other)
else:
return NotImplemented

def __mul__(self, other):
# Scalar multiplication
if isinstance(other, Complex | Constant):
return ufl.FormSum((self, other))
else:
return NotImplemented

def __rmul__(self, other):
# Scalar multiplication from the left
if isinstance(other, Complex | Constant):
return ufl.FormSum((self, other))
else:
return NotImplemented

def __truediv__(self, other):
# Scalar division
if isinstance(other, Complex | Constant):
other = other.values().item() if isinstance(other, Constant) else other
Copy link
Contributor

@pbrubeck pbrubeck Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit dangerous, we are silently freezing the Constant. This returns a symbolic expression that no longer makes reference to the Constant, which is inconsistent with how scalar multiplication is implemented.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be better to leave this one for a separate PR where we enable BaseFormAssembler to numerically evaluate constant-valued symbolic expressions in the weights of a FormSum.

return ufl.FormSum((self, 1.0/other))
else:
return NotImplemented

def __neg__(self):
return ufl.FormSum((self, -1.))

def assign(self, val):
"""Set matrix entries."""
if isinstance(val, MatrixBase):
Expand Down
174 changes: 174 additions & 0 deletions tests/firedrake/regression/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,177 @@ def test_solve_with_assembled_matrix(a):
solve(A == L, solution)

assert norm(assemble(f - solution)) < 1e-15


@pytest.fixture
def matrices():
mesh = UnitSquareMesh(4, 4)
V1 = FunctionSpace(mesh, "CG", 1)
V2 = FunctionSpace(mesh, "CG", 2)

M1 = assemble(inner(TrialFunction(V1), TestFunction(V1)) * dx)
M1_2 = assemble(inner(2 * TrialFunction(V1), TestFunction(V1)) * dx)
I1 = assemble(interpolate(TrialFunction(V1), V1)) # Identity on V1
I12 = assemble(interpolate(TrialFunction(V1), V2)) # Interp from V1 to V2
return M1, M1_2, I1, I12


def test_matrix_matmul(matrices):
M1, M1_2, I1, _ = matrices

# Test matrix-matrix multiplication
matmul_symb = M1 @ I1
res1 = assemble(matmul_symb)
assert isinstance(matmul_symb, ufl.Action)
assert isinstance(res1, matrix.AssembledMatrix)
assert np.allclose(res1.petscmat[:, :], M1.petscmat[:, :])

# Incompatible matmul
with pytest.raises(TypeError, match="Incompatible function spaces in Action"):
M1 @ M1_2


def test_matrix_addition(matrices):
M1, M1_2, I1, _ = matrices

V1 = M1.arguments()[0].function_space()

# Test matrix addition
matadd_symb = I1 + I1
res2 = assemble(matadd_symb)
assert isinstance(matadd_symb, ufl.FormSum)
assert isinstance(res2, matrix.AssembledMatrix)
assert np.allclose(res2.petscmat[:, :], 2 * np.eye(V1.dim()))

# Matrix addition incorrect args
with pytest.raises(ValueError, match="Arguments in matrix addition must match."):
I1 + M1

# Different matrices, correct arguments
matadd_symb3 = M1 + M1_2
res_add3 = assemble(matadd_symb3)
assert isinstance(matadd_symb3, ufl.FormSum)
assert isinstance(res_add3, matrix.AssembledMatrix)
assert np.allclose(res_add3.petscmat[:, :], 3 * M1.petscmat[:, :])


def test_matrix_subtraction(matrices):
M1, M1_2, I1, _ = matrices

matsub = M1 - M1_2
res_sub = assemble(matsub)
assert isinstance(matsub, ufl.FormSum)
assert isinstance(res_sub, matrix.AssembledMatrix)
assert np.allclose(res_sub.petscmat[:, :], -1 * M1.petscmat[:, :])

matsub2 = M1_2 - M1
res_sub2 = assemble(matsub2)
assert isinstance(matsub2, ufl.FormSum)
assert isinstance(res_sub2, matrix.AssembledMatrix)
assert np.allclose(res_sub2.petscmat[:, :], M1.petscmat[:, :])

with pytest.raises(ValueError, match="Arguments in matrix subtraction must match."):
I1 - M1


def test_matrix_scalar_multiplication(matrices):
M1, _, _, _ = matrices

# Test left scalar multiplication
matscal_left = 2.5 * M1
res7 = assemble(matscal_left)
assert isinstance(matscal_left, ufl.FormSum)
assert isinstance(res7, matrix.AssembledMatrix)
assert np.allclose(res7.petscmat[:, :], 2.5 * M1.petscmat[:, :])

# Test right scalar multiplication
matscal_right = M1 * 3.0
res8 = assemble(matscal_right)
assert isinstance(matscal_right, ufl.FormSum)
assert isinstance(res8, matrix.AssembledMatrix)
assert np.allclose(res8.petscmat[:, :], 3.0 * M1.petscmat[:, :])

# Test with Constant
c = Constant(4.0)
matscal_const = c * M1
res_const = assemble(matscal_const)
assert isinstance(matscal_const, ufl.FormSum)
assert isinstance(res_const, matrix.AssembledMatrix)
assert np.allclose(res_const.petscmat[:, :], 4.0 * M1.petscmat[:, :])


def test_matrix_scalar_division(matrices):
M1, _, _, _ = matrices

# Test scalar division
matdiv = M1 / 2.0
res9 = assemble(matdiv)
assert isinstance(matdiv, ufl.FormSum)
assert isinstance(res9, matrix.AssembledMatrix)
assert np.allclose(res9.petscmat[:, :], 0.5 * M1.petscmat[:, :])

# Test division by Constant
c = Constant(4.0)
matdiv_const = M1 / c
res_const = assemble(matdiv_const)
assert isinstance(matdiv_const, ufl.FormSum)
assert isinstance(res_const, matrix.AssembledMatrix)
assert np.allclose(res_const.petscmat[:, :], 0.25 * M1.petscmat[:, :])


def test_matrix_negation(matrices):
M1, _, _, _ = matrices

matneg_symb = -M1
res_neg = assemble(matneg_symb)
assert isinstance(matneg_symb, ufl.FormSum)
assert isinstance(res_neg, matrix.AssembledMatrix)
assert np.allclose(res_neg.petscmat[:, :], -1 * M1.petscmat[:, :])

matneg2_symb = -matneg_symb
res_neg2 = assemble(matneg2_symb)
isinstance(matneg2_symb, ufl.FormSum)
isinstance(res_neg2, matrix.AssembledMatrix)
assert np.allclose(res_neg2.petscmat[:, :], M1.petscmat[:, :])


def test_matrix_vector_product(matrices):
M1, _, I1, I12 = matrices

V1 = M1.arguments()[0].function_space()

f = Function(V1).assign(1.0)
matvec = I1 @ f
assert isinstance(matvec, ufl.Action)
res4 = assemble(matvec)
assert isinstance(res4, Function)
assert np.allclose(res4.dat.data[:], f.dat.data[:])

# test vector-matrix product
x, y = SpatialCoordinate(V1.mesh())
f = Function(V1).interpolate(x + y)
with pytest.raises(TypeError, match=r"unsupported operand type\(s\) for @: 'Function' and 'AssembledMatrix'"):
f @ I12


def test_cofunction_matrix_product(matrices):
M1, _, I1, I12 = matrices

V1 = M1.arguments()[0].function_space()
V2 = I12.arguments()[0].function_space().dual()

f = assemble(conj(TestFunction(V2)) * dx) # Cofunction in V2*
vecmat = f @ I12 # adjoint interpolation from V2^* to V1^*
assert isinstance(vecmat, ufl.Action)
res5 = assemble(vecmat)
assert isinstance(res5, Cofunction)
res5_comp = assemble(conj(TestFunction(V1)) * dx)
assert np.allclose(res5.dat.data_ro[:], res5_comp.dat.data_ro[:])

I12_adj = assemble(adjoint(I12))
vecmat = I12_adj @ f
assert isinstance(vecmat, ufl.Action)
res6 = assemble(vecmat)
assert isinstance(res6, Cofunction)
res6_comp = assemble(conj(TestFunction(V1)) * dx)
assert np.allclose(res6.dat.data_ro[:], res6_comp.dat.data_ro[:])
Loading