26
26
27
27
from airflow .dag_processing .bundles .base import BaseDagBundle
28
28
from airflow .dag_processing .bundles .dagfolder import DagsFolderDagBundle
29
- from airflow .dag_processing .bundles .git import GitDagBundle
29
+ from airflow .dag_processing .bundles .git import GitDagBundle , GitHook , SSHHook
30
30
from airflow .dag_processing .bundles .local import LocalDagBundle
31
31
from airflow .exceptions import AirflowException
32
+ from airflow .models import Connection
33
+ from airflow .utils import db
32
34
33
35
from tests_common .test_utils .config import conf_vars
36
+ from tests_common .test_utils .db import clear_db_connections
34
37
35
38
36
39
@pytest .fixture (autouse = True )
@@ -107,28 +110,125 @@ def git_repo(tmp_path_factory):
107
110
return (directory , repo )
108
111
109
112
113
+ AIRFLOW_HTTPS_URL = "https://github.com/apache/airflow.git"
114
+ AIRFLOW_GIT = "[email protected] :apache/airflow.git"
115
+ ACCESS_TOKEN = "my_access_token"
116
+ CONN_DEFAULT = "git_default"
117
+ CONN_HTTPS = "my_git_conn"
118
+ CONN_HTTPS_PASSWORD = "my_git_conn_https_password"
119
+ CONN_ONLY_PATH = "my_git_conn_only_path"
120
+ CONN_NO_REPO_URL = "my_git_conn_no_repo_url"
121
+
122
+
123
+ class TestGitHook :
124
+ @classmethod
125
+ def teardown_class (cls ) -> None :
126
+ clear_db_connections ()
127
+
128
+ @classmethod
129
+ def setup_class (cls ) -> None :
130
+ db .merge_conn (
131
+ Connection (
132
+ conn_id = CONN_DEFAULT ,
133
+ host = "github.com" ,
134
+ conn_type = "git" ,
135
+ extra = {"git_repo_url" : AIRFLOW_GIT },
136
+ )
137
+ )
138
+ db .merge_conn (
139
+ Connection (
140
+ conn_id = CONN_HTTPS ,
141
+ host = "github.com" ,
142
+ conn_type = "git" ,
143
+ extra = {"git_repo_url" : AIRFLOW_HTTPS_URL , "git_access_token" : ACCESS_TOKEN },
144
+ )
145
+ )
146
+ db .merge_conn (
147
+ Connection (
148
+ conn_id = CONN_HTTPS_PASSWORD ,
149
+ host = "github.com" ,
150
+ conn_type = "git" ,
151
+ password = ACCESS_TOKEN ,
152
+ extra = {"git_repo_url" : AIRFLOW_HTTPS_URL },
153
+ )
154
+ )
155
+ db .merge_conn (
156
+ Connection (
157
+ conn_id = CONN_ONLY_PATH ,
158
+ host = "github.com" ,
159
+ conn_type = "git" ,
160
+ extra = {"git_repo_url" : "path/to/repo" },
161
+ )
162
+ )
163
+
164
+ @pytest .mark .parametrize (
165
+ "conn_id, expected_repo_url" ,
166
+ [
167
+ (CONN_DEFAULT , AIRFLOW_GIT ),
168
+ (CONN_HTTPS , f"https://{ ACCESS_TOKEN } @github.com/apache/airflow.git" ),
169
+ (CONN_HTTPS_PASSWORD , f"https://{ ACCESS_TOKEN } @github.com/apache/airflow.git" ),
170
+ (CONN_ONLY_PATH , "path/to/repo" ),
171
+ ],
172
+ )
173
+ def test_correct_repo_urls (self , conn_id , expected_repo_url ):
174
+ hook = GitHook (git_conn_id = conn_id )
175
+ assert hook .repo_url == expected_repo_url
176
+
177
+ @mock .patch .object (SSHHook , "get_conn" )
178
+ def test_connection_made_to_ssh_hook (self , mock_ssh_hook_get_conn ):
179
+ hook = GitHook (git_conn_id = CONN_DEFAULT )
180
+ hook .get_conn ()
181
+ mock_ssh_hook_get_conn .assert_called_once_with ()
182
+
183
+
110
184
class TestGitDagBundle :
185
+ @classmethod
186
+ def teardown_class (cls ) -> None :
187
+ clear_db_connections ()
188
+
189
+ @classmethod
190
+ def setup_class (cls ) -> None :
191
+ db .merge_conn (
192
+ Connection (
193
+ conn_id = "git_default" ,
194
+ host = "github.com" ,
195
+ conn_type = "git" ,
196
+ extra = {
"git_repo_url" :
"[email protected] :apache/airflow.git" },
197
+ )
198
+ )
199
+ db .merge_conn (
200
+ Connection (
201
+ conn_id = CONN_NO_REPO_URL ,
202
+ host = "github.com" ,
203
+ conn_type = "git" ,
204
+ extra = "{}" ,
205
+ )
206
+ )
207
+
111
208
def test_supports_versioning (self ):
112
209
assert GitDagBundle .supports_versioning is True
113
210
114
211
def test_uses_dag_bundle_root_storage_path (self , git_repo ):
115
212
repo_path , repo = git_repo
116
- bundle = GitDagBundle (
117
- name = "test" , refresh_interval = 300 , repo_url = repo_path , tracking_ref = GIT_DEFAULT_BRANCH
118
- )
213
+ bundle = GitDagBundle (name = "test" , refresh_interval = 300 , tracking_ref = GIT_DEFAULT_BRANCH )
119
214
assert str (bundle ._dag_bundle_root_storage_path ) in str (bundle .path )
120
215
121
- def test_get_current_version (self , git_repo ):
216
+ @mock .patch ("airflow.dag_processing.bundles.git.GitHook" )
217
+ def test_get_current_version (self , mock_githook , git_repo ):
218
+ mock_githook .get_conn .return_value = mock .MagicMock ()
122
219
repo_path , repo = git_repo
123
- bundle = GitDagBundle (
124
- name = "test" , refresh_interval = 300 , repo_url = repo_path , tracking_ref = GIT_DEFAULT_BRANCH
125
- )
220
+ mock_githook . return_value . repo_url = repo_path
221
+ bundle = GitDagBundle ( name = "test" , refresh_interval = 300 , tracking_ref = GIT_DEFAULT_BRANCH )
222
+
126
223
bundle .initialize ()
127
224
128
225
assert bundle .get_current_version () == repo .head .commit .hexsha
129
226
130
- def test_get_specific_version (self , git_repo ):
227
+ @mock .patch ("airflow.dag_processing.bundles.git.GitHook" )
228
+ def test_get_specific_version (self , mock_githook , git_repo ):
229
+ mock_githook .get_conn .return_value = mock .MagicMock ()
131
230
repo_path , repo = git_repo
231
+ mock_githook .return_value .repo_url = repo_path
132
232
starting_commit = repo .head .commit
133
233
134
234
# Add new file to the repo
@@ -142,7 +242,6 @@ def test_get_specific_version(self, git_repo):
142
242
name = "test" ,
143
243
refresh_interval = 300 ,
144
244
version = starting_commit .hexsha ,
145
- repo_url = repo_path ,
146
245
tracking_ref = GIT_DEFAULT_BRANCH ,
147
246
)
148
247
bundle .initialize ()
@@ -152,8 +251,11 @@ def test_get_specific_version(self, git_repo):
152
251
files_in_repo = {f .name for f in bundle .path .iterdir () if f .is_file ()}
153
252
assert {"test_dag.py" } == files_in_repo
154
253
155
- def test_get_tag_version (self , git_repo ):
254
+ @mock .patch ("airflow.dag_processing.bundles.git.GitHook" )
255
+ def test_get_tag_version (self , mock_githook , git_repo ):
256
+ mock_githook .get_conn .return_value = mock .MagicMock ()
156
257
repo_path , repo = git_repo
258
+ mock_githook .return_value .repo_url = repo_path
157
259
starting_commit = repo .head .commit
158
260
159
261
# add tag
@@ -171,7 +273,6 @@ def test_get_tag_version(self, git_repo):
171
273
name = "test" ,
172
274
refresh_interval = 300 ,
173
275
version = "test" ,
174
- repo_url = repo_path ,
175
276
tracking_ref = GIT_DEFAULT_BRANCH ,
176
277
)
177
278
bundle .initialize ()
@@ -180,8 +281,11 @@ def test_get_tag_version(self, git_repo):
180
281
files_in_repo = {f .name for f in bundle .path .iterdir () if f .is_file ()}
181
282
assert {"test_dag.py" } == files_in_repo
182
283
183
- def test_get_latest (self , git_repo ):
284
+ @mock .patch ("airflow.dag_processing.bundles.git.GitHook" )
285
+ def test_get_latest (self , mock_githook , git_repo ):
286
+ mock_githook .get_conn .return_value = mock .MagicMock ()
184
287
repo_path , repo = git_repo
288
+ mock_githook .return_value .repo_url = repo_path
185
289
starting_commit = repo .head .commit
186
290
187
291
file_path = repo_path / "new_test.py"
@@ -190,23 +294,22 @@ def test_get_latest(self, git_repo):
190
294
repo .index .add ([file_path ])
191
295
repo .index .commit ("Another commit" )
192
296
193
- bundle = GitDagBundle (
194
- name = "test" , refresh_interval = 300 , repo_url = repo_path , tracking_ref = GIT_DEFAULT_BRANCH
195
- )
297
+ bundle = GitDagBundle (name = "test" , refresh_interval = 300 , tracking_ref = GIT_DEFAULT_BRANCH )
196
298
bundle .initialize ()
197
299
198
300
assert bundle .get_current_version () != starting_commit .hexsha
199
301
200
302
files_in_repo = {f .name for f in bundle .path .iterdir () if f .is_file ()}
201
303
assert {"test_dag.py" , "new_test.py" } == files_in_repo
202
304
203
- def test_refresh (self , git_repo ):
305
+ @mock .patch ("airflow.dag_processing.bundles.git.GitHook" )
306
+ def test_refresh (self , mock_githook , git_repo ):
307
+ mock_githook .get_conn .return_value = mock .MagicMock ()
204
308
repo_path , repo = git_repo
309
+ mock_githook .return_value .repo_url = repo_path
205
310
starting_commit = repo .head .commit
206
311
207
- bundle = GitDagBundle (
208
- name = "test" , refresh_interval = 300 , repo_url = repo_path , tracking_ref = GIT_DEFAULT_BRANCH
209
- )
312
+ bundle = GitDagBundle (name = "test" , refresh_interval = 300 , tracking_ref = GIT_DEFAULT_BRANCH )
210
313
bundle .initialize ()
211
314
212
315
assert bundle .get_current_version () == starting_commit .hexsha
@@ -227,29 +330,37 @@ def test_refresh(self, git_repo):
227
330
files_in_repo = {f .name for f in bundle .path .iterdir () if f .is_file ()}
228
331
assert {"test_dag.py" , "new_test.py" } == files_in_repo
229
332
230
- def test_head (self , git_repo ):
333
+ @mock .patch ("airflow.dag_processing.bundles.git.GitHook" )
334
+ def test_head (self , mock_githook , git_repo ):
335
+ mock_githook .get_conn .return_value = mock .MagicMock ()
231
336
repo_path , repo = git_repo
337
+ mock_githook .return_value .repo_url = repo_path
232
338
233
339
repo .create_head ("test" )
234
- bundle = GitDagBundle (name = "test" , refresh_interval = 300 , repo_url = repo_path , tracking_ref = "test" )
340
+ bundle = GitDagBundle (name = "test" , refresh_interval = 300 , tracking_ref = "test" )
235
341
bundle .initialize ()
236
342
assert bundle .repo .head .ref .name == "test"
237
343
238
- def test_version_not_found (self , git_repo ):
344
+ @mock .patch ("airflow.dag_processing.bundles.git.GitHook" )
345
+ def test_version_not_found (self , mock_githook , git_repo ):
346
+ mock_githook .get_conn .return_value = mock .MagicMock ()
239
347
repo_path , repo = git_repo
348
+ mock_githook .return_value .repo_url = repo_path
240
349
bundle = GitDagBundle (
241
350
name = "test" ,
242
351
refresh_interval = 300 ,
243
352
version = "not_found" ,
244
- repo_url = repo_path ,
245
353
tracking_ref = GIT_DEFAULT_BRANCH ,
246
354
)
247
355
248
356
with pytest .raises (AirflowException , match = "Version not_found not found in the repository" ):
249
357
bundle .initialize ()
250
358
251
- def test_subdir (self , git_repo ):
359
+ @mock .patch ("airflow.dag_processing.bundles.git.GitHook" )
360
+ def test_subdir (self , mock_githook , git_repo ):
361
+ mock_githook .get_conn .return_value = mock .MagicMock ()
252
362
repo_path , repo = git_repo
363
+ mock_githook .return_value .repo_url = repo_path
253
364
254
365
subdir = "somesubdir"
255
366
subdir_path = repo_path / subdir
@@ -264,7 +375,6 @@ def test_subdir(self, git_repo):
264
375
bundle = GitDagBundle (
265
376
name = "test" ,
266
377
refresh_interval = 300 ,
267
- repo_url = repo_path ,
268
378
tracking_ref = GIT_DEFAULT_BRANCH ,
269
379
subdir = subdir ,
270
380
)
@@ -274,44 +384,62 @@ def test_subdir(self, git_repo):
274
384
assert str (bundle .path ).endswith (subdir )
275
385
assert {"some_new_file.py" } == files_in_repo
276
386
277
- @mock .patch ("airflow.providers.ssh.hooks.ssh.SSHHook" )
387
+ def test_raises_when_no_repo_url (self ):
388
+ bundle = GitDagBundle (
389
+ name = "test" ,
390
+ refresh_interval = 300 ,
391
+ git_conn_id = CONN_NO_REPO_URL ,
392
+ tracking_ref = GIT_DEFAULT_BRANCH ,
393
+ )
394
+ with pytest .raises (
395
+ AirflowException , match = f"Connection { CONN_NO_REPO_URL } doesn't have a git_repo_url"
396
+ ):
397
+ bundle .initialize ()
398
+
399
+ @mock .patch ("airflow.dag_processing.bundles.git.GitHook" )
278
400
@mock .patch ("airflow.dag_processing.bundles.git.Repo" )
279
- def test_with_ssh_conn_id (self , mock_gitRepo , mock_hook ):
280
- repo_url = "[email protected] :apache/airflow.git"
281
- conn_id = "ssh_default"
401
+ @mock .patch .object (GitDagBundle , "_clone_from" )
402
+ def test_with_path_as_repo_url (self , mock_clone_from , mock_gitRepo , mock_githook ):
282
403
bundle = GitDagBundle (
283
- repo_url = repo_url ,
284
404
name = "test" ,
285
405
refresh_interval = 300 ,
286
- ssh_conn_kwargs = { "ssh_conn_id" : "ssh_default" } ,
406
+ git_conn_id = CONN_ONLY_PATH ,
287
407
tracking_ref = GIT_DEFAULT_BRANCH ,
288
408
)
289
409
bundle .initialize ()
290
- mock_hook .assert_called_once_with (ssh_conn_id = conn_id )
410
+ assert mock_clone_from .call_count == 2
411
+ assert mock_gitRepo .return_value .git .checkout .call_count == 1
291
412
292
- @mock .patch ("airflow.providers.ssh.hooks.ssh.SSHHook " )
413
+ @mock .patch ("airflow.dag_processing.bundles.git.GitHook " )
293
414
@mock .patch ("airflow.dag_processing.bundles.git.Repo" )
294
- def test_refresh_with_ssh_connection (self , mock_gitRepo , mock_hook ):
295
- repo_url = "[email protected] :apache/airflow.git"
415
+ def test_refresh_with_git_connection (self , mock_gitRepo , mock_hook ):
296
416
bundle = GitDagBundle (
297
- repo_url = repo_url ,
298
417
name = "test" ,
299
418
refresh_interval = 300 ,
300
- ssh_conn_kwargs = { "ssh_conn_id" : "ssh_default" } ,
419
+ git_conn_id = "git_default" ,
301
420
tracking_ref = GIT_DEFAULT_BRANCH ,
302
421
)
303
422
bundle .initialize ()
304
423
bundle .refresh ()
305
424
# check remotes called twice. one at initialize and one at refresh above
306
425
assert mock_gitRepo .return_value .remotes .origin .fetch .call_count == 2
307
426
308
- def test_repo_url_starts_with_git_when_using_ssh_conn_id (self ):
309
- repo_url = "https://github.com/apache/airflow"
427
+ @pytest .mark .parametrize (
428
+ "repo_url" ,
429
+ [
430
+ pytest .param ("https://github.com/apache/airflow" , id = "https_url" ),
431
+ pytest .param ("airflow@example:apache/airflow.git" , id = "does_not_start_with_git_at" ),
432
+ pytest .param ("git@example:apache/airflow" , id = "does_not_end_with_dot_git" ),
433
+ ],
434
+ )
435
+ @mock .patch ("airflow.dag_processing.bundles.git.GitHook" )
436
+ def test_repo_url_starts_with_git_when_using_ssh_conn_id (self , mock_hook , repo_url , session ):
437
+ mock_hook .get_conn .return_value = mock .MagicMock ()
438
+ mock_hook .return_value .repo_url = repo_url
310
439
bundle = GitDagBundle (
311
- repo_url = repo_url ,
312
440
name = "test" ,
313
441
refresh_interval = 300 ,
314
- ssh_conn_kwargs = { "ssh_conn_id" : "ssh_default" } ,
442
+ git_conn_id = "git_default" ,
315
443
tracking_ref = GIT_DEFAULT_BRANCH ,
316
444
)
317
445
with pytest .raises (
0 commit comments