Skip to content

Commit

Permalink
Add dynamic task mapping into TaskSDK runtime
Browse files Browse the repository at this point in the history
The big change here (other than just moving code around) is to introduce a
conceptual separation between Definition/Execution time and Scheduler time.

This means that the expansion of tasks (creating the TaskInstance rows with
different map_index values) is now done on the scheduler, and we now
deserialize to different classes. For example, when we deserialize the
`DictOfListsExpandInput` it gets turned into an instance of
SchedulerDictOfListsExpandInput. This is primarily designed so that DB access
is kept 100% out of the TaskSDK.

Some of the changes here are on the "wat" side of the scale, and this is
mostly designed to not break 100% of our tests, and we have #45549 to look at
that more holistically.
  • Loading branch information
ashb committed Jan 31, 2025
1 parent 53e1723 commit ee19629
Show file tree
Hide file tree
Showing 45 changed files with 1,802 additions and 1,519 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ class TaskInstance(StrictBaseModel):
dag_id: str
run_id: str
try_number: int
map_index: int = -1
map_index: int | None = None
hostname: str | None = None


Expand Down
16 changes: 7 additions & 9 deletions airflow/decorators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@
from airflow.utils.types import NOTSET

if TYPE_CHECKING:
from sqlalchemy.orm import Session

