3131 )
3232from pymbolic .mapper .dependency import (
3333 DependencyMapper as DependencyMapperBase )
34+ from pymbolic .mapper .equality import (
35+ EqualityMapper as EqualityMapperBase )
3436from pymbolic .geometric_algebra .mapper import (
3537 CombineMapper as CombineMapperBase ,
3638 IdentityMapper as IdentityMapperBase ,
5153import pytential .symbolic .primitives as prim
5254
5355
56+ # {{{ IdentityMapper
57+
5458def rec_int_g_arguments (mapper , expr ):
5559 densities = mapper .rec (expr .densities )
5660 kernel_arguments = {
@@ -138,6 +142,11 @@ def map_interpolation(self, expr):
138142 return type (expr )(expr .from_dd , expr .to_dd , operand )
139143
140144
145+ # }}}
146+
147+
148+ # {{{ CombineMapper
149+
141150class CombineMapper (CombineMapperBase ):
142151 def map_node_sum (self , expr ):
143152 return self .rec (expr .operand )
@@ -168,6 +177,10 @@ def map_is_shape_class(self, expr):
168177
169178 map_error_expression = map_is_shape_class
170179
180+ # }}}
181+
182+
183+ # {{{ Collector
171184
172185class Collector (CollectorBase , CombineMapper ):
173186 def map_ones (self , expr ):
@@ -186,6 +199,10 @@ def map_int_g(self, expr):
186199class DependencyMapper (DependencyMapperBase , Collector ):
187200 pass
188201
202+ # }}}
203+
204+
205+ # {{{ EvaluationMapper
189206
190207class EvaluationMapper (EvaluationMapperBase ):
191208 """Unlike :mod:`pymbolic.mapper.evaluation.EvaluationMapper`, this class
@@ -249,8 +266,10 @@ def map_common_subexpression(self, expr):
249266 expr .prefix ,
250267 expr .scope )
251268
269+ # }}}
270+
252271
253- # {{{ dofdesc tagging
272+ # {{{ dofdesc tagging: LocationTagger, ToTargetTagger
254273
255274class LocationTagger (CSECachingMapperMixin , IdentityMapper ):
256275 """Used internally by :class:`ToTargetTagger`."""
@@ -655,6 +674,88 @@ def map_int_g(self, expr):
655674# }}}
656675
657676
677+ # {{{ EqualityMapper
678+
679+ class EqualityMapper (EqualityMapperBase ):
680+ def map_ones (self , expr , other ) -> bool :
681+ return expr .dofdesc == other .dofdesc
682+
683+ map_q_weight = map_ones
684+
685+ def map_node_coordinate_component (self , expr , other ) -> bool :
686+ return (
687+ expr .ambient_axis == other .ambient_axis
688+ and expr .dofdesc == other .dofdesc )
689+
690+ def map_num_reference_derivative (self , expr , other ) -> bool :
691+ return (
692+ expr .ref_axes == other .ref_axes
693+ and expr .dofdesc == other .dofdesc
694+ and self .rec (expr .operand , other .operand )
695+ )
696+
697+ def map_node_sum (self , expr , other ) -> bool :
698+ return self .rec (expr .operand , other .operand )
699+
700+ map_node_max = map_node_sum
701+ map_node_min = map_node_sum
702+
703+ def map_elementwise_sum (self , expr , other ) -> bool :
704+ return (
705+ expr .dofdesc == other .dofdesc
706+ and self .rec (expr .operand , other .operand ))
707+
708+ map_elementwise_max = map_elementwise_sum
709+ map_elementwise_min = map_elementwise_sum
710+
711+ def map_int_g (self , expr , other ) -> bool :
712+ import numpy as np
713+
714+ def as_hashable (kernel_arg_value ):
715+ # FIXME: this is here to match the fact that pickled IntGs get
716+ # restored as tuples, not ndarray, so they don't equal anymore
717+ if isinstance (kernel_arg_value , np .ndarray ):
718+ return tuple (kernel_arg_value )
719+ return kernel_arg_value
720+
721+ return (
722+ expr .qbx_forced_limit == other .qbx_forced_limit
723+ and expr .source == other .source
724+ and expr .target == other .target
725+ and len (expr .kernel_arguments ) == len (other .kernel_arguments )
726+ and len (expr .source_kernels ) == len (other .source_kernels )
727+ and len (expr .densities ) == len (other .densities )
728+ and expr .target_kernel == other .target_kernel
729+ and all (knl == other_knl for knl , other_knl in zip (
730+ expr .source_kernels , other .source_kernels )
731+ )
732+ and all (d == other_d for d , other_d in zip (
733+ expr .densities , other .densities ))
734+ and all (k == other_k
735+ and self .rec (as_hashable (v ), as_hashable (other_v ))
736+ for (k , v ), (other_k , other_v ) in zip (
737+ sorted (expr .kernel_arguments .items ()),
738+ sorted (other .kernel_arguments .items ())))
739+ )
740+
741+ def map_interpolation (self , expr , other ) -> bool :
742+ return (
743+ expr .from_dd == other .from_dd
744+ and expr .to_dd == other .to_dd
745+ and self .rec (expr .operand , other .operand ))
746+
747+ def map_is_shape_class (self , expr , other ) -> bool :
748+ return (
749+ expr .shape is other .shape ,
750+ expr .dofdesc == other .dofdesc
751+ )
752+
753+ def map_error_expression (self , expr , other ) -> bool :
754+ return expr .message == other .message
755+
756+ # }}}
757+
758+
658759# {{{ stringifier
659760
660761def stringify_where (where ):
@@ -768,13 +869,13 @@ def map_is_shape_class(self, expr, enclosing_prec):
768869 return "IsShape[{}]({})" .format (stringify_where (expr .dofdesc ),
769870 expr .shape .__name__ )
770871
771- # }}}
772-
773872
774873class PrettyStringifyMapper (
775874 CSESplittingStringifyMapperMixin , StringifyMapper ):
776875 pass
777876
877+ # }}}
878+
778879
779880# {{{ graphviz
780881
0 commit comments