diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/asset_state.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/asset_state.py new file mode 100644 index 0000000000000..ec773201c7e2f --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/asset_state.py @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from airflow.api_fastapi.core_api.base import StrictBaseModel + + +class AssetStateResponse(StrictBaseModel): + """Asset state value returned to a worker.""" + + value: str + + +class AssetStatePutBody(StrictBaseModel): + """Request body for setting an asset state value.""" + + value: str diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/task_state.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/task_state.py new file mode 100644 index 0000000000000..3200f3177af35 --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/task_state.py @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from airflow.api_fastapi.core_api.base import StrictBaseModel + + +class TaskStateResponse(StrictBaseModel): + """Task state value returned to a worker.""" + + value: str + + +class TaskStatePutBody(StrictBaseModel): + """Request body for setting a task state value.""" + + value: str diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py index a076592d6471a..06f07aee82389 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py @@ -21,6 +21,7 @@ from airflow.api_fastapi.execution_api.routes import ( asset_events, + asset_state, assets, connections, dag_runs, @@ -29,6 +30,7 @@ hitl, task_instances, task_reschedules, + task_state, variables, xcoms, ) @@ -52,5 +54,7 @@ authenticated_router.include_router(variables.router, prefix="/variables", tags=["Variables"]) authenticated_router.include_router(xcoms.router, prefix="/xcoms", tags=["XComs"]) authenticated_router.include_router(hitl.router, prefix="/hitlDetails", tags=["Human in the Loop"]) +authenticated_router.include_router(task_state.router, prefix="/state/ti", tags=["Task State"]) +authenticated_router.include_router(asset_state.router, prefix="/state/asset", tags=["Asset State"]) execution_api_router.include_router(authenticated_router) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_state.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_state.py new file mode 100644 index 0000000000000..ba00260b16b2d --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_state.py @@ -0,0 +1,123 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Execution API routes for asset state. + +Asset state is keyed by asset *name* (not integer id) in the URL — asset names +are unique, and callers (task SDK accessors) have the name from their Asset +object without needing a DB lookup. The route resolves name → asset_id +internally for the state backend scope. + +Per-task asset registration checks are intentionally not implemented here +(deferred to AIP-93 — see TODO comment below). +""" + +from __future__ import annotations + +from typing import Annotated + +from cadwyn import VersionedAPIRouter +from fastapi import HTTPException, Query, status +from sqlalchemy import select + +from airflow._shared.state import AssetScope +from airflow.api_fastapi.common.db.common import SessionDep +from airflow.api_fastapi.execution_api.datamodels.asset_state import ( + AssetStatePutBody, + AssetStateResponse, +) +from airflow.api_fastapi.execution_api.security import ExecutionAPIRoute +from airflow.models.asset import AssetModel +from airflow.state import get_state_backend + +# TODO(AIP-103): enforce that the requesting task is registered with the asset +# (via task_inlet_asset_reference or task_outlet_asset_reference) before +# allowing reads/writes. Currently any task with a valid execution token can +# access any asset's state — the same gap exists in /assets and /asset-events. +# Proper fix is a unified asset-registration check across all asset routes, +# not just here. +router = VersionedAPIRouter( + route_class=ExecutionAPIRoute, + responses={ + status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, + status.HTTP_404_NOT_FOUND: {"description": "Not found"}, + }, +) + + +def _resolve_asset_id(name: str, session: SessionDep) -> int: + """Resolve asset name → integer asset_id, 404 if not found.""" + asset_id = session.scalar(select(AssetModel.id).where(AssetModel.name == name, AssetModel.active.has())) + if asset_id is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={"reason": "not_found", "message": f"Asset {name!r} not found"}, + ) + return asset_id + + +@router.get("/value") +def get_asset_state( + name: Annotated[str, Query(min_length=1)], + key: Annotated[str, Query(min_length=1)], + session: SessionDep, +) -> AssetStateResponse: + """Get an asset state value.""" + asset_id = _resolve_asset_id(name, session) + value = get_state_backend().get(AssetScope(asset_id=asset_id), key, session=session) # type: ignore[call-arg] # @provide_session adds session kwarg at runtime; BaseStateBackend signature omits it so mypy can't see it + if value is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={ + "reason": "not_found", + "message": f"Asset state key {key!r} not found", + }, + ) + return AssetStateResponse(value=value) + + +@router.put("/value", status_code=status.HTTP_204_NO_CONTENT) +def set_asset_state( + name: Annotated[str, Query(min_length=1)], + key: Annotated[str, Query(min_length=1)], + body: AssetStatePutBody, + session: SessionDep, +) -> None: + """Set an asset state value.""" + asset_id = _resolve_asset_id(name, session) + get_state_backend().set(AssetScope(asset_id=asset_id), key, body.value, session=session) # type: ignore[call-arg] # @provide_session adds session kwarg at runtime; BaseStateBackend signature omits it so mypy can't see it + + +@router.delete("/value", status_code=status.HTTP_204_NO_CONTENT) +def delete_asset_state( + name: Annotated[str, Query(min_length=1)], + key: Annotated[str, Query(min_length=1)], + session: SessionDep, +) -> None: + """Delete a single asset state key.""" + asset_id = _resolve_asset_id(name, session) + get_state_backend().delete(AssetScope(asset_id=asset_id), key, session=session) # type: ignore[call-arg] # @provide_session adds session kwarg at runtime; BaseStateBackend signature omits it so mypy can't see it + + +@router.delete("/clear", status_code=status.HTTP_204_NO_CONTENT) +def clear_asset_state( + name: Annotated[str, Query(min_length=1)], + session: SessionDep, +) -> None: + """Delete all state keys for an asset.""" + asset_id = _resolve_asset_id(name, session) + get_state_backend().clear(AssetScope(asset_id=asset_id), session=session) # type: ignore[call-arg] # @provide_session adds session kwarg at runtime; BaseStateBackend signature omits it so mypy can't see it diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_state.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_state.py new file mode 100644 index 0000000000000..db24109969c76 --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_state.py @@ -0,0 +1,128 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Annotated +from uuid import UUID + +from cadwyn import VersionedAPIRouter +from fastapi import HTTPException, Path, Query, Security, status +from sqlalchemy.orm import Session + +from airflow._shared.state import TaskScope +from airflow.api_fastapi.common.db.common import SessionDep +from airflow.api_fastapi.execution_api.datamodels.task_state import ( + TaskStatePutBody, + TaskStateResponse, +) +from airflow.api_fastapi.execution_api.security import ExecutionAPIRoute, require_auth +from airflow.models.taskinstance import TaskInstance as TI +from airflow.state import get_state_backend + +router = VersionedAPIRouter( + route_class=ExecutionAPIRoute, + responses={ + status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, + status.HTTP_403_FORBIDDEN: {"description": "Access denied"}, + status.HTTP_404_NOT_FOUND: {"description": "Not found"}, + }, + dependencies=[Security(require_auth, scopes=["ti:self"])], +) + + +def _get_task_scope_for_ti(task_instance_id: UUID, session: Session) -> TaskScope: + ti = session.get(TI, task_instance_id) + if ti is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={ + "reason": "not_found", + "message": f"Task instance {task_instance_id} not found", + }, + ) + return TaskScope(dag_id=ti.dag_id, run_id=ti.run_id, task_id=ti.task_id, map_index=ti.map_index) + + +@router.get("/{task_instance_id}/{key}") +def get_task_state( + task_instance_id: UUID, + key: Annotated[str, Path(min_length=1)], + session: SessionDep, +) -> TaskStateResponse: + """Get value for a task state.""" + scope = _get_task_scope_for_ti(task_instance_id, session) + value = get_state_backend().get(scope, key, session=session) # type: ignore[call-arg] # @provide_session adds session kwarg at runtime; BaseStateBackend signature omits it so mypy can't see it + if value is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={ + "reason": "not_found", + "message": f"Task state key {key!r} not found", + }, + ) + return TaskStateResponse(value=value) + + +@router.put("/{task_instance_id}/{key}", status_code=status.HTTP_204_NO_CONTENT) +def set_task_state( + task_instance_id: UUID, + key: Annotated[str, Path(min_length=1)], + body: TaskStatePutBody, + session: SessionDep, +) -> None: + """Set a task state key, creating or updating the row.""" + scope = _get_task_scope_for_ti(task_instance_id, session) + get_state_backend().set(scope, key, body.value, session=session) # type: ignore[call-arg] # @provide_session adds session kwarg at runtime; BaseStateBackend signature omits it so mypy can't see it + + +@router.delete("/{task_instance_id}/{key}", status_code=status.HTTP_204_NO_CONTENT) +def delete_task_state( + task_instance_id: UUID, + key: Annotated[str, Path(min_length=1)], + session: SessionDep, +) -> None: + """Delete a single task state key.""" + scope = _get_task_scope_for_ti(task_instance_id, session) + get_state_backend().delete(scope, key, session=session) # type: ignore[call-arg] # @provide_session adds session kwarg at runtime; BaseStateBackend signature omits it so mypy can't see it + + +@router.delete("/{task_instance_id}", status_code=status.HTTP_204_NO_CONTENT) +def clear_task_state( + task_instance_id: UUID, + session: SessionDep, + all_map_indices: Annotated[bool, Query()] = False, +) -> None: + """ + Delete all task state keys for this task instance. + + By default, only keys for the requesting TI's exact ``map_index`` are + cleared — same isolation as DELETE endpoint above. + + Pass ``?all_map_indices=true`` to wipe state for every mapped sibling of + the task in the same DAG run. This is intentionally fleet-wide: the + ``ti:self`` JWT authentication scope authenticates that the caller is + a legitimate member of the mapped task group, and grants it authority + to reset shared task state on behalf of the whole group. + The SDK only forwards this flag when the user calls ``task_state.clear(all_map_indices=True)`` + explicitly, so the expanded scope is always an explicit opt-in by the task author. + + For non-mapped tasks (``map_index=-1``), there is only ever one index, so + ``?all_map_indices=true`` is functionally identical to the default and is + accepted without error. + """ + scope = _get_task_scope_for_ti(task_instance_id, session) + get_state_backend().clear(scope, all_map_indices=all_map_indices, session=session) # type: ignore[call-arg] # @provide_session adds session kwarg at runtime; BaseStateBackend signature omits it so mypy can't see it diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py index 7ccaf3fdf15ca..dfa27f53ebd91 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py @@ -40,7 +40,7 @@ MovePreviousRunEndpoint, RemoveUpstreamMapIndexesField, ) -from airflow.api_fastapi.execution_api.versions.v2026_04_17 import AddTeamNameField +from airflow.api_fastapi.execution_api.versions.v2026_04_17 import AddStateEndpoints, AddTeamNameField from airflow.api_fastapi.execution_api.versions.v2026_06_16 import AddRetryPolicyFields bundle = VersionBundle( @@ -49,6 +49,7 @@ Version( "2026-04-17", AddTeamNameField, + AddStateEndpoints, ), Version( "2026-04-06", diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_17.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_17.py index e7cd9d331a591..c08454c4786b8 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_17.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_17.py @@ -17,7 +17,7 @@ from __future__ import annotations -from cadwyn import ResponseInfo, VersionChange, convert_response_to_previous_version_for, schema +from cadwyn import ResponseInfo, VersionChange, convert_response_to_previous_version_for, endpoint, schema from airflow.api_fastapi.execution_api.datamodels.taskinstance import DagRun, TIRunContext @@ -34,3 +34,20 @@ def remove_team_name_field(response: ResponseInfo) -> None: # type: ignore[misc """Remove the ``team_name`` field from dag_run for older API versions.""" if "dag_run" in response.body and isinstance(response.body["dag_run"], dict): response.body["dag_run"].pop("team_name", None) + + +class AddStateEndpoints(VersionChange): + """Add task state and asset state CRUD endpoints.""" + + description = __doc__ + + instructions_to_migrate_to_previous_version = ( + endpoint("/state/ti/{task_instance_id}/{key}", ["GET"]).didnt_exist, + endpoint("/state/ti/{task_instance_id}/{key}", ["PUT"]).didnt_exist, + endpoint("/state/ti/{task_instance_id}/{key}", ["DELETE"]).didnt_exist, + endpoint("/state/ti/{task_instance_id}", ["DELETE"]).didnt_exist, + endpoint("/state/asset/value", ["GET"]).didnt_exist, + endpoint("/state/asset/value", ["PUT"]).didnt_exist, + endpoint("/state/asset/value", ["DELETE"]).didnt_exist, + endpoint("/state/asset/clear", ["DELETE"]).didnt_exist, + ) diff --git a/airflow-core/src/airflow/migrations/versions/0113_3_3_0_add_retry_policy_fields_to_ti.py b/airflow-core/src/airflow/migrations/versions/0113_3_3_0_add_retry_policy_fields_to_ti.py index 5fc8007e4178e..a390163b57dc0 100644 --- a/airflow-core/src/airflow/migrations/versions/0113_3_3_0_add_retry_policy_fields_to_ti.py +++ b/airflow-core/src/airflow/migrations/versions/0113_3_3_0_add_retry_policy_fields_to_ti.py @@ -30,7 +30,7 @@ is a metadata-only operation (no table rewrite). Revision ID: b8f3e4a1d2c9 -Revises: 9fabad868fdb +Revises: fde9ed84d07b Create Date: 2026-04-16 12:00:00.000000 """ diff --git a/airflow-core/src/airflow/state/__init__.py b/airflow-core/src/airflow/state/__init__.py index 89fa0801a06f0..109b9e9b37f04 100644 --- a/airflow-core/src/airflow/state/__init__.py +++ b/airflow-core/src/airflow/state/__init__.py @@ -17,6 +17,8 @@ # under the License. from __future__ import annotations +import threading + from airflow._shared.state import ( AssetScope as AssetScope, BaseStateBackend as BaseStateBackend, @@ -36,3 +38,17 @@ def resolve_state_backend() -> type[BaseStateBackend]: f"Your custom state backend `{clazz.__name__}` is not a subclass of `BaseStateBackend`." ) return clazz + + +_backend_instance: BaseStateBackend | None = None +_backend_lock = threading.Lock() + + +def get_state_backend() -> BaseStateBackend: + """Return a cached instance of the configured state backend.""" + global _backend_instance + if _backend_instance is None: + with _backend_lock: + if _backend_instance is None: + _backend_instance = resolve_state_backend()() + return _backend_instance diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_asset_state.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_asset_state.py new file mode 100644 index 0000000000000..8bdf6ac859965 --- /dev/null +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_asset_state.py @@ -0,0 +1,183 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest +from sqlalchemy import delete, select + +from airflow.models.asset import AssetActive, AssetModel +from airflow.models.asset_state import AssetStateModel +from airflow.utils.session import create_session + +if TYPE_CHECKING: + from fastapi.testclient import TestClient + from sqlalchemy.orm import Session + +pytestmark = pytest.mark.db_test + + +@pytest.fixture(autouse=True) +def reset_state_tables(): + with create_session() as session: + session.execute(delete(AssetStateModel)) + session.execute(delete(AssetModel)) + + +@pytest.fixture +def asset(session: Session) -> AssetModel: + asset = AssetModel(name="test_asset", uri="s3://bucket/test", group="asset") + session.add(asset) + session.flush() + session.add(AssetActive.for_asset(asset)) + session.commit() + return asset + + +@pytest.fixture +def inactive_asset(session: Session) -> AssetModel: + """An asset row with no asset_active entry — simulates a removed asset.""" + asset = AssetModel(name="inactive_asset", uri="s3://bucket/inactive", group="asset") + session.add(asset) + session.commit() + return asset + + +_VALUE_URL = "/execution/state/asset/value" +_CLEAR_URL = "/execution/state/asset/clear" + + +class TestGetAssetState: + def test_get_returns_value(self, client: TestClient, asset: AssetModel): + client.put(_VALUE_URL, params={"name": asset.name, "key": "watermark"}, json={"value": "2026-04-29"}) + + response = client.get(_VALUE_URL, params={"name": asset.name, "key": "watermark"}) + + assert response.status_code == 200 + assert response.json() == {"value": "2026-04-29"} + + def test_get_missing_key_returns_404(self, client: TestClient, asset: AssetModel): + response = client.get(_VALUE_URL, params={"name": asset.name, "key": "never_set"}) + + assert response.status_code == 404 + assert response.json()["detail"]["reason"] == "not_found" + + def test_get_asset_name_with_slashes(self, client: TestClient, session): + slashed = AssetModel(name="team/sales/orders", uri="s3://bucket/team/sales", group="asset") + session.add(slashed) + session.flush() + session.add(AssetActive.for_asset(slashed)) + session.commit() + + client.put(_VALUE_URL, params={"name": slashed.name, "key": "wm"}, json={"value": "x"}) + response = client.get(_VALUE_URL, params={"name": slashed.name, "key": "wm"}) + + assert response.status_code == 200 + assert response.json() == {"value": "x"} + + +class TestPutAssetState: + def test_put_creates_row(self, client: TestClient, asset: AssetModel): + response = client.put( + _VALUE_URL, params={"name": asset.name, "key": "watermark"}, json={"value": "2026-04-29"} + ) + + assert response.status_code == 204 + with create_session() as session: + row = session.scalar( + select(AssetStateModel).where( + AssetStateModel.asset_id == asset.id, + AssetStateModel.key == "watermark", + ) + ) + assert row is not None + assert row.value == "2026-04-29" + + def test_put_overwrites_existing(self, client: TestClient, asset: AssetModel): + client.put(_VALUE_URL, params={"name": asset.name, "key": "watermark"}, json={"value": "2026-04-28"}) + + response = client.put( + _VALUE_URL, params={"name": asset.name, "key": "watermark"}, json={"value": "2026-04-29"} + ) + + assert response.status_code == 204 + assert client.get(_VALUE_URL, params={"name": asset.name, "key": "watermark"}).json() == { + "value": "2026-04-29" + } + + def test_put_empty_body_returns_422(self, client: TestClient, asset: AssetModel): + response = client.put(_VALUE_URL, params={"name": asset.name, "key": "watermark"}, json={}) + + assert response.status_code == 422 + + def test_put_extra_field_returns_422(self, client: TestClient, asset: AssetModel): + response = client.put( + _VALUE_URL, params={"name": asset.name, "key": "watermark"}, json={"value": "x", "extra": "y"} + ) + + assert response.status_code == 422 + + def test_put_unknown_asset_returns_404(self, client: TestClient): + response = client.put( + _VALUE_URL, params={"name": "nonexistent", "key": "watermark"}, json={"value": "x"} + ) + + assert response.status_code == 404 + assert "nonexistent" in response.json()["detail"]["message"] + + +class TestDeleteAssetState: + def test_delete_removes_key(self, client: TestClient, asset: AssetModel): + client.put(_VALUE_URL, params={"name": asset.name, "key": "watermark"}, json={"value": "2026-04-29"}) + + response = client.delete(_VALUE_URL, params={"name": asset.name, "key": "watermark"}) + + assert response.status_code == 204 + assert client.get(_VALUE_URL, params={"name": asset.name, "key": "watermark"}).status_code == 404 + + def test_delete_missing_key_is_noop(self, client: TestClient, asset: AssetModel): + response = client.delete(_VALUE_URL, params={"name": asset.name, "key": "never_existed"}) + + assert response.status_code == 204 + + +class TestClearAssetState: + def test_clear_removes_all_keys(self, client: TestClient, asset: AssetModel): + for k, v in [("watermark", "a"), ("last_id", "b"), ("schema_hash", "c")]: + client.put(_VALUE_URL, params={"name": asset.name, "key": k}, json={"value": v}) + + response = client.delete(_CLEAR_URL, params={"name": asset.name}) + + assert response.status_code == 204 + with create_session() as session: + row = session.scalar(select(AssetStateModel).where(AssetStateModel.asset_id == asset.id)) + assert row is None + + +class TestInactiveAssetRejected: + """An asset row without a corresponding asset_active entry is treated as not found.""" + + def test_get_inactive_asset_returns_404(self, client: TestClient, inactive_asset: AssetModel): + response = client.get(_VALUE_URL, params={"name": inactive_asset.name, "key": "watermark"}) + assert response.status_code == 404 + + def test_put_inactive_asset_returns_404(self, client: TestClient, inactive_asset: AssetModel): + response = client.put( + _VALUE_URL, params={"name": inactive_asset.name, "key": "watermark"}, json={"value": "x"} + ) + assert response.status_code == 404 diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_state.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_state.py new file mode 100644 index 0000000000000..8a66a0a23c739 --- /dev/null +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_state.py @@ -0,0 +1,274 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING +from uuid import uuid4 + +import pytest +from fastapi import Request +from fastapi.testclient import TestClient +from sqlalchemy import delete, select + +from airflow._shared.timezones import timezone +from airflow.api_fastapi.app import cached_app +from airflow.api_fastapi.execution_api.datamodels.token import TIClaims, TIToken +from airflow.api_fastapi.execution_api.security import _jwt_bearer +from airflow.models.dagrun import DagRun +from airflow.models.task_state import TaskStateModel +from airflow.utils.session import create_session + +if TYPE_CHECKING: + from tests_common.pytest_plugin import CreateTaskInstance + + +pytestmark = pytest.mark.db_test + + +@pytest.fixture(autouse=True) +def reset_state_tables(): + with create_session() as session: + session.execute(delete(TaskStateModel)) + session.execute(delete(DagRun)) + + +def _api_url(ti_id, key: str | None = None) -> str: + base = f"/execution/state/ti/{ti_id}" + return f"{base}/{key}" if key else base + + +class TestGetTaskState: + def test_get_returns_value(self, client: TestClient, create_task_instance: CreateTaskInstance): + ti = create_task_instance() + client.put(_api_url(ti.id, "job_id"), json={"value": "spark_001"}) + + response = client.get(_api_url(ti.id, "job_id")) + + assert response.status_code == 200 + assert response.json() == {"value": "spark_001"} + + def test_get_missing_key_returns_404(self, client: TestClient, create_task_instance: CreateTaskInstance): + ti = create_task_instance() + + response = client.get(_api_url(ti.id, "never_set")) + + assert response.status_code == 404 + assert response.json()["detail"]["reason"] == "not_found" + + def test_get_missing_ti_returns_404(self, client: TestClient): + response = client.get(_api_url(uuid4(), "any_key")) + + assert response.status_code == 404 + assert "Task instance" in response.json()["detail"]["message"] + + +class TestPutTaskState: + def test_put_creates_row(self, client: TestClient, create_task_instance: CreateTaskInstance): + ti = create_task_instance() + + response = client.put(_api_url(ti.id, "job_id"), json={"value": "spark_001"}) + + assert response.status_code == 204 + with create_session() as session: + row = session.scalar( + select(TaskStateModel).where( + TaskStateModel.dag_id == ti.dag_id, + TaskStateModel.run_id == ti.run_id, + TaskStateModel.task_id == ti.task_id, + TaskStateModel.key == "job_id", + ) + ) + assert row is not None + assert row.value == "spark_001" + + def test_put_overwrites_existing(self, client: TestClient, create_task_instance: CreateTaskInstance): + ti = create_task_instance() + client.put(_api_url(ti.id, "job_id"), json={"value": "spark_001"}) + + response = client.put(_api_url(ti.id, "job_id"), json={"value": "spark_002"}) + + assert response.status_code == 204 + assert client.get(_api_url(ti.id, "job_id")).json() == {"value": "spark_002"} + + def test_put_empty_body_returns_422(self, client: TestClient, create_task_instance: CreateTaskInstance): + ti = create_task_instance() + + response = client.put(_api_url(ti.id, "job_id"), json={}) + + assert response.status_code == 422 + + def test_put_extra_field_returns_422(self, client: TestClient, create_task_instance: CreateTaskInstance): + ti = create_task_instance() + + response = client.put(_api_url(ti.id, "job_id"), json={"value": "x", "extra": "y"}) + + assert response.status_code == 422 + + def test_put_null_value_returns_422(self, client: TestClient, create_task_instance: CreateTaskInstance): + ti = create_task_instance() + + response = client.put(_api_url(ti.id, "job_id"), json={"value": None}) + + assert response.status_code == 422 + + def test_put_missing_ti_returns_404(self, client: TestClient): + response = client.put(_api_url(uuid4(), "job_id"), json={"value": "x"}) + + assert response.status_code == 404 + + +class TestDeleteTaskState: + def test_delete_removes_key(self, client: TestClient, create_task_instance: CreateTaskInstance): + ti = create_task_instance() + client.put(_api_url(ti.id, "job_id"), json={"value": "spark_001"}) + + response = client.delete(_api_url(ti.id, "job_id")) + + assert response.status_code == 204 + assert client.get(_api_url(ti.id, "job_id")).status_code == 404 + + def test_delete_missing_key_is_noop(self, client: TestClient, create_task_instance: CreateTaskInstance): + ti = create_task_instance() + + response = client.delete(_api_url(ti.id, "never_existed")) + + assert response.status_code == 204 + + def test_delete_only_targets_one_key(self, client: TestClient, create_task_instance: CreateTaskInstance): + ti = create_task_instance() + client.put(_api_url(ti.id, "job_id"), json={"value": "a"}) + client.put(_api_url(ti.id, "checkpoint"), json={"value": "b"}) + + client.delete(_api_url(ti.id, "job_id")) + + assert client.get(_api_url(ti.id, "job_id")).status_code == 404 + assert client.get(_api_url(ti.id, "checkpoint")).json() == {"value": "b"} + + +class TestClearTaskState: + def test_clear_removes_all_keys(self, client: TestClient, create_task_instance: CreateTaskInstance): + ti = create_task_instance() + for k, v in [("job_id", "a"), ("checkpoint", "b"), ("retry_count", "c")]: + client.put(_api_url(ti.id, k), json={"value": v}) + + response = client.delete(_api_url(ti.id)) + + assert response.status_code == 204 + with create_session() as session: + remaining = session.scalars( + select(TaskStateModel.key).where( + TaskStateModel.dag_id == ti.dag_id, + TaskStateModel.task_id == ti.task_id, + ) + ).all() + assert remaining == [] + + def test_clear_when_empty_is_noop(self, client: TestClient, create_task_instance: CreateTaskInstance): + ti = create_task_instance() + + response = client.delete(_api_url(ti.id)) + + assert response.status_code == 204 + + def _seed_fleet_rows(self, ti, indices: tuple[int, ...]) -> None: + with create_session() as session: + now = timezone.utcnow() + for idx in indices: + session.add( + TaskStateModel( + dag_run_id=ti.dag_run.id, + dag_id=ti.dag_id, + run_id=ti.run_id, + task_id=ti.task_id, + map_index=idx, + key="job_id", + value=f"app_{idx}", + updated_at=now, + ) + ) + session.commit() + + def test_clear_default_only_clears_this_map_index( + self, client: TestClient, create_task_instance: CreateTaskInstance + ): + """Clear without the query param only wipes the requesting TI's own map_index.""" + ti = create_task_instance(map_index=2) + self._seed_fleet_rows(ti, (0, 1, 2)) + + response = client.delete(_api_url(ti.id)) + + assert response.status_code == 204 + with create_session() as session: + remaining_indices = sorted( + session.scalars( + select(TaskStateModel.map_index).where( + TaskStateModel.dag_id == ti.dag_id, + TaskStateModel.task_id == ti.task_id, + ) + ).all() + ) + assert remaining_indices == [0, 1] + + def test_clear_with_all_map_indices_query_param_wipes_fleet( + self, client: TestClient, create_task_instance: CreateTaskInstance + ): + """Clear with ?all_map_indices=true wipes state for every mapped instance.""" + ti = create_task_instance(map_index=2) + self._seed_fleet_rows(ti, (0, 1, 2)) + + response = client.delete(_api_url(ti.id), params={"all_map_indices": "true"}) + + assert response.status_code == 204 + with create_session() as session: + remaining = session.scalars( + select(TaskStateModel).where( + TaskStateModel.dag_id == ti.dag_id, + TaskStateModel.task_id == ti.task_id, + ) + ).all() + assert remaining == [] + + +class TestTiSelfEnforcement: + @pytest.fixture + def wrong_ti_client(self): + """TestClient using the real require_auth, JWT bound to a random TI UUID.""" + app = cached_app(apps="execution") + other_ti_id = uuid4() + + async def mock_jwt(request: Request) -> TIToken: + return TIToken(id=other_ti_id, claims=TIClaims(scope="execution")) + + app.dependency_overrides[_jwt_bearer] = mock_jwt + with TestClient(app, headers={"Authorization": "Bearer fake"}) as client: + yield client + app.dependency_overrides.pop(_jwt_bearer, None) + + def test_get_with_wrong_ti_token_returns_403(self, wrong_ti_client: TestClient, create_task_instance): + ti = create_task_instance() + response = wrong_ti_client.get(_api_url(ti.id, "some_key")) + assert response.status_code == 403 + + def test_put_with_wrong_ti_token_returns_403(self, wrong_ti_client: TestClient, create_task_instance): + ti = create_task_instance() + response = wrong_ti_client.put(_api_url(ti.id, "some_key"), json={"value": "x"}) + assert response.status_code == 403 + + def test_clear_with_wrong_ti_token_returns_403(self, wrong_ti_client: TestClient, create_task_instance): + ti = create_task_instance() + response = wrong_ti_client.delete(_api_url(ti.id)) + assert response.status_code == 403 diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index 5a8099c1ef3f3..b5b100154c389 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -63,6 +63,28 @@ class AssetProfile(BaseModel): type: Annotated[str, Field(title="Type")] +class AssetStatePutBody(BaseModel): + """ + Request body for setting an asset state value. + """ + + model_config = ConfigDict( + extra="forbid", + ) + value: Annotated[str, Field(title="Value")] + + +class AssetStateResponse(BaseModel): + """ + Asset state value returned to a worker. + """ + + model_config = ConfigDict( + extra="forbid", + ) + value: Annotated[str, Field(title="Value")] + + class ConnectionResponse(BaseModel): """ Connection schema for responses with fields that are needed for Runtime. @@ -345,6 +367,28 @@ class TaskInstanceState(str, Enum): DEFERRED = "deferred" +class TaskStatePutBody(BaseModel): + """ + Request body for setting a task state value. + """ + + model_config = ConfigDict( + extra="forbid", + ) + value: Annotated[str, Field(title="Value")] + + +class TaskStateResponse(BaseModel): + """ + Task state value returned to a worker. + """ + + model_config = ConfigDict( + extra="forbid", + ) + value: Annotated[str, Field(title="Value")] + + class TaskStatesResponse(BaseModel): """ Response for task states with run_id, task and state.