Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
273 changes: 272 additions & 1 deletion trapdata/api/datasets.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,25 @@
import os
import time
import typing
from io import BytesIO

import requests
import torch
import torch.utils.data
import torchvision
from PIL import Image

from trapdata.common.logs import logger

from .schemas import DetectionResponse, SourceImage
from .schemas import (
AntennaPipelineProcessingTask,
AntennaTasksListResponse,
DetectionResponse,
SourceImage,
)

if typing.TYPE_CHECKING:
from trapdata.settings import Settings


class LocalizationImageDataset(torch.utils.data.Dataset):
Expand Down Expand Up @@ -87,3 +100,261 @@ def __getitem__(self, idx):

# return (ids_batch, image_batch)
return (source_image.id, detection_idx), image_data


class RESTDataset(torch.utils.data.IterableDataset):
"""
An IterableDataset that fetches tasks from a REST API endpoint and loads images.

The dataset continuously polls the API for tasks, loads the associated images,
and yields them as PyTorch tensors along with metadata.

IMPORTANT: This dataset assumes the API endpoint atomically removes tasks from
the queue when fetched (like RabbitMQ, SQS, Redis LPOP). This means multiple
DataLoader workers are SAFE and won't process duplicate tasks. Each worker
independently fetches different tasks from the shared queue.

With num_workers > 0:
Worker 1: GET /tasks → receives [1,2,3,4], removed from queue
Worker 2: GET /tasks → receives [5,6,7,8], removed from queue
No duplicates, safe for parallel processing
"""

def __init__(
self,
base_url: str,
job_id: int,
batch_size: int = 1,
image_transforms: typing.Optional[torchvision.transforms.Compose] = None,
auth_token: typing.Optional[str] = None,
):
"""
Initialize the REST dataset.

Args:
base_url: Base URL for the API including /api/v2 (e.g., "http://localhost:8000/api/v2")
job_id: The job ID to fetch tasks for
batch_size: Number of tasks to request per batch
image_transforms: Optional transforms to apply to loaded images
auth_token: API authentication token
"""
super().__init__()
self.base_url = base_url
self.job_id = job_id
self.batch_size = batch_size
self.image_transforms = image_transforms or torchvision.transforms.ToTensor()
self.auth_token = auth_token or os.environ.get("ANTENNA_API_TOKEN")

def _fetch_tasks(self) -> list[AntennaPipelineProcessingTask]:
"""
Fetch a batch of tasks from the REST API.

Returns:
List of tasks (possibly empty if queue is drained)

Raises:
requests.RequestException: If the request fails (network error, etc.)
"""
url = f"{self.base_url.rstrip('/')}/jobs/{self.job_id}/tasks"
params = {"batch": self.batch_size}

headers = {}
if self.auth_token:
headers["Authorization"] = f"Token {self.auth_token}"

response = requests.get(
url,
params=params,
timeout=30,
headers=headers,
)
response.raise_for_status()

# Parse and validate response with Pydantic
tasks_response = AntennaTasksListResponse.model_validate(response.json())
return tasks_response.tasks # Empty list is valid (queue drained)

def _load_image(self, image_url: str) -> typing.Optional[torch.Tensor]:
"""
Load an image from a URL and convert it to a PyTorch tensor.

Args:
image_url: URL of the image to load

Returns:
Image as a PyTorch tensor, or None if loading failed
"""
try:
response = requests.get(image_url, timeout=30)
response.raise_for_status()
image = Image.open(BytesIO(response.content))

# Convert to RGB if necessary
if image.mode != "RGB":
image = image.convert("RGB")

# Apply transforms
image_tensor = self.image_transforms(image)
return image_tensor
except Exception as e:
logger.error(f"Failed to load image from {image_url}: {e}")
return None

def __iter__(self):
"""
Iterate over tasks from the REST API.

Yields:
Dictionary containing:
- image: PyTorch tensor of the loaded image
- reply_subject: Reply subject for the task
- batch_index: Index of the image in the batch
- job_id: Job ID
- image_id: Image ID
"""
try:
# Get worker info for debugging
worker_info = torch.utils.data.get_worker_info()
worker_id = worker_info.id if worker_info else 0
num_workers = worker_info.num_workers if worker_info else 1

logger.info(
f"Worker {worker_id}/{num_workers} starting iteration for job {self.job_id}"
)

while True:
try:
tasks = self._fetch_tasks()
except requests.RequestException as e:
# Fetch failed - retry after delay
logger.warning(
f"Worker {worker_id}: Fetch failed ({e}), retrying in 5s"
)
time.sleep(5)
continue

if not tasks:
# Queue is empty - job complete
logger.info(
f"Worker {worker_id}: No more tasks for job {self.job_id}"
)
break

for task in tasks:
errors = []
# Load the image
# _, t = log_time()
image_tensor = (
self._load_image(task.image_url) if task.image_url else None
)
# _, t = t(f"Loaded image from {image_url}")

if image_tensor is None:
errors.append("failed to load image")

if errors:
logger.warning(
f"Worker {worker_id}: Errors in task for image '{task.image_id}': {', '.join(errors)}"
)

# Yield the data row
row = {
"image": image_tensor,
"reply_subject": task.reply_subject,
"image_id": task.image_id,
"image_url": task.image_url,
}
if errors:
row["error"] = "; ".join(errors) if errors else None
yield row

