Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support to saving.py for loading GPU-trained models on CPU-only machines #19024

Open
wants to merge 10 commits into
base: master
Choose a base branch
from

Conversation

amorehead
Copy link
Contributor

@amorehead amorehead commented Nov 17, 2023

What does this PR do?

  • Adds support to saving.py for loading GPU-trained models on CPU-only machines.
  • Without this fix, a .to() call in the context of CPU-only inference may lead to AssertionError: Torch not compiled with CUDA enabled.
Before submitting
  • Was this discussed/agreed via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or minor internal changes/refactors)

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:

Reviewer checklist
  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

📚 Documentation preview 📚: https://pytorch-lightning--19024.org.readthedocs.build/en/19024/

@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Nov 17, 2023
@carmocca
Copy link
Contributor

@amorehead Have you tried reporting this on PyTorch? You would expect that cpu_thing.to(cpu) is always a no-op

@amorehead
Copy link
Contributor Author

@carmocca, great point. I'll open up an issue for PyTorch as well, linked to this one for Lightning. However, for the time being (since it may take a while for PyTorch to fix the issue on their end), I think this PR for Lightning should still be useful for the time being, in case other users run into the same issue I am facing.

@carmocca
Copy link
Contributor

Yes, we can merge this, but I would like to hear from their team first before moving forward. Then we could have this:

if not _TORCH_GREATER_EQUAL_2_2:
    # your patch

@awaelchli
Copy link
Contributor

@amorehead Great find. If you still have it, could you provide the full stack trace of the error?

@amorehead
Copy link
Contributor Author

amorehead commented Nov 18, 2023

@awaelchli, yes, the stack trace is as follows.

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/home/acmwhb/mambaforge/envs/GCPNet/lib/python3.10/site-packages/lightning/pytorch/core/module.py", line 1552, in load_from_checkpoint
    loaded = _load_from_checkpoint(
  File "/home/acmwhb/mambaforge/envs/GCPNet/lib/python3.10/site-packages/lightning/pytorch/core/saving.py", line 97, in _load_from_checkpoint
    return model.to(device)
  File "/home/acmwhb/mambaforge/envs/GCPNet/lib/python3.10/site-packages/lightning/fabric/utilities/device_dtype_mixin.py", line 54, in to
    return super().to(*args, **kwargs)
  File "/home/acmwhb/mambaforge/envs/GCPNet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1160, in to
    return self._apply(convert)
  File "/home/acmwhb/mambaforge/envs/GCPNet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 810, in _apply
    module._apply(fn)
  File "/home/acmwhb/mambaforge/envs/GCPNet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 810, in _apply
    module._apply(fn)
  File "/home/acmwhb/mambaforge/envs/GCPNet/lib/python3.10/site-packages/torchmetrics/metric.py", line 808, in _apply
    self._device = fn(torch.zeros(1, device=self.device)).device
  File "/home/acmwhb/mambaforge/envs/GCPNet/lib/python3.10/site-packages/torch/cuda/__init__.py", line 289, in _lazy_init
    raise AssertionError("Torch not compiled with CUDA enabled")
AssertionError: Torch not compiled with CUDA enabled

This is triggered by calling:

my_lightning_model.__class__.load_from_checkpoint(
    checkpoint_path=ckpt_path,
    map_location="cpu",
    strict=True,
)

The issue happens with both Lightning 2.1.0 and 2.1.2 (note the __class__ bit for 2.1.2). When I install my patched version of Lightning (as packaged in this PR), this issue goes away by skipping these .to() calls altogether.

@awaelchli
Copy link
Contributor

Given the stack trace, we see that it goes through torchmetrics and fails at this line:

    self._device = fn(torch.zeros(1, device=self.device)).device

maybe self.device (for some reason) is cuda and not cpu? In any case, it would be good if we could identify if it's an issue PyTorch or metrics. I couldn't repro on my MacOS. @amorehead any change you could help here sanity checking that self.device is CPU is in this line above?

@tringwald
Copy link

This seems to be a torchmetrics bug, see discussion on the PyTorch issue tracker (pytorch/pytorch#113973).

@awaelchli
Copy link
Contributor

@amorehead Did you actually end up with the entire torchmetric object pickled in the checkpoint like described by this user Lightning-AI/torchmetrics#2223 or was it a proper checkpoint with the state dict of the metric? Because the former would indeed explain your issue, but then the fix should be not to pickle the metric in the first place.

@amorehead
Copy link
Contributor Author

@awaelchli,

You have described it perfectly. The checkpoints I am trying to load on a CPU-only machine contain full TorchMetrics objects in them unintentionally. Seems this is not best practice by any means. Are you aware of any workarounds for this issue in light of the metrics being fully saved in my checkpoint files, or is the only solution to only save the state_dicts in the first place?

@sfalkena
Copy link

sfalkena commented Dec 7, 2023

Hi, I was having the same issue, and this commit fixed it for me! I would be very happy if this gets merged.

Copy link
Member

@Borda Borda left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seem reasonable, could you pls add test that would train a model on GPU and then mock CUDA_VISIBLE_EVICES to load it with CPU only...

Copy link

codecov bot commented Dec 7, 2023

Codecov Report

Attention: Patch coverage is 83.33333% with 1 line in your changes missing coverage. Please review.

Project coverage is 47%. Comparing base (bb14a97) to head (c856888).
Report is 339 commits behind head on master.

❗ There is a different number of reports uploaded between BASE (bb14a97) and HEAD (c856888). Click for more details.

HEAD has 205 uploads less than BASE
Flag BASE (bb14a97) HEAD (c856888)
lightning 44 15
cpu 74 24
pytest 56 0
python3.10 21 9
app 9 0
examples 9 0
gpu 4 0
lightning_fabric 10 0
python3.9 6 3
python3.11 15 6
python3.8 12 6
tpu 2 0
pytorch_lightning 10 9
lightning_app 5 0
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #19024      +/-   ##
==========================================
- Coverage      83%      47%     -36%     
==========================================
  Files         445      437       -8     
  Lines       37289    37140     -149     
==========================================
- Hits        31119    17586   -13533     
- Misses       6170    19554   +13384     

@mergify mergify bot removed the has conflicts label Jan 10, 2024
Copy link

gitguardian bot commented Jan 16, 2024

️✅ There are no secrets present in this pull request anymore.

If these secrets were true positive and are still valid, we highly recommend you to revoke them.
Once a secret has been leaked into a git repository, you should consider it compromised, even if it was deleted immediately.
Find here more information about risks.


🦉 GitGuardian detects secrets in your source code to help developers and security teams secure the modern development process. You are seeing this because you or someone else with access to this repository has authorized GitGuardian to scan your pull request.

@mergify mergify bot removed the has conflicts label Feb 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pl Generic label for PyTorch Lightning package
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants