diff --git a/src/gt4py/next/iterator/transforms/inline_lambdas.py b/src/gt4py/next/iterator/transforms/inline_lambdas.py index b7eb45d156..a41f74ebc1 100644 --- a/src/gt4py/next/iterator/transforms/inline_lambdas.py +++ b/src/gt4py/next/iterator/transforms/inline_lambdas.py @@ -71,9 +71,9 @@ def inline_lambda( # see todo above if eligible ) ) - syms: set[str] = node.fun.expr.pre_walk_values().if_isinstance(ir.Sym).getattr("id").to_set() + syms: set[str] = node.fun.pre_walk_values().if_isinstance(ir.Sym).getattr("id").to_set() clashes = refs & syms - expr = node.fun.expr + fun = node.fun if clashes: # TODO(tehrengruber): find a better way of generating new symbols in `name_map` that don't collide with each other. E.g. this must still work: # (lambda arg, arg_: (lambda arg_: ...)(arg))(a, b) # noqa: ERA001 [commented-out-code] @@ -82,14 +82,17 @@ def inline_lambda( # see todo above for sym in clashes: name_map[sym] = ir_misc.unique_symbol(sym, refs | syms | {*name_map.values()}) - expr = RenameSymbols().visit(expr, name_map=name_map) + # Let's rename the symbols (including params) of the function. + # If we would like to preserve the original param names, we could alternatively + # rename the eligible symrefs in `args`. + fun = RenameSymbols().visit(fun, name_map=name_map) symbol_map = { param.id: arg - for param, arg, eligible in zip(node.fun.params, node.args, eligible_params) + for param, arg, eligible in zip(fun.params, node.args, eligible_params) if eligible } - new_expr = RemapSymbolRefs().visit(expr, symbol_map=symbol_map) + new_expr = RemapSymbolRefs().visit(fun.expr, symbol_map=symbol_map) if all(eligible_params): new_expr.location = node.location @@ -97,9 +100,7 @@ def inline_lambda( # see todo above new_expr = ir.FunCall( fun=ir.Lambda( params=[ - param - for param, eligible in zip(node.fun.params, eligible_params) - if not eligible + param for param, eligible in zip(fun.params, eligible_params) if not eligible ], expr=new_expr, ), diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py index c10d48ad06..005e9b047c 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py @@ -55,6 +55,20 @@ False: im.plus(im.call("opaque")(), im.call("opaque")()), }, ), + ( + # symbol clash when partially inlining (opcount preserving) + "symbol_clash", + # (λ(x, y) → f(x, x + y))(y + y, x) + im.call(im.lambda_("x", "y")(im.call("f")("x", im.plus("x", "y"))))(im.plus("y", "y"), "x"), + { + # (λ(x_) → f(x_, x_ + x))(y + y) + True: im.call(im.lambda_("x_")(im.call("f")("x_", im.plus("x_", "x"))))( + im.plus("y", "y") + ), + # f(y + y, (y + y) + x) # noqa: ERA001 + False: im.call("f")(im.plus("y", "y"), im.plus(im.plus("y", "y"), "x")), + }, + ), ]