Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions airflow-core/src/airflow/api_fastapi/execution_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
from fastapi import FastAPI, Request, Response
from fastapi.responses import JSONResponse
from sqlalchemy.exc import SQLAlchemyError
from starlette.middleware.base import BaseHTTPMiddleware

from airflow.api_fastapi.auth.tokens import (
Expand Down Expand Up @@ -271,6 +272,19 @@ def handle_exceptions(request: Request, exc: Exception):
content["correlation-id"] = correlation_id
return JSONResponse(status_code=500, content=content)

@app.exception_handler(SQLAlchemyError)
def handle_database_exceptions(request: Request, exc: SQLAlchemyError):
logger.exception(
"Database error handling request",
path=request.url.path,
method=request.method,
exc_info=(type(exc), exc, exc.__traceback__),
)
content: dict[str, str] = {"detail": "Database error occurred"}
if correlation_id := request.headers.get("correlation-id"):
content["correlation-id"] = correlation_id
return JSONResponse(status_code=500, content=content)

return app


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from pydantic import JsonValue
from sqlalchemy import and_, func, or_, tuple_, update
from sqlalchemy.engine import CursorResult
from sqlalchemy.exc import NoResultFound, SQLAlchemyError
from sqlalchemy.exc import NoResultFound
from sqlalchemy.orm import joinedload
from sqlalchemy.sql import select
from structlog.contextvars import bind_contextvars
Expand Down Expand Up @@ -242,74 +242,66 @@ def ti_run(
retry_reason=None,
)

try:
result = session.execute(query)
log.info("Task instance state updated", rows_affected=getattr(result, "rowcount", 0))

dr = (
session.scalars(
select(DR)
.filter_by(dag_id=ti.dag_id, run_id=ti.run_id)
.options(joinedload(DR.consumed_asset_events))
)
.unique()
.one_or_none()
)
result = session.execute(query)
log.info("Task instance state updated", rows_affected=getattr(result, "rowcount", 0))

if not dr:
log.error("DagRun not found", dag_id=ti.dag_id, run_id=ti.run_id)
raise ValueError(f"DagRun with dag_id={ti.dag_id} and run_id={ti.run_id} not found.")

# Send the keys to the SDK so that the client requests to clear those XComs from the server.
# The reason we cannot do this here in the server is because we need to issue a purge on custom XCom backends
# too. With the current assumption, the workers ONLY have access to the custom XCom backends directly and they
# can issue the purge.

# However, do not clear it for deferral
xcom_keys = []
if not ti.next_method:
map_index = None if ti.map_index < 0 else ti.map_index
xcom_query = select(XComModel.key).where(
XComModel.dag_id == ti.dag_id,
XComModel.task_id == ti.task_id,
XComModel.run_id == ti.run_id,
)
if map_index is not None:
xcom_query = xcom_query.where(XComModel.map_index == map_index)
dr = (
session.scalars(
select(DR)
.filter_by(dag_id=ti.dag_id, run_id=ti.run_id)
.options(joinedload(DR.consumed_asset_events))
)
.unique()
.one_or_none()
)

xcom_keys = list(session.scalars(xcom_query))
task_reschedule_count = (
session.scalar(
select(func.count(TaskReschedule.id)).where(TaskReschedule.ti_id == task_instance_id)
)
or 0
if not dr:
log.error("DagRun not found", dag_id=ti.dag_id, run_id=ti.run_id)
raise ValueError(f"DagRun with dag_id={ti.dag_id} and run_id={ti.run_id} not found.")

# Send the keys to the SDK so that the client requests to clear those XComs from the server.
# The reason we cannot do this here in the server is because we need to issue a purge on custom XCom backends
# too. With the current assumption, the workers ONLY have access to the custom XCom backends directly and they
# can issue the purge.

# However, do not clear it for deferral
xcom_keys = []
if not ti.next_method:
map_index = None if ti.map_index < 0 else ti.map_index
xcom_query = select(XComModel.key).where(
XComModel.dag_id == ti.dag_id,
XComModel.task_id == ti.task_id,
XComModel.run_id == ti.run_id,
)
if map_index is not None:
xcom_query = xcom_query.where(XComModel.map_index == map_index)

from airflow.api_fastapi.execution_api.security import get_team_name_for_ti
xcom_keys = list(session.scalars(xcom_query))
task_reschedule_count = (
session.scalar(select(func.count(TaskReschedule.id)).where(TaskReschedule.ti_id == task_instance_id))
or 0
)

dr.team_name = get_team_name_for_ti(task_instance_id, session)
from airflow.api_fastapi.execution_api.security import get_team_name_for_ti

context = TIRunContext(
dag_run=dr,
task_reschedule_count=task_reschedule_count,
max_tries=ti.max_tries,
# TODO: Add variables and connections that are needed (and has perms) for the task
variables=[],
connections=[],
xcom_keys_to_clear=xcom_keys,
should_retry=_is_eligible_to_retry(previous_state, ti.try_number, ti.max_tries),
)
dr.team_name = get_team_name_for_ti(task_instance_id, session)

# Only set if they are non-null
if ti.next_method:
context.next_method = ti.next_method
context.next_kwargs = ti.next_kwargs
context.start_date = ti.start_date
except SQLAlchemyError:
log.exception("Error marking Task Instance state as running")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Database error occurred"
)
context = TIRunContext(
dag_run=dr,
task_reschedule_count=task_reschedule_count,
max_tries=ti.max_tries,
# TODO: Add variables and connections that are needed (and has perms) for the task
variables=[],
connections=[],
xcom_keys_to_clear=xcom_keys,
should_retry=_is_eligible_to_retry(previous_state, ti.try_number, ti.max_tries),
)

# Only set if they are non-null
if ti.next_method:
context.next_method = ti.next_method
context.next_kwargs = ti.next_kwargs
context.start_date = ti.start_date

# JWTReissueMiddleware also writes Refreshed-API-Token but skips workload tokens, so we set it here for the workload→execution swap.
if token.claims.scope == "workload":
Expand Down Expand Up @@ -435,33 +427,25 @@ def ti_update_state(
if ti is not None:
_handle_fail_fast_for_dag(ti=ti, dag_id=dag_id, session=session, dag_bag=dag_bag)

# TODO: Replace this with FastAPI's Custom Exception handling:
# https://fastapi.tiangolo.com/tutorial/handling-errors/#install-custom-exception-handlers
try:
result = session.execute(query)
log.info(
"Task instance state updated",
new_state=updated_state,
rows_affected=getattr(result, "rowcount", 0),
)
session.add(
Log(
event=updated_state.value,
task_id=task_id,
dag_id=dag_id,
run_id=run_id,
map_index=map_index,
try_number=try_number,
logical_date=logical_date,
owner=owners,
extra=json.dumps({"host_name": hostname}) if hostname else None,
)
)
except SQLAlchemyError as e:
log.error("Error updating Task Instance state", error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Database error occurred"
result = session.execute(query)
log.info(
"Task instance state updated",
new_state=updated_state,
rows_affected=getattr(result, "rowcount", 0),
)
session.add(
Log(
event=updated_state.value,
task_id=task_id,
dag_id=dag_id,
run_id=run_id,
map_index=map_index,
try_number=try_number,
logical_date=logical_date,
owner=owners,
extra=json.dumps({"host_name": hostname}) if hostname else None,
)
)

if updated_state == TaskInstanceState.SUCCESS:
if conf.getboolean("state_store", "clear_on_success"):
Expand Down
Loading