Skip to content

Commit 2fd61ed

Browse files
Implemented symbol table to prevent redefinition of func symbol name (#28)
Implemented symbol table to prevent redefinition of symbol named during backward(). `loss.backward()` in an example test case will call `import_stateless_graph()` again, and thus requires func symbol name inside module not to be redefined. Now with the symbol table implemented, we no longer check against a set() that stores previously used func names as we used to do in PR #25. After `FuncOp` is initially created in the FX Importer, `self.symbol_table.insert(func)` insertion renames this symbol in the symbol table.
1 parent 38525d6 commit 2fd61ed

File tree

1 file changed

+9
-14
lines changed

1 file changed

+9
-14
lines changed

python/shark_turbine/dynamo/importer.py

+9-14
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Licensed under the Apache License v2.0 with LLVM Exceptions.
44
# See https://llvm.org/LICENSE.txt for license information.
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6-
6+
import builtins
77
import logging
88
import operator
99
import re
@@ -27,7 +27,7 @@
2727
)
2828

2929
import iree.compiler.dialects.func as func_dialect
30-
30+
from iree.compiler.ir import SymbolTable
3131
# import iree.compiler.dialects.torch as torch_dialect
3232

3333

@@ -132,7 +132,7 @@ class FxImporter:
132132
"_cc",
133133
"_m",
134134
"_m_ip",
135-
"used_function_names",
135+
"symbol_table",
136136
]
137137

138138
def __init__(
@@ -153,7 +153,7 @@ def __init__(
153153
self._config_check()
154154
self._cc = ContextCache(self._c)
155155
self._m_ip = InsertionPoint(self._m.body)
156-
self.used_function_names: Set[str] = set()
156+
self.symbol_table = SymbolTable(self._m.operation)
157157

158158
def _config_check(self):
159159
for dname in REQUIRED_DIALCTS:
@@ -173,12 +173,6 @@ def import_graph_module(self, gm: GraphModule):
173173
self.import_stateless_graph(gm.graph)
174174

175175
def import_stateless_graph(self, g: Graph, func_name: str = "main"):
176-
# TODO(Stella): Switch this to SymbolTable insertion/dedup
177-
if func_name in self.used_function_names:
178-
new_name = f"{func_name}_{len(self.used_function_names)}"
179-
func_name = new_name
180-
self.used_function_names.add(func_name)
181-
182176
ftype, loc = self._graph_to_function_meta(g)
183177
# TODO: The FuncOp constructor requires a context-manager context.
184178
# Fix upstream and then unnest.
@@ -191,6 +185,7 @@ def import_stateless_graph(self, g: Graph, func_name: str = "main"):
191185
entry_block = Block.create_at_start(func.body, ftype.inputs)
192186
node_importer = GraphNodeImporter(self._c, self._cc, entry_block)
193187
node_importer.import_nodes(g.nodes)
188+
self.symbol_table.insert(func)
194189

195190
def _graph_to_function_meta(self, g: Graph) -> Tuple[FunctionType, Location]:
196191
"""Extracts function metadata from the Graph.
@@ -392,7 +387,7 @@ def import_nodes(self, nodes: Sequence[torch_fx.Node]):
392387
func_dialect.ReturnOp(operands, loc=loc)
393388

394389
def _import_torch_op_overload(
395-
self, loc: Location, node: torch_fx.Node, target: TorchOpOverload
390+
self, loc: Location, node: torch_fx.Node, target: TorchOpOverload
396391
):
397392
schema = target._schema
398393
assert isinstance(schema, FunctionSchema)
@@ -406,7 +401,7 @@ def _import_torch_op_overload(
406401

407402
# Intervening to use Scalar ops due to incorrect ops from AOT-autograd with scalar arguments.
408403
if mlir_op_name in TENSOR_SCALAR_OP_CONVERTER and (
409-
isinstance(node.args[1], float) or isinstance(node.args[1], int)
404+
isinstance(node.args[1], float) or isinstance(node.args[1], int)
410405
):
411406
mlir_op_name = TENSOR_SCALAR_OP_CONVERTER[mlir_op_name]
412407

@@ -508,7 +503,7 @@ def _import_list_argument(self, loc: Location, arg):
508503
val_type = str(val.type)
509504
match = re.match(pattern, val_type)
510505
assert (
511-
match is not None
506+
match is not None
512507
), f"Unexpected MlirType in list: '{val_type}'"
513508
list_type = match.group(1)
514509
result_type = MlirType.parse(f"!torch.list<{list_type}>")
@@ -582,7 +577,7 @@ def lookup(self, t: type) -> Any:
582577

583578

584579
def _make_constant_op(
585-
op_name: str, value_attr: MlirAttribute, result_type: Optional[MlirType] = None
580+
op_name: str, value_attr: MlirAttribute, result_type: Optional[MlirType] = None
586581
) -> Operation:
587582
return Operation.create(
588583
op_name,

0 commit comments

Comments
 (0)