Skip to content

Commit cfc2157

Browse files
authored
AIP-72: Handling failed TI state for AirflowFailException & AirflowSensorTimeout (#44954)
related: #44414 We already have support for handling terminal states from the task execution side as well as the task SDK client side. (almost) and failed state is part of the terminal state. This PR extends the task runner's run function to handle cases when we have to fail a task: `AirflowFailException, AirflowSensorTimeout`. It is functionally very similar to #44786 As part of failing a task, multiple other things also needs to be done like: - Callbacks: which will eventually be converted to teardown tasks - Retries: Handled in #44351 - unmapping TIs: #44351 - Handling task history: will be handled by #44952 - Handling downstream tasks and non teardown tasks: will be handled by #44951 ### Testing performed #### End to End with Postman 1. Run airflow with breeze and run any DAG ![image](https://github.com/user-attachments/assets/fafc89ea-4e28-4802-912b-d72bf401d94b) 2. Login to metadata DB and get the "id" for your task instance from TI table ![image](https://github.com/user-attachments/assets/75440f0f-f62a-4277-a2e6-cb78bd666dd4) 3. Send a request to `fail` your task ![image](https://github.com/user-attachments/assets/5991e944-f416-4b79-9954-15f1a6ebdd79) Or using curl: ``` curl --location --request PATCH 'http://localhost:29091/execution/task-instances/0193cec2-f46b-7348-9c27-9869d835dc7b/state' \ --header 'Content-Type: application/json' \ --data '{ "state": "failed", "end_date": "2024-10-31T12:00:00Z" }' ``` 4. Refresh back the Airflow UI to see that the task is in failed state. ![image](https://github.com/user-attachments/assets/bb866dc6-e1d6-435e-abe4-2d04c97280ad)
1 parent f631bef commit cfc2157

File tree

5 files changed

+96
-2
lines changed

5 files changed

+96
-2
lines changed

airflow/api_fastapi/execution_api/routes/task_instances.py

+4
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,10 @@ def ti_update_state(
201201

202202
if isinstance(ti_patch_payload, TITerminalStatePayload):
203203
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind)
204+
query = query.values(state=ti_patch_payload.state)
205+
if ti_patch_payload.state == State.FAILED:
206+
# clear the next_method and next_kwargs
207+
query = query.values(next_method=None, next_kwargs=None)
204208
elif isinstance(ti_patch_payload, TIDeferredStatePayload):
205209
# Calculate timeout if it was passed
206210
timeout = None

task_sdk/src/airflow/sdk/execution_time/task_runner.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,16 @@ def run(ti: RuntimeTaskInstance, log: Logger):
283283
...
284284
except (AirflowFailException, AirflowSensorTimeout):
285285
# If AirflowFailException is raised, task should not retry.
286-
...
286+
# If a sensor in reschedule mode reaches timeout, task should not retry.
287+
288+
# TODO: Handle fail_stop here: https://github.com/apache/airflow/issues/44951
289+
# TODO: Handle addition to Log table: https://github.com/apache/airflow/issues/44952
290+
msg = TaskState(
291+
state=TerminalTIState.FAILED,
292+
end_date=datetime.now(tz=timezone.utc),
293+
)
294+
295+
# TODO: Run task failure callbacks here
287296
except (AirflowTaskTimeout, AirflowException, AirflowTaskTerminated):
288297
...
289298
except SystemExit:

task_sdk/tests/execution_time/test_supervisor.py

+2
Original file line numberDiff line numberDiff line change
@@ -854,6 +854,8 @@ def watched_subprocess(self, mocker):
854854
{"ok": True},
855855
id="set_xcom_with_map_index",
856856
),
857+
# we aren't adding all states under TerminalTIState here, because this test's scope is only to check
858+
# if it can handle TaskState message
857859
pytest.param(
858860
TaskState(state=TerminalTIState.SKIPPED, end_date=timezone.parse("2024-10-31T12:00:00Z")),
859861
b"",

task_sdk/tests/execution_time/test_task_runner.py

+50-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import pytest
2727
from uuid6 import uuid7
2828

29-
from airflow.exceptions import AirflowSkipException
29+
from airflow.exceptions import AirflowFailException, AirflowSensorTimeout, AirflowSkipException
3030
from airflow.sdk import DAG, BaseOperator
3131
from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
3232
from airflow.sdk.execution_time.comms import DeferTask, SetRenderedFields, StartupDetails, TaskState
@@ -333,6 +333,55 @@ def __init__(self, *args, **kwargs):
333333
)
334334

