diff --git a/providers/apache/livy/src/airflow/providers/apache/livy/operators/livy.py b/providers/apache/livy/src/airflow/providers/apache/livy/operators/livy.py index b42de329d332b..b979f6c7ffbb1 100644 --- a/providers/apache/livy/src/airflow/providers/apache/livy/operators/livy.py +++ b/providers/apache/livy/src/airflow/providers/apache/livy/operators/livy.py @@ -29,11 +29,36 @@ ) from airflow.providers.common.compat.sdk import AirflowException, BaseOperator, conf +# ResumableJobMixin ships in airflow.sdk, which only exists on Airflow 3, while this provider +# still targets apache-airflow>=2.11. Guard the import and fall back to a stub on Airflow 2; +# drop the fallback once the provider's minimum Airflow version is >=3.0. +try: + from airflow.sdk import ResumableJobMixin +except ImportError: + + class ResumableJobMixin: # type: ignore[no-redef] + """Airflow 2 fallback: no task_state_store, so the operator always submits a fresh batch.""" + + external_id_key: str = "remote_job_id" + + def __init__(self, *, durable: bool = True, **kwargs: Any) -> None: + # Swallow ``durable`` so it doesn't reach BaseOperator; crash recovery is a no-op here. + super().__init__(**kwargs) + self.durable = durable + + def execute_resumable(self, context): + external_id = self.submit_job(context) + self.poll_until_complete(external_id, context) + return self.get_job_result(external_id, context) + + if TYPE_CHECKING: + from pydantic import JsonValue + from airflow.providers.common.compat.sdk import Context -class LivyOperator(BaseOperator): +class LivyOperator(ResumableJobMixin, BaseOperator): """ Wraps the Apache Livy batch REST API, allowing to submit a Spark application to the underlying cluster. @@ -62,10 +87,15 @@ class LivyOperator(BaseOperator): :param retry_args: Arguments which define the retry behaviour. See Tenacity documentation at https://github.com/jd/tenacity :param deferrable: Run operator in the deferrable mode + :param durable: When True (the default) and the operator waits synchronously + (``deferrable=False`` with ``polling_interval > 0``), the Livy batch id is persisted before + polling so a worker crash reconnects to the running batch on retry instead of submitting a + duplicate. Requires Airflow 3.3+ (task_state_store); a no-op on earlier versions. """ template_fields: Sequence[str] = ("spark_params",) template_fields_renderers = {"spark_params": "json"} + external_id_key = "livy_batch_id" def __init__( self, @@ -167,14 +197,17 @@ def execute(self, context: Context) -> Any: cast("dict", self.spark_params["conf"]), context ) + if not self.deferrable and self._polling_interval > 0: + # Synchronous wait: route through the resumable mixin so a worker crash mid-poll + # reconnects to the running batch on retry instead of resubmitting a duplicate. + return self.execute_resumable(context) + _batch_id: int | str = self.hook.post_batch(**self.spark_params) self._batch_id = _batch_id self.log.info("Generated batch-id is %s", self._batch_id) - # Wait for the job to complete + # No polling requested: submit and return without waiting (nothing to reconnect to). if not self.deferrable: - if self._polling_interval > 0: - self.poll_for_termination(self._batch_id) context["ti"].xcom_push(key="app_id", value=self.hook.get_batch(self._batch_id)["appId"]) return self._batch_id @@ -203,6 +236,31 @@ def execute(self, context: Context) -> Any: context["ti"].xcom_push(key="app_id", value=self.hook.get_batch(self._batch_id)["appId"]) return self._batch_id + def submit_job(self, context: Context) -> JsonValue: + batch_id: int | str = self.hook.post_batch(**self.spark_params) + self._batch_id = batch_id + self.log.info("Generated batch-id is %s", batch_id) + return batch_id + + def get_job_status(self, external_id: JsonValue, context: Context) -> str: + return self.hook.get_batch_state(cast("int | str", external_id), retry_args=self.retry_args).value + + def is_job_active(self, status: str) -> bool: + return BatchState(status) not in self.hook.TERMINAL_STATES + + def is_job_succeeded(self, status: str) -> bool: + return BatchState(status) == BatchState.SUCCESS + + def poll_until_complete(self, external_id: JsonValue, context: Context) -> None: + # Set _batch_id so on_kill() can delete the batch after a reconnect (submit_job was skipped). + self._batch_id = cast("int | str", external_id) + self.poll_for_termination(self._batch_id) + + def get_job_result(self, external_id: JsonValue, context: Context) -> Any: + batch_id = cast("int | str", external_id) + context["ti"].xcom_push(key="app_id", value=self.hook.get_batch(batch_id)["appId"]) + return batch_id + def poll_for_termination(self, batch_id: int | str) -> None: """ Pool Livy for batch termination. diff --git a/providers/apache/livy/tests/unit/apache/livy/operators/test_livy.py b/providers/apache/livy/tests/unit/apache/livy/operators/test_livy.py index 3851ad9ebb762..7c0419c450ba3 100644 --- a/providers/apache/livy/tests/unit/apache/livy/operators/test_livy.py +++ b/providers/apache/livy/tests/unit/apache/livy/operators/test_livy.py @@ -17,17 +17,20 @@ from __future__ import annotations import logging +from typing import Any from unittest.mock import MagicMock, patch import pytest from airflow.models import Connection from airflow.models.dag import DAG -from airflow.providers.apache.livy.hooks.livy import BatchState +from airflow.providers.apache.livy.hooks.livy import BatchState, LivyHook from airflow.providers.apache.livy.operators.livy import LivyOperator from airflow.providers.common.compat.sdk import AirflowException from airflow.utils import timezone +from tests_common.test_utils.version_compat import AIRFLOW_V_3_3_PLUS + DEFAULT_DATE = timezone.datetime(2017, 1, 1) BATCH_ID = 100 APP_ID = "application_1433865536131_34483" @@ -607,3 +610,125 @@ def test_spark_params_templating(create_task_instance_of_operator, session): "py_files": "literal-py-files", "queue": "literal-queue", } + + +class FakeTaskStateStore: + """In-memory task state store for tests.""" + + def __init__(self, stored: dict[str, Any] | None = None): + self._store: dict[str, Any] = dict(stored or {}) + + def get(self, key: str) -> Any: + return self._store.get(key) + + def set(self, key: str, value: Any) -> None: + self._store[key] = value + + +@pytest.mark.skipif( + not AIRFLOW_V_3_3_PLUS, + reason="ResumableJobMixin reconnect requires task_state_store, available in Airflow 3.3+", +) +class TestLivyOperatorResumable: + """Crash-safe synchronous wait (deferrable=False, polling_interval>0) via ResumableJobMixin.""" + + def _make_operator(self, **kwargs) -> LivyOperator: + return LivyOperator(task_id="livy_resumable", file="sparkapp.jar", polling_interval=1, **kwargs) + + def _make_hook(self, batch_id: int = BATCH_ID) -> MagicMock: + hook = MagicMock() + hook.post_batch.return_value = batch_id + hook.get_batch.return_value = GET_BATCH + hook.TERMINAL_STATES = LivyHook.TERMINAL_STATES + return hook + + def test_first_run_persists_batch_id_before_polling(self): + operator = self._make_operator() + operator.hook = self._make_hook(batch_id=42) + task_store = FakeTaskStateStore() + persisted_before_poll = [] + operator.poll_until_complete = lambda external_id, context: persisted_before_poll.append( + task_store.get("livy_batch_id") + ) + + operator.execute(context={"task_state_store": task_store, "ti": MagicMock()}) + + operator.hook.post_batch.assert_called_once() + assert persisted_before_poll == [42] + + @pytest.mark.parametrize( + ("prior_status", "expect_submit", "expect_poll_id"), + [ + ("running", False, 1), # active -> reconnect to the existing batch + ("starting", False, 1), + ("success", False, None), # already succeeded -> return, no poll, no resubmit + ("dead", True, BATCH_ID), # terminal failure -> resubmit fresh + ("killed", True, BATCH_ID), + ("error", True, BATCH_ID), + ], + ) + def test_retry_behaviour_based_on_prior_batch_status(self, prior_status, expect_submit, expect_poll_id): + operator = self._make_operator() + operator.hook = self._make_hook() + task_store = FakeTaskStateStore({"livy_batch_id": 1}) + operator.get_job_status = lambda external_id, context: prior_status + polled = [] + operator.poll_until_complete = lambda external_id, context: polled.append(external_id) + + operator.execute(context={"task_state_store": task_store, "ti": MagicMock()}) + + if expect_submit: + operator.hook.post_batch.assert_called_once() + else: + operator.hook.post_batch.assert_not_called() + assert polled == ([] if expect_poll_id is None else [expect_poll_id]) + + def test_submits_fresh_when_task_state_store_unavailable(self): + operator = self._make_operator() + operator.hook = self._make_hook(batch_id=7) + polled = [] + operator.poll_until_complete = lambda external_id, context: polled.append(external_id) + + operator.execute(context={"ti": MagicMock()}) + + operator.hook.post_batch.assert_called_once() + assert polled == [7] + + def test_durable_false_submits_fresh_and_polls(self): + operator = self._make_operator(durable=False) + operator.hook = self._make_hook(batch_id=7) + task_store = FakeTaskStateStore({"livy_batch_id": 1}) + polled = [] + operator.poll_until_complete = lambda external_id, context: polled.append(external_id) + + operator.execute(context={"task_state_store": task_store, "ti": MagicMock()}) + + operator.hook.post_batch.assert_called_once() + assert polled == [7] + + def test_status_helpers_classify_real_batch_states(self): + operator = self._make_operator() + operator.hook = self._make_hook() + assert operator.is_job_active("running") is True + assert operator.is_job_active("starting") is True + assert operator.is_job_active("success") is False + assert operator.is_job_succeeded("success") is True + assert operator.is_job_succeeded("dead") is False + + def test_get_job_status_reads_batch_state_value(self): + operator = self._make_operator() + hook = self._make_hook() + hook.get_batch_state.return_value = BatchState.RUNNING + operator.hook = hook + + assert operator.get_job_status(BATCH_ID, {}) == "running" + + def test_poll_until_complete_sets_batch_id_for_on_kill(self): + operator = self._make_operator() + hook = self._make_hook() + hook.get_batch_state.return_value = BatchState.SUCCESS + operator.hook = hook + + operator.poll_until_complete(55, {}) + + assert operator._batch_id == 55