Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
b8c8569
Add logits field to Detections table
mihow Mar 13, 2025
8812878
Save raw model logits to results (from species classifier)
mihow Mar 13, 2025
65ee9a8
Add logits to export
mihow Mar 13, 2025
71545ca
feat: add panama plus model
Yuyan-C Apr 7, 2025
16072fc
Add model for Kenya & Uganda moths from UKCEH / Turing (#78)
mihow Apr 15, 2025
18b132e
feat: add ood score
Yuyan-C Apr 22, 2025
57e7098
feat: add ood_score to ClassificationResponse
Yuyan-C Apr 23, 2025
95a8a7b
feat: format prediction result with ClassifierResult
Yuyan-C Apr 23, 2025
131c896
cleanup panama plus class
Yuyan-C Apr 23, 2025
fbcab94
chore: remove print statements, add logging
mihow Apr 28, 2025
038d777
chore: reorder imports
mihow Apr 28, 2025
b8feba6
Merge branch 'main' of https://github.com/RolnickLab/ami-data-manager…
mihow Apr 28, 2025
601cb19
feat: allow setting log level from environment variable
mihow Apr 28, 2025
c1516f3
fix: set terminal/intermediate in classification responses
mihow Apr 28, 2025
0773eb1
feat: ensure ood scores are between 0 & 1 and inverted
mihow Apr 28, 2025
9143e14
Merge branch 'feat/add-classification-features-to-response' of https:…
mihow Apr 30, 2025
4d77814
fix: don't scale the OOD score
mihow May 1, 2025
cfe0210
Check for MPS device
mihow Mar 21, 2023
81a2f07
Fall back to CPU for object detection
mihow Jul 10, 2023
28e8719
feat: add new panama model
Yuyan-C May 26, 2025
15f8595
feat: add panama new model
Yuyan-C May 26, 2025
82fcb7d
feat: add panama new to gradio
Yuyan-C May 26, 2025
6268b32
chore: version the panama plus model name
mihow May 26, 2025
4366fbb
fix: typo in the renamed title
mihow May 26, 2025
7bb3f2d
Support data import to AMI platform DB
Jul 11, 2023
2515e0b
feat: new exporter using PipelineResultsResponse schema for Antenna
mihow Aug 7, 2025
6e12c44
feat: add deployment data to api export format
mihow Aug 7, 2025
078f200
fix: format of api export should match pipelineresultsresponse
mihow Aug 7, 2025
036a687
feat: incomplete support for category maps in the api exports
mihow Aug 7, 2025
4a60a30
chore: clean up
mihow Aug 7, 2025
ef0d9ef
feat: require the user to specify a valid pipeline name for import
mihow Aug 7, 2025
914136e
feat: split occurrence exports into multiple files
mihow Aug 7, 2025
3a947a3
Merge branch 'feat/add-logits' of https://github.com/RolnickLab/ami-d…
mihow Aug 7, 2025
8c85f89
feat: command to reprocess existing detections
mihow Aug 7, 2025
0022024
Merge branch 'platform-export-clean' of https://github.com/RolnickLab…
mihow Aug 8, 2025
bac5e95
feat[exports]: add logits & cnn features to export for antenna
mihow Aug 8, 2025
1d91553
feat[export]: use actual model name used for algorithm in exports
mihow Aug 8, 2025
7317fcd
feat[export]: update name of export command and add default pipeline
mihow Aug 8, 2025
cbc1f44
fix[export]: allow detections with now label (incomplete reprocessing)
mihow Aug 8, 2025
a7019d2
fix[exports]: add missing deployments to export
mihow Aug 8, 2025
4c5b13e
fix[exports]: correct algorithm keys in export
mihow Aug 8, 2025
6a1b16f
fix[exports]: use algorithm key not pipeline key
mihow Aug 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions trapdata/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
MothClassifierGlobal,
MothClassifierPanama,
MothClassifierPanama2024,
MothClassifierPanamaPlus2025,
MothClassifierQuebecVermont,
MothClassifierTuringAnguilla,
MothClassifierTuringCostaRica,
Expand All @@ -39,6 +40,7 @@


CLASSIFIER_CHOICES = {
"panama_plus_moths_2025": MothClassifierPanamaPlus2025,
"panama_moths_2023": MothClassifierPanama,
"panama_moths_2024": MothClassifierPanama2024,
"quebec_vermont_moths_2023": MothClassifierQuebecVermont,
Expand Down
117 changes: 91 additions & 26 deletions trapdata/api/models/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
TuringAnguillaSpeciesClassifier,
TuringCostaRicaSpeciesClassifier,
UKDenmarkMothSpeciesClassifier2024,
PanamaPlusWithOODClassifier2025,
)

from ..datasets import ClassificationImageDataset
Expand All @@ -25,6 +26,11 @@
SourceImage,
)
from .base import APIInferenceBaseClass
from trapdata.ml.models.base import ClassifierResult

