Skip to content

Commit 97ee5a7

Browse files
authored
Merge pull request #122 from jlapenna/feat/name-alias
feat/name alias
2 parents ca57d92 + 40fbf65 commit 97ee5a7

2 files changed

Lines changed: 72 additions & 0 deletions

File tree

src/sparkrun/cli/_run.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@
3535
@click.argument("recipe_name", type=RECIPE_NAME)
3636
@host_options
3737
@recipe_override_options
38+
@click.option(
39+
"--container-name", "cluster_id_override", default=None, hidden=True, help="Override deterministic cluster ID (static container name)"
40+
)
3841
@click.option("--solo", is_flag=True, help="Force single-node mode", hidden=True)
3942
@click.option("--port", type=int, default=None, help="Override serve port")
4043
@click.option("--served-model-name", default=None, help="Override served model name")
@@ -90,6 +93,7 @@ def run(
9093
hosts,
9194
hosts_file,
9295
cluster_name,
96+
cluster_id_override,
9397
solo,
9498
port,
9599
tensor_parallel,
@@ -271,6 +275,25 @@ def run(
271275
if labels_override:
272276
cli_executor_opts["labels"] = list(labels_override)
273277

278+
# Also extract executor-specific keys from -o/--option overrides
279+
executor_keys = {
280+
"auto_remove",
281+
"restart_policy",
282+
"privileged",
283+
"gpus",
284+
"ipc",
285+
"shm_size",
286+
"network",
287+
"user",
288+
"security_opt",
289+
"cap_add",
290+
"ulimit",
291+
"devices",
292+
"memory_limit",
293+
}
294+
for key in list(overrides.keys()):
295+
if key in executor_keys:
296+
cli_executor_opts[key] = overrides.pop(key)
274297
# --- Diagnostics setup ---
275298
diag = None
276299
if diagnostics_path:
@@ -323,6 +346,7 @@ def run(
323346
dashboard=dashboard,
324347
init_port=init_port,
325348
topology=cluster_cfg.topology,
349+
cluster_id_override=cluster_id_override,
326350
executor_config=cli_executor_opts,
327351
extra_docker_opts=list(executor_args) if executor_args else None,
328352
rootless=not rootful,

tests/test_name_alias.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from click.testing import CliRunner
2+
from sparkrun.cli import main
3+
from unittest.mock import MagicMock
4+
5+
6+
def test_run_with_name_override(monkeypatch):
7+
runner = CliRunner()
8+
9+
# Mock launch_inference in its original module, since _run.py imports it locally
10+
mock_launch = MagicMock()
11+
mock_result = MagicMock()
12+
mock_result.cluster_id = "test-cluster-id"
13+
mock_result.serve_command = "vllm serve"
14+
mock_result.runtime_info = {}
15+
mock_result.rc = 0
16+
mock_result.head_host = "localhost"
17+
mock_result.host_list = ["localhost"]
18+
mock_launch.return_value = mock_result
19+
monkeypatch.setattr("sparkrun.core.launcher.launch_inference", mock_launch)
20+
21+
monkeypatch.setattr("sparkrun.core.launcher.post_launch_lifecycle", MagicMock())
22+
monkeypatch.setattr("sparkrun.cli._run._resolve_hosts_or_exit", lambda *args, **kwargs: (["localhost"], None))
23+
mock_recipe = MagicMock()
24+
mock_recipe.runtime = "vllm"
25+
mock_recipe.model = "test-model"
26+
mock_recipe.validate.return_value = []
27+
mock_recipe.mode = "solo"
28+
mock_recipe.build_config_chain.return_value = {"port": 8000}
29+
monkeypatch.setattr("sparkrun.cli._run._load_recipe", lambda *args, **kwargs: (mock_recipe, "path", None))
30+
31+
mock_ret = MagicMock()
32+
mock_ret.topology = None
33+
mock_ret.resolve_transfer_config.return_value = (None, None, None, None)
34+
monkeypatch.setattr("sparkrun.cli._run.resolve_cluster_config", lambda *args, **kwargs: mock_ret)
35+
36+
mock_runtime = MagicMock()
37+
mock_runtime.runtime_name = "vllm"
38+
mock_runtime.resolve_container.return_value = "img:latest"
39+
mock_runtime.validate_recipe.return_value = []
40+
monkeypatch.setattr("sparkrun.core.bootstrap.get_runtime", lambda *args, **kwargs: mock_runtime)
41+
monkeypatch.setattr("sparkrun.cli._run.validate_and_prepare_hosts", lambda *args, **kwargs: (["localhost"], True))
42+
monkeypatch.setattr("sparkrun.cli._run._display_vram_estimate", lambda *args, **kwargs: None)
43+
44+
result = runner.invoke(main, ["run", "test-recipe", "--container-name", "custom-cluster-id", "--solo"])
45+
46+
assert result.exit_code == 0
47+
args, kwargs = mock_launch.call_args
48+
assert kwargs.get("cluster_id_override") == "custom-cluster-id"

0 commit comments

Comments
 (0)