Skip to content

Commit

Permalink
add pytorch skips
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Jan 26, 2024
1 parent b1e74db commit 06a95d0
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 2 deletions.
6 changes: 6 additions & 0 deletions tests/tests_pytorch/callbacks/test_spike.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest
import torch
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
from lightning.fabric.utilities.spike import _TORCHMETRICS_GREATER_EQUAL_1_0_0, TrainingSpikeException
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks.spike import SpikeDetection
Expand Down Expand Up @@ -46,6 +47,11 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)


@pytest.mark.xfail(
# https://github.com/pytorch/pytorch/issues/116056
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
reason="Windows + DDP issue in PyTorch 2.2",
)
@pytest.mark.flaky(max_runs=3)
@pytest.mark.parametrize(
("global_rank_spike", "num_devices", "spike_value", "finite_only"),
Expand Down
7 changes: 7 additions & 0 deletions tests/tests_pytorch/loops/test_prediction_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import sys

import pytest
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
from torch.utils.data import DataLoader, DistributedSampler, SequentialSampler


Expand Down Expand Up @@ -50,6 +52,11 @@ def predict_step(self, batch, batch_idx):
assert trainer.predict_loop.predictions == []


@pytest.mark.xfail(
# https://github.com/pytorch/pytorch/issues/116056
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
reason="Windows + DDP issue in PyTorch 2.2",
)
@pytest.mark.parametrize("use_distributed_sampler", [False, True])
def test_prediction_loop_batch_sampler_set_epoch_called(tmp_path, use_distributed_sampler):
"""Tests that set_epoch is called on the dataloader's batch sampler (if any) during prediction."""
Expand Down
8 changes: 7 additions & 1 deletion tests/tests_pytorch/models/test_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from unittest import mock

import pytest
import torch
from lightning.fabric.plugins.environments import SLURMEnvironment
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -53,7 +55,11 @@ def _assert_autocast_enabled(self):
[
("single_device", "16-mixed", 1),
("single_device", "bf16-mixed", 1),
("ddp_spawn", "16-mixed", 2),
pytest.param("ddp_spawn", "16-mixed", 2, marks=pytest.mark.xfail(
# https://github.com/pytorch/pytorch/issues/116056
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
reason="Windows + DDP issue in PyTorch 2.2",
)),
pytest.param("ddp_spawn", "bf16-mixed", 2, marks=RunIf(skip_windows=True)),
],
)
Expand Down
7 changes: 7 additions & 0 deletions tests/tests_pytorch/serve/test_servable_module_validator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import sys
from typing import Dict

import pytest
import torch
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.serve.servable_module_validator import ServableModule, ServableModuleValidator
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
from torch import Tensor


Expand Down Expand Up @@ -36,6 +38,11 @@ def test_servable_module_validator():
callback.on_train_start(Trainer(accelerator="cpu"), model)


@pytest.mark.xfail(
# https://github.com/pytorch/pytorch/issues/116056
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
reason="Windows + DDP issue in PyTorch 2.2",
)
@pytest.mark.flaky(reruns=3)
def test_servable_module_validator_with_trainer(tmpdir):
callback = ServableModuleValidator()
Expand Down
12 changes: 12 additions & 0 deletions tests/tests_pytorch/strategies/launchers/test_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from multiprocessing import Process
from unittest import mock
from unittest.mock import ANY, Mock, call, patch

import pytest
import torch
from lightning.fabric.plugins import ClusterEnvironment
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.strategies import DDPStrategy
Expand Down Expand Up @@ -194,6 +196,11 @@ def on_fit_start(self) -> None:
assert torch.equal(self.layer.weight.data, self.tied_layer.weight.data)


@pytest.mark.xfail(
# https://github.com/pytorch/pytorch/issues/116056
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
reason="Windows + DDP issue in PyTorch 2.2",
)
def test_memory_sharing_disabled():
"""Test that the multiprocessing launcher disables memory sharing on model parameters and buffers to avoid race
conditions on model updates."""
Expand All @@ -214,6 +221,11 @@ def test_check_for_missing_main_guard():
launcher.launch(function=Mock())


@pytest.mark.xfail(
# https://github.com/pytorch/pytorch/issues/116056
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
reason="Windows + DDP issue in PyTorch 2.2",
)
def test_fit_twice_raises():
model = BoringModel()
trainer = Trainer(
Expand Down
7 changes: 7 additions & 0 deletions tests/tests_pytorch/trainer/connectors/test_data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from re import escape
from typing import Sized
from unittest import mock
Expand All @@ -19,6 +20,7 @@
import lightning.fabric
import pytest
from lightning.fabric.utilities.distributed import DistributedSamplerWrapper
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
from lightning.fabric.utilities.warnings import PossibleUserWarning
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset
Expand Down Expand Up @@ -123,6 +125,11 @@ def on_train_end(self):
self.ctx.__exit__(None, None, None)


@pytest.mark.xfail(
# https://github.com/pytorch/pytorch/issues/116056
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
reason="Windows + DDP issue in PyTorch 2.2",
)
@pytest.mark.parametrize("num_workers", [0, 1, 2])
def test_dataloader_persistent_workers_performance_warning(num_workers, tmp_path):
"""Test that when the multiprocessing start-method is 'spawn', we recommend setting `persistent_workers=True`."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import collections
import itertools
import sys
from re import escape
from unittest import mock
from unittest.mock import call
Expand All @@ -30,6 +31,7 @@
from lightning.pytorch.trainer.states import RunningStage
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_11 as _TM_GE_0_11
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
from lightning_utilities.test.warning import no_warning_call
from torch import Tensor
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -346,7 +348,11 @@ def validation_step(self, batch, batch_idx):
("devices", "accelerator"),
[
(1, "cpu"),
(2, "cpu"),
pytest.param(2, "cpu", marks=pytest.mark.xfail(
# https://github.com/pytorch/pytorch/issues/116056
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
reason="Windows + DDP issue in PyTorch 2.2",
)),
pytest.param(2, "gpu", marks=RunIf(min_cuda_gpus=2)),
],
)
Expand Down

0 comments on commit 06a95d0

Please sign in to comment.