Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
456b025
fix: return images as list in rest_collate_fn to support variable sizes
mihow Feb 10, 2026
54d9944
fix: handle batch processing errors per-batch instead of crashing job
mihow Feb 10, 2026
7e289b4
perf: batch classification crops in worker instead of N individual GP…
mihow Feb 10, 2026
23a90e3
perf: increase default localization_batch_size and num_workers for GPU
mihow Feb 10, 2026
8557298
feat: spawn one worker process per GPU for multi-GPU inference
mihow Feb 10, 2026
0627bd7
fix: clarify batch size log messages to distinguish worker vs inference
mihow Feb 10, 2026
123b24e
perf: increase DataLoader prefetch_factor to 4 for worker
mihow Feb 10, 2026
96b19f6
perf: download images concurrently within each DataLoader worker
mihow Feb 10, 2026
57bfed2
docs: clarify types of workers
mihow Feb 12, 2026
d9ae653
refactor: make ThreadPoolExecutor a class member, simplify with map()
mihow Feb 13, 2026
7abad40
docs: fix stale docstring in rest_collate_fn
mihow Feb 13, 2026
4baa3d7
fix: defensive batch_results init and strict=True in error handler
mihow Feb 13, 2026
27d9da2
revert: remove prefetch_factor=4 override, use PyTorch default
mihow Feb 13, 2026
3319c60
docs: document data loading pipeline concurrency layers
mihow Feb 13, 2026
ed49153
perf: tune defaults for higher GPU utilization
mihow Feb 13, 2026
5163e44
fix: lazily init unpicklable objects in RESTDataset for num_workers>0
mihow Feb 17, 2026
0857522
docs: fix stale num_workers default in datasets.py docstring (4 → 2)
mihow Feb 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 64 additions & 22 deletions trapdata/antenna/datasets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Dataset classes for streaming tasks from the Antenna API."""

import typing
from concurrent.futures import ThreadPoolExecutor, as_completed
from io import BytesIO

import requests
Expand Down Expand Up @@ -32,9 +33,9 @@ class RESTDataset(torch.utils.data.IterableDataset):
DataLoader workers are SAFE and won't process duplicate tasks. Each worker
independently fetches different tasks from the shared queue.

With num_workers > 0:
Worker 1: GET /tasks → receives [1,2,3,4], removed from queue
Worker 2: GET /tasks → receives [5,6,7,8], removed from queue
With DataLoader num_workers > 0 (I/O subprocesses, not AMI instances):
Subprocess 1: GET /tasks → receives [1,2,3,4], removed from queue
Subprocess 2: GET /tasks → receives [5,6,7,8], removed from queue
No duplicates, safe for parallel processing
"""

Expand Down Expand Up @@ -120,17 +121,55 @@ def _load_image(self, image_url: str) -> torch.Tensor | None:
logger.error(f"Failed to load image from {image_url}: {e}")
return None

def _load_images_threaded(
self,
tasks: list[AntennaPipelineProcessingTask],
) -> dict[str, torch.Tensor | None]:
"""Download images for a batch of tasks using concurrent threads.

Image downloads are I/O-bound (network latency, not CPU), so threads
provide near-linear speedup without the overhead of extra processes.
The HTTP session's connection pool is thread-safe and reuses TCP
connections across threads.

Args:
tasks: List of tasks whose images should be downloaded.

Returns:
Mapping from image_id to tensor (or None on failure), preserving
the order needed by the caller.
"""
results: dict[str, torch.Tensor | None] = {}

def _download(
task: AntennaPipelineProcessingTask,
) -> tuple[str, torch.Tensor | None]:
tensor = self._load_image(task.image_url) if task.image_url else None
return (task.image_id, tensor)

max_threads = min(len(tasks), 8)
with ThreadPoolExecutor(max_workers=max_threads) as executor:
futures = {executor.submit(_download, t): t for t in tasks}
for future in as_completed(futures):
image_id, tensor = future.result()
results[image_id] = tensor

return results

