Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIP-82 Handle trigger serialization #45562

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 16 additions & 19 deletions airflow/dag_processing/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import json
import logging
import traceback
from typing import TYPE_CHECKING, Any, NamedTuple
from typing import TYPE_CHECKING, Any, NamedTuple, cast

from sqlalchemy import and_, delete, exists, func, insert, select, tuple_
from sqlalchemy.exc import OperationalError
Expand All @@ -52,9 +52,8 @@
from airflow.models.dagwarning import DagWarningType
from airflow.models.errors import ParseImportError
from airflow.models.trigger import Trigger
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUriRef
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUriRef, AssetWatcher
from airflow.serialization.serialized_objects import BaseSerialization
from airflow.triggers.base import BaseTrigger
from airflow.utils.retries import MAX_DB_RETRIES, run_with_db_retries
from airflow.utils.sqlalchemy import with_row_locks
from airflow.utils.timezone import utcnow
Expand All @@ -68,7 +67,6 @@

from airflow.models.dagwarning import DagWarning
from airflow.serialization.serialized_objects import MaybeSerializedDAG
from airflow.triggers.base import BaseTrigger
from airflow.typing_compat import Self

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -737,16 +735,19 @@ def add_asset_trigger_references(
# Update references from assets being used
refs_to_add: dict[tuple[str, str], set[int]] = {}
refs_to_remove: dict[tuple[str, str], set[int]] = {}
triggers: dict[int, BaseTrigger] = {}
triggers: dict[int, dict] = {}

# Optimization: if no asset collected, skip fetching active assets
active_assets = _find_active_assets(self.assets.keys(), session=session) if self.assets else {}

for name_uri, asset in self.assets.items():
# If the asset belong to a DAG not active or paused, consider there is no watcher associated to it
asset_watchers = asset.watchers if name_uri in active_assets else []
trigger_hash_to_trigger_dict: dict[int, BaseTrigger] = {
self._get_base_trigger_hash(trigger): trigger for trigger in asset_watchers
asset_watchers: list[AssetWatcher] = asset.watchers if name_uri in active_assets else []
trigger_hash_to_trigger_dict: dict[int, dict] = {
self._get_trigger_hash(
cast(dict, watcher.trigger)["classpath"], cast(dict, watcher.trigger)["kwargs"]
): cast(dict, watcher.trigger)
vincbeck marked this conversation as resolved.
Show resolved Hide resolved
for watcher in asset_watchers
}
triggers.update(trigger_hash_to_trigger_dict)
trigger_hash_from_asset: set[int] = set(trigger_hash_to_trigger_dict.keys())
Expand All @@ -773,7 +774,10 @@ def add_asset_trigger_references(
}

all_trigger_keys: set[tuple[str, str]] = {
self._encrypt_trigger_kwargs(triggers[trigger_hash])
(
triggers[trigger_hash]["classpath"],
Trigger.encrypt_kwargs(triggers[trigger_hash]["kwargs"]),
)
for trigger_hashes in refs_to_add.values()
for trigger_hash in trigger_hashes
}
Expand All @@ -790,7 +794,9 @@ def add_asset_trigger_references(
new_trigger_models = [
trigger
for trigger in [
Trigger.from_object(triggers[trigger_hash])
Trigger(
classpath=triggers[trigger_hash]["classpath"], kwargs=triggers[trigger_hash]["kwargs"]
)
for trigger_hash in all_trigger_hashes
if trigger_hash not in orm_triggers
]
Expand Down Expand Up @@ -826,11 +832,6 @@ def add_asset_trigger_references(
if (asset_model.name, asset_model.uri) not in self.assets:
asset_model.triggers = []

@staticmethod
def _encrypt_trigger_kwargs(trigger: BaseTrigger) -> tuple[str, str]:
classpath, kwargs = trigger.serialize()
return classpath, Trigger.encrypt_kwargs(kwargs)

@staticmethod
def _get_trigger_hash(classpath: str, kwargs: dict[str, Any]) -> int:
"""
Expand All @@ -842,7 +843,3 @@ def _get_trigger_hash(classpath: str, kwargs: dict[str, Any]) -> int:
This is not true for event driven scheduling.
"""
return hash((classpath, json.dumps(BaseSerialization.serialize(kwargs)).encode("utf-8")))

def _get_base_trigger_hash(self, trigger: BaseTrigger) -> int:
classpath, kwargs = trigger.serialize()
return self._get_trigger_hash(classpath, kwargs)
7 changes: 3 additions & 4 deletions airflow/example_dags/example_asset_with_watchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,14 @@
from __future__ import annotations

import os
import tempfile

from airflow.decorators import task
from airflow.models.baseoperator import chain
from airflow.models.dag import DAG
from airflow.providers.standard.triggers.file import FileTrigger
from airflow.sdk.definitions.asset import Asset
from airflow.sdk.definitions.asset import Asset, AssetWatcher
vincbeck marked this conversation as resolved.
Show resolved Hide resolved

file_path = tempfile.NamedTemporaryFile().name
file_path = "/tmp/test"
vincbeck marked this conversation as resolved.
Show resolved Hide resolved

with DAG(
dag_id="example_create_file",
Expand All @@ -44,7 +43,7 @@ def create_file():
chain(create_file())

trigger = FileTrigger(filepath=file_path, poke_interval=10)
asset = Asset("example_asset", watchers=[trigger])
asset = Asset("example_asset", watchers=[AssetWatcher(name="test_file_watcher", trigger=trigger)])

with DAG(
dag_id="example_asset_with_watchers",
Expand Down
12 changes: 12 additions & 0 deletions airflow/serialization/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@
{"type": "null"},
{ "$ref": "#/definitions/dict" }
]
},
"watchers": {
"type": "array",
"items": { "$ref": "#/definitions/trigger" }
}
},
"required": [ "uri", "extra" ]
Expand Down Expand Up @@ -126,6 +130,14 @@
],
"additionalProperties": false
},
"trigger": {
"type": "object",
"properties": {
"classpath": { "type": "string" },
"kwargs": { "$ref": "#/definitions/dict" }
},
"required": [ "classpath", "kwargs" ]
},
"dict": {
"description": "A python dictionary containing values of any type",
"type": "object"
Expand Down
44 changes: 41 additions & 3 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
AssetAny,
AssetRef,
AssetUniqueKey,
AssetWatcher,
BaseAsset,
)
from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator
Expand Down Expand Up @@ -251,13 +252,34 @@ def encode_asset_condition(var: BaseAsset) -> dict[str, Any]:
:meta private:
"""
if isinstance(var, Asset):
return {

def _encode_watcher(watcher: AssetWatcher):
return {
"name": watcher.name,
"trigger": _encode_trigger(watcher.trigger),
}

def _encode_trigger(trigger: BaseTrigger | dict):
if isinstance(trigger, dict):
return trigger
classpath, kwargs = trigger.serialize()
return {
"classpath": classpath,
"kwargs": kwargs,
}

asset = {
"__type": DAT.ASSET,
"name": var.name,
"uri": var.uri,
"group": var.group,
"extra": var.extra,
}

if len(var.watchers) > 0:
asset["watchers"] = [_encode_watcher(watcher) for watcher in var.watchers]

return asset
if isinstance(var, AssetAlias):
return {"__type": DAT.ASSET_ALIAS, "name": var.name, "group": var.group}
if isinstance(var, AssetAll):
Expand All @@ -283,7 +305,17 @@ def decode_asset_condition(var: dict[str, Any]) -> BaseAsset:
"""
dat = var["__type"]
if dat == DAT.ASSET:
return Asset(name=var["name"], uri=var["uri"], group=var["group"], extra=var["extra"])
serialized_watchers = var["watchers"] if "watchers" in var else []
return Asset(
name=var["name"],
uri=var["uri"],
group=var["group"],
extra=var["extra"],
vincbeck marked this conversation as resolved.
Show resolved Hide resolved
watchers=[
AssetWatcher(name=watcher["name"], trigger=watcher["trigger"])
for watcher in serialized_watchers
],
)
if dat == DAT.ASSET_ALL:
return AssetAll(*(decode_asset_condition(x) for x in var["objects"]))
if dat == DAT.ASSET_ANY:
Expand Down Expand Up @@ -874,7 +906,13 @@ def deserialize(cls, encoded_var: Any) -> Any:
elif type_ == DAT.XCOM_REF:
return _XComRef(var) # Delay deserializing XComArg objects until we have the entire DAG.
elif type_ == DAT.ASSET:
return Asset(**var)
watchers = var.pop("watchers", [])
return Asset(
**var,
watchers=[
AssetWatcher(name=watcher["name"], trigger=watcher["trigger"]) for watcher in watchers
],
)
elif type_ == DAT.ASSET_ALIAS:
return AssetAlias(**var)
elif type_ == DAT.ASSET_ANY:
Expand Down
22 changes: 17 additions & 5 deletions task_sdk/src/airflow/sdk/definitions/asset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
"AssetNameRef",
"AssetRef",
"AssetUriRef",
"AssetWatcher",
]


Expand Down Expand Up @@ -257,6 +258,17 @@ def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDepe
raise NotImplementedError


@attrs.define(frozen=True)
class AssetWatcher:
"""A representation of an asset watcher. The name uniquely identity the watch."""
vincbeck marked this conversation as resolved.
Show resolved Hide resolved

name: str
# This attribute serves double purpose. For a "normal" asset instance
# loaded from DAG, this holds the trigger used to monitor an external resource.
# For an asset recreated from a serialized DAG, however, this holds the serialized data of the trigger.
vincbeck marked this conversation as resolved.
Show resolved Hide resolved
trigger: BaseTrigger | dict


@attrs.define(init=False, unsafe_hash=False)
class Asset(os.PathLike, BaseAsset):
"""A representation of data asset dependencies between workflows."""
Expand All @@ -276,7 +288,7 @@ class Asset(os.PathLike, BaseAsset):
factory=dict,
converter=_set_extra_default,
)
watchers: list[BaseTrigger] = attrs.field(
watchers: list[AssetWatcher] = attrs.field(
factory=list,
)

Expand All @@ -291,7 +303,7 @@ def __init__(
*,
group: str = ...,
extra: dict | None = None,
watchers: list[BaseTrigger] = ...,
watchers: list[AssetWatcher] = ...,
) -> None:
"""Canonical; both name and uri are provided."""

Expand All @@ -302,7 +314,7 @@ def __init__(
*,
group: str = ...,
extra: dict | None = None,
watchers: list[BaseTrigger] = ...,
watchers: list[AssetWatcher] = ...,
) -> None:
"""It's possible to only provide the name, either by keyword or as the only positional argument."""

Expand All @@ -313,7 +325,7 @@ def __init__(
uri: str,
group: str = ...,
extra: dict | None = None,
watchers: list[BaseTrigger] = ...,
watchers: list[AssetWatcher] = ...,
) -> None:
"""It's possible to only provide the URI as a keyword argument."""

Expand All @@ -324,7 +336,7 @@ def __init__(
*,
group: str | None = None,
extra: dict | None = None,
watchers: list[BaseTrigger] | None = None,
watchers: list[AssetWatcher] | None = None,
) -> None:
if name is None and uri is None:
raise TypeError("Asset() requires either 'name' or 'uri'")
Expand Down
8 changes: 6 additions & 2 deletions tests/dag_processing/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from airflow.models.serialized_dag import SerializedDagModel
from airflow.operators.empty import EmptyOperator
from airflow.providers.standard.triggers.temporal import TimeDeltaTrigger
from airflow.sdk.definitions.asset import Asset
from airflow.sdk.definitions.asset import Asset, AssetWatcher
from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG
from airflow.utils import timezone as tz
from airflow.utils.session import create_session
Expand Down Expand Up @@ -131,7 +131,11 @@ def per_test(self) -> Generator:
)
def test_add_asset_trigger_references(self, is_active, is_paused, expected_num_triggers, dag_maker):
trigger = TimeDeltaTrigger(timedelta(seconds=0))
asset = Asset("test_add_asset_trigger_references_asset", watchers=[trigger])
classpath, kwargs = trigger.serialize()
asset = Asset(
"test_add_asset_trigger_references_asset",
watchers=[AssetWatcher(name="test", trigger={"classpath": classpath, "kwargs": kwargs})],
)

