Skip to content

Commit 046dbc1

Browse files
xmunozewdurbin
authored andcommitted
Add wipe-out functionality (#7202)
* Add wipe-out functionality Related: #7133 * Call list explicitly
1 parent d4dbeed commit 046dbc1

File tree

5 files changed

+114
-10
lines changed

5 files changed

+114
-10
lines changed

tests/common/db/malware.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,13 @@
2020
MalwareCheckObjectType,
2121
MalwareCheckState,
2222
MalwareCheckType,
23+
MalwareVerdict,
24+
VerdictClassification,
25+
VerdictConfidence,
2326
)
2427

2528
from .base import WarehouseFactory
29+
from .packaging import FileFactory
2630

2731

2832
class MalwareCheckFactory(WarehouseFactory):
@@ -33,9 +37,20 @@ class Meta:
3337
version = 1
3438
short_description = factory.fuzzy.FuzzyText(length=80)
3539
long_description = factory.fuzzy.FuzzyText(length=300)
36-
check_type = factory.fuzzy.FuzzyChoice([e for e in MalwareCheckType])
37-
hooked_object = factory.fuzzy.FuzzyChoice([e for e in MalwareCheckObjectType])
38-
state = factory.fuzzy.FuzzyChoice([e for e in MalwareCheckState])
40+
check_type = factory.fuzzy.FuzzyChoice(list(MalwareCheckType))
41+
hooked_object = factory.fuzzy.FuzzyChoice(list(MalwareCheckObjectType))
42+
state = factory.fuzzy.FuzzyChoice(list(MalwareCheckState))
3943
created = factory.fuzzy.FuzzyNaiveDateTime(
4044
datetime.datetime.utcnow() - datetime.timedelta(days=7)
4145
)
46+
47+
48+
class MalwareVerdictFactory(WarehouseFactory):
49+
class Meta:
50+
model = MalwareVerdict
51+
52+
check = factory.SubFactory(MalwareCheckFactory)
53+
release_file = factory.SubFactory(FileFactory)
54+
classification = factory.fuzzy.FuzzyChoice(list(VerdictClassification))
55+
confidence = factory.fuzzy.FuzzyChoice(list(VerdictConfidence))
56+
message = factory.fuzzy.FuzzyText(length=80)

tests/unit/admin/views/test_checks.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -67,27 +67,41 @@ def test_get_check_not_found(self, db_request):
6767

6868

6969
class TestChangeCheckState:
70-
def test_change_to_enabled(self, db_request):
70+
@pytest.mark.parametrize(
71+
("final_state"), [MalwareCheckState.disabled, MalwareCheckState.wiped_out]
72+
)
73+
def test_change_to_valid_state(self, db_request, final_state):
7174
check = MalwareCheckFactory.create(
7275
name="MyCheck", state=MalwareCheckState.disabled
7376
)
7477

75-
db_request.POST = {"id": check.id, "check_state": "enabled"}
78+
db_request.POST = {"id": check.id, "check_state": final_state.value}
7679
db_request.matchdict["check_name"] = check.name
7780

7881
db_request.session = pretend.stub(
7982
flash=pretend.call_recorder(lambda *a, **kw: None)
8083
)
84+
wipe_out_recorder = pretend.stub(
85+
delay=pretend.call_recorder(lambda *a, **kw: None)
86+
)
87+
db_request.task = pretend.call_recorder(lambda *a, **kw: wipe_out_recorder)
88+
8189
db_request.route_path = pretend.call_recorder(
8290
lambda *a, **kw: "/admin/checks/MyCheck/change_state"
8391
)
8492

8593
views.change_check_state(db_request)
8694

8795
assert db_request.session.flash.calls == [
88-
pretend.call("Changed 'MyCheck' check to 'enabled'!", queue="success")
96+
pretend.call(
97+
"Changed 'MyCheck' check to '%s'!" % final_state.value, queue="success"
98+
)
8999
]
90-
assert check.state == MalwareCheckState.enabled
100+
101+
assert check.state == final_state
102+
103+
if final_state == MalwareCheckState.wiped_out:
104+
assert wipe_out_recorder.delay.calls == [pretend.call("MyCheck")]
91105

92106
def test_change_to_invalid_state(self, db_request):
93107
check = MalwareCheckFactory.create(name="MyCheck")

tests/unit/malware/test_tasks.py

