Skip to content

Commit 313a98a

Browse files
committed
fixup! Add dynamic task mapping into TaskSDK runtime
1 parent 5ced426 commit 313a98a

File tree

6 files changed

+20
-43
lines changed

6 files changed

+20
-43
lines changed

airflow/models/skipmixin.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def skip(
9797
dag_id: str,
9898
run_id: str,
9999
tasks: Iterable[DAGNode],
100-
map_index: int = -1,
100+
map_index: int | None = -1,
101101
session: Session = NEW_SESSION,
102102
):
103103
"""
@@ -126,6 +126,9 @@ def skip(
126126
if task_id is not None:
127127
from airflow.models.xcom import XCom
128128

129+
if map_index is None:
130+
map_index = -1
131+
129132
XCom.set(
130133
key=XCOM_SKIPMIXIN_KEY,
131134
value={XCOM_SKIPMIXIN_SKIPPED: task_ids_list},

airflow/serialization/serialized_objects.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1792,7 +1792,7 @@ def serialize_task_group(cls, task_group: TaskGroup) -> dict[str, Any] | None:
17921792
if isinstance(task_group, MappedTaskGroup):
17931793
expand_input = task_group._expand_input
17941794
encoded["expand_input"] = {
1795-
"type": type(expand_input).EXPAND_INPUT_TYPE,
1795+
"type": expand_input.EXPAND_INPUT_TYPE,
17961796
"value": cls.serialize(expand_input.value),
17971797
}
17981798
encoded["is_mapped"] = True

providers/src/airflow/providers/cncf/kubernetes/operators/pod.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ def _get_ti_pod_labels(context: Context | None = None, include_try_number: bool
489489
}
490490

491491
map_index = ti.map_index
492-
if map_index >= 0:
492+
if map_index is not None and map_index >= 0:
493493
labels["map_index"] = str(map_index)
494494

495495
if include_try_number:

providers/src/airflow/providers/microsoft/azure/operators/msgraph.py

-1
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,6 @@ def pull_xcom(self, context: Context) -> list:
241241
key=self.key,
242242
task_ids=self.task_id,
243243
dag_id=self.dag_id,
244-
map_indexes=map_index,
245244
)
246245
or []
247246
)

providers/standard/tests/provider_tests/standard/decorators/test_python.py

+6-33
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import sys
1919
import typing
2020
from collections import namedtuple
21-
from datetime import date, timedelta
21+
from datetime import date
2222
from typing import Union
2323

2424
import pytest
@@ -28,8 +28,6 @@
2828
from airflow.exceptions import AirflowException, XComNotFound
2929
from airflow.models.taskinstance import TaskInstance
3030
from airflow.models.taskmap import TaskMap
31-
from airflow.sdk import DAG, BaseOperator, TaskGroup, XComArg
32-
from airflow.sdk.definitions.mappedoperator import MappedOperator
3331
from airflow.utils import timezone
3432
from airflow.utils.state import State
3533
from airflow.utils.task_instance_session import set_current_task_instance_session
@@ -41,12 +39,17 @@
4139
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
4240

4341
if AIRFLOW_V_3_0_PLUS:
42+
from airflow.sdk import DAG, BaseOperator, TaskGroup, XComArg
4443
from airflow.sdk.definitions._internal.expandinput import DictOfListsExpandInput
4544
from airflow.sdk.definitions.mappedoperator import MappedOperator
4645
from airflow.utils.types import DagRunTriggeredByType
4746
else:
47+
from airflow.models.baseoperator import BaseOperator
48+
from airflow.models.dag import DAG # type: ignore[assignment]
4849
from airflow.models.expandinput import DictOfListsExpandInput
4950
from airflow.models.mappedoperator import MappedOperator
51+
from airflow.models.xcom_arg import XComArg
52+
from airflow.utils.task_group import TaskGroup
5053

5154
pytestmark = pytest.mark.db_test
5255

@@ -794,36 +797,6 @@ def task2(arg1, arg2): ...
794797
assert set(unmapped.op_kwargs) == {"arg1", "arg2"}
795798

796799

797-
def test_mapped_decorator_converts_partial_kwargs(dag_maker, session):
798-
with dag_maker(session=session):
799-
800-
@task_decorator
801-
def task1(arg):
802-
return ["x" * arg]
803-
804-
@task_decorator(retry_delay=30)
805-
def task2(arg1, arg2): ...
806-
807-
task2.partial(arg1=1).expand(arg2=task1.expand(arg=[1, 2]))
808-
809-
run = dag_maker.create_dagrun()
810-
811-
# Expand and run task1.
812-
dec = run.task_instance_scheduling_decisions(session=session)
813-
assert [ti.task_id for ti in dec.schedulable_tis] == ["task1", "task1"]
814-
for ti in dec.schedulable_tis:
815-
ti.run(session=session)
816-
assert not isinstance(ti.task, MappedOperator)
817-
assert ti.task.retry_delay == timedelta(seconds=300) # Operator default.
818-
819-
# Expand task2.
820-
dec = run.task_instance_scheduling_decisions(session=session)
821-
assert [ti.task_id for ti in dec.schedulable_tis] == ["task2", "task2"]
822-
for ti in dec.schedulable_tis:
823-
unmapped = ti.task.unmap((ti.get_template_context(session),))
824-
assert unmapped.retry_delay == timedelta(seconds=30)
825-
826-
827800
def test_mapped_render_template_fields(dag_maker, session):
828801
@task_decorator
829802
def fn(arg1, arg2): ...

tests_common/test_utils/mock_context.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,16 @@ def xcom_pull(
5656
default: Any = None,
5757
run_id: str | None = None,
5858
) -> Any:
59-
if map_indexes:
60-
return values.get(
61-
f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}_{map_indexes}", default
62-
)
63-
return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}", default)
59+
key = f"{self.task_id}_{self.dag_id}_{key}"
60+
if map_indexes is not None and (not isinstance(map_indexes, int) or map_indexes >= 0):
61+
key += f"_{map_indexes}"
62+
return values.get(key, default)
6463

6564
def xcom_push(self, key: str, value: Any, session: Session = NEW_SESSION, **kwargs) -> None:
66-
values[f"{self.task_id}_{self.dag_id}_{key}_{self.map_index}"] = value
65+
key = f"{self.task_id}_{self.dag_id}_{key}"
66+
if self.map_index is not None and self.map_index >= 0:
67+
key += f"_{self.map_index}"
68+
values[key] = value
6769

6870
values["ti"] = MockedTaskInstance(task=task)
6971

0 commit comments

Comments
 (0)