Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incompatibility in model checkpoints between multi-GPU/single-GPU #1087

Closed
zphang opened this issue May 15, 2020 · 5 comments
Closed

Incompatibility in model checkpoints between multi-GPU/single-GPU #1087

zphang opened this issue May 15, 2020 · 5 comments
Assignees
Labels
bug Something isn't working high-priority Fix this before addressing any other major issue. jiant-v1-legacy Relevant to versions <= v1.3.2

Comments

@zphang
Copy link
Collaborator

zphang commented May 15, 2020

Summary: If you train a model with multi-GPU, then try to load it with single-GPU, the weights will not be successfully loaded. The run will print warnings but will not fail.

Background:

  • When a PyTorch module is wrapped in DataParallel, the original module will be stored as multi_gpu_model.module.
  • This means that calling state_dict on multi_gpu_model will add a "module." prefix to every key in the state_dict.
  • In other words, state_dicts are not compatible between single-gpu and multi-gpu models.

What this affects:

  • Two-phase training runs that are run from two separate commands (via load_target_train_checkpoint) will incorrectly load the model weights, reverting to the pretrained weights if one phase uses multi-GPU and the other uses single-GPU
  • E.g. If intermediate task training is on multi-GPU and target task training is on single-GPU, the target task run will fail to load the intermediate-trained weights, and revert to pretrained weights. The logs will be filled with "parameter missed" warnings, but will not fail

Proposed solution:

  • Don't call state_dict on a DataParallel model. Write a function that determines whether a model is DaraParallel, and call model.state_dict or model.module.state_dict correspondingly
def get_state_dict_for_saving(model: nn.Module) -> nn.Module:
    if isinstance(model, nn.DataParallel):
        return model.module.state_dict()
    else:
        return model.state_dict()
@zphang zphang added high-priority Fix this before addressing any other major issue. bug Something isn't working labels May 15, 2020
@zphang zphang changed the title Inconsistency in state_dict from saving model checkpoints with multi-GPU/single-GPU Incompatibility in model checkpoints between multi-GPU/single-GPU May 15, 2020
@sleepinyourhat
Copy link
Contributor

Eek—does anyone have bandwidth to put together a fix?

@pruksmhc
Copy link
Contributor

I can take a look at this this weekend

@pruksmhc pruksmhc self-assigned this May 16, 2020
@pruksmhc
Copy link
Contributor

Update: It's actually not quite as simple as just putting the above function to modify saving, since modifying saving will not work if we want to restart a job from multi-GPU (because the saved model will now always have model.state_dict() while the model in multi-GPU expects module.model.state_dict()

@zphang
Copy link
Collaborator Author

zphang commented May 18, 2020

Could you instead call model.module.load_state_dict?

@zphang
Copy link
Collaborator Author

zphang commented Oct 16, 2020

This is an automatically generated comment.

As we update jiant to v2.x, jiant v1.x has been migrated to https://github.com/nyu-mll/jiant-v1-legacy. As such, we are closing all issues relating to jiant v1.x in this repository.

If this issue is still affecting you in jiant v1.x, please follow up at nyu-mll/jiant-v1-legacy#1087.

If this issue is still affecting you in jiant v2.x, reopen this issue or create a new one.

@zphang zphang closed this as completed Oct 16, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working high-priority Fix this before addressing any other major issue. jiant-v1-legacy Relevant to versions <= v1.3.2
Projects
None yet
Development

No branches or pull requests

4 participants