From 8635b2c45074f71c636c7e60969afc9e177c7b31 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 10 Jul 2025 23:55:34 +0200 Subject: [PATCH 1/3] fix[next]: symbol clash in inline_lambda --- .../iterator/transforms/inline_lambdas.py | 17 +-- .../transforms_tests/test_inline_lambdas.py | 109 ++++++++++++++++++ 2 files changed, 118 insertions(+), 8 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/inline_lambdas.py b/src/gt4py/next/iterator/transforms/inline_lambdas.py index b7eb45d156..a41f74ebc1 100644 --- a/src/gt4py/next/iterator/transforms/inline_lambdas.py +++ b/src/gt4py/next/iterator/transforms/inline_lambdas.py @@ -71,9 +71,9 @@ def inline_lambda( # see todo above if eligible ) ) - syms: set[str] = node.fun.expr.pre_walk_values().if_isinstance(ir.Sym).getattr("id").to_set() + syms: set[str] = node.fun.pre_walk_values().if_isinstance(ir.Sym).getattr("id").to_set() clashes = refs & syms - expr = node.fun.expr + fun = node.fun if clashes: # TODO(tehrengruber): find a better way of generating new symbols in `name_map` that don't collide with each other. E.g. this must still work: # (lambda arg, arg_: (lambda arg_: ...)(arg))(a, b) # noqa: ERA001 [commented-out-code] @@ -82,14 +82,17 @@ def inline_lambda( # see todo above for sym in clashes: name_map[sym] = ir_misc.unique_symbol(sym, refs | syms | {*name_map.values()}) - expr = RenameSymbols().visit(expr, name_map=name_map) + # Let's rename the symbols (including params) of the function. + # If we would like to preserve the original param names, we could alternatively + # rename the eligible symrefs in `args`. + fun = RenameSymbols().visit(fun, name_map=name_map) symbol_map = { param.id: arg - for param, arg, eligible in zip(node.fun.params, node.args, eligible_params) + for param, arg, eligible in zip(fun.params, node.args, eligible_params) if eligible } - new_expr = RemapSymbolRefs().visit(expr, symbol_map=symbol_map) + new_expr = RemapSymbolRefs().visit(fun.expr, symbol_map=symbol_map) if all(eligible_params): new_expr.location = node.location @@ -97,9 +100,7 @@ def inline_lambda( # see todo above new_expr = ir.FunCall( fun=ir.Lambda( params=[ - param - for param, eligible in zip(node.fun.params, eligible_params) - if not eligible + param for param, eligible in zip(fun.params, eligible_params) if not eligible ], expr=new_expr, ), diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py index c10d48ad06..00f1fb1a1b 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py @@ -55,6 +55,17 @@ False: im.plus(im.call("opaque")(), im.call("opaque")()), }, ), + ( + # symbol clash when partially inlining (opcount preserving) + "symbol_clash", + im.call(im.lambda_("x", "y")(im.call("f")("x", im.plus("x", "y"))))(im.plus("y", "y"), "x"), + { + True: im.call(im.lambda_("x_")(im.call("f")("x_", im.plus("x_", "x"))))( + im.plus("y", "y") + ), + False: im.call("f")(im.plus("y", "y"), im.plus(im.plus("y", "y"), "x")), + }, + ), ] @@ -91,3 +102,101 @@ def test_type_preservation(): testee.type = testee.annex.type = ts.ScalarType(kind=ts.ScalarKind.FLOAT32) inlined = InlineLambdas.apply(testee) assert inlined.type == inlined.annex.type == ts.ScalarType(kind=ts.ScalarKind.FLOAT32) + + +def test_dbg1(): + # (λ(foo, bar) → + # (λ(foo, bar) → + # if_(c0, + # foo, + # if_(c1, foo, bar)))( + # if_(c2, foo, bar), foo + # ))(a, b) + + testee = im.call( + im.lambda_("foo", "bar")( + im.call( + im.lambda_("foo", "bar")( + im.if_(im.ref("c0"), "foo", im.if_(im.ref("c1"), "foo", "bar")) + ) + )( + im.if_(im.ref("c2"), "foo", "bar"), + "foo", + ) + ) + )("a", "b") + print(testee) + inlined = InlineLambdas.apply(testee, opcount_preserving=True) + print(inlined) + inlined = InlineLambdas.apply(inlined, opcount_preserving=False) + print(inlined) + + # if c0 then b else if c1 then b else b + + # expected: + # if_(c0, + # if_(c2, a, b), + # if_(c1, + # if_(c2, a, b), + # a) + # ) + expected = im.if_( + im.ref("c0"), + im.if_(im.ref("c2"), "a", "b"), + im.if_(im.ref("c1"), im.if_(im.ref("c2"), "a", "b"), "a"), + ) + print(expected) + assert inlined == expected + + +# def test_dbg2(): +# testee = im.call( +# im.lambda_("x", "y")(im.multiplies_(im.call(im.lambda_("x")(im.plus("x", 1)))("y"), "x")) +# )(im.plus("x", "x"), "x") + +# print(testee) +# inlined = InlineLambdas.apply(testee, opcount_preserving=True) +# print(inlined) + + +def test_dbg2(): + testee = im.call( + im.lambda_("x", "y")( + im.call(im.lambda_("x", "y")(im.call("f")("x", im.plus("x", "y"))))( + im.plus("y", "y"), "x" + ) + ) + )("a", "b") + + print(testee) + inlined = InlineLambdas.apply(testee, opcount_preserving=True) + print(inlined) + inlined = InlineLambdas.apply(inlined, opcount_preserving=False) + print(inlined) + + direct = InlineLambdas.apply(testee, opcount_preserving=False) + print(direct) + + +def test_dbg3(): + testee = im.call(im.lambda_("x", "y")(im.call("f")("x", im.plus("x", "y"))))( + im.plus("y", "y"), "x" + ) + + print(testee) + # inlined = InlineLambdas.apply(testee, opcount_preserving=True) + inlined = InlineLambdas.apply(testee, opcount_preserving=True) + print(inlined) + # inlined = inline_lambda(testee, opcount_preserving=False) + # print(inlined) + + # expected = (λ(x) → f(x_, x_ + x))(y + y) + expected = im.call(im.lambda_("x_")(im.call("f")("x_", im.plus("x_", "x"))))(im.plus("y", "y")) + print(expected) + assert inlined == expected + + # inlined = InlineLambdas.apply(inlined, opcount_preserving=False) + # print(inlined) + + # direct = InlineLambdas.apply(testee, opcount_preserving=False) + # print(direct) From c72f4945e5d0c98dc50c7900f822e57cb8474106 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 11 Jul 2025 07:15:05 +0200 Subject: [PATCH 2/3] cleanup --- .../transforms_tests/test_inline_lambdas.py | 98 ------------------- 1 file changed, 98 deletions(-) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py index 00f1fb1a1b..314d14cd7f 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py @@ -102,101 +102,3 @@ def test_type_preservation(): testee.type = testee.annex.type = ts.ScalarType(kind=ts.ScalarKind.FLOAT32) inlined = InlineLambdas.apply(testee) assert inlined.type == inlined.annex.type == ts.ScalarType(kind=ts.ScalarKind.FLOAT32) - - -def test_dbg1(): - # (λ(foo, bar) → - # (λ(foo, bar) → - # if_(c0, - # foo, - # if_(c1, foo, bar)))( - # if_(c2, foo, bar), foo - # ))(a, b) - - testee = im.call( - im.lambda_("foo", "bar")( - im.call( - im.lambda_("foo", "bar")( - im.if_(im.ref("c0"), "foo", im.if_(im.ref("c1"), "foo", "bar")) - ) - )( - im.if_(im.ref("c2"), "foo", "bar"), - "foo", - ) - ) - )("a", "b") - print(testee) - inlined = InlineLambdas.apply(testee, opcount_preserving=True) - print(inlined) - inlined = InlineLambdas.apply(inlined, opcount_preserving=False) - print(inlined) - - # if c0 then b else if c1 then b else b - - # expected: - # if_(c0, - # if_(c2, a, b), - # if_(c1, - # if_(c2, a, b), - # a) - # ) - expected = im.if_( - im.ref("c0"), - im.if_(im.ref("c2"), "a", "b"), - im.if_(im.ref("c1"), im.if_(im.ref("c2"), "a", "b"), "a"), - ) - print(expected) - assert inlined == expected - - -# def test_dbg2(): -# testee = im.call( -# im.lambda_("x", "y")(im.multiplies_(im.call(im.lambda_("x")(im.plus("x", 1)))("y"), "x")) -# )(im.plus("x", "x"), "x") - -# print(testee) -# inlined = InlineLambdas.apply(testee, opcount_preserving=True) -# print(inlined) - - -def test_dbg2(): - testee = im.call( - im.lambda_("x", "y")( - im.call(im.lambda_("x", "y")(im.call("f")("x", im.plus("x", "y"))))( - im.plus("y", "y"), "x" - ) - ) - )("a", "b") - - print(testee) - inlined = InlineLambdas.apply(testee, opcount_preserving=True) - print(inlined) - inlined = InlineLambdas.apply(inlined, opcount_preserving=False) - print(inlined) - - direct = InlineLambdas.apply(testee, opcount_preserving=False) - print(direct) - - -def test_dbg3(): - testee = im.call(im.lambda_("x", "y")(im.call("f")("x", im.plus("x", "y"))))( - im.plus("y", "y"), "x" - ) - - print(testee) - # inlined = InlineLambdas.apply(testee, opcount_preserving=True) - inlined = InlineLambdas.apply(testee, opcount_preserving=True) - print(inlined) - # inlined = inline_lambda(testee, opcount_preserving=False) - # print(inlined) - - # expected = (λ(x) → f(x_, x_ + x))(y + y) - expected = im.call(im.lambda_("x_")(im.call("f")("x_", im.plus("x_", "x"))))(im.plus("y", "y")) - print(expected) - assert inlined == expected - - # inlined = InlineLambdas.apply(inlined, opcount_preserving=False) - # print(inlined) - - # direct = InlineLambdas.apply(testee, opcount_preserving=False) - # print(direct) From 95d0bab44983724b0092dcef39158bdb4bcea3b2 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 11 Jul 2025 09:15:02 +0200 Subject: [PATCH 3/3] add pretty printed example --- .../iterator_tests/transforms_tests/test_inline_lambdas.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py index 314d14cd7f..005e9b047c 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py @@ -58,11 +58,14 @@ ( # symbol clash when partially inlining (opcount preserving) "symbol_clash", + # (λ(x, y) → f(x, x + y))(y + y, x) im.call(im.lambda_("x", "y")(im.call("f")("x", im.plus("x", "y"))))(im.plus("y", "y"), "x"), { + # (λ(x_) → f(x_, x_ + x))(y + y) True: im.call(im.lambda_("x_")(im.call("f")("x_", im.plus("x_", "x"))))( im.plus("y", "y") ), + # f(y + y, (y + y) + x) # noqa: ERA001 False: im.call("f")(im.plus("y", "y"), im.plus(im.plus("y", "y"), "x")), }, ),