Skip to content

Commit 0fc8143

Browse files
committed
fix
fix lint
1 parent 1862c21 commit 0fc8143

File tree

3 files changed

+93
-38
lines changed

3 files changed

+93
-38
lines changed

firedrake/assemble.py

Lines changed: 65 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,16 @@ def allocation_integral_types(self):
387387
else:
388388
return self._allocation_integral_types
389389

390+
@staticmethod
391+
def _as_pyop2_type(tensor, indices=None):
392+
if isinstance(tensor, (firedrake.Cofunction, firedrake.Function)):
393+
return OneFormAssembler._as_pyop2_type(tensor, indices=indices)
394+
elif isinstance(tensor, ufl.Matrix):
395+
return ExplicitMatrixAssembler._as_pyop2_type(tensor, indices=indices)
396+
else:
397+
assert indices is None
398+
return tensor
399+
390400
def assemble(self, tensor=None, current_state=None):
391401
"""Assemble the form.
392402
@@ -411,21 +421,22 @@ def assemble(self, tensor=None, current_state=None):
411421
"""
412422
def visitor(e, *operands):
413423
t = tensor if e is self._form else None
414-
return self.base_form_assembly_visitor(e, t, *operands)
424+
# Deal with 2-form bcs inside the visitor
425+
bcs = self._bcs if isinstance(e, ufl.BaseForm) and len(e.arguments()) == 2 else ()
426+
return self.base_form_assembly_visitor(e, t, bcs, *operands)
415427

416428
# DAG assembly: traverse the DAG in a post-order fashion and evaluate the node on the fly.
417429
visited = {}
418430
result = BaseFormAssembler.base_form_postorder_traversal(self._form, visitor, visited)
419431

420-
# Apply BCs after assembly
432+
# Deal with 1-form bcs outside the visitor
421433
rank = len(self._form.arguments())
422434
if rank == 1 and not isinstance(result, ufl.ZeroBaseForm):
423435
for bc in self._bcs:
424436
OneFormAssembler._apply_bc(self, result, bc, u=current_state)
425-
426437
return result
427438

