Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from typing import Literal

from airflow.api_fastapi.core_api.base import StrictBaseModel

CallbackTerminalState = Literal["success", "failed"]


class CallbackTerminalStatePayload(StrictBaseModel):
"""Payload for transitioning a callback from RUNNING to a terminal state."""

state: CallbackTerminalState
output: str | None = None
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
asset_events,
asset_state,
assets,
callbacks,
connections,
dag_runs,
dags,
Expand All @@ -44,6 +45,7 @@

authenticated_router.include_router(assets.router, prefix="/assets", tags=["Assets"])
authenticated_router.include_router(asset_events.router, prefix="/asset-events", tags=["Asset Events"])
authenticated_router.include_router(callbacks.router, prefix="/callbacks", tags=["Callbacks"])
authenticated_router.include_router(connections.router, prefix="/connections", tags=["Connections"])
authenticated_router.include_router(dag_runs.router, prefix="/dag-runs", tags=["Dag Runs"])
authenticated_router.include_router(dags.router, prefix="/dags", tags=["Dags"])
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from typing import Annotated
from uuid import UUID

import structlog
from cadwyn import VersionedAPIRouter
from fastapi import Body, HTTPException, Response, Security, status
from structlog.contextvars import bind_contextvars

from airflow.api_fastapi.auth.tokens import JWTGenerator
from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.execution_api.datamodels.callback import CallbackTerminalStatePayload
from airflow.api_fastapi.execution_api.datamodels.token import TIToken
from airflow.api_fastapi.execution_api.deps import DepContainer
from airflow.api_fastapi.execution_api.security import CurrentTIToken, ExecutionAPIRoute, require_auth
from airflow.models.callback import Callback
from airflow.utils.state import CallbackState

log = structlog.get_logger(__name__)

router = VersionedAPIRouter(route_class=ExecutionAPIRoute)


def _require_self(token: TIToken, callback_id: UUID) -> None:
"""Mirror the ``ti:self`` enforcement from security.py for callback routes."""
if str(token.id) != str(callback_id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Token subject does not match callback id",
)


@router.post(
"/{callback_id}/run",
status_code=status.HTTP_204_NO_CONTENT,
dependencies=[Security(require_auth, scopes=["token:execution", "token:workload"])],
responses={
status.HTTP_403_FORBIDDEN: {"description": "Token subject does not match callback id"},
status.HTTP_404_NOT_FOUND: {"description": "Callback not found"},
status.HTTP_409_CONFLICT: {"description": "Callback is not in a state that can be marked running"},
},
)
def callback_run(
callback_id: UUID,
response: Response,
session: SessionDep,
services=DepContainer,
token: TIToken = CurrentTIToken,
) -> Response:
"""
Mark a callback as RUNNING.

Mirrors ``PATCH /task-instances/{id}/run``: this is the single endpoint that
accepts a workload-scoped token and atomically (a) transitions the callback
from QUEUED to RUNNING and (b) issues a fresh execution-scoped token via the
``Refreshed-API-Token`` response header. All subsequent supervisor calls hit
execution-only routes.
"""
bind_contextvars(callback_id=str(callback_id))
_require_self(token, callback_id)

callback = session.get(Callback, callback_id)
if callback is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={"reason": "not_found", "message": "Callback not found"},
)

# Allow QUEUED → RUNNING transition; treat RUNNING as idempotent so a retried
# supervisor start does not 409. Anything else (PENDING / SCHEDULED / terminal) rejects.
if callback.state == CallbackState.RUNNING:
log.info("Duplicate start request received from %s", callback.id)
elif callback.state == CallbackState.QUEUED:
callback.state = CallbackState.RUNNING
session.add(callback)
else:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail={
"reason": "invalid_state",
"message": "Callback was not in a state where it could be marked running",
"current_state": callback.state,
},
)

if token.claims.scope == "workload":
generator: JWTGenerator = services.get(JWTGenerator)
execution_token = generator.generate(extras={"sub": str(callback_id), "scope": "execution"})
response.headers["Refreshed-API-Token"] = execution_token

response.status_code = status.HTTP_204_NO_CONTENT
return response