335335

336+
@pytest.mark.parametrize(
337+
["dag_id", "task_id", "fail_with_exception"],
338+
[
339+
pytest.param(
340+
"basic_failed", "fail-exception", AirflowFailException("Oops. Failing by AirflowFailException!")
341+
),
342+
pytest.param(
343+
"basic_failed2",
344+
"sensor-timeout-exception",
345+
AirflowSensorTimeout("Oops. Failing by AirflowSensorTimeout!"),
346+
),
347+
],
348+
)
349+
def test_run_basic_failed(time_machine, mocked_parse, dag_id, task_id, fail_with_exception, make_ti_context):
350+
"""Test running a basic task that marks itself as failed by raising exception."""
351+
352+
class CustomOperator(BaseOperator):
353+
def __init__(self, e, *args, **kwargs):
354+
super().__init__(*args, **kwargs)
355+
self.e = e
356+
357+
def execute(self, context):
358+
print(f"raising exception {self.e}")
359+
raise self.e
360+
361+
task = CustomOperator(task_id=task_id, e=fail_with_exception)
362+
363+
what = StartupDetails(
364+
ti=TaskInstance(id=uuid7(), task_id=task_id, dag_id=dag_id, run_id="c", try_number=1),
365+
file="",
366+
requests_fd=0,
367+
ti_context=make_ti_context(),
368+
)
369+
370+
ti = mocked_parse(what, dag_id, task)
371+
372+
instant = timezone.datetime(2024, 12, 3, 10, 0)
373+
time_machine.move_to(instant, tick=False)
374+
375+
with mock.patch(
376+
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
377+
) as mock_supervisor_comms:
378+
run(ti, log=mock.MagicMock())
379+
380+
mock_supervisor_comms.send_request.assert_called_once_with(
381+
msg=TaskState(state=TerminalTIState.FAILED, end_date=instant), log=mock.ANY
382+
)
383+
384+
336385
class TestRuntimeTaskInstance:
337386
def test_get_context_without_ti_context_from_server(self, mocked_parse, make_ti_context):
338387
"""Test get_template_context without ti_context_from_server."""

tests/api_fastapi/execution_api/routes/test_task_instances.py

+30
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,36 @@ def test_ti_heartbeat_update(self, client, session, create_task_instance, time_m
455455
session.refresh(ti)
456456
assert ti.last_heartbeat_at == time_now.add(minutes=10)
457457

458+
def test_ti_update_state_to_failed_table_check(self, client, session, create_task_instance):
459+
from math import ceil
460+
461+
ti = create_task_instance(
462+
task_id="test_ti_update_state_to_failed_table_check",
463+
state=State.RUNNING,
464+
)
465+
ti.start_date = DEFAULT_START_DATE
466+
session.commit()
467+
468+
response = client.patch(
469+
f"/execution/task-instances/{ti.id}/state",
470+
json={
471+
"state": State.FAILED,
472+
"end_date": DEFAULT_END_DATE.isoformat(),
473+
},
474+
)
475+
476+
assert response.status_code == 204
477+
assert response.text == ""
478+
479+
session.expire_all()
480+
481+
ti = session.get(TaskInstance, ti.id)
482+
assert ti.state == State.FAILED
483+
assert ti.next_method is None
484+
assert ti.next_kwargs is None
485+
# TODO: remove/amend this once https://github.com/apache/airflow/pull/45002 is merged
486+
assert ceil(ti.duration) == 3600.00
487+
458488

459489
class TestTIPutRTIF:
460490
def setup_method(self):

0 commit comments

Comments
 (0)