from trapdata.ml.utils import StopWatch
import torch.utils.data
from sentry_sdk import start_transaction


class APIMothClassifier(
Expand Down Expand Up @@ -59,44 +65,64 @@ def get_dataset(self):
batch_size=self.batch_size,
)

def post_process_batch(self, logits: torch.Tensor):
def get_ood_score(self, preds):
pass

def post_process_batch(
self, logits: torch.Tensor, features: torch.Tensor | None = None
):
"""
Return the labels, softmax/calibrated scores, and the original logits for
each image in the batch.

Almost like the base class method, but we need to return the logits as well.
each image in the batch, along with optional feature vectors.
"""
predictions = torch.nn.functional.softmax(logits, dim=1)
predictions = predictions.cpu().numpy()

if self.class_prior is None:
ood_scores = np.max(predictions, axis=-1)
else:
ood_scores = np.max(predictions - self.class_prior, axis=-1)

features = features.cpu() if features is not None else None
batch_results = []
for pred in predictions:
# Get all class indices and their corresponding scores

logits = logits.cpu().numpy()

for i, pred in enumerate(predictions):
class_indices = np.arange(len(pred))
scores = pred
labels = [self.category_map[i] for i in class_indices]
batch_results.append(list(zip(labels, scores, pred)))
ood_score = ood_scores[i]
logit = logits[i].tolist()
feature = features[i].tolist() if features is not None else None

result = ClassifierResult(
feature=feature,
labels=labels,
logit=logit,
scores=pred,
ood_score=ood_score,
)

logger.debug(f"Post-processing result batch: {batch_results}")
batch_results.append(result)

logger.debug(f"Post-processing result batch with {len(batch_results)} entries.")
return batch_results

def get_best_label(self, predictions):
"""
Convenience method to get the best label from the predictions, which are a list of tuples
in the order of the model's class index, NOT the values.
def predict_batch(self, batch, return_features: bool = False):
batch_input = batch.to(self.device, non_blocking=True)

This must not modify the predictions list!
if return_features:
features = self.get_features(batch_input)
logits = self.model(batch_input)
return logits, features

predictions look like:
[
('label1', score1, logit1),
('label2', score2, logit2),
...
]
"""
best_pred = max(predictions, key=lambda x: x[1])
best_label = best_pred[0]
logits = self.model(batch_input)
return logits, None

def get_best_label(self, predictions):
best_label = predictions.labels[np.argmax(predictions.scores)]
return best_label

