Skip to content

Commit

Permalink
Fix torch.load() in model_utils.py (#696)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #696

Addressing Github issue #690 (#690).

Since Opacus is primarily intended for academic use cases, it is considered low risk to utilize torch.load() with the setting "weights_only=False". Additionally, reminders have been added to the function description.

Reviewed By: iden-kalemaj

Differential Revision: D66999322

fbshipit-source-id: 6eddc7b5a0390809ef25fc9fe9e5d28bf6b55130
  • Loading branch information
HuanyuZhang authored and facebook-github-bot committed Dec 10, 2024
1 parent 10eb10a commit 4dbe5ef
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion opacus/utils/module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def clone_module(module: nn.Module) -> nn.Module:
"""
Handy utility to clone an nn.Module. PyTorch doesn't always support copy.deepcopy(), so it is
just easier to serialize the model to a BytesIO and read it from there.
When ``weights_only=False``, ``torch.load()`` uses "pickle" module implicity, which is known to be insecure.
Only load the model you trust.
Args:
module: The module to clone
Expand All @@ -99,7 +101,7 @@ def clone_module(module: nn.Module) -> nn.Module:
with io.BytesIO() as bytesio:
torch.save(module, bytesio)
bytesio.seek(0)
module_copy = torch.load(bytesio)
module_copy = torch.load(bytesio, weights_only=False)
next_param = next(
module.parameters(), None
) # Eg, InstanceNorm with affine=False has no params
Expand Down

0 comments on commit 4dbe5ef

Please sign in to comment.