Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions scripts/ci/prek/check_template_context_variable_in_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@
"data_interval_start",
"prev_data_interval_start_success",
"prev_data_interval_end_success",
# AIP-103: task_state/asset_state aren't documented yet. Will be done in a later PR.
"task_state",
"asset_state",
Comment thread
amoghrajesh marked this conversation as resolved.
}


Expand Down
105 changes: 105 additions & 0 deletions task-sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
API_VERSION,
AssetEventsResponse,
AssetResponse,
AssetStatePutBody,
AssetStateResponse,
ConnectionResponse,
DagResponse,
DagRun,
Expand All @@ -56,6 +58,8 @@
PrevSuccessfulDagRunResponse,
TaskBreadcrumbsResponse,
TaskInstanceState,
TaskStatePutBody,
TaskStateResponse,
TaskStatesResponse,
TerminalStateNonSuccess,
TIDeferredStatePayload,
Expand Down Expand Up @@ -639,6 +643,95 @@ def get_sequence_slice(
return XComSequenceSliceResponse.model_validate_json(resp.read())


class TaskStateOperations:
__slots__ = ("client",)

def __init__(self, client: Client):
self.client = client

def get(self, ti_id: uuid.UUID, key: str) -> TaskStateResponse | ErrorResponse:
"""Get a task state value from the API server."""
try:
resp = self.client.get(f"state/ti/{ti_id}/{key}")
except ServerResponseError as e:
if e.response.status_code == HTTPStatus.NOT_FOUND:
log.debug("Task state key not found", ti_id=ti_id, key=key)
return ErrorResponse(error=ErrorType.TASK_STATE_NOT_FOUND, detail={"key": key})
raise
return TaskStateResponse.model_validate_json(resp.read())

def set(self, ti_id: uuid.UUID, key: str, value: str) -> OKResponse:
"""Set a task state value via the API server."""
body = TaskStatePutBody(value=value)
self.client.put(f"state/ti/{ti_id}/{key}", content=body.model_dump_json())
return OKResponse(ok=True)

def delete(self, ti_id: uuid.UUID, key: str) -> OKResponse:
"""Delete a single task state key via the API server."""
self.client.delete(f"state/ti/{ti_id}/{key}")
return OKResponse(ok=True)

def clear(self, ti_id: uuid.UUID, all_map_indices: bool = False) -> OKResponse:
"""Clear all task state keys for a task instance via the API server."""
params = {"all_map_indices": "true"} if all_map_indices else {}
self.client.delete(f"state/ti/{ti_id}", params=params)
return OKResponse(ok=True)


class AssetStateOperations:
__slots__ = ("client",)

def __init__(self, client: Client):
self.client = client

def _resolve_endpoint(
self, op: str, *, key: str | None = None, name: str | None = None, uri: str | None = None
) -> tuple[str, dict[str, str]]:
if name:
params: dict[str, str] = {"name": name}
endpoint = f"state/asset/by-name/{op}"
elif uri:
params = {"uri": uri}
endpoint = f"state/asset/by-uri/{op}"
else:
raise ValueError("Either `name` or `uri` must be provided")
Comment thread
Lee-W marked this conversation as resolved.
if key is not None:
params["key"] = key
return endpoint, params

def get(
self, key: str, *, name: str | None = None, uri: str | None = None
) -> AssetStateResponse | ErrorResponse:
"""Get an asset state value from the API server."""
endpoint, params = self._resolve_endpoint("value", key=key, name=name, uri=uri)
try:
resp = self.client.get(endpoint, params=params)
except ServerResponseError as e:
if e.response.status_code == HTTPStatus.NOT_FOUND:
log.debug("Asset state key not found", name=name, uri=uri, key=key)
return ErrorResponse(error=ErrorType.ASSET_STATE_NOT_FOUND, detail={"key": key})
raise
return AssetStateResponse.model_validate_json(resp.read())

def set(self, key: str, value: str, *, name: str | None = None, uri: str | None = None) -> OKResponse:
"""Set an asset state value via the API server."""
endpoint, params = self._resolve_endpoint("value", key=key, name=name, uri=uri)
self.client.put(endpoint, params=params, content=AssetStatePutBody(value=value).model_dump_json())
return OKResponse(ok=True)

def delete(self, key: str, *, name: str | None = None, uri: str | None = None) -> OKResponse:
"""Delete a single asset state key via the API server."""
endpoint, params = self._resolve_endpoint("value", key=key, name=name, uri=uri)
self.client.delete(endpoint, params=params)
return OKResponse(ok=True)

def clear(self, *, name: str | None = None, uri: str | None = None) -> OKResponse:
"""Clear all state keys for an asset via the API server."""
endpoint, params = self._resolve_endpoint("clear", name=name, uri=uri)
self.client.delete(endpoint, params=params)
return OKResponse(ok=True)


class AssetOperations:
__slots__ = ("client",)

Expand Down Expand Up @@ -1052,6 +1145,18 @@ def asset_events(self) -> AssetEventOperations:
"""Operations related to Asset Events."""
return AssetEventOperations(self)

@lru_cache() # type: ignore[misc]
@property
def task_state(self) -> TaskStateOperations:
"""Operations related to task state."""
return TaskStateOperations(self)

@lru_cache() # type: ignore[misc]
@property
def asset_state(self) -> AssetStateOperations:
"""Operations related to asset state."""
return AssetStateOperations(self)

@lru_cache() # type: ignore[misc]
@property
def hitl(self):
Expand Down
8 changes: 7 additions & 1 deletion task-sdk/src/airflow/sdk/definitions/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@

from airflow.sdk.bases.operator import BaseOperator
from airflow.sdk.definitions.dag import DAG
from airflow.sdk.execution_time.context import InletEventsAccessors
from airflow.sdk.execution_time.context import (
AssetStateAccessors,
InletEventsAccessors,
TaskStateAccessor,
)
from airflow.sdk.types import (
DagRunProtocol,
Operator,
Expand Down Expand Up @@ -72,6 +76,8 @@ class Context(TypedDict, total=False):
task_reschedule_count: int
task_instance: RuntimeTaskInstanceProtocol
task_instance_key_str: str
task_state: TaskStateAccessor
Comment thread
amoghrajesh marked this conversation as resolved.
asset_state: AssetStateAccessors
# `templates_dict` is only set in PythonOperator
templates_dict: NotRequired[dict[str, Any] | None]
test_mode: bool
Expand Down
2 changes: 2 additions & 0 deletions task-sdk/src/airflow/sdk/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ class ErrorType(enum.Enum):
VARIABLE_NOT_FOUND = "VARIABLE_NOT_FOUND"
XCOM_NOT_FOUND = "XCOM_NOT_FOUND"
ASSET_NOT_FOUND = "ASSET_NOT_FOUND"
TASK_STATE_NOT_FOUND = "TASK_STATE_NOT_FOUND"
ASSET_STATE_NOT_FOUND = "ASSET_STATE_NOT_FOUND"
DAGRUN_ALREADY_EXISTS = "DAGRUN_ALREADY_EXISTS"
GENERIC_ERROR = "GENERIC_ERROR"
API_SERVER_ERROR = "API_SERVER_ERROR"
Expand Down
111 changes: 110 additions & 1 deletion task-sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
AssetEventResponse,
AssetEventsResponse,
AssetResponse,
AssetStateResponse,
BundleInfo,
ConnectionResponse,
DagResponse,
Expand All @@ -81,6 +82,7 @@
TaskBreadcrumbsResponse,
TaskInstance,
TaskInstanceState,
TaskStateResponse,
TaskStatesResponse,
TIDeferredStatePayload,
TIRescheduleStatePayload,
Expand Down Expand Up @@ -545,6 +547,26 @@ def from_variable_response(cls, variable_response: VariableResponse) -> Variable
return cls(**variable_response.model_dump(exclude_defaults=True), type="VariableResult")


class TaskStateResult(TaskStateResponse):
"""Response to GetTaskState; wraps the generated API response for supervisor to worker comms."""

type: Literal["TaskStateResult"] = "TaskStateResult"

@classmethod
def from_task_state_response(cls, resp: TaskStateResponse) -> TaskStateResult:
return cls(**resp.model_dump(exclude_defaults=True), type="TaskStateResult")


class AssetStateResult(AssetStateResponse):
"""Response to GetAssetState; wraps the generated API response for supervisor to worker comms."""

type: Literal["AssetStateResult"] = "AssetStateResult"

@classmethod
def from_asset_state_response(cls, resp: AssetStateResponse) -> AssetStateResult:
return cls(**resp.model_dump(exclude_defaults=True), type="AssetStateResult")


class DagRunResult(DagRun):
type: Literal["DagRunResult"] = "DagRunResult"

Expand Down Expand Up @@ -713,6 +735,7 @@ def from_api_response(cls, dag_response: DagResponse) -> DagResult:
ToTask = Annotated[
AssetResult
| AssetEventsResult
| AssetStateResult
| ConnectionResult
| DagRunResult
| DagRunStateResult
Expand All @@ -724,6 +747,7 @@ def from_api_response(cls, dag_response: DagResponse) -> DagResult:
| SentFDs
| StartupDetails
| TaskRescheduleStartDate
| TaskStateResult
| TICount
| TaskBreadcrumbsResult
| TaskStatesResult
Expand Down Expand Up @@ -852,6 +876,79 @@ class DeleteXCom(BaseModel):
type: Literal["DeleteXCom"] = "DeleteXCom"


class GetTaskState(BaseModel):
ti_id: UUID
key: str
type: Literal["GetTaskState"] = "GetTaskState"


class SetTaskState(BaseModel):
ti_id: UUID
key: str
value: str
type: Literal["SetTaskState"] = "SetTaskState"


class DeleteTaskState(BaseModel):
ti_id: UUID
key: str
type: Literal["DeleteTaskState"] = "DeleteTaskState"


class ClearTaskState(BaseModel):
ti_id: UUID
all_map_indices: bool = False
type: Literal["ClearTaskState"] = "ClearTaskState"


class GetAssetStateByName(BaseModel):
name: str
key: str
type: Literal["GetAssetStateByName"] = "GetAssetStateByName"


class GetAssetStateByUri(BaseModel):
uri: str
key: str
type: Literal["GetAssetStateByUri"] = "GetAssetStateByUri"


class SetAssetStateByName(BaseModel):
name: str
key: str
value: str
type: Literal["SetAssetStateByName"] = "SetAssetStateByName"


class SetAssetStateByUri(BaseModel):
uri: str
key: str
value: str
type: Literal["SetAssetStateByUri"] = "SetAssetStateByUri"


class DeleteAssetStateByName(BaseModel):
name: str
key: str
type: Literal["DeleteAssetStateByName"] = "DeleteAssetStateByName"


class DeleteAssetStateByUri(BaseModel):
uri: str
key: str
type: Literal["DeleteAssetStateByUri"] = "DeleteAssetStateByUri"


class ClearAssetStateByName(BaseModel):
name: str
type: Literal["ClearAssetStateByName"] = "ClearAssetStateByName"


class ClearAssetStateByUri(BaseModel):
uri: str
type: Literal["ClearAssetStateByUri"] = "ClearAssetStateByUri"


class GetConnection(BaseModel):
conn_id: str
type: Literal["GetConnection"] = "GetConnection"
Expand Down Expand Up @@ -1042,12 +1139,20 @@ class GetDag(BaseModel):


ToSupervisor = Annotated[
DeferTask
ClearAssetStateByName
| ClearAssetStateByUri
| ClearTaskState
| DeferTask
| DeleteAssetStateByName
| DeleteAssetStateByUri
| DeleteTaskState
| DeleteXCom
| GetAssetByName
| GetAssetByUri
| GetAssetEventByAsset
| GetAssetEventByAssetAlias
| GetAssetStateByName
| GetAssetStateByUri
| GetConnection
| GetDagRun
| GetDagRunState
Expand All @@ -1057,6 +1162,7 @@ class GetDag(BaseModel):
| GetPreviousDagRun
| GetPreviousTI
| GetTaskRescheduleStartDate
| GetTaskState
| GetTICount
| GetTaskBreadcrumbs
| GetTaskStates
Expand All @@ -1068,8 +1174,11 @@ class GetDag(BaseModel):
| PutVariable
| RescheduleTask
| RetryTask
| SetAssetStateByName
| SetAssetStateByUri
| SetRenderedFields
| SetRenderedMapIndex
| SetTaskState
| SetXCom
| SkipDownstreamTasks
| SucceedTask
Expand Down
Loading
Loading