Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 132 additions & 97 deletions task-sdk/src/airflow/sdk/bases/resumablejobmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@
from airflow.sdk.bases.operator import BaseOperatorMeta

if TYPE_CHECKING:
from collections.abc import Callable

from pydantic import JsonValue
from structlog.typing import FilteringBoundLogger

from airflow.sdk.definitions.context import Context
from airflow.sdk.types import Logger
Expand Down Expand Up @@ -101,26 +104,12 @@ def __init__(self, *, durable: bool = True, **kwargs: Any) -> None:

def execute_resumable(self, context: Context) -> Any:
"""
Core of the resumable execution logic. Call this from execute() when reconnection is supported.

On initial run: submits the job, persists the external ID to task_state_store, then polls.
Crash-safe submit-and-poll for synchronous operators. Call from ``execute()``.

Behaviour on retry:
- On retry with active job: skips submission, reconnects to the running job.
- On retry with succeeded job: skips submission and polling, returns result immediately.
- On retry with failed job: falls through and resubmits fresh.

Known limitation: there is a small window between ``submit_job`` returning and
``task_state_store.set`` completing. If the worker dies in that gap, the next retry still
holds the previous (terminal) ID and will resubmit a fresh job rather than reconnecting.
Closing this window would require atomic "submit + persist", which is not possible across
an external system boundary.
Binds the operator's ``submit_job`` / ``get_job_status`` / … methods to
:func:`resume_or_submit`, which owns the persist + three-state reconnect logic. See that
function for the behaviour and its known submit-vs-persist limitation.
"""
if not self.durable:
external_id = self.submit_job(context)
self.poll_until_complete(external_id, context)
return self.get_job_result(external_id, context)

stats_tags = {"operator": type(self).__name__}
# The task is team-scoped in multi-team deployments; surface team_name on the
# resumable_job metrics via the running task instance's stats tags (omitted when
Expand All @@ -129,85 +118,20 @@ def execute_resumable(self, context: Context) -> Any:
if ti is not None and (team_name := ti.stats_tags.get("team_name")):
stats_tags["team_name"] = team_name

reconnect_to: Any = None
already_succeeded_id: Any = None

with tracer.start_as_current_span("resumable_job.resume_decision") as span:
span.set_attribute("operator", type(self).__name__)
span.set_attribute("resumable.external_id_key", self.external_id_key)

task_state_store = context.get("task_state_store")

if task_state_store is None:
span.set_attribute("resumable.decision", "no_task_state_store")
self.log.warning(
"task_state_store not available in context, crash recovery is disabled for this run"
)
else:
external_id = task_state_store.get(self.external_id_key)
if external_id:
stats.incr("resumable_job.reconnect_attempt", tags=stats_tags)

status = self.get_job_status(external_id, context)

span.set_attribute("resumable.external_id", str(external_id))
span.set_attribute("resumable.prior_status", status)

if self.is_job_active(status):
# Job is still running, skip submission and reconnect to it.
span.set_attribute("resumable.decision", "reconnect")
stats.incr("resumable_job.reconnect_success", tags=stats_tags)
self.log.info(
"Reconnecting to existing job",
external_id_key=self.external_id_key,
external_id=external_id,
status=status,
)
reconnect_to = external_id
elif self.is_job_succeeded(status):
# Job already finished successfully, skip polling and return result directly.
span.set_attribute("resumable.decision", "already_succeeded")
stats.incr("resumable_job.already_succeeded", tags=stats_tags)
self.log.info(
"Job already completed successfully, skipping resubmission",
external_id_key=self.external_id_key,
external_id=external_id,
)
already_succeeded_id = external_id
else:
# Job is in a terminal failed state, fall through and submit a new job.
span.set_attribute("resumable.decision", "terminal_resubmit")
stats.incr("resumable_job.terminal_resubmit", tags=stats_tags)
self.log.warning(
"Prior job in terminal state, resubmitting fresh",
external_id_key=self.external_id_key,
external_id=external_id,
status=status,
)
else:
span.set_attribute("resumable.decision", "fresh_submit")
stats.incr("resumable_job.fresh_submit", tags=stats_tags)
self.log.debug(
"No stored external ID found; submitting fresh job",
external_id_key=self.external_id_key,
)

if reconnect_to is not None:
return self.poll_until_complete(reconnect_to, context)
if already_succeeded_id is not None:
return self.get_job_result(already_succeeded_id, context)
external_id = self.submit_job(context)

if task_state_store is not None and external_id is not None:
task_state_store.set(self.external_id_key, external_id)
self.log.debug(
"Persisted external ID to task store",
external_id_key=self.external_id_key,
external_id=external_id,
)

