Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accept task_key as an argument in DatabricksNotebookOperator #44960

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 25 additions & 18 deletions providers/src/airflow/providers/databricks/operators/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from __future__ import annotations

import hashlib
import time
from abc import ABC, abstractmethod
from collections.abc import Sequence
Expand Down Expand Up @@ -966,6 +967,8 @@ class DatabricksTaskBaseOperator(BaseOperator, ABC):

:param caller: The name of the caller operator to be used in the logs.
:param databricks_conn_id: The name of the Airflow connection to use.
:param databricks_task_key: An optional task_key used to refer to the task by Databricks API. By
default this will be set to the hash of ``dag_id + task_id``.
:param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class.
:param databricks_retry_delay: Number of seconds to wait between retries.
:param databricks_retry_limit: Amount of times to retry if the Databricks backend is unreachable.
Expand All @@ -986,6 +989,7 @@ def __init__(
self,
caller: str = "DatabricksTaskBaseOperator",
databricks_conn_id: str = "databricks_default",
databricks_task_key: str = "",
databricks_retry_args: dict[Any, Any] | None = None,
databricks_retry_delay: int = 1,
databricks_retry_limit: int = 3,
Expand All @@ -1000,6 +1004,7 @@ def __init__(
):
self.caller = caller
self.databricks_conn_id = databricks_conn_id
self._databricks_task_key = databricks_task_key
self.databricks_retry_args = databricks_retry_args
self.databricks_retry_delay = databricks_retry_delay
self.databricks_retry_limit = databricks_retry_limit
Expand Down Expand Up @@ -1037,17 +1042,21 @@ def _get_hook(self, caller: str) -> DatabricksHook:
caller=caller,
)

