Skip to content

Commit

Permalink
task sdk: call on_task_instance_* listeners
Browse files Browse the repository at this point in the history
Signed-off-by: Maciej Obuchowski <[email protected]>
  • Loading branch information
mobuchowski committed Jan 2, 2025
1 parent 48a5a0a commit 81b886b
Show file tree
Hide file tree
Showing 33 changed files with 523 additions and 235 deletions.
3 changes: 2 additions & 1 deletion airflow/api_fastapi/execution_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,9 @@ def custom_openapi() -> dict:

def get_extra_schemas() -> dict[str, dict]:
"""Get all the extra schemas that are not part of the main FastAPI app."""
from airflow.api_fastapi.execution_api.datamodels import taskinstance
from airflow.api_fastapi.execution_api.datamodels import dagrun, taskinstance

return {
"TaskInstance": taskinstance.TaskInstance.model_json_schema(),
"DagRun": dagrun.DagRun.model_json_schema(),
}
35 changes: 35 additions & 0 deletions airflow/api_fastapi/execution_api/datamodels/dagrun.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# This model is not used in the API, but it is included in generated OpenAPI schema
# for use in the client SDKs.
from __future__ import annotations

from airflow.api_fastapi.common.types import UtcDateTime
from airflow.api_fastapi.core_api.base import BaseModel


class DagRun(BaseModel):
"""Schema for TaskInstance model with minimal required fields needed for OL for now."""

id: int
dag_id: str
run_id: str
logical_date: UtcDateTime
data_interval_start: UtcDateTime
data_interval_end: UtcDateTime
clear_number: int
5 changes: 5 additions & 0 deletions airflow/api_fastapi/execution_api/datamodels/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ class TaskInstance(BaseModel):
run_id: str
try_number: int
map_index: int | None = None
start_date: UtcDateTime


class DagRun(BaseModel):
Expand All @@ -180,6 +181,7 @@ class DagRun(BaseModel):
data_interval_end: UtcDateTime | None
start_date: UtcDateTime
end_date: UtcDateTime | None
clear_number: int
run_type: DagRunType
conf: Annotated[dict[str, Any], Field(default_factory=dict)]

Expand All @@ -190,6 +192,9 @@ class TIRunContext(BaseModel):
dag_run: DagRun
"""DAG run information for the task instance."""

task_reschedule_count: Annotated[int, Field(default=0)]
"""How many times the task has been rescheduled."""

variables: Annotated[list[VariableResponse], Field(default_factory=list)]
"""Variables that can be accessed by the task instance."""

Expand Down
23 changes: 20 additions & 3 deletions airflow/api_fastapi/execution_api/routes/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from fastapi import Body, HTTPException, status
from pydantic import JsonValue
from sqlalchemy import update
from sqlalchemy import func, update
from sqlalchemy.exc import NoResultFound, SQLAlchemyError
from sqlalchemy.sql import select

