diff --git a/cpp/dolfinx/fem/DirichletBC.h b/cpp/dolfinx/fem/DirichletBC.h index 861c9b108f..28b893c52a 100644 --- a/cpp/dolfinx/fem/DirichletBC.h +++ b/cpp/dolfinx/fem/DirichletBC.h @@ -255,12 +255,13 @@ class DirichletBC { private: /// Compute number of owned dofs indices. Will contain 'gaps' for - /// sub-spaces. + /// sub-spaces. The dofs must be unrolled. std::size_t num_owned(const DofMap& dofmap, std::span dofs) { + int bs = dofmap.index_map_bs(); std::int32_t map_size = dofmap.index_map->size_local(); - std::int32_t owned_size = map_size; + std::int32_t owned_size = bs * map_size; auto it = std::ranges::lower_bound(dofs, owned_size); return std::distance(dofs.begin(), it); } @@ -324,8 +325,7 @@ class DirichletBC std::vector> DirichletBC(std::shared_ptr> g, X&& dofs, std::shared_ptr> V) - : _function_space(V), _g(g), _dofs0(std::forward(dofs)), - _owned_indices0(num_owned(*V->dofmap(), _dofs0)) + : _function_space(V), _g(g), _dofs0(std::forward(dofs)) { assert(g); assert(V); @@ -351,10 +351,9 @@ class DirichletBC // Unroll _dofs0 if dofmap block size > 1 if (const int bs = V->dofmap()->bs(); bs > 1) - { - _owned_indices0 *= bs; _dofs0 = unroll_dofs(_dofs0, bs); - } + + _owned_indices0 = num_owned(*_function_space->dofmap(), _dofs0); } /// @brief Create a representation of a Dirichlet boundary condition @@ -374,17 +373,15 @@ class DirichletBC std::vector> DirichletBC(std::shared_ptr> g, X&& dofs) : _function_space(g->function_space()), _g(g), - _dofs0(std::forward(dofs)), - _owned_indices0(num_owned(*_function_space->dofmap(), _dofs0)) + _dofs0(std::forward(dofs)) { assert(_function_space); // Unroll _dofs0 if dofmap block size > 1 if (const int bs = _function_space->dofmap()->bs(); bs > 1) - { - _owned_indices0 *= bs; _dofs0 = unroll_dofs(_dofs0, bs); - } + + _owned_indices0 = num_owned(*_function_space->dofmap(), _dofs0); } /// @brief Create a representation of a Dirichlet boundary condition diff --git a/python/test/unit/fem/test_bcs.py b/python/test/unit/fem/test_bcs.py index a97b41e841..d5adbd18fd 100644 --- a/python/test/unit/fem/test_bcs.py +++ b/python/test/unit/fem/test_bcs.py @@ -366,22 +366,36 @@ def test_mixed_blocked_constant(): @pytest.mark.parametrize("shape", [(), (2,), (3, 2)]) def test_blocked_dof_ownership(shape): + """Test that dof ownership is correctly handled for blocked function spaces.""" mesh = create_unit_square(MPI.COMM_WORLD, 4, 4) V = functionspace(mesh, ("Lagrange", 1, shape)) - u_bc = Function(V) - mesh.topology.create_connectivity(mesh.topology.dim - 1, mesh.topology.dim) - bc_facets = exterior_facet_indices(mesh.topology) - # Blocked spaces are not unrolled here - bc_dofs_u = locate_dofs_topological(V, mesh.topology.dim - 1, bc_facets) - - # Num owned dofs - num_owned_blocked = V.dofmap.index_map.size_local - input_dofs_owned = bc_dofs_u[bc_dofs_u < num_owned_blocked] + tdim = mesh.topology.dim + mesh.topology.create_connectivity(tdim - 1, tdim) + boundary_facets = exterior_facet_indices(mesh.topology) + boundary_dofs = locate_dofs_topological(V, tdim - 1, boundary_facets) - bc = dirichletbc(u_bc, bc_dofs_u) - unrolled_bc_dofs, num_owned = bc.dof_indices() + # Test full space BC + bc = dirichletbc(u_bc, boundary_dofs) + unrolled_dofs, num_owned = bc.dof_indices() - assert len(input_dofs_owned) * V.dofmap.index_map_bs == num_owned - assert len(unrolled_bc_dofs) == len(bc_dofs_u) * V.dofmap.index_map_bs + num_owned_blocked = V.dofmap.index_map.size_local + bs = V.dofmap.index_map_bs + owned_input_dofs = boundary_dofs[boundary_dofs < num_owned_blocked] + + assert len(owned_input_dofs) * bs == num_owned + assert len(unrolled_dofs) == len(boundary_dofs) * bs + + # Test subspace BC for tensor spaces + if len(shape) > 1: + V0, _ = V.sub(0).collapse() + boundary_dofs = locate_dofs_topological((V.sub(0), V0), tdim - 1, boundary_facets) + bc_sub = dirichletbc(u_bc, boundary_dofs, V) + unrolled_dofs_sub, num_owned_sub = bc_sub.dof_indices() + + # Check number of unrolled owned dofs in the full non-collapsed space + boundary_dofs_V = boundary_dofs[0] + owned_sub_dofs = boundary_dofs_V[boundary_dofs_V < num_owned_blocked * bs] + assert len(owned_sub_dofs) == num_owned_sub + assert len(unrolled_dofs_sub) == len(boundary_dofs_V)