Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions task-sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1680,6 +1680,7 @@ def start( # type: ignore[override]
log = structlog.get_logger(logger_name="task")

state, msg, error = run(ti, context, log)
context["exception"] = error
finalize(ti, state, context, log, error)
Comment thread
ASk1 marked this conversation as resolved.

# In the normal subprocess model, the task runner calls this before exiting.
Expand Down
50 changes: 50 additions & 0 deletions task-sdk/tests/task_sdk/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2805,6 +2805,56 @@ def _handle_request(self, msg, log, req_id):
assert isinstance(response, VariableResult)
assert response.value == "value"

def test_inprocess_failure_callback_receives_exception(
self,
monkeypatch,
make_ti_context,
):
"""Run a failing task via InProcessTestSupervisor and ensure the
`on_failure_callback` receives `context['exception']`.
"""
collected: list[BaseException | None] = [None]

class _Failure(Exception):
pass

def failure_callback(context):
collected[0] = context.get("exception")

class FailingOperator(BaseOperator):
def execute(self, context=None):
raise _Failure("boom")

task = FailingOperator(task_id="failing", on_failure_callback=failure_callback)

# Assign a minimal DAG to the operator so `task.dag` access succeeds
from airflow.sdk import DAG

task.dag = DAG(dag_id="test_dag")
Comment thread
ASk1 marked this conversation as resolved.

# Create a simple TaskInstance datamodel to pass to the supervisor
ti = TaskInstance(
id=uuid7(),
task_id=task.task_id,
dag_id="test_dag",
run_id="r",
try_number=1,
dag_version_id=uuid7(),
)

# Patch the API client used by InProcessTestSupervisor to return a predictable TI context
fake_client = MagicMock()
fake_client.task_instances.start.return_value = make_ti_context()
Comment thread
ASk1 marked this conversation as resolved.
Outdated
monkeypatch.setattr(
InProcessTestSupervisor, "_api_client", staticmethod(lambda dag=None: fake_client)
)

result = InProcessTestSupervisor.start(what=ti, task=task)

# Ensure the task failed and the callback saw the exception
assert isinstance(result.error, _Failure)
assert isinstance(collected[0], _Failure)


class TestInProcessClient:
def test_no_retries(self):
Expand Down