From 0148e34640518758cf1402baa641b6580fb66d75 Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Wed, 27 Nov 2024 00:49:04 +0100 Subject: [PATCH] feat(airflow): add on_finish_action to decide what to do when the job finishes (#100) --- spark_on_k8s/airflow/operators.py | 58 +++++++++++++++++++++++-------- 1 file changed, 43 insertions(+), 15 deletions(-) diff --git a/spark_on_k8s/airflow/operators.py b/spark_on_k8s/airflow/operators.py index cbaef13..839fe8e 100644 --- a/spark_on_k8s/airflow/operators.py +++ b/spark_on_k8s/airflow/operators.py @@ -9,6 +9,8 @@ from spark_on_k8s.airflow.operator_links import SparkOnK8SOperatorLink from spark_on_k8s.airflow.triggers import SparkOnK8STrigger from spark_on_k8s.k8s.sync_client import KubernetesClientManager +from spark_on_k8s.utils.app_manager import SparkAppManager +from spark_on_k8s.utils.spark_app_status import SparkAppStatus if TYPE_CHECKING: from typing import Literal @@ -18,7 +20,6 @@ from airflow.utils.context import Context from spark_on_k8s.client import ExecutorInstances, PodResources - from spark_on_k8s.utils.app_manager import SparkAppManager from spark_on_k8s.utils.types import ConfigMap @@ -37,6 +38,12 @@ def create_client(self): return k8s_hook.get_conn() +class OnFinishAction(str, Enum): + KEEP = "keep" + DELETE = "delete" + KEEP_IF_FAILED = "keep_if_failed" + + class OnKillAction(str, Enum): KEEP = "keep" DELETE = "delete" @@ -82,8 +89,10 @@ class SparkOnK8SOperator(BaseOperator): poll_interval (int, optional): Poll interval for checking the Spark application status. Defaults to 10. deferrable (bool, optional): Whether the operator is deferrable. Defaults to False. - on_kill_action (Literal["keep", "delete", "kill"], optional): Action to take when the - operator is killed. Defaults to "delete". + on_finish_action (OnFinishAction, optional): Action to take when the operator finishes. + Defaults to "keep". + on_kill_action (OnKillAction, optional): Action to take when the operator is killed. + Defaults to "delete". startup_timeout (int, optional): Timeout for the Spark application to start. Defaults to 0 (no timeout). **kwargs: Other keyword arguments for BaseOperator. @@ -152,7 +161,9 @@ def __init__( kubernetes_conn_id: str = "kubernetes_default", poll_interval: int = 10, deferrable: bool = False, - on_kill_action: Literal["keep", "delete", "kill"] = OnKillAction.DELETE, + # TODO: Change to OnFinishAction.KEEP_IF_FAILED in the stable version + on_finish_action: OnFinishAction = OnFinishAction.KEEP, + on_kill_action: OnKillAction = OnKillAction.DELETE, startup_timeout: int = 0, **kwargs, ): @@ -190,6 +201,7 @@ def __init__( self.kubernetes_conn_id = kubernetes_conn_id self.poll_interval = poll_interval self.deferrable = deferrable + self.on_finish_action = on_finish_action self.on_kill_action = on_kill_action self.startup_timeout = startup_timeout @@ -244,8 +256,6 @@ def _persist_spark_history_ui_link(self, context: Context): ) def _try_to_adopt_job(self, context: Context, spark_app_manager: SparkAppManager) -> bool: - from spark_on_k8s.utils.spark_app_status import SparkAppStatus - xcom_driver_namespace = context["ti"].xcom_pull( dag_id=context["ti"].dag_id, task_ids=context["ti"].task_id, @@ -345,6 +355,24 @@ def _submit_new_job(self, context: Context): **submit_app_kwargs, ) + def _clean_up(self, spark_app_manager: SparkAppManager, app_status: SparkAppStatus): + if self.on_finish_action == OnFinishAction.KEEP or ( + self.on_finish_action == OnFinishAction.KEEP_IF_FAILED and app_status == SparkAppStatus.Failed + ): + return + if self.on_finish_action == OnFinishAction.DELETE or ( + self.on_finish_action == OnFinishAction.KEEP_IF_FAILED and app_status != SparkAppStatus.Failed + ): + self.log.info("Deleting Spark application...") + spark_app_manager.delete_app( + namespace=self.namespace, + pod_name=self._driver_pod_name, + ) + else: + self.log.error( + f"Something went wrong: on_finish_action={self.on_finish_action}, app_status={app_status}" + ) + def execute(self, context: Context): from spark_on_k8s.utils.app_manager import SparkAppManager @@ -402,27 +430,27 @@ def execute(self, context: Context): pod_name=self._driver_pod_name, ) self._persist_spark_history_ui_link(context) - if app_status == "Succeeded": + self._clean_up(spark_app_manager, app_status) + if app_status == SparkAppStatus.Succeeded: return app_status raise AirflowException(f"The job finished with status: {app_status}") def execute_complete(self, context: Context, event: dict, **kwargs): self.namespace = event["namespace"] self._driver_pod_name = event["pod_name"] + k8s_client_manager = _AirflowKubernetesClientManager( + kubernetes_conn_id=self.kubernetes_conn_id, + ) + spark_app_manager = SparkAppManager( + k8s_client_manager=k8s_client_manager, + ) if self.app_waiter == "log": - from spark_on_k8s.utils.app_manager import SparkAppManager - - k8s_client_manager = _AirflowKubernetesClientManager( - kubernetes_conn_id=self.kubernetes_conn_id, - ) - spark_app_manager = SparkAppManager( - k8s_client_manager=k8s_client_manager, - ) spark_app_manager.stream_logs( namespace=event["namespace"], pod_name=event["pod_name"], ) self._persist_spark_history_ui_link(context) + self._clean_up(spark_app_manager, event["status"]) if event["status"] == "Succeeded": return event["status"] if event["status"] == "error":