diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 6b9d9d2335e8f..00410a512136b 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -161,6 +161,9 @@ Dockerfile.ci @potiuk @ashb @gopidesupavan @amoghrajesh @jscheffl @bugraoz93 @ja # AIP-72 - Task SDK # Python SDK /task-sdk/ @ashb @kaxil @amoghrajesh +/task-sdk/**/executor.py @dabla +/task-sdk/**/iterableoperator.py @dabla +/task-sdk/**/partitionedoperator.py @dabla # Golang SDK /go-sdk/ @ashb @amoghrajesh diff --git a/airflow-core/src/airflow/serialization/definitions/mappedoperator.py b/airflow-core/src/airflow/serialization/definitions/mappedoperator.py index 1cf6d357e651a..3d3b165244c74 100644 --- a/airflow-core/src/airflow/serialization/definitions/mappedoperator.py +++ b/airflow-core/src/airflow/serialization/definitions/mappedoperator.py @@ -515,6 +515,11 @@ def _(task: SerializedBaseOperator | TaskSDKBaseOperator, run_id: str, *, sessio def _(task: SerializedMappedOperator | TaskSDKMappedOperator, run_id: str, *, session: Session) -> int: from airflow.serialization.serialized_objects import BaseSerialization, _ExpandInputRef + partition_size = task.partial_kwargs.get("partition_size") + + if partition_size is not None: + return partition_size + exp_input = task._get_specified_expand_input() # TODO (GH-52141): 'task' here should be scheduler-bound and returns scheduler expand input. if not hasattr(exp_input, "get_total_map_length"): diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py b/airflow-core/tests/unit/serialization/test_dag_serialization.py index 33db3d8d906d4..0d37e8237c862 100644 --- a/airflow-core/tests/unit/serialization/test_dag_serialization.py +++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py @@ -2698,6 +2698,7 @@ def test_operator_expand_serde(): "template_ext": [".sh", ".bash"], "template_fields_renderers": {"bash_command": "bash", "env": "json"}, "ui_color": "#f0ede4", + "_apply_upstream_relationship": True, "_disallow_kwargs_override": False, "_expand_input_attr": "expand_input", } @@ -2741,6 +2742,7 @@ def test_operator_expand_xcomarg_serde(): }, "task_id": "task_2", "template_fields": ["arg1", "arg2"], + "_apply_upstream_relationship": True, "_disallow_kwargs_override": False, "_expand_input_attr": "expand_input", } @@ -2796,6 +2798,7 @@ def test_operator_expand_kwargs_literal_serde(strict): }, "task_id": "task_2", "template_fields": ["arg1", "arg2"], + "_apply_upstream_relationship": True, "_disallow_kwargs_override": strict, "_expand_input_attr": "expand_input", } @@ -2843,6 +2846,7 @@ def test_operator_expand_kwargs_xcomarg_serde(strict): }, "task_id": "task_2", "template_fields": ["arg1", "arg2"], + "_apply_upstream_relationship": True, "_disallow_kwargs_override": strict, "_expand_input_attr": "expand_input", } @@ -2951,6 +2955,7 @@ def x(arg1, arg2, arg3): "op_args": "py", "op_kwargs": "py", }, + "_apply_upstream_relationship": True, "_disallow_kwargs_override": False, "_expand_input_attr": "op_kwargs_expand_input", "python_callable_name": "test_taskflow_expand_serde..x", @@ -3044,6 +3049,7 @@ def x(arg1, arg2, arg3): "op_args": "py", "op_kwargs": "py", }, + "_apply_upstream_relationship": True, "_disallow_kwargs_override": strict, "_expand_input_attr": "op_kwargs_expand_input", } @@ -3155,6 +3161,7 @@ def operator_extra_links(self): "partial_kwargs": { "retry_delay": {"__type": "timedelta", "__var": 300.0}, }, + "_apply_upstream_relationship": True, "_disallow_kwargs_override": False, "_expand_input_attr": "expand_input", "_operator_extra_links": {"airflow": "_link_AirflowLink2"}, diff --git a/devel-common/src/tests_common/test_utils/mock_context.py b/devel-common/src/tests_common/test_utils/mock_context.py index 4e7aa7884f259..589bd29f99861 100644 --- a/devel-common/src/tests_common/test_utils/mock_context.py +++ b/devel-common/src/tests_common/test_utils/mock_context.py @@ -20,14 +20,29 @@ from typing import TYPE_CHECKING, Any from unittest import mock +from airflow.models import DagRun +from airflow.utils.types import DagRunType + from tests_common.test_utils.compat import Context from tests_common.test_utils.taskinstance import create_task_instance +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS + +if AIRFLOW_V_3_1_PLUS: + from airflow.sdk import timezone +else: + from airflow.utils import timezone # type: ignore[attr-defined,no-redef] if TYPE_CHECKING: from sqlalchemy.orm import Session -def mock_context(task) -> Context: +def generate_run_id() -> str: + if AIRFLOW_V_3_0_PLUS: + return DagRun.generate_run_id(run_type=DagRunType.MANUAL, run_after=timezone.utcnow()) + return DagRun.generate_run_id(run_type=DagRunType.MANUAL, execution_date=timezone.utcnow()) # type: ignore[call-arg] + + +def mock_context(task, run_id: str | None = None) -> Context: from airflow.models import TaskInstance from airflow.utils.session import NEW_SESSION @@ -64,6 +79,6 @@ def xcom_push(self, key: str, value: Any, session: Session = NEW_SESSION, **kwar values[key] = value values["ti"] = create_task_instance(task, dag_version_id=mock.MagicMock(), ti_type=MockedTaskInstance) + values["run_id"] = generate_run_id() if run_id is None else run_id - # See https://github.com/python/mypy/issues/8890 - mypy does not support passing typed dict to TypedDict return Context(values) # type: ignore[misc] diff --git a/task-sdk/.pre-commit-config.yaml b/task-sdk/.pre-commit-config.yaml index 100a6e6490849..835300df1e71a 100644 --- a/task-sdk/.pre-commit-config.yaml +++ b/task-sdk/.pre-commit-config.yaml @@ -39,7 +39,9 @@ repos: ^src/airflow/sdk/definitions/asset/__init__\.py$| ^src/airflow/sdk/definitions/asset/decorators\.py$| ^src/airflow/sdk/definitions/taskgroup\.py$| + ^src/airflow/sdk/definitions/iterableoperator\.py$| ^src/airflow/sdk/definitions/mappedoperator\.py$| + ^src/airflow/sdk/definitions/partitionedoperator\.py$| ^src/airflow/sdk/definitions/deadline\.py$| ^src/airflow/sdk/definitions/dag\.py$| ^src/airflow/sdk/definitions/_internal/types\.py$| diff --git a/task-sdk/docs/deferred-vs-async-operators.rst b/task-sdk/docs/deferred-vs-async-operators.rst index 4d77deea81d78..adf2b48896097 100644 --- a/task-sdk/docs/deferred-vs-async-operators.rst +++ b/task-sdk/docs/deferred-vs-async-operators.rst @@ -196,7 +196,7 @@ concurrently using ``asyncio.gather`` while limiting concurrency with a semaphor .. note:: - The upcoming *Dynamic Task Iteration* feature will simplify patterns like this. + The new :ref:`Dynamic Task Iteration `. feature will simplify patterns like this. Instead of manually managing concurrency with constructs such as ``asyncio.gather`` and ``asyncio.Semaphore``, authors will be able to iterate over asynchronous results directly in downstream tasks while still benefiting diff --git a/task-sdk/docs/dynamic-task-mapping-vs-iteration.rst b/task-sdk/docs/dynamic-task-mapping-vs-iteration.rst new file mode 100644 index 0000000000000..b1ec0a8240a36 --- /dev/null +++ b/task-sdk/docs/dynamic-task-mapping-vs-iteration.rst @@ -0,0 +1,394 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. _sdk-dynamic-task-mapping-vs-iteration: + +Dynamic Task Mapping vs Dynamic Task Iteration +================================================ + +.. versionadded:: 3.2.0 + +Airflow provides two complementary ways to process collections of data: + +- **Dynamic Task Mapping (DTM)** distributes work **across multiple workers**. + Each item becomes a separate Task Instance that can run on a different worker, + giving you horizontal scalability and per-item observability. + +- **Dynamic Task Iteration (DTI)** improves concurrency **within a single task**. + All items are processed inside one Task Instance on one worker, eliminating + scheduling overhead and — when combined with async operators — enabling true + I/O multiplexing through a shared event loop. + +In short: **DTM spreads load across workers; DTI speeds up work within one worker.** + +While both approaches allow you to apply an operation over a collection, +they differ significantly in execution model, scheduler impact, and observability. +This page explains the trade-offs and when to use each. + +Real-World Motivation +--------------------- + +Consider a workflow that downloads ~17,000 XML files from an SFTP server and loads +them into a data warehouse. Community benchmarks demonstrate the dramatic performance +difference between the two approaches: + +.. list-table:: + :header-rows: 1 + + * - Approach + - Execution Time + * - Dynamic Task Mapping with mapped ``SFTPOperator`` + - 3 h 25 m + * - Sync ``@task`` with ``SFTPHook`` (sequential loop) + - 1 h 21 m + * - Async ``@task`` with ``SFTPHookAsync`` (concurrent loop) + - 8 m 29 s + * - Async ``@task`` with ``SFTPHookAsync`` and connection pooling + - 3 m 32 s + +The ~60× improvement stems from eliminating per-item scheduling overhead and +sharing a single event loop for concurrent I/O. This is the kind of workload +where DTI excels: many small, I/O-bound operations processed within one task. + +Dynamic Task Mapping (DTM) +-------------------------- + +Dynamic Task Mapping allows you to expand a single task definition into multiple +Task Instances (TIs). + +For more details, see :ref:`dynamic task mapping `. + +Key characteristics: + +- Each item in the iterable creates a separate Task Instance. +- The scheduler is responsible for creating and managing all mapped tasks. +- Tasks can run in parallel across multiple worker slots. +- Fine-grained retry, logging, and observability per item. +- Well suited for workloads where each item should be independently scheduled and tracked. + +The following example fetches user data from a REST API. Each user ID becomes +a separate Task Instance, individually scheduled, retried, and visible in the UI: + +.. code-block:: python + + from datetime import datetime + + from airflow.providers.http.operators.http import HttpOperator + from airflow.sdk import DAG, task + + + @task + def list_user_ids(): + return [1, 2, 3, 4, 5] + + + with DAG(dag_id="dtm-http-example", start_date=datetime(2022, 1, 1)) as dag: + HttpOperator( + task_id="fetch_user", + http_conn_id="api_default", + method="GET", + endpoint="/users/{{ item }}", + ).expand(item=list_user_ids()) + +With five user IDs the scheduler creates five Task Instances, each occupying +a worker slot. This is fine for small lists, but for thousands of items the +scheduler and database overhead becomes significant. + +Dynamic Task Iteration (DTI) +---------------------------- + +Dynamic Task Iteration allows you to iterate over an iterable (typically an XCom result) +*within a single Task Instance*, applying an operator multiple times without creating +separate Task Instances. + +This means that iteration happens inside the task execution itself rather than at the +scheduler level. + +Key characteristics: + +- A single Task Instance processes all items in the iterable. +- No task expansion; the scheduler manages only one task. +- Lower scheduler overhead compared to DTM. +- Iterations share the same execution context (e.g., memory, event loop). +- Particularly well suited for async operators and high-throughput workloads. + +The same user-fetching problem can be solved with DTI. Here, a single Task +Instance processes all user IDs sequentially using the sync +:class:`~airflow.providers.http.hooks.http.HttpHook`: + +.. code-block:: python + + from datetime import datetime + + from airflow.providers.http.hooks.http import HttpHook + from airflow.sdk import DAG, task + + + @task + def list_user_ids(): + return [1, 2, 3, 4, 5] + + + @task + def fetch_user(user_id: int): + hook = HttpHook(http_conn_id="api_default", method="GET") + response = hook.run(endpoint=f"/users/{user_id}") + return response.json() + + + with DAG(dag_id="dti-sync-http-example", start_date=datetime(2022, 1, 1)) as dag: + fetch_user.iterate(user_id=list_user_ids()) + +The scheduler only manages a single task. With sync tasks, iterations are +executed in a multi-threaded fashion, which eliminates scheduling overhead +and can speed up compute-bound workloads. However, for I/O-bound operations +like HTTP requests, multi-threading alone does not provide the same +performance benefits as async multiplexing — threads still block on each +request individually rather than sharing a single event loop. + +To truly **multiplex** I/O-bound operations, use an async task with +:class:`~airflow.providers.http.hooks.http.HttpAsyncHook`: + +.. code-block:: python + + from datetime import datetime + + from airflow.providers.http.hooks.http import HttpAsyncHook + from airflow.sdk import DAG, task + + + @task + def list_user_ids(): + return [1, 2, 3, 4, 5] + + + @task + async def fetch_user(user_id: int): + hook = HttpAsyncHook(http_conn_id="api_default", method="GET") + async with hook.session() as session: + response = await session.run(endpoint=f"/users/{user_id}") + return await response.json() + + + with DAG(dag_id="dti-async-http-example", start_date=datetime(2022, 1, 1)) as dag: + fetch_user.iterate(user_id=list_user_ids()) + +When ``iterate()`` is used with an async task, all iterations share the same +event loop, enabling true multiplexing of I/O-bound operations without any +manual concurrency management by the DAG author. For five user IDs the +difference is negligible, but for hundreds or thousands of items the +concurrent approach is dramatically faster — see the +:ref:`benchmarks above `. + +Why Dynamic Task Iteration? +--------------------------- + +DTI is designed to address limitations of Dynamic Task Mapping in specific scenarios: + +- **Scheduler scalability**: + DTM creates one Task Instance per item, which can put pressure on the scheduler + for very large datasets. DTI avoids this by keeping execution within a single task. + +- **Async multiplexing**: + With Python-native async support in Airflow 3.2, DTI allows multiple + operations to share the same event loop within a single Task Instance. + This enables efficient multiplexing of I/O-bound workloads. + +- **Lower overhead**: + No need to serialize, schedule, and track thousands of Task Instances. + +- **Triggerer and deferrable-operator bottleneck**: + Deferrable operators delegate async work to triggerers, which store yielded + events directly in the Airflow metadata database. Unlike workers, triggerers + cannot leverage a custom XCom backend to offload large payloads. This makes + triggerers a bottleneck for sustained high-load async execution or workloads + that return large results. Dynamic Task Mapping with deferrable operators + amplifies the problem further. DTI sidesteps triggerers entirely — iterations + execute on workers, which scale more effectively and support custom XCom + backends. + + For more on deferred vs async trade-offs, see :doc:`deferred-vs-async-operators`. + +DTI is especially useful for patterns such as: + +- API pagination +- Bulk HTTP or database calls +- High-throughput async workloads +- Streaming or lazily-evaluated XCom results + +Hooks as Building Blocks +^^^^^^^^^^^^^^^^^^^^^^^^ + +DTI encourages a pattern where DAG authors call **hooks** directly from +``@task``-decorated functions rather than relying on operators. Operators are +wrappers around hooks and sometimes expose only a subset of the hook's +capabilities. By calling hooks directly, users gain full control over +concurrency, error handling, and batching. + +For example, instead of using ``HttpOperator`` in deferrable mode (which +delegates to the triggerer for a single request at a time), an async +``@task`` can call :class:`~airflow.providers.http.hooks.http.HttpAsyncHook` +directly to perform many concurrent requests. With DTI, the framework +handles the iteration, concurrency, and event-loop management +automatically — the DAG author only writes the per-item logic. + +This "hooks as building blocks" approach is especially powerful with async +hooks, where the shared event loop enables concurrent I/O without any +manual ``asyncio.gather`` or ``asyncio.Semaphore`` management. + +For more examples of calling async hooks directly from tasks, see +:doc:`deferred-vs-async-operators`. + +Comparison +---------- + +.. list-table:: + :header-rows: 1 + + * - Aspect + - Dynamic Task Mapping (DTM) + - Dynamic Task Iteration (DTI) + * - Task Instances + - One per item + - Single Task Instance + * - Scheduler load + - High for large iterables + - Minimal + * - Execution model + - Distributed across workers + - In-process iteration + * - Concurrency + - Parallel tasks + - Sync or async within one task + * - Async support + - Limited (per task) + - Strong (shared event loop, multiplexing) + * - Retry behavior + - Per item + - Entire task retries + * - Observability + - Per item in UI + - Aggregated in a single task + * - Triggerer dependency + - Deferrable mapped tasks rely on triggerers + - No triggerers involved + * - XCom backend + - Workers support custom XCom backends + - Workers support custom XCom backends (triggerers do not) + * - Use case + - Independent, trackable units of work + - High-throughput or streaming workloads + +When to Use Dynamic Task Mapping +-------------------------------- + +Prefer DTM when: + +- Each item must be independently tracked in the UI. +- You need fine-grained retries per item. +- Tasks are long-running or resource-intensive. +- Work should be distributed across multiple workers. +- Scheduling decisions should be made per item. + +When to Use Dynamic Task Iteration +----------------------------------- + +Prefer DTI when: + +- You are processing large numbers of small items. +- Scheduler overhead becomes a concern. +- You are using async operators and want to leverage a shared event loop. +- Workloads are I/O-bound and benefit from multiplexing. +- Fine-grained observability per item is not required. + +When **not** to use DTI +----------------------- + +Avoid Dynamic Task Iteration when: + +- You need per-item retries or failure isolation. +- Each item represents a long-running or heavy computation. +- You require detailed visibility per item in the Airflow UI. +- Work must be distributed across multiple worker nodes. + +Note that DTI retries the **entire task** on failure — all items are reprocessed +from the beginning. This trade-off is generally acceptable when total processing +time is short (e.g., minutes rather than hours), but it may be undesirable for +workloads where individual items are expensive to reprocess. + +.. tip:: + + DTI is a **third execution option** alongside Dynamic Task Mapping and + deferrable operators. It is not intended as a replacement for either. + Triggerers remain the right choice for long-running polling or waiting tasks + (e.g., monitoring a remote job or waiting for a Kubernetes pod to complete). + +Combining DTM and DTI (Dynamic Task Partitioning) +-------------------------------------------------- + +.. note:: + + Dynamic Task Partitioning is a planned future feature that will build on + top of Dynamic Task Iteration once DTI is fully implemented. The pattern + described here is not yet available. + +DTM and DTI are not mutually exclusive in principle. A future *Dynamic Task +Partitioning* pattern could use DTM to split a large dataset into +coarse-grained chunks, where each mapped task processes its chunk using DTI. + +For example, downloading 17,000 files could be partitioned into 17 chunks of +1,000 files each. DTM would create one task per chunk, and DTI would iterate +within each chunk using a shared event loop for concurrent I/O. + +This pattern would provide: + +- **Coarse-grained retry**: if a chunk fails, only that chunk is retried — not all 17,000 items. +- **Reduced scheduler load**: the scheduler manages chunks (e.g., 17 tasks) instead of individual items (17,000 tasks). +- **High throughput within each chunk**: async I/O processes items concurrently inside each task. + +Relationship with Async Operators +---------------------------------- + +DTI complements async operators introduced in Airflow 3.2. + +- Async operators allow concurrent I/O within a single task. +- DTI allows you to *apply an operator repeatedly* over a dataset within that same task. + +Together, they enable patterns such as: + +- Efficient API pagination +- Concurrent request batching +- Streaming data processing + +Unlike Dynamic Task Mapping, where each mapped task runs in its own execution context, +DTI allows all iterations to share the same event loop, enabling true multiplexing. + +Because DTI executes on workers rather than triggerers, it also benefits from the +full worker environment: custom XCom backends, Edge Worker support, and the +scalability of execution frameworks such as Celery. + +For more details on async execution, see :doc:`deferred-vs-async-operators`. + +Future Outlook +-------------- + +As Python's async ecosystem evolves, DTI tasks will benefit from improved +introspection and tooling. For example, Python 3.14 introduces new +`asyncio introspection capabilities `_ +that could eventually enable structured progress reporting in the Airflow UI +for DTI tasks — providing per-item visibility without the overhead of per-item +task instances. diff --git a/task-sdk/docs/index.rst b/task-sdk/docs/index.rst index d1b26544c8855..dd97d5e116717 100644 --- a/task-sdk/docs/index.rst +++ b/task-sdk/docs/index.rst @@ -172,6 +172,7 @@ For the full public API reference, see the :doc:`api` page. examples dynamic-task-mapping + dynamic-task-mapping-vs-iteration deferred-vs-async-operators api concepts diff --git a/task-sdk/pyproject.toml b/task-sdk/pyproject.toml index 6e4a0b1017d01..ea97395a4cdc7 100644 --- a/task-sdk/pyproject.toml +++ b/task-sdk/pyproject.toml @@ -57,6 +57,8 @@ dependencies = [ "httpx>=0.27.0", "jinja2>=3.1.5", "methodtools>=0.4.7", + # Only needed on Python < 3.12 + 'more-itertools>=9.0.0;python_version<"3.12"', "msgspec>=0.19.0", "python-dateutil>=2.7.0", "psutil>=6.1.0", diff --git a/task-sdk/src/airflow/sdk/bases/decorator.py b/task-sdk/src/airflow/sdk/bases/decorator.py index 8634fbe99647c..5e2b9ea154790 100644 --- a/task-sdk/src/airflow/sdk/bases/decorator.py +++ b/task-sdk/src/airflow/sdk/bases/decorator.py @@ -28,14 +28,10 @@ import attr import typing_extensions -from airflow.sdk import TriggerRule, timezone +from airflow.sdk import TriggerRule from airflow.sdk.bases.operator import ( BASEOPERATOR_ARGS_EXPECTED_TYPES, BaseOperator, - coerce_resources, - coerce_timedelta, - get_merged_defaults, - parse_retries, ) from airflow.sdk.definitions._internal.contextmanager import DagContext, TaskGroupContext from airflow.sdk.definitions._internal.decorators import remove_task_decorator @@ -50,7 +46,6 @@ from airflow.sdk.definitions.context import KNOWN_CONTEXT_KEYS from airflow.sdk.definitions.mappedoperator import ( MappedOperator, - ensure_xcomarg_return_value, prevent_duplicates, ) from airflow.sdk.definitions.xcom_arg import XComArg @@ -64,6 +59,7 @@ from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.dag import DAG from airflow.sdk.definitions.mappedoperator import ValidationSource + from airflow.sdk.definitions.partitionedoperator import DecoratedPartitionedOperator from airflow.sdk.definitions.taskgroup import TaskGroup @@ -585,108 +581,29 @@ def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") return self._expand(ListOfDictsExpandInput(kwargs), strict=strict) - def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg: - ensure_xcomarg_return_value(expand_input.value) - - task_kwargs = self.kwargs.copy() - dag = task_kwargs.pop("dag", None) or DagContext.get_current() - task_group = task_kwargs.pop("task_group", None) or TaskGroupContext.get_current(dag) - - default_args, partial_params = get_merged_defaults( - dag=dag, - task_group=task_group, - task_params=task_kwargs.pop("params", None), - task_default_args=task_kwargs.pop("default_args", None), + def _expand( + self, + expand_input: ExpandInput, + *, + strict: bool, + apply_upstream_relationship: bool = True, + ) -> XComArg: + operator = self.partition(size=0)._expand( + expand_input, strict=strict, apply_upstream_relationship=apply_upstream_relationship ) - partial_kwargs: dict[str, Any] = { - "is_setup": self.is_setup, - "is_teardown": self.is_teardown, - "on_failure_fail_dagrun": self.on_failure_fail_dagrun, - } - base_signature = inspect.signature(BaseOperator) - ignore = { - "default_args", # This is target we are working on now. - "kwargs", # A common name for a keyword argument. - "do_xcom_push", # In the same boat as `multiple_outputs` - "multiple_outputs", # We will use `self.multiple_outputs` instead. - "params", # Already handled above `partial_params`. - "task_concurrency", # Deprecated(replaced by `max_active_tis_per_dag`). - } - partial_keys = set(base_signature.parameters) - ignore - partial_kwargs.update({key: value for key, value in default_args.items() if key in partial_keys}) - partial_kwargs.update(task_kwargs) - - task_id = get_unique_task_id(partial_kwargs.pop("task_id"), dag, task_group) - if task_group: - task_id = task_group.child_id(task_id) - - # Logic here should be kept in sync with BaseOperatorMeta.partial(). - if partial_kwargs.get("wait_for_downstream"): - partial_kwargs["depends_on_past"] = True - start_date = timezone.convert_to_utc(partial_kwargs.pop("start_date", None)) - end_date = timezone.convert_to_utc(partial_kwargs.pop("end_date", None)) - if "pool_slots" in partial_kwargs: - if partial_kwargs["pool_slots"] < 1: - dag_str = "" - if dag: - dag_str = f" in dag {dag.dag_id}" - raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less than 1") - - for fld, convert in ( - ("retries", parse_retries), - ("retry_delay", coerce_timedelta), - ("max_retry_delay", coerce_timedelta), - ("resources", coerce_resources), - ): - if (v := partial_kwargs.get(fld, NOTSET)) is not NOTSET: - partial_kwargs[fld] = convert(v) + return XComArg(operator=operator) - partial_kwargs.setdefault("executor_config", {}) - partial_kwargs.setdefault("op_args", []) - partial_kwargs.setdefault("op_kwargs", {}) + def iterate(self, **mapped_kwargs: OperatorExpandArgument) -> XComArg: + return self.partition(size=0).iterate(**mapped_kwargs) - # Mypy does not work well with a subclassed attrs class :( - _MappedOperator = cast("Any", DecoratedMappedOperator) + def iterate_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> XComArg: + return self.partition(size=0).iterate_kwargs(kwargs, strict=strict) - try: - operator_name = self.operator_class.custom_operator_name # type: ignore - except AttributeError: - operator_name = self.operator_class.__name__ - - operator = _MappedOperator( - operator_class=self.operator_class, - expand_input=EXPAND_INPUT_EMPTY, # Don't use this; mapped values go to op_kwargs_expand_input. - partial_kwargs=partial_kwargs, - task_id=task_id, - params=partial_params, - operator_extra_links=self.operator_class.operator_extra_links, - template_ext=self.operator_class.template_ext, - template_fields=self.operator_class.template_fields, - template_fields_renderers=self.operator_class.template_fields_renderers, - ui_color=self.operator_class.ui_color, - ui_fgcolor=self.operator_class.ui_fgcolor, - is_empty=False, - is_sensor=self.operator_class._is_sensor, - can_skip_downstream=self.operator_class._can_skip_downstream, - task_module=self.operator_class.__module__, - task_type=self.operator_class.__name__, - operator_name=operator_name, - dag=dag, - task_group=task_group, - start_date=start_date, - end_date=end_date, - multiple_outputs=self.multiple_outputs, - python_callable=self.function, - op_kwargs_expand_input=expand_input, - disallow_kwargs_override=strict, - # Different from classic operators, kwargs passed to a taskflow - # task's expand() contribute to the op_kwargs operator argument, not - # the operator arguments themselves, and should expand against it. - expand_input_attr="op_kwargs_expand_input", - start_trigger_args=self.operator_class.start_trigger_args, - start_from_trigger=self.operator_class.start_from_trigger, - ) - return XComArg(operator=operator) + def partition(self, size: int) -> DecoratedPartitionedOperator: + """Return a DecoratedPartitionedOperator for partitioned mapping.""" + from airflow.sdk.definitions.partitionedoperator import DecoratedPartitionedOperator + + return DecoratedPartitionedOperator(operator_partial=self, size=size) def partial(self, **kwargs: Any) -> _TaskDecorator[FParams, FReturn, OperatorSubclass]: self._validate_arg_names("partial", kwargs) diff --git a/task-sdk/src/airflow/sdk/bases/operator.py b/task-sdk/src/airflow/sdk/bases/operator.py index 8d6de54eb6d3d..fb53cf7ceccdc 100644 --- a/task-sdk/src/airflow/sdk/bases/operator.py +++ b/task-sdk/src/airflow/sdk/bases/operator.py @@ -22,6 +22,7 @@ import collections.abc import contextlib import copy +import functools import inspect import sys import warnings @@ -65,7 +66,7 @@ from airflow.sdk.definitions.edges import EdgeModifier from airflow.sdk.definitions.mappedoperator import OperatorPartial, validate_mapping_kwargs from airflow.sdk.definitions.param import ParamsDict -from airflow.sdk.exceptions import RemovedInAirflow4Warning +from airflow.sdk.exceptions import RemovedInAirflow4Warning, TaskDeferred # Databases do not support arbitrary precision integers, so we need to limit the range of priority weights. # postgres: -2147483648 to +2147483647 (see https://www.postgresql.org/docs/current/datatype-numeric.html) @@ -221,6 +222,13 @@ def event_loop() -> Generator[AbstractEventLoop]: asyncio.set_event_loop(None) +# TODO: Once AIP-88 is implemented, multiple events could be returned +async def run_trigger(trigger: BaseTrigger) -> Any | None: + async for event in trigger.run(): + return event + return None + + class _PartialDescriptor: """A descriptor that guards against ``.partial`` being called on Task objects.""" @@ -375,8 +383,6 @@ def partial( partial_kwargs.update((k, v) for k, v in OPERATOR_DEFAULTS.items() if k not in partial_kwargs) # Post-process arguments. Should be kept in sync with _TaskDecorator.expand(). - if "task_concurrency" in kwargs: # Reject deprecated option. - raise TypeError("unexpected argument: task_concurrency") if start_date := partial_kwargs.get("start_date", None): partial_kwargs["start_date"] = timezone.convert_to_utc(start_date) if end_date := partial_kwargs.get("end_date", None): @@ -819,6 +825,8 @@ class derived from this one results in the creation of a task object, key in the returned dictionary result. If False and do_xcom_push is True, pushes a single XCom. :param task_group: The TaskGroup to which the task should belong. This is typically provided when not using a TaskGroup as a context manager. + :param task_concurrency: The maximum number of threads that will be used when the operator is used + with Dynamic Task Iteration (default is the number of threads available on the executor). :param doc: Add documentation or notes to your Task objects that is visible in Task Instance details View in the Webserver :param doc_md: Add documentation (in Markdown format) or notes to your Task objects @@ -1068,6 +1076,7 @@ def __init__( inlets: Any | None = None, outlets: Any | None = None, task_group: TaskGroup | None = None, + task_concurrency: int | None = None, doc: str | None = None, doc_md: str | None = None, doc_json: str | None = None, @@ -1091,6 +1100,7 @@ def __init__( super().__init__() self.task_group = task_group + self.task_concurrency = task_concurrency kwargs.pop("_airflow_mapped_validation_only", None) if kwargs: @@ -1537,6 +1547,7 @@ def get_serialized_fields(cls): "_BaseOperator__from_mapped", "on_failure_fail_dagrun", "task_group", + "task_concurrency", "_task_type", "operator_extra_links", "on_execute_callback", @@ -1669,12 +1680,10 @@ def defer( raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout) - def resume_execution(self, next_method: str, next_kwargs: dict[str, Any] | None, context: Context): + def next_callable(self, next_method: str, next_kwargs: dict[str, Any] | None): """Entrypoint method called by the Task Runner (instead of execute) when this task is resumed.""" from airflow.sdk.exceptions import TaskDeferralError, TaskDeferralTimeout - if next_kwargs is None: - next_kwargs = {} # __fail__ is a special signal value for next_method that indicates # this task was scheduled specifically to fail. @@ -1688,7 +1697,17 @@ def resume_execution(self, next_method: str, next_kwargs: dict[str, Any] | None, raise TaskDeferralError(error) # Grab the callable off the Operator/Task and add in any kwargs execute_callable = getattr(self, next_method) - return execute_callable(context, **next_kwargs) + if next_kwargs: + return functools.partial(execute_callable, **next_kwargs) + return execute_callable + + def resume_execution(self, next_method: str, next_kwargs: dict[str, Any] | None, context: Context): + """Entrypoint method called by the Task Runner (instead of execute) when this task is resumed.""" + if next_kwargs is None: + next_kwargs = {} + + execute_callable = self.next_callable(next_method, next_kwargs) + return execute_callable(context) def dry_run(self) -> None: """Perform dry run for the operator - just render template fields.""" @@ -1766,6 +1785,52 @@ def execute(self, context): return loop.run_until_complete(self.aexecute(context)) +class DecoratedDeferredAsyncOperator(BaseAsyncOperator): + """ + A decorator operator that wraps another deferred BaseOperator instance. + + Implements the async aexecute() method while delegating all other behavior. + """ + + def __init__(self, *, operator: BaseOperator, task_deferred: TaskDeferred, **kwargs: Any): + super().__init__(task_id=operator.task_id, **kwargs) + self._operator = operator + self._task_deferred = task_deferred + + async def aexecute(self, context): + from airflow.sdk.execution_time.callback_runner import create_executable_runner + from airflow.sdk.execution_time.context import context_get_outlet_events + + while True: + event = await run_trigger(self._task_deferred.trigger) + + self.log.debug("event: %s", event) + + if not event: + return None + + self.log.debug("next_method: %s", self._task_deferred.method_name) + + if not self._task_deferred.method_name: + return None + + try: + next_method = self._operator.next_callable( + self._task_deferred.method_name, + self._task_deferred.kwargs, + ) + outlet_events = context_get_outlet_events(context) + runner = create_executable_runner( + func=next_method, + outlet_events=outlet_events, + logger=self.log, + ) + return runner.run(context, event.payload) + except TaskDeferred as task_deferred: + self._task_deferred = task_deferred + continue + + def chain(*tasks: DependencyMixin | Sequence[DependencyMixin]) -> None: r""" Given a number of tasks, builds a dependency chain. diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/expandinput.py b/task-sdk/src/airflow/sdk/definitions/_internal/expandinput.py index b6ffbd2214253..6641dc95ae9ff 100644 --- a/task-sdk/src/airflow/sdk/definitions/_internal/expandinput.py +++ b/task-sdk/src/airflow/sdk/definitions/_internal/expandinput.py @@ -17,21 +17,20 @@ # under the License. from __future__ import annotations +from abc import ABC, abstractmethod from collections.abc import Iterable, Mapping, Sequence, Sized from typing import TYPE_CHECKING, Any, ClassVar, Union import attrs from airflow.sdk.definitions._internal.mixins import ResolveMixin +from airflow.sdk.definitions.xcom_arg import XComArg if TYPE_CHECKING: from typing import TypeGuard - from airflow.sdk.definitions.xcom_arg import XComArg from airflow.sdk.types import Operator -ExpandInput = Union["DictOfListsExpandInput", "ListOfDictsExpandInput"] - # Each keyword argument to expand() can be an XComArg, sequence, or dict (not # any mapping since we need the value to be ordered). OperatorExpandArgument = Union["MappedArgument", "XComArg", Sequence, dict[str, Any]] @@ -79,6 +78,66 @@ def _needs_run_time_resolution(v: OperatorExpandArgument) -> TypeGuard[MappedArg return isinstance(v, (MappedArgument, XComArg)) +class ExpandInput(ABC, ResolveMixin): + EXPAND_INPUT_TYPE: ClassVar[str] + + @property + @abstractmethod + def value(self) -> Any: + """The value of the expand input.""" + ... + + def iter_values(self, context: Mapping[str, Any]) -> Iterable[Any]: + raise NotImplementedError() + + def resolve(self, context: Mapping[str, Any]) -> Any: + raise NotImplementedError() + + +class DecoratedExpandInput(ExpandInput): + EXPAND_INPUT_TYPE: ClassVar[str] = "decorated" + + def __init__(self, expand_input: ExpandInput): + self.delegate = expand_input + + @property + def value(self) -> Any: + return self.delegate.value + + def iter_references(self) -> Iterable[tuple[Operator, str]]: + return self.delegate.iter_references() + + def iter_values(self, context: Mapping[str, Any]) -> Iterable[dict]: + return map( + lambda value: {"op_kwargs": value}, + self.delegate.iter_values(context), + ) + + def resolve(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]: + return self.delegate.resolve(context) + + +class PartitionedExpandInput(DecoratedExpandInput): + """ + ExpandInput that partitions another ExpandInput into N chunks. + + This affects mapping cardinality, NOT resolve-time behavior. + """ + + EXPAND_INPUT_TYPE: ClassVar[str] = "partitioned" + + def __init__(self, expand_input: ExpandInput, size: int): + super().__init__(expand_input=expand_input) + self.size = size + + def iter_values(self, context: Mapping[str, Any]) -> Iterable[dict]: + map_index = context["ti"].map_index + + for index, item in enumerate(self.delegate.iter_values(context)): + if index % self.size == map_index: + yield item + + @attrs.define(kw_only=True) class MappedArgument(ResolveMixin): """ @@ -107,7 +166,7 @@ def resolve(self, context: Mapping[str, Any]) -> Any: @attrs.define() -class DictOfListsExpandInput(ResolveMixin): +class DictOfListsExpandInput(ExpandInput): """ Storage type of a mapped operator's mapped kwargs. @@ -184,6 +243,19 @@ def iter_references(self) -> Iterable[tuple[Operator, str]]: if isinstance(x, XComArg): yield from x.iter_references() + def iter_values(self, context: Mapping[str, Any]) -> Iterable[Any]: + from airflow.sdk.definitions.xcom_arg import XComArg + + resolved = {k: v.resolve(context) if isinstance(v, XComArg) else v for k, v in self.value.items()} + keys = list(resolved) + for items in zip( + *( + v if hasattr(v, "__iter__") and not isinstance(v, (str, bytes)) else (v,) + for v in (resolved[k] for k in keys) + ) + ): + yield dict(zip(keys, items)) + def resolve(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]: map_index: int | None = context["ti"].map_index if map_index is None or map_index < 0: @@ -217,7 +289,7 @@ def _describe_type(value: Any) -> str: @attrs.define() -class ListOfDictsExpandInput(ResolveMixin): +class ListOfDictsExpandInput(ExpandInput): """ Storage type of a mapped operator's mapped kwargs. @@ -238,12 +310,22 @@ def iter_references(self) -> Iterable[tuple[Operator, str]]: if isinstance(x, XComArg): yield from x.iter_references() + def iter_values(self, context: Mapping[str, Any]) -> Iterable[Any]: + if isinstance(self.value, XComArg): + for item in self.value.resolve(context): + yield item + else: + for item in self.value: + if isinstance(item, XComArg): + yield from item.resolve(context) + else: + yield item + def resolve(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], set[int]]: map_index = context["ti"].map_index - if map_index < 0: + if map_index is None or map_index < 0: raise RuntimeError("can't resolve task-mapping argument without expanding") - mapping: Any = None if isinstance(self.value, Sized): mapping = self.value[map_index] if not isinstance(mapping, Mapping): diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/mixins.py b/task-sdk/src/airflow/sdk/definitions/_internal/mixins.py index e186ef97e64dd..e47846499f9b8 100644 --- a/task-sdk/src/airflow/sdk/definitions/_internal/mixins.py +++ b/task-sdk/src/airflow/sdk/definitions/_internal/mixins.py @@ -135,6 +135,24 @@ def iter_references(self) -> Iterable[tuple[Operator, str]]: """ raise NotImplementedError + def iter_values(self, context: Context) -> Iterable[Any]: + """ + Yield individual values for task expansion during Dynamic Task Iteration. + + Called by :class:`~airflow.sdk.definitions.iterableoperator.IterableOperator` during + execution to enumerate all the items this expand input resolves to. Each yielded value + represents the keyword arguments passed to one mapped task instance. Any deferred XCom + references embedded in the expand input are resolved against the provided runtime + context before being yielded. + + :param context: The runtime task execution context used to resolve XCom references. + Must expose the current task instance under the ``ti`` key so that upstream + XCom values can be fetched. + + :meta private: + """ + raise NotImplementedError() + def resolve(self, context: Context) -> Any: """ Resolve this value for runtime. diff --git a/task-sdk/src/airflow/sdk/definitions/iterableoperator.py b/task-sdk/src/airflow/sdk/definitions/iterableoperator.py new file mode 100644 index 0000000000000..aca71253c2dbc --- /dev/null +++ b/task-sdk/src/airflow/sdk/definitions/iterableoperator.py @@ -0,0 +1,454 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +import copy +import os +from collections import deque +from collections.abc import Iterable, Mapping, Sequence +from concurrent.futures import Future +from itertools import chain +from typing import TYPE_CHECKING, Any + +try: + # Python 3.12+ + from itertools import batched # type: ignore[attr-defined] +except ImportError: + from more_itertools import batched # type: ignore[no-redef] + +try: + # Python 3.11+ + BaseExceptionGroup +except NameError: + from exceptiongroup import BaseExceptionGroup + +from airflow.sdk import TaskInstanceState, timezone, BaseXCom +from airflow.sdk.bases.operator import BaseOperator, DecoratedDeferredAsyncOperator, event_loop +from airflow.sdk.definitions._internal.expandinput import PartitionedExpandInput +from airflow.sdk.definitions.mappedoperator import MappedOperator +from airflow.sdk.definitions.xcom_arg import MapXComArg, XComArg # noqa: F401 +from airflow.sdk.exceptions import ( + AirflowRescheduleTaskInstanceException, + AirflowTaskTimeout, + TaskDeferred, +) +from airflow.sdk.execution_time.executor import ConcurrentExecutor, TaskExecutor, collect_futures +from airflow.sdk.execution_time.task_runner import MappedTaskInstance + +if TYPE_CHECKING: + import jinja2 + + from airflow.sdk.definitions._internal.expandinput import ExpandInput + from airflow.sdk.definitions.context import Context + from airflow.sdk.execution_time.lazy_sequence import XComIterable + + +class IterableOperator(BaseOperator): + """Object representing an iterable operator in a DAG.""" + + _operator: MappedOperator + expand_input: ExpandInput + partial_kwargs: dict[str, Any] + shallow_copy_attrs: Sequence[str] = ( + "_operator", + "expand_input", + "partial_kwargs", + "_log", + ) + + def __init__( + self, + *, + operator: MappedOperator, + expand_input: ExpandInput, + **kwargs, + ): + super().__init__( + **{ + **kwargs, + "task_id": operator.task_id, + "owner": operator.owner, + "email": operator.email, + "email_on_retry": operator.email_on_retry, + "email_on_failure": operator.email_on_failure, + "retries": 0, # We should not retry the IterableOperator, only the mapped ti's should be retried + "retry_delay": operator.retry_delay, + "retry_exponential_backoff": operator.retry_exponential_backoff, + "max_retry_delay": operator.max_retry_delay, + "start_date": operator.start_date, + "end_date": operator.end_date, + "depends_on_past": operator.depends_on_past, + "ignore_first_depends_on_past": operator.ignore_first_depends_on_past, + "wait_for_past_depends_before_skipping": operator.wait_for_past_depends_before_skipping, + "wait_for_downstream": operator.wait_for_downstream, + "dag": operator.dag, + "priority_weight": operator.priority_weight, + "queue": operator.queue, + "pool": operator.pool, + "pool_slots": operator.pool_slots, + "execution_timeout": None, + "trigger_rule": operator.trigger_rule, + "resources": operator.resources, + "run_as_user": operator.run_as_user, + "map_index_template": operator.map_index_template, + "max_active_tis_per_dag": operator.max_active_tis_per_dag, + "max_active_tis_per_dagrun": operator.max_active_tis_per_dagrun, + "executor": operator.executor, + "executor_config": operator.executor_config, + "inlets": operator.inlets, + "outlets": operator.outlets, + "task_group": operator.task_group, + "doc": operator.doc, + "doc_md": operator.doc_md, + "doc_json": operator.doc_json, + "doc_yaml": operator.doc_yaml, + "doc_rst": operator.doc_rst, + "task_display_name": operator.task_display_name, + "allow_nested_operators": operator.allow_nested_operators, + } + ) + self._operator = operator + self.expand_input = expand_input + self.partial_kwargs = operator.partial_kwargs or {} + self.partial_kwargs.pop("partition_size", None) + self.max_workers = self.partial_kwargs.pop("task_concurrency", None) or os.cpu_count() or 1 + self._number_of_tasks: int = 0 + XComArg.apply_upstream_relationship(self, self.expand_input.value) + + @property + def task_type(self) -> str: + """@property: type of the task.""" + return self._operator.__class__.__name__ + + @property + def timeout(self) -> float | None: + if self.execution_timeout: + return self.execution_timeout.total_seconds() + return None + + def _do_render_template_fields( + self, + parent: Any, + template_fields: Iterable[str], + context: Context, + jinja_env: jinja2.Environment, + seen_oids: set[int], + ) -> None: + # IterableOperator doesn't need to render template fields as the actual operator's template fields + # will be rendered in the TaskExecutor when running each mapped task instance. + pass + + def _get_specified_expand_input(self) -> ExpandInput: + return self.expand_input + + def _unmap_operator( + self, context: Context, mapped_kwargs: Context, jinja_env: jinja2.Environment + ) -> BaseOperator: + from airflow.sdk.execution_time.context import context_update_for_unmapped + + self._number_of_tasks += 1 + unmapped_task = self._operator.unmap(mapped_kwargs) + # Make sure deferred operators will always raise a DeferredTask exception when executed + unmapped_task.start_from_trigger = False + context_update_for_unmapped(context, unmapped_task) + + unmapped_task._do_render_template_fields( + parent=unmapped_task, + template_fields=self._operator.template_fields, + context=context, + jinja_env=jinja_env, + seen_oids=set(), + ) + return unmapped_task + + def _xcom_push(self, task: MappedTaskInstance, value: Any) -> None: + if task.xcom_pushed: + self.log.debug( + "XCom already pushed for task_id %s with index %s", + task.task_id, + task.index, + ) + else: + self.log.debug( + "Pushing XCom for task_id %s with index %s", + task.task_id, + task.index, + ) + + task.xcom_push(key=BaseXCom.XCOM_RETURN_KEY, value=value) + + def _run_tasks( + self, + context: Context, + tasks: Iterable[MappedTaskInstance], + ) -> XComIterable | None: + exceptions: list[BaseException] = [] + reschedule_date = timezone.utcnow() + prev_futures_count = 0 + futures: dict[Future | asyncio.futures.Future, MappedTaskInstance] = {} + deferred_tasks: deque[MappedTaskInstance] = deque() + failed_tasks: deque[MappedTaskInstance] = deque() + chunked_tasks = batched(tasks, self.max_workers) + do_xcom_push = True + + self.log.info("Running tasks with %d workers", self.max_workers) + + with event_loop() as loop: + with ConcurrentExecutor(loop=loop, max_workers=self.max_workers) as executor: + for task in next(chunked_tasks, []): + do_xcom_push = task.do_xcom_push + if task.is_async: + future = executor.submit(self._run_async_operator, context, task) + else: + future = executor.submit(self._run_operator, context, task) + futures[future] = task + + while futures: + futures_count = len(futures) + + if futures_count != prev_futures_count: + self.log.info("Number of remaining futures: %s", futures_count) + prev_futures_count = futures_count + + for future in collect_futures(loop, list(futures.keys())): + task = futures.pop(future) + + try: + if isinstance(future, asyncio.futures.Future): + result = future.result() + else: + result = future.result(timeout=self.timeout) + + self.log.debug("result: %s", result) + + if result is not None and task.do_xcom_push: + self._xcom_push( + task=task, + value=result, + ) + except TaskDeferred as task_deferred: + operator = DecoratedDeferredAsyncOperator( + operator=task.task, task_deferred=task_deferred + ) + # map_index is guaranteed to be int in MappedTaskInstance due to validation in __init__ + deferred_tasks.append( + self._create_mapped_task( + run_id=task.run_id, + index=task.index, + map_index=task.map_index, # type: ignore[arg-type] + try_number=task.try_number, + operator=operator, + ) + ) + except asyncio.TimeoutError as e: + self.log.warning("A timeout occurred for task_id %s", task.task_id) + if task.next_try_number > (self.retries or 0): + exceptions.append(AirflowTaskTimeout(e)) + else: + reschedule_date = min(reschedule_date, task.next_retry_datetime()) + failed_tasks.append(task) + except AirflowRescheduleTaskInstanceException as e: + reschedule_date = min(reschedule_date, e.reschedule_date) + self.log.exception( + "An exception occurred for task_id %s with index %s, it has been rescheduled at %s", + task.task_id, + task.index, + reschedule_date, + ) + failed_tasks.append(e.task) + except Exception as e: + self.log.exception( + "An exception occurred for task_id %s with index %s", + task.task_id, + task.index, + ) + exceptions.append(e) + + if len(futures) < self.max_workers: + chunked_tasks = chain(list(deferred_tasks), chunked_tasks) + deferred_tasks.clear() + + for task in next(chunked_tasks, []): + if task.is_async: + future = executor.submit(self._run_async_operator, context, task) + else: + future = executor.submit(self._run_operator, context, task) + futures[future] = task + + if not failed_tasks: + if exceptions: + raise BaseExceptionGroup("Multiple sub-task failures", exceptions) + if do_xcom_push: + from airflow.sdk.execution_time.lazy_sequence import XComIterable + + return XComIterable( + task_id=self.task_id, + dag_id=self.dag_id, + run_id=context["run_id"], + length=self._number_of_tasks, + map_index=context["ti"].map_index, + ) + return None + + # If the retry time is still in the future we defer the operator so the worker + # slot is released. If the retry time has already passed we immediately re-run + # the failed tasks without deferring. + if reschedule_date > timezone.utcnow(): + # TODO: This is tricky as that import doesn't exist in Task SDK + from airflow.providers.standard.triggers.temporal import DateTimeTrigger + + self.defer( + trigger=DateTimeTrigger(reschedule_date), + method_name=self.execute_failed_tasks.__name__, + kwargs={ + "failed_tasks": {failed_task.index for failed_task in failed_tasks}, + "try_number": next(iter(failed_tasks)).try_number, + }, + ) + + return self._run_tasks(context=context, tasks=list(failed_tasks)) + + def _run_operator(self, context: Context, task_instance: MappedTaskInstance): + with TaskExecutor(task_instance=task_instance) as executor: + return executor.run( + context={ + **dict(context), + **{ + "ti": task_instance, + "task_instance": task_instance, + }, + } + ) + + async def _run_async_operator(self, context: Context, task_instance: MappedTaskInstance): + async with TaskExecutor(task_instance=task_instance) as executor: + return await executor.arun( + context={ + **dict(context), + **{ + "ti": task_instance, + "task_instance": task_instance, + }, + } + ) + + def _create_task( + self, + context: Context, + index: int, + mapped_kwargs: Context, + jinja_env: jinja2.Environment, + try_number: int = 0, + ) -> MappedTaskInstance: + run_id = context["ti"].run_id + map_index = context["ti"].map_index + operator = self._unmap_operator(context.copy(), mapped_kwargs, jinja_env) + return self._create_mapped_task( + run_id=run_id, map_index=map_index, index=index, try_number=try_number, operator=operator + ) + + def _create_mapped_task( + self, run_id: str, map_index: int | None, index: int, try_number: int, operator: BaseOperator + ) -> MappedTaskInstance: + return MappedTaskInstance.model_construct( + task_id=operator.task_id, + dag_id=operator.dag_id, + run_id=run_id, + map_index=map_index, + index=index, + max_tries=operator.retries, + start_date=self.start_date, + state=TaskInstanceState.SCHEDULED.value, + is_mapped=True, + task=operator, + try_number=try_number, + xcom_pushed=False, + ) + + def execute(self, context: Context): + jinja_env = self.get_template_env(dag=self.dag) + tasks = ( + self._create_task( + context=context, + index=index, + mapped_kwargs=value, + jinja_env=jinja_env, + ) + for index, value in enumerate(self.expand_input.iter_values(context=context)) + ) + return self._run_tasks(context=context, tasks=tasks) + + def execute_failed_tasks( + self, + context: Context, + try_number: int, + failed_tasks: set[int], + event: dict[Any, Any], + ): + jinja_env = self.get_template_env(dag=self.dag) + tasks = ( + self._create_task( + context=context, + index=index, + try_number=try_number, + jinja_env=jinja_env, + mapped_kwargs=value, + ) + for index, value in enumerate(self.expand_input.iter_values(context=context)) + if index in failed_tasks + ) + return self._run_tasks(context=context, tasks=tasks) + + +class MappedIterableOperator(MappedOperator): + """A thin wrapper around an existing MappedOperator that unmaps an MappedOperator within an IterableOperator.""" + + def __init__( + self, + mapped_operator: MappedOperator, + expand_input: ExpandInput, + partition_size: int, + ): + self.delegate = mapped_operator + self.delegate.partial_kwargs["partition_size"] = partition_size + self.expand_input = expand_input + self._apply_upstream_relationship = True + self.__attrs_post_init__() + + def __getattr__(self, name): + if name.startswith("__") and name.endswith("__"): + raise AttributeError(f"{type(self).__name__!r} object has no attribute {name!r}") + return getattr(self.delegate, name) + + def prepare_for_execution(self) -> MappedOperator: + return self + + @property + def partition_size(self) -> int: + return self.partial_kwargs.get("partition_size", 0) + + def __repr__(self): + return f"" + + def unmap(self, resolve: Mapping[str, Any]) -> BaseOperator: + return IterableOperator( + operator=copy.deepcopy(self.delegate), + expand_input=PartitionedExpandInput(self.expand_input, self.partition_size), + _airflow_from_mapped=True, + ) diff --git a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py index 0faa2ab6f1850..0bb9b74a81263 100644 --- a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py +++ b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py @@ -21,7 +21,7 @@ import copy import warnings from collections.abc import Collection, Iterable, Iterator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Literal, TypeGuard +from typing import TYPE_CHECKING, Any, Literal, TypeGuard, cast import attrs import methodtools @@ -64,13 +64,15 @@ OperatorExpandArgument, OperatorExpandKwargsArgument, ) + from airflow.sdk.definitions.iterableoperator import IterableOperator from airflow.sdk.definitions.operator_resources import Resources from airflow.sdk.definitions.param import ParamsDict + from airflow.sdk.definitions.partitionedoperator import PartitionedOperator from airflow.sdk.definitions.retry_policy import RetryPolicy from airflow.sdk.types import WeightRuleParam from airflow.triggers.base import StartTriggerArgs -ValidationSource = Literal["expand"] | Literal["partial"] +ValidationSource = Literal["expand"] | Literal["iterate"] | Literal["partial"] def validate_mapping_kwargs(op: type[BaseOperator], func: ValidationSource, value: dict[str, Any]) -> None: @@ -213,67 +215,32 @@ def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") return self._expand(ListOfDictsExpandInput(kwargs), strict=strict) - def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator: - from airflow.providers.standard.operators.empty import EmptyOperator - from airflow.sdk import BaseSensorOperator - from airflow.sdk.bases.skipmixin import SkipMixin - - self._expand_called = True - ensure_xcomarg_return_value(expand_input.value) - - partial_kwargs = self.kwargs.copy() - task_id = partial_kwargs.pop("task_id") - dag = partial_kwargs.pop("dag") - task_group = partial_kwargs.pop("task_group") - start_date = partial_kwargs.pop("start_date", None) - end_date = partial_kwargs.pop("end_date", None) - start_from_trigger = ( - partial_kwargs["start_from_trigger"] - if "start_from_trigger" in partial_kwargs - else getattr(self.operator_class, "start_from_trigger", False) - ) - start_trigger_args = ( - partial_kwargs["start_trigger_args"] - if "start_trigger_args" in partial_kwargs - else getattr(self.operator_class, "start_trigger_args", None) + def _expand( + self, + expand_input: ExpandInput, + *, + strict: bool, + apply_upstream_relationship: bool = True, + ) -> MappedOperator: + return self.partition(size=0)._expand( + expand_input, strict=strict, apply_upstream_relationship=apply_upstream_relationship ) - try: - operator_name = self.operator_class.custom_operator_name # type: ignore - except AttributeError: - operator_name = self.operator_class.__name__ - - op = MappedOperator( - operator_class=self.operator_class, - expand_input=expand_input, - partial_kwargs=partial_kwargs, - task_id=task_id, - params=self.params, - operator_extra_links=self.operator_class.operator_extra_links, - template_ext=self.operator_class.template_ext, - template_fields=self.operator_class.template_fields, - template_fields_renderers=self.operator_class.template_fields_renderers, - ui_color=self.operator_class.ui_color, - ui_fgcolor=self.operator_class.ui_fgcolor, - is_empty=issubclass(self.operator_class, EmptyOperator), - is_sensor=issubclass(self.operator_class, BaseSensorOperator), - can_skip_downstream=issubclass(self.operator_class, SkipMixin), - task_module=self.operator_class.__module__, - task_type=self.operator_class.__name__, - operator_name=operator_name, - dag=dag, - task_group=task_group, - start_date=start_date, - end_date=end_date, - disallow_kwargs_override=strict, - # For classic operators, this points to expand_input because kwargs - # to BaseOperator.expand() contribute to operator arguments. - expand_input_attr="expand_input", - # TODO: Move these to task SDK's BaseOperator and remove getattr - start_trigger_args=start_trigger_args, - start_from_trigger=start_from_trigger, - ) - return op + def iterate(self, **mapped_kwargs: OperatorExpandArgument) -> IterableOperator: + operator = self.partition(size=0).iterate(**mapped_kwargs) + return cast("IterableOperator", operator) + + def iterate_kwargs( + self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True + ) -> IterableOperator: + operator = self.partition(size=0).iterate_kwargs(kwargs, strict=strict) + return cast("IterableOperator", operator) + + def partition(self, size: int) -> PartitionedOperator: + """Return a PartitionedOperator for partitioned mapping.""" + from airflow.sdk.definitions.partitionedoperator import PartitionedOperator + + return PartitionedOperator(operator_partial=self, size=size) @attrs.define( @@ -322,6 +289,7 @@ class MappedOperator(AbstractOperator): end_date: pendulum.DateTime | None upstream_task_ids: set[str] = attrs.field(factory=set, init=False) downstream_task_ids: set[str] = attrs.field(factory=set, init=False) + _apply_upstream_relationship: bool = attrs.field(alias="apply_upstream_relationship", default=True) _disallow_kwargs_override: bool """Whether execution fails if ``expand_input`` has duplicates to ``partial_kwargs``. @@ -343,19 +311,26 @@ def __repr__(self): return f"" def __attrs_post_init__(self): - from airflow.sdk.definitions.xcom_arg import XComArg - - if self.get_closest_mapped_task_group() is not None: - raise NotImplementedError("operator expansion in an expanded task group is not yet supported") - - if self.task_group: - self.task_group.add(self) - if self.dag: - self.dag.add_task(self) - XComArg.apply_upstream_relationship(self, self._get_specified_expand_input().value) - for k, v in self.partial_kwargs.items(): - if k in self.template_fields: - XComArg.apply_upstream_relationship(self, v) + # When _apply_upstream_relationship is False (i.e. IterableOperator), we intentionally + # skip the *entire* body — not just XComArg.apply_upstream_relationship. + # IterableOperator creates in-memory MappedOperator instances solely to drive task + # expansion; they must NOT be registered with the DAG or task group because Airflow + # treats the IterableOperator itself as the single real task instance in the DB. + # Calling dag.add_task() or task_group.add() here would raise duplicate-task errors. + if self._apply_upstream_relationship: + from airflow.sdk.definitions.xcom_arg import XComArg + + if self.get_closest_mapped_task_group() is not None: + raise NotImplementedError("operator expansion in an expanded task group is not yet supported") + + if self.task_group: + self.task_group.add(self) + if self.dag: + self.dag.add_task(self) + XComArg.apply_upstream_relationship(self, self._get_specified_expand_input().value) + for k, v in self.partial_kwargs.items(): + if k in self.template_fields: + XComArg.apply_upstream_relationship(self, v) @methodtools.lru_cache(maxsize=None) @classmethod diff --git a/task-sdk/src/airflow/sdk/definitions/partitionedoperator.py b/task-sdk/src/airflow/sdk/definitions/partitionedoperator.py new file mode 100644 index 0000000000000..c4d29d432ff40 --- /dev/null +++ b/task-sdk/src/airflow/sdk/definitions/partitionedoperator.py @@ -0,0 +1,505 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import inspect +from abc import ABCMeta, abstractmethod +from collections.abc import Callable, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Generic, TypeVar + +import attrs + +from airflow.sdk import TriggerRule, timezone +from airflow.sdk.bases.decorator import ( + DecoratedMappedOperator, + FParams, + FReturn, + OperatorSubclass, + _TaskDecorator, + get_unique_task_id, +) +from airflow.sdk.bases.operator import ( + BaseOperator, + coerce_resources, + coerce_timedelta, + get_merged_defaults, + parse_retries, +) +from airflow.sdk.definitions._internal.contextmanager import ( + DagContext, + TaskGroupContext, +) +from airflow.sdk.definitions._internal.expandinput import ( + EXPAND_INPUT_EMPTY, + DecoratedExpandInput, + DictOfListsExpandInput, + ExpandInput, + ListOfDictsExpandInput, + OperatorExpandArgument, + OperatorExpandKwargsArgument, +) +from airflow.sdk.definitions._internal.types import NOTSET +from airflow.sdk.definitions.mappedoperator import ( + MappedOperator, + OperatorPartial, + ensure_xcomarg_return_value, + prevent_duplicates, + validate_mapping_kwargs, +) +from airflow.sdk.definitions.xcom_arg import XComArg + +if TYPE_CHECKING: + from airflow.sdk.definitions.iterableoperator import IterableOperator, MappedIterableOperator + from airflow.sdk.definitions.mappedoperator import ValidationSource + from airflow.sdk.definitions.param import ParamsDict + +T = TypeVar("T", bound=OperatorPartial | _TaskDecorator) + + +@attrs.define(kw_only=True, repr=False) +class PartitionableOperator(Generic[T], metaclass=ABCMeta): + """ + Intermediate abstraction for partitioned mapping. + + This class decorates an OperatorPartial and stores partition configuration for partitioned mapping. + It is used to facilitate partitioned expansion of operators, allowing tasks to be mapped over partitions + of data and then iterate over the partitioned data. + + :param operator_partial: The partial operator to be partitioned. + :param size: The number of partitions to create. + """ + + operator_partial: T + size: int + + @property + def operator_class(self) -> type[BaseOperator]: + return self.operator_partial.operator_class + + @property + def kwargs(self) -> dict[str, Any]: + return self.operator_partial.kwargs + + @abstractmethod + def iterate(self, **mapped_kwargs: OperatorExpandArgument) -> Any: + """ + Iterate the operator over the provided mapped keyword arguments. + + :param mapped_kwargs: Keyword arguments to expand against. + :return: An expanded operator or XComArg, depending on the subclass implementation. + """ + + @abstractmethod + def iterate_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> Any: + """ + Iterate the operator over a list of dictionaries or XComArg. + + :param kwargs: List of dicts or XComArg to expand against. + :param strict: Whether to enforce strict argument checking. + :return: An expanded operator or XComArg, depending on the subclass implementation. + """ + + @abstractmethod + def _iterate( + self, + expand_input: ExpandInput, + *, + strict: bool, + ) -> IterableOperator | MappedIterableOperator: + """ + Create an iterable operator for the given expansion input. + + This method calls the _expand method first to get a MappedOperator based on expansion input, + then wraps it in either an IterableOperator or MappedIterableOperator depending on the partition size. + + :param expand_input: The input to iterate against. + :param strict: Whether to enforce strict argument checking. + :return: An IterableOperator or MappedIterableOperator. + """ + + @abstractmethod + def _expand( + self, + expand_input: ExpandInput, + *, + strict: bool, + apply_upstream_relationship: bool = True, + ) -> MappedOperator: + """ + Create a mapped operator for the given expansion input. + + :param expand_input: The input to expand against. + :param strict: Whether to enforce strict argument checking. + :param apply_upstream_relationship: Whether to apply upstream relationships. + :return: A MappedOperator instance. + """ + + +@attrs.define(kw_only=True, repr=False) +class PartitionedOperator(PartitionableOperator[OperatorPartial]): + """ + Concrete implementation of PartitionableOperator for classic (non-decorated) operators. + + This class wraps an OperatorPartial and provides partitioned expansion and iteration logic + for classic Airflow operators. It enables mapping tasks over partitions of data, supporting + both direct expansion via keyword arguments and expansion via a list of dictionaries or XComArg. + + :param operator_partial: The OperatorPartial instance to be partitioned and expanded. + :param size: The number of partitions to create for mapping. + """ + + @property + def params(self) -> ParamsDict | dict: + return self.operator_partial.params + + @property + def _expand_called(self) -> bool: + return self.operator_partial._expand_called + + @_expand_called.setter + def _expand_called(self, value: bool) -> None: + self.operator_partial._expand_called = value + + def iterate(self, **mapped_kwargs: OperatorExpandArgument) -> IterableOperator | MappedIterableOperator: + if not mapped_kwargs: + raise TypeError("no arguments to iterate against") + + validate_mapping_kwargs(self.operator_class, "iterate", mapped_kwargs) + prevent_duplicates( + self.kwargs, + mapped_kwargs, + fail_reason="unmappable or already specified", + ) + # Since the input is already checked at parse time, we can set strict + # to False to skip the checks on execution. + expand_input = DictOfListsExpandInput(mapped_kwargs) + return self._iterate(expand_input, strict=False) + + def iterate_kwargs( + self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True + ) -> IterableOperator | MappedIterableOperator: + if isinstance(kwargs, Sequence): + for item in kwargs: + if not isinstance(item, (XComArg, Mapping)): + raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") + elif not isinstance(kwargs, XComArg): + raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") + + expand_input = ListOfDictsExpandInput(kwargs) + return self._iterate(expand_input, strict=strict) + + def _iterate( + self, + expand_input: ExpandInput, + *, + strict: bool, + ) -> IterableOperator | MappedIterableOperator: + from airflow.sdk.definitions.iterableoperator import IterableOperator, MappedIterableOperator + + operator = self._expand(expand_input, strict=strict, apply_upstream_relationship=False) + + if self.size > 0: + return MappedIterableOperator( + mapped_operator=operator, + expand_input=expand_input, + partition_size=self.size, + ) + return IterableOperator( + operator=operator, + expand_input=expand_input, + ) + + def _expand( + self, + expand_input: ExpandInput, + *, + strict: bool, + apply_upstream_relationship: bool = True, + ) -> MappedOperator: + from airflow.providers.standard.operators.empty import EmptyOperator + from airflow.providers.standard.utils.skipmixin import SkipMixin + from airflow.sdk import BaseSensorOperator + + self._expand_called = True + ensure_xcomarg_return_value(expand_input.value) + + partial_kwargs = self.kwargs.copy() + task_id = partial_kwargs.pop("task_id") + dag = partial_kwargs.pop("dag") + task_group = partial_kwargs.pop("task_group") + start_date = partial_kwargs.pop("start_date", None) + end_date = partial_kwargs.pop("end_date", None) + start_from_trigger = ( + partial_kwargs["start_from_trigger"] + if "start_from_trigger" in partial_kwargs + else getattr(self.operator_class, "start_from_trigger", False) + ) + start_trigger_args = ( + partial_kwargs["start_trigger_args"] + if "start_trigger_args" in partial_kwargs + else getattr(self.operator_class, "start_trigger_args", None) + ) + + try: + operator_name = self.operator_class.custom_operator_name # type: ignore + except AttributeError: + operator_name = self.operator_class.__name__ + + return MappedOperator( + operator_class=self.operator_class, + expand_input=expand_input, + partial_kwargs=partial_kwargs, + task_id=task_id, + params=self.params, + operator_extra_links=self.operator_class.operator_extra_links, + template_ext=self.operator_class.template_ext, + template_fields=self.operator_class.template_fields, + template_fields_renderers=self.operator_class.template_fields_renderers, + ui_color=self.operator_class.ui_color, + ui_fgcolor=self.operator_class.ui_fgcolor, + is_empty=issubclass(self.operator_class, EmptyOperator), + is_sensor=issubclass(self.operator_class, BaseSensorOperator), + can_skip_downstream=issubclass(self.operator_class, SkipMixin), + task_module=self.operator_class.__module__, + task_type=self.operator_class.__name__, + operator_name=operator_name, + dag=dag, + task_group=task_group, + start_date=start_date, + end_date=end_date, + disallow_kwargs_override=strict, + # For classic operators, this points to expand_input because kwargs + # to BaseOperator.expand() contribute to operator arguments. + expand_input_attr="expand_input", + start_from_trigger=start_from_trigger, + start_trigger_args=start_trigger_args, + apply_upstream_relationship=apply_upstream_relationship, + ) + + +@attrs.define(kw_only=True, repr=False) +class DecoratedPartitionedOperator(PartitionableOperator[_TaskDecorator]): + """ + Concrete implementation of PartitionableOperator for decorated (TaskFlow) operators. + + This class wraps a _TaskDecorator and provides partitioned expansion and iteration logic + for TaskFlow-style decorated Airflow operators. It enables mapping decorated tasks over + partitions of data, returning XComArg objects for downstream dependencies and supporting + both direct expansion via keyword arguments and expansion via a list of dictionaries or XComArg. + + :param operator_partial: The _TaskDecorator instance to be partitioned and expanded. + :param size: The number of partitions to create for mapping. + """ + + @property + def is_setup(self) -> bool: + return self.operator_partial.is_setup + + @property + def is_teardown(self) -> bool: + return self.operator_partial.is_teardown + + @property + def function(self) -> Callable[FParams, FReturn]: + return self.operator_partial.function + + @property + def operator_class(self) -> type[OperatorSubclass]: + return self.operator_partial.operator_class + + @property + def multiple_outputs(self) -> bool: + return self.operator_partial.multiple_outputs + + @property + def on_failure_fail_dagrun(self) -> bool: + return self.operator_partial.on_failure_fail_dagrun + + def _validate_arg_names(self, func: ValidationSource, kwargs: dict[str, Any]): + self.operator_partial._validate_arg_names(func, kwargs) + + def iterate(self, **map_kwargs: OperatorExpandArgument) -> XComArg: + if self.kwargs.get("trigger_rule") == TriggerRule.ALWAYS and any( + [isinstance(expanded, XComArg) for expanded in map_kwargs.values()] + ): + raise ValueError( + "Task-generated iterating within a task using 'iterate' is not allowed with trigger rule 'always'." + ) + if not map_kwargs: + raise TypeError("no arguments to expand against") + self._validate_arg_names("expand", map_kwargs) + prevent_duplicates(self.kwargs, map_kwargs, fail_reason="mapping already partial") + # Since the input is already checked at parse time, we can set strict + # to False to skip the checks on execution. + if self.is_teardown: + if "trigger_rule" in self.kwargs: + raise ValueError("Trigger rule not configurable for teardown tasks.") + self.kwargs.update(trigger_rule=TriggerRule.ALL_DONE_SETUP_SUCCESS) + expand_input = DictOfListsExpandInput(map_kwargs) + operator = self._iterate(expand_input, strict=False) + return XComArg(operator=operator) + + def iterate_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> XComArg: + if ( + self.kwargs.get("trigger_rule") == TriggerRule.ALWAYS + and not isinstance(kwargs, XComArg) + and any( + [ + isinstance(v, XComArg) + for kwarg in kwargs + if not isinstance(kwarg, XComArg) + for v in kwarg.values() + ] + ) + ): + raise ValueError( + "Task-generated iterating within a task using 'iterate_kwargs' is not allowed with trigger rule 'always'." + ) + if isinstance(kwargs, Sequence): + for item in kwargs: + if not isinstance(item, (XComArg, Mapping)): + raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") + elif not isinstance(kwargs, XComArg): + raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") + expand_input = ListOfDictsExpandInput(kwargs) + operator = self._iterate(expand_input, strict=strict) + return XComArg(operator=operator) + + def _iterate( + self, + expand_input: ExpandInput, + *, + strict: bool, + ) -> IterableOperator | MappedIterableOperator: + from airflow.sdk.definitions.iterableoperator import IterableOperator, MappedIterableOperator + + operator = self._expand(expand_input, strict=strict, apply_upstream_relationship=False) + + if self.size > 0: + return MappedIterableOperator( + mapped_operator=operator, + expand_input=DecoratedExpandInput(expand_input), + partition_size=self.size, + ) + return IterableOperator(operator=operator, expand_input=DecoratedExpandInput(expand_input)) + + def _expand( + self, + expand_input: ExpandInput, + *, + strict: bool, + apply_upstream_relationship: bool = True, + ) -> MappedOperator: + ensure_xcomarg_return_value(expand_input.value) + + task_kwargs = self.kwargs.copy() + dag = task_kwargs.pop("dag", None) or DagContext.get_current() + task_group = task_kwargs.pop("task_group", None) or TaskGroupContext.get_current(dag) + + default_args, partial_params = get_merged_defaults( + dag=dag, + task_group=task_group, + task_params=task_kwargs.pop("params", None), + task_default_args=task_kwargs.pop("default_args", None), + ) + partial_kwargs: dict[str, Any] = { + "is_setup": self.is_setup, + "is_teardown": self.is_teardown, + "on_failure_fail_dagrun": self.on_failure_fail_dagrun, + } + base_signature = inspect.signature(BaseOperator) + ignore = { + "default_args", # This is target we are working on now. + "kwargs", # A common name for a keyword argument. + "do_xcom_push", # In the same boat as `multiple_outputs` + "multiple_outputs", # We will use `self.multiple_outputs` instead. + "params", # Already handled above `partial_params`. + "task_concurrency", # Deprecated(replaced by `max_active_tis_per_dag`). + } + partial_keys = set(base_signature.parameters) - ignore + partial_kwargs.update({key: value for key, value in default_args.items() if key in partial_keys}) + partial_kwargs.update(task_kwargs) + + task_id = get_unique_task_id(partial_kwargs.pop("task_id"), dag, task_group) + if task_group: + task_id = task_group.child_id(task_id) + + # Logic here should be kept in sync with BaseOperatorMeta.partial(). + if partial_kwargs.get("wait_for_downstream"): + partial_kwargs["depends_on_past"] = True + start_date = timezone.convert_to_utc(partial_kwargs.pop("start_date", None)) + end_date = timezone.convert_to_utc(partial_kwargs.pop("end_date", None)) + if "pool_slots" in partial_kwargs: + if partial_kwargs["pool_slots"] < 1: + dag_str = "" + if dag: + dag_str = f" in dag {dag.dag_id}" + raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less than 1") + + for fld, convert in ( + ("retries", parse_retries), + ("retry_delay", coerce_timedelta), + ("max_retry_delay", coerce_timedelta), + ("resources", coerce_resources), + ): + if (v := partial_kwargs.get(fld, NOTSET)) is not NOTSET: + partial_kwargs[fld] = convert(v) + + partial_kwargs.setdefault("executor_config", {}) + partial_kwargs.setdefault("op_args", []) + partial_kwargs.setdefault("op_kwargs", {}) + + try: + operator_name = self.operator_class.custom_operator_name # type: ignore + except AttributeError: + operator_name = self.operator_class.__name__ + + return DecoratedMappedOperator( + operator_class=self.operator_class, + expand_input=EXPAND_INPUT_EMPTY, # Don't use this; mapped values go to op_kwargs_expand_input. + partial_kwargs=partial_kwargs, + task_id=task_id, + params=partial_params, + operator_extra_links=self.operator_class.operator_extra_links, + template_ext=self.operator_class.template_ext, + template_fields=self.operator_class.template_fields, + template_fields_renderers=self.operator_class.template_fields_renderers, + ui_color=self.operator_class.ui_color, + ui_fgcolor=self.operator_class.ui_fgcolor, + is_empty=False, + is_sensor=self.operator_class._is_sensor, + can_skip_downstream=self.operator_class._can_skip_downstream, + task_module=self.operator_class.__module__, + task_type=self.operator_class.__name__, + operator_name=operator_name, + dag=dag, + task_group=task_group, + start_date=start_date, + end_date=end_date, + multiple_outputs=self.multiple_outputs, + python_callable=self.function, + op_kwargs_expand_input=expand_input, + disallow_kwargs_override=strict, + # Different from classic operators, kwargs passed to a taskflow + # task's expand() contribute to the op_kwargs operator argument, not + # the operator arguments themselves, and should expand against it. + expand_input_attr="op_kwargs_expand_input", + start_trigger_args=self.operator_class.start_trigger_args, + start_from_trigger=self.operator_class.start_from_trigger, + apply_upstream_relationship=apply_upstream_relationship, + ) diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index 88db900924289..d237cc7653360 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -96,6 +96,16 @@ def __new__(cls, *args, **kwargs) -> XComArg: def iter_references(self) -> Iterator[tuple[Operator, str]]: raise NotImplementedError() + def iter_values(self, context: Mapping[str, Any]) -> Iterable[Any]: + resolved = self.resolve(context) + + if isinstance(resolved, (str, bytes, dict)): + yield resolved + elif isinstance(resolved, Iterable): + yield from resolved + else: + yield resolved + @staticmethod def iter_xcom_references(arg: Any) -> Iterator[tuple[Operator, str]]: """ diff --git a/task-sdk/src/airflow/sdk/exceptions.py b/task-sdk/src/airflow/sdk/exceptions.py index 7d42dad5d8502..33bb60b21c07c 100644 --- a/task-sdk/src/airflow/sdk/exceptions.py +++ b/task-sdk/src/airflow/sdk/exceptions.py @@ -31,6 +31,7 @@ from airflow.sdk.definitions.asset import AssetNameRef, AssetUniqueKey, AssetUriRef from airflow.sdk.execution_time.comms import ErrorResponse + from airflow.sdk.execution_time.task_runner import MappedTaskInstance class AirflowException(Exception): @@ -157,6 +158,18 @@ def serialize(self): return f"{cls.__module__}.{cls.__name__}", (), {"reschedule_date": self.reschedule_date} +class AirflowRescheduleTaskInstanceException(AirflowRescheduleException): + """ + Raise when the task should be re-scheduled for a specific TaskInstance at a later time. + + :param task: The task instance that should be rescheduled + """ + + def __init__(self, task: MappedTaskInstance): + super().__init__(reschedule_date=task.next_retry_datetime()) + self.task = task + + class AirflowSensorTimeout(AirflowException): """Raise when there is a timeout on sensor polling.""" diff --git a/task-sdk/src/airflow/sdk/execution_time/executor.py b/task-sdk/src/airflow/sdk/execution_time/executor.py new file mode 100644 index 0000000000000..172afcb6c36db --- /dev/null +++ b/task-sdk/src/airflow/sdk/execution_time/executor.py @@ -0,0 +1,241 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +import contextvars +import inspect +import logging +import time +from asyncio import AbstractEventLoop, Semaphore +from collections.abc import Callable, Generator +from concurrent.futures import Future, ThreadPoolExecutor, as_completed +from typing import TYPE_CHECKING, Any, cast + +from airflow.sdk import BaseAsyncOperator, BaseOperator, TaskInstanceState, timezone +from airflow.sdk.bases.operator import ExecutorSafeguard +from airflow.sdk.definitions._internal.logging_mixin import LoggingMixin +from airflow.sdk.exceptions import AirflowRescheduleTaskInstanceException, TaskDeferred +from airflow.sdk.execution_time.callback_runner import create_executable_runner +from airflow.sdk.execution_time.context import context_get_outlet_events +from airflow.sdk.execution_time.task_runner import ( + RuntimeTaskInstance, + _execute_task, + _run_task_state_change_callbacks, +) + +if TYPE_CHECKING: + from structlog.typing import FilteringBoundLogger as Logger + + from airflow.sdk import Context + from airflow.sdk.execution_time.task_runner import MappedTaskInstance + + +def collect_futures( + loop: AbstractEventLoop, futures: list[Any] +) -> Generator[Future | asyncio.futures.Future, None, None]: + """ + Yield futures as they complete (sync or async). + + :param loop: The asyncio event loop to use for async tasks + :param futures: List of Future or asyncio.futures.Future objects to collect + :return: Generator yielding Future or asyncio.futures.Future objects as they complete + """ + yield from as_completed(f for f in futures if isinstance(f, Future)) + + async_tasks = [f for f in futures if isinstance(f, asyncio.futures.Future)] + + if async_tasks: + for task, _ in zip( + async_tasks, + loop.run_until_complete(asyncio.gather(*async_tasks, return_exceptions=True)), + ): + yield task + + +class ConcurrentExecutor: + """ + Executes both sync and async functions concurrently. + + Sync functions run in a ThreadPoolExecutor. + Async coroutines run on an asyncio event loop with a semaphore limit. + """ + + def __init__(self, loop: AbstractEventLoop, max_workers: int = 4): + self._loop = loop + self._semaphore = Semaphore(max_workers) + self._thread_pool = ThreadPoolExecutor(max_workers=max_workers) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._thread_pool: + self._thread_pool.shutdown(wait=True) + + def submit(self, func: Callable, *args, **kwargs): + if inspect.iscoroutine(func): + coro = func + elif inspect.iscoroutinefunction(func): + coro = func(*args, **kwargs) + else: + return self._thread_pool.submit(func, *args, **kwargs) + + async def guarded(): + async with self._semaphore: + return await coro + + return self._loop.create_task(guarded()) + + +class TaskExecutor(LoggingMixin): + """Base class to run an operator or trigger with given task context and task instance.""" + + def __init__( + self, + task_instance: MappedTaskInstance, + ): + super().__init__() + self.task_instance = task_instance + self._result: Any | None = None + self._start_time: float | None = None + + @property + def dag_id(self) -> str: + return self.task_instance.dag_id + + @property + def task_id(self) -> str: + return self.task_instance.task_id + + @property + def task_index(self) -> int: + return self.task_instance.index + + @property + def xcom_key(self): + return self.task_instance.xcom_key + + @property + def operator(self) -> BaseOperator: + return self.task_instance.task + + @property + def is_async(self) -> bool: + return self.task_instance.is_async + + def run(self, context: Context): + return _execute_task(context, self.task_instance, self.log) + + async def arun(self, context: Context): + return await _execute_async_task(context, self.task_instance, self.log) + + def __enter__(self): + self._start_time = time.monotonic() + + if self.log.isEnabledFor(logging.INFO): + self.log.info( + "Attempting running task %s of %s for %s with index %s in %s mode.", + self.task_instance.try_number, + self.operator.retries, + self.task_instance.task_id, + self.task_index, + "async" if self.is_async else "sync", + ) + return self + + async def __aenter__(self): + return self.__enter__() + + def __exit__(self, exc_type, exc_value, traceback): + elapsed = time.monotonic() - self._start_time if self._start_time else 0.0 + + if exc_value: + if not isinstance(exc_value, TaskDeferred): + if self.task_instance.next_try_number > self.task_instance.max_tries: + self.log.error( + "Task instance %s for %s failed after %s attempts in %.2f seconds due to: %s", + self.task_index, + self.task_instance.task_id, + self.task_instance.max_tries, + elapsed, + exc_value, + ) + self.task_instance.state = TaskInstanceState.FAILED + raise exc_value + self.task_instance.try_number += 1 + self.task_instance.end_date = timezone.utcnow() + self.task_instance.state = TaskInstanceState.UP_FOR_RESCHEDULE + raise AirflowRescheduleTaskInstanceException(task=self.task_instance) + raise exc_value + + self.task_instance.state = TaskInstanceState.SUCCESS + if self.log.isEnabledFor(logging.INFO): + self.log.info( + "Task instance %s for %s finished successfully in %s attempts in %.2f seconds", + self.task_index, + self.task_instance.task_id, + self.task_instance.next_try_number, + elapsed, + ) + + async def __aexit__(self, exc_type, exc_value, traceback): + self.__exit__(exc_type, exc_value, traceback) + + +async def _execute_async_task(context: Context, ti: RuntimeTaskInstance, log: Logger): + """Execute Task (optionally with a Timeout) and push Xcom results.""" + # set-up + task = cast("BaseAsyncOperator", ti.task) + execute = task.aexecute # here we must use aexecute instead of execute + + # async tasks can't originate from deferred operator, so no need to check next_method + + ctx = contextvars.copy_context() + # Populate the context var so ExecutorSafeguard doesn't complain + ctx.run(ExecutorSafeguard.tracker.set, task) + + outlet_events = context_get_outlet_events(context) + + if (pre_execute_hook := task._pre_execute_hook) is not None: + create_executable_runner(pre_execute_hook, outlet_events, logger=log).run(context) + if getattr(pre_execute_hook := task.pre_execute, "__func__", None) is not BaseOperator.pre_execute: + create_executable_runner(pre_execute_hook, outlet_events, logger=log).run(context) + + _run_task_state_change_callbacks(task, "on_execute_callback", context, log) + + async def _run_in_context(coro_func, *args, **kwargs): + """Run async function in contextvars context with optional timeout.""" + coro_in_ctx = ctx.run(lambda: coro_func(*args, **kwargs)) + + if task.execution_timeout: + return await asyncio.wait_for(coro_in_ctx, timeout=task.execution_timeout.total_seconds()) + return await coro_in_ctx + + try: + result = await _run_in_context(execute, context=context) + except asyncio.TimeoutError: + task.on_kill() + raise + + if (post_execute_hook := task._post_execute_hook) is not None: + create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context, result) + if getattr(post_execute_hook := task.post_execute, "__func__", None) is not BaseOperator.post_execute: + create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context) + + return result diff --git a/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py b/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py index 4efb0b71368ca..4f65b4a60c71c 100644 --- a/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py +++ b/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py @@ -25,6 +25,9 @@ import attrs import structlog +from airflow.sdk import BaseXCom +from airflow.sdk.execution_time.xcom import XCom + if TYPE_CHECKING: from airflow.sdk.definitions.xcom_arg import PlainXComArg from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance @@ -166,6 +169,80 @@ def __getitem__(self, key: int | slice) -> T | Sequence[T]: return XCom.deserialize_value(_XComWrapper(msg.root)) +class XComIterable(Sequence): + """An iterable that lazily fetches XCom values one by one instead of loading all at once.""" + + def __init__(self, task_id: str, dag_id: str, run_id: str, length: int, map_index: int | None = None): + self.task_id = task_id + self.dag_id = dag_id + self.run_id = run_id + self.length = length + self.map_index = map_index + + def __iter__(self) -> Iterator: + return _XComIterator(self) + + def __len__(self) -> int: + return self.length + + @overload + def __getitem__(self, key: int) -> Any: ... + + @overload + def __getitem__(self, key: slice) -> Sequence[Any]: ... + + def __getitem__(self, key: int | slice) -> Any | Sequence[Any]: + """Allow direct indexing so this works like a sequence.""" + if isinstance(key, slice): + start, stop, step = key.indices(len(self)) + return [self[i] for i in range(start, stop, step)] + + if not (0 <= key < self.length): + raise IndexError(key) + + return XCom.get_one( + key=f"{BaseXCom.XCOM_RETURN_KEY}_{key}", + dag_id=self.dag_id, + task_id=self.task_id, + run_id=self.run_id, + map_index=self.map_index, + ) + + def serialize(self) -> dict: + """Ensure the object is JSON serializable.""" + return { + "task_id": self.task_id, + "dag_id": self.dag_id, + "run_id": self.run_id, + "length": self.length, + "map_index": self.map_index, + } + + @classmethod + def deserialize(cls, data: dict, version: int): + """Ensure the object is JSON deserializable.""" + return XComIterable(**data) + + +class _XComIterator: + """Iterator for XComIterable.""" + + def __init__(self, iterable: XComIterable): + self._iterable = iterable + self._index = 0 + + def __iter__(self): + return self + + def __next__(self): + if self._index >= len(self._iterable): + raise StopIteration + + value = self._iterable[self._index] + self._index += 1 + return value + + def _coerce_slice_index(value: Any) -> int | None: """ Check slice attribute's type and convert it to int. diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 56ba8343c648b..74c8b8985a989 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -21,6 +21,8 @@ import contextvars import functools +import hashlib +import math import os import sys import time @@ -57,6 +59,7 @@ from airflow.sdk.bases.xcom import BaseXCom from airflow.sdk.configuration import conf from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager +from airflow.sdk.definitions._internal.logging_mixin import LoggingMixin from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet, is_arg_set from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUniqueKey, AssetUriRef from airflow.sdk.definitions.mappedoperator import MappedOperator @@ -709,6 +712,112 @@ def mark_success_url(self) -> str: return self.log_url +class MappedTaskInstance(RuntimeTaskInstance, LoggingMixin): + """Mapped task instance to run an operator which handles XCom's in memory.""" + + index: int + xcom_pushed: bool = Field(default=False) + + def __init__(self, /, **data: Any): + super().__init__(**data) + + if self.index is None or self.index < 0: + raise ValueError("MappedTaskInstance requires index >= 0") + + def xcom_pull( + self, + task_ids: str | Iterable[str] | None = None, + dag_id: str | None = None, + key: str = BaseXCom.XCOM_RETURN_KEY, + include_prior_dates: bool = False, + *, + map_indexes: int | Iterable[int] | None | ArgNotSet = NOTSET, + default: Any = None, + run_id: str | None = None, + ) -> Any: + return super().xcom_pull( + task_ids=task_ids, + dag_id=dag_id, + key=f"{key}_{self.index}", + include_prior_dates=include_prior_dates, + map_indexes=map_indexes, + default=default, + run_id=run_id, + ) + + def xcom_push( + self, + key: str, + value: Any, + ): + super().xcom_push(key=f"{key}_{self.index}", value=value) + if key == BaseXCom.XCOM_RETURN_KEY: + self.xcom_pushed = True + + def next_retry_datetime(self): + """ + Get datetime of the next retry if the task instance fails. + + For exponential backoff, retry_delay is used as base and will be converted to seconds. + """ + from airflow.sdk.definitions._internal.abstractoperator import MAX_RETRY_DELAY + + delay = self.task.retry_delay + if self.task.retry_exponential_backoff: + try: + # If the min_backoff calculation is below 1, it will be converted to 0 via int. Thus, + # we must round up prior to converting to an int, otherwise a divide by zero error + # will occur in the modded_hash calculation. + # this probably gives unexpected results if a task instance has previously been cleared, + # because try_number can increase without bound + min_backoff = math.ceil(delay.total_seconds() * (2 ** (self.try_number - 1))) + except OverflowError: + min_backoff = MAX_RETRY_DELAY + self.log.warning( + "OverflowError occurred while calculating min_backoff, using MAX_RETRY_DELAY for min_backoff." + ) + + # In the case when delay.total_seconds() is 0, min_backoff will not be rounded up to 1. + # To address this, we impose a lower bound of 1 on min_backoff. This effectively makes + # the ceiling function unnecessary, but the ceiling function was retained to avoid + # introducing a breaking change. + if min_backoff < 1: + min_backoff = 1 + + # deterministic per task instance + ti_hash = int( + hashlib.sha1( + f"{self.dag_id}#{self.task_id}#{self.logical_date}#{self.try_number}".encode(), + usedforsecurity=False, + ).hexdigest(), + 16, + ) + # between 1 and 1.0 * delay * (2^retry_number) + modded_hash = min_backoff + ti_hash % min_backoff + # timedelta has a maximum representable value. The exponentiation + # here means this value can be exceeded after a certain number + # of tries (around 50 if the initial delay is 1s, even fewer if + # the delay is larger). Cap the value here before creating a + # timedelta object so the operation doesn't fail with "OverflowError". + delay_backoff_in_seconds = min(modded_hash, MAX_RETRY_DELAY) + delay = timedelta(seconds=delay_backoff_in_seconds) + if self.task.max_retry_delay: + delay = min(self.task.max_retry_delay, delay) + return self.end_date + delay + + @property + def is_async(self) -> bool: + return self.task.is_async + + @property + def next_try_number(self) -> int: + return self.try_number + 1 + + @property + def do_xcom_push(self) -> bool: + return self.task.do_xcom_push + + def _xcom_push( ti: RuntimeTaskInstance, key: str, diff --git a/task-sdk/tests/task_sdk/bases/test_operator.py b/task-sdk/tests/task_sdk/bases/test_operator.py index dcb5240a83dc8..2a2cc574875bf 100644 --- a/task-sdk/tests/task_sdk/bases/test_operator.py +++ b/task-sdk/tests/task_sdk/bases/test_operator.py @@ -34,6 +34,7 @@ from airflow.sdk.bases.operator import ( BaseOperator, BaseOperatorMeta, + DecoratedDeferredAsyncOperator, ExecutorSafeguard, chain, chain_linear, @@ -41,7 +42,7 @@ ) from airflow.sdk.definitions.param import ParamsDict from airflow.sdk.definitions.template import literal -from airflow.triggers.base import StartTriggerArgs +from airflow.triggers.base import BaseTrigger, StartTriggerArgs DEFAULT_DATE = datetime(2016, 1, 1, tzinfo=timezone.utc) @@ -1137,3 +1138,152 @@ def __init__(self, arg1, arg2, arg3, **kwargs): assert op.arg2 == "b" assert op.arg3 == 3 assert op.queue == "THIS" + + +class MockTrigger(BaseTrigger): + """A minimal trigger stub that yields a single TriggerEvent-like object.""" + + def __init__(self, payload=None, **kwargs): + super().__init__(**kwargs) + self._payload = payload + + async def run(self): + if self._payload is not None: + yield mock.Mock(payload=self._payload) + + def serialize(self): + return ("tests.MockTrigger", {"payload": self._payload}) + + +class MockDeferredOperator(BaseOperator): + """ + Operator that simulates deferral behavior. + + On each call to ``execute_complete``, it can either return a result or + raise ``TaskDeferred`` again (up to ``max_deferrals`` times). + """ + + template_fields = () + + def __init__(self, *, max_deferrals: int = 0, final_result: str = "done", **kwargs): + super().__init__(**kwargs) + self.max_deferrals = max_deferrals + self.final_result = final_result + self._deferral_count = 0 + + def execute(self, context): + self.defer( + trigger=MockTrigger(payload="initial"), + method_name="execute_complete", + ) + + def execute_complete(self, context, event=None): + from airflow.sdk.exceptions import TaskDeferred + + self._deferral_count += 1 + if self._deferral_count <= self.max_deferrals: + raise TaskDeferred( + trigger=MockTrigger(payload=f"deferred_{self._deferral_count}"), + method_name="execute_complete", + ) + return self.final_result + + +class TestDecoratedDeferredAsyncOperator: + """Tests for DecoratedDeferredAsyncOperator.aexecute.""" + + @staticmethod + def _make_operator(task_id="test_task", max_deferrals=0, final_result="done"): + """Create a DecoratedDeferredAsyncOperator wrapping a MockDeferredOperator.""" + from airflow.sdk.exceptions import TaskDeferred + + with DAG(dag_id="test_dag"): + inner = MockDeferredOperator( + task_id=task_id, + max_deferrals=max_deferrals, + final_result=final_result, + ) + + task_deferred = TaskDeferred( + trigger=MockTrigger(payload="initial"), + method_name="execute_complete", + ) + + return DecoratedDeferredAsyncOperator(operator=inner, task_deferred=task_deferred) + + @pytest.mark.asyncio + async def test_single_deferral_returns_result(self): + """A single deferral cycle should return the final result from execute_complete.""" + operator = self._make_operator(final_result="success") + result = await operator.aexecute(context={}) + assert result == "success" + + @pytest.mark.asyncio + async def test_returns_none_when_trigger_yields_no_event(self): + """When the trigger yields nothing, aexecute should return None.""" + from airflow.sdk.exceptions import TaskDeferred + + with DAG(dag_id="test_dag"): + inner = MockDeferredOperator(task_id="no_event_task") + + # Use a trigger whose run() yields nothing + empty_trigger = MockTrigger(payload=None) + task_deferred = TaskDeferred(trigger=empty_trigger, method_name="execute_complete") + operator = DecoratedDeferredAsyncOperator(operator=inner, task_deferred=task_deferred) + + result = await operator.aexecute(context={}) + assert result is None + + @pytest.mark.asyncio + async def test_returns_none_when_method_name_is_none(self): + """When method_name is falsy, aexecute should return None even if trigger fires.""" + from airflow.sdk.exceptions import TaskDeferred + + with DAG(dag_id="test_dag"): + inner = MockDeferredOperator(task_id="no_method_task") + + task_deferred = TaskDeferred(trigger=MockTrigger(payload="event"), method_name="") + operator = DecoratedDeferredAsyncOperator(operator=inner, task_deferred=task_deferred) + + result = await operator.aexecute(context={}) + assert result is None + + @pytest.mark.asyncio + async def test_multiple_consecutive_deferrals(self): + """Consecutive deferrals should be handled iteratively, not recursively.""" + operator = self._make_operator(max_deferrals=5, final_result="after_5") + result = await operator.aexecute(context={}) + assert result == "after_5" + + @pytest.mark.asyncio + async def test_many_deferrals_do_not_cause_recursion_error(self): + """A large number of deferrals must not blow the stack (no unbounded recursion).""" + operator = self._make_operator(max_deferrals=200, final_result="survived") + # This would hit RecursionError with the old recursive implementation + result = await operator.aexecute(context={}) + assert result == "survived" + + @pytest.mark.asyncio + async def test_deferral_updates_task_deferred_state(self): + """Each deferral should update _task_deferred on the operator.""" + operator = self._make_operator(max_deferrals=2, final_result="final") + result = await operator.aexecute(context={}) + assert result == "final" + # After completion, _task_deferred should reflect the last deferral's trigger + assert operator._task_deferred.trigger._payload == "deferred_2" + + @pytest.mark.asyncio + @pytest.mark.parametrize( + ("max_deferrals", "expected"), + [ + (0, "immediate"), + (1, "after_one"), + (3, "after_three"), + ], + ids=["no-redeferral", "one-redeferral", "three-redeferrals"], + ) + async def test_parametrized_deferral_counts(self, max_deferrals, expected): + """Varying numbers of deferrals should all resolve correctly.""" + operator = self._make_operator(max_deferrals=max_deferrals, final_result=expected) + result = await operator.aexecute(context={}) + assert result == expected diff --git a/task-sdk/tests/task_sdk/definitions/_internal/test_expandinput.py b/task-sdk/tests/task_sdk/definitions/_internal/test_expandinput.py new file mode 100644 index 0000000000000..886a87b20b449 --- /dev/null +++ b/task-sdk/tests/task_sdk/definitions/_internal/test_expandinput.py @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest +from task_sdk.definitions.conftest import make_xcom_arg + +from airflow.sdk.definitions._internal.expandinput import DictOfListsExpandInput, ListOfDictsExpandInput + + +class TestExpandInput: + @pytest.mark.parametrize( + ("actual", "expected"), + [ + ({"a": 1}, [{"a": 1}]), + ({"a": [1, 2, 3]}, [{"a": 1}, {"a": 2}, {"a": 3}]), + ({"a": "hello"}, [{"a": "hello"}]), + ({"a": [1, 2], "b": [10, 20]}, [{"a": 1, "b": 10}, {"a": 2, "b": 20}]), + ({"a": (x for x in [1, 2])}, [{"a": 1}, {"a": 2}]), + ({"a": make_xcom_arg([1, 2])}, [{"a": 1}, {"a": 2}]), + ], + ) + def test_dict_of_lists_expand_input_iter_values(self, actual, expected): + result = list(DictOfListsExpandInput(actual).iter_values({})) + assert result == expected + + @pytest.mark.parametrize( + ("actual", "expected"), + [ + ([{"a": 1}, {"a": 2}], [{"a": 1}, {"a": 2}]), + ([{"a": 1, "b": 2}], [{"a": 1, "b": 2}]), + ([], []), + ([make_xcom_arg([{"a": 1}, {"a": 2}])], [{"a": 1}, {"a": 2}]), # XComArg input + ], + ) + def test_list_of_dicts_expand_input_iter_values(self, actual, expected): + result = list(ListOfDictsExpandInput(actual).iter_values({})) + assert result == expected diff --git a/task-sdk/tests/task_sdk/definitions/conftest.py b/task-sdk/tests/task_sdk/definitions/conftest.py index 3f89f34b4d2da..e5cefa1e3e8ac 100644 --- a/task-sdk/tests/task_sdk/definitions/conftest.py +++ b/task-sdk/tests/task_sdk/definitions/conftest.py @@ -17,11 +17,12 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import pytest import structlog +from airflow.sdk import BaseOperator, XComArg from airflow.sdk.execution_time.comms import SucceedTask, TaskState if TYPE_CHECKING: @@ -47,3 +48,10 @@ def run(dag: DAG, task_id: str, map_index: int): raise RuntimeError("Unable to find call to TaskState") return run + + +def make_xcom_arg(values: Any) -> XComArg: + op = BaseOperator(task_id="upstream") + xcom_arg = XComArg(op) + xcom_arg.resolve = lambda *a, **kw: values + return xcom_arg diff --git a/task-sdk/tests/task_sdk/definitions/test_iterableoperator.py b/task-sdk/tests/task_sdk/definitions/test_iterableoperator.py new file mode 100644 index 0000000000000..ed98861095be1 --- /dev/null +++ b/task-sdk/tests/task_sdk/definitions/test_iterableoperator.py @@ -0,0 +1,464 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import copy +from typing import TYPE_CHECKING + +try: + # Python 3.11+ + BaseExceptionGroup +except NameError: + from exceptiongroup import BaseExceptionGroup + +import pytest + +from airflow.sdk import DAG, BaseOperator, BaseXCom +from airflow.sdk.definitions._internal.abstractoperator import DEFAULT_RETRIES +from airflow.sdk.definitions._internal.expandinput import DictOfListsExpandInput, ListOfDictsExpandInput +from airflow.sdk.definitions.iterableoperator import IterableOperator +from airflow.sdk.execution_time.xcom import XCom + +from tests_common.test_utils.mock_context import mock_context + +if TYPE_CHECKING: + from airflow.sdk.definitions._internal.expandinput import ExpandInput + from airflow.sdk.definitions.mappedoperator import MappedOperator + + from tests_common.test_utils.compat import Context + + +class MockOperator(BaseOperator): + """Mock operator for testing IterableOperator expansion.""" + + def __init__(self, arg1=None, arg2=None, arg3=None, fail_on_first_attempt=False, **kwargs): + super().__init__(**kwargs) + self.arg1 = arg1 + self.arg2 = arg2 + self.arg3 = arg3 + self.fail_on_first_attempt = fail_on_first_attempt + + def execute(self, context): + """Execute the operator and return passed arguments as tuple if do_xcom_push is True.""" + expected = copy.deepcopy(context) + + if self.fail_on_first_attempt: + self.fail_on_first_attempt = False + raise RuntimeError + if not self.do_xcom_push: + return None + result = self.arg1, self.arg2, self.arg3 + + assert context == expected, "Context was unexpectedly mutated during task execution" + return result + + +@pytest.fixture +def mock_xcom_get_one(monkeypatch: pytest.MonkeyPatch): + """ + Fixture that mocks XCom.get_one using monkeypatch for proper cleanup. + + Captures values pushed via context["ti"].xcom_push in the order they arrive, + then serves them back by index when XComIterable calls + XCom.get_one(key=f"{task_id}_{idx}", ...). + """ + + def _mock_xcom(context: Context): + pushed_values: list = [] + + original_push = context["ti"].xcom_push + + def capturing_push(key: str, value, **kwargs) -> None: + pushed_values.append(value) + original_push(key=key, value=value, **kwargs) + + monkeypatch.setattr(context["ti"], "xcom_push", capturing_push) + + task_id = context["ti"].task_id + + def mock_get_one(**kwargs): + key = kwargs.get("key", "") + prefix = f"{task_id}_" + if key.startswith(prefix): + try: + idx = int(key[len(prefix):]) + if 0 <= idx < len(pushed_values): + return pushed_values[idx] + except (ValueError, TypeError): + pass + return None + + monkeypatch.setattr(XCom, "get_one", mock_get_one) + + return _mock_xcom + + +class TestIterableOperator: + @classmethod + def create_mapped_operator( + cls, + dag: DAG, + expand_input: ExpandInput, + task_id: str = "my_task", + retries: int = DEFAULT_RETRIES, + do_xcom_push: bool = True, + task_concurrency: int | None = None, + ) -> MappedOperator: + """ + Create a MappedOperator and assign it to a DAG. + + :param expand_input: The input to expand + :param dag: The DAG to assign the operator to + :param task_id: Task ID for the operator + :param do_xcom_push: Whether to push XCom (default True) + """ + return MockOperator.partial( + task_id=task_id, + dag=dag, + retries=retries, + task_concurrency=task_concurrency, + do_xcom_push=do_xcom_push, + )._expand( + expand_input, + strict=True, + apply_upstream_relationship=False, + ) + + @classmethod + def create_iterable_operator( + cls, + dag: DAG, + expand_input: ExpandInput, + task_id: str = "my_task", + task_concurrency: int | None = None, + retries: int = DEFAULT_RETRIES, + do_xcom_push: bool = True, + ) -> IterableOperator: + """Create an IterableOperator with a MappedOperator and ExpandInput.""" + mapped_op = cls.create_mapped_operator( + dag=dag, + expand_input=expand_input, + task_id=task_id, + retries=retries, + do_xcom_push=do_xcom_push, + task_concurrency=task_concurrency, + ) + return IterableOperator( + operator=mapped_op, + expand_input=expand_input, + dag=dag, + ) + + @pytest.mark.db_test + @pytest.mark.parametrize( + ("actual", "expected"), + [ + ([{"a": 1}, {"a": 2}], [{"a": 1}, {"a": 2}]), + ([{"a": 1, "b": 2}], [{"a": 1, "b": 2}]), + ([], []), + ], + ) + def test_list_of_dicts_expand_input_iter_values(self, dag_maker, session, actual, expected): + """Test IterableOperator with ListOfDictsExpandInput expand_input.""" + if not actual: + pytest.skip("Empty list case tested separately") + + with dag_maker(session=session) as dag: + expand_input = ListOfDictsExpandInput(actual) + iterable_op = self.create_iterable_operator(dag, expand_input) + + result = list(iterable_op.expand_input.iter_values({})) + assert result == expected + + @pytest.mark.db_test + def test_list_of_dicts_empty(self, dag_maker, session): + """Test IterableOperator with empty list.""" + with dag_maker(session=session) as dag: + expand_input = ListOfDictsExpandInput([]) + iterable_op = self.create_iterable_operator(dag, expand_input) + + result = list(iterable_op.expand_input.iter_values({})) + assert result == [] + + @pytest.mark.db_test + @pytest.mark.parametrize( + ("actual", "expected"), + [ + ({"a": 1}, [{"a": 1}]), + ({"a": [1, 2, 3]}, [{"a": 1}, {"a": 2}, {"a": 3}]), + ({"a": "hello"}, [{"a": "hello"}]), + ({"a": [1, 2], "b": [10, 20]}, [{"a": 1, "b": 10}, {"a": 2, "b": 20}]), + ({"a": [1, 2]}, [{"a": 1}, {"a": 2}]), # Convert generator to list for testing + ], + ) + def test_dict_of_lists_expand_input_iter_values(self, dag_maker, session, actual, expected): + """Test IterableOperator with DictOfListsExpandInput expand_input.""" + with dag_maker(session=session) as dag: + expand_input = DictOfListsExpandInput(actual) + iterable_op = self.create_iterable_operator(dag, expand_input) + + result = list(iterable_op.expand_input.iter_values({})) + assert result == expected + + @pytest.mark.db_test + def test_task_type(self, dag_maker, session): + """Test that IterableOperator correctly reports task_type.""" + with dag_maker(session=session) as dag: + expand_input = ListOfDictsExpandInput([{"a": 1}]) + iterable_op = self.create_iterable_operator(dag, expand_input) + + assert isinstance(iterable_op, IterableOperator) + assert iterable_op.task_type == "MappedOperator" + + @pytest.mark.db_test + def test_task_id(self, dag_maker, session): + """Test that IterableOperator inherits task_id from operator.""" + with dag_maker(session=session) as dag: + task_id = "my_task" + expand_input = ListOfDictsExpandInput([{"a": 1}]) + iterable_op = self.create_iterable_operator(dag, expand_input, task_id=task_id) + + assert iterable_op.task_id == task_id + + @pytest.mark.db_test + def test_with_task_concurrency(self, dag_maker, session): + """Test that IterableOperator respects task_concurrency parameter.""" + with dag_maker(session=session) as dag: + expand_input = ListOfDictsExpandInput([{"a": 1}]) + iterable_op = self.create_iterable_operator(dag, expand_input, task_concurrency=4) + + assert iterable_op.max_workers == 4 + + @pytest.mark.db_test + def test_expand_input_stored(self, dag_maker, session): + """Test that IterableOperator stores expand_input correctly.""" + with dag_maker(session=session) as dag: + expand_input_data = ListOfDictsExpandInput([{"a": 1}, {"a": 2}]) + iterable_op = self.create_iterable_operator(dag, expand_input_data) + + assert iterable_op.expand_input is expand_input_data + assert isinstance(iterable_op.expand_input, (ListOfDictsExpandInput, DictOfListsExpandInput)) + + @pytest.mark.db_test + def test_partial_kwargs_stored(self, dag_maker, session): + """Test that IterableOperator stores partial_kwargs from operator.""" + with dag_maker(session=session) as dag: + expand_input = ListOfDictsExpandInput([{"a": 1}]) + iterable_op = self.create_iterable_operator(dag, expand_input) + + assert hasattr(iterable_op, "partial_kwargs") + assert isinstance(iterable_op.partial_kwargs, dict) + + @pytest.mark.db_test + def test_xcom_push_delegates_to_task_when_not_pushed(self, dag_maker, session): + """_xcom_push delegates to task.xcom_push only when xcom_pushed is False.""" + from unittest import mock + + with dag_maker(session=session) as dag: + expand_input = ListOfDictsExpandInput([{"arg1": 1}]) + iterable_op = self.create_iterable_operator(dag, expand_input) + + task = mock.MagicMock() + task.xcom_pushed = False + task.task_id = "my_task" + task.index = 0 + + iterable_op._xcom_push(task=task, value="result_value") + + task.xcom_push.assert_called_once_with(key=BaseXCom.XCOM_RETURN_KEY, value="result_value") + + @pytest.mark.db_test + def test_xcom_push_skips_when_already_pushed(self, dag_maker, session): + """_xcom_push skips pushing when xcom_pushed is already True.""" + from unittest import mock + + with dag_maker(session=session) as dag: + expand_input = ListOfDictsExpandInput([{"arg1": 1}]) + iterable_op = self.create_iterable_operator(dag, expand_input) + + task = mock.MagicMock() + task.xcom_pushed = True + task.task_id = "my_task" + task.index = 0 + + iterable_op._xcom_push(task=task, value="result_value") + + task.xcom_push.assert_not_called() + + @pytest.mark.db_test + def test_execute_list_of_dicts(self, dag_maker, session, mock_xcom_get_one): + """Test executing IterableOperator with ListOfDictsExpandInput.""" + with dag_maker(session=session) as dag: + expand_input = ListOfDictsExpandInput([{"arg1": 1}, {"arg1": 2}]) + iterable_op = self.create_iterable_operator(dag, expand_input, task_id="exec_list_of_dicts") + + context = mock_context(task=iterable_op) + mock_xcom_get_one(context) + result = iterable_op.execute(context=context) + materialized = list(result) + assert materialized == [(1, None, None), (2, None, None)] + + @pytest.mark.db_test + def test_execute_dict_of_lists(self, dag_maker, session, mock_xcom_get_one): + """Test executing IterableOperator with DictOfListsExpandInput.""" + with dag_maker(session=session) as dag: + expand_input = DictOfListsExpandInput({"arg1": [1, 2, 3]}) + iterable_op = self.create_iterable_operator(dag, expand_input, task_id="exec_dict_of_lists") + + context = mock_context(task=iterable_op) + mock_xcom_get_one(context) + result = iterable_op.execute(context=context) + materialized = list(result) + assert materialized == [(1, None, None), (2, None, None), (3, None, None)] + + @pytest.mark.db_test + def test_execute_empty_list_of_dicts(self, dag_maker, session, mock_xcom_get_one): + """Test executing IterableOperator with empty ListOfDictsExpandInput.""" + with dag_maker(session=session) as dag: + expand_input = ListOfDictsExpandInput([]) + iterable_op = self.create_iterable_operator(dag, expand_input, task_id="exec_empty") + + context = mock_context(task=iterable_op) + mock_xcom_get_one(context) + result = iterable_op.execute(context=context) + materialized = list(result) + assert materialized == [] + + @pytest.mark.db_test + def test_execute_multiple_key_dict_of_lists(self, dag_maker, session, mock_xcom_get_one): + """Test executing IterableOperator with multiple keys in DictOfListsExpandInput.""" + with dag_maker(session=session) as dag: + expand_input = DictOfListsExpandInput({"arg1": [1, 2], "arg2": [10, 20], "arg3": ["x", "y"]}) + iterable_op = self.create_iterable_operator(dag, expand_input, task_id="exec_multi_key") + + context = mock_context(task=iterable_op) + mock_xcom_get_one(context) + result = iterable_op.execute(context=context) + materialized = list(result) + assert materialized == [(1, 10, "x"), (2, 20, "y")] + + @pytest.mark.db_test + def test_execute_with_task_concurrency_setting(self, dag_maker, session, mock_xcom_get_one): + """Test executing IterableOperator with task_concurrency parameter.""" + with dag_maker(session=session) as dag: + expand_input = ListOfDictsExpandInput([{"arg1": 1}, {"arg1": 2}, {"arg1": 3}]) + iterable_op = self.create_iterable_operator( + dag, expand_input, task_id="exec_concurrency", task_concurrency=2 + ) + + context = mock_context(task=iterable_op) + mock_xcom_get_one(context) + result = iterable_op.execute(context=context) + materialized = list(result) + assert materialized == [(1, None, None), (2, None, None), (3, None, None)] + assert iterable_op.max_workers == 2 + + @pytest.mark.db_test + def test_execute_all_parameters(self, dag_maker, session, mock_xcom_get_one): + """Test executing IterableOperator with all arg1, arg2, arg3 parameters.""" + with dag_maker(session=session) as dag: + expand_input = ListOfDictsExpandInput( + [ + {"arg1": 1, "arg2": 10, "arg3": 100}, + {"arg1": 2, "arg2": 20, "arg3": 200}, + ] + ) + iterable_op = self.create_iterable_operator(dag, expand_input, task_id="exec_all_args") + + context = mock_context(task=iterable_op) + mock_xcom_get_one(context) + result = iterable_op.execute(context=context) + materialized = list(result) + assert materialized == [(1, 10, 100), (2, 20, 200)] + + @pytest.mark.db_test + def test_execute_with_do_xcom_push_false(self, dag_maker, session): + """Test executing IterableOperator when do_xcom_push is False.""" + with dag_maker(session=session) as dag: + expand_input = ListOfDictsExpandInput([{"arg1": 1}, {"arg1": 2}]) + iterable_op = self.create_iterable_operator( + dag, expand_input, task_id="no_xcom_push", do_xcom_push=False + ) + + context = mock_context(task=iterable_op) + result = iterable_op.execute(context=context) + + assert result is None + + @pytest.mark.db_test + def test_execute_with_failed_tasks_but_no_retries(self, dag_maker, session, mock_xcom_get_one): + """Test executing IterableOperator where tasks fail but no retries are available. + + This test verifies that: + 1. Tasks with fail_on_first_attempt=True raise an exception on first attempt + 2. When no retries are configured (retries=0), the exception propagates and is not retried + 3. The BaseExceptionGroup is raised containing the task failure + """ + with dag_maker(session=session) as dag: + expand_input = ListOfDictsExpandInput( + [ + {"arg1": 1, "arg2": 10}, + {"arg1": 2, "arg2": 20, "fail_on_first_attempt": True}, + {"arg1": 3, "arg2": 30}, + ] + ) + iterable_op = self.create_iterable_operator( + dag, + expand_input, + task_id="exec_with_failures", + retries=0, + ) + + context = mock_context(task=iterable_op) + mock_xcom_get_one(context) + with pytest.raises(BaseExceptionGroup): + iterable_op.execute(context=context) + + @pytest.mark.db_test + def test_execute_with_failed_tasks_and_expired_reschedule_date( + self, dag_maker, session, mock_xcom_get_one + ): + """Test executing IterableOperator where certain map_index tasks fail on first attempt and are retried. + + This test verifies that: + 1. Tasks with fail_on_first_attempt=True raise an exception on first attempt (try_number == 0) + 2. Failed tasks are retried immediately without deferring (since reschedule_date is expired) + 3. Retried tasks succeed on subsequent attempts (try_number > 0) and produce the expected output + """ + with dag_maker(session=session) as dag: + expand_input = ListOfDictsExpandInput( + [ + {"arg1": 1, "arg2": 10}, + {"arg1": 2, "arg2": 20, "fail_on_first_attempt": True}, + {"arg1": 3, "arg2": 30}, + ] + ) + iterable_op = self.create_iterable_operator( + dag, + expand_input, + task_id="exec_with_failures", + retries=1, + ) + + context = mock_context(task=iterable_op) + mock_xcom_get_one(context) + result = iterable_op.execute(context=context) + materialized = list(result) + + assert len(materialized) == 3 + assert materialized == [(1, 10, None), (2, 20, None), (3, 30, None)] diff --git a/task-sdk/tests/task_sdk/definitions/test_partitionedoperator.py b/task-sdk/tests/task_sdk/definitions/test_partitionedoperator.py new file mode 100644 index 0000000000000..8324798a4b844 --- /dev/null +++ b/task-sdk/tests/task_sdk/definitions/test_partitionedoperator.py @@ -0,0 +1,77 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from collections import defaultdict +from collections.abc import Callable +from unittest import mock + +from airflow.sdk import DAG, TaskInstanceState +from airflow.sdk.bases.xcom import BaseXCom +from airflow.sdk.execution_time.comms import ( + GetTICount, + GetXCom, + GetXComSequenceSlice, + TICount, + XComResult, + XComSequenceSliceResult, +) + +RunTI = Callable[[DAG, str, int], TaskInstanceState] + + +class TestPartitionedOperator: + def test_partition_iterate(self, run_ti: RunTI, mock_supervisor_comms): + """Test a partitioned task which iterates on it's set of values.""" + outputs = defaultdict(list) + numbers = list(range(10)) + + with DAG(dag_id="product_same") as dag: + + @dag.task + def emit_numbers(): + return numbers + + @dag.task + def show(number, **context): + map_index = str(context["ti"].map_index) + outputs[map_index].append(number) + return number + + emit_task = emit_numbers() + show.partition(size=2).iterate(number=emit_task) + + def mock_comms(msg): + if isinstance(msg, GetXCom): + if msg.task_id == "emit_numbers": + return XComResult(key=BaseXCom.XCOM_RETURN_KEY, value=numbers) + elif isinstance(msg, GetXComSequenceSlice): + if msg.task_id == "emit_numbers": + return XComSequenceSliceResult(root=numbers) + elif isinstance(msg, GetTICount): + if msg.task_ids and msg.task_ids[0] == "show": + return TICount(count=2) + return TICount(count=1) + return mock.DEFAULT + + mock_supervisor_comms.send.side_effect = mock_comms + + states = [run_ti(dag, "show", map_index) for map_index in range(2)] + assert states == [TaskInstanceState.SUCCESS] * 2 + assert set(outputs["0"]) == {0, 2, 4, 6, 8} + assert set(outputs["1"]) == {1, 3, 5, 7, 9} diff --git a/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py b/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py index af487851b07cb..e6b39c8bb253e 100644 --- a/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py +++ b/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py @@ -22,6 +22,7 @@ import pytest import structlog +from task_sdk.definitions.conftest import make_xcom_arg from airflow.sdk import TaskInstanceState from airflow.sdk.bases.xcom import BaseXCom @@ -350,3 +351,39 @@ def xcom_get(msg): states = [run_ti(dag, "pull_one", map_index) for map_index in range(5)] assert states == [TaskInstanceState.SUCCESS] * 5 assert agg_results == {"a", "b", "c", 1, 2} + + +class TestXComArg: + @pytest.mark.parametrize( + ("actual", "expected"), + [ + (1, [1]), # scalar + ("hello", ["hello"]), # string + ([1, 2, 3], [1, 2, 3]), # list + ((x for x in [4, 5]), [4, 5]), # generator + ], + ) + def test_plain_xcomarg_iter_values(self, actual, expected): + xcom_arg = make_xcom_arg(actual) + result = list(xcom_arg.iter_values({})) + assert result == expected + + def test_map_xcomarg_iter_values(self): + base = make_xcom_arg([1, 2, 3]) + mapped = base.map(lambda x: x * 10) + result = list(mapped.iter_values({})) + assert result == [10, 20, 30] + + def test_zip_xcomarg_iter_values(self): + a = make_xcom_arg([1, 2]) + b = make_xcom_arg([10, 20]) + zipped = a.zip(b) + result = list(zipped.iter_values({})) + assert result == [(1, 10), (2, 20)] + + def test_concat_xcomarg_iter_values(self): + a = make_xcom_arg([1, 2]) + b = make_xcom_arg([10, 20]) + concatenated = a.concat(b) + result = list(concatenated.iter_values({})) + assert result == [1, 2, 10, 20] diff --git a/task-sdk/tests/task_sdk/execution_time/conftest.py b/task-sdk/tests/task_sdk/execution_time/conftest.py index 4a537373363aa..7fd017ce83660 100644 --- a/task-sdk/tests/task_sdk/execution_time/conftest.py +++ b/task-sdk/tests/task_sdk/execution_time/conftest.py @@ -18,8 +18,15 @@ from __future__ import annotations import sys +from datetime import timedelta +from unittest import mock import pytest +from uuid6 import uuid7 + +from airflow.sdk import BaseAsyncOperator, BaseOperator, timezone +from airflow.sdk.api.datamodels._generated import TaskInstanceState +from airflow.sdk.execution_time.task_runner import MappedTaskInstance @pytest.fixture @@ -31,3 +38,73 @@ def disable_capturing(): sys.stderr = sys.__stderr__ yield sys.stdin, sys.stdout, sys.stderr = old_in, old_out, old_err + + +@pytest.fixture +def make_mapped_ti(): + """Factory for creating MappedTaskInstance objects for testing.""" + + def _make_mapped_ti( + *, + task_id: str = "my_task", + dag_id: str = "my_dag", + run_id: str = "run_1", + map_index: int = -1, + index: int | None = 0, + try_number: int = 0, + max_tries: int = 3, + is_async: bool = False, + retry_delay: timedelta = None, + retry_exponential_backoff: bool = False, + max_retry_delay: timedelta | None = None, + end_date=None, + start_date=None, + logical_date=None, + do_xcom_push: bool = True, + ) -> MappedTaskInstance: + """Create a MappedTaskInstance via model_construct to bypass full Pydantic validation.""" + if retry_delay is None: + retry_delay = timedelta(seconds=300) + + # Set defaults for dates if not provided + if end_date is None: + end_date = timezone.datetime(2024, 12, 3, 10, 0, 0) + if start_date is None: + start_date = timezone.datetime(2024, 12, 3, 9, 55, 0) + if logical_date is None: + logical_date = timezone.datetime(2024, 12, 3, 0, 0, 0) + + operator_cls = BaseAsyncOperator if is_async else BaseOperator + operator = mock.create_autospec(operator_cls, instance=True) + operator.task_id = task_id + operator.dag_id = dag_id + operator.is_async = is_async + operator.retries = max_tries + operator.retry_delay = retry_delay + operator.retry_exponential_backoff = retry_exponential_backoff + operator.max_retry_delay = max_retry_delay + operator.do_xcom_push = do_xcom_push + + return MappedTaskInstance.model_construct( + id=uuid7(), + task_id=task_id, + dag_id=dag_id, + run_id=run_id, + map_index=map_index, + index=index, + try_number=try_number, + max_tries=max_tries, + state=TaskInstanceState.SCHEDULED, + is_mapped=True, + task=operator, + xcom_pushed=False, + dag_version_id=uuid7(), + end_date=end_date, + start_date=start_date, + logical_date=logical_date, + ) + + return _make_mapped_ti + + + diff --git a/task-sdk/tests/task_sdk/execution_time/test_executor.py b/task-sdk/tests/task_sdk/execution_time/test_executor.py new file mode 100644 index 0000000000000..e15e5668efc91 --- /dev/null +++ b/task-sdk/tests/task_sdk/execution_time/test_executor.py @@ -0,0 +1,362 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +from concurrent.futures import Future +from datetime import timedelta +from unittest import mock + +import pytest +from task_sdk.execution_time.test_task_runner import get_inline_dag +from uuid6 import uuid7 + +from airflow.sdk import BaseAsyncOperator, BaseOperator +from airflow.sdk.api.datamodels._generated import TaskInstanceState +from airflow.sdk.bases.operator import event_loop +from airflow.sdk.exceptions import AirflowRescheduleTaskInstanceException, TaskDeferred +from airflow.sdk.execution_time.executor import ConcurrentExecutor, TaskExecutor, collect_futures +from airflow.sdk.execution_time.task_runner import MappedTaskInstance + +from tests_common.test_utils.mock_context import mock_context + + + +class TestConcurrentExecutor: + def test_submit_sync_function_returns_future(self): + """Sync callables are dispatched to the thread pool and return a concurrent.futures.Future.""" + with event_loop() as loop: + with ConcurrentExecutor(loop=loop, max_workers=2) as executor: + future = executor.submit(lambda: 42) + assert isinstance(future, Future) + assert future.result() == 42 + + def test_submit_async_coroutine_function_returns_task(self): + """Async callables are scheduled on the event loop and return an asyncio.Task.""" + with event_loop() as loop: + with ConcurrentExecutor(loop=loop, max_workers=2) as executor: + + async def async_fn(): + return "async_result" + + task = executor.submit(async_fn) + assert isinstance(task, asyncio.Task) + result = loop.run_until_complete(task) + assert result == "async_result" + + def test_submit_coroutine_object_returns_task(self): + """Passing a coroutine object (not a function) directly is also scheduled on the event loop.""" + with event_loop() as loop: + with ConcurrentExecutor(loop=loop, max_workers=2) as executor: + + async def async_fn(): + return "coro_result" + + coro = async_fn() + task = executor.submit(coro) + assert isinstance(task, asyncio.Task) + result = loop.run_until_complete(task) + assert result == "coro_result" + + def test_submit_sync_function_propagates_exception(self): + """Exceptions raised inside sync callables are propagated when the future is resolved.""" + with event_loop() as loop: + with ConcurrentExecutor(loop=loop, max_workers=2) as executor: + future = executor.submit(lambda: (_ for _ in ()).throw(ValueError("boom"))) + assert isinstance(future, Future) + with pytest.raises(ValueError, match="boom"): + future.result() + + def test_submit_async_function_propagates_exception(self): + """Exceptions raised inside async callables are propagated when the task is awaited.""" + with event_loop() as loop: + with ConcurrentExecutor(loop=loop, max_workers=2) as executor: + + async def failing(): + raise RuntimeError("async boom") + + task = executor.submit(failing) + with pytest.raises(RuntimeError, match="async boom"): + loop.run_until_complete(task) + + def test_semaphore_limits_concurrent_async_tasks(self): + """The semaphore prevents more than max_workers coroutines from running simultaneously.""" + concurrency_high_watermark = 0 + running = 0 + + async def count_concurrent(): + nonlocal concurrency_high_watermark, running + running += 1 + concurrency_high_watermark = max(concurrency_high_watermark, running) + await asyncio.sleep(0) + running -= 1 + + max_workers = 2 + with event_loop() as loop: + with ConcurrentExecutor(loop=loop, max_workers=max_workers) as executor: + tasks = [executor.submit(count_concurrent) for _ in range(6)] + loop.run_until_complete(asyncio.gather(*tasks)) + + assert concurrency_high_watermark <= max_workers + + def test_exit_shuts_down_thread_pool(self): + """__exit__ calls shutdown on the thread pool.""" + with event_loop() as loop: + executor = ConcurrentExecutor(loop=loop, max_workers=2) + with mock.patch.object( + executor._thread_pool, "shutdown", wraps=executor._thread_pool.shutdown + ) as shutdown_mock: + with executor: + pass + shutdown_mock.assert_called_once_with(wait=True) + + def test_context_manager_returns_self(self): + """__enter__ returns the executor instance itself.""" + with event_loop() as loop: + executor = ConcurrentExecutor(loop=loop, max_workers=2) + with executor as ctx: + assert ctx is executor + + +class TestCollectFutures: + def test_yields_sync_futures(self): + """collect_futures yields completed concurrent.futures.Future objects.""" + with event_loop() as loop: + f1: Future = Future() + f2: Future = Future() + f1.set_result("a") + f2.set_result("b") + + results = list(collect_futures(loop, [f1, f2])) + assert set(results) == {f1, f2} + + def test_yields_async_tasks(self): + """collect_futures yields completed asyncio.Task objects.""" + with event_loop() as loop: + + async def coro(val): + return val + + t1 = loop.create_task(coro("x")) + t2 = loop.create_task(coro("y")) + + results = list(collect_futures(loop, [t1, t2])) + assert set(results) == {t1, t2} + + def test_yields_mixed_futures_and_tasks(self): + """collect_futures handles a mix of concurrent.futures.Future and asyncio.Task.""" + with event_loop() as loop: + f: Future = Future() + f.set_result(1) + + async def coro(): + return 2 + + t = loop.create_task(coro()) + + results = list(collect_futures(loop, [f, t])) + assert len(results) == 2 + assert f in results + assert t in results + + def test_empty_list_yields_nothing(self): + """collect_futures with an empty list yields nothing.""" + with event_loop() as loop: + results = list(collect_futures(loop, [])) + assert results == [] + + +class TestTaskExecutor: + def test_dag_id_property(self, make_mapped_ti): + ti = make_mapped_ti(dag_id="my_dag") + executor = TaskExecutor(task_instance=ti) + assert executor.dag_id == "my_dag" + + def test_task_id_property(self, make_mapped_ti): + ti = make_mapped_ti(task_id="my_task") + executor = TaskExecutor(task_instance=ti) + assert executor.task_id == "my_task" + + def test_task_index(self, make_mapped_ti): + ti = make_mapped_ti(index=3) + executor = TaskExecutor(task_instance=ti) + assert executor.task_index == ti.index + assert executor.task_index == 3 + + def test_xcom_key_property_returns_xcom_key(self, make_mapped_ti): + ti = make_mapped_ti(task_id="op", index=2) + executor = TaskExecutor(task_instance=ti) + assert executor.xcom_key == ti.xcom_key + assert executor.xcom_key == "op_2" + + def test_operator_property(self, make_mapped_ti): + ti = make_mapped_ti() + executor = TaskExecutor(task_instance=ti) + assert executor.operator is ti.task + + def test_is_async_property_sync(self, make_mapped_ti): + ti = make_mapped_ti(is_async=False) + executor = TaskExecutor(task_instance=ti) + assert executor.is_async is False + + def test_is_async_property_async(self, make_mapped_ti): + ti = make_mapped_ti(is_async=True) + executor = TaskExecutor(task_instance=ti) + assert executor.is_async is True + + def test_enter_sets_start_time(self, make_mapped_ti): + ti = make_mapped_ti() + executor = TaskExecutor(task_instance=ti) + assert executor._start_time is None + executor.__enter__() + assert executor._start_time is not None + + def test_enter_returns_self(self, make_mapped_ti): + ti = make_mapped_ti() + executor = TaskExecutor(task_instance=ti) + with executor as ctx: + assert ctx is executor + + def test_exit_success_sets_state(self, make_mapped_ti): + """__exit__ without an exception marks the task instance as SUCCESS.""" + ti = make_mapped_ti() + with TaskExecutor(task_instance=ti): + pass # no exception + assert ti.state == TaskInstanceState.SUCCESS + + def test_exit_with_task_deferred_reraises(self, make_mapped_ti): + """TaskDeferred must propagate unchanged through __exit__.""" + ti = make_mapped_ti() + trigger = mock.Mock() + deferred = TaskDeferred(trigger=trigger, method_name="resume") + + with pytest.raises(TaskDeferred): + with TaskExecutor(task_instance=ti): + raise deferred + + def test_exit_reschedules_when_retries_remain(self, make_mapped_ti): + """ + When a retryable exception occurs and retries are not exhausted, + the task state is set to UP_FOR_RESCHEDULE and + AirflowRescheduleTaskInstanceException is raised. + """ + # try_number=0 → next_try_number=1; max_tries=3 → 1 <= 3, reschedule + ti = make_mapped_ti(try_number=0, max_tries=3) + + with pytest.raises(AirflowRescheduleTaskInstanceException): + with TaskExecutor(task_instance=ti): + raise RuntimeError("transient failure") + + assert ti.state == TaskInstanceState.UP_FOR_RESCHEDULE + assert ti.try_number == 1 # incremented + + def test_exit_fails_when_retries_exhausted(self, make_mapped_ti): + """ + When a retryable exception occurs and all retries are exhausted, + the task state is set to FAILED and the original exception is re-raised. + """ + # try_number=3 → next_try_number=4; max_tries=3 → 4 > 3, fail + ti = make_mapped_ti(try_number=3, max_tries=3) + original_error = RuntimeError("permanent failure") + + with pytest.raises(RuntimeError, match="permanent failure"): + with TaskExecutor(task_instance=ti): + raise original_error + + assert ti.state == TaskInstanceState.FAILED + + @pytest.mark.parametrize( + ("try_number", "max_tries", "should_fail"), + [ + (0, 0, True), # next=1, max=0 → 1>0 → fail + (0, 1, False), # next=1, max=1 → 1<=1 → reschedule + (1, 1, True), # next=2, max=1 → 2>1 → fail + (2, 3, False), # next=3, max=3 → 3<=3 → reschedule + (3, 3, True), # next=4, max=3 → 4>3 → fail + ], + ) + def test_exit_retry_boundary(self, make_mapped_ti, try_number, max_tries, should_fail): + """Exhaustive boundary checks for the retry/fail decision in __exit__.""" + ti = make_mapped_ti(try_number=try_number, max_tries=max_tries) + if should_fail: + with pytest.raises(RuntimeError): + with TaskExecutor(task_instance=ti): + raise RuntimeError("err") + assert ti.state == TaskInstanceState.FAILED + else: + with pytest.raises(AirflowRescheduleTaskInstanceException): + with TaskExecutor(task_instance=ti): + raise RuntimeError("err") + assert ti.state == TaskInstanceState.UP_FOR_RESCHEDULE + + def test_run_delegates_to_execute_task(self, make_mapped_ti): + """run() must call _execute_task with the given context.""" + ti = make_mapped_ti() + task = BaseOperator(task_id="test_task") + get_inline_dag("test_dag", task) + context = mock_context(task) + executor = TaskExecutor(task_instance=ti) + + with mock.patch( + "airflow.sdk.execution_time.executor._execute_task", + autospec=True, + return_value="result", + ) as mock_execute: + result = executor.run(context) + + mock_execute.assert_called_once_with(context, ti, executor.log) + assert result == "result" + + @pytest.mark.asyncio + async def test_arun_delegates_to_execute_async_task(self, make_mapped_ti): + """arun() must call _execute_async_task with the given context.""" + ti = make_mapped_ti(is_async=True) + task = BaseOperator(task_id="test_task") + get_inline_dag("test_dag", task) + context = mock_context(task) + executor = TaskExecutor(task_instance=ti) + + with mock.patch( + "airflow.sdk.execution_time.executor._execute_async_task", + new=mock.AsyncMock(return_value="async_result"), + ) as mock_async_execute: + result = await executor.arun(context) + + mock_async_execute.assert_called_once_with(context, ti, executor.log) + assert result == "async_result" + + @pytest.mark.asyncio + async def test_async_context_manager_enter_returns_self(self, make_mapped_ti): + ti = make_mapped_ti() + executor = TaskExecutor(task_instance=ti) + async with executor as ctx: + assert ctx is executor + + @pytest.mark.asyncio + async def test_async_context_manager_exit_success(self, make_mapped_ti): + ti = make_mapped_ti() + async with TaskExecutor(task_instance=ti): + pass + assert ti.state == TaskInstanceState.SUCCESS + + @pytest.mark.asyncio + async def test_async_context_manager_exit_reschedules(self, make_mapped_ti): + ti = make_mapped_ti(try_number=0, max_tries=3) + with pytest.raises(AirflowRescheduleTaskInstanceException): + async with TaskExecutor(task_instance=ti): + raise RuntimeError("async transient failure") + assert ti.state == TaskInstanceState.UP_FOR_RESCHEDULE diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 630aff9094ed1..cf98475a9019e 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -133,6 +133,7 @@ VariableAccessor, ) from airflow.sdk.execution_time.task_runner import ( + MappedTaskInstance, RuntimeTaskInstance, TaskRunnerMarker, _defer_task, @@ -1716,6 +1717,90 @@ def test_function(ti): assert ti.rendered_map_index == "Label: test_task" +class TestMappedTaskInstance: + def test_xcom_push_delegates_with_index_suffix(self, make_mapped_ti): + """xcom_push delegates to RuntimeTaskInstance with key suffixed by _{index}.""" + ti = make_mapped_ti(index=3) + + with mock.patch( + "airflow.sdk.execution_time.task_runner._xcom_push", autospec=True + ) as mock_push: + ti.xcom_push(key="result", value="ok") + + mock_push.assert_called_once_with(ti, "result_3", "ok") + + def test_xcom_push_sets_xcom_pushed_only_for_default_key(self, make_mapped_ti): + """xcom_push sets xcom_pushed to True only when pushing with XCOM_RETURN_KEY.""" + from airflow.sdk.bases.xcom import BaseXCom + + ti = make_mapped_ti(index=2) + assert ti.xcom_pushed is False + + with mock.patch("airflow.sdk.execution_time.task_runner._xcom_push", autospec=True): + # Push with default key — should set xcom_pushed to True + ti.xcom_push(key=BaseXCom.XCOM_RETURN_KEY, value="value1") + + assert ti.xcom_pushed is True + + def test_xcom_push_does_not_set_flag_for_custom_key(self, make_mapped_ti): + """xcom_push does NOT set xcom_pushed when pushing with a custom key.""" + ti = make_mapped_ti(index=1) + assert ti.xcom_pushed is False + + with mock.patch("airflow.sdk.execution_time.task_runner._xcom_push", autospec=True): + # Push with custom key — should NOT set xcom_pushed + ti.xcom_push(key="custom_key", value="value2") + + assert ti.xcom_pushed is False + + def test_xcom_pull_delegates_with_index_suffix(self, make_mapped_ti): + """xcom_pull delegates to RuntimeTaskInstance with key suffixed by _{index}.""" + ti = make_mapped_ti(index=5) + + with mock.patch.object( + ti.__class__.__bases__[0], + "xcom_pull", + autospec=True, + return_value="pulled_value", + ) as mock_pull: + result = ti.xcom_pull(key="result") + + mock_pull.assert_called_once_with( + ti, + task_ids=None, + dag_id=None, + key="result_5", + include_prior_dates=False, + map_indexes=mock.ANY, + default=None, + run_id=None, + ) + assert result == "pulled_value" + + def test_next_retry_datetime_without_exponential_backoff(self, make_mapped_ti): + ti = make_mapped_ti(retry_delay=timedelta(seconds=30), retry_exponential_backoff=False) + + assert ti.next_retry_datetime() == timezone.datetime(2024, 12, 3, 10, 0, 30) + + def test_next_retry_datetime_exponential_backoff_honors_max_retry_delay(self, make_mapped_ti): + ti = make_mapped_ti( + try_number=2, + retry_delay=timedelta(seconds=10), + retry_exponential_backoff=True, + max_retry_delay=timedelta(seconds=5), + ) + + assert ti.next_retry_datetime() == timezone.datetime(2024, 12, 3, 10, 0, 5) + + def test_properties(self, make_mapped_ti): + ti = make_mapped_ti(index=7, try_number=4, is_async=True, do_xcom_push=False) + + assert ti.is_async is True + assert ti.next_try_number == 5 + assert ti.xcom_key == "my_task_7" + assert ti.do_xcom_push is False + + class TestRuntimeTaskInstance: def test_get_context_without_ti_context_from_server(self, mocked_parse, make_ti_context): """Test get_template_context without ti_context_from_server.""" diff --git a/uv.lock b/uv.lock index e0084ec1cb511..3dbbc68bf7289 100644 --- a/uv.lock +++ b/uv.lock @@ -8442,6 +8442,7 @@ dependencies = [ { name = "jinja2" }, { name = "jsonschema" }, { name = "methodtools" }, + { name = "more-itertools", marker = "python_full_version < '3.12'" }, { name = "msgspec" }, { name = "opentelemetry-api" }, { name = "packaging" }, @@ -8518,6 +8519,7 @@ requires-dist = [ { name = "jinja2", specifier = ">=3.1.5" }, { name = "jsonschema", specifier = ">=4.19.1" }, { name = "methodtools", specifier = ">=0.4.7" }, + { name = "more-itertools", marker = "python_full_version < '3.12'", specifier = ">=9.0.0" }, { name = "msgspec", specifier = ">=0.19.0" }, { name = "opentelemetry-api", specifier = ">=1.27.0" }, { name = "opentelemetry-api", marker = "extra == 'all'", specifier = ">=1.27.0" },