Skip to content
Merged
11 changes: 9 additions & 2 deletions ami/jobs/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from django.db.models.query import QuerySet
from django.forms import IntegerField
from django.utils import timezone
from drf_spectacular.utils import extend_schema
from rest_framework.decorators import action
from rest_framework.response import Response

from ami.main.api.views import DefaultViewSet
from ami.utils.fields import url_boolean_param
from ami.utils.requests import get_active_project, project_id_doc_param

from .models import Job, JobState, MLJob
from .serializers import JobListSerializer, JobSerializer
Expand Down Expand Up @@ -35,7 +37,6 @@ class JobViewSet(DefaultViewSet):
"""

queryset = Job.objects.select_related(
"project",
"deployment",
"pipeline",
"source_image_collection",
Expand Down Expand Up @@ -128,7 +129,9 @@ def perform_create(self, serializer):

def get_queryset(self) -> QuerySet:
jobs = super().get_queryset()

project = get_active_project(self.request)
if project:
jobs = jobs.filter(project=project)
cutoff_hours = IntegerField(required=False, min_value=0).clean(
self.request.query_params.get("cutoff_hours", Job.FAILED_CUTOFF_HOURS)
)
Expand All @@ -138,3 +141,7 @@ def get_queryset(self) -> QuerySet:
status=JobState.failed_states(),
updated_at__lt=cutoff_datetime,
)

@extend_schema(parameters=[project_id_doc_param])
def list(self, request, *args, **kwargs):
return super().list(request, *args, **kwargs)
2 changes: 1 addition & 1 deletion ami/main/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ def get_occurrence_images(self, obj):

# request = self.context.get("request")
# project_id = request.query_params.get("project") if request else None
project_id = self.context["request"].query_params["project"]
project_id = self.context["request"].query_params["project_id"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for finding this. When we optimize the Taxa view, I think we will be able to remove this function.

classification_threshold = get_active_classification_threshold(self.context["request"])

return obj.occurrence_images(
Expand Down
44 changes: 36 additions & 8 deletions ami/main/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ class DeploymentViewSet(DefaultViewSet):
"""