+54-2
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
import warehouse.malware.checks as checks
2020

2121
from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareVerdict
22-
from warehouse.malware.tasks import run_check, sync_checks
22+
from warehouse.malware.tasks import remove_verdicts, run_check, sync_checks
2323

24-
from ...common.db.malware import MalwareCheckFactory
24+
from ...common.db.malware import MalwareCheckFactory, MalwareVerdictFactory
2525
from ...common.db.packaging import FileFactory, ProjectFactory, ReleaseFactory
2626

2727

@@ -255,3 +255,55 @@ def test_only_wiped_out(self, db_session):
255255
from codebase."
256256
),
257257
]
258+
259+
260+
class TestRemoveVerdicts:
261+
def test_no_verdicts(self, db_session):
262+
check = MalwareCheckFactory.create()
263+
264+
request = pretend.stub(
265+
db=db_session,
266+
log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None),),
267+
)
268+
task = pretend.stub()
269+
remove_verdicts(task, request, check.name)
270+
271+
assert request.log.info.calls == [
272+
pretend.call(
273+
"Removing 0 malware verdicts associated with %s version 1." % check.name
274+
),
275+
]
276+
277+
@pytest.mark.parametrize(("check_with_verdicts"), [True, False])
278+
def test_many_verdicts(self, db_session, check_with_verdicts):
279+
check0 = MalwareCheckFactory.create()
280+
check1 = MalwareCheckFactory.create()
281+
project = ProjectFactory.create(name="foo")
282+
release = ReleaseFactory.create(project=project)
283+
file0 = FileFactory.create(release=release, filename="foo.bar")
284+
num_verdicts = 10
285+
286+
for i in range(num_verdicts):
287+
MalwareVerdictFactory.create(check=check1, release_file=file0)
288+
289+
request = pretend.stub(
290+
db=db_session,
291+
log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None),),
292+
)
293+
294+
task = pretend.stub()
295+
296+
if check_with_verdicts:
297+
wiped_out_check = check1
298+
else:
299+
wiped_out_check = check0
300+
num_verdicts = 0
301+
302+
remove_verdicts(task, request, wiped_out_check.name)
303+
304+
assert request.log.info.calls == [
305+
pretend.call(
306+
"Removing %d malware verdicts associated with %s version 1."
307+
% (num_verdicts, wiped_out_check.name)
308+
),
309+
]

warehouse/admin/views/checks.py

+3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from sqlalchemy.orm.exc import NoResultFound
1616

1717
from warehouse.malware.models import MalwareCheck, MalwareCheckState
18+
from warehouse.malware.tasks import remove_verdicts
1819

1920

2021
@view_config(
@@ -80,6 +81,8 @@ def change_check_state(request):
8081
except (AttributeError, KeyError):
8182
request.session.flash("Invalid check state provided.", queue="error")
8283
else:
84+
if check.state == MalwareCheckState.wiped_out:
85+
request.task(remove_verdicts).delay(check.name)
8386
request.session.flash(
8487
f"Changed {check.name!r} check to {check.state.value!r}!", queue="success"
8588
)

warehouse/malware/tasks.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import warehouse.malware.checks as checks
1616

17-
from warehouse.malware.models import MalwareCheck, MalwareCheckState
17+
from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareVerdict
1818
from warehouse.malware.utils import get_check_fields
1919
from warehouse.tasks import task
2020

@@ -86,3 +86,23 @@ def sync_checks(task, request):
8686
request.log.info("Adding new %s to the database." % check_name)
8787
fields = get_check_fields(check)
8888
request.db.add(MalwareCheck(**fields))
89+
90+
91+
@task(bind=True, ignore_result=True, acks_late=True)
92+
def remove_verdicts(task, request, check_name):
93+
check_ids = (
94+
request.db.query(MalwareCheck.id, MalwareCheck.version)
95+
.filter(MalwareCheck.name == check_name)
96+
.all()
97+
)
98+
99+
for check_id, check_version in check_ids:
100+
query = request.db.query(MalwareVerdict).filter(
101+
MalwareVerdict.check_id == check_id
102+
)
103+
num_verdicts = query.count()
104+
request.log.info(
105+
"Removing %d malware verdicts associated with %s version %d."
106+
% (num_verdicts, check_name, check_version)
107+
)
108+
query.delete(synchronize_session=False)

0 commit comments

Comments
 (0)