@@ -680,8 +680,10 @@ def _get_callable(self, tensor=None, bcs=None):
680680 assert isinstance (f , Cofunction )
681681
682682 def callable () -> Cofunction :
683- with self .dual_arg .dat .vec_ro as source_vec , f .dat .vec_wo as target_vec :
684- self .mat .handle .multHermitian (source_vec , target_vec )
683+ with self .dual_arg .dat .vec_ro as source_vec :
684+ coeff = self .mat .expr_as_coeff (source_vec )
685+ with coeff .dat .vec_ro as coeff_vec , f .dat .vec_wo as target_vec :
686+ self .mat .handle .multHermitian (coeff_vec , target_vec )
685687 return f
686688 else :
687689 assert isinstance (f , Function )
@@ -1209,6 +1211,8 @@ def __init__(self, interpolator: VomOntoVomInterpolator):
12091211 """The PETSc Star Forest representing the permutation between the VOMs."""
12101212 self .target_space = interpolator .target_space
12111213 """The FunctionSpace being interpolated into."""
1214+ self .target_vom = interpolator .target_mesh
1215+ """The VOM being interpolated to."""
12121216 self .source_vom = interpolator .source_mesh
12131217 """The VOM being interpolated from."""
12141218 self .operand = interpolator .operand
@@ -1280,24 +1284,26 @@ def expr_as_coeff(self, source_vec: PETSc.Vec | None = None) -> Function:
12801284 # so its dat can be sent to the target mesh.
12811285 with stop_annotating ():
12821286 element = self .target_space .ufl_element () # Could be vector/tensor valued
1283- P0DG = FunctionSpace (self .source_vom , element )
12841287 # if we have any arguments in the expression we need to replace
12851288 # them with equivalent coefficients now
1286- coeff_expr = self .operand
12871289 if len (self .arguments ):
12881290 if len (self .arguments ) > 1 :
1289- raise NotImplementedError (
1290- "Can only interpolate expressions with one argument!"
1291- )
1291+ raise NotImplementedError ("Can only interpolate expressions with one argument!" )
12921292 if source_vec is None :
12931293 raise ValueError ("Need to provide a source dat for the argument!" )
1294+
12941295 arg = self .arguments [0 ]
1295- arg_coeff = Function (arg .function_space ())
1296+ source_space = arg .function_space ()
1297+ P0DG = FunctionSpace (self .target_vom if self .is_adjoint else self .source_vom , element )
1298+ arg_coeff = Function (self .target_space if self .is_adjoint else source_space )
12961299 arg_coeff .dat .data_wo [:] = source_vec .getArray (readonly = True ).reshape (
12971300 arg_coeff .dat .data_wo .shape
12981301 )
12991302 coeff_expr = replace (self .operand , {arg : arg_coeff })
1300- coeff = Function (P0DG ).interpolate (coeff_expr )
1303+ coeff = Function (P0DG ).interpolate (coeff_expr )
1304+ else :
1305+ P0DG = FunctionSpace (self .source_vom , element )
1306+ coeff = Function (P0DG ).interpolate (self .operand )
13011307 return coeff
13021308
13031309 def reduce (self , source_vec : PETSc .Vec , target_vec : PETSc .Vec ) -> None :
0 commit comments