Skip to content

Commit 02b6a8c

Browse files
authored
Add binary classification to worker processing (#116)
* Add binary classification to worker * refactor * add logging
1 parent 8a59673 commit 02b6a8c

1 file changed

Lines changed: 108 additions & 5 deletions

File tree

trapdata/antenna/worker.py

Lines changed: 108 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from trapdata.antenna.client import get_full_service_name, get_jobs, post_batch_results
1212
from trapdata.antenna.datasets import get_rest_dataloader
1313
from trapdata.antenna.schemas import AntennaTaskResult, AntennaTaskResultError
14-
from trapdata.api.api import CLASSIFIER_CHOICES
14+
from trapdata.api.api import CLASSIFIER_CHOICES, should_filter_detections
15+
from trapdata.api.models.classification import MothClassifierBinary
1516
from trapdata.api.models.localization import APIMothDetector
1617
from trapdata.api.schemas import (
1718
DetectionResponse,
@@ -125,6 +126,79 @@ def _worker_loop(gpu_id: int, pipelines: list[str]):
125126
time.sleep(SLEEP_TIME_SECONDS)
126127

127128

129+
def _apply_binary_classification(
130+
binary_filter: "MothClassifierBinary",
131+
detector_results: list[DetectionResponse],
132+
image_tensors: dict[str, torch.Tensor],
133+
image_detections: dict[str, list[DetectionResponse]],
134+
) -> tuple[list[DetectionResponse], list[DetectionResponse]]:
135+
"""Apply binary classification to filter moth vs non-moth detections.
136+
137+
Args:
138+
binary_filter: The binary classifier instance
139+
detector_results: List of detections from the object detector
140+
image_tensors: Mapping of image IDs to tensor data
141+
image_detections: Mapping to store detections by image ID
142+
143+
Returns:
144+
Tuple of (moth_detections, non_moth_detections)
145+
"""
146+
binary_filter.reset(detector_results)
147+
148+
# Process binary classification crops
149+
binary_crops = []
150+
binary_valid_indices = []
151+
to_pil = torchvision.transforms.ToPILImage()
152+
binary_transforms = binary_filter.get_transforms()
153+
154+
for idx, dresp in enumerate(detector_results):
155+
image_tensor = image_tensors[dresp.source_image_id]
156+
bbox = dresp.bbox
157+
y1, y2 = int(bbox.y1), int(bbox.y2)
158+
x1, x2 = int(bbox.x1), int(bbox.x2)
159+
if y1 >= y2 or x1 >= x2:
160+
logger.warning(
161+
f"Skipping binary classification {idx} with invalid bbox: "
162+
f"({x1},{y1})->({x2},{y2})"
163+
)
164+
continue
165+
crop = image_tensor[:, y1:y2, x1:x2]
166+
crop_pil = to_pil(crop)
167+
crop_transformed = binary_transforms(crop_pil)
168+
binary_crops.append(crop_transformed)
169+
binary_valid_indices.append(idx)
170+
171+
moth_detections = []
172+
non_moth_detections = []
173+
174+
if binary_crops:
175+
batched_binary_crops = torch.stack(binary_crops)
176+
binary_out = binary_filter.predict_batch(batched_binary_crops)
177+
binary_out = binary_filter.post_process_batch(binary_out)
178+
179+
for crop_i, idx in enumerate(binary_valid_indices):
180+
dresp = detector_results[idx]
181+
detection = binary_filter.update_detection_classification(
182+
seconds_per_item=0,
183+
image_id=dresp.source_image_id,
184+
detection_idx=idx,
185+
predictions=binary_out[crop_i],
186+
)
187+
188+
# Separate moth from non-moth detections
189+
for classification in detection.classifications:
190+
if classification.classification == binary_filter.positive_binary_label:
191+
moth_detections.append(detection)
192+
elif (
193+
classification.classification == binary_filter.negative_binary_label
194+
):
195+
non_moth_detections.append(detection)
196+
image_detections[detection.source_image_id].append(detection)
197+
break
198+
199+
return moth_detections, non_moth_detections
200+
201+
128202
@torch.no_grad()
129203
def _process_job(
130204
pipeline: str,
@@ -151,6 +225,17 @@ def _process_job(
151225
classifier = None
152226
detector = None
153227

228+
# Check if binary filtering is needed once for the entire job
229+
classifier_class = CLASSIFIER_CHOICES[pipeline]
230+
use_binary_filter = should_filter_detections(classifier_class)
231+
binary_filter = None
232+
if use_binary_filter:
233+
binary_filter = MothClassifierBinary(
234+
source_images=[],
235+
detections=[],
236+
terminal=False,
237+
)
238+
154239
if torch.cuda.is_available():
155240
torch.cuda.empty_cache()
156241
items = 0
@@ -171,7 +256,6 @@ def _process_job(
171256

172257
# Defer instantiation of detector and classifier until we have data
173258
if not classifier:
174-
classifier_class = CLASSIFIER_CHOICES[pipeline]
175259
classifier = classifier_class(source_images=[], detections=[])
176260
detector = APIMothDetector([])
177261
assert detector is not None, "Detector not initialized"
@@ -233,14 +317,30 @@ def _process_job(
233317
}
234318
image_tensors = dict(zip(image_ids, images, strict=True))
235319

236-
classifier.reset(detector.results)
320+
# Apply binary classification filter if needed
321+
detector_results = detector.results
322+
323+
if use_binary_filter:
324+
assert binary_filter is not None, "Binary filter not initialized"
325+
detections_for_terminal_classifier, detections_to_return = (
326+
_apply_binary_classification(
327+
binary_filter, detector_results, image_tensors, image_detections
328+
)
329+
)
330+
else:
331+
# No binary filtering, send all detections to terminal classifier
332+
detections_for_terminal_classifier = detector_results
333+
detections_to_return = []
334+
335+
# Run terminal classifier on filtered detections
336+
classifier.reset(detections_for_terminal_classifier)
237337
to_pil = torchvision.transforms.ToPILImage()
238338
classify_transforms = classifier.get_transforms()
239339

240340
# Collect and transform all crops for batched classification
241341
crops = []
242342
valid_indices = []
243-
for idx, dresp in enumerate(detector.results):
343+
for idx, dresp in enumerate(detections_for_terminal_classifier):
244344
image_tensor = image_tensors[dresp.source_image_id]
245345
bbox = dresp.bbox
246346
y1, y2 = int(bbox.y1), int(bbox.y2)
@@ -263,7 +363,7 @@ def _process_job(
263363
classifier_out = classifier.post_process_batch(classifier_out)
264364

265365
for crop_i, idx in enumerate(valid_indices):
266-
dresp = detector.results[idx]
366+
dresp = detections_for_terminal_classifier[idx]
267367
detection = classifier.update_detection_classification(
268368
seconds_per_item=0,
269369
image_id=dresp.source_image_id,
@@ -273,6 +373,9 @@ def _process_job(
273373
image_detections[dresp.source_image_id].append(detection)
274374
all_detections.append(detection)
275375

376+
# Add non-moth detections to all_detections
377+
all_detections.extend(detections_to_return)
378+
276379
ct, t = t("Finished classification")
277380
total_classification_time += ct
278381

0 commit comments

Comments
 (0)