Skip to content

Commit 2746dcb

Browse files
authored
Fix mypy errors for task_instance access in provider triggers (#68685)
1 parent 88b8928 commit 2746dcb

4 files changed

Lines changed: 114 additions & 84 deletions

File tree

  • providers
    • amazon/src/airflow/providers/amazon/aws/triggers
    • cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers
    • google/src/airflow/providers/google/cloud/triggers

providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -423,19 +423,22 @@ def get_task_instance(self, *, session: Session) -> TaskInstance:
423423
"""Get the task instance for the current trigger (Airflow 2.x compatibility)."""
424424
from sqlalchemy import select
425425

426+
ti = self.task_instance
427+
if ti is None:
428+
raise RuntimeError("task_instance is not set on the trigger")
426429
query = select(TaskInstance).where(
427-
TaskInstance.dag_id == self.task_instance.dag_id,
428-
TaskInstance.task_id == self.task_instance.task_id,
429-
TaskInstance.run_id == self.task_instance.run_id,
430-
TaskInstance.map_index == self.task_instance.map_index,
430+
TaskInstance.dag_id == ti.dag_id,
431+
TaskInstance.task_id == ti.task_id,
432+
TaskInstance.run_id == ti.run_id,
433+
TaskInstance.map_index == ti.map_index,
431434
)
432435
task_instance = session.scalars(query).one_or_none()
433436
if task_instance is None:
434437
raise ValueError(
435-
f"TaskInstance with dag_id: {self.task_instance.dag_id}, "
436-
f"task_id: {self.task_instance.task_id}, "
437-
f"run_id: {self.task_instance.run_id} and "
438-
f"map_index: {self.task_instance.map_index} is not found"
438+
f"TaskInstance with dag_id: {ti.dag_id}, "
439+
f"task_id: {ti.task_id}, "
440+
f"run_id: {ti.run_id} and "
441+
f"map_index: {ti.map_index} is not found"
439442
)
440443
return task_instance
441444

providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -440,21 +440,24 @@ def pod_manager(self) -> AsyncPodManager:
440440
@provide_session
441441
def get_task_instance(self, *, session: Session) -> TaskInstance:
442442
"""Get the task instance for this trigger from the database (Airflow 2.x only)."""
443+
ti = self.task_instance
444+
if ti is None:
445+
raise RuntimeError("task_instance is not set on the trigger")
443446
task_instance = session.scalar(
444447
select(TaskInstance).where(
445-
TaskInstance.dag_id == self.task_instance.dag_id,
446-
TaskInstance.task_id == self.task_instance.task_id,
447-
TaskInstance.run_id == self.task_instance.run_id,
448-
TaskInstance.map_index == self.task_instance.map_index,
448+
TaskInstance.dag_id == ti.dag_id,
449+
TaskInstance.task_id == ti.task_id,
450+
TaskInstance.run_id == ti.run_id,
451+
TaskInstance.map_index == ti.map_index,
449452
)
450453
)
451454
if task_instance is None:
452455
raise AirflowException(
453456
"TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found",
454-
self.task_instance.dag_id,
455-
self.task_instance.task_id,
456-
self.task_instance.run_id,
457-
self.task_instance.map_index,
457+
ti.dag_id,
458+
ti.task_id,
459+
ti.run_id,
460+
ti.map_index,
458461
)
459462
return task_instance
460463

providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -120,42 +120,48 @@ async def on_kill(self) -> None:
120120

121121
@provide_session
122122
def get_task_instance(self, *, session: Session) -> TaskInstance:
123+
ti = self.task_instance
124+
if ti is None:
125+
raise RuntimeError("task_instance is not set on the trigger")
123126
task_instance = session.scalar(
124127
select(TaskInstance).where(
125-
TaskInstance.dag_id == self.task_instance.dag_id,
126-
TaskInstance.task_id == self.task_instance.task_id,
127-
TaskInstance.run_id == self.task_instance.run_id,
128-
TaskInstance.map_index == self.task_instance.map_index,
128+
TaskInstance.dag_id == ti.dag_id,
129+
TaskInstance.task_id == ti.task_id,
130+
TaskInstance.run_id == ti.run_id,
131+
TaskInstance.map_index == ti.map_index,
129132
)
130133
)
131134
if task_instance is None:
132135
raise AirflowException(
133136
"TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found",
134-
self.task_instance.dag_id,
135-
self.task_instance.task_id,
136-
self.task_instance.run_id,
137-
self.task_instance.map_index,
137+
ti.dag_id,
138+
ti.task_id,
139+
ti.run_id,
140+
ti.map_index,
138141
)
139142
return task_instance
140143

141144
async def get_task_state(self):
142145
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
143146

147+
ti = self.task_instance
148+
if ti is None:
149+
raise RuntimeError("task_instance is not set on the trigger")
144150
task_states_response = await sync_to_async(RuntimeTaskInstance.get_task_states)(
145-
dag_id=self.task_instance.dag_id,
146-
task_ids=[self.task_instance.task_id],
147-
run_ids=[self.task_instance.run_id],
148-
map_index=self.task_instance.map_index,
151+
dag_id=ti.dag_id,
152+
task_ids=[ti.task_id],
153+
run_ids=[ti.run_id],
154+
map_index=ti.map_index,
149155
)
150156
try:
151-
task_state = task_states_response[self.task_instance.run_id][self.task_instance.task_id]
157+
task_state = task_states_response[ti.run_id][ti.task_id]
152158
except Exception:
153159
raise AirflowException(
154160
"TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found",
155-
self.task_instance.dag_id,
156-
self.task_instance.task_id,
157-
self.task_instance.run_id,
158-
self.task_instance.map_index,
161+
ti.dag_id,
162+
ti.task_id,
163+
ti.run_id,
164+
ti.map_index,
159165
)
160166
return task_state
161167

providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py

Lines changed: 69 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -141,42 +141,48 @@ def get_task_instance(self, *, session: Session) -> TaskInstance:
141141
142142
:param session: Sqlalchemy session
143143
"""
144+
ti = self.task_instance
145+
if ti is None:
146+
raise RuntimeError("task_instance is not set on the trigger")
144147
task_instance = session.scalar(
145148
select(TaskInstance).where(
146-
TaskInstance.dag_id == self.task_instance.dag_id,
147-
TaskInstance.task_id == self.task_instance.task_id,
148-
TaskInstance.run_id == self.task_instance.run_id,
149-
TaskInstance.map_index == self.task_instance.map_index,
149+
TaskInstance.dag_id == ti.dag_id,
150+
TaskInstance.task_id == ti.task_id,
151+
TaskInstance.run_id == ti.run_id,
152+
TaskInstance.map_index == ti.map_index,
150153
)
151154
)
152155
if task_instance is None:
153156
raise AirflowException(
154157
"TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found",
155-
self.task_instance.dag_id,
156-
self.task_instance.task_id,
157-
self.task_instance.run_id,
158-
self.task_instance.map_index,
158+
ti.dag_id,
159+
ti.task_id,
160+
ti.run_id,
161+
ti.map_index,
159162
)
160163
return task_instance
161164

162165
async def get_task_state(self):
163166
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
164167

168+
ti = self.task_instance
169+
if ti is None:
170+
raise RuntimeError("task_instance is not set on the trigger")
165171
task_states_response = await sync_to_async(RuntimeTaskInstance.get_task_states)(
166-
dag_id=self.task_instance.dag_id,
167-
task_ids=[self.task_instance.task_id],
168-
run_ids=[self.task_instance.run_id],
169-
map_index=self.task_instance.map_index,
172+
dag_id=ti.dag_id,
173+
task_ids=[ti.task_id],
174+
run_ids=[ti.run_id],
175+
map_index=ti.map_index,
170176
)
171177
try:
172-
task_state = task_states_response[self.task_instance.run_id][self.task_instance.task_id]
178+
task_state = task_states_response[ti.run_id][ti.task_id]
173179
except Exception:
174180
raise AirflowException(
175181
"TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found",
176-
self.task_instance.dag_id,
177-
self.task_instance.task_id,
178-
self.task_instance.run_id,
179-
self.task_instance.map_index,
182+
ti.dag_id,
183+
ti.task_id,
184+
ti.run_id,
185+
ti.map_index,
180186
)
181187
return task_state
182188

@@ -293,42 +299,48 @@ def get_task_instance(self, *, session: Session) -> TaskInstance:
293299
294300
:param session: Sqlalchemy session
295301
"""
302+
ti = self.task_instance
303+
if ti is None:
304+
raise RuntimeError("task_instance is not set on the trigger")
296305
task_instance = session.scalar(
297306
select(TaskInstance).where(
298-
TaskInstance.dag_id == self.task_instance.dag_id,
299-
TaskInstance.task_id == self.task_instance.task_id,
300-
TaskInstance.run_id == self.task_instance.run_id,
301-
TaskInstance.map_index == self.task_instance.map_index,
307+
TaskInstance.dag_id == ti.dag_id,
308+
TaskInstance.task_id == ti.task_id,
309+
TaskInstance.run_id == ti.run_id,
310+
TaskInstance.map_index == ti.map_index,
302311
)
303312
)
304313
if task_instance is None:
305314
raise RuntimeError(
306315
"TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found",
307-
self.task_instance.dag_id,
308-
self.task_instance.task_id,
309-
self.task_instance.run_id,
310-
self.task_instance.map_index,
316+
ti.dag_id,
317+
ti.task_id,
318+
ti.run_id,
319+
ti.map_index,
311320
)
312321
return task_instance
313322

314323
async def get_task_state(self):
315324
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
316325

326+
ti = self.task_instance
327+
if ti is None:
328+
raise RuntimeError("task_instance is not set on the trigger")
317329
task_states_response = await sync_to_async(RuntimeTaskInstance.get_task_states)(
318-
dag_id=self.task_instance.dag_id,
319-
task_ids=[self.task_instance.task_id],
320-
run_ids=[self.task_instance.run_id],
321-
map_index=self.task_instance.map_index,
330+
dag_id=ti.dag_id,
331+
task_ids=[ti.task_id],
332+
run_ids=[ti.run_id],
333+
map_index=ti.map_index,
322334
)
323335
try:
324-
task_state = task_states_response[self.task_instance.run_id][self.task_instance.task_id]
336+
task_state = task_states_response[ti.run_id][ti.task_id]
325337
except Exception:
326338
raise RuntimeError(
327339
"TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found",
328-
self.task_instance.dag_id,
329-
self.task_instance.task_id,
330-
self.task_instance.run_id,
331-
self.task_instance.map_index,
340+
ti.dag_id,
341+
ti.task_id,
342+
ti.run_id,
343+
ti.map_index,
332344
)
333345
return task_state
334346

@@ -432,42 +444,48 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
432444

433445
@provide_session
434446
def get_task_instance(self, *, session: Session) -> TaskInstance:
447+
ti = self.task_instance
448+
if ti is None:
449+
raise RuntimeError("task_instance is not set on the trigger")
435450
task_instance = session.scalar(
436451
select(TaskInstance).where(
437-
TaskInstance.dag_id == self.task_instance.dag_id,
438-
TaskInstance.task_id == self.task_instance.task_id,
439-
TaskInstance.run_id == self.task_instance.run_id,
440-
TaskInstance.map_index == self.task_instance.map_index,
452+
TaskInstance.dag_id == ti.dag_id,
453+
TaskInstance.task_id == ti.task_id,
454+
TaskInstance.run_id == ti.run_id,
455+
TaskInstance.map_index == ti.map_index,
441456
)
442457
)
443458
if task_instance is None:
444459
raise AirflowException(
445460
"TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found.",
446-
self.task_instance.dag_id,
447-
self.task_instance.task_id,
448-
self.task_instance.run_id,
449-
self.task_instance.map_index,
461+
ti.dag_id,
462+
ti.task_id,
463+
ti.run_id,
464+
ti.map_index,
450465
)
451466
return task_instance
452467

453468
async def get_task_state(self):
454469
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
455470

471+
ti = self.task_instance
472+
if ti is None:
473+
raise RuntimeError("task_instance is not set on the trigger")
456474
task_states_response = await sync_to_async(RuntimeTaskInstance.get_task_states)(
457-
dag_id=self.task_instance.dag_id,
458-
task_ids=[self.task_instance.task_id],
459-
run_ids=[self.task_instance.run_id],
460-
map_index=self.task_instance.map_index,
475+
dag_id=ti.dag_id,
476+
task_ids=[ti.task_id],
477+
run_ids=[ti.run_id],
478+
map_index=ti.map_index,
461479
)
462480
try:
463-
task_state = task_states_response[self.task_instance.run_id][self.task_instance.task_id]
481+
task_state = task_states_response[ti.run_id][ti.task_id]
464482
except Exception:
465483
raise AirflowException(
466484
"TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found",
467-
self.task_instance.dag_id,
468-
self.task_instance.task_id,
469-
self.task_instance.run_id,
470-
self.task_instance.map_index,
485+
ti.dag_id,
486+
ti.task_id,
487+
ti.run_id,
488+
ti.map_index,
471489
)
472490
return task_state
473491

0 commit comments

Comments
 (0)