Skip to content

Commit 2a5c109

Browse files
[minor] restructure create method (#551)
* restructure create method * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes * remove unused line --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 565bb32 commit 2a5c109

File tree

1 file changed

+42
-15
lines changed

1 file changed

+42
-15
lines changed

executorlib/interactive/create.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,15 @@ def create_executor(
8383
of the individual function.
8484
init_function (None): optional function to preset arguments for functions which are submitted later
8585
"""
86-
check_init_function(block_allocation=block_allocation, init_function=init_function)
8786
if flux_executor is not None and backend != "flux_allocation":
8887
backend = "flux_allocation"
89-
check_pmi(backend=backend, pmi=flux_executor_pmi_mode)
90-
cores_per_worker = resource_dict.get("cores", 1)
91-
resource_dict["cache_directory"] = cache_directory
92-
resource_dict["hostname_localhost"] = hostname_localhost
9388
if backend == "flux_allocation":
89+
check_init_function(
90+
block_allocation=block_allocation, init_function=init_function
91+
)
92+
check_pmi(backend=backend, pmi=flux_executor_pmi_mode)
93+
resource_dict["cache_directory"] = cache_directory
94+
resource_dict["hostname_localhost"] = hostname_localhost
9495
check_oversubscribe(
9596
oversubscribe=resource_dict.get("openmpi_oversubscribe", False)
9697
)
@@ -100,40 +101,41 @@ def create_executor(
100101
return create_flux_allocation_executor(
101102
max_workers=max_workers,
102103
max_cores=max_cores,
103-
cores_per_worker=cores_per_worker,
104+
cache_directory=cache_directory,
104105
resource_dict=resource_dict,
105106
flux_executor=flux_executor,
106107
flux_executor_pmi_mode=flux_executor_pmi_mode,
107108
flux_executor_nesting=flux_executor_nesting,
108109
flux_log_files=flux_log_files,
110+
hostname_localhost=hostname_localhost,
109111
block_allocation=block_allocation,
110112
init_function=init_function,
111113
)
112114
elif backend == "slurm_allocation":
115+
check_pmi(backend=backend, pmi=flux_executor_pmi_mode)
113116
check_executor(executor=flux_executor)
114117
check_nested_flux_executor(nested_flux_executor=flux_executor_nesting)
115118
check_flux_log_files(flux_log_files=flux_log_files)
116119
return create_slurm_allocation_executor(
117120
max_workers=max_workers,
118121
max_cores=max_cores,
119-
cores_per_worker=cores_per_worker,
122+
cache_directory=cache_directory,
120123
resource_dict=resource_dict,
124+
hostname_localhost=hostname_localhost,
121125
block_allocation=block_allocation,
122126
init_function=init_function,
123127
)
124128
elif backend == "local":
129+
check_pmi(backend=backend, pmi=flux_executor_pmi_mode)
125130
check_executor(executor=flux_executor)
126131
check_nested_flux_executor(nested_flux_executor=flux_executor_nesting)
127132
check_flux_log_files(flux_log_files=flux_log_files)
128-
check_gpus_per_worker(gpus_per_worker=resource_dict.get("gpus_per_core", 0))
129-
check_command_line_argument_lst(
130-
command_line_argument_lst=resource_dict.get("slurm_cmd_args", [])
131-
)
132133
return create_local_executor(
133134
max_workers=max_workers,
134135
max_cores=max_cores,
135-
cores_per_worker=cores_per_worker,
136+
cache_directory=cache_directory,
136137
resource_dict=resource_dict,
138+
hostname_localhost=hostname_localhost,
137139
block_allocation=block_allocation,
138140
init_function=init_function,
139141
)
@@ -146,15 +148,25 @@ def create_executor(
146148
def create_flux_allocation_executor(
147149
max_workers: Optional[int] = None,
148150
max_cores: Optional[int] = None,
149-
cores_per_worker: int = 1,
151+
cache_directory: Optional[str] = None,
150152
resource_dict: dict = {},
151153
flux_executor=None,
152154
flux_executor_pmi_mode: Optional[str] = None,
153155
flux_executor_nesting: bool = False,
154156
flux_log_files: bool = False,
157+
hostname_localhost: Optional[bool] = None,
155158
block_allocation: bool = False,
156159
init_function: Optional[Callable] = None,
157160
) -> Union[InteractiveStepExecutor, InteractiveExecutor]:
161+
check_init_function(block_allocation=block_allocation, init_function=init_function)
162+
check_pmi(backend="flux_allocation", pmi=flux_executor_pmi_mode)
163+
cores_per_worker = resource_dict.get("cores", 1)
164+
resource_dict["cache_directory"] = cache_directory
165+
resource_dict["hostname_localhost"] = hostname_localhost
166+
check_oversubscribe(oversubscribe=resource_dict.get("openmpi_oversubscribe", False))
167+
check_command_line_argument_lst(
168+
command_line_argument_lst=resource_dict.get("slurm_cmd_args", [])
169+
)
158170
if "openmpi_oversubscribe" in resource_dict.keys():
159171
del resource_dict["openmpi_oversubscribe"]
160172
if "slurm_cmd_args" in resource_dict.keys():
@@ -193,11 +205,16 @@ def create_flux_allocation_executor(
193205
def create_slurm_allocation_executor(
194206
max_workers: Optional[int] = None,
195207
max_cores: Optional[int] = None,
196-
cores_per_worker: int = 1,
208+
cache_directory: Optional[str] = None,
197209
resource_dict: dict = {},
210+
hostname_localhost: Optional[bool] = None,
198211
block_allocation: bool = False,
199212
init_function: Optional[Callable] = None,
200213
) -> Union[InteractiveStepExecutor, InteractiveExecutor]:
214+
check_init_function(block_allocation=block_allocation, init_function=init_function)
215+
cores_per_worker = resource_dict.get("cores", 1)
216+
resource_dict["cache_directory"] = cache_directory
217+
resource_dict["hostname_localhost"] = hostname_localhost
201218
if block_allocation:
202219
resource_dict["init_function"] = init_function
203220
max_workers = validate_number_of_cores(
@@ -228,11 +245,21 @@ def create_slurm_allocation_executor(
228245
def create_local_executor(
229246
max_workers: Optional[int] = None,
230247
max_cores: Optional[int] = None,
231-
cores_per_worker: int = 1,
248+
cache_directory: Optional[str] = None,
232249
resource_dict: dict = {},
250+
hostname_localhost: Optional[bool] = None,
233251
block_allocation: bool = False,
234252
init_function: Optional[Callable] = None,
235253
) -> Union[InteractiveStepExecutor, InteractiveExecutor]:
254+
check_init_function(block_allocation=block_allocation, init_function=init_function)
255+
cores_per_worker = resource_dict.get("cores", 1)
256+
resource_dict["cache_directory"] = cache_directory
257+
resource_dict["hostname_localhost"] = hostname_localhost
258+
259+
check_gpus_per_worker(gpus_per_worker=resource_dict.get("gpus_per_core", 0))
260+
check_command_line_argument_lst(
261+
command_line_argument_lst=resource_dict.get("slurm_cmd_args", [])
262+
)
236263
if "threads_per_core" in resource_dict.keys():
237264
del resource_dict["threads_per_core"]
238265
if "gpus_per_core" in resource_dict.keys():

0 commit comments

Comments
 (0)