Skip to content

Commit 9ee597f

Browse files
committed
fix
tidy
1 parent c2a63f0 commit 9ee597f

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

firedrake/interpolation.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -680,8 +680,10 @@ def _get_callable(self, tensor=None, bcs=None):
680680
assert isinstance(f, Cofunction)
681681

682682
def callable() -> Cofunction:
683-
with self.dual_arg.dat.vec_ro as source_vec, f.dat.vec_wo as target_vec:
684-
self.mat.handle.multHermitian(source_vec, target_vec)
683+
with self.dual_arg.dat.vec_ro as source_vec:
684+
coeff = self.mat.expr_as_coeff(source_vec)
685+
with coeff.dat.vec_ro as coeff_vec, f.dat.vec_wo as target_vec:
686+
self.mat.handle.multHermitian(coeff_vec, target_vec)
685687
return f
686688
else:
687689
assert isinstance(f, Function)
@@ -1209,6 +1211,8 @@ def __init__(self, interpolator: VomOntoVomInterpolator):
12091211
"""The PETSc Star Forest representing the permutation between the VOMs."""
12101212
self.target_space = interpolator.target_space
12111213
"""The FunctionSpace being interpolated into."""
1214+
self.target_vom = interpolator.target_mesh
1215+
"""The VOM being interpolated to."""
12121216
self.source_vom = interpolator.source_mesh
12131217
"""The VOM being interpolated from."""
12141218
self.operand = interpolator.operand
@@ -1280,24 +1284,26 @@ def expr_as_coeff(self, source_vec: PETSc.Vec | None = None) -> Function:
12801284
# so its dat can be sent to the target mesh.
12811285
with stop_annotating():
12821286
element = self.target_space.ufl_element() # Could be vector/tensor valued
1283-
P0DG = FunctionSpace(self.source_vom, element)
12841287
# if we have any arguments in the expression we need to replace
12851288
# them with equivalent coefficients now
1286-
coeff_expr = self.operand
12871289
if len(self.arguments):
12881290
if len(self.arguments) > 1:
1289-
raise NotImplementedError(
1290-
"Can only interpolate expressions with one argument!"
1291-
)
1291+
raise NotImplementedError("Can only interpolate expressions with one argument!")
12921292
if source_vec is None:
12931293
raise ValueError("Need to provide a source dat for the argument!")
1294+
12941295
arg = self.arguments[0]
1295-
arg_coeff = Function(arg.function_space())
1296+
source_space = arg.function_space()
1297+
P0DG = FunctionSpace(self.target_vom if self.is_adjoint else self.source_vom, element)
1298+
arg_coeff = Function(self.target_space if self.is_adjoint else source_space)
12961299
arg_coeff.dat.data_wo[:] = source_vec.getArray(readonly=True).reshape(
12971300
arg_coeff.dat.data_wo.shape
12981301
)
12991302
coeff_expr = replace(self.operand, {arg: arg_coeff})
1300-
coeff = Function(P0DG).interpolate(coeff_expr)
1303+
coeff = Function(P0DG).interpolate(coeff_expr)
1304+
else:
1305+
P0DG = FunctionSpace(self.source_vom, element)
1306+
coeff = Function(P0DG).interpolate(self.operand)
13011307
return coeff
13021308

13031309
def reduce(self, source_vec: PETSc.Vec, target_vec: PETSc.Vec) -> None:

0 commit comments

Comments
 (0)