From 626c3428c89ee3765f4c88f3e9930a03a53faf64 Mon Sep 17 00:00:00 2001 From: Jed Cunningham Date: Fri, 3 Jan 2025 10:55:35 -0500 Subject: [PATCH] Let manager manage the db session --- airflow/dag_processing/manager.py | 60 ++++++++++++++++--------------- airflow/models/dagbundle.py | 11 ------ 2 files changed, 31 insertions(+), 40 deletions(-) diff --git a/airflow/dag_processing/manager.py b/airflow/dag_processing/manager.py index 73f22e5847795..57621b09d831f 100644 --- a/airflow/dag_processing/manager.py +++ b/airflow/dag_processing/manager.py @@ -69,7 +69,7 @@ set_new_process_group, ) from airflow.utils.retries import retry_db_transaction -from airflow.utils.session import NEW_SESSION, provide_session +from airflow.utils.session import NEW_SESSION, create_session, provide_session from airflow.utils.sqlalchemy import prohibit_commit, with_row_locks if TYPE_CHECKING: @@ -654,40 +654,42 @@ def _refresh_dag_bundles(self): for bundle in self._dag_bundles: # TODO: AIP-66 test to make sure we get a fresh record from the db and it's not cached - bundle_model = DagBundleModel.get(bundle.name) + with create_session() as session: + bundle_model = session.query(DagBundleModel).get(bundle.name) elapsed_time_since_refresh = ( now - (bundle_model.last_refreshed or timezone.utc_epoch()) ).total_seconds() - if elapsed_time_since_refresh > bundle.refresh_interval: + if not elapsed_time_since_refresh > bundle.refresh_interval: # or bundle_model.version != bundle.get_current_version(): - # TODO: AIP-66 locking / dealing with multiple processors - self.log.info("Time to refresh %s", bundle.name) - old_version = bundle.get_current_version() - bundle.refresh() - bundle_model.last_refreshed = now - - if old_version != bundle.get_current_version(): - self.log.info( - "Version changed for %s, new version: %s", bundle.name, bundle.get_current_version() - ) - bundle_file_paths = self._refresh_dag_dir(bundle) - # remove all files from the bundle, then add the new ones - self._file_paths = [f for f in self._file_paths if f.bundle_name != bundle_model.name] - self._file_paths.extend( - DagFilePath(path=path, bundle_name=bundle_model.name) for path in bundle_file_paths + self.log.info("Not time to refresh %s", bundle.name) + continue + + # TODO: AIP-66 locking / dealing with multiple processors + self.log.info("Time to refresh %s", bundle.name) + old_version = bundle.get_current_version() + bundle.refresh() + bundle_model.last_refreshed = now + + if old_version != bundle.get_current_version(): + self.log.info( + "Version changed for %s, new version: %s", bundle.name, bundle.get_current_version() ) + bundle_file_paths = self._refresh_dag_dir(bundle) + # remove all files from the bundle, then add the new ones + self._file_paths = [f for f in self._file_paths if f.bundle_name != bundle_model.name] + self._file_paths.extend( + DagFilePath(path=path, bundle_name=bundle_model.name) for path in bundle_file_paths + ) - try: - self.log.debug("Removing old import errors") - self.clear_nonexistent_import_errors() - except Exception: - self.log.exception("Error removing old import errors") - - self._bundle_versions[bundle_model.name] = bundle.get_current_version() - self.log.info("Found %s files for bundle %s", len(bundle_file_paths), bundle.name) - # TODO: AIP-66 detect if version changed and update accordingly - else: - self.log.info("Not time to refresh %s", bundle.name) + try: + self.log.debug("Removing old import errors") + self.clear_nonexistent_import_errors() + except Exception: + self.log.exception("Error removing old import errors") + + self._bundle_versions[bundle_model.name] = bundle.get_current_version() + self.log.info("Found %s files for bundle %s", len(bundle_file_paths), bundle.name) + # TODO: AIP-66 detect if version changed and update accordingly def _refresh_dag_dir(self, bundle: BaseDagBundle) -> list[str]: """Refresh file paths from bundle dir.""" diff --git a/airflow/models/dagbundle.py b/airflow/models/dagbundle.py index 43f6396c5bf5f..08429db0b0bcb 100644 --- a/airflow/models/dagbundle.py +++ b/airflow/models/dagbundle.py @@ -16,17 +16,11 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING - from sqlalchemy import Boolean, Column, String from airflow.models.base import Base, StringID -from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.sqlalchemy import UtcDateTime -if TYPE_CHECKING: - from sqlalchemy.orm import Session - class DagBundleModel(Base): """ @@ -47,8 +41,3 @@ class DagBundleModel(Base): def __init__(self, *, name: str): self.name = name - - @staticmethod - @provide_session - def get(name: str, session: Session = NEW_SESSION) -> DagBundleModel: - return session.query(DagBundleModel).get(name)