Skip to content

Commit bc02d3d

Browse files
committed
fixup! Use githook
1 parent c582d74 commit bc02d3d

File tree

2 files changed

+179
-47
lines changed

2 files changed

+179
-47
lines changed

airflow/dag_processing/bundles/git.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ def _process_git_auth_url(self):
7474
return
7575
if self.auth_token and self.repo_url.startswith("https://"):
7676
self.repo_url = self.repo_url.replace("https://", f"https://{self.auth_token}@")
77+
elif not self.repo_url.startswith("git@") or not self.repo_url.startswith("https://"):
78+
self.repo_url = os.path.expanduser(self.repo_url)
7779

7880
def get_conn(self):
7981
"""
@@ -143,11 +145,13 @@ def _initialize(self):
143145
def initialize(self) -> None:
144146
if not self.repo_url:
145147
raise AirflowException(f"Connection {self.git_conn_id} doesn't have a git_repo_url")
146-
if self.repo_url.startswith("git@"):
147-
if not self.repo_url.startswith("git@") and not self.repo_url.endswith(".git"):
148-
raise AirflowException(
149-
f"Invalid git URL: {self.repo_url}. URL must start with git@ and end with .git"
150-
)
148+
if isinstance(self.repo_url, os.PathLike):
149+
self._initialize()
150+
elif not self.repo_url.startswith("git@") or not self.repo_url.endswith(".git"):
151+
raise AirflowException(
152+
f"Invalid git URL: {self.repo_url}. URL must start with git@ and end with .git"
153+
)
154+
elif self.repo_url.startswith("git@"):
151155
with self.hook.get_conn():
152156
self._initialize()
153157
else:

tests/dag_processing/test_dag_bundles.py

+170-42
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,14 @@
2626

2727
from airflow.dag_processing.bundles.base import BaseDagBundle
2828
from airflow.dag_processing.bundles.dagfolder import DagsFolderDagBundle
29-
from airflow.dag_processing.bundles.git import GitDagBundle
29+
from airflow.dag_processing.bundles.git import GitDagBundle, GitHook, SSHHook
3030
from airflow.dag_processing.bundles.local import LocalDagBundle
3131
from airflow.exceptions import AirflowException
32+
from airflow.models import Connection
33+
from airflow.utils import db
3234

3335
from tests_common.test_utils.config import conf_vars
36+
from tests_common.test_utils.db import clear_db_connections
3437

3538

3639
@pytest.fixture(autouse=True)
@@ -107,28 +110,125 @@ def git_repo(tmp_path_factory):
107110
return (directory, repo)
108111

109112

113+
AIRFLOW_HTTPS_URL = "https://github.com/apache/airflow.git"
114+
AIRFLOW_GIT = "[email protected]:apache/airflow.git"
115+
ACCESS_TOKEN = "my_access_token"
116+
CONN_DEFAULT = "git_default"
117+
CONN_HTTPS = "my_git_conn"
118+
CONN_HTTPS_PASSWORD = "my_git_conn_https_password"
119+
CONN_ONLY_PATH = "my_git_conn_only_path"
120+
CONN_NO_REPO_URL = "my_git_conn_no_repo_url"
121+
122+
123+
class TestGitHook:
124+
@classmethod
125+
def teardown_class(cls) -> None:
126+
clear_db_connections()
127+
128+
@classmethod
129+
def setup_class(cls) -> None:
130+
db.merge_conn(
131+
Connection(
132+
conn_id=CONN_DEFAULT,
133+
host="github.com",
134+
conn_type="git",
135+
extra={"git_repo_url": AIRFLOW_GIT},
136+
)
137+
)
138+
db.merge_conn(
139+
Connection(
140+
conn_id=CONN_HTTPS,
141+
host="github.com",
142+
conn_type="git",
143+
extra={"git_repo_url": AIRFLOW_HTTPS_URL, "git_access_token": ACCESS_TOKEN},
144+
)
145+
)
146+
db.merge_conn(
147+
Connection(
148+
conn_id=CONN_HTTPS_PASSWORD,
149+
host="github.com",
150+
conn_type="git",
151+
password=ACCESS_TOKEN,
152+
extra={"git_repo_url": AIRFLOW_HTTPS_URL},
153+
)
154+
)
155+
db.merge_conn(
156+
Connection(
157+
conn_id=CONN_ONLY_PATH,
158+
host="github.com",
159+
conn_type="git",
160+
extra={"git_repo_url": "path/to/repo"},
161+
)
162+
)
163+
164+
@pytest.mark.parametrize(
165+
"conn_id, expected_repo_url",
166+
[
167+
(CONN_DEFAULT, AIRFLOW_GIT),
168+
(CONN_HTTPS, f"https://{ACCESS_TOKEN}@github.com/apache/airflow.git"),
169+
(CONN_HTTPS_PASSWORD, f"https://{ACCESS_TOKEN}@github.com/apache/airflow.git"),
170+
(CONN_ONLY_PATH, "path/to/repo"),
171+
],
172+
)
173+
def test_correct_repo_urls(self, conn_id, expected_repo_url):
174+
hook = GitHook(git_conn_id=conn_id)
175+
assert hook.repo_url == expected_repo_url
176+
177+
@mock.patch.object(SSHHook, "get_conn")
178+
def test_connection_made_to_ssh_hook(self, mock_ssh_hook_get_conn):
179+
hook = GitHook(git_conn_id=CONN_DEFAULT)
180+
hook.get_conn()
181+
mock_ssh_hook_get_conn.assert_called_once_with()
182+
183+
110184
class TestGitDagBundle:
185+
@classmethod
186+
def teardown_class(cls) -> None:
187+
clear_db_connections()
188+
189+
@classmethod
190+
def setup_class(cls) -> None:
191+
db.merge_conn(
192+
Connection(
193+
conn_id="git_default",
194+
host="github.com",
195+
conn_type="git",
196+
extra={"git_repo_url": "[email protected]:apache/airflow.git"},
197+
)
198+
)
199+
db.merge_conn(
200+
Connection(
201+
conn_id=CONN_NO_REPO_URL,
202+
host="github.com",
203+
conn_type="git",
204+
extra="{}",
205+
)
206+
)
207+
111208
def test_supports_versioning(self):
112209
assert GitDagBundle.supports_versioning is True
113210

114211
def test_uses_dag_bundle_root_storage_path(self, git_repo):
115212
repo_path, repo = git_repo
116-
bundle = GitDagBundle(
117-
name="test", refresh_interval=300, repo_url=repo_path, tracking_ref=GIT_DEFAULT_BRANCH
118-
)
213+
bundle = GitDagBundle(name="test", refresh_interval=300, tracking_ref=GIT_DEFAULT_BRANCH)
119214
assert str(bundle._dag_bundle_root_storage_path) in str(bundle.path)
120215

121-
def test_get_current_version(self, git_repo):
216+
@mock.patch("airflow.dag_processing.bundles.git.GitHook")
217+
def test_get_current_version(self, mock_githook, git_repo):
218+
mock_githook.get_conn.return_value = mock.MagicMock()
122219
repo_path, repo = git_repo
123-
bundle = GitDagBundle(
124-
name="test", refresh_interval=300, repo_url=repo_path, tracking_ref=GIT_DEFAULT_BRANCH
125-
)
220+
mock_githook.return_value.repo_url = repo_path
221+
bundle = GitDagBundle(name="test", refresh_interval=300, tracking_ref=GIT_DEFAULT_BRANCH)
222+
126223
bundle.initialize()
127224

128225
assert bundle.get_current_version() == repo.head.commit.hexsha
129226

130-
def test_get_specific_version(self, git_repo):
227+
@mock.patch("airflow.dag_processing.bundles.git.GitHook")
228+
def test_get_specific_version(self, mock_githook, git_repo):
229+
mock_githook.get_conn.return_value = mock.MagicMock()
131230
repo_path, repo = git_repo
231+
mock_githook.return_value.repo_url = repo_path
132232
starting_commit = repo.head.commit
133233

134234
# Add new file to the repo
@@ -142,7 +242,6 @@ def test_get_specific_version(self, git_repo):
142242
name="test",
143243
refresh_interval=300,
144244
version=starting_commit.hexsha,
145-
repo_url=repo_path,
146245
tracking_ref=GIT_DEFAULT_BRANCH,
147246
)
148247
bundle.initialize()
@@ -152,8 +251,11 @@ def test_get_specific_version(self, git_repo):
152251
files_in_repo = {f.name for f in bundle.path.iterdir() if f.is_file()}
153252
assert {"test_dag.py"} == files_in_repo
154253

155-
def test_get_tag_version(self, git_repo):
254+
@mock.patch("airflow.dag_processing.bundles.git.GitHook")
255+
def test_get_tag_version(self, mock_githook, git_repo):
256+
mock_githook.get_conn.return_value = mock.MagicMock()
156257
repo_path, repo = git_repo
258+
mock_githook.return_value.repo_url = repo_path
157259
starting_commit = repo.head.commit
158260

159261
# add tag
@@ -171,7 +273,6 @@ def test_get_tag_version(self, git_repo):
171273
name="test",
172274
refresh_interval=300,
173275
version="test",
174-
repo_url=repo_path,
175276
tracking_ref=GIT_DEFAULT_BRANCH,
176277
)
177278
bundle.initialize()
@@ -180,8 +281,11 @@ def test_get_tag_version(self, git_repo):
180281
files_in_repo = {f.name for f in bundle.path.iterdir() if f.is_file()}
181282
assert {"test_dag.py"} == files_in_repo
182283

183-
def test_get_latest(self, git_repo):
284+
@mock.patch("airflow.dag_processing.bundles.git.GitHook")
285+
def test_get_latest(self, mock_githook, git_repo):
286+
mock_githook.get_conn.return_value = mock.MagicMock()
184287
repo_path, repo = git_repo
288+
mock_githook.return_value.repo_url = repo_path
185289
starting_commit = repo.head.commit
186290

187291
file_path = repo_path / "new_test.py"
@@ -190,23 +294,22 @@ def test_get_latest(self, git_repo):
190294
repo.index.add([file_path])
191295
repo.index.commit("Another commit")
192296

193-
bundle = GitDagBundle(
194-
name="test", refresh_interval=300, repo_url=repo_path, tracking_ref=GIT_DEFAULT_BRANCH
195-
)
297+
bundle = GitDagBundle(name="test", refresh_interval=300, tracking_ref=GIT_DEFAULT_BRANCH)
196298
bundle.initialize()
197299

198300
assert bundle.get_current_version() != starting_commit.hexsha
199301

200302
files_in_repo = {f.name for f in bundle.path.iterdir() if f.is_file()}
201303
assert {"test_dag.py", "new_test.py"} == files_in_repo
202304

203-
def test_refresh(self, git_repo):
305+
@mock.patch("airflow.dag_processing.bundles.git.GitHook")
306+
def test_refresh(self, mock_githook, git_repo):
307+
mock_githook.get_conn.return_value = mock.MagicMock()
204308
repo_path, repo = git_repo
309+
mock_githook.return_value.repo_url = repo_path
205310
starting_commit = repo.head.commit
206311

207-
bundle = GitDagBundle(
208-
name="test", refresh_interval=300, repo_url=repo_path, tracking_ref=GIT_DEFAULT_BRANCH
209-
)
312+
bundle = GitDagBundle(name="test", refresh_interval=300, tracking_ref=GIT_DEFAULT_BRANCH)
210313
bundle.initialize()
211314

212315
assert bundle.get_current_version() == starting_commit.hexsha
@@ -227,29 +330,37 @@ def test_refresh(self, git_repo):
227330
files_in_repo = {f.name for f in bundle.path.iterdir() if f.is_file()}
228331
assert {"test_dag.py", "new_test.py"} == files_in_repo
229332

230-
def test_head(self, git_repo):
333+
@mock.patch("airflow.dag_processing.bundles.git.GitHook")
334+
def test_head(self, mock_githook, git_repo):
335+
mock_githook.get_conn.return_value = mock.MagicMock()
231336
repo_path, repo = git_repo
337+
mock_githook.return_value.repo_url = repo_path
232338

233339
repo.create_head("test")
234-
bundle = GitDagBundle(name="test", refresh_interval=300, repo_url=repo_path, tracking_ref="test")
340+
bundle = GitDagBundle(name="test", refresh_interval=300, tracking_ref="test")
235341
bundle.initialize()
236342
assert bundle.repo.head.ref.name == "test"
237343

238-
def test_version_not_found(self, git_repo):
344+
@mock.patch("airflow.dag_processing.bundles.git.GitHook")
345+
def test_version_not_found(self, mock_githook, git_repo):
346+
mock_githook.get_conn.return_value = mock.MagicMock()
239347
repo_path, repo = git_repo
348+
mock_githook.return_value.repo_url = repo_path
240349
bundle = GitDagBundle(
241350
name="test",
242351
refresh_interval=300,
243352
version="not_found",
244-
repo_url=repo_path,
245353
tracking_ref=GIT_DEFAULT_BRANCH,
246354
)
247355

248356
with pytest.raises(AirflowException, match="Version not_found not found in the repository"):
249357
bundle.initialize()
250358

251-
def test_subdir(self, git_repo):
359+
@mock.patch("airflow.dag_processing.bundles.git.GitHook")
360+
def test_subdir(self, mock_githook, git_repo):
361+
mock_githook.get_conn.return_value = mock.MagicMock()
252362
repo_path, repo = git_repo
363+
mock_githook.return_value.repo_url = repo_path
253364

254365
subdir = "somesubdir"
255366
subdir_path = repo_path / subdir
@@ -264,7 +375,6 @@ def test_subdir(self, git_repo):
264375
bundle = GitDagBundle(
265376
name="test",
266377
refresh_interval=300,
267-
repo_url=repo_path,
268378
tracking_ref=GIT_DEFAULT_BRANCH,
269379
subdir=subdir,
270380
)
@@ -274,44 +384,62 @@ def test_subdir(self, git_repo):
274384
assert str(bundle.path).endswith(subdir)
275385
assert {"some_new_file.py"} == files_in_repo
276386

277-
@mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook")
387+
def test_raises_when_no_repo_url(self):
388+
bundle = GitDagBundle(
389+
name="test",
390+
refresh_interval=300,
391+
git_conn_id=CONN_NO_REPO_URL,
392+
tracking_ref=GIT_DEFAULT_BRANCH,
393+
)
394+
with pytest.raises(
395+
AirflowException, match=f"Connection {CONN_NO_REPO_URL} doesn't have a git_repo_url"
396+
):
397+
bundle.initialize()
398+
399+
@mock.patch("airflow.dag_processing.bundles.git.GitHook")
278400
@mock.patch("airflow.dag_processing.bundles.git.Repo")
279-
def test_with_ssh_conn_id(self, mock_gitRepo, mock_hook):
280-
repo_url = "[email protected]:apache/airflow.git"
281-
conn_id = "ssh_default"
401+
@mock.patch.object(GitDagBundle, "_clone_from")
402+
def test_with_path_as_repo_url(self, mock_clone_from, mock_gitRepo, mock_githook):
282403
bundle = GitDagBundle(
283-
repo_url=repo_url,
284404
name="test",
285405
refresh_interval=300,
286-
ssh_conn_kwargs={"ssh_conn_id": "ssh_default"},
406+
git_conn_id=CONN_ONLY_PATH,
287407
tracking_ref=GIT_DEFAULT_BRANCH,
288408
)
289409
bundle.initialize()
290-
mock_hook.assert_called_once_with(ssh_conn_id=conn_id)
410+
assert mock_clone_from.call_count == 2
411+
assert mock_gitRepo.return_value.git.checkout.call_count == 1
291412

292-
@mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook")
413+
@mock.patch("airflow.dag_processing.bundles.git.GitHook")
293414
@mock.patch("airflow.dag_processing.bundles.git.Repo")
294-
def test_refresh_with_ssh_connection(self, mock_gitRepo, mock_hook):
295-
repo_url = "[email protected]:apache/airflow.git"
415+
def test_refresh_with_git_connection(self, mock_gitRepo, mock_hook):
296416
bundle = GitDagBundle(
297-
repo_url=repo_url,
298417
name="test",
299418
refresh_interval=300,
300-
ssh_conn_kwargs={"ssh_conn_id": "ssh_default"},
419+
git_conn_id="git_default",
301420
tracking_ref=GIT_DEFAULT_BRANCH,
302421
)
303422
bundle.initialize()
304423
bundle.refresh()
305424
# check remotes called twice. one at initialize and one at refresh above
306425
assert mock_gitRepo.return_value.remotes.origin.fetch.call_count == 2
307426

308-
def test_repo_url_starts_with_git_when_using_ssh_conn_id(self):
309-
repo_url = "https://github.com/apache/airflow"
427+
@pytest.mark.parametrize(
428+
"repo_url",
429+
[
430+
pytest.param("https://github.com/apache/airflow", id="https_url"),
431+
pytest.param("airflow@example:apache/airflow.git", id="does_not_start_with_git_at"),
432+
pytest.param("git@example:apache/airflow", id="does_not_end_with_dot_git"),
433+
],
434+
)
435+
@mock.patch("airflow.dag_processing.bundles.git.GitHook")
436+
def test_repo_url_starts_with_git_when_using_ssh_conn_id(self, mock_hook, repo_url, session):
437+
mock_hook.get_conn.return_value = mock.MagicMock()
438+
mock_hook.return_value.repo_url = repo_url
310439
bundle = GitDagBundle(
311-
repo_url=repo_url,
312440
name="test",
313441
refresh_interval=300,
314-
ssh_conn_kwargs={"ssh_conn_id": "ssh_default"},
442+
git_conn_id="git_default",
315443
tracking_ref=GIT_DEFAULT_BRANCH,
316444
)
317445
with pytest.raises(

0 commit comments

Comments
 (0)