Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
456b025
fix: return images as list in rest_collate_fn to support variable sizes
mihow Feb 10, 2026
54d9944
fix: handle batch processing errors per-batch instead of crashing job
mihow Feb 10, 2026
7e289b4
perf: batch classification crops in worker instead of N individual GP…
mihow Feb 10, 2026
23a90e3
perf: increase default localization_batch_size and num_workers for GPU
mihow Feb 10, 2026
8557298
feat: spawn one worker process per GPU for multi-GPU inference
mihow Feb 10, 2026
0627bd7
fix: clarify batch size log messages to distinguish worker vs inference
mihow Feb 10, 2026
123b24e
perf: increase DataLoader prefetch_factor to 4 for worker
mihow Feb 10, 2026
96b19f6
perf: download images concurrently within each DataLoader worker
mihow Feb 10, 2026
57bfed2
docs: clarify types of workers
mihow Feb 12, 2026
d9ae653
refactor: make ThreadPoolExecutor a class member, simplify with map()
mihow Feb 13, 2026
7abad40
docs: fix stale docstring in rest_collate_fn
mihow Feb 13, 2026
4baa3d7
fix: defensive batch_results init and strict=True in error handler
mihow Feb 13, 2026
27d9da2
revert: remove prefetch_factor=4 override, use PyTorch default
mihow Feb 13, 2026
3319c60
docs: document data loading pipeline concurrency layers
mihow Feb 13, 2026
ed49153
perf: tune defaults for higher GPU utilization
mihow Feb 13, 2026
5163e44
fix: lazily init unpicklable objects in RESTDataset for num_workers>0
mihow Feb 17, 2026
0857522
docs: fix stale num_workers default in datasets.py docstring (4 → 2)
mihow Feb 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 162 additions & 35 deletions trapdata/antenna/datasets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,69 @@
"""Dataset classes for streaming tasks from the Antenna API."""
"""Dataset and DataLoader for streaming tasks from the Antenna API.

Data loading pipeline overview
==============================

The pipeline has three layers of concurrency. Each layer is controlled by a
different setting and targets a different bottleneck.

::

┌──────────────────────────────────────────────────────────────────┐
│ GPU process (_worker_loop in worker.py) │
│ One per GPU. Runs detection → classification on batches. │
│ Controlled by: automatic (one per torch.cuda.device_count()) │
├──────────────────────────────────────────────────────────────────┤
│ DataLoader workers (num_workers subprocesses) │
│ Each subprocess runs its own RESTDataset.__iter__ loop: │
│ 1. GET /tasks → fetch batch of task metadata from Antenna │
│ 2. Download images (threaded, see below) │
│ 3. Yield individual (image_tensor, metadata) rows │
│ The DataLoader collates rows into GPU-sized batches. │
│ Controlled by: settings.num_workers (AMI_NUM_WORKERS) │
│ Default: 4. Safe >0 because Antenna dequeues atomically. │
├──────────────────────────────────────────────────────────────────┤
│ Thread pool (ThreadPoolExecutor inside each DataLoader worker) │
│ Downloads images concurrently *within* one API fetch batch. │
│ Each thread: HTTP GET → PIL open → RGB convert → ToTensor(). │
│ Controlled by: ThreadPoolExecutor(max_workers=8) on the class. │
│ Note: RGB conversion and ToTensor are GIL-bound (CPU). Only │
│ the network wait truly runs in parallel. A future optimisation │
│ could move transforms out of the thread. │
└──────────────────────────────────────────────────────────────────┘

Settings quick-reference (prefix with AMI_ as env vars):

localization_batch_size (default 8)
How many images the GPU processes at once (detection). Larger =
more GPU memory. These are full-resolution images (~4K).

num_workers (default 2)
DataLoader subprocesses. Each independently fetches tasks and
downloads images. More workers = more images prefetched for the
GPU, at the cost of CPU/RAM. With 0 workers, fetching and
inference are sequential (useful for debugging).

antenna_api_batch_size (default 16)
How many task URLs to request from Antenna per API call.
Determines how many images are downloaded concurrently per
thread pool invocation. Should be >= localization_batch_size
so one API call can fill at least one GPU batch without an
extra round trip.

prefetch_factor (PyTorch default: 2 when num_workers > 0)
Batches prefetched per worker. Not overridden here — the
default was tested and no improvement was measured by
increasing it (it just adds memory pressure).

What has NOT been benchmarked yet (as of 2026-02):
- Optimal num_workers / thread count combination
- Whether moving transforms out of threads helps throughput
- Whether multiple DataLoader workers + threads overlap well
or contend on the GIL
"""

import typing
from concurrent.futures import ThreadPoolExecutor
from io import BytesIO

import requests
Expand Down Expand Up @@ -32,9 +95,9 @@ class RESTDataset(torch.utils.data.IterableDataset):
DataLoader workers are SAFE and won't process duplicate tasks. Each worker
independently fetches different tasks from the shared queue.

