Skip to content

Commit

Permalink
fix modular
Browse files Browse the repository at this point in the history
  • Loading branch information
jadechoghari committed Dec 29, 2024
1 parent 6135be4 commit 350cafe
Show file tree
Hide file tree
Showing 11 changed files with 2,737 additions and 105 deletions.
1 change: 1 addition & 0 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ Flax), PyTorch, and/or TensorFlow.
| [RoFormer](model_doc/roformer) ||||
| [RT-DETR](model_doc/rt_detr) ||||
| [RT-DETR-ResNet](model_doc/rt_detr_resnet) ||||
| [RtDetrV2](model_doc/rt_detr_v2) ||||
| [RWKV](model_doc/rwkv) ||||
| [SAM](model_doc/sam) ||||
| [SeamlessM4T](model_doc/seamless_m4t) ||||
Expand Down
10 changes: 3 additions & 7 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,7 @@
"RoFormerTokenizer",
],
"models.rt_detr": ["RTDetrConfig", "RTDetrResNetConfig"],
"models.rt_detr_v2": ["RtDetrV2Config"],
"models.rt_detr_v2": ["RtDetrV2Config", "RtDetrV2ResNetConfig"],
"models.rwkv": ["RwkvConfig"],
"models.sam": [
"SamConfig",
Expand Down Expand Up @@ -3365,11 +3365,7 @@
]
)
_import_structure["models.rt_detr_v2"].extend(
[
"RtDetrV2ForObjectDetection",
"RtDetrV2Model",
"RtDetrV2PreTrainedModel"
]
["RtDetrV2ForObjectDetection", "RtDetrV2Model", "RtDetrV2PreTrainedModel"]
)
_import_structure["models.rwkv"].extend(
[
Expand Down Expand Up @@ -5743,7 +5739,7 @@
RTDetrConfig,
RTDetrResNetConfig,
)
from .models.rt_detr_v2 import RtDetrV2Config
from .models.rt_detr_v2 import RtDetrV2Config, RtDetrV2ResNetConfig
from .models.rwkv import RwkvConfig
from .models.sam import (
SamConfig,
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@
("rt_detr", "RTDetrConfig"),
("rt_detr_resnet", "RTDetrResNetConfig"),
("rt_detr_v2", "RtDetrV2Config"),
("rt_detr_v2_resnet", "RtDetrV2ResNetConfig"),
("rwkv", "RwkvConfig"),
("sam", "SamConfig"),
("seamless_m4t", "SeamlessM4TConfig"),
Expand Down Expand Up @@ -580,7 +581,8 @@
("roformer", "RoFormer"),
("rt_detr", "RT-DETR"),
("rt_detr_resnet", "RT-DETR-ResNet"),
("rt_detr_v2", "RtDetrV2"),
("rt_detr_v2", "RT-DETR-V2-ResNet"),
("rt_detr_v2_resnet", "RtDetrV2ResNetConfig"),
("rwkv", "RWKV"),
("sam", "SAM"),
("seamless_m4t", "SeamlessM4T"),
Expand Down Expand Up @@ -713,6 +715,7 @@
("siglip_vision_model", "siglip"),
("chinese_clip_vision_model", "chinese_clip"),
("rt_detr_resnet", "rt_detr"),
("rt_detr_v2_resnet", "rt_detr_v2"),
]
)

Expand Down
21 changes: 21 additions & 0 deletions src/transformers/models/rt_detr_v2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import TYPE_CHECKING

from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
from .configuration_rt_detr_v2 import *
from .modeling_rt_detr_v2 import *
450 changes: 450 additions & 0 deletions src/transformers/models/rt_detr_v2/configuration_rt_detr_v2.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@
from PIL import Image
from torchvision import transforms

from transformers import RTDetrImageProcessor
from transformers.models.rt_detr_v2.modular_rt_detr_v2 import RtDetrV2Config, RtDetrV2ForObjectDetection
from transformers import RTDetrImageProcessor, RtDetrV2Config, RtDetrV2ForObjectDetection
from transformers.utils import logging


Expand Down Expand Up @@ -558,7 +557,6 @@ def convert_rt_detr_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub
"ema"
]["module"]


# rename keys
for src, dest in create_rename_keys(config):
rename_key(state_dict, src, dest)
Expand All @@ -581,7 +579,6 @@ def convert_rt_detr_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub
model.load_state_dict(state_dict)
model.eval()


# load image processor
image_processor = RTDetrImageProcessor()

Expand Down
Loading

0 comments on commit 350cafe

Please sign in to comment.