Skip to content

Commit

Permalink
Accept task_key as an argument in DatabricksTaskBaseOperator
Browse files Browse the repository at this point in the history
  • Loading branch information
hardeybisey committed Dec 16, 2024
1 parent dbff6e3 commit 55100c3
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 52 deletions.
41 changes: 23 additions & 18 deletions providers/src/airflow/providers/databricks/operators/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from __future__ import annotations

import hashlib
import time
from abc import ABC, abstractmethod
from collections.abc import Sequence
Expand Down Expand Up @@ -986,6 +987,7 @@ def __init__(
self,
caller: str = "DatabricksTaskBaseOperator",
databricks_conn_id: str = "databricks_default",
databricks_task_key: str | None = None,
databricks_retry_args: dict[Any, Any] | None = None,
databricks_retry_delay: int = 1,
databricks_retry_limit: int = 3,
Expand All @@ -1000,6 +1002,7 @@ def __init__(
):
self.caller = caller
self.databricks_conn_id = databricks_conn_id
self._databricks_task_key = databricks_task_key
self.databricks_retry_args = databricks_retry_args
self.databricks_retry_delay = databricks_retry_delay
self.databricks_retry_limit = databricks_retry_limit
Expand Down Expand Up @@ -1037,17 +1040,21 @@ def _get_hook(self, caller: str) -> DatabricksHook:
caller=caller,
)