With num_workers > 0:
Worker 1: GET /tasks → receives [1,2,3,4], removed from queue
Worker 2: GET /tasks → receives [5,6,7,8], removed from queue
With DataLoader num_workers > 0 (I/O subprocesses, not AMI instances):
Subprocess 1: GET /tasks → receives [1,2,3,4], removed from queue
Subprocess 2: GET /tasks → receives [5,6,7,8], removed from queue
No duplicates, safe for parallel processing
"""

Expand All @@ -58,20 +121,40 @@ def __init__(
"""
super().__init__()
self.base_url = base_url
self.auth_token = auth_token
self.job_id = job_id
self.batch_size = batch_size
self.image_transforms = image_transforms or torchvision.transforms.ToTensor()

# Create persistent sessions for connection pooling
self.api_session = get_http_session(auth_token)
self.image_fetch_session = get_http_session() # No auth for external image URLs
# These are created lazily in _ensure_sessions() because they contain
# unpicklable objects (ThreadPoolExecutor has a SimpleQueue) and
# PyTorch DataLoader with num_workers>0 pickles the dataset to send
# it to worker subprocesses.
self._api_session: requests.Session | None = None
self._image_fetch_session: requests.Session | None = None
self._executor: ThreadPoolExecutor | None = None

def _ensure_sessions(self) -> None:
"""Lazily create HTTP sessions and thread pool.

Called once per worker process on first use. This avoids pickling
issues with num_workers > 0 (SimpleQueue, socket objects, etc.).
"""
if self._api_session is None:
self._api_session = get_http_session(self.auth_token)
if self._image_fetch_session is None:
self._image_fetch_session = get_http_session()
if self._executor is None:
self._executor = ThreadPoolExecutor(max_workers=8)

def __del__(self):
"""Clean up HTTP sessions on dataset destruction."""
if hasattr(self, "api_session"):
self.api_session.close()
if hasattr(self, "image_fetch_session"):
self.image_fetch_session.close()
"""Clean up HTTP sessions and thread pool on dataset destruction."""
if self._executor is not None:
self._executor.shutdown(wait=False)
if self._api_session is not None:
self._api_session.close()
if self._image_fetch_session is not None:
self._image_fetch_session.close()

def _fetch_tasks(self) -> list[AntennaPipelineProcessingTask]:
"""
Expand All @@ -86,16 +169,22 @@ def _fetch_tasks(self) -> list[AntennaPipelineProcessingTask]:
url = f"{self.base_url.rstrip('/')}/jobs/{self.job_id}/tasks"
params = {"batch": self.batch_size}

response = self.api_session.get(url, params=params, timeout=30)
self._ensure_sessions()
assert self._api_session is not None
response = self._api_session.get(url, params=params, timeout=30)
response.raise_for_status()

# Parse and validate response with Pydantic
tasks_response = AntennaTasksListResponse.model_validate(response.json())
return tasks_response.tasks # Empty list is valid (queue drained)

def _load_image(self, image_url: str) -> torch.Tensor | None:
"""
Load an image from a URL and convert it to a PyTorch tensor.
"""Load an image from a URL and convert it to a PyTorch tensor.

Called from threads inside ``_load_images_threaded``. The HTTP
fetch is truly concurrent (network I/O releases the GIL), but
PIL decode, RGB conversion, and ``image_transforms`` (ToTensor)
are CPU-bound and serialised by the GIL.

Args:
image_url: URL of the image to load
Expand All @@ -105,7 +194,9 @@ def _load_image(self, image_url: str) -> torch.Tensor | None:
"""
try:
# Use dedicated session without auth for external images
response = self.image_fetch_session.get(image_url, timeout=30)
self._ensure_sessions()
assert self._image_fetch_session is not None
response = self._image_fetch_session.get(image_url, timeout=30)
response.raise_for_status()
image = Image.open(BytesIO(response.content))

Expand All @@ -120,17 +211,51 @@ def _load_image(self, image_url: str) -> torch.Tensor | None:
logger.error(f"Failed to load image from {image_url}: {e}")
return None

def _load_images_threaded(
self,
tasks: list[AntennaPipelineProcessingTask],
) -> dict[str, torch.Tensor | None]:
"""Download images for a batch of tasks using concurrent threads.

Image downloads are I/O-bound (network latency, not CPU), so threads
provide near-linear speedup without the overhead of extra processes.
Note: ``requests.Session`` is not formally thread-safe, but the
underlying urllib3 connection pool handles concurrent socket access.
In practice shared read-only sessions work fine for GET requests;
if issues arise, switch to per-thread sessions.

Args:
tasks: List of tasks whose images should be downloaded.

Returns:
Mapping from image_id to tensor (or None on failure), preserving
the order needed by the caller.
"""

def _download(
task: AntennaPipelineProcessingTask,
) -> tuple[str, torch.Tensor | None]:
tensor = self._load_image(task.image_url) if task.image_url else None
return (task.image_id, tensor)

self._ensure_sessions()
assert self._executor is not None
return dict(self._executor.map(_download, tasks))

