Skip to content

Comments

Batch classification predictions in worker for GPU efficiency#105

Closed
Copilot wants to merge 5 commits intocarlosg/pulldlfrom
copilot/batch-classification-predictions
Closed

Batch classification predictions in worker for GPU efficiency#105
Copilot wants to merge 5 commits intocarlosg/pulldlfrom
copilot/batch-classification-predictions

Conversation

Copy link
Contributor

Copilot AI commented Jan 28, 2026

Worker code from PR #94 classifies each detection crop individually (batch size = 1), resulting in N GPU calls for N detections. This creates high kernel launch and memory transfer overhead.

Implementation

Batches all crops together following the same pattern used by the DB queue and API pipelines:

# Before: N GPU calls
for dresp in detector.results:
    crop = crop.unsqueeze(0)
    classifier_out = classifier.predict_batch(crop)

# After: 1 GPU call
crops = []
for dresp in detector.results:
    crop_pil = torchvision.transforms.ToPILImage()(crop)
    crop_transformed = classifier.get_transforms()(crop_pil)  # Resize to uniform size
    crops.append(crop_transformed)

if crops:
    batched_crops = torch.stack(crops)
    classifier_out = classifier.predict_batch(batched_crops)

Transforms resize each crop to the model's fixed input_size before batching, so torch.stack() works without padding.

Changes

  • trapdata/cli/worker.py: Collect crops, apply transforms, batch classify
  • trapdata/api/models/classification.py: Add reset() and update_detection_classification() methods
  • trapdata/api/models/localization.py: Add reset() method
  • trapdata/cli/base.py: Add worker CLI command
  • trapdata/settings.py: Add Antenna API settings (base_url, auth_token, batch_size)
  • trapdata/api/: Add REST dataset, schemas, and supporting code for Antenna integration

Edge Cases

  • Skip detections with invalid bounding boxes (y1 ≥ y2 or x1 ≥ x2)
  • Handle empty crops list before stacking
  • Avoid division by zero when batch has no valid images

Testing

  • trapdata/cli/tests/test_worker_batching.py: Verifies predict_batch called once with batch size = N
  • trapdata/api/tests/test_worker.py: Integration tests for REST worker

Expected Impact

For images with 10+ detections: N× reduction in GPU overhead from single batched call vs N individual calls.

Original prompt

This section details on the original issue you should resolve

<issue_title>Worker: Batch classification predictions for GPU efficiency</issue_title>
<issue_description>## Summary

The worker currently classifies each detection crop one at a time in a loop, which is inefficient on GPU. Crops can be batched together for a single forward pass.

Question: the classification step in the existing dataloader & classification pipeline already batches crops (defaults to 100 or 200). How does it do it without resizing or padding? Review the existing pipeline (the one that reads from the local DB queue and does batch processing).

Current behavior

In trapdata/cli/worker.py:212-229, each detection is processed individually:

for idx, dresp in enumerate(detector.results):
    # ... crop the image ...
    crop = crop.unsqueeze(0)  # batch size = 1
    classifier_out = classifier.predict_batch(crop)  # single item
    # ...

This causes:

  • High per-call overhead (GPU memory transfers, kernel launches)
  • Poor GPU utilization with batch size = 1
  • Slower processing, especially for images with many detections

Proposed solution

Collect all crops first, then classify in a single batched call:

# Collect all crops
crops = []
for dresp in detector.results:
    crop = image_tensor[:, int(bbox.y1):int(bbox.y2), int(bbox.x1):int(bbox.x2)]
    crops.append(crop)

# Batch and classify all at once
if crops:
    # May need padding/resizing for uniform tensor sizes
    batched_crops = torch.stack(crops)  
    classifier_out = classifier.predict_batch(batched_crops)
    # ... map predictions back to detections ...

Considerations

  • Crops may have different sizes - need to resize or pad to uniform dimensions
  • Memory constraints - may need to sub-batch if too many detections
  • Need benchmarking to measure actual improvement
  • Should verify classification accuracy is unchanged after batching

Related

Test plan

  • Benchmark current vs batched implementation on sample images
  • Verify classification results are identical
  • Test with images containing many detections (10+, 50+)
  • Test memory usage with large batches</issue_description>

<agent_instructions>Question: the classification step in the existing dataloader & classification pipeline already batches crops (defaults to 100 or 200). How does it do it without resizing or padding? Review the existing pipeline (the one that reads from the local DB queue and does batch processing).</agent_instructions>

Comments on the Issue (you are @copilot in this section)


💬 We'd love your input! Share your thoughts on Copilot coding agent in our 2 minute survey.

Copilot AI and others added 2 commits January 28, 2026 19:50
Co-authored-by: mihow <158175+mihow@users.noreply.github.com>
Co-authored-by: mihow <158175+mihow@users.noreply.github.com>
@mihow mihow changed the base branch from main to carlosg/pulldl January 28, 2026 19:52
Copilot AI and others added 2 commits January 28, 2026 19:54
Co-authored-by: mihow <158175+mihow@users.noreply.github.com>
Co-authored-by: mihow <158175+mihow@users.noreply.github.com>
Copilot AI changed the title [WIP] Optimize batch classification predictions for GPU efficiency Batch classification predictions in worker for GPU efficiency Jan 28, 2026
Copilot AI requested a review from mihow January 28, 2026 20:01
mihow added a commit that referenced this pull request Feb 10, 2026
…U calls

Collect all detection crops, apply classifier transforms (which resize
to uniform input_size), then run a single batched predict_batch call.
Skips detections with invalid bounding boxes (y1 >= y2 or x1 >= x2).
This supersedes PR #105.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@mihow
Copy link
Collaborator

