Skip to content

Commit 35fca81

Browse files
committed
Add dynamic task mapping into TaskSDK runtime
The big change here (other than just moving code around) is to introduce a conceptual separation between Definition/Execution time and Scheduler time. This means that the expansion of tasks (creating the TaskInstance rows with different map_index values) is now done on the scheduler, and we now deserialize to different classes. For example, when we deserialize the `DictOfListsExpandInput` it gets turned into an instance of SchedulerDictOfListsExpandInput. This is primarily designed so that DB access is kept 100% out of the TaskSDK. Some of the changes here are on the "wat" side of the scale, and this is mostly designed to not break 100% of our tests, and we have #45549 to look at that more holistically. To support the "reduce" style task which takes as input a sequence of all the pushed (mapped) XCom values, and to keep the previous behaviour of not loading all values in to memory at once, we have added a new HEAD route to the Task Execution interface that returns the number of mapped XCom values so that it is possible to implement `__len__` on the new LazyXComSequence class. This change also changes when and where in the TaskSDK exeuction time code we render templates and send RTIF fields to the server. This is needed because calling `render_templates` also expands the Mapped operator. As a result the `startup` call parses the dag, renders templates and performs the runtime checks (currently checking Inlets and Outlets with the API server) and returns the context. This context is important as the `ti.task` _in that context_ is unnmapped if required. I have deleted a tranche of tests from tests/models that were to do with runtime behavoir and and now tested in the TaskSDK instead.
1 parent e8be1bf commit 35fca81

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+2256
-2252
lines changed

airflow/api_fastapi/execution_api/datamodels/taskinstance.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ class TaskInstance(StrictBaseModel):
203203
dag_id: str
204204
run_id: str
205205
try_number: int
206-
map_index: int = -1
206+
map_index: int | None = None
207207
hostname: str | None = None
208208

209209

airflow/api_fastapi/execution_api/routes/xcoms.py

+73-18
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
import logging
2222
from typing import Annotated
2323

24-
from fastapi import Body, HTTPException, Query, status
24+
from fastapi import Body, Depends, HTTPException, Query, Response, status
25+
from sqlalchemy.sql.selectable import Select
2526

2627
from airflow.api_fastapi.common.db.common import SessionDep
2728
from airflow.api_fastapi.common.router import AirflowRouter
@@ -30,6 +31,7 @@
3031
from airflow.api_fastapi.execution_api.datamodels.xcom import XComResponse
3132
from airflow.models.taskmap import TaskMap
3233
from airflow.models.xcom import BaseXCom
34+
from airflow.utils.db import get_query_count
3335

