Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions airflow-core/tests/unit/dag_processing/test_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1969,6 +1969,19 @@ def get_type_names(union_type):
"UpdateHITLDetail",
"GetHITLDetailResponse",
"SetRenderedMapIndex",
# AIP-103 task/asset state — Dag processor has no task execution context.
"GetTaskState",
"SetTaskState",
"DeleteTaskState",
"ClearTaskState",
"GetAssetStateByName",
"GetAssetStateByUri",
"SetAssetStateByName",
"SetAssetStateByUri",
"DeleteAssetStateByName",
"DeleteAssetStateByUri",
"ClearAssetStateByName",
"ClearAssetStateByUri",
}

in_task_runner_but_not_in_dag_processing_process = {
Expand All @@ -1987,6 +2000,9 @@ def get_type_names(union_type):
"InactiveAssetsResult",
"CreateHITLDetailPayload",
"HITLDetailRequestResult",
# AIP-103 task/asset state results — worker-only responses to the above messages.
"TaskStateResult",
"AssetStateResult",
}

supervisor_diff = supervisor_types - manager_types - in_supervisor_but_not_in_manager
Expand Down
16 changes: 16 additions & 0 deletions airflow-core/tests/unit/jobs/test_triggerer_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -1727,6 +1727,19 @@ def get_type_names(union_type):
"CreateHITLDetailPayload",
"SetRenderedMapIndex",
"GetDag",
# AIP-103 task/asset state — triggerer has no task execution context.
"GetTaskState",
"SetTaskState",
"DeleteTaskState",
"ClearTaskState",
"GetAssetStateByName",
"GetAssetStateByUri",
"SetAssetStateByName",
"SetAssetStateByUri",
"DeleteAssetStateByName",
"DeleteAssetStateByUri",
"ClearAssetStateByName",
"ClearAssetStateByUri",
}

in_task_but_not_in_trigger_runner = {
Expand All @@ -1747,6 +1760,9 @@ def get_type_names(union_type):
"PreviousTIResult",
"HITLDetailRequestResult",
"DagResult",
# AIP-103 task/asset state results — worker-only responses to the above messages.
"TaskStateResult",
"AssetStateResult",
}

supervisor_diff = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
AIRFLOW_V_3_0_PLUS,
AIRFLOW_V_3_1_PLUS,
AIRFLOW_V_3_2_PLUS,
AIRFLOW_V_3_3_PLUS,
NOTSET,
)

Expand Down Expand Up @@ -1131,6 +1132,11 @@ def test_virtualenv_serializable_context_fields(self, create_task_instance):
"inlet_events",
"outlet_events",
}
if AIRFLOW_V_3_3_PLUS:
# AIP-103: task_state is a live accessor backed by the supervisor pipe —
# not serializable and meaningless in a virtualenv subprocess.
# asset_state is excluded via its absence: only present when a task has inlets.
intentionally_excluded_context_keys.add("task_state")

ti = create_task_instance(dag_id=self.dag_id, task_id=self.task_id, schedule=None)
context = ti.get_template_context()
Expand Down
3 changes: 3 additions & 0 deletions scripts/ci/prek/check_template_context_variable_in_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@
"data_interval_start",
"prev_data_interval_start_success",
"prev_data_interval_end_success",
# AIP-103: task_state/asset_state aren't documented yet. Will be done in a later PR.
"task_state",
"asset_state",
Comment thread
amoghrajesh marked this conversation as resolved.
}


