Skip to content
11 changes: 9 additions & 2 deletions ffcx/codegeneration/C/form.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#
# SPDX-License-Identifier: LGPL-3.0-or-later
#
# Modified by Chris Richardson and Jørgen S. Dokken 2023
# Modified by Chris Richardson and Jørgen S. Dokken 2023-2025
#
# Note: Most of the code in this file is a direct translation from the
# old implementation in FFC
Expand All @@ -17,6 +17,7 @@
import numpy as np

from ffcx.codegeneration.C import form_template
from ffcx.definitions import IntegralType
from ffcx.ir.representation import FormIR

logger = logging.getLogger("ffcx")
Expand Down Expand Up @@ -119,7 +120,13 @@ def generator(ir: FormIR, options):
integral_offsets = [0]
integral_domains = []
# Note: the order of this list is defined by the enum ufcx_integral_type in ufcx.h
for itg_type in ("cell", "exterior_facet", "interior_facet", "vertex", "ridge"):
for itg_type in (
IntegralType(codim=0),
IntegralType(codim=1, num_neighbours=1),
IntegralType(codim=1, num_neighbours=2),
IntegralType(codim=-1),
IntegralType(codim=2),
):
unsorted_integrals = []
unsorted_ids = []
unsorted_domains = []
Expand Down
66 changes: 33 additions & 33 deletions ffcx/codegeneration/access.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2011-2017 Martin Sandve Alnæs
# Copyright (C) 2011-2025 Martin Sandve Alnæs, Jørgen S. Dokken
#
# This file is part of FFCx. (https://www.fenicsproject.org)
#
Expand All @@ -12,7 +12,7 @@
import ufl

import ffcx.codegeneration.lnodes as L
from ffcx.definitions import entity_types
from ffcx.definitions import IntegralType
from ffcx.ir.analysis.modified_terminals import ModifiedTerminal
from ffcx.ir.elementtables import UniqueTableReferenceT
from ffcx.ir.representationutils import QuadratureRule
Expand All @@ -23,12 +23,11 @@
class FFCXBackendAccess:
"""FFCx specific formatter class."""

entity_type: entity_types
integral_type: IntegralType

