Skip to content

Commit

Permalink
Use SSH to authenticate GitDagBundle (apache#44976)
Browse files Browse the repository at this point in the history
* Use SSH to authenticate GitDagBundle

This uses SSH hook to authenticate GitDagBundle when provided.

* Add tests

* Account for remotes with ssh

* renames

* fix tests

* Refactor code

* Use githook

* fixup! Use githook

* Populate the connection form with git type connection

* Mark test_dag_bundles as db test

* Add names to the extra items

* Update airflow/dag_processing/bundles/git.py

Co-authored-by: Felix Uellendall <[email protected]>

* Fix refresh

* Apply suggestions from code review

Co-authored-by: Jed Cunningham <[email protected]>

* Remove ssh hook inheritance

* fixup! Remove ssh hook inheritance

* Apply suggestions from code review

Co-authored-by: Jed Cunningham <[email protected]>

* Fix code and link to dag processor

* Apply suggestions from code review

Co-authored-by: Jed Cunningham <[email protected]>

---------

Co-authored-by: Felix Uellendall <[email protected]>
Co-authored-by: Jed Cunningham <[email protected]>
  • Loading branch information
3 people authored and dauinh committed Jan 23, 2025
1 parent d291335 commit 725900a
Show file tree
Hide file tree
Showing 7 changed files with 361 additions and 47 deletions.
10 changes: 10 additions & 0 deletions airflow/dag_processing/bundles/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,16 @@ def __init__(self, *, name: str, refresh_interval: int, version: str | None = No
self.name = name
self.version = version
self.refresh_interval = refresh_interval
self.is_initialized: bool = False

def initialize(self) -> None:
"""
Initialize the bundle.
This method is called by the DAG processor before the bundle is used,
and allows for deferring expensive operations until that point in time.
"""
self.is_initialized = True

@property
def _dag_bundle_root_storage_path(self) -> Path:
Expand Down
107 changes: 96 additions & 11 deletions airflow/dag_processing/bundles/git.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,80 +17,159 @@

from __future__ import annotations

import json
import os
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse

from git import Repo
from git.exc import BadName

from airflow.dag_processing.bundles.base import BaseDagBundle
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.utils.log.logging_mixin import LoggingMixin

if TYPE_CHECKING:
from pathlib import Path


class GitDagBundle(BaseDagBundle):
class GitHook(BaseHook):
"""
Hook for git repositories.
:param git_conn_id: Connection ID for SSH connection to the repository
"""

conn_name_attr = "git_conn_id"
default_conn_name = "git_default"
conn_type = "git"
hook_name = "GIT"

@classmethod
def get_ui_field_behaviour(cls) -> dict[str, Any]:
return {
"hidden_fields": ["schema"],
"relabeling": {
"login": "Username",
"host": "Repository URL",
"password": "Access Token (optional)",
},
"placeholders": {
"extra": json.dumps(
{
"key_file": "optional/path/to/keyfile",
}
)
},
}

def __init__(self, git_conn_id="git_default", *args, **kwargs):
super().__init__()
connection = self.get_connection(git_conn_id)
self.repo_url = connection.host
self.auth_token = connection.password
self.key_file = connection.extra_dejson.get("key_file")
self.env: dict[str, str] = {}
if self.key_file:
self.env["GIT_SSH_COMMAND"] = f"ssh -i {self.key_file} -o IdentitiesOnly=yes"
self._process_git_auth_url()

def _process_git_auth_url(self):
if not isinstance(self.repo_url, str):
return
if self.auth_token and self.repo_url.startswith("https://"):
self.repo_url = self.repo_url.replace("https://", f"https://{self.auth_token}@")
elif not self.repo_url.startswith("git@") or not self.repo_url.startswith("https://"):
self.repo_url = os.path.expanduser(self.repo_url)


class GitDagBundle(BaseDagBundle, LoggingMixin):
"""
git DAG bundle - exposes a git repository as a DAG bundle.
Instead of cloning the repository every time, we clone the repository once into a bare repo from the source
and then do a clone for each version from there.
:param repo_url: URL of the git repository
:param tracking_ref: Branch or tag for this DAG bundle
:param subdir: Subdirectory within the repository where the DAGs are stored (Optional)
:param git_conn_id: Connection ID for SSH/token based connection to the repository (Optional)
"""

supports_versioning = True

def __init__(self, *, repo_url: str, tracking_ref: str, subdir: str | None = None, **kwargs) -> None:
def __init__(
self,
*,
tracking_ref: str,
subdir: str | None = None,
git_conn_id: str = "git_default",
**kwargs,
) -> None:
super().__init__(**kwargs)
self.repo_url = repo_url
self.tracking_ref = tracking_ref
self.subdir = subdir

self.bare_repo_path = self._dag_bundle_root_storage_path / "git" / self.name
self.repo_path = (
self._dag_bundle_root_storage_path / "git" / (self.name + f"+{self.version or self.tracking_ref}")
)
self.git_conn_id = git_conn_id
self.hook = GitHook(git_conn_id=self.git_conn_id)
self.repo_url = self.hook.repo_url

def _initialize(self):
self._clone_bare_repo_if_required()
self._ensure_version_in_bare_repo()
self._clone_repo_if_required()
self.repo.git.checkout(self.tracking_ref)

if self.version:
if not self._has_version(self.repo, self.version):
self.repo.remotes.origin.fetch()

self.repo.head.set_reference(self.repo.commit(self.version))
self.repo.head.reset(index=True, working_tree=True)
else:
self.refresh()

def initialize(self) -> None:
if not self.repo_url:
raise AirflowException(f"Connection {self.git_conn_id} doesn't have a git_repo_url")
if isinstance(self.repo_url, os.PathLike):
self._initialize()
elif not self.repo_url.startswith("git@") or not self.repo_url.endswith(".git"):
raise AirflowException(
f"Invalid git URL: {self.repo_url}. URL must start with git@ and end with .git"
)
else:
self._initialize()
super().initialize()

def _clone_repo_if_required(self) -> None:
if not os.path.exists(self.repo_path):
self.log.info("Cloning repository to %s from %s", self.repo_path, self.bare_repo_path)
Repo.clone_from(
url=self.bare_repo_path,
to_path=self.repo_path,
)

self.repo = Repo(self.repo_path)

def _clone_bare_repo_if_required(self) -> None:
if not os.path.exists(self.bare_repo_path):
self.log.info("Cloning bare repository to %s", self.bare_repo_path)
Repo.clone_from(
url=self.repo_url,
to_path=self.bare_repo_path,
bare=True,
env=self.hook.env,
)
self.bare_repo = Repo(self.bare_repo_path)

def _ensure_version_in_bare_repo(self) -> None:
if not self.version:
return
if not self._has_version(self.bare_repo, self.version):
self.bare_repo.remotes.origin.fetch("+refs/heads/*:refs/heads/*")
self._fetch_bare_repo()
if not self._has_version(self.bare_repo, self.version):
raise AirflowException(f"Version {self.version} not found in the repository")

Expand Down Expand Up @@ -121,11 +200,17 @@ def _has_version(repo: Repo, version: str) -> bool:
except BadName:
return False

def _fetch_bare_repo(self):
if self.hook.env:
with self.bare_repo.git.custom_environment(GIT_SSH_COMMAND=self.hook.env.get("GIT_SSH_COMMAND")):
self.bare_repo.remotes.origin.fetch("+refs/heads/*:refs/heads/*")
else:
self.bare_repo.remotes.origin.fetch("+refs/heads/*:refs/heads/*")

def refresh(self) -> None:
if self.version:
raise AirflowException("Refreshing a specific version is not supported")

self.bare_repo.remotes.origin.fetch("+refs/heads/*:refs/heads/*")
self._fetch_bare_repo()
self.repo.remotes.origin.pull()

def _convert_git_ssh_url_to_https(self) -> str:
Expand Down
1 change: 1 addition & 0 deletions airflow/dag_processing/bundles/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def parse_config(self) -> None:
class_ = import_string(cfg["classpath"])
kwargs = cfg["kwargs"]
self._bundle_config[name] = (class_, kwargs)
self.log.info("DAG bundles loaded: %s", ", ".join(self._bundle_config.keys()))

@provide_session
def sync_bundles_to_db(self, *, session: Session = NEW_SESSION) -> None:
Expand Down
44 changes: 44 additions & 0 deletions airflow/dag_processing/bundles/provider.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

---
package-name: apache-airflow-providers-bundles
name: GIT
description: |
`GIT <https://git-scm.com/>`__
state: not-ready
source-date-epoch: 1726861127
# note that those versions are maintained by release manager - do not update them manually
versions:
- 1.0.0

dependencies:
- apache-airflow-providers-ssh

integrations:
- integration-name: GIT (Git)

hooks:
- integration-name: GIT
python-modules:
- airflow.dag_processing.bundles.git


connection-types:
- hook-class-name: airflow.dag_processing.bundles.git.GitHook
connection-type: git
4 changes: 4 additions & 0 deletions airflow/dag_processing/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,10 @@ def _refresh_dag_bundles(self):
self.log.info("Refreshing DAG bundles")

for bundle in self._dag_bundles:
# TODO: AIP-66 handle errors in the case of incomplete cloning? And test this.
# What if the cloning/refreshing took too long(longer than the dag processor timeout)
if not bundle.is_initialized:
bundle.initialize()
# TODO: AIP-66 test to make sure we get a fresh record from the db and it's not cached
with create_session() as session:
bundle_model = session.get(DagBundleModel, bundle.name)
Expand Down
5 changes: 5 additions & 0 deletions airflow/providers_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ def _create_customized_form_field_behaviours_schema_validator():


def _check_builtin_provider_prefix(provider_package: str, class_name: str) -> bool:
if "bundles" in provider_package:
# TODO: AIP-66: remove this when this package is moved to providers directory
return True
if provider_package.startswith("apache-airflow"):
provider_path = provider_package[len("apache-") :].replace("-", ".")
if not class_name.startswith(provider_path):
Expand Down Expand Up @@ -676,6 +679,8 @@ def _discover_all_airflow_builtin_providers_from_local_sources(self) -> None:
self._add_provider_info_from_local_source_files_on_path(path)
except Exception as e:
log.warning("Error when loading 'provider.yaml' files from %s airflow sources: %s", path, e)
# TODO: AIP-66: Remove this when the package is moved to providers
self._add_provider_info_from_local_source_files_on_path("airflow/dag_processing")

def _add_provider_info_from_local_source_files_on_path(self, path) -> None:
"""
Expand Down
Loading

0 comments on commit 725900a

Please sign in to comment.