Skip to content

Commit

Permalink
Add dynamic task mapping into TaskSDK runtime (apache#46032)
Browse files Browse the repository at this point in the history
The big change here (other than just moving code around) is to introduce a
conceptual separation between Definition/Execution time and Scheduler time.

This means that the expansion of tasks (creating the TaskInstance rows with
different map_index values) is now done on the scheduler, and we now
deserialize to different classes. For example, when we deserialize the
`DictOfListsExpandInput` it gets turned into an instance of
SchedulerDictOfListsExpandInput. This is primarily designed so that DB access
is kept 100% out of the TaskSDK.

Some of the changes here are on the "wat" side of the scale, and this is
mostly designed to not break 100% of our tests, and we have apache#45549 to look at
that more holistically.

To support the "reduce" style task which takes as input a sequence of all the
pushed (mapped) XCom values, and to keep the previous behaviour of not loading
all values in to memory at once, we have added a new HEAD route to the Task
Execution interface that returns the number of mapped XCom values so that it
is possible to implement `__len__` on the new LazyXComSequence class.

I have deleted a tranche of tests from tests/models that were to do with
runtime behavoir and and now tested in the TaskSDK instead.
  • Loading branch information
ashb authored and insomnes committed Feb 6, 2025
1 parent 76ca340 commit b58f1a4
Show file tree
Hide file tree
Showing 63 changed files with 2,406 additions and 2,336 deletions.
10 changes: 8 additions & 2 deletions airflow/api_fastapi/execution_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,14 @@ def custom_openapi() -> dict:

def get_extra_schemas() -> dict[str, dict]:
"""Get all the extra schemas that are not part of the main FastAPI app."""
from airflow.api_fastapi.execution_api.datamodels import taskinstance
from airflow.api_fastapi.execution_api.datamodels.taskinstance import TaskInstance
from airflow.executors.workloads import BundleInfo
from airflow.utils.state import TerminalTIState

return {
"TaskInstance": taskinstance.TaskInstance.model_json_schema(),
"TaskInstance": TaskInstance.model_json_schema(),
"BundleInfo": BundleInfo.model_json_schema(),
# Include the combined state enum too. In the datamodels we separate out SUCCESS from the other states
# as that has different payload requirements
"TerminalTIState": {"type": "string", "enum": list(TerminalTIState)},
}
21 changes: 14 additions & 7 deletions airflow/api_fastapi/execution_api/datamodels/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import uuid
from datetime import timedelta
from enum import Enum
from typing import Annotated, Any, Literal, Union

from pydantic import (
Expand Down Expand Up @@ -60,15 +60,20 @@ class TIEnterRunningPayload(StrictBaseModel):
"""When the task started executing"""


# Create an enum to give a nice name in the generated datamodels
class TerminalStateNonSuccess(str, Enum):
"""TaskInstance states that can be reported without extra information."""

FAILED = TerminalTIState.FAILED
SKIPPED = TerminalTIState.SKIPPED
REMOVED = TerminalTIState.REMOVED
FAIL_WITHOUT_RETRY = TerminalTIState.FAIL_WITHOUT_RETRY


class TITerminalStatePayload(StrictBaseModel):
"""Schema for updating TaskInstance to a terminal state except SUCCESS state."""

state: Literal[
TerminalTIState.FAILED,
TerminalTIState.SKIPPED,
TerminalTIState.REMOVED,
TerminalTIState.FAIL_WITHOUT_RETRY,
]
state: TerminalStateNonSuccess

end_date: UtcDateTime
"""When the task completed executing"""
Expand Down Expand Up @@ -242,6 +247,8 @@ class TIRunContext(BaseModel):
connections: Annotated[list[ConnectionResponse], Field(default_factory=list)]
"""Connections that can be accessed by the task instance."""

upstream_map_indexes: dict[str, int] | None = None


class PrevSuccessfulDagRunResponse(BaseModel):
"""Schema for response with previous successful DagRun information for Task Template Context."""
Expand Down
2 changes: 2 additions & 0 deletions airflow/api_fastapi/execution_api/routes/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
status.HTTP_409_CONFLICT: {"description": "The TI is already in the requested state"},
status.HTTP_422_UNPROCESSABLE_ENTITY: {"description": "Invalid payload for the state transition"},
},
response_model_exclude_unset=True,
)
def ti_run(
task_instance_id: UUID, ti_run_payload: Annotated[TIEnterRunningPayload, Body()], session: SessionDep
Expand Down Expand Up @@ -149,6 +150,7 @@ def ti_run(
DR.run_type,
DR.conf,
DR.logical_date,
DR.external_trigger,
).filter_by(dag_id=dag_id, run_id=run_id)
).one_or_none()

