diff --git a/airflow/cli/commands/local_commands/dag_processor_command.py b/airflow/cli/commands/local_commands/dag_processor_command.py index 653c5f6bf577f..af2b65ff49b9f 100644 --- a/airflow/cli/commands/local_commands/dag_processor_command.py +++ b/airflow/cli/commands/local_commands/dag_processor_command.py @@ -39,7 +39,6 @@ def _create_dag_processor_job_runner(args: Any) -> DagProcessorJobRunner: job=Job(), processor=DagFileProcessorManager( processor_timeout=processor_timeout_seconds, - dag_directory=args.subdir, max_runs=args.num_runs, ), ) diff --git a/airflow/dag_processing/bundles/manager.py b/airflow/dag_processing/bundles/manager.py index 0a17ab3f8c5f5..49cccc02849e8 100644 --- a/airflow/dag_processing/bundles/manager.py +++ b/airflow/dag_processing/bundles/manager.py @@ -64,6 +64,23 @@ def parse_config(self) -> None: "Bundle config is not a list. Check config value" " for section `dag_bundles` and key `backends`." ) + + # example dags! + if conf.getboolean("core", "LOAD_EXAMPLES"): + from airflow import example_dags + + example_dag_folder = next(iter(example_dags.__path__)) + backends.append( + { + "name": "example_dags", + "classpath": "airflow.dag_processing.bundles.local.LocalDagBundle", + "kwargs": { + "local_folder": example_dag_folder, + "refresh_interval": conf.getint("scheduler", "dag_dir_list_interval"), + }, + } + ) + seen = set() for cfg in backends: name = cfg["name"] diff --git a/airflow/dag_processing/collection.py b/airflow/dag_processing/collection.py index 65d6dbea77bf4..6e0b627198995 100644 --- a/airflow/dag_processing/collection.py +++ b/airflow/dag_processing/collection.py @@ -74,11 +74,14 @@ log = logging.getLogger(__name__) -def _create_orm_dags(dags: Iterable[MaybeSerializedDAG], *, session: Session) -> Iterator[DagModel]: +def _create_orm_dags( + bundle_name: str, dags: Iterable[MaybeSerializedDAG], *, session: Session +) -> Iterator[DagModel]: for dag in dags: orm_dag = DagModel(dag_id=dag.dag_id) if dag.is_paused_upon_creation is not None: orm_dag.is_paused = dag.is_paused_upon_creation + orm_dag.bundle_name = bundle_name log.info("Creating ORM DAG for %s", dag.dag_id) session.add(orm_dag) yield orm_dag @@ -270,6 +273,8 @@ def _update_import_errors(files_parsed: set[str], import_errors: dict[str, str], def update_dag_parsing_results_in_db( + bundle_name: str, + bundle_version: str | None, dags: Collection[MaybeSerializedDAG], import_errors: dict[str, str], warnings: set[DagWarning], @@ -307,7 +312,7 @@ def update_dag_parsing_results_in_db( ) log.debug("Calling the DAG.bulk_sync_to_db method") try: - DAG.bulk_write_to_db(dags, session=session) + DAG.bulk_write_to_db(bundle_name, bundle_version, dags, session=session) # Write Serialized DAGs to DB, capturing errors # Write Serialized DAGs to DB, capturing errors for dag in dags: @@ -346,6 +351,8 @@ class DagModelOperation(NamedTuple): """Collect DAG objects and perform database operations for them.""" dags: dict[str, MaybeSerializedDAG] + bundle_name: str + bundle_version: str | None def find_orm_dags(self, *, session: Session) -> dict[str, DagModel]: """Find existing DagModel objects from DAG objects.""" @@ -365,7 +372,8 @@ def add_dags(self, *, session: Session) -> dict[str, DagModel]: orm_dags.update( (model.dag_id, model) for model in _create_orm_dags( - (dag for dag_id, dag in self.dags.items() if dag_id not in orm_dags), + bundle_name=self.bundle_name, + dags=(dag for dag_id, dag in self.dags.items() if dag_id not in orm_dags), session=session, ) ) @@ -430,6 +438,8 @@ def update_dags( dm.timetable_summary = dag.timetable.summary dm.timetable_description = dag.timetable.description dm.asset_expression = dag.timetable.asset_condition.as_expression() + dm.bundle_name = self.bundle_name + dm.latest_bundle_version = self.bundle_version last_automated_run: DagRun | None = run_info.latest_runs.get(dag.dag_id) if last_automated_run is None: diff --git a/airflow/dag_processing/manager.py b/airflow/dag_processing/manager.py index 99257b991114e..a66788dc8fbe3 100644 --- a/airflow/dag_processing/manager.py +++ b/airflow/dag_processing/manager.py @@ -51,6 +51,7 @@ from airflow.dag_processing.processor import DagFileParsingResult, DagFileProcessorProcess from airflow.models.dag import DagModel from airflow.models.dagbag import DagPriorityParsingRequest +from airflow.models.dagbundle import DagBundleModel from airflow.models.dagwarning import DagWarning from airflow.models.db_callback_request import DbCallbackRequest from airflow.models.errors import ParseImportError @@ -68,7 +69,7 @@ set_new_process_group, ) from airflow.utils.retries import retry_db_transaction -from airflow.utils.session import NEW_SESSION, provide_session +from airflow.utils.session import NEW_SESSION, create_session, provide_session from airflow.utils.sqlalchemy import prohibit_commit, with_row_locks if TYPE_CHECKING: @@ -76,6 +77,8 @@ from sqlalchemy.orm import Session + from airflow.dag_processing.bundles.base import BaseDagBundle + class DagParsingStat(NamedTuple): """Information on processing progress.""" @@ -99,6 +102,13 @@ class DagFileStat: log = logging.getLogger("airflow.processor_manager") +class DagFileInfo(NamedTuple): + """Information about a DAG file.""" + + path: str # absolute path of the file + bundle_name: str + + class DagFileProcessorAgent(LoggingMixin, MultiprocessingStartMethodMixin): """ Agent for DAG file processing. @@ -109,8 +119,6 @@ class DagFileProcessorAgent(LoggingMixin, MultiprocessingStartMethodMixin): This class runs in the main `airflow scheduler` process when standalone_dag_processor is not enabled. - :param dag_directory: Directory where DAG definitions are kept. All - files in file_paths should be under this directory :param max_runs: The number of times to parse and schedule each file. -1 for unlimited. :param processor_timeout: How long to wait before timing out a DAG file processor @@ -118,12 +126,10 @@ class DagFileProcessorAgent(LoggingMixin, MultiprocessingStartMethodMixin): def __init__( self, - dag_directory: os.PathLike, max_runs: int, processor_timeout: timedelta, ): super().__init__() - self._dag_directory: os.PathLike = dag_directory self._max_runs = max_runs self._processor_timeout = processor_timeout self._process: multiprocessing.Process | None = None @@ -146,7 +152,6 @@ def start(self) -> None: process = context.Process( target=type(self)._run_processor_manager, args=( - self._dag_directory, self._max_runs, self._processor_timeout, child_signal_conn, @@ -171,7 +176,6 @@ def get_callbacks_pipe(self) -> MultiprocessingConnection: @staticmethod def _run_processor_manager( - dag_directory: os.PathLike, max_runs: int, processor_timeout: timedelta, signal_conn: MultiprocessingConnection, @@ -184,7 +188,6 @@ def _run_processor_manager( setproctitle("airflow scheduler -- DagFileProcessorManager") reload_configuration_for_dag_processing() processor_manager = DagFileProcessorManager( - dag_directory=dag_directory, max_runs=max_runs, processor_timeout=processor_timeout.total_seconds(), signal_conn=signal_conn, @@ -303,15 +306,12 @@ class DagFileProcessorManager: processors finish, more are launched. The files are processed over and over again, but no more often than the specified interval. - :param dag_directory: Directory where DAG definitions are kept. All - files in file_paths should be under this directory :param max_runs: The number of times to parse and schedule each file. -1 for unlimited. :param processor_timeout: How long to wait before timing out a DAG file processor :param signal_conn: connection to communicate signal with processor agent. """ - _dag_directory: os.PathLike[str] = attrs.field(validator=_resolve_path) max_runs: int processor_timeout: float = attrs.field(factory=_config_int_factory("core", "dag_file_processor_timeout")) selector: selectors.BaseSelector = attrs.field(factory=selectors.DefaultSelector) @@ -329,7 +329,6 @@ class DagFileProcessorManager: factory=_config_int_factory("scheduler", "min_file_process_interval") ) stale_dag_threshold: float = attrs.field(factory=_config_int_factory("scheduler", "stale_dag_threshold")) - last_dag_dir_refresh_time: float = attrs.field(default=0, init=False) log: logging.Logger = attrs.field(default=log, init=False) @@ -342,11 +341,16 @@ class DagFileProcessorManager: heartbeat: Callable[[], None] = attrs.field(default=lambda: None) """An overridable heartbeat called once every time around the loop""" - _file_paths: list[str] = attrs.field(factory=list, init=False) - _file_path_queue: deque[str] = attrs.field(factory=deque, init=False) - _file_stats: dict[str, DagFileStat] = attrs.field(factory=lambda: defaultdict(DagFileStat), init=False) + _file_paths: list[DagFileInfo] = attrs.field(factory=list, init=False) + _file_path_queue: deque[DagFileInfo] = attrs.field(factory=deque, init=False) + _file_stats: dict[DagFileInfo, DagFileStat] = attrs.field( + factory=lambda: defaultdict(DagFileStat), init=False + ) + + _dag_bundles: list[BaseDagBundle] = attrs.field(factory=list, init=False) + _bundle_versions: dict[str, str] = attrs.field(factory=dict, init=False) - _processors: dict[str, DagFileProcessorProcess] = attrs.field(factory=dict, init=False) + _processors: dict[DagFileInfo, DagFileProcessorProcess] = attrs.field(factory=dict, init=False) _parsing_start_time: float = attrs.field(init=False) _num_run: int = attrs.field(default=0, init=False) @@ -393,14 +397,17 @@ def run(self): self.log.info("Processing files using up to %s processes at a time ", self._parallelism) self.log.info("Process each file at most once every %s seconds", self._file_process_interval) - self.log.info( - "Checking for new files in %s every %s seconds", self._dag_directory, self.dag_dir_list_interval - ) + # TODO: AIP-66 move to report by bundle self.log.info( + # "Checking for new files in %s every %s seconds", self._dag_directory, self.dag_dir_list_interval + # ) from airflow.dag_processing.bundles.manager import DagBundlesManager DagBundlesManager().sync_bundles_to_db() + self.log.info("Getting all DAG bundles") + self._dag_bundles = list(DagBundlesManager().get_all_dag_bundles()) + return self._run_parsing_loop() def _scan_stale_dags(self): @@ -413,7 +420,6 @@ def _scan_stale_dags(self): } self.deactivate_stale_dags( last_parsed=last_parsed, - dag_directory=self.get_dag_directory(), stale_dag_threshold=self.stale_dag_threshold, ) self._last_deactivate_stale_dags_time = time.monotonic() @@ -421,14 +427,16 @@ def _scan_stale_dags(self): @provide_session def deactivate_stale_dags( self, - last_parsed: dict[str, datetime | None], - dag_directory: str, + last_parsed: dict[DagFileInfo, datetime | None], stale_dag_threshold: int, session: Session = NEW_SESSION, ): """Detect and deactivate DAGs which are no longer present in files.""" to_deactivate = set() - query = select(DagModel.dag_id, DagModel.fileloc, DagModel.last_parsed_time).where(DagModel.is_active) + query = select( + DagModel.dag_id, DagModel.bundle_name, DagModel.fileloc, DagModel.last_parsed_time + ).where(DagModel.is_active) + # TODO: AIP-66 by bundle! dags_parsed = session.execute(query) for dag in dags_parsed: @@ -436,9 +444,11 @@ def deactivate_stale_dags( # last_parsed_time is the processor_timeout. Longer than that indicates that the DAG is # no longer present in the file. We have a stale_dag_threshold configured to prevent a # significant delay in deactivation of stale dags when a large timeout is configured + dag_file_path = DagFileInfo(path=dag.fileloc, bundle_name=dag.bundle_name) if ( - dag.fileloc in last_parsed - and (dag.last_parsed_time + timedelta(seconds=stale_dag_threshold)) < last_parsed[dag.fileloc] + dag_file_path in last_parsed + and (dag.last_parsed_time + timedelta(seconds=stale_dag_threshold)) + < last_parsed[dag_file_path] ): self.log.info("DAG %s is missing and will be deactivated.", dag.dag_id) to_deactivate.add(dag.dag_id) @@ -471,17 +481,17 @@ def _run_parsing_loop(self): self.heartbeat() - refreshed_dag_dir = self._refresh_dag_dir() - self._kill_timed_out_processors() + self._refresh_dag_bundles() + if not self._file_path_queue: # Generate more file paths to process if we processed all the files already. Note for this to # clear down, we must have cleared all files found from scanning the dags dir _and_ have # cleared all files added as a result of callbacks self.prepare_file_path_queue() self.emit_metrics() - elif refreshed_dag_dir: + else: # if new files found in dag dir, add them self.add_new_file_path_to_queue() @@ -572,6 +582,8 @@ def _read_from_direct_scheduler_conn(self, conn: MultiprocessingConnection) -> b def _refresh_requested_filelocs(self) -> None: """Refresh filepaths from dag dir as requested by users via APIs.""" + return + # TODO: AIP-66 make bundle aware - fileloc will be relative (eventually), thus not unique in order to know what file to repase # Get values from DB table filelocs = self._get_priority_filelocs() for fileloc in filelocs: @@ -609,6 +621,9 @@ def _fetch_callbacks( def _add_callback_to_queue(self, request: CallbackRequest): self.log.debug("Queuing %s CallbackRequest: %s", type(request).__name__, request) + self.log.warning("Callbacks are not implemented yet!") + # TODO: AIP-66 make callbacks bundle aware + return self._callback_to_execute[request.full_filepath].append(request) if request.full_filepath in self._file_path_queue: # Remove file paths matching request.full_filepath from self._file_path_queue @@ -631,25 +646,70 @@ def _get_priority_filelocs(cls, session: Session = NEW_SESSION): session.delete(request) return filelocs - def _refresh_dag_dir(self) -> bool: - """Refresh file paths from dag dir if we haven't done it for too long.""" - now = time.monotonic() - elapsed_time_since_refresh = now - self.last_dag_dir_refresh_time - if elapsed_time_since_refresh <= self.dag_dir_list_interval: - return False + def _refresh_dag_bundles(self): + """Refresh DAG bundles, if required.""" + now = timezone.utcnow() - # Build up a list of Python files that could contain DAGs - self.log.info("Searching for files in %s", self._dag_directory) - self._file_paths = list_py_file_paths(self._dag_directory) - self.last_dag_dir_refresh_time = now - self.log.info("There are %s files in %s", len(self._file_paths), self._dag_directory) - self.set_file_paths(self._file_paths) + self.log.info("Refreshing DAG bundles") + + for bundle in self._dag_bundles: + # TODO: AIP-66 test to make sure we get a fresh record from the db and it's not cached + with create_session() as session: + bundle_model = session.get(DagBundleModel, bundle.name) + elapsed_time_since_refresh = ( + now - (bundle_model.last_refreshed or timezone.utc_epoch()) + ).total_seconds() + current_version = bundle.get_current_version() + if ( + not elapsed_time_since_refresh > bundle.refresh_interval + ) and bundle_model.latest_version == current_version: + self.log.info("Not time to refresh %s", bundle.name) + continue - try: - self.log.debug("Removing old import errors") + try: + bundle.refresh() + except Exception: + self.log.exception("Error refreshing bundle %s", bundle.name) + continue + + bundle_model.last_refreshed = now + + new_version = bundle.get_current_version() + if bundle.supports_versioning: + # We can short-circuit the rest of the refresh if the version hasn't changed + # and we've already fully "refreshed" this bundle before in this dag processor. + if current_version == new_version and bundle.name in self._bundle_versions: + self.log.debug("Bundle %s version not changed after refresh", bundle.name) + continue + + bundle_model.latest_version = new_version + + self.log.info("Version changed for %s, new version: %s", bundle.name, new_version) + + bundle_file_paths = self._find_files_in_bundle(bundle) + + new_file_paths = [f for f in self._file_paths if f.bundle_name != bundle.name] + new_file_paths.extend( + DagFileInfo(path=path, bundle_name=bundle.name) for path in bundle_file_paths + ) + self.set_file_paths(new_file_paths) + + self.deactivate_deleted_dags(bundle_file_paths) self.clear_nonexistent_import_errors() - except Exception: - self.log.exception("Error removing old import errors") + + self._bundle_versions[bundle.name] = bundle.get_current_version() + + def _find_files_in_bundle(self, bundle: BaseDagBundle) -> list[str]: + """Refresh file paths from bundle dir.""" + # Build up a list of Python files that could contain DAGs + self.log.info("Searching for files in %s at %s", bundle.name, bundle.path) + file_paths = list_py_file_paths(bundle.path) + self.log.info("Found %s files for bundle %s", len(file_paths), bundle.name) + + return file_paths + + def deactivate_deleted_dags(self, file_paths: set[str]) -> None: + """Deactivate DAGs that come from files that are no longer present.""" def _iter_dag_filelocs(fileloc: str) -> Iterator[str]: """ @@ -668,12 +728,11 @@ def _iter_dag_filelocs(fileloc: str) -> Iterator[str]: except zipfile.BadZipFile: self.log.exception("There was an error accessing ZIP file %s %s", fileloc) - dag_filelocs = {full_loc for path in self._file_paths for full_loc in _iter_dag_filelocs(path)} + dag_filelocs = {full_loc for path in file_paths for full_loc in _iter_dag_filelocs(path)} + # TODO: AIP-66: make bundle aware, as fileloc won't be unique long term. DagModel.deactivate_deleted_dags(dag_filelocs) - return True - def _print_stat(self): """Occasionally print out stats about how fast the files are getting processed.""" if 0 < self.print_stats_interval < time.monotonic() - self.last_stat_print_time: @@ -690,15 +749,18 @@ def clear_nonexistent_import_errors(self, session=NEW_SESSION): :param session: session for ORM operations """ self.log.debug("Removing old import errors") - query = delete(ParseImportError) + try: + query = delete(ParseImportError) - if self._file_paths: - query = query.where( - ParseImportError.filename.notin_(self._file_paths), - ) + if self._file_paths: + query = query.where( + ParseImportError.filename.notin_([f.path for f in self._file_paths]), + ) - session.execute(query.execution_options(synchronize_session="fetch")) - session.commit() + session.execute(query.execution_options(synchronize_session="fetch")) + session.commit() + except Exception: + self.log.exception("Error removing old import errors") def _log_file_processing_stats(self, known_file_paths): """ @@ -736,7 +798,7 @@ def _log_file_processing_stats(self, known_file_paths): proc = self._processors.get(file_path) num_dags = stat.num_dags num_errors = stat.import_errors - file_name = Path(file_path).stem + file_name = Path(file_path.path).stem processor_pid = proc.pid if proc else None processor_start_time = proc.start_time if proc else None runtime = (now - processor_start_time) if processor_start_time else None @@ -793,15 +855,9 @@ def _log_file_processing_stats(self, known_file_paths): self.log.info(log_str) - def get_dag_directory(self) -> str | None: - """Return the dag_directory as a string.""" - if self._dag_directory is not None: - return os.fspath(self._dag_directory) - return None - - def set_file_paths(self, new_file_paths): + def set_file_paths(self, new_file_paths: list[DagFileInfo]): """ - Update this with a new set of paths to DAG definition files. + Update this with a new set of DagFilePaths to DAG definition files. :param new_file_paths: list of paths to DAG definition files :return: None @@ -812,9 +868,10 @@ def set_file_paths(self, new_file_paths): self._file_path_queue = deque(x for x in self._file_path_queue if x in new_file_paths) Stats.gauge("dag_processing.file_path_queue_size", len(self._file_path_queue)) - callback_paths_to_del = [x for x in self._callback_to_execute if x not in new_file_paths] - for path_to_del in callback_paths_to_del: - del self._callback_to_execute[path_to_del] + # TODO: AIP-66 make callbacks bundle aware + # callback_paths_to_del = [x for x in self._callback_to_execute if x not in new_file_paths] + # for path_to_del in callback_paths_to_del: + # del self._callback_to_execute[path_to_del] # Stop processors that are working on deleted files filtered_processors = {} @@ -838,33 +895,35 @@ def set_file_paths(self, new_file_paths): def _collect_results(self, session: Session = NEW_SESSION): # TODO: Use an explicit session in this fn finished = [] - for path, proc in self._processors.items(): + for dag_file, proc in self._processors.items(): if not proc.is_ready: # This processor hasn't finished yet, or we haven't read all the output from it yet continue - finished.append(path) + finished.append(dag_file) # Collect the DAGS and import errors into the DB, emit metrics etc. - self._file_stats[path] = process_parse_results( + self._file_stats[dag_file] = process_parse_results( run_duration=time.time() - proc.start_time, finish_time=timezone.utcnow(), - run_count=self._file_stats[path].run_count, + run_count=self._file_stats[dag_file].run_count, + bundle_name=dag_file.bundle_name, + bundle_version=self._bundle_versions[dag_file.bundle_name], parsing_result=proc.parsing_result, - path=path, session=session, ) - for path in finished: - self._processors.pop(path) + for dag_file in finished: + self._processors.pop(dag_file) - def _create_process(self, file_path): + def _create_process(self, dag_file: DagFileInfo) -> DagFileProcessorProcess: id = uuid7() - callback_to_execute_for_file = self._callback_to_execute.pop(file_path, []) + # callback_to_execute_for_file = self._callback_to_execute.pop(file_path, []) + callback_to_execute_for_file: list[CallbackRequest] = [] return DagFileProcessorProcess.start( id=id, - path=file_path, + path=dag_file.path, callbacks=callback_to_execute_for_file, selector=self.selector, ) @@ -914,7 +973,7 @@ def prepare_file_path_queue(self): for file_path in self._file_paths: if is_mtime_mode: try: - files_with_mtime[file_path] = os.path.getmtime(file_path) + files_with_mtime[file_path] = os.path.getmtime(file_path.path) except FileNotFoundError: self.log.warning("Skipping processing of missing file: %s", file_path) self._file_stats.pop(file_path, None) @@ -973,7 +1032,8 @@ def prepare_file_path_queue(self): ) self.log.debug( - "Queuing the following files for processing:\n\t%s", "\n\t".join(files_paths_to_queue) + "Queuing the following files for processing:\n\t%s", + "\n\t".join(f.path for f in files_paths_to_queue), ) self._add_paths_to_queue(files_paths_to_queue, False) Stats.incr("dag_processing.file_path_queue_update_count") @@ -1011,7 +1071,7 @@ def _kill_timed_out_processors(self): for proc in processors_to_remove: self._processors.pop(proc) - def _add_paths_to_queue(self, file_paths_to_enqueue: list[str], add_at_front: bool): + def _add_paths_to_queue(self, file_paths_to_enqueue: list[DagFileInfo], add_at_front: bool): """Add stuff to the back or front of the file queue, unless it's already present.""" new_file_paths = list(p for p in file_paths_to_enqueue if p not in self._file_path_queue) if add_at_front: @@ -1091,7 +1151,8 @@ def process_parse_results( run_duration: float, finish_time: datetime, run_count: int, - path: str, + bundle_name: str, + bundle_version: str | None, parsing_result: DagFileParsingResult | None, session: Session, ) -> DagFileStat: @@ -1102,15 +1163,18 @@ def process_parse_results( run_count=run_count + 1, ) - file_name = Path(path).stem - Stats.timing(f"dag_processing.last_duration.{file_name}", stat.last_duration) - Stats.timing("dag_processing.last_duration", stat.last_duration, tags={"file_name": file_name}) + # TODO: AIP-66 emit metrics + # file_name = Path(dag_file.path).stem + # Stats.timing(f"dag_processing.last_duration.{file_name}", stat.last_duration) + # Stats.timing("dag_processing.last_duration", stat.last_duration, tags={"file_name": file_name}) if parsing_result is None: stat.import_errors = 1 else: # record DAGs and import errors to database update_dag_parsing_results_in_db( + bundle_name=bundle_name, + bundle_version=bundle_version, dags=parsing_result.serialized_dags, import_errors=parsing_result.import_errors or {}, warnings=set(parsing_result.warnings or []), diff --git a/airflow/jobs/scheduler_job_runner.py b/airflow/jobs/scheduler_job_runner.py index 17643b0212195..4718b824830b7 100644 --- a/airflow/jobs/scheduler_job_runner.py +++ b/airflow/jobs/scheduler_job_runner.py @@ -30,7 +30,6 @@ from datetime import timedelta from functools import lru_cache, partial from itertools import groupby -from pathlib import Path from typing import TYPE_CHECKING, Any, Callable from deprecated import deprecated @@ -922,7 +921,6 @@ def _execute(self) -> int | None: processor_timeout = timedelta(seconds=processor_timeout_seconds) if not self._standalone_dag_processor and not self.processor_agent: self.processor_agent = DagFileProcessorAgent( - dag_directory=Path(self.subdir), max_runs=self.num_times_parse_dags, processor_timeout=processor_timeout, ) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 4abf24af52537..82b4ca70819b9 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -770,6 +770,16 @@ def get_is_paused(self, session=NEW_SESSION) -> None: """Return a boolean indicating whether this DAG is paused.""" return session.scalar(select(DagModel.is_paused).where(DagModel.dag_id == self.dag_id)) + @provide_session + def get_bundle_name(self, session=NEW_SESSION) -> None: + """Return the bundle name this DAG is in.""" + return session.scalar(select(DagModel.bundle_name).where(DagModel.dag_id == self.dag_id)) + + @provide_session + def get_latest_bundle_version(self, session=NEW_SESSION) -> None: + """Return the bundle name this DAG is in.""" + return session.scalar(select(DagModel.latest_bundle_version).where(DagModel.dag_id == self.dag_id)) + @methodtools.lru_cache(maxsize=None) @classmethod def get_serialized_fields(cls): @@ -1832,6 +1842,8 @@ def create_dagrun( @provide_session def bulk_write_to_db( cls, + bundle_name: str, + bundle_version: str | None, dags: Collection[MaybeSerializedDAG], session: Session = NEW_SESSION, ): @@ -1847,7 +1859,9 @@ def bulk_write_to_db( from airflow.dag_processing.collection import AssetModelOperation, DagModelOperation log.info("Sync %s DAGs", len(dags)) - dag_op = DagModelOperation({dag.dag_id: dag for dag in dags}) # type: ignore[misc] + dag_op = DagModelOperation( + bundle_name=bundle_name, bundle_version=bundle_version, dags={d.dag_id: d for d in dags} + ) # type: ignore[misc] orm_dags = dag_op.add_dags(session=session) dag_op.update_dags(orm_dags, session=session) @@ -1873,7 +1887,10 @@ def sync_to_db(self, session=NEW_SESSION): :return: None """ - self.bulk_write_to_db([self], session=session) + # TODO: AIP-66 should this be in the model? + bundle_name = self.get_bundle_name(session=session) + bundle_version = self.get_latest_bundle_version(session=session) + self.bulk_write_to_db(bundle_name, bundle_version, [self], session=session) def get_default_view(self): """Allow backward compatible jinja2 templates.""" diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py index 7d0d2efc1bf6e..26f58b26929ae 100644 --- a/airflow/models/dagbag.py +++ b/airflow/models/dagbag.py @@ -568,11 +568,17 @@ def collect_dags( # Ensure dag_folder is a str -- it may have been a pathlib.Path dag_folder = correct_maybe_zipped(str(dag_folder)) - for filepath in list_py_file_paths( - dag_folder, - safe_mode=safe_mode, - include_examples=include_examples, - ): + + files_to_parse = list_py_file_paths(dag_folder, safe_mode=safe_mode) + + if include_examples: + from airflow import example_dags + + example_dag_folder = next(iter(example_dags.__path__)) + + files_to_parse.extend(list_py_file_paths(example_dag_folder, safe_mode=safe_mode)) + + for filepath in files_to_parse: try: file_parse_start_dttm = timezone.utcnow() found_dags = self.process_file(filepath, only_if_updated=only_if_updated, safe_mode=safe_mode) @@ -626,11 +632,13 @@ def dagbag_report(self): return report @provide_session - def sync_to_db(self, session: Session = NEW_SESSION): + def sync_to_db(self, bundle_name: str, bundle_version: str | None, session: Session = NEW_SESSION): """Save attributes about list of DAG to the DB.""" from airflow.dag_processing.collection import update_dag_parsing_results_in_db update_dag_parsing_results_in_db( + bundle_name, + bundle_version, self.dags.values(), # type: ignore[arg-type] # We should create a proto for DAG|LazySerializedDAG self.import_errors, self.dag_warnings, diff --git a/airflow/utils/file.py b/airflow/utils/file.py index 5c3e454e294af..7e72b26792e9e 100644 --- a/airflow/utils/file.py +++ b/airflow/utils/file.py @@ -245,7 +245,6 @@ def find_path_from_directory( def list_py_file_paths( directory: str | os.PathLike[str] | None, safe_mode: bool = conf.getboolean("core", "DAG_DISCOVERY_SAFE_MODE", fallback=True), - include_examples: bool | None = None, ) -> list[str]: """ Traverse a directory and look for Python files. @@ -255,11 +254,8 @@ def list_py_file_paths( contains Airflow DAG definitions. If not provided, use the core.DAG_DISCOVERY_SAFE_MODE configuration setting. If not set, default to safe. - :param include_examples: include example DAGs :return: a list of paths to Python files in the specified directory """ - if include_examples is None: - include_examples = conf.getboolean("core", "LOAD_EXAMPLES") file_paths: list[str] = [] if directory is None: file_paths = [] @@ -267,11 +263,6 @@ def list_py_file_paths( file_paths = [str(directory)] elif os.path.isdir(directory): file_paths.extend(find_dag_file_paths(directory, safe_mode)) - if include_examples: - from airflow import example_dags - - example_dag_folder = next(iter(example_dags.__path__)) - file_paths.extend(list_py_file_paths(example_dag_folder, safe_mode, include_examples=False)) return file_paths diff --git a/providers/tests/fab/auth_manager/api_endpoints/test_dag_run_endpoint.py b/providers/tests/fab/auth_manager/api_endpoints/test_dag_run_endpoint.py index 33df2e9bbc407..e745d3d655bdc 100644 --- a/providers/tests/fab/auth_manager/api_endpoints/test_dag_run_endpoint.py +++ b/providers/tests/fab/auth_manager/api_endpoints/test_dag_run_endpoint.py @@ -20,7 +20,7 @@ import pytest -from airflow.models.dag import DAG, DagModel +from airflow.models.dag import DagModel from airflow.models.dagrun import DagRun from airflow.models.param import Param from airflow.providers.fab.www.security import permissions @@ -125,29 +125,26 @@ def setup_attrs(self, configured_app) -> None: clear_db_serialized_dags() clear_db_dags() + @pytest.fixture(autouse=True) + def create_dag(self, dag_maker, setup_attrs): + with dag_maker( + "TEST_DAG_ID", schedule=None, params={"validated_number": Param(1, minimum=1, maximum=10)} + ): + pass + + dag_maker.sync_dagbag_to_db() + def teardown_method(self) -> None: clear_db_runs() clear_db_dags() clear_db_serialized_dags() - def _create_dag(self, dag_id): - dag_instance = DagModel(dag_id=dag_id) - dag_instance.is_active = True - with create_session() as session: - session.add(dag_instance) - dag = DAG(dag_id=dag_id, schedule=None, params={"validated_number": Param(1, minimum=1, maximum=10)}) - self.app.dag_bag.bag_dag(dag) - self.app.dag_bag.sync_to_db() - return dag_instance - def _create_test_dag_run(self, state=DagRunState.RUNNING, extra_dag=False, commit=True, idx_start=1): dag_runs = [] dags = [] triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} for i in range(idx_start, idx_start + 2): - if i == 1: - dags.append(DagModel(dag_id="TEST_DAG_ID", is_active=True)) dagrun_model = DagRun( dag_id="TEST_DAG_ID", run_id=f"TEST_DAG_RUN_ID_{i}", @@ -247,7 +244,6 @@ def test_should_return_accessible_with_tilde_as_dag_id_and_dag_level_permissions class TestPostDagRun(TestDagRunEndpoint): def test_dagrun_trigger_with_dag_level_permissions(self): - self._create_dag("TEST_DAG_ID") response = self.client.post( "api/v1/dags/TEST_DAG_ID/dagRuns", json={"conf": {"validated_number": 1}}, @@ -260,7 +256,6 @@ def test_dagrun_trigger_with_dag_level_permissions(self): ["test_dag_view_only", "test_view_dags", "test_granular_permissions"], ) def test_should_raises_403_unauthorized(self, username): - self._create_dag("TEST_DAG_ID") response = self.client.post( "api/v1/dags/TEST_DAG_ID/dagRuns", json={ diff --git a/providers/tests/fab/auth_manager/api_endpoints/test_dag_source_endpoint.py b/providers/tests/fab/auth_manager/api_endpoints/test_dag_source_endpoint.py index 8461e5bf0f170..66cd6477c9e9b 100644 --- a/providers/tests/fab/auth_manager/api_endpoints/test_dag_source_endpoint.py +++ b/providers/tests/fab/auth_manager/api_endpoints/test_dag_source_endpoint.py @@ -18,7 +18,6 @@ import ast import os -from typing import TYPE_CHECKING import pytest @@ -26,7 +25,12 @@ from airflow.providers.fab.www.security import permissions from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user -from tests_common.test_utils.db import clear_db_dag_code, clear_db_dags, clear_db_serialized_dags +from tests_common.test_utils.db import ( + clear_db_dag_code, + clear_db_dags, + clear_db_serialized_dags, + parse_and_sync_to_db, +) from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS pytestmark = [ @@ -34,11 +38,7 @@ pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), ] -if TYPE_CHECKING: - from airflow.models.dag import DAG -ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)) -EXAMPLE_DAG_FILE = os.path.join("airflow", "example_dags", "example_bash_operator.py") EXAMPLE_DAG_ID = "example_bash_operator" TEST_DAG_ID = "latest_only" NOT_READABLE_DAG_ID = "latest_only_with_trigger" @@ -97,9 +97,9 @@ def _get_dag_file_docstring(fileloc: str) -> str | None: return docstring def test_should_respond_403_not_readable(self, url_safe_serializer): - dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE) - dagbag.sync_to_db() - dag: DAG = dagbag.dags[NOT_READABLE_DAG_ID] + parse_and_sync_to_db(os.devnull, include_examples=True) + dagbag = DagBag(read_dags_from_db=True) + dag = dagbag.get_dag(NOT_READABLE_DAG_ID) response = self.client.get( f"/api/v1/dagSources/{dag.dag_id}", @@ -114,9 +114,9 @@ def test_should_respond_403_not_readable(self, url_safe_serializer): assert read_dag.status_code == 403 def test_should_respond_403_some_dags_not_readable_in_the_file(self, url_safe_serializer): - dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE) - dagbag.sync_to_db() - dag: DAG = dagbag.dags[TEST_MULTIPLE_DAGS_ID] + parse_and_sync_to_db(os.devnull, include_examples=True) + dagbag = DagBag(read_dags_from_db=True) + dag = dagbag.get_dag(TEST_MULTIPLE_DAGS_ID) response = self.client.get( f"/api/v1/dagSources/{dag.dag_id}", diff --git a/providers/tests/fab/auth_manager/conftest.py b/providers/tests/fab/auth_manager/conftest.py index d400a7b86a027..f26a08d19c3e3 100644 --- a/providers/tests/fab/auth_manager/conftest.py +++ b/providers/tests/fab/auth_manager/conftest.py @@ -16,11 +16,14 @@ # under the License. from __future__ import annotations +import os + import pytest from airflow.www import app from tests_common.test_utils.config import conf_vars +from tests_common.test_utils.db import parse_and_sync_to_db from tests_common.test_utils.decorators import dont_initialize_flask_app_submodules @@ -72,5 +75,5 @@ def set_auth_role_public(request): def dagbag(): from airflow.models import DagBag - DagBag(include_examples=True, read_dags_from_db=False).sync_to_db() - return DagBag(include_examples=True, read_dags_from_db=True) + parse_and_sync_to_db(os.devnull, include_examples=True) + return DagBag(read_dags_from_db=True) diff --git a/providers/tests/openlineage/plugins/test_utils.py b/providers/tests/openlineage/plugins/test_utils.py index e1e3e8732690a..490061db4f594 100644 --- a/providers/tests/openlineage/plugins/test_utils.py +++ b/providers/tests/openlineage/plugins/test_utils.py @@ -28,7 +28,7 @@ from openlineage.client.utils import RedactMixin from pkg_resources import parse_version -from airflow.models import DAG as AIRFLOW_DAG, DagModel +from airflow.models import DagModel from airflow.providers.common.compat.assets import Asset from airflow.providers.openlineage.plugins.facets import AirflowDebugRunFacet from airflow.providers.openlineage.utils.utils import ( @@ -91,12 +91,14 @@ def test_get_airflow_debug_facet_logging_set_to_debug(mock_debug_mode, mock_get_ @pytest.mark.db_test -def test_get_dagrun_start_end(): +def test_get_dagrun_start_end(dag_maker): start_date = datetime.datetime(2022, 1, 1) end_date = datetime.datetime(2022, 1, 1, hour=2) - dag = AIRFLOW_DAG("test", start_date=start_date, end_date=end_date, schedule="@once") - AIRFLOW_DAG.bulk_write_to_db([dag]) + with dag_maker("test", start_date=start_date, end_date=end_date, schedule="@once") as dag: + pass + dag_maker.sync_dagbag_to_db() dag_model = DagModel.get_dagmodel(dag.dag_id) + run_id = str(uuid.uuid1()) triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} dagrun = dag.create_dagrun( diff --git a/tests/api_connexion/conftest.py b/tests/api_connexion/conftest.py index 39e8da7212ae7..fa5d4c9250065 100644 --- a/tests/api_connexion/conftest.py +++ b/tests/api_connexion/conftest.py @@ -16,11 +16,14 @@ # under the License. from __future__ import annotations +import os + import pytest from airflow.www import app from tests_common.test_utils.config import conf_vars +from tests_common.test_utils.db import parse_and_sync_to_db from tests_common.test_utils.decorators import dont_initialize_flask_app_submodules @@ -64,5 +67,5 @@ def session(): def dagbag(): from airflow.models import DagBag - DagBag(include_examples=True, read_dags_from_db=False).sync_to_db() - return DagBag(include_examples=True, read_dags_from_db=True) + parse_and_sync_to_db(os.devnull, include_examples=True) + return DagBag(read_dags_from_db=True) diff --git a/tests/api_connexion/endpoints/test_dag_parsing.py b/tests/api_connexion/endpoints/test_dag_parsing.py index 1df80a905d92e..0052732b7dfb9 100644 --- a/tests/api_connexion/endpoints/test_dag_parsing.py +++ b/tests/api_connexion/endpoints/test_dag_parsing.py @@ -17,7 +17,6 @@ from __future__ import annotations import os -from typing import TYPE_CHECKING import pytest from sqlalchemy import select @@ -26,17 +25,14 @@ from airflow.models.dagbag import DagPriorityParsingRequest from tests_common.test_utils.api_connexion_utils import create_user, delete_user -from tests_common.test_utils.db import clear_db_dag_parsing_requests +from tests_common.test_utils.db import clear_db_dag_parsing_requests, parse_and_sync_to_db pytestmark = pytest.mark.db_test -if TYPE_CHECKING: - from airflow.models.dag import DAG ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)) EXAMPLE_DAG_FILE = os.path.join("airflow", "example_dags", "example_bash_operator.py") -EXAMPLE_DAG_ID = "example_bash_operator" -TEST_DAG_ID = "latest_only" +TEST_DAG_ID = "example_bash_operator" NOT_READABLE_DAG_ID = "latest_only_with_trigger" TEST_MULTIPLE_DAGS_ID = "asset_produces_1" @@ -72,9 +68,9 @@ def clear_db(): clear_db_dag_parsing_requests() def test_201_and_400_requests(self, url_safe_serializer, session): - dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE) - dagbag.sync_to_db() - test_dag: DAG = dagbag.dags[TEST_DAG_ID] + parse_and_sync_to_db(EXAMPLE_DAG_FILE) + dagbag = DagBag(read_dags_from_db=True) + test_dag = dagbag.get_dag(TEST_DAG_ID) url = f"/api/v1/parseDagFile/{url_safe_serializer.dumps(test_dag.fileloc)}" response = self.client.put( diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py b/tests/api_connexion/endpoints/test_dag_run_endpoint.py index 4f322bb8f0a6c..b6609c162390d 100644 --- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py @@ -24,6 +24,7 @@ import time_machine from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP +from airflow.dag_processing.bundles.manager import DagBundlesManager from airflow.models.asset import AssetEvent, AssetModel from airflow.models.dag import DAG, DagModel from airflow.models.dagrun import DagRun @@ -88,8 +89,9 @@ def _create_dag(self, dag_id): with create_session() as session: session.add(dag_instance) dag = DAG(dag_id=dag_id, schedule=None, params={"validated_number": Param(1, minimum=1, maximum=10)}) + DagBundlesManager().sync_bundles_to_db() self.app.dag_bag.bag_dag(dag) - self.app.dag_bag.sync_to_db() + self.app.dag_bag.sync_to_db("dags-folder", None) return dag_instance def _create_test_dag_run(self, state=DagRunState.RUNNING, extra_dag=False, commit=True, idx_start=1): diff --git a/tests/api_connexion/endpoints/test_dag_source_endpoint.py b/tests/api_connexion/endpoints/test_dag_source_endpoint.py index a907d2704c6e7..c7e3b12c7e6fd 100644 --- a/tests/api_connexion/endpoints/test_dag_source_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_source_endpoint.py @@ -16,6 +16,8 @@ # under the License. from __future__ import annotations +import os + import pytest from sqlalchemy import select @@ -24,13 +26,12 @@ from airflow.models.serialized_dag import SerializedDagModel from tests_common.test_utils.api_connexion_utils import assert_401, create_user, delete_user -from tests_common.test_utils.db import clear_db_dags +from tests_common.test_utils.db import clear_db_dags, parse_and_sync_to_db pytestmark = pytest.mark.db_test -EXAMPLE_DAG_ID = "example_bash_operator" -TEST_DAG_ID = "latest_only" +TEST_DAG_ID = "example_bash_operator" @pytest.fixture(scope="module") @@ -51,9 +52,9 @@ def configured_app(minimal_app_for_api): @pytest.fixture def test_dag(): - dagbag = DagBag(include_examples=True) - dagbag.sync_to_db() - return dagbag.dags[TEST_DAG_ID] + parse_and_sync_to_db(os.devnull, include_examples=True) + dagbag = DagBag(read_dags_from_db=True) + return dagbag.get_dag(TEST_DAG_ID) class TestGetSource: diff --git a/tests/api_connexion/endpoints/test_extra_link_endpoint.py b/tests/api_connexion/endpoints/test_extra_link_endpoint.py index 8cb8dd4e030c1..e4aa7895ac662 100644 --- a/tests/api_connexion/endpoints/test_extra_link_endpoint.py +++ b/tests/api_connexion/endpoints/test_extra_link_endpoint.py @@ -22,6 +22,7 @@ import pytest from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP +from airflow.dag_processing.bundles.manager import DagBundlesManager from airflow.models.dag import DAG from airflow.models.dagbag import DagBag from airflow.models.xcom import XCom @@ -73,9 +74,10 @@ def setup_attrs(self, configured_app, session) -> None: self.dag = self._create_dag() + DagBundlesManager().sync_bundles_to_db() self.app.dag_bag = DagBag(os.devnull, include_examples=False) self.app.dag_bag.dags = {self.dag.dag_id: self.dag} - self.app.dag_bag.sync_to_db() + self.app.dag_bag.sync_to_db("dags-folder", None) triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} self.dag.create_dagrun( diff --git a/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py b/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py index a4a14fa3b421d..7a4d802f6cc15 100644 --- a/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py +++ b/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py @@ -24,6 +24,7 @@ import pytest from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP +from airflow.dag_processing.bundles.manager import DagBundlesManager from airflow.models import TaskInstance from airflow.models.baseoperator import BaseOperator from airflow.models.dagbag import DagBag @@ -122,9 +123,10 @@ def create_dag_runs_with_mapped_tasks(self, dag_maker, session, dags=None): setattr(ti, "start_date", DEFAULT_DATETIME_1) session.add(ti) + DagBundlesManager().sync_bundles_to_db() self.app.dag_bag = DagBag(os.devnull, include_examples=False) self.app.dag_bag.dags = {dag_id: dag_maker.dag} - self.app.dag_bag.sync_to_db() + self.app.dag_bag.sync_to_db("dags-folder", None) session.flush() mapped.expand_mapped_task(dr.run_id, session=session) diff --git a/tests/api_connexion/endpoints/test_task_endpoint.py b/tests/api_connexion/endpoints/test_task_endpoint.py index d1662943604ac..7e593731d539d 100644 --- a/tests/api_connexion/endpoints/test_task_endpoint.py +++ b/tests/api_connexion/endpoints/test_task_endpoint.py @@ -22,6 +22,7 @@ import pytest +from airflow.dag_processing.bundles.manager import DagBundlesManager from airflow.models import DagBag from airflow.models.dag import DAG from airflow.models.expandinput import EXPAND_INPUT_EMPTY @@ -80,13 +81,15 @@ def setup_dag(self, configured_app): task5 = EmptyOperator(task_id=self.unscheduled_task_id2, params={"is_unscheduled": True}) task1 >> task2 task4 >> task5 + + DagBundlesManager().sync_bundles_to_db() dag_bag = DagBag(os.devnull, include_examples=False) dag_bag.dags = { dag.dag_id: dag, mapped_dag.dag_id: mapped_dag, unscheduled_dag.dag_id: unscheduled_dag, } - dag_bag.sync_to_db() + dag_bag.sync_to_db("dags-folder", None) configured_app.dag_bag = dag_bag # type:ignore @staticmethod diff --git a/tests/api_connexion/endpoints/test_task_instance_endpoint.py b/tests/api_connexion/endpoints/test_task_instance_endpoint.py index a4089c9785a98..b5079c47aa17e 100644 --- a/tests/api_connexion/endpoints/test_task_instance_endpoint.py +++ b/tests/api_connexion/endpoints/test_task_instance_endpoint.py @@ -1224,7 +1224,7 @@ def test_should_respond_200(self, main_dag, task_instances, request_dag, payload task_instances=task_instances, update_extras=False, ) - self.app.dag_bag.sync_to_db() + self.app.dag_bag.sync_to_db("dags-folder", None) response = self.client.post( f"/api/v1/dags/{request_dag}/clearTaskInstances", environ_overrides={"REMOTE_USER": "test"}, @@ -1246,7 +1246,7 @@ def test_clear_taskinstance_is_called_with_queued_dr_state(self, mock_clearti, s self.create_task_instances(session) dag_id = "example_python_operator" payload = {"reset_dag_runs": True, "dry_run": False} - self.app.dag_bag.sync_to_db() + self.app.dag_bag.sync_to_db("dags-folder", None) response = self.client.post( f"/api/v1/dags/{dag_id}/clearTaskInstances", environ_overrides={"REMOTE_USER": "test"}, @@ -1266,7 +1266,7 @@ def test_clear_taskinstance_is_called_with_invalid_task_ids(self, session): assert dagrun.state == "running" payload = {"dry_run": False, "reset_dag_runs": True, "task_ids": [""]} - self.app.dag_bag.sync_to_db() + self.app.dag_bag.sync_to_db("dags-folder", None) response = self.client.post( f"/api/v1/dags/{dag_id}/clearTaskInstances", environ_overrides={"REMOTE_USER": "test"}, @@ -1693,7 +1693,7 @@ def test_should_raise_400_for_naive_and_bad_datetime(self, payload, expected, se task_instances=task_instances, update_extras=False, ) - self.app.dag_bag.sync_to_db() + self.app.dag_bag.sync_to_db("dags-folder", None) response = self.client.post( "/api/v1/dags/example_python_operator/clearTaskInstances", environ_overrides={"REMOTE_USER": "test"}, diff --git a/tests/api_fastapi/conftest.py b/tests/api_fastapi/conftest.py index 2928a4d829c70..71f42756dd564 100644 --- a/tests/api_fastapi/conftest.py +++ b/tests/api_fastapi/conftest.py @@ -16,11 +16,15 @@ # under the License. from __future__ import annotations +import os + import pytest from fastapi.testclient import TestClient from airflow.api_fastapi.app import create_app +from tests_common.test_utils.db import parse_and_sync_to_db + @pytest.fixture def test_client(): @@ -42,6 +46,5 @@ def create_test_client(apps="all"): def dagbag(): from airflow.models import DagBag - dagbag_instance = DagBag(include_examples=True, read_dags_from_db=False) - dagbag_instance.sync_to_db() - return dagbag_instance + parse_and_sync_to_db(os.devnull, include_examples=True) + return DagBag(read_dags_from_db=True) diff --git a/tests/api_fastapi/core_api/routes/public/test_dag_parsing.py b/tests/api_fastapi/core_api/routes/public/test_dag_parsing.py index b937f66803f31..fe78c193d389b 100644 --- a/tests/api_fastapi/core_api/routes/public/test_dag_parsing.py +++ b/tests/api_fastapi/core_api/routes/public/test_dag_parsing.py @@ -17,7 +17,6 @@ from __future__ import annotations import os -from typing import TYPE_CHECKING import pytest from sqlalchemy import select @@ -26,19 +25,15 @@ from airflow.models.dagbag import DagPriorityParsingRequest from airflow.utils.session import provide_session -from tests_common.test_utils.db import clear_db_dag_parsing_requests +from tests_common.test_utils.db import clear_db_dag_parsing_requests, parse_and_sync_to_db pytestmark = pytest.mark.db_test -if TYPE_CHECKING: - from airflow.models.dag import DAG - class TestDagParsingEndpoint: ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)) EXAMPLE_DAG_FILE = os.path.join("airflow", "example_dags", "example_bash_operator.py") - EXAMPLE_DAG_ID = "example_bash_operator" - TEST_DAG_ID = "latest_only" + TEST_DAG_ID = "example_bash_operator" NOT_READABLE_DAG_ID = "latest_only_with_trigger" TEST_MULTIPLE_DAGS_ID = "asset_produces_1" @@ -55,9 +50,9 @@ def teardown_method(self) -> None: self.clear_db() def test_201_and_400_requests(self, url_safe_serializer, session, test_client): - dagbag = DagBag(dag_folder=self.EXAMPLE_DAG_FILE) - dagbag.sync_to_db() - test_dag: DAG = dagbag.dags[self.TEST_DAG_ID] + parse_and_sync_to_db(self.EXAMPLE_DAG_FILE) + dagbag = DagBag(read_dags_from_db=True) + test_dag = dagbag.get_dag(self.TEST_DAG_ID) url = f"/public/parseDagFile/{url_safe_serializer.dumps(test_dag.fileloc)}" response = test_client.put(url, headers={"Accept": "application/json"}) diff --git a/tests/api_fastapi/core_api/routes/public/test_dag_run.py b/tests/api_fastapi/core_api/routes/public/test_dag_run.py index 75fec91e1a0c7..9b7c9211fddef 100644 --- a/tests/api_fastapi/core_api/routes/public/test_dag_run.py +++ b/tests/api_fastapi/core_api/routes/public/test_dag_run.py @@ -141,7 +141,7 @@ def setup(request, dag_maker, session=None): logical_date=LOGICAL_DATE4, ) - dag_maker.dagbag.sync_to_db() + dag_maker.sync_dagbag_to_db() dag_maker.dag_model dag_maker.dag_model.has_task_concurrency_limits = True session.merge(ti1) diff --git a/tests/api_fastapi/core_api/routes/public/test_dag_sources.py b/tests/api_fastapi/core_api/routes/public/test_dag_sources.py index 4e8a5990ecb26..dce7981872085 100644 --- a/tests/api_fastapi/core_api/routes/public/test_dag_sources.py +++ b/tests/api_fastapi/core_api/routes/public/test_dag_sources.py @@ -28,7 +28,7 @@ from airflow.models.dagcode import DagCode from airflow.models.serialized_dag import SerializedDagModel -from tests_common.test_utils.db import clear_db_dags +from tests_common.test_utils.db import clear_db_dags, parse_and_sync_to_db pytestmark = pytest.mark.db_test @@ -36,14 +36,13 @@ # Example bash operator located here: airflow/example_dags/example_bash_operator.py EXAMPLE_DAG_FILE = os.path.join("airflow", "example_dags", "example_bash_operator.py") -TEST_DAG_ID = "latest_only" +TEST_DAG_ID = "example_bash_operator" @pytest.fixture def test_dag(): - dagbag = DagBag(include_examples=True) - dagbag.sync_to_db() - return dagbag.dags[TEST_DAG_ID] + parse_and_sync_to_db(EXAMPLE_DAG_FILE, include_examples=False) + return DagBag(read_dags_from_db=True).get_dag(TEST_DAG_ID) class TestGetDAGSource: @@ -131,9 +130,7 @@ def test_should_respond_200_version(self, test_client, accept, session, test_dag "version_number": 2, } - def test_should_respond_406_unsupport_mime_type(self, test_client): - dagbag = DagBag(include_examples=True) - dagbag.sync_to_db() + def test_should_respond_406_unsupport_mime_type(self, test_client, test_dag): response = test_client.get( f"{API_PREFIX}/{TEST_DAG_ID}", headers={"Accept": "text/html"}, diff --git a/tests/api_fastapi/core_api/routes/public/test_dag_tags.py b/tests/api_fastapi/core_api/routes/public/test_dag_tags.py index 784bb480c431b..7d7720c76f773 100644 --- a/tests/api_fastapi/core_api/routes/public/test_dag_tags.py +++ b/tests/api_fastapi/core_api/routes/public/test_dag_tags.py @@ -111,6 +111,7 @@ def setup(self, dag_maker, session=None) -> None: ): EmptyOperator(task_id=TASK_ID) + dag_maker.sync_dagbag_to_db() dag_maker.create_dagrun(state=DagRunState.FAILED) with dag_maker( @@ -127,7 +128,7 @@ def setup(self, dag_maker, session=None) -> None: self._create_deactivated_paused_dag(session) self._create_dag_tags(session) - dag_maker.dagbag.sync_to_db() + dag_maker.sync_dagbag_to_db() dag_maker.dag_model.has_task_concurrency_limits = True session.merge(dag_maker.dag_model) session.commit() diff --git a/tests/api_fastapi/core_api/routes/public/test_dags.py b/tests/api_fastapi/core_api/routes/public/test_dags.py index 05a49c44dfeea..b79c23eb36b4e 100644 --- a/tests/api_fastapi/core_api/routes/public/test_dags.py +++ b/tests/api_fastapi/core_api/routes/public/test_dags.py @@ -127,7 +127,7 @@ def setup(self, dag_maker, session=None) -> None: self._create_deactivated_paused_dag(session) self._create_dag_tags(session) - dag_maker.dagbag.sync_to_db() + dag_maker.sync_dagbag_to_db() dag_maker.dag_model.has_task_concurrency_limits = True session.merge(dag_maker.dag_model) session.commit() @@ -409,7 +409,7 @@ def _create_dag_for_deletion( ti = dr.get_task_instances()[0] ti.set_state(TaskInstanceState.RUNNING) - dag_maker.dagbag.sync_to_db() + dag_maker.sync_dagbag_to_db() @pytest.mark.parametrize( "dag_id, dag_display_name, status_code_delete, status_code_details, has_running_dagruns, is_create_dag", diff --git a/tests/api_fastapi/core_api/routes/public/test_extra_links.py b/tests/api_fastapi/core_api/routes/public/test_extra_links.py index 6a2ca4650075e..89aef555bbe20 100644 --- a/tests/api_fastapi/core_api/routes/public/test_extra_links.py +++ b/tests/api_fastapi/core_api/routes/public/test_extra_links.py @@ -21,6 +21,7 @@ import pytest +from airflow.dag_processing.bundles.manager import DagBundlesManager from airflow.models.dag import DAG from airflow.models.dagbag import DagBag from airflow.models.xcom import XCom @@ -67,10 +68,11 @@ def setup(self, test_client, session=None) -> None: self.dag = self._create_dag() + DagBundlesManager().sync_bundles_to_db() dag_bag = DagBag(os.devnull, include_examples=False) dag_bag.dags = {self.dag.dag_id: self.dag} test_client.app.state.dag_bag = dag_bag - dag_bag.sync_to_db() + dag_bag.sync_to_db("dags-folder", None) triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} diff --git a/tests/api_fastapi/core_api/routes/public/test_task_instances.py b/tests/api_fastapi/core_api/routes/public/test_task_instances.py index d62d37944348c..f0c3c888807ea 100644 --- a/tests/api_fastapi/core_api/routes/public/test_task_instances.py +++ b/tests/api_fastapi/core_api/routes/public/test_task_instances.py @@ -26,6 +26,7 @@ import pytest from sqlalchemy import select +from airflow.dag_processing.bundles.manager import DagBundlesManager from airflow.jobs.job import Job from airflow.jobs.triggerer_job_runner import TriggererJobRunner from airflow.models import DagRun, TaskInstance @@ -522,9 +523,10 @@ def create_dag_runs_with_mapped_tasks(self, dag_maker, session, dags=None): setattr(ti, "start_date", DEFAULT_DATETIME_1) session.add(ti) + DagBundlesManager().sync_bundles_to_db() dagbag = DagBag(os.devnull, include_examples=False) dagbag.dags = {dag_id: dag_maker.dag} - dagbag.sync_to_db() + dagbag.sync_to_db("dags-folder", None) session.flush() mapped.expand_mapped_task(dr.run_id, session=session) @@ -1857,7 +1859,7 @@ def test_should_respond_200( task_instances=task_instances, update_extras=False, ) - self.dagbag.sync_to_db() + self.dagbag.sync_to_db("dags-folder", None) response = test_client.post( f"/public/dags/{request_dag}/clearTaskInstances", json=payload, @@ -1871,7 +1873,7 @@ def test_clear_taskinstance_is_called_with_queued_dr_state(self, mock_clearti, t self.create_task_instances(session) dag_id = "example_python_operator" payload = {"reset_dag_runs": True, "dry_run": False} - self.dagbag.sync_to_db() + self.dagbag.sync_to_db("dags-folder", None) response = test_client.post( f"/public/dags/{dag_id}/clearTaskInstances", json=payload, @@ -1891,7 +1893,7 @@ def test_clear_taskinstance_is_called_with_invalid_task_ids(self, test_client, s assert dagrun.state == "running" payload = {"dry_run": False, "reset_dag_runs": True, "task_ids": [""]} - self.dagbag.sync_to_db() + self.dagbag.sync_to_db("dags-folder", None) response = test_client.post( f"/public/dags/{dag_id}/clearTaskInstances", json=payload, @@ -1941,7 +1943,7 @@ def test_should_respond_200_with_reset_dag_run(self, test_client, session): update_extras=False, dag_run_state=DagRunState.FAILED, ) - self.dagbag.sync_to_db() + self.dagbag.sync_to_db("dags-folder", None) response = test_client.post( f"/public/dags/{dag_id}/clearTaskInstances", json=payload, @@ -2026,7 +2028,7 @@ def test_should_respond_200_with_dag_run_id(self, test_client, session): update_extras=False, dag_run_state=State.FAILED, ) - self.dagbag.sync_to_db() + self.dagbag.sync_to_db("dags-folder", None) response = test_client.post( f"/public/dags/{dag_id}/clearTaskInstances", json=payload, @@ -2082,7 +2084,7 @@ def test_should_respond_200_with_include_past(self, test_client, session): update_extras=False, dag_run_state=State.FAILED, ) - self.dagbag.sync_to_db() + self.dagbag.sync_to_db("dags-folder", None) response = test_client.post( f"/public/dags/{dag_id}/clearTaskInstances", json=payload, @@ -2164,7 +2166,7 @@ def test_should_respond_200_with_include_future(self, test_client, session): update_extras=False, dag_run_state=State.FAILED, ) - self.dagbag.sync_to_db() + self.dagbag.sync_to_db("dags-folder", None) response = test_client.post( f"/public/dags/{dag_id}/clearTaskInstances", json=payload, @@ -2312,7 +2314,7 @@ def test_should_raise_400_for_naive_and_bad_datetime(self, test_client, session, task_instances=task_instances, update_extras=False, ) - self.dagbag.sync_to_db() + self.dagbag.sync_to_db("dags-folder", None) response = test_client.post( "/public/dags/example_python_operator/clearTaskInstances", json=payload, diff --git a/tests/api_fastapi/core_api/routes/ui/test_assets.py b/tests/api_fastapi/core_api/routes/ui/test_assets.py index cbe1b292f6d7e..4935b25014ea7 100644 --- a/tests/api_fastapi/core_api/routes/ui/test_assets.py +++ b/tests/api_fastapi/core_api/routes/ui/test_assets.py @@ -46,7 +46,7 @@ def test_next_run_assets(test_client, dag_maker): EmptyOperator(task_id="task1") dag_maker.create_dagrun() - dag_maker.dagbag.sync_to_db() + dag_maker.sync_dagbag_to_db() response = test_client.get("/ui/next_run_assets/upstream") diff --git a/tests/api_fastapi/core_api/routes/ui/test_dashboard.py b/tests/api_fastapi/core_api/routes/ui/test_dashboard.py index 93317eaa67088..416e016e25c95 100644 --- a/tests/api_fastapi/core_api/routes/ui/test_dashboard.py +++ b/tests/api_fastapi/core_api/routes/ui/test_dashboard.py @@ -95,7 +95,7 @@ def make_dag_runs(dag_maker, session, time_machine): for ti in run2.task_instances: ti.state = TaskInstanceState.FAILED - dag_maker.dagbag.sync_to_db() + dag_maker.sync_dagbag_to_db() time_machine.move_to("2023-07-02T00:00:00+00:00", tick=False) diff --git a/tests/api_fastapi/core_api/routes/ui/test_structure.py b/tests/api_fastapi/core_api/routes/ui/test_structure.py index 9732269944dfb..1f865a140c5be 100644 --- a/tests/api_fastapi/core_api/routes/ui/test_structure.py +++ b/tests/api_fastapi/core_api/routes/ui/test_structure.py @@ -59,7 +59,7 @@ def make_dag(dag_maker, session, time_machine): ): TriggerDagRunOperator(task_id="trigger_dag_run_operator", trigger_dag_id=DAG_ID) - dag_maker.dagbag.sync_to_db() + dag_maker.sync_dagbag_to_db() with dag_maker( dag_id=DAG_ID, @@ -78,7 +78,7 @@ def make_dag(dag_maker, session, time_machine): >> EmptyOperator(task_id="task_2") ) - dag_maker.dagbag.sync_to_db() + dag_maker.sync_dagbag_to_db() class TestStructureDataEndpoint: diff --git a/tests/cli/commands/remote_commands/test_asset_command.py b/tests/cli/commands/remote_commands/test_asset_command.py index 69906d1813a29..067bc9e8e1e2b 100644 --- a/tests/cli/commands/remote_commands/test_asset_command.py +++ b/tests/cli/commands/remote_commands/test_asset_command.py @@ -21,15 +21,15 @@ import contextlib import io import json +import os import typing import pytest from airflow.cli import cli_parser from airflow.cli.commands.remote_commands import asset_command -from airflow.models.dagbag import DagBag -from tests_common.test_utils.db import clear_db_dags, clear_db_runs +from tests_common.test_utils.db import clear_db_dags, clear_db_runs, parse_and_sync_to_db if typing.TYPE_CHECKING: from argparse import ArgumentParser @@ -39,7 +39,7 @@ @pytest.fixture(scope="module", autouse=True) def prepare_examples(): - DagBag(include_examples=True).sync_to_db() + parse_and_sync_to_db(os.devnull, include_examples=True) yield clear_db_runs() clear_db_dags() diff --git a/tests/cli/commands/remote_commands/test_backfill_command.py b/tests/cli/commands/remote_commands/test_backfill_command.py index e252d4db1c478..38b0169a0e7dd 100644 --- a/tests/cli/commands/remote_commands/test_backfill_command.py +++ b/tests/cli/commands/remote_commands/test_backfill_command.py @@ -18,6 +18,7 @@ from __future__ import annotations import argparse +import os from datetime import datetime from unittest import mock @@ -26,11 +27,10 @@ import airflow.cli.commands.remote_commands.backfill_command from airflow.cli import cli_parser -from airflow.models import DagBag from airflow.models.backfill import ReprocessBehavior from airflow.utils import timezone -from tests_common.test_utils.db import clear_db_backfills, clear_db_dags, clear_db_runs +from tests_common.test_utils.db import clear_db_backfills, clear_db_dags, clear_db_runs, parse_and_sync_to_db DEFAULT_DATE = timezone.make_aware(datetime(2015, 1, 1), timezone=timezone.utc) if pendulum.__version__.startswith("3"): @@ -48,8 +48,7 @@ class TestCliBackfill: @classmethod def setup_class(cls): - cls.dagbag = DagBag(include_examples=True) - cls.dagbag.sync_to_db() + parse_and_sync_to_db(os.devnull, include_examples=True) cls.parser = cli_parser.get_parser() @classmethod diff --git a/tests/cli/commands/remote_commands/test_dag_command.py b/tests/cli/commands/remote_commands/test_dag_command.py index dab4d0da6caa0..0db0ac83df02f 100644 --- a/tests/cli/commands/remote_commands/test_dag_command.py +++ b/tests/cli/commands/remote_commands/test_dag_command.py @@ -50,7 +50,7 @@ from tests.models import TEST_DAGS_FOLDER from tests_common.test_utils.config import conf_vars -from tests_common.test_utils.db import clear_db_dags, clear_db_runs +from tests_common.test_utils.db import clear_db_dags, clear_db_runs, parse_and_sync_to_db DEFAULT_DATE = timezone.make_aware(datetime(2015, 1, 1), timezone=timezone.utc) if pendulum.__version__.startswith("3"): @@ -68,8 +68,7 @@ class TestCliDags: @classmethod def setup_class(cls): - cls.dagbag = DagBag(include_examples=True) - cls.dagbag.sync_to_db() + parse_and_sync_to_db(os.devnull, include_examples=True) cls.parser = cli_parser.get_parser() @classmethod @@ -207,8 +206,7 @@ def test_next_execution(self, tmp_path): with time_machine.travel(DEFAULT_DATE): clear_db_dags() - self.dagbag = DagBag(dag_folder=tmp_path, include_examples=False) - self.dagbag.sync_to_db() + parse_and_sync_to_db(tmp_path, include_examples=False) default_run = DEFAULT_DATE future_run = default_run + timedelta(days=5) @@ -255,8 +253,7 @@ def test_next_execution(self, tmp_path): # Rebuild Test DB for other tests clear_db_dags() - TestCliDags.dagbag = DagBag(include_examples=True) - TestCliDags.dagbag.sync_to_db() + parse_and_sync_to_db(os.devnull, include_examples=True) @conf_vars({("core", "load_examples"): "true"}) def test_cli_report(self): @@ -405,24 +402,24 @@ def test_cli_list_jobs_with_args(self): def test_pause(self): args = self.parser.parse_args(["dags", "pause", "example_bash_operator"]) dag_command.dag_pause(args) - assert self.dagbag.dags["example_bash_operator"].get_is_paused() + assert DagModel.get_dagmodel("example_bash_operator").is_paused dag_command.dag_unpause(args) - assert not self.dagbag.dags["example_bash_operator"].get_is_paused() + assert not DagModel.get_dagmodel("example_bash_operator").is_paused @mock.patch("airflow.cli.commands.remote_commands.dag_command.ask_yesno") def test_pause_regex(self, mock_yesno): args = self.parser.parse_args(["dags", "pause", "^example_.*$", "--treat-dag-id-as-regex"]) dag_command.dag_pause(args) mock_yesno.assert_called_once() - assert self.dagbag.dags["example_bash_decorator"].get_is_paused() - assert self.dagbag.dags["example_kubernetes_executor"].get_is_paused() - assert self.dagbag.dags["example_xcom_args"].get_is_paused() + assert DagModel.get_dagmodel("example_bash_decorator").is_paused + assert DagModel.get_dagmodel("example_kubernetes_executor").is_paused + assert DagModel.get_dagmodel("example_xcom_args").is_paused args = self.parser.parse_args(["dags", "unpause", "^example_.*$", "--treat-dag-id-as-regex"]) dag_command.dag_unpause(args) - assert not self.dagbag.dags["example_bash_decorator"].get_is_paused() - assert not self.dagbag.dags["example_kubernetes_executor"].get_is_paused() - assert not self.dagbag.dags["example_xcom_args"].get_is_paused() + assert not DagModel.get_dagmodel("example_bash_decorator").is_paused + assert not DagModel.get_dagmodel("example_kubernetes_executor").is_paused + assert not DagModel.get_dagmodel("example_xcom_args").is_paused @mock.patch("airflow.cli.commands.remote_commands.dag_command.ask_yesno") def test_pause_regex_operation_cancelled(self, ask_yesno, capsys): diff --git a/tests/cli/commands/remote_commands/test_task_command.py b/tests/cli/commands/remote_commands/test_task_command.py index 843d6817cdcc5..9a4c606caa469 100644 --- a/tests/cli/commands/remote_commands/test_task_command.py +++ b/tests/cli/commands/remote_commands/test_task_command.py @@ -53,7 +53,7 @@ from airflow.utils.types import DagRunType from tests_common.test_utils.config import conf_vars -from tests_common.test_utils.db import clear_db_pools, clear_db_runs +from tests_common.test_utils.db import clear_db_pools, clear_db_runs, parse_and_sync_to_db from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS if AIRFLOW_V_3_0_PLUS: @@ -97,12 +97,12 @@ class TestCliTasks: @pytest.fixture(autouse=True) def setup_class(cls): logging.config.dictConfig(DEFAULT_LOGGING_CONFIG) - cls.dagbag = DagBag(include_examples=True) + parse_and_sync_to_db(os.devnull, include_examples=True) cls.parser = cli_parser.get_parser() clear_db_runs() + cls.dagbag = DagBag(read_dags_from_db=True) cls.dag = cls.dagbag.get_dag(cls.dag_id) - cls.dagbag.sync_to_db() data_interval = cls.dag.timetable.infer_manual_data_interval(run_after=DEFAULT_DATE) triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.CLI} if AIRFLOW_V_3_0_PLUS else {} cls.dag_run = cls.dag.create_dagrun( @@ -164,7 +164,7 @@ def test_cli_test_different_path(self, session, tmp_path): with conf_vars({("core", "dags_folder"): orig_dags_folder.as_posix()}): dagbag = DagBag(include_examples=False) dag = dagbag.get_dag("test_dags_folder") - dagbag.sync_to_db(session=session) + dagbag.sync_to_db("dags-folder", None, session=session) logical_date = pendulum.now("UTC") data_interval = dag.timetable.infer_manual_data_interval(run_after=logical_date) diff --git a/tests/conftest.py b/tests/conftest.py index de13fe99c4bf6..8e41e35e35d06 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,15 +16,20 @@ # under the License. from __future__ import annotations +import json import logging import os import sys +from contextlib import contextmanager from typing import TYPE_CHECKING import pytest from tests_common.test_utils.log_handlers import non_pytest_handlers +if TYPE_CHECKING: + from pathlib import Path + # We should set these before loading _any_ of the rest of airflow so that the # unit test mode config is set as early as possible. assert "airflow" not in sys.modules, "No airflow module can be imported before these lines" @@ -81,6 +86,37 @@ def clear_all_logger_handlers(): remove_all_non_pytest_log_handlers() +@pytest.fixture +def testing_dag_bundle(): + from airflow.models.dagbundle import DagBundleModel + from airflow.utils.session import create_session + + with create_session() as session: + if session.query(DagBundleModel).filter(DagBundleModel.name == "testing").count() == 0: + testing = DagBundleModel(name="testing") + session.add(testing) + + +@pytest.fixture +def configure_testing_dag_bundle(): + """Configure the testing DAG bundle with the provided path, and disable the DAGs folder bundle.""" + from tests_common.test_utils.config import conf_vars + + @contextmanager + def _config_bundle(path_to_parse: Path | str): + bundle_config = [ + { + "name": "testing", + "classpath": "airflow.dag_processing.bundles.local.LocalDagBundle", + "kwargs": {"local_folder": str(path_to_parse), "refresh_interval": 0}, + } + ] + with conf_vars({("dag_bundles", "backends"): json.dumps(bundle_config)}): + yield + + return _config_bundle + + if TYPE_CHECKING: # Static checkers do not know about pytest fixtures' types and return, # In case if them distributed through third party packages. diff --git a/tests/dag_processing/bundles/test_dag_bundle_manager.py b/tests/dag_processing/bundles/test_dag_bundle_manager.py index 47b192efc198f..eee7dee211472 100644 --- a/tests/dag_processing/bundles/test_dag_bundle_manager.py +++ b/tests/dag_processing/bundles/test_dag_bundle_manager.py @@ -68,7 +68,9 @@ ) def test_parse_bundle_config(value, expected): """Test that bundle_configs are read from configuration.""" - envs = {"AIRFLOW__DAG_BUNDLES__BACKENDS": value} if value else {} + envs = {"AIRFLOW__CORE__LOAD_EXAMPLES": "False"} + if value: + envs["AIRFLOW__DAG_BUNDLES__BACKENDS"] = value cm = nullcontext() exp_fail = False if isinstance(expected, str): @@ -133,6 +135,7 @@ def clear_db(): @pytest.mark.db_test +@conf_vars({("core", "LOAD_EXAMPLES"): "False"}) def test_sync_bundles_to_db(clear_db): def _get_bundle_names_and_active(): with create_session() as session: @@ -167,3 +170,14 @@ def test_view_url(version): with patch.object(BaseDagBundle, "view_url") as view_url_mock: bundle_manager.view_url("my-test-bundle", version=version) view_url_mock.assert_called_once_with(version=version) + + +def test_example_dags_bundle_added(): + manager = DagBundlesManager() + manager.parse_config() + assert "example_dags" in manager._bundle_config + + with conf_vars({("core", "LOAD_EXAMPLES"): "False"}): + manager = DagBundlesManager() + manager.parse_config() + assert "example_dags" not in manager._bundle_config diff --git a/tests/dag_processing/test_collection.py b/tests/dag_processing/test_collection.py index a248904cbefcc..8ef07514bbfbe 100644 --- a/tests/dag_processing/test_collection.py +++ b/tests/dag_processing/test_collection.py @@ -185,7 +185,7 @@ def dag_to_lazy_serdag(self, dag: DAG) -> LazyDeserializedDAG: @pytest.mark.usefixtures("clean_db") # sync_perms in fab has bad session commit hygiene def test_sync_perms_syncs_dag_specific_perms_on_update( - self, monkeypatch, spy_agency: SpyAgency, session, time_machine + self, monkeypatch, spy_agency: SpyAgency, session, time_machine, testing_dag_bundle ): """ Test that dagbag.sync_to_db will sync DAG specific permissions when a DAG is @@ -210,7 +210,7 @@ def _sync_to_db(): sync_perms_spy.reset_calls() time_machine.shift(20) - update_dag_parsing_results_in_db([dag], dict(), set(), session) + update_dag_parsing_results_in_db("testing", None, [dag], dict(), set(), session) _sync_to_db() spy_agency.assert_spy_called_with(sync_perms_spy, dag, session=session) @@ -228,7 +228,9 @@ def _sync_to_db(): @patch.object(SerializedDagModel, "write_dag") @patch("airflow.models.dag.DAG.bulk_write_to_db") - def test_sync_to_db_is_retried(self, mock_bulk_write_to_db, mock_s10n_write_dag, session): + def test_sync_to_db_is_retried( + self, mock_bulk_write_to_db, mock_s10n_write_dag, testing_dag_bundle, session + ): """Test that important DB operations in db sync are retried on OperationalError""" serialized_dags_count = session.query(func.count(SerializedDagModel.dag_id)).scalar() @@ -244,14 +246,16 @@ def test_sync_to_db_is_retried(self, mock_bulk_write_to_db, mock_s10n_write_dag, mock_bulk_write_to_db.side_effect = side_effect mock_session = mock.MagicMock() - update_dag_parsing_results_in_db(dags=dags, import_errors={}, warnings=set(), session=mock_session) + update_dag_parsing_results_in_db( + "testing", None, dags=dags, import_errors={}, warnings=set(), session=mock_session + ) # Test that 3 attempts were made to run 'DAG.bulk_write_to_db' successfully mock_bulk_write_to_db.assert_has_calls( [ - mock.call(mock.ANY, session=mock.ANY), - mock.call(mock.ANY, session=mock.ANY), - mock.call(mock.ANY, session=mock.ANY), + mock.call("testing", None, mock.ANY, session=mock.ANY), + mock.call("testing", None, mock.ANY, session=mock.ANY), + mock.call("testing", None, mock.ANY, session=mock.ANY), ] ) # Assert that rollback is called twice (i.e. whenever OperationalError occurs) @@ -268,7 +272,7 @@ def test_sync_to_db_is_retried(self, mock_bulk_write_to_db, mock_s10n_write_dag, serialized_dags_count = session.query(func.count(SerializedDagModel.dag_id)).scalar() assert serialized_dags_count == 0 - def test_serialized_dags_are_written_to_db_on_sync(self, session): + def test_serialized_dags_are_written_to_db_on_sync(self, testing_dag_bundle, session): """ Test that when dagbag.sync_to_db is called the DAGs are Serialized and written to DB even when dagbag.read_dags_from_db is False @@ -278,14 +282,14 @@ def test_serialized_dags_are_written_to_db_on_sync(self, session): dag = DAG(dag_id="test") - update_dag_parsing_results_in_db([dag], dict(), set(), session) + update_dag_parsing_results_in_db("testing", None, [dag], dict(), set(), session) new_serialized_dags_count = session.query(func.count(SerializedDagModel.dag_id)).scalar() assert new_serialized_dags_count == 1 @patch.object(SerializedDagModel, "write_dag") def test_serialized_dag_errors_are_import_errors( - self, mock_serialize, caplog, session, dag_import_error_listener + self, mock_serialize, caplog, session, dag_import_error_listener, testing_dag_bundle ): """ Test that errors serializing a DAG are recorded as import_errors in the DB @@ -298,7 +302,7 @@ def test_serialized_dag_errors_are_import_errors( dag.fileloc = "abc.py" import_errors = {} - update_dag_parsing_results_in_db([dag], import_errors, set(), session) + update_dag_parsing_results_in_db("testing", None, [dag], import_errors, set(), session) assert "SerializationError" in caplog.text # Should have been edited in place @@ -320,7 +324,7 @@ def test_serialized_dag_errors_are_import_errors( assert len(dag_import_error_listener.existing) == 0 assert dag_import_error_listener.new["abc.py"] == import_error.stacktrace - def test_new_import_error_replaces_old(self, session, dag_import_error_listener): + def test_new_import_error_replaces_old(self, session, dag_import_error_listener, testing_dag_bundle): """ Test that existing import error is updated and new record not created for a dag with the same filename @@ -336,6 +340,8 @@ def test_new_import_error_replaces_old(self, session, dag_import_error_listener) prev_error_id = prev_error.id update_dag_parsing_results_in_db( + bundle_name="testing", + bundle_version=None, dags=[], import_errors={"abc.py": "New error"}, warnings=set(), @@ -353,7 +359,7 @@ def test_new_import_error_replaces_old(self, session, dag_import_error_listener) assert len(dag_import_error_listener.existing) == 1 assert dag_import_error_listener.existing["abc.py"] == prev_error.stacktrace - def test_remove_error_clears_import_error(self, session): + def test_remove_error_clears_import_error(self, testing_dag_bundle, session): # Pre-condition: there is an import error for the dag file filename = "abc.py" prev_error = ParseImportError( @@ -381,7 +387,7 @@ def test_remove_error_clears_import_error(self, session): dag.fileloc = filename import_errors = {} - update_dag_parsing_results_in_db([dag], import_errors, set(), session) + update_dag_parsing_results_in_db("testing", None, [dag], import_errors, set(), session) dag_model: DagModel = session.get(DagModel, (dag.dag_id,)) assert dag_model.has_import_errors is False @@ -473,7 +479,7 @@ def _sync_perms(): ], ) @pytest.mark.usefixtures("clean_db") - def test_dagmodel_properties(self, attrs, expected, session, time_machine): + def test_dagmodel_properties(self, attrs, expected, session, time_machine, testing_dag_bundle): """Test that properties on the dag model are correctly set when dealing with a LazySerializedDag""" dt = tz.datetime(2020, 1, 5, 0, 0, 0) time_machine.move_to(dt, tick=False) @@ -493,7 +499,7 @@ def test_dagmodel_properties(self, attrs, expected, session, time_machine): session.add(dr1) session.commit() - update_dag_parsing_results_in_db([self.dag_to_lazy_serdag(dag)], {}, set(), session) + update_dag_parsing_results_in_db("testing", None, [self.dag_to_lazy_serdag(dag)], {}, set(), session) orm_dag = session.get(DagModel, ("dag",)) @@ -505,14 +511,21 @@ def test_dagmodel_properties(self, attrs, expected, session, time_machine): assert orm_dag.last_parsed_time == dt - def test_existing_dag_is_paused_upon_creation(self, session): + def test_existing_dag_is_paused_upon_creation(self, testing_dag_bundle, session): dag = DAG("dag_paused", schedule=None) - update_dag_parsing_results_in_db([self.dag_to_lazy_serdag(dag)], {}, set(), session) + update_dag_parsing_results_in_db("testing", None, [self.dag_to_lazy_serdag(dag)], {}, set(), session) orm_dag = session.get(DagModel, ("dag_paused",)) assert orm_dag.is_paused is False dag = DAG("dag_paused", schedule=None, is_paused_upon_creation=True) - update_dag_parsing_results_in_db([self.dag_to_lazy_serdag(dag)], {}, set(), session) + update_dag_parsing_results_in_db("testing", None, [self.dag_to_lazy_serdag(dag)], {}, set(), session) # Since the dag existed before, it should not follow the pause flag upon creation orm_dag = session.get(DagModel, ("dag_paused",)) assert orm_dag.is_paused is False + + def test_bundle_name_and_version_are_stored(self, testing_dag_bundle, session): + dag = DAG("mydag", schedule=None) + update_dag_parsing_results_in_db("testing", "1.0", [self.dag_to_lazy_serdag(dag)], {}, set(), session) + orm_dag = session.get(DagModel, "mydag") + assert orm_dag.bundle_name == "testing" + assert orm_dag.latest_bundle_version == "1.0" diff --git a/tests/dag_processing/test_manager.py b/tests/dag_processing/test_manager.py index 2cc8a43e05450..7608cbbd32766 100644 --- a/tests/dag_processing/test_manager.py +++ b/tests/dag_processing/test_manager.py @@ -19,6 +19,7 @@ import io import itertools +import json import logging import multiprocessing import os @@ -43,15 +44,18 @@ from airflow.callbacks.callback_requests import DagCallbackRequest from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG +from airflow.dag_processing.bundles.manager import DagBundlesManager from airflow.dag_processing.manager import ( + DagFileInfo, DagFileProcessorAgent, DagFileProcessorManager, DagFileStat, ) from airflow.dag_processing.processor import DagFileProcessorProcess -from airflow.models import DagBag, DagModel, DbCallbackRequest +from airflow.models import DAG, DagBag, DagModel, DbCallbackRequest from airflow.models.asset import TaskOutletAssetReference from airflow.models.dag_version import DagVersion +from airflow.models.dagbundle import DagBundleModel from airflow.models.dagcode import DagCode from airflow.models.serialized_dag import SerializedDagModel from airflow.utils import timezone @@ -78,6 +82,10 @@ DEFAULT_DATE = timezone.datetime(2016, 1, 1) +def _get_dag_file_paths(files: list[str]) -> list[DagFileInfo]: + return [DagFileInfo(bundle_name="testing", path=f) for f in files] + + class TestDagFileProcessorManager: @pytest.fixture(autouse=True) def _disable_examples(self): @@ -101,9 +109,6 @@ def teardown_class(self): clear_db_callbacks() clear_db_import_errors() - def run_processor_manager_one_loop(self, manager: DagFileProcessorManager) -> None: - manager._run_parsing_loop() - def mock_processor(self) -> DagFileProcessorProcess: proc = MagicMock() proc.create_time.return_value = time.time() @@ -124,49 +129,51 @@ def clear_parse_import_errors(self): @pytest.mark.usefixtures("clear_parse_import_errors") @conf_vars({("core", "load_examples"): "False"}) - def test_remove_file_clears_import_error(self, tmp_path): + def test_remove_file_clears_import_error(self, tmp_path, configure_testing_dag_bundle): path_to_parse = tmp_path / "temp_dag.py" # Generate original import error path_to_parse.write_text("an invalid airflow DAG") - manager = DagFileProcessorManager( - dag_directory=path_to_parse.parent, - max_runs=1, - processor_timeout=365 * 86_400, - ) + with configure_testing_dag_bundle(path_to_parse): + manager = DagFileProcessorManager( + max_runs=1, + processor_timeout=365 * 86_400, + ) - with create_session() as session: - self.run_processor_manager_one_loop(manager) + with create_session() as session: + manager.run() - import_errors = session.query(ParseImportError).all() - assert len(import_errors) == 1 + import_errors = session.query(ParseImportError).all() + assert len(import_errors) == 1 - path_to_parse.unlink() + path_to_parse.unlink() - # Rerun the parser once the dag file has been removed - self.run_processor_manager_one_loop(manager) - import_errors = session.query(ParseImportError).all() + # Rerun the parser once the dag file has been removed + manager.run() + import_errors = session.query(ParseImportError).all() - assert len(import_errors) == 0 - session.rollback() + assert len(import_errors) == 0 + session.rollback() @conf_vars({("core", "load_examples"): "False"}) def test_max_runs_when_no_files(self, tmp_path): - manager = DagFileProcessorManager(dag_directory=tmp_path, max_runs=1) + with conf_vars({("core", "dags_folder"): str(tmp_path)}): + manager = DagFileProcessorManager(max_runs=1) + manager.run() - self.run_processor_manager_one_loop(manager) + # TODO: AIP-66 no asserts? def test_start_new_processes_with_same_filepath(self): """ Test that when a processor already exist with a filepath, a new processor won't be created with that filepath. The filepath will just be removed from the list. """ - manager = DagFileProcessorManager(dag_directory="directory", max_runs=1) + manager = DagFileProcessorManager(max_runs=1) - file_1 = "file_1.py" - file_2 = "file_2.py" - file_3 = "file_3.py" + file_1 = DagFileInfo(bundle_name="testing", path="file_1.py") + file_2 = DagFileInfo(bundle_name="testing", path="file_2.py") + file_3 = DagFileInfo(bundle_name="testing", path="file_3.py") manager._file_path_queue = deque([file_1, file_2, file_3]) # Mock that only one processor exists. This processor runs with 'file_1' @@ -185,49 +192,47 @@ def test_start_new_processes_with_same_filepath(self): assert deque([file_3]) == manager._file_path_queue def test_set_file_paths_when_processor_file_path_not_in_new_file_paths(self): - manager = DagFileProcessorManager(dag_directory="directory", max_runs=1) - - mock_processor = MagicMock() - mock_processor.stop.side_effect = AttributeError("DagFileProcessor object has no attribute stop") - mock_processor.terminate.side_effect = None + """Ensure processors and file stats are removed when the file path is not in the new file paths""" + manager = DagFileProcessorManager(max_runs=1) + file = DagFileInfo(bundle_name="testing", path="missing_file.txt") - manager._processors["missing_file.txt"] = mock_processor - manager._file_stats["missing_file.txt"] = DagFileStat() + manager._processors[file] = MagicMock() + manager._file_stats[file] = DagFileStat() manager.set_file_paths(["abc.txt"]) assert manager._processors == {} - assert "missing_file.txt" not in manager._file_stats + assert file not in manager._file_stats def test_set_file_paths_when_processor_file_path_is_in_new_file_paths(self): - manager = DagFileProcessorManager(dag_directory="directory", max_runs=1) - + manager = DagFileProcessorManager(max_runs=1) + file = DagFileInfo(bundle_name="testing", path="abc.txt") mock_processor = MagicMock() - mock_processor.stop.side_effect = AttributeError("DagFileProcessor object has no attribute stop") - mock_processor.terminate.side_effect = None - manager._processors["abc.txt"] = mock_processor + manager._processors[file] = mock_processor - manager.set_file_paths(["abc.txt"]) - assert manager._processors == {"abc.txt": mock_processor} + manager.set_file_paths([file]) + assert manager._processors == {file: mock_processor} @conf_vars({("scheduler", "file_parsing_sort_mode"): "alphabetical"}) def test_file_paths_in_queue_sorted_alphabetically(self): """Test dag files are sorted alphabetically""" - dag_files = ["file_3.py", "file_2.py", "file_4.py", "file_1.py"] + file_names = ["file_3.py", "file_2.py", "file_4.py", "file_1.py"] + dag_files = _get_dag_file_paths(file_names) + ordered_dag_files = _get_dag_file_paths(sorted(file_names)) - manager = DagFileProcessorManager(dag_directory="directory", max_runs=1) + manager = DagFileProcessorManager(max_runs=1) manager.set_file_paths(dag_files) assert manager._file_path_queue == deque() manager.prepare_file_path_queue() - assert manager._file_path_queue == deque(["file_1.py", "file_2.py", "file_3.py", "file_4.py"]) + assert manager._file_path_queue == deque(ordered_dag_files) @conf_vars({("scheduler", "file_parsing_sort_mode"): "random_seeded_by_host"}) def test_file_paths_in_queue_sorted_random_seeded_by_host(self): """Test files are randomly sorted and seeded by host name""" - dag_files = ["file_3.py", "file_2.py", "file_4.py", "file_1.py"] + dag_files = _get_dag_file_paths(["file_3.py", "file_2.py", "file_4.py", "file_1.py"]) - manager = DagFileProcessorManager(dag_directory="directory", max_runs=1) + manager = DagFileProcessorManager(max_runs=1) manager.set_file_paths(dag_files) assert manager._file_path_queue == deque() @@ -247,45 +252,50 @@ def test_file_paths_in_queue_sorted_random_seeded_by_host(self): def test_file_paths_in_queue_sorted_by_modified_time(self, mock_getmtime): """Test files are sorted by modified time""" paths_with_mtime = {"file_3.py": 3.0, "file_2.py": 2.0, "file_4.py": 5.0, "file_1.py": 4.0} - dag_files = list(paths_with_mtime.keys()) + dag_files = _get_dag_file_paths(paths_with_mtime.keys()) mock_getmtime.side_effect = list(paths_with_mtime.values()) - manager = DagFileProcessorManager(dag_directory="directory", max_runs=1) + manager = DagFileProcessorManager(max_runs=1) manager.set_file_paths(dag_files) assert manager._file_path_queue == deque() manager.prepare_file_path_queue() - assert manager._file_path_queue == deque(["file_4.py", "file_1.py", "file_3.py", "file_2.py"]) + ordered_files = _get_dag_file_paths(["file_4.py", "file_1.py", "file_3.py", "file_2.py"]) + assert manager._file_path_queue == deque(ordered_files) @conf_vars({("scheduler", "file_parsing_sort_mode"): "modified_time"}) @mock.patch("airflow.utils.file.os.path.getmtime") def test_file_paths_in_queue_excludes_missing_file(self, mock_getmtime): """Check that a file is not enqueued for processing if it has been deleted""" - dag_files = ["file_3.py", "file_2.py", "file_4.py"] + dag_files = _get_dag_file_paths(["file_3.py", "file_2.py", "file_4.py"]) mock_getmtime.side_effect = [1.0, 2.0, FileNotFoundError()] - manager = DagFileProcessorManager(dag_directory="directory", max_runs=1) + manager = DagFileProcessorManager(max_runs=1) manager.set_file_paths(dag_files) manager.prepare_file_path_queue() - assert manager._file_path_queue == deque(["file_2.py", "file_3.py"]) + + ordered_files = _get_dag_file_paths(["file_2.py", "file_3.py"]) + assert manager._file_path_queue == deque(ordered_files) @conf_vars({("scheduler", "file_parsing_sort_mode"): "modified_time"}) @mock.patch("airflow.utils.file.os.path.getmtime") def test_add_new_file_to_parsing_queue(self, mock_getmtime): """Check that new file is added to parsing queue""" - dag_files = ["file_1.py", "file_2.py", "file_3.py"] + dag_files = _get_dag_file_paths(["file_1.py", "file_2.py", "file_3.py"]) mock_getmtime.side_effect = [1.0, 2.0, 3.0] - manager = DagFileProcessorManager(dag_directory="directory", max_runs=1) + manager = DagFileProcessorManager(max_runs=1) manager.set_file_paths(dag_files) manager.prepare_file_path_queue() - assert manager._file_path_queue == deque(["file_3.py", "file_2.py", "file_1.py"]) + ordered_files = _get_dag_file_paths(["file_3.py", "file_2.py", "file_1.py"]) + assert manager._file_path_queue == deque(ordered_files) - manager.set_file_paths([*dag_files, "file_4.py"]) + manager.set_file_paths([*dag_files, DagFileInfo(bundle_name="testing", path="file_4.py")]) manager.add_new_file_path_to_queue() - assert manager._file_path_queue == deque(["file_4.py", "file_3.py", "file_2.py", "file_1.py"]) + ordered_files = _get_dag_file_paths(["file_4.py", "file_3.py", "file_2.py", "file_1.py"]) + assert manager._file_path_queue == deque(ordered_files) @conf_vars({("scheduler", "file_parsing_sort_mode"): "modified_time"}) @mock.patch("airflow.utils.file.os.path.getmtime") @@ -295,15 +305,16 @@ def test_recently_modified_file_is_parsed_with_mtime_mode(self, mock_getmtime): """ freezed_base_time = timezone.datetime(2020, 1, 5, 0, 0, 0) initial_file_1_mtime = (freezed_base_time - timedelta(minutes=5)).timestamp() - dag_files = ["file_1.py"] + dag_file = DagFileInfo(bundle_name="testing", path="file_1.py") + dag_files = [dag_file] mock_getmtime.side_effect = [initial_file_1_mtime] - manager = DagFileProcessorManager(dag_directory="directory", max_runs=3) + manager = DagFileProcessorManager(max_runs=3) # let's say the DAG was just parsed 10 seconds before the Freezed time last_finish_time = freezed_base_time - timedelta(seconds=10) manager._file_stats = { - "file_1.py": DagFileStat(1, 0, last_finish_time, 1.0, 1, 1), + dag_file: DagFileStat(1, 0, last_finish_time, 1.0, 1, 1), } with time_machine.travel(freezed_base_time): manager.set_file_paths(dag_files) @@ -323,13 +334,14 @@ def test_recently_modified_file_is_parsed_with_mtime_mode(self, mock_getmtime): mock_getmtime.side_effect = [file_1_new_mtime_ts] manager.prepare_file_path_queue() # Check that file is added to the queue even though file was just recently passed - assert manager._file_path_queue == deque(["file_1.py"]) + assert manager._file_path_queue == deque(dag_files) assert last_finish_time < file_1_new_mtime assert ( manager._file_process_interval - > (freezed_base_time - manager._file_stats["file_1.py"].last_finish_time).total_seconds() + > (freezed_base_time - manager._file_stats[dag_file].last_finish_time).total_seconds() ) + @pytest.mark.skip("AIP-66: parsing requests are not bundle aware yet") def test_file_paths_in_queue_sorted_by_priority(self): from airflow.models.dagbag import DagPriorityParsingRequest @@ -351,25 +363,27 @@ def test_file_paths_in_queue_sorted_by_priority(self): parsing_request_after = session2.query(DagPriorityParsingRequest).get(parsing_request.id) assert parsing_request_after is None - def test_scan_stale_dags(self): + def test_scan_stale_dags(self, testing_dag_bundle): """ Ensure that DAGs are marked inactive when the file is parsed but the DagModel.last_parsed_time is not updated. """ manager = DagFileProcessorManager( - dag_directory="directory", max_runs=1, processor_timeout=10 * 60, ) - test_dag_path = str(TEST_DAG_FOLDER / "test_example_bash_operator.py") - dagbag = DagBag(test_dag_path, read_dags_from_db=False, include_examples=False) + test_dag_path = DagFileInfo( + bundle_name="testing", + path=str(TEST_DAG_FOLDER / "test_example_bash_operator.py"), + ) + dagbag = DagBag(test_dag_path.path, read_dags_from_db=False, include_examples=False) with create_session() as session: # Add stale DAG to the DB dag = dagbag.get_dag("test_example_bash_operator") dag.last_parsed_time = timezone.utcnow() - dag.sync_to_db() + DAG.bulk_write_to_db("testing", None, [dag]) SerializedDagModel.write_dag(dag) # Add DAG to the file_parsing_stats @@ -386,7 +400,7 @@ def test_scan_stale_dags(self): active_dag_count = ( session.query(func.count(DagModel.dag_id)) - .filter(DagModel.is_active, DagModel.fileloc == test_dag_path) + .filter(DagModel.is_active, DagModel.fileloc == test_dag_path.path) .scalar() ) assert active_dag_count == 1 @@ -395,7 +409,7 @@ def test_scan_stale_dags(self): active_dag_count = ( session.query(func.count(DagModel.dag_id)) - .filter(DagModel.is_active, DagModel.fileloc == test_dag_path) + .filter(DagModel.is_active, DagModel.fileloc == test_dag_path.path) .scalar() ) assert active_dag_count == 0 @@ -410,11 +424,11 @@ def test_scan_stale_dags(self): assert serialized_dag_count == 1 def test_kill_timed_out_processors_kill(self): - manager = DagFileProcessorManager(dag_directory="directory", max_runs=1, processor_timeout=5) + manager = DagFileProcessorManager(max_runs=1, processor_timeout=5) processor = self.mock_processor() processor._process.create_time.return_value = timezone.make_aware(datetime.min).timestamp() - manager._processors = {"abc.txt": processor} + manager._processors = {DagFileInfo(bundle_name="testing", path="abc.txt"): processor} with mock.patch.object(type(processor), "kill") as mock_kill: manager._kill_timed_out_processors() mock_kill.assert_called_once_with(signal.SIGKILL) @@ -422,14 +436,13 @@ def test_kill_timed_out_processors_kill(self): def test_kill_timed_out_processors_no_kill(self): manager = DagFileProcessorManager( - dag_directory=TEST_DAG_FOLDER, max_runs=1, processor_timeout=5, ) processor = self.mock_processor() processor._process.create_time.return_value = timezone.make_aware(datetime.max).timestamp() - manager._processors = {"abc.txt": processor} + manager._processors = {DagFileInfo(bundle_name="testing", path="abc.txt"): processor} with mock.patch.object(type(processor), "kill") as mock_kill: manager._kill_timed_out_processors() mock_kill.assert_not_called() @@ -472,7 +485,7 @@ def test_serialize_callback_requests(self, callbacks, path, child_comms_fd, expe @conf_vars({("core", "load_examples"): "False"}) @pytest.mark.execution_timeout(10) - def test_dag_with_system_exit(self): + def test_dag_with_system_exit(self, configure_testing_dag_bundle): """ Test to check that a DAG with a system.exit() doesn't break the scheduler. """ @@ -484,9 +497,9 @@ def test_dag_with_system_exit(self): clear_db_dags() clear_db_serialized_dags() - manager = DagFileProcessorManager(dag_directory=dag_directory, max_runs=1) - - manager._run_parsing_loop() + with configure_testing_dag_bundle(dag_directory): + manager = DagFileProcessorManager(max_runs=1) + manager.run() # Three files in folder should be processed assert sum(stat.run_count for stat in manager._file_stats.values()) == 3 @@ -496,7 +509,7 @@ def test_dag_with_system_exit(self): @conf_vars({("core", "load_examples"): "False"}) @pytest.mark.execution_timeout(30) - def test_pipe_full_deadlock(self): + def test_pipe_full_deadlock(self, configure_testing_dag_bundle): dag_filepath = TEST_DAG_FOLDER / "test_scheduler_dags.py" child_pipe, parent_pipe = multiprocessing.Pipe() @@ -533,38 +546,43 @@ def keep_pipe_full(pipe, exit_event): thread = threading.Thread(target=keep_pipe_full, args=(parent_pipe, exit_event)) - manager = DagFileProcessorManager( - dag_directory=dag_filepath, - # A reasonable large number to ensure that we trigger the deadlock - max_runs=100, - processor_timeout=5, - signal_conn=child_pipe, - # Make it loop sub-processes quickly. Need to be non-zero to exercise the bug, else it finishes - # too quickly - file_process_interval=0.01, - ) - - try: - thread.start() + with configure_testing_dag_bundle(dag_filepath): + manager = DagFileProcessorManager( + # A reasonable large number to ensure that we trigger the deadlock + max_runs=100, + processor_timeout=5, + signal_conn=child_pipe, + # Make it loop sub-processes quickly. Need to be non-zero to exercise the bug, else it finishes + # too quickly + file_process_interval=0.01, + ) - # If this completes without hanging, then the test is good! - with mock.patch.object( - DagFileProcessorProcess, "start", side_effect=lambda *args, **kwargs: self.mock_processor() - ): - manager.run() - exit_event.set() - finally: - logger.info("Closing pipes") - parent_pipe.close() - child_pipe.close() - logger.info("Closed pipes") - logger.info("Joining thread") - thread.join(timeout=1.0) - logger.info("Joined thread") + try: + thread.start() + + # If this completes without hanging, then the test is good! + with mock.patch.object( + DagFileProcessorProcess, + "start", + side_effect=lambda *args, **kwargs: self.mock_processor(), + ): + manager.run() + exit_event.set() + finally: + logger.info("Closing pipes") + parent_pipe.close() + child_pipe.close() + logger.info("Closed pipes") + logger.info("Joining thread") + thread.join(timeout=1.0) + logger.info("Joined thread") @conf_vars({("core", "load_examples"): "False"}) @mock.patch("airflow.dag_processing.manager.Stats.timing") - def test_send_file_processing_statsd_timing(self, statsd_timing_mock, tmp_path): + @pytest.mark.skip("AIP-66: stats are not implemented yet") + def test_send_file_processing_statsd_timing( + self, statsd_timing_mock, tmp_path, configure_testing_dag_bundle + ): path_to_parse = tmp_path / "temp_dag.py" dag_code = textwrap.dedent( """ @@ -574,11 +592,11 @@ def test_send_file_processing_statsd_timing(self, statsd_timing_mock, tmp_path): ) path_to_parse.write_text(dag_code) - manager = DagFileProcessorManager(dag_directory=path_to_parse.parent, max_runs=1) + with configure_testing_dag_bundle(tmp_path): + manager = DagFileProcessorManager(max_runs=1) + manager.run() - self.run_processor_manager_one_loop(manager) last_runtime = manager._file_stats[os.fspath(path_to_parse)].last_duration - statsd_timing_mock.assert_has_calls( [ mock.call("dag_processing.last_duration.temp_dag", last_runtime), @@ -587,17 +605,19 @@ def test_send_file_processing_statsd_timing(self, statsd_timing_mock, tmp_path): any_order=True, ) - def test_refresh_dags_dir_doesnt_delete_zipped_dags(self, tmp_path): + def test_refresh_dags_dir_doesnt_delete_zipped_dags(self, tmp_path, configure_testing_dag_bundle): """Test DagFileProcessorManager._refresh_dag_dir method""" - manager = DagFileProcessorManager(dag_directory=TEST_DAG_FOLDER, max_runs=1) dagbag = DagBag(dag_folder=tmp_path, include_examples=False) zipped_dag_path = os.path.join(TEST_DAGS_FOLDER, "test_zip.zip") dagbag.process_file(zipped_dag_path) dag = dagbag.get_dag("test_zip_dag") - dag.sync_to_db() + DAG.bulk_write_to_db("testing", None, [dag]) SerializedDagModel.write_dag(dag) - manager.last_dag_dir_refresh_time = time.monotonic() - 10 * 60 - manager._refresh_dag_dir() + + with configure_testing_dag_bundle(zipped_dag_path): + manager = DagFileProcessorManager(max_runs=1) + manager.run() + # Assert dag not deleted in SDM assert SerializedDagModel.has_dag("test_zip_dag") # assert code not deleted @@ -605,20 +625,23 @@ def test_refresh_dags_dir_doesnt_delete_zipped_dags(self, tmp_path): # assert dag still active assert dag.get_is_active() - def test_refresh_dags_dir_deactivates_deleted_zipped_dags(self, tmp_path): + def test_refresh_dags_dir_deactivates_deleted_zipped_dags(self, tmp_path, configure_testing_dag_bundle): """Test DagFileProcessorManager._refresh_dag_dir method""" - manager = DagFileProcessorManager(dag_directory=TEST_DAG_FOLDER, max_runs=1) dagbag = DagBag(dag_folder=tmp_path, include_examples=False) zipped_dag_path = os.path.join(TEST_DAGS_FOLDER, "test_zip.zip") dagbag.process_file(zipped_dag_path) dag = dagbag.get_dag("test_zip_dag") dag.sync_to_db() SerializedDagModel.write_dag(dag) - manager.last_dag_dir_refresh_time = time.monotonic() - 10 * 60 + + # TODO: this test feels a bit fragile - pointing at the zip directly causes the test to fail + # TODO: jed look at this more closely - bagbad then process_file?! # Mock might_contain_dag to mimic deleting the python file from the zip with mock.patch("airflow.dag_processing.manager.might_contain_dag", return_value=False): - manager._refresh_dag_dir() + with configure_testing_dag_bundle(TEST_DAGS_FOLDER): + manager = DagFileProcessorManager(max_runs=1) + manager.run() # Deleting the python file should not delete SDM for versioning sake assert SerializedDagModel.has_dag("test_zip_dag") @@ -635,7 +658,7 @@ def test_refresh_dags_dir_deactivates_deleted_zipped_dags(self, tmp_path): ("scheduler", "standalone_dag_processor"): "True", } ) - def test_fetch_callbacks_from_database(self, tmp_path): + def test_fetch_callbacks_from_database(self, tmp_path, configure_testing_dag_bundle): dag_filepath = TEST_DAG_FOLDER / "test_on_failure_callback_dag.py" callback1 = DagCallbackRequest( @@ -655,13 +678,12 @@ def test_fetch_callbacks_from_database(self, tmp_path): session.add(DbCallbackRequest(callback=callback1, priority_weight=11)) session.add(DbCallbackRequest(callback=callback2, priority_weight=10)) - manager = DagFileProcessorManager( - dag_directory=os.fspath(tmp_path), max_runs=1, standalone_dag_processor=True - ) + with configure_testing_dag_bundle(tmp_path): + manager = DagFileProcessorManager(max_runs=1, standalone_dag_processor=True) - with create_session() as session: - self.run_processor_manager_one_loop(manager) - assert session.query(DbCallbackRequest).count() == 0 + with create_session() as session: + manager.run() + assert session.query(DbCallbackRequest).count() == 0 @conf_vars( { @@ -670,7 +692,7 @@ def test_fetch_callbacks_from_database(self, tmp_path): ("core", "load_examples"): "False", } ) - def test_fetch_callbacks_from_database_max_per_loop(self, tmp_path): + def test_fetch_callbacks_from_database_max_per_loop(self, tmp_path, configure_testing_dag_bundle): """Test DagFileProcessorManager._fetch_callbacks method""" dag_filepath = TEST_DAG_FOLDER / "test_on_failure_callback_dag.py" @@ -684,15 +706,16 @@ def test_fetch_callbacks_from_database_max_per_loop(self, tmp_path): ) session.add(DbCallbackRequest(callback=callback, priority_weight=i)) - manager = DagFileProcessorManager(dag_directory=tmp_path, max_runs=1) + with configure_testing_dag_bundle(tmp_path): + manager = DagFileProcessorManager(max_runs=1) - with create_session() as session: - self.run_processor_manager_one_loop(manager) - assert session.query(DbCallbackRequest).count() == 3 + with create_session() as session: + manager.run() + assert session.query(DbCallbackRequest).count() == 3 - with create_session() as session: - self.run_processor_manager_one_loop(manager) - assert session.query(DbCallbackRequest).count() == 1 + with create_session() as session: + manager.run() + assert session.query(DbCallbackRequest).count() == 1 @conf_vars( { @@ -700,7 +723,7 @@ def test_fetch_callbacks_from_database_max_per_loop(self, tmp_path): ("core", "load_examples"): "False", } ) - def test_fetch_callbacks_from_database_not_standalone(self, tmp_path): + def test_fetch_callbacks_from_database_not_standalone(self, tmp_path, configure_testing_dag_bundle): dag_filepath = TEST_DAG_FOLDER / "test_on_failure_callback_dag.py" with create_session() as session: @@ -712,22 +735,23 @@ def test_fetch_callbacks_from_database_not_standalone(self, tmp_path): ) session.add(DbCallbackRequest(callback=callback, priority_weight=10)) - manager = DagFileProcessorManager(dag_directory=tmp_path, max_runs=1) - - self.run_processor_manager_one_loop(manager) + with configure_testing_dag_bundle(tmp_path): + manager = DagFileProcessorManager(max_runs=1) + manager.run() # Verify no callbacks removed from database. with create_session() as session: assert session.query(DbCallbackRequest).count() == 1 + @pytest.mark.skip("AIP-66: callbacks are not implemented yet") def test_callback_queue(self, tmp_path): # given manager = DagFileProcessorManager( - dag_directory=TEST_DAG_FOLDER, max_runs=1, processor_timeout=365 * 86_400, ) + dag1_path = DagFileInfo(bundle_name="testing", path="/green_eggs/ham/file1.py") dag1_req1 = DagCallbackRequest( full_filepath="/green_eggs/ham/file1.py", dag_id="dag1", @@ -743,6 +767,7 @@ def test_callback_queue(self, tmp_path): msg=None, ) + dag2_path = DagFileInfo(bundle_name="testing", path="/green_eggs/ham/file2.py") dag2_req1 = DagCallbackRequest( full_filepath="/green_eggs/ham/file2.py", dag_id="dag2", @@ -756,7 +781,7 @@ def test_callback_queue(self, tmp_path): manager._add_callback_to_queue(dag2_req1) # then - requests should be in manager's queue, with dag2 ahead of dag1 (because it was added last) - assert manager._file_path_queue == deque([dag2_req1.full_filepath, dag1_req1.full_filepath]) + assert manager._file_path_queue == deque([dag2_path, dag1_path]) assert set(manager._callback_to_execute.keys()) == { dag1_req1.full_filepath, dag2_req1.full_filepath, @@ -787,24 +812,106 @@ def test_callback_queue(self, tmp_path): # And removed from the queue assert dag1_req1.full_filepath not in manager._callback_to_execute - def test_dag_with_assets(self, session): + def test_dag_with_assets(self, session, configure_testing_dag_bundle): """'Integration' test to ensure that the assets get parsed and stored correctly for parsed dags.""" test_dag_path = str(TEST_DAG_FOLDER / "test_assets.py") - manager = DagFileProcessorManager( - dag_directory=test_dag_path, - max_runs=1, - processor_timeout=365 * 86_400, - ) - - self.run_processor_manager_one_loop(manager) + with configure_testing_dag_bundle(test_dag_path): + manager = DagFileProcessorManager( + max_runs=1, + processor_timeout=365 * 86_400, + ) + manager.run() dag_model = session.get(DagModel, ("dag_with_skip_task")) assert dag_model.task_outlet_asset_references == [ TaskOutletAssetReference(asset_id=mock.ANY, dag_id="dag_with_skip_task", task_id="skip_task") ] + def test_bundles_are_refreshed(self): + """ + Ensure bundles are refreshed by the manager, when necessary. + + - refresh if the bundle hasn't been refreshed in the refresh_interval + - when the latest_version in the db doesn't match the version this parser knows about + """ + config = [ + { + "name": "bundleone", + "classpath": "airflow.dag_processing.bundles.local.LocalDagBundle", + "kwargs": {"local_folder": "/dev/null", "refresh_interval": 0}, + }, + { + "name": "bundletwo", + "classpath": "airflow.dag_processing.bundles.local.LocalDagBundle", + "kwargs": {"local_folder": "/dev/null", "refresh_interval": 300}, + }, + ] + + bundleone = MagicMock() + bundleone.name = "bundleone" + bundleone.refresh_interval = 0 + bundleone.get_current_version.return_value = None + bundletwo = MagicMock() + bundletwo.name = "bundletwo" + bundletwo.refresh_interval = 300 + bundletwo.get_current_version.return_value = None + + with conf_vars({("dag_bundles", "backends"): json.dumps(config)}): + DagBundlesManager().sync_bundles_to_db() + with mock.patch( + "airflow.dag_processing.bundles.manager.DagBundlesManager" + ) as mock_bundle_manager: + mock_bundle_manager.return_value._bundle_config = {"bundleone": None, "bundletwo": None} + mock_bundle_manager.return_value.get_all_dag_bundles.return_value = [bundleone, bundletwo] + manager = DagFileProcessorManager(max_runs=1) + manager.run() + bundleone.refresh.assert_called_once() + bundletwo.refresh.assert_called_once() + + # Now, we should only refresh bundleone, as haven't hit the refresh_interval for bundletwo + bundleone.reset_mock() + bundletwo.reset_mock() + manager.run() + bundleone.refresh.assert_called_once() + bundletwo.refresh.assert_not_called() + + # however, if the version doesn't match, we should still refresh + bundletwo.reset_mock() + bundletwo.get_current_version.return_value = "123" + manager.run() + bundletwo.refresh.assert_called_once() + + def test_bundles_versions_are_stored(self): + config = [ + { + "name": "mybundle", + "classpath": "airflow.dag_processing.bundles.local.LocalDagBundle", + "kwargs": {"local_folder": "/dev/null", "refresh_interval": 0}, + }, + ] + + mybundle = MagicMock() + mybundle.name = "bundleone" + mybundle.refresh_interval = 0 + mybundle.supports_versioning = True + mybundle.get_current_version.return_value = "123" + + with conf_vars({("dag_bundles", "backends"): json.dumps(config)}): + DagBundlesManager().sync_bundles_to_db() + with mock.patch( + "airflow.dag_processing.bundles.manager.DagBundlesManager" + ) as mock_bundle_manager: + mock_bundle_manager.return_value._bundle_config = {"bundleone": None} + mock_bundle_manager.return_value.get_all_dag_bundles.return_value = [mybundle] + manager = DagFileProcessorManager(max_runs=1) + manager.run() + + with create_session() as session: + model = session.get(DagBundleModel, "bundleone") + assert model.latest_version == "123" + class TestDagFileProcessorAgent: @pytest.fixture(autouse=True) @@ -815,14 +922,12 @@ def _disable_examples(self): def test_launch_process(self): from airflow.configuration import conf - test_dag_path = TEST_DAG_FOLDER / "test_scheduler_dags.py" - log_file_loc = conf.get("logging", "DAG_PROCESSOR_MANAGER_LOG_LOCATION") with suppress(OSError): os.remove(log_file_loc) # Starting dag processing with 0 max_runs to avoid redundant operations. - processor_agent = DagFileProcessorAgent(test_dag_path, 0, timedelta(days=365)) + processor_agent = DagFileProcessorAgent(0, timedelta(days=365)) processor_agent.start() processor_agent._process.join() @@ -830,25 +935,25 @@ def test_launch_process(self): assert os.path.isfile(log_file_loc) def test_get_callbacks_pipe(self): - processor_agent = DagFileProcessorAgent("", 1, timedelta(days=365)) + processor_agent = DagFileProcessorAgent(1, timedelta(days=365)) processor_agent._parent_signal_conn = Mock() retval = processor_agent.get_callbacks_pipe() assert retval == processor_agent._parent_signal_conn def test_get_callbacks_pipe_no_parent_signal_conn(self): - processor_agent = DagFileProcessorAgent("", 1, timedelta(days=365)) + processor_agent = DagFileProcessorAgent(1, timedelta(days=365)) processor_agent._parent_signal_conn = None with pytest.raises(ValueError, match="Process not started"): processor_agent.get_callbacks_pipe() def test_heartbeat_no_parent_signal_conn(self): - processor_agent = DagFileProcessorAgent("", 1, timedelta(days=365)) + processor_agent = DagFileProcessorAgent(1, timedelta(days=365)) processor_agent._parent_signal_conn = None with pytest.raises(ValueError, match="Process not started"): processor_agent.heartbeat() def test_heartbeat_poll_eof_error(self): - processor_agent = DagFileProcessorAgent("", 1, timedelta(days=365)) + processor_agent = DagFileProcessorAgent(1, timedelta(days=365)) processor_agent._parent_signal_conn = Mock() processor_agent._parent_signal_conn.poll.return_value = True processor_agent._parent_signal_conn.recv = Mock() @@ -857,7 +962,7 @@ def test_heartbeat_poll_eof_error(self): assert ret_val is None def test_heartbeat_poll_connection_error(self): - processor_agent = DagFileProcessorAgent("", 1, timedelta(days=365)) + processor_agent = DagFileProcessorAgent(1, timedelta(days=365)) processor_agent._parent_signal_conn = Mock() processor_agent._parent_signal_conn.poll.return_value = True processor_agent._parent_signal_conn.recv = Mock() @@ -866,7 +971,7 @@ def test_heartbeat_poll_connection_error(self): assert ret_val is None def test_heartbeat_poll_process_message(self): - processor_agent = DagFileProcessorAgent("", 1, timedelta(days=365)) + processor_agent = DagFileProcessorAgent(1, timedelta(days=365)) processor_agent._parent_signal_conn = Mock() processor_agent._parent_signal_conn.poll.side_effect = [True, False] processor_agent._parent_signal_conn.recv = Mock() @@ -877,13 +982,13 @@ def test_heartbeat_poll_process_message(self): def test_process_message_invalid_type(self): message = "xyz" - processor_agent = DagFileProcessorAgent("", 1, timedelta(days=365)) + processor_agent = DagFileProcessorAgent(1, timedelta(days=365)) with pytest.raises(RuntimeError, match="Unexpected message received of type str"): processor_agent._process_message(message) @mock.patch("airflow.utils.process_utils.reap_process_group") def test_heartbeat_manager_process_restart(self, mock_pg, monkeypatch): - processor_agent = DagFileProcessorAgent("", 1, timedelta(days=365)) + processor_agent = DagFileProcessorAgent(1, timedelta(days=365)) processor_agent._parent_signal_conn = Mock() processor_agent._process = MagicMock() monkeypatch.setattr(processor_agent._process, "pid", 1234) @@ -899,7 +1004,7 @@ def test_heartbeat_manager_process_restart(self, mock_pg, monkeypatch): @mock.patch("time.monotonic") @mock.patch("airflow.dag_processing.manager.reap_process_group") def test_heartbeat_manager_process_reap(self, mock_pg, mock_time_monotonic, mock_stats): - processor_agent = DagFileProcessorAgent("", 1, timedelta(days=365)) + processor_agent = DagFileProcessorAgent(1, timedelta(days=365)) processor_agent._parent_signal_conn = Mock() processor_agent._process = Mock() processor_agent._process.pid = 12345 @@ -920,7 +1025,7 @@ def test_heartbeat_manager_process_reap(self, mock_pg, mock_time_monotonic, mock processor_agent.start.assert_called() def test_heartbeat_manager_end_no_process(self): - processor_agent = DagFileProcessorAgent("", 1, timedelta(days=365)) + processor_agent = DagFileProcessorAgent(1, timedelta(days=365)) processor_agent._process = Mock() processor_agent._process.__bool__ = Mock(return_value=False) processor_agent._process.side_effect = [None] @@ -931,26 +1036,25 @@ def test_heartbeat_manager_end_no_process(self): processor_agent._process.join.assert_not_called() @pytest.mark.execution_timeout(5) - def test_terminate(self, tmp_path): - processor_agent = DagFileProcessorAgent(tmp_path, -1, timedelta(days=365)) + def test_terminate(self, tmp_path, configure_testing_dag_bundle): + with configure_testing_dag_bundle(tmp_path): + processor_agent = DagFileProcessorAgent(-1, timedelta(days=365)) - processor_agent.start() - try: - processor_agent.terminate() + processor_agent.start() + try: + processor_agent.terminate() - processor_agent._process.join(timeout=1) - assert processor_agent._process.is_alive() is False - assert processor_agent._process.exitcode == 0 - except Exception: - reap_process_group(processor_agent._process.pid, logger=logger) - raise + processor_agent._process.join(timeout=1) + assert processor_agent._process.is_alive() is False + assert processor_agent._process.exitcode == 0 + except Exception: + reap_process_group(processor_agent._process.pid, logger=logger) + raise @conf_vars({("logging", "dag_processor_manager_log_stdout"): "True"}) def test_log_to_stdout(self, capfd): - test_dag_path = TEST_DAG_FOLDER / "test_scheduler_dags.py" - # Starting dag processing with 0 max_runs to avoid redundant operations. - processor_agent = DagFileProcessorAgent(test_dag_path, 0, timedelta(days=365)) + processor_agent = DagFileProcessorAgent(0, timedelta(days=365)) processor_agent.start() processor_agent._process.join() @@ -961,10 +1065,8 @@ def test_log_to_stdout(self, capfd): @conf_vars({("logging", "dag_processor_manager_log_stdout"): "False"}) def test_not_log_to_stdout(self, capfd): - test_dag_path = TEST_DAG_FOLDER / "test_scheduler_dags.py" - # Starting dag processing with 0 max_runs to avoid redundant operations. - processor_agent = DagFileProcessorAgent(test_dag_path, 0, timedelta(days=365)) + processor_agent = DagFileProcessorAgent(0, timedelta(days=365)) processor_agent.start() processor_agent._process.join() diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index b9db6f2751d5f..497ae4fb1d441 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -107,6 +107,7 @@ ELASTIC_DAG_FILE = os.path.join(PERF_DAGS_FOLDER, "elastic_dag.py") TEST_DAG_FOLDER = os.environ["AIRFLOW__CORE__DAGS_FOLDER"] +EXAMPLE_DAGS_FOLDER = airflow.example_dags.__path__[0] DEFAULT_DATE = timezone.datetime(2016, 1, 1) DEFAULT_LOGICAL_DATE = timezone.coerce_datetime(DEFAULT_DATE) TRY_NUMBER = 1 @@ -119,12 +120,6 @@ def disable_load_example(): yield -@pytest.fixture -def load_examples(): - with conf_vars({("core", "load_examples"): "True"}): - yield - - # Patch the MockExecutor into the dict of known executors in the Loader @contextlib.contextmanager def _loader_mock(mock_executors): @@ -579,26 +574,32 @@ def test_execute_task_instances_backfill_tasks_will_execute(self, dag_maker): session.rollback() @conf_vars({("scheduler", "standalone_dag_processor"): "False"}) - def test_setup_callback_sink_not_standalone_dag_processor(self, mock_executors): - scheduler_job = Job() - self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull, num_runs=1) - self.job_runner._execute() + def test_setup_callback_sink_not_standalone_dag_processor( + self, mock_executors, configure_testing_dag_bundle + ): + with configure_testing_dag_bundle(os.devnull): + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, num_runs=1) + self.job_runner._execute() - assert isinstance(scheduler_job.executor.callback_sink, PipeCallbackSink) + assert isinstance(scheduler_job.executor.callback_sink, PipeCallbackSink) @conf_vars({("scheduler", "standalone_dag_processor"): "False"}) - def test_setup_callback_sink_not_standalone_dag_processor_multiple_executors(self, mock_executors): - scheduler_job = Job() - self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull, num_runs=1) - self.job_runner._execute() + def test_setup_callback_sink_not_standalone_dag_processor_multiple_executors( + self, mock_executors, configure_testing_dag_bundle + ): + with configure_testing_dag_bundle(os.devnull): + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, num_runs=1) + self.job_runner._execute() - for executor in scheduler_job.executors: - assert isinstance(executor.callback_sink, PipeCallbackSink) + for executor in scheduler_job.executors: + assert isinstance(executor.callback_sink, PipeCallbackSink) @conf_vars({("scheduler", "standalone_dag_processor"): "True"}) def test_setup_callback_sink_standalone_dag_processor(self, mock_executors): scheduler_job = Job() - self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull, num_runs=1) + self.job_runner = SchedulerJobRunner(job=scheduler_job, num_runs=1) self.job_runner._execute() assert isinstance(scheduler_job.executor.callback_sink, DatabaseCallbackSink) @@ -606,7 +607,7 @@ def test_setup_callback_sink_standalone_dag_processor(self, mock_executors): @conf_vars({("scheduler", "standalone_dag_processor"): "True"}) def test_setup_callback_sink_standalone_dag_processor_multiple_executors(self, mock_executors): scheduler_job = Job() - self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull, num_runs=1) + self.job_runner = SchedulerJobRunner(job=scheduler_job, num_runs=1) self.job_runner._execute() for executor in scheduler_job.executors: @@ -615,37 +616,40 @@ def test_setup_callback_sink_standalone_dag_processor_multiple_executors(self, m @conf_vars({("scheduler", "standalone_dag_processor"): "True"}) def test_executor_start_called(self, mock_executors): scheduler_job = Job() - self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull, num_runs=1) + self.job_runner = SchedulerJobRunner(job=scheduler_job, num_runs=1) self.job_runner._execute() scheduler_job.executor.start.assert_called_once() for executor in scheduler_job.executors: executor.start.assert_called_once() - def test_executor_job_id_assigned(self, mock_executors): - scheduler_job = Job() - self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull, num_runs=1) - self.job_runner._execute() + def test_executor_job_id_assigned(self, mock_executors, configure_testing_dag_bundle): + with configure_testing_dag_bundle(os.devnull): + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, num_runs=1) + self.job_runner._execute() - assert scheduler_job.executor.job_id == scheduler_job.id - for executor in scheduler_job.executors: - assert executor.job_id == scheduler_job.id + assert scheduler_job.executor.job_id == scheduler_job.id + for executor in scheduler_job.executors: + assert executor.job_id == scheduler_job.id - def test_executor_heartbeat(self, mock_executors): - scheduler_job = Job() - self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull, num_runs=1) - self.job_runner._execute() + def test_executor_heartbeat(self, mock_executors, configure_testing_dag_bundle): + with configure_testing_dag_bundle(os.devnull): + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, num_runs=1) + self.job_runner._execute() - for executor in scheduler_job.executors: - executor.heartbeat.assert_called_once() + for executor in scheduler_job.executors: + executor.heartbeat.assert_called_once() - def test_executor_events_processed(self, mock_executors): - scheduler_job = Job() - self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull, num_runs=1) - self.job_runner._execute() + def test_executor_events_processed(self, mock_executors, configure_testing_dag_bundle): + with configure_testing_dag_bundle(os.devnull): + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, num_runs=1) + self.job_runner._execute() - for executor in scheduler_job.executors: - executor.get_event_buffer.assert_called_once() + for executor in scheduler_job.executors: + executor.get_event_buffer.assert_called_once() def test_executor_debug_dump(self, mock_executors): scheduler_job = Job() @@ -2891,17 +2895,17 @@ def my_task(): ... self.job_runner._do_scheduling(session) assert session.query(DagRun).one().state == run_state - def test_dagrun_root_after_dagrun_unfinished(self, mock_executor): + def test_dagrun_root_after_dagrun_unfinished(self, mock_executor, testing_dag_bundle): """ DagRuns with one successful and one future root task -> SUCCESS Noted: the DagRun state could be still in running state during CI. """ dagbag = DagBag(TEST_DAG_FOLDER, include_examples=False) - dagbag.sync_to_db() + dagbag.sync_to_db("testing", None) dag_id = "test_dagrun_states_root_future" dag = dagbag.get_dag(dag_id) - dag.sync_to_db() + DAG.bulk_write_to_db("testing", None, [dag]) scheduler_job = Job() self.job_runner = SchedulerJobRunner(job=scheduler_job, num_runs=2, subdir=dag.fileloc) @@ -2920,7 +2924,7 @@ def test_dagrun_root_after_dagrun_unfinished(self, mock_executor): {("scheduler", "standalone_dag_processor"): "True"}, ], ) - def test_scheduler_start_date(self, configs): + def test_scheduler_start_date(self, configs, testing_dag_bundle): """ Test that the scheduler respects start_dates, even when DAGs have run """ @@ -2935,7 +2939,7 @@ def test_scheduler_start_date(self, configs): # Deactivate other dags in this file other_dag = dagbag.get_dag("test_task_start_date_scheduling") other_dag.is_paused_upon_creation = True - other_dag.sync_to_db() + DAG.bulk_write_to_db("testing", None, [other_dag]) scheduler_job = Job( executor=self.null_exec, ) @@ -2981,7 +2985,7 @@ def test_scheduler_start_date(self, configs): {("scheduler", "standalone_dag_processor"): "True"}, ], ) - def test_scheduler_task_start_date(self, configs): + def test_scheduler_task_start_date(self, configs, testing_dag_bundle): """ Test that the scheduler respects task start dates that are different from DAG start dates """ @@ -3000,7 +3004,7 @@ def test_scheduler_task_start_date(self, configs): other_dag.is_paused_upon_creation = True dagbag.bag_dag(dag=other_dag) - dagbag.sync_to_db() + dagbag.sync_to_db("testing", None) scheduler_job = Job( executor=self.null_exec, @@ -3553,21 +3557,7 @@ def test_list_py_file_paths(self): if file_name.endswith((".py", ".zip")): if file_name not in ignored_files: expected_files.add(f"{root}/{file_name}") - for file_path in list_py_file_paths(TEST_DAG_FOLDER, include_examples=False): - detected_files.add(file_path) - assert detected_files == expected_files - - ignored_files = { - "helper.py", - } - example_dag_folder = airflow.example_dags.__path__[0] - for root, _, files in os.walk(example_dag_folder): - for file_name in files: - if file_name.endswith((".py", ".zip")): - if file_name not in ["__init__.py"] and file_name not in ignored_files: - expected_files.add(os.path.join(root, file_name)) - detected_files.clear() - for file_path in list_py_file_paths(TEST_DAG_FOLDER, include_examples=True): + for file_path in list_py_file_paths(TEST_DAG_FOLDER): detected_files.add(file_path) assert detected_files == expected_files @@ -4207,7 +4197,7 @@ def test_scheduler_create_dag_runs_does_not_raise_error(self, caplog, dag_maker) "DAG 'test_scheduler_create_dag_runs_does_not_raise_error' not found in serialized_dag table", ] - def test_bulk_write_to_db_external_trigger_dont_skip_scheduled_run(self, dag_maker): + def test_bulk_write_to_db_external_trigger_dont_skip_scheduled_run(self, dag_maker, testing_dag_bundle): """ Test that externally triggered Dag Runs should not affect (by skipping) next scheduled DAG runs @@ -4265,7 +4255,7 @@ def test_bulk_write_to_db_external_trigger_dont_skip_scheduled_run(self, dag_mak ) assert dr is not None # Run DAG.bulk_write_to_db -- this is run when in DagFileProcessor.process_file - DAG.bulk_write_to_db([dag], session=session) + DAG.bulk_write_to_db("testing", None, [dag], session=session) # Test that 'dag_model.next_dagrun' has not been changed because of newly created external # triggered DagRun. @@ -4558,7 +4548,7 @@ def test_do_schedule_max_active_runs_and_manual_trigger(self, dag_maker, mock_ex dag_version = DagVersion.get_latest_version(dag.dag_id) dag_run = dag_maker.create_dagrun(state=State.QUEUED, session=session, dag_version=dag_version) - dag.sync_to_db(session=session) # Update the date fields + DAG.bulk_write_to_db("testing", None, [dag], session=session) # Update the date fields scheduler_job = Job() self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) @@ -5653,11 +5643,11 @@ def test_find_and_purge_zombies_nothing(self): self.job_runner._find_and_purge_zombies() executor.callback_sink.send.assert_not_called() - def test_find_and_purge_zombies(self, load_examples, session): - dagbag = DagBag(TEST_DAG_FOLDER, read_dags_from_db=False) - + def test_find_and_purge_zombies(self, session, testing_dag_bundle): + dagfile = os.path.join(EXAMPLE_DAGS_FOLDER, "example_branch_operator.py") + dagbag = DagBag(dagfile) dag = dagbag.get_dag("example_branch_operator") - dag.sync_to_db() + DAG.bulk_write_to_db("testing", None, [dag]) data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE) triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} dag_run = dag.create_dagrun( @@ -5709,70 +5699,74 @@ def test_find_and_purge_zombies(self, load_examples, session): assert callback_request.ti.run_id == ti.run_id assert callback_request.ti.map_index == ti.map_index - def test_zombie_message(self, load_examples): + def test_zombie_message(self, testing_dag_bundle, session): """ Check that the zombie message comes out as expected """ dagbag = DagBag(TEST_DAG_FOLDER, read_dags_from_db=False) - with create_session() as session: - session.query(Job).delete() - dag = dagbag.get_dag("example_branch_operator") - dag.sync_to_db() - - data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE) - triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} - dag_run = dag.create_dagrun( - state=DagRunState.RUNNING, - logical_date=DEFAULT_DATE, - run_type=DagRunType.SCHEDULED, - session=session, - data_interval=data_interval, - **triggered_by_kwargs, - ) + dagfile = os.path.join(EXAMPLE_DAGS_FOLDER, "example_branch_operator.py") + dagbag = DagBag(dagfile) + dag = dagbag.get_dag("example_branch_operator") + DAG.bulk_write_to_db("testing", None, [dag]) - scheduler_job = Job(executor=MockExecutor()) - self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) - self.job_runner.processor_agent = mock.MagicMock() + session.query(Job).delete() - # We will provision 2 tasks so we can check we only find zombies from this scheduler - tasks_to_setup = ["branching", "run_this_first"] + data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE) + triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} + dag_run = dag.create_dagrun( + state=DagRunState.RUNNING, + logical_date=DEFAULT_DATE, + run_type=DagRunType.SCHEDULED, + session=session, + data_interval=data_interval, + **triggered_by_kwargs, + ) - for task_id in tasks_to_setup: - task = dag.get_task(task_id=task_id) - ti = TaskInstance(task, run_id=dag_run.run_id, state=State.RUNNING) - ti.queued_by_job_id = 999 + scheduler_job = Job(executor=MockExecutor()) + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + self.job_runner.processor_agent = mock.MagicMock() - session.add(ti) - session.flush() + # We will provision 2 tasks so we can check we only find zombies from this scheduler + tasks_to_setup = ["branching", "run_this_first"] - assert task.task_id == "run_this_first" # Make sure we have the task/ti we expect + for task_id in tasks_to_setup: + task = dag.get_task(task_id=task_id) + ti = TaskInstance(task, run_id=dag_run.run_id, state=State.RUNNING) + ti.queued_by_job_id = 999 - ti.queued_by_job_id = scheduler_job.id + session.add(ti) session.flush() - zombie_message = self.job_runner._generate_zombie_message_details(ti) - assert zombie_message == { - "DAG Id": "example_branch_operator", - "Task Id": "run_this_first", - "Run Id": "scheduled__2016-01-01T00:00:00+00:00", - } - - ti.hostname = "10.10.10.10" - ti.map_index = 2 - ti.external_executor_id = "abcdefg" - - zombie_message = self.job_runner._generate_zombie_message_details(ti) - assert zombie_message == { - "DAG Id": "example_branch_operator", - "Task Id": "run_this_first", - "Run Id": "scheduled__2016-01-01T00:00:00+00:00", - "Hostname": "10.10.10.10", - "Map Index": 2, - "External Executor Id": "abcdefg", - } - - def test_find_zombies_handle_failure_callbacks_are_correctly_passed_to_dag_processor(self): + assert task.task_id == "run_this_first" # Make sure we have the task/ti we expect + + ti.queued_by_job_id = scheduler_job.id + session.flush() + + zombie_message = self.job_runner._generate_zombie_message_details(ti) + assert zombie_message == { + "DAG Id": "example_branch_operator", + "Task Id": "run_this_first", + "Run Id": "scheduled__2016-01-01T00:00:00+00:00", + } + + ti.hostname = "10.10.10.10" + ti.map_index = 2 + ti.external_executor_id = "abcdefg" + + zombie_message = self.job_runner._generate_zombie_message_details(ti) + assert zombie_message == { + "DAG Id": "example_branch_operator", + "Task Id": "run_this_first", + "Run Id": "scheduled__2016-01-01T00:00:00+00:00", + "Hostname": "10.10.10.10", + "Map Index": 2, + "External Executor Id": "abcdefg", + } + + def test_find_zombies_handle_failure_callbacks_are_correctly_passed_to_dag_processor( + self, testing_dag_bundle + ): """ Check that the same set of failure callback with zombies are passed to the dag file processors until the next zombie detection logic is invoked. @@ -5784,7 +5778,7 @@ def test_find_zombies_handle_failure_callbacks_are_correctly_passed_to_dag_proce ) session.query(Job).delete() dag = dagbag.get_dag("test_example_bash_operator") - dag.sync_to_db() + DAG.bulk_write_to_db("testing", None, [dag]) data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE) triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} dag_run = dag.create_dagrun( @@ -5830,11 +5824,11 @@ def test_find_zombies_handle_failure_callbacks_are_correctly_passed_to_dag_proce callback_requests[0].ti = None assert expected_failure_callback_requests[0] == callback_requests[0] - def test_cleanup_stale_dags(self): + def test_cleanup_stale_dags(self, testing_dag_bundle): dagbag = DagBag(TEST_DAG_FOLDER, read_dags_from_db=False) with create_session() as session: dag = dagbag.get_dag("test_example_bash_operator") - dag.sync_to_db() + DAG.bulk_write_to_db("testing", None, [dag]) dm = DagModel.get_current("test_example_bash_operator") # Make it "stale". dm.last_parsed_time = timezone.utcnow() - timedelta(minutes=11) @@ -5842,7 +5836,7 @@ def test_cleanup_stale_dags(self): # This one should remain active. dag = dagbag.get_dag("test_start_date_scheduling") - dag.sync_to_db() + DAG.bulk_write_to_db("testing", None, [dag]) session.flush() @@ -5924,13 +5918,13 @@ def watch_heartbeat(*args, **kwargs): @pytest.mark.long_running @pytest.mark.parametrize("dag_id", ["test_mapped_classic", "test_mapped_taskflow"]) - def test_mapped_dag(self, dag_id, session): + def test_mapped_dag(self, dag_id, session, testing_dag_bundle): """End-to-end test of a simple mapped dag""" # Use SequentialExecutor for more predictable test behaviour from airflow.executors.sequential_executor import SequentialExecutor dagbag = DagBag(dag_folder=TEST_DAGS_FOLDER, include_examples=False) - dagbag.sync_to_db() + dagbag.sync_to_db("testing", None) dagbag.process_file(str(TEST_DAGS_FOLDER / f"{dag_id}.py")) dag = dagbag.get_dag(dag_id) assert dag @@ -5957,14 +5951,14 @@ def test_mapped_dag(self, dag_id, session): dr.refresh_from_db(session) assert dr.state == DagRunState.SUCCESS - def test_should_mark_empty_task_as_success(self): + def test_should_mark_empty_task_as_success(self, testing_dag_bundle): dag_file = os.path.join( os.path.dirname(os.path.realpath(__file__)), "../dags/test_only_empty_tasks.py" ) # Write DAGs to dag and serialized_dag table dagbag = DagBag(dag_folder=dag_file, include_examples=False, read_dags_from_db=False) - dagbag.sync_to_db() + dagbag.sync_to_db("testing", None) scheduler_job = Job() self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) @@ -6035,7 +6029,7 @@ def test_should_mark_empty_task_as_success(self): assert duration is None @pytest.mark.need_serialized_dag - def test_catchup_works_correctly(self, dag_maker): + def test_catchup_works_correctly(self, dag_maker, testing_dag_bundle): """Test that catchup works correctly""" session = settings.Session() with dag_maker( @@ -6070,7 +6064,7 @@ def test_catchup_works_correctly(self, dag_maker): session.flush() dag.catchup = False - dag.sync_to_db() + DAG.bulk_write_to_db("testing", None, [dag]) assert not dag.catchup dm = DagModel.get_dagmodel(dag.dag_id) @@ -6279,7 +6273,7 @@ def test_misconfigured_dags_doesnt_crash_scheduler(self, session, dag_maker, cap # Check if the second dagrun was created assert DagRun.find(dag_id="testdag2", session=session) - def test_activate_referenced_assets_with_no_existing_warning(self, session): + def test_activate_referenced_assets_with_no_existing_warning(self, session, testing_dag_bundle): dag_warnings = session.query(DagWarning).all() assert dag_warnings == [] @@ -6292,7 +6286,7 @@ def test_activate_referenced_assets_with_no_existing_warning(self, session): asset1_2 = Asset(name="it's also a duplicate", uri="s3://bucket/key/1", extra=asset_extra) dag1 = DAG(dag_id=dag_id1, start_date=DEFAULT_DATE, schedule=[asset1, asset1_1, asset1_2]) - DAG.bulk_write_to_db([dag1], session=session) + DAG.bulk_write_to_db("testing", None, [dag1], session=session) asset_models = session.scalars(select(AssetModel)).all() assert len(asset_models) == 3 @@ -6313,7 +6307,7 @@ def test_activate_referenced_assets_with_no_existing_warning(self, session): "dy associated to 'asset1'" ) - def test_activate_referenced_assets_with_existing_warnings(self, session): + def test_activate_referenced_assets_with_existing_warnings(self, session, testing_dag_bundle): dag_ids = [f"test_asset_dag{i}" for i in range(1, 4)] asset1_name = "asset1" asset_extra = {"foo": "bar"} @@ -6332,7 +6326,7 @@ def test_activate_referenced_assets_with_existing_warnings(self, session): dag2 = DAG(dag_id=dag_ids[1], start_date=DEFAULT_DATE) dag3 = DAG(dag_id=dag_ids[2], start_date=DEFAULT_DATE, schedule=[asset1_2]) - DAG.bulk_write_to_db([dag1, dag2, dag3], session=session) + DAG.bulk_write_to_db("testing", None, [dag1, dag2, dag3], session=session) asset_models = session.scalars(select(AssetModel)).all() @@ -6366,7 +6360,9 @@ def test_activate_referenced_assets_with_existing_warnings(self, session): "name is already associated to 's3://bucket/key/1'" ) - def test_activate_referenced_assets_with_multiple_conflict_asset_in_one_dag(self, session): + def test_activate_referenced_assets_with_multiple_conflict_asset_in_one_dag( + self, session, testing_dag_bundle + ): dag_id = "test_asset_dag" asset1_name = "asset1" asset_extra = {"foo": "bar"} @@ -6379,7 +6375,7 @@ def test_activate_referenced_assets_with_multiple_conflict_asset_in_one_dag(self ) dag1 = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, schedule=schedule) - DAG.bulk_write_to_db([dag1], session=session) + DAG.bulk_write_to_db("testing", None, [dag1], session=session) asset_models = session.scalars(select(AssetModel)).all() @@ -6469,7 +6465,9 @@ def per_test(self) -> Generator: (93, 10, 10), # 10 DAGs with 10 tasks per DAG file. ], ) - def test_execute_queries_count_with_harvested_dags(self, expected_query_count, dag_count, task_count): + def test_execute_queries_count_with_harvested_dags( + self, expected_query_count, dag_count, task_count, testing_dag_bundle + ): with ( mock.patch.dict( "os.environ", @@ -6496,7 +6494,7 @@ def test_execute_queries_count_with_harvested_dags(self, expected_query_count, d ): dagruns = [] dagbag = DagBag(dag_folder=ELASTIC_DAG_FILE, include_examples=False, read_dags_from_db=False) - dagbag.sync_to_db() + dagbag.sync_to_db("testing", None) dag_ids = dagbag.dag_ids dagbag = DagBag(read_dags_from_db=True) @@ -6566,7 +6564,7 @@ def test_execute_queries_count_with_harvested_dags(self, expected_query_count, d ], ) def test_process_dags_queries_count( - self, expected_query_counts, dag_count, task_count, start_ago, schedule, shape + self, expected_query_counts, dag_count, task_count, start_ago, schedule, shape, testing_dag_bundle ): with ( mock.patch.dict( @@ -6592,7 +6590,7 @@ def test_process_dags_queries_count( ), ): dagbag = DagBag(dag_folder=ELASTIC_DAG_FILE, include_examples=False) - dagbag.sync_to_db() + dagbag.sync_to_db("testing", None) mock_agent = mock.MagicMock() diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 6eaa4e3ac3aeb..a63b05b0cd343 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -623,7 +623,7 @@ def test_dagtag_repr(self): repr(t) for t in session.query(DagTag).filter(DagTag.dag_id == "dag-test-dagtag").all() } - def test_bulk_write_to_db(self): + def test_bulk_write_to_db(self, testing_dag_bundle): clear_db_dags() dags = [ DAG(f"dag-bulk-sync-{i}", schedule=None, start_date=DEFAULT_DATE, tags=["test-dag"]) @@ -631,7 +631,7 @@ def test_bulk_write_to_db(self): ] with assert_queries_count(6): - DAG.bulk_write_to_db(dags) + DAG.bulk_write_to_db("testing", None, dags) with create_session() as session: assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", "dag-bulk-sync-3"} == { row[0] for row in session.query(DagModel.dag_id).all() @@ -648,14 +648,14 @@ def test_bulk_write_to_db(self): # Re-sync should do fewer queries with assert_queries_count(9): - DAG.bulk_write_to_db(dags) + DAG.bulk_write_to_db("testing", None, dags) with assert_queries_count(9): - DAG.bulk_write_to_db(dags) + DAG.bulk_write_to_db("testing", None, dags) # Adding tags for dag in dags: dag.tags.add("test-dag2") with assert_queries_count(10): - DAG.bulk_write_to_db(dags) + DAG.bulk_write_to_db("testing", None, dags) with create_session() as session: assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", "dag-bulk-sync-3"} == { row[0] for row in session.query(DagModel.dag_id).all() @@ -674,7 +674,7 @@ def test_bulk_write_to_db(self): for dag in dags: dag.tags.remove("test-dag") with assert_queries_count(10): - DAG.bulk_write_to_db(dags) + DAG.bulk_write_to_db("testing", None, dags) with create_session() as session: assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", "dag-bulk-sync-3"} == { row[0] for row in session.query(DagModel.dag_id).all() @@ -693,7 +693,7 @@ def test_bulk_write_to_db(self): for dag in dags: dag.tags = set() with assert_queries_count(10): - DAG.bulk_write_to_db(dags) + DAG.bulk_write_to_db("testing", None, dags) with create_session() as session: assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", "dag-bulk-sync-3"} == { row[0] for row in session.query(DagModel.dag_id).all() @@ -703,7 +703,7 @@ def test_bulk_write_to_db(self): for row in session.query(DagModel.last_parsed_time).all(): assert row[0] is not None - def test_bulk_write_to_db_single_dag(self): + def test_bulk_write_to_db_single_dag(self, testing_dag_bundle): """ Test bulk_write_to_db for a single dag using the index optimized query """ @@ -714,7 +714,7 @@ def test_bulk_write_to_db_single_dag(self): ] with assert_queries_count(6): - DAG.bulk_write_to_db(dags) + DAG.bulk_write_to_db("testing", None, dags) with create_session() as session: assert {"dag-bulk-sync-0"} == {row[0] for row in session.query(DagModel.dag_id).all()} assert { @@ -726,11 +726,11 @@ def test_bulk_write_to_db_single_dag(self): # Re-sync should do fewer queries with assert_queries_count(8): - DAG.bulk_write_to_db(dags) + DAG.bulk_write_to_db("testing", None, dags) with assert_queries_count(8): - DAG.bulk_write_to_db(dags) + DAG.bulk_write_to_db("testing", None, dags) - def test_bulk_write_to_db_multiple_dags(self): + def test_bulk_write_to_db_multiple_dags(self, testing_dag_bundle): """ Test bulk_write_to_db for multiple dags which does not use the index optimized query """ @@ -741,7 +741,7 @@ def test_bulk_write_to_db_multiple_dags(self): ] with assert_queries_count(6): - DAG.bulk_write_to_db(dags) + DAG.bulk_write_to_db("testing", None, dags) with create_session() as session: assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", "dag-bulk-sync-3"} == { row[0] for row in session.query(DagModel.dag_id).all() @@ -758,26 +758,26 @@ def test_bulk_write_to_db_multiple_dags(self): # Re-sync should do fewer queries with assert_queries_count(9): - DAG.bulk_write_to_db(dags) + DAG.bulk_write_to_db("testing", None, dags) with assert_queries_count(9): - DAG.bulk_write_to_db(dags) + DAG.bulk_write_to_db("testing", None, dags) @pytest.mark.parametrize("interval", [None, "@daily"]) - def test_bulk_write_to_db_interval_save_runtime(self, interval): + def test_bulk_write_to_db_interval_save_runtime(self, testing_dag_bundle, interval): mock_active_runs_of_dags = mock.MagicMock(side_effect=DagRun.active_runs_of_dags) with mock.patch.object(DagRun, "active_runs_of_dags", mock_active_runs_of_dags): dags_null_timetable = [ DAG("dag-interval-None", schedule=None, start_date=TEST_DATE), DAG("dag-interval-test", schedule=interval, start_date=TEST_DATE), ] - DAG.bulk_write_to_db(dags_null_timetable, session=settings.Session()) + DAG.bulk_write_to_db("testing", None, dags_null_timetable, session=settings.Session()) if interval: mock_active_runs_of_dags.assert_called_once() else: mock_active_runs_of_dags.assert_not_called() @pytest.mark.parametrize("state", [DagRunState.RUNNING, DagRunState.QUEUED]) - def test_bulk_write_to_db_max_active_runs(self, state): + def test_bulk_write_to_db_max_active_runs(self, testing_dag_bundle, state): """ Test that DagModel.next_dagrun_create_after is set to NULL when the dag cannot be created due to max active runs being hit. @@ -793,7 +793,7 @@ def test_bulk_write_to_db_max_active_runs(self, state): session = settings.Session() dag.clear() - DAG.bulk_write_to_db([dag], session=session) + DAG.bulk_write_to_db("testing", None, [dag], session=session) model = session.get(DagModel, dag.dag_id) @@ -810,17 +810,17 @@ def test_bulk_write_to_db_max_active_runs(self, state): **triggered_by_kwargs, ) assert dr is not None - DAG.bulk_write_to_db([dag]) + DAG.bulk_write_to_db("testing", None, [dag]) model = session.get(DagModel, dag.dag_id) # We signal "at max active runs" by saying this run is never eligible to be created assert model.next_dagrun_create_after is None # test that bulk_write_to_db again doesn't update next_dagrun_create_after - DAG.bulk_write_to_db([dag]) + DAG.bulk_write_to_db("testing", None, [dag]) model = session.get(DagModel, dag.dag_id) assert model.next_dagrun_create_after is None - def test_bulk_write_to_db_has_import_error(self): + def test_bulk_write_to_db_has_import_error(self, testing_dag_bundle): """ Test that DagModel.has_import_error is set to false if no import errors. """ @@ -830,7 +830,7 @@ def test_bulk_write_to_db_has_import_error(self): session = settings.Session() dag.clear() - DAG.bulk_write_to_db([dag], session=session) + DAG.bulk_write_to_db("testing", None, [dag], session=session) model = session.get(DagModel, dag.dag_id) @@ -844,14 +844,14 @@ def test_bulk_write_to_db_has_import_error(self): # assert assert model.has_import_errors # parse - DAG.bulk_write_to_db([dag]) + DAG.bulk_write_to_db("testing", None, [dag]) model = session.get(DagModel, dag.dag_id) # assert that has_import_error is now false assert not model.has_import_errors session.close() - def test_bulk_write_to_db_assets(self): + def test_bulk_write_to_db_assets(self, testing_dag_bundle): """ Ensure that assets referenced in a dag are correctly loaded into the database. """ @@ -877,7 +877,7 @@ def test_bulk_write_to_db_assets(self): session = settings.Session() dag1.clear() - DAG.bulk_write_to_db([dag1, dag2], session=session) + DAG.bulk_write_to_db("testing", None, [dag1, dag2], session=session) session.commit() stored_assets = {x.uri: x for x in session.query(AssetModel).all()} asset1_orm = stored_assets[a1.uri] @@ -908,7 +908,7 @@ def test_bulk_write_to_db_assets(self): EmptyOperator(task_id=task_id, dag=dag1, outlets=[a2]) dag2 = DAG(dag_id=dag_id2, start_date=DEFAULT_DATE, schedule=None) EmptyOperator(task_id=task_id, dag=dag2) - DAG.bulk_write_to_db([dag1, dag2], session=session) + DAG.bulk_write_to_db("testing", None, [dag1, dag2], session=session) session.commit() session.expunge_all() stored_assets = {x.uri: x for x in session.query(AssetModel).all()} @@ -937,7 +937,7 @@ def _find_assets_activation(session) -> tuple[list[AssetModel], list[AssetModel] ).all() return [a for a, v in assets if not v], [a for a, v in assets if v] - def test_bulk_write_to_db_does_not_activate(self, dag_maker, session): + def test_bulk_write_to_db_does_not_activate(self, dag_maker, testing_dag_bundle, session): """ Assets are not activated on write, but later in the scheduler by the SchedulerJob. """ @@ -950,14 +950,14 @@ def test_bulk_write_to_db_does_not_activate(self, dag_maker, session): dag1 = DAG(dag_id="assets-1", start_date=DEFAULT_DATE, schedule=[asset1]) BashOperator(dag=dag1, task_id="task", bash_command="echo 1", outlets=[asset3]) - DAG.bulk_write_to_db([dag1], session=session) + DAG.bulk_write_to_db("testing", None, [dag1], session=session) assert session.scalars(select(AssetModel).order_by(AssetModel.uri)).all() == [asset1, asset3] assert session.scalars(select(AssetActive)).all() == [] dag1 = DAG(dag_id="assets-1", start_date=DEFAULT_DATE, schedule=[asset1, asset2]) BashOperator(dag=dag1, task_id="task", bash_command="echo 1", outlets=[asset3, asset4]) - DAG.bulk_write_to_db([dag1], session=session) + DAG.bulk_write_to_db("testing", None, [dag1], session=session) assert session.scalars(select(AssetModel).order_by(AssetModel.uri)).all() == [ asset1, @@ -967,7 +967,7 @@ def test_bulk_write_to_db_does_not_activate(self, dag_maker, session): ] assert session.scalars(select(AssetActive)).all() == [] - def test_bulk_write_to_db_asset_aliases(self): + def test_bulk_write_to_db_asset_aliases(self, testing_dag_bundle): """ Ensure that asset aliases referenced in a dag are correctly loaded into the database. """ @@ -983,7 +983,7 @@ def test_bulk_write_to_db_asset_aliases(self): dag2 = DAG(dag_id=dag_id2, start_date=DEFAULT_DATE, schedule=None) EmptyOperator(task_id=task_id, dag=dag2, outlets=[asset_alias_2_2, asset_alias_3]) session = settings.Session() - DAG.bulk_write_to_db([dag1, dag2], session=session) + DAG.bulk_write_to_db("testing", None, [dag1, dag2], session=session) session.commit() stored_asset_alias_models = {x.name: x for x in session.query(AssetAliasModel).all()} @@ -2431,7 +2431,7 @@ def test_dags_needing_dagruns_asset_triggered_dag_info_queued_times(self, sessio assert first_queued_time == DEFAULT_DATE assert last_queued_time == DEFAULT_DATE + timedelta(hours=1) - def test_asset_expression(self, session: Session) -> None: + def test_asset_expression(self, testing_dag_bundle, session: Session) -> None: dag = DAG( dag_id="test_dag_asset_expression", schedule=AssetAny( @@ -2449,7 +2449,7 @@ def test_asset_expression(self, session: Session) -> None: ), start_date=datetime.datetime.min, ) - DAG.bulk_write_to_db([dag], session=session) + DAG.bulk_write_to_db("testing", None, [dag], session=session) expression = session.scalars(select(DagModel.asset_expression).filter_by(dag_id=dag.dag_id)).one() assert expression == { diff --git a/tests/models/test_dagcode.py b/tests/models/test_dagcode.py index 05598703476a5..ac7330b1aa581 100644 --- a/tests/models/test_dagcode.py +++ b/tests/models/test_dagcode.py @@ -41,8 +41,17 @@ def make_example_dags(module): """Loads DAGs from a module for test.""" + # TODO: AIP-66 dedup with tests/models/test_serdag + from airflow.models.dagbundle import DagBundleModel + from airflow.utils.session import create_session + + with create_session() as session: + if session.query(DagBundleModel).filter(DagBundleModel.name == "testing").count() == 0: + testing = DagBundleModel(name="testing") + session.add(testing) + dagbag = DagBag(module.__path__[0]) - DAG.bulk_write_to_db(dagbag.dags.values()) + DAG.bulk_write_to_db("testing", None, dagbag.dags.values()) return dagbag.dags diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index 13614d6abeaea..f398be8ad88c4 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -518,7 +518,7 @@ def test_on_success_callback_when_task_skipped(self, session): assert dag_run.state == DagRunState.SUCCESS mock_on_success.assert_called_once() - def test_dagrun_update_state_with_handle_callback_success(self, session): + def test_dagrun_update_state_with_handle_callback_success(self, testing_dag_bundle, session): def on_success_callable(context): assert context["dag_run"].dag_id == "test_dagrun_update_state_with_handle_callback_success" @@ -528,7 +528,7 @@ def on_success_callable(context): start_date=datetime.datetime(2017, 1, 1), on_success_callback=on_success_callable, ) - DAG.bulk_write_to_db(dags=[dag], session=session) + DAG.bulk_write_to_db("testing", None, dags=[dag], session=session) dag_task1 = EmptyOperator(task_id="test_state_succeeded1", dag=dag) dag_task2 = EmptyOperator(task_id="test_state_succeeded2", dag=dag) @@ -556,7 +556,7 @@ def on_success_callable(context): msg="success", ) - def test_dagrun_update_state_with_handle_callback_failure(self, session): + def test_dagrun_update_state_with_handle_callback_failure(self, testing_dag_bundle, session): def on_failure_callable(context): assert context["dag_run"].dag_id == "test_dagrun_update_state_with_handle_callback_failure" @@ -566,7 +566,7 @@ def on_failure_callable(context): start_date=datetime.datetime(2017, 1, 1), on_failure_callback=on_failure_callable, ) - DAG.bulk_write_to_db(dags=[dag], session=session) + DAG.bulk_write_to_db("testing", None, dags=[dag], session=session) dag_task1 = EmptyOperator(task_id="test_state_succeeded1", dag=dag) dag_task2 = EmptyOperator(task_id="test_state_failed2", dag=dag) diff --git a/tests/models/test_serialized_dag.py b/tests/models/test_serialized_dag.py index 60d26959b079b..94835fbd5e5d7 100644 --- a/tests/models/test_serialized_dag.py +++ b/tests/models/test_serialized_dag.py @@ -49,8 +49,16 @@ # To move it to a shared module. def make_example_dags(module): """Loads DAGs from a module for test.""" + from airflow.models.dagbundle import DagBundleModel + from airflow.utils.session import create_session + + with create_session() as session: + if session.query(DagBundleModel).filter(DagBundleModel.name == "testing").count() == 0: + testing = DagBundleModel(name="testing") + session.add(testing) + dagbag = DagBag(module.__path__[0]) - DAG.bulk_write_to_db(dagbag.dags.values()) + DAG.bulk_write_to_db("testing", None, dagbag.dags.values()) return dagbag.dags @@ -177,13 +185,13 @@ def test_read_all_dags_only_picks_the_latest_serdags(self, session): # assert only the latest SDM is returned assert len(sdags) != len(serialized_dags2) - def test_bulk_sync_to_db(self): + def test_bulk_sync_to_db(self, testing_dag_bundle): dags = [ DAG("dag_1", schedule=None), DAG("dag_2", schedule=None), DAG("dag_3", schedule=None), ] - DAG.bulk_write_to_db(dags) + DAG.bulk_write_to_db("testing", None, dags) # we also write to dag_version and dag_code tables # in dag_version. with assert_queries_count(24): diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 73f5908b707cf..bb3edc53cdd28 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -2280,7 +2280,7 @@ def test_success_callback_no_race_condition(self, create_task_instance): ti.refresh_from_db() assert ti.state == State.SUCCESS - def test_outlet_assets(self, create_task_instance): + def test_outlet_assets(self, create_task_instance, testing_dag_bundle): """ Verify that when we have an outlet asset on a task, and the task completes successfully, an AssetDagRunQueue is logged. @@ -2291,7 +2291,7 @@ def test_outlet_assets(self, create_task_instance): session = settings.Session() dagbag = DagBag(dag_folder=example_assets.__file__) dagbag.collect_dags(only_if_updated=False, safe_mode=False) - dagbag.sync_to_db(session=session) + dagbag.sync_to_db("testing", None, session=session) asset_models = session.scalars(select(AssetModel)).all() SchedulerJobRunner._activate_referenced_assets(asset_models, session=session) @@ -2343,7 +2343,7 @@ def test_outlet_assets(self, create_task_instance): event.timestamp < adrq_timestamp for (adrq_timestamp,) in adrq_timestamps ), f"Some items in {[str(t) for t in adrq_timestamps]} are earlier than {event.timestamp}" - def test_outlet_assets_failed(self, create_task_instance): + def test_outlet_assets_failed(self, create_task_instance, testing_dag_bundle): """ Verify that when we have an outlet asset on a task, and the task failed, an AssetDagRunQueue is not logged, and an AssetEvent is @@ -2355,7 +2355,7 @@ def test_outlet_assets_failed(self, create_task_instance): session = settings.Session() dagbag = DagBag(dag_folder=test_assets.__file__) dagbag.collect_dags(only_if_updated=False, safe_mode=False) - dagbag.sync_to_db(session=session) + dagbag.sync_to_db("testing", None, session=session) run_id = str(uuid4()) dr = DagRun(dag_with_fail_task.dag_id, run_id=run_id, run_type="anything") session.merge(dr) @@ -2397,7 +2397,7 @@ def raise_an_exception(placeholder: int): task_instance.run() assert task_instance.current_state() == TaskInstanceState.SUCCESS - def test_outlet_assets_skipped(self): + def test_outlet_assets_skipped(self, testing_dag_bundle): """ Verify that when we have an outlet asset on a task, and the task is skipped, an AssetDagRunQueue is not logged, and an AssetEvent is @@ -2409,7 +2409,7 @@ def test_outlet_assets_skipped(self): session = settings.Session() dagbag = DagBag(dag_folder=test_assets.__file__) dagbag.collect_dags(only_if_updated=False, safe_mode=False) - dagbag.sync_to_db(session=session) + dagbag.sync_to_db("testing", None, session=session) asset_models = session.scalars(select(AssetModel)).all() SchedulerJobRunner._activate_referenced_assets(asset_models, session=session) diff --git a/tests/operators/test_trigger_dagrun.py b/tests/operators/test_trigger_dagrun.py index a8a6b3c262903..74d8371e5d78e 100644 --- a/tests/operators/test_trigger_dagrun.py +++ b/tests/operators/test_trigger_dagrun.py @@ -26,7 +26,6 @@ from airflow.exceptions import AirflowException, DagRunAlreadyExists, TaskDeferred from airflow.models.dag import DagModel -from airflow.models.dagbag import DagBag from airflow.models.dagrun import DagRun from airflow.models.log import Log from airflow.models.taskinstance import TaskInstance @@ -37,6 +36,8 @@ from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.types import DagRunType +from tests_common.test_utils.db import parse_and_sync_to_db + pytestmark = pytest.mark.db_test DEFAULT_DATE = datetime(2019, 1, 1, tzinfo=timezone.utc) @@ -71,11 +72,6 @@ def setup_method(self): session.add(DagModel(dag_id=TRIGGERED_DAG_ID, fileloc=self._tmpfile)) session.commit() - def re_sync_triggered_dag_to_db(self, dag, dag_maker): - dagbag = DagBag(self.f_name, read_dags_from_db=False, include_examples=False) - dagbag.bag_dag(dag) - dagbag.sync_to_db(session=dag_maker.session) - def teardown_method(self): """Cleanup state after testing in DB.""" with create_session() as session: @@ -120,9 +116,10 @@ def test_trigger_dagrun(self, dag_maker): """Test TriggerDagRunOperator.""" with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator(task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @@ -134,13 +131,14 @@ def test_trigger_dagrun(self, dag_maker): def test_trigger_dagrun_custom_run_id(self, dag_maker): with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, trigger_run_id="custom_run_id", ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) @@ -154,13 +152,14 @@ def test_trigger_dagrun_with_logical_date(self, dag_maker): custom_logical_date = timezone.datetime(2021, 1, 2, 3, 4, 5) with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_trigger_dagrun_with_logical_date", trigger_dag_id=TRIGGERED_DAG_ID, logical_date=custom_logical_date, ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @@ -177,7 +176,7 @@ def test_trigger_dagrun_twice(self, dag_maker): run_id = f"manual__{utc_now.isoformat()}" with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_trigger_dagrun_with_logical_date", trigger_dag_id=TRIGGERED_DAG_ID, @@ -187,7 +186,8 @@ def test_trigger_dagrun_twice(self, dag_maker): reset_dag_run=True, wait_for_completion=True, ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() dag_run = DagRun( dag_id=TRIGGERED_DAG_ID, @@ -213,7 +213,7 @@ def test_trigger_dagrun_with_scheduled_dag_run(self, dag_maker): run_id = f"scheduled__{utc_now.isoformat()}" with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_trigger_dagrun_with_logical_date", trigger_dag_id=TRIGGERED_DAG_ID, @@ -223,7 +223,8 @@ def test_trigger_dagrun_with_scheduled_dag_run(self, dag_maker): reset_dag_run=True, wait_for_completion=True, ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() run_id = f"scheduled__{utc_now.isoformat()}" dag_run = DagRun( @@ -248,13 +249,14 @@ def test_trigger_dagrun_with_templated_logical_date(self, dag_maker): """Test TriggerDagRunOperator with templated logical_date.""" with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_trigger_dagrun_with_str_logical_date", trigger_dag_id=TRIGGERED_DAG_ID, logical_date="{{ logical_date }}", ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @@ -270,12 +272,13 @@ def test_trigger_dagrun_with_templated_trigger_dag_id(self, dag_maker): """Test TriggerDagRunOperator with templated trigger dag id.""" with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="__".join(["test_trigger_dagrun_with_templated_trigger_dag_id", TRIGGERED_DAG_ID]), trigger_dag_id="{{ ti.task_id.rsplit('.', 1)[-1].split('__')[-1] }}", ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @@ -291,13 +294,14 @@ def test_trigger_dagrun_operator_conf(self, dag_maker): """Test passing conf to the triggered DagRun.""" with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_trigger_dagrun_with_str_logical_date", trigger_dag_id=TRIGGERED_DAG_ID, conf={"foo": "bar"}, ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @@ -310,13 +314,14 @@ def test_trigger_dagrun_operator_templated_invalid_conf(self, dag_maker): """Test passing a conf that is not JSON Serializable raise error.""" with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_trigger_dagrun_with_invalid_conf", trigger_dag_id=TRIGGERED_DAG_ID, conf={"foo": "{{ dag.dag_id }}", "datetime": timezone.utcnow()}, ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() with pytest.raises(AirflowException, match="^conf parameter should be JSON Serializable$"): task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) @@ -325,13 +330,14 @@ def test_trigger_dagrun_operator_templated_conf(self, dag_maker): """Test passing a templated conf to the triggered DagRun.""" with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_trigger_dagrun_with_str_logical_date", trigger_dag_id=TRIGGERED_DAG_ID, conf={"foo": "{{ dag.dag_id }}"}, ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @@ -345,7 +351,7 @@ def test_trigger_dagrun_with_reset_dag_run_false(self, dag_maker): logical_date = DEFAULT_DATE with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, @@ -353,7 +359,8 @@ def test_trigger_dagrun_with_reset_dag_run_false(self, dag_maker): logical_date=None, reset_dag_run=False, ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() task.run(start_date=logical_date, end_date=logical_date, ignore_ti_state=True) task.run(start_date=logical_date, end_date=logical_date, ignore_ti_state=True) @@ -377,7 +384,7 @@ def test_trigger_dagrun_with_reset_dag_run_false_fail( logical_date = DEFAULT_DATE with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, @@ -385,7 +392,8 @@ def test_trigger_dagrun_with_reset_dag_run_false_fail( logical_date=trigger_logical_date, reset_dag_run=False, ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() task.run(start_date=logical_date, end_date=logical_date, ignore_ti_state=True) @@ -397,7 +405,7 @@ def test_trigger_dagrun_with_skip_when_already_exists(self, dag_maker): logical_date = DEFAULT_DATE with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, @@ -405,7 +413,8 @@ def test_trigger_dagrun_with_skip_when_already_exists(self, dag_maker): reset_dag_run=False, skip_when_already_exists=True, ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dr: DagRun = dag_maker.create_dagrun() task.run(start_date=logical_date, end_date=logical_date, ignore_ti_state=True) assert dr.get_task_instance("test_task").state == TaskInstanceState.SUCCESS @@ -428,7 +437,7 @@ def test_trigger_dagrun_with_reset_dag_run_true( logical_date = DEFAULT_DATE with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, @@ -436,7 +445,8 @@ def test_trigger_dagrun_with_reset_dag_run_true( logical_date=trigger_logical_date, reset_dag_run=True, ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() task.run(start_date=logical_date, end_date=logical_date, ignore_ti_state=True) task.run(start_date=logical_date, end_date=logical_date, ignore_ti_state=True) @@ -451,7 +461,7 @@ def test_trigger_dagrun_with_wait_for_completion_true(self, dag_maker): logical_date = DEFAULT_DATE with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, @@ -460,7 +470,8 @@ def test_trigger_dagrun_with_wait_for_completion_true(self, dag_maker): poke_interval=10, allowed_states=[State.QUEUED], ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() task.run(start_date=logical_date, end_date=logical_date) @@ -473,7 +484,7 @@ def test_trigger_dagrun_with_wait_for_completion_true_fail(self, dag_maker): logical_date = DEFAULT_DATE with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, @@ -482,7 +493,8 @@ def test_trigger_dagrun_with_wait_for_completion_true_fail(self, dag_maker): poke_interval=10, failed_states=[State.QUEUED], ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() with pytest.raises(AirflowException): task.run(start_date=logical_date, end_date=logical_date) @@ -492,12 +504,13 @@ def test_trigger_dagrun_triggering_itself(self, dag_maker): logical_date = DEFAULT_DATE with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TEST_DAG_ID, ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() task.run(start_date=logical_date, end_date=logical_date) @@ -516,7 +529,7 @@ def test_trigger_dagrun_with_wait_for_completion_true_defer_false(self, dag_make logical_date = DEFAULT_DATE with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, @@ -526,7 +539,8 @@ def test_trigger_dagrun_with_wait_for_completion_true_defer_false(self, dag_make allowed_states=[State.QUEUED], deferrable=False, ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() task.run(start_date=logical_date, end_date=logical_date) @@ -539,7 +553,7 @@ def test_trigger_dagrun_with_wait_for_completion_true_defer_true(self, dag_maker logical_date = DEFAULT_DATE with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, @@ -549,7 +563,8 @@ def test_trigger_dagrun_with_wait_for_completion_true_defer_true(self, dag_maker allowed_states=[State.QUEUED], deferrable=True, ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() task.run(start_date=logical_date, end_date=logical_date) @@ -571,7 +586,7 @@ def test_trigger_dagrun_with_wait_for_completion_true_defer_true_failure(self, d logical_date = DEFAULT_DATE with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, @@ -581,7 +596,8 @@ def test_trigger_dagrun_with_wait_for_completion_true_defer_true_failure(self, d allowed_states=[State.SUCCESS], deferrable=True, ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() task.run(start_date=logical_date, end_date=logical_date) @@ -607,7 +623,7 @@ def test_trigger_dagrun_with_wait_for_completion_true_defer_true_failure_2(self, logical_date = DEFAULT_DATE with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, @@ -618,7 +634,8 @@ def test_trigger_dagrun_with_wait_for_completion_true_defer_true_failure_2(self, failed_states=[State.QUEUED], deferrable=True, ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() task.run(start_date=logical_date, end_date=logical_date) @@ -648,7 +665,7 @@ def test_dagstatetrigger_logical_dates(self, trigger_logical_date, dag_maker): """Ensure that the DagStateTrigger is called with the triggered DAG's logical date.""" with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, @@ -658,7 +675,8 @@ def test_dagstatetrigger_logical_dates(self, trigger_logical_date, dag_maker): allowed_states=[DagRunState.QUEUED], deferrable=True, ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() mock_task_defer = mock.MagicMock(side_effect=task.defer) @@ -677,7 +695,7 @@ def test_dagstatetrigger_logical_dates_with_clear_and_reset(self, dag_maker): """Check DagStateTrigger is called with the triggered DAG's logical date on subsequent defers.""" with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, @@ -688,7 +706,8 @@ def test_dagstatetrigger_logical_dates_with_clear_and_reset(self, dag_maker): deferrable=True, reset_dag_run=True, ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() mock_task_defer = mock.MagicMock(side_effect=task.defer) @@ -727,7 +746,7 @@ def test_trigger_dagrun_with_no_failed_state(self, dag_maker): logical_date = DEFAULT_DATE with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, @@ -736,7 +755,8 @@ def test_trigger_dagrun_with_no_failed_state(self, dag_maker): poke_interval=10, failed_states=[], ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() assert task.failed_states == [] diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task_sensor.py index ba3cfd6b4480c..c73980ca24bd5 100644 --- a/tests/sensors/test_external_task_sensor.py +++ b/tests/sensors/test_external_task_sensor.py @@ -80,7 +80,7 @@ def clean_db(): @pytest.fixture -def dag_zip_maker(): +def dag_zip_maker(testing_dag_bundle): class DagZipMaker: def __call__(self, *dag_files): self.__dag_files = [os.sep.join([TEST_DAGS_FOLDER.__str__(), dag_file]) for dag_file in dag_files] @@ -98,7 +98,7 @@ def __enter__(self): for dag_file in self.__dag_files: zf.write(dag_file, os.path.basename(dag_file)) dagbag = DagBag(dag_folder=self.__tmp_dir, include_examples=False) - dagbag.sync_to_db() + dagbag.sync_to_db("testing", None) return dagbag def __exit__(self, exc_type, exc_val, exc_tb): diff --git a/tests/www/views/conftest.py b/tests/www/views/conftest.py index 7f524b2377e39..68516b4f4c865 100644 --- a/tests/www/views/conftest.py +++ b/tests/www/views/conftest.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +import os from collections.abc import Generator from contextlib import contextmanager from typing import Any, NamedTuple @@ -31,6 +32,7 @@ from tests_common.test_utils.api_connexion_utils import delete_user from tests_common.test_utils.config import conf_vars +from tests_common.test_utils.db import parse_and_sync_to_db from tests_common.test_utils.decorators import dont_initialize_flask_app_submodules from tests_common.test_utils.www import ( client_with_login, @@ -47,8 +49,8 @@ def session(): @pytest.fixture(autouse=True, scope="module") def examples_dag_bag(session): - DagBag(include_examples=True).sync_to_db() - dag_bag = DagBag(include_examples=True, read_dags_from_db=True) + parse_and_sync_to_db(os.devnull, include_examples=True) + dag_bag = DagBag(read_dags_from_db=True) session.commit() return dag_bag diff --git a/tests/www/views/test_views_acl.py b/tests/www/views/test_views_acl.py index 0f430b9b6dca7..21379550dd1fd 100644 --- a/tests/www/views/test_views_acl.py +++ b/tests/www/views/test_views_acl.py @@ -293,9 +293,9 @@ def test_dag_autocomplete_success(client_all_dags): "dag_display_name": None, }, {"name": "example_setup_teardown_taskflow", "type": "dag", "dag_display_name": None}, - {"name": "test_mapped_taskflow", "type": "dag", "dag_display_name": None}, {"name": "tutorial_taskflow_api", "type": "dag", "dag_display_name": None}, {"name": "tutorial_taskflow_api_virtualenv", "type": "dag", "dag_display_name": None}, + {"name": "tutorial_taskflow_templates", "type": "dag", "dag_display_name": None}, ] assert resp.json == expected diff --git a/tests/www/views/test_views_decorators.py b/tests/www/views/test_views_decorators.py index 5815510b07f6e..6c0a0a7d78172 100644 --- a/tests/www/views/test_views_decorators.py +++ b/tests/www/views/test_views_decorators.py @@ -17,6 +17,8 @@ # under the License. from __future__ import annotations +import os + import pytest from airflow.models import DagBag, Variable @@ -24,7 +26,7 @@ from airflow.utils.state import State from airflow.utils.types import DagRunType -from tests_common.test_utils.db import clear_db_runs, clear_db_variables +from tests_common.test_utils.db import clear_db_runs, clear_db_variables, parse_and_sync_to_db from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS from tests_common.test_utils.www import ( _check_last_log, @@ -42,8 +44,8 @@ @pytest.fixture(scope="module") def dagbag(): - DagBag(include_examples=True, read_dags_from_db=False).sync_to_db() - return DagBag(include_examples=True, read_dags_from_db=True) + parse_and_sync_to_db(os.devnull, include_examples=True) + return DagBag(read_dags_from_db=True) @pytest.fixture(scope="module") diff --git a/tests/www/views/test_views_log.py b/tests/www/views/test_views_log.py index be492487a4c8b..735c9a569e093 100644 --- a/tests/www/views/test_views_log.py +++ b/tests/www/views/test_views_log.py @@ -20,6 +20,7 @@ import copy import logging import logging.config +import os import pathlib import shutil import sys @@ -127,7 +128,7 @@ def _reset_modules_after_every_test(backup_modules): @pytest.fixture(autouse=True) -def dags(log_app, create_dummy_dag, session): +def dags(log_app, create_dummy_dag, testing_dag_bundle, session): dag, _ = create_dummy_dag( dag_id=DAG_ID, task_id=TASK_ID, @@ -143,10 +144,10 @@ def dags(log_app, create_dummy_dag, session): session=session, ) - bag = DagBag(include_examples=False) + bag = DagBag(os.devnull, include_examples=False) bag.bag_dag(dag=dag) bag.bag_dag(dag=dag_removed) - bag.sync_to_db(session=session) + bag.sync_to_db("testing", None, session=session) log_app.dag_bag = bag yield dag, dag_removed diff --git a/tests/www/views/test_views_tasks.py b/tests/www/views/test_views_tasks.py index 8289228b81dfe..10f42aca6c4f2 100644 --- a/tests/www/views/test_views_tasks.py +++ b/tests/www/views/test_views_tasks.py @@ -597,7 +597,7 @@ def heartbeat(self): def test_delete_dag_button_for_dag_on_scheduler_only(admin_client, dag_maker): with dag_maker() as dag: EmptyOperator(task_id="task") - dag.sync_to_db() + dag_maker.sync_dagbag_to_db() # The delete-dag URL should be generated correctly test_dag_id = dag.dag_id resp = admin_client.get("/", follow_redirects=True) @@ -606,12 +606,12 @@ def test_delete_dag_button_for_dag_on_scheduler_only(admin_client, dag_maker): @pytest.fixture -def new_dag_to_delete(): +def new_dag_to_delete(testing_dag_bundle): dag = DAG( "new_dag_to_delete", is_paused_upon_creation=True, schedule="0 * * * *", start_date=DEFAULT_DATE ) session = settings.Session() - dag.sync_to_db(session=session) + DAG.bulk_write_to_db("testing", None, [dag], session=session) return dag diff --git a/tests_common/pytest_plugin.py b/tests_common/pytest_plugin.py index 43b4733e38811..b4d3c089a3740 100644 --- a/tests_common/pytest_plugin.py +++ b/tests_common/pytest_plugin.py @@ -935,6 +935,19 @@ def create_dagrun_after(self, dagrun, **kwargs): **kwargs, ) + def sync_dagbag_to_db(self): + if not AIRFLOW_V_3_0_PLUS: + self.dagbag.sync_to_db() + return + + from airflow.models.dagbundle import DagBundleModel + + if self.session.query(DagBundleModel).filter(DagBundleModel.name == "dag_maker").count() == 0: + self.session.add(DagBundleModel(name="dag_maker")) + self.session.commit() + + self.dagbag.sync_to_db("dag_maker", None) + def __call__( self, dag_id="test_dag", diff --git a/tests_common/test_utils/db.py b/tests_common/test_utils/db.py index 8c5d59751f55c..58ad46372d711 100644 --- a/tests_common/test_utils/db.py +++ b/tests_common/test_utils/db.py @@ -17,6 +17,8 @@ # under the License. from __future__ import annotations +from typing import TYPE_CHECKING + from airflow.jobs.job import Job from airflow.models import ( Connection, @@ -51,15 +53,28 @@ ) from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS +if TYPE_CHECKING: + from pathlib import Path + def _bootstrap_dagbag(): from airflow.models.dag import DAG from airflow.models.dagbag import DagBag + if AIRFLOW_V_3_0_PLUS: + from airflow.dag_processing.bundles.manager import DagBundlesManager + with create_session() as session: + if AIRFLOW_V_3_0_PLUS: + DagBundlesManager().sync_bundles_to_db(session=session) + session.commit() + dagbag = DagBag() # Save DAGs in the ORM - dagbag.sync_to_db(session=session) + if AIRFLOW_V_3_0_PLUS: + dagbag.sync_to_db(bundle_name="dags-folder", bundle_version=None, session=session) + else: + dagbag.sync_to_db(session=session) # Deactivate the unknown ones DAG.deactivate_unknown_dags(dagbag.dags.keys(), session=session) @@ -92,6 +107,25 @@ def initial_db_init(): get_auth_manager().init() +def parse_and_sync_to_db(folder: Path | str, include_examples: bool = False): + from airflow.models.dagbag import DagBag + + if AIRFLOW_V_3_0_PLUS: + from airflow.dag_processing.bundles.manager import DagBundlesManager + + with create_session() as session: + if AIRFLOW_V_3_0_PLUS: + DagBundlesManager().sync_bundles_to_db(session=session) + session.commit() + + dagbag = DagBag(dag_folder=folder, include_examples=include_examples) + if AIRFLOW_V_3_0_PLUS: + dagbag.sync_to_db("dags-folder", None, session) + else: + dagbag.sync_to_db(session=session) # type: ignore[call-arg] + session.commit() + + def clear_db_runs(): with create_session() as session: session.query(Job).delete()