def __iter__(self):
"""
Iterate over tasks from the REST API.

Each API fetch returns a batch of tasks. Images for the entire batch
are downloaded concurrently using threads (see _load_images_threaded),
then yielded one at a time for the DataLoader to collate.

Yields:
Dictionary containing:
- image: PyTorch tensor of the loaded image
- reply_subject: Reply subject for the task
- batch_index: Index of the image in the batch
- job_id: Job ID
- image_id: Image ID
- image_url: Source URL
"""
worker_id = 0 # Initialize before try block to avoid UnboundLocalError
try:
Expand All @@ -140,7 +265,7 @@ def __iter__(self):
num_workers = worker_info.num_workers if worker_info else 1

logger.info(
f"Worker {worker_id}/{num_workers} starting iteration for job {self.job_id}"
f"DataLoader subprocess {worker_id}/{num_workers} starting iteration for job {self.job_id}"
)

while True:
Expand All @@ -160,14 +285,12 @@ def __iter__(self):
)
break

# Download all images concurrently
image_map = self._load_images_threaded(tasks)

for task in tasks:
image_tensor = image_map.get(task.image_id)
errors = []
# Load the image
# _, t = log_time()
image_tensor = (
self._load_image(task.image_url) if task.image_url else None
)
# _, t = t(f"Loaded image from {image_url}")

if image_tensor is None:
errors.append("failed to load image")
Expand Down Expand Up @@ -199,7 +322,7 @@ def rest_collate_fn(batch: list[dict]) -> dict:
Custom collate function that separates failed and successful items.

Returns a dict with:
- images: Stacked tensor of valid images (only present if there are successful items)
- images: List of image tensors (only present if there are successful items)
- reply_subjects: List of reply subjects for valid images
- image_ids: List of image IDs for valid images
- image_urls: List of image URLs for valid images
Expand Down Expand Up @@ -231,7 +354,7 @@ def rest_collate_fn(batch: list[dict]) -> dict:
# Collate successful items
if successful:
result = {
"images": torch.stack([item["image"] for item in successful]),
"images": [item["image"] for item in successful],
"reply_subjects": [item["reply_subject"] for item in successful],
"image_ids": [item["image_id"] for item in successful],
"image_urls": [item.get("image_url") for item in successful],
Expand All @@ -252,18 +375,22 @@ def get_rest_dataloader(
job_id: int,
settings: "Settings",
) -> torch.utils.data.DataLoader:
"""
Create a DataLoader that fetches tasks from Antenna API.
"""Create a DataLoader that fetches tasks from Antenna API.

See the module docstring for an overview of the three concurrency
layers (GPU processes → DataLoader workers → thread pool) and which
settings control each.

Note: num_workers > 0 is SAFE here (unlike local file reading) because:
- Antenna API provides atomic task dequeue (work queue pattern)
- No shared file handles between workers
- Each worker gets different tasks automatically
- Parallel downloads improve throughput for I/O-bound work
DataLoader num_workers > 0 is safe here because Antenna dequeues
tasks atomically — each worker subprocess gets a unique set of tasks.

Args:
job_id: Job ID to fetch tasks for
settings: Settings object with antenna_api_* configuration
settings: Settings object. Relevant fields:
- antenna_api_base_url / antenna_api_auth_token
- antenna_api_batch_size (tasks per API call)
- localization_batch_size (images per GPU batch)
- num_workers (DataLoader subprocesses)
"""
dataset = RESTDataset(
base_url=settings.antenna_api_base_url,
Expand Down
29 changes: 27 additions & 2 deletions trapdata/antenna/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,29 @@
# ---------------------------------------------------------------------------


class TestDataLoaderMultiWorker(TestCase):
"""DataLoader with num_workers > 0 must be able to start workers."""

def test_dataloader_starts_with_num_workers(self):
"""Creating an iterator pickles the dataset to send to worker subprocesses."""
dataset = RESTDataset(
base_url="http://localhost:1/api/v2",
auth_token="test-token",
job_id=1,
batch_size=4,
)
loader = torch.utils.data.DataLoader(
dataset,
batch_size=2,
num_workers=2,
collate_fn=rest_collate_fn,
)
# iter() pickles the dataset and spawns workers.
# If the dataset has unpicklable attributes this raises TypeError.
it = iter(loader)
del it


class TestRestCollateFn(TestCase):
"""Tests for rest_collate_fn which separates successful/failed items."""

Expand All @@ -55,7 +78,8 @@ def test_all_successful(self):
result = rest_collate_fn(batch)

assert "images" in result
assert result["images"].shape == (2, 3, 64, 64)
assert len(result["images"]) == 2
assert result["images"][0].shape == (3, 64, 64)
assert result["image_ids"] == ["img1", "img2"]
assert result["reply_subjects"] == ["subj1", "subj2"]
assert result["failed_items"] == []
Expand Down Expand Up @@ -104,7 +128,8 @@ def test_mixed(self):
]
result = rest_collate_fn(batch)

assert result["images"].shape == (1, 3, 64, 64)
assert len(result["images"]) == 1
assert result["images"][0].shape == (3, 64, 64)
assert result["image_ids"] == ["img1"]
assert len(result["failed_items"]) == 1
assert result["failed_items"][0]["image_id"] == "img2"
Expand Down
Loading