diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index e5b19f2768da9..ba0b75747c4a2 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -88,6 +88,7 @@ from airflow.models.taskmap import TaskMap from airflow.models.taskreschedule import TaskReschedule from airflow.models.xcom import XCOM_RETURN_KEY, LazyXComSelectSequence, XComModel +from airflow.serialization.enums import stringify_encoding_keys from airflow.settings import task_instance_mutation_hook from airflow.task.priority_strategy import validate_and_load_priority_weight_strategy from airflow.ti_deps.dep_context import DepContext @@ -1691,7 +1692,7 @@ def defer_task(self, session: Session = NEW_SESSION) -> bool: self.state = TaskInstanceState.DEFERRED self.trigger_id = trigger_row.id self.next_method = start_trigger_args.next_method - self.next_kwargs = start_trigger_args.next_kwargs or {} + self.next_kwargs = stringify_encoding_keys(start_trigger_args.next_kwargs or {}) self.start_date = timezone.utcnow() # If an execution_timeout is set, set the timeout to the minimum of diff --git a/airflow-core/src/airflow/models/trigger.py b/airflow-core/src/airflow/models/trigger.py index 2949262532ef0..a6ee583032a88 100644 --- a/airflow-core/src/airflow/models/trigger.py +++ b/airflow-core/src/airflow/models/trigger.py @@ -35,6 +35,7 @@ from airflow.models.asset import AssetWatcherModel from airflow.models.base import Base from airflow.models.taskinstance import TaskInstance +from airflow.serialization.enums import stringify_encoding_keys as _stringify_encoding_keys from airflow.triggers.base import BaseTaskEndEvent from airflow.utils.retries import run_with_db_retries from airflow.utils.session import NEW_SESSION, provide_session @@ -146,7 +147,7 @@ def encrypt_kwargs(kwargs: dict[str, Any]) -> str: from airflow.models.crypto import get_fernet from airflow.sdk.serde import serialize - serialized_kwargs = serialize(kwargs) + serialized_kwargs = serialize(_stringify_encoding_keys(kwargs)) return get_fernet().encrypt(json.dumps(serialized_kwargs).encode("utf-8")).decode("utf-8") @staticmethod diff --git a/airflow-core/src/airflow/serialization/enums.py b/airflow-core/src/airflow/serialization/enums.py index 5fdd4d6698793..ae4c1249cab09 100644 --- a/airflow-core/src/airflow/serialization/enums.py +++ b/airflow-core/src/airflow/serialization/enums.py @@ -20,6 +20,7 @@ from __future__ import annotations from enum import Enum, unique +from typing import Any # Fields of an encoded object in serialization. @@ -31,6 +32,31 @@ class Encoding(str, Enum): VAR = "__var" +def stringify_encoding_keys(d: Any) -> Any: + """ + Convert BaseSerialization Encoding enum keys to their string values recursively. + + Python 3.10 compatibility: str(Encoding.TYPE) returns "Encoding.TYPE" on 3.10 + instead of "__type__" (3.10 is still the default CI target). serde.serialize + uses str(k) for dict keys, so without this conversion the encrypted blob ends up + with "Encoding.TYPE" keys that neither serde._convert nor the BaseSerialization + fallback can read back. + """ + if isinstance(d, dict): + return { + (k.value if isinstance(k, Encoding) else str(k)): stringify_encoding_keys(v) for k, v in d.items() + } + if isinstance(d, list): + return [stringify_encoding_keys(i) for i in d] + if isinstance(d, tuple): + converted = [stringify_encoding_keys(i) for i in d] + # namedtuples require positional args, not a single list argument + if hasattr(d, "_fields"): + return type(d)(*converted) + return tuple(converted) + return d + + # Supported types for encoding. primitives and list are not encoded. @unique class DagAttributeTypes(str, Enum): diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py b/airflow-core/src/airflow/serialization/serialized_objects.py index a0295dc237003..7764323fca7cc 100644 --- a/airflow-core/src/airflow/serialization/serialized_objects.py +++ b/airflow-core/src/airflow/serialization/serialized_objects.py @@ -251,17 +251,11 @@ def serialize_kwargs(key: str) -> Any: def _decode_start_trigger_args(var: dict[str, Any]) -> StartTriggerArgs: """Decode a StartTriggerArgs.""" - - def deserialize_kwargs(key: str) -> Any: - if (val := var[key]) is None: - return None - return BaseSerialization.deserialize(val) - return StartTriggerArgs( trigger_cls=var["trigger_cls"], - trigger_kwargs=deserialize_kwargs("trigger_kwargs"), + trigger_kwargs=var["trigger_kwargs"], next_method=var["next_method"], - next_kwargs=deserialize_kwargs("next_kwargs"), + next_kwargs=var["next_kwargs"], timeout=datetime.timedelta(seconds=var["timeout"]) if var["timeout"] else None, ) diff --git a/airflow-core/src/airflow/triggers/base.py b/airflow-core/src/airflow/triggers/base.py index b37448edfea19..f39b62facf7b2 100644 --- a/airflow-core/src/airflow/triggers/base.py +++ b/airflow-core/src/airflow/triggers/base.py @@ -120,9 +120,16 @@ def task_instance(self, value: TaskInstance | None) -> None: # does not build a template context, so render_template_fields is # never called and empty template_fields is safe. start_trigger_args = getattr(self.task, "start_trigger_args", None) - trigger_kwarg_keys = ( - set((start_trigger_args.trigger_kwargs or {}).keys()) if start_trigger_args else set() - ) + if start_trigger_args: + from airflow.serialization.enums import Encoding + + raw = start_trigger_args.trigger_kwargs or {} + # trigger_kwargs may be BaseSerialization-encoded; extract inner dict keys + if isinstance(raw, dict) and Encoding.TYPE in raw: + raw = raw.get(Encoding.VAR) or {} + trigger_kwarg_keys = set(raw.keys()) + else: + trigger_kwarg_keys = set() if trigger_kwarg_keys: self.template_fields = tuple( f for f in self.task.template_fields if f in trigger_kwarg_keys and hasattr(self, f) @@ -256,9 +263,11 @@ def hash(classpath: str, kwargs: dict[str, Any]) -> int: We do not want to have this logic in ``BaseTrigger`` because, when used to defer tasks, 2 triggers can have the same classpath and kwargs. This is not true for event driven scheduling. """ + from airflow.serialization.encoders import encode_trigger from airflow.serialization.serialized_objects import BaseSerialization - return hash((classpath, json.dumps(BaseSerialization.serialize(kwargs)).encode("utf-8"))) + normalized = encode_trigger({"classpath": classpath, "kwargs": kwargs})["kwargs"] + return hash((classpath, json.dumps(BaseSerialization.serialize(normalized)).encode("utf-8"))) class TriggerEvent(BaseModel): diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index 752247c6d1752..9ef0d28f33f9e 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -851,15 +851,7 @@ async def test_trigger_kwargs_serialization_cleanup(self, session): session.commit() stored_kwargs = trigger_orm.kwargs - assert stored_kwargs == { - "Encoding.TYPE": "dict", - "Encoding.VAR": { - "dict": {"Encoding.TYPE": "dict", "Encoding.VAR": {}}, - "list": [], - "simple": "test", - "tuple": {"Encoding.TYPE": "tuple", "Encoding.VAR": []}, - }, - } + assert stored_kwargs == kw runner = TriggerRunner() runner.to_create.append( diff --git a/airflow-core/tests/unit/models/test_dagrun.py b/airflow-core/tests/unit/models/test_dagrun.py index 34e8acbaacf60..478e4f5ea4d42 100644 --- a/airflow-core/tests/unit/models/test_dagrun.py +++ b/airflow-core/tests/unit/models/test_dagrun.py @@ -54,6 +54,7 @@ from airflow.models.taskinstance import TaskInstance, TaskInstanceNote, clear_task_instances from airflow.models.taskmap import TaskMap from airflow.models.taskreschedule import TaskReschedule +from airflow.models.trigger import Trigger from airflow.providers.standard.operators.bash import BashOperator from airflow.providers.standard.operators.empty import EmptyOperator from airflow.providers.standard.operators.python import PythonOperator, ShortCircuitOperator @@ -2298,6 +2299,87 @@ def execute_complete(self): assert ti.state == TaskInstanceState.DEFERRED +@pytest.mark.need_serialized_dag +def test_schedule_tis_start_trigger_next_kwargs_round_trip(dag_maker, session): + """next_kwargs with encoded values (timedelta) must survive the defer_task round-trip.""" + import datetime + + from airflow.sdk.serde import deserialize + + class TestOperator(BaseOperator): + start_trigger_args = StartTriggerArgs( + trigger_cls="airflow.triggers.testing.SuccessTrigger", + trigger_kwargs={}, + next_method="execute_complete", + next_kwargs={"delay": datetime.timedelta(seconds=30)}, + timeout=None, + ) + start_from_trigger = True + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def execute_complete(self): + pass + + with dag_maker(session=session): + TestOperator(task_id="test_task") + + dr: DagRun = dag_maker.create_dagrun() + ti = dr.get_task_instance("test_task") + ti.task = dr.dag.get_task("test_task") + dr.schedule_tis((ti,), session=session) + + assert ti.state == TaskInstanceState.DEFERRED + assert deserialize(ti.next_kwargs) == {"delay": datetime.timedelta(seconds=30)} + + +@pytest.mark.need_serialized_dag +def test_schedule_tis_start_trigger_kwargs_e2e(dag_maker, session): + """ + End to end test of scheduler defer_task with non-trivial trigger_kwargs (timedelta) -> + Trigger row -> Trigger.kwargs returns correct Python objects. + + Covers the path: BaseSerialization encodes trigger_kwargs with Encoding enum keys, + defer_task passes them to Trigger as kwargs which calls encrypt_kwargs -> _stringify_encoding_keys + -> serde.serialize -> stores them. + + On reading, serde.deserialize + _convert must reconstruct the original values. + """ + import datetime + + class TestOperator(BaseOperator): + start_trigger_args = StartTriggerArgs( + trigger_cls="airflow.triggers.testing.SuccessTrigger", + trigger_kwargs={"delta": datetime.timedelta(seconds=2)}, + next_method="execute_complete", + next_kwargs=None, + timeout=None, + ) + start_from_trigger = True + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def execute_complete(self): + pass + + with dag_maker(session=session): + TestOperator(task_id="test_task") + + dr: DagRun = dag_maker.create_dagrun() + ti = dr.get_task_instance("test_task") + ti.task = dr.dag.get_task("test_task") + dr.schedule_tis((ti,), session=session) + + assert ti.state == TaskInstanceState.DEFERRED + + trigger_row = session.get(Trigger, ti.trigger_id) + assert trigger_row is not None + # trigger_kwargs must round-trip correctly through encrypt_kwargs → _decrypt_kwargs + assert trigger_row.kwargs == {"delta": datetime.timedelta(seconds=2)} + + def test_schedule_tis_empty_operator_try_number(dag_maker, session: Session): """ When empty operator is not actually run, then we need to increment the try_number, diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py b/airflow-core/tests/unit/serialization/test_dag_serialization.py index 33db3d8d906d4..fd333d5d79918 100644 --- a/airflow-core/tests/unit/serialization/test_dag_serialization.py +++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py @@ -817,6 +817,8 @@ def validate_deserialized_task( "on_failure_fail_dagrun", "_needs_expansion", "_is_sensor", + # trigger_kwargs is kept as raw JSON after deserialization; checked separately + "start_trigger_args", } else: # Promised to be mapped by the assert above. assert isinstance(serialized_task, SerializedMappedOperator) @@ -857,6 +859,20 @@ def validate_deserialized_task( else: assert serialized_task.resources == task.resources + # start_trigger_args: trigger_kwargs is kept as raw BaseSerialization-encoded form + # after deserialization. Compare the encoded forms directly — s.trigger_kwargs is + # exactly BaseSerialization.serialize(o.trigger_kwargs) since _encode_start_trigger_args + # serializes it and _decode_start_trigger_args keeps it raw. + if task.start_trigger_args is not None: + from airflow.serialization.serialized_objects import BaseSerialization + + s = serialized_task.start_trigger_args + o = task.start_trigger_args + assert s.trigger_cls == o.trigger_cls + assert s.next_method == o.next_method + assert s.timeout == o.timeout + assert s.trigger_kwargs == BaseSerialization.serialize(o.trigger_kwargs or {}) + assert [ensure_serialized_asset(i) for i in task.inlets] == serialized_task.inlets assert [ensure_serialized_asset(o) for o in task.outlets] == serialized_task.outlets @@ -2630,6 +2646,42 @@ def execute_complete(self): } assert tasks[1]["__var"]["start_from_trigger"] is True + def test_trigger_kwargs_not_deserialised_through_serdag(self): + """trigger_kwargs and next_kwargs are kept as raw BaseSerialization JSON when loading a serialized DAG.""" + + class TestOperator(BaseOperator): + start_trigger_args = StartTriggerArgs( + trigger_cls="airflow.providers.standard.triggers.temporal.TimeDeltaTrigger", + trigger_kwargs={"delta": timedelta(seconds=2)}, + next_method="execute_complete", + next_kwargs={"resume_after": timedelta(seconds=5)}, + timeout=None, + ) + start_from_trigger = True + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def execute_complete(self): + pass + + dag = DAG(dag_id="test_dag_kwargs_raw", schedule=None, start_date=datetime(2023, 11, 9)) + with dag: + TestOperator(task_id="test_task") + + serialized = DagSerialization.to_dict(dag) + deserialized_dag = DagSerialization.from_dict(serialized) + + task = deserialized_dag.get_task("test_task") + assert task.start_trigger_args.trigger_kwargs == { + "__type": "dict", + "__var": {"delta": {"__type": "timedelta", "__var": 2.0}}, + } + assert task.start_trigger_args.next_kwargs == { + "__type": "dict", + "__var": {"resume_after": {"__type": "timedelta", "__var": 5.0}}, + } + def test_kubernetes_optional(): """Test that serialization module loads without kubernetes, but deserialization of PODs requires it"""