|
| 1 | +from click.testing import CliRunner |
| 2 | +from sparkrun.cli import main |
| 3 | +from unittest.mock import MagicMock |
| 4 | + |
| 5 | + |
| 6 | +def test_run_with_name_override(monkeypatch): |
| 7 | + runner = CliRunner() |
| 8 | + |
| 9 | + # Mock launch_inference in its original module, since _run.py imports it locally |
| 10 | + mock_launch = MagicMock() |
| 11 | + mock_result = MagicMock() |
| 12 | + mock_result.cluster_id = "test-cluster-id" |
| 13 | + mock_result.serve_command = "vllm serve" |
| 14 | + mock_result.runtime_info = {} |
| 15 | + mock_result.rc = 0 |
| 16 | + mock_result.head_host = "localhost" |
| 17 | + mock_result.host_list = ["localhost"] |
| 18 | + mock_launch.return_value = mock_result |
| 19 | + monkeypatch.setattr("sparkrun.core.launcher.launch_inference", mock_launch) |
| 20 | + |
| 21 | + monkeypatch.setattr("sparkrun.core.launcher.post_launch_lifecycle", MagicMock()) |
| 22 | + monkeypatch.setattr("sparkrun.cli._run._resolve_hosts_or_exit", lambda *args, **kwargs: (["localhost"], None)) |
| 23 | + mock_recipe = MagicMock() |
| 24 | + mock_recipe.runtime = "vllm" |
| 25 | + mock_recipe.model = "test-model" |
| 26 | + mock_recipe.validate.return_value = [] |
| 27 | + mock_recipe.mode = "solo" |
| 28 | + mock_recipe.build_config_chain.return_value = {"port": 8000} |
| 29 | + monkeypatch.setattr("sparkrun.cli._run._load_recipe", lambda *args, **kwargs: (mock_recipe, "path", None)) |
| 30 | + |
| 31 | + mock_ret = MagicMock() |
| 32 | + mock_ret.topology = None |
| 33 | + mock_ret.resolve_transfer_config.return_value = (None, None, None, None) |
| 34 | + monkeypatch.setattr("sparkrun.cli._run.resolve_cluster_config", lambda *args, **kwargs: mock_ret) |
| 35 | + |
| 36 | + mock_runtime = MagicMock() |
| 37 | + mock_runtime.runtime_name = "vllm" |
| 38 | + mock_runtime.resolve_container.return_value = "img:latest" |
| 39 | + mock_runtime.validate_recipe.return_value = [] |
| 40 | + monkeypatch.setattr("sparkrun.core.bootstrap.get_runtime", lambda *args, **kwargs: mock_runtime) |
| 41 | + monkeypatch.setattr("sparkrun.cli._run.validate_and_prepare_hosts", lambda *args, **kwargs: (["localhost"], True)) |
| 42 | + monkeypatch.setattr("sparkrun.cli._run._display_vram_estimate", lambda *args, **kwargs: None) |
| 43 | + |
| 44 | + result = runner.invoke(main, ["run", "test-recipe", "--container-name", "custom-cluster-id", "--solo"]) |
| 45 | + |
| 46 | + assert result.exit_code == 0 |
| 47 | + args, kwargs = mock_launch.call_args |
| 48 | + assert kwargs.get("cluster_id_override") == "custom-cluster-id" |
0 commit comments