Skip to content
14 changes: 13 additions & 1 deletion airflow-core/src/airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1633,7 +1633,19 @@ def defer_task(self, session: Session = NEW_SESSION) -> bool:
assert isinstance(self.task, Operator)

if start_trigger_args := self.start_trigger_args:
trigger_kwargs = start_trigger_args.trigger_kwargs or {}
from airflow.serialization.enums import Encoding

def _normalize(d: Any) -> Any:
Comment thread
amoghrajesh marked this conversation as resolved.
Outdated
# trigger_kwargs arrives with Encoding enum keys from BaseSerialization.
# On Python 3.10, str(Encoding.TYPE) returns "Encoding.TYPE" not "__type",
Comment thread
amoghrajesh marked this conversation as resolved.
Outdated
# so we convert enum keys to their values before passing to serde.serialize.
if isinstance(d, dict):
return {
(k.value if isinstance(k, Encoding) else str(k)): _normalize(v) for k, v in d.items()
}
return d

trigger_kwargs = _normalize(start_trigger_args.trigger_kwargs or {})
Comment thread
amoghrajesh marked this conversation as resolved.
Outdated
timeout = start_trigger_args.timeout

# Calculate timeout too if it was passed
Expand Down
10 changes: 2 additions & 8 deletions airflow-core/src/airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,17 +251,11 @@ def serialize_kwargs(key: str) -> Any:

def _decode_start_trigger_args(var: dict[str, Any]) -> StartTriggerArgs:
"""Decode a StartTriggerArgs."""

def deserialize_kwargs(key: str) -> Any:
if (val := var[key]) is None:
return None
return BaseSerialization.deserialize(val)

return StartTriggerArgs(
trigger_cls=var["trigger_cls"],
trigger_kwargs=deserialize_kwargs("trigger_kwargs"),
trigger_kwargs=var["trigger_kwargs"],
next_method=var["next_method"],
next_kwargs=deserialize_kwargs("next_kwargs"),
next_kwargs=var["next_kwargs"],
timeout=datetime.timedelta(seconds=var["timeout"]) if var["timeout"] else None,
)

Expand Down
10 changes: 1 addition & 9 deletions airflow-core/tests/unit/jobs/test_triggerer_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,15 +646,7 @@ async def test_trigger_kwargs_serialization_cleanup(self, session):
session.commit()

stored_kwargs = trigger_orm.kwargs
assert stored_kwargs == {
"Encoding.TYPE": "dict",
"Encoding.VAR": {
"dict": {"Encoding.TYPE": "dict", "Encoding.VAR": {}},
"list": [],
"simple": "test",
"tuple": {"Encoding.TYPE": "tuple", "Encoding.VAR": []},
},
}
assert stored_kwargs == kw

runner = TriggerRunner()
runner.to_create.append(
Expand Down
48 changes: 48 additions & 0 deletions airflow-core/tests/unit/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,8 @@ def validate_deserialized_task(
"on_failure_fail_dagrun",
"_needs_expansion",
"_is_sensor",
# trigger_kwargs is kept as raw JSON after deserialization; checked separately
"start_trigger_args",
}
else: # Promised to be mapped by the assert above.
assert isinstance(serialized_task, SerializedMappedOperator)
Expand Down Expand Up @@ -855,6 +857,20 @@ def validate_deserialized_task(
else:
assert serialized_task.resources == task.resources

# start_trigger_args: trigger_kwargs is kept as raw JSON after deserialization;
# compare after deserializing both sides
if task.start_trigger_args is not None:
from airflow.serialization.serialized_objects import BaseSerialization

s = serialized_task.start_trigger_args
o = task.start_trigger_args
assert s.trigger_cls == o.trigger_cls
assert s.next_method == o.next_method
Comment thread
amoghrajesh marked this conversation as resolved.
assert s.timeout == o.timeout
assert BaseSerialization.deserialize(s.trigger_kwargs or {}) == BaseSerialization.deserialize(
o.trigger_kwargs or {}
)

assert [ensure_serialized_asset(i) for i in task.inlets] == serialized_task.inlets
assert [ensure_serialized_asset(o) for o in task.outlets] == serialized_task.outlets

Expand Down Expand Up @@ -2627,6 +2643,38 @@ def execute_complete(self):
}
assert tasks[1]["__var"]["start_from_trigger"] is True

def test_trigger_kwargs_not_deserialised_through_serdag(self):
Comment thread
amoghrajesh marked this conversation as resolved.
"""trigger_kwargs are not deserialized when loading a serialized DAG."""

class TestOperator(BaseOperator):
start_trigger_args = StartTriggerArgs(
trigger_cls="airflow.providers.standard.triggers.temporal.TimeDeltaTrigger",
trigger_kwargs={"delta": timedelta(seconds=2)},
next_method="execute_complete",
next_kwargs=None,
timeout=None,
)
start_from_trigger = True

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def execute_complete(self):
pass

dag = DAG(dag_id="test_dag_kwargs_raw", schedule=None, start_date=datetime(2023, 11, 9))
with dag:
TestOperator(task_id="test_task")

serialized = DagSerialization.to_dict(dag)
deserialized_dag = DagSerialization.from_dict(serialized)

task = deserialized_dag.get_task("test_task")
assert task.start_trigger_args.trigger_kwargs == {
"__type": "dict",
"__var": {"delta": {"__type": "timedelta", "__var": 2.0}},
}


def test_kubernetes_optional():
"""Test that serialization module loads without kubernetes, but deserialization of PODs requires it"""
Expand Down
Loading