Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
17e75c4
WIP Implement integration/evaluation of functions with arbitrary numb…
btalamini Jun 24, 2025
4d4f263
Fix bug, output size is at least correct
btalamini Jun 24, 2025
9b4c737
WIP Allow gradients of fields in kernel function
btalamini Jun 25, 2025
2a0c004
Allow user to specify signature of q-function in terms of values and …
btalamini Jun 25, 2025
b23f483
Put missing required property (quad point vmap axis) in abstract base…
btalamini Jun 25, 2025
f1d85a7
Add support for quadrature weight-like fields (quad point data on par…
btalamini Jun 26, 2025
dddaa33
Tidy and give more appropriate names
btalamini Jun 26, 2025
cd22ac0
Put in more input validation in constructor
btalamini Jun 27, 2025
0d4d8ac
Implement field evaluation in another way that allows transforming th…
btalamini Jul 6, 2025
b5d40d4
Implement first unit tests
btalamini Jul 6, 2025
e135b03
Implement uniform fields and quadrature fields
btalamini Jul 6, 2025
5617359
Remove temperory tests that were located directly in module
btalamini Jul 6, 2025
1c7631a
Imrpove clarity of tests
btalamini Jul 6, 2025
8f1581b
Allow different order fields by storing a connectivity table for each…
btalamini Jul 16, 2025
a8704c4
Make tests pass with recent change to Pk space (holds mesh)
btalamini Aug 5, 2025
1d1b742
Prevent (for now) use of Pk space with higher order meshes, until mes…
btalamini Aug 5, 2025
733e348
Remove field dimension from Field classes, they don't need to know this
btalamini Aug 7, 2025
75ff055
Add a test with a quadrature field
btalamini Aug 7, 2025
d1f6779
Test helmholtz against analytical solution
btalamini Aug 7, 2025
d1195d2
Change how loop over elements is implemented so that you can use fiel…
btalamini Aug 8, 2025
cd51e32
Add coordinates to all spaces, get integrals with higher order fields…
btalamini Aug 10, 2025
ac2cb84
Add comments and rename things for clarity
btalamini Aug 11, 2025
846e1aa
Add more type hints and input error checking
btalamini Aug 11, 2025
83ecaa4
Remove earlier attempt
btalamini Aug 11, 2025
5b90259
Add test that code can jit and support jax gradients
btalamini Aug 11, 2025
e9cdb74
Specify block of elements to integrate over
btalamini Aug 11, 2025
9ccfeb4
write docstrigns
btalamini Aug 11, 2025
8d1f9a8
Remove unused imports
btalamini Aug 11, 2025
d109cd9
Add more type annotations
btalamini Aug 11, 2025
bf528d8
Rename variables in tests for increased clarity
btalamini Aug 11, 2025
7518a87
Rename class to FieldOperator
btalamini Aug 11, 2025
8435161
Rename test class consistently
btalamini Aug 11, 2025
3589dbf
Add DG field with one test
btalamini Aug 12, 2025
e4a3482
Test interpolation of gradient for DG space
btalamini Aug 13, 2025
965f55c
Test DG P0 field
btalamini Aug 13, 2025
74d1413
Make type annotations work with older versions of Python with a futur…
btalamini Aug 22, 2025
2c7bc6c
Merge branch 'main' into new_function_space
btalamini Aug 22, 2025
cca6878
Update ci-build.yml
cmhamel Aug 28, 2025
acc95ff
Merge branch 'main' into new_function_space
btalamini Oct 16, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions .github/workflows/ci-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,18 @@ jobs:
pip install -e ".[sparse,test]"
python -m pytest -n auto optimism --cov=optimism -Wignore
# we can also add the flag -n auto for parallel testing
- name: docs
run: |
pip install -e ".[docs,sparse,test]"
cd docs
sphinx-apidoc -o source/ ../optimism -P
make html
- name: Deploy to GitHub Pages
uses: peaceiris/actions-gh-pages@v3
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
publish_dir: docs/build/html # Adjust this if your output directory is different
publish_branch: gh-pages # The branch to deploy to
# - name: docs
# run: |
# pip install -e ".[docs,sparse,test]"
# cd docs
# sphinx-apidoc -o source/ ../optimism -P
# make html
# - name: Deploy to GitHub Pages
# uses: peaceiris/actions-gh-pages@v3
# with:
# github_token: ${{ secrets.GITHUB_TOKEN }}
# publish_dir: docs/build/html # Adjust this if your output directory is different
# publish_branch: gh-pages # The branch to deploy to
- name: codecov
uses: codecov/codecov-action@v4
with:
Expand Down
315 changes: 315 additions & 0 deletions optimism/FieldOperator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,315 @@
from __future__ import annotations
import abc
from collections.abc import Callable
import dataclasses

import jax
import jax.numpy as np
from jax import Array
from jaxtyping import Int, Float

from optimism import Interpolants
from optimism import Mesh
from optimism import QuadratureRule

class Field(abc.ABC):
@abc.abstractmethod
def interpolate(self, shape, U, el_id):
pass

@abc.abstractmethod
def interpolate_gradient(self, shape, U, el_id):
pass

@property
@abc.abstractmethod
def quadpoint_axis(self):
pass

@abc.abstractmethod
def compute_shape_functions(self, points):
pass

@abc.abstractmethod
def map_shape_functions(self, shapes, jacs):
return shapes


class PkField(Field):
"""Standard Lagrange polynomial finite element fields."""
quadpoint_axis = 0

def __init__(self, k: int, mesh: Mesh.Mesh) -> None:
assert k > 0, "Polynomial degree must be positive"
# temporarily restrict to first order meshes.
# Building the connectivity table for higher-order fields requires knowing the
# simplex connectivity. The mesh object must be updated to store that.
assert mesh.parentElement.degree == 1
self.order = k
self.element, self.element1d = Interpolants.make_parent_elements(k)
self.mesh = mesh
self.conns, self.coords = self._make_connectivity_and_coordinates()

def interpolate(self, shape, field, el_id):
e_vector = field[self.conns[el_id]]
return shape.values@e_vector

def interpolate_gradient(self, shape, field, el_id):
return jax.vmap(lambda ev, dshp: np.tensordot(ev, dshp, (0, 0)), (None, 0))(field[self.conns[el_id]], shape.gradients)

def compute_shape_functions(self, points):
return Interpolants.compute_shapes(self.element, points)

def map_shape_functions(self, shapes, jacs):
shape_grads = jax.vmap(lambda dshp, J: [email protected](J))(shapes.gradients, jacs)
return Interpolants.ShapeFunctions(shapes.values, shape_grads)

def _make_connectivity_and_coordinates(self):
if self.order == 1:
return self.mesh.conns, self.mesh.coords

conns = np.zeros((Mesh.num_elements(self.mesh), self.element.coordinates.shape[0]), dtype=int)

# step 1/3: vertex nodes
conns = conns.at[:, self.element.vertexNodes].set(self.mesh.conns)
nodeOrdinalOffset = self.mesh.conns.max() + 1 # offset for later node numbering

# The non-vertex nodes are placed using linear interpolation. When we add meshes that
# have higher-order coordinate spaces, we must remember to update this part with
# the higher-order interpolation.

# step 2/3: mid-edge nodes (excluding vertices)
edgeConns, edges = Mesh.create_edges(self.mesh.conns)
A = np.column_stack((1.0 - self.element1d.coordinates[self.element1d.interiorNodes],
self.element1d.coordinates[self.element1d.interiorNodes]))
edgeCoords = jax.vmap(lambda edgeConn: np.dot(A, self.mesh.coords[edgeConn, :]))(edgeConns)

nNodesPerEdge = self.element1d.interiorNodes.size
for e, edge in enumerate(edges):
edgeNodeOrdinals = nodeOrdinalOffset + np.arange(e*nNodesPerEdge,(e+1)*nNodesPerEdge)

elemLeft = edge[0]
sideLeft = edge[1]
edgeMasterNodes = self.element.faceNodes[sideLeft][self.element1d.interiorNodes]
conns = conns.at[elemLeft, edgeMasterNodes].set(edgeNodeOrdinals)

elemRight = edge[2]
if elemRight >= 0:
sideRight = edge[3]
edgeMasterNodes = self.element.faceNodes[sideRight][self.element1d.interiorNodes]
conns = conns.at[elemRight, edgeMasterNodes].set(np.flip(edgeNodeOrdinals))

nEdges = edges.shape[0]
nodeOrdinalOffset += nEdges*nNodesPerEdge # for offset of interior node numbering

# step 3/3: interior nodes
nInNodesPerTri = self.element.interiorNodes.shape[0]
if nInNodesPerTri > 0:
N0 = self.element.coordinates[self.element.interiorNodes, 0]
N1 = self.element.coordinates[self.element.interiorNodes, 1]
N2 = 1.0 - N0 - N1
A = np.column_stack((N0, N1, N2))
interiorCoords = jax.vmap(lambda triConn: np.dot(A, self.mesh.coords[triConn]))(self.mesh.conns)

def add_element_interior_nodes(conn, newNodeOrdinals):
return conn.at[self.element.interiorNodes].set(newNodeOrdinals)

nTri = conns.shape[0]
newNodeOrdinals = np.arange(nTri*nInNodesPerTri).reshape(nTri,nInNodesPerTri) \
+ nodeOrdinalOffset

conns = jax.vmap(add_element_interior_nodes)(conns, newNodeOrdinals)
else:
interiorCoords = np.zeros((0, 2))

coords = np.vstack((self.mesh.coords, edgeCoords.reshape(-1,2), interiorCoords.reshape(-1,2)))
return conns, coords


class DG_PkField(Field):
quadpoint_axis = 0

def __init__(self, k: int, mesh: Mesh.Mesh) -> None:
assert k >= 0, "Polynomial degree must be >= 0"
assert mesh.parentElement.degree == 1
self.order = k
self.element, self.element1d = Interpolants.make_parent_elements(k)
self.mesh = mesh
self.conns, self.coords = self._make_connectivity_and_coordinates()
self.field_shape = Mesh.num_elements(mesh), self.element.num_nodes

def _make_connectivity_and_coordinates(self):
nnodes = Mesh.num_elements(self.mesh)*self.element.num_nodes
conns = np.arange(nnodes).reshape(-1, self.element.num_nodes)

def set_elem_coords(simplex_element_conn):
Xs = self.mesh.coords[simplex_element_conn]
J = np.column_stack((Xs[0] - Xs[2], Xs[1] - Xs[2]))
Jxi = [email protected]
b = Xs[0] - Jxi[0]
return [email protected] + b

coords = jax.vmap(set_elem_coords)(self.mesh.conns)
return conns, coords

def compute_shape_functions(self, points):
return Interpolants.compute_shapes(self.element, points)

def interpolate(self, shape, U, el_id):
e_vector = U[el_id]
return shape.values@e_vector

def interpolate_gradient(self, shape, U, el_id):
return jax.vmap(lambda ev, dshp: np.tensordot(ev, dshp, (0, 0)), (None, 0))(U[el_id], shape.gradients)

def map_shape_functions(self, shapes, jacs):
shape_grads = jax.vmap(lambda dshp, J: [email protected](J))(shapes.gradients, jacs)
return Interpolants.ShapeFunctions(shapes.values, shape_grads)


class UniformField(Field):
"""A unique value for the whole mesh (things like time)."""
quadpoint_axis = None

def interpolate(self, shape, U, el_id):
return U

def interpolate_gradient(self, shape, U, el_id):
raise NotImplementedError(f"Gradients not supported for {type(self).__name__}")

def compute_shape_functions(self, points):
return Interpolants.ShapeFunctions(np.array([]), np.array([]))

def map_shape_functions(self, shapes, jacs):
return super().map_shape_functions(shapes, jacs)


class QuadratureField(Field):
"""Arrays defined directly at quadrature points (things like internal variables)."""
quadpoint_axis = 0

def interpolate(self, shape, field, el_id):
return field[el_id]

def interpolate_gradient(self, shape, field, el_id):
raise NotImplementedError(f"Gradients not supported for {type(self).__name__}")

def compute_shape_functions(self, points):
return Interpolants.ShapeFunctions(np.array([]), np.array([]))

def map_shape_functions(self, shapes, jacs):
return super().map_shape_functions(shapes, jacs)


@dataclasses.dataclass
class FieldInterpolation:
"""Abstract base class for specific types of field interpolations.

This class should not be instantiated, only derived from.
"""
field: int


class Value(FieldInterpolation):
"""Sentinel to indicate you want to interpolate the value of the field."""
pass


class Gradient(FieldInterpolation):
"""Sentinel to indicate you want to interpolate the gradient of the field."""
pass


def _choose_interpolation_function(input, spaces):
"""Helper function to choose correct field space interpolation function for each input."""
if input.field >= len(spaces):
raise IndexError(f"Field space {input.field} exceeds range, which is {len(spaces) - 1}.")
if type(input) is Value:
return spaces[input.field].interpolate
elif type(input) is Gradient:
return spaces[input.field].interpolate_gradient
else:
raise TypeError("Type of object in qfunction signature is invalid.")


class FieldOperator:
def __init__(self, input_spaces: tuple[Field, ...], qfunction_signature: tuple[FieldInterpolation, ...],
mesh: Mesh.Mesh, quadrature_rule: QuadratureRule.QuadratureRule) -> None:
"""Entity that can evaluate and integrate functions at points in a mesh.

Args:
spaces: Collection of ``Field`` objects describing the inputs
qfunction_signature: Tells the ``FieldOperator`` what quantitites to interpolate to the quadrature points.
mesh: Finite mesh over which to evaluate/integrate
quad_rule: Quadrature rule to define evaluation points and integration weights
"""
for input in qfunction_signature:
assert 0 <= input.field < len(input_spaces), """Field index in qfunction signature outside valid range."""
# the coord space should live on the Mesh eventually
self._coord_space = PkField(mesh.parentElement.degree, mesh)
self._coord_shapes = self._coord_space.compute_shape_functions(quadrature_rule.xigauss)

self._spaces = input_spaces
self._mesh = mesh
self._quadrature_rule = quadrature_rule
self._shapes = [space.compute_shape_functions(quadrature_rule.xigauss) for space in input_spaces]
self._input_fields = tuple(input.field for input in qfunction_signature)
self._interpolators = tuple(_choose_interpolation_function(input, self._spaces) for input in qfunction_signature)

def evaluate(self, f: Callable, coords: Float[Array, "nnode ndim"],
block: Int[Array, "nelem"], *fields: Array | tuple[Array, ...]) -> Array:
"""Evaluate a function at quadrature points.

Args:
f: Integrand function.
coords: Spatial coordinates of the mesh, used for parametric mapping
of gradients. Can be different values than given in the
constructor, but the shape must be the same.
block: Element ids identifying the domain of integration.
*fields: Input fields to the functional. Must match the
specifications of the ``inputs`` agrument to the constructor. For
performance reasons, this is not checked.

Returns:
A ``QuadratureField`` with the value of ``f`` at every quadrature point in the mesh.
"""
f_vmap_axis = None
compute_values = jax.vmap(self._evaluate_on_element, (f_vmap_axis, 0, None) + tuple(None for field in fields))
return compute_values(f, block, coords, *fields)

def integrate(self, f: Callable, coords: Float[Array, "nnode ndim"],
block: Int[Array, "nelem"], *fields: Array | tuple[Array, ...]) -> Array:
"""Integrate a function over a mesh block.

Args:
f: Integrand function.
coords: Spatial coordinates of the mesh, used for parametric mapping
of gradients. Can be different values than given in the
constructor, but the shape must be the same.
block: Element ids identifying the domain of integration.
*fields: Input fields to the functional. Must match the
specifications of the ``inputs`` agrument to the constructor. For
performance reasons, this is not checked.

Returns:
The integral of ``f`` over the domain.
"""
f_vmap_axis = None
integrate = jax.vmap(self._integrate_over_element, (f_vmap_axis, 0, None) + tuple(None for field in fields))
return np.sum(integrate(f, block, coords, *fields))

def _evaluate_on_element(self, f, el_id, coords, *fields):
jacs = self._coord_space.interpolate_gradient(self._coord_shapes, coords, el_id)
shapes = [space.map_shape_functions(shape, jacs) for (space, shape) in zip(self._spaces, self._shapes)]
f_args = [interp(shapes[field_id], fields[field_id], el_id) for (interp, field_id) in zip(self._interpolators, self._input_fields)]
f_batch = jax.vmap(f, tuple(self._spaces[input].quadpoint_axis for input in self._input_fields))
return f_batch(*f_args)

def _integrate_over_element(self, f, el_id, coords, *fields):
jacs = self._coord_space.interpolate_gradient(self._coord_shapes, coords, el_id)
shapes = [space.map_shape_functions(shape, jacs) for (space, shape) in zip(self._spaces, self._shapes)]
dVs = jax.vmap(lambda J, w: np.linalg.det(J)*w)(jacs, self._quadrature_rule.wgauss)
f_args = [interp(shapes[field_id], fields[field_id], el_id) for (interp, field_id) in zip(self._interpolators, self._input_fields)]
f_batch = jax.vmap(f, tuple(self._spaces[input].quadpoint_axis for input in self._input_fields))
f_vals = f_batch(*f_args)
return np.dot(f_vals, dVs)
Loading