diff --git a/.azure/gpu-tests-fabric.yml b/.azure/gpu-tests-fabric.yml index b2a1d255418eb..88dfe3b7f15cb 100644 --- a/.azure/gpu-tests-fabric.yml +++ b/.azure/gpu-tests-fabric.yml @@ -59,6 +59,9 @@ jobs: "Fabric | latest": image: "pytorchlightning/pytorch_lightning:base-cuda-py3.10-torch2.1-cuda12.1.0" PACKAGE_NAME: "fabric" + "Fabric | future": + image: "pytorchlightning/pytorch_lightning:base-cuda-py3.11-torch2.2-cuda12.1.0" + PACKAGE_NAME: "fabric" "Lightning | latest": image: "pytorchlightning/pytorch_lightning:base-cuda-py3.10-torch2.1-cuda12.1.0" PACKAGE_NAME: "lightning" @@ -73,6 +76,10 @@ jobs: scope=$(python -c 'n = "$(PACKAGE_NAME)" ; print(dict(fabric="lightning_fabric").get(n, n))') echo "##vso[task.setvariable variable=COVERAGE_SOURCE]$scope" displayName: "set env. vars" + - bash: | + echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM}/torch_test.html" + condition: endsWith(variables['Agent.JobName'], 'future') + displayName: "set env. vars 4 future" - bash: | echo $(DEVICES) @@ -99,7 +106,7 @@ jobs: - bash: | extra=$(python -c "print({'lightning': 'fabric-'}.get('$(PACKAGE_NAME)', ''))") - pip install -e ".[${extra}dev]" pytest-timeout -U --find-links ${TORCH_URL} + pip install -e ".[${extra}dev]" pytest-timeout -U --find-links="${TORCH_URL}" displayName: "Install package & dependencies" - bash: | diff --git a/.azure/gpu-tests-pytorch.yml b/.azure/gpu-tests-pytorch.yml index 0dc40af5c8977..21e68af90e0ed 100644 --- a/.azure/gpu-tests-pytorch.yml +++ b/.azure/gpu-tests-pytorch.yml @@ -51,6 +51,10 @@ jobs: "PyTorch | latest": image: "pytorchlightning/pytorch_lightning:base-cuda-py3.10-torch2.1-cuda12.1.0" PACKAGE_NAME: "pytorch" + "PyTorch | future": + # todo: failed to install `pygame` with py3.11 + image: "pytorchlightning/pytorch_lightning:base-cuda-py3.10-torch2.2-cuda12.1.0" + PACKAGE_NAME: "pytorch" "Lightning | latest": image: "pytorchlightning/pytorch_lightning:base-cuda-py3.10-torch2.1-cuda12.1.0" PACKAGE_NAME: "lightning" @@ -76,6 +80,10 @@ jobs: scope=$(python -c 'n = "$(PACKAGE_NAME)" ; print(dict(pytorch="pytorch_lightning").get(n, n))') echo "##vso[task.setvariable variable=COVERAGE_SOURCE]$scope" displayName: "set env. vars" + - bash: | + echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM}/torch_test.html" + condition: endsWith(variables['Agent.JobName'], 'future') + displayName: "set env. vars 4 future" - bash: | echo $(DEVICES) @@ -109,7 +117,7 @@ jobs: - bash: | extra=$(python -c "print({'lightning': 'pytorch-'}.get('$(PACKAGE_NAME)', ''))") - pip install -e ".[${extra}dev]" -r requirements/_integrations/strategies.txt pytest-timeout -U --find-links ${TORCH_URL} + pip install -e ".[${extra}dev]" -r requirements/_integrations/strategies.txt pytest-timeout -U --find-links="${TORCH_URL}" displayName: "Install package & dependencies" - bash: pip uninstall -y lightning diff --git a/.github/workflows/ci-tests-data.yml b/.github/workflows/ci-tests-data.yml index 8af4b83c443c1..629855e9b75da 100644 --- a/.github/workflows/ci-tests-data.yml +++ b/.github/workflows/ci-tests-data.yml @@ -87,7 +87,7 @@ jobs: # ls -lh $PYPI_CACHE_DIR - name: Install package & dependencies - timeout-minutes: 20 + timeout-minutes: 30 run: | pip install -e ".[data-dev]" -U --prefer-binary -f ${TORCH_URL} pip list diff --git a/.github/workflows/ci-tests-fabric.yml b/.github/workflows/ci-tests-fabric.yml index 3bde8946ff7d2..f82add5d8e6a4 100644 --- a/.github/workflows/ci-tests-fabric.yml +++ b/.github/workflows/ci-tests-fabric.yml @@ -42,13 +42,17 @@ jobs: - { os: "macOS-11", pkg-name: "lightning", python-version: "3.10", pytorch-version: "1.13" } - { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "1.13" } - { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "1.13" } - # only run PyTorch latest with Python recent + # only run PyTorch latest - { os: "macOS-11", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" } - { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" } - { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" } - { os: "macOS-11", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.1" } - { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.1" } - { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.1" } + # only run PyTorch future + - { os: "macOS-12", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" } + - { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" } + - { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" } # only run PyTorch latest with Python latest, use Fabric scope to limit dependency issues - { os: "macOS-12", pkg-name: "fabric", python-version: "3.11", pytorch-version: "2.0" } - { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.11", pytorch-version: "2.0" } @@ -125,7 +129,7 @@ jobs: - name: Env. variables run: | # Switch PyTorch URL - #python -c "print('TORCH_URL=' + str('${{env.TORCH_URL_TEST}}' if '${{ matrix.release }}' == 'pre' else '${{env.TORCH_URL_STABLE}}'))" >> $GITHUB_ENV + python -c "print('TORCH_URL=' + str('${{env.TORCH_URL_TEST}}' if '${{ matrix.pytorch-version }}' == '2.2' else '${{env.TORCH_URL_STABLE}}'))" >> $GITHUB_ENV # Switch coverage scope python -c "print('COVERAGE_SCOPE=' + str('lightning' if '${{matrix.pkg-name}}' == 'lightning' else 'lightning_fabric'))" >> $GITHUB_ENV # if you install mono-package set dependency only for this subpackage @@ -154,7 +158,7 @@ jobs: - name: Testing Warnings working-directory: tests/tests_fabric - # needs to run outside of `pytest` + # needs to run outside `pytest` run: python utilities/test_warnings.py - name: Testing Fabric diff --git a/.github/workflows/ci-tests-pytorch.yml b/.github/workflows/ci-tests-pytorch.yml index 12285efbf61b0..717aac509663b 100644 --- a/.github/workflows/ci-tests-pytorch.yml +++ b/.github/workflows/ci-tests-pytorch.yml @@ -46,13 +46,17 @@ jobs: - { os: "macOS-11", pkg-name: "lightning", python-version: "3.10", pytorch-version: "1.13" } - { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "1.13" } - { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "1.13" } - # only run PyTorch latest with Python recent + # only run PyTorch latest - { os: "macOS-11", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" } - { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" } - { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" } - { os: "macOS-11", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" } - { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" } - { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" } + # only run PyTorch future + - { os: "macOS-12", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.2" } + - { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.2" } + - { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.2" } # only run PyTorch latest with Python latest, use PyTorch scope to limit dependency issues - { os: "macOS-12", pkg-name: "pytorch", python-version: "3.11", pytorch-version: "2.0" } - { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.11", pytorch-version: "2.0" } @@ -131,7 +135,7 @@ jobs: - name: Env. variables run: | # Switch PyTorch URL - #python -c "print('TORCH_URL=' + str('${{env.TORCH_URL_TEST}}' if '${{ matrix.release }}' == 'pre' else '${{env.TORCH_URL_STABLE}}'))" >> $GITHUB_ENV + python -c "print('TORCH_URL=' + str('${{env.TORCH_URL_TEST}}' if '${{ matrix.pytorch-version }}' == '2.2' else '${{env.TORCH_URL_STABLE}}'))" >> $GITHUB_ENV # Switch coverage scope python -c "print('COVERAGE_SCOPE=' + str('lightning' if '${{matrix.pkg-name}}' == 'lightning' else 'pytorch_lightning'))" >> $GITHUB_ENV # if you install mono-package set dependency only for this subpackage @@ -191,7 +195,7 @@ jobs: - name: Testing Warnings working-directory: tests/tests_pytorch - # needs to run outside of `pytest` + # needs to run outside `pytest` run: python utilities/test_warnings.py - name: Testing PyTorch diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml index 65580e9817326..4b9a56d0c963a 100644 --- a/.github/workflows/docker-build.yml +++ b/.github/workflows/docker-build.yml @@ -101,12 +101,16 @@ jobs: fail-fast: false matrix: include: - # These are the base images for PL release docker images. - # Make sure the matrix here matches the one above. + # These are the base images for PL release docker images, + # so include at least all the combinations in release-dockers.yml. - { python_version: "3.9", pytorch_version: "1.13", cuda_version: "11.8.0" } - { python_version: "3.9", pytorch_version: "1.13", cuda_version: "12.0.1" } - { python_version: "3.10", pytorch_version: "2.0", cuda_version: "11.8.0" } - { python_version: "3.10", pytorch_version: "2.1", cuda_version: "12.1.0" } + - { python_version: "3.10", pytorch_version: "2.2", cuda_version: "12.1.0" } + - { python_version: "3.11", pytorch_version: "2.1", cuda_version: "12.1.0" } + - { python_version: "3.11", pytorch_version: "2.2", cuda_version: "12.1.0" } + # - { python_version: "3.12", pytorch_version: "2.2", cuda_version: "12.1.0" } # todo: pending on `onnxruntime` steps: - uses: actions/checkout@v4 - uses: docker/setup-buildx-action@v3 diff --git a/dockers/base-cuda/Dockerfile b/dockers/base-cuda/Dockerfile index dd89692dd7c13..d5b72768148ed 100644 --- a/dockers/base-cuda/Dockerfile +++ b/dockers/base-cuda/Dockerfile @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -ARG UBUNTU_VERSION=20.04 +ARG UBUNTU_VERSION=22.04 ARG CUDA_VERSION=11.7.1 @@ -38,7 +38,7 @@ RUN \ # https://github.com/NVIDIA/nvidia-docker/issues/1631 # https://github.com/NVIDIA/nvidia-docker/issues/1631#issuecomment-1264715214 apt-get update && apt-get install -y wget && \ - wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/3bf863cc.pub && \ + wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/3bf863cc.pub && \ mkdir -p /etc/apt/keyrings/ && mv 3bf863cc.pub /etc/apt/keyrings/ && \ echo "deb [signed-by=/etc/apt/keyrings/3bf863cc.pub] https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/ /" /etc/apt/sources.list.d/cuda.list && \ apt-get update -qq --fix-missing && \ @@ -82,9 +82,7 @@ COPY requirements/_integrations/ requirements/_integrations/ ENV PYTHONPATH="/usr/lib/python${PYTHON_VERSION}/site-packages" RUN \ - wget https://bootstrap.pypa.io/get-pip.py --progress=bar:force:noscroll --no-check-certificate && \ - python${PYTHON_VERSION} get-pip.py && \ - rm get-pip.py && \ + curl https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} && \ # Disable cache \ pip config set global.cache-dir false && \ # set particular PyTorch version \ @@ -99,7 +97,8 @@ RUN \ -r requirements/pytorch/extra.txt \ -r requirements/pytorch/test.txt \ -r requirements/pytorch/strategies.txt \ - --find-links "https://download.pytorch.org/whl/cu${CUDA_VERSION_MM//'.'/''}/torch_stable.html" + --find-links="https://download.pytorch.org/whl/cu${CUDA_VERSION_MM//'.'/''}/torch_stable.html" \ + --find-links="https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM//'.'/''}/torch_test.html" RUN \ # Show what we have diff --git a/docs/source-fabric/api/fabric_methods.rst b/docs/source-fabric/api/fabric_methods.rst index d6e7ccbc38e31..315ba195767dd 100644 --- a/docs/source-fabric/api/fabric_methods.rst +++ b/docs/source-fabric/api/fabric_methods.rst @@ -108,7 +108,6 @@ This is useful if your model experiences *exploding gradients* during training. fabric.clip_gradients(model, optimizer, max_norm=2.0, norm_type="inf") The :meth:`~lightning.fabric.fabric.Fabric.clip_gradients` method is agnostic to the precision and strategy being used. -Note: Gradient clipping with FSDP is not yet fully supported. to_device diff --git a/docs/source-fabric/fundamentals/launch.rst b/docs/source-fabric/fundamentals/launch.rst index bc4a5bafd42db..0f25ebcda6021 100644 --- a/docs/source-fabric/fundamentals/launch.rst +++ b/docs/source-fabric/fundamentals/launch.rst @@ -172,7 +172,7 @@ Choose from the following options based on your expertise level and available in
.. displayitem:: - :header: Lightning Cloud + :header: Run single or multi-node on Lightning Studios :description: The easiest way to scale models in the cloud. No infrastructure setup required. :col_css: col-md-4 :button_link: ../guide/multi_node/cloud.html diff --git a/docs/source-fabric/guide/checkpoint/distributed_checkpoint.rst b/docs/source-fabric/guide/checkpoint/distributed_checkpoint.rst index 57e546e03f1c5..da380c26a7783 100644 --- a/docs/source-fabric/guide/checkpoint/distributed_checkpoint.rst +++ b/docs/source-fabric/guide/checkpoint/distributed_checkpoint.rst @@ -183,4 +183,36 @@ Note that you can load the distributed checkpoint even if the world size has cha Convert a distributed checkpoint ******************************** -Coming soon. +It is possible to convert a distributed checkpoint to a regular, single-file checkpoint with this utility: + +.. code-block:: bash + + python -m lightning.fabric.utilities.consolidate_checkpoint path/to/my/checkpoint + +You will need to do this for example if you want to load the checkpoint into a script that doesn't use FSDP, or need to export the checkpoint to a different format for deployment, evaluation, etc. + +.. note:: + + All tensors in the checkpoint will be converted to CPU tensors, and no GPUs are required to run the conversion command. + This function assumes you have enough free CPU memory to hold the entire checkpoint in memory. + +.. collapse:: Full example + + Assuming you have saved a checkpoint ``my-checkpoint.ckpt`` using the examples above, run the following command to convert it: + + .. code-block:: bash + + python -m lightning.fabric.utilities.consolidate_checkpoint my-checkpoint.ckpt + + This saves a new file ``my-checkpoint.ckpt.consolidated`` next to the sharded checkpoint which you can load normally in PyTorch: + + .. code-block:: python + + import torch + + checkpoint = torch.load("my-checkpoint.ckpt.consolidated") + print(list(checkpoint.keys())) + print(checkpoint["model"]["transformer.decoder.layers.31.norm1.weight"]) + + +| diff --git a/docs/source-fabric/guide/multi_node/cloud.rst b/docs/source-fabric/guide/multi_node/cloud.rst index 2f1854bc33323..39758cd3d094e 100644 --- a/docs/source-fabric/guide/multi_node/cloud.rst +++ b/docs/source-fabric/guide/multi_node/cloud.rst @@ -1,14 +1,13 @@ :orphan: -########################## -Run in the Lightning Cloud -########################## +############################################# +Run single or multi-node on Lightning Studios +############################################# **Audience**: Users who don't want to waste time on cluster configuration and maintenance. - -The Lightning AI cloud is a platform where you can build, train, finetune and deploy models without worrying about infrastructure, cost management, scaling, and other technical headaches. -In this guide, and within just 10 minutes, you will learn how to run a Fabric training script across multiple nodes in the cloud. +`Lightning Studios `_ is a cloud platform where you can build, train, finetune and deploy models without worrying about infrastructure, cost management, scaling, and other technical headaches. +This guide shows you how easy it is to run a Fabric training script across multiple machines on Lightning Studios. ---- @@ -19,13 +18,8 @@ Initial Setup ************* First, create a free `Lightning AI account `_. -Then, log in from the CLI: - -.. code-block:: bash - - lightning login - -A page opens in your browser where you can follow the instructions to complete the setup. +You get free credits every month you can spend on GPU compute. +To use machines with multiple GPUs or run jobs across machines, you need to be on the `Pro or Teams plan `_. ---- @@ -35,66 +29,107 @@ A page opens in your browser where you can follow the instructions to complete t Launch multi-node training in the cloud *************************************** -**Step 1:** Put your code inside a ``lightning.app.core.work.LightningWork``: +**Step 1:** Start a new Studio. -.. code-block:: python - :emphasize-lines: 5 - :caption: app.py +.. video:: https://pl-public-data.s3.amazonaws.com/assets_lightning/fabric/videos/start-studio-for-mmt.mp4 + :width: 800 + :loop: + :muted: + +| + +**Step 2:** Bring your code into the Studio. You can clone a GitHub repo, drag and drop local files, or use the following demo example: + +.. collapse:: Code Example + + .. code-block:: python + + import lightning as L + import torch + import torch.nn.functional as F + from lightning.pytorch.demos import Transformer, WikiText2 + from torch.utils.data import DataLoader + + + def main(): + L.seed_everything(42) + + fabric = L.Fabric() + fabric.launch() - import lightning as L - from lightning.app.components import FabricMultiNode + # Data + with fabric.rank_zero_first(): + dataset = WikiText2() + train_dataloader = DataLoader(dataset, batch_size=20, shuffle=True) - # 1. Put your code inside a LightningWork - class MyTrainingComponent(L.LightningWork): - def run(self): - # Set up Fabric - # The `devices` and `num_nodes` gets set by Lightning automatically - fabric = L.Fabric(strategy="ddp", precision="16-mixed") + # Model + model = Transformer(vocab_size=dataset.vocab_size) + + # Optimizer + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) - # Your training code - model = ... - optimizer = ... model, optimizer = fabric.setup(model, optimizer) - ... + train_dataloader = fabric.setup_dataloaders(train_dataloader) + + for batch_idx, batch in enumerate(train_dataloader): + input, target = batch + output = model(input, target) + loss = F.nll_loss(output, target.view(-1)) + fabric.backward(loss) + optimizer.step() + optimizer.zero_grad() + + if batch_idx % 10 == 0: + fabric.print(f"iteration: {batch_idx} - loss {loss.item():.4f}") + + + if __name__ == "__main__": + main() + +| -**Step 2:** Init a ``lightning.app.core.app.LightningApp`` with the ``FabricMultiNode`` component. -Configure the number of nodes, the number of GPUs per node, and the type of GPU: +**Step 3:** Remove hardcoded accelerator settings if any and let Lightning automatically set them for you. No other changes are required in your script. .. code-block:: python - :emphasize-lines: 5,7 - :caption: app.py - # 2. Create the app with the FabricMultiNode component inside - app = L.LightningApp( - FabricMultiNode( - MyTrainingComponent, - # Run with 2 nodes - num_nodes=2, - # Each with 4 x V100 GPUs, total 8 GPUs - cloud_compute=L.CloudCompute("gpu-fast-multi"), - ) - ) + # These are the defaults + fabric = L.Fabric(accelerator="auto", devices="auto") + # DON'T hardcode these, leave them default/auto + # fabric = L.Fabric(accelerator="cpu", devices=3) -**Step 3:** Run your code from the CLI: +| -.. code-block:: bash +**Step 4:** Install dependencies and download all necessary data. Test that your script runs in the Studio first. If it runs in the Studio, it will run in multi-node! - lightning run app app.py --cloud +| -This command will upload your Python file and then opens the app admin view, where you can see the logs of what's happening. +**Step 5:** Open the Multi-Machine Training (MMT) app. Type the command to run your script, select the machine type and how many machines you want to launch it on. Click "Run" to start the job. -.. figure:: https://pl-public-data.s3.amazonaws.com/assets_lightning/fabric/fabric-multi-node-admin.png - :alt: The Lightning AI admin page of an app running a multi-node fabric training script - :width: 100% +.. video:: https://pl-public-data.s3.amazonaws.com/assets_lightning/fabric/videos/lightning-ai-mmt-demo-fabric.mp4 + :width: 800 + :loop: + :muted: + +After submitting the job, you will be redirected to a page where you can monitor the machine metrics and logs in real-time. ---- +**************************** +Bring your own cloud account +**************************** + +As a `Teams or Enterprise `_ customer, you have the option to connect your existing cloud account to Lightning AI. +This gives your organization the ability to keep all compute and data on your own cloud account and your Virtual Private Cloud (VPC). + + +---- + ********** -Next steps +Learn more ********** .. raw:: html @@ -103,8 +138,8 @@ Next steps
.. displayitem:: - :header: Lightning Platform - :description: Develop, Train and Deploy models on the cloud + :header: Lightning Studios + :description: Code together. Prototype. Train. Deploy. Host AI web apps. From your browser - with zero setup. :col_css: col-md-4 :button_link: https://lightning.ai :height: 150 diff --git a/docs/source-pytorch/common/checkpointing_expert.rst b/docs/source-pytorch/common/checkpointing_expert.rst index a8f7205596604..23e5215e7cb97 100644 --- a/docs/source-pytorch/common/checkpointing_expert.rst +++ b/docs/source-pytorch/common/checkpointing_expert.rst @@ -136,4 +136,37 @@ Note that you can load the distributed checkpoint even if the world size has cha Convert a distributed checkpoint ******************************** -Coming soon. +It is possible to convert a distributed checkpoint to a regular, single-file checkpoint with this utility: + +.. code-block:: bash + + python -m lightning.pytorch.utilities.consolidate_checkpoint path/to/my/checkpoint + +You will need to do this for example if you want to load the checkpoint into a script that doesn't use FSDP, or need to export the checkpoint to a different format for deployment, evaluation, etc. + +.. note:: + + All tensors in the checkpoint will be converted to CPU tensors, and no GPUs are required to run the conversion command. + This function assumes you have enough free CPU memory to hold the entire checkpoint in memory. + +.. collapse:: Full example + + Assuming you have saved a checkpoint ``epoch=0-step=3.ckpt`` using the examples above, run the following command to convert it: + + .. code-block:: bash + + cd lightning_logs/version_0/checkpoints + python -m lightning.pytorch.utilities.consolidate_checkpoint epoch=0-step=3.ckpt + + This saves a new file ``epoch=0-step=3.ckpt.consolidated`` next to the sharded checkpoint which you can load normally in PyTorch: + + .. code-block:: python + + import torch + + checkpoint = torch.load("epoch=0-step=3.ckpt.consolidated") + print(list(checkpoint.keys())) + print(checkpoint["state_dict"]["model.transformer.decoder.layers.31.norm1.weight"]) + + +| diff --git a/requirements/app/app.txt b/requirements/app/app.txt index 6a042ae7c573b..8c77f23e666b2 100644 --- a/requirements/app/app.txt +++ b/requirements/app/app.txt @@ -1,6 +1,6 @@ lightning-cloud == 0.5.61 # Must be pinned to ensure compatibility packaging -typing-extensions >=4.4.0, <4.8.0 +typing-extensions >=4.4.0, <4.10.0 deepdiff >=5.7.0, <6.6.0 fsspec[http] >=2022.5.0, <2023.11.0 croniter >=1.3.0, <1.5.0 # strict; TODO: for now until we find something more robust. diff --git a/requirements/data/test.txt b/requirements/data/test.txt index d30343b08a628..38439e2d6705a 100644 --- a/requirements/data/test.txt +++ b/requirements/data/test.txt @@ -4,3 +4,4 @@ pytest-cov ==4.1.0 pytest-timeout ==2.1.0 pytest-rerunfailures ==12.0 pytest-random-order ==1.1.0 +viztracer diff --git a/requirements/fabric/base.txt b/requirements/fabric/base.txt index d1496393a8b6f..64bd5cc6368bc 100644 --- a/requirements/fabric/base.txt +++ b/requirements/fabric/base.txt @@ -5,5 +5,5 @@ numpy >=1.17.2, <1.27.0 torch >=1.13.0, <2.2.0 fsspec[http] >=2022.5.0, <2023.11.0 packaging >=20.0, <=23.1 -typing-extensions >=4.4.0, <4.8.0 +typing-extensions >=4.4.0, <4.10.0 lightning-utilities >=0.8.0, <0.10.0 diff --git a/requirements/pytorch/base.txt b/requirements/pytorch/base.txt index 2b790dd277358..17af46699ece9 100644 --- a/requirements/pytorch/base.txt +++ b/requirements/pytorch/base.txt @@ -8,5 +8,5 @@ PyYAML >=5.4, <6.1.0 fsspec[http] >=2022.5.0, <2023.11.0 torchmetrics >=0.7.0, <1.3.0 # needed for using fixed compare_version packaging >=20.0, <=23.1 -typing-extensions >=4.4.0, <4.8.0 +typing-extensions >=4.4.0, <4.10.0 lightning-utilities >=0.8.0, <0.10.0 diff --git a/requirements/pytorch/extra.txt b/requirements/pytorch/extra.txt index 4bdfdcd7ade80..fddaece918b4a 100644 --- a/requirements/pytorch/extra.txt +++ b/requirements/pytorch/extra.txt @@ -5,7 +5,7 @@ matplotlib>3.1, <3.9.0 omegaconf >=2.0.5, <2.4.0 hydra-core >=1.0.5, <1.4.0 -jsonargparse[signatures] >=4.26.1, <4.27.0 +jsonargparse[signatures] >=4.26.1, <4.28.0 rich >=12.3.0, <13.6.0 tensorboardX >=2.2, <2.7.0 # min version is set by torch.onnx missing attribute bitsandbytes ==0.41.0 # strict diff --git a/src/lightning/data/__init__.py b/src/lightning/data/__init__.py index 9750dba22edaa..88b384e8a227b 100644 --- a/src/lightning/data/__init__.py +++ b/src/lightning/data/__init__.py @@ -1,7 +1,7 @@ from lightning.data.streaming.combined import CombinedStreamingDataset from lightning.data.streaming.dataloader import StreamingDataLoader from lightning.data.streaming.dataset import StreamingDataset -from lightning.data.streaming.functions import map, optimize +from lightning.data.streaming.functions import map, optimize, walk __all__ = [ "LightningDataset", @@ -11,4 +11,5 @@ "LightningIterableDataset", "map", "optimize", + "walk", ] diff --git a/src/lightning/data/streaming/client.py b/src/lightning/data/streaming/client.py index cfe757537d218..c42d3af9f6ed1 100644 --- a/src/lightning/data/streaming/client.py +++ b/src/lightning/data/streaming/client.py @@ -7,8 +7,6 @@ if _BOTO3_AVAILABLE: import boto3 import botocore - from botocore.credentials import InstanceMetadataProvider - from botocore.utils import InstanceMetadataFetcher class S3Client: @@ -31,14 +29,8 @@ def client(self) -> Any: # Re-generate credentials for EC2 if self._last_time is None or (time() - self._last_time) > self._refetch_interval: - provider = InstanceMetadataProvider(iam_role_fetcher=InstanceMetadataFetcher(timeout=3600, num_attempts=5)) - credentials = provider.load() self._client = boto3.client( - "s3", - aws_access_key_id=credentials.access_key, - aws_secret_access_key=credentials.secret_key, - aws_session_token=credentials.token, - config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}), + "s3", config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}) ) self._last_time = time() diff --git a/src/lightning/data/streaming/combined.py b/src/lightning/data/streaming/combined.py index aab5691acba9a..7b3373a3e800e 100644 --- a/src/lightning/data/streaming/combined.py +++ b/src/lightning/data/streaming/combined.py @@ -17,20 +17,29 @@ from torch.utils.data import IterableDataset from lightning.data.streaming.dataset import StreamingDataset +from lightning.data.utilities.env import _WorkerEnv + +__NUM_SAMPLES_YIELDED_KEY__ = "__NUM_SAMPLES_YIELDED__" +__SAMPLES_KEY__ = "__SAMPLES__" class CombinedStreamingDataset(IterableDataset): """The `CombinedStreamingDataset` enables to stream data from multiple StreamingDataset with the sampling ratio of your choice. - Addtionally, the `CombinedStreamingDataset` keeps track of the number of - samples fetched to enable resumability of the datasets. + Addtionally, the `CombinedStreamingDataset` keeps track of the number of samples fetched to enable resumability + of the datasets. + + Note that due to the random sampling, the number of samples returned from the iterator is variable and a function + of the given seed. The combined dataset will raise a StopIteration as soon as any of the datasets is exhausted. """ def __init__( self, datasets: List[StreamingDataset], seed: int = 42, weights: Optional[Sequence[float]] = None ) -> None: + self._check_datasets(datasets) + self._seed = seed self._datasets = datasets self._weights = weights @@ -43,34 +52,83 @@ def __init__( self._weights = [w / sum(weights) for w in weights] self._iterator: Optional[_CombinedDatasetIterator] = None + self._use_streaming_dataloader = False + self._num_samples_yielded: Optional[List[int]] = None + self._current_epoch = 0 - def __len__(self) -> int: - assert self._weights - return int(min([1 / w * len(d) for w, d in zip(self._weights, self._datasets) if w > 0])) + def set_epoch(self, current_epoch: int) -> None: + """Set the current epoch to the datasets on epoch starts. + + When using the StreamingDataLoader, this is done automatically + + """ + self._current_epoch = current_epoch + for dataset in self._datasets: + dataset.set_epoch(current_epoch) + + def _check_datasets(self, datasets: List[StreamingDataset]) -> None: + if any(not isinstance(d, StreamingDataset) for d in datasets): + raise RuntimeError("The provided datasets should be instances of the StreamingDataset.") + + def _set_use_streaming_dataloader(self, use_streaming_dataloader: bool) -> None: + # Used to prevent returning num_samples_yielded when using PyTorch DataLoader + self._use_streaming_dataloader = use_streaming_dataloader def __iter__(self) -> Iterator[Any]: assert self._weights - self._iterator = _CombinedDatasetIterator(self._datasets, self._seed, self._weights) + + worker_env = _WorkerEnv.detect() + + num_samples_yielded = None + + if self._num_samples_yielded is not None and worker_env.rank in self._num_samples_yielded: + num_samples_yielded = self._num_samples_yielded[worker_env.rank] + + self._iterator = _CombinedDatasetIterator( + self._datasets, + self._seed, + self._weights, + self._use_streaming_dataloader, + num_samples_yielded, + ) return self._iterator - def state_dict(self, num_workers: int, batch_size: int) -> Dict[str, Any]: + def state_dict( + self, num_workers: int, batch_size: int, num_samples_yielded: Optional[List[int]] = None + ) -> Dict[str, Any]: if self._iterator is None: - return {} + if num_samples_yielded is None: + return {} + return _state_dict(self._datasets, num_samples_yielded, num_workers, batch_size) return self._iterator.state_dict(num_workers, batch_size) def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - if len(state_dict) != len(self._datasets): + if not state_dict: + return + + if len(state_dict["dataset"]) != len(self._datasets): raise RuntimeError(f"The provided state doesn't match the current number of datasets: {self._datasets}.") for dataset_idx, dataset in enumerate(self._datasets): - if str(dataset_idx) not in state_dict: + if str(dataset_idx) not in state_dict["dataset"]: raise RuntimeError(f"The provided state doesn't contain the index {dataset_idx}.") - dataset.load_state_dict(state_dict[str(dataset_idx)]) + dataset.load_state_dict(state_dict["dataset"][str(dataset_idx)]) + + # Used to iterate over the sampler to avoid sampling the same samples + if self._use_streaming_dataloader: + self._num_samples_yielded = state_dict["num_samples_yielded"] class _CombinedDatasetIterator(Iterator): - def __init__(self, datasets: List[StreamingDataset], seed: int, weights: Sequence[float]) -> None: + def __init__( + self, + datasets: List[StreamingDataset], + seed: int, + weights: Sequence[float], + use_streaming_dataloader: bool, + num_samples_yielded: Optional[Any] = None, + ) -> None: self._datasets = datasets self._dataset_iters = [iter(dataset) for dataset in datasets] self._dataset_indexes = list(range(len(datasets))) @@ -78,6 +136,13 @@ def __init__(self, datasets: List[StreamingDataset], seed: int, weights: Sequenc self._weights = weights self._rng = random.Random(seed) + if num_samples_yielded is not None: + self._num_samples_yielded = num_samples_yielded + for _ in range(sum(num_samples_yielded)): + self._rng.choices(self._dataset_indexes, weights=self._weights, k=1) + + self._use_streaming_dataloader = use_streaming_dataloader + def __next__(self) -> Any: # randomly select a dataset index (dataset_index,) = self._rng.choices(self._dataset_indexes, weights=self._weights, k=1) @@ -85,11 +150,26 @@ def __next__(self) -> Any: # keep track the sample was fetched self._num_samples_yielded[dataset_index] += 1 + sample = next(self._dataset_iters[dataset_index]) + # return a new sample - return next(self._dataset_iters[dataset_index]) + if self._use_streaming_dataloader: + return { + __SAMPLES_KEY__: sample, + __NUM_SAMPLES_YIELDED_KEY__: self._num_samples_yielded, + } + return sample def state_dict(self, num_workers: int = 0, batch_size: int = 1) -> Dict[str, Any]: - return { - str(dataset_idx): dataset.state_dict(self._num_samples_yielded[dataset_idx], num_workers, batch_size) - for dataset_idx, dataset in enumerate(self._datasets) - } + return _state_dict(self._datasets, self._num_samples_yielded, num_workers, batch_size) + + +def _state_dict( + datasets: List[StreamingDataset], num_samples_yielded: List[int], num_workers: int = 0, batch_size: int = 1 +) -> Dict[str, Any]: + return { + str(dataset_idx): dataset.state_dict( + num_samples_yielded=num_samples_yielded[dataset_idx], num_workers=num_workers, batch_size=batch_size + ) + for dataset_idx, dataset in enumerate(datasets) + } diff --git a/src/lightning/data/streaming/data_processor.py b/src/lightning/data/streaming/data_processor.py index a589a5906e26c..a6e100e3deb64 100644 --- a/src/lightning/data/streaming/data_processor.py +++ b/src/lightning/data/streaming/data_processor.py @@ -563,7 +563,7 @@ def _handle_data_chunk_recipe_end(self) -> None: def _handle_data_transform_recipe(self, index: int) -> None: # Don't use a context manager to avoid deleting files that are being uploaded. output_dir = tempfile.mkdtemp() - item_data = self.data_recipe.prepare_item(str(output_dir), self.items[index]) + item_data = self.data_recipe.prepare_item(self.items[index], str(output_dir)) if item_data is not None: raise ValueError( "When using a `DataTransformRecipe`, the `prepare_item` shouldn't return anything." @@ -753,7 +753,7 @@ def prepare_structure(self, input_dir: Optional[str]) -> List[T]: """ @abstractmethod - def prepare_item(self, output_dir: str, item_metadata: T) -> None: # type: ignore + def prepare_item(self, item_metadata: T, output_dir: str) -> None: # type: ignore """Use your item metadata to process your files and save the file outputs into `output_dir`.""" diff --git a/src/lightning/data/streaming/dataloader.py b/src/lightning/data/streaming/dataloader.py index bd49e83349a0e..acd0ebef19af6 100644 --- a/src/lightning/data/streaming/dataloader.py +++ b/src/lightning/data/streaming/dataloader.py @@ -15,7 +15,9 @@ import inspect import logging import os +from copy import deepcopy from importlib import reload +from itertools import cycle from typing import Any, Callable, Dict, List, Optional, Union import torch @@ -32,7 +34,11 @@ from torch.utils.data.sampler import BatchSampler, Sampler from lightning.data.streaming import Cache -from lightning.data.streaming.combined import CombinedStreamingDataset +from lightning.data.streaming.combined import ( + __NUM_SAMPLES_YIELDED_KEY__, + __SAMPLES_KEY__, + CombinedStreamingDataset, +) from lightning.data.streaming.constants import _DEFAULT_CHUNK_BYTES, _TORCH_GREATER_EQUAL_2_1_0, _VIZ_TRACKER_AVAILABLE from lightning.data.streaming.dataset import StreamingDataset from lightning.data.streaming.sampler import CacheBatchSampler @@ -341,6 +347,137 @@ def _get_iterator(self) -> "_BaseDataLoaderIter": return _MultiProcessingDataLoaderIterPatch(self) +def _wrapper(fetcher: Any, func: Callable, tracer: Any, profile: int, profile_dir: str) -> Callable: + counter = 0 + + def wrap(*args: Any, **kwargs: Any) -> Any: + nonlocal counter + result = func(*args, **kwargs) + + if tracer.enable and counter == profile: + tracer.stop() + tracer.save() + print( + f"Saved {os.path.join(profile_dir, 'result.json')} file after {profile} batches." + "Use chrome://tracing/ to view it." + ) + fetcher.fetch = func + + counter += 1 + return result + + return wrap + + +class _ProfileWorkerLoop: + """Wrap the PyTorch DataLoader WorkerLoop to add profiling.""" + + def __init__(self, profile: Union[int, bool], profile_dir: Optional[str] = None): + self._profile = profile + self._profile_dir = profile_dir if profile_dir else os.getcwd() + + def __call__( + self, + dataset_kind: Any, + dataset: Any, + index_queue: Any, + data_queue: Any, + done_event: Any, + auto_collation: Any, + collate_fn: Any, + drop_last: Any, + base_seed: Any, + init_fn: Any, + worker_id: Any, + *args: Any, + **kwargs: Any, + ) -> None: + from torch.utils.data._utils import worker + from viztracer import VizTracer + + if worker_id == 0: + output_file = os.path.join(self._profile_dir, "result.json") + + if os.path.exists(output_file): + os.remove(output_file) + + tracer = VizTracer(output_file=output_file, verbose=0) + tracer.start() + + # Reload to remove the patching + reloaded_worker = reload(worker) + create_fetcher = _DatasetKind.create_fetcher + fetcher = None + + def create_fetcher_fn(*args: Any, **kwargs: Any) -> "_BaseDatasetFetcher": + nonlocal fetcher + fetcher = create_fetcher(*args, **kwargs) + + if worker_id == 0 and isinstance(self._profile, int): + fetcher.fetch = _wrapper(fetcher, fetcher.fetch, tracer, self._profile, self._profile_dir) + return fetcher + + _DatasetKind.create_fetcher = create_fetcher_fn # type: ignore + + reloaded_worker._worker_loop( + dataset_kind, + dataset, + index_queue, + data_queue, + done_event, + auto_collation, + collate_fn, + drop_last, + base_seed, + init_fn, + worker_id, + *args, + **kwargs, + ) + + if worker_id == 0 and isinstance(self._profile, bool): + tracer.stop() + tracer.save() + + +class _StreamingMultiProcessingDataLoaderIter(_MultiProcessingDataLoaderIter): + def __init__(self, loader: DataLoader) -> None: + self._loader = loader + self._indexes = ( + list(range(self._loader._latest_worker_idx, self._loader.num_workers)) + if self._loader._latest_worker_idx > 0 + else [] + ) + self._num_workers = loader.num_workers + + distributed_env = _DistributedEnv.detect() + + if self._loader._profile_bactches and distributed_env.global_rank == 0 and _VIZ_TRACKER_AVAILABLE: + from torch.utils.data._utils import worker + + worker._worker_loop = _ProfileWorkerLoop(self._loader._profile_bactches, self._loader._profile_dir) + + super().__init__(loader) + + def _try_put_index(self) -> None: + # Used to restart on the right DataLoader worker + if self._loader.restore and self._indexes: + assert self._tasks_outstanding < self._prefetch_factor * self._num_workers + + try: + index = self._next_index() + except StopIteration: + return + worker_queue_idx = self._indexes.pop(0) + + self._index_queues[worker_queue_idx].put((self._send_idx, index)) + self._task_info[self._send_idx] = (worker_queue_idx,) + self._tasks_outstanding += 1 + self._send_idx += 1 + else: + super()._try_put_index() + + class StreamingDataLoader(DataLoader): """The `StreamingDataLoader` keeps track of the number of samples fetched in order to enable resumability of the dataset.""" @@ -353,29 +490,102 @@ def __init__( *args: Any, batch_size: int = 1, num_workers: int = 0, + profile_bactches: Union[bool, int] = False, + profile_dir: Optional[str] = None, + prefetch_factor: Optional[int] = None, **kwargs: Any, ) -> None: # pyright: ignore + if not isinstance(dataset, (StreamingDataset, CombinedStreamingDataset)): + raise RuntimeError( + "The provided dataset should be either an instance of StreamingDataset or CombinedStreamingDataset." + f" Found {dataset}." + ) + + if profile_bactches and not _VIZ_TRACKER_AVAILABLE: + raise ModuleNotFoundError("To use profile_bactches, viztracer is required. Run `pip install viztracer`") + + if profile_bactches and num_workers == 0: + raise ValueError("Profiling is supported only with num_workers >= 1.") + + self.current_epoch = 0 self.batch_size = batch_size self.num_workers = num_workers - self.num_samples_yielded = 0 - super().__init__(dataset, *args, batch_size=batch_size, num_workers=num_workers, **kwargs) # type: ignore + self._profile_bactches = profile_bactches + self._profile_dir = profile_dir + self._num_samples_yielded_streaming = 0 + self._num_samples_yielded_combined: Dict[int, List[Any]] = {} + self.rng_state: Optional[Any] = None + self._worker_idx = cycle(list(range(self.num_workers if self.num_workers > 0 else 1))) + self._worker_idx_iter: Optional[Any] = None + self._latest_worker_idx = 0 + self.restore = False + super().__init__( + dataset, + *args, + batch_size=batch_size, + num_workers=num_workers, + prefetch_factor=(10 if num_workers > 0 else None) if prefetch_factor is None else prefetch_factor, + **kwargs, + ) # type: ignore def __iter__(self) -> Any: + if not self.restore: + self._latest_worker_idx = 0 + self._worker_idx = cycle(list(range(self.num_workers if self.num_workers > 0 else 1))) + self._worker_idx_iter = iter(self._worker_idx) + self.current_epoch += 1 + self._num_samples_yielded_combined = {} + self._num_samples_yielded_streaming = 0 + + self.dataset.set_epoch(self.current_epoch) + if isinstance(self.dataset, StreamingDataset): assert self.batch_size - self.num_samples_yielded = 0 for batch in super().__iter__(): - self.num_samples_yielded += self.batch_size + self._latest_worker_idx = next(self._worker_idx_iter) # type: ignore + self._num_samples_yielded_streaming += self.batch_size yield batch else: - yield from super().__iter__() + self.dataset._set_use_streaming_dataloader(True) + assert self.batch_size + # TODO: Inject a custom collate function to avoid collating the __NUM_SAMPLES_YIELDED__ key + for batch in super().__iter__(): + self._latest_worker_idx = next(self._worker_idx_iter) # type: ignore + if isinstance(batch, dict) and __NUM_SAMPLES_YIELDED_KEY__ in batch: + self._num_samples_yielded_combined[self._latest_worker_idx] = [ + sample[-1].item() if self.batch_size > 1 else sample.item() + for sample in batch[__NUM_SAMPLES_YIELDED_KEY__] + ] + + yield batch[__SAMPLES_KEY__] + else: + yield batch + + self.restore = False def state_dict(self) -> Dict[str, Any]: if isinstance(self.dataset, StreamingDataset): assert self.batch_size - num_samples = self.num_samples_yielded - return self.dataset.state_dict(num_samples, self.num_workers, self.batch_size) - return self.dataset.state_dict(self.num_workers, self.batch_size) + return { + "dataset": self.dataset.state_dict( + self._num_samples_yielded_streaming, self.num_workers, self.batch_size + ), + "current_epoch": self.current_epoch, + "num_samples_yielded": self._num_samples_yielded_streaming, + "latest_worker_idx": self._latest_worker_idx, + } + + num_samples_yieled = [0 for _ in range(len(list(self._num_samples_yielded_combined.values())[0]))] + for worker_idx in self._num_samples_yielded_combined: + for dataset_idx, samples_yieled in enumerate(self._num_samples_yielded_combined[worker_idx]): + num_samples_yieled[dataset_idx] += samples_yieled + + return { + "dataset": self.dataset.state_dict(self.num_workers, self.batch_size, num_samples_yieled), + "current_epoch": self.current_epoch if self.restore else self.current_epoch - 1, + "latest_worker_idx": self._latest_worker_idx, + "num_samples_yielded": deepcopy(self._num_samples_yielded_combined), + } def load_state_dict(self, obj: Dict[str, Any]) -> None: """Load a dict containing training state (called from non-worker process). @@ -386,7 +596,34 @@ def load_state_dict(self, obj: Dict[str, Any]) -> None: obj (Any): The state. """ - if isinstance(self.dataset, (StreamingDataset, CombinedStreamingDataset)): + self.current_epoch = obj["current_epoch"] + + if isinstance(self.dataset, StreamingDataset): + self._num_samples_yielded_streaming = obj["num_samples_yielded"] + else: + self._num_samples_yielded_combined = obj["num_samples_yielded"] + + # Used to restart on the next DataLoader worker from the previous run. + self._latest_worker_idx = obj["latest_worker_idx"] + 1 + self._worker_idx_iter = iter(self._worker_idx) + for _ in range(self._latest_worker_idx): + next(self._worker_idx_iter) + + # Inform we are resuming and disable resetting the StreamingDataLoader state. + # This is toggle back to False when the `__iter__` method of the StreamingDataLoader completes. + self.restore = True + + if isinstance(self.dataset, CombinedStreamingDataset): + self.dataset._set_use_streaming_dataloader(True) self.dataset.load_state_dict(obj) + elif isinstance(self.dataset, StreamingDataset): + self.dataset.load_state_dict(obj["dataset"]) else: raise RuntimeError("The provided dataset should be a `StreamingDataset` or a `CombinedStreamingDataset`.") + + def _get_iterator(self) -> "_BaseDataLoaderIter": + """Overriden to ensure the `Cache.done()` method is triggered on iteration done.""" + if self.num_workers == 0: + return _SingleProcessDataLoaderIter(self) + self.check_worker_number_rationality() + return _StreamingMultiProcessingDataLoaderIter(self) diff --git a/src/lightning/data/streaming/dataset.py b/src/lightning/data/streaming/dataset.py index ec331a92dec95..e281bbc0a1e2d 100644 --- a/src/lightning/data/streaming/dataset.py +++ b/src/lightning/data/streaming/dataset.py @@ -17,7 +17,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np -from torch.utils.data import IterableDataset, get_worker_info +from torch.utils.data import IterableDataset from lightning.data.streaming import Cache from lightning.data.streaming.constants import ( @@ -29,7 +29,7 @@ from lightning.data.streaming.sampler import ChunkedIndex from lightning.data.streaming.serializers import Serializer from lightning.data.streaming.shuffle import FullShuffle, NoShuffle, Shuffle -from lightning.data.utilities.env import _DistributedEnv, _WorkerEnv +from lightning.data.utilities.env import _DistributedEnv, _is_in_dataloader_worker, _WorkerEnv class StreamingDataset(IterableDataset): @@ -90,6 +90,17 @@ def __init__( self.serializers = serializers self._state_dict: Optional[Dict[str, Any]] = None + def set_epoch(self, current_epoch: int) -> None: + """Set the current epoch to the dataset on epoch starts. + + When using the StreamingDataLoader, this is done automatically + + """ + # If the state dict has been reloaded, don't override the current epoch + # The StreamingDataloader would clean this out + if self._state_dict is None: + self.current_epoch = current_epoch + def _create_cache(self, worker_env: _WorkerEnv) -> Cache: if _should_replace_path(self.input_dir.path): cache_path = _try_create_cache_dir( @@ -119,8 +130,7 @@ def _create_shuffler(self, cache: Cache) -> Shuffle: seed = self.seed drop_last = self.drop_last if self._state_dict is not None: - restart_keys = sorted(self._state_dict) - state: Dict[str, Any] = self._state_dict[restart_keys[-1]] + state: Dict[str, Any] = self._state_dict seed = state["seed"] drop_last = state["drop_last"] return FullShuffle(cache, seed, drop_last) if self.shuffle else NoShuffle(cache, seed, drop_last) @@ -139,8 +149,7 @@ def __iter__(self) -> "StreamingDataset": # Handle restart if self._state_dict: self._validate_state_dict() - restart_keys = sorted(self._state_dict) - state: Dict[str, Any] = self._state_dict[restart_keys[-1]] + state: Dict[str, Any] = self._state_dict self.current_epoch = state["current_epoch"] chunks_per_replica, intervals_per_replica = self.shuffler.get_chunks_and_intervals_per_ranks( @@ -187,16 +196,13 @@ def _resume(self, chunks_replica: List[int], intervals_replica: List[Any]) -> No assert self.worker_env assert self.shuffler - restart_keys = sorted(self._state_dict) - - # Get the state from the previous run - state: Dict[str, Any] = self._state_dict[restart_keys[-1]] + state: Dict[str, Any] = self._state_dict num_workers = state["num_workers"] batch_size = state["batch_size"] # TODO: Implement elastic sampling where the number of workers, ranks can change. - num_samples_yielded = sum([state["num_samples_yielded"] for state in self._state_dict.values()]) + num_samples_yielded = self._state_dict["num_samples_yielded"] # replay sampling from each worker / chunks using the batch size workers_chunks, workers_intervals = _associate_chunks_to_workers( @@ -213,7 +219,7 @@ def _resume(self, chunks_replica: List[int], intervals_replica: List[Any]) -> No self.worker_intervals = workers_intervals[worker_rank] # replay the indexes for the current chunks - interval = workers_intervals[worker_rank][self.chunk_index] + interval = self.worker_intervals[self.chunk_index] current_indexes = np.arange(interval[0], interval[1]) # re-shuffle the indexes @@ -223,6 +229,8 @@ def _resume(self, chunks_replica: List[int], intervals_replica: List[Any]) -> No current_indexes = current_indexes[indexes[worker_rank] :] self.current_indexes = current_indexes + self.global_index = num_samples_yielded + # bump the chunk_index self.chunk_index += 1 @@ -283,6 +291,10 @@ def state_dict(self, num_samples_yielded: int, num_workers: int, batch_size: int if _is_in_dataloader_worker(): raise RuntimeError("The method `state_dict` should only be called in the main process.") + if self._state_dict is not None: + self._state_dict["num_samples_yielded"] = num_samples_yielded + return self._state_dict + state = { "num_samples_yielded": num_samples_yielded, "num_workers": num_workers, @@ -297,10 +309,7 @@ def state_dict(self, num_samples_yielded: int, num_workers: int, batch_size: int "shuffle": self.shuffle, } - if self._state_dict: - num_restarts = len(self._state_dict) - return {**self._state_dict, f"{num_restarts}": state} - return {"0": state} + return state def load_state_dict(self, state_dict: Dict[str, Any]) -> None: if state_dict: @@ -312,8 +321,7 @@ def _validate_state_dict(self) -> None: assert self.worker_env assert self.cache - restart_keys = sorted(self._state_dict) - state: Dict[str, Any] = self._state_dict[restart_keys[-1]] + state: Dict[str, Any] = self._state_dict if state["shuffle"] != self.shuffle: raise ValueError( @@ -327,7 +335,18 @@ def _validate_state_dict(self) -> None: f"Found `{self.worker_env.world_size}` instead of `{state['num_workers']}`." ) - if state["input_dir_path"] != self.input_dir.path: + # Note: We need to check whether the path has been resolved to its associated cache. + # In this case, validate the cache folder is the same. + if _should_replace_path(state["input_dir_path"]): + cache_path = _try_create_cache_dir( + input_dir=state["input_dir_path"] if state["input_dir_path"] else state["input_dir_url"] + ) + if cache_path != self.input_dir.path: + raise ValueError( + "The provided `input_dir` path state doesn't match the current one. " + f"Found `{self.input_dir.path}` instead of `{cache_path}`." + ) + elif state["input_dir_path"] != self.input_dir.path: raise ValueError( "The provided `input_dir` path state doesn't match the current one. " f"Found `{self.input_dir.path}` instead of `{state['input_dir_path']}`." @@ -374,11 +393,7 @@ def _should_replace_path(path: Optional[str]) -> bool: if path is None or path == "": return True - return "/datasets/" in path or "_connections/" in path - - -def _is_in_dataloader_worker() -> bool: - return get_worker_info() is not None + return path.startswith("/teamspace/datasets/") or path.startswith("/teamspace/s3_connections/") def is_integer(value: str) -> bool: diff --git a/src/lightning/data/streaming/downloader.py b/src/lightning/data/streaming/downloader.py index b9097c843e66e..982b4d25142b3 100644 --- a/src/lightning/data/streaming/downloader.py +++ b/src/lightning/data/streaming/downloader.py @@ -12,6 +12,7 @@ # limitations under the License. import os import shutil +import subprocess from abc import ABC from typing import Any, Dict, List from urllib import parse @@ -19,6 +20,7 @@ from filelock import FileLock, Timeout from lightning.data.streaming.client import S3Client +from lightning.data.streaming.constants import _INDEX_FILENAME class Downloader(ABC): @@ -40,7 +42,10 @@ def download_file(self, remote_chunkpath: str, local_chunkpath: str) -> None: class S3Downloader(Downloader): def __init__(self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]]): super().__init__(remote_dir, cache_dir, chunks) - self._client = S3Client() + self._s5cmd_available = os.system("s5cmd > /dev/null 2>&1") == 0 + + if not self._s5cmd_available: + self._client = S3Client() def download_file(self, remote_filepath: str, local_filepath: str) -> None: obj = parse.urlparse(remote_filepath) @@ -48,21 +53,34 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None: if obj.scheme != "s3": raise ValueError(f"Expected obj.scheme to be `s3`, instead, got {obj.scheme} for remote={remote_filepath}") - from boto3.s3.transfer import TransferConfig - - extra_args: Dict[str, Any] = {} + if os.path.exists(local_filepath): + return try: - with FileLock(local_filepath + ".lock", timeout=1): - if not os.path.exists(local_filepath): - # Issue: https://github.com/boto/boto3/issues/3113 - self._client.client.download_file( - obj.netloc, - obj.path.lstrip("/"), - local_filepath, - ExtraArgs=extra_args, - Config=TransferConfig(use_threads=False), + with FileLock(local_filepath + ".lock", timeout=1 if obj.path.endswith(_INDEX_FILENAME) else 0): + if self._s5cmd_available: + proc = subprocess.Popen( + f"s5cmd cp {remote_filepath} {local_filepath}", + shell=True, + stdout=subprocess.PIPE, ) + proc.wait() + else: + from boto3.s3.transfer import TransferConfig + + extra_args: Dict[str, Any] = {} + + # try: + # with FileLock(local_filepath + ".lock", timeout=1): + if not os.path.exists(local_filepath): + # Issue: https://github.com/boto/boto3/issues/3113 + self._client.client.download_file( + obj.netloc, + obj.path.lstrip("/"), + local_filepath, + ExtraArgs=extra_args, + Config=TransferConfig(use_threads=False), + ) except Timeout: # another process is responsible to download that file, continue pass diff --git a/src/lightning/data/streaming/functions.py b/src/lightning/data/streaming/functions.py index 1d06f6a1a3885..f6aefcbcd0b33 100644 --- a/src/lightning/data/streaming/functions.py +++ b/src/lightning/data/streaming/functions.py @@ -11,13 +11,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import concurrent.futures import inspect import os from datetime import datetime from functools import partial from pathlib import Path from types import FunctionType -from typing import Any, Callable, Dict, Optional, Sequence, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import torch @@ -81,24 +82,24 @@ def __init__(self, fn: Callable[[str, Any], None], inputs: Sequence[Any]): params = inspect.signature(_fn).parameters self._contains_device = "device" in params - def prepare_structure(self, input_dir: Optional[str]) -> Any: + def prepare_structure(self, _: Optional[str]) -> Any: return self._inputs - def prepare_item(self, output_dir: str, item_metadata: Any) -> None: # type: ignore + def prepare_item(self, item_metadata: Any, output_dir: str) -> None: # type: ignore if self._contains_device and self._device is None: self._find_device() if isinstance(self._fn, (FunctionType, partial)): if self._contains_device: - self._fn(output_dir, item_metadata, self._device) + self._fn(item_metadata, output_dir, self._device) else: - self._fn(output_dir, item_metadata) + self._fn(item_metadata, output_dir) elif callable(self._fn): if self._contains_device: - self._fn.__call__(output_dir, item_metadata, self._device) # type: ignore + self._fn.__call__(item_metadata, output_dir, self._device) # type: ignore else: - self._fn.__call__(output_dir, item_metadata) # type: ignore + self._fn.__call__(item_metadata, output_dir) # type: ignore else: raise ValueError(f"The provided {self._fn} isn't supported.") @@ -286,3 +287,51 @@ def optimize( num_nodes, machine, ) + + +def _listdir(folder: str) -> Tuple[str, List[str]]: + return folder, os.listdir(folder) + + +class walk: + """This class is an optimized version of os.walk for listing files and folders from cloud filesystem. + + Note: The order of files and folders yielded aren't depth-first anymore due to the asynchronous listing call. + + """ + + def __init__(self, folder: str, max_workers: Optional[int] = os.cpu_count()) -> None: + self.folders = [folder] + self.max_workers = max_workers or 1 + self.futures: List[concurrent.futures.Future] = [] + + def __iter__(self) -> Any: + """This function queues the folders to perform listdir across multiple workers.""" + with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor: + while len(self.folders): + folder = self.folders.pop(0) + future = executor.submit(_listdir, folder) + self.futures.append(future) + + while self.futures: + for future in concurrent.futures.as_completed(self.futures): + filenames = [] + folders = [] + + folder, files_or_folders = future.result() + self.futures = [f for f in self.futures if f != future] + + for file_or_folder in files_or_folders: + if os.path.isfile(os.path.join(folder, file_or_folder)): + filenames.append(file_or_folder) + else: + folders.append(file_or_folder) + self.folders.append(os.path.join(folder, file_or_folder)) + + yield folder, folders, filenames + + while len(self.folders) and len(self.futures) <= self.max_workers * 2: + folder = self.folders.pop(0) + future = executor.submit(_listdir, folder) + self.futures.append(future) + return diff --git a/src/lightning/data/streaming/item_loader.py b/src/lightning/data/streaming/item_loader.py index 6ce919e2804c8..779a683146182 100644 --- a/src/lightning/data/streaming/item_loader.py +++ b/src/lightning/data/streaming/item_loader.py @@ -87,15 +87,11 @@ def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str del self._chunk_filepaths[chunk_filepath] if chunk_filepath not in self._chunk_filepaths: - first_exists = exists = os.path.exists(chunk_filepath) + exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size > 0 while not exists: sleep(0.1) - exists = os.path.exists(chunk_filepath) - - # Wait to avoid any corruption when the file appears - if not first_exists: - sleep(0.001) + exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size > 0 self._chunk_filepaths[chunk_filepath] = True @@ -189,16 +185,12 @@ def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str del self._chunk_filepaths[chunk_filepath] if chunk_filepath not in self._chunk_filepaths: - first_exists = exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size > 0 + exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size > 0 while not exists: sleep(0.1) exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size > 0 - # Wait to avoid any corruption when the file appears - if not first_exists: - sleep(0.1) - self._chunk_filepaths[chunk_filepath] = True self._load_chunk(chunk_index, chunk_filepath) diff --git a/src/lightning/data/streaming/writer.py b/src/lightning/data/streaming/writer.py index 68b2f6c762d2e..e58194289b208 100644 --- a/src/lightning/data/streaming/writer.py +++ b/src/lightning/data/streaming/writer.py @@ -105,12 +105,12 @@ def filled(self) -> bool: return True files = os.listdir(self._cache_dir) index_files = [f for f in files if f.endswith(_INDEX_FILENAME)] - worker_end = _WorkerEnv.detect() + worker_env = _WorkerEnv.detect() data_optimiser_num_workers = os.getenv("DATA_OPTIMIZER_NUM_WORKERS", None) if data_optimiser_num_workers is not None: self._is_done = len(index_files) == int(data_optimiser_num_workers) else: - self._is_done = len(index_files) == self._distributed_env.world_size * worker_end.world_size + self._is_done = len(index_files) == self._distributed_env.world_size * worker_env.world_size return self._is_done @property diff --git a/src/lightning/data/utilities/env.py b/src/lightning/data/utilities/env.py index fa91e714f8666..c9406963d909b 100644 --- a/src/lightning/data/utilities/env.py +++ b/src/lightning/data/utilities/env.py @@ -163,3 +163,7 @@ def __repr__(self) -> str: def __str__(self) -> str: return repr(self) + + +def _is_in_dataloader_worker() -> bool: + return torch_get_worker_info() is not None diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 808d8f3ee2746..32344efc17610 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -24,6 +24,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for clipping gradients by value with FSDP ([#19236](https://github.com/Lightning-AI/lightning/pull/19236)) +- Added a utility function and CLI to consolidate FSDP sharded checkpoints into a single file ([#19213](https://github.com/Lightning-AI/lightning/pull/19213)) + + +- (Experimental) Added support for re-compiling the model inside `Fabric.setup()` over the FSDP/DDP wrappers ([#19280](https://github.com/Lightning-AI/lightning/pull/19280)) + + ### Changed - `seed_everything()` without passing in a seed no longer randomly selects a seed, and now defaults to `0` ([#18846](https://github.com/Lightning-AI/lightning/pull/18846)) diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index 4afed65356e62..3eb12f2afad67 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -76,6 +76,7 @@ _FabricDataLoader, _FabricModule, _FabricOptimizer, + _to_compiled, _unwrap_compiled, _unwrap_objects, ) @@ -213,6 +214,7 @@ def setup( module: nn.Module, *optimizers: Optimizer, move_to_device: bool = True, + _reapply_compile: Optional[bool] = None, ) -> Any: # no specific return because the way we want our API to look does not play well with mypy r"""Set up a model and its optimizers for accelerated training. @@ -221,12 +223,17 @@ def setup( *optimizers: The optimizer(s) to set up (no optimizers is also possible) move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False`` and alternatively use :meth:`to_device` manually. + _reapply_compile: (Experimental) If set to ``True``, and the model was ``torch.compile``d before, the + corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the + same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP, + FSDP etc.). Only supported on PyTorch >= 2.1. Defaults to ``False``, but it may change in the future. Returns: The tuple containing wrapped module and the optimizers, in the same order they were passed in. """ self._validate_setup(module, optimizers) + module, compile_kwargs = _unwrap_compiled(module) if _reapply_compile else (module, None) original_module = module module = self._precision.convert_module(module) @@ -242,6 +249,8 @@ def setup( else: module = self._strategy.setup_module(module) + if compile_kwargs is not None: + module = _to_compiled(module, compile_kwargs) module = _FabricModule(module, self._precision, original_module=original_module) # Update the _DeviceDtypeModuleMixin's device parameter @@ -258,8 +267,8 @@ def setup( self._models_setup += 1 if hasattr(original_module, "_fabric"): # this is probably a LightningModule - original_module._fabric = self # type: ignore[assignment] - original_module._fabric_optimizers = optimizers # type: ignore[assignment] + original_module._fabric = self + original_module._fabric_optimizers = optimizers if original_module not in self._callbacks: self._callbacks.append(original_module) @@ -270,7 +279,9 @@ def setup( return (module, *optimizers) return module - def setup_module(self, module: nn.Module, move_to_device: bool = True) -> _FabricModule: + def setup_module( + self, module: nn.Module, move_to_device: bool = True, _reapply_compile: Optional[bool] = None + ) -> _FabricModule: r"""Set up a model for accelerated training or inference. This is the same as calling ``.setup(model)`` with no optimizers. It is useful for inference or for certain @@ -281,12 +292,17 @@ def setup_module(self, module: nn.Module, move_to_device: bool = True) -> _Fabri module: A :class:`torch.nn.Module` to set up move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False`` and alternatively use :meth:`to_device` manually. + _reapply_compile: (Experimental) If set to ``True``, and the model was ``torch.compile``d before, the + corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the + same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP, + FSDP etc.). Only supported on PyTorch >= 2.1. Defaults to ``False``, but it may change in the future. Returns: The wrapped model. """ self._validate_setup_module(module) + module, compile_kwargs = _unwrap_compiled(module) if _reapply_compile else (module, None) original_module = module module = self._precision.convert_module(module) @@ -296,6 +312,9 @@ def setup_module(self, module: nn.Module, move_to_device: bool = True) -> _Fabri # Let strategy wrap and connect the module alone module = self._strategy.setup_module(module) + + if compile_kwargs is not None: + module = _to_compiled(module, compile_kwargs) module = _FabricModule(module, self._precision, original_module=original_module) # Update the _DeviceDtypeModuleMixin's device parameter @@ -305,7 +324,7 @@ def setup_module(self, module: nn.Module, move_to_device: bool = True) -> _Fabri ) if hasattr(original_module, "_fabric"): # this is probably a LightningModule - original_module._fabric = self # type: ignore[assignment] + original_module._fabric = self if original_module not in self._callbacks: self._callbacks.append(original_module) @@ -410,6 +429,7 @@ def backward(self, tensor: Tensor, *args: Any, model: Optional[_FabricModule] = """ module = model._forward_module if model is not None else model + module, _ = _unwrap_compiled(module) if isinstance(self._strategy, DeepSpeedStrategy): if model is None: if self._models_setup == 0: @@ -641,7 +661,7 @@ def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> Conte skip. """ - module = _unwrap_compiled(module) + module, _ = _unwrap_compiled(module) if not isinstance(module, _FabricModule): raise TypeError( "You need to set up the model first before you can call `fabric.no_backward_sync()`:" @@ -656,7 +676,9 @@ def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> Conte category=PossibleUserWarning, ) return nullcontext() - return self._strategy._backward_sync_control.no_backward_sync(module._forward_module) + + forward_module, _ = _unwrap_compiled(module._forward_module) + return self._strategy._backward_sync_control.no_backward_sync(forward_module) def sharded_model(self) -> ContextManager: r"""Instantiate a model under this context manager to prepare it for model-parallel sharding. @@ -772,7 +794,7 @@ def load( # We need to unwrap objects (see above) but this creates a new dictionary. In-place updates # (for user metadata) wouldn't show up in the original dict, so we need to copy the data back. for k in list(unwrapped_state.keys()): - obj = _unwrap_compiled(state[k]) + obj, _ = _unwrap_compiled(state[k]) if isinstance(obj, (_FabricModule, _FabricOptimizer, _FabricDataLoader)): continue state[k] = unwrapped_state[k] diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index f429a364b5618..420d49605fcba 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -68,7 +68,7 @@ _TORCH_GREATER_EQUAL_2_2, ) from lightning.fabric.utilities.init import _EmptyInit -from lightning.fabric.utilities.load import _lazy_load, _materialize_tensors, _move_state_into +from lightning.fabric.utilities.load import _METADATA_FILENAME, _lazy_load, _materialize_tensors, _move_state_into from lightning.fabric.utilities.rank_zero import rank_zero_deprecation, rank_zero_only, rank_zero_warn from lightning.fabric.utilities.seed import reset_seed from lightning.fabric.utilities.types import _PATH, _Stateful @@ -86,7 +86,6 @@ _SHARDING_STRATEGY = Union[ShardingStrategy, Literal["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD"]] _FSDP_ALIASES = ("fsdp", "fsdp_cpu_offload") -_METADATA_FILENAME = "meta.pt" class FSDPStrategy(ParallelStrategy, _Sharded): diff --git a/src/lightning/fabric/utilities/consolidate_checkpoint.py b/src/lightning/fabric/utilities/consolidate_checkpoint.py new file mode 100644 index 0000000000000..b41e8f8a1312e --- /dev/null +++ b/src/lightning/fabric/utilities/consolidate_checkpoint.py @@ -0,0 +1,79 @@ +import logging +from argparse import ArgumentParser, Namespace +from pathlib import Path + +import torch + +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1 +from lightning.fabric.utilities.load import _METADATA_FILENAME, _load_distributed_checkpoint + +_log = logging.getLogger(__name__) + + +def _parse_cli_args() -> Namespace: + parser = ArgumentParser( + description=( + "Converts a distributed/sharded checkpoint into a single file that can be loaded with `torch.load()`." + " Only supports FSDP sharded checkpoints at the moment." + ), + ) + parser.add_argument( + "checkpoint_folder", + type=str, + help=( + "Path to a checkpoint folder, containing the sharded checkpoint files saved using the" + " `torch.distributed.checkpoint` API." + ), + ) + parser.add_argument( + "--output_file", + type=str, + help=( + "Path to the file where the converted checkpoint should be saved. The file should not already exist." + " If no path is provided, the file will be saved next to the input checkpoint folder with the same name" + " and a '.consolidated' suffix." + ), + ) + return parser.parse_args() + + +def _process_cli_args(args: Namespace) -> Namespace: + if not _TORCH_GREATER_EQUAL_2_1: + _log.error("Processing distributed checkpoints requires PyTorch >= 2.1.") + exit(1) + + checkpoint_folder = Path(args.checkpoint_folder) + if not checkpoint_folder.exists(): + _log.error(f"The provided checkpoint folder does not exist: {checkpoint_folder}") + exit(1) + if not checkpoint_folder.is_dir(): + _log.error( + f"The provided checkpoint path must be a folder, containing the checkpoint shards: {checkpoint_folder}" + ) + exit(1) + if not (checkpoint_folder / _METADATA_FILENAME).is_file(): + _log.error( + "Only FSDP-sharded checkpoints saved with Lightning are supported for consolidation. The provided folder" + f" is not in that format: {checkpoint_folder}" + ) + exit(1) + + if args.output_file is None: + output_file = checkpoint_folder.with_suffix(checkpoint_folder.suffix + ".consolidated") + else: + output_file = Path(args.output_file) + if output_file.exists(): + _log.error( + "The path for the converted checkpoint already exists. Choose a different path by providing" + f" `--output_file` or move/delete the file first: {output_file}" + ) + exit(1) + + return Namespace(checkpoint_folder=checkpoint_folder, output_file=output_file) + + +if __name__ == "__main__": + args = _parse_cli_args() + config = _process_cli_args(args) + checkpoint = _load_distributed_checkpoint(config.checkpoint_folder) + torch.save(checkpoint, config.output_file) diff --git a/src/lightning/fabric/utilities/load.py b/src/lightning/fabric/utilities/load.py index 6925ccfabfd7a..bab8e29823903 100644 --- a/src/lightning/fabric/utilities/load.py +++ b/src/lightning/fabric/utilities/load.py @@ -15,7 +15,8 @@ import warnings from functools import partial from io import BytesIO -from typing import IO, TYPE_CHECKING, Any, Callable, Dict, Optional, OrderedDict, Sequence, Set, Union +from pathlib import Path +from typing import IO, TYPE_CHECKING, Any, Callable, Dict, Optional, OrderedDict, Sequence, Set, Tuple, Union import torch from lightning_utilities.core.apply_func import apply_to_collection @@ -24,9 +25,16 @@ from torch.nn import Parameter from typing_extensions import override -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0 +from lightning.fabric.utilities.imports import ( + _TORCH_GREATER_EQUAL_2_0, + _TORCH_GREATER_EQUAL_2_1, + _TORCH_GREATER_EQUAL_2_2, +) from lightning.fabric.utilities.types import _PATH, _Stateful +_METADATA_FILENAME = "meta.pt" + + if TYPE_CHECKING: from torch.storage import TypedStorage @@ -227,3 +235,76 @@ def _move_state_into( destination[key].load_state_dict(state) else: destination[key] = state + + +def _load_distributed_checkpoint(checkpoint_folder: Path) -> Dict[str, Any]: + """Loads a sharded checkpoint saved with the `torch.distributed.checkpoint` into a full state dict. + + The current implementation assumes that the entire checkpoint fits in CPU memory. + + """ + if not _TORCH_GREATER_EQUAL_2_1: + raise ImportError("Processing distributed checkpoints requires PyTorch >= 2.1.") + + from torch.distributed.checkpoint import FileSystemReader + from torch.distributed.checkpoint.metadata import BytesStorageMetadata, TensorStorageMetadata + + if _TORCH_GREATER_EQUAL_2_2: + from torch.distributed.checkpoint import load + else: + from torch.distributed.checkpoint import load_state_dict as load # deprecated + + reader = FileSystemReader(checkpoint_folder) + metadata = reader.read_metadata() + + # TODO: Add sequential save to avoid storing the entire checkpoint in memory + checkpoint: Dict[str, Any] = {} + for tensor_name, sd_metadata in metadata.state_dict_metadata.items(): + if isinstance(sd_metadata, BytesStorageMetadata): + checkpoint[tensor_name] = "" + elif isinstance(sd_metadata, TensorStorageMetadata): + checkpoint[tensor_name] = torch.empty( + size=sd_metadata.size, + dtype=sd_metadata.properties.dtype, + device=torch.device("cpu"), + memory_format=sd_metadata.properties.memory_format, + layout=sd_metadata.properties.layout, + requires_grad=sd_metadata.properties.requires_grad, + pin_memory=sd_metadata.properties.pin_memory, + ) + + load(state_dict=checkpoint, storage_reader=reader, no_dist=True) + checkpoint = _unflatten_dict(checkpoint, key_map=metadata.planner_data) + + # This is the extra file saved by Fabric, with user data separate from weights and optimizer states + extra_file = checkpoint_folder / _METADATA_FILENAME + extra = torch.load(extra_file, map_location="cpu") if extra_file.is_file() else {} + checkpoint.update(extra) + + return checkpoint + + +def _unflatten_dict(checkpoint: Dict[str, Any], key_map: Dict[str, Tuple[str, ...]]) -> Dict[str, Any]: + """Converts the flat dictionary with keys 'x.y.z...' to a nested dictionary using the provided key map. + + Args: + checkpoint: The flat checkpoint dictionary. + key_map: A dictionary that maps the keys in flattened format 'x.y.z...' to a tuple representing + the index path into the nested dictonary that this function should construct. + + """ + assert checkpoint.keys() == key_map.keys() + converted: Dict[str, Any] = {} + for flat_key in checkpoint: + key_path = key_map[flat_key] + _set_nested_dict_value(converted, key_path, checkpoint[flat_key]) + return converted + + +def _set_nested_dict_value(nested_dict: Dict[str, Any], key_path: Tuple[str, ...], value: Any) -> None: + result = nested_dict + for key in key_path[:-1]: + if key not in result: + result[key] = {} + result = result[key] + result[key_path[-1]] = value diff --git a/src/lightning/fabric/utilities/throughput.py b/src/lightning/fabric/utilities/throughput.py index 6b971a58ada9d..d96673df6451d 100644 --- a/src/lightning/fabric/utilities/throughput.py +++ b/src/lightning/fabric/utilities/throughput.py @@ -597,7 +597,9 @@ def get_available_flops(device: torch.device, dtype: Union[torch.dtype, str]) -> else: from torch_xla.experimental import tpu - device_name = tpu.get_tpu_env()["TYPE"] + tpu_env = tpu.get_tpu_env() + # not all TPU generations define the "TYPE" envar. example: TYPE="V4", ACCELERATOR_TYPE="v4-8" + device_name = tpu_env.get("TYPE") or tpu_env["ACCELERATOR_TYPE"].split("-")[0] chip = device_name.lower() assert isinstance(device_name, str) if chip not in _TPU_FLOPS: diff --git a/src/lightning/fabric/wrappers.py b/src/lightning/fabric/wrappers.py index 16611eb4c754b..1aabb718d9abd 100644 --- a/src/lightning/fabric/wrappers.py +++ b/src/lightning/fabric/wrappers.py @@ -12,8 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect +from copy import deepcopy from functools import wraps -from typing import Any, Callable, Dict, Generator, Iterator, List, Mapping, Optional, TypeVar, Union, overload +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Generator, + Iterator, + List, + Mapping, + Optional, + Tuple, + TypeVar, + Union, + overload, +) import torch from lightning_utilities.core.apply_func import apply_to_collection @@ -32,6 +47,9 @@ from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0 from lightning.fabric.utilities.types import Optimizable +if TYPE_CHECKING: + from torch._dynamo import OptimizedModule + T_destination = TypeVar("T_destination", bound=Dict[str, Any]) _LIGHTNING_MODULE_STEP_METHODS = ("training_step", "validation_step", "test_step", "predict_step") @@ -285,8 +303,8 @@ def _unwrap_objects(collection: Any) -> Any: def _unwrap( obj: Union[_FabricModule, _FabricOptimizer, _FabricDataLoader] ) -> Union[nn.Module, Optimizer, DataLoader]: - if isinstance(unwrapped := _unwrap_compiled(obj), _FabricModule): - return unwrapped._forward_module + if isinstance(unwrapped := _unwrap_compiled(obj)[0], _FabricModule): + return _unwrap_compiled(unwrapped._forward_module)[0] if isinstance(obj, _FabricOptimizer): return obj.optimizer if isinstance(obj, _FabricDataLoader): @@ -302,19 +320,33 @@ def _unwrap( return apply_to_collection(collection, dtype=tuple(types), function=_unwrap) -def _unwrap_compiled(obj: Any) -> Any: +def _unwrap_compiled(obj: Union[Any, "OptimizedModule"]) -> Tuple[Union[Any, nn.Module], Optional[Dict[str, Any]]]: """Removes the :class:`torch._dynamo.OptimizedModule` around the object if it is wrapped. Use this function before instance checks against e.g. :class:`_FabricModule`. """ if not _TORCH_GREATER_EQUAL_2_0: - return obj + # obj can't be an `OptimizedModule` anyway + return obj, None + from torch._dynamo import OptimizedModule if isinstance(obj, OptimizedModule): - return obj._orig_mod - return obj + if (compile_kwargs := getattr(obj, "_compile_kwargs", None)) is None: + raise RuntimeError( + "Failed to determine the arguments that were used to compile the module. Make sure to import" + " lightning before `torch.compile` is used." + ) + return obj._orig_mod, compile_kwargs + return obj, None + + +def _to_compiled(module: nn.Module, compile_kwargs: Dict[str, Any]) -> "OptimizedModule": + if not _TORCH_GREATER_EQUAL_2_0: + raise RuntimeError("Converting to a compiled module is only supported in PyTorch >= 2.0.0") + + return torch.compile(module, **compile_kwargs) # type: ignore[return-value] def is_wrapped(obj: object) -> bool: @@ -328,5 +360,30 @@ def is_wrapped(obj: object) -> bool: obj: The object to test. """ - obj = _unwrap_compiled(obj) + obj, _ = _unwrap_compiled(obj) return isinstance(obj, (_FabricModule, _FabricOptimizer, _FabricDataLoader)) + + +def _capture_compile_kwargs(compile_fn: Callable) -> Callable: + """Wraps the ``torch.compile`` function and captures the compile arguments. + + We extract the compile arguments so that we can reapply ``torch.compile`` in ``Fabric.setup()`` with the + same arguments as the user passed to the original call. The arguments get stored in a dictionary + ``_compile_kwargs`` on the returned compiled module. + + """ + # Limitation: Currently, the global compile config does not get captured on a per-model basis. + # PyTorch will resolve this in the future: https://github.com/pytorch/pytorch/issues/116575 + + @wraps(compile_fn) + def _capture(model: Any, **kwargs: Any) -> Any: + compiled_model = compile_fn(model, **kwargs) + if isinstance(model, nn.Module): + compiled_model._compile_kwargs = deepcopy(kwargs) + return compiled_model + + return _capture + + +if _TORCH_GREATER_EQUAL_2_0: + torch.compile = _capture_compile_kwargs(torch.compile) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index dcb056d59dfac..fb6a835509946 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -24,6 +24,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added the option `ModelCheckpoint(save_last='link')` to create a symbolic link for the 'last.ckpt' file ([#19191](https://github.com/Lightning-AI/lightning/pull/19191)) +- Added a utility function and CLI to consolidate FSDP sharded checkpoints into a single file ([#19213](https://github.com/Lightning-AI/lightning/pull/19213)) + + ### Changed - `seed_everything()` without passing in a seed no longer randomly selects a seed, and now defaults to `0` ([#18846](https://github.com/Lightning-AI/lightning/pull/18846)) @@ -77,6 +80,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue with the ModelCheckpoint callback not saving relative symlinks with `ModelCheckpoint(save_last="link")` ([#19303](https://github.com/Lightning-AI/lightning/pull/19303)) +- Fixed issue where the `_restricted_classmethod_impl` would incorrectly raise a TypeError on inspection rather than on call ([#19332](https://github.com/Lightning-AI/lightning/pull/19332)) + + ## [2.1.3] - 2023-12-21 ### Changed diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index a46a8a2ccf8c6..acfd17c308e6b 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -688,9 +688,11 @@ def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: Return: - :class:`~torch.Tensor` - The loss tensor - - ``dict`` - A dictionary. Can include any keys, but must include the key ``'loss'``. - - ``None`` - Skip to the next batch. This is only supported for automatic optimization. - This is not supported for multi-GPU, TPU, IPU, or DeepSpeed. + - ``dict`` - A dictionary which can include any keys, but must include the key ``'loss'`` in the case of + automatic optimization. + - ``None`` - In automatic optimization, this will skip to the next batch (but is not supported for + multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning + the loss is not required. In this step you'd normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific. diff --git a/src/lightning/pytorch/profilers/pytorch.py b/src/lightning/pytorch/profilers/pytorch.py index 9554439da2fa9..b685fb403497a 100644 --- a/src/lightning/pytorch/profilers/pytorch.py +++ b/src/lightning/pytorch/profilers/pytorch.py @@ -240,9 +240,8 @@ def __init__( table_kwargs: Optional[Dict[str, Any]] = None, **profiler_kwargs: Any, ) -> None: - r"""This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of. - - different operators inside your model - both on the CPU and GPU + r"""This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of + different operators inside your model - both on the CPU and GPU. Args: dirpath: Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the diff --git a/src/lightning/pytorch/utilities/consolidate_checkpoint.py b/src/lightning/pytorch/utilities/consolidate_checkpoint.py new file mode 100644 index 0000000000000..6f150bab0f23c --- /dev/null +++ b/src/lightning/pytorch/utilities/consolidate_checkpoint.py @@ -0,0 +1,30 @@ +import re +from typing import Any, Dict + +import torch + +from lightning.fabric.utilities.consolidate_checkpoint import _parse_cli_args, _process_cli_args +from lightning.fabric.utilities.load import _load_distributed_checkpoint + + +def _format_checkpoint(checkpoint: Dict[str, Any]) -> Dict[str, Any]: + """Converts the special FSDP checkpoint format to the standard format the Lightning Trainer can load.""" + # Rename the model key + checkpoint["state_dict"] = checkpoint.pop("model") + + optimizer_keys = [key for key in checkpoint if re.match("optimizer_[0-9]+", key)] + if not optimizer_keys: + return checkpoint + + # Optimizers are saved in special keys named `optimizer_0`, `optimizer_1`, etc. + # These need to be merged back into a Python list + checkpoint["optimizer_states"] = [checkpoint.pop(f"optimizer_{opt_idx}") for opt_idx in range(len(optimizer_keys))] + return checkpoint + + +if __name__ == "__main__": + args = _parse_cli_args() + config = _process_cli_args(args) + checkpoint = _load_distributed_checkpoint(config.checkpoint_folder) + checkpoint = _format_checkpoint(checkpoint) + torch.save(checkpoint, config.output_file) diff --git a/src/lightning/pytorch/utilities/model_helpers.py b/src/lightning/pytorch/utilities/model_helpers.py index bb307be1db88f..dde6cf2a33b02 100644 --- a/src/lightning/pytorch/utilities/model_helpers.py +++ b/src/lightning/pytorch/utilities/model_helpers.py @@ -11,10 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import functools import inspect import logging import os -from types import MethodType from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, Type, TypeVar from lightning_utilities.core.imports import RequirementCache @@ -108,18 +108,23 @@ class _restricted_classmethod_impl(Generic[_T, _P, _R_co]): """Drop-in replacement for @classmethod, but raises an exception when the decorated method is called on an instance instead of a class type.""" - def __init__(self, method: Callable[Concatenate[_T, _P], _R_co]) -> None: + def __init__(self, method: Callable[Concatenate[Type[_T], _P], _R_co]) -> None: self.method = method def __get__(self, instance: Optional[_T], cls: Type[_T]) -> Callable[_P, _R_co]: - # Workaround for https://github.com/pytorch/pytorch/issues/67146 - is_scripting = any(os.path.join("torch", "jit") in frameinfo.filename for frameinfo in inspect.stack()) - if instance is not None and not is_scripting: - raise TypeError( - f"The classmethod `{cls.__name__}.{self.method.__name__}` cannot be called on an instance." - " Please call it on the class type and make sure the return value is used." - ) - return MethodType(self.method, cls) + # The wrapper ensures that the method can be inspected, but not called on an instance + @functools.wraps(self.method) + def wrapper(*args: Any, **kwargs: Any) -> _R_co: + # Workaround for https://github.com/pytorch/pytorch/issues/67146 + is_scripting = any(os.path.join("torch", "jit") in frameinfo.filename for frameinfo in inspect.stack()) + if instance is not None and not is_scripting: + raise TypeError( + f"The classmethod `{cls.__name__}.{self.method.__name__}` cannot be called on an instance." + " Please call it on the class type and make sure the return value is used." + ) + return self.method(cls, *args, **kwargs) + + return wrapper # trick static type checkers into thinking it's a @classmethod diff --git a/tests/tests_data/streaming/test_client.py b/tests/tests_data/streaming/test_client.py index 0b18a2ae98270..e4d9d80cbdbe9 100644 --- a/tests/tests_data/streaming/test_client.py +++ b/tests/tests_data/streaming/test_client.py @@ -31,12 +31,6 @@ def test_s3_client_with_cloud_space_id(monkeypatch): botocore = mock.MagicMock() monkeypatch.setattr(client, "botocore", botocore) - instance_metadata_provider = mock.MagicMock() - monkeypatch.setattr(client, "InstanceMetadataProvider", instance_metadata_provider) - - instance_metadata_fetcher = mock.MagicMock() - monkeypatch.setattr(client, "InstanceMetadataFetcher", instance_metadata_fetcher) - monkeypatch.setenv("LIGHTNING_CLOUD_SPACE_ID", "dummy") s3 = client.S3Client(1) diff --git a/tests/tests_data/streaming/test_combined.py b/tests/tests_data/streaming/test_combined.py index 1b72ac545e011..5d8d4baa16581 100644 --- a/tests/tests_data/streaming/test_combined.py +++ b/tests/tests_data/streaming/test_combined.py @@ -1,22 +1,36 @@ +import os +import sys +from unittest.mock import ANY + import pytest +import torch +from lightning.data.streaming.cache import Cache from lightning.data.streaming.combined import CombinedStreamingDataset +from lightning.data.streaming.dataloader import StreamingDataLoader +from lightning.data.streaming.dataset import Dir, StreamingDataset from torch.utils.data import IterableDataset +from torch.utils.data.dataloader import DataLoader + + +class TestCombinedStreamingDataset(CombinedStreamingDataset): + def _check_datasets(self, datasets) -> None: + pass def test_combined_dataset_num_samples_yield(): - dataset = CombinedStreamingDataset([range(10), range(0, -10, -1)], 42, weights=(0.5, 0.5)) + dataset = TestCombinedStreamingDataset([range(10), range(0, -10, -1)], 42, weights=(0.5, 0.5)) dataset_iter = iter(dataset) data = list(dataset_iter) assert data == [0, 0, 1, 2, -1, -2, -3, 3, 4, 5, 6, -4, 7, 8, -5, -6, 9, -7, -8] - dataset = CombinedStreamingDataset([range(10), range(0, -10, -1)], 37, weights=(0.5, 0.5)) + dataset = TestCombinedStreamingDataset([range(10), range(0, -10, -1)], 37, weights=(0.5, 0.5)) dataset_iter = iter(dataset) data = list(dataset_iter) assert data == [0, 0, -1, -2, -3, -4, -5, 1, -6, 2, -7, -8, 3, 4, -9, 5] - dataset = CombinedStreamingDataset([range(10), range(0, -10, -1)], 23, weights=(0.5, 0.5)) + dataset = TestCombinedStreamingDataset([range(10), range(0, -10, -1)], 23, weights=(0.5, 0.5)) dataset_iter = iter(dataset) data = [next(dataset_iter) for _ in range(5)] @@ -54,14 +68,14 @@ def load_state_dict(self, state_dict): def test_combined_dataset_state_dict(): - dataset = CombinedStreamingDataset( + dataset = TestCombinedStreamingDataset( [TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], 42, weights=(0.5, 0.5) ) assert dataset.state_dict(0, 1) == {} dataset_iter = iter(dataset) assert dataset.state_dict(0, 1) == {"0": {"counter": 0}, "1": {"counter": 0}} - dataset2 = CombinedStreamingDataset( + dataset2 = TestCombinedStreamingDataset( [TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], 42, weights=(0.5, 0.5) ) assert dataset2.state_dict(0, 1) == {} @@ -96,7 +110,7 @@ def test_combined_dataset_state_dict(): {"0": {"counter": 10}, "1": {"counter": 9}}, ] - dataset2 = CombinedStreamingDataset( + dataset2 = TestCombinedStreamingDataset( [TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], 42, weights=(0.5, 0.5) ) assert dataset2.state_dict(0, 1) == {} @@ -104,7 +118,7 @@ def test_combined_dataset_state_dict(): data_2 = [] for state in states: - dataset.load_state_dict(state) + dataset.load_state_dict({"dataset": state}) data_2.append(next(dataset2_iter)) assert data == data_2 @@ -122,7 +136,7 @@ def test_combined_dataset_state_dict(): ], ) def test_combined_dataset_normalizes_weights(weights, expected): - combined_dataset = CombinedStreamingDataset([[1], [2, 3]], weights=weights, seed=1) + combined_dataset = TestCombinedStreamingDataset([[1], [2, 3]], weights=weights, seed=1) assert combined_dataset._weights == expected @@ -135,28 +149,687 @@ def __init__(self, start, end): def __iter__(self): return iter(range(self._start, self._end)) + def state_dict(self, **kwargs): + return kwargs + + def set_epoch(self, current_epoch): + pass + def test_combined_dataset(): dataset1 = SimpleDataset(0, 10) dataset2 = SimpleDataset(10, 20) - dataset = CombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[1.0, 0.0], seed=12345) + dataset = TestCombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[1.0, 0.0], seed=12345) res = list(dataset) assert res == list(range(0, 10)) dataset1 = SimpleDataset(0, 10) dataset2 = SimpleDataset(10, 20) - dataset = CombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[0.0, 1.0], seed=12345) + dataset = TestCombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[0.0, 1.0], seed=12345) res = list(dataset) assert res == list(range(10, 20)) dataset1 = SimpleDataset(0, 10) dataset2 = SimpleDataset(10, 20) - dataset = CombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[0.5, 0.5], seed=12345) + dataset = TestCombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[0.5, 0.5], seed=12345) res = list(dataset) assert 9 in res or 19 in res if len(res) > 10: assert 0 in res assert 10 in res + + dataset1 = SimpleDataset(0, 10) + dataset2 = SimpleDataset(10, 20) + dataset = TestCombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[0.5, 0.5], seed=12345) + dataloader = DataLoader(dataset, batch_size=2, num_workers=1) + dataloader_iter = iter(dataloader) + assert torch.equal(next(dataloader_iter), torch.Tensor([0, 1])) + + +@pytest.mark.parametrize("batch_size", [1, 2]) +def test_combined_dataset_with_dataloader_and_one_worker(batch_size): + dataset1 = SimpleDataset(0, 10) + dataset2 = SimpleDataset(10, 20) + dataset = TestCombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[0.5, 0.5], seed=12345) + dataloader = StreamingDataLoader(dataset, num_workers=1, batch_size=batch_size, prefetch_factor=1) + dataloader_iter = iter(dataloader) + + if batch_size == 2: + assert torch.equal(next(dataloader_iter), torch.Tensor([0, 1])) + assert torch.equal(next(dataloader_iter), torch.Tensor([10, 2])) + assert torch.equal(next(dataloader_iter), torch.Tensor([3, 4])) + assert torch.equal(next(dataloader_iter), torch.Tensor([11, 5])) + assert torch.equal(next(dataloader_iter), torch.Tensor([6, 7])) + assert torch.equal(next(dataloader_iter), torch.Tensor([12, 8])) + + else: + assert torch.equal(next(dataloader_iter), torch.Tensor([0])) + assert torch.equal(next(dataloader_iter), torch.Tensor([1])) + assert torch.equal(next(dataloader_iter), torch.Tensor([10])) + assert torch.equal(next(dataloader_iter), torch.Tensor([2])) + assert torch.equal(next(dataloader_iter), torch.Tensor([3])) + assert torch.equal(next(dataloader_iter), torch.Tensor([4])) + assert torch.equal(next(dataloader_iter), torch.Tensor([11])) + assert torch.equal(next(dataloader_iter), torch.Tensor([5])) + assert torch.equal(next(dataloader_iter), torch.Tensor([6])) + assert torch.equal(next(dataloader_iter), torch.Tensor([7])) + assert torch.equal(next(dataloader_iter), torch.Tensor([12])) + assert torch.equal(next(dataloader_iter), torch.Tensor([8])) + + assert dataloader.state_dict() == { + "dataset": { + "0": {"num_samples_yielded": 9, "num_workers": 1, "batch_size": batch_size}, + "1": {"num_samples_yielded": 3, "num_workers": 1, "batch_size": batch_size}, + }, + "current_epoch": 0, + "latest_worker_idx": 0, + "num_samples_yielded": {0: [9, 3]}, + } + + +@pytest.mark.skipif(sys.platform == "win32" or sys.platform == "darwin", reason="too slow in CI") +def test_combined_dataset_with_dataloader_2_epochs(tmpdir): + data_dir_1 = os.path.join(tmpdir, "data_1") + data_dir_2 = os.path.join(tmpdir, "data_2") + cache_dir_1 = os.path.join(tmpdir, "cache_dir_1") + cache_dir_2 = os.path.join(tmpdir, "cache_dir_2") + + os.makedirs(data_dir_1) + os.makedirs(data_dir_2) + os.makedirs(cache_dir_1) + os.makedirs(cache_dir_2) + + cache = Cache(input_dir=str(data_dir_1), chunk_size=2) + + for i in range(10): + cache[i] = i + + cache.done() + cache.merge() + + cache = Cache(input_dir=str(data_dir_2), chunk_size=2) + + for i in range(10): + cache[i] = i + 5 + + cache.done() + cache.merge() + + dataset1 = StreamingDataset(input_dir=Dir(cache_dir_1, data_dir_1), shuffle=True) + dataset2 = StreamingDataset(input_dir=Dir(cache_dir_2, data_dir_2), shuffle=True) + dataset = CombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[0.5, 0.5], seed=12345) + dataloader = StreamingDataLoader(dataset, num_workers=3, batch_size=2) + + assert dataset1.current_epoch == 1 + assert dataset2.current_epoch == 1 + + batches_1 = [] + states_1 = [] + for batch in dataloader: + batches_1.append(batch) + states_1.append(dataloader.state_dict()) + + assert dataset1.current_epoch == 1 + assert dataset2.current_epoch == 1 + + batches_2 = [] + states_2 = [] + for batch in dataloader: + batches_2.append(batch) + states_2.append(dataloader.state_dict()) + assert dataset1.current_epoch == 2 + assert dataset2.current_epoch == 2 + + assert sum(torch.equal(b1, b2) for b1, b2 in zip(batches_1, batches_2)) != len(batches_1) + + assert states_1 == [ + { + "dataset": { + "0": { + "num_samples_yielded": 2, + "num_workers": 3, + "batch_size": 2, + "current_epoch": 1, + "input_dir_path": ANY, + "input_dir_url": ANY, + "item_loader": None, + "drop_last": False, + "seed": 42, + "world_size": 1, + "shuffle": True, + }, + "1": { + "num_samples_yielded": 0, + "num_workers": 3, + "batch_size": 2, + "current_epoch": 1, + "input_dir_path": ANY, + "input_dir_url": ANY, + "item_loader": None, + "drop_last": False, + "seed": 42, + "world_size": 1, + "shuffle": True, + }, + }, + "current_epoch": 0, + "latest_worker_idx": 0, + "num_samples_yielded": {0: [2, 0]}, + }, + { + "dataset": { + "0": { + "num_samples_yielded": 4, + "num_workers": 3, + "batch_size": 2, + "current_epoch": 1, + "input_dir_path": ANY, + "input_dir_url": ANY, + "item_loader": None, + "drop_last": False, + "seed": 42, + "world_size": 1, + "shuffle": True, + }, + "1": { + "num_samples_yielded": 0, + "num_workers": 3, + "batch_size": 2, + "current_epoch": 1, + "input_dir_path": ANY, + "input_dir_url": ANY, + "item_loader": None, + "drop_last": False, + "seed": 42, + "world_size": 1, + "shuffle": True, + }, + }, + "current_epoch": 0, + "latest_worker_idx": 1, + "num_samples_yielded": {0: [2, 0], 1: [2, 0]}, + }, + { + "dataset": { + "0": { + "num_samples_yielded": 6, + "num_workers": 3, + "batch_size": 2, + "current_epoch": 1, + "input_dir_path": ANY, + "input_dir_url": ANY, + "item_loader": None, + "drop_last": False, + "seed": 42, + "world_size": 1, + "shuffle": True, + }, + "1": { + "num_samples_yielded": 0, + "num_workers": 3, + "batch_size": 2, + "current_epoch": 1, + "input_dir_path": ANY, + "input_dir_url": ANY, + "item_loader": None, + "drop_last": False, + "seed": 42, + "world_size": 1, + "shuffle": True, + }, + }, + "current_epoch": 0, + "latest_worker_idx": 2, + "num_samples_yielded": {0: [2, 0], 1: [2, 0], 2: [2, 0]}, + }, + { + "dataset": { + "0": { + "num_samples_yielded": 7, + "num_workers": 3, + "batch_size": 2, + "current_epoch": 1, + "input_dir_path": ANY, + "input_dir_url": ANY, + "item_loader": None, + "drop_last": False, + "seed": 42, + "world_size": 1, + "shuffle": True, + }, + "1": { + "num_samples_yielded": 1, + "num_workers": 3, + "batch_size": 2, + "current_epoch": 1, + "input_dir_path": ANY, + "input_dir_url": ANY, + "item_loader": None, + "drop_last": False, + "seed": 42, + "world_size": 1, + "shuffle": True, + }, + }, + "current_epoch": 0, + "latest_worker_idx": 0, + "num_samples_yielded": {0: [3, 1], 1: [2, 0], 2: [2, 0]}, + }, + { + "dataset": { + "0": { + "num_samples_yielded": 8, + "num_workers": 3, + "batch_size": 2, + "current_epoch": 1, + "input_dir_path": ANY, + "input_dir_url": ANY, + "item_loader": None, + "drop_last": False, + "seed": 42, + "world_size": 1, + "shuffle": True, + }, + "1": { + "num_samples_yielded": 2, + "num_workers": 3, + "batch_size": 2, + "current_epoch": 1, + "input_dir_path": ANY, + "input_dir_url": ANY, + "item_loader": None, + "drop_last": False, + "seed": 42, + "world_size": 1, + "shuffle": True, + }, + }, + "current_epoch": 0, + "latest_worker_idx": 1, + "num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [2, 0]}, + }, + { + "dataset": { + "0": { + "num_samples_yielded": 9, + "num_workers": 3, + "batch_size": 2, + "current_epoch": 1, + "input_dir_path": ANY, + "input_dir_url": ANY, + "item_loader": None, + "drop_last": False, + "seed": 42, + "world_size": 1, + "shuffle": True, + }, + "1": { + "num_samples_yielded": 3, + "num_workers": 3, + "batch_size": 2, + "current_epoch": 1, + "input_dir_path": ANY, + "input_dir_url": ANY, + "item_loader": None, + "drop_last": False, + "seed": 42, + "world_size": 1, + "shuffle": True, + }, + }, + "current_epoch": 0, + "latest_worker_idx": 2, + "num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [3, 1]}, + }, + { + "dataset": { + "0": { + "num_samples_yielded": 11, + "num_workers": 3, + "batch_size": 2, + "current_epoch": 1, + "input_dir_path": ANY, + "input_dir_url": ANY, + "item_loader": None, + "drop_last": False, + "seed": 42, + "world_size": 1, + "shuffle": True, + }, + "1": { + "num_samples_yielded": 3, + "num_workers": 3, + "batch_size": 2, + "current_epoch": 1, + "input_dir_path": ANY, + "input_dir_url": ANY, + "item_loader": None, + "drop_last": False, + "seed": 42, + "world_size": 1, + "shuffle": True, + }, + }, + "current_epoch": 0, + "latest_worker_idx": 0, + "num_samples_yielded": {0: [5, 1], 1: [3, 1], 2: [3, 1]}, + }, + { + "dataset": { + "0": { + "num_samples_yielded": 13, + "num_workers": 3, + "batch_size": 2, + "current_epoch": 1, + "input_dir_path": ANY, + "input_dir_url": ANY, + "item_loader": None, + "drop_last": False, + "seed": 42, + "world_size": 1, + "shuffle": True, + }, + "1": { + "num_samples_yielded": 3, + "num_workers": 3, + "batch_size": 2, + "current_epoch": 1, + "input_dir_path": ANY, + "input_dir_url": ANY, + "item_loader": None, + "drop_last": False, + "seed": 42, + "world_size": 1, + "shuffle": True, + }, + }, + "current_epoch": 0, + "latest_worker_idx": 1, + "num_samples_yielded": {0: [5, 1], 1: [5, 1], 2: [3, 1]}, + }, + ] + + assert states_2 == [ + { + "dataset": { + "0": { + "num_samples_yielded": 2, + "num_workers": 3, + "batch_size": 2, + "current_epoch": 2, + "input_dir_path": ANY, + "input_dir_url": ANY, + "item_loader": None, + "drop_last": False, + "seed": 42, + "world_size": 1, + "shuffle": True, + }, + "1": { + "num_samples_yielded": 0, + "num_workers": 3, + "batch_size": 2, + "current_epoch": 2, + "input_dir_path": ANY, + "input_dir_url": ANY, + "item_loader": None, + "drop_last": False, + "seed": 42, + "world_size": 1, + "shuffle": True, + }, + }, + "current_epoch": 1, + "latest_worker_idx": 0, + "num_samples_yielded": {0: [2, 0]}, + }, + { + "dataset": { + "0": { + "num_samples_yielded": 4, + "num_workers": 3, + "batch_size": 2, + "current_epoch": 2, + "input_dir_path": ANY, + "input_dir_url": ANY, + "item_loader": None, + "drop_last": False, + "seed": 42, + "world_size": 1, + "shuffle": True, + }, + "1": { + "num_samples_yielded": 0, + "num_workers": 3, + "batch_size": 2, + "current_epoch": 2, + "input_dir_path": ANY, + "input_dir_url": ANY, + "item_loader": None, + "drop_last": False, + "seed": 42, + "world_size": 1, + "shuffle": True, + }, + }, + "current_epoch": 1, + "latest_worker_idx": 1, + "num_samples_yielded": {0: [2, 0], 1: [2, 0]}, + }, + { + "dataset": { + "0": { + "num_samples_yielded": 6, + "num_workers": 3, + "batch_size": 2, + "current_epoch": 2, + "input_dir_path": ANY, + "input_dir_url": ANY, + "item_loader": None, + "drop_last": False, + "seed": 42, + "world_size": 1, + "shuffle": True, + }, + "1": { + "num_samples_yielded": 0, + "num_workers": 3, + "batch_size": 2, + "current_epoch": 2, + "input_dir_path": ANY, + "input_dir_url": ANY, + "item_loader": None, + "drop_last": False, + "seed": 42, + "world_size": 1, + "shuffle": True, + }, + }, + "current_epoch": 1, + "latest_worker_idx": 2, + "num_samples_yielded": {0: [2, 0], 1: [2, 0], 2: [2, 0]}, + }, + { + "dataset": { + "0": { + "num_samples_yielded": 7, + "num_workers": 3, + "batch_size": 2, + "current_epoch": 2, + "input_dir_path": ANY, + "input_dir_url": ANY, + "item_loader": None, + "drop_last": False, + "seed": 42, + "world_size": 1, + "shuffle": True, + }, + "1": { + "num_samples_yielded": 1, + "num_workers": 3, + "batch_size": 2, + "current_epoch": 2, + "input_dir_path": ANY, + "input_dir_url": ANY, + "item_loader": None, + "drop_last": False, + "seed": 42, + "world_size": 1, + "shuffle": True, + }, + }, + "current_epoch": 1, + "latest_worker_idx": 0, + "num_samples_yielded": {0: [3, 1], 1: [2, 0], 2: [2, 0]}, + }, + { + "dataset": { + "0": { + "num_samples_yielded": 8, + "num_workers": 3, + "batch_size": 2, + "current_epoch": 2, + "input_dir_path": ANY, + "input_dir_url": ANY, + "item_loader": None, + "drop_last": False, + "seed": 42, + "world_size": 1, + "shuffle": True, + }, + "1": { + "num_samples_yielded": 2, + "num_workers": 3, + "batch_size": 2, + "current_epoch": 2, + "input_dir_path": ANY, + "input_dir_url": ANY, + "item_loader": None, + "drop_last": False, + "seed": 42, + "world_size": 1, + "shuffle": True, + }, + }, + "current_epoch": 1, + "latest_worker_idx": 1, + "num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [2, 0]}, + }, + { + "dataset": { + "0": { + "num_samples_yielded": 9, + "num_workers": 3, + "batch_size": 2, + "current_epoch": 2, + "input_dir_path": ANY, + "input_dir_url": ANY, + "item_loader": None, + "drop_last": False, + "seed": 42, + "world_size": 1, + "shuffle": True, + }, + "1": { + "num_samples_yielded": 3, + "num_workers": 3, + "batch_size": 2, + "current_epoch": 2, + "input_dir_path": ANY, + "input_dir_url": ANY, + "item_loader": None, + "drop_last": False, + "seed": 42, + "world_size": 1, + "shuffle": True, + }, + }, + "current_epoch": 1, + "latest_worker_idx": 2, + "num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [3, 1]}, + }, + { + "dataset": { + "0": { + "num_samples_yielded": 11, + "num_workers": 3, + "batch_size": 2, + "current_epoch": 2, + "input_dir_path": ANY, + "input_dir_url": ANY, + "item_loader": None, + "drop_last": False, + "seed": 42, + "world_size": 1, + "shuffle": True, + }, + "1": { + "num_samples_yielded": 3, + "num_workers": 3, + "batch_size": 2, + "current_epoch": 2, + "input_dir_path": ANY, + "input_dir_url": ANY, + "item_loader": None, + "drop_last": False, + "seed": 42, + "world_size": 1, + "shuffle": True, + }, + }, + "current_epoch": 1, + "latest_worker_idx": 0, + "num_samples_yielded": {0: [5, 1], 1: [3, 1], 2: [3, 1]}, + }, + { + "dataset": { + "0": { + "num_samples_yielded": 13, + "num_workers": 3, + "batch_size": 2, + "current_epoch": 2, + "input_dir_path": ANY, + "input_dir_url": ANY, + "item_loader": None, + "drop_last": False, + "seed": 42, + "world_size": 1, + "shuffle": True, + }, + "1": { + "num_samples_yielded": 3, + "num_workers": 3, + "batch_size": 2, + "current_epoch": 2, + "input_dir_path": ANY, + "input_dir_url": ANY, + "item_loader": None, + "drop_last": False, + "seed": 42, + "world_size": 1, + "shuffle": True, + }, + }, + "current_epoch": 1, + "latest_worker_idx": 1, + "num_samples_yielded": {0: [5, 1], 1: [5, 1], 2: [3, 1]}, + }, + ] + + dataloader.load_state_dict(states_2[1]) + + assert dataloader.restore + + batches_23 = [] + states_23 = [] + for batch in dataloader: + batches_23.append(batch) + states_23.append(dataloader.state_dict()) + + assert sum(not torch.equal(b1, b2) for b1, b2 in zip(batches_2[2:], batches_23)) == 0 + assert states_23[0]["current_epoch"] == 1 + + assert not dataloader.restore diff --git a/tests/tests_data/streaming/test_data_processor.py b/tests/tests_data/streaming/test_data_processor.py index 9c8ba63391d86..33edc1d724b14 100644 --- a/tests/tests_data/streaming/test_data_processor.py +++ b/tests/tests_data/streaming/test_data_processor.py @@ -581,7 +581,7 @@ def prepare_structure(self, input_dir: str): filepaths = [os.path.join(input_dir, filename) for filename in os.listdir(input_dir)] return [filepath for filepath in filepaths if os.path.isfile(filepath)] - def prepare_item(self, output_dir: str, filepath: Any) -> None: + def prepare_item(self, filepath: Any, output_dir: str) -> None: from PIL import Image img = Image.open(filepath) @@ -628,7 +628,7 @@ def test_data_process_transform(monkeypatch, tmpdir): assert img.size == (12, 12) -def map_fn(output_dir, filepath): +def map_fn(filepath, output_dir): from PIL import Image img = Image.open(filepath) @@ -833,7 +833,7 @@ def fn(output_dir, item, device): data_recipe = LambdaDataTransformRecipe(fn, range(1)) - data_recipe.prepare_item("", 1) + data_recipe.prepare_item(1, "") assert called @@ -847,13 +847,13 @@ def test_lambda_transform_recipe_class(monkeypatch): called = False class Transform: - def __call__(self, output_dir, item, device): + def __call__(self, item, output_dir, device): nonlocal called assert device == "cuda:2" called = True data_recipe = LambdaDataTransformRecipe(Transform(), range(1)) - data_recipe.prepare_item("", 1) + data_recipe.prepare_item(1, "") assert called @@ -894,7 +894,7 @@ def test_get_item_filesizes(tmp_path): _get_item_filesizes([str(tmp_path / "empty_file")]) -def map_fn_index(output_dir, index): +def map_fn_index(index, output_dir): with open(os.path.join(output_dir, f"{index}.JPEG"), "w") as f: f.write("Hello") diff --git a/tests/tests_data/streaming/test_dataloader.py b/tests/tests_data/streaming/test_dataloader.py index 88c2f84def0ff..293a96636adae 100644 --- a/tests/tests_data/streaming/test_dataloader.py +++ b/tests/tests_data/streaming/test_dataloader.py @@ -1,5 +1,9 @@ +import os + +import pytest import torch from lightning.data.streaming import CombinedStreamingDataset, StreamingDataLoader +from lightning.data.streaming import dataloader as streaming_dataloader_module from torch import tensor @@ -29,9 +33,17 @@ def state_dict(self, *args, **kwargs): def load_state_dict(self, state_dict): self.counter = state_dict["counter"] + def set_epoch(self, current_epoch): + pass + + +class TestCombinedStreamingDataset(CombinedStreamingDataset): + def _check_datasets(self, datasets) -> None: + pass + def test_streaming_dataloader(): - dataset = CombinedStreamingDataset( + dataset = TestCombinedStreamingDataset( [TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], 42, weights=(0.5, 0.5) ) dataloader = StreamingDataLoader(dataset, batch_size=2) @@ -56,4 +68,27 @@ def test_streaming_dataloader(): for exp, gen in zip(expected, batches): assert torch.equal(exp, gen) - assert dataloader.state_dict() == {"0": {"counter": 10}, "1": {"counter": 9}} + assert dataloader.state_dict() == { + "dataset": {"0": {"counter": 10}, "1": {"counter": 9}}, + "current_epoch": 0, + "latest_worker_idx": 0, + "num_samples_yielded": {0: [11, 9]}, + } + + +@pytest.mark.parametrize("profile", [2, True]) +def test_dataloader_profiling(profile, tmpdir, monkeypatch): + monkeypatch.setattr(streaming_dataloader_module, "_VIZ_TRACKER_AVAILABLE", True) + + dataset = TestCombinedStreamingDataset( + [TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], 42, weights=(0.5, 0.5) + ) + dataloader = StreamingDataLoader( + dataset, batch_size=2, profile_bactches=profile, profile_dir=str(tmpdir), num_workers=1 + ) + dataloader_iter = iter(dataloader) + batches = [] + for batch in dataloader_iter: + batches.append(batch) + + assert os.path.exists(os.path.join(tmpdir, "result.json")) diff --git a/tests/tests_data/streaming/test_dataset.py b/tests/tests_data/streaming/test_dataset.py index f3f26595ec44f..db294f6ea564f 100644 --- a/tests/tests_data/streaming/test_dataset.py +++ b/tests/tests_data/streaming/test_dataset.py @@ -69,8 +69,10 @@ def test_streaming_dataset(tmpdir, monkeypatch): def test_should_replace_path(): assert _should_replace_path(None) assert _should_replace_path("") - assert _should_replace_path(".../datasets/...") - assert _should_replace_path(".../_connections/...") + assert not _should_replace_path(".../datasets/...") + assert not _should_replace_path(".../s3__connections/...") + assert _should_replace_path("/teamspace/datasets/...") + assert _should_replace_path("/teamspace/s3_connections/...") assert not _should_replace_path("something_else") @@ -647,21 +649,22 @@ def test_resumable_dataset_two_workers(tmpdir): _ = next(dataloader_iter) state_dict_0 = dataloader.state_dict() - assert state_dict_0["0"]["num_samples_yielded"] == 2 - assert state_dict_0["0"]["num_workers"] == 2 - assert state_dict_0["0"]["batch_size"] == 2 + + assert state_dict_0["dataset"]["num_samples_yielded"] == 2 + assert state_dict_0["dataset"]["num_workers"] == 2 + assert state_dict_0["dataset"]["batch_size"] == 2 _ = next(dataloader_iter) state_dict_1 = dataloader.state_dict() - assert state_dict_1["0"]["num_samples_yielded"] == 4 - assert state_dict_1["0"]["num_workers"] == 2 - assert state_dict_1["0"]["batch_size"] == 2 + assert state_dict_1["dataset"]["num_samples_yielded"] == 4 + assert state_dict_1["dataset"]["num_workers"] == 2 + assert state_dict_1["dataset"]["batch_size"] == 2 batch_2 = next(dataloader_iter) state_dict_2 = dataloader.state_dict() - assert state_dict_2["0"]["num_samples_yielded"] == 6 - assert state_dict_2["0"]["num_workers"] == 2 - assert state_dict_2["0"]["batch_size"] == 2 + assert state_dict_2["dataset"]["num_samples_yielded"] == 6 + assert state_dict_2["dataset"]["num_workers"] == 2 + assert state_dict_2["dataset"]["batch_size"] == 2 dataset = EmulateS3StreamingDataset( input_dir=Dir(cache_dir, data_dir), @@ -669,21 +672,17 @@ def test_resumable_dataset_two_workers(tmpdir): shuffle=True, ) - dataset.load_state_dict(state_dict_1) dataloader = StreamingDataLoader(dataset, num_workers=2, batch_size=2, prefetch_factor=1) + dataloader.load_state_dict(state_dict_1) dataloader_iter = iter(dataloader) batch_0_restart = next(dataloader_iter) - state_dict_2 = dataloader.state_dict() - assert len(state_dict_2) == 2 - assert state_dict_2["0"]["num_samples_yielded"] == 4 - assert state_dict_2["0"]["num_workers"] == 2 - assert state_dict_2["0"]["batch_size"] == 2 + state_dict_2 = dataloader.state_dict()["dataset"] - assert state_dict_2["1"]["num_samples_yielded"] == 2 - assert state_dict_2["1"]["num_workers"] == 2 - assert state_dict_2["1"]["batch_size"] == 2 + assert state_dict_2["num_samples_yielded"] == 6 + assert state_dict_2["num_workers"] == 2 + assert state_dict_2["batch_size"] == 2 assert torch.equal(batch_2, batch_0_restart) @@ -738,7 +737,7 @@ def test_resumable_dataset_two_workers_2_epochs(tmpdir): @pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs") -def test_dataset_valid_state(tmpdir): +def test_dataset_valid_state(tmpdir, monkeypatch): seed_everything(42) data_dir = os.path.join(tmpdir, "data") @@ -776,7 +775,7 @@ def test_dataset_valid_state(tmpdir): dataset._validate_state_dict() - state_dict["0"]["drop_last"] = True + state_dict["drop_last"] = True dataset.load_state_dict(state_dict) with pytest.raises( ValueError, @@ -784,7 +783,7 @@ def test_dataset_valid_state(tmpdir): ): dataset._validate_state_dict() - state_dict["0"]["item_loader"] = {} + state_dict["item_loader"] = {} dataset.load_state_dict(state_dict) with pytest.raises( ValueError, @@ -792,7 +791,7 @@ def test_dataset_valid_state(tmpdir): ): dataset._validate_state_dict() - state_dict["0"]["seed"] = 12 + state_dict["seed"] = 12 dataset.load_state_dict(state_dict) with pytest.raises( ValueError, @@ -800,7 +799,7 @@ def test_dataset_valid_state(tmpdir): ): dataset._validate_state_dict() - state_dict["0"]["input_dir_url"] = "toto" + state_dict["input_dir_url"] = "toto" dataset.load_state_dict(state_dict) with pytest.raises( ValueError, @@ -808,7 +807,7 @@ def test_dataset_valid_state(tmpdir): ): dataset._validate_state_dict() - state_dict["0"]["input_dir_path"] = "toto" + state_dict["input_dir_path"] = "toto" dataset.load_state_dict(state_dict) with pytest.raises( ValueError, @@ -816,7 +815,15 @@ def test_dataset_valid_state(tmpdir): ): dataset._validate_state_dict() - state_dict["0"]["num_workers"] = "8" + state_dict["input_dir_path"] = "/teamspace/datasets/coco" + dataset.load_state_dict(state_dict) + with pytest.raises( + ValueError, + match=f"The provided `input_dir` path state doesn't match the current one. Found `{cache_dir}` instead of ", # noqa E501 + ): + dataset._validate_state_dict() + + state_dict["num_workers"] = "8" dataset.load_state_dict(state_dict) with pytest.raises( ValueError, @@ -824,7 +831,7 @@ def test_dataset_valid_state(tmpdir): ): dataset._validate_state_dict() - state_dict["0"]["shuffle"] = True + state_dict["shuffle"] = True dataset.load_state_dict(state_dict) with pytest.raises( ValueError, diff --git a/tests/tests_data/streaming/test_downloader.py b/tests/tests_data/streaming/test_downloader.py new file mode 100644 index 0000000000000..3d4e5421711c2 --- /dev/null +++ b/tests/tests_data/streaming/test_downloader.py @@ -0,0 +1,13 @@ +import os +from unittest.mock import MagicMock + +from lightning.data.streaming.downloader import S3Downloader, subprocess + + +def test_s3_downloader_fast(tmpdir, monkeypatch): + monkeypatch.setattr(os, "system", MagicMock(return_value=0)) + popen_mock = MagicMock() + monkeypatch.setattr(subprocess, "Popen", MagicMock(return_value=popen_mock)) + downloader = S3Downloader(tmpdir, tmpdir, []) + downloader.download_file("s3://random_bucket/a.txt", os.path.join(tmpdir, "a.txt")) + popen_mock.wait.assert_called() diff --git a/tests/tests_data/streaming/test_functions.py b/tests/tests_data/streaming/test_functions.py index b5d1b737b5c43..10bf40caf7c2f 100644 --- a/tests/tests_data/streaming/test_functions.py +++ b/tests/tests_data/streaming/test_functions.py @@ -1,8 +1,10 @@ +import os import sys from unittest import mock import pytest -from lightning.data.streaming.functions import _get_input_dir, os +from lightning.data import walk +from lightning.data.streaming.functions import _get_input_dir @pytest.mark.skipif(sys.platform == "win32", reason="currently not supported for windows.") @@ -19,3 +21,17 @@ def fn(*_, **__): with pytest.raises(ValueError, match="The provided item didn't contain any filepaths."): assert _get_input_dir(["", "/teamspace/studios/asd/b"]) + + +def test_walk(tmpdir): + for i in range(5): + folder_path = os.path.join(tmpdir, str(i)) + os.makedirs(folder_path, exist_ok=True) + for j in range(5): + filepath = os.path.join(folder_path, f"{j}.txt") + with open(filepath, "w") as f: + f.write("hello world !") + + walks_os = sorted(os.walk(tmpdir)) + walks_function = sorted(walk(tmpdir)) + assert walks_os == walks_function diff --git a/tests/tests_fabric/plugins/precision/test_amp_integration.py b/tests/tests_fabric/plugins/precision/test_amp_integration.py index cc9c687487f92..c4f4a856007bc 100644 --- a/tests/tests_fabric/plugins/precision/test_amp_integration.py +++ b/tests/tests_fabric/plugins/precision/test_amp_integration.py @@ -12,10 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Integration tests for Automatic Mixed Precision (AMP) training.""" +import sys + import pytest import torch import torch.nn as nn from lightning.fabric import Fabric, seed_everything +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 from tests_fabric.helpers.runif import RunIf @@ -37,6 +40,11 @@ def forward(self, x): return output +@pytest.mark.xfail( + # https://github.com/pytorch/pytorch/issues/116056 + sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2, + reason="Windows + DDP issue in PyTorch 2.2", +) @pytest.mark.parametrize( ("accelerator", "precision", "expected_dtype"), [ diff --git a/tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py b/tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py index ee1bee5cc5c81..c986ee0db91bb 100644 --- a/tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py +++ b/tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py @@ -11,10 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import sys + import pytest import torch import torch.nn as nn from lightning.fabric import Fabric +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 from tests_fabric.helpers.runif import RunIf @@ -28,6 +31,11 @@ def __init__(self): self.register_buffer("buffer", torch.ones(3)) +@pytest.mark.xfail( + # https://github.com/pytorch/pytorch/issues/116056 + sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2, + reason="Windows + DDP issue in PyTorch 2.2", +) @pytest.mark.parametrize("strategy", ["ddp_spawn", pytest.param("ddp_fork", marks=RunIf(skip_windows=True))]) def test_memory_sharing_disabled(strategy): """Test that the multiprocessing launcher disables memory sharing on model parameters and buffers to avoid race diff --git a/tests/tests_fabric/strategies/test_ddp_integration.py b/tests/tests_fabric/strategies/test_ddp_integration.py index a0c810669d8da..251d1870d90ac 100644 --- a/tests/tests_fabric/strategies/test_ddp_integration.py +++ b/tests/tests_fabric/strategies/test_ddp_integration.py @@ -11,16 +11,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import sys from copy import deepcopy +from unittest import mock +from unittest.mock import Mock import pytest import torch from lightning.fabric import Fabric +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_2 +from torch.nn.parallel.distributed import DistributedDataParallel from tests_fabric.helpers.runif import RunIf from tests_fabric.strategies.test_single_device import _run_test_clip_gradients +from tests_fabric.test_fabric import BoringModel +@pytest.mark.xfail( + # https://github.com/pytorch/pytorch/issues/116056 + sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2, + reason="Windows + DDP issue in PyTorch 2.2", +) @pytest.mark.parametrize( "accelerator", [ @@ -64,6 +76,40 @@ def assert_params_equal(params0, params1): assert_params_equal(params_before, wrapped_model.parameters()) +@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.1.0", dynamo=True) +@mock.patch( + "lightning.fabric.wrappers.torch.compile", + Mock(wraps=(torch.compile if _TORCH_GREATER_EQUAL_2_0 else None)), +) +@mock.patch.dict(os.environ, {}) +def test_reapply_compile(): + """Test that Fabric can rewrap a compiled module such that compilation happens over the DDP-wrapper.""" + from torch._dynamo import OptimizedModule + + fabric = Fabric(accelerator="cuda", devices=2, strategy="ddp") + fabric.launch() + + model = BoringModel() + compile_kwargs = {"mode": "reduce-overhead"} + compiled_model = torch.compile(model, **compile_kwargs) + torch.compile.reset_mock() + + fabric_model = fabric.setup(compiled_model, _reapply_compile=True) + + assert isinstance(fabric_model._forward_module, OptimizedModule) + assert isinstance(fabric_model._forward_module._orig_mod, DistributedDataParallel) + # Assert we called compile again with the same arguments, but on the DDP-wrapped module + torch.compile.assert_called_with(fabric_model._forward_module._orig_mod, **compile_kwargs) + + assert fabric_model._original_module == model + assert fabric_model._forward_module._orig_mod.module == model + assert fabric_model.device == fabric.device + + # Smoke-testing forward to ensure we don't get compilation errors + for _ in range(3): + fabric_model(torch.randn(2, 32, device=fabric.device)).sum().backward() + + @pytest.mark.parametrize( ("clip_type", "accelerator", "precision"), [ diff --git a/tests/tests_fabric/strategies/test_fsdp_integration.py b/tests/tests_fabric/strategies/test_fsdp_integration.py index dc7b18095e0c3..704a86a997bc5 100644 --- a/tests/tests_fabric/strategies/test_fsdp_integration.py +++ b/tests/tests_fabric/strategies/test_fsdp_integration.py @@ -15,13 +15,18 @@ from copy import deepcopy from pathlib import Path from unittest import mock +from unittest.mock import Mock import pytest import torch from lightning.fabric import Fabric from lightning.fabric.plugins import FSDPPrecision from lightning.fabric.strategies import FSDPStrategy -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_1 +from lightning.fabric.utilities.imports import ( + _TORCH_GREATER_EQUAL_2_0, + _TORCH_GREATER_EQUAL_2_1, +) +from lightning.fabric.utilities.load import _load_distributed_checkpoint from lightning.fabric.wrappers import _FabricOptimizer from torch.distributed.fsdp import FlatParameter, FullyShardedDataParallel, OptimStateKeyType from torch.distributed.fsdp.wrap import always_wrap_policy, wrap @@ -344,33 +349,40 @@ def test_setup_with_orig_params_and_multiple_param_groups(): assert not isinstance(layer.weight, FlatParameter) -@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, dynamo=True) -@mock.patch.dict(os.environ, {}) -@pytest.mark.parametrize( - "compile_after_setup", - [ - False, - # https://github.com/pytorch/pytorch/issues/97811 - pytest.param(True, marks=RunIf(min_python="3.9")), - ], +@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.1.0", dynamo=True, skip_windows=True) +@mock.patch( + "lightning.fabric.wrappers.torch.compile", + Mock(wraps=(torch.compile if _TORCH_GREATER_EQUAL_2_0 else None)), ) -def test_compile(compile_after_setup): - """Test that the model can be compiled before and after the model is wrapped in FSDP.""" - model = BoringModel() +@mock.patch.dict(os.environ, {}) +def test_reapply_compile(): + """Test that Fabric can rewrap a compiled module such that compilation happens over the FSDP-wrapper.""" + from torch._dynamo import OptimizedModule + strategy = FSDPStrategy(auto_wrap_policy=always_wrap_policy) fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy) fabric.launch() - if not compile_after_setup: - model = torch.compile(model) + model = BoringModel() + compile_kwargs = {"mode": "reduce-overhead"} + compiled_model = torch.compile(model, **compile_kwargs) + torch.compile.reset_mock() + + fabric_model = fabric.setup(compiled_model, _reapply_compile=True) - model = fabric.setup(model) + assert isinstance(fabric_model._forward_module, OptimizedModule) + assert isinstance(fabric_model._forward_module._orig_mod, FullyShardedDataParallel) - if compile_after_setup: - model = torch.compile(model) + # Assert we called compile again with the same arguments, but on the FSDP-wrapped module + torch.compile.assert_called_with(fabric_model._forward_module._orig_mod, **compile_kwargs) + assert fabric_model._original_module == model + assert fabric_model._forward_module._orig_mod.module == model + assert fabric_model.device == fabric.device + + # Smoke-testing forward to ensure we don't get compilation errors for _ in range(3): - model(torch.rand(2, 32, device=fabric.device)).sum().backward() + fabric_model(torch.randn(2, 32, device=fabric.device)).sum().backward() @RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True) @@ -549,3 +561,52 @@ def test_clip_gradients(clip_type, precision): optimizer.step() optimizer.zero_grad() + + +# TODO: Support checkpoint consolidation with PyTorch >= 2.2 +@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.1.0", max_torch="2.2.0") +def test_save_sharded_and_consolidate_and_load(tmp_path): + """Test the consolidation of a FSDP-sharded checkpoint into a single file.""" + + fabric = Fabric( + accelerator="cuda", + strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy, state_dict_type="sharded"), + devices=2, + ) + fabric.launch() + + model = BoringModel() + optimizer = torch.optim.Adam(model.parameters()) + model, optimizer = fabric.setup(model, optimizer) + state = {"model": model, "optimizer": optimizer, "steps": 1} + + # run one iteration to init the state of the optimizer + model(torch.rand(1, 32, device=fabric.device)).sum().backward() + optimizer.step() + + checkpoint_path_sharded = fabric.broadcast(str(tmp_path / "checkpoint_sharded")) + fabric.save(checkpoint_path_sharded, state) + assert set(os.listdir(checkpoint_path_sharded)) == {"meta.pt", ".metadata", "__0_0.distcp", "__1_0.distcp"} + + # consolidate the checkpoint to a single file + checkpoint_path_full = fabric.broadcast(str(tmp_path / "checkpoint_full.pt")) + if fabric.global_rank == 0: + checkpoint = _load_distributed_checkpoint(Path(checkpoint_path_sharded)) + torch.save(checkpoint, checkpoint_path_full) + fabric.barrier() + + # re-init and load from full checkpoint + fabric = Fabric( + accelerator="cuda", + strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy), + devices=2, + ) + + # Hack: we already called launch() on another Fabric instance above + fabric._launched = True + + model = BoringModel() + optimizer = torch.optim.Adam(model.parameters()) + model, optimizer = fabric.setup(model, optimizer) + state = {"model": model, "optimizer": optimizer, "steps": 1} + fabric.load(checkpoint_path_full, state) diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py index b9e7a9a3430d5..2c860eb45d78d 100644 --- a/tests/tests_fabric/test_fabric.py +++ b/tests/tests_fabric/test_fabric.py @@ -89,18 +89,28 @@ def test_setup_module(ddp_mock, setup_method): @RunIf(skip_windows=True, dynamo=True) @pytest.mark.parametrize("setup_method", ["setup", "setup_module"]) -def test_setup_compiled_module(setup_method): +@pytest.mark.parametrize("reapply_compile", [True, False, None]) +def test_setup_compiled_module(reapply_compile, setup_method): """Test that an `OptimizedModule` can be passed to the setup method.""" from torch._dynamo.eval_frame import OptimizedModule fabric = Fabric(devices=1) model = nn.Linear(1, 2) compiled_model = torch.compile(model) + assert compiled_model._compile_kwargs is not None assert isinstance(compiled_model, OptimizedModule) setup_method = getattr(fabric, setup_method) - fabric_model = setup_method(compiled_model) - - assert fabric_model.module == compiled_model + fabric_model = setup_method(compiled_model, _reapply_compile=reapply_compile) + + assert isinstance(fabric_model._forward_module, OptimizedModule) + if reapply_compile: + # The forward_module got rewrapped into a new OptimizedModule + assert fabric_model._forward_module != fabric_model._original_module + # The original_module points to the pure module + assert fabric_model._original_module is model + assert fabric_model._forward_module._orig_mod is model + else: + assert fabric_model._forward_module is fabric_model._original_module # Attributes get passed through assert fabric_model.weight is model.weight diff --git a/tests/tests_fabric/test_wrappers.py b/tests/tests_fabric/test_wrappers.py index 820d9032f5b99..09d731c064d39 100644 --- a/tests/tests_fabric/test_wrappers.py +++ b/tests/tests_fabric/test_wrappers.py @@ -24,6 +24,7 @@ _FabricDataLoader, _FabricModule, _FabricOptimizer, + _unwrap_compiled, _unwrap_objects, is_wrapped, ) @@ -593,3 +594,23 @@ def normal_method(self): fabric_module = _FabricModule(forward_module=original_module, precision=Mock(), original_module=original_module) assert fabric_module.training_step == original_module.training_step assert fabric_module.validation_step == original_module.validation_step + + +@RunIf(dynamo=True) +def test_unwrap_compiled(): + model = torch.nn.Linear(1, 1) + + with mock.patch("lightning.fabric.wrappers", "_TORCH_GREATER_EQUAL_2_0", False): + unwrapped, compile_kwargs = _unwrap_compiled(model) + assert unwrapped is model + assert compile_kwargs is None + + compiled = torch.compile(model, fullgraph=True, dynamic=True, disable=False) + assert compiled._compile_kwargs == {"fullgraph": True, "dynamic": True, "disable": False} + unwrapped, compile_kwargs = _unwrap_compiled(compiled) + assert unwrapped is compiled._orig_mod + assert compile_kwargs == {"fullgraph": True, "dynamic": True, "disable": False} + + del compiled._compile_kwargs + with pytest.raises(RuntimeError, match="Failed to determine the arguments that were used to compile the module"): + _unwrap_compiled(compiled) diff --git a/tests/tests_fabric/utilities/test_consolidate_checkpoint.py b/tests/tests_fabric/utilities/test_consolidate_checkpoint.py new file mode 100644 index 0000000000000..16feb3d3c1014 --- /dev/null +++ b/tests/tests_fabric/utilities/test_consolidate_checkpoint.py @@ -0,0 +1,97 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from argparse import Namespace +from pathlib import Path +from unittest import mock + +import lightning.fabric +import pytest +from lightning.fabric.utilities.consolidate_checkpoint import _parse_cli_args, _process_cli_args +from lightning.fabric.utilities.load import _METADATA_FILENAME + + +@pytest.mark.parametrize( + ("args", "expected"), + [ + (["path/to/checkpoint"], {"checkpoint_folder": "path/to/checkpoint", "output_file": None}), + ( + ["path/to/checkpoint", "--output_file", "path/to/output"], + {"checkpoint_folder": "path/to/checkpoint", "output_file": "path/to/output"}, + ), + ], +) +def test_parse_cli_args(args, expected): + with mock.patch("sys.argv", ["any.py", *args]): + args = _parse_cli_args() + assert vars(args) == expected + + +def test_process_cli_args(tmp_path, caplog, monkeypatch): + # PyTorch version < 2.1 + monkeypatch.setattr(lightning.fabric.utilities.consolidate_checkpoint, "_TORCH_GREATER_EQUAL_2_1", False) + with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises( + SystemExit + ): + _process_cli_args(Namespace()) + assert "requires PyTorch >= 2.1." in caplog.text + caplog.clear() + monkeypatch.setattr(lightning.fabric.utilities.consolidate_checkpoint, "_TORCH_GREATER_EQUAL_2_1", True) + + # Checkpoint does not exist + checkpoint_folder = Path("does/not/exist") + with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises( + SystemExit + ): + _process_cli_args(Namespace(checkpoint_folder=checkpoint_folder)) + assert f"checkpoint folder does not exist: {checkpoint_folder}" in caplog.text + caplog.clear() + + # Checkpoint exists but is not a folder + file = tmp_path / "checkpoint_file" + file.touch() + with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises( + SystemExit + ): + _process_cli_args(Namespace(checkpoint_folder=file)) + assert "checkpoint path must be a folder" in caplog.text + caplog.clear() + + # Checkpoint exists but is not an FSDP checkpoint + folder = tmp_path / "checkpoint_folder" + folder.mkdir() + with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises( + SystemExit + ): + _process_cli_args(Namespace(checkpoint_folder=folder)) + assert "Only FSDP-sharded checkpoints saved with Lightning are supported" in caplog.text + caplog.clear() + + # Checkpoint is a FSDP folder, output file not specified + (folder / _METADATA_FILENAME).touch() + config = _process_cli_args(Namespace(checkpoint_folder=folder, output_file=None)) + assert vars(config) == { + "checkpoint_folder": folder, + "output_file": folder.with_suffix(folder.suffix + ".consolidated"), + } + + # Checkpoint is a FSDP folder, output file already exists + file = tmp_path / "ouput_file" + file.touch() + with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises( + SystemExit + ): + _process_cli_args(Namespace(checkpoint_folder=folder, output_file=file)) + assert "path for the converted checkpoint already exists" in caplog.text + caplog.clear() diff --git a/tests/tests_fabric/utilities/test_distributed.py b/tests/tests_fabric/utilities/test_distributed.py index ad32c936a1c4d..4badff47403eb 100644 --- a/tests/tests_fabric/utilities/test_distributed.py +++ b/tests/tests_fabric/utilities/test_distributed.py @@ -1,5 +1,6 @@ import functools import os +import sys from functools import partial from pathlib import Path from unittest import mock @@ -17,6 +18,7 @@ _sync_ddp, is_shared_filesystem, ) +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 from tests_fabric.helpers.runif import RunIf @@ -118,6 +120,11 @@ def test_collective_operations(devices, process): spawn_launch(process, devices) +@pytest.mark.xfail( + # https://github.com/pytorch/pytorch/issues/116056 + sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2, + reason="Windows + DDP issue in PyTorch 2.2", +) @pytest.mark.flaky(reruns=3) # flaky with "process 0 terminated with signal SIGABRT" (GLOO) def test_is_shared_filesystem(tmp_path, monkeypatch): # In the non-distributed case, every location is interpreted as 'shared' diff --git a/tests/tests_fabric/utilities/test_load.py b/tests/tests_fabric/utilities/test_load.py index 3dccf2258e74b..eb534bf1cdca1 100644 --- a/tests/tests_fabric/utilities/test_load.py +++ b/tests/tests_fabric/utilities/test_load.py @@ -14,7 +14,13 @@ import pytest import torch import torch.nn as nn -from lightning.fabric.utilities.load import _lazy_load, _materialize_tensors, _move_state_into, _NotYetLoadedTensor +from lightning.fabric.utilities.load import ( + _lazy_load, + _materialize_tensors, + _move_state_into, + _NotYetLoadedTensor, + _unflatten_dict, +) from tests_fabric.helpers.runif import RunIf @@ -139,3 +145,24 @@ def load_state_dict(self, state_dict): assert source == {} assert destination["cocofruit"] == 2 assert destination["banana"].count == 100 + + +def test_unflatten_dict(): + assert _unflatten_dict({}, {}) == {} + + tensor0 = torch.rand(2, 2) + tensor1 = torch.tensor(3.0) + data = { + "model.layer.weight": tensor0, + "optimizer.state.layer.weight.exp_avg": {"test": tensor1}, + "optimizer.param_groups": "param_groups", + } + key_map = { + "model.layer.weight": ("model", "layer.weight"), + "optimizer.state.layer.weight.exp_avg": ("optimizer", "state", "layer.weight", "exp_avg"), + "optimizer.param_groups": ("optimizer", "param_groups"), + } + assert _unflatten_dict(data, key_map) == { + "model": {"layer.weight": tensor0}, + "optimizer": {"state": {"layer.weight": {"exp_avg": {"test": tensor1}}}, "param_groups": "param_groups"}, + } diff --git a/tests/tests_fabric/utilities/test_spike.py b/tests/tests_fabric/utilities/test_spike.py index 9739540af7f18..ab7d9f474383d 100644 --- a/tests/tests_fabric/utilities/test_spike.py +++ b/tests/tests_fabric/utilities/test_spike.py @@ -4,6 +4,7 @@ import pytest import torch from lightning.fabric import Fabric +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 from lightning.fabric.utilities.spike import _TORCHMETRICS_GREATER_EQUAL_1_0_0, SpikeDetection, TrainingSpikeException @@ -28,6 +29,11 @@ def spike_detection_test(fabric, global_rank_spike, spike_value, should_raise): ) +@pytest.mark.xfail( + # https://github.com/pytorch/pytorch/issues/116056 + sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2, + reason="Windows + DDP issue in PyTorch 2.2", +) @pytest.mark.flaky(max_runs=3) @pytest.mark.parametrize( ("global_rank_spike", "num_devices", "spike_value", "finite_only"), diff --git a/tests/tests_fabric/utilities/test_throughput.py b/tests/tests_fabric/utilities/test_throughput.py index 2a66ce25e9861..f2c3de30a325b 100644 --- a/tests/tests_fabric/utilities/test_throughput.py +++ b/tests/tests_fabric/utilities/test_throughput.py @@ -49,8 +49,8 @@ def test_get_available_flops(xla_available): from torch_xla.experimental import tpu assert isinstance(tpu, Mock) - tpu.get_tpu_env.return_value = {"TYPE": "V4"} + tpu.get_tpu_env.return_value = {"TYPE": "V4"} flops = get_available_flops(torch.device("xla"), torch.bfloat16) assert flops == 275e12 @@ -58,6 +58,10 @@ def test_get_available_flops(xla_available): with pytest.warns(match="not found for TPU 'V1'"): assert get_available_flops(torch.device("xla"), torch.bfloat16) is None + tpu.get_tpu_env.return_value = {"ACCELERATOR_TYPE": "v3-8"} + flops = get_available_flops(torch.device("xla"), torch.bfloat16) + assert flops == 123e12 + tpu.reset_mock() diff --git a/tests/tests_pytorch/callbacks/test_spike.py b/tests/tests_pytorch/callbacks/test_spike.py index f4d0c946cefa0..b3bc54bd6e836 100644 --- a/tests/tests_pytorch/callbacks/test_spike.py +++ b/tests/tests_pytorch/callbacks/test_spike.py @@ -3,6 +3,7 @@ import pytest import torch +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 from lightning.fabric.utilities.spike import _TORCHMETRICS_GREATER_EQUAL_1_0_0, TrainingSpikeException from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.callbacks.spike import SpikeDetection @@ -46,6 +47,11 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx) +@pytest.mark.xfail( + # https://github.com/pytorch/pytorch/issues/116056 + sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2, + reason="Windows + DDP issue in PyTorch 2.2", +) @pytest.mark.flaky(max_runs=3) @pytest.mark.parametrize( ("global_rank_spike", "num_devices", "spike_value", "finite_only"), diff --git a/tests/tests_pytorch/loops/test_prediction_loop.py b/tests/tests_pytorch/loops/test_prediction_loop.py index 2f4b281836510..5f6b9d6328d55 100644 --- a/tests/tests_pytorch/loops/test_prediction_loop.py +++ b/tests/tests_pytorch/loops/test_prediction_loop.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools +import sys import pytest +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper @@ -50,6 +52,11 @@ def predict_step(self, batch, batch_idx): assert trainer.predict_loop.predictions == [] +@pytest.mark.xfail( + # https://github.com/pytorch/pytorch/issues/116056 + sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2, + reason="Windows + DDP issue in PyTorch 2.2", +) @pytest.mark.parametrize("use_distributed_sampler", [False, True]) def test_prediction_loop_batch_sampler_set_epoch_called(tmp_path, use_distributed_sampler): """Tests that set_epoch is called on the dataloader's batch sampler (if any) during prediction.""" diff --git a/tests/tests_pytorch/models/test_amp.py b/tests/tests_pytorch/models/test_amp.py index d51b139d6118d..5608c02026072 100644 --- a/tests/tests_pytorch/models/test_amp.py +++ b/tests/tests_pytorch/models/test_amp.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import sys from unittest import mock import pytest import torch from lightning.fabric.plugins.environments import SLURMEnvironment +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from torch.utils.data import DataLoader @@ -53,7 +55,16 @@ def _assert_autocast_enabled(self): [ ("single_device", "16-mixed", 1), ("single_device", "bf16-mixed", 1), - ("ddp_spawn", "16-mixed", 2), + pytest.param( + "ddp_spawn", + "16-mixed", + 2, + marks=pytest.mark.xfail( + # https://github.com/pytorch/pytorch/issues/116056 + sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2, + reason="Windows + DDP issue in PyTorch 2.2", + ), + ), pytest.param("ddp_spawn", "bf16-mixed", 2, marks=RunIf(skip_windows=True)), ], ) diff --git a/tests/tests_pytorch/serve/test_servable_module_validator.py b/tests/tests_pytorch/serve/test_servable_module_validator.py index 578619ece75eb..9d299da460770 100644 --- a/tests/tests_pytorch/serve/test_servable_module_validator.py +++ b/tests/tests_pytorch/serve/test_servable_module_validator.py @@ -1,7 +1,9 @@ +import sys from typing import Dict import pytest import torch +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.serve.servable_module_validator import ServableModule, ServableModuleValidator @@ -36,6 +38,11 @@ def test_servable_module_validator(): callback.on_train_start(Trainer(accelerator="cpu"), model) +@pytest.mark.xfail( + # https://github.com/pytorch/pytorch/issues/116056 + sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2, + reason="Windows + DDP issue in PyTorch 2.2", +) @pytest.mark.flaky(reruns=3) def test_servable_module_validator_with_trainer(tmpdir): callback = ServableModuleValidator() diff --git a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py index beab1e20f46b3..d70a4776b81bd 100644 --- a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py +++ b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import sys from multiprocessing import Process from unittest import mock from unittest.mock import ANY, Mock, call, patch @@ -19,6 +20,7 @@ import pytest import torch from lightning.fabric.plugins import ClusterEnvironment +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.strategies import DDPStrategy @@ -194,6 +196,11 @@ def on_fit_start(self) -> None: assert torch.equal(self.layer.weight.data, self.tied_layer.weight.data) +@pytest.mark.xfail( + # https://github.com/pytorch/pytorch/issues/116056 + sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2, + reason="Windows + DDP issue in PyTorch 2.2", +) def test_memory_sharing_disabled(): """Test that the multiprocessing launcher disables memory sharing on model parameters and buffers to avoid race conditions on model updates.""" @@ -214,6 +221,11 @@ def test_check_for_missing_main_guard(): launcher.launch(function=Mock()) +@pytest.mark.xfail( + # https://github.com/pytorch/pytorch/issues/116056 + sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2, + reason="Windows + DDP issue in PyTorch 2.2", +) def test_fit_twice_raises(): model = BoringModel() trainer = Trainer( diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index 7be116456154a..bcd46ebde7526 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -18,6 +18,7 @@ _TORCH_GREATER_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_1, ) +from lightning.fabric.utilities.load import _load_distributed_checkpoint from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel @@ -25,6 +26,7 @@ from lightning.pytorch.plugins.precision.fsdp import FSDPPrecision from lightning.pytorch.strategies import FSDPStrategy from lightning.pytorch.trainer.states import TrainerFn +from lightning.pytorch.utilities.consolidate_checkpoint import _format_checkpoint from lightning.pytorch.utilities.exceptions import MisconfigurationException from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision from torch.distributed.fsdp.wrap import always_wrap_policy, size_based_auto_wrap_policy, wrap @@ -991,3 +993,40 @@ def _run_setup_assertions(empty_init, expected_device): else: # Case 2: Empty-init with PyTorch < 2.1 only supports `torch.empty()`-init _run_setup_assertions(empty_init=True, expected_device=torch.device("cpu")) + + +# TODO: Support checkpoint consolidation with PyTorch >= 2.2 +@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.1.0", max_torch="2.2.0") +def test_save_sharded_and_consolidate_and_load(tmp_path): + """Test the consolidation of a FSDP-sharded checkpoint into a single file.""" + + model = BoringModel() + trainer = Trainer( + default_root_dir=tmp_path, + accelerator="cuda", + devices=2, + strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy, state_dict_type="sharded"), + max_steps=3, + ) + trainer.fit(model) + + checkpoint_path_sharded = trainer.strategy.broadcast(str(trainer.checkpoint_callback.best_model_path)) + assert set(os.listdir(checkpoint_path_sharded)) == {"meta.pt", ".metadata", "__0_0.distcp", "__1_0.distcp"} + + # consolidate the checkpoint to a single file + checkpoint_path_full = trainer.strategy.broadcast(str(tmp_path / "checkpoint_full.ckpt")) + if trainer.global_rank == 0: + checkpoint = _load_distributed_checkpoint(Path(checkpoint_path_sharded)) + checkpoint = _format_checkpoint(checkpoint) + torch.save(checkpoint, checkpoint_path_full) + trainer.strategy.barrier() + + model = BoringModel() + trainer = Trainer( + default_root_dir=tmp_path, + accelerator="cuda", + devices=2, + strategy="ddp", + max_steps=4, + ) + trainer.fit(model, ckpt_path=checkpoint_path_full) diff --git a/tests/tests_pytorch/trainer/connectors/test_data_connector.py b/tests/tests_pytorch/trainer/connectors/test_data_connector.py index 5f77604a7a6d4..e9684657dd3c5 100644 --- a/tests/tests_pytorch/trainer/connectors/test_data_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_data_connector.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import sys from re import escape from typing import Sized from unittest import mock @@ -19,6 +20,7 @@ import lightning.fabric import pytest from lightning.fabric.utilities.distributed import DistributedSamplerWrapper +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset @@ -123,6 +125,11 @@ def on_train_end(self): self.ctx.__exit__(None, None, None) +@pytest.mark.xfail( + # https://github.com/pytorch/pytorch/issues/116056 + sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2, + reason="Windows + DDP issue in PyTorch 2.2", +) @pytest.mark.parametrize("num_workers", [0, 1, 2]) def test_dataloader_persistent_workers_performance_warning(num_workers, tmp_path): """Test that when the multiprocessing start-method is 'spawn', we recommend setting `persistent_workers=True`.""" diff --git a/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py index bc2abdacb0e04..bb17220610366 100644 --- a/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py @@ -15,6 +15,7 @@ import collections import itertools +import sys from re import escape from unittest import mock from unittest.mock import call @@ -22,6 +23,7 @@ import numpy as np import pytest import torch +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 from lightning.pytorch import Trainer, callbacks from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar from lightning.pytorch.core.module import LightningModule @@ -346,7 +348,15 @@ def validation_step(self, batch, batch_idx): ("devices", "accelerator"), [ (1, "cpu"), - (2, "cpu"), + pytest.param( + 2, + "cpu", + marks=pytest.mark.xfail( + # https://github.com/pytorch/pytorch/issues/116056 + sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2, + reason="Windows + DDP issue in PyTorch 2.2", + ), + ), pytest.param(2, "gpu", marks=RunIf(min_cuda_gpus=2)), ], ) diff --git a/tests/tests_pytorch/utilities/test_consolidate_checkpoint.py b/tests/tests_pytorch/utilities/test_consolidate_checkpoint.py new file mode 100644 index 0000000000000..5bbaa4181ecdc --- /dev/null +++ b/tests/tests_pytorch/utilities/test_consolidate_checkpoint.py @@ -0,0 +1,33 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import Mock + +from lightning.pytorch.utilities.consolidate_checkpoint import _format_checkpoint + + +def test_format_checkpoint(): + # The 'model' key gets renamed to 'state_dict' + model = Mock() + checkpoint = {"model": model} + assert _format_checkpoint(checkpoint) == {"state_dict": model} + + # Optimizer states with keys 'optimizer_0', 'optimizer_1', etc. get converted to a list + optimizer0 = Mock() + optimizer1 = Mock() + checkpoint = {"model": model, "optimizer_1": optimizer1, "optimizer_0": optimizer0, "optimizer_abc": "other"} + assert _format_checkpoint(checkpoint) == { + "state_dict": model, + "optimizer_states": [optimizer0, optimizer1], + "optimizer_abc": "other", + } diff --git a/tests/tests_pytorch/utilities/test_model_helpers.py b/tests/tests_pytorch/utilities/test_model_helpers.py index 248a40cb38989..78a63a7e9d2a7 100644 --- a/tests/tests_pytorch/utilities/test_model_helpers.py +++ b/tests/tests_pytorch/utilities/test_model_helpers.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect import logging import pytest @@ -66,10 +67,12 @@ def cmethod(cls): def test_restricted_classmethod(): + restricted_method = RestrictedClass().restricted_cmethod # no exception when getting restricted method + with pytest.raises(TypeError, match="cannot be called on an instance"): - RestrictedClass().restricted_cmethod() + restricted_method() - RestrictedClass.restricted_cmethod() # no exception + _ = inspect.getmembers(RestrictedClass()) # no exception on inspecting instance def test_module_mode():