diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index ead3f1a03717..9ecf1de07789 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -565,13 +565,15 @@ def set_initialized_submodules(model, state_dict_keys): Sets the `_is_hf_initialized` flag in all submodules of a given model when all its weights are in the loaded state dict. """ + state_dict_keys = set(state_dict_keys) not_initialized_submodules = {} for module_name, module in model.named_modules(): - loaded_keys = {k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.")} - # When checking if the root module is loaded all state_dict_keys must be used. if module_name == "": - loaded_keys = set(state_dict_keys) - if loaded_keys.issuperset(module.state_dict()): + # When checking if the root module is loaded there's no need to prepend module_name. + module_keys = set(module.state_dict()) + else: + module_keys = {f"{module_name}.{k}" for k in module.state_dict()} + if module_keys.issubset(state_dict_keys): module._is_hf_initialized = True else: not_initialized_submodules[module_name] = module