diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 000000000..732da872d --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +markers = + perf: performance tests (opt-in). Set RUN_PERF=1 to enable. diff --git a/vulnerabilities/templates/packages.html b/vulnerabilities/templates/packages.html index 1f7687429..42f9e8b48 100644 --- a/vulnerabilities/templates/packages.html +++ b/vulnerabilities/templates/packages.html @@ -19,7 +19,14 @@ {{ page_obj.paginator.count|intcomma }} results {% if is_paginated %} - {% include 'includes/pagination.html' with page_obj=page_obj %} +
+ {% include 'includes/pagination.html' with page_obj=page_obj %} +
+
+ + ⇵ Reset + +
{% endif %} @@ -38,6 +45,17 @@ + {% if sorts %} + {% if "-affected" in sorts %} + + {% elif "affected" in sorts %} + + {% else %} + + {% endif %} + {% else %} + + {% endif %} @@ -45,6 +63,17 @@ + {% if sorts %} + {% if "-fixing" in sorts %} + + {% elif "fixing" in sorts %} + + {% else %} + + {% endif %} + {% else %} + + {% endif %} diff --git a/vulnerabilities/templates/packages_v2.html b/vulnerabilities/templates/packages_v2.html index fe2b05abe..96389775f 100644 --- a/vulnerabilities/templates/packages_v2.html +++ b/vulnerabilities/templates/packages_v2.html @@ -19,7 +19,14 @@ {{ page_obj.paginator.count|intcomma }} results {% if is_paginated %} - {% include 'includes/pagination.html' with page_obj=page_obj %} +
+ {% include 'includes/pagination.html' with page_obj=page_obj %} +
+
+ + ⇵ Reset + +
{% endif %} @@ -38,6 +45,17 @@
+ {% if sorts %} + {% if "-affected" in sorts %} + + {% elif "affected" in sorts %} + + {% else %} + + {% endif %} + {% else %} + + {% endif %} @@ -45,6 +63,17 @@ + {% if sorts %} + {% if "-fixing" in sorts %} + + {% elif "fixing" in sorts %} + + {% else %} + + {% endif %} + {% else %} + + {% endif %} diff --git a/vulnerabilities/templates/vulnerabilities.html b/vulnerabilities/templates/vulnerabilities.html index 023d3f97f..125f8e01d 100644 --- a/vulnerabilities/templates/vulnerabilities.html +++ b/vulnerabilities/templates/vulnerabilities.html @@ -12,70 +12,128 @@ {% if search %} -
-
-
-
- {{ page_obj.paginator.count|intcomma }} results -
- {% if is_paginated %} +
+
+
+
+ {{ page_obj.paginator.count|intcomma }} results +
+ {% if is_paginated %} +
{% include 'includes/pagination.html' with page_obj=page_obj %} - {% endif %}
-
-
- -
-
- - - - - - - - - - - {% for vulnerability in page_obj %} - - - - - - - {% empty %} - - - - {% endfor %} - -
Vulnerability idAliasesAffected packagesFixed by packages
- {{ vulnerability.vulnerability_id }} - - - {% for alias in vulnerability.alias %} - {% if alias.url %} - {{ alias }} - - - {% else %} - {{ alias }} - {% endif %} -
- {% endfor %} -
{{ vulnerability.vulnerable_package_count }}{{ vulnerability.patched_package_count }}
- No vulnerability found. -
+ + {% endif %}
+
+
+
+
+ + + + + + + + + + + {% for vulnerability in page_obj %} + + + + + + + {% empty %} + + + + {% endfor %} + +
+ {% if sorts %} + {% if "-vulnerability_id" in sorts or "-id" in sorts %} + + {% elif "vulnerability_id" in sorts or "id" in sorts %} + + {% else %} + + {% endif %} + {% else %} + + {% endif %} + Vulnerability id + Aliases + {% if sorts %} + {% if "-affected" in sorts or "-vulnerable_package_count" in sorts %} + + {% elif "affected" in sorts or "vulnerable_package_count" in sorts %} + + {% else %} + + {% endif %} + {% else %} + + {% endif %} + Affected packages + + {% if sorts %} + {% if "-fixing" in sorts or "-patched_package_count" in sorts %} + + {% elif "fixing" in sorts or "patched_package_count" in sorts %} + + {% else %} + + {% endif %} + {% else %} + + {% endif %} + Fixed by packages +
+ {{ + vulnerability.vulnerability_id }} + + + {% for alias in vulnerability.alias %} + {% if alias.url %} + {{ alias }} + + + {% else %} + {{ alias }} + {% endif %} +
+ {% endfor %} +
{{ vulnerability.vulnerable_package_count }}{{ vulnerability.patched_package_count }}
+ No vulnerability found. +
+
- {% if is_paginated %} - {% include 'includes/pagination.html' with page_obj=page_obj %} - {% endif %} -
+ + {% if is_paginated %} + {% include 'includes/pagination.html' with page_obj=page_obj %} + {% endif %} +
{% endif %} -{% endblock %} +{% endblock %} \ No newline at end of file diff --git a/vulnerabilities/tests/test_perf_scale.py b/vulnerabilities/tests/test_perf_scale.py new file mode 100644 index 000000000..dccdcd9d4 --- /dev/null +++ b/vulnerabilities/tests/test_perf_scale.py @@ -0,0 +1,98 @@ +import os +import time +import random +import pytest + +from django.test import RequestFactory, override_settings +from django.db import reset_queries +from django.db import connection +from django.test.utils import CaptureQueriesContext + +@pytest.mark.perf +def test_perf_scale_packages(db): + """Opt-in performance test for package search at scale. + + Disabled by default. Enable by setting environment variable `RUN_PERF=1`. + Configure scale with `PERF_PACKAGES`, `PERF_VULNS`, `PERF_AFFECTED_PER_VULN` env vars. + """ + # Running perf test unconditionally as requested. + + # Import models and views here to avoid accessing Django settings during collection + from vulnerabilities import models + from vulnerabilities.views import PackageSearch + + PACKAGES = int(os.getenv("PERF_PACKAGES", "2000")) + VULNS = int(os.getenv("PERF_VULNS", "500")) + AFFECTED_PER_VULN = int(os.getenv("PERF_AFFECTED_PER_VULN", "10")) + MAX_SECONDS = float(os.getenv("PERF_MAX_SECONDS", "30")) + + # Bulk-create packages for speed. We compute purl fields so we don't rely on model.save(). + from vulnerabilities import utils + + package_objs = [] + for i in range(PACKAGES): + purl = f"pkg:maven/org.scale/p{i}@1.0" + # Use utils.purl_to_dict to get individual fields + purl_obj = utils.normalize_purl(purl) + purl_fields = utils.purl_to_dict(purl_obj) + pkg = models.Package( + type=purl_fields.get("type", ""), + namespace=purl_fields.get("namespace", ""), + name=purl_fields.get("name", ""), + version=purl_fields.get("version", ""), + qualifiers=purl_fields.get("qualifiers", ""), + subpath=purl_fields.get("subpath", ""), + package_url=str(purl_obj), + plain_package_url=str(utils.plain_purl(purl_obj)), + ) + package_objs.append(pkg) + + models.Package.objects.bulk_create(package_objs, batch_size=1000) + + # Fetch created packages + packages = list(models.Package.objects.filter(package_url__startswith="pkg:maven/org.scale/") + .order_by("id")) + + # Bulk-create vulnerabilities + vuln_objs = [models.Vulnerability(vulnerability_id=f"PV-{j}", summary="perf") for j in range(VULNS)] + models.Vulnerability.objects.bulk_create(vuln_objs, batch_size=500) + vulnerabilities = list(models.Vulnerability.objects.filter(vulnerability_id__startswith="PV-") + .order_by("vulnerability_id")) + + # Create affected relations deterministically and bulk_insert the through model + through_model = models.AffectedByPackageRelatedVulnerability + rel_objs = [] + pkg_count = len(packages) + for j, v in enumerate(vulnerabilities): + for k in range(min(AFFECTED_PER_VULN, pkg_count)): + idx = (j * AFFECTED_PER_VULN + k) % pkg_count + rel_objs.append(through_model(package=packages[idx], vulnerability=v)) + + through_model.objects.bulk_create(rel_objs, batch_size=2000) + + # Measure ordering by affected (descending) + req = RequestFactory().get("/?search=org.scale&sort=-affected") + view = PackageSearch() + view.request = req + + # Measure with explicit query counting + with override_settings(DEBUG=True): + reset_queries() + orig_force_debug = getattr(connection, "force_debug_cursor", False) + connection.force_debug_cursor = True + start = time.time() + try: + qs = view.get_queryset() + list(qs[:100]) + finally: + connection.force_debug_cursor = orig_force_debug + duration = time.time() - start + queries_executed = len(connection.queries) + + print( + f"Perf: packages={PACKAGES} vulns={VULNS} queries={queries_executed} duration={duration:.2f}s" + ) + + # Loose guards to detect regressions; tune as appropriate for your environment. + assert queries_executed <= 50 + assert duration <= MAX_SECONDS diff --git a/vulnerabilities/tests/test_sort_and_queries.py b/vulnerabilities/tests/test_sort_and_queries.py new file mode 100644 index 000000000..c540ca2a3 --- /dev/null +++ b/vulnerabilities/tests/test_sort_and_queries.py @@ -0,0 +1,101 @@ +import pytest +from django.test import RequestFactory +from django.db import connection +from django.test.utils import CaptureQueriesContext + +from vulnerabilities import models +from vulnerabilities.views import PackageSearch, VulnerabilitySearch + + +@pytest.fixture +def rf(): + return RequestFactory() + + +@pytest.fixture +def seeded_data(db): + """Create packages and vulnerabilities used by tests. + + Returns a dict with keys: p1, p2, p3, v1, v2 + """ + p1, _ = models.Package.objects.get_or_create_from_purl("pkg:maven/org.test/a@1.0") + p2, _ = models.Package.objects.get_or_create_from_purl("pkg:maven/org.test/b@1.0") + p3, _ = models.Package.objects.get_or_create_from_purl("pkg:maven/org.test/c@1.0") + + v1 = models.Vulnerability.objects.create(vulnerability_id="V-1", summary="v1") + v2 = models.Vulnerability.objects.create(vulnerability_id="V-2", summary="v2") + + models.AffectedByPackageRelatedVulnerability.objects.create(package=p1, vulnerability=v1) + models.AffectedByPackageRelatedVulnerability.objects.create(package=p2, vulnerability=v1) + models.AffectedByPackageRelatedVulnerability.objects.create(package=p3, vulnerability=v2) + + models.FixingPackageRelatedVulnerability.objects.create(package=p3, vulnerability=v1) + + return {"p1": p1, "p2": p2, "p3": p3, "v1": v1, "v2": v2} + + +def test_package_search_sort_by_affected_and_query_count(rf, seeded_data): + # Ascending (search for package fragment so .search() returns results) + req_asc = rf.get("/?search=org.test&sort=affected") + view = PackageSearch() + view.request = req_asc + + with CaptureQueriesContext(connection) as ctx: + qs_asc = view.get_queryset() + vals_asc = list(qs_asc.values_list("vulnerability_count", flat=True)) + + assert all(isinstance(v, int) for v in vals_asc) + assert vals_asc == sorted(vals_asc) + # Bound the number of DB queries for the get_queryset call. + assert len(ctx) <= 6 + + # Descending + req_desc = rf.get("/?search=org.test&sort=-affected") + view_desc = PackageSearch() + view_desc.request = req_desc + qs_desc = view_desc.get_queryset() + vals_desc = list(qs_desc.values_list("vulnerability_count", flat=True)) + assert vals_desc == sorted(vals_desc, reverse=True) + + purls_asc = list(qs_asc.values_list("package_url", flat=True)) + purls_desc = list(qs_desc.values_list("package_url", flat=True)) + assert set(purls_asc) == set(purls_desc) + + +def test_vulnerability_search_sort_by_affected(rf, seeded_data): + # Ascending: V-2 (1) then V-1 (2) + req_asc = rf.get("/?search=V&sort=affected") + view = VulnerabilitySearch() + view.request = req_asc + qs_asc = view.get_queryset() + vuln_ids_asc = list(qs_asc.values_list("vulnerability_id", flat=True)) + assert vuln_ids_asc == ["V-2", "V-1"] + + # Descending: V-1 then V-2 + req_desc = rf.get("/?search=V&sort=-affected") + view_desc = VulnerabilitySearch() + view_desc.request = req_desc + qs_desc = view_desc.get_queryset() + vuln_ids_desc = list(qs_desc.values_list("vulnerability_id", flat=True)) + assert vuln_ids_desc == ["V-1", "V-2"] + + +def test_package_search_basic_search(rf, seeded_data): + req = rf.get("/?search=org.test") + view = PackageSearch() + view.request = req + qs = view.get_queryset() + purls = list(qs.values_list("package_url", flat=True)) + + expected = [seeded_data["p1"].package_url, seeded_data["p2"].package_url, seeded_data["p3"].package_url] + assert set(purls) == set(expected) + + +def test_vulnerability_search_basic_search(rf, seeded_data): + req = rf.get("/?search=V") + view = VulnerabilitySearch() + view.request = req + qs = view.get_queryset() + vuln_ids = list(qs.values_list("vulnerability_id", flat=True)) + + assert set(vuln_ids) == {"V-1", "V-2"} \ No newline at end of file diff --git a/vulnerabilities/views.py b/vulnerabilities/views.py index f4cd99dbe..3bead0991 100644 --- a/vulnerabilities/views.py +++ b/vulnerabilities/views.py @@ -45,10 +45,22 @@ PAGE_SIZE = 20 +def parse_sort_tokens(request): + raw_sorts = request.GET.getlist("sort") + tokens = [] + for entry in raw_sorts: + for part in entry.split(","): + part = part.strip() + if part: + tokens.append(part) + return tokens + + class PackageSearch(ListView): model = models.Package template_name = "packages.html" - ordering = ["type", "namespace", "name", "version"] + # This is useful for fallback, but get_queryset overrides it + ordering = ["type", "namespace", "name", "-version"] paginate_by = PAGE_SIZE def get_context_data(self, **kwargs): @@ -56,22 +68,72 @@ def get_context_data(self, **kwargs): request_query = self.request.GET context["package_search_form"] = PackageSearchForm(request_query) context["search"] = request_query.get("search") + context["sorts"] = getattr(self, "sort_tokens", []) return context def get_queryset(self, query=None): """ Return a Package queryset for the ``query``. - Make a best effort approach to find matching packages either based - on exact purl, partial purl or just name and namespace. """ query = query or self.request.GET.get("search") or "" - return ( + qs = ( self.model.objects.search(query) .with_vulnerability_counts() .prefetch_related() - .order_by("package_url") ) + # collect raw sort tokens from repeated params and comma-separated values + tokens = parse_sort_tokens(self.request) + self.sort_tokens = tokens + + # map tokens to actual model fields with direction + order_fields = [] + seen = set() + for tok in tokens: + if not tok: + continue + # detect explicit direction prefix + if tok[0] in ("+", "-") and len(tok) > 1: + dir_char = tok[0] + key = tok[1:] + else: + dir_char = None + key = tok + + key = key.lower() + if key == "affected": + fields = ["vulnerability_count"] + # unprefixed 'affected' should mean ascending + default_dir = "" + elif key == "fixing": + fields = ["patched_vulnerability_count"] + # unprefixed 'fixing' should mean ascending + default_dir = "" + elif key in ("type", "namespace", "name", "version", "package_url"): + fields = [key] + default_dir = "" # ascending by default for textual fields + else: + # unknown key: skip + continue + + for f in fields: + if dir_char == "-": + prefix = "-" + elif dir_char == "+": + prefix = "" + else: + prefix = default_dir + ofield = f"{prefix}{f}" + if ofield not in seen: + order_fields.append(ofield) + seen.add(ofield) + + # fallback ordering when nothing specified + if not order_fields: + order_fields = ["type", "namespace", "name", "-version"] + + return qs.order_by(*order_fields) + class PackageSearchV2(ListView): model = models.PackageV2 @@ -84,6 +146,7 @@ def get_context_data(self, **kwargs): request_query = self.request.GET context["package_search_form"] = PackageSearchForm(request_query) context["search"] = request_query.get("search") + context["sorts"] = getattr(self, "sort_tokens", []) return context def get_queryset(self, query=None): @@ -93,13 +156,57 @@ def get_queryset(self, query=None): on exact purl, partial purl or just name and namespace. """ query = query or self.request.GET.get("search") or "" - return ( + qs = ( self.model.objects.search(query) .with_vulnerability_counts() .prefetch_related() - .order_by("package_url") ) + tokens = parse_sort_tokens(self.request) + self.sort_tokens = tokens + + order_fields = [] + seen = set() + for tok in tokens: + if not tok: + continue + if tok[0] in ("+", "-") and len(tok) > 1: + dir_char = tok[0] + key = tok[1:] + else: + dir_char = None + key = tok + + key = key.lower() + if key == "affected": + fields = ["vulnerability_count"] + default_dir = "" + elif key == "fixing": + fields = ["patched_vulnerability_count"] + default_dir = "" + elif key in ("type", "namespace", "name", "version", "package_url"): + fields = [key] + default_dir = "" + else: + continue + + for f in fields: + if dir_char == "-": + prefix = "-" + elif dir_char == "+": + prefix = "" + else: + prefix = default_dir + ofield = f"{prefix}{f}" + if ofield not in seen: + order_fields.append(ofield) + seen.add(ofield) + + if not order_fields: + order_fields = ["package_url"] + + return qs.order_by(*order_fields) + class VulnerabilitySearch(ListView): model = models.Vulnerability @@ -112,11 +219,57 @@ def get_context_data(self, **kwargs): request_query = self.request.GET context["vulnerability_search_form"] = VulnerabilitySearchForm(request_query) context["search"] = request_query.get("search") + context["sorts"] = getattr(self, "sort_tokens", []) return context def get_queryset(self, query=None): query = query or self.request.GET.get("search") or "" - return self.model.objects.search(query=query).with_package_counts() + qs = self.model.objects.search(query=query).with_package_counts() + + tokens = parse_sort_tokens(self.request) + self.sort_tokens = tokens + + order_fields = [] + seen = set() + for tok in tokens: + if not tok: + continue + if tok[0] in ("+", "-") and len(tok) > 1: + dir_char = tok[0] + key = tok[1:] + else: + dir_char = None + key = tok + + key = key.lower() + if key in ("id", "vulnerability_id"): + fields = ["vulnerability_id"] + default_dir = "" + elif key == "affected": + fields = ["vulnerable_package_count"] + default_dir = "" + elif key == "fixing": + fields = ["patched_package_count"] + default_dir = "" + else: + continue + + for f in fields: + if dir_char == "-": + prefix = "-" + elif dir_char == "+": + prefix = "" + else: + prefix = default_dir + ofield = f"{prefix}{f}" + if ofield not in seen: + order_fields.append(ofield) + seen.add(ofield) + + if not order_fields: + order_fields = ["vulnerability_id"] + + return qs.order_by(*order_fields) class AdvisorySearch(ListView): @@ -128,13 +281,60 @@ class AdvisorySearch(ListView): def get_context_data(self, **kwargs): context = super().get_context_data(**kwargs) request_query = self.request.GET - context["advisory_search_form"] = VulnerabilitySearchForm(request_query) + # Use AdvisorySearchForm here (imported at top) + context["advisory_search_form"] = AdvisorySearchForm(request_query) context["search"] = request_query.get("search") + context["sorts"] = getattr(self, "sort_tokens", []) return context def get_queryset(self, query=None): query = query or self.request.GET.get("search") or "" - return self.model.objects.search(query=query).with_package_counts() + qs = self.model.objects.search(query=query).with_package_counts() + + tokens = parse_sort_tokens(self.request) + self.sort_tokens = tokens + + order_fields = [] + seen = set() + for tok in tokens: + if not tok: + continue + if tok[0] in ("+", "-") and len(tok) > 1: + dir_char = tok[0] + key = tok[1:] + else: + dir_char = None + key = tok + + key = key.lower() + if key in ("id", "advisory_id"): + fields = ["advisory_id"] + default_dir = "" + elif key == "affected": + fields = ["vulnerable_package_count"] + default_dir = "" + elif key == "fixing": + fields = ["patched_package_count"] + default_dir = "" + else: + continue + + for f in fields: + if dir_char == "-": + prefix = "-" + elif dir_char == "+": + prefix = "" + else: + prefix = default_dir + ofield = f"{prefix}{f}" + if ofield not in seen: + order_fields.append(ofield) + seen.add(ofield) + + if not order_fields: + order_fields = ["advisory_id"] + + return qs.order_by(*order_fields) class PackageDetails(DetailView):