Skip to content

Commit

Permalink
FEATURE: Allowing to start jobs from shell
Browse files Browse the repository at this point in the history
  • Loading branch information
d-krupke committed Feb 27, 2024
1 parent 6ad53ba commit f9f8d7a
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 33 deletions.
41 changes: 15 additions & 26 deletions src/slurminade/bundling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import typing
from collections import defaultdict
from pathlib import Path

from .dispatcher import (
Dispatcher,
Expand Down Expand Up @@ -40,17 +41,14 @@ class TaskBuffer:
def __init__(self):
self._tasks = defaultdict(list)

def add(self, task: FunctionCall, options: SlurmOptions) -> int:
self._tasks[options].append(task)
return len(self._tasks[options])
def add(self, task: FunctionCall, options: SlurmOptions, entry_point: Path) -> int:
self._tasks[(entry_point, options)].append(task)
return len(self._tasks[(entry_point, options)])

def items(self):
for opt, tasks in self._tasks.items():
for (entry_point, opt), tasks in self._tasks.items():
if tasks:
yield opt, tasks

def get(self, options: SlurmOptions) -> typing.List[FunctionCall]:
return self._tasks[options]
yield entry_point, opt, tasks

def clear(self):
self._tasks.clear()
Expand Down Expand Up @@ -78,10 +76,9 @@ def __init__(self, max_size: int):
self.max_size = max_size
self.subdispatcher = get_dispatcher()
self._tasks = TaskBuffer()
self._batch_guard = BatchGuard()
self._all_job_ids = []

def flush(self, options: typing.Optional[SlurmOptions] = None) -> typing.List[int]:
def flush(self) -> typing.List[int]:
"""
Distribute all buffered tasks. Return the job ids used.
This method is called automatically when the context is exited.
Expand All @@ -91,20 +88,11 @@ def flush(self, options: typing.Optional[SlurmOptions] = None) -> typing.List[in
:return: A list of job ids.
"""
job_ids = []
if options is None:
for opt, tasks in self._tasks.items():
while tasks:
job_id = self.subdispatcher(tasks[: self.max_size], opt)
job_ids.append(job_id)
tasks = tasks[self.max_size :]

else:
tasks = self._tasks.get(options)
self._batch_guard.report_flush(len(tasks))
while len(tasks) > self.max_size:
job_id = self.subdispatcher(tasks[: self.max_size], options)
for entry_point, opt, tasks in self._tasks.items():
while tasks:
job_id = self.subdispatcher(tasks[: self.max_size], opt, entry_point)
job_ids.append(job_id)
tasks = tasks[: self.max_size]
tasks = tasks[self.max_size :]
self._tasks.clear()
self._all_job_ids.extend(job_ids)
return job_ids
Expand All @@ -124,20 +112,21 @@ def add(self, func: SlurmFunction, *args, **kwargs):
:return: None
"""
self._dispatch(
[FunctionCall(func.func_id, args, kwargs)], func.special_slurm_opts
[FunctionCall(func.func_id, args, kwargs)], func.special_slurm_opts, func.get_entry_point()
)

def _dispatch(
self,
funcs: typing.Iterable[FunctionCall],
options: SlurmOptions,
entry_point: Path,
block: bool = False,
) -> JobReference:
if block:
# if blocking, we don't buffer, but dispatch immediately
return self.subdispatcher._dispatch(funcs, options, block=True)
return self.subdispatcher._dispatch(funcs, options, entry_point, block=True)
for func in funcs:
self._tasks.add(func, options)
self._tasks.add(func, options, entry_point)
return BundlingJobReference()

def srun(
Expand Down
15 changes: 11 additions & 4 deletions src/slurminade/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def _dispatch(
self,
funcs: typing.Iterable[FunctionCall],
options: SlurmOptions,
entry_point: Path,
block: bool = False,
) -> JobReference:
"""
Expand Down Expand Up @@ -96,6 +97,7 @@ def __call__(
self,
funcs: typing.Union[FunctionCall, typing.Iterable[FunctionCall]],
options: SlurmOptions,
entry_point: Path,
block: bool = False,
) -> JobReference:
"""
Expand All @@ -108,7 +110,7 @@ def __call__(
funcs = [funcs]
funcs = list(funcs)
self._log_dispatch(funcs, options)
return self._dispatch(funcs, options, block)
return self._dispatch(funcs, options, entry_point, block)

def is_sequential(self):
"""
Expand Down Expand Up @@ -156,6 +158,7 @@ def _dispatch(
self,
funcs: typing.Iterable[FunctionCall],
options: SlurmOptions, # noqa: ARG002
entry_point: Path,
block: bool = False, # noqa: ARG002
) -> JobReference:
dispatch_guard()
Expand Down Expand Up @@ -251,6 +254,7 @@ def _dispatch(
self,
funcs: typing.Iterable[FunctionCall],
options: SlurmOptions,
entry_point: Path,
block: bool = False,
) -> SlurmJobReference:
dispatch_guard()
Expand All @@ -264,7 +268,7 @@ def _dispatch(
options.add_dependencies(self._join_dependencies, "afterany")
slurm = self._create_slurm_api(options)
command = create_slurminade_command(
get_entry_point(), funcs, self.max_arg_length
entry_point, funcs, self.max_arg_length
)
logging.getLogger("slurminade").debug(command)
if block:
Expand Down Expand Up @@ -347,11 +351,12 @@ def _dispatch(
self,
funcs: typing.Iterable[FunctionCall],
options: SlurmOptions, # noqa: ARG002
entry_point: Path, # noqa: ARG002
block: bool = False, # noqa: ARG002
) -> int:
dispatch_guard()
command = create_slurminade_command(
get_entry_point(), funcs, self.max_arg_length
entry_point, funcs, self.max_arg_length
)
os.system(command)
return -1
Expand Down Expand Up @@ -400,6 +405,7 @@ def _dispatch(
self,
funcs: typing.Iterable[FunctionCall],
options: SlurmOptions, # noqa: ARG002
entry_point: Path, # noqa: ARG002
block: bool = False, # noqa: ARG002
) -> LocalJobReference:
dispatch_guard()
Expand Down Expand Up @@ -466,6 +472,7 @@ def set_dispatcher(dispatcher: Dispatcher) -> None:
def dispatch(
funcs: typing.Union[FunctionCall, typing.Iterable[FunctionCall]],
options: SlurmOptions,
entry_point: Path,
block: bool = False,
) -> JobReference:
"""
Expand All @@ -479,7 +486,7 @@ def dispatch(
if not FunctionMap.check_id(func.func_id):
msg = f"Function '{func.func_id}' cannot be called from the given entry point."
raise KeyError(msg)
return get_dispatcher()(funcs, options, block)
return get_dispatcher()(funcs, options, entry_point, block)


def srun(
Expand Down
19 changes: 17 additions & 2 deletions src/slurminade/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import subprocess
import typing
from enum import Enum

from pathlib import Path
import logging
from .dispatcher import FunctionCall, dispatch, get_dispatcher
from .function_map import FunctionMap
from .function_map import FunctionMap, get_entry_point
from .guard import guard_recursive_distribution
from .job_reference import JobReference
from .options import SlurmOptions
Expand Down Expand Up @@ -48,6 +49,7 @@ def __init__(
self.func = func
self.func_id = func_id
self.call_policy = call_policy
self.defining_file = Path(inspect.getfile(func))

def update_options(self, conf: typing.Dict[str, typing.Any]):
self.special_slurm_opts.update(conf)
Expand Down Expand Up @@ -122,6 +124,17 @@ def __call__(self, *args, **kwargs):
else:
msg = "Unknown call policy."
raise RuntimeError(msg)

def get_entry_point(self)->Path:
"""
Returns the entry point for the function.
Either it is defined in the FunctionMap, or the defining file is used.
"""
try:
return get_entry_point()
except FileNotFoundError:
logging.getLogger("slurminade").debug("Using defining file %s as entry point.", self.defining_file)
return self.defining_file

def distribute(self, *args, **kwargs) -> JobReference:
"""
Expand All @@ -136,6 +149,7 @@ def distribute(self, *args, **kwargs) -> JobReference:
return dispatch(
[FunctionCall(self.func_id, args, kwargs)],
self.special_slurm_opts,
entry_point=self.get_entry_point(),
block=False,
)

Expand All @@ -151,6 +165,7 @@ def distribute_and_wait(self, *args, **kwargs) -> JobReference:
return dispatch(
[FunctionCall(self.func_id, args, kwargs)],
self.special_slurm_opts,
entry_point=self.get_entry_point(),
block=True,
)

Expand Down
2 changes: 1 addition & 1 deletion src/slurminade/function_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ class FunctionCall:
A function call to be dispatched.
"""

def __init__(self, func_id, args: typing.Tuple, kwargs: typing.Dict):
def __init__(self, func_id: str, args: typing.Tuple, kwargs: typing.Dict):
self.func_id = func_id # the function id, as in FunctionMap
self.args = args # the positional arguments for the call
self.kwargs = kwargs # the keyword arguments for the call
Expand Down

0 comments on commit f9f8d7a

Please sign in to comment.