Skip to content

ObjectDetection/InstanceSegmentationTask: fix support for non-RGB images #2752

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions tests/trainers/test_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any

import pytest
import torch
from lightning.pytorch import Trainer
from pytest import MonkeyPatch

Expand Down Expand Up @@ -122,3 +123,17 @@ def test_freeze_backbone(self, model_name: str) -> None:
model=model_name, backbone='resnet18', freeze_backbone=True
)
assert not all([param.requires_grad for param in model.model.parameters()])

@pytest.mark.parametrize('model_name', ['faster-rcnn', 'fcos', 'retinanet'])
@pytest.mark.parametrize('in_channels', [1, 4])
def test_multispectral_support(self, model_name: str, in_channels: int) -> None:
model = ObjectDetectionTask(
model=model_name,
backbone='resnet18',
num_classes=2,
in_channels=in_channels,
)
model.eval()
sample = [torch.randn(in_channels, 224, 224)]
with torch.inference_mode():
model(sample)
9 changes: 9 additions & 0 deletions tests/trainers/test_instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any

import pytest
import torch
from lightning.pytorch import Trainer
from pytest import MonkeyPatch

Expand Down Expand Up @@ -123,3 +124,11 @@ def test_freeze_backbone(self) -> None:
for head in ['rpn', 'roi_heads']:
for param in getattr(task.model, head).parameters():
assert param.requires_grad is True

@pytest.mark.parametrize('in_channels', [1, 4])
def test_multispectral_support(self, in_channels: int) -> None:
model = InstanceSegmentationTask(in_channels=in_channels, num_classes=2)
model.eval()
sample = [torch.randn(in_channels, 224, 224)]
with torch.inference_mode():
model(sample)
4 changes: 4 additions & 0 deletions torchgeo/trainers/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from ..datasets import RGBBandsMissingError, unbind_samples
from .base import BaseTask
from .utils import GeneralizedRCNNTransformNoOp

BACKBONE_LAT_DIM_MAP = {
'resnet18': 512,
Expand Down Expand Up @@ -152,6 +153,7 @@ def configure_models(self) -> None:
rpn_anchor_generator=anchor_generator,
box_roi_pool=roi_pooler,
)
self.model.transform = GeneralizedRCNNTransformNoOp()
elif model == 'fcos':
kwargs['extra_blocks'] = feature_pyramid_network.LastLevelP6P7(256, 256)
kwargs['norm_layer'] = (
Expand All @@ -171,6 +173,7 @@ def configure_models(self) -> None:
self.model = torchvision.models.detection.FCOS(
model_backbone, num_classes, anchor_generator=anchor_generator
)
self.model.transform = GeneralizedRCNNTransformNoOp()
elif model == 'retinanet':
kwargs['extra_blocks'] = feature_pyramid_network.LastLevelP6P7(
latent_dim, 256
Expand Down Expand Up @@ -205,6 +208,7 @@ def configure_models(self) -> None:
anchor_generator=anchor_generator,
head=head,
)
self.model.transform = GeneralizedRCNNTransformNoOp()
else:
raise ValueError(f"Model type '{model}' is not valid.")

Expand Down
7 changes: 7 additions & 0 deletions torchgeo/trainers/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from ..datasets import RGBBandsMissingError, unbind_samples
from .base import BaseTask
from .utils import GeneralizedRCNNTransformNoOp


class InstanceSegmentationTask(BaseTask):
Expand All @@ -47,6 +48,9 @@ def __init__(
) -> None:
"""Initialize a new InstanceSegmentationTask instance.

Note that we disable the internal normalize+resize transform of the MaskRCNN model.
Please ensure your images are appropriately resized before passing them to the model.

Args:
model: Name of the model to use.
backbone: Name of the backbone to use.
Expand Down Expand Up @@ -87,7 +91,10 @@ def configure_models(self) -> None:
weights=weights,
num_classes=num_classes,
weights_backbone=weights_backbone,
image_mean=[0],
image_std=[1],
)
self.model.transform = GeneralizedRCNNTransformNoOp()
else:
msg = f"Invalid backbone type '{backbone}'. Supported backbone: 'resnet50'"
raise ValueError(msg)
Expand Down
18 changes: 18 additions & 0 deletions torchgeo/trainers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,24 @@
import torch.nn as nn
from torch import Tensor
from torch.nn.modules import Conv2d, Module
from torchvision.models.detection.transform import GeneralizedRCNNTransform


class GeneralizedRCNNTransformNoOp(GeneralizedRCNNTransform):
"""GeneralizedRCNNTransform without the normalize and resize ops.

.. versionadded:: 0.8
"""

def __init__(self) -> None:
"""Initialize a new GeneralizedRCNNTransformNoOp instance."""
super().__init__(min_size=0, max_size=0, image_mean=[0], image_std=[1])

def resize(
self, image: Tensor, target: dict[str, Tensor] | None = None
) -> tuple[Tensor, dict[str, Tensor] | None]:
"""Skip resizing and return the image and target."""
return image, target


def extract_backbone(path: str) -> tuple[str, 'OrderedDict[str, Tensor]']:
Expand Down
Loading