Skip to content
3 changes: 2 additions & 1 deletion airflow-core/src/airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
from airflow.models.taskmap import TaskMap
from airflow.models.taskreschedule import TaskReschedule
from airflow.models.xcom import XCOM_RETURN_KEY, LazyXComSelectSequence, XComModel
from airflow.serialization.enums import stringify_encoding_keys
from airflow.settings import task_instance_mutation_hook
from airflow.task.priority_strategy import validate_and_load_priority_weight_strategy
from airflow.ti_deps.dep_context import DepContext
Expand Down Expand Up @@ -1691,7 +1692,7 @@ def defer_task(self, session: Session = NEW_SESSION) -> bool:
self.state = TaskInstanceState.DEFERRED
self.trigger_id = trigger_row.id
self.next_method = start_trigger_args.next_method
self.next_kwargs = start_trigger_args.next_kwargs or {}
self.next_kwargs = stringify_encoding_keys(start_trigger_args.next_kwargs or {})
self.start_date = timezone.utcnow()

# If an execution_timeout is set, set the timeout to the minimum of
Expand Down
3 changes: 2 additions & 1 deletion airflow-core/src/airflow/models/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from airflow.models.asset import AssetWatcherModel
from airflow.models.base import Base
from airflow.models.taskinstance import TaskInstance
from airflow.serialization.enums import stringify_encoding_keys as _stringify_encoding_keys
from airflow.triggers.base import BaseTaskEndEvent
from airflow.utils.retries import run_with_db_retries
from airflow.utils.session import NEW_SESSION, provide_session
Expand Down Expand Up @@ -146,7 +147,7 @@ def encrypt_kwargs(kwargs: dict[str, Any]) -> str:
from airflow.models.crypto import get_fernet
from airflow.sdk.serde import serialize

serialized_kwargs = serialize(kwargs)
serialized_kwargs = serialize(_stringify_encoding_keys(kwargs))
return get_fernet().encrypt(json.dumps(serialized_kwargs).encode("utf-8")).decode("utf-8")

@staticmethod
Expand Down
26 changes: 26 additions & 0 deletions airflow-core/src/airflow/serialization/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from __future__ import annotations

from enum import Enum, unique
from typing import Any


# Fields of an encoded object in serialization.
Expand All @@ -31,6 +32,31 @@ class Encoding(str, Enum):
VAR = "__var"


def stringify_encoding_keys(d: Any) -> Any:
"""
Convert BaseSerialization Encoding enum keys to their string values recursively.

Python 3.10 compatibility: str(Encoding.TYPE) returns "Encoding.TYPE" on 3.10
instead of "__type__" (3.10 is still the default CI target). serde.serialize
uses str(k) for dict keys, so without this conversion the encrypted blob ends up
with "Encoding.TYPE" keys that neither serde._convert nor the BaseSerialization
fallback can read back.
"""
if isinstance(d, dict):
return {
(k.value if isinstance(k, Encoding) else str(k)): stringify_encoding_keys(v) for k, v in d.items()
}
if isinstance(d, list):
return [stringify_encoding_keys(i) for i in d]
if isinstance(d, tuple):
converted = [stringify_encoding_keys(i) for i in d]
# namedtuples require positional args, not a single list argument
if hasattr(d, "_fields"):
return type(d)(*converted)
return tuple(converted)
return d


# Supported types for encoding. primitives and list are not encoded.
@unique
class DagAttributeTypes(str, Enum):
Expand Down
10 changes: 2 additions & 8 deletions airflow-core/src/airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,17 +251,11 @@ def serialize_kwargs(key: str) -> Any:

def _decode_start_trigger_args(var: dict[str, Any]) -> StartTriggerArgs:
"""Decode a StartTriggerArgs."""

def deserialize_kwargs(key: str) -> Any:
if (val := var[key]) is None:
return None
return BaseSerialization.deserialize(val)

return StartTriggerArgs(
trigger_cls=var["trigger_cls"],
trigger_kwargs=deserialize_kwargs("trigger_kwargs"),
trigger_kwargs=var["trigger_kwargs"],
next_method=var["next_method"],
next_kwargs=deserialize_kwargs("next_kwargs"),
next_kwargs=var["next_kwargs"],
timeout=datetime.timedelta(seconds=var["timeout"]) if var["timeout"] else None,
)

