diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/callback.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/callback.py new file mode 100644 index 0000000000000..3e637e89adca8 --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/callback.py @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Literal + +from airflow.api_fastapi.core_api.base import StrictBaseModel + +CallbackTerminalState = Literal["success", "failed"] + + +class CallbackTerminalStatePayload(StrictBaseModel): + """Payload for transitioning a callback from RUNNING to a terminal state.""" + + state: CallbackTerminalState + output: str | None = None diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py index 06f07aee82389..b4bda17fb4d1a 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py @@ -23,6 +23,7 @@ asset_events, asset_state, assets, + callbacks, connections, dag_runs, dags, @@ -44,6 +45,7 @@ authenticated_router.include_router(assets.router, prefix="/assets", tags=["Assets"]) authenticated_router.include_router(asset_events.router, prefix="/asset-events", tags=["Asset Events"]) +authenticated_router.include_router(callbacks.router, prefix="/callbacks", tags=["Callbacks"]) authenticated_router.include_router(connections.router, prefix="/connections", tags=["Connections"]) authenticated_router.include_router(dag_runs.router, prefix="/dag-runs", tags=["Dag Runs"]) authenticated_router.include_router(dags.router, prefix="/dags", tags=["Dags"]) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/callbacks.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/callbacks.py new file mode 100644 index 0000000000000..6e018a373fee0 --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/callbacks.py @@ -0,0 +1,153 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Annotated +from uuid import UUID + +import structlog +from cadwyn import VersionedAPIRouter +from fastapi import Body, HTTPException, Response, Security, status +from structlog.contextvars import bind_contextvars + +from airflow.api_fastapi.auth.tokens import JWTGenerator +from airflow.api_fastapi.common.db.common import SessionDep +from airflow.api_fastapi.execution_api.datamodels.callback import CallbackTerminalStatePayload +from airflow.api_fastapi.execution_api.datamodels.token import TIToken +from airflow.api_fastapi.execution_api.deps import DepContainer +from airflow.api_fastapi.execution_api.security import CurrentTIToken, ExecutionAPIRoute, require_auth +from airflow.models.callback import Callback +from airflow.utils.state import CallbackState + +log = structlog.get_logger(__name__) + +router = VersionedAPIRouter(route_class=ExecutionAPIRoute) + + +def _require_self(token: TIToken, callback_id: UUID) -> None: + """Mirror the ``ti:self`` enforcement from security.py for callback routes.""" + if str(token.id) != str(callback_id): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Token subject does not match callback id", + ) + + +@router.post( + "/{callback_id}/run", + status_code=status.HTTP_204_NO_CONTENT, + dependencies=[Security(require_auth, scopes=["token:execution", "token:workload"])], + responses={ + status.HTTP_403_FORBIDDEN: {"description": "Token subject does not match callback id"}, + status.HTTP_404_NOT_FOUND: {"description": "Callback not found"}, + status.HTTP_409_CONFLICT: {"description": "Callback is not in a state that can be marked running"}, + }, +) +def callback_run( + callback_id: UUID, + response: Response, + session: SessionDep, + services=DepContainer, + token: TIToken = CurrentTIToken, +) -> Response: + """ + Mark a callback as RUNNING. + + Mirrors ``PATCH /task-instances/{id}/run``: this is the single endpoint that + accepts a workload-scoped token and atomically (a) transitions the callback + from QUEUED to RUNNING and (b) issues a fresh execution-scoped token via the + ``Refreshed-API-Token`` response header. All subsequent supervisor calls hit + execution-only routes. + """ + bind_contextvars(callback_id=str(callback_id)) + _require_self(token, callback_id) + + callback = session.get(Callback, callback_id) + if callback is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={"reason": "not_found", "message": "Callback not found"}, + ) + + # Allow QUEUED → RUNNING transition; treat RUNNING as idempotent so a retried + # supervisor start does not 409. Anything else (PENDING / SCHEDULED / terminal) rejects. + if callback.state == CallbackState.RUNNING: + log.info("Duplicate start request received from %s", callback.id) + elif callback.state == CallbackState.QUEUED: + callback.state = CallbackState.RUNNING + session.add(callback) + else: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail={ + "reason": "invalid_state", + "message": "Callback was not in a state where it could be marked running", + "current_state": callback.state, + }, + ) + + if token.claims.scope == "workload": + generator: JWTGenerator = services.get(JWTGenerator) + execution_token = generator.generate(extras={"sub": str(callback_id), "scope": "execution"}) + response.headers["Refreshed-API-Token"] = execution_token + + response.status_code = status.HTTP_204_NO_CONTENT + return response + + +@router.patch( + "/{callback_id}/state", + status_code=status.HTTP_204_NO_CONTENT, + responses={ + status.HTTP_403_FORBIDDEN: {"description": "Token subject does not match callback id"}, + status.HTTP_404_NOT_FOUND: {"description": "Callback not found"}, + status.HTTP_409_CONFLICT: {"description": "Callback is not in RUNNING state"}, + }, +) +def callback_update_state( + callback_id: UUID, + payload: Annotated[CallbackTerminalStatePayload, Body()], + session: SessionDep, + token: TIToken = CurrentTIToken, +) -> Response: + """Mark a RUNNING callback as SUCCESS or FAILED.""" + bind_contextvars(callback_id=str(callback_id)) + _require_self(token, callback_id) + + callback = session.get(Callback, callback_id) + if callback is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={"reason": "not_found", "message": "Callback not found"}, + ) + + if callback.state != CallbackState.RUNNING: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail={ + "reason": "invalid_state", + "message": "Callback was not in RUNNING state", + "current_state": callback.state, + }, + ) + + callback.state = CallbackState(payload.state) + if payload.output is not None: + callback.output = payload.output + session.add(callback) + + return Response(status_code=status.HTTP_204_NO_CONTENT) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py index dfa27f53ebd91..33ce51f3c084d 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py @@ -41,11 +41,16 @@ RemoveUpstreamMapIndexesField, ) from airflow.api_fastapi.execution_api.versions.v2026_04_17 import AddStateEndpoints, AddTeamNameField +from airflow.api_fastapi.execution_api.versions.v2026_04_30 import AddCallbackEndpoints from airflow.api_fastapi.execution_api.versions.v2026_06_16 import AddRetryPolicyFields bundle = VersionBundle( HeadVersion(), Version("2026-06-16", AddRetryPolicyFields), + Version( + "2026-04-30", + AddCallbackEndpoints, + ), Version( "2026-04-17", AddTeamNameField, diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_30.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_30.py new file mode 100644 index 0000000000000..50c33982a985b --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_30.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from cadwyn import VersionChange, endpoint + + +class AddCallbackEndpoints(VersionChange): + """Add the ``POST /callbacks/{callback_id}/run`` and ``PATCH /callbacks/{callback_id}/state`` endpoints.""" + + description = __doc__ + + instructions_to_migrate_to_previous_version = ( + endpoint("/callbacks/{callback_id}/run", ["POST"]).didnt_exist, + endpoint("/callbacks/{callback_id}/state", ["PATCH"]).didnt_exist, + ) diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index 80897213c18b5..6f9779b1c17a1 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -1249,7 +1249,7 @@ def process_executor_events( if state in (CallbackState.RUNNING, CallbackState.FAILED, CallbackState.SUCCESS): callback_keys_with_events.append(key) - # Handle callback state events + # Handle callback state events. for callback_id in callback_keys_with_events: state, info = event_buffer.pop(callback_id) callback = session.get(Callback, callback_id) @@ -1261,17 +1261,30 @@ def process_executor_events( ) continue + # Callback state transitions are now driven by the supervisor through + # the Execution API (POST /callbacks/{id}/run, PATCH /callbacks/{id}/state). + # The in-process events from the executor are kept as a fallback safety + # net for cases where the supervisor crashed before reporting a terminal state + + need_to_modify = False + if state == CallbackState.RUNNING: - callback.state = CallbackState.RUNNING cls.logger().info("Callback %s is currently running", callback_id) elif state == CallbackState.SUCCESS: - callback.state = CallbackState.SUCCESS cls.logger().info("Callback %s completed successfully", callback_id) + if callback.state == CallbackState.RUNNING: + callback.state = CallbackState.SUCCESS + need_to_modify = True elif state == CallbackState.FAILED: - callback.state = CallbackState.FAILED - callback.output = str(info) if info else "Execution failed" - cls.logger().error("Callback %s failed: %s", callback_id, callback.output) - session.add(callback) + callback_output = str(info) if info else "Execution failed" + cls.logger().error("Callback %s failed: %s", callback_id, callback_output) + if callback.state == CallbackState.RUNNING: + callback.state = CallbackState.FAILED + callback.output = callback_output + need_to_modify = True + + if need_to_modify: + session.add(callback) # Return if no finished tasks if not tis_with_right_state: diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_callbacks.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_callbacks.py new file mode 100644 index 0000000000000..778ef9ef8ea14 --- /dev/null +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_callbacks.py @@ -0,0 +1,264 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Literal +from unittest import mock +from uuid import UUID, uuid4 + +import pytest +from fastapi import Request + +from airflow.api_fastapi.auth.tokens import JWTGenerator, JWTValidator +from airflow.api_fastapi.execution_api.app import lifespan +from airflow.api_fastapi.execution_api.datamodels.token import TIClaims, TIToken +from airflow.api_fastapi.execution_api.security import require_auth +from airflow.executors.workloads.callback import CallbackFetchMethod +from airflow.models.callback import ExecutorCallback +from airflow.utils.state import CallbackState + +from tests_common.test_utils.db import clear_db_callbacks + +pytestmark = pytest.mark.db_test + + +class _FakeCallbackDef: + """Minimal CallbackDefinitionProtocol stand-in for tests.""" + + path: str = "tests.fake.callback" + kwargs: dict = {} + executor: str | None = None + + def serialize(self) -> dict: + return {"path": self.path, "kwargs": self.kwargs, "executor": self.executor} + + +def _make_callback(state: CallbackState, session) -> ExecutorCallback: + cb = ExecutorCallback(callback_def=_FakeCallbackDef(), fetch_method=CallbackFetchMethod.IMPORT_PATH) + cb.state = state + session.add(cb) + session.commit() + return cb + + +def _override_require_auth(exec_app, scope: Literal["execution", "workload"] = "execution") -> None: + """Override require_auth to return a token whose sub matches the path callback_id.""" + + async def _token(request: Request) -> TIToken: + path_id = request.path_params.get("callback_id") + cb_id = UUID(path_id) if path_id else uuid4() + return TIToken(id=cb_id, claims=TIClaims(scope=scope)) + + exec_app.dependency_overrides[require_auth] = _token + + +@pytest.fixture +def _use_real_jwt_bearer(exec_app): + """Remove the mock require_auth override so the real JWT validation runs end-to-end.""" + exec_app.dependency_overrides.pop(require_auth, None) + + +class TestCallbackRun: + def setup_method(self): + clear_db_callbacks() + + def teardown_method(self): + clear_db_callbacks() + + @pytest.mark.parametrize("scope", ["workload", "execution"]) + def test_run_marks_callback_running_and_swaps_workload_token(self, client, exec_app, session, scope): + cb = _make_callback(CallbackState.QUEUED, session) + + mock_gen = mock.MagicMock(spec=JWTGenerator) + mock_gen.generate.return_value = "mock-execution-token" + lifespan.registry.register_value(JWTGenerator, mock_gen) + + _override_require_auth(exec_app, scope=scope) + + response = client.post(f"/execution/callbacks/{cb.id}/run") + + exec_app.dependency_overrides.pop(require_auth, None) + + assert response.status_code == 204 + + session.expire_all() + cb_after = session.get(ExecutorCallback, cb.id) + assert cb_after.state == CallbackState.RUNNING + + if scope == "workload": + assert response.headers["Refreshed-API-Token"] == "mock-execution-token" + mock_gen.generate.assert_called_once() + extras = mock_gen.generate.call_args.kwargs["extras"] + assert extras == {"sub": str(cb.id), "scope": "execution"} + else: + # Execution-scoped tokens skip the swap; the middleware handles refresh elsewhere. + assert "Refreshed-API-Token" not in response.headers + mock_gen.generate.assert_not_called() + + @pytest.mark.parametrize( + "state", + [ + CallbackState.PENDING, + CallbackState.SCHEDULED, + CallbackState.SUCCESS, + CallbackState.FAILED, + ], + ) + def test_run_returns_409_for_non_runnable_state(self, client, exec_app, session, state): + cb = _make_callback(state, session) + _override_require_auth(exec_app, scope="workload") + + response = client.post(f"/execution/callbacks/{cb.id}/run") + exec_app.dependency_overrides.pop(require_auth, None) + + assert response.status_code == 409 + assert response.json()["detail"]["reason"] == "invalid_state" + assert response.json()["detail"]["current_state"] == state.value + + def test_run_returns_404_when_callback_missing(self, client, exec_app): + missing_id = uuid4() + + async def _token(request: Request) -> TIToken: + return TIToken(id=missing_id, claims=TIClaims(scope="workload")) + + exec_app.dependency_overrides[require_auth] = _token + + response = client.post(f"/execution/callbacks/{missing_id}/run") + exec_app.dependency_overrides.pop(require_auth, None) + + assert response.status_code == 404 + assert response.json()["detail"]["reason"] == "not_found" + + def test_run_rejects_mismatched_sub(self, client, exec_app, session): + cb = _make_callback(CallbackState.QUEUED, session) + + async def _token(request: Request) -> TIToken: + return TIToken(id=uuid4(), claims=TIClaims(scope="workload")) + + exec_app.dependency_overrides[require_auth] = _token + + response = client.post(f"/execution/callbacks/{cb.id}/run") + exec_app.dependency_overrides.pop(require_auth, None) + + assert response.status_code == 403 + assert response.json()["detail"] == "Token subject does not match callback id" + + +class TestCallbackUpdateState: + def setup_method(self): + clear_db_callbacks() + + def teardown_method(self): + clear_db_callbacks() + + @pytest.mark.parametrize( + ("payload_state", "expected_state"), + [ + ("success", CallbackState.SUCCESS), + ("failed", CallbackState.FAILED), + ], + ) + def test_update_state_writes_terminal_state( + self, client, exec_app, session, payload_state, expected_state + ): + cb = _make_callback(CallbackState.RUNNING, session) + _override_require_auth(exec_app, scope="execution") + + response = client.patch( + f"/execution/callbacks/{cb.id}/state", + json={"state": payload_state, "output": "an output"}, + ) + exec_app.dependency_overrides.pop(require_auth, None) + + assert response.status_code == 204 + session.expire_all() + cb_after = session.get(ExecutorCallback, cb.id) + assert cb_after.state == expected_state + assert cb_after.output == "an output" + + @pytest.mark.parametrize( + "state", + [ + CallbackState.QUEUED, + CallbackState.PENDING, + CallbackState.SCHEDULED, + CallbackState.SUCCESS, + CallbackState.FAILED, + ], + ) + def test_update_state_returns_409_when_not_running(self, client, exec_app, session, state): + cb = _make_callback(state, session) + _override_require_auth(exec_app, scope="execution") + + response = client.patch( + f"/execution/callbacks/{cb.id}/state", + json={"state": "success"}, + ) + exec_app.dependency_overrides.pop(require_auth, None) + + assert response.status_code == 409 + assert response.json()["detail"]["reason"] == "invalid_state" + + def test_update_state_returns_404_when_callback_missing(self, client, exec_app): + missing_id = uuid4() + _override_require_auth(exec_app, scope="execution") + + response = client.patch( + f"/execution/callbacks/{missing_id}/state", + json={"state": "success"}, + ) + exec_app.dependency_overrides.pop(require_auth, None) + + assert response.status_code == 404 + + def test_update_state_rejects_mismatched_sub(self, client, exec_app, session): + cb = _make_callback(CallbackState.RUNNING, session) + + async def _token(request: Request) -> TIToken: + return TIToken(id=uuid4(), claims=TIClaims(scope="execution")) + + exec_app.dependency_overrides[require_auth] = _token + + response = client.patch( + f"/execution/callbacks/{cb.id}/state", + json={"state": "success"}, + ) + exec_app.dependency_overrides.pop(require_auth, None) + + assert response.status_code == 403 + + @pytest.mark.usefixtures("_use_real_jwt_bearer") + def test_update_state_rejects_workload_scope(self, client, session): + cb = _make_callback(CallbackState.RUNNING, session) + + validator = mock.AsyncMock(spec=JWTValidator) + validator.avalidated_claims.return_value = { + "sub": str(cb.id), + "scope": "workload", + "exp": 9999999999, + "iat": 1000000000, + "nbf": 1000000000, + } + lifespan.registry.register_value(JWTValidator, validator) + + response = client.patch( + f"/execution/callbacks/{cb.id}/state", + json={"state": "success"}, + ) + + assert response.status_code == 403 + assert "Token type 'workload' not allowed" in response.json()["detail"] diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index 16c42e8acd029..d7a6a1b1bf440 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -658,6 +658,58 @@ def create_callback_in_state(state: CallbackState): assert session.get(ExecutorCallback, queued_callback.id).state == CallbackState.QUEUED assert session.get(ExecutorCallback, running_callback.id).state == CallbackState.RUNNING + @pytest.mark.parametrize( + ("db_state", "event_state", "expected_state"), + [ + # The state QUEUED can't go forward, it is only available by calling api + (CallbackState.QUEUED, CallbackState.RUNNING, CallbackState.QUEUED), + (CallbackState.QUEUED, CallbackState.SUCCESS, CallbackState.QUEUED), + (CallbackState.QUEUED, CallbackState.FAILED, CallbackState.QUEUED), + (CallbackState.RUNNING, CallbackState.SUCCESS, CallbackState.SUCCESS), + (CallbackState.RUNNING, CallbackState.FAILED, CallbackState.FAILED), + # Stale events must not regress an already-terminal callback. The API + # path (POST /run, PATCH /state) is authoritative; events are a fallback. + (CallbackState.SUCCESS, CallbackState.RUNNING, CallbackState.SUCCESS), + (CallbackState.SUCCESS, CallbackState.FAILED, CallbackState.SUCCESS), + (CallbackState.FAILED, CallbackState.RUNNING, CallbackState.FAILED), + (CallbackState.FAILED, CallbackState.SUCCESS, CallbackState.FAILED), + # Already RUNNING in DB: a duplicate RUNNING event is a no-op. + (CallbackState.RUNNING, CallbackState.RUNNING, CallbackState.RUNNING), + ], + ) + def test_process_executor_events_writes_callback_state_forward_only( + self, dag_maker, session, db_state, event_state, expected_state + ): + def test_callback(): + pass + + with dag_maker(dag_id="test_callback_forward_only"): + pass + dag_run = dag_maker.create_dagrun() + + callback = Deadline( + deadline_time=timezone.utcnow(), + callback=SyncCallback(test_callback), + dagrun_id=dag_run.id, + deadline_alert_id=None, + ).callback + callback.state = db_state + callback.data["dag_run_id"] = dag_run.id + callback.data["dag_id"] = dag_run.dag_id + session.add(callback) + session.flush() + + executor = MockExecutor(do_update=False) + executor.event_buffer[callback.id] = (event_state, None) + + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(scheduler_job, executors=[executor]) + self.job_runner._process_executor_events(executor=executor, session=session) + + session.flush() + session.expire_all() + assert session.get(ExecutorCallback, callback.id).state == expected_state + @mock.patch("airflow.jobs.scheduler_job_runner.TaskCallbackRequest") @mock.patch("airflow._shared.observability.metrics.stats._get_backend") def test_process_executor_events_with_callback( diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 15513010062e9..76ddfca091ac5 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -169,6 +169,7 @@ def getuser() -> str: _log_retry_warning = before_log(log, logging.WARNING) __all__ = [ + "CallbackOperations", "Client", "ConnectionOperations", "ServerResponseError", @@ -907,6 +908,36 @@ def get_detail_response(self, ti_id: uuid.UUID) -> HITLDetailResponse: return HITLDetailResponse.model_validate_json(resp.read()) +class CallbackOperations: + __slots__ = ("client",) + + def __init__(self, client: Client): + self.client = client + + def start(self, id: uuid.UUID | str) -> None: + """ + Mark a callback as RUNNING and exchange a workload token for an execution token. + + Mirrors ``TaskInstanceOperations.start``: this is the single API call that + accepts a workload-scoped token; the server returns the new execution token + via the ``Refreshed-API-Token`` response header and the Client's response + hook automatically swaps it onto subsequent requests. + """ + self.client.post(f"callbacks/{id}/run") + + def finish( + self, + id: uuid.UUID | str, + state: str, + output: str | None = None, + ) -> None: + """Tell the API server that this callback has reached a terminal state.""" + body: dict[str, Any] = {"state": state} + if output is not None: + body["output"] = output + self.client.patch(f"callbacks/{id}/state", json=body) + + class BearerAuth(httpx.Auth): def __init__(self, token: str): self.token: str = token @@ -1090,6 +1121,12 @@ def dags(self) -> DagsOperations: """Operations related to DAGs.""" return DagsOperations(self) + @lru_cache() # type: ignore[misc] + @property + def callbacks(self) -> CallbackOperations: + """Operations related to Callbacks.""" + return CallbackOperations(self) + # This is only used for parsing. ServerResponseError is raised instead class _ErrorBody(BaseModel): diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index b5b100154c389..7acdcb90d001b 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -85,6 +85,23 @@ class AssetStateResponse(BaseModel): value: Annotated[str, Field(title="Value")] +class State(str, Enum): + SUCCESS = "success" + FAILED = "failed" + + +class CallbackTerminalStatePayload(BaseModel): + """ + Payload for transitioning a callback from RUNNING to a terminal state. + """ + + model_config = ConfigDict( + extra="forbid", + ) + state: Annotated[State, Field(title="State")] + output: Annotated[str | None, Field(title="Output")] = None + + class ConnectionResponse(BaseModel): """ Connection schema for responses with fields that are needed for Runtime. diff --git a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py index 94d84193192db..4374da13b5cf7 100644 --- a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py @@ -352,6 +352,15 @@ def supervise_callback( logger = structlog.get_logger(logger_name="callback").bind() with _ensure_client(server, token, client=client) as client: + # Mark the callback as RUNNING via the API. This is the single endpoint + # that accepts a workload-scoped token; it returns a fresh execution + # token via the Refreshed-API-Token header which the Client's response + # hook adopts automatically for the rest of the run. + client.callbacks.start(id) + + terminal_state = "failed" + terminal_output = None + try: process = CallbackSubprocess.start( id=id, @@ -372,9 +381,23 @@ def supervise_callback( exit_code=exit_code, duration=end - start, ) - if exit_code != 0: - raise RuntimeError(f"Callback subprocess exited with code {exit_code}") - return exit_code + if exit_code == 0: + terminal_state = "success" + else: + terminal_output = f"Callback subprocess exited with code {exit_code}" + except Exception as e: + terminal_output = f"Callback supervisor error: {type(e).__name__}: {e}" + log.exception("Callback supervisor error", workload_id=id) + raise finally: + try: + client.callbacks.finish(id, state=terminal_state, output=terminal_output) + except Exception: + log.exception("Failed to report final callback state", workload_id=id) + if log_path and log_file_descriptor: log_file_descriptor.close() + + if exit_code != 0: + raise RuntimeError(f"Callback subprocess exited with code {exit_code}") + return exit_code diff --git a/task-sdk/tests/task_sdk/api/test_client.py b/task-sdk/tests/task_sdk/api/test_client.py index 26e2a7e66bb6e..db35129afcef3 100644 --- a/task-sdk/tests/task_sdk/api/test_client.py +++ b/task-sdk/tests/task_sdk/api/test_client.py @@ -1700,3 +1700,81 @@ def handle_request(request: httpx.Request) -> httpx.Response: with pytest.raises(ServerResponseError): client.dags.get(dag_id="test_dag") + + +class TestCallbackOperations: + def test_start_posts_and_picks_up_refreshed_token(self): + """start() POSTs to /callbacks/{id}/run; the response hook applies the new bearer.""" + callback_id = uuid6.uuid7() + seen: dict[str, str | None] = {} + + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == f"/callbacks/{callback_id}/run": + seen["method"] = request.method + seen["auth"] = request.headers.get("Authorization") + return httpx.Response( + status_code=204, + headers={"Refreshed-API-Token": "new-execution-token"}, + ) + seen["follow_up_auth"] = request.headers.get("Authorization") + return httpx.Response(status_code=200, json={}) + + client = Client( + base_url="test://server", + token="initial-workload-token", + transport=httpx.MockTransport(handle_request), + ) + + client.callbacks.start(callback_id) + + assert seen["method"] == "POST" + assert seen["auth"] == "Bearer initial-workload-token" + + # A subsequent call should now carry the refreshed token, proving the + # response hook adopted Refreshed-API-Token onto the client's auth. + client.get(f"/callbacks/{callback_id}/something-else") + assert seen["follow_up_auth"] == "Bearer new-execution-token" + + def test_start_propagates_server_error(self): + callback_id = uuid6.uuid7() + + def handle_request(request: httpx.Request) -> httpx.Response: + return httpx.Response( + status_code=409, + json={"detail": {"reason": "invalid_state", "current_state": "success"}}, + ) + + client = make_client(transport=httpx.MockTransport(handle_request)) + + with pytest.raises(ServerResponseError) as exc_info: + client.callbacks.start(callback_id) + + assert exc_info.value.response.status_code == 409 + + @pytest.mark.parametrize( + ("state", "output"), + [ + ("success", None), + ("failed", "Callback subprocess exited with code 1"), + ], + ) + def test_finish_patches_state_endpoint(self, state, output): + callback_id = uuid6.uuid7() + seen: dict = {} + + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == f"/callbacks/{callback_id}/state": + seen["method"] = request.method + seen["body"] = json.loads(request.read()) + return httpx.Response(status_code=204) + return httpx.Response(status_code=400) + + client = make_client(transport=httpx.MockTransport(handle_request)) + client.callbacks.finish(callback_id, state=state, output=output) + + assert seen["method"] == "PATCH" + assert seen["body"]["state"] == state + if output is not None: + assert seen["body"]["output"] == output + else: + assert "output" not in seen["body"] diff --git a/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py index 8cb9fdcc8167a..51b4e32f0bbc4 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py @@ -24,11 +24,17 @@ from operator import attrgetter from typing import Any from unittest.mock import patch +from uuid import uuid4 import pytest import structlog -from airflow.sdk.execution_time.callback_supervisor import CallbackSubprocess, execute_callback +from airflow.sdk.api.client import CallbackOperations, Client +from airflow.sdk.execution_time.callback_supervisor import ( + CallbackSubprocess, + execute_callback, + supervise_callback, +) from airflow.sdk.execution_time.comms import ( ConnectionResult, GetConnection, @@ -228,3 +234,106 @@ def test_handle_requests( if client_mock: mock_client_method.assert_called_once_with(*client_mock.args, **client_mock.kwargs) + + +class TestSuperviseCallback: + """supervise_callback drives every callback state transition through the API.""" + + def _make_mock_client(self, mocker): + client = mocker.Mock(spec=Client) + client.callbacks = mocker.Mock(spec=CallbackOperations) + return client + + def test_start_called_before_subprocess_then_finish_success(self, mocker): + cb_id = str(uuid4()) + client = self._make_mock_client(mocker) + + order: list[str] = [] + client.callbacks.start.side_effect = lambda _id: order.append("start_api") + client.callbacks.finish.side_effect = lambda _id, state, output=None: order.append(f"finish:{state}") + + proc = mocker.Mock() + proc.wait.return_value = 0 + + def _subprocess_start(**_): + order.append("subprocess_start") + return proc + + mocker.patch.object(CallbackSubprocess, "start", side_effect=_subprocess_start) + + exit_code = supervise_callback( + id=cb_id, + callback_path="tests.fake.callback", + callback_kwargs={}, + token="workload-token", + client=client, + ) + + assert exit_code == 0 + client.callbacks.start.assert_called_once_with(cb_id) + client.callbacks.finish.assert_called_once_with(cb_id, state="success", output=None) + assert order == ["start_api", "subprocess_start", "finish:success"] + + def test_finish_called_with_failed_when_subprocess_exits_nonzero(self, mocker): + cb_id = str(uuid4()) + client = self._make_mock_client(mocker) + + proc = mocker.Mock() + proc.wait.return_value = 1 + mocker.patch.object(CallbackSubprocess, "start", return_value=proc) + + with pytest.raises(RuntimeError, match="exited with code 1"): + supervise_callback( + id=cb_id, + callback_path="tests.fake.callback", + callback_kwargs={}, + token="workload-token", + client=client, + ) + + client.callbacks.start.assert_called_once_with(cb_id) + client.callbacks.finish.assert_called_once() + kwargs = client.callbacks.finish.call_args.kwargs + assert kwargs["state"] == "failed" + assert "exited with code 1" in kwargs["output"] + + def test_finish_called_with_failed_when_subprocess_raises(self, mocker): + cb_id = str(uuid4()) + client = self._make_mock_client(mocker) + + mocker.patch.object(CallbackSubprocess, "start", side_effect=RuntimeError("boom")) + + with pytest.raises(RuntimeError, match="boom"): + supervise_callback( + id=cb_id, + callback_path="tests.fake.callback", + callback_kwargs={}, + token="workload-token", + client=client, + ) + + client.callbacks.finish.assert_called_once() + kwargs = client.callbacks.finish.call_args.kwargs + assert kwargs["state"] == "failed" + assert "RuntimeError" in kwargs["output"] + + def test_finish_failure_is_swallowed(self, mocker): + """If the terminal-state report fails, supervisor must still propagate the run result.""" + cb_id = str(uuid4()) + client = self._make_mock_client(mocker) + client.callbacks.finish.side_effect = RuntimeError("network down") + + proc = mocker.Mock() + proc.wait.return_value = 0 + mocker.patch.object(CallbackSubprocess, "start", return_value=proc) + + # Should not raise — finish() is best-effort because the executor's + # event channel is the safety net. + exit_code = supervise_callback( + id=cb_id, + callback_path="tests.fake.callback", + callback_kwargs={}, + token="workload-token", + client=client, + ) + assert exit_code == 0