1111from trapdata .antenna .client import get_full_service_name , get_jobs , post_batch_results
1212from trapdata .antenna .datasets import get_rest_dataloader
1313from trapdata .antenna .schemas import AntennaTaskResult , AntennaTaskResultError
14- from trapdata .api .api import CLASSIFIER_CHOICES
14+ from trapdata .api .api import CLASSIFIER_CHOICES , should_filter_detections
15+ from trapdata .api .models .classification import MothClassifierBinary
1516from trapdata .api .models .localization import APIMothDetector
1617from trapdata .api .schemas import (
1718 DetectionResponse ,
@@ -125,6 +126,79 @@ def _worker_loop(gpu_id: int, pipelines: list[str]):
125126 time .sleep (SLEEP_TIME_SECONDS )
126127
127128
129+ def _apply_binary_classification (
130+ binary_filter : "MothClassifierBinary" ,
131+ detector_results : list [DetectionResponse ],
132+ image_tensors : dict [str , torch .Tensor ],
133+ image_detections : dict [str , list [DetectionResponse ]],
134+ ) -> tuple [list [DetectionResponse ], list [DetectionResponse ]]:
135+ """Apply binary classification to filter moth vs non-moth detections.
136+
137+ Args:
138+ binary_filter: The binary classifier instance
139+ detector_results: List of detections from the object detector
140+ image_tensors: Mapping of image IDs to tensor data
141+ image_detections: Mapping to store detections by image ID
142+
143+ Returns:
144+ Tuple of (moth_detections, non_moth_detections)
145+ """
146+ binary_filter .reset (detector_results )
147+
148+ # Process binary classification crops
149+ binary_crops = []
150+ binary_valid_indices = []
151+ to_pil = torchvision .transforms .ToPILImage ()
152+ binary_transforms = binary_filter .get_transforms ()
153+
154+ for idx , dresp in enumerate (detector_results ):
155+ image_tensor = image_tensors [dresp .source_image_id ]
156+ bbox = dresp .bbox
157+ y1 , y2 = int (bbox .y1 ), int (bbox .y2 )
158+ x1 , x2 = int (bbox .x1 ), int (bbox .x2 )
159+ if y1 >= y2 or x1 >= x2 :
160+ logger .warning (
161+ f"Skipping binary classification { idx } with invalid bbox: "
162+ f"({ x1 } ,{ y1 } )->({ x2 } ,{ y2 } )"
163+ )
164+ continue
165+ crop = image_tensor [:, y1 :y2 , x1 :x2 ]
166+ crop_pil = to_pil (crop )
167+ crop_transformed = binary_transforms (crop_pil )
168+ binary_crops .append (crop_transformed )
169+ binary_valid_indices .append (idx )
170+
171+ moth_detections = []
172+ non_moth_detections = []
173+
174+ if binary_crops :
175+ batched_binary_crops = torch .stack (binary_crops )
176+ binary_out = binary_filter .predict_batch (batched_binary_crops )
177+ binary_out = binary_filter .post_process_batch (binary_out )
178+
179+ for crop_i , idx in enumerate (binary_valid_indices ):
180+ dresp = detector_results [idx ]
181+ detection = binary_filter .update_detection_classification (
182+ seconds_per_item = 0 ,
183+ image_id = dresp .source_image_id ,
184+ detection_idx = idx ,
185+ predictions = binary_out [crop_i ],
186+ )
187+
188+ # Separate moth from non-moth detections
189+ for classification in detection .classifications :
190+ if classification .classification == binary_filter .positive_binary_label :
191+ moth_detections .append (detection )
192+ elif (
193+ classification .classification == binary_filter .negative_binary_label
194+ ):
195+ non_moth_detections .append (detection )
196+ image_detections [detection .source_image_id ].append (detection )
197+ break
198+
199+ return moth_detections , non_moth_detections
200+
201+
128202@torch .no_grad ()
129203def _process_job (
130204 pipeline : str ,
@@ -151,6 +225,17 @@ def _process_job(
151225 classifier = None
152226 detector = None
153227
228+ # Check if binary filtering is needed once for the entire job
229+ classifier_class = CLASSIFIER_CHOICES [pipeline ]
230+ use_binary_filter = should_filter_detections (classifier_class )
231+ binary_filter = None
232+ if use_binary_filter :
233+ binary_filter = MothClassifierBinary (
234+ source_images = [],
235+ detections = [],
236+ terminal = False ,
237+ )
238+
154239 if torch .cuda .is_available ():
155240 torch .cuda .empty_cache ()
156241 items = 0
@@ -171,7 +256,6 @@ def _process_job(
171256
172257 # Defer instantiation of detector and classifier until we have data
173258 if not classifier :
174- classifier_class = CLASSIFIER_CHOICES [pipeline ]
175259 classifier = classifier_class (source_images = [], detections = [])
176260 detector = APIMothDetector ([])
177261 assert detector is not None , "Detector not initialized"
@@ -233,14 +317,30 @@ def _process_job(
233317 }
234318 image_tensors = dict (zip (image_ids , images , strict = True ))
235319
236- classifier .reset (detector .results )
320+ # Apply binary classification filter if needed
321+ detector_results = detector .results
322+
323+ if use_binary_filter :
324+ assert binary_filter is not None , "Binary filter not initialized"
325+ detections_for_terminal_classifier , detections_to_return = (
326+ _apply_binary_classification (
327+ binary_filter , detector_results , image_tensors , image_detections
328+ )
329+ )
330+ else :
331+ # No binary filtering, send all detections to terminal classifier
332+ detections_for_terminal_classifier = detector_results
333+ detections_to_return = []
334+
335+ # Run terminal classifier on filtered detections
336+ classifier .reset (detections_for_terminal_classifier )
237337 to_pil = torchvision .transforms .ToPILImage ()
238338 classify_transforms = classifier .get_transforms ()
239339
240340 # Collect and transform all crops for batched classification
241341 crops = []
242342 valid_indices = []
243- for idx , dresp in enumerate (detector . results ):
343+ for idx , dresp in enumerate (detections_for_terminal_classifier ):
244344 image_tensor = image_tensors [dresp .source_image_id ]
245345 bbox = dresp .bbox
246346 y1 , y2 = int (bbox .y1 ), int (bbox .y2 )
@@ -263,7 +363,7 @@ def _process_job(
263363 classifier_out = classifier .post_process_batch (classifier_out )
264364
265365 for crop_i , idx in enumerate (valid_indices ):
266- dresp = detector . results [idx ]
366+ dresp = detections_for_terminal_classifier [idx ]
267367 detection = classifier .update_detection_classification (
268368 seconds_per_item = 0 ,
269369 image_id = dresp .source_image_id ,
@@ -273,6 +373,9 @@ def _process_job(
273373 image_detections [dresp .source_image_id ].append (detection )
274374 all_detections .append (detection )
275375
376+ # Add non-moth detections to all_detections
377+ all_detections .extend (detections_to_return )
378+
276379 ct , t = t ("Finished classification" )
277380 total_classification_time += ct
278381
0 commit comments