@router.patch(
"/{callback_id}/state",
status_code=status.HTTP_204_NO_CONTENT,
responses={
status.HTTP_403_FORBIDDEN: {"description": "Token subject does not match callback id"},
status.HTTP_404_NOT_FOUND: {"description": "Callback not found"},
status.HTTP_409_CONFLICT: {"description": "Callback is not in RUNNING state"},
},
)
def callback_update_state(
callback_id: UUID,
payload: Annotated[CallbackTerminalStatePayload, Body()],
session: SessionDep,
token: TIToken = CurrentTIToken,
) -> Response:
"""Mark a RUNNING callback as SUCCESS or FAILED."""
bind_contextvars(callback_id=str(callback_id))
_require_self(token, callback_id)

callback = session.get(Callback, callback_id)
if callback is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={"reason": "not_found", "message": "Callback not found"},
)

if callback.state != CallbackState.RUNNING:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail={
"reason": "invalid_state",
"message": "Callback was not in RUNNING state",
"current_state": callback.state,
},
)

callback.state = CallbackState(payload.state)
if payload.output is not None:
callback.output = payload.output
session.add(callback)

return Response(status_code=status.HTTP_204_NO_CONTENT)
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,16 @@
RemoveUpstreamMapIndexesField,
)
from airflow.api_fastapi.execution_api.versions.v2026_04_17 import AddStateEndpoints, AddTeamNameField
from airflow.api_fastapi.execution_api.versions.v2026_04_30 import AddCallbackEndpoints
from airflow.api_fastapi.execution_api.versions.v2026_06_16 import AddRetryPolicyFields

bundle = VersionBundle(
HeadVersion(),
Version("2026-06-16", AddRetryPolicyFields),
Version(
"2026-04-30",
AddCallbackEndpoints,
),
Version(
"2026-04-17",
AddTeamNameField,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from cadwyn import VersionChange, endpoint


class AddCallbackEndpoints(VersionChange):
"""Add the ``POST /callbacks/{callback_id}/run`` and ``PATCH /callbacks/{callback_id}/state`` endpoints."""

description = __doc__

instructions_to_migrate_to_previous_version = (
endpoint("/callbacks/{callback_id}/run", ["POST"]).didnt_exist,
endpoint("/callbacks/{callback_id}/state", ["PATCH"]).didnt_exist,
)
27 changes: 20 additions & 7 deletions airflow-core/src/airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,7 +1249,7 @@ def process_executor_events(
if state in (CallbackState.RUNNING, CallbackState.FAILED, CallbackState.SUCCESS):
callback_keys_with_events.append(key)

# Handle callback state events
# Handle callback state events.
for callback_id in callback_keys_with_events:
state, info = event_buffer.pop(callback_id)
callback = session.get(Callback, callback_id)
Expand All @@ -1261,17 +1261,30 @@ def process_executor_events(
)
continue

# Callback state transitions are now driven by the supervisor through
# the Execution API (POST /callbacks/{id}/run, PATCH /callbacks/{id}/state).
# The in-process events from the executor are kept as a fallback safety
# net for cases where the supervisor crashed before reporting a terminal state

need_to_modify = False

if state == CallbackState.RUNNING:
callback.state = CallbackState.RUNNING
cls.logger().info("Callback %s is currently running", callback_id)
elif state == CallbackState.SUCCESS:
callback.state = CallbackState.SUCCESS
cls.logger().info("Callback %s completed successfully", callback_id)
if callback.state == CallbackState.RUNNING:
callback.state = CallbackState.SUCCESS
need_to_modify = True
elif state == CallbackState.FAILED:
callback.state = CallbackState.FAILED
callback.output = str(info) if info else "Execution failed"
cls.logger().error("Callback %s failed: %s", callback_id, callback.output)
session.add(callback)
callback_output = str(info) if info else "Execution failed"
cls.logger().error("Callback %s failed: %s", callback_id, callback_output)
if callback.state == CallbackState.RUNNING:
callback.state = CallbackState.FAILED
callback.output = callback_output
need_to_modify = True

if need_to_modify:
session.add(callback)

# Return if no finished tasks
if not tis_with_right_state:
Expand Down
Loading
Loading