Skip to content

Issue with loading checkpoints from internally compiled model from _EmptyInit() #20524

Closed
@ameya98

Description

@ameya98

Bug description

The logic in _EmptyInit() seems to cause some issues when calling .load_from_checkpoint() on a PyTorch Lightning module with a compiled submodule.

  File "/homefs/home/daigavaa/miniforge3/envs/jamun-extras/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/homefs/home/daigavaa/miniforge3/envs/jamun-extras/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 560, in inner
    raise exc.UserError(
torch._dynamo.exc.UserError: Dynamic control flow is not supported at the moment. Please use functorch.experimental.control_flow.cond to explicitly capture the control flow. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#cond-operands

from user code:
   File "/homefs/home/daigavaa/miniforge3/envs/jamun-extras/lib/python3.10/site-packages/lightning/fabric/utilities/init.py", line 52, in __torch_function__
    if not self.enabled:


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

I replaced the class with:

class _EmptyInit(TorchFunctionMode):
    """Initialize `nn.Module` with empty tensors, i.e., uninitialized memory.

    Example::

        with _EmptyInit():
            model = BigModel()
        model.load_state_dict(torch.load("checkpoint.pt"))

    """

    def __init__(self, enabled: bool = True) -> None:
        super().__init__()
        self.enabled = enabled

    @override
    def __torch_function__(
        self,
        func: Callable,
        types: Sequence,
        args: Sequence[Any] = (),
        kwargs: Optional[Dict] = {},
    ) -> Any:
        return func(*args, **kwargs)

and now I can load the checkpoint successfully.
This must have been a recent change because in previous environments, the original code was working successfully.

What version are you seeing the problem on?

v2.4

How to reproduce the bug

No response

Error messages and logs

  File "/homefs/home/daigavaa/miniforge3/envs/jamun-extras/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/homefs/home/daigavaa/miniforge3/envs/jamun-extras/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 560, in inner
    raise exc.UserError(
torch._dynamo.exc.UserError: Dynamic control flow is not supported at the moment. Please use functorch.experimental.control_flow.cond to explicitly capture the control flow. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#cond-operands

from user code:
   File "/homefs/home/daigavaa/miniforge3/envs/jamun-extras/lib/python3.10/site-packages/lightning/fabric/utilities/init.py", line 52, in __torch_function__
    if not self.enabled:


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Environment

Current environment
  • CUDA:
    - GPU:
    - NVIDIA A100-SXM4-80GB
    - NVIDIA A100-SXM4-80GB
    - available: True
    - version: 12.4
  • Lightning:
    - lightning: 2.4.0
    - lightning-utilities: 0.11.9
    - pytorch-lightning: 2.4.0
    - torch: 2.5.1
    - torch-cluster: 1.6.3+pt25cu121
    - torch-geometric: 2.6.1
    - torch-scatter: 2.1.2+pt25cu121
    - torchmetrics: 1.6.0
    - torchvision: 0.20.1
  • Packages:
    - aiobotocore: 2.15.2
    - aiohappyeyeballs: 2.4.4
    - aiohttp: 3.11.10
    - aioitertools: 0.12.0
    - aiosignal: 1.3.2
    - amberutils: 21.0
    - annotated-types: 0.7.0
    - antlr4-python3-runtime: 4.9.3
    - ase: 3.23.0
    - asttokens: 3.0.0
    - async-timeout: 5.0.1
    - attrs: 24.3.0
    - autocommand: 2.2.2
    - backports.tarfile: 1.2.0
    - boto3: 1.35.36
    - botocore: 1.35.36
    - cached-property: 1.5.2
    - certifi: 2024.12.14
    - cfgv: 3.4.0
    - cftime: 1.6.4.post1
    - charset-normalizer: 3.4.0
    - click: 8.1.7
    - colorama: 0.4.6
    - comm: 0.2.2
    - contourpy: 1.3.1
    - cycler: 0.12.1
    - debugpy: 1.8.11
    - decorator: 5.1.1
    - deeptime: 0.4.4
    - dill: 0.3.9
    - distlib: 0.3.9
    - docker-pycreds: 0.4.0
    - e3nn: 0.5.4
    - edgembar: 0.2
    - einops: 0.8.0
    - exceptiongroup: 1.2.2
    - executing: 2.1.0
    - fastjsonschema: 2.21.1
    - filelock: 3.16.1
    - fonttools: 4.55.3
    - frozenlist: 1.5.0
    - fsspec: 2024.10.0
    - gitdb: 4.0.11
    - gitpython: 3.1.43
    - h5py: 3.12.1
    - hydra-core: 1.3.2
    - identify: 2.6.4
    - idna: 3.10
    - importlib-metadata: 8.0.0
    - inflect: 7.3.1
    - iniconfig: 2.0.0
    - ipykernel: 6.29.5
    - ipython: 8.31.0
    - jamun: 0.0.post1.dev68+g46d412b.d20250101
    - jaraco.collections: 5.1.0
    - jaraco.context: 5.3.0
    - jaraco.functools: 4.0.1
    - jaraco.text: 3.12.1
    - jedi: 0.19.2
    - jinja2: 3.1.4
    - jmespath: 1.0.1
    - joblib: 1.4.2
    - jsonschema: 4.23.0
    - jsonschema-specifications: 2024.10.1
    - jupyter-client: 8.6.3
    - jupyter-core: 5.7.2
    - kiwisolver: 1.4.7
    - lightning: 2.4.0
    - lightning-utilities: 0.11.9
    - lovelyplots: 1.0.2
    - markupsafe: 3.0.2
    - matplotlib: 3.10.0
    - matplotlib-inline: 0.1.7
    - mdtraj: 1.10.2
    - mmpbsa.py: 16.0
    - more-itertools: 10.3.0
    - mpmath: 1.3.0
    - multidict: 6.1.0
    - multiprocess: 0.70.17
    - munkres: 1.1.4
    - nbformat: 5.10.4
    - nbstripout: 0.8.1
    - ndfes: 1.8
    - nest-asyncio: 1.6.0
    - netcdf4: 1.7.2
    - networkx: 3.4.2
    - ninja: 1.11.1.3
    - nodeenv: 1.9.1
    - numexpr: 2.10.2
    - numpy: 1.23.4
    - nvidia-cublas-cu12: 12.4.5.8
    - nvidia-cuda-cupti-cu12: 12.4.127
    - nvidia-cuda-nvrtc-cu12: 12.4.127
    - nvidia-cuda-runtime-cu12: 12.4.127
    - nvidia-cudnn-cu12: 9.1.0.70
    - nvidia-cufft-cu12: 11.2.1.3
    - nvidia-curand-cu12: 10.3.5.147
    - nvidia-cusolver-cu12: 11.6.1.9
    - nvidia-cusparse-cu12: 12.3.1.170
    - nvidia-nccl-cu12: 2.21.5
    - nvidia-nvjitlink-cu12: 12.4.127
    - nvidia-nvtx-cu12: 12.4.127
    - omegaconf: 2.3.0
    - openmm: 8.2.0
    - opt-einsum: 3.4.0
    - opt-einsum-fx: 0.1.4
    - packaging: 24.2
    - packmol-memgen: 2024.2.9
    - pandas: 2.2.3
    - parmed: 4.3.0
    - parso: 0.8.4
    - pathos: 0.3.3
    - pdb4amber: 22.0
    - pdbfixer: 1.10.0
    - pexpect: 4.9.0
    - pillow: 11.0.0
    - pip: 24.3.1
    - platformdirs: 4.3.6
    - plotly: 5.24.1
    - pluggy: 1.5.0
    - posebusters: 0.3.1
    - pot: 0.9.5
    - pox: 0.3.5
    - ppft: 1.7.6.9
    - pre-commit: 4.0.1
    - prompt-toolkit: 3.0.48
    - propcache: 0.2.1
    - protobuf: 5.29.1
    - psutil: 6.1.0
    - ptyprocess: 0.7.0
    - pure-eval: 0.2.3
    - py-cpuinfo: 9.0.0
    - py3dmol: 2.4.2
    - pydantic: 2.10.3
    - pydantic-core: 2.27.1
    - pyemma: 2.5.12
    - pygments: 2.18.0
    - pymsmt: 22.0
    - pyparsing: 3.2.0
    - pytest: 8.3.4
    - python-dateutil: 2.9.0.post0
    - python-dotenv: 1.0.1
    - pytorch-lightning: 2.4.0
    - pytraj: 2.0.6
    - pytz: 2024.2
    - pyyaml: 6.0.2
    - pyzmq: 26.2.0
    - rdkit: 2024.3.6
    - referencing: 0.35.1
    - requests: 2.32.3
    - rpds-py: 0.22.3
    - ruff: 0.8.4
    - s3fs: 2024.10.0
    - s3transfer: 0.10.4
    - sander: 22.0
    - scikit-learn: 1.6.0
    - scipy: 1.13.1
    - sentry-sdk: 2.19.2
    - setproctitle: 1.3.4
    - setuptools: 75.6.0
    - six: 1.17.0
    - smmap: 5.0.1
    - stack-data: 0.6.3
    - sympy: 1.13.1
    - tables: 3.10.1
    - tenacity: 9.0.0
    - threadpoolctl: 3.5.0
    - tomli: 2.2.1
    - torch: 2.5.1
    - torch-cluster: 1.6.3+pt25cu121
    - torch-geometric: 2.6.1
    - torch-scatter: 2.1.2+pt25cu121
    - torchmetrics: 1.6.0
    - torchvision: 0.20.1
    - tornado: 6.4.2
    - tqdm: 4.67.1
    - traitlets: 5.14.3
    - triton: 3.1.0
    - typeguard: 4.3.0
    - typing-extensions: 4.12.2
    - tzdata: 2024.2
    - unicodedata2: 15.1.0
    - universal-pathlib: 0.2.6
    - urllib3: 2.2.3
    - virtualenv: 20.28.0
    - wandb: 0.19.1
    - wcwidth: 0.2.13
    - wheel: 0.45.1
    - wrapt: 1.17.0
    - yarl: 1.18.3
    - zipp: 3.19.2
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.10.16
    - release: 6.1.82-99.168.amzn2023.x86_64
    - version: Proposal for help #1 SMP PREEMPT_DYNAMIC Mon Mar 25 17:11:31 UTC 2024

More info

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingrepro neededThe issue is missing a reproducible examplever: 2.4.x

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions