10
10
import logging
11
11
from pathlib import Path
12
12
from sfapi_client import AsyncClient
13
+ from sfapi_client .exceptions import SfApiError
13
14
from sfapi_client .paths import AsyncRemotePath
14
15
from sfapi_client .compute import Machine , AsyncCompute
15
16
import sys
46
47
_SEC_PER_GB = 2 * 60 # may want to make this configurable
47
48
48
49
_CTS_SCRATCH_ROOT_DIR = Path ("cdm_task_service" )
49
- _JOB_FILES = "files"
50
- _MANIFESTS = "manifests"
50
+ _JOB_FILES = Path ( "files" )
51
+ _JOB_MANIFESTS = Path ( "manifests" )
51
52
_MANIFEST_FILE_PREFIX = "manifest-"
52
53
53
54
54
- _JAWS_CONF_FILENAME = "jaws .conf"
55
+ _JAWS_CONF_FILENAME = "jaws_cts .conf"
55
56
_JAWS_CONF_TEMPLATE = """
56
57
[USER]
57
58
token = {token}
58
59
default_team = {group}
59
60
"""
61
+ _JAWS_COMMAND_TEMPLATE = f"""
62
+ module use /global/cfs/projectdirs/kbase/jaws/modulefiles
63
+ module load jaws
64
+ export JAWS_USER_CONFIG=~/{ _JAWS_CONF_FILENAME }
65
+ jaws submit --quiet --tag {{job_id}} {{wdlpath}} {{inputjsonpath}} {{site}}
66
+ """
67
+ _JAWS_SITE_PERLMUTTER = "kbase" # add lawrencium later, maybe
68
+ _JAWS_INPUT_WDL = "input.wdl"
69
+ _JAWS_INPUT_JSON = "input.json"
60
70
61
71
62
72
# TODO PROD add start and end time to task output and record
@@ -184,7 +194,8 @@ async def _setup_remote_code(self, file_group: str, jaws_token: str, jaws_group:
184
194
),
185
195
chmod = "600"
186
196
))
187
- scratch = tg .create_task (self ._set_up_dtn_scratch (cli , file_group ))
197
+ pm_scratch = tg .create_task (perlmutter .run ("echo $SCRATCH" ))
198
+ dtn_scratch = tg .create_task (self ._set_up_dtn_scratch (cli , file_group ))
188
199
if _PIP_DEPENDENCIES :
189
200
deps = " " .join (
190
201
# may need to do something else if module doesn't have __version__
@@ -198,7 +209,11 @@ async def _setup_remote_code(self, file_group: str, jaws_token: str, jaws_group:
198
209
+ f"pip install { deps } " # adding notapackage causes a failure
199
210
)
200
211
tg .create_task (dt .run (command ))
201
- self ._dtn_scratch = scratch .result ()
212
+ self ._dtn_scratch = dtn_scratch .result ()
213
+ self ._perlmutter_scratch = Path (pm_scratch .result ().strip ())
214
+ logging .getLogger (__name__ ).info (
215
+ f"NERSC perlmutter scratch path: { self ._perlmutter_scratch } "
216
+ )
202
217
203
218
async def _set_up_dtn_scratch (self , client : AsyncClient , file_group : str ) -> Path :
204
219
dt = await client .compute (_DT_TARGET )
@@ -208,7 +223,7 @@ async def _set_up_dtn_scratch(self, client: AsyncClient, file_group: str) -> Pat
208
223
raise ValueError ("Unable to determine $SCRATCH variable for NERSC dtns" )
209
224
logging .getLogger (__name__ ).info (f"NERSC DTN scratch path: { scratch } " )
210
225
await dt .run (
211
- f"{ _DT_WORKAROUND } ; set -e; chgrp { file_group } { scratch } ; chmod g+rs { scratch } "
226
+ f"{ _DT_WORKAROUND } ; set -e; chgrp { file_group } { scratch } ; chmod g+rsx { scratch } "
212
227
)
213
228
return Path (scratch )
214
229
@@ -394,26 +409,58 @@ async def run_JAWS(self, job: models.Job, file_download_concurrency: int = 10) -
394
409
_check_int (file_download_concurrency , "file_download_concurrency" )
395
410
if not _not_falsy (job , "job" ).job_input .inputs_are_S3File ():
396
411
raise ValueError ("Job files must be S3File objects" )
412
+ cli = self ._client_provider ()
413
+ await self ._generate_and_load_job_files_to_nersc (cli , job , file_download_concurrency )
414
+ perl = await cli .compute (Machine .perlmutter )
415
+ pre = self ._perlmutter_scratch / _CTS_SCRATCH_ROOT_DIR / job .id
416
+ try :
417
+ res = await perl .run (_JAWS_COMMAND_TEMPLATE .format (
418
+ job_id = job .id ,
419
+ wdlpath = pre / _JAWS_INPUT_WDL ,
420
+ inputjsonpath = pre / _JAWS_INPUT_JSON ,
421
+ site = _JAWS_SITE_PERLMUTTER
422
+ ))
423
+ except SfApiError as e :
424
+ # TODO ERRORHANDLING if jaws provides valid json parse it and return just the detail
425
+ #try:
426
+ # j = json.loads(f"{e}")
427
+ # if "detail" in j:
428
+ # raise ValueError(f"JAWS error: {j['detail']}") from e
429
+ raise ValueError (f"JAWS error: { e } " ) from e
430
+ #except json.JSONDecodeError as je:
431
+ # raise ValueError(f"JAWS returned invalid JSON ({je}) in error: {e}") from e
432
+ try :
433
+ j = json .loads (res )
434
+ if "run_id" not in j :
435
+ raise ValueError (f"JAWS returned no run_id in JSON { res } " )
436
+ run_id = j ["run_id" ]
437
+ logging .getLogger (__name__ ).info (
438
+ f"Submitted JAWS job with run id { run_id } for job { job .id } "
439
+ )
440
+ return run_id
441
+ except json .JSONDecodeError as e :
442
+ raise ValueError (f"JAWS returned invalid JSON: { e } \n { res } " ) from e
443
+
444
+ async def _generate_and_load_job_files_to_nersc (
445
+ self , cli : AsyncClient , job : models .Job , concurrency : int
446
+ ):
397
447
manifest_files = generate_manifest_files (job )
398
448
manifest_file_paths = self ._get_manifest_file_paths (job .id , len (manifest_files ))
399
- fmap = {m : self . _localize_s3_path ( job . id , m .file ) for m in job .job_input .input_files }
449
+ fmap = {m : _JOB_FILES / m .file for m in job .job_input .input_files }
400
450
wdljson = wdl .generate_wdl (job , fmap , manifest_file_paths )
401
- uploads = {fp : f for fp , f in zip (manifest_file_paths , manifest_files )}
402
451
pre = self ._dtn_scratch / _CTS_SCRATCH_ROOT_DIR / job .id
403
- wdlpath = pre / "input.wdl"
404
- jsonpath = pre / "input.json"
405
- uploads [wdlpath ] = wdljson .wdl
406
- uploads [jsonpath ] = json .dumps (wdljson .input_json , indent = 4 )
407
- cli = self ._client_provider ()
452
+ downloads = {pre / fp : f for fp , f in zip (manifest_file_paths , manifest_files )}
453
+ downloads [pre / _JAWS_INPUT_WDL ] = wdljson .wdl
454
+ downloads [pre / _JAWS_INPUT_JSON ] = json .dumps (wdljson .input_json , indent = 4 )
408
455
dt = await cli .compute (_DT_TARGET )
409
- semaphore = asyncio .Semaphore (file_download_concurrency )
456
+ semaphore = asyncio .Semaphore (concurrency )
410
457
async def sem_coro (coro ):
411
458
async with semaphore :
412
459
return await coro
413
460
coros = []
414
461
try :
415
462
async with asyncio .TaskGroup () as tg :
416
- for path , file in uploads .items ():
463
+ for path , file in downloads .items ():
417
464
coros .append (self ._upload_file_to_nersc (
418
465
dt , path , bio = io .BytesIO (file .encode ())
419
466
))
@@ -425,11 +472,8 @@ async def sem_coro(coro):
425
472
# otherwise you can get coroutine never awaited warnings if a failure occurs
426
473
for c in coros :
427
474
c .close ()
428
- # TODO NEXT run jaws job
429
- return "fake_job_id"
430
475
431
476
def _get_manifest_file_paths (self , job_id : str , count : int ) -> list [Path ]:
432
477
if count == 0 :
433
478
return []
434
- pre = self ._dtn_scratch / _CTS_SCRATCH_ROOT_DIR / job_id / _MANIFESTS
435
- return [pre / f"{ _MANIFEST_FILE_PREFIX } { c } " for c in range (1 , count + 1 )]
479
+ return [_JOB_MANIFESTS / f"{ _MANIFEST_FILE_PREFIX } { c } " for c in range (1 , count + 1 )]
0 commit comments