Skip to content
Merged
118 changes: 106 additions & 12 deletions ami/main/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from django.forms import BooleanField, CharField, IntegerField
from django.utils import timezone
from django_filters.rest_framework import DjangoFilterBackend
from drf_spectacular.utils import OpenApiParameter, extend_schema
from rest_framework import exceptions as api_exceptions
from rest_framework import filters, serializers, status, viewsets
from rest_framework.decorators import action
Expand Down Expand Up @@ -549,7 +550,7 @@ class SourceImageCollectionViewSet(DefaultViewSet):
)
serializer_class = SourceImageCollectionSerializer

filterset_fields = ["project", "method"]
filterset_fields = ["method"]
ordering_fields = [
"created_at",
"updated_at",
Expand All @@ -562,11 +563,16 @@ class SourceImageCollectionViewSet(DefaultViewSet):

def get_queryset(self) -> QuerySet:
classification_threshold = get_active_classification_threshold(self.request)
queryset = (
super()
.get_queryset()
.with_occurrences_count(classification_threshold=classification_threshold) # type: ignore
.with_taxa_count(classification_threshold=classification_threshold)
query_set: QuerySet = super().get_queryset()
project_id = self.request.query_params.get("project_id")
if project_id is not None:
project = Project.objects.filter(id=project_id).first()
if project:
query_set = query_set.filter(project=project)
queryset = query_set.with_occurrences_count(
classification_threshold=classification_threshold
).with_taxa_count( # type: ignore
classification_threshold=classification_threshold
)
return queryset

Expand Down Expand Up @@ -647,6 +653,19 @@ def remove(self, request, pk=None):
}
)

@extend_schema(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome that you discovered how to customize the auto-generated docs! I will be copying this approach to other places.

parameters=[
OpenApiParameter(
name="project_id",
description="Filter by project ID",
required=False,
type=int,
)
]
)
def list(self, request, *args, **kwargs):
return super().list(request, *args, **kwargs)


class SourceImageUploadViewSet(DefaultViewSet):
"""
Expand Down Expand Up @@ -1207,16 +1226,25 @@ class ClassificationViewSet(DefaultViewSet):

class SummaryView(GenericAPIView):
permission_classes = [IsActiveStaffOrReadOnly]
filterset_fields = ["project"]

@extend_schema(
parameters=[
OpenApiParameter(
name="project_id",
description="Filter by project ID",
required=False,
type=int,
)
]
)
def get(self, request):
"""
Return counts of all models.
"""
project_id = request.query_params.get("project")
project_id = request.query_params.get("project_id")
confidence_threshold = get_active_classification_threshold(request)
if project_id:
project = Project.objects.get(id=project_id)
project = Project.objects.filter(id=project_id).first()
data = {
"projects_count": Project.objects.count(), # @TODO filter by current user, here and everywhere!
"deployments_count": Deployment.objects.filter(project=project).count(),
Expand Down Expand Up @@ -1358,13 +1386,35 @@ class SiteViewSet(DefaultViewSet):

queryset = Site.objects.all()
serializer_class = SiteSerializer
filterset_fields = ["project", "deployments"]
filterset_fields = ["deployments"]
ordering_fields = [
"created_at",
"updated_at",
"name",
]

def get_queryset(self) -> QuerySet:
query_set: QuerySet = super().get_queryset()
project_id = self.request.query_params.get("project_id")
if project_id is not None:
project = Project.objects.filter(id=project_id).first()
if project:
query_set = query_set.filter(project=project)
return query_set

@extend_schema(
parameters=[
OpenApiParameter(
name="project_id",
description="Filter by project ID",
required=False,
type=int,
)
]
)
def list(self, request, *args, **kwargs):
return super().list(request, *args, **kwargs)


class DeviceViewSet(DefaultViewSet):
"""
Expand All @@ -1373,13 +1423,35 @@ class DeviceViewSet(DefaultViewSet):

queryset = Device.objects.all()
serializer_class = DeviceSerializer
filterset_fields = ["project", "deployments"]
filterset_fields = ["deployments"]
ordering_fields = [
"created_at",
"updated_at",
"name",
]

