4
4
from dagster ._core .launcher .base import (
5
5
CheckRunHealthResult ,
6
6
LaunchRunContext ,
7
- ResumeRunContext ,
8
7
RunLauncher ,
9
8
WorkerStatus ,
10
9
)
11
10
from dagster ._core .origin import JobPythonOrigin
12
11
from dagster ._core .storage .dagster_run import DagsterRun
13
12
from dagster ._core .storage .tags import DOCKER_IMAGE_TAG
14
13
from dagster ._core .utils import parse_env_var
15
- from dagster ._grpc .types import ExecuteRunArgs , ResumeRunArgs
14
+ from dagster ._grpc .types import ExecuteRunArgs
16
15
from dagster ._serdes import ConfigurableClass
17
16
from dagster ._serdes .config_class import ConfigurableClassData
18
17
from typing_extensions import Self
@@ -115,64 +114,39 @@ def _get_semantic_job_name(self, dagster_run: DagsterRun) -> str:
115
114
return list (dagster_run .asset_selection )[- 1 ].to_user_string ()
116
115
return dagster_run .run_id
117
116
118
- def _create_or_update_job_definition (
117
+ def _get_or_create_job_definition (
119
118
self ,
120
119
client : scaleway .Client ,
121
120
run : DagsterRun ,
122
121
docker_image : str ,
123
- command : list [str ],
124
122
) -> scw .JobDefinition :
125
123
serverless_job_context = self .get_serverless_job_context (run )
126
124
api = scw .JobsV1Alpha1API (client )
127
125
128
- job_def_env = dict (
129
- [parse_env_var (env_var ) for env_var in serverless_job_context .env_vars ]
130
- )
131
- job_def_env ["DAGSTER_RUN_JOB_NAME" ] = run .job_name
132
- job_def_env ["DAGSTER_RUN_ID" ] = run .run_id
133
- job_def_env ["INPUT_JSON" ] = command [- 1 ]
134
-
135
- api = scw .JobsV1Alpha1API (client )
136
-
137
- wrapped_command = [COMMAND_WRAPPER ] + command [:- 1 ]
138
- description = (
139
- f"JobDefinition for { run .job_name } ."
140
- + " "
141
- + "Created by the ServerlessJobRunLauncher from dagster-scaleway."
142
- )
143
126
job_def_name = self ._get_semantic_job_name (run )
144
127
145
128
for job_def in api .list_job_definitions_all ():
146
129
if job_def .name != job_def_name :
147
130
continue
148
131
149
- job_def = api .update_job_definition (
150
- job_definition_id = job_def .id ,
151
- name = job_def_name ,
152
- image_uri = docker_image ,
153
- environment_variables = job_def_env ,
154
- command = " " .join (wrapped_command ),
155
- memory_limit = serverless_job_context .memory_limit ,
156
- cpu_limit = serverless_job_context .cpu_limit ,
157
- description = description ,
132
+ return self ._update_job_definition (
133
+ client , serverless_job_context , job_def , docker_image
158
134
)
159
135
160
- self ._instance .report_engine_event (
161
- message = f"Updated job { job_def .id } for Dagster run { run .run_id } " ,
162
- dagster_run = run ,
163
- cls = self .__class__ ,
164
- )
136
+ job_def_env = dict (
137
+ [parse_env_var (env_var ) for env_var in serverless_job_context .env_vars ]
138
+ )
165
139
166
- return job_def
140
+ job_def_name = self . _get_semantic_job_name ( run )
167
141
168
142
job_def = api .create_job_definition (
169
143
name = job_def_name ,
170
144
image_uri = docker_image ,
171
145
environment_variables = job_def_env ,
172
- command = " " . join ( wrapped_command ) ,
146
+ command = COMMAND_WRAPPER ,
173
147
memory_limit = serverless_job_context .memory_limit ,
174
148
cpu_limit = serverless_job_context .cpu_limit ,
175
- description = description ,
149
+ description = "Created by the ServerlessJobRunLauncher from dagster-scaleway." ,
176
150
project_id = client .default_project_id ,
177
151
)
178
152
@@ -184,21 +158,70 @@ def _create_or_update_job_definition(
184
158
185
159
return job_def
186
160
187
- def _launch_serverless_job_with_command (
161
+ def _update_job_definition (
162
+ self ,
163
+ client : scaleway .Client ,
164
+ context : ScalewayServerlessJobContext ,
165
+ job_def : scw .JobDefinition ,
166
+ docker_image : str ,
167
+ ) -> scw .JobDefinition :
168
+ """Reconcile the job definition with the desired state"""
169
+ api = scw .JobsV1Alpha1API (client )
170
+ has_changed = False
171
+
172
+ if job_def .image_uri != docker_image :
173
+ job_def .image_uri = docker_image
174
+ has_changed = True
175
+
176
+ if job_def .memory_limit != context .memory_limit :
177
+ job_def .memory_limit = context .memory_limit
178
+ has_changed = True
179
+
180
+ if job_def .cpu_limit != context .cpu_limit :
181
+ job_def .cpu_limit = context .cpu_limit
182
+ has_changed = True
183
+
184
+ if not has_changed :
185
+ return job_def
186
+
187
+ return api .update_job_definition (
188
+ job_definition_id = job_def .id ,
189
+ image_uri = docker_image ,
190
+ memory_limit = context .memory_limit ,
191
+ cpu_limit = context .cpu_limit ,
192
+ )
193
+
194
+ def _start_serverless_job_with_command (
188
195
self , run : DagsterRun , docker_image : str , command : list [str ]
189
196
):
190
197
serverless_job_context = self .get_serverless_job_context (run )
191
198
client = self ._get_client (serverless_job_context )
192
199
api = scw .JobsV1Alpha1API (client )
193
200
194
- job_def = self ._create_or_update_job_definition (
195
- client , run , docker_image , command
201
+ job_def = self ._get_or_create_job_definition (client , run , docker_image )
202
+
203
+ extra_env = {
204
+ "DAGSTER_RUN_JOB_NAME" : run .job_name ,
205
+ "DAGSTER_RUN_ID" : run .run_id ,
206
+ "INPUT_JSON" : command [- 1 ],
207
+ }
208
+ wrapped_command = [COMMAND_WRAPPER ] + command [:- 1 ]
209
+
210
+ start_response = api .start_job_definition (
211
+ job_definition_id = job_def .id ,
212
+ command = " " .join (wrapped_command ),
213
+ environment_variables = extra_env ,
196
214
)
197
215
198
- job_run = api .start_job_definition (job_definition_id = job_def .id )
216
+ if len (start_response .job_runs ) != 1 :
217
+ raise RuntimeError (
218
+ f"Expected 1 job run to be created, got { len (start_response .job_runs )} "
219
+ )
220
+
221
+ job_run = start_response .job_runs [0 ]
199
222
200
223
self ._instance .report_engine_event (
201
- message = f"Started job definition { job_def .name } with run id { job_run .id } for Dagster run { run .run_id } " ,
224
+ message = f"Started job definition { job_def .name } with job run id { job_run .id } for Dagster run { run .run_id } " ,
202
225
dagster_run = run ,
203
226
cls = self .__class__ ,
204
227
)
@@ -223,25 +246,11 @@ def launch_run(self, context: LaunchRunContext) -> None:
223
246
instance_ref = self ._instance .get_ref (),
224
247
).get_command_args ()
225
248
226
- self ._launch_serverless_job_with_command (run , docker_image , command )
249
+ self ._start_serverless_job_with_command (run , docker_image , list ( command ) )
227
250
228
251
@property
229
252
def supports_resume_run (self ):
230
- # TODO?: check if we can resume a run
231
- return True
232
-
233
- def resume_run (self , context : ResumeRunContext ) -> None :
234
- run = context .dagster_run
235
- job_code_origin = check .not_none (context .job_code_origin )
236
- docker_image = self ._get_docker_image (job_code_origin )
237
-
238
- command = ResumeRunArgs (
239
- job_origin = job_code_origin ,
240
- run_id = run .run_id ,
241
- instance_ref = self ._instance .get_ref (),
242
- ).get_command_args ()
243
-
244
- self ._launch_serverless_job_with_command (run , docker_image , command )
253
+ return False
245
254
246
255
def _get_scaleway_job_run_from_dagster_run (self , run ) -> Optional [scw .JobRun ]:
247
256
if not run or run .is_finished :
@@ -303,18 +312,18 @@ def check_run_worker_health(self, run: DagsterRun):
303
312
job_run = self ._get_scaleway_job_run_from_dagster_run (run )
304
313
if job_run is None :
305
314
return CheckRunHealthResult (
306
- WorkerStatus .NOT_FOUND ,
315
+ status = WorkerStatus .NOT_FOUND ,
307
316
run_worker_id = run .run_id ,
308
317
msg = f"Unable to find Scaleway job run with id { run .run_id } for Dagster run { run .run_id } " ,
309
318
)
310
319
311
- health = CheckRunHealthResult (run_worker_id = run .run_id )
312
- health .transient = job_run .state in scw .JOB_RUN_TRANSIENT_STATUSES
313
- health .status = SERVERLESS_JOBS_STATES_TO_WORKER_STATUS .get (
314
- job_run .state , WorkerStatus .UNKNOWN
320
+ health = CheckRunHealthResult (
321
+ status = SERVERLESS_JOBS_STATES_TO_WORKER_STATUS .get (
322
+ job_run .state , WorkerStatus .UNKNOWN
323
+ ),
324
+ run_worker_id = run .run_id ,
325
+ msg = job_run .error_message ,
326
+ transient = job_run .state in scw .JOB_RUN_TRANSIENT_STATUSES ,
315
327
)
316
328
317
- if job_run .error_message :
318
- health .msg = job_run .error_message
319
-
320
329
return health
0 commit comments