Skip to content

fix: worker image size mismatches, per-batch error handling, batched classification#110

Merged
mihow merged 17 commits intomainfrom
fix/size-mismatches
Feb 17, 2026
Merged

fix: worker image size mismatches, per-batch error handling, batched classification#110
mihow merged 17 commits intomainfrom
fix/size-mismatches

Conversation

@mihow
Copy link
Collaborator

@mihow mihow commented Feb 10, 2026

Summary

  • Variable-size image crash: rest_collate_fn used torch.stack which requires uniform tensor sizes. When a batch contained images of different resolutions (e.g. 3420x6080 and 2160x4096), the entire job crashed and restarted in a loop. Images are now passed as a list of tensors, which FasterRCNN natively accepts.

  • Per-batch error handling: A single failed batch previously killed the entire _process_job call, resetting the counter and retrying from scratch. Batch processing is now wrapped in try/except: errors are reported back to Antenna for the affected images, and the loop continues to the next batch. post_batch_results failure is also downgraded from a raised exception to a logged error.

  • Batched classification: Detection crops were classified one at a time (N GPU calls for N detections). Crops are now collected, transformed to uniform size via classifier.get_transforms(), stacked, and classified in a single predict_batch call. Invalid bounding boxes (y1 >= y2 or x1 >= x2) are skipped. Supersedes Batch classification predictions in worker for GPU efficiency #105.

Test plan

  • All 11 existing worker tests pass
  • Manual test with ami worker on images of varying resolutions

🤖 Generated with Claude Code

Summary by CodeRabbit

  • New Features

    • Threaded concurrent image downloads per DataLoader worker and GPU-aware worker spawning for multi-GPU runs.
  • Bug Fixes

    • Improved batch-level error handling to avoid worker crashes and report failures gracefully.
    • Collation now returns a list of per-image tensors to handle mixed/failed items robustly.
  • Performance

    • Prefetching and batched classification increase throughput.
  • Chores

    • Increased default batch sizes and worker counts; updated docs and logs.
  • Tests

    • Updated tests to expect per-image list outputs.

mihow and others added 3 commits February 9, 2026 18:27
torch.stack requires all tensors to be the same size, which crashes
when a batch contains images of different resolutions (e.g. 3420x6080
and 2160x4096). FasterRCNN natively accepts a list of variable-sized
images, and predict_batch already handles this code path.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Wrap the batch processing loop body in try/except so a single failed
batch doesn't kill the entire job. On failure, error results are posted
back to Antenna for each image in the batch so tasks don't get stuck
in the queue. Also downgrade post_batch_results failure from a raised
exception to a logged error to avoid losing progress on subsequent
batches.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
…U calls

Collect all detection crops, apply classifier transforms (which resize
to uniform input_size), then run a single batched predict_batch call.
Skips detections with invalid bounding boxes (y1 >= y2 or x1 >= x2).
This supersedes PR #105.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
@coderabbitai
Copy link

coderabbitai bot commented Feb 10, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Implements lazy, threaded image prefetching in RESTDataset (per-worker ThreadPoolExecutor and HTTP sessions), changes collate to return lists of per-image tensors, adds GPU-aware multiprocessing with per-GPU worker loops and batched detection→classification, and updates device selection and defaults.

Changes

