Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
456b025
fix: return images as list in rest_collate_fn to support variable sizes
mihow Feb 10, 2026
54d9944
fix: handle batch processing errors per-batch instead of crashing job
mihow Feb 10, 2026
7e289b4
perf: batch classification crops in worker instead of N individual GP…
mihow Feb 10, 2026
23a90e3
perf: increase default localization_batch_size and num_workers for GPU
mihow Feb 10, 2026
8557298
feat: spawn one worker process per GPU for multi-GPU inference
mihow Feb 10, 2026
0627bd7
fix: clarify batch size log messages to distinguish worker vs inference
mihow Feb 10, 2026
123b24e
perf: increase DataLoader prefetch_factor to 4 for worker
mihow Feb 10, 2026
96b19f6
perf: download images concurrently within each DataLoader worker
mihow Feb 10, 2026
57bfed2
docs: clarify types of workers
mihow Feb 12, 2026
d9ae653
refactor: make ThreadPoolExecutor a class member, simplify with map()
mihow Feb 13, 2026
7abad40
docs: fix stale docstring in rest_collate_fn
mihow Feb 13, 2026
4baa3d7
fix: defensive batch_results init and strict=True in error handler
mihow Feb 13, 2026
27d9da2
revert: remove prefetch_factor=4 override, use PyTorch default
mihow Feb 13, 2026
3319c60
docs: document data loading pipeline concurrency layers
mihow Feb 13, 2026
ed49153
perf: tune defaults for higher GPU utilization
mihow Feb 13, 2026
5163e44
fix: lazily init unpicklable objects in RESTDataset for num_workers>0
mihow Feb 17, 2026
0857522
docs: fix stale num_workers default in datasets.py docstring (4 → 2)
mihow Feb 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion trapdata/antenna/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def rest_collate_fn(batch: list[dict]) -> dict:
# Collate successful items
if successful:
result = {
"images": torch.stack([item["image"] for item in successful]),
"images": [item["image"] for item in successful],
"reply_subjects": [item["reply_subject"] for item in successful],
"image_ids": [item["image_id"] for item in successful],
"image_urls": [item.get("image_url") for item in successful],
Expand Down
6 changes: 4 additions & 2 deletions trapdata/antenna/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def test_all_successful(self):
result = rest_collate_fn(batch)

assert "images" in result
assert result["images"].shape == (2, 3, 64, 64)
assert len(result["images"]) == 2
assert result["images"][0].shape == (3, 64, 64)
assert result["image_ids"] == ["img1", "img2"]
assert result["reply_subjects"] == ["subj1", "subj2"]
assert result["failed_items"] == []
Expand Down Expand Up @@ -104,7 +105,8 @@ def test_mixed(self):
]
result = rest_collate_fn(batch)

assert result["images"].shape == (1, 3, 64, 64)
assert len(result["images"]) == 1
assert result["images"][0].shape == (3, 64, 64)
assert result["image_ids"] == ["img1"]
assert len(result["failed_items"]) == 1
assert result["failed_items"][0]["image_id"] == "img2"
Expand Down
211 changes: 124 additions & 87 deletions trapdata/antenna/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import torch
import torchvision

