Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,36 @@
)
from airflow.providers.common.compat.sdk import AirflowException, BaseOperator, conf

# ResumableJobMixin ships in airflow.sdk, which only exists on Airflow 3, while this provider
# still targets apache-airflow>=2.11. Guard the import and fall back to a stub on Airflow 2;
# drop the fallback once the provider's minimum Airflow version is >=3.0.
try:
from airflow.sdk import ResumableJobMixin
except ImportError:

class ResumableJobMixin: # type: ignore[no-redef]
"""Airflow 2 fallback: no task_state_store, so the operator always submits a fresh batch."""

external_id_key: str = "remote_job_id"

def __init__(self, *, durable: bool = True, **kwargs: Any) -> None:
# Swallow ``durable`` so it doesn't reach BaseOperator; crash recovery is a no-op here.
super().__init__(**kwargs)
self.durable = durable

def execute_resumable(self, context):
external_id = self.submit_job(context)
self.poll_until_complete(external_id, context)
return self.get_job_result(external_id, context)


if TYPE_CHECKING:
from pydantic import JsonValue

from airflow.providers.common.compat.sdk import Context


class LivyOperator(BaseOperator):
class LivyOperator(ResumableJobMixin, BaseOperator):
"""
Wraps the Apache Livy batch REST API, allowing to submit a Spark application to the underlying cluster.

Expand Down Expand Up @@ -62,10 +87,15 @@ class LivyOperator(BaseOperator):
:param retry_args: Arguments which define the retry behaviour.
See Tenacity documentation at https://github.com/jd/tenacity
:param deferrable: Run operator in the deferrable mode
:param durable: When True (the default) and the operator waits synchronously
(``deferrable=False`` with ``polling_interval > 0``), the Livy batch id is persisted before
polling so a worker crash reconnects to the running batch on retry instead of submitting a
duplicate. Requires Airflow 3.3+ (task_state_store); a no-op on earlier versions.
"""

template_fields: Sequence[str] = ("spark_params",)
template_fields_renderers = {"spark_params": "json"}
external_id_key = "livy_batch_id"

def __init__(
self,
Expand Down Expand Up @@ -167,14 +197,17 @@ def execute(self, context: Context) -> Any:
cast("dict", self.spark_params["conf"]), context
)

if not self.deferrable and self._polling_interval > 0:
# Synchronous wait: route through the resumable mixin so a worker crash mid-poll
# reconnects to the running batch on retry instead of resubmitting a duplicate.
return self.execute_resumable(context)

_batch_id: int | str = self.hook.post_batch(**self.spark_params)
self._batch_id = _batch_id
self.log.info("Generated batch-id is %s", self._batch_id)

# Wait for the job to complete
# No polling requested: submit and return without waiting (nothing to reconnect to).
if not self.deferrable:
if self._polling_interval > 0:
self.poll_for_termination(self._batch_id)
context["ti"].xcom_push(key="app_id", value=self.hook.get_batch(self._batch_id)["appId"])
return self._batch_id

Expand Down Expand Up @@ -203,6 +236,31 @@ def execute(self, context: Context) -> Any:
context["ti"].xcom_push(key="app_id", value=self.hook.get_batch(self._batch_id)["appId"])
return self._batch_id

def submit_job(self, context: Context) -> JsonValue:
batch_id: int | str = self.hook.post_batch(**self.spark_params)
self._batch_id = batch_id
self.log.info("Generated batch-id is %s", batch_id)
return batch_id

def get_job_status(self, external_id: JsonValue, context: Context) -> str:
return self.hook.get_batch_state(cast("int | str", external_id), retry_args=self.retry_args).value

def is_job_active(self, status: str) -> bool:
return BatchState(status) not in self.hook.TERMINAL_STATES

def is_job_succeeded(self, status: str) -> bool:
return BatchState(status) == BatchState.SUCCESS

def poll_until_complete(self, external_id: JsonValue, context: Context) -> None:
# Set _batch_id so on_kill() can delete the batch after a reconnect (submit_job was skipped).
self._batch_id = cast("int | str", external_id)
self.poll_for_termination(self._batch_id)

def get_job_result(self, external_id: JsonValue, context: Context) -> Any:
batch_id = cast("int | str", external_id)
context["ti"].xcom_push(key="app_id", value=self.hook.get_batch(batch_id)["appId"])
return batch_id