def __iter__(self):
"""
Iterate over tasks from the REST API.

Each API fetch returns a batch of tasks. Images for the entire batch
are downloaded concurrently using threads (see _load_images_threaded),
then yielded one at a time for the DataLoader to collate.

Yields:
Dictionary containing:
- image: PyTorch tensor of the loaded image
- reply_subject: Reply subject for the task
- batch_index: Index of the image in the batch
- job_id: Job ID
- image_id: Image ID
- image_url: Source URL
"""
worker_id = 0 # Initialize before try block to avoid UnboundLocalError
try:
Expand All @@ -140,7 +179,7 @@ def __iter__(self):
num_workers = worker_info.num_workers if worker_info else 1

logger.info(
f"Worker {worker_id}/{num_workers} starting iteration for job {self.job_id}"
f"DataLoader subprocess {worker_id}/{num_workers} starting iteration for job {self.job_id}"
)

while True:
Expand All @@ -160,14 +199,12 @@ def __iter__(self):
)
break

# Download all images concurrently
image_map = self._load_images_threaded(tasks)

for task in tasks:
image_tensor = image_map.get(task.image_id)
errors = []
# Load the image
# _, t = log_time()
image_tensor = (
self._load_image(task.image_url) if task.image_url else None
)
# _, t = t(f"Loaded image from {image_url}")

if image_tensor is None:
errors.append("failed to load image")
Expand Down Expand Up @@ -231,7 +268,7 @@ def rest_collate_fn(batch: list[dict]) -> dict:
# Collate successful items
if successful:
result = {
"images": torch.stack([item["image"] for item in successful]),
"images": [item["image"] for item in successful],
"reply_subjects": [item["reply_subject"] for item in successful],
"image_ids": [item["image_id"] for item in successful],
"image_urls": [item.get("image_url") for item in successful],
Expand All @@ -255,10 +292,10 @@ def get_rest_dataloader(
"""
Create a DataLoader that fetches tasks from Antenna API.

Note: num_workers > 0 is SAFE here (unlike local file reading) because:
Note: DataLoader num_workers > 0 is SAFE here (unlike local file reading) because:
- Antenna API provides atomic task dequeue (work queue pattern)
- No shared file handles between workers
- Each worker gets different tasks automatically
- No shared file handles between subprocesses
- Each subprocess gets different tasks automatically
- Parallel downloads improve throughput for I/O-bound work

Args:
Expand All @@ -272,9 +309,14 @@ def get_rest_dataloader(
batch_size=settings.antenna_api_batch_size,
)

return torch.utils.data.DataLoader(
dataset,
batch_size=settings.localization_batch_size,
num_workers=settings.num_workers,
collate_fn=rest_collate_fn,
)
dataloader_kwargs: dict = {
"batch_size": settings.localization_batch_size,
"num_workers": settings.num_workers,
"collate_fn": rest_collate_fn,
}
if settings.num_workers > 0:
# Prefetch more batches so the next batch is already downloading
# while the GPU processes the current one. Default is 2.
dataloader_kwargs["prefetch_factor"] = 4

return torch.utils.data.DataLoader(dataset, **dataloader_kwargs)
6 changes: 4 additions & 2 deletions trapdata/antenna/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def test_all_successful(self):
result = rest_collate_fn(batch)

assert "images" in result
assert result["images"].shape == (2, 3, 64, 64)
assert len(result["images"]) == 2
assert result["images"][0].shape == (3, 64, 64)
assert result["image_ids"] == ["img1", "img2"]
assert result["reply_subjects"] == ["subj1", "subj2"]
assert result["failed_items"] == []
Expand Down Expand Up @@ -104,7 +105,8 @@ def test_mixed(self):
]
result = rest_collate_fn(batch)

assert result["images"].shape == (1, 3, 64, 64)
assert len(result["images"]) == 1
assert result["images"][0].shape == (3, 64, 64)
assert result["image_ids"] == ["img1"]
assert len(result["failed_items"]) == 1
assert result["failed_items"][0]["image_id"] == "img2"
Expand Down
Loading