Expand Down
91 changes: 73 additions & 18 deletions airflow/api_fastapi/execution_api/routes/xcoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
import logging
from typing import Annotated

from fastapi import Body, HTTPException, Query, status
from fastapi import Body, Depends, HTTPException, Query, Response, status
from sqlalchemy.sql.selectable import Select

from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.common.router import AirflowRouter
Expand All @@ -30,6 +31,7 @@
from airflow.api_fastapi.execution_api.datamodels.xcom import XComResponse
from airflow.models.taskmap import TaskMap
from airflow.models.xcom import BaseXCom
from airflow.utils.db import get_query_count

# TODO: Add dependency on JWT token
router = AirflowRouter(
Expand All @@ -42,20 +44,15 @@
log = logging.getLogger(__name__)


@router.get(
"/{dag_id}/{run_id}/{task_id}/{key}",
responses={status.HTTP_404_NOT_FOUND: {"description": "XCom not found"}},
)
def get_xcom(
async def xcom_query(
dag_id: str,
run_id: str,
task_id: str,
key: str,
token: deps.TokenDep,
session: SessionDep,
map_index: Annotated[int, Query()] = -1,
) -> XComResponse:
"""Get an Airflow XCom from database - not other XCom Backends."""
token: deps.TokenDep,
map_index: Annotated[int | None, Query()] = None,
) -> Select:
if not has_xcom_access(dag_id, run_id, task_id, key, token):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
Expand All @@ -65,29 +62,87 @@ def get_xcom(
},
)

# We use `BaseXCom.get_many` to fetch XComs directly from the database, bypassing the XCom Backend.
# This avoids deserialization via the backend (e.g., from a remote storage like S3) and instead
# retrieves the raw serialized value from the database. By not relying on `XCom.get_many` or `XCom.get_one`
# (which automatically deserializes using the backend), we avoid potential
# performance hits from retrieving large data files into the API server.
query = BaseXCom.get_many(
run_id=run_id,
key=key,
task_ids=task_id,
dag_ids=dag_id,
map_indexes=map_index,
limit=1,
session=session,
)
return query.with_entities(BaseXCom.value)


@router.head(
"/{dag_id}/{run_id}/{task_id}/{key}",
responses={
status.HTTP_404_NOT_FOUND: {"description": "XCom not found"},
status.HTTP_200_OK: {
"description": "Metadata about the number of matching XCom values",
"headers": {
"Content-Range": {
"pattern": r"^map_indexes \d+$",
"description": "The number of (mapped) XCom values found for this task.",
},
},
},
},
description="Return the count of the number of XCom values found via the Content-Range response header",
)
def head_xcom(
response: Response,
token: deps.TokenDep,
session: SessionDep,
xcom_query: Annotated[Select, Depends(xcom_query)],
map_index: Annotated[int | None, Query()] = None,
) -> None:
"""Get the count of XComs from database - not other XCom Backends."""
if map_index is not None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"reason": "invalid_request", "message": "Cannot specify map_index in a HEAD request"},
)

count = get_query_count(xcom_query, session=session)
# Tell the caller how many items in this query. We define a custom range unit (HTTP spec only defines
# "bytes" but we can add our own)
response.headers["Content-Range"] = f"map_indexes {count}"


