Skip to content

Commit

Permalink
Add dynamic task mapping into TaskSDK runtime
Browse files Browse the repository at this point in the history
The big change here (other than just moving code around) is to introduce a
conceptual separation between Definition/Execution time and Scheduler time.

This means that the expansion of tasks (creating the TaskInstance rows with
different map_index values) is now done on the scheduler, and we now
deserialize to different classes. For example, when we deserialize the
`DictOfListsExpandInput` it gets turned into an instance of
SchedulerDictOfListsExpandInput. This is primarily designed so that DB access
is kept 100% out of the TaskSDK.

Some of the changes here are on the "wat" side of the scale, and this is
mostly designed to not break 100% of our tests, and we have #45549 to look at
that more holistically.

To support the "reduce" style task which takes as input a sequence of all the
pushed (mapped) XCom values, and to keep the previous behaviour of not loading
all values in to memory at once, we have added a new HEAD route to the Task
Execution interface that returns the number of mapped XCom values so that it
is possible to implement `__len__` on the new LazyXComSequence class.

I have deleted a tranche of tests from tests/models that were to do with
runtime behavoir and and now tested in the TaskSDK instead.
  • Loading branch information
ashb committed Feb 6, 2025
1 parent 7af4571 commit ce56376
Show file tree
Hide file tree
Showing 63 changed files with 2,406 additions and 2,336 deletions.
10 changes: 8 additions & 2 deletions airflow/api_fastapi/execution_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,14 @@ def custom_openapi() -> dict:

def get_extra_schemas() -> dict[str, dict]:
"""Get all the extra schemas that are not part of the main FastAPI app."""
from airflow.api_fastapi.execution_api.datamodels import taskinstance
from airflow.api_fastapi.execution_api.datamodels.taskinstance import TaskInstance
from airflow.executors.workloads import BundleInfo
from airflow.utils.state import TerminalTIState

return {
"TaskInstance": taskinstance.TaskInstance.model_json_schema(),
"TaskInstance": TaskInstance.model_json_schema(),
"BundleInfo": BundleInfo.model_json_schema(),
# Include the combined state enum too. In the datamodels we separate out SUCCESS from the other states
# as that has different payload requirements
"TerminalTIState": {"type": "string", "enum": list(TerminalTIState)},
}
21 changes: 14 additions & 7 deletions airflow/api_fastapi/execution_api/datamodels/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import uuid
from datetime import timedelta
from enum import Enum
from typing import Annotated, Any, Literal, Union

from pydantic import (
Expand Down Expand Up @@ -60,15 +60,20 @@ class TIEnterRunningPayload(StrictBaseModel):
"""When the task started executing"""


# Create an enum to give a nice name in the generated datamodels
class TerminalStateNonSuccess(str, Enum):
"""TaskInstance states that can be reported without extra information."""

FAILED = TerminalTIState.FAILED
SKIPPED = TerminalTIState.SKIPPED
REMOVED = TerminalTIState.REMOVED
FAIL_WITHOUT_RETRY = TerminalTIState.FAIL_WITHOUT_RETRY


class TITerminalStatePayload(StrictBaseModel):
"""Schema for updating TaskInstance to a terminal state except SUCCESS state."""

state: Literal[
TerminalTIState.FAILED,
TerminalTIState.SKIPPED,
TerminalTIState.REMOVED,
TerminalTIState.FAIL_WITHOUT_RETRY,
]
state: TerminalStateNonSuccess

end_date: UtcDateTime
"""When the task completed executing"""
Expand Down Expand Up @@ -242,6 +247,8 @@ class TIRunContext(BaseModel):
connections: Annotated[list[ConnectionResponse], Field(default_factory=list)]
"""Connections that can be accessed by the task instance."""

upstream_map_indexes: dict[str, int] | None = None


class PrevSuccessfulDagRunResponse(BaseModel):
"""Schema for response with previous successful DagRun information for Task Template Context."""
Expand Down
2 changes: 2 additions & 0 deletions airflow/api_fastapi/execution_api/routes/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
status.HTTP_409_CONFLICT: {"description": "The TI is already in the requested state"},
status.HTTP_422_UNPROCESSABLE_ENTITY: {"description": "Invalid payload for the state transition"},
},
response_model_exclude_unset=True,
)
def ti_run(
task_instance_id: UUID, ti_run_payload: Annotated[TIEnterRunningPayload, Body()], session: SessionDep
Expand Down Expand Up @@ -149,6 +150,7 @@ def ti_run(
DR.run_type,
DR.conf,
DR.logical_date,
DR.external_trigger,
).filter_by(dag_id=dag_id, run_id=run_id)
).one_or_none()

