14
14
import pretend
15
15
import pytest
16
16
17
- from sqlalchemy .orm .exc import NoResultFound
18
-
19
17
from warehouse .malware import tasks
20
18
from warehouse .malware .models import MalwareCheck , MalwareCheckState , MalwareVerdict
21
19
@@ -34,45 +32,86 @@ def test_success(self, db_request, monkeypatch):
34
32
name = "ExampleHookedCheck" , state = MalwareCheckState .Enabled
35
33
)
36
34
task = pretend .stub ()
37
- tasks .run_check (task , db_request , "ExampleHookedCheck" , file0 .id )
35
+ tasks .run_check (task , db_request , "ExampleHookedCheck" , obj_id = file0 .id )
38
36
39
37
assert db_request .route_url .calls == [
40
38
pretend .call ("packaging.file" , path = file0 .path )
41
39
]
42
40
assert db_request .db .query (MalwareVerdict ).one ()
43
41
44
- def test_disabled_check (self , db_request , monkeypatch ):
42
+ @pytest .mark .parametrize (("manually_triggered" ), [True , False ])
43
+ def test_evaluation_run (self , db_session , monkeypatch , manually_triggered ):
44
+ monkeypatch .setattr (tasks , "checks" , test_checks )
45
+ MalwareCheckFactory .create (
46
+ name = "ExampleScheduledCheck" , state = MalwareCheckState .Evaluation
47
+ )
48
+ ProjectFactory .create ()
49
+ task = pretend .stub ()
50
+
51
+ request = pretend .stub (
52
+ db = db_session ,
53
+ log = pretend .stub (info = pretend .call_recorder (lambda * args , ** kwargs : None )),
54
+ )
55
+
56
+ tasks .run_check (
57
+ task ,
58
+ request ,
59
+ "ExampleScheduledCheck" ,
60
+ manually_triggered = manually_triggered ,
61
+ )
62
+
63
+ if manually_triggered :
64
+ assert db_session .query (MalwareVerdict ).one ()
65
+ else :
66
+ assert request .log .info .calls == [
67
+ pretend .call (
68
+ "ExampleScheduledCheck is in the `evaluation` state and must be \
69
+ manually triggered to run."
70
+ )
71
+ ]
72
+ assert db_session .query (MalwareVerdict ).all () == []
73
+
74
+ def test_disabled_check (self , db_session , monkeypatch ):
45
75
monkeypatch .setattr (tasks , "checks" , test_checks )
46
76
MalwareCheckFactory .create (
47
77
name = "ExampleHookedCheck" , state = MalwareCheckState .Disabled
48
78
)
49
79
task = pretend .stub ()
80
+ request = pretend .stub (
81
+ db = db_session ,
82
+ log = pretend .stub (info = pretend .call_recorder (lambda * args , ** kwargs : None )),
83
+ )
50
84
51
85
file = FileFactory .create ()
52
86
53
- with pytest .raises (NoResultFound ):
54
- tasks .run_check (task , db_request , "ExampleHookedCheck" , file .id )
87
+ tasks .run_check (
88
+ task , request , "ExampleHookedCheck" , obj_id = file .id ,
89
+ )
90
+
91
+ assert request .log .info .calls == [
92
+ pretend .call ("Check ExampleHookedCheck isn't active. Aborting." )
93
+ ]
55
94
56
95
def test_missing_check (self , db_request , monkeypatch ):
57
96
monkeypatch .setattr (tasks , "checks" , test_checks )
58
97
task = pretend .stub ()
59
98
60
- file = FileFactory .create ()
61
-
62
99
with pytest .raises (AttributeError ):
63
- tasks .run_check (task , db_request , "DoesNotExistCheck" , file .id )
100
+ tasks .run_check (
101
+ task , db_request , "DoesNotExistCheck" ,
102
+ )
64
103
65
104
def test_retry (self , db_session , monkeypatch ):
105
+ monkeypatch .setattr (tasks , "checks" , test_checks )
66
106
exc = Exception ("Scan failed" )
67
107
68
108
def scan (self , ** kwargs ):
69
109
raise exc
70
110
71
- monkeypatch .setattr (tasks , "checks" , test_checks )
72
111
monkeypatch .setattr (tasks .checks .ExampleHookedCheck , "scan" , scan )
73
112
74
113
MalwareCheckFactory .create (
75
- name = "ExampleHookedCheck" , state = MalwareCheckState .Evaluation
114
+ name = "ExampleHookedCheck" , state = MalwareCheckState .Enabled
76
115
)
77
116
78
117
task = pretend .stub (
@@ -87,7 +126,7 @@ def scan(self, **kwargs):
87
126
file = FileFactory .create ()
88
127
89
128
with pytest .raises (celery .exceptions .Retry ):
90
- tasks .run_check (task , request , "ExampleHookedCheck" , file .id )
129
+ tasks .run_check (task , request , "ExampleHookedCheck" , obj_id = file .id )
91
130
92
131
assert request .log .error .calls == [
93
132
pretend .call ("Error executing check ExampleHookedCheck: Scan failed" )
@@ -108,9 +147,8 @@ def test_invalid_check_name(self, db_request, monkeypatch):
108
147
)
109
148
def test_run (self , db_session , num_objects , num_runs , monkeypatch ):
110
149
monkeypatch .setattr (tasks , "checks" , test_checks )
111
- files = []
112
150
for i in range (num_objects ):
113
- files . append ( FileFactory .create () )
151
+ FileFactory .create ()
114
152
115
153
MalwareCheckFactory .create (
116
154
name = "ExampleHookedCheck" , state = MalwareCheckState .Enabled
@@ -133,15 +171,14 @@ def test_run(self, db_session, num_objects, num_runs, monkeypatch):
133
171
pretend .call ("Running backfill on %d Files." % num_runs )
134
172
]
135
173
136
- assert enqueue_recorder .delay .calls == [
137
- pretend .call ("ExampleHookedCheck" , files [i ].id ) for i in range (num_runs )
138
- ]
174
+ assert len (enqueue_recorder .delay .calls ) == num_runs
139
175
140
176
141
177
class TestSyncChecks :
142
178
def test_no_updates (self , db_session , monkeypatch ):
143
179
monkeypatch .setattr (tasks , "checks" , test_checks )
144
180
monkeypatch .setattr (tasks .checks .ExampleScheduledCheck , "version" , 2 )
181
+
145
182
MalwareCheckFactory .create (
146
183
name = "ExampleHookedCheck" , state = MalwareCheckState .Disabled
147
184
)
0 commit comments