diff --git a/airflow-core/newsfragments/66394.significant.rst b/airflow-core/newsfragments/66394.significant.rst new file mode 100644 index 0000000000000..d3ded83dc2a2f --- /dev/null +++ b/airflow-core/newsfragments/66394.significant.rst @@ -0,0 +1,10 @@ +Add a required ``msg`` keyword argument to the four task instance listener hooks (``on_task_instance_running``, ``on_task_instance_success``, ``on_task_instance_failed``, ``on_task_instance_skipped``). + +The ``msg`` arg carries short canonical context for the state change +(``"started"``, ``"success"``, ``"skipped"``, ``"failed"``, ``"up_for_retry"`` +from the worker; ``"manually_set_to_*"`` when the state was changed via the +API). This mirrors the DagRun listener pattern (#56272) and lets listener +implementations route or filter events without re-deriving intent from other +fields. Existing ``hookimpl`` implementations that don't declare ``msg`` +continue to work unchanged — pluggy passes only the parameters each +implementation declares. diff --git a/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py b/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py index 65e8260c2f47a..33b37f89b71ed 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py @@ -78,15 +78,20 @@ def _emit_state_listener_hooks(updated_tis: list[TI], new_state: str | TaskInsta for ti in updated_tis: try: if new_state == TaskInstanceState.SUCCESS: - get_listener_manager().hook.on_task_instance_success(previous_state=None, task_instance=ti) + get_listener_manager().hook.on_task_instance_success( + previous_state=None, task_instance=ti, msg="manually_set_to_success" + ) elif new_state == TaskInstanceState.FAILED: get_listener_manager().hook.on_task_instance_failed( previous_state=None, task_instance=ti, error=f"TaskInstance's state was manually set to `{TaskInstanceState.FAILED}`.", + msg="manually_set_to_failed", ) elif new_state == TaskInstanceState.SKIPPED: - get_listener_manager().hook.on_task_instance_skipped(previous_state=None, task_instance=ti) + get_listener_manager().hook.on_task_instance_skipped( + previous_state=None, task_instance=ti, msg="manually_set_to_skipped" + ) except Exception: log.exception("error calling listener") diff --git a/airflow-core/src/airflow/example_dags/plugins/event_listener.py b/airflow-core/src/airflow/example_dags/plugins/event_listener.py index 91af9f5ccc6df..0ec64fd7eb7b3 100644 --- a/airflow-core/src/airflow/example_dags/plugins/event_listener.py +++ b/airflow-core/src/airflow/example_dags/plugins/event_listener.py @@ -31,16 +31,20 @@ # [START howto_listen_ti_running_task] @hookimpl def on_task_instance_running( - previous_state: TaskInstanceState | None, task_instance: RuntimeTaskInstance | TaskInstance + previous_state: TaskInstanceState | None, + task_instance: RuntimeTaskInstance | TaskInstance, + msg: str, ): """ Called when task state changes to RUNNING. previous_task_state and task_instance object can be used to retrieve more information about current - task_instance that is running, its dag_run, task and dag information. + task_instance that is running, its dag_run, task and dag information. ``msg`` carries a short + canonical context (e.g. ``"started"``) so listeners can route or filter events without + re-deriving intent from other fields. """ print("Task instance is in running state") - print(" Previous state of the Task instance:", previous_state) + print(f" Previous state: {previous_state}, msg: {msg}") name: str = task_instance.task_id @@ -65,7 +69,9 @@ def on_task_instance_running( # [START howto_listen_ti_success_task] @hookimpl def on_task_instance_success( - previous_state: TaskInstanceState | None, task_instance: RuntimeTaskInstance | TaskInstance + previous_state: TaskInstanceState | None, + task_instance: RuntimeTaskInstance | TaskInstance, + msg: str, ): """ Called when task state changes to SUCCESS. @@ -75,9 +81,12 @@ def on_task_instance_success( A RuntimeTaskInstance is provided in most cases, except when the task's state change is triggered through the API. In that case, the TaskInstance available on the API server will be provided instead. + + ``msg`` carries a short canonical context (e.g. ``"success"`` from the worker, + ``"manually_set_to_success"`` when the state was changed via the API). """ print("Task instance in success state") - print(" Previous state of the Task instance:", previous_state) + print(f" Previous state: {previous_state}, msg: {msg}") if isinstance(task_instance, TaskInstance): print("Task instance's state was changed through the API.") @@ -100,6 +109,7 @@ def on_task_instance_failed( previous_state: TaskInstanceState | None, task_instance: RuntimeTaskInstance | TaskInstance, error: None | str | BaseException, + msg: str, ): """ Called when task state changes to FAILED. @@ -109,8 +119,12 @@ def on_task_instance_failed( A RuntimeTaskInstance is provided in most cases, except when the task's state change is triggered through the API. In that case, the TaskInstance available on the API server will be provided instead. + + ``msg`` distinguishes failure paths without inspecting ``error``: + ``"failed"`` (terminal), ``"up_for_retry"`` (will retry), or + ``"manually_set_to_failed"`` (API-driven state change). """ - print("Task instance in failure state") + print(f"Task instance in failure state (msg={msg})") if isinstance(task_instance, TaskInstance): print("Task instance's state was changed through the API.") @@ -138,7 +152,9 @@ def on_task_instance_failed( # [START howto_listen_ti_skipped_task] @hookimpl def on_task_instance_skipped( - previous_state: TaskInstanceState | None, task_instance: RuntimeTaskInstance | TaskInstance + previous_state: TaskInstanceState | None, + task_instance: RuntimeTaskInstance | TaskInstance, + msg: str, ): """ Called when a task instance skips itself during execution. @@ -153,8 +169,11 @@ def on_task_instance_skipped( For comprehensive tracking of skipped tasks, use DAG-level listeners (on_dag_run_success/on_dag_run_failed) which may have access to all task states. + + ``msg`` carries the canonical context (``"skipped"`` or + ``"manually_set_to_skipped"``). """ - print("Task instance was skipped") + print(f"Task instance was skipped (msg={msg})") if isinstance(task_instance, TaskInstance): print("Task instance's state was changed through the API.") diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index e5b19f2768da9..814d8e5977216 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -1773,7 +1773,10 @@ def fetch_handle_failure_context( try: get_listener_manager().hook.on_task_instance_failed( - previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error + previous_state=TaskInstanceState.RUNNING, + task_instance=ti, + error=error, + msg="up_for_retry", ) except Exception: log.exception("error calling listener") diff --git a/shared/listeners/src/airflow_shared/listeners/spec/taskinstance.py b/shared/listeners/src/airflow_shared/listeners/spec/taskinstance.py index d3450d6b05aa7..7a8e5473aaf4c 100644 --- a/shared/listeners/src/airflow_shared/listeners/spec/taskinstance.py +++ b/shared/listeners/src/airflow_shared/listeners/spec/taskinstance.py @@ -35,16 +35,35 @@ def on_task_instance_running( previous_state: TaskInstanceState | None, task_instance: RuntimeTaskInstance | TaskInstance, + msg: str, ): - """Execute when task state changes to RUNNING. previous_state can be None.""" + """ + Execute when task state changes to RUNNING. previous_state can be None. + + :param previous_state: Previous state of the task instance (can be None) + :param task_instance: The task instance object + :param msg: Short canonical context for the state change. Always ``"started"`` + for this hook. Mirrors the DagRun listener pattern so listeners can route + or filter events without re-deriving intent from other fields. + """ @hookspec def on_task_instance_success( previous_state: TaskInstanceState | None, task_instance: RuntimeTaskInstance | TaskInstance, + msg: str, ): - """Execute when task state changes to SUCCESS. previous_state can be None.""" + """ + Execute when task state changes to SUCCESS. previous_state can be None. + + :param previous_state: Previous state of the task instance (can be None) + :param task_instance: The task instance object (RuntimeTaskInstance when called + from task execution context, TaskInstance when called from API server) + :param msg: Short canonical context for the state change. ``"success"`` when + emitted from the worker; ``"manually_set_to_success"`` when the state was + changed via the API. + """ @hookspec @@ -52,14 +71,26 @@ def on_task_instance_failed( previous_state: TaskInstanceState | None, task_instance: RuntimeTaskInstance | TaskInstance, error: None | str | BaseException, + msg: str, ): - """Execute when task state changes to FAIL. previous_state can be None.""" + """ + Execute when task state changes to FAIL. previous_state can be None. + + :param previous_state: Previous state of the task instance (can be None) + :param task_instance: The task instance object (RuntimeTaskInstance when called + from task execution context, TaskInstance when called from API server) + :param error: The exception or error message that caused the failure + :param msg: Short canonical context distinguishing failure paths without + inspecting ``error``. ``"failed"`` (terminal), ``"up_for_retry"`` (will + retry), or ``"manually_set_to_failed"`` (API-driven state change). + """ @hookspec def on_task_instance_skipped( previous_state: TaskInstanceState | None, task_instance: RuntimeTaskInstance | TaskInstance, + msg: str, ): """ Execute when a task instance skips itself during execution. @@ -78,4 +109,7 @@ def on_task_instance_skipped( :param previous_state: Previous state of the task instance (can be None) :param task_instance: The task instance object (RuntimeTaskInstance when called from task execution context, TaskInstance when called from API server) + :param msg: Short canonical context for the state change. ``"skipped"`` when + emitted from the worker; ``"manually_set_to_skipped"`` when the state was + changed via the API. """ diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 56ba8343c648b..da1c379dc00c4 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -1175,7 +1175,7 @@ def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) -> ToSuperv try: # TODO: Call pre execute etc. get_listener_manager().hook.on_task_instance_running( - previous_state=TaskInstanceState.QUEUED, task_instance=ti + previous_state=TaskInstanceState.QUEUED, task_instance=ti, msg="started" ) except Exception: log.exception("error calling listener") @@ -1900,7 +1900,7 @@ def finalize( _run_task_state_change_callbacks(task, "on_success_callback", context, log) try: get_listener_manager().hook.on_task_instance_success( - previous_state=TaskInstanceState.RUNNING, task_instance=ti + previous_state=TaskInstanceState.RUNNING, task_instance=ti, msg="success" ) except Exception: log.exception("error calling listener") @@ -1908,7 +1908,7 @@ def finalize( _run_task_state_change_callbacks(task, "on_skipped_callback", context, log) try: get_listener_manager().hook.on_task_instance_skipped( - previous_state=TaskInstanceState.RUNNING, task_instance=ti + previous_state=TaskInstanceState.RUNNING, task_instance=ti, msg="skipped" ) except Exception: log.exception("error calling listener") @@ -1916,7 +1916,10 @@ def finalize( _run_task_state_change_callbacks(task, "on_retry_callback", context, log) try: get_listener_manager().hook.on_task_instance_failed( - previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error + previous_state=TaskInstanceState.RUNNING, + task_instance=ti, + error=error, + msg="up_for_retry", ) except Exception: log.exception("error calling listener") @@ -1926,7 +1929,10 @@ def finalize( _run_task_state_change_callbacks(task, "on_failure_callback", context, log) try: get_listener_manager().hook.on_task_instance_failed( - previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error + previous_state=TaskInstanceState.RUNNING, + task_instance=ti, + error=error, + msg="failed", ) except Exception: log.exception("error calling listener") diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 630aff9094ed1..e03cadc35b51e 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -3814,30 +3814,35 @@ def return_num(num): class TestTaskRunnerCallsListeners: class CustomListener: def __init__(self): - self.state = [] + self.state: list[TaskInstanceState] = [] self.component = None self.error = None + self.msgs: list[str] = [] @hookimpl def on_starting(self, component): self.component = component @hookimpl - def on_task_instance_running(self, previous_state, task_instance): + def on_task_instance_running(self, previous_state, task_instance, msg: str): self.state.append(TaskInstanceState.RUNNING) + self.msgs.append(msg) @hookimpl - def on_task_instance_success(self, previous_state, task_instance): + def on_task_instance_success(self, previous_state, task_instance, msg: str): self.state.append(TaskInstanceState.SUCCESS) + self.msgs.append(msg) @hookimpl - def on_task_instance_failed(self, previous_state, task_instance, error): + def on_task_instance_failed(self, previous_state, task_instance, error, msg: str): self.state.append(TaskInstanceState.FAILED) self.error = error + self.msgs.append(msg) @hookimpl - def on_task_instance_skipped(self, previous_state, task_instance): + def on_task_instance_skipped(self, previous_state, task_instance, msg: str): self.state.append(TaskInstanceState.SKIPPED) + self.msgs.append(msg) @hookimpl def before_stopping(self, component): @@ -3940,6 +3945,7 @@ def execute(self, context): finalize(runtime_ti, state, context, log) assert listener.state == [TaskInstanceState.RUNNING, TaskInstanceState.SUCCESS] + assert listener.msgs == ["started", "success"] @pytest.mark.parametrize( "exception", @@ -3982,6 +3988,7 @@ def execute(self, context): assert listener.state == [TaskInstanceState.RUNNING, TaskInstanceState.FAILED] assert listener.error == error + assert listener.msgs == ["started", "failed"] def test_task_runner_calls_listeners_skipped(self, mocked_parse, mock_supervisor_comms, listener_manager): listener = self.CustomListener() @@ -4013,6 +4020,51 @@ def execute(self, context): finalize(runtime_ti, state, context, log) assert listener.state == [TaskInstanceState.RUNNING, TaskInstanceState.SKIPPED] + assert listener.msgs == ["started", "skipped"] + + @pytest.mark.parametrize( + ("should_retry", "expected_state", "expected_failure_msg"), + [ + pytest.param( + True, + TaskInstanceState.UP_FOR_RETRY, + "up_for_retry", + id="up_for_retry-when-retries-remain", + ), + pytest.param( + False, + TaskInstanceState.FAILED, + "failed", + id="failed-when-no-retries-remain", + ), + ], + ) + def test_task_runner_listener_msg_distinguishes_retry_vs_terminal( + self, + create_runtime_ti, + listener_manager, + should_retry, + expected_state, + expected_failure_msg, + ): + """The ``msg`` arg lets a listener distinguish ``up_for_retry`` from + terminal ``failed`` without inspecting ``error`` or task config.""" + listener = self.CustomListener() + listener_manager(listener) + + class CustomOperator(BaseOperator): + def execute(self, context): + raise ValueError("boom") + + task = CustomOperator(task_id="task_listener_msg_paths") + runtime_ti = create_runtime_ti(dag_id="dag", task=task, should_retry=should_retry) + log = mock.MagicMock() + context = runtime_ti.get_template_context() + state, _, error = run(runtime_ti, context, log) + finalize(runtime_ti, state, context, log, error) + + assert state == expected_state + assert listener.msgs == ["started", expected_failure_msg] def test_listener_access_outlet_event_on_running_and_success( self, mocked_parse, mock_supervisor_comms, listener_manager