Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions airflow-core/newsfragments/66394.significant.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
Add a required ``msg`` keyword argument to the four task instance listener hooks (``on_task_instance_running``, ``on_task_instance_success``, ``on_task_instance_failed``, ``on_task_instance_skipped``).

The ``msg`` arg carries short canonical context for the state change
(``"started"``, ``"success"``, ``"skipped"``, ``"failed"``, ``"up_for_retry"``
from the worker; ``"manually_set_to_*"`` when the state was changed via the
API). This mirrors the DagRun listener pattern (#56272) and lets listener
implementations route or filter events without re-deriving intent from other
fields. Existing ``hookimpl`` implementations that don't declare ``msg``
continue to work unchanged — pluggy passes only the parameters each
implementation declares.
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,20 @@ def _emit_state_listener_hooks(updated_tis: list[TI], new_state: str | TaskInsta
for ti in updated_tis:
try:
if new_state == TaskInstanceState.SUCCESS:
get_listener_manager().hook.on_task_instance_success(previous_state=None, task_instance=ti)
get_listener_manager().hook.on_task_instance_success(
previous_state=None, task_instance=ti, msg="manually_set_to_success"
)
elif new_state == TaskInstanceState.FAILED:
get_listener_manager().hook.on_task_instance_failed(
previous_state=None,
task_instance=ti,
error=f"TaskInstance's state was manually set to `{TaskInstanceState.FAILED}`.",
msg="manually_set_to_failed",
)
elif new_state == TaskInstanceState.SKIPPED:
get_listener_manager().hook.on_task_instance_skipped(previous_state=None, task_instance=ti)
get_listener_manager().hook.on_task_instance_skipped(
previous_state=None, task_instance=ti, msg="manually_set_to_skipped"
)
except Exception:
log.exception("error calling listener")

Expand Down
35 changes: 27 additions & 8 deletions airflow-core/src/airflow/example_dags/plugins/event_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,20 @@
# [START howto_listen_ti_running_task]
@hookimpl
def on_task_instance_running(
previous_state: TaskInstanceState | None, task_instance: RuntimeTaskInstance | TaskInstance
previous_state: TaskInstanceState | None,
task_instance: RuntimeTaskInstance | TaskInstance,
msg: str,
):
"""
Called when task state changes to RUNNING.

previous_task_state and task_instance object can be used to retrieve more information about current
task_instance that is running, its dag_run, task and dag information.
task_instance that is running, its dag_run, task and dag information. ``msg`` carries a short
canonical context (e.g. ``"started"``) so listeners can route or filter events without
re-deriving intent from other fields.
"""
print("Task instance is in running state")
print(" Previous state of the Task instance:", previous_state)
print(f" Previous state: {previous_state}, msg: {msg}")

name: str = task_instance.task_id

Expand All @@ -65,7 +69,9 @@ def on_task_instance_running(
# [START howto_listen_ti_success_task]
@hookimpl
def on_task_instance_success(
previous_state: TaskInstanceState | None, task_instance: RuntimeTaskInstance | TaskInstance
previous_state: TaskInstanceState | None,
task_instance: RuntimeTaskInstance | TaskInstance,
msg: str,
):
"""
Called when task state changes to SUCCESS.
Expand All @@ -75,9 +81,12 @@ def on_task_instance_success(

A RuntimeTaskInstance is provided in most cases, except when the task's state change is triggered
through the API. In that case, the TaskInstance available on the API server will be provided instead.

``msg`` carries a short canonical context (e.g. ``"success"`` from the worker,
``"manually_set_to_success"`` when the state was changed via the API).
"""
print("Task instance in success state")
print(" Previous state of the Task instance:", previous_state)
print(f" Previous state: {previous_state}, msg: {msg}")

if isinstance(task_instance, TaskInstance):
print("Task instance's state was changed through the API.")
Expand All @@ -100,6 +109,7 @@ def on_task_instance_failed(
previous_state: TaskInstanceState | None,
task_instance: RuntimeTaskInstance | TaskInstance,
error: None | str | BaseException,
msg: str,
):
"""
Called when task state changes to FAILED.
Expand All @@ -109,8 +119,12 @@ def on_task_instance_failed(

A RuntimeTaskInstance is provided in most cases, except when the task's state change is triggered
through the API. In that case, the TaskInstance available on the API server will be provided instead.

``msg`` distinguishes failure paths without inspecting ``error``:
``"failed"`` (terminal), ``"up_for_retry"`` (will retry), or
``"manually_set_to_failed"`` (API-driven state change).
"""
print("Task instance in failure state")
print(f"Task instance in failure state (msg={msg})")

if isinstance(task_instance, TaskInstance):
print("Task instance's state was changed through the API.")
Expand Down Expand Up @@ -138,7 +152,9 @@ def on_task_instance_failed(
# [START howto_listen_ti_skipped_task]
@hookimpl
def on_task_instance_skipped(
previous_state: TaskInstanceState | None, task_instance: RuntimeTaskInstance | TaskInstance
previous_state: TaskInstanceState | None,
task_instance: RuntimeTaskInstance | TaskInstance,
msg: str,
):
"""
Called when a task instance skips itself during execution.
Expand All @@ -153,8 +169,11 @@ def on_task_instance_skipped(

For comprehensive tracking of skipped tasks, use DAG-level listeners
(on_dag_run_success/on_dag_run_failed) which may have access to all task states.

``msg`` carries the canonical context (``"skipped"`` or
``"manually_set_to_skipped"``).
"""
print("Task instance was skipped")
print(f"Task instance was skipped (msg={msg})")

if isinstance(task_instance, TaskInstance):
print("Task instance's state was changed through the API.")
Expand Down
5 changes: 4 additions & 1 deletion airflow-core/src/airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1773,7 +1773,10 @@ def fetch_handle_failure_context(

try:
get_listener_manager().hook.on_task_instance_failed(
previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error
previous_state=TaskInstanceState.RUNNING,
task_instance=ti,
error=error,
msg="up_for_retry",
)
except Exception:
log.exception("error calling listener")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,31 +35,62 @@
def on_task_instance_running(
previous_state: TaskInstanceState | None,
task_instance: RuntimeTaskInstance | TaskInstance,
msg: str,
):
"""Execute when task state changes to RUNNING. previous_state can be None."""
"""
Execute when task state changes to RUNNING. previous_state can be None.

:param previous_state: Previous state of the task instance (can be None)
:param task_instance: The task instance object
:param msg: Short canonical context for the state change. Always ``"started"``
for this hook. Mirrors the DagRun listener pattern so listeners can route
or filter events without re-deriving intent from other fields.
"""


@hookspec
def on_task_instance_success(
previous_state: TaskInstanceState | None,
task_instance: RuntimeTaskInstance | TaskInstance,
msg: str,
):
"""Execute when task state changes to SUCCESS. previous_state can be None."""
"""
Execute when task state changes to SUCCESS. previous_state can be None.

:param previous_state: Previous state of the task instance (can be None)
:param task_instance: The task instance object (RuntimeTaskInstance when called
from task execution context, TaskInstance when called from API server)
:param msg: Short canonical context for the state change. ``"success"`` when
emitted from the worker; ``"manually_set_to_success"`` when the state was
changed via the API.
"""


@hookspec
def on_task_instance_failed(
previous_state: TaskInstanceState | None,
task_instance: RuntimeTaskInstance | TaskInstance,
error: None | str | BaseException,
msg: str,
):
"""Execute when task state changes to FAIL. previous_state can be None."""
"""
Execute when task state changes to FAIL. previous_state can be None.

:param previous_state: Previous state of the task instance (can be None)
:param task_instance: The task instance object (RuntimeTaskInstance when called
from task execution context, TaskInstance when called from API server)
:param error: The exception or error message that caused the failure
:param msg: Short canonical context distinguishing failure paths without
inspecting ``error``. ``"failed"`` (terminal), ``"up_for_retry"`` (will
retry), or ``"manually_set_to_failed"`` (API-driven state change).
"""


@hookspec
def on_task_instance_skipped(
previous_state: TaskInstanceState | None,
task_instance: RuntimeTaskInstance | TaskInstance,
msg: str,
):
"""
Execute when a task instance skips itself during execution.
Expand All @@ -78,4 +109,7 @@ def on_task_instance_skipped(
:param previous_state: Previous state of the task instance (can be None)
:param task_instance: The task instance object (RuntimeTaskInstance when called
from task execution context, TaskInstance when called from API server)
:param msg: Short canonical context for the state change. ``"skipped"`` when
emitted from the worker; ``"manually_set_to_skipped"`` when the state was
changed via the API.
"""
16 changes: 11 additions & 5 deletions task-sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1175,7 +1175,7 @@ def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) -> ToSuperv
try:
# TODO: Call pre execute etc.
get_listener_manager().hook.on_task_instance_running(
previous_state=TaskInstanceState.QUEUED, task_instance=ti
previous_state=TaskInstanceState.QUEUED, task_instance=ti, msg="started"
)
except Exception:
log.exception("error calling listener")
Expand Down Expand Up @@ -1900,23 +1900,26 @@ def finalize(
_run_task_state_change_callbacks(task, "on_success_callback", context, log)
try:
get_listener_manager().hook.on_task_instance_success(
previous_state=TaskInstanceState.RUNNING, task_instance=ti
previous_state=TaskInstanceState.RUNNING, task_instance=ti, msg="success"
)
except Exception:
log.exception("error calling listener")
elif state == TaskInstanceState.SKIPPED:
_run_task_state_change_callbacks(task, "on_skipped_callback", context, log)
try:
get_listener_manager().hook.on_task_instance_skipped(
previous_state=TaskInstanceState.RUNNING, task_instance=ti
previous_state=TaskInstanceState.RUNNING, task_instance=ti, msg="skipped"
)
except Exception:
log.exception("error calling listener")
elif state == TaskInstanceState.UP_FOR_RETRY:
_run_task_state_change_callbacks(task, "on_retry_callback", context, log)
try:
get_listener_manager().hook.on_task_instance_failed(
previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error
previous_state=TaskInstanceState.RUNNING,
task_instance=ti,
error=error,
msg="up_for_retry",
)
except Exception:
log.exception("error calling listener")
Expand All @@ -1926,7 +1929,10 @@ def finalize(
_run_task_state_change_callbacks(task, "on_failure_callback", context, log)
try:
get_listener_manager().hook.on_task_instance_failed(
previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error
previous_state=TaskInstanceState.RUNNING,
task_instance=ti,
error=error,
msg="failed",
)
except Exception:
log.exception("error calling listener")
Expand Down
62 changes: 57 additions & 5 deletions task-sdk/tests/task_sdk/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3814,30 +3814,35 @@ def return_num(num):
class TestTaskRunnerCallsListeners:
class CustomListener:
def __init__(self):
self.state = []
self.state: list[TaskInstanceState] = []
self.component = None
self.error = None
self.msgs: list[str] = []

@hookimpl
def on_starting(self, component):
self.component = component

@hookimpl
def on_task_instance_running(self, previous_state, task_instance):
def on_task_instance_running(self, previous_state, task_instance, msg: str):
self.state.append(TaskInstanceState.RUNNING)
self.msgs.append(msg)

@hookimpl
def on_task_instance_success(self, previous_state, task_instance):
def on_task_instance_success(self, previous_state, task_instance, msg: str):
self.state.append(TaskInstanceState.SUCCESS)
self.msgs.append(msg)

@hookimpl
def on_task_instance_failed(self, previous_state, task_instance, error):
def on_task_instance_failed(self, previous_state, task_instance, error, msg: str):
self.state.append(TaskInstanceState.FAILED)
self.error = error
self.msgs.append(msg)

@hookimpl
def on_task_instance_skipped(self, previous_state, task_instance):
def on_task_instance_skipped(self, previous_state, task_instance, msg: str):
self.state.append(TaskInstanceState.SKIPPED)
self.msgs.append(msg)

@hookimpl
def before_stopping(self, component):
Expand Down Expand Up @@ -3940,6 +3945,7 @@ def execute(self, context):
finalize(runtime_ti, state, context, log)

assert listener.state == [TaskInstanceState.RUNNING, TaskInstanceState.SUCCESS]
assert listener.msgs == ["started", "success"]

@pytest.mark.parametrize(
"exception",
Expand Down Expand Up @@ -3982,6 +3988,7 @@ def execute(self, context):

assert listener.state == [TaskInstanceState.RUNNING, TaskInstanceState.FAILED]
assert listener.error == error
assert listener.msgs == ["started", "failed"]

def test_task_runner_calls_listeners_skipped(self, mocked_parse, mock_supervisor_comms, listener_manager):
listener = self.CustomListener()
Expand Down Expand Up @@ -4013,6 +4020,51 @@ def execute(self, context):
finalize(runtime_ti, state, context, log)

assert listener.state == [TaskInstanceState.RUNNING, TaskInstanceState.SKIPPED]
assert listener.msgs == ["started", "skipped"]

@pytest.mark.parametrize(
("should_retry", "expected_state", "expected_failure_msg"),
[
pytest.param(
True,
TaskInstanceState.UP_FOR_RETRY,
"up_for_retry",
id="up_for_retry-when-retries-remain",
),
pytest.param(
False,
TaskInstanceState.FAILED,
"failed",
id="failed-when-no-retries-remain",
),
],
)
def test_task_runner_listener_msg_distinguishes_retry_vs_terminal(
self,
create_runtime_ti,
listener_manager,
should_retry,
expected_state,
expected_failure_msg,
):
"""The ``msg`` arg lets a listener distinguish ``up_for_retry`` from
terminal ``failed`` without inspecting ``error`` or task config."""
listener = self.CustomListener()
listener_manager(listener)

class CustomOperator(BaseOperator):
def execute(self, context):
raise ValueError("boom")

task = CustomOperator(task_id="task_listener_msg_paths")
runtime_ti = create_runtime_ti(dag_id="dag", task=task, should_retry=should_retry)
log = mock.MagicMock()
context = runtime_ti.get_template_context()
state, _, error = run(runtime_ti, context, log)
finalize(runtime_ti, state, context, log, error)

assert state == expected_state
assert listener.msgs == ["started", expected_failure_msg]

def test_listener_access_outlet_event_on_running_and_success(
self, mocked_parse, mock_supervisor_comms, listener_manager
Expand Down
Loading