Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with loading checkpoints from internally compiled model from _EmptyInit() #20524

Open
ameya98 opened this issue Jan 2, 2025 · 1 comment
Labels
bug Something isn't working ver: 2.4.x waiting on author Waiting on user action, correction, or update

Comments

@ameya98
Copy link

ameya98 commented Jan 2, 2025

Bug description

The logic in _EmptyInit() seems to cause some issues when calling .load_from_checkpoint() on a PyTorch Lightning module with a compiled submodule.

  File "/homefs/home/daigavaa/miniforge3/envs/jamun-extras/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/homefs/home/daigavaa/miniforge3/envs/jamun-extras/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 560, in inner
    raise exc.UserError(
torch._dynamo.exc.UserError: Dynamic control flow is not supported at the moment. Please use functorch.experimental.control_flow.cond to explicitly capture the control flow. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#cond-operands

from user code:
   File "/homefs/home/daigavaa/miniforge3/envs/jamun-extras/lib/python3.10/site-packages/lightning/fabric/utilities/init.py", line 52, in __torch_function__
    if not self.enabled:


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

I replaced the class with:

class _EmptyInit(TorchFunctionMode):
    """Initialize `nn.Module` with empty tensors, i.e., uninitialized memory.

    Example::

        with _EmptyInit():
            model = BigModel()
        model.load_state_dict(torch.load("checkpoint.pt"))

    """

    def __init__(self, enabled: bool = True) -> None:
        super().__init__()
        self.enabled = enabled

    @override
    def __torch_function__(
        self,
        func: Callable,
        types: Sequence,
        args: Sequence[Any] = (),
        kwargs: Optional[Dict] = {},
    ) -> Any:
        return func(*args, **kwargs)

and now I can load the checkpoint successfully.
This must have been a recent change because in previous environments, the original code was working successfully.

What version are you seeing the problem on?

v2.4

How to reproduce the bug

No response

Error messages and logs

  File "/homefs/home/daigavaa/miniforge3/envs/jamun-extras/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/homefs/home/daigavaa/miniforge3/envs/jamun-extras/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 560, in inner
    raise exc.UserError(
torch._dynamo.exc.UserError: Dynamic control flow is not supported at the moment. Please use functorch.experimental.control_flow.cond to explicitly capture the control flow. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#cond-operands

from user code:
   File "/homefs/home/daigavaa/miniforge3/envs/jamun-extras/lib/python3.10/site-packages/lightning/fabric/utilities/init.py", line 52, in __torch_function__
    if not self.enabled:


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Environment

Current environment
  • CUDA:
    - GPU:
    - NVIDIA A100-SXM4-80GB
    - NVIDIA A100-SXM4-80GB
    - available: True
    - version: 12.4
  • Lightning:
    - lightning: 2.4.0
    - lightning-utilities: 0.11.9
    - pytorch-lightning: 2.4.0
    - torch: 2.5.1
    - torch-cluster: 1.6.3+pt25cu121
    - torch-geometric: 2.6.1
    - torch-scatter: 2.1.2+pt25cu121
    - torchmetrics: 1.6.0
    - torchvision: 0.20.1
  • Packages:
    - aiobotocore: 2.15.2
    - aiohappyeyeballs: 2.4.4
    - aiohttp: 3.11.10
    - aioitertools: 0.12.0
    - aiosignal: 1.3.2
    - amberutils: 21.0
    - annotated-types: 0.7.0
    - antlr4-python3-runtime: 4.9.3
    - ase: 3.23.0
    - asttokens: 3.0.0
    - async-timeout: 5.0.1
    - attrs: 24.3.0
    - autocommand: 2.2.2
    - backports.tarfile: 1.2.0
    - boto3: 1.35.36
    - botocore: 1.35.36
    - cached-property: 1.5.2
    - certifi: 2024.12.14
    - cfgv: 3.4.0
    - cftime: 1.6.4.post1
    - charset-normalizer: 3.4.0
    - click: 8.1.7
    - colorama: 0.4.6
    - comm: 0.2.2
    - contourpy: 1.3.1
    - cycler: 0.12.1
    - debugpy: 1.8.11
    - decorator: 5.1.1
    - deeptime: 0.4.4
    - dill: 0.3.9
    - distlib: 0.3.9
    - docker-pycreds: 0.4.0
    - e3nn: 0.5.4
    - edgembar: 0.2
    - einops: 0.8.0
    - exceptiongroup: 1.2.2
    - executing: 2.1.0
    - fastjsonschema: 2.21.1
    - filelock: 3.16.1
    - fonttools: 4.55.3
    - frozenlist: 1.5.0
    - fsspec: 2024.10.0
    - gitdb: 4.0.11
    - gitpython: 3.1.43
    - h5py: 3.12.1
    - hydra-core: 1.3.2
    - identify: 2.6.4
    - idna: 3.10
    - importlib-metadata: 8.0.0
    - inflect: 7.3.1
    - iniconfig: 2.0.0
    - ipykernel: 6.29.5
    - ipython: 8.31.0
    - jamun: 0.0.post1.dev68+g46d412b.d20250101
    - jaraco.collections: 5.1.0
    - jaraco.context: 5.3.0
    - jaraco.functools: 4.0.1
    - jaraco.text: 3.12.1
    - jedi: 0.19.2
    - jinja2: 3.1.4
    - jmespath: 1.0.1
    - joblib: 1.4.2
    - jsonschema: 4.23.0
    - jsonschema-specifications: 2024.10.1
    - jupyter-client: 8.6.3
    - jupyter-core: 5.7.2
    - kiwisolver: 1.4.7
    - lightning: 2.4.0
    - lightning-utilities: 0.11.9
    - lovelyplots: 1.0.2
    - markupsafe: 3.0.2
    - matplotlib: 3.10.0
    - matplotlib-inline: 0.1.7
    - mdtraj: 1.10.2
    - mmpbsa.py: 16.0
    - more-itertools: 10.3.0
    - mpmath: 1.3.0
    - multidict: 6.1.0
    - multiprocess: 0.70.17
    - munkres: 1.1.4
    - nbformat: 5.10.4
    - nbstripout: 0.8.1
    - ndfes: 1.8
    - nest-asyncio: 1.6.0
    - netcdf4: 1.7.2
    - networkx: 3.4.2
    - ninja: 1.11.1.3
    - nodeenv: 1.9.1
    - numexpr: 2.10.2
    - numpy: 1.23.4
    - nvidia-cublas-cu12: 12.4.5.8
    - nvidia-cuda-cupti-cu12: 12.4.127
    - nvidia-cuda-nvrtc-cu12: 12.4.127
    - nvidia-cuda-runtime-cu12: 12.4.127
    - nvidia-cudnn-cu12: 9.1.0.70
    - nvidia-cufft-cu12: 11.2.1.3
    - nvidia-curand-cu12: 10.3.5.147
    - nvidia-cusolver-cu12: 11.6.1.9
    - nvidia-cusparse-cu12: 12.3.1.170
    - nvidia-nccl-cu12: 2.21.5
    - nvidia-nvjitlink-cu12: 12.4.127
    - nvidia-nvtx-cu12: 12.4.127
    - omegaconf: 2.3.0
    - openmm: 8.2.0
    - opt-einsum: 3.4.0
    - opt-einsum-fx: 0.1.4
    - packaging: 24.2
    - packmol-memgen: 2024.2.9
    - pandas: 2.2.3
    - parmed: 4.3.0
    - parso: 0.8.4
    - pathos: 0.3.3
    - pdb4amber: 22.0
    - pdbfixer: 1.10.0
    - pexpect: 4.9.0
    - pillow: 11.0.0
    - pip: 24.3.1
    - platformdirs: 4.3.6
    - plotly: 5.24.1
    - pluggy: 1.5.0
    - posebusters: 0.3.1
    - pot: 0.9.5
    - pox: 0.3.5
    - ppft: 1.7.6.9
    - pre-commit: 4.0.1
    - prompt-toolkit: 3.0.48
    - propcache: 0.2.1
    - protobuf: 5.29.1
    - psutil: 6.1.0
    - ptyprocess: 0.7.0
    - pure-eval: 0.2.3
    - py-cpuinfo: 9.0.0
    - py3dmol: 2.4.2
    - pydantic: 2.10.3
    - pydantic-core: 2.27.1
    - pyemma: 2.5.12
    - pygments: 2.18.0
    - pymsmt: 22.0
    - pyparsing: 3.2.0
    - pytest: 8.3.4
    - python-dateutil: 2.9.0.post0
    - python-dotenv: 1.0.1
    - pytorch-lightning: 2.4.0
    - pytraj: 2.0.6
    - pytz: 2024.2
    - pyyaml: 6.0.2
    - pyzmq: 26.2.0
    - rdkit: 2024.3.6
    - referencing: 0.35.1
    - requests: 2.32.3
    - rpds-py: 0.22.3
    - ruff: 0.8.4
    - s3fs: 2024.10.0
    - s3transfer: 0.10.4
    - sander: 22.0
    - scikit-learn: 1.6.0
    - scipy: 1.13.1
    - sentry-sdk: 2.19.2
    - setproctitle: 1.3.4
    - setuptools: 75.6.0
    - six: 1.17.0
    - smmap: 5.0.1
    - stack-data: 0.6.3
    - sympy: 1.13.1
    - tables: 3.10.1
    - tenacity: 9.0.0
    - threadpoolctl: 3.5.0
    - tomli: 2.2.1
    - torch: 2.5.1
    - torch-cluster: 1.6.3+pt25cu121
    - torch-geometric: 2.6.1
    - torch-scatter: 2.1.2+pt25cu121
    - torchmetrics: 1.6.0
    - torchvision: 0.20.1
    - tornado: 6.4.2
    - tqdm: 4.67.1
    - traitlets: 5.14.3
    - triton: 3.1.0
    - typeguard: 4.3.0
    - typing-extensions: 4.12.2
    - tzdata: 2024.2
    - unicodedata2: 15.1.0
    - universal-pathlib: 0.2.6
    - urllib3: 2.2.3
    - virtualenv: 20.28.0
    - wandb: 0.19.1
    - wcwidth: 0.2.13
    - wheel: 0.45.1
    - wrapt: 1.17.0
    - yarl: 1.18.3
    - zipp: 3.19.2
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.10.16
    - release: 6.1.82-99.168.amzn2023.x86_64
    - version: Proposal for help #1 SMP PREEMPT_DYNAMIC Mon Mar 25 17:11:31 UTC 2024

More info

No response

@ameya98 ameya98 added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Jan 2, 2025
@lantiga
Copy link
Collaborator

lantiga commented Jan 6, 2025

Hey @ameya98 the code hasn't changed in a long time, it's likely that dynamo now detects this whereas earlier it would fall back to eager.

Is it possible for you to send a full reproduction to speed up looking into this?

@lantiga lantiga added waiting on author Waiting on user action, correction, or update and removed needs triage Waiting to be triaged by maintainers labels Jan 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working ver: 2.4.x waiting on author Waiting on user action, correction, or update
Projects
None yet
Development

No branches or pull requests

2 participants