Skip to content

Commit 06a95d0

Browse files
committed
add pytorch skips
1 parent b1e74db commit 06a95d0

File tree

7 files changed

+53
-2
lines changed

7 files changed

+53
-2
lines changed

tests/tests_pytorch/callbacks/test_spike.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import pytest
55
import torch
6+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
67
from lightning.fabric.utilities.spike import _TORCHMETRICS_GREATER_EQUAL_1_0_0, TrainingSpikeException
78
from lightning.pytorch import LightningModule, Trainer
89
from lightning.pytorch.callbacks.spike import SpikeDetection
@@ -46,6 +47,11 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
4647
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
4748

4849

50+
@pytest.mark.xfail(
51+
# https://github.com/pytorch/pytorch/issues/116056
52+
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
53+
reason="Windows + DDP issue in PyTorch 2.2",
54+
)
4955
@pytest.mark.flaky(max_runs=3)
5056
@pytest.mark.parametrize(
5157
("global_rank_spike", "num_devices", "spike_value", "finite_only"),

tests/tests_pytorch/loops/test_prediction_loop.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import itertools
15+
import sys
1516

1617
import pytest
1718
from lightning.pytorch import LightningModule, Trainer
1819
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
1920
from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper
21+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
2022
from torch.utils.data import DataLoader, DistributedSampler, SequentialSampler
2123

2224

@@ -50,6 +52,11 @@ def predict_step(self, batch, batch_idx):
5052
assert trainer.predict_loop.predictions == []
5153

5254

55+
@pytest.mark.xfail(
56+
# https://github.com/pytorch/pytorch/issues/116056
57+
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
58+
reason="Windows + DDP issue in PyTorch 2.2",
59+
)
5360
@pytest.mark.parametrize("use_distributed_sampler", [False, True])
5461
def test_prediction_loop_batch_sampler_set_epoch_called(tmp_path, use_distributed_sampler):
5562
"""Tests that set_epoch is called on the dataloader's batch sampler (if any) during prediction."""

tests/tests_pytorch/models/test_amp.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15+
import sys
1516
from unittest import mock
1617

1718
import pytest
1819
import torch
1920
from lightning.fabric.plugins.environments import SLURMEnvironment
21+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
2022
from lightning.pytorch import Trainer
2123
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
2224
from torch.utils.data import DataLoader
@@ -53,7 +55,11 @@ def _assert_autocast_enabled(self):
5355
[
5456
("single_device", "16-mixed", 1),
5557
("single_device", "bf16-mixed", 1),
56-
("ddp_spawn", "16-mixed", 2),
58+
pytest.param("ddp_spawn", "16-mixed", 2, marks=pytest.mark.xfail(
59+
# https://github.com/pytorch/pytorch/issues/116056
60+
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
61+
reason="Windows + DDP issue in PyTorch 2.2",
62+
)),
5763
pytest.param("ddp_spawn", "bf16-mixed", 2, marks=RunIf(skip_windows=True)),
5864
],
5965
)

tests/tests_pytorch/serve/test_servable_module_validator.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
import sys
12
from typing import Dict
23

34
import pytest
45
import torch
56
from lightning.pytorch import Trainer
67
from lightning.pytorch.demos.boring_classes import BoringModel
78
from lightning.pytorch.serve.servable_module_validator import ServableModule, ServableModuleValidator
9+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
810
from torch import Tensor
911

1012

@@ -36,6 +38,11 @@ def test_servable_module_validator():
3638
callback.on_train_start(Trainer(accelerator="cpu"), model)
3739

3840

41+
@pytest.mark.xfail(
42+
# https://github.com/pytorch/pytorch/issues/116056
43+
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
44+
reason="Windows + DDP issue in PyTorch 2.2",
45+
)
3946
@pytest.mark.flaky(reruns=3)
4047
def test_servable_module_validator_with_trainer(tmpdir):
4148
callback = ServableModuleValidator()

tests/tests_pytorch/strategies/launchers/test_multiprocessing.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15+
import sys
1516
from multiprocessing import Process
1617
from unittest import mock
1718
from unittest.mock import ANY, Mock, call, patch
1819

1920
import pytest
2021
import torch
2122
from lightning.fabric.plugins import ClusterEnvironment
23+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
2224
from lightning.pytorch import Trainer
2325
from lightning.pytorch.demos.boring_classes import BoringModel
2426
from lightning.pytorch.strategies import DDPStrategy
@@ -194,6 +196,11 @@ def on_fit_start(self) -> None:
194196
assert torch.equal(self.layer.weight.data, self.tied_layer.weight.data)
195197

196198

199+
@pytest.mark.xfail(
200+
# https://github.com/pytorch/pytorch/issues/116056
201+
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
202+
reason="Windows + DDP issue in PyTorch 2.2",
203+
)
197204
def test_memory_sharing_disabled():
198205
"""Test that the multiprocessing launcher disables memory sharing on model parameters and buffers to avoid race
199206
conditions on model updates."""
@@ -214,6 +221,11 @@ def test_check_for_missing_main_guard():
214221
launcher.launch(function=Mock())
215222

216223

224+
@pytest.mark.xfail(
225+
# https://github.com/pytorch/pytorch/issues/116056
226+
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
227+
reason="Windows + DDP issue in PyTorch 2.2",
228+
)
217229
def test_fit_twice_raises():
218230
model = BoringModel()
219231
trainer = Trainer(

tests/tests_pytorch/trainer/connectors/test_data_connector.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import sys
1415
from re import escape
1516
from typing import Sized
1617
from unittest import mock
@@ -19,6 +20,7 @@
1920
import lightning.fabric
2021
import pytest
2122
from lightning.fabric.utilities.distributed import DistributedSamplerWrapper
23+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
2224
from lightning.fabric.utilities.warnings import PossibleUserWarning
2325
from lightning.pytorch import Trainer
2426
from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset
@@ -123,6 +125,11 @@ def on_train_end(self):
123125
self.ctx.__exit__(None, None, None)
124126

125127

128+
@pytest.mark.xfail(
129+
# https://github.com/pytorch/pytorch/issues/116056
130+
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
131+
reason="Windows + DDP issue in PyTorch 2.2",
132+
)
126133
@pytest.mark.parametrize("num_workers", [0, 1, 2])
127134
def test_dataloader_persistent_workers_performance_warning(num_workers, tmp_path):
128135
"""Test that when the multiprocessing start-method is 'spawn', we recommend setting `persistent_workers=True`."""

tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import collections
1717
import itertools
18+
import sys
1819
from re import escape
1920
from unittest import mock
2021
from unittest.mock import call
@@ -30,6 +31,7 @@
3031
from lightning.pytorch.trainer.states import RunningStage
3132
from lightning.pytorch.utilities.exceptions import MisconfigurationException
3233
from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_11 as _TM_GE_0_11
34+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
3335
from lightning_utilities.test.warning import no_warning_call
3436
from torch import Tensor
3537
from torch.utils.data import DataLoader
@@ -346,7 +348,11 @@ def validation_step(self, batch, batch_idx):
346348
("devices", "accelerator"),
347349
[
348350
(1, "cpu"),
349-
(2, "cpu"),
351+
pytest.param(2, "cpu", marks=pytest.mark.xfail(
352+
# https://github.com/pytorch/pytorch/issues/116056
353+
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,
354+
reason="Windows + DDP issue in PyTorch 2.2",
355+
)),
350356
pytest.param(2, "gpu", marks=RunIf(min_cuda_gpus=2)),
351357
],
352358
)

0 commit comments

Comments
 (0)