Skip to content

Commit

Permalink
feat(airflow): add on_finish_action to decide what to do when the job…
Browse files Browse the repository at this point in the history
… finishes (#100)
  • Loading branch information
hussein-awala authored Nov 26, 2024
1 parent 8e65993 commit 0148e34
Showing 1 changed file with 43 additions and 15 deletions.
58 changes: 43 additions & 15 deletions spark_on_k8s/airflow/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from spark_on_k8s.airflow.operator_links import SparkOnK8SOperatorLink
from spark_on_k8s.airflow.triggers import SparkOnK8STrigger
from spark_on_k8s.k8s.sync_client import KubernetesClientManager
from spark_on_k8s.utils.app_manager import SparkAppManager
from spark_on_k8s.utils.spark_app_status import SparkAppStatus

if TYPE_CHECKING:
from typing import Literal
Expand All @@ -18,7 +20,6 @@

from airflow.utils.context import Context
from spark_on_k8s.client import ExecutorInstances, PodResources
from spark_on_k8s.utils.app_manager import SparkAppManager
from spark_on_k8s.utils.types import ConfigMap


Expand All @@ -37,6 +38,12 @@ def create_client(self):
return k8s_hook.get_conn()


class OnFinishAction(str, Enum):
KEEP = "keep"
DELETE = "delete"
KEEP_IF_FAILED = "keep_if_failed"


class OnKillAction(str, Enum):
KEEP = "keep"
DELETE = "delete"
Expand Down Expand Up @@ -82,8 +89,10 @@ class SparkOnK8SOperator(BaseOperator):
poll_interval (int, optional): Poll interval for checking the Spark application status.
Defaults to 10.
deferrable (bool, optional): Whether the operator is deferrable. Defaults to False.
on_kill_action (Literal["keep", "delete", "kill"], optional): Action to take when the
operator is killed. Defaults to "delete".
on_finish_action (OnFinishAction, optional): Action to take when the operator finishes.
Defaults to "keep".
on_kill_action (OnKillAction, optional): Action to take when the operator is killed.
Defaults to "delete".
startup_timeout (int, optional): Timeout for the Spark application to start.
Defaults to 0 (no timeout).
**kwargs: Other keyword arguments for BaseOperator.
Expand Down Expand Up @@ -152,7 +161,9 @@ def __init__(
kubernetes_conn_id: str = "kubernetes_default",
poll_interval: int = 10,
deferrable: bool = False,
on_kill_action: Literal["keep", "delete", "kill"] = OnKillAction.DELETE,
# TODO: Change to OnFinishAction.KEEP_IF_FAILED in the stable version
on_finish_action: OnFinishAction = OnFinishAction.KEEP,
on_kill_action: OnKillAction = OnKillAction.DELETE,
startup_timeout: int = 0,
**kwargs,
):
Expand Down Expand Up @@ -190,6 +201,7 @@ def __init__(
self.kubernetes_conn_id = kubernetes_conn_id
self.poll_interval = poll_interval
self.deferrable = deferrable
self.on_finish_action = on_finish_action
self.on_kill_action = on_kill_action
self.startup_timeout = startup_timeout

Expand Down Expand Up @@ -244,8 +256,6 @@ def _persist_spark_history_ui_link(self, context: Context):
)

def _try_to_adopt_job(self, context: Context, spark_app_manager: SparkAppManager) -> bool:
from spark_on_k8s.utils.spark_app_status import SparkAppStatus

xcom_driver_namespace = context["ti"].xcom_pull(
dag_id=context["ti"].dag_id,
task_ids=context["ti"].task_id,
Expand Down Expand Up @@ -345,6 +355,24 @@ def _submit_new_job(self, context: Context):
**submit_app_kwargs,
)

def _clean_up(self, spark_app_manager: SparkAppManager, app_status: SparkAppStatus):
if self.on_finish_action == OnFinishAction.KEEP or (
self.on_finish_action == OnFinishAction.KEEP_IF_FAILED and app_status == SparkAppStatus.Failed
):
return
if self.on_finish_action == OnFinishAction.DELETE or (
self.on_finish_action == OnFinishAction.KEEP_IF_FAILED and app_status != SparkAppStatus.Failed
):
self.log.info("Deleting Spark application...")
spark_app_manager.delete_app(
namespace=self.namespace,
pod_name=self._driver_pod_name,
)
else:
self.log.error(
f"Something went wrong: on_finish_action={self.on_finish_action}, app_status={app_status}"
)

def execute(self, context: Context):
from spark_on_k8s.utils.app_manager import SparkAppManager

Expand Down Expand Up @@ -402,27 +430,27 @@ def execute(self, context: Context):
pod_name=self._driver_pod_name,
)
self._persist_spark_history_ui_link(context)
if app_status == "Succeeded":
self._clean_up(spark_app_manager, app_status)
if app_status == SparkAppStatus.Succeeded:
return app_status
raise AirflowException(f"The job finished with status: {app_status}")

def execute_complete(self, context: Context, event: dict, **kwargs):
self.namespace = event["namespace"]
self._driver_pod_name = event["pod_name"]
k8s_client_manager = _AirflowKubernetesClientManager(
kubernetes_conn_id=self.kubernetes_conn_id,
)
spark_app_manager = SparkAppManager(
k8s_client_manager=k8s_client_manager,
)
if self.app_waiter == "log":
from spark_on_k8s.utils.app_manager import SparkAppManager

k8s_client_manager = _AirflowKubernetesClientManager(
kubernetes_conn_id=self.kubernetes_conn_id,
)
spark_app_manager = SparkAppManager(
k8s_client_manager=k8s_client_manager,
)
spark_app_manager.stream_logs(
namespace=event["namespace"],
pod_name=event["pod_name"],
)
self._persist_spark_history_ui_link(context)
self._clean_up(spark_app_manager, event["status"])
if event["status"] == "Succeeded":
return event["status"]
if event["status"] == "error":
Expand Down

0 comments on commit 0148e34

Please sign in to comment.