diff --git a/task-sdk/src/airflow/sdk/bases/resumablejobmixin.py b/task-sdk/src/airflow/sdk/bases/resumablejobmixin.py index 27533dbe8409c..c4d6848795ca8 100644 --- a/task-sdk/src/airflow/sdk/bases/resumablejobmixin.py +++ b/task-sdk/src/airflow/sdk/bases/resumablejobmixin.py @@ -24,7 +24,10 @@ from airflow.sdk.bases.operator import BaseOperatorMeta if TYPE_CHECKING: + from collections.abc import Callable + from pydantic import JsonValue + from structlog.typing import FilteringBoundLogger from airflow.sdk.definitions.context import Context from airflow.sdk.types import Logger @@ -101,26 +104,12 @@ def __init__(self, *, durable: bool = True, **kwargs: Any) -> None: def execute_resumable(self, context: Context) -> Any: """ - Core of the resumable execution logic. Call this from execute() when reconnection is supported. - - On initial run: submits the job, persists the external ID to task_state_store, then polls. + Crash-safe submit-and-poll for synchronous operators. Call from ``execute()``. - Behaviour on retry: - - On retry with active job: skips submission, reconnects to the running job. - - On retry with succeeded job: skips submission and polling, returns result immediately. - - On retry with failed job: falls through and resubmits fresh. - - Known limitation: there is a small window between ``submit_job`` returning and - ``task_state_store.set`` completing. If the worker dies in that gap, the next retry still - holds the previous (terminal) ID and will resubmit a fresh job rather than reconnecting. - Closing this window would require atomic "submit + persist", which is not possible across - an external system boundary. + Binds the operator's ``submit_job`` / ``get_job_status`` / … methods to + :func:`resume_or_submit`, which owns the persist + three-state reconnect logic. See that + function for the behaviour and its known submit-vs-persist limitation. """ - if not self.durable: - external_id = self.submit_job(context) - self.poll_until_complete(external_id, context) - return self.get_job_result(external_id, context) - stats_tags = {"operator": type(self).__name__} # The task is team-scoped in multi-team deployments; surface team_name on the # resumable_job metrics via the running task instance's stats tags (omitted when @@ -129,85 +118,20 @@ def execute_resumable(self, context: Context) -> Any: if ti is not None and (team_name := ti.stats_tags.get("team_name")): stats_tags["team_name"] = team_name - reconnect_to: Any = None - already_succeeded_id: Any = None - - with tracer.start_as_current_span("resumable_job.resume_decision") as span: - span.set_attribute("operator", type(self).__name__) - span.set_attribute("resumable.external_id_key", self.external_id_key) - - task_state_store = context.get("task_state_store") - - if task_state_store is None: - span.set_attribute("resumable.decision", "no_task_state_store") - self.log.warning( - "task_state_store not available in context, crash recovery is disabled for this run" - ) - else: - external_id = task_state_store.get(self.external_id_key) - if external_id: - stats.incr("resumable_job.reconnect_attempt", tags=stats_tags) - - status = self.get_job_status(external_id, context) - - span.set_attribute("resumable.external_id", str(external_id)) - span.set_attribute("resumable.prior_status", status) - - if self.is_job_active(status): - # Job is still running, skip submission and reconnect to it. - span.set_attribute("resumable.decision", "reconnect") - stats.incr("resumable_job.reconnect_success", tags=stats_tags) - self.log.info( - "Reconnecting to existing job", - external_id_key=self.external_id_key, - external_id=external_id, - status=status, - ) - reconnect_to = external_id - elif self.is_job_succeeded(status): - # Job already finished successfully, skip polling and return result directly. - span.set_attribute("resumable.decision", "already_succeeded") - stats.incr("resumable_job.already_succeeded", tags=stats_tags) - self.log.info( - "Job already completed successfully, skipping resubmission", - external_id_key=self.external_id_key, - external_id=external_id, - ) - already_succeeded_id = external_id - else: - # Job is in a terminal failed state, fall through and submit a new job. - span.set_attribute("resumable.decision", "terminal_resubmit") - stats.incr("resumable_job.terminal_resubmit", tags=stats_tags) - self.log.warning( - "Prior job in terminal state, resubmitting fresh", - external_id_key=self.external_id_key, - external_id=external_id, - status=status, - ) - else: - span.set_attribute("resumable.decision", "fresh_submit") - stats.incr("resumable_job.fresh_submit", tags=stats_tags) - self.log.debug( - "No stored external ID found; submitting fresh job", - external_id_key=self.external_id_key, - ) - - if reconnect_to is not None: - return self.poll_until_complete(reconnect_to, context) - if already_succeeded_id is not None: - return self.get_job_result(already_succeeded_id, context) - external_id = self.submit_job(context) - - if task_state_store is not None and external_id is not None: - task_state_store.set(self.external_id_key, external_id) - self.log.debug( - "Persisted external ID to task store", - external_id_key=self.external_id_key, - external_id=external_id, - ) - - self.poll_until_complete(external_id, context) - return self.get_job_result(external_id, context) + return resume_or_submit( + durable=self.durable, + external_id_key=self.external_id_key, + task_state_store=context.get("task_state_store"), + submit=lambda: self.submit_job(context), + get_status=lambda external_id: self.get_job_status(external_id, context), + is_active=self.is_job_active, + is_succeeded=self.is_job_succeeded, + poll=lambda external_id: self.poll_until_complete(external_id, context), + get_result=lambda external_id: self.get_job_result(external_id, context), + log=self.log, + operator_name=type(self).__name__, + stats_tags=stats_tags, + ) def submit_job(self, context: Context) -> JsonValue: """ @@ -254,3 +178,114 @@ def poll_until_complete(self, external_id: JsonValue, context: Context) -> None: def get_job_result(self, external_id: JsonValue, context: Context) -> Any: """Return the job result after completion. Return None if not applicable.""" raise NotImplementedError + + +def resume_or_submit( + *, + durable: bool, + external_id_key: str, + task_state_store: Any, + submit: Callable[[], JsonValue], + get_status: Callable[[JsonValue], str], + is_active: Callable[[str], bool], + is_succeeded: Callable[[str], bool], + poll: Callable[[JsonValue], Any], + get_result: Callable[[JsonValue], Any], + log: FilteringBoundLogger, + operator_name: str, + stats_tags: dict[str, str], +) -> Any: + """ + Submit an external job and poll for it, reconnecting to a prior run on retry. + + The reusable core of crash-safe submit-and-poll execution, independent of any operator. + ``ResumableJobMixin`` wraps it for synchronous operators, but it can equally be driven from the + task runner for operators whose wait happens outside ``execute()`` (e.g. ``TriggerDagRunOperator``, + which raises an exception and is polled by the runner). + + The callbacks are the external-system bindings: ``submit`` starts the job and returns its external + id, ``get_status`` reads the raw backend status, ``is_active`` / ``is_succeeded`` classify it, + ``poll`` blocks until terminal (raising on failure), ``get_result`` returns the result. On the + first run the external id is persisted to ``task_state_store`` before polling; on retry it is read + back and the job is reconnected (active), returned (succeeded), or resubmitted (terminal / missing). + + Known limitation: there is a small window between ``submit`` returning and the persist completing. + A crash in that gap resubmits fresh on the next retry rather than reconnecting; closing it would + require an atomic "submit + persist", which is not possible across an external system boundary. + """ + if not durable: + external_id = submit() + poll(external_id) + return get_result(external_id) + + reconnect_to: Any = None + already_succeeded_id: Any = None + + with tracer.start_as_current_span("resumable_job.resume_decision") as span: + span.set_attribute("operator", operator_name) + span.set_attribute("resumable.external_id_key", external_id_key) + + if task_state_store is None: + span.set_attribute("resumable.decision", "no_task_state_store") + log.warning("task_state_store not available in context, crash recovery is disabled for this run") + else: + external_id = task_state_store.get(external_id_key) + if external_id: + stats.incr("resumable_job.reconnect_attempt", tags=stats_tags) + + status = get_status(external_id) + + span.set_attribute("resumable.external_id", str(external_id)) + span.set_attribute("resumable.prior_status", status) + + if is_active(status): + span.set_attribute("resumable.decision", "reconnect") + stats.incr("resumable_job.reconnect_success", tags=stats_tags) + log.info( + "Reconnecting to existing job", + external_id_key=external_id_key, + external_id=external_id, + status=status, + ) + reconnect_to = external_id + elif is_succeeded(status): + span.set_attribute("resumable.decision", "already_succeeded") + stats.incr("resumable_job.already_succeeded", tags=stats_tags) + log.info( + "Job already completed successfully, skipping resubmission", + external_id_key=external_id_key, + external_id=external_id, + ) + already_succeeded_id = external_id + else: + span.set_attribute("resumable.decision", "terminal_resubmit") + stats.incr("resumable_job.terminal_resubmit", tags=stats_tags) + log.warning( + "Prior job in terminal state, resubmitting fresh", + external_id_key=external_id_key, + external_id=external_id, + status=status, + ) + else: + span.set_attribute("resumable.decision", "fresh_submit") + stats.incr("resumable_job.fresh_submit", tags=stats_tags) + log.debug( + "No stored external ID found; submitting fresh job", + external_id_key=external_id_key, + ) + + if reconnect_to is not None: + return poll(reconnect_to) + if already_succeeded_id is not None: + return get_result(already_succeeded_id) + + external_id = submit() + if task_state_store is not None and external_id is not None: + task_state_store.set(external_id_key, external_id) + log.debug( + "Persisted external ID to task store", + external_id_key=external_id_key, + external_id=external_id, + ) + poll(external_id) + return get_result(external_id)