def save_results(
Expand All @@ -109,15 +135,16 @@ def save_results(
):
detection = self.detections[detection_idx]
assert detection.source_image_id == image_id
_labels, scores, logits = zip(*predictions)

classification = ClassificationResponse(
classification=self.get_best_label(predictions),
scores=scores,
logits=logits,
scores=predictions.scores,
ood_score=predictions.ood_score,
logits=predictions.logit,
features=predictions.feature,
inference_time=seconds_per_item,
algorithm=AlgorithmReference(name=self.name, key=self.get_key()),
timestamp=datetime.datetime.now(),
terminal=self.terminal,
)
self.update_classification(detection, classification)

Expand All @@ -139,12 +166,45 @@ def update_classification(
f"Total classifications: {len(detection.classifications)}"
)

@torch.no_grad()
def run(self) -> list[DetectionResponse]:
logger.info(
f"Starting {self.__class__.__name__} run with {len(self.results)} "
"detections"
)
super().run()
torch.cuda.empty_cache()

for i, batch in enumerate(self.dataloader):
if not batch:
logger.info(f"Batch {i+1} is empty, skipping")
continue

item_ids, batch_input = batch

logger.info(
f"Processing batch {i+1}, about {len(self.dataloader)} remaining"
)

with StopWatch() as batch_time:
with start_transaction(op="inference_batch", name=self.name):
logits, features = self.predict_batch(
batch_input, return_features=True
)

seconds_per_item = batch_time.duration / len(logits)

batch_output = list(self.post_process_batch(logits, features=features))
if isinstance(item_ids, (np.ndarray, torch.Tensor)):
item_ids = item_ids.tolist()

logger.info(f"Saving results from {len(item_ids)} items")
self.save_results(
item_ids,
batch_output,
seconds_per_item=seconds_per_item,
)
logger.info(f"{self.name} Batch -- Done")

logger.info(
f"Finished {self.__class__.__name__} run. "
f"Processed {len(self.results)} detections"
Expand Down Expand Up @@ -188,3 +248,8 @@ class MothClassifierTuringAnguilla(APIMothClassifier, TuringAnguillaSpeciesClass

class MothClassifierGlobal(APIMothClassifier, GlobalMothSpeciesClassifier):
pass


class MothClassifierPanamaPlus2025(APIMothClassifier, PanamaPlusWithOODClassifier2025):

pass
2 changes: 2 additions & 0 deletions trapdata/api/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ class ClassificationResponse(pydantic.BaseModel):
),
repr=False, # Too long to display in the repr
)

ood_score: float | None = None
inference_time: float | None = None
algorithm: AlgorithmReference
terminal: bool = True
Expand Down
75 changes: 75 additions & 0 deletions trapdata/api/tests/test_ood.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import os
import pathlib
from unittest import TestCase
from fastapi.testclient import TestClient
from trapdata.api.api import PipelineChoice, PipelineRequest, PipelineResponse, app
from trapdata.api.schemas import SourceImageRequest
from trapdata.api.tests.image_server import StaticFileTestServer
from trapdata.tests import TEST_IMAGES_BASE_PATH


class TestFeatureExtractionAPI(TestCase):
@classmethod
def setUpClass(cls):
cls.test_images_dir = pathlib.Path(TEST_IMAGES_BASE_PATH)
cls.file_server = StaticFileTestServer(cls.test_images_dir)
cls.client = TestClient(app)

@classmethod
def tearDownClass(cls):
cls.file_server.stop()

def get_local_test_images(self, num=1):
image_paths = [
"panama/01-20231110214539-snapshot.jpg",
"panama/01-20231111032659-snapshot.jpg",
"panama/01-20231111015309-snapshot.jpg",
]
return [
SourceImageRequest(id="0", url=self.file_server.get_url(image_path))
for image_path in image_paths[:num]
]

def get_pipeline_response(
self,
pipeline_slug="panama_plus_moths_2025",
num_images=1,
):
"""
Utility method to send a pipeline request and return the parsed response.
"""
test_images = self.get_local_test_images(num=num_images)
pipeline_request = PipelineRequest(
pipeline=PipelineChoice[pipeline_slug],
source_images=test_images,
)

with self.file_server:
response = self.client.post("/process", json=pipeline_request.model_dump())
assert response.status_code == 200
return PipelineResponse(**response.json())

def test_ood_scores_from_pipeline(self):
"""
Run a local image through the pipeline and validate extracted features.
"""
pipeline_response = self.get_pipeline_response()

self.assertTrue(pipeline_response.detections, "No detections returned")
for detection in pipeline_response.detections:
for classification in detection.classifications:
print(classification)
print(classification.ood_score)

# if classification.terminal:
# ood_scores = classification.ood_scores
# features = classification.features
# self.assertIsNotNone(features, "Features should not be None")
# self.assertIsInstance(features, list, "Features should be a list")
# self.assertTrue(
# all(isinstance(x, float) for x in features),
# "All features should be floats",
# )
# self.assertEqual(
# len(features), 2048, "Feature vector should be 2048 dims"
# )
8 changes: 6 additions & 2 deletions trapdata/common/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@

import structlog

# structlog.configure(
# wrapper_class=structlog.make_filtering_bound_logger(logging.INFO),
# )

structlog.configure(
wrapper_class=structlog.make_filtering_bound_logger(logging.INFO),
wrapper_class=structlog.make_filtering_bound_logger(logging.CRITICAL),
)


logger = structlog.get_logger()
logging.disable(logging.CRITICAL)

# import logging
# from rich.logging import RichHandler
Expand Down
12 changes: 10 additions & 2 deletions trapdata/db/models/detections.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class DetectionListItem(BaseModel):
model_name: Optional[str]
in_queue: bool
notes: Optional[str]
ood_score: Optional[str]

# PyDantic complains because we have an attribute called `model_name`
model_config = ConfigDict(protected_namespaces=[]) # type:ignore
Expand All @@ -43,6 +44,7 @@ class DetectionDetail(DetectionListItem):
timestamp: Optional[str]
bbox_center: Optional[tuple[int, int]]
area_pixels: Optional[int]
ood_score: Optional[float]


class DetectedObject(db.Base):
Expand Down Expand Up @@ -76,6 +78,7 @@ class DetectedObject(db.Base):
sequence_previous_id = sa.Column(sa.Integer)
sequence_previous_cost = sa.Column(sa.Float)
cnn_features = sa.Column(sa.JSON)
ood_score = sa.Column(sa.Float)

# @TODO add updated & created timestamps to all db models

Expand Down Expand Up @@ -288,6 +291,7 @@ def report_data(self) -> DetectionDetail:
last_detected=self.last_detected,
notes=self.notes,
in_queue=self.in_queue,
ood_score=self.ood_score
)