Expand Down
17 changes: 13 additions & 4 deletions airflow-core/src/airflow/triggers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,16 @@ def task_instance(self, value: TaskInstance | None) -> None:
# does not build a template context, so render_template_fields is
# never called and empty template_fields is safe.
start_trigger_args = getattr(self.task, "start_trigger_args", None)
trigger_kwarg_keys = (
set((start_trigger_args.trigger_kwargs or {}).keys()) if start_trigger_args else set()
)
if start_trigger_args:
from airflow.serialization.enums import Encoding

raw = start_trigger_args.trigger_kwargs or {}
# trigger_kwargs may be BaseSerialization-encoded; extract inner dict keys
if isinstance(raw, dict) and Encoding.TYPE in raw:
raw = raw.get(Encoding.VAR) or {}
trigger_kwarg_keys = set(raw.keys())
else:
trigger_kwarg_keys = set()
if trigger_kwarg_keys:
self.template_fields = tuple(
f for f in self.task.template_fields if f in trigger_kwarg_keys and hasattr(self, f)
Expand Down Expand Up @@ -256,9 +263,11 @@ def hash(classpath: str, kwargs: dict[str, Any]) -> int:
We do not want to have this logic in ``BaseTrigger`` because, when used to defer tasks, 2 triggers
can have the same classpath and kwargs. This is not true for event driven scheduling.
"""
from airflow.serialization.encoders import encode_trigger
from airflow.serialization.serialized_objects import BaseSerialization

return hash((classpath, json.dumps(BaseSerialization.serialize(kwargs)).encode("utf-8")))
normalized = encode_trigger({"classpath": classpath, "kwargs": kwargs})["kwargs"]
return hash((classpath, json.dumps(BaseSerialization.serialize(normalized)).encode("utf-8")))


class TriggerEvent(BaseModel):
Expand Down
10 changes: 1 addition & 9 deletions airflow-core/tests/unit/jobs/test_triggerer_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,15 +851,7 @@ async def test_trigger_kwargs_serialization_cleanup(self, session):
session.commit()

stored_kwargs = trigger_orm.kwargs
assert stored_kwargs == {
"Encoding.TYPE": "dict",
"Encoding.VAR": {
"dict": {"Encoding.TYPE": "dict", "Encoding.VAR": {}},
"list": [],
"simple": "test",
"tuple": {"Encoding.TYPE": "tuple", "Encoding.VAR": []},
},
}
assert stored_kwargs == kw

runner = TriggerRunner()
runner.to_create.append(
Expand Down
82 changes: 82 additions & 0 deletions airflow-core/tests/unit/models/test_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from airflow.models.taskinstance import TaskInstance, TaskInstanceNote, clear_task_instances
from airflow.models.taskmap import TaskMap
from airflow.models.taskreschedule import TaskReschedule
from airflow.models.trigger import Trigger
from airflow.providers.standard.operators.bash import BashOperator
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.providers.standard.operators.python import PythonOperator, ShortCircuitOperator
Expand Down Expand Up @@ -2298,6 +2299,87 @@ def execute_complete(self):
assert ti.state == TaskInstanceState.DEFERRED


@pytest.mark.need_serialized_dag
def test_schedule_tis_start_trigger_next_kwargs_round_trip(dag_maker, session):
"""next_kwargs with encoded values (timedelta) must survive the defer_task round-trip."""
import datetime

from airflow.sdk.serde import deserialize

class TestOperator(BaseOperator):
start_trigger_args = StartTriggerArgs(
trigger_cls="airflow.triggers.testing.SuccessTrigger",
trigger_kwargs={},
next_method="execute_complete",
next_kwargs={"delay": datetime.timedelta(seconds=30)},
timeout=None,
)
start_from_trigger = True

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def execute_complete(self):
pass

with dag_maker(session=session):
TestOperator(task_id="test_task")

dr: DagRun = dag_maker.create_dagrun()
ti = dr.get_task_instance("test_task")
ti.task = dr.dag.get_task("test_task")
dr.schedule_tis((ti,), session=session)

assert ti.state == TaskInstanceState.DEFERRED
assert deserialize(ti.next_kwargs) == {"delay": datetime.timedelta(seconds=30)}


Comment thread
amoghrajesh marked this conversation as resolved.
@pytest.mark.need_serialized_dag
def test_schedule_tis_start_trigger_kwargs_e2e(dag_maker, session):
"""
End to end test of scheduler defer_task with non-trivial trigger_kwargs (timedelta) ->
Trigger row -> Trigger.kwargs returns correct Python objects.

Covers the path: BaseSerialization encodes trigger_kwargs with Encoding enum keys,
defer_task passes them to Trigger as kwargs which calls encrypt_kwargs -> _stringify_encoding_keys
-> serde.serialize -> stores them.

On reading, serde.deserialize + _convert must reconstruct the original values.
"""
import datetime

class TestOperator(BaseOperator):
start_trigger_args = StartTriggerArgs(
trigger_cls="airflow.triggers.testing.SuccessTrigger",
trigger_kwargs={"delta": datetime.timedelta(seconds=2)},
next_method="execute_complete",
next_kwargs=None,
timeout=None,
)
start_from_trigger = True

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def execute_complete(self):
pass

with dag_maker(session=session):
TestOperator(task_id="test_task")

dr: DagRun = dag_maker.create_dagrun()
ti = dr.get_task_instance("test_task")
ti.task = dr.dag.get_task("test_task")
dr.schedule_tis((ti,), session=session)

assert ti.state == TaskInstanceState.DEFERRED

trigger_row = session.get(Trigger, ti.trigger_id)
assert trigger_row is not None
# trigger_kwargs must round-trip correctly through encrypt_kwargs → _decrypt_kwargs
assert trigger_row.kwargs == {"delta": datetime.timedelta(seconds=2)}


def test_schedule_tis_empty_operator_try_number(dag_maker, session: Session):
"""
When empty operator is not actually run, then we need to increment the try_number,
Expand Down
52 changes: 52 additions & 0 deletions airflow-core/tests/unit/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,8 @@ def validate_deserialized_task(
"on_failure_fail_dagrun",
"_needs_expansion",
"_is_sensor",
# trigger_kwargs is kept as raw JSON after deserialization; checked separately
"start_trigger_args",
}
else: # Promised to be mapped by the assert above.
assert isinstance(serialized_task, SerializedMappedOperator)
Expand Down Expand Up @@ -857,6 +859,20 @@ def validate_deserialized_task(
else:
assert serialized_task.resources == task.resources

# start_trigger_args: trigger_kwargs is kept as raw BaseSerialization-encoded form
# after deserialization. Compare the encoded forms directly — s.trigger_kwargs is
# exactly BaseSerialization.serialize(o.trigger_kwargs) since _encode_start_trigger_args
# serializes it and _decode_start_trigger_args keeps it raw.
if task.start_trigger_args is not None:
from airflow.serialization.serialized_objects import BaseSerialization

s = serialized_task.start_trigger_args
o = task.start_trigger_args
assert s.trigger_cls == o.trigger_cls
assert s.next_method == o.next_method
Comment thread
amoghrajesh marked this conversation as resolved.
assert s.timeout == o.timeout
assert s.trigger_kwargs == BaseSerialization.serialize(o.trigger_kwargs or {})

assert [ensure_serialized_asset(i) for i in task.inlets] == serialized_task.inlets
assert [ensure_serialized_asset(o) for o in task.outlets] == serialized_task.outlets

Expand Down Expand Up @@ -2630,6 +2646,42 @@ def execute_complete(self):
}
assert tasks[1]["__var"]["start_from_trigger"] is True

def test_trigger_kwargs_not_deserialised_through_serdag(self):
Comment thread
amoghrajesh marked this conversation as resolved.
"""trigger_kwargs and next_kwargs are kept as raw BaseSerialization JSON when loading a serialized DAG."""

class TestOperator(BaseOperator):
start_trigger_args = StartTriggerArgs(
trigger_cls="airflow.providers.standard.triggers.temporal.TimeDeltaTrigger",
trigger_kwargs={"delta": timedelta(seconds=2)},
next_method="execute_complete",
next_kwargs={"resume_after": timedelta(seconds=5)},
timeout=None,
)
start_from_trigger = True

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def execute_complete(self):
pass

dag = DAG(dag_id="test_dag_kwargs_raw", schedule=None, start_date=datetime(2023, 11, 9))
with dag:
TestOperator(task_id="test_task")

serialized = DagSerialization.to_dict(dag)
deserialized_dag = DagSerialization.from_dict(serialized)

task = deserialized_dag.get_task("test_task")
assert task.start_trigger_args.trigger_kwargs == {
"__type": "dict",
"__var": {"delta": {"__type": "timedelta", "__var": 2.0}},
}
assert task.start_trigger_args.next_kwargs == {
"__type": "dict",
"__var": {"resume_after": {"__type": "timedelta", "__var": 5.0}},
}


def test_kubernetes_optional():
"""Test that serialization module loads without kubernetes, but deserialization of PODs requires it"""
Expand Down
Loading