From e84d1e347b38af9a3df5767144150f9221b550eb Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Tue, 7 Jan 2025 17:21:27 +0800 Subject: [PATCH] Better compat in OL listener tests --- airflow/models/dag.py | 7 +- .../openlineage/plugins/test_listener.py | 48 +++++---- tests/models/test_taskinstance.py | 3 + tests/utils/test_log_handlers.py | 100 ++++++------------ tests/utils/test_sqlalchemy.py | 3 + tests/utils/test_state.py | 5 + tests/utils/test_types.py | 74 +++++-------- 7 files changed, 103 insertions(+), 137 deletions(-) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index f4d48710d7ee06..74f613c0864b95 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -1745,7 +1745,7 @@ def create_dagrun( """ logical_date = timezone.coerce_datetime(logical_date) - if data_interval and not isinstance(data_interval, DataInterval): + if not isinstance(data_interval, DataInterval): data_interval = DataInterval(*map(timezone.coerce_datetime, data_interval)) if isinstance(run_type, DagRunType): @@ -1755,6 +1755,9 @@ def create_dagrun( else: raise ValueError(f"run_type should be a DagRunType, not {type(run_type)}") + if not isinstance(run_id, str): + raise ValueError(f"`run_id` should be a str, not {type(run_id)}") + # Prevent a manual run from using an ID that looks like a scheduled run. if run_type == DagRunType.MANUAL: if (inferred_run_type := DagRunType.from_run_id(run_id)) != DagRunType.MANUAL: @@ -1778,7 +1781,7 @@ def create_dagrun( dag=self, run_id=run_id, logical_date=logical_date, - start_date=start_date, + start_date=timezone.coerce_datetime(start_date), external_trigger=external_trigger, conf=conf, state=state, diff --git a/providers/tests/openlineage/plugins/test_listener.py b/providers/tests/openlineage/plugins/test_listener.py index 7d73d3243e226b..837873f439d24e 100644 --- a/providers/tests/openlineage/plugins/test_listener.py +++ b/providers/tests/openlineage/plugins/test_listener.py @@ -20,7 +20,7 @@ import uuid from concurrent.futures import Future from contextlib import suppress -from typing import Any, Callable +from typing import Callable from unittest import mock from unittest.mock import ANY, MagicMock, patch @@ -41,6 +41,7 @@ from tests_common.test_utils.compat import PythonOperator from tests_common.test_utils.config import conf_vars +from tests_common.test_utils.db import clear_db_runs from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS pytestmark = pytest.mark.db_test @@ -88,21 +89,20 @@ def test_listener_does_not_change_task_instance(render_mock, xcom_push_mock): ) t = TemplateOperator(task_id="template_op", dag=dag, do_xcom_push=True, df=dag.param("df")) run_id = str(uuid.uuid1()) - v3_kwargs = ( - { + if AIRFLOW_V_3_0_PLUS: + dagrun_kwargs = { "dag_version": None, + "logical_date": date, "triggered_by": types.DagRunTriggeredByType.TEST, } - if AIRFLOW_V_3_0_PLUS - else {} - ) + else: + dagrun_kwargs = {"execution_date": date} dag.create_dagrun( run_id=run_id, - logical_date=date, data_interval=(date, date), run_type=types.DagRunType.MANUAL, state=DagRunState.QUEUED, - **v3_kwargs, + **dagrun_kwargs, ) ti = TaskInstance(t, run_id=run_id) ti.check_and_change_state_before_execution() # make listener hook on running event @@ -155,8 +155,9 @@ def _create_test_dag_and_task(python_callable: Callable, scenario_name: str) -> :return: TaskInstance: The created TaskInstance object. - This function creates a DAG and a PythonOperator task with the provided python_callable. It generates a unique - run ID and creates a DAG run. This setup is useful for testing different scenarios in Airflow tasks. + This function creates a DAG and a PythonOperator task with the provided + python_callable. It generates a unique run ID and creates a DAG run. This + setup is useful for testing different scenarios in Airflow tasks. :Example: @@ -174,21 +175,20 @@ def sample_callable(**kwargs): ) t = PythonOperator(task_id=f"test_task_{scenario_name}", dag=dag, python_callable=python_callable) run_id = str(uuid.uuid1()) - v3_kwargs: dict[str, Any] = ( - { + if AIRFLOW_V_3_0_PLUS: + dagrun_kwargs: dict = { "dag_version": None, + "logical_date": date, "triggered_by": types.DagRunTriggeredByType.TEST, } - if AIRFLOW_V_3_0_PLUS - else {} - ) + else: + dagrun_kwargs = {"execution_date": date} dagrun = dag.create_dagrun( run_id=run_id, - logical_date=date, data_interval=(date, date), run_type=types.DagRunType.MANUAL, state=DagRunState.QUEUED, - **v3_kwargs, + **dagrun_kwargs, ) task_instance = TaskInstance(t, run_id=run_id) return dagrun, task_instance @@ -706,26 +706,28 @@ def simple_callable(**kwargs): task_id="test_task_selective_enable_2", dag=self.dag, python_callable=simple_callable ) run_id = str(uuid.uuid1()) - v3_kwargs = ( - { + if AIRFLOW_V_3_0_PLUS: + dagrun_kwargs = { "dag_version": None, "logical_date": date, "triggered_by": types.DagRunTriggeredByType.TEST, } - if AIRFLOW_V_3_0_PLUS - else {"execution_date": date} - ) + else: + dagrun_kwargs = {"execution_date": date} self.dagrun = self.dag.create_dagrun( run_id=run_id, data_interval=(date, date), run_type=types.DagRunType.MANUAL, state=DagRunState.QUEUED, - **v3_kwargs, + **dagrun_kwargs, ) # type: ignore self.task_instance_1 = TaskInstance(self.task_1, run_id=run_id, map_index=-1) self.task_instance_2 = TaskInstance(self.task_2, run_id=run_id, map_index=-1) self.task_instance_1.dag_run = self.task_instance_2.dag_run = self.dagrun + def teardown_method(self): + clear_db_runs() + @pytest.mark.parametrize( "selective_enable, enable_dag, expected_call_count", [ diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 73f5908b707cff..7b3102d4c4aea4 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -1748,6 +1748,8 @@ def test_xcom_pull_different_logical_date(self, create_task_instance): triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} dr = ti.task.dag.create_dagrun( run_id="test2", + run_type=DagRunType.MANUAL, + logical_date=exec_date, data_interval=(exec_date, exec_date), state=None, **triggered_by_kwargs, @@ -2022,6 +2024,7 @@ def test_get_num_running_task_instances(self, create_task_instance): triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} dr = ti1.task.dag.create_dagrun( logical_date=logical_date, + run_type=DagRunType.MANUAL, state=None, run_id="2", session=session, diff --git a/tests/utils/test_log_handlers.py b/tests/utils/test_log_handlers.py index 19a432bb737671..454af48d667632 100644 --- a/tests/utils/test_log_handlers.py +++ b/tests/utils/test_log_handlers.py @@ -36,7 +36,6 @@ from airflow.executors import executor_loader from airflow.jobs.job import Job from airflow.jobs.triggerer_job_runner import TriggererJobRunner -from airflow.models.dag import DAG from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskInstance from airflow.models.trigger import Trigger @@ -56,10 +55,6 @@ from airflow.utils.types import DagRunType from tests_common.test_utils.config import conf_vars -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS - -if AIRFLOW_V_3_0_PLUS: - from airflow.utils.types import DagRunTriggeredByType pytestmark = pytest.mark.db_test @@ -91,24 +86,17 @@ def test_default_task_logging_setup(self): handler = handlers[0] assert handler.name == FILE_TASK_HANDLER - def test_file_task_handler_when_ti_value_is_invalid(self): + def test_file_task_handler_when_ti_value_is_invalid(self, dag_maker): def task_callable(ti): ti.log.info("test") - dag = DAG("dag_for_testing_file_task_handler", schedule=None, start_date=DEFAULT_DATE) - triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} - dagrun = dag.create_dagrun( - run_type=DagRunType.MANUAL, - state=State.RUNNING, - logical_date=DEFAULT_DATE, - data_interval=dag.timetable.infer_manual_data_interval(run_after=DEFAULT_DATE), - **triggered_by_kwargs, - ) - task = PythonOperator( - task_id="task_for_testing_file_log_handler", - dag=dag, - python_callable=task_callable, - ) + with dag_maker("dag_for_testing_file_task_handler", schedule=None): + task = PythonOperator( + task_id="task_for_testing_file_log_handler", + python_callable=task_callable, + ) + + dagrun = dag_maker.create_dagrun() ti = TaskInstance(task=task, run_id=dagrun.run_id) logger = ti.log @@ -146,26 +134,22 @@ def task_callable(ti): # Remove the generated tmp log file. os.remove(log_filename) - def test_file_task_handler(self): + def test_file_task_handler(self, dag_maker, session): def task_callable(ti): ti.log.info("test") - dag = DAG("dag_for_testing_file_task_handler", schedule=None, start_date=DEFAULT_DATE) - triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} - dagrun = dag.create_dagrun( - run_type=DagRunType.MANUAL, - state=State.RUNNING, - logical_date=DEFAULT_DATE, - data_interval=dag.timetable.infer_manual_data_interval(run_after=DEFAULT_DATE), - **triggered_by_kwargs, - ) - task = PythonOperator( - task_id="task_for_testing_file_log_handler", - dag=dag, - python_callable=task_callable, - ) - ti = TaskInstance(task=task, run_id=dagrun.run_id) + with dag_maker("dag_for_testing_file_task_handler", schedule=None, session=session): + PythonOperator( + task_id="task_for_testing_file_log_handler", + python_callable=task_callable, + ) + + dagrun = dag_maker.create_dagrun() + + (ti,) = dagrun.get_task_instances() ti.try_number += 1 + session.merge(ti) + session.flush() logger = ti.log ti.log.disabled = False @@ -203,24 +187,16 @@ def task_callable(ti): # Remove the generated tmp log file. os.remove(log_filename) - def test_file_task_handler_running(self): + def test_file_task_handler_running(self, dag_maker): def task_callable(ti): ti.log.info("test") - dag = DAG("dag_for_testing_file_task_handler", schedule=None, start_date=DEFAULT_DATE) - task = PythonOperator( - task_id="task_for_testing_file_log_handler", - python_callable=task_callable, - dag=dag, - ) - triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} - dagrun = dag.create_dagrun( - run_type=DagRunType.MANUAL, - state=State.RUNNING, - logical_date=DEFAULT_DATE, - data_interval=dag.timetable.infer_manual_data_interval(run_after=DEFAULT_DATE), - **triggered_by_kwargs, - ) + with dag_maker("dag_for_testing_file_task_handler", schedule=None): + task = PythonOperator( + task_id="task_for_testing_file_log_handler", + python_callable=task_callable, + ) + dagrun = dag_maker.create_dagrun() ti = TaskInstance(task=task, run_id=dagrun.run_id) ti.try_number = 2 @@ -256,7 +232,7 @@ def task_callable(ti): # Remove the generated tmp log file. os.remove(log_filename) - def test_file_task_handler_rotate_size_limit(self): + def test_file_task_handler_rotate_size_limit(self, dag_maker): def reset_log_config(update_conf): import logging.config @@ -270,20 +246,12 @@ def task_callable(ti): max_bytes_size = 60000 update_conf = {"handlers": {"task": {"max_bytes": max_bytes_size, "backup_count": 1}}} reset_log_config(update_conf) - dag = DAG("dag_for_testing_file_task_handler_rotate_size_limit", start_date=DEFAULT_DATE) - task = PythonOperator( - task_id="task_for_testing_file_log_handler_rotate_size_limit", - python_callable=task_callable, - dag=dag, - ) - triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} - dagrun = dag.create_dagrun( - run_type=DagRunType.MANUAL, - state=State.RUNNING, - logical_date=DEFAULT_DATE, - data_interval=dag.timetable.infer_manual_data_interval(run_after=DEFAULT_DATE), - **triggered_by_kwargs, - ) + with dag_maker("dag_for_testing_file_task_handler_rotate_size_limit"): + task = PythonOperator( + task_id="task_for_testing_file_log_handler_rotate_size_limit", + python_callable=task_callable, + ) + dagrun = dag_maker.create_dagrun() ti = TaskInstance(task=task, run_id=dagrun.run_id) ti.try_number = 1 diff --git a/tests/utils/test_sqlalchemy.py b/tests/utils/test_sqlalchemy.py index 0e791de19552dc..5b30595e4a3050 100644 --- a/tests/utils/test_sqlalchemy.py +++ b/tests/utils/test_sqlalchemy.py @@ -42,6 +42,7 @@ ) from airflow.utils.state import State from airflow.utils.timezone import utcnow +from airflow.utils.types import DagRunType from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS @@ -81,6 +82,7 @@ def test_utc_transformations(self): triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} run = dag.create_dagrun( run_id=iso_date, + run_type=DagRunType.MANUAL, state=State.NONE, logical_date=logical_date, start_date=start_date, @@ -115,6 +117,7 @@ def test_process_bind_param_naive(self): with pytest.raises((ValueError, StatementError)): dag.create_dagrun( run_id=start_date.isoformat, + run_type=DagRunType.MANUAL, state=State.NONE, logical_date=start_date, start_date=start_date, diff --git a/tests/utils/test_state.py b/tests/utils/test_state.py index 5ad9f4a7044ad6..e99bc0bd083870 100644 --- a/tests/utils/test_state.py +++ b/tests/utils/test_state.py @@ -44,6 +44,11 @@ def test_dagrun_state_enum_escape(): dag = DAG(dag_id="test_dagrun_state_enum_escape", schedule=timedelta(days=1), start_date=DEFAULT_DATE) triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} dag.create_dagrun( + run_id=dag.timetable.generate_run_id( + run_type=DagRunType.SCHEDULED, + logical_date=DEFAULT_DATE, + data_interval=dag.timetable.infer_manual_data_interval(run_after=DEFAULT_DATE), + ), run_type=DagRunType.SCHEDULED, state=DagRunState.QUEUED, logical_date=DEFAULT_DATE, diff --git a/tests/utils/test_types.py b/tests/utils/test_types.py index c66c58079880a3..4a6831f40354d4 100644 --- a/tests/utils/test_types.py +++ b/tests/utils/test_types.py @@ -20,60 +20,42 @@ import pytest -from airflow.models.dag import DAG from airflow.models.dagrun import DagRun -from airflow.utils.session import create_session from airflow.utils.state import State from airflow.utils.types import DagRunType -from tests.models import DEFAULT_DATE -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS - -if AIRFLOW_V_3_0_PLUS: - from airflow.utils.types import DagRunTriggeredByType - pytestmark = pytest.mark.db_test -def test_runtype_enum_escape(): +def test_runtype_enum_escape(dag_maker, session): """ Make sure DagRunType.SCHEDULE is converted to string 'scheduled' when referenced in DB query """ - with create_session() as session: - dag = DAG(dag_id="test_enum_dags", schedule=timedelta(days=1), start_date=DEFAULT_DATE) - data_interval = dag.timetable.infer_manual_data_interval(run_after=DEFAULT_DATE) - triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} - dag.create_dagrun( - run_type=DagRunType.SCHEDULED, - state=State.RUNNING, - logical_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, - session=session, - data_interval=data_interval, - **triggered_by_kwargs, - ) - - query = session.query( - DagRun.dag_id, - DagRun.state, - DagRun.run_type, - ).filter( - DagRun.dag_id == dag.dag_id, - # make sure enum value can be used in filter queries - DagRun.run_type == DagRunType.SCHEDULED, - ) - assert str(query.statement.compile(compile_kwargs={"literal_binds": True})) == ( - "SELECT dag_run.dag_id, dag_run.state, dag_run.run_type \n" - "FROM dag_run \n" - "WHERE dag_run.dag_id = 'test_enum_dags' AND dag_run.run_type = 'scheduled'" - ) - - rows = query.all() - assert len(rows) == 1 - assert rows[0].dag_id == dag.dag_id - assert rows[0].state == State.RUNNING - # make sure value in db is stored as `scheduled`, not `DagRunType.SCHEDULED` - assert rows[0].run_type == "scheduled" - - session.rollback() + with dag_maker(dag_id="test_enum_dags", schedule=timedelta(days=1), session=session): + pass + dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) + + query = session.query( + DagRun.dag_id, + DagRun.state, + DagRun.run_type, + ).filter( + DagRun.dag_id == "test_enum_dags", + # make sure enum value can be used in filter queries + DagRun.run_type == DagRunType.SCHEDULED, + ) + assert str(query.statement.compile(compile_kwargs={"literal_binds": True})) == ( + "SELECT dag_run.dag_id, dag_run.state, dag_run.run_type \n" + "FROM dag_run \n" + "WHERE dag_run.dag_id = 'test_enum_dags' AND dag_run.run_type = 'scheduled'" + ) + + rows = query.all() + assert len(rows) == 1 + assert rows[0].dag_id == "test_enum_dags" + assert rows[0].state == State.RUNNING + # make sure value in db is stored as `scheduled`, not `DagRunType.SCHEDULED` + assert rows[0].run_type == "scheduled" + + session.rollback()