@@ -246,24 +246,34 @@ def __call__(
246246 return gm .forward
247247
248248
249- # Equivalent to backend="aot_eager", but also records graphs that
250- # we can assert on
251- class AOTEagerAndRecordGraphs :
249+ class AotEagerAndRecordGraphs :
252250 def __init__ (self ) -> None :
253251 self .graphs : List [torch .fx .GraphModule ] = []
252+ self .fw_graphs : List [torch .fx .GraphModule ] = []
253+ self .bw_graphs : List [torch .fx .GraphModule ] = []
254254
255255 def __call__ (
256256 self , gm : torch .fx .GraphModule , example_inputs : List [torch .Tensor ]
257257 ) -> Callable [..., Any ]:
258- def save_graph (gm : torch .fx .GraphModule , * args : Any , ** kwargs : Any ) -> Any :
259- self .graphs .append (gm )
258+ self .graphs .append (gm )
259+
260+ def fw_compiler (
261+ gm : torch .fx .GraphModule , example_inputs : List [torch .Tensor ]
262+ ) -> Callable [..., Any ]:
263+ self .fw_graphs .append (gm )
264+ return gm .forward
265+
266+ def bw_compiler (
267+ gm : torch .fx .GraphModule , example_inputs : List [torch .Tensor ]
268+ ) -> Callable [..., Any ]:
269+ self .bw_graphs .append (gm )
260270 return gm .forward
261271
262272 return aot_eager (
263273 gm ,
264274 example_inputs ,
265- fw_compiler = save_graph ,
266- bw_compiler = save_graph ,
275+ fw_compiler = fw_compiler ,
276+ bw_compiler = bw_compiler ,
267277 )
268278
269279
0 commit comments