@@ -155,7 +155,7 @@ def forward_hook(self, module, inputs, outputs):
155155            return 
156156
157157        # _get_name() is implemented by nn.Module which is the base class for all modules 
158-         module_name  =  module ._get_name () 
158+         module_name  =  module ._module_name 
159159        # This overwhelms the logs; turn back on if you really need it 
160160        # logger.debug("Processing the global step {0} for module {1}".format(self.step, module_name)) 
161161
@@ -198,6 +198,14 @@ def register_hook(self, module):
198198        # for compatibility with ZCC patches which call this 
199199        self .register_module (module )
200200
201+     @staticmethod  
202+     def  _add_module_name (module , module_name ):
203+         if  isinstance (module , torch .nn .parallel .data_parallel .DataParallel ):
204+             module .module ._module_name  =  module_name 
205+         else :
206+             module ._module_name  =  module_name 
207+         return  module 
208+ 
201209    def  register_module (self , module ):
202210        """ 
203211        This function registers the forward hook. If user wants to register the hook 
@@ -216,9 +224,9 @@ def register_module(self, module):
216224
217225        for  name , submodule  in  module .named_modules ():
218226            assert  submodule  not  in self .module_set , f"Don't register module={ module }  
219-             submodule . _module_name   =   name 
227+             Hook . _add_module_name ( submodule ,  name ) 
220228            self .module_set .add (submodule )
221-         module . _module_name   =   module ._get_name ()
229+         Hook . _add_module_name ( module ,  module ._get_name () )
222230        self .module_set .add (module )
223231
224232        # Use `forward_pre_hook` for the entire net 
0 commit comments