-
Notifications
You must be signed in to change notification settings - Fork 177
add matrix magic methods #4729
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
add matrix magic methods #4729
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,10 +1,14 @@ | ||
| import itertools | ||
| from numbers import Complex | ||
| import ufl | ||
|
|
||
| from pyop2 import op2 | ||
| from pyop2.mpi import internal_comm | ||
| from pyop2.utils import as_tuple | ||
| from firedrake.petsc import PETSc | ||
| from firedrake.function import Function | ||
| from firedrake.cofunction import Cofunction | ||
| from firedrake.constant import Constant | ||
|
|
||
|
|
||
| class DummyOP2Mat: | ||
|
|
@@ -108,18 +112,59 @@ def __str__(self): | |
|
|
||
| def __add__(self, other): | ||
| if isinstance(other, MatrixBase): | ||
| mat = self.petscmat + other.petscmat | ||
| return AssembledMatrix(self.arguments(), (), mat) | ||
| if self.arguments() != other.arguments(): | ||
| raise ValueError("Arguments in matrix addition must match.") | ||
| return ufl.FormSum((self, 1.), (other, 1.)) | ||
| else: | ||
| return NotImplemented | ||
|
|
||
| def __sub__(self, other): | ||
| if isinstance(other, MatrixBase): | ||
| mat = self.petscmat - other.petscmat | ||
| return AssembledMatrix(self.arguments(), (), mat) | ||
| if self.arguments() != other.arguments(): | ||
| raise ValueError("Arguments in matrix subtraction must match.") | ||
| return ufl.FormSum((self, 1.), (other, -1.)) | ||
| else: | ||
| return NotImplemented | ||
|
|
||
| def __matmul__(self, other): | ||
| if isinstance(other, MatrixBase | Function | Cofunction): | ||
| return ufl.Action(self, other) | ||
| else: | ||
| return NotImplemented | ||
|
|
||
| def __rmatmul__(self, other): | ||
| if isinstance(other, MatrixBase): | ||
| return ufl.Action(other, self) | ||
| elif isinstance(other, Cofunction): | ||
| return ufl.Action(ufl.Adjoint(self), other) | ||
| else: | ||
| return NotImplemented | ||
|
|
||
| def __mul__(self, other): | ||
| # Scalar multiplication | ||
| if isinstance(other, Complex | Constant): | ||
connorjward marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return ufl.FormSum((self, other)) | ||
| else: | ||
| return NotImplemented | ||
|
|
||
| def __rmul__(self, other): | ||
| # Scalar multiplication from the left | ||
| if isinstance(other, Complex | Constant): | ||
| return ufl.FormSum((self, other)) | ||
| else: | ||
| return NotImplemented | ||
|
|
||
| def __truediv__(self, other): | ||
| # Scalar division | ||
| if isinstance(other, Complex | Constant): | ||
| other = other.values().item() if isinstance(other, Constant) else other | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a bit dangerous, we are silently freezing the
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It might be better to leave this one for a separate PR where we enable |
||
| return ufl.FormSum((self, 1.0/other)) | ||
| else: | ||
| return NotImplemented | ||
|
|
||
| def __neg__(self): | ||
| return ufl.FormSum((self, -1.)) | ||
|
|
||
| def assign(self, val): | ||
| """Set matrix entries.""" | ||
| if isinstance(val, MatrixBase): | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this class inherits from
ufl.Matrix, andufl.Matrixis aBaseForm, shouldn't it already inherit addition and scalar multiplication?I'm not aware of matrix multiplication being implemented in
BaseForm, but it seems that it would not harm to have it there.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Matrix multiplication is there, but it seems that it only supports multiplication times a
Function(and breaks forCofunctionandMatrix).https://github.com/FEniCS/ufl/blob/main/ufl/form.py#L247
https://github.com/FEniCS/ufl/blob/main/ufl/form.py#L219-L225