From e5721cd8831cf66230eb1a93472f56ce94df9324 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 25 Jan 2024 00:45:12 +0100 Subject: [PATCH 1/5] exists --- src/lightning/fabric/plugins/io/torch_io.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/lightning/fabric/plugins/io/torch_io.py b/src/lightning/fabric/plugins/io/torch_io.py index 02de1aa274a32..58cdac4d0f4fb 100644 --- a/src/lightning/fabric/plugins/io/torch_io.py +++ b/src/lightning/fabric/plugins/io/torch_io.py @@ -91,6 +91,5 @@ def remove_checkpoint(self, path: _PATH) -> None: """ fs = get_filesystem(path) - if fs.exists(path): - fs.rm(path, recursive=True) - log.debug(f"Removed checkpoint: {path}") + fs.rm(path, recursive=True) + log.debug(f"Removed checkpoint: {path}") From cf73803e52c6c4ce85995db20c3453c38d7268e1 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 25 Jan 2024 01:11:28 +0100 Subject: [PATCH 2/5] tests --- .../callbacks/on_exception_checkpoint.py | 4 +- tests/tests_fabric/plugins/io/__init__.py | 0 .../tests_fabric/plugins/io/test_torch_io.py | 58 +++++++++++++++++++ .../checkpointing/test_model_checkpoint.py | 1 + 4 files changed, 62 insertions(+), 1 deletion(-) create mode 100644 tests/tests_fabric/plugins/io/__init__.py create mode 100644 tests/tests_fabric/plugins/io/test_torch_io.py diff --git a/src/lightning/pytorch/callbacks/on_exception_checkpoint.py b/src/lightning/pytorch/callbacks/on_exception_checkpoint.py index d92dd352d355b..7744d37959830 100644 --- a/src/lightning/pytorch/callbacks/on_exception_checkpoint.py +++ b/src/lightning/pytorch/callbacks/on_exception_checkpoint.py @@ -67,4 +67,6 @@ def on_exception(self, trainer: "pl.Trainer", *_: Any, **__: Any) -> None: @override def teardown(self, trainer: "pl.Trainer", *_: Any, **__: Any) -> None: - trainer.strategy.remove_checkpoint(self.ckpt_path) + if os.path.exists(self.ckpt_path): + # only exists if there was an exception + trainer.strategy.remove_checkpoint(self.ckpt_path) diff --git a/tests/tests_fabric/plugins/io/__init__.py b/tests/tests_fabric/plugins/io/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/tests_fabric/plugins/io/test_torch_io.py b/tests/tests_fabric/plugins/io/test_torch_io.py new file mode 100644 index 0000000000000..54ef53b857671 --- /dev/null +++ b/tests/tests_fabric/plugins/io/test_torch_io.py @@ -0,0 +1,58 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from lightning.fabric.plugins.io import TorchCheckpointIO + + +def test_remove_checkpoint(tmp_path): + """Test that the IO can remove folders, files, and symlinks.""" + io = TorchCheckpointIO() + + # Path does not exist + with pytest.raises(FileNotFoundError): + io.remove_checkpoint("not_exist.txt") + + # Single file + file = tmp_path / "file.txt" + file.touch() + io.remove_checkpoint(file) + assert not file.exists() + + # Symlink + file = tmp_path / "file.txt" + file.touch() + link = tmp_path / "link.txt" + link.symlink_to(file) + io.remove_checkpoint(link) + assert file.exists() + assert not link.is_symlink() + file.unlink() + + # Broken Symlink + file_not_exists = tmp_path / "not_exist.txt" + link = tmp_path / "link.txt" + link.symlink_to(file_not_exists) + assert not file_not_exists.exists() + io.remove_checkpoint(link) + assert not link.is_symlink() + + # Folder with contents + folder = tmp_path / "folder" + nested_folder = folder / "nested_folder" + nested_folder.mkdir(parents=True) + file = nested_folder / "file.txt" + file.touch() + io.remove_checkpoint(folder) + assert not folder.exists() diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index fb2c8d8e35a93..4082f957ec9e7 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -960,6 +960,7 @@ def on_validation_epoch_end(self): max_epochs=len(monitor), ) trainer.save_checkpoint = Mock() + trainer.strategy.remove_checkpoint = Mock() trainer.fit(model) From 680983e72238f16215309180abd9e2b5da7eaad6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 25 Jan 2024 00:12:51 +0000 Subject: [PATCH 3/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_fabric/plugins/io/test_torch_io.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/tests_fabric/plugins/io/test_torch_io.py b/tests/tests_fabric/plugins/io/test_torch_io.py index 54ef53b857671..e9c30252d8cfa 100644 --- a/tests/tests_fabric/plugins/io/test_torch_io.py +++ b/tests/tests_fabric/plugins/io/test_torch_io.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest - from lightning.fabric.plugins.io import TorchCheckpointIO From fc3dfb386791df39e5918c02734e045121db3874 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 25 Jan 2024 01:14:59 +0100 Subject: [PATCH 4/5] chlog --- src/lightning/fabric/CHANGELOG.md | 3 +++ src/lightning/pytorch/CHANGELOG.md | 3 +++ 2 files changed, 6 insertions(+) diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 9d7ac2b075a36..418b695ab4dda 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -41,6 +41,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - The columns in the `metrics.csv` file produced by `CSVLogger` are now sorted alphabetically ([#19159](https://github.com/Lightning-AI/lightning/pull/19159)) +- `TorchCheckpointIO.remove_checkpoint()` no longer silently passes if the given file does not exist ([#19344](https://github.com/Lightning-AI/lightning/pull/19344)) + + ### Deprecated - diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index b3118ac90cbeb..16ba44f36aebc 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -53,6 +53,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Reverted back to creating a checkpoint copy when `ModelCheckpoint(save_last=True)` instead of creating a symbolic link ([#19191](https://github.com/Lightning-AI/lightning/pull/19191)) +- `TorchCheckpointIO.remove_checkpoint()` no longer silently passes if the given file does not exist ([#19344](https://github.com/Lightning-AI/lightning/pull/19344)) + + ### Deprecated - Deprecated all precision plugin classes under `lightning.pytorch.plugins` with the suffix `Plugin` in the name ([#18840](https://github.com/Lightning-AI/lightning/pull/18840)) From 1fc8f5747562672a04e5d3691018db3d13ff2ecd Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 25 Jan 2024 01:25:00 +0100 Subject: [PATCH 5/5] oldest version of fsspec does not support pathlib --- src/lightning/fabric/plugins/io/torch_io.py | 4 ++-- tests/tests_fabric/plugins/io/test_torch_io.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lightning/fabric/plugins/io/torch_io.py b/src/lightning/fabric/plugins/io/torch_io.py index 58cdac4d0f4fb..db8d91c7e330a 100644 --- a/src/lightning/fabric/plugins/io/torch_io.py +++ b/src/lightning/fabric/plugins/io/torch_io.py @@ -90,6 +90,6 @@ def remove_checkpoint(self, path: _PATH) -> None: path: Path to checkpoint """ - fs = get_filesystem(path) - fs.rm(path, recursive=True) + fs = get_filesystem(str(path)) + fs.rm(str(path), recursive=True) log.debug(f"Removed checkpoint: {path}") diff --git a/tests/tests_fabric/plugins/io/test_torch_io.py b/tests/tests_fabric/plugins/io/test_torch_io.py index e9c30252d8cfa..aa20e8331d51d 100644 --- a/tests/tests_fabric/plugins/io/test_torch_io.py +++ b/tests/tests_fabric/plugins/io/test_torch_io.py @@ -21,7 +21,7 @@ def test_remove_checkpoint(tmp_path): # Path does not exist with pytest.raises(FileNotFoundError): - io.remove_checkpoint("not_exist.txt") + io.remove_checkpoint("does_not_exist.txt") # Single file file = tmp_path / "file.txt"