Closed
Description
Bug description
The logic in _EmptyInit() seems to cause some issues when calling .load_from_checkpoint()
on a PyTorch Lightning module with a compiled submodule.
File "/homefs/home/daigavaa/miniforge3/envs/jamun-extras/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
self.dispatch_table[inst.opcode](self, inst)
File "/homefs/home/daigavaa/miniforge3/envs/jamun-extras/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 560, in inner
raise exc.UserError(
torch._dynamo.exc.UserError: Dynamic control flow is not supported at the moment. Please use functorch.experimental.control_flow.cond to explicitly capture the control flow. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#cond-operands
from user code:
File "/homefs/home/daigavaa/miniforge3/envs/jamun-extras/lib/python3.10/site-packages/lightning/fabric/utilities/init.py", line 52, in __torch_function__
if not self.enabled:
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
I replaced the class with:
class _EmptyInit(TorchFunctionMode):
"""Initialize `nn.Module` with empty tensors, i.e., uninitialized memory.
Example::
with _EmptyInit():
model = BigModel()
model.load_state_dict(torch.load("checkpoint.pt"))
"""
def __init__(self, enabled: bool = True) -> None:
super().__init__()
self.enabled = enabled
@override
def __torch_function__(
self,
func: Callable,
types: Sequence,
args: Sequence[Any] = (),
kwargs: Optional[Dict] = {},
) -> Any:
return func(*args, **kwargs)
and now I can load the checkpoint successfully.
This must have been a recent change because in previous environments, the original code was working successfully.
What version are you seeing the problem on?
v2.4
How to reproduce the bug
No response
Error messages and logs
File "/homefs/home/daigavaa/miniforge3/envs/jamun-extras/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
self.dispatch_table[inst.opcode](self, inst)
File "/homefs/home/daigavaa/miniforge3/envs/jamun-extras/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 560, in inner
raise exc.UserError(
torch._dynamo.exc.UserError: Dynamic control flow is not supported at the moment. Please use functorch.experimental.control_flow.cond to explicitly capture the control flow. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#cond-operands
from user code:
File "/homefs/home/daigavaa/miniforge3/envs/jamun-extras/lib/python3.10/site-packages/lightning/fabric/utilities/init.py", line 52, in __torch_function__
if not self.enabled:
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
Environment
Current environment
- CUDA:
- GPU:
- NVIDIA A100-SXM4-80GB
- NVIDIA A100-SXM4-80GB
- available: True
- version: 12.4 - Lightning:
- lightning: 2.4.0
- lightning-utilities: 0.11.9
- pytorch-lightning: 2.4.0
- torch: 2.5.1
- torch-cluster: 1.6.3+pt25cu121
- torch-geometric: 2.6.1
- torch-scatter: 2.1.2+pt25cu121
- torchmetrics: 1.6.0
- torchvision: 0.20.1 - Packages:
- aiobotocore: 2.15.2
- aiohappyeyeballs: 2.4.4
- aiohttp: 3.11.10
- aioitertools: 0.12.0
- aiosignal: 1.3.2
- amberutils: 21.0
- annotated-types: 0.7.0
- antlr4-python3-runtime: 4.9.3
- ase: 3.23.0
- asttokens: 3.0.0
- async-timeout: 5.0.1
- attrs: 24.3.0
- autocommand: 2.2.2
- backports.tarfile: 1.2.0
- boto3: 1.35.36
- botocore: 1.35.36
- cached-property: 1.5.2
- certifi: 2024.12.14
- cfgv: 3.4.0
- cftime: 1.6.4.post1
- charset-normalizer: 3.4.0
- click: 8.1.7
- colorama: 0.4.6
- comm: 0.2.2
- contourpy: 1.3.1
- cycler: 0.12.1
- debugpy: 1.8.11
- decorator: 5.1.1
- deeptime: 0.4.4
- dill: 0.3.9
- distlib: 0.3.9
- docker-pycreds: 0.4.0
- e3nn: 0.5.4
- edgembar: 0.2
- einops: 0.8.0
- exceptiongroup: 1.2.2
- executing: 2.1.0
- fastjsonschema: 2.21.1
- filelock: 3.16.1
- fonttools: 4.55.3
- frozenlist: 1.5.0
- fsspec: 2024.10.0
- gitdb: 4.0.11
- gitpython: 3.1.43
- h5py: 3.12.1
- hydra-core: 1.3.2
- identify: 2.6.4
- idna: 3.10
- importlib-metadata: 8.0.0
- inflect: 7.3.1
- iniconfig: 2.0.0
- ipykernel: 6.29.5
- ipython: 8.31.0
- jamun: 0.0.post1.dev68+g46d412b.d20250101
- jaraco.collections: 5.1.0
- jaraco.context: 5.3.0
- jaraco.functools: 4.0.1
- jaraco.text: 3.12.1
- jedi: 0.19.2
- jinja2: 3.1.4
- jmespath: 1.0.1
- joblib: 1.4.2
- jsonschema: 4.23.0
- jsonschema-specifications: 2024.10.1
- jupyter-client: 8.6.3
- jupyter-core: 5.7.2
- kiwisolver: 1.4.7
- lightning: 2.4.0
- lightning-utilities: 0.11.9
- lovelyplots: 1.0.2
- markupsafe: 3.0.2
- matplotlib: 3.10.0
- matplotlib-inline: 0.1.7
- mdtraj: 1.10.2
- mmpbsa.py: 16.0
- more-itertools: 10.3.0
- mpmath: 1.3.0
- multidict: 6.1.0
- multiprocess: 0.70.17
- munkres: 1.1.4
- nbformat: 5.10.4
- nbstripout: 0.8.1
- ndfes: 1.8
- nest-asyncio: 1.6.0
- netcdf4: 1.7.2
- networkx: 3.4.2
- ninja: 1.11.1.3
- nodeenv: 1.9.1
- numexpr: 2.10.2
- numpy: 1.23.4
- nvidia-cublas-cu12: 12.4.5.8
- nvidia-cuda-cupti-cu12: 12.4.127
- nvidia-cuda-nvrtc-cu12: 12.4.127
- nvidia-cuda-runtime-cu12: 12.4.127
- nvidia-cudnn-cu12: 9.1.0.70
- nvidia-cufft-cu12: 11.2.1.3
- nvidia-curand-cu12: 10.3.5.147
- nvidia-cusolver-cu12: 11.6.1.9
- nvidia-cusparse-cu12: 12.3.1.170
- nvidia-nccl-cu12: 2.21.5
- nvidia-nvjitlink-cu12: 12.4.127
- nvidia-nvtx-cu12: 12.4.127
- omegaconf: 2.3.0
- openmm: 8.2.0
- opt-einsum: 3.4.0
- opt-einsum-fx: 0.1.4
- packaging: 24.2
- packmol-memgen: 2024.2.9
- pandas: 2.2.3
- parmed: 4.3.0
- parso: 0.8.4
- pathos: 0.3.3
- pdb4amber: 22.0
- pdbfixer: 1.10.0
- pexpect: 4.9.0
- pillow: 11.0.0
- pip: 24.3.1
- platformdirs: 4.3.6
- plotly: 5.24.1
- pluggy: 1.5.0
- posebusters: 0.3.1
- pot: 0.9.5
- pox: 0.3.5
- ppft: 1.7.6.9
- pre-commit: 4.0.1
- prompt-toolkit: 3.0.48
- propcache: 0.2.1
- protobuf: 5.29.1
- psutil: 6.1.0
- ptyprocess: 0.7.0
- pure-eval: 0.2.3
- py-cpuinfo: 9.0.0
- py3dmol: 2.4.2
- pydantic: 2.10.3
- pydantic-core: 2.27.1
- pyemma: 2.5.12
- pygments: 2.18.0
- pymsmt: 22.0
- pyparsing: 3.2.0
- pytest: 8.3.4
- python-dateutil: 2.9.0.post0
- python-dotenv: 1.0.1
- pytorch-lightning: 2.4.0
- pytraj: 2.0.6
- pytz: 2024.2
- pyyaml: 6.0.2
- pyzmq: 26.2.0
- rdkit: 2024.3.6
- referencing: 0.35.1
- requests: 2.32.3
- rpds-py: 0.22.3
- ruff: 0.8.4
- s3fs: 2024.10.0
- s3transfer: 0.10.4
- sander: 22.0
- scikit-learn: 1.6.0
- scipy: 1.13.1
- sentry-sdk: 2.19.2
- setproctitle: 1.3.4
- setuptools: 75.6.0
- six: 1.17.0
- smmap: 5.0.1
- stack-data: 0.6.3
- sympy: 1.13.1
- tables: 3.10.1
- tenacity: 9.0.0
- threadpoolctl: 3.5.0
- tomli: 2.2.1
- torch: 2.5.1
- torch-cluster: 1.6.3+pt25cu121
- torch-geometric: 2.6.1
- torch-scatter: 2.1.2+pt25cu121
- torchmetrics: 1.6.0
- torchvision: 0.20.1
- tornado: 6.4.2
- tqdm: 4.67.1
- traitlets: 5.14.3
- triton: 3.1.0
- typeguard: 4.3.0
- typing-extensions: 4.12.2
- tzdata: 2024.2
- unicodedata2: 15.1.0
- universal-pathlib: 0.2.6
- urllib3: 2.2.3
- virtualenv: 20.28.0
- wandb: 0.19.1
- wcwidth: 0.2.13
- wheel: 0.45.1
- wrapt: 1.17.0
- yarl: 1.18.3
- zipp: 3.19.2 - System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.10.16
- release: 6.1.82-99.168.amzn2023.x86_64
- version: Proposal for help #1 SMP PREEMPT_DYNAMIC Mon Mar 25 17:11:31 UTC 2024
More info
No response