From 0d71a70912efea5c97d24d235e73405d88e29410 Mon Sep 17 00:00:00 2001 From: raphaelauv Date: Thu, 25 Jul 2024 18:25:43 +0200 Subject: [PATCH] introduce `fail_policy` --- airflow/decorators/__init__.pyi | 5 +- airflow/example_dags/example_sensors.py | 22 ++++-- airflow/exceptions.py | 4 + airflow/sensors/base.py | 74 +++++++++-------- .../src/airflow/providers/ftp/sensors/ftp.py | 3 +- .../airflow/providers/sftp/sensors/sftp.py | 4 +- .../standard/sensors/external_task.py | 47 ++++------- .../providers/standard/sensors/filesystem.py | 4 +- .../providers/standard/sensors/time_delta.py | 14 +--- .../standard/triggers/external_task.py | 9 ++- providers/tests/ftp/sensors/test_ftp.py | 19 +++-- providers/tests/http/sensors/test_http.py | 6 +- providers/tests/sftp/sensors/test_sftp.py | 4 +- .../standard/triggers/test_external_task.py | 2 +- tests/decorators/test_sensor.py | 11 ++- tests/sensors/test_base.py | 53 +++++++------ tests/sensors/test_external_task_sensor.py | 79 ++++++++++++------- 17 files changed, 206 insertions(+), 154 deletions(-) diff --git a/airflow/decorators/__init__.pyi b/airflow/decorators/__init__.pyi index 49504fa388a79..38222a52a227d 100644 --- a/airflow/decorators/__init__.pyi +++ b/airflow/decorators/__init__.pyi @@ -41,6 +41,7 @@ from airflow.decorators.short_circuit import short_circuit_task from airflow.decorators.task_group import task_group from airflow.models.dag import dag from airflow.providers.cncf.kubernetes.secret import Secret +from airflow.sensors.base import FailPolicy from airflow.typing_compat import Literal # Please keep this in sync with __init__.py's __all__. @@ -690,7 +691,7 @@ class TaskDecoratorCollection: *, poke_interval: float = ..., timeout: float = ..., - soft_fail: bool = False, + fail_policy: FailPolicy = ..., mode: str = ..., exponential_backoff: bool = False, max_wait: timedelta | float | None = None, @@ -702,7 +703,7 @@ class TaskDecoratorCollection: :param poke_interval: Time in seconds that the job should wait in between each try :param timeout: Time, in seconds before the task times out and fails. - :param soft_fail: Set to true to mark the task as SKIPPED on failure + :param fail_policy: TODO. :param mode: How the sensor operates. Options are: ``{ poke | reschedule }``, default is ``poke``. When set to ``poke`` the sensor is taking up a worker slot for its diff --git a/airflow/example_dags/example_sensors.py b/airflow/example_dags/example_sensors.py index 39d7b8d29635f..6ca71527112e6 100644 --- a/airflow/example_dags/example_sensors.py +++ b/airflow/example_dags/example_sensors.py @@ -23,6 +23,7 @@ from airflow.models.dag import DAG from airflow.providers.standard.operators.bash import BashOperator +from airflow.providers.standard.sensors.base import FailPolicy from airflow.providers.standard.sensors.bash import BashSensor from airflow.providers.standard.sensors.filesystem import FileSensor from airflow.providers.standard.sensors.python import PythonSensor @@ -68,7 +69,7 @@ def failure_callable(): t2 = TimeSensor( task_id="timeout_after_second_date_in_the_future", timeout=1, - soft_fail=True, + fail_policy=FailPolicy.SKIP_ON_TIMEOUT, target_time=(datetime.datetime.now(tz=datetime.timezone.utc) + datetime.timedelta(hours=1)).time(), ) # [END example_time_sensors] @@ -81,7 +82,7 @@ def failure_callable(): t2a = TimeSensorAsync( task_id="timeout_after_second_date_in_the_future_async", timeout=1, - soft_fail=True, + fail_policy=FailPolicy.SKIP_ON_TIMEOUT, target_time=(datetime.datetime.now(tz=datetime.timezone.utc) + datetime.timedelta(hours=1)).time(), ) # [END example_time_sensors_async] @@ -89,7 +90,12 @@ def failure_callable(): # [START example_bash_sensors] t3 = BashSensor(task_id="Sensor_succeeds", bash_command="exit 0") - t4 = BashSensor(task_id="Sensor_fails_after_3_seconds", timeout=3, soft_fail=True, bash_command="exit 1") + t4 = BashSensor( + task_id="Sensor_fails_after_3_seconds", + timeout=3, + fail_policy=FailPolicy.SKIP_ON_TIMEOUT, + bash_command="exit 1", + ) # [END example_bash_sensors] t5 = BashOperator(task_id="remove_file", bash_command="rm -rf /tmp/temporary_file_for_testing") @@ -112,13 +118,19 @@ def failure_callable(): t9 = PythonSensor(task_id="success_sensor_python", python_callable=success_callable) t10 = PythonSensor( - task_id="failure_timeout_sensor_python", timeout=3, soft_fail=True, python_callable=failure_callable + task_id="failure_timeout_sensor_python", + timeout=3, + fail_policy=FailPolicy.SKIP_ON_TIMEOUT, + python_callable=failure_callable, ) # [END example_python_sensors] # [START example_day_of_week_sensor] t11 = DayOfWeekSensor( - task_id="week_day_sensor_failing_on_timeout", timeout=3, soft_fail=True, week_day=WeekDay.MONDAY + task_id="week_day_sensor_failing_on_timeout", + timeout=3, + fail_policy=FailPolicy.SKIP_ON_TIMEOUT, + week_day=WeekDay.MONDAY, ) # [END example_day_of_week_sensor] diff --git a/airflow/exceptions.py b/airflow/exceptions.py index 5e32c00c7d5da..96e1601728a30 100644 --- a/airflow/exceptions.py +++ b/airflow/exceptions.py @@ -70,6 +70,10 @@ class AirflowSensorTimeout(AirflowException): """Raise when there is a timeout on sensor polling.""" +class AirflowPokeFailException(AirflowException): + """Raise when a sensor must not try to poke again.""" + + class AirflowRescheduleException(AirflowException): """ Raise when the task should be re-scheduled at a later time. diff --git a/airflow/sensors/base.py b/airflow/sensors/base.py index 39172ce64afd7..58c83a4effcb6 100644 --- a/airflow/sensors/base.py +++ b/airflow/sensors/base.py @@ -18,6 +18,7 @@ from __future__ import annotations import datetime +import enum import functools import hashlib import time @@ -32,7 +33,7 @@ from airflow.configuration import conf from airflow.exceptions import ( AirflowException, - AirflowFailException, + AirflowPokeFailException, AirflowRescheduleException, AirflowSensorTimeout, AirflowSkipException, @@ -81,6 +82,23 @@ def __init__(self, is_done: bool, xcom_value: Any | None = None) -> None: def __bool__(self) -> bool: return self.is_done +class FailPolicy(str, enum.Enum): + """Class with sensor's fail policies.""" + + # if poke method raise an exception, sensor will not be skipped on. + NONE = "none" + + # If poke method raises an exception, sensor will be skipped on. + SKIP_ON_ANY_ERROR = "skip_on_any_error" + + # If poke method raises AirflowSensorTimeout, AirflowTaskTimeout,AirflowPokeFailException or AirflowSkipException + # sensor will be skipped on. + SKIP_ON_TIMEOUT = "skip_on_timeout" + + # If poke method raises an exception different from AirflowSensorTimeout, AirflowTaskTimeout, + # AirflowSkipException or AirflowFailException sensor will ignore exception and re-poke until timeout. + IGNORE_ERROR = "ignore_error" + class BaseSensorOperator(BaseOperator, SkipMixin): """ @@ -89,8 +107,6 @@ class BaseSensorOperator(BaseOperator, SkipMixin): Sensor operators keep executing at a time interval and succeed when a criteria is met and fail if and when they time out. - :param soft_fail: Set to true to mark the task as SKIPPED on failure. - Mutually exclusive with never_fail. :param poke_interval: Time that the job should wait in between each try. Can be ``timedelta`` or ``float`` seconds. :param timeout: Time elapsed before the task times out and fails. @@ -118,13 +134,10 @@ class BaseSensorOperator(BaseOperator, SkipMixin): :param exponential_backoff: allow progressive longer waits between pokes by using exponential backoff algorithm :param max_wait: maximum wait interval between pokes, can be ``timedelta`` or ``float`` seconds - :param silent_fail: If true, and poke method raises an exception different from - AirflowSensorTimeout, AirflowTaskTimeout, AirflowSkipException - and AirflowFailException, the sensor will log the error and continue - its execution. Otherwise, the sensor task fails, and it can be retried - based on the provided `retries` parameter. - :param never_fail: If true, and poke method raises an exception, sensor will be skipped. - Mutually exclusive with soft_fail. + :param fail_policy: defines the rule by which sensor skip itself. Options are: + ``{ none | skip_on_any_error | skip_on_timeout | ignore_error }`` + default is ``none``. Options can be set as string or + using the constants defined in the static class ``airflow.sensors.base.FailPolicy`` """ ui_color: str = "#e6f1f2" @@ -139,26 +152,19 @@ def __init__( *, poke_interval: timedelta | float = 60, timeout: timedelta | float = conf.getfloat("sensors", "default_timeout"), - soft_fail: bool = False, mode: str = "poke", exponential_backoff: bool = False, max_wait: timedelta | float | None = None, - silent_fail: bool = False, - never_fail: bool = False, + fail_policy: str = FailPolicy.NONE, **kwargs, ) -> None: super().__init__(**kwargs) self.poke_interval = self._coerce_poke_interval(poke_interval).total_seconds() - self.soft_fail = soft_fail self.timeout: int | float = self._coerce_timeout(timeout).total_seconds() self.mode = mode self.exponential_backoff = exponential_backoff self.max_wait = self._coerce_max_wait(max_wait) - if soft_fail is True and never_fail is True: - raise ValueError("soft_fail and never_fail are mutually exclusive, you can not provide both.") - - self.silent_fail = silent_fail - self.never_fail = never_fail + self.fail_policy = fail_policy self._validate_input_values() @staticmethod @@ -266,21 +272,20 @@ def run_duration() -> float: except ( AirflowSensorTimeout, AirflowTaskTimeout, - AirflowFailException, + AirflowPokeFailException, + AirflowSkipException, ) as e: - if self.soft_fail: - raise AirflowSkipException("Skipping due to soft_fail is set to True.") from e - elif self.never_fail: - raise AirflowSkipException("Skipping due to never_fail is set to True.") from e - raise e - except AirflowSkipException as e: + if self.fail_policy == FailPolicy.SKIP_ON_TIMEOUT: + raise AirflowSkipException("Skipping due fail_policy set to SKIP_ON_TIMEOUT.") from e + elif self.fail_policy == FailPolicy.SKIP_ON_ANY_ERROR: + raise AirflowSkipException("Skipping due to SKIP_ON_ANY_ERROR is set to True.") from e raise e except Exception as e: - if self.silent_fail: + if self.fail_policy == FailPolicy.IGNORE_ERROR: self.log.error("Sensor poke failed: \n %s", traceback.format_exc()) poke_return = False - elif self.never_fail: - raise AirflowSkipException("Skipping due to never_fail is set to True.") from e + elif self.fail_policy == FailPolicy.SKIP_ON_ANY_ERROR: + raise AirflowSkipException("Skipping due to SKIP_ON_ANY_ERROR is set to True.") from e else: raise e @@ -290,13 +295,13 @@ def run_duration() -> float: break if run_duration() > self.timeout: - # If sensor is in soft fail mode but times out raise AirflowSkipException. + # If sensor is in SKIP_ON_TIMEOUT mode but times out it raise AirflowSkipException. message = ( f"Sensor has timed out; run duration of {run_duration()} seconds exceeds " f"the specified timeout of {self.timeout}." ) - if self.soft_fail: + if self.fail_policy == FailPolicy.SKIP_ON_TIMEOUT: raise AirflowSkipException(message) else: raise AirflowSensorTimeout(message) @@ -319,9 +324,12 @@ def resume_execution(self, next_method: str, next_kwargs: dict[str, Any] | None, try: return super().resume_execution(next_method, next_kwargs, context) except TaskDeferralTimeout as e: - raise AirflowSensorTimeout(*e.args) from e + if self.fail_policy == FailPolicy.SKIP_ON_ANY_ERROR: + raise AirflowSkipException(str(e)) from e + else: + raise AirflowSensorTimeout(*e.args) from e except (AirflowException, TaskDeferralError) as e: - if self.soft_fail: + if self.fail_policy == FailPolicy.SKIP_ON_ANY_ERROR: raise AirflowSkipException(str(e)) from e raise diff --git a/providers/src/airflow/providers/ftp/sensors/ftp.py b/providers/src/airflow/providers/ftp/sensors/ftp.py index 9d384c889c7ba..50a289f2a432f 100644 --- a/providers/src/airflow/providers/ftp/sensors/ftp.py +++ b/providers/src/airflow/providers/ftp/sensors/ftp.py @@ -22,6 +22,7 @@ from collections.abc import Sequence from typing import TYPE_CHECKING +from airflow.exceptions import AirflowPokeFailException from airflow.providers.ftp.hooks.ftp import FTPHook, FTPSHook from airflow.sensors.base import BaseSensorOperator @@ -83,7 +84,7 @@ def poke(self, context: Context) -> bool: if (error_code != 550) and ( self.fail_on_transient_errors or (error_code not in self.transient_errors) ): - raise e + raise AirflowPokeFailException from e return False diff --git a/providers/src/airflow/providers/sftp/sensors/sftp.py b/providers/src/airflow/providers/sftp/sensors/sftp.py index 9a5cb14345282..98bf4199fb428 100644 --- a/providers/src/airflow/providers/sftp/sensors/sftp.py +++ b/providers/src/airflow/providers/sftp/sensors/sftp.py @@ -27,7 +27,7 @@ from paramiko.sftp import SFTP_NO_SUCH_FILE from airflow.configuration import conf -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowPokeFailException from airflow.providers.sftp.hooks.sftp import SFTPHook from airflow.providers.sftp.triggers.sftp import SFTPTrigger from airflow.sensors.base import BaseSensorOperator, PokeReturnValue @@ -99,7 +99,7 @@ def poke(self, context: Context) -> PokeReturnValue | bool: self.log.info("Found File %s last modified: %s", actual_file_to_check, mod_time) except OSError as e: if e.errno != SFTP_NO_SUCH_FILE: - raise AirflowException from e + raise AirflowPokeFailException from e continue if self.newer_than: diff --git a/providers/src/airflow/providers/standard/sensors/external_task.py b/providers/src/airflow/providers/standard/sensors/external_task.py index ff43ff2f463f7..560d598213720 100644 --- a/providers/src/airflow/providers/standard/sensors/external_task.py +++ b/providers/src/airflow/providers/standard/sensors/external_task.py @@ -24,7 +24,7 @@ from typing import TYPE_CHECKING, Any, Callable, ClassVar from airflow.configuration import conf -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowPokeFailException, AirflowSkipException from airflow.models.baseoperatorlink import BaseOperatorLink from airflow.models.dag import DagModel from airflow.models.dagbag import DagBag @@ -176,7 +176,7 @@ def __init__( total_states = set(self.allowed_states + self.skipped_states + self.failed_states) if len(total_states) != len(self.allowed_states) + len(self.skipped_states) + len(self.failed_states): - raise AirflowException( + raise ValueError( "Duplicate values provided across allowed_states, skipped_states and failed_states." ) @@ -287,32 +287,18 @@ def poke(self, context: Context, session: Session = NEW_SESSION) -> bool: # Fail if anything in the list has failed. if count_failed > 0: if self.external_task_ids: - if self.soft_fail: - raise AirflowSkipException( - f"Some of the external tasks {self.external_task_ids} " - f"in DAG {self.external_dag_id} failed. Skipping due to soft_fail." - ) - raise AirflowException( + raise AirflowPokeFailException( f"Some of the external tasks {self.external_task_ids} " f"in DAG {self.external_dag_id} failed." ) elif self.external_task_group_id: - if self.soft_fail: - raise AirflowSkipException( - f"The external task_group '{self.external_task_group_id}' " - f"in DAG '{self.external_dag_id}' failed. Skipping due to soft_fail." - ) - raise AirflowException( + raise AirflowPokeFailException( f"The external task_group '{self.external_task_group_id}' " f"in DAG '{self.external_dag_id}' failed." ) else: - if self.soft_fail: - raise AirflowSkipException( - f"The external DAG {self.external_dag_id} failed. Skipping due to soft_fail." - ) - raise AirflowException(f"The external DAG {self.external_dag_id} failed.") + raise AirflowPokeFailException(f"The external DAG {self.external_dag_id} failed.") count_skipped = -1 if self.skipped_states: @@ -355,7 +341,7 @@ def execute(self, context: Context) -> None: logical_dates=self._get_dttm_filter(context), allowed_states=self.allowed_states, poke_interval=self.poll_interval, - soft_fail=self.soft_fail, + fail_policy=self.fail_policy, ), method_name="execute_complete", ) @@ -365,30 +351,27 @@ def execute_complete(self, context, event=None): if event["status"] == "success": self.log.info("External tasks %s has executed successfully.", self.external_task_ids) elif event["status"] == "skipped": - raise AirflowSkipException("External job has skipped skipping.") + raise AirflowPokeFailException("External job has skipped skipping.") else: - if self.soft_fail: - raise AirflowSkipException("External job has failed skipping.") - else: - raise AirflowException( - "Error occurred while trying to retrieve task status. Please, check the " - "name of executed task and Dag." - ) + raise AirflowPokeFailException( + "Error occurred while trying to retrieve task status. Please, check the " + "name of executed task and Dag." + ) def _check_for_existence(self, session) -> None: dag_to_wait = DagModel.get_current(self.external_dag_id, session) if not dag_to_wait: - raise AirflowException(f"The external DAG {self.external_dag_id} does not exist.") + raise AirflowPokeFailException(f"The external DAG {self.external_dag_id} does not exist.") if not os.path.exists(correct_maybe_zipped(dag_to_wait.fileloc)): - raise AirflowException(f"The external DAG {self.external_dag_id} was deleted.") + raise AirflowPokeFailException(f"The external DAG {self.external_dag_id} was deleted.") if self.external_task_ids: refreshed_dag_info = DagBag(dag_to_wait.fileloc).get_dag(self.external_dag_id) for external_task_id in self.external_task_ids: if not refreshed_dag_info.has_task(external_task_id): - raise AirflowException( + raise AirflowPokeFailException( f"The external task {external_task_id} in " f"DAG {self.external_dag_id} does not exist." ) @@ -396,7 +379,7 @@ def _check_for_existence(self, session) -> None: if self.external_task_group_id: refreshed_dag_info = DagBag(dag_to_wait.fileloc).get_dag(self.external_dag_id) if not refreshed_dag_info.has_task_group(self.external_task_group_id): - raise AirflowException( + raise AirflowPokeFailException( f"The external task group '{self.external_task_group_id}' in " f"DAG '{self.external_dag_id}' does not exist." ) diff --git a/providers/src/airflow/providers/standard/sensors/filesystem.py b/providers/src/airflow/providers/standard/sensors/filesystem.py index 650787c485e65..771c700a29c9c 100644 --- a/providers/src/airflow/providers/standard/sensors/filesystem.py +++ b/providers/src/airflow/providers/standard/sensors/filesystem.py @@ -26,7 +26,7 @@ from typing import TYPE_CHECKING, Any from airflow.configuration import conf -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowPokeFailException from airflow.providers.standard.hooks.filesystem import FSHook from airflow.providers.standard.triggers.file import FileTrigger from airflow.sensors.base import BaseSensorOperator @@ -150,5 +150,5 @@ def execute(self, context: Context) -> None: def execute_complete(self, context: Context, event: bool | None = None) -> None: if not event: - raise AirflowException("%s task failed as %s not found.", self.task_id, self.filepath) + raise AirflowPokeFailException("%s task failed as %s not found.", self.task_id, self.filepath) self.log.info("%s completed successfully as %s found.", self.task_id, self.filepath) diff --git a/providers/src/airflow/providers/standard/sensors/time_delta.py b/providers/src/airflow/providers/standard/sensors/time_delta.py index 6b09a361efadf..13eaecf4d41cf 100644 --- a/providers/src/airflow/providers/standard/sensors/time_delta.py +++ b/providers/src/airflow/providers/standard/sensors/time_delta.py @@ -24,7 +24,6 @@ from packaging.version import Version from airflow.configuration import conf -from airflow.exceptions import AirflowSkipException from airflow.providers.standard.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS from airflow.sensors.base import BaseSensorOperator @@ -89,15 +88,10 @@ def execute(self, context: Context) -> bool | NoReturn: if timezone.utcnow() > target_dttm: # If the target datetime is in the past, return immediately return True - try: - if AIRFLOW_V_3_0_PLUS: - trigger = DateTimeTrigger(moment=target_dttm, end_from_trigger=self.end_from_trigger) - else: - trigger = DateTimeTrigger(moment=target_dttm) - except (TypeError, ValueError) as e: - if self.soft_fail: - raise AirflowSkipException("Skipping due to soft_fail is set to True.") from e - raise + if AIRFLOW_V_3_0_PLUS: + trigger = DateTimeTrigger(moment=target_dttm, end_from_trigger=self.end_from_trigger) + else: + trigger = DateTimeTrigger(moment=target_dttm) # todo: remove backcompat when min airflow version greater than 2.11 timeout: int | float | timedelta diff --git a/providers/src/airflow/providers/standard/triggers/external_task.py b/providers/src/airflow/providers/standard/triggers/external_task.py index a54729fa69081..52c1ad2fc8197 100644 --- a/providers/src/airflow/providers/standard/triggers/external_task.py +++ b/providers/src/airflow/providers/standard/triggers/external_task.py @@ -26,6 +26,7 @@ from airflow.models import DagRun from airflow.providers.standard.utils.sensor_helper import _get_count from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS +from airflow.sensors.base import FailPolicy from airflow.triggers.base import BaseTrigger, TriggerEvent from airflow.utils.session import NEW_SESSION, provide_session @@ -49,7 +50,7 @@ class WorkflowTrigger(BaseTrigger): :param skipped_states: States considered as skipped for external tasks. :param allowed_states: States considered as successful for external tasks. :param poke_interval: The interval (in seconds) for poking the external tasks. - :param soft_fail: If True, the trigger will not fail the entire DAG on external task failure. + """ def __init__( @@ -63,7 +64,7 @@ def __init__( skipped_states: typing.Iterable[str] | None = None, allowed_states: typing.Iterable[str] | None = None, poke_interval: float = 2.0, - soft_fail: bool = False, + fail_policy: str = FailPolicy.NONE, **kwargs, ): self.external_dag_id = external_dag_id @@ -74,7 +75,7 @@ def __init__( self.allowed_states = allowed_states self.logical_dates = logical_dates self.poke_interval = poke_interval - self.soft_fail = soft_fail + self.fail_policy = fail_policy self.execution_dates = execution_dates super().__init__(**kwargs) @@ -96,7 +97,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "allowed_states": self.allowed_states, **_dates, "poke_interval": self.poke_interval, - "soft_fail": self.soft_fail, + "fail_policy": self.fail_policy, }, ) diff --git a/providers/tests/ftp/sensors/test_ftp.py b/providers/tests/ftp/sensors/test_ftp.py index 0a71fbd594c8f..0a8620079cd8b 100644 --- a/providers/tests/ftp/sensors/test_ftp.py +++ b/providers/tests/ftp/sensors/test_ftp.py @@ -22,8 +22,10 @@ import pytest +from airflow.exceptions import AirflowPokeFailException, AirflowSkipException from airflow.providers.ftp.hooks.ftp import FTPHook from airflow.providers.ftp.sensors.ftp import FTPSensor +from airflow.sensors.base import FailPolicy class TestFTPSensor: @@ -51,10 +53,10 @@ def test_poke_fails_due_error(self, mock_hook): "530: Login authentication failed" ) - with pytest.raises(error_perm) as ctx: + with pytest.raises(AirflowPokeFailException) as ctx: op.execute(None) - assert "530" in str(ctx.value) + assert "530" in str(ctx.value.__cause__) @mock.patch("airflow.providers.ftp.sensors.ftp.FTPHook", spec=FTPHook) def test_poke_fail_on_transient_error(self, mock_hook): @@ -64,20 +66,25 @@ def test_poke_fail_on_transient_error(self, mock_hook): "434: Host unavailable" ) - with pytest.raises(error_perm) as ctx: + with pytest.raises(AirflowPokeFailException) as ctx: op.execute(None) - assert "434" in str(ctx.value) + assert "434" in str(ctx.value.__cause__) @mock.patch("airflow.providers.ftp.sensors.ftp.FTPHook", spec=FTPHook) def test_poke_fail_on_transient_error_and_skip(self, mock_hook): - op = FTPSensor(path="foobar.json", ftp_conn_id="bob_ftp", task_id="test_task") + op = FTPSensor( + path="foobar.json", + ftp_conn_id="bob_ftp", + task_id="test_task", + fail_policy=FailPolicy.SKIP_ON_TIMEOUT, + ) mock_hook.return_value.__enter__.return_value.get_mod_time.side_effect = error_perm( "434: Host unavailable" ) - with pytest.raises(error_perm): + with pytest.raises(AirflowSkipException): op.execute(None) @mock.patch("airflow.providers.ftp.sensors.ftp.FTPHook", spec=FTPHook) diff --git a/providers/tests/http/sensors/test_http.py b/providers/tests/http/sensors/test_http.py index 78a11e15bb7c1..b2dbd08e3751b 100644 --- a/providers/tests/http/sensors/test_http.py +++ b/providers/tests/http/sensors/test_http.py @@ -23,7 +23,11 @@ import pytest import requests -from airflow.exceptions import AirflowException, AirflowSensorTimeout, TaskDeferred +from airflow.exceptions import ( + AirflowException, + AirflowSensorTimeout, + TaskDeferred, +) from airflow.models.dag import DAG from airflow.providers.http.operators.http import HttpOperator from airflow.providers.http.sensors.http import HttpSensor diff --git a/providers/tests/sftp/sensors/test_sftp.py b/providers/tests/sftp/sensors/test_sftp.py index 4d1be081af16c..0d209169a7083 100644 --- a/providers/tests/sftp/sensors/test_sftp.py +++ b/providers/tests/sftp/sensors/test_sftp.py @@ -25,7 +25,7 @@ from paramiko.sftp import SFTP_FAILURE, SFTP_NO_SUCH_FILE from pendulum import datetime as pendulum_datetime, timezone -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowPokeFailException from airflow.providers.sftp.sensors.sftp import SFTPSensor from airflow.sensors.base import PokeReturnValue @@ -58,7 +58,7 @@ def test_sftp_failure(self, sftp_hook_mock): sftp_sensor = SFTPSensor(task_id="unit_test", path="/path/to/file/1970-01-01.txt") context = {"ds": "1970-01-01"} - with pytest.raises(AirflowException): + with pytest.raises(AirflowPokeFailException): sftp_sensor.poke(context) def test_hook_not_created_during_init(self): diff --git a/providers/tests/standard/triggers/test_external_task.py b/providers/tests/standard/triggers/test_external_task.py index 8d5ce456be6b2..396c64731cfd7 100644 --- a/providers/tests/standard/triggers/test_external_task.py +++ b/providers/tests/standard/triggers/test_external_task.py @@ -226,7 +226,7 @@ def test_serialization(self): "skipped_states": None, "allowed_states": self.STATES, "poke_interval": 5, - "soft_fail": False, + "fail_policy": "none", } diff --git a/tests/decorators/test_sensor.py b/tests/decorators/test_sensor.py index 9e21e11d8b6a2..7301aa4120e34 100644 --- a/tests/decorators/test_sensor.py +++ b/tests/decorators/test_sensor.py @@ -24,6 +24,11 @@ from airflow.exceptions import AirflowSensorTimeout from airflow.models import XCom from airflow.sensors.base import PokeReturnValue + +from tests_common.test_utils.compat import AIRFLOW_V_2_10_PLUS, ignore_provider_compatibility_error + +with ignore_provider_compatibility_error("2.10.0", __file__): + from airflow.sensors.base import FailPolicy from airflow.utils.state import State pytestmark = pytest.mark.db_test @@ -141,8 +146,9 @@ def dummy_f(): if ti.task_id == "dummy_f": assert ti.state == State.NONE + @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0") def test_basic_sensor_soft_fail(self, dag_maker): - @task.sensor(timeout=0, soft_fail=True) + @task.sensor(timeout=0, fail_policy=FailPolicy.SKIP_ON_TIMEOUT) def sensor_f(): return PokeReturnValue(is_done=False, xcom_value="xcom_value") @@ -165,8 +171,9 @@ def dummy_f(): if ti.task_id == "dummy_f": assert ti.state == State.NONE + @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0") def test_basic_sensor_soft_fail_returns_bool(self, dag_maker): - @task.sensor(timeout=0, soft_fail=True) + @task.sensor(timeout=0, fail_policy=FailPolicy.SKIP_ON_TIMEOUT) def sensor_f(): return False diff --git a/tests/sensors/test_base.py b/tests/sensors/test_base.py index 348062394fb2f..44ebe59956213 100644 --- a/tests/sensors/test_base.py +++ b/tests/sensors/test_base.py @@ -27,6 +27,7 @@ from airflow.exceptions import ( AirflowException, AirflowFailException, + AirflowPokeFailException, AirflowRescheduleException, AirflowSensorTimeout, AirflowSkipException, @@ -52,7 +53,7 @@ from airflow.providers.celery.executors.celery_kubernetes_executor import CeleryKubernetesExecutor from airflow.providers.cncf.kubernetes.executors.kubernetes_executor import KubernetesExecutor from airflow.providers.cncf.kubernetes.executors.local_kubernetes_executor import LocalKubernetesExecutor -from airflow.sensors.base import BaseSensorOperator, PokeReturnValue, poke_mode_only +from airflow.sensors.base import BaseSensorOperator, FailPolicy, PokeReturnValue, poke_mode_only from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep from airflow.utils import timezone from airflow.utils.session import create_session @@ -97,7 +98,7 @@ def __init__(self, return_value=False, **kwargs): self.return_value = return_value def execute_complete(self, context, event=None): - raise AirflowException("Should be skipped") + raise AirflowException() class DummySensorWithXcomValue(BaseSensorOperator): @@ -180,8 +181,8 @@ def test_fail(self, make_sensor): if ti.task_id == DUMMY_OP: assert ti.state == State.NONE - def test_soft_fail(self, make_sensor): - sensor, dr = make_sensor(False, soft_fail=True) + def test_skip_on_timeout(self, make_sensor): + sensor, dr = make_sensor(False, fail_policy=FailPolicy.SKIP_ON_TIMEOUT) self._run(sensor) tis = dr.get_task_instances() @@ -196,8 +197,8 @@ def test_soft_fail(self, make_sensor): "exception_cls", (ValueError,), ) - def test_soft_fail_with_exception(self, make_sensor, exception_cls): - sensor, dr = make_sensor(False, soft_fail=True) + def test_skip_on_timeout_with_exception(self, make_sensor, exception_cls): + sensor, dr = make_sensor(False, fail_policy=FailPolicy.SKIP_ON_TIMEOUT) sensor.poke = Mock(side_effect=[exception_cls(None)]) with pytest.raises(ValueError): self._run(sensor) @@ -215,11 +216,11 @@ def test_soft_fail_with_exception(self, make_sensor, exception_cls): ( AirflowSensorTimeout, AirflowTaskTimeout, - AirflowFailException, + AirflowPokeFailException, ), ) - def test_soft_fail_with_skip_exception(self, make_sensor, exception_cls): - sensor, dr = make_sensor(False, soft_fail=True) + def test_skip_on_timeout_with_skip_exception(self, make_sensor, exception_cls): + sensor, dr = make_sensor(False, fail_policy=FailPolicy.SKIP_ON_TIMEOUT) sensor.poke = Mock(side_effect=[exception_cls(None)]) self._run(sensor) @@ -233,10 +234,10 @@ def test_soft_fail_with_skip_exception(self, make_sensor, exception_cls): @pytest.mark.parametrize( "exception_cls", - (AirflowSensorTimeout, AirflowTaskTimeout, AirflowFailException, Exception), + (AirflowSensorTimeout, AirflowTaskTimeout, AirflowFailException, AirflowPokeFailException, Exception), ) - def test_never_fail_with_skip_exception(self, make_sensor, exception_cls): - sensor, dr = make_sensor(False, never_fail=True) + def test_skip_on_any_error_with_skip_exception(self, make_sensor, exception_cls): + sensor, dr = make_sensor(False, fail_policy=FailPolicy.SKIP_ON_ANY_ERROR) sensor.poke = Mock(side_effect=[exception_cls(None)]) self._run(sensor) @@ -248,9 +249,12 @@ def test_never_fail_with_skip_exception(self, make_sensor, exception_cls): if ti.task_id == DUMMY_OP: assert ti.state == State.NONE - def test_soft_fail_with_retries(self, make_sensor): + def test_skip_on_timeout_with_retries(self, make_sensor): sensor, dr = make_sensor( - return_value=False, soft_fail=True, retries=1, retry_delay=timedelta(milliseconds=1) + return_value=False, + fail_policy=FailPolicy.SKIP_ON_TIMEOUT, + retries=1, + retry_delay=timedelta(milliseconds=1), ) # first run times out and task instance is skipped @@ -358,9 +362,13 @@ def _get_tis(): assert sensor_ti.state == State.FAILED assert dummy_ti.state == State.NONE - def test_soft_fail_with_reschedule(self, make_sensor, time_machine, session): + def test_skip_on_timeout_with_reschedule(self, make_sensor, time_machine, session): sensor, dr = make_sensor( - return_value=False, poke_interval=10, timeout=5, soft_fail=True, mode="reschedule" + return_value=False, + poke_interval=10, + timeout=5, + fail_policy=FailPolicy.SKIP_ON_TIMEOUT, + mode="reschedule", ) def _get_tis(): @@ -1193,7 +1201,7 @@ def test_reschedule_and_retry_timeout_and_silent_fail(self, make_sensor, time_ma retries=2, retry_delay=timedelta(seconds=3), mode="reschedule", - silent_fail=True, + fail_policy=FailPolicy.IGNORE_ERROR, ) def _get_sensor_ti(): @@ -1402,14 +1410,15 @@ def test_poke_mode_only_bad_poke(self): class TestAsyncSensor: @pytest.mark.parametrize( - "soft_fail, expected_exception", + "fail_policy, expected_exception", [ - (True, AirflowSkipException), - (False, AirflowException), + (FailPolicy.SKIP_ON_TIMEOUT, AirflowException), + (FailPolicy.SKIP_ON_ANY_ERROR, AirflowSkipException), + (FailPolicy.NONE, AirflowException), ], ) - def test_fail_after_resuming_deferred_sensor(self, soft_fail, expected_exception): - async_sensor = DummyAsyncSensor(task_id="dummy_async_sensor", soft_fail=soft_fail) + def test_fail_after_resuming_deferred_sensor(self, fail_policy, expected_exception): + async_sensor = DummyAsyncSensor(task_id="dummy_async_sensor", fail_policy=fail_policy) ti = TaskInstance(task=async_sensor) ti.next_method = "execute_complete" with pytest.raises(expected_exception): diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task_sensor.py index ba3cfd6b4480c..c0cd414d6b4f8 100644 --- a/tests/sensors/test_external_task_sensor.py +++ b/tests/sensors/test_external_task_sensor.py @@ -30,7 +30,13 @@ from airflow import settings from airflow.decorators import task as task_deco -from airflow.exceptions import AirflowException, AirflowSensorTimeout, AirflowSkipException, TaskDeferred +from airflow.exceptions import ( + AirflowException, + AirflowPokeFailException, + AirflowSensorTimeout, + AirflowSkipException, + TaskDeferred, +) from airflow.models import DagBag, DagRun, TaskInstance from airflow.models.dag import DAG from airflow.models.serialized_dag import SerializedDagModel @@ -53,10 +59,17 @@ from airflow.utils.types import DagRunType from tests.models import TEST_DAGS_FOLDER +from tests_common.test_utils.compat import ( + AIRFLOW_V_2_10_PLUS, + ignore_provider_compatibility_error, +) from tests_common.test_utils.db import clear_db_runs from tests_common.test_utils.mock_operators import MockOperator from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS +with ignore_provider_compatibility_error("2.10.0", __file__): + from airflow.sensors.base import FailPolicy + if AIRFLOW_V_3_0_PLUS: from airflow.utils.types import DagRunTriggeredByType @@ -251,7 +264,7 @@ def test_external_task_group_not_exists_without_check_existence(self): dag=self.dag, poke_interval=0.1, ) - with pytest.raises(AirflowException, match="Sensor has timed out"): + with pytest.raises(AirflowSensorTimeout, match="Sensor has timed out"): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_external_task_group_sensor_success(self): @@ -278,13 +291,13 @@ def test_external_task_group_sensor_failed_states(self): dag=self.dag, ) with pytest.raises( - AirflowException, + AirflowPokeFailException, match=f"The external task_group '{TEST_TASK_GROUP_ID}' in DAG '{TEST_DAG_ID}' failed.", ): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_catch_overlap_allowed_failed_state(self): - with pytest.raises(AirflowException): + with pytest.raises(ValueError): ExternalTaskSensor( task_id="test_external_task_sensor_check", external_dag_id=TEST_DAG_ID, @@ -328,12 +341,13 @@ def test_external_task_sensor_failed_states_as_success(self, caplog): error_message = rf"Some of the external tasks \['{TEST_TASK_ID}'\] in DAG {TEST_DAG_ID} failed\." with caplog.at_level(logging.INFO, logger=op.log.name): caplog.clear() - with pytest.raises(AirflowException, match=error_message): + with pytest.raises(AirflowPokeFailException, match=error_message): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) assert ( f"Poking for tasks ['{TEST_TASK_ID}'] in dag {TEST_DAG_ID} on {DEFAULT_DATE.isoformat()} ... " ) in caplog.messages + @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0") def test_external_task_sensor_soft_fail_failed_states_as_skipped(self): self.add_time_sensor() op = ExternalTaskSensor( @@ -342,7 +356,7 @@ def test_external_task_sensor_soft_fail_failed_states_as_skipped(self): external_task_id=TEST_TASK_ID, allowed_states=[State.FAILED], failed_states=[State.SUCCESS], - soft_fail=True, + fail_policy=FailPolicy.SKIP_ON_TIMEOUT, dag=self.dag, ) @@ -431,7 +445,7 @@ def test_external_task_sensor_failed_states_as_success_mulitple_task_ids(self, c ) with caplog.at_level(logging.INFO, logger=op.log.name): caplog.clear() - with pytest.raises(AirflowException, match=error_message): + with pytest.raises(AirflowPokeFailException, match=error_message): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) assert ( f"Poking for tasks ['{TEST_TASK_ID}', '{TEST_TASK_ID_ALTERNATE}'] " @@ -476,6 +490,7 @@ def test_external_dag_sensor_log(self, caplog): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) assert (f"Poking for DAG 'other_dag' on {DEFAULT_DATE.isoformat()} ... ") in caplog.messages + @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0") def test_external_dag_sensor_soft_fail_as_skipped(self): other_dag = DAG("other_dag", default_args=self.args, end_date=DEFAULT_DATE, schedule="@once") triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} @@ -493,7 +508,7 @@ def test_external_dag_sensor_soft_fail_as_skipped(self): external_task_id=None, allowed_states=[State.FAILED], failed_states=[State.SUCCESS], - soft_fail=True, + fail_policy=FailPolicy.SKIP_ON_TIMEOUT, dag=self.dag, ) @@ -600,12 +615,12 @@ def test_external_task_sensor_fn_multiple_logical_dates(self): dag=dag, ) - # We need to test for an AirflowException explicitly since + # We need to test for an AirflowPokeFailException explicitly since # AirflowSensorTimeout is a subclass that will be raised if this does # not execute properly. - with pytest.raises(AirflowException) as ex_ctx: + with pytest.raises(AirflowPokeFailException) as ex_ctx: task_chain_with_failure.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - assert type(ex_ctx.value) is AirflowException + assert type(ex_ctx.value) is AirflowPokeFailException def test_external_task_sensor_delta(self): self.add_time_sensor() @@ -839,7 +854,7 @@ def test_external_task_group_with_mapped_tasks_failed_states(self): dag=self.dag, ) with pytest.raises( - AirflowException, + AirflowPokeFailException, match=f"The external task_group '{TEST_TASK_GROUP_ID}' in DAG '{TEST_DAG_ID}' failed.", ): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @@ -864,6 +879,7 @@ def test_external_task_group_when_there_is_no_TIs(self): ignore_ti_state=True, ) + @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0") @pytest.mark.parametrize( "kwargs, expected_message", ( @@ -890,14 +906,14 @@ def test_external_task_group_when_there_is_no_TIs(self): ), ) @pytest.mark.parametrize( - "soft_fail, expected_exception", + "fail_policy, expected_exception", ( ( - False, - AirflowException, + FailPolicy.NONE, + AirflowPokeFailException, ), ( - True, + FailPolicy.SKIP_ON_TIMEOUT, AirflowSkipException, ), ), @@ -905,7 +921,7 @@ def test_external_task_group_when_there_is_no_TIs(self): @mock.patch("airflow.providers.standard.sensors.external_task.ExternalTaskSensor.get_count") @mock.patch("airflow.providers.standard.sensors.external_task.ExternalTaskSensor._get_dttm_filter") def test_fail_poke( - self, _get_dttm_filter, get_count, soft_fail, expected_exception, kwargs, expected_message + self, _get_dttm_filter, get_count, fail_policy, expected_exception, kwargs, expected_message ): _get_dttm_filter.return_value = [] get_count.return_value = 1 @@ -914,13 +930,16 @@ def test_fail_poke( external_dag_id=TEST_DAG_ID, allowed_states=["success"], dag=self.dag, - soft_fail=soft_fail, + fail_policy=fail_policy, deferrable=False, **kwargs, ) + if fail_policy == FailPolicy.SKIP_ON_TIMEOUT: + expected_message = "Skipping due fail_policy set to SKIP_ON_TIMEOUT." with pytest.raises(expected_exception, match=expected_message): op.execute(context={}) + @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0") @pytest.mark.parametrize( "response_get_current, response_exists, kwargs, expected_message", ( @@ -947,15 +966,15 @@ def test_fail_poke( ), ) @pytest.mark.parametrize( - "soft_fail, expected_exception", + "fail_policy, expected_exception", ( ( - False, - AirflowException, + FailPolicy.NONE, + AirflowPokeFailException, ), ( - True, - AirflowException, + FailPolicy.SKIP_ON_TIMEOUT, + AirflowSkipException, ), ), ) @@ -969,7 +988,7 @@ def test_fail__check_for_existence( exists, get_dag, _get_dttm_filter, - soft_fail, + fail_policy, expected_exception, response_get_current, response_exists, @@ -988,10 +1007,12 @@ def test_fail__check_for_existence( external_dag_id=TEST_DAG_ID, allowed_states=["success"], dag=self.dag, - soft_fail=soft_fail, + fail_policy=fail_policy, check_existence=True, **kwargs, ) + if fail_policy == FailPolicy.SKIP_ON_TIMEOUT: + expected_message = "Skipping due fail_policy set to SKIP_ON_TIMEOUT." with pytest.raises(expected_exception, match=expected_message): op.execute(context={}) @@ -1020,7 +1041,7 @@ def test_defer_and_fire_task_state_trigger(self): assert isinstance(exc.value.trigger, WorkflowTrigger), "Trigger is not a WorkflowTrigger" def test_defer_and_fire_failed_state_trigger(self): - """Tests that an AirflowException is raised in case of error event""" + """Tests that an AirflowPokeFailException is raised in case of error event""" sensor = ExternalTaskSensor( task_id=TASK_ID, external_task_id=EXTERNAL_TASK_ID, @@ -1028,13 +1049,13 @@ def test_defer_and_fire_failed_state_trigger(self): deferrable=True, ) - with pytest.raises(AirflowException): + with pytest.raises(AirflowPokeFailException): sensor.execute_complete( context=mock.MagicMock(), event={"status": "error", "message": "test failure message"} ) def test_defer_and_fire_timeout_state_trigger(self): - """Tests that an AirflowException is raised in case of timeout event""" + """Tests that an AirflowPokeFailException is raised in case of timeout event""" sensor = ExternalTaskSensor( task_id=TASK_ID, external_task_id=EXTERNAL_TASK_ID, @@ -1042,7 +1063,7 @@ def test_defer_and_fire_timeout_state_trigger(self): deferrable=True, ) - with pytest.raises(AirflowException): + with pytest.raises(AirflowPokeFailException): sensor.execute_complete( context=mock.MagicMock(), event={"status": "timeout", "message": "Dag was not started within 1 minute, assuming fail."},