From baeea081d585586f2e717446349dafdd359dc936 Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Wed, 11 Dec 2024 14:05:30 +0100 Subject: [PATCH 01/19] Use SSH to authenticate GitDagBundle This uses SSH hook to authenticate GitDagBundle when provided. --- airflow/dag_processing/bundles/git.py | 61 ++++++++++++++++++++---- tests/dag_processing/test_dag_bundles.py | 23 +++++---- 2 files changed, 68 insertions(+), 16 deletions(-) diff --git a/airflow/dag_processing/bundles/git.py b/airflow/dag_processing/bundles/git.py index d731f65db3b88..63f105b9d98e1 100644 --- a/airflow/dag_processing/bundles/git.py +++ b/airflow/dag_processing/bundles/git.py @@ -18,6 +18,7 @@ from __future__ import annotations import os +import tempfile from typing import TYPE_CHECKING from urllib.parse import urlparse @@ -25,13 +26,14 @@ from git.exc import BadName from airflow.dag_processing.bundles.base import BaseDagBundle -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowOptionalProviderFeatureException +from airflow.utils.log.logging_mixin import LoggingMixin if TYPE_CHECKING: from pathlib import Path -class GitDagBundle(BaseDagBundle): +class GitDagBundle(BaseDagBundle, LoggingMixin): """ git DAG bundle - exposes a git repository as a DAG bundle. @@ -45,16 +47,30 @@ class GitDagBundle(BaseDagBundle): supports_versioning = True - def __init__(self, *, repo_url: str, tracking_ref: str, subdir: str | None = None, **kwargs) -> None: + def __init__( + self, + *, + tracking_ref: str, + subdir: str | None = None, + ssh_conn_id: str | None = None, + repo_url: str | os.PathLike = "", + **kwargs, + ) -> None: super().__init__(**kwargs) self.repo_url = repo_url self.tracking_ref = tracking_ref self.subdir = subdir - + self.ssh_conn_id = ssh_conn_id self.bare_repo_path = self._dag_bundle_root_storage_path / "git" / self.name self.repo_path = ( self._dag_bundle_root_storage_path / "git" / (self.name + f"+{self.version or self.tracking_ref}") ) + self.env: dict[str, str] = {} + + def _clone_from(self, to_path: Path, bare: bool = False) -> Repo: + return Repo.clone_from(self.repo_url, to_path, bare=bare, env=self.env) + + def _init_bundle(self): self._clone_bare_repo_if_required() self._ensure_version_in_bare_repo() self._clone_repo_if_required() @@ -69,18 +85,47 @@ def __init__(self, *, repo_url: str, tracking_ref: str, subdir: str | None = Non else: self.refresh() + def init_bundle(self) -> None: + if self.ssh_conn_id: + try: + from airflow.providers.ssh.hooks.ssh import SSHHook + except ImportError as e: + raise AirflowOptionalProviderFeatureException(e) + ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id) + ssh_hook.get_conn() + temp_key_file_path = None + try: + if not ssh_hook.key_file: + conn = ssh_hook.get_connection(self.ssh_conn_id) + private_key = conn.extra_dejson.get("private_key") + if not private_key: + raise AirflowException("No private key present in connection") + with tempfile.NamedTemporaryFile(delete=False) as key_file: + temp_key_file_path = key_file.name + key_file.write(private_key.encode("utf-8")) + self.env["GIT_SSH_COMMAND"] = f"ssh -i {temp_key_file_path} -o IdentitiesOnly=yes" + else: + self.env["GIT_SSH_COMMAND"] = f"ssh -i {ssh_hook.key_file} -o IdentitiesOnly=yes" + if ssh_hook.remote_host: + self.log.info("Using repo URL defined in the SSH connection") + self.repo_url = ssh_hook.remote_host + self._init_bundle() + finally: + if temp_key_file_path: + os.remove(temp_key_file_path) + else: + self._init_bundle() + def _clone_repo_if_required(self) -> None: if not os.path.exists(self.repo_path): - Repo.clone_from( - url=self.bare_repo_path, + self._clone_from( to_path=self.repo_path, ) self.repo = Repo(self.repo_path) def _clone_bare_repo_if_required(self) -> None: if not os.path.exists(self.bare_repo_path): - Repo.clone_from( - url=self.repo_url, + self._clone_from( to_path=self.bare_repo_path, bare=True, ) diff --git a/tests/dag_processing/test_dag_bundles.py b/tests/dag_processing/test_dag_bundles.py index d450a56131361..21e0a99a3fe7c 100644 --- a/tests/dag_processing/test_dag_bundles.py +++ b/tests/dag_processing/test_dag_bundles.py @@ -123,6 +123,7 @@ def test_get_current_version(self, git_repo): bundle = GitDagBundle( name="test", refresh_interval=300, repo_url=repo_path, tracking_ref=GIT_DEFAULT_BRANCH ) + bundle.init_bundle() assert bundle.get_current_version() == repo.head.commit.hexsha @@ -144,6 +145,7 @@ def test_get_specific_version(self, git_repo): repo_url=repo_path, tracking_ref=GIT_DEFAULT_BRANCH, ) + bundle.init_bundle() assert bundle.get_current_version() == starting_commit.hexsha @@ -172,7 +174,7 @@ def test_get_tag_version(self, git_repo): repo_url=repo_path, tracking_ref=GIT_DEFAULT_BRANCH, ) - + bundle.init_bundle() assert bundle.get_current_version() == starting_commit.hexsha files_in_repo = {f.name for f in bundle.path.iterdir() if f.is_file()} @@ -191,6 +193,7 @@ def test_get_latest(self, git_repo): bundle = GitDagBundle( name="test", refresh_interval=300, repo_url=repo_path, tracking_ref=GIT_DEFAULT_BRANCH ) + bundle.init_bundle() assert bundle.get_current_version() != starting_commit.hexsha @@ -204,6 +207,7 @@ def test_refresh(self, git_repo): bundle = GitDagBundle( name="test", refresh_interval=300, repo_url=repo_path, tracking_ref=GIT_DEFAULT_BRANCH ) + bundle.init_bundle() assert bundle.get_current_version() == starting_commit.hexsha @@ -228,19 +232,21 @@ def test_head(self, git_repo): repo.create_head("test") bundle = GitDagBundle(name="test", refresh_interval=300, repo_url=repo_path, tracking_ref="test") + bundle.init_bundle() assert bundle.repo.head.ref.name == "test" def test_version_not_found(self, git_repo): repo_path, repo = git_repo + bundle = GitDagBundle( + name="test", + refresh_interval=300, + version="not_found", + repo_url=repo_path, + tracking_ref=GIT_DEFAULT_BRANCH, + ) with pytest.raises(AirflowException, match="Version not_found not found in the repository"): - GitDagBundle( - name="test", - refresh_interval=300, - version="not_found", - repo_url=repo_path, - tracking_ref=GIT_DEFAULT_BRANCH, - ) + bundle.init_bundle() def test_subdir(self, git_repo): repo_path, repo = git_repo @@ -262,6 +268,7 @@ def test_subdir(self, git_repo): tracking_ref=GIT_DEFAULT_BRANCH, subdir=subdir, ) + bundle.init_bundle() files_in_repo = {f.name for f in bundle.path.iterdir() if f.is_file()} assert str(bundle.path).endswith(subdir) From b64392412405c76fe7052ea21f2721a9bf7720b1 Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Tue, 17 Dec 2024 17:50:17 +0100 Subject: [PATCH 02/19] Add tests --- tests/dag_processing/test_dag_bundles.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/dag_processing/test_dag_bundles.py b/tests/dag_processing/test_dag_bundles.py index 21e0a99a3fe7c..a563a45f61253 100644 --- a/tests/dag_processing/test_dag_bundles.py +++ b/tests/dag_processing/test_dag_bundles.py @@ -274,6 +274,25 @@ def test_subdir(self, git_repo): assert str(bundle.path).endswith(subdir) assert {"some_new_file.py"} == files_in_repo + @mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook") + @mock.patch("airflow.dag_processing.bundles.git.Repo") + def test_with_ssh_conn_id(self, mock_gitRepo, mock_hook): + repo_url = "git@github.com:apache/airflow.git" + conn_id = "ssh_default" + key_filepath = "/path/to/keyfile" + mock_hook.return_value.key_file = key_filepath + mock_hook.return_value.remote_host = repo_url + bundle = GitDagBundle(name="test", ssh_conn_id="ssh_default", tracking_ref=GIT_DEFAULT_BRANCH) + assert bundle.env == {} + bundle.init_bundle() + mock_hook.assert_called_once_with(ssh_conn_id=conn_id) + assert bundle.env == {"GIT_SSH_COMMAND": f"ssh -i {key_filepath} -o IdentitiesOnly=yes"} + + def test_no_key_file_and_no_private_key_raises_for_ssh_conn(self): + bundle = GitDagBundle(name="test", ssh_conn_id="ssh_default", tracking_ref=GIT_DEFAULT_BRANCH) + with pytest.raises(AirflowException, match="No private key present in connection"): + bundle.init_bundle() + @pytest.mark.parametrize( "repo_url, expected_url", [ From af5d3d9ef46898f238753bad31a02bf5767593a7 Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Wed, 18 Dec 2024 10:38:46 +0100 Subject: [PATCH 03/19] Account for remotes with ssh --- airflow/dag_processing/bundles/git.py | 40 ++++++++++++++++++++---- tests/dag_processing/test_dag_bundles.py | 1 - 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/airflow/dag_processing/bundles/git.py b/airflow/dag_processing/bundles/git.py index 63f105b9d98e1..73a9a41d291d8 100644 --- a/airflow/dag_processing/bundles/git.py +++ b/airflow/dag_processing/bundles/git.py @@ -66,6 +66,8 @@ def __init__( self._dag_bundle_root_storage_path / "git" / (self.name + f"+{self.version or self.tracking_ref}") ) self.env: dict[str, str] = {} + self.pkey: str | None = None + self.key_file: str | None = None def _clone_from(self, to_path: Path, bare: bool = False) -> Repo: return Repo.clone_from(self.repo_url, to_path, bare=bare, env=self.env) @@ -94,22 +96,25 @@ def init_bundle(self) -> None: ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id) ssh_hook.get_conn() temp_key_file_path = None + self.key_file = ssh_hook.key_file try: - if not ssh_hook.key_file: + if not self.key_file: conn = ssh_hook.get_connection(self.ssh_conn_id) - private_key = conn.extra_dejson.get("private_key") + private_key: str | None = conn.extra_dejson.get("private_key") if not private_key: raise AirflowException("No private key present in connection") + self.pkey = private_key with tempfile.NamedTemporaryFile(delete=False) as key_file: temp_key_file_path = key_file.name - key_file.write(private_key.encode("utf-8")) + key_file.write(self.pkey.encode("utf-8")) self.env["GIT_SSH_COMMAND"] = f"ssh -i {temp_key_file_path} -o IdentitiesOnly=yes" else: - self.env["GIT_SSH_COMMAND"] = f"ssh -i {ssh_hook.key_file} -o IdentitiesOnly=yes" + self.env["GIT_SSH_COMMAND"] = f"ssh -i {self.key_file} -o IdentitiesOnly=yes" if ssh_hook.remote_host: self.log.info("Using repo URL defined in the SSH connection") self.repo_url = ssh_hook.remote_host self._init_bundle() + self.env = {} finally: if temp_key_file_path: os.remove(temp_key_file_path) @@ -170,8 +175,31 @@ def refresh(self) -> None: if self.version: raise AirflowException("Refreshing a specific version is not supported") - self.bare_repo.remotes.origin.fetch("+refs/heads/*:refs/heads/*") - self.repo.remotes.origin.pull() + def _refresh(): + self.bare_repo.remotes.origin.fetch("+refs/heads/*:refs/heads/*") + self.repo.remotes.origin.pull() + + if self.env: + _refresh() + elif self.ssh_conn_id: + temp_key_file_path = None + try: + if not self.key_file: + if not self.pkey: + raise AirflowException("Missing private key, please initialize the bundle first") + with tempfile.NamedTemporaryFile(delete=False) as key_file: + temp_key_file_path = key_file.name + key_file.write(self.pkey.encode("utf-8")) + GIT_SSH_COMMAND = f"ssh -i {temp_key_file_path} -o IdentitiesOnly=yes" + else: + GIT_SSH_COMMAND = f"ssh -i {self.key_file} -o IdentitiesOnly=yes" + with self.repo.git.custom_environment(GIT_SSH_COMMAND=GIT_SSH_COMMAND): + _refresh() + finally: + if temp_key_file_path: + os.remove(temp_key_file_path) + else: + _refresh() def _convert_git_ssh_url_to_https(self) -> str: if not self.repo_url.startswith("git@"): diff --git a/tests/dag_processing/test_dag_bundles.py b/tests/dag_processing/test_dag_bundles.py index a563a45f61253..ddf4e3cd2b0fb 100644 --- a/tests/dag_processing/test_dag_bundles.py +++ b/tests/dag_processing/test_dag_bundles.py @@ -286,7 +286,6 @@ def test_with_ssh_conn_id(self, mock_gitRepo, mock_hook): assert bundle.env == {} bundle.init_bundle() mock_hook.assert_called_once_with(ssh_conn_id=conn_id) - assert bundle.env == {"GIT_SSH_COMMAND": f"ssh -i {key_filepath} -o IdentitiesOnly=yes"} def test_no_key_file_and_no_private_key_raises_for_ssh_conn(self): bundle = GitDagBundle(name="test", ssh_conn_id="ssh_default", tracking_ref=GIT_DEFAULT_BRANCH) From 4b482699cd9ad641741b1c9307583b858dd130d5 Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Wed, 18 Dec 2024 10:45:35 +0100 Subject: [PATCH 04/19] renames --- airflow/dag_processing/bundles/git.py | 8 ++++---- tests/dag_processing/test_dag_bundles.py | 20 ++++++++++---------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/airflow/dag_processing/bundles/git.py b/airflow/dag_processing/bundles/git.py index 73a9a41d291d8..d3e084772a454 100644 --- a/airflow/dag_processing/bundles/git.py +++ b/airflow/dag_processing/bundles/git.py @@ -72,7 +72,7 @@ def __init__( def _clone_from(self, to_path: Path, bare: bool = False) -> Repo: return Repo.clone_from(self.repo_url, to_path, bare=bare, env=self.env) - def _init_bundle(self): + def _initialize(self): self._clone_bare_repo_if_required() self._ensure_version_in_bare_repo() self._clone_repo_if_required() @@ -87,7 +87,7 @@ def _init_bundle(self): else: self.refresh() - def init_bundle(self) -> None: + def initialize(self) -> None: if self.ssh_conn_id: try: from airflow.providers.ssh.hooks.ssh import SSHHook @@ -113,13 +113,13 @@ def init_bundle(self) -> None: if ssh_hook.remote_host: self.log.info("Using repo URL defined in the SSH connection") self.repo_url = ssh_hook.remote_host - self._init_bundle() + self._initialize() self.env = {} finally: if temp_key_file_path: os.remove(temp_key_file_path) else: - self._init_bundle() + self._initialize() def _clone_repo_if_required(self) -> None: if not os.path.exists(self.repo_path): diff --git a/tests/dag_processing/test_dag_bundles.py b/tests/dag_processing/test_dag_bundles.py index ddf4e3cd2b0fb..e7b36267b878b 100644 --- a/tests/dag_processing/test_dag_bundles.py +++ b/tests/dag_processing/test_dag_bundles.py @@ -123,7 +123,7 @@ def test_get_current_version(self, git_repo): bundle = GitDagBundle( name="test", refresh_interval=300, repo_url=repo_path, tracking_ref=GIT_DEFAULT_BRANCH ) - bundle.init_bundle() + bundle.initialize() assert bundle.get_current_version() == repo.head.commit.hexsha @@ -145,7 +145,7 @@ def test_get_specific_version(self, git_repo): repo_url=repo_path, tracking_ref=GIT_DEFAULT_BRANCH, ) - bundle.init_bundle() + bundle.initialize() assert bundle.get_current_version() == starting_commit.hexsha @@ -174,7 +174,7 @@ def test_get_tag_version(self, git_repo): repo_url=repo_path, tracking_ref=GIT_DEFAULT_BRANCH, ) - bundle.init_bundle() + bundle.initialize() assert bundle.get_current_version() == starting_commit.hexsha files_in_repo = {f.name for f in bundle.path.iterdir() if f.is_file()} @@ -193,7 +193,7 @@ def test_get_latest(self, git_repo): bundle = GitDagBundle( name="test", refresh_interval=300, repo_url=repo_path, tracking_ref=GIT_DEFAULT_BRANCH ) - bundle.init_bundle() + bundle.initialize() assert bundle.get_current_version() != starting_commit.hexsha @@ -207,7 +207,7 @@ def test_refresh(self, git_repo): bundle = GitDagBundle( name="test", refresh_interval=300, repo_url=repo_path, tracking_ref=GIT_DEFAULT_BRANCH ) - bundle.init_bundle() + bundle.initialize() assert bundle.get_current_version() == starting_commit.hexsha @@ -232,7 +232,7 @@ def test_head(self, git_repo): repo.create_head("test") bundle = GitDagBundle(name="test", refresh_interval=300, repo_url=repo_path, tracking_ref="test") - bundle.init_bundle() + bundle.initialize() assert bundle.repo.head.ref.name == "test" def test_version_not_found(self, git_repo): @@ -246,7 +246,7 @@ def test_version_not_found(self, git_repo): ) with pytest.raises(AirflowException, match="Version not_found not found in the repository"): - bundle.init_bundle() + bundle.initialize() def test_subdir(self, git_repo): repo_path, repo = git_repo @@ -268,7 +268,7 @@ def test_subdir(self, git_repo): tracking_ref=GIT_DEFAULT_BRANCH, subdir=subdir, ) - bundle.init_bundle() + bundle.initialize() files_in_repo = {f.name for f in bundle.path.iterdir() if f.is_file()} assert str(bundle.path).endswith(subdir) @@ -284,13 +284,13 @@ def test_with_ssh_conn_id(self, mock_gitRepo, mock_hook): mock_hook.return_value.remote_host = repo_url bundle = GitDagBundle(name="test", ssh_conn_id="ssh_default", tracking_ref=GIT_DEFAULT_BRANCH) assert bundle.env == {} - bundle.init_bundle() + bundle.initialize() mock_hook.assert_called_once_with(ssh_conn_id=conn_id) def test_no_key_file_and_no_private_key_raises_for_ssh_conn(self): bundle = GitDagBundle(name="test", ssh_conn_id="ssh_default", tracking_ref=GIT_DEFAULT_BRANCH) with pytest.raises(AirflowException, match="No private key present in connection"): - bundle.init_bundle() + bundle.initialize() @pytest.mark.parametrize( "repo_url, expected_url", From 30f4e8acca6a8d1cbcc360d89fdfa159df20b57c Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Wed, 18 Dec 2024 15:19:08 +0100 Subject: [PATCH 05/19] fix tests --- airflow/dag_processing/bundles/git.py | 11 +-- tests/dag_processing/test_dag_bundles.py | 87 +++++++++++++++++++++++- 2 files changed, 91 insertions(+), 7 deletions(-) diff --git a/airflow/dag_processing/bundles/git.py b/airflow/dag_processing/bundles/git.py index d3e084772a454..2ccbabdc03b57 100644 --- a/airflow/dag_processing/bundles/git.py +++ b/airflow/dag_processing/bundles/git.py @@ -43,6 +43,7 @@ class GitDagBundle(BaseDagBundle, LoggingMixin): :param repo_url: URL of the git repository :param tracking_ref: Branch or tag for this DAG bundle :param subdir: Subdirectory within the repository where the DAGs are stored (Optional) + :param ssh_conn_id: Connection ID for SSH connection to the repository (Optional) """ supports_versioning = True @@ -50,10 +51,10 @@ class GitDagBundle(BaseDagBundle, LoggingMixin): def __init__( self, *, + repo_url: str | os.PathLike = "", tracking_ref: str, subdir: str | None = None, ssh_conn_id: str | None = None, - repo_url: str | os.PathLike = "", **kwargs, ) -> None: super().__init__(**kwargs) @@ -179,9 +180,11 @@ def _refresh(): self.bare_repo.remotes.origin.fetch("+refs/heads/*:refs/heads/*") self.repo.remotes.origin.pull() - if self.env: - _refresh() - elif self.ssh_conn_id: + if self.ssh_conn_id: + if self.env: + with self.repo.git.custom_environment(**self.env): + _refresh() + return temp_key_file_path = None try: if not self.key_file: diff --git a/tests/dag_processing/test_dag_bundles.py b/tests/dag_processing/test_dag_bundles.py index e7b36267b878b..1de7593ed6aab 100644 --- a/tests/dag_processing/test_dag_bundles.py +++ b/tests/dag_processing/test_dag_bundles.py @@ -20,6 +20,7 @@ import tempfile from pathlib import Path from unittest import mock +from unittest.mock import MagicMock import pytest from git import Repo @@ -282,16 +283,96 @@ def test_with_ssh_conn_id(self, mock_gitRepo, mock_hook): key_filepath = "/path/to/keyfile" mock_hook.return_value.key_file = key_filepath mock_hook.return_value.remote_host = repo_url - bundle = GitDagBundle(name="test", ssh_conn_id="ssh_default", tracking_ref=GIT_DEFAULT_BRANCH) + bundle = GitDagBundle( + name="test", + refresh_interval=300, + ssh_conn_id="ssh_default", + tracking_ref=GIT_DEFAULT_BRANCH, + ) assert bundle.env == {} bundle.initialize() mock_hook.assert_called_once_with(ssh_conn_id=conn_id) - def test_no_key_file_and_no_private_key_raises_for_ssh_conn(self): - bundle = GitDagBundle(name="test", ssh_conn_id="ssh_default", tracking_ref=GIT_DEFAULT_BRANCH) + @mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook") + def test_no_key_file_and_no_private_key_raises_for_ssh_conn(self, mock_hook): + bundle = GitDagBundle( + name="test", refresh_interval=300, ssh_conn_id="ssh_default", tracking_ref=GIT_DEFAULT_BRANCH + ) with pytest.raises(AirflowException, match="No private key present in connection"): bundle.initialize() + @mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook") + @mock.patch("airflow.dag_processing.bundles.git.Repo") + @mock.patch("airflow.dag_processing.bundles.git.os") + def test_temporary_file_removed_after_initialization(self, mock_os, mock_gitRepo, mock_hook): + repo_url = "git@github.com:apache/airflow.git" + ssh_hook = mock_hook.return_value + ssh_hook.key_file = None + ssh_hook.remote_host = repo_url + conn = MagicMock() + conn.extra_dejson = {"private_key": "private"} + ssh_hook.get_connection.return_value = conn + bundle = GitDagBundle( + name="test", + refresh_interval=300, + ssh_conn_id="ssh_default", + tracking_ref=GIT_DEFAULT_BRANCH, + ) + bundle.initialize() + assert mock_os.remove.call_count == 1 + + @mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook") + @mock.patch("airflow.dag_processing.bundles.git.Repo") + def test_refresh_with_an_existing_env(self, mock_gitRepo, mock_hook): + repo_url = "git@github.com:apache/airflow.git" + key_filepath = "/path/to/keyfile" + mock_hook.return_value.key_file = key_filepath + mock_hook.return_value.remote_host = repo_url + bundle = GitDagBundle( + name="test", + refresh_interval=300, + ssh_conn_id="ssh_default", + tracking_ref=GIT_DEFAULT_BRANCH, + ) + bundle.initialize() + bundle.env = {"GIT_SSH_COMMAND": "ssh -i /path/to/keyfile -o IdentitiesOnly=yes"} + bundle.refresh() + + # check remotes called twice. one at initialize and one at refresh above + assert mock_gitRepo.return_value.remotes.origin.fetch.call_count == 2 + # assert remotes called with custom env + mock_gitRepo.return_value.git.custom_environment.assert_called_with(**bundle.env) + + def test_refresh_with_conn_id_raises_when_bundle_not_initialized(self): + bundle = GitDagBundle( + name="test", refresh_interval=300, ssh_conn_id="ssh_default", tracking_ref=GIT_DEFAULT_BRANCH + ) + with pytest.raises(AirflowException, match="Missing private key, please initialize the bundle first"): + bundle.refresh() + + @mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook") + @mock.patch("airflow.dag_processing.bundles.git.Repo") + @mock.patch("airflow.dag_processing.bundles.git.os") + def test_temporary_file_removed_in_refresh(self, mock_os, mock_gitRepo, mock_hook): + repo_url = "git@github.com:apache/airflow.git" + ssh_hook = mock_hook.return_value + ssh_hook.key_file = None + ssh_hook.remote_host = repo_url + conn = MagicMock() + conn.extra_dejson = {"private_key": "private"} + ssh_hook.get_connection.return_value = conn + bundle = GitDagBundle( + name="test", + refresh_interval=300, + ssh_conn_id="ssh_default", + tracking_ref=GIT_DEFAULT_BRANCH, + ) + bundle.initialize() + # Check os.remove called. + bundle.refresh() + # Check os.remove called twice. Once in initialization and another in the method + assert mock_os.remove.call_count == 2 + @pytest.mark.parametrize( "repo_url, expected_url", [ From b6032ae6e92aaed08952b341bf6e296c3d17af71 Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Thu, 19 Dec 2024 11:45:39 +0100 Subject: [PATCH 06/19] Refactor code --- airflow/dag_processing/bundles/git.py | 94 ++++++++---------------- tests/dag_processing/test_dag_bundles.py | 80 ++++---------------- 2 files changed, 43 insertions(+), 131 deletions(-) diff --git a/airflow/dag_processing/bundles/git.py b/airflow/dag_processing/bundles/git.py index 2ccbabdc03b57..a2af7ff8630bd 100644 --- a/airflow/dag_processing/bundles/git.py +++ b/airflow/dag_processing/bundles/git.py @@ -19,7 +19,7 @@ import os import tempfile -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from urllib.parse import urlparse from git import Repo @@ -51,27 +51,26 @@ class GitDagBundle(BaseDagBundle, LoggingMixin): def __init__( self, *, - repo_url: str | os.PathLike = "", + repo_url: str, tracking_ref: str, subdir: str | None = None, - ssh_conn_id: str | None = None, + ssh_conn_kwargs: dict[str, str] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) self.repo_url = repo_url self.tracking_ref = tracking_ref self.subdir = subdir - self.ssh_conn_id = ssh_conn_id self.bare_repo_path = self._dag_bundle_root_storage_path / "git" / self.name self.repo_path = ( self._dag_bundle_root_storage_path / "git" / (self.name + f"+{self.version or self.tracking_ref}") ) - self.env: dict[str, str] = {} - self.pkey: str | None = None - self.key_file: str | None = None + self.ssh_conn_kwargs = ssh_conn_kwargs + self.ssh_hook: Any | None = None def _clone_from(self, to_path: Path, bare: bool = False) -> Repo: - return Repo.clone_from(self.repo_url, to_path, bare=bare, env=self.env) + self.log.info("Cloning %s to %s", self.repo_url, to_path) + return Repo.clone_from(self.repo_url, to_path, bare=bare) def _initialize(self): self._clone_bare_repo_if_required() @@ -86,39 +85,24 @@ def _initialize(self): self.repo.head.set_reference(self.repo.commit(self.version)) self.repo.head.reset(index=True, working_tree=True) else: - self.refresh() + self._refresh() + + def _ssh_hook(self): + try: + from airflow.providers.ssh.hooks.ssh import SSHHook + except ImportError as e: + raise AirflowOptionalProviderFeatureException(e) + return SSHHook(**self.ssh_conn_kwargs) def initialize(self) -> None: - if self.ssh_conn_id: - try: - from airflow.providers.ssh.hooks.ssh import SSHHook - except ImportError as e: - raise AirflowOptionalProviderFeatureException(e) - ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id) - ssh_hook.get_conn() - temp_key_file_path = None - self.key_file = ssh_hook.key_file - try: - if not self.key_file: - conn = ssh_hook.get_connection(self.ssh_conn_id) - private_key: str | None = conn.extra_dejson.get("private_key") - if not private_key: - raise AirflowException("No private key present in connection") - self.pkey = private_key - with tempfile.NamedTemporaryFile(delete=False) as key_file: - temp_key_file_path = key_file.name - key_file.write(self.pkey.encode("utf-8")) - self.env["GIT_SSH_COMMAND"] = f"ssh -i {temp_key_file_path} -o IdentitiesOnly=yes" - else: - self.env["GIT_SSH_COMMAND"] = f"ssh -i {self.key_file} -o IdentitiesOnly=yes" - if ssh_hook.remote_host: - self.log.info("Using repo URL defined in the SSH connection") - self.repo_url = ssh_hook.remote_host + if self.ssh_conn_kwargs: + if not self.repo_url.startswith("git@") and not self.repo_url.endswith(".git"): + raise AirflowException( + f"Invalid git URL: {self.repo_url}. URL must start with git@ and end with .git" + ) + self.ssh_hook = self._ssh_hook() + with self.ssh_hook.get_conn(): self._initialize() - self.env = {} - finally: - if temp_key_file_path: - os.remove(temp_key_file_path) else: self._initialize() @@ -172,37 +156,19 @@ def _has_version(repo: Repo, version: str) -> bool: except BadName: return False + def _refresh(self): + self.bare_repo.remotes.origin.fetch("+refs/heads/*:refs/heads/*") + self.repo.remotes.origin.pull() + def refresh(self) -> None: if self.version: raise AirflowException("Refreshing a specific version is not supported") - def _refresh(): - self.bare_repo.remotes.origin.fetch("+refs/heads/*:refs/heads/*") - self.repo.remotes.origin.pull() - - if self.ssh_conn_id: - if self.env: - with self.repo.git.custom_environment(**self.env): - _refresh() - return - temp_key_file_path = None - try: - if not self.key_file: - if not self.pkey: - raise AirflowException("Missing private key, please initialize the bundle first") - with tempfile.NamedTemporaryFile(delete=False) as key_file: - temp_key_file_path = key_file.name - key_file.write(self.pkey.encode("utf-8")) - GIT_SSH_COMMAND = f"ssh -i {temp_key_file_path} -o IdentitiesOnly=yes" - else: - GIT_SSH_COMMAND = f"ssh -i {self.key_file} -o IdentitiesOnly=yes" - with self.repo.git.custom_environment(GIT_SSH_COMMAND=GIT_SSH_COMMAND): - _refresh() - finally: - if temp_key_file_path: - os.remove(temp_key_file_path) + if self.ssh_hook: + with self.ssh_hook.get_conn(): + self._refresh() else: - _refresh() + self._refresh() def _convert_git_ssh_url_to_https(self) -> str: if not self.repo_url.startswith("git@"): diff --git a/tests/dag_processing/test_dag_bundles.py b/tests/dag_processing/test_dag_bundles.py index 1de7593ed6aab..b409b53b95aae 100644 --- a/tests/dag_processing/test_dag_bundles.py +++ b/tests/dag_processing/test_dag_bundles.py @@ -20,7 +20,6 @@ import tempfile from pathlib import Path from unittest import mock -from unittest.mock import MagicMock import pytest from git import Repo @@ -280,98 +279,45 @@ def test_subdir(self, git_repo): def test_with_ssh_conn_id(self, mock_gitRepo, mock_hook): repo_url = "git@github.com:apache/airflow.git" conn_id = "ssh_default" - key_filepath = "/path/to/keyfile" - mock_hook.return_value.key_file = key_filepath - mock_hook.return_value.remote_host = repo_url bundle = GitDagBundle( + repo_url=repo_url, name="test", refresh_interval=300, - ssh_conn_id="ssh_default", + ssh_conn_kwargs={"ssh_conn_id": "ssh_default"}, tracking_ref=GIT_DEFAULT_BRANCH, ) - assert bundle.env == {} bundle.initialize() mock_hook.assert_called_once_with(ssh_conn_id=conn_id) - @mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook") - def test_no_key_file_and_no_private_key_raises_for_ssh_conn(self, mock_hook): - bundle = GitDagBundle( - name="test", refresh_interval=300, ssh_conn_id="ssh_default", tracking_ref=GIT_DEFAULT_BRANCH - ) - with pytest.raises(AirflowException, match="No private key present in connection"): - bundle.initialize() - - @mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook") - @mock.patch("airflow.dag_processing.bundles.git.Repo") - @mock.patch("airflow.dag_processing.bundles.git.os") - def test_temporary_file_removed_after_initialization(self, mock_os, mock_gitRepo, mock_hook): - repo_url = "git@github.com:apache/airflow.git" - ssh_hook = mock_hook.return_value - ssh_hook.key_file = None - ssh_hook.remote_host = repo_url - conn = MagicMock() - conn.extra_dejson = {"private_key": "private"} - ssh_hook.get_connection.return_value = conn - bundle = GitDagBundle( - name="test", - refresh_interval=300, - ssh_conn_id="ssh_default", - tracking_ref=GIT_DEFAULT_BRANCH, - ) - bundle.initialize() - assert mock_os.remove.call_count == 1 - @mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook") @mock.patch("airflow.dag_processing.bundles.git.Repo") - def test_refresh_with_an_existing_env(self, mock_gitRepo, mock_hook): + def test_refresh_with_ssh_connection(self, mock_gitRepo, mock_hook): repo_url = "git@github.com:apache/airflow.git" - key_filepath = "/path/to/keyfile" - mock_hook.return_value.key_file = key_filepath - mock_hook.return_value.remote_host = repo_url bundle = GitDagBundle( + repo_url=repo_url, name="test", refresh_interval=300, - ssh_conn_id="ssh_default", + ssh_conn_kwargs={"ssh_conn_id": "ssh_default"}, tracking_ref=GIT_DEFAULT_BRANCH, ) bundle.initialize() - bundle.env = {"GIT_SSH_COMMAND": "ssh -i /path/to/keyfile -o IdentitiesOnly=yes"} bundle.refresh() - # check remotes called twice. one at initialize and one at refresh above assert mock_gitRepo.return_value.remotes.origin.fetch.call_count == 2 - # assert remotes called with custom env - mock_gitRepo.return_value.git.custom_environment.assert_called_with(**bundle.env) - - def test_refresh_with_conn_id_raises_when_bundle_not_initialized(self): - bundle = GitDagBundle( - name="test", refresh_interval=300, ssh_conn_id="ssh_default", tracking_ref=GIT_DEFAULT_BRANCH - ) - with pytest.raises(AirflowException, match="Missing private key, please initialize the bundle first"): - bundle.refresh() - @mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook") - @mock.patch("airflow.dag_processing.bundles.git.Repo") - @mock.patch("airflow.dag_processing.bundles.git.os") - def test_temporary_file_removed_in_refresh(self, mock_os, mock_gitRepo, mock_hook): - repo_url = "git@github.com:apache/airflow.git" - ssh_hook = mock_hook.return_value - ssh_hook.key_file = None - ssh_hook.remote_host = repo_url - conn = MagicMock() - conn.extra_dejson = {"private_key": "private"} - ssh_hook.get_connection.return_value = conn + def test_repo_url_starts_with_git_when_using_ssh_conn_id(self): + repo_url = "https://github.com/apache/airflow" bundle = GitDagBundle( + repo_url=repo_url, name="test", refresh_interval=300, - ssh_conn_id="ssh_default", + ssh_conn_kwargs={"ssh_conn_id": "ssh_default"}, tracking_ref=GIT_DEFAULT_BRANCH, ) - bundle.initialize() - # Check os.remove called. - bundle.refresh() - # Check os.remove called twice. Once in initialization and another in the method - assert mock_os.remove.call_count == 2 + with pytest.raises( + AirflowException, match=f"Invalid git URL: {repo_url}. URL must start with git@ and end with .git" + ): + bundle.initialize() @pytest.mark.parametrize( "repo_url, expected_url", From 369461b646e8d466bbe6352816e2f5f7b92a8e1b Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Mon, 23 Dec 2024 12:47:27 +0100 Subject: [PATCH 07/19] Use githook --- airflow/dag_processing/bundles/git.py | 83 ++++++++++++++++---- airflow/dag_processing/bundles/provider.yaml | 46 +++++++++++ 2 files changed, 112 insertions(+), 17 deletions(-) create mode 100644 airflow/dag_processing/bundles/provider.yaml diff --git a/airflow/dag_processing/bundles/git.py b/airflow/dag_processing/bundles/git.py index a2af7ff8630bd..75fed74332d69 100644 --- a/airflow/dag_processing/bundles/git.py +++ b/airflow/dag_processing/bundles/git.py @@ -29,9 +29,65 @@ from airflow.exceptions import AirflowException, AirflowOptionalProviderFeatureException from airflow.utils.log.logging_mixin import LoggingMixin +try: + from airflow.providers.ssh.hooks.ssh import SSHHook +except ImportError as e: + raise AirflowOptionalProviderFeatureException(e) + if TYPE_CHECKING: from pathlib import Path + import paramiko + + +class GitHook(SSHHook): + """ + Hook for git repositories. + + :param git_conn_id: Connection ID for SSH connection to the repository + + """ + + conn_name_attr = "git_conn_id" + default_conn_name = "git_default" + conn_type = "git" + hook_name = "GIT" + + @classmethod + def get_ui_field_behaviour(cls) -> dict[str, Any]: + return { + "hidden_fields": ["schema"], + "relabeling": { + "login": "Username", + }, + } + + def __init__(self, git_conn_id="git_default", *args, **kwargs): + self.conn: paramiko.SSHClient | None = None + kwargs["ssh_conn_id"] = git_conn_id + super().__init__(*args, **kwargs) + connection = self.get_connection(git_conn_id) + self.repo_url = connection.extra_dejson.get("git_repo_url") + self.auth_token = connection.extra_dejson.get("git_access_token", None) or connection.password + self._process_git_auth_url() + + def _process_git_auth_url(self): + if not isinstance(self.repo_url, str): + return + if self.auth_token and self.repo_url.startswith("https://"): + self.repo_url = self.repo_url.replace("https://", f"https://{self.auth_token}@") + + def get_conn(self): + """ + Establish an SSH connection. + + Please use as a context manager to ensure the connection is closed after use. + :return: SSH connection + """ + if self.conn is None: + self.conn = super().get_conn() + return self.conn + class GitDagBundle(BaseDagBundle, LoggingMixin): """ @@ -51,22 +107,21 @@ class GitDagBundle(BaseDagBundle, LoggingMixin): def __init__( self, *, - repo_url: str, tracking_ref: str, subdir: str | None = None, - ssh_conn_kwargs: dict[str, str] | None = None, + git_conn_id: str = "git_default", **kwargs, ) -> None: super().__init__(**kwargs) - self.repo_url = repo_url self.tracking_ref = tracking_ref self.subdir = subdir self.bare_repo_path = self._dag_bundle_root_storage_path / "git" / self.name self.repo_path = ( self._dag_bundle_root_storage_path / "git" / (self.name + f"+{self.version or self.tracking_ref}") ) - self.ssh_conn_kwargs = ssh_conn_kwargs - self.ssh_hook: Any | None = None + self.git_conn_id = git_conn_id + self.hook = GitHook(git_conn_id=self.git_conn_id) + self.repo_url = self.hook.repo_url def _clone_from(self, to_path: Path, bare: bool = False) -> Repo: self.log.info("Cloning %s to %s", self.repo_url, to_path) @@ -87,21 +142,15 @@ def _initialize(self): else: self._refresh() - def _ssh_hook(self): - try: - from airflow.providers.ssh.hooks.ssh import SSHHook - except ImportError as e: - raise AirflowOptionalProviderFeatureException(e) - return SSHHook(**self.ssh_conn_kwargs) - def initialize(self) -> None: - if self.ssh_conn_kwargs: + if not self.repo_url: + raise AirflowException(f"Connection {self.git_conn_id} doesn't have a git_repo_url") + if self.repo_url.startswith("git@"): if not self.repo_url.startswith("git@") and not self.repo_url.endswith(".git"): raise AirflowException( f"Invalid git URL: {self.repo_url}. URL must start with git@ and end with .git" ) - self.ssh_hook = self._ssh_hook() - with self.ssh_hook.get_conn(): + with self.hook.get_conn(): self._initialize() else: self._initialize() @@ -164,8 +213,8 @@ def refresh(self) -> None: if self.version: raise AirflowException("Refreshing a specific version is not supported") - if self.ssh_hook: - with self.ssh_hook.get_conn(): + if self.hook: + with self.hook.get_conn(): self._refresh() else: self._refresh() diff --git a/airflow/dag_processing/bundles/provider.yaml b/airflow/dag_processing/bundles/provider.yaml new file mode 100644 index 0000000000000..fcbb81c97e095 --- /dev/null +++ b/airflow/dag_processing/bundles/provider.yaml @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-git +name: GIT +description: | + `GIT`__ + +state: not-ready +source-date-epoch: 1726861127 +# note that those versions are maintained by release manager - do not update them manually +versions: + - 1.0.0 + +dependencies: + - apache-airflow-providers-ssh + - paramiko>=2.9.0 + - asyncssh>=2.12.0 + +integrations: + - integration-name: GIT (Git) + +hooks: + - integration-name: GIT + python-modules: + - airflow.dag_processing.bundles.git + + +connection-types: + - hook-class-name: airflow.dag_processing.bundles.git.GitHook + connection-type: git From f256e5970dce7aa10cf2e5d4e146515af21e5239 Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Mon, 23 Dec 2024 18:19:28 +0100 Subject: [PATCH 08/19] fixup! Use githook --- airflow/dag_processing/bundles/git.py | 14 +- tests/dag_processing/test_dag_bundles.py | 212 ++++++++++++++++++----- 2 files changed, 179 insertions(+), 47 deletions(-) diff --git a/airflow/dag_processing/bundles/git.py b/airflow/dag_processing/bundles/git.py index 75fed74332d69..2374113bf9a14 100644 --- a/airflow/dag_processing/bundles/git.py +++ b/airflow/dag_processing/bundles/git.py @@ -76,6 +76,8 @@ def _process_git_auth_url(self): return if self.auth_token and self.repo_url.startswith("https://"): self.repo_url = self.repo_url.replace("https://", f"https://{self.auth_token}@") + elif not self.repo_url.startswith("git@") or not self.repo_url.startswith("https://"): + self.repo_url = os.path.expanduser(self.repo_url) def get_conn(self): """ @@ -145,11 +147,13 @@ def _initialize(self): def initialize(self) -> None: if not self.repo_url: raise AirflowException(f"Connection {self.git_conn_id} doesn't have a git_repo_url") - if self.repo_url.startswith("git@"): - if not self.repo_url.startswith("git@") and not self.repo_url.endswith(".git"): - raise AirflowException( - f"Invalid git URL: {self.repo_url}. URL must start with git@ and end with .git" - ) + if isinstance(self.repo_url, os.PathLike): + self._initialize() + elif not self.repo_url.startswith("git@") or not self.repo_url.endswith(".git"): + raise AirflowException( + f"Invalid git URL: {self.repo_url}. URL must start with git@ and end with .git" + ) + elif self.repo_url.startswith("git@"): with self.hook.get_conn(): self._initialize() else: diff --git a/tests/dag_processing/test_dag_bundles.py b/tests/dag_processing/test_dag_bundles.py index b409b53b95aae..bcc320910e77e 100644 --- a/tests/dag_processing/test_dag_bundles.py +++ b/tests/dag_processing/test_dag_bundles.py @@ -26,11 +26,14 @@ from airflow.dag_processing.bundles.base import BaseDagBundle from airflow.dag_processing.bundles.dagfolder import DagsFolderDagBundle -from airflow.dag_processing.bundles.git import GitDagBundle +from airflow.dag_processing.bundles.git import GitDagBundle, GitHook, SSHHook from airflow.dag_processing.bundles.local import LocalDagBundle from airflow.exceptions import AirflowException +from airflow.models import Connection +from airflow.utils import db from tests_common.test_utils.config import conf_vars +from tests_common.test_utils.db import clear_db_connections @pytest.fixture(autouse=True) @@ -107,28 +110,125 @@ def git_repo(tmp_path_factory): return (directory, repo) +AIRFLOW_HTTPS_URL = "https://github.com/apache/airflow.git" +AIRFLOW_GIT = "git@github.com:apache/airflow.git" +ACCESS_TOKEN = "my_access_token" +CONN_DEFAULT = "git_default" +CONN_HTTPS = "my_git_conn" +CONN_HTTPS_PASSWORD = "my_git_conn_https_password" +CONN_ONLY_PATH = "my_git_conn_only_path" +CONN_NO_REPO_URL = "my_git_conn_no_repo_url" + + +class TestGitHook: + @classmethod + def teardown_class(cls) -> None: + clear_db_connections() + + @classmethod + def setup_class(cls) -> None: + db.merge_conn( + Connection( + conn_id=CONN_DEFAULT, + host="github.com", + conn_type="git", + extra={"git_repo_url": AIRFLOW_GIT}, + ) + ) + db.merge_conn( + Connection( + conn_id=CONN_HTTPS, + host="github.com", + conn_type="git", + extra={"git_repo_url": AIRFLOW_HTTPS_URL, "git_access_token": ACCESS_TOKEN}, + ) + ) + db.merge_conn( + Connection( + conn_id=CONN_HTTPS_PASSWORD, + host="github.com", + conn_type="git", + password=ACCESS_TOKEN, + extra={"git_repo_url": AIRFLOW_HTTPS_URL}, + ) + ) + db.merge_conn( + Connection( + conn_id=CONN_ONLY_PATH, + host="github.com", + conn_type="git", + extra={"git_repo_url": "path/to/repo"}, + ) + ) + + @pytest.mark.parametrize( + "conn_id, expected_repo_url", + [ + (CONN_DEFAULT, AIRFLOW_GIT), + (CONN_HTTPS, f"https://{ACCESS_TOKEN}@github.com/apache/airflow.git"), + (CONN_HTTPS_PASSWORD, f"https://{ACCESS_TOKEN}@github.com/apache/airflow.git"), + (CONN_ONLY_PATH, "path/to/repo"), + ], + ) + def test_correct_repo_urls(self, conn_id, expected_repo_url): + hook = GitHook(git_conn_id=conn_id) + assert hook.repo_url == expected_repo_url + + @mock.patch.object(SSHHook, "get_conn") + def test_connection_made_to_ssh_hook(self, mock_ssh_hook_get_conn): + hook = GitHook(git_conn_id=CONN_DEFAULT) + hook.get_conn() + mock_ssh_hook_get_conn.assert_called_once_with() + + class TestGitDagBundle: + @classmethod + def teardown_class(cls) -> None: + clear_db_connections() + + @classmethod + def setup_class(cls) -> None: + db.merge_conn( + Connection( + conn_id="git_default", + host="github.com", + conn_type="git", + extra={"git_repo_url": "git@github.com:apache/airflow.git"}, + ) + ) + db.merge_conn( + Connection( + conn_id=CONN_NO_REPO_URL, + host="github.com", + conn_type="git", + extra="{}", + ) + ) + def test_supports_versioning(self): assert GitDagBundle.supports_versioning is True def test_uses_dag_bundle_root_storage_path(self, git_repo): repo_path, repo = git_repo - bundle = GitDagBundle( - name="test", refresh_interval=300, repo_url=repo_path, tracking_ref=GIT_DEFAULT_BRANCH - ) + bundle = GitDagBundle(name="test", refresh_interval=300, tracking_ref=GIT_DEFAULT_BRANCH) assert str(bundle._dag_bundle_root_storage_path) in str(bundle.path) - def test_get_current_version(self, git_repo): + @mock.patch("airflow.dag_processing.bundles.git.GitHook") + def test_get_current_version(self, mock_githook, git_repo): + mock_githook.get_conn.return_value = mock.MagicMock() repo_path, repo = git_repo - bundle = GitDagBundle( - name="test", refresh_interval=300, repo_url=repo_path, tracking_ref=GIT_DEFAULT_BRANCH - ) + mock_githook.return_value.repo_url = repo_path + bundle = GitDagBundle(name="test", refresh_interval=300, tracking_ref=GIT_DEFAULT_BRANCH) + bundle.initialize() assert bundle.get_current_version() == repo.head.commit.hexsha - def test_get_specific_version(self, git_repo): + @mock.patch("airflow.dag_processing.bundles.git.GitHook") + def test_get_specific_version(self, mock_githook, git_repo): + mock_githook.get_conn.return_value = mock.MagicMock() repo_path, repo = git_repo + mock_githook.return_value.repo_url = repo_path starting_commit = repo.head.commit # Add new file to the repo @@ -142,7 +242,6 @@ def test_get_specific_version(self, git_repo): name="test", refresh_interval=300, version=starting_commit.hexsha, - repo_url=repo_path, tracking_ref=GIT_DEFAULT_BRANCH, ) bundle.initialize() @@ -152,8 +251,11 @@ def test_get_specific_version(self, git_repo): files_in_repo = {f.name for f in bundle.path.iterdir() if f.is_file()} assert {"test_dag.py"} == files_in_repo - def test_get_tag_version(self, git_repo): + @mock.patch("airflow.dag_processing.bundles.git.GitHook") + def test_get_tag_version(self, mock_githook, git_repo): + mock_githook.get_conn.return_value = mock.MagicMock() repo_path, repo = git_repo + mock_githook.return_value.repo_url = repo_path starting_commit = repo.head.commit # add tag @@ -171,7 +273,6 @@ def test_get_tag_version(self, git_repo): name="test", refresh_interval=300, version="test", - repo_url=repo_path, tracking_ref=GIT_DEFAULT_BRANCH, ) bundle.initialize() @@ -180,8 +281,11 @@ def test_get_tag_version(self, git_repo): files_in_repo = {f.name for f in bundle.path.iterdir() if f.is_file()} assert {"test_dag.py"} == files_in_repo - def test_get_latest(self, git_repo): + @mock.patch("airflow.dag_processing.bundles.git.GitHook") + def test_get_latest(self, mock_githook, git_repo): + mock_githook.get_conn.return_value = mock.MagicMock() repo_path, repo = git_repo + mock_githook.return_value.repo_url = repo_path starting_commit = repo.head.commit file_path = repo_path / "new_test.py" @@ -190,9 +294,7 @@ def test_get_latest(self, git_repo): repo.index.add([file_path]) repo.index.commit("Another commit") - bundle = GitDagBundle( - name="test", refresh_interval=300, repo_url=repo_path, tracking_ref=GIT_DEFAULT_BRANCH - ) + bundle = GitDagBundle(name="test", refresh_interval=300, tracking_ref=GIT_DEFAULT_BRANCH) bundle.initialize() assert bundle.get_current_version() != starting_commit.hexsha @@ -200,13 +302,14 @@ def test_get_latest(self, git_repo): files_in_repo = {f.name for f in bundle.path.iterdir() if f.is_file()} assert {"test_dag.py", "new_test.py"} == files_in_repo - def test_refresh(self, git_repo): + @mock.patch("airflow.dag_processing.bundles.git.GitHook") + def test_refresh(self, mock_githook, git_repo): + mock_githook.get_conn.return_value = mock.MagicMock() repo_path, repo = git_repo + mock_githook.return_value.repo_url = repo_path starting_commit = repo.head.commit - bundle = GitDagBundle( - name="test", refresh_interval=300, repo_url=repo_path, tracking_ref=GIT_DEFAULT_BRANCH - ) + bundle = GitDagBundle(name="test", refresh_interval=300, tracking_ref=GIT_DEFAULT_BRANCH) bundle.initialize() assert bundle.get_current_version() == starting_commit.hexsha @@ -227,29 +330,37 @@ def test_refresh(self, git_repo): files_in_repo = {f.name for f in bundle.path.iterdir() if f.is_file()} assert {"test_dag.py", "new_test.py"} == files_in_repo - def test_head(self, git_repo): + @mock.patch("airflow.dag_processing.bundles.git.GitHook") + def test_head(self, mock_githook, git_repo): + mock_githook.get_conn.return_value = mock.MagicMock() repo_path, repo = git_repo + mock_githook.return_value.repo_url = repo_path repo.create_head("test") - bundle = GitDagBundle(name="test", refresh_interval=300, repo_url=repo_path, tracking_ref="test") + bundle = GitDagBundle(name="test", refresh_interval=300, tracking_ref="test") bundle.initialize() assert bundle.repo.head.ref.name == "test" - def test_version_not_found(self, git_repo): + @mock.patch("airflow.dag_processing.bundles.git.GitHook") + def test_version_not_found(self, mock_githook, git_repo): + mock_githook.get_conn.return_value = mock.MagicMock() repo_path, repo = git_repo + mock_githook.return_value.repo_url = repo_path bundle = GitDagBundle( name="test", refresh_interval=300, version="not_found", - repo_url=repo_path, tracking_ref=GIT_DEFAULT_BRANCH, ) with pytest.raises(AirflowException, match="Version not_found not found in the repository"): bundle.initialize() - def test_subdir(self, git_repo): + @mock.patch("airflow.dag_processing.bundles.git.GitHook") + def test_subdir(self, mock_githook, git_repo): + mock_githook.get_conn.return_value = mock.MagicMock() repo_path, repo = git_repo + mock_githook.return_value.repo_url = repo_path subdir = "somesubdir" subdir_path = repo_path / subdir @@ -264,7 +375,6 @@ def test_subdir(self, git_repo): bundle = GitDagBundle( name="test", refresh_interval=300, - repo_url=repo_path, tracking_ref=GIT_DEFAULT_BRANCH, subdir=subdir, ) @@ -274,30 +384,39 @@ def test_subdir(self, git_repo): assert str(bundle.path).endswith(subdir) assert {"some_new_file.py"} == files_in_repo - @mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook") + def test_raises_when_no_repo_url(self): + bundle = GitDagBundle( + name="test", + refresh_interval=300, + git_conn_id=CONN_NO_REPO_URL, + tracking_ref=GIT_DEFAULT_BRANCH, + ) + with pytest.raises( + AirflowException, match=f"Connection {CONN_NO_REPO_URL} doesn't have a git_repo_url" + ): + bundle.initialize() + + @mock.patch("airflow.dag_processing.bundles.git.GitHook") @mock.patch("airflow.dag_processing.bundles.git.Repo") - def test_with_ssh_conn_id(self, mock_gitRepo, mock_hook): - repo_url = "git@github.com:apache/airflow.git" - conn_id = "ssh_default" + @mock.patch.object(GitDagBundle, "_clone_from") + def test_with_path_as_repo_url(self, mock_clone_from, mock_gitRepo, mock_githook): bundle = GitDagBundle( - repo_url=repo_url, name="test", refresh_interval=300, - ssh_conn_kwargs={"ssh_conn_id": "ssh_default"}, + git_conn_id=CONN_ONLY_PATH, tracking_ref=GIT_DEFAULT_BRANCH, ) bundle.initialize() - mock_hook.assert_called_once_with(ssh_conn_id=conn_id) + assert mock_clone_from.call_count == 2 + assert mock_gitRepo.return_value.git.checkout.call_count == 1 - @mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook") + @mock.patch("airflow.dag_processing.bundles.git.GitHook") @mock.patch("airflow.dag_processing.bundles.git.Repo") - def test_refresh_with_ssh_connection(self, mock_gitRepo, mock_hook): - repo_url = "git@github.com:apache/airflow.git" + def test_refresh_with_git_connection(self, mock_gitRepo, mock_hook): bundle = GitDagBundle( - repo_url=repo_url, name="test", refresh_interval=300, - ssh_conn_kwargs={"ssh_conn_id": "ssh_default"}, + git_conn_id="git_default", tracking_ref=GIT_DEFAULT_BRANCH, ) bundle.initialize() @@ -305,13 +424,22 @@ def test_refresh_with_ssh_connection(self, mock_gitRepo, mock_hook): # check remotes called twice. one at initialize and one at refresh above assert mock_gitRepo.return_value.remotes.origin.fetch.call_count == 2 - def test_repo_url_starts_with_git_when_using_ssh_conn_id(self): - repo_url = "https://github.com/apache/airflow" + @pytest.mark.parametrize( + "repo_url", + [ + pytest.param("https://github.com/apache/airflow", id="https_url"), + pytest.param("airflow@example:apache/airflow.git", id="does_not_start_with_git_at"), + pytest.param("git@example:apache/airflow", id="does_not_end_with_dot_git"), + ], + ) + @mock.patch("airflow.dag_processing.bundles.git.GitHook") + def test_repo_url_starts_with_git_when_using_ssh_conn_id(self, mock_hook, repo_url, session): + mock_hook.get_conn.return_value = mock.MagicMock() + mock_hook.return_value.repo_url = repo_url bundle = GitDagBundle( - repo_url=repo_url, name="test", refresh_interval=300, - ssh_conn_kwargs={"ssh_conn_id": "ssh_default"}, + git_conn_id="git_default", tracking_ref=GIT_DEFAULT_BRANCH, ) with pytest.raises( From 5f7f87af17556bb380efe57ea4ad22cd72b540e9 Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Fri, 27 Dec 2024 09:44:41 +0100 Subject: [PATCH 09/19] Populate the connection form with git type connection --- airflow/dag_processing/bundles/provider.yaml | 5 ++--- airflow/providers_manager.py | 5 +++++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/airflow/dag_processing/bundles/provider.yaml b/airflow/dag_processing/bundles/provider.yaml index fcbb81c97e095..32a2dec28a379 100644 --- a/airflow/dag_processing/bundles/provider.yaml +++ b/airflow/dag_processing/bundles/provider.yaml @@ -16,10 +16,10 @@ # under the License. --- -package-name: apache-airflow-providers-git +package-name: apache-airflow-providers-bundles name: GIT description: | - `GIT`__ + `GIT `__ state: not-ready source-date-epoch: 1726861127 @@ -30,7 +30,6 @@ versions: dependencies: - apache-airflow-providers-ssh - paramiko>=2.9.0 - - asyncssh>=2.12.0 integrations: - integration-name: GIT (Git) diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py index 575306a840b79..955e678449fce 100644 --- a/airflow/providers_manager.py +++ b/airflow/providers_manager.py @@ -175,6 +175,9 @@ def _create_customized_form_field_behaviours_schema_validator(): def _check_builtin_provider_prefix(provider_package: str, class_name: str) -> bool: + if "bundles" in provider_package: + # TODO: remove this when this package is moved to providers directory + return True if provider_package.startswith("apache-airflow"): provider_path = provider_package[len("apache-") :].replace("-", ".") if not class_name.startswith(provider_path): @@ -676,6 +679,8 @@ def _discover_all_airflow_builtin_providers_from_local_sources(self) -> None: self._add_provider_info_from_local_source_files_on_path(path) except Exception as e: log.warning("Error when loading 'provider.yaml' files from %s airflow sources: %s", path, e) + # TODO: Remove this when the package is moved to providers + self._add_provider_info_from_local_source_files_on_path("airflow/dag_processing") def _add_provider_info_from_local_source_files_on_path(self, path) -> None: """ From e188b830b0813f62da1e4f2f1b1beef6c3cdaa5f Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Fri, 27 Dec 2024 11:52:27 +0100 Subject: [PATCH 10/19] Mark test_dag_bundles as db test --- tests/dag_processing/test_dag_bundles.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/dag_processing/test_dag_bundles.py b/tests/dag_processing/test_dag_bundles.py index bcc320910e77e..16a994e89ef61 100644 --- a/tests/dag_processing/test_dag_bundles.py +++ b/tests/dag_processing/test_dag_bundles.py @@ -35,6 +35,8 @@ from tests_common.test_utils.config import conf_vars from tests_common.test_utils.db import clear_db_connections +pytestmark = pytest.mark.db_test + @pytest.fixture(autouse=True) def bundle_temp_dir(tmp_path): From 3e6591d1fdc078da190e01f6da05261b5b2035a4 Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Mon, 6 Jan 2025 10:12:28 +0100 Subject: [PATCH 11/19] Add names to the extra items --- airflow/dag_processing/bundles/git.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/airflow/dag_processing/bundles/git.py b/airflow/dag_processing/bundles/git.py index 2374113bf9a14..05bdcfa658356 100644 --- a/airflow/dag_processing/bundles/git.py +++ b/airflow/dag_processing/bundles/git.py @@ -17,6 +17,7 @@ from __future__ import annotations +import json import os import tempfile from typing import TYPE_CHECKING, Any @@ -60,6 +61,15 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]: "relabeling": { "login": "Username", }, + "placeholders": { + "extra": json.dumps( + { + "git_repo_url": "git@github.com:orgname/projectname.git", + "git_access_token": "optional_access_token_can_be_deleted", + "key_file": "optional/path/to/keyfile", + } + ) + }, } def __init__(self, git_conn_id="git_default", *args, **kwargs): From 7d10f3954fefe8112ce9de487e302b13937c8a82 Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Mon, 6 Jan 2025 13:00:59 +0100 Subject: [PATCH 12/19] Update airflow/dag_processing/bundles/git.py Co-authored-by: Felix Uellendall --- airflow/dag_processing/bundles/git.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/dag_processing/bundles/git.py b/airflow/dag_processing/bundles/git.py index 05bdcfa658356..c0ee897b31fa4 100644 --- a/airflow/dag_processing/bundles/git.py +++ b/airflow/dag_processing/bundles/git.py @@ -111,7 +111,7 @@ class GitDagBundle(BaseDagBundle, LoggingMixin): :param repo_url: URL of the git repository :param tracking_ref: Branch or tag for this DAG bundle :param subdir: Subdirectory within the repository where the DAGs are stored (Optional) - :param ssh_conn_id: Connection ID for SSH connection to the repository (Optional) + :param git_conn_id: Connection ID for SSH connection to the repository (Optional) """ supports_versioning = True From 5347787887bb135cb2e98fe928c1519abf1c75ff Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Mon, 6 Jan 2025 13:06:48 +0100 Subject: [PATCH 13/19] Fix refresh --- airflow/dag_processing/bundles/git.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/airflow/dag_processing/bundles/git.py b/airflow/dag_processing/bundles/git.py index c0ee897b31fa4..9acdf7ccecc62 100644 --- a/airflow/dag_processing/bundles/git.py +++ b/airflow/dag_processing/bundles/git.py @@ -220,17 +220,13 @@ def _has_version(repo: Repo, version: str) -> bool: return False def _refresh(self): + if self.version: + raise AirflowException("Refreshing a specific version is not supported") self.bare_repo.remotes.origin.fetch("+refs/heads/*:refs/heads/*") self.repo.remotes.origin.pull() def refresh(self) -> None: - if self.version: - raise AirflowException("Refreshing a specific version is not supported") - - if self.hook: - with self.hook.get_conn(): - self._refresh() - else: + with self.hook.get_conn(): self._refresh() def _convert_git_ssh_url_to_https(self) -> str: From 0a5cd0d513da8da610dafa4eb781c5a5c63e7907 Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Thu, 9 Jan 2025 10:44:16 +0100 Subject: [PATCH 14/19] Apply suggestions from code review Co-authored-by: Jed Cunningham <66968678+jedcunningham@users.noreply.github.com> --- airflow/dag_processing/bundles/git.py | 10 ++++------ airflow/providers_manager.py | 4 ++-- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/airflow/dag_processing/bundles/git.py b/airflow/dag_processing/bundles/git.py index 9acdf7ccecc62..c42e9271fa24b 100644 --- a/airflow/dag_processing/bundles/git.py +++ b/airflow/dag_processing/bundles/git.py @@ -108,10 +108,9 @@ class GitDagBundle(BaseDagBundle, LoggingMixin): Instead of cloning the repository every time, we clone the repository once into a bare repo from the source and then do a clone for each version from there. - :param repo_url: URL of the git repository :param tracking_ref: Branch or tag for this DAG bundle :param subdir: Subdirectory within the repository where the DAGs are stored (Optional) - :param git_conn_id: Connection ID for SSH connection to the repository (Optional) + :param git_conn_id: Connection ID for SSH/token based connection to the repository (Optional) """ supports_versioning = True @@ -157,13 +156,12 @@ def _initialize(self): def initialize(self) -> None: if not self.repo_url: raise AirflowException(f"Connection {self.git_conn_id} doesn't have a git_repo_url") - if isinstance(self.repo_url, os.PathLike): - self._initialize() - elif not self.repo_url.startswith("git@") or not self.repo_url.endswith(".git"): + if self.repo_url.startswith("git@") or not self.repo_url.endswith(".git"): raise AirflowException( f"Invalid git URL: {self.repo_url}. URL must start with git@ and end with .git" ) - elif self.repo_url.startswith("git@"): + + if self.repo_url.startswith("git@"): with self.hook.get_conn(): self._initialize() else: diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py index 955e678449fce..9b39439384f56 100644 --- a/airflow/providers_manager.py +++ b/airflow/providers_manager.py @@ -176,7 +176,7 @@ def _create_customized_form_field_behaviours_schema_validator(): def _check_builtin_provider_prefix(provider_package: str, class_name: str) -> bool: if "bundles" in provider_package: - # TODO: remove this when this package is moved to providers directory + # TODO: AIP-66: remove this when this package is moved to providers directory return True if provider_package.startswith("apache-airflow"): provider_path = provider_package[len("apache-") :].replace("-", ".") @@ -679,7 +679,7 @@ def _discover_all_airflow_builtin_providers_from_local_sources(self) -> None: self._add_provider_info_from_local_source_files_on_path(path) except Exception as e: log.warning("Error when loading 'provider.yaml' files from %s airflow sources: %s", path, e) - # TODO: Remove this when the package is moved to providers + # TODO: AIP-66: Remove this when the package is moved to providers self._add_provider_info_from_local_source_files_on_path("airflow/dag_processing") def _add_provider_info_from_local_source_files_on_path(self, path) -> None: From 1fa75638c01fb0780a499b064c7531ab30396770 Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Tue, 14 Jan 2025 13:42:34 +0100 Subject: [PATCH 15/19] Remove ssh hook inheritance --- airflow/dag_processing/bundles/git.py | 89 ++++++++++++------------ tests/dag_processing/test_dag_bundles.py | 26 ++----- 2 files changed, 50 insertions(+), 65 deletions(-) diff --git a/airflow/dag_processing/bundles/git.py b/airflow/dag_processing/bundles/git.py index c42e9271fa24b..01ec9e6ba6884 100644 --- a/airflow/dag_processing/bundles/git.py +++ b/airflow/dag_processing/bundles/git.py @@ -19,7 +19,6 @@ import json import os -import tempfile from typing import TYPE_CHECKING, Any from urllib.parse import urlparse @@ -27,21 +26,15 @@ from git.exc import BadName from airflow.dag_processing.bundles.base import BaseDagBundle -from airflow.exceptions import AirflowException, AirflowOptionalProviderFeatureException +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook from airflow.utils.log.logging_mixin import LoggingMixin -try: - from airflow.providers.ssh.hooks.ssh import SSHHook -except ImportError as e: - raise AirflowOptionalProviderFeatureException(e) - if TYPE_CHECKING: from pathlib import Path - import paramiko - -class GitHook(SSHHook): +class GitHook(BaseHook): """ Hook for git repositories. @@ -60,12 +53,12 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]: "hidden_fields": ["schema"], "relabeling": { "login": "Username", + "host": "Repository URL", + "password": "Access Token (optional)", }, "placeholders": { "extra": json.dumps( { - "git_repo_url": "git@github.com:orgname/projectname.git", - "git_access_token": "optional_access_token_can_be_deleted", "key_file": "optional/path/to/keyfile", } ) @@ -73,12 +66,14 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]: } def __init__(self, git_conn_id="git_default", *args, **kwargs): - self.conn: paramiko.SSHClient | None = None - kwargs["ssh_conn_id"] = git_conn_id - super().__init__(*args, **kwargs) + super().__init__() connection = self.get_connection(git_conn_id) - self.repo_url = connection.extra_dejson.get("git_repo_url") - self.auth_token = connection.extra_dejson.get("git_access_token", None) or connection.password + self.repo_url = connection.host + self.auth_token = connection.password + self.key_file = connection.extra_dejson.get("key_file") + self.env: dict[str, str] = {} + if self.key_file: + self.env["GIT_SSH_COMMAND"] = f"ssh -i {self.key_file} -o IdentitiesOnly=yes" self._process_git_auth_url() def _process_git_auth_url(self): @@ -89,17 +84,6 @@ def _process_git_auth_url(self): elif not self.repo_url.startswith("git@") or not self.repo_url.startswith("https://"): self.repo_url = os.path.expanduser(self.repo_url) - def get_conn(self): - """ - Establish an SSH connection. - - Please use as a context manager to ensure the connection is closed after use. - :return: SSH connection - """ - if self.conn is None: - self.conn = super().get_conn() - return self.conn - class GitDagBundle(BaseDagBundle, LoggingMixin): """ @@ -136,34 +120,30 @@ def __init__( def _clone_from(self, to_path: Path, bare: bool = False) -> Repo: self.log.info("Cloning %s to %s", self.repo_url, to_path) - return Repo.clone_from(self.repo_url, to_path, bare=bare) + return Repo.clone_from(self.repo_url, to_path, bare=bare, env=self.hook.env) def _initialize(self): self._clone_bare_repo_if_required() self._ensure_version_in_bare_repo() self._clone_repo_if_required() self.repo.git.checkout(self.tracking_ref) - if self.version: if not self._has_version(self.repo, self.version): - self.repo.remotes.origin.fetch() - + self._fetch_repo() self.repo.head.set_reference(self.repo.commit(self.version)) self.repo.head.reset(index=True, working_tree=True) else: - self._refresh() + self.refresh() def initialize(self) -> None: if not self.repo_url: raise AirflowException(f"Connection {self.git_conn_id} doesn't have a git_repo_url") - if self.repo_url.startswith("git@") or not self.repo_url.endswith(".git"): + if isinstance(self.repo_url, os.PathLike): + self._initialize() + elif not self.repo_url.startswith("git@") or not self.repo_url.endswith(".git"): raise AirflowException( f"Invalid git URL: {self.repo_url}. URL must start with git@ and end with .git" ) - - if self.repo_url.startswith("git@"): - with self.hook.get_conn(): - self._initialize() else: self._initialize() @@ -186,7 +166,7 @@ def _ensure_version_in_bare_repo(self) -> None: if not self.version: return if not self._has_version(self.bare_repo, self.version): - self.bare_repo.remotes.origin.fetch("+refs/heads/*:refs/heads/*") + self._fetch_bare_repo() if not self._has_version(self.bare_repo, self.version): raise AirflowException(f"Version {self.version} not found in the repository") @@ -217,15 +197,32 @@ def _has_version(repo: Repo, version: str) -> bool: except BadName: return False - def _refresh(self): - if self.version: - raise AirflowException("Refreshing a specific version is not supported") - self.bare_repo.remotes.origin.fetch("+refs/heads/*:refs/heads/*") - self.repo.remotes.origin.pull() + def _fetch_repo(self): + if self.hook.env: + with self.repo.git.custom_environment(GIT_SSH_COMMAND=self.hook.env.get("GIT_SSH_COMMAND")): + self.repo.remotes.origin.fetch() + else: + self.repo.remotes.origin.fetch() + + def _fetch_bare_repo(self): + if self.hook.env: + with self.bare_repo.git.custom_environment(GIT_SSH_COMMAND=self.hook.env.get("GIT_SSH_COMMAND")): + self.bare_repo.remotes.origin.fetch("+refs/heads/*:refs/heads/*") + else: + self.bare_repo.remotes.origin.fetch("+refs/heads/*:refs/heads/*") + + def _pull_repo(self): + if self.hook.env: + with self.repo.git.custom_environment(GIT_SSH_COMMAND=self.hook.env.get("GIT_SSH_COMMAND")): + self.repo.remotes.origin.pull() + else: + self.repo.remotes.origin.pull() def refresh(self) -> None: - with self.hook.get_conn(): - self._refresh() + if self.version: + raise AirflowException("Refreshing a specific version is not supported") + self._fetch_bare_repo() + self._pull_repo() def _convert_git_ssh_url_to_https(self) -> str: if not self.repo_url.startswith("git@"): diff --git a/tests/dag_processing/test_dag_bundles.py b/tests/dag_processing/test_dag_bundles.py index 16a994e89ef61..5478daa4e1aeb 100644 --- a/tests/dag_processing/test_dag_bundles.py +++ b/tests/dag_processing/test_dag_bundles.py @@ -26,7 +26,7 @@ from airflow.dag_processing.bundles.base import BaseDagBundle from airflow.dag_processing.bundles.dagfolder import DagsFolderDagBundle -from airflow.dag_processing.bundles.git import GitDagBundle, GitHook, SSHHook +from airflow.dag_processing.bundles.git import GitDagBundle, GitHook from airflow.dag_processing.bundles.local import LocalDagBundle from airflow.exceptions import AirflowException from airflow.models import Connection @@ -132,34 +132,31 @@ def setup_class(cls) -> None: db.merge_conn( Connection( conn_id=CONN_DEFAULT, - host="github.com", + host=AIRFLOW_GIT, conn_type="git", - extra={"git_repo_url": AIRFLOW_GIT}, ) ) db.merge_conn( Connection( conn_id=CONN_HTTPS, - host="github.com", + host=AIRFLOW_HTTPS_URL, + password=ACCESS_TOKEN, conn_type="git", - extra={"git_repo_url": AIRFLOW_HTTPS_URL, "git_access_token": ACCESS_TOKEN}, ) ) db.merge_conn( Connection( conn_id=CONN_HTTPS_PASSWORD, - host="github.com", + host=AIRFLOW_HTTPS_URL, conn_type="git", password=ACCESS_TOKEN, - extra={"git_repo_url": AIRFLOW_HTTPS_URL}, ) ) db.merge_conn( Connection( conn_id=CONN_ONLY_PATH, - host="github.com", + host="path/to/repo", conn_type="git", - extra={"git_repo_url": "path/to/repo"}, ) ) @@ -176,12 +173,6 @@ def test_correct_repo_urls(self, conn_id, expected_repo_url): hook = GitHook(git_conn_id=conn_id) assert hook.repo_url == expected_repo_url - @mock.patch.object(SSHHook, "get_conn") - def test_connection_made_to_ssh_hook(self, mock_ssh_hook_get_conn): - hook = GitHook(git_conn_id=CONN_DEFAULT) - hook.get_conn() - mock_ssh_hook_get_conn.assert_called_once_with() - class TestGitDagBundle: @classmethod @@ -193,17 +184,14 @@ def setup_class(cls) -> None: db.merge_conn( Connection( conn_id="git_default", - host="github.com", + host="git@github.com:apache/airflow.git", conn_type="git", - extra={"git_repo_url": "git@github.com:apache/airflow.git"}, ) ) db.merge_conn( Connection( conn_id=CONN_NO_REPO_URL, - host="github.com", conn_type="git", - extra="{}", ) ) From 455e3b776b242dae5a68b132b73b1e77d6762319 Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Tue, 14 Jan 2025 14:36:16 +0100 Subject: [PATCH 16/19] fixup! Remove ssh hook inheritance --- tests/dag_processing/test_dag_bundles.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/dag_processing/test_dag_bundles.py b/tests/dag_processing/test_dag_bundles.py index 5478daa4e1aeb..b91f0455ee1be 100644 --- a/tests/dag_processing/test_dag_bundles.py +++ b/tests/dag_processing/test_dag_bundles.py @@ -450,11 +450,18 @@ def test_repo_url_starts_with_git_when_using_ssh_conn_id(self, mock_hook, repo_u ], ) @mock.patch("airflow.dag_processing.bundles.git.Repo") - def test_view_url(self, mock_gitrepo, repo_url, expected_url): + def test_view_url(self, mock_gitrepo, repo_url, expected_url, session): + session.query(Connection).delete() + conn = Connection( + conn_id="git_default", + host=repo_url, + conn_type="git", + ) + session.add(conn) + session.commit() bundle = GitDagBundle( name="test", refresh_interval=300, - repo_url=repo_url, tracking_ref="main", ) view_url = bundle.view_url("0f0f0f") @@ -465,7 +472,6 @@ def test_view_url_returns_none_when_no_version_in_view_url(self, mock_gitrepo): bundle = GitDagBundle( name="test", refresh_interval=300, - repo_url="git@github.com:apache/airflow.git", tracking_ref="main", ) view_url = bundle.view_url(None) From de27d211deec7ba35e26d33b4330d671690b39aa Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Wed, 15 Jan 2025 13:10:01 +0100 Subject: [PATCH 17/19] Apply suggestions from code review Co-authored-by: Jed Cunningham <66968678+jedcunningham@users.noreply.github.com> --- tests/dag_processing/test_dag_bundles.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/dag_processing/test_dag_bundles.py b/tests/dag_processing/test_dag_bundles.py index b91f0455ee1be..6032136610b27 100644 --- a/tests/dag_processing/test_dag_bundles.py +++ b/tests/dag_processing/test_dag_bundles.py @@ -205,7 +205,6 @@ def test_uses_dag_bundle_root_storage_path(self, git_repo): @mock.patch("airflow.dag_processing.bundles.git.GitHook") def test_get_current_version(self, mock_githook, git_repo): - mock_githook.get_conn.return_value = mock.MagicMock() repo_path, repo = git_repo mock_githook.return_value.repo_url = repo_path bundle = GitDagBundle(name="test", refresh_interval=300, tracking_ref=GIT_DEFAULT_BRANCH) @@ -423,7 +422,7 @@ def test_refresh_with_git_connection(self, mock_gitRepo, mock_hook): ], ) @mock.patch("airflow.dag_processing.bundles.git.GitHook") - def test_repo_url_starts_with_git_when_using_ssh_conn_id(self, mock_hook, repo_url, session): + def test_repo_url_validation_for_ssh(self, mock_hook, repo_url, session): mock_hook.get_conn.return_value = mock.MagicMock() mock_hook.return_value.repo_url = repo_url bundle = GitDagBundle( From 969dbbd4a7d0c2a96f305157bc7206e978972f6d Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Wed, 15 Jan 2025 16:29:43 +0100 Subject: [PATCH 18/19] Fix code and link to dag processor --- airflow/dag_processing/bundles/base.py | 9 +++++++ airflow/dag_processing/bundles/git.py | 33 ++++++++--------------- airflow/dag_processing/bundles/manager.py | 1 + airflow/dag_processing/manager.py | 4 +++ tests/dag_processing/test_dag_bundles.py | 16 +++-------- 5 files changed, 28 insertions(+), 35 deletions(-) diff --git a/airflow/dag_processing/bundles/base.py b/airflow/dag_processing/bundles/base.py index ea560f1be26b0..4a575f0f83c51 100644 --- a/airflow/dag_processing/bundles/base.py +++ b/airflow/dag_processing/bundles/base.py @@ -50,6 +50,15 @@ def __init__(self, *, name: str, refresh_interval: int, version: str | None = No self.name = name self.version = version self.refresh_interval = refresh_interval + self.is_initialized: bool = False + + def initialize(self) -> None: + """ + Initialize the bundle. + + This method is called by the DAG processor before the bundle is used. + """ + self.is_initialized = True @property def _dag_bundle_root_storage_path(self) -> Path: diff --git a/airflow/dag_processing/bundles/git.py b/airflow/dag_processing/bundles/git.py index 01ec9e6ba6884..4b2a19de36484 100644 --- a/airflow/dag_processing/bundles/git.py +++ b/airflow/dag_processing/bundles/git.py @@ -118,10 +118,6 @@ def __init__( self.hook = GitHook(git_conn_id=self.git_conn_id) self.repo_url = self.hook.repo_url - def _clone_from(self, to_path: Path, bare: bool = False) -> Repo: - self.log.info("Cloning %s to %s", self.repo_url, to_path) - return Repo.clone_from(self.repo_url, to_path, bare=bare, env=self.hook.env) - def _initialize(self): self._clone_bare_repo_if_required() self._ensure_version_in_bare_repo() @@ -129,7 +125,7 @@ def _initialize(self): self.repo.git.checkout(self.tracking_ref) if self.version: if not self._has_version(self.repo, self.version): - self._fetch_repo() + self.repo.remotes.origin.fetch() self.repo.head.set_reference(self.repo.commit(self.version)) self.repo.head.reset(index=True, working_tree=True) else: @@ -146,19 +142,26 @@ def initialize(self) -> None: ) else: self._initialize() + super().initialize() def _clone_repo_if_required(self) -> None: if not os.path.exists(self.repo_path): - self._clone_from( + self.log.info("Cloning repository to %s from %s", self.repo_path, self.bare_repo_path) + Repo.clone_from( + url=self.bare_repo_path, to_path=self.repo_path, ) + self.repo = Repo(self.repo_path) def _clone_bare_repo_if_required(self) -> None: if not os.path.exists(self.bare_repo_path): - self._clone_from( + self.log.info("Cloning bare repository to %s", self.bare_repo_path) + Repo.clone_from( + url=self.repo_url, to_path=self.bare_repo_path, bare=True, + env=self.hook.env, ) self.bare_repo = Repo(self.bare_repo_path) @@ -197,13 +200,6 @@ def _has_version(repo: Repo, version: str) -> bool: except BadName: return False - def _fetch_repo(self): - if self.hook.env: - with self.repo.git.custom_environment(GIT_SSH_COMMAND=self.hook.env.get("GIT_SSH_COMMAND")): - self.repo.remotes.origin.fetch() - else: - self.repo.remotes.origin.fetch() - def _fetch_bare_repo(self): if self.hook.env: with self.bare_repo.git.custom_environment(GIT_SSH_COMMAND=self.hook.env.get("GIT_SSH_COMMAND")): @@ -211,18 +207,11 @@ def _fetch_bare_repo(self): else: self.bare_repo.remotes.origin.fetch("+refs/heads/*:refs/heads/*") - def _pull_repo(self): - if self.hook.env: - with self.repo.git.custom_environment(GIT_SSH_COMMAND=self.hook.env.get("GIT_SSH_COMMAND")): - self.repo.remotes.origin.pull() - else: - self.repo.remotes.origin.pull() - def refresh(self) -> None: if self.version: raise AirflowException("Refreshing a specific version is not supported") self._fetch_bare_repo() - self._pull_repo() + self.repo.remotes.origin.pull() def _convert_git_ssh_url_to_https(self) -> str: if not self.repo_url.startswith("git@"): diff --git a/airflow/dag_processing/bundles/manager.py b/airflow/dag_processing/bundles/manager.py index 1ae751f8d3304..ad1ebc5889159 100644 --- a/airflow/dag_processing/bundles/manager.py +++ b/airflow/dag_processing/bundles/manager.py @@ -96,6 +96,7 @@ def parse_config(self) -> None: class_ = import_string(cfg["classpath"]) kwargs = cfg["kwargs"] self._bundle_config[name] = (class_, kwargs) + self.log.info("DAG bundles loaded: %s", ", ".join(self._bundle_config.keys())) @provide_session def sync_bundles_to_db(self, *, session: Session = NEW_SESSION) -> None: diff --git a/airflow/dag_processing/manager.py b/airflow/dag_processing/manager.py index 0b831d238d1c3..9adbac5a336a9 100644 --- a/airflow/dag_processing/manager.py +++ b/airflow/dag_processing/manager.py @@ -653,6 +653,10 @@ def _refresh_dag_bundles(self): self.log.info("Refreshing DAG bundles") for bundle in self._dag_bundles: + # TODO: AIP-66 handle errors in the case of incomplete cloning? And test this. + # What if the cloning/refreshing took too long(longer than the dag processor timeout) + if not bundle.is_initialized: + bundle.initialize() # TODO: AIP-66 test to make sure we get a fresh record from the db and it's not cached with create_session() as session: bundle_model = session.get(DagBundleModel, bundle.name) diff --git a/tests/dag_processing/test_dag_bundles.py b/tests/dag_processing/test_dag_bundles.py index 6032136610b27..49b7da1a03a92 100644 --- a/tests/dag_processing/test_dag_bundles.py +++ b/tests/dag_processing/test_dag_bundles.py @@ -215,7 +215,6 @@ def test_get_current_version(self, mock_githook, git_repo): @mock.patch("airflow.dag_processing.bundles.git.GitHook") def test_get_specific_version(self, mock_githook, git_repo): - mock_githook.get_conn.return_value = mock.MagicMock() repo_path, repo = git_repo mock_githook.return_value.repo_url = repo_path starting_commit = repo.head.commit @@ -242,7 +241,6 @@ def test_get_specific_version(self, mock_githook, git_repo): @mock.patch("airflow.dag_processing.bundles.git.GitHook") def test_get_tag_version(self, mock_githook, git_repo): - mock_githook.get_conn.return_value = mock.MagicMock() repo_path, repo = git_repo mock_githook.return_value.repo_url = repo_path starting_commit = repo.head.commit @@ -272,7 +270,6 @@ def test_get_tag_version(self, mock_githook, git_repo): @mock.patch("airflow.dag_processing.bundles.git.GitHook") def test_get_latest(self, mock_githook, git_repo): - mock_githook.get_conn.return_value = mock.MagicMock() repo_path, repo = git_repo mock_githook.return_value.repo_url = repo_path starting_commit = repo.head.commit @@ -293,7 +290,6 @@ def test_get_latest(self, mock_githook, git_repo): @mock.patch("airflow.dag_processing.bundles.git.GitHook") def test_refresh(self, mock_githook, git_repo): - mock_githook.get_conn.return_value = mock.MagicMock() repo_path, repo = git_repo mock_githook.return_value.repo_url = repo_path starting_commit = repo.head.commit @@ -321,7 +317,6 @@ def test_refresh(self, mock_githook, git_repo): @mock.patch("airflow.dag_processing.bundles.git.GitHook") def test_head(self, mock_githook, git_repo): - mock_githook.get_conn.return_value = mock.MagicMock() repo_path, repo = git_repo mock_githook.return_value.repo_url = repo_path @@ -332,7 +327,6 @@ def test_head(self, mock_githook, git_repo): @mock.patch("airflow.dag_processing.bundles.git.GitHook") def test_version_not_found(self, mock_githook, git_repo): - mock_githook.get_conn.return_value = mock.MagicMock() repo_path, repo = git_repo mock_githook.return_value.repo_url = repo_path bundle = GitDagBundle( @@ -347,7 +341,6 @@ def test_version_not_found(self, mock_githook, git_repo): @mock.patch("airflow.dag_processing.bundles.git.GitHook") def test_subdir(self, mock_githook, git_repo): - mock_githook.get_conn.return_value = mock.MagicMock() repo_path, repo = git_repo mock_githook.return_value.repo_url = repo_path @@ -387,8 +380,7 @@ def test_raises_when_no_repo_url(self): @mock.patch("airflow.dag_processing.bundles.git.GitHook") @mock.patch("airflow.dag_processing.bundles.git.Repo") - @mock.patch.object(GitDagBundle, "_clone_from") - def test_with_path_as_repo_url(self, mock_clone_from, mock_gitRepo, mock_githook): + def test_with_path_as_repo_url(self, mock_gitRepo, mock_githook): bundle = GitDagBundle( name="test", refresh_interval=300, @@ -396,12 +388,11 @@ def test_with_path_as_repo_url(self, mock_clone_from, mock_gitRepo, mock_githook tracking_ref=GIT_DEFAULT_BRANCH, ) bundle.initialize() - assert mock_clone_from.call_count == 2 + assert mock_gitRepo.clone_from.call_count == 2 assert mock_gitRepo.return_value.git.checkout.call_count == 1 - @mock.patch("airflow.dag_processing.bundles.git.GitHook") @mock.patch("airflow.dag_processing.bundles.git.Repo") - def test_refresh_with_git_connection(self, mock_gitRepo, mock_hook): + def test_refresh_with_git_connection(self, mock_gitRepo): bundle = GitDagBundle( name="test", refresh_interval=300, @@ -423,7 +414,6 @@ def test_refresh_with_git_connection(self, mock_gitRepo, mock_hook): ) @mock.patch("airflow.dag_processing.bundles.git.GitHook") def test_repo_url_validation_for_ssh(self, mock_hook, repo_url, session): - mock_hook.get_conn.return_value = mock.MagicMock() mock_hook.return_value.repo_url = repo_url bundle = GitDagBundle( name="test", From 5b713830c1985ba89e56bf9914d051c28ee0277c Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Thu, 16 Jan 2025 09:56:49 +0100 Subject: [PATCH 19/19] Apply suggestions from code review Co-authored-by: Jed Cunningham <66968678+jedcunningham@users.noreply.github.com> --- airflow/dag_processing/bundles/base.py | 3 ++- airflow/dag_processing/bundles/provider.yaml | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/dag_processing/bundles/base.py b/airflow/dag_processing/bundles/base.py index 4a575f0f83c51..cf0467b372a4e 100644 --- a/airflow/dag_processing/bundles/base.py +++ b/airflow/dag_processing/bundles/base.py @@ -56,7 +56,8 @@ def initialize(self) -> None: """ Initialize the bundle. - This method is called by the DAG processor before the bundle is used. + This method is called by the DAG processor before the bundle is used, + and allows for deferring expensive operations until that point in time. """ self.is_initialized = True diff --git a/airflow/dag_processing/bundles/provider.yaml b/airflow/dag_processing/bundles/provider.yaml index 32a2dec28a379..9ca5d1479f28c 100644 --- a/airflow/dag_processing/bundles/provider.yaml +++ b/airflow/dag_processing/bundles/provider.yaml @@ -29,7 +29,6 @@ versions: dependencies: - apache-airflow-providers-ssh - - paramiko>=2.9.0 integrations: - integration-name: GIT (Git)