Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIP-72: Handling up_for_retry task instance states for AirflowTaskTimeout and AirflowException #44981

Closed
wants to merge 10 commits into from
2 changes: 2 additions & 0 deletions airflow/api_fastapi/execution_api/routes/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ def ti_update_state(

if isinstance(ti_patch_payload, TITerminalStatePayload):
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind)
# clear the next_method and next_kwargs for all terminal states, as we do not want retries to pick them
query = query.values(state=ti_patch_payload.state, next_method=None, next_kwargs=None)
elif isinstance(ti_patch_payload, TIDeferredStatePayload):
# Calculate timeout if it was passed
timeout = None
Expand Down
4 changes: 2 additions & 2 deletions airflow/utils/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class TerminalTIState(str, Enum):
FAILED = "failed"
SKIPPED = "skipped" # A user can raise a AirflowSkipException from a task & it will be marked as skipped
REMOVED = "removed"
UP_FOR_RETRY = "up_for_retry" # We do not need to do anything actionable for this state, hence it is a terminal state.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"terminal state" means the task is done; up for retry does not really feel like a terminal state.... since it's going to be retried....

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this highlights an ambiguity / conflict. the TI-try is done, but the TI is not


def __str__(self) -> str:
return self.value
Expand All @@ -50,7 +51,6 @@ class IntermediateTIState(str, Enum):
SCHEDULED = "scheduled"
QUEUED = "queued"
RESTARTING = "restarting"
UP_FOR_RETRY = "up_for_retry"
UP_FOR_RESCHEDULE = "up_for_reschedule"
UPSTREAM_FAILED = "upstream_failed"
DEFERRED = "deferred"
Expand Down Expand Up @@ -80,7 +80,7 @@ class TaskInstanceState(str, Enum):
SUCCESS = TerminalTIState.SUCCESS # Task completed
RESTARTING = IntermediateTIState.RESTARTING # External request to restart (e.g. cleared when running)
FAILED = TerminalTIState.FAILED # Task errored out
UP_FOR_RETRY = IntermediateTIState.UP_FOR_RETRY # Task failed but has retries left
UP_FOR_RETRY = TerminalTIState.UP_FOR_RETRY # Task failed but has retries left
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be IntermediateTIState, i.e. the task will be "retried" as opposed to success, failed, skipped etc -- where it TI is completed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in DM we discussed UPSTREAM_FAILED should be TerminalTIState not UP_FOR_RETRY

UP_FOR_RESCHEDULE = IntermediateTIState.UP_FOR_RESCHEDULE # A waiting `reschedule` sensor
UPSTREAM_FAILED = IntermediateTIState.UPSTREAM_FAILED # One or more upstream deps failed
SKIPPED = TerminalTIState.SKIPPED # Skipped by branching or some other mechanism
Expand Down
2 changes: 1 addition & 1 deletion task_sdk/src/airflow/sdk/api/datamodels/_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ class IntermediateTIState(str, Enum):
SCHEDULED = "scheduled"
QUEUED = "queued"
RESTARTING = "restarting"
UP_FOR_RETRY = "up_for_retry"
UP_FOR_RESCHEDULE = "up_for_reschedule"
UPSTREAM_FAILED = "upstream_failed"
DEFERRED = "deferred"
Expand Down Expand Up @@ -119,6 +118,7 @@ class TerminalTIState(str, Enum):
FAILED = "failed"
SKIPPED = "skipped"
REMOVED = "removed"
UP_FOR_RETRY = "up_for_retry"


class ValidationError(BaseModel):
Expand Down
23 changes: 21 additions & 2 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,28 @@ def run(ti: RuntimeTaskInstance, log: Logger):
...
except (AirflowFailException, AirflowSensorTimeout):
# If AirflowFailException is raised, task should not retry.
# If a sensor in reschedule mode reaches timeout, task should not retry.

# TODO: Handle fail_stop here: https://github.com/apache/airflow/issues/44951
# TODO: Handle addition to Log table: https://github.com/apache/airflow/issues/44952
msg = TaskState(
state=TerminalTIState.FAILED,
end_date=datetime.now(tz=timezone.utc),
)

