diff --git a/tests/conf/vhr10_ins_seg.yaml b/tests/conf/vhr10_ins_seg.yaml index 042454a716b..2f602f4823f 100644 --- a/tests/conf/vhr10_ins_seg.yaml +++ b/tests/conf/vhr10_ins_seg.yaml @@ -9,6 +9,6 @@ data: init_args: batch_size: 1 num_workers: 0 - patch_size: 4 + patch_size: 256 dict_kwargs: root: 'tests/data/vhr10' diff --git a/tests/trainers/test_detection.py b/tests/trainers/test_detection.py index 28b8b5eab3c..f45a6a33803 100644 --- a/tests/trainers/test_detection.py +++ b/tests/trainers/test_detection.py @@ -5,6 +5,7 @@ from typing import Any import pytest +import torch from lightning.pytorch import Trainer from pytest import MonkeyPatch @@ -122,3 +123,17 @@ def test_freeze_backbone(self, model_name: str) -> None: model=model_name, backbone='resnet18', freeze_backbone=True ) assert not all([param.requires_grad for param in model.model.parameters()]) + + @pytest.mark.parametrize('model_name', ['faster-rcnn', 'fcos', 'retinanet']) + @pytest.mark.parametrize('in_channels', [1, 4]) + def test_multispectral_support(self, model_name: str, in_channels: int) -> None: + model = ObjectDetectionTask( + model=model_name, + backbone='resnet18', + num_classes=2, + in_channels=in_channels, + ) + model.eval() + sample = [torch.randn(in_channels, 224, 224)] + with torch.inference_mode(): + model(sample) diff --git a/tests/trainers/test_instance_segmentation.py b/tests/trainers/test_instance_segmentation.py index f422a7cd720..8ce3603a3d9 100644 --- a/tests/trainers/test_instance_segmentation.py +++ b/tests/trainers/test_instance_segmentation.py @@ -5,6 +5,7 @@ from typing import Any import pytest +import torch from lightning.pytorch import Trainer from pytest import MonkeyPatch @@ -123,3 +124,11 @@ def test_freeze_backbone(self) -> None: for head in ['rpn', 'roi_heads']: for param in getattr(task.model, head).parameters(): assert param.requires_grad is True + + @pytest.mark.parametrize('in_channels', [1, 4]) + def test_multispectral_support(self, in_channels: int) -> None: + model = InstanceSegmentationTask(in_channels=in_channels, num_classes=2) + model.eval() + sample = [torch.randn(in_channels, 224, 224)] + with torch.inference_mode(): + model(sample) diff --git a/torchgeo/trainers/detection.py b/torchgeo/trainers/detection.py index 3a97faadf96..added13ec6d 100644 --- a/torchgeo/trainers/detection.py +++ b/torchgeo/trainers/detection.py @@ -24,6 +24,7 @@ from ..datasets import RGBBandsMissingError, unbind_samples from .base import BaseTask +from .utils import GeneralizedRCNNTransformNoOp BACKBONE_LAT_DIM_MAP = { 'resnet18': 512, @@ -152,6 +153,7 @@ def configure_models(self) -> None: rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler, ) + self.model.transform = GeneralizedRCNNTransformNoOp() elif model == 'fcos': kwargs['extra_blocks'] = feature_pyramid_network.LastLevelP6P7(256, 256) kwargs['norm_layer'] = ( @@ -171,6 +173,7 @@ def configure_models(self) -> None: self.model = torchvision.models.detection.FCOS( model_backbone, num_classes, anchor_generator=anchor_generator ) + self.model.transform = GeneralizedRCNNTransformNoOp() elif model == 'retinanet': kwargs['extra_blocks'] = feature_pyramid_network.LastLevelP6P7( latent_dim, 256 @@ -205,6 +208,7 @@ def configure_models(self) -> None: anchor_generator=anchor_generator, head=head, ) + self.model.transform = GeneralizedRCNNTransformNoOp() else: raise ValueError(f"Model type '{model}' is not valid.") diff --git a/torchgeo/trainers/instance_segmentation.py b/torchgeo/trainers/instance_segmentation.py index 3fb7812b9a2..b8105a33ca4 100644 --- a/torchgeo/trainers/instance_segmentation.py +++ b/torchgeo/trainers/instance_segmentation.py @@ -22,6 +22,7 @@ from ..datasets import RGBBandsMissingError, unbind_samples from .base import BaseTask +from .utils import GeneralizedRCNNTransformNoOp class InstanceSegmentationTask(BaseTask): @@ -47,6 +48,9 @@ def __init__( ) -> None: """Initialize a new InstanceSegmentationTask instance. + Note that we disable the internal normalize+resize transform of the MaskRCNN model. + Please ensure your images are appropriately resized before passing them to the model. + Args: model: Name of the model to use. backbone: Name of the backbone to use. @@ -87,7 +91,10 @@ def configure_models(self) -> None: weights=weights, num_classes=num_classes, weights_backbone=weights_backbone, + image_mean=[0], + image_std=[1], ) + self.model.transform = GeneralizedRCNNTransformNoOp() else: msg = f"Invalid backbone type '{backbone}'. Supported backbone: 'resnet50'" raise ValueError(msg) diff --git a/torchgeo/trainers/utils.py b/torchgeo/trainers/utils.py index 10da4ba452f..bfc19811fcd 100644 --- a/torchgeo/trainers/utils.py +++ b/torchgeo/trainers/utils.py @@ -11,6 +11,24 @@ import torch.nn as nn from torch import Tensor from torch.nn.modules import Conv2d, Module +from torchvision.models.detection.transform import GeneralizedRCNNTransform + + +class GeneralizedRCNNTransformNoOp(GeneralizedRCNNTransform): # type: ignore[misc] + """GeneralizedRCNNTransform without the normalize and resize ops. + + .. versionadded:: 0.7.1 + """ + + def __init__(self) -> None: + """Initialize a new GeneralizedRCNNTransformNoOp instance.""" + super().__init__(min_size=0, max_size=0, image_mean=[0], image_std=[1]) + + def resize( + self, image: Tensor, target: dict[str, Tensor] | None = None + ) -> tuple[Tensor, dict[str, Tensor] | None]: + """Skip resizing and return the image and target.""" + return image, target def extract_backbone(path: str) -> tuple[str, 'OrderedDict[str, Tensor]']: