-
Notifications
You must be signed in to change notification settings - Fork 2
Fix adjoint interpolation (take 2) #191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 9 commits
4e404cc
78b9efe
9148fe4
7b61902
21a2c9b
5cb4cd9
9896571
aa301f2
9f8b6bb
3d758a0
6de0c0a
fe602c3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -66,63 +66,53 @@ def transfer(source, target_space, transfer_method="project", **kwargs): | |
| raise TypeError( | ||
| "Second argument must be a FunctionSpace, Function, or Cofunction." | ||
| ) | ||
| if isinstance(source, firedrake.Cofunction): | ||
| return _transfer_adjoint(source, target, transfer_method, **kwargs) | ||
| elif source.function_space() == target.function_space(): | ||
| return target.assign(source) | ||
| if transfer_method == "interpolate": | ||
| return interpolate(source, target, **kwargs) | ||
| else: | ||
| return _transfer_forward(source, target, transfer_method, **kwargs) | ||
| return project(source, target, **kwargs) | ||
|
|
||
|
|
||
| @PETSc.Log.EventDecorator() | ||
| def interpolate(source, target_space, **kwargs): | ||
| r""" | ||
| Overload function :func:`firedrake.__future__.interpolate` to account for the case | ||
| of two mixed function spaces defined on different meshes and for the adjoint | ||
| interpolation operator when applied to :class:`firedrake.cofunction.Cofunction`\s. | ||
|
|
||
| :arg source: the function to be transferred | ||
| :type source: :class:`firedrake.function.Function` or | ||
| :class:`firedrake.cofunction.Cofunction` | ||
| :arg target_space: the function space which we seek to transfer onto, or the | ||
| function or cofunction to use as the target | ||
| :type target_space: :class:`firedrake.functionspaceimpl.FunctionSpace`, | ||
| :class:`firedrake.function.Function` or :class:`firedrake.cofunction.Cofunction` | ||
| :returns: the transferred function | ||
| :rtype: :class:`firedrake.function.Function` or | ||
| :class:`firedrake.cofunction.Cofunction` | ||
|
|
||
| Extra keyword arguments are passed to :func:`firedrake.__future__.interpolate` | ||
| """ | ||
| return transfer(source, target_space, transfer_method="interpolate", **kwargs) | ||
| def _validate_consistent_spaces(Vs, Vt): | ||
| if Vs._dual != Vt._dual: | ||
| raise ValueError("Spaces must be both primal or both dual.") | ||
| if hasattr(Vs, "num_sub_spaces"): | ||
| if not hasattr(Vt, "num_sub_spaces"): | ||
| raise ValueError( | ||
| "Source space has multiple components but target space does not." | ||
| ) | ||
| if Vs.num_sub_spaces() != Vt.num_sub_spaces(): | ||
| raise ValueError( | ||
| "Inconsistent numbers of components in source and target spaces:" | ||
| f" {Vs.num_sub_spaces()} vs. {Vt.num_sub_spaces()}." | ||
| ) | ||
| elif hasattr(Vt, "num_sub_spaces"): | ||
| raise ValueError( | ||
| "Target space has multiple components but source space does not." | ||
| ) | ||
|
|
||
|
|
||
| @PETSc.Log.EventDecorator() | ||
| def project(source, target_space, **kwargs): | ||
| def interpolate(source, target, **kwargs): | ||
| r""" | ||
| Overload function :func:`firedrake.projection.project` to account for the case of | ||
| two mixed function spaces defined on different meshes and for the adjoint | ||
| projection operator when applied to :class:`firedrake.cofunction.Cofunction`\s. | ||
|
|
||
| For details on the approach for achieving boundedness through mass lumping and | ||
| post-processing, see :cite:`Farrell:2009`. | ||
| Overload :func:`firedrake.__future__.interpolate` to account for the case of mixed | ||
| function spaces. | ||
|
|
||
| :arg source: the function to be transferred | ||
| :arg source: the function or cofunction to be transferred | ||
| :type source: :class:`firedrake.function.Function` or | ||
| :class:`firedrake.cofunction.Cofunction` | ||
| :arg target_space: the function space which we seek to transfer onto, or the | ||
| function or cofunction to use as the target | ||
| :type target_space: :class:`firedrake.functionspaceimpl.FunctionSpace`, | ||
| :class:`firedrake.function.Function` or :class:`firedrake.cofunction.Cofunction` | ||
| :returns: the transferred function | ||
| :rtype: :class:`firedrake.function.Function` or | ||
| :arg target: the function or cofunction to use as the target, which is modified in | ||
| place | ||
| :type target: :class:`firedrake.function.Function` or | ||
| :class:`firedrake.cofunction.Cofunction` | ||
| :kwarg bounded: apply mass lumping to the mass matrix to ensure boundedness | ||
| :type bounded: :class:`bool` | ||
|
|
||
| Extra keyword arguments are passed to :func:`firedrake.projection.project`. | ||
| Extra keyword arguments are passed to :func:`firedrake.__future__.interpolate` | ||
|
||
| """ | ||
| return transfer(source, target_space, transfer_method="project", **kwargs) | ||
| _validate_consistent_spaces(source.function_space(), target.function_space()) | ||
| if hasattr(target.function_space(), "num_sub_spaces"): | ||
| for s, t in zip(source.subfunctions, target.subfunctions): | ||
| t.interpolate(s, **kwargs) | ||
| else: | ||
| target.interpolate(source, **kwargs) | ||
|
|
||
|
|
||
| # TODO: Reimplement by introducing a LumpedSupermeshProjector subclass of | ||
|
|
@@ -148,148 +138,46 @@ def _supermesh_project(source, target, bounded=False): | |
|
|
||
|
|
||
| @PETSc.Log.EventDecorator() | ||
| def _transfer_forward(source, target, transfer_method, **kwargs): | ||
| """ | ||
| Apply mesh-to-mesh transfer operator to a Function. | ||
| def project(source, target, bounded=False, **kwargs): | ||
| r""" | ||
| Overload :func:`firedrake.projection.project` to account for the case of mixed | ||
| function spaces. | ||
|
||
|
|
||
| This function extends the functionality of :func:`firedrake.__future__.interpolate` | ||
| and :func:`firedrake.projection.project` to account for mixed spaces. | ||
| For details on the approach for achieving boundedness through mass lumping and | ||
| post-processing, see :cite:`Farrell:2009`. | ||
|
|
||
| :arg source: the Function to be transferred | ||
| :type source: :class:`firedrake.function.Function` | ||
| :arg target: the Function which we seek to transfer onto | ||
| :type target: :class:`firedrake.function.Function` | ||
| :kwarg transfer_method: the method to use for the transfer. Options are | ||
| "interpolate" (default) and "project" | ||
| :type transfer_method: str | ||
| :arg source: the function or cofunction to be transferred | ||
| :type source: :class:`firedrake.function.Function` or | ||
| :class:`firedrake.cofunction.Cofunction` | ||
| :arg target: the function or cofunction to transfer onto, which is modified in | ||
| place | ||
| :type target: :class:`firedrake.function.Function` or | ||
| :class:`firedrake.cofunction.Cofunction` | ||
| :kwarg bounded: apply mass lumping to the mass matrix to ensure boundedness | ||
| (project only) | ||
| :type bounded: :class:`bool` | ||
| :returns: the transferred Function | ||
| :rtype: :class:`firedrake.function.Function` | ||
|
|
||
| Extra keyword arguments are passed to :func:`firedrake.__future__.interpolate` or | ||
| :func:`firedrake.projection.project`. | ||
| Extra keyword arguments are passed to :func:`firedrake.projection.project`. | ||
| """ | ||
| is_project = transfer_method == "project" | ||
| bounded = is_project and kwargs.pop("bounded", False) | ||
| Vs = source.function_space() | ||
| Vt = target.function_space() | ||
| _validate_matching_spaces(Vs, Vt) | ||
| assert isinstance(target, firedrake.Function) | ||
| if hasattr(Vt, "num_sub_spaces"): | ||
| _validate_consistent_spaces(source.function_space(), target.function_space()) | ||
| if not source.function_space()._dual: | ||
| for s, t in zip(source.subfunctions, target.subfunctions): | ||
| if transfer_method == "interpolate": | ||
| t.interpolate(s, **kwargs) | ||
| elif transfer_method == "project": | ||
| if bounded: | ||
| _supermesh_project(s, t, bounded=True) | ||
| else: | ||
| t.project(s, **kwargs) | ||
| else: | ||
| raise ValueError( | ||
| f"Invalid transfer method: {transfer_method}." | ||
| " Options are 'interpolate' and 'project'." | ||
| ) | ||
| else: | ||
| if transfer_method == "interpolate": | ||
| target.interpolate(source, **kwargs) | ||
| elif transfer_method == "project": | ||
| if bounded: | ||
| _supermesh_project(source, target, bounded=True) | ||
| _supermesh_project(s, t, bounded=True) | ||
| else: | ||
| target.project(source, **kwargs) | ||
| else: | ||
| raise ValueError( | ||
| f"Invalid transfer method: {transfer_method}." | ||
| " Options are 'interpolate' and 'project'." | ||
| ) | ||
| return target | ||
|
|
||
|
|
||
| @PETSc.Log.EventDecorator() | ||
| def _transfer_adjoint(target_b, source_b, transfer_method, **kwargs): | ||
| """ | ||
| Apply an adjoint mesh-to-mesh transfer operator to a Cofunction. | ||
|
|
||
| :arg target_b: seed Cofunction from the target space of the forward projection | ||
| :type target_b: :class:`firedrake.cofunction.Cofunction` | ||
| :arg source_b: output Cofunction from the source space of the forward projection | ||
| :type source_b: :class:`firedrake.cofunction.Cofunction` | ||
| :kwarg transfer_method: the method to use for the transfer. Options are | ||
| "interpolate" (default) and "project" | ||
| :type transfer_method: str | ||
| :kwarg bounded: apply mass lumping to the mass matrix to ensure boundedness | ||
| (project only) | ||
| :type bounded: :class:`bool` | ||
| :returns: the transferred Cofunction | ||
| :rtype: :class:`firedrake.cofunction.Cofunction` | ||
|
|
||
| Extra keyword arguments are passed to :func:`firedrake.__future__.interpolate` or | ||
| :func:`firedrake.projection.project`. | ||
| """ | ||
| is_project = transfer_method == "project" | ||
| bounded = is_project and kwargs.pop("bounded", False) | ||
|
|
||
| # Map to Functions to apply the adjoint transfer | ||
| if not isinstance(target_b, firedrake.Function): | ||
| target_b = cofunction2function(target_b) | ||
| if not isinstance(source_b, firedrake.Function): | ||
| source_b = cofunction2function(source_b) | ||
|
|
||
| Vt = target_b.function_space() | ||
| Vs = source_b.function_space() | ||
| if Vs == Vt: | ||
| source_b.assign(target_b) | ||
| return function2cofunction(source_b) | ||
|
|
||
| _validate_matching_spaces(Vs, Vt) | ||
| if hasattr(Vs, "num_sub_spaces"): | ||
| target_b_split = target_b.subfunctions | ||
| source_b_split = source_b.subfunctions | ||
| t.project(s, **kwargs) | ||
| else: | ||
| target_b_split = [target_b] | ||
| source_b_split = [source_b] | ||
|
|
||
| # Apply adjoint transfer operator to each component | ||
| for i, (t_b, s_b) in enumerate(zip(target_b_split, source_b_split)): | ||
| if transfer_method == "interpolate": | ||
| raise NotImplementedError( | ||
| "Adjoint of interpolation operator not implemented." | ||
| ) # TODO (#113) | ||
| elif transfer_method == "project": | ||
| for s, t in zip(source.subfunctions, target.subfunctions): | ||
| sf = cofunction2function(s) | ||
| tf = cofunction2function(t) | ||
| Vs = sf.function_space() | ||
| ksp = petsc4py.KSP().create() | ||
| ksp.setOperators(assemble_mass_matrix(t_b.function_space(), lumped=bounded)) | ||
| mixed_mass = assemble_mixed_mass_matrix(Vt[i], Vs[i]) | ||
| with t_b.dat.vec_ro as tb, s_b.dat.vec_wo as sb: | ||
| residual = tb.copy() | ||
| ksp.solveTranspose(tb, residual) | ||
| mixed_mass.mult(residual, sb) # NOTE: already transposed above | ||
| else: | ||
| raise ValueError( | ||
| f"Invalid transfer method: {transfer_method}." | ||
| " Options are 'interpolate' and 'project'." | ||
| ) | ||
|
|
||
| # Map back to a Cofunction | ||
| return function2cofunction(source_b) | ||
|
|
||
|
|
||
| def _validate_matching_spaces(Vs, Vt): | ||
| if hasattr(Vs, "num_sub_spaces"): | ||
| if not hasattr(Vt, "num_sub_spaces"): | ||
| raise ValueError( | ||
| "Source space has multiple components but target space does not." | ||
| ) | ||
| if Vs.num_sub_spaces() != Vt.num_sub_spaces(): | ||
| raise ValueError( | ||
| "Inconsistent numbers of components in source and target spaces:" | ||
| f" {Vs.num_sub_spaces()} vs. {Vt.num_sub_spaces()}." | ||
| ) | ||
| elif hasattr(Vt, "num_sub_spaces"): | ||
| raise ValueError( | ||
| "Target space has multiple components but source space does not." | ||
| ) | ||
| ksp.setOperators(assemble_mass_matrix(Vs, lumped=bounded)) | ||
| mixed_mass = assemble_mixed_mass_matrix(Vs, tf.function_space()) | ||
| with sf.dat.vec_ro as vs, tf.dat.vec_wo as vt: | ||
| residual = vs.copy() | ||
| ksp.solveTranspose(vs, residual) | ||
| mixed_mass.mult(residual, vt) # NOTE: already transposed above | ||
| function2cofunction(tf, cofunc=t) | ||
|
|
||
|
|
||
| @PETSc.Log.EventDecorator() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is no longer in
__future__(the future is now!): firedrakeproject/firedrake#4346There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh nice, thanks. Done in 6de0c0a.