logger.info(f"Worker {worker_id}: Iterator finished")
except Exception as e:
logger.error(f"Worker {worker_id}: Exception in iterator: {e}")
raise


def rest_collate_fn(batch: list[dict]) -> dict:
"""
Custom collate function that separates failed and successful items.

Returns a dict with:
- images: Stacked tensor of valid images (only present if there are successful items)
- reply_subjects: List of reply subjects for valid images
- image_ids: List of image IDs for valid images
- image_urls: List of image URLs for valid images
- failed_items: List of dicts with metadata for failed items

When all items in the batch have failed, the returned dict will only contain:
- reply_subjects: empty list
- image_ids: empty list
- failed_items: list of failure metadata
"""
successful = []
failed = []

for item in batch:
if item["image"] is None or item.get("error"):
# Failed item
failed.append(
{
"reply_subject": item["reply_subject"],
"image_id": item["image_id"],
"image_url": item.get("image_url"),
"error": item.get("error", "Unknown error"),
}
)
else:
# Successful item
successful.append(item)

# Collate successful items
if successful:
result = {
"images": torch.stack([item["image"] for item in successful]),
"reply_subjects": [item["reply_subject"] for item in successful],
"image_ids": [item["image_id"] for item in successful],
"image_urls": [item.get("image_url") for item in successful],
}
else:
# Empty batch - all failed
result = {
"reply_subjects": [],
"image_ids": [],
}

result["failed_items"] = failed

return result


def get_rest_dataloader(
job_id: int,
settings: "Settings",
) -> torch.utils.data.DataLoader:
"""
Create a DataLoader that fetches tasks from Antenna API.

Note: num_workers > 0 is SAFE here (unlike local file reading) because:
- Antenna API provides atomic task dequeue (work queue pattern)
- No shared file handles between workers
- Each worker gets different tasks automatically
- Parallel downloads improve throughput for I/O-bound work

Args:
job_id: Job ID to fetch tasks for
settings: Settings object with antenna_api_* configuration
"""
dataset = RESTDataset(
base_url=settings.antenna_api_base_url,
job_id=job_id,
batch_size=settings.antenna_api_batch_size,
auth_token=settings.antenna_api_auth_token,
)

return torch.utils.data.DataLoader(
dataset,
batch_size=settings.localization_batch_size,
num_workers=settings.num_workers,
collate_fn=rest_collate_fn,
)
43 changes: 31 additions & 12 deletions trapdata/api/models/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ def __init__(
"detections"
)

def reset(self, detections: typing.Iterable[DetectionResponse]):
self.detections = list(detections)
self.results = []

def get_dataset(self):
return ClassificationImageDataset(
source_images=self.source_images,
Expand Down Expand Up @@ -117,19 +121,12 @@ def save_results(
for image_id, detection_idx, predictions in zip(
image_ids, detection_idxes, batch_output
):
detection = self.detections[detection_idx]
assert detection.source_image_id == image_id

classification = ClassificationResponse(
classification=self.get_best_label(predictions),
scores=predictions.scores,
logits=predictions.logit,
inference_time=seconds_per_item,
algorithm=AlgorithmReference(name=self.name, key=self.get_key()),
timestamp=datetime.datetime.now(),
terminal=self.terminal,
self.update_detection_classification(
seconds_per_item,
image_id,
detection_idx,
predictions,
)
self.update_classification(detection, classification)

self.results = self.detections
logger.info(f"Saving {len(self.results)} detections with classifications")
Expand All @@ -149,6 +146,28 @@ def update_classification(
f"Total classifications: {len(detection.classifications)}"
)

def update_detection_classification(
self,
seconds_per_item: float,
image_id: str,
detection_idx: int,
predictions: ClassifierResult,
) -> DetectionResponse:
detection = self.detections[detection_idx]
assert detection.source_image_id == image_id

classification = ClassificationResponse(
classification=self.get_best_label(predictions),
scores=predictions.scores,
logits=predictions.logit,
inference_time=seconds_per_item,
algorithm=AlgorithmReference(name=self.name, key=self.get_key()),
timestamp=datetime.datetime.now(),
terminal=self.terminal,
)
self.update_classification(detection, classification)
return detection

def run(self) -> list[DetectionResponse]:
logger.info(
f"Starting {self.__class__.__name__} run with {len(self.results)} "
Expand Down
17 changes: 7 additions & 10 deletions trapdata/api/models/localization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import concurrent.futures
import datetime
import typing

Expand All @@ -17,6 +16,10 @@ def __init__(self, source_images: typing.Iterable[SourceImage], *args, **kwargs)
self.results: list[DetectionResponse] = []
super().__init__(*args, **kwargs)

def reset(self, source_images: typing.Iterable[SourceImage]):
self.source_images = source_images
self.results = []

def get_dataset(self):
return LocalizationImageDataset(
self.source_images, self.get_transforms(), batch_size=self.batch_size
Expand All @@ -43,15 +46,9 @@ def save_detection(image_id, coords):
)
return detection

with concurrent.futures.ThreadPoolExecutor() as executor:
futures = []
for image_id, image_output in zip(item_ids, batch_output):
for coords in image_output:
future = executor.submit(save_detection, image_id, coords)
futures.append(future)

for future in concurrent.futures.as_completed(futures):
detection = future.result()
for image_id, image_output in zip(item_ids, batch_output):
for coords in image_output:
detection = save_detection(image_id, coords)
detections.append(detection)

self.results += detections
Expand Down
Loading