diff --git a/.gitignore b/.gitignore index e17e6b7..3af0f19 100644 --- a/.gitignore +++ b/.gitignore @@ -128,3 +128,4 @@ dmypy.json # Pyre type checker .pyre/ .idea +examples/slurminade_example.txt diff --git a/README.rst b/README.rst index e11c802..8396322 100644 --- a/README.rst +++ b/README.rst @@ -359,6 +359,7 @@ The project is reasonably easy: Changes ------- +- 1.1.0: Slurminade can now be called from iPython, too! `exec` has been renamed `shell` to prevent confusion with the Python call `exec` which will evaluate a string as Python code. - 1.0.1: Dispatcher now return jobs references instead of job ids. This allows to do some fancier stuff in the future, when the jobs infos are only available a short time after the job has been submitted. - 0.10.1: FIX: Listing functions will no longer execute setup functions. - 0.10.0: `Batch` is now named `JobBundling`. There is a method `join` for easier synchronization. `exec` allows to executed commands just like `srun` and `sbatch`, but uniform syntax with other slurmified functions. Functions can now also be called with `distribute_and_wait`. If you call `python3 -m slurminade.check --partition YOUR_PARTITION --constraint YOUR_CONSTRAINT` you can check if your slurm configuration is running correctly. diff --git a/pyproject.toml b/pyproject.toml index 3a6d25d..2d3c635 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ where = ["src"] [project] name = "slurminade" -version = "1.0.1" +version = "1.1.0" authors = [ { name = "TU Braunschweig, IBR, Algorithms Group (Dominik Krupke)", email = "krupke@ibr.cs.tu-bs.de" }, ] @@ -69,7 +69,7 @@ extend-ignore = [ "PT004", # Incorrect, just usefixtures instead. "RUF009", # Too easy to get a false positive ] -target-version = "py37" +target-version = "py38" src = ["src"] unfixable = ["T20", "F841"] exclude = [] @@ -78,7 +78,7 @@ exclude = [] [tool.mypy] files = ["src", "tests"] mypy_path = ["$MYPY_CONFIG_FILE_DIR/src"] -python_version = "3.7" +python_version = "3.8" warn_unused_configs = true show_error_codes = true enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"] @@ -95,7 +95,7 @@ ignore_missing_imports = true [tool.pylint] -py-version = "3.7" +py-version = "3.8" jobs = "0" reports.output-format = "colorized" similarities.ignore-imports = "yes" diff --git a/src/slurminade/__init__.py b/src/slurminade/__init__.py index c022fe3..784beb9 100644 --- a/src/slurminade/__init__.py +++ b/src/slurminade/__init__.py @@ -55,7 +55,7 @@ def clean_up(): """ # flake8: noqa F401 -from .function import slurmify +from .function import slurmify, shell from .conf import update_default_configuration, set_default_configuration from .guard import ( set_dispatch_limit, @@ -94,6 +94,7 @@ def clean_up(): "TestDispatcher", "SubprocessDispatcher", "set_entry_point", + "shell", "node_setup", ] diff --git a/src/slurminade/bundling.py b/src/slurminade/bundling.py index 8f4f528..d91b8a3 100644 --- a/src/slurminade/bundling.py +++ b/src/slurminade/bundling.py @@ -4,6 +4,7 @@ import logging import typing from collections import defaultdict +from pathlib import Path from .dispatcher import ( Dispatcher, @@ -12,24 +13,24 @@ set_dispatcher, ) from .function import SlurmFunction -from .guard import BatchGuard -from .options import SlurmOptions from .job_reference import JobReference +from .options import SlurmOptions + class BundlingJobReference(JobReference): def __init__(self) -> None: super().__init__() - pass def get_job_id(self) -> typing.Optional[int]: return None - + def get_exit_code(self) -> typing.Optional[int]: return None - + def get_info(self) -> typing.Dict[str, typing.Any]: return {} + class TaskBuffer: """ A simple container to buffer all the tasks by their options. @@ -39,21 +40,19 @@ class TaskBuffer: def __init__(self): self._tasks = defaultdict(list) - def add(self, task: FunctionCall, options: SlurmOptions) -> int: - self._tasks[options].append(task) - return len(self._tasks[options]) + def add(self, task: FunctionCall, options: SlurmOptions, entry_point: Path) -> int: + self._tasks[(entry_point, options)].append(task) + return len(self._tasks[(entry_point, options)]) def items(self): - for opt, tasks in self._tasks.items(): + for (entry_point, opt), tasks in self._tasks.items(): if tasks: - yield opt, tasks - - def get(self, options: SlurmOptions) -> typing.List[FunctionCall]: - return self._tasks[options] + yield entry_point, opt, tasks def clear(self): self._tasks.clear() + class JobBundling(Dispatcher): """ The logic to buffer the function calls. It wraps the original dispatcher. @@ -76,10 +75,9 @@ def __init__(self, max_size: int): self.max_size = max_size self.subdispatcher = get_dispatcher() self._tasks = TaskBuffer() - self._batch_guard = BatchGuard() self._all_job_ids = [] - def flush(self, options: typing.Optional[SlurmOptions] = None) -> typing.List[int]: + def flush(self) -> typing.List[int]: """ Distribute all buffered tasks. Return the job ids used. This method is called automatically when the context is exited. @@ -89,24 +87,15 @@ def flush(self, options: typing.Optional[SlurmOptions] = None) -> typing.List[in :return: A list of job ids. """ job_ids = [] - if options is None: - for opt, tasks in self._tasks.items(): - while tasks: - job_id = self.subdispatcher(tasks[: self.max_size], opt) - job_ids.append(job_id) - tasks = tasks[self.max_size :] - - else: - tasks = self._tasks.get(options) - self._batch_guard.report_flush(len(tasks)) - while len(tasks) > self.max_size: - job_id = self.subdispatcher(tasks[: self.max_size], options) + for entry_point, opt, tasks in self._tasks.items(): + while tasks: + job_id = self.subdispatcher(tasks[: self.max_size], opt, entry_point) job_ids.append(job_id) - tasks = tasks[: self.max_size] + tasks = tasks[self.max_size :] self._tasks.clear() self._all_job_ids.extend(job_ids) return job_ids - + def get_all_job_ids(self): """ Return all job ids that have been used. @@ -122,20 +111,23 @@ def add(self, func: SlurmFunction, *args, **kwargs): :return: None """ self._dispatch( - [FunctionCall(func.func_id, args, kwargs)], func.special_slurm_opts + [FunctionCall(func.func_id, args, kwargs)], + func.special_slurm_opts, + func.get_entry_point(), ) def _dispatch( self, funcs: typing.Iterable[FunctionCall], options: SlurmOptions, + entry_point: Path, block: bool = False, ) -> JobReference: if block: # if blocking, we don't buffer, but dispatch immediately - return self.subdispatcher._dispatch(funcs, options, block=True) + return self.subdispatcher._dispatch(funcs, options, entry_point, block=True) for func in funcs: - self._tasks.add(func, options) + self._tasks.add(func, options, entry_point) return BundlingJobReference() def srun( @@ -189,13 +181,13 @@ def is_sequential(self): return self.subdispatcher.is_sequential() - class Batch(JobBundling): """ Compatibility alias for JobBundling. This is the old name. Deprecated. """ + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) logging.getLogger("slurminade").warning( "The `Batch` class has been renamed to `JobBundling`. Please update your code." - ) \ No newline at end of file + ) diff --git a/src/slurminade/check.py b/src/slurminade/check.py index 64cf52c..471de75 100644 --- a/src/slurminade/check.py +++ b/src/slurminade/check.py @@ -13,7 +13,7 @@ def _write_to_file(path, content): time.sleep(1) # get hostname and write it to the file hostname = socket.gethostname() - with open(path, "w") as file: + with Path(path).open("w") as file: print("Hello from ", hostname) file.write(content + "\n" + hostname) # wait a second for the file to be written @@ -43,15 +43,15 @@ def check_slurm(partition, constraint): # create a temporary folder for the slurm check with tempfile.TemporaryDirectory(dir=".") as tmpdir: - tmpdir = Path(tmpdir).resolve() - assert Path(tmpdir).exists() + tmpdir_ = Path(tmpdir).resolve() + assert tmpdir_.exists() # Check 1 - tmp_file_path = tmpdir / "check_1.txt" + tmp_file_path = tmpdir_ / "check_1.txt" _write_to_file.distribute_and_wait(str(tmp_file_path), "test") if not Path(tmp_file_path).exists(): msg = "Slurminade failed: The file was not written to the temporary directory." raise Exception(msg) - with open(tmp_file_path) as file: + with Path(tmp_file_path).open() as file: content = file.readlines() print( "Slurminade check 1 successful. Test was run on node", @@ -59,7 +59,7 @@ def check_slurm(partition, constraint): ) # Check 2 - tmp_file_path = tmpdir / "check_2.txt" + tmp_file_path = tmpdir_ / "check_2.txt" _write_to_file.distribute(str(tmp_file_path), "test") # wait up to 1 minutes for the file to be written for _ in range(60): @@ -69,7 +69,7 @@ def check_slurm(partition, constraint): if not Path(tmp_file_path).exists(): msg = "Slurminade failed: The file was not written to the temporary directory." raise Exception(msg) - with open(tmp_file_path) as file: + with Path(tmp_file_path).open() as file: content = file.readlines() print( "Slurminade check 2 successful. Test was run on node", @@ -79,7 +79,7 @@ def check_slurm(partition, constraint): join() # Check 3 - tmp_file_path = tmpdir / "check_3.txt" + tmp_file_path = tmpdir_ / "check_3.txt" srun(["touch", str(tmp_file_path)]) time.sleep(1) if not Path(tmp_file_path).exists(): @@ -89,12 +89,12 @@ def check_slurm(partition, constraint): tmp_file_path.unlink() # Check 4 - tmp_file_path = tmpdir / "check_4.txt" + tmp_file_path = tmpdir_ / "check_4.txt" _write_to_file.distribute_and_wait(str(tmp_file_path), "test") if not Path(tmp_file_path).exists(): msg = "Slurminade failed: The file was not written to the temporary directory." raise Exception(msg) - with open(tmp_file_path) as file: + with Path(tmp_file_path).open() as file: content = file.readlines() print( "Slurminade check 1 successful. Test was run on node", diff --git a/src/slurminade/conf.py b/src/slurminade/conf.py index 9409682..479d78d 100644 --- a/src/slurminade/conf.py +++ b/src/slurminade/conf.py @@ -12,10 +12,10 @@ __default_conf: typing.Dict = {} -def _load_conf(path): +def _load_conf(path: Path): try: - if os.path.isfile(path): - with open(path) as f: + if path.is_file(): + with path.open() as f: return json.load(f) else: return {} @@ -25,6 +25,11 @@ def _load_conf(path): def update_default_configuration(conf=None, **kwargs): + """ + Adds or updates the default configuration. + :param conf: A dictionary with the configuration. + :param kwargs: Configuration parameters. (alternative to giving a dictionary) + """ if conf: __default_conf.update(conf) if kwargs: @@ -32,18 +37,24 @@ def update_default_configuration(conf=None, **kwargs): def _load_default_conf(): - path = os.path.join(Path.home(), CONFIG_NAME) + path = Path.home() / CONFIG_NAME update_default_configuration(_load_conf(path)) if "XDG_CONFIG_HOME" in os.environ: - path = os.path.join(os.environ["XDG_CONFIG_HOME"], "slurminade", CONFIG_NAME) + path = Path(os.environ["XDG_CONFIG_HOME"]) / "slurminade" / CONFIG_NAME update_default_configuration(_load_conf(path)) - update_default_configuration(_load_conf(CONFIG_NAME)) + update_default_configuration(_load_conf(Path(CONFIG_NAME))) _load_default_conf() def set_default_configuration(conf=None, **kwargs): + """ + Replaces the default configuration. + This will overwrite the default configuration with the given one. + :param conf: A dictionary with the configuration. + :param kwargs: Configuration parameters. (alternative to giving a dictionary) + """ __default_conf = {} update_default_configuration(conf, **kwargs) diff --git a/src/slurminade/dispatcher.py b/src/slurminade/dispatcher.py index 059e354..0b43dc8 100644 --- a/src/slurminade/dispatcher.py +++ b/src/slurminade/dispatcher.py @@ -11,7 +11,9 @@ import shutil import subprocess import typing +from pathlib import Path from typing import Any, Dict, Optional + import simple_slurm from .conf import _get_conf @@ -26,6 +28,7 @@ from .job_reference import JobReference + class Dispatcher(abc.ABC): """ Abstract dispatcher to be inherited by all concrete dispatchers. @@ -38,6 +41,7 @@ def _dispatch( self, funcs: typing.Iterable[FunctionCall], options: SlurmOptions, + entry_point: Path, block: bool = False, ) -> JobReference: """ @@ -93,6 +97,7 @@ def __call__( self, funcs: typing.Union[FunctionCall, typing.Iterable[FunctionCall]], options: SlurmOptions, + entry_point: Path, block: bool = False, ) -> JobReference: """ @@ -105,7 +110,7 @@ def __call__( funcs = [funcs] funcs = list(funcs) self._log_dispatch(funcs, options) - return self._dispatch(funcs, options, block) + return self._dispatch(funcs, options, entry_point, block) def is_sequential(self): """ @@ -125,6 +130,7 @@ def join(self): msg = "Joining is not implemented for this dispatcher." raise NotImplementedError(msg) + class TestJobReference(JobReference): def get_job_id(self) -> None: return None @@ -135,6 +141,7 @@ def get_exit_code(self) -> None: def get_info(self) -> Dict[str, Any]: return {"info": "test"} + class TestDispatcher(Dispatcher): """ A dummy dispatcher that just prints the output. Primarily for debugging and testing. @@ -150,8 +157,9 @@ def __init__(self): def _dispatch( self, funcs: typing.Iterable[FunctionCall], - options: SlurmOptions, - block: bool = False, + options: SlurmOptions, # noqa: ARG002 + entry_point: Path, + block: bool = False, # noqa: ARG002 ) -> JobReference: dispatch_guard() funcs = list(funcs) @@ -167,15 +175,15 @@ def _cleanup(self, command): args = shlex.split(command) if args[-2] != "temp": return - filename = args[-1] - if os.path.exists(filename): - os.remove(filename) + filename = Path(args[-1]) + if filename.exists(): + filename.unlink() def srun( self, command: str, - conf: typing.Optional[typing.Dict] = None, - simple_slurm_kwargs: typing.Optional[typing.Dict] = None, + conf: typing.Optional[typing.Dict] = None, # noqa: ARG002 + simple_slurm_kwargs: typing.Optional[typing.Dict] = None, # noqa: ARG002 ): dispatch_guard() self.sruns.append(command) @@ -185,8 +193,8 @@ def srun( def sbatch( self, command: str, - conf: typing.Optional[typing.Dict] = None, - simple_slurm_kwargs: typing.Optional[typing.Dict] = None, + conf: typing.Optional[typing.Dict] = None, # noqa: ARG002 + simple_slurm_kwargs: typing.Optional[typing.Dict] = None, # noqa: ARG002 ): dispatch_guard() self.sbatches.append(command) @@ -196,6 +204,7 @@ def sbatch( def is_sequential(self): return True + class SlurmJobReference(JobReference): def __init__(self, job_id, exit_code, mode: str): self.job_id = job_id @@ -209,10 +218,13 @@ def get_exit_code(self) -> Optional[int]: return self.exit_code def get_info(self) -> Dict[str, Any]: - return {"job_id": self.job_id, - "exit_code": self.exit_code, - "on_slurm": True, - "mode": self.mode} + return { + "job_id": self.job_id, + "exit_code": self.exit_code, + "on_slurm": True, + "mode": self.mode, + } + class SlurmDispatcher(Dispatcher): """ @@ -242,6 +254,7 @@ def _dispatch( self, funcs: typing.Iterable[FunctionCall], options: SlurmOptions, + entry_point: Path, block: bool = False, ) -> SlurmJobReference: dispatch_guard() @@ -254,9 +267,7 @@ def _dispatch( if self._join_dependencies: options.add_dependencies(self._join_dependencies, "afterany") slurm = self._create_slurm_api(options) - command = create_slurminade_command( - get_entry_point(), funcs, self.max_arg_length - ) + command = create_slurminade_command(entry_point, funcs, self.max_arg_length) logging.getLogger("slurminade").debug(command) if block: ret = slurm.srun(command) @@ -306,11 +317,12 @@ def srun( ret = slurm.srun(command) return SlurmJobReference(None, ret, "srun") + class SubprocessJobReference(JobReference): def __init__(self): pass - def get_job_id(self) -> Optional[int]: + def get_job_id(self) -> Optional[int]: return None def get_exit_code(self) -> Optional[int]: @@ -319,6 +331,7 @@ def get_exit_code(self) -> Optional[int]: def get_info(self) -> Dict[str, Any]: return {"on_slurm": False} + class SubprocessDispatcher(Dispatcher): """ A dispatcher for debugging that distributes function calls using subprocesses. @@ -335,21 +348,20 @@ def __init__(self): def _dispatch( self, funcs: typing.Iterable[FunctionCall], - options: SlurmOptions, - block: bool = False, + options: SlurmOptions, # noqa: ARG002 + entry_point: Path, + block: bool = False, # noqa: ARG002 ) -> int: dispatch_guard() - command = create_slurminade_command( - get_entry_point(), funcs, self.max_arg_length - ) + command = create_slurminade_command(entry_point, funcs, self.max_arg_length) os.system(command) return -1 def srun( self, command: str, - conf: typing.Optional[typing.Dict] = None, - simple_slurm_kwargs: typing.Optional[typing.Dict] = None, + conf: typing.Optional[typing.Dict] = None, # noqa: ARG002 + simple_slurm_kwargs: typing.Optional[typing.Dict] = None, # noqa: ARG002 ): dispatch_guard() logging.getLogger("slurminade").debug("SRUN %s", command) @@ -358,21 +370,22 @@ def srun( def sbatch( self, command: str, - conf: typing.Optional[typing.Dict] = None, - simple_slurm_kwargs: typing.Optional[typing.Dict] = None, + conf: typing.Optional[typing.Dict] = None, # noqa: ARG002 + simple_slurm_kwargs: typing.Optional[typing.Dict] = None, # noqa: ARG002 ): self.srun(command) def is_sequential(self): return True + class LocalJobReference(JobReference): def get_job_id(self) -> None: return None - - def get_exit_code(self) -> None: + + def get_exit_code(self) -> None: return None - + def get_info(self) -> Dict[str, Any]: return {"on_slurm": False} @@ -387,8 +400,9 @@ class DirectCallDispatcher(Dispatcher): def _dispatch( self, funcs: typing.Iterable[FunctionCall], - options: SlurmOptions, - block: bool = False, + options: SlurmOptions, # noqa: ARG002 + entry_point: Path, # noqa: ARG002 + block: bool = False, # noqa: ARG002 ) -> LocalJobReference: dispatch_guard() for func in funcs: @@ -398,8 +412,8 @@ def _dispatch( def srun( self, command: str, - conf: typing.Optional[typing.Dict] = None, - simple_slurm_kwargs: typing.Optional[typing.Dict] = None, + conf: typing.Optional[typing.Dict] = None, # noqa: ARG002 + simple_slurm_kwargs: typing.Optional[typing.Dict] = None, # noqa: ARG002 ): dispatch_guard() subprocess.run(command, check=True) @@ -408,8 +422,8 @@ def srun( def sbatch( self, command: str, - conf: typing.Optional[typing.Dict] = None, - simple_slurm_kwargs: typing.Optional[typing.Dict] = None, + conf: typing.Optional[typing.Dict] = None, # noqa: ARG002 + simple_slurm_kwargs: typing.Optional[typing.Dict] = None, # noqa: ARG002 ): return self.srun(command) @@ -454,8 +468,9 @@ def set_dispatcher(dispatcher: Dispatcher) -> None: def dispatch( funcs: typing.Union[FunctionCall, typing.Iterable[FunctionCall]], options: SlurmOptions, + entry_point: Path, block: bool = False, -) -> int: +) -> JobReference: """ Distribute function calls with the current dispatcher. :param funcs: The functions calls to be distributed. @@ -464,17 +479,17 @@ def dispatch( """ funcs = list(funcs) if not isinstance(funcs, FunctionCall) else [funcs] for func in funcs: - if not FunctionMap.check_id(func.func_id): + if not FunctionMap.check_id(func.func_id, entry_point): msg = f"Function '{func.func_id}' cannot be called from the given entry point." raise KeyError(msg) - return get_dispatcher()(funcs, options, block) + return get_dispatcher()(funcs, options, entry_point, block) def srun( command: typing.Union[str, typing.List[str]], conf: typing.Union[SlurmOptions, typing.Dict, None] = None, simple_slurm_kwargs: typing.Optional[typing.Dict] = None, -) -> int: +) -> JobReference: """ Call srun with the current dispatcher. This command is directly executed and only terminates after completion. @@ -499,7 +514,7 @@ def sbatch( command: typing.Union[str, typing.List[str]], conf: typing.Union[SlurmOptions, typing.Dict, None] = None, simple_slurm_kwargs: typing.Optional[typing.Dict] = None, -) -> int: +) -> JobReference: """ The command is scheduled and the function returns immediately. :param command: A system command, e.g. `echo hello world > foobar.txt`. diff --git a/src/slurminade/execute.py b/src/slurminade/execute.py index 6153c8b..0c951ac 100644 --- a/src/slurminade/execute.py +++ b/src/slurminade/execute.py @@ -14,6 +14,7 @@ from .guard import prevent_distribution from .node_setup import disable_setup + @click.command() @click.option( "--root", @@ -40,7 +41,7 @@ def main(root, calls, fromfile, listfuncs): if listfuncs: disable_setup() set_entry_point(root) - with open(root) as f: + with Path(root).open() as f: code = "".join(f.readlines()) glob = dict(globals()) @@ -49,12 +50,12 @@ def main(root, calls, fromfile, listfuncs): exec(code, glob) if listfuncs: - print(json.dumps(FunctionMap.get_all_ids())) # noqa T201 + print(json.dumps(FunctionMap.get_all_ids())) # noqa: T201 return if calls: function_calls = json.loads(calls) elif fromfile: - with open(fromfile) as f: + with Path(fromfile).open() as f: logging.getLogger("slurminade").info( f"Reading function calls from {fromfile}." ) diff --git a/src/slurminade/function.py b/src/slurminade/function.py index ae30bfb..941776d 100644 --- a/src/slurminade/function.py +++ b/src/slurminade/function.py @@ -1,13 +1,16 @@ import inspect +import logging import subprocess import typing from enum import Enum +from pathlib import Path from .dispatcher import FunctionCall, dispatch, get_dispatcher -from .function_map import FunctionMap +from .function_map import FunctionMap, get_entry_point from .guard import guard_recursive_distribution -from .options import SlurmOptions from .job_reference import JobReference +from .options import SlurmOptions + class CallPolicy(Enum): """ @@ -47,12 +50,15 @@ def __init__( self.func = func self.func_id = func_id self.call_policy = call_policy + self.defining_file = Path(inspect.getfile(func)) def update_options(self, conf: typing.Dict[str, typing.Any]): self.special_slurm_opts.update(conf) def wait_for( - self, job_ids: typing.Union[JobReference, typing.Iterable[JobReference]], method: str = "afterany" + self, + job_ids: typing.Union[JobReference, typing.Iterable[JobReference]], + method: str = "afterany", ) -> "SlurmFunction": """ Add a dependency to a distribution. @@ -76,10 +82,15 @@ def wait_for( msg += " This is probably an error in your code." msg += " Maybe you are using `Batch` but flush outside of the `with` block?" raise RuntimeError(msg) - if any(jid.get_job_id() is None for jid in job_ids) and not get_dispatcher().is_sequential(): + if ( + any(jid.get_job_id() is None for jid in job_ids) + and not get_dispatcher().is_sequential() + ): msg = "Invalid job id. Not every dispatcher can directly return job ids, because it may not directly distribute them or doesn't distribute them at all." raise RuntimeError(msg) - sfunc.special_slurm_opts.add_dependencies(list(jid.get_job_id() for jid in job_ids), method) + sfunc.special_slurm_opts.add_dependencies( + [jid.get_job_id() for jid in job_ids], method + ) return sfunc def with_options(self, **kwargs) -> "SlurmFunction": @@ -115,7 +126,20 @@ def __call__(self, *args, **kwargs): msg = "Unknown call policy." raise RuntimeError(msg) - def distribute(self, *args, **kwargs) -> int: + def get_entry_point(self) -> Path: + """ + Returns the entry point for the function. + Either it is defined in the FunctionMap, or the defining file is used. + """ + try: + return get_entry_point() + except FileNotFoundError: + logging.getLogger("slurminade").debug( + "Using defining file %s as entry point.", self.defining_file + ) + return self.defining_file + + def distribute(self, *args, **kwargs) -> JobReference: """ Try to distribute function call. If slurm is not available, a direct function call will be performed. @@ -128,10 +152,11 @@ def distribute(self, *args, **kwargs) -> int: return dispatch( [FunctionCall(self.func_id, args, kwargs)], self.special_slurm_opts, + entry_point=self.get_entry_point(), block=False, ) - def distribute_and_wait(self, *args, **kwargs) -> int: + def distribute_and_wait(self, *args, **kwargs) -> JobReference: """ Distribute the function and wait for it to finish. :param args: The positional arguments. @@ -143,6 +168,7 @@ def distribute_and_wait(self, *args, **kwargs) -> int: return dispatch( [FunctionCall(self.func_id, args, kwargs)], self.special_slurm_opts, + entry_point=self.get_entry_point(), block=True, ) @@ -192,10 +218,35 @@ def dec(func) -> SlurmFunction: return dec -@slurmify() -def exec(cmd: typing.Union[str, typing.List[str]]): +def _slurmify( + allow_overwrite: bool, **args +) -> typing.Union[typing.Callable[[typing.Callable], SlurmFunction], SlurmFunction]: + """ + Decorator: Make a function distributable to slurm. + Usage: + + .. code-block:: python + + @slurmify_() + def func(a, b): + pass + + :param f: Function + :param args: Special slurm options for this function. + :return: A decorated function, callable with slurm. + """ + + def dec(func) -> SlurmFunction: + func_id = FunctionMap.register(func, allow_overwrite=allow_overwrite) + return SlurmFunction(args, func, func_id) + + return dec + + +@_slurmify(allow_overwrite=True) +def shell(cmd: typing.Union[str, typing.List[str]]): """ Execute a command. :param cmd: The command to be executed. """ - subprocess.run(cmd, check=True) + subprocess.run(cmd, check=True, shell=True) diff --git a/src/slurminade/function_call.py b/src/slurminade/function_call.py index 59e8574..83ef801 100644 --- a/src/slurminade/function_call.py +++ b/src/slurminade/function_call.py @@ -6,7 +6,7 @@ class FunctionCall: A function call to be dispatched. """ - def __init__(self, func_id, args: typing.Tuple, kwargs: typing.Dict): + def __init__(self, func_id: str, args: typing.Tuple, kwargs: typing.Dict): self.func_id = func_id # the function id, as in FunctionMap self.args = args # the positional arguments for the call self.kwargs = kwargs # the keyword arguments for the call diff --git a/src/slurminade/function_map.py b/src/slurminade/function_map.py index b5a7f92..a5ff94b 100644 --- a/src/slurminade/function_map.py +++ b/src/slurminade/function_map.py @@ -66,7 +66,7 @@ def check_compatibility(func: typing.Callable): raise ValueError(msg) @staticmethod - def register(func: typing.Callable) -> str: + def register(func: typing.Callable, allow_overwrite: bool = False) -> str: """ Register a function, allowing it to be called just by its id. :param func: The function to be stored. Needs to be a proper function. @@ -74,12 +74,16 @@ def register(func: typing.Callable) -> str: """ FunctionMap.check_compatibility(func) func_id = FunctionMap.get_id(func) - if func_id in FunctionMap._data: + if func_id in FunctionMap._data and not allow_overwrite: msg = "Multiple function definitions!" raise RuntimeError(msg) FunctionMap._data[func_id] = func return func_id + @staticmethod + def exists(func_id: str) -> bool: + return func_id in FunctionMap._data + @staticmethod def call( func_id: str, args: typing.Iterable, kwargs: typing.Dict[str, typing.Any] @@ -97,13 +101,13 @@ def call( return FunctionMap._data[func_id](*args, **kwargs) @staticmethod - def check_id(func_id: str) -> bool: + def check_id(func_id: str, entry_point: Path) -> bool: if func_id in FunctionMap._ids: return True - FunctionMap._ids = call_slurminade_to_get_function_ids(get_entry_point()) + FunctionMap._ids = call_slurminade_to_get_function_ids(entry_point) logging.getLogger("slurminade").info( "Entry point '%s' has functions %s", - get_entry_point(), + entry_point, list(FunctionMap._ids), ) return func_id in FunctionMap._ids @@ -134,7 +138,15 @@ def get_entry_point() -> Path: if FunctionMap.entry_point is None: import __main__ + # check if attribute __file__ is available + if not hasattr(__main__, "__file__"): + msg = "No entry point known." + raise FileNotFoundError(msg) + entry_point = __main__.__file__ + if not Path(entry_point).is_file() or Path(entry_point).suffix != ".py": + msg = "No entry point known." + raise FileNotFoundError(msg) set_entry_point(entry_point) assert FunctionMap.entry_point is not None diff --git a/src/slurminade/job_reference.py b/src/slurminade/job_reference.py index 4ff3af9..6f8dfd5 100644 --- a/src/slurminade/job_reference.py +++ b/src/slurminade/job_reference.py @@ -1,6 +1,7 @@ import abc from typing import Any, Dict, Optional + class JobReference(abc.ABC): @abc.abstractmethod def get_job_id(self) -> Optional[int]: @@ -12,4 +13,4 @@ def get_exit_code(self) -> Optional[int]: @abc.abstractmethod def get_info(self) -> Dict[str, Any]: - pass \ No newline at end of file + pass diff --git a/src/slurminade/node_setup.py b/src/slurminade/node_setup.py index fbe57be..da6c67a 100644 --- a/src/slurminade/node_setup.py +++ b/src/slurminade/node_setup.py @@ -5,6 +5,7 @@ _no_setup = False + def disable_setup(): """ Disable the setup function. This is useful for testing. @@ -12,6 +13,7 @@ def disable_setup(): global _no_setup _no_setup = True + def node_setup(func: typing.Callable): """ Decorator: Call this function on the node before running any function calls.