diff --git a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index 675a76d431392..8776d1271df81 100644 --- a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -203,7 +203,7 @@ class TaskInstance(StrictBaseModel): dag_id: str run_id: str try_number: int - map_index: int = -1 + map_index: int | None = None hostname: str | None = None diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 9db6d058adeb5..b79819bedcc5e 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -55,8 +55,6 @@ from airflow.utils.types import NOTSET if TYPE_CHECKING: - from sqlalchemy.orm import Session - from airflow.models.expandinput import ( ExpandInput, OperatorExpandArgument, @@ -184,7 +182,9 @@ def __init__( kwargs_to_upstream: dict[str, Any] | None = None, **kwargs, ) -> None: - task_id = get_unique_task_id(task_id, kwargs.get("dag"), kwargs.get("task_group")) + if not getattr(self, "_BaseOperator__from_mapped", False): + # If we are being created from calling unmap(), then don't mangle the task id + task_id = get_unique_task_id(task_id, kwargs.get("dag"), kwargs.get("task_group")) self.python_callable = python_callable kwargs_to_upstream = kwargs_to_upstream or {} op_args = op_args or [] @@ -218,10 +218,10 @@ def __init__( The function signature broke while assigning defaults to context key parameters. The decorator is replacing the signature - > {python_callable.__name__}({', '.join(str(param) for param in signature.parameters.values())}) + > {python_callable.__name__}({", ".join(str(param) for param in signature.parameters.values())}) with - > {python_callable.__name__}({', '.join(str(param) for param in parameters)}) + > {python_callable.__name__}({", ".join(str(param) for param in parameters)}) which isn't valid: {err} """ @@ -568,13 +568,11 @@ def __attrs_post_init__(self): super(DecoratedMappedOperator, DecoratedMappedOperator).__attrs_post_init__(self) XComArg.apply_upstream_relationship(self, self.op_kwargs_expand_input.value) - def _expand_mapped_kwargs( - self, context: Mapping[str, Any], session: Session, *, include_xcom: bool - ) -> tuple[Mapping[str, Any], set[int]]: + def _expand_mapped_kwargs(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]: # We only use op_kwargs_expand_input so this must always be empty. if self.expand_input is not EXPAND_INPUT_EMPTY: raise AssertionError(f"unexpected expand_input: {self.expand_input}") - op_kwargs, resolved_oids = super()._expand_mapped_kwargs(context, session, include_xcom=include_xcom) + op_kwargs, resolved_oids = super()._expand_mapped_kwargs(context) return {"op_kwargs": op_kwargs}, resolved_oids def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]: diff --git a/airflow/models/abstractoperator.py b/airflow/models/abstractoperator.py index b8fb54f6966fd..98fd977c59128 100644 --- a/airflow/models/abstractoperator.py +++ b/airflow/models/abstractoperator.py @@ -27,7 +27,6 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException -from airflow.models.expandinput import NotFullyPopulated from airflow.sdk.definitions._internal.abstractoperator import ( AbstractOperator as TaskSDKAbstractOperator, NotMapped as NotMapped, # Re-export this for compat @@ -237,6 +236,7 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence ) from airflow.models.baseoperator import BaseOperator as DBBaseOperator + from airflow.models.expandinput import NotFullyPopulated try: total_length: int | None = DBBaseOperator.get_mapped_ti_count(self, run_id, session=session) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index e84994c6d04e7..15b5a3105b47f 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -847,11 +847,20 @@ def _(cls, task: TaskSDKAbstractOperator, run_id: str, *, session: Session) -> i @get_mapped_ti_count.register(MappedOperator) @classmethod def _(cls, task: MappedOperator, run_id: str, *, session: Session) -> int: - from airflow.serialization.serialized_objects import _ExpandInputRef + from airflow.serialization.serialized_objects import BaseSerialization, _ExpandInputRef exp_input = task._get_specified_expand_input() if isinstance(exp_input, _ExpandInputRef): exp_input = exp_input.deref(task.dag) + # TODO: TaskSDK This is only needed to support `dag.test()` etc until we port it over ot use the + # task sdk runner. + if not hasattr(exp_input, "get_total_map_length"): + exp_input = _ExpandInputRef( + type(exp_input).EXPAND_INPUT_TYPE, + BaseSerialization.deserialize(BaseSerialization.serialize(exp_input.value)), + ) + exp_input = exp_input.deref(task.dag) + current_count = exp_input.get_total_map_length(run_id, session=session) group = task.get_closest_mapped_task_group() @@ -877,18 +886,24 @@ def _(cls, group: TaskGroup, run_id: str, *, session: Session) -> int: :raise NotFullyPopulated: If upstream tasks are not all complete yet. :return: Total number of mapped TIs this task should have. """ + from airflow.serialization.serialized_objects import BaseSerialization, _ExpandInputRef - def iter_mapped_task_groups(group) -> Iterator[MappedTaskGroup]: + def iter_mapped_task_group_lengths(group) -> Iterator[int]: while group is not None: if isinstance(group, MappedTaskGroup): - yield group + exp_input = group._expand_input + # TODO: TaskSDK This is only needed to support `dag.test()` etc until we port it over ot use the + # task sdk runner. + if not hasattr(exp_input, "get_total_map_length"): + exp_input = _ExpandInputRef( + type(exp_input).EXPAND_INPUT_TYPE, + BaseSerialization.deserialize(BaseSerialization.serialize(exp_input.value)), + ) + exp_input = exp_input.deref(group.dag) + yield exp_input.get_total_map_length(run_id, session=session) group = group.parent_group - groups = iter_mapped_task_groups(group) - return functools.reduce( - operator.mul, - (g._expand_input.get_total_map_length(run_id, session=session) for g in groups), - ) + return functools.reduce(operator.mul, iter_mapped_task_group_lengths(group)) def chain(*tasks: DependencyMixin | Sequence[DependencyMixin]) -> None: diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 727746b9b0333..29f279d9d7cb6 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -63,7 +63,6 @@ from airflow.models.backfill import Backfill from airflow.models.base import Base, StringID from airflow.models.dag_version import DagVersion -from airflow.models.expandinput import NotFullyPopulated from airflow.models.taskinstance import TaskInstance as TI from airflow.models.tasklog import LogTemplate from airflow.models.taskmap import TaskMap @@ -1330,6 +1329,7 @@ def _check_for_removed_or_restored_tasks( """ from airflow.models.baseoperator import BaseOperator + from airflow.models.expandinput import NotFullyPopulated tis = self.get_task_instances(session=session) @@ -1467,6 +1467,7 @@ def _create_tasks( :param task_creator: Function to create task instances """ from airflow.models.baseoperator import BaseOperator + from airflow.models.expandinput import NotFullyPopulated map_indexes: Iterable[int] for task in tasks: @@ -1538,6 +1539,7 @@ def _revise_map_indexes_if_mapped(self, task: Operator, *, session: Session) -> for more details. """ from airflow.models.baseoperator import BaseOperator + from airflow.models.expandinput import NotFullyPopulated from airflow.settings import task_instance_mutation_hook try: diff --git a/airflow/models/expandinput.py b/airflow/models/expandinput.py index 8fb35f7032965..b0916dfa4cca6 100644 --- a/airflow/models/expandinput.py +++ b/airflow/models/expandinput.py @@ -17,108 +17,54 @@ # under the License. from __future__ import annotations -import collections.abc import functools import operator -from collections.abc import Iterable, Mapping, Sequence, Sized -from typing import TYPE_CHECKING, Any, NamedTuple, Union +from collections.abc import Iterable, Sized +from typing import TYPE_CHECKING, Any -import attr - -from airflow.sdk.definitions._internal.mixins import ResolveMixin -from airflow.utils.session import NEW_SESSION, provide_session +import attrs if TYPE_CHECKING: from sqlalchemy.orm import Session - from airflow.models.xcom_arg import XComArg - from airflow.sdk.types import Operator - from airflow.serialization.serialized_objects import _ExpandInputRef + from airflow.models.xcom_arg import SchedulerXComArg from airflow.typing_compat import TypeGuard -ExpandInput = Union["DictOfListsExpandInput", "ListOfDictsExpandInput"] - -# Each keyword argument to expand() can be an XComArg, sequence, or dict (not -# any mapping since we need the value to be ordered). -OperatorExpandArgument = Union["MappedArgument", "XComArg", Sequence, dict[str, Any]] - -# The single argument of expand_kwargs() can be an XComArg, or a list with each -# element being either an XComArg or a dict. -OperatorExpandKwargsArgument = Union["XComArg", Sequence[Union["XComArg", Mapping[str, Any]]]] - - -@attr.define(kw_only=True) -class MappedArgument(ResolveMixin): - """ - Stand-in stub for task-group-mapping arguments. - - This is very similar to an XComArg, but resolved differently. Declared here - (instead of in the task group module) to avoid import cycles. - """ - - _input: ExpandInput - _key: str - - def iter_references(self) -> Iterable[tuple[Operator, str]]: - yield from self._input.iter_references() - - @provide_session - def resolve( - self, context: Mapping[str, Any], *, include_xcom: bool = True, session: Session = NEW_SESSION - ) -> Any: - data, _ = self._input.resolve(context, session=session, include_xcom=include_xcom) - return data[self._key] - - -# To replace tedious isinstance() checks. -def is_mappable(v: Any) -> TypeGuard[OperatorExpandArgument]: - from airflow.models.xcom_arg import XComArg - - return isinstance(v, (MappedArgument, XComArg, Mapping, Sequence)) and not isinstance(v, str) - - -# To replace tedious isinstance() checks. -def _is_parse_time_mappable(v: OperatorExpandArgument) -> TypeGuard[Mapping | Sequence]: - from airflow.models.xcom_arg import XComArg - - return not isinstance(v, (MappedArgument, XComArg)) - - -# To replace tedious isinstance() checks. -def _needs_run_time_resolution(v: OperatorExpandArgument) -> TypeGuard[MappedArgument | XComArg]: - from airflow.models.xcom_arg import XComArg - - return isinstance(v, (MappedArgument, XComArg)) - - -class NotFullyPopulated(RuntimeError): - """ - Raise when ``get_map_lengths`` cannot populate all mapping metadata. - - This is generally due to not all upstream tasks have finished when the - function is called. - """ +from airflow.sdk.definitions._internal.expandinput import ( + DictOfListsExpandInput, + ExpandInput, + ListOfDictsExpandInput, + MappedArgument, + NotFullyPopulated, + OperatorExpandArgument, + OperatorExpandKwargsArgument, + is_mappable, +) - def __init__(self, missing: set[str]) -> None: - self.missing = missing +__all__ = [ + "DictOfListsExpandInput", + "ListOfDictsExpandInput", + "MappedArgument", + "NotFullyPopulated", + "OperatorExpandArgument", + "OperatorExpandKwargsArgument", + "is_mappable", +] - def __str__(self) -> str: - keys = ", ".join(repr(k) for k in sorted(self.missing)) - return f"Failed to populate all mapping metadata; missing: {keys}" +def _needs_run_time_resolution(v: OperatorExpandArgument) -> TypeGuard[MappedArgument | SchedulerXComArg]: + from airflow.models.xcom_arg import SchedulerXComArg -class DictOfListsExpandInput(NamedTuple): - """ - Storage type of a mapped operator's mapped kwargs. + return isinstance(v, (MappedArgument, SchedulerXComArg)) - This is created from ``expand(**kwargs)``. - """ - value: dict[str, OperatorExpandArgument] +@attrs.define +class SchedulerDictOfListsExpandInput: + value: dict def _iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]: """Generate kwargs with values available on parse-time.""" - return ((k, v) for k, v in self.value.items() if _is_parse_time_mappable(v)) + return ((k, v) for k, v in self.value.items() if not _needs_run_time_resolution(v)) def get_parse_time_mapped_ti_count(self) -> int: if not self.value: @@ -164,150 +110,34 @@ def get_total_map_length(self, run_id: str, *, session: Session) -> int: lengths = self._get_map_lengths(run_id, session=session) return functools.reduce(operator.mul, (lengths[name] for name in self.value), 1) - def _expand_mapped_field( - self, key: str, value: Any, context: Mapping[str, Any], *, session: Session, include_xcom: bool - ) -> Any: - if _needs_run_time_resolution(value): - value = ( - value.resolve(context, session=session, include_xcom=include_xcom) - if include_xcom - else str(value) - ) - map_index = context["ti"].map_index - if map_index < 0: - raise RuntimeError("can't resolve task-mapping argument without expanding") - all_lengths = self._get_map_lengths(context["run_id"], session=session) - def _find_index_for_this_field(index: int) -> int: - # Need to use the original user input to retain argument order. - for mapped_key in reversed(self.value): - mapped_length = all_lengths[mapped_key] - if mapped_length < 1: - raise RuntimeError(f"cannot expand field mapped to length {mapped_length!r}") - if mapped_key == key: - return index % mapped_length - index //= mapped_length - return -1 - - found_index = _find_index_for_this_field(map_index) - if found_index < 0: - return value - if isinstance(value, collections.abc.Sequence): - return value[found_index] - if not isinstance(value, dict): - raise TypeError(f"can't map over value of type {type(value)}") - for i, (k, v) in enumerate(value.items()): - if i == found_index: - return k, v - raise IndexError(f"index {map_index} is over mapped length") - - def iter_references(self) -> Iterable[tuple[Operator, str]]: - from airflow.models.xcom_arg import XComArg - - for x in self.value.values(): - if isinstance(x, XComArg): - yield from x.iter_references() - - def resolve( - self, context: Mapping[str, Any], session: Session, *, include_xcom: bool = True - ) -> tuple[Mapping[str, Any], set[int]]: - data = { - k: self._expand_mapped_field(k, v, context, session=session, include_xcom=include_xcom) - for k, v in self.value.items() - } - literal_keys = {k for k, _ in self._iter_parse_time_resolved_kwargs()} - resolved_oids = {id(v) for k, v in data.items() if k not in literal_keys} - return data, resolved_oids - - -def _describe_type(value: Any) -> str: - if value is None: - return "None" - return type(value).__name__ - - -class ListOfDictsExpandInput(NamedTuple): - """ - Storage type of a mapped operator's mapped kwargs. - - This is created from ``expand_kwargs(xcom_arg)``. - """ - - value: OperatorExpandKwargsArgument +@attrs.define +class SchedulerListOfDictsExpandInput: + value: list def get_parse_time_mapped_ti_count(self) -> int: - if isinstance(self.value, collections.abc.Sized): + if isinstance(self.value, Sized): return len(self.value) raise NotFullyPopulated({"expand_kwargs() argument"}) def get_total_map_length(self, run_id: str, *, session: Session) -> int: from airflow.models.xcom_arg import get_task_map_length - if isinstance(self.value, collections.abc.Sized): + if isinstance(self.value, Sized): return len(self.value) length = get_task_map_length(self.value, run_id, session=session) if length is None: raise NotFullyPopulated({"expand_kwargs() argument"}) return length - def iter_references(self) -> Iterable[tuple[Operator, str]]: - from airflow.models.xcom_arg import XComArg - - if isinstance(self.value, XComArg): - yield from self.value.iter_references() - else: - for x in self.value: - if isinstance(x, XComArg): - yield from x.iter_references() - - def resolve( - self, context: Mapping[str, Any], session: Session, *, include_xcom: bool = True - ) -> tuple[Mapping[str, Any], set[int]]: - map_index = context["ti"].map_index - if map_index < 0: - raise RuntimeError("can't resolve task-mapping argument without expanding") - - mapping: Any - if isinstance(self.value, collections.abc.Sized): - mapping = self.value[map_index] - if not isinstance(mapping, collections.abc.Mapping): - mapping = mapping.resolve(context, session, include_xcom=include_xcom) - elif include_xcom: - mappings = self.value.resolve(context, session, include_xcom=include_xcom) - if not isinstance(mappings, collections.abc.Sequence): - raise ValueError(f"expand_kwargs() expects a list[dict], not {_describe_type(mappings)}") - mapping = mappings[map_index] - - if not isinstance(mapping, collections.abc.Mapping): - raise ValueError(f"expand_kwargs() expects a list[dict], not list[{_describe_type(mapping)}]") - - for key in mapping: - if not isinstance(key, str): - raise ValueError( - f"expand_kwargs() input dict keys must all be str, " - f"but {key!r} is of type {_describe_type(key)}" - ) - # filter out parse time resolved values from the resolved_oids - resolved_oids = {id(v) for k, v in mapping.items() if not _is_parse_time_mappable(v)} - - return mapping, resolved_oids - EXPAND_INPUT_EMPTY = DictOfListsExpandInput({}) # Sentinel value. _EXPAND_INPUT_TYPES = { - "dict-of-lists": DictOfListsExpandInput, - "list-of-dicts": ListOfDictsExpandInput, + "dict-of-lists": SchedulerDictOfListsExpandInput, + "list-of-dicts": SchedulerListOfDictsExpandInput, } -def get_map_type_key(expand_input: ExpandInput | _ExpandInputRef) -> str: - from airflow.serialization.serialized_objects import _ExpandInputRef - - if isinstance(expand_input, _ExpandInputRef): - return expand_input.key - return next(k for k, v in _EXPAND_INPUT_TYPES.items() if isinstance(expand_input, v)) - - def create_expand_input(kind: str, value: Any) -> ExpandInput: return _EXPAND_INPUT_TYPES[kind](value) diff --git a/airflow/models/serialized_dag.py b/airflow/models/serialized_dag.py index 5da2f24957ff0..2787ce1b82993 100644 --- a/airflow/models/serialized_dag.py +++ b/airflow/models/serialized_dag.py @@ -291,7 +291,7 @@ def read_all_dags(cls, session: Session = NEW_SESSION) -> dict[str, SerializedDA @property def data(self) -> dict | None: # use __data_cache to avoid decompress and loads - if not hasattr(self, "__data_cache") or self.__data_cache is None: + if not hasattr(self, "_SerializedDagModel__data_cache") or self.__data_cache is None: if self._data_compressed: self.__data_cache = json.loads(zlib.decompress(self._data_compressed)) else: diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 7106366b67ef8..5fed1cf10e62a 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -3316,6 +3316,8 @@ def render_templates( # able to access the unmapped task instead. original_task.render_template_fields(context, jinja_env) if isinstance(self.task, MappedOperator): + if not self.task.is_mapped: + breakpoint() self.task = context["ti"].task # type: ignore[assignment] return original_task diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py index 078a9e6ff5223..ceb422c7cbff0 100644 --- a/airflow/models/xcom_arg.py +++ b/airflow/models/xcom_arg.py @@ -17,9 +17,11 @@ from __future__ import annotations +from collections.abc import Sequence from functools import singledispatch -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any +import attrs from sqlalchemy import func, or_, select from sqlalchemy.orm import Session @@ -27,36 +29,98 @@ from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.sdk.definitions.xcom_arg import ( ConcatXComArg, - MapXComArg, - PlainXComArg, XComArg, - ZipXComArg, ) from airflow.utils.db import exists_query from airflow.utils.state import State +from airflow.utils.types import NOTSET from airflow.utils.xcom import XCOM_RETURN_KEY __all__ = ["XComArg", "get_task_map_length"] if TYPE_CHECKING: + from airflow.models.dag import DAG as SchedulerDAG from airflow.models.expandinput import OperatorExpandArgument + from airflow.models.operator import Operator + from airflow.typing_compat import Self + + +@attrs.define +class SchedulerXComArg: + @classmethod + def _deserialize(cls, data: dict[str, Any], dag: SchedulerDAG) -> Self: + """ + Deserialize an XComArg. + + The implementation should be the inverse function to ``serialize``, + implementing given a data dict converted from this XComArg derivative, + how the original XComArg should be created. DAG serialization relies on + additional information added in ``serialize_xcom_arg`` to dispatch data + dicts to the correct ``_deserialize`` information, so this function does + not need to validate whether the incoming data contains correct keys. + """ + ... + + +@attrs.define +class SchedulerPlainXComArg(SchedulerXComArg): + operator: Operator + key: str + + @classmethod + def _deserialize(cls, data: dict[str, Any], dag: SchedulerDAG) -> Self: + return cls(dag.get_task(data["task_id"]), data["key"]) + + +@attrs.define +class SchedulerMapXComArg(SchedulerXComArg): + arg: XComArg + callables: Sequence[str] + + @classmethod + def _deserialize(cls, data: dict[str, Any], dag: SchedulerDAG) -> Self: + # We are deliberately NOT deserializing the callables. These are shown + # in the UI, and displaying a function object is useless. + return cls(deserialize_xcom_arg(data["arg"], dag), data["callables"]) + + +@attrs.define +class SchedulerConcatXComArg(SchedulerXComArg): + args: Sequence[XComArg] + + @classmethod + def _deserialize(cls, data: dict[str, Any], dag: SchedulerDAG) -> Self: + return cls([deserialize_xcom_arg(arg, dag) for arg in data["args"]]) + + +@attrs.define +class SchedulerZipXComArg(SchedulerXComArg): + args: Sequence[XComArg] + fillvalue: Any + + @classmethod + def _deserialize(cls, data: dict[str, Any], dag: SchedulerDAG) -> Self: + return cls( + [deserialize_xcom_arg(arg, dag) for arg in data["args"]], + fillvalue=data.get("fillvalue", NOTSET), + ) @singledispatch def get_task_map_length(xcom_arg: OperatorExpandArgument, run_id: str, *, session: Session) -> int | None: # The base implementation -- specific XComArg subclasses have specialised implementations - raise NotImplementedError() + raise NotImplementedError(f"get_task_map_length not implemented for {type(xcom_arg)}") @get_task_map_length.register -def _(xcom_arg: PlainXComArg, run_id: str, *, session: Session): +def _(xcom_arg: SchedulerPlainXComArg, run_id: str, *, session: Session): from airflow.models.taskinstance import TaskInstance from airflow.models.taskmap import TaskMap from airflow.models.xcom import XCom dag_id = xcom_arg.operator.dag_id task_id = xcom_arg.operator.task_id - is_mapped = isinstance(xcom_arg.operator, MappedOperator) + is_mapped = xcom_arg.operator.is_mapped or isinstance(xcom_arg.operator, MappedOperator) if is_mapped: unfinished_ti_exists = exists_query( @@ -92,12 +156,12 @@ def _(xcom_arg: PlainXComArg, run_id: str, *, session: Session): @get_task_map_length.register -def _(xcom_arg: MapXComArg, run_id: str, *, session: Session): +def _(xcom_arg: SchedulerMapXComArg, run_id: str, *, session: Session): return get_task_map_length(xcom_arg.arg, run_id, session=session) @get_task_map_length.register -def _(xcom_arg: ZipXComArg, run_id: str, *, session: Session): +def _(xcom_arg: SchedulerZipXComArg, run_id: str, *, session: Session): all_lengths = (get_task_map_length(arg, run_id, session=session) for arg in xcom_arg.args) ready_lengths = [length for length in all_lengths if length is not None] if len(ready_lengths) != len(xcom_arg.args): @@ -114,3 +178,17 @@ def _(xcom_arg: ConcatXComArg, run_id: str, *, session: Session): if len(ready_lengths) != len(xcom_arg.args): return None # If any of the referenced XComs is not ready, we are not ready either. return sum(ready_lengths) + + +def deserialize_xcom_arg(data: dict[str, Any], dag: SchedulerDAG): + """DAG serialization interface.""" + klass = _XCOM_ARG_TYPES[data.get("type", "")] + return klass._deserialize(data, dag) + + +_XCOM_ARG_TYPES: dict[str, type[SchedulerXComArg]] = { + "": SchedulerPlainXComArg, + "concat": SchedulerConcatXComArg, + "map": SchedulerMapXComArg, + "zip": SchedulerZipXComArg, +} diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 89ea668f02ffe..26db391148a17 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -45,10 +45,10 @@ from airflow.models.expandinput import ( EXPAND_INPUT_EMPTY, create_expand_input, - get_map_type_key, ) from airflow.models.taskinstance import SimpleTaskInstance from airflow.models.taskinstancekey import TaskInstanceKey +from airflow.models.xcom_arg import SchedulerXComArg, deserialize_xcom_arg from airflow.providers_manager import ProvidersManager from airflow.sdk.definitions.asset import ( Asset, @@ -66,7 +66,7 @@ from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.sdk.definitions.param import Param, ParamsDict from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup -from airflow.sdk.definitions.xcom_arg import XComArg, deserialize_xcom_arg, serialize_xcom_arg +from airflow.sdk.definitions.xcom_arg import XComArg, serialize_xcom_arg from airflow.sdk.execution_time.context import OutletEventAccessor, OutletEventAccessors from airflow.serialization.dag_dependency import DagDependency from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding @@ -493,7 +493,7 @@ class _XComRef(NamedTuple): data: dict - def deref(self, dag: DAG) -> XComArg: + def deref(self, dag: DAG) -> SchedulerXComArg: return deserialize_xcom_arg(self.data, dag) @@ -1195,7 +1195,7 @@ def serialize_mapped_operator(cls, op: MappedOperator) -> dict[str, Any]: if TYPE_CHECKING: # Let Mypy check the input type for us! _ExpandInputRef.validate_expand_input_value(expansion_kwargs.value) serialized_op[op._expand_input_attr] = { - "type": get_map_type_key(expansion_kwargs), + "type": type(expansion_kwargs).EXPAND_INPUT_TYPE, "value": cls.serialize(expansion_kwargs.value), } @@ -1792,7 +1792,7 @@ def serialize_task_group(cls, task_group: TaskGroup) -> dict[str, Any] | None: if isinstance(task_group, MappedTaskGroup): expand_input = task_group._expand_input encoded["expand_input"] = { - "type": get_map_type_key(expand_input), + "type": type(expand_input).EXPAND_INPUT_TYPE, "value": cls.serialize(expand_input.value), } encoded["is_mapped"] = True diff --git a/airflow/utils/context.py b/airflow/utils/context.py index 0415542c6ca8c..62308605dbb37 100644 --- a/airflow/utils/context.py +++ b/airflow/utils/context.py @@ -66,7 +66,6 @@ from sqlalchemy.orm import Session from sqlalchemy.sql.expression import Select, TextClause - from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.types import OutletEventAccessorsProtocol # NOTE: Please keep this in sync with the following: @@ -293,24 +292,6 @@ def context_merge(context: Context, *args: Any, **kwargs: Any) -> None: context.update(*args, **kwargs) -def context_update_for_unmapped(context: Context, task: BaseOperator) -> None: - """ - Update context after task unmapping. - - Since ``get_template_context()`` is called before unmapping, the context - contains information about the mapped task. We need to do some in-place - updates to ensure the template context reflects the unmapped task instead. - - :meta private: - """ - from airflow.sdk.definitions.param import process_params - - context["task"] = context["ti"].task = task - context["params"] = process_params( - context["dag"], task, context["dag_run"].conf, suppress_exception=False - ) - - def context_copy_partial(source: Context, keys: Container[str]) -> Context: """ Create a context by copying items under selected keys in ``source``. diff --git a/airflow/utils/setup_teardown.py b/airflow/utils/setup_teardown.py index 3108657d30ac2..32d19d316844c 100644 --- a/airflow/utils/setup_teardown.py +++ b/airflow/utils/setup_teardown.py @@ -23,8 +23,8 @@ if TYPE_CHECKING: from airflow.models.taskmixin import DependencyMixin - from airflow.models.xcom_arg import PlainXComArg from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator + from airflow.sdk.definitions.xcom_arg import PlainXComArg class BaseSetupTeardownContext: @@ -335,7 +335,7 @@ class SetupTeardownContext(BaseSetupTeardownContext): @staticmethod def add_task(task: AbstractOperator | PlainXComArg): """Add task to context manager.""" - from airflow.models.xcom_arg import PlainXComArg + from airflow.sdk.definitions.xcom_arg import PlainXComArg if not SetupTeardownContext.active: raise AirflowException("Cannot add task to context outside the context manager.") diff --git a/providers/standard/tests/provider_tests/standard/decorators/test_python.py b/providers/standard/tests/provider_tests/standard/decorators/test_python.py index 64914e3d149f8..1c95d626037d4 100644 --- a/providers/standard/tests/provider_tests/standard/decorators/test_python.py +++ b/providers/standard/tests/provider_tests/standard/decorators/test_python.py @@ -26,15 +26,12 @@ from airflow.decorators import setup, task as task_decorator, teardown from airflow.decorators.base import DecoratedMappedOperator from airflow.exceptions import AirflowException, XComNotFound -from airflow.models.baseoperator import BaseOperator -from airflow.models.dag import DAG -from airflow.models.expandinput import DictOfListsExpandInput from airflow.models.taskinstance import TaskInstance from airflow.models.taskmap import TaskMap -from airflow.models.xcom_arg import PlainXComArg, XComArg +from airflow.sdk import DAG, BaseOperator, TaskGroup, XComArg +from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.utils import timezone from airflow.utils.state import State -from airflow.utils.task_group import TaskGroup from airflow.utils.task_instance_session import set_current_task_instance_session from airflow.utils.trigger_rule import TriggerRule from airflow.utils.types import DagRunType @@ -44,9 +41,11 @@ from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.definitions._internal.expandinput import DictOfListsExpandInput from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.utils.types import DagRunTriggeredByType else: + from airflow.models.expandinput import DictOfListsExpandInput from airflow.models.mappedoperator import MappedOperator pytestmark = pytest.mark.db_test @@ -731,6 +730,8 @@ def double(number: int): def test_partial_mapped_decorator() -> None: + from airflow.sdk.definitions.xcom_arg import PlainXComArg + @task_decorator def product(number: int, multiple: int): return number * multiple @@ -787,7 +788,7 @@ def task2(arg1, arg2): ... dec = run.task_instance_scheduling_decisions(session=session) assert [ti.task_id for ti in dec.schedulable_tis] == ["task2"] ti = dec.schedulable_tis[0] - unmapped = ti.task.unmap((ti.get_template_context(session), session)) + unmapped = ti.task.unmap((ti.get_template_context(session),)) assert set(unmapped.op_kwargs) == {"arg1", "arg2"} @@ -817,7 +818,7 @@ def task2(arg1, arg2): ... dec = run.task_instance_scheduling_decisions(session=session) assert [ti.task_id for ti in dec.schedulable_tis] == ["task2", "task2"] for ti in dec.schedulable_tis: - unmapped = ti.task.unmap((ti.get_template_context(session), session)) + unmapped = ti.task.unmap((ti.get_template_context(session),)) assert unmapped.retry_delay == timedelta(seconds=30) diff --git a/providers/standard/tests/provider_tests/standard/operators/test_datetime.py b/providers/standard/tests/provider_tests/standard/operators/test_datetime.py index 67f72f2c6e2f0..62a31e3255aed 100644 --- a/providers/standard/tests/provider_tests/standard/operators/test_datetime.py +++ b/providers/standard/tests/provider_tests/standard/operators/test_datetime.py @@ -74,13 +74,13 @@ def base_tests_setup(self, dag_maker): self.branch_1.set_upstream(self.branch_op) self.branch_2.set_upstream(self.branch_op) - self.dr = dag_maker.create_dagrun( - run_id="manual__", - start_date=DEFAULT_DATE, - logical_date=DEFAULT_DATE, - state=State.RUNNING, - data_interval=(DEFAULT_DATE, DEFAULT_DATE), - ) + self.dr = dag_maker.create_dagrun( + run_id="manual__", + start_date=DEFAULT_DATE, + logical_date=DEFAULT_DATE, + state=State.RUNNING, + data_interval=(DEFAULT_DATE, DEFAULT_DATE), + ) def teardown_method(self): with create_session() as session: diff --git a/providers/standard/tests/provider_tests/standard/operators/test_python.py b/providers/standard/tests/provider_tests/standard/operators/test_python.py index b4320e7a19330..d8d78eb46fd6d 100644 --- a/providers/standard/tests/provider_tests/standard/operators/test_python.py +++ b/providers/standard/tests/provider_tests/standard/operators/test_python.py @@ -42,7 +42,6 @@ from slugify import slugify from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG -from airflow.decorators import task_group from airflow.exceptions import ( AirflowException, DeserializingResultError, @@ -112,8 +111,12 @@ def base_tests_setup(self, request, create_serialized_task_instance_of_operator, self.run_id = f"run_{slugify(request.node.name, max_length=40)}" self.ds_templated = self.default_date.date().isoformat() self.ti_maker = create_serialized_task_instance_of_operator + self.dag_maker = dag_maker self.dag_non_serialized = self.dag_maker(self.dag_id, template_searchpath=TEMPLATE_SEARCHPATH).dag + # We need to entre the context in order to the factory to create things + with self.dag_maker: + ... clear_db_runs() yield clear_db_runs() @@ -140,6 +143,10 @@ def default_kwargs(**kwargs): return kwargs def create_dag_run(self) -> DagRun: + from airflow.models.serialized_dag import SerializedDagModel + + # Update the serialized DAG with any tasks added after initial dag was created + self.dag_maker.serialized_model = SerializedDagModel(self.dag_non_serialized) return self.dag_maker.create_dagrun( state=DagRunState.RUNNING, start_date=self.dag_maker.start_date, @@ -755,39 +762,6 @@ def test_xcom_push_skipped_tasks(self): "skipped": ["empty_task"] } - def test_mapped_xcom_push_skipped_tasks(self, session): - with self.dag_non_serialized: - - @task_group - def group(x): - short_op_push_xcom = ShortCircuitOperator( - task_id="push_xcom_from_shortcircuit", - python_callable=lambda arg: arg % 2 == 0, - op_kwargs={"arg": x}, - ) - empty_task = EmptyOperator(task_id="empty_task") - short_op_push_xcom >> empty_task - - group.expand(x=[0, 1]) - dr = self.create_dag_run() - decision = dr.task_instance_scheduling_decisions(session=session) - for ti in decision.schedulable_tis: - ti.run() - # dr.run(start_date=self.default_date, end_date=self.default_date) - tis = dr.get_task_instances() - - assert ( - tis[0].xcom_pull(task_ids="group.push_xcom_from_shortcircuit", key="return_value", map_indexes=0) - is True - ) - assert ( - tis[0].xcom_pull(task_ids="group.push_xcom_from_shortcircuit", key="skipmixin_key", map_indexes=0) - is None - ) - assert tis[0].xcom_pull( - task_ids="group.push_xcom_from_shortcircuit", key="skipmixin_key", map_indexes=1 - ) == {"skipped": ["group.empty_task"]} - virtualenv_string_args: list[str] = [] diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py index 1d6d0eb4156c3..88e4d452d420a 100644 --- a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -239,7 +239,7 @@ class TaskInstance(BaseModel): dag_id: Annotated[str, Field(title="Dag Id")] run_id: Annotated[str, Field(title="Run Id")] try_number: Annotated[int, Field(title="Try Number")] - map_index: Annotated[int, Field(title="Map Index")] = -1 + map_index: Annotated[int | None, Field(title="Map Index")] = None hostname: Annotated[str | None, Field(title="Hostname")] = None @@ -273,6 +273,7 @@ class TIRunContext(BaseModel): max_tries: Annotated[int, Field(title="Max Tries")] variables: Annotated[list[VariableResponse] | None, Field(title="Variables")] = None connections: Annotated[list[ConnectionResponse] | None, Field(title="Connections")] = None + upstream_map_indexes: Annotated[dict[str, int] | None, Field(title="Upstream Max Indexes")] = None class TITerminalStatePayload(BaseModel): diff --git a/task_sdk/src/airflow/sdk/definitions/_internal/expandinput.py b/task_sdk/src/airflow/sdk/definitions/_internal/expandinput.py new file mode 100644 index 0000000000000..13db9bfa0cc3c --- /dev/null +++ b/task_sdk/src/airflow/sdk/definitions/_internal/expandinput.py @@ -0,0 +1,271 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import functools +import operator +from collections.abc import Iterable, Mapping, Sequence, Sized +from typing import TYPE_CHECKING, Any, ClassVar, Union + +import attrs + +from airflow.sdk.definitions._internal.mixins import ResolveMixin + +if TYPE_CHECKING: + from airflow.sdk.definitions.xcom_arg import XComArg + from airflow.sdk.types import Operator + from airflow.typing_compat import TypeGuard + +ExpandInput = Union["DictOfListsExpandInput", "ListOfDictsExpandInput"] + +# Each keyword argument to expand() can be an XComArg, sequence, or dict (not +# any mapping since we need the value to be ordered). +OperatorExpandArgument = Union["MappedArgument", "XComArg", Sequence, dict[str, Any]] + +# The single argument of expand_kwargs() can be an XComArg, or a list with each +# element being either an XComArg or a dict. +OperatorExpandKwargsArgument = Union["XComArg", Sequence[Union["XComArg", Mapping[str, Any]]]] + + +class NotFullyPopulated(RuntimeError): + """ + Raise when ``get_map_lengths`` cannot populate all mapping metadata. + + This is generally due to not all upstream tasks have finished when the + function is called. + """ + + def __init__(self, missing: set[str]) -> None: + self.missing = missing + + def __str__(self) -> str: + keys = ", ".join(repr(k) for k in sorted(self.missing)) + return f"Failed to populate all mapping metadata; missing: {keys}" + + +# To replace tedious isinstance() checks. +def is_mappable(v: Any) -> TypeGuard[OperatorExpandArgument]: + from airflow.sdk.definitions.xcom_arg import XComArg + + return isinstance(v, (MappedArgument, XComArg, Mapping, Sequence)) and not isinstance(v, str) + + +# To replace tedious isinstance() checks. +def _is_parse_time_mappable(v: OperatorExpandArgument) -> TypeGuard[Mapping | Sequence]: + from airflow.models.xcom_arg import XComArg + + return not isinstance(v, (MappedArgument, XComArg)) + + +# To replace tedious isinstance() checks. +def _needs_run_time_resolution(v: OperatorExpandArgument) -> TypeGuard[MappedArgument | XComArg]: + from airflow.models.xcom_arg import XComArg + + return isinstance(v, (MappedArgument, XComArg)) + + +@attrs.define(kw_only=True) +class MappedArgument(ResolveMixin): + """ + Stand-in stub for task-group-mapping arguments. + + This is very similar to an XComArg, but resolved differently. Declared here + (instead of in the task group module) to avoid import cycles. + """ + + _input: ExpandInput + _key: str + + def iter_references(self) -> Iterable[tuple[Operator, str]]: + yield from self._input.iter_references() + + def resolve(self, context: Mapping[str, Any]) -> Any: + data, _ = self._input.resolve(context) + return data[self._key] + + +@attrs.define() +class DictOfListsExpandInput(ResolveMixin): + """ + Storage type of a mapped operator's mapped kwargs. + + This is created from ``expand(**kwargs)``. + """ + + value: dict[str, OperatorExpandArgument] + + EXPAND_INPUT_TYPE: ClassVar[str] = "dict-of-lists" + + def _iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]: + """Generate kwargs with values available on parse-time.""" + return ((k, v) for k, v in self.value.items() if _is_parse_time_mappable(v)) + + def get_parse_time_mapped_ti_count(self) -> int: + if not self.value: + return 0 + literal_values = [len(v) for _, v in self._iter_parse_time_resolved_kwargs()] + if len(literal_values) != len(self.value): + literal_keys = (k for k, _ in self._iter_parse_time_resolved_kwargs()) + raise NotFullyPopulated(set(self.value).difference(literal_keys)) + return functools.reduce(operator.mul, literal_values, 1) + + def _get_map_lengths( + self, resolved_vals: dict[str, Sized], upstream_map_indexes: dict[str, int] + ) -> dict[str, int]: + """ + Return dict of argument name to map length. + + If any arguments are not known right now (upstream task not finished), + they will not be present in the dict. + """ + + # TODO: This initiates one API call for each XComArg. Would it be + # more efficient to do one single call and unpack the value here? + def _get_length(k: str, v: OperatorExpandArgument) -> int | None: + from airflow.sdk.definitions.xcom_arg import XComArg, get_task_map_length + + if isinstance(v, XComArg): + return get_task_map_length(v, resolved_vals[k], upstream_map_indexes) + + # Unfortunately a user-defined TypeGuard cannot apply negative type + # narrowing. https://github.com/python/typing/discussions/1013 + if TYPE_CHECKING: + assert isinstance(v, Sized) + return len(v) + + map_lengths = { + k: res for k, v in self.value.items() if v is not None if (res := _get_length(k, v)) is not None + } + if len(map_lengths) < len(self.value): + raise NotFullyPopulated(set(self.value).difference(map_lengths)) + return map_lengths + + def _expand_mapped_field(self, key: str, value: Any, map_index: int, all_lengths: dict[str, int]) -> Any: + def _find_index_for_this_field(index: int) -> int: + # Need to use the original user input to retain argument order. + for mapped_key in reversed(self.value): + mapped_length = all_lengths[mapped_key] + if mapped_length < 1: + raise RuntimeError(f"cannot expand field mapped to length {mapped_length!r}") + if mapped_key == key: + return index % mapped_length + index //= mapped_length + return -1 + + found_index = _find_index_for_this_field(map_index) + if found_index < 0: + return value + if isinstance(value, Sequence): + return value[found_index] + if not isinstance(value, dict): + raise TypeError(f"can't map over value of type {type(value)}") + for i, (k, v) in enumerate(value.items()): + if i == found_index: + return k, v + raise IndexError(f"index {map_index} is over mapped length") + + def iter_references(self) -> Iterable[tuple[Operator, str]]: + from airflow.models.xcom_arg import XComArg + + for x in self.value.values(): + if isinstance(x, XComArg): + yield from x.iter_references() + + def resolve(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]: + map_index: int | None = context["ti"].map_index + if map_index is None or map_index < 0: + raise RuntimeError("can't resolve task-mapping argument without expanding") + + upstream_map_indexes = context.get("_upstream_map_indexes", {}) + + # TODO: This initiates one API call for each XComArg. Would it be + # more efficient to do one single call and unpack the value here? + resolved = { + k: v.resolve(context) if _needs_run_time_resolution(v) else v for k, v in self.value.items() + } + + all_lengths = self._get_map_lengths(resolved, upstream_map_indexes) + + data = {k: self._expand_mapped_field(k, v, map_index, all_lengths) for k, v in resolved.items()} + literal_keys = {k for k, _ in self._iter_parse_time_resolved_kwargs()} + resolved_oids = {id(v) for k, v in data.items() if k not in literal_keys} + return data, resolved_oids + + +def _describe_type(value: Any) -> str: + if value is None: + return "None" + return type(value).__name__ + + +@attrs.define() +class ListOfDictsExpandInput(ResolveMixin): + """ + Storage type of a mapped operator's mapped kwargs. + + This is created from ``expand_kwargs(xcom_arg)``. + """ + + value: OperatorExpandKwargsArgument + + EXPAND_INPUT_TYPE: ClassVar[str] = "list-of-dicts" + + def get_parse_time_mapped_ti_count(self) -> int: + if isinstance(self.value, Sized): + return len(self.value) + raise NotFullyPopulated({"expand_kwargs() argument"}) + + def iter_references(self) -> Iterable[tuple[Operator, str]]: + from airflow.models.xcom_arg import XComArg + + if isinstance(self.value, XComArg): + yield from self.value.iter_references() + else: + for x in self.value: + if isinstance(x, XComArg): + yield from x.iter_references() + + def resolve(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]: + map_index = context["ti"].map_index + if map_index < 0: + raise RuntimeError("can't resolve task-mapping argument without expanding") + + mapping: Any = None + if isinstance(self.value, Sized): + mapping = self.value[map_index] + if not isinstance(mapping, Mapping): + mapping = mapping.resolve(context) + else: + mappings = self.value.resolve(context) + if not isinstance(mappings, Sequence): + raise ValueError(f"expand_kwargs() expects a list[dict], not {_describe_type(mappings)}") + mapping = mappings[map_index] + + if not isinstance(mapping, Mapping): + raise ValueError(f"expand_kwargs() expects a list[dict], not list[{_describe_type(mapping)}]") + + for key in mapping: + if not isinstance(key, str): + raise ValueError( + f"expand_kwargs() input dict keys must all be str, " + f"but {key!r} is of type {_describe_type(key)}" + ) + # filter out parse time resolved values from the resolved_oids + resolved_oids = {id(v) for k, v in mapping.items() if not _is_parse_time_mappable(v)} + + return mapping, resolved_oids diff --git a/task_sdk/src/airflow/sdk/definitions/_internal/mixins.py b/task_sdk/src/airflow/sdk/definitions/_internal/mixins.py index fcd68ba20b2c6..93fd9431cbe38 100644 --- a/task_sdk/src/airflow/sdk/definitions/_internal/mixins.py +++ b/task_sdk/src/airflow/sdk/definitions/_internal/mixins.py @@ -133,7 +133,7 @@ def iter_references(self) -> Iterable[tuple[Operator, str]]: """ raise NotImplementedError - def resolve(self, context: Context, *, include_xcom: bool = True) -> Any: + def resolve(self, context: Context) -> Any: """ Resolve this value for runtime. diff --git a/task_sdk/src/airflow/sdk/definitions/_internal/templater.py b/task_sdk/src/airflow/sdk/definitions/_internal/templater.py index d7028d4d6bca7..b50c4dbb3cadb 100644 --- a/task_sdk/src/airflow/sdk/definitions/_internal/templater.py +++ b/task_sdk/src/airflow/sdk/definitions/_internal/templater.py @@ -50,7 +50,7 @@ class LiteralValue(ResolveMixin): def iter_references(self) -> Iterable[tuple[Operator, str]]: return () - def resolve(self, context: Context, *, include_xcom: bool = True) -> Any: + def resolve(self, context: Context) -> Any: return self.value @@ -179,7 +179,7 @@ def render_template( return self._render_object_storage_path(value, context, jinja_env) if resolve := getattr(value, "resolve", None): - return resolve(context, include_xcom=True) + return resolve(context) # Fast path for common built-in collections. if value.__class__ is tuple: diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py b/task_sdk/src/airflow/sdk/definitions/dag.py index 5662d542859f7..3c8d99737ab8e 100644 --- a/task_sdk/src/airflow/sdk/definitions/dag.py +++ b/task_sdk/src/airflow/sdk/definitions/dag.py @@ -359,6 +359,13 @@ class DAG: # argument "description" in call to "DAG"`` etc), so for init=True args we use the `default=Factory()` # style + def __rich_repr__(self): + yield "dag_id", self.dag_id + yield "schedule", self.schedule + yield "#tasks", len(self.tasks) + + __rich_repr__.angular = True # type: ignore[attr-defined] + # NOTE: When updating arguments here, please also keep arguments in @dag() # below in sync. (Search for 'def dag(' in this file.) dag_id: str = attrs.field(kw_only=False, validator=attrs.validators.instance_of(str)) diff --git a/task_sdk/src/airflow/sdk/definitions/mappedoperator.py b/task_sdk/src/airflow/sdk/definitions/mappedoperator.py index 00bd2ab8ab2f9..58e84e86d3f7d 100644 --- a/task_sdk/src/airflow/sdk/definitions/mappedoperator.py +++ b/task_sdk/src/airflow/sdk/definitions/mappedoperator.py @@ -27,11 +27,6 @@ import methodtools from airflow.models.abstractoperator import NotMapped -from airflow.models.expandinput import ( - DictOfListsExpandInput, - ListOfDictsExpandInput, - is_mappable, -) from airflow.sdk.definitions._internal.abstractoperator import ( DEFAULT_EXECUTOR, DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, @@ -47,13 +42,16 @@ DEFAULT_WEIGHT_RULE, AbstractOperator, ) +from airflow.sdk.definitions._internal.expandinput import ( + DictOfListsExpandInput, + ListOfDictsExpandInput, + is_mappable, +) from airflow.sdk.definitions._internal.types import NOTSET from airflow.serialization.enums import DagAttributeTypes from airflow.task.priority_strategy import PriorityWeightStrategy, validate_and_load_priority_weight_strategy -from airflow.triggers.base import StartTriggerArgs from airflow.typing_compat import Literal from airflow.utils.helpers import is_container, prevent_duplicates -from airflow.utils.task_instance_session import get_current_task_instance_session from airflow.utils.xcom import XCOM_RETURN_KEY if TYPE_CHECKING: @@ -61,7 +59,6 @@ import jinja2 # Slow import. import pendulum - from sqlalchemy.orm.session import Session from airflow.models.abstractoperator import ( TaskStateChangeCallback, @@ -78,6 +75,7 @@ from airflow.sdk.definitions.param import ParamsDict from airflow.sdk.types import Operator from airflow.ti_deps.deps.base_ti_dep import BaseTIDep + from airflow.triggers.base import StartTriggerArgs from airflow.typing_compat import TypeGuard from airflow.utils.context import Context from airflow.utils.operator_resources import Resources @@ -683,16 +681,14 @@ def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: """Implement DAGNode.""" return DagAttributeTypes.OP, self.task_id - def _expand_mapped_kwargs( - self, context: Mapping[str, Any], session: Session, *, include_xcom: bool - ) -> tuple[Mapping[str, Any], set[int]]: + def _expand_mapped_kwargs(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]: """ Get the kwargs to create the unmapped operator. This exists because taskflow operators expand against op_kwargs, not the entire operator kwargs dict. """ - return self._get_specified_expand_input().resolve(context, session, include_xcom=include_xcom) + return self._get_specified_expand_input().resolve(context) def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]: """ @@ -726,70 +722,7 @@ def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) - "params": params, } - def expand_start_from_trigger(self, *, context: Context, session: Session) -> bool: - """ - Get the start_from_trigger value of the current abstract operator. - - MappedOperator uses this to unmap start_from_trigger to decide whether to start the task - execution directly from triggerer. - - :meta private: - """ - # start_from_trigger only makes sense when start_trigger_args exists. - if not self.start_trigger_args: - return False - - mapped_kwargs, _ = self._expand_mapped_kwargs(context, session, include_xcom=False) - if self._disallow_kwargs_override: - prevent_duplicates( - self.partial_kwargs, - mapped_kwargs, - fail_reason="unmappable or already specified", - ) - - # Ordering is significant; mapped kwargs should override partial ones. - return mapped_kwargs.get( - "start_from_trigger", self.partial_kwargs.get("start_from_trigger", self.start_from_trigger) - ) - - def expand_start_trigger_args(self, *, context: Context, session: Session) -> StartTriggerArgs | None: - """ - Get the kwargs to create the unmapped start_trigger_args. - - This method is for allowing mapped operator to start execution from triggerer. - """ - if not self.start_trigger_args: - return None - - mapped_kwargs, _ = self._expand_mapped_kwargs(context, session, include_xcom=False) - if self._disallow_kwargs_override: - prevent_duplicates( - self.partial_kwargs, - mapped_kwargs, - fail_reason="unmappable or already specified", - ) - - # Ordering is significant; mapped kwargs should override partial ones. - trigger_kwargs = mapped_kwargs.get( - "trigger_kwargs", - self.partial_kwargs.get("trigger_kwargs", self.start_trigger_args.trigger_kwargs), - ) - next_kwargs = mapped_kwargs.get( - "next_kwargs", - self.partial_kwargs.get("next_kwargs", self.start_trigger_args.next_kwargs), - ) - timeout = mapped_kwargs.get( - "trigger_timeout", self.partial_kwargs.get("trigger_timeout", self.start_trigger_args.timeout) - ) - return StartTriggerArgs( - trigger_cls=self.start_trigger_args.trigger_cls, - trigger_kwargs=trigger_kwargs, - next_method=self.start_trigger_args.next_method, - next_kwargs=next_kwargs, - timeout=timeout, - ) - - def unmap(self, resolve: None | Mapping[str, Any] | tuple[Context, Session]) -> BaseOperator: + def unmap(self, resolve: None | Mapping[str, Any]) -> BaseOperator: """ Get the "normal" Operator after applying the current mapping. @@ -798,30 +731,21 @@ def unmap(self, resolve: None | Mapping[str, Any] | tuple[Context, Session]) -> not a class (i.e. this DAG has been deserialized), this returns a SerializedBaseOperator that "looks like" the actual unmapping result. - If *resolve* is a two-tuple (context, session), the information is used - to resolve the mapped arguments into init arguments. If it is a mapping, - no resolving happens, the mapping directly provides those init arguments - resolved from mapped kwargs. - :meta private: """ if isinstance(self.operator_class, type): if isinstance(resolve, Mapping): kwargs = resolve elif resolve is not None: - kwargs, _ = self._expand_mapped_kwargs(*resolve, include_xcom=True) + kwargs, _ = self._expand_mapped_kwargs(*resolve) else: raise RuntimeError("cannot unmap a non-serialized operator without context") kwargs = self._get_unmap_kwargs(kwargs, strict=self._disallow_kwargs_override) is_setup = kwargs.pop("is_setup", False) is_teardown = kwargs.pop("is_teardown", False) on_failure_fail_dagrun = kwargs.pop("on_failure_fail_dagrun", False) + kwargs["task_id"] = self.task_id op = self.operator_class(**kwargs, _airflow_from_mapped=True) - # We need to overwrite task_id here because BaseOperator further - # mangles the task_id based on the task hierarchy (namely, group_id - # is prepended, and '__N' appended to deduplicate). This is hacky, - # but better than duplicating the whole mangling logic. - op.task_id = self.task_id op.is_setup = is_setup op.is_teardown = is_teardown op.on_failure_fail_dagrun = on_failure_fail_dagrun @@ -856,7 +780,7 @@ def prepare_for_execution(self) -> MappedOperator: def iter_mapped_dependencies(self) -> Iterator[Operator]: """Upstream dependencies that provide XComs used by this task for task mapping.""" - from airflow.models.xcom_arg import XComArg + from airflow.sdk.definitions.xcom_arg import XComArg for operator, _ in XComArg.iter_xcom_references(self._get_specified_expand_input()): yield operator @@ -886,23 +810,14 @@ def render_template_fields( :param context: Context dict with values to apply on content. :param jinja_env: Jinja environment to use for rendering. """ - from airflow.utils.context import context_update_for_unmapped + from airflow.sdk.execution_time.context import context_update_for_unmapped if not jinja_env: jinja_env = self.get_template_env() - # We retrieve the session here, stored by _run_raw_task in set_current_task_session - # context manager - we cannot pass the session via @provide_session because the signature - # of render_template_fields is defined by BaseOperator and there are already many subclasses - # overriding it, so changing the signature is not an option. However render_template_fields is - # always executed within "_run_raw_task" so we make sure that _run_raw_task uses the - # set_current_task_session context manager to store the session in the current task. - session = get_current_task_instance_session() - - mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context, session, include_xcom=True) + mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context) unmapped_task = self.unmap(mapped_kwargs) - # TODO: Task-SDK: remove arg-type ignore once Kaxil's PR lands - context_update_for_unmapped(context, unmapped_task) # type: ignore[arg-type] + context_update_for_unmapped(context, unmapped_task) # Since the operators that extend `BaseOperator` are not subclasses of # `MappedOperator`, we need to call `_do_render_template_fields` from diff --git a/task_sdk/src/airflow/sdk/definitions/param.py b/task_sdk/src/airflow/sdk/definitions/param.py index cd3ccec26a48a..d9eec82a147fc 100644 --- a/task_sdk/src/airflow/sdk/definitions/param.py +++ b/task_sdk/src/airflow/sdk/definitions/param.py @@ -67,8 +67,7 @@ def _check_json(value): json.dumps(value) except Exception: raise ParamValidationError( - "All provided parameters must be json-serializable. " - f"The value '{value}' is not serializable." + f"All provided parameters must be json-serializable. The value '{value}' is not serializable." ) def resolve(self, value: Any = NOTSET, suppress_exception: bool = False) -> Any: @@ -294,7 +293,7 @@ def __init__(self, current_dag: DAG, name: str, default: Any = NOTSET): def iter_references(self) -> Iterable[tuple[Operator, str]]: return () - def resolve(self, context: Context, *, include_xcom: bool = True) -> Any: + def resolve(self, context: Context) -> Any: """Pull DagParam value from DagRun context. This method is run during ``op.execute()``.""" with contextlib.suppress(KeyError): if context["dag_run"].conf: diff --git a/task_sdk/src/airflow/sdk/definitions/xcom_arg.py b/task_sdk/src/airflow/sdk/definitions/xcom_arg.py index 436cd9d005012..7a2130a784c89 100644 --- a/task_sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task_sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -20,30 +20,27 @@ import contextlib import inspect import itertools -from collections.abc import Iterable, Iterator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Callable, Union, overload +from collections.abc import Iterable, Iterator, Mapping, Sequence, Sized +from functools import singledispatch +from typing import TYPE_CHECKING, Any, Callable, overload from airflow.exceptions import AirflowException, XComNotFound from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator from airflow.sdk.definitions._internal.mixins import DependencyMixin, ResolveMixin from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet -from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.setup_teardown import SetupTeardownContext from airflow.utils.trigger_rule import TriggerRule from airflow.utils.xcom import XCOM_RETURN_KEY if TYPE_CHECKING: - from sqlalchemy.orm import Session - from airflow.sdk.definitions.baseoperator import BaseOperator - from airflow.sdk.definitions.dag import DAG from airflow.sdk.types import Operator from airflow.utils.edgemodifier import EdgeModifier # Callable objects contained by MapXComArg. We only accept callables from # the user, but deserialize them into strings in a serialized XComArg for # safety (those callables are arbitrary user code). -MapCallables = Sequence[Union[Callable[[Any], Any], str]] +MapCallables = Sequence[Callable[[Any], Any]] class XComArg(ResolveMixin, DependencyMixin): @@ -168,20 +165,6 @@ def _serialize(self) -> dict[str, Any]: """ raise NotImplementedError() - @classmethod - def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: - """ - Deserialize an XComArg. - - The implementation should be the inverse function to ``serialize``, - implementing given a data dict converted from this XComArg derivative, - how the original XComArg should be created. DAG serialization relies on - additional information added in ``serialize_xcom_arg`` to dispatch data - dicts to the correct ``_deserialize`` information, so this function does - not need to validate whether the incoming data contains correct keys. - """ - raise NotImplementedError() - def map(self, f: Callable[[Any], Any]) -> MapXComArg: return MapXComArg(self, [f]) @@ -191,9 +174,7 @@ def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg: def concat(self, *others: XComArg) -> ConcatXComArg: return ConcatXComArg([self, *others]) - def resolve( - self, context: Mapping[str, Any], session: Session | None = None, *, include_xcom: bool = True - ) -> Any: + def resolve(self, context: Mapping[str, Any]) -> Any: raise NotImplementedError() def __enter__(self): @@ -266,7 +247,7 @@ def __str__(self) -> str: **Example**: to use XComArg at BashOperator:: - BashOperator(cmd=f"... { xcomarg } ...") + BashOperator(cmd=f"... {xcomarg} ...") :return: """ @@ -285,10 +266,6 @@ def __str__(self) -> str: def _serialize(self) -> dict[str, Any]: return {"task_id": self.operator.task_id, "key": self.key} - @classmethod - def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: - return cls(dag.get_task(data["task_id"]), data["key"]) - @property def is_setup(self) -> bool: return self.operator.is_setup @@ -354,10 +331,7 @@ def concat(self, *others: XComArg) -> ConcatXComArg: raise ValueError("cannot concatenate non-return XCom") return super().concat(*others) - # TODO: Task-SDK: Remove session argument once everything is ported over to Task SDK - def resolve( - self, context: Mapping[str, Any], session: Session | None = None, *, include_xcom: bool = True - ) -> Any: + def resolve(self, context: Mapping[str, Any]) -> Any: ti = context["ti"] task_id = self.operator.task_id map_indexes = context.get("_upstream_map_indexes", {}).get(task_id) @@ -448,12 +422,6 @@ def _serialize(self) -> dict[str, Any]: "callables": [inspect.getsource(c) if callable(c) else c for c in self.callables], } - @classmethod - def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: - # We are deliberately NOT deserializing the callables. These are shown - # in the UI, and displaying a function object is useless. - return cls(deserialize_xcom_arg(data["arg"], dag), data["callables"]) - def iter_references(self) -> Iterator[tuple[Operator, str]]: yield from self.arg.iter_references() @@ -461,12 +429,8 @@ def map(self, f: Callable[[Any], Any]) -> MapXComArg: # Flatten arg.map(f1).map(f2) into one MapXComArg. return MapXComArg(self.arg, [*self.callables, f]) - # TODO: Task-SDK: Remove session argument once everything is ported over to Task SDK - @provide_session - def resolve( - self, context: Mapping[str, Any], session: Session = NEW_SESSION, *, include_xcom: bool = True - ) -> Any: - value = self.arg.resolve(context, session=session, include_xcom=include_xcom) + def resolve(self, context: Mapping[str, Any]) -> Any: + value = self.arg.resolve(context) if not isinstance(value, (Sequence, dict)): raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}") return _MapResult(value, self.callables) @@ -525,22 +489,12 @@ def _serialize(self) -> dict[str, Any]: return {"args": args} return {"args": args, "fillvalue": self.fillvalue} - @classmethod - def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: - return cls( - [deserialize_xcom_arg(arg, dag) for arg in data["args"]], - fillvalue=data.get("fillvalue", NOTSET), - ) - def iter_references(self) -> Iterator[tuple[Operator, str]]: for arg in self.args: yield from arg.iter_references() - @provide_session - def resolve( - self, context: Mapping[str, Any], session: Session = NEW_SESSION, *, include_xcom: bool = True - ) -> Any: - values = [arg.resolve(context, session=session, include_xcom=include_xcom) for arg in self.args] + def resolve(self, context: Mapping[str, Any]) -> Any: + values = [arg.resolve(context) for arg in self.args] for value in values: if not isinstance(value, (Sequence, dict)): raise ValueError(f"XCom zip expects sequence or dict, not {type(value).__name__}") @@ -594,10 +548,6 @@ def __repr__(self) -> str: def _serialize(self) -> dict[str, Any]: return {"args": [serialize_xcom_arg(arg) for arg in self.args]} - @classmethod - def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg: - return cls([deserialize_xcom_arg(arg, dag) for arg in data["args"]]) - def iter_references(self) -> Iterator[tuple[Operator, str]]: for arg in self.args: yield from arg.iter_references() @@ -606,11 +556,8 @@ def concat(self, *others: XComArg) -> ConcatXComArg: # Flatten foo.concat(x).concat(y) into one call. return ConcatXComArg([*self.args, *others]) - @provide_session - def resolve( - self, context: Mapping[str, Any], session: Session = NEW_SESSION, *, include_xcom: bool = True - ) -> Any: - values = [arg.resolve(context, session=session, include_xcom=include_xcom) for arg in self.args] + def resolve(self, context: Mapping[str, Any]) -> Any: + values = [arg.resolve(context) for arg in self.args] for value in values: if not isinstance(value, (Sequence, dict)): raise ValueError(f"XCom concat expects sequence or dict, not {type(value).__name__}") @@ -633,7 +580,44 @@ def serialize_xcom_arg(value: XComArg) -> dict[str, Any]: return value._serialize() -def deserialize_xcom_arg(data: dict[str, Any], dag: DAG) -> XComArg: - """DAG serialization interface.""" - klass = _XCOM_ARG_TYPES[data.get("type", "")] - return klass._deserialize(data, dag) +@singledispatch +def get_task_map_length( + xcom_arg: XComArg, resolved_val: Sized, upstream_map_indexes: dict[str, int] +) -> int | None: + # The base implementation -- specific XComArg subclasses have specialised implementations + raise NotImplementedError() + + +@get_task_map_length.register +def _(xcom_arg: PlainXComArg, resolved_val: Sized, upstream_map_indexes: dict[str, int]): + task_id = xcom_arg.operator.task_id + + if xcom_arg.operator.is_mapped: + # TODO: How to tell if all the upstream TIs finished? + pass + return (upstream_map_indexes.get(task_id) or 1) * len(resolved_val) + + +@get_task_map_length.register +def _(xcom_arg: MapXComArg, resolved_val: Sized, upstream_map_indexes: dict[str, int]): + return get_task_map_length(xcom_arg.arg, resolved_val, upstream_map_indexes) + + +@get_task_map_length.register +def _(xcom_arg: ZipXComArg, resolved_val: Sized, upstream_map_indexes: dict[str, int]): + all_lengths = (get_task_map_length(arg, resolved_val, upstream_map_indexes) for arg in xcom_arg.args) + ready_lengths = [length for length in all_lengths if length is not None] + if len(ready_lengths) != len(xcom_arg.args): + return None # If any of the referenced XComs is not ready, we are not ready either. + if isinstance(xcom_arg.fillvalue, ArgNotSet): + return min(ready_lengths) + return max(ready_lengths) + + +@get_task_map_length.register +def _(xcom_arg: ConcatXComArg, resolved_val: Sized, upstream_map_indexes: dict[str, int]): + all_lengths = (get_task_map_length(arg, resolved_val, upstream_map_indexes) for arg in xcom_arg.args) + ready_lengths = [length for length in all_lengths if length is not None] + if len(ready_lengths) != len(xcom_arg.args): + return None # If any of the referenced XComs is not ready, we are not ready either. + return sum(ready_lengths) diff --git a/task_sdk/src/airflow/sdk/execution_time/context.py b/task_sdk/src/airflow/sdk/execution_time/context.py index 984919ea1c86b..0674d27320a6c 100644 --- a/task_sdk/src/airflow/sdk/execution_time/context.py +++ b/task_sdk/src/airflow/sdk/execution_time/context.py @@ -42,6 +42,7 @@ if TYPE_CHECKING: from uuid import UUID + from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.connection import Connection from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.variable import Variable @@ -315,3 +316,21 @@ def set_current_context(context: Context) -> Generator[Context, None, None]: expected=context, got=expected_state, ) + + +def context_update_for_unmapped(context: Context, task: BaseOperator) -> None: + """ + Update context after task unmapping. + + Since ``get_template_context()`` is called before unmapping, the context + contains information about the mapped task. We need to do some in-place + updates to ensure the template context reflects the unmapped task instead. + + :meta private: + """ + from airflow.sdk.definitions.param import process_params + + context["task"] = context["ti"].task = task + context["params"] = process_params( + context["dag"], task, context["dag_run"].conf, suppress_exception=False + ) diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index a0f6189b89b2d..d0d3063c5f08b 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -83,6 +83,16 @@ class RuntimeTaskInstance(TaskInstance): max_tries: int = 0 """The maximum number of retries for the task.""" + def __rich_repr__(self): + yield "id", self.id + yield "task_id", self.task_id + yield "dag_id", self.dag_id + yield "run_id", self.run_id + yield "max_tries", self.max_tries + yield "task", type(self.task) + + __rich_repr__.angular = True # type: ignore[attr-defined] + def get_template_context(self) -> Context: # TODO: Move this to `airflow.sdk.execution_time.context` # once we port the entire context logic from airflow/utils/context.py ? @@ -123,8 +133,8 @@ def get_template_context(self) -> Context: }, "conn": ConnectionAccessor(), } - if self._ti_context_from_server: - dag_run = self._ti_context_from_server.dag_run + if from_server := self._ti_context_from_server: + dag_run = from_server.dag_run logical_date = dag_run.logical_date ds = logical_date.strftime("%Y-%m-%d") @@ -160,6 +170,11 @@ def get_template_context(self) -> Context: } context.update(context_from_server) + if from_server.upstream_map_indexes is not None: + # We stash this in here for later use, but we purposefully don't want to document it's + # existence. Should this be a private attribute on RuntimeTI instead perhaps? + context["_upstream_map_indexes"] = from_server.upstream_map_indexes # type: ignore [typeddict-unknown-key] + return context def render_templates( @@ -190,10 +205,7 @@ def render_templates( # unmapped BaseOperator created by this function! This is because the # MappedOperator is useless for template rendering, and we need to be # able to access the unmapped task instead. - original_task.render_template_fields(context, jinja_env) - # TODO: Add support for rendering templates in the MappedOperator - # if isinstance(self.task, MappedOperator): - # self.task = context["ti"].task + self.task.render_template_fields(context, jinja_env) return original_task @@ -538,11 +550,11 @@ def run(ti: RuntimeTaskInstance, log: Logger): context = ti.get_template_context() with set_current_context(context): jinja_env = ti.task.dag.get_template_env() - ti.task = ti.render_templates(context=context, jinja_env=jinja_env) + ti.render_templates(context=context, jinja_env=jinja_env) # TODO: Get things from _execute_task_with_callbacks # - Pre Execute # etc - result = _execute_task(context, ti.task) + result = _execute_task(context, ti) _push_xcom_if_needed(result, ti) @@ -622,10 +634,11 @@ def run(ti: RuntimeTaskInstance, log: Logger): SUPERVISOR_COMMS.send_request(msg=msg, log=log) -def _execute_task(context: Context, task: BaseOperator): +def _execute_task(context: Context, ti: RuntimeTaskInstance): """Execute Task (optionally with a Timeout) and push Xcom results.""" from airflow.exceptions import AirflowTaskTimeout + task = ti.task if task.execution_timeout: # TODO: handle timeout in case of deferral from airflow.utils.timeout import timeout diff --git a/task_sdk/src/airflow/sdk/types.py b/task_sdk/src/airflow/sdk/types.py index fd02104fb2fc4..5c1a19d67dea1 100644 --- a/task_sdk/src/airflow/sdk/types.py +++ b/task_sdk/src/airflow/sdk/types.py @@ -54,7 +54,7 @@ class RuntimeTaskInstanceProtocol(Protocol): dag_id: str run_id: str try_number: int - map_index: int + map_index: int | None max_tries: int hostname: str | None = None diff --git a/task_sdk/tests/conftest.py b/task_sdk/tests/conftest.py index cc4bc4f96148a..ab00ed543b299 100644 --- a/task_sdk/tests/conftest.py +++ b/task_sdk/tests/conftest.py @@ -35,6 +35,9 @@ from structlog.typing import EventDict, WrappedLogger from airflow.sdk.api.datamodels._generated import TIRunContext + from airflow.sdk.definitions.baseoperator import BaseOperator + from airflow.sdk.execution_time.comms import StartupDetails + from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance @pytest.hookimpl() @@ -236,3 +239,133 @@ def mock_supervisor_comms(): "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True ) as supervisor_comms: yield supervisor_comms + + +@pytest.fixture +def mocked_parse(spy_agency): + """ + Fixture to set up an inline DAG and use it in a stubbed `parse` function. Use this fixture if you + want to isolate and test `parse` or `run` logic without having to define a DAG file. + + This fixture returns a helper function `set_dag` that: + 1. Creates an in line DAG with the given `dag_id` and `task` (limited to one task) + 2. Constructs a `RuntimeTaskInstance` based on the provided `StartupDetails` and task. + 3. Stubs the `parse` function using `spy_agency`, to return the mocked `RuntimeTaskInstance`. + + After adding the fixture in your test function signature, you can use it like this :: + + mocked_parse( + StartupDetails( + ti=TaskInstance(id=uuid7(), task_id="hello", dag_id="super_basic_run", run_id="c", try_number=1), + file="", + requests_fd=0, + ), + "example_dag_id", + CustomOperator(task_id="hello"), + ) + """ + + def set_dag(what: StartupDetails, dag_id: str, task: BaseOperator) -> RuntimeTaskInstance: + from airflow.sdk.definitions.dag import DAG + from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance, parse + from airflow.utils import timezone + + if not task.has_dag(): + dag = DAG(dag_id=dag_id, start_date=timezone.datetime(2024, 12, 3)) + task.dag = dag + task = dag.task_dict[task.task_id] + if what.ti_context.dag_run.conf: + dag.params = what.ti_context.dag_run.conf # type: ignore[assignment] + ti = RuntimeTaskInstance.model_construct( + **what.ti.model_dump(exclude_unset=True), + task=task, + _ti_context_from_server=what.ti_context, + max_tries=what.ti_context.max_tries, + ) + if hasattr(parse, "spy"): + spy_agency.unspy(parse) + spy_agency.spy_on(parse, call_fake=lambda _: ti) + return ti + + return set_dag + + +@pytest.fixture +def create_runtime_ti(mocked_parse, make_ti_context): + """ + Fixture to create a Runtime TaskInstance for testing purposes without defining a dag file. + + It mimics the behavior of the `parse` function by creating a `RuntimeTaskInstance` based on the provided + `StartupDetails` (formed from arguments) and task. This allows you to test the logic of a task without + having to define a DAG file, parse it, get context from the server, etc. + + Example usage: :: + + def test_custom_task_instance(create_runtime_ti): + class MyTaskOperator(BaseOperator): + def execute(self, context): + assert context["dag_run"].run_id == "test_run" + + task = MyTaskOperator(task_id="test_task") + ti = create_runtime_ti(task, context_from_server=make_ti_context(run_id="test_run")) + # Further test logic... + """ + from uuid6 import uuid7 + + from airflow.sdk.api.datamodels._generated import TaskInstance + from airflow.sdk.execution_time.comms import BundleInfo, StartupDetails + + def _create_task_instance( + task: BaseOperator, + dag_id: str = "test_dag", + run_id: str = "test_run", + logical_date: str | datetime = "2024-12-01T01:00:00Z", + data_interval_start: str | datetime = "2024-12-01T00:00:00Z", + data_interval_end: str | datetime = "2024-12-01T01:00:00Z", + start_date: str | datetime = "2024-12-01T01:00:00Z", + run_type: str = "manual", + try_number: int = 1, + map_index: int | None = None, + upstream_map_indexes: dict[str, int] | None = None, + ti_id=None, + conf=None, + ) -> RuntimeTaskInstance: + if not ti_id: + ti_id = uuid7() + + if task.has_dag(): + dag_id = task.dag.dag_id + + ti_context = make_ti_context( + dag_id=dag_id, + run_id=run_id, + logical_date=logical_date, + data_interval_start=data_interval_start, + data_interval_end=data_interval_end, + start_date=start_date, + run_type=run_type, + conf=conf, + ) + + if upstream_map_indexes is not None: + ti_context.upstream_map_indexes = upstream_map_indexes + + startup_details = StartupDetails( + ti=TaskInstance( + id=ti_id, + task_id=task.task_id, + dag_id=dag_id, + run_id=run_id, + try_number=try_number, + map_index=map_index, + ), + dag_rel_path="", + bundle_info=BundleInfo(name="anything", version="any"), + requests_fd=0, + ti_context=ti_context, + ) + + ti = mocked_parse(startup_details, dag_id, task) + return ti + + return _create_task_instance diff --git a/task_sdk/tests/definitions/test_baseoperator.py b/task_sdk/tests/definitions/test_baseoperator.py index af6bf592f5373..35f33818dc198 100644 --- a/task_sdk/tests/definitions/test_baseoperator.py +++ b/task_sdk/tests/definitions/test_baseoperator.py @@ -621,26 +621,3 @@ def _do_render(): assert expected_log in caplog.text if not_expected_log: assert not_expected_log not in caplog.text - - -def test_find_mapped_dependants_in_another_group(): - from airflow.decorators import task as task_decorator - from airflow.sdk import TaskGroup - - @task_decorator - def gen(x): - return list(range(x)) - - @task_decorator - def add(x, y): - return x + y - - with DAG(dag_id="test"): - with TaskGroup(group_id="g1"): - gen_result = gen(3) - with TaskGroup(group_id="g2"): - add_result = add.partial(y=1).expand(x=gen_result) - - # breakpoint() - dependants = list(gen_result.operator.iter_mapped_dependants()) - assert dependants == [add_result.operator] diff --git a/task_sdk/tests/definitions/test_mappedoperator.py b/task_sdk/tests/definitions/test_mappedoperator.py index eeb79f31b4d47..588e0d7516cfa 100644 --- a/task_sdk/tests/definitions/test_mappedoperator.py +++ b/task_sdk/tests/definitions/test_mappedoperator.py @@ -17,16 +17,21 @@ # under the License. from __future__ import annotations +import json from datetime import datetime, timedelta +from typing import TYPE_CHECKING, Callable +from unittest import mock import pendulum import pytest +import structlog +from airflow.sdk.api.datamodels._generated import TerminalTIState from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.dag import DAG from airflow.sdk.definitions.mappedoperator import MappedOperator -from airflow.sdk.definitions.param import ParamsDict from airflow.sdk.definitions.xcom_arg import XComArg +from airflow.sdk.execution_time.comms import GetXCom, SetXCom, SucceedTask, TaskState, XComResult from airflow.utils.trigger_rule import TriggerRule from tests_common.test_utils.mapping import expand_mapped_task # noqa: F401 @@ -120,9 +125,6 @@ def test_map_xcom_arg(): assert task1.downstream_list == [mapped] -# def test_map_xcom_arg_multiple_upstream_xcoms(dag_maker, session): - - def test_partial_on_instance() -> None: """`.partial` on an instance should fail -- it's only designed to be called on classes""" with pytest.raises(TypeError): @@ -154,15 +156,6 @@ def test_partial_on_invalid_pool_slots_raises() -> None: MockOperator.partial(task_id="pool_slots_test", pool="test", pool_slots="a").expand(arg1=[1, 2, 3]) -# def test_expand_mapped_task_instance(dag_maker, session, num_existing_tis, expected): - - -# def test_expand_mapped_task_failed_state_in_db(dag_maker, session): - - -# def test_expand_mapped_task_instance_skipped_on_zero(dag_maker, session): - - def test_mapped_task_applies_default_args_classic(): with DAG("test", default_args={"execution_timeout": timedelta(minutes=30)}) as dag: MockOperator(task_id="simple", arg1=None, arg2=0) @@ -191,43 +184,120 @@ def mapped(arg): @pytest.mark.parametrize( - "dag_params, task_params, expected_partial_params", + ("callable", "expected"), [ - pytest.param(None, None, ParamsDict(), id="none"), - pytest.param({"a": -1}, None, ParamsDict({"a": -1}), id="dag"), - pytest.param(None, {"b": -2}, ParamsDict({"b": -2}), id="task"), - pytest.param({"a": -1}, {"b": -2}, ParamsDict({"a": -1, "b": -2}), id="merge"), + pytest.param( + lambda partial, output1: partial.expand( + map_template=output1, map_static=output1, file_template=["/path/to/file.ext"] + ), + # Note to the next person to come across this. In #32272 we changed expand_kwargs so that it + # resolves the mapped template when it's in `expand_kwargs()`, but we _didn't_ do the same for + # things in `expand()`. This feels like a bug to me (ashb) but I am not changing that now, I have + # just moved and parametrized this test. + "{{ ds }}", + id="expand", + ), + pytest.param( + lambda partial, output1: partial.expand_kwargs( + [{"map_template": "{{ ds }}", "map_static": "{{ ds }}", "file_template": "/path/to/file.ext"}] + ), + "2024-12-01", + id="expand_kwargs", + ), ], ) -def test_mapped_expand_against_params(dag_params, task_params, expected_partial_params): - with DAG("test", params=dag_params) as dag: - MockOperator.partial(task_id="t", params=task_params).expand(params=[{"c": "x"}, {"d": 1}]) - - t = dag.get_task("t") - assert isinstance(t, MappedOperator) - assert t.params == expected_partial_params - assert t.expand_input.value == {"params": [{"c": "x"}, {"d": 1}]} - +def test_mapped_render_template_fields_validating_operator( + tmp_path, create_runtime_ti, mock_supervisor_comms, callable, expected: bool +): + file_template_dir = tmp_path / "path" / "to" + file_template_dir.mkdir(parents=True, exist_ok=True) + file_template = file_template_dir / "file.ext" + file_template.write_text("loaded data") + + class MyOperator(BaseOperator): + template_fields = ("partial_template", "map_template", "file_template") + template_ext = (".ext",) + + def __init__( + self, partial_template, partial_static, map_template, map_static, file_template, **kwargs + ): + for value in [partial_template, partial_static, map_template, map_static, file_template]: + assert isinstance(value, str), "value should have been resolved before unmapping" + super().__init__(**kwargs) + self.partial_template = partial_template + self.partial_static = partial_static + self.map_template = map_template + self.map_static = map_static + self.file_template = file_template + + def execute(self, context): + pass -# def test_mapped_render_template_fields_validating_operator(dag_maker, session, tmp_path): + with DAG("test_dag", template_searchpath=tmp_path.__fspath__()): + task1 = BaseOperator(task_id="op1") + mapped = MyOperator.partial( + task_id="a", partial_template="{{ ti.task_id }}", partial_static="{{ ti.task_id }}" + ) + mapped = callable(mapped, task1.output) + mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value='["{{ ds }}"]') -# def test_mapped_expand_kwargs_render_template_fields_validating_operator(dag_maker, session, tmp_path): + mapped_ti = create_runtime_ti(task=mapped, map_index=0, upstream_map_indexes={task1.task_id: 1}) + assert isinstance(mapped_ti.task, MappedOperator) + mapped_ti.task.render_template_fields(context=mapped_ti.get_template_context()) + assert isinstance(mapped_ti.task, MyOperator) -# def test_mapped_render_nested_template_fields(dag_maker, session): + assert mapped_ti.task.partial_template == "a", "Should be templated!" + assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be templated!" + assert mapped_ti.task.map_template == expected + assert mapped_ti.task.map_static == "{{ ds }}", "Should not be templated!" + assert mapped_ti.task.file_template == "loaded data", "Should be templated!" -# def test_expand_kwargs_mapped_task_instance(dag_maker, session, num_existing_tis, expected): +def test_mapped_render_nested_template_fields(create_runtime_ti, mock_supervisor_comms): + with DAG("test_dag"): + mapped = MockOperatorWithNestedFields.partial( + task_id="t", arg2=NestedFields(field_1="{{ ti.task_id }}", field_2="value_2") + ).expand(arg1=["{{ ti.task_id }}", ["s", "{{ ti.task_id }}"]]) + ti = create_runtime_ti(task=mapped, map_index=0, upstream_map_indexes={}) + ti.task.render_template_fields(context=ti.get_template_context()) + assert ti.task.arg1 == "t" + assert ti.task.arg2.field_1 == "t" + assert ti.task.arg2.field_2 == "value_2" -# def test_expand_mapped_task_instance_with_named_index( + ti = create_runtime_ti(task=mapped, map_index=1, upstream_map_indexes={}) + ti.task.render_template_fields(context=ti.get_template_context()) + assert ti.task.arg1 == ["s", "t"] + assert ti.task.arg2.field_1 == "t" + assert ti.task.arg2.field_2 == "value_2" -# def test_expand_mapped_task_task_instance_mutation_hook(dag_maker, session, create_mapped_task) -> None: +@pytest.mark.parametrize( + ("map_index", "expected"), + [ + pytest.param(0, "2024-12-01", id="0"), + pytest.param(1, 2, id="1"), + ], +) +def test_expand_kwargs_render_template_fields_validating_operator( + map_index, expected, create_runtime_ti, mock_supervisor_comms +): + with DAG("test_dag"): + task1 = BaseOperator(task_id="op1") + mapped = MockOperator.partial(task_id="a", arg2="{{ ti.task_id }}").expand_kwargs(task1.output) + mock_supervisor_comms.get_message.return_value = XComResult( + key="return_value", value=json.dumps([{"arg1": "{{ ds }}"}, {"arg1": 2}]) + ) -# def test_expand_kwargs_render_template_fields_validating_operator(dag_maker, session, map_index, expected): + ti = create_runtime_ti(task=mapped, map_index=map_index, upstream_map_indexes={}) + assert isinstance(ti.task, MappedOperator) + ti.task.render_template_fields(context=ti.get_template_context()) + assert isinstance(ti.task, MockOperator) + assert ti.task.arg1 == expected + assert ti.task.arg2 == "a" def test_xcomarg_property_of_mapped_operator(): @@ -252,7 +322,29 @@ def test_set_xcomarg_dependencies_with_mapped_operator(): assert op2 in op5.upstream_list -# def test_all_xcomargs_from_mapped_tasks_are_consumable(dag_maker, session): +def test_all_xcomargs_from_mapped_tasks_are_consumable(create_runtime_ti, mock_supervisor_comms): + with DAG("test_all_xcomargs_from_mapped_tasks_are_consumable") as dag: + op1 = MockOperator.partial(task_id="op1").expand(arg1=[1, 2, 3]) + downstream = MockOperator(task_id="op2", arg1=op1.output) + + mock_supervisor_comms.get_message.return_value = XComResult( + key="return_value", value=json.dumps([1, 2, 3]) + ) + + ti = create_runtime_ti(task=downstream, upstream_map_indexes={"op1": None}) + ti.task.render_template_fields(context=ti.get_template_context()) + mock_supervisor_comms.send_request.assert_called_once_with( + log=mock.ANY, + msg=GetXCom( + key="return_value", + dag_id=dag.dag_id, + run_id="test_run", + task_id=op1.task_id, + map_index=None, + ), + ) + + assert ti.task.arg1 == [1, 2, 3] def test_task_mapping_with_task_group_context(): @@ -299,3 +391,324 @@ def test_task_mapping_with_explicit_task_group(): assert finish.upstream_list == [mapped] assert mapped.downstream_list == [finish] + + +RunTI = Callable[[DAG, str, int], TerminalTIState] + + +@pytest.fixture +def run_ti(create_runtime_ti, mock_supervisor_comms) -> RunTI: + def run(dag: DAG, task_id: str, map_index: int): + """Run the task and return the state that the SDK sent as the result for easier asserts""" + from airflow.sdk.execution_time.task_runner import run + + log = structlog.get_logger() + mock_supervisor_comms.send_request.reset_mock() + ti = create_runtime_ti(dag.task_dict[task_id], map_index=map_index) + run(ti, log) + for call in mock_supervisor_comms.send_request.mock_calls: + msg = call.kwargs["msg"] + if isinstance(msg, (TaskState, SucceedTask)): + return msg.state + raise RuntimeError("Unable to find call to TaskState") + + return run + + +def test_map_cross_product(run_ti: RunTI, mock_supervisor_comms): + outputs = [] + + with DAG(dag_id="cross_product") as dag: + + @dag.task + def emit_numbers(): + return [1, 2] + + @dag.task + def emit_letters(): + return {"a": "x", "b": "y", "c": "z"} + + @dag.task + def show(number, letter): + outputs.append((number, letter)) + + show.expand(number=emit_numbers(), letter=emit_letters()) + + def xcom_get(): + # TODO: Tidy this after #45927 is reopened and fixed properly + last_request = mock_supervisor_comms.send_request.mock_calls[-1].kwargs["msg"] + if not isinstance(last_request, GetXCom): + return mock.DEFAULT + task = dag.get_task(last_request.task_id) + value = json.dumps(task.python_callable()) + return XComResult(key="return_value", value=value) + + mock_supervisor_comms.get_message.side_effect = xcom_get + + states = [run_ti(dag, "show", map_index) for map_index in range(6)] + assert states == [TerminalTIState.SUCCESS] * 6 + assert outputs == [ + (1, ("a", "x")), + (1, ("b", "y")), + (1, ("c", "z")), + (2, ("a", "x")), + (2, ("b", "y")), + (2, ("c", "z")), + ] + + +def test_map_product_same(run_ti: RunTI, mock_supervisor_comms): + """Test a mapped task can refer to the same source multiple times.""" + outputs = [] + + with DAG(dag_id="product_same") as dag: + + @dag.task + def emit_numbers(): + return [1, 2] + + @dag.task + def show(a, b): + outputs.append((a, b)) + + emit_task = emit_numbers() + show.expand(a=emit_task, b=emit_task) + + def xcom_get(): + # TODO: Tidy this after #45927 is reopened and fixed properly + last_request = mock_supervisor_comms.send_request.mock_calls[-1].kwargs["msg"] + if not isinstance(last_request, GetXCom): + return mock.DEFAULT + task = dag.get_task(last_request.task_id) + value = json.dumps(task.python_callable()) + return XComResult(key="return_value", value=value) + + mock_supervisor_comms.get_message.side_effect = xcom_get + + states = [run_ti(dag, "show", map_index) for map_index in range(4)] + assert states == [TerminalTIState.SUCCESS] * 4 + assert outputs == [(1, 1), (1, 2), (2, 1), (2, 2)] + + +class NestedFields: + """Nested fields for testing purposes.""" + + def __init__(self, field_1, field_2): + self.field_1 = field_1 + self.field_2 = field_2 + + +class MockOperatorWithNestedFields(BaseOperator): + """Operator with nested fields for testing purposes.""" + + template_fields = ("arg1", "arg2") + + def __init__(self, arg1: str = "", arg2: NestedFields | None = None, **kwargs): + super().__init__(**kwargs) + self.arg1 = arg1 + self.arg2 = arg2 + + def _render_nested_template_fields(self, content, context, jinja_env, seen_oids) -> None: + if id(content) not in seen_oids: + template_fields: tuple | None = None + + if isinstance(content, NestedFields): + template_fields = ("field_1", "field_2") + + if template_fields: + seen_oids.add(id(content)) + self._do_render_template_fields(content, template_fields, context, jinja_env, seen_oids) + return + + super()._render_nested_template_fields(content, context, jinja_env, seen_oids) + + +def test_find_mapped_dependants_in_another_group(): + from airflow.decorators import task as task_decorator + from airflow.sdk import TaskGroup + + @task_decorator + def gen(x): + return list(range(x)) + + @task_decorator + def add(x, y): + return x + y + + with DAG(dag_id="test"): + with TaskGroup(group_id="g1"): + gen_result = gen(3) + with TaskGroup(group_id="g2"): + add_result = add.partial(y=1).expand(x=gen_result) + + dependants = list(gen_result.operator.iter_mapped_dependants()) + assert dependants == [add_result.operator] + + +@pytest.mark.parametrize( + "partial_params, mapped_params, expected", + [ + pytest.param(None, [{"a": 1}], [{"a": 1}], id="simple"), + pytest.param({"b": 2}, [{"a": 1}], [{"a": 1, "b": 2}], id="merge"), + pytest.param({"b": 2}, [{"a": 1, "b": 3}], [{"a": 1, "b": 3}], id="override"), + pytest.param({"b": 2}, [{"a": 1, "b": 3}, {"b": 1}], [{"a": 1, "b": 3}, {"b": 1}], id="multiple"), + ], +) +def test_mapped_expand_against_params(create_runtime_ti, partial_params, mapped_params, expected): + with DAG("test"): + task = BaseOperator.partial(task_id="t", params=partial_params).expand(params=mapped_params) + + for map_index, expansion in enumerate(expected): + mapped_ti = create_runtime_ti(task=task, map_index=map_index) + mapped_ti.task.render_template_fields(context=mapped_ti.get_template_context()) + assert mapped_ti.task.params == expansion + + +def test_operator_mapped_task_group_receives_value(create_runtime_ti, mock_supervisor_comms): + # Test the runtime expansion behaviour of mapped task groups + mapped operators + results = {} + + from airflow.decorators import task_group + + with DAG("test") as dag: + + @dag.task + def t(value, *, ti=None): + results[(ti.task_id, ti.map_index)] = value + return value + + @task_group + def tg(va): + # Each expanded group has one t1 and t2 each. + t1 = t.override(task_id="t1")(va) + t2 = t.override(task_id="t2")(t1) + + with pytest.raises(NotImplementedError) as ctx: + t.override(task_id="t4").expand(value=va) + assert str(ctx.value) == "operator expansion in an expanded task group is not yet supported" + + return t2 + + # The group is mapped by 3. + t2 = tg.expand(va=[["a", "b"], [4], ["z"]]) + + # Aggregates results from task group. + t.override(task_id="t3")(t2) + + def xcom_get(): + # TODO: Tidy this after #45927 is reopened and fixed properly + last_request = mock_supervisor_comms.send_request.mock_calls[-1].kwargs["msg"] + if not isinstance(last_request, GetXCom): + return mock.DEFAULT + key = (last_request.task_id, last_request.map_index) + if key in expected_values: + value = expected_values[key] + return XComResult(key="return_value", value=json.dumps(value)) + elif last_request.map_index is None: + # Get all mapped XComValues for this ti + value = [v for k, v in expected_values.items() if k[0] == last_request.task_id] + return XComResult(key="return_value", value=json.dumps(value)) + return mock.DEFAULT + + mock_supervisor_comms.get_message.side_effect = xcom_get + + expected_values = { + ("tg.t1", 0): ["a", "b"], + ("tg.t1", 1): [4], + ("tg.t1", 2): ["z"], + ("tg.t2", 0): ["a", "b"], + ("tg.t2", 1): [4], + ("tg.t2", 2): ["z"], + ("t3", None): [["a", "b"], [4], ["z"]], + } + + # We hard-code the number of expansions here as the server is in charge of that. + expansion_per_task_id = { + "tg.t1": range(3), + "tg.t2": range(3), + "t3": [None], + } + for task in dag.tasks: + for map_index in expansion_per_task_id[task.task_id]: + mapped_ti = create_runtime_ti(task=task.prepare_for_execution(), map_index=map_index) + context = mapped_ti.get_template_context() + mapped_ti.task.render_template_fields(context) + mapped_ti.task.execute(context) + assert results == expected_values + + +@pytest.mark.xfail(reason="SkipMixin hasn't been ported over to use the Task Execution API yet") +def test_mapped_xcom_push_skipped_tasks(create_runtime_ti, mock_supervisor_comms): + from airflow.decorators import task_group + from airflow.operators.empty import EmptyOperator + + if TYPE_CHECKING: + from airflow.providers.standard.operators.python import ShortCircuitOperator + else: + ShortCircuitOperator = pytest.importorskip( + "airflow.providers.standard.operators.python" + ).ShortCircuitOperator + + with DAG("test") as dag: + + @task_group + def group(x): + short_op_push_xcom = ShortCircuitOperator( + task_id="push_xcom_from_shortcircuit", + python_callable=lambda arg: arg % 2 == 0, + op_kwargs={"arg": x}, + ) + empty_task = EmptyOperator(task_id="empty_task") + short_op_push_xcom >> empty_task + + group.expand(x=[0, 1]) + + for task in dag.tasks: + for map_index in range(2): + ti = create_runtime_ti(task=task.prepare_for_execution(), map_index=map_index) + context = ti.get_template_context() + ti.task.render_template_fields(context) + ti.task.execute(context) + + assert ti + # TODO: these tests might not be right + mock_supervisor_comms.send_request.assert_has_calls( + [ + SetXCom( + key="skipmixin_key", + value=None, + dag_id=ti.dag_id, + run_id=ti.run_id, + task_id="group.push_xcom_from_shortcircuit", + map_index=0, + ), + SetXCom( + key="return_value", + value=True, + dag_id=ti.dag_id, + run_id=ti.run_id, + task_id="group.push_xcom_from_shortcircuit", + map_index=0, + ), + SetXCom( + key="skipmixin_key", + value={"skipped": ["group.empty_task"]}, + dag_id=ti.dag_id, + run_id=ti.run_id, + task_id="group.push_xcom_from_shortcircuit", + map_index=1, + ), + ] + ) + # + # assert ( + # tis[0].xcom_pull(task_ids="group.push_xcom_from_shortcircuit", key="return_value", map_indexes=0) + # is True + # ) + # assert ( + # tis[0].xcom_pull(task_ids="group.push_xcom_from_shortcircuit", key="skipmixin_key", map_indexes=0) + # is None + # ) + # assert tis[0].xcom_pull( + # task_ids="group.push_xcom_from_shortcircuit", key="skipmixin_key", map_indexes=1 + # ) == {"skipped": ["group.empty_task"]} diff --git a/task_sdk/tests/definitions/test_xcom_arg.py b/task_sdk/tests/definitions/test_xcom_arg.py new file mode 100644 index 0000000000000..d4964cf250b03 --- /dev/null +++ b/task_sdk/tests/definitions/test_xcom_arg.py @@ -0,0 +1,375 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +from typing import Callable +from unittest import mock + +import pytest +import structlog +from pytest_unordered import unordered + +from airflow.exceptions import AirflowSkipException +from airflow.sdk.api.datamodels._generated import TerminalTIState +from airflow.sdk.definitions.dag import DAG +from airflow.sdk.execution_time.comms import GetXCom, SucceedTask, TaskState, XComResult + +log = structlog.get_logger() + +RunTI = Callable[[DAG, str, int], TerminalTIState] + + +@pytest.fixture +def run_ti(create_runtime_ti, mock_supervisor_comms) -> RunTI: + def run(dag: DAG, task_id: str, map_index: int): + """Run the task and return the state that the SDK sent as the result for easier asserts""" + from airflow.sdk.execution_time.task_runner import run + + mock_supervisor_comms.send_request.reset_mock() + ti = create_runtime_ti(dag.task_dict[task_id], map_index=map_index) + run(ti, log) + for call in mock_supervisor_comms.send_request.mock_calls: + msg = call.kwargs["msg"] + if isinstance(msg, (TaskState, SucceedTask)): + return msg.state + raise RuntimeError("Unable to find call to TaskState") + + return run + + +def test_xcom_map(run_ti: RunTI, mock_supervisor_comms): + results = set() + with DAG("test") as dag: + + @dag.task + def push(): + return ["a", "b", "c"] + + @dag.task + def pull(value): + results.add(value) + + pull.expand_kwargs(push().map(lambda v: {"value": v * 2})) + + # The function passed to "map" is *NOT* a task. + assert set(dag.task_dict) == {"push", "pull"} + + # Mock xcom result from push task + mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value='["a", "b", "c"]') + + for map_index in range(3): + assert run_ti(dag, "pull", map_index) == TerminalTIState.SUCCESS + + assert results == {"aa", "bb", "cc"} + + +def test_xcom_map_transform_to_none(run_ti: RunTI, mock_supervisor_comms): + results = set() + + with DAG("test") as dag: + + @dag.task() + def push(): + return ["a", "b", "c"] + + @dag.task() + def pull(value): + results.add(value) + + def c_to_none(v): + if v == "c": + return None + return v + + pull.expand(value=push().map(c_to_none)) + + # Mock xcom result from push task + mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value='["a", "b", "c"]') + + # Run "pull". This should automatically convert "c" to None. + for map_index in range(3): + assert run_ti(dag, "pull", map_index) == TerminalTIState.SUCCESS + + assert results == {"a", "b", None} + + +def test_xcom_convert_to_kwargs_fails_task(run_ti: RunTI, mock_supervisor_comms, captured_logs): + results = set() + + with DAG("test") as dag: + + @dag.task() + def push(): + return ["a", "b", "c"] + + @dag.task() + def pull(value): + results.add(value) + + def c_to_none(v): + if v == "c": + return None + return {"value": v} + + pull.expand_kwargs(push().map(c_to_none)) + + # Mock xcom result from push task + mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value='["a", "b", "c"]') + + # The first two "pull" tis should succeed. + for map_index in range(2): + assert run_ti(dag, "pull", map_index) == TerminalTIState.SUCCESS + + # But the third one fails because the map() result cannot be used as kwargs. + assert run_ti(dag, "pull", 2) == TerminalTIState.FAILED + + assert captured_logs == unordered( + [ + { + "event": "Task failed with exception", + "level": "error", + "timestamp": mock.ANY, + "exception": [ + { + "exc_type": "ValueError", + "exc_value": "expand_kwargs() expects a list[dict], not list[None]", + "frames": mock.ANY, + "is_cause": False, + "syntax_error": None, + } + ], + }, + ] + ) + + +def test_xcom_map_error_fails_task(mock_supervisor_comms, run_ti, captured_logs): + with DAG("test") as dag: + + @dag.task() + def push(): + return ["a", "b", "c"] + + @dag.task() + def pull(value): + print(value) + + def does_not_work_with_c(v): + if v == "c": + raise RuntimeError("nope") + return {"value": v * 2} + + pull.expand_kwargs(push().map(does_not_work_with_c)) + + # Mock xcom result from push task + mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value='["a", "b", "c"]') + # The third one (for "c") will fail. + assert run_ti(dag, "pull", 2) == TerminalTIState.FAILED + + assert captured_logs == unordered( + [ + { + "event": "Task failed with exception", + "level": "error", + "timestamp": mock.ANY, + "exception": [ + { + "exc_type": "RuntimeError", + "exc_value": "nope", + "frames": mock.ANY, + "is_cause": False, + "syntax_error": None, + } + ], + }, + ] + ) + + +def test_xcom_map_nest(mock_supervisor_comms, run_ti): + results = set() + + with DAG("test") as dag: + + @dag.task() + def push(): + return ["a", "b", "c"] + + @dag.task() + def pull(value): + results.add(value) + + converted = push().map(lambda v: v * 2).map(lambda v: {"value": v}) + pull.expand_kwargs(converted) + + # Mock xcom result from push task + mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value='["a", "b", "c"]') + + # Now "pull" should apply the mapping functions in order. + for map_index in range(3): + assert run_ti(dag, "pull", map_index) == TerminalTIState.SUCCESS + assert results == {"aa", "bb", "cc"} + + +def test_xcom_map_zip_nest(mock_supervisor_comms, run_ti): + results = set() + + with DAG("test") as dag: + + @dag.task + def push_letters(): + return ["a", "b", "c", "d"] + + @dag.task + def push_numbers(): + return [1, 2, 3, 4] + + @dag.task + def pull(value): + results.add(value) + + doubled = push_numbers().map(lambda v: v * 2) + combined = doubled.zip(push_letters()) + + def convert_zipped(zipped): + letter, number = zipped + return letter * number + + pull.expand(value=combined.map(convert_zipped)) + + def xcom_get(): + # TODO: Tidy this after #45927 is reopened and fixed properly + last_request = mock_supervisor_comms.send_request.mock_calls[-1].kwargs["msg"] + if not isinstance(last_request, GetXCom): + return mock.DEFAULT + if last_request.task_id == "push_letters": + value = json.dumps(push_letters.function()) + return XComResult(key="return_value", value=value) + if last_request.task_id == "push_numbers": + value = json.dumps(push_numbers.function()) + return XComResult(key="return_value", value=value) + return mock.DEFAULT + + mock_supervisor_comms.get_message.side_effect = xcom_get + + # Run "pull". + for map_index in range(4): + assert run_ti(dag, "pull", map_index) == TerminalTIState.SUCCESS + + assert results == {"aa", "bbbb", "cccccc", "dddddddd"} + + +def test_xcom_map_raise_to_skip(run_ti, mock_supervisor_comms): + result = [] + + with DAG("test") as dag: + + @dag.task() + def push(): + return ["a", "b", "c"] + + @dag.task() + def forward(value): + result.append(value) + + def skip_c(v): + if v == "c": + raise AirflowSkipException() + return {"value": v} + + forward.expand_kwargs(push().map(skip_c)) + + # Mock xcom result from push task + mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value='["a", "b", "c"]') + + # Run "forward". This should automatically skip "c". + states = [run_ti(dag, "forward", map_index) for map_index in range(3)] + + assert states == [TerminalTIState.SUCCESS, TerminalTIState.SUCCESS, TerminalTIState.SKIPPED] + + assert result == ["a", "b"] + + +def test_xcom_concat(run_ti, mock_supervisor_comms): + from airflow.sdk.definitions.xcom_arg import _ConcatResult + + agg_results = set() + all_results = None + + with DAG("test") as dag: + + @dag.task + def push_letters(): + return ["a", "b", "c"] + + @dag.task + def push_numbers(): + return [1, 2] + + @dag.task + def pull_one(value): + agg_results.add(value) + + @dag.task + def pull_all(value): + assert isinstance(value, _ConcatResult) + assert value[0] == "a" + assert value[1] == "b" + assert value[2] == "c" + assert value[3] == 1 + assert value[4] == 2 + with pytest.raises(IndexError): + value[5] + assert value[-5] == "a" + assert value[-4] == "b" + assert value[-3] == "c" + assert value[-2] == 1 + assert value[-1] == 2 + with pytest.raises(IndexError): + value[-6] + nonlocal all_results + all_results = list(value) + + pushed_values = push_letters().concat(push_numbers()) + + pull_one.expand(value=pushed_values) + pull_all(pushed_values) + + def xcom_get(): + # TODO: Tidy this after #45927 is reopened and fixed properly + last_request = mock_supervisor_comms.send_request.mock_calls[-1].kwargs["msg"] + if not isinstance(last_request, GetXCom): + return mock.DEFAULT + if last_request.task_id == "push_letters": + value = json.dumps(push_letters.function()) + return XComResult(key="return_value", value=value) + if last_request.task_id == "push_numbers": + value = json.dumps(push_numbers.function()) + return XComResult(key="return_value", value=value) + return mock.DEFAULT + + mock_supervisor_comms.get_message.side_effect = xcom_get + + # Run "pull_one" and "pull_all". + assert run_ti(dag, "pull_all", None) == TerminalTIState.SUCCESS + assert all_results == ["a", "b", "c", 1, 2] + + states = [run_ti(dag, "pull_one", map_index) for map_index in range(5)] + assert states == [TerminalTIState.SUCCESS] * 5 + assert agg_results == {"a", "b", "c", 1, 2} diff --git a/task_sdk/tests/execution_time/conftest.py b/task_sdk/tests/execution_time/conftest.py index ac0c21246c1ce..4a537373363aa 100644 --- a/task_sdk/tests/execution_time/conftest.py +++ b/task_sdk/tests/execution_time/conftest.py @@ -18,14 +18,6 @@ from __future__ import annotations import sys -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from datetime import datetime - - from airflow.sdk.definitions.baseoperator import BaseOperator - from airflow.sdk.execution_time.comms import StartupDetails - from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance import pytest @@ -39,117 +31,3 @@ def disable_capturing(): sys.stderr = sys.__stderr__ yield sys.stdin, sys.stdout, sys.stderr = old_in, old_out, old_err - - -@pytest.fixture -def mocked_parse(spy_agency): - """ - Fixture to set up an inline DAG and use it in a stubbed `parse` function. Use this fixture if you - want to isolate and test `parse` or `run` logic without having to define a DAG file. - - This fixture returns a helper function `set_dag` that: - 1. Creates an in line DAG with the given `dag_id` and `task` (limited to one task) - 2. Constructs a `RuntimeTaskInstance` based on the provided `StartupDetails` and task. - 3. Stubs the `parse` function using `spy_agency`, to return the mocked `RuntimeTaskInstance`. - - After adding the fixture in your test function signature, you can use it like this :: - - mocked_parse( - StartupDetails( - ti=TaskInstance(id=uuid7(), task_id="hello", dag_id="super_basic_run", run_id="c", try_number=1), - file="", - requests_fd=0, - ), - "example_dag_id", - CustomOperator(task_id="hello"), - ) - """ - - def set_dag(what: StartupDetails, dag_id: str, task: BaseOperator) -> RuntimeTaskInstance: - from airflow.sdk.definitions.dag import DAG - from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance, parse - from airflow.utils import timezone - - dag = DAG(dag_id=dag_id, start_date=timezone.datetime(2024, 12, 3)) - if what.ti_context.dag_run.conf: - dag.params = what.ti_context.dag_run.conf # type: ignore[assignment] - task.dag = dag - t = dag.task_dict[task.task_id] - ti = RuntimeTaskInstance.model_construct( - **what.ti.model_dump(exclude_unset=True), - task=t, - _ti_context_from_server=what.ti_context, - max_tries=what.ti_context.max_tries, - ) - spy_agency.spy_on(parse, call_fake=lambda _: ti) - return ti - - return set_dag - - -@pytest.fixture -def create_runtime_ti(mocked_parse, make_ti_context): - """ - Fixture to create a Runtime TaskInstance for testing purposes without defining a dag file. - - It mimics the behavior of the `parse` function by creating a `RuntimeTaskInstance` based on the provided - `StartupDetails` (formed from arguments) and task. This allows you to test the logic of a task without - having to define a DAG file, parse it, get context from the server, etc. - - Example usage: :: - - def test_custom_task_instance(create_runtime_ti): - class MyTaskOperator(BaseOperator): - def execute(self, context): - assert context["dag_run"].run_id == "test_run" - - task = MyTaskOperator(task_id="test_task") - ti = create_runtime_ti(task, context_from_server=make_ti_context(run_id="test_run")) - # Further test logic... - """ - from uuid6 import uuid7 - - from airflow.sdk.api.datamodels._generated import TaskInstance - from airflow.sdk.execution_time.comms import BundleInfo, StartupDetails - - def _create_task_instance( - task: BaseOperator, - dag_id: str = "test_dag", - run_id: str = "test_run", - logical_date: str | datetime = "2024-12-01T01:00:00Z", - data_interval_start: str | datetime = "2024-12-01T00:00:00Z", - data_interval_end: str | datetime = "2024-12-01T01:00:00Z", - start_date: str | datetime = "2024-12-01T01:00:00Z", - run_type: str = "manual", - try_number: int = 1, - conf=None, - ti_id=None, - ) -> RuntimeTaskInstance: - if not ti_id: - ti_id = uuid7() - - ti_context = make_ti_context( - dag_id=dag_id, - run_id=run_id, - logical_date=logical_date, - data_interval_start=data_interval_start, - data_interval_end=data_interval_end, - start_date=start_date, - run_type=run_type, - conf=conf, - ) - - startup_details = StartupDetails( - ti=TaskInstance( - id=ti_id, task_id=task.task_id, dag_id=dag_id, run_id=run_id, try_number=try_number - ), - dag_rel_path="", - bundle_info=BundleInfo(name="anything", version="any"), - requests_fd=0, - ti_context=ti_context, - ) - - ti = mocked_parse(startup_details, dag_id, task) - return ti - - return _create_task_instance diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index 99e380754d4d0..8c280ce0634ee 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -1027,7 +1027,7 @@ def execute(self, context): dag_id="test_dag", run_id="test_run", task_id=task_id, - map_index=-1, + map_index=None, ), ) diff --git a/task_sdk/tests/test_log.py b/task_sdk/tests/test_log.py index bf00f33e9a7ec..338e672ef35bf 100644 --- a/task_sdk/tests/test_log.py +++ b/task_sdk/tests/test_log.py @@ -50,7 +50,7 @@ def test_json_rendering(captured_logs): assert isinstance(captured_logs[0], bytes) assert json.loads(captured_logs[0]) == { "event": "A test message with a Pydantic class", - "pydantic_class": "TaskInstance(id=UUID('ffec3c8e-2898-46f8-b7d5-3cc571577368'), task_id='test_task', dag_id='test_dag', run_id='test_run', try_number=1, map_index=-1, hostname=None)", + "pydantic_class": "TaskInstance(id=UUID('ffec3c8e-2898-46f8-b7d5-3cc571577368'), task_id='test_task', dag_id='test_dag', run_id='test_run', try_number=1, map_index=None, hostname=None)", "timestamp": unittest.mock.ANY, "level": "info", } diff --git a/tests/decorators/test_task_group.py b/tests/decorators/test_task_group.py index fb38c95759a26..971b274a80f71 100644 --- a/tests/decorators/test_task_group.py +++ b/tests/decorators/test_task_group.py @@ -215,6 +215,7 @@ def tg(a, b): @pytest.mark.db_test +@pytest.mark.need_serialized_dag def test_task_group_expand_kwargs_with_upstream(dag_maker, session, caplog): with dag_maker() as dag: @@ -239,6 +240,7 @@ def t2(): @pytest.mark.db_test +@pytest.mark.need_serialized_dag def test_task_group_expand_with_upstream(dag_maker, session, caplog): with dag_maker() as dag: diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 78fffd45fd098..affb2662aa834 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -150,13 +150,13 @@ def test_dags_bundle(configure_testing_dag_bundle): def _create_dagrun( dag: DAG, *, - logical_date: DateTime, + logical_date: DateTime | datetime.datetime, data_interval: DataInterval, run_type: DagRunType, state: DagRunState = DagRunState.RUNNING, start_date: datetime.datetime | None = None, + **kwargs, ) -> DagRun: - triggered_by_kwargs: dict = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} run_id = dag.timetable.generate_run_id( run_type=run_type, logical_date=logical_date, # type: ignore @@ -169,7 +169,8 @@ def _create_dagrun( run_type=run_type, state=state, start_date=start_date, - **triggered_by_kwargs, + triggered_by=DagRunTriggeredByType.TEST, + **kwargs, ) @@ -1539,15 +1540,14 @@ def consumer(value): PythonOperator.partial(task_id=task_id, python_callable=consumer).expand(op_args=make_arg_lists()) session = dag_maker.session - dagrun_1 = dag.create_dagrun( - run_id="backfill", + dagrun_1 = _create_dagrun( + dag, run_type=DagRunType.BACKFILL_JOB, - state=State.FAILED, + state=DagRunState.FAILED, start_date=DEFAULT_DATE, logical_date=DEFAULT_DATE, - session=session, data_interval=(DEFAULT_DATE, DEFAULT_DATE), - triggered_by=DagRunTriggeredByType.TEST, + session=session, ) # Get the (de)serialized MappedOperator mapped = dag.get_task(task_id) diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index b5afa6301f4a6..d12cef3a0c8cf 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -57,7 +57,7 @@ if AIRFLOW_V_3_0_PLUS: from airflow.utils.types import DagRunTriggeredByType -pytestmark = pytest.mark.db_test +pytestmark = [pytest.mark.db_test, pytest.mark.need_serialized_dag] if TYPE_CHECKING: @@ -329,7 +329,7 @@ def test_dagrun_deadlock(self, dag_maker, session): assert dr.state == DagRunState.RUNNING ti_op2.set_state(state=None, session=session) - op2.trigger_rule = "invalid" # type: ignore + ti_op2.task.trigger_rule = "invalid" # type: ignore dr.update_state(session=session) assert dr.state == DagRunState.FAILED @@ -738,7 +738,6 @@ def mutate_task_instance(task_instance): (None, False), ], ) - @pytest.mark.need_serialized_dag def test_depends_on_past(self, dag_maker, session, prev_ti_state, is_ti_schedulable): # DAG tests depends_on_past dependencies with dag_maker( @@ -781,7 +780,6 @@ def test_depends_on_past(self, dag_maker, session, prev_ti_state, is_ti_schedula (None, False), ], ) - @pytest.mark.need_serialized_dag def test_wait_for_downstream(self, dag_maker, session, prev_ti_state, is_ti_schedulable): dag_id = "test_wait_for_downstream" @@ -1069,7 +1067,6 @@ def test_expand_mapped_task_instance_at_create(is_noop, dag_maker, session): assert indices == [(0,), (1,), (2,), (3,)] -@pytest.mark.need_serialized_dag @pytest.mark.parametrize("is_noop", [True, False]) def test_expand_mapped_task_instance_task_decorator(is_noop, dag_maker, session): with mock.patch("airflow.settings.task_instance_mutation_hook") as mock_mut: @@ -1093,7 +1090,6 @@ def mynameis(arg): assert indices == [(0,), (1,), (2,), (3,)] -@pytest.mark.need_serialized_dag def test_mapped_literal_verify_integrity(dag_maker, session): """Test that when the length of a mapped literal changes we remove extra TIs""" @@ -1126,7 +1122,6 @@ def task_2(arg2): ... assert indices == [(0, None), (1, None), (2, TaskInstanceState.REMOVED), (3, TaskInstanceState.REMOVED)] -@pytest.mark.need_serialized_dag def test_mapped_literal_to_xcom_arg_verify_integrity(dag_maker, session): """Test that when we change from literal to a XComArg the TIs are removed""" @@ -1160,7 +1155,6 @@ def task_2(arg2): ... ] -@pytest.mark.need_serialized_dag def test_mapped_literal_length_increase_adds_additional_ti(dag_maker, session): """Test that when the length of mapped literal increases, additional ti is added""" @@ -1203,7 +1197,6 @@ def task_2(arg2): ... ] -@pytest.mark.need_serialized_dag def test_mapped_literal_length_reduction_adds_removed_state(dag_maker, session): """Test that when the length of mapped literal reduces, removed state is added""" @@ -1244,7 +1237,6 @@ def task_2(arg2): ... ] -@pytest.mark.need_serialized_dag def test_mapped_length_increase_at_runtime_adds_additional_tis(dag_maker, session): """Test that when the length of mapped literal increases at runtime, additional ti is added""" # Variable.set(key="arg1", value=[1, 2, 3]) @@ -1296,7 +1288,6 @@ def task_2(arg2): ... ] -@pytest.mark.need_serialized_dag def test_mapped_literal_length_reduction_at_runtime_adds_removed_state(dag_maker, session): """ Test that when the length of mapped literal reduces at runtime, the missing task instances @@ -1384,7 +1375,6 @@ def task_2(arg2): ... assert len(decision.schedulable_tis) == 2 -@pytest.mark.need_serialized_dag def test_calls_to_verify_integrity_with_mapped_task_zero_length_at_runtime(dag_maker, session, caplog): """ Test zero length reduction in mapped task at runtime with calls to dagrun.verify_integrity @@ -1447,7 +1437,6 @@ def task_2(arg2): ... ) -@pytest.mark.need_serialized_dag def test_mapped_mixed_literal_not_expanded_at_create(dag_maker, session): literal = [1, 2, 3, 4] with dag_maker(session=session): @@ -1624,8 +1613,6 @@ def consumer(*args): def test_mapped_task_all_finish_before_downstream(dag_maker, session): - result = None - with dag_maker(session=session) as dag: @dag.task @@ -1638,8 +1625,8 @@ def double(value): @dag.task def consumer(value): - nonlocal result - result = list(value) + ... + # result = list(value) consumer(value=double.expand(value=make_list())) @@ -1653,26 +1640,29 @@ def _task_ids(tis): assert _task_ids(decision.schedulable_tis) == ["make_list"] # After make_list is run, double is expanded. - decision.schedulable_tis[0].run(verbose=False, session=session) + ti = decision.schedulable_tis[0] + ti.state = TaskInstanceState.SUCCESS + session.add(TaskMap.from_task_instance_xcom(ti, [1, 2])) + session.flush() + decision = dr.task_instance_scheduling_decisions(session=session) assert _task_ids(decision.schedulable_tis) == ["double", "double"] # Running just one of the mapped tis does not make downstream schedulable. - decision.schedulable_tis[0].run(verbose=False, session=session) + ti = decision.schedulable_tis[0] + ti.state = TaskInstanceState.SUCCESS + session.flush() + decision = dr.task_instance_scheduling_decisions(session=session) assert _task_ids(decision.schedulable_tis) == ["double"] - # Downstream is schedulable after all mapped tis are run. - decision.schedulable_tis[0].run(verbose=False, session=session) + # Downstream is scheduleable after all mapped tis are run. + ti = decision.schedulable_tis[0] + ti.state = TaskInstanceState.SUCCESS + session.flush() decision = dr.task_instance_scheduling_decisions(session=session) assert _task_ids(decision.schedulable_tis) == ["consumer"] - # We should be able to get all values aggregated from mapped upstreams. - decision.schedulable_tis[0].run(verbose=False, session=session) - decision = dr.task_instance_scheduling_decisions(session=session) - assert decision.schedulable_tis == [] - assert result == [2, 4] - def test_schedule_tis_map_index(dag_maker, session): with dag_maker(session=session, dag_id="test"): @@ -1750,6 +1740,7 @@ def test_schedule_tis_empty_operator_try_number(dag_maker, session: Session): assert empty_ti.try_number == 1 +@pytest.mark.xfail(reason="We can't keep this bevaviour with remote workers where scheduler can't reach xcom") def test_schedule_tis_start_trigger_through_expand(dag_maker, session): """ Test that an operator with start_trigger_args set can be directly deferred during scheduling. @@ -1904,28 +1895,18 @@ def do_something_else(i): @pytest.mark.parametrize( "partial_params, mapped_params, expected", [ - pytest.param(None, [{"a": 1}], [[("a", 1)]], id="simple"), - pytest.param({"b": 2}, [{"a": 1}], [[("a", 1), ("b", 2)]], id="merge"), - pytest.param({"b": 2}, [{"a": 1, "b": 3}], [[("a", 1), ("b", 3)]], id="override"), + pytest.param(None, [{"a": 1}], 1, id="simple"), + pytest.param({"b": 2}, [{"a": 1}], 1, id="merge"), + pytest.param({"b": 2}, [{"a": 1, "b": 3}], 1, id="override"), ], ) def test_mapped_expand_against_params(dag_maker, partial_params, mapped_params, expected): - results = [] - - class PullOperator(BaseOperator): - def execute(self, context): - results.append(sorted(context["params"].items())) - with dag_maker(): - PullOperator.partial(task_id="t", params=partial_params).expand(params=mapped_params) + BaseOperator.partial(task_id="t", params=partial_params).expand(params=mapped_params) dr: DagRun = dag_maker.create_dagrun() decision = dr.task_instance_scheduling_decisions() - - for ti in decision.schedulable_tis: - ti.run() - - assert sorted(results) == expected + assert len(decision.schedulable_tis) == expected def test_mapped_task_group_expands(dag_maker, session): @@ -1970,9 +1951,7 @@ def test_operator_mapped_task_group_receives_value(dag_maker, session): with dag_maker(session=session): @task - def t(value, *, ti=None): - results[(ti.task_id, ti.map_index)] = value - return value + def t(value): ... @task_group def tg(va): @@ -1994,24 +1973,29 @@ def tg(va): dr: DagRun = dag_maker.create_dagrun() - results = {} + results = set() decision = dr.task_instance_scheduling_decisions(session=session) for ti in decision.schedulable_tis: - ti.run() - assert results == {("tg.t1", 0): ["a", "b"], ("tg.t1", 1): [4], ("tg.t1", 2): ["z"]} + results.add((ti.task_id, ti.map_index)) + ti.state = TaskInstanceState.SUCCESS + session.flush() + assert results == {("tg.t1", 0), ("tg.t1", 1), ("tg.t1", 2)} - results = {} + results.clear() decision = dr.task_instance_scheduling_decisions(session=session) for ti in decision.schedulable_tis: - ti.run() - assert results == {("tg.t2", 0): ["a", "b"], ("tg.t2", 1): [4], ("tg.t2", 2): ["z"]} + results.add((ti.task_id, ti.map_index)) + ti.state = TaskInstanceState.SUCCESS + session.flush() + assert results == {("tg.t2", 0), ("tg.t2", 1), ("tg.t2", 2)} - results = {} + results.clear() decision = dr.task_instance_scheduling_decisions(session=session) for ti in decision.schedulable_tis: - ti.run() - assert len(results) == 1 - assert list(results[("t3", -1)]) == [["a", "b"], [4], ["z"]] + results.add((ti.task_id, ti.map_index)) + ti.state = TaskInstanceState.SUCCESS + session.flush() + assert results == {("t3", -1)} def test_mapping_against_empty_list(dag_maker, session): @@ -2079,13 +2063,15 @@ def print_value(value): decision = dr1.task_instance_scheduling_decisions(session=session) assert len(decision.schedulable_tis) == 2 for ti in decision.schedulable_tis: - ti.run(session=session) + ti.state = TaskInstanceState.SUCCESS + session.flush() # Now print_value in dr2 can run decision = dr2.task_instance_scheduling_decisions(session=session) assert len(decision.schedulable_tis) == 2 for ti in decision.schedulable_tis: - ti.run(session=session) + ti.state = TaskInstanceState.SUCCESS + session.flush() # Both runs are finished now. decision = dr1.task_instance_scheduling_decisions(session=session) @@ -2094,6 +2080,69 @@ def print_value(value): assert len(decision.unfinished_tis) == 0 +def test_xcom_map_skip_raised(dag_maker, session): + result = None + + with dag_maker(session=session) as dag: + # Note: this doesn't actually run this dag, the callbacks are for reference only. + + @dag.task() + def push(): + return ["a", "b", "c"] + + @dag.task() + def forward(value): + return value + + @dag.task(trigger_rule=TriggerRule.ALL_DONE) + def collect(value): + nonlocal result + result = list(value) + + def skip_c(v): + ... + # if v == "c": + # raise AirflowSkipException + # return {"value": v} + + collect(value=forward.expand_kwargs(push().map(skip_c))) + + dr: DagRun = dag_maker.create_dagrun(session=session) + + def _task_ids(tis): + return [(ti.task_id, ti.map_index) for ti in tis] + + # Check that when forward w/ map_index=2 ends up skipping, that the collect task can still be + # scheduled! + + # Run "push". + decision = dr.task_instance_scheduling_decisions(session=session) + assert _task_ids(decision.schedulable_tis) == [("push", -1)] + ti = decision.schedulable_tis[0] + ti.state = TaskInstanceState.SUCCESS + session.add(TaskMap.from_task_instance_xcom(ti, push.function())) + session.flush() + + decision = dr.task_instance_scheduling_decisions(session=session) + assert _task_ids(decision.schedulable_tis) == [ + ("forward", 0), + ("forward", 1), + ("forward", 2), + ] + # Run "forward". "c"/index 2 is skipped. Runtime behaviour checked in test_xcom_map_raise_to_skip in + # TaskSDK + for ti, state in zip( + decision.schedulable_tis, + [TaskInstanceState.SUCCESS, TaskInstanceState.SUCCESS, TaskInstanceState.SKIPPED], + ): + ti.state = state + session.flush() + + # Now "collect" should only get "a" and "b". + decision = dr.task_instance_scheduling_decisions(session=session) + assert _task_ids(decision.schedulable_tis) == [("collect", -1)] + + def test_clearing_task_and_moving_from_non_mapped_to_mapped(dag_maker, session): """ Test that clearing a task and moving from non-mapped to mapped clears existing @@ -2365,7 +2414,7 @@ def make_task(task_id, dag): with dag_maker() as dag: for line in input: - tasks = [make_task(x, dag) for x in line.split(" >> ")] + tasks = [make_task(x, dag_maker.dag) for x in line.split(" >> ")] reduce(lambda x, y: x >> y, tasks) dr = dag_maker.create_dagrun() diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py index 2d06dc6216f84..0d7e37f90e50c 100644 --- a/tests/models/test_mappedoperator.py +++ b/tests/models/test_mappedoperator.py @@ -32,19 +32,12 @@ from airflow.models.taskinstance import TaskInstance from airflow.models.taskmap import TaskMap from airflow.providers.standard.operators.python import PythonOperator -from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.utils.state import TaskInstanceState from airflow.utils.task_group import TaskGroup -from airflow.utils.task_instance_session import set_current_task_instance_session -from airflow.utils.xcom import XCOM_RETURN_KEY from tests.models import DEFAULT_DATE from tests_common.test_utils.mapping import expand_mapped_task -from tests_common.test_utils.mock_operators import ( - MockOperator, - MockOperatorWithNestedFields, - NestedFields, -) +from tests_common.test_utils.mock_operators import MockOperator pytestmark = pytest.mark.db_test @@ -255,151 +248,6 @@ def test_expand_mapped_task_instance_skipped_on_zero(dag_maker, session): assert indices == [(-1, TaskInstanceState.SKIPPED)] -class _RenderTemplateFieldsValidationOperator(BaseOperator): - template_fields = ( - "partial_template", - "map_template_xcom", - "map_template_literal", - "map_template_file", - ) - template_ext = (".ext",) - - fields_to_test = [ - "partial_template", - "partial_static", - "map_template_xcom", - "map_template_literal", - "map_static", - "map_template_file", - ] - - def __init__( - self, - partial_template, - partial_static, - map_template_xcom, - map_template_literal, - map_static, - map_template_file, - **kwargs, - ): - for field in self.fields_to_test: - setattr(self, field, value := locals()[field]) - assert isinstance(value, str), "value should have been resolved before unmapping" - super().__init__(**kwargs) - - def execute(self, context): - pass - - -def test_mapped_render_template_fields_validating_operator(dag_maker, session, tmp_path): - file_template_dir = tmp_path / "path" / "to" - file_template_dir.mkdir(parents=True, exist_ok=True) - file_template = file_template_dir / "file.ext" - file_template.write_text("loaded data") - - with set_current_task_instance_session(session=session): - with dag_maker(session=session, template_searchpath=tmp_path.__fspath__()): - task1 = BaseOperator(task_id="op1") - output1 = task1.output - mapped = _RenderTemplateFieldsValidationOperator.partial( - task_id="a", partial_template="{{ ti.task_id }}", partial_static="{{ ti.task_id }}" - ).expand( - map_static=output1, - map_template_literal=["{{ ds }}"], - map_template_xcom=output1, - map_template_file=["/path/to/file.ext"], - ) - - dr = dag_maker.create_dagrun() - ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session) - ti.xcom_push(key=XCOM_RETURN_KEY, value=["{{ ds }}"], session=session) - session.add( - TaskMap( - dag_id=dr.dag_id, - task_id=task1.task_id, - run_id=dr.run_id, - map_index=-1, - length=1, - keys=None, - ) - ) - session.flush() - - mapped_ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session) - mapped_ti.map_index = 0 - assert isinstance(mapped_ti.task, MappedOperator) - mapped.render_template_fields(context=mapped_ti.get_template_context(session=session)) - assert isinstance(mapped_ti.task, _RenderTemplateFieldsValidationOperator) - - assert mapped_ti.task.partial_template == "a", "Should be rendered!" - assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be rendered!" - assert mapped_ti.task.map_static == "{{ ds }}", "Should not be rendered!" - assert mapped_ti.task.map_template_literal == "2016-01-01", "Should be rendered!" - assert mapped_ti.task.map_template_xcom == "{{ ds }}", "XCom resolved but not double rendered!" - assert mapped_ti.task.map_template_file == "loaded data", "Should be rendered!" - - -def test_mapped_expand_kwargs_render_template_fields_validating_operator(dag_maker, session, tmp_path): - file_template_dir = tmp_path / "path" / "to" - file_template_dir.mkdir(parents=True, exist_ok=True) - file_template = file_template_dir / "file.ext" - file_template.write_text("loaded data") - - with set_current_task_instance_session(session=session): - with dag_maker(session=session, template_searchpath=tmp_path.__fspath__()): - mapped = _RenderTemplateFieldsValidationOperator.partial( - task_id="a", partial_template="{{ ti.task_id }}", partial_static="{{ ti.task_id }}" - ).expand_kwargs( - [ - { - "map_template_literal": "{{ ds }}", - "map_static": "{{ ds }}", - "map_template_file": "/path/to/file.ext", - # This field is not tested since XCom inside a literal list - # is not rendered (matching BaseOperator rendering behavior). - "map_template_xcom": "", - } - ] - ) - - dr = dag_maker.create_dagrun() - mapped_ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session, map_index=0) - assert isinstance(mapped_ti.task, MappedOperator) - mapped.render_template_fields(context=mapped_ti.get_template_context(session=session)) - assert isinstance(mapped_ti.task, _RenderTemplateFieldsValidationOperator) - - assert mapped_ti.task.partial_template == "a", "Should be rendered!" - assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be rendered!" - assert mapped_ti.task.map_template_literal == "2016-01-01", "Should be rendered!" - assert mapped_ti.task.map_static == "{{ ds }}", "Should not be rendered!" - assert mapped_ti.task.map_template_file == "loaded data", "Should be rendered!" - - -def test_mapped_render_nested_template_fields(dag_maker, session): - with dag_maker(session=session): - MockOperatorWithNestedFields.partial( - task_id="t", arg2=NestedFields(field_1="{{ ti.task_id }}", field_2="value_2") - ).expand(arg1=["{{ ti.task_id }}", ["s", "{{ ti.task_id }}"]]) - - dr = dag_maker.create_dagrun() - decision = dr.task_instance_scheduling_decisions() - tis = {(ti.task_id, ti.map_index): ti for ti in decision.schedulable_tis} - assert len(tis) == 2 - - ti = tis[("t", 0)] - ti.run(session=session) - assert ti.task.arg1 == "t" - assert ti.task.arg2.field_1 == "t" - assert ti.task.arg2.field_2 == "value_2" - - ti = tis[("t", 1)] - ti.run(session=session) - assert ti.task.arg1 == ["s", "t"] - assert ti.task.arg2.field_1 == "t" - assert ti.task.arg2.field_2 == "value_2" - - @pytest.mark.parametrize( ["num_existing_tis", "expected"], ( @@ -467,6 +315,45 @@ def test_expand_kwargs_mapped_task_instance(dag_maker, session, num_existing_tis assert indices == expected +def test_map_product_expansion(dag_maker, session): + """Test the cross-product effect of mapping two inputs""" + outputs = [] + + with dag_maker(dag_id="product", session=session) as dag: + + @dag.task + def emit_numbers(): + return [1, 2] + + @dag.task + def emit_letters(): + return {"a": "x", "b": "y", "c": "z"} + + @dag.task + def show(number, letter): + outputs.append((number, letter)) + + show.expand(number=emit_numbers(), letter=emit_letters()) + + dr = dag_maker.create_dagrun() + for fn in (emit_numbers, emit_letters): + session.add( + TaskMap( + dag_id=dr.dag_id, + task_id=fn.__name__, + run_id=dr.run_id, + map_index=-1, + length=len(fn.function()), + keys=None, + ) + ) + + session.flush() + show_task = dag.get_task("show") + mapped_tis, max_map_index = TaskMap.expand_mapped_task(show_task, dr.run_id, session=session) + assert max_map_index + 1 == len(mapped_tis) == 6 + + def _create_mapped_with_name_template_classic(*, task_id, map_names, template): class HasMapName(BaseOperator): def __init__(self, *, map_name: str, **kwargs): @@ -591,68 +478,6 @@ def test_expand_mapped_task_task_instance_mutation_hook(dag_maker, session, crea assert call.args[0].map_index == expected_map_index[index] -@pytest.mark.parametrize( - "map_index, expected", - [ - pytest.param(0, "2016-01-01", id="0"), - pytest.param(1, 2, id="1"), - ], -) -def test_expand_kwargs_render_template_fields_validating_operator(dag_maker, session, map_index, expected): - with set_current_task_instance_session(session=session): - with dag_maker(session=session): - task1 = BaseOperator(task_id="op1") - mapped = MockOperator.partial(task_id="a", arg2="{{ ti.task_id }}").expand_kwargs(task1.output) - - dr = dag_maker.create_dagrun() - ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session) - - ti.xcom_push(key=XCOM_RETURN_KEY, value=[{"arg1": "{{ ds }}"}, {"arg1": 2}], session=session) - - session.add( - TaskMap( - dag_id=dr.dag_id, - task_id=task1.task_id, - run_id=dr.run_id, - map_index=-1, - length=2, - keys=None, - ) - ) - session.flush() - - ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session) - ti.refresh_from_task(mapped) - ti.map_index = map_index - assert isinstance(ti.task, MappedOperator) - mapped.render_template_fields(context=ti.get_template_context(session=session)) - assert isinstance(ti.task, MockOperator) - assert ti.task.arg1 == expected - assert ti.task.arg2 == "a" - - -def test_all_xcomargs_from_mapped_tasks_are_consumable(dag_maker, session): - class PushXcomOperator(MockOperator): - def __init__(self, arg1, **kwargs): - super().__init__(arg1=arg1, **kwargs) - - def execute(self, context): - return self.arg1 - - class ConsumeXcomOperator(PushXcomOperator): - def execute(self, context): - assert set(self.arg1) == {1, 2, 3} - - with dag_maker("test_all_xcomargs_from_mapped_tasks_are_consumable"): - op1 = PushXcomOperator.partial(task_id="op1").expand(arg1=[1, 2, 3]) - ConsumeXcomOperator(task_id="op2", arg1=op1.output) - - dr = dag_maker.create_dagrun() - tis = dr.get_task_instances(session=session) - for ti in tis: - ti.run() - - class TestMappedSetupTeardown: @staticmethod def get_states(dr): diff --git a/tests/models/test_skipmixin.py b/tests/models/test_skipmixin.py index 7107520930246..bc0913e3bbabf 100644 --- a/tests/models/test_skipmixin.py +++ b/tests/models/test_skipmixin.py @@ -120,6 +120,7 @@ def get_state(ti): assert executed_states == expected_states + @pytest.mark.need_serialized_dag def test_mapped_tasks_skip_all_except(self, dag_maker): with dag_maker("dag_test_skip_all_except") as dag: diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 628945ebf2aaa..534a9ebd7940b 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -59,7 +59,7 @@ from airflow.models.dag import DAG from airflow.models.dagbag import DagBag from airflow.models.dagrun import DagRun -from airflow.models.expandinput import EXPAND_INPUT_EMPTY, NotFullyPopulated +from airflow.models.expandinput import EXPAND_INPUT_EMPTY from airflow.models.pool import Pool from airflow.models.renderedtifields import RenderedTaskInstanceFields from airflow.models.serialized_dag import SerializedDagModel @@ -4923,80 +4923,6 @@ def show(value): ti.run() assert outputs == expected_outputs - def test_map_product(self, dag_maker, session): - outputs = [] - - with dag_maker(dag_id="product", session=session) as dag: - - @dag.task - def emit_numbers(): - return [1, 2] - - @dag.task - def emit_letters(): - return {"a": "x", "b": "y", "c": "z"} - - @dag.task - def show(number, letter): - outputs.append((number, letter)) - - show.expand(number=emit_numbers(), letter=emit_letters()) - - dag_run = dag_maker.create_dagrun() - for task_id in ["emit_numbers", "emit_letters"]: - ti = dag_run.get_task_instance(task_id, session=session) - ti.refresh_from_task(dag.get_task(task_id)) - ti.run() - - show_task = dag.get_task("show") - mapped_tis, max_map_index = TaskMap.expand_mapped_task(show_task, dag_run.run_id, session=session) - assert max_map_index + 1 == len(mapped_tis) == 6 - - for ti in sorted(mapped_tis, key=operator.attrgetter("map_index")): - ti.refresh_from_task(show_task) - ti.run() - assert outputs == [ - (1, ("a", "x")), - (1, ("b", "y")), - (1, ("c", "z")), - (2, ("a", "x")), - (2, ("b", "y")), - (2, ("c", "z")), - ] - - def test_map_product_same(self, dag_maker, session): - """Test a mapped task can refer to the same source multiple times.""" - outputs = [] - - with dag_maker(dag_id="product_same", session=session) as dag: - - @dag.task - def emit_numbers(): - return [1, 2] - - @dag.task - def show(a, b): - outputs.append((a, b)) - - emit_task = emit_numbers() - show.expand(a=emit_task, b=emit_task) - - dag_run = dag_maker.create_dagrun() - ti = dag_run.get_task_instance("emit_numbers", session=session) - ti.refresh_from_task(dag.get_task("emit_numbers")) - ti.run() - - show_task = dag.get_task("show") - with pytest.raises(NotFullyPopulated): - assert show_task.get_parse_time_mapped_ti_count() - mapped_tis, max_map_index = TaskMap.expand_mapped_task(show_task, dag_run.run_id, session=session) - assert max_map_index + 1 == len(mapped_tis) == 4 - - for ti in sorted(mapped_tis, key=operator.attrgetter("map_index")): - ti.refresh_from_task(show_task) - ti.run() - assert outputs == [(1, 1), (1, 2), (2, 1), (2, 2)] - def test_map_literal_cross_product(self, dag_maker, session): """Test a mapped task with literal cross product args expand properly.""" outputs = [] diff --git a/tests/models/test_taskmap.py b/tests/models/test_taskmap.py new file mode 100644 index 0000000000000..c14dc9928b562 --- /dev/null +++ b/tests/models/test_taskmap.py @@ -0,0 +1,76 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.models.taskinstance import TaskInstance +from airflow.models.taskmap import TaskMap, TaskMapVariant +from airflow.operators.empty import EmptyOperator + +pytestmark = pytest.mark.db_test + + +def test_task_map_from_task_instance_xcom(): + task = EmptyOperator(task_id="test_task") + ti = TaskInstance(task=task, run_id="test_run", map_index=0) + ti.dag_id = "test_dag" + value = {"key1": "value1", "key2": "value2"} + + # Test case where run_id is not None + task_map = TaskMap.from_task_instance_xcom(ti, value) + assert task_map.dag_id == ti.dag_id + assert task_map.task_id == ti.task_id + assert task_map.run_id == ti.run_id + assert task_map.map_index == ti.map_index + assert task_map.length == len(value) + assert task_map.keys == list(value) + + # Test case where run_id is None + ti.run_id = None + with pytest.raises(ValueError, match="cannot record task map for unrun task instance"): + TaskMap.from_task_instance_xcom(ti, value) + + +def test_task_map_with_invalid_task_instance(): + task = EmptyOperator(task_id="test_task") + ti = TaskInstance(task=task, run_id=None, map_index=0) + ti.dag_id = "test_dag" + + # Define some arbitrary XCom-like value data + value = {"example_key": "example_value"} + + with pytest.raises(ValueError, match="cannot record task map for unrun task instance"): + TaskMap.from_task_instance_xcom(ti, value) + + +def test_task_map_variant(): + # Test case where keys is None + task_map = TaskMap( + dag_id="test_dag", + task_id="test_task", + run_id="test_run", + map_index=0, + length=3, + keys=None, + ) + assert task_map.variant == TaskMapVariant.LIST + + # Test case where keys is not None + task_map.keys = ["key1", "key2"] + assert task_map.variant == TaskMapVariant.DICT diff --git a/tests/models/test_xcom_arg_map.py b/tests/models/test_xcom_arg_map.py deleted file mode 100644 index a654ef1dd4a2e..0000000000000 --- a/tests/models/test_xcom_arg_map.py +++ /dev/null @@ -1,433 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import pytest - -from airflow.exceptions import AirflowSkipException -from airflow.models.taskinstance import TaskInstance -from airflow.models.taskmap import TaskMap, TaskMapVariant -from airflow.operators.empty import EmptyOperator -from airflow.utils.state import TaskInstanceState -from airflow.utils.trigger_rule import TriggerRule - -pytestmark = pytest.mark.db_test - - -def test_xcom_map(dag_maker, session): - results = set() - with dag_maker(session=session) as dag: - - @dag.task - def push(): - return ["a", "b", "c"] - - @dag.task - def pull(value): - results.add(value) - - pull.expand_kwargs(push().map(lambda v: {"value": v * 2})) - - # The function passed to "map" is *NOT* a task. - assert set(dag.task_dict) == {"push", "pull"} - - dr = dag_maker.create_dagrun(session=session) - - # Run "push". - decision = dr.task_instance_scheduling_decisions(session=session) - for ti in decision.schedulable_tis: - ti.run(session=session) - session.commit() - - # Run "pull". - decision = dr.task_instance_scheduling_decisions(session=session) - tis = {(ti.task_id, ti.map_index): ti for ti in decision.schedulable_tis} - assert sorted(tis) == [("pull", 0), ("pull", 1), ("pull", 2)] - for ti in tis.values(): - ti.run(session=session) - - assert results == {"aa", "bb", "cc"} - - -def test_xcom_map_transform_to_none(dag_maker, session): - results = set() - - with dag_maker(session=session) as dag: - - @dag.task() - def push(): - return ["a", "b", "c"] - - @dag.task() - def pull(value): - results.add(value) - - def c_to_none(v): - if v == "c": - return None - return v - - pull.expand(value=push().map(c_to_none)) - - dr = dag_maker.create_dagrun(session=session) - - # Run "push". - decision = dr.task_instance_scheduling_decisions(session=session) - for ti in decision.schedulable_tis: - ti.run(session=session) - - # Run "pull". This should automatically convert "c" to None. - decision = dr.task_instance_scheduling_decisions(session=session) - for ti in decision.schedulable_tis: - ti.run(session=session) - assert results == {"a", "b", None} - - -def test_xcom_convert_to_kwargs_fails_task(dag_maker, session): - results = set() - - with dag_maker(session=session) as dag: - - @dag.task() - def push(): - return ["a", "b", "c"] - - @dag.task() - def pull(value): - results.add(value) - - def c_to_none(v): - if v == "c": - return None - return {"value": v} - - pull.expand_kwargs(push().map(c_to_none)) - - dr = dag_maker.create_dagrun(session=session) - - # Run "push". - decision = dr.task_instance_scheduling_decisions(session=session) - for ti in decision.schedulable_tis: - ti.run(session=session) - - # Prepare to run "pull"... - decision = dr.task_instance_scheduling_decisions(session=session) - tis = {(ti.task_id, ti.map_index): ti for ti in decision.schedulable_tis} - - # The first two "pull" tis should also succeed. - tis[("pull", 0)].run(session=session) - tis[("pull", 1)].run(session=session) - - # But the third one fails because the map() result cannot be used as kwargs. - with pytest.raises(ValueError) as ctx: - tis[("pull", 2)].run(session=session) - assert str(ctx.value) == "expand_kwargs() expects a list[dict], not list[None]" - - assert [tis[("pull", i)].state for i in range(3)] == [ - TaskInstanceState.SUCCESS, - TaskInstanceState.SUCCESS, - TaskInstanceState.FAILED, - ] - - -def test_xcom_map_error_fails_task(dag_maker, session): - with dag_maker(session=session) as dag: - - @dag.task() - def push(): - return ["a", "b", "c"] - - @dag.task() - def pull(value): - print(value) - - def does_not_work_with_c(v): - if v == "c": - raise ValueError("nope") - return {"value": v * 2} - - pull.expand_kwargs(push().map(does_not_work_with_c)) - - dr = dag_maker.create_dagrun(session=session) - - # The "push" task should not fail. - decision = dr.task_instance_scheduling_decisions(session=session) - for ti in decision.schedulable_tis: - ti.run(session=session) - assert [ti.state for ti in decision.schedulable_tis] == [TaskInstanceState.SUCCESS] - - # Prepare to run "pull"... - decision = dr.task_instance_scheduling_decisions(session=session) - tis = {(ti.task_id, ti.map_index): ti for ti in decision.schedulable_tis} - - # The first two "pull" tis should also succeed. - tis[("pull", 0)].run(session=session) - tis[("pull", 1)].run(session=session) - - # But the third one (for "c") will fail. - with pytest.raises(ValueError) as ctx: - tis[("pull", 2)].run(session=session) - assert str(ctx.value) == "nope" - - assert [tis[("pull", i)].state for i in range(3)] == [ - TaskInstanceState.SUCCESS, - TaskInstanceState.SUCCESS, - TaskInstanceState.FAILED, - ] - - -def test_task_map_from_task_instance_xcom(): - task = EmptyOperator(task_id="test_task") - ti = TaskInstance(task=task, run_id="test_run", map_index=0) - ti.dag_id = "test_dag" - value = {"key1": "value1", "key2": "value2"} - - # Test case where run_id is not None - task_map = TaskMap.from_task_instance_xcom(ti, value) - assert task_map.dag_id == ti.dag_id - assert task_map.task_id == ti.task_id - assert task_map.run_id == ti.run_id - assert task_map.map_index == ti.map_index - assert task_map.length == len(value) - assert task_map.keys == list(value) - - # Test case where run_id is None - ti.run_id = None - with pytest.raises(ValueError, match="cannot record task map for unrun task instance"): - TaskMap.from_task_instance_xcom(ti, value) - - -def test_task_map_with_invalid_task_instance(): - task = EmptyOperator(task_id="test_task") - ti = TaskInstance(task=task, run_id=None, map_index=0) - ti.dag_id = "test_dag" - - # Define some arbitrary XCom-like value data - value = {"example_key": "example_value"} - - with pytest.raises(ValueError, match="cannot record task map for unrun task instance"): - TaskMap.from_task_instance_xcom(ti, value) - - -def test_task_map_variant(): - # Test case where keys is None - task_map = TaskMap( - dag_id="test_dag", - task_id="test_task", - run_id="test_run", - map_index=0, - length=3, - keys=None, - ) - assert task_map.variant == TaskMapVariant.LIST - - # Test case where keys is not None - task_map.keys = ["key1", "key2"] - assert task_map.variant == TaskMapVariant.DICT - - -def test_xcom_map_raise_to_skip(dag_maker, session): - result = None - - with dag_maker(session=session) as dag: - - @dag.task() - def push(): - return ["a", "b", "c"] - - @dag.task() - def forward(value): - return value - - @dag.task(trigger_rule=TriggerRule.ALL_DONE) - def collect(value): - nonlocal result - result = list(value) - - def skip_c(v): - if v == "c": - raise AirflowSkipException - return {"value": v} - - collect(value=forward.expand_kwargs(push().map(skip_c))) - - dr = dag_maker.create_dagrun(session=session) - - # Run "push". - decision = dr.task_instance_scheduling_decisions(session=session) - for ti in decision.schedulable_tis: - ti.run(session=session) - - # Run "forward". This should automatically skip "c". - decision = dr.task_instance_scheduling_decisions(session=session) - for ti in decision.schedulable_tis: - ti.run(session=session) - - # Now "collect" should only get "a" and "b". - decision = dr.task_instance_scheduling_decisions(session=session) - for ti in decision.schedulable_tis: - ti.run(session=session) - assert result == ["a", "b"] - - -def test_xcom_map_nest(dag_maker, session): - results = set() - - with dag_maker(session=session) as dag: - - @dag.task() - def push(): - return ["a", "b", "c"] - - @dag.task() - def pull(value): - results.add(value) - - converted = push().map(lambda v: v * 2).map(lambda v: {"value": v}) - pull.expand_kwargs(converted) - - dr = dag_maker.create_dagrun(session=session) - - # Run "push". - decision = dr.task_instance_scheduling_decisions(session=session) - for ti in decision.schedulable_tis: - ti.run(session=session) - - session.flush() - session.commit() - - # Now "pull" should apply the mapping functions in order. - decision = dr.task_instance_scheduling_decisions(session=session) - for ti in decision.schedulable_tis: - ti.run(session=session) - assert results == {"aa", "bb", "cc"} - - -def test_xcom_map_zip_nest(dag_maker, session): - results = set() - - with dag_maker(session=session) as dag: - - @dag.task - def push_letters(): - return ["a", "b", "c", "d"] - - @dag.task - def push_numbers(): - return [1, 2, 3, 4] - - @dag.task - def pull(value): - results.add(value) - - doubled = push_numbers().map(lambda v: v * 2) - combined = doubled.zip(push_letters()) - - def convert_zipped(zipped): - letter, number = zipped - return letter * number - - pull.expand(value=combined.map(convert_zipped)) - - dr = dag_maker.create_dagrun(session=session) - - # Run "push_letters" and "push_numbers". - decision = dr.task_instance_scheduling_decisions(session=session) - assert decision.schedulable_tis - assert all(ti.task_id.startswith("push_") for ti in decision.schedulable_tis) - for ti in decision.schedulable_tis: - ti.run(session=session) - session.commit() - - # Run "pull". - decision = dr.task_instance_scheduling_decisions(session=session) - assert decision.schedulable_tis - assert all(ti.task_id == "pull" for ti in decision.schedulable_tis) - for ti in decision.schedulable_tis: - ti.run(session=session) - - assert results == {"aa", "bbbb", "cccccc", "dddddddd"} - - -def test_xcom_concat(dag_maker, session): - from airflow.sdk.definitions.xcom_arg import _ConcatResult - - agg_results = set() - all_results = None - - with dag_maker(session=session) as dag: - - @dag.task - def push_letters(): - return ["a", "b", "c"] - - @dag.task - def push_numbers(): - return [1, 2] - - @dag.task - def pull_one(value): - agg_results.add(value) - - @dag.task - def pull_all(value): - assert isinstance(value, _ConcatResult) - assert value[0] == "a" - assert value[1] == "b" - assert value[2] == "c" - assert value[3] == 1 - assert value[4] == 2 - with pytest.raises(IndexError): - value[5] - assert value[-5] == "a" - assert value[-4] == "b" - assert value[-3] == "c" - assert value[-2] == 1 - assert value[-1] == 2 - with pytest.raises(IndexError): - value[-6] - nonlocal all_results - all_results = list(value) - - pushed_values = push_letters().concat(push_numbers()) - - pull_one.expand(value=pushed_values) - pull_all(pushed_values) - - dr = dag_maker.create_dagrun(session=session) - - # Run "push_letters" and "push_numbers". - decision = dr.task_instance_scheduling_decisions(session=session) - assert len(decision.schedulable_tis) == 2 - assert all(ti.task_id.startswith("push_") for ti in decision.schedulable_tis) - for ti in decision.schedulable_tis: - ti.run(session=session) - session.commit() - - # Run "pull_one" and "pull_all". - decision = dr.task_instance_scheduling_decisions(session=session) - assert len(decision.schedulable_tis) == 6 - assert all(ti.task_id.startswith("pull_") for ti in decision.schedulable_tis) - for ti in decision.schedulable_tis: - ti.run(session=session) - - assert agg_results == {"a", "b", "c", 1, 2} - assert all_results == ["a", "b", "c", 1, 2] - - decision = dr.task_instance_scheduling_decisions(session=session) - assert not decision.schedulable_tis diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 2cd1ce14a5073..5b73f2a71b0c4 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -40,7 +40,7 @@ from typing import TYPE_CHECKING from unittest import mock -import attr +import attrs import pendulum import pytest from dateutil.relativedelta import FR, relativedelta @@ -693,7 +693,7 @@ def validate_deserialized_task( } else: # Promised to be mapped by the assert above. assert isinstance(serialized_task, MappedOperator) - fields_to_check = {f.name for f in attr.fields(MappedOperator)} + fields_to_check = {f.name for f in attrs.fields(MappedOperator)} fields_to_check -= { "map_index_template", # Matching logic in BaseOperator.get_serialized_fields(). @@ -706,6 +706,7 @@ def validate_deserialized_task( # Checked separately. "operator_class", "partial_kwargs", + "expand_input", } fields_to_check |= {"deps"} @@ -752,6 +753,9 @@ def validate_deserialized_task( original_partial_kwargs = {**default_partial_kwargs, **task.partial_kwargs} assert serialized_partial_kwargs == original_partial_kwargs + # ExpandInputs have different classes between scheduler and definition + assert attrs.asdict(serialized_task.expand_input) == attrs.asdict(task.expand_input) + @pytest.mark.parametrize( "dag_start_date, task_start_date, expected_task_start_date", [ diff --git a/tests/ti_deps/deps/test_mapped_task_upstream_dep.py b/tests/ti_deps/deps/test_mapped_task_upstream_dep.py index e91ce905a129a..450399b620f00 100644 --- a/tests/ti_deps/deps/test_mapped_task_upstream_dep.py +++ b/tests/ti_deps/deps/test_mapped_task_upstream_dep.py @@ -28,7 +28,7 @@ from airflow.ti_deps.deps.mapped_task_upstream_dep import MappedTaskUpstreamDep from airflow.utils.state import TaskInstanceState -pytestmark = pytest.mark.db_test +pytestmark = [pytest.mark.db_test, pytest.mark.need_serialized_dag] if TYPE_CHECKING: from sqlalchemy.orm.session import Session diff --git a/tests_common/pytest_plugin.py b/tests_common/pytest_plugin.py index 904b05b946a76..2d8d8447a2915 100644 --- a/tests_common/pytest_plugin.py +++ b/tests_common/pytest_plugin.py @@ -788,7 +788,12 @@ def __enter__(self): self.dag.__enter__() if self.want_serialized: - return lazy_object_proxy.Proxy(self._serialized_dag) + + class DAGProxy(lazy_object_proxy.Proxy): + # Make `@dag.task` decorator work when need_serialized_dag marker is set + task = self.dag.task + + return DAGProxy(self._serialized_dag) return self.dag def _serialized_dag(self): @@ -868,6 +873,9 @@ def __exit__(self, type, value, traceback): if self.want_activate_assets: self._activate_assets() if sdm: + sdm._SerializedDagModel__data_cache = ( + self.serialized_model._SerializedDagModel__data_cache + ) self.serialized_model = sdm else: self.session.merge(self.serialized_model) @@ -936,9 +944,12 @@ def create_dagrun(self, *, logical_date=None, **kwargs): kwargs.pop("triggered_by", None) kwargs["execution_date"] = logical_date + if self.want_serialized: + dag = self.serialized_model.dag self.dag_run = dag.create_dagrun(**kwargs) for ti in self.dag_run.task_instances: - ti.refresh_from_task(dag.get_task(ti.task_id)) + # This need to always operate on the _real_ dag + ti.refresh_from_task(self.dag.get_task(ti.task_id)) if self.want_serialized: self.session.commit() return self.dag_run diff --git a/tests_common/test_utils/mock_operators.py b/tests_common/test_utils/mock_operators.py index a785dffad7e2d..abd72b1d2a950 100644 --- a/tests_common/test_utils/mock_operators.py +++ b/tests_common/test_utils/mock_operators.py @@ -17,7 +17,7 @@ from __future__ import annotations from collections.abc import Sequence -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import attr @@ -27,8 +27,6 @@ from tests_common.test_utils.compat import BaseOperatorLink if TYPE_CHECKING: - import jinja2 - from airflow.sdk.definitions.context import Context @@ -46,48 +44,6 @@ def execute(self, context: Context): pass -class NestedFields: - """Nested fields for testing purposes.""" - - def __init__(self, field_1, field_2): - self.field_1 = field_1 - self.field_2 = field_2 - - -class MockOperatorWithNestedFields(BaseOperator): - """Operator with nested fields for testing purposes.""" - - template_fields: Sequence[str] = ("arg1", "arg2") - - def __init__(self, arg1: str = "", arg2: NestedFields | None = None, **kwargs): - super().__init__(**kwargs) - self.arg1 = arg1 - self.arg2 = arg2 - - def _render_nested_template_fields( - self, - content: Any, - context: Context, - jinja_env: jinja2.Environment, - seen_oids: set, - ) -> None: - if id(content) not in seen_oids: - template_fields: tuple | None = None - - if isinstance(content, NestedFields): - template_fields = ("field_1", "field_2") - - if template_fields: - seen_oids.add(id(content)) - self._do_render_template_fields(content, template_fields, context, jinja_env, seen_oids) - return - - super()._render_nested_template_fields(content, context, jinja_env, seen_oids) - - def execute(self, context: Context): - pass - - class AirflowLink(BaseOperatorLink): """Operator Link for Apache Airflow Website."""