@@ -88,53 +88,22 @@ def external_mutable_weights_pass(
8888 return PassResult (gm , mutated )
8989
9090
91- def delegate_external_constants_pass (
92- gm : GraphModule ,
93- ep : ExportedProgram ,
94- gen_tag_fn : Optional [Callable [[torch .fx .Node ], str ]] = None ,
95- ) -> PassResult :
96- """
97- Tag external constants before to_backend.
98-
99- Note: this pass must be run after run_decompositions(), as tags on
100- constants are removed then.
101-
102- Args:
103- gm: GraphModule to tag.
104- ep: ExportedProgram, to distinguish if a node is a constant.
105- gen_tag_fn: node -> str callable indicating the tag for the node.
106- Returns:
107- PassResult: The resulting gm, and if it was mutated or not.
108- """
109- mutated = False
110- for module in gm .modules ():
111- if not isinstance (module , torch .fx .GraphModule ):
112- continue
113- for node in module .graph .nodes :
114- if node .op == "placeholder" and is_param_node (ep , node ):
115- if gen_tag_fn is not None :
116- node .meta .setdefault ("custom" , {})
117- node .meta ["custom" ]["delegate_constant_tag" ] = gen_tag_fn (node )
118- mutated = True
119- return PassResult (gm , mutated )
120-
121-
12291# Note: this pass must be run on an unlifted graph, e.g. ep.module(),
12392# and not on a lifted graph, e.g. ep.graph_module.
12493# This is using 'get_attr' to tag constants, which only appears in
12594# unlifted graphs.
12695def delegate_external_constants_pass_unlifted (
127- gm : GraphModule ,
128- gen_tag_fn : Optional [Callable [[torch .fx .Node ], str ]] = None ,
96+ module : torch . nn . Module ,
97+ gen_tag_fn : Optional [Callable [[torch .fx .Node ], Optional [ str ] ]] = None ,
12998) -> PassResult :
13099 mutated = False
131- for module in gm .modules ():
132- if not isinstance (module , torch .fx .GraphModule ):
100+ for m in module .modules ():
101+ if not isinstance (m , torch .fx .GraphModule ):
133102 continue
134- for node in module .graph .nodes :
103+ for node in m .graph .nodes :
135104 if node .op == "get_attr" :
136105 if gen_tag_fn is not None :
137106 node .meta .setdefault ("custom" , {})
138107 node .meta ["custom" ]["delegate_constant_tag" ] = gen_tag_fn (node )
139108 mutated = True
140- return PassResult (gm , mutated )
109+ return PassResult (module , mutated )
0 commit comments