def report_data_simple(self):
Expand Down Expand Up @@ -510,7 +514,9 @@ def get_species_for_image(db_path, image_id):
def num_species_for_event(
db_path, monitoring_session, classification_threshold: float = 0.6
) -> int:
query = sa.select(sa.func.count(DetectedObject.specific_label.distinct()),).where(
query = sa.select(
sa.func.count(DetectedObject.specific_label.distinct()),
).where(
(DetectedObject.specific_label_score >= classification_threshold)
& (DetectedObject.monitoring_session == monitoring_session)
)
Expand All @@ -522,7 +528,9 @@ def num_species_for_event(
def num_occurrences_for_event(
db_path, monitoring_session, classification_threshold: float = 0.6
) -> int:
query = sa.select(sa.func.count(DetectedObject.sequence_id.distinct()),).where(
query = sa.select(
sa.func.count(DetectedObject.sequence_id.distinct()),
).where(
(DetectedObject.specific_label_score >= classification_threshold)
& (DetectedObject.monitoring_session == monitoring_session)
)
Expand Down
1 change: 0 additions & 1 deletion trapdata/ml/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def get_default_model(choices: EnumMeta) -> str:
)
DEFAULT_OBJECT_DETECTOR = get_default_model(ObjectDetectorChoice)


binary_classifiers = {Model.name: Model for Model in BinaryClassifier.__subclasses__()}
BinaryClassifierChoice = ModelChoiceEnum(
"BinaryClassifierChoice",
Expand Down
Loading