19
19
20
20
from __future__ import annotations
21
21
22
+ import hashlib
22
23
import time
23
24
from abc import ABC , abstractmethod
24
25
from collections .abc import Sequence
@@ -966,6 +967,8 @@ class DatabricksTaskBaseOperator(BaseOperator, ABC):
966
967
967
968
:param caller: The name of the caller operator to be used in the logs.
968
969
:param databricks_conn_id: The name of the Airflow connection to use.
970
+ :param databricks_task_key: An optional task_key used to refer to the task by Databricks API. By
971
+ default this will be set to the hash of ``dag_id + task_id``.
969
972
:param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class.
970
973
:param databricks_retry_delay: Number of seconds to wait between retries.
971
974
:param databricks_retry_limit: Amount of times to retry if the Databricks backend is unreachable.
@@ -986,6 +989,7 @@ def __init__(
986
989
self ,
987
990
caller : str = "DatabricksTaskBaseOperator" ,
988
991
databricks_conn_id : str = "databricks_default" ,
992
+ databricks_task_key : str = "" ,
989
993
databricks_retry_args : dict [Any , Any ] | None = None ,
990
994
databricks_retry_delay : int = 1 ,
991
995
databricks_retry_limit : int = 3 ,
@@ -1000,6 +1004,7 @@ def __init__(
1000
1004
):
1001
1005
self .caller = caller
1002
1006
self .databricks_conn_id = databricks_conn_id
1007
+ self ._databricks_task_key = databricks_task_key
1003
1008
self .databricks_retry_args = databricks_retry_args
1004
1009
self .databricks_retry_delay = databricks_retry_delay
1005
1010
self .databricks_retry_limit = databricks_retry_limit
@@ -1037,17 +1042,21 @@ def _get_hook(self, caller: str) -> DatabricksHook:
1037
1042
caller = caller ,
1038
1043
)
1039
1044
1040
- def _get_databricks_task_id ( self , task_id : str ) -> str :
1041
- """Get the databricks task ID using dag_id and task_id. Removes illegal characters."""
1042
- task_id = f" { self .dag_id } __ { task_id . replace ( '.' , '__' ) } "
1043
- if len ( task_id ) > 100 :
1044
- self . log . warning (
1045
- "The generated task_key '%s' exceeds 100 characters and will be truncated by the Databricks API. "
1046
- "This will cause failure when trying to monitor the task. task_key is generated by " ,
1047
- "concatenating dag_id and task_id." ,
1048
- task_id ,
1045
+ @ cached_property
1046
+ def databricks_task_key ( self ) -> str :
1047
+ return self ._generate_databricks_task_key ()
1048
+
1049
+ def _generate_databricks_task_key ( self , task_id : str | None = None ) -> str :
1050
+ """Create a databricks task key using the hash of dag_id and task_id."" "
1051
+ if not self . _databricks_task_key or len ( self . _databricks_task_key ) > 100 :
1052
+ self . log . info (
1053
+ "databricks_task_key has not be provided or the provided one exceeds 100 characters and will be truncated by the Databricks API. This will cause failure when trying to monitor the task. A task_key will be generated using the hash value of dag_id+ task_id"
1049
1054
)
1050
- return task_id
1055
+ task_id = task_id or self .task_id
1056
+ task_key = f"{ self .dag_id } __{ task_id } " .encode ()
1057
+ self ._databricks_task_key = hashlib .md5 (task_key ).hexdigest ()
1058
+ self .log .info ("Generated databricks task_key: %s" , self ._databricks_task_key )
1059
+ return self ._databricks_task_key
1051
1060
1052
1061
@property
1053
1062
def _databricks_workflow_task_group (self ) -> DatabricksWorkflowTaskGroup | None :
@@ -1077,7 +1086,7 @@ def _get_task_base_json(self) -> dict[str, Any]:
1077
1086
def _get_run_json (self ) -> dict [str , Any ]:
1078
1087
"""Get run json to be used for task submissions."""
1079
1088
run_json = {
1080
- "run_name" : self ._get_databricks_task_id ( self . task_id ) ,
1089
+ "run_name" : self .databricks_task_key ,
1081
1090
** self ._get_task_base_json (),
1082
1091
}
1083
1092
if self .new_cluster and self .existing_cluster_id :
@@ -1127,19 +1136,17 @@ def _get_current_databricks_task(self) -> dict[str, Any]:
1127
1136
# building the {task_key: task} map below.
1128
1137
sorted_task_runs = sorted (tasks , key = lambda x : x ["start_time" ])
1129
1138
1130
- return {task ["task_key" ]: task for task in sorted_task_runs }[
1131
- self ._get_databricks_task_id (self .task_id )
1132
- ]
1139
+ return {task ["task_key" ]: task for task in sorted_task_runs }[self .databricks_task_key ]
1133
1140
1134
1141
def _convert_to_databricks_workflow_task (
1135
1142
self , relevant_upstreams : list [BaseOperator ], context : Context | None = None
1136
1143
) -> dict [str , object ]:
1137
1144
"""Convert the operator to a Databricks workflow task that can be a task in a workflow."""
1138
1145
base_task_json = self ._get_task_base_json ()
1139
1146
result = {
1140
- "task_key" : self ._get_databricks_task_id ( self . task_id ) ,
1147
+ "task_key" : self .databricks_task_key ,
1141
1148
"depends_on" : [
1142
- {"task_key" : self ._get_databricks_task_id (task_id )}
1149
+ {"task_key" : self ._generate_databricks_task_key (task_id )}
1143
1150
for task_id in self .upstream_task_ids
1144
1151
if task_id in relevant_upstreams
1145
1152
],
@@ -1172,7 +1179,7 @@ def monitor_databricks_job(self) -> None:
1172
1179
run_state = RunState (** run ["state" ])
1173
1180
self .log .info (
1174
1181
"Current state of the the databricks task %s is %s" ,
1175
- self ._get_databricks_task_id ( self . task_id ) ,
1182
+ self .databricks_task_key ,
1176
1183
run_state .life_cycle_state ,
1177
1184
)
1178
1185
if self .deferrable and not run_state .is_terminal :
@@ -1194,7 +1201,7 @@ def monitor_databricks_job(self) -> None:
1194
1201
run_state = RunState (** run ["state" ])
1195
1202
self .log .info (
1196
1203
"Current state of the databricks task %s is %s" ,
1197
- self ._get_databricks_task_id ( self . task_id ) ,
1204
+ self .databricks_task_key ,
1198
1205
run_state .life_cycle_state ,
1199
1206
)
1200
1207
self ._handle_terminal_run_state (run_state )
0 commit comments