diff --git a/airflow-core/newsfragments/68954.bugfix.rst b/airflow-core/newsfragments/68954.bugfix.rst new file mode 100644 index 0000000000000..eca1675dadfe4 --- /dev/null +++ b/airflow-core/newsfragments/68954.bugfix.rst @@ -0,0 +1 @@ +Fix branch operators not skipping mapped or cleared downstream tasks when a custom XCom backend is configured. diff --git a/airflow-core/tests/unit/ti_deps/deps/test_not_previously_skipped_dep.py b/airflow-core/tests/unit/ti_deps/deps/test_not_previously_skipped_dep.py index 58ccce35dd896..cd5b364311bce 100644 --- a/airflow-core/tests/unit/ti_deps/deps/test_not_previously_skipped_dep.py +++ b/airflow-core/tests/unit/ti_deps/deps/test_not_previously_skipped_dep.py @@ -17,14 +17,17 @@ # under the License. from __future__ import annotations +from unittest import mock + import pendulum import pytest -from sqlalchemy import delete +from sqlalchemy import delete, select from airflow.models import DagRun, TaskInstance from airflow.models.xcom import XComModel from airflow.providers.standard.operators.empty import EmptyOperator from airflow.providers.standard.operators.python import BranchPythonOperator +from airflow.sdk.bases.xcom import BaseXCom from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.deps.not_previously_skipped_dep import ( XCOM_SKIPMIXIN_FOLLOWED, @@ -214,3 +217,65 @@ def test_unmapped_parent_skip_mapped_downstream(session, dag_maker): assert len(list(dep.get_dep_statuses(tis["op2"], DepContext(), session=session))) == 1 assert not dep.is_met(tis["op2"], session=session) assert tis["op2"].state == State.SKIPPED + + +def test_branch_skip_decision_bypasses_custom_xcom_backend(session, dag_maker): + """ + A value-externalizing custom XCom backend must not break branch-skip of + mapped/cleared downstream tasks. + + The branch decision is written through the real worker push path with such a + backend configured. It must be stored readably (not as the backend's opaque + pointer) so that NotPreviouslySkippedDep can skip a not-yet-expanded mapped + downstream task, which the worker does not skip directly. + + Regression test for https://github.com/apache/airflow/issues/50491. + """ + + class _PointerXComBackend(BaseXCom): + @staticmethod + def serialize_value(value, **kwargs): + return "xcom_s3://pointer" + + @staticmethod + def deserialize_value(result): + return "xcom_s3://pointer" + + start_date = pendulum.datetime(2020, 1, 1) + with dag_maker( + "test_skip_bypass_backend_dag", + schedule=None, + start_date=start_date, + session=session, + ): + op1 = BranchPythonOperator(task_id="op1", python_callable=lambda: "op3") + op2 = EmptyOperator(task_id="op2") + op3 = EmptyOperator(task_id="op3") + op1 >> [op2, op3] + + dr = dag_maker.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING) + tis = {ti.task_id: ti for ti in dr.task_instances} + + with mock.patch("airflow.sdk.execution_time.task_runner.XCom", _PointerXComBackend): + run_task_instance(tis["op1"], op1) + + stored = session.scalar( + select(XComModel.value).where( + XComModel.dag_id == dr.dag_id, + XComModel.task_id == "op1", + XComModel.run_id == dr.run_id, + XComModel.key == XCOM_SKIPMIXIN_KEY, + XComModel.map_index == -1, + ) + ) + + assert stored is not None + assert "xcom_s3://pointer" not in str(stored) + + tis["op2"].map_index = 0 + session.merge(tis["op2"]) + session.flush() + + dep = NotPreviouslySkippedDep() + assert len(list(dep.get_dep_statuses(tis["op2"], DepContext(), session=session))) == 1 + assert tis["op2"].state == State.SKIPPED diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index f3fee689928a0..f941d6a5d60dd 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -57,6 +57,7 @@ TIRunContext, ) from airflow.sdk.bases.operator import BaseOperator, ExecutorSafeguard +from airflow.sdk.bases.skipmixin import XCOM_SKIPMIXIN_KEY from airflow.sdk.bases.xcom import BaseXCom from airflow.sdk.configuration import conf from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager @@ -880,6 +881,14 @@ def _xcom_push( # Private function, as we don't want to expose the ability to manually set `mapped_length` to SDK # consumers + if key == XCOM_SKIPMIXIN_KEY: + # The branch/skip decision is control-plane data the scheduler reads (via + # NotPreviouslySkippedDep) to skip mapped or cleared downstream tasks. It must + # bypass any custom XCom backend, which could externalize it into a pointer the + # scheduler cannot interpret, silently leaving those tasks unskipped (#50491). + _xcom_push_to_db(ti, key, value) + return + XCom.set( key=key, value=value,