Skip to content

Commit

Permalink
AIP-82 Use hash instead of repr
Browse files Browse the repository at this point in the history
  • Loading branch information
vincbeck committed Dec 11, 2024
1 parent 7bce01b commit 6711957
Showing 1 changed file with 46 additions and 29 deletions.
75 changes: 46 additions & 29 deletions airflow/dag_processing/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

import itertools
import logging
from typing import TYPE_CHECKING, NamedTuple
from typing import TYPE_CHECKING, Any, NamedTuple

from sqlalchemy import func, select, tuple_
from sqlalchemy.orm import joinedload, load_only
Expand All @@ -47,7 +47,7 @@
from airflow.models.dagrun import DagRun
from airflow.models.trigger import Trigger
from airflow.sdk.definitions.asset import Asset, AssetAlias
from airflow.triggers.base import BaseTrigger
from airflow.serialization.serialized_objects import BaseSerialization
from airflow.utils.sqlalchemy import with_row_locks
from airflow.utils.timezone import utcnow
from airflow.utils.types import DagRunType
Expand All @@ -58,6 +58,7 @@
from sqlalchemy.orm import Session
from sqlalchemy.sql import Select

from airflow.triggers.base import BaseTrigger
from airflow.typing_compat import Self

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -452,50 +453,50 @@ def add_asset_trigger_references(
self, assets: dict[tuple[str, str], AssetModel], *, session: Session
) -> None:
# Update references from assets being used
refs_to_add: dict[tuple[str, str], set[str]] = {}
refs_to_remove: dict[tuple[str, str], set[str]] = {}
triggers: dict[str, BaseTrigger] = {}
refs_to_add: dict[tuple[str, str], set[int]] = {}
refs_to_remove: dict[tuple[str, str], set[int]] = {}
triggers: dict[int, BaseTrigger] = {}

# Optimization: if no asset collected, skip fetching active assets
active_assets = _find_active_assets(self.assets.keys(), session=session) if self.assets else {}

for name_uri, asset in self.assets.items():
# If the asset belong to a DAG not active or paused, consider there is no watcher associated to it
asset_watchers = asset.watchers if name_uri in active_assets else []
trigger_repr_to_trigger_dict: dict[str, BaseTrigger] = {
repr(trigger): trigger for trigger in asset_watchers
trigger_hash_to_trigger_dict: dict[int, BaseTrigger] = {
self._get_base_trigger_hash(trigger): trigger for trigger in asset_watchers
}
triggers.update(trigger_repr_to_trigger_dict)
trigger_repr_from_asset: set[str] = set(trigger_repr_to_trigger_dict.keys())
triggers.update(trigger_hash_to_trigger_dict)
trigger_hash_from_asset: set[int] = set(trigger_hash_to_trigger_dict.keys())

asset_model = assets[name_uri]
trigger_repr_from_asset_model: set[str] = {
BaseTrigger.repr(trigger.classpath, trigger.kwargs) for trigger in asset_model.triggers
trigger_hash_from_asset_model: set[int] = {
self._get_trigger_hash(trigger.classpath, trigger.kwargs) for trigger in asset_model.triggers
}

# Optimization: no diff between the DB and DAG definitions, no update needed
if trigger_repr_from_asset == trigger_repr_from_asset_model:
if trigger_hash_from_asset == trigger_hash_from_asset_model:
continue

diff_to_add = trigger_repr_from_asset - trigger_repr_from_asset_model
diff_to_remove = trigger_repr_from_asset_model - trigger_repr_from_asset
diff_to_add = trigger_hash_from_asset - trigger_hash_from_asset_model
diff_to_remove = trigger_hash_from_asset_model - trigger_hash_from_asset
if diff_to_add:
refs_to_add[name_uri] = diff_to_add
if diff_to_remove:
refs_to_remove[name_uri] = diff_to_remove

if refs_to_add:
all_trigger_reprs: set[str] = {
trigger_repr for trigger_reprs in refs_to_add.values() for trigger_repr in trigger_reprs
all_trigger_hashes: set[int] = {
trigger_hash for trigger_hashes in refs_to_add.values() for trigger_hash in trigger_hashes
}

all_trigger_keys: set[tuple[str, str]] = {
self._encrypt_trigger_kwargs(triggers[trigger_repr])
for trigger_reprs in refs_to_add.values()
for trigger_repr in trigger_reprs
self._encrypt_trigger_kwargs(triggers[trigger_hash])
for trigger_hashes in refs_to_add.values()
for trigger_hash in trigger_hashes
}
orm_triggers: dict[str, Trigger] = {
BaseTrigger.repr(trigger.classpath, trigger.kwargs): trigger
orm_triggers: dict[int, Trigger] = {
self._get_trigger_hash(trigger.classpath, trigger.kwargs): trigger
for trigger in session.scalars(
select(Trigger).where(
tuple_(Trigger.classpath, Trigger.encrypted_kwargs).in_(all_trigger_keys)
Expand All @@ -507,32 +508,32 @@ def add_asset_trigger_references(
new_trigger_models = [
trigger
for trigger in [
Trigger.from_object(triggers[trigger_repr])
for trigger_repr in all_trigger_reprs
if trigger_repr not in orm_triggers
Trigger.from_object(triggers[trigger_hash])
for trigger_hash in all_trigger_hashes
if trigger_hash not in orm_triggers
]
]
session.add_all(new_trigger_models)
orm_triggers.update(
(BaseTrigger.repr(trigger.classpath, trigger.kwargs), trigger)
(self._get_trigger_hash(trigger.classpath, trigger.kwargs), trigger)
for trigger in new_trigger_models
)

# Add new references
for name_uri, trigger_reprs in refs_to_add.items():
for name_uri, trigger_hashes in refs_to_add.items():
asset_model = assets[name_uri]
asset_model.triggers.extend(
[orm_triggers.get(trigger_repr) for trigger_repr in trigger_reprs]
[orm_triggers.get(trigger_hash) for trigger_hash in trigger_hashes]
)

if refs_to_remove:
# Remove old references
for name_uri, trigger_reprs in refs_to_remove.items():
for name_uri, trigger_hashes in refs_to_remove.items():
asset_model = assets[name_uri]
asset_model.triggers = [
trigger
for trigger in asset_model.triggers
if BaseTrigger.repr(trigger.classpath, trigger.kwargs) not in trigger_reprs
if self._get_trigger_hash(trigger.classpath, trigger.kwargs) not in trigger_hashes
]

# Remove references from assets no longer used
Expand All @@ -547,3 +548,19 @@ def add_asset_trigger_references(
def _encrypt_trigger_kwargs(trigger: BaseTrigger) -> tuple[str, str]:
classpath, kwargs = trigger.serialize()
return classpath, Trigger.encrypt_kwargs(kwargs)

@staticmethod
def _get_trigger_hash(classpath: str, kwargs: dict[str, Any]) -> int:
"""
Return the hash of the trigger classpath and kwargs. This is used to uniquely identify a trigger.
We do not want to move this logic in a `__hash__` method in `BaseTrigger` because we do not want to
make the triggers hashable. The reason being, when the triggerer retrieve the list of triggers, we do
not want it dedupe them. When used to defer tasks, 2 triggers can have the same classpath and kwargs.
This is not true for event driven scheduling.
"""
return hash((classpath, BaseSerialization.serialize(kwargs)))

def _get_base_trigger_hash(self, trigger: BaseTrigger) -> int:
classpath, kwargs = trigger.serialize()
return self._get_trigger_hash(classpath, kwargs)

0 comments on commit 6711957

Please sign in to comment.