From 72ab1d54fdfe1587d7cb1238461332e808c1bcf8 Mon Sep 17 00:00:00 2001 From: Jed Cunningham <66968678+jedcunningham@users.noreply.github.com> Date: Thu, 9 Jan 2025 09:55:02 -0700 Subject: [PATCH] AIP-66: Add support for parsing DAG bundles (#45371) Let's start parsing DAG bundles! This moves us away from parsing a single local directory to being able to parse many different bundles, including optional support for versioning. This is just the basics - it keeps the parsing loop largely untouched. We still have a single list of "dag files" to parse, and queue of them. However, instead of just a path, this list and queue now contain `DagFilePath`s, which hold both a local path and the bundle its from. There are a number of things that are not fully functional at this stage, like versioned callbacks. These will be refactored later. There is enough churn with the basics (particularly with the number of test changes). Co-authored-by: Daniel Standish <15932138+dstandish@users.noreply.github.com> --- .../local_commands/dag_processor_command.py | 1 - airflow/dag_processing/bundles/manager.py | 17 + airflow/dag_processing/collection.py | 16 +- airflow/dag_processing/manager.py | 230 +++++---- airflow/jobs/scheduler_job_runner.py | 2 - airflow/models/dag.py | 21 +- airflow/models/dagbag.py | 20 +- airflow/utils/file.py | 9 - .../api_endpoints/test_dag_run_endpoint.py | 25 +- .../api_endpoints/test_dag_source_endpoint.py | 24 +- providers/tests/fab/auth_manager/conftest.py | 7 +- .../tests/openlineage/plugins/test_utils.py | 10 +- tests/api_connexion/conftest.py | 7 +- .../endpoints/test_dag_parsing.py | 14 +- .../endpoints/test_dag_run_endpoint.py | 4 +- .../endpoints/test_dag_source_endpoint.py | 13 +- .../endpoints/test_extra_link_endpoint.py | 4 +- .../test_mapped_task_instance_endpoint.py | 4 +- .../endpoints/test_task_endpoint.py | 5 +- .../endpoints/test_task_instance_endpoint.py | 8 +- tests/api_fastapi/conftest.py | 9 +- .../routes/public/test_dag_parsing.py | 15 +- .../core_api/routes/public/test_dag_run.py | 2 +- .../routes/public/test_dag_sources.py | 13 +- .../core_api/routes/public/test_dag_tags.py | 3 +- .../core_api/routes/public/test_dags.py | 4 +- .../routes/public/test_extra_links.py | 4 +- .../routes/public/test_task_instances.py | 20 +- .../core_api/routes/ui/test_assets.py | 2 +- .../core_api/routes/ui/test_dashboard.py | 2 +- .../core_api/routes/ui/test_structure.py | 4 +- .../remote_commands/test_asset_command.py | 6 +- .../remote_commands/test_backfill_command.py | 7 +- .../remote_commands/test_dag_command.py | 27 +- .../remote_commands/test_task_command.py | 8 +- tests/conftest.py | 36 ++ .../bundles/test_dag_bundle_manager.py | 16 +- tests/dag_processing/test_collection.py | 51 +- tests/dag_processing/test_manager.py | 456 +++++++++++------- tests/jobs/test_scheduler_job.py | 276 ++++++----- tests/models/test_dag.py | 68 +-- tests/models/test_dagcode.py | 11 +- tests/models/test_dagrun.py | 8 +- tests/models/test_serialized_dag.py | 14 +- tests/models/test_taskinstance.py | 12 +- tests/operators/test_trigger_dagrun.py | 128 ++--- tests/sensors/test_external_task_sensor.py | 4 +- tests/www/views/conftest.py | 6 +- tests/www/views/test_views_acl.py | 2 +- tests/www/views/test_views_decorators.py | 8 +- tests/www/views/test_views_log.py | 7 +- tests/www/views/test_views_tasks.py | 6 +- tests_common/pytest_plugin.py | 13 + tests_common/test_utils/db.py | 36 +- 54 files changed, 1043 insertions(+), 682 deletions(-) diff --git a/airflow/cli/commands/local_commands/dag_processor_command.py b/airflow/cli/commands/local_commands/dag_processor_command.py index 653c5f6bf577f..af2b65ff49b9f 100644 --- a/airflow/cli/commands/local_commands/dag_processor_command.py +++ b/airflow/cli/commands/local_commands/dag_processor_command.py @@ -39,7 +39,6 @@ def _create_dag_processor_job_runner(args: Any) -> DagProcessorJobRunner: job=Job(), processor=DagFileProcessorManager( processor_timeout=processor_timeout_seconds, - dag_directory=args.subdir, max_runs=args.num_runs, ), ) diff --git a/airflow/dag_processing/bundles/manager.py b/airflow/dag_processing/bundles/manager.py index 0a17ab3f8c5f5..49cccc02849e8 100644 --- a/airflow/dag_processing/bundles/manager.py +++ b/airflow/dag_processing/bundles/manager.py @@ -64,6 +64,23 @@ def parse_config(self) -> None: "Bundle config is not a list. Check config value" " for section `dag_bundles` and key `backends`." ) + + # example dags! + if conf.getboolean("core", "LOAD_EXAMPLES"): + from airflow import example_dags + + example_dag_folder = next(iter(example_dags.__path__)) + backends.append( + { + "name": "example_dags", + "classpath": "airflow.dag_processing.bundles.local.LocalDagBundle", + "kwargs": { + "local_folder": example_dag_folder, + "refresh_interval": conf.getint("scheduler", "dag_dir_list_interval"), + }, + } + ) + seen = set() for cfg in backends: name = cfg["name"] diff --git a/airflow/dag_processing/collection.py b/airflow/dag_processing/collection.py index 65d6dbea77bf4..6e0b627198995 100644 --- a/airflow/dag_processing/collection.py +++ b/airflow/dag_processing/collection.py @@ -74,11 +74,14 @@ log = logging.getLogger(__name__) -def _create_orm_dags(dags: Iterable[MaybeSerializedDAG], *, session: Session) -> Iterator[DagModel]: +def _create_orm_dags( + bundle_name: str, dags: Iterable[MaybeSerializedDAG], *, session: Session +) -> Iterator[DagModel]: for dag in dags: orm_dag = DagModel(dag_id=dag.dag_id) if dag.is_paused_upon_creation is not None: orm_dag.is_paused = dag.is_paused_upon_creation + orm_dag.bundle_name = bundle_name log.info("Creating ORM DAG for %s", dag.dag_id) session.add(orm_dag) yield orm_dag @@ -270,6 +273,8 @@ def _update_import_errors(files_parsed: set[str], import_errors: dict[str, str], def update_dag_parsing_results_in_db( + bundle_name: str, + bundle_version: str | None, dags: Collection[MaybeSerializedDAG], import_errors: dict[str, str], warnings: set[DagWarning], @@ -307,7 +312,7 @@ def update_dag_parsing_results_in_db( ) log.debug("Calling the DAG.bulk_sync_to_db method") try: - DAG.bulk_write_to_db(dags, session=session) + DAG.bulk_write_to_db(bundle_name, bundle_version, dags, session=session) # Write Serialized DAGs to DB, capturing errors # Write Serialized DAGs to DB, capturing errors for dag in dags: @@ -346,6 +351,8 @@ class DagModelOperation(NamedTuple): """Collect DAG objects and perform database operations for them.""" dags: dict[str, MaybeSerializedDAG] + bundle_name: str + bundle_version: str | None def find_orm_dags(self, *, session: Session) -> dict[str, DagModel]: """Find existing DagModel objects from DAG objects.""" @@ -365,7 +372,8 @@ def add_dags(self, *, session: Session) -> dict[str, DagModel]: orm_dags.update( (model.dag_id, model) for model in _create_orm_dags( - (dag for dag_id, dag in self.dags.items() if dag_id not in orm_dags), + bundle_name=self.bundle_name, + dags=(dag for dag_id, dag in self.dags.items() if dag_id not in orm_dags), session=session, ) ) @@ -430,6 +438,8 @@ def update_dags( dm.timetable_summary = dag.timetable.summary dm.timetable_description = dag.timetable.description dm.asset_expression = dag.timetable.asset_condition.as_expression() + dm.bundle_name = self.bundle_name + dm.latest_bundle_version = self.bundle_version last_automated_run: DagRun | None = run_info.latest_runs.get(dag.dag_id) if last_automated_run is None: diff --git a/airflow/dag_processing/manager.py b/airflow/dag_processing/manager.py index 99257b991114e..a66788dc8fbe3 100644 --- a/airflow/dag_processing/manager.py +++ b/airflow/dag_processing/manager.py @@ -51,6 +51,7 @@ from airflow.dag_processing.processor import DagFileParsingResult, DagFileProcessorProcess from airflow.models.dag import DagModel from airflow.models.dagbag import DagPriorityParsingRequest +from airflow.models.dagbundle import DagBundleModel from airflow.models.dagwarning import DagWarning from airflow.models.db_callback_request import DbCallbackRequest from airflow.models.errors import ParseImportError @@ -68,7 +69,7 @@ set_new_process_group, ) from airflow.utils.retries import retry_db_transaction -from airflow.utils.session import NEW_SESSION, provide_session +from airflow.utils.session import NEW_SESSION, create_session, provide_session from airflow.utils.sqlalchemy import prohibit_commit, with_row_locks if TYPE_CHECKING: @@ -76,6 +77,8 @@ from sqlalchemy.orm import Session + from airflow.dag_processing.bundles.base import BaseDagBundle + class DagParsingStat(NamedTuple): """Information on processing progress.""" @@ -99,6 +102,13 @@ class DagFileStat: log = logging.getLogger("airflow.processor_manager") +class DagFileInfo(NamedTuple): + """Information about a DAG file.""" + + path: str # absolute path of the file + bundle_name: str + + class DagFileProcessorAgent(LoggingMixin, MultiprocessingStartMethodMixin): """ Agent for DAG file processing. @@ -109,8 +119,6 @@ class DagFileProcessorAgent(LoggingMixin, MultiprocessingStartMethodMixin): This class runs in the main `airflow scheduler` process when standalone_dag_processor is not enabled. - :param dag_directory: Directory where DAG definitions are kept. All - files in file_paths should be under this directory :param max_runs: The number of times to parse and schedule each file. -1 for unlimited. :param processor_timeout: How long to wait before timing out a DAG file processor @@ -118,12 +126,10 @@ class DagFileProcessorAgent(LoggingMixin, MultiprocessingStartMethodMixin): def __init__( self, - dag_directory: os.PathLike, max_runs: int, processor_timeout: timedelta, ): super().__init__() - self._dag_directory: os.PathLike = dag_directory self._max_runs = max_runs self._processor_timeout = processor_timeout self._process: multiprocessing.Process | None = None @@ -146,7 +152,6 @@ def start(self) -> None: process = context.Process( target=type(self)._run_processor_manager, args=( - self._dag_directory, self._max_runs, self._processor_timeout, child_signal_conn, @@ -171,7 +176,6 @@ def get_callbacks_pipe(self) -> MultiprocessingConnection: @staticmethod def _run_processor_manager( - dag_directory: os.PathLike, max_runs: int, processor_timeout: timedelta, signal_conn: MultiprocessingConnection, @@ -184,7 +188,6 @@ def _run_processor_manager( setproctitle("airflow scheduler -- DagFileProcessorManager") reload_configuration_for_dag_processing() processor_manager = DagFileProcessorManager( - dag_directory=dag_directory, max_runs=max_runs, processor_timeout=processor_timeout.total_seconds(), signal_conn=signal_conn, @@ -303,15 +306,12 @@ class DagFileProcessorManager: processors finish, more are launched. The files are processed over and over again, but no more often than the specified interval. - :param dag_directory: Directory where DAG definitions are kept. All - files in file_paths should be under this directory :param max_runs: The number of times to parse and schedule each file. -1 for unlimited. :param processor_timeout: How long to wait before timing out a DAG file processor :param signal_conn: connection to communicate signal with processor agent. """ - _dag_directory: os.PathLike[str] = attrs.field(validator=_resolve_path) max_runs: int processor_timeout: float = attrs.field(factory=_config_int_factory("core", "dag_file_processor_timeout")) selector: selectors.BaseSelector = attrs.field(factory=selectors.DefaultSelector) @@ -329,7 +329,6 @@ class DagFileProcessorManager: factory=_config_int_factory("scheduler", "min_file_process_interval") ) stale_dag_threshold: float = attrs.field(factory=_config_int_factory("scheduler", "stale_dag_threshold")) - last_dag_dir_refresh_time: float = attrs.field(default=0, init=False) log: logging.Logger = attrs.field(default=log, init=False) @@ -342,11 +341,16 @@ class DagFileProcessorManager: heartbeat: Callable[[], None] = attrs.field(default=lambda: None) """An overridable heartbeat called once every time around the loop""" - _file_paths: list[str] = attrs.field(factory=list, init=False) - _file_path_queue: deque[str] = attrs.field(factory=deque, init=False) - _file_stats: dict[str, DagFileStat] = attrs.field(factory=lambda: defaultdict(DagFileStat), init=False) + _file_paths: list[DagFileInfo] = attrs.field(factory=list, init=False) + _file_path_queue: deque[DagFileInfo] = attrs.field(factory=deque, init=False) + _file_stats: dict[DagFileInfo, DagFileStat] = attrs.field( + factory=lambda: defaultdict(DagFileStat), init=False + ) + + _dag_bundles: list[BaseDagBundle] = attrs.field(factory=list, init=False) + _bundle_versions: dict[str, str] = attrs.field(factory=dict, init=False) - _processors: dict[str, DagFileProcessorProcess] = attrs.field(factory=dict, init=False) + _processors: dict[DagFileInfo, DagFileProcessorProcess] = attrs.field(factory=dict, init=False) _parsing_start_time: float = attrs.field(init=False) _num_run: int = attrs.field(default=0, init=False) @@ -393,14 +397,17 @@ def run(self): self.log.info("Processing files using up to %s processes at a time ", self._parallelism) self.log.info("Process each file at most once every %s seconds", self._file_process_interval) - self.log.info( - "Checking for new files in %s every %s seconds", self._dag_directory, self.dag_dir_list_interval - ) + # TODO: AIP-66 move to report by bundle self.log.info( + # "Checking for new files in %s every %s seconds", self._dag_directory, self.dag_dir_list_interval + # ) from airflow.dag_processing.bundles.manager import DagBundlesManager DagBundlesManager().sync_bundles_to_db() + self.log.info("Getting all DAG bundles") + self._dag_bundles = list(DagBundlesManager().get_all_dag_bundles()) + return self._run_parsing_loop() def _scan_stale_dags(self): @@ -413,7 +420,6 @@ def _scan_stale_dags(self): } self.deactivate_stale_dags( last_parsed=last_parsed, - dag_directory=self.get_dag_directory(), stale_dag_threshold=self.stale_dag_threshold, ) self._last_deactivate_stale_dags_time = time.monotonic() @@ -421,14 +427,16 @@ def _scan_stale_dags(self): @provide_session def deactivate_stale_dags( self, - last_parsed: dict[str, datetime | None], - dag_directory: str, + last_parsed: dict[DagFileInfo, datetime | None], stale_dag_threshold: int, session: Session = NEW_SESSION, ): """Detect and deactivate DAGs which are no longer present in files.""" to_deactivate = set() - query = select(DagModel.dag_id, DagModel.fileloc, DagModel.last_parsed_time).where(DagModel.is_active) + query = select( + DagModel.dag_id, DagModel.bundle_name, DagModel.fileloc, DagModel.last_parsed_time + ).where(DagModel.is_active) + # TODO: AIP-66 by bundle! dags_parsed = session.execute(query) for dag in dags_parsed: @@ -436,9 +444,11 @@ def deactivate_stale_dags( # last_parsed_time is the processor_timeout. Longer than that indicates that the DAG is # no longer present in the file. We have a stale_dag_threshold configured to prevent a # significant delay in deactivation of stale dags when a large timeout is configured + dag_file_path = DagFileInfo(path=dag.fileloc, bundle_name=dag.bundle_name) if ( - dag.fileloc in last_parsed - and (dag.last_parsed_time + timedelta(seconds=stale_dag_threshold)) < last_parsed[dag.fileloc] + dag_file_path in last_parsed + and (dag.last_parsed_time + timedelta(seconds=stale_dag_threshold)) + < last_parsed[dag_file_path] ): self.log.info("DAG %s is missing and will be deactivated.", dag.dag_id) to_deactivate.add(dag.dag_id) @@ -471,17 +481,17 @@ def _run_parsing_loop(self): self.heartbeat() - refreshed_dag_dir = self._refresh_dag_dir() - self._kill_timed_out_processors() + self._refresh_dag_bundles() + if not self._file_path_queue: # Generate more file paths to process if we processed all the files already. Note for this to # clear down, we must have cleared all files found from scanning the dags dir _and_ have # cleared all files added as a result of callbacks self.prepare_file_path_queue() self.emit_metrics() - elif refreshed_dag_dir: + else: # if new files found in dag dir, add them self.add_new_file_path_to_queue() @@ -572,6 +582,8 @@ def _read_from_direct_scheduler_conn(self, conn: MultiprocessingConnection) -> b def _refresh_requested_filelocs(self) -> None: """Refresh filepaths from dag dir as requested by users via APIs.""" + return + # TODO: AIP-66 make bundle aware - fileloc will be relative (eventually), thus not unique in order to know what file to repase # Get values from DB table filelocs = self._get_priority_filelocs() for fileloc in filelocs: @@ -609,6 +621,9 @@ def _fetch_callbacks( def _add_callback_to_queue(self, request: CallbackRequest): self.log.debug("Queuing %s CallbackRequest: %s", type(request).__name__, request) + self.log.warning("Callbacks are not implemented yet!") + # TODO: AIP-66 make callbacks bundle aware + return self._callback_to_execute[request.full_filepath].append(request) if request.full_filepath in self._file_path_queue: # Remove file paths matching request.full_filepath from self._file_path_queue @@ -631,25 +646,70 @@ def _get_priority_filelocs(cls, session: Session = NEW_SESSION): session.delete(request) return filelocs - def _refresh_dag_dir(self) -> bool: - """Refresh file paths from dag dir if we haven't done it for too long.""" - now = time.monotonic() - elapsed_time_since_refresh = now - self.last_dag_dir_refresh_time - if elapsed_time_since_refresh <= self.dag_dir_list_interval: - return False + def _refresh_dag_bundles(self): + """Refresh DAG bundles, if required.""" + now = timezone.utcnow() - # Build up a list of Python files that could contain DAGs - self.log.info("Searching for files in %s", self._dag_directory) - self._file_paths = list_py_file_paths(self._dag_directory) - self.last_dag_dir_refresh_time = now - self.log.info("There are %s files in %s", len(self._file_paths), self._dag_directory) - self.set_file_paths(self._file_paths) + self.log.info("Refreshing DAG bundles") + + for bundle in self._dag_bundles: + # TODO: AIP-66 test to make sure we get a fresh record from the db and it's not cached + with create_session() as session: + bundle_model = session.get(DagBundleModel, bundle.name) + elapsed_time_since_refresh = ( + now - (bundle_model.last_refreshed or timezone.utc_epoch()) + ).total_seconds() + current_version = bundle.get_current_version() + if ( + not elapsed_time_since_refresh > bundle.refresh_interval + ) and bundle_model.latest_version == current_version: + self.log.info("Not time to refresh %s", bundle.name) + continue - try: - self.log.debug("Removing old import errors") + try: + bundle.refresh() + except Exception: + self.log.exception("Error refreshing bundle %s", bundle.name) + continue + + bundle_model.last_refreshed = now + + new_version = bundle.get_current_version() + if bundle.supports_versioning: + # We can short-circuit the rest of the refresh if the version hasn't changed + # and we've already fully "refreshed" this bundle before in this dag processor. + if current_version == new_version and bundle.name in self._bundle_versions: + self.log.debug("Bundle %s version not changed after refresh", bundle.name) + continue + + bundle_model.latest_version = new_version + + self.log.info("Version changed for %s, new version: %s", bundle.name, new_version) + + bundle_file_paths = self._find_files_in_bundle(bundle) + + new_file_paths = [f for f in self._file_paths if f.bundle_name != bundle.name] + new_file_paths.extend( + DagFileInfo(path=path, bundle_name=bundle.name) for path in bundle_file_paths + ) + self.set_file_paths(new_file_paths) + + self.deactivate_deleted_dags(bundle_file_paths) self.clear_nonexistent_import_errors() - except Exception: - self.log.exception("Error removing old import errors") + + self._bundle_versions[bundle.name] = bundle.get_current_version() + + def _find_files_in_bundle(self, bundle: BaseDagBundle) -> list[str]: + """Refresh file paths from bundle dir.""" + # Build up a list of Python files that could contain DAGs + self.log.info("Searching for files in %s at %s", bundle.name, bundle.path) + file_paths = list_py_file_paths(bundle.path) + self.log.info("Found %s files for bundle %s", len(file_paths), bundle.name) + + return file_paths + + def deactivate_deleted_dags(self, file_paths: set[str]) -> None: + """Deactivate DAGs that come from files that are no longer present.""" def _iter_dag_filelocs(fileloc: str) -> Iterator[str]: """ @@ -668,12 +728,11 @@ def _iter_dag_filelocs(fileloc: str) -> Iterator[str]: except zipfile.BadZipFile: self.log.exception("There was an error accessing ZIP file %s %s", fileloc) - dag_filelocs = {full_loc for path in self._file_paths for full_loc in _iter_dag_filelocs(path)} + dag_filelocs = {full_loc for path in file_paths for full_loc in _iter_dag_filelocs(path)} + # TODO: AIP-66: make bundle aware, as fileloc won't be unique long term. DagModel.deactivate_deleted_dags(dag_filelocs) - return True - def _print_stat(self): """Occasionally print out stats about how fast the files are getting processed.""" if 0 < self.print_stats_interval < time.monotonic() - self.last_stat_print_time: @@ -690,15 +749,18 @@ def clear_nonexistent_import_errors(self, session=NEW_SESSION): :param session: session for ORM operations """ self.log.debug("Removing old import errors") - query = delete(ParseImportError) + try: + query = delete(ParseImportError) - if self._file_paths: - query = query.where( - ParseImportError.filename.notin_(self._file_paths), - ) + if self._file_paths: + query = query.where( + ParseImportError.filename.notin_([f.path for f in self._file_paths]), + ) - session.execute(query.execution_options(synchronize_session="fetch")) - session.commit() + session.execute(query.execution_options(synchronize_session="fetch")) + session.commit() + except Exception: + self.log.exception("Error removing old import errors") def _log_file_processing_stats(self, known_file_paths): """ @@ -736,7 +798,7 @@ def _log_file_processing_stats(self, known_file_paths): proc = self._processors.get(file_path) num_dags = stat.num_dags num_errors = stat.import_errors - file_name = Path(file_path).stem + file_name = Path(file_path.path).stem processor_pid = proc.pid if proc else None processor_start_time = proc.start_time if proc else None runtime = (now - processor_start_time) if processor_start_time else None @@ -793,15 +855,9 @@ def _log_file_processing_stats(self, known_file_paths): self.log.info(log_str) - def get_dag_directory(self) -> str | None: - """Return the dag_directory as a string.""" - if self._dag_directory is not None: - return os.fspath(self._dag_directory) - return None - - def set_file_paths(self, new_file_paths): + def set_file_paths(self, new_file_paths: list[DagFileInfo]): """ - Update this with a new set of paths to DAG definition files. + Update this with a new set of DagFilePaths to DAG definition files. :param new_file_paths: list of paths to DAG definition files :return: None @@ -812,9 +868,10 @@ def set_file_paths(self, new_file_paths): self._file_path_queue = deque(x for x in self._file_path_queue if x in new_file_paths) Stats.gauge("dag_processing.file_path_queue_size", len(self._file_path_queue)) - callback_paths_to_del = [x for x in self._callback_to_execute if x not in new_file_paths] - for path_to_del in callback_paths_to_del: - del self._callback_to_execute[path_to_del] + # TODO: AIP-66 make callbacks bundle aware + # callback_paths_to_del = [x for x in self._callback_to_execute if x not in new_file_paths] + # for path_to_del in callback_paths_to_del: + # del self._callback_to_execute[path_to_del] # Stop processors that are working on deleted files filtered_processors = {} @@ -838,33 +895,35 @@ def set_file_paths(self, new_file_paths): def _collect_results(self, session: Session = NEW_SESSION): # TODO: Use an explicit session in this fn finished = [] - for path, proc in self._processors.items(): + for dag_file, proc in self._processors.items(): if not proc.is_ready: # This processor hasn't finished yet, or we haven't read all the output from it yet continue - finished.append(path) + finished.append(dag_file) # Collect the DAGS and import errors into the DB, emit metrics etc. - self._file_stats[path] = process_parse_results( + self._file_stats[dag_file] = process_parse_results( run_duration=time.time() - proc.start_time, finish_time=timezone.utcnow(), - run_count=self._file_stats[path].run_count, + run_count=self._file_stats[dag_file].run_count, + bundle_name=dag_file.bundle_name, + bundle_version=self._bundle_versions[dag_file.bundle_name], parsing_result=proc.parsing_result, - path=path, session=session, ) - for path in finished: - self._processors.pop(path) + for dag_file in finished: + self._processors.pop(dag_file) - def _create_process(self, file_path): + def _create_process(self, dag_file: DagFileInfo) -> DagFileProcessorProcess: id = uuid7() - callback_to_execute_for_file = self._callback_to_execute.pop(file_path, []) + # callback_to_execute_for_file = self._callback_to_execute.pop(file_path, []) + callback_to_execute_for_file: list[CallbackRequest] = [] return DagFileProcessorProcess.start( id=id, - path=file_path, + path=dag_file.path, callbacks=callback_to_execute_for_file, selector=self.selector, ) @@ -914,7 +973,7 @@ def prepare_file_path_queue(self): for file_path in self._file_paths: if is_mtime_mode: try: - files_with_mtime[file_path] = os.path.getmtime(file_path) + files_with_mtime[file_path] = os.path.getmtime(file_path.path) except FileNotFoundError: self.log.warning("Skipping processing of missing file: %s", file_path) self._file_stats.pop(file_path, None) @@ -973,7 +1032,8 @@ def prepare_file_path_queue(self): ) self.log.debug( - "Queuing the following files for processing:\n\t%s", "\n\t".join(files_paths_to_queue) + "Queuing the following files for processing:\n\t%s", + "\n\t".join(f.path for f in files_paths_to_queue), ) self._add_paths_to_queue(files_paths_to_queue, False) Stats.incr("dag_processing.file_path_queue_update_count") @@ -1011,7 +1071,7 @@ def _kill_timed_out_processors(self): for proc in processors_to_remove: self._processors.pop(proc) - def _add_paths_to_queue(self, file_paths_to_enqueue: list[str], add_at_front: bool): + def _add_paths_to_queue(self, file_paths_to_enqueue: list[DagFileInfo], add_at_front: bool): """Add stuff to the back or front of the file queue, unless it's already present.""" new_file_paths = list(p for p in file_paths_to_enqueue if p not in self._file_path_queue) if add_at_front: @@ -1091,7 +1151,8 @@ def process_parse_results( run_duration: float, finish_time: datetime, run_count: int, - path: str, + bundle_name: str, + bundle_version: str | None, parsing_result: DagFileParsingResult | None, session: Session, ) -> DagFileStat: @@ -1102,15 +1163,18 @@ def process_parse_results( run_count=run_count + 1, ) - file_name = Path(path).stem - Stats.timing(f"dag_processing.last_duration.{file_name}", stat.last_duration) - Stats.timing("dag_processing.last_duration", stat.last_duration, tags={"file_name": file_name}) + # TODO: AIP-66 emit metrics + # file_name = Path(dag_file.path).stem + # Stats.timing(f"dag_processing.last_duration.{file_name}", stat.last_duration) + # Stats.timing("dag_processing.last_duration", stat.last_duration, tags={"file_name": file_name}) if parsing_result is None: stat.import_errors = 1 else: # record DAGs and import errors to database update_dag_parsing_results_in_db( + bundle_name=bundle_name, + bundle_version=bundle_version, dags=parsing_result.serialized_dags, import_errors=parsing_result.import_errors or {}, warnings=set(parsing_result.warnings or []), diff --git a/airflow/jobs/scheduler_job_runner.py b/airflow/jobs/scheduler_job_runner.py index 17643b0212195..4718b824830b7 100644 --- a/airflow/jobs/scheduler_job_runner.py +++ b/airflow/jobs/scheduler_job_runner.py @@ -30,7 +30,6 @@ from datetime import timedelta from functools import lru_cache, partial from itertools import groupby -from pathlib import Path from typing import TYPE_CHECKING, Any, Callable from deprecated import deprecated @@ -922,7 +921,6 @@ def _execute(self) -> int | None: processor_timeout = timedelta(seconds=processor_timeout_seconds) if not self._standalone_dag_processor and not self.processor_agent: self.processor_agent = DagFileProcessorAgent( - dag_directory=Path(self.subdir), max_runs=self.num_times_parse_dags, processor_timeout=processor_timeout, ) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 4abf24af52537..82b4ca70819b9 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -770,6 +770,16 @@ def get_is_paused(self, session=NEW_SESSION) -> None: """Return a boolean indicating whether this DAG is paused.""" return session.scalar(select(DagModel.is_paused).where(DagModel.dag_id == self.dag_id)) + @provide_session + def get_bundle_name(self, session=NEW_SESSION) -> None: + """Return the bundle name this DAG is in.""" + return session.scalar(select(DagModel.bundle_name).where(DagModel.dag_id == self.dag_id)) + + @provide_session + def get_latest_bundle_version(self, session=NEW_SESSION) -> None: + """Return the bundle name this DAG is in.""" + return session.scalar(select(DagModel.latest_bundle_version).where(DagModel.dag_id == self.dag_id)) + @methodtools.lru_cache(maxsize=None) @classmethod def get_serialized_fields(cls): @@ -1832,6 +1842,8 @@ def create_dagrun( @provide_session def bulk_write_to_db( cls, + bundle_name: str, + bundle_version: str | None, dags: Collection[MaybeSerializedDAG], session: Session = NEW_SESSION, ): @@ -1847,7 +1859,9 @@ def bulk_write_to_db( from airflow.dag_processing.collection import AssetModelOperation, DagModelOperation log.info("Sync %s DAGs", len(dags)) - dag_op = DagModelOperation({dag.dag_id: dag for dag in dags}) # type: ignore[misc] + dag_op = DagModelOperation( + bundle_name=bundle_name, bundle_version=bundle_version, dags={d.dag_id: d for d in dags} + ) # type: ignore[misc] orm_dags = dag_op.add_dags(session=session) dag_op.update_dags(orm_dags, session=session) @@ -1873,7 +1887,10 @@ def sync_to_db(self, session=NEW_SESSION): :return: None """ - self.bulk_write_to_db([self], session=session) + # TODO: AIP-66 should this be in the model? + bundle_name = self.get_bundle_name(session=session) + bundle_version = self.get_latest_bundle_version(session=session) + self.bulk_write_to_db(bundle_name, bundle_version, [self], session=session) def get_default_view(self): """Allow backward compatible jinja2 templates.""" diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py index 7d0d2efc1bf6e..26f58b26929ae 100644 --- a/airflow/models/dagbag.py +++ b/airflow/models/dagbag.py @@ -568,11 +568,17 @@ def collect_dags( # Ensure dag_folder is a str -- it may have been a pathlib.Path dag_folder = correct_maybe_zipped(str(dag_folder)) - for filepath in list_py_file_paths( - dag_folder, - safe_mode=safe_mode, - include_examples=include_examples, - ): + + files_to_parse = list_py_file_paths(dag_folder, safe_mode=safe_mode) + + if include_examples: + from airflow import example_dags + + example_dag_folder = next(iter(example_dags.__path__)) + + files_to_parse.extend(list_py_file_paths(example_dag_folder, safe_mode=safe_mode)) + + for filepath in files_to_parse: try: file_parse_start_dttm = timezone.utcnow() found_dags = self.process_file(filepath, only_if_updated=only_if_updated, safe_mode=safe_mode) @@ -626,11 +632,13 @@ def dagbag_report(self): return report @provide_session - def sync_to_db(self, session: Session = NEW_SESSION): + def sync_to_db(self, bundle_name: str, bundle_version: str | None, session: Session = NEW_SESSION): """Save attributes about list of DAG to the DB.""" from airflow.dag_processing.collection import update_dag_parsing_results_in_db update_dag_parsing_results_in_db( + bundle_name, + bundle_version, self.dags.values(), # type: ignore[arg-type] # We should create a proto for DAG|LazySerializedDAG self.import_errors, self.dag_warnings, diff --git a/airflow/utils/file.py b/airflow/utils/file.py index 5c3e454e294af..7e72b26792e9e 100644 --- a/airflow/utils/file.py +++ b/airflow/utils/file.py @@ -245,7 +245,6 @@ def find_path_from_directory( def list_py_file_paths( directory: str | os.PathLike[str] | None, safe_mode: bool = conf.getboolean("core", "DAG_DISCOVERY_SAFE_MODE", fallback=True), - include_examples: bool | None = None, ) -> list[str]: """ Traverse a directory and look for Python files. @@ -255,11 +254,8 @@ def list_py_file_paths( contains Airflow DAG definitions. If not provided, use the core.DAG_DISCOVERY_SAFE_MODE configuration setting. If not set, default to safe. - :param include_examples: include example DAGs :return: a list of paths to Python files in the specified directory """ - if include_examples is None: - include_examples = conf.getboolean("core", "LOAD_EXAMPLES") file_paths: list[str] = [] if directory is None: file_paths = [] @@ -267,11 +263,6 @@ def list_py_file_paths( file_paths = [str(directory)] elif os.path.isdir(directory): file_paths.extend(find_dag_file_paths(directory, safe_mode)) - if include_examples: - from airflow import example_dags - - example_dag_folder = next(iter(example_dags.__path__)) - file_paths.extend(list_py_file_paths(example_dag_folder, safe_mode, include_examples=False)) return file_paths diff --git a/providers/tests/fab/auth_manager/api_endpoints/test_dag_run_endpoint.py b/providers/tests/fab/auth_manager/api_endpoints/test_dag_run_endpoint.py index 33df2e9bbc407..e745d3d655bdc 100644 --- a/providers/tests/fab/auth_manager/api_endpoints/test_dag_run_endpoint.py +++ b/providers/tests/fab/auth_manager/api_endpoints/test_dag_run_endpoint.py @@ -20,7 +20,7 @@ import pytest -from airflow.models.dag import DAG, DagModel +from airflow.models.dag import DagModel from airflow.models.dagrun import DagRun from airflow.models.param import Param from airflow.providers.fab.www.security import permissions @@ -125,29 +125,26 @@ def setup_attrs(self, configured_app) -> None: clear_db_serialized_dags() clear_db_dags() + @pytest.fixture(autouse=True) + def create_dag(self, dag_maker, setup_attrs): + with dag_maker( + "TEST_DAG_ID", schedule=None, params={"validated_number": Param(1, minimum=1, maximum=10)} + ): + pass + + dag_maker.sync_dagbag_to_db() + def teardown_method(self) -> None: clear_db_runs() clear_db_dags() clear_db_serialized_dags() - def _create_dag(self, dag_id): - dag_instance = DagModel(dag_id=dag_id) - dag_instance.is_active = True - with create_session() as session: - session.add(dag_instance) - dag = DAG(dag_id=dag_id, schedule=None, params={"validated_number": Param(1, minimum=1, maximum=10)}) - self.app.dag_bag.bag_dag(dag) - self.app.dag_bag.sync_to_db() - return dag_instance - def _create_test_dag_run(self, state=DagRunState.RUNNING, extra_dag=False, commit=True, idx_start=1): dag_runs = [] dags = [] triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} for i in range(idx_start, idx_start + 2): - if i == 1: - dags.append(DagModel(dag_id="TEST_DAG_ID", is_active=True)) dagrun_model = DagRun( dag_id="TEST_DAG_ID", run_id=f"TEST_DAG_RUN_ID_{i}", @@ -247,7 +244,6 @@ def test_should_return_accessible_with_tilde_as_dag_id_and_dag_level_permissions class TestPostDagRun(TestDagRunEndpoint): def test_dagrun_trigger_with_dag_level_permissions(self): - self._create_dag("TEST_DAG_ID") response = self.client.post( "api/v1/dags/TEST_DAG_ID/dagRuns", json={"conf": {"validated_number": 1}}, @@ -260,7 +256,6 @@ def test_dagrun_trigger_with_dag_level_permissions(self): ["test_dag_view_only", "test_view_dags", "test_granular_permissions"], ) def test_should_raises_403_unauthorized(self, username): - self._create_dag("TEST_DAG_ID") response = self.client.post( "api/v1/dags/TEST_DAG_ID/dagRuns", json={ diff --git a/providers/tests/fab/auth_manager/api_endpoints/test_dag_source_endpoint.py b/providers/tests/fab/auth_manager/api_endpoints/test_dag_source_endpoint.py index 8461e5bf0f170..66cd6477c9e9b 100644 --- a/providers/tests/fab/auth_manager/api_endpoints/test_dag_source_endpoint.py +++ b/providers/tests/fab/auth_manager/api_endpoints/test_dag_source_endpoint.py @@ -18,7 +18,6 @@ import ast import os -from typing import TYPE_CHECKING import pytest @@ -26,7 +25,12 @@ from airflow.providers.fab.www.security import permissions from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user -from tests_common.test_utils.db import clear_db_dag_code, clear_db_dags, clear_db_serialized_dags +from tests_common.test_utils.db import ( + clear_db_dag_code, + clear_db_dags, + clear_db_serialized_dags, + parse_and_sync_to_db, +) from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS pytestmark = [ @@ -34,11 +38,7 @@ pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), ] -if TYPE_CHECKING: - from airflow.models.dag import DAG -ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)) -EXAMPLE_DAG_FILE = os.path.join("airflow", "example_dags", "example_bash_operator.py") EXAMPLE_DAG_ID = "example_bash_operator" TEST_DAG_ID = "latest_only" NOT_READABLE_DAG_ID = "latest_only_with_trigger" @@ -97,9 +97,9 @@ def _get_dag_file_docstring(fileloc: str) -> str | None: return docstring def test_should_respond_403_not_readable(self, url_safe_serializer): - dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE) - dagbag.sync_to_db() - dag: DAG = dagbag.dags[NOT_READABLE_DAG_ID] + parse_and_sync_to_db(os.devnull, include_examples=True) + dagbag = DagBag(read_dags_from_db=True) + dag = dagbag.get_dag(NOT_READABLE_DAG_ID) response = self.client.get( f"/api/v1/dagSources/{dag.dag_id}", @@ -114,9 +114,9 @@ def test_should_respond_403_not_readable(self, url_safe_serializer): assert read_dag.status_code == 403 def test_should_respond_403_some_dags_not_readable_in_the_file(self, url_safe_serializer): - dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE) - dagbag.sync_to_db() - dag: DAG = dagbag.dags[TEST_MULTIPLE_DAGS_ID] + parse_and_sync_to_db(os.devnull, include_examples=True) + dagbag = DagBag(read_dags_from_db=True) + dag = dagbag.get_dag(TEST_MULTIPLE_DAGS_ID) response = self.client.get( f"/api/v1/dagSources/{dag.dag_id}", diff --git a/providers/tests/fab/auth_manager/conftest.py b/providers/tests/fab/auth_manager/conftest.py index d400a7b86a027..f26a08d19c3e3 100644 --- a/providers/tests/fab/auth_manager/conftest.py +++ b/providers/tests/fab/auth_manager/conftest.py @@ -16,11 +16,14 @@ # under the License. from __future__ import annotations +import os + import pytest from airflow.www import app from tests_common.test_utils.config import conf_vars +from tests_common.test_utils.db import parse_and_sync_to_db from tests_common.test_utils.decorators import dont_initialize_flask_app_submodules @@ -72,5 +75,5 @@ def set_auth_role_public(request): def dagbag(): from airflow.models import DagBag - DagBag(include_examples=True, read_dags_from_db=False).sync_to_db() - return DagBag(include_examples=True, read_dags_from_db=True) + parse_and_sync_to_db(os.devnull, include_examples=True) + return DagBag(read_dags_from_db=True) diff --git a/providers/tests/openlineage/plugins/test_utils.py b/providers/tests/openlineage/plugins/test_utils.py index e1e3e8732690a..490061db4f594 100644 --- a/providers/tests/openlineage/plugins/test_utils.py +++ b/providers/tests/openlineage/plugins/test_utils.py @@ -28,7 +28,7 @@ from openlineage.client.utils import RedactMixin from pkg_resources import parse_version -from airflow.models import DAG as AIRFLOW_DAG, DagModel +from airflow.models import DagModel from airflow.providers.common.compat.assets import Asset from airflow.providers.openlineage.plugins.facets import AirflowDebugRunFacet from airflow.providers.openlineage.utils.utils import ( @@ -91,12 +91,14 @@ def test_get_airflow_debug_facet_logging_set_to_debug(mock_debug_mode, mock_get_ @pytest.mark.db_test -def test_get_dagrun_start_end(): +def test_get_dagrun_start_end(dag_maker): start_date = datetime.datetime(2022, 1, 1) end_date = datetime.datetime(2022, 1, 1, hour=2) - dag = AIRFLOW_DAG("test", start_date=start_date, end_date=end_date, schedule="@once") - AIRFLOW_DAG.bulk_write_to_db([dag]) + with dag_maker("test", start_date=start_date, end_date=end_date, schedule="@once") as dag: + pass + dag_maker.sync_dagbag_to_db() dag_model = DagModel.get_dagmodel(dag.dag_id) + run_id = str(uuid.uuid1()) triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} dagrun = dag.create_dagrun( diff --git a/tests/api_connexion/conftest.py b/tests/api_connexion/conftest.py index 39e8da7212ae7..fa5d4c9250065 100644 --- a/tests/api_connexion/conftest.py +++ b/tests/api_connexion/conftest.py @@ -16,11 +16,14 @@ # under the License. from __future__ import annotations +import os + import pytest from airflow.www import app from tests_common.test_utils.config import conf_vars +from tests_common.test_utils.db import parse_and_sync_to_db from tests_common.test_utils.decorators import dont_initialize_flask_app_submodules @@ -64,5 +67,5 @@ def session(): def dagbag(): from airflow.models import DagBag - DagBag(include_examples=True, read_dags_from_db=False).sync_to_db() - return DagBag(include_examples=True, read_dags_from_db=True) + parse_and_sync_to_db(os.devnull, include_examples=True) + return DagBag(read_dags_from_db=True) diff --git a/tests/api_connexion/endpoints/test_dag_parsing.py b/tests/api_connexion/endpoints/test_dag_parsing.py index 1df80a905d92e..0052732b7dfb9 100644 --- a/tests/api_connexion/endpoints/test_dag_parsing.py +++ b/tests/api_connexion/endpoints/test_dag_parsing.py @@ -17,7 +17,6 @@ from __future__ import annotations import os -from typing import TYPE_CHECKING import pytest from sqlalchemy import select @@ -26,17 +25,14 @@ from airflow.models.dagbag import DagPriorityParsingRequest from tests_common.test_utils.api_connexion_utils import create_user, delete_user -from tests_common.test_utils.db import clear_db_dag_parsing_requests +from tests_common.test_utils.db import clear_db_dag_parsing_requests, parse_and_sync_to_db pytestmark = pytest.mark.db_test -if TYPE_CHECKING: - from airflow.models.dag import DAG ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)) EXAMPLE_DAG_FILE = os.path.join("airflow", "example_dags", "example_bash_operator.py") -EXAMPLE_DAG_ID = "example_bash_operator" -TEST_DAG_ID = "latest_only" +TEST_DAG_ID = "example_bash_operator" NOT_READABLE_DAG_ID = "latest_only_with_trigger" TEST_MULTIPLE_DAGS_ID = "asset_produces_1" @@ -72,9 +68,9 @@ def clear_db(): clear_db_dag_parsing_requests() def test_201_and_400_requests(self, url_safe_serializer, session): - dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE) - dagbag.sync_to_db() - test_dag: DAG = dagbag.dags[TEST_DAG_ID] + parse_and_sync_to_db(EXAMPLE_DAG_FILE) + dagbag = DagBag(read_dags_from_db=True) + test_dag = dagbag.get_dag(TEST_DAG_ID) url = f"/api/v1/parseDagFile/{url_safe_serializer.dumps(test_dag.fileloc)}" response = self.client.put( diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py b/tests/api_connexion/endpoints/test_dag_run_endpoint.py index 4f322bb8f0a6c..b6609c162390d 100644 --- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py @@ -24,6 +24,7 @@ import time_machine from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP +from airflow.dag_processing.bundles.manager import DagBundlesManager from airflow.models.asset import AssetEvent, AssetModel from airflow.models.dag import DAG, DagModel from airflow.models.dagrun import DagRun @@ -88,8 +89,9 @@ def _create_dag(self, dag_id): with create_session() as session: session.add(dag_instance) dag = DAG(dag_id=dag_id, schedule=None, params={"validated_number": Param(1, minimum=1, maximum=10)}) + DagBundlesManager().sync_bundles_to_db() self.app.dag_bag.bag_dag(dag) - self.app.dag_bag.sync_to_db() + self.app.dag_bag.sync_to_db("dags-folder", None) return dag_instance def _create_test_dag_run(self, state=DagRunState.RUNNING, extra_dag=False, commit=True, idx_start=1): diff --git a/tests/api_connexion/endpoints/test_dag_source_endpoint.py b/tests/api_connexion/endpoints/test_dag_source_endpoint.py index a907d2704c6e7..c7e3b12c7e6fd 100644 --- a/tests/api_connexion/endpoints/test_dag_source_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_source_endpoint.py @@ -16,6 +16,8 @@ # under the License. from __future__ import annotations +import os + import pytest from sqlalchemy import select @@ -24,13 +26,12 @@ from airflow.models.serialized_dag import SerializedDagModel from tests_common.test_utils.api_connexion_utils import assert_401, create_user, delete_user -from tests_common.test_utils.db import clear_db_dags +from tests_common.test_utils.db import clear_db_dags, parse_and_sync_to_db pytestmark = pytest.mark.db_test -EXAMPLE_DAG_ID = "example_bash_operator" -TEST_DAG_ID = "latest_only" +TEST_DAG_ID = "example_bash_operator" @pytest.fixture(scope="module") @@ -51,9 +52,9 @@ def configured_app(minimal_app_for_api): @pytest.fixture def test_dag(): - dagbag = DagBag(include_examples=True) - dagbag.sync_to_db() - return dagbag.dags[TEST_DAG_ID] + parse_and_sync_to_db(os.devnull, include_examples=True) + dagbag = DagBag(read_dags_from_db=True) + return dagbag.get_dag(TEST_DAG_ID) class TestGetSource: diff --git a/tests/api_connexion/endpoints/test_extra_link_endpoint.py b/tests/api_connexion/endpoints/test_extra_link_endpoint.py index 8cb8dd4e030c1..e4aa7895ac662 100644 --- a/tests/api_connexion/endpoints/test_extra_link_endpoint.py +++ b/tests/api_connexion/endpoints/test_extra_link_endpoint.py @@ -22,6 +22,7 @@ import pytest from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP +from airflow.dag_processing.bundles.manager import DagBundlesManager from airflow.models.dag import DAG from airflow.models.dagbag import DagBag from airflow.models.xcom import XCom @@ -73,9 +74,10 @@ def setup_attrs(self, configured_app, session) -> None: self.dag = self._create_dag() + DagBundlesManager().sync_bundles_to_db() self.app.dag_bag = DagBag(os.devnull, include_examples=False) self.app.dag_bag.dags = {self.dag.dag_id: self.dag} - self.app.dag_bag.sync_to_db() + self.app.dag_bag.sync_to_db("dags-folder", None) triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} self.dag.create_dagrun( diff --git a/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py b/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py index a4a14fa3b421d..7a4d802f6cc15 100644 --- a/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py +++ b/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py @@ -24,6 +24,7 @@ import pytest from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP +from airflow.dag_processing.bundles.manager import DagBundlesManager from airflow.models import TaskInstance from airflow.models.baseoperator import BaseOperator from airflow.models.dagbag import DagBag @@ -122,9 +123,10 @@ def create_dag_runs_with_mapped_tasks(self, dag_maker, session, dags=None): setattr(ti, "start_date", DEFAULT_DATETIME_1) session.add(ti) + DagBundlesManager().sync_bundles_to_db() self.app.dag_bag = DagBag(os.devnull, include_examples=False) self.app.dag_bag.dags = {dag_id: dag_maker.dag} - self.app.dag_bag.sync_to_db() + self.app.dag_bag.sync_to_db("dags-folder", None) session.flush() mapped.expand_mapped_task(dr.run_id, session=session) diff --git a/tests/api_connexion/endpoints/test_task_endpoint.py b/tests/api_connexion/endpoints/test_task_endpoint.py index d1662943604ac..7e593731d539d 100644 --- a/tests/api_connexion/endpoints/test_task_endpoint.py +++ b/tests/api_connexion/endpoints/test_task_endpoint.py @@ -22,6 +22,7 @@ import pytest +from airflow.dag_processing.bundles.manager import DagBundlesManager from airflow.models import DagBag from airflow.models.dag import DAG from airflow.models.expandinput import EXPAND_INPUT_EMPTY @@ -80,13 +81,15 @@ def setup_dag(self, configured_app): task5 = EmptyOperator(task_id=self.unscheduled_task_id2, params={"is_unscheduled": True}) task1 >> task2 task4 >> task5 + + DagBundlesManager().sync_bundles_to_db() dag_bag = DagBag(os.devnull, include_examples=False) dag_bag.dags = { dag.dag_id: dag, mapped_dag.dag_id: mapped_dag, unscheduled_dag.dag_id: unscheduled_dag, } - dag_bag.sync_to_db() + dag_bag.sync_to_db("dags-folder", None) configured_app.dag_bag = dag_bag # type:ignore @staticmethod diff --git a/tests/api_connexion/endpoints/test_task_instance_endpoint.py b/tests/api_connexion/endpoints/test_task_instance_endpoint.py index a4089c9785a98..b5079c47aa17e 100644 --- a/tests/api_connexion/endpoints/test_task_instance_endpoint.py +++ b/tests/api_connexion/endpoints/test_task_instance_endpoint.py @@ -1224,7 +1224,7 @@ def test_should_respond_200(self, main_dag, task_instances, request_dag, payload task_instances=task_instances, update_extras=False, ) - self.app.dag_bag.sync_to_db() + self.app.dag_bag.sync_to_db("dags-folder", None) response = self.client.post( f"/api/v1/dags/{request_dag}/clearTaskInstances", environ_overrides={"REMOTE_USER": "test"}, @@ -1246,7 +1246,7 @@ def test_clear_taskinstance_is_called_with_queued_dr_state(self, mock_clearti, s self.create_task_instances(session) dag_id = "example_python_operator" payload = {"reset_dag_runs": True, "dry_run": False} - self.app.dag_bag.sync_to_db() + self.app.dag_bag.sync_to_db("dags-folder", None) response = self.client.post( f"/api/v1/dags/{dag_id}/clearTaskInstances", environ_overrides={"REMOTE_USER": "test"}, @@ -1266,7 +1266,7 @@ def test_clear_taskinstance_is_called_with_invalid_task_ids(self, session): assert dagrun.state == "running" payload = {"dry_run": False, "reset_dag_runs": True, "task_ids": [""]} - self.app.dag_bag.sync_to_db() + self.app.dag_bag.sync_to_db("dags-folder", None) response = self.client.post( f"/api/v1/dags/{dag_id}/clearTaskInstances", environ_overrides={"REMOTE_USER": "test"}, @@ -1693,7 +1693,7 @@ def test_should_raise_400_for_naive_and_bad_datetime(self, payload, expected, se task_instances=task_instances, update_extras=False, ) - self.app.dag_bag.sync_to_db() + self.app.dag_bag.sync_to_db("dags-folder", None) response = self.client.post( "/api/v1/dags/example_python_operator/clearTaskInstances", environ_overrides={"REMOTE_USER": "test"}, diff --git a/tests/api_fastapi/conftest.py b/tests/api_fastapi/conftest.py index 2928a4d829c70..71f42756dd564 100644 --- a/tests/api_fastapi/conftest.py +++ b/tests/api_fastapi/conftest.py @@ -16,11 +16,15 @@ # under the License. from __future__ import annotations +import os + import pytest from fastapi.testclient import TestClient from airflow.api_fastapi.app import create_app +from tests_common.test_utils.db import parse_and_sync_to_db + @pytest.fixture def test_client(): @@ -42,6 +46,5 @@ def create_test_client(apps="all"): def dagbag(): from airflow.models import DagBag - dagbag_instance = DagBag(include_examples=True, read_dags_from_db=False) - dagbag_instance.sync_to_db() - return dagbag_instance + parse_and_sync_to_db(os.devnull, include_examples=True) + return DagBag(read_dags_from_db=True) diff --git a/tests/api_fastapi/core_api/routes/public/test_dag_parsing.py b/tests/api_fastapi/core_api/routes/public/test_dag_parsing.py index b937f66803f31..fe78c193d389b 100644 --- a/tests/api_fastapi/core_api/routes/public/test_dag_parsing.py +++ b/tests/api_fastapi/core_api/routes/public/test_dag_parsing.py @@ -17,7 +17,6 @@ from __future__ import annotations import os -from typing import TYPE_CHECKING import pytest from sqlalchemy import select @@ -26,19 +25,15 @@ from airflow.models.dagbag import DagPriorityParsingRequest from airflow.utils.session import provide_session -from tests_common.test_utils.db import clear_db_dag_parsing_requests +from tests_common.test_utils.db import clear_db_dag_parsing_requests, parse_and_sync_to_db pytestmark = pytest.mark.db_test -if TYPE_CHECKING: - from airflow.models.dag import DAG - class TestDagParsingEndpoint: ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)) EXAMPLE_DAG_FILE = os.path.join("airflow", "example_dags", "example_bash_operator.py") - EXAMPLE_DAG_ID = "example_bash_operator" - TEST_DAG_ID = "latest_only" + TEST_DAG_ID = "example_bash_operator" NOT_READABLE_DAG_ID = "latest_only_with_trigger" TEST_MULTIPLE_DAGS_ID = "asset_produces_1" @@ -55,9 +50,9 @@ def teardown_method(self) -> None: self.clear_db() def test_201_and_400_requests(self, url_safe_serializer, session, test_client): - dagbag = DagBag(dag_folder=self.EXAMPLE_DAG_FILE) - dagbag.sync_to_db() - test_dag: DAG = dagbag.dags[self.TEST_DAG_ID] + parse_and_sync_to_db(self.EXAMPLE_DAG_FILE) + dagbag = DagBag(read_dags_from_db=True) + test_dag = dagbag.get_dag(self.TEST_DAG_ID) url = f"/public/parseDagFile/{url_safe_serializer.dumps(test_dag.fileloc)}" response = test_client.put(url, headers={"Accept": "application/json"}) diff --git a/tests/api_fastapi/core_api/routes/public/test_dag_run.py b/tests/api_fastapi/core_api/routes/public/test_dag_run.py index 75fec91e1a0c7..9b7c9211fddef 100644 --- a/tests/api_fastapi/core_api/routes/public/test_dag_run.py +++ b/tests/api_fastapi/core_api/routes/public/test_dag_run.py @@ -141,7 +141,7 @@ def setup(request, dag_maker, session=None): logical_date=LOGICAL_DATE4, ) - dag_maker.dagbag.sync_to_db() + dag_maker.sync_dagbag_to_db() dag_maker.dag_model dag_maker.dag_model.has_task_concurrency_limits = True session.merge(ti1) diff --git a/tests/api_fastapi/core_api/routes/public/test_dag_sources.py b/tests/api_fastapi/core_api/routes/public/test_dag_sources.py index 4e8a5990ecb26..dce7981872085 100644 --- a/tests/api_fastapi/core_api/routes/public/test_dag_sources.py +++ b/tests/api_fastapi/core_api/routes/public/test_dag_sources.py @@ -28,7 +28,7 @@ from airflow.models.dagcode import DagCode from airflow.models.serialized_dag import SerializedDagModel -from tests_common.test_utils.db import clear_db_dags +from tests_common.test_utils.db import clear_db_dags, parse_and_sync_to_db pytestmark = pytest.mark.db_test @@ -36,14 +36,13 @@ # Example bash operator located here: airflow/example_dags/example_bash_operator.py EXAMPLE_DAG_FILE = os.path.join("airflow", "example_dags", "example_bash_operator.py") -TEST_DAG_ID = "latest_only" +TEST_DAG_ID = "example_bash_operator" @pytest.fixture def test_dag(): - dagbag = DagBag(include_examples=True) - dagbag.sync_to_db() - return dagbag.dags[TEST_DAG_ID] + parse_and_sync_to_db(EXAMPLE_DAG_FILE, include_examples=False) + return DagBag(read_dags_from_db=True).get_dag(TEST_DAG_ID) class TestGetDAGSource: @@ -131,9 +130,7 @@ def test_should_respond_200_version(self, test_client, accept, session, test_dag "version_number": 2, } - def test_should_respond_406_unsupport_mime_type(self, test_client): - dagbag = DagBag(include_examples=True) - dagbag.sync_to_db() + def test_should_respond_406_unsupport_mime_type(self, test_client, test_dag): response = test_client.get( f"{API_PREFIX}/{TEST_DAG_ID}", headers={"Accept": "text/html"}, diff --git a/tests/api_fastapi/core_api/routes/public/test_dag_tags.py b/tests/api_fastapi/core_api/routes/public/test_dag_tags.py index 784bb480c431b..7d7720c76f773 100644 --- a/tests/api_fastapi/core_api/routes/public/test_dag_tags.py +++ b/tests/api_fastapi/core_api/routes/public/test_dag_tags.py @@ -111,6 +111,7 @@ def setup(self, dag_maker, session=None) -> None: ): EmptyOperator(task_id=TASK_ID) + dag_maker.sync_dagbag_to_db() dag_maker.create_dagrun(state=DagRunState.FAILED) with dag_maker( @@ -127,7 +128,7 @@ def setup(self, dag_maker, session=None) -> None: self._create_deactivated_paused_dag(session) self._create_dag_tags(session) - dag_maker.dagbag.sync_to_db() + dag_maker.sync_dagbag_to_db() dag_maker.dag_model.has_task_concurrency_limits = True session.merge(dag_maker.dag_model) session.commit() diff --git a/tests/api_fastapi/core_api/routes/public/test_dags.py b/tests/api_fastapi/core_api/routes/public/test_dags.py index 05a49c44dfeea..b79c23eb36b4e 100644 --- a/tests/api_fastapi/core_api/routes/public/test_dags.py +++ b/tests/api_fastapi/core_api/routes/public/test_dags.py @@ -127,7 +127,7 @@ def setup(self, dag_maker, session=None) -> None: self._create_deactivated_paused_dag(session) self._create_dag_tags(session) - dag_maker.dagbag.sync_to_db() + dag_maker.sync_dagbag_to_db() dag_maker.dag_model.has_task_concurrency_limits = True session.merge(dag_maker.dag_model) session.commit() @@ -409,7 +409,7 @@ def _create_dag_for_deletion( ti = dr.get_task_instances()[0] ti.set_state(TaskInstanceState.RUNNING) - dag_maker.dagbag.sync_to_db() + dag_maker.sync_dagbag_to_db() @pytest.mark.parametrize( "dag_id, dag_display_name, status_code_delete, status_code_details, has_running_dagruns, is_create_dag", diff --git a/tests/api_fastapi/core_api/routes/public/test_extra_links.py b/tests/api_fastapi/core_api/routes/public/test_extra_links.py index 6a2ca4650075e..89aef555bbe20 100644 --- a/tests/api_fastapi/core_api/routes/public/test_extra_links.py +++ b/tests/api_fastapi/core_api/routes/public/test_extra_links.py @@ -21,6 +21,7 @@ import pytest +from airflow.dag_processing.bundles.manager import DagBundlesManager from airflow.models.dag import DAG from airflow.models.dagbag import DagBag from airflow.models.xcom import XCom @@ -67,10 +68,11 @@ def setup(self, test_client, session=None) -> None: self.dag = self._create_dag() + DagBundlesManager().sync_bundles_to_db() dag_bag = DagBag(os.devnull, include_examples=False) dag_bag.dags = {self.dag.dag_id: self.dag} test_client.app.state.dag_bag = dag_bag - dag_bag.sync_to_db() + dag_bag.sync_to_db("dags-folder", None) triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} diff --git a/tests/api_fastapi/core_api/routes/public/test_task_instances.py b/tests/api_fastapi/core_api/routes/public/test_task_instances.py index d62d37944348c..f0c3c888807ea 100644 --- a/tests/api_fastapi/core_api/routes/public/test_task_instances.py +++ b/tests/api_fastapi/core_api/routes/public/test_task_instances.py @@ -26,6 +26,7 @@ import pytest from sqlalchemy import select +from airflow.dag_processing.bundles.manager import DagBundlesManager from airflow.jobs.job import Job from airflow.jobs.triggerer_job_runner import TriggererJobRunner from airflow.models import DagRun, TaskInstance @@ -522,9 +523,10 @@ def create_dag_runs_with_mapped_tasks(self, dag_maker, session, dags=None): setattr(ti, "start_date", DEFAULT_DATETIME_1) session.add(ti) + DagBundlesManager().sync_bundles_to_db() dagbag = DagBag(os.devnull, include_examples=False) dagbag.dags = {dag_id: dag_maker.dag} - dagbag.sync_to_db() + dagbag.sync_to_db("dags-folder", None) session.flush() mapped.expand_mapped_task(dr.run_id, session=session) @@ -1857,7 +1859,7 @@ def test_should_respond_200( task_instances=task_instances, update_extras=False, ) - self.dagbag.sync_to_db() + self.dagbag.sync_to_db("dags-folder", None) response = test_client.post( f"/public/dags/{request_dag}/clearTaskInstances", json=payload, @@ -1871,7 +1873,7 @@ def test_clear_taskinstance_is_called_with_queued_dr_state(self, mock_clearti, t self.create_task_instances(session) dag_id = "example_python_operator" payload = {"reset_dag_runs": True, "dry_run": False} - self.dagbag.sync_to_db() + self.dagbag.sync_to_db("dags-folder", None) response = test_client.post( f"/public/dags/{dag_id}/clearTaskInstances", json=payload, @@ -1891,7 +1893,7 @@ def test_clear_taskinstance_is_called_with_invalid_task_ids(self, test_client, s assert dagrun.state == "running" payload = {"dry_run": False, "reset_dag_runs": True, "task_ids": [""]} - self.dagbag.sync_to_db() + self.dagbag.sync_to_db("dags-folder", None) response = test_client.post( f"/public/dags/{dag_id}/clearTaskInstances", json=payload, @@ -1941,7 +1943,7 @@ def test_should_respond_200_with_reset_dag_run(self, test_client, session): update_extras=False, dag_run_state=DagRunState.FAILED, ) - self.dagbag.sync_to_db() + self.dagbag.sync_to_db("dags-folder", None) response = test_client.post( f"/public/dags/{dag_id}/clearTaskInstances", json=payload, @@ -2026,7 +2028,7 @@ def test_should_respond_200_with_dag_run_id(self, test_client, session): update_extras=False, dag_run_state=State.FAILED, ) - self.dagbag.sync_to_db() + self.dagbag.sync_to_db("dags-folder", None) response = test_client.post( f"/public/dags/{dag_id}/clearTaskInstances", json=payload, @@ -2082,7 +2084,7 @@ def test_should_respond_200_with_include_past(self, test_client, session): update_extras=False, dag_run_state=State.FAILED, ) - self.dagbag.sync_to_db() + self.dagbag.sync_to_db("dags-folder", None) response = test_client.post( f"/public/dags/{dag_id}/clearTaskInstances", json=payload, @@ -2164,7 +2166,7 @@ def test_should_respond_200_with_include_future(self, test_client, session): update_extras=False, dag_run_state=State.FAILED, ) - self.dagbag.sync_to_db() + self.dagbag.sync_to_db("dags-folder", None) response = test_client.post( f"/public/dags/{dag_id}/clearTaskInstances", json=payload, @@ -2312,7 +2314,7 @@ def test_should_raise_400_for_naive_and_bad_datetime(self, test_client, session, task_instances=task_instances, update_extras=False, ) - self.dagbag.sync_to_db() + self.dagbag.sync_to_db("dags-folder", None) response = test_client.post( "/public/dags/example_python_operator/clearTaskInstances", json=payload, diff --git a/tests/api_fastapi/core_api/routes/ui/test_assets.py b/tests/api_fastapi/core_api/routes/ui/test_assets.py index cbe1b292f6d7e..4935b25014ea7 100644 --- a/tests/api_fastapi/core_api/routes/ui/test_assets.py +++ b/tests/api_fastapi/core_api/routes/ui/test_assets.py @@ -46,7 +46,7 @@ def test_next_run_assets(test_client, dag_maker): EmptyOperator(task_id="task1") dag_maker.create_dagrun() - dag_maker.dagbag.sync_to_db() + dag_maker.sync_dagbag_to_db() response = test_client.get("/ui/next_run_assets/upstream") diff --git a/tests/api_fastapi/core_api/routes/ui/test_dashboard.py b/tests/api_fastapi/core_api/routes/ui/test_dashboard.py index 93317eaa67088..416e016e25c95 100644 --- a/tests/api_fastapi/core_api/routes/ui/test_dashboard.py +++ b/tests/api_fastapi/core_api/routes/ui/test_dashboard.py @@ -95,7 +95,7 @@ def make_dag_runs(dag_maker, session, time_machine): for ti in run2.task_instances: ti.state = TaskInstanceState.FAILED - dag_maker.dagbag.sync_to_db() + dag_maker.sync_dagbag_to_db() time_machine.move_to("2023-07-02T00:00:00+00:00", tick=False) diff --git a/tests/api_fastapi/core_api/routes/ui/test_structure.py b/tests/api_fastapi/core_api/routes/ui/test_structure.py index 9732269944dfb..1f865a140c5be 100644 --- a/tests/api_fastapi/core_api/routes/ui/test_structure.py +++ b/tests/api_fastapi/core_api/routes/ui/test_structure.py @@ -59,7 +59,7 @@ def make_dag(dag_maker, session, time_machine): ): TriggerDagRunOperator(task_id="trigger_dag_run_operator", trigger_dag_id=DAG_ID) - dag_maker.dagbag.sync_to_db() + dag_maker.sync_dagbag_to_db() with dag_maker( dag_id=DAG_ID, @@ -78,7 +78,7 @@ def make_dag(dag_maker, session, time_machine): >> EmptyOperator(task_id="task_2") ) - dag_maker.dagbag.sync_to_db() + dag_maker.sync_dagbag_to_db() class TestStructureDataEndpoint: diff --git a/tests/cli/commands/remote_commands/test_asset_command.py b/tests/cli/commands/remote_commands/test_asset_command.py index 69906d1813a29..067bc9e8e1e2b 100644 --- a/tests/cli/commands/remote_commands/test_asset_command.py +++ b/tests/cli/commands/remote_commands/test_asset_command.py @@ -21,15 +21,15 @@ import contextlib import io import json +import os import typing import pytest from airflow.cli import cli_parser from airflow.cli.commands.remote_commands import asset_command -from airflow.models.dagbag import DagBag -from tests_common.test_utils.db import clear_db_dags, clear_db_runs +from tests_common.test_utils.db import clear_db_dags, clear_db_runs, parse_and_sync_to_db if typing.TYPE_CHECKING: from argparse import ArgumentParser @@ -39,7 +39,7 @@ @pytest.fixture(scope="module", autouse=True) def prepare_examples(): - DagBag(include_examples=True).sync_to_db() + parse_and_sync_to_db(os.devnull, include_examples=True) yield clear_db_runs() clear_db_dags() diff --git a/tests/cli/commands/remote_commands/test_backfill_command.py b/tests/cli/commands/remote_commands/test_backfill_command.py index e252d4db1c478..38b0169a0e7dd 100644 --- a/tests/cli/commands/remote_commands/test_backfill_command.py +++ b/tests/cli/commands/remote_commands/test_backfill_command.py @@ -18,6 +18,7 @@ from __future__ import annotations import argparse +import os from datetime import datetime from unittest import mock @@ -26,11 +27,10 @@ import airflow.cli.commands.remote_commands.backfill_command from airflow.cli import cli_parser -from airflow.models import DagBag from airflow.models.backfill import ReprocessBehavior from airflow.utils import timezone -from tests_common.test_utils.db import clear_db_backfills, clear_db_dags, clear_db_runs +from tests_common.test_utils.db import clear_db_backfills, clear_db_dags, clear_db_runs, parse_and_sync_to_db DEFAULT_DATE = timezone.make_aware(datetime(2015, 1, 1), timezone=timezone.utc) if pendulum.__version__.startswith("3"): @@ -48,8 +48,7 @@ class TestCliBackfill: @classmethod def setup_class(cls): - cls.dagbag = DagBag(include_examples=True) - cls.dagbag.sync_to_db() + parse_and_sync_to_db(os.devnull, include_examples=True) cls.parser = cli_parser.get_parser() @classmethod diff --git a/tests/cli/commands/remote_commands/test_dag_command.py b/tests/cli/commands/remote_commands/test_dag_command.py index dab4d0da6caa0..0db0ac83df02f 100644 --- a/tests/cli/commands/remote_commands/test_dag_command.py +++ b/tests/cli/commands/remote_commands/test_dag_command.py @@ -50,7 +50,7 @@ from tests.models import TEST_DAGS_FOLDER from tests_common.test_utils.config import conf_vars -from tests_common.test_utils.db import clear_db_dags, clear_db_runs +from tests_common.test_utils.db import clear_db_dags, clear_db_runs, parse_and_sync_to_db DEFAULT_DATE = timezone.make_aware(datetime(2015, 1, 1), timezone=timezone.utc) if pendulum.__version__.startswith("3"): @@ -68,8 +68,7 @@ class TestCliDags: @classmethod def setup_class(cls): - cls.dagbag = DagBag(include_examples=True) - cls.dagbag.sync_to_db() + parse_and_sync_to_db(os.devnull, include_examples=True) cls.parser = cli_parser.get_parser() @classmethod @@ -207,8 +206,7 @@ def test_next_execution(self, tmp_path): with time_machine.travel(DEFAULT_DATE): clear_db_dags() - self.dagbag = DagBag(dag_folder=tmp_path, include_examples=False) - self.dagbag.sync_to_db() + parse_and_sync_to_db(tmp_path, include_examples=False) default_run = DEFAULT_DATE future_run = default_run + timedelta(days=5) @@ -255,8 +253,7 @@ def test_next_execution(self, tmp_path): # Rebuild Test DB for other tests clear_db_dags() - TestCliDags.dagbag = DagBag(include_examples=True) - TestCliDags.dagbag.sync_to_db() + parse_and_sync_to_db(os.devnull, include_examples=True) @conf_vars({("core", "load_examples"): "true"}) def test_cli_report(self): @@ -405,24 +402,24 @@ def test_cli_list_jobs_with_args(self): def test_pause(self): args = self.parser.parse_args(["dags", "pause", "example_bash_operator"]) dag_command.dag_pause(args) - assert self.dagbag.dags["example_bash_operator"].get_is_paused() + assert DagModel.get_dagmodel("example_bash_operator").is_paused dag_command.dag_unpause(args) - assert not self.dagbag.dags["example_bash_operator"].get_is_paused() + assert not DagModel.get_dagmodel("example_bash_operator").is_paused @mock.patch("airflow.cli.commands.remote_commands.dag_command.ask_yesno") def test_pause_regex(self, mock_yesno): args = self.parser.parse_args(["dags", "pause", "^example_.*$", "--treat-dag-id-as-regex"]) dag_command.dag_pause(args) mock_yesno.assert_called_once() - assert self.dagbag.dags["example_bash_decorator"].get_is_paused() - assert self.dagbag.dags["example_kubernetes_executor"].get_is_paused() - assert self.dagbag.dags["example_xcom_args"].get_is_paused() + assert DagModel.get_dagmodel("example_bash_decorator").is_paused + assert DagModel.get_dagmodel("example_kubernetes_executor").is_paused + assert DagModel.get_dagmodel("example_xcom_args").is_paused args = self.parser.parse_args(["dags", "unpause", "^example_.*$", "--treat-dag-id-as-regex"]) dag_command.dag_unpause(args) - assert not self.dagbag.dags["example_bash_decorator"].get_is_paused() - assert not self.dagbag.dags["example_kubernetes_executor"].get_is_paused() - assert not self.dagbag.dags["example_xcom_args"].get_is_paused() + assert not DagModel.get_dagmodel("example_bash_decorator").is_paused + assert not DagModel.get_dagmodel("example_kubernetes_executor").is_paused + assert not DagModel.get_dagmodel("example_xcom_args").is_paused @mock.patch("airflow.cli.commands.remote_commands.dag_command.ask_yesno") def test_pause_regex_operation_cancelled(self, ask_yesno, capsys): diff --git a/tests/cli/commands/remote_commands/test_task_command.py b/tests/cli/commands/remote_commands/test_task_command.py index 843d6817cdcc5..9a4c606caa469 100644 --- a/tests/cli/commands/remote_commands/test_task_command.py +++ b/tests/cli/commands/remote_commands/test_task_command.py @@ -53,7 +53,7 @@ from airflow.utils.types import DagRunType from tests_common.test_utils.config import conf_vars -from tests_common.test_utils.db import clear_db_pools, clear_db_runs +from tests_common.test_utils.db import clear_db_pools, clear_db_runs, parse_and_sync_to_db from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS if AIRFLOW_V_3_0_PLUS: @@ -97,12 +97,12 @@ class TestCliTasks: @pytest.fixture(autouse=True) def setup_class(cls): logging.config.dictConfig(DEFAULT_LOGGING_CONFIG) - cls.dagbag = DagBag(include_examples=True) + parse_and_sync_to_db(os.devnull, include_examples=True) cls.parser = cli_parser.get_parser() clear_db_runs() + cls.dagbag = DagBag(read_dags_from_db=True) cls.dag = cls.dagbag.get_dag(cls.dag_id) - cls.dagbag.sync_to_db() data_interval = cls.dag.timetable.infer_manual_data_interval(run_after=DEFAULT_DATE) triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.CLI} if AIRFLOW_V_3_0_PLUS else {} cls.dag_run = cls.dag.create_dagrun( @@ -164,7 +164,7 @@ def test_cli_test_different_path(self, session, tmp_path): with conf_vars({("core", "dags_folder"): orig_dags_folder.as_posix()}): dagbag = DagBag(include_examples=False) dag = dagbag.get_dag("test_dags_folder") - dagbag.sync_to_db(session=session) + dagbag.sync_to_db("dags-folder", None, session=session) logical_date = pendulum.now("UTC") data_interval = dag.timetable.infer_manual_data_interval(run_after=logical_date) diff --git a/tests/conftest.py b/tests/conftest.py index de13fe99c4bf6..8e41e35e35d06 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,15 +16,20 @@ # under the License. from __future__ import annotations +import json import logging import os import sys +from contextlib import contextmanager from typing import TYPE_CHECKING import pytest from tests_common.test_utils.log_handlers import non_pytest_handlers +if TYPE_CHECKING: + from pathlib import Path + # We should set these before loading _any_ of the rest of airflow so that the # unit test mode config is set as early as possible. assert "airflow" not in sys.modules, "No airflow module can be imported before these lines" @@ -81,6 +86,37 @@ def clear_all_logger_handlers(): remove_all_non_pytest_log_handlers() +@pytest.fixture +def testing_dag_bundle(): + from airflow.models.dagbundle import DagBundleModel + from airflow.utils.session import create_session + + with create_session() as session: + if session.query(DagBundleModel).filter(DagBundleModel.name == "testing").count() == 0: + testing = DagBundleModel(name="testing") + session.add(testing) + + +@pytest.fixture +def configure_testing_dag_bundle(): + """Configure the testing DAG bundle with the provided path, and disable the DAGs folder bundle.""" + from tests_common.test_utils.config import conf_vars + + @contextmanager + def _config_bundle(path_to_parse: Path | str): + bundle_config = [ + { + "name": "testing", + "classpath": "airflow.dag_processing.bundles.local.LocalDagBundle", + "kwargs": {"local_folder": str(path_to_parse), "refresh_interval": 0}, + } + ] + with conf_vars({("dag_bundles", "backends"): json.dumps(bundle_config)}): + yield + + return _config_bundle + + if TYPE_CHECKING: # Static checkers do not know about pytest fixtures' types and return, # In case if them distributed through third party packages. diff --git a/tests/dag_processing/bundles/test_dag_bundle_manager.py b/tests/dag_processing/bundles/test_dag_bundle_manager.py index 47b192efc198f..eee7dee211472 100644 --- a/tests/dag_processing/bundles/test_dag_bundle_manager.py +++ b/tests/dag_processing/bundles/test_dag_bundle_manager.py @@ -68,7 +68,9 @@ ) def test_parse_bundle_config(value, expected): """Test that bundle_configs are read from configuration.""" - envs = {"AIRFLOW__DAG_BUNDLES__BACKENDS": value} if value else {} + envs = {"AIRFLOW__CORE__LOAD_EXAMPLES": "False"} + if value: + envs["AIRFLOW__DAG_BUNDLES__BACKENDS"] = value cm = nullcontext() exp_fail = False if isinstance(expected, str): @@ -133,6 +135,7 @@ def clear_db(): @pytest.mark.db_test +@conf_vars({("core", "LOAD_EXAMPLES"): "False"}) def test_sync_bundles_to_db(clear_db): def _get_bundle_names_and_active(): with create_session() as session: @@ -167,3 +170,14 @@ def test_view_url(version): with patch.object(BaseDagBundle, "view_url") as view_url_mock: bundle_manager.view_url("my-test-bundle", version=version) view_url_mock.assert_called_once_with(version=version) + + +def test_example_dags_bundle_added(): + manager = DagBundlesManager() + manager.parse_config() + assert "example_dags" in manager._bundle_config + + with conf_vars({("core", "LOAD_EXAMPLES"): "False"}): + manager = DagBundlesManager() + manager.parse_config() + assert "example_dags" not in manager._bundle_config diff --git a/tests/dag_processing/test_collection.py b/tests/dag_processing/test_collection.py index a248904cbefcc..8ef07514bbfbe 100644 --- a/tests/dag_processing/test_collection.py +++ b/tests/dag_processing/test_collection.py @@ -185,7 +185,7 @@ def dag_to_lazy_serdag(self, dag: DAG) -> LazyDeserializedDAG: @pytest.mark.usefixtures("clean_db") # sync_perms in fab has bad session commit hygiene def test_sync_perms_syncs_dag_specific_perms_on_update( - self, monkeypatch, spy_agency: SpyAgency, session, time_machine + self, monkeypatch, spy_agency: SpyAgency, session, time_machine, testing_dag_bundle ): """ Test that dagbag.sync_to_db will sync DAG specific permissions when a DAG is @@ -210,7 +210,7 @@ def _sync_to_db(): sync_perms_spy.reset_calls() time_machine.shift(20) - update_dag_parsing_results_in_db([dag], dict(), set(), session) + update_dag_parsing_results_in_db("testing", None, [dag], dict(), set(), session) _sync_to_db() spy_agency.assert_spy_called_with(sync_perms_spy, dag, session=session) @@ -228,7 +228,9 @@ def _sync_to_db(): @patch.object(SerializedDagModel, "write_dag") @patch("airflow.models.dag.DAG.bulk_write_to_db") - def test_sync_to_db_is_retried(self, mock_bulk_write_to_db, mock_s10n_write_dag, session): + def test_sync_to_db_is_retried( + self, mock_bulk_write_to_db, mock_s10n_write_dag, testing_dag_bundle, session + ): """Test that important DB operations in db sync are retried on OperationalError""" serialized_dags_count = session.query(func.count(SerializedDagModel.dag_id)).scalar() @@ -244,14 +246,16 @@ def test_sync_to_db_is_retried(self, mock_bulk_write_to_db, mock_s10n_write_dag, mock_bulk_write_to_db.side_effect = side_effect mock_session = mock.MagicMock() - update_dag_parsing_results_in_db(dags=dags, import_errors={}, warnings=set(), session=mock_session) + update_dag_parsing_results_in_db( + "testing", None, dags=dags, import_errors={}, warnings=set(), session=mock_session + ) # Test that 3 attempts were made to run 'DAG.bulk_write_to_db' successfully mock_bulk_write_to_db.assert_has_calls( [ - mock.call(mock.ANY, session=mock.ANY), - mock.call(mock.ANY, session=mock.ANY), - mock.call(mock.ANY, session=mock.ANY), + mock.call("testing", None, mock.ANY, session=mock.ANY), + mock.call("testing", None, mock.ANY, session=mock.ANY), + mock.call("testing", None, mock.ANY, session=mock.ANY), ] ) # Assert that rollback is called twice (i.e. whenever OperationalError occurs) @@ -268,7 +272,7 @@ def test_sync_to_db_is_retried(self, mock_bulk_write_to_db, mock_s10n_write_dag, serialized_dags_count = session.query(func.count(SerializedDagModel.dag_id)).scalar() assert serialized_dags_count == 0 - def test_serialized_dags_are_written_to_db_on_sync(self, session): + def test_serialized_dags_are_written_to_db_on_sync(self, testing_dag_bundle, session): """ Test that when dagbag.sync_to_db is called the DAGs are Serialized and written to DB even when dagbag.read_dags_from_db is False @@ -278,14 +282,14 @@ def test_serialized_dags_are_written_to_db_on_sync(self, session): dag = DAG(dag_id="test") - update_dag_parsing_results_in_db([dag], dict(), set(), session) + update_dag_parsing_results_in_db("testing", None, [dag], dict(), set(), session) new_serialized_dags_count = session.query(func.count(SerializedDagModel.dag_id)).scalar() assert new_serialized_dags_count == 1 @patch.object(SerializedDagModel, "write_dag") def test_serialized_dag_errors_are_import_errors( - self, mock_serialize, caplog, session, dag_import_error_listener + self, mock_serialize, caplog, session, dag_import_error_listener, testing_dag_bundle ): """ Test that errors serializing a DAG are recorded as import_errors in the DB @@ -298,7 +302,7 @@ def test_serialized_dag_errors_are_import_errors( dag.fileloc = "abc.py" import_errors = {} - update_dag_parsing_results_in_db([dag], import_errors, set(), session) + update_dag_parsing_results_in_db("testing", None, [dag], import_errors, set(), session) assert "SerializationError" in caplog.text # Should have been edited in place @@ -320,7 +324,7 @@ def test_serialized_dag_errors_are_import_errors( assert len(dag_import_error_listener.existing) == 0 assert dag_import_error_listener.new["abc.py"] == import_error.stacktrace - def test_new_import_error_replaces_old(self, session, dag_import_error_listener): + def test_new_import_error_replaces_old(self, session, dag_import_error_listener, testing_dag_bundle): """ Test that existing import error is updated and new record not created for a dag with the same filename @@ -336,6 +340,8 @@ def test_new_import_error_replaces_old(self, session, dag_import_error_listener) prev_error_id = prev_error.id update_dag_parsing_results_in_db( + bundle_name="testing", + bundle_version=None, dags=[], import_errors={"abc.py": "New error"}, warnings=set(), @@ -353,7 +359,7 @@ def test_new_import_error_replaces_old(self, session, dag_import_error_listener) assert len(dag_import_error_listener.existing) == 1 assert dag_import_error_listener.existing["abc.py"] == prev_error.stacktrace - def test_remove_error_clears_import_error(self, session): + def test_remove_error_clears_import_error(self, testing_dag_bundle, session): # Pre-condition: there is an import error for the dag file filename = "abc.py" prev_error = ParseImportError( @@ -381,7 +387,7 @@ def test_remove_error_clears_import_error(self, session): dag.fileloc = filename import_errors = {} - update_dag_parsing_results_in_db([dag], import_errors, set(), session) + update_dag_parsing_results_in_db("testing", None, [dag], import_errors, set(), session) dag_model: DagModel = session.get(DagModel, (dag.dag_id,)) assert dag_model.has_import_errors is False @@ -473,7 +479,7 @@ def _sync_perms(): ], ) @pytest.mark.usefixtures("clean_db") - def test_dagmodel_properties(self, attrs, expected, session, time_machine): + def test_dagmodel_properties(self, attrs, expected, session, time_machine, testing_dag_bundle): """Test that properties on the dag model are correctly set when dealing with a LazySerializedDag""" dt = tz.datetime(2020, 1, 5, 0, 0, 0) time_machine.move_to(dt, tick=False) @@ -493,7 +499,7 @@ def test_dagmodel_properties(self, attrs, expected, session, time_machine): session.add(dr1) session.commit() - update_dag_parsing_results_in_db([self.dag_to_lazy_serdag(dag)], {}, set(), session) + update_dag_parsing_results_in_db("testing", None, [self.dag_to_lazy_serdag(dag)], {}, set(), session) orm_dag = session.get(DagModel, ("dag",)) @@ -505,14 +511,21 @@ def test_dagmodel_properties(self, attrs, expected, session, time_machine): assert orm_dag.last_parsed_time == dt - def test_existing_dag_is_paused_upon_creation(self, session): + def test_existing_dag_is_paused_upon_creation(self, testing_dag_bundle, session): dag = DAG("dag_paused", schedule=None) - update_dag_parsing_results_in_db([self.dag_to_lazy_serdag(dag)], {}, set(), session) + update_dag_parsing_results_in_db("testing", None, [self.dag_to_lazy_serdag(dag)], {}, set(), session) orm_dag = session.get(DagModel, ("dag_paused",)) assert orm_dag.is_paused is False dag = DAG("dag_paused", schedule=None, is_paused_upon_creation=True) - update_dag_parsing_results_in_db([self.dag_to_lazy_serdag(dag)], {}, set(), session) + update_dag_parsing_results_in_db("testing", None, [self.dag_to_lazy_serdag(dag)], {}, set(), session) # Since the dag existed before, it should not follow the pause flag upon creation orm_dag = session.get(DagModel, ("dag_paused",)) assert orm_dag.is_paused is False + + def test_bundle_name_and_version_are_stored(self, testing_dag_bundle, session): + dag = DAG("mydag", schedule=None) + update_dag_parsing_results_in_db("testing", "1.0", [self.dag_to_lazy_serdag(dag)], {}, set(), session) + orm_dag = session.get(DagModel, "mydag") + assert orm_dag.bundle_name == "testing" + assert orm_dag.latest_bundle_version == "1.0" diff --git a/tests/dag_processing/test_manager.py b/tests/dag_processing/test_manager.py index 2cc8a43e05450..7608cbbd32766 100644 --- a/tests/dag_processing/test_manager.py +++ b/tests/dag_processing/test_manager.py @@ -19,6 +19,7 @@ import io import itertools +import json import logging import multiprocessing import os @@ -43,15 +44,18 @@ from airflow.callbacks.callback_requests import DagCallbackRequest from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG +from airflow.dag_processing.bundles.manager import DagBundlesManager from airflow.dag_processing.manager import ( + DagFileInfo, DagFileProcessorAgent, DagFileProcessorManager, DagFileStat, ) from airflow.dag_processing.processor import DagFileProcessorProcess -from airflow.models import DagBag, DagModel, DbCallbackRequest +from airflow.models import DAG, DagBag, DagModel, DbCallbackRequest from airflow.models.asset import TaskOutletAssetReference from airflow.models.dag_version import DagVersion +from airflow.models.dagbundle import DagBundleModel from airflow.models.dagcode import DagCode from airflow.models.serialized_dag import SerializedDagModel from airflow.utils import timezone @@ -78,6 +82,10 @@ DEFAULT_DATE = timezone.datetime(2016, 1, 1) +def _get_dag_file_paths(files: list[str]) -> list[DagFileInfo]: + return [DagFileInfo(bundle_name="testing", path=f) for f in files] + + class TestDagFileProcessorManager: @pytest.fixture(autouse=True) def _disable_examples(self): @@ -101,9 +109,6 @@ def teardown_class(self): clear_db_callbacks() clear_db_import_errors() - def run_processor_manager_one_loop(self, manager: DagFileProcessorManager) -> None: - manager._run_parsing_loop() - def mock_processor(self) -> DagFileProcessorProcess: proc = MagicMock() proc.create_time.return_value = time.time() @@ -124,49 +129,51 @@ def clear_parse_import_errors(self): @pytest.mark.usefixtures("clear_parse_import_errors") @conf_vars({("core", "load_examples"): "False"}) - def test_remove_file_clears_import_error(self, tmp_path): + def test_remove_file_clears_import_error(self, tmp_path, configure_testing_dag_bundle): path_to_parse = tmp_path / "temp_dag.py" # Generate original import error path_to_parse.write_text("an invalid airflow DAG") - manager = DagFileProcessorManager( - dag_directory=path_to_parse.parent, - max_runs=1, - processor_timeout=365 * 86_400, - ) + with configure_testing_dag_bundle(path_to_parse): + manager = DagFileProcessorManager( + max_runs=1, + processor_timeout=365 * 86_400, + ) - with create_session() as session: - self.run_processor_manager_one_loop(manager) + with create_session() as session: + manager.run() - import_errors = session.query(ParseImportError).all() - assert len(import_errors) == 1 + import_errors = session.query(ParseImportError).all() + assert len(import_errors) == 1 - path_to_parse.unlink() + path_to_parse.unlink() - # Rerun the parser once the dag file has been removed - self.run_processor_manager_one_loop(manager) - import_errors = session.query(ParseImportError).all() + # Rerun the parser once the dag file has been removed + manager.run() + import_errors = session.query(ParseImportError).all() - assert len(import_errors) == 0 - session.rollback() + assert len(import_errors) == 0 + session.rollback() @conf_vars({("core", "load_examples"): "False"}) def test_max_runs_when_no_files(self, tmp_path): - manager = DagFileProcessorManager(dag_directory=tmp_path, max_runs=1) + with conf_vars({("core", "dags_folder"): str(tmp_path)}): + manager = DagFileProcessorManager(max_runs=1) + manager.run() - self.run_processor_manager_one_loop(manager) + # TODO: AIP-66 no asserts? def test_start_new_processes_with_same_filepath(self): """ Test that when a processor already exist with a filepath, a new processor won't be created with that filepath. The filepath will just be removed from the list. """ - manager = DagFileProcessorManager(dag_directory="directory", max_runs=1) + manager = DagFileProcessorManager(max_runs=1) - file_1 = "file_1.py" - file_2 = "file_2.py" - file_3 = "file_3.py" + file_1 = DagFileInfo(bundle_name="testing", path="file_1.py") + file_2 = DagFileInfo(bundle_name="testing", path="file_2.py") + file_3 = DagFileInfo(bundle_name="testing", path="file_3.py") manager._file_path_queue = deque([file_1, file_2, file_3]) # Mock that only one processor exists. This processor runs with 'file_1' @@ -185,49 +192,47 @@ def test_start_new_processes_with_same_filepath(self): assert deque([file_3]) == manager._file_path_queue def test_set_file_paths_when_processor_file_path_not_in_new_file_paths(self): - manager = DagFileProcessorManager(dag_directory="directory", max_runs=1) - - mock_processor = MagicMock() - mock_processor.stop.side_effect = AttributeError("DagFileProcessor object has no attribute stop") - mock_processor.terminate.side_effect = None + """Ensure processors and file stats are removed when the file path is not in the new file paths""" + manager = DagFileProcessorManager(max_runs=1) + file = DagFileInfo(bundle_name="testing", path="missing_file.txt") - manager._processors["missing_file.txt"] = mock_processor - manager._file_stats["missing_file.txt"] = DagFileStat() + manager._processors[file] = MagicMock() + manager._file_stats[file] = DagFileStat() manager.set_file_paths(["abc.txt"]) assert manager._processors == {} - assert "missing_file.txt" not in manager._file_stats + assert file not in manager._file_stats def test_set_file_paths_when_processor_file_path_is_in_new_file_paths(self): - manager = DagFileProcessorManager(dag_directory="directory", max_runs=1) - + manager = DagFileProcessorManager(max_runs=1) + file = DagFileInfo(bundle_name="testing", path="abc.txt") mock_processor = MagicMock() - mock_processor.stop.side_effect = AttributeError("DagFileProcessor object has no attribute stop") - mock_processor.terminate.side_effect = None - manager._processors["abc.txt"] = mock_processor + manager._processors[file] = mock_processor - manager.set_file_paths(["abc.txt"]) - assert manager._processors == {"abc.txt": mock_processor} + manager.set_file_paths([file]) + assert manager._processors == {file: mock_processor} @conf_vars({("scheduler", "file_parsing_sort_mode"): "alphabetical"}) def test_file_paths_in_queue_sorted_alphabetically(self): """Test dag files are sorted alphabetically""" - dag_files = ["file_3.py", "file_2.py", "file_4.py", "file_1.py"] + file_names = ["file_3.py", "file_2.py", "file_4.py", "file_1.py"] + dag_files = _get_dag_file_paths(file_names) + ordered_dag_files = _get_dag_file_paths(sorted(file_names)) - manager = DagFileProcessorManager(dag_directory="directory", max_runs=1) + manager = DagFileProcessorManager(max_runs=1) manager.set_file_paths(dag_files) assert manager._file_path_queue == deque() manager.prepare_file_path_queue() - assert manager._file_path_queue == deque(["file_1.py", "file_2.py", "file_3.py", "file_4.py"]) + assert manager._file_path_queue == deque(ordered_dag_files) @conf_vars({("scheduler", "file_parsing_sort_mode"): "random_seeded_by_host"}) def test_file_paths_in_queue_sorted_random_seeded_by_host(self): """Test files are randomly sorted and seeded by host name""" - dag_files = ["file_3.py", "file_2.py", "file_4.py", "file_1.py"] + dag_files = _get_dag_file_paths(["file_3.py", "file_2.py", "file_4.py", "file_1.py"]) - manager = DagFileProcessorManager(dag_directory="directory", max_runs=1) + manager = DagFileProcessorManager(max_runs=1) manager.set_file_paths(dag_files) assert manager._file_path_queue == deque() @@ -247,45 +252,50 @@ def test_file_paths_in_queue_sorted_random_seeded_by_host(self): def test_file_paths_in_queue_sorted_by_modified_time(self, mock_getmtime): """Test files are sorted by modified time""" paths_with_mtime = {"file_3.py": 3.0, "file_2.py": 2.0, "file_4.py": 5.0, "file_1.py": 4.0} - dag_files = list(paths_with_mtime.keys()) + dag_files = _get_dag_file_paths(paths_with_mtime.keys()) mock_getmtime.side_effect = list(paths_with_mtime.values()) - manager = DagFileProcessorManager(dag_directory="directory", max_runs=1) + manager = DagFileProcessorManager(max_runs=1) manager.set_file_paths(dag_files) assert manager._file_path_queue == deque() manager.prepare_file_path_queue() - assert manager._file_path_queue == deque(["file_4.py", "file_1.py", "file_3.py", "file_2.py"]) + ordered_files = _get_dag_file_paths(["file_4.py", "file_1.py", "file_3.py", "file_2.py"]) + assert manager._file_path_queue == deque(ordered_files) @conf_vars({("scheduler", "file_parsing_sort_mode"): "modified_time"}) @mock.patch("airflow.utils.file.os.path.getmtime") def test_file_paths_in_queue_excludes_missing_file(self, mock_getmtime): """Check that a file is not enqueued for processing if it has been deleted""" - dag_files = ["file_3.py", "file_2.py", "file_4.py"] + dag_files = _get_dag_file_paths(["file_3.py", "file_2.py", "file_4.py"]) mock_getmtime.side_effect = [1.0, 2.0, FileNotFoundError()] - manager = DagFileProcessorManager(dag_directory="directory", max_runs=1) + manager = DagFileProcessorManager(max_runs=1) manager.set_file_paths(dag_files) manager.prepare_file_path_queue() - assert manager._file_path_queue == deque(["file_2.py", "file_3.py"]) + + ordered_files = _get_dag_file_paths(["file_2.py", "file_3.py"]) + assert manager._file_path_queue == deque(ordered_files) @conf_vars({("scheduler", "file_parsing_sort_mode"): "modified_time"}) @mock.patch("airflow.utils.file.os.path.getmtime") def test_add_new_file_to_parsing_queue(self, mock_getmtime): """Check that new file is added to parsing queue""" - dag_files = ["file_1.py", "file_2.py", "file_3.py"] + dag_files = _get_dag_file_paths(["file_1.py", "file_2.py", "file_3.py"]) mock_getmtime.side_effect = [1.0, 2.0, 3.0] - manager = DagFileProcessorManager(dag_directory="directory", max_runs=1) + manager = DagFileProcessorManager(max_runs=1) manager.set_file_paths(dag_files) manager.prepare_file_path_queue() - assert manager._file_path_queue == deque(["file_3.py", "file_2.py", "file_1.py"]) + ordered_files = _get_dag_file_paths(["file_3.py", "file_2.py", "file_1.py"]) + assert manager._file_path_queue == deque(ordered_files) - manager.set_file_paths([*dag_files, "file_4.py"]) + manager.set_file_paths([*dag_files, DagFileInfo(bundle_name="testing", path="file_4.py")]) manager.add_new_file_path_to_queue() - assert manager._file_path_queue == deque(["file_4.py", "file_3.py", "file_2.py", "file_1.py"]) + ordered_files = _get_dag_file_paths(["file_4.py", "file_3.py", "file_2.py", "file_1.py"]) + assert manager._file_path_queue == deque(ordered_files) @conf_vars({("scheduler", "file_parsing_sort_mode"): "modified_time"}) @mock.patch("airflow.utils.file.os.path.getmtime") @@ -295,15 +305,16 @@ def test_recently_modified_file_is_parsed_with_mtime_mode(self, mock_getmtime): """ freezed_base_time = timezone.datetime(2020, 1, 5, 0, 0, 0) initial_file_1_mtime = (freezed_base_time - timedelta(minutes=5)).timestamp() - dag_files = ["file_1.py"] + dag_file = DagFileInfo(bundle_name="testing", path="file_1.py") + dag_files = [dag_file] mock_getmtime.side_effect = [initial_file_1_mtime] - manager = DagFileProcessorManager(dag_directory="directory", max_runs=3) + manager = DagFileProcessorManager(max_runs=3) # let's say the DAG was just parsed 10 seconds before the Freezed time last_finish_time = freezed_base_time - timedelta(seconds=10) manager._file_stats = { - "file_1.py": DagFileStat(1, 0, last_finish_time, 1.0, 1, 1), + dag_file: DagFileStat(1, 0, last_finish_time, 1.0, 1, 1), } with time_machine.travel(freezed_base_time): manager.set_file_paths(dag_files) @@ -323,13 +334,14 @@ def test_recently_modified_file_is_parsed_with_mtime_mode(self, mock_getmtime): mock_getmtime.side_effect = [file_1_new_mtime_ts] manager.prepare_file_path_queue() # Check that file is added to the queue even though file was just recently passed - assert manager._file_path_queue == deque(["file_1.py"]) + assert manager._file_path_queue == deque(dag_files) assert last_finish_time < file_1_new_mtime assert ( manager._file_process_interval - > (freezed_base_time - manager._file_stats["file_1.py"].last_finish_time).total_seconds() + > (freezed_base_time - manager._file_stats[dag_file].last_finish_time).total_seconds() ) + @pytest.mark.skip("AIP-66: parsing requests are not bundle aware yet") def test_file_paths_in_queue_sorted_by_priority(self): from airflow.models.dagbag import DagPriorityParsingRequest @@ -351,25 +363,27 @@ def test_file_paths_in_queue_sorted_by_priority(self): parsing_request_after = session2.query(DagPriorityParsingRequest).get(parsing_request.id) assert parsing_request_after is None - def test_scan_stale_dags(self): + def test_scan_stale_dags(self, testing_dag_bundle): """ Ensure that DAGs are marked inactive when the file is parsed but the DagModel.last_parsed_time is not updated. """ manager = DagFileProcessorManager( - dag_directory="directory", max_runs=1, processor_timeout=10 * 60, ) - test_dag_path = str(TEST_DAG_FOLDER / "test_example_bash_operator.py") - dagbag = DagBag(test_dag_path, read_dags_from_db=False, include_examples=False) + test_dag_path = DagFileInfo( + bundle_name="testing", + path=str(TEST_DAG_FOLDER / "test_example_bash_operator.py"), + ) + dagbag = DagBag(test_dag_path.path, read_dags_from_db=False, include_examples=False) with create_session() as session: # Add stale DAG to the DB dag = dagbag.get_dag("test_example_bash_operator") dag.last_parsed_time = timezone.utcnow() - dag.sync_to_db() + DAG.bulk_write_to_db("testing", None, [dag]) SerializedDagModel.write_dag(dag) # Add DAG to the file_parsing_stats @@ -386,7 +400,7 @@ def test_scan_stale_dags(self): active_dag_count = ( session.query(func.count(DagModel.dag_id)) - .filter(DagModel.is_active, DagModel.fileloc == test_dag_path) + .filter(DagModel.is_active, DagModel.fileloc == test_dag_path.path) .scalar() ) assert active_dag_count == 1 @@ -395,7 +409,7 @@ def test_scan_stale_dags(self): active_dag_count = ( session.query(func.count(DagModel.dag_id)) - .filter(DagModel.is_active, DagModel.fileloc == test_dag_path) + .filter(DagModel.is_active, DagModel.fileloc == test_dag_path.path) .scalar() ) assert active_dag_count == 0 @@ -410,11 +424,11 @@ def test_scan_stale_dags(self): assert serialized_dag_count == 1 def test_kill_timed_out_processors_kill(self): - manager = DagFileProcessorManager(dag_directory="directory", max_runs=1, processor_timeout=5) + manager = DagFileProcessorManager(max_runs=1, processor_timeout=5) processor = self.mock_processor() processor._process.create_time.return_value = timezone.make_aware(datetime.min).timestamp() - manager._processors = {"abc.txt": processor} + manager._processors = {DagFileInfo(bundle_name="testing", path="abc.txt"): processor} with mock.patch.object(type(processor), "kill") as mock_kill: manager._kill_timed_out_processors() mock_kill.assert_called_once_with(signal.SIGKILL) @@ -422,14 +436,13 @@ def test_kill_timed_out_processors_kill(self): def test_kill_timed_out_processors_no_kill(self): manager = DagFileProcessorManager( - dag_directory=TEST_DAG_FOLDER, max_runs=1, processor_timeout=5, ) processor = self.mock_processor() processor._process.create_time.return_value = timezone.make_aware(datetime.max).timestamp() - manager._processors = {"abc.txt": processor} + manager._processors = {DagFileInfo(bundle_name="testing", path="abc.txt"): processor} with mock.patch.object(type(processor), "kill") as mock_kill: manager._kill_timed_out_processors() mock_kill.assert_not_called() @@ -472,7 +485,7 @@ def test_serialize_callback_requests(self, callbacks, path, child_comms_fd, expe @conf_vars({("core", "load_examples"): "False"}) @pytest.mark.execution_timeout(10) - def test_dag_with_system_exit(self): + def test_dag_with_system_exit(self, configure_testing_dag_bundle): """ Test to check that a DAG with a system.exit() doesn't break the scheduler. """ @@ -484,9 +497,9 @@ def test_dag_with_system_exit(self): clear_db_dags() clear_db_serialized_dags() - manager = DagFileProcessorManager(dag_directory=dag_directory, max_runs=1) - - manager._run_parsing_loop() + with configure_testing_dag_bundle(dag_directory): + manager = DagFileProcessorManager(max_runs=1) + manager.run() # Three files in folder should be processed assert sum(stat.run_count for stat in manager._file_stats.values()) == 3 @@ -496,7 +509,7 @@ def test_dag_with_system_exit(self): @conf_vars({("core", "load_examples"): "False"}) @pytest.mark.execution_timeout(30) - def test_pipe_full_deadlock(self): + def test_pipe_full_deadlock(self, configure_testing_dag_bundle): dag_filepath = TEST_DAG_FOLDER / "test_scheduler_dags.py" child_pipe, parent_pipe = multiprocessing.Pipe() @@ -533,38 +546,43 @@ def keep_pipe_full(pipe, exit_event): thread = threading.Thread(target=keep_pipe_full, args=(parent_pipe, exit_event)) - manager = DagFileProcessorManager( - dag_directory=dag_filepath, - # A reasonable large number to ensure that we trigger the deadlock - max_runs=100, - processor_timeout=5, - signal_conn=child_pipe, - # Make it loop sub-processes quickly. Need to be non-zero to exercise the bug, else it finishes - # too quickly - file_process_interval=0.01, - ) - - try: - thread.start() + with configure_testing_dag_bundle(dag_filepath): + manager = DagFileProcessorManager( + # A reasonable large number to ensure that we trigger the deadlock + max_runs=100, + processor_timeout=5, + signal_conn=child_pipe, + # Make it loop sub-processes quickly. Need to be non-zero to exercise the bug, else it finishes + # too quickly + file_process_interval=0.01, + ) - # If this completes without hanging, then the test is good! - with mock.patch.object( - DagFileProcessorProcess, "start", side_effect=lambda *args, **kwargs: self.mock_processor() - ): - manager.run() - exit_event.set() - finally: - logger.info("Closing pipes") - parent_pipe.close() - child_pipe.close() - logger.info("Closed pipes") - logger.info("Joining thread") - thread.join(timeout=1.0) - logger.info("Joined thread") + try: + thread.start() + + # If this completes without hanging, then the test is good! + with mock.patch.object( + DagFileProcessorProcess, + "start", + side_effect=lambda *args, **kwargs: self.mock_processor(), + ): + manager.run() + exit_event.set() + finally: + logger.info("Closing pipes") + parent_pipe.close() + child_pipe.close() + logger.info("Closed pipes") + logger.info("Joining thread") + thread.join(timeout=1.0) + logger.info("Joined thread") @conf_vars({("core", "load_examples"): "False"}) @mock.patch("airflow.dag_processing.manager.Stats.timing") - def test_send_file_processing_statsd_timing(self, statsd_timing_mock, tmp_path): + @pytest.mark.skip("AIP-66: stats are not implemented yet") + def test_send_file_processing_statsd_timing( + self, statsd_timing_mock, tmp_path, configure_testing_dag_bundle + ): path_to_parse = tmp_path / "temp_dag.py" dag_code = textwrap.dedent( """ @@ -574,11 +592,11 @@ def test_send_file_processing_statsd_timing(self, statsd_timing_mock, tmp_path): ) path_to_parse.write_text(dag_code) - manager = DagFileProcessorManager(dag_directory=path_to_parse.parent, max_runs=1) + with configure_testing_dag_bundle(tmp_path): + manager = DagFileProcessorManager(max_runs=1) + manager.run() - self.run_processor_manager_one_loop(manager) last_runtime = manager._file_stats[os.fspath(path_to_parse)].last_duration - statsd_timing_mock.assert_has_calls( [ mock.call("dag_processing.last_duration.temp_dag", last_runtime), @@ -587,17 +605,19 @@ def test_send_file_processing_statsd_timing(self, statsd_timing_mock, tmp_path): any_order=True, ) - def test_refresh_dags_dir_doesnt_delete_zipped_dags(self, tmp_path): + def test_refresh_dags_dir_doesnt_delete_zipped_dags(self, tmp_path, configure_testing_dag_bundle): """Test DagFileProcessorManager._refresh_dag_dir method""" - manager = DagFileProcessorManager(dag_directory=TEST_DAG_FOLDER, max_runs=1) dagbag = DagBag(dag_folder=tmp_path, include_examples=False) zipped_dag_path = os.path.join(TEST_DAGS_FOLDER, "test_zip.zip") dagbag.process_file(zipped_dag_path) dag = dagbag.get_dag("test_zip_dag") - dag.sync_to_db() + DAG.bulk_write_to_db("testing", None, [dag]) SerializedDagModel.write_dag(dag) - manager.last_dag_dir_refresh_time = time.monotonic() - 10 * 60 - manager._refresh_dag_dir() + + with configure_testing_dag_bundle(zipped_dag_path): + manager = DagFileProcessorManager(max_runs=1) + manager.run() + # Assert dag not deleted in SDM assert SerializedDagModel.has_dag("test_zip_dag") # assert code not deleted @@ -605,20 +625,23 @@ def test_refresh_dags_dir_doesnt_delete_zipped_dags(self, tmp_path): # assert dag still active assert dag.get_is_active() - def test_refresh_dags_dir_deactivates_deleted_zipped_dags(self, tmp_path): + def test_refresh_dags_dir_deactivates_deleted_zipped_dags(self, tmp_path, configure_testing_dag_bundle): """Test DagFileProcessorManager._refresh_dag_dir method""" - manager = DagFileProcessorManager(dag_directory=TEST_DAG_FOLDER, max_runs=1) dagbag = DagBag(dag_folder=tmp_path, include_examples=False) zipped_dag_path = os.path.join(TEST_DAGS_FOLDER, "test_zip.zip") dagbag.process_file(zipped_dag_path) dag = dagbag.get_dag("test_zip_dag") dag.sync_to_db() SerializedDagModel.write_dag(dag) - manager.last_dag_dir_refresh_time = time.monotonic() - 10 * 60 + + # TODO: this test feels a bit fragile - pointing at the zip directly causes the test to fail + # TODO: jed look at this more closely - bagbad then process_file?! # Mock might_contain_dag to mimic deleting the python file from the zip with mock.patch("airflow.dag_processing.manager.might_contain_dag", return_value=False): - manager._refresh_dag_dir() + with configure_testing_dag_bundle(TEST_DAGS_FOLDER): + manager = DagFileProcessorManager(max_runs=1) + manager.run() # Deleting the python file should not delete SDM for versioning sake assert SerializedDagModel.has_dag("test_zip_dag") @@ -635,7 +658,7 @@ def test_refresh_dags_dir_deactivates_deleted_zipped_dags(self, tmp_path): ("scheduler", "standalone_dag_processor"): "True", } ) - def test_fetch_callbacks_from_database(self, tmp_path): + def test_fetch_callbacks_from_database(self, tmp_path, configure_testing_dag_bundle): dag_filepath = TEST_DAG_FOLDER / "test_on_failure_callback_dag.py" callback1 = DagCallbackRequest( @@ -655,13 +678,12 @@ def test_fetch_callbacks_from_database(self, tmp_path): session.add(DbCallbackRequest(callback=callback1, priority_weight=11)) session.add(DbCallbackRequest(callback=callback2, priority_weight=10)) - manager = DagFileProcessorManager( - dag_directory=os.fspath(tmp_path), max_runs=1, standalone_dag_processor=True - ) + with configure_testing_dag_bundle(tmp_path): + manager = DagFileProcessorManager(max_runs=1, standalone_dag_processor=True) - with create_session() as session: - self.run_processor_manager_one_loop(manager) - assert session.query(DbCallbackRequest).count() == 0 + with create_session() as session: + manager.run() + assert session.query(DbCallbackRequest).count() == 0 @conf_vars( { @@ -670,7 +692,7 @@ def test_fetch_callbacks_from_database(self, tmp_path): ("core", "load_examples"): "False", } ) - def test_fetch_callbacks_from_database_max_per_loop(self, tmp_path): + def test_fetch_callbacks_from_database_max_per_loop(self, tmp_path, configure_testing_dag_bundle): """Test DagFileProcessorManager._fetch_callbacks method""" dag_filepath = TEST_DAG_FOLDER / "test_on_failure_callback_dag.py" @@ -684,15 +706,16 @@ def test_fetch_callbacks_from_database_max_per_loop(self, tmp_path): ) session.add(DbCallbackRequest(callback=callback, priority_weight=i)) - manager = DagFileProcessorManager(dag_directory=tmp_path, max_runs=1) + with configure_testing_dag_bundle(tmp_path): + manager = DagFileProcessorManager(max_runs=1) - with create_session() as session: - self.run_processor_manager_one_loop(manager) - assert session.query(DbCallbackRequest).count() == 3 + with create_session() as session: + manager.run() + assert session.query(DbCallbackRequest).count() == 3 - with create_session() as session: - self.run_processor_manager_one_loop(manager) - assert session.query(DbCallbackRequest).count() == 1 + with create_session() as session: + manager.run() + assert session.query(DbCallbackRequest).count() == 1 @conf_vars( { @@ -700,7 +723,7 @@ def test_fetch_callbacks_from_database_max_per_loop(self, tmp_path): ("core", "load_examples"): "False", } ) - def test_fetch_callbacks_from_database_not_standalone(self, tmp_path): + def test_fetch_callbacks_from_database_not_standalone(self, tmp_path, configure_testing_dag_bundle): dag_filepath = TEST_DAG_FOLDER / "test_on_failure_callback_dag.py" with create_session() as session: @@ -712,22 +735,23 @@ def test_fetch_callbacks_from_database_not_standalone(self, tmp_path): ) session.add(DbCallbackRequest(callback=callback, priority_weight=10)) - manager = DagFileProcessorManager(dag_directory=tmp_path, max_runs=1) - - self.run_processor_manager_one_loop(manager) + with configure_testing_dag_bundle(tmp_path): + manager = DagFileProcessorManager(max_runs=1) + manager.run() # Verify no callbacks removed from database. with create_session() as session: assert session.query(DbCallbackRequest).count() == 1 + @pytest.mark.skip("AIP-66: callbacks are not implemented yet") def test_callback_queue(self, tmp_path): # given manager = DagFileProcessorManager( - dag_directory=TEST_DAG_FOLDER, max_runs=1, processor_timeout=365 * 86_400, ) + dag1_path = DagFileInfo(bundle_name="testing", path="/green_eggs/ham/file1.py") dag1_req1 = DagCallbackRequest( full_filepath="/green_eggs/ham/file1.py", dag_id="dag1", @@ -743,6 +767,7 @@ def test_callback_queue(self, tmp_path): msg=None, ) + dag2_path = DagFileInfo(bundle_name="testing", path="/green_eggs/ham/file2.py") dag2_req1 = DagCallbackRequest( full_filepath="/green_eggs/ham/file2.py", dag_id="dag2", @@ -756,7 +781,7 @@ def test_callback_queue(self, tmp_path): manager._add_callback_to_queue(dag2_req1) # then - requests should be in manager's queue, with dag2 ahead of dag1 (because it was added last) - assert manager._file_path_queue == deque([dag2_req1.full_filepath, dag1_req1.full_filepath]) + assert manager._file_path_queue == deque([dag2_path, dag1_path]) assert set(manager._callback_to_execute.keys()) == { dag1_req1.full_filepath, dag2_req1.full_filepath, @@ -787,24 +812,106 @@ def test_callback_queue(self, tmp_path): # And removed from the queue assert dag1_req1.full_filepath not in manager._callback_to_execute - def test_dag_with_assets(self, session): + def test_dag_with_assets(self, session, configure_testing_dag_bundle): """'Integration' test to ensure that the assets get parsed and stored correctly for parsed dags.""" test_dag_path = str(TEST_DAG_FOLDER / "test_assets.py") - manager = DagFileProcessorManager( - dag_directory=test_dag_path, - max_runs=1, - processor_timeout=365 * 86_400, - ) - - self.run_processor_manager_one_loop(manager) + with configure_testing_dag_bundle(test_dag_path): + manager = DagFileProcessorManager( + max_runs=1, + processor_timeout=365 * 86_400, + ) + manager.run() dag_model = session.get(DagModel, ("dag_with_skip_task")) assert dag_model.task_outlet_asset_references == [ TaskOutletAssetReference(asset_id=mock.ANY, dag_id="dag_with_skip_task", task_id="skip_task") ] + def test_bundles_are_refreshed(self): + """ + Ensure bundles are refreshed by the manager, when necessary. + + - refresh if the bundle hasn't been refreshed in the refresh_interval + - when the latest_version in the db doesn't match the version this parser knows about + """ + config = [ + { + "name": "bundleone", + "classpath": "airflow.dag_processing.bundles.local.LocalDagBundle", + "kwargs": {"local_folder": "/dev/null", "refresh_interval": 0}, + }, + { + "name": "bundletwo", + "classpath": "airflow.dag_processing.bundles.local.LocalDagBundle", + "kwargs": {"local_folder": "/dev/null", "refresh_interval": 300}, + }, + ] + + bundleone = MagicMock() + bundleone.name = "bundleone" + bundleone.refresh_interval = 0 + bundleone.get_current_version.return_value = None + bundletwo = MagicMock() + bundletwo.name = "bundletwo" + bundletwo.refresh_interval = 300 + bundletwo.get_current_version.return_value = None + + with conf_vars({("dag_bundles", "backends"): json.dumps(config)}): + DagBundlesManager().sync_bundles_to_db() + with mock.patch( + "airflow.dag_processing.bundles.manager.DagBundlesManager" + ) as mock_bundle_manager: + mock_bundle_manager.return_value._bundle_config = {"bundleone": None, "bundletwo": None} + mock_bundle_manager.return_value.get_all_dag_bundles.return_value = [bundleone, bundletwo] + manager = DagFileProcessorManager(max_runs=1) + manager.run() + bundleone.refresh.assert_called_once() + bundletwo.refresh.assert_called_once() + + # Now, we should only refresh bundleone, as haven't hit the refresh_interval for bundletwo + bundleone.reset_mock() + bundletwo.reset_mock() + manager.run() + bundleone.refresh.assert_called_once() + bundletwo.refresh.assert_not_called() + + # however, if the version doesn't match, we should still refresh + bundletwo.reset_mock() + bundletwo.get_current_version.return_value = "123" + manager.run() + bundletwo.refresh.assert_called_once() + + def test_bundles_versions_are_stored(self): + config = [ + { + "name": "mybundle", + "classpath": "airflow.dag_processing.bundles.local.LocalDagBundle", + "kwargs": {"local_folder": "/dev/null", "refresh_interval": 0}, + }, + ] + + mybundle = MagicMock() + mybundle.name = "bundleone" + mybundle.refresh_interval = 0 + mybundle.supports_versioning = True + mybundle.get_current_version.return_value = "123" + + with conf_vars({("dag_bundles", "backends"): json.dumps(config)}): + DagBundlesManager().sync_bundles_to_db() + with mock.patch( + "airflow.dag_processing.bundles.manager.DagBundlesManager" + ) as mock_bundle_manager: + mock_bundle_manager.return_value._bundle_config = {"bundleone": None} + mock_bundle_manager.return_value.get_all_dag_bundles.return_value = [mybundle] + manager = DagFileProcessorManager(max_runs=1) + manager.run() + + with create_session() as session: + model = session.get(DagBundleModel, "bundleone") + assert model.latest_version == "123" + class TestDagFileProcessorAgent: @pytest.fixture(autouse=True) @@ -815,14 +922,12 @@ def _disable_examples(self): def test_launch_process(self): from airflow.configuration import conf - test_dag_path = TEST_DAG_FOLDER / "test_scheduler_dags.py" - log_file_loc = conf.get("logging", "DAG_PROCESSOR_MANAGER_LOG_LOCATION") with suppress(OSError): os.remove(log_file_loc) # Starting dag processing with 0 max_runs to avoid redundant operations. - processor_agent = DagFileProcessorAgent(test_dag_path, 0, timedelta(days=365)) + processor_agent = DagFileProcessorAgent(0, timedelta(days=365)) processor_agent.start() processor_agent._process.join() @@ -830,25 +935,25 @@ def test_launch_process(self): assert os.path.isfile(log_file_loc) def test_get_callbacks_pipe(self): - processor_agent = DagFileProcessorAgent("", 1, timedelta(days=365)) + processor_agent = DagFileProcessorAgent(1, timedelta(days=365)) processor_agent._parent_signal_conn = Mock() retval = processor_agent.get_callbacks_pipe() assert retval == processor_agent._parent_signal_conn def test_get_callbacks_pipe_no_parent_signal_conn(self): - processor_agent = DagFileProcessorAgent("", 1, timedelta(days=365)) + processor_agent = DagFileProcessorAgent(1, timedelta(days=365)) processor_agent._parent_signal_conn = None with pytest.raises(ValueError, match="Process not started"): processor_agent.get_callbacks_pipe() def test_heartbeat_no_parent_signal_conn(self): - processor_agent = DagFileProcessorAgent("", 1, timedelta(days=365)) + processor_agent = DagFileProcessorAgent(1, timedelta(days=365)) processor_agent._parent_signal_conn = None with pytest.raises(ValueError, match="Process not started"): processor_agent.heartbeat() def test_heartbeat_poll_eof_error(self): - processor_agent = DagFileProcessorAgent("", 1, timedelta(days=365)) + processor_agent = DagFileProcessorAgent(1, timedelta(days=365)) processor_agent._parent_signal_conn = Mock() processor_agent._parent_signal_conn.poll.return_value = True processor_agent._parent_signal_conn.recv = Mock() @@ -857,7 +962,7 @@ def test_heartbeat_poll_eof_error(self): assert ret_val is None def test_heartbeat_poll_connection_error(self): - processor_agent = DagFileProcessorAgent("", 1, timedelta(days=365)) + processor_agent = DagFileProcessorAgent(1, timedelta(days=365)) processor_agent._parent_signal_conn = Mock() processor_agent._parent_signal_conn.poll.return_value = True processor_agent._parent_signal_conn.recv = Mock() @@ -866,7 +971,7 @@ def test_heartbeat_poll_connection_error(self): assert ret_val is None def test_heartbeat_poll_process_message(self): - processor_agent = DagFileProcessorAgent("", 1, timedelta(days=365)) + processor_agent = DagFileProcessorAgent(1, timedelta(days=365)) processor_agent._parent_signal_conn = Mock() processor_agent._parent_signal_conn.poll.side_effect = [True, False] processor_agent._parent_signal_conn.recv = Mock() @@ -877,13 +982,13 @@ def test_heartbeat_poll_process_message(self): def test_process_message_invalid_type(self): message = "xyz" - processor_agent = DagFileProcessorAgent("", 1, timedelta(days=365)) + processor_agent = DagFileProcessorAgent(1, timedelta(days=365)) with pytest.raises(RuntimeError, match="Unexpected message received of type str"): processor_agent._process_message(message) @mock.patch("airflow.utils.process_utils.reap_process_group") def test_heartbeat_manager_process_restart(self, mock_pg, monkeypatch): - processor_agent = DagFileProcessorAgent("", 1, timedelta(days=365)) + processor_agent = DagFileProcessorAgent(1, timedelta(days=365)) processor_agent._parent_signal_conn = Mock() processor_agent._process = MagicMock() monkeypatch.setattr(processor_agent._process, "pid", 1234) @@ -899,7 +1004,7 @@ def test_heartbeat_manager_process_restart(self, mock_pg, monkeypatch): @mock.patch("time.monotonic") @mock.patch("airflow.dag_processing.manager.reap_process_group") def test_heartbeat_manager_process_reap(self, mock_pg, mock_time_monotonic, mock_stats): - processor_agent = DagFileProcessorAgent("", 1, timedelta(days=365)) + processor_agent = DagFileProcessorAgent(1, timedelta(days=365)) processor_agent._parent_signal_conn = Mock() processor_agent._process = Mock() processor_agent._process.pid = 12345 @@ -920,7 +1025,7 @@ def test_heartbeat_manager_process_reap(self, mock_pg, mock_time_monotonic, mock processor_agent.start.assert_called() def test_heartbeat_manager_end_no_process(self): - processor_agent = DagFileProcessorAgent("", 1, timedelta(days=365)) + processor_agent = DagFileProcessorAgent(1, timedelta(days=365)) processor_agent._process = Mock() processor_agent._process.__bool__ = Mock(return_value=False) processor_agent._process.side_effect = [None] @@ -931,26 +1036,25 @@ def test_heartbeat_manager_end_no_process(self): processor_agent._process.join.assert_not_called() @pytest.mark.execution_timeout(5) - def test_terminate(self, tmp_path): - processor_agent = DagFileProcessorAgent(tmp_path, -1, timedelta(days=365)) + def test_terminate(self, tmp_path, configure_testing_dag_bundle): + with configure_testing_dag_bundle(tmp_path): + processor_agent = DagFileProcessorAgent(-1, timedelta(days=365)) - processor_agent.start() - try: - processor_agent.terminate() + processor_agent.start() + try: + processor_agent.terminate() - processor_agent._process.join(timeout=1) - assert processor_agent._process.is_alive() is False - assert processor_agent._process.exitcode == 0 - except Exception: - reap_process_group(processor_agent._process.pid, logger=logger) - raise + processor_agent._process.join(timeout=1) + assert processor_agent._process.is_alive() is False + assert processor_agent._process.exitcode == 0 + except Exception: + reap_process_group(processor_agent._process.pid, logger=logger) + raise @conf_vars({("logging", "dag_processor_manager_log_stdout"): "True"}) def test_log_to_stdout(self, capfd): - test_dag_path = TEST_DAG_FOLDER / "test_scheduler_dags.py" - # Starting dag processing with 0 max_runs to avoid redundant operations. - processor_agent = DagFileProcessorAgent(test_dag_path, 0, timedelta(days=365)) + processor_agent = DagFileProcessorAgent(0, timedelta(days=365)) processor_agent.start() processor_agent._process.join() @@ -961,10 +1065,8 @@ def test_log_to_stdout(self, capfd): @conf_vars({("logging", "dag_processor_manager_log_stdout"): "False"}) def test_not_log_to_stdout(self, capfd): - test_dag_path = TEST_DAG_FOLDER / "test_scheduler_dags.py" - # Starting dag processing with 0 max_runs to avoid redundant operations. - processor_agent = DagFileProcessorAgent(test_dag_path, 0, timedelta(days=365)) + processor_agent = DagFileProcessorAgent(0, timedelta(days=365)) processor_agent.start() processor_agent._process.join() diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index b9db6f2751d5f..497ae4fb1d441 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -107,6 +107,7 @@ ELASTIC_DAG_FILE = os.path.join(PERF_DAGS_FOLDER, "elastic_dag.py") TEST_DAG_FOLDER = os.environ["AIRFLOW__CORE__DAGS_FOLDER"] +EXAMPLE_DAGS_FOLDER = airflow.example_dags.__path__[0] DEFAULT_DATE = timezone.datetime(2016, 1, 1) DEFAULT_LOGICAL_DATE = timezone.coerce_datetime(DEFAULT_DATE) TRY_NUMBER = 1 @@ -119,12 +120,6 @@ def disable_load_example(): yield -@pytest.fixture -def load_examples(): - with conf_vars({("core", "load_examples"): "True"}): - yield - - # Patch the MockExecutor into the dict of known executors in the Loader @contextlib.contextmanager def _loader_mock(mock_executors): @@ -579,26 +574,32 @@ def test_execute_task_instances_backfill_tasks_will_execute(self, dag_maker): session.rollback() @conf_vars({("scheduler", "standalone_dag_processor"): "False"}) - def test_setup_callback_sink_not_standalone_dag_processor(self, mock_executors): - scheduler_job = Job() - self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull, num_runs=1) - self.job_runner._execute() + def test_setup_callback_sink_not_standalone_dag_processor( + self, mock_executors, configure_testing_dag_bundle + ): + with configure_testing_dag_bundle(os.devnull): + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, num_runs=1) + self.job_runner._execute() - assert isinstance(scheduler_job.executor.callback_sink, PipeCallbackSink) + assert isinstance(scheduler_job.executor.callback_sink, PipeCallbackSink) @conf_vars({("scheduler", "standalone_dag_processor"): "False"}) - def test_setup_callback_sink_not_standalone_dag_processor_multiple_executors(self, mock_executors): - scheduler_job = Job() - self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull, num_runs=1) - self.job_runner._execute() + def test_setup_callback_sink_not_standalone_dag_processor_multiple_executors( + self, mock_executors, configure_testing_dag_bundle + ): + with configure_testing_dag_bundle(os.devnull): + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, num_runs=1) + self.job_runner._execute() - for executor in scheduler_job.executors: - assert isinstance(executor.callback_sink, PipeCallbackSink) + for executor in scheduler_job.executors: + assert isinstance(executor.callback_sink, PipeCallbackSink) @conf_vars({("scheduler", "standalone_dag_processor"): "True"}) def test_setup_callback_sink_standalone_dag_processor(self, mock_executors): scheduler_job = Job() - self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull, num_runs=1) + self.job_runner = SchedulerJobRunner(job=scheduler_job, num_runs=1) self.job_runner._execute() assert isinstance(scheduler_job.executor.callback_sink, DatabaseCallbackSink) @@ -606,7 +607,7 @@ def test_setup_callback_sink_standalone_dag_processor(self, mock_executors): @conf_vars({("scheduler", "standalone_dag_processor"): "True"}) def test_setup_callback_sink_standalone_dag_processor_multiple_executors(self, mock_executors): scheduler_job = Job() - self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull, num_runs=1) + self.job_runner = SchedulerJobRunner(job=scheduler_job, num_runs=1) self.job_runner._execute() for executor in scheduler_job.executors: @@ -615,37 +616,40 @@ def test_setup_callback_sink_standalone_dag_processor_multiple_executors(self, m @conf_vars({("scheduler", "standalone_dag_processor"): "True"}) def test_executor_start_called(self, mock_executors): scheduler_job = Job() - self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull, num_runs=1) + self.job_runner = SchedulerJobRunner(job=scheduler_job, num_runs=1) self.job_runner._execute() scheduler_job.executor.start.assert_called_once() for executor in scheduler_job.executors: executor.start.assert_called_once() - def test_executor_job_id_assigned(self, mock_executors): - scheduler_job = Job() - self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull, num_runs=1) - self.job_runner._execute() + def test_executor_job_id_assigned(self, mock_executors, configure_testing_dag_bundle): + with configure_testing_dag_bundle(os.devnull): + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, num_runs=1) + self.job_runner._execute() - assert scheduler_job.executor.job_id == scheduler_job.id - for executor in scheduler_job.executors: - assert executor.job_id == scheduler_job.id + assert scheduler_job.executor.job_id == scheduler_job.id + for executor in scheduler_job.executors: + assert executor.job_id == scheduler_job.id - def test_executor_heartbeat(self, mock_executors): - scheduler_job = Job() - self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull, num_runs=1) - self.job_runner._execute() + def test_executor_heartbeat(self, mock_executors, configure_testing_dag_bundle): + with configure_testing_dag_bundle(os.devnull): + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, num_runs=1) + self.job_runner._execute() - for executor in scheduler_job.executors: - executor.heartbeat.assert_called_once() + for executor in scheduler_job.executors: + executor.heartbeat.assert_called_once() - def test_executor_events_processed(self, mock_executors): - scheduler_job = Job() - self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull, num_runs=1) - self.job_runner._execute() + def test_executor_events_processed(self, mock_executors, configure_testing_dag_bundle): + with configure_testing_dag_bundle(os.devnull): + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, num_runs=1) + self.job_runner._execute() - for executor in scheduler_job.executors: - executor.get_event_buffer.assert_called_once() + for executor in scheduler_job.executors: + executor.get_event_buffer.assert_called_once() def test_executor_debug_dump(self, mock_executors): scheduler_job = Job() @@ -2891,17 +2895,17 @@ def my_task(): ... self.job_runner._do_scheduling(session) assert session.query(DagRun).one().state == run_state - def test_dagrun_root_after_dagrun_unfinished(self, mock_executor): + def test_dagrun_root_after_dagrun_unfinished(self, mock_executor, testing_dag_bundle): """ DagRuns with one successful and one future root task -> SUCCESS Noted: the DagRun state could be still in running state during CI. """ dagbag = DagBag(TEST_DAG_FOLDER, include_examples=False) - dagbag.sync_to_db() + dagbag.sync_to_db("testing", None) dag_id = "test_dagrun_states_root_future" dag = dagbag.get_dag(dag_id) - dag.sync_to_db() + DAG.bulk_write_to_db("testing", None, [dag]) scheduler_job = Job() self.job_runner = SchedulerJobRunner(job=scheduler_job, num_runs=2, subdir=dag.fileloc) @@ -2920,7 +2924,7 @@ def test_dagrun_root_after_dagrun_unfinished(self, mock_executor): {("scheduler", "standalone_dag_processor"): "True"}, ], ) - def test_scheduler_start_date(self, configs): + def test_scheduler_start_date(self, configs, testing_dag_bundle): """ Test that the scheduler respects start_dates, even when DAGs have run """ @@ -2935,7 +2939,7 @@ def test_scheduler_start_date(self, configs): # Deactivate other dags in this file other_dag = dagbag.get_dag("test_task_start_date_scheduling") other_dag.is_paused_upon_creation = True - other_dag.sync_to_db() + DAG.bulk_write_to_db("testing", None, [other_dag]) scheduler_job = Job( executor=self.null_exec, ) @@ -2981,7 +2985,7 @@ def test_scheduler_start_date(self, configs): {("scheduler", "standalone_dag_processor"): "True"}, ], ) - def test_scheduler_task_start_date(self, configs): + def test_scheduler_task_start_date(self, configs, testing_dag_bundle): """ Test that the scheduler respects task start dates that are different from DAG start dates """ @@ -3000,7 +3004,7 @@ def test_scheduler_task_start_date(self, configs): other_dag.is_paused_upon_creation = True dagbag.bag_dag(dag=other_dag) - dagbag.sync_to_db() + dagbag.sync_to_db("testing", None) scheduler_job = Job( executor=self.null_exec, @@ -3553,21 +3557,7 @@ def test_list_py_file_paths(self): if file_name.endswith((".py", ".zip")): if file_name not in ignored_files: expected_files.add(f"{root}/{file_name}") - for file_path in list_py_file_paths(TEST_DAG_FOLDER, include_examples=False): - detected_files.add(file_path) - assert detected_files == expected_files - - ignored_files = { - "helper.py", - } - example_dag_folder = airflow.example_dags.__path__[0] - for root, _, files in os.walk(example_dag_folder): - for file_name in files: - if file_name.endswith((".py", ".zip")): - if file_name not in ["__init__.py"] and file_name not in ignored_files: - expected_files.add(os.path.join(root, file_name)) - detected_files.clear() - for file_path in list_py_file_paths(TEST_DAG_FOLDER, include_examples=True): + for file_path in list_py_file_paths(TEST_DAG_FOLDER): detected_files.add(file_path) assert detected_files == expected_files @@ -4207,7 +4197,7 @@ def test_scheduler_create_dag_runs_does_not_raise_error(self, caplog, dag_maker) "DAG 'test_scheduler_create_dag_runs_does_not_raise_error' not found in serialized_dag table", ] - def test_bulk_write_to_db_external_trigger_dont_skip_scheduled_run(self, dag_maker): + def test_bulk_write_to_db_external_trigger_dont_skip_scheduled_run(self, dag_maker, testing_dag_bundle): """ Test that externally triggered Dag Runs should not affect (by skipping) next scheduled DAG runs @@ -4265,7 +4255,7 @@ def test_bulk_write_to_db_external_trigger_dont_skip_scheduled_run(self, dag_mak ) assert dr is not None # Run DAG.bulk_write_to_db -- this is run when in DagFileProcessor.process_file - DAG.bulk_write_to_db([dag], session=session) + DAG.bulk_write_to_db("testing", None, [dag], session=session) # Test that 'dag_model.next_dagrun' has not been changed because of newly created external # triggered DagRun. @@ -4558,7 +4548,7 @@ def test_do_schedule_max_active_runs_and_manual_trigger(self, dag_maker, mock_ex dag_version = DagVersion.get_latest_version(dag.dag_id) dag_run = dag_maker.create_dagrun(state=State.QUEUED, session=session, dag_version=dag_version) - dag.sync_to_db(session=session) # Update the date fields + DAG.bulk_write_to_db("testing", None, [dag], session=session) # Update the date fields scheduler_job = Job() self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) @@ -5653,11 +5643,11 @@ def test_find_and_purge_zombies_nothing(self): self.job_runner._find_and_purge_zombies() executor.callback_sink.send.assert_not_called() - def test_find_and_purge_zombies(self, load_examples, session): - dagbag = DagBag(TEST_DAG_FOLDER, read_dags_from_db=False) - + def test_find_and_purge_zombies(self, session, testing_dag_bundle): + dagfile = os.path.join(EXAMPLE_DAGS_FOLDER, "example_branch_operator.py") + dagbag = DagBag(dagfile) dag = dagbag.get_dag("example_branch_operator") - dag.sync_to_db() + DAG.bulk_write_to_db("testing", None, [dag]) data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE) triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} dag_run = dag.create_dagrun( @@ -5709,70 +5699,74 @@ def test_find_and_purge_zombies(self, load_examples, session): assert callback_request.ti.run_id == ti.run_id assert callback_request.ti.map_index == ti.map_index - def test_zombie_message(self, load_examples): + def test_zombie_message(self, testing_dag_bundle, session): """ Check that the zombie message comes out as expected """ dagbag = DagBag(TEST_DAG_FOLDER, read_dags_from_db=False) - with create_session() as session: - session.query(Job).delete() - dag = dagbag.get_dag("example_branch_operator") - dag.sync_to_db() - - data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE) - triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} - dag_run = dag.create_dagrun( - state=DagRunState.RUNNING, - logical_date=DEFAULT_DATE, - run_type=DagRunType.SCHEDULED, - session=session, - data_interval=data_interval, - **triggered_by_kwargs, - ) + dagfile = os.path.join(EXAMPLE_DAGS_FOLDER, "example_branch_operator.py") + dagbag = DagBag(dagfile) + dag = dagbag.get_dag("example_branch_operator") + DAG.bulk_write_to_db("testing", None, [dag]) - scheduler_job = Job(executor=MockExecutor()) - self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) - self.job_runner.processor_agent = mock.MagicMock() + session.query(Job).delete() - # We will provision 2 tasks so we can check we only find zombies from this scheduler - tasks_to_setup = ["branching", "run_this_first"] + data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE) + triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} + dag_run = dag.create_dagrun( + state=DagRunState.RUNNING, + logical_date=DEFAULT_DATE, + run_type=DagRunType.SCHEDULED, + session=session, + data_interval=data_interval, + **triggered_by_kwargs, + ) - for task_id in tasks_to_setup: - task = dag.get_task(task_id=task_id) - ti = TaskInstance(task, run_id=dag_run.run_id, state=State.RUNNING) - ti.queued_by_job_id = 999 + scheduler_job = Job(executor=MockExecutor()) + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + self.job_runner.processor_agent = mock.MagicMock() - session.add(ti) - session.flush() + # We will provision 2 tasks so we can check we only find zombies from this scheduler + tasks_to_setup = ["branching", "run_this_first"] - assert task.task_id == "run_this_first" # Make sure we have the task/ti we expect + for task_id in tasks_to_setup: + task = dag.get_task(task_id=task_id) + ti = TaskInstance(task, run_id=dag_run.run_id, state=State.RUNNING) + ti.queued_by_job_id = 999 - ti.queued_by_job_id = scheduler_job.id + session.add(ti) session.flush() - zombie_message = self.job_runner._generate_zombie_message_details(ti) - assert zombie_message == { - "DAG Id": "example_branch_operator", - "Task Id": "run_this_first", - "Run Id": "scheduled__2016-01-01T00:00:00+00:00", - } - - ti.hostname = "10.10.10.10" - ti.map_index = 2 - ti.external_executor_id = "abcdefg" - - zombie_message = self.job_runner._generate_zombie_message_details(ti) - assert zombie_message == { - "DAG Id": "example_branch_operator", - "Task Id": "run_this_first", - "Run Id": "scheduled__2016-01-01T00:00:00+00:00", - "Hostname": "10.10.10.10", - "Map Index": 2, - "External Executor Id": "abcdefg", - } - - def test_find_zombies_handle_failure_callbacks_are_correctly_passed_to_dag_processor(self): + assert task.task_id == "run_this_first" # Make sure we have the task/ti we expect + + ti.queued_by_job_id = scheduler_job.id + session.flush() + + zombie_message = self.job_runner._generate_zombie_message_details(ti) + assert zombie_message == { + "DAG Id": "example_branch_operator", + "Task Id": "run_this_first", + "Run Id": "scheduled__2016-01-01T00:00:00+00:00", + } + + ti.hostname = "10.10.10.10" + ti.map_index = 2 + ti.external_executor_id = "abcdefg" + + zombie_message = self.job_runner._generate_zombie_message_details(ti) + assert zombie_message == { + "DAG Id": "example_branch_operator", + "Task Id": "run_this_first", + "Run Id": "scheduled__2016-01-01T00:00:00+00:00", + "Hostname": "10.10.10.10", + "Map Index": 2, + "External Executor Id": "abcdefg", + } + + def test_find_zombies_handle_failure_callbacks_are_correctly_passed_to_dag_processor( + self, testing_dag_bundle + ): """ Check that the same set of failure callback with zombies are passed to the dag file processors until the next zombie detection logic is invoked. @@ -5784,7 +5778,7 @@ def test_find_zombies_handle_failure_callbacks_are_correctly_passed_to_dag_proce ) session.query(Job).delete() dag = dagbag.get_dag("test_example_bash_operator") - dag.sync_to_db() + DAG.bulk_write_to_db("testing", None, [dag]) data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE) triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} dag_run = dag.create_dagrun( @@ -5830,11 +5824,11 @@ def test_find_zombies_handle_failure_callbacks_are_correctly_passed_to_dag_proce callback_requests[0].ti = None assert expected_failure_callback_requests[0] == callback_requests[0] - def test_cleanup_stale_dags(self): + def test_cleanup_stale_dags(self, testing_dag_bundle): dagbag = DagBag(TEST_DAG_FOLDER, read_dags_from_db=False) with create_session() as session: dag = dagbag.get_dag("test_example_bash_operator") - dag.sync_to_db() + DAG.bulk_write_to_db("testing", None, [dag]) dm = DagModel.get_current("test_example_bash_operator") # Make it "stale". dm.last_parsed_time = timezone.utcnow() - timedelta(minutes=11) @@ -5842,7 +5836,7 @@ def test_cleanup_stale_dags(self): # This one should remain active. dag = dagbag.get_dag("test_start_date_scheduling") - dag.sync_to_db() + DAG.bulk_write_to_db("testing", None, [dag]) session.flush() @@ -5924,13 +5918,13 @@ def watch_heartbeat(*args, **kwargs): @pytest.mark.long_running @pytest.mark.parametrize("dag_id", ["test_mapped_classic", "test_mapped_taskflow"]) - def test_mapped_dag(self, dag_id, session): + def test_mapped_dag(self, dag_id, session, testing_dag_bundle): """End-to-end test of a simple mapped dag""" # Use SequentialExecutor for more predictable test behaviour from airflow.executors.sequential_executor import SequentialExecutor dagbag = DagBag(dag_folder=TEST_DAGS_FOLDER, include_examples=False) - dagbag.sync_to_db() + dagbag.sync_to_db("testing", None) dagbag.process_file(str(TEST_DAGS_FOLDER / f"{dag_id}.py")) dag = dagbag.get_dag(dag_id) assert dag @@ -5957,14 +5951,14 @@ def test_mapped_dag(self, dag_id, session): dr.refresh_from_db(session) assert dr.state == DagRunState.SUCCESS - def test_should_mark_empty_task_as_success(self): + def test_should_mark_empty_task_as_success(self, testing_dag_bundle): dag_file = os.path.join( os.path.dirname(os.path.realpath(__file__)), "../dags/test_only_empty_tasks.py" ) # Write DAGs to dag and serialized_dag table dagbag = DagBag(dag_folder=dag_file, include_examples=False, read_dags_from_db=False) - dagbag.sync_to_db() + dagbag.sync_to_db("testing", None) scheduler_job = Job() self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) @@ -6035,7 +6029,7 @@ def test_should_mark_empty_task_as_success(self): assert duration is None @pytest.mark.need_serialized_dag - def test_catchup_works_correctly(self, dag_maker): + def test_catchup_works_correctly(self, dag_maker, testing_dag_bundle): """Test that catchup works correctly""" session = settings.Session() with dag_maker( @@ -6070,7 +6064,7 @@ def test_catchup_works_correctly(self, dag_maker): session.flush() dag.catchup = False - dag.sync_to_db() + DAG.bulk_write_to_db("testing", None, [dag]) assert not dag.catchup dm = DagModel.get_dagmodel(dag.dag_id) @@ -6279,7 +6273,7 @@ def test_misconfigured_dags_doesnt_crash_scheduler(self, session, dag_maker, cap # Check if the second dagrun was created assert DagRun.find(dag_id="testdag2", session=session) - def test_activate_referenced_assets_with_no_existing_warning(self, session): + def test_activate_referenced_assets_with_no_existing_warning(self, session, testing_dag_bundle): dag_warnings = session.query(DagWarning).all() assert dag_warnings == [] @@ -6292,7 +6286,7 @@ def test_activate_referenced_assets_with_no_existing_warning(self, session): asset1_2 = Asset(name="it's also a duplicate", uri="s3://bucket/key/1", extra=asset_extra) dag1 = DAG(dag_id=dag_id1, start_date=DEFAULT_DATE, schedule=[asset1, asset1_1, asset1_2]) - DAG.bulk_write_to_db([dag1], session=session) + DAG.bulk_write_to_db("testing", None, [dag1], session=session) asset_models = session.scalars(select(AssetModel)).all() assert len(asset_models) == 3 @@ -6313,7 +6307,7 @@ def test_activate_referenced_assets_with_no_existing_warning(self, session): "dy associated to 'asset1'" ) - def test_activate_referenced_assets_with_existing_warnings(self, session): + def test_activate_referenced_assets_with_existing_warnings(self, session, testing_dag_bundle): dag_ids = [f"test_asset_dag{i}" for i in range(1, 4)] asset1_name = "asset1" asset_extra = {"foo": "bar"} @@ -6332,7 +6326,7 @@ def test_activate_referenced_assets_with_existing_warnings(self, session): dag2 = DAG(dag_id=dag_ids[1], start_date=DEFAULT_DATE) dag3 = DAG(dag_id=dag_ids[2], start_date=DEFAULT_DATE, schedule=[asset1_2]) - DAG.bulk_write_to_db([dag1, dag2, dag3], session=session) + DAG.bulk_write_to_db("testing", None, [dag1, dag2, dag3], session=session) asset_models = session.scalars(select(AssetModel)).all() @@ -6366,7 +6360,9 @@ def test_activate_referenced_assets_with_existing_warnings(self, session): "name is already associated to 's3://bucket/key/1'" ) - def test_activate_referenced_assets_with_multiple_conflict_asset_in_one_dag(self, session): + def test_activate_referenced_assets_with_multiple_conflict_asset_in_one_dag( + self, session, testing_dag_bundle + ): dag_id = "test_asset_dag" asset1_name = "asset1" asset_extra = {"foo": "bar"} @@ -6379,7 +6375,7 @@ def test_activate_referenced_assets_with_multiple_conflict_asset_in_one_dag(self ) dag1 = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, schedule=schedule) - DAG.bulk_write_to_db([dag1], session=session) + DAG.bulk_write_to_db("testing", None, [dag1], session=session) asset_models = session.scalars(select(AssetModel)).all() @@ -6469,7 +6465,9 @@ def per_test(self) -> Generator: (93, 10, 10), # 10 DAGs with 10 tasks per DAG file. ], ) - def test_execute_queries_count_with_harvested_dags(self, expected_query_count, dag_count, task_count): + def test_execute_queries_count_with_harvested_dags( + self, expected_query_count, dag_count, task_count, testing_dag_bundle + ): with ( mock.patch.dict( "os.environ", @@ -6496,7 +6494,7 @@ def test_execute_queries_count_with_harvested_dags(self, expected_query_count, d ): dagruns = [] dagbag = DagBag(dag_folder=ELASTIC_DAG_FILE, include_examples=False, read_dags_from_db=False) - dagbag.sync_to_db() + dagbag.sync_to_db("testing", None) dag_ids = dagbag.dag_ids dagbag = DagBag(read_dags_from_db=True) @@ -6566,7 +6564,7 @@ def test_execute_queries_count_with_harvested_dags(self, expected_query_count, d ], ) def test_process_dags_queries_count( - self, expected_query_counts, dag_count, task_count, start_ago, schedule, shape + self, expected_query_counts, dag_count, task_count, start_ago, schedule, shape, testing_dag_bundle ): with ( mock.patch.dict( @@ -6592,7 +6590,7 @@ def test_process_dags_queries_count( ), ): dagbag = DagBag(dag_folder=ELASTIC_DAG_FILE, include_examples=False) - dagbag.sync_to_db() + dagbag.sync_to_db("testing", None) mock_agent = mock.MagicMock() diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 6eaa4e3ac3aeb..a63b05b0cd343 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -623,7 +623,7 @@ def test_dagtag_repr(self): repr(t) for t in session.query(DagTag).filter(DagTag.dag_id == "dag-test-dagtag").all() } - def test_bulk_write_to_db(self): + def test_bulk_write_to_db(self, testing_dag_bundle): clear_db_dags() dags = [ DAG(f"dag-bulk-sync-{i}", schedule=None, start_date=DEFAULT_DATE, tags=["test-dag"]) @@ -631,7 +631,7 @@ def test_bulk_write_to_db(self): ] with assert_queries_count(6): - DAG.bulk_write_to_db(dags) + DAG.bulk_write_to_db("testing", None, dags) with create_session() as session: assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", "dag-bulk-sync-3"} == { row[0] for row in session.query(DagModel.dag_id).all() @@ -648,14 +648,14 @@ def test_bulk_write_to_db(self): # Re-sync should do fewer queries with assert_queries_count(9): - DAG.bulk_write_to_db(dags) + DAG.bulk_write_to_db("testing", None, dags) with assert_queries_count(9): - DAG.bulk_write_to_db(dags) + DAG.bulk_write_to_db("testing", None, dags) # Adding tags for dag in dags: dag.tags.add("test-dag2") with assert_queries_count(10): - DAG.bulk_write_to_db(dags) + DAG.bulk_write_to_db("testing", None, dags) with create_session() as session: assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", "dag-bulk-sync-3"} == { row[0] for row in session.query(DagModel.dag_id).all() @@ -674,7 +674,7 @@ def test_bulk_write_to_db(self): for dag in dags: dag.tags.remove("test-dag") with assert_queries_count(10): - DAG.bulk_write_to_db(dags) + DAG.bulk_write_to_db("testing", None, dags) with create_session() as session: assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", "dag-bulk-sync-3"} == { row[0] for row in session.query(DagModel.dag_id).all() @@ -693,7 +693,7 @@ def test_bulk_write_to_db(self): for dag in dags: dag.tags = set() with assert_queries_count(10): - DAG.bulk_write_to_db(dags) + DAG.bulk_write_to_db("testing", None, dags) with create_session() as session: assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", "dag-bulk-sync-3"} == { row[0] for row in session.query(DagModel.dag_id).all() @@ -703,7 +703,7 @@ def test_bulk_write_to_db(self): for row in session.query(DagModel.last_parsed_time).all(): assert row[0] is not None - def test_bulk_write_to_db_single_dag(self): + def test_bulk_write_to_db_single_dag(self, testing_dag_bundle): """ Test bulk_write_to_db for a single dag using the index optimized query """ @@ -714,7 +714,7 @@ def test_bulk_write_to_db_single_dag(self): ] with assert_queries_count(6): - DAG.bulk_write_to_db(dags) + DAG.bulk_write_to_db("testing", None, dags) with create_session() as session: assert {"dag-bulk-sync-0"} == {row[0] for row in session.query(DagModel.dag_id).all()} assert { @@ -726,11 +726,11 @@ def test_bulk_write_to_db_single_dag(self): # Re-sync should do fewer queries with assert_queries_count(8): - DAG.bulk_write_to_db(dags) + DAG.bulk_write_to_db("testing", None, dags) with assert_queries_count(8): - DAG.bulk_write_to_db(dags) + DAG.bulk_write_to_db("testing", None, dags) - def test_bulk_write_to_db_multiple_dags(self): + def test_bulk_write_to_db_multiple_dags(self, testing_dag_bundle): """ Test bulk_write_to_db for multiple dags which does not use the index optimized query """ @@ -741,7 +741,7 @@ def test_bulk_write_to_db_multiple_dags(self): ] with assert_queries_count(6): - DAG.bulk_write_to_db(dags) + DAG.bulk_write_to_db("testing", None, dags) with create_session() as session: assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", "dag-bulk-sync-3"} == { row[0] for row in session.query(DagModel.dag_id).all() @@ -758,26 +758,26 @@ def test_bulk_write_to_db_multiple_dags(self): # Re-sync should do fewer queries with assert_queries_count(9): - DAG.bulk_write_to_db(dags) + DAG.bulk_write_to_db("testing", None, dags) with assert_queries_count(9): - DAG.bulk_write_to_db(dags) + DAG.bulk_write_to_db("testing", None, dags) @pytest.mark.parametrize("interval", [None, "@daily"]) - def test_bulk_write_to_db_interval_save_runtime(self, interval): + def test_bulk_write_to_db_interval_save_runtime(self, testing_dag_bundle, interval): mock_active_runs_of_dags = mock.MagicMock(side_effect=DagRun.active_runs_of_dags) with mock.patch.object(DagRun, "active_runs_of_dags", mock_active_runs_of_dags): dags_null_timetable = [ DAG("dag-interval-None", schedule=None, start_date=TEST_DATE), DAG("dag-interval-test", schedule=interval, start_date=TEST_DATE), ] - DAG.bulk_write_to_db(dags_null_timetable, session=settings.Session()) + DAG.bulk_write_to_db("testing", None, dags_null_timetable, session=settings.Session()) if interval: mock_active_runs_of_dags.assert_called_once() else: mock_active_runs_of_dags.assert_not_called() @pytest.mark.parametrize("state", [DagRunState.RUNNING, DagRunState.QUEUED]) - def test_bulk_write_to_db_max_active_runs(self, state): + def test_bulk_write_to_db_max_active_runs(self, testing_dag_bundle, state): """ Test that DagModel.next_dagrun_create_after is set to NULL when the dag cannot be created due to max active runs being hit. @@ -793,7 +793,7 @@ def test_bulk_write_to_db_max_active_runs(self, state): session = settings.Session() dag.clear() - DAG.bulk_write_to_db([dag], session=session) + DAG.bulk_write_to_db("testing", None, [dag], session=session) model = session.get(DagModel, dag.dag_id) @@ -810,17 +810,17 @@ def test_bulk_write_to_db_max_active_runs(self, state): **triggered_by_kwargs, ) assert dr is not None - DAG.bulk_write_to_db([dag]) + DAG.bulk_write_to_db("testing", None, [dag]) model = session.get(DagModel, dag.dag_id) # We signal "at max active runs" by saying this run is never eligible to be created assert model.next_dagrun_create_after is None # test that bulk_write_to_db again doesn't update next_dagrun_create_after - DAG.bulk_write_to_db([dag]) + DAG.bulk_write_to_db("testing", None, [dag]) model = session.get(DagModel, dag.dag_id) assert model.next_dagrun_create_after is None - def test_bulk_write_to_db_has_import_error(self): + def test_bulk_write_to_db_has_import_error(self, testing_dag_bundle): """ Test that DagModel.has_import_error is set to false if no import errors. """ @@ -830,7 +830,7 @@ def test_bulk_write_to_db_has_import_error(self): session = settings.Session() dag.clear() - DAG.bulk_write_to_db([dag], session=session) + DAG.bulk_write_to_db("testing", None, [dag], session=session) model = session.get(DagModel, dag.dag_id) @@ -844,14 +844,14 @@ def test_bulk_write_to_db_has_import_error(self): # assert assert model.has_import_errors # parse - DAG.bulk_write_to_db([dag]) + DAG.bulk_write_to_db("testing", None, [dag]) model = session.get(DagModel, dag.dag_id) # assert that has_import_error is now false assert not model.has_import_errors session.close() - def test_bulk_write_to_db_assets(self): + def test_bulk_write_to_db_assets(self, testing_dag_bundle): """ Ensure that assets referenced in a dag are correctly loaded into the database. """ @@ -877,7 +877,7 @@ def test_bulk_write_to_db_assets(self): session = settings.Session() dag1.clear() - DAG.bulk_write_to_db([dag1, dag2], session=session) + DAG.bulk_write_to_db("testing", None, [dag1, dag2], session=session) session.commit() stored_assets = {x.uri: x for x in session.query(AssetModel).all()} asset1_orm = stored_assets[a1.uri] @@ -908,7 +908,7 @@ def test_bulk_write_to_db_assets(self): EmptyOperator(task_id=task_id, dag=dag1, outlets=[a2]) dag2 = DAG(dag_id=dag_id2, start_date=DEFAULT_DATE, schedule=None) EmptyOperator(task_id=task_id, dag=dag2) - DAG.bulk_write_to_db([dag1, dag2], session=session) + DAG.bulk_write_to_db("testing", None, [dag1, dag2], session=session) session.commit() session.expunge_all() stored_assets = {x.uri: x for x in session.query(AssetModel).all()} @@ -937,7 +937,7 @@ def _find_assets_activation(session) -> tuple[list[AssetModel], list[AssetModel] ).all() return [a for a, v in assets if not v], [a for a, v in assets if v] - def test_bulk_write_to_db_does_not_activate(self, dag_maker, session): + def test_bulk_write_to_db_does_not_activate(self, dag_maker, testing_dag_bundle, session): """ Assets are not activated on write, but later in the scheduler by the SchedulerJob. """ @@ -950,14 +950,14 @@ def test_bulk_write_to_db_does_not_activate(self, dag_maker, session): dag1 = DAG(dag_id="assets-1", start_date=DEFAULT_DATE, schedule=[asset1]) BashOperator(dag=dag1, task_id="task", bash_command="echo 1", outlets=[asset3]) - DAG.bulk_write_to_db([dag1], session=session) + DAG.bulk_write_to_db("testing", None, [dag1], session=session) assert session.scalars(select(AssetModel).order_by(AssetModel.uri)).all() == [asset1, asset3] assert session.scalars(select(AssetActive)).all() == [] dag1 = DAG(dag_id="assets-1", start_date=DEFAULT_DATE, schedule=[asset1, asset2]) BashOperator(dag=dag1, task_id="task", bash_command="echo 1", outlets=[asset3, asset4]) - DAG.bulk_write_to_db([dag1], session=session) + DAG.bulk_write_to_db("testing", None, [dag1], session=session) assert session.scalars(select(AssetModel).order_by(AssetModel.uri)).all() == [ asset1, @@ -967,7 +967,7 @@ def test_bulk_write_to_db_does_not_activate(self, dag_maker, session): ] assert session.scalars(select(AssetActive)).all() == [] - def test_bulk_write_to_db_asset_aliases(self): + def test_bulk_write_to_db_asset_aliases(self, testing_dag_bundle): """ Ensure that asset aliases referenced in a dag are correctly loaded into the database. """ @@ -983,7 +983,7 @@ def test_bulk_write_to_db_asset_aliases(self): dag2 = DAG(dag_id=dag_id2, start_date=DEFAULT_DATE, schedule=None) EmptyOperator(task_id=task_id, dag=dag2, outlets=[asset_alias_2_2, asset_alias_3]) session = settings.Session() - DAG.bulk_write_to_db([dag1, dag2], session=session) + DAG.bulk_write_to_db("testing", None, [dag1, dag2], session=session) session.commit() stored_asset_alias_models = {x.name: x for x in session.query(AssetAliasModel).all()} @@ -2431,7 +2431,7 @@ def test_dags_needing_dagruns_asset_triggered_dag_info_queued_times(self, sessio assert first_queued_time == DEFAULT_DATE assert last_queued_time == DEFAULT_DATE + timedelta(hours=1) - def test_asset_expression(self, session: Session) -> None: + def test_asset_expression(self, testing_dag_bundle, session: Session) -> None: dag = DAG( dag_id="test_dag_asset_expression", schedule=AssetAny( @@ -2449,7 +2449,7 @@ def test_asset_expression(self, session: Session) -> None: ), start_date=datetime.datetime.min, ) - DAG.bulk_write_to_db([dag], session=session) + DAG.bulk_write_to_db("testing", None, [dag], session=session) expression = session.scalars(select(DagModel.asset_expression).filter_by(dag_id=dag.dag_id)).one() assert expression == { diff --git a/tests/models/test_dagcode.py b/tests/models/test_dagcode.py index 05598703476a5..ac7330b1aa581 100644 --- a/tests/models/test_dagcode.py +++ b/tests/models/test_dagcode.py @@ -41,8 +41,17 @@ def make_example_dags(module): """Loads DAGs from a module for test.""" + # TODO: AIP-66 dedup with tests/models/test_serdag + from airflow.models.dagbundle import DagBundleModel + from airflow.utils.session import create_session + + with create_session() as session: + if session.query(DagBundleModel).filter(DagBundleModel.name == "testing").count() == 0: + testing = DagBundleModel(name="testing") + session.add(testing) + dagbag = DagBag(module.__path__[0]) - DAG.bulk_write_to_db(dagbag.dags.values()) + DAG.bulk_write_to_db("testing", None, dagbag.dags.values()) return dagbag.dags diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index 13614d6abeaea..f398be8ad88c4 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -518,7 +518,7 @@ def test_on_success_callback_when_task_skipped(self, session): assert dag_run.state == DagRunState.SUCCESS mock_on_success.assert_called_once() - def test_dagrun_update_state_with_handle_callback_success(self, session): + def test_dagrun_update_state_with_handle_callback_success(self, testing_dag_bundle, session): def on_success_callable(context): assert context["dag_run"].dag_id == "test_dagrun_update_state_with_handle_callback_success" @@ -528,7 +528,7 @@ def on_success_callable(context): start_date=datetime.datetime(2017, 1, 1), on_success_callback=on_success_callable, ) - DAG.bulk_write_to_db(dags=[dag], session=session) + DAG.bulk_write_to_db("testing", None, dags=[dag], session=session) dag_task1 = EmptyOperator(task_id="test_state_succeeded1", dag=dag) dag_task2 = EmptyOperator(task_id="test_state_succeeded2", dag=dag) @@ -556,7 +556,7 @@ def on_success_callable(context): msg="success", ) - def test_dagrun_update_state_with_handle_callback_failure(self, session): + def test_dagrun_update_state_with_handle_callback_failure(self, testing_dag_bundle, session): def on_failure_callable(context): assert context["dag_run"].dag_id == "test_dagrun_update_state_with_handle_callback_failure" @@ -566,7 +566,7 @@ def on_failure_callable(context): start_date=datetime.datetime(2017, 1, 1), on_failure_callback=on_failure_callable, ) - DAG.bulk_write_to_db(dags=[dag], session=session) + DAG.bulk_write_to_db("testing", None, dags=[dag], session=session) dag_task1 = EmptyOperator(task_id="test_state_succeeded1", dag=dag) dag_task2 = EmptyOperator(task_id="test_state_failed2", dag=dag) diff --git a/tests/models/test_serialized_dag.py b/tests/models/test_serialized_dag.py index 60d26959b079b..94835fbd5e5d7 100644 --- a/tests/models/test_serialized_dag.py +++ b/tests/models/test_serialized_dag.py @@ -49,8 +49,16 @@ # To move it to a shared module. def make_example_dags(module): """Loads DAGs from a module for test.""" + from airflow.models.dagbundle import DagBundleModel + from airflow.utils.session import create_session + + with create_session() as session: + if session.query(DagBundleModel).filter(DagBundleModel.name == "testing").count() == 0: + testing = DagBundleModel(name="testing") + session.add(testing) + dagbag = DagBag(module.__path__[0]) - DAG.bulk_write_to_db(dagbag.dags.values()) + DAG.bulk_write_to_db("testing", None, dagbag.dags.values()) return dagbag.dags @@ -177,13 +185,13 @@ def test_read_all_dags_only_picks_the_latest_serdags(self, session): # assert only the latest SDM is returned assert len(sdags) != len(serialized_dags2) - def test_bulk_sync_to_db(self): + def test_bulk_sync_to_db(self, testing_dag_bundle): dags = [ DAG("dag_1", schedule=None), DAG("dag_2", schedule=None), DAG("dag_3", schedule=None), ] - DAG.bulk_write_to_db(dags) + DAG.bulk_write_to_db("testing", None, dags) # we also write to dag_version and dag_code tables # in dag_version. with assert_queries_count(24): diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 73f5908b707cf..bb3edc53cdd28 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -2280,7 +2280,7 @@ def test_success_callback_no_race_condition(self, create_task_instance): ti.refresh_from_db() assert ti.state == State.SUCCESS - def test_outlet_assets(self, create_task_instance): + def test_outlet_assets(self, create_task_instance, testing_dag_bundle): """ Verify that when we have an outlet asset on a task, and the task completes successfully, an AssetDagRunQueue is logged. @@ -2291,7 +2291,7 @@ def test_outlet_assets(self, create_task_instance): session = settings.Session() dagbag = DagBag(dag_folder=example_assets.__file__) dagbag.collect_dags(only_if_updated=False, safe_mode=False) - dagbag.sync_to_db(session=session) + dagbag.sync_to_db("testing", None, session=session) asset_models = session.scalars(select(AssetModel)).all() SchedulerJobRunner._activate_referenced_assets(asset_models, session=session) @@ -2343,7 +2343,7 @@ def test_outlet_assets(self, create_task_instance): event.timestamp < adrq_timestamp for (adrq_timestamp,) in adrq_timestamps ), f"Some items in {[str(t) for t in adrq_timestamps]} are earlier than {event.timestamp}" - def test_outlet_assets_failed(self, create_task_instance): + def test_outlet_assets_failed(self, create_task_instance, testing_dag_bundle): """ Verify that when we have an outlet asset on a task, and the task failed, an AssetDagRunQueue is not logged, and an AssetEvent is @@ -2355,7 +2355,7 @@ def test_outlet_assets_failed(self, create_task_instance): session = settings.Session() dagbag = DagBag(dag_folder=test_assets.__file__) dagbag.collect_dags(only_if_updated=False, safe_mode=False) - dagbag.sync_to_db(session=session) + dagbag.sync_to_db("testing", None, session=session) run_id = str(uuid4()) dr = DagRun(dag_with_fail_task.dag_id, run_id=run_id, run_type="anything") session.merge(dr) @@ -2397,7 +2397,7 @@ def raise_an_exception(placeholder: int): task_instance.run() assert task_instance.current_state() == TaskInstanceState.SUCCESS - def test_outlet_assets_skipped(self): + def test_outlet_assets_skipped(self, testing_dag_bundle): """ Verify that when we have an outlet asset on a task, and the task is skipped, an AssetDagRunQueue is not logged, and an AssetEvent is @@ -2409,7 +2409,7 @@ def test_outlet_assets_skipped(self): session = settings.Session() dagbag = DagBag(dag_folder=test_assets.__file__) dagbag.collect_dags(only_if_updated=False, safe_mode=False) - dagbag.sync_to_db(session=session) + dagbag.sync_to_db("testing", None, session=session) asset_models = session.scalars(select(AssetModel)).all() SchedulerJobRunner._activate_referenced_assets(asset_models, session=session) diff --git a/tests/operators/test_trigger_dagrun.py b/tests/operators/test_trigger_dagrun.py index a8a6b3c262903..74d8371e5d78e 100644 --- a/tests/operators/test_trigger_dagrun.py +++ b/tests/operators/test_trigger_dagrun.py @@ -26,7 +26,6 @@ from airflow.exceptions import AirflowException, DagRunAlreadyExists, TaskDeferred from airflow.models.dag import DagModel -from airflow.models.dagbag import DagBag from airflow.models.dagrun import DagRun from airflow.models.log import Log from airflow.models.taskinstance import TaskInstance @@ -37,6 +36,8 @@ from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.types import DagRunType +from tests_common.test_utils.db import parse_and_sync_to_db + pytestmark = pytest.mark.db_test DEFAULT_DATE = datetime(2019, 1, 1, tzinfo=timezone.utc) @@ -71,11 +72,6 @@ def setup_method(self): session.add(DagModel(dag_id=TRIGGERED_DAG_ID, fileloc=self._tmpfile)) session.commit() - def re_sync_triggered_dag_to_db(self, dag, dag_maker): - dagbag = DagBag(self.f_name, read_dags_from_db=False, include_examples=False) - dagbag.bag_dag(dag) - dagbag.sync_to_db(session=dag_maker.session) - def teardown_method(self): """Cleanup state after testing in DB.""" with create_session() as session: @@ -120,9 +116,10 @@ def test_trigger_dagrun(self, dag_maker): """Test TriggerDagRunOperator.""" with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator(task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @@ -134,13 +131,14 @@ def test_trigger_dagrun(self, dag_maker): def test_trigger_dagrun_custom_run_id(self, dag_maker): with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, trigger_run_id="custom_run_id", ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) @@ -154,13 +152,14 @@ def test_trigger_dagrun_with_logical_date(self, dag_maker): custom_logical_date = timezone.datetime(2021, 1, 2, 3, 4, 5) with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_trigger_dagrun_with_logical_date", trigger_dag_id=TRIGGERED_DAG_ID, logical_date=custom_logical_date, ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @@ -177,7 +176,7 @@ def test_trigger_dagrun_twice(self, dag_maker): run_id = f"manual__{utc_now.isoformat()}" with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_trigger_dagrun_with_logical_date", trigger_dag_id=TRIGGERED_DAG_ID, @@ -187,7 +186,8 @@ def test_trigger_dagrun_twice(self, dag_maker): reset_dag_run=True, wait_for_completion=True, ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() dag_run = DagRun( dag_id=TRIGGERED_DAG_ID, @@ -213,7 +213,7 @@ def test_trigger_dagrun_with_scheduled_dag_run(self, dag_maker): run_id = f"scheduled__{utc_now.isoformat()}" with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_trigger_dagrun_with_logical_date", trigger_dag_id=TRIGGERED_DAG_ID, @@ -223,7 +223,8 @@ def test_trigger_dagrun_with_scheduled_dag_run(self, dag_maker): reset_dag_run=True, wait_for_completion=True, ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() run_id = f"scheduled__{utc_now.isoformat()}" dag_run = DagRun( @@ -248,13 +249,14 @@ def test_trigger_dagrun_with_templated_logical_date(self, dag_maker): """Test TriggerDagRunOperator with templated logical_date.""" with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_trigger_dagrun_with_str_logical_date", trigger_dag_id=TRIGGERED_DAG_ID, logical_date="{{ logical_date }}", ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @@ -270,12 +272,13 @@ def test_trigger_dagrun_with_templated_trigger_dag_id(self, dag_maker): """Test TriggerDagRunOperator with templated trigger dag id.""" with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="__".join(["test_trigger_dagrun_with_templated_trigger_dag_id", TRIGGERED_DAG_ID]), trigger_dag_id="{{ ti.task_id.rsplit('.', 1)[-1].split('__')[-1] }}", ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @@ -291,13 +294,14 @@ def test_trigger_dagrun_operator_conf(self, dag_maker): """Test passing conf to the triggered DagRun.""" with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_trigger_dagrun_with_str_logical_date", trigger_dag_id=TRIGGERED_DAG_ID, conf={"foo": "bar"}, ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @@ -310,13 +314,14 @@ def test_trigger_dagrun_operator_templated_invalid_conf(self, dag_maker): """Test passing a conf that is not JSON Serializable raise error.""" with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_trigger_dagrun_with_invalid_conf", trigger_dag_id=TRIGGERED_DAG_ID, conf={"foo": "{{ dag.dag_id }}", "datetime": timezone.utcnow()}, ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() with pytest.raises(AirflowException, match="^conf parameter should be JSON Serializable$"): task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) @@ -325,13 +330,14 @@ def test_trigger_dagrun_operator_templated_conf(self, dag_maker): """Test passing a templated conf to the triggered DagRun.""" with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_trigger_dagrun_with_str_logical_date", trigger_dag_id=TRIGGERED_DAG_ID, conf={"foo": "{{ dag.dag_id }}"}, ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @@ -345,7 +351,7 @@ def test_trigger_dagrun_with_reset_dag_run_false(self, dag_maker): logical_date = DEFAULT_DATE with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, @@ -353,7 +359,8 @@ def test_trigger_dagrun_with_reset_dag_run_false(self, dag_maker): logical_date=None, reset_dag_run=False, ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() task.run(start_date=logical_date, end_date=logical_date, ignore_ti_state=True) task.run(start_date=logical_date, end_date=logical_date, ignore_ti_state=True) @@ -377,7 +384,7 @@ def test_trigger_dagrun_with_reset_dag_run_false_fail( logical_date = DEFAULT_DATE with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, @@ -385,7 +392,8 @@ def test_trigger_dagrun_with_reset_dag_run_false_fail( logical_date=trigger_logical_date, reset_dag_run=False, ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() task.run(start_date=logical_date, end_date=logical_date, ignore_ti_state=True) @@ -397,7 +405,7 @@ def test_trigger_dagrun_with_skip_when_already_exists(self, dag_maker): logical_date = DEFAULT_DATE with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, @@ -405,7 +413,8 @@ def test_trigger_dagrun_with_skip_when_already_exists(self, dag_maker): reset_dag_run=False, skip_when_already_exists=True, ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dr: DagRun = dag_maker.create_dagrun() task.run(start_date=logical_date, end_date=logical_date, ignore_ti_state=True) assert dr.get_task_instance("test_task").state == TaskInstanceState.SUCCESS @@ -428,7 +437,7 @@ def test_trigger_dagrun_with_reset_dag_run_true( logical_date = DEFAULT_DATE with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, @@ -436,7 +445,8 @@ def test_trigger_dagrun_with_reset_dag_run_true( logical_date=trigger_logical_date, reset_dag_run=True, ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() task.run(start_date=logical_date, end_date=logical_date, ignore_ti_state=True) task.run(start_date=logical_date, end_date=logical_date, ignore_ti_state=True) @@ -451,7 +461,7 @@ def test_trigger_dagrun_with_wait_for_completion_true(self, dag_maker): logical_date = DEFAULT_DATE with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, @@ -460,7 +470,8 @@ def test_trigger_dagrun_with_wait_for_completion_true(self, dag_maker): poke_interval=10, allowed_states=[State.QUEUED], ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() task.run(start_date=logical_date, end_date=logical_date) @@ -473,7 +484,7 @@ def test_trigger_dagrun_with_wait_for_completion_true_fail(self, dag_maker): logical_date = DEFAULT_DATE with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, @@ -482,7 +493,8 @@ def test_trigger_dagrun_with_wait_for_completion_true_fail(self, dag_maker): poke_interval=10, failed_states=[State.QUEUED], ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() with pytest.raises(AirflowException): task.run(start_date=logical_date, end_date=logical_date) @@ -492,12 +504,13 @@ def test_trigger_dagrun_triggering_itself(self, dag_maker): logical_date = DEFAULT_DATE with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TEST_DAG_ID, ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() task.run(start_date=logical_date, end_date=logical_date) @@ -516,7 +529,7 @@ def test_trigger_dagrun_with_wait_for_completion_true_defer_false(self, dag_make logical_date = DEFAULT_DATE with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, @@ -526,7 +539,8 @@ def test_trigger_dagrun_with_wait_for_completion_true_defer_false(self, dag_make allowed_states=[State.QUEUED], deferrable=False, ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() task.run(start_date=logical_date, end_date=logical_date) @@ -539,7 +553,7 @@ def test_trigger_dagrun_with_wait_for_completion_true_defer_true(self, dag_maker logical_date = DEFAULT_DATE with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, @@ -549,7 +563,8 @@ def test_trigger_dagrun_with_wait_for_completion_true_defer_true(self, dag_maker allowed_states=[State.QUEUED], deferrable=True, ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() task.run(start_date=logical_date, end_date=logical_date) @@ -571,7 +586,7 @@ def test_trigger_dagrun_with_wait_for_completion_true_defer_true_failure(self, d logical_date = DEFAULT_DATE with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, @@ -581,7 +596,8 @@ def test_trigger_dagrun_with_wait_for_completion_true_defer_true_failure(self, d allowed_states=[State.SUCCESS], deferrable=True, ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() task.run(start_date=logical_date, end_date=logical_date) @@ -607,7 +623,7 @@ def test_trigger_dagrun_with_wait_for_completion_true_defer_true_failure_2(self, logical_date = DEFAULT_DATE with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, @@ -618,7 +634,8 @@ def test_trigger_dagrun_with_wait_for_completion_true_defer_true_failure_2(self, failed_states=[State.QUEUED], deferrable=True, ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() task.run(start_date=logical_date, end_date=logical_date) @@ -648,7 +665,7 @@ def test_dagstatetrigger_logical_dates(self, trigger_logical_date, dag_maker): """Ensure that the DagStateTrigger is called with the triggered DAG's logical date.""" with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, @@ -658,7 +675,8 @@ def test_dagstatetrigger_logical_dates(self, trigger_logical_date, dag_maker): allowed_states=[DagRunState.QUEUED], deferrable=True, ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() mock_task_defer = mock.MagicMock(side_effect=task.defer) @@ -677,7 +695,7 @@ def test_dagstatetrigger_logical_dates_with_clear_and_reset(self, dag_maker): """Check DagStateTrigger is called with the triggered DAG's logical date on subsequent defers.""" with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, @@ -688,7 +706,8 @@ def test_dagstatetrigger_logical_dates_with_clear_and_reset(self, dag_maker): deferrable=True, reset_dag_run=True, ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() mock_task_defer = mock.MagicMock(side_effect=task.defer) @@ -727,7 +746,7 @@ def test_trigger_dagrun_with_no_failed_state(self, dag_maker): logical_date = DEFAULT_DATE with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: + ): task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, @@ -736,7 +755,8 @@ def test_trigger_dagrun_with_no_failed_state(self, dag_maker): poke_interval=10, failed_states=[], ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) dag_maker.create_dagrun() assert task.failed_states == [] diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task_sensor.py index ba3cfd6b4480c..c73980ca24bd5 100644 --- a/tests/sensors/test_external_task_sensor.py +++ b/tests/sensors/test_external_task_sensor.py @@ -80,7 +80,7 @@ def clean_db(): @pytest.fixture -def dag_zip_maker(): +def dag_zip_maker(testing_dag_bundle): class DagZipMaker: def __call__(self, *dag_files): self.__dag_files = [os.sep.join([TEST_DAGS_FOLDER.__str__(), dag_file]) for dag_file in dag_files] @@ -98,7 +98,7 @@ def __enter__(self): for dag_file in self.__dag_files: zf.write(dag_file, os.path.basename(dag_file)) dagbag = DagBag(dag_folder=self.__tmp_dir, include_examples=False) - dagbag.sync_to_db() + dagbag.sync_to_db("testing", None) return dagbag def __exit__(self, exc_type, exc_val, exc_tb): diff --git a/tests/www/views/conftest.py b/tests/www/views/conftest.py index 7f524b2377e39..68516b4f4c865 100644 --- a/tests/www/views/conftest.py +++ b/tests/www/views/conftest.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +import os from collections.abc import Generator from contextlib import contextmanager from typing import Any, NamedTuple @@ -31,6 +32,7 @@ from tests_common.test_utils.api_connexion_utils import delete_user from tests_common.test_utils.config import conf_vars +from tests_common.test_utils.db import parse_and_sync_to_db from tests_common.test_utils.decorators import dont_initialize_flask_app_submodules from tests_common.test_utils.www import ( client_with_login, @@ -47,8 +49,8 @@ def session(): @pytest.fixture(autouse=True, scope="module") def examples_dag_bag(session): - DagBag(include_examples=True).sync_to_db() - dag_bag = DagBag(include_examples=True, read_dags_from_db=True) + parse_and_sync_to_db(os.devnull, include_examples=True) + dag_bag = DagBag(read_dags_from_db=True) session.commit() return dag_bag diff --git a/tests/www/views/test_views_acl.py b/tests/www/views/test_views_acl.py index 0f430b9b6dca7..21379550dd1fd 100644 --- a/tests/www/views/test_views_acl.py +++ b/tests/www/views/test_views_acl.py @@ -293,9 +293,9 @@ def test_dag_autocomplete_success(client_all_dags): "dag_display_name": None, }, {"name": "example_setup_teardown_taskflow", "type": "dag", "dag_display_name": None}, - {"name": "test_mapped_taskflow", "type": "dag", "dag_display_name": None}, {"name": "tutorial_taskflow_api", "type": "dag", "dag_display_name": None}, {"name": "tutorial_taskflow_api_virtualenv", "type": "dag", "dag_display_name": None}, + {"name": "tutorial_taskflow_templates", "type": "dag", "dag_display_name": None}, ] assert resp.json == expected diff --git a/tests/www/views/test_views_decorators.py b/tests/www/views/test_views_decorators.py index 5815510b07f6e..6c0a0a7d78172 100644 --- a/tests/www/views/test_views_decorators.py +++ b/tests/www/views/test_views_decorators.py @@ -17,6 +17,8 @@ # under the License. from __future__ import annotations +import os + import pytest from airflow.models import DagBag, Variable @@ -24,7 +26,7 @@ from airflow.utils.state import State from airflow.utils.types import DagRunType -from tests_common.test_utils.db import clear_db_runs, clear_db_variables +from tests_common.test_utils.db import clear_db_runs, clear_db_variables, parse_and_sync_to_db from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS from tests_common.test_utils.www import ( _check_last_log, @@ -42,8 +44,8 @@ @pytest.fixture(scope="module") def dagbag(): - DagBag(include_examples=True, read_dags_from_db=False).sync_to_db() - return DagBag(include_examples=True, read_dags_from_db=True) + parse_and_sync_to_db(os.devnull, include_examples=True) + return DagBag(read_dags_from_db=True) @pytest.fixture(scope="module") diff --git a/tests/www/views/test_views_log.py b/tests/www/views/test_views_log.py index be492487a4c8b..735c9a569e093 100644 --- a/tests/www/views/test_views_log.py +++ b/tests/www/views/test_views_log.py @@ -20,6 +20,7 @@ import copy import logging import logging.config +import os import pathlib import shutil import sys @@ -127,7 +128,7 @@ def _reset_modules_after_every_test(backup_modules): @pytest.fixture(autouse=True) -def dags(log_app, create_dummy_dag, session): +def dags(log_app, create_dummy_dag, testing_dag_bundle, session): dag, _ = create_dummy_dag( dag_id=DAG_ID, task_id=TASK_ID, @@ -143,10 +144,10 @@ def dags(log_app, create_dummy_dag, session): session=session, ) - bag = DagBag(include_examples=False) + bag = DagBag(os.devnull, include_examples=False) bag.bag_dag(dag=dag) bag.bag_dag(dag=dag_removed) - bag.sync_to_db(session=session) + bag.sync_to_db("testing", None, session=session) log_app.dag_bag = bag yield dag, dag_removed diff --git a/tests/www/views/test_views_tasks.py b/tests/www/views/test_views_tasks.py index 8289228b81dfe..10f42aca6c4f2 100644 --- a/tests/www/views/test_views_tasks.py +++ b/tests/www/views/test_views_tasks.py @@ -597,7 +597,7 @@ def heartbeat(self): def test_delete_dag_button_for_dag_on_scheduler_only(admin_client, dag_maker): with dag_maker() as dag: EmptyOperator(task_id="task") - dag.sync_to_db() + dag_maker.sync_dagbag_to_db() # The delete-dag URL should be generated correctly test_dag_id = dag.dag_id resp = admin_client.get("/", follow_redirects=True) @@ -606,12 +606,12 @@ def test_delete_dag_button_for_dag_on_scheduler_only(admin_client, dag_maker): @pytest.fixture -def new_dag_to_delete(): +def new_dag_to_delete(testing_dag_bundle): dag = DAG( "new_dag_to_delete", is_paused_upon_creation=True, schedule="0 * * * *", start_date=DEFAULT_DATE ) session = settings.Session() - dag.sync_to_db(session=session) + DAG.bulk_write_to_db("testing", None, [dag], session=session) return dag diff --git a/tests_common/pytest_plugin.py b/tests_common/pytest_plugin.py index 43b4733e38811..b4d3c089a3740 100644 --- a/tests_common/pytest_plugin.py +++ b/tests_common/pytest_plugin.py @@ -935,6 +935,19 @@ def create_dagrun_after(self, dagrun, **kwargs): **kwargs, ) + def sync_dagbag_to_db(self): + if not AIRFLOW_V_3_0_PLUS: + self.dagbag.sync_to_db() + return + + from airflow.models.dagbundle import DagBundleModel + + if self.session.query(DagBundleModel).filter(DagBundleModel.name == "dag_maker").count() == 0: + self.session.add(DagBundleModel(name="dag_maker")) + self.session.commit() + + self.dagbag.sync_to_db("dag_maker", None) + def __call__( self, dag_id="test_dag", diff --git a/tests_common/test_utils/db.py b/tests_common/test_utils/db.py index 8c5d59751f55c..58ad46372d711 100644 --- a/tests_common/test_utils/db.py +++ b/tests_common/test_utils/db.py @@ -17,6 +17,8 @@ # under the License. from __future__ import annotations +from typing import TYPE_CHECKING + from airflow.jobs.job import Job from airflow.models import ( Connection, @@ -51,15 +53,28 @@ ) from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS +if TYPE_CHECKING: + from pathlib import Path + def _bootstrap_dagbag(): from airflow.models.dag import DAG from airflow.models.dagbag import DagBag + if AIRFLOW_V_3_0_PLUS: + from airflow.dag_processing.bundles.manager import DagBundlesManager + with create_session() as session: + if AIRFLOW_V_3_0_PLUS: + DagBundlesManager().sync_bundles_to_db(session=session) + session.commit() + dagbag = DagBag() # Save DAGs in the ORM - dagbag.sync_to_db(session=session) + if AIRFLOW_V_3_0_PLUS: + dagbag.sync_to_db(bundle_name="dags-folder", bundle_version=None, session=session) + else: + dagbag.sync_to_db(session=session) # Deactivate the unknown ones DAG.deactivate_unknown_dags(dagbag.dags.keys(), session=session) @@ -92,6 +107,25 @@ def initial_db_init(): get_auth_manager().init() +def parse_and_sync_to_db(folder: Path | str, include_examples: bool = False): + from airflow.models.dagbag import DagBag + + if AIRFLOW_V_3_0_PLUS: + from airflow.dag_processing.bundles.manager import DagBundlesManager + + with create_session() as session: + if AIRFLOW_V_3_0_PLUS: + DagBundlesManager().sync_bundles_to_db(session=session) + session.commit() + + dagbag = DagBag(dag_folder=folder, include_examples=include_examples) + if AIRFLOW_V_3_0_PLUS: + dagbag.sync_to_db("dags-folder", None, session) + else: + dagbag.sync_to_db(session=session) # type: ignore[call-arg] + session.commit() + + def clear_db_runs(): with create_session() as session: session.query(Job).delete()