with dag_maker(dag_id="test_add_asset_trigger_references_dag", schedule=[asset]) as dag:
EmptyOperator(task_id="mytask")
Expand Down
12 changes: 11 additions & 1 deletion tests/serialization/test_serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@
from airflow.models.xcom_arg import XComArg
from airflow.operators.empty import EmptyOperator
from airflow.providers.standard.operators.python import PythonOperator
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent, AssetUniqueKey
from airflow.providers.standard.triggers.file import FileTrigger
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent, AssetUniqueKey, AssetWatcher
from airflow.sdk.execution_time.context import OutletEventAccessor, OutletEventAccessors
from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
from airflow.serialization.serialized_objects import BaseSerialization
Expand Down Expand Up @@ -254,6 +255,15 @@ def __len__(self) -> int:
lambda a, b: len(a) == len(b) and isinstance(b, list),
),
(Asset(uri="test://asset1", name="test"), DAT.ASSET, equals),
(
Asset(
uri="test://asset1",
name="test",
watchers=[AssetWatcher(name="test", trigger=FileTrigger(filepath="/tmp"))],
),
DAT.ASSET,
equals,
),
(SimpleTaskInstance.from_ti(ti=TI), DAT.SIMPLE_TASK_INSTANCE, equals),
(
Connection(conn_id="TEST_ID", uri="mysql://"),
Expand Down