Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent using trigger_rule=TriggerRule.ALWAYS in a task-generated mapping within mapped task groups #43368

Merged
merged 1 commit into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ The grid view also provides visibility into your mapped tasks in the details pan

Although we show a "reduce" task here (``sum_it``) you don't have to have one, the mapped tasks will still be executed even if they have no downstream tasks.

.. warning:: ``TriggerRule.ALWAYS`` cannot be utilized in expanded tasks

Assigning ``trigger_rule=TriggerRule.ALWAYS`` in expanded tasks is forbidden, as expanded parameters will be undefined with the task's immediate execution.
This is enforced at the time of the DAG parsing, and will raise an error if you try to use it.

Task-generated Mapping
----------------------

Expand Down
22 changes: 18 additions & 4 deletions task_sdk/src/airflow/sdk/definitions/taskgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
TaskAlreadyInTaskGroup,
)
from airflow.sdk.definitions.node import DAGNode
from airflow.utils.trigger_rule import TriggerRule

if TYPE_CHECKING:
from airflow.models.expandinput import ExpandInput
Expand Down Expand Up @@ -195,10 +196,15 @@ def task_group(self, value: TaskGroup | None):

def __iter__(self):
for child in self.children.values():
if isinstance(child, TaskGroup):
yield from child
else:
yield child
yield from self._iter_child(child)

@staticmethod
def _iter_child(child):
"""Iterate over the children of this TaskGroup."""
if isinstance(child, TaskGroup):
yield from child
else:
yield child

def add(self, task: DAGNode) -> DAGNode:
"""
Expand Down Expand Up @@ -574,6 +580,14 @@ def __init__(self, *, expand_input: ExpandInput, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._expand_input = expand_input

def __iter__(self):
from airflow.models.abstractoperator import AbstractOperator

for child in self.children.values():
if isinstance(child, AbstractOperator) and child.trigger_rule == TriggerRule.ALWAYS:
raise ValueError("Tasks in a mapped task group cannot have trigger_rule set to 'ALWAYS'")
yield from self._iter_child(child)

def iter_mapped_dependencies(self) -> Iterator[DAGNode]:
"""Upstream dependencies that provide XComs used by this mapped task group."""
from airflow.models.xcom_arg import XComArg
Expand Down
25 changes: 24 additions & 1 deletion tests/decorators/test_task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@
import pendulum
import pytest

from airflow.decorators import dag, task_group
from airflow.decorators import dag, task, task_group
from airflow.models.expandinput import DictOfListsExpandInput, ListOfDictsExpandInput, MappedArgument
from airflow.operators.empty import EmptyOperator
from airflow.utils.task_group import MappedTaskGroup
from airflow.utils.trigger_rule import TriggerRule


def test_task_group_with_overridden_kwargs():
Expand Down Expand Up @@ -133,6 +134,28 @@ def tg():
assert str(ctx.value) == "no arguments to expand against"


@pytest.mark.db_test
def test_expand_fail_trigger_rule_always(dag_maker, session):
@dag(schedule=None, start_date=pendulum.datetime(2022, 1, 1))
def pipeline():
@task
def get_param():
return ["a", "b", "c"]

@task(trigger_rule=TriggerRule.ALWAYS)
def t1(param):
return param

@task_group()
def tg(param):
t1(param)

with pytest.raises(
ValueError, match="Tasks in a mapped task group cannot have trigger_rule set to 'ALWAYS'"
):
tg.expand(param=get_param())


def test_expand_create_mapped():
saved = {}

Expand Down