2727
2828
2929class EqualityMapper (Mapper ):
30+ __slots__ = ["_ids_to_result" ]
31+
3032 def __init__ (self ) -> None :
3133 self ._ids_to_result : Dict [Tuple [int , int ], bool ] = {}
3234
3335 def rec (self , expr : Any , other : Any ) -> bool :
3436 key = (id (expr ), id (other ))
35-
36- try :
37- result = self ._ids_to_result [key ]
38- except KeyError :
39- if expr is other :
40- result = True
41- elif expr .__class__ != other .__class__ :
42- result = False
43- elif hash (expr ) != hash (other ):
44- result = False
45- else :
46- try :
47- method = getattr (self , expr .mapper_method )
48- except AttributeError :
49- if isinstance (expr , Expression ):
50- result = self .handle_unsupported_expression (expr , other )
51- else :
52- result = self .map_foreign (expr , other )
37+ if key in self ._ids_to_result :
38+ return self ._ids_to_result [key ]
39+
40+ if expr is other :
41+ result = True
42+ elif expr .__class__ != other .__class__ :
43+ result = False
44+ else :
45+ try :
46+ method = getattr (self , expr .mapper_method )
47+ except AttributeError :
48+ if isinstance (expr , Expression ):
49+ result = self .handle_unsupported_expression (expr , other )
5350 else :
54- result = method (expr , other )
55-
56- self . _ids_to_result [ key ] = result
51+ result = self . map_foreign (expr , other )
52+ else :
53+ result = method ( expr , other )
5754
55+ self ._ids_to_result [key ] = result
5856 return result
5957
6058 def __call__ (self , expr : Any , other : Any ) -> bool :
6159 return self .rec (expr , other )
6260
61+ # {{{ handle_unsupported_expression
62+
6363 def handle_unsupported_expression (self , expr , other ) -> bool :
6464 eq = expr .make_equality_mapper ()
6565 if type (self ) == type (eq ):
6666 raise UnsupportedExpressionError (
6767 "'{}' cannot handle expressions of type '{}'" .format (
6868 type (self ).__name__ , type (expr ).__name__ ))
6969
70+ # NOTE: this may look fishy, but we want to preserve the cache as we
71+ # go through the expression tree, so that it does not do
72+ # unnecessary checks when we change the mapper for some subclass
7073 eq ._ids_to_result = self ._ids_to_result
74+
7175 return eq (expr , other )
7276
73- def map_constant (self , expr , other ) -> bool :
74- return expr == other
77+ # }}}
78+
79+ # {{{ foreign
80+
81+ def map_tuple (self , expr , other ) -> bool :
82+ return (
83+ len (expr ) == len (other )
84+ and all (self .rec (el , other_el )
85+ for el , other_el in zip (expr , other )))
86+
87+ def map_foreign (self , expr , other ) -> bool :
88+ from pymbolic .primitives import VALID_CONSTANT_CLASSES
89+
90+ if isinstance (expr , VALID_CONSTANT_CLASSES ):
91+ return expr == other
92+ elif isinstance (expr , tuple ):
93+ return self .map_tuple (expr , other )
94+ else :
95+ raise ValueError (
96+ f"{ type (self ).__name__ } encountered invalid foreign object: "
97+ f"{ expr !r} " )
98+
99+ # }}}
100+
101+ # {{{
102+
103+ # NOTE: `type(expr) == type(other)` is checked in `__call__`, so the
104+ # checks below can assume that the two operands always have the same type
105+
106+ # NOTE: as much as possible, these should try to put the "cheap" checks
107+ # first so that the shortcircuiting removes the need to to extra work
108+
109+ # NOTE: `all` is also shortcircuiting, so should be better to use a
110+ # generator there to avoid extra work
111+
112+ def map_wildcard (self , expr , other ) -> bool :
113+ return True
114+
115+ def map_function_symbol (self , expr , other ) -> bool :
116+ return True
75117
76118 def map_variable (self , expr , other ) -> bool :
77119 return expr .name == other .name
@@ -112,6 +154,7 @@ def map_sum(self, expr, other) -> bool:
112154 for child , other_child in zip (expr .children , other .children ))
113155 )
114156
157+ map_slice = map_sum
115158 map_product = map_sum
116159 map_min = map_sum
117160 map_max = map_sum
@@ -143,8 +186,8 @@ def map_power(self, expr, other) -> bool:
143186
144187 def map_left_shift (self , expr , other ) -> bool :
145188 return (
146- self .rec (expr .shiftee , other .shiftee )
147- and self .rec (expr .shift , other .shift ))
189+ self .rec (expr .shift , other .shift )
190+ and self .rec (expr .shiftee , other .shiftee ))
148191
149192 map_right_shift = map_left_shift
150193
@@ -160,6 +203,12 @@ def map_if(self, expr, other) -> bool:
160203 and self .rec (expr .then , other .then )
161204 and self .rec (expr .else_ , other .else_ ))
162205
206+ def map_if_positive (self , expr , other ) -> bool :
207+ return (
208+ self .rec (expr .criterion , other .criterion )
209+ and self .rec (expr .then , other .then )
210+ and self .rec (expr .else_ , other .else_ ))
211+
163212 def map_common_subexpression (self , expr , other ) -> bool :
164213 return (
165214 expr .prefix == other .prefix
@@ -188,12 +237,4 @@ def map_polynomial(self, expr, other) -> bool:
188237 self .rec (expr .Base , other .Data )
189238 and self .rec (expr .Data , other .Data ))
190239
191- # {{{ foreign
192-
193- def map_tuple (self , expr , other ) -> bool :
194- return (
195- len (expr ) == len (other )
196- and all (self .rec (el , other_el )
197- for el , other_el in zip (expr , other )))
198-
199240 # }}}
0 commit comments