Skip to content

Commit

Permalink
Let manager manage the db session
Browse files Browse the repository at this point in the history
  • Loading branch information
jedcunningham committed Jan 3, 2025
1 parent 0880baa commit 626c342
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 40 deletions.
60 changes: 31 additions & 29 deletions airflow/dag_processing/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
set_new_process_group,
)
from airflow.utils.retries import retry_db_transaction
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.session import NEW_SESSION, create_session, provide_session
from airflow.utils.sqlalchemy import prohibit_commit, with_row_locks

if TYPE_CHECKING:
Expand Down Expand Up @@ -654,40 +654,42 @@ def _refresh_dag_bundles(self):

for bundle in self._dag_bundles:
# TODO: AIP-66 test to make sure we get a fresh record from the db and it's not cached
bundle_model = DagBundleModel.get(bundle.name)
with create_session() as session:
bundle_model = session.query(DagBundleModel).get(bundle.name)
elapsed_time_since_refresh = (
now - (bundle_model.last_refreshed or timezone.utc_epoch())
).total_seconds()
if elapsed_time_since_refresh > bundle.refresh_interval:
if not elapsed_time_since_refresh > bundle.refresh_interval:
# or bundle_model.version != bundle.get_current_version():
# TODO: AIP-66 locking / dealing with multiple processors
self.log.info("Time to refresh %s", bundle.name)
old_version = bundle.get_current_version()
bundle.refresh()
bundle_model.last_refreshed = now

if old_version != bundle.get_current_version():
self.log.info(
"Version changed for %s, new version: %s", bundle.name, bundle.get_current_version()
)
bundle_file_paths = self._refresh_dag_dir(bundle)
# remove all files from the bundle, then add the new ones
self._file_paths = [f for f in self._file_paths if f.bundle_name != bundle_model.name]
self._file_paths.extend(
DagFilePath(path=path, bundle_name=bundle_model.name) for path in bundle_file_paths
self.log.info("Not time to refresh %s", bundle.name)
continue

# TODO: AIP-66 locking / dealing with multiple processors
self.log.info("Time to refresh %s", bundle.name)
old_version = bundle.get_current_version()
bundle.refresh()
bundle_model.last_refreshed = now

if old_version != bundle.get_current_version():
self.log.info(
"Version changed for %s, new version: %s", bundle.name, bundle.get_current_version()
)
bundle_file_paths = self._refresh_dag_dir(bundle)
# remove all files from the bundle, then add the new ones
self._file_paths = [f for f in self._file_paths if f.bundle_name != bundle_model.name]
self._file_paths.extend(
DagFilePath(path=path, bundle_name=bundle_model.name) for path in bundle_file_paths
)

try:
self.log.debug("Removing old import errors")
self.clear_nonexistent_import_errors()
except Exception:
self.log.exception("Error removing old import errors")

self._bundle_versions[bundle_model.name] = bundle.get_current_version()
self.log.info("Found %s files for bundle %s", len(bundle_file_paths), bundle.name)
# TODO: AIP-66 detect if version changed and update accordingly
else:
self.log.info("Not time to refresh %s", bundle.name)
try:
self.log.debug("Removing old import errors")
self.clear_nonexistent_import_errors()
except Exception:
self.log.exception("Error removing old import errors")

self._bundle_versions[bundle_model.name] = bundle.get_current_version()
self.log.info("Found %s files for bundle %s", len(bundle_file_paths), bundle.name)
# TODO: AIP-66 detect if version changed and update accordingly

def _refresh_dag_dir(self, bundle: BaseDagBundle) -> list[str]:
"""Refresh file paths from bundle dir."""
Expand Down
11 changes: 0 additions & 11 deletions airflow/models/dagbundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,11 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING

from sqlalchemy import Boolean, Column, String

from airflow.models.base import Base, StringID
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import UtcDateTime

if TYPE_CHECKING:
from sqlalchemy.orm import Session


class DagBundleModel(Base):
"""
Expand All @@ -47,8 +41,3 @@ class DagBundleModel(Base):

def __init__(self, *, name: str):
self.name = name

@staticmethod
@provide_session
def get(name: str, session: Session = NEW_SESSION) -> DagBundleModel:
return session.query(DagBundleModel).get(name)

0 comments on commit 626c342

Please sign in to comment.