@@ -387,6 +387,16 @@ def allocation_integral_types(self):
387387 else :
388388 return self ._allocation_integral_types
389389
390+ @staticmethod
391+ def _as_pyop2_type (tensor , indices = None ):
392+ if isinstance (tensor , (firedrake .Cofunction , firedrake .Function )):
393+ return OneFormAssembler ._as_pyop2_type (tensor , indices = indices )
394+ elif isinstance (tensor , ufl .Matrix ):
395+ return ExplicitMatrixAssembler ._as_pyop2_type (tensor , indices = indices )
396+ else :
397+ assert indices is None
398+ return tensor
399+
390400 def assemble (self , tensor = None , current_state = None ):
391401 """Assemble the form.
392402
@@ -411,21 +421,22 @@ def assemble(self, tensor=None, current_state=None):
411421 """
412422 def visitor (e , * operands ):
413423 t = tensor if e is self ._form else None
414- return self .base_form_assembly_visitor (e , t , * operands )
424+ # Deal with 2-form bcs inside the visitor
425+ bcs = self ._bcs if isinstance (e , ufl .BaseForm ) and len (e .arguments ()) == 2 else ()
426+ return self .base_form_assembly_visitor (e , t , bcs , * operands )
415427
416428 # DAG assembly: traverse the DAG in a post-order fashion and evaluate the node on the fly.
417429 visited = {}
418430 result = BaseFormAssembler .base_form_postorder_traversal (self ._form , visitor , visited )
419431
420- # Apply BCs after assembly
432+ # Deal with 1-form bcs outside the visitor
421433 rank = len (self ._form .arguments ())
422434 if rank == 1 and not isinstance (result , ufl .ZeroBaseForm ):
423435 for bc in self ._bcs :
424436 OneFormAssembler ._apply_bc (self , result , bc , u = current_state )
425-
426437 return result
427438
428- def base_form_assembly_visitor (self , expr , tensor , * args ):
439+ def base_form_assembly_visitor (self , expr , tensor , bcs , * args ):
429440 r"""Assemble a :class:`~ufl.classes.BaseForm` object given its assembled operands.
430441
431442 This functions contains the assembly handlers corresponding to the different nodes that
@@ -446,7 +457,7 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
446457 assembler = OneFormAssembler (form , form_compiler_parameters = self ._form_compiler_params ,
447458 zero_bc_nodes = self ._zero_bc_nodes , diagonal = self ._diagonal , weight = self ._weight )
448459 elif rank == 2 :
449- assembler = TwoFormAssembler (form , bcs = self . _bcs , form_compiler_parameters = self ._form_compiler_params ,
460+ assembler = TwoFormAssembler (form , bcs = bcs , form_compiler_parameters = self ._form_compiler_params ,
450461 mat_type = self ._mat_type , sub_mat_type = self ._sub_mat_type ,
451462 options_prefix = self ._options_prefix , appctx = self ._appctx , weight = self ._weight ,
452463 allocation_integral_types = self .allocation_integral_types )
@@ -457,13 +468,12 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
457468 if len (args ) != 1 :
458469 raise TypeError ("Not enough operands for Adjoint" )
459470 mat , = args
460- res = tensor .petscmat if tensor else PETSc .Mat ()
461- petsc_mat = mat .petscmat
471+ result = tensor .petscmat if tensor else PETSc .Mat ()
462472 # Out-of-place Hermitian transpose
463- petsc_mat . hermitianTranspose (out = res )
464- ( row , col ) = mat . arguments ()
465- return matrix . AssembledMatrix (( col , row ), self ._bcs , res ,
466- options_prefix = self . _options_prefix )
473+ mat . petscmat . hermitianTranspose (out = result )
474+ if tensor is None :
475+ tensor = self .assembled_matrix ( expr , bcs , result )
476+ return tensor
467477 elif isinstance (expr , ufl .Action ):
468478 if len (args ) != 2 :
469479 raise TypeError ("Not enough operands for Action" )
@@ -481,7 +491,7 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
481491 result = tensor .petscmat if tensor else PETSc .Mat ()
482492 lhs .petscmat .matMult (rhs .petscmat , result = result )
483493 if tensor is None :
484- tensor = self .assembled_matrix (expr , result )
494+ tensor = self .assembled_matrix (expr , bcs , result )
485495 return tensor
486496 else :
487497 raise TypeError ("Incompatible RHS for Action." )
@@ -500,9 +510,6 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
500510 raise TypeError ("Mismatching weights and operands in FormSum" )
501511 if len (args ) == 0 :
502512 raise TypeError ("Empty FormSum" )
503- if tensor :
504- tensor .zero ()
505-
506513 # Assemble weights
507514 weights = []
508515 for w in expr .weights ():
@@ -520,27 +527,54 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
520527 raise ValueError ("Expecting a scalar weight expression" )
521528 weights .append (w )
522529
530+ # Scalar FormSum
523531 if all (isinstance (op , numbers .Complex ) for op in args ):
524- result = sum (weight * arg for weight , arg in zip (weights , args ))
525- return tensor .assign (result ) if tensor else result
526- elif (all (isinstance (op , firedrake .Cofunction ) for op in args )
532+ result = numpy .dot (weights , args )
533+ return tensor .assign (result ) if tensor else result .item ()
534+
535+ # Accumulate coefficients in a dictionary for each unique Dat/Mat
536+ terms = defaultdict (PETSc .ScalarType )
537+ for arg , weight in zip (args , weights ):
538+ t = self ._as_pyop2_type (arg )
539+ terms [t ] += weight
540+
541+ # Zero the output tensor, or rescale it if it appears in the sum
542+ tensor_scale = terms .pop (self ._as_pyop2_type (tensor ), 0 )
543+ if tensor is None or tensor_scale == 1 :
544+ pass
545+ elif tensor_scale == 0 :
546+ tensor .zero ()
547+ elif isinstance (tensor , (firedrake .Cofunction , firedrake .Function )):
548+ with tensor .dat .vec as v :
549+ v .scale (tensor_scale )
550+ elif isinstance (tensor , ufl .Matrix ):
551+ tensor .petscmat .scale (tensor_scale )
552+ else :
553+ raise ValueError ("Expecting tensor to be None, Function, Cofunction, or Matrix" )
554+
555+ # Compute the linear combination
556+ if (all (isinstance (op , firedrake .Cofunction ) for op in args )
527557 or all (isinstance (op , firedrake .Function ) for op in args )):
558+ # Vector FormSum
528559 V , = set (a .function_space () for a in args )
529560 result = tensor if tensor else firedrake .Function (V )
530- result .dat .maxpy (weights , [a .dat for a in args ])
561+ weights = terms .values ()
562+ dats = terms .keys ()
563+ result .dat .maxpy (weights , dats )
531564 return result
532565 elif all (isinstance (op , ufl .Matrix ) for op in args ):
566+ # Matrix FormSum
533567 result = tensor .petscmat if tensor else PETSc .Mat ()
534- for (op , w ) in zip ( args , weights ):
568+ for (op , w ) in terms . items ( ):
535569 if result :
536570 # If result is not void, then accumulate on it
537- result .axpy (w , op .petscmat )
571+ result .axpy (w , op .handle )
538572 else :
539573 # If result is void, then allocate it with first term
540- op .petscmat .copy (result = result )
574+ op .handle .copy (result = result )
541575 result .scale (w )
542576 if tensor is None :
543- tensor = self .assembled_matrix (expr , result )
577+ tensor = self .assembled_matrix (expr , bcs , result )
544578 return tensor
545579 else :
546580 raise TypeError ("Mismatching FormSum shapes" )
@@ -572,9 +606,8 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
572606 # Occur in situations such as Interpolate composition
573607 operand = assembled_operand [0 ]
574608
575- reconstruct_interp = expr ._ufl_expr_reconstruct_
576609 if (v , operand ) != expr .argument_slots ():
577- expr = reconstruct_interp (operand , v = v )
610+ expr = expr . _ufl_expr_reconstruct_ (operand , v = v )
578611
579612 rank = len (expr .arguments ())
580613 if rank > 2 :
@@ -591,8 +624,8 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
591624 else :
592625 raise TypeError (f"Unrecognised BaseForm instance: { expr } " )
593626
594- def assembled_matrix (self , expr , petscmat ):
595- return matrix .AssembledMatrix (expr .arguments (), self . _bcs , petscmat ,
627+ def assembled_matrix (self , expr , bcs , petscmat ):
628+ return matrix .AssembledMatrix (expr .arguments (), bcs , petscmat ,
596629 options_prefix = self ._options_prefix )
597630
598631 @staticmethod
@@ -1441,10 +1474,11 @@ def _apply_bc(self, tensor, bc, u=None):
14411474 index = 0 if V .index is None else V .index
14421475 space = V if V .parent is None else V .parent
14431476 if isinstance (bc , DirichletBC ):
1444- if space != spaces [0 ]:
1445- raise TypeError ("bc space does not match the test function space" )
1446- elif space != spaces [1 ]:
1447- raise TypeError ("bc space does not match the trial function space" )
1477+ if not any (space == fs for fs in spaces ):
1478+ raise TypeError ("bc space does not match the test or trial function space" )
1479+ if spaces [0 ] != spaces [1 ]:
1480+ # Not on a diagonal block, we cannot set diagonal entries
1481+ return
14481482
14491483 # Set diagonal entries on bc nodes to 1 if the current
14501484 # block is on the matrix diagonal and its index matches the
0 commit comments