mihow commented Feb 10, 2026

Covered by #110

@mihow mihow closed this Feb 10, 2026
mihow added a commit that referenced this pull request Feb 17, 2026
…classification (#110)

* fix: return images as list in rest_collate_fn to support variable sizes

torch.stack requires all tensors to be the same size, which crashes
when a batch contains images of different resolutions (e.g. 3420x6080
and 2160x4096). FasterRCNN natively accepts a list of variable-sized
images, and predict_batch already handles this code path.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: handle batch processing errors per-batch instead of crashing job

Wrap the batch processing loop body in try/except so a single failed
batch doesn't kill the entire job. On failure, error results are posted
back to Antenna for each image in the batch so tasks don't get stuck
in the queue. Also downgrade post_batch_results failure from a raised
exception to a logged error to avoid losing progress on subsequent
batches.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* perf: batch classification crops in worker instead of N individual GPU calls

Collect all detection crops, apply classifier transforms (which resize
to uniform input_size), then run a single batched predict_batch call.
Skips detections with invalid bounding boxes (y1 >= y2 or x1 >= x2).
This supersedes PR #105.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* perf: increase default localization_batch_size and num_workers for GPU

localization_batch_size 2 → 8 and num_workers 1 → 4. The old defaults
were far too conservative for 24GB VRAM GPUs. These can still be
overridden via AMI_LOCALIZATION_BATCH_SIZE and AMI_NUM_WORKERS env vars.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* feat: spawn one worker process per GPU for multi-GPU inference

run_worker() now detects torch.cuda.device_count() and uses
torch.multiprocessing.spawn to launch one worker per GPU. Each worker
pins itself to a specific GPU via set_device(). Single-GPU and CPU-only
machines keep existing single-process behavior with no overhead.

Also fixes get_device() to use current_device() instead of bare "cuda"
so that models load onto the correct GPU in spawned workers.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: clarify batch size log messages to distinguish worker vs inference

Worker batch logs now show image count ("Processing worker batch 3 (8
images)") and model inference logs include the model name ("Preparing
FasterRCNN inference dataloader (batch_size=4, single worker mode)").

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* perf: increase DataLoader prefetch_factor to 4 for worker

The default prefetch_factor of 2 means the DataLoader only prepares 2
batches ahead. With GPU inference taking ~2s and image downloads taking
~30s per batch, the GPU idles waiting for data. Bumping to 4 keeps the
download pipeline fuller so the next batch is more likely to be ready.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* perf: download images concurrently within each DataLoader worker

Previously each DataLoader worker downloaded images sequentially: fetch
task metadata, then download image 1, download image 2, ... download
image N, yield all. With 32 tasks per API fetch, this meant ~30s of
serial HTTP requests before a single image reached the GPU.

Now _load_images_threaded() uses a ThreadPoolExecutor (up to 8 threads)
to download all images in a task batch concurrently. Threads are ideal
here because image downloads are I/O-bound (network latency), not
CPU-bound, and the requests Session's connection pool is thread-safe.

This stacks with the existing DataLoader num_workers parallelism:
- num_workers: N independent DataLoader processes, each polling the API
- ThreadPoolExecutor: within each process, M concurrent image downloads
- prefetch_factor: DataLoader queues future batches while GPU is busy

Expected improvement: download time for a batch of 32 images drops from
~30s (sequential) to ~4-5s (8 concurrent threads), keeping the GPU fed.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* docs: clarify types of workers

* refactor: make ThreadPoolExecutor a class member, simplify with map()

- Keep the thread pool alive across batches instead of recreating it
- Use executor.map() instead of submit + as_completed
- Fix docstring: requests.Session is not formally thread-safe

Addresses review comments from @carlosgjs.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* docs: fix stale docstring in rest_collate_fn

"Stacked tensor" → "List of image tensors" to match the actual
return type after the variable-size image change.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: defensive batch_results init and strict=True in error handler

- Initialize batch_results before the try block to prevent potential
  NameError if a future refactor introduces a path between try/except
  and the post-results code.
- Add strict=True to zip(reply_subjects, image_ids) in the except
  block so length mismatches raise immediately rather than silently
  dropping error reports.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* revert: remove prefetch_factor=4 override, use PyTorch default

No measured improvement over the default (2). The override just
increases memory usage without demonstrated benefit.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* docs: document data loading pipeline concurrency layers

Add a module-level overview mapping the three concurrency layers
(GPU processes, DataLoader workers, thread pool) to their settings,
what work each layer does, and what runs under the GIL.

Includes a "not yet benchmarked" section so future contributors
know which knobs are tuned empirically vs speculatively.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* perf: tune defaults for higher GPU utilization

- antenna_api_batch_size: 4 → 16 (fetch enough tasks per API call to
  fill at least one GPU batch without an extra round trip)
- num_workers: 4 → 2 (each worker prefetches more with larger API
  batches; fewer workers reduces CPU/RAM overhead)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: lazily init unpicklable objects in RESTDataset for num_workers>0

ThreadPoolExecutor (SimpleQueue), requests.Session objects created in
__init__ cannot be pickled. PyTorch DataLoader with num_workers>0 uses
spawn, which pickles the dataset to send to worker subprocesses.

Move session and executor creation to _ensure_sessions(), called on
first use in each worker process. Add regression test that creates a
DataLoader with num_workers=2 and verifies workers can start.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* docs: fix stale num_workers default in datasets.py docstring (4 → 2)

Matches the actual default changed in ed49153.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Worker: Batch classification predictions for GPU efficiency

2 participants