diff --git a/airflow/cli/commands/local_commands/dag_processor_command.py b/airflow/cli/commands/local_commands/dag_processor_command.py index af2b65ff49b9f..653c5f6bf577f 100644 --- a/airflow/cli/commands/local_commands/dag_processor_command.py +++ b/airflow/cli/commands/local_commands/dag_processor_command.py @@ -39,6 +39,7 @@ def _create_dag_processor_job_runner(args: Any) -> DagProcessorJobRunner: job=Job(), processor=DagFileProcessorManager( processor_timeout=processor_timeout_seconds, + dag_directory=args.subdir, max_runs=args.num_runs, ), ) diff --git a/airflow/dag_processing/bundles/manager.py b/airflow/dag_processing/bundles/manager.py index 49cccc02849e8..0a17ab3f8c5f5 100644 --- a/airflow/dag_processing/bundles/manager.py +++ b/airflow/dag_processing/bundles/manager.py @@ -64,23 +64,6 @@ def parse_config(self) -> None: "Bundle config is not a list. Check config value" " for section `dag_bundles` and key `backends`." ) - - # example dags! - if conf.getboolean("core", "LOAD_EXAMPLES"): - from airflow import example_dags - - example_dag_folder = next(iter(example_dags.__path__)) - backends.append( - { - "name": "example_dags", - "classpath": "airflow.dag_processing.bundles.local.LocalDagBundle", - "kwargs": { - "local_folder": example_dag_folder, - "refresh_interval": conf.getint("scheduler", "dag_dir_list_interval"), - }, - } - ) - seen = set() for cfg in backends: name = cfg["name"] diff --git a/airflow/dag_processing/collection.py b/airflow/dag_processing/collection.py index 6e0b627198995..65d6dbea77bf4 100644 --- a/airflow/dag_processing/collection.py +++ b/airflow/dag_processing/collection.py @@ -74,14 +74,11 @@ log = logging.getLogger(__name__) -def _create_orm_dags( - bundle_name: str, dags: Iterable[MaybeSerializedDAG], *, session: Session -) -> Iterator[DagModel]: +def _create_orm_dags(dags: Iterable[MaybeSerializedDAG], *, session: Session) -> Iterator[DagModel]: for dag in dags: orm_dag = DagModel(dag_id=dag.dag_id) if dag.is_paused_upon_creation is not None: orm_dag.is_paused = dag.is_paused_upon_creation - orm_dag.bundle_name = bundle_name log.info("Creating ORM DAG for %s", dag.dag_id) session.add(orm_dag) yield orm_dag @@ -273,8 +270,6 @@ def _update_import_errors(files_parsed: set[str], import_errors: dict[str, str], def update_dag_parsing_results_in_db( - bundle_name: str, - bundle_version: str | None, dags: Collection[MaybeSerializedDAG], import_errors: dict[str, str], warnings: set[DagWarning], @@ -312,7 +307,7 @@ def update_dag_parsing_results_in_db( ) log.debug("Calling the DAG.bulk_sync_to_db method") try: - DAG.bulk_write_to_db(bundle_name, bundle_version, dags, session=session) + DAG.bulk_write_to_db(dags, session=session) # Write Serialized DAGs to DB, capturing errors # Write Serialized DAGs to DB, capturing errors for dag in dags: @@ -351,8 +346,6 @@ class DagModelOperation(NamedTuple): """Collect DAG objects and perform database operations for them.""" dags: dict[str, MaybeSerializedDAG] - bundle_name: str - bundle_version: str | None def find_orm_dags(self, *, session: Session) -> dict[str, DagModel]: """Find existing DagModel objects from DAG objects.""" @@ -372,8 +365,7 @@ def add_dags(self, *, session: Session) -> dict[str, DagModel]: orm_dags.update( (model.dag_id, model) for model in _create_orm_dags( - bundle_name=self.bundle_name, - dags=(dag for dag_id, dag in self.dags.items() if dag_id not in orm_dags), + (dag for dag_id, dag in self.dags.items() if dag_id not in orm_dags), session=session, ) ) @@ -438,8 +430,6 @@ def update_dags( dm.timetable_summary = dag.timetable.summary dm.timetable_description = dag.timetable.description dm.asset_expression = dag.timetable.asset_condition.as_expression() - dm.bundle_name = self.bundle_name - dm.latest_bundle_version = self.bundle_version last_automated_run: DagRun | None = run_info.latest_runs.get(dag.dag_id) if last_automated_run is None: diff --git a/airflow/dag_processing/manager.py b/airflow/dag_processing/manager.py index a66788dc8fbe3..99257b991114e 100644 --- a/airflow/dag_processing/manager.py +++ b/airflow/dag_processing/manager.py @@ -51,7 +51,6 @@ from airflow.dag_processing.processor import DagFileParsingResult, DagFileProcessorProcess from airflow.models.dag import DagModel from airflow.models.dagbag import DagPriorityParsingRequest -from airflow.models.dagbundle import DagBundleModel from airflow.models.dagwarning import DagWarning from airflow.models.db_callback_request import DbCallbackRequest from airflow.models.errors import ParseImportError @@ -69,7 +68,7 @@ set_new_process_group, ) from airflow.utils.retries import retry_db_transaction -from airflow.utils.session import NEW_SESSION, create_session, provide_session +from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.sqlalchemy import prohibit_commit, with_row_locks if TYPE_CHECKING: @@ -77,8 +76,6 @@ from sqlalchemy.orm import Session - from airflow.dag_processing.bundles.base import BaseDagBundle - class DagParsingStat(NamedTuple): """Information on processing progress.""" @@ -102,13 +99,6 @@ class DagFileStat: log = logging.getLogger("airflow.processor_manager") -class DagFileInfo(NamedTuple): - """Information about a DAG file.""" - - path: str # absolute path of the file - bundle_name: str - - class DagFileProcessorAgent(LoggingMixin, MultiprocessingStartMethodMixin): """ Agent for DAG file processing. @@ -119,6 +109,8 @@ class DagFileProcessorAgent(LoggingMixin, MultiprocessingStartMethodMixin): This class runs in the main `airflow scheduler` process when standalone_dag_processor is not enabled. + :param dag_directory: Directory where DAG definitions are kept. All + files in file_paths should be under this directory :param max_runs: The number of times to parse and schedule each file. -1 for unlimited. :param processor_timeout: How long to wait before timing out a DAG file processor @@ -126,10 +118,12 @@ class DagFileProcessorAgent(LoggingMixin, MultiprocessingStartMethodMixin): def __init__( self, + dag_directory: os.PathLike, max_runs: int, processor_timeout: timedelta, ): super().__init__() + self._dag_directory: os.PathLike = dag_directory self._max_runs = max_runs self._processor_timeout = processor_timeout self._process: multiprocessing.Process | None = None @@ -152,6 +146,7 @@ def start(self) -> None: process = context.Process( target=type(self)._run_processor_manager, args=( + self._dag_directory, self._max_runs, self._processor_timeout, child_signal_conn, @@ -176,6 +171,7 @@ def get_callbacks_pipe(self) -> MultiprocessingConnection: @staticmethod def _run_processor_manager( + dag_directory: os.PathLike, max_runs: int, processor_timeout: timedelta, signal_conn: MultiprocessingConnection, @@ -188,6 +184,7 @@ def _run_processor_manager( setproctitle("airflow scheduler -- DagFileProcessorManager") reload_configuration_for_dag_processing() processor_manager = DagFileProcessorManager( + dag_directory=dag_directory, max_runs=max_runs, processor_timeout=processor_timeout.total_seconds(), signal_conn=signal_conn, @@ -306,12 +303,15 @@ class DagFileProcessorManager: processors finish, more are launched. The files are processed over and over again, but no more often than the specified interval. + :param dag_directory: Directory where DAG definitions are kept. All + files in file_paths should be under this directory :param max_runs: The number of times to parse and schedule each file. -1 for unlimited. :param processor_timeout: How long to wait before timing out a DAG file processor :param signal_conn: connection to communicate signal with processor agent. """ + _dag_directory: os.PathLike[str] = attrs.field(validator=_resolve_path) max_runs: int processor_timeout: float = attrs.field(factory=_config_int_factory("core", "dag_file_processor_timeout")) selector: selectors.BaseSelector = attrs.field(factory=selectors.DefaultSelector) @@ -329,6 +329,7 @@ class DagFileProcessorManager: factory=_config_int_factory("scheduler", "min_file_process_interval") ) stale_dag_threshold: float = attrs.field(factory=_config_int_factory("scheduler", "stale_dag_threshold")) + last_dag_dir_refresh_time: float = attrs.field(default=0, init=False) log: logging.Logger = attrs.field(default=log, init=False) @@ -341,16 +342,11 @@ class DagFileProcessorManager: heartbeat: Callable[[], None] = attrs.field(default=lambda: None) """An overridable heartbeat called once every time around the loop""" - _file_paths: list[DagFileInfo] = attrs.field(factory=list, init=False) - _file_path_queue: deque[DagFileInfo] = attrs.field(factory=deque, init=False) - _file_stats: dict[DagFileInfo, DagFileStat] = attrs.field( - factory=lambda: defaultdict(DagFileStat), init=False - ) - - _dag_bundles: list[BaseDagBundle] = attrs.field(factory=list, init=False) - _bundle_versions: dict[str, str] = attrs.field(factory=dict, init=False) + _file_paths: list[str] = attrs.field(factory=list, init=False) + _file_path_queue: deque[str] = attrs.field(factory=deque, init=False) + _file_stats: dict[str, DagFileStat] = attrs.field(factory=lambda: defaultdict(DagFileStat), init=False) - _processors: dict[DagFileInfo, DagFileProcessorProcess] = attrs.field(factory=dict, init=False) + _processors: dict[str, DagFileProcessorProcess] = attrs.field(factory=dict, init=False) _parsing_start_time: float = attrs.field(init=False) _num_run: int = attrs.field(default=0, init=False) @@ -397,17 +393,14 @@ def run(self): self.log.info("Processing files using up to %s processes at a time ", self._parallelism) self.log.info("Process each file at most once every %s seconds", self._file_process_interval) - # TODO: AIP-66 move to report by bundle self.log.info( - # "Checking for new files in %s every %s seconds", self._dag_directory, self.dag_dir_list_interval - # ) + self.log.info( + "Checking for new files in %s every %s seconds", self._dag_directory, self.dag_dir_list_interval + ) from airflow.dag_processing.bundles.manager import DagBundlesManager DagBundlesManager().sync_bundles_to_db() - self.log.info("Getting all DAG bundles") - self._dag_bundles = list(DagBundlesManager().get_all_dag_bundles()) - return self._run_parsing_loop() def _scan_stale_dags(self): @@ -420,6 +413,7 @@ def _scan_stale_dags(self): } self.deactivate_stale_dags( last_parsed=last_parsed, + dag_directory=self.get_dag_directory(), stale_dag_threshold=self.stale_dag_threshold, ) self._last_deactivate_stale_dags_time = time.monotonic() @@ -427,16 +421,14 @@ def _scan_stale_dags(self): @provide_session def deactivate_stale_dags( self, - last_parsed: dict[DagFileInfo, datetime | None], + last_parsed: dict[str, datetime | None], + dag_directory: str, stale_dag_threshold: int, session: Session = NEW_SESSION, ): """Detect and deactivate DAGs which are no longer present in files.""" to_deactivate = set() - query = select( - DagModel.dag_id, DagModel.bundle_name, DagModel.fileloc, DagModel.last_parsed_time - ).where(DagModel.is_active) - # TODO: AIP-66 by bundle! + query = select(DagModel.dag_id, DagModel.fileloc, DagModel.last_parsed_time).where(DagModel.is_active) dags_parsed = session.execute(query) for dag in dags_parsed: @@ -444,11 +436,9 @@ def deactivate_stale_dags( # last_parsed_time is the processor_timeout. Longer than that indicates that the DAG is # no longer present in the file. We have a stale_dag_threshold configured to prevent a # significant delay in deactivation of stale dags when a large timeout is configured - dag_file_path = DagFileInfo(path=dag.fileloc, bundle_name=dag.bundle_name) if ( - dag_file_path in last_parsed - and (dag.last_parsed_time + timedelta(seconds=stale_dag_threshold)) - < last_parsed[dag_file_path] + dag.fileloc in last_parsed + and (dag.last_parsed_time + timedelta(seconds=stale_dag_threshold)) < last_parsed[dag.fileloc] ): self.log.info("DAG %s is missing and will be deactivated.", dag.dag_id) to_deactivate.add(dag.dag_id) @@ -481,9 +471,9 @@ def _run_parsing_loop(self): self.heartbeat() - self._kill_timed_out_processors() + refreshed_dag_dir = self._refresh_dag_dir() - self._refresh_dag_bundles() + self._kill_timed_out_processors() if not self._file_path_queue: # Generate more file paths to process if we processed all the files already. Note for this to @@ -491,7 +481,7 @@ def _run_parsing_loop(self): # cleared all files added as a result of callbacks self.prepare_file_path_queue() self.emit_metrics() - else: + elif refreshed_dag_dir: # if new files found in dag dir, add them self.add_new_file_path_to_queue() @@ -582,8 +572,6 @@ def _read_from_direct_scheduler_conn(self, conn: MultiprocessingConnection) -> b def _refresh_requested_filelocs(self) -> None: """Refresh filepaths from dag dir as requested by users via APIs.""" - return - # TODO: AIP-66 make bundle aware - fileloc will be relative (eventually), thus not unique in order to know what file to repase # Get values from DB table filelocs = self._get_priority_filelocs() for fileloc in filelocs: @@ -621,9 +609,6 @@ def _fetch_callbacks( def _add_callback_to_queue(self, request: CallbackRequest): self.log.debug("Queuing %s CallbackRequest: %s", type(request).__name__, request) - self.log.warning("Callbacks are not implemented yet!") - # TODO: AIP-66 make callbacks bundle aware - return self._callback_to_execute[request.full_filepath].append(request) if request.full_filepath in self._file_path_queue: # Remove file paths matching request.full_filepath from self._file_path_queue @@ -646,70 +631,25 @@ def _get_priority_filelocs(cls, session: Session = NEW_SESSION): session.delete(request) return filelocs - def _refresh_dag_bundles(self): - """Refresh DAG bundles, if required.""" - now = timezone.utcnow() - - self.log.info("Refreshing DAG bundles") - - for bundle in self._dag_bundles: - # TODO: AIP-66 test to make sure we get a fresh record from the db and it's not cached - with create_session() as session: - bundle_model = session.get(DagBundleModel, bundle.name) - elapsed_time_since_refresh = ( - now - (bundle_model.last_refreshed or timezone.utc_epoch()) - ).total_seconds() - current_version = bundle.get_current_version() - if ( - not elapsed_time_since_refresh > bundle.refresh_interval - ) and bundle_model.latest_version == current_version: - self.log.info("Not time to refresh %s", bundle.name) - continue - - try: - bundle.refresh() - except Exception: - self.log.exception("Error refreshing bundle %s", bundle.name) - continue - - bundle_model.last_refreshed = now - - new_version = bundle.get_current_version() - if bundle.supports_versioning: - # We can short-circuit the rest of the refresh if the version hasn't changed - # and we've already fully "refreshed" this bundle before in this dag processor. - if current_version == new_version and bundle.name in self._bundle_versions: - self.log.debug("Bundle %s version not changed after refresh", bundle.name) - continue - - bundle_model.latest_version = new_version - - self.log.info("Version changed for %s, new version: %s", bundle.name, new_version) - - bundle_file_paths = self._find_files_in_bundle(bundle) - - new_file_paths = [f for f in self._file_paths if f.bundle_name != bundle.name] - new_file_paths.extend( - DagFileInfo(path=path, bundle_name=bundle.name) for path in bundle_file_paths - ) - self.set_file_paths(new_file_paths) - - self.deactivate_deleted_dags(bundle_file_paths) - self.clear_nonexistent_import_errors() - - self._bundle_versions[bundle.name] = bundle.get_current_version() + def _refresh_dag_dir(self) -> bool: + """Refresh file paths from dag dir if we haven't done it for too long.""" + now = time.monotonic() + elapsed_time_since_refresh = now - self.last_dag_dir_refresh_time + if elapsed_time_since_refresh <= self.dag_dir_list_interval: + return False - def _find_files_in_bundle(self, bundle: BaseDagBundle) -> list[str]: - """Refresh file paths from bundle dir.""" # Build up a list of Python files that could contain DAGs - self.log.info("Searching for files in %s at %s", bundle.name, bundle.path) - file_paths = list_py_file_paths(bundle.path) - self.log.info("Found %s files for bundle %s", len(file_paths), bundle.name) - - return file_paths + self.log.info("Searching for files in %s", self._dag_directory) + self._file_paths = list_py_file_paths(self._dag_directory) + self.last_dag_dir_refresh_time = now + self.log.info("There are %s files in %s", len(self._file_paths), self._dag_directory) + self.set_file_paths(self._file_paths) - def deactivate_deleted_dags(self, file_paths: set[str]) -> None: - """Deactivate DAGs that come from files that are no longer present.""" + try: + self.log.debug("Removing old import errors") + self.clear_nonexistent_import_errors() + except Exception: + self.log.exception("Error removing old import errors") def _iter_dag_filelocs(fileloc: str) -> Iterator[str]: """ @@ -728,11 +668,12 @@ def _iter_dag_filelocs(fileloc: str) -> Iterator[str]: except zipfile.BadZipFile: self.log.exception("There was an error accessing ZIP file %s %s", fileloc) - dag_filelocs = {full_loc for path in file_paths for full_loc in _iter_dag_filelocs(path)} + dag_filelocs = {full_loc for path in self._file_paths for full_loc in _iter_dag_filelocs(path)} - # TODO: AIP-66: make bundle aware, as fileloc won't be unique long term. DagModel.deactivate_deleted_dags(dag_filelocs) + return True + def _print_stat(self): """Occasionally print out stats about how fast the files are getting processed.""" if 0 < self.print_stats_interval < time.monotonic() - self.last_stat_print_time: @@ -749,18 +690,15 @@ def clear_nonexistent_import_errors(self, session=NEW_SESSION): :param session: session for ORM operations """ self.log.debug("Removing old import errors") - try: - query = delete(ParseImportError) + query = delete(ParseImportError) - if self._file_paths: - query = query.where( - ParseImportError.filename.notin_([f.path for f in self._file_paths]), - ) + if self._file_paths: + query = query.where( + ParseImportError.filename.notin_(self._file_paths), + ) - session.execute(query.execution_options(synchronize_session="fetch")) - session.commit() - except Exception: - self.log.exception("Error removing old import errors") + session.execute(query.execution_options(synchronize_session="fetch")) + session.commit() def _log_file_processing_stats(self, known_file_paths): """ @@ -798,7 +736,7 @@ def _log_file_processing_stats(self, known_file_paths): proc = self._processors.get(file_path) num_dags = stat.num_dags num_errors = stat.import_errors - file_name = Path(file_path.path).stem + file_name = Path(file_path).stem processor_pid = proc.pid if proc else None processor_start_time = proc.start_time if proc else None runtime = (now - processor_start_time) if processor_start_time else None @@ -855,9 +793,15 @@ def _log_file_processing_stats(self, known_file_paths): self.log.info(log_str) - def set_file_paths(self, new_file_paths: list[DagFileInfo]): + def get_dag_directory(self) -> str | None: + """Return the dag_directory as a string.""" + if self._dag_directory is not None: + return os.fspath(self._dag_directory) + return None + + def set_file_paths(self, new_file_paths): """ - Update this with a new set of DagFilePaths to DAG definition files. + Update this with a new set of paths to DAG definition files. :param new_file_paths: list of paths to DAG definition files :return: None @@ -868,10 +812,9 @@ def set_file_paths(self, new_file_paths: list[DagFileInfo]): self._file_path_queue = deque(x for x in self._file_path_queue if x in new_file_paths) Stats.gauge("dag_processing.file_path_queue_size", len(self._file_path_queue)) - # TODO: AIP-66 make callbacks bundle aware - # callback_paths_to_del = [x for x in self._callback_to_execute if x not in new_file_paths] - # for path_to_del in callback_paths_to_del: - # del self._callback_to_execute[path_to_del] + callback_paths_to_del = [x for x in self._callback_to_execute if x not in new_file_paths] + for path_to_del in callback_paths_to_del: + del self._callback_to_execute[path_to_del] # Stop processors that are working on deleted files filtered_processors = {} @@ -895,35 +838,33 @@ def set_file_paths(self, new_file_paths: list[DagFileInfo]): def _collect_results(self, session: Session = NEW_SESSION): # TODO: Use an explicit session in this fn finished = [] - for dag_file, proc in self._processors.items(): + for path, proc in self._processors.items(): if not proc.is_ready: # This processor hasn't finished yet, or we haven't read all the output from it yet continue - finished.append(dag_file) + finished.append(path) # Collect the DAGS and import errors into the DB, emit metrics etc. - self._file_stats[dag_file] = process_parse_results( + self._file_stats[path] = process_parse_results( run_duration=time.time() - proc.start_time, finish_time=timezone.utcnow(), - run_count=self._file_stats[dag_file].run_count, - bundle_name=dag_file.bundle_name, - bundle_version=self._bundle_versions[dag_file.bundle_name], + run_count=self._file_stats[path].run_count, parsing_result=proc.parsing_result, + path=path, session=session, ) - for dag_file in finished: - self._processors.pop(dag_file) + for path in finished: + self._processors.pop(path) - def _create_process(self, dag_file: DagFileInfo) -> DagFileProcessorProcess: + def _create_process(self, file_path): id = uuid7() - # callback_to_execute_for_file = self._callback_to_execute.pop(file_path, []) - callback_to_execute_for_file: list[CallbackRequest] = [] + callback_to_execute_for_file = self._callback_to_execute.pop(file_path, []) return DagFileProcessorProcess.start( id=id, - path=dag_file.path, + path=file_path, callbacks=callback_to_execute_for_file, selector=self.selector, ) @@ -973,7 +914,7 @@ def prepare_file_path_queue(self): for file_path in self._file_paths: if is_mtime_mode: try: - files_with_mtime[file_path] = os.path.getmtime(file_path.path) + files_with_mtime[file_path] = os.path.getmtime(file_path) except FileNotFoundError: self.log.warning("Skipping processing of missing file: %s", file_path) self._file_stats.pop(file_path, None) @@ -1032,8 +973,7 @@ def prepare_file_path_queue(self): ) self.log.debug( - "Queuing the following files for processing:\n\t%s", - "\n\t".join(f.path for f in files_paths_to_queue), + "Queuing the following files for processing:\n\t%s", "\n\t".join(files_paths_to_queue) ) self._add_paths_to_queue(files_paths_to_queue, False) Stats.incr("dag_processing.file_path_queue_update_count") @@ -1071,7 +1011,7 @@ def _kill_timed_out_processors(self): for proc in processors_to_remove: self._processors.pop(proc) - def _add_paths_to_queue(self, file_paths_to_enqueue: list[DagFileInfo], add_at_front: bool): + def _add_paths_to_queue(self, file_paths_to_enqueue: list[str], add_at_front: bool): """Add stuff to the back or front of the file queue, unless it's already present.""" new_file_paths = list(p for p in file_paths_to_enqueue if p not in self._file_path_queue) if add_at_front: @@ -1151,8 +1091,7 @@ def process_parse_results( run_duration: float, finish_time: datetime, run_count: int, - bundle_name: str, - bundle_version: str | None, + path: str, parsing_result: DagFileParsingResult | None, session: Session, ) -> DagFileStat: @@ -1163,18 +1102,15 @@ def process_parse_results( run_count=run_count + 1, ) - # TODO: AIP-66 emit metrics - # file_name = Path(dag_file.path).stem - # Stats.timing(f"dag_processing.last_duration.{file_name}", stat.last_duration) - # Stats.timing("dag_processing.last_duration", stat.last_duration, tags={"file_name": file_name}) + file_name = Path(path).stem + Stats.timing(f"dag_processing.last_duration.{file_name}", stat.last_duration) + Stats.timing("dag_processing.last_duration", stat.last_duration, tags={"file_name": file_name}) if parsing_result is None: stat.import_errors = 1 else: # record DAGs and import errors to database update_dag_parsing_results_in_db( - bundle_name=bundle_name, - bundle_version=bundle_version, dags=parsing_result.serialized_dags, import_errors=parsing_result.import_errors or {}, warnings=set(parsing_result.warnings or []), diff --git a/airflow/jobs/scheduler_job_runner.py b/airflow/jobs/scheduler_job_runner.py index 4718b824830b7..17643b0212195 100644 --- a/airflow/jobs/scheduler_job_runner.py +++ b/airflow/jobs/scheduler_job_runner.py @@ -30,6 +30,7 @@ from datetime import timedelta from functools import lru_cache, partial from itertools import groupby +from pathlib import Path from typing import TYPE_CHECKING, Any, Callable from deprecated import deprecated @@ -921,6 +922,7 @@ def _execute(self) -> int | None: processor_timeout = timedelta(seconds=processor_timeout_seconds) if not self._standalone_dag_processor and not self.processor_agent: self.processor_agent = DagFileProcessorAgent( + dag_directory=Path(self.subdir), max_runs=self.num_times_parse_dags, processor_timeout=processor_timeout, ) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 82b4ca70819b9..4abf24af52537 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -770,16 +770,6 @@ def get_is_paused(self, session=NEW_SESSION) -> None: """Return a boolean indicating whether this DAG is paused.""" return session.scalar(select(DagModel.is_paused).where(DagModel.dag_id == self.dag_id)) - @provide_session - def get_bundle_name(self, session=NEW_SESSION) -> None: - """Return the bundle name this DAG is in.""" - return session.scalar(select(DagModel.bundle_name).where(DagModel.dag_id == self.dag_id)) - - @provide_session - def get_latest_bundle_version(self, session=NEW_SESSION) -> None: - """Return the bundle name this DAG is in.""" - return session.scalar(select(DagModel.latest_bundle_version).where(DagModel.dag_id == self.dag_id)) - @methodtools.lru_cache(maxsize=None) @classmethod def get_serialized_fields(cls): @@ -1842,8 +1832,6 @@ def create_dagrun( @provide_session def bulk_write_to_db( cls, - bundle_name: str, - bundle_version: str | None, dags: Collection[MaybeSerializedDAG], session: Session = NEW_SESSION, ): @@ -1859,9 +1847,7 @@ def bulk_write_to_db( from airflow.dag_processing.collection import AssetModelOperation, DagModelOperation log.info("Sync %s DAGs", len(dags)) - dag_op = DagModelOperation( - bundle_name=bundle_name, bundle_version=bundle_version, dags={d.dag_id: d for d in dags} - ) # type: ignore[misc] + dag_op = DagModelOperation({dag.dag_id: dag for dag in dags}) # type: ignore[misc] orm_dags = dag_op.add_dags(session=session) dag_op.update_dags(orm_dags, session=session) @@ -1887,10 +1873,7 @@ def sync_to_db(self, session=NEW_SESSION): :return: None """ - # TODO: AIP-66 should this be in the model? - bundle_name = self.get_bundle_name(session=session) - bundle_version = self.get_latest_bundle_version(session=session) - self.bulk_write_to_db(bundle_name, bundle_version, [self], session=session) + self.bulk_write_to_db([self], session=session) def get_default_view(self): """Allow backward compatible jinja2 templates.""" diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py index 26f58b26929ae..7d0d2efc1bf6e 100644 --- a/airflow/models/dagbag.py +++ b/airflow/models/dagbag.py @@ -568,17 +568,11 @@ def collect_dags( # Ensure dag_folder is a str -- it may have been a pathlib.Path dag_folder = correct_maybe_zipped(str(dag_folder)) - - files_to_parse = list_py_file_paths(dag_folder, safe_mode=safe_mode) - - if include_examples: - from airflow import example_dags - - example_dag_folder = next(iter(example_dags.__path__)) - - files_to_parse.extend(list_py_file_paths(example_dag_folder, safe_mode=safe_mode)) - - for filepath in files_to_parse: + for filepath in list_py_file_paths( + dag_folder, + safe_mode=safe_mode, + include_examples=include_examples, + ): try: file_parse_start_dttm = timezone.utcnow() found_dags = self.process_file(filepath, only_if_updated=only_if_updated, safe_mode=safe_mode) @@ -632,13 +626,11 @@ def dagbag_report(self): return report @provide_session - def sync_to_db(self, bundle_name: str, bundle_version: str | None, session: Session = NEW_SESSION): + def sync_to_db(self, session: Session = NEW_SESSION): """Save attributes about list of DAG to the DB.""" from airflow.dag_processing.collection import update_dag_parsing_results_in_db update_dag_parsing_results_in_db( - bundle_name, - bundle_version, self.dags.values(), # type: ignore[arg-type] # We should create a proto for DAG|LazySerializedDAG self.import_errors, self.dag_warnings, diff --git a/airflow/utils/file.py b/airflow/utils/file.py index 7e72b26792e9e..5c3e454e294af 100644 --- a/airflow/utils/file.py +++ b/airflow/utils/file.py @@ -245,6 +245,7 @@ def find_path_from_directory( def list_py_file_paths( directory: str | os.PathLike[str] | None, safe_mode: bool = conf.getboolean("core", "DAG_DISCOVERY_SAFE_MODE", fallback=True), + include_examples: bool | None = None, ) -> list[str]: """ Traverse a directory and look for Python files. @@ -254,8 +255,11 @@ def list_py_file_paths( contains Airflow DAG definitions. If not provided, use the core.DAG_DISCOVERY_SAFE_MODE configuration setting. If not set, default to safe. + :param include_examples: include example DAGs :return: a list of paths to Python files in the specified directory """ + if include_examples is None: + include_examples = conf.getboolean("core", "LOAD_EXAMPLES") file_paths: list[str] = [] if directory is None: file_paths = [] @@ -263,6 +267,11 @@ def list_py_file_paths( file_paths = [str(directory)] elif os.path.isdir(directory): file_paths.extend(find_dag_file_paths(directory, safe_mode)) + if include_examples: + from airflow import example_dags + + example_dag_folder = next(iter(example_dags.__path__)) + file_paths.extend(list_py_file_paths(example_dag_folder, safe_mode, include_examples=False)) return file_paths diff --git a/providers/tests/fab/auth_manager/api_endpoints/test_dag_run_endpoint.py b/providers/tests/fab/auth_manager/api_endpoints/test_dag_run_endpoint.py index e745d3d655bdc..33df2e9bbc407 100644 --- a/providers/tests/fab/auth_manager/api_endpoints/test_dag_run_endpoint.py +++ b/providers/tests/fab/auth_manager/api_endpoints/test_dag_run_endpoint.py @@ -20,7 +20,7 @@ import pytest -from airflow.models.dag import DagModel +from airflow.models.dag import DAG, DagModel from airflow.models.dagrun import DagRun from airflow.models.param import Param from airflow.providers.fab.www.security import permissions @@ -125,26 +125,29 @@ def setup_attrs(self, configured_app) -> None: clear_db_serialized_dags() clear_db_dags() - @pytest.fixture(autouse=True) - def create_dag(self, dag_maker, setup_attrs): - with dag_maker( - "TEST_DAG_ID", schedule=None, params={"validated_number": Param(1, minimum=1, maximum=10)} - ): - pass - - dag_maker.sync_dagbag_to_db() - def teardown_method(self) -> None: clear_db_runs() clear_db_dags() clear_db_serialized_dags() + def _create_dag(self, dag_id): + dag_instance = DagModel(dag_id=dag_id) + dag_instance.is_active = True + with create_session() as session: + session.add(dag_instance) + dag = DAG(dag_id=dag_id, schedule=None, params={"validated_number": Param(1, minimum=1, maximum=10)}) + self.app.dag_bag.bag_dag(dag) + self.app.dag_bag.sync_to_db() + return dag_instance + def _create_test_dag_run(self, state=DagRunState.RUNNING, extra_dag=False, commit=True, idx_start=1): dag_runs = [] dags = [] triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} for i in range(idx_start, idx_start + 2): + if i == 1: + dags.append(DagModel(dag_id="TEST_DAG_ID", is_active=True)) dagrun_model = DagRun( dag_id="TEST_DAG_ID", run_id=f"TEST_DAG_RUN_ID_{i}", @@ -244,6 +247,7 @@ def test_should_return_accessible_with_tilde_as_dag_id_and_dag_level_permissions class TestPostDagRun(TestDagRunEndpoint): def test_dagrun_trigger_with_dag_level_permissions(self): + self._create_dag("TEST_DAG_ID") response = self.client.post( "api/v1/dags/TEST_DAG_ID/dagRuns", json={"conf": {"validated_number": 1}}, @@ -256,6 +260,7 @@ def test_dagrun_trigger_with_dag_level_permissions(self): ["test_dag_view_only", "test_view_dags", "test_granular_permissions"], ) def test_should_raises_403_unauthorized(self, username): + self._create_dag("TEST_DAG_ID") response = self.client.post( "api/v1/dags/TEST_DAG_ID/dagRuns", json={ diff --git a/providers/tests/fab/auth_manager/api_endpoints/test_dag_source_endpoint.py b/providers/tests/fab/auth_manager/api_endpoints/test_dag_source_endpoint.py index 66cd6477c9e9b..8461e5bf0f170 100644 --- a/providers/tests/fab/auth_manager/api_endpoints/test_dag_source_endpoint.py +++ b/providers/tests/fab/auth_manager/api_endpoints/test_dag_source_endpoint.py @@ -18,6 +18,7 @@ import ast import os +from typing import TYPE_CHECKING import pytest @@ -25,12 +26,7 @@ from airflow.providers.fab.www.security import permissions from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user -from tests_common.test_utils.db import ( - clear_db_dag_code, - clear_db_dags, - clear_db_serialized_dags, - parse_and_sync_to_db, -) +from tests_common.test_utils.db import clear_db_dag_code, clear_db_dags, clear_db_serialized_dags from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS pytestmark = [ @@ -38,7 +34,11 @@ pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), ] +if TYPE_CHECKING: + from airflow.models.dag import DAG +ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)) +EXAMPLE_DAG_FILE = os.path.join("airflow", "example_dags", "example_bash_operator.py") EXAMPLE_DAG_ID = "example_bash_operator" TEST_DAG_ID = "latest_only" NOT_READABLE_DAG_ID = "latest_only_with_trigger" @@ -97,9 +97,9 @@ def _get_dag_file_docstring(fileloc: str) -> str | None: return docstring def test_should_respond_403_not_readable(self, url_safe_serializer): - parse_and_sync_to_db(os.devnull, include_examples=True) - dagbag = DagBag(read_dags_from_db=True) - dag = dagbag.get_dag(NOT_READABLE_DAG_ID) + dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE) + dagbag.sync_to_db() + dag: DAG = dagbag.dags[NOT_READABLE_DAG_ID] response = self.client.get( f"/api/v1/dagSources/{dag.dag_id}", @@ -114,9 +114,9 @@ def test_should_respond_403_not_readable(self, url_safe_serializer): assert read_dag.status_code == 403 def test_should_respond_403_some_dags_not_readable_in_the_file(self, url_safe_serializer): - parse_and_sync_to_db(os.devnull, include_examples=True) - dagbag = DagBag(read_dags_from_db=True) - dag = dagbag.get_dag(TEST_MULTIPLE_DAGS_ID) + dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE) + dagbag.sync_to_db() + dag: DAG = dagbag.dags[TEST_MULTIPLE_DAGS_ID] response = self.client.get( f"/api/v1/dagSources/{dag.dag_id}", diff --git a/providers/tests/fab/auth_manager/conftest.py b/providers/tests/fab/auth_manager/conftest.py index f26a08d19c3e3..d400a7b86a027 100644 --- a/providers/tests/fab/auth_manager/conftest.py +++ b/providers/tests/fab/auth_manager/conftest.py @@ -16,14 +16,11 @@ # under the License. from __future__ import annotations -import os - import pytest from airflow.www import app from tests_common.test_utils.config import conf_vars -from tests_common.test_utils.db import parse_and_sync_to_db from tests_common.test_utils.decorators import dont_initialize_flask_app_submodules @@ -75,5 +72,5 @@ def set_auth_role_public(request): def dagbag(): from airflow.models import DagBag - parse_and_sync_to_db(os.devnull, include_examples=True) - return DagBag(read_dags_from_db=True) + DagBag(include_examples=True, read_dags_from_db=False).sync_to_db() + return DagBag(include_examples=True, read_dags_from_db=True) diff --git a/providers/tests/openlineage/plugins/test_utils.py b/providers/tests/openlineage/plugins/test_utils.py index 490061db4f594..e1e3e8732690a 100644 --- a/providers/tests/openlineage/plugins/test_utils.py +++ b/providers/tests/openlineage/plugins/test_utils.py @@ -28,7 +28,7 @@ from openlineage.client.utils import RedactMixin from pkg_resources import parse_version -from airflow.models import DagModel +from airflow.models import DAG as AIRFLOW_DAG, DagModel from airflow.providers.common.compat.assets import Asset from airflow.providers.openlineage.plugins.facets import AirflowDebugRunFacet from airflow.providers.openlineage.utils.utils import ( @@ -91,14 +91,12 @@ def test_get_airflow_debug_facet_logging_set_to_debug(mock_debug_mode, mock_get_ @pytest.mark.db_test -def test_get_dagrun_start_end(dag_maker): +def test_get_dagrun_start_end(): start_date = datetime.datetime(2022, 1, 1) end_date = datetime.datetime(2022, 1, 1, hour=2) - with dag_maker("test", start_date=start_date, end_date=end_date, schedule="@once") as dag: - pass - dag_maker.sync_dagbag_to_db() + dag = AIRFLOW_DAG("test", start_date=start_date, end_date=end_date, schedule="@once") + AIRFLOW_DAG.bulk_write_to_db([dag]) dag_model = DagModel.get_dagmodel(dag.dag_id) - run_id = str(uuid.uuid1()) triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} dagrun = dag.create_dagrun( diff --git a/tests/api_connexion/conftest.py b/tests/api_connexion/conftest.py index fa5d4c9250065..39e8da7212ae7 100644 --- a/tests/api_connexion/conftest.py +++ b/tests/api_connexion/conftest.py @@ -16,14 +16,11 @@ # under the License. from __future__ import annotations -import os - import pytest from airflow.www import app from tests_common.test_utils.config import conf_vars -from tests_common.test_utils.db import parse_and_sync_to_db from tests_common.test_utils.decorators import dont_initialize_flask_app_submodules @@ -67,5 +64,5 @@ def session(): def dagbag(): from airflow.models import DagBag - parse_and_sync_to_db(os.devnull, include_examples=True) - return DagBag(read_dags_from_db=True) + DagBag(include_examples=True, read_dags_from_db=False).sync_to_db() + return DagBag(include_examples=True, read_dags_from_db=True) diff --git a/tests/api_connexion/endpoints/test_dag_parsing.py b/tests/api_connexion/endpoints/test_dag_parsing.py index 0052732b7dfb9..1df80a905d92e 100644 --- a/tests/api_connexion/endpoints/test_dag_parsing.py +++ b/tests/api_connexion/endpoints/test_dag_parsing.py @@ -17,6 +17,7 @@ from __future__ import annotations import os +from typing import TYPE_CHECKING import pytest from sqlalchemy import select @@ -25,14 +26,17 @@ from airflow.models.dagbag import DagPriorityParsingRequest from tests_common.test_utils.api_connexion_utils import create_user, delete_user -from tests_common.test_utils.db import clear_db_dag_parsing_requests, parse_and_sync_to_db +from tests_common.test_utils.db import clear_db_dag_parsing_requests pytestmark = pytest.mark.db_test +if TYPE_CHECKING: + from airflow.models.dag import DAG ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)) EXAMPLE_DAG_FILE = os.path.join("airflow", "example_dags", "example_bash_operator.py") -TEST_DAG_ID = "example_bash_operator" +EXAMPLE_DAG_ID = "example_bash_operator" +TEST_DAG_ID = "latest_only" NOT_READABLE_DAG_ID = "latest_only_with_trigger" TEST_MULTIPLE_DAGS_ID = "asset_produces_1" @@ -68,9 +72,9 @@ def clear_db(): clear_db_dag_parsing_requests() def test_201_and_400_requests(self, url_safe_serializer, session): - parse_and_sync_to_db(EXAMPLE_DAG_FILE) - dagbag = DagBag(read_dags_from_db=True) - test_dag = dagbag.get_dag(TEST_DAG_ID) + dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE) + dagbag.sync_to_db() + test_dag: DAG = dagbag.dags[TEST_DAG_ID] url = f"/api/v1/parseDagFile/{url_safe_serializer.dumps(test_dag.fileloc)}" response = self.client.put( diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py b/tests/api_connexion/endpoints/test_dag_run_endpoint.py index b6609c162390d..4f322bb8f0a6c 100644 --- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py @@ -24,7 +24,6 @@ import time_machine from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP -from airflow.dag_processing.bundles.manager import DagBundlesManager from airflow.models.asset import AssetEvent, AssetModel from airflow.models.dag import DAG, DagModel from airflow.models.dagrun import DagRun @@ -89,9 +88,8 @@ def _create_dag(self, dag_id): with create_session() as session: session.add(dag_instance) dag = DAG(dag_id=dag_id, schedule=None, params={"validated_number": Param(1, minimum=1, maximum=10)}) - DagBundlesManager().sync_bundles_to_db() self.app.dag_bag.bag_dag(dag) - self.app.dag_bag.sync_to_db("dags-folder", None) + self.app.dag_bag.sync_to_db() return dag_instance def _create_test_dag_run(self, state=DagRunState.RUNNING, extra_dag=False, commit=True, idx_start=1): diff --git a/tests/api_connexion/endpoints/test_dag_source_endpoint.py b/tests/api_connexion/endpoints/test_dag_source_endpoint.py index c7e3b12c7e6fd..a907d2704c6e7 100644 --- a/tests/api_connexion/endpoints/test_dag_source_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_source_endpoint.py @@ -16,8 +16,6 @@ # under the License. from __future__ import annotations -import os - import pytest from sqlalchemy import select @@ -26,12 +24,13 @@ from airflow.models.serialized_dag import SerializedDagModel from tests_common.test_utils.api_connexion_utils import assert_401, create_user, delete_user -from tests_common.test_utils.db import clear_db_dags, parse_and_sync_to_db +from tests_common.test_utils.db import clear_db_dags pytestmark = pytest.mark.db_test -TEST_DAG_ID = "example_bash_operator" +EXAMPLE_DAG_ID = "example_bash_operator" +TEST_DAG_ID = "latest_only" @pytest.fixture(scope="module") @@ -52,9 +51,9 @@ def configured_app(minimal_app_for_api): @pytest.fixture def test_dag(): - parse_and_sync_to_db(os.devnull, include_examples=True) - dagbag = DagBag(read_dags_from_db=True) - return dagbag.get_dag(TEST_DAG_ID) + dagbag = DagBag(include_examples=True) + dagbag.sync_to_db() + return dagbag.dags[TEST_DAG_ID] class TestGetSource: diff --git a/tests/api_connexion/endpoints/test_extra_link_endpoint.py b/tests/api_connexion/endpoints/test_extra_link_endpoint.py index e4aa7895ac662..8cb8dd4e030c1 100644 --- a/tests/api_connexion/endpoints/test_extra_link_endpoint.py +++ b/tests/api_connexion/endpoints/test_extra_link_endpoint.py @@ -22,7 +22,6 @@ import pytest from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP -from airflow.dag_processing.bundles.manager import DagBundlesManager from airflow.models.dag import DAG from airflow.models.dagbag import DagBag from airflow.models.xcom import XCom @@ -74,10 +73,9 @@ def setup_attrs(self, configured_app, session) -> None: self.dag = self._create_dag() - DagBundlesManager().sync_bundles_to_db() self.app.dag_bag = DagBag(os.devnull, include_examples=False) self.app.dag_bag.dags = {self.dag.dag_id: self.dag} - self.app.dag_bag.sync_to_db("dags-folder", None) + self.app.dag_bag.sync_to_db() triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} self.dag.create_dagrun( diff --git a/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py b/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py index 7a4d802f6cc15..a4a14fa3b421d 100644 --- a/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py +++ b/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py @@ -24,7 +24,6 @@ import pytest from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP -from airflow.dag_processing.bundles.manager import DagBundlesManager from airflow.models import TaskInstance from airflow.models.baseoperator import BaseOperator from airflow.models.dagbag import DagBag @@ -123,10 +122,9 @@ def create_dag_runs_with_mapped_tasks(self, dag_maker, session, dags=None): setattr(ti, "start_date", DEFAULT_DATETIME_1) session.add(ti) - DagBundlesManager().sync_bundles_to_db() self.app.dag_bag = DagBag(os.devnull, include_examples=False) self.app.dag_bag.dags = {dag_id: dag_maker.dag} - self.app.dag_bag.sync_to_db("dags-folder", None) + self.app.dag_bag.sync_to_db() session.flush() mapped.expand_mapped_task(dr.run_id, session=session) diff --git a/tests/api_connexion/endpoints/test_task_endpoint.py b/tests/api_connexion/endpoints/test_task_endpoint.py index 7e593731d539d..d1662943604ac 100644 --- a/tests/api_connexion/endpoints/test_task_endpoint.py +++ b/tests/api_connexion/endpoints/test_task_endpoint.py @@ -22,7 +22,6 @@ import pytest -from airflow.dag_processing.bundles.manager import DagBundlesManager from airflow.models import DagBag from airflow.models.dag import DAG from airflow.models.expandinput import EXPAND_INPUT_EMPTY @@ -81,15 +80,13 @@ def setup_dag(self, configured_app): task5 = EmptyOperator(task_id=self.unscheduled_task_id2, params={"is_unscheduled": True}) task1 >> task2 task4 >> task5 - - DagBundlesManager().sync_bundles_to_db() dag_bag = DagBag(os.devnull, include_examples=False) dag_bag.dags = { dag.dag_id: dag, mapped_dag.dag_id: mapped_dag, unscheduled_dag.dag_id: unscheduled_dag, } - dag_bag.sync_to_db("dags-folder", None) + dag_bag.sync_to_db() configured_app.dag_bag = dag_bag # type:ignore @staticmethod diff --git a/tests/api_connexion/endpoints/test_task_instance_endpoint.py b/tests/api_connexion/endpoints/test_task_instance_endpoint.py index b5079c47aa17e..a4089c9785a98 100644 --- a/tests/api_connexion/endpoints/test_task_instance_endpoint.py +++ b/tests/api_connexion/endpoints/test_task_instance_endpoint.py @@ -1224,7 +1224,7 @@ def test_should_respond_200(self, main_dag, task_instances, request_dag, payload task_instances=task_instances, update_extras=False, ) - self.app.dag_bag.sync_to_db("dags-folder", None) + self.app.dag_bag.sync_to_db() response = self.client.post( f"/api/v1/dags/{request_dag}/clearTaskInstances", environ_overrides={"REMOTE_USER": "test"}, @@ -1246,7 +1246,7 @@ def test_clear_taskinstance_is_called_with_queued_dr_state(self, mock_clearti, s self.create_task_instances(session) dag_id = "example_python_operator" payload = {"reset_dag_runs": True, "dry_run": False} - self.app.dag_bag.sync_to_db("dags-folder", None) + self.app.dag_bag.sync_to_db() response = self.client.post( f"/api/v1/dags/{dag_id}/clearTaskInstances", environ_overrides={"REMOTE_USER": "test"}, @@ -1266,7 +1266,7 @@ def test_clear_taskinstance_is_called_with_invalid_task_ids(self, session): assert dagrun.state == "running" payload = {"dry_run": False, "reset_dag_runs": True, "task_ids": [""]} - self.app.dag_bag.sync_to_db("dags-folder", None) + self.app.dag_bag.sync_to_db() response = self.client.post( f"/api/v1/dags/{dag_id}/clearTaskInstances", environ_overrides={"REMOTE_USER": "test"}, @@ -1693,7 +1693,7 @@ def test_should_raise_400_for_naive_and_bad_datetime(self, payload, expected, se task_instances=task_instances, update_extras=False, ) - self.app.dag_bag.sync_to_db("dags-folder", None) + self.app.dag_bag.sync_to_db() response = self.client.post( "/api/v1/dags/example_python_operator/clearTaskInstances", environ_overrides={"REMOTE_USER": "test"}, diff --git a/tests/api_fastapi/conftest.py b/tests/api_fastapi/conftest.py index 71f42756dd564..2928a4d829c70 100644 --- a/tests/api_fastapi/conftest.py +++ b/tests/api_fastapi/conftest.py @@ -16,15 +16,11 @@ # under the License. from __future__ import annotations -import os - import pytest from fastapi.testclient import TestClient from airflow.api_fastapi.app import create_app -from tests_common.test_utils.db import parse_and_sync_to_db - @pytest.fixture def test_client(): @@ -46,5 +42,6 @@ def create_test_client(apps="all"): def dagbag(): from airflow.models import DagBag - parse_and_sync_to_db(os.devnull, include_examples=True) - return DagBag(read_dags_from_db=True) + dagbag_instance = DagBag(include_examples=True, read_dags_from_db=False) + dagbag_instance.sync_to_db() + return dagbag_instance diff --git a/tests/api_fastapi/core_api/routes/public/test_dag_parsing.py b/tests/api_fastapi/core_api/routes/public/test_dag_parsing.py index fe78c193d389b..b937f66803f31 100644 --- a/tests/api_fastapi/core_api/routes/public/test_dag_parsing.py +++ b/tests/api_fastapi/core_api/routes/public/test_dag_parsing.py @@ -17,6 +17,7 @@ from __future__ import annotations import os +from typing import TYPE_CHECKING import pytest from sqlalchemy import select @@ -25,15 +26,19 @@ from airflow.models.dagbag import DagPriorityParsingRequest from airflow.utils.session import provide_session -from tests_common.test_utils.db import clear_db_dag_parsing_requests, parse_and_sync_to_db +from tests_common.test_utils.db import clear_db_dag_parsing_requests pytestmark = pytest.mark.db_test +if TYPE_CHECKING: + from airflow.models.dag import DAG + class TestDagParsingEndpoint: ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)) EXAMPLE_DAG_FILE = os.path.join("airflow", "example_dags", "example_bash_operator.py") - TEST_DAG_ID = "example_bash_operator" + EXAMPLE_DAG_ID = "example_bash_operator" + TEST_DAG_ID = "latest_only" NOT_READABLE_DAG_ID = "latest_only_with_trigger" TEST_MULTIPLE_DAGS_ID = "asset_produces_1" @@ -50,9 +55,9 @@ def teardown_method(self) -> None: self.clear_db() def test_201_and_400_requests(self, url_safe_serializer, session, test_client): - parse_and_sync_to_db(self.EXAMPLE_DAG_FILE) - dagbag = DagBag(read_dags_from_db=True) - test_dag = dagbag.get_dag(self.TEST_DAG_ID) + dagbag = DagBag(dag_folder=self.EXAMPLE_DAG_FILE) + dagbag.sync_to_db() + test_dag: DAG = dagbag.dags[self.TEST_DAG_ID] url = f"/public/parseDagFile/{url_safe_serializer.dumps(test_dag.fileloc)}" response = test_client.put(url, headers={"Accept": "application/json"}) diff --git a/tests/api_fastapi/core_api/routes/public/test_dag_run.py b/tests/api_fastapi/core_api/routes/public/test_dag_run.py index 9b7c9211fddef..75fec91e1a0c7 100644 --- a/tests/api_fastapi/core_api/routes/public/test_dag_run.py +++ b/tests/api_fastapi/core_api/routes/public/test_dag_run.py @@ -141,7 +141,7 @@ def setup(request, dag_maker, session=None): logical_date=LOGICAL_DATE4, ) - dag_maker.sync_dagbag_to_db() + dag_maker.dagbag.sync_to_db() dag_maker.dag_model dag_maker.dag_model.has_task_concurrency_limits = True session.merge(ti1) diff --git a/tests/api_fastapi/core_api/routes/public/test_dag_sources.py b/tests/api_fastapi/core_api/routes/public/test_dag_sources.py index dce7981872085..4e8a5990ecb26 100644 --- a/tests/api_fastapi/core_api/routes/public/test_dag_sources.py +++ b/tests/api_fastapi/core_api/routes/public/test_dag_sources.py @@ -28,7 +28,7 @@ from airflow.models.dagcode import DagCode from airflow.models.serialized_dag import SerializedDagModel -from tests_common.test_utils.db import clear_db_dags, parse_and_sync_to_db +from tests_common.test_utils.db import clear_db_dags pytestmark = pytest.mark.db_test @@ -36,13 +36,14 @@ # Example bash operator located here: airflow/example_dags/example_bash_operator.py EXAMPLE_DAG_FILE = os.path.join("airflow", "example_dags", "example_bash_operator.py") -TEST_DAG_ID = "example_bash_operator" +TEST_DAG_ID = "latest_only" @pytest.fixture def test_dag(): - parse_and_sync_to_db(EXAMPLE_DAG_FILE, include_examples=False) - return DagBag(read_dags_from_db=True).get_dag(TEST_DAG_ID) + dagbag = DagBag(include_examples=True) + dagbag.sync_to_db() + return dagbag.dags[TEST_DAG_ID] class TestGetDAGSource: @@ -130,7 +131,9 @@ def test_should_respond_200_version(self, test_client, accept, session, test_dag "version_number": 2, } - def test_should_respond_406_unsupport_mime_type(self, test_client, test_dag): + def test_should_respond_406_unsupport_mime_type(self, test_client): + dagbag = DagBag(include_examples=True) + dagbag.sync_to_db() response = test_client.get( f"{API_PREFIX}/{TEST_DAG_ID}", headers={"Accept": "text/html"}, diff --git a/tests/api_fastapi/core_api/routes/public/test_dag_tags.py b/tests/api_fastapi/core_api/routes/public/test_dag_tags.py index 7d7720c76f773..784bb480c431b 100644 --- a/tests/api_fastapi/core_api/routes/public/test_dag_tags.py +++ b/tests/api_fastapi/core_api/routes/public/test_dag_tags.py @@ -111,7 +111,6 @@ def setup(self, dag_maker, session=None) -> None: ): EmptyOperator(task_id=TASK_ID) - dag_maker.sync_dagbag_to_db() dag_maker.create_dagrun(state=DagRunState.FAILED) with dag_maker( @@ -128,7 +127,7 @@ def setup(self, dag_maker, session=None) -> None: self._create_deactivated_paused_dag(session) self._create_dag_tags(session) - dag_maker.sync_dagbag_to_db() + dag_maker.dagbag.sync_to_db() dag_maker.dag_model.has_task_concurrency_limits = True session.merge(dag_maker.dag_model) session.commit() diff --git a/tests/api_fastapi/core_api/routes/public/test_dags.py b/tests/api_fastapi/core_api/routes/public/test_dags.py index b79c23eb36b4e..05a49c44dfeea 100644 --- a/tests/api_fastapi/core_api/routes/public/test_dags.py +++ b/tests/api_fastapi/core_api/routes/public/test_dags.py @@ -127,7 +127,7 @@ def setup(self, dag_maker, session=None) -> None: self._create_deactivated_paused_dag(session) self._create_dag_tags(session) - dag_maker.sync_dagbag_to_db() + dag_maker.dagbag.sync_to_db() dag_maker.dag_model.has_task_concurrency_limits = True session.merge(dag_maker.dag_model) session.commit() @@ -409,7 +409,7 @@ def _create_dag_for_deletion( ti = dr.get_task_instances()[0] ti.set_state(TaskInstanceState.RUNNING) - dag_maker.sync_dagbag_to_db() + dag_maker.dagbag.sync_to_db() @pytest.mark.parametrize( "dag_id, dag_display_name, status_code_delete, status_code_details, has_running_dagruns, is_create_dag", diff --git a/tests/api_fastapi/core_api/routes/public/test_extra_links.py b/tests/api_fastapi/core_api/routes/public/test_extra_links.py index 89aef555bbe20..6a2ca4650075e 100644 --- a/tests/api_fastapi/core_api/routes/public/test_extra_links.py +++ b/tests/api_fastapi/core_api/routes/public/test_extra_links.py @@ -21,7 +21,6 @@ import pytest -from airflow.dag_processing.bundles.manager import DagBundlesManager from airflow.models.dag import DAG from airflow.models.dagbag import DagBag from airflow.models.xcom import XCom @@ -68,11 +67,10 @@ def setup(self, test_client, session=None) -> None: self.dag = self._create_dag() - DagBundlesManager().sync_bundles_to_db() dag_bag = DagBag(os.devnull, include_examples=False) dag_bag.dags = {self.dag.dag_id: self.dag} test_client.app.state.dag_bag = dag_bag - dag_bag.sync_to_db("dags-folder", None) + dag_bag.sync_to_db() triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} diff --git a/tests/api_fastapi/core_api/routes/public/test_task_instances.py b/tests/api_fastapi/core_api/routes/public/test_task_instances.py index f0c3c888807ea..d62d37944348c 100644 --- a/tests/api_fastapi/core_api/routes/public/test_task_instances.py +++ b/tests/api_fastapi/core_api/routes/public/test_task_instances.py @@ -26,7 +26,6 @@ import pytest from sqlalchemy import select -from airflow.dag_processing.bundles.manager import DagBundlesManager from airflow.jobs.job import Job from airflow.jobs.triggerer_job_runner import TriggererJobRunner from airflow.models import DagRun, TaskInstance @@ -523,10 +522,9 @@ def create_dag_runs_with_mapped_tasks(self, dag_maker, session, dags=None): setattr(ti, "start_date", DEFAULT_DATETIME_1) session.add(ti) - DagBundlesManager().sync_bundles_to_db() dagbag = DagBag(os.devnull, include_examples=False) dagbag.dags = {dag_id: dag_maker.dag} - dagbag.sync_to_db("dags-folder", None) + dagbag.sync_to_db() session.flush() mapped.expand_mapped_task(dr.run_id, session=session) @@ -1859,7 +1857,7 @@ def test_should_respond_200( task_instances=task_instances, update_extras=False, ) - self.dagbag.sync_to_db("dags-folder", None) + self.dagbag.sync_to_db() response = test_client.post( f"/public/dags/{request_dag}/clearTaskInstances", json=payload, @@ -1873,7 +1871,7 @@ def test_clear_taskinstance_is_called_with_queued_dr_state(self, mock_clearti, t self.create_task_instances(session) dag_id = "example_python_operator" payload = {"reset_dag_runs": True, "dry_run": False} - self.dagbag.sync_to_db("dags-folder", None) + self.dagbag.sync_to_db() response = test_client.post( f"/public/dags/{dag_id}/clearTaskInstances", json=payload, @@ -1893,7 +1891,7 @@ def test_clear_taskinstance_is_called_with_invalid_task_ids(self, test_client, s assert dagrun.state == "running" payload = {"dry_run": False, "reset_dag_runs": True, "task_ids": [""]} - self.dagbag.sync_to_db("dags-folder", None) + self.dagbag.sync_to_db() response = test_client.post( f"/public/dags/{dag_id}/clearTaskInstances", json=payload, @@ -1943,7 +1941,7 @@ def test_should_respond_200_with_reset_dag_run(self, test_client, session): update_extras=False, dag_run_state=DagRunState.FAILED, ) - self.dagbag.sync_to_db("dags-folder", None) + self.dagbag.sync_to_db() response = test_client.post( f"/public/dags/{dag_id}/clearTaskInstances", json=payload, @@ -2028,7 +2026,7 @@ def test_should_respond_200_with_dag_run_id(self, test_client, session): update_extras=False, dag_run_state=State.FAILED, ) - self.dagbag.sync_to_db("dags-folder", None) + self.dagbag.sync_to_db() response = test_client.post( f"/public/dags/{dag_id}/clearTaskInstances", json=payload, @@ -2084,7 +2082,7 @@ def test_should_respond_200_with_include_past(self, test_client, session): update_extras=False, dag_run_state=State.FAILED, ) - self.dagbag.sync_to_db("dags-folder", None) + self.dagbag.sync_to_db() response = test_client.post( f"/public/dags/{dag_id}/clearTaskInstances", json=payload, @@ -2166,7 +2164,7 @@ def test_should_respond_200_with_include_future(self, test_client, session): update_extras=False, dag_run_state=State.FAILED, ) - self.dagbag.sync_to_db("dags-folder", None) + self.dagbag.sync_to_db() response = test_client.post( f"/public/dags/{dag_id}/clearTaskInstances", json=payload, @@ -2314,7 +2312,7 @@ def test_should_raise_400_for_naive_and_bad_datetime(self, test_client, session, task_instances=task_instances, update_extras=False, ) - self.dagbag.sync_to_db("dags-folder", None) + self.dagbag.sync_to_db() response = test_client.post( "/public/dags/example_python_operator/clearTaskInstances", json=payload, diff --git a/tests/api_fastapi/core_api/routes/ui/test_assets.py b/tests/api_fastapi/core_api/routes/ui/test_assets.py index 4935b25014ea7..cbe1b292f6d7e 100644 --- a/tests/api_fastapi/core_api/routes/ui/test_assets.py +++ b/tests/api_fastapi/core_api/routes/ui/test_assets.py @@ -46,7 +46,7 @@ def test_next_run_assets(test_client, dag_maker): EmptyOperator(task_id="task1") dag_maker.create_dagrun() - dag_maker.sync_dagbag_to_db() + dag_maker.dagbag.sync_to_db() response = test_client.get("/ui/next_run_assets/upstream") diff --git a/tests/api_fastapi/core_api/routes/ui/test_dashboard.py b/tests/api_fastapi/core_api/routes/ui/test_dashboard.py index 416e016e25c95..93317eaa67088 100644 --- a/tests/api_fastapi/core_api/routes/ui/test_dashboard.py +++ b/tests/api_fastapi/core_api/routes/ui/test_dashboard.py @@ -95,7 +95,7 @@ def make_dag_runs(dag_maker, session, time_machine): for ti in run2.task_instances: ti.state = TaskInstanceState.FAILED - dag_maker.sync_dagbag_to_db() + dag_maker.dagbag.sync_to_db() time_machine.move_to("2023-07-02T00:00:00+00:00", tick=False) diff --git a/tests/api_fastapi/core_api/routes/ui/test_structure.py b/tests/api_fastapi/core_api/routes/ui/test_structure.py index 1f865a140c5be..9732269944dfb 100644 --- a/tests/api_fastapi/core_api/routes/ui/test_structure.py +++ b/tests/api_fastapi/core_api/routes/ui/test_structure.py @@ -59,7 +59,7 @@ def make_dag(dag_maker, session, time_machine): ): TriggerDagRunOperator(task_id="trigger_dag_run_operator", trigger_dag_id=DAG_ID) - dag_maker.sync_dagbag_to_db() + dag_maker.dagbag.sync_to_db() with dag_maker( dag_id=DAG_ID, @@ -78,7 +78,7 @@ def make_dag(dag_maker, session, time_machine): >> EmptyOperator(task_id="task_2") ) - dag_maker.sync_dagbag_to_db() + dag_maker.dagbag.sync_to_db() class TestStructureDataEndpoint: diff --git a/tests/cli/commands/remote_commands/test_asset_command.py b/tests/cli/commands/remote_commands/test_asset_command.py index 067bc9e8e1e2b..69906d1813a29 100644 --- a/tests/cli/commands/remote_commands/test_asset_command.py +++ b/tests/cli/commands/remote_commands/test_asset_command.py @@ -21,15 +21,15 @@ import contextlib import io import json -import os import typing import pytest from airflow.cli import cli_parser from airflow.cli.commands.remote_commands import asset_command +from airflow.models.dagbag import DagBag -from tests_common.test_utils.db import clear_db_dags, clear_db_runs, parse_and_sync_to_db +from tests_common.test_utils.db import clear_db_dags, clear_db_runs if typing.TYPE_CHECKING: from argparse import ArgumentParser @@ -39,7 +39,7 @@ @pytest.fixture(scope="module", autouse=True) def prepare_examples(): - parse_and_sync_to_db(os.devnull, include_examples=True) + DagBag(include_examples=True).sync_to_db() yield clear_db_runs() clear_db_dags() diff --git a/tests/cli/commands/remote_commands/test_backfill_command.py b/tests/cli/commands/remote_commands/test_backfill_command.py index 38b0169a0e7dd..e252d4db1c478 100644 --- a/tests/cli/commands/remote_commands/test_backfill_command.py +++ b/tests/cli/commands/remote_commands/test_backfill_command.py @@ -18,7 +18,6 @@ from __future__ import annotations import argparse -import os from datetime import datetime from unittest import mock @@ -27,10 +26,11 @@ import airflow.cli.commands.remote_commands.backfill_command from airflow.cli import cli_parser +from airflow.models import DagBag from airflow.models.backfill import ReprocessBehavior from airflow.utils import timezone -from tests_common.test_utils.db import clear_db_backfills, clear_db_dags, clear_db_runs, parse_and_sync_to_db +from tests_common.test_utils.db import clear_db_backfills, clear_db_dags, clear_db_runs DEFAULT_DATE = timezone.make_aware(datetime(2015, 1, 1), timezone=timezone.utc) if pendulum.__version__.startswith("3"): @@ -48,7 +48,8 @@ class TestCliBackfill: @classmethod def setup_class(cls): - parse_and_sync_to_db(os.devnull, include_examples=True) + cls.dagbag = DagBag(include_examples=True) + cls.dagbag.sync_to_db() cls.parser = cli_parser.get_parser() @classmethod diff --git a/tests/cli/commands/remote_commands/test_dag_command.py b/tests/cli/commands/remote_commands/test_dag_command.py index 0db0ac83df02f..dab4d0da6caa0 100644 --- a/tests/cli/commands/remote_commands/test_dag_command.py +++ b/tests/cli/commands/remote_commands/test_dag_command.py @@ -50,7 +50,7 @@ from tests.models import TEST_DAGS_FOLDER from tests_common.test_utils.config import conf_vars -from tests_common.test_utils.db import clear_db_dags, clear_db_runs, parse_and_sync_to_db +from tests_common.test_utils.db import clear_db_dags, clear_db_runs DEFAULT_DATE = timezone.make_aware(datetime(2015, 1, 1), timezone=timezone.utc) if pendulum.__version__.startswith("3"): @@ -68,7 +68,8 @@ class TestCliDags: @classmethod def setup_class(cls): - parse_and_sync_to_db(os.devnull, include_examples=True) + cls.dagbag = DagBag(include_examples=True) + cls.dagbag.sync_to_db() cls.parser = cli_parser.get_parser() @classmethod @@ -206,7 +207,8 @@ def test_next_execution(self, tmp_path): with time_machine.travel(DEFAULT_DATE): clear_db_dags() - parse_and_sync_to_db(tmp_path, include_examples=False) + self.dagbag = DagBag(dag_folder=tmp_path, include_examples=False) + self.dagbag.sync_to_db() default_run = DEFAULT_DATE future_run = default_run + timedelta(days=5) @@ -253,7 +255,8 @@ def test_next_execution(self, tmp_path): # Rebuild Test DB for other tests clear_db_dags() - parse_and_sync_to_db(os.devnull, include_examples=True) + TestCliDags.dagbag = DagBag(include_examples=True) + TestCliDags.dagbag.sync_to_db() @conf_vars({("core", "load_examples"): "true"}) def test_cli_report(self): @@ -402,24 +405,24 @@ def test_cli_list_jobs_with_args(self): def test_pause(self): args = self.parser.parse_args(["dags", "pause", "example_bash_operator"]) dag_command.dag_pause(args) - assert DagModel.get_dagmodel("example_bash_operator").is_paused + assert self.dagbag.dags["example_bash_operator"].get_is_paused() dag_command.dag_unpause(args) - assert not DagModel.get_dagmodel("example_bash_operator").is_paused + assert not self.dagbag.dags["example_bash_operator"].get_is_paused() @mock.patch("airflow.cli.commands.remote_commands.dag_command.ask_yesno") def test_pause_regex(self, mock_yesno): args = self.parser.parse_args(["dags", "pause", "^example_.*$", "--treat-dag-id-as-regex"]) dag_command.dag_pause(args) mock_yesno.assert_called_once() - assert DagModel.get_dagmodel("example_bash_decorator").is_paused - assert DagModel.get_dagmodel("example_kubernetes_executor").is_paused - assert DagModel.get_dagmodel("example_xcom_args").is_paused + assert self.dagbag.dags["example_bash_decorator"].get_is_paused() + assert self.dagbag.dags["example_kubernetes_executor"].get_is_paused() + assert self.dagbag.dags["example_xcom_args"].get_is_paused() args = self.parser.parse_args(["dags", "unpause", "^example_.*$", "--treat-dag-id-as-regex"]) dag_command.dag_unpause(args) - assert not DagModel.get_dagmodel("example_bash_decorator").is_paused - assert not DagModel.get_dagmodel("example_kubernetes_executor").is_paused - assert not DagModel.get_dagmodel("example_xcom_args").is_paused + assert not self.dagbag.dags["example_bash_decorator"].get_is_paused() + assert not self.dagbag.dags["example_kubernetes_executor"].get_is_paused() + assert not self.dagbag.dags["example_xcom_args"].get_is_paused() @mock.patch("airflow.cli.commands.remote_commands.dag_command.ask_yesno") def test_pause_regex_operation_cancelled(self, ask_yesno, capsys): diff --git a/tests/cli/commands/remote_commands/test_task_command.py b/tests/cli/commands/remote_commands/test_task_command.py index 9a4c606caa469..843d6817cdcc5 100644 --- a/tests/cli/commands/remote_commands/test_task_command.py +++ b/tests/cli/commands/remote_commands/test_task_command.py @@ -53,7 +53,7 @@ from airflow.utils.types import DagRunType from tests_common.test_utils.config import conf_vars -from tests_common.test_utils.db import clear_db_pools, clear_db_runs, parse_and_sync_to_db +from tests_common.test_utils.db import clear_db_pools, clear_db_runs from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS if AIRFLOW_V_3_0_PLUS: @@ -97,12 +97,12 @@ class TestCliTasks: @pytest.fixture(autouse=True) def setup_class(cls): logging.config.dictConfig(DEFAULT_LOGGING_CONFIG) - parse_and_sync_to_db(os.devnull, include_examples=True) + cls.dagbag = DagBag(include_examples=True) cls.parser = cli_parser.get_parser() clear_db_runs() - cls.dagbag = DagBag(read_dags_from_db=True) cls.dag = cls.dagbag.get_dag(cls.dag_id) + cls.dagbag.sync_to_db() data_interval = cls.dag.timetable.infer_manual_data_interval(run_after=DEFAULT_DATE) triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.CLI} if AIRFLOW_V_3_0_PLUS else {} cls.dag_run = cls.dag.create_dagrun( @@ -164,7 +164,7 @@ def test_cli_test_different_path(self, session, tmp_path): with conf_vars({("core", "dags_folder"): orig_dags_folder.as_posix()}): dagbag = DagBag(include_examples=False) dag = dagbag.get_dag("test_dags_folder") - dagbag.sync_to_db("dags-folder", None, session=session) + dagbag.sync_to_db(session=session) logical_date = pendulum.now("UTC") data_interval = dag.timetable.infer_manual_data_interval(run_after=logical_date) diff --git a/tests/conftest.py b/tests/conftest.py index 8e41e35e35d06..de13fe99c4bf6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,20 +16,15 @@ # under the License. from __future__ import annotations -import json import logging import os import sys -from contextlib import contextmanager from typing import TYPE_CHECKING import pytest from tests_common.test_utils.log_handlers import non_pytest_handlers -if TYPE_CHECKING: - from pathlib import Path - # We should set these before loading _any_ of the rest of airflow so that the # unit test mode config is set as early as possible. assert "airflow" not in sys.modules, "No airflow module can be imported before these lines" @@ -86,37 +81,6 @@ def clear_all_logger_handlers(): remove_all_non_pytest_log_handlers() -@pytest.fixture -def testing_dag_bundle(): - from airflow.models.dagbundle import DagBundleModel - from airflow.utils.session import create_session - - with create_session() as session: - if session.query(DagBundleModel).filter(DagBundleModel.name == "testing").count() == 0: - testing = DagBundleModel(name="testing") - session.add(testing) - - -@pytest.fixture -def configure_testing_dag_bundle(): - """Configure the testing DAG bundle with the provided path, and disable the DAGs folder bundle.""" - from tests_common.test_utils.config import conf_vars - - @contextmanager - def _config_bundle(path_to_parse: Path | str): - bundle_config = [ - { - "name": "testing", - "classpath": "airflow.dag_processing.bundles.local.LocalDagBundle", - "kwargs": {"local_folder": str(path_to_parse), "refresh_interval": 0}, - } - ] - with conf_vars({("dag_bundles", "backends"): json.dumps(bundle_config)}): - yield - - return _config_bundle - - if TYPE_CHECKING: # Static checkers do not know about pytest fixtures' types and return, # In case if them distributed through third party packages. diff --git a/tests/dag_processing/bundles/test_dag_bundle_manager.py b/tests/dag_processing/bundles/test_dag_bundle_manager.py index eee7dee211472..47b192efc198f 100644 --- a/tests/dag_processing/bundles/test_dag_bundle_manager.py +++ b/tests/dag_processing/bundles/test_dag_bundle_manager.py @@ -68,9 +68,7 @@ ) def test_parse_bundle_config(value, expected): """Test that bundle_configs are read from configuration.""" - envs = {"AIRFLOW__CORE__LOAD_EXAMPLES": "False"} - if value: - envs["AIRFLOW__DAG_BUNDLES__BACKENDS"] = value + envs = {"AIRFLOW__DAG_BUNDLES__BACKENDS": value} if value else {} cm = nullcontext() exp_fail = False if isinstance(expected, str): @@ -135,7 +133,6 @@ def clear_db(): @pytest.mark.db_test -@conf_vars({("core", "LOAD_EXAMPLES"): "False"}) def test_sync_bundles_to_db(clear_db): def _get_bundle_names_and_active(): with create_session() as session: @@ -170,14 +167,3 @@ def test_view_url(version): with patch.object(BaseDagBundle, "view_url") as view_url_mock: bundle_manager.view_url("my-test-bundle", version=version) view_url_mock.assert_called_once_with(version=version) - - -def test_example_dags_bundle_added(): - manager = DagBundlesManager() - manager.parse_config() - assert "example_dags" in manager._bundle_config - - with conf_vars({("core", "LOAD_EXAMPLES"): "False"}): - manager = DagBundlesManager() - manager.parse_config() - assert "example_dags" not in manager._bundle_config diff --git a/tests/dag_processing/test_collection.py b/tests/dag_processing/test_collection.py index 8ef07514bbfbe..a248904cbefcc 100644 --- a/tests/dag_processing/test_collection.py +++ b/tests/dag_processing/test_collection.py @@ -185,7 +185,7 @@ def dag_to_lazy_serdag(self, dag: DAG) -> LazyDeserializedDAG: @pytest.mark.usefixtures("clean_db") # sync_perms in fab has bad session commit hygiene def test_sync_perms_syncs_dag_specific_perms_on_update( - self, monkeypatch, spy_agency: SpyAgency, session, time_machine, testing_dag_bundle + self, monkeypatch, spy_agency: SpyAgency, session, time_machine ): """ Test that dagbag.sync_to_db will sync DAG specific permissions when a DAG is @@ -210,7 +210,7 @@ def _sync_to_db(): sync_perms_spy.reset_calls() time_machine.shift(20) - update_dag_parsing_results_in_db("testing", None, [dag], dict(), set(), session) + update_dag_parsing_results_in_db([dag], dict(), set(), session) _sync_to_db() spy_agency.assert_spy_called_with(sync_perms_spy, dag, session=session) @@ -228,9 +228,7 @@ def _sync_to_db(): @patch.object(SerializedDagModel, "write_dag") @patch("airflow.models.dag.DAG.bulk_write_to_db") - def test_sync_to_db_is_retried( - self, mock_bulk_write_to_db, mock_s10n_write_dag, testing_dag_bundle, session - ): + def test_sync_to_db_is_retried(self, mock_bulk_write_to_db, mock_s10n_write_dag, session): """Test that important DB operations in db sync are retried on OperationalError""" serialized_dags_count = session.query(func.count(SerializedDagModel.dag_id)).scalar() @@ -246,16 +244,14 @@ def test_sync_to_db_is_retried( mock_bulk_write_to_db.side_effect = side_effect mock_session = mock.MagicMock() - update_dag_parsing_results_in_db( - "testing", None, dags=dags, import_errors={}, warnings=set(), session=mock_session - ) + update_dag_parsing_results_in_db(dags=dags, import_errors={}, warnings=set(), session=mock_session) # Test that 3 attempts were made to run 'DAG.bulk_write_to_db' successfully mock_bulk_write_to_db.assert_has_calls( [ - mock.call("testing", None, mock.ANY, session=mock.ANY), - mock.call("testing", None, mock.ANY, session=mock.ANY), - mock.call("testing", None, mock.ANY, session=mock.ANY), + mock.call(mock.ANY, session=mock.ANY), + mock.call(mock.ANY, session=mock.ANY), + mock.call(mock.ANY, session=mock.ANY), ] ) # Assert that rollback is called twice (i.e. whenever OperationalError occurs) @@ -272,7 +268,7 @@ def test_sync_to_db_is_retried( serialized_dags_count = session.query(func.count(SerializedDagModel.dag_id)).scalar() assert serialized_dags_count == 0 - def test_serialized_dags_are_written_to_db_on_sync(self, testing_dag_bundle, session): + def test_serialized_dags_are_written_to_db_on_sync(self, session): """ Test that when dagbag.sync_to_db is called the DAGs are Serialized and written to DB even when dagbag.read_dags_from_db is False @@ -282,14 +278,14 @@ def test_serialized_dags_are_written_to_db_on_sync(self, testing_dag_bundle, ses dag = DAG(dag_id="test") - update_dag_parsing_results_in_db("testing", None, [dag], dict(), set(), session) + update_dag_parsing_results_in_db([dag], dict(), set(), session) new_serialized_dags_count = session.query(func.count(SerializedDagModel.dag_id)).scalar() assert new_serialized_dags_count == 1 @patch.object(SerializedDagModel, "write_dag") def test_serialized_dag_errors_are_import_errors( - self, mock_serialize, caplog, session, dag_import_error_listener, testing_dag_bundle + self, mock_serialize, caplog, session, dag_import_error_listener ): """ Test that errors serializing a DAG are recorded as import_errors in the DB @@ -302,7 +298,7 @@ def test_serialized_dag_errors_are_import_errors( dag.fileloc = "abc.py" import_errors = {} - update_dag_parsing_results_in_db("testing", None, [dag], import_errors, set(), session) + update_dag_parsing_results_in_db([dag], import_errors, set(), session) assert "SerializationError" in caplog.text # Should have been edited in place @@ -324,7 +320,7 @@ def test_serialized_dag_errors_are_import_errors( assert len(dag_import_error_listener.existing) == 0 assert dag_import_error_listener.new["abc.py"] == import_error.stacktrace - def test_new_import_error_replaces_old(self, session, dag_import_error_listener, testing_dag_bundle): + def test_new_import_error_replaces_old(self, session, dag_import_error_listener): """ Test that existing import error is updated and new record not created for a dag with the same filename @@ -340,8 +336,6 @@ def test_new_import_error_replaces_old(self, session, dag_import_error_listener, prev_error_id = prev_error.id update_dag_parsing_results_in_db( - bundle_name="testing", - bundle_version=None, dags=[], import_errors={"abc.py": "New error"}, warnings=set(), @@ -359,7 +353,7 @@ def test_new_import_error_replaces_old(self, session, dag_import_error_listener, assert len(dag_import_error_listener.existing) == 1 assert dag_import_error_listener.existing["abc.py"] == prev_error.stacktrace - def test_remove_error_clears_import_error(self, testing_dag_bundle, session): + def test_remove_error_clears_import_error(self, session): # Pre-condition: there is an import error for the dag file filename = "abc.py" prev_error = ParseImportError( @@ -387,7 +381,7 @@ def test_remove_error_clears_import_error(self, testing_dag_bundle, session): dag.fileloc = filename import_errors = {} - update_dag_parsing_results_in_db("testing", None, [dag], import_errors, set(), session) + update_dag_parsing_results_in_db([dag], import_errors, set(), session) dag_model: DagModel = session.get(DagModel, (dag.dag_id,)) assert dag_model.has_import_errors is False @@ -479,7 +473,7 @@ def _sync_perms(): ], ) @pytest.mark.usefixtures("clean_db") - def test_dagmodel_properties(self, attrs, expected, session, time_machine, testing_dag_bundle): + def test_dagmodel_properties(self, attrs, expected, session, time_machine): """Test that properties on the dag model are correctly set when dealing with a LazySerializedDag""" dt = tz.datetime(2020, 1, 5, 0, 0, 0) time_machine.move_to(dt, tick=False) @@ -499,7 +493,7 @@ def test_dagmodel_properties(self, attrs, expected, session, time_machine, testi session.add(dr1) session.commit() - update_dag_parsing_results_in_db("testing", None, [self.dag_to_lazy_serdag(dag)], {}, set(), session) + update_dag_parsing_results_in_db([self.dag_to_lazy_serdag(dag)], {}, set(), session) orm_dag = session.get(DagModel, ("dag",)) @@ -511,21 +505,14 @@ def test_dagmodel_properties(self, attrs, expected, session, time_machine, testi assert orm_dag.last_parsed_time == dt - def test_existing_dag_is_paused_upon_creation(self, testing_dag_bundle, session): + def test_existing_dag_is_paused_upon_creation(self, session): dag = DAG("dag_paused", schedule=None) - update_dag_parsing_results_in_db("testing", None, [self.dag_to_lazy_serdag(dag)], {}, set(), session) + update_dag_parsing_results_in_db([self.dag_to_lazy_serdag(dag)], {}, set(), session) orm_dag = session.get(DagModel, ("dag_paused",)) assert orm_dag.is_paused is False dag = DAG("dag_paused", schedule=None, is_paused_upon_creation=True) - update_dag_parsing_results_in_db("testing", None, [self.dag_to_lazy_serdag(dag)], {}, set(), session) + update_dag_parsing_results_in_db([self.dag_to_lazy_serdag(dag)], {}, set(), session) # Since the dag existed before, it should not follow the pause flag upon creation orm_dag = session.get(DagModel, ("dag_paused",)) assert orm_dag.is_paused is False - - def test_bundle_name_and_version_are_stored(self, testing_dag_bundle, session): - dag = DAG("mydag", schedule=None) - update_dag_parsing_results_in_db("testing", "1.0", [self.dag_to_lazy_serdag(dag)], {}, set(), session) - orm_dag = session.get(DagModel, "mydag") - assert orm_dag.bundle_name == "testing" - assert orm_dag.latest_bundle_version == "1.0" diff --git a/tests/dag_processing/test_manager.py b/tests/dag_processing/test_manager.py index 7608cbbd32766..2cc8a43e05450 100644 --- a/tests/dag_processing/test_manager.py +++ b/tests/dag_processing/test_manager.py @@ -19,7 +19,6 @@ import io import itertools -import json import logging import multiprocessing import os @@ -44,18 +43,15 @@ from airflow.callbacks.callback_requests import DagCallbackRequest from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG -from airflow.dag_processing.bundles.manager import DagBundlesManager from airflow.dag_processing.manager import ( - DagFileInfo, DagFileProcessorAgent, DagFileProcessorManager, DagFileStat, ) from airflow.dag_processing.processor import DagFileProcessorProcess -from airflow.models import DAG, DagBag, DagModel, DbCallbackRequest +from airflow.models import DagBag, DagModel, DbCallbackRequest from airflow.models.asset import TaskOutletAssetReference from airflow.models.dag_version import DagVersion -from airflow.models.dagbundle import DagBundleModel from airflow.models.dagcode import DagCode from airflow.models.serialized_dag import SerializedDagModel from airflow.utils import timezone @@ -82,10 +78,6 @@ DEFAULT_DATE = timezone.datetime(2016, 1, 1) -def _get_dag_file_paths(files: list[str]) -> list[DagFileInfo]: - return [DagFileInfo(bundle_name="testing", path=f) for f in files] - - class TestDagFileProcessorManager: @pytest.fixture(autouse=True) def _disable_examples(self): @@ -109,6 +101,9 @@ def teardown_class(self): clear_db_callbacks() clear_db_import_errors() + def run_processor_manager_one_loop(self, manager: DagFileProcessorManager) -> None: + manager._run_parsing_loop() + def mock_processor(self) -> DagFileProcessorProcess: proc = MagicMock() proc.create_time.return_value = time.time() @@ -129,51 +124,49 @@ def clear_parse_import_errors(self): @pytest.mark.usefixtures("clear_parse_import_errors") @conf_vars({("core", "load_examples"): "False"}) - def test_remove_file_clears_import_error(self, tmp_path, configure_testing_dag_bundle): + def test_remove_file_clears_import_error(self, tmp_path): path_to_parse = tmp_path / "temp_dag.py" # Generate original import error path_to_parse.write_text("an invalid airflow DAG") - with configure_testing_dag_bundle(path_to_parse): - manager = DagFileProcessorManager( - max_runs=1, - processor_timeout=365 * 86_400, - ) + manager = DagFileProcessorManager( + dag_directory=path_to_parse.parent, + max_runs=1, + processor_timeout=365 * 86_400, + ) - with create_session() as session: - manager.run() + with create_session() as session: + self.run_processor_manager_one_loop(manager) - import_errors = session.query(ParseImportError).all() - assert len(import_errors) == 1 + import_errors = session.query(ParseImportError).all() + assert len(import_errors) == 1 - path_to_parse.unlink() + path_to_parse.unlink() - # Rerun the parser once the dag file has been removed - manager.run() - import_errors = session.query(ParseImportError).all() + # Rerun the parser once the dag file has been removed + self.run_processor_manager_one_loop(manager) + import_errors = session.query(ParseImportError).all() - assert len(import_errors) == 0 - session.rollback() + assert len(import_errors) == 0 + session.rollback() @conf_vars({("core", "load_examples"): "False"}) def test_max_runs_when_no_files(self, tmp_path): - with conf_vars({("core", "dags_folder"): str(tmp_path)}): - manager = DagFileProcessorManager(max_runs=1) - manager.run() + manager = DagFileProcessorManager(dag_directory=tmp_path, max_runs=1) - # TODO: AIP-66 no asserts? + self.run_processor_manager_one_loop(manager) def test_start_new_processes_with_same_filepath(self): """ Test that when a processor already exist with a filepath, a new processor won't be created with that filepath. The filepath will just be removed from the list. """ - manager = DagFileProcessorManager(max_runs=1) + manager = DagFileProcessorManager(dag_directory="directory", max_runs=1) - file_1 = DagFileInfo(bundle_name="testing", path="file_1.py") - file_2 = DagFileInfo(bundle_name="testing", path="file_2.py") - file_3 = DagFileInfo(bundle_name="testing", path="file_3.py") + file_1 = "file_1.py" + file_2 = "file_2.py" + file_3 = "file_3.py" manager._file_path_queue = deque([file_1, file_2, file_3]) # Mock that only one processor exists. This processor runs with 'file_1' @@ -192,47 +185,49 @@ def test_start_new_processes_with_same_filepath(self): assert deque([file_3]) == manager._file_path_queue def test_set_file_paths_when_processor_file_path_not_in_new_file_paths(self): - """Ensure processors and file stats are removed when the file path is not in the new file paths""" - manager = DagFileProcessorManager(max_runs=1) - file = DagFileInfo(bundle_name="testing", path="missing_file.txt") + manager = DagFileProcessorManager(dag_directory="directory", max_runs=1) + + mock_processor = MagicMock() + mock_processor.stop.side_effect = AttributeError("DagFileProcessor object has no attribute stop") + mock_processor.terminate.side_effect = None - manager._processors[file] = MagicMock() - manager._file_stats[file] = DagFileStat() + manager._processors["missing_file.txt"] = mock_processor + manager._file_stats["missing_file.txt"] = DagFileStat() manager.set_file_paths(["abc.txt"]) assert manager._processors == {} - assert file not in manager._file_stats + assert "missing_file.txt" not in manager._file_stats def test_set_file_paths_when_processor_file_path_is_in_new_file_paths(self): - manager = DagFileProcessorManager(max_runs=1) - file = DagFileInfo(bundle_name="testing", path="abc.txt") + manager = DagFileProcessorManager(dag_directory="directory", max_runs=1) + mock_processor = MagicMock() + mock_processor.stop.side_effect = AttributeError("DagFileProcessor object has no attribute stop") + mock_processor.terminate.side_effect = None - manager._processors[file] = mock_processor + manager._processors["abc.txt"] = mock_processor - manager.set_file_paths([file]) - assert manager._processors == {file: mock_processor} + manager.set_file_paths(["abc.txt"]) + assert manager._processors == {"abc.txt": mock_processor} @conf_vars({("scheduler", "file_parsing_sort_mode"): "alphabetical"}) def test_file_paths_in_queue_sorted_alphabetically(self): """Test dag files are sorted alphabetically""" - file_names = ["file_3.py", "file_2.py", "file_4.py", "file_1.py"] - dag_files = _get_dag_file_paths(file_names) - ordered_dag_files = _get_dag_file_paths(sorted(file_names)) + dag_files = ["file_3.py", "file_2.py", "file_4.py", "file_1.py"] - manager = DagFileProcessorManager(max_runs=1) + manager = DagFileProcessorManager(dag_directory="directory", max_runs=1) manager.set_file_paths(dag_files) assert manager._file_path_queue == deque() manager.prepare_file_path_queue() - assert manager._file_path_queue == deque(ordered_dag_files) + assert manager._file_path_queue == deque(["file_1.py", "file_2.py", "file_3.py", "file_4.py"]) @conf_vars({("scheduler", "file_parsing_sort_mode"): "random_seeded_by_host"}) def test_file_paths_in_queue_sorted_random_seeded_by_host(self): """Test files are randomly sorted and seeded by host name""" - dag_files = _get_dag_file_paths(["file_3.py", "file_2.py", "file_4.py", "file_1.py"]) + dag_files = ["file_3.py", "file_2.py", "file_4.py", "file_1.py"] - manager = DagFileProcessorManager(max_runs=1) + manager = DagFileProcessorManager(dag_directory="directory", max_runs=1) manager.set_file_paths(dag_files) assert manager._file_path_queue == deque() @@ -252,50 +247,45 @@ def test_file_paths_in_queue_sorted_random_seeded_by_host(self): def test_file_paths_in_queue_sorted_by_modified_time(self, mock_getmtime): """Test files are sorted by modified time""" paths_with_mtime = {"file_3.py": 3.0, "file_2.py": 2.0, "file_4.py": 5.0, "file_1.py": 4.0} - dag_files = _get_dag_file_paths(paths_with_mtime.keys()) + dag_files = list(paths_with_mtime.keys()) mock_getmtime.side_effect = list(paths_with_mtime.values()) - manager = DagFileProcessorManager(max_runs=1) + manager = DagFileProcessorManager(dag_directory="directory", max_runs=1) manager.set_file_paths(dag_files) assert manager._file_path_queue == deque() manager.prepare_file_path_queue() - ordered_files = _get_dag_file_paths(["file_4.py", "file_1.py", "file_3.py", "file_2.py"]) - assert manager._file_path_queue == deque(ordered_files) + assert manager._file_path_queue == deque(["file_4.py", "file_1.py", "file_3.py", "file_2.py"]) @conf_vars({("scheduler", "file_parsing_sort_mode"): "modified_time"}) @mock.patch("airflow.utils.file.os.path.getmtime") def test_file_paths_in_queue_excludes_missing_file(self, mock_getmtime): """Check that a file is not enqueued for processing if it has been deleted""" - dag_files = _get_dag_file_paths(["file_3.py", "file_2.py", "file_4.py"]) + dag_files = ["file_3.py", "file_2.py", "file_4.py"] mock_getmtime.side_effect = [1.0, 2.0, FileNotFoundError()] - manager = DagFileProcessorManager(max_runs=1) + manager = DagFileProcessorManager(dag_directory="directory", max_runs=1) manager.set_file_paths(dag_files) manager.prepare_file_path_queue() - - ordered_files = _get_dag_file_paths(["file_2.py", "file_3.py"]) - assert manager._file_path_queue == deque(ordered_files) + assert manager._file_path_queue == deque(["file_2.py", "file_3.py"]) @conf_vars({("scheduler", "file_parsing_sort_mode"): "modified_time"}) @mock.patch("airflow.utils.file.os.path.getmtime") def test_add_new_file_to_parsing_queue(self, mock_getmtime): """Check that new file is added to parsing queue""" - dag_files = _get_dag_file_paths(["file_1.py", "file_2.py", "file_3.py"]) + dag_files = ["file_1.py", "file_2.py", "file_3.py"] mock_getmtime.side_effect = [1.0, 2.0, 3.0] - manager = DagFileProcessorManager(max_runs=1) + manager = DagFileProcessorManager(dag_directory="directory", max_runs=1) manager.set_file_paths(dag_files) manager.prepare_file_path_queue() - ordered_files = _get_dag_file_paths(["file_3.py", "file_2.py", "file_1.py"]) - assert manager._file_path_queue == deque(ordered_files) + assert manager._file_path_queue == deque(["file_3.py", "file_2.py", "file_1.py"]) - manager.set_file_paths([*dag_files, DagFileInfo(bundle_name="testing", path="file_4.py")]) + manager.set_file_paths([*dag_files, "file_4.py"]) manager.add_new_file_path_to_queue() - ordered_files = _get_dag_file_paths(["file_4.py", "file_3.py", "file_2.py", "file_1.py"]) - assert manager._file_path_queue == deque(ordered_files) + assert manager._file_path_queue == deque(["file_4.py", "file_3.py", "file_2.py", "file_1.py"]) @conf_vars({("scheduler", "file_parsing_sort_mode"): "modified_time"}) @mock.patch("airflow.utils.file.os.path.getmtime") @@ -305,16 +295,15 @@ def test_recently_modified_file_is_parsed_with_mtime_mode(self, mock_getmtime): """ freezed_base_time = timezone.datetime(2020, 1, 5, 0, 0, 0) initial_file_1_mtime = (freezed_base_time - timedelta(minutes=5)).timestamp() - dag_file = DagFileInfo(bundle_name="testing", path="file_1.py") - dag_files = [dag_file] + dag_files = ["file_1.py"] mock_getmtime.side_effect = [initial_file_1_mtime] - manager = DagFileProcessorManager(max_runs=3) + manager = DagFileProcessorManager(dag_directory="directory", max_runs=3) # let's say the DAG was just parsed 10 seconds before the Freezed time last_finish_time = freezed_base_time - timedelta(seconds=10) manager._file_stats = { - dag_file: DagFileStat(1, 0, last_finish_time, 1.0, 1, 1), + "file_1.py": DagFileStat(1, 0, last_finish_time, 1.0, 1, 1), } with time_machine.travel(freezed_base_time): manager.set_file_paths(dag_files) @@ -334,14 +323,13 @@ def test_recently_modified_file_is_parsed_with_mtime_mode(self, mock_getmtime): mock_getmtime.side_effect = [file_1_new_mtime_ts] manager.prepare_file_path_queue() # Check that file is added to the queue even though file was just recently passed - assert manager._file_path_queue == deque(dag_files) + assert manager._file_path_queue == deque(["file_1.py"]) assert last_finish_time < file_1_new_mtime assert ( manager._file_process_interval - > (freezed_base_time - manager._file_stats[dag_file].last_finish_time).total_seconds() + > (freezed_base_time - manager._file_stats["file_1.py"].last_finish_time).total_seconds() ) - @pytest.mark.skip("AIP-66: parsing requests are not bundle aware yet") def test_file_paths_in_queue_sorted_by_priority(self): from airflow.models.dagbag import DagPriorityParsingRequest @@ -363,27 +351,25 @@ def test_file_paths_in_queue_sorted_by_priority(self): parsing_request_after = session2.query(DagPriorityParsingRequest).get(parsing_request.id) assert parsing_request_after is None - def test_scan_stale_dags(self, testing_dag_bundle): + def test_scan_stale_dags(self): """ Ensure that DAGs are marked inactive when the file is parsed but the DagModel.last_parsed_time is not updated. """ manager = DagFileProcessorManager( + dag_directory="directory", max_runs=1, processor_timeout=10 * 60, ) - test_dag_path = DagFileInfo( - bundle_name="testing", - path=str(TEST_DAG_FOLDER / "test_example_bash_operator.py"), - ) - dagbag = DagBag(test_dag_path.path, read_dags_from_db=False, include_examples=False) + test_dag_path = str(TEST_DAG_FOLDER / "test_example_bash_operator.py") + dagbag = DagBag(test_dag_path, read_dags_from_db=False, include_examples=False) with create_session() as session: # Add stale DAG to the DB dag = dagbag.get_dag("test_example_bash_operator") dag.last_parsed_time = timezone.utcnow() - DAG.bulk_write_to_db("testing", None, [dag]) + dag.sync_to_db() SerializedDagModel.write_dag(dag) # Add DAG to the file_parsing_stats @@ -400,7 +386,7 @@ def test_scan_stale_dags(self, testing_dag_bundle): active_dag_count = ( session.query(func.count(DagModel.dag_id)) - .filter(DagModel.is_active, DagModel.fileloc == test_dag_path.path) + .filter(DagModel.is_active, DagModel.fileloc == test_dag_path) .scalar() ) assert active_dag_count == 1 @@ -409,7 +395,7 @@ def test_scan_stale_dags(self, testing_dag_bundle): active_dag_count = ( session.query(func.count(DagModel.dag_id)) - .filter(DagModel.is_active, DagModel.fileloc == test_dag_path.path) + .filter(DagModel.is_active, DagModel.fileloc == test_dag_path) .scalar() ) assert active_dag_count == 0 @@ -424,11 +410,11 @@ def test_scan_stale_dags(self, testing_dag_bundle): assert serialized_dag_count == 1 def test_kill_timed_out_processors_kill(self): - manager = DagFileProcessorManager(max_runs=1, processor_timeout=5) + manager = DagFileProcessorManager(dag_directory="directory", max_runs=1, processor_timeout=5) processor = self.mock_processor() processor._process.create_time.return_value = timezone.make_aware(datetime.min).timestamp() - manager._processors = {DagFileInfo(bundle_name="testing", path="abc.txt"): processor} + manager._processors = {"abc.txt": processor} with mock.patch.object(type(processor), "kill") as mock_kill: manager._kill_timed_out_processors() mock_kill.assert_called_once_with(signal.SIGKILL) @@ -436,13 +422,14 @@ def test_kill_timed_out_processors_kill(self): def test_kill_timed_out_processors_no_kill(self): manager = DagFileProcessorManager( + dag_directory=TEST_DAG_FOLDER, max_runs=1, processor_timeout=5, ) processor = self.mock_processor() processor._process.create_time.return_value = timezone.make_aware(datetime.max).timestamp() - manager._processors = {DagFileInfo(bundle_name="testing", path="abc.txt"): processor} + manager._processors = {"abc.txt": processor} with mock.patch.object(type(processor), "kill") as mock_kill: manager._kill_timed_out_processors() mock_kill.assert_not_called() @@ -485,7 +472,7 @@ def test_serialize_callback_requests(self, callbacks, path, child_comms_fd, expe @conf_vars({("core", "load_examples"): "False"}) @pytest.mark.execution_timeout(10) - def test_dag_with_system_exit(self, configure_testing_dag_bundle): + def test_dag_with_system_exit(self): """ Test to check that a DAG with a system.exit() doesn't break the scheduler. """ @@ -497,9 +484,9 @@ def test_dag_with_system_exit(self, configure_testing_dag_bundle): clear_db_dags() clear_db_serialized_dags() - with configure_testing_dag_bundle(dag_directory): - manager = DagFileProcessorManager(max_runs=1) - manager.run() + manager = DagFileProcessorManager(dag_directory=dag_directory, max_runs=1) + + manager._run_parsing_loop() # Three files in folder should be processed assert sum(stat.run_count for stat in manager._file_stats.values()) == 3 @@ -509,7 +496,7 @@ def test_dag_with_system_exit(self, configure_testing_dag_bundle): @conf_vars({("core", "load_examples"): "False"}) @pytest.mark.execution_timeout(30) - def test_pipe_full_deadlock(self, configure_testing_dag_bundle): + def test_pipe_full_deadlock(self): dag_filepath = TEST_DAG_FOLDER / "test_scheduler_dags.py" child_pipe, parent_pipe = multiprocessing.Pipe() @@ -546,43 +533,38 @@ def keep_pipe_full(pipe, exit_event): thread = threading.Thread(target=keep_pipe_full, args=(parent_pipe, exit_event)) - with configure_testing_dag_bundle(dag_filepath): - manager = DagFileProcessorManager( - # A reasonable large number to ensure that we trigger the deadlock - max_runs=100, - processor_timeout=5, - signal_conn=child_pipe, - # Make it loop sub-processes quickly. Need to be non-zero to exercise the bug, else it finishes - # too quickly - file_process_interval=0.01, - ) + manager = DagFileProcessorManager( + dag_directory=dag_filepath, + # A reasonable large number to ensure that we trigger the deadlock + max_runs=100, + processor_timeout=5, + signal_conn=child_pipe, + # Make it loop sub-processes quickly. Need to be non-zero to exercise the bug, else it finishes + # too quickly + file_process_interval=0.01, + ) - try: - thread.start() - - # If this completes without hanging, then the test is good! - with mock.patch.object( - DagFileProcessorProcess, - "start", - side_effect=lambda *args, **kwargs: self.mock_processor(), - ): - manager.run() - exit_event.set() - finally: - logger.info("Closing pipes") - parent_pipe.close() - child_pipe.close() - logger.info("Closed pipes") - logger.info("Joining thread") - thread.join(timeout=1.0) - logger.info("Joined thread") + try: + thread.start() + + # If this completes without hanging, then the test is good! + with mock.patch.object( + DagFileProcessorProcess, "start", side_effect=lambda *args, **kwargs: self.mock_processor() + ): + manager.run() + exit_event.set() + finally: + logger.info("Closing pipes") + parent_pipe.close() + child_pipe.close() + logger.info("Closed pipes") + logger.info("Joining thread") + thread.join(timeout=1.0) + logger.info("Joined thread") @conf_vars({("core", "load_examples"): "False"}) @mock.patch("airflow.dag_processing.manager.Stats.timing") - @pytest.mark.skip("AIP-66: stats are not implemented yet") - def test_send_file_processing_statsd_timing( - self, statsd_timing_mock, tmp_path, configure_testing_dag_bundle - ): + def test_send_file_processing_statsd_timing(self, statsd_timing_mock, tmp_path): path_to_parse = tmp_path / "temp_dag.py" dag_code = textwrap.dedent( """ @@ -592,11 +574,11 @@ def test_send_file_processing_statsd_timing( ) path_to_parse.write_text(dag_code) - with configure_testing_dag_bundle(tmp_path): - manager = DagFileProcessorManager(max_runs=1) - manager.run() + manager = DagFileProcessorManager(dag_directory=path_to_parse.parent, max_runs=1) + self.run_processor_manager_one_loop(manager) last_runtime = manager._file_stats[os.fspath(path_to_parse)].last_duration + statsd_timing_mock.assert_has_calls( [ mock.call("dag_processing.last_duration.temp_dag", last_runtime), @@ -605,19 +587,17 @@ def test_send_file_processing_statsd_timing( any_order=True, ) - def test_refresh_dags_dir_doesnt_delete_zipped_dags(self, tmp_path, configure_testing_dag_bundle): + def test_refresh_dags_dir_doesnt_delete_zipped_dags(self, tmp_path): """Test DagFileProcessorManager._refresh_dag_dir method""" + manager = DagFileProcessorManager(dag_directory=TEST_DAG_FOLDER, max_runs=1) dagbag = DagBag(dag_folder=tmp_path, include_examples=False) zipped_dag_path = os.path.join(TEST_DAGS_FOLDER, "test_zip.zip") dagbag.process_file(zipped_dag_path) dag = dagbag.get_dag("test_zip_dag") - DAG.bulk_write_to_db("testing", None, [dag]) + dag.sync_to_db() SerializedDagModel.write_dag(dag) - - with configure_testing_dag_bundle(zipped_dag_path): - manager = DagFileProcessorManager(max_runs=1) - manager.run() - + manager.last_dag_dir_refresh_time = time.monotonic() - 10 * 60 + manager._refresh_dag_dir() # Assert dag not deleted in SDM assert SerializedDagModel.has_dag("test_zip_dag") # assert code not deleted @@ -625,23 +605,20 @@ def test_refresh_dags_dir_doesnt_delete_zipped_dags(self, tmp_path, configure_te # assert dag still active assert dag.get_is_active() - def test_refresh_dags_dir_deactivates_deleted_zipped_dags(self, tmp_path, configure_testing_dag_bundle): + def test_refresh_dags_dir_deactivates_deleted_zipped_dags(self, tmp_path): """Test DagFileProcessorManager._refresh_dag_dir method""" + manager = DagFileProcessorManager(dag_directory=TEST_DAG_FOLDER, max_runs=1) dagbag = DagBag(dag_folder=tmp_path, include_examples=False) zipped_dag_path = os.path.join(TEST_DAGS_FOLDER, "test_zip.zip") dagbag.process_file(zipped_dag_path) dag = dagbag.get_dag("test_zip_dag") dag.sync_to_db() SerializedDagModel.write_dag(dag) - - # TODO: this test feels a bit fragile - pointing at the zip directly causes the test to fail - # TODO: jed look at this more closely - bagbad then process_file?! + manager.last_dag_dir_refresh_time = time.monotonic() - 10 * 60 # Mock might_contain_dag to mimic deleting the python file from the zip with mock.patch("airflow.dag_processing.manager.might_contain_dag", return_value=False): - with configure_testing_dag_bundle(TEST_DAGS_FOLDER): - manager = DagFileProcessorManager(max_runs=1) - manager.run() + manager._refresh_dag_dir() # Deleting the python file should not delete SDM for versioning sake assert SerializedDagModel.has_dag("test_zip_dag") @@ -658,7 +635,7 @@ def test_refresh_dags_dir_deactivates_deleted_zipped_dags(self, tmp_path, config ("scheduler", "standalone_dag_processor"): "True", } ) - def test_fetch_callbacks_from_database(self, tmp_path, configure_testing_dag_bundle): + def test_fetch_callbacks_from_database(self, tmp_path): dag_filepath = TEST_DAG_FOLDER / "test_on_failure_callback_dag.py" callback1 = DagCallbackRequest( @@ -678,12 +655,13 @@ def test_fetch_callbacks_from_database(self, tmp_path, configure_testing_dag_bun session.add(DbCallbackRequest(callback=callback1, priority_weight=11)) session.add(DbCallbackRequest(callback=callback2, priority_weight=10)) - with configure_testing_dag_bundle(tmp_path): - manager = DagFileProcessorManager(max_runs=1, standalone_dag_processor=True) + manager = DagFileProcessorManager( + dag_directory=os.fspath(tmp_path), max_runs=1, standalone_dag_processor=True + ) - with create_session() as session: - manager.run() - assert session.query(DbCallbackRequest).count() == 0 + with create_session() as session: + self.run_processor_manager_one_loop(manager) + assert session.query(DbCallbackRequest).count() == 0 @conf_vars( { @@ -692,7 +670,7 @@ def test_fetch_callbacks_from_database(self, tmp_path, configure_testing_dag_bun ("core", "load_examples"): "False", } ) - def test_fetch_callbacks_from_database_max_per_loop(self, tmp_path, configure_testing_dag_bundle): + def test_fetch_callbacks_from_database_max_per_loop(self, tmp_path): """Test DagFileProcessorManager._fetch_callbacks method""" dag_filepath = TEST_DAG_FOLDER / "test_on_failure_callback_dag.py" @@ -706,16 +684,15 @@ def test_fetch_callbacks_from_database_max_per_loop(self, tmp_path, configure_te ) session.add(DbCallbackRequest(callback=callback, priority_weight=i)) - with configure_testing_dag_bundle(tmp_path): - manager = DagFileProcessorManager(max_runs=1) + manager = DagFileProcessorManager(dag_directory=tmp_path, max_runs=1) - with create_session() as session: - manager.run() - assert session.query(DbCallbackRequest).count() == 3 + with create_session() as session: + self.run_processor_manager_one_loop(manager) + assert session.query(DbCallbackRequest).count() == 3 - with create_session() as session: - manager.run() - assert session.query(DbCallbackRequest).count() == 1 + with create_session() as session: + self.run_processor_manager_one_loop(manager) + assert session.query(DbCallbackRequest).count() == 1 @conf_vars( { @@ -723,7 +700,7 @@ def test_fetch_callbacks_from_database_max_per_loop(self, tmp_path, configure_te ("core", "load_examples"): "False", } ) - def test_fetch_callbacks_from_database_not_standalone(self, tmp_path, configure_testing_dag_bundle): + def test_fetch_callbacks_from_database_not_standalone(self, tmp_path): dag_filepath = TEST_DAG_FOLDER / "test_on_failure_callback_dag.py" with create_session() as session: @@ -735,23 +712,22 @@ def test_fetch_callbacks_from_database_not_standalone(self, tmp_path, configure_ ) session.add(DbCallbackRequest(callback=callback, priority_weight=10)) - with configure_testing_dag_bundle(tmp_path): - manager = DagFileProcessorManager(max_runs=1) - manager.run() + manager = DagFileProcessorManager(dag_directory=tmp_path, max_runs=1) + + self.run_processor_manager_one_loop(manager) # Verify no callbacks removed from database. with create_session() as session: assert session.query(DbCallbackRequest).count() == 1 - @pytest.mark.skip("AIP-66: callbacks are not implemented yet") def test_callback_queue(self, tmp_path): # given manager = DagFileProcessorManager( + dag_directory=TEST_DAG_FOLDER, max_runs=1, processor_timeout=365 * 86_400, ) - dag1_path = DagFileInfo(bundle_name="testing", path="/green_eggs/ham/file1.py") dag1_req1 = DagCallbackRequest( full_filepath="/green_eggs/ham/file1.py", dag_id="dag1", @@ -767,7 +743,6 @@ def test_callback_queue(self, tmp_path): msg=None, ) - dag2_path = DagFileInfo(bundle_name="testing", path="/green_eggs/ham/file2.py") dag2_req1 = DagCallbackRequest( full_filepath="/green_eggs/ham/file2.py", dag_id="dag2", @@ -781,7 +756,7 @@ def test_callback_queue(self, tmp_path): manager._add_callback_to_queue(dag2_req1) # then - requests should be in manager's queue, with dag2 ahead of dag1 (because it was added last) - assert manager._file_path_queue == deque([dag2_path, dag1_path]) + assert manager._file_path_queue == deque([dag2_req1.full_filepath, dag1_req1.full_filepath]) assert set(manager._callback_to_execute.keys()) == { dag1_req1.full_filepath, dag2_req1.full_filepath, @@ -812,106 +787,24 @@ def test_callback_queue(self, tmp_path): # And removed from the queue assert dag1_req1.full_filepath not in manager._callback_to_execute - def test_dag_with_assets(self, session, configure_testing_dag_bundle): + def test_dag_with_assets(self, session): """'Integration' test to ensure that the assets get parsed and stored correctly for parsed dags.""" test_dag_path = str(TEST_DAG_FOLDER / "test_assets.py") - with configure_testing_dag_bundle(test_dag_path): - manager = DagFileProcessorManager( - max_runs=1, - processor_timeout=365 * 86_400, - ) - manager.run() + manager = DagFileProcessorManager( + dag_directory=test_dag_path, + max_runs=1, + processor_timeout=365 * 86_400, + ) + + self.run_processor_manager_one_loop(manager) dag_model = session.get(DagModel, ("dag_with_skip_task")) assert dag_model.task_outlet_asset_references == [ TaskOutletAssetReference(asset_id=mock.ANY, dag_id="dag_with_skip_task", task_id="skip_task") ] - def test_bundles_are_refreshed(self): - """ - Ensure bundles are refreshed by the manager, when necessary. - - - refresh if the bundle hasn't been refreshed in the refresh_interval - - when the latest_version in the db doesn't match the version this parser knows about - """ - config = [ - { - "name": "bundleone", - "classpath": "airflow.dag_processing.bundles.local.LocalDagBundle", - "kwargs": {"local_folder": "/dev/null", "refresh_interval": 0}, - }, - { - "name": "bundletwo", - "classpath": "airflow.dag_processing.bundles.local.LocalDagBundle", - "kwargs": {"local_folder": "/dev/null", "refresh_interval": 300}, - }, - ] - - bundleone = MagicMock() - bundleone.name = "bundleone" - bundleone.refresh_interval = 0 - bundleone.get_current_version.return_value = None - bundletwo = MagicMock() - bundletwo.name = "bundletwo" - bundletwo.refresh_interval = 300 - bundletwo.get_current_version.return_value = None - - with conf_vars({("dag_bundles", "backends"): json.dumps(config)}): - DagBundlesManager().sync_bundles_to_db() - with mock.patch( - "airflow.dag_processing.bundles.manager.DagBundlesManager" - ) as mock_bundle_manager: - mock_bundle_manager.return_value._bundle_config = {"bundleone": None, "bundletwo": None} - mock_bundle_manager.return_value.get_all_dag_bundles.return_value = [bundleone, bundletwo] - manager = DagFileProcessorManager(max_runs=1) - manager.run() - bundleone.refresh.assert_called_once() - bundletwo.refresh.assert_called_once() - - # Now, we should only refresh bundleone, as haven't hit the refresh_interval for bundletwo - bundleone.reset_mock() - bundletwo.reset_mock() - manager.run() - bundleone.refresh.assert_called_once() - bundletwo.refresh.assert_not_called() - - # however, if the version doesn't match, we should still refresh - bundletwo.reset_mock() - bundletwo.get_current_version.return_value = "123" - manager.run() - bundletwo.refresh.assert_called_once() - - def test_bundles_versions_are_stored(self): - config = [ - { - "name": "mybundle", - "classpath": "airflow.dag_processing.bundles.local.LocalDagBundle", - "kwargs": {"local_folder": "/dev/null", "refresh_interval": 0}, - }, - ] - - mybundle = MagicMock() - mybundle.name = "bundleone" - mybundle.refresh_interval = 0 - mybundle.supports_versioning = True - mybundle.get_current_version.return_value = "123" - - with conf_vars({("dag_bundles", "backends"): json.dumps(config)}): - DagBundlesManager().sync_bundles_to_db() - with mock.patch( - "airflow.dag_processing.bundles.manager.DagBundlesManager" - ) as mock_bundle_manager: - mock_bundle_manager.return_value._bundle_config = {"bundleone": None} - mock_bundle_manager.return_value.get_all_dag_bundles.return_value = [mybundle] - manager = DagFileProcessorManager(max_runs=1) - manager.run() - - with create_session() as session: - model = session.get(DagBundleModel, "bundleone") - assert model.latest_version == "123" - class TestDagFileProcessorAgent: @pytest.fixture(autouse=True) @@ -922,12 +815,14 @@ def _disable_examples(self): def test_launch_process(self): from airflow.configuration import conf + test_dag_path = TEST_DAG_FOLDER / "test_scheduler_dags.py" + log_file_loc = conf.get("logging", "DAG_PROCESSOR_MANAGER_LOG_LOCATION") with suppress(OSError): os.remove(log_file_loc) # Starting dag processing with 0 max_runs to avoid redundant operations. - processor_agent = DagFileProcessorAgent(0, timedelta(days=365)) + processor_agent = DagFileProcessorAgent(test_dag_path, 0, timedelta(days=365)) processor_agent.start() processor_agent._process.join() @@ -935,25 +830,25 @@ def test_launch_process(self): assert os.path.isfile(log_file_loc) def test_get_callbacks_pipe(self): - processor_agent = DagFileProcessorAgent(1, timedelta(days=365)) + processor_agent = DagFileProcessorAgent("", 1, timedelta(days=365)) processor_agent._parent_signal_conn = Mock() retval = processor_agent.get_callbacks_pipe() assert retval == processor_agent._parent_signal_conn def test_get_callbacks_pipe_no_parent_signal_conn(self): - processor_agent = DagFileProcessorAgent(1, timedelta(days=365)) + processor_agent = DagFileProcessorAgent("", 1, timedelta(days=365)) processor_agent._parent_signal_conn = None with pytest.raises(ValueError, match="Process not started"): processor_agent.get_callbacks_pipe() def test_heartbeat_no_parent_signal_conn(self): - processor_agent = DagFileProcessorAgent(1, timedelta(days=365)) + processor_agent = DagFileProcessorAgent("", 1, timedelta(days=365)) processor_agent._parent_signal_conn = None with pytest.raises(ValueError, match="Process not started"): processor_agent.heartbeat() def test_heartbeat_poll_eof_error(self): - processor_agent = DagFileProcessorAgent(1, timedelta(days=365)) + processor_agent = DagFileProcessorAgent("", 1, timedelta(days=365)) processor_agent._parent_signal_conn = Mock() processor_agent._parent_signal_conn.poll.return_value = True processor_agent._parent_signal_conn.recv = Mock() @@ -962,7 +857,7 @@ def test_heartbeat_poll_eof_error(self): assert ret_val is None def test_heartbeat_poll_connection_error(self): - processor_agent = DagFileProcessorAgent(1, timedelta(days=365)) + processor_agent = DagFileProcessorAgent("", 1, timedelta(days=365)) processor_agent._parent_signal_conn = Mock() processor_agent._parent_signal_conn.poll.return_value = True processor_agent._parent_signal_conn.recv = Mock() @@ -971,7 +866,7 @@ def test_heartbeat_poll_connection_error(self): assert ret_val is None def test_heartbeat_poll_process_message(self): - processor_agent = DagFileProcessorAgent(1, timedelta(days=365)) + processor_agent = DagFileProcessorAgent("", 1, timedelta(days=365)) processor_agent._parent_signal_conn = Mock() processor_agent._parent_signal_conn.poll.side_effect = [True, False] processor_agent._parent_signal_conn.recv = Mock() @@ -982,13 +877,13 @@ def test_heartbeat_poll_process_message(self): def test_process_message_invalid_type(self): message = "xyz" - processor_agent = DagFileProcessorAgent(1, timedelta(days=365)) + processor_agent = DagFileProcessorAgent("", 1, timedelta(days=365)) with pytest.raises(RuntimeError, match="Unexpected message received of type str"): processor_agent._process_message(message) @mock.patch("airflow.utils.process_utils.reap_process_group") def test_heartbeat_manager_process_restart(self, mock_pg, monkeypatch): - processor_agent = DagFileProcessorAgent(1, timedelta(days=365)) + processor_agent = DagFileProcessorAgent("", 1, timedelta(days=365)) processor_agent._parent_signal_conn = Mock() processor_agent._process = MagicMock() monkeypatch.setattr(processor_agent._process, "pid", 1234) @@ -1004,7 +899,7 @@ def test_heartbeat_manager_process_restart(self, mock_pg, monkeypatch): @mock.patch("time.monotonic") @mock.patch("airflow.dag_processing.manager.reap_process_group") def test_heartbeat_manager_process_reap(self, mock_pg, mock_time_monotonic, mock_stats): - processor_agent = DagFileProcessorAgent(1, timedelta(days=365)) + processor_agent = DagFileProcessorAgent("", 1, timedelta(days=365)) processor_agent._parent_signal_conn = Mock() processor_agent._process = Mock() processor_agent._process.pid = 12345 @@ -1025,7 +920,7 @@ def test_heartbeat_manager_process_reap(self, mock_pg, mock_time_monotonic, mock processor_agent.start.assert_called() def test_heartbeat_manager_end_no_process(self): - processor_agent = DagFileProcessorAgent(1, timedelta(days=365)) + processor_agent = DagFileProcessorAgent("", 1, timedelta(days=365)) processor_agent._process = Mock() processor_agent._process.__bool__ = Mock(return_value=False) processor_agent._process.side_effect = [None] @@ -1036,25 +931,26 @@ def test_heartbeat_manager_end_no_process(self): processor_agent._process.join.assert_not_called() @pytest.mark.execution_timeout(5) - def test_terminate(self, tmp_path, configure_testing_dag_bundle): - with configure_testing_dag_bundle(tmp_path): - processor_agent = DagFileProcessorAgent(-1, timedelta(days=365)) + def test_terminate(self, tmp_path): + processor_agent = DagFileProcessorAgent(tmp_path, -1, timedelta(days=365)) - processor_agent.start() - try: - processor_agent.terminate() + processor_agent.start() + try: + processor_agent.terminate() - processor_agent._process.join(timeout=1) - assert processor_agent._process.is_alive() is False - assert processor_agent._process.exitcode == 0 - except Exception: - reap_process_group(processor_agent._process.pid, logger=logger) - raise + processor_agent._process.join(timeout=1) + assert processor_agent._process.is_alive() is False + assert processor_agent._process.exitcode == 0 + except Exception: + reap_process_group(processor_agent._process.pid, logger=logger) + raise @conf_vars({("logging", "dag_processor_manager_log_stdout"): "True"}) def test_log_to_stdout(self, capfd): + test_dag_path = TEST_DAG_FOLDER / "test_scheduler_dags.py" + # Starting dag processing with 0 max_runs to avoid redundant operations. - processor_agent = DagFileProcessorAgent(0, timedelta(days=365)) + processor_agent = DagFileProcessorAgent(test_dag_path, 0, timedelta(days=365)) processor_agent.start() processor_agent._process.join() @@ -1065,8 +961,10 @@ def test_log_to_stdout(self, capfd): @conf_vars({("logging", "dag_processor_manager_log_stdout"): "False"}) def test_not_log_to_stdout(self, capfd): + test_dag_path = TEST_DAG_FOLDER / "test_scheduler_dags.py" + # Starting dag processing with 0 max_runs to avoid redundant operations. - processor_agent = DagFileProcessorAgent(0, timedelta(days=365)) + processor_agent = DagFileProcessorAgent(test_dag_path, 0, timedelta(days=365)) processor_agent.start() processor_agent._process.join() diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 328ae1742fb4a..b9db6f2751d5f 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -69,6 +69,7 @@ from airflow.serialization.serialized_objects import SerializedDAG from airflow.timetables.base import DataInterval from airflow.utils import timezone +from airflow.utils.file import list_py_file_paths from airflow.utils.session import create_session, provide_session from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.types import DagRunType @@ -106,7 +107,6 @@ ELASTIC_DAG_FILE = os.path.join(PERF_DAGS_FOLDER, "elastic_dag.py") TEST_DAG_FOLDER = os.environ["AIRFLOW__CORE__DAGS_FOLDER"] -EXAMPLE_DAGS_FOLDER = airflow.example_dags.__path__[0] DEFAULT_DATE = timezone.datetime(2016, 1, 1) DEFAULT_LOGICAL_DATE = timezone.coerce_datetime(DEFAULT_DATE) TRY_NUMBER = 1 @@ -119,6 +119,12 @@ def disable_load_example(): yield +@pytest.fixture +def load_examples(): + with conf_vars({("core", "load_examples"): "True"}): + yield + + # Patch the MockExecutor into the dict of known executors in the Loader @contextlib.contextmanager def _loader_mock(mock_executors): @@ -573,32 +579,26 @@ def test_execute_task_instances_backfill_tasks_will_execute(self, dag_maker): session.rollback() @conf_vars({("scheduler", "standalone_dag_processor"): "False"}) - def test_setup_callback_sink_not_standalone_dag_processor( - self, mock_executors, configure_testing_dag_bundle - ): - with configure_testing_dag_bundle(os.devnull): - scheduler_job = Job() - self.job_runner = SchedulerJobRunner(job=scheduler_job, num_runs=1) - self.job_runner._execute() + def test_setup_callback_sink_not_standalone_dag_processor(self, mock_executors): + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull, num_runs=1) + self.job_runner._execute() - assert isinstance(scheduler_job.executor.callback_sink, PipeCallbackSink) + assert isinstance(scheduler_job.executor.callback_sink, PipeCallbackSink) @conf_vars({("scheduler", "standalone_dag_processor"): "False"}) - def test_setup_callback_sink_not_standalone_dag_processor_multiple_executors( - self, mock_executors, configure_testing_dag_bundle - ): - with configure_testing_dag_bundle(os.devnull): - scheduler_job = Job() - self.job_runner = SchedulerJobRunner(job=scheduler_job, num_runs=1) - self.job_runner._execute() + def test_setup_callback_sink_not_standalone_dag_processor_multiple_executors(self, mock_executors): + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull, num_runs=1) + self.job_runner._execute() - for executor in scheduler_job.executors: - assert isinstance(executor.callback_sink, PipeCallbackSink) + for executor in scheduler_job.executors: + assert isinstance(executor.callback_sink, PipeCallbackSink) @conf_vars({("scheduler", "standalone_dag_processor"): "True"}) def test_setup_callback_sink_standalone_dag_processor(self, mock_executors): scheduler_job = Job() - self.job_runner = SchedulerJobRunner(job=scheduler_job, num_runs=1) + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull, num_runs=1) self.job_runner._execute() assert isinstance(scheduler_job.executor.callback_sink, DatabaseCallbackSink) @@ -606,7 +606,7 @@ def test_setup_callback_sink_standalone_dag_processor(self, mock_executors): @conf_vars({("scheduler", "standalone_dag_processor"): "True"}) def test_setup_callback_sink_standalone_dag_processor_multiple_executors(self, mock_executors): scheduler_job = Job() - self.job_runner = SchedulerJobRunner(job=scheduler_job, num_runs=1) + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull, num_runs=1) self.job_runner._execute() for executor in scheduler_job.executors: @@ -615,40 +615,37 @@ def test_setup_callback_sink_standalone_dag_processor_multiple_executors(self, m @conf_vars({("scheduler", "standalone_dag_processor"): "True"}) def test_executor_start_called(self, mock_executors): scheduler_job = Job() - self.job_runner = SchedulerJobRunner(job=scheduler_job, num_runs=1) + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull, num_runs=1) self.job_runner._execute() scheduler_job.executor.start.assert_called_once() for executor in scheduler_job.executors: executor.start.assert_called_once() - def test_executor_job_id_assigned(self, mock_executors, configure_testing_dag_bundle): - with configure_testing_dag_bundle(os.devnull): - scheduler_job = Job() - self.job_runner = SchedulerJobRunner(job=scheduler_job, num_runs=1) - self.job_runner._execute() + def test_executor_job_id_assigned(self, mock_executors): + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull, num_runs=1) + self.job_runner._execute() - assert scheduler_job.executor.job_id == scheduler_job.id - for executor in scheduler_job.executors: - assert executor.job_id == scheduler_job.id + assert scheduler_job.executor.job_id == scheduler_job.id + for executor in scheduler_job.executors: + assert executor.job_id == scheduler_job.id - def test_executor_heartbeat(self, mock_executors, configure_testing_dag_bundle): - with configure_testing_dag_bundle(os.devnull): - scheduler_job = Job() - self.job_runner = SchedulerJobRunner(job=scheduler_job, num_runs=1) - self.job_runner._execute() + def test_executor_heartbeat(self, mock_executors): + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull, num_runs=1) + self.job_runner._execute() - for executor in scheduler_job.executors: - executor.heartbeat.assert_called_once() + for executor in scheduler_job.executors: + executor.heartbeat.assert_called_once() - def test_executor_events_processed(self, mock_executors, configure_testing_dag_bundle): - with configure_testing_dag_bundle(os.devnull): - scheduler_job = Job() - self.job_runner = SchedulerJobRunner(job=scheduler_job, num_runs=1) - self.job_runner._execute() + def test_executor_events_processed(self, mock_executors): + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull, num_runs=1) + self.job_runner._execute() - for executor in scheduler_job.executors: - executor.get_event_buffer.assert_called_once() + for executor in scheduler_job.executors: + executor.get_event_buffer.assert_called_once() def test_executor_debug_dump(self, mock_executors): scheduler_job = Job() @@ -2894,17 +2891,17 @@ def my_task(): ... self.job_runner._do_scheduling(session) assert session.query(DagRun).one().state == run_state - def test_dagrun_root_after_dagrun_unfinished(self, mock_executor, testing_dag_bundle): + def test_dagrun_root_after_dagrun_unfinished(self, mock_executor): """ DagRuns with one successful and one future root task -> SUCCESS Noted: the DagRun state could be still in running state during CI. """ dagbag = DagBag(TEST_DAG_FOLDER, include_examples=False) - dagbag.sync_to_db("testing", None) + dagbag.sync_to_db() dag_id = "test_dagrun_states_root_future" dag = dagbag.get_dag(dag_id) - DAG.bulk_write_to_db("testing", None, [dag]) + dag.sync_to_db() scheduler_job = Job() self.job_runner = SchedulerJobRunner(job=scheduler_job, num_runs=2, subdir=dag.fileloc) @@ -2923,7 +2920,7 @@ def test_dagrun_root_after_dagrun_unfinished(self, mock_executor, testing_dag_bu {("scheduler", "standalone_dag_processor"): "True"}, ], ) - def test_scheduler_start_date(self, configs, testing_dag_bundle): + def test_scheduler_start_date(self, configs): """ Test that the scheduler respects start_dates, even when DAGs have run """ @@ -2938,7 +2935,7 @@ def test_scheduler_start_date(self, configs, testing_dag_bundle): # Deactivate other dags in this file other_dag = dagbag.get_dag("test_task_start_date_scheduling") other_dag.is_paused_upon_creation = True - DAG.bulk_write_to_db("testing", None, [other_dag]) + other_dag.sync_to_db() scheduler_job = Job( executor=self.null_exec, ) @@ -2984,7 +2981,7 @@ def test_scheduler_start_date(self, configs, testing_dag_bundle): {("scheduler", "standalone_dag_processor"): "True"}, ], ) - def test_scheduler_task_start_date(self, configs, testing_dag_bundle): + def test_scheduler_task_start_date(self, configs): """ Test that the scheduler respects task start dates that are different from DAG start dates """ @@ -3003,7 +3000,7 @@ def test_scheduler_task_start_date(self, configs, testing_dag_bundle): other_dag.is_paused_upon_creation = True dagbag.bag_dag(dag=other_dag) - dagbag.sync_to_db("testing", None) + dagbag.sync_to_db() scheduler_job = Job( executor=self.null_exec, @@ -3530,6 +3527,50 @@ def test_dag_get_active_runs(self, dag_maker): assert logical_date == running_date, "Running Date must match Execution Date" + def test_list_py_file_paths(self): + """ + [JIRA-1357] Test the 'list_py_file_paths' function used by the + scheduler to list and load DAGs. + """ + detected_files = set() + expected_files = set() + # No_dags is empty, _invalid_ is ignored by .airflowignore + ignored_files = { + "no_dags.py", + "test_invalid_cron.py", + "test_invalid_dup_task.py", + "test_ignore_this.py", + "test_invalid_param.py", + "test_invalid_param2.py", + "test_invalid_param3.py", + "test_invalid_param4.py", + "test_nested_dag.py", + "test_imports.py", + "__init__.py", + } + for root, _, files in os.walk(TEST_DAG_FOLDER): + for file_name in files: + if file_name.endswith((".py", ".zip")): + if file_name not in ignored_files: + expected_files.add(f"{root}/{file_name}") + for file_path in list_py_file_paths(TEST_DAG_FOLDER, include_examples=False): + detected_files.add(file_path) + assert detected_files == expected_files + + ignored_files = { + "helper.py", + } + example_dag_folder = airflow.example_dags.__path__[0] + for root, _, files in os.walk(example_dag_folder): + for file_name in files: + if file_name.endswith((".py", ".zip")): + if file_name not in ["__init__.py"] and file_name not in ignored_files: + expected_files.add(os.path.join(root, file_name)) + detected_files.clear() + for file_path in list_py_file_paths(TEST_DAG_FOLDER, include_examples=True): + detected_files.add(file_path) + assert detected_files == expected_files + def test_adopt_or_reset_orphaned_tasks_nothing(self): """Try with nothing.""" scheduler_job = Job() @@ -4166,7 +4207,7 @@ def test_scheduler_create_dag_runs_does_not_raise_error(self, caplog, dag_maker) "DAG 'test_scheduler_create_dag_runs_does_not_raise_error' not found in serialized_dag table", ] - def test_bulk_write_to_db_external_trigger_dont_skip_scheduled_run(self, dag_maker, testing_dag_bundle): + def test_bulk_write_to_db_external_trigger_dont_skip_scheduled_run(self, dag_maker): """ Test that externally triggered Dag Runs should not affect (by skipping) next scheduled DAG runs @@ -4224,7 +4265,7 @@ def test_bulk_write_to_db_external_trigger_dont_skip_scheduled_run(self, dag_mak ) assert dr is not None # Run DAG.bulk_write_to_db -- this is run when in DagFileProcessor.process_file - DAG.bulk_write_to_db("testing", None, [dag], session=session) + DAG.bulk_write_to_db([dag], session=session) # Test that 'dag_model.next_dagrun' has not been changed because of newly created external # triggered DagRun. @@ -4517,7 +4558,7 @@ def test_do_schedule_max_active_runs_and_manual_trigger(self, dag_maker, mock_ex dag_version = DagVersion.get_latest_version(dag.dag_id) dag_run = dag_maker.create_dagrun(state=State.QUEUED, session=session, dag_version=dag_version) - DAG.bulk_write_to_db("testing", None, [dag], session=session) # Update the date fields + dag.sync_to_db(session=session) # Update the date fields scheduler_job = Job() self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) @@ -5612,11 +5653,11 @@ def test_find_and_purge_zombies_nothing(self): self.job_runner._find_and_purge_zombies() executor.callback_sink.send.assert_not_called() - def test_find_and_purge_zombies(self, session, testing_dag_bundle): - dagfile = os.path.join(EXAMPLE_DAGS_FOLDER, "example_branch_operator.py") - dagbag = DagBag(dagfile) + def test_find_and_purge_zombies(self, load_examples, session): + dagbag = DagBag(TEST_DAG_FOLDER, read_dags_from_db=False) + dag = dagbag.get_dag("example_branch_operator") - DAG.bulk_write_to_db("testing", None, [dag]) + dag.sync_to_db() data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE) triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} dag_run = dag.create_dagrun( @@ -5668,74 +5709,70 @@ def test_find_and_purge_zombies(self, session, testing_dag_bundle): assert callback_request.ti.run_id == ti.run_id assert callback_request.ti.map_index == ti.map_index - def test_zombie_message(self, testing_dag_bundle, session): + def test_zombie_message(self, load_examples): """ Check that the zombie message comes out as expected """ dagbag = DagBag(TEST_DAG_FOLDER, read_dags_from_db=False) - dagfile = os.path.join(EXAMPLE_DAGS_FOLDER, "example_branch_operator.py") - dagbag = DagBag(dagfile) - dag = dagbag.get_dag("example_branch_operator") - DAG.bulk_write_to_db("testing", None, [dag]) - - session.query(Job).delete() - - data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE) - triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} - dag_run = dag.create_dagrun( - state=DagRunState.RUNNING, - logical_date=DEFAULT_DATE, - run_type=DagRunType.SCHEDULED, - session=session, - data_interval=data_interval, - **triggered_by_kwargs, - ) - - scheduler_job = Job(executor=MockExecutor()) - self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) - self.job_runner.processor_agent = mock.MagicMock() + with create_session() as session: + session.query(Job).delete() + dag = dagbag.get_dag("example_branch_operator") + dag.sync_to_db() - # We will provision 2 tasks so we can check we only find zombies from this scheduler - tasks_to_setup = ["branching", "run_this_first"] + data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE) + triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} + dag_run = dag.create_dagrun( + state=DagRunState.RUNNING, + logical_date=DEFAULT_DATE, + run_type=DagRunType.SCHEDULED, + session=session, + data_interval=data_interval, + **triggered_by_kwargs, + ) - for task_id in tasks_to_setup: - task = dag.get_task(task_id=task_id) - ti = TaskInstance(task, run_id=dag_run.run_id, state=State.RUNNING) - ti.queued_by_job_id = 999 + scheduler_job = Job(executor=MockExecutor()) + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + self.job_runner.processor_agent = mock.MagicMock() - session.add(ti) - session.flush() + # We will provision 2 tasks so we can check we only find zombies from this scheduler + tasks_to_setup = ["branching", "run_this_first"] - assert task.task_id == "run_this_first" # Make sure we have the task/ti we expect + for task_id in tasks_to_setup: + task = dag.get_task(task_id=task_id) + ti = TaskInstance(task, run_id=dag_run.run_id, state=State.RUNNING) + ti.queued_by_job_id = 999 - ti.queued_by_job_id = scheduler_job.id - session.flush() + session.add(ti) + session.flush() - zombie_message = self.job_runner._generate_zombie_message_details(ti) - assert zombie_message == { - "DAG Id": "example_branch_operator", - "Task Id": "run_this_first", - "Run Id": "scheduled__2016-01-01T00:00:00+00:00", - } + assert task.task_id == "run_this_first" # Make sure we have the task/ti we expect - ti.hostname = "10.10.10.10" - ti.map_index = 2 - ti.external_executor_id = "abcdefg" - - zombie_message = self.job_runner._generate_zombie_message_details(ti) - assert zombie_message == { - "DAG Id": "example_branch_operator", - "Task Id": "run_this_first", - "Run Id": "scheduled__2016-01-01T00:00:00+00:00", - "Hostname": "10.10.10.10", - "Map Index": 2, - "External Executor Id": "abcdefg", - } + ti.queued_by_job_id = scheduler_job.id + session.flush() - def test_find_zombies_handle_failure_callbacks_are_correctly_passed_to_dag_processor( - self, testing_dag_bundle - ): + zombie_message = self.job_runner._generate_zombie_message_details(ti) + assert zombie_message == { + "DAG Id": "example_branch_operator", + "Task Id": "run_this_first", + "Run Id": "scheduled__2016-01-01T00:00:00+00:00", + } + + ti.hostname = "10.10.10.10" + ti.map_index = 2 + ti.external_executor_id = "abcdefg" + + zombie_message = self.job_runner._generate_zombie_message_details(ti) + assert zombie_message == { + "DAG Id": "example_branch_operator", + "Task Id": "run_this_first", + "Run Id": "scheduled__2016-01-01T00:00:00+00:00", + "Hostname": "10.10.10.10", + "Map Index": 2, + "External Executor Id": "abcdefg", + } + + def test_find_zombies_handle_failure_callbacks_are_correctly_passed_to_dag_processor(self): """ Check that the same set of failure callback with zombies are passed to the dag file processors until the next zombie detection logic is invoked. @@ -5747,7 +5784,7 @@ def test_find_zombies_handle_failure_callbacks_are_correctly_passed_to_dag_proce ) session.query(Job).delete() dag = dagbag.get_dag("test_example_bash_operator") - DAG.bulk_write_to_db("testing", None, [dag]) + dag.sync_to_db() data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE) triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} dag_run = dag.create_dagrun( @@ -5793,11 +5830,11 @@ def test_find_zombies_handle_failure_callbacks_are_correctly_passed_to_dag_proce callback_requests[0].ti = None assert expected_failure_callback_requests[0] == callback_requests[0] - def test_cleanup_stale_dags(self, testing_dag_bundle): + def test_cleanup_stale_dags(self): dagbag = DagBag(TEST_DAG_FOLDER, read_dags_from_db=False) with create_session() as session: dag = dagbag.get_dag("test_example_bash_operator") - DAG.bulk_write_to_db("testing", None, [dag]) + dag.sync_to_db() dm = DagModel.get_current("test_example_bash_operator") # Make it "stale". dm.last_parsed_time = timezone.utcnow() - timedelta(minutes=11) @@ -5805,7 +5842,7 @@ def test_cleanup_stale_dags(self, testing_dag_bundle): # This one should remain active. dag = dagbag.get_dag("test_start_date_scheduling") - DAG.bulk_write_to_db("testing", None, [dag]) + dag.sync_to_db() session.flush() @@ -5887,13 +5924,13 @@ def watch_heartbeat(*args, **kwargs): @pytest.mark.long_running @pytest.mark.parametrize("dag_id", ["test_mapped_classic", "test_mapped_taskflow"]) - def test_mapped_dag(self, dag_id, session, testing_dag_bundle): + def test_mapped_dag(self, dag_id, session): """End-to-end test of a simple mapped dag""" # Use SequentialExecutor for more predictable test behaviour from airflow.executors.sequential_executor import SequentialExecutor dagbag = DagBag(dag_folder=TEST_DAGS_FOLDER, include_examples=False) - dagbag.sync_to_db("testing", None) + dagbag.sync_to_db() dagbag.process_file(str(TEST_DAGS_FOLDER / f"{dag_id}.py")) dag = dagbag.get_dag(dag_id) assert dag @@ -5920,14 +5957,14 @@ def test_mapped_dag(self, dag_id, session, testing_dag_bundle): dr.refresh_from_db(session) assert dr.state == DagRunState.SUCCESS - def test_should_mark_empty_task_as_success(self, testing_dag_bundle): + def test_should_mark_empty_task_as_success(self): dag_file = os.path.join( os.path.dirname(os.path.realpath(__file__)), "../dags/test_only_empty_tasks.py" ) # Write DAGs to dag and serialized_dag table dagbag = DagBag(dag_folder=dag_file, include_examples=False, read_dags_from_db=False) - dagbag.sync_to_db("testing", None) + dagbag.sync_to_db() scheduler_job = Job() self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) @@ -5998,7 +6035,7 @@ def test_should_mark_empty_task_as_success(self, testing_dag_bundle): assert duration is None @pytest.mark.need_serialized_dag - def test_catchup_works_correctly(self, dag_maker, testing_dag_bundle): + def test_catchup_works_correctly(self, dag_maker): """Test that catchup works correctly""" session = settings.Session() with dag_maker( @@ -6033,7 +6070,7 @@ def test_catchup_works_correctly(self, dag_maker, testing_dag_bundle): session.flush() dag.catchup = False - DAG.bulk_write_to_db("testing", None, [dag]) + dag.sync_to_db() assert not dag.catchup dm = DagModel.get_dagmodel(dag.dag_id) @@ -6242,7 +6279,7 @@ def test_misconfigured_dags_doesnt_crash_scheduler(self, session, dag_maker, cap # Check if the second dagrun was created assert DagRun.find(dag_id="testdag2", session=session) - def test_activate_referenced_assets_with_no_existing_warning(self, session, testing_dag_bundle): + def test_activate_referenced_assets_with_no_existing_warning(self, session): dag_warnings = session.query(DagWarning).all() assert dag_warnings == [] @@ -6255,7 +6292,7 @@ def test_activate_referenced_assets_with_no_existing_warning(self, session, test asset1_2 = Asset(name="it's also a duplicate", uri="s3://bucket/key/1", extra=asset_extra) dag1 = DAG(dag_id=dag_id1, start_date=DEFAULT_DATE, schedule=[asset1, asset1_1, asset1_2]) - DAG.bulk_write_to_db("testing", None, [dag1], session=session) + DAG.bulk_write_to_db([dag1], session=session) asset_models = session.scalars(select(AssetModel)).all() assert len(asset_models) == 3 @@ -6276,7 +6313,7 @@ def test_activate_referenced_assets_with_no_existing_warning(self, session, test "dy associated to 'asset1'" ) - def test_activate_referenced_assets_with_existing_warnings(self, session, testing_dag_bundle): + def test_activate_referenced_assets_with_existing_warnings(self, session): dag_ids = [f"test_asset_dag{i}" for i in range(1, 4)] asset1_name = "asset1" asset_extra = {"foo": "bar"} @@ -6295,7 +6332,7 @@ def test_activate_referenced_assets_with_existing_warnings(self, session, testin dag2 = DAG(dag_id=dag_ids[1], start_date=DEFAULT_DATE) dag3 = DAG(dag_id=dag_ids[2], start_date=DEFAULT_DATE, schedule=[asset1_2]) - DAG.bulk_write_to_db("testing", None, [dag1, dag2, dag3], session=session) + DAG.bulk_write_to_db([dag1, dag2, dag3], session=session) asset_models = session.scalars(select(AssetModel)).all() @@ -6329,9 +6366,7 @@ def test_activate_referenced_assets_with_existing_warnings(self, session, testin "name is already associated to 's3://bucket/key/1'" ) - def test_activate_referenced_assets_with_multiple_conflict_asset_in_one_dag( - self, session, testing_dag_bundle - ): + def test_activate_referenced_assets_with_multiple_conflict_asset_in_one_dag(self, session): dag_id = "test_asset_dag" asset1_name = "asset1" asset_extra = {"foo": "bar"} @@ -6344,7 +6379,7 @@ def test_activate_referenced_assets_with_multiple_conflict_asset_in_one_dag( ) dag1 = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, schedule=schedule) - DAG.bulk_write_to_db("testing", None, [dag1], session=session) + DAG.bulk_write_to_db([dag1], session=session) asset_models = session.scalars(select(AssetModel)).all() @@ -6434,9 +6469,7 @@ def per_test(self) -> Generator: (93, 10, 10), # 10 DAGs with 10 tasks per DAG file. ], ) - def test_execute_queries_count_with_harvested_dags( - self, expected_query_count, dag_count, task_count, testing_dag_bundle - ): + def test_execute_queries_count_with_harvested_dags(self, expected_query_count, dag_count, task_count): with ( mock.patch.dict( "os.environ", @@ -6463,7 +6496,7 @@ def test_execute_queries_count_with_harvested_dags( ): dagruns = [] dagbag = DagBag(dag_folder=ELASTIC_DAG_FILE, include_examples=False, read_dags_from_db=False) - dagbag.sync_to_db("testing", None) + dagbag.sync_to_db() dag_ids = dagbag.dag_ids dagbag = DagBag(read_dags_from_db=True) @@ -6533,7 +6566,7 @@ def test_execute_queries_count_with_harvested_dags( ], ) def test_process_dags_queries_count( - self, expected_query_counts, dag_count, task_count, start_ago, schedule, shape, testing_dag_bundle + self, expected_query_counts, dag_count, task_count, start_ago, schedule, shape ): with ( mock.patch.dict( @@ -6559,7 +6592,7 @@ def test_process_dags_queries_count( ), ): dagbag = DagBag(dag_folder=ELASTIC_DAG_FILE, include_examples=False) - dagbag.sync_to_db("testing", None) + dagbag.sync_to_db() mock_agent = mock.MagicMock() diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index a63b05b0cd343..6eaa4e3ac3aeb 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -623,7 +623,7 @@ def test_dagtag_repr(self): repr(t) for t in session.query(DagTag).filter(DagTag.dag_id == "dag-test-dagtag").all() } - def test_bulk_write_to_db(self, testing_dag_bundle): + def test_bulk_write_to_db(self): clear_db_dags() dags = [ DAG(f"dag-bulk-sync-{i}", schedule=None, start_date=DEFAULT_DATE, tags=["test-dag"]) @@ -631,7 +631,7 @@ def test_bulk_write_to_db(self, testing_dag_bundle): ] with assert_queries_count(6): - DAG.bulk_write_to_db("testing", None, dags) + DAG.bulk_write_to_db(dags) with create_session() as session: assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", "dag-bulk-sync-3"} == { row[0] for row in session.query(DagModel.dag_id).all() @@ -648,14 +648,14 @@ def test_bulk_write_to_db(self, testing_dag_bundle): # Re-sync should do fewer queries with assert_queries_count(9): - DAG.bulk_write_to_db("testing", None, dags) + DAG.bulk_write_to_db(dags) with assert_queries_count(9): - DAG.bulk_write_to_db("testing", None, dags) + DAG.bulk_write_to_db(dags) # Adding tags for dag in dags: dag.tags.add("test-dag2") with assert_queries_count(10): - DAG.bulk_write_to_db("testing", None, dags) + DAG.bulk_write_to_db(dags) with create_session() as session: assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", "dag-bulk-sync-3"} == { row[0] for row in session.query(DagModel.dag_id).all() @@ -674,7 +674,7 @@ def test_bulk_write_to_db(self, testing_dag_bundle): for dag in dags: dag.tags.remove("test-dag") with assert_queries_count(10): - DAG.bulk_write_to_db("testing", None, dags) + DAG.bulk_write_to_db(dags) with create_session() as session: assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", "dag-bulk-sync-3"} == { row[0] for row in session.query(DagModel.dag_id).all() @@ -693,7 +693,7 @@ def test_bulk_write_to_db(self, testing_dag_bundle): for dag in dags: dag.tags = set() with assert_queries_count(10): - DAG.bulk_write_to_db("testing", None, dags) + DAG.bulk_write_to_db(dags) with create_session() as session: assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", "dag-bulk-sync-3"} == { row[0] for row in session.query(DagModel.dag_id).all() @@ -703,7 +703,7 @@ def test_bulk_write_to_db(self, testing_dag_bundle): for row in session.query(DagModel.last_parsed_time).all(): assert row[0] is not None - def test_bulk_write_to_db_single_dag(self, testing_dag_bundle): + def test_bulk_write_to_db_single_dag(self): """ Test bulk_write_to_db for a single dag using the index optimized query """ @@ -714,7 +714,7 @@ def test_bulk_write_to_db_single_dag(self, testing_dag_bundle): ] with assert_queries_count(6): - DAG.bulk_write_to_db("testing", None, dags) + DAG.bulk_write_to_db(dags) with create_session() as session: assert {"dag-bulk-sync-0"} == {row[0] for row in session.query(DagModel.dag_id).all()} assert { @@ -726,11 +726,11 @@ def test_bulk_write_to_db_single_dag(self, testing_dag_bundle): # Re-sync should do fewer queries with assert_queries_count(8): - DAG.bulk_write_to_db("testing", None, dags) + DAG.bulk_write_to_db(dags) with assert_queries_count(8): - DAG.bulk_write_to_db("testing", None, dags) + DAG.bulk_write_to_db(dags) - def test_bulk_write_to_db_multiple_dags(self, testing_dag_bundle): + def test_bulk_write_to_db_multiple_dags(self): """ Test bulk_write_to_db for multiple dags which does not use the index optimized query """ @@ -741,7 +741,7 @@ def test_bulk_write_to_db_multiple_dags(self, testing_dag_bundle): ] with assert_queries_count(6): - DAG.bulk_write_to_db("testing", None, dags) + DAG.bulk_write_to_db(dags) with create_session() as session: assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", "dag-bulk-sync-3"} == { row[0] for row in session.query(DagModel.dag_id).all() @@ -758,26 +758,26 @@ def test_bulk_write_to_db_multiple_dags(self, testing_dag_bundle): # Re-sync should do fewer queries with assert_queries_count(9): - DAG.bulk_write_to_db("testing", None, dags) + DAG.bulk_write_to_db(dags) with assert_queries_count(9): - DAG.bulk_write_to_db("testing", None, dags) + DAG.bulk_write_to_db(dags) @pytest.mark.parametrize("interval", [None, "@daily"]) - def test_bulk_write_to_db_interval_save_runtime(self, testing_dag_bundle, interval): + def test_bulk_write_to_db_interval_save_runtime(self, interval): mock_active_runs_of_dags = mock.MagicMock(side_effect=DagRun.active_runs_of_dags) with mock.patch.object(DagRun, "active_runs_of_dags", mock_active_runs_of_dags): dags_null_timetable = [ DAG("dag-interval-None", schedule=None, start_date=TEST_DATE), DAG("dag-interval-test", schedule=interval, start_date=TEST_DATE), ] - DAG.bulk_write_to_db("testing", None, dags_null_timetable, session=settings.Session()) + DAG.bulk_write_to_db(dags_null_timetable, session=settings.Session()) if interval: mock_active_runs_of_dags.assert_called_once() else: mock_active_runs_of_dags.assert_not_called() @pytest.mark.parametrize("state", [DagRunState.RUNNING, DagRunState.QUEUED]) - def test_bulk_write_to_db_max_active_runs(self, testing_dag_bundle, state): + def test_bulk_write_to_db_max_active_runs(self, state): """ Test that DagModel.next_dagrun_create_after is set to NULL when the dag cannot be created due to max active runs being hit. @@ -793,7 +793,7 @@ def test_bulk_write_to_db_max_active_runs(self, testing_dag_bundle, state): session = settings.Session() dag.clear() - DAG.bulk_write_to_db("testing", None, [dag], session=session) + DAG.bulk_write_to_db([dag], session=session) model = session.get(DagModel, dag.dag_id) @@ -810,17 +810,17 @@ def test_bulk_write_to_db_max_active_runs(self, testing_dag_bundle, state): **triggered_by_kwargs, ) assert dr is not None - DAG.bulk_write_to_db("testing", None, [dag]) + DAG.bulk_write_to_db([dag]) model = session.get(DagModel, dag.dag_id) # We signal "at max active runs" by saying this run is never eligible to be created assert model.next_dagrun_create_after is None # test that bulk_write_to_db again doesn't update next_dagrun_create_after - DAG.bulk_write_to_db("testing", None, [dag]) + DAG.bulk_write_to_db([dag]) model = session.get(DagModel, dag.dag_id) assert model.next_dagrun_create_after is None - def test_bulk_write_to_db_has_import_error(self, testing_dag_bundle): + def test_bulk_write_to_db_has_import_error(self): """ Test that DagModel.has_import_error is set to false if no import errors. """ @@ -830,7 +830,7 @@ def test_bulk_write_to_db_has_import_error(self, testing_dag_bundle): session = settings.Session() dag.clear() - DAG.bulk_write_to_db("testing", None, [dag], session=session) + DAG.bulk_write_to_db([dag], session=session) model = session.get(DagModel, dag.dag_id) @@ -844,14 +844,14 @@ def test_bulk_write_to_db_has_import_error(self, testing_dag_bundle): # assert assert model.has_import_errors # parse - DAG.bulk_write_to_db("testing", None, [dag]) + DAG.bulk_write_to_db([dag]) model = session.get(DagModel, dag.dag_id) # assert that has_import_error is now false assert not model.has_import_errors session.close() - def test_bulk_write_to_db_assets(self, testing_dag_bundle): + def test_bulk_write_to_db_assets(self): """ Ensure that assets referenced in a dag are correctly loaded into the database. """ @@ -877,7 +877,7 @@ def test_bulk_write_to_db_assets(self, testing_dag_bundle): session = settings.Session() dag1.clear() - DAG.bulk_write_to_db("testing", None, [dag1, dag2], session=session) + DAG.bulk_write_to_db([dag1, dag2], session=session) session.commit() stored_assets = {x.uri: x for x in session.query(AssetModel).all()} asset1_orm = stored_assets[a1.uri] @@ -908,7 +908,7 @@ def test_bulk_write_to_db_assets(self, testing_dag_bundle): EmptyOperator(task_id=task_id, dag=dag1, outlets=[a2]) dag2 = DAG(dag_id=dag_id2, start_date=DEFAULT_DATE, schedule=None) EmptyOperator(task_id=task_id, dag=dag2) - DAG.bulk_write_to_db("testing", None, [dag1, dag2], session=session) + DAG.bulk_write_to_db([dag1, dag2], session=session) session.commit() session.expunge_all() stored_assets = {x.uri: x for x in session.query(AssetModel).all()} @@ -937,7 +937,7 @@ def _find_assets_activation(session) -> tuple[list[AssetModel], list[AssetModel] ).all() return [a for a, v in assets if not v], [a for a, v in assets if v] - def test_bulk_write_to_db_does_not_activate(self, dag_maker, testing_dag_bundle, session): + def test_bulk_write_to_db_does_not_activate(self, dag_maker, session): """ Assets are not activated on write, but later in the scheduler by the SchedulerJob. """ @@ -950,14 +950,14 @@ def test_bulk_write_to_db_does_not_activate(self, dag_maker, testing_dag_bundle, dag1 = DAG(dag_id="assets-1", start_date=DEFAULT_DATE, schedule=[asset1]) BashOperator(dag=dag1, task_id="task", bash_command="echo 1", outlets=[asset3]) - DAG.bulk_write_to_db("testing", None, [dag1], session=session) + DAG.bulk_write_to_db([dag1], session=session) assert session.scalars(select(AssetModel).order_by(AssetModel.uri)).all() == [asset1, asset3] assert session.scalars(select(AssetActive)).all() == [] dag1 = DAG(dag_id="assets-1", start_date=DEFAULT_DATE, schedule=[asset1, asset2]) BashOperator(dag=dag1, task_id="task", bash_command="echo 1", outlets=[asset3, asset4]) - DAG.bulk_write_to_db("testing", None, [dag1], session=session) + DAG.bulk_write_to_db([dag1], session=session) assert session.scalars(select(AssetModel).order_by(AssetModel.uri)).all() == [ asset1, @@ -967,7 +967,7 @@ def test_bulk_write_to_db_does_not_activate(self, dag_maker, testing_dag_bundle, ] assert session.scalars(select(AssetActive)).all() == [] - def test_bulk_write_to_db_asset_aliases(self, testing_dag_bundle): + def test_bulk_write_to_db_asset_aliases(self): """ Ensure that asset aliases referenced in a dag are correctly loaded into the database. """ @@ -983,7 +983,7 @@ def test_bulk_write_to_db_asset_aliases(self, testing_dag_bundle): dag2 = DAG(dag_id=dag_id2, start_date=DEFAULT_DATE, schedule=None) EmptyOperator(task_id=task_id, dag=dag2, outlets=[asset_alias_2_2, asset_alias_3]) session = settings.Session() - DAG.bulk_write_to_db("testing", None, [dag1, dag2], session=session) + DAG.bulk_write_to_db([dag1, dag2], session=session) session.commit() stored_asset_alias_models = {x.name: x for x in session.query(AssetAliasModel).all()} @@ -2431,7 +2431,7 @@ def test_dags_needing_dagruns_asset_triggered_dag_info_queued_times(self, sessio assert first_queued_time == DEFAULT_DATE assert last_queued_time == DEFAULT_DATE + timedelta(hours=1) - def test_asset_expression(self, testing_dag_bundle, session: Session) -> None: + def test_asset_expression(self, session: Session) -> None: dag = DAG( dag_id="test_dag_asset_expression", schedule=AssetAny( @@ -2449,7 +2449,7 @@ def test_asset_expression(self, testing_dag_bundle, session: Session) -> None: ), start_date=datetime.datetime.min, ) - DAG.bulk_write_to_db("testing", None, [dag], session=session) + DAG.bulk_write_to_db([dag], session=session) expression = session.scalars(select(DagModel.asset_expression).filter_by(dag_id=dag.dag_id)).one() assert expression == { diff --git a/tests/models/test_dagcode.py b/tests/models/test_dagcode.py index ac7330b1aa581..05598703476a5 100644 --- a/tests/models/test_dagcode.py +++ b/tests/models/test_dagcode.py @@ -41,17 +41,8 @@ def make_example_dags(module): """Loads DAGs from a module for test.""" - # TODO: AIP-66 dedup with tests/models/test_serdag - from airflow.models.dagbundle import DagBundleModel - from airflow.utils.session import create_session - - with create_session() as session: - if session.query(DagBundleModel).filter(DagBundleModel.name == "testing").count() == 0: - testing = DagBundleModel(name="testing") - session.add(testing) - dagbag = DagBag(module.__path__[0]) - DAG.bulk_write_to_db("testing", None, dagbag.dags.values()) + DAG.bulk_write_to_db(dagbag.dags.values()) return dagbag.dags diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index f398be8ad88c4..13614d6abeaea 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -518,7 +518,7 @@ def test_on_success_callback_when_task_skipped(self, session): assert dag_run.state == DagRunState.SUCCESS mock_on_success.assert_called_once() - def test_dagrun_update_state_with_handle_callback_success(self, testing_dag_bundle, session): + def test_dagrun_update_state_with_handle_callback_success(self, session): def on_success_callable(context): assert context["dag_run"].dag_id == "test_dagrun_update_state_with_handle_callback_success" @@ -528,7 +528,7 @@ def on_success_callable(context): start_date=datetime.datetime(2017, 1, 1), on_success_callback=on_success_callable, ) - DAG.bulk_write_to_db("testing", None, dags=[dag], session=session) + DAG.bulk_write_to_db(dags=[dag], session=session) dag_task1 = EmptyOperator(task_id="test_state_succeeded1", dag=dag) dag_task2 = EmptyOperator(task_id="test_state_succeeded2", dag=dag) @@ -556,7 +556,7 @@ def on_success_callable(context): msg="success", ) - def test_dagrun_update_state_with_handle_callback_failure(self, testing_dag_bundle, session): + def test_dagrun_update_state_with_handle_callback_failure(self, session): def on_failure_callable(context): assert context["dag_run"].dag_id == "test_dagrun_update_state_with_handle_callback_failure" @@ -566,7 +566,7 @@ def on_failure_callable(context): start_date=datetime.datetime(2017, 1, 1), on_failure_callback=on_failure_callable, ) - DAG.bulk_write_to_db("testing", None, dags=[dag], session=session) + DAG.bulk_write_to_db(dags=[dag], session=session) dag_task1 = EmptyOperator(task_id="test_state_succeeded1", dag=dag) dag_task2 = EmptyOperator(task_id="test_state_failed2", dag=dag) diff --git a/tests/models/test_serialized_dag.py b/tests/models/test_serialized_dag.py index 94835fbd5e5d7..60d26959b079b 100644 --- a/tests/models/test_serialized_dag.py +++ b/tests/models/test_serialized_dag.py @@ -49,16 +49,8 @@ # To move it to a shared module. def make_example_dags(module): """Loads DAGs from a module for test.""" - from airflow.models.dagbundle import DagBundleModel - from airflow.utils.session import create_session - - with create_session() as session: - if session.query(DagBundleModel).filter(DagBundleModel.name == "testing").count() == 0: - testing = DagBundleModel(name="testing") - session.add(testing) - dagbag = DagBag(module.__path__[0]) - DAG.bulk_write_to_db("testing", None, dagbag.dags.values()) + DAG.bulk_write_to_db(dagbag.dags.values()) return dagbag.dags @@ -185,13 +177,13 @@ def test_read_all_dags_only_picks_the_latest_serdags(self, session): # assert only the latest SDM is returned assert len(sdags) != len(serialized_dags2) - def test_bulk_sync_to_db(self, testing_dag_bundle): + def test_bulk_sync_to_db(self): dags = [ DAG("dag_1", schedule=None), DAG("dag_2", schedule=None), DAG("dag_3", schedule=None), ] - DAG.bulk_write_to_db("testing", None, dags) + DAG.bulk_write_to_db(dags) # we also write to dag_version and dag_code tables # in dag_version. with assert_queries_count(24): diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index bb3edc53cdd28..73f5908b707cf 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -2280,7 +2280,7 @@ def test_success_callback_no_race_condition(self, create_task_instance): ti.refresh_from_db() assert ti.state == State.SUCCESS - def test_outlet_assets(self, create_task_instance, testing_dag_bundle): + def test_outlet_assets(self, create_task_instance): """ Verify that when we have an outlet asset on a task, and the task completes successfully, an AssetDagRunQueue is logged. @@ -2291,7 +2291,7 @@ def test_outlet_assets(self, create_task_instance, testing_dag_bundle): session = settings.Session() dagbag = DagBag(dag_folder=example_assets.__file__) dagbag.collect_dags(only_if_updated=False, safe_mode=False) - dagbag.sync_to_db("testing", None, session=session) + dagbag.sync_to_db(session=session) asset_models = session.scalars(select(AssetModel)).all() SchedulerJobRunner._activate_referenced_assets(asset_models, session=session) @@ -2343,7 +2343,7 @@ def test_outlet_assets(self, create_task_instance, testing_dag_bundle): event.timestamp < adrq_timestamp for (adrq_timestamp,) in adrq_timestamps ), f"Some items in {[str(t) for t in adrq_timestamps]} are earlier than {event.timestamp}" - def test_outlet_assets_failed(self, create_task_instance, testing_dag_bundle): + def test_outlet_assets_failed(self, create_task_instance): """ Verify that when we have an outlet asset on a task, and the task failed, an AssetDagRunQueue is not logged, and an AssetEvent is @@ -2355,7 +2355,7 @@ def test_outlet_assets_failed(self, create_task_instance, testing_dag_bundle): session = settings.Session() dagbag = DagBag(dag_folder=test_assets.__file__) dagbag.collect_dags(only_if_updated=False, safe_mode=False) - dagbag.sync_to_db("testing", None, session=session) + dagbag.sync_to_db(session=session) run_id = str(uuid4()) dr = DagRun(dag_with_fail_task.dag_id, run_id=run_id, run_type="anything") session.merge(dr) @@ -2397,7 +2397,7 @@ def raise_an_exception(placeholder: int): task_instance.run() assert task_instance.current_state() == TaskInstanceState.SUCCESS - def test_outlet_assets_skipped(self, testing_dag_bundle): + def test_outlet_assets_skipped(self): """ Verify that when we have an outlet asset on a task, and the task is skipped, an AssetDagRunQueue is not logged, and an AssetEvent is @@ -2409,7 +2409,7 @@ def test_outlet_assets_skipped(self, testing_dag_bundle): session = settings.Session() dagbag = DagBag(dag_folder=test_assets.__file__) dagbag.collect_dags(only_if_updated=False, safe_mode=False) - dagbag.sync_to_db("testing", None, session=session) + dagbag.sync_to_db(session=session) asset_models = session.scalars(select(AssetModel)).all() SchedulerJobRunner._activate_referenced_assets(asset_models, session=session) diff --git a/tests/operators/test_trigger_dagrun.py b/tests/operators/test_trigger_dagrun.py index 74d8371e5d78e..a8a6b3c262903 100644 --- a/tests/operators/test_trigger_dagrun.py +++ b/tests/operators/test_trigger_dagrun.py @@ -26,6 +26,7 @@ from airflow.exceptions import AirflowException, DagRunAlreadyExists, TaskDeferred from airflow.models.dag import DagModel +from airflow.models.dagbag import DagBag from airflow.models.dagrun import DagRun from airflow.models.log import Log from airflow.models.taskinstance import TaskInstance @@ -36,8 +37,6 @@ from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.types import DagRunType -from tests_common.test_utils.db import parse_and_sync_to_db - pytestmark = pytest.mark.db_test DEFAULT_DATE = datetime(2019, 1, 1, tzinfo=timezone.utc) @@ -72,6 +71,11 @@ def setup_method(self): session.add(DagModel(dag_id=TRIGGERED_DAG_ID, fileloc=self._tmpfile)) session.commit() + def re_sync_triggered_dag_to_db(self, dag, dag_maker): + dagbag = DagBag(self.f_name, read_dags_from_db=False, include_examples=False) + dagbag.bag_dag(dag) + dagbag.sync_to_db(session=dag_maker.session) + def teardown_method(self): """Cleanup state after testing in DB.""" with create_session() as session: @@ -116,10 +120,9 @@ def test_trigger_dagrun(self, dag_maker): """Test TriggerDagRunOperator.""" with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ): + ) as dag: task = TriggerDagRunOperator(task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID) - dag_maker.sync_dagbag_to_db() - parse_and_sync_to_db(self.f_name) + self.re_sync_triggered_dag_to_db(dag, dag_maker) dag_maker.create_dagrun() task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @@ -131,14 +134,13 @@ def test_trigger_dagrun(self, dag_maker): def test_trigger_dagrun_custom_run_id(self, dag_maker): with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ): + ) as dag: task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, trigger_run_id="custom_run_id", ) - dag_maker.sync_dagbag_to_db() - parse_and_sync_to_db(self.f_name) + self.re_sync_triggered_dag_to_db(dag, dag_maker) dag_maker.create_dagrun() task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) @@ -152,14 +154,13 @@ def test_trigger_dagrun_with_logical_date(self, dag_maker): custom_logical_date = timezone.datetime(2021, 1, 2, 3, 4, 5) with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ): + ) as dag: task = TriggerDagRunOperator( task_id="test_trigger_dagrun_with_logical_date", trigger_dag_id=TRIGGERED_DAG_ID, logical_date=custom_logical_date, ) - dag_maker.sync_dagbag_to_db() - parse_and_sync_to_db(self.f_name) + self.re_sync_triggered_dag_to_db(dag, dag_maker) dag_maker.create_dagrun() task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @@ -176,7 +177,7 @@ def test_trigger_dagrun_twice(self, dag_maker): run_id = f"manual__{utc_now.isoformat()}" with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ): + ) as dag: task = TriggerDagRunOperator( task_id="test_trigger_dagrun_with_logical_date", trigger_dag_id=TRIGGERED_DAG_ID, @@ -186,8 +187,7 @@ def test_trigger_dagrun_twice(self, dag_maker): reset_dag_run=True, wait_for_completion=True, ) - dag_maker.sync_dagbag_to_db() - parse_and_sync_to_db(self.f_name) + self.re_sync_triggered_dag_to_db(dag, dag_maker) dag_maker.create_dagrun() dag_run = DagRun( dag_id=TRIGGERED_DAG_ID, @@ -213,7 +213,7 @@ def test_trigger_dagrun_with_scheduled_dag_run(self, dag_maker): run_id = f"scheduled__{utc_now.isoformat()}" with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ): + ) as dag: task = TriggerDagRunOperator( task_id="test_trigger_dagrun_with_logical_date", trigger_dag_id=TRIGGERED_DAG_ID, @@ -223,8 +223,7 @@ def test_trigger_dagrun_with_scheduled_dag_run(self, dag_maker): reset_dag_run=True, wait_for_completion=True, ) - dag_maker.sync_dagbag_to_db() - parse_and_sync_to_db(self.f_name) + self.re_sync_triggered_dag_to_db(dag, dag_maker) dag_maker.create_dagrun() run_id = f"scheduled__{utc_now.isoformat()}" dag_run = DagRun( @@ -249,14 +248,13 @@ def test_trigger_dagrun_with_templated_logical_date(self, dag_maker): """Test TriggerDagRunOperator with templated logical_date.""" with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ): + ) as dag: task = TriggerDagRunOperator( task_id="test_trigger_dagrun_with_str_logical_date", trigger_dag_id=TRIGGERED_DAG_ID, logical_date="{{ logical_date }}", ) - dag_maker.sync_dagbag_to_db() - parse_and_sync_to_db(self.f_name) + self.re_sync_triggered_dag_to_db(dag, dag_maker) dag_maker.create_dagrun() task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @@ -272,13 +270,12 @@ def test_trigger_dagrun_with_templated_trigger_dag_id(self, dag_maker): """Test TriggerDagRunOperator with templated trigger dag id.""" with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ): + ) as dag: task = TriggerDagRunOperator( task_id="__".join(["test_trigger_dagrun_with_templated_trigger_dag_id", TRIGGERED_DAG_ID]), trigger_dag_id="{{ ti.task_id.rsplit('.', 1)[-1].split('__')[-1] }}", ) - dag_maker.sync_dagbag_to_db() - parse_and_sync_to_db(self.f_name) + self.re_sync_triggered_dag_to_db(dag, dag_maker) dag_maker.create_dagrun() task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @@ -294,14 +291,13 @@ def test_trigger_dagrun_operator_conf(self, dag_maker): """Test passing conf to the triggered DagRun.""" with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ): + ) as dag: task = TriggerDagRunOperator( task_id="test_trigger_dagrun_with_str_logical_date", trigger_dag_id=TRIGGERED_DAG_ID, conf={"foo": "bar"}, ) - dag_maker.sync_dagbag_to_db() - parse_and_sync_to_db(self.f_name) + self.re_sync_triggered_dag_to_db(dag, dag_maker) dag_maker.create_dagrun() task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @@ -314,14 +310,13 @@ def test_trigger_dagrun_operator_templated_invalid_conf(self, dag_maker): """Test passing a conf that is not JSON Serializable raise error.""" with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ): + ) as dag: task = TriggerDagRunOperator( task_id="test_trigger_dagrun_with_invalid_conf", trigger_dag_id=TRIGGERED_DAG_ID, conf={"foo": "{{ dag.dag_id }}", "datetime": timezone.utcnow()}, ) - dag_maker.sync_dagbag_to_db() - parse_and_sync_to_db(self.f_name) + self.re_sync_triggered_dag_to_db(dag, dag_maker) dag_maker.create_dagrun() with pytest.raises(AirflowException, match="^conf parameter should be JSON Serializable$"): task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) @@ -330,14 +325,13 @@ def test_trigger_dagrun_operator_templated_conf(self, dag_maker): """Test passing a templated conf to the triggered DagRun.""" with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ): + ) as dag: task = TriggerDagRunOperator( task_id="test_trigger_dagrun_with_str_logical_date", trigger_dag_id=TRIGGERED_DAG_ID, conf={"foo": "{{ dag.dag_id }}"}, ) - dag_maker.sync_dagbag_to_db() - parse_and_sync_to_db(self.f_name) + self.re_sync_triggered_dag_to_db(dag, dag_maker) dag_maker.create_dagrun() task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @@ -351,7 +345,7 @@ def test_trigger_dagrun_with_reset_dag_run_false(self, dag_maker): logical_date = DEFAULT_DATE with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ): + ) as dag: task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, @@ -359,8 +353,7 @@ def test_trigger_dagrun_with_reset_dag_run_false(self, dag_maker): logical_date=None, reset_dag_run=False, ) - dag_maker.sync_dagbag_to_db() - parse_and_sync_to_db(self.f_name) + self.re_sync_triggered_dag_to_db(dag, dag_maker) dag_maker.create_dagrun() task.run(start_date=logical_date, end_date=logical_date, ignore_ti_state=True) task.run(start_date=logical_date, end_date=logical_date, ignore_ti_state=True) @@ -384,7 +377,7 @@ def test_trigger_dagrun_with_reset_dag_run_false_fail( logical_date = DEFAULT_DATE with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ): + ) as dag: task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, @@ -392,8 +385,7 @@ def test_trigger_dagrun_with_reset_dag_run_false_fail( logical_date=trigger_logical_date, reset_dag_run=False, ) - dag_maker.sync_dagbag_to_db() - parse_and_sync_to_db(self.f_name) + self.re_sync_triggered_dag_to_db(dag, dag_maker) dag_maker.create_dagrun() task.run(start_date=logical_date, end_date=logical_date, ignore_ti_state=True) @@ -405,7 +397,7 @@ def test_trigger_dagrun_with_skip_when_already_exists(self, dag_maker): logical_date = DEFAULT_DATE with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ): + ) as dag: task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, @@ -413,8 +405,7 @@ def test_trigger_dagrun_with_skip_when_already_exists(self, dag_maker): reset_dag_run=False, skip_when_already_exists=True, ) - dag_maker.sync_dagbag_to_db() - parse_and_sync_to_db(self.f_name) + self.re_sync_triggered_dag_to_db(dag, dag_maker) dr: DagRun = dag_maker.create_dagrun() task.run(start_date=logical_date, end_date=logical_date, ignore_ti_state=True) assert dr.get_task_instance("test_task").state == TaskInstanceState.SUCCESS @@ -437,7 +428,7 @@ def test_trigger_dagrun_with_reset_dag_run_true( logical_date = DEFAULT_DATE with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ): + ) as dag: task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, @@ -445,8 +436,7 @@ def test_trigger_dagrun_with_reset_dag_run_true( logical_date=trigger_logical_date, reset_dag_run=True, ) - dag_maker.sync_dagbag_to_db() - parse_and_sync_to_db(self.f_name) + self.re_sync_triggered_dag_to_db(dag, dag_maker) dag_maker.create_dagrun() task.run(start_date=logical_date, end_date=logical_date, ignore_ti_state=True) task.run(start_date=logical_date, end_date=logical_date, ignore_ti_state=True) @@ -461,7 +451,7 @@ def test_trigger_dagrun_with_wait_for_completion_true(self, dag_maker): logical_date = DEFAULT_DATE with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ): + ) as dag: task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, @@ -470,8 +460,7 @@ def test_trigger_dagrun_with_wait_for_completion_true(self, dag_maker): poke_interval=10, allowed_states=[State.QUEUED], ) - dag_maker.sync_dagbag_to_db() - parse_and_sync_to_db(self.f_name) + self.re_sync_triggered_dag_to_db(dag, dag_maker) dag_maker.create_dagrun() task.run(start_date=logical_date, end_date=logical_date) @@ -484,7 +473,7 @@ def test_trigger_dagrun_with_wait_for_completion_true_fail(self, dag_maker): logical_date = DEFAULT_DATE with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ): + ) as dag: task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, @@ -493,8 +482,7 @@ def test_trigger_dagrun_with_wait_for_completion_true_fail(self, dag_maker): poke_interval=10, failed_states=[State.QUEUED], ) - dag_maker.sync_dagbag_to_db() - parse_and_sync_to_db(self.f_name) + self.re_sync_triggered_dag_to_db(dag, dag_maker) dag_maker.create_dagrun() with pytest.raises(AirflowException): task.run(start_date=logical_date, end_date=logical_date) @@ -504,13 +492,12 @@ def test_trigger_dagrun_triggering_itself(self, dag_maker): logical_date = DEFAULT_DATE with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ): + ) as dag: task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TEST_DAG_ID, ) - dag_maker.sync_dagbag_to_db() - parse_and_sync_to_db(self.f_name) + self.re_sync_triggered_dag_to_db(dag, dag_maker) dag_maker.create_dagrun() task.run(start_date=logical_date, end_date=logical_date) @@ -529,7 +516,7 @@ def test_trigger_dagrun_with_wait_for_completion_true_defer_false(self, dag_make logical_date = DEFAULT_DATE with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ): + ) as dag: task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, @@ -539,8 +526,7 @@ def test_trigger_dagrun_with_wait_for_completion_true_defer_false(self, dag_make allowed_states=[State.QUEUED], deferrable=False, ) - dag_maker.sync_dagbag_to_db() - parse_and_sync_to_db(self.f_name) + self.re_sync_triggered_dag_to_db(dag, dag_maker) dag_maker.create_dagrun() task.run(start_date=logical_date, end_date=logical_date) @@ -553,7 +539,7 @@ def test_trigger_dagrun_with_wait_for_completion_true_defer_true(self, dag_maker logical_date = DEFAULT_DATE with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ): + ) as dag: task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, @@ -563,8 +549,7 @@ def test_trigger_dagrun_with_wait_for_completion_true_defer_true(self, dag_maker allowed_states=[State.QUEUED], deferrable=True, ) - dag_maker.sync_dagbag_to_db() - parse_and_sync_to_db(self.f_name) + self.re_sync_triggered_dag_to_db(dag, dag_maker) dag_maker.create_dagrun() task.run(start_date=logical_date, end_date=logical_date) @@ -586,7 +571,7 @@ def test_trigger_dagrun_with_wait_for_completion_true_defer_true_failure(self, d logical_date = DEFAULT_DATE with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ): + ) as dag: task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, @@ -596,8 +581,7 @@ def test_trigger_dagrun_with_wait_for_completion_true_defer_true_failure(self, d allowed_states=[State.SUCCESS], deferrable=True, ) - dag_maker.sync_dagbag_to_db() - parse_and_sync_to_db(self.f_name) + self.re_sync_triggered_dag_to_db(dag, dag_maker) dag_maker.create_dagrun() task.run(start_date=logical_date, end_date=logical_date) @@ -623,7 +607,7 @@ def test_trigger_dagrun_with_wait_for_completion_true_defer_true_failure_2(self, logical_date = DEFAULT_DATE with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ): + ) as dag: task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, @@ -634,8 +618,7 @@ def test_trigger_dagrun_with_wait_for_completion_true_defer_true_failure_2(self, failed_states=[State.QUEUED], deferrable=True, ) - dag_maker.sync_dagbag_to_db() - parse_and_sync_to_db(self.f_name) + self.re_sync_triggered_dag_to_db(dag, dag_maker) dag_maker.create_dagrun() task.run(start_date=logical_date, end_date=logical_date) @@ -665,7 +648,7 @@ def test_dagstatetrigger_logical_dates(self, trigger_logical_date, dag_maker): """Ensure that the DagStateTrigger is called with the triggered DAG's logical date.""" with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ): + ) as dag: task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, @@ -675,8 +658,7 @@ def test_dagstatetrigger_logical_dates(self, trigger_logical_date, dag_maker): allowed_states=[DagRunState.QUEUED], deferrable=True, ) - dag_maker.sync_dagbag_to_db() - parse_and_sync_to_db(self.f_name) + self.re_sync_triggered_dag_to_db(dag, dag_maker) dag_maker.create_dagrun() mock_task_defer = mock.MagicMock(side_effect=task.defer) @@ -695,7 +677,7 @@ def test_dagstatetrigger_logical_dates_with_clear_and_reset(self, dag_maker): """Check DagStateTrigger is called with the triggered DAG's logical date on subsequent defers.""" with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ): + ) as dag: task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, @@ -706,8 +688,7 @@ def test_dagstatetrigger_logical_dates_with_clear_and_reset(self, dag_maker): deferrable=True, reset_dag_run=True, ) - dag_maker.sync_dagbag_to_db() - parse_and_sync_to_db(self.f_name) + self.re_sync_triggered_dag_to_db(dag, dag_maker) dag_maker.create_dagrun() mock_task_defer = mock.MagicMock(side_effect=task.defer) @@ -746,7 +727,7 @@ def test_trigger_dagrun_with_no_failed_state(self, dag_maker): logical_date = DEFAULT_DATE with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ): + ) as dag: task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, @@ -755,8 +736,7 @@ def test_trigger_dagrun_with_no_failed_state(self, dag_maker): poke_interval=10, failed_states=[], ) - dag_maker.sync_dagbag_to_db() - parse_and_sync_to_db(self.f_name) + self.re_sync_triggered_dag_to_db(dag, dag_maker) dag_maker.create_dagrun() assert task.failed_states == [] diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task_sensor.py index c73980ca24bd5..ba3cfd6b4480c 100644 --- a/tests/sensors/test_external_task_sensor.py +++ b/tests/sensors/test_external_task_sensor.py @@ -80,7 +80,7 @@ def clean_db(): @pytest.fixture -def dag_zip_maker(testing_dag_bundle): +def dag_zip_maker(): class DagZipMaker: def __call__(self, *dag_files): self.__dag_files = [os.sep.join([TEST_DAGS_FOLDER.__str__(), dag_file]) for dag_file in dag_files] @@ -98,7 +98,7 @@ def __enter__(self): for dag_file in self.__dag_files: zf.write(dag_file, os.path.basename(dag_file)) dagbag = DagBag(dag_folder=self.__tmp_dir, include_examples=False) - dagbag.sync_to_db("testing", None) + dagbag.sync_to_db() return dagbag def __exit__(self, exc_type, exc_val, exc_tb): diff --git a/tests/utils/test_file.py b/tests/utils/test_file.py index 6b5b5b550047c..da1c6fa2fc87f 100644 --- a/tests/utils/test_file.py +++ b/tests/utils/test_file.py @@ -25,18 +25,11 @@ import pytest from airflow.utils import file as file_utils -from airflow.utils.file import ( - correct_maybe_zipped, - find_path_from_directory, - list_py_file_paths, - open_maybe_zipped, -) +from airflow.utils.file import correct_maybe_zipped, find_path_from_directory, open_maybe_zipped from tests.models import TEST_DAGS_FOLDER from tests_common.test_utils.config import conf_vars -TEST_DAG_FOLDER = os.environ["AIRFLOW__CORE__DAGS_FOLDER"] - def might_contain_dag(file_path: str, zip_file: zipfile.ZipFile | None = None): return False @@ -219,31 +212,6 @@ def test_get_modules_from_invalid_file(self): assert len(modules) == 0 - def test_list_py_file_paths(self): - detected_files = set() - expected_files = set() - # No_dags is empty, _invalid_ is ignored by .airflowignore - ignored_files = { - "no_dags.py", - "test_invalid_cron.py", - "test_invalid_dup_task.py", - "test_ignore_this.py", - "test_invalid_param.py", - "test_invalid_param2.py", - "test_invalid_param3.py", - "test_invalid_param4.py", - "test_nested_dag.py", - "test_imports.py", - "__init__.py", - } - for root, _, files in os.walk(TEST_DAG_FOLDER): - for file_name in files: - if file_name.endswith((".py", ".zip")): - if file_name not in ignored_files: - expected_files.add(f"{root}/{file_name}") - detected_files = set(list_py_file_paths(TEST_DAG_FOLDER)) - assert detected_files == expected_files - @pytest.mark.parametrize( "edge_filename, expected_modification", diff --git a/tests/www/views/conftest.py b/tests/www/views/conftest.py index 68516b4f4c865..7f524b2377e39 100644 --- a/tests/www/views/conftest.py +++ b/tests/www/views/conftest.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import os from collections.abc import Generator from contextlib import contextmanager from typing import Any, NamedTuple @@ -32,7 +31,6 @@ from tests_common.test_utils.api_connexion_utils import delete_user from tests_common.test_utils.config import conf_vars -from tests_common.test_utils.db import parse_and_sync_to_db from tests_common.test_utils.decorators import dont_initialize_flask_app_submodules from tests_common.test_utils.www import ( client_with_login, @@ -49,8 +47,8 @@ def session(): @pytest.fixture(autouse=True, scope="module") def examples_dag_bag(session): - parse_and_sync_to_db(os.devnull, include_examples=True) - dag_bag = DagBag(read_dags_from_db=True) + DagBag(include_examples=True).sync_to_db() + dag_bag = DagBag(include_examples=True, read_dags_from_db=True) session.commit() return dag_bag diff --git a/tests/www/views/test_views_acl.py b/tests/www/views/test_views_acl.py index 21379550dd1fd..0f430b9b6dca7 100644 --- a/tests/www/views/test_views_acl.py +++ b/tests/www/views/test_views_acl.py @@ -293,9 +293,9 @@ def test_dag_autocomplete_success(client_all_dags): "dag_display_name": None, }, {"name": "example_setup_teardown_taskflow", "type": "dag", "dag_display_name": None}, + {"name": "test_mapped_taskflow", "type": "dag", "dag_display_name": None}, {"name": "tutorial_taskflow_api", "type": "dag", "dag_display_name": None}, {"name": "tutorial_taskflow_api_virtualenv", "type": "dag", "dag_display_name": None}, - {"name": "tutorial_taskflow_templates", "type": "dag", "dag_display_name": None}, ] assert resp.json == expected diff --git a/tests/www/views/test_views_decorators.py b/tests/www/views/test_views_decorators.py index 6c0a0a7d78172..5815510b07f6e 100644 --- a/tests/www/views/test_views_decorators.py +++ b/tests/www/views/test_views_decorators.py @@ -17,8 +17,6 @@ # under the License. from __future__ import annotations -import os - import pytest from airflow.models import DagBag, Variable @@ -26,7 +24,7 @@ from airflow.utils.state import State from airflow.utils.types import DagRunType -from tests_common.test_utils.db import clear_db_runs, clear_db_variables, parse_and_sync_to_db +from tests_common.test_utils.db import clear_db_runs, clear_db_variables from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS from tests_common.test_utils.www import ( _check_last_log, @@ -44,8 +42,8 @@ @pytest.fixture(scope="module") def dagbag(): - parse_and_sync_to_db(os.devnull, include_examples=True) - return DagBag(read_dags_from_db=True) + DagBag(include_examples=True, read_dags_from_db=False).sync_to_db() + return DagBag(include_examples=True, read_dags_from_db=True) @pytest.fixture(scope="module") diff --git a/tests/www/views/test_views_log.py b/tests/www/views/test_views_log.py index 735c9a569e093..be492487a4c8b 100644 --- a/tests/www/views/test_views_log.py +++ b/tests/www/views/test_views_log.py @@ -20,7 +20,6 @@ import copy import logging import logging.config -import os import pathlib import shutil import sys @@ -128,7 +127,7 @@ def _reset_modules_after_every_test(backup_modules): @pytest.fixture(autouse=True) -def dags(log_app, create_dummy_dag, testing_dag_bundle, session): +def dags(log_app, create_dummy_dag, session): dag, _ = create_dummy_dag( dag_id=DAG_ID, task_id=TASK_ID, @@ -144,10 +143,10 @@ def dags(log_app, create_dummy_dag, testing_dag_bundle, session): session=session, ) - bag = DagBag(os.devnull, include_examples=False) + bag = DagBag(include_examples=False) bag.bag_dag(dag=dag) bag.bag_dag(dag=dag_removed) - bag.sync_to_db("testing", None, session=session) + bag.sync_to_db(session=session) log_app.dag_bag = bag yield dag, dag_removed diff --git a/tests/www/views/test_views_tasks.py b/tests/www/views/test_views_tasks.py index 10f42aca6c4f2..8289228b81dfe 100644 --- a/tests/www/views/test_views_tasks.py +++ b/tests/www/views/test_views_tasks.py @@ -597,7 +597,7 @@ def heartbeat(self): def test_delete_dag_button_for_dag_on_scheduler_only(admin_client, dag_maker): with dag_maker() as dag: EmptyOperator(task_id="task") - dag_maker.sync_dagbag_to_db() + dag.sync_to_db() # The delete-dag URL should be generated correctly test_dag_id = dag.dag_id resp = admin_client.get("/", follow_redirects=True) @@ -606,12 +606,12 @@ def test_delete_dag_button_for_dag_on_scheduler_only(admin_client, dag_maker): @pytest.fixture -def new_dag_to_delete(testing_dag_bundle): +def new_dag_to_delete(): dag = DAG( "new_dag_to_delete", is_paused_upon_creation=True, schedule="0 * * * *", start_date=DEFAULT_DATE ) session = settings.Session() - DAG.bulk_write_to_db("testing", None, [dag], session=session) + dag.sync_to_db(session=session) return dag diff --git a/tests_common/pytest_plugin.py b/tests_common/pytest_plugin.py index b4d3c089a3740..43b4733e38811 100644 --- a/tests_common/pytest_plugin.py +++ b/tests_common/pytest_plugin.py @@ -935,19 +935,6 @@ def create_dagrun_after(self, dagrun, **kwargs): **kwargs, ) - def sync_dagbag_to_db(self): - if not AIRFLOW_V_3_0_PLUS: - self.dagbag.sync_to_db() - return - - from airflow.models.dagbundle import DagBundleModel - - if self.session.query(DagBundleModel).filter(DagBundleModel.name == "dag_maker").count() == 0: - self.session.add(DagBundleModel(name="dag_maker")) - self.session.commit() - - self.dagbag.sync_to_db("dag_maker", None) - def __call__( self, dag_id="test_dag", diff --git a/tests_common/test_utils/db.py b/tests_common/test_utils/db.py index 58ad46372d711..8c5d59751f55c 100644 --- a/tests_common/test_utils/db.py +++ b/tests_common/test_utils/db.py @@ -17,8 +17,6 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING - from airflow.jobs.job import Job from airflow.models import ( Connection, @@ -53,28 +51,15 @@ ) from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS -if TYPE_CHECKING: - from pathlib import Path - def _bootstrap_dagbag(): from airflow.models.dag import DAG from airflow.models.dagbag import DagBag - if AIRFLOW_V_3_0_PLUS: - from airflow.dag_processing.bundles.manager import DagBundlesManager - with create_session() as session: - if AIRFLOW_V_3_0_PLUS: - DagBundlesManager().sync_bundles_to_db(session=session) - session.commit() - dagbag = DagBag() # Save DAGs in the ORM - if AIRFLOW_V_3_0_PLUS: - dagbag.sync_to_db(bundle_name="dags-folder", bundle_version=None, session=session) - else: - dagbag.sync_to_db(session=session) + dagbag.sync_to_db(session=session) # Deactivate the unknown ones DAG.deactivate_unknown_dags(dagbag.dags.keys(), session=session) @@ -107,25 +92,6 @@ def initial_db_init(): get_auth_manager().init() -def parse_and_sync_to_db(folder: Path | str, include_examples: bool = False): - from airflow.models.dagbag import DagBag - - if AIRFLOW_V_3_0_PLUS: - from airflow.dag_processing.bundles.manager import DagBundlesManager - - with create_session() as session: - if AIRFLOW_V_3_0_PLUS: - DagBundlesManager().sync_bundles_to_db(session=session) - session.commit() - - dagbag = DagBag(dag_folder=folder, include_examples=include_examples) - if AIRFLOW_V_3_0_PLUS: - dagbag.sync_to_db("dags-folder", None, session) - else: - dagbag.sync_to_db(session=session) # type: ignore[call-arg] - session.commit() - - def clear_db_runs(): with create_session() as session: session.query(Job).delete()