From 350cafe6dac289363f44f8e18b2265bf676ff8da Mon Sep 17 00:00:00 2001 From: jadechoghari Date: Sun, 29 Dec 2024 08:47:47 +0100 Subject: [PATCH] fix modular --- docs/source/en/index.md | 1 + src/transformers/__init__.py | 10 +- .../models/auto/configuration_auto.py | 5 +- .../models/rt_detr_v2/__init__.py | 21 + .../rt_detr_v2/configuration_rt_detr_v2.py | 450 ++++ .../convert_rt_detr_v2_weights_to_hf.py | 5 +- .../models/rt_detr_v2/modeling_rt_detr_v2.py | 2177 +++++++++++++++++ .../models/rt_detr_v2/modular_rt_detr_v2.py | 145 +- src/transformers/utils/dummy_pt_objects.py | 21 + .../rtdetrv2/test_modeling_rt_detr_v2.py | 3 +- utils/check_docstrings.py | 4 +- 11 files changed, 2737 insertions(+), 105 deletions(-) create mode 100644 src/transformers/models/rt_detr_v2/__init__.py create mode 100644 src/transformers/models/rt_detr_v2/configuration_rt_detr_v2.py create mode 100644 src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py diff --git a/docs/source/en/index.md b/docs/source/en/index.md index dcecfc872d61d0..069ca33c62bc49 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -298,6 +298,7 @@ Flax), PyTorch, and/or TensorFlow. | [RoFormer](model_doc/roformer) | ✅ | ✅ | ✅ | | [RT-DETR](model_doc/rt_detr) | ✅ | ❌ | ❌ | | [RT-DETR-ResNet](model_doc/rt_detr_resnet) | ✅ | ❌ | ❌ | +| [RtDetrV2](model_doc/rt_detr_v2) | ✅ | ❌ | ❌ | | [RWKV](model_doc/rwkv) | ✅ | ❌ | ❌ | | [SAM](model_doc/sam) | ✅ | ✅ | ❌ | | [SeamlessM4T](model_doc/seamless_m4t) | ✅ | ❌ | ❌ | diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index cbdfebd1347498..7fb711bfa24df8 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -730,7 +730,7 @@ "RoFormerTokenizer", ], "models.rt_detr": ["RTDetrConfig", "RTDetrResNetConfig"], - "models.rt_detr_v2": ["RtDetrV2Config"], + "models.rt_detr_v2": ["RtDetrV2Config", "RtDetrV2ResNetConfig"], "models.rwkv": ["RwkvConfig"], "models.sam": [ "SamConfig", @@ -3365,11 +3365,7 @@ ] ) _import_structure["models.rt_detr_v2"].extend( - [ - "RtDetrV2ForObjectDetection", - "RtDetrV2Model", - "RtDetrV2PreTrainedModel" - ] + ["RtDetrV2ForObjectDetection", "RtDetrV2Model", "RtDetrV2PreTrainedModel"] ) _import_structure["models.rwkv"].extend( [ @@ -5743,7 +5739,7 @@ RTDetrConfig, RTDetrResNetConfig, ) - from .models.rt_detr_v2 import RtDetrV2Config + from .models.rt_detr_v2 import RtDetrV2Config, RtDetrV2ResNetConfig from .models.rwkv import RwkvConfig from .models.sam import ( SamConfig, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 52e03bcd0ec994..9dded31e6b49c6 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -252,6 +252,7 @@ ("rt_detr", "RTDetrConfig"), ("rt_detr_resnet", "RTDetrResNetConfig"), ("rt_detr_v2", "RtDetrV2Config"), + ("rt_detr_v2_resnet", "RtDetrV2ResNetConfig"), ("rwkv", "RwkvConfig"), ("sam", "SamConfig"), ("seamless_m4t", "SeamlessM4TConfig"), @@ -580,7 +581,8 @@ ("roformer", "RoFormer"), ("rt_detr", "RT-DETR"), ("rt_detr_resnet", "RT-DETR-ResNet"), - ("rt_detr_v2", "RtDetrV2"), + ("rt_detr_v2", "RT-DETR-V2-ResNet"), + ("rt_detr_v2_resnet", "RtDetrV2ResNetConfig"), ("rwkv", "RWKV"), ("sam", "SAM"), ("seamless_m4t", "SeamlessM4T"), @@ -713,6 +715,7 @@ ("siglip_vision_model", "siglip"), ("chinese_clip_vision_model", "chinese_clip"), ("rt_detr_resnet", "rt_detr"), + ("rt_detr_v2_resnet", "rt_detr_v2"), ] ) diff --git a/src/transformers/models/rt_detr_v2/__init__.py b/src/transformers/models/rt_detr_v2/__init__.py new file mode 100644 index 00000000000000..18d51ec296cea3 --- /dev/null +++ b/src/transformers/models/rt_detr_v2/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure +from .configuration_rt_detr_v2 import * +from .modeling_rt_detr_v2 import * diff --git a/src/transformers/models/rt_detr_v2/configuration_rt_detr_v2.py b/src/transformers/models/rt_detr_v2/configuration_rt_detr_v2.py new file mode 100644 index 00000000000000..1d7e71ae2f2971 --- /dev/null +++ b/src/transformers/models/rt_detr_v2/configuration_rt_detr_v2.py @@ -0,0 +1,450 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/rt_detr_v2/modular_rt_detr_v2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_rt_detr_v2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ...utils.backbone_utils import ( + BackboneConfigMixin, + get_aligned_output_features_output_indices, + verify_backbone_config_arguments, +) +from ..auto import CONFIG_MAPPING + + +logger = logging.get_logger(__name__) + + +class RtDetrV2ResNetConfig(BackboneConfigMixin, PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`RtDetrV2ResnetBackbone`]. It is used to instantiate an + ResNet model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the ResNet + [microsoft/resnet-50](https://huggingface.co/microsoft/resnet-50) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + embedding_size (`int`, *optional*, defaults to 64): + Dimensionality (hidden size) for the embedding layer. + hidden_sizes (`List[int]`, *optional*, defaults to `[256, 512, 1024, 2048]`): + Dimensionality (hidden size) at each stage. + depths (`List[int]`, *optional*, defaults to `[3, 4, 6, 3]`): + Depth (number of layers) for each stage. + layer_type (`str`, *optional*, defaults to `"bottleneck"`): + The layer to use, it can be either `"basic"` (used for smaller models, like resnet-18 or resnet-34) or + `"bottleneck"` (used for larger models like resnet-50 and above). + hidden_act (`str`, *optional*, defaults to `"relu"`): + The non-linear activation function in each block. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` + are supported. + downsample_in_first_stage (`bool`, *optional*, defaults to `False`): + If `True`, the first stage will downsample the inputs using a `stride` of 2. + downsample_in_bottleneck (`bool`, *optional*, defaults to `False`): + If `True`, the first conv 1x1 in ResNetBottleNeckLayer will downsample the inputs using a `stride` of 2. + out_features (`List[str]`, *optional*): + If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. + (depending on how many stages the model has). If unset and `out_indices` is set, will default to the + corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + out_indices (`List[int]`, *optional*): + If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how + many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. + If unset and `out_features` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + + Example: + ```python + >>> from transformers import RtDetrV2ResNetConfig, RtDetrV2ResnetBackbone + + >>> # Initializing a ResNet resnet-50 style configuration + >>> configuration = RtDetrV2ResNetConfig() + + >>> # Initializing a model (with random weights) from the resnet-50 style configuration + >>> model = RtDetrV2ResnetBackbone(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "rt_detr_v2_resnet" + layer_types = ["basic", "bottleneck"] + + def __init__( + self, + num_channels=3, + embedding_size=64, + hidden_sizes=[256, 512, 1024, 2048], + depths=[3, 4, 6, 3], + layer_type="bottleneck", + hidden_act="relu", + downsample_in_first_stage=False, + downsample_in_bottleneck=False, + out_features=None, + out_indices=None, + **kwargs, + ): + super().__init__(**kwargs) + if layer_type not in self.layer_types: + raise ValueError(f"layer_type={layer_type} is not one of {','.join(self.layer_types)}") + self.num_channels = num_channels + self.embedding_size = embedding_size + self.hidden_sizes = hidden_sizes + self.depths = depths + self.layer_type = layer_type + self.hidden_act = hidden_act + self.downsample_in_first_stage = downsample_in_first_stage + self.downsample_in_bottleneck = downsample_in_bottleneck + self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)] + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=out_features, out_indices=out_indices, stage_names=self.stage_names + ) + + +class RtDetrV2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`RtDetrV2Model`]. It is used to instantiate a + RT-DETR model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the RT-DETR + [checkpoing/todo](https://huggingface.co/checkpoing/todo) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + initializer_range (`float`, *optional*, defaults to 0.01): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_bias_prior_prob (`float`, *optional*): + The prior probability used by the bias initializer to initialize biases for `enc_score_head` and `class_embed`. + If `None`, `prior_prob` computed as `prior_prob = 1 / (num_labels + 1)` while initializing model weights. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + batch_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the batch normalization layers. + backbone_config (`Dict`, *optional*, defaults to `RtDetrV2ResNetConfig()`): + The configuration of the backbone model. + backbone (`str`, *optional*): + Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this + will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone` + is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. + use_pretrained_backbone (`bool`, *optional*, defaults to `False`): + Whether to use pretrained weights for the backbone. + use_timm_backbone (`bool`, *optional*, defaults to `False`): + Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers + library. + freeze_backbone_batch_norms (`bool`, *optional*, defaults to `True`): + Whether to freeze the batch normalization layers in the backbone. + backbone_kwargs (`dict`, *optional*): + Keyword arguments to be passed to AutoBackbone when loading from a checkpoint + e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set. + encoder_hidden_dim (`int`, *optional*, defaults to 256): + Dimension of the layers in hybrid encoder. + encoder_in_channels (`list`, *optional*, defaults to `[512, 1024, 2048]`): + Multi level features input for encoder. + feat_strides (`List[int]`, *optional*, defaults to `[8, 16, 32]`): + Strides used in each feature map. + encoder_layers (`int`, *optional*, defaults to 1): + Total of layers to be used by the encoder. + encoder_ffn_dim (`int`, *optional*, defaults to 1024): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + encoder_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + dropout (`float`, *optional*, defaults to 0.0): + The ratio for all dropout layers. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + encode_proj_layers (`List[int]`, *optional*, defaults to `[2]`): + Indexes of the projected layers to be used in the encoder. + positional_encoding_temperature (`int`, *optional*, defaults to 10000): + The temperature parameter used to create the positional encodings. + encoder_activation_function (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + activation_function (`str`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the general layer. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + eval_size (`Tuple[int, int]`, *optional*): + Height and width used to computes the effective height and width of the position embeddings after taking + into account the stride. + normalize_before (`bool`, *optional*, defaults to `False`): + Determine whether to apply layer normalization in the transformer encoder layer before self-attention and + feed-forward modules. + hidden_expansion (`float`, *optional*, defaults to 1.0): + Expansion ratio to enlarge the dimension size of RepVGGBlock and CSPRepLayer. + d_model (`int`, *optional*, defaults to 256): + Dimension of the layers exclude hybrid encoder. + num_queries (`int`, *optional*, defaults to 300): + Number of object queries. + decoder_in_channels (`list`, *optional*, defaults to `[256, 256, 256]`): + Multi level features dimension for decoder + decoder_ffn_dim (`int`, *optional*, defaults to 1024): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + num_feature_levels (`int`, *optional*, defaults to 3): + The number of input feature levels. + decoder_n_points (`int`, *optional*, defaults to 4): + The number of sampled keys in each feature level for each attention head in the decoder. + decoder_layers (`int`, *optional*, defaults to 6): + Number of decoder layers. + decoder_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_activation_function (`str`, *optional*, defaults to `"relu"`): + The non-linear activation function (function or string) in the decoder. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + num_denoising (`int`, *optional*, defaults to 100): + The total number of denoising tasks or queries to be used for contrastive denoising. + label_noise_ratio (`float`, *optional*, defaults to 0.5): + The fraction of denoising labels to which random noise should be added. + box_noise_scale (`float`, *optional*, defaults to 1.0): + Scale or magnitude of noise to be added to the bounding boxes. + learn_initial_query (`bool`, *optional*, defaults to `False`): + Indicates whether the initial query embeddings for the decoder should be learned during training + anchor_image_size (`Tuple[int, int]`, *optional*): + Height and width of the input image used during evaluation to generate the bounding box anchors. If None, automatic generate anchor is applied. + disable_custom_kernels (`bool`, *optional*, defaults to `True`): + Whether to disable custom kernels. + with_box_refine (`bool`, *optional*, defaults to `True`): + Whether to apply iterative bounding box refinement, where each decoder layer refines the bounding boxes + based on the predictions from the previous layer. + is_encoder_decoder (`bool`, *optional*, defaults to `True`): + Whether the architecture has an encoder decoder structure. + matcher_alpha (`float`, *optional*, defaults to 0.25): + Parameter alpha used by the Hungarian Matcher. + matcher_gamma (`float`, *optional*, defaults to 2.0): + Parameter gamma used by the Hungarian Matcher. + matcher_class_cost (`float`, *optional*, defaults to 2.0): + The relative weight of the class loss used by the Hungarian Matcher. + matcher_bbox_cost (`float`, *optional*, defaults to 5.0): + The relative weight of the bounding box loss used by the Hungarian Matcher. + matcher_giou_cost (`float`, *optional*, defaults to 2.0): + The relative weight of the giou loss of used by the Hungarian Matcher. + use_focal_loss (`bool`, *optional*, defaults to `True`): + Parameter informing if focal focal should be used. + auxiliary_loss (`bool`, *optional*, defaults to `True`): + Whether auxiliary decoding losses (loss at each decoder layer) are to be used. + focal_loss_alpha (`float`, *optional*, defaults to 0.75): + Parameter alpha used to compute the focal loss. + focal_loss_gamma (`float`, *optional*, defaults to 2.0): + Parameter gamma used to compute the focal loss. + weight_loss_vfl (`float`, *optional*, defaults to 1.0): + Relative weight of the varifocal loss in the object detection loss. + weight_loss_bbox (`float`, *optional*, defaults to 5.0): + Relative weight of the L1 bounding box loss in the object detection loss. + weight_loss_giou (`float`, *optional*, defaults to 2.0): + Relative weight of the generalized IoU loss in the object detection loss. + eos_coefficient (`float`, *optional*, defaults to 0.0001): + Relative classification weight of the 'no-object' class in the object detection loss. + + Examples: + + ```python + >>> from transformers import RtDetrV2Config, RtDetrV2Model + + >>> # Initializing a RT-DETR configuration + >>> configuration = RtDetrV2Config() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = RtDetrV2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "rt_detr_v2" + layer_types = ["basic", "bottleneck"] + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "encoder_attention_heads", + } + + def __init__( + self, + initializer_range=0.01, + initializer_bias_prior_prob=None, + layer_norm_eps=1e-5, + batch_norm_eps=1e-5, + # backbone + backbone_config=None, + backbone=None, + use_pretrained_backbone=False, + use_timm_backbone=False, + freeze_backbone_batch_norms=True, + backbone_kwargs=None, + # encoder HybridEncoder + encoder_hidden_dim=256, + encoder_in_channels=[512, 1024, 2048], + feat_strides=[8, 16, 32], + encoder_layers=1, + encoder_ffn_dim=1024, + encoder_attention_heads=8, + dropout=0.0, + activation_dropout=0.0, + encode_proj_layers=[2], + positional_encoding_temperature=10000, + encoder_activation_function="gelu", + activation_function="silu", + eval_size=None, + normalize_before=False, + hidden_expansion=1.0, + # decoder RtDetrV2Transformer + d_model=256, + num_queries=300, + decoder_in_channels=[256, 256, 256], + decoder_ffn_dim=1024, + num_feature_levels=3, + decoder_n_points=4, + decoder_layers=6, + decoder_attention_heads=8, + decoder_activation_function="relu", + attention_dropout=0.0, + num_denoising=100, + label_noise_ratio=0.5, + box_noise_scale=1.0, + learn_initial_query=False, + anchor_image_size=None, + disable_custom_kernels=True, + with_box_refine=True, + is_encoder_decoder=True, + # Loss + matcher_alpha=0.25, + matcher_gamma=2.0, + matcher_class_cost=2.0, + matcher_bbox_cost=5.0, + matcher_giou_cost=2.0, + use_focal_loss=True, + auxiliary_loss=True, + focal_loss_alpha=0.75, + focal_loss_gamma=2.0, + weight_loss_vfl=1.0, + weight_loss_bbox=5.0, + weight_loss_giou=2.0, + eos_coefficient=1e-4, + decoder_n_levels=3, # default value + decoder_offset_scale=0.5, # default value + **kwargs, + ): + super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs) + self.initializer_range = initializer_range + self.initializer_bias_prior_prob = initializer_bias_prior_prob + self.layer_norm_eps = layer_norm_eps + self.batch_norm_eps = batch_norm_eps + # backbone + if backbone_config is None and backbone is None: + logger.info( + "`backbone_config` and `backbone` are `None`. Initializing the config with the default `RtDetrV2-ResNet` backbone." + ) + backbone_config = RtDetrV2ResNetConfig( + num_channels=3, + embedding_size=64, + hidden_sizes=[256, 512, 1024, 2048], + depths=[3, 4, 6, 3], + layer_type="bottleneck", + hidden_act="relu", + downsample_in_first_stage=False, + downsample_in_bottleneck=False, + out_features=None, + out_indices=[2, 3, 4], + ) + elif isinstance(backbone_config, dict): + backbone_model_type = backbone_config.pop("model_type") + config_class = CONFIG_MAPPING[backbone_model_type] + backbone_config = config_class.from_dict(backbone_config) + + verify_backbone_config_arguments( + use_timm_backbone=use_timm_backbone, + use_pretrained_backbone=use_pretrained_backbone, + backbone=backbone, + backbone_config=backbone_config, + backbone_kwargs=backbone_kwargs, + ) + + self.backbone_config = backbone_config + self.backbone = backbone + self.use_pretrained_backbone = use_pretrained_backbone + self.use_timm_backbone = use_timm_backbone + self.freeze_backbone_batch_norms = freeze_backbone_batch_norms + self.backbone_kwargs = backbone_kwargs + # encoder + self.encoder_hidden_dim = encoder_hidden_dim + self.encoder_in_channels = encoder_in_channels + self.feat_strides = feat_strides + self.encoder_attention_heads = encoder_attention_heads + self.encoder_ffn_dim = encoder_ffn_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + self.encode_proj_layers = encode_proj_layers + self.encoder_layers = encoder_layers + self.positional_encoding_temperature = positional_encoding_temperature + self.eval_size = eval_size + self.normalize_before = normalize_before + self.encoder_activation_function = encoder_activation_function + self.activation_function = activation_function + self.hidden_expansion = hidden_expansion + # decoder + self.d_model = d_model + self.num_queries = num_queries + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_in_channels = decoder_in_channels + self.num_feature_levels = num_feature_levels + self.decoder_n_points = decoder_n_points + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.decoder_activation_function = decoder_activation_function + self.attention_dropout = attention_dropout + self.num_denoising = num_denoising + self.label_noise_ratio = label_noise_ratio + self.box_noise_scale = box_noise_scale + self.learn_initial_query = learn_initial_query + self.anchor_image_size = anchor_image_size + self.auxiliary_loss = auxiliary_loss + self.disable_custom_kernels = disable_custom_kernels + self.with_box_refine = with_box_refine + # Loss + self.matcher_alpha = matcher_alpha + self.matcher_gamma = matcher_gamma + self.matcher_class_cost = matcher_class_cost + self.matcher_bbox_cost = matcher_bbox_cost + self.matcher_giou_cost = matcher_giou_cost + self.use_focal_loss = use_focal_loss + self.focal_loss_alpha = focal_loss_alpha + self.focal_loss_gamma = focal_loss_gamma + self.weight_loss_vfl = weight_loss_vfl + self.weight_loss_bbox = weight_loss_bbox + self.weight_loss_giou = weight_loss_giou + self.eos_coefficient = eos_coefficient + + # add the new attributes with the given values or defaults + self.decoder_n_levels = decoder_n_levels + self.decoder_offset_scale = decoder_offset_scale + + @property + def num_attention_heads(self) -> int: + return self.encoder_attention_heads + + @property + def hidden_size(self) -> int: + return self.d_model + + @classmethod + def from_backbone_configs(cls, backbone_config: PretrainedConfig, **kwargs): + """Instantiate a [`RtDetrV2Config`] (or a derived class) from a pre-trained backbone model configuration and DETR model + configuration. + + Args: + backbone_config ([`PretrainedConfig`]): + The backbone configuration. + + Returns: + [`RtDetrV2Config`]: An instance of a configuration object + """ + return cls( + backbone_config=backbone_config, + **kwargs, + ) diff --git a/src/transformers/models/rt_detr_v2/convert_rt_detr_v2_weights_to_hf.py b/src/transformers/models/rt_detr_v2/convert_rt_detr_v2_weights_to_hf.py index c16be339335f04..c868698cca1d94 100644 --- a/src/transformers/models/rt_detr_v2/convert_rt_detr_v2_weights_to_hf.py +++ b/src/transformers/models/rt_detr_v2/convert_rt_detr_v2_weights_to_hf.py @@ -24,8 +24,7 @@ from PIL import Image from torchvision import transforms -from transformers import RTDetrImageProcessor -from transformers.models.rt_detr_v2.modular_rt_detr_v2 import RtDetrV2Config, RtDetrV2ForObjectDetection +from transformers import RTDetrImageProcessor, RtDetrV2Config, RtDetrV2ForObjectDetection from transformers.utils import logging @@ -558,7 +557,6 @@ def convert_rt_detr_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub "ema" ]["module"] - # rename keys for src, dest in create_rename_keys(config): rename_key(state_dict, src, dest) @@ -581,7 +579,6 @@ def convert_rt_detr_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub model.load_state_dict(state_dict) model.eval() - # load image processor image_processor = RTDetrImageProcessor() diff --git a/src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py b/src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py new file mode 100644 index 00000000000000..d6604e6c2eaeec --- /dev/null +++ b/src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py @@ -0,0 +1,2177 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/rt_detr_v2/modular_rt_detr_v2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_rt_detr_v2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +import math +import os +import warnings +from dataclasses import dataclass +from functools import lru_cache, partial, wraps +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable + +from ...activations import ACT2CLS, ACT2FN +from ...image_transforms import center_to_corners_format, corners_to_center_format +from ...modeling_outputs import BaseModelOutput +from ...modeling_utils import PreTrainedModel +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_ninja_available, + is_torch_cuda_available, + logging, + replace_return_docstrings, +) +from ...utils.backbone_utils import load_backbone +from .configuration_rt_detr_v2 import RtDetrV2Config + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "RtDetrV2Config" + +MultiScaleDeformableAttention = None + + +class MultiScaleDeformableAttentionFunctionV2(Function): + @staticmethod + def forward( + context, + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + im2col_step, + ): + context.im2col_step = im2col_step + output = MultiScaleDeformableAttention.ms_deform_attn_forward( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + context.im2col_step, + ) + context.save_for_backward( + value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights + ) + return output + + @staticmethod + @once_differentiable + def backward(context, grad_output): + ( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + ) = context.saved_tensors + grad_value, grad_sampling_loc, grad_attn_weight = MultiScaleDeformableAttention.ms_deform_attn_backward( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + grad_output, + context.im2col_step, + ) + + return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None + + +def load_cuda_kernels(): + from torch.utils.cpp_extension import load + + global MultiScaleDeformableAttention + + root = Path(__file__).resolve().parent.parent.parent / "kernels" / "deformable_detr" + src_files = [ + root / filename + for filename in [ + "vision.cpp", + os.path.join("cpu", "ms_deform_attn_cpu.cpp"), + os.path.join("cuda", "ms_deform_attn_cuda.cu"), + ] + ] + + MultiScaleDeformableAttention = load( + "MultiScaleDeformableAttention", + src_files, + with_cuda=True, + extra_include_paths=[str(root)], + extra_cflags=["-DWITH_CUDA=1"], + extra_cuda_cflags=[ + "-DCUDA_HAS_FP16=1", + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", + ], + ) + + +def multi_scale_deformable_attention_v2( + value: Tensor, + value_spatial_shapes: Tensor, + sampling_locations: Tensor, + attention_weights: Tensor, + num_points_list: List[int], + method="default", +) -> Tensor: + batch_size, _, num_heads, hidden_dim = value.shape + _, num_queries, num_heads, num_levels, num_points = sampling_locations.shape + value_list = ( + value.permute(0, 2, 3, 1) + .flatten(0, 1) + .split([height.item() * width.item() for height, width in value_spatial_shapes], dim=-1) + ) + # sampling_offsets [8, 480, 8, 12, 2] + if method == "default": + sampling_grids = 2 * sampling_locations - 1 + elif method == "discrete": + sampling_grids = sampling_locations + sampling_grids = sampling_grids.permute(0, 2, 1, 3, 4).flatten(0, 1) + sampling_grids = sampling_grids.split(num_points_list, dim=-2) + sampling_value_list = [] + for level_id, (height, width) in enumerate(value_spatial_shapes): + # batch_size, height*width, num_heads, hidden_dim + # -> batch_size, height*width, num_heads*hidden_dim + # -> batch_size, num_heads*hidden_dim, height*width + # -> batch_size*num_heads, hidden_dim, height, width + value_l_ = value_list[level_id].reshape(batch_size * num_heads, hidden_dim, height, width) + # batch_size, num_queries, num_heads, num_points, 2 + # -> batch_size, num_heads, num_queries, num_points, 2 + # -> batch_size*num_heads, num_queries, num_points, 2 + sampling_grid_l_ = sampling_grids[level_id] + # batch_size*num_heads, hidden_dim, num_queries, num_points + if method == "default": + sampling_value_l_ = nn.functional.grid_sample( + value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False + ) + elif method == "discrete": + sampling_coord = (sampling_grid_l_ * torch.tensor([[width, height]], device=value.device) + 0.5).to( + torch.int64 + ) + + # Separate clamping for x and y coordinates + sampling_coord_x = sampling_coord[..., 0].clamp(0, width - 1) + sampling_coord_y = sampling_coord[..., 1].clamp(0, height - 1) + + # Combine the clamped coordinates + sampling_coord = torch.stack([sampling_coord_x, sampling_coord_y], dim=-1) + sampling_coord = sampling_coord.reshape(batch_size * num_heads, num_queries * num_points_list[level_id], 2) + sampling_idx = ( + torch.arange(sampling_coord.shape[0], device=value.device) + .unsqueeze(-1) + .repeat(1, sampling_coord.shape[1]) + ) + sampling_value_l_ = value_l_[sampling_idx, :, sampling_coord[..., 1], sampling_coord[..., 0]] + sampling_value_l_ = sampling_value_l_.permute(0, 2, 1).reshape( + batch_size * num_heads, hidden_dim, num_queries, num_points_list[level_id] + ) + sampling_value_list.append(sampling_value_l_) + # (batch_size, num_queries, num_heads, num_levels, num_points) + # -> (batch_size, num_heads, num_queries, num_levels, num_points) + # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points) + attention_weights = attention_weights.permute(0, 2, 1, 3).reshape( + batch_size * num_heads, 1, num_queries, sum(num_points_list) + ) + output = ( + (torch.concat(sampling_value_list, dim=-1) * attention_weights) + .sum(-1) + .view(batch_size, num_heads * hidden_dim, num_queries) + ) + return output.transpose(1, 2).contiguous() + + +class RtDetrV2MultiscaleDeformableAttention(nn.Module): + """ + Multiscale deformable attention as proposed in Deformable DETR. + + RTDETRV2 version of multiscale deformable attention, extending the base implementation + with improved offset handling and initialization. + """ + + def __init__(self, config: RtDetrV2Config, num_heads: int = None, n_points: int = None): + super().__init__() + num_heads = num_heads if num_heads is not None else config.decoder_attention_heads + n_points = n_points if n_points is not None else config.decoder_n_points + + kernel_loaded = MultiScaleDeformableAttention is not None + if is_torch_cuda_available() and is_ninja_available() and not kernel_loaded: + try: + load_cuda_kernels() + except Exception as e: + logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}") + + if config.d_model % num_heads != 0: + raise ValueError( + f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}" + ) + dim_per_head = config.d_model // num_heads + # check if dim_per_head is power of 2 + if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0): + warnings.warn( + "You'd better set embed_dim (d_model) in RtDetrV2MultiscaleDeformableAttention to make the" + " dimension of each attention head a power of 2 which is more efficient in the authors' CUDA" + " implementation." + ) + + self.im2col_step = 64 + + self.d_model = config.d_model + self.n_levels = config.num_feature_levels + self.n_heads = num_heads + self.n_points = n_points + + self.sampling_offsets = nn.Linear(config.d_model, num_heads * self.n_levels * n_points * 2) + self.attention_weights = nn.Linear(config.d_model, num_heads * self.n_levels * n_points) + self.value_proj = nn.Linear(config.d_model, config.d_model) + self.output_proj = nn.Linear(config.d_model, config.d_model) + + self.disable_custom_kernels = config.disable_custom_kernels + + # V2-specific attributes + self.offset_scale = config.decoder_offset_scale + # Initialize n_points list and scale + n_points_list = [self.n_points for _ in range(self.n_levels)] + self.n_points_list = n_points_list + n_points_scale = [1 / n for n in n_points_list for _ in range(n)] + self.register_buffer("n_points_scale", torch.tensor(n_points_scale, dtype=torch.float32)) + + def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]): + return tensor if position_embeddings is None else tensor + position_embeddings + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states=None, + encoder_attention_mask=None, + position_embeddings: Optional[torch.Tensor] = None, + reference_points=None, + spatial_shapes=None, + level_start_index=None, + output_attentions: bool = False, + **kwargs, + ): + # Process inputs up to sampling locations calculation using parent class logic + if position_embeddings is not None: + hidden_states = self.with_pos_embed(hidden_states, position_embeddings) + + batch_size, num_queries, _ = hidden_states.shape + batch_size, sequence_length, _ = encoder_hidden_states.shape + if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length: + raise ValueError( + "Make sure to align the spatial shapes with the sequence length of the encoder hidden states" + ) + + value = self.value_proj(encoder_hidden_states) + if attention_mask is not None: + value = value.masked_fill(~attention_mask[..., None], float(0)) + value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads) + + # V2-specific sampling offsets shape + sampling_offsets = self.sampling_offsets(hidden_states).view( + batch_size, num_queries, self.n_heads, self.n_levels * self.n_points, 2 + ) + + attention_weights = self.attention_weights(hidden_states).view( + batch_size, num_queries, self.n_heads, self.n_levels * self.n_points + ) + attention_weights = F.softmax(attention_weights, -1) + + # V2-specific sampling locations calculation + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) + sampling_locations = ( + reference_points[:, :, None, :, None, :] + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + ) + elif reference_points.shape[-1] == 4: + n_points_scale = self.n_points_scale.to(dtype=hidden_states.dtype).unsqueeze(-1) + offset = sampling_offsets * n_points_scale * reference_points[:, :, None, :, 2:] * self.offset_scale + sampling_locations = reference_points[:, :, None, :, :2] + offset + else: + raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") + + # V2-specific attention implementation choice + if self.disable_custom_kernels: + output = multi_scale_deformable_attention_v2( + value, spatial_shapes, sampling_locations, attention_weights, self.n_points_list + ) + else: + try: + output = MultiScaleDeformableAttentionFunctionV2.apply( + value, + spatial_shapes, + level_start_index, + sampling_locations, + attention_weights, + self.im2col_step, + ) + except Exception: + output = multi_scale_deformable_attention_v2( + value, spatial_shapes, sampling_locations, attention_weights, self.n_points_list + ) + + output = self.output_proj(output) + return output, attention_weights + + +class RtDetrV2MultiheadAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. + + Here, we add position embeddings to the queries and keys (as explained in the Deformable DETR paper). + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + bias: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + if self.head_dim * num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _reshape(self, tensor: torch.Tensor, seq_len: int, batch_size: int): + return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]): + return tensor if position_embeddings is None else tensor + position_embeddings + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + batch_size, target_len, embed_dim = hidden_states.size() + # add position embeddings to the hidden states before projecting to queries and keys + if position_embeddings is not None: + hidden_states_original = hidden_states + hidden_states = self.with_pos_embed(hidden_states, position_embeddings) + + # get queries, keys and values + query_states = self.q_proj(hidden_states) * self.scaling + key_states = self._reshape(self.k_proj(hidden_states), -1, batch_size) + value_states = self._reshape(self.v_proj(hidden_states_original), -1, batch_size) + + proj_shape = (batch_size * self.num_heads, -1, self.head_dim) + query_states = self._reshape(query_states, target_len, batch_size).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + source_len = key_states.size(1) + + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len): + raise ValueError( + f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is" + f" {attn_weights.size()}" + ) + + # expand attention_mask + if attention_mask is not None: + # [seq_len, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len] + attention_mask = attention_mask.expand(batch_size, 1, *attention_mask.size()) + + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, target_len, source_len): + raise ValueError( + f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is" + f" {attention_mask.size()}" + ) + attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask + attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(batch_size, target_len, embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + + +class RtDetrV2DecoderLayer(nn.Module): + def __init__(self, config: RtDetrV2Config): + super().__init__() + # self-attention + self.self_attn = RtDetrV2MultiheadAttention( + embed_dim=config.d_model, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.decoder_activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps) + # override only the encoder attention module with v2 version + self.encoder_attn = RtDetrV2MultiscaleDeformableAttention(config) + self.encoder_attn_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps) + # feedforward neural networks + self.fc1 = nn.Linear(config.d_model, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, config.d_model) + self.final_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[torch.Tensor] = None, + reference_points=None, + spatial_shapes=None, + spatial_shapes_list=None, + level_start_index=None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(seq_len, batch, embed_dim)`. + position_embeddings (`torch.FloatTensor`, *optional*): + Position embeddings that are added to the queries and keys in the self-attention layer. + reference_points (`torch.FloatTensor`, *optional*): + Reference points. + spatial_shapes (`torch.LongTensor`, *optional*): + Spatial shapes. + level_start_index (`torch.LongTensor`, *optional*): + Level start index. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(seq_len, batch, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative + values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=encoder_attention_mask, + position_embeddings=position_embeddings, + output_attentions=output_attentions, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + second_residual = hidden_states + + # Cross-Attention + cross_attn_weights = None + hidden_states, cross_attn_weights = self.encoder_attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + position_embeddings=position_embeddings, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + level_start_index=level_start_index, + output_attentions=output_attentions, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = second_residual + hidden_states + + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # Fully Connected + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +@dataclass +class RtDetrV2DecoderOutput(ModelOutput): + """ + Base class for outputs of the RtDetrV2Decoder. This class adds two attributes to + BaseModelOutputWithCrossAttentions, namely: + - a stacked tensor of intermediate decoder hidden states (i.e. the output of each decoder layer) + - a stacked tensor of intermediate reference points. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`): + Stacked intermediate logits (logits of each layer of the decoder). + intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, hidden_size)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer + plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, + used to compute the weighted average in the cross-attention heads. + """ + + last_hidden_state: torch.FloatTensor = None + intermediate_hidden_states: torch.FloatTensor = None + intermediate_logits: torch.FloatTensor = None + intermediate_reference_points: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class RtDetrV2ModelOutput(ModelOutput): + """ + Base class for outputs of the RT-DETR encoder-decoder model. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`): + Stacked intermediate logits (logits of each layer of the decoder). + intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, num_queries, hidden_size)`. Hidden-states of the decoder at the output of each layer + plus the initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, num_queries, + num_queries)`. Attentions weights of the decoder, after the attention softmax, used to compute the weighted + average in the self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_queries, num_heads, 4, 4)`. + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each + layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_queries, num_heads, 4, 4)`. + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Initial reference points sent through the Transformer decoder. + enc_topk_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`): + Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are + picked as region proposals in the encoder stage. Output of bounding box binary classification (i.e. + foreground and background). + enc_topk_bboxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`): + Logits of predicted bounding boxes coordinates in the encoder stage. + enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are + picked as region proposals in the first stage. Output of bounding box binary classification (i.e. + foreground and background). + enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the first stage. + denoising_meta_values (`dict`): + Extra dictionary for the denoising related values + """ + + last_hidden_state: torch.FloatTensor = None + intermediate_hidden_states: torch.FloatTensor = None + intermediate_logits: torch.FloatTensor = None + intermediate_reference_points: torch.FloatTensor = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + init_reference_points: torch.FloatTensor = None + enc_topk_logits: Optional[torch.FloatTensor] = None + enc_topk_bboxes: Optional[torch.FloatTensor] = None + enc_outputs_class: Optional[torch.FloatTensor] = None + enc_outputs_coord_logits: Optional[torch.FloatTensor] = None + denoising_meta_values: Optional[Dict] = None + + +@dataclass +class RtDetrV2ObjectDetectionOutput(ModelOutput): + """ + Output type of [`RtDetrV2ForObjectDetection`]. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)): + Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a + bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized + scale-invariant IoU loss. + loss_dict (`Dict`, *optional*): + A dictionary containing the individual losses. Useful for logging. + logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`): + Classification logits (including no-object) for all queries. + pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These + values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding + possible padding). You can use [`~RtDetrV2ImageProcessor.post_process_object_detection`] to retrieve the + unnormalized (absolute) bounding boxes. + auxiliary_outputs (`list[Dict]`, *optional*): + Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`) + and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and + `pred_boxes`) for each decoder layer. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, config.num_labels)`): + Stacked intermediate logits (logits of each layer of the decoder). + intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, num_queries, hidden_size)`. Hidden-states of the decoder at the output of each layer + plus the initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, num_queries, + num_queries)`. Attentions weights of the decoder, after the attention softmax, used to compute the weighted + average in the self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_queries, num_heads, 4, 4)`. + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each + layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_queries, num_heads, 4, 4)`. + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Initial reference points sent through the Transformer decoder. + enc_topk_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the encoder. + enc_topk_bboxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the encoder. + enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are + picked as region proposals in the first stage. Output of bounding box binary classification (i.e. + foreground and background). + enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the first stage. + denoising_meta_values (`dict`): + Extra dictionary for the denoising related values + """ + + loss: Optional[torch.FloatTensor] = None + loss_dict: Optional[Dict] = None + logits: torch.FloatTensor = None + pred_boxes: torch.FloatTensor = None + auxiliary_outputs: Optional[List[Dict]] = None + last_hidden_state: torch.FloatTensor = None + intermediate_hidden_states: torch.FloatTensor = None + intermediate_logits: torch.FloatTensor = None + intermediate_reference_points: torch.FloatTensor = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + init_reference_points: Optional[Tuple[torch.FloatTensor]] = None + enc_topk_logits: Optional[torch.FloatTensor] = None + enc_topk_bboxes: Optional[torch.FloatTensor] = None + enc_outputs_class: Optional[torch.FloatTensor] = None + enc_outputs_coord_logits: Optional[torch.FloatTensor] = None + denoising_meta_values: Optional[Dict] = None + + +class RtDetrV2FrozenBatchNorm2d(nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than + torchvision.models.resnet[18,34,50,101] produce nans. + """ + + def __init__(self, n): + super().__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + num_batches_tracked_key = prefix + "num_batches_tracked" + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + + def forward(self, x): + # move reshapes to the beginning + # to make it user-friendly + weight = self.weight.reshape(1, -1, 1, 1) + bias = self.bias.reshape(1, -1, 1, 1) + running_var = self.running_var.reshape(1, -1, 1, 1) + running_mean = self.running_mean.reshape(1, -1, 1, 1) + epsilon = 1e-5 + scale = weight * (running_var + epsilon).rsqrt() + bias = bias - running_mean * scale + return x * scale + bias + + +def replace_batch_norm(model): + r""" + Recursively replace all `torch.nn.BatchNorm2d` with `RtDetrV2FrozenBatchNorm2d`. + + Args: + model (torch.nn.Module): + input model + """ + for name, module in model.named_children(): + if isinstance(module, nn.BatchNorm2d): + new_module = RtDetrV2FrozenBatchNorm2d(module.num_features) + + if not module.weight.device == torch.device("meta"): + new_module.weight.data.copy_(module.weight) + new_module.bias.data.copy_(module.bias) + new_module.running_mean.data.copy_(module.running_mean) + new_module.running_var.data.copy_(module.running_var) + + model._modules[name] = new_module + + if len(list(module.children())) > 0: + replace_batch_norm(module) + + +class RtDetrV2ConvEncoder(nn.Module): + """ + Convolutional backbone using the modeling_rt_detr_v2_resnet.py. + + nn.BatchNorm2d layers are replaced by RtDetrV2FrozenBatchNorm2d as defined above. + https://github.com/lyuwenyu/RT-DETR/blob/main/RtDetrV2_pytorch/src/nn/backbone/presnet.py#L142 + """ + + def __init__(self, config): + super().__init__() + + backbone = load_backbone(config) + + if config.freeze_backbone_batch_norms: + # replace batch norm by frozen batch norm + with torch.no_grad(): + replace_batch_norm(backbone) + self.model = backbone + self.intermediate_channel_sizes = self.model.channels + + def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor): + # send pixel_values through the model to get list of feature maps + features = self.model(pixel_values).feature_maps + + out = [] + for feature_map in features: + # downsample pixel_mask to match shape of corresponding feature_map + mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0] + out.append((feature_map, mask)) + return out + + +class RtDetrV2ConvNormLayer(nn.Module): + def __init__(self, config, in_channels, out_channels, kernel_size, stride, padding=None, activation=None): + super().__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + padding=(kernel_size - 1) // 2 if padding is None else padding, + bias=False, + ) + self.norm = nn.BatchNorm2d(out_channels, config.batch_norm_eps) + self.activation = nn.Identity() if activation is None else ACT2CLS[activation]() + + def forward(self, hidden_state): + hidden_state = self.conv(hidden_state) + hidden_state = self.norm(hidden_state) + hidden_state = self.activation(hidden_state) + return hidden_state + + +class RtDetrV2EncoderLayer(nn.Module): + def __init__(self, config: RtDetrV2Config): + super().__init__() + self.normalize_before = config.normalize_before + + # self-attention + self.self_attn = RtDetrV2MultiheadAttention( + embed_dim=config.encoder_hidden_dim, + num_heads=config.num_attention_heads, + dropout=config.dropout, + ) + self.self_attn_layer_norm = nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.encoder_activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(config.encoder_hidden_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, config.encoder_hidden_dim) + self.final_layer_norm = nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: torch.Tensor = None, + output_attentions: bool = False, + **kwargs, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative + values. + position_embeddings (`torch.FloatTensor`, *optional*): + Object queries (also called content embeddings), to be added to the hidden states. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + if self.normalize_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + output_attentions=output_attentions, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + if not self.normalize_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + if self.normalize_before: + hidden_states = self.final_layer_norm(hidden_states) + residual = hidden_states + + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + + hidden_states = self.fc2(hidden_states) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + hidden_states = residual + hidden_states + if not self.normalize_before: + hidden_states = self.final_layer_norm(hidden_states) + + if self.training: + if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class RtDetrV2RepVggBlock(nn.Module): + """ + RepVGG architecture block introduced by the work "RepVGG: Making VGG-style ConvNets Great Again". + """ + + def __init__(self, config: RtDetrV2Config): + super().__init__() + + activation = config.activation_function + hidden_channels = int(config.encoder_hidden_dim * config.hidden_expansion) + self.conv1 = RtDetrV2ConvNormLayer(config, hidden_channels, hidden_channels, 3, 1, padding=1) + self.conv2 = RtDetrV2ConvNormLayer(config, hidden_channels, hidden_channels, 1, 1, padding=0) + self.activation = nn.Identity() if activation is None else ACT2CLS[activation]() + + def forward(self, x): + y = self.conv1(x) + self.conv2(x) + return self.activation(y) + + +class RtDetrV2CSPRepLayer(nn.Module): + """ + Cross Stage Partial (CSP) network layer with RepVGG blocks. + """ + + def __init__(self, config: RtDetrV2Config): + super().__init__() + + in_channels = config.encoder_hidden_dim * 2 + out_channels = config.encoder_hidden_dim + num_blocks = 3 + activation = config.activation_function + + hidden_channels = int(out_channels * config.hidden_expansion) + self.conv1 = RtDetrV2ConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation) + self.conv2 = RtDetrV2ConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation) + self.bottlenecks = nn.Sequential(*[RtDetrV2RepVggBlock(config) for _ in range(num_blocks)]) + if hidden_channels != out_channels: + self.conv3 = RtDetrV2ConvNormLayer(config, hidden_channels, out_channels, 1, 1, activation=activation) + else: + self.conv3 = nn.Identity() + + def forward(self, hidden_state): + device = hidden_state.device + hidden_state_1 = self.conv1(hidden_state) + hidden_state_1 = self.bottlenecks(hidden_state_1).to(device) + hidden_state_2 = self.conv2(hidden_state).to(device) + return self.conv3(hidden_state_1 + hidden_state_2) + + +class RtDetrV2Encoder(nn.Module): + def __init__(self, config: RtDetrV2Config): + super().__init__() + + self.layers = nn.ModuleList([RtDetrV2EncoderLayer(config) for _ in range(config.encoder_layers)]) + + def forward(self, src, src_mask=None, pos_embed=None, output_attentions: bool = False) -> torch.Tensor: + hidden_states = src + for layer in self.layers: + hidden_states = layer( + hidden_states, + attention_mask=src_mask, + position_embeddings=pos_embed, + output_attentions=output_attentions, + ) + return hidden_states + + +class RtDetrV2HybridEncoder(nn.Module): + """ + Decoder consisting of a projection layer, a set of `RtDetrV2Encoder`, a top-down Feature Pyramid Network + (FPN) and a bottom-up Path Aggregation Network (PAN). More details on the paper: https://arxiv.org/abs/2304.08069 + + Args: + config: RtDetrV2Config + """ + + def __init__(self, config: RtDetrV2Config): + super().__init__() + self.config = config + self.in_channels = config.encoder_in_channels + self.feat_strides = config.feat_strides + self.encoder_hidden_dim = config.encoder_hidden_dim + self.encode_proj_layers = config.encode_proj_layers + self.positional_encoding_temperature = config.positional_encoding_temperature + self.eval_size = config.eval_size + self.out_channels = [self.encoder_hidden_dim for _ in self.in_channels] + self.out_strides = self.feat_strides + activation_function = config.activation_function + + # encoder transformer + self.encoder = nn.ModuleList([RtDetrV2Encoder(config) for _ in range(len(self.encode_proj_layers))]) + # top-down fpn + self.lateral_convs = nn.ModuleList() + self.fpn_blocks = nn.ModuleList() + for _ in range(len(self.in_channels) - 1, 0, -1): + self.lateral_convs.append( + RtDetrV2ConvNormLayer( + config, self.encoder_hidden_dim, self.encoder_hidden_dim, 1, 1, activation=activation_function + ) + ) + self.fpn_blocks.append(RtDetrV2CSPRepLayer(config)) + + # bottom-up pan + self.downsample_convs = nn.ModuleList() + self.pan_blocks = nn.ModuleList() + for _ in range(len(self.in_channels) - 1): + self.downsample_convs.append( + RtDetrV2ConvNormLayer( + config, self.encoder_hidden_dim, self.encoder_hidden_dim, 3, 2, activation=activation_function + ) + ) + self.pan_blocks.append(RtDetrV2CSPRepLayer(config)) + + @staticmethod + def build_2d_sincos_position_embedding( + width, height, embed_dim=256, temperature=10000.0, device="cpu", dtype=torch.float32 + ): + grid_w = torch.arange(int(width), dtype=dtype, device=device) + grid_h = torch.arange(int(height), dtype=dtype, device=device) + grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij") + if embed_dim % 4 != 0: + raise ValueError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding") + pos_dim = embed_dim // 4 + omega = torch.arange(pos_dim, dtype=dtype, device=device) / pos_dim + omega = 1.0 / (temperature**omega) + + out_w = grid_w.flatten()[..., None] @ omega[None] + out_h = grid_h.flatten()[..., None] @ omega[None] + + return torch.concat([out_w.sin(), out_w.cos(), out_h.sin(), out_h.cos()], dim=1)[None, :, :] + + def forward( + self, + inputs_embeds=None, + attention_mask=None, + position_embeddings=None, + spatial_shapes=None, + level_start_index=None, + valid_ratios=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Flattened feature map (output of the backbone + projection layer) that is passed to the encoder. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`: + - 1 for pixel features that are real (i.e. **not masked**), + - 0 for pixel features that are padding (i.e. **masked**). + [What are attention masks?](../glossary#attention-mask) + position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Position embeddings that are added to the queries and keys in each self-attention layer. + spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`): + Spatial shapes of each feature map. + level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`): + Starting index of each feature map. + valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`): + Ratio of valid area in each feature level. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states = inputs_embeds + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + # encoder + if self.config.encoder_layers > 0: + for i, enc_ind in enumerate(self.encode_proj_layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states[enc_ind],) + height, width = hidden_states[enc_ind].shape[2:] + # flatten [batch, channel, height, width] to [batch, height*width, channel] + src_flatten = hidden_states[enc_ind].flatten(2).permute(0, 2, 1) + if self.training or self.eval_size is None: + pos_embed = self.build_2d_sincos_position_embedding( + width, + height, + self.encoder_hidden_dim, + self.positional_encoding_temperature, + device=src_flatten.device, + dtype=src_flatten.dtype, + ) + else: + pos_embed = None + + layer_outputs = self.encoder[i]( + src_flatten, + pos_embed=pos_embed, + output_attentions=output_attentions, + ) + hidden_states[enc_ind] = ( + layer_outputs[0].permute(0, 2, 1).reshape(-1, self.encoder_hidden_dim, height, width).contiguous() + ) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states[enc_ind],) + + # broadcasting and fusion + fpn_feature_maps = [hidden_states[-1]] + for idx in range(len(self.in_channels) - 1, 0, -1): + feat_high = fpn_feature_maps[0] + feat_low = hidden_states[idx - 1] + feat_high = self.lateral_convs[len(self.in_channels) - 1 - idx](feat_high) + fpn_feature_maps[0] = feat_high + upsample_feat = F.interpolate(feat_high, scale_factor=2.0, mode="nearest") + fps_map = self.fpn_blocks[len(self.in_channels) - 1 - idx](torch.concat([upsample_feat, feat_low], dim=1)) + fpn_feature_maps.insert(0, fps_map) + + fpn_states = [fpn_feature_maps[0]] + for idx in range(len(self.in_channels) - 1): + feat_low = fpn_states[-1] + feat_high = fpn_feature_maps[idx + 1] + downsample_feat = self.downsample_convs[idx](feat_low) + hidden_states = self.pan_blocks[idx]( + torch.concat([downsample_feat, feat_high.to(downsample_feat.device)], dim=1) + ) + fpn_states.append(hidden_states) + + if not return_dict: + return tuple(v for v in [fpn_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput(last_hidden_state=fpn_states, hidden_states=encoder_states, attentions=all_attentions) + + +def inverse_sigmoid(x, eps=1e-5): + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) + + +# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py +class RtDetrV2MLPPredictionHead(nn.Module): + """ + Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates, + height and width of a bounding box w.r.t. an image. + + Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py + Origin from https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/RtDetrV2_paddle/ppdet/modeling/transformers/utils.py#L453 + + """ + + def __init__(self, config, input_dim, d_model, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [d_model] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +def get_contrastive_denoising_training_group( + targets, + num_classes, + num_queries, + class_embed, + num_denoising_queries=100, + label_noise_ratio=0.5, + box_noise_scale=1.0, +): + """ + Creates a contrastive denoising training group using ground-truth samples. It adds noise to labels and boxes. + + Args: + targets (`List[dict]`): + The target objects, each containing 'class_labels' and 'boxes' for objects in an image. + num_classes (`int`): + Total number of classes in the dataset. + num_queries (`int`): + Number of query slots in the transformer. + class_embed (`callable`): + A function or a model layer to embed class labels. + num_denoising_queries (`int`, *optional*, defaults to 100): + Number of denoising queries. + label_noise_ratio (`float`, *optional*, defaults to 0.5): + Ratio of noise applied to labels. + box_noise_scale (`float`, *optional*, defaults to 1.0): + Scale of noise applied to bounding boxes. + Returns: + `tuple` comprising various elements: + - **input_query_class** (`torch.FloatTensor`) -- + Class queries with applied label noise. + - **input_query_bbox** (`torch.FloatTensor`) -- + Bounding box queries with applied box noise. + - **attn_mask** (`torch.FloatTensor`) -- + Attention mask for separating denoising and reconstruction queries. + - **denoising_meta_values** (`dict`) -- + Metadata including denoising positive indices, number of groups, and split sizes. + """ + + if num_denoising_queries <= 0: + return None, None, None, None + + num_ground_truths = [len(t["class_labels"]) for t in targets] + device = targets[0]["class_labels"].device + + max_gt_num = max(num_ground_truths) + if max_gt_num == 0: + return None, None, None, None + + num_groups_denoising_queries = num_denoising_queries // max_gt_num + num_groups_denoising_queries = 1 if num_groups_denoising_queries == 0 else num_groups_denoising_queries + # pad gt to max_num of a batch + batch_size = len(num_ground_truths) + + input_query_class = torch.full([batch_size, max_gt_num], num_classes, dtype=torch.int32, device=device) + input_query_bbox = torch.zeros([batch_size, max_gt_num, 4], device=device) + pad_gt_mask = torch.zeros([batch_size, max_gt_num], dtype=torch.bool, device=device) + + for i in range(batch_size): + num_gt = num_ground_truths[i] + if num_gt > 0: + input_query_class[i, :num_gt] = targets[i]["class_labels"] + input_query_bbox[i, :num_gt] = targets[i]["boxes"] + pad_gt_mask[i, :num_gt] = 1 + # each group has positive and negative queries. + input_query_class = input_query_class.tile([1, 2 * num_groups_denoising_queries]) + input_query_bbox = input_query_bbox.tile([1, 2 * num_groups_denoising_queries, 1]) + pad_gt_mask = pad_gt_mask.tile([1, 2 * num_groups_denoising_queries]) + # positive and negative mask + negative_gt_mask = torch.zeros([batch_size, max_gt_num * 2, 1], device=device) + negative_gt_mask[:, max_gt_num:] = 1 + negative_gt_mask = negative_gt_mask.tile([1, num_groups_denoising_queries, 1]) + positive_gt_mask = 1 - negative_gt_mask + # contrastive denoising training positive index + positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask + denoise_positive_idx = torch.nonzero(positive_gt_mask)[:, 1] + denoise_positive_idx = torch.split( + denoise_positive_idx, [n * num_groups_denoising_queries for n in num_ground_truths] + ) + # total denoising queries + num_denoising_queries = int(max_gt_num * 2 * num_groups_denoising_queries) + + if label_noise_ratio > 0: + mask = torch.rand_like(input_query_class, dtype=torch.float) < (label_noise_ratio * 0.5) + # randomly put a new one here + new_label = torch.randint_like(mask, 0, num_classes, dtype=input_query_class.dtype) + input_query_class = torch.where(mask & pad_gt_mask, new_label, input_query_class) + + if box_noise_scale > 0: + known_bbox = center_to_corners_format(input_query_bbox) + diff = torch.tile(input_query_bbox[..., 2:] * 0.5, [1, 1, 2]) * box_noise_scale + rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0 + rand_part = torch.rand_like(input_query_bbox) + rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (1 - negative_gt_mask) + rand_part *= rand_sign + known_bbox += rand_part * diff + known_bbox.clip_(min=0.0, max=1.0) + input_query_bbox = corners_to_center_format(known_bbox) + input_query_bbox = inverse_sigmoid(input_query_bbox) + + input_query_class = class_embed(input_query_class) + + target_size = num_denoising_queries + num_queries + attn_mask = torch.full([target_size, target_size], False, dtype=torch.bool, device=device) + # match query cannot see the reconstruction + attn_mask[num_denoising_queries:, :num_denoising_queries] = True + + # reconstructions cannot see each other + for i in range(num_groups_denoising_queries): + idx_block_start = max_gt_num * 2 * i + idx_block_end = max_gt_num * 2 * (i + 1) + attn_mask[idx_block_start:idx_block_end, :idx_block_start] = True + attn_mask[idx_block_start:idx_block_end, idx_block_end:num_denoising_queries] = True + + denoising_meta_values = { + "dn_positive_idx": denoise_positive_idx, + "dn_num_group": num_groups_denoising_queries, + "dn_num_split": [num_denoising_queries, num_queries], + } + + return input_query_class, input_query_bbox, attn_mask, denoising_meta_values + + +RtDetrV2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`RtDetrV2Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +RtDetrV2_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`RtDetrV2ImageProcessor.__call__`] for details. + pixel_mask (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): + Mask to avoid performing attention on padding pixel values. Mask values selected in `[0, 1]`: + + - 1 for pixels that are real (i.e. **not masked**), + - 0 for pixels that are padding (i.e. **masked**). + + [What are attention masks?](../glossary#attention-mask) + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you + can choose to directly pass a flattened representation of an image. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*): + Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an + embedded representation. + labels (`List[Dict]` of len `(batch_size,)`, *optional*): + Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the + following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch + respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes + in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +def compile_compatible_lru_cache(*lru_args, **lru_kwargs): + def decorator(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + if not torch.compiler.is_compiling(): + # Cache the function only if the model is not being compiled + # check if the function is already cached, otherwise create it + if not hasattr(self, f"_cached_{func.__name__}"): + self.__setattr__( + f"_cached_{func.__name__}", lru_cache(*lru_args, **lru_kwargs)(func.__get__(self)) + ) + return self.__getattribute__(f"_cached_{func.__name__}")(*args, **kwargs) + else: + # Otherwise, just call the original function + return func(self, *args, **kwargs) + + return wrapper + + return decorator + + +def _get_clones(partial_module, N): + return nn.ModuleList([partial_module() for i in range(N)]) + + +class RtDetrV2PreTrainedModel(PreTrainedModel): + config_class = RtDetrV2Config + base_model_prefix = "rt_detr_v2" + main_input_name = "pixel_values" + _no_split_modules = [r"RtDetrV2ConvEncoder", r"RtDetrV2EncoderLayer", r"RtDetrV2DecoderLayer"] + + def _init_weights(self, module): + # this could be simplified like calling the base class function first then adding new stuff + """Initalize the weights + initialize linear layer biasreturn value according to a given probability value.""" + if isinstance(module, (RtDetrV2ForObjectDetection, RtDetrV2Decoder)): + if module.class_embed is not None: + for layer in module.class_embed: + prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) + bias = float(-math.log((1 - prior_prob) / prior_prob)) + nn.init.xavier_uniform_(layer.weight) + nn.init.constant_(layer.bias, bias) + + if module.bbox_embed is not None: + for layer in module.bbox_embed: + nn.init.constant_(layer.layers[-1].weight, 0) + nn.init.constant_(layer.layers[-1].bias, 0) + + if isinstance(module, RtDetrV2MultiscaleDeformableAttention): + nn.init.constant_(module.sampling_offsets.weight.data, 0.0) + default_dtype = torch.get_default_dtype() + thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * ( + 2.0 * math.pi / module.n_heads + ) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = ( + (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) + .view(module.n_heads, 1, 1, 2) + .repeat(1, module.n_levels, module.n_points, 1) + ) + for i in range(module.n_points): + grid_init[:, :, i, :] *= i + 1 + with torch.no_grad(): + module.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) + nn.init.constant_(module.attention_weights.weight.data, 0.0) + nn.init.constant_(module.attention_weights.bias.data, 0.0) + nn.init.xavier_uniform_(module.value_proj.weight.data) + nn.init.constant_(module.value_proj.bias.data, 0.0) + nn.init.xavier_uniform_(module.output_proj.weight.data) + nn.init.constant_(module.output_proj.bias.data, 0.0) + + if isinstance(module, RtDetrV2Model): + prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) + bias = float(-math.log((1 - prior_prob) / prior_prob)) + nn.init.xavier_uniform_(module.enc_score_head.weight) + nn.init.constant_(module.enc_score_head.bias, bias) + + if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + + if hasattr(module, "weight_embedding") and self.config.learn_initial_query: + nn.init.xavier_uniform_(module.weight_embedding.weight) + if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0: + nn.init.xavier_uniform_(module.denoising_class_embed.weight) + + +class RtDetrV2Decoder(RtDetrV2PreTrainedModel): + def __init__(self, config: RtDetrV2Config): + super().__init__(config) + + self.dropout = config.dropout + self.layers = nn.ModuleList([RtDetrV2DecoderLayer(config) for _ in range(config.decoder_layers)]) + self.query_pos_head = RtDetrV2MLPPredictionHead(config, 4, 2 * config.d_model, config.d_model, num_layers=2) + + # hack implementation for iterative bounding box refinement and two-stage Deformable DETR + self.bbox_embed = None + self.class_embed = None + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + position_embeddings=None, + reference_points=None, + spatial_shapes=None, + spatial_shapes_list=None, + level_start_index=None, + valid_ratios=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): + The query embeddings that are passed into the decoder. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding pixel_values of the encoder. Mask values selected + in `[0, 1]`: + - 1 for pixels that are real (i.e. **not masked**), + - 0 for pixels that are padding (i.e. **masked**). + position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*): + Position embeddings that are added to the queries and keys in each self-attention layer. + reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)` is `as_two_stage` else `(batch_size, num_queries, 2)` or , *optional*): + Reference point in range `[0, 1]`, top-left (0,0), bottom-right (1, 1), including padding area. + spatial_shapes (`torch.FloatTensor` of shape `(num_feature_levels, 2)`): + Spatial shapes of the feature maps. + level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`, *optional*): + Indexes for the start of each feature level. In range `[0, sequence_length]`. + valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`, *optional*): + Ratio of valid area in each feature level. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is not None: + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + intermediate = () + intermediate_reference_points = () + intermediate_logits = () + + reference_points = F.sigmoid(reference_points) + + # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/RtDetrV2_pytorch/src/zoo/RtDetrV2/RtDetrV2_decoder.py#L252 + for idx, decoder_layer in enumerate(self.layers): + reference_points_input = reference_points.unsqueeze(2) + position_embeddings = self.query_pos_head(reference_points) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + position_embeddings=position_embeddings, + encoder_hidden_states=encoder_hidden_states, + reference_points=reference_points_input, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + level_start_index=level_start_index, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + # hack implementation for iterative bounding box refinement + if self.bbox_embed is not None: + tmp = self.bbox_embed[idx](hidden_states) + new_reference_points = F.sigmoid(tmp + inverse_sigmoid(reference_points)) + reference_points = new_reference_points.detach() + + intermediate += (hidden_states,) + intermediate_reference_points += ( + (new_reference_points,) if self.bbox_embed is not None else (reference_points,) + ) + + if self.class_embed is not None: + logits = self.class_embed[idx](hidden_states) + intermediate_logits += (logits,) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + # Keep batch_size as first dimension + intermediate = torch.stack(intermediate, dim=1) + intermediate_reference_points = torch.stack(intermediate_reference_points, dim=1) + if self.class_embed is not None: + intermediate_logits = torch.stack(intermediate_logits, dim=1) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + intermediate, + intermediate_logits, + intermediate_reference_points, + all_hidden_states, + all_self_attns, + all_cross_attentions, + ] + if v is not None + ) + return RtDetrV2DecoderOutput( + last_hidden_state=hidden_states, + intermediate_hidden_states=intermediate, + intermediate_logits=intermediate_logits, + intermediate_reference_points=intermediate_reference_points, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + """ + RT-DETR Model (consisting of a backbone and encoder-decoder) outputting raw hidden states without any head on top. + """, + RtDetrV2_START_DOCSTRING, +) +class RtDetrV2Model(RtDetrV2PreTrainedModel): + def __init__(self, config: RtDetrV2Config): + super().__init__(config) + + # Create backbone + self.backbone = RtDetrV2ConvEncoder(config) + intermediate_channel_sizes = self.backbone.intermediate_channel_sizes + + # Create encoder input projection layers + # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/RtDetrV2_pytorch/src/zoo/RtDetrV2/hybrid_encoder.py#L212 + num_backbone_outs = len(intermediate_channel_sizes) + encoder_input_proj_list = [] + for _ in range(num_backbone_outs): + in_channels = intermediate_channel_sizes[_] + encoder_input_proj_list.append( + nn.Sequential( + nn.Conv2d(in_channels, config.encoder_hidden_dim, kernel_size=1, bias=False), + nn.BatchNorm2d(config.encoder_hidden_dim), + ) + ) + self.encoder_input_proj = nn.ModuleList(encoder_input_proj_list) + + # Create encoder + self.encoder = RtDetrV2HybridEncoder(config) + + # denoising part + if config.num_denoising > 0: + self.denoising_class_embed = nn.Embedding( + config.num_labels + 1, config.d_model, padding_idx=config.num_labels + ) + + # decoder embedding + if config.learn_initial_query: + self.weight_embedding = nn.Embedding(config.num_queries, config.d_model) + + # encoder head + self.enc_output = nn.Sequential( + nn.Linear(config.d_model, config.d_model), + nn.LayerNorm(config.d_model, eps=config.layer_norm_eps), + ) + self.enc_score_head = nn.Linear(config.d_model, config.num_labels) + self.enc_bbox_head = RtDetrV2MLPPredictionHead(config, config.d_model, config.d_model, 4, num_layers=3) + + # init encoder output anchors and valid_mask + if config.anchor_image_size: + self.anchors, self.valid_mask = self.generate_anchors(dtype=self.dtype) + + # Create decoder input projection layers + # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/RtDetrV2_pytorch/src/zoo/RtDetrV2/RtDetrV2_decoder.py#L412 + num_backbone_outs = len(config.decoder_in_channels) + decoder_input_proj_list = [] + for _ in range(num_backbone_outs): + in_channels = config.decoder_in_channels[_] + decoder_input_proj_list.append( + nn.Sequential( + nn.Conv2d(in_channels, config.d_model, kernel_size=1, bias=False), + nn.BatchNorm2d(config.d_model, config.batch_norm_eps), + ) + ) + for _ in range(config.num_feature_levels - num_backbone_outs): + decoder_input_proj_list.append( + nn.Sequential( + nn.Conv2d(in_channels, config.d_model, kernel_size=3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(config.d_model, config.batch_norm_eps), + ) + ) + in_channels = config.d_model + self.decoder_input_proj = nn.ModuleList(decoder_input_proj_list) + # decoder + self.decoder = RtDetrV2Decoder(config) + + self.post_init() + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def freeze_backbone(self): + for param in self.backbone.parameters(): + param.requires_grad_(False) + + def unfreeze_backbone(self): + for param in self.backbone.parameters(): + param.requires_grad_(True) + + @compile_compatible_lru_cache(maxsize=32) + def generate_anchors(self, spatial_shapes=None, grid_size=0.05, device="cpu", dtype=torch.float32): + if spatial_shapes is None: + spatial_shapes = [ + [int(self.config.anchor_image_size[0] / s), int(self.config.anchor_image_size[1] / s)] + for s in self.config.feat_strides + ] + anchors = [] + for level, (height, width) in enumerate(spatial_shapes): + grid_y, grid_x = torch.meshgrid( + torch.arange(end=height, dtype=dtype, device=device), + torch.arange(end=width, dtype=dtype, device=device), + indexing="ij", + ) + grid_xy = torch.stack([grid_x, grid_y], -1) + valid_wh = torch.tensor([width, height], device=device).to(dtype) + grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_wh + wh = torch.ones_like(grid_xy) * grid_size * (2.0**level) + anchors.append(torch.concat([grid_xy, wh], -1).reshape(-1, height * width, 4)) + # define the valid range for anchor coordinates + eps = 1e-2 + anchors = torch.concat(anchors, 1) + valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True) + anchors = torch.log(anchors / (1 - anchors)) + anchors = torch.where(valid_mask, anchors, torch.tensor(torch.finfo(dtype).max, dtype=dtype, device=device)) + + return anchors, valid_mask + + @add_start_docstrings_to_model_forward(RtDetrV2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=RtDetrV2ModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[List[dict]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], RtDetrV2ModelOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, RtDetrV2Model + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("PekingU/RtDetrV2_r50vd") + >>> model = RtDetrV2Model.from_pretrained("PekingU/RtDetrV2_r50vd") + + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + + >>> last_hidden_states = outputs.last_hidden_state + >>> list(last_hidden_states.shape) + [1, 300, 256] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, num_channels, height, width = pixel_values.shape + device = pixel_values.device + + if pixel_mask is None: + pixel_mask = torch.ones(((batch_size, height, width)), device=device) + + features = self.backbone(pixel_values, pixel_mask) + + proj_feats = [self.encoder_input_proj[level](source) for level, (source, mask) in enumerate(features)] + + if encoder_outputs is None: + encoder_outputs = self.encoder( + proj_feats, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if output_hidden_states else None, + attentions=encoder_outputs[2] + if len(encoder_outputs) > 2 + else encoder_outputs[1] + if output_attentions + else None, + ) + + # Equivalent to def _get_encoder_input + # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/RtDetrV2_pytorch/src/zoo/RtDetrV2/RtDetrV2_decoder.py#L412 + sources = [] + for level, source in enumerate(encoder_outputs[0]): + sources.append(self.decoder_input_proj[level](source)) + + # Lowest resolution feature maps are obtained via 3x3 stride 2 convolutions on the final stage + if self.config.num_feature_levels > len(sources): + _len_sources = len(sources) + sources.append(self.decoder_input_proj[_len_sources](encoder_outputs[0])[-1]) + for i in range(_len_sources + 1, self.config.num_feature_levels): + sources.append(self.decoder_input_proj[i](encoder_outputs[0][-1])) + + # Prepare encoder inputs (by flattening) + source_flatten = [] + spatial_shapes_list = [] + for level, source in enumerate(sources): + batch_size, num_channels, height, width = source.shape + spatial_shape = (height, width) + spatial_shapes_list.append(spatial_shape) + source = source.flatten(2).transpose(1, 2) + source_flatten.append(source) + source_flatten = torch.cat(source_flatten, 1) + spatial_shapes = torch.as_tensor(spatial_shapes_list, dtype=torch.long, device=source_flatten.device) + level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + + # prepare denoising training + if self.training and self.config.num_denoising > 0 and labels is not None: + ( + denoising_class, + denoising_bbox_unact, + attention_mask, + denoising_meta_values, + ) = get_contrastive_denoising_training_group( + targets=labels, + num_classes=self.config.num_labels, + num_queries=self.config.num_queries, + class_embed=self.denoising_class_embed, + num_denoising_queries=self.config.num_denoising, + label_noise_ratio=self.config.label_noise_ratio, + box_noise_scale=self.config.box_noise_scale, + ) + else: + denoising_class, denoising_bbox_unact, attention_mask, denoising_meta_values = None, None, None, None + + batch_size = len(source_flatten) + device = source_flatten.device + dtype = source_flatten.dtype + + # prepare input for decoder + if self.training or self.config.anchor_image_size is None: + # Pass spatial_shapes as tuple to make it hashable and make sure + # lru_cache is working for generate_anchors() + spatial_shapes_tuple = tuple(spatial_shapes_list) + anchors, valid_mask = self.generate_anchors(spatial_shapes_tuple, device=device, dtype=dtype) + else: + anchors, valid_mask = self.anchors, self.valid_mask + + anchors, valid_mask = anchors.to(device, dtype), valid_mask.to(device, dtype) + + # use the valid_mask to selectively retain values in the feature map where the mask is `True` + memory = valid_mask.to(source_flatten.dtype) * source_flatten + + output_memory = self.enc_output(memory) + + enc_outputs_class = self.enc_score_head(output_memory) + enc_outputs_coord_logits = self.enc_bbox_head(output_memory) + anchors + + _, topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.config.num_queries, dim=1) + + reference_points_unact = enc_outputs_coord_logits.gather( + dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_coord_logits.shape[-1]) + ) + + enc_topk_bboxes = F.sigmoid(reference_points_unact) + if denoising_bbox_unact is not None: + reference_points_unact = torch.concat([denoising_bbox_unact, reference_points_unact], 1) + + enc_topk_logits = enc_outputs_class.gather( + dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1]) + ) + + # extract region features + if self.config.learn_initial_query: + target = self.weight_embedding.tile([batch_size, 1, 1]) + else: + target = output_memory.gather(dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1])) + target = target.detach() + + if denoising_class is not None: + target = torch.concat([denoising_class, target], 1) + + init_reference_points = reference_points_unact.detach() + + # decoder + decoder_outputs = self.decoder( + inputs_embeds=target, + encoder_hidden_states=source_flatten, + encoder_attention_mask=attention_mask, + reference_points=init_reference_points, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + level_start_index=level_start_index, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + enc_outputs = tuple( + value + for value in [enc_topk_logits, enc_topk_bboxes, enc_outputs_class, enc_outputs_coord_logits] + if value is not None + ) + dn_outputs = tuple(value if value is not None else None for value in [denoising_meta_values]) + tuple_outputs = decoder_outputs + encoder_outputs + (init_reference_points,) + enc_outputs + dn_outputs + + return tuple_outputs + + return RtDetrV2ModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + intermediate_hidden_states=decoder_outputs.intermediate_hidden_states, + intermediate_logits=decoder_outputs.intermediate_logits, + intermediate_reference_points=decoder_outputs.intermediate_reference_points, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + init_reference_points=init_reference_points, + enc_topk_logits=enc_topk_logits, + enc_topk_bboxes=enc_topk_bboxes, + enc_outputs_class=enc_outputs_class, + enc_outputs_coord_logits=enc_outputs_coord_logits, + denoising_meta_values=denoising_meta_values, + ) + + +@add_start_docstrings( + """ + RT-DETR Model (consisting of a backbone and encoder-decoder) outputting bounding boxes and logits to be further + decoded into scores and classes. + """, + RtDetrV2_START_DOCSTRING, +) +class RtDetrV2ForObjectDetection(RtDetrV2PreTrainedModel): + # When using clones, all layers > 0 will be clones, but layer 0 *is* required + _tied_weights_keys = ["bbox_embed", "class_embed"] + # We can't initialize the model on meta device as some weights are modified during the initialization + _no_split_modules = None + + def __init__(self, config: RtDetrV2Config): + super().__init__(config) + # RTDETR encoder-decoder model + self.model = RtDetrV2Model(config) + + # Detection heads on top + self.class_embed = partial(nn.Linear, config.d_model, config.num_labels) + self.bbox_embed = partial(RtDetrV2MLPPredictionHead, config, config.d_model, config.d_model, 4, num_layers=3) + + # if two-stage, the last class_embed and bbox_embed is for region proposal generation + num_pred = config.decoder_layers + if config.with_box_refine: + self.class_embed = _get_clones(self.class_embed, num_pred) + self.bbox_embed = _get_clones(self.bbox_embed, num_pred) + else: + self.class_embed = nn.ModuleList([self.class_embed() for _ in range(num_pred)]) + self.bbox_embed = nn.ModuleList([self.bbox_embed() for _ in range(num_pred)]) + # here self.model.decoder.bbox_embed is null, but not self.bbox_embed + # fix bug + self.model.decoder.class_embed = self.class_embed + self.model.decoder.bbox_embed = self.bbox_embed + + # Initialize weights and apply final processing + self.post_init() + + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_coord): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class, outputs_coord)] + + @add_start_docstrings_to_model_forward(RtDetrV2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=RtDetrV2ObjectDetectionOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[List[dict]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **loss_kwargs, + ) -> Union[Tuple[torch.FloatTensor], RtDetrV2ObjectDetectionOutput]: + r""" + labels (`List[Dict]` of len `(batch_size,)`, *optional*): + Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the + following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch + respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes + in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`. + + Returns: + + Examples: + + ```python + >>> from transformers import RtDetrV2ImageProcessor, RtDetrV2ForObjectDetection + >>> from PIL import Image + >>> import requests + >>> import torch + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = RtDetrV2ImageProcessor.from_pretrained("PekingU/RtDetrV2_r50vd") + >>> model = RtDetrV2ForObjectDetection.from_pretrained("PekingU/RtDetrV2_r50vd") + + >>> # prepare image for the model + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> # forward pass + >>> outputs = model(**inputs) + + >>> logits = outputs.logits + >>> list(logits.shape) + [1, 300, 80] + + >>> boxes = outputs.pred_boxes + >>> list(boxes.shape) + [1, 300, 4] + + >>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax) + >>> target_sizes = torch.tensor([image.size[::-1]]) + >>> results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[ + ... 0 + ... ] + + >>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): + ... box = [round(i, 2) for i in box.tolist()] + ... print( + ... f"Detected {model.config.id2label[label.item()]} with confidence " + ... f"{round(score.item(), 3)} at location {box}" + ... ) + Detected sofa with confidence 0.97 at location [0.14, 0.38, 640.13, 476.21] + Detected cat with confidence 0.96 at location [343.38, 24.28, 640.14, 371.5] + Detected cat with confidence 0.958 at location [13.23, 54.18, 318.98, 472.22] + Detected remote with confidence 0.951 at location [40.11, 73.44, 175.96, 118.48] + Detected remote with confidence 0.924 at location [333.73, 76.58, 369.97, 186.99] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + pixel_values, + pixel_mask=pixel_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + labels=labels, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + denoising_meta_values = ( + outputs.denoising_meta_values if return_dict else outputs[-1] if self.training else None + ) + + outputs_class = outputs.intermediate_logits if return_dict else outputs[2] + outputs_coord = outputs.intermediate_reference_points if return_dict else outputs[3] + + logits = outputs_class[:, -1] + pred_boxes = outputs_coord[:, -1] + + loss, loss_dict, auxiliary_outputs, enc_topk_logits, enc_topk_bboxes = None, None, None, None, None + if labels is not None: + if self.training and denoising_meta_values is not None: + enc_topk_logits = outputs.enc_topk_logits if return_dict else outputs[-5] + enc_topk_bboxes = outputs.enc_topk_bboxes if return_dict else outputs[-4] + loss, loss_dict, auxiliary_outputs = self.loss_function( + logits, + labels, + self.device, + pred_boxes, + self.config, + outputs_class, + outputs_coord, + enc_topk_logits=enc_topk_logits, + enc_topk_bboxes=enc_topk_bboxes, + denoising_meta_values=denoising_meta_values, + **loss_kwargs, + ) + + if not return_dict: + if auxiliary_outputs is not None: + output = (logits, pred_boxes) + (auxiliary_outputs,) + outputs + else: + output = (logits, pred_boxes) + outputs + return ((loss, loss_dict) + output) if loss is not None else output + + return RtDetrV2ObjectDetectionOutput( + loss=loss, + loss_dict=loss_dict, + logits=logits, + pred_boxes=pred_boxes, + auxiliary_outputs=auxiliary_outputs, + last_hidden_state=outputs.last_hidden_state, + intermediate_hidden_states=outputs.intermediate_hidden_states, + intermediate_logits=outputs.intermediate_logits, + intermediate_reference_points=outputs.intermediate_reference_points, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + init_reference_points=outputs.init_reference_points, + enc_topk_logits=outputs.enc_topk_logits, + enc_topk_bboxes=outputs.enc_topk_bboxes, + enc_outputs_class=outputs.enc_outputs_class, + enc_outputs_coord_logits=outputs.enc_outputs_coord_logits, + denoising_meta_values=outputs.denoising_meta_values, + ) diff --git a/src/transformers/models/rt_detr_v2/modular_rt_detr_v2.py b/src/transformers/models/rt_detr_v2/modular_rt_detr_v2.py index eaf9874192c493..a7dbabbf0e684c 100644 --- a/src/transformers/models/rt_detr_v2/modular_rt_detr_v2.py +++ b/src/transformers/models/rt_detr_v2/modular_rt_detr_v2.py @@ -5,8 +5,8 @@ import torch.nn.functional as F from torch import Tensor, nn -from ...modeling_utils import PreTrainedModel from ..rt_detr.configuration_rt_detr import RTDetrConfig +from ..rt_detr.configuration_rt_detr_resnet import RTDetrResNetConfig from ..rt_detr.modeling_rt_detr import ( MultiScaleDeformableAttentionFunction, RTDetrDecoder, @@ -14,16 +14,27 @@ RTDetrForObjectDetection, RTDetrModel, RTDetrMultiscaleDeformableAttention, + RTDetrPreTrainedModel, ) +class RtDetrV2ResNetConfig(RTDetrResNetConfig): + def __init__( + self, + **super_kwargs, + ): + super().__init__(**super_kwargs) + + class RtDetrV2Config(RTDetrConfig): model_type = "rt_detr_v2" - def __init__(self, - decoder_n_levels=3, # default value - decoder_offset_scale=0.5, # default value - **super_kwargs): + def __init__( + self, + decoder_n_levels=3, # default value + decoder_offset_scale=0.5, # default value + **super_kwargs, + ): # init the base RTDetrConfig class with additional parameters super().__init__(**super_kwargs) @@ -105,9 +116,11 @@ def multi_scale_deformable_attention_v2( ) return output.transpose(1, 2).contiguous() + class MultiScaleDeformableAttentionFunctionV2(MultiScaleDeformableAttentionFunction): pass + # the main change class RtDetrV2MultiscaleDeformableAttention(RTDetrMultiscaleDeformableAttention): """ @@ -115,13 +128,11 @@ class RtDetrV2MultiscaleDeformableAttention(RTDetrMultiscaleDeformableAttention) with improved offset handling and initialization. """ - def __init__(self, config: RtDetrV2Config): + def __init__(self, config: RtDetrV2Config, num_heads: int = None, n_points: int = None): + num_heads = num_heads if num_heads is not None else config.decoder_attention_heads + n_points = n_points if n_points is not None else config.decoder_n_points # Initialize parent class with config parameters - super().__init__( - config=config, - num_heads=config.decoder_attention_heads, - n_points=config.decoder_n_points - ) + super().__init__(config=config, num_heads=num_heads, n_points=n_points) # V2-specific attributes self.offset_scale = config.decoder_offset_scale @@ -130,31 +141,6 @@ def __init__(self, config: RtDetrV2Config): self.n_points_list = n_points_list n_points_scale = [1 / n for n in n_points_list for _ in range(n)] self.register_buffer("n_points_scale", torch.tensor(n_points_scale, dtype=torch.float32)) - # Initialize weights with v2-specific pattern - self._reset_parameters() - def _reset_parameters(self): - """V2-specific initialization of parameters""" - nn.init.constant_(self.sampling_offsets.weight.data, 0.0) - default_dtype = torch.get_default_dtype() - thetas = torch.arange(self.n_heads, dtype=torch.int64).to(default_dtype) * (2.0 * math.pi / self.n_heads) - grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) - grid_init = ( - (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) - .view(self.n_heads, 1, 1, 2) - .repeat(1, self.n_levels, self.n_points, 1) - ) - for i in range(self.n_points): - grid_init[:, :, i, :] *= i + 1 - - with torch.no_grad(): - self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) - - nn.init.constant_(self.attention_weights.weight.data, 0.0) - nn.init.constant_(self.attention_weights.bias.data, 0.0) - nn.init.xavier_uniform_(self.value_proj.weight.data) - nn.init.constant_(self.value_proj.bias.data, 0.0) - nn.init.xavier_uniform_(self.output_proj.weight.data) - nn.init.constant_(self.output_proj.bias.data, 0.0) def forward( self, @@ -240,63 +226,12 @@ def __init__(self, config: RtDetrV2Config): # override only the encoder attention module with v2 version self.encoder_attn = RtDetrV2MultiscaleDeformableAttention(config) -class RtDetrV2Decoder(RTDetrDecoder): - def __init__(self, config: RtDetrV2Config): - super().__init__(config) - self.layers = nn.ModuleList([RtDetrV2DecoderLayer(config) for _ in range(config.decoder_layers)]) - - -class RtDetrV2Model(RTDetrModel): - def __init__(self, config: RtDetrV2Config): - super().__init__(config) - # decoder - self.decoder = RtDetrV2Decoder(config) - -def _get_clones(partial_module, N): - return nn.ModuleList([partial_module() for i in range(N)]) - -class RtDetrV2ForObjectDetection(RTDetrForObjectDetection): - def __init__(self, config: RtDetrV2Config): - super().__init__(config) - # RTDETR encoder-decoder model - self.model = RtDetrV2Model(config) - # here self.model.decoder.bbox_embed is null, but not self.bbox_embed - # fix bug - self.model.decoder.class_embed = self.class_embed - self.model.decoder.bbox_embed = self.bbox_embed - -class RtDetrV2PreTrainedModel(PreTrainedModel): - def __init__(self, RTDetrV2Config): - super().__init__(config) - self.model = RTDetrV2Model(config) +class RtDetrV2PreTrainedModel(RTDetrPreTrainedModel): def _init_weights(self, module): # this could be simplified like calling the base class function first then adding new stuff - """Initalize the weights""" - nn.init.constant_(self.sampling_offsets.weight.data, 0.0) - default_dtype = torch.get_default_dtype() - thetas = torch.arange(self.n_heads, dtype=torch.int64).to(default_dtype) * (2.0 * math.pi / self.n_heads) - grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) - grid_init = ( - (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) - .view(self.n_heads, 1, 1, 2) - .repeat(1, self.n_levels, self.n_points, 1) - ) - for i in range(self.n_points): - grid_init[:, :, i, :] *= i + 1 - - with torch.no_grad(): - self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) - - nn.init.constant_(self.attention_weights.weight.data, 0.0) - nn.init.constant_(self.attention_weights.bias.data, 0.0) - nn.init.xavier_uniform_(self.value_proj.weight.data) - nn.init.constant_(self.value_proj.bias.data, 0.0) - nn.init.xavier_uniform_(self.output_proj.weight.data) - nn.init.constant_(self.output_proj.bias.data, 0.0) - """initialize linear layer biasreturn value according to a given probability value.""" - if isinstance(module, (RTDetrForObjectDetection, RTDetrDecoder)): + if isinstance(module, (RtDetrV2ForObjectDetection, RtDetrV2Decoder)): if module.class_embed is not None: for layer in module.class_embed: prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) @@ -309,7 +244,7 @@ def _init_weights(self, module): nn.init.constant_(layer.layers[-1].weight, 0) nn.init.constant_(layer.layers[-1].bias, 0) - if isinstance(module, RTDetrV2MultiscaleDeformableAttention): + if isinstance(module, RtDetrV2MultiscaleDeformableAttention): nn.init.constant_(module.sampling_offsets.weight.data, 0.0) default_dtype = torch.get_default_dtype() thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * ( @@ -332,7 +267,7 @@ def _init_weights(self, module): nn.init.xavier_uniform_(module.output_proj.weight.data) nn.init.constant_(module.output_proj.bias.data, 0.0) - if isinstance(module, RTDetrV2Model): + if isinstance(module, RtDetrV2Model): prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) bias = float(-math.log((1 - prior_prob) / prior_prob)) nn.init.xavier_uniform_(module.enc_score_head.weight) @@ -347,3 +282,31 @@ def _init_weights(self, module): nn.init.xavier_uniform_(module.weight_embedding.weight) if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0: nn.init.xavier_uniform_(module.denoising_class_embed.weight) + + +class RtDetrV2Decoder(RTDetrDecoder): + def __init__(self, config: RtDetrV2Config): + super().__init__(config) + self.layers = nn.ModuleList([RtDetrV2DecoderLayer(config) for _ in range(config.decoder_layers)]) + + +class RtDetrV2Model(RTDetrModel): + def __init__(self, config: RtDetrV2Config): + super().__init__(config) + # decoder + self.decoder = RtDetrV2Decoder(config) + + +def _get_clones(partial_module, N): + return nn.ModuleList([partial_module() for i in range(N)]) + + +class RtDetrV2ForObjectDetection(RTDetrForObjectDetection): + def __init__(self, config: RtDetrV2Config): + super().__init__(config) + # RTDETR encoder-decoder model + self.model = RtDetrV2Model(config) + # here self.model.decoder.bbox_embed is null, but not self.bbox_embed + # fix bug + self.model.decoder.class_embed = self.class_embed + self.model.decoder.bbox_embed = self.bbox_embed diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 922d67264bb142..bed68936c1babf 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -8324,6 +8324,27 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class RtDetrV2ForObjectDetection(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class RtDetrV2Model(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class RtDetrV2PreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class RwkvForCausalLM(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/rtdetrv2/test_modeling_rt_detr_v2.py b/tests/models/rtdetrv2/test_modeling_rt_detr_v2.py index c31af4d8b5e212..452b22c12bded7 100644 --- a/tests/models/rtdetrv2/test_modeling_rt_detr_v2.py +++ b/tests/models/rtdetrv2/test_modeling_rt_detr_v2.py @@ -15,10 +15,12 @@ """Testing suite for the PyTorch RT_DETR model.""" import inspect +import json import math import tempfile import unittest +from huggingface_hub import hf_hub_download from parameterized import parameterized from transformers import ( @@ -163,7 +165,6 @@ def prepare_config_and_inputs(self): config.num_labels = self.num_labels return config, pixel_values, pixel_mask, labels - def get_rt_detr_v2_config(model_name: str) -> RtDetrV2Config: config = RtDetrV2Config() diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py index a63ca59690f748..8e72ec8c3c2af9 100644 --- a/utils/check_docstrings.py +++ b/utils/check_docstrings.py @@ -975,7 +975,9 @@ def check_docstrings(overwrite: bool = False, check_all: bool = False): if modified_file_diff.a_path.startswith("src/transformers"): module_diff_files.add(modified_file_diff.a_path) # Diff from index to `main` - for modified_file_diff in repo.index.diff(repo.refs.main.commit): + default_branch = repo.active_branch.name if repo.active_branch else "main" + # Diff from index to the detected branch + for modified_file_diff in repo.index.diff(repo.refs[default_branch].commit): if modified_file_diff.a_path.startswith("src/transformers"): module_diff_files.add(modified_file_diff.a_path) # quick escape route: if there are no module files in the diff, skip this check