diff --git a/trapdata/antenna/datasets.py b/trapdata/antenna/datasets.py index faf56b8..25b33ba 100644 --- a/trapdata/antenna/datasets.py +++ b/trapdata/antenna/datasets.py @@ -1,6 +1,69 @@ -"""Dataset classes for streaming tasks from the Antenna API.""" +"""Dataset and DataLoader for streaming tasks from the Antenna API. + +Data loading pipeline overview +============================== + +The pipeline has three layers of concurrency. Each layer is controlled by a +different setting and targets a different bottleneck. + +:: + + ┌──────────────────────────────────────────────────────────────────┐ + │ GPU process (_worker_loop in worker.py) │ + │ One per GPU. Runs detection → classification on batches. │ + │ Controlled by: automatic (one per torch.cuda.device_count()) │ + ├──────────────────────────────────────────────────────────────────┤ + │ DataLoader workers (num_workers subprocesses) │ + │ Each subprocess runs its own RESTDataset.__iter__ loop: │ + │ 1. GET /tasks → fetch batch of task metadata from Antenna │ + │ 2. Download images (threaded, see below) │ + │ 3. Yield individual (image_tensor, metadata) rows │ + │ The DataLoader collates rows into GPU-sized batches. │ + │ Controlled by: settings.num_workers (AMI_NUM_WORKERS) │ + │ Default: 2. Safe >0 because Antenna dequeues atomically. │ + ├──────────────────────────────────────────────────────────────────┤ + │ Thread pool (ThreadPoolExecutor inside each DataLoader worker) │ + │ Downloads images concurrently *within* one API fetch batch. │ + │ Each thread: HTTP GET → PIL open → RGB convert → ToTensor(). │ + │ Controlled by: ThreadPoolExecutor(max_workers=8) on the class. │ + │ Note: RGB conversion and ToTensor are GIL-bound (CPU). Only │ + │ the network wait truly runs in parallel. A future optimisation │ + │ could move transforms out of the thread. │ + └──────────────────────────────────────────────────────────────────┘ + +Settings quick-reference (prefix with AMI_ as env vars): + + localization_batch_size (default 8) + How many images the GPU processes at once (detection). Larger = + more GPU memory. These are full-resolution images (~4K). + + num_workers (default 2) + DataLoader subprocesses. Each independently fetches tasks and + downloads images. More workers = more images prefetched for the + GPU, at the cost of CPU/RAM. With 0 workers, fetching and + inference are sequential (useful for debugging). + + antenna_api_batch_size (default 16) + How many task URLs to request from Antenna per API call. + Determines how many images are downloaded concurrently per + thread pool invocation. Should be >= localization_batch_size + so one API call can fill at least one GPU batch without an + extra round trip. + + prefetch_factor (PyTorch default: 2 when num_workers > 0) + Batches prefetched per worker. Not overridden here — the + default was tested and no improvement was measured by + increasing it (it just adds memory pressure). + +What has NOT been benchmarked yet (as of 2026-02): + - Optimal num_workers / thread count combination + - Whether moving transforms out of threads helps throughput + - Whether multiple DataLoader workers + threads overlap well + or contend on the GIL +""" import typing +from concurrent.futures import ThreadPoolExecutor from io import BytesIO import requests @@ -32,9 +95,9 @@ class RESTDataset(torch.utils.data.IterableDataset): DataLoader workers are SAFE and won't process duplicate tasks. Each worker independently fetches different tasks from the shared queue. - With num_workers > 0: - Worker 1: GET /tasks → receives [1,2,3,4], removed from queue - Worker 2: GET /tasks → receives [5,6,7,8], removed from queue + With DataLoader num_workers > 0 (I/O subprocesses, not AMI instances): + Subprocess 1: GET /tasks → receives [1,2,3,4], removed from queue + Subprocess 2: GET /tasks → receives [5,6,7,8], removed from queue No duplicates, safe for parallel processing """ @@ -58,20 +121,40 @@ def __init__( """ super().__init__() self.base_url = base_url + self.auth_token = auth_token self.job_id = job_id self.batch_size = batch_size self.image_transforms = image_transforms or torchvision.transforms.ToTensor() - # Create persistent sessions for connection pooling - self.api_session = get_http_session(auth_token) - self.image_fetch_session = get_http_session() # No auth for external image URLs + # These are created lazily in _ensure_sessions() because they contain + # unpicklable objects (ThreadPoolExecutor has a SimpleQueue) and + # PyTorch DataLoader with num_workers>0 pickles the dataset to send + # it to worker subprocesses. + self._api_session: requests.Session | None = None + self._image_fetch_session: requests.Session | None = None + self._executor: ThreadPoolExecutor | None = None + + def _ensure_sessions(self) -> None: + """Lazily create HTTP sessions and thread pool. + + Called once per worker process on first use. This avoids pickling + issues with num_workers > 0 (SimpleQueue, socket objects, etc.). + """ + if self._api_session is None: + self._api_session = get_http_session(self.auth_token) + if self._image_fetch_session is None: + self._image_fetch_session = get_http_session() + if self._executor is None: + self._executor = ThreadPoolExecutor(max_workers=8) def __del__(self): - """Clean up HTTP sessions on dataset destruction.""" - if hasattr(self, "api_session"): - self.api_session.close() - if hasattr(self, "image_fetch_session"): - self.image_fetch_session.close() + """Clean up HTTP sessions and thread pool on dataset destruction.""" + if self._executor is not None: + self._executor.shutdown(wait=False) + if self._api_session is not None: + self._api_session.close() + if self._image_fetch_session is not None: + self._image_fetch_session.close() def _fetch_tasks(self) -> list[AntennaPipelineProcessingTask]: """ @@ -86,7 +169,9 @@ def _fetch_tasks(self) -> list[AntennaPipelineProcessingTask]: url = f"{self.base_url.rstrip('/')}/jobs/{self.job_id}/tasks" params = {"batch": self.batch_size} - response = self.api_session.get(url, params=params, timeout=30) + self._ensure_sessions() + assert self._api_session is not None + response = self._api_session.get(url, params=params, timeout=30) response.raise_for_status() # Parse and validate response with Pydantic @@ -94,8 +179,12 @@ def _fetch_tasks(self) -> list[AntennaPipelineProcessingTask]: return tasks_response.tasks # Empty list is valid (queue drained) def _load_image(self, image_url: str) -> torch.Tensor | None: - """ - Load an image from a URL and convert it to a PyTorch tensor. + """Load an image from a URL and convert it to a PyTorch tensor. + + Called from threads inside ``_load_images_threaded``. The HTTP + fetch is truly concurrent (network I/O releases the GIL), but + PIL decode, RGB conversion, and ``image_transforms`` (ToTensor) + are CPU-bound and serialised by the GIL. Args: image_url: URL of the image to load @@ -105,7 +194,9 @@ def _load_image(self, image_url: str) -> torch.Tensor | None: """ try: # Use dedicated session without auth for external images - response = self.image_fetch_session.get(image_url, timeout=30) + self._ensure_sessions() + assert self._image_fetch_session is not None + response = self._image_fetch_session.get(image_url, timeout=30) response.raise_for_status() image = Image.open(BytesIO(response.content)) @@ -120,17 +211,51 @@ def _load_image(self, image_url: str) -> torch.Tensor | None: logger.error(f"Failed to load image from {image_url}: {e}") return None + def _load_images_threaded( + self, + tasks: list[AntennaPipelineProcessingTask], + ) -> dict[str, torch.Tensor | None]: + """Download images for a batch of tasks using concurrent threads. + + Image downloads are I/O-bound (network latency, not CPU), so threads + provide near-linear speedup without the overhead of extra processes. + Note: ``requests.Session`` is not formally thread-safe, but the + underlying urllib3 connection pool handles concurrent socket access. + In practice shared read-only sessions work fine for GET requests; + if issues arise, switch to per-thread sessions. + + Args: + tasks: List of tasks whose images should be downloaded. + + Returns: + Mapping from image_id to tensor (or None on failure), preserving + the order needed by the caller. + """ + + def _download( + task: AntennaPipelineProcessingTask, + ) -> tuple[str, torch.Tensor | None]: + tensor = self._load_image(task.image_url) if task.image_url else None + return (task.image_id, tensor) + + self._ensure_sessions() + assert self._executor is not None + return dict(self._executor.map(_download, tasks)) + def __iter__(self): """ Iterate over tasks from the REST API. + Each API fetch returns a batch of tasks. Images for the entire batch + are downloaded concurrently using threads (see _load_images_threaded), + then yielded one at a time for the DataLoader to collate. + Yields: Dictionary containing: - image: PyTorch tensor of the loaded image - reply_subject: Reply subject for the task - - batch_index: Index of the image in the batch - - job_id: Job ID - image_id: Image ID + - image_url: Source URL """ worker_id = 0 # Initialize before try block to avoid UnboundLocalError try: @@ -140,7 +265,7 @@ def __iter__(self): num_workers = worker_info.num_workers if worker_info else 1 logger.info( - f"Worker {worker_id}/{num_workers} starting iteration for job {self.job_id}" + f"DataLoader subprocess {worker_id}/{num_workers} starting iteration for job {self.job_id}" ) while True: @@ -160,14 +285,12 @@ def __iter__(self): ) break + # Download all images concurrently + image_map = self._load_images_threaded(tasks) + for task in tasks: + image_tensor = image_map.get(task.image_id) errors = [] - # Load the image - # _, t = log_time() - image_tensor = ( - self._load_image(task.image_url) if task.image_url else None - ) - # _, t = t(f"Loaded image from {image_url}") if image_tensor is None: errors.append("failed to load image") @@ -199,7 +322,7 @@ def rest_collate_fn(batch: list[dict]) -> dict: Custom collate function that separates failed and successful items. Returns a dict with: - - images: Stacked tensor of valid images (only present if there are successful items) + - images: List of image tensors (only present if there are successful items) - reply_subjects: List of reply subjects for valid images - image_ids: List of image IDs for valid images - image_urls: List of image URLs for valid images @@ -231,7 +354,7 @@ def rest_collate_fn(batch: list[dict]) -> dict: # Collate successful items if successful: result = { - "images": torch.stack([item["image"] for item in successful]), + "images": [item["image"] for item in successful], "reply_subjects": [item["reply_subject"] for item in successful], "image_ids": [item["image_id"] for item in successful], "image_urls": [item.get("image_url") for item in successful], @@ -252,18 +375,22 @@ def get_rest_dataloader( job_id: int, settings: "Settings", ) -> torch.utils.data.DataLoader: - """ - Create a DataLoader that fetches tasks from Antenna API. + """Create a DataLoader that fetches tasks from Antenna API. + + See the module docstring for an overview of the three concurrency + layers (GPU processes → DataLoader workers → thread pool) and which + settings control each. - Note: num_workers > 0 is SAFE here (unlike local file reading) because: - - Antenna API provides atomic task dequeue (work queue pattern) - - No shared file handles between workers - - Each worker gets different tasks automatically - - Parallel downloads improve throughput for I/O-bound work + DataLoader num_workers > 0 is safe here because Antenna dequeues + tasks atomically — each worker subprocess gets a unique set of tasks. Args: job_id: Job ID to fetch tasks for - settings: Settings object with antenna_api_* configuration + settings: Settings object. Relevant fields: + - antenna_api_base_url / antenna_api_auth_token + - antenna_api_batch_size (tasks per API call) + - localization_batch_size (images per GPU batch) + - num_workers (DataLoader subprocesses) """ dataset = RESTDataset( base_url=settings.antenna_api_base_url, diff --git a/trapdata/antenna/tests/test_worker.py b/trapdata/antenna/tests/test_worker.py index 4a83958..e9919ca 100644 --- a/trapdata/antenna/tests/test_worker.py +++ b/trapdata/antenna/tests/test_worker.py @@ -34,6 +34,29 @@ # --------------------------------------------------------------------------- +class TestDataLoaderMultiWorker(TestCase): + """DataLoader with num_workers > 0 must be able to start workers.""" + + def test_dataloader_starts_with_num_workers(self): + """Creating an iterator pickles the dataset to send to worker subprocesses.""" + dataset = RESTDataset( + base_url="http://localhost:1/api/v2", + auth_token="test-token", + job_id=1, + batch_size=4, + ) + loader = torch.utils.data.DataLoader( + dataset, + batch_size=2, + num_workers=2, + collate_fn=rest_collate_fn, + ) + # iter() pickles the dataset and spawns workers. + # If the dataset has unpicklable attributes this raises TypeError. + it = iter(loader) + del it + + class TestRestCollateFn(TestCase): """Tests for rest_collate_fn which separates successful/failed items.""" @@ -55,7 +78,8 @@ def test_all_successful(self): result = rest_collate_fn(batch) assert "images" in result - assert result["images"].shape == (2, 3, 64, 64) + assert len(result["images"]) == 2 + assert result["images"][0].shape == (3, 64, 64) assert result["image_ids"] == ["img1", "img2"] assert result["reply_subjects"] == ["subj1", "subj2"] assert result["failed_items"] == [] @@ -104,7 +128,8 @@ def test_mixed(self): ] result = rest_collate_fn(batch) - assert result["images"].shape == (1, 3, 64, 64) + assert len(result["images"]) == 1 + assert result["images"][0].shape == (3, 64, 64) assert result["image_ids"] == ["img1"] assert len(result["failed_items"]) == 1 assert result["failed_items"][0]["image_id"] == "img2" diff --git a/trapdata/antenna/worker.py b/trapdata/antenna/worker.py index 2fbf3b5..bba7905 100644 --- a/trapdata/antenna/worker.py +++ b/trapdata/antenna/worker.py @@ -5,6 +5,8 @@ import numpy as np import torch +import torch.multiprocessing as mp +import torchvision from trapdata.antenna.client import get_jobs, post_batch_results from trapdata.antenna.datasets import get_rest_dataloader @@ -24,7 +26,11 @@ def run_worker(pipelines: list[str]): - """Run the worker to process images from the REST API queue.""" + """Run the worker to process images from the REST API queue. + + Automatically spawns one AMI worker instance process per available GPU. + On single-GPU or CPU-only machines, runs in-process (no overhead). + """ settings = read_settings() # Validate auth token @@ -34,20 +40,57 @@ def run_worker(pipelines: list[str]): "Get your auth token from your Antenna project settings." ) + gpu_count = torch.cuda.device_count() + + if gpu_count > 1: + logger.info(f"Found {gpu_count} GPUs, spawning one AMI worker instance per GPU") + # Don't pass settings through mp.spawn — Settings contains enums that + # can't be pickled. Each child process calls read_settings() itself. + mp.spawn( + _worker_loop, + args=(pipelines,), + nprocs=gpu_count, + join=True, + ) + else: + if gpu_count == 1: + logger.info(f"Found 1 GPU: {torch.cuda.get_device_name(0)}") + else: + logger.info("No GPUs found, running on CPU") + _worker_loop(0, pipelines) + + +def _worker_loop(gpu_id: int, pipelines: list[str]): + """Main polling loop for a single AMI worker instance, pinned to a specific GPU. + + Args: + gpu_id: GPU index to pin this AMI worker instance to (0 for CPU-only). + pipelines: List of pipeline slugs to poll for jobs. + """ + settings = read_settings() + + if torch.cuda.is_available() and torch.cuda.device_count() > 0: + torch.cuda.set_device(gpu_id) + logger.info( + f"AMI worker instance {gpu_id} pinned to GPU {gpu_id}: {torch.cuda.get_device_name(gpu_id)}" + ) + while True: # TODO CGJS: Support pulling and prioritizing single image tasks, which are used in interactive testing # These should probably come from a dedicated endpoint and should preempt batch jobs under the assumption that they # would run on the same GPU. any_jobs = False for pipeline in pipelines: - logger.info(f"Checking for jobs for pipeline {pipeline}") + logger.info(f"[GPU {gpu_id}] Checking for jobs for pipeline {pipeline}") jobs = get_jobs( base_url=settings.antenna_api_base_url, auth_token=settings.antenna_api_auth_token, pipeline_slug=pipeline, ) for job_id in jobs: - logger.info(f"Processing job {job_id} with pipeline {pipeline}") + logger.info( + f"[GPU {gpu_id}] Processing job {job_id} with pipeline {pipeline}" + ) try: any_work_done = _process_job( pipeline=pipeline, @@ -57,13 +100,15 @@ def run_worker(pipelines: list[str]): any_jobs = any_jobs or any_work_done except Exception as e: logger.error( - f"Failed to process job {job_id} with pipeline {pipeline}: {e}", + f"[GPU {gpu_id}] Failed to process job {job_id} with pipeline {pipeline}: {e}", exc_info=True, ) # Continue to next job rather than crashing the worker if not any_jobs: - logger.info(f"No jobs found, sleeping for {SLEEP_TIME_SECONDS} seconds") + logger.info( + f"[GPU {gpu_id}] No jobs found, sleeping for {SLEEP_TIME_SECONDS} seconds" + ) time.sleep(SLEEP_TIME_SECONDS) @@ -121,99 +166,139 @@ def _process_job( reply_subjects = batch.get("reply_subjects", [None] * len(images)) image_urls = batch.get("image_urls", [None] * len(images)) - # Validate all arrays have same length before zipping - if len(image_ids) != len(images): - raise ValueError( - f"Length mismatch: image_ids ({len(image_ids)}) != images ({len(images)})" - ) - if len(image_ids) != len(reply_subjects) or len(image_ids) != len(image_urls): - raise ValueError( - f"Length mismatch: image_ids ({len(image_ids)}), " - f"reply_subjects ({len(reply_subjects)}), image_urls ({len(image_urls)})" - ) - - # Track start time for this batch - batch_start_time = datetime.datetime.now() - - logger.info(f"Processing batch {i + 1}") - # output is dict of "boxes", "labels", "scores" - batch_output = [] - if len(images) > 0: - batch_output = detector.predict_batch(images) + batch_results: list[AntennaTaskResult] = [] - items += len(batch_output) - logger.info(f"Total items processed so far: {items}") - batch_output = list(detector.post_process_batch(batch_output)) + try: + # Validate all arrays have same length before zipping + if len(image_ids) != len(images): + raise ValueError( + f"Length mismatch: image_ids ({len(image_ids)}) != images ({len(images)})" + ) + if len(image_ids) != len(reply_subjects) or len(image_ids) != len( + image_urls + ): + raise ValueError( + f"Length mismatch: image_ids ({len(image_ids)}), " + f"reply_subjects ({len(reply_subjects)}), image_urls ({len(image_urls)})" + ) - # Convert image_ids to list if needed - if isinstance(image_ids, (np.ndarray, torch.Tensor)): - image_ids = image_ids.tolist() + # Track start time for this batch + batch_start_time = datetime.datetime.now() - # TODO CGJS: Add seconds per item calculation for both detector and classifier - detector.save_results( - item_ids=image_ids, - batch_output=batch_output, - seconds_per_item=0, - ) - dt, t = t("Finished detection") - total_detection_time += dt - - # Group detections by image_id - image_detections: dict[str, list[DetectionResponse]] = { - img_id: [] for img_id in image_ids - } - image_tensors = dict(zip(image_ids, images, strict=True)) - - classifier.reset(detector.results) - - for idx, dresp in enumerate(detector.results): - image_tensor = image_tensors[dresp.source_image_id] - bbox = dresp.bbox - # crop the image tensor using the bbox - crop = image_tensor[ - :, int(bbox.y1) : int(bbox.y2), int(bbox.x1) : int(bbox.x2) - ] - crop = crop.unsqueeze(0) # add batch dimension - classifier_out = classifier.predict_batch(crop) - classifier_out = classifier.post_process_batch(classifier_out) - detection = classifier.update_detection_classification( - seconds_per_item=0, - image_id=dresp.source_image_id, - detection_idx=idx, - predictions=classifier_out[0], - ) - image_detections[dresp.source_image_id].append(detection) - all_detections.append(detection) + logger.info(f"Processing worker batch {i + 1} ({len(images)} images)") + # output is dict of "boxes", "labels", "scores" + batch_output = [] + if len(images) > 0: + batch_output = detector.predict_batch(images) - ct, t = t("Finished classification") - total_classification_time += ct + items += len(batch_output) + logger.info(f"Total items processed so far: {items}") + batch_output = list(detector.post_process_batch(batch_output)) - # Calculate batch processing time - batch_end_time = datetime.datetime.now() - batch_elapsed = (batch_end_time - batch_start_time).total_seconds() + # Convert image_ids to list if needed + if isinstance(image_ids, (np.ndarray, torch.Tensor)): + image_ids = image_ids.tolist() - # Post results back to the API with PipelineResponse for each image - batch_results: list[AntennaTaskResult] = [] - for reply_subject, image_id, image_url in zip( - reply_subjects, image_ids, image_urls, strict=True - ): - # Create SourceImageResponse for this image - source_image = SourceImageResponse(id=image_id, url=image_url) - - # Create PipelineResultsResponse - pipeline_response = PipelineResultsResponse( - pipeline=pipeline, - source_images=[source_image], - detections=image_detections[image_id], - total_time=batch_elapsed / len(image_ids), # Approximate time per image + # TODO CGJS: Add seconds per item calculation for both detector and classifier + detector.save_results( + item_ids=image_ids, + batch_output=batch_output, + seconds_per_item=0, ) + dt, t = t("Finished detection") + total_detection_time += dt + + # Group detections by image_id + image_detections: dict[str, list[DetectionResponse]] = { + img_id: [] for img_id in image_ids + } + image_tensors = dict(zip(image_ids, images, strict=True)) + + classifier.reset(detector.results) + to_pil = torchvision.transforms.ToPILImage() + classify_transforms = classifier.get_transforms() + + # Collect and transform all crops for batched classification + crops = [] + valid_indices = [] + for idx, dresp in enumerate(detector.results): + image_tensor = image_tensors[dresp.source_image_id] + bbox = dresp.bbox + y1, y2 = int(bbox.y1), int(bbox.y2) + x1, x2 = int(bbox.x1), int(bbox.x2) + if y1 >= y2 or x1 >= x2: + logger.warning( + f"Skipping detection {idx} with invalid bbox: " + f"({x1},{y1})->({x2},{y2})" + ) + continue + crop = image_tensor[:, y1:y2, x1:x2] + crop_pil = to_pil(crop) + crop_transformed = classify_transforms(crop_pil) + crops.append(crop_transformed) + valid_indices.append(idx) + + if crops: + batched_crops = torch.stack(crops) + classifier_out = classifier.predict_batch(batched_crops) + classifier_out = classifier.post_process_batch(classifier_out) + + for crop_i, idx in enumerate(valid_indices): + dresp = detector.results[idx] + detection = classifier.update_detection_classification( + seconds_per_item=0, + image_id=dresp.source_image_id, + detection_idx=idx, + predictions=classifier_out[crop_i], + ) + image_detections[dresp.source_image_id].append(detection) + all_detections.append(detection) + + ct, t = t("Finished classification") + total_classification_time += ct + + # Calculate batch processing time + batch_end_time = datetime.datetime.now() + batch_elapsed = (batch_end_time - batch_start_time).total_seconds() + + # Post results back to the API with PipelineResponse for each image + batch_results.clear() + for reply_subject, image_id, image_url in zip( + reply_subjects, image_ids, image_urls, strict=True + ): + # Create SourceImageResponse for this image + source_image = SourceImageResponse(id=image_id, url=image_url) + + # Create PipelineResultsResponse + pipeline_response = PipelineResultsResponse( + pipeline=pipeline, + source_images=[source_image], + detections=image_detections[image_id], + total_time=batch_elapsed + / len(image_ids), # Approximate time per image + ) - batch_results.append( - AntennaTaskResult( - reply_subject=reply_subject, - result=pipeline_response, + batch_results.append( + AntennaTaskResult( + reply_subject=reply_subject, + result=pipeline_response, + ) ) - ) + except Exception as e: + logger.error(f"Batch {i + 1} failed during processing: {e}", exc_info=True) + # Report errors back to Antenna so tasks aren't stuck in the queue + batch_results = [] + for reply_subject, image_id in zip(reply_subjects, image_ids, strict=True): + batch_results.append( + AntennaTaskResult( + reply_subject=reply_subject, + result=AntennaTaskResultError( + error=f"Batch processing error: {e}", + image_id=image_id, + ), + ) + ) + failed_items = batch.get("failed_items") if failed_items: for failed_item in failed_items: @@ -236,12 +321,10 @@ def _process_job( st, t = t("Finished posting results") if not success: - error_msg = ( + logger.error( f"Failed to post {len(batch_results)} results for job {job_id} to " f"{settings.antenna_api_base_url}. Batch processing data lost." ) - logger.error(error_msg) - raise RuntimeError(error_msg) total_save_time += st diff --git a/trapdata/ml/models/base.py b/trapdata/ml/models/base.py index bb7d1fa..1c694a7 100644 --- a/trapdata/ml/models/base.py +++ b/trapdata/ml/models/base.py @@ -244,11 +244,11 @@ def get_dataloader(self): """ if self.single: logger.info( - f"Preparing dataloader with batch size of {self.batch_size} in single worker mode." + f"Preparing {self.name} inference dataloader (batch_size={self.batch_size}, single worker mode)" ) else: logger.info( - f"Preparing dataloader with batch size of {self.batch_size} and {self.num_workers} workers." + f"Preparing {self.name} inference dataloader (batch_size={self.batch_size}, dataloader_workers={self.num_workers})" ) dataloader_args = { "num_workers": 0 if self.single else self.num_workers, diff --git a/trapdata/ml/utils.py b/trapdata/ml/utils.py index 3d52067..da09746 100644 --- a/trapdata/ml/utils.py +++ b/trapdata/ml/utils.py @@ -42,8 +42,14 @@ def get_device(device_str=None) -> torch.device: @TODO check Kivy settings to see if user forced use of CPU """ if not device_str: - device_str = "cuda" if torch.cuda.is_available() else "cpu" - device = torch.device(device_str) + if torch.cuda.is_available(): + # Use current_device() so mp.spawn workers that called + # torch.cuda.set_device(i) get the correct GPU index. + device = torch.device("cuda", torch.cuda.current_device()) + else: + device = torch.device("cpu") + else: + device = torch.device(device_str) logger.info(f"Using device '{device}' for inference") return device diff --git a/trapdata/settings.py b/trapdata/settings.py index f4b83f1..0020bd5 100644 --- a/trapdata/settings.py +++ b/trapdata/settings.py @@ -33,14 +33,14 @@ class Settings(BaseSettings): default=ml.models.DEFAULT_FEATURE_EXTRACTOR ) classification_threshold: float = 0.6 - localization_batch_size: int = 2 + localization_batch_size: int = 8 classification_batch_size: int = 20 - num_workers: int = 1 + num_workers: int = 2 # Antenna API worker settings antenna_api_base_url: str = "http://localhost:8000/api/v2" antenna_api_auth_token: str = "" - antenna_api_batch_size: int = 4 + antenna_api_batch_size: int = 16 @pydantic.field_validator("image_base_path", "user_data_path") def validate_path(cls, v): @@ -143,8 +143,11 @@ class Config: "kivy_section": "performance", }, "num_workers": { - "title": "Number of workers", - "description": "Number of parallel workers for the PyTorch dataloader. See https://pytorch.org/docs/stable/data.html", + "title": "DataLoader workers", + "description": ( + "Number of parallel subprocesses for the PyTorch DataLoader (image downloading & preprocessing). " + "See https://pytorch.org/docs/stable/data.html" + ), "kivy_type": "numeric", "kivy_section": "performance", },