diff --git a/.github/workflows/ci-build.yml b/.github/workflows/ci-build.yml index 94385938..38fce0da 100644 --- a/.github/workflows/ci-build.yml +++ b/.github/workflows/ci-build.yml @@ -36,18 +36,18 @@ jobs: pip install -e ".[sparse,test]" python -m pytest -n auto optimism --cov=optimism -Wignore # we can also add the flag -n auto for parallel testing - - name: docs - run: | - pip install -e ".[docs,sparse,test]" - cd docs - sphinx-apidoc -o source/ ../optimism -P - make html - - name: Deploy to GitHub Pages - uses: peaceiris/actions-gh-pages@v3 - with: - github_token: ${{ secrets.GITHUB_TOKEN }} - publish_dir: docs/build/html # Adjust this if your output directory is different - publish_branch: gh-pages # The branch to deploy to + # - name: docs + # run: | + # pip install -e ".[docs,sparse,test]" + # cd docs + # sphinx-apidoc -o source/ ../optimism -P + # make html + # - name: Deploy to GitHub Pages + # uses: peaceiris/actions-gh-pages@v3 + # with: + # github_token: ${{ secrets.GITHUB_TOKEN }} + # publish_dir: docs/build/html # Adjust this if your output directory is different + # publish_branch: gh-pages # The branch to deploy to - name: codecov uses: codecov/codecov-action@v4 with: diff --git a/optimism/FieldOperator.py b/optimism/FieldOperator.py new file mode 100644 index 00000000..7545e45b --- /dev/null +++ b/optimism/FieldOperator.py @@ -0,0 +1,315 @@ +from __future__ import annotations +import abc +from collections.abc import Callable +import dataclasses + +import jax +import jax.numpy as np +from jax import Array +from jaxtyping import Int, Float + +from optimism import Interpolants +from optimism import Mesh +from optimism import QuadratureRule + +class Field(abc.ABC): + @abc.abstractmethod + def interpolate(self, shape, U, el_id): + pass + + @abc.abstractmethod + def interpolate_gradient(self, shape, U, el_id): + pass + + @property + @abc.abstractmethod + def quadpoint_axis(self): + pass + + @abc.abstractmethod + def compute_shape_functions(self, points): + pass + + @abc.abstractmethod + def map_shape_functions(self, shapes, jacs): + return shapes + + +class PkField(Field): + """Standard Lagrange polynomial finite element fields.""" + quadpoint_axis = 0 + + def __init__(self, k: int, mesh: Mesh.Mesh) -> None: + assert k > 0, "Polynomial degree must be positive" + # temporarily restrict to first order meshes. + # Building the connectivity table for higher-order fields requires knowing the + # simplex connectivity. The mesh object must be updated to store that. + assert mesh.parentElement.degree == 1 + self.order = k + self.element, self.element1d = Interpolants.make_parent_elements(k) + self.mesh = mesh + self.conns, self.coords = self._make_connectivity_and_coordinates() + + def interpolate(self, shape, field, el_id): + e_vector = field[self.conns[el_id]] + return shape.values@e_vector + + def interpolate_gradient(self, shape, field, el_id): + return jax.vmap(lambda ev, dshp: np.tensordot(ev, dshp, (0, 0)), (None, 0))(field[self.conns[el_id]], shape.gradients) + + def compute_shape_functions(self, points): + return Interpolants.compute_shapes(self.element, points) + + def map_shape_functions(self, shapes, jacs): + shape_grads = jax.vmap(lambda dshp, J: dshp@np.linalg.inv(J))(shapes.gradients, jacs) + return Interpolants.ShapeFunctions(shapes.values, shape_grads) + + def _make_connectivity_and_coordinates(self): + if self.order == 1: + return self.mesh.conns, self.mesh.coords + + conns = np.zeros((Mesh.num_elements(self.mesh), self.element.coordinates.shape[0]), dtype=int) + + # step 1/3: vertex nodes + conns = conns.at[:, self.element.vertexNodes].set(self.mesh.conns) + nodeOrdinalOffset = self.mesh.conns.max() + 1 # offset for later node numbering + + # The non-vertex nodes are placed using linear interpolation. When we add meshes that + # have higher-order coordinate spaces, we must remember to update this part with + # the higher-order interpolation. + + # step 2/3: mid-edge nodes (excluding vertices) + edgeConns, edges = Mesh.create_edges(self.mesh.conns) + A = np.column_stack((1.0 - self.element1d.coordinates[self.element1d.interiorNodes], + self.element1d.coordinates[self.element1d.interiorNodes])) + edgeCoords = jax.vmap(lambda edgeConn: np.dot(A, self.mesh.coords[edgeConn, :]))(edgeConns) + + nNodesPerEdge = self.element1d.interiorNodes.size + for e, edge in enumerate(edges): + edgeNodeOrdinals = nodeOrdinalOffset + np.arange(e*nNodesPerEdge,(e+1)*nNodesPerEdge) + + elemLeft = edge[0] + sideLeft = edge[1] + edgeMasterNodes = self.element.faceNodes[sideLeft][self.element1d.interiorNodes] + conns = conns.at[elemLeft, edgeMasterNodes].set(edgeNodeOrdinals) + + elemRight = edge[2] + if elemRight >= 0: + sideRight = edge[3] + edgeMasterNodes = self.element.faceNodes[sideRight][self.element1d.interiorNodes] + conns = conns.at[elemRight, edgeMasterNodes].set(np.flip(edgeNodeOrdinals)) + + nEdges = edges.shape[0] + nodeOrdinalOffset += nEdges*nNodesPerEdge # for offset of interior node numbering + + # step 3/3: interior nodes + nInNodesPerTri = self.element.interiorNodes.shape[0] + if nInNodesPerTri > 0: + N0 = self.element.coordinates[self.element.interiorNodes, 0] + N1 = self.element.coordinates[self.element.interiorNodes, 1] + N2 = 1.0 - N0 - N1 + A = np.column_stack((N0, N1, N2)) + interiorCoords = jax.vmap(lambda triConn: np.dot(A, self.mesh.coords[triConn]))(self.mesh.conns) + + def add_element_interior_nodes(conn, newNodeOrdinals): + return conn.at[self.element.interiorNodes].set(newNodeOrdinals) + + nTri = conns.shape[0] + newNodeOrdinals = np.arange(nTri*nInNodesPerTri).reshape(nTri,nInNodesPerTri) \ + + nodeOrdinalOffset + + conns = jax.vmap(add_element_interior_nodes)(conns, newNodeOrdinals) + else: + interiorCoords = np.zeros((0, 2)) + + coords = np.vstack((self.mesh.coords, edgeCoords.reshape(-1,2), interiorCoords.reshape(-1,2))) + return conns, coords + + +class DG_PkField(Field): + quadpoint_axis = 0 + + def __init__(self, k: int, mesh: Mesh.Mesh) -> None: + assert k >= 0, "Polynomial degree must be >= 0" + assert mesh.parentElement.degree == 1 + self.order = k + self.element, self.element1d = Interpolants.make_parent_elements(k) + self.mesh = mesh + self.conns, self.coords = self._make_connectivity_and_coordinates() + self.field_shape = Mesh.num_elements(mesh), self.element.num_nodes + + def _make_connectivity_and_coordinates(self): + nnodes = Mesh.num_elements(self.mesh)*self.element.num_nodes + conns = np.arange(nnodes).reshape(-1, self.element.num_nodes) + + def set_elem_coords(simplex_element_conn): + Xs = self.mesh.coords[simplex_element_conn] + J = np.column_stack((Xs[0] - Xs[2], Xs[1] - Xs[2])) + Jxi = self.element.coordinates@J.T + b = Xs[0] - Jxi[0] + return self.element.coordinates@J.T + b + + coords = jax.vmap(set_elem_coords)(self.mesh.conns) + return conns, coords + + def compute_shape_functions(self, points): + return Interpolants.compute_shapes(self.element, points) + + def interpolate(self, shape, U, el_id): + e_vector = U[el_id] + return shape.values@e_vector + + def interpolate_gradient(self, shape, U, el_id): + return jax.vmap(lambda ev, dshp: np.tensordot(ev, dshp, (0, 0)), (None, 0))(U[el_id], shape.gradients) + + def map_shape_functions(self, shapes, jacs): + shape_grads = jax.vmap(lambda dshp, J: dshp@np.linalg.inv(J))(shapes.gradients, jacs) + return Interpolants.ShapeFunctions(shapes.values, shape_grads) + + +class UniformField(Field): + """A unique value for the whole mesh (things like time).""" + quadpoint_axis = None + + def interpolate(self, shape, U, el_id): + return U + + def interpolate_gradient(self, shape, U, el_id): + raise NotImplementedError(f"Gradients not supported for {type(self).__name__}") + + def compute_shape_functions(self, points): + return Interpolants.ShapeFunctions(np.array([]), np.array([])) + + def map_shape_functions(self, shapes, jacs): + return super().map_shape_functions(shapes, jacs) + + +class QuadratureField(Field): + """Arrays defined directly at quadrature points (things like internal variables).""" + quadpoint_axis = 0 + + def interpolate(self, shape, field, el_id): + return field[el_id] + + def interpolate_gradient(self, shape, field, el_id): + raise NotImplementedError(f"Gradients not supported for {type(self).__name__}") + + def compute_shape_functions(self, points): + return Interpolants.ShapeFunctions(np.array([]), np.array([])) + + def map_shape_functions(self, shapes, jacs): + return super().map_shape_functions(shapes, jacs) + + +@dataclasses.dataclass +class FieldInterpolation: + """Abstract base class for specific types of field interpolations. + + This class should not be instantiated, only derived from. + """ + field: int + + +class Value(FieldInterpolation): + """Sentinel to indicate you want to interpolate the value of the field.""" + pass + + +class Gradient(FieldInterpolation): + """Sentinel to indicate you want to interpolate the gradient of the field.""" + pass + + +def _choose_interpolation_function(input, spaces): + """Helper function to choose correct field space interpolation function for each input.""" + if input.field >= len(spaces): + raise IndexError(f"Field space {input.field} exceeds range, which is {len(spaces) - 1}.") + if type(input) is Value: + return spaces[input.field].interpolate + elif type(input) is Gradient: + return spaces[input.field].interpolate_gradient + else: + raise TypeError("Type of object in qfunction signature is invalid.") + + +class FieldOperator: + def __init__(self, input_spaces: tuple[Field, ...], qfunction_signature: tuple[FieldInterpolation, ...], + mesh: Mesh.Mesh, quadrature_rule: QuadratureRule.QuadratureRule) -> None: + """Entity that can evaluate and integrate functions at points in a mesh. + + Args: + spaces: Collection of ``Field`` objects describing the inputs + qfunction_signature: Tells the ``FieldOperator`` what quantitites to interpolate to the quadrature points. + mesh: Finite mesh over which to evaluate/integrate + quad_rule: Quadrature rule to define evaluation points and integration weights + """ + for input in qfunction_signature: + assert 0 <= input.field < len(input_spaces), """Field index in qfunction signature outside valid range.""" + # the coord space should live on the Mesh eventually + self._coord_space = PkField(mesh.parentElement.degree, mesh) + self._coord_shapes = self._coord_space.compute_shape_functions(quadrature_rule.xigauss) + + self._spaces = input_spaces + self._mesh = mesh + self._quadrature_rule = quadrature_rule + self._shapes = [space.compute_shape_functions(quadrature_rule.xigauss) for space in input_spaces] + self._input_fields = tuple(input.field for input in qfunction_signature) + self._interpolators = tuple(_choose_interpolation_function(input, self._spaces) for input in qfunction_signature) + + def evaluate(self, f: Callable, coords: Float[Array, "nnode ndim"], + block: Int[Array, "nelem"], *fields: Array | tuple[Array, ...]) -> Array: + """Evaluate a function at quadrature points. + + Args: + f: Integrand function. + coords: Spatial coordinates of the mesh, used for parametric mapping + of gradients. Can be different values than given in the + constructor, but the shape must be the same. + block: Element ids identifying the domain of integration. + *fields: Input fields to the functional. Must match the + specifications of the ``inputs`` agrument to the constructor. For + performance reasons, this is not checked. + + Returns: + A ``QuadratureField`` with the value of ``f`` at every quadrature point in the mesh. + """ + f_vmap_axis = None + compute_values = jax.vmap(self._evaluate_on_element, (f_vmap_axis, 0, None) + tuple(None for field in fields)) + return compute_values(f, block, coords, *fields) + + def integrate(self, f: Callable, coords: Float[Array, "nnode ndim"], + block: Int[Array, "nelem"], *fields: Array | tuple[Array, ...]) -> Array: + """Integrate a function over a mesh block. + + Args: + f: Integrand function. + coords: Spatial coordinates of the mesh, used for parametric mapping + of gradients. Can be different values than given in the + constructor, but the shape must be the same. + block: Element ids identifying the domain of integration. + *fields: Input fields to the functional. Must match the + specifications of the ``inputs`` agrument to the constructor. For + performance reasons, this is not checked. + + Returns: + The integral of ``f`` over the domain. + """ + f_vmap_axis = None + integrate = jax.vmap(self._integrate_over_element, (f_vmap_axis, 0, None) + tuple(None for field in fields)) + return np.sum(integrate(f, block, coords, *fields)) + + def _evaluate_on_element(self, f, el_id, coords, *fields): + jacs = self._coord_space.interpolate_gradient(self._coord_shapes, coords, el_id) + shapes = [space.map_shape_functions(shape, jacs) for (space, shape) in zip(self._spaces, self._shapes)] + f_args = [interp(shapes[field_id], fields[field_id], el_id) for (interp, field_id) in zip(self._interpolators, self._input_fields)] + f_batch = jax.vmap(f, tuple(self._spaces[input].quadpoint_axis for input in self._input_fields)) + return f_batch(*f_args) + + def _integrate_over_element(self, f, el_id, coords, *fields): + jacs = self._coord_space.interpolate_gradient(self._coord_shapes, coords, el_id) + shapes = [space.map_shape_functions(shape, jacs) for (space, shape) in zip(self._spaces, self._shapes)] + dVs = jax.vmap(lambda J, w: np.linalg.det(J)*w)(jacs, self._quadrature_rule.wgauss) + f_args = [interp(shapes[field_id], fields[field_id], el_id) for (interp, field_id) in zip(self._interpolators, self._input_fields)] + f_batch = jax.vmap(f, tuple(self._spaces[input].quadpoint_axis for input in self._input_fields)) + f_vals = f_batch(*f_args) + return np.dot(f_vals, dVs) diff --git a/optimism/test/test_FieldOperator.py b/optimism/test/test_FieldOperator.py new file mode 100644 index 00000000..75f5edec --- /dev/null +++ b/optimism/test/test_FieldOperator.py @@ -0,0 +1,207 @@ +import pytest +import jax.numpy as np + +from optimism import Mesh +from optimism.FieldOperator import * + +class TestFieldOperator: + coord_degree = 1 + dim = 2 + length = 3.0 + height = 2.0 + mesh = Mesh.construct_structured_mesh(2, 2, [0.0, length], [0.0, height], coord_degree) + quad_rule = QuadratureRule.create_quadrature_rule_on_triangle(4) + + def test_gradient_evaluation(self): + "Check the gradient of an affine field" + k = 1 + spaces = PkField(k, self.mesh), + integrand_signature = Gradient(0), + field_operator = FieldOperator(spaces, integrand_signature, self.mesh, self.quad_rule) + + target_disp_grad = np.array([[0.1, 0.01], + [0.05, 0.3]]) + U = np.einsum('aj, ij', self.mesh.coords, target_disp_grad + np.identity(2)) - self.mesh.coords + + def f(dudX): + return dudX + + disp_grads = field_operator.evaluate(f, self.mesh.coords, self.mesh.blocks['block_0'], U) + + for H in disp_grads.reshape(-1, 2, 2): + assert pytest.approx(H) == target_disp_grad + + def test_trivial_integral(self): + spaces = PkField(self.coord_degree, self.mesh), + integrand_signature = Value(0), + field_operator = FieldOperator(spaces, integrand_signature, self.mesh, self.quad_rule) + U = np.zeros_like(self.mesh.coords) + def f(u): + return 1.0 + area = field_operator.integrate(f, self.mesh.coords, self.mesh.blocks['block_0'], U) + assert pytest.approx(area) == self.length*self.height + + def test_integral_with_one_nodal_field(self): + "Computes area in a non-trivial way, checking consistency of gradient interpolation." + integrand_signature = PkField(self.coord_degree, self.mesh), + POSITION = 0 + # We're taking the gradient of position, which is just the identity tensor + inputs = Gradient(POSITION), + field_operator = FieldOperator(integrand_signature, inputs, self.mesh, self.quad_rule) + + def f(dXdX): + # note dXdX == identity, so + # trace(dXdX)/dim = 1 + return np.trace(dXdX)/self.dim + + area = field_operator.integrate(f, self.mesh.coords, self.mesh.blocks['block_0'], self.mesh.coords) + assert pytest.approx(area) == self.length*self.height + + def test_helmholtz(self): + "Tests interpolation, gradient, and simple use of a QuadratureField" + spaces = PkField(2, self.mesh), QuadratureField() + + integrand_signature = Value(0), Gradient(0), Value(1) + field_operator = FieldOperator(spaces, integrand_signature, self.mesh, self.quad_rule) + + def f(u, dudX, q): + return 0.5*q[0]*(u*u + np.dot(dudX, dudX)) + + # u(X, Y) = 0.1*X + 0.01*Y + 2 + target_grad = np.array([0.1, 0.01]) + U = spaces[0].coords@target_grad + 2.0 + + Q = 2*np.ones((Mesh.num_elements(self.mesh), len(self.quad_rule), 1)) + + energy = field_operator.integrate(f, self.mesh.coords, self.mesh.blocks['block_0'], U, Q) + print(f"{energy:.12e}") + assert energy == pytest.approx(28.0994) + + def test_nonexistent_field_id_gets_error(self): + spaces = PkField(self.coord_degree, self.mesh), + integrand_signature = Gradient(1), # there is no field 1 + with pytest.raises(AssertionError): + field_operator = FieldOperator(spaces, integrand_signature, self.mesh, self.quad_rule) + + def test_jit_and_grad(self): + k = 2 + spaces = PkField(k, self.mesh), + integrand_signature = Gradient(0), + field_operator = FieldOperator(spaces, integrand_signature, self.mesh, self.quad_rule) + + def f(dudX): + return 0.5*np.dot(dudX, dudX) + + @jax.jit + def energy(U): + return field_operator.integrate(f, self.mesh.coords, self.mesh.blocks['block_0'], U) + + target_grad = np.array([0.1, 0.01]) + V = spaces[0].coords@target_grad + + e = energy(V) + assert e == pytest.approx(0.5*np.dot(target_grad, target_grad)*self.length*self.height) + + force = jax.jit(jax.grad(energy)) + F = force(V) + assert sum(F) == pytest.approx(0.0) + + +class TestDGField: + quad_rule = QuadratureRule.create_quadrature_rule_on_triangle(1) + mesh = Mesh.construct_structured_mesh(3, 3, [0.0, 2.0], [0.0, 1.0]) + + def test_dg_interpolation(self): + k = 1 + space = DG_PkField(k, self.mesh) + + # make a DG field that is u = 0.1*x on x < 1, u = 2 on x > 1 + def field_values(el_coords): + centroid = np.mean(el_coords, axis=0) + return np.where(centroid[0] < 1.0, 0.1*el_coords[:, 0], 2.0*np.ones_like(el_coords[:, 0])) + + U = jax.vmap(field_values)(space.coords) + + coord_space = PkField(1, self.mesh) + field_operator = FieldOperator((space, coord_space), (Value(0), Value(1)), self.mesh, self.quad_rule) + def f(u, x): + return u, x + + # elements in left-hand side are [0:4] -> u = 0.1*x + u_q, x_q = field_operator.evaluate(f, self.mesh.coords, np.arange(4), U, self.mesh.coords) + Uq_expected = 0.1*x_q[..., 0] + assert u_q == pytest.approx(Uq_expected) + + # elements in right-hand side [4:7] u = 2.0 + u_q, x_q = field_operator.evaluate(f, self.mesh.coords, np.arange(4, 8), U, self.mesh.coords) + assert u_q == pytest.approx(2.0) + + def test_dg_gradient_interpolation(self): + k = 1 + self.mesh = Mesh.construct_structured_mesh(3, 3, [0.0, 2.0], [0.0, 1.0]) + space = DG_PkField(k, self.mesh) + + # make a DG field that is u = 0.1*x on x < 1, u = 2 on x > 1 + def field_values(el_coords): + centroid = np.mean(el_coords, axis=0) + return np.where(centroid[0] < 1.0, 0.1*el_coords[:, 0], 2.0*np.ones_like(el_coords[:, 0])) + + U = jax.vmap(field_values)(space.coords) + + coord_space = PkField(1, self.mesh) + field_operator = FieldOperator((space, coord_space), (Gradient(0), Value(1)), self.mesh, self.quad_rule) + def f(dudX, x): + return dudX, x + + # elements in left-hand side are [0:4] -> grad u = [0.1, 0.0] + dudX_q, _ = field_operator.evaluate(f, self.mesh.coords, np.arange(4), U, self.mesh.coords) + for dudX in dudX_q.reshape(-1, 2): + assert dudX == pytest.approx(np.array([0.1, 0.0])) + + # elements in right-hand side [4:7] -> grad u = 0 + dudX_q, _ = field_operator.evaluate(f, self.mesh.coords, np.arange(4, 8), U, self.mesh.coords) + assert dudX_q == pytest.approx(0.0) + +def test_parameterized_elasticity(): + mesh = Mesh.construct_structured_mesh(5, 3, [0.0, 1.0], [0.0, 1.0]) + ne = Mesh.num_elements(mesh) + blocks = {'all': np.arange(Mesh.num_elements(mesh)), + 'left': np.arange(ne//2), + 'right': np.arange(ne//2, ne)} + quad_rule = QuadratureRule.create_quadrature_rule_on_triangle(2) + spaces = PkField(1, mesh), DG_PkField(0, mesh) + integrand_signature = Gradient(0), Value(1) + field_operator = FieldOperator(spaces, integrand_signature, mesh, quad_rule) + + lam = 3.0 + + def f(dudX, mu): + strain = 0.5*(dudX + dudX.T) + return mu*np.tensordot(strain, strain) + 0.5*lam*np.trace(strain)**2 + + target_disp_grad = np.array([[0.1, 0.01], + [0.05, 0.3]]) + U = np.einsum('aj, ij', mesh.coords, target_disp_grad + np.identity(2)) - mesh.coords + + mu_left = 1.0 + mu_right = 2.0 + mu = np.zeros(spaces[1].field_shape) + mu = mu.at[blocks['left']].set(mu_left) + mu = mu.at[blocks['right']].set(mu_right) + + def energy(U, mu): + return field_operator.integrate(f, mesh.coords, blocks['all'], U, mu) + + R = jax.grad(energy, 0)(U, mu) + assert R[6] == pytest.approx(np.zeros(2)) + assert R[8] == pytest.approx(np.zeros(2)) + + + stresses = field_operator.evaluate(jax.grad(f, 0), mesh.coords, blocks['all'], U, mu) + strain = 0.5*(target_disp_grad + target_disp_grad.T) + stress_left_exact = 2.0*mu_left*strain + lam*np.trace(strain)*np.identity(2) + stress_right_exact = 2.0*mu_right*strain + lam*np.trace(strain)*np.identity(2) + for stress in stresses[blocks['left']].reshape(-1, 2, 2): + assert stress == pytest.approx(stress_left_exact) + for stress in stresses[blocks['right']].reshape(-1, 2, 2): + assert stress == pytest.approx(stress_right_exact)