4343 CSECachingMapperMixin ,
4444 )
4545import immutables
46+ from pymbolic .mapper .equality import (
47+ EqualityMapper as EqualityMapperBase )
4648from pymbolic .mapper .evaluator import \
4749 CachedEvaluationMapper as EvaluationMapperBase
4850from pymbolic .mapper .substitutor import \
@@ -501,6 +503,60 @@ def map_substitution(self, name, rule, arguments):
501503
502504 return self .rec (expr )
503505
506+
507+ class EqualityMapper (EqualityMapperBase ):
508+ def map_loopy_function_identifier (self , expr , other ) -> bool :
509+ return True
510+
511+ def map_reduction (self , expr , other ) -> bool :
512+ return (
513+ expr .operation == other .operation
514+ and expr .allow_simultaneous == other .allow_simultaneous
515+ and self .rec (expr .expr , other .expr )
516+ and all (iname == other_iname
517+ for iname , other_iname in zip (expr .inames , other .inames )))
518+
519+ def map_group_hw_index (self , expr , other ) -> bool :
520+ return expr .axis == other .axis
521+
522+ map_local_hw_index = map_group_hw_index
523+
524+ def map_linear_subscript (self , expr , other ) -> bool :
525+ return (
526+ self .rec (expr .index , other .index )
527+ and self .rec (expr .aggregate , other .aggregate ))
528+
529+ def map_rule_argument (self , expr , other ) -> bool :
530+ return expr .index == other .index
531+
532+ def map_resolved_function (self , expr , other ) -> bool :
533+ return self .rec (expr .function , other .function )
534+
535+ def map_sub_array_ref (self , expr , other ) -> bool :
536+ return (
537+ len (expr .swept_inames ) == len (other .swept_inames )
538+ and self .rec (expr .subscript , other .subscript )
539+ and all (self .rec (iname , other_iname )
540+ for iname , other_iname in zip (
541+ expr .swept_inames ,
542+ other .swept_inames ))
543+ )
544+
545+ def map_tagged_variable (self , expr , other ) -> bool :
546+ return (
547+ expr .name == other .name
548+ and all (tag == other_tag
549+ for tag , other_tag in zip (expr .tags , other .tags ))
550+ )
551+
552+ def map_type_cast (self , expr , other ) -> bool :
553+ return (
554+ expr .type == other .type
555+ and self .rec (expr .child , other .child ))
556+
557+ def map_fortran_division (self , expr , other ) -> bool :
558+ return self .map_quotient (expr , other )
559+
504560# }}}
505561
506562
@@ -514,15 +570,18 @@ def stringifier(self):
514570 def make_stringifier (self , originating_stringifier = None ):
515571 return StringifyMapper ()
516572
573+ def make_equality_mapper (self ):
574+ return EqualityMapper ()
575+
517576
518577class Literal (LoopyExpressionBase ):
519578 """A literal to be used during code generation.
520579
521580 .. note::
522581
523582 Only used in the output of
524- :mod :`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and
525- similar mappers). Not for use in Loopy source representation.
583+ :class :`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
584+ (and similar mappers). Not for use in :mod:`loopy` source representation.
526585 """
527586
528587 def __init__ (self , s ):
@@ -542,8 +601,8 @@ class ArrayLiteral(LoopyExpressionBase):
542601 .. note::
543602
544603 Only used in the output of
545- :mod :`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and
546- similar mappers). Not for use in Loopy source representation.
604+ :class :`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
605+ (and similar mappers). Not for use in :mod:`loopy` source representation.
547606 """
548607
549608 def __init__ (self , children ):
@@ -572,8 +631,8 @@ class GroupHardwareAxisIndex(HardwareAxisIndex):
572631 .. note::
573632
574633 Only used in the output of
575- :mod :`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
576- similar mappers). Not for use in Loopy source representation.
634+ :class :`loopy.target.c.codegen. expression.ExpressionToCExpressionMapper`
635+ (and similar mappers). Not for use in :mod:`loopy` source representation.
577636 """
578637 mapper_method = "map_group_hw_index"
579638
@@ -583,8 +642,8 @@ class LocalHardwareAxisIndex(HardwareAxisIndex):
583642 .. note::
584643
585644 Only used in the output of
586- :mod :`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
587- similar mappers). Not for use in Loopy source representation.
645+ :class :`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
646+ similar mappers). Not for use in :mod:`loopy` source representation.
588647 """
589648 mapper_method = "map_local_hw_index"
590649
@@ -791,12 +850,6 @@ def __getinitargs__(self):
791850 def get_hash (self ):
792851 return hash ((self .__class__ , self .operation , self .inames , self .expr ))
793852
794- def is_equal (self , other ):
795- return (other .__class__ == self .__class__
796- and other .operation == self .operation
797- and other .inames == self .inames
798- and other .expr == self .expr )
799-
800853 @property
801854 def is_tuple_typed (self ):
802855 return self .operation .arg_count > 1
@@ -994,14 +1047,6 @@ def __getinitargs__(self):
9941047 def get_hash (self ):
9951048 return hash ((self .__class__ , self .swept_inames , self .subscript ))
9961049
997- def is_equal (self , other ):
998- """
999- Returns *True* iff the sub-array refs have identical expressions.
1000- """
1001- return (other .__class__ == self .__class__
1002- and other .subscript == self .subscript
1003- and other .swept_inames == self .swept_inames )
1004-
10051050 def make_stringifier (self , originating_stringifier = None ):
10061051 return StringifyMapper ()
10071052
0 commit comments