def __init__(self, entity_type: entity_types, integral_type: str, symbols, options):
def __init__(self, integral_type: IntegralType, symbols, options):
"""Initialise."""
# Store ir and options
self.entity_type = entity_type
self.integral_type = integral_type
self.symbols = symbols
self.options = options
Expand Down Expand Up @@ -122,27 +121,28 @@ def spatial_coordinate(
if mt.averaged is not None:
raise RuntimeError("Not expecting average of SpatialCoordinates.")

if self.integral_type in ufl.custom_integral_types:
if mt.local_derivatives:
raise RuntimeError("FIXME: Jacobian in custom integrals is not implemented.")

# Access predefined quadrature points table
x = self.symbols.custom_points_table
iq = self.symbols.quadrature_loop_index
(gdim,) = mt.terminal.ufl_shape
if gdim == 1:
index = iq
else:
index = iq * gdim + mt.flat_component
return x[index]
elif self.integral_type == "expression":
# Physical coordinates are computed by code generated in
# definitions
return self.symbols.x_component(mt)
else:
# Physical coordinates are computed by code generated in
# definitions
return self.symbols.x_component(mt)
match self.integral_type:
case IntegralType(is_custom=True):
if mt.local_derivatives:
raise RuntimeError("FIXME: Jacobian in custom integrals is not implemented.")

# Access predefined quadrature points table
x = self.symbols.custom_points_table
iq = self.symbols.quadrature_loop_index
(gdim,) = mt.terminal.ufl_shape
if gdim == 1:
index = iq
else:
index = iq * gdim + mt.flat_component
return x[index]
case IntegralType(is_expression=True):
# Physical coordinates are computed by code generated in
# definitions
return self.symbols.x_component(mt)
case _:
# Physical coordinates are computed by code generated in
# definitions
return self.symbols.x_component(mt)

def cell_coordinate(self, mt, tabledata, num_points):
"""Access a cell coordinate."""
Expand Down Expand Up @@ -181,7 +181,7 @@ def facet_coordinate(self, mt, tabledata, num_points):
if mt.restriction:
raise RuntimeError("Not expecting restriction of FacetCoordinate.")

if self.integral_type in ("interior_facet", "exterior_facet"):
if self.integral_type in ("interior_facet", "facet"):
(tdim,) = mt.terminal.ufl_shape
if tdim == 0:
raise RuntimeError("Vertices have no facet coordinates.")
Expand Down Expand Up @@ -233,7 +233,7 @@ def reference_normal(self, mt, tabledata, access):
cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname()
if cellname in ("interval", "triangle", "tetrahedron", "quadrilateral", "hexahedron"):
table = L.Symbol(f"{cellname}_reference_normals", dtype=L.DataType.REAL)
facet = self.symbols.entity("facet", mt.restriction)
facet = self.symbols.entity(IntegralType(codim=1), mt.restriction)
return table[facet][mt.component[0]]
else:
raise RuntimeError(f"Unhandled cell types {cellname}.")
Expand All @@ -250,7 +250,7 @@ def cell_facet_jacobian(self, mt, tabledata, num_points):
"pyramid",
):
table = L.Symbol(f"{cellname}_cell_facet_jacobian", dtype=L.DataType.REAL)
facet = self.symbols.entity("facet", mt.restriction)
facet = self.symbols.entity(IntegralType(codim=1), mt.restriction)
return table[facet][mt.component[0]][mt.component[1]]
elif cellname == "interval":
raise RuntimeError("The reference facet jacobian doesn't make sense for interval cell.")
Expand All @@ -262,7 +262,7 @@ def cell_ridge_jacobian(self, mt, tabledata, num_points):
cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname()
if cellname in ("tetrahedron", "prism", "hexahedron"):
table = L.Symbol(f"{cellname}_cell_ridge_jacobian", dtype=L.DataType.REAL)
ridge = self.symbols.entity("ridge", mt.restriction)
ridge = self.symbols.entity(IntegralType(codim=2), mt.restriction)
return table[ridge][mt.component[0]][mt.component[1]]
elif cellname in ["triangle", "quadrilateral"]:
raise RuntimeError("The ridge jacobian doesn't make sense for 2D cells.")
Expand Down Expand Up @@ -422,7 +422,7 @@ def _pass(self, *args, **kwargs):
def table_access(
self,
tabledata: UniqueTableReferenceT,
entity_type: entity_types,
integral_type: IntegralType,
restriction: str,
quadrature_index: L.MultiIndex,
dof_index: L.MultiIndex,
Expand All @@ -431,12 +431,12 @@ def table_access(

Args:
tabledata: Table data object
entity_type: Entity type
integral_type: Integral type
restriction: Restriction ("+", "-")
quadrature_index: Quadrature index
dof_index: Dof index
"""
entity = self.symbols.entity(entity_type, restriction)
entity = self.symbols.entity(integral_type, restriction)
iq_global_index = quadrature_index.global_index
ic_global_index = dof_index.global_index
qp = 0 # quadrature permutation
Expand Down
8 changes: 2 additions & 6 deletions ffcx/codegeneration/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,5 @@ def __init__(self, ir: IntegralIR | ExpressionIR, options):
self.symbols = FFCXBackendSymbols(
coefficient_numbering, coefficient_offsets, original_constant_offsets
)
self.access = FFCXBackendAccess(
ir.expression.entity_type, ir.expression.integral_type, self.symbols, options
)
self.definitions = FFCXBackendDefinitions(
ir.expression.entity_type, ir.expression.integral_type, self.access, options
)
self.access = FFCXBackendAccess(ir.expression.integral_type, self.symbols, options)
self.definitions = FFCXBackendDefinitions(ir.expression.integral_type, self.access, options)
13 changes: 6 additions & 7 deletions ffcx/codegeneration/definitions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2011-2023 Martin Sandve Alnæs, Igor A. Baratta
# Copyright (C) 2011-2025 Martin Sandve Alnæs, Igor A. Baratta and Jørgen S. Dokken
#
# This file is part of FFCx. (https://www.fenicsproject.org)
#
Expand All @@ -10,7 +10,7 @@
import ufl

import ffcx.codegeneration.lnodes as L
from ffcx.definitions import entity_types
from ffcx.definitions import IntegralType
from ffcx.ir.analysis.modified_terminals import ModifiedTerminal
from ffcx.ir.elementtables import UniqueTableReferenceT
from ffcx.ir.representationutils import QuadratureRule
Expand Down Expand Up @@ -50,13 +50,12 @@ def create_dof_index(tabledata, dof_index_symbol):
class FFCXBackendDefinitions:
"""FFCx specific code definitions."""

entity_type: entity_types
integral_type: IntegralType

def __init__(self, entity_type: entity_types, integral_type: str, access, options):
def __init__(self, integral_type: IntegralType, access, options):
"""Initialise."""
# Store ir and options
self.integral_type = integral_type
self.entity_type = entity_type
self.access = access
self.options = options

Expand Down Expand Up @@ -148,7 +147,7 @@ def coefficient(
assert begin < end

# Get access to element table
FE, tables = self.access.table_access(tabledata, self.entity_type, mt.restriction, iq, ic)
FE, tables = self.access.table_access(tabledata, self.integral_type, mt.restriction, iq, ic)
dof_access: L.ArrayAccess = self.symbols.coefficient_dof_access(
mt.terminal, (ic.global_index) * bs + begin
)
Expand Down Expand Up @@ -197,7 +196,7 @@ def _define_coordinate_dofs_lincomb(
iq_symbol = self.symbols.quadrature_loop_index
ic = create_dof_index(tabledata, ic_symbol)
iq = create_quadrature_index(quadrature_rule, iq_symbol)
FE, tables = self.access.table_access(tabledata, self.entity_type, mt.restriction, iq, ic)
FE, tables = self.access.table_access(tabledata, self.integral_type, mt.restriction, iq, ic)

dof_access = L.Symbol("coordinate_dofs", dtype=L.DataType.REAL)

Expand Down
4 changes: 2 additions & 2 deletions ffcx/codegeneration/expression_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def generate_geometry_tables(self):
mt = attr.get("mt")
if mt is not None:
t = type(mt.terminal)
if self.ir.expression.entity_type == "cell" and issubclass(
if self.ir.expression.integral_type.codim == 0 and issubclass(
t, ufl.geometry.GeometricFacetQuantity
):
raise RuntimeError(f"Expressions for cells do not support {t}.")
Expand Down Expand Up @@ -302,7 +302,7 @@ def get_arg_factors(self, blockdata, block_rank, indices):
]

table = self.backend.symbols.element_table(
td, self.ir.expression.entity_type, mt.restriction
td, self.ir.expression.integral_type, mt.restriction
)

assert td.ttype != "zeros"
Expand Down
3 changes: 1 addition & 2 deletions ffcx/codegeneration/integral_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,13 +417,12 @@ def get_arg_factors(self, blockdata, block_rank, quadrature_rule, domain, iq, in
# now because it assumes too much about indices.

assert td.ttype != "zeros"

if td.ttype == "ones":
arg_factor = 1
else:
# Assuming B sparsity follows element table sparsity
arg_factor, arg_tables = self.backend.access.table_access(
td, self.ir.expression.entity_type, mt.restriction, iq, indices[i]
td, self.ir.expression.integral_type, mt.restriction, iq, indices[i]
)

tables += arg_tables
Expand Down
2 changes: 1 addition & 1 deletion ffcx/codegeneration/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def compile_forms(


def compile_expressions(
expressions: list[tuple[ufl.Expr, npt.NDArray[np.floating]]], # type: ignore
expressions: list[tuple[ufl.core.expr.Expr, npt.NDArray[np.floating]]],
options: dict = {},
cache_dir: Path | None = None,
timeout: int = 10,
Expand Down
44 changes: 23 additions & 21 deletions ffcx/codegeneration/symbols.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
# Copyright (C) 2011-2023 Martin Sandve Alnæs, Igor A. Baratta
# Copyright (C) 2011-2025 Martin Sandve Alnæs, Igor A. Baratta and Jørgen S. Dokken
#
# This file is part of FFCx. (https://www.fenicsproject.org)
#
# SPDX-License-Identifier: LGPL-3.0-or-later
"""FFCx/UFC specific symbol naming."""

import logging
import typing

import ufl

import ffcx.codegeneration.lnodes as L
from ffcx.definitions import entity_types
from ffcx.definitions import IntegralType

logger = logging.getLogger("ffcx")

Expand Down Expand Up @@ -96,23 +97,25 @@ def __init__(self, coefficient_numbering, coefficient_offsets, original_constant
# Table for chunk of custom quadrature points (physical coordinates).
self.custom_points_table = L.Symbol("points_chunk", dtype=L.DataType.REAL)

def entity(self, entity_type: entity_types, restriction):
def entity(self, integral_type: IntegralType, restriction: typing.Literal["+", "-"]):
"""Entity index for lookup in element tables."""
if entity_type == "cell":
# Always 0 for cells (even with restriction)
return L.LiteralInt(0)

if entity_type == "facet":
if restriction == "-":
return self.entity_local_index[1]
else:
match integral_type:
case IntegralType(codim=0):
# Always 0 for cells (even with restriction)
return L.LiteralInt(0)
case IntegralType(codim=1):
if restriction == "-":
return self.entity_local_index[1]
else:
return self.entity_local_index[0]
case IntegralType(codim=-1):
return self.entity_local_index[0]
elif entity_type == "vertex":
return self.entity_local_index[0]
elif entity_type == "ridge":
return self.entity_local_index[0]
else:
logger.exception(f"Unknown entity_type {entity_type}")
case IntegralType(codim=2):
return self.entity_local_index[0]
case _:
raise RuntimeError(
f"Unknown integral_type {integral_type} in entity restriction lookup/"
)

def argument_loop_index(self, iarg):
"""Loop index for argument iarg."""
Expand Down Expand Up @@ -180,14 +183,13 @@ def constant_index_access(self, constant, index):
return c[offset + index]

# TODO: Remove this, use table_access instead
def element_table(self, tabledata, entity_type: entity_types, restriction):
def element_table(self, tabledata, integral_type: IntegralType, restriction):
"""Get an element table."""
entity = self.entity(entity_type, restriction)

entity = self.entity(integral_type, restriction)
if tabledata.is_uniform:
entity = 0
else:
entity = self.entity(entity_type, restriction)
entity = self.entity(integral_type, restriction)

if tabledata.is_piecewise:
iq = 0
Expand Down
2 changes: 1 addition & 1 deletion ffcx/codegeneration/ufcx.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ extern "C"
typedef enum
{
cell = 0,
exterior_facet = 1,
facet = 1,
interior_facet = 2,
vertex = 3,
ridge = 4,
Expand Down
Loading
Loading