From 9e6a5818dfa1ea64efd13b089e48920cf912d2ef Mon Sep 17 00:00:00 2001 From: NihalHarish Date: Thu, 29 Oct 2020 17:33:01 -0700 Subject: [PATCH 1/2] get_name --- smdebug/pytorch/hook.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/smdebug/pytorch/hook.py b/smdebug/pytorch/hook.py index 1eeb0636f..22184d651 100644 --- a/smdebug/pytorch/hook.py +++ b/smdebug/pytorch/hook.py @@ -197,6 +197,14 @@ def register_hook(self, module): # for compatibility with ZCC patches which call this self.register_module(module) + @staticmethod + def _add_module_name(module, module_name): + if isinstance(module, torch.nn.parallel.data_parallel.DataParallel): + module.module._module_name = module_name + else: + module._module_name = module_name + return module + def register_module(self, module): """ This function registers the forward hook. If user wants to register the hook @@ -215,9 +223,9 @@ def register_module(self, module): for name, submodule in module.named_modules(): assert submodule not in self.module_set, f"Don't register module={module} twice" - submodule._module_name = name + Hook._add_module_name(submodule, name) self.module_set.add(submodule) - module._module_name = module._get_name() + Hook._add_module_name(module, module._get_name()) self.module_set.add(module) # Use `forward_pre_hook` for the entire net From 3c9b1ad4c46ec7d68ce41b03853cb39f0f63e112 Mon Sep 17 00:00:00 2001 From: NihalHarish Date: Fri, 30 Oct 2020 01:13:54 -0700 Subject: [PATCH 2/2] resolve import --- smdebug/pytorch/hook.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/smdebug/pytorch/hook.py b/smdebug/pytorch/hook.py index 22184d651..ebcef15c1 100644 --- a/smdebug/pytorch/hook.py +++ b/smdebug/pytorch/hook.py @@ -3,6 +3,7 @@ # Third Party import torch import torch.distributed as dist +from torch.nn.parallel.data_parallel import DataParallel # First Party from smdebug.core.collection import DEFAULT_PYTORCH_COLLECTIONS, CollectionKeys @@ -199,7 +200,7 @@ def register_hook(self, module): @staticmethod def _add_module_name(module, module_name): - if isinstance(module, torch.nn.parallel.data_parallel.DataParallel): + if isinstance(module, DataParallel): module.module._module_name = module_name else: module._module_name = module_name