Skip to content

Commit 69a2a0b

Browse files
authored
Merge pull request #148 from kbase/dev-service
Submit job to JAWS
2 parents 3e959f7 + 4eeb567 commit 69a2a0b

File tree

2 files changed

+65
-21
lines changed

2 files changed

+65
-21
lines changed

cdmtaskservice/jobflows/nersc_jaws.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,9 @@ async def download_complete(self, job: models.AdminJobDetails):
130130
models.JobState.JOB_SUBMITTING,
131131
timestamp.utcdatetime()
132132
)
133-
await self._coman.run_coroutine(self._download_complete(job))
133+
await self._coman.run_coroutine(self._submit_jaws_job(job))
134134

135-
async def _download_complete(self, job: models.AdminJobDetails):
135+
async def _submit_jaws_job(self, job: models.AdminJobDetails):
136136
logr = logging.getLogger(__name__)
137137
try:
138138
# TODO PERF configure file download concurrency

cdmtaskservice/nersc/manager.py

+63-19
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import logging
1111
from pathlib import Path
1212
from sfapi_client import AsyncClient
13+
from sfapi_client.exceptions import SfApiError
1314
from sfapi_client.paths import AsyncRemotePath
1415
from sfapi_client.compute import Machine, AsyncCompute
1516
import sys
@@ -46,17 +47,26 @@
4647
_SEC_PER_GB = 2 * 60 # may want to make this configurable
4748

4849
_CTS_SCRATCH_ROOT_DIR = Path("cdm_task_service")
49-
_JOB_FILES = "files"
50-
_MANIFESTS = "manifests"
50+
_JOB_FILES = Path("files")
51+
_JOB_MANIFESTS = Path("manifests")
5152
_MANIFEST_FILE_PREFIX = "manifest-"
5253

5354

54-
_JAWS_CONF_FILENAME = "jaws.conf"
55+
_JAWS_CONF_FILENAME = "jaws_cts.conf"
5556
_JAWS_CONF_TEMPLATE = """
5657
[USER]
5758
token = {token}
5859
default_team = {group}
5960
"""
61+
_JAWS_COMMAND_TEMPLATE = f"""
62+
module use /global/cfs/projectdirs/kbase/jaws/modulefiles
63+
module load jaws
64+
export JAWS_USER_CONFIG=~/{_JAWS_CONF_FILENAME}
65+
jaws submit --quiet --tag {{job_id}} {{wdlpath}} {{inputjsonpath}} {{site}}
66+
"""
67+
_JAWS_SITE_PERLMUTTER = "kbase" # add lawrencium later, maybe
68+
_JAWS_INPUT_WDL = "input.wdl"
69+
_JAWS_INPUT_JSON = "input.json"
6070

6171

