Skip to content

Commit 416e713

Browse files
authored
fix[next]: symbol clash in inline_lambda (#2134)
Before this PR `(λ(x, y) → f(x, x + y))(y + y, x)` with `opcount_preserving=True` would inline to `(λ(x) → f(x, x + x))(y + y)` because the outer `x` vs the lambda parameter `x` was not properly handled. Now, in case of such a conflict, the lambda is first transformed to `λ(x_, y) → f(x_, x_ + y)`.
1 parent 8811635 commit 416e713

File tree

2 files changed

+23
-8
lines changed

2 files changed

+23
-8
lines changed

src/gt4py/next/iterator/transforms/inline_lambdas.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,9 @@ def inline_lambda( # see todo above
7171
if eligible
7272
)
7373
)
74-
syms: set[str] = node.fun.expr.pre_walk_values().if_isinstance(ir.Sym).getattr("id").to_set()
74+
syms: set[str] = node.fun.pre_walk_values().if_isinstance(ir.Sym).getattr("id").to_set()
7575
clashes = refs & syms
76-
expr = node.fun.expr
76+
fun = node.fun
7777
if clashes:
7878
# TODO(tehrengruber): find a better way of generating new symbols in `name_map` that don't collide with each other. E.g. this must still work:
7979
# (lambda arg, arg_: (lambda arg_: ...)(arg))(a, b) # noqa: ERA001 [commented-out-code]
@@ -82,24 +82,25 @@ def inline_lambda( # see todo above
8282
for sym in clashes:
8383
name_map[sym] = ir_misc.unique_symbol(sym, refs | syms | {*name_map.values()})
8484

85-
expr = RenameSymbols().visit(expr, name_map=name_map)
85+
# Let's rename the symbols (including params) of the function.
86+
# If we would like to preserve the original param names, we could alternatively
87+
# rename the eligible symrefs in `args`.
88+
fun = RenameSymbols().visit(fun, name_map=name_map)
8689

8790
symbol_map = {
8891
param.id: arg
89-
for param, arg, eligible in zip(node.fun.params, node.args, eligible_params)
92+
for param, arg, eligible in zip(fun.params, node.args, eligible_params)
9093
if eligible
9194
}
92-
new_expr = RemapSymbolRefs().visit(expr, symbol_map=symbol_map)
95+
new_expr = RemapSymbolRefs().visit(fun.expr, symbol_map=symbol_map)
9396

9497
if all(eligible_params):
9598
new_expr.location = node.location
9699
else:
97100
new_expr = ir.FunCall(
98101
fun=ir.Lambda(
99102
params=[
100-
param
101-
for param, eligible in zip(node.fun.params, eligible_params)
102-
if not eligible
103+
param for param, eligible in zip(fun.params, eligible_params) if not eligible
103104
],
104105
expr=new_expr,
105106
),

tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,20 @@
5555
False: im.plus(im.call("opaque")(), im.call("opaque")()),
5656
},
5757
),
58+
(
59+
# symbol clash when partially inlining (opcount preserving)
60+
"symbol_clash",
61+
# (λ(x, y) → f(x, x + y))(y + y, x)
62+
im.call(im.lambda_("x", "y")(im.call("f")("x", im.plus("x", "y"))))(im.plus("y", "y"), "x"),
63+
{
64+
# (λ(x_) → f(x_, x_ + x))(y + y)
65+
True: im.call(im.lambda_("x_")(im.call("f")("x_", im.plus("x_", "x"))))(
66+
im.plus("y", "y")
67+
),
68+
# f(y + y, (y + y) + x) # noqa: ERA001
69+
False: im.call("f")(im.plus("y", "y"), im.plus(im.plus("y", "y"), "x")),
70+
},
71+
),
5872
]
5973

6074

0 commit comments

Comments
 (0)