def _get_databricks_task_id(self, task_id: str) -> str:
"""Get the databricks task ID using dag_id and task_id. Removes illegal characters."""
task_id = f"{self.dag_id}__{task_id.replace('.', '__')}"
if len(task_id) > 100:
self.log.warning(
"The generated task_key '%s' exceeds 100 characters and will be truncated by the Databricks API. "
"This will cause failure when trying to monitor the task. task_key is generated by ",
"concatenating dag_id and task_id.",
task_id,
@cached_property
def databricks_task_key(self) -> str:
return self._generate_databricks_task_key()

def _generate_databricks_task_key(self, task_id: str | None = None) -> str:
"""Create a databricks task key using the hash of dag_id and task_id."""
if not self._databricks_task_key or len(self._databricks_task_key) > 100:
self.log.info(
"databricks_task_key has not be provided or the provided one exceeds 100 characters and will be truncated by the Databricks API. This will cause failure when trying to monitor the task. A task_key will be generated using the hash value of dag_id+task_id"
)
return task_id
task_id = task_id or self.task_id
task_key = f"{self.dag_id}__{task_id}".encode()
self._databricks_task_key = hashlib.md5(task_key).hexdigest()
self.log.info("Generated databricks task_key: %s", self._databricks_task_key)
return self._databricks_task_key
hardeybisey marked this conversation as resolved.
Show resolved Hide resolved

@property
def _databricks_workflow_task_group(self) -> DatabricksWorkflowTaskGroup | None:
Expand Down Expand Up @@ -1077,7 +1086,7 @@ def _get_task_base_json(self) -> dict[str, Any]:
def _get_run_json(self) -> dict[str, Any]:
"""Get run json to be used for task submissions."""
run_json = {
"run_name": self._get_databricks_task_id(self.task_id),
"run_name": self.databricks_task_key,
**self._get_task_base_json(),
}
if self.new_cluster and self.existing_cluster_id:
Expand Down Expand Up @@ -1127,19 +1136,17 @@ def _get_current_databricks_task(self) -> dict[str, Any]:
# building the {task_key: task} map below.
sorted_task_runs = sorted(tasks, key=lambda x: x["start_time"])

return {task["task_key"]: task for task in sorted_task_runs}[
self._get_databricks_task_id(self.task_id)
]
return {task["task_key"]: task for task in sorted_task_runs}[self.databricks_task_key]

def _convert_to_databricks_workflow_task(
self, relevant_upstreams: list[BaseOperator], context: Context | None = None
) -> dict[str, object]:
"""Convert the operator to a Databricks workflow task that can be a task in a workflow."""
base_task_json = self._get_task_base_json()
result = {
"task_key": self._get_databricks_task_id(self.task_id),
"task_key": self.databricks_task_key,
"depends_on": [
{"task_key": self._get_databricks_task_id(task_id)}
{"task_key": self._generate_databricks_task_key(task_id)}
for task_id in self.upstream_task_ids
if task_id in relevant_upstreams
],
hardeybisey marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -1172,7 +1179,7 @@ def monitor_databricks_job(self) -> None:
run_state = RunState(**run["state"])
self.log.info(
"Current state of the the databricks task %s is %s",
self._get_databricks_task_id(self.task_id),
self.databricks_task_key,
run_state.life_cycle_state,
)
if self.deferrable and not run_state.is_terminal:
Expand All @@ -1194,7 +1201,7 @@ def monitor_databricks_job(self) -> None:
run_state = RunState(**run["state"])
self.log.info(
"Current state of the databricks task %s is %s",
self._get_databricks_task_id(self.task_id),
self.databricks_task_key,
run_state.life_cycle_state,
)
self._handle_terminal_run_state(run_state)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
if TYPE_CHECKING:
from sqlalchemy.orm.session import Session

from airflow.providers.databricks.operators.databricks import DatabricksTaskBaseOperator


REPAIR_WAIT_ATTEMPTS = os.getenv("DATABRICKS_REPAIR_WAIT_ATTEMPTS", 20)
REPAIR_WAIT_DELAY = os.getenv("DATABRICKS_REPAIR_WAIT_DELAY", 0.5)
Expand All @@ -57,18 +59,8 @@ def get_auth_decorator():
return auth.has_access_dag("POST", DagAccessEntity.RUN)


def _get_databricks_task_id(task: BaseOperator) -> str:
"""
Get the databricks task ID using dag_id and task_id. removes illegal characters.

:param task: The task to get the databricks task ID for.
:return: The databricks task ID.
"""
return f"{task.dag_id}__{task.task_id.replace('.', '__')}"


def get_databricks_task_ids(
group_id: str, task_map: dict[str, BaseOperator], log: logging.Logger
group_id: str, task_map: dict[str, DatabricksTaskBaseOperator], log: logging.Logger
) -> list[str]:
"""
Return a list of all Databricks task IDs for a dictionary of Airflow tasks.
Expand All @@ -83,7 +75,7 @@ def get_databricks_task_ids(
for task_id, task in task_map.items():
if task_id == f"{group_id}.launch":
continue
databricks_task_id = _get_databricks_task_id(task)
databricks_task_id = task.databricks_task_key
log.debug("databricks task id for task %s is %s", task_id, databricks_task_id)
task_ids.append(databricks_task_id)
return task_ids
Expand Down Expand Up @@ -112,7 +104,7 @@ def _clear_task_instances(
dag = airflow_app.dag_bag.get_dag(dag_id)
log.debug("task_ids %s to clear", str(task_ids))
dr: DagRun = _get_dagrun(dag, run_id, session=session)
tis_to_clear = [ti for ti in dr.get_task_instances() if _get_databricks_task_id(ti) in task_ids]
tis_to_clear = [ti for ti in dr.get_task_instances() if ti.databricks_task_key in task_ids]
clear_task_instances(tis_to_clear, session)


Expand Down Expand Up @@ -327,7 +319,7 @@ def get_tasks_to_run(self, ti_key: TaskInstanceKey, operator: BaseOperator, log:

tasks_to_run = {ti: t for ti, t in task_group_sub_tasks if ti in failed_and_skipped_tasks}

return ",".join(get_databricks_task_ids(task_group.group_id, tasks_to_run, log))
return ",".join(get_databricks_task_ids(task_group.group_id, tasks_to_run, log)) # type: ignore[arg-type]

@staticmethod
def _get_failed_and_skipped_tasks(dr: DagRun) -> list[str]:
Expand Down Expand Up @@ -390,7 +382,7 @@ def get_link(
"databricks_conn_id": metadata.conn_id,
"databricks_run_id": metadata.run_id,
"run_id": ti_key.run_id,
"tasks_to_repair": _get_databricks_task_id(task),
"tasks_to_repair": task.databricks_task_key,
}
return url_for("RepairDatabricksTasks.repair", **query_params)

Expand Down
28 changes: 27 additions & 1 deletion providers/tests/databricks/operators/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

import hashlib
from datetime import datetime, timedelta
from typing import Any
from unittest import mock
Expand Down Expand Up @@ -2216,8 +2217,9 @@ def test_convert_to_databricks_workflow_task(self):

task_json = operator._convert_to_databricks_workflow_task(relevant_upstreams)

task_key = hashlib.md5(b"example_dag__test_task").hexdigest()
expected_json = {
"task_key": "example_dag__test_task",
"task_key": task_key,
"depends_on": [],
"timeout_seconds": 0,
"email_notifications": {},
Expand Down Expand Up @@ -2317,3 +2319,27 @@ def test_get_task_base_json(self):

assert operator.task_config == task_config
assert task_base_json == task_config

def test_generate_databricks_task_key(self):
task_config = {}
operator = DatabricksTaskOperator(
task_id="test_task",
databricks_conn_id="test_conn_id",
task_config=task_config,
)

task_key = f"{operator.dag_id}__{operator.task_id}".encode()
expected_task_key = hashlib.md5(task_key).hexdigest()
assert expected_task_key == operator.databricks_task_key

def test_user_databricks_task_key(self):
task_config = {}
operator = DatabricksTaskOperator(
task_id="test_task",
databricks_conn_id="test_conn_id",
databricks_task_key="test_task_key",
task_config=task_config,
)
expected_task_key = "test_task_key"

assert expected_task_key == operator.databricks_task_key
22 changes: 4 additions & 18 deletions providers/tests/databricks/plugins/test_databricks_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
WorkflowJobRepairSingleTaskLink,
WorkflowJobRunLink,
_get_dagrun,
_get_databricks_task_id,
_get_launch_task_key,
_repair_task,
get_databricks_task_ids,
Expand All @@ -50,30 +49,17 @@
DATABRICKS_CONN_ID = "databricks_default"
DATABRICKS_RUN_ID = 12345
GROUP_ID = "test_group"
LOG = MagicMock()
TASK_MAP = {
"task1": MagicMock(dag_id=DAG_ID, task_id="task1"),
"task2": MagicMock(dag_id=DAG_ID, task_id="task2"),
"task1": MagicMock(dag_id=DAG_ID, task_id="task1", databricks_task_key="task_key1"),
"task2": MagicMock(dag_id=DAG_ID, task_id="task2", databricks_task_key="task_key2"),
}
LOG = MagicMock()


@pytest.mark.parametrize(
"task, expected_id",
[
(MagicMock(dag_id="dag1", task_id="task.1"), "dag1__task__1"),
(MagicMock(dag_id="dag2", task_id="task_1"), "dag2__task_1"),
],
)
def test_get_databricks_task_id(task, expected_id):
result = _get_databricks_task_id(task)

assert result == expected_id


def test_get_databricks_task_ids():
result = get_databricks_task_ids(GROUP_ID, TASK_MAP, LOG)

expected_ids = ["test_dag__task1", "test_dag__task2"]
expected_ids = ["task_key1", "task_key2"]
assert result == expected_ids


Expand Down