Skip to content

Commit

Permalink
Fix code and link to dag processor
Browse files Browse the repository at this point in the history
  • Loading branch information
ephraimbuddy committed Jan 15, 2025
1 parent de27d21 commit 969dbbd
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 35 deletions.
9 changes: 9 additions & 0 deletions airflow/dag_processing/bundles/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,15 @@ def __init__(self, *, name: str, refresh_interval: int, version: str | None = No
self.name = name
self.version = version
self.refresh_interval = refresh_interval
self.is_initialized: bool = False

def initialize(self) -> None:
"""
Initialize the bundle.
This method is called by the DAG processor before the bundle is used.
"""
self.is_initialized = True

@property
def _dag_bundle_root_storage_path(self) -> Path:
Expand Down
33 changes: 11 additions & 22 deletions airflow/dag_processing/bundles/git.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,18 +118,14 @@ def __init__(
self.hook = GitHook(git_conn_id=self.git_conn_id)
self.repo_url = self.hook.repo_url

def _clone_from(self, to_path: Path, bare: bool = False) -> Repo:
self.log.info("Cloning %s to %s", self.repo_url, to_path)
return Repo.clone_from(self.repo_url, to_path, bare=bare, env=self.hook.env)

def _initialize(self):
self._clone_bare_repo_if_required()
self._ensure_version_in_bare_repo()
self._clone_repo_if_required()
self.repo.git.checkout(self.tracking_ref)
if self.version:
if not self._has_version(self.repo, self.version):
self._fetch_repo()
self.repo.remotes.origin.fetch()
self.repo.head.set_reference(self.repo.commit(self.version))
self.repo.head.reset(index=True, working_tree=True)
else:
Expand All @@ -146,19 +142,26 @@ def initialize(self) -> None:
)
else:
self._initialize()
super().initialize()

def _clone_repo_if_required(self) -> None:
if not os.path.exists(self.repo_path):
self._clone_from(
self.log.info("Cloning repository to %s from %s", self.repo_path, self.bare_repo_path)
Repo.clone_from(
url=self.bare_repo_path,
to_path=self.repo_path,
)

self.repo = Repo(self.repo_path)

def _clone_bare_repo_if_required(self) -> None:
if not os.path.exists(self.bare_repo_path):
self._clone_from(
self.log.info("Cloning bare repository to %s", self.bare_repo_path)
Repo.clone_from(
url=self.repo_url,
to_path=self.bare_repo_path,
bare=True,
env=self.hook.env,
)
self.bare_repo = Repo(self.bare_repo_path)

Expand Down Expand Up @@ -197,32 +200,18 @@ def _has_version(repo: Repo, version: str) -> bool:
except BadName:
return False

def _fetch_repo(self):
if self.hook.env:
with self.repo.git.custom_environment(GIT_SSH_COMMAND=self.hook.env.get("GIT_SSH_COMMAND")):
self.repo.remotes.origin.fetch()
else:
self.repo.remotes.origin.fetch()

def _fetch_bare_repo(self):
if self.hook.env:
with self.bare_repo.git.custom_environment(GIT_SSH_COMMAND=self.hook.env.get("GIT_SSH_COMMAND")):
self.bare_repo.remotes.origin.fetch("+refs/heads/*:refs/heads/*")
else:
self.bare_repo.remotes.origin.fetch("+refs/heads/*:refs/heads/*")

def _pull_repo(self):
if self.hook.env:
with self.repo.git.custom_environment(GIT_SSH_COMMAND=self.hook.env.get("GIT_SSH_COMMAND")):
self.repo.remotes.origin.pull()
else:
self.repo.remotes.origin.pull()

def refresh(self) -> None:
if self.version:
raise AirflowException("Refreshing a specific version is not supported")
self._fetch_bare_repo()
self._pull_repo()
self.repo.remotes.origin.pull()

def _convert_git_ssh_url_to_https(self) -> str:
if not self.repo_url.startswith("git@"):
Expand Down
1 change: 1 addition & 0 deletions airflow/dag_processing/bundles/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def parse_config(self) -> None:
class_ = import_string(cfg["classpath"])
kwargs = cfg["kwargs"]
self._bundle_config[name] = (class_, kwargs)
self.log.info("DAG bundles loaded: %s", ", ".join(self._bundle_config.keys()))

@provide_session
def sync_bundles_to_db(self, *, session: Session = NEW_SESSION) -> None:
Expand Down
4 changes: 4 additions & 0 deletions airflow/dag_processing/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,10 @@ def _refresh_dag_bundles(self):
self.log.info("Refreshing DAG bundles")

for bundle in self._dag_bundles:
# TODO: AIP-66 handle errors in the case of incomplete cloning? And test this.
# What if the cloning/refreshing took too long(longer than the dag processor timeout)
if not bundle.is_initialized:
bundle.initialize()
# TODO: AIP-66 test to make sure we get a fresh record from the db and it's not cached
with create_session() as session:
bundle_model = session.get(DagBundleModel, bundle.name)
Expand Down
16 changes: 3 additions & 13 deletions tests/dag_processing/test_dag_bundles.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ def test_get_current_version(self, mock_githook, git_repo):

@mock.patch("airflow.dag_processing.bundles.git.GitHook")
def test_get_specific_version(self, mock_githook, git_repo):
mock_githook.get_conn.return_value = mock.MagicMock()
repo_path, repo = git_repo
mock_githook.return_value.repo_url = repo_path
starting_commit = repo.head.commit
Expand All @@ -242,7 +241,6 @@ def test_get_specific_version(self, mock_githook, git_repo):

@mock.patch("airflow.dag_processing.bundles.git.GitHook")
def test_get_tag_version(self, mock_githook, git_repo):
mock_githook.get_conn.return_value = mock.MagicMock()
repo_path, repo = git_repo
mock_githook.return_value.repo_url = repo_path
starting_commit = repo.head.commit
Expand Down Expand Up @@ -272,7 +270,6 @@ def test_get_tag_version(self, mock_githook, git_repo):

@mock.patch("airflow.dag_processing.bundles.git.GitHook")
def test_get_latest(self, mock_githook, git_repo):
mock_githook.get_conn.return_value = mock.MagicMock()
repo_path, repo = git_repo
mock_githook.return_value.repo_url = repo_path
starting_commit = repo.head.commit
Expand All @@ -293,7 +290,6 @@ def test_get_latest(self, mock_githook, git_repo):

@mock.patch("airflow.dag_processing.bundles.git.GitHook")
def test_refresh(self, mock_githook, git_repo):
mock_githook.get_conn.return_value = mock.MagicMock()
repo_path, repo = git_repo
mock_githook.return_value.repo_url = repo_path
starting_commit = repo.head.commit
Expand Down Expand Up @@ -321,7 +317,6 @@ def test_refresh(self, mock_githook, git_repo):

@mock.patch("airflow.dag_processing.bundles.git.GitHook")
def test_head(self, mock_githook, git_repo):
mock_githook.get_conn.return_value = mock.MagicMock()
repo_path, repo = git_repo
mock_githook.return_value.repo_url = repo_path

Expand All @@ -332,7 +327,6 @@ def test_head(self, mock_githook, git_repo):

@mock.patch("airflow.dag_processing.bundles.git.GitHook")
def test_version_not_found(self, mock_githook, git_repo):
mock_githook.get_conn.return_value = mock.MagicMock()
repo_path, repo = git_repo
mock_githook.return_value.repo_url = repo_path
bundle = GitDagBundle(
Expand All @@ -347,7 +341,6 @@ def test_version_not_found(self, mock_githook, git_repo):

@mock.patch("airflow.dag_processing.bundles.git.GitHook")
def test_subdir(self, mock_githook, git_repo):
mock_githook.get_conn.return_value = mock.MagicMock()
repo_path, repo = git_repo
mock_githook.return_value.repo_url = repo_path

Expand Down Expand Up @@ -387,21 +380,19 @@ def test_raises_when_no_repo_url(self):

@mock.patch("airflow.dag_processing.bundles.git.GitHook")
@mock.patch("airflow.dag_processing.bundles.git.Repo")
@mock.patch.object(GitDagBundle, "_clone_from")
def test_with_path_as_repo_url(self, mock_clone_from, mock_gitRepo, mock_githook):
def test_with_path_as_repo_url(self, mock_gitRepo, mock_githook):
bundle = GitDagBundle(
name="test",
refresh_interval=300,
git_conn_id=CONN_ONLY_PATH,
tracking_ref=GIT_DEFAULT_BRANCH,
)
bundle.initialize()
assert mock_clone_from.call_count == 2
assert mock_gitRepo.clone_from.call_count == 2
assert mock_gitRepo.return_value.git.checkout.call_count == 1

@mock.patch("airflow.dag_processing.bundles.git.GitHook")
@mock.patch("airflow.dag_processing.bundles.git.Repo")
def test_refresh_with_git_connection(self, mock_gitRepo, mock_hook):
def test_refresh_with_git_connection(self, mock_gitRepo):
bundle = GitDagBundle(
name="test",
refresh_interval=300,
Expand All @@ -423,7 +414,6 @@ def test_refresh_with_git_connection(self, mock_gitRepo, mock_hook):
)
@mock.patch("airflow.dag_processing.bundles.git.GitHook")
def test_repo_url_validation_for_ssh(self, mock_hook, repo_url, session):
mock_hook.get_conn.return_value = mock.MagicMock()
mock_hook.return_value.repo_url = repo_url
bundle = GitDagBundle(
name="test",
Expand Down

0 comments on commit 969dbbd

Please sign in to comment.