From 3f522145ed9133a43f33897cfb943f9b965cb74f Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Tue, 4 Nov 2025 12:34:00 +0000 Subject: [PATCH] Fieldsplit: aggressively expand Forms to detect zero blocks --- firedrake/formmanipulation.py | 12 ++++++++---- tests/firedrake/regression/test_split.py | 12 ++++++++++++ 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/firedrake/formmanipulation.py b/firedrake/formmanipulation.py index eb830493e1..b3dba367ed 100644 --- a/firedrake/formmanipulation.py +++ b/firedrake/formmanipulation.py @@ -3,9 +3,9 @@ import collections from ufl import as_tensor, as_vector, split -from ufl.classes import Zero, FixedIndex, ListTensor, ZeroBaseForm +from ufl.classes import FixedIndex, Form, ListTensor, Zero, ZeroBaseForm from ufl.algorithms.map_integrands import map_integrand_dags -from ufl.algorithms import expand_derivatives +from ufl.algorithms import expand_derivatives, expand_indices from ufl.corealg.map_dag import MultiFunction, map_expr_dags from pyop2 import MixedDat @@ -78,9 +78,13 @@ def split(self, form, argument_indices): assert (len(idx) == 1 for idx in self.blocks.values()) assert (idx[0] == 0 for idx in self.blocks.values()) return form - # TODO find a way to distinguish empty Forms avoiding expand_derivatives f = map_integrand_dags(self, form) - if expand_derivatives(f).empty(): + + # TODO find a better way to distinguish empty Forms + f_expanded = expand_derivatives(f) + if isinstance(f_expanded, Form): + f_expanded = expand_indices(f_expanded) + if f_expanded.empty(): # Get ZeroBaseForm with the right shape f = ZeroBaseForm(tuple(map(self._subspace_argument, form.arguments()))) return f diff --git a/tests/firedrake/regression/test_split.py b/tests/firedrake/regression/test_split.py index fcf8f221f8..a37e578a42 100644 --- a/tests/firedrake/regression/test_split.py +++ b/tests/firedrake/regression/test_split.py @@ -125,3 +125,15 @@ def test_split_coefficient_not_argument(): as_vector([TestFunction(V), 0])), w, wr) assert J00.signature() == expect.signature() + + +def test_split_zero_block(): + mesh = UnitSquareMesh(1, 1) + V = FunctionSpace(mesh, "DG", 0) + Z = V * V * V * V + J = inner(TrialFunction(Z), TestFunction(Z))*dx + splitter = ExtractSubBlock() + + J00 = splitter.split(J, (0, 1)) + + assert isinstance(J00, ZeroBaseForm)