# TODO: Run task failure callbacks here
except (AirflowTaskTimeout, AirflowException):
# TODO: handle the case of up_for_retry here
...
except (AirflowTaskTimeout, AirflowException, AirflowTaskTerminated):
...
except AirflowTaskTerminated:
# External state updates are already handled with `ti_heartbeat` and will be
# updated already be another UI API. So, these exceptions should ideally never be thrown.
# If these are thrown, we should mark the TI state as failed.
msg = TaskState(
state=TerminalTIState.FAILED,
end_date=datetime.now(tz=timezone.utc),
)
# TODO: Run task failure callbacks here
except SystemExit:
...
except BaseException:
Expand Down
8 changes: 8 additions & 0 deletions task_sdk/tests/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,14 @@ def watched_subprocess(self, mocker):
"",
id="patch_task_instance_to_skipped",
),
pytest.param(
TaskState(state=TerminalTIState.FAILED, end_date=timezone.parse("2024-10-31T12:00:00Z")),
b"",
"",
(),
"",
id="patch_task_instance_to_failed",
),
],
)
def test_handle_requests(
Expand Down
58 changes: 57 additions & 1 deletion task_sdk/tests/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@
import pytest
from uuid6 import uuid7

from airflow.exceptions import AirflowSkipException
from airflow.exceptions import (
AirflowFailException,
AirflowSensorTimeout,
AirflowSkipException,
AirflowTaskTerminated,
)
from airflow.sdk import DAG, BaseOperator
from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
from airflow.sdk.execution_time.comms import DeferTask, SetRenderedFields, StartupDetails, TaskState
Expand Down Expand Up @@ -333,6 +338,57 @@ def __init__(self, *args, **kwargs):
)


@pytest.mark.parametrize(
["dag_id", "task_id", "fail_with_exception"],
[
pytest.param(
"basic_failed", "fail-exception", AirflowFailException("Oops. Failing by AirflowFailException!")
),
pytest.param(
"basic_failed2",
"sensor-timeout-exception",
AirflowSensorTimeout("Oops. Failing by AirflowSensorTimeout!"),
),
pytest.param(
"basic_failed3",
"task-terminated-exception",
AirflowTaskTerminated("Oops. Failing by AirflowTaskTerminated!"),
),
],
)
def test_run_basic_failed(time_machine, mocked_parse, dag_id, task_id, fail_with_exception, make_ti_context):
"""Test running a basic task that marks itself as failed by raising exception."""
from airflow.providers.standard.operators.python import PythonOperator

task = PythonOperator(
task_id=task_id,
python_callable=lambda: (_ for _ in ()).throw(
fail_with_exception,
),
)

what = StartupDetails(
ti=TaskInstance(id=uuid7(), task_id=task_id, dag_id=dag_id, run_id="c", try_number=1),
file="",
requests_fd=0,
ti_context=make_ti_context(),
)

ti = mocked_parse(what, dag_id, task)

instant = timezone.datetime(2024, 12, 3, 10, 0)
time_machine.move_to(instant, tick=False)

with mock.patch(
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
) as mock_supervisor_comms:
run(ti, log=mock.MagicMock())

mock_supervisor_comms.send_request.assert_called_once_with(
msg=TaskState(state=TerminalTIState.FAILED, end_date=instant), log=mock.ANY
)


class TestRuntimeTaskInstance:
def test_get_context_without_ti_context_from_server(self, mocked_parse, make_ti_context):
"""Test get_template_context without ti_context_from_server."""
Expand Down
31 changes: 31 additions & 0 deletions tests/api_fastapi/execution_api/routes/test_task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def teardown_method(self):
(State.SUCCESS, DEFAULT_END_DATE, State.SUCCESS),
(State.FAILED, DEFAULT_END_DATE, State.FAILED),
(State.SKIPPED, DEFAULT_END_DATE, State.SKIPPED),
(State.UP_FOR_RETRY, DEFAULT_END_DATE, State.UP_FOR_RETRY),
],
)
def test_ti_update_state_to_terminal(
Expand Down Expand Up @@ -455,6 +456,36 @@ def test_ti_heartbeat_update(self, client, session, create_task_instance, time_m
session.refresh(ti)
assert ti.last_heartbeat_at == time_now.add(minutes=10)

def test_ti_update_state_to_failed_table_check(self, client, session, create_task_instance):
from math import ceil

ti = create_task_instance(
task_id="test_ti_update_state_to_terminal",
start_date=DEFAULT_START_DATE,
state=State.RUNNING,
)
ti.start_date = DEFAULT_START_DATE
session.commit()

response = client.patch(
f"/execution/task-instances/{ti.id}/state",
json={
"state": State.FAILED,
"end_date": DEFAULT_END_DATE.isoformat(),
},
)

assert response.status_code == 204
assert response.text == ""

session.expire_all()

ti = session.get(TaskInstance, ti.id)
assert ti.state == State.FAILED
assert ti.next_method is None
assert ti.next_kwargs is None
assert ceil(ti.duration) == 3600.00


class TestTIPutRTIF:
def setup_method(self):
Expand Down