@router.get(
"/{dag_id}/{run_id}/{task_id}/{key}",
responses={status.HTTP_404_NOT_FOUND: {"description": "XCom not found"}},
description="Get a single XCom Value",
)
def get_xcom(
session: SessionDep,
dag_id: str,
run_id: str,
task_id: str,
key: str,
xcom_query: Annotated[Select, Depends(xcom_query)],
map_index: Annotated[int, Query()] = -1,
) -> XComResponse:
"""Get an Airflow XCom from database - not other XCom Backends."""
# The xcom_query allows no map_index to be passed. This endpoint should always return just a single item,
# so we override that query value

xcom_query = xcom_query.filter_by(map_index=map_index)

# We use `BaseXCom.get_many` to fetch XComs directly from the database, bypassing the XCom Backend.
# This avoids deserialization via the backend (e.g., from a remote storage like S3) and instead
# retrieves the raw serialized value from the database. By not relying on `XCom.get_many` or `XCom.get_one`
# (which automatically deserializes using the backend), we avoid potential
# performance hits from retrieving large data files into the API server.

result = query.with_entities(BaseXCom.value).first()
result = xcom_query.limit(1).first()

if result is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"reason": "not_found",
"message": f"XCom with key '{key}' not found for task '{task_id}' in DAG '{dag_id}'",
"message": f"XCom with {key=} {map_index=} not found for task {task_id!r} in DAG run {run_id!r} of {dag_id!r}",
},
)

Expand Down
16 changes: 7 additions & 9 deletions airflow/decorators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@
from airflow.utils.types import NOTSET

if TYPE_CHECKING:
from sqlalchemy.orm import Session