Expand Down
105 changes: 105 additions & 0 deletions task-sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
API_VERSION,
AssetEventsResponse,
AssetResponse,
AssetStatePutBody,
AssetStateResponse,
ConnectionResponse,
DagResponse,
DagRun,
Expand All @@ -56,6 +58,8 @@
PrevSuccessfulDagRunResponse,
TaskBreadcrumbsResponse,
TaskInstanceState,
TaskStatePutBody,
TaskStateResponse,
TaskStatesResponse,
TerminalStateNonSuccess,
TIDeferredStatePayload,
Expand Down Expand Up @@ -639,6 +643,95 @@ def get_sequence_slice(
return XComSequenceSliceResponse.model_validate_json(resp.read())


class TaskStateOperations:
__slots__ = ("client",)

def __init__(self, client: Client):
self.client = client

def get(self, ti_id: uuid.UUID, key: str) -> TaskStateResponse | ErrorResponse:
"""Get a task state value from the API server."""
try:
resp = self.client.get(f"state/ti/{ti_id}/{key}")
except ServerResponseError as e:
if e.response.status_code == HTTPStatus.NOT_FOUND:
log.debug("Task state key not found", ti_id=ti_id, key=key)
return ErrorResponse(error=ErrorType.TASK_STATE_NOT_FOUND, detail={"key": key})
raise
return TaskStateResponse.model_validate_json(resp.read())

def set(self, ti_id: uuid.UUID, key: str, value: str) -> OKResponse:
"""Set a task state value via the API server."""
body = TaskStatePutBody(value=value)
self.client.put(f"state/ti/{ti_id}/{key}", content=body.model_dump_json())
return OKResponse(ok=True)

def delete(self, ti_id: uuid.UUID, key: str) -> OKResponse:
"""Delete a single task state key via the API server."""
self.client.delete(f"state/ti/{ti_id}/{key}")
return OKResponse(ok=True)

def clear(self, ti_id: uuid.UUID, all_map_indices: bool = False) -> OKResponse:
"""Clear all task state keys for a task instance via the API server."""
params = {"all_map_indices": "true"} if all_map_indices else {}
self.client.delete(f"state/ti/{ti_id}", params=params)
return OKResponse(ok=True)


class AssetStateOperations:
__slots__ = ("client",)

def __init__(self, client: Client):
self.client = client

def _resolve_endpoint(
self, op: str, *, key: str | None = None, name: str | None = None, uri: str | None = None
) -> tuple[str, dict[str, str]]:
if name:
params: dict[str, str] = {"name": name}
endpoint = f"state/asset/by-name/{op}"
elif uri:
params = {"uri": uri}
endpoint = f"state/asset/by-uri/{op}"
else:
raise ValueError("Either `name` or `uri` must be provided")
Comment thread
Lee-W marked this conversation as resolved.
if key is not None:
params["key"] = key
return endpoint, params

def get(
self, key: str, *, name: str | None = None, uri: str | None = None
) -> AssetStateResponse | ErrorResponse:
"""Get an asset state value from the API server."""
endpoint, params = self._resolve_endpoint("value", key=key, name=name, uri=uri)
try:
resp = self.client.get(endpoint, params=params)
except ServerResponseError as e:
if e.response.status_code == HTTPStatus.NOT_FOUND:
log.debug("Asset state key not found", name=name, uri=uri, key=key)
return ErrorResponse(error=ErrorType.ASSET_STATE_NOT_FOUND, detail={"key": key})
raise
return AssetStateResponse.model_validate_json(resp.read())

def set(self, key: str, value: str, *, name: str | None = None, uri: str | None = None) -> OKResponse:
"""Set an asset state value via the API server."""
endpoint, params = self._resolve_endpoint("value", key=key, name=name, uri=uri)
self.client.put(endpoint, params=params, content=AssetStatePutBody(value=value).model_dump_json())
return OKResponse(ok=True)

def delete(self, key: str, *, name: str | None = None, uri: str | None = None) -> OKResponse:
"""Delete a single asset state key via the API server."""
endpoint, params = self._resolve_endpoint("value", key=key, name=name, uri=uri)
self.client.delete(endpoint, params=params)
return OKResponse(ok=True)

def clear(self, *, name: str | None = None, uri: str | None = None) -> OKResponse:
"""Clear all state keys for an asset via the API server."""
endpoint, params = self._resolve_endpoint("clear", name=name, uri=uri)
self.client.delete(endpoint, params=params)
return OKResponse(ok=True)


class AssetOperations:
__slots__ = ("client",)

Expand Down Expand Up @@ -1052,6 +1145,18 @@ def asset_events(self) -> AssetEventOperations:
"""Operations related to Asset Events."""
return AssetEventOperations(self)

@lru_cache() # type: ignore[misc]
@property
def task_state(self) -> TaskStateOperations:
"""Operations related to task state."""
return TaskStateOperations(self)

@lru_cache() # type: ignore[misc]
@property
def asset_state(self) -> AssetStateOperations:
"""Operations related to asset state."""
return AssetStateOperations(self)

@lru_cache() # type: ignore[misc]
@property
def hitl(self):
Expand Down
8 changes: 7 additions & 1 deletion task-sdk/src/airflow/sdk/definitions/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@

from airflow.sdk.bases.operator import BaseOperator
from airflow.sdk.definitions.dag import DAG
from airflow.sdk.execution_time.context import InletEventsAccessors
from airflow.sdk.execution_time.context import (
AssetStateAccessors,
InletEventsAccessors,
TaskStateAccessor,
)
from airflow.sdk.types import (
DagRunProtocol,
Operator,
Expand Down Expand Up @@ -72,6 +76,8 @@ class Context(TypedDict, total=False):
task_reschedule_count: int
task_instance: RuntimeTaskInstanceProtocol
task_instance_key_str: str
task_state: TaskStateAccessor
Comment thread
amoghrajesh marked this conversation as resolved.
asset_state: AssetStateAccessors
# `templates_dict` is only set in PythonOperator
templates_dict: NotRequired[dict[str, Any] | None]
test_mode: bool
Expand Down
2 changes: 2 additions & 0 deletions task-sdk/src/airflow/sdk/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ class ErrorType(enum.Enum):
VARIABLE_NOT_FOUND = "VARIABLE_NOT_FOUND"
XCOM_NOT_FOUND = "XCOM_NOT_FOUND"
ASSET_NOT_FOUND = "ASSET_NOT_FOUND"
TASK_STATE_NOT_FOUND = "TASK_STATE_NOT_FOUND"
ASSET_STATE_NOT_FOUND = "ASSET_STATE_NOT_FOUND"
DAGRUN_ALREADY_EXISTS = "DAGRUN_ALREADY_EXISTS"
GENERIC_ERROR = "GENERIC_ERROR"
API_SERVER_ERROR = "API_SERVER_ERROR"
Expand Down
Loading
Loading