diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 628efb001c..8c320e01cd 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -110,7 +110,10 @@ def past_to_gtir(inp: AOT_PRG) -> stages.CompilableProgram: for name, descr in static_arg_descriptors.items() if not any(el is None for el in gtx_utils.flatten_nested_tuple(descr)) # type: ignore[arg-type] } - body = remap_symbols.RemapSymbolRefs().visit(itir_program.body, symbol_map=static_args) + body = [ + remap_symbols.RemapSymbolRefs.apply(stmt, symbol_map=static_args) + for stmt in itir_program.body + ] # type: ignore[arg-type] itir_program = itir.Program( id=itir_program.id, function_definitions=itir_program.function_definitions, diff --git a/src/gt4py/next/iterator/transforms/inline_lambdas.py b/src/gt4py/next/iterator/transforms/inline_lambdas.py index a41f74ebc1..7ac0b964a3 100644 --- a/src/gt4py/next/iterator/transforms/inline_lambdas.py +++ b/src/gt4py/next/iterator/transforms/inline_lambdas.py @@ -7,14 +7,14 @@ # SPDX-License-Identifier: BSD-3-Clause import dataclasses -from typing import Mapping, Optional, TypeVar +from typing import Optional, TypeVar, Mapping from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import ir -from gt4py.next.iterator.ir_utils import ir_makers as im, misc as ir_misc +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift -from gt4py.next.iterator.transforms.remap_symbols import RemapSymbolRefs, RenameSymbols -from gt4py.next.iterator.transforms.symbol_ref_utils import CountSymbolRefs +from gt4py.next.iterator.transforms import symbol_ref_utils +from gt4py.next.iterator.transforms.remap_symbols import RemapSymbolRefs from gt4py.next.iterator.type_system import inference as itir_inference @@ -34,7 +34,9 @@ def inline_lambda( # see todo above assert len(eligible_params) == len(node.fun.params) == len(node.args) if opcount_preserving: - ref_counts = CountSymbolRefs.apply(node.fun.expr, [p.id for p in node.fun.params]) + ref_counts = symbol_ref_utils.CountSymbolRefs.apply( + node.fun.expr, [p.id for p in node.fun.params] + ) for i, param in enumerate(node.fun.params): # TODO(tehrengruber): allow inlining more complicated zero-op expressions like ignore_shift(...)(it_sym) @@ -64,49 +66,24 @@ def inline_lambda( # see todo above if node.fun.params and not any(eligible_params): return node - refs: set[str] = set().union( - *( - arg.pre_walk_values().if_isinstance(ir.SymRef).getattr("id").to_set() - for arg, eligible in zip(node.args, eligible_params) - if eligible - ) - ) - syms: set[str] = node.fun.pre_walk_values().if_isinstance(ir.Sym).getattr("id").to_set() - clashes = refs & syms - fun = node.fun - if clashes: - # TODO(tehrengruber): find a better way of generating new symbols in `name_map` that don't collide with each other. E.g. this must still work: - # (lambda arg, arg_: (lambda arg_: ...)(arg))(a, b) # noqa: ERA001 [commented-out-code] - name_map: dict[str, str] = {} - - for sym in clashes: - name_map[sym] = ir_misc.unique_symbol(sym, refs | syms | {*name_map.values()}) - - # Let's rename the symbols (including params) of the function. - # If we would like to preserve the original param names, we could alternatively - # rename the eligible symrefs in `args`. - fun = RenameSymbols().visit(fun, name_map=name_map) - symbol_map = { - param.id: arg - for param, arg, eligible in zip(fun.params, node.args, eligible_params) + str(param.id): arg + for param, arg, eligible in zip(node.fun.params, node.args, eligible_params) if eligible } - new_expr = RemapSymbolRefs().visit(fun.expr, symbol_map=symbol_map) + + new_fun_proto = im.lambda_( + *(param for param, eligible in zip(node.fun.params, eligible_params) if not eligible) + )(node.fun.expr) + new_fun_proto = RemapSymbolRefs.apply(new_fun_proto, symbol_map=symbol_map) + new_expr = im.call(new_fun_proto)( + *(arg for arg, eligible in zip(node.args, eligible_params) if not eligible) + ) if all(eligible_params): + new_expr = new_expr.fun.expr new_expr.location = node.location - else: - new_expr = ir.FunCall( - fun=ir.Lambda( - params=[ - param for param, eligible in zip(fun.params, eligible_params) if not eligible - ], - expr=new_expr, - ), - args=[arg for arg, eligible in zip(node.args, eligible_params) if not eligible], - location=node.location, - ) + for attr in ("type", "recorded_shifts", "domain"): if hasattr(node.annex, attr): setattr(new_expr.annex, attr, getattr(node.annex, attr)) diff --git a/src/gt4py/next/iterator/transforms/remap_symbols.py b/src/gt4py/next/iterator/transforms/remap_symbols.py index 5495f63ae1..a082543089 100644 --- a/src/gt4py/next/iterator/transforms/remap_symbols.py +++ b/src/gt4py/next/iterator/transforms/remap_symbols.py @@ -10,6 +10,8 @@ from gt4py.eve import NodeTranslator, PreserveLocationVisitor, SymbolTableTrait from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ir_makers as im, misc as ir_misc +from gt4py.next.iterator.transforms import symbol_ref_utils from gt4py.next.iterator.type_system import inference as type_inference @@ -17,13 +19,50 @@ class RemapSymbolRefs(PreserveLocationVisitor, NodeTranslator): # This pass preserves, but doesn't use the `type`, `recorded_shifts`, `domain` annex. PRESERVED_ANNEX_ATTRS = ("type", "recorded_shifts", "domain") - def visit_SymRef(self, node: ir.SymRef, *, symbol_map: Dict[str, ir.Node]): + @classmethod + def apply(cls, node: ir.Node, symbol_map: Dict[str, ir.Node]): + external_symbols = set().union( + *(symbol_ref_utils.collect_symbol_refs(expr) for expr in [node, *symbol_map.values()]) + ) + return cls().visit(node, symbol_map=symbol_map, reserved_params=external_symbols) + + def visit_SymRef( + self, node: ir.SymRef, *, symbol_map: Dict[str, ir.Node], reserved_params: set[str] + ): return symbol_map.get(str(node.id), node) - def visit_Lambda(self, node: ir.Lambda, *, symbol_map: Dict[str, ir.Node]): + def visit_Lambda( + self, node: ir.Lambda, *, symbol_map: Dict[str, ir.Node], reserved_params: set[str] + ): params = {str(p.id) for p in node.params} - new_symbol_map = {k: v for k, v in symbol_map.items() if k not in params} - return ir.Lambda(params=node.params, expr=self.visit(node.expr, symbol_map=new_symbol_map)) + + clashes = params & reserved_params + if clashes: + reserved_params = {*reserved_params} + new_symbol_map: Dict[str, ir.Node] = {} + new_params: list[ir.Sym] = [] + for param in node.params: + if param.id in clashes: + new_param = im.sym(ir_misc.unique_symbol(param.id, reserved_params), param.type) + assert new_param.id not in symbol_map + new_symbol_map[param.id] = im.ref(new_param.id, param.type) + reserved_params.add(new_param.id) + else: + new_param = param + new_params.append(new_param) + + new_symbol_map = symbol_map | new_symbol_map + else: + new_params = node.params # keep params as is + new_symbol_map = symbol_map + + filtered_symbol_map = {k: v for k, v in new_symbol_map.items() if k not in new_params} + return ir.Lambda( + params=new_params, + expr=self.visit( + node.expr, symbol_map=filtered_symbol_map, reserved_params=reserved_params + ), + ) def generic_visit(self, node: ir.Node, **kwargs: Any): # type: ignore[override] assert isinstance(node, SymbolTableTrait) == isinstance(node, ir.Lambda), ( diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py index 005e9b047c..12e140f0c7 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py @@ -40,6 +40,19 @@ ), im.multiplies_(im.plus(2, 1), im.plus("x", "x")), ), + ( + "name_shadowing_external", + im.call(im.lambda_("x")(im.lambda_("y")(im.plus("x", "y"))))(im.plus("x", "y")), + im.lambda_("y_")(im.plus(im.plus("x", "y"), "y_")), + ), + ( + "renaming_collision", + # the `y` param of the lambda may not be renamed to `y_` as this name is already referenced + im.call(im.lambda_("x")(im.lambda_("y")(im.plus(im.plus("x", "y"), "y_"))))( + im.plus("x", "y") + ), + im.lambda_("y__")(im.plus(im.plus(im.plus("x", "y"), "y__"), "y_")), + ), ( # ensure opcount preserving option works whether `itir.SymRef` has a type or not "typed_ref",