Skip to content

Commit 25bdc52

Browse files
d4l3kfacebook-github-bot
authored andcommitted
torchx/workspace: workspace_opts + standardize dryrun_push_images and push_images (#619)
Summary: This makes a few changes intended to make it easier to swap out the workspace implementation for schedulers that support multiple image runtimes. I.e. LSF and Slurm support Docker, Singularity, native, etc * This adds a new `workspace_opts` so we can keep runopts attached to the workspace * This standardizes the `dryrun_push_images` and `push_images` that were used in DockerWorkspace so we can swap the workspace without running into any type issues. Pull Request resolved: #619 Test Plan: Refactor -- no functional changes. Existing tests/CI should suffice Reviewed By: kurman Differential Revision: D40446676 Pulled By: d4l3k fbshipit-source-id: 763eeaddb14282bf8bc189d979de2277fcd3b7d9
1 parent da19c5f commit 25bdc52

21 files changed

+149
-84
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ Milestone: https://github.com/pytorch/torchx/milestones/3
7777
* Slurm jobs will by default launch in the current working directory to match `local_cwd` and workspace behavior. #372
7878
* Replicas now have their own log files and can be accessed programmatically. #373
7979
* Support for `comment`, `mail-user` and `constraint` fields. #391
80-
* Workspace support (prototype) - Slurm jobs can now be launched in isolated experiment directories. #416
80+
* WorkspaceMixin support (prototype) - Slurm jobs can now be launched in isolated experiment directories. #416
8181
* Kubernetes
8282
* Support for running jobs under service accounts. #408
8383
* Support for specifying instance types. #433

docs/source/workspace.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ torchx.workspace
66

77
.. currentmodule:: torchx.workspace
88

9-
.. autoclass:: Workspace
9+
.. autoclass:: WorkspaceMixin
1010
:members:
1111

1212
.. autofunction:: walk_workspace
@@ -18,7 +18,7 @@ torchx.workspace.docker_workspace
1818
.. automodule:: torchx.workspace.docker_workspace
1919
.. currentmodule:: torchx.workspace.docker_workspace
2020

21-
.. autoclass:: DockerWorkspace
21+
.. autoclass:: DockerWorkspaceMixin
2222
:members:
2323
:private-members: _update_app_images, _push_images
2424

@@ -29,7 +29,7 @@ torchx.workspace.dir_workspace
2929
.. automodule:: torchx.workspace.dir_workspace
3030
.. currentmodule:: torchx.workspace.dir_workspace
3131

32-
.. autoclass:: DirWorkspace
32+
.. autoclass:: DirWorkspaceMixin
3333
:members:
3434

3535
.. fbcode::
@@ -40,6 +40,6 @@ torchx.workspace.dir_workspace
4040
.. automodule:: torchx.workspace.fb.jetter_workspace
4141
.. currentmodule:: torchx.workspace.fb.jetter_workspace
4242

43-
.. autoclass:: JetterWorkspace
43+
.. autoclass:: JetterWorkspaceMixin
4444
:members:
4545
:show-inheritance:

torchx/runner/api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from torchx.tracker.api import tracker_config_env_var_name, TRACKER_ENV_VAR_NAME
3535

3636
from torchx.util.types import none_throws
37-
from torchx.workspace.api import Workspace
37+
from torchx.workspace.api import WorkspaceMixin
3838

3939
from .config import get_config, get_configs
4040

@@ -363,7 +363,7 @@ def dryrun(
363363
with log_event("dryrun", scheduler, runcfg=json.dumps(cfg) if cfg else None):
364364
sched = self._scheduler(scheduler)
365365

366-
if workspace and isinstance(sched, Workspace):
366+
if workspace and isinstance(sched, WorkspaceMixin):
367367
role = app.roles[0]
368368
old_img = role.image
369369

torchx/runner/test/api_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from torchx.specs.finder import ComponentNotFoundException
2727

2828
from torchx.util.types import none_throws
29-
from torchx.workspace import Workspace
29+
from torchx.workspace import WorkspaceMixin
3030

3131

3232
GET_SCHEDULER_FACTORIES = "torchx.runner.api.get_scheduler_factories"
@@ -293,9 +293,9 @@ def test_dryrun_setup_trackers_as_env_variable(self, _) -> None:
293293
)
294294

295295
def test_dryrun_with_workspace(self, _) -> None:
296-
class TestScheduler(Scheduler, Workspace):
296+
class TestScheduler(WorkspaceMixin[None], Scheduler):
297297
def __init__(self, build_new_img: bool):
298-
Scheduler.__init__(self, backend="ignored", session_name="ignored")
298+
super().__init__(backend="ignored", session_name="ignored")
299299
self.build_new_img = build_new_img
300300

301301
def schedule(self, dryrun_info: AppDryRunInfo) -> str:

torchx/runner/test/config_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def log_iter(
6262
def list(self) -> List[ListAppResponse]:
6363
raise NotImplementedError()
6464

65-
def run_opts(self) -> runopts:
65+
def _run_opts(self) -> runopts:
6666
opts = runopts()
6767
opts.add(
6868
"i",

torchx/schedulers/api.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
RoleStatus,
2323
runopts,
2424
)
25-
from torchx.workspace.api import Workspace
25+
from torchx.workspace.api import WorkspaceMixin
2626

2727

2828
DAYS_IN_2_WEEKS = 14
@@ -138,7 +138,7 @@ def submit(
138138
"""
139139
if workspace:
140140
sched = self
141-
assert isinstance(sched, Workspace)
141+
assert isinstance(sched, WorkspaceMixin)
142142
role = app.roles[0]
143143
sched.build_workspace_and_update_role(role, workspace, cfg)
144144
dryrun_info = self.submit_dryrun(app, cfg)
@@ -189,6 +189,12 @@ def run_opts(self) -> runopts:
189189
Returns the run configuration options expected by the scheduler.
190190
Basically a ``--help`` for the ``run`` API.
191191
"""
192+
opts = self._run_opts()
193+
if isinstance(self, WorkspaceMixin):
194+
opts.update(self.workspace_opts())
195+
return opts
196+
197+
def _run_opts(self) -> runopts:
192198
return runopts()
193199

194200
@abc.abstractmethod

torchx/schedulers/aws_batch_scheduler.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,11 @@
4242
from typing import (
4343
Any,
4444
Callable,
45+
cast,
4546
Dict,
4647
Iterable,
4748
List,
49+
Mapping,
4850
Optional,
4951
Tuple,
5052
TYPE_CHECKING,
@@ -67,13 +69,14 @@
6769
AppDef,
6870
AppState,
6971
BindMount,
72+
CfgVal,
7073
DeviceMount,
7174
macros,
7275
Role,
7376
runopts,
7477
VolumeMount,
7578
)
76-
from torchx.workspace.docker_workspace import DockerWorkspace
79+
from torchx.workspace.docker_workspace import DockerWorkspaceMixin
7780
from typing_extensions import TypedDict
7881

7982
if TYPE_CHECKING:
@@ -246,7 +249,7 @@ class AWSBatchOpts(TypedDict, total=False):
246249
priority: Optional[int]
247250

248251

249-
class AWSBatchScheduler(Scheduler[AWSBatchOpts], DockerWorkspace):
252+
class AWSBatchScheduler(DockerWorkspaceMixin, Scheduler[AWSBatchOpts]):
250253
"""
251254
AWSBatchScheduler is a TorchX scheduling interface to AWS Batch.
252255
@@ -308,8 +311,7 @@ def __init__(
308311
log_client: Optional[Any] = None,
309312
docker_client: Optional["DockerClient"] = None,
310313
) -> None:
311-
Scheduler.__init__(self, "aws_batch", session_name)
312-
DockerWorkspace.__init__(self, docker_client)
314+
super().__init__("aws_batch", session_name, docker_client=docker_client)
313315

314316
# pyre-fixme[4]: Attribute annotation cannot be `Any`.
315317
self.__client = client
@@ -335,7 +337,7 @@ def schedule(self, dryrun_info: AppDryRunInfo[BatchJob]) -> str:
335337
assert cfg is not None, f"{dryrun_info} missing cfg"
336338

337339
images_to_push = dryrun_info.request.images_to_push
338-
self._push_images(images_to_push)
340+
self.push_images(images_to_push)
339341

340342
req = dryrun_info.request
341343
self._client.register_job_definition(**req.job_def)
@@ -370,7 +372,7 @@ def _submit_dryrun(self, app: AppDef, cfg: AWSBatchOpts) -> AppDryRunInfo[BatchJ
370372
name = make_unique(f"{app.name}{name_suffix}")
371373

372374
# map any local images to the remote image
373-
images_to_push = self._update_app_images(app, cfg.get("image_repo"))
375+
images_to_push = self.dryrun_push_images(app, cast(Mapping[str, CfgVal], cfg))
374376

375377
nodes = []
376378

@@ -450,14 +452,9 @@ def _cancel_existing(self, app_id: str) -> None:
450452
reason="killed via torchx CLI",
451453
)
452454

453-
def run_opts(self) -> runopts:
455+
def _run_opts(self) -> runopts:
454456
opts = runopts()
455457
opts.add("queue", type_=str, help="queue to schedule job in", required=True)
456-
opts.add(
457-
"image_repo",
458-
type_=str,
459-
help="The image repository to use when pushing patched images, must have push access. Ex: example.com/your/container",
460-
)
461458
opts.add(
462459
"share_id",
463460
type_=str,

torchx/schedulers/docker_scheduler.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
runopts,
3939
VolumeMount,
4040
)
41-
from torchx.workspace.docker_workspace import DockerWorkspace
41+
from torchx.workspace.docker_workspace import DockerWorkspaceMixin
4242
from typing_extensions import TypedDict
4343

4444

@@ -76,7 +76,7 @@ def __repr__(self) -> str:
7676
return str(self)
7777

7878

79-
LABEL_VERSION: str = DockerWorkspace.LABEL_VERSION
79+
LABEL_VERSION: str = DockerWorkspaceMixin.LABEL_VERSION
8080
LABEL_APP_ID: str = "torchx.pytorch.org/app-id"
8181
LABEL_ROLE_NAME: str = "torchx.pytorch.org/role-name"
8282
LABEL_REPLICA_ID: str = "torchx.pytorch.org/replica-id"
@@ -98,7 +98,7 @@ class DockerOpts(TypedDict, total=False):
9898
copy_env: Optional[List[str]]
9999

100100

101-
class DockerScheduler(Scheduler[DockerOpts], DockerWorkspace):
101+
class DockerScheduler(DockerWorkspaceMixin, Scheduler[DockerOpts]):
102102
"""
103103
DockerScheduler is a TorchX scheduling interface to Docker.
104104
@@ -143,8 +143,7 @@ class DockerScheduler(Scheduler[DockerOpts], DockerWorkspace):
143143
"""
144144

145145
def __init__(self, session_name: str) -> None:
146-
Scheduler.__init__(self, "docker", session_name)
147-
DockerWorkspace.__init__(self)
146+
super().__init__("docker", session_name)
148147

149148
def _ensure_network(self) -> None:
150149
import filelock
@@ -346,7 +345,7 @@ def _cancel_existing(self, app_id: str) -> None:
346345
for container in containers:
347346
container.stop()
348347

349-
def run_opts(self) -> runopts:
348+
def _run_opts(self) -> runopts:
350349
opts = runopts()
351350
opts.add(
352351
"copy_env",

torchx/schedulers/kubernetes_scheduler.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,17 @@
3333
import warnings
3434
from dataclasses import dataclass
3535
from datetime import datetime
36-
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, TYPE_CHECKING
36+
from typing import (
37+
Any,
38+
cast,
39+
Dict,
40+
Iterable,
41+
List,
42+
Mapping,
43+
Optional,
44+
Tuple,
45+
TYPE_CHECKING,
46+
)
3747

3848
import torchx
3949
import yaml
@@ -51,6 +61,7 @@
5161
AppDef,
5262
AppState,
5363
BindMount,
64+
CfgVal,
5465
DeviceMount,
5566
macros,
5667
ReplicaState,
@@ -61,7 +72,7 @@
6172
runopts,
6273
VolumeMount,
6374
)
64-
from torchx.workspace.docker_workspace import DockerWorkspace
75+
from torchx.workspace.docker_workspace import DockerWorkspaceMixin
6576
from typing_extensions import TypedDict
6677

6778

@@ -441,7 +452,7 @@ class KubernetesOpts(TypedDict, total=False):
441452
priority_class: Optional[str]
442453

443454

444-
class KubernetesScheduler(Scheduler[KubernetesOpts], DockerWorkspace):
455+
class KubernetesScheduler(DockerWorkspaceMixin, Scheduler[KubernetesOpts]):
445456
"""
446457
KubernetesScheduler is a TorchX scheduling interface to Kubernetes.
447458
@@ -535,8 +546,7 @@ def __init__(
535546
client: Optional["ApiClient"] = None,
536547
docker_client: Optional["DockerClient"] = None,
537548
) -> None:
538-
Scheduler.__init__(self, "kubernetes", session_name)
539-
DockerWorkspace.__init__(self, docker_client)
549+
super().__init__("kubernetes", session_name, docker_client=docker_client)
540550

541551
self._client = client
542552

@@ -575,7 +585,7 @@ def schedule(self, dryrun_info: AppDryRunInfo[KubernetesJob]) -> str:
575585
namespace = cfg.get("namespace") or "default"
576586

577587
images_to_push = dryrun_info.request.images_to_push
578-
self._push_images(images_to_push)
588+
self.push_images(images_to_push)
579589

580590
resource = dryrun_info.request.resource
581591
try:
@@ -605,7 +615,7 @@ def _submit_dryrun(
605615
raise TypeError(f"config value 'queue' must be a string, got {queue}")
606616

607617
# map any local images to the remote image
608-
images_to_push = self._update_app_images(app, cfg.get("image_repo"))
618+
images_to_push = self.dryrun_push_images(app, cast(Mapping[str, CfgVal], cfg))
609619

610620
service_account = cfg.get("service_account")
611621
assert service_account is None or isinstance(
@@ -642,7 +652,7 @@ def _cancel_existing(self, app_id: str) -> None:
642652
name=name,
643653
)
644654

645-
def run_opts(self) -> runopts:
655+
def _run_opts(self) -> runopts:
646656
opts = runopts()
647657
opts.add(
648658
"namespace",
@@ -656,11 +666,6 @@ def run_opts(self) -> runopts:
656666
help="Volcano queue to schedule job in",
657667
required=True,
658668
)
659-
opts.add(
660-
"image_repo",
661-
type_=str,
662-
help="The image repository to use when pushing patched images, must have push access. Ex: example.com/your/container",
663-
)
664669
opts.add(
665670
"service_account",
666671
type_=str,

torchx/schedulers/local_scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -573,7 +573,7 @@ def __init__(
573573
self._base_log_dir: Optional[str] = None
574574
self._created_tmp_log_dir: bool = False
575575

576-
def run_opts(self) -> runopts:
576+
def _run_opts(self) -> runopts:
577577
opts = runopts()
578578
opts.add(
579579
"log_dir",

0 commit comments

Comments
 (0)