diff --git a/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow/api_fastapi/execution_api/routes/task_instances.py index 3a1545283e81b..9fe23d2f45e0f 100644 --- a/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -201,6 +201,8 @@ def ti_update_state( if isinstance(ti_patch_payload, TITerminalStatePayload): query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind) + # clear the next_method and next_kwargs for all terminal states, as we do not want retries to pick them + query = query.values(state=ti_patch_payload.state, next_method=None, next_kwargs=None) elif isinstance(ti_patch_payload, TIDeferredStatePayload): # Calculate timeout if it was passed timeout = None diff --git a/airflow/utils/state.py b/airflow/utils/state.py index e4e2e9db8a587..7027d65ef8e81 100644 --- a/airflow/utils/state.py +++ b/airflow/utils/state.py @@ -39,6 +39,7 @@ class TerminalTIState(str, Enum): FAILED = "failed" SKIPPED = "skipped" # A user can raise a AirflowSkipException from a task & it will be marked as skipped REMOVED = "removed" + UP_FOR_RETRY = "up_for_retry" # We do not need to do anything actionable for this state, hence it is a terminal state. def __str__(self) -> str: return self.value @@ -50,7 +51,6 @@ class IntermediateTIState(str, Enum): SCHEDULED = "scheduled" QUEUED = "queued" RESTARTING = "restarting" - UP_FOR_RETRY = "up_for_retry" UP_FOR_RESCHEDULE = "up_for_reschedule" UPSTREAM_FAILED = "upstream_failed" DEFERRED = "deferred" @@ -80,7 +80,7 @@ class TaskInstanceState(str, Enum): SUCCESS = TerminalTIState.SUCCESS # Task completed RESTARTING = IntermediateTIState.RESTARTING # External request to restart (e.g. cleared when running) FAILED = TerminalTIState.FAILED # Task errored out - UP_FOR_RETRY = IntermediateTIState.UP_FOR_RETRY # Task failed but has retries left + UP_FOR_RETRY = TerminalTIState.UP_FOR_RETRY # Task failed but has retries left UP_FOR_RESCHEDULE = IntermediateTIState.UP_FOR_RESCHEDULE # A waiting `reschedule` sensor UPSTREAM_FAILED = IntermediateTIState.UPSTREAM_FAILED # One or more upstream deps failed SKIPPED = TerminalTIState.SKIPPED # Skipped by branching or some other mechanism diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py index 5a103e78fc0ff..a3937e48086dc 100644 --- a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -63,7 +63,6 @@ class IntermediateTIState(str, Enum): SCHEDULED = "scheduled" QUEUED = "queued" RESTARTING = "restarting" - UP_FOR_RETRY = "up_for_retry" UP_FOR_RESCHEDULE = "up_for_reschedule" UPSTREAM_FAILED = "upstream_failed" DEFERRED = "deferred" @@ -119,6 +118,7 @@ class TerminalTIState(str, Enum): FAILED = "failed" SKIPPED = "skipped" REMOVED = "removed" + UP_FOR_RETRY = "up_for_retry" class ValidationError(BaseModel): diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index 92f400d46e2bb..50f89ff064eb9 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -283,9 +283,32 @@ def run(ti: RuntimeTaskInstance, log: Logger): ... except (AirflowFailException, AirflowSensorTimeout): # If AirflowFailException is raised, task should not retry. - ... - except (AirflowTaskTimeout, AirflowException, AirflowTaskTerminated): - ... + # If a sensor in reschedule mode reaches timeout, task should not retry. + + # TODO: Handle fail_stop here: https://github.com/apache/airflow/issues/44951 + # TODO: Handle addition to Log table: https://github.com/apache/airflow/issues/44952 + msg = TaskState( + state=TerminalTIState.FAILED, + end_date=datetime.now(tz=timezone.utc), + ) + + # TODO: Run task failure callbacks here + except (AirflowTaskTimeout, AirflowException): + msg = TaskState( + state=TerminalTIState.UP_FOR_RETRY, + end_date=datetime.now(tz=timezone.utc), + ) + + # TODO: Run task retry callbacks here + except AirflowTaskTerminated: + # External state updates are already handled with `ti_heartbeat` and will be + # updated already be another UI API. So, these exceptions should ideally never be thrown. + # If these are thrown, we should mark the TI state as failed. + msg = TaskState( + state=TerminalTIState.FAILED, + end_date=datetime.now(tz=timezone.utc), + ) + # TODO: Run task failure callbacks here except SystemExit: ... except BaseException: diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py index 70f9e26486408..17c336fdcd7fe 100644 --- a/task_sdk/tests/execution_time/test_supervisor.py +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -862,6 +862,24 @@ def watched_subprocess(self, mocker): "", id="patch_task_instance_to_skipped", ), + pytest.param( + TaskState(state=TerminalTIState.FAILED, end_date=timezone.parse("2024-10-31T12:00:00Z")), + b"", + "", + (), + "", + id="patch_task_instance_to_failed", + ), + pytest.param( + TaskState( + state=TerminalTIState.UP_FOR_RETRY, end_date=timezone.parse("2024-10-31T12:00:00Z") + ), + b"", + "", + (), + "", + id="patch_task_instance_to_retry", + ), ], ) def test_handle_requests( diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index 2b812c92a7338..f4a3f2f0db6d7 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -26,7 +26,14 @@ import pytest from uuid6 import uuid7 -from airflow.exceptions import AirflowSkipException +from airflow.exceptions import ( + AirflowException, + AirflowFailException, + AirflowSensorTimeout, + AirflowSkipException, + AirflowTaskTerminated, + AirflowTaskTimeout, +) from airflow.sdk import DAG, BaseOperator from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState from airflow.sdk.execution_time.comms import DeferTask, SetRenderedFields, StartupDetails, TaskState @@ -333,6 +340,103 @@ def __init__(self, *args, **kwargs): ) +@pytest.mark.parametrize( + ["dag_id", "task_id", "fail_with_exception"], + [ + pytest.param( + "basic_failed", "fail-exception", AirflowFailException("Oops. Failing by AirflowFailException!") + ), + pytest.param( + "basic_failed2", + "sensor-timeout-exception", + AirflowSensorTimeout("Oops. Failing by AirflowSensorTimeout!"), + ), + pytest.param( + "basic_failed3", + "task-terminated-exception", + AirflowTaskTerminated("Oops. Failing by AirflowTaskTerminated!"), + ), + ], +) +def test_run_basic_failed(time_machine, mocked_parse, dag_id, task_id, fail_with_exception, make_ti_context): + """Test running a basic task that marks itself as failed by raising exception.""" + from airflow.providers.standard.operators.python import PythonOperator + + task = PythonOperator( + task_id=task_id, + python_callable=lambda: (_ for _ in ()).throw( + fail_with_exception, + ), + ) + + what = StartupDetails( + ti=TaskInstance(id=uuid7(), task_id=task_id, dag_id=dag_id, run_id="c", try_number=1), + file="", + requests_fd=0, + ti_context=make_ti_context(), + ) + + ti = mocked_parse(what, dag_id, task) + + instant = timezone.datetime(2024, 12, 3, 10, 0) + time_machine.move_to(instant, tick=False) + + with mock.patch( + "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True + ) as mock_supervisor_comms: + run(ti, log=mock.MagicMock()) + + mock_supervisor_comms.send_request.assert_called_once_with( + msg=TaskState(state=TerminalTIState.FAILED, end_date=instant), log=mock.ANY + ) + + +@pytest.mark.parametrize( + ["dag_id", "task_id", "retry_with_exception"], + [ + pytest.param( + "basic_retry", + "task-timeout-exception", + AirflowTaskTimeout("Oops. Failing by AirflowTaskTimeout!"), + ), + pytest.param( + "basic_retry2", "airflow-exception", AirflowException("Oops. Failing by AirflowException!") + ), + ], +) +def test_run_basic_retry(time_machine, mocked_parse, dag_id, task_id, retry_with_exception, make_ti_context): + """Test running a basic task that sets itself for up_for_retry with various exceptions.""" + from airflow.providers.standard.operators.python import PythonOperator + + task = PythonOperator( + task_id=task_id, + python_callable=lambda: (_ for _ in ()).throw( + retry_with_exception, + ), + ) + + what = StartupDetails( + ti=TaskInstance(id=uuid7(), task_id=task_id, dag_id=dag_id, run_id="c", try_number=1), + file="", + requests_fd=0, + ti_context=make_ti_context(), + ) + + ti = mocked_parse(what, dag_id, task) + + instant = timezone.datetime(2024, 12, 3, 10, 0) + time_machine.move_to(instant, tick=False) + + with mock.patch( + "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True + ) as mock_supervisor_comms: + run(ti, log=mock.MagicMock()) + + mock_supervisor_comms.send_request.assert_called_once_with( + msg=TaskState(state=TerminalTIState.UP_FOR_RETRY, end_date=instant), log=mock.ANY + ) + + class TestRuntimeTaskInstance: def test_get_context_without_ti_context_from_server(self, mocked_parse, make_ti_context): """Test get_template_context without ti_context_from_server.""" diff --git a/tests/api_fastapi/execution_api/routes/test_task_instances.py b/tests/api_fastapi/execution_api/routes/test_task_instances.py index e67d82a718cd6..e8b2dffa8b483 100644 --- a/tests/api_fastapi/execution_api/routes/test_task_instances.py +++ b/tests/api_fastapi/execution_api/routes/test_task_instances.py @@ -150,6 +150,7 @@ def teardown_method(self): (State.SUCCESS, DEFAULT_END_DATE, State.SUCCESS), (State.FAILED, DEFAULT_END_DATE, State.FAILED), (State.SKIPPED, DEFAULT_END_DATE, State.SKIPPED), + (State.UP_FOR_RETRY, DEFAULT_END_DATE, State.UP_FOR_RETRY), ], ) def test_ti_update_state_to_terminal( @@ -455,6 +456,36 @@ def test_ti_heartbeat_update(self, client, session, create_task_instance, time_m session.refresh(ti) assert ti.last_heartbeat_at == time_now.add(minutes=10) + def test_ti_update_state_to_failed_table_check(self, client, session, create_task_instance): + from math import ceil + + ti = create_task_instance( + task_id="test_ti_update_state_to_terminal", + start_date=DEFAULT_START_DATE, + state=State.RUNNING, + ) + ti.start_date = DEFAULT_START_DATE + session.commit() + + response = client.patch( + f"/execution/task-instances/{ti.id}/state", + json={ + "state": State.FAILED, + "end_date": DEFAULT_END_DATE.isoformat(), + }, + ) + + assert response.status_code == 204 + assert response.text == "" + + session.expire_all() + + ti = session.get(TaskInstance, ti.id) + assert ti.state == State.FAILED + assert ti.next_method is None + assert ti.next_kwargs is None + assert ceil(ti.duration) == 3600.00 + class TestTIPutRTIF: def setup_method(self):