Cohort / File(s) Summary
REST dataset & collate
trapdata/antenna/datasets.py
Adds lazy session & executor lifecycle management, _load_images_threaded(self, tasks) to parallel-download a batch (returns map image_id→tensor
Worker: GPU multiprocessing & batched flow
trapdata/antenna/worker.py
Introduces torch.multiprocessing usage and mp.spawn for one worker per GPU (fallback single loop), adds _worker_loop(gpu_id, ...) with GPU pinning and GPU-scoped logging, wraps batch processing in try/except, validates bboxes, extracts/validates crops, stacks crops for batched classifier inference, and posts per-image results or structured batch errors.
Tests
trapdata/antenna/tests/test_worker.py
Adds TestDataLoaderMultiWorker to ensure dataset is picklable for DataLoader num_workers>0; updates rest_collate_fn tests to expect images as a list of tensors and assert per-item shapes.
Device selection utility
trapdata/ml/utils.py
get_device now prefers torch.cuda.current_device() when CUDA is available to return a device pinned to the current CUDA index; explicit device_str handling unchanged.
Model dataloader logging
trapdata/ml/models/base.py
Log messages in get_dataloader updated to include model name and formatted parameters only; behavior unchanged.
Settings defaults & metadata
trapdata/settings.py
Default values changed: localization_batch_size 2→8, num_workers 1→2, antenna_api_batch_size 4→16; num_workers field title updated to "DataLoader workers" and description expanded to mention PyTorch DataLoader (image downloading & preprocessing).

Sequence Diagram

sequenceDiagram
    participant JobQueue as Job Queue
    participant DataLoader as DataLoader subprocess
    participant Worker as Worker (per-GPU)
    participant Detector as Detector Model
    participant Classifier as Classifier Model
    participant API as Antenna API

    JobQueue->>Worker: Deliver job / batch
    Worker->>DataLoader: Request images (batch)
    DataLoader->>DataLoader: Threaded download -> map image_id: tensor
    DataLoader-->>Worker: Return batch image tensors

    Worker->>Detector: Run detection on images
    Detector-->>Worker: Detections (bboxes)

    rect rgba(100,150,255,0.5)
    Worker->>Worker: Validate bboxes, extract crops, convert to PIL, apply transforms
    Worker->>Classifier: Send stacked batch of crops
    Classifier-->>Worker: Return classification scores
    Worker->>Worker: Merge scores into detections
    end

    alt Batch succeeds
        Worker->>API: Post PipelineResultsResponse(s)
        API-->>Worker: Ack
    else Batch fails
        Worker->>API: Post AntennaTaskResultError entries
        API-->>Worker: Ack
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Poem

🐇
I hop through bytes and channel queues,
Prefetching pixels in threaded shoes.
GPUs hum as crops leap to the batch,
Lists replace stacks — a nimble catch.
I thump my foot; the pipeline chews.

🚥 Pre-merge checks | ✅ 4
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title directly addresses the primary changes: fixing image size mismatches through list-based image output, implementing per-batch error handling to prevent job aborts, and implementing batched classification for efficiency.
Docstring Coverage ✅ Passed Docstring coverage is 84.21% which is sufficient. The required threshold is 80.00%.
Merge Conflict Detection ✅ Passed ✅ No merge conflicts detected when merging into main

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch fix/size-mismatches

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
trapdata/antenna/datasets.py (1)

197-238: ⚠️ Potential issue | 🟡 Minor

Docstring still says "Stacked tensor" but images are now a list of tensors.

Line 202 reads images: Stacked tensor of valid images but the implementation on Line 234 now returns a plain list. Update the docstring to match.

Proposed fix
-        - images: Stacked tensor of valid images (only present if there are successful items)
+        - images: List of image tensors (only present if there are successful items)
🤖 Fix all issues with AI agents
In `@trapdata/antenna/worker.py`:
- Around line 241-254: The exception handler builds batch_results by iterating
zip(reply_subjects, image_ids) but doesn't use strict mode; update the zip call
in the except block to zip(reply_subjects, image_ids, strict=True) so any length
mismatch raises immediately and ensures every image_id has a corresponding
AntennaTaskResult/AntennaTaskResultError; locate the loop that constructs
AntennaTaskResult (references: reply_subjects, image_ids, AntennaTaskResult,
AntennaTaskResultError) and add strict=True to the zip invocation, then run
tests to confirm compatibility with the Python runtime.
🧹 Nitpick comments (3)
trapdata/antenna/worker.py (3)

165-169: dict(zip(...)) silently drops images if image_ids contains duplicates.

Line 169 builds image_tensors via dict(zip(image_ids, images, strict=True)). If the batch ever contains duplicate image_id values, later tensors overwrite earlier ones, silently losing data. The same issue applies to image_detections on Line 166-168, though pre-initialized with empty lists so the effect there is that detections for the overwritten image would be posted under the surviving key's results.

If duplicate IDs within a single batch are impossible by design, a defensive assertion would make that contract explicit.

Proposed defensive check
             image_detections: dict[str, list[DetectionResponse]] = {
                 img_id: [] for img_id in image_ids
             }
+            if len(set(image_ids)) != len(image_ids):
+                logger.warning(f"Duplicate image_ids in batch: {image_ids}")
             image_tensors = dict(zip(image_ids, images, strict=True))

175-209: Detections with invalid bounding boxes are silently dropped from results.

Detections where y1 >= y2 or x1 >= x2 are skipped for classification (correct), but they are also excluded from image_detections entirely (Line 208 is only reached for valid crops). This means these detections are never posted back to Antenna — the user won't see them at all, even though the detector found something. Consider whether unclassified detections with invalid bboxes should still be included in results (without classification data) so the caller has full visibility.


269-281: Downgrading post_batch_results failure to a log is intentional but means silent data loss.

Per the PR objectives, this is deliberate to avoid interrupting subsequent batches. The log message on Line 279-281 is clear. However, consider emitting a metric or structured log field (e.g., job_id, batch_index, num_results) to make monitoring and alerting on data loss easier in production.

mihow and others added 5 commits February 9, 2026 18:51
localization_batch_size 2 → 8 and num_workers 1 → 4. The old defaults
were far too conservative for 24GB VRAM GPUs. These can still be
overridden via AMI_LOCALIZATION_BATCH_SIZE and AMI_NUM_WORKERS env vars.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
run_worker() now detects torch.cuda.device_count() and uses
torch.multiprocessing.spawn to launch one worker per GPU. Each worker
pins itself to a specific GPU via set_device(). Single-GPU and CPU-only
machines keep existing single-process behavior with no overhead.

Also fixes get_device() to use current_device() instead of bare "cuda"
so that models load onto the correct GPU in spawned workers.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Worker batch logs now show image count ("Processing worker batch 3 (8
images)") and model inference logs include the model name ("Preparing
FasterRCNN inference dataloader (batch_size=4, single worker mode)").

Co-Authored-By: Claude Opus 4.6 <[email protected]>
The default prefetch_factor of 2 means the DataLoader only prepares 2
batches ahead. With GPU inference taking ~2s and image downloads taking
~30s per batch, the GPU idles waiting for data. Bumping to 4 keeps the
download pipeline fuller so the next batch is more likely to be ready.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Previously each DataLoader worker downloaded images sequentially: fetch
task metadata, then download image 1, download image 2, ... download
image N, yield all. With 32 tasks per API fetch, this meant ~30s of
serial HTTP requests before a single image reached the GPU.

Now _load_images_threaded() uses a ThreadPoolExecutor (up to 8 threads)
to download all images in a task batch concurrently. Threads are ideal
here because image downloads are I/O-bound (network latency), not
CPU-bound, and the requests Session's connection pool is thread-safe.

This stacks with the existing DataLoader num_workers parallelism:
- num_workers: N independent DataLoader processes, each polling the API
- ThreadPoolExecutor: within each process, M concurrent image downloads
- prefetch_factor: DataLoader queues future batches while GPU is busy

Expected improvement: download time for a batch of 32 images drops from
~30s (sequential) to ~4-5s (8 concurrent threads), keeping the GPU fed.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Copy link
Collaborator

@carlosgjs carlosgjs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, added a few comments

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
trapdata/antenna/datasets.py (1)

234-274: ⚠️ Potential issue | 🟡 Minor

Stale docstring: Line 239 says "Stacked tensor" but images are now a list.

The whole point of this PR is that rest_collate_fn now returns images as a list[torch.Tensor] (Line 271) to support variable-size inputs. But the docstring on Line 239 still reads "Stacked tensor of valid images."

Proposed fix
-        - images: Stacked tensor of valid images (only present if there are successful items)
+        - images: List of image tensors (variable sizes supported; only present if there are successful items)
🤖 Fix all issues with AI agents
In `@trapdata/antenna/datasets.py`:
- Around line 124-157: The docstring incorrectly claims the HTTP session is
thread-safe; fix by avoiding shared mutable Session across threads: update
_load_images_threaded to create and use a separate requests.Session (or
urllib3.PoolManager) per worker thread inside the inner _download function (or
via thread-local storage) and ensure the session is closed after use, and adjust
the docstring to remove the thread-safety claim and state the chosen approach;
relevant symbols: _load_images_threaded, _download, _load_image, and
self.image_fetch_session.
🧹 Nitpick comments (2)
trapdata/antenna/worker.py (1)

300-327: failed_items handling outside the try/except is correct but fragile.

batch_results is defined inside the try block (Line 263) or inside the except block (Line 288). If a future refactor introduces a code path between the try/except and Line 300 where batch_results could be undefined, this would produce a NameError. Consider initializing batch_results: list[AntennaTaskResult] = [] before the try to make the contract explicit.

Proposed defensive initialization
+        batch_results: list[AntennaTaskResult] = []
+
         try:
             # Validate all arrays have same length before zipping
trapdata/antenna/datasets.py (1)

312-322: Heads-up on prefetch memory: prefetch_factor=4 × num_workers=4 = 16 batches buffered.

With localization_batch_size=8, that's up to 128 full-resolution images resident in CPU memory simultaneously. For 4K images (~100 MB as a float32 tensor each), this could peak around ~12 GB of CPU RAM. If the target deployment environment has limited memory, consider making prefetch_factor a setting or reducing it to 2 (PyTorch's default).