from trapdata.antenna.client import get_jobs, post_batch_results
from trapdata.antenna.datasets import get_rest_dataloader
Expand Down Expand Up @@ -121,99 +122,137 @@ def _process_job(
reply_subjects = batch.get("reply_subjects", [None] * len(images))
image_urls = batch.get("image_urls", [None] * len(images))

# Validate all arrays have same length before zipping
if len(image_ids) != len(images):
raise ValueError(
f"Length mismatch: image_ids ({len(image_ids)}) != images ({len(images)})"
)
if len(image_ids) != len(reply_subjects) or len(image_ids) != len(image_urls):
raise ValueError(
f"Length mismatch: image_ids ({len(image_ids)}), "
f"reply_subjects ({len(reply_subjects)}), image_urls ({len(image_urls)})"
)
try:
# Validate all arrays have same length before zipping
if len(image_ids) != len(images):
raise ValueError(
f"Length mismatch: image_ids ({len(image_ids)}) != images ({len(images)})"
)
if len(image_ids) != len(reply_subjects) or len(image_ids) != len(
image_urls
):
raise ValueError(
f"Length mismatch: image_ids ({len(image_ids)}), "
f"reply_subjects ({len(reply_subjects)}), image_urls ({len(image_urls)})"
)

# Track start time for this batch
batch_start_time = datetime.datetime.now()
# Track start time for this batch
batch_start_time = datetime.datetime.now()

logger.info(f"Processing batch {i + 1}")
# output is dict of "boxes", "labels", "scores"
batch_output = []
if len(images) > 0:
batch_output = detector.predict_batch(images)
logger.info(f"Processing batch {i + 1}")
# output is dict of "boxes", "labels", "scores"
batch_output = []
if len(images) > 0:
batch_output = detector.predict_batch(images)

items += len(batch_output)
logger.info(f"Total items processed so far: {items}")
batch_output = list(detector.post_process_batch(batch_output))
items += len(batch_output)
logger.info(f"Total items processed so far: {items}")
batch_output = list(detector.post_process_batch(batch_output))

# Convert image_ids to list if needed
if isinstance(image_ids, (np.ndarray, torch.Tensor)):
image_ids = image_ids.tolist()
# Convert image_ids to list if needed
if isinstance(image_ids, (np.ndarray, torch.Tensor)):
image_ids = image_ids.tolist()

# TODO CGJS: Add seconds per item calculation for both detector and classifier
detector.save_results(
item_ids=image_ids,
batch_output=batch_output,
seconds_per_item=0,
)
dt, t = t("Finished detection")
total_detection_time += dt

# Group detections by image_id
image_detections: dict[str, list[DetectionResponse]] = {
img_id: [] for img_id in image_ids
}
image_tensors = dict(zip(image_ids, images, strict=True))

classifier.reset(detector.results)

for idx, dresp in enumerate(detector.results):
image_tensor = image_tensors[dresp.source_image_id]
bbox = dresp.bbox
# crop the image tensor using the bbox
crop = image_tensor[
:, int(bbox.y1) : int(bbox.y2), int(bbox.x1) : int(bbox.x2)
]
crop = crop.unsqueeze(0) # add batch dimension
classifier_out = classifier.predict_batch(crop)
classifier_out = classifier.post_process_batch(classifier_out)
detection = classifier.update_detection_classification(
# TODO CGJS: Add seconds per item calculation for both detector and classifier
detector.save_results(
item_ids=image_ids,
batch_output=batch_output,
seconds_per_item=0,
image_id=dresp.source_image_id,
detection_idx=idx,
predictions=classifier_out[0],
)
image_detections[dresp.source_image_id].append(detection)
all_detections.append(detection)

ct, t = t("Finished classification")
total_classification_time += ct

# Calculate batch processing time
batch_end_time = datetime.datetime.now()
batch_elapsed = (batch_end_time - batch_start_time).total_seconds()

# Post results back to the API with PipelineResponse for each image
batch_results: list[AntennaTaskResult] = []
for reply_subject, image_id, image_url in zip(
reply_subjects, image_ids, image_urls, strict=True
):
# Create SourceImageResponse for this image
source_image = SourceImageResponse(id=image_id, url=image_url)

# Create PipelineResultsResponse
pipeline_response = PipelineResultsResponse(
pipeline=pipeline,
source_images=[source_image],
detections=image_detections[image_id],
total_time=batch_elapsed / len(image_ids), # Approximate time per image
)
dt, t = t("Finished detection")
total_detection_time += dt

# Group detections by image_id
image_detections: dict[str, list[DetectionResponse]] = {
img_id: [] for img_id in image_ids
}
image_tensors = dict(zip(image_ids, images, strict=True))

classifier.reset(detector.results)
to_pil = torchvision.transforms.ToPILImage()
classify_transforms = classifier.get_transforms()

# Collect and transform all crops for batched classification
crops = []
valid_indices = []
for idx, dresp in enumerate(detector.results):
image_tensor = image_tensors[dresp.source_image_id]
bbox = dresp.bbox
y1, y2 = int(bbox.y1), int(bbox.y2)
x1, x2 = int(bbox.x1), int(bbox.x2)
if y1 >= y2 or x1 >= x2:
logger.warning(
f"Skipping detection {idx} with invalid bbox: "
f"({x1},{y1})->({x2},{y2})"
)
continue
crop = image_tensor[:, y1:y2, x1:x2]
crop_pil = to_pil(crop)
crop_transformed = classify_transforms(crop_pil)
crops.append(crop_transformed)
valid_indices.append(idx)

if crops:
batched_crops = torch.stack(crops)
classifier_out = classifier.predict_batch(batched_crops)
classifier_out = classifier.post_process_batch(classifier_out)

for crop_i, idx in enumerate(valid_indices):
dresp = detector.results[idx]
detection = classifier.update_detection_classification(
seconds_per_item=0,
image_id=dresp.source_image_id,
detection_idx=idx,
predictions=classifier_out[crop_i],
)
image_detections[dresp.source_image_id].append(detection)
all_detections.append(detection)

ct, t = t("Finished classification")
total_classification_time += ct

# Calculate batch processing time
batch_end_time = datetime.datetime.now()
batch_elapsed = (batch_end_time - batch_start_time).total_seconds()

# Post results back to the API with PipelineResponse for each image
batch_results: list[AntennaTaskResult] = []
for reply_subject, image_id, image_url in zip(
reply_subjects, image_ids, image_urls, strict=True
):
# Create SourceImageResponse for this image
source_image = SourceImageResponse(id=image_id, url=image_url)

# Create PipelineResultsResponse
pipeline_response = PipelineResultsResponse(
pipeline=pipeline,
source_images=[source_image],
detections=image_detections[image_id],
total_time=batch_elapsed
/ len(image_ids), # Approximate time per image
)

batch_results.append(
AntennaTaskResult(
reply_subject=reply_subject,
result=pipeline_response,
batch_results.append(
AntennaTaskResult(
reply_subject=reply_subject,
result=pipeline_response,
)
)
)
except Exception as e:
logger.error(f"Batch {i + 1} failed during processing: {e}", exc_info=True)
# Report errors back to Antenna so tasks aren't stuck in the queue
batch_results = []
for reply_subject, image_id in zip(reply_subjects, image_ids):
batch_results.append(
AntennaTaskResult(
reply_subject=reply_subject,
result=AntennaTaskResultError(
error=f"Batch processing error: {e}",
image_id=image_id,
),
)
)

failed_items = batch.get("failed_items")
if failed_items:
for failed_item in failed_items:
Expand All @@ -236,12 +275,10 @@ def _process_job(
st, t = t("Finished posting results")

if not success:
error_msg = (
logger.error(
f"Failed to post {len(batch_results)} results for job {job_id} to "
f"{settings.antenna_api_base_url}. Batch processing data lost."
)
logger.error(error_msg)
raise RuntimeError(error_msg)

total_save_time += st

Expand Down