Skip to content

Commit 89c03e6

Browse files
hardeybiseyLefteris Gilmaz
authored and
Lefteris Gilmaz
committed
Accept task_key as an argument in DatabricksNotebookOperator (apache#44960)
This PR introduces the ability for users to explicitly specify databricks_task_key as a parameter for the DatabricksNotebookOperator. If databricks_task_key is not provided, a default value is generated using the hash of the dag_id and task_id. Key Changes: Users can now define databricks_task_key explicitly. When not provided, the key defaults to a deterministic hash based on dag_id and task_id. Fixes: apache#41816 Fixes: apache#44250 related: apache#43106
1 parent 24679b7 commit 89c03e6

File tree

4 files changed

+63
-52
lines changed

4 files changed

+63
-52
lines changed

providers/src/airflow/providers/databricks/operators/databricks.py

+25-18
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from __future__ import annotations
2121

22+
import hashlib
2223
import time
2324
from abc import ABC, abstractmethod
2425
from collections.abc import Sequence
@@ -966,6 +967,8 @@ class DatabricksTaskBaseOperator(BaseOperator, ABC):
966967
967968
:param caller: The name of the caller operator to be used in the logs.
968969
:param databricks_conn_id: The name of the Airflow connection to use.
970+
:param databricks_task_key: An optional task_key used to refer to the task by Databricks API. By
971+
default this will be set to the hash of ``dag_id + task_id``.
969972
:param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class.
970973
:param databricks_retry_delay: Number of seconds to wait between retries.
971974
:param databricks_retry_limit: Amount of times to retry if the Databricks backend is unreachable.
@@ -986,6 +989,7 @@ def __init__(
986989
self,
987990
caller: str = "DatabricksTaskBaseOperator",
988991
databricks_conn_id: str = "databricks_default",
992+
databricks_task_key: str = "",
989993
databricks_retry_args: dict[Any, Any] | None = None,
990994
databricks_retry_delay: int = 1,
991995
databricks_retry_limit: int = 3,
@@ -1000,6 +1004,7 @@ def __init__(
10001004
):
10011005
self.caller = caller
10021006
self.databricks_conn_id = databricks_conn_id
1007+
self._databricks_task_key = databricks_task_key
10031008
self.databricks_retry_args = databricks_retry_args
10041009
self.databricks_retry_delay = databricks_retry_delay
10051010
self.databricks_retry_limit = databricks_retry_limit
@@ -1037,17 +1042,21 @@ def _get_hook(self, caller: str) -> DatabricksHook:
10371042
caller=caller,
10381043
)
10391044

1040-
def _get_databricks_task_id(self, task_id: str) -> str:
1041-
"""Get the databricks task ID using dag_id and task_id. Removes illegal characters."""
1042-
task_id = f"{self.dag_id}__{task_id.replace('.', '__')}"
1043-
if len(task_id) > 100:
1044-
self.log.warning(
1045-
"The generated task_key '%s' exceeds 100 characters and will be truncated by the Databricks API. "
1046-
"This will cause failure when trying to monitor the task. task_key is generated by ",
1047-
"concatenating dag_id and task_id.",
1048-
task_id,
1045+
@cached_property
1046+
def databricks_task_key(self) -> str:
1047+
return self._generate_databricks_task_key()
1048+
1049+
def _generate_databricks_task_key(self, task_id: str | None = None) -> str:
1050+
"""Create a databricks task key using the hash of dag_id and task_id."""
1051+
if not self._databricks_task_key or len(self._databricks_task_key) > 100:
1052+
self.log.info(
1053+
"databricks_task_key has not be provided or the provided one exceeds 100 characters and will be truncated by the Databricks API. This will cause failure when trying to monitor the task. A task_key will be generated using the hash value of dag_id+task_id"
10491054
)
1050-
return task_id
1055+
task_id = task_id or self.task_id
1056+
task_key = f"{self.dag_id}__{task_id}".encode()
1057+
self._databricks_task_key = hashlib.md5(task_key).hexdigest()
1058+
self.log.info("Generated databricks task_key: %s", self._databricks_task_key)
1059+
return self._databricks_task_key
10511060

10521061
@property
10531062
def _databricks_workflow_task_group(self) -> DatabricksWorkflowTaskGroup | None:
@@ -1077,7 +1086,7 @@ def _get_task_base_json(self) -> dict[str, Any]:
10771086
def _get_run_json(self) -> dict[str, Any]:
10781087
"""Get run json to be used for task submissions."""
10791088
run_json = {
1080-
"run_name": self._get_databricks_task_id(self.task_id),
1089+
"run_name": self.databricks_task_key,
10811090
**self._get_task_base_json(),
10821091
}
10831092
if self.new_cluster and self.existing_cluster_id:
@@ -1127,19 +1136,17 @@ def _get_current_databricks_task(self) -> dict[str, Any]:
11271136
# building the {task_key: task} map below.
11281137
sorted_task_runs = sorted(tasks, key=lambda x: x["start_time"])
11291138

1130-
return {task["task_key"]: task for task in sorted_task_runs}[
1131-
self._get_databricks_task_id(self.task_id)
1132-
]
1139+
return {task["task_key"]: task for task in sorted_task_runs}[self.databricks_task_key]
11331140

11341141
def _convert_to_databricks_workflow_task(
11351142
self, relevant_upstreams: list[BaseOperator], context: Context | None = None
11361143
) -> dict[str, object]:
11371144
"""Convert the operator to a Databricks workflow task that can be a task in a workflow."""
11381145
base_task_json = self._get_task_base_json()
11391146
result = {
1140-
"task_key": self._get_databricks_task_id(self.task_id),
1147+
"task_key": self.databricks_task_key,
11411148
"depends_on": [
1142-
{"task_key": self._get_databricks_task_id(task_id)}
1149+
{"task_key": self._generate_databricks_task_key(task_id)}
11431150
for task_id in self.upstream_task_ids
11441151
if task_id in relevant_upstreams
11451152
],
@@ -1172,7 +1179,7 @@ def monitor_databricks_job(self) -> None:
11721179
run_state = RunState(**run["state"])
11731180
self.log.info(
11741181
"Current state of the the databricks task %s is %s",
1175-
self._get_databricks_task_id(self.task_id),
1182+
self.databricks_task_key,
11761183
run_state.life_cycle_state,
11771184
)
11781185
if self.deferrable and not run_state.is_terminal:
@@ -1194,7 +1201,7 @@ def monitor_databricks_job(self) -> None:
11941201
run_state = RunState(**run["state"])
11951202
self.log.info(
11961203
"Current state of the databricks task %s is %s",
1197-
self._get_databricks_task_id(self.task_id),
1204+
self.databricks_task_key,
11981205
run_state.life_cycle_state,
11991206
)
12001207
self._handle_terminal_run_state(run_state)

providers/src/airflow/providers/databricks/plugins/databricks_workflow.py

+7-15
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@
4444
if TYPE_CHECKING:
4545
from sqlalchemy.orm.session import Session
4646

47+
from airflow.providers.databricks.operators.databricks import DatabricksTaskBaseOperator
48+
4749

4850
REPAIR_WAIT_ATTEMPTS = os.getenv("DATABRICKS_REPAIR_WAIT_ATTEMPTS", 20)
4951
REPAIR_WAIT_DELAY = os.getenv("DATABRICKS_REPAIR_WAIT_DELAY", 0.5)
@@ -57,18 +59,8 @@ def get_auth_decorator():
5759
return auth.has_access_dag("POST", DagAccessEntity.RUN)
5860

5961

60-
def _get_databricks_task_id(task: BaseOperator) -> str:
61-
"""
62-
Get the databricks task ID using dag_id and task_id. removes illegal characters.
63-
64-
:param task: The task to get the databricks task ID for.
65-
:return: The databricks task ID.
66-
"""
67-
return f"{task.dag_id}__{task.task_id.replace('.', '__')}"
68-
69-
7062
def get_databricks_task_ids(
71-
group_id: str, task_map: dict[str, BaseOperator], log: logging.Logger
63+
group_id: str, task_map: dict[str, DatabricksTaskBaseOperator], log: logging.Logger
7264
) -> list[str]:
7365
"""
7466
Return a list of all Databricks task IDs for a dictionary of Airflow tasks.
@@ -83,7 +75,7 @@ def get_databricks_task_ids(
8375
for task_id, task in task_map.items():
8476
if task_id == f"{group_id}.launch":
8577
continue
86-
databricks_task_id = _get_databricks_task_id(task)
78+
databricks_task_id = task.databricks_task_key
8779
log.debug("databricks task id for task %s is %s", task_id, databricks_task_id)
8880
task_ids.append(databricks_task_id)
8981
return task_ids
@@ -112,7 +104,7 @@ def _clear_task_instances(
112104
dag = airflow_app.dag_bag.get_dag(dag_id)
113105
log.debug("task_ids %s to clear", str(task_ids))
114106
dr: DagRun = _get_dagrun(dag, run_id, session=session)
115-
tis_to_clear = [ti for ti in dr.get_task_instances() if _get_databricks_task_id(ti) in task_ids]
107+
tis_to_clear = [ti for ti in dr.get_task_instances() if ti.databricks_task_key in task_ids]
116108
clear_task_instances(tis_to_clear, session)
117109

118110

@@ -327,7 +319,7 @@ def get_tasks_to_run(self, ti_key: TaskInstanceKey, operator: BaseOperator, log:
327319

328320
tasks_to_run = {ti: t for ti, t in task_group_sub_tasks if ti in failed_and_skipped_tasks}
329321

330-
return ",".join(get_databricks_task_ids(task_group.group_id, tasks_to_run, log))
322+
return ",".join(get_databricks_task_ids(task_group.group_id, tasks_to_run, log)) # type: ignore[arg-type]
331323

332324
@staticmethod
333325
def _get_failed_and_skipped_tasks(dr: DagRun) -> list[str]:
@@ -390,7 +382,7 @@ def get_link(
390382
"databricks_conn_id": metadata.conn_id,
391383
"databricks_run_id": metadata.run_id,
392384
"run_id": ti_key.run_id,
393-
"tasks_to_repair": _get_databricks_task_id(task),
385+
"tasks_to_repair": task.databricks_task_key,
394386
}
395387
return url_for("RepairDatabricksTasks.repair", **query_params)
396388

providers/tests/databricks/operators/test_databricks.py

+27-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# under the License.
1818
from __future__ import annotations
1919

20+
import hashlib
2021
from datetime import datetime, timedelta
2122
from typing import Any
2223
from unittest import mock
@@ -2216,8 +2217,9 @@ def test_convert_to_databricks_workflow_task(self):
22162217

22172218
task_json = operator._convert_to_databricks_workflow_task(relevant_upstreams)
22182219

2220+
task_key = hashlib.md5(b"example_dag__test_task").hexdigest()
22192221
expected_json = {
2220-
"task_key": "example_dag__test_task",
2222+
"task_key": task_key,
22212223
"depends_on": [],
22222224
"timeout_seconds": 0,
22232225
"email_notifications": {},
@@ -2317,3 +2319,27 @@ def test_get_task_base_json(self):
23172319

23182320
assert operator.task_config == task_config
23192321
assert task_base_json == task_config
2322+
2323+
def test_generate_databricks_task_key(self):
2324+
task_config = {}
2325+
operator = DatabricksTaskOperator(
2326+
task_id="test_task",
2327+
databricks_conn_id="test_conn_id",
2328+
task_config=task_config,
2329+
)
2330+
2331+
task_key = f"{operator.dag_id}__{operator.task_id}".encode()
2332+
expected_task_key = hashlib.md5(task_key).hexdigest()
2333+
assert expected_task_key == operator.databricks_task_key
2334+
2335+
def test_user_databricks_task_key(self):
2336+
task_config = {}
2337+
operator = DatabricksTaskOperator(
2338+
task_id="test_task",
2339+
databricks_conn_id="test_conn_id",
2340+
databricks_task_key="test_task_key",
2341+
task_config=task_config,
2342+
)
2343+
expected_task_key = "test_task_key"
2344+
2345+
assert expected_task_key == operator.databricks_task_key

providers/tests/databricks/plugins/test_databricks_workflow.py

+4-18
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
WorkflowJobRepairSingleTaskLink,
3333
WorkflowJobRunLink,
3434
_get_dagrun,
35-
_get_databricks_task_id,
3635
_get_launch_task_key,
3736
_repair_task,
3837
get_databricks_task_ids,
@@ -50,30 +49,17 @@
5049
DATABRICKS_CONN_ID = "databricks_default"
5150
DATABRICKS_RUN_ID = 12345
5251
GROUP_ID = "test_group"
52+
LOG = MagicMock()
5353
TASK_MAP = {
54-
"task1": MagicMock(dag_id=DAG_ID, task_id="task1"),
55-
"task2": MagicMock(dag_id=DAG_ID, task_id="task2"),
54+
"task1": MagicMock(dag_id=DAG_ID, task_id="task1", databricks_task_key="task_key1"),
55+
"task2": MagicMock(dag_id=DAG_ID, task_id="task2", databricks_task_key="task_key2"),
5656
}
57-
LOG = MagicMock()
58-
59-
60-
@pytest.mark.parametrize(
61-
"task, expected_id",
62-
[
63-
(MagicMock(dag_id="dag1", task_id="task.1"), "dag1__task__1"),
64-
(MagicMock(dag_id="dag2", task_id="task_1"), "dag2__task_1"),
65-
],
66-
)
67-
def test_get_databricks_task_id(task, expected_id):
68-
result = _get_databricks_task_id(task)
69-
70-
assert result == expected_id
7157

7258

7359
def test_get_databricks_task_ids():
7460
result = get_databricks_task_ids(GROUP_ID, TASK_MAP, LOG)
7561

76-
expected_ids = ["test_dag__task1", "test_dag__task2"]
62+
expected_ids = ["task_key1", "task_key2"]
7763
assert result == expected_ids
7864

7965

0 commit comments

Comments
 (0)