From 81b886b4abc5be39d72c4101e6aa0b25184023e0 Mon Sep 17 00:00:00 2001 From: Maciej Obuchowski Date: Sun, 29 Dec 2024 18:35:48 +0100 Subject: [PATCH] task sdk: call on_task_instance_* listeners Signed-off-by: Maciej Obuchowski --- airflow/api_fastapi/execution_api/app.py | 3 +- .../execution_api/datamodels/dagrun.py | 35 +++ .../execution_api/datamodels/taskinstance.py | 5 + .../execution_api/routes/task_instances.py | 23 +- airflow/dag_processing/processor.py | 2 +- .../example_dags/plugins/event_listener.py | 44 ++-- airflow/executors/workloads.py | 14 ++ airflow/listeners/listener.py | 8 +- airflow/listeners/spec/taskinstance.py | 15 +- airflow/models/taskinstance.py | 8 +- .../providers/openlineage/plugins/listener.py | 224 +++++++++++------- .../openlineage/utils/selective_enable.py | 8 +- .../providers/openlineage/utils/utils.py | 42 +--- .../openlineage/extractors/test_manager.py | 148 +++++++++++- .../openlineage/plugins/test_execution.py | 1 + .../openlineage/plugins/test_listener.py | 34 ++- .../airflow/sdk/api/datamodels/_generated.py | 3 + .../airflow/sdk/execution_time/supervisor.py | 1 + .../airflow/sdk/execution_time/task_runner.py | 20 ++ task_sdk/tests/conftest.py | 11 +- .../tests/execution_time/test_supervisor.py | 10 +- .../tests/execution_time/test_task_runner.py | 20 +- tests/conftest.py | 8 + tests/listeners/class_listener.py | 14 +- tests/listeners/empty_listener.py | 2 +- tests/listeners/file_write_listener.py | 8 +- tests/listeners/full_listener.py | 6 +- tests/listeners/partial_listener.py | 2 +- tests/listeners/slow_listener.py | 2 +- tests/listeners/throwing_listener.py | 2 +- tests/listeners/very_slow_listener.py | 2 +- tests/listeners/xcom_listener.py | 4 +- tests/plugins/test_plugins_manager.py | 29 ++- 33 files changed, 523 insertions(+), 235 deletions(-) create mode 100644 airflow/api_fastapi/execution_api/datamodels/dagrun.py diff --git a/airflow/api_fastapi/execution_api/app.py b/airflow/api_fastapi/execution_api/app.py index 61283dc2cf87f..702ee9ad6a423 100644 --- a/airflow/api_fastapi/execution_api/app.py +++ b/airflow/api_fastapi/execution_api/app.py @@ -77,8 +77,9 @@ def custom_openapi() -> dict: def get_extra_schemas() -> dict[str, dict]: """Get all the extra schemas that are not part of the main FastAPI app.""" - from airflow.api_fastapi.execution_api.datamodels import taskinstance + from airflow.api_fastapi.execution_api.datamodels import dagrun, taskinstance return { "TaskInstance": taskinstance.TaskInstance.model_json_schema(), + "DagRun": dagrun.DagRun.model_json_schema(), } diff --git a/airflow/api_fastapi/execution_api/datamodels/dagrun.py b/airflow/api_fastapi/execution_api/datamodels/dagrun.py new file mode 100644 index 0000000000000..f9f99c1c43b42 --- /dev/null +++ b/airflow/api_fastapi/execution_api/datamodels/dagrun.py @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This model is not used in the API, but it is included in generated OpenAPI schema +# for use in the client SDKs. +from __future__ import annotations + +from airflow.api_fastapi.common.types import UtcDateTime +from airflow.api_fastapi.core_api.base import BaseModel + + +class DagRun(BaseModel): + """Schema for TaskInstance model with minimal required fields needed for OL for now.""" + + id: int + dag_id: str + run_id: str + logical_date: UtcDateTime + data_interval_start: UtcDateTime + data_interval_end: UtcDateTime + clear_number: int diff --git a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index c1bf588c2bbd4..bbba320d0f886 100644 --- a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -164,6 +164,7 @@ class TaskInstance(BaseModel): run_id: str try_number: int map_index: int | None = None + start_date: UtcDateTime class DagRun(BaseModel): @@ -180,6 +181,7 @@ class DagRun(BaseModel): data_interval_end: UtcDateTime | None start_date: UtcDateTime end_date: UtcDateTime | None + clear_number: int run_type: DagRunType conf: Annotated[dict[str, Any], Field(default_factory=dict)] @@ -190,6 +192,9 @@ class TIRunContext(BaseModel): dag_run: DagRun """DAG run information for the task instance.""" + task_reschedule_count: Annotated[int, Field(default=0)] + """How many times the task has been rescheduled.""" + variables: Annotated[list[VariableResponse], Field(default_factory=list)] """Variables that can be accessed by the task instance.""" diff --git a/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow/api_fastapi/execution_api/routes/task_instances.py index 4956466ca707a..dadf165c96b89 100644 --- a/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -23,7 +23,7 @@ from fastapi import Body, HTTPException, status from pydantic import JsonValue -from sqlalchemy import update +from sqlalchemy import func, update from sqlalchemy.exc import NoResultFound, SQLAlchemyError from sqlalchemy.sql import select @@ -73,9 +73,9 @@ def ti_run( # We only use UUID above for validation purposes ti_id_str = str(task_instance_id) - old = select(TI.state, TI.dag_id, TI.run_id).where(TI.id == ti_id_str).with_for_update() + old = select(TI.state, TI.dag_id, TI.run_id, TI.try_number).where(TI.id == ti_id_str).with_for_update() try: - (previous_state, dag_id, run_id) = session.execute(old).one() + (previous_state, dag_id, run_id, try_number) = session.execute(old).one() except NoResultFound: log.error("Task Instance %s not found", ti_id_str) raise HTTPException( @@ -135,6 +135,7 @@ def ti_run( DR.data_interval_end, DR.start_date, DR.end_date, + DR.clear_number, DR.run_type, DR.conf, DR.logical_date, @@ -144,8 +145,24 @@ def ti_run( if not dr: raise ValueError(f"DagRun with dag_id={dag_id} and run_id={run_id} not found.") + task_reschedule_count = ( + session.query( + func.count(TaskReschedule.id) # or any other primary key column + ) + .filter( + TaskReschedule.dag_id == dag_id, + TaskReschedule.task_id == ti_id_str, + TaskReschedule.run_id == run_id, + # TaskReschedule.map_index == ti.map_index, # TODO: Handle mapped tasks + TaskReschedule.try_number == try_number, + ) + .scalar() + or 0 + ) + return TIRunContext( dag_run=DagRun.model_validate(dr, from_attributes=True), + task_reschedule_count=task_reschedule_count, # TODO: Add variables and connections that are needed (and has perms) for the task variables=[], connections=[], diff --git a/airflow/dag_processing/processor.py b/airflow/dag_processing/processor.py index 981e00341a8e6..28f2c6d7f005b 100644 --- a/airflow/dag_processing/processor.py +++ b/airflow/dag_processing/processor.py @@ -111,7 +111,7 @@ def _execute_callbacks( dagbag: DagBag, callback_requests: list[CallbackRequest], log: FilteringBoundLogger ) -> None: for request in callback_requests: - log.debug("Processing Callback Request", request=request) + log.debug("Processing Callback Request", request=request.to_json()) if isinstance(request, TaskCallbackRequest): raise NotImplementedError( "Haven't coded Task callback yet - https://github.com/apache/airflow/issues/44354!" diff --git a/airflow/example_dags/plugins/event_listener.py b/airflow/example_dags/plugins/event_listener.py index 6d9fe2ff11735..b0001b0bc7e5d 100644 --- a/airflow/example_dags/plugins/event_listener.py +++ b/airflow/example_dags/plugins/event_listener.py @@ -23,13 +23,13 @@ if TYPE_CHECKING: from airflow.models.dagrun import DagRun - from airflow.models.taskinstance import TaskInstance + from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance from airflow.utils.state import TaskInstanceState # [START howto_listen_ti_running_task] @hookimpl -def on_task_instance_running(previous_state: TaskInstanceState, task_instance: TaskInstance, session): +def on_task_instance_running(previous_state: TaskInstanceState, task_instance: RuntimeTaskInstance): """ This method is called when task state changes to RUNNING. Through callback, parameters like previous_task_state, task_instance object can be accessed. @@ -39,14 +39,11 @@ def on_task_instance_running(previous_state: TaskInstanceState, task_instance: T print("Task instance is in running state") print(" Previous state of the Task instance:", previous_state) - state: TaskInstanceState = task_instance.state name: str = task_instance.task_id - start_date = task_instance.start_date - dagrun = task_instance.dag_run - dagrun_status = dagrun.state + context = task_instance.get_template_context() - task = task_instance.task + task = context["task"] if TYPE_CHECKING: assert task @@ -55,8 +52,8 @@ def on_task_instance_running(previous_state: TaskInstanceState, task_instance: T dag_name = None if dag: dag_name = dag.dag_id - print(f"Current task name:{name} state:{state} start_date:{start_date}") - print(f"Dag name:{dag_name} and current dag run status:{dagrun_status}") + print(f"Current task name:{name}") + print(f"Dag name:{dag_name}") # [END howto_listen_ti_running_task] @@ -64,7 +61,7 @@ def on_task_instance_running(previous_state: TaskInstanceState, task_instance: T # [START howto_listen_ti_success_task] @hookimpl -def on_task_instance_success(previous_state: TaskInstanceState, task_instance: TaskInstance, session): +def on_task_instance_success(previous_state: TaskInstanceState, task_instance: RuntimeTaskInstance): """ This method is called when task state changes to SUCCESS. Through callback, parameters like previous_task_state, task_instance object can be accessed. @@ -74,14 +71,10 @@ def on_task_instance_success(previous_state: TaskInstanceState, task_instance: T print("Task instance in success state") print(" Previous state of the Task instance:", previous_state) - dag_id = task_instance.dag_id - hostname = task_instance.hostname - operator = task_instance.operator + context = task_instance.get_template_context() + operator = context["task"] - dagrun = task_instance.dag_run - queued_at = dagrun.queued_at - print(f"Dag name:{dag_id} queued_at:{queued_at}") - print(f"Task hostname:{hostname} operator:{operator}") + print(f"Task operator:{operator}") # [END howto_listen_ti_success_task] @@ -90,7 +83,7 @@ def on_task_instance_success(previous_state: TaskInstanceState, task_instance: T # [START howto_listen_ti_failure_task] @hookimpl def on_task_instance_failed( - previous_state: TaskInstanceState, task_instance: TaskInstance, error: None | str | BaseException, session + previous_state: TaskInstanceState, task_instance: RuntimeTaskInstance, error: None | str | BaseException ): """ This method is called when task state changes to FAILED. @@ -100,21 +93,14 @@ def on_task_instance_failed( """ print("Task instance in failure state") - start_date = task_instance.start_date - end_date = task_instance.end_date - duration = task_instance.duration - - dagrun = task_instance.dag_run - - task = task_instance.task + context = task_instance.get_template_context() + task = context["task"] if TYPE_CHECKING: assert task - dag = task.dag - - print(f"Task start:{start_date} end:{end_date} duration:{duration}") - print(f"Task:{task} dag:{dag} dagrun:{dagrun}") + print("Task start") + print(f"Task:{task}") if error: print(f"Failure caused by {error}") diff --git a/airflow/executors/workloads.py b/airflow/executors/workloads.py index 0adb54cd6da27..26692d9b23987 100644 --- a/airflow/executors/workloads.py +++ b/airflow/executors/workloads.py @@ -18,6 +18,7 @@ import os import uuid +from datetime import datetime from pathlib import Path from typing import TYPE_CHECKING, Literal, Union @@ -49,6 +50,7 @@ class TaskInstance(BaseModel): run_id: str try_number: int map_index: int | None = None + start_date: datetime # TODO: Task-SDK: Can we replace TastInstanceKey with just the uuid across the codebase? @property @@ -64,6 +66,15 @@ def key(self) -> TaskInstanceKey: ) +class DagRun(BaseModel): + id: int + dag_id: str + run_id: str + logical_date: datetime + data_interval_start: datetime + data_interval_end: datetime + + class ExecuteTask(BaseActivity): """Execute the given Task.""" @@ -83,6 +94,9 @@ def make(cls, ti: TIModel, dag_path: Path | None = None) -> ExecuteTask: from airflow.utils.helpers import log_filename_template_renderer + if not ti.start_date: + ti.start_date = datetime.now() + ser_ti = TaskInstance.model_validate(ti, from_attributes=True) dag_path = dag_path or Path(ti.dag_run.dag_model.relative_fileloc) diff --git a/airflow/listeners/listener.py b/airflow/listeners/listener.py index 5e8fba55d4395..11918527ef252 100644 --- a/airflow/listeners/listener.py +++ b/airflow/listeners/listener.py @@ -46,7 +46,13 @@ class ListenerManager: """Manage listener registration and provides hook property for calling them.""" def __init__(self): - from airflow.listeners.spec import asset, dagrun, importerrors, lifecycle, taskinstance + from airflow.listeners.spec import ( + asset, + dagrun, + importerrors, + lifecycle, + taskinstance, + ) self.pm = pluggy.PluginManager("airflow") self.pm.add_hookcall_monitoring(_before_hookcall, _after_hookcall) diff --git a/airflow/listeners/spec/taskinstance.py b/airflow/listeners/spec/taskinstance.py index f012de0aac8ea..d66d6c83ce3de 100644 --- a/airflow/listeners/spec/taskinstance.py +++ b/airflow/listeners/spec/taskinstance.py @@ -22,33 +22,26 @@ from pluggy import HookspecMarker if TYPE_CHECKING: - from sqlalchemy.orm.session import Session - - from airflow.models.taskinstance import TaskInstance + from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance from airflow.utils.state import TaskInstanceState hookspec = HookspecMarker("airflow") @hookspec -def on_task_instance_running( - previous_state: TaskInstanceState | None, task_instance: TaskInstance, session: Session | None -): +def on_task_instance_running(previous_state: TaskInstanceState | None, task_instance: RuntimeTaskInstance): """Execute when task state changes to RUNNING. previous_state can be None.""" @hookspec -def on_task_instance_success( - previous_state: TaskInstanceState | None, task_instance: TaskInstance, session: Session | None -): +def on_task_instance_success(previous_state: TaskInstanceState | None, task_instance: RuntimeTaskInstance): """Execute when task state changes to SUCCESS. previous_state can be None.""" @hookspec def on_task_instance_failed( previous_state: TaskInstanceState | None, - task_instance: TaskInstance, + task_instance: RuntimeTaskInstance, error: None | str | BaseException, - session: Session | None, ): """Execute when task state changes to FAIL. previous_state can be None.""" diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index a5e50cb0d2cdb..52cdf89fb5c5a 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -355,12 +355,12 @@ def _run_raw_task( if not test_mode: _add_log(event=ti.state, task_instance=ti, session=session) if ti.state == TaskInstanceState.SUCCESS: - ti._register_asset_changes(events=context["outlet_events"], session=session) + ti._register_asset_changes(events=context["outlet_events"]) TaskInstance.save_to_db(ti=ti, session=session) if ti.state == TaskInstanceState.SUCCESS: get_listener_manager().hook.on_task_instance_success( - previous_state=TaskInstanceState.RUNNING, task_instance=ti, session=session + previous_state=TaskInstanceState.RUNNING, task_instance=ti ) return None @@ -2890,7 +2890,7 @@ def signal_handler(signum, frame): # Run on_task_instance_running event get_listener_manager().hook.on_task_instance_running( - previous_state=TaskInstanceState.QUEUED, task_instance=self, session=session + previous_state=TaskInstanceState.QUEUED, task_instance=self ) def _render_map_index(context: Context, *, jinja_env: jinja2.Environment | None) -> str | None: @@ -3132,7 +3132,7 @@ def fetch_handle_failure_context( callbacks = task.on_retry_callback if task else None get_listener_manager().hook.on_task_instance_failed( - previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error, session=session + previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error ) return { diff --git a/providers/src/airflow/providers/openlineage/plugins/listener.py b/providers/src/airflow/providers/openlineage/plugins/listener.py index aefd534f155e1..f93500d51d479 100644 --- a/providers/src/airflow/providers/openlineage/plugins/listener.py +++ b/providers/src/airflow/providers/openlineage/plugins/listener.py @@ -19,6 +19,7 @@ import logging import os from concurrent.futures import ProcessPoolExecutor +from datetime import datetime from typing import TYPE_CHECKING import psutil @@ -33,6 +34,7 @@ from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter, RunState from airflow.providers.openlineage.utils.utils import ( AIRFLOW_V_2_10_PLUS, + AIRFLOW_V_3_0_PLUS, get_airflow_dag_run_facet, get_airflow_debug_facet, get_airflow_job_facet, @@ -42,7 +44,6 @@ get_user_provided_run_facets, is_operator_disabled, is_selective_lineage_enabled, - is_ti_rescheduled_already, print_warning, ) from airflow.settings import configure_orm @@ -52,9 +53,9 @@ from airflow.utils.timeout import timeout if TYPE_CHECKING: - from sqlalchemy.orm import Session - from airflow.models import TaskInstance + from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance + from airflow.settings import Session _openlineage_listener: OpenLineageListener | None = None @@ -87,28 +88,58 @@ def __init__(self): self.extractor_manager = ExtractorManager() self.adapter = OpenLineageAdapter() - @hookimpl - def on_task_instance_running( - self, - previous_state: TaskInstanceState, - task_instance: TaskInstance, - session: Session, # This will always be QUEUED - ) -> None: - if not getattr(task_instance, "task", None) is not None: - self.log.warning( - "No task set for TI object task_id: %s - dag_id: %s - run_id %s", - task_instance.task_id, - task_instance.dag_id, - task_instance.run_id, - ) - return + if AIRFLOW_V_3_0_PLUS: - self.log.debug("OpenLineage listener got notification about task instance start") - dagrun = task_instance.dag_run - task = task_instance.task - if TYPE_CHECKING: - assert task - dag = task.dag + @hookimpl + def on_task_instance_running( + self, + previous_state: TaskInstanceState, + task_instance: RuntimeTaskInstance, + ): + if not getattr(task_instance, "task", None) is not None: + self.log.warning( + "No task set for TI object task_id: %s - dag_id: %s - run_id %s", + task_instance.task_id, + task_instance.dag_id, + task_instance.run_id, + ) + return + + self.log.debug("OpenLineage listener got notification about task instance start") + context = task_instance.get_template_context() + + task = context["task"] + if TYPE_CHECKING: + assert task + dagrun = context["dag_run"] + dag = context["dag"] + self._on_task_instance_running(task_instance, dag, dagrun, task) + else: + + @hookimpl + def on_task_instance_running( + self, + previous_state: TaskInstanceState, + task_instance: TaskInstance, + session: Session, # type: ignore[valid-type] + ) -> None: + if not getattr(task_instance, "task", None) is not None: + self.log.warning( + "No task set for TI object task_id: %s - dag_id: %s - run_id %s", + task_instance.task_id, + task_instance.dag_id, + task_instance.run_id, + ) + return + + self.log.debug("OpenLineage listener got notification about task instance start") + task = task_instance.task + if TYPE_CHECKING: + assert task + + self._on_task_instance_running(task_instance, task.dag, task_instance.dag_run, task) + + def _on_task_instance_running(self, task_instance: RuntimeTaskInstance | TaskInstance, dag, dagrun, task): if is_operator_disabled(task): self.log.debug( "Skipping OpenLineage event emission for operator `%s` " @@ -127,35 +158,34 @@ def on_task_instance_running( return # Needs to be calculated outside of inner method so that it gets cached for usage in fork processes + data_interval_start = dagrun.data_interval_start + if isinstance(data_interval_start, datetime): + data_interval_start = data_interval_start.isoformat() + data_interval_end = dagrun.data_interval_end + if isinstance(data_interval_end, datetime): + data_interval_end = data_interval_end.isoformat() + debug_facet = get_airflow_debug_facet() @print_warning(self.log) def on_running(): - # that's a workaround to detect task running from deferred state - # we return here because Airflow 2.3 needs task from deferred state - if task_instance.next_method is not None: - return - - if is_ti_rescheduled_already(task_instance): + context = task_instance.get_template_context() + if hasattr(context, "task_reschedule_count") and context["task_reschedule_count"] > 0: self.log.debug("Skipping this instance of rescheduled task - START event was emitted already") return parent_run_id = self.adapter.build_dag_run_id( dag_id=dag.dag_id, logical_date=dagrun.logical_date, - clear_number=dagrun.clear_number, + clear_number=0, ) - - if hasattr(task_instance, "logical_date"): - logical_date = task_instance.logical_date - else: - logical_date = task_instance.execution_date + start_date = task_instance.start_date if task_instance.start_date else timezone.utcnow() task_uuid = self.adapter.build_task_instance_run_id( dag_id=dag.dag_id, task_id=task.task_id, try_number=task_instance.try_number, - logical_date=logical_date, + logical_date=dagrun.logical_date, map_index=task_instance.map_index, ) event_type = RunState.RUNNING.value.lower() @@ -164,11 +194,6 @@ def on_running(): with Stats.timer(f"ol.extract.{event_type}.{operator_name}"): task_metadata = self.extractor_manager.extract_metadata(dagrun, task) - start_date = task_instance.start_date if task_instance.start_date else timezone.utcnow() - data_interval_start = ( - dagrun.data_interval_start.isoformat() if dagrun.data_interval_start else None - ) - data_interval_end = dagrun.data_interval_end.isoformat() if dagrun.data_interval_end else None redacted_event = self.adapter.start_task( run_id=task_uuid, job_name=get_job_name(task), @@ -195,17 +220,39 @@ def on_running(): self._execute(on_running, "on_running", use_fork=True) - @hookimpl - def on_task_instance_success( - self, previous_state: TaskInstanceState, task_instance: TaskInstance, session: Session - ) -> None: - self.log.debug("OpenLineage listener got notification about task instance success") + if AIRFLOW_V_3_0_PLUS: + + @hookimpl + def on_task_instance_success( + self, previous_state: TaskInstanceState, task_instance: RuntimeTaskInstance + ) -> None: + self.log.debug("OpenLineage listener got notification about task instance success") + + context = task_instance.get_template_context() + task = context["task"] + if TYPE_CHECKING: + assert task + dagrun = context["dag_run"] + dag = context["dag"] + self._on_task_instance_success(task_instance, dag, dagrun, task) - dagrun = task_instance.dag_run - task = task_instance.task - if TYPE_CHECKING: - assert task - dag = task.dag + else: + + @hookimpl + def on_task_instance_success( + self, + previous_state: TaskInstanceState, + task_instance: TaskInstance, + session: Session, # type: ignore[valid-type] + ) -> None: + self.log.debug("OpenLineage listener got notification about task instance success") + task = task_instance.task + if TYPE_CHECKING: + assert task + self._on_task_instance_success(task_instance, task.dag, task_instance.dag_run, task) + + def _on_task_instance_success(self, task_instance: RuntimeTaskInstance, dag, dagrun, task): + end_date = timezone.utcnow() if is_operator_disabled(task): self.log.debug( @@ -232,15 +279,11 @@ def on_success(): clear_number=dagrun.clear_number, ) - if hasattr(task_instance, "logical_date"): - logical_date = task_instance.logical_date - else: - logical_date = task_instance.execution_date task_uuid = self.adapter.build_task_instance_run_id( dag_id=dag.dag_id, task_id=task.task_id, try_number=_get_try_number_success(task_instance), - logical_date=logical_date, + logical_date=dagrun.logical_date, map_index=task_instance.map_index, ) event_type = RunState.COMPLETE.value.lower() @@ -251,8 +294,6 @@ def on_success(): dagrun, task, complete=True, task_instance=task_instance ) - end_date = task_instance.end_date if task_instance.end_date else timezone.utcnow() - redacted_event = self.adapter.complete_task( run_id=task_uuid, job_name=get_job_name(task), @@ -273,7 +314,7 @@ def on_success(): self._execute(on_success, "on_success", use_fork=True) - if AIRFLOW_V_2_10_PLUS: + if AIRFLOW_V_3_0_PLUS: @hookimpl def on_task_instance_failed( @@ -281,36 +322,54 @@ def on_task_instance_failed( previous_state: TaskInstanceState, task_instance: TaskInstance, error: None | str | BaseException, - session: Session, ) -> None: - self._on_task_instance_failed( - previous_state=previous_state, task_instance=task_instance, error=error, session=session - ) + self.log.debug("OpenLineage listener got notification about task instance failure") + context = task_instance.get_template_context() + task = context["task"] + if TYPE_CHECKING: + assert task + dagrun = context["dag_run"] + dag = context["dag"] + self._on_task_instance_failed(task_instance, dag, dagrun, task, error) + + elif AIRFLOW_V_2_10_PLUS: + @hookimpl + def on_task_instance_failed( + self, + previous_state: TaskInstanceState, + task_instance: TaskInstance, + error: None | str | BaseException, + session: Session, # type: ignore[valid-type] + ) -> None: + self.log.debug("OpenLineage listener got notification about task instance failure") + task = task_instance.task + if TYPE_CHECKING: + assert task + self._on_task_instance_failed(task_instance, task.dag, task_instance.dag_run, task, error) else: @hookimpl def on_task_instance_failed( - self, previous_state: TaskInstanceState, task_instance: TaskInstance, session: Session + self, + previous_state: TaskInstanceState, + task_instance: TaskInstance, + session: Session, # type: ignore[valid-type] ) -> None: - self._on_task_instance_failed( - previous_state=previous_state, task_instance=task_instance, error=None, session=session - ) + task = task_instance.task + if TYPE_CHECKING: + assert task + self._on_task_instance_failed(task_instance, task.dag, task_instance.dag_run, task) def _on_task_instance_failed( self, - previous_state: TaskInstanceState, - task_instance: TaskInstance, - session: Session, + task_instance: TaskInstance | RuntimeTaskInstance, + dag, + dagrun, + task, error: None | str | BaseException = None, ) -> None: - self.log.debug("OpenLineage listener got notification about task instance failure") - - dagrun = task_instance.dag_run - task = task_instance.task - if TYPE_CHECKING: - assert task - dag = task.dag + end_date = timezone.utcnow() if is_operator_disabled(task): self.log.debug( @@ -337,16 +396,11 @@ def on_failure(): clear_number=dagrun.clear_number, ) - if hasattr(task_instance, "logical_date"): - logical_date = task_instance.logical_date - else: - logical_date = task_instance.execution_date - task_uuid = self.adapter.build_task_instance_run_id( dag_id=dag.dag_id, task_id=task.task_id, try_number=task_instance.try_number, - logical_date=logical_date, + logical_date=dagrun.logical_date, map_index=task_instance.map_index, ) event_type = RunState.FAIL.value.lower() @@ -357,8 +411,6 @@ def on_failure(): dagrun, task, complete=True, task_instance=task_instance ) - end_date = task_instance.end_date if task_instance.end_date else timezone.utcnow() - redacted_event = self.adapter.fail_task( run_id=task_uuid, job_name=get_job_name(task), diff --git a/providers/src/airflow/providers/openlineage/utils/selective_enable.py b/providers/src/airflow/providers/openlineage/utils/selective_enable.py index a3c16a1e18da3..b0cd8a4455c88 100644 --- a/providers/src/airflow/providers/openlineage/utils/selective_enable.py +++ b/providers/src/airflow/providers/openlineage/utils/selective_enable.py @@ -18,7 +18,7 @@ from __future__ import annotations import logging -from typing import TypeVar +from typing import TYPE_CHECKING, TypeVar from airflow.models import DAG, Operator, Param from airflow.models.xcom_arg import XComArg @@ -28,6 +28,10 @@ DISABLE_OL_PARAM = Param(False, const=False) T = TypeVar("T", bound="DAG | Operator") +if TYPE_CHECKING: + from airflow.sdk.definitions.baseoperator import BaseOperator as SdkBaseOperator + + log = logging.getLogger(__name__) @@ -65,7 +69,7 @@ def disable_lineage(obj: T) -> T: return obj -def is_task_lineage_enabled(task: Operator) -> bool: +def is_task_lineage_enabled(task: Operator | SdkBaseOperator) -> bool: """Check if selective enable OpenLineage parameter is set to True on task level.""" if task.params.get(ENABLE_OL_PARAM_NAME) is False: log.debug( diff --git a/providers/src/airflow/providers/openlineage/utils/utils.py b/providers/src/airflow/providers/openlineage/utils/utils.py index 4408a833fba68..2cb998d47aa95 100644 --- a/providers/src/airflow/providers/openlineage/utils/utils.py +++ b/providers/src/airflow/providers/openlineage/utils/utils.py @@ -27,12 +27,11 @@ import attrs from openlineage.client.utils import RedactMixin -from sqlalchemy import exists from airflow import __version__ as AIRFLOW_VERSION # TODO: move this maybe to Airflow's logic? -from airflow.models import DAG, BaseOperator, DagRun, MappedOperator, TaskReschedule +from airflow.models import DAG, BaseOperator, DagRun, MappedOperator from airflow.providers.openlineage import ( __version__ as OPENLINEAGE_PROVIDER_VERSION, conf, @@ -52,7 +51,6 @@ is_task_lineage_enabled, ) from airflow.providers.openlineage.version_compat import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS -from airflow.sensors.base import BaseSensorOperator from airflow.serialization.serialized_objects import SerializedBaseOperator from airflow.utils.context import AirflowContextDeprecationWarning from airflow.utils.log.secrets_masker import ( @@ -62,7 +60,11 @@ should_hide_value_for_key, ) from airflow.utils.module_loading import import_string -from airflow.utils.session import NEW_SESSION, provide_session + +try: + from airflow.sdk.definitions.baseoperator import BaseOperator as SdkBaseOperator +except ImportError: + SdkBaseOperator = BaseOperator # type: ignore[misc] if TYPE_CHECKING: from openlineage.client.event_v2 import Dataset as OpenLineageDataset @@ -90,7 +92,7 @@ def try_import_from_string(string: str) -> Any: return import_string(string) -def get_operator_class(task: BaseOperator) -> type: +def get_operator_class(task: BaseOperator | SdkBaseOperator) -> type: if task.__class__.__name__ in ("DecoratedMappedOperator", "MappedOperator"): return task.operator_class return task.__class__ @@ -153,7 +155,7 @@ def get_user_provided_run_facets(ti: TaskInstance, ti_state: TaskInstanceState) return custom_facets -def get_fully_qualified_class_name(operator: BaseOperator | MappedOperator) -> str: +def get_fully_qualified_class_name(operator: BaseOperator | MappedOperator | SdkBaseOperator) -> str: if isinstance(operator, (MappedOperator, SerializedBaseOperator)): # as in airflow.api_connexion.schemas.common_schema.ClassReferenceSchema return operator._task_module + "." + operator._task_type # type: ignore @@ -161,44 +163,22 @@ def get_fully_qualified_class_name(operator: BaseOperator | MappedOperator) -> s return op_class.__module__ + "." + op_class.__name__ -def is_operator_disabled(operator: BaseOperator | MappedOperator) -> bool: +def is_operator_disabled(operator: BaseOperator | MappedOperator | SdkBaseOperator) -> bool: return get_fully_qualified_class_name(operator) in conf.disabled_operators() -def is_selective_lineage_enabled(obj: DAG | BaseOperator | MappedOperator) -> bool: +def is_selective_lineage_enabled(obj: DAG | BaseOperator | MappedOperator | SdkBaseOperator) -> bool: """If selective enable is active check if DAG or Task is enabled to emit events.""" if not conf.selective_enable(): return True if isinstance(obj, DAG): return is_dag_lineage_enabled(obj) - elif isinstance(obj, (BaseOperator, MappedOperator)): + elif isinstance(obj, (BaseOperator, MappedOperator, SdkBaseOperator)): return is_task_lineage_enabled(obj) else: raise TypeError("is_selective_lineage_enabled can only be used on DAG or Operator objects") -@provide_session -def is_ti_rescheduled_already(ti: TaskInstance, session=NEW_SESSION): - if not isinstance(ti.task, BaseSensorOperator): - return False - - if not ti.task.reschedule: - return False - - return ( - session.query( - exists().where( - TaskReschedule.dag_id == ti.dag_id, - TaskReschedule.task_id == ti.task_id, - TaskReschedule.run_id == ti.run_id, - TaskReschedule.map_index == ti.map_index, - TaskReschedule.try_number == ti.try_number, - ) - ).scalar() - is True - ) - - class InfoJsonEncodable(dict): """ Airflow objects might not be json-encodable overall. diff --git a/providers/tests/openlineage/extractors/test_manager.py b/providers/tests/openlineage/extractors/test_manager.py index 6b5e0dedd96ec..39c24047053bf 100644 --- a/providers/tests/openlineage/extractors/test_manager.py +++ b/providers/tests/openlineage/extractors/test_manager.py @@ -19,6 +19,7 @@ import tempfile from typing import TYPE_CHECKING, Any +from unittest import mock from unittest.mock import MagicMock import pytest @@ -36,14 +37,19 @@ from airflow.providers.openlineage.extractors import OperatorLineage from airflow.providers.openlineage.extractors.manager import ExtractorManager from airflow.providers.openlineage.utils.utils import Asset +from airflow.utils import timezone from airflow.utils.state import State from tests_common.test_utils.compat import PythonOperator -from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS +from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS if TYPE_CHECKING: + from datetime import datetime + from airflow.utils.context import Context + from task_sdk.tests.conftest import MakeTIContextCallable + if AIRFLOW_V_2_10_PLUS: @pytest.fixture @@ -61,6 +67,19 @@ def hook_lineage_collector(): hook._hook_lineage_collector = None +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.api.datamodels._generated import TaskInstance as SDKTaskInstance + from airflow.sdk.execution_time import task_runner + from airflow.sdk.execution_time.comms import StartupDetails + from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance, parse +else: + SDKTaskInstance = ... # type: ignore + task_runner = ... # type: ignore + StartupDetails = ... # type: ignore + RuntimeTaskInstance = ... # type: ignore + parse = ... # type: ignore + + @pytest.mark.parametrize( ("uri", "dataset"), ( @@ -297,7 +316,10 @@ def get_openlineage_facets_on_start(self): @pytest.mark.db_test -@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in Airflow >= 2.10.0") +@pytest.mark.skipif( + not AIRFLOW_V_2_10_PLUS or AIRFLOW_V_3_0_PLUS, + reason="Test for hook level lineage in Airflow >= 2.10.0 < 3.0", +) def test_extractor_manager_gets_data_from_pythonoperator(session, dag_maker, hook_lineage_collector): path = None with tempfile.NamedTemporaryFile() as f: @@ -324,3 +346,125 @@ def use_read(): assert len(datasets.outputs) == 1 assert datasets.outputs[0].asset == Asset(uri=path) + + +@pytest.fixture +def mock_supervisor_comms(): + with mock.patch( + "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True + ) as supervisor_comms: + yield supervisor_comms + + +@pytest.fixture +def mocked_parse(spy_agency): + """ + Fixture to set up an inline DAG and use it in a stubbed `parse` function. Use this fixture if you + want to isolate and test `parse` or `run` logic without having to define a DAG file. + + This fixture returns a helper function `set_dag` that: + 1. Creates an in line DAG with the given `dag_id` and `task` (limited to one task) + 2. Constructs a `RuntimeTaskInstance` based on the provided `StartupDetails` and task. + 3. Stubs the `parse` function using `spy_agency`, to return the mocked `RuntimeTaskInstance`. + + After adding the fixture in your test function signature, you can use it like this :: + + mocked_parse( + StartupDetails( + ti=TaskInstance(id=uuid7(), task_id="hello", dag_id="super_basic_run", run_id="c", try_number=1), + file="", + requests_fd=0, + ), + "example_dag_id", + CustomOperator(task_id="hello"), + ) + """ + + def set_dag(what: StartupDetails, dag_id: str, task: BaseOperator) -> RuntimeTaskInstance: + from task_sdk.tests.execution_time.test_task_runner import get_inline_dag + + dag = get_inline_dag(dag_id, task) + t = dag.task_dict[task.task_id] + ti = RuntimeTaskInstance.model_construct( + **what.ti.model_dump(exclude_unset=True), task=t, _ti_context_from_server=what.ti_context + ) + spy_agency.spy_on(parse, call_fake=lambda _: ti) + return ti + + return set_dag + + +@pytest.fixture +def make_ti_context() -> MakeTIContextCallable: + """Factory for creating TIRunContext objects.""" + from airflow.sdk.api.datamodels._generated import DagRun, TIRunContext + + def _make_context( + dag_id: str = "test_dag", + run_id: str = "test_run", + logical_date: str | datetime = "2024-12-01T01:00:00Z", + data_interval_start: str | datetime = "2024-12-01T00:00:00Z", + data_interval_end: str | datetime = "2024-12-01T01:00:00Z", + clear_number: int = 0, + start_date: str | datetime = "2024-12-01T01:00:00Z", + run_type: str = "manual", + task_reschedule_count: int = 0, + ) -> TIRunContext: + return TIRunContext( + dag_run=DagRun( + dag_id=dag_id, + run_id=run_id, + logical_date=logical_date, # type: ignore + data_interval_start=data_interval_start, # type: ignore + data_interval_end=data_interval_end, # type: ignore + clear_number=clear_number, # type: ignore + start_date=start_date, # type: ignore + run_type=run_type, # type: ignore + ), + task_reschedule_count=task_reschedule_count, + ) + + return _make_context + + +@pytest.mark.db_test +@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Task SDK related test") +def test_extractor_manager_gets_data_from_pythonoperator_tasksdk( + session, hook_lineage_collector, mocked_parse, make_ti_context, mock_supervisor_comms +): + from airflow.models.taskinstance import uuid7 + + path = None + with tempfile.NamedTemporaryFile() as f: + path = f.name + + def use_read(): + storage_path = ObjectStoragePath(path) + with storage_path.open("w") as out: + out.write("test") + + task = PythonOperator(task_id="test_task_extractor_pythonoperator", python_callable=use_read) + + what = StartupDetails( + ti=SDKTaskInstance( + id=uuid7(), + task_id="test_task_extractor_pythonoperator", + dag_id="test_hookcollector_dag", + run_id="c", + try_number=1, + start_date=timezone.utcnow(), + ), + file="", + requests_fd=0, + ti_context=make_ti_context(), + ) + ti = mocked_parse(what, "test_hookcollector_dag", task) + + print(ti.__dict__) + + task_runner.run(ti, MagicMock()) + + datasets = hook_lineage_collector.collected_assets + + assert len(datasets.outputs) == 1 + assert datasets.outputs[0].asset == Asset(uri=path) diff --git a/providers/tests/openlineage/plugins/test_execution.py b/providers/tests/openlineage/plugins/test_execution.py index 039064e7053f6..33d79902c5266 100644 --- a/providers/tests/openlineage/plugins/test_execution.py +++ b/providers/tests/openlineage/plugins/test_execution.py @@ -116,6 +116,7 @@ def test_not_stalled_task_emits_proper_lineage(self): self.setup_job(task_name, run_id) events = get_sorted_events(tmp_dir) + log.error(events) assert has_value_in_events(events, ["inputs", "name"], "on-start") assert has_value_in_events(events, ["inputs", "name"], "on-complete") diff --git a/providers/tests/openlineage/plugins/test_listener.py b/providers/tests/openlineage/plugins/test_listener.py index eeca49c7f8d1b..d8786cb55ad22 100644 --- a/providers/tests/openlineage/plugins/test_listener.py +++ b/providers/tests/openlineage/plugins/test_listener.py @@ -271,7 +271,7 @@ def test_adapter_start_task_is_called_with_proper_arguments( mock_get_airflow_run_facet.return_value = {"airflow_run_facet": 3} mock_disabled.return_value = False - listener.on_task_instance_running(None, task_instance, None) + listener.on_task_instance_running(None, task_instance) listener.adapter.start_task.assert_called_once_with( run_id="2020-01-01T01:01:01.dag_id.task_id.1.-1", job_name="job_name", @@ -326,7 +326,7 @@ def test_adapter_fail_task_is_called_with_proper_arguments( expected_err_kwargs = {"error": err if AIRFLOW_V_2_10_PLUS else None} listener.on_task_instance_failed( - previous_state=None, task_instance=task_instance, session=None, **on_task_failed_listener_kwargs + previous_state=None, task_instance=task_instance, **on_task_failed_listener_kwargs ) listener.adapter.fail_task.assert_called_once_with( end_time="2023-01-03T13:01:01", @@ -372,7 +372,7 @@ def test_adapter_complete_task_is_called_with_proper_arguments( mock_get_airflow_run_facet.return_value = {"airflow": {"task": "..."}} mock_disabled.return_value = False - listener.on_task_instance_success(None, task_instance, None) + listener.on_task_instance_success(None, task_instance) # This run_id will be different as we did NOT simulate increase of the try_number attribute, # which happens in Airflow < 2.10. calls = listener.adapter.complete_task.call_args_list @@ -401,7 +401,7 @@ def test_on_task_instance_running_correctly_calls_openlineage_adapter_run_id_met parameters derived from the task instance. """ listener, task_instance = _create_listener_and_task_instance() - listener.on_task_instance_running(None, task_instance, None) + listener.on_task_instance_running(None, task_instance) listener.adapter.build_task_instance_run_id.assert_called_once_with( dag_id="dag_id", task_id="task_id", @@ -423,7 +423,7 @@ def test_on_task_instance_failed_correctly_calls_openlineage_adapter_run_id_meth on_task_failed_kwargs = {"error": ValueError("test")} if AIRFLOW_V_2_10_PLUS else {} listener.on_task_instance_failed( - previous_state=None, task_instance=task_instance, session=None, **on_task_failed_kwargs + previous_state=None, task_instance=task_instance, **on_task_failed_kwargs ) listener.adapter.build_task_instance_run_id.assert_called_once_with( dag_id="dag_id", @@ -443,7 +443,7 @@ def test_on_task_instance_success_correctly_calls_openlineage_adapter_run_id_met parameters derived from the task instance. """ listener, task_instance = _create_listener_and_task_instance() - listener.on_task_instance_success(None, task_instance, None) + listener.on_task_instance_success(None, task_instance) listener.adapter.build_task_instance_run_id.assert_called_once_with( dag_id="dag_id", task_id="task_id", @@ -527,7 +527,7 @@ def test_listener_on_task_instance_running_do_not_call_adapter_when_disabled_ope mock_get_airflow_run_facet.return_value = {"airflow_run_facet": 3} mock_disabled.return_value = True - listener.on_task_instance_running(None, task_instance, None) + listener.on_task_instance_running(None, task_instance) mock_disabled.assert_called_once_with(task_instance.task) listener.adapter.build_dag_run_id.assert_not_called() listener.adapter.build_task_instance_run_id.assert_not_called() @@ -548,7 +548,7 @@ def test_listener_on_task_instance_failed_do_not_call_adapter_when_disabled_oper on_task_failed_kwargs = {"error": ValueError("test")} if AIRFLOW_V_2_10_PLUS else {} listener.on_task_instance_failed( - previous_state=None, task_instance=task_instance, session=None, **on_task_failed_kwargs + previous_state=None, task_instance=task_instance, **on_task_failed_kwargs ) mock_disabled.assert_called_once_with(task_instance.task) listener.adapter.build_dag_run_id.assert_not_called() @@ -567,7 +567,7 @@ def test_listener_on_task_instance_success_do_not_call_adapter_when_disabled_ope mock_get_user_provided_run_facets.return_value = {"custom_facet": 2} mock_disabled.return_value = True - listener.on_task_instance_success(None, task_instance, None) + listener.on_task_instance_success(None, task_instance) mock_disabled.assert_called_once_with(task_instance.task) listener.adapter.build_dag_run_id.assert_not_called() listener.adapter.build_task_instance_run_id.assert_not_called() @@ -755,24 +755,22 @@ def test_listener_with_task_enabled( assert expected_dag_call_count == listener._executor.submit.call_count # run TaskInstance-related hooks for lineage enabled task - listener.on_task_instance_running(None, self.task_instance_1, None) - listener.on_task_instance_success(None, self.task_instance_1, None) + listener.on_task_instance_running(None, self.task_instance_1) + listener.on_task_instance_success(None, self.task_instance_1) listener.on_task_instance_failed( previous_state=None, task_instance=self.task_instance_1, - session=None, **on_task_failed_kwargs, ) assert expected_task_call_count == listener.extractor_manager.extract_metadata.call_count # run TaskInstance-related hooks for lineage disabled task - listener.on_task_instance_running(None, self.task_instance_2, None) - listener.on_task_instance_success(None, self.task_instance_2, None) + listener.on_task_instance_running(None, self.task_instance_2) + listener.on_task_instance_success(None, self.task_instance_2) listener.on_task_instance_failed( previous_state=None, task_instance=self.task_instance_2, - session=None, **on_task_failed_kwargs, ) @@ -817,10 +815,10 @@ def test_listener_with_dag_disabled_task_enabled( listener.on_dag_run_success(self.dagrun, msg="test success") # run TaskInstance-related hooks for lineage enabled task - listener.on_task_instance_running(None, self.task_instance_1, None) - listener.on_task_instance_success(None, self.task_instance_1, None) + listener.on_task_instance_running(None, self.task_instance_1) + listener.on_task_instance_success(None, self.task_instance_1) listener.on_task_instance_failed( - previous_state=None, task_instance=self.task_instance_1, session=None, **on_task_failed_kwargs + previous_state=None, task_instance=self.task_instance_1, **on_task_failed_kwargs ) assert expected_call_count == listener._executor.submit.call_count diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py index ff4cc588ff564..b6298b626f2a4 100644 --- a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -178,6 +178,7 @@ class TaskInstance(BaseModel): dag_id: Annotated[str, Field(title="Dag Id")] run_id: Annotated[str, Field(title="Run Id")] try_number: Annotated[int, Field(title="Try Number")] + start_date: Annotated[datetime, Field(title="Start Date")] map_index: Annotated[int | None, Field(title="Map Index")] = None @@ -193,6 +194,7 @@ class DagRun(BaseModel): data_interval_end: Annotated[datetime | None, Field(title="Data Interval End")] = None start_date: Annotated[datetime, Field(title="Start Date")] end_date: Annotated[datetime | None, Field(title="End Date")] = None + clear_number: Annotated[int, Field(title="Clear Number")] run_type: DagRunType conf: Annotated[dict[str, Any] | None, Field(title="Conf")] = None @@ -207,6 +209,7 @@ class TIRunContext(BaseModel): """ dag_run: DagRun + task_reschedule_count: Annotated[int, Field(title="Task Reschedule Count")] variables: Annotated[list[VariableResponse] | None, Field(title="Variables")] = None connections: Annotated[list[ConnectionResponse] | None, Field(title="Connections")] = None diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py b/task_sdk/src/airflow/sdk/execution_time/supervisor.py index 811d1ce86a60d..bdbf9beeccc18 100644 --- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py @@ -848,6 +848,7 @@ def supervise( Run a single task execution to completion. :param ti: The task instance to run. + :param dr: Current DagRun of the task instance. :param dag_path: The file path to the DAG. :param token: Authentication token for the API client. :param server: Base URL of the API server. diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index effea9f6a7e52..396dc7435bfbb 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -30,6 +30,7 @@ import structlog from pydantic import BaseModel, ConfigDict, Field, JsonValue, TypeAdapter +from airflow.listeners.listener import get_listener_manager from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState, TIRunContext from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.execution_time.comms import ( @@ -45,6 +46,7 @@ XComResult, ) from airflow.sdk.execution_time.context import ConnectionAccessor +from airflow.utils.state import TaskInstanceState if TYPE_CHECKING: from structlog.typing import FilteringBoundLogger as Logger @@ -113,6 +115,7 @@ def get_template_context(self): "ts": ts, "ts_nodash": ts_nodash, "ts_nodash_with_tz": ts_nodash_with_tz, + "task_reschedule_count": self._ti_context_from_server.task_reschedule_count, } context.update(context_from_server) # TODO: We should use/move TypeDict from airflow.utils.context.Context @@ -381,6 +384,9 @@ def run(ti: RuntimeTaskInstance, log: Logger): try: # TODO: pre execute etc. # TODO: Get a real context object + get_listener_manager().hook.on_task_instance_running( + previous_state=TaskInstanceState.QUEUED, task_instance=ti + ) ti.task = ti.task.prepare_for_execution() context = ti.get_template_context() # TODO: Get things from _execute_task_with_callbacks @@ -412,6 +418,9 @@ def run(ti: RuntimeTaskInstance, log: Logger): _push_xcom_if_needed(result, ti) msg = TaskState(state=TerminalTIState.SUCCESS, end_date=datetime.now(tz=timezone.utc)) + get_listener_manager().hook.on_task_instance_success( + previous_state=TaskInstanceState.RUNNING, task_instance=ti + ) except TaskDeferred as defer: classpath, trigger_kwargs = defer.trigger.serialize() next_method = defer.method_name @@ -441,6 +450,11 @@ def run(ti: RuntimeTaskInstance, log: Logger): state=TerminalTIState.FAIL_WITHOUT_RETRY, end_date=datetime.now(tz=timezone.utc), ) + + get_listener_manager().hook.on_task_instance_failed( + previous_state=TaskInstanceState.RUNNING, task_instance=ti + ) + # TODO: Run task failure callbacks here except (AirflowTaskTimeout, AirflowException): # We should allow retries if the task has defined it. @@ -448,6 +462,9 @@ def run(ti: RuntimeTaskInstance, log: Logger): state=TerminalTIState.FAILED, end_date=datetime.now(tz=timezone.utc), ) + get_listener_manager().hook.on_task_instance_failed( + previous_state=TaskInstanceState.RUNNING, task_instance=ti + ) # TODO: Run task failure callbacks here except AirflowException: # TODO: handle the case of up_for_retry here @@ -455,6 +472,9 @@ def run(ti: RuntimeTaskInstance, log: Logger): state=TerminalTIState.FAILED, end_date=datetime.now(tz=timezone.utc), ) + get_listener_manager().hook.on_task_instance_failed( + previous_state=TaskInstanceState.RUNNING, task_instance=ti + ) except AirflowTaskTerminated: # External state updates are already handled with `ti_heartbeat` and will be # updated already be another UI API. So, these exceptions should ideally never be thrown. diff --git a/task_sdk/tests/conftest.py b/task_sdk/tests/conftest.py index 25d0a1b0061b6..d8adb41a9ef2b 100644 --- a/task_sdk/tests/conftest.py +++ b/task_sdk/tests/conftest.py @@ -142,8 +142,10 @@ def __call__( logical_date: str | datetime = ..., data_interval_start: str | datetime = ..., data_interval_end: str | datetime = ..., + clear_number: int = ..., start_date: str | datetime = ..., run_type: str = ..., + task_reschedule_count: int = ..., ) -> TIRunContext: ... @@ -157,6 +159,7 @@ def __call__( data_interval_end: str | datetime = ..., start_date: str | datetime = ..., run_type: str = ..., + task_reschedule_count: int = ..., ) -> dict[str, Any]: ... @@ -171,8 +174,10 @@ def _make_context( logical_date: str | datetime = "2024-12-01T01:00:00Z", data_interval_start: str | datetime = "2024-12-01T00:00:00Z", data_interval_end: str | datetime = "2024-12-01T01:00:00Z", + clear_number: int = 0, start_date: str | datetime = "2024-12-01T01:00:00Z", run_type: str = "manual", + task_reschedule_count: int = 0, ) -> TIRunContext: return TIRunContext( dag_run=DagRun( @@ -181,9 +186,11 @@ def _make_context( logical_date=logical_date, # type: ignore data_interval_start=data_interval_start, # type: ignore data_interval_end=data_interval_end, # type: ignore + clear_number=clear_number, # type: ignore start_date=start_date, # type: ignore run_type=run_type, # type: ignore - ) + ), + task_reschedule_count=task_reschedule_count, ) return _make_context @@ -201,6 +208,7 @@ def _make_context_dict( data_interval_end: str | datetime = "2024-12-01T01:00:00Z", start_date: str | datetime = "2024-12-01T00:00:00Z", run_type: str = "manual", + task_reschedule_count: int = 0, ) -> dict[str, Any]: context = make_ti_context( dag_id=dag_id, @@ -210,6 +218,7 @@ def _make_context_dict( data_interval_end=data_interval_end, start_date=start_date, run_type=run_type, + task_reschedule_count=task_reschedule_count, ) return context.model_dump(exclude_unset=True, mode="json") diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py index 9cfe456962bb9..e765ebada1c83 100644 --- a/task_sdk/tests/execution_time/test_supervisor.py +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -228,11 +228,7 @@ def subprocess_main(): proc = WatchedSubprocess.start( path=os.devnull, what=TaskInstance( - id=ti_id, - task_id="b", - dag_id="c", - run_id="d", - try_number=1, + id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1, start_date=timezone.utcnow() ), client=sdk_client.Client(base_url="", dry_run=True, token=""), target=subprocess_main, @@ -320,7 +316,9 @@ def test_supervise_handles_deferred_task( def test_supervisor_handles_already_running_task(self): """Test that Supervisor prevents starting a Task Instance that is already running.""" - ti = TaskInstance(id=uuid7(), task_id="b", dag_id="c", run_id="d", try_number=1) + ti = TaskInstance( + id=uuid7(), task_id="b", dag_id="c", run_id="d", try_number=1, start_date=tz.utcnow() + ) # Mock API Server response indicating the TI is already running # The API Server would return a 409 Conflict status code if the TI is not diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index ebf03d3323315..dab466ef8092d 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -90,7 +90,9 @@ def mocked_parse(spy_agency): def set_dag(what: StartupDetails, dag_id: str, task: BaseOperator) -> RuntimeTaskInstance: dag = get_inline_dag(dag_id, task) t = dag.task_dict[task.task_id] - ti = RuntimeTaskInstance.model_construct(**what.ti.model_dump(exclude_unset=True), task=t) + ti = RuntimeTaskInstance.model_construct( + **what.ti.model_dump(exclude_unset=True), task=t, _ti_context_from_server=what.ti_context + ) spy_agency.spy_on(parse, call_fake=lambda _: ti) return ti @@ -141,7 +143,14 @@ def test_recv_StartupDetails(self): def test_parse(test_dags_dir: Path, make_ti_context): """Test that checks parsing of a basic dag with an un-mocked parse.""" what = StartupDetails( - ti=TaskInstance(id=uuid7(), task_id="a", dag_id="super_basic", run_id="c", try_number=1), + ti=TaskInstance( + id=uuid7(), + task_id="a", + dag_id="super_basic", + run_id="c", + try_number=1, + start_date=timezone.utcnow(), + ), file=str(test_dags_dir / "super_basic.py"), requests_fd=0, ti_context=make_ti_context(), @@ -723,7 +732,12 @@ def execute(self, context): task = CustomOperator(task_id="hello", do_xcom_push=do_xcom_push) ti = TaskInstance( - id=uuid7(), task_id=task.task_id, dag_id="xcom_push_flag", run_id="test_run", try_number=1 + id=uuid7(), + task_id=task.task_id, + dag_id="xcom_push_flag", + run_id="test_run", + try_number=1, + start_date=timezone.utcnow(), ) what = StartupDetails(ti=ti, file="", requests_fd=0, ti_context=make_ti_context()) diff --git a/tests/conftest.py b/tests/conftest.py index de13fe99c4bf6..004e8ac85db9e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -40,6 +40,14 @@ ] +@pytest.fixture(autouse=True) +def load_examples(): + from tests_common.test_utils.config import conf_vars + + with conf_vars({("core", "load_examples"): "False"}): + yield + + @pytest.hookimpl(tryfirst=True) def pytest_configure(config: pytest.Config) -> None: dep_path = [config.rootpath.joinpath("tests", "deprecations_ignore.yml")] diff --git a/tests/listeners/class_listener.py b/tests/listeners/class_listener.py index 90ff6ab975ea9..cda1ea09cff26 100644 --- a/tests/listeners/class_listener.py +++ b/tests/listeners/class_listener.py @@ -42,17 +42,15 @@ def before_stopping(self, component): self.state.append(DagRunState.SUCCESS) @hookimpl - def on_task_instance_running(self, previous_state, task_instance, session): + def on_task_instance_running(self, previous_state, task_instance): self.state.append(TaskInstanceState.RUNNING) @hookimpl - def on_task_instance_success(self, previous_state, task_instance, session): + def on_task_instance_success(self, previous_state, task_instance): self.state.append(TaskInstanceState.SUCCESS) @hookimpl - def on_task_instance_failed( - self, previous_state, task_instance, error: None | str | BaseException, session - ): + def on_task_instance_failed(self, previous_state, task_instance, error: None | str | BaseException): self.state.append(TaskInstanceState.FAILED) else: @@ -74,15 +72,15 @@ def before_stopping(self, component): self.state.append(DagRunState.SUCCESS) @hookimpl - def on_task_instance_running(self, previous_state, task_instance, session): + def on_task_instance_running(self, previous_state, task_instance): self.state.append(TaskInstanceState.RUNNING) @hookimpl - def on_task_instance_success(self, previous_state, task_instance, session): + def on_task_instance_success(self, previous_state, task_instance): self.state.append(TaskInstanceState.SUCCESS) @hookimpl - def on_task_instance_failed(self, previous_state, task_instance, session): + def on_task_instance_failed(self, previous_state, task_instance): self.state.append(TaskInstanceState.FAILED) diff --git a/tests/listeners/empty_listener.py b/tests/listeners/empty_listener.py index 0b298e95fe6d8..0a69f9ec1fa59 100644 --- a/tests/listeners/empty_listener.py +++ b/tests/listeners/empty_listener.py @@ -21,7 +21,7 @@ @hookimpl -def on_task_instance_running(previous_state, task_instance, session): +def on_task_instance_running(previous_state, task_instance): pass diff --git a/tests/listeners/file_write_listener.py b/tests/listeners/file_write_listener.py index 0ca026da7165f..c542ccacab5b3 100644 --- a/tests/listeners/file_write_listener.py +++ b/tests/listeners/file_write_listener.py @@ -34,17 +34,15 @@ def write(self, line: str): f.write(line + "\n") @hookimpl - def on_task_instance_running(self, previous_state, task_instance, session): + def on_task_instance_running(self, previous_state, task_instance): self.write("on_task_instance_running") @hookimpl - def on_task_instance_success(self, previous_state, task_instance, session): + def on_task_instance_success(self, previous_state, task_instance): self.write("on_task_instance_success") @hookimpl - def on_task_instance_failed( - self, previous_state, task_instance, error: None | str | BaseException, session - ): + def on_task_instance_failed(self, previous_state, task_instance, error: None | str | BaseException): self.write("on_task_instance_failed") @hookimpl diff --git a/tests/listeners/full_listener.py b/tests/listeners/full_listener.py index 229fdab676254..50701c822e6d6 100644 --- a/tests/listeners/full_listener.py +++ b/tests/listeners/full_listener.py @@ -40,17 +40,17 @@ def before_stopping(component): @hookimpl -def on_task_instance_running(previous_state, task_instance, session): +def on_task_instance_running(previous_state, task_instance): state.append(TaskInstanceState.RUNNING) @hookimpl -def on_task_instance_success(previous_state, task_instance, session): +def on_task_instance_success(previous_state, task_instance): state.append(TaskInstanceState.SUCCESS) @hookimpl -def on_task_instance_failed(previous_state, task_instance, error: None | str | BaseException, session): +def on_task_instance_failed(previous_state, task_instance, error: None | str | BaseException): state.append(TaskInstanceState.FAILED) diff --git a/tests/listeners/partial_listener.py b/tests/listeners/partial_listener.py index b4027e287565b..2bf1d117745db 100644 --- a/tests/listeners/partial_listener.py +++ b/tests/listeners/partial_listener.py @@ -24,7 +24,7 @@ @hookimpl -def on_task_instance_running(previous_state, task_instance, session): +def on_task_instance_running(previous_state, task_instance): state.append(State.RUNNING) diff --git a/tests/listeners/slow_listener.py b/tests/listeners/slow_listener.py index b366aa4d0cb50..b585b19650a43 100644 --- a/tests/listeners/slow_listener.py +++ b/tests/listeners/slow_listener.py @@ -22,5 +22,5 @@ @hookimpl -def on_task_instance_success(previous_state, task_instance, session): +def on_task_instance_success(previous_state, task_instance): time.sleep(3) diff --git a/tests/listeners/throwing_listener.py b/tests/listeners/throwing_listener.py index ae7345d395af9..eeb7d0ee6edf7 100644 --- a/tests/listeners/throwing_listener.py +++ b/tests/listeners/throwing_listener.py @@ -21,7 +21,7 @@ @hookimpl -def on_task_instance_success(previous_state, task_instance, session): +def on_task_instance_success(previous_state, task_instance): raise RuntimeError() diff --git a/tests/listeners/very_slow_listener.py b/tests/listeners/very_slow_listener.py index 28df43a2b8b6b..9752b8787b758 100644 --- a/tests/listeners/very_slow_listener.py +++ b/tests/listeners/very_slow_listener.py @@ -22,5 +22,5 @@ @hookimpl -def on_task_instance_success(previous_state, task_instance, session): +def on_task_instance_success(previous_state, task_instance): time.sleep(10) diff --git a/tests/listeners/xcom_listener.py b/tests/listeners/xcom_listener.py index a7ffc19178589..bbfbbba4e6563 100644 --- a/tests/listeners/xcom_listener.py +++ b/tests/listeners/xcom_listener.py @@ -30,13 +30,13 @@ def write(self, line: str): f.write(line + "\n") @hookimpl - def on_task_instance_running(self, previous_state, task_instance, session): + def on_task_instance_running(self, previous_state, task_instance): task_instance.xcom_push(key="listener", value="listener") task_instance.xcom_pull(task_ids=task_instance.task_id, key="listener") self.write("on_task_instance_running") @hookimpl - def on_task_instance_success(self, previous_state, task_instance, session): + def on_task_instance_success(self, previous_state, task_instance): read = task_instance.xcom_pull(task_ids=self.task_id, key="listener") self.write("on_task_instance_success") self.write(read) diff --git a/tests/plugins/test_plugins_manager.py b/tests/plugins/test_plugins_manager.py index 8618af26254bc..54ab88d2f60e0 100644 --- a/tests/plugins/test_plugins_manager.py +++ b/tests/plugins/test_plugins_manager.py @@ -358,19 +358,22 @@ class MacroPlugin(AirflowPlugin): def test_registering_plugin_listeners(self): from airflow import plugins_manager - with mock.patch("airflow.plugins_manager.plugins", []): - plugins_manager.load_plugins_from_plugin_directory() - plugins_manager.integrate_listener_plugins(get_listener_manager()) - - assert get_listener_manager().has_listeners - listeners = get_listener_manager().pm.get_plugins() - listener_names = [el.__name__ if inspect.ismodule(el) else qualname(el) for el in listeners] - # sort names as order of listeners is not guaranteed - assert sorted(listener_names) == [ - "airflow.example_dags.plugins.event_listener", - "tests.listeners.class_listener.ClassBasedListener", - "tests.listeners.empty_listener", - ] + try: + with mock.patch("airflow.plugins_manager.plugins", []): + plugins_manager.load_plugins_from_plugin_directory() + plugins_manager.integrate_listener_plugins(get_listener_manager()) + + assert get_listener_manager().has_listeners + listeners = get_listener_manager().pm.get_plugins() + listener_names = [el.__name__ if inspect.ismodule(el) else qualname(el) for el in listeners] + # sort names as order of listeners is not guaranteed + assert sorted(listener_names) == [ + "airflow.example_dags.plugins.event_listener", + "tests.listeners.class_listener.ClassBasedListener", + "tests.listeners.empty_listener", + ] + finally: + get_listener_manager().clear() def test_should_import_plugin_from_providers(self): from airflow import plugins_manager