Skip to content

Commit 17783e1

Browse files
committed
update docs and add some missing expressions to EqualityMapper
1 parent 2141b26 commit 17783e1

File tree

2 files changed

+77
-37
lines changed

2 files changed

+77
-37
lines changed

pymbolic/mapper/equality.py

Lines changed: 74 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -27,51 +27,93 @@
2727

2828

2929
class EqualityMapper(Mapper):
30+
__slots__ = ["_ids_to_result"]
31+
3032
def __init__(self) -> None:
3133
self._ids_to_result: Dict[Tuple[int, int], bool] = {}
3234

3335
def rec(self, expr: Any, other: Any) -> bool:
3436
key = (id(expr), id(other))
35-
36-
try:
37-
result = self._ids_to_result[key]
38-
except KeyError:
39-
if expr is other:
40-
result = True
41-
elif expr.__class__ != other.__class__:
42-
result = False
43-
elif hash(expr) != hash(other):
44-
result = False
45-
else:
46-
try:
47-
method = getattr(self, expr.mapper_method)
48-
except AttributeError:
49-
if isinstance(expr, Expression):
50-
result = self.handle_unsupported_expression(expr, other)
51-
else:
52-
result = self.map_foreign(expr, other)
37+
if key in self._ids_to_result:
38+
return self._ids_to_result[key]
39+
40+
if expr is other:
41+
result = True
42+
elif expr.__class__ != other.__class__:
43+
result = False
44+
else:
45+
try:
46+
method = getattr(self, expr.mapper_method)
47+
except AttributeError:
48+
if isinstance(expr, Expression):
49+
result = self.handle_unsupported_expression(expr, other)
5350
else:
54-
result = method(expr, other)
55-
56-
self._ids_to_result[key] = result
51+
result = self.map_foreign(expr, other)
52+
else:
53+
result = method(expr, other)
5754

55+
self._ids_to_result[key] = result
5856
return result
5957

6058
def __call__(self, expr: Any, other: Any) -> bool:
6159
return self.rec(expr, other)
6260

61+
# {{{ handle_unsupported_expression
62+
6363
def handle_unsupported_expression(self, expr, other) -> bool:
6464
eq = expr.make_equality_mapper()
6565
if type(self) == type(eq):
6666
raise UnsupportedExpressionError(
6767
"'{}' cannot handle expressions of type '{}'".format(
6868
type(self).__name__, type(expr).__name__))
6969

70+
# NOTE: this may look fishy, but we want to preserve the cache as we
71+
# go through the expression tree, so that it does not do
72+
# unnecessary checks when we change the mapper for some subclass
7073
eq._ids_to_result = self._ids_to_result
74+
7175
return eq(expr, other)
7276

73-
def map_constant(self, expr, other) -> bool:
74-
return expr == other
77+
# }}}
78+
79+
# {{{ foreign
80+
81+
def map_tuple(self, expr, other) -> bool:
82+
return (
83+
len(expr) == len(other)
84+
and all(self.rec(el, other_el)
85+
for el, other_el in zip(expr, other)))
86+
87+
def map_foreign(self, expr, other) -> bool:
88+
from pymbolic.primitives import VALID_CONSTANT_CLASSES
89+
90+
if isinstance(expr, VALID_CONSTANT_CLASSES):
91+
return expr == other
92+
elif isinstance(expr, tuple):
93+
return self.map_tuple(expr, other)
94+
else:
95+
raise ValueError(
96+
f"{type(self).__name__} encountered invalid foreign object: "
97+
f"{expr!r}")
98+
99+
# }}}
100+
101+
# {{{
102+
103+
# NOTE: `type(expr) == type(other)` is checked in `__call__`, so the
104+
# checks below can assume that the two operands always have the same type
105+
106+
# NOTE: as much as possible, these should try to put the "cheap" checks
107+
# first so that the shortcircuiting removes the need to to extra work
108+
109+
# NOTE: `all` is also shortcircuiting, so should be better to use a
110+
# generator there to avoid extra work
111+
112+
def map_wildcard(self, expr, other) -> bool:
113+
return True
114+
115+
def map_function_symbol(self, expr, other) -> bool:
116+
return True
75117

76118
def map_variable(self, expr, other) -> bool:
77119
return expr.name == other.name
@@ -112,6 +154,7 @@ def map_sum(self, expr, other) -> bool:
112154
for child, other_child in zip(expr.children, other.children))
113155
)
114156

157+
map_slice = map_sum
115158
map_product = map_sum
116159
map_min = map_sum
117160
map_max = map_sum
@@ -143,8 +186,8 @@ def map_power(self, expr, other) -> bool:
143186

144187
def map_left_shift(self, expr, other) -> bool:
145188
return (
146-
self.rec(expr.shiftee, other.shiftee)
147-
and self.rec(expr.shift, other.shift))
189+
self.rec(expr.shift, other.shift)
190+
and self.rec(expr.shiftee, other.shiftee))
148191

149192
map_right_shift = map_left_shift
150193

@@ -160,6 +203,12 @@ def map_if(self, expr, other) -> bool:
160203
and self.rec(expr.then, other.then)
161204
and self.rec(expr.else_, other.else_))
162205

206+
def map_if_positive(self, expr, other) -> bool:
207+
return (
208+
self.rec(expr.criterion, other.criterion)
209+
and self.rec(expr.then, other.then)
210+
and self.rec(expr.else_, other.else_))
211+
163212
def map_common_subexpression(self, expr, other) -> bool:
164213
return (
165214
expr.prefix == other.prefix
@@ -188,12 +237,4 @@ def map_polynomial(self, expr, other) -> bool:
188237
self.rec(expr.Base, other.Data)
189238
and self.rec(expr.Data, other.Data))
190239

191-
# {{{ foreign
192-
193-
def map_tuple(self, expr, other) -> bool:
194-
return (
195-
len(expr) == len(other)
196-
and all(self.rec(el, other_el)
197-
for el, other_el in zip(expr, other)))
198-
199240
# }}}

pymbolic/primitives.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ class Expression:
193193
.. automethod:: make_stringifier
194194
195195
.. automethod:: __eq__
196-
.. automethod:: is_equal
196+
.. automethod:: make_equality_mapper
197197
.. automethod:: __hash__
198198
.. automethod:: get_hash
199199
.. automethod:: __str__
@@ -499,11 +499,10 @@ def __repr__(self):
499499
# {{{ hash/equality interface
500500

501501
def __eq__(self, other):
502-
"""Provides equality testing with quick positive and negative paths
503-
based on :func:`id` and :meth:`__hash__`.
502+
"""Provides equality testing with quick positive and negative paths.
504503
505504
Subclasses should generally not override this method, but instead
506-
provide an implementation of :meth:`is_equal`.
505+
provide an implementation of :meth:`make_equality_mapper`.
507506
"""
508507
return self.make_equality_mapper()(self, other)
509508

0 commit comments

Comments
 (0)