Batch classification predictions in worker for GPU efficiency#105
Closed
Copilot wants to merge 5 commits intocarlosg/pulldlfrom
Closed
Batch classification predictions in worker for GPU efficiency#105Copilot wants to merge 5 commits intocarlosg/pulldlfrom
Copilot wants to merge 5 commits intocarlosg/pulldlfrom
Conversation
Co-authored-by: mihow <158175+mihow@users.noreply.github.com>
Co-authored-by: mihow <158175+mihow@users.noreply.github.com>
Co-authored-by: mihow <158175+mihow@users.noreply.github.com>
Co-authored-by: mihow <158175+mihow@users.noreply.github.com>
Copilot
AI
changed the title
[WIP] Optimize batch classification predictions for GPU efficiency
Batch classification predictions in worker for GPU efficiency
Jan 28, 2026
mihow
added a commit
that referenced
this pull request
Feb 10, 2026
…U calls Collect all detection crops, apply classifier transforms (which resize to uniform input_size), then run a single batched predict_batch call. Skips detections with invalid bounding boxes (y1 >= y2 or x1 >= x2). This supersedes PR #105. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2 tasks
Collaborator
|
Covered by #110 |
mihow
added a commit
that referenced
this pull request
Feb 17, 2026
…classification (#110) * fix: return images as list in rest_collate_fn to support variable sizes torch.stack requires all tensors to be the same size, which crashes when a batch contains images of different resolutions (e.g. 3420x6080 and 2160x4096). FasterRCNN natively accepts a list of variable-sized images, and predict_batch already handles this code path. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: handle batch processing errors per-batch instead of crashing job Wrap the batch processing loop body in try/except so a single failed batch doesn't kill the entire job. On failure, error results are posted back to Antenna for each image in the batch so tasks don't get stuck in the queue. Also downgrade post_batch_results failure from a raised exception to a logged error to avoid losing progress on subsequent batches. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * perf: batch classification crops in worker instead of N individual GPU calls Collect all detection crops, apply classifier transforms (which resize to uniform input_size), then run a single batched predict_batch call. Skips detections with invalid bounding boxes (y1 >= y2 or x1 >= x2). This supersedes PR #105. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * perf: increase default localization_batch_size and num_workers for GPU localization_batch_size 2 → 8 and num_workers 1 → 4. The old defaults were far too conservative for 24GB VRAM GPUs. These can still be overridden via AMI_LOCALIZATION_BATCH_SIZE and AMI_NUM_WORKERS env vars. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * feat: spawn one worker process per GPU for multi-GPU inference run_worker() now detects torch.cuda.device_count() and uses torch.multiprocessing.spawn to launch one worker per GPU. Each worker pins itself to a specific GPU via set_device(). Single-GPU and CPU-only machines keep existing single-process behavior with no overhead. Also fixes get_device() to use current_device() instead of bare "cuda" so that models load onto the correct GPU in spawned workers. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: clarify batch size log messages to distinguish worker vs inference Worker batch logs now show image count ("Processing worker batch 3 (8 images)") and model inference logs include the model name ("Preparing FasterRCNN inference dataloader (batch_size=4, single worker mode)"). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * perf: increase DataLoader prefetch_factor to 4 for worker The default prefetch_factor of 2 means the DataLoader only prepares 2 batches ahead. With GPU inference taking ~2s and image downloads taking ~30s per batch, the GPU idles waiting for data. Bumping to 4 keeps the download pipeline fuller so the next batch is more likely to be ready. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * perf: download images concurrently within each DataLoader worker Previously each DataLoader worker downloaded images sequentially: fetch task metadata, then download image 1, download image 2, ... download image N, yield all. With 32 tasks per API fetch, this meant ~30s of serial HTTP requests before a single image reached the GPU. Now _load_images_threaded() uses a ThreadPoolExecutor (up to 8 threads) to download all images in a task batch concurrently. Threads are ideal here because image downloads are I/O-bound (network latency), not CPU-bound, and the requests Session's connection pool is thread-safe. This stacks with the existing DataLoader num_workers parallelism: - num_workers: N independent DataLoader processes, each polling the API - ThreadPoolExecutor: within each process, M concurrent image downloads - prefetch_factor: DataLoader queues future batches while GPU is busy Expected improvement: download time for a batch of 32 images drops from ~30s (sequential) to ~4-5s (8 concurrent threads), keeping the GPU fed. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * docs: clarify types of workers * refactor: make ThreadPoolExecutor a class member, simplify with map() - Keep the thread pool alive across batches instead of recreating it - Use executor.map() instead of submit + as_completed - Fix docstring: requests.Session is not formally thread-safe Addresses review comments from @carlosgjs. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * docs: fix stale docstring in rest_collate_fn "Stacked tensor" → "List of image tensors" to match the actual return type after the variable-size image change. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: defensive batch_results init and strict=True in error handler - Initialize batch_results before the try block to prevent potential NameError if a future refactor introduces a path between try/except and the post-results code. - Add strict=True to zip(reply_subjects, image_ids) in the except block so length mismatches raise immediately rather than silently dropping error reports. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * revert: remove prefetch_factor=4 override, use PyTorch default No measured improvement over the default (2). The override just increases memory usage without demonstrated benefit. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * docs: document data loading pipeline concurrency layers Add a module-level overview mapping the three concurrency layers (GPU processes, DataLoader workers, thread pool) to their settings, what work each layer does, and what runs under the GIL. Includes a "not yet benchmarked" section so future contributors know which knobs are tuned empirically vs speculatively. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * perf: tune defaults for higher GPU utilization - antenna_api_batch_size: 4 → 16 (fetch enough tasks per API call to fill at least one GPU batch without an extra round trip) - num_workers: 4 → 2 (each worker prefetches more with larger API batches; fewer workers reduces CPU/RAM overhead) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: lazily init unpicklable objects in RESTDataset for num_workers>0 ThreadPoolExecutor (SimpleQueue), requests.Session objects created in __init__ cannot be pickled. PyTorch DataLoader with num_workers>0 uses spawn, which pickles the dataset to send to worker subprocesses. Move session and executor creation to _ensure_sessions(), called on first use in each worker process. Add regression test that creates a DataLoader with num_workers=2 and verifies workers can start. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * docs: fix stale num_workers default in datasets.py docstring (4 → 2) Matches the actual default changed in ed49153. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Worker code from PR #94 classifies each detection crop individually (batch size = 1), resulting in N GPU calls for N detections. This creates high kernel launch and memory transfer overhead.
Implementation
Batches all crops together following the same pattern used by the DB queue and API pipelines:
Transforms resize each crop to the model's fixed
input_sizebefore batching, sotorch.stack()works without padding.Changes
trapdata/cli/worker.py: Collect crops, apply transforms, batch classifytrapdata/api/models/classification.py: Addreset()andupdate_detection_classification()methodstrapdata/api/models/localization.py: Addreset()methodtrapdata/cli/base.py: AddworkerCLI commandtrapdata/settings.py: Add Antenna API settings (base_url,auth_token,batch_size)trapdata/api/: Add REST dataset, schemas, and supporting code for Antenna integrationEdge Cases
Testing
trapdata/cli/tests/test_worker_batching.py: Verifiespredict_batchcalled once with batch size = Ntrapdata/api/tests/test_worker.py: Integration tests for REST workerExpected Impact
For images with 10+ detections: N× reduction in GPU overhead from single batched call vs N individual calls.
Original prompt
💬 We'd love your input! Share your thoughts on Copilot coding agent in our 2 minute survey.