diff --git a/airflow/dag_processing/bundles/base.py b/airflow/dag_processing/bundles/base.py index ea560f1be26b0..cf0467b372a4e 100644 --- a/airflow/dag_processing/bundles/base.py +++ b/airflow/dag_processing/bundles/base.py @@ -50,6 +50,16 @@ def __init__(self, *, name: str, refresh_interval: int, version: str | None = No self.name = name self.version = version self.refresh_interval = refresh_interval + self.is_initialized: bool = False + + def initialize(self) -> None: + """ + Initialize the bundle. + + This method is called by the DAG processor before the bundle is used, + and allows for deferring expensive operations until that point in time. + """ + self.is_initialized = True @property def _dag_bundle_root_storage_path(self) -> Path: diff --git a/airflow/dag_processing/bundles/git.py b/airflow/dag_processing/bundles/git.py index d731f65db3b88..4b2a19de36484 100644 --- a/airflow/dag_processing/bundles/git.py +++ b/airflow/dag_processing/bundles/git.py @@ -17,8 +17,9 @@ from __future__ import annotations +import json import os -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from urllib.parse import urlparse from git import Repo @@ -26,63 +27,141 @@ from airflow.dag_processing.bundles.base import BaseDagBundle from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from airflow.utils.log.logging_mixin import LoggingMixin if TYPE_CHECKING: from pathlib import Path -class GitDagBundle(BaseDagBundle): +class GitHook(BaseHook): + """ + Hook for git repositories. + + :param git_conn_id: Connection ID for SSH connection to the repository + + """ + + conn_name_attr = "git_conn_id" + default_conn_name = "git_default" + conn_type = "git" + hook_name = "GIT" + + @classmethod + def get_ui_field_behaviour(cls) -> dict[str, Any]: + return { + "hidden_fields": ["schema"], + "relabeling": { + "login": "Username", + "host": "Repository URL", + "password": "Access Token (optional)", + }, + "placeholders": { + "extra": json.dumps( + { + "key_file": "optional/path/to/keyfile", + } + ) + }, + } + + def __init__(self, git_conn_id="git_default", *args, **kwargs): + super().__init__() + connection = self.get_connection(git_conn_id) + self.repo_url = connection.host + self.auth_token = connection.password + self.key_file = connection.extra_dejson.get("key_file") + self.env: dict[str, str] = {} + if self.key_file: + self.env["GIT_SSH_COMMAND"] = f"ssh -i {self.key_file} -o IdentitiesOnly=yes" + self._process_git_auth_url() + + def _process_git_auth_url(self): + if not isinstance(self.repo_url, str): + return + if self.auth_token and self.repo_url.startswith("https://"): + self.repo_url = self.repo_url.replace("https://", f"https://{self.auth_token}@") + elif not self.repo_url.startswith("git@") or not self.repo_url.startswith("https://"): + self.repo_url = os.path.expanduser(self.repo_url) + + +class GitDagBundle(BaseDagBundle, LoggingMixin): """ git DAG bundle - exposes a git repository as a DAG bundle. Instead of cloning the repository every time, we clone the repository once into a bare repo from the source and then do a clone for each version from there. - :param repo_url: URL of the git repository :param tracking_ref: Branch or tag for this DAG bundle :param subdir: Subdirectory within the repository where the DAGs are stored (Optional) + :param git_conn_id: Connection ID for SSH/token based connection to the repository (Optional) """ supports_versioning = True - def __init__(self, *, repo_url: str, tracking_ref: str, subdir: str | None = None, **kwargs) -> None: + def __init__( + self, + *, + tracking_ref: str, + subdir: str | None = None, + git_conn_id: str = "git_default", + **kwargs, + ) -> None: super().__init__(**kwargs) - self.repo_url = repo_url self.tracking_ref = tracking_ref self.subdir = subdir - self.bare_repo_path = self._dag_bundle_root_storage_path / "git" / self.name self.repo_path = ( self._dag_bundle_root_storage_path / "git" / (self.name + f"+{self.version or self.tracking_ref}") ) + self.git_conn_id = git_conn_id + self.hook = GitHook(git_conn_id=self.git_conn_id) + self.repo_url = self.hook.repo_url + + def _initialize(self): self._clone_bare_repo_if_required() self._ensure_version_in_bare_repo() self._clone_repo_if_required() self.repo.git.checkout(self.tracking_ref) - if self.version: if not self._has_version(self.repo, self.version): self.repo.remotes.origin.fetch() - self.repo.head.set_reference(self.repo.commit(self.version)) self.repo.head.reset(index=True, working_tree=True) else: self.refresh() + def initialize(self) -> None: + if not self.repo_url: + raise AirflowException(f"Connection {self.git_conn_id} doesn't have a git_repo_url") + if isinstance(self.repo_url, os.PathLike): + self._initialize() + elif not self.repo_url.startswith("git@") or not self.repo_url.endswith(".git"): + raise AirflowException( + f"Invalid git URL: {self.repo_url}. URL must start with git@ and end with .git" + ) + else: + self._initialize() + super().initialize() + def _clone_repo_if_required(self) -> None: if not os.path.exists(self.repo_path): + self.log.info("Cloning repository to %s from %s", self.repo_path, self.bare_repo_path) Repo.clone_from( url=self.bare_repo_path, to_path=self.repo_path, ) + self.repo = Repo(self.repo_path) def _clone_bare_repo_if_required(self) -> None: if not os.path.exists(self.bare_repo_path): + self.log.info("Cloning bare repository to %s", self.bare_repo_path) Repo.clone_from( url=self.repo_url, to_path=self.bare_repo_path, bare=True, + env=self.hook.env, ) self.bare_repo = Repo(self.bare_repo_path) @@ -90,7 +169,7 @@ def _ensure_version_in_bare_repo(self) -> None: if not self.version: return if not self._has_version(self.bare_repo, self.version): - self.bare_repo.remotes.origin.fetch("+refs/heads/*:refs/heads/*") + self._fetch_bare_repo() if not self._has_version(self.bare_repo, self.version): raise AirflowException(f"Version {self.version} not found in the repository") @@ -121,11 +200,17 @@ def _has_version(repo: Repo, version: str) -> bool: except BadName: return False + def _fetch_bare_repo(self): + if self.hook.env: + with self.bare_repo.git.custom_environment(GIT_SSH_COMMAND=self.hook.env.get("GIT_SSH_COMMAND")): + self.bare_repo.remotes.origin.fetch("+refs/heads/*:refs/heads/*") + else: + self.bare_repo.remotes.origin.fetch("+refs/heads/*:refs/heads/*") + def refresh(self) -> None: if self.version: raise AirflowException("Refreshing a specific version is not supported") - - self.bare_repo.remotes.origin.fetch("+refs/heads/*:refs/heads/*") + self._fetch_bare_repo() self.repo.remotes.origin.pull() def _convert_git_ssh_url_to_https(self) -> str: diff --git a/airflow/dag_processing/bundles/manager.py b/airflow/dag_processing/bundles/manager.py index 1ae751f8d3304..ad1ebc5889159 100644 --- a/airflow/dag_processing/bundles/manager.py +++ b/airflow/dag_processing/bundles/manager.py @@ -96,6 +96,7 @@ def parse_config(self) -> None: class_ = import_string(cfg["classpath"]) kwargs = cfg["kwargs"] self._bundle_config[name] = (class_, kwargs) + self.log.info("DAG bundles loaded: %s", ", ".join(self._bundle_config.keys())) @provide_session def sync_bundles_to_db(self, *, session: Session = NEW_SESSION) -> None: diff --git a/airflow/dag_processing/bundles/provider.yaml b/airflow/dag_processing/bundles/provider.yaml new file mode 100644 index 0000000000000..9ca5d1479f28c --- /dev/null +++ b/airflow/dag_processing/bundles/provider.yaml @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-bundles +name: GIT +description: | + `GIT `__ + +state: not-ready +source-date-epoch: 1726861127 +# note that those versions are maintained by release manager - do not update them manually +versions: + - 1.0.0 + +dependencies: + - apache-airflow-providers-ssh + +integrations: + - integration-name: GIT (Git) + +hooks: + - integration-name: GIT + python-modules: + - airflow.dag_processing.bundles.git + + +connection-types: + - hook-class-name: airflow.dag_processing.bundles.git.GitHook + connection-type: git diff --git a/airflow/dag_processing/manager.py b/airflow/dag_processing/manager.py index 96c7fe4f0ed5c..220b55edce69c 100644 --- a/airflow/dag_processing/manager.py +++ b/airflow/dag_processing/manager.py @@ -653,6 +653,10 @@ def _refresh_dag_bundles(self): self.log.info("Refreshing DAG bundles") for bundle in self._dag_bundles: + # TODO: AIP-66 handle errors in the case of incomplete cloning? And test this. + # What if the cloning/refreshing took too long(longer than the dag processor timeout) + if not bundle.is_initialized: + bundle.initialize() # TODO: AIP-66 test to make sure we get a fresh record from the db and it's not cached with create_session() as session: bundle_model = session.get(DagBundleModel, bundle.name) diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py index 575306a840b79..9b39439384f56 100644 --- a/airflow/providers_manager.py +++ b/airflow/providers_manager.py @@ -175,6 +175,9 @@ def _create_customized_form_field_behaviours_schema_validator(): def _check_builtin_provider_prefix(provider_package: str, class_name: str) -> bool: + if "bundles" in provider_package: + # TODO: AIP-66: remove this when this package is moved to providers directory + return True if provider_package.startswith("apache-airflow"): provider_path = provider_package[len("apache-") :].replace("-", ".") if not class_name.startswith(provider_path): @@ -676,6 +679,8 @@ def _discover_all_airflow_builtin_providers_from_local_sources(self) -> None: self._add_provider_info_from_local_source_files_on_path(path) except Exception as e: log.warning("Error when loading 'provider.yaml' files from %s airflow sources: %s", path, e) + # TODO: AIP-66: Remove this when the package is moved to providers + self._add_provider_info_from_local_source_files_on_path("airflow/dag_processing") def _add_provider_info_from_local_source_files_on_path(self, path) -> None: """ diff --git a/tests/dag_processing/test_dag_bundles.py b/tests/dag_processing/test_dag_bundles.py index d450a56131361..49b7da1a03a92 100644 --- a/tests/dag_processing/test_dag_bundles.py +++ b/tests/dag_processing/test_dag_bundles.py @@ -26,11 +26,16 @@ from airflow.dag_processing.bundles.base import BaseDagBundle from airflow.dag_processing.bundles.dagfolder import DagsFolderDagBundle -from airflow.dag_processing.bundles.git import GitDagBundle +from airflow.dag_processing.bundles.git import GitDagBundle, GitHook from airflow.dag_processing.bundles.local import LocalDagBundle from airflow.exceptions import AirflowException +from airflow.models import Connection +from airflow.utils import db from tests_common.test_utils.config import conf_vars +from tests_common.test_utils.db import clear_db_connections + +pytestmark = pytest.mark.db_test @pytest.fixture(autouse=True) @@ -107,27 +112,111 @@ def git_repo(tmp_path_factory): return (directory, repo) +AIRFLOW_HTTPS_URL = "https://github.com/apache/airflow.git" +AIRFLOW_GIT = "git@github.com:apache/airflow.git" +ACCESS_TOKEN = "my_access_token" +CONN_DEFAULT = "git_default" +CONN_HTTPS = "my_git_conn" +CONN_HTTPS_PASSWORD = "my_git_conn_https_password" +CONN_ONLY_PATH = "my_git_conn_only_path" +CONN_NO_REPO_URL = "my_git_conn_no_repo_url" + + +class TestGitHook: + @classmethod + def teardown_class(cls) -> None: + clear_db_connections() + + @classmethod + def setup_class(cls) -> None: + db.merge_conn( + Connection( + conn_id=CONN_DEFAULT, + host=AIRFLOW_GIT, + conn_type="git", + ) + ) + db.merge_conn( + Connection( + conn_id=CONN_HTTPS, + host=AIRFLOW_HTTPS_URL, + password=ACCESS_TOKEN, + conn_type="git", + ) + ) + db.merge_conn( + Connection( + conn_id=CONN_HTTPS_PASSWORD, + host=AIRFLOW_HTTPS_URL, + conn_type="git", + password=ACCESS_TOKEN, + ) + ) + db.merge_conn( + Connection( + conn_id=CONN_ONLY_PATH, + host="path/to/repo", + conn_type="git", + ) + ) + + @pytest.mark.parametrize( + "conn_id, expected_repo_url", + [ + (CONN_DEFAULT, AIRFLOW_GIT), + (CONN_HTTPS, f"https://{ACCESS_TOKEN}@github.com/apache/airflow.git"), + (CONN_HTTPS_PASSWORD, f"https://{ACCESS_TOKEN}@github.com/apache/airflow.git"), + (CONN_ONLY_PATH, "path/to/repo"), + ], + ) + def test_correct_repo_urls(self, conn_id, expected_repo_url): + hook = GitHook(git_conn_id=conn_id) + assert hook.repo_url == expected_repo_url + + class TestGitDagBundle: + @classmethod + def teardown_class(cls) -> None: + clear_db_connections() + + @classmethod + def setup_class(cls) -> None: + db.merge_conn( + Connection( + conn_id="git_default", + host="git@github.com:apache/airflow.git", + conn_type="git", + ) + ) + db.merge_conn( + Connection( + conn_id=CONN_NO_REPO_URL, + conn_type="git", + ) + ) + def test_supports_versioning(self): assert GitDagBundle.supports_versioning is True def test_uses_dag_bundle_root_storage_path(self, git_repo): repo_path, repo = git_repo - bundle = GitDagBundle( - name="test", refresh_interval=300, repo_url=repo_path, tracking_ref=GIT_DEFAULT_BRANCH - ) + bundle = GitDagBundle(name="test", refresh_interval=300, tracking_ref=GIT_DEFAULT_BRANCH) assert str(bundle._dag_bundle_root_storage_path) in str(bundle.path) - def test_get_current_version(self, git_repo): + @mock.patch("airflow.dag_processing.bundles.git.GitHook") + def test_get_current_version(self, mock_githook, git_repo): repo_path, repo = git_repo - bundle = GitDagBundle( - name="test", refresh_interval=300, repo_url=repo_path, tracking_ref=GIT_DEFAULT_BRANCH - ) + mock_githook.return_value.repo_url = repo_path + bundle = GitDagBundle(name="test", refresh_interval=300, tracking_ref=GIT_DEFAULT_BRANCH) + + bundle.initialize() assert bundle.get_current_version() == repo.head.commit.hexsha - def test_get_specific_version(self, git_repo): + @mock.patch("airflow.dag_processing.bundles.git.GitHook") + def test_get_specific_version(self, mock_githook, git_repo): repo_path, repo = git_repo + mock_githook.return_value.repo_url = repo_path starting_commit = repo.head.commit # Add new file to the repo @@ -141,17 +230,19 @@ def test_get_specific_version(self, git_repo): name="test", refresh_interval=300, version=starting_commit.hexsha, - repo_url=repo_path, tracking_ref=GIT_DEFAULT_BRANCH, ) + bundle.initialize() assert bundle.get_current_version() == starting_commit.hexsha files_in_repo = {f.name for f in bundle.path.iterdir() if f.is_file()} assert {"test_dag.py"} == files_in_repo - def test_get_tag_version(self, git_repo): + @mock.patch("airflow.dag_processing.bundles.git.GitHook") + def test_get_tag_version(self, mock_githook, git_repo): repo_path, repo = git_repo + mock_githook.return_value.repo_url = repo_path starting_commit = repo.head.commit # add tag @@ -169,17 +260,18 @@ def test_get_tag_version(self, git_repo): name="test", refresh_interval=300, version="test", - repo_url=repo_path, tracking_ref=GIT_DEFAULT_BRANCH, ) - + bundle.initialize() assert bundle.get_current_version() == starting_commit.hexsha files_in_repo = {f.name for f in bundle.path.iterdir() if f.is_file()} assert {"test_dag.py"} == files_in_repo - def test_get_latest(self, git_repo): + @mock.patch("airflow.dag_processing.bundles.git.GitHook") + def test_get_latest(self, mock_githook, git_repo): repo_path, repo = git_repo + mock_githook.return_value.repo_url = repo_path starting_commit = repo.head.commit file_path = repo_path / "new_test.py" @@ -188,22 +280,22 @@ def test_get_latest(self, git_repo): repo.index.add([file_path]) repo.index.commit("Another commit") - bundle = GitDagBundle( - name="test", refresh_interval=300, repo_url=repo_path, tracking_ref=GIT_DEFAULT_BRANCH - ) + bundle = GitDagBundle(name="test", refresh_interval=300, tracking_ref=GIT_DEFAULT_BRANCH) + bundle.initialize() assert bundle.get_current_version() != starting_commit.hexsha files_in_repo = {f.name for f in bundle.path.iterdir() if f.is_file()} assert {"test_dag.py", "new_test.py"} == files_in_repo - def test_refresh(self, git_repo): + @mock.patch("airflow.dag_processing.bundles.git.GitHook") + def test_refresh(self, mock_githook, git_repo): repo_path, repo = git_repo + mock_githook.return_value.repo_url = repo_path starting_commit = repo.head.commit - bundle = GitDagBundle( - name="test", refresh_interval=300, repo_url=repo_path, tracking_ref=GIT_DEFAULT_BRANCH - ) + bundle = GitDagBundle(name="test", refresh_interval=300, tracking_ref=GIT_DEFAULT_BRANCH) + bundle.initialize() assert bundle.get_current_version() == starting_commit.hexsha @@ -223,27 +315,34 @@ def test_refresh(self, git_repo): files_in_repo = {f.name for f in bundle.path.iterdir() if f.is_file()} assert {"test_dag.py", "new_test.py"} == files_in_repo - def test_head(self, git_repo): + @mock.patch("airflow.dag_processing.bundles.git.GitHook") + def test_head(self, mock_githook, git_repo): repo_path, repo = git_repo + mock_githook.return_value.repo_url = repo_path repo.create_head("test") - bundle = GitDagBundle(name="test", refresh_interval=300, repo_url=repo_path, tracking_ref="test") + bundle = GitDagBundle(name="test", refresh_interval=300, tracking_ref="test") + bundle.initialize() assert bundle.repo.head.ref.name == "test" - def test_version_not_found(self, git_repo): + @mock.patch("airflow.dag_processing.bundles.git.GitHook") + def test_version_not_found(self, mock_githook, git_repo): repo_path, repo = git_repo + mock_githook.return_value.repo_url = repo_path + bundle = GitDagBundle( + name="test", + refresh_interval=300, + version="not_found", + tracking_ref=GIT_DEFAULT_BRANCH, + ) with pytest.raises(AirflowException, match="Version not_found not found in the repository"): - GitDagBundle( - name="test", - refresh_interval=300, - version="not_found", - repo_url=repo_path, - tracking_ref=GIT_DEFAULT_BRANCH, - ) + bundle.initialize() - def test_subdir(self, git_repo): + @mock.patch("airflow.dag_processing.bundles.git.GitHook") + def test_subdir(self, mock_githook, git_repo): repo_path, repo = git_repo + mock_githook.return_value.repo_url = repo_path subdir = "somesubdir" subdir_path = repo_path / subdir @@ -258,15 +357,75 @@ def test_subdir(self, git_repo): bundle = GitDagBundle( name="test", refresh_interval=300, - repo_url=repo_path, tracking_ref=GIT_DEFAULT_BRANCH, subdir=subdir, ) + bundle.initialize() files_in_repo = {f.name for f in bundle.path.iterdir() if f.is_file()} assert str(bundle.path).endswith(subdir) assert {"some_new_file.py"} == files_in_repo + def test_raises_when_no_repo_url(self): + bundle = GitDagBundle( + name="test", + refresh_interval=300, + git_conn_id=CONN_NO_REPO_URL, + tracking_ref=GIT_DEFAULT_BRANCH, + ) + with pytest.raises( + AirflowException, match=f"Connection {CONN_NO_REPO_URL} doesn't have a git_repo_url" + ): + bundle.initialize() + + @mock.patch("airflow.dag_processing.bundles.git.GitHook") + @mock.patch("airflow.dag_processing.bundles.git.Repo") + def test_with_path_as_repo_url(self, mock_gitRepo, mock_githook): + bundle = GitDagBundle( + name="test", + refresh_interval=300, + git_conn_id=CONN_ONLY_PATH, + tracking_ref=GIT_DEFAULT_BRANCH, + ) + bundle.initialize() + assert mock_gitRepo.clone_from.call_count == 2 + assert mock_gitRepo.return_value.git.checkout.call_count == 1 + + @mock.patch("airflow.dag_processing.bundles.git.Repo") + def test_refresh_with_git_connection(self, mock_gitRepo): + bundle = GitDagBundle( + name="test", + refresh_interval=300, + git_conn_id="git_default", + tracking_ref=GIT_DEFAULT_BRANCH, + ) + bundle.initialize() + bundle.refresh() + # check remotes called twice. one at initialize and one at refresh above + assert mock_gitRepo.return_value.remotes.origin.fetch.call_count == 2 + + @pytest.mark.parametrize( + "repo_url", + [ + pytest.param("https://github.com/apache/airflow", id="https_url"), + pytest.param("airflow@example:apache/airflow.git", id="does_not_start_with_git_at"), + pytest.param("git@example:apache/airflow", id="does_not_end_with_dot_git"), + ], + ) + @mock.patch("airflow.dag_processing.bundles.git.GitHook") + def test_repo_url_validation_for_ssh(self, mock_hook, repo_url, session): + mock_hook.return_value.repo_url = repo_url + bundle = GitDagBundle( + name="test", + refresh_interval=300, + git_conn_id="git_default", + tracking_ref=GIT_DEFAULT_BRANCH, + ) + with pytest.raises( + AirflowException, match=f"Invalid git URL: {repo_url}. URL must start with git@ and end with .git" + ): + bundle.initialize() + @pytest.mark.parametrize( "repo_url, expected_url", [ @@ -280,11 +439,18 @@ def test_subdir(self, git_repo): ], ) @mock.patch("airflow.dag_processing.bundles.git.Repo") - def test_view_url(self, mock_gitrepo, repo_url, expected_url): + def test_view_url(self, mock_gitrepo, repo_url, expected_url, session): + session.query(Connection).delete() + conn = Connection( + conn_id="git_default", + host=repo_url, + conn_type="git", + ) + session.add(conn) + session.commit() bundle = GitDagBundle( name="test", refresh_interval=300, - repo_url=repo_url, tracking_ref="main", ) view_url = bundle.view_url("0f0f0f") @@ -295,7 +461,6 @@ def test_view_url_returns_none_when_no_version_in_view_url(self, mock_gitrepo): bundle = GitDagBundle( name="test", refresh_interval=300, - repo_url="git@github.com:apache/airflow.git", tracking_ref="main", ) view_url = bundle.view_url(None)