def _get_databricks_task_id(self, task_id: str) -> str:
"""Get the databricks task ID using dag_id and task_id. Removes illegal characters."""
task_id = f"{self.dag_id}__{task_id.replace('.', '__')}"
if len(task_id) > 100:
self.log.warning(
"The generated task_key '%s' exceeds 100 characters and will be truncated by the Databricks API. "
"This will cause failure when trying to monitor the task. task_key is generated by ",
"concatenating dag_id and task_id.",
task_id,
@cached_property
def databricks_task_key(self) -> str:
return self._generate_databricks_task_key()

def _generate_databricks_task_key(self, task_id: str | None = None) -> str:
"""Create a databricks task key using the hash of dag_id and task_id."""
if self._databricks_task_key is None:
self.log.info(
"No databricks_task_key provided. Generating task key using the hash value of dag_id+task_id."
)
return task_id
task_id = task_id or self.task_id
task_key = f"{self.dag_id}__{task_id}".encode()
self._databricks_task_key = hashlib.md5(task_key).hexdigest()
self.log.info("Generated databricks task key: %s", task_key)
return self._databricks_task_key

@property
def _databricks_workflow_task_group(self) -> DatabricksWorkflowTaskGroup | None:
Expand Down Expand Up @@ -1077,7 +1084,7 @@ def _get_task_base_json(self) -> dict[str, Any]:
def _get_run_json(self) -> dict[str, Any]:
"""Get run json to be used for task submissions."""
run_json = {
"run_name": self._get_databricks_task_id(self.task_id),
"run_name": self.databricks_task_key,
**self._get_task_base_json(),
}
if self.new_cluster and self.existing_cluster_id:
Expand Down Expand Up @@ -1127,19 +1134,17 @@ def _get_current_databricks_task(self) -> dict[str, Any]:
# building the {task_key: task} map below.
sorted_task_runs = sorted(tasks, key=lambda x: x["start_time"])

return {task["task_key"]: task for task in sorted_task_runs}[
self._get_databricks_task_id(self.task_id)
]
return {task["task_key"]: task for task in sorted_task_runs}[self.databricks_task_key]

def _convert_to_databricks_workflow_task(
self, relevant_upstreams: list[BaseOperator], context: Context | None = None
) -> dict[str, object]:
"""Convert the operator to a Databricks workflow task that can be a task in a workflow."""
base_task_json = self._get_task_base_json()
result = {
"task_key": self._get_databricks_task_id(self.task_id),
"task_key": self.databricks_task_key,
"depends_on": [
{"task_key": self._get_databricks_task_id(task_id)}
{"task_key": self._generate_databricks_task_key(task_id)}
for task_id in self.upstream_task_ids
if task_id in relevant_upstreams
],
Expand Down Expand Up @@ -1172,7 +1177,7 @@ def monitor_databricks_job(self) -> None:
run_state = RunState(**run["state"])
self.log.info(
"Current state of the the databricks task %s is %s",
self._get_databricks_task_id(self.task_id),
self.databricks_task_key,
run_state.life_cycle_state,
)
if self.deferrable and not run_state.is_terminal:
Expand All @@ -1194,7 +1199,7 @@ def monitor_databricks_job(self) -> None:
run_state = RunState(**run["state"])
self.log.info(
"Current state of the databricks task %s is %s",
self._get_databricks_task_id(self.task_id),
self.databricks_task_key,
run_state.life_cycle_state,
)
self._handle_terminal_run_state(run_state)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
if TYPE_CHECKING:
from sqlalchemy.orm.session import Session

from airflow.providers.databricks.operators.databricks import DatabricksTaskBaseOperator


REPAIR_WAIT_ATTEMPTS = os.getenv("DATABRICKS_REPAIR_WAIT_ATTEMPTS", 20)
REPAIR_WAIT_DELAY = os.getenv("DATABRICKS_REPAIR_WAIT_DELAY", 0.5)
Expand All @@ -57,18 +59,8 @@ def get_auth_decorator():
return auth.has_access_dag("POST", DagAccessEntity.RUN)


def _get_databricks_task_id(task: BaseOperator) -> str:
"""
Get the databricks task ID using dag_id and task_id. removes illegal characters.
:param task: The task to get the databricks task ID for.
:return: The databricks task ID.
"""
return f"{task.dag_id}__{task.task_id.replace('.', '__')}"


def get_databricks_task_ids(
group_id: str, task_map: dict[str, BaseOperator], log: logging.Logger
group_id: str, task_map: dict[str, DatabricksTaskBaseOperator], log: logging.Logger
) -> list[str]:
"""
Return a list of all Databricks task IDs for a dictionary of Airflow tasks.
Expand All @@ -83,7 +75,7 @@ def get_databricks_task_ids(
for task_id, task in task_map.items():
if task_id == f"{group_id}.launch":
continue
databricks_task_id = _get_databricks_task_id(task)
databricks_task_id = task.databricks_task_key
log.debug("databricks task id for task %s is %s", task_id, databricks_task_id)
task_ids.append(databricks_task_id)
return task_ids
Expand Down Expand Up @@ -112,7 +104,7 @@ def _clear_task_instances(
dag = airflow_app.dag_bag.get_dag(dag_id)
log.debug("task_ids %s to clear", str(task_ids))
dr: DagRun = _get_dagrun(dag, run_id, session=session)
tis_to_clear = [ti for ti in dr.get_task_instances() if _get_databricks_task_id(ti) in task_ids]
tis_to_clear = [ti for ti in dr.get_task_instances() if ti.databricks_task_key in task_ids]
clear_task_instances(tis_to_clear, session)


Expand Down Expand Up @@ -327,7 +319,7 @@ def get_tasks_to_run(self, ti_key: TaskInstanceKey, operator: BaseOperator, log:

tasks_to_run = {ti: t for ti, t in task_group_sub_tasks if ti in failed_and_skipped_tasks}

return ",".join(get_databricks_task_ids(task_group.group_id, tasks_to_run, log))
return ",".join(get_databricks_task_ids(task_group.group_id, tasks_to_run, log)) # type: ignore[arg-type]

@staticmethod
def _get_failed_and_skipped_tasks(dr: DagRun) -> list[str]:
Expand Down Expand Up @@ -390,7 +382,7 @@ def get_link(
"databricks_conn_id": metadata.conn_id,
"databricks_run_id": metadata.run_id,
"run_id": ti_key.run_id,
"tasks_to_repair": _get_databricks_task_id(task),
"tasks_to_repair": task.databricks_task_key,
}
return url_for("RepairDatabricksTasks.repair", **query_params)

Expand Down
28 changes: 27 additions & 1 deletion providers/tests/databricks/operators/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

import hashlib
from datetime import datetime, timedelta
from typing import Any
from unittest import mock
Expand Down Expand Up @@ -2216,8 +2217,9 @@ def test_convert_to_databricks_workflow_task(self):

task_json = operator._convert_to_databricks_workflow_task(relevant_upstreams)

task_key = hashlib.md5(b"example_dag__test_task").hexdigest()
expected_json = {
"task_key": "example_dag__test_task",
"task_key": task_key,
"depends_on": [],
"timeout_seconds": 0,
"email_notifications": {},
Expand Down Expand Up @@ -2317,3 +2319,27 @@ def test_get_task_base_json(self):

assert operator.task_config == task_config
assert task_base_json == task_config

def test_generate_databricks_task_key(self):
task_config = {}
operator = DatabricksTaskOperator(
task_id="test_task",
databricks_conn_id="test_conn_id",
task_config=task_config,
)

task_key = f"{operator.dag_id}__{operator.task_id}".encode()
expected_task_key = hashlib.md5(task_key).hexdigest()
assert expected_task_key == operator.databricks_task_key

def test_user_databricks_task_key(self):
task_config = {}
operator = DatabricksTaskOperator(
task_id="test_task",
databricks_conn_id="test_conn_id",
databricks_task_key="test_task_key",
task_config=task_config,
)
expected_task_key = "test_task_key"

assert expected_task_key == operator.databricks_task_key
22 changes: 4 additions & 18 deletions providers/tests/databricks/plugins/test_databricks_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
WorkflowJobRepairSingleTaskLink,
WorkflowJobRunLink,
_get_dagrun,
_get_databricks_task_id,
_get_launch_task_key,
_repair_task,
get_databricks_task_ids,
Expand All @@ -50,30 +49,17 @@
DATABRICKS_CONN_ID = "databricks_default"
DATABRICKS_RUN_ID = 12345
GROUP_ID = "test_group"
LOG = MagicMock()
TASK_MAP = {
"task1": MagicMock(dag_id=DAG_ID, task_id="task1"),
"task2": MagicMock(dag_id=DAG_ID, task_id="task2"),
"task1": MagicMock(dag_id=DAG_ID, task_id="task1", databricks_task_key="task1_key"),
"task2": MagicMock(dag_id=DAG_ID, task_id="task2", databricks_task_key="task2_key"),
}
LOG = MagicMock()


@pytest.mark.parametrize(
"task, expected_id",
[
(MagicMock(dag_id="dag1", task_id="task.1"), "dag1__task__1"),
(MagicMock(dag_id="dag2", task_id="task_1"), "dag2__task_1"),
],
)
def test_get_databricks_task_id(task, expected_id):
result = _get_databricks_task_id(task)

assert result == expected_id


def test_get_databricks_task_ids():
result = get_databricks_task_ids(GROUP_ID, TASK_MAP, LOG)

expected_ids = ["test_dag__task1", "test_dag__task2"]
expected_ids = ["task1_key", "task2_key"]
assert result == expected_ids


Expand Down

0 comments on commit 55100c3

Please sign in to comment.