mihow and others added 6 commits February 12, 2026 16:22
- Keep the thread pool alive across batches instead of recreating it
- Use executor.map() instead of submit + as_completed
- Fix docstring: requests.Session is not formally thread-safe

Addresses review comments from @carlosgjs.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
"Stacked tensor" → "List of image tensors" to match the actual
return type after the variable-size image change.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
- Initialize batch_results before the try block to prevent potential
  NameError if a future refactor introduces a path between try/except
  and the post-results code.
- Add strict=True to zip(reply_subjects, image_ids) in the except
  block so length mismatches raise immediately rather than silently
  dropping error reports.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
No measured improvement over the default (2). The override just
increases memory usage without demonstrated benefit.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Add a module-level overview mapping the three concurrency layers
(GPU processes, DataLoader workers, thread pool) to their settings,
what work each layer does, and what runs under the GIL.

Includes a "not yet benchmarked" section so future contributors
know which knobs are tuned empirically vs speculatively.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
- antenna_api_batch_size: 4 → 16 (fetch enough tasks per API call to
  fill at least one GPU batch without an extra round trip)
- num_workers: 4 → 2 (each worker prefetches more with larger API
  batches; fewer workers reduces CPU/RAM overhead)

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@trapdata/antenna/datasets.py`:
- Around line 22-23: The docstring in antenna/datasets.py incorrectly states
"Default: 4" for num_workers; align it with the actual default in
Settings.num_workers (trapdata/settings.py) which is 2 by updating the docstring
to "Default: 2" (or if you intended 4, change Settings.num_workers to 4) so the
comment Controlled by: settings.num_workers (AMI_NUM_WORKERS) matches the real
default value.
🧹 Nitpick comments (3)
trapdata/antenna/datasets.py (2)

132-133: ThreadPoolExecutor created in __init__ may be fragile with fork-based DataLoader workers.

When num_workers > 0, the DataLoader forks the process that holds this dataset. The executor is created in __init__ before forking. While this works today because threads are lazily spawned (no threads exist yet at fork time), it's a subtle invariant — if anything triggers map() or submit() on the parent before the fork, the forked executor's internal lock could be corrupted.

A safer alternative is to lazily create the executor on first use inside _load_images_threaded:

Suggested lazy initialization
-        # Reusable thread pool for concurrent image downloads
-        self._executor = ThreadPoolExecutor(max_workers=8)
+        # Created lazily in _load_images_threaded to be fork-safe
+        self._executor: ThreadPoolExecutor | None = None

     ...

     def _load_images_threaded(
         self,
         tasks: list[AntennaPipelineProcessingTask],
     ) -> dict[str, torch.Tensor | None]:
         ...
+        if self._executor is None:
+            self._executor = ThreadPoolExecutor(max_workers=8)
+
         def _download(

289-290: Redundant ternary inside an if errors: guard.

The if errors else None part is unreachable since we're already inside if errors:.

Nit fix
                    if errors:
-                        row["error"] = "; ".join(errors) if errors else None
+                        row["error"] = "; ".join(errors)
trapdata/antenna/worker.py (1)

315-327: POST request in post_batch_results lacks retry protection for transient failures.

The get_http_session() used by post_batch_results has retry logic configured, but it only applies to GET requests (allowed_methods=["GET"]). Since post_batch_results makes a POST request, transient network errors will cause permanent data loss without retry capability. Consider updating the retry strategy to include POST requests or implementing retry logic within post_batch_results itself.

ThreadPoolExecutor (SimpleQueue), requests.Session objects created in
__init__ cannot be pickled. PyTorch DataLoader with num_workers>0 uses
spawn, which pickles the dataset to send to worker subprocesses.

Move session and executor creation to _ensure_sessions(), called on
first use in each worker process. Add regression test that creates a
DataLoader with num_workers=2 and verifies workers can start.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
trapdata/antenna/datasets.py (2)

195-199: _ensure_sessions() called from worker threads is racy (but practically safe).

_load_image is invoked by threads inside _load_images_threaded. If the defensive _ensure_sessions() on line 197 is reached before the main-thread call on line 241 completes (shouldn't happen with current control flow), two threads could both see None and create duplicate sessions. This is benign today because _ensure_sessions is always called on line 241 before executor.map starts, but it's fragile — a future caller invoking _load_image directly from a thread pool would hit the race.

Consider removing the _ensure_sessions() call here (relying on the caller contract) or adding a threading.Lock guard.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@trapdata/antenna/datasets.py` around lines 195 - 199, The call to
_ensure_sessions() from worker threads is racy; add a threading.Lock to
serialize session creation: create self._sessions_lock (e.g., in the class
__init__), and wrap the session-initialization logic inside _ensure_sessions()
with a with self._sessions_lock: (also protect reads/writes of
self._image_fetch_session there) so concurrent _load_image calls cannot create
duplicate sessions; alternatively, remove the _ensure_sessions() call from
_load_image and rely on the caller contract, but prefer the lock approach to
make _load_image/_load_images_threaded robust.

