@@ -209,7 +209,7 @@ def __init__(self, expr: Interpolate):
209209 """The dual argument slot of the Interpolate expression."""
210210 self .target_space = dual_arg .function_space ().dual ()
211211 """The primal space we are interpolating into."""
212- self .target_mesh = self .target_space .mesh ()
212+ self .target_mesh = self .target_space .mesh (). unique ()
213213 """The domain we are interpolating into."""
214214 self .source_mesh = extract_unique_domain (operand ) or self .target_mesh
215215 """The domain we are interpolating from."""
@@ -312,7 +312,18 @@ def get_interpolator(expr: Interpolate) -> Interpolator:
312312
313313 operand , = expr .ufl_operands
314314 target_mesh = expr .target_space .mesh ()
315- source_mesh = extract_unique_domain (operand ) or target_mesh
315+
316+ try :
317+ source_mesh = extract_unique_domain (operand ) or target_mesh
318+ except ValueError :
319+ raise NotImplementedError ("Interpolating an expression defined on multiple meshes is not implemented yet." )
320+
321+ try :
322+ target_mesh = target_mesh .unique ()
323+ source_mesh = source_mesh .unique ()
324+ except RuntimeError :
325+ return MixedInterpolator (expr )
326+
316327 submesh_interp_implemented = (
317328 all (isinstance (m .topology , MeshTopology ) for m in [target_mesh , source_mesh ])
318329 and target_mesh .submesh_ancesters [- 1 ] is source_mesh .submesh_ancesters [- 1 ]
@@ -321,11 +332,8 @@ def get_interpolator(expr: Interpolate) -> Interpolator:
321332 if target_mesh is source_mesh or submesh_interp_implemented :
322333 return SameMeshInterpolator (expr )
323334
324- target_topology = target_mesh .topology
325- source_topology = source_mesh .topology
326-
327- if isinstance (target_topology , VertexOnlyMeshTopology ):
328- if isinstance (source_topology , VertexOnlyMeshTopology ):
335+ if isinstance (target_mesh .topology , VertexOnlyMeshTopology ):
336+ if isinstance (source_mesh .topology , VertexOnlyMeshTopology ):
329337 return VomOntoVomInterpolator (expr )
330338 if target_mesh .geometric_dimension != source_mesh .geometric_dimension :
331339 raise ValueError ("Cannot interpolate onto a VertexOnlyMesh of a different geometric dimension." )
@@ -614,10 +622,19 @@ def _get_tensor(self) -> op2.Mat | Function | Cofunction:
614622 return f
615623
616624 def _get_callable (self , tensor = None , bcs = None ):
617- f = tensor or self ._get_tensor ()
618- op2_tensor = f if isinstance (f , op2 .Mat ) else f .dat
625+ if (isinstance (tensor , Cofunction ) and isinstance (self .dual_arg , Cofunction )) and set (tensor .dat ).intersection (set (self .dual_arg .dat )):
626+ # adjoint one-form case: we need a zero tensor, so if it shares dats with
627+ # the dual_arg we cannot use it directly
628+ f = self ._get_tensor ()
629+ copyout = (partial (f .dat .copy , tensor .dat ),)
630+ else :
631+ f = tensor or self ._get_tensor ()
632+ copyout = ()
619633
634+ op2_tensor = f if isinstance (f , op2 .Mat ) else f .dat
620635 loops = []
636+ if self .access is op2 .INC :
637+ loops .append (op2_tensor .zero )
621638
622639 # Arguments in the operand are allowed to be from a MixedFunctionSpace
623640 # We need to split the target space V and generate separate kernels
@@ -646,6 +663,8 @@ def _get_callable(self, tensor=None, bcs=None):
646663 if bcs and self .rank == 1 :
647664 loops .extend (partial (bc .apply , f ) for bc in bcs )
648665
666+ loops .extend (copyout )
667+
649668 def callable () -> Function | Cofunction | PETSc .Mat | Number :
650669 for l in loops :
651670 l ()
@@ -912,8 +931,6 @@ def _build_interpolation_callables(
912931 if isinstance (tensor , op2 .Mat ):
913932 return parloop , tensor .assemble
914933 else :
915- if access == op2 .INC :
916- copyin += (tensor .zero ,)
917934 return copyin + (parloop , ) + copyout
918935
919936
0 commit comments