diff --git a/trapdata/api/datasets.py b/trapdata/api/datasets.py index 57cf9ba1..472e4b7f 100644 --- a/trapdata/api/datasets.py +++ b/trapdata/api/datasets.py @@ -1,12 +1,25 @@ +import os +import time import typing +from io import BytesIO +import requests import torch import torch.utils.data import torchvision +from PIL import Image from trapdata.common.logs import logger -from .schemas import DetectionResponse, SourceImage +from .schemas import ( + AntennaPipelineProcessingTask, + AntennaTasksListResponse, + DetectionResponse, + SourceImage, +) + +if typing.TYPE_CHECKING: + from trapdata.settings import Settings class LocalizationImageDataset(torch.utils.data.Dataset): @@ -87,3 +100,261 @@ def __getitem__(self, idx): # return (ids_batch, image_batch) return (source_image.id, detection_idx), image_data + + +class RESTDataset(torch.utils.data.IterableDataset): + """ + An IterableDataset that fetches tasks from a REST API endpoint and loads images. + + The dataset continuously polls the API for tasks, loads the associated images, + and yields them as PyTorch tensors along with metadata. + + IMPORTANT: This dataset assumes the API endpoint atomically removes tasks from + the queue when fetched (like RabbitMQ, SQS, Redis LPOP). This means multiple + DataLoader workers are SAFE and won't process duplicate tasks. Each worker + independently fetches different tasks from the shared queue. + + With num_workers > 0: + Worker 1: GET /tasks → receives [1,2,3,4], removed from queue + Worker 2: GET /tasks → receives [5,6,7,8], removed from queue + No duplicates, safe for parallel processing + """ + + def __init__( + self, + base_url: str, + job_id: int, + batch_size: int = 1, + image_transforms: typing.Optional[torchvision.transforms.Compose] = None, + auth_token: typing.Optional[str] = None, + ): + """ + Initialize the REST dataset. + + Args: + base_url: Base URL for the API including /api/v2 (e.g., "http://localhost:8000/api/v2") + job_id: The job ID to fetch tasks for + batch_size: Number of tasks to request per batch + image_transforms: Optional transforms to apply to loaded images + auth_token: API authentication token + """ + super().__init__() + self.base_url = base_url + self.job_id = job_id + self.batch_size = batch_size + self.image_transforms = image_transforms or torchvision.transforms.ToTensor() + self.auth_token = auth_token or os.environ.get("ANTENNA_API_TOKEN") + + def _fetch_tasks(self) -> list[AntennaPipelineProcessingTask]: + """ + Fetch a batch of tasks from the REST API. + + Returns: + List of tasks (possibly empty if queue is drained) + + Raises: + requests.RequestException: If the request fails (network error, etc.) + """ + url = f"{self.base_url.rstrip('/')}/jobs/{self.job_id}/tasks" + params = {"batch": self.batch_size} + + headers = {} + if self.auth_token: + headers["Authorization"] = f"Token {self.auth_token}" + + response = requests.get( + url, + params=params, + timeout=30, + headers=headers, + ) + response.raise_for_status() + + # Parse and validate response with Pydantic + tasks_response = AntennaTasksListResponse.model_validate(response.json()) + return tasks_response.tasks # Empty list is valid (queue drained) + + def _load_image(self, image_url: str) -> typing.Optional[torch.Tensor]: + """ + Load an image from a URL and convert it to a PyTorch tensor. + + Args: + image_url: URL of the image to load + + Returns: + Image as a PyTorch tensor, or None if loading failed + """ + try: + response = requests.get(image_url, timeout=30) + response.raise_for_status() + image = Image.open(BytesIO(response.content)) + + # Convert to RGB if necessary + if image.mode != "RGB": + image = image.convert("RGB") + + # Apply transforms + image_tensor = self.image_transforms(image) + return image_tensor + except Exception as e: + logger.error(f"Failed to load image from {image_url}: {e}") + return None + + def __iter__(self): + """ + Iterate over tasks from the REST API. + + Yields: + Dictionary containing: + - image: PyTorch tensor of the loaded image + - reply_subject: Reply subject for the task + - batch_index: Index of the image in the batch + - job_id: Job ID + - image_id: Image ID + """ + try: + # Get worker info for debugging + worker_info = torch.utils.data.get_worker_info() + worker_id = worker_info.id if worker_info else 0 + num_workers = worker_info.num_workers if worker_info else 1 + + logger.info( + f"Worker {worker_id}/{num_workers} starting iteration for job {self.job_id}" + ) + + while True: + try: + tasks = self._fetch_tasks() + except requests.RequestException as e: + # Fetch failed - retry after delay + logger.warning( + f"Worker {worker_id}: Fetch failed ({e}), retrying in 5s" + ) + time.sleep(5) + continue + + if not tasks: + # Queue is empty - job complete + logger.info( + f"Worker {worker_id}: No more tasks for job {self.job_id}" + ) + break + + for task in tasks: + errors = [] + # Load the image + # _, t = log_time() + image_tensor = ( + self._load_image(task.image_url) if task.image_url else None + ) + # _, t = t(f"Loaded image from {image_url}") + + if image_tensor is None: + errors.append("failed to load image") + + if errors: + logger.warning( + f"Worker {worker_id}: Errors in task for image '{task.image_id}': {', '.join(errors)}" + ) + + # Yield the data row + row = { + "image": image_tensor, + "reply_subject": task.reply_subject, + "image_id": task.image_id, + "image_url": task.image_url, + } + if errors: + row["error"] = "; ".join(errors) if errors else None + yield row + + logger.info(f"Worker {worker_id}: Iterator finished") + except Exception as e: + logger.error(f"Worker {worker_id}: Exception in iterator: {e}") + raise + + +def rest_collate_fn(batch: list[dict]) -> dict: + """ + Custom collate function that separates failed and successful items. + + Returns a dict with: + - images: Stacked tensor of valid images (only present if there are successful items) + - reply_subjects: List of reply subjects for valid images + - image_ids: List of image IDs for valid images + - image_urls: List of image URLs for valid images + - failed_items: List of dicts with metadata for failed items + + When all items in the batch have failed, the returned dict will only contain: + - reply_subjects: empty list + - image_ids: empty list + - failed_items: list of failure metadata + """ + successful = [] + failed = [] + + for item in batch: + if item["image"] is None or item.get("error"): + # Failed item + failed.append( + { + "reply_subject": item["reply_subject"], + "image_id": item["image_id"], + "image_url": item.get("image_url"), + "error": item.get("error", "Unknown error"), + } + ) + else: + # Successful item + successful.append(item) + + # Collate successful items + if successful: + result = { + "images": torch.stack([item["image"] for item in successful]), + "reply_subjects": [item["reply_subject"] for item in successful], + "image_ids": [item["image_id"] for item in successful], + "image_urls": [item.get("image_url") for item in successful], + } + else: + # Empty batch - all failed + result = { + "reply_subjects": [], + "image_ids": [], + } + + result["failed_items"] = failed + + return result + + +def get_rest_dataloader( + job_id: int, + settings: "Settings", +) -> torch.utils.data.DataLoader: + """ + Create a DataLoader that fetches tasks from Antenna API. + + Note: num_workers > 0 is SAFE here (unlike local file reading) because: + - Antenna API provides atomic task dequeue (work queue pattern) + - No shared file handles between workers + - Each worker gets different tasks automatically + - Parallel downloads improve throughput for I/O-bound work + + Args: + job_id: Job ID to fetch tasks for + settings: Settings object with antenna_api_* configuration + """ + dataset = RESTDataset( + base_url=settings.antenna_api_base_url, + job_id=job_id, + batch_size=settings.antenna_api_batch_size, + auth_token=settings.antenna_api_auth_token, + ) + + return torch.utils.data.DataLoader( + dataset, + batch_size=settings.localization_batch_size, + num_workers=settings.num_workers, + collate_fn=rest_collate_fn, + ) diff --git a/trapdata/api/models/classification.py b/trapdata/api/models/classification.py index 482c4ac3..f000e0d9 100644 --- a/trapdata/api/models/classification.py +++ b/trapdata/api/models/classification.py @@ -54,6 +54,10 @@ def __init__( "detections" ) + def reset(self, detections: typing.Iterable[DetectionResponse]): + self.detections = list(detections) + self.results = [] + def get_dataset(self): return ClassificationImageDataset( source_images=self.source_images, @@ -117,19 +121,12 @@ def save_results( for image_id, detection_idx, predictions in zip( image_ids, detection_idxes, batch_output ): - detection = self.detections[detection_idx] - assert detection.source_image_id == image_id - - classification = ClassificationResponse( - classification=self.get_best_label(predictions), - scores=predictions.scores, - logits=predictions.logit, - inference_time=seconds_per_item, - algorithm=AlgorithmReference(name=self.name, key=self.get_key()), - timestamp=datetime.datetime.now(), - terminal=self.terminal, + self.update_detection_classification( + seconds_per_item, + image_id, + detection_idx, + predictions, ) - self.update_classification(detection, classification) self.results = self.detections logger.info(f"Saving {len(self.results)} detections with classifications") @@ -149,6 +146,28 @@ def update_classification( f"Total classifications: {len(detection.classifications)}" ) + def update_detection_classification( + self, + seconds_per_item: float, + image_id: str, + detection_idx: int, + predictions: ClassifierResult, + ) -> DetectionResponse: + detection = self.detections[detection_idx] + assert detection.source_image_id == image_id + + classification = ClassificationResponse( + classification=self.get_best_label(predictions), + scores=predictions.scores, + logits=predictions.logit, + inference_time=seconds_per_item, + algorithm=AlgorithmReference(name=self.name, key=self.get_key()), + timestamp=datetime.datetime.now(), + terminal=self.terminal, + ) + self.update_classification(detection, classification) + return detection + def run(self) -> list[DetectionResponse]: logger.info( f"Starting {self.__class__.__name__} run with {len(self.results)} " diff --git a/trapdata/api/models/localization.py b/trapdata/api/models/localization.py index 600fc9f7..9ec1acd5 100644 --- a/trapdata/api/models/localization.py +++ b/trapdata/api/models/localization.py @@ -1,4 +1,3 @@ -import concurrent.futures import datetime import typing @@ -17,6 +16,10 @@ def __init__(self, source_images: typing.Iterable[SourceImage], *args, **kwargs) self.results: list[DetectionResponse] = [] super().__init__(*args, **kwargs) + def reset(self, source_images: typing.Iterable[SourceImage]): + self.source_images = source_images + self.results = [] + def get_dataset(self): return LocalizationImageDataset( self.source_images, self.get_transforms(), batch_size=self.batch_size @@ -43,15 +46,9 @@ def save_detection(image_id, coords): ) return detection - with concurrent.futures.ThreadPoolExecutor() as executor: - futures = [] - for image_id, image_output in zip(item_ids, batch_output): - for coords in image_output: - future = executor.submit(save_detection, image_id, coords) - futures.append(future) - - for future in concurrent.futures.as_completed(futures): - detection = future.result() + for image_id, image_output in zip(item_ids, batch_output): + for coords in image_output: + detection = save_detection(image_id, coords) detections.append(detection) self.results += detections diff --git a/trapdata/api/schemas.py b/trapdata/api/schemas.py index a8b682ac..30fd6186 100644 --- a/trapdata/api/schemas.py +++ b/trapdata/api/schemas.py @@ -282,6 +282,38 @@ class PipelineResultsResponse(pydantic.BaseModel): config: PipelineConfigRequest = PipelineConfigRequest() +class AntennaPipelineProcessingTask(pydantic.BaseModel): + """ + A task representing a single image or detection to be processed in an async pipeline. + """ + + id: str + image_id: str + image_url: str + reply_subject: str | None = None # The NATS subject to send the result to + # TODO: Do we need these? + # detections: list[DetectionRequest] | None = None + # config: PipelineRequestConfigParameters | dict | None = None + + +class AntennaJobListItem(pydantic.BaseModel): + """A single job item from the Antenna jobs list API response.""" + + id: int + + +class AntennaJobsListResponse(pydantic.BaseModel): + """Response from Antenna API GET /api/v2/jobs with ids_only=1.""" + + results: list[AntennaJobListItem] + + +class AntennaTasksListResponse(pydantic.BaseModel): + """Response from Antenna API GET /api/v2/jobs/{job_id}/tasks.""" + + tasks: list[AntennaPipelineProcessingTask] + + class PipelineStageParam(pydantic.BaseModel): """A configurable parameter of a stage of a pipeline.""" @@ -310,6 +342,26 @@ class PipelineConfigResponse(pydantic.BaseModel): stages: list[PipelineStage] = [] +class AntennaTaskResultError(pydantic.BaseModel): + """Error result for a single Antenna task that failed to process.""" + + error: str + image_id: str | None = None + + +class AntennaTaskResult(pydantic.BaseModel): + """Result for a single Antenna task, either success or error.""" + + reply_subject: str | None = None + result: PipelineResultsResponse | AntennaTaskResultError + + +class AntennaTaskResults(pydantic.BaseModel): + """Batch of task results to post back to Antenna API.""" + + results: list[AntennaTaskResult] = pydantic.Field(default_factory=list) + + class ProcessingServiceInfoResponse(pydantic.BaseModel): """Information about the processing service.""" diff --git a/trapdata/api/tests/test_worker.py b/trapdata/api/tests/test_worker.py new file mode 100644 index 00000000..81df4bd0 --- /dev/null +++ b/trapdata/api/tests/test_worker.py @@ -0,0 +1,565 @@ +"""Integration tests for the REST worker and related utilities. + +These tests validate the Antenna API contract and run real ML inference through +the worker's unique code path (RESTDataset → rest_collate_fn → batch processing). +Only external service dependencies are mocked - ML models and image loading are real. +""" + +import pathlib +from unittest import TestCase +from unittest.mock import MagicMock + +import torch +from fastapi.testclient import TestClient + +from trapdata.api.datasets import RESTDataset, rest_collate_fn +from trapdata.api.schemas import ( + AntennaPipelineProcessingTask, + AntennaTaskResult, + AntennaTaskResultError, + PipelineResultsResponse, +) +from trapdata.api.tests import antenna_api_server +from trapdata.api.tests.antenna_api_server import app as antenna_app +from trapdata.api.tests.image_server import StaticFileTestServer +from trapdata.api.tests.utils import get_test_image_urls, patch_antenna_api_requests +from trapdata.cli.worker import _get_jobs, _process_job +from trapdata.tests import TEST_IMAGES_BASE_PATH + + +# --------------------------------------------------------------------------- +# TestRestCollateFn - Unit tests for collation logic +# --------------------------------------------------------------------------- + + +class TestRestCollateFn: + """Tests for rest_collate_fn which separates successful/failed items.""" + + def test_all_successful(self): + batch = [ + { + "image": torch.rand(3, 64, 64), + "reply_subject": "subj1", + "image_id": "img1", + "image_url": "http://example.com/1.jpg", + }, + { + "image": torch.rand(3, 64, 64), + "reply_subject": "subj2", + "image_id": "img2", + "image_url": "http://example.com/2.jpg", + }, + ] + result = rest_collate_fn(batch) + + assert "images" in result + assert result["images"].shape == (2, 3, 64, 64) + assert result["image_ids"] == ["img1", "img2"] + assert result["reply_subjects"] == ["subj1", "subj2"] + assert result["failed_items"] == [] + + def test_all_failed(self): + batch = [ + { + "image": None, + "reply_subject": "subj1", + "image_id": "img1", + "image_url": "http://example.com/1.jpg", + "error": "download failed", + }, + { + "image": None, + "reply_subject": "subj2", + "image_id": "img2", + "image_url": "http://example.com/2.jpg", + "error": "timeout", + }, + ] + result = rest_collate_fn(batch) + + assert "images" not in result + assert result["image_ids"] == [] + assert result["reply_subjects"] == [] + assert len(result["failed_items"]) == 2 + assert result["failed_items"][0]["image_id"] == "img1" + assert result["failed_items"][1]["error"] == "timeout" + + def test_mixed(self): + batch = [ + { + "image": torch.rand(3, 64, 64), + "reply_subject": "subj1", + "image_id": "img1", + "image_url": "http://example.com/1.jpg", + }, + { + "image": None, + "reply_subject": "subj2", + "image_id": "img2", + "image_url": "http://example.com/2.jpg", + "error": "404", + }, + ] + result = rest_collate_fn(batch) + + assert result["images"].shape == (1, 3, 64, 64) + assert result["image_ids"] == ["img1"] + assert len(result["failed_items"]) == 1 + assert result["failed_items"][0]["image_id"] == "img2" + + def test_single_item(self): + batch = [ + { + "image": torch.rand(3, 32, 32), + "reply_subject": "subj1", + "image_id": "img1", + "image_url": "http://example.com/1.jpg", + }, + ] + result = rest_collate_fn(batch) + + assert result["images"].shape == (1, 3, 32, 32) + assert result["image_ids"] == ["img1"] + assert result["failed_items"] == [] + + +# --------------------------------------------------------------------------- +# TestRESTDatasetIntegration - Integration tests with real image loading +# --------------------------------------------------------------------------- + + +class TestRESTDatasetIntegration(TestCase): + """Integration tests for RESTDataset that fetch tasks and load real images.""" + + @classmethod + def setUpClass(cls): + # Setup file server for test images + cls.test_images_dir = pathlib.Path(TEST_IMAGES_BASE_PATH) + cls.file_server = StaticFileTestServer(cls.test_images_dir) + cls.file_server.start() # Start server and keep it running for all tests + + # Setup mock Antenna API + cls.antenna_client = TestClient(antenna_app) + + @classmethod + def tearDownClass(cls): + cls.file_server.stop() + + def setUp(self): + # Reset state between tests + antenna_api_server.reset() + + def _make_dataset(self, job_id: int = 42, batch_size: int = 2) -> RESTDataset: + """Create a RESTDataset pointing to the mock API.""" + return RESTDataset( + base_url="http://testserver/api/v2", + job_id=job_id, + batch_size=batch_size, + auth_token="test-token", + ) + + def test_fetches_and_loads_images(self): + """RESTDataset fetches tasks and loads images from URLs.""" + # Setup mock API job with real image URLs + image_urls = get_test_image_urls( + self.file_server, self.test_images_dir, subdir="vermont", num=2 + ) + tasks = [ + AntennaPipelineProcessingTask( + id=f"task_{i}", + image_id=f"img_{i}", + image_url=url, + reply_subject=f"reply_{i}", + ) + for i, url in enumerate(image_urls) + ] + antenna_api_server.setup_job(job_id=1, tasks=tasks) + + # Create dataset and iterate + with patch_antenna_api_requests(self.antenna_client): + dataset = self._make_dataset(job_id=1, batch_size=2) + rows = list(dataset) + + # Validate images actually loaded + assert len(rows) == 2 + assert all(r["image"] is not None for r in rows) + assert all(isinstance(r["image"], torch.Tensor) for r in rows) + assert rows[0]["image_id"] == "img_0" + assert rows[1]["image_id"] == "img_1" + + def test_image_failure(self): + """Invalid image URL produces error row with image=None.""" + tasks = [ + AntennaPipelineProcessingTask( + id="task_bad", + image_id="img_bad", + image_url="http://invalid-url.test/bad.jpg", + reply_subject="reply_bad", + ) + ] + antenna_api_server.setup_job(job_id=2, tasks=tasks) + + with patch_antenna_api_requests(self.antenna_client): + dataset = self._make_dataset(job_id=2) + rows = list(dataset) + + assert len(rows) == 1 + assert rows[0]["image"] is None + assert "error" in rows[0] + + def test_empty_queue(self): + """First fetch returns empty tasks → iterator stops immediately.""" + antenna_api_server.setup_job(job_id=3, tasks=[]) + + with patch_antenna_api_requests(self.antenna_client): + dataset = self._make_dataset(job_id=3) + rows = list(dataset) + + assert rows == [] + + def test_multiple_batches(self): + """Dataset fetches multiple batches until queue is empty.""" + # Setup job with 3 images (all available in vermont dir), batch size 2 + image_urls = get_test_image_urls( + self.file_server, self.test_images_dir, subdir="vermont", num=3 + ) + tasks = [ + AntennaPipelineProcessingTask( + id=f"task_{i}", + image_id=f"img_{i}", + image_url=url, + reply_subject=f"reply_{i}", + ) + for i, url in enumerate(image_urls) + ] + antenna_api_server.setup_job(job_id=4, tasks=tasks) + + with patch_antenna_api_requests(self.antenna_client): + dataset = self._make_dataset(job_id=4, batch_size=2) + rows = list(dataset) + + # Should get all 3 images (batch1: 2 images, batch2: 1 image) + assert len(rows) == 3 + assert all(r["image"] is not None for r in rows) + + +# --------------------------------------------------------------------------- +# TestGetJobsIntegration - Integration tests for job fetching +# --------------------------------------------------------------------------- + + +class TestGetJobsIntegration(TestCase): + """Integration tests for _get_jobs() with mock Antenna API.""" + + @classmethod + def setUpClass(cls): + cls.antenna_client = TestClient(antenna_app) + + def setUp(self): + antenna_api_server.reset() + + def test_returns_job_ids(self): + """Successfully fetches list of job IDs.""" + # Setup jobs in queue + antenna_api_server.setup_job(10, []) + antenna_api_server.setup_job(20, []) + antenna_api_server.setup_job(30, []) + + with patch_antenna_api_requests(self.antenna_client): + result = _get_jobs("http://testserver/api/v2", "test-token", "moths_2024") + + assert result == [10, 20, 30] + + def test_empty_queue(self): + """Empty job queue returns empty list.""" + with patch_antenna_api_requests(self.antenna_client): + result = _get_jobs("http://testserver/api/v2", "test-token", "moths_2024") + + assert result == [] + + def test_query_params_sent(self): + """Request includes correct query parameters.""" + # This test validates the query params are sent by checking the function works + # The mock API checks the params internally + antenna_api_server.setup_job(1, []) + + with patch_antenna_api_requests(self.antenna_client): + result = _get_jobs("http://testserver/api/v2", "test-token", "my_pipeline") + + assert isinstance(result, list) + + +# --------------------------------------------------------------------------- +# TestProcessJobIntegration - Integration tests with real ML inference +# --------------------------------------------------------------------------- + + +class TestProcessJobIntegration(TestCase): + """Integration tests for _process_job() with real detector and classifier.""" + + @classmethod + def setUpClass(cls): + cls.test_images_dir = pathlib.Path(TEST_IMAGES_BASE_PATH) + cls.file_server = StaticFileTestServer(cls.test_images_dir) + cls.file_server.start() # Start server and keep it running for all tests + cls.antenna_client = TestClient(antenna_app) + + @classmethod + def tearDownClass(cls): + cls.file_server.stop() + + def setUp(self): + antenna_api_server.reset() + + def _make_settings(self): + """Create mock settings for worker.""" + settings = MagicMock() + settings.antenna_api_base_url = "http://testserver/api/v2" + settings.antenna_api_auth_token = "test-token" + settings.antenna_api_batch_size = 2 + settings.num_workers = 0 # Disable multiprocessing for tests + settings.localization_batch_size = 2 # Real integer for batch processing + return settings + + def test_empty_queue(self): + """No tasks in queue → returns False.""" + antenna_api_server.setup_job(job_id=100, tasks=[]) + + with patch_antenna_api_requests(self.antenna_client): + result = _process_job( + "quebec_vermont_moths_2023", 100, self._make_settings() + ) + + assert result is False + + def test_processes_batch_with_real_inference(self): + """Worker fetches tasks, loads images, runs ML, posts results.""" + # Setup job with 2 test images + image_urls = get_test_image_urls( + self.file_server, self.test_images_dir, subdir="vermont", num=2 + ) + tasks = [ + AntennaPipelineProcessingTask( + id=f"task_{i}", + image_id=f"img_{i}", + image_url=url, + reply_subject=f"reply_{i}", + ) + for i, url in enumerate(image_urls) + ] + antenna_api_server.setup_job(job_id=101, tasks=tasks) + + # Run worker + with patch_antenna_api_requests(self.antenna_client): + result = _process_job( + "quebec_vermont_moths_2023", 101, self._make_settings() + ) + + # Validate processing succeeded + assert result is True + + # Validate results were posted + posted_results = antenna_api_server.get_posted_results(101) + assert len(posted_results) == 2 + + # Validate schema compliance + for task_result in posted_results: + assert isinstance(task_result, AntennaTaskResult) + assert isinstance(task_result.result, PipelineResultsResponse) + + # Validate structure + response = task_result.result + assert response.pipeline == "quebec_vermont_moths_2023" + assert response.total_time > 0 + assert len(response.source_images) == 1 + assert len(response.detections) >= 0 # May be 0 if no moths + + def test_handles_failed_items(self): + """Failed image downloads produce AntennaTaskResultError.""" + tasks = [ + AntennaPipelineProcessingTask( + id="task_fail", + image_id="img_fail", + image_url="http://invalid-url.test/image.jpg", + reply_subject="reply_fail", + ) + ] + antenna_api_server.setup_job(job_id=102, tasks=tasks) + + with patch_antenna_api_requests(self.antenna_client): + _process_job("quebec_vermont_moths_2023", 102, self._make_settings()) + + posted_results = antenna_api_server.get_posted_results(102) + assert len(posted_results) == 1 + assert isinstance(posted_results[0].result, AntennaTaskResultError) + assert posted_results[0].result.error # Error message should not be empty + + def test_mixed_batch_success_and_failures(self): + """Batch with some successful and some failed images.""" + # One valid image, one invalid + valid_url = get_test_image_urls( + self.file_server, self.test_images_dir, subdir="vermont", num=1 + )[0] + + tasks = [ + AntennaPipelineProcessingTask( + id="task_good", + image_id="img_good", + image_url=valid_url, + reply_subject="reply_good", + ), + AntennaPipelineProcessingTask( + id="task_bad", + image_id="img_bad", + image_url="http://invalid-url.test/bad.jpg", + reply_subject="reply_bad", + ), + ] + antenna_api_server.setup_job(job_id=103, tasks=tasks) + + with patch_antenna_api_requests(self.antenna_client): + result = _process_job( + "quebec_vermont_moths_2023", 103, self._make_settings() + ) + + assert result is True + posted_results = antenna_api_server.get_posted_results(103) + assert len(posted_results) == 2 + + # One success, one error + success_results = [ + r for r in posted_results if isinstance(r.result, PipelineResultsResponse) + ] + error_results = [ + r for r in posted_results if isinstance(r.result, AntennaTaskResultError) + ] + assert len(success_results) == 1 + assert len(error_results) == 1 + + +# --------------------------------------------------------------------------- +# TestWorkerEndToEnd - Full workflow integration tests +# --------------------------------------------------------------------------- + + +class TestWorkerEndToEnd(TestCase): + """End-to-end integration tests for complete worker workflow.""" + + @classmethod + def setUpClass(cls): + cls.test_images_dir = pathlib.Path(TEST_IMAGES_BASE_PATH) + cls.file_server = StaticFileTestServer(cls.test_images_dir) + cls.file_server.start() # Start server and keep it running for all tests + cls.antenna_client = TestClient(antenna_app) + + @classmethod + def tearDownClass(cls): + cls.file_server.stop() + + def setUp(self): + antenna_api_server.reset() + + def _make_settings(self): + settings = MagicMock() + settings.antenna_api_base_url = "http://testserver/api/v2" + settings.antenna_api_auth_token = "test-token" + settings.antenna_api_batch_size = 2 + settings.num_workers = 0 + settings.localization_batch_size = 2 # Real integer for batch processing + return settings + + def test_full_workflow_with_real_inference(self): + """ + Complete workflow: fetch jobs → fetch tasks → load images → + run detection → run classification → post results. + """ + # Setup job with 2 test images + image_urls = get_test_image_urls( + self.file_server, self.test_images_dir, subdir="vermont", num=2 + ) + tasks = [ + AntennaPipelineProcessingTask( + id=f"task_{i}", + image_id=f"img_{i}", + image_url=url, + reply_subject=f"reply_{i}", + ) + for i, url in enumerate(image_urls) + ] + antenna_api_server.setup_job(job_id=200, tasks=tasks) + + # Step 1: Get jobs + with patch_antenna_api_requests(self.antenna_client): + job_ids = _get_jobs( + "http://testserver/api/v2", + "test-token", + "quebec_vermont_moths_2023", + ) + + assert 200 in job_ids + + # Step 2: Process job + with patch_antenna_api_requests(self.antenna_client): + result = _process_job( + "quebec_vermont_moths_2023", 200, self._make_settings() + ) + + assert result is True + + # Step 3: Validate results posted + posted_results = antenna_api_server.get_posted_results(200) + assert len(posted_results) == 2 + + # Validate all results are valid + for task_result in posted_results: + assert isinstance(task_result, AntennaTaskResult) + assert task_result.reply_subject is not None + + # Should be success results + assert isinstance(task_result.result, PipelineResultsResponse) + response = task_result.result + + # Validate pipeline response structure + assert response.pipeline == "quebec_vermont_moths_2023" + assert response.total_time > 0 + assert len(response.source_images) == 1 + + # Validate detections structure (may be empty if no moths) + assert isinstance(response.detections, list) + if response.detections: + detection = response.detections[0] + assert detection.bbox is not None + assert detection.source_image_id is not None + + def test_multiple_batches_processed(self): + """Job with more tasks than batch size processes in multiple batches.""" + # Setup job with 3 images (all available in vermont dir), batch size 2 + image_urls = get_test_image_urls( + self.file_server, self.test_images_dir, subdir="vermont", num=3 + ) + tasks = [ + AntennaPipelineProcessingTask( + id=f"task_{i}", + image_id=f"img_{i}", + image_url=url, + reply_subject=f"reply_{i}", + ) + for i, url in enumerate(image_urls) + ] + antenna_api_server.setup_job(job_id=201, tasks=tasks) + + with patch_antenna_api_requests(self.antenna_client): + result = _process_job( + "quebec_vermont_moths_2023", 201, self._make_settings() + ) + + assert result is True + + # All 3 results should be posted (batch1: 2, batch2: 1) + posted_results = antenna_api_server.get_posted_results(201) + assert len(posted_results) == 3 + + # All should be successful + assert all( + isinstance(r.result, PipelineResultsResponse) for r in posted_results + ) diff --git a/trapdata/cli/base.py b/trapdata/cli/base.py index f53cb651..433ccd2e 100644 --- a/trapdata/cli/base.py +++ b/trapdata/cli/base.py @@ -96,5 +96,36 @@ def run_api(port: int = 2000): uvicorn.run("trapdata.api.api:app", host="0.0.0.0", port=port, reload=True) +@cli.command("worker") +def worker( + pipelines: Optional[list[str]] = typer.Option( + None, + help="List of pipelines to use for processing (e.g., moth_binary, panama_moths_2024, etc.) or all if not specified.", + ), +): + """ + Run the worker to process images from the REST API queue. + """ + from trapdata.api.api import CLASSIFIER_CHOICES + + if not pipelines: + pipelines = list(CLASSIFIER_CHOICES.keys()) + + # Validate that each pipeline is in CLASSIFIER_CHOICES + invalid_pipelines = [ + pipeline for pipeline in pipelines if pipeline not in CLASSIFIER_CHOICES.keys() + ] + + if invalid_pipelines: + raise typer.BadParameter( + f"Invalid pipeline(s): {', '.join(invalid_pipelines)}. " + f"Must be one of: {', '.join(CLASSIFIER_CHOICES.keys())}" + ) + + from trapdata.cli.worker import run_worker + + run_worker(pipelines=pipelines) + + if __name__ == "__main__": cli() diff --git a/trapdata/cli/tests/__init__.py b/trapdata/cli/tests/__init__.py new file mode 100644 index 00000000..f1fd207f --- /dev/null +++ b/trapdata/cli/tests/__init__.py @@ -0,0 +1 @@ +# CLI tests module diff --git a/trapdata/cli/tests/test_worker_batching.py b/trapdata/cli/tests/test_worker_batching.py new file mode 100644 index 00000000..ec389c7a --- /dev/null +++ b/trapdata/cli/tests/test_worker_batching.py @@ -0,0 +1,114 @@ +"""Unit tests for batched classification in worker.py. + +This test validates that the worker correctly batches multiple detection crops +together for classification instead of processing them one at a time. +""" + +import torch +import torchvision.transforms +from unittest.mock import MagicMock, patch + + +def test_batched_classification(): + """Test that worker batches all crops together for classification.""" + from trapdata.cli.worker import _process_job + from trapdata.api.schemas import DetectionResponse, BBox + + # Mock the dataloader to return a batch with detections + mock_detector_results = [ + DetectionResponse( + source_image_id="img1", + bbox=BBox(x1=10, y1=10, x2=50, y2=50), + score=0.9, + classifications=[], + ), + DetectionResponse( + source_image_id="img1", + bbox=BBox(x1=60, y1=60, x2=100, y2=100), + score=0.85, + classifications=[], + ), + DetectionResponse( + source_image_id="img1", + bbox=BBox(x1=110, y1=110, x2=150, y2=150), + score=0.8, + classifications=[], + ), + ] + + # Mock classifier + mock_classifier = MagicMock() + mock_classifier.get_transforms.return_value = torchvision.transforms.Compose([ + torchvision.transforms.Resize((128, 128)), + torchvision.transforms.ToTensor(), + ]) + + # Track how many times predict_batch is called and with what batch sizes + predict_batch_calls = [] + + def mock_predict_batch(batch): + predict_batch_calls.append(batch.shape[0]) # Record batch size + # Return dummy output for each item in batch + return torch.rand(batch.shape[0], 10) # 10 classes + + mock_classifier.predict_batch = mock_predict_batch + mock_classifier.post_process_batch.return_value = [ + MagicMock(labels=["class1"] * 10, logit=[0.0] * 10, scores=[0.1] * 10) + for _ in range(3) + ] + mock_classifier.update_detection_classification.return_value = mock_detector_results[0] + mock_classifier.reset = MagicMock() + + # Mock detector + mock_detector = MagicMock() + mock_detector.results = mock_detector_results + mock_detector.predict_batch.return_value = [] + mock_detector.post_process_batch.return_value = [] + mock_detector.save_results = MagicMock() + mock_detector.reset = MagicMock() + + # Create a simple batch with one image + image_tensor = torch.rand(3, 200, 200) + mock_batch = { + "images": [image_tensor], + "image_ids": ["img1"], + "reply_subjects": ["subj1"], + "image_urls": ["http://example.com/img1.jpg"], + "failed_items": [], + } + + # Mock the dataloader to return our batch + mock_loader = [mock_batch] + + # Mock settings + mock_settings = MagicMock() + mock_settings.antenna_api_base_url = "http://localhost:8000/api/v2" + mock_settings.antenna_api_auth_token = "test_token" + + # Patch dependencies + with patch("trapdata.cli.worker.get_rest_dataloader", return_value=mock_loader), \ + patch("trapdata.cli.worker.CLASSIFIER_CHOICES", {"test_pipeline": MagicMock(return_value=mock_classifier)}), \ + patch("trapdata.cli.worker.APIMothDetector", return_value=mock_detector), \ + patch("trapdata.cli.worker.post_batch_results", return_value=True): + + # Run the worker + _process_job(pipeline="test_pipeline", job_id=1, settings=mock_settings) + + # Verify that predict_batch was called exactly once (batched) + assert len(predict_batch_calls) == 1, ( + f"Expected predict_batch to be called once (batched), " + f"but it was called {len(predict_batch_calls)} times" + ) + + # Verify the batch size was 3 (all crops together) + assert predict_batch_calls[0] == 3, ( + f"Expected batch size of 3, but got {predict_batch_calls[0]}" + ) + + print("✓ Batched classification test passed!") + print(f" - predict_batch called {len(predict_batch_calls)} time(s)") + print(f" - Batch size: {predict_batch_calls[0]} crops") + + +if __name__ == "__main__": + test_batched_classification() diff --git a/trapdata/cli/worker.py b/trapdata/cli/worker.py new file mode 100644 index 00000000..894ee519 --- /dev/null +++ b/trapdata/cli/worker.py @@ -0,0 +1,320 @@ +"""Worker to process images from the REST API queue.""" + +import datetime +import time +from typing import List + +import numpy as np +import requests +import torch +import torchvision.transforms + +from trapdata.api.api import CLASSIFIER_CHOICES +from trapdata.api.datasets import get_rest_dataloader +from trapdata.api.models.localization import APIMothDetector +from trapdata.api.schemas import ( + AntennaJobsListResponse, + AntennaTaskResult, + AntennaTaskResultError, + DetectionResponse, + PipelineResultsResponse, + SourceImageResponse, +) +from trapdata.common.logs import logger +from trapdata.common.utils import log_time +from trapdata.settings import Settings, read_settings + +SLEEP_TIME_SECONDS = 5 + + +def post_batch_results( + base_url: str, job_id: int, results: list[AntennaTaskResult], auth_token: str | None = None +) -> bool: + """ + Post batch results back to the API. + + Args: + base_url: Base URL for the API + job_id: Job ID + results: List of AntennaTaskResult objects + auth_token: API authentication token + + Returns: + True if successful, False otherwise + """ + url = f"{base_url.rstrip('/')}/jobs/{job_id}/result/" + + headers = {} + if auth_token: + headers["Authorization"] = f"Token {auth_token}" + + payload = [r.model_dump(mode="json") for r in results] + + try: + response = requests.post(url, json=payload, headers=headers, timeout=60) + response.raise_for_status() + logger.info(f"Successfully posted {len(results)} results to {url}") + return True + except requests.RequestException as e: + logger.error(f"Failed to post results to {url}: {e}") + return False + + +def _get_jobs(base_url: str, auth_token: str, pipeline_slug: str) -> list[int]: + """Fetch job ids from the API for the given pipeline. + + Calls: GET {base_url}/jobs?pipeline__slug=&ids_only=1 + + Returns a list of job ids (possibly empty) on success or error. + """ + try: + url = f"{base_url.rstrip('/')}/jobs" + params = {"pipeline__slug": pipeline_slug, "ids_only": 1, "incomplete_only": 1} + + headers = {} + if auth_token: + headers["Authorization"] = f"Token {auth_token}" + + resp = requests.get(url, params=params, headers=headers, timeout=30) + resp.raise_for_status() + + # Parse and validate response with Pydantic + jobs_response = AntennaJobsListResponse.model_validate(resp.json()) + return [job.id for job in jobs_response.results] + except requests.RequestException as e: + logger.error(f"Failed to fetch jobs from {base_url}: {e}") + return [] + except Exception as e: + logger.error(f"Failed to parse jobs response: {e}") + return [] + + +def run_worker(pipelines: List[str]): + """Run the worker to process images from the REST API queue.""" + settings = read_settings() + + # Validate auth token + if not settings.antenna_api_auth_token: + raise ValueError( + "AMI_ANTENNA_API_AUTH_TOKEN environment variable must be set. " + "Get your auth token from your Antenna project settings." + ) + + # TODO CGJS: Support a list of pipelines + while True: + # TODO CGJS: Support pulling and prioritizing single image tasks, which are used in interactive testing + # These should probably come from a dedicated endpoint and should preempt batch jobs under the assumption that they + # would run on the same GPU. + any_jobs = False + for pipeline in pipelines: + logger.info(f"Checking for jobs for pipeline {pipeline}") + jobs = _get_jobs( + base_url=settings.antenna_api_base_url, + auth_token=settings.antenna_api_auth_token, + pipeline_slug=pipeline, + ) + for job_id in jobs: + logger.info(f"Processing job {job_id} with pipeline {pipeline}") + any_work_done = _process_job( + pipeline=pipeline, + job_id=job_id, + settings=settings, + ) + any_jobs = any_jobs or any_work_done + + if not any_jobs: + logger.info(f"No jobs found, sleeping for {SLEEP_TIME_SECONDS} seconds") + time.sleep(SLEEP_TIME_SECONDS) + + +@torch.no_grad() +def _process_job(pipeline: str, job_id: int, settings: Settings) -> bool: + """Run the worker to process images from the REST API queue. + + Args: + pipeline: Pipeline name to use for processing (e.g., moth_binary, panama_moths_2024) + job_id: Job ID to process + settings: Settings object with antenna_api_* configuration + Returns: + True if any work was done, False otherwise + """ + did_work = False + loader = get_rest_dataloader(job_id=job_id, settings=settings) + classifier = None + detector = None + + torch.cuda.empty_cache() + items = 0 + + total_detection_time = 0.0 + total_classification_time = 0.0 + total_save_time = 0.0 + total_dl_time = 0.0 + all_detections = [] + _, t = log_time() + + for i, batch in enumerate(loader): + dt, t = t("Finished loading batch") + total_dl_time += dt + if not batch: + logger.warning(f"Batch {i+1} is empty, skipping") + continue + + # Defer instantiation of detector and classifier until we have data + if not classifier: + classifier_class = CLASSIFIER_CHOICES[pipeline] + classifier = classifier_class(source_images=[], detections=[]) + detector = APIMothDetector([]) + assert detector is not None, "Detector not initialized" + assert classifier is not None, "Classifier not initialized" + detector.reset([]) + did_work = True + + # Extract data from dictionary batch + images = batch.get("images", []) + image_ids = batch.get("image_ids", []) + reply_subjects = batch.get("reply_subjects", []) + image_urls = batch.get("image_urls", []) + + # Track start time for this batch + batch_start_time = datetime.datetime.now() + + logger.info(f"Processing batch {i+1}") + # output is dict of "boxes", "labels", "scores" + batch_output = [] + if len(images) > 0: + batch_output = detector.predict_batch(images) + + items += len(batch_output) + logger.info(f"Total items processed so far: {items}") + batch_output = list(detector.post_process_batch(batch_output)) + + # Convert image_ids to list if needed + if isinstance(image_ids, (np.ndarray, torch.Tensor)): + image_ids = image_ids.tolist() + + # TODO CGJS: Add seconds per item calculation for both detector and classifier + detector.save_results( + item_ids=image_ids, + batch_output=batch_output, + seconds_per_item=0, + ) + dt, t = t("Finished detection") + total_detection_time += dt + + # Group detections by image_id + image_detections: dict[str, list[DetectionResponse]] = { + img_id: [] for img_id in image_ids + } + image_tensors = dict(zip(image_ids, images)) + + classifier.reset(detector.results) + + # Batch classification: collect all crops first, then classify together + if detector.results: + # Step 1: Collect all crops and metadata + crops = [] + detection_metadata = [] # (idx, image_id, detection) + + for idx, dresp in enumerate(detector.results): + image_tensor = image_tensors[dresp.source_image_id] + bbox = dresp.bbox + + # Validate bbox dimensions to avoid empty crops + if bbox.y1 >= bbox.y2 or bbox.x1 >= bbox.x2: + logger.warning( + f"Skipping detection {idx} with invalid bbox: " + f"({bbox.x1}, {bbox.y1}, {bbox.x2}, {bbox.y2})" + ) + continue + + # Crop the image tensor using the bbox + crop = image_tensor[ + :, int(bbox.y1) : int(bbox.y2), int(bbox.x1) : int(bbox.x2) + ] + + # Convert tensor to PIL Image for transforms (same as API pipeline) + # Transforms expect PIL images and handle resizing to model's input_size + crop_pil = torchvision.transforms.ToPILImage()(crop) + + # Apply classifier transforms (resizes to uniform size) + crop_transformed = classifier.get_transforms()(crop_pil) + + crops.append(crop_transformed) + detection_metadata.append((idx, dresp.source_image_id, dresp)) + + # Step 2: Stack crops into a batch tensor (only if we have valid crops) + if crops: + batched_crops = torch.stack(crops) + + # Step 3: Run batched classification (single GPU call) + classifier_out = classifier.predict_batch(batched_crops) + classifier_out = classifier.post_process_batch(classifier_out) + + # Step 4: Map results back to detections + for (idx, image_id, dresp), predictions in zip(detection_metadata, classifier_out): + detection = classifier.update_detection_classification( + seconds_per_item=0, + image_id=image_id, + detection_idx=idx, + predictions=predictions, + ) + image_detections[image_id].append(detection) + all_detections.append(detection) + + ct, t = t("Finished classification") + total_classification_time += ct + + # Calculate batch processing time + batch_end_time = datetime.datetime.now() + batch_elapsed = (batch_end_time - batch_start_time).total_seconds() + + # Post results back to the API with PipelineResponse for each image + batch_results: list[AntennaTaskResult] = [] + for reply_subject, image_id, image_url in zip( + reply_subjects, image_ids, image_urls + ): + # Create SourceImageResponse for this image + source_image = SourceImageResponse(id=image_id, url=image_url) + + # Create PipelineResultsResponse + pipeline_response = PipelineResultsResponse( + pipeline=pipeline, + source_images=[source_image], + detections=image_detections[image_id], + total_time=batch_elapsed / len(image_ids) if image_ids else 0, # Avoid division by zero + ) + + batch_results.append( + AntennaTaskResult( + reply_subject=reply_subject, + result=pipeline_response, + ) + ) + failed_items = batch.get("failed_items") + if failed_items: + for failed_item in failed_items: + batch_results.append( + AntennaTaskResult( + reply_subject=failed_item.get("reply_subject"), + result=AntennaTaskResultError( + error=failed_item.get("error", "Unknown error"), + image_id=failed_item.get("image_id"), + ), + ) + ) + + post_batch_results( + settings.antenna_api_base_url, + job_id, + batch_results, + settings.antenna_api_auth_token, + ) + st, t = t("Finished posting results") + total_save_time += st + + logger.info( + f"Done, detections: {len(all_detections)}. Detecting time: {total_detection_time}, " + f"classification time: {total_classification_time}, dl time: {total_dl_time}, save time: {total_save_time}" + ) + return did_work diff --git a/trapdata/settings.py b/trapdata/settings.py index 6ce566ed..c13b2768 100644 --- a/trapdata/settings.py +++ b/trapdata/settings.py @@ -36,6 +36,9 @@ class Settings(BaseSettings): localization_batch_size: int = 2 classification_batch_size: int = 20 num_workers: int = 1 + antenna_api_base_url: str = "http://localhost:8000/api/v2" + antenna_api_auth_token: str = "" + antenna_api_batch_size: int = 4 @pydantic.field_validator("image_base_path", "user_data_path") def validate_path(cls, v): @@ -143,6 +146,24 @@ class Config: "kivy_type": "numeric", "kivy_section": "performance", }, + "antenna_api_base_url": { + "title": "Antenna API Base URL", + "description": "URL to the Antenna platform API for worker processing (should include /api/v2)", + "kivy_type": "string", + "kivy_section": "antenna", + }, + "antenna_api_auth_token": { + "title": "Antenna API Token", + "description": "Authentication token for your Antenna project", + "kivy_type": "string", + "kivy_section": "antenna", + }, + "antenna_api_batch_size": { + "title": "Antenna API Batch Size", + "description": "Number of tasks to fetch from Antenna per batch", + "kivy_type": "numeric", + "kivy_section": "antenna", + }, } @classmethod