Expand Down
91 changes: 73 additions & 18 deletions airflow/api_fastapi/execution_api/routes/xcoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
import logging
from typing import Annotated

from fastapi import Body, HTTPException, Query, status
from fastapi import Body, Depends, HTTPException, Query, Response, status
from sqlalchemy.sql.selectable import Select

from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.common.router import AirflowRouter
Expand All @@ -30,6 +31,7 @@
from airflow.api_fastapi.execution_api.datamodels.xcom import XComResponse
from airflow.models.taskmap import TaskMap
from airflow.models.xcom import BaseXCom
from airflow.utils.db import get_query_count

# TODO: Add dependency on JWT token
router = AirflowRouter(
Expand All @@ -42,20 +44,15 @@
log = logging.getLogger(__name__)


@router.get(
"/{dag_id}/{run_id}/{task_id}/{key}",
responses={status.HTTP_404_NOT_FOUND: {"description": "XCom not found"}},
)
def get_xcom(
async def xcom_query(
dag_id: str,
run_id: str,
task_id: str,
key: str,
token: deps.TokenDep,
session: SessionDep,
map_index: Annotated[int, Query()] = -1,
) -> XComResponse:
"""Get an Airflow XCom from database - not other XCom Backends."""
token: deps.TokenDep,
map_index: Annotated[int | None, Query()] = None,
) -> Select:
if not has_xcom_access(dag_id, run_id, task_id, key, token):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
Expand All @@ -65,29 +62,87 @@ def get_xcom(
},
)

# We use `BaseXCom.get_many` to fetch XComs directly from the database, bypassing the XCom Backend.
# This avoids deserialization via the backend (e.g., from a remote storage like S3) and instead
# retrieves the raw serialized value from the database. By not relying on `XCom.get_many` or `XCom.get_one`
# (which automatically deserializes using the backend), we avoid potential
# performance hits from retrieving large data files into the API server.
query = BaseXCom.get_many(
run_id=run_id,
key=key,
task_ids=task_id,
dag_ids=dag_id,
map_indexes=map_index,
limit=1,
session=session,
)
return query.with_entities(BaseXCom.value)


@router.head(
"/{dag_id}/{run_id}/{task_id}/{key}",
responses={
status.HTTP_404_NOT_FOUND: {"description": "XCom not found"},
status.HTTP_200_OK: {
"description": "Metadata about the number of matching XCom values",
"headers": {
"Content-Range": {
"pattern": r"^map_indexes \d+$",
"description": "The number of (mapped) XCom values found for this task.",
},
},
},
},
description="Return the count of the number of XCom values found via the Content-Range response header",
)
def head_xcom(
response: Response,
token: deps.TokenDep,
session: SessionDep,
xcom_query: Annotated[Select, Depends(xcom_query)],
map_index: Annotated[int | None, Query()] = None,
) -> None:
"""Get the count of XComs from database - not other XCom Backends."""
if map_index is not None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"reason": "invalid_request", "message": "Cannot specify map_index in a HEAD request"},
)

count = get_query_count(xcom_query, session=session)
# Tell the caller how many items in this query. We define a custom range unit (HTTP spec only defines
# "bytes" but we can add our own)
response.headers["Content-Range"] = f"map_indexes {count}"


@router.get(
"/{dag_id}/{run_id}/{task_id}/{key}",
responses={status.HTTP_404_NOT_FOUND: {"description": "XCom not found"}},
description="Get a single XCom Value",
)
def get_xcom(
session: SessionDep,
dag_id: str,
run_id: str,
task_id: str,
key: str,
xcom_query: Annotated[Select, Depends(xcom_query)],
map_index: Annotated[int, Query()] = -1,
) -> XComResponse:
"""Get an Airflow XCom from database - not other XCom Backends."""
# The xcom_query allows no map_index to be passed. This endpoint should always return just a single item,
# so we override that query value

xcom_query = xcom_query.filter_by(map_index=map_index)

# We use `BaseXCom.get_many` to fetch XComs directly from the database, bypassing the XCom Backend.
# This avoids deserialization via the backend (e.g., from a remote storage like S3) and instead
# retrieves the raw serialized value from the database. By not relying on `XCom.get_many` or `XCom.get_one`
# (which automatically deserializes using the backend), we avoid potential
# performance hits from retrieving large data files into the API server.

result = query.with_entities(BaseXCom.value).first()
result = xcom_query.limit(1).first()

if result is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"reason": "not_found",
"message": f"XCom with key '{key}' not found for task '{task_id}' in DAG '{dag_id}'",
"message": f"XCom with {key=} {map_index=} not found for task {task_id!r} in DAG run {run_id!r} of {dag_id!r}",
},
)

Expand Down
16 changes: 7 additions & 9 deletions airflow/decorators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@
from airflow.utils.types import NOTSET

if TYPE_CHECKING:
from sqlalchemy.orm import Session

from airflow.models.expandinput import (
ExpandInput,
OperatorExpandArgument,
Expand Down Expand Up @@ -184,7 +182,9 @@ def __init__(
kwargs_to_upstream: dict[str, Any] | None = None,
**kwargs,
) -> None:
task_id = get_unique_task_id(task_id, kwargs.get("dag"), kwargs.get("task_group"))
if not getattr(self, "_BaseOperator__from_mapped", False):
# If we are being created from calling unmap(), then don't mangle the task id
task_id = get_unique_task_id(task_id, kwargs.get("dag"), kwargs.get("task_group"))
self.python_callable = python_callable
kwargs_to_upstream = kwargs_to_upstream or {}
op_args = op_args or []
Expand Down Expand Up @@ -218,10 +218,10 @@ def __init__(
The function signature broke while assigning defaults to context key parameters.
The decorator is replacing the signature
> {python_callable.__name__}({', '.join(str(param) for param in signature.parameters.values())})
> {python_callable.__name__}({", ".join(str(param) for param in signature.parameters.values())})
with
> {python_callable.__name__}({', '.join(str(param) for param in parameters)})
> {python_callable.__name__}({", ".join(str(param) for param in parameters)})
which isn't valid: {err}
"""
Expand Down Expand Up @@ -568,13 +568,11 @@ def __attrs_post_init__(self):
super(DecoratedMappedOperator, DecoratedMappedOperator).__attrs_post_init__(self)
XComArg.apply_upstream_relationship(self, self.op_kwargs_expand_input.value)

def _expand_mapped_kwargs(
self, context: Mapping[str, Any], session: Session, *, include_xcom: bool
) -> tuple[Mapping[str, Any], set[int]]:
def _expand_mapped_kwargs(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]:
# We only use op_kwargs_expand_input so this must always be empty.
if self.expand_input is not EXPAND_INPUT_EMPTY:
raise AssertionError(f"unexpected expand_input: {self.expand_input}")
op_kwargs, resolved_oids = super()._expand_mapped_kwargs(context, session, include_xcom=include_xcom)
op_kwargs, resolved_oids = super()._expand_mapped_kwargs(context)
return {"op_kwargs": op_kwargs}, resolved_oids

def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]:
Expand Down
2 changes: 1 addition & 1 deletion airflow/models/abstractoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.models.expandinput import NotFullyPopulated
from airflow.sdk.definitions._internal.abstractoperator import (
AbstractOperator as TaskSDKAbstractOperator,
NotMapped as NotMapped, # Re-export this for compat
Expand Down Expand Up @@ -237,6 +236,7 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence
)

from airflow.models.baseoperator import BaseOperator as DBBaseOperator
from airflow.models.expandinput import NotFullyPopulated

try:
total_length: int | None = DBBaseOperator.get_mapped_ti_count(self, run_id, session=session)
Expand Down
31 changes: 23 additions & 8 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,11 +848,20 @@ def _(cls, task: TaskSDKAbstractOperator, run_id: str, *, session: Session) -> i
@get_mapped_ti_count.register(MappedOperator)
@classmethod
def _(cls, task: MappedOperator, run_id: str, *, session: Session) -> int:
from airflow.serialization.serialized_objects import _ExpandInputRef
from airflow.serialization.serialized_objects import BaseSerialization, _ExpandInputRef

exp_input = task._get_specified_expand_input()
if isinstance(exp_input, _ExpandInputRef):
exp_input = exp_input.deref(task.dag)
# TODO: TaskSDK This is only needed to support `dag.test()` etc until we port it over ot use the
# task sdk runner.
if not hasattr(exp_input, "get_total_map_length"):
exp_input = _ExpandInputRef(
type(exp_input).EXPAND_INPUT_TYPE,
BaseSerialization.deserialize(BaseSerialization.serialize(exp_input.value)),
)
exp_input = exp_input.deref(task.dag)

current_count = exp_input.get_total_map_length(run_id, session=session)

group = task.get_closest_mapped_task_group()
Expand All @@ -878,18 +887,24 @@ def _(cls, group: TaskGroup, run_id: str, *, session: Session) -> int:
:raise NotFullyPopulated: If upstream tasks are not all complete yet.
:return: Total number of mapped TIs this task should have.
"""
from airflow.serialization.serialized_objects import BaseSerialization, _ExpandInputRef

def iter_mapped_task_groups(group) -> Iterator[MappedTaskGroup]:
def iter_mapped_task_group_lengths(group) -> Iterator[int]:
while group is not None:
if isinstance(group, MappedTaskGroup):
yield group
exp_input = group._expand_input
# TODO: TaskSDK This is only needed to support `dag.test()` etc until we port it over ot use the
# task sdk runner.
if not hasattr(exp_input, "get_total_map_length"):
exp_input = _ExpandInputRef(
type(exp_input).EXPAND_INPUT_TYPE,
BaseSerialization.deserialize(BaseSerialization.serialize(exp_input.value)),
)
exp_input = exp_input.deref(group.dag)
yield exp_input.get_total_map_length(run_id, session=session)
group = group.parent_group

groups = iter_mapped_task_groups(group)
return functools.reduce(
operator.mul,
(g._expand_input.get_total_map_length(run_id, session=session) for g in groups),
)
return functools.reduce(operator.mul, iter_mapped_task_group_lengths(group))


def chain(*tasks: DependencyMixin | Sequence[DependencyMixin]) -> None:
Expand Down
4 changes: 3 additions & 1 deletion airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
from airflow.models.backfill import Backfill
from airflow.models.base import Base, StringID
from airflow.models.dag_version import DagVersion
from airflow.models.expandinput import NotFullyPopulated
from airflow.models.taskinstance import TaskInstance as TI
from airflow.models.tasklog import LogTemplate
from airflow.models.taskmap import TaskMap
Expand Down Expand Up @@ -1354,6 +1353,7 @@ def _check_for_removed_or_restored_tasks(
"""
from airflow.models.baseoperator import BaseOperator
from airflow.models.expandinput import NotFullyPopulated

tis = self.get_task_instances(session=session)

Expand Down Expand Up @@ -1491,6 +1491,7 @@ def _create_tasks(
:param task_creator: Function to create task instances
"""
from airflow.models.baseoperator import BaseOperator
from airflow.models.expandinput import NotFullyPopulated

map_indexes: Iterable[int]
for task in tasks:
Expand Down Expand Up @@ -1562,6 +1563,7 @@ def _revise_map_indexes_if_mapped(self, task: Operator, *, session: Session) ->
for more details.
"""
from airflow.models.baseoperator import BaseOperator
from airflow.models.expandinput import NotFullyPopulated
from airflow.settings import task_instance_mutation_hook

try:
Expand Down
Loading

0 comments on commit ce56376

Please sign in to comment.