diff --git a/torchx/components/utils.py b/torchx/components/utils.py index e64069a14..7709f1553 100644 --- a/torchx/components/utils.py +++ b/torchx/components/utils.py @@ -104,19 +104,26 @@ def sh( entrypoint: the entrypoint to use for the command (defaults to sh) """ - escaped_args = " ".join(shlex.quote(arg) for arg in args) + escaped_args = [shlex.quote(arg) for arg in args] if env is None: env = {} env.setdefault("LOGLEVEL", os.getenv("LOGLEVEL", "WARNING")) + if entrypoint is not None: + resolved_entrypoint = entrypoint + resolved_args = escaped_args + else: + resolved_entrypoint = "sh" + resolved_args = ["-c", " ".join(escaped_args)] + return specs.AppDef( name="sh", roles=[ specs.Role( name="sh", image=image, - entrypoint=entrypoint or "sh", - args=["-c", escaped_args], + entrypoint=resolved_entrypoint, + args=resolved_args, num_replicas=num_replicas, resource=specs.resource(cpu=cpu, gpu=gpu, memMB=memMB, h=h), env=env,