Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions src/sparkrun/cli/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
@click.argument("recipe_name", type=RECIPE_NAME)
@host_options
@recipe_override_options
@click.option(
"--container-name", "cluster_id_override", default=None, hidden=True, help="Override deterministic cluster ID (static container name)"
)
@click.option("--solo", is_flag=True, help="Force single-node mode", hidden=True)
@click.option("--port", type=int, default=None, help="Override serve port")
@click.option("--served-model-name", default=None, help="Override served model name")
Expand Down Expand Up @@ -75,6 +78,7 @@ def run(
hosts,
hosts_file,
cluster_name,
cluster_id_override,
solo,
port,
tensor_parallel,
Expand Down Expand Up @@ -256,6 +260,25 @@ def run(
if labels_override:
cli_executor_opts["labels"] = list(labels_override)

# Also extract executor-specific keys from -o/--option overrides
executor_keys = {
"auto_remove",
"restart_policy",
"privileged",
"gpus",
"ipc",
"shm_size",
"network",
"user",
"security_opt",
"cap_add",
"ulimit",
"devices",
"memory_limit",
}
for key in list(overrides.keys()):
if key in executor_keys:
cli_executor_opts[key] = overrides.pop(key)
# --- Diagnostics setup ---
diag = None
if diagnostics_path:
Expand Down Expand Up @@ -308,6 +331,7 @@ def run(
dashboard=dashboard,
init_port=init_port,
topology=cluster_cfg.topology,
cluster_id_override=cluster_id_override,
executor_config=cli_executor_opts,
extra_docker_opts=list(executor_args) if executor_args else None,
rootless=not rootful,
Expand Down
48 changes: 48 additions & 0 deletions tests/test_name_alias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from click.testing import CliRunner
from sparkrun.cli import main
from unittest.mock import MagicMock


def test_run_with_name_override(monkeypatch):
runner = CliRunner()

# Mock launch_inference in its original module, since _run.py imports it locally
mock_launch = MagicMock()
mock_result = MagicMock()
mock_result.cluster_id = "test-cluster-id"
mock_result.serve_command = "vllm serve"
mock_result.runtime_info = {}
mock_result.rc = 0
mock_result.head_host = "localhost"
mock_result.host_list = ["localhost"]
mock_launch.return_value = mock_result
monkeypatch.setattr("sparkrun.core.launcher.launch_inference", mock_launch)

monkeypatch.setattr("sparkrun.core.launcher.post_launch_lifecycle", MagicMock())
monkeypatch.setattr("sparkrun.cli._run._resolve_hosts_or_exit", lambda *args, **kwargs: (["localhost"], None))
mock_recipe = MagicMock()
mock_recipe.runtime = "vllm"
mock_recipe.model = "test-model"
mock_recipe.validate.return_value = []
mock_recipe.mode = "solo"
mock_recipe.build_config_chain.return_value = {"port": 8000}
monkeypatch.setattr("sparkrun.cli._run._load_recipe", lambda *args, **kwargs: (mock_recipe, "path", None))

mock_ret = MagicMock()
mock_ret.topology = None
mock_ret.resolve_transfer_config.return_value = (None, None, None, None)
monkeypatch.setattr("sparkrun.cli._run.resolve_cluster_config", lambda *args, **kwargs: mock_ret)

mock_runtime = MagicMock()
mock_runtime.runtime_name = "vllm"
mock_runtime.resolve_container.return_value = "img:latest"
mock_runtime.validate_recipe.return_value = []
monkeypatch.setattr("sparkrun.core.bootstrap.get_runtime", lambda *args, **kwargs: mock_runtime)
monkeypatch.setattr("sparkrun.cli._run.validate_and_prepare_hosts", lambda *args, **kwargs: (["localhost"], True))
monkeypatch.setattr("sparkrun.cli._run._display_vram_estimate", lambda *args, **kwargs: None)

result = runner.invoke(main, ["run", "test-recipe", "--container-name", "custom-cluster-id", "--solo"])

assert result.exit_code == 0
args, kwargs = mock_launch.call_args
assert kwargs.get("cluster_id_override") == "custom-cluster-id"