def get_queryset(self) -> QuerySet:
query_set: QuerySet = super().get_queryset()
project_id = self.request.query_params.get("project_id")
if project_id is not None:
project = Project.objects.filter(id=project_id).first()
if project:
query_set = query_set.filter(project=project)
return query_set

@extend_schema(
parameters=[
OpenApiParameter(
name="project_id",
description="Filter by project ID",
required=False,
type=int,
)
]
)
def list(self, request, *args, **kwargs):
return super().list(request, *args, **kwargs)


class StorageSourceConnectionTestSerializer(serializers.Serializer):
subdir = serializers.CharField(required=False, allow_null=True)
Expand All @@ -1393,13 +1465,35 @@ class StorageSourceViewSet(DefaultViewSet):

queryset = S3StorageSource.objects.all()
serializer_class = StorageSourceSerializer
filterset_fields = ["project", "deployments"]
filterset_fields = ["deployments"]
ordering_fields = [
"created_at",
"updated_at",
"name",
]

def get_queryset(self) -> QuerySet:
query_set: QuerySet = super().get_queryset()
project_id = self.request.query_params.get("project_id")
if project_id is not None:
project = Project.objects.filter(id=project_id).first()
if project:
query_set = query_set.filter(project=project)
return query_set

@extend_schema(
parameters=[
OpenApiParameter(
name="project_id",
description="Filter by project ID",
required=False,
type=int,
)
]
)
def list(self, request, *args, **kwargs):
return super().list(request, *args, **kwargs)

