|
26 | 26 | import pytest
|
27 | 27 | from uuid6 import uuid7
|
28 | 28 |
|
29 |
| -from airflow.exceptions import AirflowSkipException |
| 29 | +from airflow.exceptions import AirflowFailException, AirflowSensorTimeout, AirflowSkipException |
30 | 30 | from airflow.sdk import DAG, BaseOperator
|
31 | 31 | from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
|
32 | 32 | from airflow.sdk.execution_time.comms import DeferTask, SetRenderedFields, StartupDetails, TaskState
|
@@ -333,6 +333,55 @@ def __init__(self, *args, **kwargs):
|
333 | 333 | )
|
334 | 334 |
|
335 | 335 |
|
| 336 | +@pytest.mark.parametrize( |
| 337 | + ["dag_id", "task_id", "fail_with_exception"], |
| 338 | + [ |
| 339 | + pytest.param( |
| 340 | + "basic_failed", "fail-exception", AirflowFailException("Oops. Failing by AirflowFailException!") |
| 341 | + ), |
| 342 | + pytest.param( |
| 343 | + "basic_failed2", |
| 344 | + "sensor-timeout-exception", |
| 345 | + AirflowSensorTimeout("Oops. Failing by AirflowSensorTimeout!"), |
| 346 | + ), |
| 347 | + ], |
| 348 | +) |
| 349 | +def test_run_basic_failed(time_machine, mocked_parse, dag_id, task_id, fail_with_exception, make_ti_context): |
| 350 | + """Test running a basic task that marks itself as failed by raising exception.""" |
| 351 | + |
| 352 | + class CustomOperator(BaseOperator): |
| 353 | + def __init__(self, e, *args, **kwargs): |
| 354 | + super().__init__(*args, **kwargs) |
| 355 | + self.e = e |
| 356 | + |
| 357 | + def execute(self, context): |
| 358 | + print(f"raising exception {self.e}") |
| 359 | + raise self.e |
| 360 | + |
| 361 | + task = CustomOperator(task_id=task_id, e=fail_with_exception) |
| 362 | + |
| 363 | + what = StartupDetails( |
| 364 | + ti=TaskInstance(id=uuid7(), task_id=task_id, dag_id=dag_id, run_id="c", try_number=1), |
| 365 | + file="", |
| 366 | + requests_fd=0, |
| 367 | + ti_context=make_ti_context(), |
| 368 | + ) |
| 369 | + |
| 370 | + ti = mocked_parse(what, dag_id, task) |
| 371 | + |
| 372 | + instant = timezone.datetime(2024, 12, 3, 10, 0) |
| 373 | + time_machine.move_to(instant, tick=False) |
| 374 | + |
| 375 | + with mock.patch( |
| 376 | + "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True |
| 377 | + ) as mock_supervisor_comms: |
| 378 | + run(ti, log=mock.MagicMock()) |
| 379 | + |
| 380 | + mock_supervisor_comms.send_request.assert_called_once_with( |
| 381 | + msg=TaskState(state=TerminalTIState.FAILED, end_date=instant), log=mock.ANY |
| 382 | + ) |
| 383 | + |
| 384 | + |
336 | 385 | class TestRuntimeTaskInstance:
|
337 | 386 | def test_get_context_without_ti_context_from_server(self, mocked_parse, make_ti_context):
|
338 | 387 | """Test get_template_context without ti_context_from_server."""
|
|
0 commit comments