Skip to content

Commit b7368d7

Browse files
committed
mesh sequence fixes
fixes
1 parent 9ee597f commit b7368d7

File tree

1 file changed

+28
-11
lines changed

1 file changed

+28
-11
lines changed

firedrake/interpolation.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def __init__(self, expr: Interpolate):
209209
"""The dual argument slot of the Interpolate expression."""
210210
self.target_space = dual_arg.function_space().dual()
211211
"""The primal space we are interpolating into."""
212-
self.target_mesh = self.target_space.mesh()
212+
self.target_mesh = self.target_space.mesh().unique()
213213
"""The domain we are interpolating into."""
214214
self.source_mesh = extract_unique_domain(operand) or self.target_mesh
215215
"""The domain we are interpolating from."""
@@ -312,7 +312,18 @@ def get_interpolator(expr: Interpolate) -> Interpolator:
312312

313313
operand, = expr.ufl_operands
314314
target_mesh = expr.target_space.mesh()
315-
source_mesh = extract_unique_domain(operand) or target_mesh
315+
316+
try:
317+
source_mesh = extract_unique_domain(operand) or target_mesh
318+
except ValueError:
319+
raise NotImplementedError("Interpolating an expression defined on multiple meshes is not implemented yet.")
320+
321+
try:
322+
target_mesh = target_mesh.unique()
323+
source_mesh = source_mesh.unique()
324+
except RuntimeError:
325+
return MixedInterpolator(expr)
326+
316327
submesh_interp_implemented = (
317328
all(isinstance(m.topology, MeshTopology) for m in [target_mesh, source_mesh])
318329
and target_mesh.submesh_ancesters[-1] is source_mesh.submesh_ancesters[-1]
@@ -321,11 +332,8 @@ def get_interpolator(expr: Interpolate) -> Interpolator:
321332
if target_mesh is source_mesh or submesh_interp_implemented:
322333
return SameMeshInterpolator(expr)
323334

324-
target_topology = target_mesh.topology
325-
source_topology = source_mesh.topology
326-
327-
if isinstance(target_topology, VertexOnlyMeshTopology):
328-
if isinstance(source_topology, VertexOnlyMeshTopology):
335+
if isinstance(target_mesh.topology, VertexOnlyMeshTopology):
336+
if isinstance(source_mesh.topology, VertexOnlyMeshTopology):
329337
return VomOntoVomInterpolator(expr)
330338
if target_mesh.geometric_dimension != source_mesh.geometric_dimension:
331339
raise ValueError("Cannot interpolate onto a VertexOnlyMesh of a different geometric dimension.")
@@ -614,10 +622,19 @@ def _get_tensor(self) -> op2.Mat | Function | Cofunction:
614622
return f
615623

616624
def _get_callable(self, tensor=None, bcs=None):
617-
f = tensor or self._get_tensor()
618-
op2_tensor = f if isinstance(f, op2.Mat) else f.dat
625+
if (isinstance(tensor, Cofunction) and isinstance(self.dual_arg, Cofunction)) and set(tensor.dat).intersection(set(self.dual_arg.dat)):
626+
# adjoint one-form case: we need a zero tensor, so if it shares dats with
627+
# the dual_arg we cannot use it directly
628+
f = self._get_tensor()
629+
copyout = (partial(f.dat.copy, tensor.dat),)
630+
else:
631+
f = tensor or self._get_tensor()
632+
copyout = ()
619633

634+
op2_tensor = f if isinstance(f, op2.Mat) else f.dat
620635
loops = []
636+
if self.access is op2.INC:
637+
loops.append(op2_tensor.zero)
621638

622639
# Arguments in the operand are allowed to be from a MixedFunctionSpace
623640
# We need to split the target space V and generate separate kernels
@@ -646,6 +663,8 @@ def _get_callable(self, tensor=None, bcs=None):
646663
if bcs and self.rank == 1:
647664
loops.extend(partial(bc.apply, f) for bc in bcs)
648665

666+
loops.extend(copyout)
667+
649668
def callable() -> Function | Cofunction | PETSc.Mat | Number:
650669
for l in loops:
651670
l()
@@ -912,8 +931,6 @@ def _build_interpolation_callables(
912931
if isinstance(tensor, op2.Mat):
913932
return parloop, tensor.assemble
914933
else:
915-
if access == op2.INC:
916-
copyin += (tensor.zero,)
917934
return copyin + (parloop, ) + copyout
918935

919936

0 commit comments

Comments
 (0)