from airflow.models.expandinput import (
ExpandInput,
OperatorExpandArgument,
Expand Down Expand Up @@ -184,7 +182,9 @@ def __init__(
kwargs_to_upstream: dict[str, Any] | None = None,
**kwargs,
) -> None:
task_id = get_unique_task_id(task_id, kwargs.get("dag"), kwargs.get("task_group"))
if not getattr(self, "_BaseOperator__from_mapped", False):
# If we are being created from calling unmap(), then don't mangle the task id
task_id = get_unique_task_id(task_id, kwargs.get("dag"), kwargs.get("task_group"))
self.python_callable = python_callable
kwargs_to_upstream = kwargs_to_upstream or {}
op_args = op_args or []
Expand Down Expand Up @@ -218,10 +218,10 @@ def __init__(
The function signature broke while assigning defaults to context key parameters.
The decorator is replacing the signature
> {python_callable.__name__}({', '.join(str(param) for param in signature.parameters.values())})
> {python_callable.__name__}({", ".join(str(param) for param in signature.parameters.values())})
with
> {python_callable.__name__}({', '.join(str(param) for param in parameters)})
> {python_callable.__name__}({", ".join(str(param) for param in parameters)})
which isn't valid: {err}
"""
Expand Down Expand Up @@ -568,13 +568,11 @@ def __attrs_post_init__(self):
super(DecoratedMappedOperator, DecoratedMappedOperator).__attrs_post_init__(self)
XComArg.apply_upstream_relationship(self, self.op_kwargs_expand_input.value)

def _expand_mapped_kwargs(
self, context: Mapping[str, Any], session: Session, *, include_xcom: bool
) -> tuple[Mapping[str, Any], set[int]]:
def _expand_mapped_kwargs(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]:
# We only use op_kwargs_expand_input so this must always be empty.
if self.expand_input is not EXPAND_INPUT_EMPTY:
raise AssertionError(f"unexpected expand_input: {self.expand_input}")
op_kwargs, resolved_oids = super()._expand_mapped_kwargs(context, session, include_xcom=include_xcom)
op_kwargs, resolved_oids = super()._expand_mapped_kwargs(context)
return {"op_kwargs": op_kwargs}, resolved_oids

def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]:
Expand Down
2 changes: 1 addition & 1 deletion airflow/models/abstractoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.models.expandinput import NotFullyPopulated
from airflow.sdk.definitions._internal.abstractoperator import (
AbstractOperator as TaskSDKAbstractOperator,
NotMapped as NotMapped, # Re-export this for compat
Expand Down Expand Up @@ -237,6 +236,7 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence
)

from airflow.models.baseoperator import BaseOperator as DBBaseOperator
from airflow.models.expandinput import NotFullyPopulated

try:
total_length: int | None = DBBaseOperator.get_mapped_ti_count(self, run_id, session=session)
Expand Down
31 changes: 23 additions & 8 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,11 +848,20 @@ def _(cls, task: TaskSDKAbstractOperator, run_id: str, *, session: Session) -> i
@get_mapped_ti_count.register(MappedOperator)
@classmethod
def _(cls, task: MappedOperator, run_id: str, *, session: Session) -> int:
from airflow.serialization.serialized_objects import _ExpandInputRef
from airflow.serialization.serialized_objects import BaseSerialization, _ExpandInputRef

exp_input = task._get_specified_expand_input()
if isinstance(exp_input, _ExpandInputRef):
exp_input = exp_input.deref(task.dag)
# TODO: TaskSDK This is only needed to support `dag.test()` etc until we port it over to use the
# task sdk runner.
if not hasattr(exp_input, "get_total_map_length"):
exp_input = _ExpandInputRef(
type(exp_input).EXPAND_INPUT_TYPE,
BaseSerialization.deserialize(BaseSerialization.serialize(exp_input.value)),
)
exp_input = exp_input.deref(task.dag)

current_count = exp_input.get_total_map_length(run_id, session=session)

group = task.get_closest_mapped_task_group()
Expand All @@ -878,18 +887,24 @@ def _(cls, group: TaskGroup, run_id: str, *, session: Session) -> int:
:raise NotFullyPopulated: If upstream tasks are not all complete yet.
:return: Total number of mapped TIs this task should have.
"""
from airflow.serialization.serialized_objects import BaseSerialization, _ExpandInputRef

def iter_mapped_task_groups(group) -> Iterator[MappedTaskGroup]:
def iter_mapped_task_group_lengths(group) -> Iterator[int]:
while group is not None:
if isinstance(group, MappedTaskGroup):
yield group
exp_input = group._expand_input
# TODO: TaskSDK This is only needed to support `dag.test()` etc until we port it over to use the
# task sdk runner.
if not hasattr(exp_input, "get_total_map_length"):
exp_input = _ExpandInputRef(
type(exp_input).EXPAND_INPUT_TYPE,
BaseSerialization.deserialize(BaseSerialization.serialize(exp_input.value)),
)
exp_input = exp_input.deref(group.dag)
yield exp_input.get_total_map_length(run_id, session=session)
group = group.parent_group

groups = iter_mapped_task_groups(group)
return functools.reduce(
operator.mul,
(g._expand_input.get_total_map_length(run_id, session=session) for g in groups),
)
return functools.reduce(operator.mul, iter_mapped_task_group_lengths(group))


def chain(*tasks: DependencyMixin | Sequence[DependencyMixin]) -> None:
Expand Down
4 changes: 3 additions & 1 deletion airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
from airflow.models.backfill import Backfill
from airflow.models.base import Base, StringID
from airflow.models.dag_version import DagVersion
from airflow.models.expandinput import NotFullyPopulated
from airflow.models.taskinstance import TaskInstance as TI
from airflow.models.tasklog import LogTemplate
from airflow.models.taskmap import TaskMap
Expand Down Expand Up @@ -1354,6 +1353,7 @@ def _check_for_removed_or_restored_tasks(
"""
from airflow.models.baseoperator import BaseOperator
from airflow.models.expandinput import NotFullyPopulated

tis = self.get_task_instances(session=session)

Expand Down Expand Up @@ -1491,6 +1491,7 @@ def _create_tasks(
:param task_creator: Function to create task instances
"""
from airflow.models.baseoperator import BaseOperator
from airflow.models.expandinput import NotFullyPopulated

map_indexes: Iterable[int]
for task in tasks:
Expand Down Expand Up @@ -1562,6 +1563,7 @@ def _revise_map_indexes_if_mapped(self, task: Operator, *, session: Session) ->
for more details.
"""
from airflow.models.baseoperator import BaseOperator
from airflow.models.expandinput import NotFullyPopulated
from airflow.settings import task_instance_mutation_hook

try:
Expand Down
Loading

0 comments on commit b58f1a4

Please sign in to comment.