3436
# TODO: Add dependency on JWT token
3537
router = AirflowRouter(
@@ -42,20 +44,15 @@
4244
log = logging.getLogger(__name__)
4345

4446

45-
@router.get(
46-
"/{dag_id}/{run_id}/{task_id}/{key}",
47-
responses={status.HTTP_404_NOT_FOUND: {"description": "XCom not found"}},
48-
)
49-
def get_xcom(
47+
async def xcom_query(
5048
dag_id: str,
5149
run_id: str,
5250
task_id: str,
5351
key: str,
54-
token: deps.TokenDep,
5552
session: SessionDep,
56-
map_index: Annotated[int, Query()] = -1,
57-
) -> XComResponse:
58-
"""Get an Airflow XCom from database - not other XCom Backends."""
53+
token: deps.TokenDep,
54+
map_index: Annotated[int | None, Query()] = None,
55+
) -> Select:
5956
if not has_xcom_access(dag_id, run_id, task_id, key, token):
6057
raise HTTPException(
6158
status_code=status.HTTP_403_FORBIDDEN,
@@ -65,29 +62,87 @@ def get_xcom(
6562
},
6663
)
6764

68-
# We use `BaseXCom.get_many` to fetch XComs directly from the database, bypassing the XCom Backend.
69-
# This avoids deserialization via the backend (e.g., from a remote storage like S3) and instead
70-
# retrieves the raw serialized value from the database. By not relying on `XCom.get_many` or `XCom.get_one`
71-
# (which automatically deserializes using the backend), we avoid potential
72-
# performance hits from retrieving large data files into the API server.
7365
query = BaseXCom.get_many(
7466
run_id=run_id,
7567
key=key,
7668
task_ids=task_id,
7769
dag_ids=dag_id,
7870
map_indexes=map_index,
79-
limit=1,
8071
session=session,
8172
)
73+
return query.with_entities(BaseXCom.value)
74+
75+
76+
@router.head(
77+
"/{dag_id}/{run_id}/{task_id}/{key}",
78+
responses={
79+
status.HTTP_404_NOT_FOUND: {"description": "XCom not found"},
80+
status.HTTP_200_OK: {
81+
"description": "Metadata about the number of matching XCom values",
82+
"headers": {
83+
"Content-Range": {
84+
"pattern": r"^map_indexes \d+$",
85+
"description": "The number of (mapped) XCom values found for this task.",
86+
},
87+
},
88+
},
89+
},
90+
description="Return the count of the number of XCom values found via the Content-Range response header",
91+
)
92+
def head_xcom(
93+
response: Response,
94+
token: deps.TokenDep,
95+
session: SessionDep,
96+
xcom_query: Annotated[Select, Depends(xcom_query)],
97+
map_index: Annotated[int | None, Query()] = None,
98+
) -> None:
99+
"""Get the count of XComs from database - not other XCom Backends."""
100+
if map_index is not None:
101+
raise HTTPException(
102+
status_code=status.HTTP_400_BAD_REQUEST,
103+
detail={"reason": "invalid_request", "message": "Cannot specify map_index in a HEAD request"},
104+
)
105+
106+
count = get_query_count(xcom_query, session=session)
107+
# Tell the caller how many items in this query. We define a custom range unit (HTTP spec only defines
108+
# "bytes" but we can add our own)
109+
response.headers["Content-Range"] = f"map_indexes {count}"
110+
111+
112+
@router.get(
113+
"/{dag_id}/{run_id}/{task_id}/{key}",
114+
responses={status.HTTP_404_NOT_FOUND: {"description": "XCom not found"}},
115+
description="Get a single XCom Value",
116+
)
117+
def get_xcom(
118+
session: SessionDep,
119+
dag_id: str,
120+
run_id: str,
121+
task_id: str,
122+
key: str,
123+
xcom_query: Annotated[Select, Depends(xcom_query)],
124+
map_index: Annotated[int, Query()] = -1,
125+
) -> XComResponse:
126+
"""Get an Airflow XCom from database - not other XCom Backends."""
127+
# The xcom_query allows no map_index to be passed. This endpoint should always return just a single item,
128+
# so we override that query value
129+
130+
xcom_query = xcom_query.filter_by(map_index=map_index)
131+
132+
# We use `BaseXCom.get_many` to fetch XComs directly from the database, bypassing the XCom Backend.
133+
# This avoids deserialization via the backend (e.g., from a remote storage like S3) and instead
134+
# retrieves the raw serialized value from the database. By not relying on `XCom.get_many` or `XCom.get_one`
135+
# (which automatically deserializes using the backend), we avoid potential
136+
# performance hits from retrieving large data files into the API server.
82137

83-
result = query.with_entities(BaseXCom.value).first()
138+
result = xcom_query.limit(1).first()
84139

85140
if result is None:
86141
raise HTTPException(
87142
status_code=status.HTTP_404_NOT_FOUND,
88143
detail={
89144
"reason": "not_found",
90-
"message": f"XCom with key '{key}' not found for task '{task_id}' in DAG '{dag_id}'",
145+
"message": f"XCom with {key=} {map_index=} not found for task {task_id!r} in DAG run {run_id!r} of {dag_id!r}",
91146
},
92147
)
93148

airflow/decorators/base.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,6 @@
5555
from airflow.utils.types import NOTSET
5656

5757
if TYPE_CHECKING:
58-
from sqlalchemy.orm import Session
59-
6058
from airflow.models.expandinput import (
6159
ExpandInput,
6260
OperatorExpandArgument,
@@ -184,7 +182,9 @@ def __init__(
184182
kwargs_to_upstream: dict[str, Any] | None = None,
185183
**kwargs,
186184
) -> None:
187-
task_id = get_unique_task_id(task_id, kwargs.get("dag"), kwargs.get("task_group"))
185+
if not getattr(self, "_BaseOperator__from_mapped", False):
186+
# If we are being created from calling unmap(), then don't mangle the task id
187+
task_id = get_unique_task_id(task_id, kwargs.get("dag"), kwargs.get("task_group"))
188188
self.python_callable = python_callable
189189
kwargs_to_upstream = kwargs_to_upstream or {}
190190
op_args = op_args or []
@@ -218,10 +218,10 @@ def __init__(
218218
The function signature broke while assigning defaults to context key parameters.
219219
220220
The decorator is replacing the signature
221-
> {python_callable.__name__}({', '.join(str(param) for param in signature.parameters.values())})
221+
> {python_callable.__name__}({", ".join(str(param) for param in signature.parameters.values())})
222222
223223
with
224-
> {python_callable.__name__}({', '.join(str(param) for param in parameters)})
224+
> {python_callable.__name__}({", ".join(str(param) for param in parameters)})
225225
226226
which isn't valid: {err}
227227
"""
@@ -568,13 +568,11 @@ def __attrs_post_init__(self):
568568
super(DecoratedMappedOperator, DecoratedMappedOperator).__attrs_post_init__(self)
569569
XComArg.apply_upstream_relationship(self, self.op_kwargs_expand_input.value)
570570

571-
def _expand_mapped_kwargs(
572-
self, context: Mapping[str, Any], session: Session, *, include_xcom: bool
573-
) -> tuple[Mapping[str, Any], set[int]]:
571+
def _expand_mapped_kwargs(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]:
574572
# We only use op_kwargs_expand_input so this must always be empty.
575573
if self.expand_input is not EXPAND_INPUT_EMPTY:
576574
raise AssertionError(f"unexpected expand_input: {self.expand_input}")
577-
op_kwargs, resolved_oids = super()._expand_mapped_kwargs(context, session, include_xcom=include_xcom)
575+
op_kwargs, resolved_oids = super()._expand_mapped_kwargs(context)
578576
return {"op_kwargs": op_kwargs}, resolved_oids
579577

580578
def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]:

airflow/models/abstractoperator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727

2828
from airflow.configuration import conf
2929
from airflow.exceptions import AirflowException
30-
from airflow.models.expandinput import NotFullyPopulated
3130
from airflow.sdk.definitions._internal.abstractoperator import (
3231
AbstractOperator as TaskSDKAbstractOperator,
3332
NotMapped as NotMapped, # Re-export this for compat
@@ -237,6 +236,7 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence
237236
)
238237

239238
from airflow.models.baseoperator import BaseOperator as DBBaseOperator
239+
from airflow.models.expandinput import NotFullyPopulated
240240

241241
try:
242242
total_length: int | None = DBBaseOperator.get_mapped_ti_count(self, run_id, session=session)

airflow/models/baseoperator.py

+23-8
Original file line numberDiff line numberDiff line change
@@ -848,11 +848,20 @@ def _(cls, task: TaskSDKAbstractOperator, run_id: str, *, session: Session) -> i
848848
@get_mapped_ti_count.register(MappedOperator)
849849
@classmethod
850850
def _(cls, task: MappedOperator, run_id: str, *, session: Session) -> int:
851-
from airflow.serialization.serialized_objects import _ExpandInputRef
851+
from airflow.serialization.serialized_objects import BaseSerialization, _ExpandInputRef
852852

853853
exp_input = task._get_specified_expand_input()
854854
if isinstance(exp_input, _ExpandInputRef):
855855
exp_input = exp_input.deref(task.dag)
856+
# TODO: TaskSDK This is only needed to support `dag.test()` etc until we port it over ot use the
857+
# task sdk runner.
858+
if not hasattr(exp_input, "get_total_map_length"):
859+
exp_input = _ExpandInputRef(
860+
type(exp_input).EXPAND_INPUT_TYPE,
861+
BaseSerialization.deserialize(BaseSerialization.serialize(exp_input.value)),
862+
)
863+
exp_input = exp_input.deref(task.dag)
864+
856865
current_count = exp_input.get_total_map_length(run_id, session=session)
857866

858867
group = task.get_closest_mapped_task_group()
@@ -878,18 +887,24 @@ def _(cls, group: TaskGroup, run_id: str, *, session: Session) -> int:
878887
:raise NotFullyPopulated: If upstream tasks are not all complete yet.
879888
:return: Total number of mapped TIs this task should have.
880889
"""
890+
from airflow.serialization.serialized_objects import BaseSerialization, _ExpandInputRef
881891

882-
def iter_mapped_task_groups(group) -> Iterator[MappedTaskGroup]:
892+
def iter_mapped_task_group_lengths(group) -> Iterator[int]:
883893
while group is not None:
884894
if isinstance(group, MappedTaskGroup):
885-
yield group
895+
exp_input = group._expand_input
896+
# TODO: TaskSDK This is only needed to support `dag.test()` etc until we port it over ot use the
897+
# task sdk runner.
898+
if not hasattr(exp_input, "get_total_map_length"):
899+
exp_input = _ExpandInputRef(
900+
type(exp_input).EXPAND_INPUT_TYPE,
901+
BaseSerialization.deserialize(BaseSerialization.serialize(exp_input.value)),
902+
)
903+
exp_input = exp_input.deref(group.dag)
904+
yield exp_input.get_total_map_length(run_id, session=session)
886905
group = group.parent_group
887906

888-
groups = iter_mapped_task_groups(group)
889-
return functools.reduce(
890-
operator.mul,
891-
(g._expand_input.get_total_map_length(run_id, session=session) for g in groups),
892-
)
907+
return functools.reduce(operator.mul, iter_mapped_task_group_lengths(group))
893908

894909

895910
def chain(*tasks: DependencyMixin | Sequence[DependencyMixin]) -> None:

airflow/models/dagrun.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@
6363
from airflow.models.backfill import Backfill
6464
from airflow.models.base import Base, StringID
6565
from airflow.models.dag_version import DagVersion
66-
from airflow.models.expandinput import NotFullyPopulated
6766
from airflow.models.taskinstance import TaskInstance as TI
6867
from airflow.models.tasklog import LogTemplate
6968
from airflow.models.taskmap import TaskMap
@@ -1347,6 +1346,7 @@ def _check_for_removed_or_restored_tasks(
13471346
13481347
"""
13491348
from airflow.models.baseoperator import BaseOperator
1349+
from airflow.models.expandinput import NotFullyPopulated
13501350

13511351
tis = self.get_task_instances(session=session)
13521352

@@ -1484,6 +1484,7 @@ def _create_tasks(
14841484
:param task_creator: Function to create task instances
14851485
"""
14861486
from airflow.models.baseoperator import BaseOperator
1487+
from airflow.models.expandinput import NotFullyPopulated
14871488

14881489
map_indexes: Iterable[int]
14891490
for task in tasks:
@@ -1555,6 +1556,7 @@ def _revise_map_indexes_if_mapped(self, task: Operator, *, session: Session) ->
15551556
for more details.
15561557
"""
15571558
from airflow.models.baseoperator import BaseOperator
1559+
from airflow.models.expandinput import NotFullyPopulated
15581560
from airflow.settings import task_instance_mutation_hook
15591561

15601562
try:

0 commit comments

Comments
 (0)