-
Notifications
You must be signed in to change notification settings - Fork 443
ObjectDetection/InstanceSegmentationTask: fix support for non-RGB images #2752
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
ObjectDetection/InstanceSegmentationTask: fix support for non-RGB images #2752
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR fixes multispectral support in the ObjectDetectionTask by overriding the default transform parameters in the detector models. The changes update the initialization of FasterRCNN, FCOS, and RetinaNet with custom parameters (min_size, max_size, image_mean, and image_std) that enable multispectral inputs, and a new test is added to validate this functionality.
- Updated transform parameters for multispectral support in three detection model constructors.
- Added a new test in tests/trainers/test_detection.py to check multispectral behavior.
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.
File | Description |
---|---|
torchgeo/trainers/detection.py | Updated model constructors to override transform parameters for multispectral data. |
tests/trainers/test_detection.py | Added a test case to validate multispectral support with a non-RGB input channel. |
Can you also check instance segmentation? |
FYI I confirmed no issues with OD on my 4 channel dataset |
OK, whilst the error is resolved, the loss for OD models I train is always zero. Is there a validated results I can reproduce? Note this might just be my datasets which I recently updated for the new format |
This by default was resizing all imagery to a min of 800. What transforms are you using to preprocess your imagery? |
@adamjstewart fixed the instance segmentation task. It had the same issue. |
@isaaccorley can you elaborate on
where typically chip_size = 224 |
Torchvision Faster-RCNN and MaskRCNN has a One trick that works well for object detection in remote sensing is to simply resize your small patches to be larger. This may be why you're getting poor performance. |
@isaaccorley good to know! Perhaps we should document this? |
This PR basically removes this transform, so a user can decide which normalize and resize Kornia transform they want to do themselves. |
I've ruled out issues with my dataset and the remaining differences I see beween my legacy implementation and this implementation are details such as the anchor sizes I've utilised. I suggest we merge this approach and then as a follow up (and pending a suitable test dataset) work on further optimisations in another PR |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the fix!
Except for the tests... |
7c10564
to
e4c6832
Compare
As discussed in slack with @isaaccorley better results are achieved with the torchvision defaults for norm with
Noting allenai/rslearn#171 where they choose 800 for both, which makes sense as we usually deal with square images in RS |
Some benchmarking based on the defaults using VHR10 for 25 epochs: task = ObjectDetectionTask(
model="faster-rcnn",
backbone="resnet18",
weights=True,
in_channels=3,
num_classes=11,
trainable_layers=3,
lr=1e-3,
patience=10,
freeze_backbone=False,
) This PR:
Equal sizes - as in rslearn but actually skipped
Torchvision defaults:
Ideally we just don't transform of course - will check if we can use the |
c196441
to
bd5d31d
Compare
min_size=800, | ||
max_size=800, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a hard constraint? It would be nice to be as flexible as possible here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually think it might be better to pass this in as an arg, image_size or similar. However as mentioned it would be preferable if we could avoid this processing altogether.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll make it an input arg
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't we make the range huge and let the datamodules control this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do all the datamodules support image_size as an input arg? I have it as one for mine and always perform a resize - definitely want control over this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not consistently, no, but we could. We definitely could for all the object detection/instance segmentation ones, there aren't that many.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me try updating the datamodules with controllable sizes and see if I can get the models to converge. Ideally we could set this range to be a no-op which is what I originally set it to (1,4096) for
One of the wonderful surprises of torchvision's detector models is that a
GeneralizedRCNNTransform
gets added under the hood which defaults to ImageNet RGB mean/std normalize + dynamic resizing in the range of (800, 1333).This PR fixes this by loading pretrained weights but overriding this transform to simply subtract 0 and divide by 1 which is a no-op and changes the dynamic resize to allow for a min/max input shape in the range of (1, 4096).
Alternatives considered:
I attempted to simply replace
model.transform
withnn.Identity()
but this doesn't work because the detection models pass multiple args to the transform which will throw an error.Fixes #2749