27
27
28
28
from __future__ import annotations
29
29
30
+ import json
30
31
import logging
31
32
import traceback
32
- from typing import TYPE_CHECKING , NamedTuple
33
+ from typing import TYPE_CHECKING , Any , NamedTuple
33
34
34
35
from sqlalchemy import and_ , delete , exists , func , select , tuple_
35
36
from sqlalchemy .exc import OperationalError
50
51
from airflow .models .errors import ParseImportError
51
52
from airflow .models .trigger import Trigger
52
53
from airflow .sdk .definitions .asset import Asset , AssetAlias
54
+ from airflow .serialization .serialized_objects import BaseSerialization
53
55
from airflow .triggers .base import BaseTrigger
54
56
from airflow .utils .retries import MAX_DB_RETRIES , run_with_db_retries
55
57
from airflow .utils .sqlalchemy import with_row_locks
64
66
65
67
from airflow .models .dagwarning import DagWarning
66
68
from airflow .serialization .serialized_objects import MaybeSerializedDAG
69
+ from airflow .triggers .base import BaseTrigger
67
70
from airflow .typing_compat import Self
68
71
69
72
log = logging .getLogger (__name__ )
@@ -652,50 +655,50 @@ def add_asset_trigger_references(
652
655
self , assets : dict [tuple [str , str ], AssetModel ], * , session : Session
653
656
) -> None :
654
657
# Update references from assets being used
655
- refs_to_add : dict [tuple [str , str ], set [str ]] = {}
656
- refs_to_remove : dict [tuple [str , str ], set [str ]] = {}
657
- triggers : dict [str , BaseTrigger ] = {}
658
+ refs_to_add : dict [tuple [str , str ], set [int ]] = {}
659
+ refs_to_remove : dict [tuple [str , str ], set [int ]] = {}
660
+ triggers : dict [int , BaseTrigger ] = {}
658
661
659
662
# Optimization: if no asset collected, skip fetching active assets
660
663
active_assets = _find_active_assets (self .assets .keys (), session = session ) if self .assets else {}
661
664
662
665
for name_uri , asset in self .assets .items ():
663
666
# If the asset belong to a DAG not active or paused, consider there is no watcher associated to it
664
667
asset_watchers = asset .watchers if name_uri in active_assets else []
665
- trigger_repr_to_trigger_dict : dict [str , BaseTrigger ] = {
666
- repr (trigger ): trigger for trigger in asset_watchers
668
+ trigger_hash_to_trigger_dict : dict [int , BaseTrigger ] = {
669
+ self . _get_base_trigger_hash (trigger ): trigger for trigger in asset_watchers
667
670
}
668
- triggers .update (trigger_repr_to_trigger_dict )
669
- trigger_repr_from_asset : set [str ] = set (trigger_repr_to_trigger_dict .keys ())
671
+ triggers .update (trigger_hash_to_trigger_dict )
672
+ trigger_hash_from_asset : set [int ] = set (trigger_hash_to_trigger_dict .keys ())
670
673
671
674
asset_model = assets [name_uri ]
672
- trigger_repr_from_asset_model : set [str ] = {
673
- BaseTrigger . repr (trigger .classpath , trigger .kwargs ) for trigger in asset_model .triggers
675
+ trigger_hash_from_asset_model : set [int ] = {
676
+ self . _get_trigger_hash (trigger .classpath , trigger .kwargs ) for trigger in asset_model .triggers
674
677
}
675
678
676
679
# Optimization: no diff between the DB and DAG definitions, no update needed
677
- if trigger_repr_from_asset == trigger_repr_from_asset_model :
680
+ if trigger_hash_from_asset == trigger_hash_from_asset_model :
678
681
continue
679
682
680
- diff_to_add = trigger_repr_from_asset - trigger_repr_from_asset_model
681
- diff_to_remove = trigger_repr_from_asset_model - trigger_repr_from_asset
683
+ diff_to_add = trigger_hash_from_asset - trigger_hash_from_asset_model
684
+ diff_to_remove = trigger_hash_from_asset_model - trigger_hash_from_asset
682
685
if diff_to_add :
683
686
refs_to_add [name_uri ] = diff_to_add
684
687
if diff_to_remove :
685
688
refs_to_remove [name_uri ] = diff_to_remove
686
689
687
690
if refs_to_add :
688
- all_trigger_reprs : set [str ] = {
689
- trigger_repr for trigger_reprs in refs_to_add .values () for trigger_repr in trigger_reprs
691
+ all_trigger_hashes : set [int ] = {
692
+ trigger_hash for trigger_hashes in refs_to_add .values () for trigger_hash in trigger_hashes
690
693
}
691
694
692
695
all_trigger_keys : set [tuple [str , str ]] = {
693
- self ._encrypt_trigger_kwargs (triggers [trigger_repr ])
694
- for trigger_reprs in refs_to_add .values ()
695
- for trigger_repr in trigger_reprs
696
+ self ._encrypt_trigger_kwargs (triggers [trigger_hash ])
697
+ for trigger_hashes in refs_to_add .values ()
698
+ for trigger_hash in trigger_hashes
696
699
}
697
- orm_triggers : dict [str , Trigger ] = {
698
- BaseTrigger . repr (trigger .classpath , trigger .kwargs ): trigger
700
+ orm_triggers : dict [int , Trigger ] = {
701
+ self . _get_trigger_hash (trigger .classpath , trigger .kwargs ): trigger
699
702
for trigger in session .scalars (
700
703
select (Trigger ).where (
701
704
tuple_ (Trigger .classpath , Trigger .encrypted_kwargs ).in_ (all_trigger_keys )
@@ -707,32 +710,32 @@ def add_asset_trigger_references(
707
710
new_trigger_models = [
708
711
trigger
709
712
for trigger in [
710
- Trigger .from_object (triggers [trigger_repr ])
711
- for trigger_repr in all_trigger_reprs
712
- if trigger_repr not in orm_triggers
713
+ Trigger .from_object (triggers [trigger_hash ])
714
+ for trigger_hash in all_trigger_hashes
715
+ if trigger_hash not in orm_triggers
713
716
]
714
717
]
715
718
session .add_all (new_trigger_models )
716
719
orm_triggers .update (
717
- (BaseTrigger . repr (trigger .classpath , trigger .kwargs ), trigger )
720
+ (self . _get_trigger_hash (trigger .classpath , trigger .kwargs ), trigger )
718
721
for trigger in new_trigger_models
719
722
)
720
723
721
724
# Add new references
722
- for name_uri , trigger_reprs in refs_to_add .items ():
725
+ for name_uri , trigger_hashes in refs_to_add .items ():
723
726
asset_model = assets [name_uri ]
724
727
asset_model .triggers .extend (
725
- [orm_triggers .get (trigger_repr ) for trigger_repr in trigger_reprs ]
728
+ [orm_triggers .get (trigger_hash ) for trigger_hash in trigger_hashes ]
726
729
)
727
730
728
731
if refs_to_remove :
729
732
# Remove old references
730
- for name_uri , trigger_reprs in refs_to_remove .items ():
733
+ for name_uri , trigger_hashes in refs_to_remove .items ():
731
734
asset_model = assets [name_uri ]
732
735
asset_model .triggers = [
733
736
trigger
734
737
for trigger in asset_model .triggers
735
- if BaseTrigger . repr (trigger .classpath , trigger .kwargs ) not in trigger_reprs
738
+ if self . _get_trigger_hash (trigger .classpath , trigger .kwargs ) not in trigger_hashes
736
739
]
737
740
738
741
# Remove references from assets no longer used
@@ -747,3 +750,19 @@ def add_asset_trigger_references(
747
750
def _encrypt_trigger_kwargs (trigger : BaseTrigger ) -> tuple [str , str ]:
748
751
classpath , kwargs = trigger .serialize ()
749
752
return classpath , Trigger .encrypt_kwargs (kwargs )
753
+
754
+ @staticmethod
755
+ def _get_trigger_hash (classpath : str , kwargs : dict [str , Any ]) -> int :
756
+ """
757
+ Return the hash of the trigger classpath and kwargs. This is used to uniquely identify a trigger.
758
+
759
+ We do not want to move this logic in a `__hash__` method in `BaseTrigger` because we do not want to
760
+ make the triggers hashable. The reason being, when the triggerer retrieve the list of triggers, we do
761
+ not want it dedupe them. When used to defer tasks, 2 triggers can have the same classpath and kwargs.
762
+ This is not true for event driven scheduling.
763
+ """
764
+ return hash ((classpath , json .dumps (BaseSerialization .serialize (kwargs )).encode ("utf-8" )))
765
+
766
+ def _get_base_trigger_hash (self , trigger : BaseTrigger ) -> int :
767
+ classpath , kwargs = trigger .serialize ()
768
+ return self ._get_trigger_hash (classpath , kwargs )
0 commit comments