Skip to content

Commit 9181362

Browse files
Renamed project filter query parameter to project_id for Project Summary entities (#668)
* Renamed "project" filter query parameter to "project_id" for Project Summary entities * Fixed project association in create_ml_pipeline function * Added tests for filtering by "project_id" on Project Summary entities * Refactored open api project_id docs params, get_project logic and moved it to requests.py * Applied changes to the frontend * Updated all entities in the project page to filter by project_id * Updated tests to use project_id parameter instead of project * Removed occurrences__project query param
1 parent 284cb14 commit 9181362

8 files changed

Lines changed: 261 additions & 32 deletions

File tree

ami/jobs/views.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
from django.db.models.query import QuerySet
44
from django.forms import IntegerField
55
from django.utils import timezone
6+
from drf_spectacular.utils import extend_schema
67
from rest_framework.decorators import action
78
from rest_framework.response import Response
89

910
from ami.main.api.views import DefaultViewSet
1011
from ami.utils.fields import url_boolean_param
12+
from ami.utils.requests import get_active_project, project_id_doc_param
1113

1214
from .models import Job, JobState, MLJob
1315
from .serializers import JobListSerializer, JobSerializer
@@ -35,7 +37,6 @@ class JobViewSet(DefaultViewSet):
3537
"""
3638

3739
queryset = Job.objects.select_related(
38-
"project",
3940
"deployment",
4041
"pipeline",
4142
"source_image_collection",
@@ -128,7 +129,9 @@ def perform_create(self, serializer):
128129

129130
def get_queryset(self) -> QuerySet:
130131
jobs = super().get_queryset()
131-
132+
project = get_active_project(self.request)
133+
if project:
134+
jobs = jobs.filter(project=project)
132135
cutoff_hours = IntegerField(required=False, min_value=0).clean(
133136
self.request.query_params.get("cutoff_hours", Job.FAILED_CUTOFF_HOURS)
134137
)
@@ -138,3 +141,7 @@ def get_queryset(self) -> QuerySet:
138141
status=JobState.failed_states(),
139142
updated_at__lt=cutoff_datetime,
140143
)
144+
145+
@extend_schema(parameters=[project_id_doc_param])
146+
def list(self, request, *args, **kwargs):
147+
return super().list(request, *args, **kwargs)

ami/main/api/serializers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ def get_occurrence_images(self, obj):
503503

504504
# request = self.context.get("request")
505505
# project_id = request.query_params.get("project") if request else None
506-
project_id = self.context["request"].query_params["project"]
506+
project_id = self.context["request"].query_params["project_id"]
507507
classification_threshold = get_active_classification_threshold(self.context["request"])
508508

509509
return obj.occurrence_images(

ami/main/api/views.py

Lines changed: 83 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from django.forms import BooleanField, CharField, IntegerField
1111
from django.utils import timezone
1212
from django_filters.rest_framework import DjangoFilterBackend
13+
from drf_spectacular.utils import extend_schema
1314
from rest_framework import exceptions as api_exceptions
1415
from rest_framework import filters, serializers, status, viewsets
1516
from rest_framework.decorators import action
@@ -24,7 +25,7 @@
2425
from ami.base.pagination import LimitOffsetPaginationWithPermissions
2526
from ami.base.permissions import IsActiveStaffOrReadOnly
2627
from ami.base.serializers import FilterParamsSerializer, SingleParamSerializer
27-
from ami.utils.requests import get_active_classification_threshold
28+
from ami.utils.requests import get_active_classification_threshold, get_active_project, project_id_doc_param
2829
from ami.utils.storages import ConnectionTestResult
2930

3031
from ..models import (
@@ -137,7 +138,6 @@ class DeploymentViewSet(DefaultViewSet):
137138
"""
138139

139140
queryset = Deployment.objects.select_related("project", "device", "research_site")
140-
filterset_fields = ["project"]
141141
ordering_fields = [
142142
"created_at",
143143
"updated_at",
@@ -160,7 +160,9 @@ def get_serializer_class(self):
160160

161161
def get_queryset(self) -> QuerySet:
162162
qs = super().get_queryset()
163-
163+
project = get_active_project(self.request)
164+
if project:
165+
qs = qs.filter(project=project)
164166
num_example_captures = 10
165167
if self.action == "retrieve":
166168
qs = qs.prefetch_related(
@@ -204,6 +206,10 @@ def sync(self, _request, pk=None) -> Response:
204206
else:
205207
raise api_exceptions.ValidationError(detail="Deployment must have a data source to sync captures from")
206208

209+
@extend_schema(parameters=[project_id_doc_param])
210+
def list(self, request, *args, **kwargs):
211+
return super().list(request, *args, **kwargs)
212+
207213

208214
class EventViewSet(DefaultViewSet):
209215
"""
@@ -212,7 +218,7 @@ class EventViewSet(DefaultViewSet):
212218

213219
queryset = Event.objects.all()
214220
serializer_class = EventSerializer
215-
filterset_fields = ["deployment", "project"]
221+
filterset_fields = ["deployment"]
216222
ordering_fields = [
217223
"created_at",
218224
"updated_at",
@@ -237,6 +243,9 @@ def get_serializer_class(self):
237243

238244
def get_queryset(self) -> QuerySet:
239245
qs: QuerySet = super().get_queryset()
246+
project = get_active_project(self.request)
247+
if project:
248+
qs = qs.filter(project=project)
240249
qs = qs.filter(deployment__isnull=False)
241250
qs = qs.annotate(
242251
duration=models.F("end") - models.F("start"),
@@ -363,6 +372,10 @@ def timeline(self, request, pk=None):
363372
)
364373
return Response(serializer.data)
365374

375+
@extend_schema(parameters=[project_id_doc_param])
376+
def list(self, request, *args, **kwargs):
377+
return super().list(request, *args, **kwargs)
378+
366379

367380
class SourceImageViewSet(DefaultViewSet):
368381
"""
@@ -549,7 +562,7 @@ class SourceImageCollectionViewSet(DefaultViewSet):
549562
)
550563
serializer_class = SourceImageCollectionSerializer
551564

552-
filterset_fields = ["project", "method"]
565+
filterset_fields = ["method"]
553566
ordering_fields = [
554567
"created_at",
555568
"updated_at",
@@ -562,11 +575,14 @@ class SourceImageCollectionViewSet(DefaultViewSet):
562575

563576
def get_queryset(self) -> QuerySet:
564577
classification_threshold = get_active_classification_threshold(self.request)
565-
queryset = (
566-
super()
567-
.get_queryset()
568-
.with_occurrences_count(classification_threshold=classification_threshold) # type: ignore
569-
.with_taxa_count(classification_threshold=classification_threshold)
578+
query_set: QuerySet = super().get_queryset()
579+
project = get_active_project(self.request)
580+
if project:
581+
query_set = query_set.filter(project=project)
582+
queryset = query_set.with_occurrences_count(
583+
classification_threshold=classification_threshold
584+
).with_taxa_count( # type: ignore
585+
classification_threshold=classification_threshold
570586
)
571587
return queryset
572588

@@ -647,6 +663,10 @@ def remove(self, request, pk=None):
647663
}
648664
)
649665

666+
@extend_schema(parameters=[project_id_doc_param])
667+
def list(self, request, *args, **kwargs):
668+
return super().list(request, *args, **kwargs)
669+
650670

651671
class SourceImageUploadViewSet(DefaultViewSet):
652672
"""
@@ -903,7 +923,6 @@ class OccurrenceViewSet(DefaultViewSet):
903923
filterset_fields = [
904924
"event",
905925
"deployment",
906-
"project",
907926
"determination__rank",
908927
"detections__source_image",
909928
]
@@ -933,7 +952,10 @@ def get_serializer_class(self):
933952
return OccurrenceSerializer
934953

935954
def get_queryset(self) -> QuerySet:
955+
project = get_active_project(self.request)
936956
qs = super().get_queryset()
957+
if project:
958+
qs = qs.filter(project=project)
937959
qs = qs.select_related(
938960
"determination",
939961
"deployment",
@@ -961,6 +983,10 @@ def get_queryset(self) -> QuerySet:
961983

962984
return qs
963985

986+
@extend_schema(parameters=[project_id_doc_param])
987+
def list(self, request, *args, **kwargs):
988+
return super().list(request, *args, **kwargs)
989+
964990

965991
class TaxonViewSet(DefaultViewSet):
966992
"""
@@ -1042,23 +1068,22 @@ def filter_taxa_by_observed(self, queryset: QuerySet) -> tuple[QuerySet, bool]:
10421068
"""
10431069

10441070
occurrence_id = self.request.query_params.get("occurrence")
1045-
project_id = self.request.query_params.get("project") or self.request.query_params.get("occurrences__project")
1071+
project = get_active_project(self.request)
10461072
deployment_id = self.request.query_params.get("deployment") or self.request.query_params.get(
10471073
"occurrences__deployment"
10481074
)
10491075
event_id = self.request.query_params.get("event") or self.request.query_params.get("occurrences__event")
10501076
collection_id = self.request.query_params.get("collection")
10511077

1052-
filter_active = any([occurrence_id, project_id, deployment_id, event_id, collection_id])
1078+
filter_active = any([occurrence_id, project, deployment_id, event_id, collection_id])
10531079

1054-
if not project_id:
1080+
if not project:
10551081
# Raise a 400 if no project is specified
10561082
raise api_exceptions.ValidationError(detail="A project must be specified")
10571083

10581084
queryset = super().get_queryset()
10591085
try:
1060-
if project_id:
1061-
project = Project.objects.get(id=project_id)
1086+
if project:
10621087
queryset = queryset.filter(occurrences__project=project)
10631088
if occurrence_id:
10641089
occurrence = Occurrence.objects.get(id=occurrence_id)
@@ -1180,6 +1205,10 @@ def get_queryset(self) -> QuerySet:
11801205

11811206
return qs
11821207

1208+
@extend_schema(parameters=[project_id_doc_param])
1209+
def list(self, request, *args, **kwargs):
1210+
return super().list(request, *args, **kwargs)
1211+
11831212
# def retrieve(self, request: Request, *args, **kwargs) -> Response:
11841213
# """
11851214
# Override the serializer to include the recursive occurrences count
@@ -1207,16 +1236,15 @@ class ClassificationViewSet(DefaultViewSet):
12071236

12081237
class SummaryView(GenericAPIView):
12091238
permission_classes = [IsActiveStaffOrReadOnly]
1210-
filterset_fields = ["project"]
12111239

1240+
@extend_schema(parameters=[project_id_doc_param])
12121241
def get(self, request):
12131242
"""
12141243
Return counts of all models.
12151244
"""
1216-
project_id = request.query_params.get("project")
1245+
project = get_active_project(request)
12171246
confidence_threshold = get_active_classification_threshold(request)
1218-
if project_id:
1219-
project = Project.objects.get(id=project_id)
1247+
if project:
12201248
data = {
12211249
"projects_count": Project.objects.count(), # @TODO filter by current user, here and everywhere!
12221250
"deployments_count": Deployment.objects.filter(project=project).count(),
@@ -1358,13 +1386,24 @@ class SiteViewSet(DefaultViewSet):
13581386

13591387
queryset = Site.objects.all()
13601388
serializer_class = SiteSerializer
1361-
filterset_fields = ["project", "deployments"]
1389+
filterset_fields = ["deployments"]
13621390
ordering_fields = [
13631391
"created_at",
13641392
"updated_at",
13651393
"name",
13661394
]
13671395

1396+
def get_queryset(self) -> QuerySet:
1397+
query_set: QuerySet = super().get_queryset()
1398+
project = get_active_project(self.request)
1399+
if project:
1400+
query_set = query_set.filter(project=project)
1401+
return query_set
1402+
1403+
@extend_schema(parameters=[project_id_doc_param])
1404+
def list(self, request, *args, **kwargs):
1405+
return super().list(request, *args, **kwargs)
1406+
13681407

13691408
class DeviceViewSet(DefaultViewSet):
13701409
"""
@@ -1373,13 +1412,24 @@ class DeviceViewSet(DefaultViewSet):
13731412

13741413
queryset = Device.objects.all()
13751414
serializer_class = DeviceSerializer
1376-
filterset_fields = ["project", "deployments"]
1415+
filterset_fields = ["deployments"]
13771416
ordering_fields = [
13781417
"created_at",
13791418
"updated_at",
13801419
"name",
13811420
]
13821421

1422+
def get_queryset(self) -> QuerySet:
1423+
query_set: QuerySet = super().get_queryset()
1424+
project = get_active_project(self.request)
1425+
if project:
1426+
query_set = query_set.filter(project=project)
1427+
return query_set
1428+
1429+
@extend_schema(parameters=[project_id_doc_param])
1430+
def list(self, request, *args, **kwargs):
1431+
return super().list(request, *args, **kwargs)
1432+
13831433

13841434
class StorageSourceConnectionTestSerializer(serializers.Serializer):
13851435
subdir = serializers.CharField(required=False, allow_null=True)
@@ -1393,13 +1443,23 @@ class StorageSourceViewSet(DefaultViewSet):
13931443

13941444
queryset = S3StorageSource.objects.all()
13951445
serializer_class = StorageSourceSerializer
1396-
filterset_fields = ["project", "deployments"]
1446+
filterset_fields = ["deployments"]
13971447
ordering_fields = [
13981448
"created_at",
13991449
"updated_at",
14001450
"name",
14011451
]
14021452

1453+
def get_queryset(self) -> QuerySet:
1454+
query_set: QuerySet = super().get_queryset()
1455+
project = get_active_project(self.request)
1456+
query_set = query_set.filter(project=project)
1457+
return query_set
1458+
1459+
@extend_schema(parameters=[project_id_doc_param])
1460+
def list(self, request, *args, **kwargs):
1461+
return super().list(request, *args, **kwargs)
1462+
14031463
@action(detail=True, methods=["post"], name="test", serializer_class=StorageSourceConnectionTestSerializer)
14041464
def test(self, request: Request, pk=None) -> Response:
14051465
"""

0 commit comments

Comments
 (0)