diff --git a/.github/checkgroup.yml b/.github/checkgroup.yml index d7f0da2042c9d..3f19b567cfb34 100644 --- a/.github/checkgroup.yml +++ b/.github/checkgroup.yml @@ -22,12 +22,15 @@ subprojects: - "pl-cpu (macOS-11, lightning, 3.8, 1.13, oldest)" - "pl-cpu (macOS-11, lightning, 3.10, 1.13)" - "pl-cpu (macOS-11, lightning, 3.10, 2.1)" + - "pl-cpu (macOS-11, lightning, 3.10, 2.2)" - "pl-cpu (ubuntu-20.04, lightning, 3.8, 1.13, oldest)" - "pl-cpu (ubuntu-20.04, lightning, 3.10, 1.13)" - "pl-cpu (ubuntu-20.04, lightning, 3.10, 2.1)" + - "pl-cpu (ubuntu-20.04, lightning, 3.10, 2.2)" - "pl-cpu (windows-2022, lightning, 3.8, 1.13, oldest)" - "pl-cpu (windows-2022, lightning, 3.10, 1.13)" - "pl-cpu (windows-2022, lightning, 3.10, 2.1)" + - "pl-cpu (windows-2022, lightning, 3.10, 2.2)" - "pl-cpu (macOS-11, pytorch, 3.8, 1.13)" - "pl-cpu (ubuntu-20.04, pytorch, 3.8, 1.13)" - "pl-cpu (windows-2022, pytorch, 3.8, 1.13)" @@ -188,12 +191,15 @@ subprojects: - "fabric-cpu (macOS-11, lightning, 3.8, 1.13, oldest)" - "fabric-cpu (macOS-11, lightning, 3.10, 1.13)" - "fabric-cpu (macOS-11, lightning, 3.11, 2.1)" + - "fabric-cpu (macOS-11, lightning, 3.11, 2.2)" - "fabric-cpu (ubuntu-20.04, lightning, 3.8, 1.13, oldest)" - "fabric-cpu (ubuntu-20.04, lightning, 3.10, 1.13)" - "fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.1)" + - "fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.2)" - "fabric-cpu (windows-2022, lightning, 3.8, 1.13, oldest)" - "fabric-cpu (windows-2022, lightning, 3.10, 1.13)" - "fabric-cpu (windows-2022, lightning, 3.11, 2.1)" + - "fabric-cpu (windows-2022, lightning, 3.11, 2.2)" - "fabric-cpu (macOS-11, fabric, 3.8, 1.13)" - "fabric-cpu (ubuntu-20.04, fabric, 3.8, 1.13)" - "fabric-cpu (windows-2022, fabric, 3.8, 1.13)" diff --git a/.github/workflows/ci-examples-app.yml b/.github/workflows/ci-examples-app.yml index cedbdbb7e260b..134930d84be14 100644 --- a/.github/workflows/ci-examples-app.yml +++ b/.github/workflows/ci-examples-app.yml @@ -67,7 +67,7 @@ jobs: run: python .actions/assistant.py replace_oldest_ver - name: pip wheels cache - uses: actions/cache/restore@v3 + uses: actions/cache/restore@v4 with: path: ${{ env.PYPI_CACHE_DIR }} key: pypi_wheels @@ -126,7 +126,7 @@ jobs: coverage report -i - name: Upload coverage to Codecov - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} file: tests/coverage.xml diff --git a/.github/workflows/ci-tests-app.yml b/.github/workflows/ci-tests-app.yml index a0650c2cf1213..8d8fb94181903 100644 --- a/.github/workflows/ci-tests-app.yml +++ b/.github/workflows/ci-tests-app.yml @@ -73,7 +73,7 @@ jobs: run: python .actions/assistant.py replace_oldest_ver - name: pip wheels cache - uses: actions/cache/restore@v3 + uses: actions/cache/restore@v4 with: path: ${{ env.PYPI_CACHE_DIR }} key: pypi_wheels @@ -151,7 +151,7 @@ jobs: coverage report -i - name: Upload coverage to Codecov - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} file: tests/coverage.xml diff --git a/.github/workflows/ci-tests-fabric.yml b/.github/workflows/ci-tests-fabric.yml index 36ce5d4131a81..ccb6fc928a014 100644 --- a/.github/workflows/ci-tests-fabric.yml +++ b/.github/workflows/ci-tests-fabric.yml @@ -46,8 +46,8 @@ jobs: - { os: "macOS-11", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.1" } - { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.1" } - { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.1" } - - { os: "macOS-12", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" } - - { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" } + - { os: "macOS-11", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" } + - { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" } - { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" } # only run PyTorch latest with Python latest, use Fabric scope to limit dependency issues - { os: "macOS-12", pkg-name: "fabric", python-version: "3.11", pytorch-version: "2.0" } @@ -114,7 +114,7 @@ jobs: done - name: pip wheels cache - uses: actions/cache/restore@v3 + uses: actions/cache/restore@v4 with: path: ${{ env.PYPI_CACHE_DIR }} key: pypi_wheels @@ -173,7 +173,7 @@ jobs: coverage xml - name: Upload coverage to Codecov - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 # see: https://github.com/actions/toolkit/issues/399 continue-on-error: true with: diff --git a/.github/workflows/ci-tests-pytorch.yml b/.github/workflows/ci-tests-pytorch.yml index 40bd439255d70..1e5835054fdfe 100644 --- a/.github/workflows/ci-tests-pytorch.yml +++ b/.github/workflows/ci-tests-pytorch.yml @@ -50,8 +50,8 @@ jobs: - { os: "macOS-11", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" } - { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" } - { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" } - - { os: "macOS-12", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.2" } - - { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.2" } + - { os: "macOS-11", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.2" } + - { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.2" } - { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.2" } # only run PyTorch latest with Python latest, use PyTorch scope to limit dependency issues - { os: "macOS-12", pkg-name: "pytorch", python-version: "3.11", pytorch-version: "2.0" } @@ -120,7 +120,7 @@ jobs: cat requirements/pytorch/base.txt - name: pip wheels cache - uses: actions/cache/restore@v3 + uses: actions/cache/restore@v4 with: path: ${{ env.PYPI_CACHE_DIR }} key: pypi_wheels @@ -161,7 +161,7 @@ jobs: cache-key: "pypi_wheels" - name: Cache datasets - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: Datasets key: pl-dataset @@ -210,7 +210,7 @@ jobs: coverage xml - name: Upload coverage to Codecov - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 # see: https://github.com/actions/toolkit/issues/399 continue-on-error: true with: diff --git a/.github/workflows/ci-tests-store.yml b/.github/workflows/ci-tests-store.yml index 8f01f33e036f0..60614005c614d 100644 --- a/.github/workflows/ci-tests-store.yml +++ b/.github/workflows/ci-tests-store.yml @@ -85,7 +85,7 @@ jobs: coverage xml - name: Upload coverage to Codecov - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 # see: https://github.com/actions/toolkit/issues/399 continue-on-error: true with: diff --git a/.github/workflows/code-checks.yml b/.github/workflows/code-checks.yml index bb27311036e38..535274e403a58 100644 --- a/.github/workflows/code-checks.yml +++ b/.github/workflows/code-checks.yml @@ -34,7 +34,7 @@ jobs: python-version: "3.10.6" - name: Mypy cache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: .mypy_cache key: mypy-${{ hashFiles('requirements/typing.txt') }} diff --git a/.github/workflows/docs-build.yml b/.github/workflows/docs-build.yml index 6e535f2ea3866..e731539a2d8e1 100644 --- a/.github/workflows/docs-build.yml +++ b/.github/workflows/docs-build.yml @@ -80,7 +80,7 @@ jobs: pip install lai-sphinx-theme -U -f ${PYPI_LOCAL_DIR} - name: pip wheels cache - uses: actions/cache/restore@v3 + uses: actions/cache/restore@v4 with: path: ${{ env.PYPI_CACHE_DIR }} key: pypi_wheels diff --git a/.github/workflows/tpu-tests.yml b/.github/workflows/tpu-tests.yml index 22b0ec3cbc96b..8d9d277cac26b 100644 --- a/.github/workflows/tpu-tests.yml +++ b/.github/workflows/tpu-tests.yml @@ -165,7 +165,7 @@ jobs: gcloud compute tpus tpu-vm list - name: Upload coverage to Codecov - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 continue-on-error: true with: token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 72aaa721f665b..1180c52f4c696 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -71,6 +71,11 @@ repos: additional_dependencies: [tomli] args: ["--in-place"] + - repo: https://github.com/sphinx-contrib/sphinx-lint + rev: v0.9.1 + hooks: + - id: sphinx-lint + - repo: https://github.com/asottile/yesqa rev: v1.5.0 hooks: @@ -86,6 +91,8 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit rev: "v0.1.15" hooks: + - id: ruff-format + args: ["--preview"] - id: ruff args: ["--fix", "--preview"] diff --git a/docs/source-app/core_api/lightning_app/communication_content.rst b/docs/source-app/core_api/lightning_app/communication_content.rst index ea39749018757..4b373dbb07588 100644 --- a/docs/source-app/core_api/lightning_app/communication_content.rst +++ b/docs/source-app/core_api/lightning_app/communication_content.rst @@ -87,13 +87,13 @@ And here's the output you get when running the App using the **Lightning CLI**: .. code-block:: console - INFO: Your app has started. View it in your browser: http://127.0.0.1:7501/view - State: {'works': {'w': {'vars': {'counter': 1}}}} - State: {'works': {'w': {'vars': {'counter': 2}}}} - State: {'works': {'w': {'vars': {'counter': 3}}}} - State: {'works': {'w': {'vars': {'counter': 3}}}} - State: {'works': {'w': {'vars': {'counter': 4}}}} - ... + INFO: Your app has started. View it in your browser: http://127.0.0.1:7501/view + State: {'works': {'w': {'vars': {'counter': 1}}}} + State: {'works': {'w': {'vars': {'counter': 2}}}} + State: {'works': {'w': {'vars': {'counter': 3}}}} + State: {'works': {'w': {'vars': {'counter': 3}}}} + State: {'works': {'w': {'vars': {'counter': 4}}}} + ... ---- diff --git a/docs/source-app/core_api/lightning_app/dynamic_work_content.rst b/docs/source-app/core_api/lightning_app/dynamic_work_content.rst index 6b8bc06cb6cb0..31616a42a1493 100644 --- a/docs/source-app/core_api/lightning_app/dynamic_work_content.rst +++ b/docs/source-app/core_api/lightning_app/dynamic_work_content.rst @@ -41,9 +41,9 @@ There are a couple of ways you can add a dynamic Work: def run(self): if not hasattr(self, "work"): - # The `Work` component is created and attached here. + # The `Work` component is created and attached here. setattr(self, "work", Work()) - # Run the `Work` component. + # Run the `Work` component. getattr(self, "work").run() **OPTION 2:** Use the built-in Lightning classes :class:`~lightning.app.structures.Dict` or :class:`~lightning.app.structures.List` @@ -60,7 +60,7 @@ There are a couple of ways you can add a dynamic Work: def run(self): if "work" not in self.dict: - # The `Work` component is attached here. + # The `Work` component is attached here. self.dict["work"] = Work() self.dict["work"].run() diff --git a/docs/source-app/glossary/environment_variables.rst b/docs/source-app/glossary/environment_variables.rst index c561a5dee4dfd..dcbcd8c6f4ab7 100644 --- a/docs/source-app/glossary/environment_variables.rst +++ b/docs/source-app/glossary/environment_variables.rst @@ -24,4 +24,4 @@ Environment variables are available in all Flows and Works, and can be accessed print(os.environ["BAZ"]) # FAZ .. note:: - Environment variables are not encrypted. For sensitive values, we recommend using :ref:`Encrypted Secrets `. + Environment variables are not encrypted. For sensitive values, we recommend using :ref:`Encrypted Secrets `. diff --git a/docs/source-app/glossary/secrets.rst b/docs/source-app/glossary/secrets.rst index 0f1b226739af7..95a0d564c648d 100644 --- a/docs/source-app/glossary/secrets.rst +++ b/docs/source-app/glossary/secrets.rst @@ -8,7 +8,7 @@ Encrypted Secrets allow you to pass private data to your apps, like API keys, ac Secrets provide you with a secure way to store this data in a way that is accessible to Apps so that they can authenticate third-party services/solutions. .. tip:: - For non-sensitive configuration values, we recommend using :ref:`plain-text Environment Variables `. + For non-sensitive configuration values, we recommend using :ref:`plain-text Environment Variables `. ************ Add a secret diff --git a/docs/source-app/glossary/sharing_components.rst b/docs/source-app/glossary/sharing_components.rst index f9cc48c49b776..2426bb43d4469 100644 --- a/docs/source-app/glossary/sharing_components.rst +++ b/docs/source-app/glossary/sharing_components.rst @@ -34,7 +34,7 @@ Now, imagine you have implemented a **KerasScriptRunner** component for training Here are the best practices steps before sharing the component: -* **Testing**: Ensure your component is well tested by following the ref:`../testing` guide. +* **Testing**: Ensure your component is well tested by following the :doc:`../testing` guide. * **Documented**: Ensure your component has a docstring and comes with some usage explications. .. Note:: As a Lightning user, it helps to implement your components thinking someone else is going to use them. diff --git a/docs/source-app/workflows/access_app_state.rst b/docs/source-app/workflows/access_app_state.rst index 1f7e1258a4e9b..8f99534bd239b 100644 --- a/docs/source-app/workflows/access_app_state.rst +++ b/docs/source-app/workflows/access_app_state.rst @@ -50,10 +50,10 @@ And here's the output you get when running the App using **Lightning CLI**: .. code-block:: console - INFO: Your app has started. View it in your browser: http://127.0.0.1:7501/view - State: {'works': {'w': {'vars': {'counter': 1}}}} - State: {'works': {'w': {'vars': {'counter': 2}}}} - State: {'works': {'w': {'vars': {'counter': 3}}}} - State: {'works': {'w': {'vars': {'counter': 3}}}} - State: {'works': {'w': {'vars': {'counter': 4}}}} - ... + INFO: Your app has started. View it in your browser: http://127.0.0.1:7501/view + State: {'works': {'w': {'vars': {'counter': 1}}}} + State: {'works': {'w': {'vars': {'counter': 2}}}} + State: {'works': {'w': {'vars': {'counter': 3}}}} + State: {'works': {'w': {'vars': {'counter': 3}}}} + State: {'works': {'w': {'vars': {'counter': 4}}}} + ... diff --git a/docs/source-app/workflows/add_web_ui/react/communicate_between_react_and_lightning.rst b/docs/source-app/workflows/add_web_ui/react/communicate_between_react_and_lightning.rst index 63b92beea6972..ef836e7e4b0d1 100644 --- a/docs/source-app/workflows/add_web_ui/react/communicate_between_react_and_lightning.rst +++ b/docs/source-app/workflows/add_web_ui/react/communicate_between_react_and_lightning.rst @@ -47,7 +47,7 @@ Update React <-- Lightning app ****************************** To change the React app from the Lightning app, use the values from the `lightningState`. -In this example, when the `react_ui.counter`` increaes in the Lightning app: +In this example, when the ``react_ui.counter`` increaes in the Lightning app: .. literalinclude:: ../../../../../src/lightning/app/cli/react-ui-template/example_app.py :emphasize-lines: 18, 24 diff --git a/docs/source-pytorch/advanced/post_training_quantization.rst b/docs/source-pytorch/advanced/post_training_quantization.rst index 12139d2cea6cd..504a57a4191cf 100644 --- a/docs/source-pytorch/advanced/post_training_quantization.rst +++ b/docs/source-pytorch/advanced/post_training_quantization.rst @@ -55,8 +55,8 @@ Usage Minor code changes are required for the user to get started with Intel® Neural Compressor quantization API. To construct the quantization process, users can specify the below settings via the Python code: -1. Calibration Dataloader (Needed for post-training static quantization) -2. Evaluation Dataloader and Metric +1. Calibration Dataloader (Needed for post-training static quantization) +2. Evaluation Dataloader and Metric The code changes that are required for Intel® Neural Compressor are highlighted with comments in the line above. diff --git a/docs/source-pytorch/conf.py b/docs/source-pytorch/conf.py index 829f1f2c9ed75..fa8eb179ead3b 100644 --- a/docs/source-pytorch/conf.py +++ b/docs/source-pytorch/conf.py @@ -623,4 +623,7 @@ def package_list_from_file(file): "https://deepgenerativemodels.github.io/assets/slides/cs236_lecture11.pdf", "https://www.intel.com/content/www/us/en/products/docs/processors/what-is-a-gpu.html", "https://www.microsoft.com/en-us/research/blog/zero-infinity-and-deepspeed-unlocking-unprecedented-model-scale-for-deep-learning-training/", # noqa: E501 + "https://stackoverflow.com/questions/66640705/how-can-i-install-grpcio-on-an-apple-m1-silicon-laptop", + "https://github.com/Lightning-AI/lightning/blob/master/examples/pytorch/ipu/mnist_sample.py", + "https://ngc.nvidia.com/catalog/containers/nvidia:nemo", # in ecosystem/asr_nlp_tts.rst ] diff --git a/docs/source-pytorch/data/alternatives.rst b/docs/source-pytorch/data/alternatives.rst index 79454b6078440..976f6f9de7297 100644 --- a/docs/source-pytorch/data/alternatives.rst +++ b/docs/source-pytorch/data/alternatives.rst @@ -32,8 +32,8 @@ As datasets grow in size and the number of nodes scales, loading training data c The `StreamingDataset `__ can make training on large datasets from cloud storage as fast, cheap, and scalable as possible. -This library uses a custom built class:`~torch.utils.data.IterableDataset`. The library recommends iterating through it -via a regular class:`~torch.utils.data.DataLoader`. This means that support in the ``Trainer`` is seamless: +This library uses a custom built :class:`~torch.utils.data.IterableDataset`. The library recommends iterating through it +via a regular :class:`~torch.utils.data.DataLoader`. This means that support in the ``Trainer`` is seamless: .. code-block:: python diff --git a/docs/source-pytorch/ecosystem/asr_nlp_tts.rst b/docs/source-pytorch/ecosystem/asr_nlp_tts.rst index 5989f11e606b0..af7fd05af709a 100644 --- a/docs/source-pytorch/ecosystem/asr_nlp_tts.rst +++ b/docs/source-pytorch/ecosystem/asr_nlp_tts.rst @@ -660,7 +660,7 @@ Hydra makes every aspect of the NeMo model, including the PyTorch Lightning Trai Using State-Of-The-Art Pre-trained TTS Model -------------------------------------------- -Generate speech using models trained on `LJSpeech `, +Generate speech using models trained on `LJSpeech `_, around 24 hours of single speaker data. See this `TTS notebook `_ diff --git a/docs/source-pytorch/tuning/profiler_basic.rst b/docs/source-pytorch/tuning/profiler_basic.rst index d248cc649004a..99b43adc60e83 100644 --- a/docs/source-pytorch/tuning/profiler_basic.rst +++ b/docs/source-pytorch/tuning/profiler_basic.rst @@ -31,22 +31,22 @@ Once the **.fit()** function has completed, you'll see an output like this: FIT Profiler Report - ----------------------------------------------------------------------------------------------- - | Action | Mean duration (s) | Total time (s) | - ----------------------------------------------------------------------------------------------- - | [LightningModule]BoringModel.prepare_data | 10.0001 | 20.00 | - | run_training_epoch | 6.1558 | 6.1558 | - | run_training_batch | 0.0022506 | 0.015754 | - | [LightningModule]BoringModel.optimizer_step | 0.0017477 | 0.012234 | - | [LightningModule]BoringModel.val_dataloader | 0.00024388 | 0.00024388 | - | on_train_batch_start | 0.00014637 | 0.0010246 | - | [LightningModule]BoringModel.teardown | 2.15e-06 | 2.15e-06 | - | [LightningModule]BoringModel.on_train_start | 1.644e-06 | 1.644e-06 | - | [LightningModule]BoringModel.on_train_end | 1.516e-06 | 1.516e-06 | - | [LightningModule]BoringModel.on_fit_end | 1.426e-06 | 1.426e-06 | - | [LightningModule]BoringModel.setup | 1.403e-06 | 1.403e-06 | - | [LightningModule]BoringModel.on_fit_start | 1.226e-06 | 1.226e-06 | - ----------------------------------------------------------------------------------------------- + ------------------------------------------------------------------------------------------- + | Action | Mean duration (s) | Total time (s) | + ------------------------------------------------------------------------------------------- + | [LightningModule]BoringModel.prepare_data | 10.0001 | 20.00 | + | run_training_epoch | 6.1558 | 6.1558 | + | run_training_batch | 0.0022506 | 0.015754 | + | [LightningModule]BoringModel.optimizer_step | 0.0017477 | 0.012234 | + | [LightningModule]BoringModel.val_dataloader | 0.00024388 | 0.00024388 | + | on_train_batch_start | 0.00014637 | 0.0010246 | + | [LightningModule]BoringModel.teardown | 2.15e-06 | 2.15e-06 | + | [LightningModule]BoringModel.on_train_start | 1.644e-06 | 1.644e-06 | + | [LightningModule]BoringModel.on_train_end | 1.516e-06 | 1.516e-06 | + | [LightningModule]BoringModel.on_fit_end | 1.426e-06 | 1.426e-06 | + | [LightningModule]BoringModel.setup | 1.403e-06 | 1.403e-06 | + | [LightningModule]BoringModel.on_fit_start | 1.226e-06 | 1.226e-06 | + ------------------------------------------------------------------------------------------- In this report we can see that the slowest function is **prepare_data**. Now you can figure out why data preparation is slowing down your training. diff --git a/docs/source-pytorch/upgrade/sections/1_7_advanced.rst b/docs/source-pytorch/upgrade/sections/1_7_advanced.rst index 8b920445465aa..21461c070058f 100644 --- a/docs/source-pytorch/upgrade/sections/1_7_advanced.rst +++ b/docs/source-pytorch/upgrade/sections/1_7_advanced.rst @@ -103,15 +103,15 @@ - `PR11871`_ * - used ``Trainer.validated_ckpt_path`` attribute - - rely on generic read-only property ``Trainer.ckpt_path`` which is set when checkpoints are loaded via ``Trainer.validate(````ckpt_path=...)`` + - rely on generic read-only property ``Trainer.ckpt_path`` which is set when checkpoints are loaded via ``Trainer.validate(ckpt_path=...)`` - `PR11696`_ * - used ``Trainer.tested_ckpt_path`` attribute - - rely on generic read-only property ``Trainer.ckpt_path`` which is set when checkpoints are loaded via ``Trainer.test(````ckpt_path=...)`` + - rely on generic read-only property ``Trainer.ckpt_path`` which is set when checkpoints are loaded via ``Trainer.test(ckpt_path=...)`` - `PR11696`_ * - used ``Trainer.predicted_ckpt_path`` attribute - - rely on generic read-only property ``Trainer.ckpt_path``, which is set when checkpoints are loaded via ``Trainer.predict(````ckpt_path=...)`` + - rely on generic read-only property ``Trainer.ckpt_path``, which is set when checkpoints are loaded via ``Trainer.predict(ckpt_path=...)`` - `PR11696`_ * - rely on the returned dictionary from ``Callback.on_save_checkpoint`` diff --git a/docs/source-pytorch/upgrade/sections/1_9_devel.rst b/docs/source-pytorch/upgrade/sections/1_9_devel.rst index cf1493b9055aa..440b21dfc1b8c 100644 --- a/docs/source-pytorch/upgrade/sections/1_9_devel.rst +++ b/docs/source-pytorch/upgrade/sections/1_9_devel.rst @@ -26,7 +26,7 @@ - use DDP instead - `PR16386`_ :doc:`DDP <../../accelerators/gpu_expert>` - * - used the pl.plugins.ApexMixedPrecisionPlugin`` plugin + * - used the ``pl.plugins.ApexMixedPrecisionPlugin`` plugin - use PyTorch native mixed precision - `PR16039`_ diff --git a/docs/source-pytorch/upgrade/sections/1_9_regular.rst b/docs/source-pytorch/upgrade/sections/1_9_regular.rst index 0aa0108e484d5..2f35957bcd21b 100644 --- a/docs/source-pytorch/upgrade/sections/1_9_regular.rst +++ b/docs/source-pytorch/upgrade/sections/1_9_regular.rst @@ -39,11 +39,11 @@ - `PR16184`_ * - called the ``pl.tuner.auto_gpu_select.pick_single_gpu`` function - - use Trainer’s flag``devices="auto"`` + - use Trainer’s flag ``devices="auto"`` - `PR16184`_ * - called the ``pl.tuner.auto_gpu_select.pick_multiple_gpus`` functions - - use Trainer’s flag``devices="auto"`` + - use Trainer’s flag ``devices="auto"`` - `PR16184`_ * - used Trainer’s flag ``accumulate_grad_batches`` with a scheduling dictionary value diff --git a/examples/app/dag/app.py b/examples/app/dag/app.py index c63b0d15c82c2..7e94f64b62c5d 100644 --- a/examples/app/dag/app.py +++ b/examples/app/dag/app.py @@ -65,9 +65,9 @@ def __init__(self, models_paths: list): ) # Step 3: Create the work to train the models_paths in parallel. - self.dict = Dict( - **{model_path.split(".")[-1]: ModelWork(model_path, parallel=True) for model_path in models_paths} - ) + self.dict = Dict(**{ + model_path.split(".")[-1]: ModelWork(model_path, parallel=True) for model_path in models_paths + }) # Step 4: Some element to track components progress. self.has_completed = False diff --git a/examples/app/server/app.py b/examples/app/server/app.py index 689c57271ce3d..97030179ffa78 100644 --- a/examples/app/server/app.py +++ b/examples/app/server/app.py @@ -20,13 +20,11 @@ def setup(self): def predict(self, request): image = base64.b64decode(request.image.encode("utf-8")) image = Image.open(io.BytesIO(image)) - transforms = torchvision.transforms.Compose( - [ - torchvision.transforms.Resize(224), - torchvision.transforms.ToTensor(), - torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), - ] - ) + transforms = torchvision.transforms.Compose([ + torchvision.transforms.Resize(224), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ]) image = transforms(image) image = image.to(self._device) prediction = self._model(image.unsqueeze(0)) diff --git a/examples/app/server_with_auto_scaler/app.py b/examples/app/server_with_auto_scaler/app.py index 5c90964a73601..1320da6745fa6 100644 --- a/examples/app/server_with_auto_scaler/app.py +++ b/examples/app/server_with_auto_scaler/app.py @@ -34,13 +34,11 @@ def setup(self): self._model = torchvision.models.resnet18(pretrained=True).to(self._device) def predict(self, requests: BatchRequestModel): - transforms = torchvision.transforms.Compose( - [ - torchvision.transforms.Resize(224), - torchvision.transforms.ToTensor(), - torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), - ] - ) + transforms = torchvision.transforms.Compose([ + torchvision.transforms.Resize(224), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ]) images = [] for request in requests.inputs: image = app.components.serve.types.image.Image.deserialize(request.image) diff --git a/examples/fabric/dcgan/train_fabric.py b/examples/fabric/dcgan/train_fabric.py index b4fd3c98aeb3c..f7a18b2b5bc17 100644 --- a/examples/fabric/dcgan/train_fabric.py +++ b/examples/fabric/dcgan/train_fabric.py @@ -4,6 +4,7 @@ Code adapted from the official PyTorch DCGAN tutorial: https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html """ + import os import time from pathlib import Path @@ -55,14 +56,12 @@ def main(): root=dataroot, split="all", download=True, - transform=transforms.Compose( - [ - transforms.Resize(image_size), - transforms.CenterCrop(image_size), - transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), - ] - ), + transform=transforms.Compose([ + transforms.Resize(image_size), + transforms.CenterCrop(image_size), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ]), ) # Create the dataloader @@ -227,7 +226,7 @@ def __init__(self): nn.ReLU(True), # state size. (ngf) x 32 x 32 nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False), - nn.Tanh() + nn.Tanh(), # state size. (nc) x 64 x 64 ) diff --git a/examples/fabric/dcgan/train_torch.py b/examples/fabric/dcgan/train_torch.py index d7ce548b16bda..2f10dcb3db3df 100644 --- a/examples/fabric/dcgan/train_torch.py +++ b/examples/fabric/dcgan/train_torch.py @@ -4,6 +4,7 @@ Code adapted from the official PyTorch DCGAN tutorial: https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html """ + import os import random import time @@ -55,14 +56,12 @@ def main(): root=dataroot, split="all", download=True, - transform=transforms.Compose( - [ - transforms.Resize(image_size), - transforms.CenterCrop(image_size), - transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), - ] - ), + transform=transforms.Compose([ + transforms.Resize(image_size), + transforms.CenterCrop(image_size), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ]), ) # Create the dataloader @@ -236,7 +235,7 @@ def __init__(self): nn.ReLU(True), # state size. (ngf) x 32 x 32 nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False), - nn.Tanh() + nn.Tanh(), # state size. (nc) x 64 x 64 ) diff --git a/examples/fabric/meta_learning/train_fabric.py b/examples/fabric/meta_learning/train_fabric.py index 52d427057b327..69d3e56d5e548 100644 --- a/examples/fabric/meta_learning/train_fabric.py +++ b/examples/fabric/meta_learning/train_fabric.py @@ -14,6 +14,7 @@ Run it with: lightning run model train_fabric.py --accelerator=cuda --devices=2 --strategy=ddp """ + import cherry import learn2learn as l2l import torch diff --git a/examples/fabric/meta_learning/train_torch.py b/examples/fabric/meta_learning/train_torch.py index 365d01d47526d..1e3666755704b 100644 --- a/examples/fabric/meta_learning/train_torch.py +++ b/examples/fabric/meta_learning/train_torch.py @@ -15,6 +15,7 @@ Run it with: torchrun --nproc_per_node=2 --standalone train_torch.py """ + import os import random diff --git a/examples/fabric/reinforcement_learning/train_fabric.py b/examples/fabric/reinforcement_learning/train_fabric.py index ae5fb6683b3f5..c294287ac0542 100644 --- a/examples/fabric/reinforcement_learning/train_fabric.py +++ b/examples/fabric/reinforcement_learning/train_fabric.py @@ -84,14 +84,10 @@ def main(args: argparse.Namespace): ) # Environment setup - envs = gym.vector.SyncVectorEnv( - [ - make_env( - args.env_id, args.seed + rank * args.num_envs + i, rank, args.capture_video, logger.log_dir, "train" - ) - for i in range(args.num_envs) - ] - ) + envs = gym.vector.SyncVectorEnv([ + make_env(args.env_id, args.seed + rank * args.num_envs + i, rank, args.capture_video, logger.log_dir, "train") + for i in range(args.num_envs) + ]) assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" # Define the agent and the optimizer and setup them with Fabric diff --git a/examples/fabric/reinforcement_learning/train_fabric_decoupled.py b/examples/fabric/reinforcement_learning/train_fabric_decoupled.py index f825b0f1679d6..9dd0ebac762d7 100644 --- a/examples/fabric/reinforcement_learning/train_fabric_decoupled.py +++ b/examples/fabric/reinforcement_learning/train_fabric_decoupled.py @@ -59,9 +59,9 @@ def player(args, world_collective: TorchCollective, player_trainer_collective: T ) # Environment setup - envs = gym.vector.SyncVectorEnv( - [make_env(args.env_id, args.seed + i, 0, args.capture_video, log_dir, "train") for i in range(args.num_envs)] - ) + envs = gym.vector.SyncVectorEnv([ + make_env(args.env_id, args.seed + i, 0, args.capture_video, log_dir, "train") for i in range(args.num_envs) + ]) assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" # Define the agent diff --git a/examples/fabric/reinforcement_learning/train_torch.py b/examples/fabric/reinforcement_learning/train_torch.py index 8974d305b24e6..2cf755f77f9d9 100644 --- a/examples/fabric/reinforcement_learning/train_torch.py +++ b/examples/fabric/reinforcement_learning/train_torch.py @@ -142,19 +142,17 @@ def main(args: argparse.Namespace): ) # Environment setup - envs = gym.vector.SyncVectorEnv( - [ - make_env( - args.env_id, - args.seed + global_rank * args.num_envs + i, - global_rank, - args.capture_video, - logger.log_dir if global_rank == 0 else None, - "train", - ) - for i in range(args.num_envs) - ] - ) + envs = gym.vector.SyncVectorEnv([ + make_env( + args.env_id, + args.seed + global_rank * args.num_envs + i, + global_rank, + args.capture_video, + logger.log_dir if global_rank == 0 else None, + "train", + ) + for i in range(args.num_envs) + ]) assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" # Define the agent and the optimizer and setup them with DistributedDataParallel diff --git a/examples/pytorch/basics/autoencoder.py b/examples/pytorch/basics/autoencoder.py index be22d40a32600..260cce4a548b1 100644 --- a/examples/pytorch/basics/autoencoder.py +++ b/examples/pytorch/basics/autoencoder.py @@ -16,6 +16,7 @@ To run: python autoencoder.py --trainer.max_epochs=50 """ + from os import path from typing import Optional, Tuple diff --git a/examples/pytorch/basics/backbone_image_classifier.py b/examples/pytorch/basics/backbone_image_classifier.py index ae05f5435dafd..fceb97dc41cff 100644 --- a/examples/pytorch/basics/backbone_image_classifier.py +++ b/examples/pytorch/basics/backbone_image_classifier.py @@ -16,6 +16,7 @@ To run: python backbone_image_classifier.py --trainer.max_epochs=50 """ + from os import path from typing import Optional diff --git a/examples/pytorch/domain_templates/computer_vision_fine_tuning.py b/examples/pytorch/domain_templates/computer_vision_fine_tuning.py index efc0335e8edd9..d03bb0c4edd16 100644 --- a/examples/pytorch/domain_templates/computer_vision_fine_tuning.py +++ b/examples/pytorch/domain_templates/computer_vision_fine_tuning.py @@ -119,14 +119,12 @@ def normalize_transform(self): @property def train_transform(self): - return transforms.Compose( - [ - transforms.Resize((224, 224)), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - self.normalize_transform, - ] - ) + return transforms.Compose([ + transforms.Resize((224, 224)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + self.normalize_transform, + ]) @property def valid_transform(self): @@ -269,13 +267,11 @@ def add_arguments_to_parser(self, parser): parser.link_arguments("data.batch_size", "model.batch_size") parser.link_arguments("finetuning.milestones", "model.milestones") parser.link_arguments("finetuning.train_bn", "model.train_bn") - parser.set_defaults( - { - "trainer.max_epochs": 15, - "trainer.enable_model_summary": False, - "trainer.num_sanity_val_steps": 0, - } - ) + parser.set_defaults({ + "trainer.max_epochs": 15, + "trainer.enable_model_summary": False, + "trainer.num_sanity_val_steps": 0, + }) def cli_main(): diff --git a/examples/pytorch/domain_templates/generative_adversarial_net.py b/examples/pytorch/domain_templates/generative_adversarial_net.py index 18c9288c95029..311d1c38771b8 100644 --- a/examples/pytorch/domain_templates/generative_adversarial_net.py +++ b/examples/pytorch/domain_templates/generative_adversarial_net.py @@ -18,6 +18,7 @@ tensorboard --logdir default """ + from argparse import ArgumentParser, Namespace import numpy as np diff --git a/examples/pytorch/domain_templates/imagenet.py b/examples/pytorch/domain_templates/imagenet.py index 3b52feeab0aaa..49f9b509ce6e7 100644 --- a/examples/pytorch/domain_templates/imagenet.py +++ b/examples/pytorch/domain_templates/imagenet.py @@ -30,6 +30,7 @@ python imagenet.py fit --help """ + import os from typing import Optional @@ -139,14 +140,12 @@ def setup(self, stage: str): train_dir = os.path.join(self.data_path, "train") self.train_dataset = datasets.ImageFolder( train_dir, - transforms.Compose( - [ - transforms.RandomResizedCrop(224), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - normalize, - ] - ), + transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ]), ) # all stages will use the eval dataset normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) diff --git a/examples/pytorch/domain_templates/reinforce_learn_ppo.py b/examples/pytorch/domain_templates/reinforce_learn_ppo.py index 5675bf74d82f0..bc3f8c1b9b193 100644 --- a/examples/pytorch/domain_templates/reinforce_learn_ppo.py +++ b/examples/pytorch/domain_templates/reinforce_learn_ppo.py @@ -28,6 +28,7 @@ [3] https://github.com/sid-sundrani/ppo_lightning """ + import argparse from typing import Callable, Iterator, List, Tuple diff --git a/examples/pytorch/domain_templates/semantic_segmentation.py b/examples/pytorch/domain_templates/semantic_segmentation.py index 3fec6cbea8205..c60518e229cbf 100644 --- a/examples/pytorch/domain_templates/semantic_segmentation.py +++ b/examples/pytorch/domain_templates/semantic_segmentation.py @@ -333,14 +333,10 @@ def __init__( self.net = UNet( num_classes=19, num_layers=self.num_layers, features_start=self.features_start, bilinear=self.bilinear ) - self.transform = transforms.Compose( - [ - transforms.ToTensor(), - transforms.Normalize( - mean=[0.35675976, 0.37380189, 0.3764753], std=[0.32064945, 0.32098866, 0.32325324] - ), - ] - ) + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], std=[0.32064945, 0.32098866, 0.32325324]), + ]) self.trainset = KITTI(self.data_path, split="train", transform=self.transform) self.validset = KITTI(self.data_path, split="valid", transform=self.transform) diff --git a/setup.py b/setup.py index d5d92d8228ab1..698eaa5abe71e 100755 --- a/setup.py +++ b/setup.py @@ -40,6 +40,7 @@ c) validate packages and publish to PyPI """ + import contextlib import glob import logging diff --git a/src/lightning/__init__.py b/src/lightning/__init__.py index 82143c1bb61d9..0b5b1c1db8011 100644 --- a/src/lightning/__init__.py +++ b/src/lightning/__init__.py @@ -1,4 +1,5 @@ """Root package info.""" + import logging import sys diff --git a/src/lightning/app/__init__.py b/src/lightning/app/__init__.py index b15654e7edbf1..5c904cc4a908c 100644 --- a/src/lightning/app/__init__.py +++ b/src/lightning/app/__init__.py @@ -1,4 +1,5 @@ """Root package info.""" + import logging import os diff --git a/src/lightning/app/cli/app-template/tests/test_placeholdername_app.py b/src/lightning/app/cli/app-template/tests/test_placeholdername_app.py index eabfbb0017a36..6c7743b93ce1e 100644 --- a/src/lightning/app/cli/app-template/tests/test_placeholdername_app.py +++ b/src/lightning/app/cli/app-template/tests/test_placeholdername_app.py @@ -6,6 +6,7 @@ run_once runs your app through one cycle of the event loop and then terminates """ + import io import os from contextlib import redirect_stdout diff --git a/src/lightning/app/cli/component-template/tests/test_placeholdername_component.py b/src/lightning/app/cli/component-template/tests/test_placeholdername_component.py index 8c7dad2fe76ea..6b9c28845749c 100644 --- a/src/lightning/app/cli/component-template/tests/test_placeholdername_component.py +++ b/src/lightning/app/cli/component-template/tests/test_placeholdername_component.py @@ -4,6 +4,7 @@ 2. call .run() """ + from placeholdername.component import TemplateComponent diff --git a/src/lightning/app/cli/pl-app-template/core/callbacks.py b/src/lightning/app/cli/pl-app-template/core/callbacks.py index 72a75d7bee6f9..87ec8e9bcbda2 100644 --- a/src/lightning/app/cli/pl-app-template/core/callbacks.py +++ b/src/lightning/app/cli/pl-app-template/core/callbacks.py @@ -297,9 +297,11 @@ def _collect_logger_metadata(self, trainer: "pl.Trainer") -> None: for logger in trainer.loggers: metadata = {"class_name": logger.__class__.__name__} if isinstance(logger, WandbLogger) and not logger._offline: - metadata.update( - {"username": logger.experiment.entity, "project_name": logger.name, "run_id": logger.version} - ) + metadata.update({ + "username": logger.experiment.entity, + "project_name": logger.name, + "run_id": logger.version, + }) if metadata and metadata not in self.work.logger_metadatas: self.work.logger_metadatas.append(metadata) diff --git a/src/lightning/app/cli/pl-app-template/core/components/logger/tensorboard.py b/src/lightning/app/cli/pl-app-template/core/components/logger/tensorboard.py index d0bae8c5c0b6b..0e1a536ff4859 100644 --- a/src/lightning/app/cli/pl-app-template/core/components/logger/tensorboard.py +++ b/src/lightning/app/cli/pl-app-template/core/components/logger/tensorboard.py @@ -32,17 +32,15 @@ def __init__(self, log_dir: Path, sync_every_n_seconds: int = 5) -> None: self._sync_every_n_seconds = sync_every_n_seconds def run(self) -> None: - subprocess.Popen( - [ - "tensorboard", - "--logdir", - str(self.log_dir), - "--host", - self.host, - "--port", - str(self.port), - ] - ) + subprocess.Popen([ + "tensorboard", + "--logdir", + str(self.log_dir), + "--host", + self.host, + "--port", + str(self.port), + ]) # Download the log directory periodically while True: diff --git a/src/lightning/app/cli/pl-app-template/core/components/script_runner/script_runner.py b/src/lightning/app/cli/pl-app-template/core/components/script_runner/script_runner.py index 3e9ce5cd45287..0c2f09e372237 100644 --- a/src/lightning/app/cli/pl-app-template/core/components/script_runner/script_runner.py +++ b/src/lightning/app/cli/pl-app-template/core/components/script_runner/script_runner.py @@ -37,14 +37,12 @@ def configure_tracer(self) -> Tracer: def pre_trainer_init(_, *args: Any, **kwargs: Any) -> Tuple[Dict, Tuple[Any, ...], Dict[str, Any]]: kwargs.setdefault("callbacks", []) - kwargs["callbacks"].extend( - [ - trainer_artifacts_tracker, - trainer_state_tracker, - progress_tracker, - summary, - ] - ) + kwargs["callbacks"].extend([ + trainer_artifacts_tracker, + trainer_state_tracker, + progress_tracker, + summary, + ]) return {}, args, kwargs tracer.add_traced(Trainer, "__init__", pre_fn=pre_trainer_init) diff --git a/src/lightning/app/components/database/utilities.py b/src/lightning/app/components/database/utilities.py index b848f72856447..129f8a210a758 100644 --- a/src/lightning/app/components/database/utilities.py +++ b/src/lightning/app/components/database/utilities.py @@ -148,23 +148,19 @@ def convert_to_model(self, models: Dict[str, BaseModel]): @classmethod def from_obj(cls, obj, token): - return cls( - **{ - "cls_name": obj.__class__.__name__, - "data": obj.json(), - "token": token, - } - ) + return cls(**{ + "cls_name": obj.__class__.__name__, + "data": obj.json(), + "token": token, + }) @classmethod def from_cls(cls, obj_cls, token): - return cls( - **{ - "cls_name": obj_cls.__name__, - "data": "", - "token": token, - } - ) + return cls(**{ + "cls_name": obj_cls.__name__, + "data": "", + "token": token, + }) class _SelectAll: diff --git a/src/lightning/app/components/multi_node/base.py b/src/lightning/app/components/multi_node/base.py index 082e80d7eb659..e5a651466c0a4 100644 --- a/src/lightning/app/components/multi_node/base.py +++ b/src/lightning/app/components/multi_node/base.py @@ -77,17 +77,15 @@ def run( " To run on multiple nodes in the cloud, launch your app with `--cloud`." ) num_nodes = 1 - self.ws = _List( - *[ - work_cls( - *work_args, - cloud_compute=cloud_compute.clone(), - **work_kwargs, - parallel=True, - ) - for _ in range(num_nodes) - ] - ) + self.ws = _List(*[ + work_cls( + *work_args, + cloud_compute=cloud_compute.clone(), + **work_kwargs, + parallel=True, + ) + for _ in range(num_nodes) + ]) def run(self) -> None: # 1. Wait for all works to be started ! diff --git a/src/lightning/app/components/multi_node/fabric.py b/src/lightning/app/components/multi_node/fabric.py index 8723e95afea90..286a462fe7d2e 100644 --- a/src/lightning/app/components/multi_node/fabric.py +++ b/src/lightning/app/components/multi_node/fabric.py @@ -28,8 +28,7 @@ @runtime_checkable class _FabricWorkProtocol(Protocol): @staticmethod - def run() -> None: - ... + def run() -> None: ... @dataclass diff --git a/src/lightning/app/components/multi_node/trainer.py b/src/lightning/app/components/multi_node/trainer.py index ccaa44b33f915..b3f09bb6e69d3 100644 --- a/src/lightning/app/components/multi_node/trainer.py +++ b/src/lightning/app/components/multi_node/trainer.py @@ -28,8 +28,7 @@ @runtime_checkable class _LightningTrainerWorkProtocol(Protocol): @staticmethod - def run() -> None: - ... + def run() -> None: ... @dataclass diff --git a/src/lightning/app/components/serve/auto_scaler.py b/src/lightning/app/components/serve/auto_scaler.py index 8b95d30c22d1c..129bb4e50b635 100644 --- a/src/lightning/app/components/serve/auto_scaler.py +++ b/src/lightning/app/components/serve/auto_scaler.py @@ -614,12 +614,10 @@ def workers(self) -> List[LightningWork]: def create_work(self) -> LightningWork: """Replicates a LightningWork instance with args and kwargs provided via ``__init__``.""" cloud_compute = self._work_kwargs.get("cloud_compute", None) - self._work_kwargs.update( - { - "start_with_flow": False, - "cloud_compute": cloud_compute.clone() if cloud_compute else None, - } - ) + self._work_kwargs.update({ + "start_with_flow": False, + "cloud_compute": cloud_compute.clone() if cloud_compute else None, + }) return self._work_cls(*self._work_args, **self._work_kwargs) def add_work(self, work) -> str: diff --git a/src/lightning/app/core/app.py b/src/lightning/app/core/app.py index ed4c1d5c76f3f..b9ff54f9a8852 100644 --- a/src/lightning/app/core/app.py +++ b/src/lightning/app/core/app.py @@ -350,10 +350,10 @@ def _collect_deltas_from_ui_and_work_queues(self) -> List[Union[Delta, _APIReque while (time() - t0) < self.state_accumulate_wait: # TODO: Fetch all available deltas at once to reduce queue calls. - received_deltas: List[ - Union[_DeltaRequest, _APIRequest, _CommandRequest, ComponentDelta] - ] = self.batch_get_state_changed_from_queue( - self.delta_queue # type: ignore[assignment,arg-type] + received_deltas: List[Union[_DeltaRequest, _APIRequest, _CommandRequest, ComponentDelta]] = ( + self.batch_get_state_changed_from_queue( + self.delta_queue # type: ignore[assignment,arg-type] + ) ) if len(received_deltas) == []: break @@ -371,7 +371,9 @@ def _collect_deltas_from_ui_and_work_queues(self) -> List[Union[Delta, _APIReque if work: delta = _delta_to_app_state_delta( - self.root, work, deepcopy(delta.delta) # type: ignore[arg-type] + self.root, # type: ignore[arg-type] + work, + deepcopy(delta.delta), ) deltas.append(delta) else: diff --git a/src/lightning/app/frontend/panel/__init__.py b/src/lightning/app/frontend/panel/__init__.py index 96cb5550ce3c6..8d11a3d2815c4 100644 --- a/src/lightning/app/frontend/panel/__init__.py +++ b/src/lightning/app/frontend/panel/__init__.py @@ -1,4 +1,5 @@ """The PanelFrontend and AppStateWatcher make it easy to create Lightning Apps with the Panel data app framework.""" + from lightning.app.frontend.panel.app_state_watcher import AppStateWatcher from lightning.app.frontend.panel.panel_frontend import PanelFrontend diff --git a/src/lightning/app/frontend/panel/app_state_comm.py b/src/lightning/app/frontend/panel/app_state_comm.py index 9fec245f7b2f9..ae3316167d4a4 100644 --- a/src/lightning/app/frontend/panel/app_state_comm.py +++ b/src/lightning/app/frontend/panel/app_state_comm.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """The watch_app_state function enables us to trigger a callback function when ever the app state changes.""" + # Todo: Refactor with Streamlit # Note: It would be nice one day to just watch changes within the Flow scope instead of whole app from __future__ import annotations diff --git a/src/lightning/app/frontend/panel/app_state_watcher.py b/src/lightning/app/frontend/panel/app_state_watcher.py index fa9537e006e1b..8abc9cb52e272 100644 --- a/src/lightning/app/frontend/panel/app_state_watcher.py +++ b/src/lightning/app/frontend/panel/app_state_watcher.py @@ -19,6 +19,7 @@ This is particularly useful for the ``PanelFrontend`` but can be used by other frontends too. """ + from __future__ import annotations import os diff --git a/src/lightning/app/frontend/panel/panel_frontend.py b/src/lightning/app/frontend/panel/panel_frontend.py index 31cc1d223865e..6185ab7ce6429 100644 --- a/src/lightning/app/frontend/panel/panel_frontend.py +++ b/src/lightning/app/frontend/panel/panel_frontend.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """The PanelFrontend wraps your Panel code in your LightningFlow.""" + from __future__ import annotations import inspect diff --git a/src/lightning/app/frontend/panel/panel_serve_render_fn.py b/src/lightning/app/frontend/panel/panel_serve_render_fn.py index 0c06cccd068db..4ba8de45813a0 100644 --- a/src/lightning/app/frontend/panel/panel_serve_render_fn.py +++ b/src/lightning/app/frontend/panel/panel_serve_render_fn.py @@ -28,6 +28,7 @@ python panel_serve_render_fn """ + import inspect import os import pydoc diff --git a/src/lightning/app/frontend/streamlit_base.py b/src/lightning/app/frontend/streamlit_base.py index b03628b4495b0..a0ebc4b7cf66d 100644 --- a/src/lightning/app/frontend/streamlit_base.py +++ b/src/lightning/app/frontend/streamlit_base.py @@ -16,6 +16,7 @@ From here, we will call the render function that the user provided in ``configure_layout``. """ + import os import pydoc from typing import Callable diff --git a/src/lightning/app/frontend/utils.py b/src/lightning/app/frontend/utils.py index 80898f1213d67..c05f7e143fe11 100644 --- a/src/lightning/app/frontend/utils.py +++ b/src/lightning/app/frontend/utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Utility functions for lightning Frontends.""" + from __future__ import annotations import inspect diff --git a/src/lightning/app/launcher/launcher.py b/src/lightning/app/launcher/launcher.py index f5732855d0dbf..7dc9fca11db42 100644 --- a/src/lightning/app/launcher/launcher.py +++ b/src/lightning/app/launcher/launcher.py @@ -66,12 +66,10 @@ def start_application_server( if isinstance(queues, Dict): kwargs.update(queues) else: - kwargs.update( - { - "api_publish_state_queue": queue_system.get_api_state_publish_queue(queue_id=queue_id), - "api_response_queue": queue_system.get_api_response_queue(queue_id=queue_id), - } - ) + kwargs.update({ + "api_publish_state_queue": queue_system.get_api_state_publish_queue(queue_id=queue_id), + "api_response_queue": queue_system.get_api_response_queue(queue_id=queue_id), + }) app = load_app_from_file(entrypoint_file) diff --git a/src/lightning/app/runners/multiprocess.py b/src/lightning/app/runners/multiprocess.py index 2db62668937fc..c3217197a6a33 100644 --- a/src/lightning/app/runners/multiprocess.py +++ b/src/lightning/app/runners/multiprocess.py @@ -125,14 +125,12 @@ def dispatch(self, *args: Any, open_ui: bool = True, **kwargs: Any): # wait for server to be ready has_started_queue.get() - if all( - [ - open_ui, - "PYTEST_CURRENT_TEST" not in os.environ, - not _is_headless(self.app), - constants.LIGHTNING_CLOUDSPACE_HOST is None, - ] - ): + if all([ + open_ui, + "PYTEST_CURRENT_TEST" not in os.environ, + not _is_headless(self.app), + constants.LIGHTNING_CLOUDSPACE_HOST is None, + ]): click.launch(self._get_app_url()) # Connect the runtime to the application. diff --git a/src/lightning/app/testing/testing.py b/src/lightning/app/testing/testing.py index 456c645d49e7b..c1f50fc3f3643 100644 --- a/src/lightning/app/testing/testing.py +++ b/src/lightning/app/testing/testing.py @@ -338,9 +338,10 @@ def run_app_in_cloud( browser = p.chromium.launch(headless=bool(int(os.getenv("HEADLESS", "0")))) context = browser.new_context( # Eventually this will need to be deleted - http_credentials=HttpCredentials( - {"username": os.getenv("LAI_USER", "").strip(), "password": os.getenv("LAI_PASS", "")} - ), + http_credentials=HttpCredentials({ + "username": os.getenv("LAI_USER", "").strip(), + "password": os.getenv("LAI_PASS", ""), + }), record_video_dir=os.path.join(_Config.video_location, TEST_APP_NAME), record_har_path=_Config.har_location, ) diff --git a/src/lightning/app/utilities/scheduler.py b/src/lightning/app/utilities/scheduler.py index aa661aaac820c..67b081fb56000 100644 --- a/src/lightning/app/utilities/scheduler.py +++ b/src/lightning/app/utilities/scheduler.py @@ -46,13 +46,11 @@ def run_once(self): if current_date > next_event: component_delta = ComponentDelta( id=metadata["name"], - delta=Delta( - { - "values_changed": { - f"root['calls']['scheduling']['{call_hash}']['running']": {"new_value": True} - } + delta=Delta({ + "values_changed": { + f"root['calls']['scheduling']['{call_hash}']['running']": {"new_value": True} } - ), + }), ) self._app.delta_queue.put(component_delta) metadata["start_time"] = next_event.isoformat() diff --git a/src/lightning/app/utilities/tree.py b/src/lightning/app/utilities/tree.py index cd53701fc7799..b60d8a85dc30b 100644 --- a/src/lightning/app/utilities/tree.py +++ b/src/lightning/app/utilities/tree.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Utilities for traversing the tree of components in an app.""" + from typing import TYPE_CHECKING, Type import lightning.app diff --git a/src/lightning/app/utilities/types.py b/src/lightning/app/utilities/types.py index 177ec0c5247d8..a59994a845cfb 100644 --- a/src/lightning/app/utilities/types.py +++ b/src/lightning/app/utilities/types.py @@ -24,5 +24,4 @@ @runtime_checkable class Hashable(Protocol): - def to_dict(self) -> t.Dict[str, t.Any]: - ... + def to_dict(self) -> t.Dict[str, t.Any]: ... diff --git a/src/lightning/data/processing/dns.py b/src/lightning/data/processing/dns.py index f1ca83dbc0e2e..9dce141b1bbb2 100644 --- a/src/lightning/data/processing/dns.py +++ b/src/lightning/data/processing/dns.py @@ -15,6 +15,7 @@ def optimize_dns_context(enable: bool) -> Any: optimize_dns(False) # always disable the optimize DNS raise e + def optimize_dns(enable: bool) -> None: if not _IS_IN_STUDIO: return @@ -22,11 +23,14 @@ def optimize_dns(enable: bool) -> None: with open("/etc/resolv.conf") as f: lines = f.readlines() - if ( - (enable and any("127.0.0.53" in line for line in lines)) - or (not enable and any("127.0.0.1" in line for line in lines)) - ): # noqa E501 - Popen(f"sudo /home/zeus/miniconda3/envs/cloudspace/bin/python -c 'from lightning.data.processing.dns import _optimize_dns; _optimize_dns({enable})'", shell=True).wait() # noqa E501 + if (enable and any("127.0.0.53" in line for line in lines)) or ( + not enable and any("127.0.0.1" in line for line in lines) + ): # E501 + Popen( + f"sudo /home/zeus/miniconda3/envs/cloudspace/bin/python -c 'from lightning.data.processing.dns import _optimize_dns; _optimize_dns({enable})'", # noqa E501 + shell=True, + ).wait() + def _optimize_dns(enable: bool) -> None: with open("/etc/resolv.conf") as f: @@ -36,9 +40,9 @@ def _optimize_dns(enable: bool) -> None: for line in lines: if "nameserver 127" in line: if enable: - write_lines.append('nameserver 127.0.0.1\n') + write_lines.append("nameserver 127.0.0.1\n") else: - write_lines.append('nameserver 127.0.0.53\n') + write_lines.append("nameserver 127.0.0.53\n") else: write_lines.append(line) diff --git a/src/lightning/data/processing/functions.py b/src/lightning/data/processing/functions.py index 00905aa40dcd4..16506ac6eeeb2 100644 --- a/src/lightning/data/processing/functions.py +++ b/src/lightning/data/processing/functions.py @@ -187,7 +187,8 @@ def map( if not _IS_IN_STUDIO and (machine is not None or num_nodes is not None): raise ValueError( "Only https://lightning.ai/ supports multiple nodes or selecting a machine." - " Create an account to try it out.") + " Create an account to try it out." + ) if not _IS_IN_STUDIO: print( diff --git a/src/lightning/data/processing/image.py b/src/lightning/data/processing/image.py index 0bf110b40a37e..eace5a552d583 100644 --- a/src/lightning/data/processing/image.py +++ b/src/lightning/data/processing/image.py @@ -8,6 +8,7 @@ # Credit to the https://github.com/rom1504/img2dataset Github repo # The code was taken from there. It has a MIT License. + def _download_image( url: str, timeout: int = 10, diff --git a/src/lightning/data/processing/readers.py b/src/lightning/data/processing/readers.py index 8519795b6207c..63ea44eeb1809 100644 --- a/src/lightning/data/processing/readers.py +++ b/src/lightning/data/processing/readers.py @@ -13,7 +13,6 @@ class BaseReader(ABC): - def get_num_nodes(self) -> int: return int(os.getenv("DATA_OPTIMIZER_NUM_NODES", 1)) @@ -35,13 +34,13 @@ def read(self, item: Any) -> Any: @dataclass class ParquetSlice: """Keep track of a parquet file slice with its filepath, start and end.""" + filepath: str start: int end: int class ParquetReader(BaseReader): - def __init__(self, num_rows: Optional[int] = 2048, to_pandas: bool = True) -> None: self.num_rows = num_rows self.to_pandas = to_pandas @@ -52,12 +51,14 @@ def __init__(self, num_rows: Optional[int] = 2048, to_pandas: bool = True) -> No def _get_num_rows(self, path: str) -> int: if _PYARROW_AVAILABLE: import pyarrow.dataset as ds + df = ds.dataset(path).scanner() return df.count_rows() # FIXED: There is a bug in polars. This leads to read_parquet to hang. if _POLARS_AVAILABLE: import polars as pol + df = pol.scan_parquet(path) num_rows = df.select(pol.len()).collect().item() return num_rows @@ -67,6 +68,7 @@ def _get_num_rows(self, path: str) -> int: def read(self, item: ParquetSlice) -> Any: if _POLARS_AVAILABLE: import polars as pol + df = pol.scan_parquet(item.filepath).slice(item.start, item.end).collect() if self.to_pandas: @@ -88,7 +90,6 @@ def read(self, item: ParquetSlice) -> Any: raise RuntimeError("Please, install either pyarrow or polars.") - def items_to_workers(self, items: Any, num_workers: int) -> List[List[ParquetSlice]]: intervals = [(0, self._get_num_rows(item)) for item in items] @@ -97,7 +98,8 @@ def items_to_workers(self, items: Any, num_workers: int) -> List[List[ParquetSli fake_distributed_env = _DistributedEnv(world_size, 0, self.get_num_nodes()) parquet_indexes_per_worker, p_slices_per_worker = _associate_chunks_and_internals_to_ranks( - fake_distributed_env, list(range(len(items))), intervals, False) + fake_distributed_env, list(range(len(items))), intervals, False + ) workers_user_items: List[List[ParquetSlice]] = [[] for _ in range(num_workers)] @@ -111,9 +113,11 @@ def items_to_workers(self, items: Any, num_workers: int) -> List[List[ParquetSli if self.num_rows: workers_user_items[worker_idx % num_workers].extend([ ParquetSlice( - items[parquet_index], p_slice_start, p_slice_start + self.num_rows - if p_slice[1] > (p_slice_start + self.num_rows) else - p_slice[1] + items[parquet_index], + p_slice_start, + p_slice_start + self.num_rows + if p_slice[1] > (p_slice_start + self.num_rows) + else p_slice[1], ) for parquet_index, p_slice in zip(parquet_indexes, p_slices) for p_slice_start in range(p_slice[0], p_slice[1] + self.num_rows, self.num_rows) diff --git a/src/lightning/data/streaming/cache.py b/src/lightning/data/streaming/cache.py index 305f393d276ac..1fa8874a185fc 100644 --- a/src/lightning/data/streaming/cache.py +++ b/src/lightning/data/streaming/cache.py @@ -63,7 +63,8 @@ def __init__( if not _LIGHTNING_CLOUD_LATEST: raise ModuleNotFoundError( "The `lightning-cloud` package in your environement is out-dated." - " Run: `pip install -U lightning-cloud` to resolve this.") + " Run: `pip install -U lightning-cloud` to resolve this." + ) input_dir = _resolve_dir(input_dir) self._cache_dir = input_dir.path diff --git a/src/lightning/data/streaming/sampler.py b/src/lightning/data/streaming/sampler.py index 9e4925ba8a147..9d1274d87fa61 100644 --- a/src/lightning/data/streaming/sampler.py +++ b/src/lightning/data/streaming/sampler.py @@ -180,16 +180,14 @@ def __iter_from_shuffled_chunks( interval_indices = np.arange(chunk_interval[0], chunk_interval[1]) shuffled_interval_indices = np.random.permutation(interval_indices).tolist() is_empty = len(indices_per_workers[worker_id]) == 0 - indices_per_workers[worker_id].extend( - [ - ChunkedIndex( - index, - chunk_index, - chunk_indexes=chunks_per_workers[worker_id] if j == 0 and is_empty else None, - ) - for j, index in enumerate(shuffled_interval_indices) - ] - ) + indices_per_workers[worker_id].extend([ + ChunkedIndex( + index, + chunk_index, + chunk_indexes=chunks_per_workers[worker_id] if j == 0 and is_empty else None, + ) + for j, index in enumerate(shuffled_interval_indices) + ]) indices_per_workers_splitted = [self._chunk_list(indices, self._batch_size) for indices in indices_per_workers] diff --git a/src/lightning/data/streaming/serializers.py b/src/lightning/data/streaming/serializers.py index 9e5a92f520741..871843f4d24bf 100644 --- a/src/lightning/data/streaming/serializers.py +++ b/src/lightning/data/streaming/serializers.py @@ -40,6 +40,7 @@ from torchvision.io import decode_jpeg from torchvision.transforms.functional import pil_to_tensor + class Serializer(ABC): """The base interface for any serializers. @@ -322,22 +323,20 @@ def can_serialize(self, data: Any) -> bool: return isinstance(data, str) and os.path.exists(data) and any(data.endswith(ext) for ext in self._EXTENSIONS) -_SERIALIZERS = OrderedDict( - **{ - "video": VideoSerializer(), - "tif": FileSerializer(), - "file": FileSerializer(), - "pil": PILSerializer(), - "int": IntSerializer(), - "jpeg": JPEGSerializer(), - "bytes": BytesSerializer(), - "no_header_numpy": NoHeaderNumpySerializer(), - "numpy": NumpySerializer(), - "no_header_tensor": NoHeaderTensorSerializer(), - "tensor": TensorSerializer(), - "pickle": PickleSerializer(), - } -) +_SERIALIZERS = OrderedDict(**{ + "video": VideoSerializer(), + "tif": FileSerializer(), + "file": FileSerializer(), + "pil": PILSerializer(), + "int": IntSerializer(), + "jpeg": JPEGSerializer(), + "bytes": BytesSerializer(), + "no_header_numpy": NoHeaderNumpySerializer(), + "numpy": NumpySerializer(), + "no_header_tensor": NoHeaderTensorSerializer(), + "tensor": TensorSerializer(), + "pickle": PickleSerializer(), +}) def _get_serializers(serializers: Optional[Dict[str, Serializer]]) -> Dict[str, Serializer]: diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 95332ac22cb29..2989189986f96 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -5,6 +5,13 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +## [2.2.1] - 2024-03-04 + +### Fixed + +- Fixed an issue with CSVLogger trying to append to file from a previous run when the version is set manually ([#19446](https://github.com/Lightning-AI/lightning/pull/19446)) + + ## [2.2.0] - 2024-02-08 ### Added diff --git a/src/lightning/fabric/__init__.py b/src/lightning/fabric/__init__.py index 652fe01943733..1add6e93f34d0 100644 --- a/src/lightning/fabric/__init__.py +++ b/src/lightning/fabric/__init__.py @@ -1,4 +1,5 @@ """Root package info.""" + import logging import os diff --git a/src/lightning/fabric/cli.py b/src/lightning/fabric/cli.py index 2805675cecf70..6aec4bec0c436 100644 --- a/src/lightning/fabric/cli.py +++ b/src/lightning/fabric/cli.py @@ -53,8 +53,10 @@ def _legacy_main() -> None: Raises deprecation warning and runs through fabric cli if necessary, else runs the entrypoint directly """ - print("`lightning run model` is deprecated and will be removed in future versions." - " Please call `fabric run model` instead.") + print( + "`lightning run model` is deprecated and will be removed in future versions." + " Please call `fabric run model` instead." + ) args = sys.argv[1:] if args and args[0] == "run" and args[1] == "model": _main() diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index bc07e633a911e..95328229cebe3 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -497,16 +497,13 @@ def autocast(self) -> ContextManager: return self._precision.forward_context() @overload - def to_device(self, obj: nn.Module) -> nn.Module: - ... + def to_device(self, obj: nn.Module) -> nn.Module: ... @overload - def to_device(self, obj: Tensor) -> Tensor: - ... + def to_device(self, obj: Tensor) -> Tensor: ... @overload - def to_device(self, obj: Any) -> Any: - ... + def to_device(self, obj: Any) -> Any: ... def to_device(self, obj: Union[nn.Module, Tensor, Any]) -> Union[nn.Module, Tensor, Any]: r"""Move a :class:`torch.nn.Module` or a collection of tensors to the current device, if it is not already on diff --git a/src/lightning/fabric/loggers/csv_logs.py b/src/lightning/fabric/loggers/csv_logs.py index d7bbf1ad720e8..efce80a6fb895 100644 --- a/src/lightning/fabric/loggers/csv_logs.py +++ b/src/lightning/fabric/loggers/csv_logs.py @@ -40,6 +40,8 @@ class CSVLogger(Logger): name: Experiment name. Defaults to ``'lightning_logs'``. version: Experiment version. If version is not specified the logger inspects the save directory for existing versions, then automatically assigns the next available version. + If the version is specified, and the directory already contains a metrics file for that version, it will be + overwritten. prefix: A string to put at the beginning of metric keys. flush_logs_every_n_steps: How often to flush logs to disk (defaults to every 100 steps). @@ -203,15 +205,11 @@ def __init__(self, log_dir: str) -> None: self._fs = get_filesystem(log_dir) self.log_dir = log_dir - if self._fs.exists(self.log_dir) and self._fs.listdir(self.log_dir): - rank_zero_warn( - f"Experiment logs directory {self.log_dir} exists and is not empty." - " Previous log files in this directory will be deleted when the new ones are saved!" - ) - self._fs.makedirs(self.log_dir, exist_ok=True) - self.metrics_file_path = os.path.join(self.log_dir, self.NAME_METRICS_FILE) + self._check_log_dir_exists() + self._fs.makedirs(self.log_dir, exist_ok=True) + def log_metrics(self, metrics_dict: Dict[str, float], step: Optional[int] = None) -> None: """Record metrics.""" @@ -264,3 +262,12 @@ def _rewrite_with_new_header(self, fieldnames: List[str]) -> None: writer = csv.DictWriter(file, fieldnames=fieldnames) writer.writeheader() writer.writerows(metrics) + + def _check_log_dir_exists(self) -> None: + if self._fs.exists(self.log_dir) and self._fs.listdir(self.log_dir): + rank_zero_warn( + f"Experiment logs directory {self.log_dir} exists and is not empty." + " Previous log files in this directory will be deleted when the new ones are saved!" + ) + if self._fs.isfile(self.metrics_file_path): + self._fs.rm_file(self.metrics_file_path) diff --git a/src/lightning/fabric/loggers/tensorboard.py b/src/lightning/fabric/loggers/tensorboard.py index 3828b9ee0f4f3..234903e7ac25d 100644 --- a/src/lightning/fabric/loggers/tensorboard.py +++ b/src/lightning/fabric/loggers/tensorboard.py @@ -78,6 +78,7 @@ class TensorBoardLogger(Logger): logger.finalize("success") """ + LOGGER_JOIN_CHAR = "-" def __init__( diff --git a/src/lightning/fabric/plugins/collectives/collective.py b/src/lightning/fabric/plugins/collectives/collective.py index 5c655f189c561..9d24501dce6da 100644 --- a/src/lightning/fabric/plugins/collectives/collective.py +++ b/src/lightning/fabric/plugins/collectives/collective.py @@ -21,13 +21,11 @@ def __init__(self) -> None: @property @abstractmethod - def rank(self) -> int: - ... + def rank(self) -> int: ... @property @abstractmethod - def world_size(self) -> int: - ... + def world_size(self) -> int: ... @property def group(self) -> CollectibleGroup: @@ -38,78 +36,61 @@ def group(self) -> CollectibleGroup: return self._group @abstractmethod - def broadcast(self, tensor: Tensor, src: int) -> Tensor: - ... + def broadcast(self, tensor: Tensor, src: int) -> Tensor: ... @abstractmethod - def all_reduce(self, tensor: Tensor, op: str) -> Tensor: - ... + def all_reduce(self, tensor: Tensor, op: str) -> Tensor: ... @abstractmethod - def reduce(self, tensor: Tensor, dst: int, op: str) -> Tensor: - ... + def reduce(self, tensor: Tensor, dst: int, op: str) -> Tensor: ... @abstractmethod - def all_gather(self, tensor_list: List[Tensor], tensor: Tensor) -> List[Tensor]: - ... + def all_gather(self, tensor_list: List[Tensor], tensor: Tensor) -> List[Tensor]: ... @abstractmethod - def gather(self, tensor: Tensor, gather_list: List[Tensor], dst: int = 0) -> List[Tensor]: - ... + def gather(self, tensor: Tensor, gather_list: List[Tensor], dst: int = 0) -> List[Tensor]: ... @abstractmethod - def scatter(self, tensor: Tensor, scatter_list: List[Tensor], src: int = 0) -> Tensor: - ... + def scatter(self, tensor: Tensor, scatter_list: List[Tensor], src: int = 0) -> Tensor: ... @abstractmethod - def reduce_scatter(self, output: Tensor, input_list: List[Tensor], op: str) -> Tensor: - ... + def reduce_scatter(self, output: Tensor, input_list: List[Tensor], op: str) -> Tensor: ... @abstractmethod - def all_to_all(self, output_tensor_list: List[Tensor], input_tensor_list: List[Tensor]) -> List[Tensor]: - ... + def all_to_all(self, output_tensor_list: List[Tensor], input_tensor_list: List[Tensor]) -> List[Tensor]: ... @abstractmethod - def send(self, tensor: Tensor, dst: int, tag: int = 0) -> None: - ... + def send(self, tensor: Tensor, dst: int, tag: int = 0) -> None: ... @abstractmethod - def recv(self, tensor: Tensor, src: Optional[int] = None, tag: int = 0) -> Tensor: - ... + def recv(self, tensor: Tensor, src: Optional[int] = None, tag: int = 0) -> Tensor: ... @abstractmethod - def barrier(self, device_ids: Optional[List[int]] = None) -> None: - ... + def barrier(self, device_ids: Optional[List[int]] = None) -> None: ... @classmethod @abstractmethod - def is_available(cls) -> bool: - ... + def is_available(cls) -> bool: ... @classmethod @abstractmethod - def is_initialized(cls) -> bool: - ... + def is_initialized(cls) -> bool: ... @classmethod @abstractmethod - def init_group(cls, **kwargs: Any) -> None: - ... + def init_group(cls, **kwargs: Any) -> None: ... @classmethod @abstractmethod - def new_group(cls, **kwargs: Any) -> CollectibleGroup: - ... + def new_group(cls, **kwargs: Any) -> CollectibleGroup: ... @classmethod @abstractmethod - def destroy_group(cls, group: CollectibleGroup) -> None: - ... + def destroy_group(cls, group: CollectibleGroup) -> None: ... @classmethod @abstractmethod - def _convert_to_native_op(cls, op: str) -> Any: - ... + def _convert_to_native_op(cls, op: str) -> Any: ... def setup(self, **kwargs: Any) -> Self: if not self.is_initialized(): diff --git a/src/lightning/fabric/plugins/precision/transformer_engine.py b/src/lightning/fabric/plugins/precision/transformer_engine.py index 1892d18cb3afb..cb5296b21fc39 100644 --- a/src/lightning/fabric/plugins/precision/transformer_engine.py +++ b/src/lightning/fabric/plugins/precision/transformer_engine.py @@ -116,12 +116,10 @@ def module_init_context(self) -> ContextManager: if self.replace_layers: import transformer_engine.pytorch as te - context_manager = _ClassReplacementContextManager( - { - "torch.nn.Linear": te.Linear, - "torch.nn.LayerNorm": te.LayerNorm, - } - ) + context_manager = _ClassReplacementContextManager({ + "torch.nn.Linear": te.Linear, + "torch.nn.LayerNorm": te.LayerNorm, + }) stack.enter_context(context_manager) stack.enter_context(dtype_ctx) return stack diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index 66f191920ece5..2a1a1272b498e 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -745,12 +745,10 @@ def _create_default_config( "max_in_cpu": max_in_cpu, "pin_memory": pin_memory, } - cfg.update( - { - "zero_allow_untested_optimizer": zero_allow_untested_optimizer, - "zero_optimization": zero_config, - } - ) + cfg.update({ + "zero_allow_untested_optimizer": zero_allow_untested_optimizer, + "zero_optimization": zero_config, + }) if logging_batch_size_per_gpu: cfg["train_micro_batch_size_per_gpu"] = logging_batch_size_per_gpu return cfg diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index 7acc6d6f02eda..961ab0c7dc8ba 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -66,6 +66,7 @@ _TORCH_GREATER_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_1, _TORCH_GREATER_EQUAL_2_2, + _TORCH_GREATER_EQUAL_2_3, ) from lightning.fabric.utilities.init import _EmptyInit from lightning.fabric.utilities.load import _METADATA_FILENAME, _lazy_load, _materialize_tensors, _move_state_into @@ -448,7 +449,6 @@ def save_checkpoint( if path.is_dir() and self._state_dict_type == "full" and not _is_sharded_checkpoint(path): raise IsADirectoryError(f"The checkpoint path exists and is a directory: {path}") - from torch.distributed.checkpoint import FileSystemWriter, save_state_dict from torch.distributed.fsdp import FullyShardedDataParallel as FSDP modules = [module for module in state.values() if _has_fsdp_modules(module)] @@ -491,9 +491,7 @@ def save_checkpoint( target_dict = metadata _apply_filter(key, filter or {}, converted, target_dict) - # FSDP's FileSystemWriter streams the tensors to disk to minimize memory peaks - writer = FileSystemWriter(path=path, single_file_per_rank=True) - save_state_dict(converted_state, writer) + _distributed_checkpoint_save(converted_state, path) if self.global_rank == 0: torch.save(metadata, path / _METADATA_FILENAME) @@ -555,16 +553,10 @@ def load_checkpoint( "Loading a single optimizer object from a checkpoint is not supported yet with the FSDP strategy." ) - from torch.distributed.checkpoint import FileSystemReader from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import OptimStateKeyType - if _TORCH_GREATER_EQUAL_2_2: - from torch.distributed.checkpoint import load - else: - from torch.distributed.checkpoint import load_state_dict as load # deprecated - modules = {key: module for key, module in state.items() if _has_fsdp_modules(module)} if len(modules) == 0: raise ValueError( @@ -583,26 +575,31 @@ def load_checkpoint( if _is_sharded_checkpoint(path): state_dict_ctx = _get_sharded_state_dict_context(module) - reader = FileSystemReader(path=path) with state_dict_ctx: module_state = {module_key: module.state_dict()} - load(module_state, reader) + _distributed_checkpoint_load(module_state, path) module.load_state_dict(module_state[module_key], strict=strict) - # the optimizer states must be loaded separately - for optim_key, optim in optimizers.items(): - optim_state = load_sharded_optimizer_state_dict( - model_state_dict=module_state[module_key], - optimizer_key=optim_key, - storage_reader=reader, - ) - flattened_osd = FSDP.optim_state_dict_to_load( - optim_state_dict=optim_state[optim_key], - model=module, - optim=optim, - ) - optim.load_state_dict(flattened_osd) + if optimizers: + from torch.distributed.checkpoint import FileSystemReader + + # TODO: replace with newer APIs + # https://github.com/pytorch/pytorch/issues/119800#issuecomment-1942156271 + reader = FileSystemReader(path=path) + # the optimizer states must be loaded separately + for optim_key, optim in optimizers.items(): + optim_state = load_sharded_optimizer_state_dict( + model_state_dict=module_state[module_key], + optimizer_key=optim_key, + storage_reader=reader, + ) + flattened_osd = FSDP.optim_state_dict_to_load( + optim_state_dict=optim_state[optim_key], + model=module, + optim=optim, + ) + optim.load_state_dict(flattened_osd) # Load metadata (anything not a module or optimizer) metadata = torch.load(path / _METADATA_FILENAME) @@ -803,8 +800,10 @@ def _init_sharding_strategy(sharding_strategy: "_SHARDING_STRATEGY", kwargs: Dic strategy = ShardingStrategy[sharding_strategy.upper()] if isinstance(sharding_strategy, str) else sharding_strategy if ( - "HYBRID" in strategy.name and kwargs.get("auto_wrap_policy") is None - and kwargs.get("process_group") is None and kwargs.get("device_mesh") is None + "HYBRID" in strategy.name + and kwargs.get("auto_wrap_policy") is None + and kwargs.get("process_group") is None + and kwargs.get("device_mesh") is None ): raise RuntimeError( "The hybrid sharding strategy requires you to pass at least one of the parameters: `auto_wrap_policy`," @@ -920,3 +919,40 @@ def _move_torchmetrics_to_device(module: torch.nn.Module, device: torch.device) for metric in (m for m in module.modules() if isinstance(m, Metric)): metric.to(device) # `.to()` is in-place + + +def _distributed_checkpoint_save(converted_state: Dict[str, Any], path: Path) -> None: + if _TORCH_GREATER_EQUAL_2_3: + from torch.distributed.checkpoint import save + + # let torch automatically infer the writer to use. This might also support fsspec paths in the future + # https://github.com/pytorch/pytorch/issues/118036 + save(converted_state, checkpoint_id=path) + else: # deprecated + from torch.distributed.checkpoint import FileSystemWriter + + if _TORCH_GREATER_EQUAL_2_2: + from torch.distributed.checkpoint import save + else: + from torch.distributed.checkpoint import save_state_dict as save + # FSDP's FileSystemWriter streams the tensors to disk to minimize memory peaks + writer = FileSystemWriter(path=path, single_file_per_rank=True) + save(converted_state, writer) + + +def _distributed_checkpoint_load(module_state: Dict[str, Any], path: Path) -> None: + if _TORCH_GREATER_EQUAL_2_3: + from torch.distributed.checkpoint import load + + # let torch automatically infer the reader to use. This might also support fsspec paths in the future + # https://github.com/pytorch/pytorch/issues/118036 + load(module_state, checkpoint_id=path) + else: # deprecated + from torch.distributed.checkpoint import FileSystemReader + + if _TORCH_GREATER_EQUAL_2_2: + from torch.distributed.checkpoint import load + else: + from torch.distributed.checkpoint import load_state_dict as load + reader = FileSystemReader(path=path) + load(module_state, reader) diff --git a/src/lightning/fabric/utilities/apply_func.py b/src/lightning/fabric/utilities/apply_func.py index 7f9b45eb39d46..90cc3c6c8e3ef 100644 --- a/src/lightning/fabric/utilities/apply_func.py +++ b/src/lightning/fabric/utilities/apply_func.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Utilities used for collections.""" + from abc import ABC from functools import partial from typing import Any, Callable, List, Tuple, Union diff --git a/src/lightning/fabric/utilities/cloud_io.py b/src/lightning/fabric/utilities/cloud_io.py index ca6987448179f..af795a4801ef0 100644 --- a/src/lightning/fabric/utilities/cloud_io.py +++ b/src/lightning/fabric/utilities/cloud_io.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Utilities related to data saving/loading.""" + import io import logging from pathlib import Path diff --git a/src/lightning/fabric/utilities/data.py b/src/lightning/fabric/utilities/data.py index 5bca7d37b6270..1ec0edce38050 100644 --- a/src/lightning/fabric/utilities/data.py +++ b/src/lightning/fabric/utilities/data.py @@ -108,9 +108,9 @@ def _get_dataloader_init_args_and_kwargs( if was_wrapped: # if the dataloader was wrapped in a hook, only take arguments with default values # and assume user passes their kwargs correctly - params.update( - {k: v for k, v in inspect.signature(DataLoader.__init__).parameters.items() if v.default is not v.empty} - ) + params.update({ + k: v for k, v in inspect.signature(DataLoader.__init__).parameters.items() if v.default is not v.empty + }) else: params.update(inspect.signature(DataLoader.__init__).parameters) params.pop("self", None) diff --git a/src/lightning/fabric/utilities/imports.py b/src/lightning/fabric/utilities/imports.py index f7fa1ad2ad78a..cc069a2a73338 100644 --- a/src/lightning/fabric/utilities/imports.py +++ b/src/lightning/fabric/utilities/imports.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """General utilities.""" + import operator import platform import sys @@ -26,8 +27,9 @@ _IS_INTERACTIVE = hasattr(sys, "ps1") or bool(sys.flags.interactive) _TORCH_GREATER_EQUAL_2_0 = compare_version("torch", operator.ge, "2.0.0") -_TORCH_GREATER_EQUAL_2_1 = compare_version("torch", operator.ge, "2.1.0", use_base_version=True) -_TORCH_GREATER_EQUAL_2_2 = compare_version("torch", operator.ge, "2.2.0", use_base_version=True) +_TORCH_GREATER_EQUAL_2_1 = compare_version("torch", operator.ge, "2.1.0") +_TORCH_GREATER_EQUAL_2_2 = compare_version("torch", operator.ge, "2.2.0") +_TORCH_GREATER_EQUAL_2_3 = compare_version("torch", operator.ge, "2.3.0", use_base_version=True) _TORCH_EQUAL_2_0 = _TORCH_GREATER_EQUAL_2_0 and not _TORCH_GREATER_EQUAL_2_1 _PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8) diff --git a/src/lightning/fabric/utilities/rank_zero.py b/src/lightning/fabric/utilities/rank_zero.py index 9120e9748b1a6..f54e850d45fcf 100644 --- a/src/lightning/fabric/utilities/rank_zero.py +++ b/src/lightning/fabric/utilities/rank_zero.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Utilities that can be used for calling functions on a particular rank.""" + import logging import os from functools import wraps @@ -52,12 +53,10 @@ def _get_rank() -> Optional[int]: P = ParamSpec("P") @overload - def rank_zero_only(fn: Callable[P, T]) -> Callable[P, Optional[T]]: - ... + def rank_zero_only(fn: Callable[P, T]) -> Callable[P, Optional[T]]: ... @overload - def rank_zero_only(fn: Callable[P, T], default: T) -> Callable[P, T]: - ... + def rank_zero_only(fn: Callable[P, T], default: T) -> Callable[P, T]: ... def rank_zero_only(fn: Callable[P, T], default: Optional[T] = None) -> Callable[P, Optional[T]]: @wraps(fn) diff --git a/src/lightning/fabric/utilities/registry.py b/src/lightning/fabric/utilities/registry.py index cbd6bcdce6604..a3b2ac442b46b 100644 --- a/src/lightning/fabric/utilities/registry.py +++ b/src/lightning/fabric/utilities/registry.py @@ -39,9 +39,7 @@ def _load_external_callbacks(group: str) -> List[Any]: from importlib.metadata import entry_points factories = ( - entry_points(group=group) - if _PYTHON_GREATER_EQUAL_3_10_0 - else entry_points().get(group, {}) # type: ignore[arg-type] + entry_points(group=group) if _PYTHON_GREATER_EQUAL_3_10_0 else entry_points().get(group, {}) # type: ignore[arg-type] ) else: from pkg_resources import iter_entry_points diff --git a/src/lightning/fabric/utilities/throughput.py b/src/lightning/fabric/utilities/throughput.py index d96673df6451d..d6d96a79ab941 100644 --- a/src/lightning/fabric/utilities/throughput.py +++ b/src/lightning/fabric/utilities/throughput.py @@ -172,17 +172,16 @@ def compute(self) -> _THROUGHPUT_METRICS: # we are safe from ZeroDivisionError thanks to `_MonotonicWindow` dev_samples_per_sec = elapsed_samples / elapsed_time dev_batches_per_sec = elapsed_batches / elapsed_time - metrics.update( - { - f"device{self.separator}batches_per_sec": elapsed_batches / elapsed_time, - f"device{self.separator}samples_per_sec": dev_samples_per_sec, - } - ) + metrics.update({ + f"device{self.separator}batches_per_sec": elapsed_batches / elapsed_time, + f"device{self.separator}samples_per_sec": dev_samples_per_sec, + }) if add_global_metrics: samples_per_sec = dev_batches_per_sec * self.world_size - metrics.update( - {"batches_per_sec": samples_per_sec, "samples_per_sec": dev_samples_per_sec * self.world_size} - ) + metrics.update({ + "batches_per_sec": samples_per_sec, + "samples_per_sec": dev_samples_per_sec * self.world_size, + }) if len(self._lengths) == self._lengths.maxlen: elapsed_lengths = self._lengths[-1] - self._lengths[0] diff --git a/src/lightning/fabric/utilities/types.py b/src/lightning/fabric/utilities/types.py index da987515db343..c4bc32f3cf319 100644 --- a/src/lightning/fabric/utilities/types.py +++ b/src/lightning/fabric/utilities/types.py @@ -58,20 +58,16 @@ class _Stateful(Protocol[_DictKey]): """This class is used to detect if an object is stateful using `isinstance(obj, _Stateful)`.""" - def state_dict(self) -> Dict[_DictKey, Any]: - ... + def state_dict(self) -> Dict[_DictKey, Any]: ... - def load_state_dict(self, state_dict: Dict[_DictKey, Any]) -> None: - ... + def load_state_dict(self, state_dict: Dict[_DictKey, Any]) -> None: ... @runtime_checkable class CollectibleGroup(Protocol): - def size(self) -> int: - ... + def size(self) -> int: ... - def rank(self) -> int: - ... + def rank(self) -> int: ... # Inferred from `torch.optim.lr_scheduler.pyi` @@ -81,11 +77,9 @@ class LRScheduler(_Stateful[str], Protocol): optimizer: Optimizer base_lrs: List[float] - def __init__(self, optimizer: Optimizer, *args: Any, **kwargs: Any) -> None: - ... + def __init__(self, optimizer: Optimizer, *args: Any, **kwargs: Any) -> None: ... - def step(self, epoch: Optional[int] = None) -> None: - ... + def step(self, epoch: Optional[int] = None) -> None: ... _TORCH_LRSCHEDULER: TypeAlias = ( @@ -114,11 +108,9 @@ def __init__( cooldown: int = ..., min_lr: float = ..., eps: float = ..., - ) -> None: - ... + ) -> None: ... - def step(self, metrics: Union[float, int, Tensor], epoch: Optional[int] = None) -> None: - ... + def step(self, metrics: Union[float, int, Tensor], epoch: Optional[int] = None) -> None: ... @runtime_checkable @@ -126,15 +118,12 @@ class Steppable(Protocol): """To structurally type ``optimizer.step()``""" @overload - def step(self, closure: None = ...) -> None: - ... + def step(self, closure: None = ...) -> None: ... @overload - def step(self, closure: Callable[[], float]) -> float: - ... + def step(self, closure: Callable[[], float]) -> float: ... - def step(self, closure: Optional[Callable[[], float]] = ...) -> Optional[float]: - ... + def step(self, closure: Optional[Callable[[], float]] = ...) -> Optional[float]: ... @runtime_checkable @@ -145,8 +134,6 @@ class Optimizable(Steppable, Protocol): defaults: Dict[Any, Any] state: DefaultDict[Tensor, Any] - def state_dict(self) -> Dict[str, Dict[Any, Any]]: - ... + def state_dict(self) -> Dict[str, Dict[Any, Any]]: ... - def load_state_dict(self, state_dict: Dict[str, Dict[Any, Any]]) -> None: - ... + def load_state_dict(self, state_dict: Dict[str, Dict[Any, Any]]) -> None: ... diff --git a/src/lightning/fabric/utilities/warnings.py b/src/lightning/fabric/utilities/warnings.py index 26445690eb581..62e5f5fc2ff17 100644 --- a/src/lightning/fabric/utilities/warnings.py +++ b/src/lightning/fabric/utilities/warnings.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Warning-related utilities.""" + import warnings from pathlib import Path from typing import Optional, Type, Union diff --git a/src/lightning/fabric/wrappers.py b/src/lightning/fabric/wrappers.py index 1aabb718d9abd..e40c73f011aab 100644 --- a/src/lightning/fabric/wrappers.py +++ b/src/lightning/fabric/wrappers.py @@ -142,12 +142,10 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: return output @overload - def state_dict(self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination: - ... + def state_dict(self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination: ... @overload - def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]: - ... + def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]: ... @override def state_dict( @@ -301,7 +299,7 @@ def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]: def _unwrap_objects(collection: Any) -> Any: def _unwrap( - obj: Union[_FabricModule, _FabricOptimizer, _FabricDataLoader] + obj: Union[_FabricModule, _FabricOptimizer, _FabricDataLoader], ) -> Union[nn.Module, Optimizer, DataLoader]: if isinstance(unwrapped := _unwrap_compiled(obj)[0], _FabricModule): return _unwrap_compiled(unwrapped._forward_module)[0] diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 12f8beb648f35..096e9f57b2da0 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -5,6 +5,16 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +## [2.2.1] - 2024-03-04 + +### Fixed + +- Fixed an issue with CSVLogger trying to append to file from a previous run when the version is set manually ([#19446](https://github.com/Lightning-AI/lightning/pull/19446)) +- Fixed the divisibility check for `Trainer.accumulate_grad_batches` and `Trainer.log_every_n_steps` in ThroughputMonitor ([#19470](https://github.com/Lightning-AI/lightning/pull/19470)) +- Fixed support for Remote Stop and Remote Abort with NeptuneLogger ([#19130](https://github.com/Lightning-AI/pytorch-lightning/pull/19130)) +- Fixed infinite recursion error in precision plugin graveyard ([#19542](https://github.com/Lightning-AI/pytorch-lightning/pull/19542)) + + ## [2.2.0] - 2024-02-08 ### Added diff --git a/src/lightning/pytorch/_graveyard/precision.py b/src/lightning/pytorch/_graveyard/precision.py index ee2d590e58134..34ff793809028 100644 --- a/src/lightning/pytorch/_graveyard/precision.py +++ b/src/lightning/pytorch/_graveyard/precision.py @@ -50,7 +50,7 @@ def init(self: type, *args: Any, **kwargs: Any) -> None: f"The `{deprecated_name}` is deprecated." f" Use `lightning.pytorch.plugins.precision.{new_class.__name__}` instead." ) - super(type(self), self).__init__(*args, **kwargs) + new_class.__init__(self, *args, **kwargs) # type: ignore[misc] return type(deprecated_name, (new_class,), {"__init__": init}) diff --git a/src/lightning/pytorch/callbacks/device_stats_monitor.py b/src/lightning/pytorch/callbacks/device_stats_monitor.py index b9602d2c80917..64ea47d2897ed 100644 --- a/src/lightning/pytorch/callbacks/device_stats_monitor.py +++ b/src/lightning/pytorch/callbacks/device_stats_monitor.py @@ -18,6 +18,7 @@ Monitors and logs device stats during training. """ + from typing import Any, Dict, Optional from typing_extensions import override diff --git a/src/lightning/pytorch/callbacks/early_stopping.py b/src/lightning/pytorch/callbacks/early_stopping.py index e44b17add3821..d1212fe8cc2e7 100644 --- a/src/lightning/pytorch/callbacks/early_stopping.py +++ b/src/lightning/pytorch/callbacks/early_stopping.py @@ -18,6 +18,7 @@ Monitor a metric and stop training when it stops improving. """ + import logging from typing import Any, Callable, Dict, Optional, Tuple @@ -85,6 +86,7 @@ class EarlyStopping(Callback): Read more: :ref:`Persisting Callback State ` """ + mode_dict = {"min": torch.lt, "max": torch.gt} order_dict = {"min": "<", "max": ">"} diff --git a/src/lightning/pytorch/callbacks/finetuning.py b/src/lightning/pytorch/callbacks/finetuning.py index 23833b60e53fb..46a90986c091c 100644 --- a/src/lightning/pytorch/callbacks/finetuning.py +++ b/src/lightning/pytorch/callbacks/finetuning.py @@ -17,6 +17,7 @@ Freeze and unfreeze models for finetuning purposes. """ + import logging from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union diff --git a/src/lightning/pytorch/callbacks/lr_finder.py b/src/lightning/pytorch/callbacks/lr_finder.py index c9c6af53b94fe..f667b5c501a10 100644 --- a/src/lightning/pytorch/callbacks/lr_finder.py +++ b/src/lightning/pytorch/callbacks/lr_finder.py @@ -17,6 +17,7 @@ Finds optimal learning rate """ + from typing import Optional from typing_extensions import override diff --git a/src/lightning/pytorch/callbacks/lr_monitor.py b/src/lightning/pytorch/callbacks/lr_monitor.py index 2110856a83c7e..357cfceefa03e 100644 --- a/src/lightning/pytorch/callbacks/lr_monitor.py +++ b/src/lightning/pytorch/callbacks/lr_monitor.py @@ -19,6 +19,7 @@ Monitor and logs learning rate for lr schedulers during training. """ + import itertools from collections import defaultdict from typing import Any, DefaultDict, Dict, List, Literal, Optional, Set, Tuple, Type @@ -210,9 +211,9 @@ def _extract_stats(self, trainer: "pl.Trainer", interval: str) -> Dict[str, floa current_stat = self._get_optimizer_stats(opt, names) latest_stat.update(current_stat) - trainer.callback_metrics.update( - {name: torch.tensor(value, device=trainer.strategy.root_device) for name, value in latest_stat.items()} - ) + trainer.callback_metrics.update({ + name: torch.tensor(value, device=trainer.strategy.root_device) for name, value in latest_stat.items() + }) return latest_stat diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index e1e823fc921d2..cef1c7322947d 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -17,6 +17,7 @@ Automatically save model checkpoints during training. """ + import logging import os import re diff --git a/src/lightning/pytorch/callbacks/model_summary.py b/src/lightning/pytorch/callbacks/model_summary.py index ff8a729cc466b..ddff91ed2949e 100644 --- a/src/lightning/pytorch/callbacks/model_summary.py +++ b/src/lightning/pytorch/callbacks/model_summary.py @@ -21,6 +21,7 @@ the name, type and number of parameters for each layer. """ + import logging from typing import Any, Dict, List, Tuple, Union diff --git a/src/lightning/pytorch/callbacks/on_exception_checkpoint.py b/src/lightning/pytorch/callbacks/on_exception_checkpoint.py index d92dd352d355b..53c14cf5afd24 100644 --- a/src/lightning/pytorch/callbacks/on_exception_checkpoint.py +++ b/src/lightning/pytorch/callbacks/on_exception_checkpoint.py @@ -17,6 +17,7 @@ Automatically save a checkpoints on exception. """ + import os from typing import Any diff --git a/src/lightning/pytorch/callbacks/prediction_writer.py b/src/lightning/pytorch/callbacks/prediction_writer.py index c4bb4b1179305..7f782fb81c091 100644 --- a/src/lightning/pytorch/callbacks/prediction_writer.py +++ b/src/lightning/pytorch/callbacks/prediction_writer.py @@ -17,6 +17,7 @@ Aids in saving predictions """ + from typing import Any, Literal, Optional, Sequence from typing_extensions import override diff --git a/src/lightning/pytorch/callbacks/progress/__init__.py b/src/lightning/pytorch/callbacks/progress/__init__.py index e0fab521aed6b..0687f39fd80b9 100644 --- a/src/lightning/pytorch/callbacks/progress/__init__.py +++ b/src/lightning/pytorch/callbacks/progress/__init__.py @@ -18,6 +18,7 @@ Use or override one of the progress bar callbacks. """ + from lightning.pytorch.callbacks.progress.progress_bar import ProgressBar # noqa: F401 from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBar # noqa: F401 from lightning.pytorch.callbacks.progress.tqdm_progress import TQDMProgressBar # noqa: F401 diff --git a/src/lightning/pytorch/callbacks/pruning.py b/src/lightning/pytorch/callbacks/pruning.py index 9935889eb4069..e5e214d1d14d1 100644 --- a/src/lightning/pytorch/callbacks/pruning.py +++ b/src/lightning/pytorch/callbacks/pruning.py @@ -15,6 +15,7 @@ ModelPruning ^^^^^^^^^^^^ """ + import inspect import logging from copy import deepcopy diff --git a/src/lightning/pytorch/callbacks/stochastic_weight_avg.py b/src/lightning/pytorch/callbacks/stochastic_weight_avg.py index 2844f2c752646..731f161683102 100644 --- a/src/lightning/pytorch/callbacks/stochastic_weight_avg.py +++ b/src/lightning/pytorch/callbacks/stochastic_weight_avg.py @@ -15,6 +15,7 @@ Stochastic Weight Averaging Callback ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ """ + from copy import deepcopy from typing import Any, Callable, Dict, List, Optional, Union, cast diff --git a/src/lightning/pytorch/callbacks/throughput_monitor.py b/src/lightning/pytorch/callbacks/throughput_monitor.py index 71a85e431bb7d..a2d73d83184b1 100644 --- a/src/lightning/pytorch/callbacks/throughput_monitor.py +++ b/src/lightning/pytorch/callbacks/throughput_monitor.py @@ -93,21 +93,10 @@ def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) -> dtype = _plugin_to_compute_dtype(trainer.precision_plugin) self.available_flops = get_available_flops(trainer.strategy.root_device, dtype) - if stage == TrainerFn.FITTING: - if trainer.accumulate_grad_batches % trainer.log_every_n_steps != 0: - raise ValueError( - "The `ThroughputMonitor` only logs when gradient accumulation is finished. You set" - f" `Trainer(accumulate_grad_batches={trainer.accumulate_grad_batches}," - f" log_every_n_steps={trainer.log_every_n_steps})` but these are not divisible and thus will not" - " log anything." - ) - - if trainer.enable_validation: - # `fit` includes validation inside - throughput = Throughput( - available_flops=self.available_flops, world_size=trainer.world_size, **self.kwargs - ) - self._throughputs[RunningStage.VALIDATING] = throughput + if stage == TrainerFn.FITTING and trainer.enable_validation: + # `fit` includes validation inside + throughput = Throughput(available_flops=self.available_flops, world_size=trainer.world_size, **self.kwargs) + self._throughputs[RunningStage.VALIDATING] = throughput throughput = Throughput(available_flops=self.available_flops, world_size=trainer.world_size, **self.kwargs) stage = trainer.state.stage diff --git a/src/lightning/pytorch/callbacks/timer.py b/src/lightning/pytorch/callbacks/timer.py index 571e64158fd43..7aa84c0ab5d8c 100644 --- a/src/lightning/pytorch/callbacks/timer.py +++ b/src/lightning/pytorch/callbacks/timer.py @@ -15,6 +15,7 @@ Timer ^^^^^ """ + import logging import time from datetime import timedelta diff --git a/src/lightning/pytorch/core/datamodule.py b/src/lightning/pytorch/core/datamodule.py index 9b9969385e105..9c3518b5dbdba 100644 --- a/src/lightning/pytorch/core/datamodule.py +++ b/src/lightning/pytorch/core/datamodule.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """LightningDataModule for loading DataLoaders with ease.""" + import inspect from typing import IO, Any, Dict, Iterable, Optional, Union, cast diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index faaab7e15fd05..24fe014245e40 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """The LightningModule - an nn.Module with many additional features.""" + import logging import numbers import weakref @@ -142,16 +143,15 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._fabric_optimizers: List[_FabricOptimizer] = [] @overload - def optimizers(self, use_pl_optimizer: Literal[True] = True) -> Union[LightningOptimizer, List[LightningOptimizer]]: - ... + def optimizers( + self, use_pl_optimizer: Literal[True] = True + ) -> Union[LightningOptimizer, List[LightningOptimizer]]: ... @overload - def optimizers(self, use_pl_optimizer: Literal[False]) -> Union[Optimizer, List[Optimizer]]: - ... + def optimizers(self, use_pl_optimizer: Literal[False]) -> Union[Optimizer, List[Optimizer]]: ... @overload - def optimizers(self, use_pl_optimizer: bool) -> MODULE_OPTIMIZERS: - ... + def optimizers(self, use_pl_optimizer: bool) -> MODULE_OPTIMIZERS: ... def optimizers(self, use_pl_optimizer: bool = True) -> MODULE_OPTIMIZERS: """Returns the optimizer(s) that are being used during training. Useful for manual optimization. diff --git a/src/lightning/pytorch/core/optimizer.py b/src/lightning/pytorch/core/optimizer.py index 73bedb4f151d1..b7a63a8e17cab 100644 --- a/src/lightning/pytorch/core/optimizer.py +++ b/src/lightning/pytorch/core/optimizer.py @@ -196,7 +196,7 @@ def _init_optimizers_and_lr_schedulers( def _configure_optimizers( - optim_conf: Union[Dict[str, Any], List, Optimizer, Tuple] + optim_conf: Union[Dict[str, Any], List, Optimizer, Tuple], ) -> Tuple[List, List, Optional[str]]: optimizers, lr_schedulers = [], [] monitor = None @@ -398,12 +398,10 @@ def state_dict(self) -> Dict[str, Any]: return {} # Return Empty @overload - def step(self, closure: None = ...) -> None: - ... + def step(self, closure: None = ...) -> None: ... @overload - def step(self, closure: Callable[[], float]) -> float: - ... + def step(self, closure: Callable[[], float]) -> float: ... @override def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: diff --git a/src/lightning/pytorch/demos/mnist_datamodule.py b/src/lightning/pytorch/demos/mnist_datamodule.py index a46585f971682..992527ab67296 100644 --- a/src/lightning/pytorch/demos/mnist_datamodule.py +++ b/src/lightning/pytorch/demos/mnist_datamodule.py @@ -248,9 +248,10 @@ def default_transforms(self) -> Optional[Callable]: from torchvision import transforms if self.normalize: - mnist_transforms = transforms.Compose( - [transforms.ToTensor(), transforms.Normalize(mean=(0.5,), std=(0.5,))] - ) + mnist_transforms = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=(0.5,), std=(0.5,)), + ]) else: mnist_transforms = transforms.ToTensor() diff --git a/src/lightning/pytorch/demos/transformer.py b/src/lightning/pytorch/demos/transformer.py index 4cd96115807ab..35b07f3a0e683 100644 --- a/src/lightning/pytorch/demos/transformer.py +++ b/src/lightning/pytorch/demos/transformer.py @@ -4,6 +4,7 @@ https://github.com/pytorch/examples/blob/main/word_language_model """ + import math import os from pathlib import Path diff --git a/src/lightning/pytorch/loggers/csv_logs.py b/src/lightning/pytorch/loggers/csv_logs.py index 16aeb3da2eaaa..36c397bb8abdb 100644 --- a/src/lightning/pytorch/loggers/csv_logs.py +++ b/src/lightning/pytorch/loggers/csv_logs.py @@ -18,6 +18,7 @@ CSV logger for basic experiment logging that does not require opening ports """ + import logging import os from argparse import Namespace diff --git a/src/lightning/pytorch/loggers/logger.py b/src/lightning/pytorch/loggers/logger.py index c7dc5cf41d04c..c3051d34a7b09 100644 --- a/src/lightning/pytorch/loggers/logger.py +++ b/src/lightning/pytorch/loggers/logger.py @@ -13,7 +13,6 @@ # limitations under the License. """Abstract base class used to build new loggers.""" - import functools import operator from abc import ABC diff --git a/src/lightning/pytorch/loggers/mlflow.py b/src/lightning/pytorch/loggers/mlflow.py index 9303781d1a4c5..a1f354bd8a91b 100644 --- a/src/lightning/pytorch/loggers/mlflow.py +++ b/src/lightning/pytorch/loggers/mlflow.py @@ -15,6 +15,7 @@ MLflow Logger ------------- """ + import logging import os import re diff --git a/src/lightning/pytorch/loggers/neptune.py b/src/lightning/pytorch/loggers/neptune.py index 52e3700afe159..2769d1bcd0ca6 100644 --- a/src/lightning/pytorch/loggers/neptune.py +++ b/src/lightning/pytorch/loggers/neptune.py @@ -15,11 +15,13 @@ Neptune Logger -------------- """ + import contextlib import logging import os from argparse import Namespace -from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Set, Union +from functools import wraps +from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Set, Union from lightning_utilities.core.imports import RequirementCache from torch import Tensor @@ -44,6 +46,19 @@ _INTEGRATION_VERSION_KEY = "source_code/integrations/pytorch-lightning" +# Neptune client throws `InactiveRunException` when trying to log to an inactive run. +# This may happen when the run was stopped through the UI and the logger is still trying to log to it. +def _catch_inactive(func: Callable) -> Callable: + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + from neptune.exceptions import InactiveRunException + + with contextlib.suppress(InactiveRunException): + return func(*args, **kwargs) + + return wrapper + + class NeptuneLogger(Logger): r"""Log using `Neptune `_. @@ -240,10 +255,7 @@ def __init__( if self._run_instance is not None: self._retrieve_run_data() - if _NEPTUNE_AVAILABLE: - from neptune.handler import Handler - else: - from neptune.new.handler import Handler + from neptune.handler import Handler # make sure that we've log integration version for outside `Run` instances root_obj = self._run_instance @@ -390,7 +402,8 @@ def run(self) -> "Run": @override @rank_zero_only - def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # type: ignore[override] + @_catch_inactive + def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: r"""Log hyperparameters to the run. Hyperparameters will be logged under the "/hyperparams" namespace. @@ -440,9 +453,8 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # @override @rank_zero_only - def log_metrics( # type: ignore[override] - self, metrics: Dict[str, Union[Tensor, float]], step: Optional[int] = None - ) -> None: + @_catch_inactive + def log_metrics(self, metrics: Dict[str, Union[Tensor, float]], step: Optional[int] = None) -> None: """Log metrics (numeric values) in Neptune runs. Args: @@ -460,6 +472,7 @@ def log_metrics( # type: ignore[override] @override @rank_zero_only + @_catch_inactive def finalize(self, status: str) -> None: if not self._run_instance: # When using multiprocessing, finalize() should be a no-op on the main process, as no experiment has been @@ -483,6 +496,7 @@ def save_dir(self) -> Optional[str]: return os.path.join(os.getcwd(), ".neptune") @rank_zero_only + @_catch_inactive def log_model_summary(self, model: "pl.LightningModule", max_depth: int = -1) -> None: if _NEPTUNE_AVAILABLE: from neptune.types import File @@ -496,6 +510,7 @@ def log_model_summary(self, model: "pl.LightningModule", max_depth: int = -1) -> @override @rank_zero_only + @_catch_inactive def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None: """Automatically log checkpointed model. Called after model checkpoint callback saves a new checkpoint. diff --git a/src/lightning/pytorch/loggers/tensorboard.py b/src/lightning/pytorch/loggers/tensorboard.py index 9755074132877..bfa193b7f69ac 100644 --- a/src/lightning/pytorch/loggers/tensorboard.py +++ b/src/lightning/pytorch/loggers/tensorboard.py @@ -82,6 +82,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger): of the queue for pending logs before flushing. `flush_secs` determines how many seconds elapses before flushing. """ + NAME_HPARAMS_FILE = "hparams.yaml" def __init__( diff --git a/src/lightning/pytorch/loggers/wandb.py b/src/lightning/pytorch/loggers/wandb.py index b8862a464fcbb..f9ab289c8a087 100644 --- a/src/lightning/pytorch/loggers/wandb.py +++ b/src/lightning/pytorch/loggers/wandb.py @@ -15,6 +15,7 @@ Weights and Biases Logger ------------------------- """ + import os from argparse import Namespace from pathlib import Path diff --git a/src/lightning/pytorch/loops/evaluation_loop.py b/src/lightning/pytorch/loops/evaluation_loop.py index 1305329521883..0ab3901cf072d 100644 --- a/src/lightning/pytorch/loops/evaluation_loop.py +++ b/src/lightning/pytorch/loops/evaluation_loop.py @@ -478,7 +478,7 @@ def _get_keys(data: dict) -> Iterable[Tuple[str, ...]]: for new_key in apply_to_collection(v, dict, _EvaluationLoop._get_keys): yield (k, *new_key) # this need to be in parenthesis for older python versions else: - yield k, + yield (k,) @staticmethod def _find_value(data: dict, target: Iterable[str]) -> Optional[Any]: diff --git a/src/lightning/pytorch/loops/optimization/automatic.py b/src/lightning/pytorch/loops/optimization/automatic.py index 450e8a0c60b2a..82666a0e21e5f 100644 --- a/src/lightning/pytorch/loops/optimization/automatic.py +++ b/src/lightning/pytorch/loops/optimization/automatic.py @@ -171,8 +171,7 @@ def run(self, optimizer: Optimizer, batch_idx: int, kwargs: OrderedDict) -> _OUT if ( # when the strategy handles accumulation, we want to always call the optimizer step - not self.trainer.strategy.handles_gradient_accumulation - and self.trainer.fit_loop._should_accumulate() + not self.trainer.strategy.handles_gradient_accumulation and self.trainer.fit_loop._should_accumulate() ): # For gradient accumulation diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index ad98c653007c6..9e36ee65176c8 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -18,6 +18,8 @@ from typing_extensions import override import lightning.pytorch as pl +from lightning.fabric.utilities.types import _Stateful +from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch import loops # import as loops to avoid circular imports from lightning.pytorch.loops.fetchers import _DataFetcher, _DataLoaderIterDataFetcher from lightning.pytorch.loops.optimization import _AutomaticOptimization, _ManualOptimization @@ -152,10 +154,16 @@ def reset(self) -> None: trainer = self.trainer if trainer.num_training_batches != float("inf"): expected_steps = math.ceil(trainer.num_training_batches / trainer.accumulate_grad_batches) - if self.global_step % expected_steps != 0: + loader = trainer.fit_loop._combined_loader + assert loader is not None + is_resumable_loader = all(isinstance(loader, _Stateful) for loader in loader.flattened) + if self.global_step % expected_steps != 0 and not is_resumable_loader: rank_zero_warn( - "You're resuming from a checkpoint that ended before the epoch ended. This can cause unreliable" - " results if further training is done. Consider using an end-of-epoch checkpoint" + "You're resuming from a checkpoint that ended before the epoch ended and your dataloader is" + " not resumable. This can cause unreliable results if further training is done." + " Consider using an end-of-epoch checkpoint or make your dataloader resumable by implementing" + " the `state_dict` / `load_state_dict` interface.", + category=PossibleUserWarning, ) else: self.batch_progress.reset_on_run() diff --git a/src/lightning/pytorch/overrides/distributed.py b/src/lightning/pytorch/overrides/distributed.py index fe47a5e8f31d3..e4b65285538f1 100644 --- a/src/lightning/pytorch/overrides/distributed.py +++ b/src/lightning/pytorch/overrides/distributed.py @@ -26,7 +26,7 @@ def _find_tensors( - obj: Union[Tensor, list, tuple, dict, Any] + obj: Union[Tensor, list, tuple, dict, Any], ) -> Union[List[Tensor], itertools.chain]: # pragma: no-cover """Recursively find all tensors contained in the specified object.""" if isinstance(obj, Tensor): diff --git a/src/lightning/pytorch/profilers/advanced.py b/src/lightning/pytorch/profilers/advanced.py index f39cbe153c837..9e166b5e34d9f 100644 --- a/src/lightning/pytorch/profilers/advanced.py +++ b/src/lightning/pytorch/profilers/advanced.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Profiler to check if there are any bottlenecks in your code.""" + import cProfile import io import logging diff --git a/src/lightning/pytorch/profilers/profiler.py b/src/lightning/pytorch/profilers/profiler.py index d7f168b3d2aa1..fb448321575a9 100644 --- a/src/lightning/pytorch/profilers/profiler.py +++ b/src/lightning/pytorch/profilers/profiler.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Profiler to check if there are any bottlenecks in your code.""" + import logging import os from abc import ABC, abstractmethod diff --git a/src/lightning/pytorch/profilers/pytorch.py b/src/lightning/pytorch/profilers/pytorch.py index b685fb403497a..96b8d319bb884 100644 --- a/src/lightning/pytorch/profilers/pytorch.py +++ b/src/lightning/pytorch/profilers/pytorch.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Profiler to check if there are any bottlenecks in your code.""" + import inspect import logging import os diff --git a/src/lightning/pytorch/profilers/simple.py b/src/lightning/pytorch/profilers/simple.py index 9b7274ba71029..eef7b12892faa 100644 --- a/src/lightning/pytorch/profilers/simple.py +++ b/src/lightning/pytorch/profilers/simple.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Profiler to check if there are any bottlenecks in your code.""" + import logging import os import time diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index c12fe5880e344..2cc099be39d26 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -556,12 +556,10 @@ def _initialize_deepspeed_inference(self, model: Module) -> None: if "bf16" in self.config: inference_config.update({"bf16": self.config["bf16"]}) if self.zero_stage_3: - inference_config.update( - { - "zero_allow_untested_optimizer": self.config["zero_allow_untested_optimizer"], - "zero_optimization": self.config["zero_optimization"], - } - ) + inference_config.update({ + "zero_allow_untested_optimizer": self.config["zero_allow_untested_optimizer"], + "zero_optimization": self.config["zero_optimization"], + }) # Remove all module hooks before initializing new model remove_module_hooks(model) model, _, _, _ = deepspeed.initialize( diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index 59c23b764a9b3..a3a34dba75cee 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -33,6 +33,8 @@ _METADATA_FILENAME, _activation_checkpointing_kwargs, _auto_wrap_policy_kwargs, + _distributed_checkpoint_load, + _distributed_checkpoint_save, _get_full_state_dict_context, _get_sharded_state_dict_context, _has_meta_device_parameters, @@ -55,7 +57,6 @@ from lightning.fabric.utilities.imports import ( _TORCH_GREATER_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_1, - _TORCH_GREATER_EQUAL_2_2, ) from lightning.fabric.utilities.init import _EmptyInit from lightning.fabric.utilities.load import _lazy_load, _materialize_tensors @@ -561,20 +562,16 @@ def save_checkpoint( raise IsADirectoryError(f"The checkpoint path exists and is a directory: {path}") if self._state_dict_type == "sharded": - from torch.distributed.checkpoint import FileSystemWriter, save_state_dict - if path.is_file(): path.unlink() path.mkdir(parents=True, exist_ok=True) converted_state = {"model": checkpoint.pop("state_dict")} - converted_state.update( - {f"optimizer_{idx}": optim_state for idx, optim_state in enumerate(checkpoint.pop("optimizer_states"))} - ) + converted_state.update({ + f"optimizer_{idx}": optim_state for idx, optim_state in enumerate(checkpoint.pop("optimizer_states")) + }) - # FSDP's FileSystemWriter streams the tensors to disk to minimize memory peaks - writer = FileSystemWriter(path=path, single_file_per_rank=True) - save_state_dict(converted_state, writer) + _distributed_checkpoint_save(converted_state, path) if self.global_rank == 0: torch.save(checkpoint, path / _METADATA_FILENAME) @@ -596,23 +593,21 @@ def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: assert self.lightning_module is not None if _is_sharded_checkpoint(path): - from torch.distributed.checkpoint import FileSystemReader from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict - if _TORCH_GREATER_EQUAL_2_2: - from torch.distributed.checkpoint import load - else: - from torch.distributed.checkpoint import load_state_dict as load # deprecated - state_dict_ctx = _get_sharded_state_dict_context(self.model) - reader = FileSystemReader(path=path) with state_dict_ctx: module_state = {"model": self.model.state_dict()} - load(module_state, reader) + _distributed_checkpoint_load(module_state, path) self.model.load_state_dict(module_state["model"], strict=self.lightning_module.strict_loading) - if self.lightning_module.trainer.state.fn == TrainerFn.FITTING: + if self.lightning_module.trainer.state.fn == TrainerFn.FITTING and self.optimizers: + from torch.distributed.checkpoint import FileSystemReader + + # TODO: replace with newer APIs + # https://github.com/pytorch/pytorch/issues/119800#issuecomment-1942156271 + reader = FileSystemReader(path=path) # the optimizer states must be loaded separately for idx, optim in enumerate(self.optimizers): optim_key = f"optimizer_{idx}" diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index d0988d9d38005..4b9be731295f6 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -19,6 +19,7 @@ # DO NOT REMOVE THIS NOTICE # - WILLIAM FALCON """Trainer to automate the training.""" + import logging import math import os diff --git a/src/lightning/pytorch/utilities/data.py b/src/lightning/pytorch/utilities/data.py index 0676a17e63d4d..fb1c0a4370d6d 100644 --- a/src/lightning/pytorch/utilities/data.py +++ b/src/lightning/pytorch/utilities/data.py @@ -167,9 +167,9 @@ def _get_dataloader_init_args_and_kwargs( if was_wrapped: # if the dataloader was wrapped in a hook, only take arguments with default values # and assume user passes their kwargs correctly - params.update( - {k: v for k, v in inspect.signature(DataLoader.__init__).parameters.items() if v.default is not v.empty} - ) + params.update({ + k: v for k, v in inspect.signature(DataLoader.__init__).parameters.items() if v.default is not v.empty + }) else: params.update(inspect.signature(DataLoader.__init__).parameters) params.pop("self", None) diff --git a/src/lightning/pytorch/utilities/enums.py b/src/lightning/pytorch/utilities/enums.py index 24285ea3b0c0d..89871384b9821 100644 --- a/src/lightning/pytorch/utilities/enums.py +++ b/src/lightning/pytorch/utilities/enums.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Enumerated utilities.""" + from __future__ import annotations from lightning_utilities.core.enums import StrEnum as LightningEnum diff --git a/src/lightning/pytorch/utilities/grads.py b/src/lightning/pytorch/utilities/grads.py index 21f737f8b69d2..f200d892474db 100644 --- a/src/lightning/pytorch/utilities/grads.py +++ b/src/lightning/pytorch/utilities/grads.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Utilities to describe gradients.""" + from typing import Dict, Union import torch diff --git a/src/lightning/pytorch/utilities/imports.py b/src/lightning/pytorch/utilities/imports.py index 23911f1ca1092..a171d08ba04b6 100644 --- a/src/lightning/pytorch/utilities/imports.py +++ b/src/lightning/pytorch/utilities/imports.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """General utilities.""" + import functools import sys diff --git a/src/lightning/pytorch/utilities/migration/migration.py b/src/lightning/pytorch/utilities/migration/migration.py index 69cecff0864bf..6a5a914bed9ba 100644 --- a/src/lightning/pytorch/utilities/migration/migration.py +++ b/src/lightning/pytorch/utilities/migration/migration.py @@ -29,6 +29,7 @@ python -m lightning.pytorch.utilities.upgrade_checkpoint model.ckpt """ + import re from typing import Any, Callable, Dict, List diff --git a/src/lightning/pytorch/utilities/parameter_tying.py b/src/lightning/pytorch/utilities/parameter_tying.py index 5f5ea505dc05a..8680285c272e8 100644 --- a/src/lightning/pytorch/utilities/parameter_tying.py +++ b/src/lightning/pytorch/utilities/parameter_tying.py @@ -17,6 +17,7 @@ https://github.com/pytorch/fairseq/blob/1f7ef9ed1e1061f8c7f88f8b94c7186834398690/fairseq/trainer.py#L110-L118 """ + from typing import Dict, List, Optional from torch import nn diff --git a/src/lightning/pytorch/utilities/rank_zero.py b/src/lightning/pytorch/utilities/rank_zero.py index 407a6ceb23438..7927235e10cba 100644 --- a/src/lightning/pytorch/utilities/rank_zero.py +++ b/src/lightning/pytorch/utilities/rank_zero.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Utilities that can be used for calling functions on a particular rank.""" + import logging # note: we want to keep these indirections so the `rank_zero_module.log` is set (on import) for PL users diff --git a/src/lightning/pytorch/utilities/seed.py b/src/lightning/pytorch/utilities/seed.py index 10badab69c32f..4ba9e7f0f960f 100644 --- a/src/lightning/pytorch/utilities/seed.py +++ b/src/lightning/pytorch/utilities/seed.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Utilities to help with reproducibility of models.""" + from contextlib import contextmanager from typing import Generator diff --git a/src/lightning/pytorch/utilities/types.py b/src/lightning/pytorch/utilities/types.py index a087b2a862d8b..203df53f22ba7 100644 --- a/src/lightning/pytorch/utilities/types.py +++ b/src/lightning/pytorch/utilities/types.py @@ -16,6 +16,7 @@ - Do not include any `_TYPE` suffix - Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no leading `_`) """ + from contextlib import contextmanager from dataclasses import dataclass from typing import ( @@ -68,12 +69,10 @@ def __init__( check_reduction: bool = False, gradient_as_bucket_view: bool = False, static_graph: bool = False, - ) -> None: - ... + ) -> None: ... @contextmanager - def no_sync(self) -> Generator: - ... + def no_sync(self) -> Generator: ... # todo: improve LRSchedulerType naming/typing @@ -121,7 +120,7 @@ class OptimizerLRSchedulerConfig(TypedDict): Sequence[Optimizer], Tuple[Sequence[Optimizer], Sequence[Union[LRSchedulerTypeUnion, LRSchedulerConfig]]], OptimizerLRSchedulerConfig, - Sequence[OptimizerLRSchedulerConfig] + Sequence[OptimizerLRSchedulerConfig], ] ] diff --git a/src/lightning/pytorch/utilities/warnings.py b/src/lightning/pytorch/utilities/warnings.py index ee54bb3397d58..bdf3835cdf937 100644 --- a/src/lightning/pytorch/utilities/warnings.py +++ b/src/lightning/pytorch/utilities/warnings.py @@ -12,5 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Warning-related utilities.""" + # backwards compatibility from lightning.fabric.utilities.warnings import PossibleUserWarning # noqa: F401 diff --git a/src/version.info b/src/version.info index a8d39b5507816..c043eea7767ee 100644 --- a/src/version.info +++ b/src/version.info @@ -1 +1 @@ -2.2.0.post0 +2.2.1 diff --git a/tests/legacy/back-compatible-versions.txt b/tests/legacy/back-compatible-versions.txt index 34f518a821146..d150189111304 100644 --- a/tests/legacy/back-compatible-versions.txt +++ b/tests/legacy/back-compatible-versions.txt @@ -96,3 +96,4 @@ 2.1.1 2.1.2 2.1.3 +2.2.0.post0 diff --git a/tests/parity_pytorch/models.py b/tests/parity_pytorch/models.py index 5220f12b841da..f55b0d6f1f36e 100644 --- a/tests/parity_pytorch/models.py +++ b/tests/parity_pytorch/models.py @@ -37,9 +37,10 @@ def __init__(self, backbone="resnet101", hidden_dim=1024, learning_rate=1e-3, we self.classifier = torch.nn.Sequential( torch.nn.Linear(1000, hidden_dim), torch.nn.Linear(hidden_dim, self.num_classes) ) - self.transform = transforms.Compose( - [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] - ) + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ]) self._loss = [] # needed for checking if the loss is the same as vanilla torch def training_step(self, batch, batch_idx): diff --git a/tests/tests_app/core/test_lightning_work.py b/tests/tests_app/core/test_lightning_work.py index 7a7f2717f6243..443851d97990f 100644 --- a/tests/tests_app/core/test_lightning_work.py +++ b/tests/tests_app/core/test_lightning_work.py @@ -176,14 +176,12 @@ def run(self): "run_started_counter": 1, "statuses": [], } - caller_queue.put( - { - "args": (), - "kwargs": {}, - "call_hash": call_hash, - "state": work.state, - } - ) + caller_queue.put({ + "args": (), + "kwargs": {}, + "call_hash": call_hash, + "state": work.state, + }) work_runner = WorkRunner( work, work.name, diff --git a/tests/tests_app/frontend/conftest.py b/tests/tests_app/frontend/conftest.py index 6aa31d61874be..9b1d9f174396d 100644 --- a/tests/tests_app/frontend/conftest.py +++ b/tests/tests_app/frontend/conftest.py @@ -1,4 +1,5 @@ """Test configuration.""" + # pylint: disable=protected-access from unittest import mock diff --git a/tests/tests_app/frontend/panel/test_app_state_comm.py b/tests/tests_app/frontend/panel/test_app_state_comm.py index 6baf21172b6e8..a2501db24fde1 100644 --- a/tests/tests_app/frontend/panel/test_app_state_comm.py +++ b/tests/tests_app/frontend/panel/test_app_state_comm.py @@ -1,4 +1,5 @@ """The watch_app_state function enables us to trigger a callback function whenever the App state changes.""" + import os from unittest import mock diff --git a/tests/tests_app/frontend/panel/test_app_state_watcher.py b/tests/tests_app/frontend/panel/test_app_state_watcher.py index 8b9cacad5aaea..f12b989654141 100644 --- a/tests/tests_app/frontend/panel/test_app_state_watcher.py +++ b/tests/tests_app/frontend/panel/test_app_state_watcher.py @@ -6,6 +6,7 @@ This is particularly useful for the PanelFrontend, but can be used by other Frontends too. """ + # pylint: disable=protected-access import os from unittest import mock diff --git a/tests/tests_app/frontend/panel/test_panel_frontend.py b/tests/tests_app/frontend/panel/test_panel_frontend.py index 31c2311089bc6..4b6aa41a12104 100644 --- a/tests/tests_app/frontend/panel/test_panel_frontend.py +++ b/tests/tests_app/frontend/panel/test_panel_frontend.py @@ -1,4 +1,5 @@ """The PanelFrontend wraps your Panel code in your LightningFlow.""" + # pylint: disable=protected-access, too-few-public-methods import os import runpy diff --git a/tests/tests_app/frontend/panel/test_panel_serve_render_fn.py b/tests/tests_app/frontend/panel/test_panel_serve_render_fn.py index 66bddaf830f45..6605373b2f3b6 100644 --- a/tests/tests_app/frontend/panel/test_panel_serve_render_fn.py +++ b/tests/tests_app/frontend/panel/test_panel_serve_render_fn.py @@ -3,6 +3,7 @@ These tests are for serving a render_fn function. """ + import inspect import os from unittest import mock diff --git a/tests/tests_app/runners/test_cloud.py b/tests/tests_app/runners/test_cloud.py index 821049d308180..74b74c99a8049 100644 --- a/tests/tests_app/runners/test_cloud.py +++ b/tests/tests_app/runners/test_cloud.py @@ -143,9 +143,9 @@ def test_run_on_deleted_cluster(self, cloud_backend): memberships=[V1Membership(name="Default Project", project_id=project_id)] ) - mock_client.cluster_service_list_clusters.return_value = V1ListClustersResponse( - [Externalv1Cluster(id=DEFAULT_CLUSTER)] - ) + mock_client.cluster_service_list_clusters.return_value = V1ListClustersResponse([ + Externalv1Cluster(id=DEFAULT_CLUSTER) + ]) cloud_backend.client = mock_client app = mock.MagicMock() @@ -192,12 +192,10 @@ def test_new_instance_on_different_cluster(self, tmpdir, cloud_backend, project_ # Note: # backend converts "None" cluster to "litng-ai-03" # dispatch should receive None, but API calls should return "litng-ai-03" - mock_client.cluster_service_list_clusters.return_value = V1ListClustersResponse( - [ - Externalv1Cluster(id=old_cluster or DEFAULT_CLUSTER), - Externalv1Cluster(id=new_cluster or DEFAULT_CLUSTER), - ] - ) + mock_client.cluster_service_list_clusters.return_value = V1ListClustersResponse([ + Externalv1Cluster(id=old_cluster or DEFAULT_CLUSTER), + Externalv1Cluster(id=new_cluster or DEFAULT_CLUSTER), + ]) mock_client.projects_service_list_project_cluster_bindings.return_value = V1ListProjectClusterBindingsResponse( clusters=[ @@ -267,9 +265,9 @@ def test_running_deleted_app(self, tmpdir, cloud_backend, project_id): cluster_id=DEFAULT_CLUSTER ) - mock_client.cluster_service_list_clusters.return_value = V1ListClustersResponse( - [Externalv1Cluster(id=DEFAULT_CLUSTER)] - ) + mock_client.cluster_service_list_clusters.return_value = V1ListClustersResponse([ + Externalv1Cluster(id=DEFAULT_CLUSTER) + ]) mock_client.projects_service_list_project_cluster_bindings.return_value = V1ListProjectClusterBindingsResponse( clusters=[V1ProjectClusterBinding(cluster_id=DEFAULT_CLUSTER)] @@ -365,9 +363,9 @@ def test_run_on_byoc_cluster(self, tmpdir, monkeypatch): V1ListLightningappInstancesResponse(lightningapps=[]) ) mock_client.cloud_space_service_create_lightning_run.return_value = V1LightningRun(cluster_id="test1234") - mock_client.cluster_service_list_clusters.return_value = V1ListClustersResponse( - [Externalv1Cluster(id="test1234")] - ) + mock_client.cluster_service_list_clusters.return_value = V1ListClustersResponse([ + Externalv1Cluster(id="test1234") + ]) cloud_backend = mock.MagicMock() cloud_backend.client = mock_client monkeypatch.setattr(backends, "CloudBackend", mock.MagicMock(return_value=cloud_backend)) @@ -1964,9 +1962,7 @@ def run(self): "lightning.app.runners.cloud._parse_lightningignore", wraps=_parse_lightningignore ) as parse_mock, mock.patch( "lightning.app.source_code.local._copytree", wraps=_copytree - ) as copy_mock, caplog.at_level( - logging.WARN - ): + ) as copy_mock, caplog.at_level(logging.WARN): cloud_runtime.dispatch() parse_mock.assert_called_once_with(("foo", "foo", "lightning_logs")) @@ -2015,9 +2011,7 @@ def run(self): "lightning.app.runners.cloud._parse_lightningignore", wraps=_parse_lightningignore ) as parse_mock, mock.patch( "lightning.app.source_code.local._copytree", wraps=_copytree - ) as copy_mock, caplog.at_level( - logging.WARN - ): + ) as copy_mock, caplog.at_level(logging.WARN): cloud_runtime.dispatch() parse_mock.assert_called_once_with(()) diff --git a/tests/tests_app/structures/test_structures.py b/tests/tests_app/structures/test_structures.py index a31215eab7cc5..2d0c5e58005f0 100644 --- a/tests/tests_app/structures/test_structures.py +++ b/tests/tests_app/structures/test_structures.py @@ -361,15 +361,13 @@ def __init__(self, use_list, run_once_iterable, cache_calls): DummyFlow(), ) else: - self.iter = Dict( - **{ - "0": CounterWork(cache_calls), - "1": CounterWork(cache_calls), - "2": CounterWork(cache_calls), - "3": CounterWork(cache_calls), - "4": DummyFlow(), - } - ) + self.iter = Dict(**{ + "0": CounterWork(cache_calls), + "1": CounterWork(cache_calls), + "2": CounterWork(cache_calls), + "3": CounterWork(cache_calls), + "4": DummyFlow(), + }) def run(self): for work_idx, work in self.experimental_iterate(enumerate(self.iter), run_once=self.run_once_iterable): diff --git a/tests/tests_app/test_imports.py b/tests/tests_app/test_imports.py index d33b556d9ee0f..8c17a611552ec 100644 --- a/tests/tests_app/test_imports.py +++ b/tests/tests_app/test_imports.py @@ -8,13 +8,11 @@ def _is_attribute(member, module): - return all( - [ - hasattr(member, "__module__") and module.__name__ in member.__module__, - not isinstance(member, TypeVar), - not isinstance(member, types.ModuleType), - ] - ) + return all([ + hasattr(member, "__module__") and module.__name__ in member.__module__, + not isinstance(member, TypeVar), + not isinstance(member, types.ModuleType), + ]) def _find_exports(module): @@ -47,7 +45,7 @@ def test_import_depth( "lightning.app.launcher", "lightning.app.runners", "lightning.app.utilities", - ] + ], ): """This test ensures that any public exports (functions, classes, etc.) can be imported by users with at most a depth of two. This guarantees that everything user-facing can be imported with (at most) ``lightning.app.*.*``. diff --git a/tests/tests_app/utilities/packaging/test_docker.py b/tests/tests_app/utilities/packaging/test_docker.py index 697104ca7c040..95a9edb1882c5 100644 --- a/tests/tests_app/utilities/packaging/test_docker.py +++ b/tests/tests_app/utilities/packaging/test_docker.py @@ -40,14 +40,12 @@ def test_docker_runner(): docker_runner.run() caller_queue = queues.get_caller_queue(work_name=work.name, queue_id=queue_id) - caller_queue.put( - { - "args": (), - "kwargs": {}, - "call_hash": call_hash, - "state": work.state, - } - ) + caller_queue.put({ + "args": (), + "kwargs": {}, + "call_hash": call_hash, + "state": work.state, + }) delta_queue = queues.get_delta_queue(queue_id=queue_id) delta_1 = delta_queue.get() delta_2 = delta_queue.get() diff --git a/tests/tests_app/utilities/test_cli_helpers.py b/tests/tests_app/utilities/test_cli_helpers.py index 54ba37ac07510..8d6684080d7a5 100644 --- a/tests/tests_app/utilities/test_cli_helpers.py +++ b/tests/tests_app/utilities/test_cli_helpers.py @@ -24,12 +24,10 @@ def test_format_input_env_variables(): _format_input_env_variables(("=invalid=",)) with pytest.raises(Exception, match="is duplicated. Please only include it once."): - _format_input_env_variables( - ( - "FOO=bar", - "FOO=bar", - ) - ) + _format_input_env_variables(( + "FOO=bar", + "FOO=bar", + )) with pytest.raises( Exception, @@ -189,12 +187,10 @@ def test_check_environment_and_redirect(mock_redirect_command, tmpdir, monkeypat ) with open(descriptor, "w") as f: - f.writelines( - [ - "#!/bin/bash\n", - f'{sys.executable} "$@"', - ] - ) + f.writelines([ + "#!/bin/bash\n", + f'{sys.executable} "$@"', + ]) monkeypatch.setenv("PATH", f"{tmpdir}") assert _check_environment_and_redirect() is None diff --git a/tests/tests_app/utilities/test_proxies.py b/tests/tests_app/utilities/test_proxies.py index edf23d1837d94..3c5d830e30e02 100644 --- a/tests/tests_app/utilities/test_proxies.py +++ b/tests/tests_app/utilities/test_proxies.py @@ -128,14 +128,12 @@ def get(self, timeout: int = 0): "run_started_counter": 1, "statuses": [], } - caller_queue.put( - { - "args": (), - "kwargs": {}, - "call_hash": call_hash, - "state": work.state, - } - ) + caller_queue.put({ + "args": (), + "kwargs": {}, + "call_hash": call_hash, + "state": work.state, + }) work_runner = WorkRunner( work, work.name, @@ -254,15 +252,13 @@ def __call__(self): called = self.caller_queue.get() self.work.set_state(called["state"]) state = deepcopy(self.work.state) - self.work._calls[call_hash]["statuses"].append( - { - "name": self.work.name, - "stage": WorkStageStatus.FAILED, - "reason": WorkFailureReasons.TIMEOUT, - "timestamp": time.time(), - "message": None, - } - ) + self.work._calls[call_hash]["statuses"].append({ + "name": self.work.name, + "stage": WorkStageStatus.FAILED, + "reason": WorkFailureReasons.TIMEOUT, + "timestamp": time.time(), + "message": None, + }) self.delta_queue.put( ComponentDelta(id=self.work_name, delta=Delta(DeepDiff(state, self.work.state, verbose_level=2))) ) @@ -682,14 +678,12 @@ def run(self): "run_started_counter": 1, "statuses": [], } - work_runner.caller_queue.put( - { - "args": (), - "kwargs": {}, - "call_hash": call_hash, - "state": work.state, - } - ) + work_runner.caller_queue.put({ + "args": (), + "kwargs": {}, + "call_hash": call_hash, + "state": work.state, + }) with mock.patch.dict(os.environ, environment, clear=True): work_runner.setup() diff --git a/tests/tests_app/utilities/test_safe_pickle.py b/tests/tests_app/utilities/test_safe_pickle.py index 989804af2c79d..2c9e6d49a2448 100644 --- a/tests/tests_app/utilities/test_safe_pickle.py +++ b/tests/tests_app/utilities/test_safe_pickle.py @@ -5,8 +5,9 @@ def test_safe_pickle_app(): test_dir = Path(__file__).parent / "testdata" proc = subprocess.Popen( - ["lightning_app", "run", "app", "safe_pickle_app.py", "--open-ui", "false"], - stdout=subprocess.PIPE, cwd=test_dir, + ["lightning_app", "run", "app", "safe_pickle_app.py", "--open-ui", "false"], + stdout=subprocess.PIPE, + cwd=test_dir, ) stdout, _ = proc.communicate() assert "Exiting the pickling app successfully" in stdout.decode("UTF-8") diff --git a/tests/tests_data/processing/test_data_processor.py b/tests/tests_data/processing/test_data_processor.py index 797d4911d1282..e68c4e7abc1a1 100644 --- a/tests/tests_data/processing/test_data_processor.py +++ b/tests/tests_data/processing/test_data_processor.py @@ -999,6 +999,7 @@ def test_map_error_when_not_empty(monkeypatch): error_when_not_empty=False, ) + def map_fn_is_last(index, output_dir, is_last): with open(os.path.join(output_dir, f"{index}_{is_last}.txt"), "w") as f: f.write("here") @@ -1008,8 +1009,8 @@ def map_fn_is_last(index, output_dir, is_last): @pytest.mark.parametrize( ("num_workers", "expected"), [ - (1, ['0_False.txt', '1_False.txt', '2_False.txt', '3_False.txt', '4_True.txt']), - (2, ['0_False.txt', '1_True.txt', '2_False.txt', '3_False.txt', '4_True.txt']), + (1, ["0_False.txt", "1_False.txt", "2_False.txt", "3_False.txt", "4_True.txt"]), + (2, ["0_False.txt", "1_True.txt", "2_False.txt", "3_False.txt", "4_True.txt"]), ], ) def test_map_is_last(num_workers, expected, tmpdir): diff --git a/tests/tests_data/processing/test_dns.py b/tests/tests_data/processing/test_dns.py index 0703e346b02fd..2bfadcdc8e2ae 100644 --- a/tests/tests_data/processing/test_dns.py +++ b/tests/tests_data/processing/test_dns.py @@ -11,7 +11,6 @@ def test_optimize_dns_context(monkeypatch): monkeypatch.setattr(dns_module, "Popen", popen_mock) class FakeFile: - def __init__(self, *args, **kwargs): pass @@ -30,4 +29,7 @@ def readlines(self): pass cmd = popen_mock._mock_call_args_list[0].args[0] - assert cmd == "sudo /home/zeus/miniconda3/envs/cloudspace/bin/python -c 'from lightning.data.processing.dns import _optimize_dns; _optimize_dns(True)'" # noqa: E501 + assert ( + cmd + == "sudo /home/zeus/miniconda3/envs/cloudspace/bin/python -c 'from lightning.data.processing.dns import _optimize_dns; _optimize_dns(True)'" # noqa: E501 + ) diff --git a/tests/tests_data/processing/test_readers.py b/tests/tests_data/processing/test_readers.py index 31e0b9ec62109..5ae6a76918e1a 100644 --- a/tests/tests_data/processing/test_readers.py +++ b/tests/tests_data/processing/test_readers.py @@ -7,7 +7,6 @@ class DummyReader(BaseReader): - def items_to_workers(self, items, num_workers: int): return [[(worker_idx, idx, item) for idx, item in enumerate(items)] for worker_idx in range(num_workers)] @@ -24,7 +23,7 @@ def fn(data: str, output_dir): def test_reader(tmpdir): map(fn, list(range(3)), output_dir=str(tmpdir), reader=DummyReader(), num_workers=2) - assert sorted(os.listdir(tmpdir)) == ['0_0', '0_1', '0_2', '1_0', '1_1', '1_2'] + assert sorted(os.listdir(tmpdir)) == ["0_0", "0_1", "0_2", "1_0", "1_1", "1_2"] def map_parquet(df, output_dir): @@ -33,9 +32,10 @@ def map_parquet(df, output_dir): with open(os.path.join(output_dir, filename), "w") as f: f.write("hello world") + @pytest.mark.skipif( (not _POLARS_AVAILABLE and not _PYARROW_AVAILABLE) or sys.platform == "linux", - reason="polars and pyarrow are required" + reason="polars and pyarrow are required", ) def test_parquet_reader(tmpdir): import polars as pol @@ -53,7 +53,7 @@ def test_parquet_reader(tmpdir): inputs=inputs, output_dir=os.path.join(tmpdir, "output_dir"), reader=ParquetReader(num_rows=10, to_pandas=False), - num_workers=2 + num_workers=2, ) - assert sorted(os.listdir(os.path.join(tmpdir, "output_dir"))) == ['0_10', '10_5', '15_5', '20_10'] + assert sorted(os.listdir(os.path.join(tmpdir, "output_dir"))) == ["0_10", "10_5", "15_5", "20_10"] diff --git a/tests/tests_data/utilities/test_shuffle.py b/tests/tests_data/utilities/test_shuffle.py index 9951a21c603ab..f31aa76e0767e 100644 --- a/tests/tests_data/utilities/test_shuffle.py +++ b/tests/tests_data/utilities/test_shuffle.py @@ -11,7 +11,6 @@ def test_intra_node_chunk_shuffle(): shuffled_indexes = _intra_node_chunk_shuffle(_DistributedEnv(4, 1, 2), chunks_per_ranks, 42, 2) assert shuffled_indexes == [3, 2, 1, 0, 7, 6, 5, 4] - chunks_per_ranks = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13], [14, 15]] shuffled_indexes = _intra_node_chunk_shuffle(_DistributedEnv(8, 7, 2), chunks_per_ranks, 42, 2) assert shuffled_indexes == [5, 2, 0, 7, 6, 1, 3, 4, 13, 10, 8, 15, 14, 9, 11, 12] diff --git a/tests/tests_fabric/loggers/test_csv.py b/tests/tests_fabric/loggers/test_csv.py index 8bb4476785a0c..784bb6fb45aba 100644 --- a/tests/tests_fabric/loggers/test_csv.py +++ b/tests/tests_fabric/loggers/test_csv.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from unittest import mock from unittest.mock import MagicMock import pytest @@ -50,6 +51,21 @@ def test_manual_versioning(tmp_path): assert logger.version == 1 +def test_manual_versioning_file_exists(tmp_path): + """Test that a warning is emitted and existing files get overwritten.""" + + # Simulate an existing 'version_0' vrom a previous run + (tmp_path / "exp" / "version_0").mkdir(parents=True) + previous_metrics_file = tmp_path / "exp" / "version_0" / "metrics.csv" + previous_metrics_file.touch() + + logger = CSVLogger(root_dir=tmp_path, name="exp", version=0) + assert previous_metrics_file.exists() + with pytest.warns(UserWarning, match="Experiment logs directory .* exists and is not empty"): + _ = logger.experiment + assert not previous_metrics_file.exists() + + def test_named_version(tmp_path): """Verify that manual versioning works for string versions, e.g. '2020-02-05-162402'.""" exp_name = "exp" @@ -130,7 +146,11 @@ def test_automatic_step_tracking(tmp_path): assert logger.experiment.metrics[2]["step"] == 2 -def test_append_metrics_file(tmp_path): +@mock.patch( + # Mock the existance check, so we can simulate appending to the metrics file + "lightning.fabric.loggers.csv_logs._ExperimentWriter._check_log_dir_exists" +) +def test_append_metrics_file(_, tmp_path): """Test that the logger appends to the file instead of rewriting it on every save.""" logger = CSVLogger(tmp_path, name="test", version=0, flush_logs_every_n_steps=1) @@ -167,7 +187,11 @@ def test_append_columns(tmp_path): assert set(header.split(",")) == {"step", "a", "b", "c"} -def test_rewrite_with_new_header(tmp_path): +@mock.patch( + # Mock the existance check, so we can simulate appending to the metrics file + "lightning.fabric.loggers.csv_logs._ExperimentWriter._check_log_dir_exists" +) +def test_rewrite_with_new_header(_, tmp_path): # write a csv file manually with open(tmp_path / "metrics.csv", "w") as file: file.write("step,metric1,metric2\n") diff --git a/tests/tests_fabric/plugins/precision/test_amp_integration.py b/tests/tests_fabric/plugins/precision/test_amp_integration.py index c4f4a856007bc..5d88a7d9babb9 100644 --- a/tests/tests_fabric/plugins/precision/test_amp_integration.py +++ b/tests/tests_fabric/plugins/precision/test_amp_integration.py @@ -12,13 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. """Integration tests for Automatic Mixed Precision (AMP) training.""" -import sys import pytest import torch import torch.nn as nn from lightning.fabric import Fabric, seed_everything -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 from tests_fabric.helpers.runif import RunIf @@ -40,11 +38,6 @@ def forward(self, x): return output -@pytest.mark.xfail( - # https://github.com/pytorch/pytorch/issues/116056 - sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2, - reason="Windows + DDP issue in PyTorch 2.2", -) @pytest.mark.parametrize( ("accelerator", "precision", "expected_dtype"), [ diff --git a/tests/tests_fabric/plugins/precision/test_transformer_engine.py b/tests/tests_fabric/plugins/precision/test_transformer_engine.py index 63ae0bbb3a822..c003715dead8a 100644 --- a/tests/tests_fabric/plugins/precision/test_transformer_engine.py +++ b/tests/tests_fabric/plugins/precision/test_transformer_engine.py @@ -101,11 +101,9 @@ def __init__(self): assert isinstance(model.l2, torch.nn.LayerNorm) assert isinstance(model.l3.l, torch.nn.Linear) - class TELinearMock(Mock): - ... + class TELinearMock(Mock): ... - class TELayerNormMock(Mock): - ... + class TELayerNormMock(Mock): ... transformer_engine_mock.pytorch.Linear = TELinearMock transformer_engine_mock.pytorch.LayerNorm = TELayerNormMock diff --git a/tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py b/tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py index c986ee0db91bb..c3bc7d3c2c6cd 100644 --- a/tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py +++ b/tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py @@ -11,13 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import sys import pytest import torch import torch.nn as nn from lightning.fabric import Fabric -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 from tests_fabric.helpers.runif import RunIf @@ -31,11 +29,6 @@ def __init__(self): self.register_buffer("buffer", torch.ones(3)) -@pytest.mark.xfail( - # https://github.com/pytorch/pytorch/issues/116056 - sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2, - reason="Windows + DDP issue in PyTorch 2.2", -) @pytest.mark.parametrize("strategy", ["ddp_spawn", pytest.param("ddp_fork", marks=RunIf(skip_windows=True))]) def test_memory_sharing_disabled(strategy): """Test that the multiprocessing launcher disables memory sharing on model parameters and buffers to avoid race diff --git a/tests/tests_fabric/strategies/test_ddp_integration.py b/tests/tests_fabric/strategies/test_ddp_integration.py index 251d1870d90ac..65eaacde2ff2c 100644 --- a/tests/tests_fabric/strategies/test_ddp_integration.py +++ b/tests/tests_fabric/strategies/test_ddp_integration.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -import sys from copy import deepcopy from unittest import mock from unittest.mock import Mock @@ -20,7 +19,7 @@ import pytest import torch from lightning.fabric import Fabric -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_2 +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0 from torch.nn.parallel.distributed import DistributedDataParallel from tests_fabric.helpers.runif import RunIf @@ -28,11 +27,6 @@ from tests_fabric.test_fabric import BoringModel -@pytest.mark.xfail( - # https://github.com/pytorch/pytorch/issues/116056 - sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2, - reason="Windows + DDP issue in PyTorch 2.2", -) @pytest.mark.parametrize( "accelerator", [ diff --git a/tests/tests_fabric/strategies/test_fsdp.py b/tests/tests_fabric/strategies/test_fsdp.py index ddf259e3b1f74..c6d3135f621bb 100644 --- a/tests/tests_fabric/strategies/test_fsdp.py +++ b/tests/tests_fabric/strategies/test_fsdp.py @@ -29,7 +29,7 @@ _has_meta_device_parameters, _is_sharded_checkpoint, ) -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1 +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1, _TORCH_GREATER_EQUAL_2_2 from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision from torch.optim import Adam @@ -241,13 +241,12 @@ def test_fsdp_save_checkpoint_storage_options(tmp_path): @RunIf(min_torch="2.0.0") -@mock.patch("torch.distributed.checkpoint.save_state_dict", return_value=MagicMock()) @mock.patch("lightning.fabric.strategies.fsdp.FSDPStrategy.broadcast", lambda _, x: x) -@mock.patch("lightning.fabric.strategies.fsdp._get_full_state_dict_context", return_value=MagicMock()) -@mock.patch("lightning.fabric.strategies.fsdp._get_sharded_state_dict_context", return_value=MagicMock()) -@mock.patch("lightning.fabric.strategies.fsdp.torch.save", return_value=Mock()) -@mock.patch("lightning.fabric.strategies.fsdp.shutil", return_value=MagicMock()) -def test_fsdp_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___, ____, tmp_path): +@mock.patch("lightning.fabric.strategies.fsdp._get_full_state_dict_context") +@mock.patch("lightning.fabric.strategies.fsdp._get_sharded_state_dict_context") +@mock.patch("lightning.fabric.strategies.fsdp.torch.save") +@mock.patch("lightning.fabric.strategies.fsdp.shutil") +def test_fsdp_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___, tmp_path): strategy = FSDPStrategy(state_dict_type="full") # state_dict_type='full', path exists, path is not a sharded checkpoint: error @@ -278,6 +277,11 @@ def test_fsdp_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___, torch_save_mock.assert_called_once() strategy = FSDPStrategy(state_dict_type="sharded") + save_mock = mock.patch( + "torch.distributed.checkpoint.save" + if _TORCH_GREATER_EQUAL_2_2 + else "torch.distributed.checkpoint.save_state_dict" + ) # state_dict_type='sharded', path exists, path is a folder: no error (overwrite) path = tmp_path / "not-empty-2" @@ -285,7 +289,8 @@ def test_fsdp_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___, (path / "file").touch() model = Mock(spec=FullyShardedDataParallel) model.modules.return_value = [model] - strategy.save_checkpoint(path=path, state={"model": model}) + with save_mock: + strategy.save_checkpoint(path=path, state={"model": model}) assert (path / "file").exists() # state_dict_type='sharded', path exists, path is a file: no error (overwrite) @@ -293,7 +298,8 @@ def test_fsdp_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___, path.touch() model = Mock(spec=FullyShardedDataParallel) model.modules.return_value = [model] - strategy.save_checkpoint(path=path, state={"model": model}) + with save_mock: + strategy.save_checkpoint(path=path, state={"model": model}) assert path.is_dir() diff --git a/tests/tests_fabric/strategies/test_fsdp_integration.py b/tests/tests_fabric/strategies/test_fsdp_integration.py index 42fc1ba399e69..5743d261e1077 100644 --- a/tests/tests_fabric/strategies/test_fsdp_integration.py +++ b/tests/tests_fabric/strategies/test_fsdp_integration.py @@ -385,12 +385,10 @@ def test_setup_with_orig_params_and_multiple_param_groups(): torch.nn.Linear(10, 10, bias=False), torch.nn.Linear(5, 2, bias=False), ) - optimizer = torch.optim.Adam( - [ - {"params": model[0].parameters(), "lr": 1e-2}, - {"params": model[1].parameters(), "lr": 1e-6}, - ] - ) + optimizer = torch.optim.Adam([ + {"params": model[0].parameters(), "lr": 1e-2}, + {"params": model[1].parameters(), "lr": 1e-6}, + ]) # set up model and optimizer jointly wrapped_model, wrapped_optimizer = fabric.setup(model, optimizer) diff --git a/tests/tests_fabric/test_cli.py b/tests/tests_fabric/test_cli.py index 596318b4b619f..c44dc4361943d 100644 --- a/tests/tests_fabric/test_cli.py +++ b/tests/tests_fabric/test_cli.py @@ -133,16 +133,14 @@ def test_cli_torchrun_defaults(monkeypatch, fake_script): with pytest.raises(SystemExit) as e: _run_model.main([fake_script]) assert e.value.code == 0 - torchrun_mock.main.assert_called_with( - [ - "--nproc_per_node=1", - "--nnodes=1", - "--node_rank=0", - "--master_addr=127.0.0.1", - "--master_port=29400", - fake_script, - ] - ) + torchrun_mock.main.assert_called_with([ + "--nproc_per_node=1", + "--nnodes=1", + "--node_rank=0", + "--master_addr=127.0.0.1", + "--master_port=29400", + fake_script, + ]) @pytest.mark.parametrize( @@ -163,16 +161,14 @@ def test_cli_torchrun_num_processes_launched(_, devices, expected, monkeypatch, with pytest.raises(SystemExit) as e: _run_model.main([fake_script, "--accelerator", "cuda", "--devices", devices]) assert e.value.code == 0 - torchrun_mock.main.assert_called_with( - [ - f"--nproc_per_node={expected}", - "--nnodes=1", - "--node_rank=0", - "--master_addr=127.0.0.1", - "--master_port=29400", - fake_script, - ] - ) + torchrun_mock.main.assert_called_with([ + f"--nproc_per_node={expected}", + "--nnodes=1", + "--node_rank=0", + "--master_addr=127.0.0.1", + "--master_port=29400", + fake_script, + ]) def test_cli_through_fabric_entry_point(): @@ -181,6 +177,7 @@ def test_cli_through_fabric_entry_point(): message = "Usage: fabric run model [OPTIONS] SCRIPT [SCRIPT_ARGS]" assert message in result.stdout or message in result.stderr + @pytest.mark.skipif("lightning.fabric" == "lightning_fabric", reason="standalone package") def test_cli_through_lightning_entry_point(): result = subprocess.run("lightning run model --help", capture_output=True, text=True, shell=True) diff --git a/tests/tests_fabric/utilities/test_cloud_io.py b/tests/tests_fabric/utilities/test_cloud_io.py index 5169d0fa94d77..d502199da1493 100644 --- a/tests/tests_fabric/utilities/test_cloud_io.py +++ b/tests/tests_fabric/utilities/test_cloud_io.py @@ -22,8 +22,7 @@ def test_get_filesystem_custom_filesystem(): _DUMMY_PRFEIX = "dummy" - class DummyFileSystem(LocalFileSystem): - ... + class DummyFileSystem(LocalFileSystem): ... fsspec.register_implementation(_DUMMY_PRFEIX, DummyFileSystem, clobber=True) output_file = os.path.join(f"{_DUMMY_PRFEIX}://", "tmpdir/tmp_file") diff --git a/tests/tests_fabric/utilities/test_distributed.py b/tests/tests_fabric/utilities/test_distributed.py index 4badff47403eb..ad32c936a1c4d 100644 --- a/tests/tests_fabric/utilities/test_distributed.py +++ b/tests/tests_fabric/utilities/test_distributed.py @@ -1,6 +1,5 @@ import functools import os -import sys from functools import partial from pathlib import Path from unittest import mock @@ -18,7 +17,6 @@ _sync_ddp, is_shared_filesystem, ) -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 from tests_fabric.helpers.runif import RunIf @@ -120,11 +118,6 @@ def test_collective_operations(devices, process): spawn_launch(process, devices) -@pytest.mark.xfail( - # https://github.com/pytorch/pytorch/issues/116056 - sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2, - reason="Windows + DDP issue in PyTorch 2.2", -) @pytest.mark.flaky(reruns=3) # flaky with "process 0 terminated with signal SIGABRT" (GLOO) def test_is_shared_filesystem(tmp_path, monkeypatch): # In the non-distributed case, every location is interpreted as 'shared' diff --git a/tests/tests_fabric/utilities/test_spike.py b/tests/tests_fabric/utilities/test_spike.py index ab7d9f474383d..9739540af7f18 100644 --- a/tests/tests_fabric/utilities/test_spike.py +++ b/tests/tests_fabric/utilities/test_spike.py @@ -4,7 +4,6 @@ import pytest import torch from lightning.fabric import Fabric -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 from lightning.fabric.utilities.spike import _TORCHMETRICS_GREATER_EQUAL_1_0_0, SpikeDetection, TrainingSpikeException @@ -29,11 +28,6 @@ def spike_detection_test(fabric, global_rank_spike, spike_value, should_raise): ) -@pytest.mark.xfail( - # https://github.com/pytorch/pytorch/issues/116056 - sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2, - reason="Windows + DDP issue in PyTorch 2.2", -) @pytest.mark.flaky(max_runs=3) @pytest.mark.parametrize( ("global_rank_spike", "num_devices", "spike_value", "finite_only"), diff --git a/tests/tests_fabric/utilities/test_warnings.py b/tests/tests_fabric/utilities/test_warnings.py index 1f34bf6e2a938..b7989d85b5932 100644 --- a/tests/tests_fabric/utilities/test_warnings.py +++ b/tests/tests_fabric/utilities/test_warnings.py @@ -16,6 +16,7 @@ Needs to be run outside of `pytest` as it captures all the warnings. """ + import importlib import inspect import os diff --git a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py index 0e462b2c20e24..6f6ad8f2227dd 100644 --- a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py @@ -377,9 +377,12 @@ def get_metrics(self, trainer, pl_module): items = super().get_metrics(trainer, model) del items["v_num"] # this is equivalent to mocking `set_postfix` as this method gets called every time - self.calls[trainer.state.fn].append( - (trainer.state.stage, trainer.current_epoch, trainer.global_step, items) - ) + self.calls[trainer.state.fn].append(( + trainer.state.stage, + trainer.current_epoch, + trainer.global_step, + items, + )) return items class MyModel(BoringModel): diff --git a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py index ee920105a6716..cbe07164fb427 100644 --- a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py @@ -334,13 +334,16 @@ def test_tqdm_progress_bar_value_on_colab(tmp_path): assert trainer.progress_bar_callback.refresh_rate == 19 -@pytest.mark.parametrize(("refresh_rate", "env_value", "expected"), [ - (0, 1, 1), - (1, 0, 1), - (1, 1, 1), - (2, 1, 2), - (1, 2, 2), -]) +@pytest.mark.parametrize( + ("refresh_rate", "env_value", "expected"), + [ + (0, 1, 1), + (1, 0, 1), + (1, 1, 1), + (2, 1, 2), + (1, 2, 2), + ], +) def test_tqdm_progress_bar_refresh_rate_via_env_variable(refresh_rate, env_value, expected): with mock.patch.dict(os.environ, {"TQDM_MINITERS": str(env_value)}): bar = TQDMProgressBar(refresh_rate=refresh_rate) @@ -538,9 +541,12 @@ def test_tqdm_progress_bar_print_disabled(tqdm_write, mock_print, tmp_path): trainer.test(model, verbose=False) trainer.predict(model) - mock_print.assert_has_calls( - [call("training_step", end=""), call("validation_step", file=ANY), call("test_step"), call("predict_step")] - ) + mock_print.assert_has_calls([ + call("training_step", end=""), + call("validation_step", file=ANY), + call("test_step"), + call("predict_step"), + ]) tqdm_write.assert_not_called() @@ -694,9 +700,12 @@ def get_metrics(self, trainer, pl_module): items = super().get_metrics(trainer, model) del items["v_num"] # this is equivalent to mocking `set_postfix` as this method gets called every time - self.calls[trainer.state.fn].append( - (trainer.state.stage, trainer.current_epoch, trainer.global_step, items) - ) + self.calls[trainer.state.fn].append(( + trainer.state.stage, + trainer.current_epoch, + trainer.global_step, + items, + )) return items class MyModel(BoringModel): diff --git a/tests/tests_pytorch/callbacks/test_finetuning_callback.py b/tests/tests_pytorch/callbacks/test_finetuning_callback.py index e9952d18a522a..9c3ae5f587548 100644 --- a/tests/tests_pytorch/callbacks/test_finetuning_callback.py +++ b/tests/tests_pytorch/callbacks/test_finetuning_callback.py @@ -282,9 +282,10 @@ def test_complex_nested_model(): directly themselves rather than exclusively their submodules containing parameters.""" model = nn.Sequential( - OrderedDict( - [("encoder", nn.Sequential(ConvBlockParam(3, 64), ConvBlock(64, 128))), ("decoder", ConvBlock(128, 10))] - ) + OrderedDict([ + ("encoder", nn.Sequential(ConvBlockParam(3, 64), ConvBlock(64, 128))), + ("decoder", ConvBlock(128, 10)), + ]) ) # There are 10 leaf modules or parent modules w/ parameters in the test model diff --git a/tests/tests_pytorch/callbacks/test_prediction_writer.py b/tests/tests_pytorch/callbacks/test_prediction_writer.py index 3dad25f783f34..343716d332279 100644 --- a/tests/tests_pytorch/callbacks/test_prediction_writer.py +++ b/tests/tests_pytorch/callbacks/test_prediction_writer.py @@ -89,20 +89,16 @@ def test_prediction_writer_batch_indices(num_workers): trainer = Trainer(limit_predict_batches=4, callbacks=writer) trainer.predict(model, dataloaders=dataloader) - writer.write_on_batch_end.assert_has_calls( - [ - call(trainer, model, ANY, [0, 1, 2, 3], ANY, 0, 0), - call(trainer, model, ANY, [4, 5, 6, 7], ANY, 1, 0), - call(trainer, model, ANY, [8, 9, 10, 11], ANY, 2, 0), - call(trainer, model, ANY, [12, 13, 14, 15], ANY, 3, 0), - ] - ) - - writer.write_on_epoch_end.assert_has_calls( - [ - call(trainer, model, ANY, [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]]), - ] - ) + writer.write_on_batch_end.assert_has_calls([ + call(trainer, model, ANY, [0, 1, 2, 3], ANY, 0, 0), + call(trainer, model, ANY, [4, 5, 6, 7], ANY, 1, 0), + call(trainer, model, ANY, [8, 9, 10, 11], ANY, 2, 0), + call(trainer, model, ANY, [12, 13, 14, 15], ANY, 3, 0), + ]) + + writer.write_on_epoch_end.assert_has_calls([ + call(trainer, model, ANY, [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]]), + ]) def test_batch_level_batch_indices(): @@ -119,11 +115,9 @@ def on_predict_epoch_end(self, *args, **kwargs): trainer = Trainer(limit_predict_batches=4, callbacks=writer) trainer.predict(model, dataloaders=dataloader, return_predictions=False) - writer.write_on_batch_end.assert_has_calls( - [ - call(trainer, model, ANY, [0, 1, 2, 3], ANY, 0, 0), - call(trainer, model, ANY, [4, 5, 6, 7], ANY, 1, 0), - call(trainer, model, ANY, [8, 9, 10, 11], ANY, 2, 0), - call(trainer, model, ANY, [12, 13, 14, 15], ANY, 3, 0), - ] - ) + writer.write_on_batch_end.assert_has_calls([ + call(trainer, model, ANY, [0, 1, 2, 3], ANY, 0, 0), + call(trainer, model, ANY, [4, 5, 6, 7], ANY, 1, 0), + call(trainer, model, ANY, [8, 9, 10, 11], ANY, 2, 0), + call(trainer, model, ANY, [12, 13, 14, 15], ANY, 3, 0), + ]) diff --git a/tests/tests_pytorch/callbacks/test_pruning.py b/tests/tests_pytorch/callbacks/test_pruning.py index 0c21c67781eb8..0842e7968fccb 100644 --- a/tests/tests_pytorch/callbacks/test_pruning.py +++ b/tests/tests_pytorch/callbacks/test_pruning.py @@ -33,9 +33,11 @@ class TestModel(BoringModel): def __init__(self): super().__init__() self.layer = Sequential( - OrderedDict( - [("mlp_1", nn.Linear(32, 32)), ("mlp_2", nn.Linear(32, 32, bias=False)), ("mlp_3", nn.Linear(32, 2))] - ) + OrderedDict([ + ("mlp_1", nn.Linear(32, 32)), + ("mlp_2", nn.Linear(32, 32, bias=False)), + ("mlp_3", nn.Linear(32, 2)), + ]) ) def training_step(self, batch, batch_idx): diff --git a/tests/tests_pytorch/callbacks/test_spike.py b/tests/tests_pytorch/callbacks/test_spike.py index b3bc54bd6e836..f4d0c946cefa0 100644 --- a/tests/tests_pytorch/callbacks/test_spike.py +++ b/tests/tests_pytorch/callbacks/test_spike.py @@ -3,7 +3,6 @@ import pytest import torch -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 from lightning.fabric.utilities.spike import _TORCHMETRICS_GREATER_EQUAL_1_0_0, TrainingSpikeException from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.callbacks.spike import SpikeDetection @@ -47,11 +46,6 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx) -@pytest.mark.xfail( - # https://github.com/pytorch/pytorch/issues/116056 - sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2, - reason="Windows + DDP issue in PyTorch 2.2", -) @pytest.mark.flaky(max_runs=3) @pytest.mark.parametrize( ("global_rank_spike", "num_devices", "spike_value", "finite_only"), diff --git a/tests/tests_pytorch/callbacks/test_throughput_monitor.py b/tests/tests_pytorch/callbacks/test_throughput_monitor.py index 74be939da671c..9467e45e2fa80 100644 --- a/tests/tests_pytorch/callbacks/test_throughput_monitor.py +++ b/tests/tests_pytorch/callbacks/test_throughput_monitor.py @@ -160,7 +160,8 @@ def test_throughput_monitor_fit_no_length_fn(tmp_path): ] -def test_throughput_monitor_fit_gradient_accumulation(tmp_path): +@pytest.mark.parametrize("log_every_n_steps", [1, 3]) +def test_throughput_monitor_fit_gradient_accumulation(log_every_n_steps, tmp_path): logger_mock = Mock() logger_mock.save_dir = tmp_path monitor = ThroughputMonitor(length_fn=lambda x: 3 * 2, batch_size_fn=lambda x: 3, window_size=4, separator="|") @@ -174,26 +175,8 @@ def test_throughput_monitor_fit_gradient_accumulation(tmp_path): limit_train_batches=5, limit_val_batches=0, max_epochs=2, - log_every_n_steps=3, + log_every_n_steps=log_every_n_steps, accumulate_grad_batches=2, - num_sanity_val_steps=2, - enable_checkpointing=False, - enable_model_summary=False, - enable_progress_bar=False, - ) - with pytest.raises(ValueError, match="not divisible"): - trainer.fit(model) - - trainer = Trainer( - devices=1, - logger=logger_mock, - callbacks=monitor, - limit_train_batches=5, - limit_val_batches=0, - max_epochs=2, - log_every_n_steps=1, - accumulate_grad_batches=2, - num_sanity_val_steps=2, enable_checkpointing=False, enable_model_summary=False, enable_progress_bar=False, @@ -211,9 +194,19 @@ def test_throughput_monitor_fit_gradient_accumulation(tmp_path): "train|device|flops_per_sec": 10.0, "train|device|mfu": 0.1, } - assert logger_mock.log_metrics.mock_calls == [ + + all_log_calls = [ call( - metrics={"train|time": 2.5, "train|batches": 2, "train|samples": 6, "train|lengths": 12, "epoch": 0}, step=0 + metrics={ + # The very first batch doesn't have the *_per_sec metrics yet + **(expected if log_every_n_steps > 1 else {}), + "train|time": 2.5, + "train|batches": 2, + "train|samples": 6, + "train|lengths": 12, + "epoch": 0, + }, + step=0, ), call( metrics={ @@ -271,6 +264,8 @@ def test_throughput_monitor_fit_gradient_accumulation(tmp_path): step=5, ), ] + expected_log_calls = all_log_calls[(log_every_n_steps - 1) :: log_every_n_steps] + assert logger_mock.log_metrics.mock_calls == expected_log_calls @pytest.mark.parametrize("fn", ["validate", "test", "predict"]) diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index fb2c8d8e35a93..21f926bac9cab 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -885,13 +885,11 @@ def test_default_checkpoint_behavior(tmpdir): assert len(results) == 1 save_dir = tmpdir / "checkpoints" save_weights_only = trainer.checkpoint_callback.save_weights_only - save_mock.assert_has_calls( - [ - call(save_dir / "epoch=0-step=5.ckpt", save_weights_only), - call(save_dir / "epoch=1-step=10.ckpt", save_weights_only), - call(save_dir / "epoch=2-step=15.ckpt", save_weights_only), - ] - ) + save_mock.assert_has_calls([ + call(save_dir / "epoch=0-step=5.ckpt", save_weights_only), + call(save_dir / "epoch=1-step=10.ckpt", save_weights_only), + call(save_dir / "epoch=2-step=15.ckpt", save_weights_only), + ]) ckpts = os.listdir(save_dir) assert len(ckpts) == 1 assert ckpts[0] == "epoch=2-step=15.ckpt" diff --git a/tests/tests_pytorch/core/test_datamodules.py b/tests/tests_pytorch/core/test_datamodules.py index 2e021537c93b9..b7597bbd25fe6 100644 --- a/tests/tests_pytorch/core/test_datamodules.py +++ b/tests/tests_pytorch/core/test_datamodules.py @@ -291,12 +291,10 @@ def test_dm_init_from_datasets_dataloaders(iterable): dm = LightningDataModule.from_datasets(train_ds_sequence, batch_size=4, num_workers=0) with mock.patch("lightning.pytorch.core.datamodule.DataLoader") as dl_mock: dm.train_dataloader() - dl_mock.assert_has_calls( - [ - call(train_ds_sequence[0], batch_size=4, shuffle=not iterable, num_workers=0, pin_memory=True), - call(train_ds_sequence[1], batch_size=4, shuffle=not iterable, num_workers=0, pin_memory=True), - ] - ) + dl_mock.assert_has_calls([ + call(train_ds_sequence[0], batch_size=4, shuffle=not iterable, num_workers=0, pin_memory=True), + call(train_ds_sequence[1], batch_size=4, shuffle=not iterable, num_workers=0, pin_memory=True), + ]) with pytest.raises(MisconfigurationException, match="`val_dataloader` must be implemented"): _ = dm.val_dataloader() with pytest.raises(MisconfigurationException, match="`test_dataloader` must be implemented"): @@ -321,16 +319,14 @@ def test_dm_init_from_datasets_dataloaders(iterable): dm.val_dataloader() dm.test_dataloader() dm.predict_dataloader() - dl_mock.assert_has_calls( - [ - call(valid_dss[0], batch_size=4, shuffle=False, num_workers=0, pin_memory=True), - call(valid_dss[1], batch_size=4, shuffle=False, num_workers=0, pin_memory=True), - call(test_dss[0], batch_size=4, shuffle=False, num_workers=0, pin_memory=True), - call(test_dss[1], batch_size=4, shuffle=False, num_workers=0, pin_memory=True), - call(predict_dss[0], batch_size=4, shuffle=False, num_workers=0, pin_memory=True), - call(predict_dss[1], batch_size=4, shuffle=False, num_workers=0, pin_memory=True), - ] - ) + dl_mock.assert_has_calls([ + call(valid_dss[0], batch_size=4, shuffle=False, num_workers=0, pin_memory=True), + call(valid_dss[1], batch_size=4, shuffle=False, num_workers=0, pin_memory=True), + call(test_dss[0], batch_size=4, shuffle=False, num_workers=0, pin_memory=True), + call(test_dss[1], batch_size=4, shuffle=False, num_workers=0, pin_memory=True), + call(predict_dss[0], batch_size=4, shuffle=False, num_workers=0, pin_memory=True), + call(predict_dss[1], batch_size=4, shuffle=False, num_workers=0, pin_memory=True), + ]) def test_dm_init_from_datasets_with_init_params(): diff --git a/tests/tests_pytorch/core/test_metric_result_integration.py b/tests/tests_pytorch/core/test_metric_result_integration.py index 26488307e134f..28c3068eb785b 100644 --- a/tests/tests_pytorch/core/test_metric_result_integration.py +++ b/tests/tests_pytorch/core/test_metric_result_integration.py @@ -415,9 +415,9 @@ def __init__(self): self.metrics_list = ModuleList([DummyMetric() for _ in range(2)]) self.metrics_dict = ModuleDict({"a": DummyMetric(), "b": DummyMetric()}) self.metrics_collection_dict = MetricCollection({"a": DummyMetric(), "b": DummyMetric()}) - self.metrics_collection_dict_nested = ModuleDict( - {"a": ModuleList([ModuleDict({"b": DummyMetric()}), DummyMetric()])} - ) + self.metrics_collection_dict_nested = ModuleDict({ + "a": ModuleList([ModuleDict({"b": DummyMetric()}), DummyMetric()]) + }) def training_step(self, batch, batch_idx): loss = super().training_step(batch, batch_idx) diff --git a/tests/tests_pytorch/core/test_saving.py b/tests/tests_pytorch/core/test_saving.py index 07529333ac112..e20d5cb803b85 100644 --- a/tests/tests_pytorch/core/test_saving.py +++ b/tests/tests_pytorch/core/test_saving.py @@ -120,19 +120,23 @@ def test_load_from_checkpoint_warn_on_empty_state_dict(tmp_path): assert model.device.type == "cpu" -@pytest.mark.parametrize(("strict", "strict_loading", "expected"), [ - (None, None, True), - (None, True, True), - (None, False, False), - (True, None, True), - (True, True, True), - (True, False, "error"), - (False, None, False), - (False, True, "error"), - (False, False, False), -]) +@pytest.mark.parametrize( + ("strict", "strict_loading", "expected"), + [ + (None, None, True), + (None, True, True), + (None, False, False), + (True, None, True), + (True, True, True), + (True, False, "error"), + (False, None, False), + (False, True, "error"), + (False, False, False), + ], +) def test_load_from_checkpoint_strict(strict, strict_loading, expected, tmp_path): """Test that strict loading works both with the `strict` argument and the model's `strict_loading` attribute.""" + class LoadingModel(BoringModel): def __init__(self): super().__init__() diff --git a/tests/tests_pytorch/deprecated_api/test_no_removal_version.py b/tests/tests_pytorch/deprecated_api/test_no_removal_version.py index 79d45c3d5b801..d12254c6794fe 100644 --- a/tests/tests_pytorch/deprecated_api/test_no_removal_version.py +++ b/tests/tests_pytorch/deprecated_api/test_no_removal_version.py @@ -12,8 +12,7 @@ def test_configure_sharded_model(): class MyModel(BoringModel): - def configure_sharded_model(self) -> None: - ... + def configure_sharded_model(self) -> None: ... model = MyModel() trainer = Trainer(devices=1, accelerator="cpu", fast_dev_run=1) @@ -21,8 +20,7 @@ def configure_sharded_model(self) -> None: trainer.fit(model) class MyModelBoth(MyModel): - def configure_model(self): - ... + def configure_model(self): ... model = MyModelBoth() with pytest.raises( diff --git a/tests/tests_pytorch/graveyard/test_precision.py b/tests/tests_pytorch/graveyard/test_precision.py index 17bca3315ea64..818e4cf8e9733 100644 --- a/tests/tests_pytorch/graveyard/test_precision.py +++ b/tests/tests_pytorch/graveyard/test_precision.py @@ -1,3 +1,6 @@ +import pytest + + def test_precision_plugin_renamed_imports(): # base class from lightning.pytorch.plugins import PrecisionPlugin as PrecisionPlugin2 @@ -9,6 +12,10 @@ def test_precision_plugin_renamed_imports(): assert issubclass(PrecisionPlugin1, Precision) assert issubclass(PrecisionPlugin2, Precision) + for plugin_cls in [PrecisionPlugin0, PrecisionPlugin1, PrecisionPlugin2]: + with pytest.warns(DeprecationWarning, match="The `PrecisionPlugin` is deprecated"): + plugin_cls() + # bitsandbytes from lightning.pytorch.plugins import BitsandbytesPrecisionPlugin as BnbPlugin2 from lightning.pytorch.plugins.precision import BitsandbytesPrecisionPlugin as BnbPlugin1 @@ -39,6 +46,10 @@ def test_precision_plugin_renamed_imports(): assert issubclass(DoublePlugin1, DoublePrecision) assert issubclass(DoublePlugin2, DoublePrecision) + for plugin_cls in [DoublePlugin0, DoublePlugin1, DoublePlugin2]: + with pytest.warns(DeprecationWarning, match="The `DoublePrecisionPlugin` is deprecated"): + plugin_cls() + # fsdp from lightning.pytorch.plugins import FSDPPrecisionPlugin as FSDPPlugin2 from lightning.pytorch.plugins.precision import FSDPPrecisionPlugin as FSDPPlugin1 @@ -49,6 +60,10 @@ def test_precision_plugin_renamed_imports(): assert issubclass(FSDPPlugin1, FSDPPrecision) assert issubclass(FSDPPlugin2, FSDPPrecision) + for plugin_cls in [FSDPPlugin0, FSDPPlugin1, FSDPPlugin2]: + with pytest.warns(DeprecationWarning, match="The `FSDPPrecisionPlugin` is deprecated"): + plugin_cls(precision="16-mixed") + # half from lightning.pytorch.plugins import HalfPrecisionPlugin as HalfPlugin2 from lightning.pytorch.plugins.precision import HalfPrecisionPlugin as HalfPlugin1 @@ -59,6 +74,10 @@ def test_precision_plugin_renamed_imports(): assert issubclass(HalfPlugin1, HalfPrecision) assert issubclass(HalfPlugin2, HalfPrecision) + for plugin_cls in [HalfPlugin0, HalfPlugin1, HalfPlugin2]: + with pytest.warns(DeprecationWarning, match="The `HalfPrecisionPlugin` is deprecated"): + plugin_cls() + # mixed from lightning.pytorch.plugins import MixedPrecisionPlugin as MixedPlugin2 from lightning.pytorch.plugins.precision import MixedPrecisionPlugin as MixedPlugin1 @@ -69,6 +88,10 @@ def test_precision_plugin_renamed_imports(): assert issubclass(MixedPlugin1, MixedPrecision) assert issubclass(MixedPlugin2, MixedPrecision) + for plugin_cls in [MixedPlugin0, MixedPlugin1, MixedPlugin2]: + with pytest.warns(DeprecationWarning, match="The `MixedPrecisionPlugin` is deprecated"): + plugin_cls(precision="bf16-mixed", device="cuda:0") + # transformer_engine from lightning.pytorch.plugins import TransformerEnginePrecisionPlugin as TEPlugin2 from lightning.pytorch.plugins.precision import TransformerEnginePrecisionPlugin as TEPlugin1 diff --git a/tests/tests_pytorch/loggers/conftest.py b/tests/tests_pytorch/loggers/conftest.py index 40dec3a4e7fbe..0a0923b6902a6 100644 --- a/tests/tests_pytorch/loggers/conftest.py +++ b/tests/tests_pytorch/loggers/conftest.py @@ -38,7 +38,7 @@ def mlflow_mock(monkeypatch): mlflow.tracking = mlflow_tracking mlflow.entities = mlflow_entities - monkeypatch.setattr("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", True), + (monkeypatch.setattr("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", True),) return mlflow @@ -141,6 +141,10 @@ def __setitem__(self, key, value): neptune_utils.stringify_unsupported = Mock() monkeypatch.setitem(sys.modules, "neptune.utils", neptune_utils) + neptune_exceptions = ModuleType("exceptions") + neptune_exceptions.InactiveRunException = Exception + monkeypatch.setitem(sys.modules, "neptune.exceptions", neptune_exceptions) + neptune.handler = neptune_handler neptune.types = neptune_types neptune.utils = neptune_utils diff --git a/tests/tests_pytorch/loggers/test_csv.py b/tests/tests_pytorch/loggers/test_csv.py index 30f7f94c1516a..a03b8a7d62f78 100644 --- a/tests/tests_pytorch/loggers/test_csv.py +++ b/tests/tests_pytorch/loggers/test_csv.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from unittest import mock from unittest.mock import MagicMock import fsspec @@ -51,6 +52,21 @@ def test_manual_versioning(tmp_path): assert logger.version == 1 +def test_manual_versioning_file_exists(tmp_path): + """Test that a warning is emitted and existing files get overwritten.""" + + # Simulate an existing 'version_0' vrom a previous run + (tmp_path / "exp" / "version_0").mkdir(parents=True) + previous_metrics_file = tmp_path / "exp" / "version_0" / "metrics.csv" + previous_metrics_file.touch() + + logger = CSVLogger(save_dir=tmp_path, name="exp", version=0) + assert previous_metrics_file.exists() + with pytest.warns(UserWarning, match="Experiment logs directory .* exists and is not empty"): + _ = logger.experiment + assert not previous_metrics_file.exists() + + def test_named_version(tmp_path): """Verify that manual versioning works for string versions, e.g. '2020-02-05-162402'.""" exp_name = "exp" @@ -148,6 +164,10 @@ def test_metrics_reset_after_save(tmp_path): assert not logger.experiment.metrics +@mock.patch( + # Mock the existance check, so we can simulate appending to the metrics file + "lightning.fabric.loggers.csv_logs._ExperimentWriter._check_log_dir_exists" +) def test_append_metrics_file(tmp_path): """Test that the logger appends to the file instead of rewriting it on every save.""" logger = CSVLogger(tmp_path, name="test", version=0, flush_logs_every_n_steps=1) diff --git a/tests/tests_pytorch/loggers/test_logger.py b/tests/tests_pytorch/loggers/test_logger.py index 56df40557da66..a4029c43104ed 100644 --- a/tests/tests_pytorch/loggers/test_logger.py +++ b/tests/tests_pytorch/loggers/test_logger.py @@ -261,8 +261,7 @@ def __init__(self, hparams: Dict[str, Any]) -> None: super().__init__() self.save_hyperparameters(hparams) - class _Test: - ... + class _Test: ... same_params = {1: 1, "2": 2, "three": 3.0, "test": _Test(), "4": torch.tensor(4)} model = TestModel(same_params) diff --git a/tests/tests_pytorch/loggers/test_neptune.py b/tests/tests_pytorch/loggers/test_neptune.py index ea9cdd9285b0f..13941c18db8e8 100644 --- a/tests/tests_pytorch/loggers/test_neptune.py +++ b/tests/tests_pytorch/loggers/test_neptune.py @@ -262,12 +262,10 @@ def test_after_save_checkpoint(neptune_mock): run_instance_mock.__getitem__.assert_any_call(f"{model_key_prefix}/checkpoints/model1") run_instance_mock.__getitem__.assert_any_call(f"{model_key_prefix}/checkpoints/model2/with/slashes") - run_attr_mock.upload.assert_has_calls( - [ - call(os.path.join(models_root_dir, "model1")), - call(os.path.join(models_root_dir, "model2/with/slashes")), - ] - ) + run_attr_mock.upload.assert_has_calls([ + call(os.path.join(models_root_dir, "model1")), + call(os.path.join(models_root_dir, "model2/with/slashes")), + ]) def test_save_dir(neptune_mock): @@ -305,3 +303,13 @@ def test_get_full_model_names_from_exp_structure(): } expected_keys = {"lvl1_1/lvl2/lvl3_1", "lvl1_1/lvl2/lvl3_2", "lvl1_2"} assert NeptuneLogger._get_full_model_names_from_exp_structure(input_dict, "foo/bar") == expected_keys + + +def test_inactive_run(neptune_mock, tmp_path): + from neptune.exceptions import InactiveRunException + + logger, run_instance_mock, _ = _get_logger_with_mocks(api_key="test", project="project") + run_instance_mock.__setitem__.side_effect = InactiveRunException + + # this should work without any exceptions + _fit_and_test(logger=logger, model=BoringModel(), tmp_path=tmp_path) diff --git a/tests/tests_pytorch/loggers/test_wandb.py b/tests/tests_pytorch/loggers/test_wandb.py index bb6618d4e80b4..16c69d2c6d773 100644 --- a/tests/tests_pytorch/loggers/test_wandb.py +++ b/tests/tests_pytorch/loggers/test_wandb.py @@ -436,9 +436,10 @@ def test_wandb_log_media(wandb_mock, tmp_path): wandb_mock.init().log.reset_mock() logger.log_image(key="samples", images=["1.jpg", "2.jpg"], step=5) wandb_mock.Image.assert_called_with("2.jpg") - wandb_mock.init().log.assert_called_once_with( - {"samples": [wandb_mock.Image(), wandb_mock.Image()], "trainer/global_step": 5} - ) + wandb_mock.init().log.assert_called_once_with({ + "samples": [wandb_mock.Image(), wandb_mock.Image()], + "trainer/global_step": 5, + }) # test log_image with captions wandb_mock.init().log.reset_mock() @@ -465,9 +466,10 @@ def test_wandb_log_media(wandb_mock, tmp_path): wandb_mock.init().log.reset_mock() logger.log_audio(key="samples", audios=["1.mp3", "2.mp3"], step=5) wandb_mock.Audio.assert_called_with("2.mp3") - wandb_mock.init().log.assert_called_once_with( - {"samples": [wandb_mock.Audio(), wandb_mock.Audio()], "trainer/global_step": 5} - ) + wandb_mock.init().log.assert_called_once_with({ + "samples": [wandb_mock.Audio(), wandb_mock.Audio()], + "trainer/global_step": 5, + }) # test log_audio with captions wandb_mock.init().log.reset_mock() @@ -494,9 +496,10 @@ def test_wandb_log_media(wandb_mock, tmp_path): wandb_mock.init().log.reset_mock() logger.log_video(key="samples", videos=["1.mp4", "2.mp4"], step=5) wandb_mock.Video.assert_called_with("2.mp4") - wandb_mock.init().log.assert_called_once_with( - {"samples": [wandb_mock.Video(), wandb_mock.Video()], "trainer/global_step": 5} - ) + wandb_mock.init().log.assert_called_once_with({ + "samples": [wandb_mock.Video(), wandb_mock.Video()], + "trainer/global_step": 5, + }) # test log_video with captions wandb_mock.init().log.reset_mock() diff --git a/tests/tests_pytorch/loops/test_evaluation_loop.py b/tests/tests_pytorch/loops/test_evaluation_loop.py index 7e740999cc679..33708bd542c9f 100644 --- a/tests/tests_pytorch/loops/test_evaluation_loop.py +++ b/tests/tests_pytorch/loops/test_evaluation_loop.py @@ -200,11 +200,9 @@ def test_invalid_dataloader_idx_raises_step(tmp_path): trainer = Trainer(default_root_dir=tmp_path, fast_dev_run=True) class ExtraDataloaderIdx(BoringModel): - def validation_step(self, batch, batch_idx, dataloader_idx): - ... + def validation_step(self, batch, batch_idx, dataloader_idx): ... - def test_step(self, batch, batch_idx, dataloader_idx): - ... + def test_step(self, batch, batch_idx, dataloader_idx): ... model = ExtraDataloaderIdx() with pytest.raises(RuntimeError, match="have included `dataloader_idx` in `ExtraDataloaderIdx.validation_step"): @@ -213,22 +211,18 @@ def test_step(self, batch, batch_idx, dataloader_idx): trainer.test(model) class GoodDefault(BoringModel): - def validation_step(self, batch, batch_idx, dataloader_idx=0): - ... + def validation_step(self, batch, batch_idx, dataloader_idx=0): ... - def test_step(self, batch, batch_idx, dataloader_idx=0): - ... + def test_step(self, batch, batch_idx, dataloader_idx=0): ... model = GoodDefault() trainer.validate(model) trainer.test(model) class ExtraDlIdxOtherName(BoringModel): - def validation_step(self, batch, batch_idx, dl_idx): - ... + def validation_step(self, batch, batch_idx, dl_idx): ... - def test_step(self, batch, batch_idx, dl_idx): - ... + def test_step(self, batch, batch_idx, dl_idx): ... model = ExtraDlIdxOtherName() # different names are not supported @@ -251,22 +245,18 @@ def test_dataloader(self): trainer.test(model) class IgnoringModel(MultipleDataloader): - def validation_step(self, batch, batch_idx, *_): - ... + def validation_step(self, batch, batch_idx, *_): ... - def test_step(self, batch, batch_idx, *_): - ... + def test_step(self, batch, batch_idx, *_): ... model = IgnoringModel() trainer.validate(model) trainer.test(model) class IgnoringModel2(MultipleDataloader): - def validation_step(self, batch, batch_idx, **_): - ... + def validation_step(self, batch, batch_idx, **_): ... - def test_step(self, batch, batch_idx, **_): - ... + def test_step(self, batch, batch_idx, **_): ... model = IgnoringModel2() with pytest.raises(RuntimeError, match="no `dataloader_idx` argument in `IgnoringModel2.validation_step"): @@ -279,11 +269,9 @@ def test_invalid_dataloader_idx_raises_batch_start(tmp_path): trainer = Trainer(default_root_dir=tmp_path, fast_dev_run=True) class ExtraDataloaderIdx(BoringModel): - def on_validation_batch_start(self, batch, batch_idx, dataloader_idx): - ... + def on_validation_batch_start(self, batch, batch_idx, dataloader_idx): ... - def on_test_batch_start(self, batch, batch_idx, dataloader_idx): - ... + def on_test_batch_start(self, batch, batch_idx, dataloader_idx): ... model = ExtraDataloaderIdx() with pytest.raises( @@ -294,22 +282,18 @@ def on_test_batch_start(self, batch, batch_idx, dataloader_idx): trainer.test(model) class GoodDefault(BoringModel): - def on_validation_batch_start(self, batch, batch_idx, dataloader_idx=0): - ... + def on_validation_batch_start(self, batch, batch_idx, dataloader_idx=0): ... - def on_test_batch_start(self, batch, batch_idx, dataloader_idx=0): - ... + def on_test_batch_start(self, batch, batch_idx, dataloader_idx=0): ... model = GoodDefault() trainer.validate(model) trainer.test(model) class ExtraDlIdxOtherName(BoringModel): - def on_validation_batch_start(self, batch, batch_idx, dl_idx): - ... + def on_validation_batch_start(self, batch, batch_idx, dl_idx): ... - def on_test_batch_start(self, batch, batch_idx, dl_idx): - ... + def on_test_batch_start(self, batch, batch_idx, dl_idx): ... model = ExtraDlIdxOtherName() # different names are not supported @@ -319,17 +303,13 @@ def on_test_batch_start(self, batch, batch_idx, dl_idx): trainer.test(model) class MultipleDataloader(BoringModel): - def validation_step(self, batch, batch_idx, dataloader_idx=0): - ... + def validation_step(self, batch, batch_idx, dataloader_idx=0): ... - def test_step(self, batch, batch_idx, dataloader_idx=0): - ... + def test_step(self, batch, batch_idx, dataloader_idx=0): ... - def on_validation_batch_start(self, batch, batch_idx): - ... + def on_validation_batch_start(self, batch, batch_idx): ... - def on_test_batch_start(self, batch, batch_idx): - ... + def on_test_batch_start(self, batch, batch_idx): ... def val_dataloader(self): return [super().val_dataloader(), super().val_dataloader()] @@ -346,22 +326,18 @@ def test_dataloader(self): trainer.test(model) class IgnoringModel(MultipleDataloader): - def on_validation_batch_start(self, batch, batch_idx, *_): - ... + def on_validation_batch_start(self, batch, batch_idx, *_): ... - def on_test_batch_start(self, batch, batch_idx, *_): - ... + def on_test_batch_start(self, batch, batch_idx, *_): ... model = IgnoringModel() trainer.validate(model) trainer.test(model) class IgnoringModel2(MultipleDataloader): - def on_validation_batch_start(self, batch, batch_idx, **_): - ... + def on_validation_batch_start(self, batch, batch_idx, **_): ... - def on_test_batch_start(self, batch, batch_idx, **_): - ... + def on_test_batch_start(self, batch, batch_idx, **_): ... model = IgnoringModel2() with pytest.raises(RuntimeError, match="no `dataloader_idx` argument in `IgnoringModel2.on_validation_batch_start"): @@ -374,11 +350,9 @@ def test_invalid_dataloader_idx_raises_batch_end(tmp_path): trainer = Trainer(default_root_dir=tmp_path, fast_dev_run=True) class ExtraDataloaderIdx(BoringModel): - def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx): - ... + def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx): ... - def on_test_batch_end(self, outputs, batch, batch_idx, dataloader_idx): - ... + def on_test_batch_end(self, outputs, batch, batch_idx, dataloader_idx): ... model = ExtraDataloaderIdx() with pytest.raises( @@ -389,22 +363,18 @@ def on_test_batch_end(self, outputs, batch, batch_idx, dataloader_idx): trainer.test(model) class GoodDefault(BoringModel): - def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx=0): - ... + def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx=0): ... - def on_test_batch_end(self, outputs, batch, batch_idx, dataloader_idx=0): - ... + def on_test_batch_end(self, outputs, batch, batch_idx, dataloader_idx=0): ... model = GoodDefault() trainer.validate(model) trainer.test(model) class ExtraDlIdxOtherName(BoringModel): - def on_validation_batch_end(self, outputs, batch, batch_idx, dl_idx): - ... + def on_validation_batch_end(self, outputs, batch, batch_idx, dl_idx): ... - def on_test_batch_end(self, outputs, batch, batch_idx, dl_idx): - ... + def on_test_batch_end(self, outputs, batch, batch_idx, dl_idx): ... model = ExtraDlIdxOtherName() # different names are not supported @@ -414,17 +384,13 @@ def on_test_batch_end(self, outputs, batch, batch_idx, dl_idx): trainer.test(model) class MultipleDataloader(BoringModel): - def validation_step(self, batch, batch_idx, dataloader_idx=0): - ... + def validation_step(self, batch, batch_idx, dataloader_idx=0): ... - def test_step(self, batch, batch_idx, dataloader_idx=0): - ... + def test_step(self, batch, batch_idx, dataloader_idx=0): ... - def on_validation_batch_end(self, outputs, batch, batch_idx): - ... + def on_validation_batch_end(self, outputs, batch, batch_idx): ... - def on_test_batch_end(self, outputs, batch, batch_idx): - ... + def on_test_batch_end(self, outputs, batch, batch_idx): ... def val_dataloader(self): return [super().val_dataloader(), super().val_dataloader()] @@ -441,22 +407,18 @@ def test_dataloader(self): trainer.test(model) class IgnoringModel(MultipleDataloader): - def on_validation_batch_end(self, outputs, batch, batch_idx, *_): - ... + def on_validation_batch_end(self, outputs, batch, batch_idx, *_): ... - def on_test_batch_end(self, outputs, batch, batch_idx, *_): - ... + def on_test_batch_end(self, outputs, batch, batch_idx, *_): ... model = IgnoringModel() trainer.validate(model) trainer.test(model) class IgnoringModel2(MultipleDataloader): - def on_validation_batch_end(self, outputs, batch, batch_idx, **_): - ... + def on_validation_batch_end(self, outputs, batch, batch_idx, **_): ... - def on_test_batch_end(self, outputs, batch, batch_idx, **_): - ... + def on_test_batch_end(self, outputs, batch, batch_idx, **_): ... model = IgnoringModel2() with pytest.raises(RuntimeError, match="no `dataloader_idx` argument in `IgnoringModel2.on_validation_batch_end"): diff --git a/tests/tests_pytorch/loops/test_prediction_loop.py b/tests/tests_pytorch/loops/test_prediction_loop.py index 5f6b9d6328d55..d25e5936f9ce1 100644 --- a/tests/tests_pytorch/loops/test_prediction_loop.py +++ b/tests/tests_pytorch/loops/test_prediction_loop.py @@ -12,10 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools -import sys import pytest -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper @@ -52,11 +50,6 @@ def predict_step(self, batch, batch_idx): assert trainer.predict_loop.predictions == [] -@pytest.mark.xfail( - # https://github.com/pytorch/pytorch/issues/116056 - sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2, - reason="Windows + DDP issue in PyTorch 2.2", -) @pytest.mark.parametrize("use_distributed_sampler", [False, True]) def test_prediction_loop_batch_sampler_set_epoch_called(tmp_path, use_distributed_sampler): """Tests that set_epoch is called on the dataloader's batch sampler (if any) during prediction.""" @@ -144,23 +137,20 @@ def test_invalid_dataloader_idx_raises_step(tmp_path): trainer = Trainer(default_root_dir=tmp_path, fast_dev_run=True) class ExtraDataloaderIdx(BoringModel): - def predict_step(self, batch, batch_idx, dataloader_idx): - ... + def predict_step(self, batch, batch_idx, dataloader_idx): ... model = ExtraDataloaderIdx() with pytest.raises(RuntimeError, match="have included `dataloader_idx` in `ExtraDataloaderIdx.predict_step"): trainer.predict(model) class GoodDefault(BoringModel): - def predict_step(self, batch, batch_idx, dataloader_idx=0): - ... + def predict_step(self, batch, batch_idx, dataloader_idx=0): ... model = GoodDefault() trainer.predict(model) class ExtraDlIdxOtherName(BoringModel): - def predict_step(self, batch, batch_idx, dl_idx): - ... + def predict_step(self, batch, batch_idx, dl_idx): ... model = ExtraDlIdxOtherName() # different names are not supported @@ -168,8 +158,7 @@ def predict_step(self, batch, batch_idx, dl_idx): trainer.predict(model) class MultipleDataloader(BoringModel): - def predict_step(self, batch, batch_idx): - ... + def predict_step(self, batch, batch_idx): ... def predict_dataloader(self): return [super().predict_dataloader(), super().predict_dataloader()] @@ -179,15 +168,13 @@ def predict_dataloader(self): trainer.predict(model) class IgnoringModel(MultipleDataloader): - def predict_step(self, batch, batch_idx, *_): - ... + def predict_step(self, batch, batch_idx, *_): ... model = IgnoringModel() trainer.predict(model) class IgnoringModel2(MultipleDataloader): - def predict_step(self, batch, batch_idx, **_): - ... + def predict_step(self, batch, batch_idx, **_): ... model = IgnoringModel2() with pytest.raises(RuntimeError, match="no `dataloader_idx` argument in `IgnoringModel2.predict_step"): @@ -198,8 +185,7 @@ def test_invalid_dataloader_idx_raises_batch_start(tmp_path): trainer = Trainer(default_root_dir=tmp_path, fast_dev_run=True) class ExtraDataloaderIdx(BoringModel): - def on_predict_batch_start(self, batch, batch_idx, dataloader_idx): - ... + def on_predict_batch_start(self, batch, batch_idx, dataloader_idx): ... model = ExtraDataloaderIdx() with pytest.raises( @@ -208,15 +194,13 @@ def on_predict_batch_start(self, batch, batch_idx, dataloader_idx): trainer.predict(model) class GoodDefault(BoringModel): - def on_predict_batch_start(self, batch, batch_idx, dataloader_idx=0): - ... + def on_predict_batch_start(self, batch, batch_idx, dataloader_idx=0): ... model = GoodDefault() trainer.predict(model) class ExtraDlIdxOtherName(BoringModel): - def on_predict_batch_start(self, batch, batch_idx, dl_idx): - ... + def on_predict_batch_start(self, batch, batch_idx, dl_idx): ... model = ExtraDlIdxOtherName() # different names are not supported @@ -224,8 +208,7 @@ def on_predict_batch_start(self, batch, batch_idx, dl_idx): trainer.predict(model) class MultipleDataloader(BoringModel): - def on_predict_batch_start(self, batch, batch_idx): - ... + def on_predict_batch_start(self, batch, batch_idx): ... def predict_dataloader(self): return [super().predict_dataloader(), super().predict_dataloader()] @@ -237,15 +220,13 @@ def predict_dataloader(self): trainer.predict(model) class IgnoringModel(MultipleDataloader): - def on_predict_batch_start(self, batch, batch_idx, *_): - ... + def on_predict_batch_start(self, batch, batch_idx, *_): ... model = IgnoringModel() trainer.predict(model) class IgnoringModel2(MultipleDataloader): - def on_predict_batch_start(self, batch, batch_idx, **_): - ... + def on_predict_batch_start(self, batch, batch_idx, **_): ... model = IgnoringModel2() with pytest.raises(RuntimeError, match="no `dataloader_idx` argument in `IgnoringModel2.on_predict_batch_start"): @@ -256,8 +237,7 @@ def test_invalid_dataloader_idx_raises_batch_end(tmp_path): trainer = Trainer(default_root_dir=tmp_path, fast_dev_run=True) class ExtraDataloaderIdx(BoringModel): - def on_predict_batch_end(self, outputs, batch, batch_idx, dataloader_idx): - ... + def on_predict_batch_end(self, outputs, batch, batch_idx, dataloader_idx): ... model = ExtraDataloaderIdx() with pytest.raises( @@ -266,15 +246,13 @@ def on_predict_batch_end(self, outputs, batch, batch_idx, dataloader_idx): trainer.predict(model) class GoodDefault(BoringModel): - def on_predict_batch_end(self, outputs, batch, batch_idx, dataloader_idx=0): - ... + def on_predict_batch_end(self, outputs, batch, batch_idx, dataloader_idx=0): ... model = GoodDefault() trainer.predict(model) class ExtraDlIdxOtherName(BoringModel): - def on_predict_batch_end(self, outputs, batch, batch_idx, dl_idx): - ... + def on_predict_batch_end(self, outputs, batch, batch_idx, dl_idx): ... model = ExtraDlIdxOtherName() # different names are not supported @@ -282,8 +260,7 @@ def on_predict_batch_end(self, outputs, batch, batch_idx, dl_idx): trainer.predict(model) class MultipleDataloader(BoringModel): - def on_predict_batch_end(self, outputs, batch, batch_idx): - ... + def on_predict_batch_end(self, outputs, batch, batch_idx): ... def predict_dataloader(self): return [super().predict_dataloader(), super().predict_dataloader()] @@ -293,15 +270,13 @@ def predict_dataloader(self): trainer.predict(model) class IgnoringModel(MultipleDataloader): - def on_predict_batch_end(self, outputs, batch, batch_idx, *_): - ... + def on_predict_batch_end(self, outputs, batch, batch_idx, *_): ... model = IgnoringModel() trainer.predict(model) class IgnoringModel2(MultipleDataloader): - def on_predict_batch_end(self, outputs, batch, batch_idx, **_): - ... + def on_predict_batch_end(self, outputs, batch, batch_idx, **_): ... model = IgnoringModel2() with pytest.raises(RuntimeError, match="no `dataloader_idx` argument in `IgnoringModel2.on_predict_batch_end"): diff --git a/tests/tests_pytorch/loops/test_training_epoch_loop.py b/tests/tests_pytorch/loops/test_training_epoch_loop.py index 06f27ab322530..16ed3842e3a96 100644 --- a/tests/tests_pytorch/loops/test_training_epoch_loop.py +++ b/tests/tests_pytorch/loops/test_training_epoch_loop.py @@ -15,11 +15,15 @@ from unittest.mock import Mock, patch import pytest +import torch +from lightning.fabric.utilities.warnings import PossibleUserWarning +from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.trainer.trainer import Trainer +from lightning_utilities.test.warning import no_warning_call -def test_no_val_on_train_epoch_loop_restart(tmpdir): +def test_no_val_on_train_epoch_loop_restart(tmp_path): """Test that training validation loop doesn't get triggered at the beginning of a restart.""" trainer_kwargs = { "max_epochs": 1, @@ -31,7 +35,7 @@ def test_no_val_on_train_epoch_loop_restart(tmpdir): trainer = Trainer(**trainer_kwargs) model = BoringModel() trainer.fit(model) - ckpt_path = str(tmpdir / "last.ckpt") + ckpt_path = str(tmp_path / "last.ckpt") trainer.save_checkpoint(ckpt_path) trainer_kwargs["max_epochs"] = 2 @@ -157,3 +161,59 @@ def optimizer_step(self, epoch, batch_idx, *args, **kwargs): model = MyModel() trainer.fit(model) assert model.last_batch_idx == 3 + + +def test_resume_mid_epoch_warning(tmp_path): + """Test that resuming from a mid-epoch checkpoint raises a warning unless the dataloader is stateful.""" + + class NotStatefulIterable: + def __init__(self): + self.index = 0 + + def __iter__(self): + for i in range(self.index, len(self)): + yield torch.ones(2, 32) * i + + def __len__(self): + return 3 + + class StatefulIterable(NotStatefulIterable): + def state_dict(self): + return {"index": self.index} + + def load_state_dict(self, state_dict): + self.index = state_dict["index"] + + trainer_kwargs = { + "default_root_dir": tmp_path, + "accelerator": "cpu", + "max_epochs": 1, + "enable_model_summary": False, + "enable_progress_bar": False, + "logger": False, + } + + def train_and_resume(dataloader, resume_step, expected_warning): + # Initial training + checkpoint_dir = tmp_path / "checkpoints" + trainer = Trainer( + **trainer_kwargs, + callbacks=ModelCheckpoint(dirpath=checkpoint_dir, every_n_train_steps=1, save_top_k=-1), + ) + trainer.fit(BoringModel(), dataloader) + + # Resume + trainer = Trainer(**trainer_kwargs, enable_checkpointing=False) + resume_from = checkpoint_dir / f"epoch=0-step={resume_step}.ckpt" + warn_assert = pytest.warns if expected_warning else no_warning_call + with warn_assert(PossibleUserWarning, match="resuming from a checkpoint that ended before"): + trainer.fit(BoringModel(), dataloader, ckpt_path=resume_from) + + # Resume mid-epoch, no stateful dataloader -> warning + train_and_resume(dataloader=NotStatefulIterable(), resume_step=1, expected_warning=True) + + # Resume end-of-epoch, no stateful dataloader -> no warning + train_and_resume(dataloader=NotStatefulIterable(), resume_step=3, expected_warning=False) + + # Resume mid-epoch, stateful dataloader -> no warning + train_and_resume(dataloader=StatefulIterable(), resume_step=1, expected_warning=False) diff --git a/tests/tests_pytorch/models/test_amp.py b/tests/tests_pytorch/models/test_amp.py index 5608c02026072..d51b139d6118d 100644 --- a/tests/tests_pytorch/models/test_amp.py +++ b/tests/tests_pytorch/models/test_amp.py @@ -12,13 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -import sys from unittest import mock import pytest import torch from lightning.fabric.plugins.environments import SLURMEnvironment -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from torch.utils.data import DataLoader @@ -55,16 +53,7 @@ def _assert_autocast_enabled(self): [ ("single_device", "16-mixed", 1), ("single_device", "bf16-mixed", 1), - pytest.param( - "ddp_spawn", - "16-mixed", - 2, - marks=pytest.mark.xfail( - # https://github.com/pytorch/pytorch/issues/116056 - sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2, - reason="Windows + DDP issue in PyTorch 2.2", - ), - ), + ("ddp_spawn", "16-mixed", 2), pytest.param("ddp_spawn", "bf16-mixed", 2, marks=RunIf(skip_windows=True)), ], ) diff --git a/tests/tests_pytorch/models/test_hooks.py b/tests/tests_pytorch/models/test_hooks.py index 9d02d37368c46..727f08b3e2d02 100644 --- a/tests/tests_pytorch/models/test_hooks.py +++ b/tests/tests_pytorch/models/test_hooks.py @@ -263,52 +263,50 @@ def _auto_train_batch(trainer, model, batches, device, current_epoch=0, current_ using_deepspeed = kwargs.get("strategy") == "deepspeed" out = [] for i in range(current_batch, batches): - out.extend( - [ - {"name": "on_before_batch_transfer", "args": (ANY, 0)}, - {"name": "transfer_batch_to_device", "args": (ANY, device, 0)}, - {"name": "on_after_batch_transfer", "args": (ANY, 0)}, - {"name": "Callback.on_train_batch_start", "args": (trainer, model, ANY, i)}, - {"name": "on_train_batch_start", "args": (ANY, i)}, - {"name": "forward", "args": (ANY,)}, - {"name": "training_step", "args": (ANY, i)}, - {"name": "Callback.on_before_zero_grad", "args": (trainer, model, ANY)}, - {"name": "on_before_zero_grad", "args": (ANY,)}, - {"name": "optimizer_zero_grad", "args": (current_epoch, i, ANY)}, - {"name": "Callback.on_before_backward", "args": (trainer, model, ANY)}, - {"name": "on_before_backward", "args": (ANY,)}, - # DeepSpeed handles backward internally - *([{"name": "backward", "args": (ANY,)}] if not using_deepspeed else []), - {"name": "Callback.on_after_backward", "args": (trainer, model)}, - {"name": "on_after_backward"}, - # note: unscaling happens here in the case of AMP - {"name": "Callback.on_before_optimizer_step", "args": (trainer, model, ANY)}, - {"name": "on_before_optimizer_step", "args": (ANY,)}, - { - "name": "clip_gradients", - "args": (ANY,), - "kwargs": {"gradient_clip_val": None, "gradient_clip_algorithm": None}, - }, - { - "name": "configure_gradient_clipping", - "args": (ANY,), - "kwargs": {"gradient_clip_val": None, "gradient_clip_algorithm": None}, - }, - # this is after because it refers to the `LightningModule.optimizer_step` hook which encapsulates - # the actual call to `Precision.optimizer_step` - { - "name": "optimizer_step", - "args": (current_epoch, i, ANY, ANY), - }, - *( - [{"name": "lr_scheduler_step", "args": (ANY, None)}] - if i == (trainer.num_training_batches - 1) - else [] - ), - {"name": "Callback.on_train_batch_end", "args": (trainer, model, {"loss": ANY}, ANY, i)}, - {"name": "on_train_batch_end", "args": ({"loss": ANY}, ANY, i)}, - ] - ) + out.extend([ + {"name": "on_before_batch_transfer", "args": (ANY, 0)}, + {"name": "transfer_batch_to_device", "args": (ANY, device, 0)}, + {"name": "on_after_batch_transfer", "args": (ANY, 0)}, + {"name": "Callback.on_train_batch_start", "args": (trainer, model, ANY, i)}, + {"name": "on_train_batch_start", "args": (ANY, i)}, + {"name": "forward", "args": (ANY,)}, + {"name": "training_step", "args": (ANY, i)}, + {"name": "Callback.on_before_zero_grad", "args": (trainer, model, ANY)}, + {"name": "on_before_zero_grad", "args": (ANY,)}, + {"name": "optimizer_zero_grad", "args": (current_epoch, i, ANY)}, + {"name": "Callback.on_before_backward", "args": (trainer, model, ANY)}, + {"name": "on_before_backward", "args": (ANY,)}, + # DeepSpeed handles backward internally + *([{"name": "backward", "args": (ANY,)}] if not using_deepspeed else []), + {"name": "Callback.on_after_backward", "args": (trainer, model)}, + {"name": "on_after_backward"}, + # note: unscaling happens here in the case of AMP + {"name": "Callback.on_before_optimizer_step", "args": (trainer, model, ANY)}, + {"name": "on_before_optimizer_step", "args": (ANY,)}, + { + "name": "clip_gradients", + "args": (ANY,), + "kwargs": {"gradient_clip_val": None, "gradient_clip_algorithm": None}, + }, + { + "name": "configure_gradient_clipping", + "args": (ANY,), + "kwargs": {"gradient_clip_val": None, "gradient_clip_algorithm": None}, + }, + # this is after because it refers to the `LightningModule.optimizer_step` hook which encapsulates + # the actual call to `Precision.optimizer_step` + { + "name": "optimizer_step", + "args": (current_epoch, i, ANY, ANY), + }, + *( + [{"name": "lr_scheduler_step", "args": (ANY, None)}] + if i == (trainer.num_training_batches - 1) + else [] + ), + {"name": "Callback.on_train_batch_end", "args": (trainer, model, {"loss": ANY}, ANY, i)}, + {"name": "on_train_batch_end", "args": ({"loss": ANY}, ANY, i)}, + ]) return out @staticmethod @@ -316,30 +314,28 @@ def _manual_train_batch(trainer, model, batches, device, **kwargs): using_deepspeed = kwargs.get("strategy") == "deepspeed" out = [] for i in range(batches): - out.extend( - [ - {"name": "on_before_batch_transfer", "args": (ANY, 0)}, - {"name": "transfer_batch_to_device", "args": (ANY, device, 0)}, - {"name": "on_after_batch_transfer", "args": (ANY, 0)}, - {"name": "Callback.on_train_batch_start", "args": (trainer, model, ANY, i)}, - {"name": "on_train_batch_start", "args": (ANY, i)}, - {"name": "forward", "args": (ANY,)}, - {"name": "Callback.on_before_backward", "args": (trainer, model, ANY)}, - {"name": "on_before_backward", "args": (ANY,)}, - # DeepSpeed handles backward internally - *([{"name": "backward", "args": (ANY,)}] if not using_deepspeed else []), - {"name": "Callback.on_after_backward", "args": (trainer, model)}, - {"name": "on_after_backward"}, - # `manual_backward` calls the previous 3 - {"name": "manual_backward", "args": (ANY,)}, - {"name": "closure"}, - {"name": "Callback.on_before_optimizer_step", "args": (trainer, model, ANY)}, - {"name": "on_before_optimizer_step", "args": (ANY,)}, - {"name": "training_step", "args": (ANY, i)}, - {"name": "Callback.on_train_batch_end", "args": (trainer, model, {"loss": ANY}, ANY, i)}, - {"name": "on_train_batch_end", "args": ({"loss": ANY}, ANY, i)}, - ] - ) + out.extend([ + {"name": "on_before_batch_transfer", "args": (ANY, 0)}, + {"name": "transfer_batch_to_device", "args": (ANY, device, 0)}, + {"name": "on_after_batch_transfer", "args": (ANY, 0)}, + {"name": "Callback.on_train_batch_start", "args": (trainer, model, ANY, i)}, + {"name": "on_train_batch_start", "args": (ANY, i)}, + {"name": "forward", "args": (ANY,)}, + {"name": "Callback.on_before_backward", "args": (trainer, model, ANY)}, + {"name": "on_before_backward", "args": (ANY,)}, + # DeepSpeed handles backward internally + *([{"name": "backward", "args": (ANY,)}] if not using_deepspeed else []), + {"name": "Callback.on_after_backward", "args": (trainer, model)}, + {"name": "on_after_backward"}, + # `manual_backward` calls the previous 3 + {"name": "manual_backward", "args": (ANY,)}, + {"name": "closure"}, + {"name": "Callback.on_before_optimizer_step", "args": (trainer, model, ANY)}, + {"name": "on_before_optimizer_step", "args": (ANY,)}, + {"name": "training_step", "args": (ANY, i)}, + {"name": "Callback.on_train_batch_end", "args": (trainer, model, {"loss": ANY}, ANY, i)}, + {"name": "on_train_batch_end", "args": ({"loss": ANY}, ANY, i)}, + ]) return out @staticmethod @@ -357,55 +353,47 @@ def _eval_batch(fn, trainer, model, batches, key, device): out = [] outputs = {key: ANY} for i in range(batches): - out.extend( - [ - {"name": "on_before_batch_transfer", "args": (ANY, 0)}, - {"name": "transfer_batch_to_device", "args": (ANY, device, 0)}, - {"name": "on_after_batch_transfer", "args": (ANY, 0)}, - {"name": f"Callback.on_{fn}_batch_start", "args": (trainer, model, ANY, i)}, - {"name": f"on_{fn}_batch_start", "args": (ANY, i)}, - {"name": "forward", "args": (ANY,)}, - {"name": f"{fn}_step", "args": (ANY, i)}, - {"name": f"Callback.on_{fn}_batch_end", "args": (trainer, model, outputs, ANY, i)}, - {"name": f"on_{fn}_batch_end", "args": (outputs, ANY, i)}, - ] - ) + out.extend([ + {"name": "on_before_batch_transfer", "args": (ANY, 0)}, + {"name": "transfer_batch_to_device", "args": (ANY, device, 0)}, + {"name": "on_after_batch_transfer", "args": (ANY, 0)}, + {"name": f"Callback.on_{fn}_batch_start", "args": (trainer, model, ANY, i)}, + {"name": f"on_{fn}_batch_start", "args": (ANY, i)}, + {"name": "forward", "args": (ANY,)}, + {"name": f"{fn}_step", "args": (ANY, i)}, + {"name": f"Callback.on_{fn}_batch_end", "args": (trainer, model, outputs, ANY, i)}, + {"name": f"on_{fn}_batch_end", "args": (outputs, ANY, i)}, + ]) return out @staticmethod def _predict_batch(trainer, model, batches, device): out = [] for i in range(batches): - out.extend( - [ - {"name": "on_before_batch_transfer", "args": (ANY, 0)}, - {"name": "transfer_batch_to_device", "args": (ANY, device, 0)}, - {"name": "on_after_batch_transfer", "args": (ANY, 0)}, - {"name": "Callback.on_predict_batch_start", "args": (trainer, model, ANY, i)}, - {"name": "on_predict_batch_start", "args": (ANY, i)}, - {"name": "forward", "args": (ANY,)}, - {"name": "predict_step", "args": (ANY, i)}, - {"name": "Callback.on_predict_batch_end", "args": (trainer, model, ANY, ANY, i)}, - {"name": "on_predict_batch_end", "args": (ANY, ANY, i)}, - ] - ) + out.extend([ + {"name": "on_before_batch_transfer", "args": (ANY, 0)}, + {"name": "transfer_batch_to_device", "args": (ANY, device, 0)}, + {"name": "on_after_batch_transfer", "args": (ANY, 0)}, + {"name": "Callback.on_predict_batch_start", "args": (trainer, model, ANY, i)}, + {"name": "on_predict_batch_start", "args": (ANY, i)}, + {"name": "forward", "args": (ANY,)}, + {"name": "predict_step", "args": (ANY, i)}, + {"name": "Callback.on_predict_batch_end", "args": (trainer, model, ANY, ANY, i)}, + {"name": "on_predict_batch_end", "args": (ANY, ANY, i)}, + ]) return out # override so that it gets called - def configure_model(self): - ... + def configure_model(self): ... # override so that it gets called - def on_validation_model_train(self): - ... + def on_validation_model_train(self): ... # override so that it gets called - def on_test_model_train(self): - ... + def on_test_model_train(self): ... # override so that it gets called - def on_predict_model_train(self): - ... + def on_predict_model_train(self): ... @pytest.mark.parametrize( diff --git a/tests/tests_pytorch/models/test_hparams.py b/tests/tests_pytorch/models/test_hparams.py index fd8b417da5621..c22ab1228575a 100644 --- a/tests/tests_pytorch/models/test_hparams.py +++ b/tests/tests_pytorch/models/test_hparams.py @@ -258,8 +258,7 @@ def __init__(self, test_arg, test_arg2): def test_class_nesting(): class MyModule(LightningModule): - def forward(self): - ... + def forward(self): ... # make sure PL modules are always nn.Module a = MyModule() @@ -349,8 +348,7 @@ def __init__(self, *args, dict_conf=OmegaConf.create({"my_param": "something"}), else: - class DictConfSubClassBoringModel: - ... + class DictConfSubClassBoringModel: ... @pytest.mark.parametrize( diff --git a/tests/tests_pytorch/models/test_torchscript.py b/tests/tests_pytorch/models/test_torchscript.py index 806b4db998af1..17bbaaa77997d 100644 --- a/tests/tests_pytorch/models/test_torchscript.py +++ b/tests/tests_pytorch/models/test_torchscript.py @@ -139,8 +139,7 @@ def test_torchscript_save_load_custom_filesystem(tmpdir, modelclass): _DUMMY_PRFEIX = "dummy" _PREFIX_SEPARATOR = "://" - class DummyFileSystem(LocalFileSystem): - ... + class DummyFileSystem(LocalFileSystem): ... fsspec.register_implementation(_DUMMY_PRFEIX, DummyFileSystem, clobber=True) diff --git a/tests/tests_pytorch/plugins/precision/test_bitsandbytes.py b/tests/tests_pytorch/plugins/precision/test_bitsandbytes.py index e65de4255c915..a88e38d6303e4 100644 --- a/tests/tests_pytorch/plugins/precision/test_bitsandbytes.py +++ b/tests/tests_pytorch/plugins/precision/test_bitsandbytes.py @@ -49,8 +49,7 @@ class MyModel(LightningModule): def configure_model(self): self.l = torch.nn.Linear(1, 3) - def test_step(self, *_): - ... + def test_step(self, *_): ... model = MyModel() trainer.test(model, [0]) diff --git a/tests/tests_pytorch/plugins/precision/test_half.py b/tests/tests_pytorch/plugins/precision/test_half.py index cc6bd8af950fe..d51392a00e3d0 100644 --- a/tests/tests_pytorch/plugins/precision/test_half.py +++ b/tests/tests_pytorch/plugins/precision/test_half.py @@ -90,8 +90,7 @@ def configure_model(self): # this is under the `module_init_context` assert self.l.weight.dtype == expected_dtype - def test_step(self, *_): - ... + def test_step(self, *_): ... model = MyModel() trainer = Trainer(barebones=True, precision=precision) diff --git a/tests/tests_pytorch/plugins/precision/test_transformer_engine.py b/tests/tests_pytorch/plugins/precision/test_transformer_engine.py index ec799ea5c358a..7c92ff47d909a 100644 --- a/tests/tests_pytorch/plugins/precision/test_transformer_engine.py +++ b/tests/tests_pytorch/plugins/precision/test_transformer_engine.py @@ -64,8 +64,7 @@ def configure_model(self): self.l = torch.nn.Linear(8, 16) assert self.l.weight.dtype == torch.float16 - def test_step(self, *_): - ... + def test_step(self, *_): ... model = MyModel() trainer = Trainer(barebones=True, precision="transformer-engine-float16") diff --git a/tests/tests_pytorch/serve/test_servable_module_validator.py b/tests/tests_pytorch/serve/test_servable_module_validator.py index 9d299da460770..578619ece75eb 100644 --- a/tests/tests_pytorch/serve/test_servable_module_validator.py +++ b/tests/tests_pytorch/serve/test_servable_module_validator.py @@ -1,9 +1,7 @@ -import sys from typing import Dict import pytest import torch -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.serve.servable_module_validator import ServableModule, ServableModuleValidator @@ -38,11 +36,6 @@ def test_servable_module_validator(): callback.on_train_start(Trainer(accelerator="cpu"), model) -@pytest.mark.xfail( - # https://github.com/pytorch/pytorch/issues/116056 - sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2, - reason="Windows + DDP issue in PyTorch 2.2", -) @pytest.mark.flaky(reruns=3) def test_servable_module_validator_with_trainer(tmpdir): callback = ServableModuleValidator() diff --git a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py index d70a4776b81bd..beab1e20f46b3 100644 --- a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py +++ b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -import sys from multiprocessing import Process from unittest import mock from unittest.mock import ANY, Mock, call, patch @@ -20,7 +19,6 @@ import pytest import torch from lightning.fabric.plugins import ClusterEnvironment -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.strategies import DDPStrategy @@ -196,11 +194,6 @@ def on_fit_start(self) -> None: assert torch.equal(self.layer.weight.data, self.tied_layer.weight.data) -@pytest.mark.xfail( - # https://github.com/pytorch/pytorch/issues/116056 - sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2, - reason="Windows + DDP issue in PyTorch 2.2", -) def test_memory_sharing_disabled(): """Test that the multiprocessing launcher disables memory sharing on model parameters and buffers to avoid race conditions on model updates.""" @@ -221,11 +214,6 @@ def test_check_for_missing_main_guard(): launcher.launch(function=Mock()) -@pytest.mark.xfail( - # https://github.com/pytorch/pytorch/issues/116056 - sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2, - reason="Windows + DDP issue in PyTorch 2.2", -) def test_fit_twice_raises(): model = BoringModel() trainer = Trainer( diff --git a/tests/tests_pytorch/strategies/scripts/cli_script.py b/tests/tests_pytorch/strategies/scripts/cli_script.py index 757a3c75ce34c..e63727d8f3300 100644 --- a/tests/tests_pytorch/strategies/scripts/cli_script.py +++ b/tests/tests_pytorch/strategies/scripts/cli_script.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """A trivial script that wraps a LightningCLI around the BoringModel and BoringDataModule.""" + from lightning.pytorch.cli import LightningCLI from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index a8da202e1de8e..f5513f49fd82c 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -17,6 +17,7 @@ from lightning.fabric.utilities.imports import ( _TORCH_GREATER_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_1, + _TORCH_GREATER_EQUAL_2_2, ) from lightning.fabric.utilities.load import _load_distributed_checkpoint from lightning.pytorch import Trainer @@ -801,13 +802,12 @@ def test_save_checkpoint_storage_options(tmp_path): @RunIf(min_torch="2.0.0") -@mock.patch("torch.distributed.checkpoint.save_state_dict", return_value=MagicMock()) @mock.patch("lightning.pytorch.strategies.fsdp.FSDPStrategy.broadcast", lambda _, x: x) -@mock.patch("lightning.pytorch.strategies.fsdp._get_full_state_dict_context", return_value=MagicMock()) -@mock.patch("lightning.pytorch.strategies.fsdp._get_sharded_state_dict_context", return_value=MagicMock()) -@mock.patch("lightning.fabric.plugins.io.torch_io._atomic_save", return_value=Mock()) -@mock.patch("lightning.pytorch.strategies.fsdp.shutil", return_value=MagicMock()) -def test_fsdp_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___, ____, tmp_path): +@mock.patch("lightning.pytorch.strategies.fsdp._get_full_state_dict_context") +@mock.patch("lightning.pytorch.strategies.fsdp._get_sharded_state_dict_context") +@mock.patch("lightning.fabric.plugins.io.torch_io._atomic_save") +@mock.patch("lightning.pytorch.strategies.fsdp.shutil") +def test_fsdp_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___, tmp_path): strategy = FSDPStrategy(state_dict_type="full") # state_dict_type='full', path exists, path is not a sharded checkpoint: error @@ -839,13 +839,20 @@ def test_fsdp_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___, strategy = FSDPStrategy(state_dict_type="sharded") + save_mock = mock.patch( + "torch.distributed.checkpoint.save" + if _TORCH_GREATER_EQUAL_2_2 + else "torch.distributed.checkpoint.save_state_dict" + ) + # state_dict_type='sharded', path exists, path is a folder: no error (overwrite) path = tmp_path / "not-empty-2" path.mkdir() (path / "file").touch() model = Mock(spec=FullyShardedDataParallel) model.modules.return_value = [model] - strategy.save_checkpoint({"state_dict": {}, "optimizer_states": {"": {}}}, filepath=path) + with save_mock: + strategy.save_checkpoint({"state_dict": {}, "optimizer_states": {"": {}}}, filepath=path) assert (path / "file").exists() # state_dict_type='sharded', path exists, path is a file: no error (overwrite) @@ -853,7 +860,8 @@ def test_fsdp_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___, path.touch() model = Mock(spec=FullyShardedDataParallel) model.modules.return_value = [model] - strategy.save_checkpoint({"state_dict": {}, "optimizer_states": {"": {}}}, filepath=path) + with save_mock: + strategy.save_checkpoint({"state_dict": {}, "optimizer_states": {"": {}}}, filepath=path) assert path.is_dir() diff --git a/tests/tests_pytorch/strategies/test_single_device.py b/tests/tests_pytorch/strategies/test_single_device.py index 29d842f17d3b2..5b10c4a17d726 100644 --- a/tests/tests_pytorch/strategies/test_single_device.py +++ b/tests/tests_pytorch/strategies/test_single_device.py @@ -62,8 +62,7 @@ def test_single_gpu(): assert cuda_memory < model.start_cuda_memory -class MockOptimizer: - ... +class MockOptimizer: ... def test_strategy_pickle(): diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index fc3793a07b6ed..71b2ff45cb328 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -692,11 +692,9 @@ def __init__(self): super().__init__() self.layer = torch.nn.Linear(32, 2) - def training_step(self, *_): - ... + def training_step(self, *_): ... - def train_dataloader(self): - ... + def train_dataloader(self): ... # did not define `configure_optimizers` @@ -1312,16 +1310,15 @@ def test_lightning_cli_reinstantiate_trainer(): assert cli.trainer.max_epochs is None - class TestCallback(Callback): - ... + class TestCallback(Callback): ... # make sure a new trainer can be easily created trainer = cli.instantiate_trainer(max_epochs=123, callbacks=[TestCallback()]) # the new config is used assert trainer.max_epochs == 123 - assert {c.__class__ for c in trainer.callbacks} == {c.__class__ for c in cli.trainer.callbacks}.union( - {TestCallback} - ) + assert {c.__class__ for c in trainer.callbacks} == {c.__class__ for c in cli.trainer.callbacks}.union({ + TestCallback + }) # the existing config is not updated assert cli.config_init["trainer"]["max_epochs"] is None @@ -1539,8 +1536,7 @@ class MyTrainer(Trainer): def __init__(self): super().__init__() - class MyCallback(Callback): - ... + class MyCallback(Callback): ... match = "MyTrainer` class does not expose the `callbacks" with mock.patch("sys.argv", ["any.py"]), pytest.warns(UserWarning, match=match): diff --git a/tests/tests_pytorch/trainer/connectors/test_callback_connector.py b/tests/tests_pytorch/trainer/connectors/test_callback_connector.py index 26e8adea493b9..a1c2426452429 100644 --- a/tests/tests_pytorch/trainer/connectors/test_callback_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_callback_connector.py @@ -224,8 +224,7 @@ def _attach_callbacks(trainer_callbacks, model_callbacks): ) assert trainer.callbacks == [progress_bar, early_stopping1, lr_monitor, grad_accumulation, early_stopping2] - class CustomProgressBar(TQDMProgressBar): - ... + class CustomProgressBar(TQDMProgressBar): ... custom_progress_bar = CustomProgressBar() # a custom callback that overrides ours diff --git a/tests/tests_pytorch/trainer/connectors/test_data_connector.py b/tests/tests_pytorch/trainer/connectors/test_data_connector.py index e9684657dd3c5..a820a3d6ee786 100644 --- a/tests/tests_pytorch/trainer/connectors/test_data_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_data_connector.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import sys from re import escape from typing import Sized from unittest import mock @@ -20,7 +19,6 @@ import lightning.fabric import pytest from lightning.fabric.utilities.distributed import DistributedSamplerWrapper -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset @@ -125,11 +123,6 @@ def on_train_end(self): self.ctx.__exit__(None, None, None) -@pytest.mark.xfail( - # https://github.com/pytorch/pytorch/issues/116056 - sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2, - reason="Windows + DDP issue in PyTorch 2.2", -) @pytest.mark.parametrize("num_workers", [0, 1, 2]) def test_dataloader_persistent_workers_performance_warning(num_workers, tmp_path): """Test that when the multiprocessing start-method is 'spawn', we recommend setting `persistent_workers=True`.""" @@ -430,18 +423,17 @@ def test_error_raised_with_float_limited_eval_batches(): (DataLoader(dataset=RandomDataset(32, 64), sampler=list(range(64))), False), (CombinedLoader(DataLoader(dataset=RandomDataset(32, 64), shuffle=True)), True), ( - CombinedLoader( - [DataLoader(dataset=RandomDataset(32, 64)), DataLoader(dataset=RandomDataset(32, 64), shuffle=True)] - ), + CombinedLoader([ + DataLoader(dataset=RandomDataset(32, 64)), + DataLoader(dataset=RandomDataset(32, 64), shuffle=True), + ]), True, ), ( - CombinedLoader( - { - "dl1": DataLoader(dataset=RandomDataset(32, 64)), - "dl2": DataLoader(dataset=RandomDataset(32, 64), shuffle=True), - } - ), + CombinedLoader({ + "dl1": DataLoader(dataset=RandomDataset(32, 64)), + "dl2": DataLoader(dataset=RandomDataset(32, 64), shuffle=True), + }), True, ), ], diff --git a/tests/tests_pytorch/trainer/flags/test_inference_mode.py b/tests/tests_pytorch/trainer/flags/test_inference_mode.py index 969188354edb0..c262f0ca33806 100644 --- a/tests/tests_pytorch/trainer/flags/test_inference_mode.py +++ b/tests/tests_pytorch/trainer/flags/test_inference_mode.py @@ -54,8 +54,7 @@ def test_no_grad_context(): class Foo: @_no_grad_context - def run(self): - ... + def run(self): ... f = Foo() with pytest.raises(TypeError, match="Foo` needs to be a Loop"): @@ -63,8 +62,7 @@ def run(self): class Foo(_Loop): @_no_grad_context - def run(self): - ... + def run(self): ... f = Foo(trainer) with pytest.raises(TypeError, match="Foo.inference_mode` needs to be defined"): @@ -76,8 +74,7 @@ def __init__(self): self.inference_mode = False @_no_grad_context - def run(self): - ... + def run(self): ... f = Foo() with mock.patch("torch.no_grad") as no_grad_mock: diff --git a/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py index 75ee16c70c2da..7b732057cd029 100644 --- a/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Test logging in the evaluation loop.""" + import collections import itertools import os diff --git a/tests/tests_pytorch/trainer/logging_/test_logger_connector.py b/tests/tests_pytorch/trainer/logging_/test_logger_connector.py index d7d29a857a9be..cf8590feb2998 100644 --- a/tests/tests_pytorch/trainer/logging_/test_logger_connector.py +++ b/tests/tests_pytorch/trainer/logging_/test_logger_connector.py @@ -236,31 +236,27 @@ def test_fx_validator_integration(tmpdir): ) trainer.fit(model) - not_supported.update( - { - # `lightning_module` ref is now present from the `fit` call - "test_dataloader": "You can't", - "on_test_model_eval": "You can't", - "on_test_model_train": "You can't", - "on_test_end": "You can't", - } - ) + not_supported.update({ + # `lightning_module` ref is now present from the `fit` call + "test_dataloader": "You can't", + "on_test_model_eval": "You can't", + "on_test_model_train": "You can't", + "on_test_end": "You can't", + }) trainer.test(model, verbose=False) not_supported.update({k: "result collection is not registered yet" for k in not_supported}) - not_supported.update( - { - "predict_dataloader": "result collection is not registered yet", - "on_predict_model_eval": "result collection is not registered yet", - "on_predict_start": "result collection is not registered yet", - "on_predict_epoch_start": "result collection is not registered yet", - "on_predict_batch_start": "result collection is not registered yet", - "predict_step": "result collection is not registered yet", - "on_predict_batch_end": "result collection is not registered yet", - "on_predict_epoch_end": "result collection is not registered yet", - "on_predict_end": "result collection is not registered yet", - } - ) + not_supported.update({ + "predict_dataloader": "result collection is not registered yet", + "on_predict_model_eval": "result collection is not registered yet", + "on_predict_start": "result collection is not registered yet", + "on_predict_epoch_start": "result collection is not registered yet", + "on_predict_batch_start": "result collection is not registered yet", + "predict_step": "result collection is not registered yet", + "on_predict_batch_end": "result collection is not registered yet", + "on_predict_epoch_end": "result collection is not registered yet", + "on_predict_end": "result collection is not registered yet", + }) trainer.predict(model) diff --git a/tests/tests_pytorch/trainer/logging_/test_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_loop_logging.py index c3a6b9abe9c9a..e43f3b8d6bffd 100644 --- a/tests/tests_pytorch/trainer/logging_/test_loop_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_loop_logging.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Test logging in the training loop.""" + import inspect from unittest import mock from unittest.mock import ANY diff --git a/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py index bb17220610366..74d65ce3336f2 100644 --- a/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py @@ -15,7 +15,6 @@ import collections import itertools -import sys from re import escape from unittest import mock from unittest.mock import call @@ -23,7 +22,6 @@ import numpy as np import pytest import torch -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 from lightning.pytorch import Trainer, callbacks from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar from lightning.pytorch.core.module import LightningModule @@ -348,15 +346,7 @@ def validation_step(self, batch, batch_idx): ("devices", "accelerator"), [ (1, "cpu"), - pytest.param( - 2, - "cpu", - marks=pytest.mark.xfail( - # https://github.com/pytorch/pytorch/issues/116056 - sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2, - reason="Windows + DDP issue in PyTorch 2.2", - ), - ), + (2, "cpu"), pytest.param(2, "gpu", marks=RunIf(min_cuda_gpus=2)), ], ) @@ -743,16 +733,14 @@ def training_step(self, batch, batch_idx): ) trainer.fit(model) - mock_log_metrics.assert_has_calls( - [ - call(metrics={"foo_step": 0.0, "epoch": 0}, step=0), - call(metrics={"foo_step": 0.0, "epoch": 0}, step=1), - call(metrics={"foo_epoch": 0.0, "epoch": 0}, step=1), - call(metrics={"foo_step": 0.0, "epoch": 1}, step=2), - call(metrics={"foo_step": 0.0, "epoch": 1}, step=3), - call(metrics={"foo_epoch": 0.0, "epoch": 1}, step=3), - ] - ) + mock_log_metrics.assert_has_calls([ + call(metrics={"foo_step": 0.0, "epoch": 0}, step=0), + call(metrics={"foo_step": 0.0, "epoch": 0}, step=1), + call(metrics={"foo_epoch": 0.0, "epoch": 0}, step=1), + call(metrics={"foo_step": 0.0, "epoch": 1}, step=2), + call(metrics={"foo_step": 0.0, "epoch": 1}, step=3), + call(metrics={"foo_epoch": 0.0, "epoch": 1}, step=3), + ]) @mock.patch("lightning.pytorch.loggers.TensorBoardLogger.log_metrics") diff --git a/tests/tests_pytorch/trainer/optimization/test_multiple_optimizers.py b/tests/tests_pytorch/trainer/optimization/test_multiple_optimizers.py index f8fecf20859c3..813cbc295deed 100644 --- a/tests/tests_pytorch/trainer/optimization/test_multiple_optimizers.py +++ b/tests/tests_pytorch/trainer/optimization/test_multiple_optimizers.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tests to ensure that the behaviours related to multiple optimizers works.""" + import lightning.pytorch as pl import pytest import torch diff --git a/tests/tests_pytorch/trainer/optimization/test_optimizers.py b/tests/tests_pytorch/trainer/optimization/test_optimizers.py index 60b987c88b625..185a4c5fda8e4 100644 --- a/tests/tests_pytorch/trainer/optimization/test_optimizers.py +++ b/tests/tests_pytorch/trainer/optimization/test_optimizers.py @@ -234,7 +234,8 @@ def test_optimizer_return_options(tmpdir): # opt list of dictionaries model.automatic_optimization = False model.configure_optimizers = lambda: [ - {"optimizer": opt_a, "lr_scheduler": scheduler_a}, {"optimizer": opt_b, "lr_scheduler": scheduler_a} + {"optimizer": opt_a, "lr_scheduler": scheduler_a}, + {"optimizer": opt_b, "lr_scheduler": scheduler_a}, ] opt, lr_sched = _init_optimizers_and_lr_schedulers(model) assert len(opt) == len(lr_sched) == 2 @@ -559,14 +560,11 @@ class CustomEpochScheduler: def __init__(self, optimizer): self.optimizer = optimizer - def step(self, epoch): - ... + def step(self, epoch): ... - def state_dict(self): - ... + def state_dict(self): ... - def load_state_dict(self, state_dict): - ... + def load_state_dict(self, state_dict): ... class CustomBoringModel(BoringModel): def lr_scheduler_step(self, scheduler: int, metric): @@ -611,8 +609,7 @@ class CustomScheduler: def __init__(self, optimizer): self.optimizer = optimizer - def step(self): - ... + def step(self): ... class CustomBoringModel(BoringModel): def configure_optimizers(self): @@ -637,11 +634,9 @@ def __init__(self, optimizer): def step(self, foobar): # breaks the API, forces user to override `lr_scheduler_step` ... - def state_dict(self): - ... + def state_dict(self): ... - def load_state_dict(self, state_dict): - ... + def load_state_dict(self, state_dict): ... class CustomBoringModel(BoringModel): def configure_optimizers(self): @@ -653,8 +648,7 @@ def configure_optimizers(self): model.trainer = Trainer() if override: - def lr_scheduler_step(*_): - ... + def lr_scheduler_step(*_): ... # the user did override the hook, no error model.lr_scheduler_step = lr_scheduler_step diff --git a/tests/tests_pytorch/utilities/test_all_gather_grad.py b/tests/tests_pytorch/utilities/test_all_gather_grad.py index 82c395e8c91e2..016dffc26c403 100644 --- a/tests/tests_pytorch/utilities/test_all_gather_grad.py +++ b/tests/tests_pytorch/utilities/test_all_gather_grad.py @@ -58,17 +58,15 @@ class TestModel(BoringModel): def on_train_epoch_end(self): losses = torch.rand(2, 2).t() - gathered_loss = self.all_gather( - { - "losses_tensor_int": losses.int(), - "losses_tensor_float": losses, - "losses_tensor_list": [losses, losses], - "losses_np_ndarray": np.array([1, 2, 3]), - "losses_bool": [True, False], - "losses_float": [0.0, 1.0, 2.0], - "losses_int": [0, 1, 2], - } - ) + gathered_loss = self.all_gather({ + "losses_tensor_int": losses.int(), + "losses_tensor_float": losses, + "losses_tensor_list": [losses, losses], + "losses_np_ndarray": np.array([1, 2, 3]), + "losses_bool": [True, False], + "losses_float": [0.0, 1.0, 2.0], + "losses_int": [0, 1, 2], + }) assert gathered_loss["losses_tensor_int"][0].dtype == torch.int32 assert gathered_loss["losses_tensor_float"][0].dtype == torch.float assert gathered_loss["losses_np_ndarray"][0].dtype == torch.int64 diff --git a/tests/tests_pytorch/utilities/test_combined_loader.py b/tests/tests_pytorch/utilities/test_combined_loader.py index 71ebc6e69a7a2..74f5c1330a089 100644 --- a/tests/tests_pytorch/utilities/test_combined_loader.py +++ b/tests/tests_pytorch/utilities/test_combined_loader.py @@ -442,14 +442,12 @@ def __init__(self, data_source, name) -> None: self.name = name dataset = CustomDataset(range(10)) - combined_loader = CombinedLoader( - { - "a": DataLoader(CustomDataset(range(10))), - "b": DataLoader(dataset, sampler=CustomSampler(dataset, "custom_sampler")), - "c": {"c": DataLoader(CustomDataset(range(10))), "d": DataLoader(CustomDataset(range(10)))}, - "d": [DataLoader(CustomDataset(range(10))), DataLoader(CustomDataset(range(10)))], - } - ) + combined_loader = CombinedLoader({ + "a": DataLoader(CustomDataset(range(10))), + "b": DataLoader(dataset, sampler=CustomSampler(dataset, "custom_sampler")), + "c": {"c": DataLoader(CustomDataset(range(10))), "d": DataLoader(CustomDataset(range(10)))}, + "d": [DataLoader(CustomDataset(range(10))), DataLoader(CustomDataset(range(10)))], + }) model = BoringModel() trainer = Trainer(use_distributed_sampler=use_distributed_sampler, strategy="ddp", accelerator="cpu", devices=2) trainer.strategy.connect(model) diff --git a/tests/tests_pytorch/utilities/test_signature_utils.py b/tests/tests_pytorch/utilities/test_signature_utils.py index e281f2f6fc8e6..e453459670360 100644 --- a/tests/tests_pytorch/utilities/test_signature_utils.py +++ b/tests/tests_pytorch/utilities/test_signature_utils.py @@ -4,31 +4,27 @@ def test_param_in_hook_signature(): class LightningModule: - def validation_step(self, dataloader_iter): - ... + def validation_step(self, dataloader_iter): ... model = LightningModule() assert is_param_in_hook_signature(model.validation_step, "dataloader_iter", explicit=True) class LightningModule: @torch.no_grad() - def validation_step(self, dataloader_iter): - ... + def validation_step(self, dataloader_iter): ... model = LightningModule() assert is_param_in_hook_signature(model.validation_step, "dataloader_iter", explicit=True) class LightningModule: - def validation_step(self, *args): - ... + def validation_step(self, *args): ... model = LightningModule() assert not is_param_in_hook_signature(model.validation_step, "dataloader_iter", explicit=True) assert is_param_in_hook_signature(model.validation_step, "dataloader_iter", explicit=False) class LightningModule: - def validation_step(self, a, b): - ... + def validation_step(self, a, b): ... model = LightningModule() assert not is_param_in_hook_signature(model.validation_step, "dataloader_iter", min_args=3) diff --git a/tests/tests_pytorch/utilities/test_warnings.py b/tests/tests_pytorch/utilities/test_warnings.py index a280f618f4505..8c385a2d2e49f 100644 --- a/tests/tests_pytorch/utilities/test_warnings.py +++ b/tests/tests_pytorch/utilities/test_warnings.py @@ -16,6 +16,7 @@ Needs to be run outside of `pytest` as it captures all the warnings. """ + import importlib import os import warnings