3
3
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4
4
# See https://llvm.org/LICENSE.txt for license information.
5
5
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
-
6
+ import builtins
7
7
import logging
8
8
import operator
9
9
import re
27
27
)
28
28
29
29
import iree .compiler .dialects .func as func_dialect
30
-
30
+ from iree . compiler . ir import SymbolTable
31
31
# import iree.compiler.dialects.torch as torch_dialect
32
32
33
33
@@ -132,7 +132,7 @@ class FxImporter:
132
132
"_cc" ,
133
133
"_m" ,
134
134
"_m_ip" ,
135
- "used_function_names " ,
135
+ "symbol_table " ,
136
136
]
137
137
138
138
def __init__ (
@@ -153,7 +153,7 @@ def __init__(
153
153
self ._config_check ()
154
154
self ._cc = ContextCache (self ._c )
155
155
self ._m_ip = InsertionPoint (self ._m .body )
156
- self .used_function_names : Set [ str ] = set ( )
156
+ self .symbol_table = SymbolTable ( self . _m . operation )
157
157
158
158
def _config_check (self ):
159
159
for dname in REQUIRED_DIALCTS :
@@ -173,12 +173,6 @@ def import_graph_module(self, gm: GraphModule):
173
173
self .import_stateless_graph (gm .graph )
174
174
175
175
def import_stateless_graph (self , g : Graph , func_name : str = "main" ):
176
- # TODO(Stella): Switch this to SymbolTable insertion/dedup
177
- if func_name in self .used_function_names :
178
- new_name = f"{ func_name } _{ len (self .used_function_names )} "
179
- func_name = new_name
180
- self .used_function_names .add (func_name )
181
-
182
176
ftype , loc = self ._graph_to_function_meta (g )
183
177
# TODO: The FuncOp constructor requires a context-manager context.
184
178
# Fix upstream and then unnest.
@@ -191,6 +185,7 @@ def import_stateless_graph(self, g: Graph, func_name: str = "main"):
191
185
entry_block = Block .create_at_start (func .body , ftype .inputs )
192
186
node_importer = GraphNodeImporter (self ._c , self ._cc , entry_block )
193
187
node_importer .import_nodes (g .nodes )
188
+ self .symbol_table .insert (func )
194
189
195
190
def _graph_to_function_meta (self , g : Graph ) -> Tuple [FunctionType , Location ]:
196
191
"""Extracts function metadata from the Graph.
@@ -392,7 +387,7 @@ def import_nodes(self, nodes: Sequence[torch_fx.Node]):
392
387
func_dialect .ReturnOp (operands , loc = loc )
393
388
394
389
def _import_torch_op_overload (
395
- self , loc : Location , node : torch_fx .Node , target : TorchOpOverload
390
+ self , loc : Location , node : torch_fx .Node , target : TorchOpOverload
396
391
):
397
392
schema = target ._schema
398
393
assert isinstance (schema , FunctionSchema )
@@ -406,7 +401,7 @@ def _import_torch_op_overload(
406
401
407
402
# Intervening to use Scalar ops due to incorrect ops from AOT-autograd with scalar arguments.
408
403
if mlir_op_name in TENSOR_SCALAR_OP_CONVERTER and (
409
- isinstance (node .args [1 ], float ) or isinstance (node .args [1 ], int )
404
+ isinstance (node .args [1 ], float ) or isinstance (node .args [1 ], int )
410
405
):
411
406
mlir_op_name = TENSOR_SCALAR_OP_CONVERTER [mlir_op_name ]
412
407
@@ -508,7 +503,7 @@ def _import_list_argument(self, loc: Location, arg):
508
503
val_type = str (val .type )
509
504
match = re .match (pattern , val_type )
510
505
assert (
511
- match is not None
506
+ match is not None
512
507
), f"Unexpected MlirType in list: '{ val_type } '"
513
508
list_type = match .group (1 )
514
509
result_type = MlirType .parse (f"!torch.list<{ list_type } >" )
@@ -582,7 +577,7 @@ def lookup(self, t: type) -> Any:
582
577
583
578
584
579
def _make_constant_op (
585
- op_name : str , value_attr : MlirAttribute , result_type : Optional [MlirType ] = None
580
+ op_name : str , value_attr : MlirAttribute , result_type : Optional [MlirType ] = None
586
581
) -> Operation :
587
582
return Operation .create (
588
583
op_name ,
0 commit comments