diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/assets.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/assets.py index 40397e44f43b9..385ec509f9fb1 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/assets.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/assets.py @@ -24,7 +24,7 @@ from airflow.api_fastapi.common.db.common import SessionDep from airflow.api_fastapi.execution_api.datamodels.asset import AssetResponse -from airflow.models.asset import AssetModel +from airflow.models.asset import AssetModel, expand_alias_to_assets router = APIRouter( responses={ @@ -58,6 +58,15 @@ def get_asset_by_uri( return AssetResponse.model_validate(asset) +@router.get("/by-alias") +def get_assets_by_alias( + alias_name: Annotated[str, Query(description="The name of the AssetAlias")], + session: SessionDep, +) -> list[AssetResponse]: + """Get all Airflow Assets resolved from an AssetAlias by `alias_name`.""" + return [AssetResponse.model_validate(a) for a in expand_alias_to_assets(alias_name, session=session)] + + def _raise_if_not_found(asset, msg): if asset is None: raise HTTPException( diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py index dfa27f53ebd91..e05bd22c2731e 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py @@ -40,7 +40,11 @@ MovePreviousRunEndpoint, RemoveUpstreamMapIndexesField, ) -from airflow.api_fastapi.execution_api.versions.v2026_04_17 import AddStateEndpoints, AddTeamNameField +from airflow.api_fastapi.execution_api.versions.v2026_04_17 import ( + AddAssetsByAliasEndpoint, + AddStateEndpoints, + AddTeamNameField, +) from airflow.api_fastapi.execution_api.versions.v2026_06_16 import AddRetryPolicyFields bundle = VersionBundle( @@ -50,6 +54,7 @@ "2026-04-17", AddTeamNameField, AddStateEndpoints, + AddAssetsByAliasEndpoint, ), Version( "2026-04-06", diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_17.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_17.py index dd45b11825d80..41e63cf858cbd 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_17.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_17.py @@ -36,6 +36,14 @@ def remove_team_name_field(response: ResponseInfo) -> None: # type: ignore[misc response.body["dag_run"].pop("team_name", None) +class AddAssetsByAliasEndpoint(VersionChange): + """Add endpoint to resolve assets from an AssetAlias.""" + + description = __doc__ + + instructions_to_migrate_to_previous_version = (endpoint("/assets/by-alias", ["GET"]).didnt_exist,) + + class AddStateEndpoints(VersionChange): """Add task state and asset state CRUD endpoints.""" diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py b/airflow-core/tests/unit/dag_processing/test_processor.py index be5005e4b235c..d2ac085b0b736 100644 --- a/airflow-core/tests/unit/dag_processing/test_processor.py +++ b/airflow-core/tests/unit/dag_processing/test_processor.py @@ -1947,6 +1947,7 @@ def get_type_names(union_type): "DeleteXCom", "GetAssetByName", "GetAssetByUri", + "GetAssetsByAlias", "GetAssetEventByAsset", "GetAssetEventByAssetAlias", "GetDagRun", @@ -1971,10 +1972,24 @@ def get_type_names(union_type): "UpdateHITLDetail", "GetHITLDetailResponse", "SetRenderedMapIndex", + # AIP-103 task/asset state — Dag processor has no task execution context. + "GetTaskState", + "SetTaskState", + "DeleteTaskState", + "ClearTaskState", + "GetAssetStateByName", + "GetAssetStateByUri", + "SetAssetStateByName", + "SetAssetStateByUri", + "DeleteAssetStateByName", + "DeleteAssetStateByUri", + "ClearAssetStateByName", + "ClearAssetStateByUri", } in_task_runner_but_not_in_dag_processing_process = { "AssetResult", + "AssetsByAliasResult", "AssetEventsResult", "DagResult", "DagRunResult", @@ -1989,6 +2004,9 @@ def get_type_names(union_type): "InactiveAssetsResult", "CreateHITLDetailPayload", "HITLDetailRequestResult", + # AIP-103 task/asset state results — worker-only responses to the above messages. + "TaskStateResult", + "AssetStateResult", } supervisor_diff = supervisor_types - manager_types - in_supervisor_but_not_in_manager diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index 752247c6d1752..968858be3361e 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -1758,6 +1758,7 @@ def get_type_names(union_type): "DeferTask", "GetAssetByName", "GetAssetByUri", + "GetAssetsByAlias", "GetAssetEventByAsset", "GetAssetEventByAssetAlias", "GetDagRun", @@ -1780,10 +1781,24 @@ def get_type_names(union_type): "CreateHITLDetailPayload", "SetRenderedMapIndex", "GetDag", + # AIP-103 task/asset state — triggerer has no task execution context. + "GetTaskState", + "SetTaskState", + "DeleteTaskState", + "ClearTaskState", + "GetAssetStateByName", + "GetAssetStateByUri", + "SetAssetStateByName", + "SetAssetStateByUri", + "DeleteAssetStateByName", + "DeleteAssetStateByUri", + "ClearAssetStateByName", + "ClearAssetStateByUri", } in_task_but_not_in_trigger_runner = { "AssetResult", + "AssetsByAliasResult", "AssetEventsResult", "DagRunResult", "SentFDs", @@ -1800,6 +1815,9 @@ def get_type_names(union_type): "PreviousTIResult", "HITLDetailRequestResult", "DagResult", + # AIP-103 task/asset state results — worker-only responses to the above messages. + "TaskStateResult", + "AssetStateResult", } supervisor_diff = ( diff --git a/providers/standard/tests/unit/standard/operators/test_python.py b/providers/standard/tests/unit/standard/operators/test_python.py index 437ae89ac26ff..8142157af47d6 100644 --- a/providers/standard/tests/unit/standard/operators/test_python.py +++ b/providers/standard/tests/unit/standard/operators/test_python.py @@ -71,6 +71,7 @@ AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS, AIRFLOW_V_3_2_PLUS, + AIRFLOW_V_3_3_PLUS, NOTSET, ) @@ -1131,6 +1132,11 @@ def test_virtualenv_serializable_context_fields(self, create_task_instance): "inlet_events", "outlet_events", } + if AIRFLOW_V_3_3_PLUS: + # AIP-103: task_state is a live accessor backed by the supervisor pipe — + # not serializable and meaningless in a virtualenv subprocess. + # asset_state is excluded via its absence: only present when a task has inlets. + intentionally_excluded_context_keys.add("task_state") ti = create_task_instance(dag_id=self.dag_id, task_id=self.task_id, schedule=None) context = ti.get_template_context() diff --git a/scripts/ci/prek/check_template_context_variable_in_sync.py b/scripts/ci/prek/check_template_context_variable_in_sync.py index 96d661d5aab18..26508dd5d29c4 100755 --- a/scripts/ci/prek/check_template_context_variable_in_sync.py +++ b/scripts/ci/prek/check_template_context_variable_in_sync.py @@ -48,6 +48,9 @@ "data_interval_start", "prev_data_interval_start_success", "prev_data_interval_end_success", + # AIP-103: task_state/asset_state aren't documented yet. Will be done in a later PR. + "task_state", + "asset_state", } diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 15513010062e9..493225b4699bd 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -46,6 +46,8 @@ API_VERSION, AssetEventsResponse, AssetResponse, + AssetStatePutBody, + AssetStateResponse, ConnectionResponse, DagResponse, DagRun, @@ -58,6 +60,8 @@ PrevSuccessfulDagRunResponse, TaskBreadcrumbsResponse, TaskInstanceState, + TaskStatePutBody, + TaskStateResponse, TaskStatesResponse, TerminalStateNonSuccess, TIDeferredStatePayload, @@ -80,6 +84,7 @@ from airflow.sdk.configuration import conf from airflow.sdk.exceptions import ErrorType, TaskAlreadyRunningError from airflow.sdk.execution_time.comms import ( + AssetsByAliasResult, CreateHITLDetailPayload, DRCount, ErrorResponse, @@ -662,6 +667,95 @@ def get_sequence_slice( return XComSequenceSliceResponse.model_validate_json(resp.read()) +class TaskStateOperations: + __slots__ = ("client",) + + def __init__(self, client: Client): + self.client = client + + def get(self, ti_id: uuid.UUID, key: str) -> TaskStateResponse | ErrorResponse: + """Get a task state value from the API server.""" + try: + resp = self.client.get(f"state/ti/{ti_id}/{key}") + except ServerResponseError as e: + if e.response.status_code == HTTPStatus.NOT_FOUND: + log.debug("Task state key not found", ti_id=ti_id, key=key) + return ErrorResponse(error=ErrorType.TASK_STATE_NOT_FOUND, detail={"key": key}) + raise + return TaskStateResponse.model_validate_json(resp.read()) + + def set(self, ti_id: uuid.UUID, key: str, value: str) -> OKResponse: + """Set a task state value via the API server.""" + body = TaskStatePutBody(value=value) + self.client.put(f"state/ti/{ti_id}/{key}", content=body.model_dump_json()) + return OKResponse(ok=True) + + def delete(self, ti_id: uuid.UUID, key: str) -> OKResponse: + """Delete a single task state key via the API server.""" + self.client.delete(f"state/ti/{ti_id}/{key}") + return OKResponse(ok=True) + + def clear(self, ti_id: uuid.UUID, all_map_indices: bool = False) -> OKResponse: + """Clear all task state keys for a task instance via the API server.""" + params = {"all_map_indices": "true"} if all_map_indices else {} + self.client.delete(f"state/ti/{ti_id}", params=params) + return OKResponse(ok=True) + + +class AssetStateOperations: + __slots__ = ("client",) + + def __init__(self, client: Client): + self.client = client + + def _resolve_endpoint( + self, op: str, *, key: str | None = None, name: str | None = None, uri: str | None = None + ) -> tuple[str, dict[str, str]]: + if name: + params: dict[str, str] = {"name": name} + endpoint = f"state/asset/by-name/{op}" + elif uri: + params = {"uri": uri} + endpoint = f"state/asset/by-uri/{op}" + else: + raise ValueError("Either `name` or `uri` must be provided") + if key is not None: + params["key"] = key + return endpoint, params + + def get( + self, key: str, *, name: str | None = None, uri: str | None = None + ) -> AssetStateResponse | ErrorResponse: + """Get an asset state value from the API server.""" + endpoint, params = self._resolve_endpoint("value", key=key, name=name, uri=uri) + try: + resp = self.client.get(endpoint, params=params) + except ServerResponseError as e: + if e.response.status_code == HTTPStatus.NOT_FOUND: + log.debug("Asset state key not found", name=name, uri=uri, key=key) + return ErrorResponse(error=ErrorType.ASSET_STATE_NOT_FOUND, detail={"key": key}) + raise + return AssetStateResponse.model_validate_json(resp.read()) + + def set(self, key: str, value: str, *, name: str | None = None, uri: str | None = None) -> OKResponse: + """Set an asset state value via the API server.""" + endpoint, params = self._resolve_endpoint("value", key=key, name=name, uri=uri) + self.client.put(endpoint, params=params, content=AssetStatePutBody(value=value).model_dump_json()) + return OKResponse(ok=True) + + def delete(self, key: str, *, name: str | None = None, uri: str | None = None) -> OKResponse: + """Delete a single asset state key via the API server.""" + endpoint, params = self._resolve_endpoint("value", key=key, name=name, uri=uri) + self.client.delete(endpoint, params=params) + return OKResponse(ok=True) + + def clear(self, *, name: str | None = None, uri: str | None = None) -> OKResponse: + """Clear all state keys for an asset via the API server.""" + endpoint, params = self._resolve_endpoint("clear", name=name, uri=uri) + self.client.delete(endpoint, params=params) + return OKResponse(ok=True) + + class AssetOperations: __slots__ = ("client",) @@ -694,6 +788,13 @@ def get(self, name: str | None = None, uri: str | None = None) -> AssetResponse return AssetResponse.model_validate_json(resp.read()) + def get_by_alias(self, alias_name: str) -> AssetsByAliasResult: + """Get all Assets resolved from an AssetAlias.""" + resp = self.client.get("assets/by-alias", params={"alias_name": alias_name}) + return AssetsByAliasResult.from_asset_responses( + [AssetResponse.model_validate(a) for a in resp.json()] + ) + class AssetEventOperations: __slots__ = ("client",) @@ -1078,6 +1179,18 @@ def asset_events(self) -> AssetEventOperations: """Operations related to Asset Events.""" return AssetEventOperations(self) + @lru_cache() # type: ignore[misc] + @property + def task_state(self) -> TaskStateOperations: + """Operations related to task state.""" + return TaskStateOperations(self) + + @lru_cache() # type: ignore[misc] + @property + def asset_state(self) -> AssetStateOperations: + """Operations related to asset state.""" + return AssetStateOperations(self) + @lru_cache() # type: ignore[misc] @property def hitl(self): diff --git a/task-sdk/src/airflow/sdk/definitions/context.py b/task-sdk/src/airflow/sdk/definitions/context.py index b7a63284608aa..c422c3462982a 100644 --- a/task-sdk/src/airflow/sdk/definitions/context.py +++ b/task-sdk/src/airflow/sdk/definitions/context.py @@ -30,7 +30,11 @@ from airflow.sdk.bases.operator import BaseOperator from airflow.sdk.definitions.dag import DAG - from airflow.sdk.execution_time.context import InletEventsAccessors + from airflow.sdk.execution_time.context import ( + AssetStateAccessors, + InletEventsAccessors, + TaskStateAccessor, + ) from airflow.sdk.types import ( DagRunProtocol, Operator, @@ -72,6 +76,8 @@ class Context(TypedDict, total=False): task_reschedule_count: int task_instance: RuntimeTaskInstanceProtocol task_instance_key_str: str + task_state: TaskStateAccessor + asset_state: AssetStateAccessors # `templates_dict` is only set in PythonOperator templates_dict: NotRequired[dict[str, Any] | None] test_mode: bool diff --git a/task-sdk/src/airflow/sdk/exceptions.py b/task-sdk/src/airflow/sdk/exceptions.py index 7d42dad5d8502..ed3bb3f14938e 100644 --- a/task-sdk/src/airflow/sdk/exceptions.py +++ b/task-sdk/src/airflow/sdk/exceptions.py @@ -80,6 +80,8 @@ class ErrorType(enum.Enum): VARIABLE_NOT_FOUND = "VARIABLE_NOT_FOUND" XCOM_NOT_FOUND = "XCOM_NOT_FOUND" ASSET_NOT_FOUND = "ASSET_NOT_FOUND" + TASK_STATE_NOT_FOUND = "TASK_STATE_NOT_FOUND" + ASSET_STATE_NOT_FOUND = "ASSET_STATE_NOT_FOUND" DAGRUN_ALREADY_EXISTS = "DAGRUN_ALREADY_EXISTS" GENERIC_ERROR = "GENERIC_ERROR" API_SERVER_ERROR = "API_SERVER_ERROR" diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index 1e11e9636e56f..01528c728b1fe 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -69,6 +69,7 @@ AssetEventResponse, AssetEventsResponse, AssetResponse, + AssetStateResponse, BundleInfo, ConnectionResponse, DagResponse, @@ -81,6 +82,7 @@ TaskBreadcrumbsResponse, TaskInstance, TaskInstanceState, + TaskStateResponse, TaskStatesResponse, TIDeferredStatePayload, TIRescheduleStatePayload, @@ -561,6 +563,40 @@ def from_variable_response(cls, variable_response: VariableResponse) -> Variable return cls(**variable_response.model_dump(exclude_defaults=True), type="VariableResult") +class TaskStateResult(TaskStateResponse): + """Response to GetTaskState; wraps the generated API response for supervisor to worker comms.""" + + type: Literal["TaskStateResult"] = "TaskStateResult" + + @classmethod + def from_task_state_response(cls, resp: TaskStateResponse) -> TaskStateResult: + return cls(**resp.model_dump(exclude_defaults=True), type="TaskStateResult") + + +class AssetStateResult(AssetStateResponse): + """Response to GetAssetState; wraps the generated API response for supervisor to worker comms.""" + + type: Literal["AssetStateResult"] = "AssetStateResult" + + @classmethod + def from_asset_state_response(cls, resp: AssetStateResponse) -> AssetStateResult: + return cls(**resp.model_dump(exclude_defaults=True), type="AssetStateResult") + + +class AssetsByAliasResult(BaseModel): + """Response to GetAssetsByAlias; list of concrete assets resolved from an alias.""" + + assets: list[AssetResult] + type: Literal["AssetsByAliasResult"] = "AssetsByAliasResult" + + @classmethod + def from_asset_responses(cls, asset_responses: list[AssetResponse]) -> AssetsByAliasResult: + return cls( + assets=[AssetResult.from_asset_response(a) for a in asset_responses], + type="AssetsByAliasResult", + ) + + class DagRunResult(DagRun): type: Literal["DagRunResult"] = "DagRunResult" @@ -728,7 +764,9 @@ def from_api_response(cls, dag_response: DagResponse) -> DagResult: ToTask = Annotated[ AssetResult + | AssetsByAliasResult | AssetEventsResult + | AssetStateResult | ConnectionResult | DagRunResult | DagRunStateResult @@ -740,6 +778,7 @@ def from_api_response(cls, dag_response: DagResponse) -> DagResult: | SentFDs | StartupDetails | TaskRescheduleStartDate + | TaskStateResult | TICount | TaskBreadcrumbsResult | TaskStatesResult @@ -868,6 +907,79 @@ class DeleteXCom(BaseModel): type: Literal["DeleteXCom"] = "DeleteXCom" +class GetTaskState(BaseModel): + ti_id: UUID + key: str + type: Literal["GetTaskState"] = "GetTaskState" + + +class SetTaskState(BaseModel): + ti_id: UUID + key: str + value: str + type: Literal["SetTaskState"] = "SetTaskState" + + +class DeleteTaskState(BaseModel): + ti_id: UUID + key: str + type: Literal["DeleteTaskState"] = "DeleteTaskState" + + +class ClearTaskState(BaseModel): + ti_id: UUID + all_map_indices: bool = False + type: Literal["ClearTaskState"] = "ClearTaskState" + + +class GetAssetStateByName(BaseModel): + name: str + key: str + type: Literal["GetAssetStateByName"] = "GetAssetStateByName" + + +class GetAssetStateByUri(BaseModel): + uri: str + key: str + type: Literal["GetAssetStateByUri"] = "GetAssetStateByUri" + + +class SetAssetStateByName(BaseModel): + name: str + key: str + value: str + type: Literal["SetAssetStateByName"] = "SetAssetStateByName" + + +class SetAssetStateByUri(BaseModel): + uri: str + key: str + value: str + type: Literal["SetAssetStateByUri"] = "SetAssetStateByUri" + + +class DeleteAssetStateByName(BaseModel): + name: str + key: str + type: Literal["DeleteAssetStateByName"] = "DeleteAssetStateByName" + + +class DeleteAssetStateByUri(BaseModel): + uri: str + key: str + type: Literal["DeleteAssetStateByUri"] = "DeleteAssetStateByUri" + + +class ClearAssetStateByName(BaseModel): + name: str + type: Literal["ClearAssetStateByName"] = "ClearAssetStateByName" + + +class ClearAssetStateByUri(BaseModel): + uri: str + type: Literal["ClearAssetStateByUri"] = "ClearAssetStateByUri" + + class GetConnection(BaseModel): conn_id: str type: Literal["GetConnection"] = "GetConnection" @@ -957,6 +1069,11 @@ class GetAssetByUri(BaseModel): type: Literal["GetAssetByUri"] = "GetAssetByUri" +class GetAssetsByAlias(BaseModel): + alias_name: str + type: Literal["GetAssetsByAlias"] = "GetAssetsByAlias" + + class GetAssetEventByAsset(BaseModel): name: str | None uri: str | None @@ -1058,12 +1175,21 @@ class GetDag(BaseModel): ToSupervisor = Annotated[ - DeferTask + ClearAssetStateByName + | ClearAssetStateByUri + | ClearTaskState + | DeferTask + | DeleteAssetStateByName + | DeleteAssetStateByUri + | DeleteTaskState | DeleteXCom | GetAssetByName | GetAssetByUri + | GetAssetsByAlias | GetAssetEventByAsset | GetAssetEventByAssetAlias + | GetAssetStateByName + | GetAssetStateByUri | GetConnection | GetDagRun | GetDagRunState @@ -1073,6 +1199,7 @@ class GetDag(BaseModel): | GetPreviousDagRun | GetPreviousTI | GetTaskRescheduleStartDate + | GetTaskState | GetTICount | GetTaskBreadcrumbs | GetTaskStates @@ -1084,8 +1211,11 @@ class GetDag(BaseModel): | PutVariable | RescheduleTask | RetryTask + | SetAssetStateByName + | SetAssetStateByUri | SetRenderedFields | SetRenderedMapIndex + | SetTaskState | SetXCom | SkipDownstreamTasks | SucceedTask diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index 66c1f3aa8b7eb..7fdd8bddc7fc7 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -24,6 +24,7 @@ from datetime import datetime from functools import cache from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload +from uuid import UUID import attrs import structlog @@ -45,8 +46,6 @@ from airflow.sdk.log import mask_secret if TYPE_CHECKING: - from uuid import UUID - from pydantic.types import JsonValue from typing_extensions import Self @@ -406,6 +405,229 @@ def get(self, key, default: Any = NOTSET) -> Any: raise +class TaskStateAccessor: + """Accessor for task state scoped to the current task instance. Available as ``context['task_state']`` at task execution time.""" + + def __init__(self, ti_id: UUID) -> None: + self._ti_id = ti_id + + def __eq__(self, other: object) -> bool: + if not isinstance(other, TaskStateAccessor): + return False + return self._ti_id == other._ti_id + + def __hash__(self) -> int: + return hash(self._ti_id) + + def __repr__(self) -> str: + return f"" + + # TODO: ``__getattr__`` for jinja template access like ``{{ task_state.job_id }}`` + # is not implemented yet cos it's unclear whether task state values will be + # used in templates. + + def get(self, key: str) -> str | None: + """Return the stored value, or ``None`` if the key does not exist.""" + from airflow.sdk.execution_time.comms import ErrorResponse, GetTaskState, TaskStateResult + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + resp = SUPERVISOR_COMMS.send(GetTaskState(ti_id=self._ti_id, key=key)) + if isinstance(resp, ErrorResponse) and resp.error != ErrorType.TASK_STATE_NOT_FOUND: + raise AirflowRuntimeError(resp) + if isinstance(resp, TaskStateResult): + return resp.value + return None + + def set(self, key: str, value: str) -> None: + """Write or overwrite the value for the given key.""" + from airflow.sdk.execution_time.comms import SetTaskState + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + SUPERVISOR_COMMS.send(SetTaskState(ti_id=self._ti_id, key=key, value=value)) + + def delete(self, key: str) -> None: + """Delete a single key. No-op if the key does not exist.""" + from airflow.sdk.execution_time.comms import DeleteTaskState + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + SUPERVISOR_COMMS.send(DeleteTaskState(ti_id=self._ti_id, key=key)) + + def clear(self, all_map_indices: bool = False) -> None: + """ + Delete all keys for this task instance. + + Pass ``all_map_indices=True`` to wipe state across every mapped + instance of the task (fleet-wide reset). Defaults to clearing only + this task instance's own state. + """ + from airflow.sdk.execution_time.comms import ClearTaskState + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + SUPERVISOR_COMMS.send(ClearTaskState(ti_id=self._ti_id, all_map_indices=all_map_indices)) + + +class AssetStateAccessor: + """ + Accessor for asset state scoped to a single asset. + + Obtained via ``context['asset_state'][MY_ASSET]`` or, as sugar for single-inlet + tasks, directly as ``context['asset_state']``. + """ + + def __init__(self, *, name: str | None = None, uri: str | None = None) -> None: + if not name and not uri: + raise ValueError("Either `name` or `uri` must be provided") + self._name = name + self._uri = uri + + def __eq__(self, other: object) -> bool: + if not isinstance(other, AssetStateAccessor): + return False + return self._name == other._name and self._uri == other._uri + + def __hash__(self) -> int: + return hash((self._name, self._uri)) + + def __repr__(self) -> str: + if self._name is not None: + return f"" + return f"" + + def get(self, key: str) -> str | None: + """Return the stored value, or ``None`` if the key does not exist.""" + from airflow.sdk.execution_time.comms import ( + AssetStateResult, + ErrorResponse, + GetAssetStateByName, + GetAssetStateByUri, + ToSupervisor, + ) + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + msg: ToSupervisor + if self._name: + msg = GetAssetStateByName(name=self._name, key=key) + elif self._uri: + msg = GetAssetStateByUri(uri=self._uri, key=key) + resp = SUPERVISOR_COMMS.send(msg) + if isinstance(resp, ErrorResponse) and resp.error != ErrorType.ASSET_STATE_NOT_FOUND: + raise AirflowRuntimeError(resp) + if isinstance(resp, AssetStateResult): + return resp.value + return None + + def set(self, key: str, value: str) -> None: + """Write or overwrite the value for the given key.""" + from airflow.sdk.execution_time.comms import SetAssetStateByName, SetAssetStateByUri, ToSupervisor + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + msg: ToSupervisor + if self._name: + msg = SetAssetStateByName(name=self._name, key=key, value=value) + elif self._uri: + msg = SetAssetStateByUri(uri=self._uri, key=key, value=value) + SUPERVISOR_COMMS.send(msg) + + def delete(self, key: str) -> None: + """Delete a single key. No-op if the key does not exist.""" + from airflow.sdk.execution_time.comms import ( + DeleteAssetStateByName, + DeleteAssetStateByUri, + ToSupervisor, + ) + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + msg: ToSupervisor + if self._name: + msg = DeleteAssetStateByName(name=self._name, key=key) + elif self._uri: + msg = DeleteAssetStateByUri(uri=self._uri, key=key) + SUPERVISOR_COMMS.send(msg) + + def clear(self) -> None: + """Delete all state keys for this asset.""" + from airflow.sdk.execution_time.comms import ClearAssetStateByName, ClearAssetStateByUri, ToSupervisor + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + msg: ToSupervisor + if self._name: + msg = ClearAssetStateByName(name=self._name) + elif self._uri: + msg = ClearAssetStateByUri(uri=self._uri) + SUPERVISOR_COMMS.send(msg) + + +class AssetStateAccessors: + """ + Mapping of asset state accessors for all concrete inlets of a task. + + Available as ``context['asset_state']``. Subscript by asset to get a per asset + accessor as: ``context['asset_state'][MY_ASSET].get('watermark')``. + + For tasks with exactly one concrete inlet, the accessor methods (``get``, ``set``, + ``delete``, ``clear``) can be called directly without subscripting. + """ + + def __init__(self, inlets: list) -> None: + self._by_name: dict[str, AssetStateAccessor] = {} + self._by_uri: dict[str, AssetStateAccessor] = {} + + for inlet in inlets: + if isinstance(inlet, (Asset, AssetNameRef)): + self._by_name[inlet.name] = AssetStateAccessor(name=inlet.name) + elif isinstance(inlet, AssetUriRef): + self._by_uri[inlet.uri] = AssetStateAccessor(uri=inlet.uri) + elif isinstance(inlet, AssetAlias): + from airflow.sdk.execution_time.comms import AssetsByAliasResult, GetAssetsByAlias + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + resp = SUPERVISOR_COMMS.send(GetAssetsByAlias(alias_name=inlet.name)) + if isinstance(resp, AssetsByAliasResult): + for asset in resp.assets: + self._by_name[asset.name] = AssetStateAccessor(name=asset.name) + + self._total = len(self._by_name) + len(self._by_uri) + + def __getitem__(self, key: Asset | AssetNameRef | AssetUriRef) -> AssetStateAccessor: + try: + if isinstance(key, (Asset, AssetNameRef)): + return self._by_name[key.name] + if isinstance(key, AssetUriRef): + return self._by_uri[key.uri] + except KeyError: + raise KeyError(f"{key!r} is not in this task's inlets") + raise TypeError(f"Expected Asset, AssetNameRef, or AssetUriRef; got {type(key).__name__}") + + def _single_accessor(self) -> AssetStateAccessor: + if self._total != 1: + raise ValueError( + f"Task has {self._total} concrete inlets — use context['asset_state'][MY_ASSET] to specify which" + ) + if self._by_name: + return next(iter(self._by_name.values())) + return next(iter(self._by_uri.values())) + + def get(self, key: str) -> str | None: + """Return the stored value for the single-inlet task, or ``None`` if not found.""" + return self._single_accessor().get(key) + + def set(self, key: str, value: str) -> None: + """Write or overwrite the value for the single-inlet task.""" + self._single_accessor().set(key, value) + + def delete(self, key: str) -> None: + """Delete a single key for the single-inlet task.""" + self._single_accessor().delete(key) + + def clear(self) -> None: + """Delete all state keys for the single-inlet task.""" + self._single_accessor().clear() + + def __repr__(self) -> str: + parts = [f"name={k!r}" for k in self._by_name] + [f"uri={k!r}" for k in self._by_uri] + return f"" + + class MacrosAccessor: """Wrapper to access Macros module lazily.""" diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index cd25d9279571b..b8f2eeff04510 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -64,12 +64,19 @@ from airflow.sdk.execution_time.comms import ( AssetEventsResult, AssetResult, + AssetStateResult, + ClearAssetStateByName, + ClearAssetStateByUri, + ClearTaskState, ConnectionResult, CreateHITLDetailPayload, DagResult, DagRunResult, DagRunStateResult, DeferTask, + DeleteAssetStateByName, + DeleteAssetStateByUri, + DeleteTaskState, DeleteVariable, DeleteXCom, ErrorResponse, @@ -77,6 +84,9 @@ GetAssetByUri, GetAssetEventByAsset, GetAssetEventByAssetAlias, + GetAssetsByAlias, + GetAssetStateByName, + GetAssetStateByUri, GetConnection, GetDag, GetDagRun, @@ -87,6 +97,7 @@ GetPrevSuccessfulDagRun, GetTaskBreadcrumbs, GetTaskRescheduleStartDate, + GetTaskState, GetTaskStates, GetTICount, GetVariable, @@ -97,20 +108,25 @@ HITLDetailRequestResult, InactiveAssetsResult, MaskSecret, + OKResponse, PrevSuccessfulDagRunResult, PutVariable, RescheduleTask, ResendLoggingFD, RetryTask, SentFDs, + SetAssetStateByName, + SetAssetStateByUri, SetRenderedFields, SetRenderedMapIndex, + SetTaskState, SetXCom, SkipDownstreamTasks, StartupDetails, SucceedTask, TaskBreadcrumbsResult, TaskState, + TaskStateResult, TaskStatesResult, ToSupervisor, TriggerDagRun, @@ -1509,6 +1525,8 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: dump_opts = {"exclude_unset": True} else: resp = asset_resp + elif isinstance(msg, GetAssetsByAlias): + resp = self.client.assets.get_by_alias(alias_name=msg.alias_name) elif isinstance(msg, GetAssetEventByAsset): asset_event_resp = self.client.asset_events.get( uri=msg.uri, @@ -1628,6 +1646,54 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: dag_id=msg.dag_id, ) resp = DagResult.from_api_response(dag) + elif isinstance(msg, GetTaskState): + task_state = self.client.task_state.get(msg.ti_id, msg.key) + resp = ( + task_state + if isinstance(task_state, ErrorResponse) + else TaskStateResult.from_task_state_response(task_state) + ) + elif isinstance(msg, SetTaskState): + self.client.task_state.set(msg.ti_id, msg.key, msg.value) + resp = OKResponse(ok=True) + elif isinstance(msg, DeleteTaskState): + self.client.task_state.delete(msg.ti_id, msg.key) + resp = OKResponse(ok=True) + elif isinstance(msg, ClearTaskState): + self.client.task_state.clear(msg.ti_id, all_map_indices=msg.all_map_indices) + resp = OKResponse(ok=True) + elif isinstance(msg, GetAssetStateByName): + asset_state = self.client.asset_state.get(msg.key, name=msg.name) + resp = ( + asset_state + if isinstance(asset_state, ErrorResponse) + else AssetStateResult.from_asset_state_response(asset_state) + ) + elif isinstance(msg, GetAssetStateByUri): + asset_state = self.client.asset_state.get(msg.key, uri=msg.uri) + resp = ( + asset_state + if isinstance(asset_state, ErrorResponse) + else AssetStateResult.from_asset_state_response(asset_state) + ) + elif isinstance(msg, SetAssetStateByName): + self.client.asset_state.set(msg.key, msg.value, name=msg.name) + resp = OKResponse(ok=True) + elif isinstance(msg, SetAssetStateByUri): + self.client.asset_state.set(msg.key, msg.value, uri=msg.uri) + resp = OKResponse(ok=True) + elif isinstance(msg, DeleteAssetStateByName): + self.client.asset_state.delete(msg.key, name=msg.name) + resp = OKResponse(ok=True) + elif isinstance(msg, DeleteAssetStateByUri): + self.client.asset_state.delete(msg.key, uri=msg.uri) + resp = OKResponse(ok=True) + elif isinstance(msg, ClearAssetStateByName): + self.client.asset_state.clear(name=msg.name) + resp = OKResponse(ok=True) + elif isinstance(msg, ClearAssetStateByUri): + self.client.asset_state.clear(uri=msg.uri) + resp = OKResponse(ok=True) else: log.error("Unhandled request", msg=msg) self.send_msg( diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 56ba8343c648b..7c318fc499ed6 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -111,10 +111,12 @@ ValidateInletsAndOutlets, ) from airflow.sdk.execution_time.context import ( + AssetStateAccessors, ConnectionAccessor, InletEventsAccessors, MacrosAccessor, OutletEventAccessors, + TaskStateAccessor, TriggeringAssetEventsAccessor, VariableAccessor, context_get_outlet_events, @@ -249,7 +251,12 @@ def get_template_context(self) -> Context: "value": VariableAccessor(deserialize_json=False), }, "conn": ConnectionAccessor(), + "task_state": TaskStateAccessor(ti_id=self.id), } + if any(isinstance(i, (Asset, AssetNameRef, AssetUriRef, AssetAlias)) for i in self.task.inlets): + self._cached_template_context["asset_state"] = AssetStateAccessors(self.task.inlets) + # AssetAlias inlets are resolved to their concrete assets at context build time + # via GetAssetsByAlias comms. If an alias maps to no active assets, it doesnt contribute to asset_state. if TYPE_CHECKING: assert self._cached_template_context is not None if from_server: diff --git a/task-sdk/tests/task_sdk/api/test_client.py b/task-sdk/tests/task_sdk/api/test_client.py index 26e2a7e66bb6e..a179ff08436b2 100644 --- a/task-sdk/tests/task_sdk/api/test_client.py +++ b/task-sdk/tests/task_sdk/api/test_client.py @@ -36,6 +36,7 @@ from airflow.sdk.api.datamodels._generated import ( AssetEventsResponse, AssetResponse, + AssetStateResponse, ConnectionResponse, DagResponse, DagRunState, @@ -43,12 +44,14 @@ HITLDetailRequest, HITLDetailResponse, HITLUser, + TaskStateResponse, TerminalTIState, VariableResponse, XComResponse, ) from airflow.sdk.exceptions import ErrorType, TaskAlreadyRunningError from airflow.sdk.execution_time.comms import ( + AssetsByAliasResult, DeferTask, ErrorResponse, OKResponse, @@ -1221,6 +1224,37 @@ def handle_request(request: httpx.Request) -> httpx.Response: assert isinstance(result, ErrorResponse) assert result.error == ErrorType.ASSET_NOT_FOUND + def test_get_by_alias_returns_list(self): + def handle_request(request: httpx.Request) -> httpx.Response: + assert request.url.path == "/assets/by-alias" + assert request.url.params["alias_name"] == "my_alias" + return httpx.Response( + status_code=200, + json=[ + {"name": "asset_a", "uri": "s3://bucket/a", "group": "asset", "extra": {}}, + {"name": "asset_b", "uri": "s3://bucket/b", "group": "asset", "extra": {}}, + ], + ) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.assets.get_by_alias("my_alias") + + assert isinstance(result, AssetsByAliasResult) + assert len(result.assets) == 2 + assert isinstance(result.assets[0], AssetResponse) + assert isinstance(result.assets[1], AssetResponse) + assert result.assets[0].name == "asset_a" + assert result.assets[1].name == "asset_b" + + def test_get_by_alias_returns_empty_for_unknown_alias(self): + def handle_request(request: httpx.Request) -> httpx.Response: + return httpx.Response(status_code=200, json=[]) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.assets.get_by_alias("unknown_alias") + + assert result.assets == [] + class TestDagRunOperations: def test_trigger(self): @@ -1700,3 +1734,189 @@ def handle_request(request: httpx.Request) -> httpx.Response: with pytest.raises(ServerResponseError): client.dags.get(dag_id="test_dag") + + +class TestTaskStateOperations: + TI_ID = uuid7() + + def test_get_success(self): + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == f"/state/ti/{self.TI_ID}/job_id": + return httpx.Response(status_code=200, json={"value": "spark_app_001"}) + return httpx.Response(status_code=400) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.task_state.get(ti_id=self.TI_ID, key="job_id") + + assert isinstance(result, TaskStateResponse) + assert result.value == "spark_app_001" + + def test_get_returns_error_response_on_404(self): + def handle_request(request: httpx.Request) -> httpx.Response: + return httpx.Response( + status_code=404, + json={"detail": {"reason": "not_found", "message": "Task state key 'job_id' not found"}}, + ) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.task_state.get(ti_id=self.TI_ID, key="job_id") + assert isinstance(result, ErrorResponse) + assert result.error == ErrorType.TASK_STATE_NOT_FOUND + + def test_set_success(self): + def handle_request(request: httpx.Request) -> httpx.Response: + assert request.method == "PUT" + assert request.url.path == f"/state/ti/{self.TI_ID}/job_id" + assert b'"value":"spark_app_001"' in request.content + return httpx.Response(status_code=204) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.task_state.set(ti_id=self.TI_ID, key="job_id", value="spark_app_001") + assert result == OKResponse(ok=True) + + def test_delete_success(self): + def handle_request(request: httpx.Request) -> httpx.Response: + assert request.method == "DELETE" + assert request.url.path == f"/state/ti/{self.TI_ID}/job_id" + return httpx.Response(status_code=204) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.task_state.delete(ti_id=self.TI_ID, key="job_id") + assert result == OKResponse(ok=True) + + def test_clear_default_no_query_param(self): + def handle_request(request: httpx.Request) -> httpx.Response: + assert request.method == "DELETE" + assert request.url.path == f"/state/ti/{self.TI_ID}" + assert "all_map_indices" not in str(request.url.query) + return httpx.Response(status_code=204) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.task_state.clear(ti_id=self.TI_ID) + assert result == OKResponse(ok=True) + + def test_clear_all_map_indices_sends_query_param(self): + def handle_request(request: httpx.Request) -> httpx.Response: + assert "all_map_indices=true" in str(request.url.query) + return httpx.Response(status_code=204) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.task_state.clear(ti_id=self.TI_ID, all_map_indices=True) + assert result == OKResponse(ok=True) + + +class TestAssetStateOperations: + def test_get_by_name_success(self): + def handle_request(request: httpx.Request) -> httpx.Response: + if ( + request.url.path == "/state/asset/by-name/value" + and request.url.params["name"] == "test_asset" + and request.url.params["key"] == "watermark" + ): + return httpx.Response(status_code=200, json={"value": "2026-04-30T00:00:00Z"}) + return httpx.Response(status_code=400) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.asset_state.get(key="watermark", name="test_asset") + + assert isinstance(result, AssetStateResponse) + assert result.value == "2026-04-30T00:00:00Z" + + def test_get_by_uri_success(self): + def handle_request(request: httpx.Request) -> httpx.Response: + if ( + request.url.path == "/state/asset/by-uri/value" + and request.url.params["uri"] == "s3://bucket/key" + and request.url.params["key"] == "watermark" + ): + return httpx.Response(status_code=200, json={"value": "2026-04-30T00:00:00Z"}) + return httpx.Response(status_code=400) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.asset_state.get(key="watermark", uri="s3://bucket/key") + + assert isinstance(result, AssetStateResponse) + assert result.value == "2026-04-30T00:00:00Z" + + def test_get_returns_error_response_on_404(self): + def handle_request(request: httpx.Request) -> httpx.Response: + return httpx.Response( + status_code=404, + json={"detail": {"reason": "not_found", "message": "Asset state key 'watermark' not found"}}, + ) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.asset_state.get(key="watermark", name="test_asset") + assert isinstance(result, ErrorResponse) + assert result.error == ErrorType.ASSET_STATE_NOT_FOUND + + def test_set_by_name_success(self): + def handle_request(request: httpx.Request) -> httpx.Response: + assert request.method == "PUT" + assert request.url.path == "/state/asset/by-name/value" + assert request.url.params["name"] == "test_asset" + assert request.url.params["key"] == "watermark" + assert b'"value":"2026-04-30T00:00:00Z"' in request.content + return httpx.Response(status_code=204) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.asset_state.set(key="watermark", value="2026-04-30T00:00:00Z", name="test_asset") + assert result == OKResponse(ok=True) + + def test_set_by_uri_success(self): + def handle_request(request: httpx.Request) -> httpx.Response: + assert request.method == "PUT" + assert request.url.path == "/state/asset/by-uri/value" + assert request.url.params["uri"] == "s3://bucket/key" + assert request.url.params["key"] == "watermark" + return httpx.Response(status_code=204) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.asset_state.set(key="watermark", value="2026-04-30T00:00:00Z", uri="s3://bucket/key") + assert result == OKResponse(ok=True) + + def test_delete_by_name_success(self): + def handle_request(request: httpx.Request) -> httpx.Response: + assert request.method == "DELETE" + assert request.url.path == "/state/asset/by-name/value" + assert request.url.params["name"] == "test_asset" + assert request.url.params["key"] == "watermark" + return httpx.Response(status_code=204) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.asset_state.delete(key="watermark", name="test_asset") + assert result == OKResponse(ok=True) + + def test_delete_by_uri_success(self): + def handle_request(request: httpx.Request) -> httpx.Response: + assert request.method == "DELETE" + assert request.url.path == "/state/asset/by-uri/value" + assert request.url.params["uri"] == "s3://bucket/key" + assert request.url.params["key"] == "watermark" + return httpx.Response(status_code=204) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.asset_state.delete(key="watermark", uri="s3://bucket/key") + assert result == OKResponse(ok=True) + + def test_clear_by_name_success(self): + def handle_request(request: httpx.Request) -> httpx.Response: + assert request.method == "DELETE" + assert request.url.path == "/state/asset/by-name/clear" + assert request.url.params["name"] == "test_asset" + return httpx.Response(status_code=204) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.asset_state.clear(name="test_asset") + assert result == OKResponse(ok=True) + + def test_clear_by_uri_success(self): + def handle_request(request: httpx.Request) -> httpx.Response: + assert request.method == "DELETE" + assert request.url.path == "/state/asset/by-uri/clear" + assert request.url.params["uri"] == "s3://bucket/key" + return httpx.Response(status_code=204) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.asset_state.clear(uri="s3://bucket/key") + assert result == OKResponse(ok=True) diff --git a/task-sdk/tests/task_sdk/execution_time/test_context.py b/task-sdk/tests/task_sdk/execution_time/test_context.py index 1b0be13ab3ad9..ff0e6025c637d 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_context.py +++ b/task-sdk/tests/task_sdk/execution_time/test_context.py @@ -19,6 +19,7 @@ from unittest import mock from unittest.mock import MagicMock, patch +from uuid import UUID import pytest @@ -30,33 +31,55 @@ AssetAlias, AssetAliasEvent, AssetAliasUniqueKey, + AssetNameRef, AssetUniqueKey, + AssetUriRef, ) from airflow.sdk.definitions.connection import Connection from airflow.sdk.definitions.variable import Variable -from airflow.sdk.exceptions import AirflowNotFoundException, ErrorType +from airflow.sdk.exceptions import AirflowNotFoundException, AirflowRuntimeError, ErrorType from airflow.sdk.execution_time.comms import ( AssetEventDagRunReferenceResult, AssetEventResult, AssetEventSourceTaskInstance, AssetEventsResult, AssetResult, + AssetsByAliasResult, + AssetStateResult, + ClearAssetStateByName, + ClearAssetStateByUri, + ClearTaskState, ConnectionResult, DagRunResult, + DeleteAssetStateByName, + DeleteAssetStateByUri, + DeleteTaskState, ErrorResponse, GetAssetByName, GetAssetByUri, GetAssetEventByAsset, + GetAssetsByAlias, + GetAssetStateByName, + GetAssetStateByUri, GetDagRun, + GetTaskState, GetXCom, + OKResponse, + SetAssetStateByName, + SetAssetStateByUri, + SetTaskState, + TaskStateResult, VariableResult, XComResult, ) from airflow.sdk.execution_time.context import ( + AssetStateAccessor, + AssetStateAccessors, ConnectionAccessor, InletEventsAccessors, OutletEventAccessor, OutletEventAccessors, + TaskStateAccessor, TriggeringAssetEventsAccessor, VariableAccessor, _AssetRefResolutionMixin, @@ -1032,3 +1055,265 @@ def get_connection(self, conn_id): with pytest.raises(AirflowNotFoundException, match="isn't defined"): _get_connection("nonexistent_conn") + + +class TestTaskStateAccessor: + TI_ID = UUID("01900000-0000-0000-0000-000000000001") + + def test_get_returns_value(self, mock_supervisor_comms): + mock_supervisor_comms.send.return_value = TaskStateResult(value="app_001") + + result = TaskStateAccessor(ti_id=self.TI_ID).get("job_id") + + assert result == "app_001" + mock_supervisor_comms.send.assert_called_once_with(GetTaskState(ti_id=self.TI_ID, key="job_id")) + + def test_get_returns_none_on_404(self, mock_supervisor_comms): + mock_supervisor_comms.send.return_value = ErrorResponse( + error=ErrorType.TASK_STATE_NOT_FOUND, detail={"key": "missing_key"} + ) + + result = TaskStateAccessor(ti_id=self.TI_ID).get("missing_key") + + assert result is None + + def test_get_raises_on_error(self, mock_supervisor_comms): + mock_supervisor_comms.send.return_value = ErrorResponse( + error=ErrorType.GENERIC_ERROR, detail={"message": "server error"} + ) + + with pytest.raises(AirflowRuntimeError): + TaskStateAccessor(ti_id=self.TI_ID).get("some_key") + + def test_set_operation(self, mock_supervisor_comms): + mock_supervisor_comms.send.return_value = OKResponse(ok=True) + + TaskStateAccessor(ti_id=self.TI_ID).set("job_id", "app_001") + + mock_supervisor_comms.send.assert_called_once_with( + SetTaskState(ti_id=self.TI_ID, key="job_id", value="app_001") + ) + + def test_delete_operation(self, mock_supervisor_comms): + mock_supervisor_comms.send.return_value = OKResponse(ok=True) + + TaskStateAccessor(ti_id=self.TI_ID).delete("job_id") + + mock_supervisor_comms.send.assert_called_once_with(DeleteTaskState(ti_id=self.TI_ID, key="job_id")) + + def test_clear_default_sends_all_map_indices_false(self, mock_supervisor_comms): + mock_supervisor_comms.send.return_value = OKResponse(ok=True) + + TaskStateAccessor(ti_id=self.TI_ID).clear() + + mock_supervisor_comms.send.assert_called_once_with( + ClearTaskState(ti_id=self.TI_ID, all_map_indices=False) + ) + + def test_clear_all_map_indices_sends_flag_true(self, mock_supervisor_comms): + mock_supervisor_comms.send.return_value = OKResponse(ok=True) + + TaskStateAccessor(ti_id=self.TI_ID).clear(all_map_indices=True) + + mock_supervisor_comms.send.assert_called_once_with( + ClearTaskState(ti_id=self.TI_ID, all_map_indices=True) + ) + + +class TestAssetStateAccessor: + ASSET_NAME = "debug_watcher_asset" + ASSET_URI = "s3://bucket/key" + + def test_get_returns_value(self, mock_supervisor_comms): + mock_supervisor_comms.send.return_value = AssetStateResult(value="2026-04-30T00:00:00Z") + + result = AssetStateAccessor(name=self.ASSET_NAME).get("watermark") + + assert result == "2026-04-30T00:00:00Z" + mock_supervisor_comms.send.assert_called_once_with( + GetAssetStateByName(name=self.ASSET_NAME, key="watermark") + ) + + def test_get_returns_none_on_404(self, mock_supervisor_comms): + mock_supervisor_comms.send.return_value = ErrorResponse( + error=ErrorType.ASSET_STATE_NOT_FOUND, detail={"key": "missing_key"} + ) + + result = AssetStateAccessor(name=self.ASSET_NAME).get("missing_key") + + assert result is None + + def test_get_raises_on_error(self, mock_supervisor_comms): + mock_supervisor_comms.send.return_value = ErrorResponse( + error=ErrorType.GENERIC_ERROR, detail={"message": "server error"} + ) + + with pytest.raises(AirflowRuntimeError): + AssetStateAccessor(name=self.ASSET_NAME).get("some_key") + + def test_set_operation(self, mock_supervisor_comms): + mock_supervisor_comms.send.return_value = OKResponse(ok=True) + + AssetStateAccessor(name=self.ASSET_NAME).set("watermark", "2026-04-30T00:00:00Z") + + mock_supervisor_comms.send.assert_called_once_with( + SetAssetStateByName(name=self.ASSET_NAME, key="watermark", value="2026-04-30T00:00:00Z") + ) + + def test_delete_operation(self, mock_supervisor_comms): + mock_supervisor_comms.send.return_value = OKResponse(ok=True) + + AssetStateAccessor(name=self.ASSET_NAME).delete("watermark") + + mock_supervisor_comms.send.assert_called_once_with( + DeleteAssetStateByName(name=self.ASSET_NAME, key="watermark") + ) + + def test_clear_operation(self, mock_supervisor_comms): + mock_supervisor_comms.send.return_value = OKResponse(ok=True) + + AssetStateAccessor(name=self.ASSET_NAME).clear() + + mock_supervisor_comms.send.assert_called_once_with(ClearAssetStateByName(name=self.ASSET_NAME)) + + def test_get_by_uri(self, mock_supervisor_comms): + mock_supervisor_comms.send.return_value = AssetStateResult(value="2026-04-30T00:00:00Z") + + result = AssetStateAccessor(uri=self.ASSET_URI).get("watermark") + + assert result == "2026-04-30T00:00:00Z" + mock_supervisor_comms.send.assert_called_once_with( + GetAssetStateByUri(uri=self.ASSET_URI, key="watermark") + ) + + def test_set_by_uri(self, mock_supervisor_comms): + mock_supervisor_comms.send.return_value = OKResponse(ok=True) + + AssetStateAccessor(uri=self.ASSET_URI).set("watermark", "2026-04-30T00:00:00Z") + + mock_supervisor_comms.send.assert_called_once_with( + SetAssetStateByUri(uri=self.ASSET_URI, key="watermark", value="2026-04-30T00:00:00Z") + ) + + def test_delete_by_uri(self, mock_supervisor_comms): + mock_supervisor_comms.send.return_value = OKResponse(ok=True) + + AssetStateAccessor(uri=self.ASSET_URI).delete("watermark") + + mock_supervisor_comms.send.assert_called_once_with( + DeleteAssetStateByUri(uri=self.ASSET_URI, key="watermark") + ) + + def test_clear_by_uri(self, mock_supervisor_comms): + mock_supervisor_comms.send.return_value = OKResponse(ok=True) + + AssetStateAccessor(uri=self.ASSET_URI).clear() + + mock_supervisor_comms.send.assert_called_once_with(ClearAssetStateByUri(uri=self.ASSET_URI)) + + +class TestAssetStateAccessors: + ASSET_NAME = "my_asset" + ASSET_URI = "s3://bucket/key" + + def test_subscript_by_asset_routes_by_name(self, mock_supervisor_comms): + asset = Asset(name=self.ASSET_NAME, uri=f"s3://{self.ASSET_NAME}") + mock_supervisor_comms.send.return_value = AssetStateResult(value="v1") + + result = AssetStateAccessors([asset])[asset].get("watermark") + + assert result == "v1" + mock_supervisor_comms.send.assert_called_once_with( + GetAssetStateByName(name=self.ASSET_NAME, key="watermark") + ) + + def test_subscript_by_asset_name_ref(self, mock_supervisor_comms): + ref = AssetNameRef(name=self.ASSET_NAME) + mock_supervisor_comms.send.return_value = AssetStateResult(value="v2") + + result = AssetStateAccessors([ref])[ref].get("watermark") + + assert result == "v2" + mock_supervisor_comms.send.assert_called_once_with( + GetAssetStateByName(name=self.ASSET_NAME, key="watermark") + ) + + def test_subscript_by_uri_ref(self, mock_supervisor_comms): + ref = AssetUriRef(uri=self.ASSET_URI) + mock_supervisor_comms.send.return_value = AssetStateResult(value="v3") + + result = AssetStateAccessors([ref])[ref].get("watermark") + + assert result == "v3" + mock_supervisor_comms.send.assert_called_once_with( + GetAssetStateByUri(uri=self.ASSET_URI, key="watermark") + ) + + def test_get_single_inlet_simplified(self, mock_supervisor_comms): + asset = Asset(name=self.ASSET_NAME, uri=f"s3://{self.ASSET_NAME}") + mock_supervisor_comms.send.return_value = AssetStateResult(value="v4") + + result = AssetStateAccessors([asset]).get("watermark") + + assert result == "v4" + mock_supervisor_comms.send.assert_called_once_with( + GetAssetStateByName(name=self.ASSET_NAME, key="watermark") + ) + + def test_set_single_inlet_simplified(self, mock_supervisor_comms): + asset = Asset(name=self.ASSET_NAME, uri=f"s3://{self.ASSET_NAME}") + mock_supervisor_comms.send.return_value = OKResponse(ok=True) + + AssetStateAccessors([asset]).set("watermark", "2026-05-01") + + mock_supervisor_comms.send.assert_called_once_with( + SetAssetStateByName(name=self.ASSET_NAME, key="watermark", value="2026-05-01") + ) + + def test_delete_single_inlet_simplified(self, mock_supervisor_comms): + asset = Asset(name=self.ASSET_NAME, uri=f"s3://{self.ASSET_NAME}") + mock_supervisor_comms.send.return_value = OKResponse(ok=True) + + AssetStateAccessors([asset]).delete("watermark") + + mock_supervisor_comms.send.assert_called_once_with( + DeleteAssetStateByName(name=self.ASSET_NAME, key="watermark") + ) + + def test_clear_single_inlet_simplified(self, mock_supervisor_comms): + asset = Asset(name=self.ASSET_NAME, uri=f"s3://{self.ASSET_NAME}") + mock_supervisor_comms.send.return_value = OKResponse(ok=True) + + AssetStateAccessors([asset]).clear() + + mock_supervisor_comms.send.assert_called_once_with(ClearAssetStateByName(name=self.ASSET_NAME)) + + def test_double_reference_raises(self): + a1 = Asset(name="asset_one", uri="s3://one") + a2 = Asset(name="asset_two", uri="s3://two") + + with pytest.raises(ValueError, match="2 concrete inlets"): + AssetStateAccessors([a1, a2]).get("watermark") + + def test_alias_inlet_resolves_to_concrete_assets(self, mock_supervisor_comms): + alias = AssetAlias(name="my_alias") + mock_supervisor_comms.send.return_value = AssetsByAliasResult( + assets=[AssetResult(name="resolved_asset", uri="s3://bucket/resolved", group="asset")] + ) + mock_supervisor_comms.send.return_value = AssetsByAliasResult( + assets=[AssetResult(name="resolved_asset", uri="s3://bucket/resolved", group="asset")] + ) + + accessors = AssetStateAccessors([alias]) + + mock_supervisor_comms.send.assert_called_once_with(GetAssetsByAlias(alias_name="my_alias")) + resolved = Asset(name="resolved_asset", uri="s3://bucket/resolved") + assert resolved.name in accessors._by_name + + def test_alias_inlet_no_resolved_assets_contributes_nothing(self, mock_supervisor_comms): + alias = AssetAlias(name="empty_alias") + mock_supervisor_comms.send.return_value = AssetsByAliasResult(assets=[]) + + accessors = AssetStateAccessors([alias]) + + assert accessors._total == 0 diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index f0d8e1a0b65d6..a8f97f81ac266 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -72,6 +72,11 @@ from airflow.sdk.execution_time.comms import ( AssetEventsResult, AssetResult, + AssetsByAliasResult, + AssetStateResult, + ClearAssetStateByName, + ClearAssetStateByUri, + ClearTaskState, CommsDecoder, ConnectionResult, CreateHITLDetailPayload, @@ -79,6 +84,9 @@ DagRunResult, DagRunStateResult, DeferTask, + DeleteAssetStateByName, + DeleteAssetStateByUri, + DeleteTaskState, DeleteVariable, DeleteXCom, DRCount, @@ -87,6 +95,9 @@ GetAssetByUri, GetAssetEventByAsset, GetAssetEventByAssetAlias, + GetAssetsByAlias, + GetAssetStateByName, + GetAssetStateByUri, GetConnection, GetDag, GetDagRun, @@ -98,6 +109,7 @@ GetPrevSuccessfulDagRun, GetTaskBreadcrumbs, GetTaskRescheduleStartDate, + GetTaskState, GetTaskStates, GetTICount, GetVariable, @@ -117,14 +129,18 @@ ResendLoggingFD, RetryTask, SentFDs, + SetAssetStateByName, + SetAssetStateByUri, SetRenderedFields, SetRenderedMapIndex, + SetTaskState, SetXCom, SkipDownstreamTasks, SucceedTask, TaskBreadcrumbsResult, TaskRescheduleStartDate, TaskState, + TaskStateResult, TaskStatesResult, TICount, ToSupervisor, @@ -1807,6 +1823,29 @@ class RequestTestCase: ), test_id="get_asset_by_uri", ), + RequestTestCase( + message=GetAssetsByAlias(alias_name="my_alias"), + expected_body={ + "assets": [ + { + "name": "asset_a", + "uri": "s3://bucket/a", + "group": "asset", + "extra": None, + "type": "AssetResult", + } + ], + "type": "AssetsByAliasResult", + }, + client_mock=ClientMock( + method_path="assets.get_by_alias", + kwargs={"alias_name": "my_alias"}, + response=AssetsByAliasResult( + assets=[AssetResult(name="asset_a", uri="s3://bucket/a", group="asset", extra=None)] + ), + ), + test_id="get_assets_by_alias", + ), RequestTestCase( message=GetAssetEventByAsset(uri="s3://bucket/obj", name="test"), expected_body={ @@ -2656,6 +2695,148 @@ class RequestTestCase: ), test_id="get_dag", ), + RequestTestCase( + message=GetTaskState(ti_id=TI_ID, key="job_id"), + test_id="get_task_state", + client_mock=ClientMock( + method_path="task_state.get", + args=(TI_ID, "job_id"), + response=TaskStateResult(value="spark_app_001"), + ), + expected_body={"value": "spark_app_001", "type": "TaskStateResult"}, + ), + RequestTestCase( + message=SetTaskState(ti_id=TI_ID, key="job_id", value="spark_app_001"), + test_id="set_task_state", + client_mock=ClientMock( + method_path="task_state.set", + args=(TI_ID, "job_id", "spark_app_001"), + response=OKResponse(ok=True), + ), + expected_body={"ok": True, "type": "OKResponse"}, + ), + RequestTestCase( + message=DeleteTaskState(ti_id=TI_ID, key="job_id"), + test_id="delete_task_state", + client_mock=ClientMock( + method_path="task_state.delete", + args=(TI_ID, "job_id"), + response=OKResponse(ok=True), + ), + expected_body={"ok": True, "type": "OKResponse"}, + ), + RequestTestCase( + message=ClearTaskState(ti_id=TI_ID), + test_id="clear_task_state", + client_mock=ClientMock( + method_path="task_state.clear", + args=(TI_ID,), + kwargs={"all_map_indices": False}, + response=OKResponse(ok=True), + ), + expected_body={"ok": True, "type": "OKResponse"}, + ), + RequestTestCase( + message=ClearTaskState(ti_id=TI_ID, all_map_indices=True), + test_id="clear_task_state_all_map_indices", + client_mock=ClientMock( + method_path="task_state.clear", + args=(TI_ID,), + kwargs={"all_map_indices": True}, + response=OKResponse(ok=True), + ), + expected_body={"ok": True, "type": "OKResponse"}, + ), + RequestTestCase( + message=GetAssetStateByName(name="debug_watcher_asset", key="watermark"), + test_id="get_asset_state_by_name", + client_mock=ClientMock( + method_path="asset_state.get", + args=("watermark",), + kwargs={"name": "debug_watcher_asset"}, + response=AssetStateResult(value="2026-04-30T00:00:00Z"), + ), + expected_body={"value": "2026-04-30T00:00:00Z", "type": "AssetStateResult"}, + ), + RequestTestCase( + message=GetAssetStateByUri(uri="s3://bucket/key", key="watermark"), + test_id="get_asset_state_by_uri", + client_mock=ClientMock( + method_path="asset_state.get", + args=("watermark",), + kwargs={"uri": "s3://bucket/key"}, + response=AssetStateResult(value="2026-04-30T00:00:00Z"), + ), + expected_body={"value": "2026-04-30T00:00:00Z", "type": "AssetStateResult"}, + ), + RequestTestCase( + message=SetAssetStateByName( + name="debug_watcher_asset", key="watermark", value="2026-04-30T00:00:00Z" + ), + test_id="set_asset_state_by_name", + client_mock=ClientMock( + method_path="asset_state.set", + args=("watermark", "2026-04-30T00:00:00Z"), + kwargs={"name": "debug_watcher_asset"}, + response=OKResponse(ok=True), + ), + expected_body={"ok": True, "type": "OKResponse"}, + ), + RequestTestCase( + message=SetAssetStateByUri(uri="s3://bucket/key", key="watermark", value="2026-04-30T00:00:00Z"), + test_id="set_asset_state_by_uri", + client_mock=ClientMock( + method_path="asset_state.set", + args=("watermark", "2026-04-30T00:00:00Z"), + kwargs={"uri": "s3://bucket/key"}, + response=OKResponse(ok=True), + ), + expected_body={"ok": True, "type": "OKResponse"}, + ), + RequestTestCase( + message=DeleteAssetStateByName(name="debug_watcher_asset", key="watermark"), + test_id="delete_asset_state_by_name", + client_mock=ClientMock( + method_path="asset_state.delete", + args=("watermark",), + kwargs={"name": "debug_watcher_asset"}, + response=OKResponse(ok=True), + ), + expected_body={"ok": True, "type": "OKResponse"}, + ), + RequestTestCase( + message=DeleteAssetStateByUri(uri="s3://bucket/key", key="watermark"), + test_id="delete_asset_state_by_uri", + client_mock=ClientMock( + method_path="asset_state.delete", + args=("watermark",), + kwargs={"uri": "s3://bucket/key"}, + response=OKResponse(ok=True), + ), + expected_body={"ok": True, "type": "OKResponse"}, + ), + RequestTestCase( + message=ClearAssetStateByName(name="debug_watcher_asset"), + test_id="clear_asset_state_by_name", + client_mock=ClientMock( + method_path="asset_state.clear", + args=(), + kwargs={"name": "debug_watcher_asset"}, + response=OKResponse(ok=True), + ), + expected_body={"ok": True, "type": "OKResponse"}, + ), + RequestTestCase( + message=ClearAssetStateByUri(uri="s3://bucket/key"), + test_id="clear_asset_state_by_uri", + client_mock=ClientMock( + method_path="asset_state.clear", + args=(), + kwargs={"uri": "s3://bucket/key"}, + response=OKResponse(ok=True), + ), + expected_body={"ok": True, "type": "OKResponse"}, + ), ] diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 630aff9094ed1..723ca42d93aa6 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -67,7 +67,7 @@ ) from airflow.sdk.bases.xcom import BaseXCom from airflow.sdk.definitions._internal.types import NOTSET, SET_DURING_EXECUTION, is_arg_set -from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey, Dataset, Model +from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey, AssetUriRef, Dataset, Model from airflow.sdk.definitions.param import DagParam from airflow.sdk.exceptions import ( AirflowException, @@ -86,31 +86,46 @@ from airflow.sdk.execution_time.comms import ( AssetEventResult, AssetEventsResult, + AssetResult, + AssetsByAliasResult, BundleInfo, + ClearAssetStateByName, + ClearTaskState, ConnectionResult, DagResult, DagRunStateResult, DeferTask, + DeleteAssetStateByName, + DeleteTaskState, DRCount, ErrorResponse, + GetAssetByUri, + GetAssetsByAlias, + GetAssetStateByName, + GetAssetStateByUri, GetConnection, GetDag, GetDagRunState, GetDRCount, GetPreviousDagRun, GetPreviousTI, + GetTaskState, GetTaskStates, GetTICount, GetVariable, GetXCom, GetXComSequenceSlice, + InactiveAssetsResult, MaskSecret, OKResponse, PreviousDagRunResult, PreviousTIResult, PrevSuccessfulDagRunResult, RescheduleTask, + SetAssetStateByName, + SetAssetStateByUri, SetRenderedFields, + SetTaskState, SetXCom, SkipDownstreamTasks, StartupDetails, @@ -120,6 +135,7 @@ TaskStatesResult, TICount, TriggerDagRun, + ValidateInletsAndOutlets, VariableResult, XComResult, XComSequenceSliceResult, @@ -129,6 +145,7 @@ InletEventsAccessors, MacrosAccessor, OutletEventAccessors, + TaskStateAccessor, TriggeringAssetEventsAccessor, VariableAccessor, ) @@ -1764,6 +1781,7 @@ def test_get_context_without_ti_context_from_server(self, mocked_parse, make_ti_ "run_id": "test_run", "task": task, "task_instance": runtime_ti, + "task_state": TaskStateAccessor(ti_id=ti_id), "ti": runtime_ti, } @@ -1809,6 +1827,7 @@ def test_get_context_with_ti_context_from_server(self, create_runtime_ti, mock_s "run_id": "test_run", "task": task, "task_instance": runtime_ti, + "task_state": TaskStateAccessor(ti_id=runtime_ti.id), "ti": runtime_ti, "dag_run": dr, "data_interval_end": timezone.datetime(2024, 12, 1, 1, 0, 0), @@ -4868,3 +4887,220 @@ def test_dag_add_result(create_runtime_ti, mock_supervisor_comms): dag_result=True, ) ) + + +class TestTaskInstanceStateOperations: + """Tests to verify that tasks can perform state operations (task / asset) via the supervisor.""" + + def test_task_can_set_and_get_state(self, create_runtime_ti, mock_supervisor_comms): + class MyOperator(BaseOperator): + def execute(self, context): + ts = context["task_state"] + ts.set("job_id", "spark_app_001") + return ts.get("job_id") + + task = MyOperator(task_id="t") + runtime_ti = create_runtime_ti(task=task) + + run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) + + mock_supervisor_comms.send.assert_any_call( + SetTaskState(ti_id=runtime_ti.id, key="job_id", value="spark_app_001") + ) + mock_supervisor_comms.send.assert_any_call(GetTaskState(ti_id=runtime_ti.id, key="job_id")) + + def test_task_can_delete_state(self, create_runtime_ti, mock_supervisor_comms): + class MyOperator(BaseOperator): + def execute(self, context): + context["task_state"].delete("job_id") + + task = MyOperator(task_id="t") + runtime_ti = create_runtime_ti(task=task) + + run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) + + mock_supervisor_comms.send.assert_any_call(DeleteTaskState(ti_id=runtime_ti.id, key="job_id")) + + @pytest.mark.parametrize( + ("call_kwargs", "expected_flag"), + [ + pytest.param({}, False, id="default"), + pytest.param({"all_map_indices": True}, True, id="fleet-wipe"), + ], + ) + def test_task_can_clear_state(self, call_kwargs, expected_flag, create_runtime_ti, mock_supervisor_comms): + class MyOperator(BaseOperator): + def execute(self, context): + context["task_state"].clear(**call_kwargs) + + task = MyOperator(task_id="t") + runtime_ti = create_runtime_ti(task=task) + run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) + mock_supervisor_comms.send.assert_any_call( + ClearTaskState(ti_id=runtime_ti.id, all_map_indices=expected_flag) + ) + + @staticmethod + def _watcher_side_effect(msg=None, *args, **kwargs): + actual = msg or (args[0] if args else None) + if isinstance(actual, ValidateInletsAndOutlets): + return InactiveAssetsResult(inactive_assets=[]) + if isinstance(actual, GetAssetByUri): + # GetAssetByUri has no .name field. Mirroring AssetModel behaviour: + # when only uri is provided, name defaults to uri. + return AssetResult(name=actual.uri, uri=actual.uri, group="asset") + return OKResponse(ok=True) + + def test_asset_state_get_and_set(self, create_runtime_ti, mock_supervisor_comms): + watched = Asset(name="my_asset", uri="s3://bucket/data") + + class WatcherOperator(BaseOperator): + def execute(self, context): + context["asset_state"].set("watermark", "2026-04-30") + context["asset_state"].get("watermark") + + task = WatcherOperator(task_id="t", inlets=[watched]) + runtime_ti = create_runtime_ti(task=task) + mock_supervisor_comms.send.side_effect = TestTaskInstanceStateOperations._watcher_side_effect + + run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) + + mock_supervisor_comms.send.assert_any_call( + SetAssetStateByName(name="my_asset", key="watermark", value="2026-04-30") + ) + mock_supervisor_comms.send.assert_any_call(GetAssetStateByName(name="my_asset", key="watermark")) + + def test_asset_state_delete(self, create_runtime_ti, mock_supervisor_comms): + watched = Asset(name="my_asset", uri="s3://bucket/data") + + class WatcherOperator(BaseOperator): + def execute(self, context): + context["asset_state"].delete("watermark") + + task = WatcherOperator(task_id="t", inlets=[watched]) + runtime_ti = create_runtime_ti(task=task) + mock_supervisor_comms.send.side_effect = TestTaskInstanceStateOperations._watcher_side_effect + + run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) + + mock_supervisor_comms.send.assert_any_call(DeleteAssetStateByName(name="my_asset", key="watermark")) + + def test_asset_state_clear(self, create_runtime_ti, mock_supervisor_comms): + watched = Asset(name="my_asset", uri="s3://bucket/data") + + class WatcherOperator(BaseOperator): + def execute(self, context): + context["asset_state"].clear() + + task = WatcherOperator(task_id="t", inlets=[watched]) + runtime_ti = create_runtime_ti(task=task) + mock_supervisor_comms.send.side_effect = TestTaskInstanceStateOperations._watcher_side_effect + + run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) + + mock_supervisor_comms.send.assert_any_call(ClearAssetStateByName(name="my_asset")) + + def test_asset_state_uri_ref_inlet(self, create_runtime_ti, mock_supervisor_comms): + watched = AssetUriRef(uri="s3://bucket/data") + + class WatcherOperator(BaseOperator): + def execute(self, context): + context["asset_state"].set("watermark", "2026-04-30") + context["asset_state"].get("watermark") + + task = WatcherOperator(task_id="t", inlets=[watched]) + runtime_ti = create_runtime_ti(task=task) + mock_supervisor_comms.send.side_effect = TestTaskInstanceStateOperations._watcher_side_effect + + run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) + + mock_supervisor_comms.send.assert_any_call( + SetAssetStateByUri(uri="s3://bucket/data", key="watermark", value="2026-04-30") + ) + mock_supervisor_comms.send.assert_any_call( + GetAssetStateByUri(uri="s3://bucket/data", key="watermark") + ) + + def test_asset_state_alias_as_inlet(self, create_runtime_ti, mock_supervisor_comms): + alias = AssetAlias(name="my_alias") + resolved = Asset(name="resolved_asset", uri="s3://bucket/resolved") + + class WatcherOperator(BaseOperator): + def execute(self, context): + context["asset_state"][resolved].set("watermark", "2026-05-01") + + def side_effect(msg): + if isinstance(msg, GetAssetsByAlias): + return AssetsByAliasResult( + assets=[AssetResult(name=resolved.name, uri=resolved.uri, group="asset")] + ) + return TestTaskInstanceStateOperations._watcher_side_effect(msg) + + task = WatcherOperator(task_id="t", inlets=[alias]) + runtime_ti = create_runtime_ti(task=task) + mock_supervisor_comms.send.side_effect = side_effect + + run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) + + mock_supervisor_comms.send.assert_any_call( + SetAssetStateByName(name="resolved_asset", key="watermark", value="2026-05-01") + ) + + def test_asset_state_alias_inlet_no_resolved_assets(self, create_runtime_ti, mock_supervisor_comms): + alias = AssetAlias(name="empty_alias") + + class WatcherOperator(BaseOperator): + def execute(self, context): + # asset_state is in context but it is empty because alias resolved to nothing + assert "asset_state" in context + + def side_effect(msg): + if isinstance(msg, GetAssetsByAlias): + return AssetsByAliasResult(assets=[]) + return TestTaskInstanceStateOperations._watcher_side_effect(msg) + + task = WatcherOperator(task_id="t", inlets=[alias]) + runtime_ti = create_runtime_ti(task=task) + mock_supervisor_comms.send.side_effect = side_effect + + run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) + + def test_asset_state_keyed_access_single_inlet(self, create_runtime_ti, mock_supervisor_comms): + watched = Asset(name="my_asset", uri="s3://bucket/data") + + class WatcherOperator(BaseOperator): + def execute(self, context): + # accessing via asset name key + context["asset_state"][watched].set("watermark", "2026-05-01") + + task = WatcherOperator(task_id="t", inlets=[watched]) + runtime_ti = create_runtime_ti(task=task) + mock_supervisor_comms.send.side_effect = TestTaskInstanceStateOperations._watcher_side_effect + + run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) + + mock_supervisor_comms.send.assert_any_call( + SetAssetStateByName(name="my_asset", key="watermark", value="2026-05-01") + ) + + def test_asset_state_multi_inlet(self, create_runtime_ti, mock_supervisor_comms): + asset_a = Asset(name="asset_a", uri="s3://bucket/a") + asset_b = Asset(name="asset_b", uri="s3://bucket/b") + + class MultiInletOperator(BaseOperator): + def execute(self, context): + context["asset_state"][asset_a].set("watermark_a", "2026-05-01") + context["asset_state"][asset_b].set("watermark_b", "2026-05-02") + + task = MultiInletOperator(task_id="t", inlets=[asset_a, asset_b]) + runtime_ti = create_runtime_ti(task=task) + mock_supervisor_comms.send.side_effect = TestTaskInstanceStateOperations._watcher_side_effect + + run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) + + mock_supervisor_comms.send.assert_any_call( + SetAssetStateByName(name="asset_a", key="watermark_a", value="2026-05-01") + ) + mock_supervisor_comms.send.assert_any_call( + SetAssetStateByName(name="asset_b", key="watermark_b", value="2026-05-02") + )