def poll_for_termination(self, batch_id: int | str) -> None:
"""
Pool Livy for batch termination.
Expand Down
127 changes: 126 additions & 1 deletion providers/apache/livy/tests/unit/apache/livy/operators/test_livy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,20 @@
from __future__ import annotations

import logging
from typing import Any
from unittest.mock import MagicMock, patch

import pytest

from airflow.models import Connection
from airflow.models.dag import DAG
from airflow.providers.apache.livy.hooks.livy import BatchState
from airflow.providers.apache.livy.hooks.livy import BatchState, LivyHook
from airflow.providers.apache.livy.operators.livy import LivyOperator
from airflow.providers.common.compat.sdk import AirflowException
from airflow.utils import timezone

from tests_common.test_utils.version_compat import AIRFLOW_V_3_3_PLUS

DEFAULT_DATE = timezone.datetime(2017, 1, 1)
BATCH_ID = 100
APP_ID = "application_1433865536131_34483"
Expand Down Expand Up @@ -607,3 +610,125 @@ def test_spark_params_templating(create_task_instance_of_operator, session):
"py_files": "literal-py-files",
"queue": "literal-queue",
}


class FakeTaskStateStore:
"""In-memory task state store for tests."""

def __init__(self, stored: dict[str, Any] | None = None):
self._store: dict[str, Any] = dict(stored or {})

def get(self, key: str) -> Any:
return self._store.get(key)

def set(self, key: str, value: Any) -> None:
self._store[key] = value


@pytest.mark.skipif(
not AIRFLOW_V_3_3_PLUS,
reason="ResumableJobMixin reconnect requires task_state_store, available in Airflow 3.3+",
)
class TestLivyOperatorResumable:
"""Crash-safe synchronous wait (deferrable=False, polling_interval>0) via ResumableJobMixin."""

def _make_operator(self, **kwargs) -> LivyOperator:
return LivyOperator(task_id="livy_resumable", file="sparkapp.jar", polling_interval=1, **kwargs)

def _make_hook(self, batch_id: int = BATCH_ID) -> MagicMock:
hook = MagicMock()
hook.post_batch.return_value = batch_id
hook.get_batch.return_value = GET_BATCH
hook.TERMINAL_STATES = LivyHook.TERMINAL_STATES
return hook

def test_first_run_persists_batch_id_before_polling(self):
operator = self._make_operator()
operator.hook = self._make_hook(batch_id=42)
task_store = FakeTaskStateStore()
persisted_before_poll = []
operator.poll_until_complete = lambda external_id, context: persisted_before_poll.append(
task_store.get("livy_batch_id")
)

operator.execute(context={"task_state_store": task_store, "ti": MagicMock()})

operator.hook.post_batch.assert_called_once()
assert persisted_before_poll == [42]

@pytest.mark.parametrize(
("prior_status", "expect_submit", "expect_poll_id"),
[
("running", False, 1), # active -> reconnect to the existing batch
("starting", False, 1),
("success", False, None), # already succeeded -> return, no poll, no resubmit
("dead", True, BATCH_ID), # terminal failure -> resubmit fresh
("killed", True, BATCH_ID),
("error", True, BATCH_ID),
],
)
def test_retry_behaviour_based_on_prior_batch_status(self, prior_status, expect_submit, expect_poll_id):
operator = self._make_operator()
operator.hook = self._make_hook()
task_store = FakeTaskStateStore({"livy_batch_id": 1})
operator.get_job_status = lambda external_id, context: prior_status
polled = []
operator.poll_until_complete = lambda external_id, context: polled.append(external_id)

operator.execute(context={"task_state_store": task_store, "ti": MagicMock()})

if expect_submit:
operator.hook.post_batch.assert_called_once()
else:
operator.hook.post_batch.assert_not_called()
assert polled == ([] if expect_poll_id is None else [expect_poll_id])

def test_submits_fresh_when_task_state_store_unavailable(self):
operator = self._make_operator()
operator.hook = self._make_hook(batch_id=7)
polled = []
operator.poll_until_complete = lambda external_id, context: polled.append(external_id)

operator.execute(context={"ti": MagicMock()})

operator.hook.post_batch.assert_called_once()
assert polled == [7]

def test_durable_false_submits_fresh_and_polls(self):
operator = self._make_operator(durable=False)
operator.hook = self._make_hook(batch_id=7)
task_store = FakeTaskStateStore({"livy_batch_id": 1})
polled = []
operator.poll_until_complete = lambda external_id, context: polled.append(external_id)

operator.execute(context={"task_state_store": task_store, "ti": MagicMock()})

operator.hook.post_batch.assert_called_once()
assert polled == [7]

def test_status_helpers_classify_real_batch_states(self):
operator = self._make_operator()
operator.hook = self._make_hook()
assert operator.is_job_active("running") is True
assert operator.is_job_active("starting") is True
assert operator.is_job_active("success") is False
assert operator.is_job_succeeded("success") is True
assert operator.is_job_succeeded("dead") is False

def test_get_job_status_reads_batch_state_value(self):
operator = self._make_operator()
hook = self._make_hook()
hook.get_batch_state.return_value = BatchState.RUNNING
operator.hook = hook

assert operator.get_job_status(BATCH_ID, {}) == "running"

def test_poll_until_complete_sets_batch_id_for_on_kill(self):
operator = self._make_operator()
hook = self._make_hook()
hook.get_batch_state.return_value = BatchState.SUCCESS
operator.hook = hook

operator.poll_until_complete(55, {})

assert operator._batch_id == 55
Loading