Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incompatibility in model checkpoints between multi-GPU/single-GPU #1087

Open
jeswan opened this issue Sep 17, 2020 · 4 comments
Open

Incompatibility in model checkpoints between multi-GPU/single-GPU #1087

jeswan opened this issue Sep 17, 2020 · 4 comments
Labels
bug Something isn't working high-priority Fix this before addressing any other major issue.

Comments

@jeswan
Copy link
Contributor

jeswan commented Sep 17, 2020

Issue by zphang
Friday May 15, 2020 at 06:16 GMT
Originally opened as nyu-mll/jiant#1087


Summary: If you train a model with multi-GPU, then try to load it with single-GPU, the weights will not be successfully loaded. The run will print warnings but will not fail.

Background:

  • When a PyTorch module is wrapped in DataParallel, the original module will be stored as multi_gpu_model.module.
  • This means that calling state_dict on multi_gpu_model will add a "module." prefix to every key in the state_dict.
  • In other words, state_dicts are not compatible between single-gpu and multi-gpu models.

What this affects:

  • Two-phase training runs that are run from two separate commands (via load_target_train_checkpoint) will incorrectly load the model weights, reverting to the pretrained weights if one phase uses multi-GPU and the other uses single-GPU
  • E.g. If intermediate task training is on multi-GPU and target task training is on single-GPU, the target task run will fail to load the intermediate-trained weights, and revert to pretrained weights. The logs will be filled with "parameter missed" warnings, but will not fail

Proposed solution:

  • Don't call state_dict on a DataParallel model. Write a function that determines whether a model is DaraParallel, and call model.state_dict or model.module.state_dict correspondingly
def get_state_dict_for_saving(model: nn.Module) -> nn.Module:
    if isinstance(model, nn.DataParallel):
        return model.module.state_dict()
    else:
        return model.state_dict()
@jeswan jeswan added bug Something isn't working high-priority Fix this before addressing any other major issue. labels Sep 17, 2020
@jeswan
Copy link
Contributor Author

jeswan commented Sep 17, 2020

Comment by sleepinyourhat
Friday May 15, 2020 at 15:50 GMT


Eek—does anyone have bandwidth to put together a fix?

@jeswan
Copy link
Contributor Author

jeswan commented Sep 17, 2020

Comment by pruksmhc
Friday May 15, 2020 at 19:24 GMT


I can take a look at this this weekend

@jeswan
Copy link
Contributor Author

jeswan commented Sep 17, 2020

Comment by pruksmhc
Sunday May 17, 2020 at 16:32 GMT


Update: It's actually not quite as simple as just putting the above function to modify saving, since modifying saving will not work if we want to restart a job from multi-GPU (because the saved model will now always have model.state_dict() while the model in multi-GPU expects module.model.state_dict()

@jeswan
Copy link
Contributor Author

jeswan commented Sep 17, 2020

Comment by zphang
Monday May 18, 2020 at 04:45 GMT


Could you instead call model.module.load_state_dict?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working high-priority Fix this before addressing any other major issue.
Projects
None yet
Development

No branches or pull requests

1 participant