Expand Down Expand Up @@ -73,9 +73,9 @@ def ti_run(
# We only use UUID above for validation purposes
ti_id_str = str(task_instance_id)

old = select(TI.state, TI.dag_id, TI.run_id).where(TI.id == ti_id_str).with_for_update()
old = select(TI.state, TI.dag_id, TI.run_id, TI.try_number).where(TI.id == ti_id_str).with_for_update()
try:
(previous_state, dag_id, run_id) = session.execute(old).one()
(previous_state, dag_id, run_id, try_number) = session.execute(old).one()
except NoResultFound:
log.error("Task Instance %s not found", ti_id_str)
raise HTTPException(
Expand Down Expand Up @@ -135,6 +135,7 @@ def ti_run(
DR.data_interval_end,
DR.start_date,
DR.end_date,
DR.clear_number,
DR.run_type,
DR.conf,
DR.logical_date,
Expand All @@ -144,8 +145,24 @@ def ti_run(
if not dr:
raise ValueError(f"DagRun with dag_id={dag_id} and run_id={run_id} not found.")

task_reschedule_count = (
session.query(
func.count(TaskReschedule.id) # or any other primary key column
)
.filter(
TaskReschedule.dag_id == dag_id,
TaskReschedule.task_id == ti_id_str,
TaskReschedule.run_id == run_id,
# TaskReschedule.map_index == ti.map_index, # TODO: Handle mapped tasks
TaskReschedule.try_number == try_number,
)
.scalar()
or 0
)

return TIRunContext(
dag_run=DagRun.model_validate(dr, from_attributes=True),
task_reschedule_count=task_reschedule_count,
# TODO: Add variables and connections that are needed (and has perms) for the task
variables=[],
connections=[],
Expand Down
2 changes: 1 addition & 1 deletion airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def _execute_callbacks(
dagbag: DagBag, callback_requests: list[CallbackRequest], log: FilteringBoundLogger
) -> None:
for request in callback_requests:
log.debug("Processing Callback Request", request=request)
log.debug("Processing Callback Request", request=request.to_json())
if isinstance(request, TaskCallbackRequest):
raise NotImplementedError(
"Haven't coded Task callback yet - https://github.com/apache/airflow/issues/44354!"
Expand Down
44 changes: 15 additions & 29 deletions airflow/example_dags/plugins/event_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@

if TYPE_CHECKING:
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
from airflow.utils.state import TaskInstanceState


# [START howto_listen_ti_running_task]
@hookimpl
def on_task_instance_running(previous_state: TaskInstanceState, task_instance: TaskInstance, session):
def on_task_instance_running(previous_state: TaskInstanceState, task_instance: RuntimeTaskInstance):
"""
This method is called when task state changes to RUNNING.
Through callback, parameters like previous_task_state, task_instance object can be accessed.
Expand All @@ -39,14 +39,11 @@ def on_task_instance_running(previous_state: TaskInstanceState, task_instance: T
print("Task instance is in running state")
print(" Previous state of the Task instance:", previous_state)

state: TaskInstanceState = task_instance.state
name: str = task_instance.task_id
start_date = task_instance.start_date

dagrun = task_instance.dag_run
dagrun_status = dagrun.state
context = task_instance.get_template_context()

task = task_instance.task
task = context["task"]

if TYPE_CHECKING:
assert task
Expand All @@ -55,16 +52,16 @@ def on_task_instance_running(previous_state: TaskInstanceState, task_instance: T
dag_name = None
if dag:
dag_name = dag.dag_id
print(f"Current task name:{name} state:{state} start_date:{start_date}")
print(f"Dag name:{dag_name} and current dag run status:{dagrun_status}")
print(f"Current task name:{name}")
print(f"Dag name:{dag_name}")


# [END howto_listen_ti_running_task]


# [START howto_listen_ti_success_task]
@hookimpl
def on_task_instance_success(previous_state: TaskInstanceState, task_instance: TaskInstance, session):
def on_task_instance_success(previous_state: TaskInstanceState, task_instance: RuntimeTaskInstance):
"""
This method is called when task state changes to SUCCESS.
Through callback, parameters like previous_task_state, task_instance object can be accessed.
Expand All @@ -74,14 +71,10 @@ def on_task_instance_success(previous_state: TaskInstanceState, task_instance: T
print("Task instance in success state")
print(" Previous state of the Task instance:", previous_state)

dag_id = task_instance.dag_id
hostname = task_instance.hostname
operator = task_instance.operator
context = task_instance.get_template_context()
operator = context["task"]

dagrun = task_instance.dag_run
queued_at = dagrun.queued_at
print(f"Dag name:{dag_id} queued_at:{queued_at}")
print(f"Task hostname:{hostname} operator:{operator}")
print(f"Task operator:{operator}")


# [END howto_listen_ti_success_task]
Expand All @@ -90,7 +83,7 @@ def on_task_instance_success(previous_state: TaskInstanceState, task_instance: T
# [START howto_listen_ti_failure_task]
@hookimpl
def on_task_instance_failed(
previous_state: TaskInstanceState, task_instance: TaskInstance, error: None | str | BaseException, session
previous_state: TaskInstanceState, task_instance: RuntimeTaskInstance, error: None | str | BaseException
):
"""
This method is called when task state changes to FAILED.
Expand All @@ -100,21 +93,14 @@ def on_task_instance_failed(
"""
print("Task instance in failure state")

start_date = task_instance.start_date
end_date = task_instance.end_date
duration = task_instance.duration

dagrun = task_instance.dag_run

task = task_instance.task
context = task_instance.get_template_context()
task = context["task"]

if TYPE_CHECKING:
assert task

dag = task.dag

print(f"Task start:{start_date} end:{end_date} duration:{duration}")
print(f"Task:{task} dag:{dag} dagrun:{dagrun}")
print("Task start")
print(f"Task:{task}")
if error:
print(f"Failure caused by {error}")

Expand Down
14 changes: 14 additions & 0 deletions airflow/executors/workloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import os
import uuid
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Literal, Union

Expand Down Expand Up @@ -49,6 +50,7 @@ class TaskInstance(BaseModel):
run_id: str
try_number: int
map_index: int | None = None
start_date: datetime

# TODO: Task-SDK: Can we replace TastInstanceKey with just the uuid across the codebase?
@property
Expand All @@ -64,6 +66,15 @@ def key(self) -> TaskInstanceKey:
)


class DagRun(BaseModel):
id: int
dag_id: str
run_id: str
logical_date: datetime
data_interval_start: datetime
data_interval_end: datetime


class ExecuteTask(BaseActivity):
"""Execute the given Task."""

Expand All @@ -83,6 +94,9 @@ def make(cls, ti: TIModel, dag_path: Path | None = None) -> ExecuteTask:

from airflow.utils.helpers import log_filename_template_renderer

if not ti.start_date:
ti.start_date = datetime.now()

ser_ti = TaskInstance.model_validate(ti, from_attributes=True)

dag_path = dag_path or Path(ti.dag_run.dag_model.relative_fileloc)
Expand Down
8 changes: 7 additions & 1 deletion airflow/listeners/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,13 @@ class ListenerManager:
"""Manage listener registration and provides hook property for calling them."""

def __init__(self):
from airflow.listeners.spec import asset, dagrun, importerrors, lifecycle, taskinstance
from airflow.listeners.spec import (
asset,
dagrun,
importerrors,
lifecycle,
taskinstance,
)

self.pm = pluggy.PluginManager("airflow")
self.pm.add_hookcall_monitoring(_before_hookcall, _after_hookcall)
Expand Down
15 changes: 4 additions & 11 deletions airflow/listeners/spec/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,33 +22,26 @@
from pluggy import HookspecMarker

if TYPE_CHECKING:
from sqlalchemy.orm.session import Session

from airflow.models.taskinstance import TaskInstance
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
from airflow.utils.state import TaskInstanceState

hookspec = HookspecMarker("airflow")


@hookspec
def on_task_instance_running(
previous_state: TaskInstanceState | None, task_instance: TaskInstance, session: Session | None
):
def on_task_instance_running(previous_state: TaskInstanceState | None, task_instance: RuntimeTaskInstance):
"""Execute when task state changes to RUNNING. previous_state can be None."""


@hookspec
def on_task_instance_success(
previous_state: TaskInstanceState | None, task_instance: TaskInstance, session: Session | None
):
def on_task_instance_success(previous_state: TaskInstanceState | None, task_instance: RuntimeTaskInstance):
"""Execute when task state changes to SUCCESS. previous_state can be None."""


@hookspec
def on_task_instance_failed(
previous_state: TaskInstanceState | None,
task_instance: TaskInstance,
task_instance: RuntimeTaskInstance,
error: None | str | BaseException,
session: Session | None,
):
"""Execute when task state changes to FAIL. previous_state can be None."""
8 changes: 4 additions & 4 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,12 +355,12 @@ def _run_raw_task(
if not test_mode:
_add_log(event=ti.state, task_instance=ti, session=session)
if ti.state == TaskInstanceState.SUCCESS:
ti._register_asset_changes(events=context["outlet_events"], session=session)
ti._register_asset_changes(events=context["outlet_events"])

TaskInstance.save_to_db(ti=ti, session=session)
if ti.state == TaskInstanceState.SUCCESS:
get_listener_manager().hook.on_task_instance_success(
previous_state=TaskInstanceState.RUNNING, task_instance=ti, session=session
previous_state=TaskInstanceState.RUNNING, task_instance=ti
)

return None
Expand Down Expand Up @@ -2890,7 +2890,7 @@ def signal_handler(signum, frame):

# Run on_task_instance_running event
get_listener_manager().hook.on_task_instance_running(
previous_state=TaskInstanceState.QUEUED, task_instance=self, session=session
previous_state=TaskInstanceState.QUEUED, task_instance=self
)

def _render_map_index(context: Context, *, jinja_env: jinja2.Environment | None) -> str | None:
Expand Down Expand Up @@ -3132,7 +3132,7 @@ def fetch_handle_failure_context(
callbacks = task.on_retry_callback if task else None

get_listener_manager().hook.on_task_instance_failed(
previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error, session=session
previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error
)

return {
Expand Down
Loading

0 comments on commit 81b886b

Please sign in to comment.