Skip to content

Commit

Permalink
Better compat in OL listener tests
Browse files Browse the repository at this point in the history
  • Loading branch information
uranusjr committed Jan 7, 2025
1 parent d08c0ed commit e84d1e3
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 137 deletions.
7 changes: 5 additions & 2 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1745,7 +1745,7 @@ def create_dagrun(
"""
logical_date = timezone.coerce_datetime(logical_date)

if data_interval and not isinstance(data_interval, DataInterval):
if not isinstance(data_interval, DataInterval):
data_interval = DataInterval(*map(timezone.coerce_datetime, data_interval))

if isinstance(run_type, DagRunType):
Expand All @@ -1755,6 +1755,9 @@ def create_dagrun(
else:
raise ValueError(f"run_type should be a DagRunType, not {type(run_type)}")

if not isinstance(run_id, str):
raise ValueError(f"`run_id` should be a str, not {type(run_id)}")

# Prevent a manual run from using an ID that looks like a scheduled run.
if run_type == DagRunType.MANUAL:
if (inferred_run_type := DagRunType.from_run_id(run_id)) != DagRunType.MANUAL:
Expand All @@ -1778,7 +1781,7 @@ def create_dagrun(
dag=self,
run_id=run_id,
logical_date=logical_date,
start_date=start_date,
start_date=timezone.coerce_datetime(start_date),
external_trigger=external_trigger,
conf=conf,
state=state,
Expand Down
48 changes: 25 additions & 23 deletions providers/tests/openlineage/plugins/test_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import uuid
from concurrent.futures import Future
from contextlib import suppress
from typing import Any, Callable
from typing import Callable
from unittest import mock
from unittest.mock import ANY, MagicMock, patch

Expand All @@ -41,6 +41,7 @@

from tests_common.test_utils.compat import PythonOperator
from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.db import clear_db_runs
from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS

pytestmark = pytest.mark.db_test
Expand Down Expand Up @@ -88,21 +89,20 @@ def test_listener_does_not_change_task_instance(render_mock, xcom_push_mock):
)
t = TemplateOperator(task_id="template_op", dag=dag, do_xcom_push=True, df=dag.param("df"))
run_id = str(uuid.uuid1())
v3_kwargs = (
{
if AIRFLOW_V_3_0_PLUS:
dagrun_kwargs = {
"dag_version": None,
"logical_date": date,
"triggered_by": types.DagRunTriggeredByType.TEST,
}
if AIRFLOW_V_3_0_PLUS
else {}
)
else:
dagrun_kwargs = {"execution_date": date}
dag.create_dagrun(
run_id=run_id,
logical_date=date,
data_interval=(date, date),
run_type=types.DagRunType.MANUAL,
state=DagRunState.QUEUED,
**v3_kwargs,
**dagrun_kwargs,
)
ti = TaskInstance(t, run_id=run_id)
ti.check_and_change_state_before_execution() # make listener hook on running event
Expand Down Expand Up @@ -155,8 +155,9 @@ def _create_test_dag_and_task(python_callable: Callable, scenario_name: str) ->
:return: TaskInstance: The created TaskInstance object.
This function creates a DAG and a PythonOperator task with the provided python_callable. It generates a unique
run ID and creates a DAG run. This setup is useful for testing different scenarios in Airflow tasks.
This function creates a DAG and a PythonOperator task with the provided
python_callable. It generates a unique run ID and creates a DAG run. This
setup is useful for testing different scenarios in Airflow tasks.
:Example:
Expand All @@ -174,21 +175,20 @@ def sample_callable(**kwargs):
)
t = PythonOperator(task_id=f"test_task_{scenario_name}", dag=dag, python_callable=python_callable)
run_id = str(uuid.uuid1())
v3_kwargs: dict[str, Any] = (
{
if AIRFLOW_V_3_0_PLUS:
dagrun_kwargs: dict = {
"dag_version": None,
"logical_date": date,
"triggered_by": types.DagRunTriggeredByType.TEST,
}
if AIRFLOW_V_3_0_PLUS
else {}
)
else:
dagrun_kwargs = {"execution_date": date}
dagrun = dag.create_dagrun(
run_id=run_id,
logical_date=date,
data_interval=(date, date),
run_type=types.DagRunType.MANUAL,
state=DagRunState.QUEUED,
**v3_kwargs,
**dagrun_kwargs,
)
task_instance = TaskInstance(t, run_id=run_id)
return dagrun, task_instance
Expand Down Expand Up @@ -706,26 +706,28 @@ def simple_callable(**kwargs):
task_id="test_task_selective_enable_2", dag=self.dag, python_callable=simple_callable
)
run_id = str(uuid.uuid1())
v3_kwargs = (
{
if AIRFLOW_V_3_0_PLUS:
dagrun_kwargs = {
"dag_version": None,
"logical_date": date,
"triggered_by": types.DagRunTriggeredByType.TEST,
}
if AIRFLOW_V_3_0_PLUS
else {"execution_date": date}
)
else:
dagrun_kwargs = {"execution_date": date}
self.dagrun = self.dag.create_dagrun(
run_id=run_id,
data_interval=(date, date),
run_type=types.DagRunType.MANUAL,
state=DagRunState.QUEUED,
**v3_kwargs,
**dagrun_kwargs,
) # type: ignore
self.task_instance_1 = TaskInstance(self.task_1, run_id=run_id, map_index=-1)
self.task_instance_2 = TaskInstance(self.task_2, run_id=run_id, map_index=-1)
self.task_instance_1.dag_run = self.task_instance_2.dag_run = self.dagrun

def teardown_method(self):
clear_db_runs()

@pytest.mark.parametrize(
"selective_enable, enable_dag, expected_call_count",
[
Expand Down
3 changes: 3 additions & 0 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1748,6 +1748,8 @@ def test_xcom_pull_different_logical_date(self, create_task_instance):
triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {}
dr = ti.task.dag.create_dagrun(
run_id="test2",
run_type=DagRunType.MANUAL,
logical_date=exec_date,
data_interval=(exec_date, exec_date),
state=None,
**triggered_by_kwargs,
Expand Down Expand Up @@ -2022,6 +2024,7 @@ def test_get_num_running_task_instances(self, create_task_instance):
triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {}
dr = ti1.task.dag.create_dagrun(
logical_date=logical_date,
run_type=DagRunType.MANUAL,
state=None,
run_id="2",
session=session,
Expand Down
100 changes: 34 additions & 66 deletions tests/utils/test_log_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from airflow.executors import executor_loader
from airflow.jobs.job import Job
from airflow.jobs.triggerer_job_runner import TriggererJobRunner
from airflow.models.dag import DAG
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.models.trigger import Trigger
Expand All @@ -56,10 +55,6 @@
from airflow.utils.types import DagRunType

from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS

if AIRFLOW_V_3_0_PLUS:
from airflow.utils.types import DagRunTriggeredByType

pytestmark = pytest.mark.db_test

Expand Down Expand Up @@ -91,24 +86,17 @@ def test_default_task_logging_setup(self):
handler = handlers[0]
assert handler.name == FILE_TASK_HANDLER

def test_file_task_handler_when_ti_value_is_invalid(self):
def test_file_task_handler_when_ti_value_is_invalid(self, dag_maker):
def task_callable(ti):
ti.log.info("test")

dag = DAG("dag_for_testing_file_task_handler", schedule=None, start_date=DEFAULT_DATE)
triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {}
dagrun = dag.create_dagrun(
run_type=DagRunType.MANUAL,
state=State.RUNNING,
logical_date=DEFAULT_DATE,
data_interval=dag.timetable.infer_manual_data_interval(run_after=DEFAULT_DATE),
**triggered_by_kwargs,
)
task = PythonOperator(
task_id="task_for_testing_file_log_handler",
dag=dag,
python_callable=task_callable,
)
with dag_maker("dag_for_testing_file_task_handler", schedule=None):
task = PythonOperator(
task_id="task_for_testing_file_log_handler",
python_callable=task_callable,
)

dagrun = dag_maker.create_dagrun()
ti = TaskInstance(task=task, run_id=dagrun.run_id)

logger = ti.log
Expand Down Expand Up @@ -146,26 +134,22 @@ def task_callable(ti):
# Remove the generated tmp log file.
os.remove(log_filename)

def test_file_task_handler(self):
def test_file_task_handler(self, dag_maker, session):
def task_callable(ti):
ti.log.info("test")

dag = DAG("dag_for_testing_file_task_handler", schedule=None, start_date=DEFAULT_DATE)
triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {}
dagrun = dag.create_dagrun(
run_type=DagRunType.MANUAL,
state=State.RUNNING,
logical_date=DEFAULT_DATE,
data_interval=dag.timetable.infer_manual_data_interval(run_after=DEFAULT_DATE),
**triggered_by_kwargs,
)
task = PythonOperator(
task_id="task_for_testing_file_log_handler",
dag=dag,
python_callable=task_callable,
)
ti = TaskInstance(task=task, run_id=dagrun.run_id)
with dag_maker("dag_for_testing_file_task_handler", schedule=None, session=session):
PythonOperator(
task_id="task_for_testing_file_log_handler",
python_callable=task_callable,
)

dagrun = dag_maker.create_dagrun()

(ti,) = dagrun.get_task_instances()
ti.try_number += 1
session.merge(ti)
session.flush()
logger = ti.log
ti.log.disabled = False

Expand Down Expand Up @@ -203,24 +187,16 @@ def task_callable(ti):
# Remove the generated tmp log file.
os.remove(log_filename)

def test_file_task_handler_running(self):
def test_file_task_handler_running(self, dag_maker):
def task_callable(ti):
ti.log.info("test")

dag = DAG("dag_for_testing_file_task_handler", schedule=None, start_date=DEFAULT_DATE)
task = PythonOperator(
task_id="task_for_testing_file_log_handler",
python_callable=task_callable,
dag=dag,
)
triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {}
dagrun = dag.create_dagrun(
run_type=DagRunType.MANUAL,
state=State.RUNNING,
logical_date=DEFAULT_DATE,
data_interval=dag.timetable.infer_manual_data_interval(run_after=DEFAULT_DATE),
**triggered_by_kwargs,
)
with dag_maker("dag_for_testing_file_task_handler", schedule=None):
task = PythonOperator(
task_id="task_for_testing_file_log_handler",
python_callable=task_callable,
)
dagrun = dag_maker.create_dagrun()
ti = TaskInstance(task=task, run_id=dagrun.run_id)

ti.try_number = 2
Expand Down Expand Up @@ -256,7 +232,7 @@ def task_callable(ti):
# Remove the generated tmp log file.
os.remove(log_filename)

def test_file_task_handler_rotate_size_limit(self):
def test_file_task_handler_rotate_size_limit(self, dag_maker):
def reset_log_config(update_conf):
import logging.config

Expand All @@ -270,20 +246,12 @@ def task_callable(ti):
max_bytes_size = 60000
update_conf = {"handlers": {"task": {"max_bytes": max_bytes_size, "backup_count": 1}}}
reset_log_config(update_conf)
dag = DAG("dag_for_testing_file_task_handler_rotate_size_limit", start_date=DEFAULT_DATE)
task = PythonOperator(
task_id="task_for_testing_file_log_handler_rotate_size_limit",
python_callable=task_callable,
dag=dag,
)
triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {}
dagrun = dag.create_dagrun(
run_type=DagRunType.MANUAL,
state=State.RUNNING,
logical_date=DEFAULT_DATE,
data_interval=dag.timetable.infer_manual_data_interval(run_after=DEFAULT_DATE),
**triggered_by_kwargs,
)
with dag_maker("dag_for_testing_file_task_handler_rotate_size_limit"):
task = PythonOperator(
task_id="task_for_testing_file_log_handler_rotate_size_limit",
python_callable=task_callable,
)
dagrun = dag_maker.create_dagrun()
ti = TaskInstance(task=task, run_id=dagrun.run_id)

ti.try_number = 1
Expand Down
3 changes: 3 additions & 0 deletions tests/utils/test_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
)
from airflow.utils.state import State
from airflow.utils.timezone import utcnow
from airflow.utils.types import DagRunType

from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS

Expand Down Expand Up @@ -81,6 +82,7 @@ def test_utc_transformations(self):
triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {}
run = dag.create_dagrun(
run_id=iso_date,
run_type=DagRunType.MANUAL,
state=State.NONE,
logical_date=logical_date,
start_date=start_date,
Expand Down Expand Up @@ -115,6 +117,7 @@ def test_process_bind_param_naive(self):
with pytest.raises((ValueError, StatementError)):
dag.create_dagrun(
run_id=start_date.isoformat,
run_type=DagRunType.MANUAL,
state=State.NONE,
logical_date=start_date,
start_date=start_date,
Expand Down
5 changes: 5 additions & 0 deletions tests/utils/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ def test_dagrun_state_enum_escape():
dag = DAG(dag_id="test_dagrun_state_enum_escape", schedule=timedelta(days=1), start_date=DEFAULT_DATE)
triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {}
dag.create_dagrun(
run_id=dag.timetable.generate_run_id(
run_type=DagRunType.SCHEDULED,
logical_date=DEFAULT_DATE,
data_interval=dag.timetable.infer_manual_data_interval(run_after=DEFAULT_DATE),
),
run_type=DagRunType.SCHEDULED,
state=DagRunState.QUEUED,
logical_date=DEFAULT_DATE,
Expand Down
Loading

0 comments on commit e84d1e3

Please sign in to comment.