Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions DEVELOPERS.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,37 @@ from sparkrun.orchestration.ssh import run_remote_script
result = run_remote_script(host, script_string, timeout=120, **ssh_kwargs)
```

### Shell Execution & Security

Sparkrun frequently dynamically generates bash scripts and Docker commands that interpolate user-provided inputs (like container names, image names, or environment variables). To prevent shell injection and handle spaces/special characters, you MUST adhere to the following rules:

1. **Python `shlex.quote`**: When building commands in Python (e.g., `docker run` flags), wrap all interpolated values with `shlex.quote`:
```python
import shlex
cmd = f"docker run --name {shlex.quote(container_name)} {shlex.quote(image)}"
```

2. **Base64 Command Wrapping**: When passing complex commands (especially those with nested quotes or JSON) into `bash -c` or over SSH, use the `b64_encode_cmd` and `b64_wrap_bash` utilities from `sparkrun.utils.shell`:
```python
from sparkrun.utils.shell import b64_encode_cmd
b64_cmd = b64_encode_cmd("vllm serve --hf-overrides '{\"rope\": \"yarn\"}'")
# The bash script should decode and execute this:
# printf '%s' '{b64_cmd}' | base64 -d -- | bash --noprofile --norc
```

3. **Use `printf` instead of `echo`**: Inside generated bash scripts (`.sh` files), never use `echo` to output interpolated Python variables. If a variable starts with a hyphen (e.g., `-n`), `echo` may interpret it as a flag. Instead, use `printf` with a format string:
```bash
# DANGEROUS: echo "Launching {container_name}"
# SAFE:
printf "Launching %%s\n" "{container_name}"
```
*Note: In Python string formatting (used to populate the scripts), `%` must be escaped as `%%`.*

4. **Environment Variables**: When exporting variables in generated bash scripts, quote the interpolated value using `shlex.quote` in Python and omit quotes in the bash script:
```python
env_lines.append(f"export MY_VAR={shlex.quote(val)}")
```

### Runtime Architecture

All runtimes extend `RuntimePlugin` (`runtimes/base.py`):
Expand Down
11 changes: 3 additions & 8 deletions src/sparkrun/cli/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,6 @@ def _render_install_script(slug, recipe_yaml, cluster_yaml, cluster_name, user_h
"""Render user-level install script (no sudo needed)."""
service_dir = "%s/.config/sparkrun/services/%s" % (user_home, slug)
clusters_dir = "%s/.config/sparkrun/clusters" % user_home
# Escape single quotes in YAML content for heredoc safety
recipe_yaml_escaped = recipe_yaml.replace("'", "'\\''")
cluster_yaml_escaped = cluster_yaml.replace("'", "'\\''")
return textwrap.dedent("""\
#!/usr/bin/env bash
set -euo pipefail
Expand All @@ -262,16 +259,14 @@ def _render_install_script(slug, recipe_yaml, cluster_yaml, cluster_name, user_h
service_dir=service_dir,
clusters_dir=clusters_dir,
cluster_name=cluster_name,
recipe_yaml=recipe_yaml_escaped,
cluster_yaml=cluster_yaml_escaped,
recipe_yaml=recipe_yaml,
cluster_yaml=cluster_yaml,
)


def _render_sudo_install_script(slug, unit_contents):
"""Render sudo install script (writes unit file, enables service)."""
unit_path = "/etc/systemd/system/sparkrun-%s.service" % slug
# Escape single quotes in unit contents for heredoc safety
unit_escaped = unit_contents.replace("'", "'\\''")
return textwrap.dedent("""\
#!/usr/bin/env bash
set -euo pipefail
Expand All @@ -289,7 +284,7 @@ def _render_sudo_install_script(slug, unit_contents):
""").format(
unit_path=unit_path,
slug=slug,
unit_contents=unit_escaped,
unit_contents=unit_contents,
)


Expand Down
23 changes: 23 additions & 0 deletions src/sparkrun/cli/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
@click.argument("recipe_name", type=RECIPE_NAME)
@host_options
@recipe_override_options
@click.option("--name", "cluster_id_override", default=None, help="Override deterministic cluster ID (static container name)")
@click.option("--solo", is_flag=True, help="Force single-node mode", hidden=True)
@click.option("--port", type=int, default=None, help="Override serve port")
@click.option("--served-model-name", default=None, help="Override served model name")
Expand Down Expand Up @@ -75,6 +76,7 @@ def run(
hosts,
hosts_file,
cluster_name,
cluster_id_override,
solo,
port,
tensor_parallel,
Expand Down Expand Up @@ -256,6 +258,26 @@ def run(
if labels_override:
cli_executor_opts["labels"] = list(labels_override)

# Also extract executor-specific keys from -o/--option overrides
executor_keys = {
"auto_remove",
"restart_policy",
"privileged",
"gpus",
"ipc",
"shm_size",
"network",
"user",
"security_opt",
"cap_add",
"ulimit",
"devices",
"memory_limit",
}
for key in list(overrides.keys()):
if key in executor_keys:
cli_executor_opts[key] = overrides.pop(key)

# --- Diagnostics setup ---
diag = None
if diagnostics_path:
Expand Down Expand Up @@ -308,6 +330,7 @@ def run(
dashboard=dashboard,
init_port=init_port,
topology=cluster_cfg.topology,
cluster_id_override=cluster_id_override,
executor_config=cli_executor_opts,
extra_docker_opts=list(executor_args) if executor_args else None,
rootless=not rootful,
Expand Down
3 changes: 2 additions & 1 deletion src/sparkrun/core/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def launch_inference(
dashboard: bool = False,
init_port: int | None = None,
topology: str | None = None,
cluster_id_override: str | None = None,
# Executor config (dict for config chain layering)
executor_config: dict | None = None,
extra_docker_opts: list[str] | None = None,
Expand Down Expand Up @@ -179,7 +180,7 @@ def launch_inference(
serve_port = int(config_chain.get("port") or 8000)

# Derive deterministic cluster_id from recipe + (trimmed) hosts
cluster_id = generate_cluster_id(recipe, host_list, overrides=overrides)
cluster_id = cluster_id_override or generate_cluster_id(recipe, host_list, overrides=overrides)

# Resolve container image
container_image = runtime.resolve_container(recipe, overrides)
Expand Down
7 changes: 4 additions & 3 deletions src/sparkrun/orchestration/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ def docker_exec_cmd(
parts.append("-d")
if env:
for key, value in sorted(env.items()):
parts.extend(["-e", f"{key}={value}"])
escaped_cmd = command.replace("'", "'\\''")
parts.extend([shlex.quote(container_name), "bash", "-c", "'%s'" % escaped_cmd])
parts.extend(["-e", shlex.quote(f"{key}={value}")])
from sparkrun.utils.shell import b64_wrap_bash

parts.extend([shlex.quote(container_name), "bash", "-c", shlex.quote(b64_wrap_bash(command))])
return " ".join(parts)


Expand Down
45 changes: 26 additions & 19 deletions src/sparkrun/orchestration/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,16 @@
from __future__ import annotations

import logging
import shlex
from abc import ABC, abstractmethod
from dataclasses import dataclass

from scitrera_app_framework.util import ext_parse_bool

from sparkrun.scripts import read_script
from sparkrun.utils import merge_env
from sparkrun.utils.shell import b64_encode_cmd

logger = logging.getLogger(__name__)

# Default executor settings for DGX Spark GPU workloads.
Expand Down Expand Up @@ -82,7 +87,9 @@ def from_chain(cls, chain) -> ExecutorConfig:
# still means "not set" and should fall back.
def _get(key):
v = chain.get(key)
return v if v is not None else EXECUTOR_DEFAULTS.get(key)
val = v if v is not None else EXECUTOR_DEFAULTS.get(key)
logger.debug("ExecutorConfig resolve: %s=%r (from chain: %r)", key, val, v)
return val

return cls(
auto_remove=ext_parse_bool(_get("auto_remove")),
Expand Down Expand Up @@ -214,8 +221,6 @@ def generate_launch_script(

Absorbs ``scripts.py::generate_container_launch_script``.
"""
from sparkrun.utils import merge_env
from sparkrun.scripts import read_script

all_env = merge_env(nccl_env, env)
cleanup = self.stop_cmd(container_name)
Expand Down Expand Up @@ -248,21 +253,23 @@ def generate_exec_serve_script(

Absorbs ``scripts.py::generate_exec_serve_script``.
"""
from sparkrun.scripts import read_script

env_exports = ""
if env:
for key, value in sorted(env.items()):
env_exports += "export %s='%s'; " % (key, value)
env_exports += "export %s=%s; " % (key, shlex.quote(str(value)))

full_cmd = "%s%s" % (env_exports, serve_command)

escaped_cmd = serve_command.replace("'", "'\\''")
full_cmd = "%s%s" % (env_exports, escaped_cmd)
# Base64 encode the command to avoid all bash string-escaping/quoting bugs
# when passing it into `docker exec ... bash -c "..."`
b64_cmd = b64_encode_cmd(full_cmd)

script_name = "exec_serve_detached.sh" if detached else "exec_serve_foreground.sh"
template = read_script(script_name)
return template.format(
container_name=container_name,
full_cmd=full_cmd,
container_name=shlex.quote(container_name),
b64_cmd=b64_cmd,
)

def generate_ray_head_script(
Expand All @@ -281,8 +288,6 @@ def generate_ray_head_script(

Absorbs ``scripts.py::generate_ray_head_script``.
"""
from sparkrun.utils import merge_env
from sparkrun.scripts import read_script

all_env = merge_env({"RAY_memory_monitor_refresh_ms": "0"}, nccl_env, env)

Expand Down Expand Up @@ -322,8 +327,6 @@ def generate_ray_worker_script(

Absorbs ``scripts.py::generate_ray_worker_script``.
"""
from sparkrun.utils import merge_env
from sparkrun.scripts import read_script

all_env = merge_env({"RAY_memory_monitor_refresh_ms": "0"}, nccl_env, env)

Expand Down Expand Up @@ -365,7 +368,6 @@ def generate_node_script(

Absorbs ``base.py::_generate_node_script``.
"""
from sparkrun.utils import merge_env

all_env = merge_env(nccl_env, env)
cleanup = self.stop_cmd(container_name)
Expand All @@ -383,19 +385,24 @@ def generate_node_script(
"#!/bin/bash\n"
"set -uo pipefail\n"
"\n"
"echo 'Cleaning up existing container: %(name)s'\n"
"printf 'Cleaning up existing container: %%s\\n' %(name)s\n"
"%(cleanup)s\n"
"\n"
"echo 'Launching %(label)s: %(name)s'\n"
"printf 'Launching %%s: %%s\\n' %(label)s %(name)s\n"
"%(run_cmd)s\n"
"\n"
"# Verify container started\n"
"sleep 1\n"
"if docker ps --format '{{.Names}}' | grep -q '^%(name)s$'; then\n"
" echo 'Container %(name)s launched successfully'\n"
" printf 'Container %%s launched successfully\\n' %(name)s\n"
"else\n"
" echo 'ERROR: Container %(name)s failed to start' >&2\n"
" printf 'ERROR: Container %%s failed to start\\n' %(name)s >&2\n"
" docker logs %(name)s 2>&1 | tail -20 || true\n"
" exit 1\n"
"fi\n"
) % {"name": container_name, "cleanup": cleanup, "run_cmd": run, "label": label}
) % {
"name": shlex.quote(container_name),
"cleanup": cleanup,
"run_cmd": run,
"label": shlex.quote(label),
}
43 changes: 21 additions & 22 deletions src/sparkrun/orchestration/executor_docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
from __future__ import annotations

import logging
import shlex

from sparkrun.orchestration.executor import Executor
from sparkrun.utils.shell import b64_wrap_bash, quote

logger = logging.getLogger(__name__)

Expand All @@ -28,31 +27,32 @@ def _build_default_opts(self) -> list[str]:
if cfg.gpus:
opts.extend(["--gpus", cfg.gpus])
if cfg.ipc:
opts.append("--ipc=%s" % cfg.ipc)
opts.append("--ipc=%s" % quote(cfg.ipc))
if cfg.shm_size:
opts.append("--shm-size=%s" % cfg.shm_size)
opts.append("--shm-size=%s" % quote(cfg.shm_size))
if cfg.network:
opts.append("--network %s" % cfg.network)
logger.debug("DockerExecutor using network: %s", cfg.network)
opts.append("--network=%s" % quote(cfg.network))
if cfg.user:
if cfg.user == "$SHELL_USER":
opts.extend(["--user", "$(id -u):$(id -g)"])
opts.extend(["-v", "/etc/passwd:/etc/passwd:ro"])
opts.extend(["-v", "/etc/group:/etc/group:ro"])
opts.extend(["-e", "HOME=/tmp"])
else:
opts.extend(["--user", cfg.user])
opts.extend(["--user", quote(cfg.user)])
if cfg.security_opt:
for opt in cfg.security_opt:
opts.extend(["--security-opt", opt])
opts.extend(["--security-opt", quote(opt)])
if cfg.cap_add:
for cap in cfg.cap_add:
opts.extend(["--cap-add", cap])
opts.extend(["--cap-add", quote(cap)])
if cfg.ulimit:
for ul in cfg.ulimit:
opts.extend(["--ulimit", ul])
opts.extend(["--ulimit", quote(ul)])
if cfg.devices:
for dev in cfg.devices:
opts.extend(["--device", dev])
opts.extend(["--device", quote(dev)])
if cfg.memory_limit:
opts.append("--memory=%s" % cfg.memory_limit)
if cfg.labels:
Expand Down Expand Up @@ -87,24 +87,24 @@ def run_cmd(
parts.extend(["--restart", cfg.restart_policy])

if container_name:
parts.extend(["--name", container_name])
parts.extend(["--name", quote(container_name)])

if env:
for key, value in sorted(env.items()):
parts.extend(["-e", "%s=%s" % (key, value)])
parts.extend(["-e", quote("%s=%s" % (key, value))])

if volumes:
for host_path, container_path in sorted(volumes.items()):
parts.extend(["-v", "%s:%s" % (host_path, container_path)])
parts.extend(["-v", quote("%s:%s" % (host_path, container_path))])

if extra_opts:
for opt in extra_opts:
parts.extend(shlex.quote(token) for token in shlex.split(opt))

parts.append(shlex.quote(image))
parts.append(quote(image))

if command:
parts.append(command)
parts.extend(["bash", "-c", quote(b64_wrap_bash(command))])

result = " ".join(parts)

Expand All @@ -129,14 +129,13 @@ def exec_cmd(
parts.append("-d")
if env:
for key, value in sorted(env.items()):
parts.extend(["-e", "%s=%s" % (key, value)])
escaped_cmd = command.replace("'", "'\\''")
parts.extend([shlex.quote(container_name), "bash", "-c", "'%s'" % escaped_cmd])
parts.extend(["-e", quote("%s=%s" % (key, value))])
parts.extend([quote(container_name), "bash", "-c", quote(b64_wrap_bash(command))])
return " ".join(parts)

def stop_cmd(self, container_name: str, force: bool = True) -> str:
"""Generate a docker stop/rm command string."""
quoted = shlex.quote(container_name)
quoted = quote(container_name)
if force:
return "docker rm -f %s 2>/dev/null || true" % quoted
return "docker stop %s 2>/dev/null || true" % quoted
Expand All @@ -153,13 +152,13 @@ def logs_cmd(
parts.append("-f")
if tail is not None:
parts.extend(["--tail", str(tail)])
parts.append(container_name)
parts.append(quote(container_name))
return " ".join(parts)

def inspect_exists_cmd(self, image: str) -> str:
"""Generate a command to check if a docker image exists locally."""
return "docker image inspect %s >/dev/null 2>&1" % shlex.quote(image)
return "docker image inspect %s >/dev/null 2>&1" % quote(image)

def pull_cmd(self, image: str) -> str:
"""Generate a ``docker pull`` command."""
return "docker pull %s" % shlex.quote(image)
return "docker pull %s" % quote(image)
Loading