Skip to content

Commit b1a4be4

Browse files
committed
fn
1 parent 2814284 commit b1a4be4

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

smdebug/pytorch/hook.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def forward_hook(self, module, inputs, outputs):
155155
return
156156

157157
# _get_name() is implemented by nn.Module which is the base class for all modules
158-
module_name = module._get_name()
158+
module_name = module._module_name
159159
# This overwhelms the logs; turn back on if you really need it
160160
# logger.debug("Processing the global step {0} for module {1}".format(self.step, module_name))
161161

@@ -198,6 +198,14 @@ def register_hook(self, module):
198198
# for compatibility with ZCC patches which call this
199199
self.register_module(module)
200200

201+
@staticmethod
202+
def _add_module_name(module, module_name):
203+
if isinstance(module, torch.nn.parallel.data_parallel.DataParallel):
204+
module.module._module_name = module_name
205+
else:
206+
module._module_name = module_name
207+
return module
208+
201209
def register_module(self, module):
202210
"""
203211
This function registers the forward hook. If user wants to register the hook
@@ -216,9 +224,9 @@ def register_module(self, module):
216224

217225
for name, submodule in module.named_modules():
218226
assert submodule not in self.module_set, f"Don't register module={module} twice"
219-
submodule._module_name = name
227+
Hook._add_module_name(submodule, name)
220228
self.module_set.add(submodule)
221-
module._module_name = module._get_name()
229+
Hook._add_module_name(module, module._get_name())
222230
self.module_set.add(module)
223231

224232
# Use `forward_pre_hook` for the entire net

0 commit comments

Comments
 (0)