Skip to content

Commit f117ce1

Browse files
committed
fix(orchestration): use base64 encoding to prevent quote stripping in docker exec
1 parent 4dfaae0 commit f117ce1

7 files changed

Lines changed: 226 additions & 57 deletions

File tree

src/sparkrun/orchestration/executor.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,16 @@
1010
from __future__ import annotations
1111

1212
import logging
13+
import shlex
1314
from abc import ABC, abstractmethod
1415
from dataclasses import dataclass
1516

1617
from scitrera_app_framework.util import ext_parse_bool
1718

19+
from sparkrun.scripts import read_script
20+
from sparkrun.utils import merge_env
21+
from sparkrun.utils.shell import b64_encode_cmd
22+
1823
logger = logging.getLogger(__name__)
1924

2025
# Default executor settings for DGX Spark GPU workloads.
@@ -78,7 +83,9 @@ def from_chain(cls, chain) -> ExecutorConfig:
7883
# still means "not set" and should fall back.
7984
def _get(key):
8085
v = chain.get(key)
81-
return v if v is not None else EXECUTOR_DEFAULTS.get(key)
86+
val = v if v is not None else EXECUTOR_DEFAULTS.get(key)
87+
logger.debug("ExecutorConfig resolve: %s=%r (from chain: %r)", key, val, v)
88+
return val
8289

