fix: worker image size mismatches, per-batch error handling, batched classification#110
fix: worker image size mismatches, per-batch error handling, batched classification#110
Conversation
torch.stack requires all tensors to be the same size, which crashes when a batch contains images of different resolutions (e.g. 3420x6080 and 2160x4096). FasterRCNN natively accepts a list of variable-sized images, and predict_batch already handles this code path. Co-Authored-By: Claude Opus 4.6 <[email protected]>
Wrap the batch processing loop body in try/except so a single failed batch doesn't kill the entire job. On failure, error results are posted back to Antenna for each image in the batch so tasks don't get stuck in the queue. Also downgrade post_batch_results failure from a raised exception to a logged error to avoid losing progress on subsequent batches. Co-Authored-By: Claude Opus 4.6 <[email protected]>
…U calls Collect all detection crops, apply classifier transforms (which resize to uniform input_size), then run a single batched predict_batch call. Skips detections with invalid bounding boxes (y1 >= y2 or x1 >= x2). This supersedes PR #105. Co-Authored-By: Claude Opus 4.6 <[email protected]>
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughImplements lazy, threaded image prefetching in RESTDataset (per-worker ThreadPoolExecutor and HTTP sessions), changes collate to return lists of per-image tensors, adds GPU-aware multiprocessing with per-GPU worker loops and batched detection→classification, and updates device selection and defaults. Changes
Sequence DiagramsequenceDiagram
participant JobQueue as Job Queue
participant DataLoader as DataLoader subprocess
participant Worker as Worker (per-GPU)
participant Detector as Detector Model
participant Classifier as Classifier Model
participant API as Antenna API
JobQueue->>Worker: Deliver job / batch
Worker->>DataLoader: Request images (batch)
DataLoader->>DataLoader: Threaded download -> map image_id: tensor
DataLoader-->>Worker: Return batch image tensors
Worker->>Detector: Run detection on images
Detector-->>Worker: Detections (bboxes)
rect rgba(100,150,255,0.5)
Worker->>Worker: Validate bboxes, extract crops, convert to PIL, apply transforms
Worker->>Classifier: Send stacked batch of crops
Classifier-->>Worker: Return classification scores
Worker->>Worker: Merge scores into detections
end
alt Batch succeeds
Worker->>API: Post PipelineResultsResponse(s)
API-->>Worker: Ack
else Batch fails
Worker->>API: Post AntennaTaskResultError entries
API-->>Worker: Ack
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 4✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
trapdata/antenna/datasets.py (1)
197-238:⚠️ Potential issue | 🟡 MinorDocstring still says "Stacked tensor" but images are now a list of tensors.
Line 202 reads
images: Stacked tensor of valid imagesbut the implementation on Line 234 now returns a plain list. Update the docstring to match.Proposed fix
- - images: Stacked tensor of valid images (only present if there are successful items) + - images: List of image tensors (only present if there are successful items)
🤖 Fix all issues with AI agents
In `@trapdata/antenna/worker.py`:
- Around line 241-254: The exception handler builds batch_results by iterating
zip(reply_subjects, image_ids) but doesn't use strict mode; update the zip call
in the except block to zip(reply_subjects, image_ids, strict=True) so any length
mismatch raises immediately and ensures every image_id has a corresponding
AntennaTaskResult/AntennaTaskResultError; locate the loop that constructs
AntennaTaskResult (references: reply_subjects, image_ids, AntennaTaskResult,
AntennaTaskResultError) and add strict=True to the zip invocation, then run
tests to confirm compatibility with the Python runtime.
🧹 Nitpick comments (3)
trapdata/antenna/worker.py (3)
165-169:dict(zip(...))silently drops images ifimage_idscontains duplicates.Line 169 builds
image_tensorsviadict(zip(image_ids, images, strict=True)). If the batch ever contains duplicateimage_idvalues, later tensors overwrite earlier ones, silently losing data. The same issue applies toimage_detectionson Line 166-168, though pre-initialized with empty lists so the effect there is that detections for the overwritten image would be posted under the surviving key's results.If duplicate IDs within a single batch are impossible by design, a defensive assertion would make that contract explicit.
Proposed defensive check
image_detections: dict[str, list[DetectionResponse]] = { img_id: [] for img_id in image_ids } + if len(set(image_ids)) != len(image_ids): + logger.warning(f"Duplicate image_ids in batch: {image_ids}") image_tensors = dict(zip(image_ids, images, strict=True))
175-209: Detections with invalid bounding boxes are silently dropped from results.Detections where
y1 >= y2 or x1 >= x2are skipped for classification (correct), but they are also excluded fromimage_detectionsentirely (Line 208 is only reached for valid crops). This means these detections are never posted back to Antenna — the user won't see them at all, even though the detector found something. Consider whether unclassified detections with invalid bboxes should still be included in results (without classification data) so the caller has full visibility.
269-281: Downgradingpost_batch_resultsfailure to a log is intentional but means silent data loss.Per the PR objectives, this is deliberate to avoid interrupting subsequent batches. The log message on Line 279-281 is clear. However, consider emitting a metric or structured log field (e.g.,
job_id,batch_index,num_results) to make monitoring and alerting on data loss easier in production.
localization_batch_size 2 → 8 and num_workers 1 → 4. The old defaults were far too conservative for 24GB VRAM GPUs. These can still be overridden via AMI_LOCALIZATION_BATCH_SIZE and AMI_NUM_WORKERS env vars. Co-Authored-By: Claude Opus 4.6 <[email protected]>
run_worker() now detects torch.cuda.device_count() and uses torch.multiprocessing.spawn to launch one worker per GPU. Each worker pins itself to a specific GPU via set_device(). Single-GPU and CPU-only machines keep existing single-process behavior with no overhead. Also fixes get_device() to use current_device() instead of bare "cuda" so that models load onto the correct GPU in spawned workers. Co-Authored-By: Claude Opus 4.6 <[email protected]>
Worker batch logs now show image count ("Processing worker batch 3 (8
images)") and model inference logs include the model name ("Preparing
FasterRCNN inference dataloader (batch_size=4, single worker mode)").
Co-Authored-By: Claude Opus 4.6 <[email protected]>
The default prefetch_factor of 2 means the DataLoader only prepares 2 batches ahead. With GPU inference taking ~2s and image downloads taking ~30s per batch, the GPU idles waiting for data. Bumping to 4 keeps the download pipeline fuller so the next batch is more likely to be ready. Co-Authored-By: Claude Opus 4.6 <[email protected]>
Previously each DataLoader worker downloaded images sequentially: fetch task metadata, then download image 1, download image 2, ... download image N, yield all. With 32 tasks per API fetch, this meant ~30s of serial HTTP requests before a single image reached the GPU. Now _load_images_threaded() uses a ThreadPoolExecutor (up to 8 threads) to download all images in a task batch concurrently. Threads are ideal here because image downloads are I/O-bound (network latency), not CPU-bound, and the requests Session's connection pool is thread-safe. This stacks with the existing DataLoader num_workers parallelism: - num_workers: N independent DataLoader processes, each polling the API - ThreadPoolExecutor: within each process, M concurrent image downloads - prefetch_factor: DataLoader queues future batches while GPU is busy Expected improvement: download time for a batch of 32 images drops from ~30s (sequential) to ~4-5s (8 concurrent threads), keeping the GPU fed. Co-Authored-By: Claude Opus 4.6 <[email protected]>
carlosgjs
left a comment
There was a problem hiding this comment.
Looks great, added a few comments
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
trapdata/antenna/datasets.py (1)
234-274:⚠️ Potential issue | 🟡 MinorStale docstring: Line 239 says "Stacked tensor" but images are now a list.
The whole point of this PR is that
rest_collate_fnnow returns images as alist[torch.Tensor](Line 271) to support variable-size inputs. But the docstring on Line 239 still reads "Stacked tensor of valid images."Proposed fix
- - images: Stacked tensor of valid images (only present if there are successful items) + - images: List of image tensors (variable sizes supported; only present if there are successful items)
🤖 Fix all issues with AI agents
In `@trapdata/antenna/datasets.py`:
- Around line 124-157: The docstring incorrectly claims the HTTP session is
thread-safe; fix by avoiding shared mutable Session across threads: update
_load_images_threaded to create and use a separate requests.Session (or
urllib3.PoolManager) per worker thread inside the inner _download function (or
via thread-local storage) and ensure the session is closed after use, and adjust
the docstring to remove the thread-safety claim and state the chosen approach;
relevant symbols: _load_images_threaded, _download, _load_image, and
self.image_fetch_session.
🧹 Nitpick comments (2)
trapdata/antenna/worker.py (1)
300-327:failed_itemshandling outside thetry/exceptis correct but fragile.
batch_resultsis defined inside thetryblock (Line 263) or inside theexceptblock (Line 288). If a future refactor introduces a code path between thetry/exceptand Line 300 wherebatch_resultscould be undefined, this would produce aNameError. Consider initializingbatch_results: list[AntennaTaskResult] = []before thetryto make the contract explicit.Proposed defensive initialization
+ batch_results: list[AntennaTaskResult] = [] + try: # Validate all arrays have same length before zippingtrapdata/antenna/datasets.py (1)
312-322: Heads-up on prefetch memory:prefetch_factor=4×num_workers=4= 16 batches buffered.With
localization_batch_size=8, that's up to 128 full-resolution images resident in CPU memory simultaneously. For 4K images (~100 MB as a float32 tensor each), this could peak around ~12 GB of CPU RAM. If the target deployment environment has limited memory, consider makingprefetch_factora setting or reducing it to 2 (PyTorch's default).
- Keep the thread pool alive across batches instead of recreating it - Use executor.map() instead of submit + as_completed - Fix docstring: requests.Session is not formally thread-safe Addresses review comments from @carlosgjs. Co-Authored-By: Claude Opus 4.6 <[email protected]>
"Stacked tensor" → "List of image tensors" to match the actual return type after the variable-size image change. Co-Authored-By: Claude Opus 4.6 <[email protected]>
- Initialize batch_results before the try block to prevent potential NameError if a future refactor introduces a path between try/except and the post-results code. - Add strict=True to zip(reply_subjects, image_ids) in the except block so length mismatches raise immediately rather than silently dropping error reports. Co-Authored-By: Claude Opus 4.6 <[email protected]>
No measured improvement over the default (2). The override just increases memory usage without demonstrated benefit. Co-Authored-By: Claude Opus 4.6 <[email protected]>
Add a module-level overview mapping the three concurrency layers (GPU processes, DataLoader workers, thread pool) to their settings, what work each layer does, and what runs under the GIL. Includes a "not yet benchmarked" section so future contributors know which knobs are tuned empirically vs speculatively. Co-Authored-By: Claude Opus 4.6 <[email protected]>
- antenna_api_batch_size: 4 → 16 (fetch enough tasks per API call to fill at least one GPU batch without an extra round trip) - num_workers: 4 → 2 (each worker prefetches more with larger API batches; fewer workers reduces CPU/RAM overhead) Co-Authored-By: Claude Opus 4.6 <[email protected]>
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@trapdata/antenna/datasets.py`:
- Around line 22-23: The docstring in antenna/datasets.py incorrectly states
"Default: 4" for num_workers; align it with the actual default in
Settings.num_workers (trapdata/settings.py) which is 2 by updating the docstring
to "Default: 2" (or if you intended 4, change Settings.num_workers to 4) so the
comment Controlled by: settings.num_workers (AMI_NUM_WORKERS) matches the real
default value.
🧹 Nitpick comments (3)
trapdata/antenna/datasets.py (2)
132-133: ThreadPoolExecutor created in__init__may be fragile withfork-based DataLoader workers.When
num_workers > 0, the DataLoader forks the process that holds this dataset. The executor is created in__init__before forking. While this works today because threads are lazily spawned (no threads exist yet at fork time), it's a subtle invariant — if anything triggersmap()orsubmit()on the parent before the fork, the forked executor's internal lock could be corrupted.A safer alternative is to lazily create the executor on first use inside
_load_images_threaded:Suggested lazy initialization
- # Reusable thread pool for concurrent image downloads - self._executor = ThreadPoolExecutor(max_workers=8) + # Created lazily in _load_images_threaded to be fork-safe + self._executor: ThreadPoolExecutor | None = None ... def _load_images_threaded( self, tasks: list[AntennaPipelineProcessingTask], ) -> dict[str, torch.Tensor | None]: ... + if self._executor is None: + self._executor = ThreadPoolExecutor(max_workers=8) + def _download(
289-290: Redundant ternary inside anif errors:guard.The
if errors else Nonepart is unreachable since we're already insideif errors:.Nit fix
if errors: - row["error"] = "; ".join(errors) if errors else None + row["error"] = "; ".join(errors)trapdata/antenna/worker.py (1)
315-327: POST request inpost_batch_resultslacks retry protection for transient failures.The
get_http_session()used bypost_batch_resultshas retry logic configured, but it only applies to GET requests (allowed_methods=["GET"]). Sincepost_batch_resultsmakes a POST request, transient network errors will cause permanent data loss without retry capability. Consider updating the retry strategy to include POST requests or implementing retry logic withinpost_batch_resultsitself.
ThreadPoolExecutor (SimpleQueue), requests.Session objects created in __init__ cannot be pickled. PyTorch DataLoader with num_workers>0 uses spawn, which pickles the dataset to send to worker subprocesses. Move session and executor creation to _ensure_sessions(), called on first use in each worker process. Add regression test that creates a DataLoader with num_workers=2 and verifies workers can start. Co-Authored-By: Claude Opus 4.6 <[email protected]>
There was a problem hiding this comment.
🧹 Nitpick comments (2)
trapdata/antenna/datasets.py (2)
195-199:_ensure_sessions()called from worker threads is racy (but practically safe).
_load_imageis invoked by threads inside_load_images_threaded. If the defensive_ensure_sessions()on line 197 is reached before the main-thread call on line 241 completes (shouldn't happen with current control flow), two threads could both seeNoneand create duplicate sessions. This is benign today because_ensure_sessionsis always called on line 241 beforeexecutor.mapstarts, but it's fragile — a future caller invoking_load_imagedirectly from a thread pool would hit the race.Consider removing the
_ensure_sessions()call here (relying on the caller contract) or adding athreading.Lockguard.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@trapdata/antenna/datasets.py` around lines 195 - 199, The call to _ensure_sessions() from worker threads is racy; add a threading.Lock to serialize session creation: create self._sessions_lock (e.g., in the class __init__), and wrap the session-initialization logic inside _ensure_sessions() with a with self._sessions_lock: (also protect reads/writes of self._image_fetch_session there) so concurrent _load_image calls cannot create duplicate sessions; alternatively, remove the _ensure_sessions() call from _load_image and rely on the caller contract, but prefer the lock approach to make _load_image/_load_images_threaded robust.
308-312: Redundant ternary —errorsis always truthy here.Line 310 is inside
if errors:, so theif errors else Nonepart of the ternary on line 311 can never beNone.Suggested simplification
if errors: - row["error"] = "; ".join(errors) if errors else None + row["error"] = "; ".join(errors)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@trapdata/antenna/datasets.py` around lines 308 - 312, The code sets row["error"] inside an if errors: block but uses a redundant ternary (if errors else None); in the block where errors is truthy, replace the ternary with a direct assignment so row["error"] = "; ".join(errors) (remove the "if errors else None" branch) in the same scope around the code that builds row (the variables row and errors and task.image_url are in this context).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@trapdata/antenna/datasets.py`:
- Line 23: The docstring in antenna.datasets that currently states "Default: 4"
is inconsistent with the actual default used (Settings.num_workers == 2 and the
later comment at line 40); update the docstring to state "Default: 2" (or
reference Settings.num_workers) so it matches the implemented default, and
verify any mention of the default in the function/class docstring (e.g., the
Antenna dequeues/num_workers docstring) is consistent with Settings.num_workers
in settings.py.
---
Nitpick comments:
In `@trapdata/antenna/datasets.py`:
- Around line 195-199: The call to _ensure_sessions() from worker threads is
racy; add a threading.Lock to serialize session creation: create
self._sessions_lock (e.g., in the class __init__), and wrap the
session-initialization logic inside _ensure_sessions() with a with
self._sessions_lock: (also protect reads/writes of self._image_fetch_session
there) so concurrent _load_image calls cannot create duplicate sessions;
alternatively, remove the _ensure_sessions() call from _load_image and rely on
the caller contract, but prefer the lock approach to make
_load_image/_load_images_threaded robust.
- Around line 308-312: The code sets row["error"] inside an if errors: block but
uses a redundant ternary (if errors else None); in the block where errors is
truthy, replace the ternary with a direct assignment so row["error"] = ";
".join(errors) (remove the "if errors else None" branch) in the same scope
around the code that builds row (the variables row and errors and task.image_url
are in this context).
Matches the actual default changed in ed49153. Co-Authored-By: Claude Opus 4.6 <[email protected]>
Summary
Variable-size image crash:
rest_collate_fnusedtorch.stackwhich requires uniform tensor sizes. When a batch contained images of different resolutions (e.g. 3420x6080 and 2160x4096), the entire job crashed and restarted in a loop. Images are now passed as a list of tensors, which FasterRCNN natively accepts.Per-batch error handling: A single failed batch previously killed the entire
_process_jobcall, resetting the counter and retrying from scratch. Batch processing is now wrapped in try/except: errors are reported back to Antenna for the affected images, and the loop continues to the next batch.post_batch_resultsfailure is also downgraded from a raised exception to a logged error.Batched classification: Detection crops were classified one at a time (N GPU calls for N detections). Crops are now collected, transformed to uniform size via
classifier.get_transforms(), stacked, and classified in a singlepredict_batchcall. Invalid bounding boxes (y1 >= y2 or x1 >= x2) are skipped. Supersedes Batch classification predictions in worker for GPU efficiency #105.Test plan
ami workeron images of varying resolutions🤖 Generated with Claude Code
Summary by CodeRabbit
New Features
Bug Fixes
Performance
Chores
Tests