Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Jan 20, 2024
1 parent d16d432 commit 3b46cde
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions tests/tests_fabric/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def fake_script(tmp_path):

@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_cli_env_vars_defaults(monkeypatch, fake_script):
monkeypatch.setattr(torch.distributed, "run", Mock())
monkeypatch.setattr(torch.distributed, "run", Mock(), raising=False)
with pytest.raises(SystemExit) as e:
_run_model.main([fake_script])
assert e.value.code == 0
Expand All @@ -51,7 +51,7 @@ def test_cli_env_vars_defaults(monkeypatch, fake_script):
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
@mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=2)
def test_cli_env_vars_accelerator(_, accelerator, monkeypatch, fake_script):
monkeypatch.setattr(torch.distributed, "run", Mock())
monkeypatch.setattr(torch.distributed, "run", Mock(), raising=False)
with pytest.raises(SystemExit) as e:
_run_model.main([fake_script, "--accelerator", accelerator])
assert e.value.code == 0
Expand All @@ -62,7 +62,7 @@ def test_cli_env_vars_accelerator(_, accelerator, monkeypatch, fake_script):
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
@mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=2)
def test_cli_env_vars_strategy(_, strategy, monkeypatch, fake_script):
monkeypatch.setattr(torch.distributed, "run", Mock())
monkeypatch.setattr(torch.distributed, "run", Mock(), raising=False)
with pytest.raises(SystemExit) as e:
_run_model.main([fake_script, "--strategy", strategy])
assert e.value.code == 0
Expand All @@ -89,7 +89,7 @@ def test_cli_env_vars_unsupported_strategy(strategy, fake_script):
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
@mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=2)
def test_cli_env_vars_devices_cuda(_, devices, monkeypatch, fake_script):
monkeypatch.setattr(torch.distributed, "run", Mock())
monkeypatch.setattr(torch.distributed, "run", Mock(), raising=False)
with pytest.raises(SystemExit) as e:
_run_model.main([fake_script, "--accelerator", "cuda", "--devices", devices])
assert e.value.code == 0
Expand All @@ -100,7 +100,7 @@ def test_cli_env_vars_devices_cuda(_, devices, monkeypatch, fake_script):
@pytest.mark.parametrize("accelerator", ["mps", "gpu"])
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_cli_env_vars_devices_mps(accelerator, monkeypatch, fake_script):
monkeypatch.setattr(torch.distributed, "run", Mock())
monkeypatch.setattr(torch.distributed, "run", Mock(), raising=False)
with pytest.raises(SystemExit) as e:
_run_model.main([fake_script, "--accelerator", accelerator])
assert e.value.code == 0
Expand All @@ -110,7 +110,7 @@ def test_cli_env_vars_devices_mps(accelerator, monkeypatch, fake_script):
@pytest.mark.parametrize("num_nodes", ["1", "2", "3"])
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_cli_env_vars_num_nodes(num_nodes, monkeypatch, fake_script):
monkeypatch.setattr(torch.distributed, "run", Mock())
monkeypatch.setattr(torch.distributed, "run", Mock(), raising=False)
with pytest.raises(SystemExit) as e:
_run_model.main([fake_script, "--num-nodes", num_nodes])
assert e.value.code == 0
Expand All @@ -120,7 +120,7 @@ def test_cli_env_vars_num_nodes(num_nodes, monkeypatch, fake_script):
@pytest.mark.parametrize("precision", ["64-true", "64", "32-true", "32", "16-mixed", "bf16-mixed"])
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_cli_env_vars_precision(precision, monkeypatch, fake_script):
monkeypatch.setattr(torch.distributed, "run", Mock())
monkeypatch.setattr(torch.distributed, "run", Mock(), raising=False)
with pytest.raises(SystemExit) as e:
_run_model.main([fake_script, "--precision", precision])
assert e.value.code == 0
Expand All @@ -130,7 +130,7 @@ def test_cli_env_vars_precision(precision, monkeypatch, fake_script):
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_cli_torchrun_defaults(monkeypatch, fake_script):
torchrun_mock = Mock()
monkeypatch.setattr(torch.distributed, "run", torchrun_mock)
monkeypatch.setattr(torch.distributed, "run", torchrun_mock, raising=False)
with pytest.raises(SystemExit) as e:
_run_model.main([fake_script])
assert e.value.code == 0
Expand Down Expand Up @@ -160,7 +160,7 @@ def test_cli_torchrun_defaults(monkeypatch, fake_script):
@mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=5)
def test_cli_torchrun_num_processes_launched(_, devices, expected, monkeypatch, fake_script):
torchrun_mock = Mock()
monkeypatch.setattr(torch.distributed, "run", torchrun_mock)
monkeypatch.setattr(torch.distributed, "run", torchrun_mock, raising=False)
with pytest.raises(SystemExit) as e:
_run_model.main([fake_script, "--accelerator", "cuda", "--devices", devices])
assert e.value.code == 0
Expand Down

0 comments on commit 3b46cde

Please sign in to comment.