-
Notifications
You must be signed in to change notification settings - Fork 0
Add norms module #9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1,3 @@ | ||
| from adapt_common.norms import * # noqa | ||
| from adapt_common.reduction import * # noqa | ||
| from adapt_common.utility import * # noqa |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,163 @@ | ||
| """Module containing overloaded versions of Firedrake's norm and errornorm functions.""" | ||
|
|
||
| import firedrake as fd | ||
| import ufl | ||
| from firedrake.petsc import PETSc | ||
|
|
||
| from adapt_common.utility import cofunction2function | ||
|
|
||
| __all__ = ["errornorm", "norm"] | ||
|
|
||
|
|
||
| @PETSc.Log.EventDecorator() | ||
| def norm(v, norm_type="L2", condition=None, boundary=False): | ||
| r"""Overload :func:`fd.norms.norm` to allow for :math:`\ell^p` norms. | ||
|
|
||
| Currently supported ``norm_type`` options: | ||
| * ``'l1'`` | ||
| * ``'l2'`` | ||
| * ``'linf'`` | ||
| * ``'L2'`` | ||
| * ``'Linf'`` | ||
| * ``'H1'`` | ||
| * ``'Hdiv'`` | ||
| * ``'Hcurl'`` | ||
| * or any ``'Lp'`` with :math:`p >= 1`. | ||
|
|
||
| Note that this version is case sensitive, i.e. ``'l2'`` and ``'L2'`` will give | ||
| different results in general. | ||
|
|
||
| :arg v: the function to take the norm of | ||
| :type v: :class:`fd.function.Function` or | ||
| :class:`fd.cofunction.Cofunction` | ||
| :kwarg norm_type: the type of norm to use | ||
| :type norm_type: :class:`str` | ||
| :kwarg condition: a UFL condition for specifying a subdomain to compute the norm | ||
| over | ||
| :kwarg boundary: if ``True``, the norm is computed over the domain boundary | ||
| :type boundary: :class:`bool` | ||
| :returns: the norm value | ||
| :rtype: :class:`float` | ||
| """ | ||
| if isinstance(v, fd.Cofunction): | ||
| v = cofunction2function(v) | ||
| condition = condition or fd.Constant(1.0) | ||
| norm_codes = {"l1": 0, "l2": 2, "linf": 3} | ||
| p = 2 | ||
| if norm_type in norm_codes or norm_type == "Linf": | ||
| if boundary: | ||
| not_impl_err = "lp errors on the boundary not yet implemented." | ||
| raise NotImplementedError(not_impl_err) | ||
| v.interpolate(condition * v) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm, this clobbers the input Function |
||
| with v.dat.vec_ro as vv: | ||
| if norm_type == "Linf": | ||
| return vv.max()[1] | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this need an abs()? I presume "Linf" is in fact the same as "linf" ? |
||
| else: | ||
| return vv.norm(norm_codes[norm_type]) | ||
| elif norm_type[0] == "l": | ||
| not_impl_err = f"lp norm of order {norm_type[1:]} not supported." | ||
| raise NotImplementedError(not_impl_err) | ||
| else: | ||
| dX = ufl.ds if boundary else ufl.dx | ||
| if norm_type.startswith("L"): | ||
| try: | ||
| p = int(norm_type[1:]) | ||
| except ValueError as exc: | ||
| val_err = f"Unable to interpret '{norm_type}' norm." | ||
| raise ValueError(val_err) from exc | ||
| if p < 1: | ||
| val_err = f"'{norm_type}' norm does not make sense." | ||
| raise ValueError(val_err) | ||
| integrand = ufl.inner(v, v) | ||
| elif norm_type.lower() in ("h1", "hdiv", "hcurl"): | ||
| integrand = { | ||
| "h1": lambda w: ufl.inner(w, w) + ufl.inner(ufl.grad(w), ufl.grad(w)), | ||
| "hdiv": lambda w: ufl.inner(w, w) + ufl.div(w) * ufl.div(w), | ||
| "hcurl": lambda w: ufl.inner(w, w) | ||
| + ufl.inner(ufl.curl(w), ufl.curl(w)), | ||
| }[norm_type.lower()](v) | ||
|
Comment on lines
+73
to
+78
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm probably missing something, but what's the point of the lambda's? If you're worried about the cost of ufl symbolic assembly - I would just change it to an if block (or match case) |
||
| else: | ||
| val_err = f"Unknown norm type '{norm_type}'." | ||
| raise ValueError(val_err) | ||
| return fd.assemble(condition * integrand ** (p / 2) * dX) ** (1 / p) | ||
|
|
||
|
|
||
| @PETSc.Log.EventDecorator() | ||
| def errornorm(u, uh, norm_type="L2", boundary=False, **kwargs): | ||
| r"""Overload :func:`fd.norms.errornorm` to allow for :math:`\ell^p` norms. | ||
|
|
||
| Currently supported ``norm_type`` options: | ||
| * ``'l1'`` | ||
| * ``'l2'`` | ||
| * ``'linf'`` | ||
| * ``'L2'`` | ||
| * ``'Linf'`` | ||
| * ``'H1'`` | ||
| * ``'Hdiv'`` | ||
| * ``'Hcurl'`` | ||
| * or any ``'Lp'`` with :math:`p >= 1`. | ||
|
|
||
| Note that this version is case sensitive, i.e. ``'l2'`` and ``'L2'`` will give | ||
| different results in general. | ||
|
|
||
| :arg u: the 'true' value | ||
| :type u: :class:`fd.function.Function` or | ||
| :class:`fd.cofunction.Cofunction` | ||
| :arg uh: the approximation of the 'truth' | ||
| :type uh: :class:`fd.function.Function` or | ||
| :class:`fd.cofunction.Cofunction` | ||
| :kwarg norm_type: the type of norm to use | ||
| :type norm_type: :class:`str` | ||
| :kwarg boundary: if ``True``, the norm is computed over the domain boundary | ||
| :type boundary: :class:`bool` | ||
| :returns: the error norm value | ||
| :rtype: :class:`float` | ||
|
|
||
| Any other keyword arguments are passed to :func:`fd.norms.errornorm`. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's not true it seems |
||
| """ | ||
| if isinstance(u, fd.Cofunction): | ||
| u = cofunction2function(u) | ||
| if isinstance(uh, fd.Cofunction): | ||
| uh = cofunction2function(uh) | ||
| if not isinstance(uh, fd.Function): | ||
| type_err = f"uh should be a Function, is a '{type(uh)}'." | ||
| raise TypeError(type_err) | ||
| if norm_type[0] == "l" and not isinstance(u, fd.Function): | ||
| type_err = f"u should be a Function, is a '{type(u)}'." | ||
| raise TypeError(type_err) | ||
|
|
||
| if len(u.ufl_shape) != len(uh.ufl_shape): | ||
| val_err = "Mismatching rank between u and uh." | ||
| raise ValueError(val_err) | ||
|
|
||
| if isinstance(u, fd.Function): | ||
| degree_u = u.function_space().ufl_element().degree() | ||
| degree_uh = uh.function_space().ufl_element().degree() | ||
| if degree_uh > degree_u: | ||
| fd.logging.warning( | ||
| "Degree of exact solution less than approximation degree" | ||
| ) | ||
|
|
||
| # Case 1: point-wise norms | ||
| if norm_type[0] == "l": | ||
| v = u | ||
| v -= uh | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I'm not mistaken |
||
|
|
||
| # Case 2: UFL norms for mixed function spaces | ||
| elif hasattr(uh.function_space(), "num_sub_spaces"): | ||
| if norm_type == "L2": | ||
| vv = [ | ||
| uu - uuh | ||
| for uu, uuh in zip(u.subfunctions, uh.subfunctions, strict=False) | ||
| ] | ||
| dX = ufl.ds if boundary else ufl.dx | ||
| return ufl.sqrt(fd.assemble(sum([ufl.inner(v, v) for v in vv]) * dX)) | ||
| else: | ||
| not_impl_err = f"Norm type '{norm_type}' not supported for mixed spaces." | ||
| raise NotImplementedError(not_impl_err) | ||
|
|
||
| # Case 3: UFL norms for non-mixed spaces | ||
| else: | ||
| v = u - uh | ||
|
|
||
| return norm(v, norm_type=norm_type, **kwargs) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,111 @@ | ||
| """Unit tests for overloaded norm and errornorm functions.""" | ||
|
|
||
| import firedrake as fd | ||
| import numpy as np | ||
| import pytest | ||
| import ufl | ||
|
|
||
| from adapt_common.norms import errornorm, norm | ||
|
|
||
|
|
||
| @pytest.fixture(params=["L1", "L2", "L4", "H1", "HCurl"]) | ||
| def integral_norm_type(request): | ||
| """Fixture for integral norm types.""" | ||
| return request.param | ||
|
|
||
|
|
||
| @pytest.fixture(params=["L1", "L2", "L4", "H1", "HCurl", "l1", "l2", "linf"]) | ||
| def norm_type(request): | ||
| """Fixture for all norm types.""" | ||
| return request.param | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def mesh(): | ||
| """Create a simple unit square mesh for testing.""" | ||
| return fd.UnitSquareMesh(4, 4) | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def scalar_function(mesh): | ||
| """Create a scalar function on the mesh for testing.""" | ||
| x, y = ufl.SpatialCoordinate(mesh) | ||
| V = fd.FunctionSpace(mesh, "CG", 1) | ||
| return fd.Function(V).interpolate(x**2 + y) | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def vector_function(mesh): | ||
| """Create a vector function on the mesh for testing.""" | ||
| x, y = ufl.SpatialCoordinate(mesh) | ||
| V = fd.VectorFunctionSpace(mesh, "CG", 1) | ||
| return fd.Function(V).interpolate(ufl.as_vector([y * y, -x * x])) | ||
|
|
||
|
|
||
| def test_boundary_error(scalar_function): | ||
| """Test that boundary error raises NotImplementedError under lp norm.""" | ||
| not_impl_err = "lp errors on the boundary not yet implemented." | ||
| with pytest.raises(NotImplementedError, match=not_impl_err): | ||
| norm(scalar_function, norm_type="l1", boundary=True) | ||
|
|
||
|
|
||
| def test_l1(scalar_function): | ||
| """Test l1 norm computation.""" | ||
| expected = np.sum(np.abs(scalar_function.dat.data)) | ||
| got = norm(scalar_function, norm_type="l1") | ||
| assert np.isclose(expected, got) | ||
|
|
||
|
|
||
| def test_l2(scalar_function): | ||
| """Test l2 norm computation.""" | ||
| expected = np.sqrt(np.sum(scalar_function.dat.data**2)) | ||
| got = norm(scalar_function, norm_type="l2") | ||
| assert np.isclose(expected, got) | ||
|
|
||
|
|
||
| def test_linf(scalar_function): | ||
| """Test linf norm computation.""" | ||
| expected = np.max(scalar_function.dat.data) | ||
| got = norm(scalar_function, norm_type="linf") | ||
| assert np.isclose(expected, got) | ||
|
|
||
|
|
||
| def test_notimplemented_lp_error(scalar_function): | ||
| """Test that lp norm raises NotImplementedError.""" | ||
| not_impl_err = "lp norm of order p not supported." | ||
| with pytest.raises(NotImplementedError, match=not_impl_err): | ||
| norm(scalar_function, norm_type="lp") | ||
|
|
||
|
|
||
| def test_invalid_norm_type_error(scalar_function): | ||
| """Test that invalid norm type raises ValueError.""" | ||
| val_err = "Unknown norm type 'X'." | ||
| with pytest.raises(ValueError, match=val_err): | ||
| norm(scalar_function, norm_type="X") | ||
|
|
||
|
|
||
| def test_consistency_firedrake(scalar_function, integral_norm_type): | ||
| """Test consistency with Firedrake's norm implementation.""" | ||
| expected = fd.norm(scalar_function, norm_type=integral_norm_type) | ||
| got = norm(scalar_function, norm_type=integral_norm_type) | ||
| assert np.isclose(expected, got) | ||
|
|
||
|
|
||
| def test_zero_scalar(scalar_function, norm_type): | ||
| """Test that errornorm returns zero for identical scalar functions.""" | ||
| err = errornorm(scalar_function, scalar_function, norm_type=norm_type) | ||
| assert np.isclose(err, 0.0) | ||
|
|
||
|
|
||
| def test_consistency_errornorm(scalar_function, integral_norm_type): | ||
| """Test consistency of errornorm with Firedrake's implementation.""" | ||
| g = fd.Function(scalar_function.function_space()).interpolate(scalar_function + 1) | ||
| expected = fd.errornorm(scalar_function, g, norm_type=integral_norm_type) | ||
| got = errornorm(scalar_function, g, norm_type=integral_norm_type) | ||
| assert np.isclose(expected, got) | ||
|
|
||
|
|
||
| def test_zero_hdiv(vector_function): | ||
| """Test that errornorm returns zero for identical vector functions in HDiv norm.""" | ||
| err = errornorm(vector_function, vector_function, norm_type="HDiv") | ||
| assert np.isclose(err, 0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.