Skip to content
10 changes: 2 additions & 8 deletions airflow-core/src/airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,17 +251,11 @@ def serialize_kwargs(key: str) -> Any:

def _decode_start_trigger_args(var: dict[str, Any]) -> StartTriggerArgs:
"""Decode a StartTriggerArgs."""

def deserialize_kwargs(key: str) -> Any:
if (val := var[key]) is None:
return None
return BaseSerialization.deserialize(val)

return StartTriggerArgs(
trigger_cls=var["trigger_cls"],
trigger_kwargs=deserialize_kwargs("trigger_kwargs"),
trigger_kwargs=var["trigger_kwargs"],
next_method=var["next_method"],
next_kwargs=deserialize_kwargs("next_kwargs"),
next_kwargs=var["next_kwargs"],
timeout=datetime.timedelta(seconds=var["timeout"]) if var["timeout"] else None,
)

Expand Down
32 changes: 32 additions & 0 deletions airflow-core/tests/unit/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2627,6 +2627,38 @@ def execute_complete(self):
}
assert tasks[1]["__var"]["start_from_trigger"] is True

def test_trigger_kwargs_not_deserialised_through_serdag(self):
Comment thread
amoghrajesh marked this conversation as resolved.
"""trigger_kwargs are not deserialized when loading a serialized DAG."""

class TestOperator(BaseOperator):
start_trigger_args = StartTriggerArgs(
trigger_cls="airflow.providers.standard.triggers.temporal.TimeDeltaTrigger",
trigger_kwargs={"delta": timedelta(seconds=2)},
next_method="execute_complete",
next_kwargs=None,
timeout=None,
)
start_from_trigger = True

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def execute_complete(self):
pass

dag = DAG(dag_id="test_dag_kwargs_raw", schedule=None, start_date=datetime(2023, 11, 9))
with dag:
TestOperator(task_id="test_task")

serialized = DagSerialization.to_dict(dag)
deserialized_dag = DagSerialization.from_dict(serialized)

task = deserialized_dag.get_task("test_task")
assert task.start_trigger_args.trigger_kwargs == {
"__type": "dict",
"__var": {"delta": {"__type": "timedelta", "__var": 2.0}},
}


def test_kubernetes_optional():
"""Test that serialization module loads without kubernetes, but deserialization of PODs requires it"""
Expand Down
Loading