from airflow.models.expandinput import (
ExpandInput,
OperatorExpandArgument,
Expand Down Expand Up @@ -184,7 +182,9 @@ def __init__(
kwargs_to_upstream: dict[str, Any] | None = None,
**kwargs,
) -> None:
task_id = get_unique_task_id(task_id, kwargs.get("dag"), kwargs.get("task_group"))
if not getattr(self, "_BaseOperator__from_mapped", False):
# If we are being created from calling unmap(), then don't mangle the task id
task_id = get_unique_task_id(task_id, kwargs.get("dag"), kwargs.get("task_group"))
self.python_callable = python_callable
kwargs_to_upstream = kwargs_to_upstream or {}
op_args = op_args or []
Expand Down Expand Up @@ -218,10 +218,10 @@ def __init__(
The function signature broke while assigning defaults to context key parameters.
The decorator is replacing the signature
> {python_callable.__name__}({', '.join(str(param) for param in signature.parameters.values())})
> {python_callable.__name__}({", ".join(str(param) for param in signature.parameters.values())})
with
> {python_callable.__name__}({', '.join(str(param) for param in parameters)})
> {python_callable.__name__}({", ".join(str(param) for param in parameters)})
which isn't valid: {err}
"""
Expand Down Expand Up @@ -568,13 +568,11 @@ def __attrs_post_init__(self):
super(DecoratedMappedOperator, DecoratedMappedOperator).__attrs_post_init__(self)
XComArg.apply_upstream_relationship(self, self.op_kwargs_expand_input.value)

def _expand_mapped_kwargs(
self, context: Mapping[str, Any], session: Session, *, include_xcom: bool
) -> tuple[Mapping[str, Any], set[int]]:
def _expand_mapped_kwargs(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]:
# We only use op_kwargs_expand_input so this must always be empty.
if self.expand_input is not EXPAND_INPUT_EMPTY:
raise AssertionError(f"unexpected expand_input: {self.expand_input}")
op_kwargs, resolved_oids = super()._expand_mapped_kwargs(context, session, include_xcom=include_xcom)
op_kwargs, resolved_oids = super()._expand_mapped_kwargs(context)
return {"op_kwargs": op_kwargs}, resolved_oids

def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]:
Expand Down
2 changes: 1 addition & 1 deletion airflow/models/abstractoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.models.expandinput import NotFullyPopulated
from airflow.sdk.definitions._internal.abstractoperator import (
AbstractOperator as TaskSDKAbstractOperator,
NotMapped as NotMapped, # Re-export this for compat
Expand Down Expand Up @@ -237,6 +236,7 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence
)

from airflow.models.baseoperator import BaseOperator as DBBaseOperator
from airflow.models.expandinput import NotFullyPopulated

try:
total_length: int | None = DBBaseOperator.get_mapped_ti_count(self, run_id, session=session)
Expand Down
31 changes: 23 additions & 8 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,11 +847,20 @@ def _(cls, task: TaskSDKAbstractOperator, run_id: str, *, session: Session) -> i
@get_mapped_ti_count.register(MappedOperator)
@classmethod
def _(cls, task: MappedOperator, run_id: str, *, session: Session) -> int:
from airflow.serialization.serialized_objects import _ExpandInputRef
from airflow.serialization.serialized_objects import BaseSerialization, _ExpandInputRef

exp_input = task._get_specified_expand_input()
if isinstance(exp_input, _ExpandInputRef):
exp_input = exp_input.deref(task.dag)
# TODO: TaskSDK This is only needed to support `dag.test()` etc until we port it over ot use the
# task sdk runner.
if not hasattr(exp_input, "get_total_map_length"):
exp_input = _ExpandInputRef(
type(exp_input).EXPAND_INPUT_TYPE,
BaseSerialization.deserialize(BaseSerialization.serialize(exp_input.value)),
)
exp_input = exp_input.deref(task.dag)

current_count = exp_input.get_total_map_length(run_id, session=session)

group = task.get_closest_mapped_task_group()
Expand All @@ -877,18 +886,24 @@ def _(cls, group: TaskGroup, run_id: str, *, session: Session) -> int:
:raise NotFullyPopulated: If upstream tasks are not all complete yet.
:return: Total number of mapped TIs this task should have.
"""
from airflow.serialization.serialized_objects import BaseSerialization, _ExpandInputRef

def iter_mapped_task_groups(group) -> Iterator[MappedTaskGroup]:
def iter_mapped_task_group_lengths(group) -> Iterator[int]:
while group is not None:
if isinstance(group, MappedTaskGroup):
yield group
exp_input = group._expand_input
# TODO: TaskSDK This is only needed to support `dag.test()` etc until we port it over ot use the
# task sdk runner.
if not hasattr(exp_input, "get_total_map_length"):
exp_input = _ExpandInputRef(
type(exp_input).EXPAND_INPUT_TYPE,
BaseSerialization.deserialize(BaseSerialization.serialize(exp_input.value)),
)
exp_input = exp_input.deref(group.dag)
yield exp_input.get_total_map_length(run_id, session=session)
group = group.parent_group

groups = iter_mapped_task_groups(group)
return functools.reduce(
operator.mul,
(g._expand_input.get_total_map_length(run_id, session=session) for g in groups),
)
return functools.reduce(operator.mul, iter_mapped_task_group_lengths(group))


def chain(*tasks: DependencyMixin | Sequence[DependencyMixin]) -> None:
Expand Down
4 changes: 3 additions & 1 deletion airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
from airflow.models.backfill import Backfill
from airflow.models.base import Base, StringID
from airflow.models.dag_version import DagVersion
from airflow.models.expandinput import NotFullyPopulated
from airflow.models.taskinstance import TaskInstance as TI
from airflow.models.tasklog import LogTemplate
from airflow.models.taskmap import TaskMap
Expand Down Expand Up @@ -1330,6 +1329,7 @@ def _check_for_removed_or_restored_tasks(
"""
from airflow.models.baseoperator import BaseOperator
from airflow.models.expandinput import NotFullyPopulated

tis = self.get_task_instances(session=session)

Expand Down Expand Up @@ -1467,6 +1467,7 @@ def _create_tasks(
:param task_creator: Function to create task instances
"""
from airflow.models.baseoperator import BaseOperator
from airflow.models.expandinput import NotFullyPopulated

map_indexes: Iterable[int]
for task in tasks:
Expand Down Expand Up @@ -1538,6 +1539,7 @@ def _revise_map_indexes_if_mapped(self, task: Operator, *, session: Session) ->
for more details.
"""
from airflow.models.baseoperator import BaseOperator
from airflow.models.expandinput import NotFullyPopulated
from airflow.settings import task_instance_mutation_hook

try:
Expand Down
Loading

0 comments on commit ee19629

Please sign in to comment.