428-
def base_form_assembly_visitor(self, expr, tensor, *args):
439+
def base_form_assembly_visitor(self, expr, tensor, bcs, *args):
429440
r"""Assemble a :class:`~ufl.classes.BaseForm` object given its assembled operands.
430441
431442
This functions contains the assembly handlers corresponding to the different nodes that
@@ -446,7 +457,7 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
446457
assembler = OneFormAssembler(form, form_compiler_parameters=self._form_compiler_params,
447458
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal, weight=self._weight)
448459
elif rank == 2:
449-
assembler = TwoFormAssembler(form, bcs=self._bcs, form_compiler_parameters=self._form_compiler_params,
460+
assembler = TwoFormAssembler(form, bcs=bcs, form_compiler_parameters=self._form_compiler_params,
450461
mat_type=self._mat_type, sub_mat_type=self._sub_mat_type,
451462
options_prefix=self._options_prefix, appctx=self._appctx, weight=self._weight,
452463
allocation_integral_types=self.allocation_integral_types)
@@ -457,13 +468,12 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
457468
if len(args) != 1:
458469
raise TypeError("Not enough operands for Adjoint")
459470
mat, = args
460-
res = tensor.petscmat if tensor else PETSc.Mat()
461-
petsc_mat = mat.petscmat
471+
result = tensor.petscmat if tensor else PETSc.Mat()
462472
# Out-of-place Hermitian transpose
463-
petsc_mat.hermitianTranspose(out=res)
464-
(row, col) = mat.arguments()
465-
return matrix.AssembledMatrix((col, row), self._bcs, res,
466-
options_prefix=self._options_prefix)
473+
mat.petscmat.hermitianTranspose(out=result)
474+
if tensor is None:
475+
tensor = self.assembled_matrix(expr, bcs, result)
476+
return tensor
467477
elif isinstance(expr, ufl.Action):
468478
if len(args) != 2:
469479
raise TypeError("Not enough operands for Action")
@@ -481,7 +491,7 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
481491
result = tensor.petscmat if tensor else PETSc.Mat()
482492
lhs.petscmat.matMult(rhs.petscmat, result=result)
483493
if tensor is None:
484-
tensor = self.assembled_matrix(expr, result)
494+
tensor = self.assembled_matrix(expr, bcs, result)
485495
return tensor
486496
else:
487497
raise TypeError("Incompatible RHS for Action.")
@@ -500,9 +510,6 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
500510
raise TypeError("Mismatching weights and operands in FormSum")
501511
if len(args) == 0:
502512
raise TypeError("Empty FormSum")
503-
if tensor:
504-
tensor.zero()
505-
506513
# Assemble weights
507514
weights = []
508515
for w in expr.weights():
@@ -520,27 +527,54 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
520527
raise ValueError("Expecting a scalar weight expression")
521528
weights.append(w)
522529

530+
# Scalar FormSum
523531
if all(isinstance(op, numbers.Complex) for op in args):
524-
result = sum(weight * arg for weight, arg in zip(weights, args))
525-
return tensor.assign(result) if tensor else result
526-
elif (all(isinstance(op, firedrake.Cofunction) for op in args)
532+
result = numpy.dot(weights, args)
533+
return tensor.assign(result) if tensor else result.item()
534+
535+
# Accumulate coefficients in a dictionary for each unique Dat/Mat
536+
terms = defaultdict(PETSc.ScalarType)
537+
for arg, weight in zip(args, weights):
538+
t = self._as_pyop2_type(arg)
539+
terms[t] += weight
540+
541+
# Zero the output tensor, or rescale it if it appears in the sum
542+
tensor_scale = terms.pop(self._as_pyop2_type(tensor), 0)
543+
if tensor is None or tensor_scale == 1:
544+
pass
545+
elif tensor_scale == 0:
546+
tensor.zero()
547+
elif isinstance(tensor, (firedrake.Cofunction, firedrake.Function)):
548+
with tensor.dat.vec as v:
549+
v.scale(tensor_scale)
550+
elif isinstance(tensor, ufl.Matrix):
551+
tensor.petscmat.scale(tensor_scale)
552+
else:
553+
raise ValueError("Expecting tensor to be None, Function, Cofunction, or Matrix")
554+
555+
# Compute the linear combination
556+
if (all(isinstance(op, firedrake.Cofunction) for op in args)
527557
or all(isinstance(op, firedrake.Function) for op in args)):
558+
# Vector FormSum
528559
V, = set(a.function_space() for a in args)
529560
result = tensor if tensor else firedrake.Function(V)
530-
result.dat.maxpy(weights, [a.dat for a in args])
561+
weights = terms.values()
562+
dats = terms.keys()
563+
result.dat.maxpy(weights, dats)
531564
return result
532565
elif all(isinstance(op, ufl.Matrix) for op in args):
566+
# Matrix FormSum
533567
result = tensor.petscmat if tensor else PETSc.Mat()
534-
for (op, w) in zip(args, weights):
568+
for (op, w) in terms.items():
535569
if result:
536570
# If result is not void, then accumulate on it
537-
result.axpy(w, op.petscmat)
571+
result.axpy(w, op.handle)
538572
else:
539573
# If result is void, then allocate it with first term
540-
op.petscmat.copy(result=result)
574+
op.handle.copy(result=result)
541575
result.scale(w)
542576
if tensor is None:
543-
tensor = self.assembled_matrix(expr, result)
577+
tensor = self.assembled_matrix(expr, bcs, result)
544578
return tensor
545579
else:
546580
raise TypeError("Mismatching FormSum shapes")
@@ -572,9 +606,8 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
572606
# Occur in situations such as Interpolate composition
573607
operand = assembled_operand[0]
574608

575-
reconstruct_interp = expr._ufl_expr_reconstruct_
576609
if (v, operand) != expr.argument_slots():
577-
expr = reconstruct_interp(operand, v=v)
610+
expr = expr._ufl_expr_reconstruct_(operand, v=v)
578611

579612
rank = len(expr.arguments())
580613
if rank > 2:
@@ -591,8 +624,8 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
591624
else:
592625
raise TypeError(f"Unrecognised BaseForm instance: {expr}")
593626

594-
def assembled_matrix(self, expr, petscmat):
595-
return matrix.AssembledMatrix(expr.arguments(), self._bcs, petscmat,
627+
def assembled_matrix(self, expr, bcs, petscmat):
628+
return matrix.AssembledMatrix(expr.arguments(), bcs, petscmat,
596629
options_prefix=self._options_prefix)
597630

598631
@staticmethod
@@ -1441,10 +1474,11 @@ def _apply_bc(self, tensor, bc, u=None):
14411474
index = 0 if V.index is None else V.index
14421475
space = V if V.parent is None else V.parent
14431476
if isinstance(bc, DirichletBC):
1444-
if space != spaces[0]:
1445-
raise TypeError("bc space does not match the test function space")
1446-
elif space != spaces[1]:
1447-
raise TypeError("bc space does not match the trial function space")
1477+
if not any(space == fs for fs in spaces):
1478+
raise TypeError("bc space does not match the test or trial function space")
1479+
if spaces[0] != spaces[1]:
1480+
# Not on a diagonal block, we cannot set diagonal entries
1481+
return
14481482

14491483
# Set diagonal entries on bc nodes to 1 if the current
14501484
# block is on the matrix diagonal and its index matches the

tests/firedrake/regression/test_interpolate.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,3 +599,24 @@ def test_interpolator_reuse(family, degree, mode):
599599

600600
# Test for correctness
601601
assert np.allclose(result.dat.data, expected)
602+
603+
604+
def test_mixed_space_bcs():
605+
mesh = UnitSquareMesh(2, 2)
606+
V = FunctionSpace(mesh, "CG", 1)
607+
W = V * V
608+
rg = RandomGenerator(PCG64(seed=123456789))
609+
w = rg.uniform(W)
610+
611+
bcs = [DirichletBC(W.sub(0), 0, 1),
612+
DirichletBC(W.sub(1), 0, 2),
613+
DirichletBC(V, 0, (3, 4))]
614+
615+
I = assemble(interpolate(sum(TrialFunction(W)), V), bcs=bcs)
616+
result = assemble(action(I, w))
617+
618+
for bc in bcs[:-1]:
619+
bc.zero(w)
620+
expected = assemble(interpolate(sum(w), V), bcs=bcs[-1:])
621+
622+
assert np.allclose(result.dat.data, expected.dat.data)

tests/firedrake/submesh/test_submesh_interpolate.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@
1010

1111
def _get_expr(V):
1212
m = V.ufl_domain()
13-
if m.geometric_dimension() == 1:
13+
if m.geometric_dimension == 1:
1414
x, = SpatialCoordinate(m)
1515
y = x * x
1616
z = x + y
17-
elif m.geometric_dimension() == 2:
17+
elif m.geometric_dimension == 2:
1818
x, y = SpatialCoordinate(m)
1919
z = x + y
20-
elif m.geometric_dimension() == 3:
20+
elif m.geometric_dimension == 3:
2121
x, y, z = SpatialCoordinate(m)
2222
else:
2323
raise NotImplementedError("Not implemented")
@@ -28,7 +28,7 @@ def _get_expr(V):
2828

2929

3030
def make_submesh(mesh, subdomain_cond, label_value):
31-
dim = mesh.topological_dimension()
31+
dim = mesh.topological_dimension
3232
DG0 = FunctionSpace(mesh, "DG", 0)
3333
indicator_function = Function(DG0).interpolate(subdomain_cond)
3434
mesh.mark_entities(indicator_function, label_value)
@@ -145,7 +145,7 @@ def test_submesh_interpolate_subcell_subcell_2_processes():
145145
mesh = RectangleMesh(
146146
3, 1, 3., 1., quadrilateral=True, distribution_parameters={"partitioner_type": "simple"},
147147
)
148-
dim = mesh.topological_dimension()
148+
dim = mesh.topological_dimension
149149
x, _ = SpatialCoordinate(mesh)
150150
DG0 = FunctionSpace(mesh, "DG", 0)
151151
f_l = Function(DG0).interpolate(conditional(x < 2.0, 1, 0))
@@ -210,7 +210,7 @@ def expr(m):
210210
)
211211
facet_value = 999
212212
mesh = RelabeledMesh(mesh, [facet_function], [facet_value])
213-
subm = Submesh(mesh, mesh.topological_dimension() - 1, facet_value)
213+
subm = Submesh(mesh, mesh.topological_dimension - 1, facet_value)
214214
DG3d = FunctionSpace(mesh, "DG", degree)
215215
dg3d = Function(DG3d).interpolate(expr(mesh))
216216
DG2d = FunctionSpace(subm, "DG", degree)
@@ -258,7 +258,7 @@ def expr(m):
258258
facet_function = Function(V).interpolate(Constant(1.))
259259
facet_value = 999
260260
mesh = RelabeledMesh(mesh, [facet_function], [facet_value])
261-
subm = Submesh(mesh, mesh.topological_dimension() - 1, facet_value)
261+
subm = Submesh(mesh, mesh.topological_dimension - 1, facet_value)
262262
HDivT3d = FunctionSpace(mesh, "HDiv Trace", degree)
263263
hdivt3d = Function(HDivT3d).interpolate(expr(mesh))
264264
DG2d = FunctionSpace(subm, "DG", degree)

0 commit comments

Comments
 (0)