308-312: Redundant ternary — errors is always truthy here.

Line 310 is inside if errors:, so the if errors else None part of the ternary on line 311 can never be None.

Suggested simplification
                 if errors:
-                    row["error"] = "; ".join(errors) if errors else None
+                    row["error"] = "; ".join(errors)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@trapdata/antenna/datasets.py` around lines 308 - 312, The code sets
row["error"] inside an if errors: block but uses a redundant ternary (if errors
else None); in the block where errors is truthy, replace the ternary with a
direct assignment so row["error"] = "; ".join(errors) (remove the "if errors
else None" branch) in the same scope around the code that builds row (the
variables row and errors and task.image_url are in this context).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@trapdata/antenna/datasets.py`:
- Line 23: The docstring in antenna.datasets that currently states "Default: 4"
is inconsistent with the actual default used (Settings.num_workers == 2 and the
later comment at line 40); update the docstring to state "Default: 2" (or
reference Settings.num_workers) so it matches the implemented default, and
verify any mention of the default in the function/class docstring (e.g., the
Antenna dequeues/num_workers docstring) is consistent with Settings.num_workers
in settings.py.

---

Nitpick comments:
In `@trapdata/antenna/datasets.py`:
- Around line 195-199: The call to _ensure_sessions() from worker threads is
racy; add a threading.Lock to serialize session creation: create
self._sessions_lock (e.g., in the class __init__), and wrap the
session-initialization logic inside _ensure_sessions() with a with
self._sessions_lock: (also protect reads/writes of self._image_fetch_session
there) so concurrent _load_image calls cannot create duplicate sessions;
alternatively, remove the _ensure_sessions() call from _load_image and rely on
the caller contract, but prefer the lock approach to make
_load_image/_load_images_threaded robust.
- Around line 308-312: The code sets row["error"] inside an if errors: block but
uses a redundant ternary (if errors else None); in the block where errors is
truthy, replace the ternary with a direct assignment so row["error"] = ";
".join(errors) (remove the "if errors else None" branch) in the same scope
around the code that builds row (the variables row and errors and task.image_url
are in this context).

Matches the actual default changed in ed49153.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
@mihow mihow merged commit 30e22a6 into main Feb 17, 2026
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants

Comments