6272
# TODO PROD add start and end time to task output and record
@@ -184,7 +194,8 @@ async def _setup_remote_code(self, file_group: str, jaws_token: str, jaws_group:
184194
),
185195
chmod = "600"
186196
))
187-
scratch = tg.create_task(self._set_up_dtn_scratch(cli, file_group))
197+
pm_scratch = tg.create_task(perlmutter.run("echo $SCRATCH"))
198+
dtn_scratch = tg.create_task(self._set_up_dtn_scratch(cli, file_group))
188199
if _PIP_DEPENDENCIES:
189200
deps = " ".join(
190201
# may need to do something else if module doesn't have __version__
@@ -198,7 +209,11 @@ async def _setup_remote_code(self, file_group: str, jaws_token: str, jaws_group:
198209
+ f"pip install {deps}" # adding notapackage causes a failure
199210
)
200211
tg.create_task(dt.run(command))
201-
self._dtn_scratch = scratch.result()
212+
self._dtn_scratch = dtn_scratch.result()
213+
self._perlmutter_scratch = Path(pm_scratch.result().strip())
214+
logging.getLogger(__name__).info(
215+
f"NERSC perlmutter scratch path: {self._perlmutter_scratch}"
216+
)
202217

203218
async def _set_up_dtn_scratch(self, client: AsyncClient, file_group: str) -> Path:
204219
dt = await client.compute(_DT_TARGET)
@@ -208,7 +223,7 @@ async def _set_up_dtn_scratch(self, client: AsyncClient, file_group: str) -> Pat
208223
raise ValueError("Unable to determine $SCRATCH variable for NERSC dtns")
209224
logging.getLogger(__name__).info(f"NERSC DTN scratch path: {scratch}")
210225
await dt.run(
211-
f"{_DT_WORKAROUND}; set -e; chgrp {file_group} {scratch}; chmod g+rs {scratch}"
226+
f"{_DT_WORKAROUND}; set -e; chgrp {file_group} {scratch}; chmod g+rsx {scratch}"
212227
)
213228
return Path(scratch)
214229

@@ -394,26 +409,58 @@ async def run_JAWS(self, job: models.Job, file_download_concurrency: int = 10) -
394409
_check_int(file_download_concurrency, "file_download_concurrency")
395410
if not _not_falsy(job, "job").job_input.inputs_are_S3File():
396411
raise ValueError("Job files must be S3File objects")
412+
cli = self._client_provider()
413+
await self._generate_and_load_job_files_to_nersc(cli, job, file_download_concurrency)
414+
perl = await cli.compute(Machine.perlmutter)
415+
pre = self._perlmutter_scratch / _CTS_SCRATCH_ROOT_DIR / job.id
416+
try:
417+
res = await perl.run(_JAWS_COMMAND_TEMPLATE.format(
418+
job_id=job.id,
419+
wdlpath=pre / _JAWS_INPUT_WDL,
420+
inputjsonpath=pre / _JAWS_INPUT_JSON,
421+
site=_JAWS_SITE_PERLMUTTER
422+
))
423+
except SfApiError as e:
424+
# TODO ERRORHANDLING if jaws provides valid json parse it and return just the detail
425+
#try:
426+
# j = json.loads(f"{e}")
427+
# if "detail" in j:
428+
# raise ValueError(f"JAWS error: {j['detail']}") from e
429+
raise ValueError(f"JAWS error: {e}") from e
430+
#except json.JSONDecodeError as je:
431+
# raise ValueError(f"JAWS returned invalid JSON ({je}) in error: {e}") from e
432+
try:
433+
j = json.loads(res)
434+
if "run_id" not in j:
435+
raise ValueError(f"JAWS returned no run_id in JSON {res}")
436+
run_id = j["run_id"]
437+
logging.getLogger(__name__).info(
438+
f"Submitted JAWS job with run id {run_id} for job {job.id}"
439+
)
440+
return run_id
441+
except json.JSONDecodeError as e:
442+
raise ValueError(f"JAWS returned invalid JSON: {e}\n{res}") from e
443+
444+
async def _generate_and_load_job_files_to_nersc(
445+
self, cli: AsyncClient, job: models.Job, concurrency: int
446+
):
397447
manifest_files = generate_manifest_files(job)
398448
manifest_file_paths = self._get_manifest_file_paths(job.id, len(manifest_files))
399-
fmap = {m: self._localize_s3_path(job.id, m.file) for m in job.job_input.input_files}
449+
fmap = {m: _JOB_FILES / m.file for m in job.job_input.input_files}
400450
wdljson = wdl.generate_wdl(job, fmap, manifest_file_paths)
401-
uploads = {fp: f for fp, f in zip(manifest_file_paths, manifest_files)}
402451
pre = self._dtn_scratch / _CTS_SCRATCH_ROOT_DIR / job.id
403-
wdlpath = pre / "input.wdl"
404-
jsonpath = pre / "input.json"
405-
uploads[wdlpath] = wdljson.wdl
406-
uploads[jsonpath] = json.dumps(wdljson.input_json, indent=4)
407-
cli = self._client_provider()
452+
downloads = {pre / fp: f for fp, f in zip(manifest_file_paths, manifest_files)}
453+
downloads[pre / _JAWS_INPUT_WDL] = wdljson.wdl
454+
downloads[pre / _JAWS_INPUT_JSON] = json.dumps(wdljson.input_json, indent=4)
408455
dt = await cli.compute(_DT_TARGET)
409-
semaphore = asyncio.Semaphore(file_download_concurrency)
456+
semaphore = asyncio.Semaphore(concurrency)
410457
async def sem_coro(coro):
411458
async with semaphore:
412459
return await coro
413460
coros = []
414461
try:
415462
async with asyncio.TaskGroup() as tg:
416-
for path, file in uploads.items():
463+
for path, file in downloads.items():
417464
coros.append(self._upload_file_to_nersc(
418465
dt, path, bio=io.BytesIO(file.encode())
419466
))
@@ -425,11 +472,8 @@ async def sem_coro(coro):
425472
# otherwise you can get coroutine never awaited warnings if a failure occurs
426473
for c in coros:
427474
c.close()
428-
# TODO NEXT run jaws job
429-
return "fake_job_id"
430475

431476
def _get_manifest_file_paths(self, job_id: str, count: int) -> list[Path]:
432477
if count == 0:
433478
return []
434-
pre = self._dtn_scratch / _CTS_SCRATCH_ROOT_DIR / job_id / _MANIFESTS
435-
return [pre / f"{_MANIFEST_FILE_PREFIX}{c}" for c in range(1, count + 1)]
479+
return [_JOB_MANIFESTS / f"{_MANIFEST_FILE_PREFIX}{c}" for c in range(1, count + 1)]

0 commit comments

Comments
 (0)