@@ -141,42 +141,48 @@ def get_task_instance(self, *, session: Session) -> TaskInstance:
141141
142142 :param session: Sqlalchemy session
143143 """
144+ ti = self .task_instance
145+ if ti is None :
146+ raise RuntimeError ("task_instance is not set on the trigger" )
144147 task_instance = session .scalar (
145148 select (TaskInstance ).where (
146- TaskInstance .dag_id == self . task_instance .dag_id ,
147- TaskInstance .task_id == self . task_instance .task_id ,
148- TaskInstance .run_id == self . task_instance .run_id ,
149- TaskInstance .map_index == self . task_instance .map_index ,
149+ TaskInstance .dag_id == ti .dag_id ,
150+ TaskInstance .task_id == ti .task_id ,
151+ TaskInstance .run_id == ti .run_id ,
152+ TaskInstance .map_index == ti .map_index ,
150153 )
151154 )
152155 if task_instance is None :
153156 raise AirflowException (
154157 "TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found" ,
155- self . task_instance .dag_id ,
156- self . task_instance .task_id ,
157- self . task_instance .run_id ,
158- self . task_instance .map_index ,
158+ ti .dag_id ,
159+ ti .task_id ,
160+ ti .run_id ,
161+ ti .map_index ,
159162 )
160163 return task_instance
161164
162165 async def get_task_state (self ):
163166 from airflow .sdk .execution_time .task_runner import RuntimeTaskInstance
164167
168+ ti = self .task_instance
169+ if ti is None :
170+ raise RuntimeError ("task_instance is not set on the trigger" )
165171 task_states_response = await sync_to_async (RuntimeTaskInstance .get_task_states )(
166- dag_id = self . task_instance .dag_id ,
167- task_ids = [self . task_instance .task_id ],
168- run_ids = [self . task_instance .run_id ],
169- map_index = self . task_instance .map_index ,
172+ dag_id = ti .dag_id ,
173+ task_ids = [ti .task_id ],
174+ run_ids = [ti .run_id ],
175+ map_index = ti .map_index ,
170176 )
171177 try :
172- task_state = task_states_response [self . task_instance . run_id ][self . task_instance .task_id ]
178+ task_state = task_states_response [ti . run_id ][ti .task_id ]
173179 except Exception :
174180 raise AirflowException (
175181 "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found" ,
176- self . task_instance .dag_id ,
177- self . task_instance .task_id ,
178- self . task_instance .run_id ,
179- self . task_instance .map_index ,
182+ ti .dag_id ,
183+ ti .task_id ,
184+ ti .run_id ,
185+ ti .map_index ,
180186 )
181187 return task_state
182188
@@ -293,42 +299,48 @@ def get_task_instance(self, *, session: Session) -> TaskInstance:
293299
294300 :param session: Sqlalchemy session
295301 """
302+ ti = self .task_instance
303+ if ti is None :
304+ raise RuntimeError ("task_instance is not set on the trigger" )
296305 task_instance = session .scalar (
297306 select (TaskInstance ).where (
298- TaskInstance .dag_id == self . task_instance .dag_id ,
299- TaskInstance .task_id == self . task_instance .task_id ,
300- TaskInstance .run_id == self . task_instance .run_id ,
301- TaskInstance .map_index == self . task_instance .map_index ,
307+ TaskInstance .dag_id == ti .dag_id ,
308+ TaskInstance .task_id == ti .task_id ,
309+ TaskInstance .run_id == ti .run_id ,
310+ TaskInstance .map_index == ti .map_index ,
302311 )
303312 )
304313 if task_instance is None :
305314 raise RuntimeError (
306315 "TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found" ,
307- self . task_instance .dag_id ,
308- self . task_instance .task_id ,
309- self . task_instance .run_id ,
310- self . task_instance .map_index ,
316+ ti .dag_id ,
317+ ti .task_id ,
318+ ti .run_id ,
319+ ti .map_index ,
311320 )
312321 return task_instance
313322
314323 async def get_task_state (self ):
315324 from airflow .sdk .execution_time .task_runner import RuntimeTaskInstance
316325
326+ ti = self .task_instance
327+ if ti is None :
328+ raise RuntimeError ("task_instance is not set on the trigger" )
317329 task_states_response = await sync_to_async (RuntimeTaskInstance .get_task_states )(
318- dag_id = self . task_instance .dag_id ,
319- task_ids = [self . task_instance .task_id ],
320- run_ids = [self . task_instance .run_id ],
321- map_index = self . task_instance .map_index ,
330+ dag_id = ti .dag_id ,
331+ task_ids = [ti .task_id ],
332+ run_ids = [ti .run_id ],
333+ map_index = ti .map_index ,
322334 )
323335 try :
324- task_state = task_states_response [self . task_instance . run_id ][self . task_instance .task_id ]
336+ task_state = task_states_response [ti . run_id ][ti .task_id ]
325337 except Exception :
326338 raise RuntimeError (
327339 "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found" ,
328- self . task_instance .dag_id ,
329- self . task_instance .task_id ,
330- self . task_instance .run_id ,
331- self . task_instance .map_index ,
340+ ti .dag_id ,
341+ ti .task_id ,
342+ ti .run_id ,
343+ ti .map_index ,
332344 )
333345 return task_state
334346
@@ -432,42 +444,48 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
432444
433445 @provide_session
434446 def get_task_instance (self , * , session : Session ) -> TaskInstance :
447+ ti = self .task_instance
448+ if ti is None :
449+ raise RuntimeError ("task_instance is not set on the trigger" )
435450 task_instance = session .scalar (
436451 select (TaskInstance ).where (
437- TaskInstance .dag_id == self . task_instance .dag_id ,
438- TaskInstance .task_id == self . task_instance .task_id ,
439- TaskInstance .run_id == self . task_instance .run_id ,
440- TaskInstance .map_index == self . task_instance .map_index ,
452+ TaskInstance .dag_id == ti .dag_id ,
453+ TaskInstance .task_id == ti .task_id ,
454+ TaskInstance .run_id == ti .run_id ,
455+ TaskInstance .map_index == ti .map_index ,
441456 )
442457 )
443458 if task_instance is None :
444459 raise AirflowException (
445460 "TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found." ,
446- self . task_instance .dag_id ,
447- self . task_instance .task_id ,
448- self . task_instance .run_id ,
449- self . task_instance .map_index ,
461+ ti .dag_id ,
462+ ti .task_id ,
463+ ti .run_id ,
464+ ti .map_index ,
450465 )
451466 return task_instance
452467
453468 async def get_task_state (self ):
454469 from airflow .sdk .execution_time .task_runner import RuntimeTaskInstance
455470
471+ ti = self .task_instance
472+ if ti is None :
473+ raise RuntimeError ("task_instance is not set on the trigger" )
456474 task_states_response = await sync_to_async (RuntimeTaskInstance .get_task_states )(
457- dag_id = self . task_instance .dag_id ,
458- task_ids = [self . task_instance .task_id ],
459- run_ids = [self . task_instance .run_id ],
460- map_index = self . task_instance .map_index ,
475+ dag_id = ti .dag_id ,
476+ task_ids = [ti .task_id ],
477+ run_ids = [ti .run_id ],
478+ map_index = ti .map_index ,
461479 )
462480 try :
463- task_state = task_states_response [self . task_instance . run_id ][self . task_instance .task_id ]
481+ task_state = task_states_response [ti . run_id ][ti .task_id ]
464482 except Exception :
465483 raise AirflowException (
466484 "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found" ,
467- self . task_instance .dag_id ,
468- self . task_instance .task_id ,
469- self . task_instance .run_id ,
470- self . task_instance .map_index ,
485+ ti .dag_id ,
486+ ti .task_id ,
487+ ti .run_id ,
488+ ti .map_index ,
471489 )
472490 return task_state
473491
0 commit comments