Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ephraimbuddy committed Dec 18, 2024
1 parent dc2d2ec commit 9537677
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 7 deletions.
11 changes: 7 additions & 4 deletions airflow/dag_processing/bundles/git.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,18 @@ class GitDagBundle(BaseDagBundle, LoggingMixin):
:param repo_url: URL of the git repository
:param tracking_ref: Branch or tag for this DAG bundle
:param subdir: Subdirectory within the repository where the DAGs are stored (Optional)
:param ssh_conn_id: Connection ID for SSH connection to the repository (Optional)
"""

supports_versioning = True

def __init__(
self,
*,
repo_url: str | os.PathLike = "",
tracking_ref: str,
subdir: str | None = None,
ssh_conn_id: str | None = None,
repo_url: str | os.PathLike = "",
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -178,9 +179,11 @@ def _refresh():
self.bare_repo.remotes.origin.fetch("+refs/heads/*:refs/heads/*")
self.repo.remotes.origin.pull()

if self.env:
_refresh()
elif self.ssh_conn_id:
if self.ssh_conn_id:
if self.env:
with self.repo.git.custom_environment(**self.env):
_refresh()
return
temp_key_file_path = None
try:
if not self.key_file:
Expand Down
87 changes: 84 additions & 3 deletions tests/dag_processing/test_dag_bundles.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import tempfile
from pathlib import Path
from unittest import mock
from unittest.mock import MagicMock

import pytest
from git import Repo
Expand Down Expand Up @@ -282,12 +283,92 @@ def test_with_ssh_conn_id(self, mock_gitRepo, mock_hook):
key_filepath = "/path/to/keyfile"
mock_hook.return_value.key_file = key_filepath
mock_hook.return_value.remote_host = repo_url
bundle = GitDagBundle(name="test", ssh_conn_id="ssh_default", tracking_ref=GIT_DEFAULT_BRANCH)
bundle = GitDagBundle(
name="test",
refresh_interval=300,
ssh_conn_id="ssh_default",
tracking_ref=GIT_DEFAULT_BRANCH,
)
assert bundle.env == {}
bundle.initialize()
mock_hook.assert_called_once_with(ssh_conn_id=conn_id)

def test_no_key_file_and_no_private_key_raises_for_ssh_conn(self):
bundle = GitDagBundle(name="test", ssh_conn_id="ssh_default", tracking_ref=GIT_DEFAULT_BRANCH)
@mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook")
def test_no_key_file_and_no_private_key_raises_for_ssh_conn(self, mock_hook):
bundle = GitDagBundle(
name="test", refresh_interval=300, ssh_conn_id="ssh_default", tracking_ref=GIT_DEFAULT_BRANCH
)
with pytest.raises(AirflowException, match="No private key present in connection"):
bundle.initialize()

@mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook")
@mock.patch("airflow.dag_processing.bundles.git.Repo")
@mock.patch("airflow.dag_processing.bundles.git.os")
def test_temporary_file_removed_after_initialization(self, mock_os, mock_gitRepo, mock_hook):
repo_url = "[email protected]:apache/airflow.git"
ssh_hook = mock_hook.return_value
ssh_hook.key_file = None
ssh_hook.remote_host = repo_url
conn = MagicMock()
conn.extra_dejson = {"private_key": "private"}
ssh_hook.get_connection.return_value = conn
bundle = GitDagBundle(
name="test",
refresh_interval=300,
ssh_conn_id="ssh_default",
tracking_ref=GIT_DEFAULT_BRANCH,
)
bundle.initialize()
assert mock_os.remove.call_count == 1

@mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook")
@mock.patch("airflow.dag_processing.bundles.git.Repo")
def test_refresh_with_an_existing_env(self, mock_gitRepo, mock_hook):
repo_url = "[email protected]:apache/airflow.git"
key_filepath = "/path/to/keyfile"
mock_hook.return_value.key_file = key_filepath
mock_hook.return_value.remote_host = repo_url
bundle = GitDagBundle(
name="test",
refresh_interval=300,
ssh_conn_id="ssh_default",
tracking_ref=GIT_DEFAULT_BRANCH,
)
bundle.initialize()
bundle.env = {"GIT_SSH_COMMAND": "ssh -i /path/to/keyfile -o IdentitiesOnly=yes"}
bundle.refresh()

# check remotes called twice. one at initialize and one at refresh above
assert mock_gitRepo.return_value.remotes.origin.fetch.call_count == 2
# assert remotes called with custom env
mock_gitRepo.return_value.git.custom_environment.assert_called_with(**bundle.env)

def test_refresh_with_conn_id_raises_when_bundle_not_initialized(self):
bundle = GitDagBundle(
name="test", refresh_interval=300, ssh_conn_id="ssh_default", tracking_ref=GIT_DEFAULT_BRANCH
)
with pytest.raises(AirflowException, match="Missing private key, please initialize the bundle first"):
bundle.refresh()

@mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook")
@mock.patch("airflow.dag_processing.bundles.git.Repo")
@mock.patch("airflow.dag_processing.bundles.git.os")
def test_temporary_file_removed_in_refresh(self, mock_os, mock_gitRepo, mock_hook):
repo_url = "[email protected]:apache/airflow.git"
ssh_hook = mock_hook.return_value
ssh_hook.key_file = None
ssh_hook.remote_host = repo_url
conn = MagicMock()
conn.extra_dejson = {"private_key": "private"}
ssh_hook.get_connection.return_value = conn
bundle = GitDagBundle(
name="test",
refresh_interval=300,
ssh_conn_id="ssh_default",
tracking_ref=GIT_DEFAULT_BRANCH,
)
bundle.initialize()
# Check os.remove called.
bundle.refresh()
# Check os.remove called twice. Once in initialization and another in the method
assert mock_os.remove.call_count == 2

0 comments on commit 9537677

Please sign in to comment.