@action(detail=True, methods=["post"], name="test", serializer_class=StorageSourceConnectionTestSerializer)
def test(self, request: Request, pk=None) -> Response:
"""
Expand Down
132 changes: 130 additions & 2 deletions ami/main/tests.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,26 @@
import datetime
import logging

from django.db import connection
from django.db import connection, models
from django.test import TestCase
from rest_framework import status
from rest_framework.test import APIRequestFactory, APITestCase
from rich import print

from ami.main.models import Event, Occurrence, Project, Taxon, TaxonRank, group_images_into_events
from ami.main.models import (
Device,
Event,
Occurrence,
Project,
S3StorageSource,
Site,
SourceImage,
SourceImageCollection,
Taxon,
TaxonRank,
group_images_into_events,
)
from ami.ml.models.pipeline import Pipeline
from ami.tests.fixtures.main import create_captures, create_occurrences, create_taxa, setup_test_project
from ami.users.models import User

Expand Down Expand Up @@ -754,3 +768,117 @@ def test_update_subdir(self):
self.other_subdir: self.images_per_dir,
}
self.assertDictEqual(dict(counts), expected_counts)


class TestProjectSettingsFiltering(APITestCase):
"""Test Project Settings filter by project_id"""

def setUp(self) -> None:
for _ in range(3):
project, deployment = setup_test_project(reuse=False)
create_taxa(project=project)
create_captures(deployment=deployment)
group_images_into_events(deployment=deployment)
create_occurrences(deployment=deployment, num=5)
self.project_ids = [project.id for project in Project.objects.all()]

self.user = User.objects.create_user( # type: ignore
email="testuser@insectai.org",
is_staff=True,
)
self.factory = APIRequestFactory()
self.client.force_authenticate(user=self.user)
return super().setUp()

def test_project_summary(self):
project_id = self.project_ids[1]
end_point_url = f"/api/v2/status/summary/?project_id={project_id}"
response = self.client.get(end_point_url)
response_data = response.json()
self.assertEqual(response.status_code, status.HTTP_200_OK)
project = Project.objects.get(pk=project_id)

self.assertEqual(response_data["deployments_count"], project.deployments_count())
self.assertEqual(
response_data["taxa_count"],
Taxon.objects.annotate(occurrences_count=models.Count("occurrences"))
.filter(
occurrences_count__gt=0,
occurrences__determination_score__gte=0,
occurrences__project=project,
)
.distinct()
.count(),
)
self.assertEqual(
response_data["events_count"],
Event.objects.filter(deployment__project=project, deployment__isnull=False).count(),
)
self.assertEqual(
response_data["captures_count"], SourceImage.objects.filter(deployment__project=project).count()
)
self.assertEqual(
response_data["occurrences_count"],
Occurrence.objects.filter(
project=project,
determination_score__gte=0,
event__isnull=False,
).count(),
)
self.assertEqual(
response_data["captures_count"], SourceImage.objects.filter(deployment__project=project).count()
)

def test_project_collections(self):
project_id = self.project_ids[1]
project = Project.objects.get(pk=project_id)
end_point_url = f"/api/v2/captures/collections/?project_id={project_id}"
response = self.client.get(end_point_url)
response_data = response.json()
expected_project_collection_ids = {
source_image_collection.id
for source_image_collection in SourceImageCollection.objects.filter(project=project)
}
response_source_image_collection_ids = {result.get("id") for result in response_data["results"]}
self.assertEqual(response_source_image_collection_ids, expected_project_collection_ids)

def test_project_pipelines(self):
project_id = self.project_ids[0]
project = Project.objects.get(pk=project_id)
end_point_url = f"/api/v2/ml/pipelines/?project_id={project_id}"
response = self.client.get(end_point_url)
response_data = response.json()

expected_project_pipeline_ids = {pipeline.id for pipeline in Pipeline.objects.filter(projects=project)}
response_pipeline_ids = {pipeline.get("id") for pipeline in response_data["results"]}
self.assertEqual(response_pipeline_ids, expected_project_pipeline_ids)

def test_project_storage(self):
project_id = self.project_ids[0]
project = Project.objects.get(pk=project_id)
end_point_url = f"/api/v2/storage/?project_id={project_id}"
response = self.client.get(end_point_url)
response_data = response.json()
expected_storage_ids = {storage.id for storage in S3StorageSource.objects.filter(project=project)}
response_storage_ids = {storage.get("id") for storage in response_data["results"]}
self.assertEqual(response_storage_ids, expected_storage_ids)

def test_project_sites(self):
project_id = self.project_ids[1]
project = Project.objects.get(pk=project_id)
end_point_url = f"/api/v2/deployments/sites/?project_id={project_id}"
response = self.client.get(end_point_url)
response_data = response.json()
exepcted_site_ids = {site.id for site in Site.objects.filter(project=project)}
response_site_ids = {site.get("id") for site in response_data["results"]}
self.assertEqual(response_site_ids, exepcted_site_ids)

def test_project_devices(self):
project_id = self.project_ids[1]
project = Project.objects.get(pk=project_id)
end_point_url = f"/api/v2/deployments/devices/?project_id={project_id}"
response = self.client.get(end_point_url)
response_data = response.json()
exepcted_device_ids = {device.id for device in Device.objects.filter(project=project)}
response_device_ids = {device.get("id") for device in response_data["results"]}
self.assertEqual(response_device_ids, exepcted_device_ids)
28 changes: 28 additions & 0 deletions ami/ml/views.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from django.db.models.query import QuerySet
from drf_spectacular.utils import OpenApiParameter, extend_schema

from ami.main.api.views import DefaultViewSet
from ami.main.models import Project

from .models.algorithm import Algorithm
from .models.pipeline import Pipeline
Expand Down Expand Up @@ -35,6 +39,30 @@ class PipelineViewSet(DefaultViewSet):
"created_at",
"updated_at",
]

def get_queryset(self) -> QuerySet: # @TBD
query_set: QuerySet = super().get_queryset()
project_id = self.request.query_params.get("project_id")
if project_id is not None:
project = Project.objects.filter(id=project_id).first()
if project:
query_set = query_set.filter(projects=project)
return query_set

# @TBD
@extend_schema(
parameters=[
OpenApiParameter(
name="project_id",
description="Filter by project ID",
required=False,
type=int,
)
]
)
def list(self, request, *args, **kwargs):
return super().list(request, *args, **kwargs)

# Don't enable projects filter until we can use the current users
# membership to filter the projects.
# filterset_fields = ["projects"]
2 changes: 1 addition & 1 deletion ami/tests/fixtures/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def create_ml_pipeline(project):
for algorithm_data in pipeline_data["algorithms"]:
algorithm, _ = Algorithm.objects.get_or_create(name=algorithm_data["name"], key=algorithm_data["key"])
pipeline.algorithms.add(algorithm)

pipeline.projects.add(project) # @TBD
pipeline.save()

return pipeline
Expand Down
Loading