|
18 | 18 | import sys
|
19 | 19 | import typing
|
20 | 20 | from collections import namedtuple
|
21 |
| -from datetime import date, timedelta |
| 21 | +from datetime import date |
22 | 22 | from typing import Union
|
23 | 23 |
|
24 | 24 | import pytest
|
|
28 | 28 | from airflow.exceptions import AirflowException, XComNotFound
|
29 | 29 | from airflow.models.taskinstance import TaskInstance
|
30 | 30 | from airflow.models.taskmap import TaskMap
|
31 |
| -from airflow.sdk import DAG, BaseOperator, TaskGroup, XComArg |
32 |
| -from airflow.sdk.definitions.mappedoperator import MappedOperator |
33 | 31 | from airflow.utils import timezone
|
34 | 32 | from airflow.utils.state import State
|
35 | 33 | from airflow.utils.task_instance_session import set_current_task_instance_session
|
|
41 | 39 | from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
|
42 | 40 |
|
43 | 41 | if AIRFLOW_V_3_0_PLUS:
|
| 42 | + from airflow.sdk import DAG, BaseOperator, TaskGroup, XComArg |
44 | 43 | from airflow.sdk.definitions._internal.expandinput import DictOfListsExpandInput
|
45 | 44 | from airflow.sdk.definitions.mappedoperator import MappedOperator
|
46 | 45 | from airflow.utils.types import DagRunTriggeredByType
|
47 | 46 | else:
|
| 47 | + from airflow.models.baseoperator import BaseOperator |
| 48 | + from airflow.models.dag import DAG # type: ignore[assignment] |
48 | 49 | from airflow.models.expandinput import DictOfListsExpandInput
|
49 | 50 | from airflow.models.mappedoperator import MappedOperator
|
| 51 | + from airflow.models.xcom_arg import XComArg |
| 52 | + from airflow.utils.task_group import TaskGroup |
50 | 53 |
|
51 | 54 | pytestmark = pytest.mark.db_test
|
52 | 55 |
|
@@ -794,36 +797,6 @@ def task2(arg1, arg2): ...
|
794 | 797 | assert set(unmapped.op_kwargs) == {"arg1", "arg2"}
|
795 | 798 |
|
796 | 799 |
|
797 |
| -def test_mapped_decorator_converts_partial_kwargs(dag_maker, session): |
798 |
| - with dag_maker(session=session): |
799 |
| - |
800 |
| - @task_decorator |
801 |
| - def task1(arg): |
802 |
| - return ["x" * arg] |
803 |
| - |
804 |
| - @task_decorator(retry_delay=30) |
805 |
| - def task2(arg1, arg2): ... |
806 |
| - |
807 |
| - task2.partial(arg1=1).expand(arg2=task1.expand(arg=[1, 2])) |
808 |
| - |
809 |
| - run = dag_maker.create_dagrun() |
810 |
| - |
811 |
| - # Expand and run task1. |
812 |
| - dec = run.task_instance_scheduling_decisions(session=session) |
813 |
| - assert [ti.task_id for ti in dec.schedulable_tis] == ["task1", "task1"] |
814 |
| - for ti in dec.schedulable_tis: |
815 |
| - ti.run(session=session) |
816 |
| - assert not isinstance(ti.task, MappedOperator) |
817 |
| - assert ti.task.retry_delay == timedelta(seconds=300) # Operator default. |
818 |
| - |
819 |
| - # Expand task2. |
820 |
| - dec = run.task_instance_scheduling_decisions(session=session) |
821 |
| - assert [ti.task_id for ti in dec.schedulable_tis] == ["task2", "task2"] |
822 |
| - for ti in dec.schedulable_tis: |
823 |
| - unmapped = ti.task.unmap((ti.get_template_context(session),)) |
824 |
| - assert unmapped.retry_delay == timedelta(seconds=30) |
825 |
| - |
826 |
| - |
827 | 800 | def test_mapped_render_template_fields(dag_maker, session):
|
828 | 801 | @task_decorator
|
829 | 802 | def fn(arg1, arg2): ...
|
|
0 commit comments