Skip to content

Commit 4f73593

Browse files
committed
Add windows skips for Fabric
1 parent 4a51997 commit 4f73593

File tree

5 files changed

+31
-0
lines changed

5 files changed

+31
-0
lines changed

tests/tests_fabric/plugins/precision/test_amp_integration.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Integration tests for Automatic Mixed Precision (AMP) training."""
15+
import sys
16+
1517
import pytest
1618
import torch
1719
import torch.nn as nn
1820
from lightning.fabric import Fabric, seed_everything
21+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
1922

2023
from tests_fabric.helpers.runif import RunIf
2124

@@ -37,6 +40,10 @@ def forward(self, x):
3740
return output
3841

3942

43+
@pytest.mark.skipif(
44+
# https://github.com/pytorch/pytorch/issues/116056
45+
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2, reason="Windows + DDP issue in PyTorch 2.2"
46+
)
4047
@pytest.mark.parametrize(
4148
("accelerator", "precision", "expected_dtype"),
4249
[

tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import sys
15+
1416
import pytest
1517
import torch
1618
import torch.nn as nn
1719
from lightning.fabric import Fabric
20+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
1821

1922
from tests_fabric.helpers.runif import RunIf
2023

@@ -28,6 +31,10 @@ def __init__(self):
2831
self.register_buffer("buffer", torch.ones(3))
2932

3033

34+
@pytest.mark.skipif(
35+
# https://github.com/pytorch/pytorch/issues/116056
36+
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2, reason="Windows + DDP issue in PyTorch 2.2"
37+
)
3138
@pytest.mark.parametrize("strategy", ["ddp_spawn", pytest.param("ddp_fork", marks=RunIf(skip_windows=True))])
3239
def test_memory_sharing_disabled(strategy):
3340
"""Test that the multiprocessing launcher disables memory sharing on model parameters and buffers to avoid race

tests/tests_fabric/strategies/test_ddp_integration.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import sys
1415
from copy import deepcopy
1516

1617
import pytest
@@ -19,8 +20,13 @@
1920

2021
from tests_fabric.helpers.runif import RunIf
2122
from tests_fabric.strategies.test_single_device import _run_test_clip_gradients
23+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
2224

2325

26+
@pytest.mark.skipif(
27+
# https://github.com/pytorch/pytorch/issues/116056
28+
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2, reason="Windows + DDP issue in PyTorch 2.2"
29+
)
2430
@pytest.mark.parametrize(
2531
"accelerator",
2632
[

tests/tests_fabric/utilities/test_distributed.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import functools
22
import os
3+
import sys
34
from functools import partial
45
from pathlib import Path
56
from unittest import mock
@@ -17,6 +18,7 @@
1718
_sync_ddp,
1819
is_shared_filesystem,
1920
)
21+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
2022

2123
from tests_fabric.helpers.runif import RunIf
2224

@@ -118,6 +120,10 @@ def test_collective_operations(devices, process):
118120
spawn_launch(process, devices)
119121

120122

123+
@pytest.mark.skipif(
124+
# https://github.com/pytorch/pytorch/issues/116056
125+
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,reason="Windows + DDP issue in PyTorch 2.2"
126+
)
121127
@pytest.mark.flaky(reruns=3) # flaky with "process 0 terminated with signal SIGABRT" (GLOO)
122128
def test_is_shared_filesystem(tmp_path, monkeypatch):
123129
# In the non-distributed case, every location is interpreted as 'shared'

tests/tests_fabric/utilities/test_spike.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55
import torch
66
from lightning.fabric import Fabric
7+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
78
from lightning.fabric.utilities.spike import _TORCHMETRICS_GREATER_EQUAL_1_0_0, SpikeDetection, TrainingSpikeException
89

910

@@ -28,6 +29,10 @@ def spike_detection_test(fabric, global_rank_spike, spike_value, should_raise):
2829
)
2930

3031

32+
@pytest.mark.skipif(
33+
# https://github.com/pytorch/pytorch/issues/116056
34+
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2,reason="Windows + DDP issue in PyTorch 2.2"
35+
)
3136
@pytest.mark.flaky(max_runs=3)
3237
@pytest.mark.parametrize(
3338
("global_rank_spike", "num_devices", "spike_value", "finite_only"),

0 commit comments

Comments
 (0)