5050 CSECachingMapperMixin ,
5151 )
5252import immutables
53+ from pymbolic .mapper .equality import (
54+ EqualityMapper as EqualityMapperBase )
5355from pymbolic .mapper .evaluator import \
5456 CachedEvaluationMapper as EvaluationMapperBase
5557from pymbolic .mapper .substitutor import \
@@ -502,6 +504,60 @@ def map_substitution(self, name, rule, arguments):
502504
503505 return self .rec (expr )
504506
507+
508+ class EqualityMapper (EqualityMapperBase ):
509+ def map_loopy_function_identifier (self , expr , other ) -> bool :
510+ return True
511+
512+ def map_reduction (self , expr , other ) -> bool :
513+ return (
514+ expr .operation == other .operation
515+ and expr .allow_simultaneous == other .allow_simultaneous
516+ and self .rec (expr .expr , other .expr )
517+ and all (iname == other_iname
518+ for iname , other_iname in zip (expr .inames , other .inames )))
519+
520+ def map_group_hw_index (self , expr , other ) -> bool :
521+ return expr .axis == other .axis
522+
523+ map_local_hw_index = map_group_hw_index
524+
525+ def map_linear_subscript (self , expr , other ) -> bool :
526+ return (
527+ self .rec (expr .index , other .index )
528+ and self .rec (expr .aggregate , other .aggregate ))
529+
530+ def map_rule_argument (self , expr , other ) -> bool :
531+ return expr .index == other .index
532+
533+ def map_resolved_function (self , expr , other ) -> bool :
534+ return self .rec (expr .function , other .function )
535+
536+ def map_sub_array_ref (self , expr , other ) -> bool :
537+ return (
538+ len (expr .swept_inames ) == len (other .swept_inames )
539+ and self .rec (expr .subscript , other .subscript )
540+ and all (self .rec (iname , other_iname )
541+ for iname , other_iname in zip (
542+ expr .swept_inames ,
543+ other .swept_inames ))
544+ )
545+
546+ def map_tagged_variable (self , expr , other ) -> bool :
547+ return (
548+ expr .name == other .name
549+ and all (tag == other_tag
550+ for tag , other_tag in zip (expr .tags , other .tags ))
551+ )
552+
553+ def map_type_cast (self , expr , other ) -> bool :
554+ return (
555+ expr .type == other .type
556+ and self .rec (expr .child , other .child ))
557+
558+ def map_fortran_division (self , expr , other ) -> bool :
559+ return self .map_quotient (expr , other )
560+
505561# }}}
506562
507563
@@ -515,15 +571,18 @@ def stringifier(self):
515571 def make_stringifier (self , originating_stringifier = None ):
516572 return StringifyMapper ()
517573
574+ def make_equality_mapper (self ):
575+ return EqualityMapper ()
576+
518577
519578class Literal (LoopyExpressionBase ):
520579 """A literal to be used during code generation.
521580
522581 .. note::
523582
524583 Only used in the output of
525- :mod :`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and
526- similar mappers). Not for use in Loopy source representation.
584+ :class :`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
585+ (and similar mappers). Not for use in :mod:`loopy` source representation.
527586 """
528587
529588 def __init__ (self , s ):
@@ -543,8 +602,8 @@ class ArrayLiteral(LoopyExpressionBase):
543602 .. note::
544603
545604 Only used in the output of
546- :mod :`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and
547- similar mappers). Not for use in Loopy source representation.
605+ :class :`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
606+ (and similar mappers). Not for use in :mod:`loopy` source representation.
548607 """
549608
550609 def __init__ (self , children ):
@@ -573,8 +632,8 @@ class GroupHardwareAxisIndex(HardwareAxisIndex):
573632 .. note::
574633
575634 Only used in the output of
576- :mod :`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
577- similar mappers). Not for use in Loopy source representation.
635+ :class :`loopy.target.c.codegen. expression.ExpressionToCExpressionMapper`
636+ (and similar mappers). Not for use in :mod:`loopy` source representation.
578637 """
579638 mapper_method = "map_group_hw_index"
580639
@@ -584,8 +643,8 @@ class LocalHardwareAxisIndex(HardwareAxisIndex):
584643 .. note::
585644
586645 Only used in the output of
587- :mod :`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
588- similar mappers). Not for use in Loopy source representation.
646+ :class :`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
647+ similar mappers). Not for use in :mod:`loopy` source representation.
589648 """
590649 mapper_method = "map_local_hw_index"
591650
@@ -792,12 +851,6 @@ def __getinitargs__(self):
792851 def get_hash (self ):
793852 return hash ((self .__class__ , self .operation , self .inames , self .expr ))
794853
795- def is_equal (self , other ):
796- return (other .__class__ == self .__class__
797- and other .operation == self .operation
798- and other .inames == self .inames
799- and other .expr == self .expr )
800-
801854 @property
802855 def is_tuple_typed (self ):
803856 return self .operation .arg_count > 1
@@ -994,14 +1047,6 @@ def __getinitargs__(self):
9941047 def get_hash (self ):
9951048 return hash ((self .__class__ , self .swept_inames , self .subscript ))
9961049
997- def is_equal (self , other ):
998- """
999- Returns *True* iff the sub-array refs have identical expressions.
1000- """
1001- return (other .__class__ == self .__class__
1002- and other .subscript == self .subscript
1003- and other .swept_inames == self .swept_inames )
1004-
10051050 def make_stringifier (self , originating_stringifier = None ):
10061051 return StringifyMapper ()
10071052
0 commit comments