Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/gt4py/next/ffront/past_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,13 @@ def past_to_gtir(inp: AOT_PRG) -> stages.CompilableProgram:
i: arg.value for i, arg in enumerate(inp.args.args) if isinstance(arg, arguments.StaticArg)
}
static_args = {
itir_program.params[i].id: im.literal_from_tuple_value(value)
str(itir_program.params[i].id): im.literal_from_tuple_value(value)
for i, value in static_args_index.items()
}
body = remap_symbols.RemapSymbolRefs().visit(itir_program.body, symbol_map=static_args)
body = [
remap_symbols.RemapSymbolRefs.apply(stmt, symbol_map=static_args)
for stmt in itir_program.body
] # type: ignore[arg-type]
itir_program = itir.Program(
id=itir_program.id,
function_definitions=itir_program.function_definitions,
Expand Down
60 changes: 17 additions & 43 deletions src/gt4py/next/iterator/transforms/inline_lambdas.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@

from gt4py.eve import NodeTranslator, PreserveLocationVisitor
from gt4py.next.iterator import ir
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift
from gt4py.next.iterator.transforms.remap_symbols import RemapSymbolRefs, RenameSymbols
from gt4py.next.iterator.transforms.symbol_ref_utils import CountSymbolRefs
from gt4py.next.iterator.transforms import symbol_ref_utils
from gt4py.next.iterator.transforms.remap_symbols import RemapSymbolRefs
from gt4py.next.iterator.type_system import inference as itir_inference


Expand All @@ -33,7 +34,9 @@ def inline_lambda( # see todo above
assert len(eligible_params) == len(node.fun.params) == len(node.args)

if opcount_preserving:
ref_counts = CountSymbolRefs.apply(node.fun.expr, [p.id for p in node.fun.params])
ref_counts = symbol_ref_utils.CountSymbolRefs.apply(
node.fun.expr, [p.id for p in node.fun.params]
)

for i, param in enumerate(node.fun.params):
# TODO(tehrengruber): allow inlining more complicated zero-op expressions like ignore_shift(...)(it_sym)
Expand Down Expand Up @@ -63,53 +66,24 @@ def inline_lambda( # see todo above
if node.fun.params and not any(eligible_params):
return node

refs = set().union(
*(
arg.pre_walk_values().if_isinstance(ir.SymRef).getattr("id").to_set()
for arg, eligible in zip(node.args, eligible_params)
if eligible
)
)
syms = node.fun.expr.pre_walk_values().if_isinstance(ir.Sym).getattr("id").to_set()
clashes = refs & syms
expr = node.fun.expr
if clashes:
# TODO(tehrengruber): find a better way of generating new symbols in `name_map` that don't collide with each other. E.g. this must still work:
# (lambda arg, arg_: (lambda arg_: ...)(arg))(a, b) # noqa: ERA001 [commented-out-code]
name_map: dict[ir.SymRef, str] = {}

def new_name(name):
while name in refs or name in syms or name in name_map.values():
name += "_"
return name

for sym in clashes:
name_map[sym] = new_name(sym)

expr = RenameSymbols().visit(expr, name_map=name_map)

symbol_map = {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just leaving a remark here, that there was #2134. Maybe it's already resolved here.

param.id: arg
str(param.id): arg
for param, arg, eligible in zip(node.fun.params, node.args, eligible_params)
if eligible
}
new_expr = RemapSymbolRefs().visit(expr, symbol_map=symbol_map)

new_fun_proto = im.lambda_(
*(param for param, eligible in zip(node.fun.params, eligible_params) if not eligible)
)(node.fun.expr)
new_fun_proto = RemapSymbolRefs.apply(new_fun_proto, symbol_map=symbol_map)
new_expr = im.call(new_fun_proto)(
*(arg for arg, eligible in zip(node.args, eligible_params) if not eligible)
)

if all(eligible_params):
new_expr = new_expr.fun.expr
new_expr.location = node.location
else:
new_expr = ir.FunCall(
fun=ir.Lambda(
params=[
param
for param, eligible in zip(node.fun.params, eligible_params)
if not eligible
],
expr=new_expr,
),
args=[arg for arg, eligible in zip(node.args, eligible_params) if not eligible],
location=node.location,
)

for attr in ("type", "recorded_shifts", "domain"):
if hasattr(node.annex, attr):
setattr(new_expr.annex, attr, getattr(node.annex, attr))
Expand Down
53 changes: 49 additions & 4 deletions src/gt4py/next/iterator/transforms/remap_symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,65 @@

from gt4py.eve import NodeTranslator, PreserveLocationVisitor, SymbolTableTrait
from gt4py.next.iterator import ir
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.iterator.transforms import symbol_ref_utils
from gt4py.next.iterator.type_system import inference as type_inference


def unique_name(name, prohibited_symbols):
while name in prohibited_symbols:
name += "_"
return name


class RemapSymbolRefs(PreserveLocationVisitor, NodeTranslator):
# This pass preserves, but doesn't use the `type`, `recorded_shifts`, `domain` annex.
PRESERVED_ANNEX_ATTRS = ("type", "recorded_shifts", "domain")

def visit_SymRef(self, node: ir.SymRef, *, symbol_map: Dict[str, ir.Node]):
@classmethod
def apply(cls, node: ir.Node, symbol_map: Dict[str, ir.Node]):
external_symbols = set().union(
*(symbol_ref_utils.collect_symbol_refs(expr) for expr in [node, *symbol_map.values()])
)
return cls().visit(node, symbol_map=symbol_map, reserved_params=external_symbols)

def visit_SymRef(
self, node: ir.SymRef, *, symbol_map: Dict[str, ir.Node], reserved_params: set[str]
):
return symbol_map.get(str(node.id), node)

def visit_Lambda(self, node: ir.Lambda, *, symbol_map: Dict[str, ir.Node]):
def visit_Lambda(
self, node: ir.Lambda, *, symbol_map: Dict[str, ir.Node], reserved_params: set[str]
):
params = {str(p.id) for p in node.params}
new_symbol_map = {k: v for k, v in symbol_map.items() if k not in params}
return ir.Lambda(params=node.params, expr=self.visit(node.expr, symbol_map=new_symbol_map))

clashes = params & reserved_params
if clashes:
reserved_params = {*reserved_params}
new_symbol_map: Dict[str, ir.Node] = {}
new_params: list[ir.Sym] = []
for param in node.params:
if param.id in clashes:
new_param = im.sym(unique_name(param.id, reserved_params), param.type)
assert new_param.id not in symbol_map
new_symbol_map[param.id] = im.ref(new_param.id, param.type)
reserved_params.add(new_param.id)
else:
new_param = param
new_params.append(new_param)

new_symbol_map = symbol_map | new_symbol_map
else:
new_params = node.params # keep params as is
new_symbol_map = symbol_map

filtered_symbol_map = {k: v for k, v in new_symbol_map.items() if k not in new_params}
return ir.Lambda(
params=new_params,
expr=self.visit(
node.expr, symbol_map=filtered_symbol_map, reserved_params=reserved_params
),
)

def generic_visit(self, node: ir.Node, **kwargs: Any): # type: ignore[override]
assert isinstance(node, SymbolTableTrait) == isinstance(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,19 @@
),
im.multiplies_(im.plus(2, 1), im.plus("x", "x")),
),
(
"name_shadowing_external",
im.call(im.lambda_("x")(im.lambda_("y")(im.plus("x", "y"))))(im.plus("x", "y")),
im.lambda_("y_")(im.plus(im.plus("x", "y"), "y_")),
),
(
"renaming_collision",
# the `y` param of the lambda may not be renamed to `y_` as this name is already referenced
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. Is this the same problem as in #2134

im.call(im.lambda_("x")(im.lambda_("y")(im.plus(im.plus("x", "y"), "y_"))))(
im.plus("x", "y")
),
im.lambda_("y__")(im.plus(im.plus(im.plus("x", "y"), "y__"), "y_")),
),
(
# ensure opcount preserving option works whether `itir.SymRef` has a type or not
"typed_ref",
Expand Down
Loading