From 88cd30c30f0366e8d49403a2cd040d540c07f5a7 Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Wed, 15 Jan 2025 16:29:43 +0100 Subject: [PATCH] Apply suggestions from review and link to dag processor --- airflow/dag_processing/bundles/base.py | 9 +++++++ airflow/dag_processing/bundles/git.py | 33 ++++++++--------------- airflow/dag_processing/bundles/manager.py | 1 + airflow/dag_processing/manager.py | 4 +++ tests/dag_processing/test_dag_bundles.py | 16 +++-------- 5 files changed, 28 insertions(+), 35 deletions(-) diff --git a/airflow/dag_processing/bundles/base.py b/airflow/dag_processing/bundles/base.py index ea560f1be26b0..4a575f0f83c51 100644 --- a/airflow/dag_processing/bundles/base.py +++ b/airflow/dag_processing/bundles/base.py @@ -50,6 +50,15 @@ def __init__(self, *, name: str, refresh_interval: int, version: str | None = No self.name = name self.version = version self.refresh_interval = refresh_interval + self.is_initialized: bool = False + + def initialize(self) -> None: + """ + Initialize the bundle. + + This method is called by the DAG processor before the bundle is used. + """ + self.is_initialized = True @property def _dag_bundle_root_storage_path(self) -> Path: diff --git a/airflow/dag_processing/bundles/git.py b/airflow/dag_processing/bundles/git.py index 01ec9e6ba6884..4b2a19de36484 100644 --- a/airflow/dag_processing/bundles/git.py +++ b/airflow/dag_processing/bundles/git.py @@ -118,10 +118,6 @@ def __init__( self.hook = GitHook(git_conn_id=self.git_conn_id) self.repo_url = self.hook.repo_url - def _clone_from(self, to_path: Path, bare: bool = False) -> Repo: - self.log.info("Cloning %s to %s", self.repo_url, to_path) - return Repo.clone_from(self.repo_url, to_path, bare=bare, env=self.hook.env) - def _initialize(self): self._clone_bare_repo_if_required() self._ensure_version_in_bare_repo() @@ -129,7 +125,7 @@ def _initialize(self): self.repo.git.checkout(self.tracking_ref) if self.version: if not self._has_version(self.repo, self.version): - self._fetch_repo() + self.repo.remotes.origin.fetch() self.repo.head.set_reference(self.repo.commit(self.version)) self.repo.head.reset(index=True, working_tree=True) else: @@ -146,19 +142,26 @@ def initialize(self) -> None: ) else: self._initialize() + super().initialize() def _clone_repo_if_required(self) -> None: if not os.path.exists(self.repo_path): - self._clone_from( + self.log.info("Cloning repository to %s from %s", self.repo_path, self.bare_repo_path) + Repo.clone_from( + url=self.bare_repo_path, to_path=self.repo_path, ) + self.repo = Repo(self.repo_path) def _clone_bare_repo_if_required(self) -> None: if not os.path.exists(self.bare_repo_path): - self._clone_from( + self.log.info("Cloning bare repository to %s", self.bare_repo_path) + Repo.clone_from( + url=self.repo_url, to_path=self.bare_repo_path, bare=True, + env=self.hook.env, ) self.bare_repo = Repo(self.bare_repo_path) @@ -197,13 +200,6 @@ def _has_version(repo: Repo, version: str) -> bool: except BadName: return False - def _fetch_repo(self): - if self.hook.env: - with self.repo.git.custom_environment(GIT_SSH_COMMAND=self.hook.env.get("GIT_SSH_COMMAND")): - self.repo.remotes.origin.fetch() - else: - self.repo.remotes.origin.fetch() - def _fetch_bare_repo(self): if self.hook.env: with self.bare_repo.git.custom_environment(GIT_SSH_COMMAND=self.hook.env.get("GIT_SSH_COMMAND")): @@ -211,18 +207,11 @@ def _fetch_bare_repo(self): else: self.bare_repo.remotes.origin.fetch("+refs/heads/*:refs/heads/*") - def _pull_repo(self): - if self.hook.env: - with self.repo.git.custom_environment(GIT_SSH_COMMAND=self.hook.env.get("GIT_SSH_COMMAND")): - self.repo.remotes.origin.pull() - else: - self.repo.remotes.origin.pull() - def refresh(self) -> None: if self.version: raise AirflowException("Refreshing a specific version is not supported") self._fetch_bare_repo() - self._pull_repo() + self.repo.remotes.origin.pull() def _convert_git_ssh_url_to_https(self) -> str: if not self.repo_url.startswith("git@"): diff --git a/airflow/dag_processing/bundles/manager.py b/airflow/dag_processing/bundles/manager.py index 1ae751f8d3304..ad1ebc5889159 100644 --- a/airflow/dag_processing/bundles/manager.py +++ b/airflow/dag_processing/bundles/manager.py @@ -96,6 +96,7 @@ def parse_config(self) -> None: class_ = import_string(cfg["classpath"]) kwargs = cfg["kwargs"] self._bundle_config[name] = (class_, kwargs) + self.log.info("DAG bundles loaded: %s", ", ".join(self._bundle_config.keys())) @provide_session def sync_bundles_to_db(self, *, session: Session = NEW_SESSION) -> None: diff --git a/airflow/dag_processing/manager.py b/airflow/dag_processing/manager.py index 0b831d238d1c3..28619be5b9359 100644 --- a/airflow/dag_processing/manager.py +++ b/airflow/dag_processing/manager.py @@ -653,6 +653,10 @@ def _refresh_dag_bundles(self): self.log.info("Refreshing DAG bundles") for bundle in self._dag_bundles: + # TODO: AIP-66 handle errors in the case of incomplete cloning? And test this. + # What if the cloning/refreshing takes too long(longer than the dag processor timeout) + if not bundle.is_initialized: + bundle.initialize() # TODO: AIP-66 test to make sure we get a fresh record from the db and it's not cached with create_session() as session: bundle_model = session.get(DagBundleModel, bundle.name) diff --git a/tests/dag_processing/test_dag_bundles.py b/tests/dag_processing/test_dag_bundles.py index 6032136610b27..49b7da1a03a92 100644 --- a/tests/dag_processing/test_dag_bundles.py +++ b/tests/dag_processing/test_dag_bundles.py @@ -215,7 +215,6 @@ def test_get_current_version(self, mock_githook, git_repo): @mock.patch("airflow.dag_processing.bundles.git.GitHook") def test_get_specific_version(self, mock_githook, git_repo): - mock_githook.get_conn.return_value = mock.MagicMock() repo_path, repo = git_repo mock_githook.return_value.repo_url = repo_path starting_commit = repo.head.commit @@ -242,7 +241,6 @@ def test_get_specific_version(self, mock_githook, git_repo): @mock.patch("airflow.dag_processing.bundles.git.GitHook") def test_get_tag_version(self, mock_githook, git_repo): - mock_githook.get_conn.return_value = mock.MagicMock() repo_path, repo = git_repo mock_githook.return_value.repo_url = repo_path starting_commit = repo.head.commit @@ -272,7 +270,6 @@ def test_get_tag_version(self, mock_githook, git_repo): @mock.patch("airflow.dag_processing.bundles.git.GitHook") def test_get_latest(self, mock_githook, git_repo): - mock_githook.get_conn.return_value = mock.MagicMock() repo_path, repo = git_repo mock_githook.return_value.repo_url = repo_path starting_commit = repo.head.commit @@ -293,7 +290,6 @@ def test_get_latest(self, mock_githook, git_repo): @mock.patch("airflow.dag_processing.bundles.git.GitHook") def test_refresh(self, mock_githook, git_repo): - mock_githook.get_conn.return_value = mock.MagicMock() repo_path, repo = git_repo mock_githook.return_value.repo_url = repo_path starting_commit = repo.head.commit @@ -321,7 +317,6 @@ def test_refresh(self, mock_githook, git_repo): @mock.patch("airflow.dag_processing.bundles.git.GitHook") def test_head(self, mock_githook, git_repo): - mock_githook.get_conn.return_value = mock.MagicMock() repo_path, repo = git_repo mock_githook.return_value.repo_url = repo_path @@ -332,7 +327,6 @@ def test_head(self, mock_githook, git_repo): @mock.patch("airflow.dag_processing.bundles.git.GitHook") def test_version_not_found(self, mock_githook, git_repo): - mock_githook.get_conn.return_value = mock.MagicMock() repo_path, repo = git_repo mock_githook.return_value.repo_url = repo_path bundle = GitDagBundle( @@ -347,7 +341,6 @@ def test_version_not_found(self, mock_githook, git_repo): @mock.patch("airflow.dag_processing.bundles.git.GitHook") def test_subdir(self, mock_githook, git_repo): - mock_githook.get_conn.return_value = mock.MagicMock() repo_path, repo = git_repo mock_githook.return_value.repo_url = repo_path @@ -387,8 +380,7 @@ def test_raises_when_no_repo_url(self): @mock.patch("airflow.dag_processing.bundles.git.GitHook") @mock.patch("airflow.dag_processing.bundles.git.Repo") - @mock.patch.object(GitDagBundle, "_clone_from") - def test_with_path_as_repo_url(self, mock_clone_from, mock_gitRepo, mock_githook): + def test_with_path_as_repo_url(self, mock_gitRepo, mock_githook): bundle = GitDagBundle( name="test", refresh_interval=300, @@ -396,12 +388,11 @@ def test_with_path_as_repo_url(self, mock_clone_from, mock_gitRepo, mock_githook tracking_ref=GIT_DEFAULT_BRANCH, ) bundle.initialize() - assert mock_clone_from.call_count == 2 + assert mock_gitRepo.clone_from.call_count == 2 assert mock_gitRepo.return_value.git.checkout.call_count == 1 - @mock.patch("airflow.dag_processing.bundles.git.GitHook") @mock.patch("airflow.dag_processing.bundles.git.Repo") - def test_refresh_with_git_connection(self, mock_gitRepo, mock_hook): + def test_refresh_with_git_connection(self, mock_gitRepo): bundle = GitDagBundle( name="test", refresh_interval=300, @@ -423,7 +414,6 @@ def test_refresh_with_git_connection(self, mock_gitRepo, mock_hook): ) @mock.patch("airflow.dag_processing.bundles.git.GitHook") def test_repo_url_validation_for_ssh(self, mock_hook, repo_url, session): - mock_hook.get_conn.return_value = mock.MagicMock() mock_hook.return_value.repo_url = repo_url bundle = GitDagBundle( name="test",