Skip to content

Commit 22b15a2

Browse files
vincbeckLefteris Gilmaz
authored and
Lefteris Gilmaz
committed
AIP-82 Use hash instead of repr (apache#44797)
1 parent 4097c9c commit 22b15a2

File tree

1 file changed

+47
-28
lines changed

1 file changed

+47
-28
lines changed

airflow/dag_processing/collection.py

+47-28
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,10 @@
2727

2828
from __future__ import annotations
2929

30+
import json
3031
import logging
3132
import traceback
32-
from typing import TYPE_CHECKING, NamedTuple
33+
from typing import TYPE_CHECKING, Any, NamedTuple
3334

3435
from sqlalchemy import and_, delete, exists, func, select, tuple_
3536
from sqlalchemy.exc import OperationalError
@@ -50,6 +51,7 @@
5051
from airflow.models.errors import ParseImportError
5152
from airflow.models.trigger import Trigger
5253
from airflow.sdk.definitions.asset import Asset, AssetAlias
54+
from airflow.serialization.serialized_objects import BaseSerialization
5355
from airflow.triggers.base import BaseTrigger
5456
from airflow.utils.retries import MAX_DB_RETRIES, run_with_db_retries
5557
from airflow.utils.sqlalchemy import with_row_locks
@@ -64,6 +66,7 @@
6466

6567
from airflow.models.dagwarning import DagWarning
6668
from airflow.serialization.serialized_objects import MaybeSerializedDAG
69+
from airflow.triggers.base import BaseTrigger
6770
from airflow.typing_compat import Self
6871

6972
log = logging.getLogger(__name__)
@@ -652,50 +655,50 @@ def add_asset_trigger_references(
652655
self, assets: dict[tuple[str, str], AssetModel], *, session: Session
653656
) -> None:
654657
# Update references from assets being used
655-
refs_to_add: dict[tuple[str, str], set[str]] = {}
656-
refs_to_remove: dict[tuple[str, str], set[str]] = {}
657-
triggers: dict[str, BaseTrigger] = {}
658+
refs_to_add: dict[tuple[str, str], set[int]] = {}
659+
refs_to_remove: dict[tuple[str, str], set[int]] = {}
660+
triggers: dict[int, BaseTrigger] = {}
658661

659662
# Optimization: if no asset collected, skip fetching active assets
660663
active_assets = _find_active_assets(self.assets.keys(), session=session) if self.assets else {}
661664

662665
for name_uri, asset in self.assets.items():
663666
# If the asset belong to a DAG not active or paused, consider there is no watcher associated to it
664667
asset_watchers = asset.watchers if name_uri in active_assets else []
665-
trigger_repr_to_trigger_dict: dict[str, BaseTrigger] = {
666-
repr(trigger): trigger for trigger in asset_watchers
668+
trigger_hash_to_trigger_dict: dict[int, BaseTrigger] = {
669+
self._get_base_trigger_hash(trigger): trigger for trigger in asset_watchers
667670
}
668-
triggers.update(trigger_repr_to_trigger_dict)
669-
trigger_repr_from_asset: set[str] = set(trigger_repr_to_trigger_dict.keys())
671+
triggers.update(trigger_hash_to_trigger_dict)
672+
trigger_hash_from_asset: set[int] = set(trigger_hash_to_trigger_dict.keys())
670673

671674
asset_model = assets[name_uri]
672-
trigger_repr_from_asset_model: set[str] = {
673-
BaseTrigger.repr(trigger.classpath, trigger.kwargs) for trigger in asset_model.triggers
675+
trigger_hash_from_asset_model: set[int] = {
676+
self._get_trigger_hash(trigger.classpath, trigger.kwargs) for trigger in asset_model.triggers
674677
}
675678

676679
# Optimization: no diff between the DB and DAG definitions, no update needed
677-
if trigger_repr_from_asset == trigger_repr_from_asset_model:
680+
if trigger_hash_from_asset == trigger_hash_from_asset_model:
678681
continue
679682

680-
diff_to_add = trigger_repr_from_asset - trigger_repr_from_asset_model
681-
diff_to_remove = trigger_repr_from_asset_model - trigger_repr_from_asset
683+
diff_to_add = trigger_hash_from_asset - trigger_hash_from_asset_model
684+
diff_to_remove = trigger_hash_from_asset_model - trigger_hash_from_asset
682685
if diff_to_add:
683686
refs_to_add[name_uri] = diff_to_add
684687
if diff_to_remove:
685688
refs_to_remove[name_uri] = diff_to_remove
686689

687690
if refs_to_add:
688-
all_trigger_reprs: set[str] = {
689-
trigger_repr for trigger_reprs in refs_to_add.values() for trigger_repr in trigger_reprs
691+
all_trigger_hashes: set[int] = {
692+
trigger_hash for trigger_hashes in refs_to_add.values() for trigger_hash in trigger_hashes
690693
}
691694

692695
all_trigger_keys: set[tuple[str, str]] = {
693-
self._encrypt_trigger_kwargs(triggers[trigger_repr])
694-
for trigger_reprs in refs_to_add.values()
695-
for trigger_repr in trigger_reprs
696+
self._encrypt_trigger_kwargs(triggers[trigger_hash])
697+
for trigger_hashes in refs_to_add.values()
698+
for trigger_hash in trigger_hashes
696699
}
697-
orm_triggers: dict[str, Trigger] = {
698-
BaseTrigger.repr(trigger.classpath, trigger.kwargs): trigger
700+
orm_triggers: dict[int, Trigger] = {
701+
self._get_trigger_hash(trigger.classpath, trigger.kwargs): trigger
699702
for trigger in session.scalars(
700703
select(Trigger).where(
701704
tuple_(Trigger.classpath, Trigger.encrypted_kwargs).in_(all_trigger_keys)
@@ -707,32 +710,32 @@ def add_asset_trigger_references(
707710
new_trigger_models = [
708711
trigger
709712
for trigger in [
710-
Trigger.from_object(triggers[trigger_repr])
711-
for trigger_repr in all_trigger_reprs
712-
if trigger_repr not in orm_triggers
713+
Trigger.from_object(triggers[trigger_hash])
714+
for trigger_hash in all_trigger_hashes
715+
if trigger_hash not in orm_triggers
713716
]
714717
]
715718
session.add_all(new_trigger_models)
716719
orm_triggers.update(
717-
(BaseTrigger.repr(trigger.classpath, trigger.kwargs), trigger)
720+
(self._get_trigger_hash(trigger.classpath, trigger.kwargs), trigger)
718721
for trigger in new_trigger_models
719722
)
720723

721724
# Add new references
722-
for name_uri, trigger_reprs in refs_to_add.items():
725+
for name_uri, trigger_hashes in refs_to_add.items():
723726
asset_model = assets[name_uri]
724727
asset_model.triggers.extend(
725-
[orm_triggers.get(trigger_repr) for trigger_repr in trigger_reprs]
728+
[orm_triggers.get(trigger_hash) for trigger_hash in trigger_hashes]
726729
)
727730

728731
if refs_to_remove:
729732
# Remove old references
730-
for name_uri, trigger_reprs in refs_to_remove.items():
733+
for name_uri, trigger_hashes in refs_to_remove.items():
731734
asset_model = assets[name_uri]
732735
asset_model.triggers = [
733736
trigger
734737
for trigger in asset_model.triggers
735-
if BaseTrigger.repr(trigger.classpath, trigger.kwargs) not in trigger_reprs
738+
if self._get_trigger_hash(trigger.classpath, trigger.kwargs) not in trigger_hashes
736739
]
737740

738741
# Remove references from assets no longer used
@@ -747,3 +750,19 @@ def add_asset_trigger_references(
747750
def _encrypt_trigger_kwargs(trigger: BaseTrigger) -> tuple[str, str]:
748751
classpath, kwargs = trigger.serialize()
749752
return classpath, Trigger.encrypt_kwargs(kwargs)
753+
754+
@staticmethod
755+
def _get_trigger_hash(classpath: str, kwargs: dict[str, Any]) -> int:
756+
"""
757+
Return the hash of the trigger classpath and kwargs. This is used to uniquely identify a trigger.
758+
759+
We do not want to move this logic in a `__hash__` method in `BaseTrigger` because we do not want to
760+
make the triggers hashable. The reason being, when the triggerer retrieve the list of triggers, we do
761+
not want it dedupe them. When used to defer tasks, 2 triggers can have the same classpath and kwargs.
762+
This is not true for event driven scheduling.
763+
"""
764+
return hash((classpath, json.dumps(BaseSerialization.serialize(kwargs)).encode("utf-8")))
765+
766+
def _get_base_trigger_hash(self, trigger: BaseTrigger) -> int:
767+
classpath, kwargs = trigger.serialize()
768+
return self._get_trigger_hash(classpath, kwargs)

0 commit comments

Comments
 (0)