8390
return cls(
8491
auto_remove=ext_parse_bool(_get("auto_remove")),
@@ -209,8 +216,6 @@ def generate_launch_script(
209216
210217
Absorbs ``scripts.py::generate_container_launch_script``.
211218
"""
212-
from sparkrun.utils import merge_env
213-
from sparkrun.scripts import read_script
214219

215220
all_env = merge_env(nccl_env, env)
216221
cleanup = self.stop_cmd(container_name)
@@ -243,21 +248,23 @@ def generate_exec_serve_script(
243248
244249
Absorbs ``scripts.py::generate_exec_serve_script``.
245250
"""
246-
from sparkrun.scripts import read_script
247251

248252
env_exports = ""
249253
if env:
250254
for key, value in sorted(env.items()):
251-
env_exports += "export %s='%s'; " % (key, value)
255+
env_exports += "export %s=%s; " % (key, shlex.quote(str(value)))
256+
257+
full_cmd = "%s%s" % (env_exports, serve_command)
252258

253-
escaped_cmd = serve_command.replace("'", "'\\''")
254-
full_cmd = "%s%s" % (env_exports, escaped_cmd)
259+
# Base64 encode the command to avoid all bash string-escaping/quoting bugs
260+
# when passing it into `docker exec ... bash -c "..."`
261+
b64_cmd = b64_encode_cmd(full_cmd)
255262

256263
script_name = "exec_serve_detached.sh" if detached else "exec_serve_foreground.sh"
257264
template = read_script(script_name)
258265
return template.format(
259-
container_name=container_name,
260-
full_cmd=full_cmd,
266+
container_name=shlex.quote(container_name),
267+
b64_cmd=b64_cmd,
261268
)
262269

263270
def generate_ray_head_script(
@@ -275,8 +282,6 @@ def generate_ray_head_script(
275282
276283
Absorbs ``scripts.py::generate_ray_head_script``.
277284
"""
278-
from sparkrun.utils import merge_env
279-
from sparkrun.scripts import read_script
280285

281286
all_env = merge_env({"RAY_memory_monitor_refresh_ms": "0"}, nccl_env, env)
282287

@@ -314,8 +319,6 @@ def generate_ray_worker_script(
314319
315320
Absorbs ``scripts.py::generate_ray_worker_script``.
316321
"""
317-
from sparkrun.utils import merge_env
318-
from sparkrun.scripts import read_script
319322

320323
all_env = merge_env({"RAY_memory_monitor_refresh_ms": "0"}, nccl_env, env)
321324

@@ -356,7 +359,6 @@ def generate_node_script(
356359
357360
Absorbs ``base.py::_generate_node_script``.
358361
"""
359-
from sparkrun.utils import merge_env
360362

361363
all_env = merge_env(nccl_env, env)
362364
cleanup = self.stop_cmd(container_name)
@@ -374,19 +376,24 @@ def generate_node_script(
374376
"#!/bin/bash\n"
375377
"set -uo pipefail\n"
376378
"\n"
377-
"echo 'Cleaning up existing container: %(name)s'\n"
379+
"printf 'Cleaning up existing container: %%s\\n' %(name)s\n"
378380
"%(cleanup)s\n"
379381
"\n"
380-
"echo 'Launching %(label)s: %(name)s'\n"
382+
"printf 'Launching %%s: %%s\\n' %(label)s %(name)s\n"
381383
"%(run_cmd)s\n"
382384
"\n"
383385
"# Verify container started\n"
384386
"sleep 1\n"
385387
"if docker ps --format '{{.Names}}' | grep -q '^%(name)s$'; then\n"
386-
" echo 'Container %(name)s launched successfully'\n"
388+
" printf 'Container %%s launched successfully\\n' %(name)s\n"
387389
"else\n"
388-
" echo 'ERROR: Container %(name)s failed to start' >&2\n"
390+
" printf 'ERROR: Container %%s failed to start\\n' %(name)s >&2\n"
389391
" docker logs %(name)s 2>&1 | tail -20 || true\n"
390392
" exit 1\n"
391393
"fi\n"
392-
) % {"name": container_name, "cleanup": cleanup, "run_cmd": run, "label": label}
394+
) % {
395+
"name": shlex.quote(container_name),
396+
"cleanup": cleanup,
397+
"run_cmd": run,
398+
"label": shlex.quote(label),
399+
}

src/sparkrun/orchestration/executor_docker.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import shlex
1212

1313
from sparkrun.orchestration.executor import Executor
14+
from sparkrun.utils.shell import b64_wrap_bash
1415

1516
logger = logging.getLogger(__name__)
1617

@@ -28,33 +29,34 @@ def _build_default_opts(self) -> list[str]:
2829
if cfg.gpus:
2930
opts.extend(["--gpus", cfg.gpus])
3031
if cfg.ipc:
31-
opts.append("--ipc=%s" % cfg.ipc)
32+
opts.append("--ipc=%s" % shlex.quote(cfg.ipc))
3233
if cfg.shm_size:
33-
opts.append("--shm-size=%s" % cfg.shm_size)
34+
opts.append("--shm-size=%s" % shlex.quote(cfg.shm_size))
3435
if cfg.network:
35-
opts.append("--network %s" % cfg.network)
36+
logger.debug("DockerExecutor using network: %s", cfg.network)
37+
opts.append("--network=%s" % shlex.quote(cfg.network))
3638
if cfg.user:
3739
if cfg.user == "$SHELL_USER":
3840
opts.extend(["--user", "$(id -u):$(id -g)"])
3941
opts.extend(["-v", "/etc/passwd:/etc/passwd:ro"])
4042
opts.extend(["-v", "/etc/group:/etc/group:ro"])
4143
opts.extend(["-e", "HOME=/tmp"])
4244
else:
43-
opts.extend(["--user", cfg.user])
45+
opts.extend(["--user", shlex.quote(cfg.user)])
4446
if cfg.security_opt:
4547
for opt in cfg.security_opt:
46-
opts.extend(["--security-opt", opt])
48+
opts.extend(["--security-opt", shlex.quote(opt)])
4749
if cfg.cap_add:
4850
for cap in cfg.cap_add:
49-
opts.extend(["--cap-add", cap])
51+
opts.extend(["--cap-add", shlex.quote(cap)])
5052
if cfg.ulimit:
5153
for ul in cfg.ulimit:
52-
opts.extend(["--ulimit", ul])
54+
opts.extend(["--ulimit", shlex.quote(ul)])
5355
if cfg.devices:
5456
for dev in cfg.devices:
55-
opts.extend(["--device", dev])
57+
opts.extend(["--device", shlex.quote(dev)])
5658
if cfg.memory_limit:
57-
opts.append("--memory=%s" % cfg.memory_limit)
59+
opts.append("--memory=%s" % shlex.quote(cfg.memory_limit))
5860

5961
return opts
6062

@@ -84,23 +86,23 @@ def run_cmd(
8486
parts.extend(["--restart", cfg.restart_policy])
8587

8688
if container_name:
87-
parts.extend(["--name", container_name])
89+
parts.extend(["--name", shlex.quote(container_name)])
8890

8991
if env:
9092
for key, value in sorted(env.items()):
91-
parts.extend(["-e", "%s=%s" % (key, value)])
93+
parts.extend(["-e", shlex.quote("%s=%s" % (key, value))])
9294

9395
if volumes:
9496
for host_path, container_path in sorted(volumes.items()):
95-
parts.extend(["-v", "%s:%s" % (host_path, container_path)])
97+
parts.extend(["-v", shlex.quote("%s:%s" % (host_path, container_path))])
9698

9799
if extra_opts:
98100
parts.extend(extra_opts)
99101

100102
parts.append(shlex.quote(image))
101103

102104
if command:
103-
parts.append(command)
105+
parts.extend(["bash", "-c", shlex.quote(b64_wrap_bash(command))])
104106

105107
result = " ".join(parts)
106108

@@ -125,9 +127,8 @@ def exec_cmd(
125127
parts.append("-d")
126128
if env:
127129
for key, value in sorted(env.items()):
128-
parts.extend(["-e", "%s=%s" % (key, value)])
129-
escaped_cmd = command.replace("'", "'\\''")
130-
parts.extend([shlex.quote(container_name), "bash", "-c", "'%s'" % escaped_cmd])
130+
parts.extend(["-e", shlex.quote("%s=%s" % (key, value))])
131+
parts.extend([shlex.quote(container_name), "bash", "-c", shlex.quote(b64_wrap_bash(command))])
131132
return " ".join(parts)
132133

133134
def stop_cmd(self, container_name: str, force: bool = True) -> str:
@@ -149,7 +150,7 @@ def logs_cmd(
149150
parts.append("-f")
150151
if tail is not None:
151152
parts.extend(["--tail", str(tail)])
152-
parts.append(container_name)
153+
parts.append(shlex.quote(container_name))
153154
return " ".join(parts)
154155

155156
def inspect_exists_cmd(self, image: str) -> str:

src/sparkrun/scripts/exec_serve_detached.sh

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
#!/bin/bash
22
set -uo pipefail
33

4-
echo "Executing serve command in container {container_name} (detached)..."
5-
docker exec {container_name} bash -c "nohup bash -c '{full_cmd}' > /tmp/sparkrun_serve.log 2>&1 & echo \$! > /tmp/sparkrun_serve.pid"
4+
printf "Executing serve command in container '%s' (detached)...\n" "{container_name}"
5+
echo "--- Command ---"
6+
printf '%s' '{b64_cmd}' | base64 -d --
7+
echo -e "\n---------------"
8+
9+
docker exec {container_name} bash -c "printf '%s' '{b64_cmd}' | base64 -d -- > /tmp/sparkrun_serve.sh && nohup bash --noprofile --norc /tmp/sparkrun_serve.sh > /tmp/sparkrun_serve.log 2>&1 & echo \$! > /tmp/sparkrun_serve.pid"
610

711
# Watchdog: when serve process exits, kill sleep infinity (PID 1) so container exits
812
docker exec -d {container_name} bash -c 'SERVE_PID=$(cat /tmp/sparkrun_serve.pid); while kill -0 $SERVE_PID 2>/dev/null; do sleep 5; done; kill 1'
Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
#!/bin/bash
22
set -uo pipefail
33

4-
echo "Executing serve command in container {container_name}..."
5-
docker exec {container_name} bash -c '{full_cmd}'
4+
printf "Executing serve command in container '%s'...\n" "{container_name}"
5+
echo "--- Command ---"
6+
printf '%s' '{b64_cmd}' | base64 -d --
7+
echo -e "\n---------------"
8+
9+
docker exec {container_name} bash -c "printf '%s' '{b64_cmd}' | base64 -d -- > /tmp/sparkrun_serve.sh && bash --noprofile --norc /tmp/sparkrun_serve.sh"

src/sparkrun/utils/shell.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,32 @@
55

66
from __future__ import annotations
77

8+
import base64
89
import re
910
import shlex
1011

1112

13+
def b64_encode_cmd(cmd: str) -> str:
14+
"""Base64 encode a command string to avoid shell escaping issues.
15+
16+
Useful when passing complex commands (e.g., with nested quotes or JSON)
17+
across SSH boundaries or into ``docker exec``.
18+
"""
19+
return base64.b64encode(cmd.encode("utf-8")).decode("utf-8")
20+
21+
22+
def b64_wrap_bash(cmd: str) -> str:
23+
"""Wrap a command in a base64 pipeline that decodes and executes via bash.
24+
25+
Produces a string like: ``printf '%s' <b64> | base64 -d -- | bash``
26+
"""
27+
b64_cmd = b64_encode_cmd(cmd)
28+
# Using printf instead of echo is safer against strings starting with dashes.
29+
# Adding -- to base64 -d prevents interpretation of the b64 string as flags.
30+
# Using --noprofile --norc with bash ensures a clean execution environment.
31+
return f"printf '%s' '{b64_cmd}' | base64 -d -- | bash --noprofile --norc"
32+
33+
1234
def quote(value: str) -> str:
1335
"""Return a shell-safe quoted version of *value*.
1436

0 commit comments

Comments
 (0)