diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 00000000..bdcebb7a --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,15 @@ +# Each line is a file pattern followed by one or more owners. + +# These owners will be the default owners for everything in +# the repo. Unless a later match takes precedence, +# will be requested for review when someone opens a pull request. +* @jwallwork23 @stephankramer + +# Order is important; the last matching pattern takes the most +# precedence. When someone opens a pull request that only +# modifies specified file types, only that owner will be +# requested for a review, not the global owner(s). + +# You can also use email addresses if you prefer. They'll be +# used to look up users just like we do for commit author +# emails. diff --git a/demos/burgers-goal_oriented.py b/demos/burgers-goal_oriented.py index 00beafca..2114f0fc 100644 --- a/demos/burgers-goal_oriented.py +++ b/demos/burgers-goal_oriented.py @@ -8,21 +8,18 @@ # multiple meshes to adapt. We also chose to apply a QoI which integrates in time as # well as space. # -# We copy over the setup as before. The only difference is that we import from -# ``goalie_adjoint`` rather than ``goalie``. :: +# We copy over the setup as before. :: import matplotlib.pyplot as plt from animate.adapt import adapt from animate.metric import RiemannianMetric from firedrake import * -from goalie_adjoint import * +from goalie import * -fields = ["u"] - - -def get_function_spaces(mesh): - return {"u": VectorFunctionSpace(mesh, "CG", 2)} +n = 32 +meshes = [UnitSquareMesh(n, n), UnitSquareMesh(n, n)] +fields = [Field("u", family="Lagrange", degree=2, vector=True)] def get_initial_condition(mesh_seq): @@ -37,7 +34,7 @@ def get_initial_condition(mesh_seq): def get_solver(mesh_seq): def solver(index): - u, u_ = mesh_seq.fields["u"] + u, u_ = mesh_seq.field_functions["u"] # Define constants R = FunctionSpace(mesh_seq[index], "R", 0) @@ -77,7 +74,7 @@ def get_qoi(mesh_seq, i): dt = Function(R).assign(mesh_seq.time_partition.timesteps[i]) def time_integrated_qoi(t): - u = mesh_seq.fields["u"][0] + u = mesh_seq.field_functions["u"][0] return dt * inner(u, u) * ds(2) return time_integrated_qoi @@ -87,11 +84,8 @@ def time_integrated_qoi(t): # as in a `previous demo <./burgers2.py.html>`__, except that we export # every timestep rather than every other timestep. :: -n = 32 -meshes = [UnitSquareMesh(n, n, diagonal="left"), UnitSquareMesh(n, n, diagonal="left")] end_time = 0.5 dt = 1 / n - num_subintervals = len(meshes) time_partition = TimePartition( end_time, @@ -107,7 +101,6 @@ def time_integrated_qoi(t): mesh_seq = GoalOrientedMeshSeq( time_partition, meshes, - get_function_spaces=get_function_spaces, get_initial_condition=get_initial_condition, get_solver=get_solver, get_qoi=get_qoi, @@ -178,7 +171,7 @@ def adaptor(mesh_seq, solutions=None, indicators=None): num_elem = mesh_seq.count_elements() pyrint(f"fixed point iteration {iteration + 1}:") for i, (complexity, ndofs, nelem) in enumerate( - zip(complexities, num_dofs, num_elem) + zip(complexities, num_dofs, num_elem, strict=True) ): pyrint( f" subinterval {i}, complexity: {complexity:4.0f}" @@ -339,7 +332,7 @@ def adaptor(mesh_seq, solutions=None, indicators=None): num_elem = mesh_seq.count_elements() pyrint(f"fixed point iteration {iteration + 1}:") for i, (complexity, ndofs, nelem) in enumerate( - zip(complexities, num_dofs, num_elem) + zip(complexities, num_dofs, num_elem, strict=False) ): pyrint( f" subinterval {i}, complexity: {complexity:4.0f}" @@ -365,7 +358,6 @@ def adaptor(mesh_seq, solutions=None, indicators=None): mesh_seq = GoalOrientedMeshSeq( time_partition, meshes, - get_function_spaces=get_function_spaces, get_initial_condition=get_initial_condition, get_solver=get_solver, get_qoi=get_qoi, diff --git a/demos/burgers-hessian.py b/demos/burgers-hessian.py index 06ffa45a..533831b5 100644 --- a/demos/burgers-hessian.py +++ b/demos/burgers-hessian.py @@ -7,7 +7,12 @@ # we consider the time-dependent case. Moreover, we consider a :class:`MeshSeq` with # multiple subintervals and hence multiple meshes to adapt. # -# As before, we copy over what is now effectively boiler plate to set up our problem. :: +# As before, we copy over what is now effectively boiler plate to set up our problem. +# +# The only difference is that we need to specifically define the initial mesh for each +# subinterval and pass them as a list. When a single mesh is passed to the +# :class:`~.MeshSeq` constructor, it is shallowed copied, which is insufficient for mesh +# adaptation. :: import matplotlib.pyplot as plt from animate.adapt import adapt @@ -16,16 +21,14 @@ from goalie import * -field_names = ["u"] - - -def get_function_spaces(mesh): - return {"u": VectorFunctionSpace(mesh, "CG", 2)} +n = 32 +meshes = [UnitSquareMesh(n, n), UnitSquareMesh(n, n)] +fields = [Field("u", family="Lagrange", degree=2, vector=True)] def get_solver(mesh_seq): def solver(index): - u, u_ = mesh_seq.fields["u"] + u, u_ = mesh_seq.field_functions["u"] # Define constants R = FunctionSpace(mesh_seq[index], "R", 0) @@ -61,24 +64,20 @@ def get_initial_condition(mesh_seq): return {"u": Function(fs).interpolate(as_vector([sin(pi * x), 0]))} -n = 32 -meshes = [UnitSquareMesh(n, n, diagonal="left"), UnitSquareMesh(n, n, diagonal="left")] end_time = 0.5 dt = 1 / n - num_subintervals = len(meshes) time_partition = TimePartition( end_time, num_subintervals, dt, - field_names, + fields, num_timesteps_per_export=2, ) mesh_seq = MeshSeq( time_partition, meshes, - get_function_spaces=get_function_spaces, get_initial_condition=get_initial_condition, get_solver=get_solver, ) @@ -148,7 +147,7 @@ def adaptor(mesh_seq, solutions): num_elem = mesh_seq.count_elements() pyrint(f"fixed point iteration {iteration + 1}:") for i, (complexity, ndofs, nelem) in enumerate( - zip(complexities, num_dofs, num_elem) + zip(complexities, num_dofs, num_elem, strict=True) ): pyrint( f" subinterval {i}, complexity: {complexity:4.0f}" diff --git a/demos/burgers.py b/demos/burgers.py index 3b7943ea..701512fd 100644 --- a/demos/burgers.py +++ b/demos/burgers.py @@ -24,26 +24,28 @@ from goalie import * -# In this problem, we have a single prognostic variable, -# :math:`\mathbf u`. Its name is recorded in a list of -# field names. :: - -field_names = ["u"] - -# For each such field, we need to be able to specify how to -# build a :class:`FunctionSpace`, given some mesh. Since there -# could be more than one field, function spaces are given as a -# dictionary, indexed by the prognostic solution field names. :: +# We begin by defining the two meshes of the unit sequare that we'd like to solve over. +# For simplicity, we just use the same mesh twice: a :math:`32\times32` grid of the unit +# square, with each grid-box divided into right-angled triangles. :: +n = 32 +mesh = UnitSquareMesh(n, n) -def get_function_spaces(mesh): - return {"u": VectorFunctionSpace(mesh, "CG", 2)} +# In the Burgers problem, we have a single prognostic variable, :math:`\mathbf u`. Its +# name and other metadata are recorded in a :class:`~.Field` object. One important piece +# of metadata is the finite element used to define function spaces for the field (given +# some mesh). This can be defined either using the :class:`finat.ufl.FiniteElement` +# class, or using the same arguments as can be passed to +# :class:`firedrake.functionspace.FunctionSpace` (e.g., `mesh`, `family`, `degree`). In +# this case, we use a :math:`\mathbb{P}2` space so specify `family="Lagrange"` and +# `degree=2`.Since Burgers is a vector equation, we need to specify `vector=True`. :: +fields = [Field("u", family="Lagrange", degree=2, vector=True)] # The solution :class:`Function`\s are automatically built on the function spaces given -# by the :func:`get_function_spaces` function and are accessed via the :attr:`fields` -# attribute of the :class:`MeshSeq`. This attribute provides a dictionary of tuples -# containing the current and lagged solutions for each field. +# by the :func:`get_function_spaces` function and are accessed via the +# :attr:`field_functions` attribute of the :class:`MeshSeq`. This attribute provides a +# dictionary of tuples containing the current and lagged solutions for each field. # # In order to solve the PDE, we need to choose a time integration routine and solver # parameters for the underlying linear and nonlinear systems. This is achieved below by @@ -64,7 +66,7 @@ def get_function_spaces(mesh): def get_solver(mesh_seq): def solver(index): # Get the current and lagged solutions - u, u_ = mesh_seq.fields["u"] + u, u_ = mesh_seq.field_functions["u"] # Define constants R = FunctionSpace(mesh_seq[index], "R", 0) @@ -105,37 +107,29 @@ def get_initial_condition(mesh_seq): return {"u": Function(fs).interpolate(as_vector([sin(pi * x), 0]))} -# Now that we have the above functions defined, we move onto the -# concrete parts of the solver. To begin with, we require a -# sequence of meshes, simulation end time and a timestep. :: +# Now that we have the above functions defined, we need to define the time +# discretisation used for the solver. To do this, we create a :class:`TimePartition` for +# the problem with two subintervals. :: -n = 32 -meshes = [ - UnitSquareMesh(n, n), - UnitSquareMesh(n, n), -] end_time = 0.5 dt = 1 / n - -# These can be used to create a :class:`TimePartition` for the -# problem with two subintervals. :: - -num_subintervals = len(meshes) +num_subintervals = 2 time_partition = TimePartition( end_time, num_subintervals, dt, - field_names, + fields, num_timesteps_per_export=2, ) -# Finally, we are able to construct a :class:`MeshSeq` and -# solve Burgers equation over the meshes in sequence. :: +# Finally, we are able to construct a :class:`~.MeshSeq` and solve Burgers equation over +# the meshes in sequence. Note that the second argument can be either a list of meshes +# or just a single mesh. If a single mesh is passed then this will be used for all +# subintervals. :: mesh_seq = MeshSeq( time_partition, - meshes, - get_function_spaces=get_function_spaces, + mesh, get_initial_condition=get_initial_condition, get_solver=get_solver, ) diff --git a/demos/burgers1.py b/demos/burgers1.py index 9fcf63bb..1cc9ec43 100644 --- a/demos/burgers1.py +++ b/demos/burgers1.py @@ -6,32 +6,25 @@ # automatic differentiation functionality in order to # automatically form and solve discrete adjoint problems. # -# We always begin by importing Goalie. Adjoint mode is used -# so that we have access to the :class:`AdjointMeshSeq` class. -# :: +# We always begin by importing Goalie. :: from firedrake import * -from goalie_adjoint import * +from goalie import * -# For ease, the list of field names and functions for obtaining the -# function spaces, solvers, and initial conditions -# are redefined as in the previous demo. The only difference -# is that now we are solving the adjoint problem, which -# requires that the PDE solve is labelled with an -# ``ad_block_tag`` that matches the corresponding prognostic -# variable name. :: +# For ease, the list of fields and functions for obtaining the solvers and initial +# conditions are redefined as in the previous demo. The only difference is that now we +# are solving the adjoint problem, which requires that the PDE solve is labelled with an +# ``ad_block_tag`` that matches the corresponding prognostic variable name. :: -field_names = ["u"] - - -def get_function_spaces(mesh): - return {"u": VectorFunctionSpace(mesh, "CG", 2)} +n = 32 +mesh = UnitSquareMesh(n, n) +fields = [Field("u", family="Lagrange", degree=2, vector=True)] def get_solver(mesh_seq): def solver(index): - u, u_ = mesh_seq.fields["u"] + u, u_ = mesh_seq.field_functions["u"] # Define constants R = FunctionSpace(mesh_seq[index], "R", 0) @@ -83,26 +76,19 @@ def get_initial_condition(mesh_seq): def get_qoi(mesh_seq, i): def end_time_qoi(): - u = mesh_seq.fields["u"][0] + u = mesh_seq.field_functions["u"][0] return inner(u, u) * ds(2) return end_time_qoi -# Now that we have the above functions defined, we move onto the -# concrete parts of the solver, which mimic the original demo. :: +# Next, we define the :class:`~.TimePartition`. In cases where we only solve over a +# single time subinterval (as in this demo), the partition is trivial and we can use the +# :class:`~.TimeInterval` constructor, which requires fewer arguments. :: -n = 32 -mesh = UnitSquareMesh(n, n) end_time = 0.5 dt = 1 / n - -# Another requirement to solve the adjoint problem using -# Goalie is a :class:`TimePartition`. In our case, there is a -# single mesh, so the partition is trivial and we can use the -# :class:`TimeInterval` constructor. :: - -time_partition = TimeInterval(end_time, dt, field_names, num_timesteps_per_export=2) +time_partition = TimeInterval(end_time, dt, fields, num_timesteps_per_export=2) # Finally, we are able to construct an :class:`AdjointMeshSeq` and # thereby call its :meth:`solve_adjoint` method. This computes the QoI @@ -112,7 +98,6 @@ def end_time_qoi(): mesh_seq = AdjointMeshSeq( time_partition, mesh, - get_function_spaces=get_function_spaces, get_initial_condition=get_initial_condition, get_solver=get_solver, get_qoi=get_qoi, diff --git a/demos/burgers2.py b/demos/burgers2.py index 673a0f12..394dcf8f 100644 --- a/demos/burgers2.py +++ b/demos/burgers2.py @@ -5,28 +5,25 @@ # <./burgers1.py.html>`__, but now using two subintervals. There # is still no error estimation or mesh adaptation; the same mesh # is used in each case to verify that the framework works. -# -# Again, begin by importing Goalie with adjoint mode activated. :: from firedrake import * -from goalie_adjoint import * +from goalie import * set_log_level(DEBUG) -# Redefine the ``field_names`` variable from the previous demo, as well as all the -# getter functions. :: - -field_names = ["u"] - +# Redefine the meshes and field metadata as in previous demos, as well as all the +# getter functions. In this case, we make the default `diagonal="left"` keyword argument +# to :class:`~.UnitSquareMesh` explicit. (See later.) :: -def get_function_spaces(mesh): - return {"u": VectorFunctionSpace(mesh, "CG", 2)} +n = 32 +mesh = UnitSquareMesh(n, n, diagonal="left") +fields = [Field("u", family="Lagrange", degree=2, vector=True)] def get_solver(mesh_seq): def solver(index): - u, u_ = mesh_seq.fields["u"] + u, u_ = mesh_seq.field_functions["u"] # Define constants R = FunctionSpace(mesh_seq[index], "R", 0) @@ -64,35 +61,27 @@ def get_initial_condition(mesh_seq): def get_qoi(mesh_seq, i): def end_time_qoi(): - u = mesh_seq.fields["u"][0] + u = mesh_seq.field_functions["u"][0] return inner(u, u) * ds(2) return end_time_qoi -# The solver, initial condition and QoI may be imported from the -# previous demo. The same basic setup is used. The only difference -# is that the :class:`MeshSeq` contains two meshes. :: +# This time, the ``TimePartition`` is defined on **two** subintervals. :: -n = 32 -meshes = [UnitSquareMesh(n, n, diagonal="left"), UnitSquareMesh(n, n, diagonal="left")] end_time = 0.5 dt = 1 / n - -# This time, the ``TimePartition`` is defined on **two** subintervals. :: - -num_subintervals = len(meshes) +num_subintervals = 2 time_partition = TimePartition( end_time, num_subintervals, dt, - field_names, + fields, num_timesteps_per_export=2, ) mesh_seq = AdjointMeshSeq( time_partition, - meshes, - get_function_spaces=get_function_spaces, + mesh, get_initial_condition=get_initial_condition, get_solver=get_solver, get_qoi=get_qoi, diff --git a/demos/burgers_ee.py b/demos/burgers_ee.py index 4a38270a..032f0f1b 100644 --- a/demos/burgers_ee.py +++ b/demos/burgers_ee.py @@ -26,26 +26,32 @@ from firedrake import * -from goalie_adjoint import * +from goalie import * set_log_level(DEBUG) -# Redefine the ``field_names`` variable and the getter functions as in the first -# adjoint Burgers demo. The only difference is the inclusion of the -# :meth:`GoalOrientedMeshSeq.read_forms()` method in the ``get_solver`` function. The -# method is used to communicate the variational form to the mesh sequence object so that -# Goalie can utilise it in the error estimation process described above. :: - -field_names = ["u"] - +# Redefine the meshes, fields and the getter functions as in the first adjoint Burgers +# demo, with two differences: +# +# * We need to specifically define the mesh for each subinterval and pass them as a +# list. When a single mesh is passed to the :class:`~.MeshSeq` constructor, it is +# shallow copied, which is insufficient for the :math:`h`-refinement used in the error +# estimation step. :: +# * We need to call the :meth:`~.GoalOrientedMeshSeq.read_forms()` method in the +# ``get_solver`` function. This is used to communicate the variational form to the +# mesh sequence object so that Goalie can utilise it in the error estimation process +# described above. +# +# :: -def get_function_spaces(mesh): - return {"u": VectorFunctionSpace(mesh, "CG", 2)} +n = 32 +meshes = [UnitSquareMesh(n, n), UnitSquareMesh(n, n)] +fields = [Field("u", family="Lagrange", degree=2, vector=True)] def get_solver(mesh_seq): def solver(index): - u, u_ = mesh_seq.fields["u"] + u, u_ = mesh_seq.field_functions["u"] # Define constants R = FunctionSpace(mesh_seq[index], "R", 0) @@ -86,16 +92,14 @@ def get_initial_condition(mesh_seq): def get_qoi(mesh_seq, i): def end_time_qoi(): - u = mesh_seq.fields["u"][0] + u = mesh_seq.field_functions["u"][0] return inner(u, u) * ds(2) return end_time_qoi -# Next, create a sequence of meshes and a :class:`TimePartition`. :: +# Next, create a :class:`TimePartition`. :: -n = 32 -meshes = [UnitSquareMesh(n, n, diagonal="left"), UnitSquareMesh(n, n, diagonal="left")] end_time = 0.5 dt = 1 / n num_subintervals = len(meshes) @@ -103,7 +107,7 @@ def end_time_qoi(): end_time, num_subintervals, dt, - field_names, + fields, num_timesteps_per_export=2, ) @@ -115,7 +119,6 @@ def end_time_qoi(): mesh_seq = GoalOrientedMeshSeq( time_partition, meshes, - get_function_spaces=get_function_spaces, get_initial_condition=get_initial_condition, get_solver=get_solver, get_qoi=get_qoi, diff --git a/demos/burgers_oo.py b/demos/burgers_oo.py index e7190e09..49973e20 100644 --- a/demos/burgers_oo.py +++ b/demos/burgers_oo.py @@ -20,19 +20,20 @@ from firedrake import * -from goalie_adjoint import * +from goalie import * set_log_level(DEBUG) class BurgersMeshSeq(GoalOrientedMeshSeq): - @staticmethod - def get_function_spaces(mesh): - return {"u": VectorFunctionSpace(mesh, "CG", 2)} + def get_initial_condition(self): + fs = self.function_spaces["u"][0] + x, y = SpatialCoordinate(self[0]) + return {"u": Function(fs).interpolate(as_vector([sin(pi * x), 0]))} def get_solver(self): def solver(index): - u, u_ = self.fields["u"] + u, u_ = self.field_functions["u"] # Define constants R = FunctionSpace(self[index], "R", 0) @@ -66,22 +67,17 @@ def solver(index): return solver - def get_initial_condition(self): - fs = self.function_spaces["u"][0] - x, y = SpatialCoordinate(self[0]) - return {"u": Function(fs).interpolate(as_vector([sin(pi * x), 0]))} - @annotate_qoi def get_qoi(self, i): R = FunctionSpace(self[i], "R", 0) dt = Function(R).assign(self.time_partition.timesteps[i]) def end_time_qoi(): - u = self.fields["u"][0] + u = self.field_functions["u"][0] return inner(u, u) * ds(2) def time_integrated_qoi(t): - u = self.fields["u"][0] + u = self.field_functions["u"][0] return dt * inner(u, u) * ds(2) if self.qoi_type == "end_time": @@ -94,16 +90,19 @@ def time_integrated_qoi(t): # methods have been modified to account for both ``"end_time"`` and # ``"time_integrated"`` QoIs. # -# We apply exactly the same setup as before, except that the -# :class:`BurgersMeshSeq` class is used. :: +# We apply exactly the same setup as before, except that the :class:`BurgersMeshSeq` +# class is used and we again need to specifically define the mesh for each subinterval. +# :: n = 32 -meshes = [UnitSquareMesh(n, n, diagonal="left"), UnitSquareMesh(n, n, diagonal="left")] +meshes = [UnitSquareMesh(n, n), UnitSquareMesh(n, n)] +fields = [Field("u", family="Lagrange", degree=2, vector=True)] + end_time = 0.5 dt = 1 / n num_subintervals = len(meshes) time_partition = TimePartition( - end_time, num_subintervals, dt, ["u"], num_timesteps_per_export=2 + end_time, num_subintervals, dt, fields, num_timesteps_per_export=2 ) mesh_seq = BurgersMeshSeq(time_partition, meshes, qoi_type="time_integrated") solutions, indicators = mesh_seq.indicate_errors( diff --git a/demos/burgers_time_integrated.py b/demos/burgers_time_integrated.py index f8f171bb..ee0f0f88 100644 --- a/demos/burgers_time_integrated.py +++ b/demos/burgers_time_integrated.py @@ -1,21 +1,22 @@ # Adjoint Burgers equation with a time integrated QoI -# ====================================================== +# =================================================== # # So far, we only considered a quantity of interest corresponding to a spatial integral # at the end time. For some problems, it is more suitable to have a QoI which integrates # in time as well as space. # # Begin by importing from Firedrake and Goalie. -from firedrake import * -from goalie_adjoint import * +from firedrake import * -# Redefine the ``get_initial_condition`` and ``get_function_spaces``, functions as in -# the first Burgers demo. :: +from goalie import * +# Redefine the mesh, fields and ``get_initial_condition`` function as in `the previous +# demo <./burgers2.py.html>`__. :: -def get_function_spaces(mesh): - return {"u": VectorFunctionSpace(mesh, "CG", 2)} +n = 32 +mesh = UnitSquareMesh(n, n) +fields = [Field("u", family="Lagrange", degree=2, vector=True)] def get_initial_condition(mesh_seq): @@ -34,7 +35,7 @@ def get_initial_condition(mesh_seq): def get_solver(mesh_seq): def solver(index): - u, u_ = mesh_seq.fields["u"] + u, u_ = mesh_seq.field_functions["u"] # Define constants R = FunctionSpace(mesh_seq[index], "R", 0) @@ -81,23 +82,20 @@ def get_qoi(mesh_seq, i): dt = Function(R).assign(mesh_seq.time_partition.timesteps[i]) def time_integrated_qoi(t): - u = mesh_seq.fields["u"][0] + u = mesh_seq.field_functions["u"][0] return dt * inner(u, u) * ds(2) return time_integrated_qoi -# We use the same mesh setup as in `the previous demo <./burgers2.py.html>`__ and the -# same time partitioning, except that we export every timestep rather than every other -# timestep. :: +# We use the same time partitioning as in `the previous demo <./burgers2.py.html>`__, +# except that we export every timestep rather than every other timestep. :: -n = 32 -meshes = [UnitSquareMesh(n, n, diagonal="left"), UnitSquareMesh(n, n, diagonal="left")] end_time = 0.5 dt = 1 / n -num_subintervals = len(meshes) +num_subintervals = 2 time_partition = TimePartition( - end_time, num_subintervals, dt, ["u"], num_timesteps_per_export=1 + end_time, num_subintervals, dt, fields, num_timesteps_per_export=1 ) # The only difference when defining the :class:`AdjointMeshSeq` is that we specify @@ -105,8 +103,7 @@ def time_integrated_qoi(t): mesh_seq = AdjointMeshSeq( time_partition, - meshes, - get_function_spaces=get_function_spaces, + mesh, get_initial_condition=get_initial_condition, get_solver=get_solver, get_qoi=get_qoi, diff --git a/demos/gray_scott.py b/demos/gray_scott.py index c059405d..984cd79c 100644 --- a/demos/gray_scott.py +++ b/demos/gray_scott.py @@ -11,21 +11,19 @@ from firedrake import * -from goalie_adjoint import * +from goalie import * # The problem is defined on a doubly periodic mesh of squares. :: -field_names = ["ab"] mesh = PeriodicSquareMesh(65, 65, 2.5, quadrilateral=True, direction="both") # We solve for the tracer species using a mixed formulation, with a :math:`\mathbb P1` -# approximation for both components. :: - - -def get_function_spaces(mesh): - V = FunctionSpace(mesh, "CG", 1) - return {"ab": V * V} +# approximation for both components. In this case, it's more convenient to define the +# finite element and pass this directly to the constructor for :class:`~.Field`, rather +# than using its other keyword arguments. :: +p1_element = FiniteElement("Lagrange", quadrilateral, 1) +fields = [Field("ab", finite_element=MixedElement([p1_element, p1_element]))] # The initial conditions are localised within the region :math:`[1, 1.5]^2`. :: @@ -53,7 +51,7 @@ def get_initial_condition(mesh_seq): def get_solver(mesh_seq): def solver(index): - ab, ab_ = mesh_seq.fields["ab"] + ab, ab_ = mesh_seq.field_functions["ab"] # Define constants R = FunctionSpace(mesh_seq[index], "R", 0) @@ -101,7 +99,7 @@ def solver(index): def get_qoi(mesh_seq, index): def qoi(): - ab = mesh_seq.fields["ab"][0] + ab = mesh_seq.field_functions["ab"][0] a, b = split(ab) return a * b**2 * dx @@ -120,7 +118,7 @@ def qoi(): end_time, num_subintervals, dt, - field_names, + fields, num_timesteps_per_export=dt_per_export, subintervals=[ (0.0, 0.001), @@ -136,7 +134,6 @@ def qoi(): mesh_seq = AdjointMeshSeq( time_partition, mesh, - get_function_spaces=get_function_spaces, get_initial_condition=get_initial_condition, get_solver=get_solver, get_qoi=get_qoi, diff --git a/demos/gray_scott_split.py b/demos/gray_scott_split.py index 7cea528c..6d6c30fd 100644 --- a/demos/gray_scott_split.py +++ b/demos/gray_scott_split.py @@ -10,20 +10,14 @@ from firedrake import * -from goalie_adjoint import * +from goalie import * -# This time, we have two fields instead of one, as well as two function spaces. :: +# This time, we have two fields instead of one and so use two separate +# :math:`\mathbb{P}1` spaces rather than a mixed space with two such components. :: -field_names = ["a", "b"] mesh = PeriodicSquareMesh(65, 65, 2.5, quadrilateral=True, direction="both") - - -def get_function_spaces(mesh): - return { - "a": FunctionSpace(mesh, "CG", 1), - "b": FunctionSpace(mesh, "CG", 1), - } - +p1_element = FiniteElement("Lagrange", quadrilateral, 1) +fields = [Field("a", finite_element=p1_element), Field("b", finite_element=p1_element)] # Therefore, the initial condition must be constructed using separate # :class:`Function`\s. :: @@ -50,8 +44,8 @@ def get_initial_condition(mesh_seq): def get_solver(mesh_seq): def solver(index): - a, a_ = mesh_seq.fields["a"] - b, b_ = mesh_seq.fields["b"] + a, a_ = mesh_seq.field_functions["a"] + b, b_ = mesh_seq.field_functions["b"] # Define constants R = FunctionSpace(mesh_seq[index], "R", 0) @@ -104,8 +98,8 @@ def solver(index): def get_qoi(mesh_seq, index): def qoi(): - a = mesh_seq.fields["a"][0] - b = mesh_seq.fields["b"][0] + a = mesh_seq.field_functions["a"][0] + b = mesh_seq.field_functions["b"][0] return a * b**2 * dx return qoi @@ -119,7 +113,7 @@ def qoi(): end_time, num_subintervals, dt, - field_names, + fields, num_timesteps_per_export=dt_per_export, subintervals=[ (0.0, 0.001), @@ -133,7 +127,6 @@ def qoi(): mesh_seq = AdjointMeshSeq( time_partition, mesh, - get_function_spaces=get_function_spaces, get_initial_condition=get_initial_condition, get_solver=get_solver, get_qoi=get_qoi, diff --git a/demos/mantle_convection.py b/demos/mantle_convection.py index e94e263b..b135fac1 100644 --- a/demos/mantle_convection.py +++ b/demos/mantle_convection.py @@ -43,16 +43,21 @@ # The problem is solved simultaneously for the velocity :math:`\mathbf{u}` and pressure # :math:`p` using a *mixed* formulation, which was introduced in a `previous demo on # advection-diffusion reaction <./gray_scott.py.html>`__. - -fields = ["up", "T"] - - -def get_function_spaces(mesh): - V = VectorFunctionSpace(mesh, "CG", 2, name="velocity") - W = FunctionSpace(mesh, "CG", 1, name="pressure") - Z = MixedFunctionSpace([V, W], name="velocity-pressure") - Q = FunctionSpace(mesh, "CG", 1, name="temperature") - return {"up": Z, "T": Q} +# +# To account for the lack of time derivative in the Stokes equations, we set the +# ``unsteady`` keyword argument of the initialiser for the :class:`~.Field` class to +# ``False`` rather than the default ``True`` value to specify that the ``"up"`` +# field is *steady* (i.e. without a time derivative). The ``T`` field is *unsteady* +# (i.e. involves a time derivative) so we can use ``unsteady=True``. Again, given the +# mixed finite element space used for velocity, it is more convenient to define the +# finite elements and pass these directly to the :class:`~.Field` constructor. :: + +p2v_element = VectorElement(FiniteElement("Lagrange", triangle, 2), dim=2) +p1_element = FiniteElement("Lagrange", triangle, 1) +fields = [ + Field("up", finite_element=MixedElement([p2v_element, p1_element]), unsteady=False), + Field("T", finite_element=p1_element, unsteady=True), +] # We must set initial conditions to solve the problem. Note that we define the initial @@ -78,9 +83,9 @@ def solver(index): Z = mesh_seq.function_spaces["up"][index] Q = mesh_seq.function_spaces["T"][index] - up = mesh_seq.fields["up"] + up = mesh_seq.field_functions["up"] u, p = split(up) - T, T_ = mesh_seq.fields["T"] + T, T_ = mesh_seq.field_functions["T"] # Crank-Nicolson time discretisation for temperature Ttheta = 0.5 * (T + T_) @@ -159,25 +164,17 @@ def solver(index): end_time = dt * num_timesteps dt_per_export = [10 for _ in range(num_subintervals)] -# To account for the lack of time derivative in the Stokes equations, we use the -# ``field_types`` argument of the ``TimePartition`` object to specify that the ``"up"`` -# field is *steady* (i.e. without a time derivative) and that the ``T`` field is -# *unsteady* (i.e. involves a time derivative). The order in ``field_types`` must -# match the order of the fields in the ``fields`` list above. - time_partition = TimePartition( end_time, num_subintervals, dt, fields, num_timesteps_per_export=dt_per_export, - field_types=["steady", "unsteady"], ) mesh_seq = MeshSeq( time_partition, meshes, - get_function_spaces=get_function_spaces, get_initial_condition=get_initial_condition, get_solver=get_solver, transfer_method="interpolate", diff --git a/demos/mesh_seq.py b/demos/mesh_seq.py index 8df5c6c8..adc26d6f 100644 --- a/demos/mesh_seq.py +++ b/demos/mesh_seq.py @@ -22,14 +22,14 @@ # Consider the final subinterval from the previous demo. :: end_time = 1.0 -field_names = ["solution"] +fields = [Field("solution", family="Real")] dt = [0.125, 0.0625] subintervals = [(0.0, 0.75), (0.75, 1.0)] time_partition = TimePartition( end_time, len(subintervals), dt, - field_names, + fields, num_timesteps_per_export=[2, 4], subintervals=subintervals, ) diff --git a/demos/ode.py b/demos/ode.py index 04163495..86e53cba 100644 --- a/demos/ode.py +++ b/demos/ode.py @@ -45,33 +45,31 @@ from goalie import * -# Next, create a simple :class:`~.TimeInterval` object to hold information related to -# the time discretisation. This is a simplified version of :class:`~.TimePartition`, -# which only has one subinterval. :: - -end_time = 1 -time_partition = TimeInterval(end_time, dt, "u") - # Much of the following might seem excessive for this example. However, it exists to # allow for the flexibility required in later PDE examples. # -# We need to create a :class:`~.FunctionSpace` for the solution field to live in. Given -# that we have a scalar ODE, the solution is just a real number at each time level. We -# represent this using the degree-0 :math:`R`-space, as follows. A mesh is required to -# define a function space in Firedrake, although what the mesh is doesn't actually -# matter for this example. :: +# We need to be able to create :class:`~.FunctionSpace`\s for the solution field to live +# in. Given that we have a scalar ODE, the mesh can be interpreted as a vertex-only mesh +# with a single vertex and the finite element as a real number, i.e., the degree-0 +# :math:`R`-space. :: +mesh = VertexOnlyMesh(UnitIntervalMesh(1), [[0.5]]) +fields = [Field("u", family="Real", degree=0)] -def get_function_spaces(mesh): - return {"u": FunctionSpace(mesh, "R", 0)} +# Next, create a simple :class:`~.TimeInterval` object to hold information related to +# the time discretisation. This is a simplified version of :class:`~.TimePartition`, +# which only has one subinterval. :: + +end_time = 1 +time_partition = TimeInterval(end_time, dt, fields) # Next, we need to supply the initial condition :math:`u(0) = 1`. We do this by creating # a :class:`~.Function` in the :math:`R`-space and assigning it the value 1. :: -def get_initial_condition(point_seq): - fs = point_seq.function_spaces["u"][0] +def get_initial_condition(mesh_seq): + fs = mesh_seq.function_spaces["u"][0] return {"u": Function(fs).assign(1.0)} @@ -104,15 +102,15 @@ def get_initial_condition(point_seq): # The Forward Euler scheme may be implemented and solved as follows. :: -def get_solver_forward_euler(point_seq): +def get_solver_forward_euler(mesh_seq): def solver(index): - tp = point_seq.time_partition + tp = mesh_seq.time_partition # Get the current and lagged solutions - u, u_ = point_seq.fields["u"] + u, u_ = mesh_seq.field_functions["u"] # Define the (trivial) form - R = point_seq.function_spaces["u"][index] + R = mesh_seq.function_spaces["u"][index] dt = Function(R).assign(tp.timesteps[index]) v = TestFunction(R) F = (u - u_ - dt * u_) * v * dx @@ -135,13 +133,9 @@ def solver(index): return solver -# For this ODE problem, the main driver object is a :class:`~.PointSeq`, which is -# defined in terms of the :class:`~.TimePartition` describing the time discretisation, -# plus the functions defined above. :: - -point_seq = PointSeq( +mesh_seq = MeshSeq( time_partition, - get_function_spaces=get_function_spaces, + mesh, get_initial_condition=get_initial_condition, get_solver=get_solver_forward_euler, ) @@ -152,7 +146,7 @@ def solver(index): # solution field. For the purposes of this demo, we have field ``"u"``, which is a # forward solution. The resulting solution trajectory is a list. :: -solutions = point_seq.solve_forward()["u"]["forward"] +solutions = mesh_seq.solve_forward()["u"]["forward"] # Note that the solution trajectory does not include the initial value, so we prepend # it. We also convert the solution :class:`~.Function`\s to :class:`~.float`\s, for @@ -193,15 +187,15 @@ def solver(index): # .. math:: # \int_0^1 (u_{i+1} - u_{i} - \Delta t u_{i+1}) v \mathrm{d}t, \forall v\in R. # -# To apply Backward Euler we create the :class:`~.PointSeq` in the same way, just with +# To apply Backward Euler we create the :class:`~.MeshSeq` in the same way, just with # `get_solver_forward_euler` substituted for `get_solver_backward_euler`. :: -def get_solver_backward_euler(point_seq): +def get_solver_backward_euler(mesh_seq): def solver(index): - tp = point_seq.time_partition - u, u_ = point_seq.fields["u"] - R = point_seq.function_spaces["u"][index] + tp = mesh_seq.time_partition + u, u_ = mesh_seq.field_functions["u"] + R = mesh_seq.function_spaces["u"][index] dt = Function(R).assign(tp.timesteps[index]) v = TestFunction(R) @@ -222,13 +216,13 @@ def solver(index): return solver -point_seq = PointSeq( +mesh_seq = MeshSeq( time_partition, - get_function_spaces=get_function_spaces, + mesh, get_initial_condition=get_initial_condition, get_solver=get_solver_backward_euler, ) -solutions = point_seq.solve_forward()["u"]["forward"] +solutions = mesh_seq.solve_forward()["u"]["forward"] backward_euler_trajectory = [1] backward_euler_trajectory += [ @@ -261,11 +255,11 @@ def solver(index): # :: -def get_solver_crank_nicolson(point_seq): +def get_solver_crank_nicolson(mesh_seq): def solver(index): - tp = point_seq.time_partition - u, u_ = point_seq.fields["u"] - R = point_seq.function_spaces["u"][index] + tp = mesh_seq.time_partition + u, u_ = mesh_seq.field_functions["u"] + R = mesh_seq.function_spaces["u"][index] dt = Function(R).assign(tp.timesteps[index]) v = TestFunction(R) @@ -287,14 +281,14 @@ def solver(index): return solver -point_seq = PointSeq( +mesh_seq = MeshSeq( time_partition, - get_function_spaces=get_function_spaces, + mesh, get_initial_condition=get_initial_condition, get_solver=get_solver_crank_nicolson, ) -solutions = point_seq.solve_forward()["u"]["forward"] +solutions = mesh_seq.solve_forward()["u"]["forward"] crank_nicolson_trajectory = [1] crank_nicolson_trajectory += [ float(sol) for subinterval in solutions for sol in subinterval diff --git a/demos/point_discharge2d-goal_oriented.py b/demos/point_discharge2d-goal_oriented.py index db22f79c..9642dcee 100644 --- a/demos/point_discharge2d-goal_oriented.py +++ b/demos/point_discharge2d-goal_oriented.py @@ -7,8 +7,7 @@ # `another previous demo <./point_discharge2d.py.html>`__ to provide the first # exposition of goal-oriented mesh adaptation in these demos. # -# We copy over the setup as before. The only difference is that we import from -# `goalie_adjoint` rather than `goalie`. :: +# We copy over the setup as before. :: import matplotlib.colors as mcolors import matplotlib.pyplot as plt @@ -17,13 +16,10 @@ from firedrake import * from matplotlib import ticker -from goalie_adjoint import * +from goalie import * -field_names = ["c"] - - -def get_function_spaces(mesh): - return {"c": FunctionSpace(mesh, "CG", 1)} +mesh = RectangleMesh(50, 10, 50, 10) +fields = [Field("c", family="Lagrange", degree=1, unsteady=False)] def source(mesh): @@ -35,7 +31,7 @@ def source(mesh): def get_solver(mesh_seq): def solver(index): function_space = mesh_seq.function_spaces["c"][index] - c = mesh_seq.fields["c"] + c = mesh_seq.field_functions["c"] h = CellSize(mesh_seq[index]) S = source(mesh_seq[index]) @@ -72,7 +68,7 @@ def solver(index): def get_qoi(mesh_seq, index): def qoi(): - c = mesh_seq.fields["c"] + c = mesh_seq.field_functions["c"] x, y = SpatialCoordinate(mesh_seq[index]) xr, yr, rr = 20, 7.5, 0.5 kernel = conditional((x - xr) ** 2 + (y - yr) ** 2 < rr**2, 1, 0) @@ -84,12 +80,10 @@ def qoi(): # Since we want to do goal-oriented mesh adaptation, we use a # :class:`GoalOrientedMeshSeq`. :: -mesh = RectangleMesh(50, 10, 50, 10) -time_partition = TimeInstant(field_names) +time_partition = TimeInstant(fields) mesh_seq = GoalOrientedMeshSeq( time_partition, mesh, - get_function_spaces=get_function_spaces, get_solver=get_solver, get_qoi=get_qoi, qoi_type="steady", @@ -317,7 +311,6 @@ def adaptor(mesh_seq, solutions, indicators): mesh_seq = GoalOrientedMeshSeq( time_partition, mesh, - get_function_spaces=get_function_spaces, get_solver=get_solver, get_qoi=get_qoi, qoi_type="steady", diff --git a/demos/point_discharge2d-hessian.py b/demos/point_discharge2d-hessian.py index 19928b75..98140b2b 100644 --- a/demos/point_discharge2d-hessian.py +++ b/demos/point_discharge2d-hessian.py @@ -21,13 +21,11 @@ # We again consider the "point discharge with diffusion" test case from the # `previous demo <./point_discharge2d.py.html>`__, approximating the tracer -# concentration :math:`c` in :math:`\mathbb P1` space. :: +# concentration :math:`c` in :math:`\mathbb P1` space. We start with a relatively coarse +# initial mesh. :: -field_names = ["c"] - - -def get_function_spaces(mesh): - return {"c": FunctionSpace(mesh, "CG", 1)} +mesh = RectangleMesh(50, 10, 50, 10) +fields = [Field("c", family="Lagrange", degree=1, unsteady=False)] def source(mesh): @@ -39,7 +37,7 @@ def source(mesh): def get_solver(mesh_seq): def solver(index): function_space = mesh_seq.function_spaces["c"][index] - c = mesh_seq.fields["c"] + c = mesh_seq.field_functions["c"] h = CellSize(mesh_seq[index]) S = source(mesh_seq[index]) @@ -71,15 +69,13 @@ def solver(index): return solver -# Take a relatively coarse initial mesh, a :class:`TimeInstant` (since we have a -# steady-state problem), and put everything together in a :class:`MeshSeq`. :: +# Take a :class:`TimeInstant` (since we have a steady-state problem), and put everything +# together in a :class:`MeshSeq`. :: -mesh = RectangleMesh(50, 10, 50, 10) -time_partition = TimeInstant(field_names) +time_partition = TimeInstant(fields) mesh_seq = MeshSeq( time_partition, mesh, - get_function_spaces=get_function_spaces, get_solver=get_solver, ) diff --git a/demos/point_discharge2d.py b/demos/point_discharge2d.py index db9fcea7..a241d41f 100644 --- a/demos/point_discharge2d.py +++ b/demos/point_discharge2d.py @@ -28,16 +28,12 @@ from firedrake import * -from goalie_adjoint import * +from goalie import * # We solve the advection-diffusion problem in :math:`\mathbb P1` space. :: -field_names = ["c"] - - -def get_function_spaces(mesh): - return {"c": FunctionSpace(mesh, "CG", 1)} - +mesh = RectangleMesh(200, 40, 50, 10) +fields = [Field("c", family="Lagrange", degree=1, unsteady=False)] # Point sources are difficult to represent in numerical models. Here we # follow :cite:`Wallwork:2022` in using a Gaussian approximation. Let @@ -68,7 +64,7 @@ def source(mesh): # # where :math:`h` measures cell size. # -# Note that :attr:`mesh_seq.fields` now returns a single +# Note that :attr:`mesh_seq.field_functions` now returns a single # :class:`~firedrake.function.Function` object since the problem is steady, so there is # no notion of a lagged solution, unlike in previous (time-dependent) demos. # With these ingredients, we can now define the :meth:`get_solver` method. Don't forget @@ -81,7 +77,7 @@ def source(mesh): def get_solver(mesh_seq): def solver(index): function_space = mesh_seq.function_spaces["c"][index] - c = mesh_seq.fields["c"] + c = mesh_seq.field_functions["c"] h = CellSize(mesh_seq[index]) S = source(mesh_seq[index]) @@ -127,7 +123,7 @@ def solver(index): def get_qoi(mesh_seq, index): def qoi(): - c = mesh_seq.fields["c"] + c = mesh_seq.field_functions["c"] x, y = SpatialCoordinate(mesh_seq[index]) xr, yr, rr = 20, 7.5, 0.5 kernel = conditional((x - xr) ** 2 + (y - yr) ** 2 < rr**2, 1, 0) @@ -139,8 +135,7 @@ def qoi(): # Finally, we can set up the problem. Instead of using a :class:`TimePartition`, # we use the subclass :class:`TimeInstant`, whose only input is the field list. :: -mesh = RectangleMesh(200, 40, 50, 10) -time_partition = TimeInstant(field_names) +time_partition = TimeInstant(fields) # When creating the :class:`MeshSeq`, we need to set the ``"qoi_type"`` to # ``"steady"``. :: @@ -148,7 +143,6 @@ def qoi(): mesh_seq = GoalOrientedMeshSeq( time_partition, mesh, - get_function_spaces=get_function_spaces, get_solver=get_solver, get_qoi=get_qoi, qoi_type="steady", diff --git a/demos/solid_body_rotation.py b/demos/solid_body_rotation.py index ee04337c..8443e76e 100644 --- a/demos/solid_body_rotation.py +++ b/demos/solid_body_rotation.py @@ -32,30 +32,23 @@ # curve of discontinuities. The test case was introduced in # :cite:`LeVeque:1996`. # -# As usual, we import from Firedrake and Goalie, with -# adjoint mode activated. :: +# As usual, we import from Firedrake and Goalie. :: from firedrake import * -from goalie_adjoint import * - -# For simplicity, we use a :math:`\mathbb P1` space for the -# concentration field. The domain of interest is again the -# unit square, in this case shifted to have its centre at -# the origin. :: - -field_names = ["c"] - - -def get_function_spaces(mesh): - return {"c": FunctionSpace(mesh, "CG", 1)} +from goalie import * +# The domain of interest is again the unit square, in this case shifted to have its +# centre at the origin. For simplicity, we use a :math:`\mathbb P1` space for the +# concentration field. :: mesh = UnitSquareMesh(40, 40) coords = mesh.coordinates.copy(deepcopy=True) coords.interpolate(coords - as_vector([0.5, 0.5])) mesh = Mesh(coords) +fields = [Field("c", family="Lagrange", degree=1)] + # Next, let's define the initial condition, to get a # better idea of the problem at hand. :: @@ -100,28 +93,20 @@ def get_initial_condition(mesh_seq): time_partition = TimeInterval( end_time, dt, - field_names, + fields, num_timesteps_per_export=25, ) -# For the purposes of plotting, we set up a :class:`MeshSeq` with -# only the :meth:`get_function_spaces` and :meth:`get_initial_condition` -# methods implemented. :: +# For the purposes of plotting, we set up a :class:`MeshSeq` with only the +# :meth:`get_initial_condition` method implemented. :: import matplotlib.pyplot as plt from firedrake.pyplot import tricontourf -mesh_seq = MeshSeq( - time_partition, - mesh, - get_function_spaces=get_function_spaces, - get_initial_condition=get_initial_condition, -) - -c_init = mesh_seq.get_initial_condition()["c"] +mesh_seq = MeshSeq(time_partition, mesh, get_initial_condition=get_initial_condition) fig, axes = plt.subplots() -tc = tricontourf(c_init, axes=axes) +tc = tricontourf(mesh_seq.get_initial_condition()["c"], axes=axes) fig.colorbar(tc) axes.set_aspect("equal") plt.tight_layout() @@ -143,7 +128,7 @@ def get_initial_condition(mesh_seq): def get_solver(mesh_seq): def solver(index): V = mesh_seq.function_spaces["c"][index] - c, c_ = mesh_seq.fields["c"] + c, c_ = mesh_seq.field_functions["c"] # Define velocity field x, y = SpatialCoordinate(mesh) @@ -195,7 +180,7 @@ def solver(index): def get_qoi(mesh_seq, index): def qoi(): - c = mesh_seq.fields["c"][0] + c = mesh_seq.field_functions["c"][0] x, y = SpatialCoordinate(mesh_seq[index]) x0, y0, r0 = 0.0, 0.25, 0.15 ball = conditional((x - x0) ** 2 + (y - y0) ** 2 < r0**2, 1.0, 0.0) @@ -209,7 +194,6 @@ def qoi(): mesh_seq = AdjointMeshSeq( time_partition, mesh, - get_function_spaces=get_function_spaces, get_initial_condition=get_initial_condition, get_solver=get_solver, get_qoi=get_qoi, diff --git a/demos/time_partition.py b/demos/time_partition.py index 3450eeab..fb0a0b1f 100644 --- a/demos/time_partition.py +++ b/demos/time_partition.py @@ -36,7 +36,7 @@ # * the end time; # * the number of subintervals; # * the timestep on each subinterval; -# * a list of field names for the solution components. +# * a list fields for the solution components. # # If the start time is not set then it is # assumed to be zero. @@ -51,8 +51,11 @@ end_time = 1.0 num_subintervals = 1 dt = 0.125 -field_names = ["solution"] +fields = [Field("solution", family="Real")] +# The :class:`~.Field` class accepts keyword arguments to customise more than just the +# name and finite element family, which we demonstrate in later demos. +# # With these definitions, we should get # one subinterval of :math:`(0,1]` containing # eight timesteps. When constructing a @@ -61,7 +64,7 @@ # mode. This is specified using :func:`set_log_level`. :: set_log_level(DEBUG) -tp = TimePartition(end_time, num_subintervals, dt, field_names) +tp = TimePartition(end_time, num_subintervals, dt, fields) # Notice that one of the things which is printed # out is ``num_timesteps_per_export``, which controls @@ -81,9 +84,7 @@ # than one subinterval. :: num_subintervals = 2 -tp = TimePartition( - end_time, num_subintervals, dt, field_names, num_timesteps_per_export=2 -) +tp = TimePartition(end_time, num_subintervals, dt, fields, num_timesteps_per_export=2) # In some problems, the dynamics evolve such # that different timesteps are suitable during @@ -92,9 +93,7 @@ # timesteps corresponding to each subinterval. :: dt = [0.125, 0.0625] -tp = TimePartition( - end_time, num_subintervals, dt, field_names, num_timesteps_per_export=2 -) +tp = TimePartition(end_time, num_subintervals, dt, fields, num_timesteps_per_export=2) # Note that this means that there are more # exports in the second subinterval than the first. @@ -102,7 +101,7 @@ # ``num_timesteps_per_export`` as a list. :: tp = TimePartition( - end_time, num_subintervals, dt, field_names, num_timesteps_per_export=[2, 4] + end_time, num_subintervals, dt, fields, num_timesteps_per_export=[2, 4] ) # So far, we have assumed that the subintervals @@ -116,7 +115,7 @@ end_time, num_subintervals, dt, - field_names, + fields, num_timesteps_per_export=[2, 4], subintervals=subintervals, ) diff --git a/goalie/__init__.py b/goalie/__init__.py index af7db6e4..8265f88f 100644 --- a/goalie/__init__.py +++ b/goalie/__init__.py @@ -6,8 +6,10 @@ from goalie.metric import * # noqa from goalie.mesh_seq import * # noqa from goalie.options import * # noqa -from goalie.point_seq import * # noqa from goalie.function_data import * # noqa +from goalie.field import * # noqa from goalie.error_estimation import * # noqa +from goalie.adjoint import * # noqa +from goalie.go_mesh_seq import * # noqa __version__ = "0.1" diff --git a/goalie/adjoint.py b/goalie/adjoint.py index d55e0372..8a1a2476 100644 --- a/goalie/adjoint.py +++ b/goalie/adjoint.py @@ -14,7 +14,6 @@ from .function_data import AdjointSolutionData from .log import pyrint from .mesh_seq import MeshSeq -from .utility import AttrDict __all__ = ["AdjointMeshSeq", "annotate_qoi"] @@ -189,18 +188,24 @@ def get_checkpoints(self, solver_kwargs=None, run_final_subinterval=False): return checkpoints @PETSc.Log.EventDecorator() - def get_solve_blocks(self, field, subinterval): + def get_solve_blocks(self, fieldname, subinterval): r""" Get all blocks of the tape corresponding to solve steps for prognostic solution field on a given subinterval. - :arg field: name of the prognostic solution field - :type field: :class:`str` + :arg fieldname: name of the prognostic solution field + :type fieldname: :class:`str` :arg subinterval: subinterval index :type subinterval: :class:`int` :returns: list of solve blocks :rtype: :class:`list` of :class:`pyadjoint.block.Block`\s """ + field = self._get_field_metadata(fieldname) + if not field.solved_for: + raise ValueError( + f"Cannot retrieve solve blocks for field '{fieldname}' because it isn't" + " solved for." + ) blocks = pyadjoint.get_working_tape().get_blocks() if len(blocks) == 0: self.warning("Tape has no blocks!") @@ -216,25 +221,25 @@ def get_solve_blocks(self, field, subinterval): solve_blocks = [ block for block in solve_blocks - if isinstance(block.tag, str) and block.tag.startswith(field) + if isinstance(block.tag, str) and block.tag.startswith(fieldname) ] N = len(solve_blocks) if N == 0: self.warning( - f"No solve blocks associated with field '{field}'." + f"No solve blocks associated with field '{fieldname}'." " Has ad_block_tag been used correctly?" ) return solve_blocks self.debug( - f"Field '{field}' on subinterval {subinterval} has {N} solve blocks." + f"Field '{fieldname}' on subinterval {subinterval} has {N} solve blocks." ) # Check FunctionSpaces are consistent across solve blocks - element = self.function_spaces[field][subinterval].ufl_element() + element = self.function_spaces[fieldname][subinterval].ufl_element() for block in solve_blocks: if element != block.function_space.ufl_element(): raise ValueError( - f"Solve block list for field '{field}' contains mismatching" + f"Solve block list for field '{fieldname}' contains mismatching" f" elements: {element} vs. {block.function_space.ufl_element()}." ) @@ -243,7 +248,7 @@ def get_solve_blocks(self, field, subinterval): if num_timesteps > N: raise ValueError( f"Number of timesteps exceeds number of solve blocks for field" - f" '{field}' on subinterval {subinterval}: {num_timesteps} > {N}." + f" '{fieldname}' on subinterval {subinterval}: {num_timesteps} > {N}." ) # Check the number of timesteps is divisible by the number of solve blocks @@ -251,18 +256,18 @@ def get_solve_blocks(self, field, subinterval): if not np.isclose(np.round(ratio), ratio): raise ValueError( "Number of timesteps is not divisible by number of solve blocks for" - f" field '{field}' on subinterval {subinterval}: {num_timesteps} vs." - f" {N}." + f" field '{fieldname}' on subinterval {subinterval}: {num_timesteps}" + f" vs. {N}." ) return solve_blocks - def _output(self, field, subinterval, solve_block): + def _output(self, fieldname, subinterval, solve_block): """ For a given solve block and solution field, get the block's outputs corresponding to the solution from the current timestep. - :arg field: field of interest - :type field: :class:`str` + :arg fieldname: name of the field of interest + :type fieldname: :class:`str` :arg subinterval: subinterval index :type subinterval: :class:`int` :arg solve_block: taped solve block @@ -271,7 +276,7 @@ def _output(self, field, subinterval, solve_block): :rtype: :class:`firedrake.function.Function` """ # TODO #93: Inconsistent return value - can be None - fs = self.function_spaces[field][subinterval] + fs = self.function_spaces[fieldname][subinterval] # Loop through the solve block's outputs candidates = [] @@ -285,7 +290,7 @@ def _output(self, field, subinterval, solve_block): # Look for Functions whose name matches that of the field # NOTE: Here we assume that the user has set this correctly in their # get_solver method - if not out.output.name() == field: + if not out.output.name() == fieldname: continue # Add to the list of candidates @@ -297,21 +302,21 @@ def _output(self, field, subinterval, solve_block): elif len(candidates) > 1: raise AttributeError( "Cannot determine a unique output index for the solution associated" - f" with field '{field}' out of {len(candidates)} candidates." + f" with field '{fieldname}' out of {len(candidates)} candidates." ) elif not self.steady: raise AttributeError( - f"Solve block for field '{field}' on subinterval {subinterval} has no" - " outputs." + f"Solve block for field '{fieldname}' on subinterval {subinterval} has" + " no outputs." ) - def _dependency(self, field, subinterval, solve_block): + def _dependency(self, fieldname, subinterval, solve_block): """ For a given solve block and solution field, get the block's dependency which corresponds to the solution from the previous timestep. - :arg field: field of interest - :type field: :class:`str` + :arg fieldname: name of the field of interest + :type fieldname: :class:`str` :arg subinterval: subinterval index :type subinterval: :class:`int` :arg solve_block: taped solve block @@ -320,9 +325,10 @@ def _dependency(self, field, subinterval, solve_block): :rtype: :class:`firedrake.function.Function` """ # TODO #93: Inconsistent return value - can be None - if self.field_types[field] == "steady": + field = self._get_field_metadata(fieldname) + if not field.unsteady: return - fs = self.function_spaces[field][subinterval] + fs = self.function_spaces[fieldname][subinterval] # Loop through the solve block's dependencies candidates = [] @@ -336,7 +342,7 @@ def _dependency(self, field, subinterval, solve_block): # Look for Functions whose name is the lagged version of the field's # NOTE: Here we assume that the user has set this correctly in their # get_solver method - if not dep.output.name() == f"{field}_old": + if not dep.output.name() == f"{fieldname}_old": continue # Add to the list of candidates @@ -348,12 +354,13 @@ def _dependency(self, field, subinterval, solve_block): elif len(candidates) > 1: raise AttributeError( "Cannot determine a unique dependency index for the lagged solution" - f" associated with field '{field}' out of {len(candidates)} candidates." + f" associated with field '{fieldname}' out of {len(candidates)}" + " candidates." ) elif not self.steady: raise AttributeError( - f"Solve block for field '{field}' on subinterval {subinterval} has no" - " dependencies." + f"Solve block for field '{fieldname}' on subinterval {subinterval} has" + " no dependencies." ) def _create_solutions(self): @@ -428,12 +435,16 @@ def _solve_adjoint( self.J = 0 if get_adj_values: - for field in self.fields: - self.solutions.extract(layout="field")[field]["adj_value"] = [] - for i, fs in enumerate(self.function_spaces[field]): - self.solutions.extract(layout="field")[field]["adj_value"].append( + for fieldname in self.field_names: + self.solutions.extract(layout="field")[fieldname]["adj_value"] = [] + for i, fs in enumerate(self.function_spaces[fieldname]): + self.solutions.extract(layout="field")[fieldname][ + "adj_value" + ].append( [ - firedrake.Cofunction(fs.dual(), name=f"{field}_adj_value") + firedrake.Cofunction( + fs.dual(), name=f"{fieldname}_adj_value" + ) for j in range(tp.num_exports_per_subinterval[i] - 1) ] ) @@ -453,16 +464,15 @@ def wrapped_solver(subinterval, initial_condition_map, **kwargs): All keyword arguments are passed to the solver. """ - copy_map = AttrDict( - { - field: initial_condition.copy(deepcopy=True) - for field, initial_condition in initial_condition_map.items() - } - ) - self._controls = list(map(pyadjoint.Control, copy_map.values())) + + # Stash a version of the above map as Controls + self._controls = { + fieldname: pyadjoint.Control(function) + for fieldname, function in initial_condition_map.items() + } # Reinitialise fields and assign initial conditions - self._reinitialise_fields(copy_map) + self._reinitialise_fields(initial_condition_map) return solver(subinterval, **kwargs) @@ -500,8 +510,10 @@ def wrapped_solver(subinterval, initial_condition_map, **kwargs): # Final solution is used as the initial condition for the next subinterval checkpoint = { - field: sol[0] if self.field_types[field] == "unsteady" else sol - for field, sol in self.fields.items() + fieldname: solution_function[0] + if self._get_field_metadata(fieldname).unsteady + else solution_function + for fieldname, solution_function in self.field_functions.items() } # Get seed vector for reverse propagation @@ -514,45 +526,52 @@ def wrapped_solver(subinterval, initial_condition_map, **kwargs): self.warning("Zero QoI. Is it implemented as intended?") pyadjoint.pause_annotation() else: - for field, fs in self.function_spaces.items(): - checkpoint[field].block_variable.adj_value = self._transfer( - seeds[field], fs[i] + for fieldname in self.solution_names: + checkpoint[fieldname].block_variable.adj_value = self._transfer( + seeds[fieldname], self.function_spaces[fieldname][i] ) # Update adjoint solver kwargs - for field in self.fields: - for block in self.get_solve_blocks(field, i): + for fieldname in self.solution_names: + for block in self.get_solve_blocks(fieldname, i): block.adj_kwargs.update(adj_solver_kwargs) # Solve adjoint problem tape = pyadjoint.get_working_tape() with PETSc.Log.Event("goalie.AdjointMeshSeq.solve_adjoint.evaluate_adj"): - controls = pyadjoint.enlisting.Enlist(self._controls) + controls = pyadjoint.enlisting.Enlist(list(self._controls.values())) with pyadjoint.stop_annotating(): - with tape.marked_nodes(controls): - tape.evaluate_adj(markings=True) + with tape.marked_control_dependents(controls): + with tape.marked_functional_dependencies(self.J): + tape.evaluate_adj(markings=True) # Compute the gradient on the first subinterval if i == 0 and compute_gradient: - self._gradient = controls.delist( - [control.get_derivative() for control in controls] - ) + self._gradient = { + field: control.get_derivative() + for field, control in zip( + self._controls.keys(), controls, strict=True + ) + } # Loop over prognostic variables - for field, fs in self.function_spaces.items(): + for fieldname in self.solution_names: + field = self._get_field_metadata(fieldname) + # Get solve blocks - solve_blocks = self.get_solve_blocks(field, i) + solve_blocks = self.get_solve_blocks(fieldname, i) num_solve_blocks = len(solve_blocks) if num_solve_blocks == 0: raise ValueError( "Looks like no solves were written to tape!" " Does the solution depend on the initial condition?" ) - if fs[0].ufl_element() != solve_blocks[0].function_space.ufl_element(): + finite_element = field.get_element(self.meshes[i]) + sb_element0 = solve_blocks[0].function_space.ufl_element() + if finite_element != sb_element0: raise ValueError( - f"Solve block list for field '{field}' contains mismatching" - f" finite elements: ({fs[0].ufl_element()} vs. " - f" {solve_blocks[0].function_space.ufl_element()})" + f"Solve block list for field '{fieldname}' contains mismatching" + f" finite elements: ({finite_element} vs. {sb_element0})" ) # Detect whether we have a steady problem @@ -564,15 +583,15 @@ def wrapped_solver(subinterval, initial_condition_map, **kwargs): if len(solve_blocks[::stride]) >= num_exports: self.warning( "More solve blocks than expected:" - f" ({len(solve_blocks[::stride])} > {num_exports-1})." + f" ({len(solve_blocks[::stride])} > {num_exports - 1})." ) # Update forward and adjoint solution data based on block dependencies # and outputs - solutions = self.solutions.extract(layout="field")[field] + solutions = self.solutions.extract(layout="field")[fieldname] for j, block in enumerate(reversed(solve_blocks[::-stride])): # Current forward solution is determined from outputs - out = self._output(field, i, block) + out = self._output(fieldname, i, block) if out is not None: solutions.forward[i][j].assign(out.saved_output) @@ -580,7 +599,7 @@ def wrapped_solver(subinterval, initial_condition_map, **kwargs): solutions.adjoint[i][j].assign(block.adj_sol) # Lagged forward solution comes from dependencies - dep = self._dependency(field, i, block) + dep = self._dependency(fieldname, i, block) if not self.steady and dep is not None: solutions.forward_old[i][j].assign(dep.saved_output) @@ -611,27 +630,28 @@ def wrapped_solver(subinterval, initial_condition_map, **kwargs): # Check non-zero adjoint solution/value if np.isclose(norm(solutions.adjoint[i][0]), 0.0): self.warning( - f"Adjoint solution for field '{field}' on {self.th(i)}" + f"Adjoint solution for field '{fieldname}' on {self.th(i)}" " subinterval is zero." ) if get_adj_values and np.isclose(norm(solutions.adj_value[i][0]), 0.0): self.warning( - f"Adjoint action for field '{field}' on {self.th(i)}" + f"Adjoint action for field '{fieldname}' on {self.th(i)}" " subinterval is zero." ) # Get adjoint action on each subinterval with pyadjoint.stop_annotating(): - for field, control in zip(self.fields, self._controls): - seeds[field] = firedrake.Cofunction( - self.function_spaces[field][i].dual() - ) + for fieldname in self.solution_names: + control = self._controls[fieldname] + field = self._get_field_metadata(fieldname) + function_space = self.function_spaces[fieldname][i] + seeds[fieldname] = firedrake.Cofunction(function_space.dual()) if control.block_variable.adj_value is not None: - seeds[field].assign(control.block_variable.adj_value) - if not self.steady and np.isclose(norm(seeds[field]), 0.0): + seeds[fieldname].assign(control.block_variable.adj_value) + if field.unsteady and np.isclose(norm(seeds[fieldname]), 0.0): self.warning( - f"Adjoint action for field '{field}' on {self.th(i)}" - " subinterval is zero." + f"Adjoint action for field '{fieldname}' on" + f" {self.th(i)} subinterval is zero." ) yield self.solutions @@ -739,7 +759,7 @@ def check_qoi_convergence(self): qoi_, qoi = self.qoi_values[-2:] if abs(qoi - qoi_) < self.params.qoi_rtol * abs(qoi_): pyrint( - f"QoI converged after {self.fp_iteration+1} iterations" + f"QoI converged after {self.fp_iteration + 1} iterations" f" under relative tolerance {self.params.qoi_rtol}." ) return True diff --git a/goalie/field.py b/goalie/field.py new file mode 100644 index 00000000..c94a1121 --- /dev/null +++ b/goalie/field.py @@ -0,0 +1,147 @@ +import ufl +from finat.ufl import ( + FiniteElementBase, + VectorElement, +) +from firedrake.functionspace import FunctionSpace, make_scalar_element +from firedrake.utility_meshes import UnitIntervalMesh + + +class Field: + """ + A class to represent a field. + """ + + def __init__( + self, + name, + finite_element=None, + vector=None, + solved_for=True, + unsteady=True, + **kwargs, + ): + """ + Constructs all the necessary attributes for the field object. + + The finite element for the Field should be set either using the `finite_element` + keyword argument or a combination of the `family`, `degree`, `vfamily`, + `vdegree`, and/or `variant` keyword arguments. For details on these arguments, + see :class:`firedrake.functionspace.FunctionSpace`. If the `finite_element` + keyword argument is specified, these other arguments are ignored. + + To account for mixed and tensor elements, please fully specify the element and + pass it via the `finite_element` keyword argument. + + :arg name: The name of the field. + :type name: :class:`str` + :kwarg finite_element: The finite element associated with the field (default is + Real space on an interval). + :type finite_element: :class:`~.FiniteElement` + :kwarg vector: Is the element a vector element? (default is False) + :type vector: :class:`bool` + :arg solved_for: Indicates if the field is to be solved for (default is True). + :type solved_for: :class:`bool` + :arg unsteady: Indicates if the field is time-dependent (default is True). + :type unsteady: :class:`bool` + """ + assert isinstance(name, str), "Field name must be a string." + self.name = name + if finite_element is not None: + if not isinstance(finite_element, FiniteElementBase): + raise TypeError( + "Field finite element must be a FiniteElement, MixedElement," + " VectorElement, or TensorElement object." + ) + if vector is not None: + raise ValueError( + "The finite_element and vector arguments cannot be used in" + " conjunction." + ) + elif kwargs.get("family") is None: + raise ValueError("Either the finite_element or family must be specified.") + self.finite_element = finite_element + self.vector = False if vector is None else vector + self.family = kwargs.pop("family", None) + self.degree = kwargs.pop("degree", None) + self.vfamily = kwargs.pop("vfamily", None) + self.vdegree = kwargs.pop("vdegree", None) + self.variant = kwargs.pop("variant", None) + if kwargs: + raise ValueError(f"Unexpected keyword argument '{list(kwargs.keys())[0]}'.") + assert isinstance(solved_for, bool), "'solved_for' argument must be a bool" + self.solved_for = solved_for + assert isinstance(unsteady, bool), "'unsteady' argument must be a bool" + self.unsteady = unsteady + + def __str__(self): + return f"Field({self.name})" + + def __repr__(self): + if self.finite_element is not None: + element_str = self.finite_element + else: + _element = self.get_element(UnitIntervalMesh(1)) + element_str = str(_element).replace("interval", "unknown cell type") + + return ( + f"Field('{self.name}', {element_str}, solved_for={self.solved_for}," + f" unsteady={self.unsteady})" + ) + + def __eq__(self, other): + if not isinstance(other, Field): + return False + return ( + self.name == other.name + and self.finite_element == other.finite_element + and self.solved_for == other.solved_for + and self.unsteady == other.unsteady + ) + + def __ne__(self, other): + return not self.__eq__(other) + + def get_element(self, mesh): + """ + Given a mesh, return the finite element associated with the field. + + :arg mesh: The mesh to determine the cell from. + :type mesh: :class:`~.firedrake.mesh.MeshGeometry` + :return: The finite element associated with the field. + :rtype: An appropriate subclass of :class:`~.FiniteElementBase` + """ + if self.finite_element is not None: + if isinstance(self.finite_element.cell, ufl.cell.CellSequence): + for cell in self.finite_element.cell.cells: + assert cell == mesh.coordinates.ufl_element().cell + else: + assert self.finite_element.cell == mesh.coordinates.ufl_element().cell + return self.finite_element + + finite_element = make_scalar_element( + mesh, + self.family, + self.degree, + self.vfamily, + self.vdegree, + self.variant, + ) + + if self.vector: + finite_element = VectorElement( + finite_element, dim=finite_element.cell.topological_dimension + ) + return finite_element + + def get_function_space(self, mesh): + """ + Given a mesh, return the function space associated with the field. + + :arg mesh: The mesh to determine the cell from. + :type mesh: :class:`~.firedrake.mesh.MeshGeometry` + :return: The function space associated with the field. + :rtype: :class:`~firedrake.functionspaceimpl.FunctionSpace` + """ + finite_element = self.get_element(mesh) + return FunctionSpace(mesh, finite_element) diff --git a/goalie/function_data.py b/goalie/function_data.py index ae182a3b..de7397af 100644 --- a/goalie/function_data.py +++ b/goalie/function_data.py @@ -20,8 +20,11 @@ class FunctionData(ABC): - """ + r""" Abstract base class for classes holding field data. + + Note that any :class:`~.Field`\s with `solved_for=False` will not be included in the + field data. """ @abstractmethod @@ -33,6 +36,11 @@ def __init__(self, time_partition, function_spaces): discretise the problem in space """ self.time_partition = time_partition + self.solution_names = [ + fieldname + for fieldname, field in time_partition.field_metadata.items() + if field.solved_for + ] self.function_spaces = function_spaces self._data = None self.labels = self._label_dict[ @@ -44,19 +52,19 @@ def _create_data(self): tp = self.time_partition self._data = AttrDict( { - field: AttrDict( + fieldname: AttrDict( { label: [ [ - ffunc.Function(fs, name=f"{field}_{label}") + ffunc.Function(fs, name=f"{fieldname}_{label}") for j in range(tp.num_exports_per_subinterval[i] - 1) ] - for i, fs in enumerate(self.function_spaces[field]) + for i, fs in enumerate(self.function_spaces[fieldname]) ] for label in self.labels } ) - for field in tp.field_names + for fieldname in self.solution_names } ) @@ -73,7 +81,11 @@ def _data_by_field(self): return self._data def __getitem__(self, key): - return self._data_by_field[key] + try: + return self._data_by_field[key] + except KeyError as ke: + errmsg = f"Field '{key}' is not associated with {type(self)} object." + raise ValueError(errmsg) from ke def items(self): return self._data_by_field.items() @@ -86,11 +98,13 @@ def _data_by_label(self): of the doubly-nested dictionary are doubly-nested lists, which retain the default layout: indexed first by subinterval and then by export. """ - tp = self.time_partition return AttrDict( { label: AttrDict( - {f: self._data_by_field[f][label] for f in tp.field_names} + { + fieldname: self._data_by_field[fieldname][label] + for fieldname in self.solution_names + } ) for label in self.labels } @@ -109,13 +123,13 @@ def _data_by_subinterval(self): return [ AttrDict( { - field: AttrDict( + fieldname: AttrDict( { - label: self._data_by_field[field][label][subinterval] + label: self._data_by_field[fieldname][label][subinterval] for label in self.labels } ) - for field in tp.field_names + for fieldname in self.solution_names } ) for subinterval in range(tp.num_subintervals) @@ -216,7 +230,8 @@ def _export_vtk(self, output_fpath, export_field_types, initial_condition=None): outfile = VTKFile(output_fpath, adaptive=True) if initial_condition is not None: ics = [] - for field, ic in sorted(initial_condition.items()): + for fieldname in sorted(self.solution_names): + ic = initial_condition[fieldname] for field_type in export_field_types: icc = ic.copy(deepcopy=True) # If the function space is mixed, rename and append each @@ -226,12 +241,12 @@ def _export_vtk(self, output_fpath, export_field_types, initial_condition=None): if field_type != "forward": sf = sf.copy(deepcopy=True) sf.assign(float("nan")) - sf.rename(f"{field}[{idx}]_{field_type}") + sf.rename(f"{fieldname}[{idx}]_{field_type}") ics.append(sf) else: if field_type != "forward": icc.assign(float("nan")) - icc.rename(f"{field}_{field_type}") + icc.rename(f"{fieldname}_{field_type}") ics.append(icc) outfile.write(*ics, time=tp.subintervals[0][0]) @@ -242,16 +257,18 @@ def _export_vtk(self, output_fpath, export_field_types, initial_condition=None): + (j + 1) * tp.timesteps[i] * tp.num_timesteps_per_export[i] ) fs = [] - for field in sorted(tp.field_names): - mixed = hasattr(self.function_spaces[field][0], "num_sub_spaces") + for fieldname in sorted(self.solution_names): + mixed = hasattr( + self.function_spaces[fieldname][0], "num_sub_spaces" + ) for field_type in export_field_types: - f = self._data[field][field_type][i][j].copy(deepcopy=True) + f = self._data[fieldname][field_type][i][j].copy(deepcopy=True) if mixed: for idx, sf in enumerate(f.subfunctions): - sf.rename(f"{field}[{idx}]_{field_type}") + sf.rename(f"{fieldname}[{idx}]_{field_type}") fs.append(sf) else: - f.rename(f"{field}_{field_type}") + f.rename(f"{fieldname}_{field_type}") fs.append(f) outfile.write(*fs, time=time) @@ -263,68 +280,69 @@ def _export_h5(self, output_fpath, export_field_types, initial_condition=None): tp = self.time_partition # Mesh names must be unique - mesh_names = [fs.mesh().name for fs in self.function_spaces[tp.field_names[0]]] + fieldname0 = self.solution_names[0] + mesh_names = [fs.mesh().name for fs in self.function_spaces[fieldname0]] rename_meshes = len(set(mesh_names)) != len(mesh_names) with CheckpointFile(output_fpath, "w") as outfile: if initial_condition is not None: - for field, ic in initial_condition.items(): - outfile.save_function(ic, name=f"{field}_initial") + for fieldname, ic in initial_condition.items(): + outfile.save_function(ic, name=f"{fieldname}_initial") for i in range(tp.num_subintervals): if rename_meshes: mesh_name = f"mesh_{i}" - mesh = self.function_spaces[tp.field_names[0]][i].mesh() + mesh = self.function_spaces[self.solution_names[0]][i].mesh() mesh.name = mesh_name mesh.topology_dm.name = mesh_name - for field in tp.field_names: + for fieldname in self.solution_names: for field_type in export_field_types: - name = f"{field}_{field_type}" + name = f"{fieldname}_{field_type}" for j in range(tp.num_exports_per_subinterval[i] - 1): - f = self._data[field][field_type][i][j] + f = self._data[fieldname][field_type][i][j] outfile.save_function(f, name=name, idx=j) - def transfer(self, target, method="interpolate"): + def transfer(self, other, method="interpolate"): """ - Transfer all functions from this :class:`~.FunctionData` object to the target + Transfer all functions from this :class:`~.FunctionData` object to the other :class:`~.FunctionData` object by interpolation, projection or prolongation. - :arg target: the target :class:`~.FunctionData` object to which to transfer the + :arg other: the other :class:`~.FunctionData` object to which to transfer the data - :type target: :class:`~.FunctionData` + :type other: :class:`~.FunctionData` :arg method: the transfer method to use. Either 'interpolate', 'project' or 'prolong' :type method: :class:`str` """ stp = self.time_partition - ttp = target.time_partition + otp = other.time_partition if method not in ["interpolate", "project", "prolong"]: raise ValueError( f"Transfer method '{method}' not supported." " Supported methods are 'interpolate', 'project', and 'prolong'." ) - if stp.num_subintervals != ttp.num_subintervals: + if stp.num_subintervals != otp.num_subintervals: raise ValueError( "Source and target have different numbers of subintervals." ) - if stp.num_exports_per_subinterval != ttp.num_exports_per_subinterval: + if stp.num_exports_per_subinterval != otp.num_exports_per_subinterval: raise ValueError( "Source and target have different numbers of exports per subinterval." ) - common_fields = set(stp.field_names) & set(ttp.field_names) + common_fields = set(self.solution_names) & set(other.solution_names) if not common_fields: raise ValueError("No common fields between source and target.") - common_labels = set(self.labels) & set(target.labels) + common_labels = set(self.labels) & set(other.labels) if not common_labels: raise ValueError("No common labels between source and target.") - for field in common_fields: + for fieldname in common_fields: for label in common_labels: for i in range(stp.num_subintervals): for j in range(stp.num_exports_per_subinterval[i] - 1): - source_function = self._data[field][label][i][j] - target_function = target._data[field][label][i][j] + source_function = self._data[fieldname][label][i][j] + target_function = other._data[fieldname][label][i][j] if method == "interpolate": target_function.interpolate(source_function) elif method == "project": @@ -334,7 +352,7 @@ def transfer(self, target, method="interpolate"): class ForwardSolutionData(FunctionData): - """ + r""" Class representing solution data for general forward problems. For a given exported timestep, the field types are: @@ -342,6 +360,9 @@ class ForwardSolutionData(FunctionData): * ``'forward'``: the forward solution after taking the timestep; * ``'forward_old'``: the forward solution before taking the timestep (provided the problem is not steady-state). + + Note that any :class:`~.Field`\s with `solved_for=False` will not be included in the + field data. """ def __init__(self, *args, **kwargs): @@ -353,7 +374,7 @@ def __init__(self, *args, **kwargs): class AdjointSolutionData(FunctionData): - """ + r""" Class representing solution data for general adjoint problems. For a given exported timestep, the field types are: @@ -364,6 +385,9 @@ class AdjointSolutionData(FunctionData): * ``'adjoint'``: the adjoint solution after taking the timestep; * ``'adjoint_next'``: the adjoint solution before taking the timestep backwards (provided the problem is not steady-state). + + Note that any :class:`~.Field`\s with `solved_for=False` will not be included in the + field data. """ def __init__(self, *args, **kwargs): @@ -375,11 +399,15 @@ def __init__(self, *args, **kwargs): class IndicatorData(FunctionData): - """ + r""" Class representing error indicator data. Note that this class has a single dictionary with the field name as the key, rather than a doubly-nested dictionary. + + Note that any :class:`~.Field`\s with `solved_for=False` will not be included in the + field data. (It doesn't usually make sense to compute error indicators for those + anyway.) """ def __init__(self, time_partition, meshes): @@ -389,13 +417,12 @@ def __init__(self, time_partition, meshes): :arg meshes: the list of meshes used to discretise the problem in space """ self._label_dict = dict.fromkeys(("steady", "unsteady"), ("error_indicator",)) - super().__init__( - time_partition, - { - key: [ffs.FunctionSpace(mesh, "DG", 0) for mesh in meshes] - for key in time_partition.field_names - }, - ) + solution_spaces = { + fieldname: [ffs.FunctionSpace(mesh, "DG", 0) for mesh in meshes] + for fieldname, field in time_partition.field_metadata.items() + if field.solved_for + } + super().__init__(time_partition, solution_spaces) @property def _data_by_field(self): @@ -408,8 +435,8 @@ def _data_by_field(self): self._create_data() return AttrDict( { - field: self._data[field]["error_indicator"] - for field in self.time_partition.field_names + fieldname: self._data[fieldname]["error_indicator"] + for fieldname in self.solution_names } ) @@ -428,8 +455,12 @@ def _data_by_subinterval(self): subinterval. Entries of the list are dictionaries, keyed by field label. Entries of the dictionaries are lists of field data, indexed by export. """ - tp = self.time_partition return [ - AttrDict({f: self._data_by_field[f][subinterval] for f in tp.field_names}) - for subinterval in range(tp.num_subintervals) + AttrDict( + { + fieldname: self._data_by_field[fieldname][subinterval] + for fieldname in self.solution_names + } + ) + for subinterval in range(self.time_partition.num_subintervals) ] diff --git a/goalie/go_mesh_seq.py b/goalie/go_mesh_seq.py index 4caa5592..2986a354 100644 --- a/goalie/go_mesh_seq.py +++ b/goalie/go_mesh_seq.py @@ -8,14 +8,17 @@ import numpy as np import ufl from animate.interpolation import interpolate +from animate.utility import function_data_sum from firedrake import Function, FunctionSpace, MeshHierarchy, TransferManager from firedrake.petsc import PETSc from .adjoint import AdjointMeshSeq from .error_estimation import get_dwr_indicator +from .field import Field from .function_data import IndicatorData from .log import pyrint from .options import GoalOrientedAdaptParameters +from .time_partition import TimePartition __all__ = ["GoalOrientedMeshSeq"] @@ -40,16 +43,18 @@ def read_forms(self, forms_dictionary): values are the UFL forms :type forms_dictionary: :class:`dict` """ - for field, form in forms_dictionary.items(): - if field not in self.fields: + for fieldname, form in forms_dictionary.items(): + if fieldname not in self.solution_names: raise ValueError( - f"Unexpected field '{field}' in forms dictionary." - f" Expected one of {self.time_partition.field_names}." - ) - if not isinstance(form, ufl.Form): - raise TypeError( - f"Expected a UFL form for field '{field}', not '{type(form)}'." + f"Unexpected field '{fieldname}' in forms dictionary." + f" Expected one of {self.solution_names}." ) + if self.field_metadata[fieldname].solved_for: + if not isinstance(form, ufl.Form): + raise TypeError( + f"Expected a UFL form for field '{fieldname}', not" + f" '{type(form)}'." + ) self._forms = forms_dictionary @property @@ -86,29 +91,29 @@ def _detect_changing_coefficients(self, export_idx): if export_idx == 0: # Copy coefficients at subinterval's first export timestep self._prev_form_coeffs = { - field: deepcopy(form.coefficients()) - for field, form in self.forms.items() + fieldname: deepcopy(form.coefficients()) + for fieldname, form in self.forms.items() + } + self._changed_form_coeffs = { + fieldname: {} for fieldname in self.solution_names } - self._changed_form_coeffs = {field: {} for field in self.fields} else: # Store coefficients that have changed since the previous export timestep - for field in self.fields: + for fieldname, form in self.forms.items(): # Coefficients at the current timestep - coeffs = self.forms[field].coefficients() + coeffs = form.coefficients() for coeff_idx, (coeff, init_coeff) in enumerate( - zip(coeffs, self._prev_form_coeffs[field]) + zip(coeffs, self._prev_form_coeffs[fieldname], strict=True) ): # Skip solution fields since they are stored separately - if coeff.name().split("_old")[0] in self.time_partition.field_names: + if coeff.name().split("_old")[0] in self.function_spaces: continue - if not np.allclose( - coeff.vector().array(), init_coeff.vector().array() - ): - if coeff_idx not in self._changed_form_coeffs[field]: - self._changed_form_coeffs[field][coeff_idx] = { + if not np.allclose(coeff.dat.data_ro, init_coeff.dat.data_ro): + if coeff_idx not in self._changed_form_coeffs[fieldname]: + self._changed_form_coeffs[fieldname][coeff_idx] = { 0: deepcopy(init_coeff) } - self._changed_form_coeffs[field][coeff_idx][export_idx] = ( + self._changed_form_coeffs[fieldname][coeff_idx][export_idx] = ( deepcopy(coeff) ) # Use the current coeff for comparison in the next timestep @@ -147,11 +152,33 @@ def get_enriched_mesh_seq(self, enrichment_method="p", num_enrichments=1): else: meshes = self.meshes + # Apply p-refinement + tp = self.time_partition + if enrichment_method == "p": + field_metadata = {} + for fieldname, field in self.field_metadata.items(): + element = field.get_element(meshes[0]) + element = element.reconstruct(degree=element.degree() + num_enrichments) + field_metadata[fieldname] = Field( + fieldname, + finite_element=element, + solved_for=field.solved_for, + unsteady=field.unsteady, + ) + tp = TimePartition( + tp.end_time, + tp.num_subintervals, + tp.timesteps, + field_metadata, + num_timesteps_per_export=tp.num_timesteps_per_export, + start_time=tp.start_time, + subintervals=tp.subintervals, + ) + # Construct object to hold enriched spaces enriched_mesh_seq = type(self)( - self.time_partition, + tp, meshes, - get_function_spaces=self._get_function_spaces, get_initial_condition=self._get_initial_condition, get_solver=self._get_solver, get_qoi=self._get_qoi, @@ -159,18 +186,6 @@ def get_enriched_mesh_seq(self, enrichment_method="p", num_enrichments=1): ) enriched_mesh_seq._update_function_spaces() - # Apply p-refinement - if enrichment_method == "p": - for label, fs in enriched_mesh_seq.function_spaces.items(): - for n, _space in enumerate(fs): - element = _space.ufl_element() - element = element.reconstruct( - degree=element.degree() + num_enrichments - ) - enriched_mesh_seq._fs[label][n] = FunctionSpace( - enriched_mesh_seq.meshes[n], element - ) - return enriched_mesh_seq @staticmethod @@ -262,16 +277,19 @@ def indicate_errors( # Get Functions u, u_, u_star, u_star_next, u_star_e = {}, {}, {}, {}, {} enriched_spaces = { - f: enriched_mesh_seq.function_spaces[f][i] for f in self.fields + fieldname: enriched_mesh_seq.function_spaces[fieldname][i] + for fieldname in self.field_functions } - for f, fs_e in enriched_spaces.items(): - if self.field_types[f] == "steady": - u[f] = enriched_mesh_seq.fields[f] + for fieldname, fs_e in enriched_spaces.items(): + field = self._get_field_metadata(fieldname) + if field.unsteady: + u[fieldname] = enriched_mesh_seq.field_functions[fieldname][0] + u_[fieldname] = enriched_mesh_seq.field_functions[fieldname][1] else: - u[f], u_[f] = enriched_mesh_seq.fields[f] - u_star[f] = Function(fs_e) - u_star_next[f] = Function(fs_e) - u_star_e[f] = Function(fs_e) + u[fieldname] = enriched_mesh_seq.field_functions[fieldname] + u_star[fieldname] = Function(fs_e) + u_star_next[fieldname] = Function(fs_e) + u_star_e[fieldname] = Function(fs_e) # Loop over each timestep for j in range(self.time_partition.num_exports_per_subinterval[i] - 1): @@ -280,44 +298,55 @@ def indicate_errors( # latter fields from the previous timestep. Therefore, we must transfer # the lagged solution of latter fields as if they were the current # timestep solutions. This assumes that the order of fields being solved - # for in get_solver is the same as their order in self.fields - for f_next in self.time_partition.field_names[1:]: - transfer(self.solutions[f_next][FWD_OLD][i][j], u[f_next]) + # for in get_solver is the same as their order in self.field_functions + for fieldname in self.solution_names: + transfer(self.solutions[fieldname][FWD_OLD][i][j], u[fieldname]) # Loop over each strongly coupled field - for f in self.fields: - # Transfer solutions associated with the current field f - transfer(self.solutions[f][FWD][i][j], u[f]) - if self.field_types[f] == "unsteady": - transfer(self.solutions[f][FWD_OLD][i][j], u_[f]) - transfer(self.solutions[f][ADJ][i][j], u_star[f]) - transfer(self.solutions[f][ADJ_NEXT][i][j], u_star_next[f]) + for fieldname in self.solution_names: + solutions = self.solutions[fieldname] + enriched_solutions = enriched_mesh_seq.solutions[fieldname] + + # Transfer solutions associated with the current field + transfer(solutions[FWD][i][j], u[fieldname]) + field = self._get_field_metadata(fieldname) + if field.unsteady: + transfer(solutions[FWD_OLD][i][j], u_[fieldname]) + transfer(solutions[ADJ][i][j], u_star[fieldname]) + transfer(solutions[ADJ_NEXT][i][j], u_star_next[fieldname]) # Combine adjoint solutions as appropriate - u_star[f].assign(0.5 * (u_star[f] + u_star_next[f])) - u_star_e[f].assign( + u_star[fieldname].assign( + 0.5 * (u_star[fieldname] + u_star_next[fieldname]) + ) + u_star_e[fieldname].assign( 0.5 * ( - enriched_mesh_seq.solutions[f][ADJ][i][j] - + enriched_mesh_seq.solutions[f][ADJ_NEXT][i][j] + enriched_solutions[ADJ][i][j] + + enriched_solutions[ADJ_NEXT][i][j] ) ) - u_star_e[f] -= u_star[f] + u_star_e[fieldname] -= u_star[fieldname] # Update other time-dependent form coefficients if they changed # since the previous export timestep - emseq = enriched_mesh_seq - if not self.steady and emseq._changed_form_coeffs[f]: - for idx, coeffs in emseq._changed_form_coeffs[f].items(): + changed_coeffs = enriched_mesh_seq._changed_form_coeffs + if not self.steady and changed_coeffs: + for idx, coeffs in changed_coeffs[fieldname].items(): if j in coeffs: - emseq.forms[f].coefficients()[idx].assign(coeffs[j]) + form = enriched_mesh_seq.forms[fieldname] + form.coefficients()[idx].assign(coeffs[j]) # Evaluate error indicator - indi_e = indicator_fn(enriched_mesh_seq.forms[f], u_star_e[f]) + indi_e = indicator_fn( + enriched_mesh_seq.forms[fieldname], u_star_e[fieldname] + ) # Transfer back to the base space indi = self._transfer(indi_e, P0_spaces[i]) indi.interpolate(abs(indi)) - self.indicators[f][i][j].interpolate(ufl.max_value(indi, 1.0e-16)) + self.indicators[fieldname][i][j].interpolate( + ufl.max_value(indi, 1.0e-16) + ) return self.solutions, self.indicators @@ -338,20 +367,20 @@ def error_estimate(self, absolute_value=False): f"Expected 'absolute_value' to be a bool, not '{type(absolute_value)}'." ) estimator = 0 - for field, by_field in self.indicators.items(): - if field not in self.time_partition.field_names: - raise ValueError( - f"Key '{field}' does not exist in the TimePartition provided." - ) - assert not isinstance(by_field, Function) and isinstance(by_field, Iterable) - for by_mesh, dt in zip(by_field, self.time_partition.timesteps): + for fieldname in self.solution_names: + by_field = self.indicators[fieldname] + assert not isinstance(by_field, Function) + assert isinstance(by_field, Iterable) + for by_mesh, dt in zip( + by_field, self.time_partition.timesteps, strict=True + ): assert not isinstance(by_mesh, Function) and isinstance( by_mesh, Iterable ) for indicator in by_mesh: if absolute_value: indicator.interpolate(abs(indicator)) - estimator += dt * indicator.vector().gather().sum() + estimator += dt * function_data_sum(indicator) return estimator def check_estimator_convergence(self): @@ -372,8 +401,9 @@ def check_estimator_convergence(self): ee_, ee = self.estimator_values[-2:] if abs(ee - ee_) < self.params.estimator_rtol * abs(ee_): pyrint( - f"Error estimator converged after {self.fp_iteration+1} iterations" - f" under relative tolerance {self.params.estimator_rtol}." + f"Error estimator converged after {self.fp_iteration + 1}" + " iterations under relative tolerance" + f" {self.params.estimator_rtol}." ) return True return False diff --git a/goalie/mesh_seq.py b/goalie/mesh_seq.py index 3b24d16e..c86998d4 100644 --- a/goalie/mesh_seq.py +++ b/goalie/mesh_seq.py @@ -8,8 +8,9 @@ import numpy as np from animate.interpolation import transfer from animate.quality import QualityMeasure -from animate.utility import Mesh +from animate.utility import Mesh, function_data_max from firedrake.adjoint import pyadjoint +from firedrake.mesh import MeshSequenceGeometry from firedrake.petsc import PETSc from firedrake.pyplot import triplot @@ -35,8 +36,6 @@ def __init__(self, time_partition, initial_meshes, **kwargs): :arg initial_meshes: a list of meshes corresponding to the subinterval of the time partition, or a single mesh to use for all subintervals :type initial_meshes: :class:`list` or :class:`~.MeshGeometry` - :kwarg get_function_spaces: a function as described in - :meth:`~.MeshSeq.get_function_spaces` :kwarg get_initial_condition: a function as described in :meth:`~.MeshSeq.get_initial_condition` :kwarg get_solver: a function as described in :meth:`~.MeshSeq.get_solver` @@ -49,13 +48,27 @@ def __init__(self, time_partition, initial_meshes, **kwargs): take various types """ self.time_partition = time_partition - self.fields = dict.fromkeys(time_partition.field_names) - self.field_types = dict(zip(self.fields, time_partition.field_types)) self.subintervals = time_partition.subintervals self.num_subintervals = time_partition.num_subintervals + self.field_names = time_partition.field_names + self.field_metadata = time_partition.field_metadata + self.solution_names = [ + fieldname + for fieldname in self.field_names + if self.field_metadata[fieldname].solved_for + ] + + # Create a dictionary to hold field Functions with field names as keys and None + # as values + self.field_functions = dict.fromkeys(self.field_metadata) + self.set_meshes(initial_meshes) self._fs = None - self._get_function_spaces = kwargs.get("get_function_spaces") + if "get_function_spaces" in kwargs: + raise KeyError( + "get_function_spaces is no longer supported. Specify the finite_element" + " argument for the Field class instead." + ) self._get_initial_condition = kwargs.get("get_initial_condition") self._get_solver = kwargs.get("get_solver") self._transfer_method = kwargs.get("transfer_method", "project") @@ -169,7 +182,7 @@ def set_meshes(self, meshes): if not isinstance(meshes, list): meshes = [Mesh(meshes) for subinterval in self.subintervals] self.meshes = meshes - dim = np.array([mesh.topological_dimension() for mesh in meshes]) + dim = np.array([mesh.topological_dimension for mesh in meshes]) if dim.min() != dim.max(): raise ValueError("Meshes must all have the same topological dimension.") self.dim = dim.min() @@ -180,7 +193,7 @@ def set_meshes(self, meshes): nv = self.vertex_counts[0][i] qm = QualityMeasure(mesh) ar = qm("aspect_ratio") - mar = ar.vector().gather().max() + mar = function_data_max(ar) self.debug( f"{i}: {nc:7d} cells, {nv:7d} vertices, max aspect ratio {mar:.2f}" ) @@ -234,6 +247,11 @@ def plot(self, fig=None, axes=None, **kwargs): axes = axes[0] return fig, axes + def _get_field_metadata(self, fieldname): + if fieldname not in self.field_names: + raise ValueError(f"Field '{fieldname}' is not associated with the MeshSeq.") + return self.field_metadata[fieldname] + def get_function_spaces(self, mesh): """ Construct the function spaces corresponding to each field, for a given mesh. @@ -245,9 +263,10 @@ def get_function_spaces(self, mesh): :rtype: :class:`dict` with :class:`str` keys and :class:`firedrake.functionspaceimpl.FunctionSpace` values """ - if self._get_function_spaces is None: - raise NotImplementedError("'get_function_spaces' needs implementing.") - return self._get_function_spaces(mesh) + function_spaces = {} + for fieldname, field in self.field_metadata.items(): + function_spaces[fieldname] = field.get_function_space(mesh) + return function_spaces def get_initial_condition(self): r""" @@ -261,8 +280,8 @@ def get_initial_condition(self): if self._get_initial_condition is not None: return self._get_initial_condition(self) return { - field: firedrake.Function(fs[0]) - for field, fs in self.function_spaces.items() + fieldname: firedrake.Function(fs[0]) + for fieldname, fs in self.function_spaces.items() } def get_solver(self): @@ -315,15 +334,13 @@ def _transfer(self, source, target_space, **kwargs): def _outputs_consistent(self): """ Assert that function spaces and initial conditions are given in a - dictionary format with :attr:`MeshSeq.fields` as keys. + dictionary format with the same keys as :attr:`MeshSeq.field_metadata`. """ - for method in ["function_spaces", "initial_condition", "solver"]: + for method in ["initial_condition", "solver"]: if getattr(self, f"_get_{method}") is None: continue method_map = getattr(self, f"get_{method}") - if method == "function_spaces": - method_map = method_map(self.meshes[0]) - elif method == "initial_condition": + if method == "initial_condition": method_map = method_map() elif method == "solver": self._reinitialise_fields(self.get_initial_condition()) @@ -331,15 +348,15 @@ def _outputs_consistent(self): assert hasattr(solver_gen, "__next__"), "solver should yield" if logger.level == DEBUG: next(solver_gen) - f, f_ = self.fields[next(iter(self.fields))] - if np.array_equal(f.vector().array(), f_.vector().array()): + f, f_ = self.field_functions[next(iter(self.field_functions))] + if np.array_equal(f.dat.data_ro, f_.dat.data_ro): self.debug( "Current and lagged solutions are equal. Does the" " solver yield before updating lagged solutions?" ) # noqa break assert isinstance(method_map, dict), f"get_{method} should return a dict" - mesh_seq_fields = set(self.fields) + mesh_seq_fields = set(self.field_functions) method_fields = set(method_map.keys()) diff = mesh_seq_fields.difference(method_fields) assert len(diff) == 0, f"missing fields {diff} in get_{method}" @@ -356,14 +373,24 @@ def _function_spaces_consistent(self): :rtype: `:class:`bool` """ consistent = len(self.time_partition) == len(self) - consistent &= all(len(self) == len(self._fs[field]) for field in self.fields) - for field in self.fields: - consistent &= all( - mesh == fs.mesh() for mesh, fs in zip(self.meshes, self._fs[field]) - ) + consistent &= all( + len(self) == len(self._fs[fieldname]) for fieldname in self.field_functions + ) + for fieldname in self.field_functions: + if isinstance(self._fs[fieldname][0].mesh(), MeshSequenceGeometry): + consistent &= all( + mesh1 == mesh2 + for mesh1, fs in zip(self.meshes, self._fs[fieldname], strict=True) + for mesh2 in fs.mesh() + ) + else: + consistent &= all( + mesh == fs.mesh() + for mesh, fs in zip(self.meshes, self._fs[fieldname], strict=True) + ) consistent &= all( - self._fs[field][0].ufl_element() == fs.ufl_element() - for fs in self._fs[field] + self._fs[fieldname][0].ufl_element() == fs.ufl_element() + for fs in self._fs[fieldname] ) return consistent @@ -374,13 +401,15 @@ def _update_function_spaces(self): if self._fs is None or not self._function_spaces_consistent(): self._fs = AttrDict( { - field: [self.get_function_spaces(mesh)[field] for mesh in self] - for field in self.fields + fieldname: [ + self.get_function_spaces(mesh)[fieldname] for mesh in self + ] + for fieldname in self.field_functions } ) - assert ( - self._function_spaces_consistent() - ), "Meshes and function spaces are inconsistent" + assert self._function_spaces_consistent(), ( + "Meshes and function spaces are inconsistent" + ) @property def function_spaces(self): @@ -438,15 +467,18 @@ def _reinitialise_fields(self, initial_conditions): :type initial_conditions: :class:`dict` with :class:`str` keys and :class:`firedrake.function.Function` values """ - for field, ic in initial_conditions.items(): + for fieldname in self.field_names: + ic = initial_conditions[fieldname] fs = ic.function_space() - if self.field_types[field] == "steady": - self.fields[field] = firedrake.Function(fs, name=f"{field}").assign(ic) - else: - self.fields[field] = ( - firedrake.Function(fs, name=field), - firedrake.Function(fs, name=f"{field}_old").assign(ic), + field = self._get_field_metadata(fieldname) + if field.unsteady: + self.field_functions[fieldname] = ( + firedrake.Function(fs, name=fieldname), + firedrake.Function(fs, name=f"{fieldname}_old").assign(ic), ) + else: + self.field_functions[fieldname] = firedrake.Function(fs, name=fieldname) + self.field_functions[fieldname].assign(ic) @PETSc.Log.EventDecorator() def _solve_forward(self, update_solutions=True, solver_kwargs=None): @@ -492,14 +524,15 @@ def _solve_forward(self, update_solutions=True, solver_kwargs=None): for _ in range(tp.num_timesteps_per_export[i]): next(solver_gen) # Update the solution data - for field, sol in self.fields.items(): - if not self.field_types[field] == "steady": + for fieldname, sol in self.field_functions.items(): + field = self._get_field_metadata(fieldname) + if field.unsteady: assert isinstance(sol, tuple) - solutions[field].forward[i][j].assign(sol[0]) - solutions[field].forward_old[i][j].assign(sol[1]) + solutions[fieldname].forward[i][j].assign(sol[0]) + solutions[fieldname].forward_old[i][j].assign(sol[1]) else: assert isinstance(sol, firedrake.Function) - solutions[field].forward[i][j].assign(sol) + solutions[fieldname].forward[i][j].assign(sol) else: # Solve over the entire subinterval in one go for _ in range(tp.num_timesteps_per_subinterval[i]): @@ -509,13 +542,13 @@ def _solve_forward(self, update_solutions=True, solver_kwargs=None): if i < num_subintervals - 1: checkpoint = AttrDict( { - field: self._transfer( - self.fields[field] - if self.field_types[field] == "steady" - else self.fields[field][0], + fieldname: self._transfer( + self.field_functions[fieldname][0] + if self._get_field_metadata(fieldname).unsteady + else self.field_functions[fieldname], fs[i + 1], ) - for field, fs in self._fs.items() + for fieldname, fs in self._fs.items() } ) @@ -589,7 +622,7 @@ def check_element_count_convergence(self): else: converged = np.array([False] * len(self), dtype=bool) if len(self.element_counts) >= max(2, self.params.miniter + 1): - for i, (ne_, ne) in enumerate(zip(*self.element_counts[-2:])): + for i, (ne_, ne) in enumerate(zip(*self.element_counts[-2:], strict=True)): if not self.check_convergence[i]: self.info( f"Skipping element count convergence check on subinterval {i})" @@ -600,14 +633,14 @@ def check_element_count_convergence(self): converged[i] = True if len(self) == 1: pyrint( - f"Element count converged after {self.fp_iteration+1}" + f"Element count converged after {self.fp_iteration + 1}" " iterations under relative tolerance" f" {self.params.element_rtol}." ) else: pyrint( f"Element count converged on subinterval {i} after" - f" {self.fp_iteration+1} iterations under relative" + f" {self.fp_iteration + 1} iterations under relative" f" tolerance {self.params.element_rtol}." ) diff --git a/goalie/metric.py b/goalie/metric.py index 167e1b76..5d20c880 100644 --- a/goalie/metric.py +++ b/goalie/metric.py @@ -47,7 +47,7 @@ def enforce_variable_constraints( h_max = [h_max] * len(metrics) if not isinstance(a_max, Iterable): a_max = [a_max] * len(metrics) - for metric, hmin, hmax, amax in zip(metrics, h_min, h_max, a_max): + for metric, hmin, hmax, amax in zip(metrics, h_min, h_max, a_max, strict=True): metric.set_parameters( { "dm_plex_metric_h_min": hmin, @@ -98,7 +98,7 @@ def space_time_normalise( """ if isinstance(metric_parameters, dict): metric_parameters = [metric_parameters for _ in range(len(time_partition))] - d = metrics[0].function_space().mesh().topological_dimension() + d = metrics[0].function_space().mesh().topological_dimension if len(metrics) != len(time_partition): raise ValueError( "Number of metrics does not match number of subintervals:" @@ -112,7 +112,7 @@ def space_time_normalise( # Preparation step metric_parameters = metric_parameters.copy() - for metric, mp in zip(metrics, metric_parameters): + for metric, mp in zip(metrics, metric_parameters, strict=True): if not isinstance(mp, dict): raise TypeError( "Expected metric_parameters to consist of dictionaries," @@ -147,7 +147,7 @@ def space_time_normalise( integral = 0 p = mp["dm_plex_metric_p"] exponent = 0.5 if np.isinf(p) else p / (2 * p + d) - for metric, S in zip(metrics, time_partition): + for metric, S in zip(metrics, time_partition, strict=True): dX = (ufl.ds if boundary else ufl.dx)(metric.function_space().mesh()) scaling = pow(S.num_timesteps, 2 * exponent) integral += scaling * firedrake.assemble( @@ -158,7 +158,7 @@ def space_time_normalise( global_factor = firedrake.Constant(pow(target / integral, 2 / d)) debug(f"space_time_normalise: global scale factor={float(global_factor):.4e}") - for metric, S in zip(metrics, time_partition): + for metric, S in zip(metrics, time_partition, strict=True): # Normalise according to the global normalisation factor metric.normalise( global_factor=global_factor, diff --git a/goalie/point_seq.py b/goalie/point_seq.py deleted file mode 100644 index ba4f122e..00000000 --- a/goalie/point_seq.py +++ /dev/null @@ -1,44 +0,0 @@ -import firedrake -import firedrake.mesh as fmesh - -from .mesh_seq import MeshSeq - -__all__ = ["PointSeq"] - - -class PointSeq(MeshSeq): - """ - A simplified subset of :class:`~.MeshSeq` for ODE problems. - - In this version, a single mesh comprised of a single vertex is shared across all - subintervals. - """ - - def __init__(self, time_partition, **kwargs): - r""" - :arg time_partition: the :class:`~.TimePartition` which partitions the temporal - domain - :kwarg get_function_spaces: a function, whose only argument is a - :class:`~.MeshSeq`, which constructs prognostic - :class:`firedrake.functionspaceimpl.FunctionSpace`\s for each subinterval - :kwarg get_initial_condition: a function, whose only argument is a - :class:`~.MeshSeq`, which specifies initial conditions on the first mesh - :kwarg get_solver: a function, whose only argument is a :class:`~.MeshSeq`, - which returns a function that integrates initial data over a subinterval - :kwarg get_bcs: a function, whose only argument is a :class:`~.MeshSeq`, which - returns a function that determines any Dirichlet boundary conditions - """ - mesh = fmesh.VertexOnlyMesh(firedrake.UnitIntervalMesh(1), [[0.5]]) - super().__init__(time_partition, mesh, **kwargs) - - def set_meshes(self, mesh): - """ - Update the mesh associated with the :class:`~.PointSeq`, as well as the - associated attributes. - - :arg mesh: the vertex-only mesh - """ - self.meshes = [mesh for _ in self.subintervals] - self.dim = mesh.topological_dimension() - assert self.dim == 0 - self._reset_counts() diff --git a/goalie/time_partition.py b/goalie/time_partition.py index 0a080cb6..d379ebc9 100644 --- a/goalie/time_partition.py +++ b/goalie/time_partition.py @@ -6,6 +6,7 @@ import numpy as np +from .field import Field from .log import debug __all__ = ["TimePartition", "TimeInterval", "TimeInstant"] @@ -24,11 +25,10 @@ def __init__( end_time, num_subintervals, timesteps, - field_names, + field_metadata, num_timesteps_per_export=1, start_time=0.0, subintervals=None, - field_types=None, ): r""" :arg end_time: end time of the interval of interest @@ -38,8 +38,10 @@ def __init__( :arg timesteps: a list timesteps to be used on each subinterval, or a single timestep to use for all subintervals :type timesteps: :class:`list` of :class:`float`\s or :class:`float` - :arg field_names: the list of field names to consider - :type field_names: :class:`list` of :class:`str`\s or :class:`str` + :arg field_metadata: the Field or list or dict thereof to consider. In the case + of a dict, the keys should be consistent with the field names + :type field_metadata: :class:`~.Field`, :class:`list` of :class:`~.Field`\s, or + :class:`dict` with :class:`str` keys and :class:`~.Field` values :kwarg num_timesteps_per_export: a list of numbers of timesteps per export for each subinterval, or a single number to use for all subintervals :type num_timesteps_per_export: :class:`list` of :class`int`\s or :class:`int` @@ -48,15 +50,28 @@ def __init__( :kwarg subinterals: sequence of subintervals (which need not be of uniform length), or ``None`` to use uniform subintervals (the default) :type subintervals: :class:`list` of :class:`tuple`\s - :kwarg field_types: a list of strings indicating whether each field is - 'unsteady' or 'steady', i.e., does the corresponding equation involve time - derivatives or not? - :type field_types: :class:`list` of :class:`str`\s or :class:`str` """ debug(100 * "-") - if isinstance(field_names, str): - field_names = [field_names] - self.field_names = field_names + + # Extract field metadata as a dictionary with field names as keys, if not + # already in this format + if isinstance(field_metadata, Field): + field_metadata = [field_metadata] + if not isinstance(field_metadata, (dict, list)): + raise TypeError( + "field_metadata argument must be a Field or a dict or list thereof." + ) + if isinstance(field_metadata, dict): + for fieldname, field in field_metadata.items(): + if fieldname != field.name: + raise ValueError("Inconstent field names passed as field_metadata.") + self.field_metadata = field_metadata + else: + if not all(isinstance(field, Field) for field in field_metadata): + raise TypeError("All fields must be instances of Field.") + self.field_metadata = {field.name: field for field in field_metadata} + + self.field_names = list(self.field_metadata.keys()) self.start_time = start_time self.end_time = end_time self.num_subintervals = int(np.round(num_subintervals)) @@ -91,7 +106,9 @@ def __init__( # Get number of timesteps on each subinterval self.num_timesteps_per_subinterval = [] - for i, ((ts, tf), dt) in enumerate(zip(self.subintervals, self.timesteps)): + for i, ((ts, tf), dt) in enumerate( + zip(self.subintervals, self.timesteps, strict=True) + ): num_timesteps = (tf - ts) / dt self.num_timesteps_per_subinterval.append(int(np.round(num_timesteps))) if not np.isclose(num_timesteps, self.num_timesteps_per_subinterval[-1]): @@ -112,7 +129,9 @@ def __init__( self.num_exports_per_subinterval = [ tsps // tspe + 1 for tspe, tsps in zip( - self.num_timesteps_per_export, self.num_timesteps_per_subinterval + self.num_timesteps_per_export, + self.num_timesteps_per_subinterval, + strict=True, ) ] self.debug("num_exports_per_subinterval") @@ -120,16 +139,6 @@ def __init__( self.num_subintervals == 1 and self.num_timesteps_per_subinterval[0] == 1 ) self.debug("steady") - - # Process field types - if field_types is None: - num_fields = len(self.field_names) - field_types = ["steady" if self.steady else "unsteady"] * num_fields - elif isinstance(field_types, str): - field_types = [field_types] - self.field_types = field_types - self._check_field_types() - debug("field_types") debug(100 * "-") def debug(self, attr): @@ -152,13 +161,13 @@ def __str__(self): def __repr__(self): timesteps = ", ".join([str(dt) for dt in self.timesteps]) - field_names = ", ".join([f"'{field_name}'" for field_name in self.field_names]) + fields = ", ".join([repr(field) for field in self.field_metadata.values()]) return ( f"TimePartition(" f"end_time={self.end_time}, " f"num_subintervals={self.num_subintervals}, " f"timesteps=[{timesteps}], " - f"field_names=[{field_names}])" + f"field_metadata=[{fields}])" ) def __len__(self): @@ -184,10 +193,9 @@ def __getitem__(self, index_or_slice): end_time=self.subintervals[sl.stop - 1][1], num_subintervals=num_subintervals, timesteps=self.timesteps[sl], - field_names=self.field_names, + field_metadata=self.field_metadata, num_timesteps_per_export=self.num_timesteps_per_export[sl], start_time=self.subintervals[sl.start][0], - field_types=self.field_types, ) @property @@ -213,8 +221,8 @@ def _check_subintervals(self): if not np.isclose(self.subintervals[i][1], self.subintervals[i + 1][0]): raise ValueError( f"The end of subinterval {i} does not match the start of" - f" subinterval {i+1}: {self.subintervals[i][1]} !=" - f" {self.subintervals[i+1][0]}." + f" subinterval {i + 1}: {self.subintervals[i][1]} !=" + f" {self.subintervals[i + 1][0]}." ) if not np.isclose(self.subintervals[-1][1], self.end_time): raise ValueError( @@ -239,7 +247,11 @@ def _check_num_timesteps_per_export(self): f" != {len(self.num_timesteps_per_subinterval)}." ) for i, (tspe, tsps) in enumerate( - zip(self.num_timesteps_per_export, self.num_timesteps_per_subinterval) + zip( + self.num_timesteps_per_export, + self.num_timesteps_per_subinterval, + strict=True, + ) ): if not isinstance(tspe, int): raise TypeError( @@ -253,19 +265,6 @@ def _check_num_timesteps_per_export(self): f" {tsps} | {tspe} != 0." ) - def _check_field_types(self): - if len(self.field_names) != len(self.field_types): - raise ValueError( - "Number of field names does not match number of field types:" - f" {len(self.field_names)} != {len(self.field_types)}." - ) - for field_name, field_type in zip(self.field_names, self.field_types): - if field_type not in ("unsteady", "steady"): - raise ValueError( - f"Expected field type for field '{field_name}' to be either" - f" 'unsteady' or 'steady', but got '{field_type}'." - ) - def __eq__(self, other): if len(self) != len(other): return False @@ -275,22 +274,11 @@ def __eq__(self, other): and np.allclose( self.num_exports_per_subinterval, other.num_exports_per_subinterval ) - and self.field_names == other.field_names - and self.field_types == other.field_types + and self.field_metadata == other.field_metadata ) def __ne__(self, other): - if len(self) != len(other): - return True - return ( - not np.allclose(self.subintervals, other.subintervals) - or not np.allclose(self.timesteps, other.timesteps) - or not np.allclose( - self.num_exports_per_subinterval, other.num_exports_per_subinterval - ) - or not self.field_names == other.field_names - or not self.field_types == other.field_types - ) + return not self.__eq__(other) class TimeInterval(TimePartition): @@ -306,15 +294,18 @@ def __init__(self, *args, **kwargs): else: end_time = args[0] timestep = args[1] - field_names = args[2] - super().__init__(end_time, 1, timestep, field_names, **kwargs) + field_metadata = args[2] + super().__init__(end_time, 1, timestep, field_metadata, **kwargs) def __repr__(self): + field_metadata = ", ".join( + [repr(field) for field in self.field_metadata.values()] + ) return ( f"TimeInterval(" f"end_time={self.end_time}, " f"timestep={self.timestep}, " - f"field_names={self.field_names})" + f"field_metadata=[{field_metadata}])" ) @property @@ -333,7 +324,7 @@ class TimeInstant(TimeInterval): Under the hood this means dividing :math:`[0,1)` into a single timestep. """ - def __init__(self, field_names, **kwargs): + def __init__(self, field_metadata, **kwargs): if "end_time" in kwargs: if "time" in kwargs: raise ValueError("Both 'time' and 'end_time' are set.") @@ -341,12 +332,13 @@ def __init__(self, field_names, **kwargs): else: time = kwargs.pop("time", 1.0) timestep = time - super().__init__(time, timestep, field_names, **kwargs) + super().__init__(time, timestep, field_metadata, **kwargs) def __str__(self): return f"({self.end_time})" def __repr__(self): - return ( - f"TimeInstant(" f"time={self.end_time}, " f"field_names={self.field_names})" + field_metadata = ", ".join( + [repr(field) for field in self.field_metadata.values()] ) + return f"TimeInstant(time={self.end_time}, field_metadata=[{field_metadata}])" diff --git a/goalie/utility.py b/goalie/utility.py index b3d3626b..6a9e5b22 100644 --- a/goalie/utility.py +++ b/goalie/utility.py @@ -5,7 +5,7 @@ import os import firedrake -import numpy as np +from animate.utility import function_data_sum __all__ = ["AttrDict", "create_directory", "effectivity_index"] @@ -43,8 +43,7 @@ def effectivity_index(error_indicator, Je): el = error_indicator.ufl_element() if not (el.family() == "Discontinuous Lagrange" and el.degree() == 0): raise ValueError("Error indicator must be P0.") - eta = error_indicator.vector().gather().sum() - return np.abs(eta / Je) + return abs(function_data_sum(error_indicator) / Je) def create_directory(path, comm=firedrake.COMM_WORLD): diff --git a/goalie_adjoint/__init__.py b/goalie_adjoint/__init__.py deleted file mode 100644 index bae10d55..00000000 --- a/goalie_adjoint/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from __future__ import absolute_import - -from pyadjoint import no_annotations # noqa - -from goalie import * # noqa -from goalie.adjoint import * # noqa -from goalie.go_mesh_seq import * # noqa diff --git a/pyproject.toml b/pyproject.toml index a6e9506e..3d8abc5f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,8 +12,6 @@ authors = [ ] maintainers = [ {name = "Joseph G. Wallwork", email = "joe.wallwork@outlook.com"}, - {name = "Davor Dundovic"}, - {name = "Eleda Johnson"}, {name = "Stephan C. Kramer"}, ] description = "Goal-oriented error estimation and mesh adaptation for finite element problems solved using Firedrake" @@ -42,7 +40,7 @@ Documentation = "https://mesh-adaptation.github.io/goalie/index.html" Repository = "https://github.com/mesh-adaptation/goalie" [tool.setuptools] -packages = ["goalie", "goalie_adjoint"] +packages = ["goalie"] [tool.ruff] line-length = 88 diff --git a/test/adjoint/Makefile b/test/adjoint/Makefile index 1dbc440f..092fae67 100644 --- a/test/adjoint/Makefile +++ b/test/adjoint/Makefile @@ -2,7 +2,7 @@ all: run run: @echo "Running all adjoint tests..." - @python3 -m pytest -v --durations=10 . + @python3 -m pytest -v -n auto --durations=10 . @echo "Done." clean: diff --git a/test/adjoint/examples/burgers.py b/test/adjoint/examples/burgers.py index d074d03a..9a230bb6 100644 --- a/test/adjoint/examples/burgers.py +++ b/test/adjoint/examples/burgers.py @@ -8,16 +8,20 @@ """ import ufl +from finat.ufl import FiniteElement, VectorElement from firedrake.function import Function -from firedrake.functionspace import FunctionSpace, VectorFunctionSpace +from firedrake.functionspace import FunctionSpace from firedrake.solving import solve from firedrake.ufl_expr import TestFunction from firedrake.utility_meshes import UnitSquareMesh +from goalie.field import Field + # Problem setup n = 32 mesh = UnitSquareMesh(n, n, diagonal="left") -fields = ["uv_2d"] +finite_element = VectorElement(FiniteElement("Lagrange", ufl.triangle, 2), dim=2) +fields = [Field("uv_2d", finite_element=finite_element)] end_time = 0.5 dt = 1 / n dt_per_export = 2 @@ -25,11 +29,6 @@ get_bcs = None -def get_function_spaces(mesh): - r""":math:`\mathbb P2` space.""" - return {"uv_2d": VectorFunctionSpace(mesh, "CG", 2)} - - def get_solver(self): """ Burgers equation solved using a direct method and backward Euler timestepping. @@ -39,7 +38,7 @@ def solver(i): t_start, t_end = self.time_partition.subintervals[i] dt = self.time_partition.timesteps[i] - u, u_ = self.fields["uv_2d"] + u, u_ = self.field_functions["uv_2d"] # Setup variational problem dt = self.time_partition.timesteps[i] @@ -89,7 +88,7 @@ def get_qoi(self, i): dtc = Function(R).assign(self.time_partition.timesteps[i]) def time_integrated_qoi(t): - u = self.fields["uv_2d"][0] + u = self.field_functions["uv_2d"][0] return dtc * ufl.inner(u, u) * ufl.ds(2) def end_time_qoi(): diff --git a/test/adjoint/examples/point_discharge2d.py b/test/adjoint/examples/point_discharge2d.py index 38da03e0..e375baf9 100644 --- a/test/adjoint/examples/point_discharge2d.py +++ b/test/adjoint/examples/point_discharge2d.py @@ -12,6 +12,7 @@ import numpy as np import ufl +from finat.ufl import FiniteElement from firedrake.assemble import assemble from firedrake.bcs import DirichletBC from firedrake.function import Function @@ -20,12 +21,14 @@ from firedrake.ufl_expr import CellSize, TestFunction from firedrake.utility_meshes import RectangleMesh +from goalie.field import Field from goalie.math import bessk0 # Problem setup n = 0 mesh = RectangleMesh(100 * 2**n, 20 * 2**n, 50, 10) -fields = ["tracer_2d"] +finite_element = FiniteElement("Lagrange", ufl.triangle, 1) +fields = [Field("tracer_2d", finite_element=finite_element, unsteady=False)] end_time = 1.0 dt = 1.0 dt_per_export = 1 @@ -35,13 +38,6 @@ get_initial_condition = None -def get_function_spaces(mesh): - r""" - :math:`\mathbb P1` space. - """ - return {"tracer_2d": FunctionSpace(mesh, "CG", 1)} - - def source(mesh): """ Gaussian approximation to a point source at (2, 5) with discharge rate 100 on a @@ -56,7 +52,7 @@ def get_solver(self): def solver(i): fs = self.function_spaces["tracer_2d"][i] - c = self.fields["tracer_2d"] + c = self.field_functions["tracer_2d"] # Define constants fs = self.function_spaces["tracer_2d"][i] @@ -106,7 +102,7 @@ def get_qoi(self, i): """ def steady_qoi(): - c = self.fields["tracer_2d"] + c = self.field_functions["tracer_2d"] x, y = ufl.SpatialCoordinate(self[i]) kernel = ufl.conditional((x - rec_x) ** 2 + (y - rec_y) ** 2 < rec_r**2, 1, 0) area = assemble(kernel * ufl.dx) diff --git a/test/adjoint/examples/point_discharge3d.py b/test/adjoint/examples/point_discharge3d.py index cb00e8c6..15cd29d2 100644 --- a/test/adjoint/examples/point_discharge3d.py +++ b/test/adjoint/examples/point_discharge3d.py @@ -16,6 +16,7 @@ import numpy as np import ufl +from finat.ufl import FiniteElement from firedrake.assemble import assemble from firedrake.bcs import DirichletBC from firedrake.function import Function @@ -24,12 +25,14 @@ from firedrake.ufl_expr import CellSize, TestFunction from firedrake.utility_meshes import BoxMesh +from goalie.field import Field from goalie.math import bessk0 # Problem setup n = 0 mesh = BoxMesh(100 * 2**n, 20 * 2**n, 20 * 2**n, 50, 10, 10) -fields = ["tracer_3d"] +finite_element = FiniteElement("Lagrange", ufl.tetrahedron, 1) +fields = [Field("tracer_3d", finite_element=finite_element, unsteady=False)] end_time = 1.0 dt = 1.0 dt_per_export = 1 @@ -60,7 +63,7 @@ def get_solver(self): def solver(i): fs = self.function_spaces["tracer_3d"][i] - c = self.fields["tracer_3d"] + c = self.field_functions["tracer_3d"] # Define constants fs = self.function_spaces["tracer_3d"][i] @@ -111,7 +114,7 @@ def get_qoi(self, i): """ def steady_qoi(): - c = self.fields["tracer_3d"] + c = self.field_functions["tracer_3d"] x, y, z = ufl.SpatialCoordinate(self[i]) kernel = ufl.conditional( (x - rec_x) ** 2 + (y - rec_y) ** 2 + (z - rec_z) ** 2 < rec_r**2, 1, 0 diff --git a/test/adjoint/examples/steady_flow_past_cyl.py b/test/adjoint/examples/steady_flow_past_cyl.py index fc535df8..b608d184 100644 --- a/test/adjoint/examples/steady_flow_past_cyl.py +++ b/test/adjoint/examples/steady_flow_past_cyl.py @@ -11,15 +11,22 @@ import os import ufl +from finat.ufl import FiniteElement, MixedElement, VectorElement from firedrake.bcs import DirichletBC from firedrake.function import Function -from firedrake.functionspace import FunctionSpace, VectorFunctionSpace +from firedrake.functionspace import FunctionSpace from firedrake.mesh import Mesh from firedrake.solving import solve from firedrake.ufl_expr import TestFunctions +from goalie.field import Field + mesh = Mesh(os.path.join(os.path.dirname(__file__), "mesh-with-hole.msh")) -fields = ["up"] +p2v_element = VectorElement(FiniteElement("Lagrange", ufl.triangle, 2), dim=2) +p1_element = FiniteElement("Lagrange", ufl.triangle, 1) +fields = [ + Field("up", finite_element=MixedElement([p2v_element, p1_element]), unsteady=False) +] dt = 1.0 end_time = dt dt_per_export = 1 @@ -27,17 +34,12 @@ steady = True -def get_function_spaces(mesh): - r"""Taylor-Hood :math:`\mathbb P2-\mathbb P1` space.""" - return {"up": VectorFunctionSpace(mesh, "CG", 2) * FunctionSpace(mesh, "CG", 1)} - - def get_solver(self): """Stokes problem solved using a direct method.""" def solver(i): W = self.function_spaces["up"][i] - up = self.fields["up"] + up = self.field_functions["up"] # Define variational problem R = FunctionSpace(self[i], "R", 0) @@ -96,7 +98,7 @@ def get_qoi(self, i): """Quantity of interest which integrates pressure over the boundary of the hole.""" def steady_qoi(): - u, p = ufl.split(self.fields["up"]) + u, p = ufl.split(self.field_functions["up"]) return p * ufl.ds(4) return steady_qoi diff --git a/test/adjoint/test_adjoint.py b/test/adjoint/test_adjoint.py index 0fb81ea9..7a99f63f 100644 --- a/test/adjoint/test_adjoint.py +++ b/test/adjoint/test_adjoint.py @@ -16,6 +16,7 @@ from firedrake.utility_meshes import UnitTriangleMesh from goalie.adjoint import AdjointMeshSeq +from goalie.field import Field from goalie.log import DEBUG, pyrint, set_log_level from goalie.time_partition import TimeInterval, TimePartition from goalie.utility import AttrDict @@ -33,7 +34,7 @@ class TestAdjointMeshSeqGeneric(unittest.TestCase): """ def setUp(self): - self.time_interval = TimeInterval(1.0, [0.5], ["field"]) + self.time_interval = TimeInterval(1.0, [0.5], Field("field", family="Real")) self.meshes = [UnitTriangleMesh()] def test_qoi_type_error(self): @@ -118,7 +119,6 @@ def test_adjoint_same_mesh(problem, qoi_type, debug=False): mesh_seq = AdjointMeshSeq( time_partition, test_case.mesh, - get_function_spaces=test_case.get_function_spaces, get_initial_condition=test_case.get_initial_condition, get_solver=test_case.get_solver, get_qoi=test_case.get_qoi, @@ -143,8 +143,9 @@ def test_adjoint_same_mesh(problem, qoi_type, debug=False): m = pyadjoint.enlisting.Enlist(controls) assert pyadjoint.annotate_tape() pyadjoint.pause_annotation() - with tape.marked_nodes(m): - tape.evaluate_adj(markings=True) + with tape.marked_control_dependents(m): + with tape.marked_functional_dependencies(J): + tape.evaluate_adj(markings=True) # FIXME: Using mixed Functions as Controls not correct J_expected = float(J) @@ -179,7 +180,6 @@ def test_adjoint_same_mesh(problem, qoi_type, debug=False): mesh_seq = AdjointMeshSeq( time_partition, test_case.mesh, - get_function_spaces=test_case.get_function_spaces, get_initial_condition=test_case.get_initial_condition, get_solver=test_case.get_solver, get_qoi=test_case.get_qoi, @@ -195,14 +195,14 @@ def test_adjoint_same_mesh(problem, qoi_type, debug=False): # Check adjoint solutions at first export time match first_export_time = test_case.dt * test_case.dt_per_export - for field in time_partition.field_names: - adj_sol_expected = adj_sols_expected[field] + for fieldname in time_partition.field_names: + adj_sol_expected = adj_sols_expected[fieldname] expected_norm = norm(adj_sol_expected) if np.isclose(expected_norm, 0.0): raise ValueError( f"'Expected' norm at t={first_export_time} is unexpectedly zero." ) - adj_sol_computed = solutions[field].adjoint[0][0] + adj_sol_computed = solutions[fieldname].adjoint[0][0] err = errornorm(adj_sol_expected, adj_sol_computed) / expected_norm if not np.isclose(err, 0.0): raise ValueError( @@ -212,9 +212,9 @@ def test_adjoint_same_mesh(problem, qoi_type, debug=False): # Check adjoint actions at first export time match if not steady: - for field in time_partition.field_names: - adj_value_expected = adj_values_expected[field] - adj_value_computed = solutions[field].adj_value[0][0] + for fieldname in time_partition.field_names: + adj_value_expected = adj_values_expected[fieldname] + adj_value_computed = solutions[fieldname].adj_value[0][0] err = errornorm(adj_value_expected, adj_value_computed) / norm( adj_value_expected ) @@ -256,7 +256,6 @@ def plot_solutions(problem, qoi_type, debug=True): solutions = AdjointMeshSeq( time_partition, test_case.mesh, - get_function_spaces=test_case.get_function_spaces, get_initial_condition=test_case.get_initial_condition, get_solver=test_case.get_solver, get_qoi=test_case.get_qoi, @@ -276,8 +275,8 @@ def plot_solutions(problem, qoi_type, debug=True): for label in outfiles: for k in range(time_partition.num_exports_per_subinterval[0] - 1): to_plot = [] - for field in time_partition.field_names: - sol = solutions[field][label][0][k] + for fieldname in time_partition.field_names: + sol = solutions[fieldname][label][0][k] to_plot += ( [sol] if not hasattr(sol, "subfunctions") diff --git a/test/adjoint/test_adjoint_mesh_seq.py b/test/adjoint/test_adjoint_mesh_seq.py index d70a0c28..3afcb281 100644 --- a/test/adjoint/test_adjoint_mesh_seq.py +++ b/test/adjoint/test_adjoint_mesh_seq.py @@ -10,8 +10,9 @@ import pytest import ufl from animate.utility import norm +from finat.ufl import FiniteElement, VectorElement from firedrake.function import Function -from firedrake.functionspace import FunctionSpace, VectorFunctionSpace +from firedrake.functionspace import FunctionSpace from firedrake.solving import solve from firedrake.ufl_expr import TestFunction, TrialFunction from firedrake.utility_meshes import UnitSquareMesh, UnitTriangleMesh @@ -19,6 +20,7 @@ from pyadjoint.block_variable import BlockVariable from goalie.adjoint import AdjointMeshSeq +from goalie.field import Field from goalie.go_mesh_seq import GoalOrientedMeshSeq from goalie.log import WARNING from goalie.time_partition import TimeInterval, TimePartition @@ -34,9 +36,11 @@ class RSpaceTestCase(unittest.TestCase): Unit test case using R-space. """ - @staticmethod - def get_function_spaces(mesh): - return {"field": FunctionSpace(mesh, "R", 0)} + def setUp(self): + mesh = UnitSquareMesh(1, 1) + self.meshes = [mesh] + self.field = Field("field", family="Real", degree=0) + self.function_space = FunctionSpace(mesh, self.field.get_element(mesh)) class TrivialGoalOrientedBaseClass(unittest.TestCase): """ @@ -44,8 +48,6 @@ class TrivialGoalOrientedBaseClass(unittest.TestCase): """ def setUp(self): - self.field = "field" - self.time_interval = TimeInterval(1.0, [1.0], [self.field]) self.meshes = [UnitSquareMesh(1, 1)] @staticmethod @@ -53,28 +55,36 @@ def constant_qoi(mesh_seq, solutions, index): R = FunctionSpace(mesh_seq[index], "R", 0) return lambda: Function(R).assign(1) * ufl.dx - def go_mesh_seq(self, get_function_spaces, parameters=None): + def go_mesh_seq(self, element=None, parameters=None): + if element is None: + element = FiniteElement("Real", ufl.triangle, 0) + field = Field("field", finite_element=element) return GoalOrientedMeshSeq( - self.time_interval, + TimeInterval(1.0, [1.0], field), self.meshes, - get_function_spaces=get_function_spaces, qoi_type="steady", parameters=parameters, ) - class GoalOrientedBaseClass(RSpaceTestCase): + class GoalOrientedBaseClass(unittest.TestCase): """ Base class for tests with a complete :class:`GoalOrientedMeshSeq`. """ def setUp(self): - self.field = "field" - self.time_partition = TimePartition(1.0, 1, 0.5, [self.field]) - self.meshes = [UnitSquareMesh(1, 1)] + mesh = UnitSquareMesh(1, 1) + self.meshes = [mesh] + self.field = Field("field", family="Real", degree=0) def go_mesh_seq(self, coeff_diff=0.0): + self.time_partition = TimePartition(1.0, 1, 0.5, [self.field]) + def get_initial_condition(mesh_seq): - return {self.field: Function(mesh_seq.function_spaces[self.field][0])} + return { + self.field.name: Function( + mesh_seq.function_spaces[self.field.name][0] + ) + } def get_solver(mesh_seq): def solver(index): @@ -82,14 +92,14 @@ def solver(index): R = FunctionSpace(mesh_seq[index], "R", 0) dt = Function(R).assign(tp.timesteps[index]) - u, u_ = mesh_seq.fields[self.field] + u, u_ = mesh_seq.field_functions[self.field.name] f = Function(R).assign(1.0001) v = TestFunction(u.function_space()) F = (u - u_) / dt * v * ufl.dx - f * v * ufl.dx - mesh_seq.read_forms({self.field: F}) + mesh_seq.read_forms({self.field.name: F}) for _ in range(tp.num_timesteps_per_subinterval[index]): - solve(F == 0, u, ad_block_tag=self.field) + solve(F == 0, u, ad_block_tag=self.field.name) yield u_.assign(u) @@ -99,7 +109,7 @@ def solver(index): def get_qoi(mesh_seq, i): def end_time_qoi(): - u = mesh_seq.fields[self.field][0] + u = mesh_seq.field_functions[self.field.name][0] return ufl.inner(u, u) * ufl.dx return end_time_qoi @@ -108,7 +118,6 @@ def end_time_qoi(): self.time_partition, self.meshes, get_initial_condition=get_initial_condition, - get_function_spaces=self.get_function_spaces, get_solver=get_solver, get_qoi=get_qoi, qoi_type="end_time", @@ -121,14 +130,29 @@ class TestBlockLogic(BaseClasses.RSpaceTestCase): """ def setUp(self): - self.time_interval = TimeInterval(1.0, 0.5, "field") - self.mesh = UnitTriangleMesh() + super().setUp() self.mesh_seq = AdjointMeshSeq( - self.time_interval, + TimeInterval(1.0, 0.5, self.field), + self.meshes, + qoi_type="end_time", + ) + assert len(self.meshes) == 1 + self.mesh = self.meshes[0] + + def test_field_not_solved_for(self): + field_not_solved_for = Field("field", family="Real", degree=0, solved_for=False) + mesh_seq = AdjointMeshSeq( + TimeInterval(1.0, 0.5, field_not_solved_for), self.mesh, - get_function_spaces=self.get_function_spaces, qoi_type="end_time", ) + with self.assertRaises(ValueError) as cm: + mesh_seq.get_solve_blocks("field", 0) + msg = ( + "Cannot retrieve solve blocks for field 'field' because it isn't solved" + " for." + ) + self.assertEqual(str(cm.exception), msg) @patch("firedrake.adjoint_utils.blocks.solving.GenericSolveBlock") def test_output_not_function(self, MockSolveBlock): @@ -143,7 +167,9 @@ def test_output_not_function(self, MockSolveBlock): @patch("firedrake.adjoint_utils.blocks.solving.GenericSolveBlock") def test_output_wrong_function_space(self, MockSolveBlock): solve_block = MockSolveBlock() - block_variable = BlockVariable(Function(FunctionSpace(self.mesh, "CG", 1))) + block_variable = BlockVariable( + Function(FunctionSpace(self.mesh, self.field.get_element(self.mesh))) + ) solve_block.get_outputs = lambda: [block_variable] with self.assertRaises(AttributeError) as cm: self.mesh_seq._output("field", 0, solve_block) @@ -153,8 +179,7 @@ def test_output_wrong_function_space(self, MockSolveBlock): @patch("firedrake.adjoint_utils.blocks.solving.GenericSolveBlock") def test_output_wrong_name(self, MockSolveBlock): solve_block = MockSolveBlock() - function_space = FunctionSpace(self.mesh, "R", 0) - block_variable = BlockVariable(Function(function_space, name="field2")) + block_variable = BlockVariable(Function(self.function_space, name="field2")) solve_block.get_outputs = lambda: [block_variable] with self.assertRaises(AttributeError) as cm: self.mesh_seq._output("field", 0, solve_block) @@ -164,16 +189,14 @@ def test_output_wrong_name(self, MockSolveBlock): @patch("firedrake.adjoint_utils.blocks.solving.GenericSolveBlock") def test_output_valid(self, MockSolveBlock): solve_block = MockSolveBlock() - function_space = FunctionSpace(self.mesh, "R", 0) - block_variable = BlockVariable(Function(function_space, name="field")) + block_variable = BlockVariable(Function(self.function_space, name="field")) solve_block.get_outputs = lambda: [block_variable] self.assertIsNotNone(self.mesh_seq._output("field", 0, solve_block)) @patch("firedrake.adjoint_utils.blocks.solving.GenericSolveBlock") def test_output_multiple_valid_error(self, MockSolveBlock): solve_block = MockSolveBlock() - function_space = FunctionSpace(self.mesh, "R", 0) - block_variable = BlockVariable(Function(function_space, name="field")) + block_variable = BlockVariable(Function(self.function_space, name="field")) solve_block.get_outputs = lambda: [block_variable, block_variable] with self.assertRaises(AttributeError) as cm: self.mesh_seq._output("field", 0, solve_block) @@ -196,7 +219,9 @@ def test_dependency_not_function(self, MockSolveBlock): @patch("firedrake.adjoint_utils.blocks.solving.GenericSolveBlock") def test_dependency_wrong_function_space(self, MockSolveBlock): solve_block = MockSolveBlock() - block_variable = BlockVariable(Function(FunctionSpace(self.mesh, "CG", 1))) + block_variable = BlockVariable( + Function(FunctionSpace(self.mesh, "Lagrange", 1)) + ) solve_block.get_dependencies = lambda: [block_variable] with self.assertRaises(AttributeError) as cm: self.mesh_seq._dependency("field", 0, solve_block) @@ -217,16 +242,14 @@ def test_dependency_wrong_name(self, MockSolveBlock): @patch("firedrake.adjoint_utils.blocks.solving.GenericSolveBlock") def test_dependency_valid(self, MockSolveBlock): solve_block = MockSolveBlock() - function_space = FunctionSpace(self.mesh, "R", 0) - block_variable = BlockVariable(Function(function_space, name="field_old")) + block_variable = BlockVariable(Function(self.function_space, name="field_old")) solve_block.get_dependencies = lambda: [block_variable] self.assertIsNotNone(self.mesh_seq._dependency("field", 0, solve_block)) @patch("firedrake.adjoint_utils.blocks.solving.GenericSolveBlock") def test_dependency_multiple_valid_error(self, MockSolveBlock): solve_block = MockSolveBlock() - function_space = FunctionSpace(self.mesh, "R", 0) - block_variable = BlockVariable(Function(function_space, name="field_old")) + block_variable = BlockVariable(Function(self.function_space, name="field_old")) solve_block.get_dependencies = lambda: [block_variable, block_variable] with self.assertRaises(AttributeError) as cm: self.mesh_seq._dependency("field", 0, solve_block) @@ -238,11 +261,11 @@ def test_dependency_multiple_valid_error(self, MockSolveBlock): @patch("firedrake.adjoint_utils.blocks.solving.GenericSolveBlock") def test_dependency_steady(self, MockSolveBlock): - time_interval = TimeInterval(1.0, 0.5, "field", field_types="steady") + field = Field("field", family="Real", unsteady=False) + self.time_interval = TimeInterval(1.0, 0.5, field) mesh_seq = AdjointMeshSeq( - time_interval, + self.time_interval, self.mesh, - get_function_spaces=self.get_function_spaces, qoi_type="end_time", ) solve_block = MockSolveBlock() @@ -255,11 +278,11 @@ class TestGetSolveBlocks(BaseClasses.RSpaceTestCase): """ def setUp(self): - time_interval = TimeInterval(1.0, [1.0], ["field"]) + super().setUp() + time_interval = TimeInterval(1.0, [1.0], self.field) self.mesh_seq = AdjointMeshSeq( time_interval, - [UnitSquareMesh(1, 1)], - get_function_spaces=self.get_function_spaces, + self.meshes, qoi_type="steady", ) if not pyadjoint.annotate_tape(): @@ -313,7 +336,7 @@ def test_wrong_solve_block(self): self.assertTrue(msg in str(self._caplog.records[0])) def test_wrong_function_space(self): - fs = FunctionSpace(self.mesh_seq[0], "CG", 1) + fs = FunctionSpace(self.mesh_seq[0], "Lagrange", 1) u = Function(fs, name="field") self.arbitrary_solve(u) msg = ( @@ -325,11 +348,10 @@ def test_wrong_function_space(self): self.assertEqual(str(cm.exception), msg) def test_too_many_timesteps(self): - time_interval = TimeInterval(1.0, [0.5], ["field"]) + time_interval = TimeInterval(1.0, [0.5], self.field) mesh_seq = AdjointMeshSeq( time_interval, [UnitSquareMesh(1, 1)], - get_function_spaces=self.get_function_spaces, qoi_type="end_time", ) fs = mesh_seq.function_spaces["field"][0] @@ -344,11 +366,10 @@ def test_too_many_timesteps(self): self.assertEqual(str(cm.exception), msg) def test_incompatible_timesteps(self): - time_interval = TimeInterval(1.0, [0.5], ["field"]) + time_interval = TimeInterval(1.0, [0.5], self.field) mesh_seq = AdjointMeshSeq( time_interval, [UnitSquareMesh(1, 1)], - get_function_spaces=self.get_function_spaces, qoi_type="end_time", ) fs = mesh_seq.function_spaces["field"][0] @@ -365,28 +386,40 @@ def test_incompatible_timesteps(self): self.assertEqual(str(cm.exception), msg) -class TestGoalOrientedMeshSeq( - BaseClasses.RSpaceTestCase, BaseClasses.TrivialGoalOrientedBaseClass -): +class TestGoalOrientedMeshSeq(BaseClasses.TrivialGoalOrientedBaseClass): """ Unit tests for a :class:`GoalOrientedMeshSeq`. """ def test_read_forms_error_field(self): - mesh_seq = self.go_mesh_seq(self.get_function_spaces) + fields = [ + Field("field", family="R"), + Field("field2", family="R", solved_for=False), + ] + go_mesh_seq = GoalOrientedMeshSeq( + TimeInterval(1.0, [1.0], fields), + self.meshes, + qoi_type="steady", + ) + with self.assertRaises(ValueError) as cm: - mesh_seq.read_forms({"field2": None}) + go_mesh_seq.read_forms({"field2": None}) msg = ( - "Unexpected field 'field2' in forms dictionary." - f" Expected one of ['{self.field}']." + "Unexpected field 'field2' in forms dictionary. Expected one of ['field']." + ) + self.assertEqual(str(cm.exception), msg) + + with self.assertRaises(ValueError) as cm: + go_mesh_seq.read_forms({"field3": None}) + msg = ( + "Unexpected field 'field3' in forms dictionary. Expected one of ['field']." ) self.assertEqual(str(cm.exception), msg) def test_read_forms_error_form(self): - mesh_seq = self.go_mesh_seq(self.get_function_spaces) with self.assertRaises(TypeError) as cm: - mesh_seq.read_forms({self.field: None}) - msg = f"Expected a UFL form for field '{self.field}', not ''." + self.go_mesh_seq().read_forms({"field": None}) + msg = "Expected a UFL form for field 'field', not ''." self.assertEqual(str(cm.exception), msg) @@ -395,34 +428,28 @@ class TestGlobalEnrichment(BaseClasses.TrivialGoalOrientedBaseClass): Unit tests for global enrichment of a :class:`GoalOrientedMeshSeq`. """ - def get_function_spaces_decorator(self, degree, family, rank): - def get_function_spaces(mesh): - if rank == 0: - return {self.field: FunctionSpace(mesh, degree, family)} - elif rank == 1: - return {self.field: VectorFunctionSpace(mesh, degree, family)} - else: - raise NotImplementedError - - return get_function_spaces + def element(self, family, degree, rank): + if rank == 0: + return FiniteElement(family, ufl.triangle, degree) + elif rank == 1: + return VectorElement(FiniteElement(family, ufl.triangle, degree)) + else: + raise NotImplementedError def test_enrichment_error(self): - mesh_seq = self.go_mesh_seq(self.get_function_spaces_decorator("R", 0, 0)) with self.assertRaises(ValueError) as cm: - mesh_seq.get_enriched_mesh_seq(enrichment_method="q") + self.go_mesh_seq().get_enriched_mesh_seq(enrichment_method="q") self.assertEqual(str(cm.exception), "Enrichment method 'q' not supported.") def test_num_enrichments_error(self): - mesh_seq = self.go_mesh_seq(self.get_function_spaces_decorator("R", 0, 0)) with self.assertRaises(ValueError) as cm: - mesh_seq.get_enriched_mesh_seq(num_enrichments=0) + self.go_mesh_seq().get_enriched_mesh_seq(num_enrichments=0) msg = "A positive number of enrichments is required." self.assertEqual(str(cm.exception), msg) def test_form_error(self): - mesh_seq = self.go_mesh_seq(self.get_function_spaces_decorator("R", 0, 0)) with self.assertRaises(AttributeError) as cm: - mesh_seq.forms() + self.go_mesh_seq().forms() msg = ( "Forms have not been read in. Use read_forms({'field_name': F}) in" " get_solver to read in the forms." @@ -433,8 +460,9 @@ def test_h_enrichment_error(self): end_time = 1.0 num_subintervals = 2 dt = end_time / num_subintervals + field = Field("field", family="Real") mesh_seq = GoalOrientedMeshSeq( - TimePartition(end_time, num_subintervals, dt, "field"), + TimePartition(end_time, num_subintervals, dt, field), [UnitTriangleMesh()] * num_subintervals, get_qoi=self.constant_qoi, qoi_type="end_time", @@ -459,7 +487,7 @@ def test_h_enrichment_mesh(self, num_enrichments): |/ | |/ |/ | |/|/|/|/| o-------o o---o---o o-o-o-o-o """ - mesh_seq = self.go_mesh_seq(self.get_function_spaces_decorator("R", 0, 0)) + mesh_seq = self.go_mesh_seq() mesh_seq_e = mesh_seq.get_enriched_mesh_seq( enrichment_method="h", num_enrichments=num_enrichments ) @@ -476,31 +504,30 @@ def test_h_enrichment_mesh(self, num_enrichments): @parameterized.expand( [ - ("DG", 0, 0), - ("DG", 0, 1), - ("CG", 1, 0), - ("CG", 1, 1), - ("CG", 2, 0), - ("CG", 2, 1), + ("Discontinuous Lagrange", 0, 0), + ("Discontinuous Lagrange", 0, 1), + ("Lagrange", 1, 0), + ("Lagrange", 1, 1), + ("Lagrange", 2, 0), + ("Lagrange", 2, 1), ] ) def test_h_enrichment_space(self, family, degree, rank): - mesh_seq = self.go_mesh_seq( - self.get_function_spaces_decorator(family, degree, rank) - ) + mesh_seq = self.go_mesh_seq(element=self.element(family, degree, rank)) mesh_seq_e = mesh_seq.get_enriched_mesh_seq( enrichment_method="h", num_enrichments=1 ) - fspace = mesh_seq.function_spaces[self.field][0] + field_name0 = mesh_seq.field_names[0] + fspace = mesh_seq.function_spaces[field_name0][0] element = fspace.ufl_element() - enriched_fspace = mesh_seq_e.function_spaces[self.field][0] + enriched_fspace = mesh_seq_e.function_spaces[field_name0][0] enriched_element = enriched_fspace.ufl_element() self.assertEqual(element.family(), enriched_element.family()) self.assertEqual(element.degree(), enriched_element.degree()) self.assertEqual(fspace.value_shape, enriched_fspace.value_shape) def test_p_enrichment_mesh(self): - mesh_seq = self.go_mesh_seq(self.get_function_spaces_decorator("CG", 1, 0)) + mesh_seq = self.go_mesh_seq(self.element("Lagrange", 1, 0)) mesh_seq_e = mesh_seq.get_enriched_mesh_seq( enrichment_method="p", num_enrichments=1 ) @@ -509,30 +536,29 @@ def test_p_enrichment_mesh(self): @parameterized.expand( [ - ("DG", 0, 0, 1), - ("DG", 0, 0, 2), - ("DG", 0, 1, 1), - ("DG", 0, 1, 2), - ("CG", 1, 0, 1), - ("CG", 1, 0, 2), - ("CG", 1, 1, 1), - ("CG", 1, 1, 2), - ("CG", 2, 0, 1), - ("CG", 2, 0, 2), - ("CG", 2, 1, 1), - ("CG", 2, 1, 2), + ("Discontinuous Lagrange", 0, 0, 1), + ("Discontinuous Lagrange", 0, 0, 2), + ("Discontinuous Lagrange", 0, 1, 1), + ("Discontinuous Lagrange", 0, 1, 2), + ("Lagrange", 1, 0, 1), + ("Lagrange", 1, 0, 2), + ("Lagrange", 1, 1, 1), + ("Lagrange", 1, 1, 2), + ("Lagrange", 2, 0, 1), + ("Lagrange", 2, 0, 2), + ("Lagrange", 2, 1, 1), + ("Lagrange", 2, 1, 2), ] ) def test_p_enrichment_space(self, family, degree, rank, num_enrichments): - mesh_seq = self.go_mesh_seq( - self.get_function_spaces_decorator(family, degree, rank) - ) + mesh_seq = self.go_mesh_seq(element=self.element(family, degree, rank)) mesh_seq_e = mesh_seq.get_enriched_mesh_seq( enrichment_method="p", num_enrichments=num_enrichments ) - fspace = mesh_seq.function_spaces[self.field][0] + field_name0 = mesh_seq.field_names[0] + fspace = mesh_seq.function_spaces[field_name0][0] element = fspace.ufl_element() - enriched_fspace = mesh_seq_e.function_spaces[self.field][0] + enriched_fspace = mesh_seq_e.function_spaces[field_name0][0] enriched_element = enriched_fspace.ufl_element() self.assertEqual(element.family(), enriched_element.family()) self.assertEqual(element.degree() + num_enrichments, enriched_element.degree()) @@ -540,38 +566,36 @@ def test_p_enrichment_space(self, family, degree, rank, num_enrichments): @parameterized.expand( [ - ("DG", 0, 0, "h", 1), - ("DG", 0, 0, "h", 2), - ("CG", 1, 0, "h", 1), - ("CG", 1, 0, "h", 2), - ("CG", 2, 0, "h", 1), - ("CG", 2, 0, "h", 2), - ("DG", 0, 0, "p", 1), - ("DG", 0, 0, "p", 2), - ("CG", 1, 0, "p", 1), - ("CG", 1, 0, "p", 2), - ("CG", 2, 0, "p", 1), - ("CG", 2, 0, "p", 2), - ("DG", 0, 1, "h", 1), - ("DG", 0, 1, "h", 2), - ("CG", 1, 1, "h", 1), - ("CG", 1, 1, "h", 2), - ("CG", 2, 1, "h", 1), - ("CG", 2, 1, "h", 2), - ("DG", 0, 1, "p", 1), - ("DG", 0, 1, "p", 2), - ("CG", 1, 1, "p", 1), - ("CG", 1, 1, "p", 2), - ("CG", 2, 1, "p", 1), - ("CG", 2, 1, "p", 2), + ("Discontinuous Lagrange", 0, 0, "h", 1), + ("Discontinuous Lagrange", 0, 0, "h", 2), + ("Lagrange", 1, 0, "h", 1), + ("Lagrange", 1, 0, "h", 2), + ("Lagrange", 2, 0, "h", 1), + ("Lagrange", 2, 0, "h", 2), + ("Discontinuous Lagrange", 0, 0, "p", 1), + ("Discontinuous Lagrange", 0, 0, "p", 2), + ("Lagrange", 1, 0, "p", 1), + ("Lagrange", 1, 0, "p", 2), + ("Lagrange", 2, 0, "p", 1), + ("Lagrange", 2, 0, "p", 2), + ("Discontinuous Lagrange", 0, 1, "h", 1), + ("Discontinuous Lagrange", 0, 1, "h", 2), + ("Lagrange", 1, 1, "h", 1), + ("Lagrange", 1, 1, "h", 2), + ("Lagrange", 2, 1, "h", 1), + ("Lagrange", 2, 1, "h", 2), + ("Discontinuous Lagrange", 0, 1, "p", 1), + ("Discontinuous Lagrange", 0, 1, "p", 2), + ("Lagrange", 1, 1, "p", 1), + ("Lagrange", 1, 1, "p", 2), + ("Lagrange", 2, 1, "p", 1), + ("Lagrange", 2, 1, "p", 2), ] ) def test_enrichment_transfer( self, family, degree, rank, enrichment_method, num_enrichments ): - mesh_seq = self.go_mesh_seq( - self.get_function_spaces_decorator(family, degree, rank) - ) + mesh_seq = self.go_mesh_seq(element=self.element(family, degree, rank)) mesh_seq_e = mesh_seq.get_enriched_mesh_seq( enrichment_method=enrichment_method, num_enrichments=num_enrichments ) @@ -595,7 +619,7 @@ def test_constant_coefficients(self): # Solve over the first (only) subinterval next(mesh_seq._solve_adjoint(track_coefficients=True)) # Check no coefficients have changed - self.assertEqual(mesh_seq._changed_form_coeffs, {self.field: {}}) + self.assertEqual(mesh_seq._changed_form_coeffs, {self.field.name: {}}) def test_changed_coefficients(self): # Change coefficient f by coeff_diff every timestep @@ -603,7 +627,7 @@ def test_changed_coefficients(self): mesh_seq = self.go_mesh_seq(coeff_diff=coeff_diff) # Solve over the first (only) subinterval next(mesh_seq._solve_adjoint(track_coefficients=True)) - changed_coeffs_dict = mesh_seq._changed_form_coeffs[self.field] + changed_coeffs_dict = mesh_seq._changed_form_coeffs[self.field.name] coeff_idx = next(iter(changed_coeffs_dict)) for export_idx, f in changed_coeffs_dict[coeff_idx].items(): - self.assertTrue(f.vector().gather() == [1.0001 + export_idx * coeff_diff]) + self.assertTrue(f.dat.data == [1.0001 + export_idx * coeff_diff]) diff --git a/test/adjoint/test_gradient.py b/test/adjoint/test_gradient.py index 2c8c4b05..54898d86 100644 --- a/test/adjoint/test_gradient.py +++ b/test/adjoint/test_gradient.py @@ -15,6 +15,7 @@ from parameterized import parameterized from goalie.adjoint import AdjointMeshSeq, annotate_qoi +from goalie.field import Field from goalie.time_partition import TimeInterval, TimePartition @@ -24,8 +25,9 @@ class TestExceptions(unittest.TestCase): """ def test_attribute_error(self): + field = Field("field", family="Real", degree=0, unsteady=False) mesh_seq = AdjointMeshSeq( - TimeInterval(1.0, 1.0, "field"), + TimeInterval(1.0, 1.0, field), UnitIntervalMesh(1), qoi_type="steady", ) @@ -66,10 +68,10 @@ def solver(index): fs = self.function_spaces["field"][index] tp = self.time_partition if tp.steady: - u = self.fields["field"] + u = self.field_functions["field"] u_ = Function(fs, name="field_old").assign(u) else: - u, u_ = self.fields["field"] + u, u_ = self.field_functions["field"] v = TestFunction(fs) F = u * v * ufl.dx - Constant(self.scalar) * u_ * v * ufl.dx @@ -103,14 +105,14 @@ def get_qoi(self, index): tp = self.time_partition def steady_qoi(): - return self.integrand(self.fields["field"]) * ufl.dx + return self.integrand(self.field_functions["field"]) * ufl.dx def end_time_qoi(): - return self.integrand(self.fields["field"][0]) * ufl.dx + return self.integrand(self.field_functions["field"][0]) * ufl.dx def time_integrated_qoi(t): dt = tp.timesteps[index] - return dt * self.integrand(self.fields["field"][0]) * ufl.dx + return dt * self.integrand(self.field_functions["field"][0]) * ufl.dx if self.qoi_type == "steady": return steady_qoi @@ -147,8 +149,9 @@ class TestGradientComputation(unittest.TestCase): Unit tests that check gradient values can be computed correctly. """ - def time_partition(self, num_subintervals, dt): - return TimePartition(1.0, num_subintervals, dt, "field") + def time_partition(self, num_subintervals, dt, unsteady=True): + field = Field("field", family="Real", degree=0, unsteady=unsteady) + return TimePartition(1.0, num_subintervals, dt, field) @parameterized.expand( [ @@ -165,14 +168,14 @@ def time_partition(self, num_subintervals, dt): def test_single_timestep_steady_qoi(self, qoi_degree, initial_value): mesh_seq = GradientTestMeshSeq( {"qoi_degree": qoi_degree, "initial_value": initial_value}, - self.time_partition(1, 1.0), + self.time_partition(1, 1.0, unsteady=False), UnitIntervalMesh(1), qoi_type="steady", ) mesh_seq.solve_adjoint(compute_gradient=True) self.assertTrue( np.allclose( - mesh_seq.gradient[0].dat.data, + mesh_seq.gradient["field"].dat.data, mesh_seq.expected_gradient(), ) ) @@ -199,7 +202,7 @@ def test_two_timesteps_end_time_qoi(self, qoi_degree, initial_value): mesh_seq.solve_adjoint(compute_gradient=True) self.assertTrue( np.allclose( - mesh_seq.gradient[0].dat.data, + mesh_seq.gradient["field"].dat.data, mesh_seq.expected_gradient(), ) ) @@ -226,7 +229,7 @@ def test_two_subintervals_end_time_qoi(self, qoi_degree, initial_value): mesh_seq.solve_adjoint(compute_gradient=True) self.assertTrue( np.allclose( - mesh_seq.gradient[0].dat.data, + mesh_seq.gradient["field"].dat.data, mesh_seq.expected_gradient(), ) ) @@ -253,7 +256,7 @@ def test_two_subintervals_time_integrated_qoi(self, qoi_degree, initial_value): mesh_seq.solve_adjoint(compute_gradient=True) self.assertTrue( np.allclose( - mesh_seq.gradient[0].dat.data, + mesh_seq.gradient["field"].dat.data, mesh_seq.expected_gradient(), ) ) diff --git a/test/adjoint/test_utils.py b/test/adjoint/test_utils.py index ee3b3bad..277601cd 100644 --- a/test/adjoint/test_utils.py +++ b/test/adjoint/test_utils.py @@ -7,6 +7,7 @@ from firedrake.utility_meshes import UnitSquareMesh from goalie.adjoint import AdjointMeshSeq, annotate_qoi +from goalie.field import Field from goalie.go_mesh_seq import GoalOrientedAdaptParameters from goalie.time_partition import TimeInterval @@ -17,7 +18,7 @@ class TestAdjointUtils(unittest.TestCase): """ def setUp(self): - self.time_interval = TimeInterval(1.0, [0.5], ["field"]) + self.time_interval = TimeInterval(1.0, [0.5], Field("field", family="Real")) self.mesh = UnitSquareMesh(1, 1) def mesh_seq(self, qoi_type="end_time"): @@ -100,7 +101,7 @@ def test_annotate_qoi_not_steady(self): assert str(cm.exception) == msg def test_annotate_qoi_steady(self): - time_interval = TimeInterval(1.0, [1.0], ["field"]) + time_interval = TimeInterval(1.0, [1.0], Field("field", family="Real")) with self.assertRaises(ValueError) as cm: AdjointMeshSeq(time_interval, [self.mesh], qoi_type="end_time") msg = "Time partition is steady but the QoI type is set to 'end_time'." diff --git a/test/test_error_estimation.py b/test/test_error_estimation.py index 1ea19a67..e5d2d215 100644 --- a/test/test_error_estimation.py +++ b/test/test_error_estimation.py @@ -12,6 +12,7 @@ form2indicator, get_dwr_indicator, ) +from goalie.field import Field from goalie.function_data import IndicatorData from goalie.go_mesh_seq import GoalOrientedMeshSeq from goalie.time_partition import TimeInstant, TimePartition @@ -24,6 +25,7 @@ class ErrorEstimationTestCase(unittest.TestCase): def setUp(self): self.mesh = UnitSquareMesh(1, 1) + self.field = Field("field", family="Real", degree=0) self.fs = FunctionSpace(self.mesh, "CG", 1) self.trial = TrialFunction(self.fs) self.test = TestFunction(self.fs) @@ -70,18 +72,23 @@ class TestIndicators2Estimator(ErrorEstimationTestCase): def mesh_seq(self, time_partition=None): num_timesteps = 1 if time_partition is None else time_partition.num_timesteps return GoalOrientedMeshSeq( - time_partition or TimeInstant("field"), + time_partition or TimeInstant(self.field), self.mesh, qoi_type="steady" if num_timesteps == 1 else "end_time", ) def test_time_partition_wrong_field_error(self): - mesh_seq = self.mesh_seq(TimeInstant("field")) - time_partition = TimeInstant("f") - mesh_seq._indicators = IndicatorData(time_partition, mesh_seq.meshes) + time_partition1 = TimeInstant(self.field) + field2 = Field("field2", family="Real", degree=0) + time_partition2 = TimeInstant(field2) + mesh_seq = self.mesh_seq(time_partition=time_partition1) + mesh_seq._indicators = IndicatorData(time_partition2, mesh_seq.meshes) with self.assertRaises(ValueError) as cm: mesh_seq.error_estimate() - msg = "Key 'f' does not exist in the TimePartition provided." + msg = ( + "Field 'field' is not associated with" + " object." + ) self.assertEqual(str(cm.exception), msg) def test_absolute_value_type_error(self): @@ -92,14 +99,14 @@ def test_absolute_value_type_error(self): self.assertEqual(str(cm.exception), msg) def test_unit_time_instant(self): - mesh_seq = self.mesh_seq(time_partition=TimeInstant("field", time=1.0)) + mesh_seq = self.mesh_seq(time_partition=TimeInstant(self.field, time=1.0)) mesh_seq.indicators["field"][0][0].assign(form2indicator(self.one * ufl.dx)) estimator = mesh_seq.error_estimate() self.assertAlmostEqual(estimator, 1) # 1 * (0.5 + 0.5) @parameterized.expand([[False], [True]]) def test_unit_time_instant_abs(self, absolute_value): - mesh_seq = self.mesh_seq(time_partition=TimeInstant("field", time=1.0)) + mesh_seq = self.mesh_seq(time_partition=TimeInstant(self.field, time=1.0)) mesh_seq.indicators["field"][0][0].assign(form2indicator(-self.one * ufl.dx)) estimator = mesh_seq.error_estimate(absolute_value=absolute_value) self.assertAlmostEqual( @@ -107,14 +114,14 @@ def test_unit_time_instant_abs(self, absolute_value): ) # (-)1 * (0.5 + 0.5) def test_half_time_instant(self): - mesh_seq = self.mesh_seq(time_partition=TimeInstant("field", time=0.5)) + mesh_seq = self.mesh_seq(time_partition=TimeInstant(self.field, time=0.5)) mesh_seq.indicators["field"][0][0].assign(form2indicator(self.one * ufl.dx)) estimator = mesh_seq.error_estimate() self.assertAlmostEqual(estimator, 0.5) # 0.5 * (0.5 + 0.5) def test_time_partition_same_timestep(self): mesh_seq = self.mesh_seq( - time_partition=TimePartition(1.0, 2, [0.5, 0.5], ["field"]) + time_partition=TimePartition(1.0, 2, [0.5, 0.5], [self.field]) ) mesh_seq.indicators["field"][0][0].assign(form2indicator(2 * self.one * ufl.dx)) estimator = mesh_seq.error_estimate() @@ -122,7 +129,7 @@ def test_time_partition_same_timestep(self): def test_time_partition_different_timesteps(self): mesh_seq = self.mesh_seq( - time_partition=TimePartition(1.0, 2, [0.5, 0.25], ["field"]) + time_partition=TimePartition(1.0, 2, [0.5, 0.25], [self.field]) ) indicator = form2indicator(self.one * ufl.dx) mesh_seq.indicators["field"][0][0].assign(indicator) @@ -134,11 +141,12 @@ def test_time_partition_different_timesteps(self): ) # 0.5 * (0.5 + 0.5) + 0.25 * 2 * (0.5 + 0.5) def test_time_instant_multiple_fields(self): + field2 = Field("field2", family="Real", degree=0) mesh_seq = self.mesh_seq( - time_partition=TimeInstant(["field1", "field2"], time=1.0) + time_partition=TimeInstant([self.field, field2], time=1.0) ) indicator = form2indicator(self.one * ufl.dx) - mesh_seq.indicators["field1"][0][0].assign(indicator) + mesh_seq.indicators["field"][0][0].assign(indicator) mesh_seq.indicators["field2"][0][0].assign(indicator) estimator = mesh_seq.error_estimate() self.assertAlmostEqual(estimator, 2) # 2 * (1 * (0.5 + 0.5)) diff --git a/test/test_field.py b/test/test_field.py new file mode 100644 index 00000000..f47371ad --- /dev/null +++ b/test/test_field.py @@ -0,0 +1,168 @@ +import unittest + +import ufl +from finat.ufl import FiniteElement +from firedrake.utility_meshes import UnitIntervalMesh, UnitTriangleMesh + +from goalie.field import Field + + +def p1_element(): + return FiniteElement("Lagrange", ufl.triangle, 1) + + +def real_element(): + return FiniteElement("Real", ufl.interval, 0) + + +def mesh1d(): + return UnitIntervalMesh(1) + + +def mesh2d(): + return UnitTriangleMesh() + + +class TestExceptions(unittest.TestCase): + """ + Test exceptions raised by Field class. + + NOTE: We don't check the exact exception raised in the + test_make_scalar_element_error* tests because those errors would be raised in + Firedrake. + """ + + def test_unexpected_kwarg_error(self): + with self.assertRaises(ValueError) as cm: + Field("field", family="Real", degree=0, kwarg="blah") + self.assertEqual(str(cm.exception), "Unexpected keyword argument 'kwarg'.") + + def test_element_and_rank_error(self): + with self.assertRaises(Exception) as cm: + Field("field", p1_element(), vector=True) + msg = "The finite_element and vector arguments cannot be used in conjunction." + self.assertEqual(str(cm.exception), msg) + + def test_field_invalid_finite_element(self): + with self.assertRaises(TypeError) as cm: + Field("field", "element") + msg = ( + "Field finite element must be a FiniteElement, MixedElement, VectorElement," + " or TensorElement object." + ) + self.assertEqual(str(cm.exception), msg) + + def test_insufficient_arguments(self): + with self.assertRaises(ValueError) as cm: + Field("field") + msg = "Either the finite_element or family must be specified." + self.assertEqual(str(cm.exception), msg) + + +class TestInit(unittest.TestCase): + """Test initialisation of Field class.""" + + def test_field_defaults(self): + field = Field("field", p1_element()) + self.assertTrue(field.solved_for) + self.assertTrue(field.unsteady) + + def test_field_initialization(self): + field = Field( + name="field", + finite_element=p1_element(), + solved_for=False, + unsteady=False, + ) + self.assertEqual(field.name, "field") + self.assertEqual(field.finite_element, p1_element()) + self.assertFalse(field.solved_for) + self.assertFalse(field.unsteady) + + def test_field_alternative_real(self): + field = Field( + name="field", + family="Real", + degree=0, + ) + self.assertEqual(field.get_element(mesh1d()), real_element()) + + def test_field_alternative_p1(self): + field = Field( + name="field", + family="Lagrange", + degree=1, + ) + self.assertEqual(field.get_element(mesh2d()), p1_element()) + + +class TestGetElement(unittest.TestCase): + """Test `get_element` method of Field class.""" + + def test_make_scalar_element_error1(self): + field = Field("field", family="Real") + with self.assertRaises(AttributeError): + field.get_element("mesh") + + def test_make_scalar_element_error2(self): + field = Field("field", family="family") + with self.assertRaises(ValueError): + field.get_element(mesh1d()) + + def test_make_scalar_element_error3(self): + field = Field("field", family="Real", degree=-1) + with self.assertRaises(ValueError): + field.get_element(mesh1d()) + + def test_field_alternative_real(self): + field = Field(name="field", family="Real", degree=0) + finite_element = field.get_element(mesh1d()) + self.assertEqual(finite_element, real_element()) + + def test_field_alternative_p1(self): + field = Field(name="field", family="Lagrange", degree=1) + finite_element = field.get_element(mesh2d()) + self.assertEqual(finite_element, p1_element()) + + +class TestInterrogation(unittest.TestCase): + """Test interrogation of Field class.""" + + def test_str(self): + self.assertEqual(str(Field("field", p1_element())), "Field(field)") + + def test_repr(self): + expected_repr = ( + "Field('field', , solved_for=True, unsteady=True)" + ) + self.assertEqual(repr(Field("field", p1_element())), expected_repr) + + def test_eq(self): + field1 = Field("field", p1_element(), solved_for=True, unsteady=True) + field2 = Field("field", p1_element(), solved_for=True, unsteady=True) + self.assertEqual(field1, field2) + + def test_ne_name(self): + field1 = Field("field1", p1_element(), solved_for=True, unsteady=True) + field2 = Field("field2", p1_element(), solved_for=True, unsteady=True) + self.assertNotEqual(field1, field2) + + def test_ne_element(self): + p2_element = FiniteElement("Lagrange", ufl.triangle, 2) + field1 = Field("field", p1_element(), solved_for=True, unsteady=True) + field2 = Field("field", p2_element, solved_for=True, unsteady=True) + self.assertNotEqual(field1, field2) + + def test_ne_solved_for(self): + field1 = Field("field", p1_element(), solved_for=True, unsteady=True) + field2 = Field("field", p1_element(), solved_for=False, unsteady=True) + self.assertNotEqual(field1, field2) + + def test_ne_unsteady(self): + field1 = Field("field", p1_element(), solved_for=True, unsteady=True) + field2 = Field("field", p1_element(), solved_for=True, unsteady=False) + self.assertNotEqual(field1, field2) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_function_data.py b/test/test_function_data.py index b969cb86..1ea2955a 100644 --- a/test/test_function_data.py +++ b/test/test_function_data.py @@ -12,6 +12,7 @@ from firedrake.mg.mesh import MeshHierarchy from firedrake.utility_meshes import UnitTriangleMesh +from goalie.field import Field from goalie.function_data import AdjointSolutionData, ForwardSolutionData, IndicatorData from goalie.time_partition import TimePartition from goalie.utility import AttrDict @@ -32,14 +33,14 @@ def setUpUnsteady(self): end_time = 1.0 self.num_subintervals = 2 timesteps = [0.5, 0.25] - self.field = "field" + self.field = Field("field", family="Real") self.num_exports = [1, 2] self.mesh = UnitTriangleMesh() self.time_partition = TimePartition( end_time, self.num_subintervals, timesteps, self.field ) self.function_spaces = { - self.field: [ + self.field.name: [ FunctionSpace(self.mesh, "DG", 0) for _ in range(self.num_subintervals) ] @@ -50,14 +51,14 @@ def setUpSteady(self): end_time = 1.0 self.num_subintervals = 1 timesteps = [1.0] - self.field = "field" + self.field = Field("field", family="Real") self.num_exports = [1] self.mesh = UnitTriangleMesh() self.time_partition = TimePartition( end_time, self.num_subintervals, timesteps, self.field ) self.function_spaces = { - self.field: [ + self.field.name: [ FunctionSpace(self.mesh, "DG", 0) for _ in range(self.num_subintervals) ] @@ -71,16 +72,18 @@ def _create_function_data(self): def test_extract_by_field(self): data = self.solution_data.extract(layout="field") self.assertTrue(isinstance(data, AttrDict)) - self.assertTrue(self.field in data) + self.assertTrue(self.field.name in data) for label in self.labels: - self.assertTrue(isinstance(data[self.field], AttrDict)) - self.assertTrue(label in data[self.field]) - self.assertTrue(isinstance(data[self.field][label], list)) - self.assertEqual(len(data[self.field][label]), self.num_subintervals) + self.assertTrue(isinstance(data[self.field.name], AttrDict)) + self.assertTrue(label in data[self.field.name]) + self.assertTrue(isinstance(data[self.field.name][label], list)) + self.assertEqual( + len(data[self.field.name][label]), self.num_subintervals + ) for i, num_exports in enumerate(self.num_exports): - self.assertTrue(isinstance(data[self.field][label][i], list)) - self.assertEqual(len(data[self.field][label][i]), num_exports) - for f in data[self.field][label][i]: + self.assertTrue(isinstance(data[self.field.name][label][i], list)) + self.assertEqual(len(data[self.field.name][label][i]), num_exports) + for f in data[self.field.name][label][i]: self.assertTrue(isinstance(f, Function)) def test_extract_by_label(self): @@ -89,13 +92,15 @@ def test_extract_by_label(self): for label in self.labels: self.assertTrue(label in data) self.assertTrue(isinstance(data[label], AttrDict)) - self.assertTrue(self.field in data[label]) - self.assertTrue(isinstance(data[label][self.field], list)) - self.assertEqual(len(data[label][self.field]), self.num_subintervals) + self.assertTrue(self.field.name in data[label]) + self.assertTrue(isinstance(data[label][self.field.name], list)) + self.assertEqual( + len(data[label][self.field.name]), self.num_subintervals + ) for i, num_exports in enumerate(self.num_exports): - self.assertTrue(isinstance(data[label][self.field][i], list)) - self.assertEqual(len(data[label][self.field][i]), num_exports) - for f in data[label][self.field][i]: + self.assertTrue(isinstance(data[label][self.field.name][i], list)) + self.assertEqual(len(data[label][self.field.name][i]), num_exports) + for f in data[label][self.field.name][i]: self.assertTrue(isinstance(f, Function)) def test_extract_by_subinterval(self): @@ -104,15 +109,15 @@ def test_extract_by_subinterval(self): self.assertEqual(len(data), self.num_subintervals) for i, sub_data in enumerate(data): self.assertTrue(isinstance(sub_data, AttrDict)) - self.assertTrue(self.field in sub_data) - self.assertTrue(isinstance(sub_data[self.field], AttrDict)) + self.assertTrue(self.field.name in sub_data) + self.assertTrue(isinstance(sub_data[self.field.name], AttrDict)) for label in self.labels: - self.assertTrue(label in sub_data[self.field]) - self.assertTrue(isinstance(sub_data[self.field][label], list)) + self.assertTrue(label in sub_data[self.field.name]) + self.assertTrue(isinstance(sub_data[self.field.name][label], list)) self.assertEqual( - len(sub_data[self.field][label]), self.num_exports[i] + len(sub_data[self.field.name][label]), self.num_exports[i] ) - for f in sub_data[self.field][label]: + for f in sub_data[self.field.name][label]: self.assertTrue(isinstance(f, Function)) @@ -192,12 +197,12 @@ def _create_function_data(self): def _test_extract_by_field_or_label(self, data): self.assertTrue(isinstance(data, AttrDict)) - self.assertTrue(self.field in data) - self.assertEqual(len(data[self.field]), self.num_subintervals) + self.assertTrue(self.field.name in data) + self.assertEqual(len(data[self.field.name]), self.num_subintervals) for i, num_exports in enumerate(self.num_exports): - self.assertTrue(isinstance(data[self.field][i], list)) - self.assertEqual(len(data[self.field][i]), num_exports) - for f in data[self.field][i]: + self.assertTrue(isinstance(data[self.field.name][i], list)) + self.assertEqual(len(data[self.field.name][i]), num_exports) + for f in data[self.field.name][i]: self.assertTrue(isinstance(f, Function)) def test_extract_by_field(self): @@ -214,9 +219,9 @@ def test_extract_by_subinterval(self): self.assertEqual(len(data), self.num_subintervals) for sub_data in data: self.assertTrue(isinstance(sub_data, AttrDict)) - self.assertTrue(self.field in sub_data) - self.assertTrue(isinstance(sub_data[self.field], list)) - for f in sub_data[self.field]: + self.assertTrue(self.field.name in sub_data) + self.assertTrue(isinstance(sub_data[self.field.name], list)) + for f in sub_data[self.field.name]: self.assertTrue(isinstance(f, Function)) @@ -290,11 +295,11 @@ def _create_function_data(self): # Assign 1 to all functions tp = self.solution_data.time_partition - for field in tp.field_names: + for fieldname in tp.field_names: for label in self.solution_data.labels: for i in range(tp.num_subintervals): for j in range(tp.num_exports_per_subinterval[i] - 1): - self.solution_data._data[field][label][i][j].assign(1) + self.solution_data._data[fieldname][label][i][j].assign(1) def test_transfer_method_error(self): target_solution_data = ForwardSolutionData( @@ -314,10 +319,10 @@ def test_transfer_subintervals_error(self): 1.5 * self.time_partition.end_time, self.time_partition.num_subintervals + 1, self.time_partition.timesteps + [0.25], - self.time_partition.field_names, + self.time_partition.field_metadata, ) target_function_spaces = { - self.field: [ + self.field.name: [ FunctionSpace(self.mesh, "DG", 0) for _ in range(target_time_partition.num_subintervals) ] @@ -338,11 +343,11 @@ def test_transfer_exports_error(self): self.time_partition.end_time, self.time_partition.num_subintervals, self.time_partition.timesteps, - self.time_partition.field_names, + self.time_partition.field_metadata, num_timesteps_per_export=[1, 2], ) target_function_spaces = { - self.field: [ + self.field.name: [ FunctionSpace(self.mesh, "DG", 0) for _ in range(target_time_partition.num_subintervals) ] @@ -363,7 +368,7 @@ def test_transfer_common_fields_error(self): self.time_partition.end_time, self.time_partition.num_subintervals, self.time_partition.timesteps, - ["different_field"], + [Field("different_field", family="Real")], ) target_function_spaces = { "different_field": [ @@ -399,15 +404,19 @@ def test_transfer_interpolate(self): ) target_solution_data._create_data() self.solution_data.transfer(target_solution_data, method="interpolate") - for field in self.solution_data.time_partition.field_names: + for fieldname in self.solution_data.time_partition.field_names: for label in self.solution_data.labels: for i in range(self.solution_data.time_partition.num_subintervals): for j in range( self.solution_data.time_partition.num_exports_per_subinterval[i] - 1 ): - source_function = self.solution_data._data[field][label][i][j] - target_function = target_solution_data._data[field][label][i][j] + source_function = self.solution_data._data[fieldname][label][i][ + j + ] + target_function = target_solution_data._data[fieldname][label][ + i + ][j] self.assertTrue( source_function.dat.data.all() == target_function.dat.data.all() @@ -419,15 +428,19 @@ def test_transfer_project(self): ) target_solution_data._create_data() self.solution_data.transfer(target_solution_data, method="project") - for field in self.solution_data.time_partition.field_names: + for fieldname in self.solution_data.time_partition.field_names: for label in self.solution_data.labels: for i in range(self.solution_data.time_partition.num_subintervals): for j in range( self.solution_data.time_partition.num_exports_per_subinterval[i] - 1 ): - source_function = self.solution_data._data[field][label][i][j] - target_function = target_solution_data._data[field][label][i][j] + source_function = self.solution_data._data[fieldname][label][i][ + j + ] + target_function = target_solution_data._data[fieldname][label][ + i + ][j] self.assertTrue( source_function.dat.data.all() == target_function.dat.data.all() @@ -436,7 +449,7 @@ def test_transfer_project(self): def test_transfer_prolong(self): enriched_mesh = MeshHierarchy(self.mesh, 1)[-1] target_function_spaces = { - self.field: [ + self.field.name: [ FunctionSpace(enriched_mesh, "DG", 0) for _ in range(self.num_subintervals) ] @@ -446,15 +459,19 @@ def test_transfer_prolong(self): ) target_solution_data._create_data() self.solution_data.transfer(target_solution_data, method="prolong") - for field in self.solution_data.time_partition.field_names: + for fieldname in self.solution_data.time_partition.field_names: for label in self.solution_data.labels: for i in range(self.solution_data.time_partition.num_subintervals): for j in range( self.solution_data.time_partition.num_exports_per_subinterval[i] - 1 ): - source_function = self.solution_data._data[field][label][i][j] - target_function = target_solution_data._data[field][label][i][j] + source_function = self.solution_data._data[fieldname][label][i][ + j + ] + target_function = target_solution_data._data[fieldname][label][ + i + ][j] self.assertTrue( source_function.dat.data.all() == target_function.dat.data.all() diff --git a/test/test_mesh_seq.py b/test/test_mesh_seq.py index fa918297..da513954 100644 --- a/test/test_mesh_seq.py +++ b/test/test_mesh_seq.py @@ -14,6 +14,7 @@ ) from parameterized import parameterized +from goalie.field import Field from goalie.mesh_seq import MeshSeq from goalie.time_partition import TimeInterval, TimePartition @@ -29,8 +30,9 @@ class MeshSeqTestCase(unittest.TestCase): """ def setUp(self): - self.time_partition = TimePartition(1.0, 2, [0.5, 0.5], ["field"]) - self.time_interval = TimeInterval(1.0, [0.5], ["field"]) + self.field = Field("field", family="Real") + self.time_partition = TimePartition(1.0, 2, [0.5, 0.5], self.field) + self.time_interval = TimeInterval(1.0, [0.5], self.field) def trivial_mesh(self, dim): try: @@ -55,44 +57,46 @@ def test_inconsistent_dim_error(self): msg = "Meshes must all have the same topological dimension." self.assertEqual(str(cm.exception), msg) - @parameterized.expand(["get_function_spaces", "get_solver"]) - def test_notimplemented_error(self, function_name): + def test_get_solver_notimplemented_error(self): mesh_seq = MeshSeq(self.time_interval, self.trivial_mesh(2)) with self.assertRaises(NotImplementedError) as cm: - if function_name == "get_function_spaces": - getattr(mesh_seq, function_name)(mesh_seq[0]) - else: - getattr(mesh_seq, function_name)() - self.assertEqual(str(cm.exception), f"'{function_name}' needs implementing.") - - @parameterized.expand(["get_function_spaces", "get_initial_condition"]) - def test_return_dict_error(self, method): - kwargs = {method: lambda _: 0} + mesh_seq.get_solver() + self.assertEqual(str(cm.exception), "'get_solver' needs implementing.") + + def test_return_dict_error(self): with self.assertRaises(AssertionError) as cm: - MeshSeq(self.time_interval, self.trivial_mesh(2), **kwargs) - self.assertEqual(str(cm.exception), f"{method} should return a dict") + MeshSeq( + self.time_interval, + self.trivial_mesh(2), + get_initial_condition=lambda _: 0, + ) + msg = "get_initial_condition should return a dict" + self.assertEqual(str(cm.exception), msg) - @parameterized.expand(["get_function_spaces", "get_initial_condition"]) - def test_missing_field_error(self, method): - kwargs = {method: lambda _: {}} + def test_missing_field_error(self): with self.assertRaises(AssertionError) as cm: - MeshSeq(self.time_interval, self.trivial_mesh(2), **kwargs) - msg = "missing fields {'field'} in " + f"{method}" + MeshSeq( + self.time_interval, + self.trivial_mesh(2), + get_initial_condition=lambda _: {}, + ) + msg = "missing fields {'field'} in get_initial_condition" self.assertEqual(str(cm.exception), msg) - @parameterized.expand(["get_function_spaces", "get_initial_condition"]) - def test_unexpected_field_error(self, method): - kwargs = {method: lambda _: {"field": None, "extra_field": None}} + def test_unexpected_field_error(self): with self.assertRaises(AssertionError) as cm: - MeshSeq(self.time_interval, self.trivial_mesh(2), **kwargs) - msg = "unexpected fields {'extra_field'} in " + f"{method}" + MeshSeq( + self.time_interval, + self.trivial_mesh(2), + get_initial_condition=lambda _: {"field": None, "extra_field": None}, + ) + msg = "unexpected fields {'extra_field'} in get_initial_condition" self.assertEqual(str(cm.exception), msg) def test_solver_generator_error(self): mesh = self.trivial_mesh(2) f_space = FunctionSpace(mesh, "CG", 1) kwargs = { - "get_function_spaces": lambda _: {"field": f_space}, "get_initial_condition": lambda _: {"field": Function(f_space)}, "get_solver": lambda _: lambda *_: {}, } @@ -113,10 +117,6 @@ class TestGeneric(BaseClasses.MeshSeqTestCase): Generic unit tests for :class:`MeshSeq`. """ - def setUp(self): - self.time_partition = TimePartition(1.0, 2, [0.5, 0.5], ["field"]) - self.time_interval = TimeInterval(1.0, [0.5], ["field"]) - def test_setitem(self): mesh1 = UnitSquareMesh(1, 1, diagonal="left") mesh2 = UnitSquareMesh(1, 1, diagonal="right") diff --git a/test/test_metric.py b/test/test_metric.py index 6c2637e1..28c8f35f 100644 --- a/test/test_metric.py +++ b/test/test_metric.py @@ -13,6 +13,7 @@ from parameterized import parameterized from utility import mesh_for_sensors, uniform_mesh +from goalie.field import Field from goalie.metric import ( enforce_variable_constraints, ramp_complexity, @@ -65,10 +66,10 @@ class TestMetricNormalisation(BaseClasses.MetricTestCase): def setUp(self): super().setUp() - self.time_partition = TimeInterval(1.0, 1.0, "u") + self.time_partition = TimeInterval(1.0, 1.0, Field("u", family="Real")) def test_time_partition_length_error(self): - time_partition = TimePartition(1.0, 2, [0.5, 0.5], "u") + time_partition = TimePartition(1.0, 2, [0.5, 0.5], Field("u", family="Real")) mp = {"dm_plex_metric_target_complexity": 1.0} with self.assertRaises(ValueError) as cm: space_time_normalise([self.simple_metric], time_partition, [mp]) diff --git a/test/test_parallel.py b/test/test_parallel.py index 917a2330..552fd31e 100644 --- a/test/test_parallel.py +++ b/test/test_parallel.py @@ -2,6 +2,7 @@ from firedrake.utility_meshes import UnitCubeMesh, UnitSquareMesh from pyop2.mpi import COMM_WORLD +from goalie.field import Field from goalie.mesh_seq import MeshSeq from goalie.time_partition import TimeInterval @@ -9,7 +10,7 @@ @pytest.mark.parallel(nprocs=2) def test_counting_2d(): assert COMM_WORLD.size == 2 - time_interval = TimeInterval(1.0, [0.5], ["field"]) + time_interval = TimeInterval(1.0, [0.5], Field("field", family="Real")) mesh_seq = MeshSeq(time_interval, [UnitSquareMesh(3, 3)]) assert mesh_seq.count_elements() == [18] assert mesh_seq.count_vertices() == [16] @@ -18,7 +19,7 @@ def test_counting_2d(): @pytest.mark.parallel(nprocs=2) def test_counting_3d(): assert COMM_WORLD.size == 2 - time_interval = TimeInterval(1.0, [0.5], ["field"]) + time_interval = TimeInterval(1.0, [0.5], Field("field", family="Real")) mesh_seq = MeshSeq(time_interval, [UnitCubeMesh(3, 3, 3)]) assert mesh_seq.count_elements() == [162] assert mesh_seq.count_vertices() == [64] diff --git a/test/test_time_partition.py b/test/test_time_partition.py index 2f57a5fa..9734a60c 100644 --- a/test/test_time_partition.py +++ b/test/test_time_partition.py @@ -4,123 +4,58 @@ import unittest +from goalie.field import Field from goalie.time_partition import TimeInstant, TimeInterval, TimePartition -class TestSetup(unittest.TestCase): +class BaseTestCase(unittest.TestCase): """ - Tests related to the construction of time partitions. + Base class for unit tests related to the time partition objects. + """ + + def setUp(self): + self.end_time = 1.0 + self.field = Field("field", family="Real") + self.field_metadata_list = [self.field] + self.field_metadata_dict = {"field": self.field} + + +class TestExceptions(BaseTestCase): + """ + Tests for exceptions raised by the time partition objects. """ def test_time_instant_multiple_kwargs(self): with self.assertRaises(ValueError) as cm: - TimeInstant("field", time=1.0, end_time=1.0) + TimeInstant(self.field, time=1.0, end_time=1.0) msg = "Both 'time' and 'end_time' are set." self.assertEqual(str(cm.exception), msg) - def test_time_partition_eq_positive(self): - time_partition1 = TimePartition(1.0, 1, [1.0], "field") - time_partition2 = TimePartition(1.0, 1, [1.0], "field") - self.assertTrue(time_partition1 == time_partition2) - - def test_time_partition_eq_negative(self): - time_partition1 = TimePartition(1.0, 1, [1.0], "field") - time_partition2 = TimePartition(2.0, 1, [1.0], "field") - self.assertFalse(time_partition1 == time_partition2) - - def test_time_partition_ne_positive(self): - time_partition1 = TimePartition(1.0, 2, [0.5, 0.5], "field") - time_partition2 = TimePartition(1.0, 1, [1.0], "field") - self.assertTrue(time_partition1 != time_partition2) - - def test_time_partition_ne_negative(self): - time_partition1 = TimePartition( - 1.0, 1, [1.0], "field", num_timesteps_per_export=1 - ) - time_partition2 = TimePartition(1.0, 1, [1.0], "field") - self.assertFalse(time_partition1 != time_partition2) - - def test_time_interval_eq_positive(self): - time_interval1 = TimeInterval(1.0, 1.0, "field") - time_interval2 = TimeInterval((0.0, 1.0), 1.0, ["field"]) - self.assertTrue(time_interval1 == time_interval2) - - def test_time_interval_eq_negative(self): - time_interval1 = TimeInterval(1.0, 1.0, "field") - time_interval2 = TimeInterval((0.5, 1.0), 0.5, "field") - self.assertFalse(time_interval1 == time_interval2) - - def test_time_interval_ne_positive(self): - time_interval1 = TimeInterval(1.0, 1.0, "field") - time_interval2 = TimeInterval((-0.5, 0.5), 1.0, "field") - self.assertTrue(time_interval1 != time_interval2) - - def test_time_interval_ne_negative(self): - time_interval1 = TimeInterval(1.0, 1.0, "field") - time_interval2 = TimeInterval((0.0, 1.0), 1.0, ["field"]) - self.assertFalse(time_interval1 != time_interval2) - - def test_time_instant_eq_positive(self): - time_instant1 = TimeInstant("field", time=1.0) - time_instant2 = TimeInstant(["field"], time=1.0) - self.assertTrue(time_instant1 == time_instant2) - - def test_time_instant_eq_negative(self): - time_instant1 = TimeInstant("field", time=1.0) - time_instant2 = TimeInstant("f", time=1.0) - self.assertFalse(time_instant1 == time_instant2) - - def test_time_instant_ne_positive(self): - time_instant1 = TimeInstant("field", time=1.0) - time_instant2 = TimeInstant("field", time=2.0) - self.assertTrue(time_instant1 != time_instant2) - - def test_time_instant_ne_negative(self): - time_instant1 = TimeInstant("field", time=1.0) - time_instant2 = TimeInstant("field", end_time=1.0) - self.assertFalse(time_instant1 != time_instant2) - - def test_time_partition_eq_interval_positive(self): - time_partition = TimePartition(1.0, 1, [0.5], ["field"]) - time_interval = TimeInterval(1.0, 0.5, "field") - self.assertTrue(time_partition == time_interval) - - def test_time_partition_eq_interval_negative(self): - time_partition = TimePartition(1.0, 2, [0.5, 0.5], ["field"]) - time_interval = TimeInterval(1.0, 0.5, "field") - self.assertFalse(time_partition == time_interval) - - def test_time_partition_ne_interval_positive(self): - time_partition = TimePartition(0.5, 1, [0.5], "field") - time_interval = TimeInterval(1.0, 0.5, "field") - self.assertTrue(time_partition != time_interval) - - def test_time_partition_ne_interval_negative(self): - time_partition = TimePartition(1.0, 1, 0.5, ["field"], start_time=0.0) - time_interval = TimeInterval(1.0, 0.5, "field") - self.assertFalse(time_partition != time_interval) - def test_noninteger_num_subintervals(self): with self.assertRaises(ValueError) as cm: - TimePartition(1.0, 1.1, 0.5, "field") + TimePartition(1.0, 1.1, 0.5, self.field) msg = "Non-integer number of subintervals '1.1'." self.assertEqual(str(cm.exception), msg) def test_wrong_number_of_subintervals(self): with self.assertRaises(ValueError) as cm: - TimePartition(1.0, 1, 0.5, "field", subintervals=[(0.0, 0.5), (0.5, 1.0)]) + TimePartition( + 1.0, 1, 0.5, self.field, subintervals=[(0.0, 0.5), (0.5, 1.0)] + ) msg = "Number of subintervals provided differs from num_subintervals: 2 != 1." self.assertEqual(str(cm.exception), msg) def test_wrong_subinterval_start(self): with self.assertRaises(ValueError) as cm: - TimePartition(1.0, 1, 0.5, "field", subintervals=[(0.1, 1.0)]) + TimePartition(1.0, 1, 0.5, self.field, subintervals=[(0.1, 1.0)]) msg = "The first subinterval does not start at the start time: 0.1 != 0.0." self.assertEqual(str(cm.exception), msg) def test_inconsistent_subintervals(self): with self.assertRaises(ValueError) as cm: - TimePartition(1.0, 2, 0.5, "field", subintervals=[(0.0, 0.6), (0.5, 1.0)]) + TimePartition( + 1.0, 2, 0.5, self.field, subintervals=[(0.0, 0.6), (0.5, 1.0)] + ) msg = ( "The end of subinterval 0 does not match the start of subinterval 1:" " 0.6 != 0.5." @@ -129,25 +64,25 @@ def test_inconsistent_subintervals(self): def test_wrong_subinterval_end(self): with self.assertRaises(ValueError) as cm: - TimePartition(1.0, 1, 0.5, "field", subintervals=[(0.0, 1.1)]) + TimePartition(1.0, 1, 0.5, self.field, subintervals=[(0.0, 1.1)]) msg = "The final subinterval does not end at the end time: 1.1 != 1.0." self.assertEqual(str(cm.exception), msg) def test_wrong_num_timesteps(self): with self.assertRaises(ValueError) as cm: - TimePartition(1.0, 1, [0.5, 0.5], "field") + TimePartition(1.0, 1, [0.5, 0.5], self.field) msg = "Number of timesteps does not match num_subintervals: 2 != 1." self.assertEqual(str(cm.exception), msg) def test_noninteger_num_timesteps_per_subinterval(self): with self.assertRaises(ValueError) as cm: - TimePartition(1.0, 1, [0.4], "field") + TimePartition(1.0, 1, [0.4], self.field) msg = "Non-integer number of timesteps on subinterval 0: 2.5." self.assertEqual(str(cm.exception), msg) def test_noninteger_num_timesteps_per_export(self): with self.assertRaises(TypeError) as cm: - TimePartition(1.0, 1, [0.5], "field", num_timesteps_per_export=1.1) + TimePartition(1.0, 1, [0.5], self.field, num_timesteps_per_export=1.1) msg = ( "Expected number of timesteps per export on subinterval 0 to be an integer," " not ''." @@ -156,13 +91,13 @@ def test_noninteger_num_timesteps_per_export(self): def test_nonmatching_num_timesteps_per_export(self): with self.assertRaises(ValueError) as cm: - TimePartition(1.0, 1, [0.5], "field", num_timesteps_per_export=[1, 2]) + TimePartition(1.0, 1, [0.5], self.field, num_timesteps_per_export=[1, 2]) msg = "Number of timesteps per export and subinterval do not match: 2 != 1." self.assertEqual(str(cm.exception), msg) def test_indivisible_num_timesteps_per_export(self): with self.assertRaises(ValueError) as cm: - TimePartition(1.0, 1, [0.5], "field", num_timesteps_per_export=4) + TimePartition(1.0, 1, [0.5], self.field, num_timesteps_per_export=4) msg = ( "Number of timesteps per export does not divide number of timesteps per" " subinterval on subinterval 0: 2 | 4 != 0." @@ -171,40 +106,126 @@ def test_indivisible_num_timesteps_per_export(self): def test_debug_invalid_field(self): with self.assertRaises(AttributeError) as cm: - TimeInstant("field").debug("blah") + TimeInstant(self.field).debug("blah") msg = "Attribute 'blah' cannot be debugged because it doesn't exist." self.assertEqual(str(cm.exception), msg) - def test_field_type_error(self): - with self.assertRaises(ValueError) as cm: - TimeInstant("field", field_types="type") - msg = ( - "Expected field type for field 'field' to be either 'unsteady' or" - " 'steady', but got 'type'." - ) + def test_field_typeerror1(self): + with self.assertRaises(TypeError) as cm: + TimeInstant("field") + msg = "field_metadata argument must be a Field or a dict or list thereof." self.assertEqual(str(cm.exception), msg) - def test_num_field_types_error(self): + def test_field_typeerror2(self): + with self.assertRaises(TypeError) as cm: + TimeInstant(["field"]) + self.assertEqual(str(cm.exception), "All fields must be instances of Field.") + + def test_inconsistent_fieldname_valueerror(self): with self.assertRaises(ValueError) as cm: - TimeInstant("field", field_types=["type1", "type2"]) - msg = "Number of field names does not match number of field types: 1 != 2." + TimeInstant({"field1": Field("field2", family="Real")}) + msg = "Inconstent field names passed as field_metadata." self.assertEqual(str(cm.exception), msg) -class TestStringFormatting(unittest.TestCase): +class TestSetup(BaseTestCase): + """ + Tests related to the construction of time partitions. + """ + + def test_time_partition_eq_positive(self): + time_partition1 = TimePartition(1.0, 1, [1.0], self.field) + time_partition2 = TimePartition(1.0, 1, [1.0], self.field) + self.assertTrue(time_partition1 == time_partition2) + + def test_time_partition_eq_negative(self): + time_partition1 = TimePartition(1.0, 1, [1.0], self.field) + time_partition2 = TimePartition(2.0, 1, [1.0], self.field) + self.assertFalse(time_partition1 == time_partition2) + + def test_time_partition_ne_positive(self): + time_partition1 = TimePartition(1.0, 2, [0.5, 0.5], self.field) + time_partition2 = TimePartition(1.0, 1, [1.0], self.field) + self.assertTrue(time_partition1 != time_partition2) + + def test_time_partition_ne_negative(self): + time_partition1 = TimePartition( + 1.0, 1, [1.0], self.field, num_timesteps_per_export=1 + ) + time_partition2 = TimePartition(1.0, 1, [1.0], self.field) + self.assertFalse(time_partition1 != time_partition2) + + def test_time_interval_eq_positive(self): + time_interval1 = TimeInterval(1.0, 1.0, self.field) + time_interval2 = TimeInterval((0.0, 1.0), 1.0, [self.field]) + self.assertTrue(time_interval1 == time_interval2) + + def test_time_interval_eq_negative(self): + time_interval1 = TimeInterval(1.0, 1.0, self.field) + time_interval2 = TimeInterval((0.5, 1.0), 0.5, self.field) + self.assertFalse(time_interval1 == time_interval2) + + def test_time_interval_ne_positive(self): + time_interval1 = TimeInterval(1.0, 1.0, self.field) + time_interval2 = TimeInterval((-0.5, 0.5), 1.0, self.field) + self.assertTrue(time_interval1 != time_interval2) + + def test_time_interval_ne_negative(self): + time_interval1 = TimeInterval(1.0, 1.0, self.field) + time_interval2 = TimeInterval((0.0, 1.0), 1.0, [self.field]) + self.assertFalse(time_interval1 != time_interval2) + + def test_time_instant_eq_positive(self): + time_instant1 = TimeInstant(self.field, time=1.0) + time_instant2 = TimeInstant([self.field], time=1.0) + self.assertTrue(time_instant1 == time_instant2) + + def test_time_instant_eq_negative(self): + time_instant1 = TimeInstant(self.field, time=1.0) + time_instant2 = TimeInstant(Field("f", family="Real"), time=1.0) + self.assertFalse(time_instant1 == time_instant2) + + def test_time_instant_ne_positive(self): + time_instant1 = TimeInstant(self.field, time=1.0) + time_instant2 = TimeInstant(self.field, time=2.0) + self.assertTrue(time_instant1 != time_instant2) + + def test_time_instant_ne_negative(self): + time_instant1 = TimeInstant(self.field, time=1.0) + time_instant2 = TimeInstant(self.field, end_time=1.0) + self.assertFalse(time_instant1 != time_instant2) + + def test_time_partition_eq_interval_positive(self): + time_partition = TimePartition(1.0, 1, [0.5], [self.field]) + time_interval = TimeInterval(1.0, 0.5, self.field) + self.assertTrue(time_partition == time_interval) + + def test_time_partition_eq_interval_negative(self): + time_partition = TimePartition(1.0, 2, [0.5, 0.5], [self.field]) + time_interval = TimeInterval(1.0, 0.5, self.field) + self.assertFalse(time_partition == time_interval) + + def test_time_partition_ne_interval_positive(self): + time_partition = TimePartition(0.5, 1, [0.5], self.field) + time_interval = TimeInterval(1.0, 0.5, self.field) + self.assertTrue(time_partition != time_interval) + + def test_time_partition_ne_interval_negative(self): + time_partition = TimePartition(1.0, 1, 0.5, [self.field], start_time=0.0) + time_interval = TimeInterval(1.0, 0.5, self.field) + self.assertFalse(time_partition != time_interval) + + +class TestStringFormatting(BaseTestCase): """ Test that the :meth:`__str__`` and :meth:`__repr__`` methods work as intended for Goalie's time partition objects. """ - def setUp(self): - self.end_time = 1.0 - self.field_names = ["field"] - def get_time_partition(self, n): split = self.end_time / n timesteps = [split if i % 2 else split / 2 for i in range(n)] - return TimePartition(self.end_time, n, timesteps, self.field_names) + return TimePartition(self.end_time, n, timesteps, self.field_metadata_list) def test_time_partition1_str(self): expected = "[(0.0, 1.0)]" @@ -220,74 +241,83 @@ def test_time_partition4_str(self): def test_time_interval_str(self): expected = "[(0.0, 1.0)]" - time_interval = TimeInterval(self.end_time, [0.5], self.field_names) + time_interval = TimeInterval(self.end_time, [0.5], self.field_metadata_list) self.assertEqual(str(time_interval), expected) def test_time_instant_str(self): expected = "(1.0)" - time_instant = TimeInstant(self.field_names, time=self.end_time) + time_instant = TimeInstant(self.field_metadata_list, time=self.end_time) self.assertEqual(str(time_instant), expected) def test_time_partition1_repr(self): expected = ( "TimePartition(end_time=1.0, num_subintervals=1," - " timesteps=[0.5], field_names=['field'])" + " timesteps=[0.5], field_metadata=[Field('field', , solved_for=True, unsteady=True)])" ) self.assertEqual(repr(self.get_time_partition(1)), expected) def test_time_partition2_repr(self): expected = ( "TimePartition(end_time=1.0, num_subintervals=2," - " timesteps=[0.25, 0.5], field_names=['field'])" + " timesteps=[0.25, 0.5], field_metadata=[Field('field', , solved_for=True, unsteady=True)])" ) self.assertEqual(repr(self.get_time_partition(2)), expected) def test_time_partition4_repr(self): expected = ( "TimePartition(end_time=1.0, num_subintervals=4," - " timesteps=[0.125, 0.25, 0.125, 0.25], field_names=['field'])" + " timesteps=[0.125, 0.25, 0.125, 0.25], field_metadata=[Field('field'," + " , solved_for=True, unsteady=True)])" ) self.assertEqual(repr(self.get_time_partition(4)), expected) def test_time_interval_repr(self): - expected = "TimeInterval(end_time=1.0, timestep=0.5, field_names=['field'])" - time_interval = TimeInterval(self.end_time, [0.5], self.field_names) + expected = ( + "TimeInterval(end_time=1.0, timestep=0.5, field_metadata=[Field('field'," + " , solved_for=True, unsteady=True)])" + ) + time_interval = TimeInterval(self.end_time, [0.5], self.field_metadata_list) self.assertEqual(repr(time_interval), expected) def test_time_instant_repr(self): - expected = "TimeInstant(time=1.0, field_names=['field'])" - time_instant = TimeInstant(self.field_names, time=self.end_time) + expected = ( + "TimeInstant(time=1.0, field_metadata=[Field('field'," + " , solved_for=True, unsteady=True)])" + ) + time_instant = TimeInstant(self.field_metadata_list, time=self.end_time) self.assertEqual(repr(time_instant), expected) -class TestIndexing(unittest.TestCase): +class TestIndexing(BaseTestCase): r""" Unit tests for indexing :class:`~.TimePartition`\s. """ - def setUp(self): - self.end_time = 1.0 - self.field_names = ["field"] - def test_invalid_step(self): timesteps = [0.5, 0.25] - time_partition = TimePartition(self.end_time, 2, timesteps, self.field_names) + time_partition = TimePartition( + self.end_time, 2, timesteps, self.field_metadata_list + ) with self.assertRaises(NotImplementedError) as cm: time_partition[::2] msg = "Can only currently handle slices with step size 1." self.assertEqual(str(cm.exception), msg) def test_time_interval(self): - time_interval = TimeInterval(self.end_time, [0.5], self.field_names) + time_interval = TimeInterval(self.end_time, [0.5], self.field_metadata_list) self.assertEqual(time_interval, time_interval[0]) def test_time_instant(self): - time_instant = TimeInstant(self.field_names, time=self.end_time) + time_instant = TimeInstant(self.field_metadata_list, time=self.end_time) self.assertEqual(time_instant, time_instant[0]) def test_time_partition(self): timesteps = [0.5, 0.25] - time_partition = TimePartition(self.end_time, 2, timesteps, self.field_names) + time_partition = TimePartition( + self.end_time, 2, timesteps, self.field_metadata_list + ) tp0, tp1 = time_partition self.assertEqual(len(tp0), 1) self.assertEqual(len(tp1), 1) @@ -297,36 +327,34 @@ def test_time_partition(self): self.assertAlmostEqual(tp1.end_time, 1.0) self.assertAlmostEqual(tp0.timesteps[0], timesteps[0]) self.assertAlmostEqual(tp1.timesteps[0], timesteps[1]) - self.assertEqual(tp0.field_names, self.field_names) - self.assertEqual(tp0.field_names, tp1.field_names) + self.assertEqual(tp0.field_metadata, self.field_metadata_dict) + self.assertEqual(tp0.field_metadata, tp1.field_metadata) -class TestSlicing(unittest.TestCase): +class TestSlicing(BaseTestCase): r""" Unit tests for slicing :class:`~.TimePartition`\s. """ - def setUp(self): - self.end_time = 1.0 - self.field_names = ["field"] - def test_time_interval(self): - time_interval = TimeInterval(self.end_time, [0.5], self.field_names) + time_interval = TimeInterval(self.end_time, [0.5], self.field_metadata_list) self.assertEqual(time_interval, time_interval[0:1]) def test_time_instant(self): - time_instant = TimeInstant(self.field_names, time=self.end_time) + time_instant = TimeInstant(self.field_metadata_list, time=self.end_time) self.assertEqual(time_instant, time_instant[0:1]) def test_time_partition(self): timesteps = [0.5, 0.25] - time_partition = TimePartition(self.end_time, 2, timesteps, self.field_names) + time_partition = TimePartition( + self.end_time, 2, timesteps, self.field_metadata_list + ) self.assertEqual(time_partition, time_partition[0:2]) def test_time_partition_slice(self): end_time = 0.75 timesteps = [0.25, 0.05, 0.01] - time_partition = TimePartition(end_time, 3, timesteps, self.field_names) + time_partition = TimePartition(end_time, 3, timesteps, self.field_metadata_list) tp0 = time_partition[0] tp12 = time_partition[1:3] self.assertEqual(len(tp0), 1) @@ -337,5 +365,5 @@ def test_time_partition_slice(self): self.assertAlmostEqual(tp12.end_time, end_time) self.assertAlmostEqual(tp0.timesteps[0], timesteps[0]) self.assertAlmostEqual(tp12.timesteps, timesteps[1:3]) - self.assertEqual(tp0.field_names, self.field_names) - self.assertEqual(tp0.field_names, tp12.field_names) + self.assertEqual(tp0.field_metadata, self.field_metadata_dict) + self.assertEqual(tp0.field_metadata, tp12.field_metadata) diff --git a/test/utility.py b/test/utility.py index b4800b81..2b95f213 100644 --- a/test/utility.py +++ b/test/utility.py @@ -16,7 +16,7 @@ def uniform_mesh(dim, n, length=1, **kwargs): def uniform_metric(function_space, scaling): - dim = function_space.mesh().topological_dimension() + dim = function_space.mesh().topological_dimension metric = RiemannianMetric(function_space) metric.interpolate(scaling * ufl.Identity(dim)) return metric