queryset = Deployment.objects.select_related("project", "device", "research_site")
filterset_fields = ["project"]
ordering_fields = [
"created_at",
"updated_at",
Expand All @@ -161,7 +160,9 @@ def get_serializer_class(self):

def get_queryset(self) -> QuerySet:
qs = super().get_queryset()

project = get_active_project(self.request)
if project:
qs = qs.filter(project=project)
num_example_captures = 10
if self.action == "retrieve":
qs = qs.prefetch_related(
Expand Down Expand Up @@ -205,6 +206,10 @@ def sync(self, _request, pk=None) -> Response:
else:
raise api_exceptions.ValidationError(detail="Deployment must have a data source to sync captures from")

@extend_schema(parameters=[project_id_doc_param])
def list(self, request, *args, **kwargs):
return super().list(request, *args, **kwargs)


class EventViewSet(DefaultViewSet):
"""
Expand All @@ -213,7 +218,7 @@ class EventViewSet(DefaultViewSet):

queryset = Event.objects.all()
serializer_class = EventSerializer
filterset_fields = ["deployment", "project"]
filterset_fields = ["deployment"]
ordering_fields = [
"created_at",
"updated_at",
Expand All @@ -238,6 +243,9 @@ def get_serializer_class(self):

def get_queryset(self) -> QuerySet:
qs: QuerySet = super().get_queryset()
project = get_active_project(self.request)
if project:
qs = qs.filter(project=project)
qs = qs.filter(deployment__isnull=False)
qs = qs.annotate(
duration=models.F("end") - models.F("start"),
Expand Down Expand Up @@ -364,6 +372,10 @@ def timeline(self, request, pk=None):
)
return Response(serializer.data)

@extend_schema(parameters=[project_id_doc_param])
def list(self, request, *args, **kwargs):
return super().list(request, *args, **kwargs)


class SourceImageViewSet(DefaultViewSet):
"""
Expand Down Expand Up @@ -565,7 +577,8 @@ def get_queryset(self) -> QuerySet:
classification_threshold = get_active_classification_threshold(self.request)
query_set: QuerySet = super().get_queryset()
project = get_active_project(self.request)
query_set = query_set.filter(project=project)
if project:
query_set = query_set.filter(project=project)
queryset = query_set.with_occurrences_count(
classification_threshold=classification_threshold
).with_taxa_count( # type: ignore
Expand Down Expand Up @@ -910,7 +923,6 @@ class OccurrenceViewSet(DefaultViewSet):
filterset_fields = [
"event",
"deployment",
"project",
"determination__rank",
"detections__source_image",
]
Expand Down Expand Up @@ -940,7 +952,10 @@ def get_serializer_class(self):
return OccurrenceSerializer

def get_queryset(self) -> QuerySet:
project = get_active_project(self.request)
qs = super().get_queryset()
if project:
qs = qs.filter(project=project)
qs = qs.select_related(
"determination",
"deployment",
Expand Down Expand Up @@ -968,6 +983,10 @@ def get_queryset(self) -> QuerySet:

return qs

@extend_schema(parameters=[project_id_doc_param])
def list(self, request, *args, **kwargs):
return super().list(request, *args, **kwargs)


class TaxonViewSet(DefaultViewSet):
"""
Expand Down Expand Up @@ -1049,7 +1068,10 @@ def filter_taxa_by_observed(self, queryset: QuerySet) -> tuple[QuerySet, bool]:
"""

occurrence_id = self.request.query_params.get("occurrence")
project_id = self.request.query_params.get("project") or self.request.query_params.get("occurrences__project")
project_id = self.request.query_params.get("project") or self.request.query_params.get(
"occurrences__project"
) # @TBD
project_id = self.request.query_params.get("project_id")
deployment_id = self.request.query_params.get("deployment") or self.request.query_params.get(
"occurrences__deployment"
)
Expand Down Expand Up @@ -1187,6 +1209,10 @@ def get_queryset(self) -> QuerySet:

return qs

@extend_schema(parameters=[project_id_doc_param])
def list(self, request, *args, **kwargs):
return super().list(request, *args, **kwargs)

# def retrieve(self, request: Request, *args, **kwargs) -> Response:
# """
# Override the serializer to include the recursive occurrences count
Expand Down Expand Up @@ -1374,7 +1400,8 @@ class SiteViewSet(DefaultViewSet):
def get_queryset(self) -> QuerySet:
query_set: QuerySet = super().get_queryset()
project = get_active_project(self.request)
query_set = query_set.filter(project=project)
if project:
query_set = query_set.filter(project=project)
return query_set

@extend_schema(parameters=[project_id_doc_param])
Expand All @@ -1399,7 +1426,8 @@ class DeviceViewSet(DefaultViewSet):
def get_queryset(self) -> QuerySet:
query_set: QuerySet = super().get_queryset()
project = get_active_project(self.request)
query_set = query_set.filter(project=project)
if project:
query_set = query_set.filter(project=project)
return query_set

@extend_schema(parameters=[project_id_doc_param])
Expand Down
4 changes: 2 additions & 2 deletions ami/main/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ def setUp(self) -> None:
def test_occurrences_for_project(self):
# Test that occurrences are specific to each project
for project in [self.project_one, self.project_two]:
response = self.client.get(f"/api/v2/occurrences/?project={project.pk}")
response = self.client.get(f"/api/v2/occurrences/?project_id={project.pk}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work tracking down all the existing uses!!!

self.assertEqual(response.status_code, 200)
self.assertEqual(response.json()["count"], Occurrence.objects.filter(project=project).count())

Expand Down Expand Up @@ -592,7 +592,7 @@ def _test_taxa_for_project(self, project: Project):
"""
from ami.main.models import Taxon

response = self.client.get(f"/api/v2/taxa/?project={project.pk}")
response = self.client.get(f"/api/v2/taxa/?project_id={project.pk}")
self.assertEqual(response.status_code, 200)
project_occurred_taxa = Taxon.objects.filter(occurrences__project=project).distinct()
# project_any_taxa = Taxon.objects.filter(projects=project)
Expand Down
3 changes: 2 additions & 1 deletion ami/ml/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ class PipelineViewSet(DefaultViewSet):
def get_queryset(self) -> QuerySet:
query_set: QuerySet = super().get_queryset()
project = get_active_project(self.request)
query_set = query_set.filter(projects=project)
if project:
query_set = query_set.filter(projects=project)
return query_set

@extend_schema(parameters=[project_id_doc_param])
Expand Down
Loading