1010from django .forms import BooleanField , CharField , IntegerField
1111from django .utils import timezone
1212from django_filters .rest_framework import DjangoFilterBackend
13+ from drf_spectacular .utils import extend_schema
1314from rest_framework import exceptions as api_exceptions
1415from rest_framework import filters , serializers , status , viewsets
1516from rest_framework .decorators import action
2425from ami .base .pagination import LimitOffsetPaginationWithPermissions
2526from ami .base .permissions import IsActiveStaffOrReadOnly
2627from ami .base .serializers import FilterParamsSerializer , SingleParamSerializer
27- from ami .utils .requests import get_active_classification_threshold
28+ from ami .utils .requests import get_active_classification_threshold , get_active_project , project_id_doc_param
2829from ami .utils .storages import ConnectionTestResult
2930
3031from ..models import (
@@ -137,7 +138,6 @@ class DeploymentViewSet(DefaultViewSet):
137138 """
138139
139140 queryset = Deployment .objects .select_related ("project" , "device" , "research_site" )
140- filterset_fields = ["project" ]
141141 ordering_fields = [
142142 "created_at" ,
143143 "updated_at" ,
@@ -160,7 +160,9 @@ def get_serializer_class(self):
160160
161161 def get_queryset (self ) -> QuerySet :
162162 qs = super ().get_queryset ()
163-
163+ project = get_active_project (self .request )
164+ if project :
165+ qs = qs .filter (project = project )
164166 num_example_captures = 10
165167 if self .action == "retrieve" :
166168 qs = qs .prefetch_related (
@@ -204,6 +206,10 @@ def sync(self, _request, pk=None) -> Response:
204206 else :
205207 raise api_exceptions .ValidationError (detail = "Deployment must have a data source to sync captures from" )
206208
209+ @extend_schema (parameters = [project_id_doc_param ])
210+ def list (self , request , * args , ** kwargs ):
211+ return super ().list (request , * args , ** kwargs )
212+
207213
208214class EventViewSet (DefaultViewSet ):
209215 """
@@ -212,7 +218,7 @@ class EventViewSet(DefaultViewSet):
212218
213219 queryset = Event .objects .all ()
214220 serializer_class = EventSerializer
215- filterset_fields = ["deployment" , "project" ]
221+ filterset_fields = ["deployment" ]
216222 ordering_fields = [
217223 "created_at" ,
218224 "updated_at" ,
@@ -237,6 +243,9 @@ def get_serializer_class(self):
237243
238244 def get_queryset (self ) -> QuerySet :
239245 qs : QuerySet = super ().get_queryset ()
246+ project = get_active_project (self .request )
247+ if project :
248+ qs = qs .filter (project = project )
240249 qs = qs .filter (deployment__isnull = False )
241250 qs = qs .annotate (
242251 duration = models .F ("end" ) - models .F ("start" ),
@@ -363,6 +372,10 @@ def timeline(self, request, pk=None):
363372 )
364373 return Response (serializer .data )
365374
375+ @extend_schema (parameters = [project_id_doc_param ])
376+ def list (self , request , * args , ** kwargs ):
377+ return super ().list (request , * args , ** kwargs )
378+
366379
367380class SourceImageViewSet (DefaultViewSet ):
368381 """
@@ -549,7 +562,7 @@ class SourceImageCollectionViewSet(DefaultViewSet):
549562 )
550563 serializer_class = SourceImageCollectionSerializer
551564
552- filterset_fields = ["project" , " method" ]
565+ filterset_fields = ["method" ]
553566 ordering_fields = [
554567 "created_at" ,
555568 "updated_at" ,
@@ -562,11 +575,14 @@ class SourceImageCollectionViewSet(DefaultViewSet):
562575
563576 def get_queryset (self ) -> QuerySet :
564577 classification_threshold = get_active_classification_threshold (self .request )
565- queryset = (
566- super ()
567- .get_queryset ()
568- .with_occurrences_count (classification_threshold = classification_threshold ) # type: ignore
569- .with_taxa_count (classification_threshold = classification_threshold )
578+ query_set : QuerySet = super ().get_queryset ()
579+ project = get_active_project (self .request )
580+ if project :
581+ query_set = query_set .filter (project = project )
582+ queryset = query_set .with_occurrences_count (
583+ classification_threshold = classification_threshold
584+ ).with_taxa_count ( # type: ignore
585+ classification_threshold = classification_threshold
570586 )
571587 return queryset
572588
@@ -647,6 +663,10 @@ def remove(self, request, pk=None):
647663 }
648664 )
649665
666+ @extend_schema (parameters = [project_id_doc_param ])
667+ def list (self , request , * args , ** kwargs ):
668+ return super ().list (request , * args , ** kwargs )
669+
650670
651671class SourceImageUploadViewSet (DefaultViewSet ):
652672 """
@@ -903,7 +923,6 @@ class OccurrenceViewSet(DefaultViewSet):
903923 filterset_fields = [
904924 "event" ,
905925 "deployment" ,
906- "project" ,
907926 "determination__rank" ,
908927 "detections__source_image" ,
909928 ]
@@ -933,7 +952,10 @@ def get_serializer_class(self):
933952 return OccurrenceSerializer
934953
935954 def get_queryset (self ) -> QuerySet :
955+ project = get_active_project (self .request )
936956 qs = super ().get_queryset ()
957+ if project :
958+ qs = qs .filter (project = project )
937959 qs = qs .select_related (
938960 "determination" ,
939961 "deployment" ,
@@ -961,6 +983,10 @@ def get_queryset(self) -> QuerySet:
961983
962984 return qs
963985
986+ @extend_schema (parameters = [project_id_doc_param ])
987+ def list (self , request , * args , ** kwargs ):
988+ return super ().list (request , * args , ** kwargs )
989+
964990
965991class TaxonViewSet (DefaultViewSet ):
966992 """
@@ -1042,23 +1068,22 @@ def filter_taxa_by_observed(self, queryset: QuerySet) -> tuple[QuerySet, bool]:
10421068 """
10431069
10441070 occurrence_id = self .request .query_params .get ("occurrence" )
1045- project_id = self . request . query_params . get ( "project" ) or self .request . query_params . get ( "occurrences__project" )
1071+ project = get_active_project ( self .request )
10461072 deployment_id = self .request .query_params .get ("deployment" ) or self .request .query_params .get (
10471073 "occurrences__deployment"
10481074 )
10491075 event_id = self .request .query_params .get ("event" ) or self .request .query_params .get ("occurrences__event" )
10501076 collection_id = self .request .query_params .get ("collection" )
10511077
1052- filter_active = any ([occurrence_id , project_id , deployment_id , event_id , collection_id ])
1078+ filter_active = any ([occurrence_id , project , deployment_id , event_id , collection_id ])
10531079
1054- if not project_id :
1080+ if not project :
10551081 # Raise a 400 if no project is specified
10561082 raise api_exceptions .ValidationError (detail = "A project must be specified" )
10571083
10581084 queryset = super ().get_queryset ()
10591085 try :
1060- if project_id :
1061- project = Project .objects .get (id = project_id )
1086+ if project :
10621087 queryset = queryset .filter (occurrences__project = project )
10631088 if occurrence_id :
10641089 occurrence = Occurrence .objects .get (id = occurrence_id )
@@ -1180,6 +1205,10 @@ def get_queryset(self) -> QuerySet:
11801205
11811206 return qs
11821207
1208+ @extend_schema (parameters = [project_id_doc_param ])
1209+ def list (self , request , * args , ** kwargs ):
1210+ return super ().list (request , * args , ** kwargs )
1211+
11831212 # def retrieve(self, request: Request, *args, **kwargs) -> Response:
11841213 # """
11851214 # Override the serializer to include the recursive occurrences count
@@ -1207,16 +1236,15 @@ class ClassificationViewSet(DefaultViewSet):
12071236
12081237class SummaryView (GenericAPIView ):
12091238 permission_classes = [IsActiveStaffOrReadOnly ]
1210- filterset_fields = ["project" ]
12111239
1240+ @extend_schema (parameters = [project_id_doc_param ])
12121241 def get (self , request ):
12131242 """
12141243 Return counts of all models.
12151244 """
1216- project_id = request . query_params . get ( "project" )
1245+ project = get_active_project ( request )
12171246 confidence_threshold = get_active_classification_threshold (request )
1218- if project_id :
1219- project = Project .objects .get (id = project_id )
1247+ if project :
12201248 data = {
12211249 "projects_count" : Project .objects .count (), # @TODO filter by current user, here and everywhere!
12221250 "deployments_count" : Deployment .objects .filter (project = project ).count (),
@@ -1358,13 +1386,24 @@ class SiteViewSet(DefaultViewSet):
13581386
13591387 queryset = Site .objects .all ()
13601388 serializer_class = SiteSerializer
1361- filterset_fields = ["project" , " deployments" ]
1389+ filterset_fields = ["deployments" ]
13621390 ordering_fields = [
13631391 "created_at" ,
13641392 "updated_at" ,
13651393 "name" ,
13661394 ]
13671395
1396+ def get_queryset (self ) -> QuerySet :
1397+ query_set : QuerySet = super ().get_queryset ()
1398+ project = get_active_project (self .request )
1399+ if project :
1400+ query_set = query_set .filter (project = project )
1401+ return query_set
1402+
1403+ @extend_schema (parameters = [project_id_doc_param ])
1404+ def list (self , request , * args , ** kwargs ):
1405+ return super ().list (request , * args , ** kwargs )
1406+
13681407
13691408class DeviceViewSet (DefaultViewSet ):
13701409 """
@@ -1373,13 +1412,24 @@ class DeviceViewSet(DefaultViewSet):
13731412
13741413 queryset = Device .objects .all ()
13751414 serializer_class = DeviceSerializer
1376- filterset_fields = ["project" , " deployments" ]
1415+ filterset_fields = ["deployments" ]
13771416 ordering_fields = [
13781417 "created_at" ,
13791418 "updated_at" ,
13801419 "name" ,
13811420 ]
13821421
1422+ def get_queryset (self ) -> QuerySet :
1423+ query_set : QuerySet = super ().get_queryset ()
1424+ project = get_active_project (self .request )
1425+ if project :
1426+ query_set = query_set .filter (project = project )
1427+ return query_set
1428+
1429+ @extend_schema (parameters = [project_id_doc_param ])
1430+ def list (self , request , * args , ** kwargs ):
1431+ return super ().list (request , * args , ** kwargs )
1432+
13831433
13841434class StorageSourceConnectionTestSerializer (serializers .Serializer ):
13851435 subdir = serializers .CharField (required = False , allow_null = True )
@@ -1393,13 +1443,23 @@ class StorageSourceViewSet(DefaultViewSet):
13931443
13941444 queryset = S3StorageSource .objects .all ()
13951445 serializer_class = StorageSourceSerializer
1396- filterset_fields = ["project" , " deployments" ]
1446+ filterset_fields = ["deployments" ]
13971447 ordering_fields = [
13981448 "created_at" ,
13991449 "updated_at" ,
14001450 "name" ,
14011451 ]
14021452
1453+ def get_queryset (self ) -> QuerySet :
1454+ query_set : QuerySet = super ().get_queryset ()
1455+ project = get_active_project (self .request )
1456+ query_set = query_set .filter (project = project )
1457+ return query_set
1458+
1459+ @extend_schema (parameters = [project_id_doc_param ])
1460+ def list (self , request , * args , ** kwargs ):
1461+ return super ().list (request , * args , ** kwargs )
1462+
14031463 @action (detail = True , methods = ["post" ], name = "test" , serializer_class = StorageSourceConnectionTestSerializer )
14041464 def test (self , request : Request , pk = None ) -> Response :
14051465 """
0 commit comments