Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.load error #690

Closed
ClowDragon opened this issue Nov 26, 2024 · 2 comments
Closed

torch.load error #690

ClowDragon opened this issue Nov 26, 2024 · 2 comments

Comments

@ClowDragon
Copy link

ClowDragon commented Nov 26, 2024

🐛 Bug

As opacus/utils/module_utils.py, line 102 use torch.load without setting weights_only=False, a resent update in pytorch pytorch/pytorch#137602 is the direct cause of this error.
I suppose here needed to be fixed like huggingface/transformers@1339a14

_pickle.UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, do those steps only if you trust the source of the checkpoint.
(1) Re-running torch.load with weights_only set to False will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
(2) Alternatively, to load with weights_only=True please check the recommended steps in the following error message.
WeightsUnpickler error: Unsupported global: GLOBAL torchvision.models.resnet.ResNet was not an allowed global by default. Please use torch.serialization.add_safe_globals([ResNet]) or the torch.serialization.safe_globals([ResNet]) context manager to allowlist this global if you trust this class/function.

Please reproduce using our template Colab and post here the link

To Reproduce

⚠️ We cannot help you without you sharing reproducible code. Do not ignore this part :)
Steps to reproduce the behavior:

  1. opacus/utils/module_utils.py, line 102 called torch.load

Expected behavior

Environment

Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).

You can get the script and run it with:

wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
# For security purposes, please check the contents of collect_env.py before running it.
python collect_env.py
  • PyTorch Version (e.g., 1.0):
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, source):
  • Build command you used (if compiling from source):
  • Python version:
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

@HuanyuZhang
Copy link
Contributor

Let me take a look and get back to you.

HuanyuZhang added a commit to HuanyuZhang/opacus that referenced this issue Dec 10, 2024
Summary:
Addressing Github issue pytorch#690 (pytorch#690).

Since Opacus is primarily intended for academic use cases, it is considered low risk to utilize torch.load() with the setting "weights_only=False". Additionally, reminders have been added to the function description.

Differential Revision: D66999322
facebook-github-bot pushed a commit that referenced this issue Dec 10, 2024
Summary:
Pull Request resolved: #696

Addressing Github issue #690 (#690).

Since Opacus is primarily intended for academic use cases, it is considered low risk to utilize torch.load() with the setting "weights_only=False". Additionally, reminders have been added to the function description.

Reviewed By: iden-kalemaj

Differential Revision: D66999322

fbshipit-source-id: 6eddc7b5a0390809ef25fc9fe9e5d28bf6b55130
@HuanyuZhang
Copy link
Contributor

Close the issue due to the launched fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants