Skip to content

Commit

Permalink
Add dynamic task mapping into TaskSDK runtime
Browse files Browse the repository at this point in the history
The big change here (other than just moving code around) is to introduce a
conceptual separation between Definition/Execution time and Scheduler time.

This means that the expansion of tasks (creating the TaskInstance rows with
different map_index values) is now done on the scheduler, and we now
deserialize to different classes. For example, when we deserialize the
`DictOfListsExpandInput` it gets turned into an instance of
SchedulerDictOfListsExpandInput. This is primarily designed so that DB access
is kept 100% out of the TaskSDK.

Some of the changes here are on the "wat" side of the scale, and this is
mostly designed to not break 100% of our tests, and we have #45549 to look at
that more holistically.

To support the "reduce" style task which takes as input a sequence of all the
pushed (mapped) XCom values, and to keep the previous behaviour of not loading
all values in to memory at once, we have added a new HEAD route to the Task
Execution interface that returns the number of mapped XCom values so that it
is possible to implement `__len__` on the new LazyXComSequence class.

This change also changes when and where in the TaskSDK exeuction time code we
render templates and send RTIF fields to the server. This is needed because
calling `render_templates` also expands the Mapped operator. As a result the
`startup` call parses the dag, renders templates and performs the runtime
checks (currently checking Inlets and Outlets with the API server) and returns
the context. This context is important as the `ti.task` _in that context_ is
unnmapped if required.

I have deleted a tranche of tests from tests/models that were to do with
runtime behavoir and and now tested in the TaskSDK instead.
  • Loading branch information
ashb committed Feb 4, 2025
1 parent e8be1bf commit 99f8c68
Show file tree
Hide file tree
Showing 52 changed files with 2,251 additions and 2,191 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ class TaskInstance(StrictBaseModel):
dag_id: str
run_id: str
try_number: int
map_index: int = -1
map_index: int | None = None
hostname: str | None = None


Expand Down
91 changes: 73 additions & 18 deletions airflow/api_fastapi/execution_api/routes/xcoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
import logging
from typing import Annotated

from fastapi import Body, HTTPException, Query, status
from fastapi import Body, Depends, HTTPException, Query, Response, status
from sqlalchemy.sql.selectable import Select

from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.common.router import AirflowRouter
Expand All @@ -30,6 +31,7 @@
from airflow.api_fastapi.execution_api.datamodels.xcom import XComResponse
from airflow.models.taskmap import TaskMap
from airflow.models.xcom import BaseXCom
from airflow.utils.db import get_query_count

# TODO: Add dependency on JWT token
router = AirflowRouter(
Expand All @@ -42,20 +44,15 @@
log = logging.getLogger(__name__)


@router.get(
"/{dag_id}/{run_id}/{task_id}/{key}",
responses={status.HTTP_404_NOT_FOUND: {"description": "XCom not found"}},
)
def get_xcom(
async def xcom_query(
dag_id: str,
run_id: str,
task_id: str,
key: str,
token: deps.TokenDep,
session: SessionDep,
map_index: Annotated[int, Query()] = -1,
) -> XComResponse:
"""Get an Airflow XCom from database - not other XCom Backends."""
token: deps.TokenDep,
map_index: Annotated[int | None, Query()] = None,
) -> Select:
if not has_xcom_access(dag_id, run_id, task_id, key, token):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
Expand All @@ -65,29 +62,87 @@ def get_xcom(
},
)

# We use `BaseXCom.get_many` to fetch XComs directly from the database, bypassing the XCom Backend.
# This avoids deserialization via the backend (e.g., from a remote storage like S3) and instead
# retrieves the raw serialized value from the database. By not relying on `XCom.get_many` or `XCom.get_one`
# (which automatically deserializes using the backend), we avoid potential
# performance hits from retrieving large data files into the API server.
query = BaseXCom.get_many(
run_id=run_id,
key=key,
task_ids=task_id,
dag_ids=dag_id,
map_indexes=map_index,
limit=1,
session=session,
)
return query.with_entities(BaseXCom.value)


@router.head(
"/{dag_id}/{run_id}/{task_id}/{key}",
responses={
status.HTTP_404_NOT_FOUND: {"description": "XCom not found"},
status.HTTP_200_OK: {
"description": "Metadata about the number of matching XCom values",
"headers": {
"Content-Range": {
"pattern": r"^map_indexes \d+$",
"description": "The number of (mapped) XCom values found for this task.",
},
},
},
},
description="Return the count of the number of XCom values found via the Content-Range response header",
)
def head_xcom(
response: Response,
token: deps.TokenDep,
session: SessionDep,
xcom_query: Annotated[Select, Depends(xcom_query)],
map_index: Annotated[int | None, Query()] = None,
) -> None:
"""Get the count of XComs from database - not other XCom Backends."""
if map_index is not None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"reason": "invalid_request", "message": "Cannot specify map_index in a HEAD request"},
)

count = get_query_count(xcom_query, session=session)
# Tell the caller how many items in this query. We define a custom range unit (HTTP spec only defines
# "bytes" but we can add our own)
response.headers["Content-Range"] = f"map_indexes {count}"


@router.get(
"/{dag_id}/{run_id}/{task_id}/{key}",
responses={status.HTTP_404_NOT_FOUND: {"description": "XCom not found"}},
description="Get a single XCom Value",
)
def get_xcom(
session: SessionDep,
dag_id: str,
run_id: str,
task_id: str,
key: str,
xcom_query: Annotated[Select, Depends(xcom_query)],
map_index: Annotated[int, Query()] = -1,
) -> XComResponse:
"""Get an Airflow XCom from database - not other XCom Backends."""
# The xcom_query allows no map_index to be passed. This endpoint should always return just a single item,
# so we override that query value

xcom_query = xcom_query.filter_by(map_index=map_index)

# We use `BaseXCom.get_many` to fetch XComs directly from the database, bypassing the XCom Backend.
# This avoids deserialization via the backend (e.g., from a remote storage like S3) and instead
# retrieves the raw serialized value from the database. By not relying on `XCom.get_many` or `XCom.get_one`
# (which automatically deserializes using the backend), we avoid potential
# performance hits from retrieving large data files into the API server.

result = query.with_entities(BaseXCom.value).first()
result = xcom_query.limit(1).first()

if result is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"reason": "not_found",
"message": f"XCom with key '{key}' not found for task '{task_id}' in DAG '{dag_id}'",
"message": f"XCom with {key=} {map_index=} not found for task {task_id!r} in DAG run {run_id!r} of {dag_id!r}",
},
)

Expand Down
16 changes: 7 additions & 9 deletions airflow/decorators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@
from airflow.utils.types import NOTSET

if TYPE_CHECKING:
from sqlalchemy.orm import Session

from airflow.models.expandinput import (
ExpandInput,
OperatorExpandArgument,
Expand Down Expand Up @@ -184,7 +182,9 @@ def __init__(
kwargs_to_upstream: dict[str, Any] | None = None,
**kwargs,
) -> None:
task_id = get_unique_task_id(task_id, kwargs.get("dag"), kwargs.get("task_group"))
if not getattr(self, "_BaseOperator__from_mapped", False):
# If we are being created from calling unmap(), then don't mangle the task id
task_id = get_unique_task_id(task_id, kwargs.get("dag"), kwargs.get("task_group"))
self.python_callable = python_callable
kwargs_to_upstream = kwargs_to_upstream or {}
op_args = op_args or []
Expand Down Expand Up @@ -218,10 +218,10 @@ def __init__(
The function signature broke while assigning defaults to context key parameters.
The decorator is replacing the signature
> {python_callable.__name__}({', '.join(str(param) for param in signature.parameters.values())})
> {python_callable.__name__}({", ".join(str(param) for param in signature.parameters.values())})
with
> {python_callable.__name__}({', '.join(str(param) for param in parameters)})
> {python_callable.__name__}({", ".join(str(param) for param in parameters)})
which isn't valid: {err}
"""
Expand Down Expand Up @@ -568,13 +568,11 @@ def __attrs_post_init__(self):
super(DecoratedMappedOperator, DecoratedMappedOperator).__attrs_post_init__(self)
XComArg.apply_upstream_relationship(self, self.op_kwargs_expand_input.value)

def _expand_mapped_kwargs(
self, context: Mapping[str, Any], session: Session, *, include_xcom: bool
) -> tuple[Mapping[str, Any], set[int]]:
def _expand_mapped_kwargs(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]:
# We only use op_kwargs_expand_input so this must always be empty.
if self.expand_input is not EXPAND_INPUT_EMPTY:
raise AssertionError(f"unexpected expand_input: {self.expand_input}")
op_kwargs, resolved_oids = super()._expand_mapped_kwargs(context, session, include_xcom=include_xcom)
op_kwargs, resolved_oids = super()._expand_mapped_kwargs(context)
return {"op_kwargs": op_kwargs}, resolved_oids

def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]:
Expand Down
2 changes: 1 addition & 1 deletion airflow/models/abstractoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.models.expandinput import NotFullyPopulated
from airflow.sdk.definitions._internal.abstractoperator import (
AbstractOperator as TaskSDKAbstractOperator,
NotMapped as NotMapped, # Re-export this for compat
Expand Down Expand Up @@ -237,6 +236,7 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence
)

from airflow.models.baseoperator import BaseOperator as DBBaseOperator
from airflow.models.expandinput import NotFullyPopulated

try:
total_length: int | None = DBBaseOperator.get_mapped_ti_count(self, run_id, session=session)
Expand Down
31 changes: 23 additions & 8 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,11 +848,20 @@ def _(cls, task: TaskSDKAbstractOperator, run_id: str, *, session: Session) -> i
@get_mapped_ti_count.register(MappedOperator)
@classmethod
def _(cls, task: MappedOperator, run_id: str, *, session: Session) -> int:
from airflow.serialization.serialized_objects import _ExpandInputRef
from airflow.serialization.serialized_objects import BaseSerialization, _ExpandInputRef

exp_input = task._get_specified_expand_input()
if isinstance(exp_input, _ExpandInputRef):
exp_input = exp_input.deref(task.dag)
# TODO: TaskSDK This is only needed to support `dag.test()` etc until we port it over ot use the
# task sdk runner.
if not hasattr(exp_input, "get_total_map_length"):
exp_input = _ExpandInputRef(
type(exp_input).EXPAND_INPUT_TYPE,
BaseSerialization.deserialize(BaseSerialization.serialize(exp_input.value)),
)
exp_input = exp_input.deref(task.dag)

current_count = exp_input.get_total_map_length(run_id, session=session)

group = task.get_closest_mapped_task_group()
Expand All @@ -878,18 +887,24 @@ def _(cls, group: TaskGroup, run_id: str, *, session: Session) -> int:
:raise NotFullyPopulated: If upstream tasks are not all complete yet.
:return: Total number of mapped TIs this task should have.
"""
from airflow.serialization.serialized_objects import BaseSerialization, _ExpandInputRef

def iter_mapped_task_groups(group) -> Iterator[MappedTaskGroup]:
def iter_mapped_task_group_lengths(group) -> Iterator[int]:
while group is not None:
if isinstance(group, MappedTaskGroup):
yield group
exp_input = group._expand_input
# TODO: TaskSDK This is only needed to support `dag.test()` etc until we port it over ot use the
# task sdk runner.
if not hasattr(exp_input, "get_total_map_length"):
exp_input = _ExpandInputRef(
type(exp_input).EXPAND_INPUT_TYPE,
BaseSerialization.deserialize(BaseSerialization.serialize(exp_input.value)),
)
exp_input = exp_input.deref(group.dag)
yield exp_input.get_total_map_length(run_id, session=session)
group = group.parent_group

groups = iter_mapped_task_groups(group)
return functools.reduce(
operator.mul,
(g._expand_input.get_total_map_length(run_id, session=session) for g in groups),
)
return functools.reduce(operator.mul, iter_mapped_task_group_lengths(group))


def chain(*tasks: DependencyMixin | Sequence[DependencyMixin]) -> None:
Expand Down
4 changes: 3 additions & 1 deletion airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
from airflow.models.backfill import Backfill
from airflow.models.base import Base, StringID
from airflow.models.dag_version import DagVersion
from airflow.models.expandinput import NotFullyPopulated
from airflow.models.taskinstance import TaskInstance as TI
from airflow.models.tasklog import LogTemplate
from airflow.models.taskmap import TaskMap
Expand Down Expand Up @@ -1347,6 +1346,7 @@ def _check_for_removed_or_restored_tasks(
"""
from airflow.models.baseoperator import BaseOperator
from airflow.models.expandinput import NotFullyPopulated

tis = self.get_task_instances(session=session)

Expand Down Expand Up @@ -1484,6 +1484,7 @@ def _create_tasks(
:param task_creator: Function to create task instances
"""
from airflow.models.baseoperator import BaseOperator
from airflow.models.expandinput import NotFullyPopulated

map_indexes: Iterable[int]
for task in tasks:
Expand Down Expand Up @@ -1555,6 +1556,7 @@ def _revise_map_indexes_if_mapped(self, task: Operator, *, session: Session) ->
for more details.
"""
from airflow.models.baseoperator import BaseOperator
from airflow.models.expandinput import NotFullyPopulated
from airflow.settings import task_instance_mutation_hook

try:
Expand Down
Loading

0 comments on commit 99f8c68

Please sign in to comment.