Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions src/gt4py/next/iterator/transforms/inline_lambdas.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ def inline_lambda( # see todo above
if eligible
)
)
syms: set[str] = node.fun.expr.pre_walk_values().if_isinstance(ir.Sym).getattr("id").to_set()
syms: set[str] = node.fun.pre_walk_values().if_isinstance(ir.Sym).getattr("id").to_set()
clashes = refs & syms
expr = node.fun.expr
fun = node.fun
if clashes:
# TODO(tehrengruber): find a better way of generating new symbols in `name_map` that don't collide with each other. E.g. this must still work:
# (lambda arg, arg_: (lambda arg_: ...)(arg))(a, b) # noqa: ERA001 [commented-out-code]
Expand All @@ -82,24 +82,25 @@ def inline_lambda( # see todo above
for sym in clashes:
name_map[sym] = ir_misc.unique_symbol(sym, refs | syms | {*name_map.values()})

expr = RenameSymbols().visit(expr, name_map=name_map)
# Let's rename the symbols (including params) of the function.
# If we would like to preserve the original param names, we could alternatively
# rename the eligible symrefs in `args`.
fun = RenameSymbols().visit(fun, name_map=name_map)

symbol_map = {
param.id: arg
for param, arg, eligible in zip(node.fun.params, node.args, eligible_params)
for param, arg, eligible in zip(fun.params, node.args, eligible_params)
if eligible
}
new_expr = RemapSymbolRefs().visit(expr, symbol_map=symbol_map)
new_expr = RemapSymbolRefs().visit(fun.expr, symbol_map=symbol_map)

if all(eligible_params):
new_expr.location = node.location
else:
new_expr = ir.FunCall(
fun=ir.Lambda(
params=[
param
for param, eligible in zip(node.fun.params, eligible_params)
if not eligible
param for param, eligible in zip(fun.params, eligible_params) if not eligible
],
expr=new_expr,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,20 @@
False: im.plus(im.call("opaque")(), im.call("opaque")()),
},
),
(
# symbol clash when partially inlining (opcount preserving)
"symbol_clash",
# (λ(x, y) → f(x, x + y))(y + y, x)
im.call(im.lambda_("x", "y")(im.call("f")("x", im.plus("x", "y"))))(im.plus("y", "y"), "x"),
{
# (λ(x_) → f(x_, x_ + x))(y + y)
True: im.call(im.lambda_("x_")(im.call("f")("x_", im.plus("x_", "x"))))(
im.plus("y", "y")
),
# f(y + y, (y + y) + x) # noqa: ERA001
False: im.call("f")(im.plus("y", "y"), im.plus(im.plus("y", "y"), "x")),
},
),
]


Expand Down