Skip to content

Commit

Permalink
unflatten isinstance (pytorch#143664)
Browse files Browse the repository at this point in the history
When we unflatten, the submodules we generate (`InterpreterModule` or `InterpreterModuleDispatcher`) are not related by type to the original submodules `N`. This makes `isinstance(mod, N)` checks fail. Since we do not have the original types after export, the best we can do is expose a `type_name()` method that carries the original type name, which we do carry in `nn_module_stack` entries.

Differential Revision: [D67526542](https://our.internmc.facebook.com/intern/diff/D67526542/)

Pull Request resolved: pytorch#143664
Approved by: https://github.com/tugsbayasgalan
  • Loading branch information
avikchaudhuri authored and pytorchmergebot committed Dec 21, 2024
1 parent d88ebbf commit bdeee82
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 9 deletions.
30 changes: 30 additions & 0 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4021,6 +4021,36 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
epm = export(mod, inp).module()
self.assertTrue(torch.allclose(epm(*inp), mod(*inp)))

def test_unflatten_isinstance(self):
class N(torch.nn.Module):
def forward(self, x, b):
if b:
return x + 1
else:
return x + 2

class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.n = N()

def forward(self, x):
return self.n(x + 1, True) + self.n(x + 1, False)

x = torch.zeros(4)
types = {} if is_retracebility_test(self._testMethodName) else {"n": N}
ep = export(
M(),
(x,),
preserve_module_call_signature=tuple(types.keys()),
)
ufm = torch.export.unflatten(ep)
self.assertTrue(torch.allclose(ufm(x), x + 5))
for fqn, mod in ufm.named_modules(remove_duplicate=False):
if cls := types.get(fqn):
ty = f"{cls.__module__}.{cls.__qualname__}"
self.assertTrue(ty, mod.type_name())

def test_unflatten_asserts(self):
# TODO: strict-export fails
class M1(torch.nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion torch/distributed/pipelining/_unflatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def _outline_submodules(orig_graph: torch.fx.Graph):
seen_attrs,
created_modules,
None,
[("", 0)],
[("", None, 0)],
"",
{},
module=new_module,
Expand Down
30 changes: 22 additions & 8 deletions torch/export/unflatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,14 @@ def _assign_attr(
setattr(to_module, field, from_obj)


class InterpreterModule(torch.nn.Module):
class _SubmoduleBase:
_ty: Optional[str]

def type_name(self) -> Optional[str]:
return self._ty


class InterpreterModule(_SubmoduleBase, torch.nn.Module):
"""A module that uses torch.fx.Interpreter to execute instead of the usual
codegen that GraphModule uses. This provides better stack trace information
and makes it easier to debug execution.
Expand All @@ -131,9 +138,11 @@ class InterpreterModule(torch.nn.Module):
def __init__(
self,
graph: torch.fx.Graph,
ty: Optional[str] = None,
):
super().__init__()
self.graph = graph
self._ty = ty
self.graph.owning_module = self
self._run_with_interpreter = RUN_WITH_INTERPRETER

Expand Down Expand Up @@ -206,7 +215,7 @@ def print_readable(
)


class InterpreterModuleDispatcher(torch.nn.Module):
class InterpreterModuleDispatcher(_SubmoduleBase, torch.nn.Module):
"""
A module that carries a sequence of InterpreterModules corresponding to
a sequence of calls of that module. Each call to the module dispatches
Expand All @@ -219,6 +228,7 @@ def __init__(self, attrs: Set[str], call_modules: List[InterpreterModule]):
self._modules = call_modules[0]._modules
for accessor in attrs:
setattr(self, accessor, getattr(call_modules[0], accessor))
self._ty = call_modules[0]._ty
self._call_modules = call_modules
self._num_calls = 0

Expand Down Expand Up @@ -898,7 +908,7 @@ def __init__(
seen_attrs,
created_modules,
parent,
module_stack: List[Tuple[str, int]],
module_stack: List[Tuple[str, Optional[str], int]],
module_id,
module_call_graph: Dict[str, ModuleCallSignature],
module: Optional[Union[torch.fx.GraphModule, UnflattenedModule]] = None,
Expand All @@ -916,7 +926,7 @@ def __init__(
self.module_call_graph = module_call_graph
self.verbose = False

self.fqn, num_calls = self.module_stack[-1]
self.fqn, ty, num_calls = self.module_stack[-1]
# generate call name for self.fqn
self.child_fqn = _call_name(self.fqn, num_calls + 1)

Expand All @@ -927,7 +937,7 @@ def __init__(
else:
self.module = self.created_modules.get(
self.fqn,
InterpreterModule(torch.fx.Graph()),
InterpreterModule(torch.fx.Graph(), ty=ty),
)
self.ivals = parent.ivals

Expand All @@ -945,7 +955,7 @@ def create_module(fqn):
path = f"{parent.fqn}.{fqn}" if parent.fqn else fqn
if path in self.created_modules:
return self.created_modules[path]
submod = InterpreterModule(torch.fx.Graph())
submod = InterpreterModule(torch.fx.Graph(), ty=ty)
self.created_modules[path] = submod
return submod

Expand Down Expand Up @@ -1274,7 +1284,11 @@ def run_from(self, node_idx):
node_module_stack = self.module_stack
else:
node_module_stack = [
(path, int(k.split("@")[-1]) if "@" in k else 0)
(
path,
ty if path else None,
int(k.split("@")[-1]) if "@" in k else 0,
)
for k, (path, ty) in node.meta["nn_module_stack"].items()
]

Expand Down Expand Up @@ -1349,7 +1363,7 @@ def _outline_submodules(orig_graph: torch.fx.Graph, root_module: UnflattenedModu
seen_attrs,
created_modules,
None,
[("", 0)],
[("", None, 0)],
"",
{
entry.fqn: entry.signature
Expand Down

0 comments on commit bdeee82

Please sign in to comment.