self.poll_until_complete(external_id, context)
return self.get_job_result(external_id, context)
return resume_or_submit(
durable=self.durable,
external_id_key=self.external_id_key,
task_state_store=context.get("task_state_store"),
submit=lambda: self.submit_job(context),
get_status=lambda external_id: self.get_job_status(external_id, context),
is_active=self.is_job_active,
is_succeeded=self.is_job_succeeded,
poll=lambda external_id: self.poll_until_complete(external_id, context),
get_result=lambda external_id: self.get_job_result(external_id, context),
log=self.log,
operator_name=type(self).__name__,
stats_tags=stats_tags,
)

def submit_job(self, context: Context) -> JsonValue:
"""
Expand Down Expand Up @@ -254,3 +178,114 @@ def poll_until_complete(self, external_id: JsonValue, context: Context) -> None:
def get_job_result(self, external_id: JsonValue, context: Context) -> Any:
"""Return the job result after completion. Return None if not applicable."""
raise NotImplementedError


def resume_or_submit(
*,
durable: bool,
external_id_key: str,
task_state_store: Any,
submit: Callable[[], JsonValue],
get_status: Callable[[JsonValue], str],
is_active: Callable[[str], bool],
is_succeeded: Callable[[str], bool],
poll: Callable[[JsonValue], Any],
get_result: Callable[[JsonValue], Any],
log: FilteringBoundLogger,
operator_name: str,
stats_tags: dict[str, str],
) -> Any:
"""
Submit an external job and poll for it, reconnecting to a prior run on retry.

The reusable core of crash-safe submit-and-poll execution, independent of any operator.
``ResumableJobMixin`` wraps it for synchronous operators, but it can equally be driven from the
task runner for operators whose wait happens outside ``execute()`` (e.g. ``TriggerDagRunOperator``,
which raises an exception and is polled by the runner).

The callbacks are the external-system bindings: ``submit`` starts the job and returns its external
id, ``get_status`` reads the raw backend status, ``is_active`` / ``is_succeeded`` classify it,
``poll`` blocks until terminal (raising on failure), ``get_result`` returns the result. On the
first run the external id is persisted to ``task_state_store`` before polling; on retry it is read
back and the job is reconnected (active), returned (succeeded), or resubmitted (terminal / missing).

Known limitation: there is a small window between ``submit`` returning and the persist completing.
A crash in that gap resubmits fresh on the next retry rather than reconnecting; closing it would
require an atomic "submit + persist", which is not possible across an external system boundary.
"""
if not durable:
external_id = submit()
poll(external_id)
return get_result(external_id)

reconnect_to: Any = None
already_succeeded_id: Any = None

with tracer.start_as_current_span("resumable_job.resume_decision") as span:
span.set_attribute("operator", operator_name)
span.set_attribute("resumable.external_id_key", external_id_key)

if task_state_store is None:
span.set_attribute("resumable.decision", "no_task_state_store")
log.warning("task_state_store not available in context, crash recovery is disabled for this run")
else:
external_id = task_state_store.get(external_id_key)
if external_id:
stats.incr("resumable_job.reconnect_attempt", tags=stats_tags)

status = get_status(external_id)

span.set_attribute("resumable.external_id", str(external_id))
span.set_attribute("resumable.prior_status", status)

if is_active(status):
span.set_attribute("resumable.decision", "reconnect")
stats.incr("resumable_job.reconnect_success", tags=stats_tags)
log.info(
"Reconnecting to existing job",
external_id_key=external_id_key,
external_id=external_id,
status=status,
)
reconnect_to = external_id
elif is_succeeded(status):
span.set_attribute("resumable.decision", "already_succeeded")
stats.incr("resumable_job.already_succeeded", tags=stats_tags)
log.info(
"Job already completed successfully, skipping resubmission",
external_id_key=external_id_key,
external_id=external_id,
)
already_succeeded_id = external_id
else:
span.set_attribute("resumable.decision", "terminal_resubmit")
stats.incr("resumable_job.terminal_resubmit", tags=stats_tags)
log.warning(
"Prior job in terminal state, resubmitting fresh",
external_id_key=external_id_key,
external_id=external_id,
status=status,
)
else:
span.set_attribute("resumable.decision", "fresh_submit")
stats.incr("resumable_job.fresh_submit", tags=stats_tags)
log.debug(
"No stored external ID found; submitting fresh job",
external_id_key=external_id_key,
)

if reconnect_to is not None:
return poll(reconnect_to)
if already_succeeded_id is not None:
return get_result(already_succeeded_id)

external_id = submit()
if task_state_store is not None and external_id is not None:
task_state_store.set(external_id_key, external_id)
log.debug(
"Persisted external ID to task store",
external_id_key=external_id_key,
external_id=external_id,
)
poll(external_id)
return get_result(external_id)
Loading