Skip to content

Commit b551426

Browse files
authored
add a option to disable the plugin (#25)
This is especially useful when debugging potential incompatibilites with the vanilla collection
1 parent cd6e565 commit b551426

File tree

5 files changed

+81
-1
lines changed

5 files changed

+81
-1
lines changed

pytest_pytorch/plugin.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
TORCH_AVAILABLE = False
1414

1515
warnings.warn(
16-
"Disabling the plugin 'pytest-pytorch', because 'torch' could not be imported."
16+
"Disabling the `pytest-pytorch` plugin, because 'torch' could not be imported."
1717
)
1818

1919

@@ -87,10 +87,22 @@ def collect(self):
8787
yield from super().collect()
8888

8989

90+
def pytest_addoption(parser, pluginmanager):
91+
parser.addoption(
92+
"--disable-pytest-pytorch",
93+
action="store_true",
94+
help="Disable the `pytest-pytorch` plugin",
95+
)
96+
return None
97+
98+
9099
def pytest_pycollect_makeitem(collector, name, obj):
91100
if not TORCH_AVAILABLE:
92101
return None
93102

103+
if collector.config.getoption("disable_pytest_pytorch"):
104+
return None
105+
94106
try:
95107
if not issubclass(obj, TestCaseTemplate) or obj is TestCaseTemplate:
96108
return None

tests/assets/test_disabled.py

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from torch.testing._internal.common_device_type import instantiate_device_type_tests
2+
from torch.testing._internal.common_utils import TestCase
3+
4+
5+
class TestFoo(TestCase):
6+
def test_bar(self, device):
7+
pass
8+
9+
10+
instantiate_device_type_tests(TestFoo, globals(), only_for="cpu")
11+
12+
13+
class TestSpam(TestCase):
14+
def test_ham(self):
15+
pass

tests/conftest.py

+4
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ def collect_tests(testdir):
88
def collect_tests_(file: str, cmds: str):
99
testdir.copy_example(file)
1010
result = testdir.runpytest("--quiet", "--collect-only", *cmds)
11+
12+
if result.outlines[-1].startswith("no tests collected"):
13+
return set()
14+
1115
assert result.ret == pytest.ExitCode.OK
1216

1317
collection = set()

tests/test_cli.py

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import pytest
2+
3+
4+
@pytest.mark.parametrize("option", ["--disable-pytest-pytorch"])
5+
def test_disable_pytest_pytorch(testdir, option):
6+
result = testdir.runpytest("--help")
7+
assert option in "\n".join(result.outlines)

tests/test_plugin.py

+42
Original file line numberDiff line numberDiff line change
@@ -179,3 +179,45 @@ def test_op_infos(collect_tests, file, cmds, selection):
179179
def test_nested_names(collect_tests, file, cmds, selection):
180180
collection = collect_tests(file, cmds)
181181
assert collection == selection
182+
183+
184+
@make_parametrization(
185+
Config(
186+
selection=(
187+
"::TestFooCPU::test_bar_cpu",
188+
"::TestSpam::test_ham",
189+
),
190+
),
191+
Config(
192+
new_cmds="::TestFoo",
193+
selection=(),
194+
),
195+
Config(
196+
new_cmds="::TestFoo::test_bar",
197+
selection=(),
198+
),
199+
Config(
200+
new_cmds="::TestFooCPU",
201+
legacy_cmds=("-k", "TestFoo"),
202+
selection=("::TestFooCPU::test_bar_cpu",),
203+
),
204+
Config(
205+
new_cmds="::TestFooCPU::test_bar_cpu",
206+
legacy_cmds=("-k", "TestFoo and test_bar"),
207+
selection=("::TestFooCPU::test_bar_cpu",),
208+
),
209+
Config(
210+
new_cmds="::TestSpam",
211+
legacy_cmds=("-k", "TestSpam"),
212+
selection=("::TestSpam::test_ham",),
213+
),
214+
Config(
215+
new_cmds="::TestSpam::test_ham",
216+
legacy_cmds=("-k", "TestSpam and test_ham"),
217+
selection=("::TestSpam::test_ham",),
218+
),
219+
file="test_disabled.py",
220+
)
221+
def test_disabled(collect_tests, file, cmds, selection):
222+
collection = collect_tests(file, ("--disable-pytest-pytorch", *cmds))
223+
assert collection == selection

0 commit comments

Comments
 (0)