From 333efcb6d0b60d7cceb7afc91bd96315cf211b0a Mon Sep 17 00:00:00 2001 From: Yanghao Li Date: Thu, 9 Jun 2022 14:49:30 -0700 Subject: [PATCH] ViTDet README and COCO configs Reviewed By: rbgirshick, wat3rBro, HannaMao Differential Revision: D36117941 fbshipit-source-id: 9608b390b958f2471fbdedfb5f97ae0a3c23e006 --- README.md | 2 +- configs/Misc/mmdet_mask_rcnn_R_50_FPN_1x.py | 5 +- configs/common/data/constants.py | 9 + configs/common/models/mask_rcnn_c4.py | 6 +- configs/common/models/mask_rcnn_fpn.py | 6 +- configs/common/models/mask_rcnn_vitdet.py | 59 +++++ configs/common/models/retinanet.py | 6 +- configs/common/optim.py | 13 ++ detectron2/modeling/__init__.py | 1 + detectron2/modeling/backbone/__init__.py | 2 +- detectron2/modeling/backbone/backbone.py | 6 +- detectron2/modeling/backbone/fpn.py | 5 +- detectron2/modeling/backbone/vit.py | 29 ++- detectron2/structures/image_list.py | 12 +- projects/README.md | 2 +- projects/ViTDet/README.md | 202 ++++++++++++++++++ .../cascade_mask_rcnn_mvitv2_b_in21k_100ep.py | 95 ++++++++ .../cascade_mask_rcnn_mvitv2_h_in21k_36ep.py | 39 ++++ .../cascade_mask_rcnn_mvitv2_l_in21k_50ep.py | 22 ++ .../cascade_mask_rcnn_swin_b_in21k_50ep.py | 50 +++++ .../cascade_mask_rcnn_swin_l_in21k_50ep.py | 15 ++ .../COCO/cascade_mask_rcnn_vitdet_b_100ep.py | 48 +++++ .../COCO/cascade_mask_rcnn_vitdet_h_75ep.py | 31 +++ .../COCO/cascade_mask_rcnn_vitdet_l_100ep.py | 23 ++ .../configs/COCO/mask_rcnn_vitdet_b_100ep.py | 38 ++++ .../configs/COCO/mask_rcnn_vitdet_h_75ep.py | 31 +++ .../configs/COCO/mask_rcnn_vitdet_l_100ep.py | 23 ++ .../ViTDet/configs/common/coco_loader_lsj.py | 22 ++ 28 files changed, 776 insertions(+), 26 deletions(-) create mode 100644 configs/common/data/constants.py create mode 100644 configs/common/models/mask_rcnn_vitdet.py create mode 100644 projects/ViTDet/README.md create mode 100644 projects/ViTDet/configs/COCO/cascade_mask_rcnn_mvitv2_b_in21k_100ep.py create mode 100644 projects/ViTDet/configs/COCO/cascade_mask_rcnn_mvitv2_h_in21k_36ep.py create mode 100644 projects/ViTDet/configs/COCO/cascade_mask_rcnn_mvitv2_l_in21k_50ep.py create mode 100644 projects/ViTDet/configs/COCO/cascade_mask_rcnn_swin_b_in21k_50ep.py create mode 100644 projects/ViTDet/configs/COCO/cascade_mask_rcnn_swin_l_in21k_50ep.py create mode 100644 projects/ViTDet/configs/COCO/cascade_mask_rcnn_vitdet_b_100ep.py create mode 100644 projects/ViTDet/configs/COCO/cascade_mask_rcnn_vitdet_h_75ep.py create mode 100644 projects/ViTDet/configs/COCO/cascade_mask_rcnn_vitdet_l_100ep.py create mode 100644 projects/ViTDet/configs/COCO/mask_rcnn_vitdet_b_100ep.py create mode 100644 projects/ViTDet/configs/COCO/mask_rcnn_vitdet_h_75ep.py create mode 100644 projects/ViTDet/configs/COCO/mask_rcnn_vitdet_l_100ep.py create mode 100644 projects/ViTDet/configs/common/coco_loader_lsj.py diff --git a/README.md b/README.md index d8b8d3ca95..edc60153f8 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ Explain Like I’m 5: Detectron2 | Using Machine Learning with Detec ## What's New * Includes new capabilities such as panoptic segmentation, Densepose, Cascade R-CNN, rotated bounding boxes, PointRend, - DeepLab, etc. + DeepLab, ViTDet, etc. * Used as a library to support building [research projects](projects/) on top of it. * Models can be exported to TorchScript format or Caffe2 format for deployment. * It [trains much faster](https://detectron2.readthedocs.io/notes/benchmarks.html). diff --git a/configs/Misc/mmdet_mask_rcnn_R_50_FPN_1x.py b/configs/Misc/mmdet_mask_rcnn_R_50_FPN_1x.py index 0f2464be74..bdd49a4566 100644 --- a/configs/Misc/mmdet_mask_rcnn_R_50_FPN_1x.py +++ b/configs/Misc/mmdet_mask_rcnn_R_50_FPN_1x.py @@ -4,6 +4,7 @@ from ..common.coco_schedule import lr_multiplier_1x as lr_multiplier from ..common.optim import SGD as optimizer from ..common.train import train +from ..common.data.constants import constants from detectron2.modeling.mmdet_wrapper import MMDetDetector from detectron2.config import LazyCall as L @@ -143,8 +144,8 @@ ), ), ), - pixel_mean=[123.675, 116.280, 103.530], - pixel_std=[58.395, 57.120, 57.375], + pixel_mean=constants.imagenet_rgb256_mean, + pixel_std=constants.imagenet_rgb256_std, ) dataloader.train.mapper.image_format = "RGB" # torchvision pretrained model diff --git a/configs/common/data/constants.py b/configs/common/data/constants.py new file mode 100644 index 0000000000..be11cb5ac7 --- /dev/null +++ b/configs/common/data/constants.py @@ -0,0 +1,9 @@ +constants = dict( + imagenet_rgb256_mean=[123.675, 116.28, 103.53], + imagenet_rgb256_std=[58.395, 57.12, 57.375], + imagenet_bgr256_mean=[103.530, 116.280, 123.675], + # When using pre-trained models in Detectron1 or any MSRA models, + # std has been absorbed into its conv1 weights, so the std needs to be set 1. + # Otherwise, you can use [57.375, 57.120, 58.395] (ImageNet std) + imagenet_bgr256_std=[1.0, 1.0, 1.0], +) diff --git a/configs/common/models/mask_rcnn_c4.py b/configs/common/models/mask_rcnn_c4.py index a3dcf8be42..902d5b195f 100644 --- a/configs/common/models/mask_rcnn_c4.py +++ b/configs/common/models/mask_rcnn_c4.py @@ -13,6 +13,8 @@ Res5ROIHeads, ) +from ..data.constants import constants + model = L(GeneralizedRCNN)( backbone=L(ResNet)( stem=L(BasicStem)(in_channels=3, out_channels=64, norm="FrozenBN"), @@ -82,7 +84,7 @@ conv_dims=[256], ), ), - pixel_mean=[103.530, 116.280, 123.675], - pixel_std=[1.0, 1.0, 1.0], + pixel_mean=constants.imagenet_bgr256_mean, + pixel_std=constants.imagenet_bgr256_std, input_format="BGR", ) diff --git a/configs/common/models/mask_rcnn_fpn.py b/configs/common/models/mask_rcnn_fpn.py index 744d5306f5..5e5c501cd1 100644 --- a/configs/common/models/mask_rcnn_fpn.py +++ b/configs/common/models/mask_rcnn_fpn.py @@ -15,6 +15,8 @@ FastRCNNConvFCHead, ) +from ..data.constants import constants + model = L(GeneralizedRCNN)( backbone=L(FPN)( bottom_up=L(ResNet)( @@ -87,7 +89,7 @@ conv_dims=[256, 256, 256, 256, 256], ), ), - pixel_mean=[103.530, 116.280, 123.675], - pixel_std=[1.0, 1.0, 1.0], + pixel_mean=constants.imagenet_bgr256_mean, + pixel_std=constants.imagenet_bgr256_std, input_format="BGR", ) diff --git a/configs/common/models/mask_rcnn_vitdet.py b/configs/common/models/mask_rcnn_vitdet.py new file mode 100644 index 0000000000..d6f5244402 --- /dev/null +++ b/configs/common/models/mask_rcnn_vitdet.py @@ -0,0 +1,59 @@ +from functools import partial +import torch.nn as nn +from detectron2.config import LazyCall as L +from detectron2.modeling import ViT, SimpleFeaturePyramid +from detectron2.modeling.backbone.fpn import LastLevelMaxPool + +from .mask_rcnn_fpn import model +from ..data.constants import constants + +model.pixel_mean = constants.imagenet_rgb256_mean +model.pixel_std = constants.imagenet_rgb256_std +model.input_format = "RGB" + +# Base +embed_dim, depth, num_heads, dp = 768, 12, 12, 0.1 +# Creates Simple Feature Pyramid from ViT backbone +model.backbone = L(SimpleFeaturePyramid)( + net=L(ViT)( # Single-scale ViT backbone + img_size=1024, + patch_size=16, + embed_dim=embed_dim, + depth=depth, + num_heads=num_heads, + drop_path_rate=dp, + window_size=14, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + window_block_indexes=[ + # 2, 5, 8 11 for global attention + 0, + 1, + 3, + 4, + 6, + 7, + 9, + 10, + ], + residual_block_indexes=[], + use_rel_pos=True, + out_feature="last_feat", + ), + in_feature="${.net.out_feature}", + out_channels=256, + scale_factors=(4.0, 2.0, 1.0, 0.5), + top_block=L(LastLevelMaxPool)(), + norm="LN", + square_pad=1024, +) + +model.roi_heads.box_head.conv_norm = model.roi_heads.mask_head.conv_norm = "LN" + +# 2conv in RPN: +model.proposal_generator.head.conv_dims = [-1, -1] + +# 4conv1fc box head +model.roi_heads.box_head.conv_dims = [256, 256, 256, 256] +model.roi_heads.box_head.fc_dims = [1024] diff --git a/configs/common/models/retinanet.py b/configs/common/models/retinanet.py index 83cfda4b60..784e5317f5 100644 --- a/configs/common/models/retinanet.py +++ b/configs/common/models/retinanet.py @@ -10,6 +10,8 @@ from detectron2.modeling.matcher import Matcher from detectron2.modeling.meta_arch.retinanet import RetinaNetHead +from ..data.constants import constants + model = L(RetinaNet)( backbone=L(FPN)( bottom_up=L(ResNet)( @@ -47,7 +49,7 @@ head_in_features=["p3", "p4", "p5", "p6", "p7"], focal_loss_alpha=0.25, focal_loss_gamma=2.0, - pixel_mean=[103.530, 116.280, 123.675], - pixel_std=[1.0, 1.0, 1.0], + pixel_mean=constants.imagenet_bgr256_mean, + pixel_std=constants.imagenet_bgr256_std, input_format="BGR", ) diff --git a/configs/common/optim.py b/configs/common/optim.py index d39d3aaa54..6cf43e835f 100644 --- a/configs/common/optim.py +++ b/configs/common/optim.py @@ -13,3 +13,16 @@ momentum=0.9, weight_decay=1e-4, ) + + +AdamW = L(torch.optim.AdamW)( + params=L(get_default_optimizer_params)( + # params.model is meant to be set to the model object, before instantiating + # the optimizer. + base_lr="${..lr}", + weight_decay_norm=0.0, + ), + lr=1e-4, + betas=(0.9, 0.999), + weight_decay=0.1, +) diff --git a/detectron2/modeling/__init__.py b/detectron2/modeling/__init__.py index 90a9282b83..4d949e222b 100644 --- a/detectron2/modeling/__init__.py +++ b/detectron2/modeling/__init__.py @@ -13,6 +13,7 @@ make_stage, ViT, SimpleFeaturePyramid, + get_vit_lr_decay_rate, MViT, SwinTransformer, ) diff --git a/detectron2/modeling/backbone/__init__.py b/detectron2/modeling/backbone/__init__.py index 6947fefb3c..5b3358a406 100644 --- a/detectron2/modeling/backbone/__init__.py +++ b/detectron2/modeling/backbone/__init__.py @@ -12,7 +12,7 @@ make_stage, BottleneckBlock, ) -from .vit import ViT, SimpleFeaturePyramid +from .vit import ViT, SimpleFeaturePyramid, get_vit_lr_decay_rate from .mvit import MViT from .swin import SwinTransformer diff --git a/detectron2/modeling/backbone/backbone.py b/detectron2/modeling/backbone/backbone.py index e3deec01f6..e1c765a6b3 100644 --- a/detectron2/modeling/backbone/backbone.py +++ b/detectron2/modeling/backbone/backbone.py @@ -49,11 +49,11 @@ def padding_constraints(self) -> Dict[str, int]: in :paper:vitdet). `padding_constraints` contains these optional items like: { "size_divisibility": int, - "square": int, + "square_size": int, # Future options are possible } - `size_divisibility` will read from here if presented and `square` indicates if requiring - inputs to be padded to square. Set to None if no specific padding constraints. + `size_divisibility` will read from here if presented and `square_size` indicates the + square padding size if `square_size` > 0. TODO: use type of Dict[str, int] to avoid torchscipt issues. The type of padding_constraints could be generalized as TypedDict (Python 3.8+) to support more types in the future. diff --git a/detectron2/modeling/backbone/fpn.py b/detectron2/modeling/backbone/fpn.py index 0ebaf95ea4..19d24e13f0 100644 --- a/detectron2/modeling/backbone/fpn.py +++ b/detectron2/modeling/backbone/fpn.py @@ -30,7 +30,7 @@ def __init__( norm="", top_block=None, fuse_type="sum", - square_pad=False, + square_pad=0, ): """ Args: @@ -54,6 +54,7 @@ def __init__( fuse_type (str): types for fusing the top down features and the lateral ones. It can be "sum" (default), which sums up element-wise; or "avg", which takes the element-wise mean of the two. + square_pad (int): If > 0, require input images to be padded to specific square size. """ super(FPN, self).__init__() assert isinstance(bottom_up, Backbone) @@ -120,7 +121,7 @@ def size_divisibility(self): @property def padding_constraints(self): - return {"square": int(self._square_pad)} + return {"square_size": self._square_pad} def forward(self, x): """ diff --git a/detectron2/modeling/backbone/vit.py b/detectron2/modeling/backbone/vit.py index 55dbc0dff7..230c99152f 100644 --- a/detectron2/modeling/backbone/vit.py +++ b/detectron2/modeling/backbone/vit.py @@ -22,7 +22,7 @@ logger = logging.getLogger(__name__) -__all__ = ["ViT", "SimpleFeaturePyramid"] +__all__ = ["ViT", "SimpleFeaturePyramid", "get_vit_lr_decay_rate"] class Attention(nn.Module): @@ -372,7 +372,7 @@ def __init__( scale_factors, top_block=None, norm="LN", - square_pad=False, + square_pad=0, ): """ Args: @@ -391,7 +391,7 @@ def __init__( this block, and "in_feature", which is a string representing its input feature (e.g., p5). norm (str): the normalization to use. - square_pad (bool): If true, require input images to be padded to square. + square_pad (int): If > 0, require input images to be padded to specific square size. """ super(SimpleFeaturePyramid, self).__init__() assert isinstance(net, Backbone) @@ -469,7 +469,7 @@ def __init__( def padding_constraints(self): return { "size_divisiblity": self._size_divisibility, - "square": int(self._square_pad), + "square_size": self._square_pad, } def forward(self, x): @@ -499,3 +499,24 @@ def forward(self, x): results.extend(self.top_block(top_block_in_feature)) assert len(self._out_features) == len(results) return {f: res for f, res in zip(self._out_features, results)} + + +def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12): + """ + Calculate lr decay rate for different ViT blocks. + Args: + name (string): parameter name. + lr_decay_rate (float): base lr decay rate. + num_layers (int): number of ViT blocks. + + Returns: + lr decay rate for the given parameter. + """ + layer_id = num_layers + 1 + if name.startswith("backbone"): + if ".pos_embed" in name or ".patch_embed" in name: + layer_id = 0 + elif ".blocks." in name and ".residual." not in name: + layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1 + + return lr_decay_rate ** (num_layers + 1 - layer_id) diff --git a/detectron2/structures/image_list.py b/detectron2/structures/image_list.py index fc4ecca92c..9f4df5241b 100644 --- a/detectron2/structures/image_list.py +++ b/detectron2/structures/image_list.py @@ -72,10 +72,9 @@ def from_tensors( This depends on the model and many models need a divisibility of 32. pad_value (float): value to pad. padding_constraints (optional[Dict]): If given, it would follow the format as - {"size_divisibility": int, "square": int}, where `size_divisibility` will overwrite - the above one if presented and `square` indicates if require inputs to be padded to - square. - + {"size_divisibility": int, "square_size": int}, where `size_divisibility` will + overwrite the above one if presented and `square_size` indicates the + square padding size if `square_size` > 0. Returns: an `ImageList`. """ @@ -90,9 +89,10 @@ def from_tensors( max_size = torch.stack(image_sizes_tensor).max(0).values if padding_constraints is not None: - if padding_constraints.get("square", 0) > 0: + square_size = padding_constraints.get("square_size", 0) + if square_size > 0: # pad to square. - max_size[0] = max_size[1] = max_size.max() + max_size[0] = max_size[1] = square_size if "size_divisibility" in padding_constraints: size_divisibility = padding_constraints["size_divisibility"] if size_divisibility > 1: diff --git a/projects/README.md b/projects/README.md index c6ea107398..fc83d5e831 100644 --- a/projects/README.md +++ b/projects/README.md @@ -21,6 +21,7 @@ of support or stability as detectron2. + [Unbiased Teacher for Semi-Supervised Object Detection](https://github.com/facebookresearch/unbiased-teacher) + [Rethinking "Batch" in BatchNorm](Rethinking-BatchNorm/) + [Per-Pixel Classification is Not All You Need for Semantic Segmentation](https://github.com/facebookresearch/MaskFormer) ++ [Exploring Plain Vision Transformer Backbones for Object Detection](ViTDet/) ## External Projects @@ -45,4 +46,3 @@ External projects in the community that use detectron2: + [Sparse R-CNN](https://github.com/PeizeSun/SparseR-CNN) + [BCNet](https://github.com/lkeab/BCNet), a bilayer decoupling instance segmentation method. + [DD3D](https://github.com/TRI-ML/dd3d), A fully convolutional 3D detector. - diff --git a/projects/ViTDet/README.md b/projects/ViTDet/README.md new file mode 100644 index 0000000000..332ac35ac2 --- /dev/null +++ b/projects/ViTDet/README.md @@ -0,0 +1,202 @@ +# ViTDet: Exploring Plain Vision Transformer Backbones for Object Detection + +Yanghao Li, Hanzi Mao, Ross Girshick†, Kaiming He† + +[[`arXiv`](https://arxiv.org/abs/2203.16527)] [[`BibTeX`](#CitingViTDet)] + +In this repository, we provide configs and models in Detectron2 for ViTDet as well as MViTv2 and Swin backbones with our implementation and settings as described in [ViTDet](https://arxiv.org/abs/2203.16527) paper. + + +## Pretrained Models + +### COCO + +#### Mask R-CNN + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Namepre-traintrain
time
(s/im)
inference
time
(s/im)
train
mem
(GB)
box
AP
mask
AP
model iddownload
ViTDet, ViT-BIN1K, MAE0.3140.07910.951.645.9325346929model
ViTDet, ViT-LIN1K, MAE0.6030.12520.955.549.2325599698model
ViTDet, ViT-HIN1K, MAE1.0980.17831.556.750.2329145471model
+ +#### Cascade Mask R-CNN + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Namepre-traintrain
time
(s/im)
inference
time
(s/im)
train
mem
(GB)
box
AP
mask
AP
model iddownload
Swin-BIN21K, sup0.3890.0778.753.946.2342979038model
Swin-LIN21K, sup0.5080.09712.655.047.2342979186model
MViTv2-BIN21K, sup0.4750.0908.955.648.1325820315model
MViTv2-LIN21K, sup0.8440.15719.755.748.3325607715model
MViTv2-HIN21K, sup1.6550.28518.4*55.948.3326187358model
ViTDet, ViT-BIN1K, MAE0.3620.08912.354.046.7325358525model
ViTDet, ViT-LIN1K, MAE0.6430.14222.357.650.0328021305model
ViTDet, ViT-HIN1K, MAE1.1370.19632.958.751.0328730692model
+ +Note: Unlike the system-level comparisons in the paper, these models use a lower resolution (1024 instead of 1280) and standard NMS (instead of soft NMS). As a result, they have slightly lower box and mask AP. + +The above models were trained and measured on 8-node with 64 NVIDIA A100 GPUs in total. *: Activation checkpointing is used. + + +## Training +All configs can be trained with: + +``` +../../tools/lazyconfig_train_net.py --config-file configs/path/to/config.py +``` +By default, we use 64 GPUs with batch size as 64 for training. + +## Evaluation +Model evaluation can be done similarly: +``` +../../tools/lazyconfig_train_net.py --config-file configs/path/to/config.py --eval-only train.init_checkpoint=/path/to/model_checkpoint +``` + + +## Citing ViTDet + +If you use ViTDet, please use the following BibTeX entry. + +```BibTeX +@article{li2022exploring, + title={Exploring plain vision transformer backbones for object detection}, + author={Li, Yanghao and Mao, Hanzi and Girshick, Ross and He, Kaiming}, + journal={arXiv preprint arXiv:2203.16527}, + year={2022} +} +``` diff --git a/projects/ViTDet/configs/COCO/cascade_mask_rcnn_mvitv2_b_in21k_100ep.py b/projects/ViTDet/configs/COCO/cascade_mask_rcnn_mvitv2_b_in21k_100ep.py new file mode 100644 index 0000000000..9dba203086 --- /dev/null +++ b/projects/ViTDet/configs/COCO/cascade_mask_rcnn_mvitv2_b_in21k_100ep.py @@ -0,0 +1,95 @@ +from functools import partial +import torch.nn as nn +from fvcore.common.param_scheduler import MultiStepParamScheduler + +from detectron2 import model_zoo +from detectron2.config import LazyCall as L +from detectron2.solver import WarmupParamScheduler +from detectron2.modeling import MViT +from detectron2.layers import ShapeSpec +from detectron2.modeling.box_regression import Box2BoxTransform +from detectron2.modeling.matcher import Matcher +from detectron2.modeling.roi_heads import ( + FastRCNNOutputLayers, + FastRCNNConvFCHead, + CascadeROIHeads, +) + +from ..common.coco_loader_lsj import dataloader + +model = model_zoo.get_config("common/models/mask_rcnn_fpn.py").model +constants = model_zoo.get_config("common/data/constants.py").constants +model.pixel_mean = constants.imagenet_rgb256_mean +model.pixel_std = constants.imagenet_rgb256_std +model.input_format = "RGB" +model.backbone.bottom_up = L(MViT)( + embed_dim=96, + depth=24, + num_heads=1, + last_block_indexes=(1, 4, 20, 23), + residual_pooling=True, + drop_path_rate=0.4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + out_features=("scale2", "scale3", "scale4", "scale5"), +) +model.backbone.in_features = "${.bottom_up.out_features}" +model.backbone.square_pad = 1024 + +# New heads and LN +model.backbone.norm = "LN" # Use LN in FPN +model.roi_heads.box_head.conv_norm = model.roi_heads.mask_head.conv_norm = "LN" + +# 2conv in RPN: +model.proposal_generator.head.conv_dims = [-1, -1] + +# arguments that don't exist for Cascade R-CNN +[model.roi_heads.pop(k) for k in ["box_head", "box_predictor", "proposal_matcher"]] +model.roi_heads.update( + _target_=CascadeROIHeads, + box_heads=[ + L(FastRCNNConvFCHead)( + input_shape=ShapeSpec(channels=256, height=7, width=7), + conv_dims=[256, 256, 256, 256], + fc_dims=[1024], + conv_norm="LN", + ) + for _ in range(3) + ], + box_predictors=[ + L(FastRCNNOutputLayers)( + input_shape=ShapeSpec(channels=1024), + test_score_thresh=0.05, + box2box_transform=L(Box2BoxTransform)(weights=(w1, w1, w2, w2)), + cls_agnostic_bbox_reg=True, + num_classes="${...num_classes}", + ) + for (w1, w2) in [(10, 5), (20, 10), (30, 15)] + ], + proposal_matchers=[ + L(Matcher)(thresholds=[th], labels=[0, 1], allow_low_quality_matches=False) + for th in [0.5, 0.6, 0.7] + ], +) + +# Initialization and trainer settings +train = model_zoo.get_config("common/train.py").train +train.amp.enabled = True +train.ddp.fp16_compression = True +train.init_checkpoint = "detectron2://ImageNetPretrained/mvitv2/MViTv2_B_in21k.pyth" + +# Schedule +# 100 ep = 184375 iters * 64 images/iter / 118000 images/ep +train.max_iter = 184375 +lr_multiplier = L(WarmupParamScheduler)( + scheduler=L(MultiStepParamScheduler)( + values=[1.0, 0.1, 0.01], + milestones=[163889, 177546], + num_updates=train.max_iter, + ), + warmup_length=250 / train.max_iter, + warmup_factor=0.001, +) + +optimizer = model_zoo.get_config("common/optim.py").AdamW +optimizer.params.overrides = {"pos_embed": {"weight_decay": 0.0}} +optimizer.lr = 8e-5 diff --git a/projects/ViTDet/configs/COCO/cascade_mask_rcnn_mvitv2_h_in21k_36ep.py b/projects/ViTDet/configs/COCO/cascade_mask_rcnn_mvitv2_h_in21k_36ep.py new file mode 100644 index 0000000000..577045043b --- /dev/null +++ b/projects/ViTDet/configs/COCO/cascade_mask_rcnn_mvitv2_h_in21k_36ep.py @@ -0,0 +1,39 @@ +from fvcore.common.param_scheduler import MultiStepParamScheduler + +from detectron2.config import LazyCall as L +from detectron2.solver import WarmupParamScheduler + +from .cascade_mask_rcnn_mvitv2_b_in21k_100ep import ( + dataloader, + lr_multiplier, + model, + train, + optimizer, +) + +model.backbone.bottom_up.embed_dim = 192 +model.backbone.bottom_up.depth = 80 +model.backbone.bottom_up.num_heads = 3 +model.backbone.bottom_up.last_block_indexes = (3, 11, 71, 79) +model.backbone.bottom_up.drop_path_rate = 0.6 +model.backbone.bottom_up.use_act_checkpoint = True + + +train.init_checkpoint = "detectron2://ImageNetPretrained/mvitv2/MViTv2_H_in21k.pyth" + + +# 36 epochs +train.max_iter = 67500 +lr_multiplier = L(WarmupParamScheduler)( + scheduler=L(MultiStepParamScheduler)( + values=[1.0, 0.1, 0.01], + milestones=[ + 52500, + 62500, + 67500, + ], + ), + warmup_length=250 / train.max_iter, + warmup_factor=0.001, +) +optimizer.lr = 1.6e-4 diff --git a/projects/ViTDet/configs/COCO/cascade_mask_rcnn_mvitv2_l_in21k_50ep.py b/projects/ViTDet/configs/COCO/cascade_mask_rcnn_mvitv2_l_in21k_50ep.py new file mode 100644 index 0000000000..c64f0c18ae --- /dev/null +++ b/projects/ViTDet/configs/COCO/cascade_mask_rcnn_mvitv2_l_in21k_50ep.py @@ -0,0 +1,22 @@ +from .cascade_mask_rcnn_mvitv2_b_in21k_100ep import ( + dataloader, + lr_multiplier, + model, + train, + optimizer, +) + +model.backbone.bottom_up.embed_dim = 144 +model.backbone.bottom_up.depth = 48 +model.backbone.bottom_up.num_heads = 2 +model.backbone.bottom_up.last_block_indexes = (1, 7, 43, 47) +model.backbone.bottom_up.drop_path_rate = 0.5 + + +train.init_checkpoint = "detectron2://ImageNetPretrained/mvitv2/MViTv2_L_in21k.pyth" + +train.max_iter = train.max_iter // 2 # 100ep -> 50ep +lr_multiplier.scheduler.milestones = [ + milestone // 2 for milestone in lr_multiplier.scheduler.milestones +] +lr_multiplier.scheduler.num_updates = train.max_iter diff --git a/projects/ViTDet/configs/COCO/cascade_mask_rcnn_swin_b_in21k_50ep.py b/projects/ViTDet/configs/COCO/cascade_mask_rcnn_swin_b_in21k_50ep.py new file mode 100644 index 0000000000..b2aad98526 --- /dev/null +++ b/projects/ViTDet/configs/COCO/cascade_mask_rcnn_swin_b_in21k_50ep.py @@ -0,0 +1,50 @@ +from fvcore.common.param_scheduler import MultiStepParamScheduler + +from detectron2 import model_zoo +from detectron2.config import LazyCall as L +from detectron2.solver import WarmupParamScheduler +from detectron2.modeling import SwinTransformer + +from ..common.coco_loader_lsj import dataloader +from .cascade_mask_rcnn_mvitv2_b_in21k_100ep import model + +model.backbone.bottom_up = L(SwinTransformer)( + depths=[2, 2, 18, 2], + drop_path_rate=0.4, + embed_dim=128, + num_heads=[4, 8, 16, 32], +) +model.backbone.in_features = ("p0", "p1", "p2", "p3") +model.backbone.square_pad = 1024 + +# Initialization and trainer settings +train = model_zoo.get_config("common/train.py").train +train.amp.enabled = True +train.ddp.fp16_compression = True +train.init_checkpoint = "detectron2://ImageNetPretrained/swin/swin_base_patch4_window7_224_22k.pth" + +# Schedule +# 100 ep = 184375 iters * 64 images/iter / 118000 images/ep +train.max_iter = 184375 +lr_multiplier = L(WarmupParamScheduler)( + scheduler=L(MultiStepParamScheduler)( + values=[1.0, 0.1, 0.01], + milestones=[163889, 177546], + num_updates=train.max_iter, + ), + warmup_length=250 / train.max_iter, + warmup_factor=0.001, +) + +# Rescale schedule +train.max_iter = train.max_iter // 2 # 100ep -> 50ep +lr_multiplier.scheduler.milestones = [ + milestone // 2 for milestone in lr_multiplier.scheduler.milestones +] +lr_multiplier.scheduler.num_updates = train.max_iter + + +optimizer = model_zoo.get_config("common/optim.py").AdamW +optimizer.lr = 4e-5 +optimizer.weight_decay = 0.05 +optimizer.params.overrides = {"relative_position_bias_table": {"weight_decay": 0.0}} diff --git a/projects/ViTDet/configs/COCO/cascade_mask_rcnn_swin_l_in21k_50ep.py b/projects/ViTDet/configs/COCO/cascade_mask_rcnn_swin_l_in21k_50ep.py new file mode 100644 index 0000000000..60bc917b59 --- /dev/null +++ b/projects/ViTDet/configs/COCO/cascade_mask_rcnn_swin_l_in21k_50ep.py @@ -0,0 +1,15 @@ +from .cascade_mask_rcnn_swin_b_in21k_50ep import ( + dataloader, + lr_multiplier, + model, + train, + optimizer, +) + +model.backbone.bottom_up.depths = [2, 2, 18, 2] +model.backbone.bottom_up.drop_path_rate = 0.4 +model.backbone.bottom_up.embed_dim = 192 +model.backbone.bottom_up.num_heads = [6, 12, 24, 48] + + +train.init_checkpoint = "detectron2://ImageNetPretrained/swin/swin_large_patch4_window7_224_22k.pth" diff --git a/projects/ViTDet/configs/COCO/cascade_mask_rcnn_vitdet_b_100ep.py b/projects/ViTDet/configs/COCO/cascade_mask_rcnn_vitdet_b_100ep.py new file mode 100644 index 0000000000..95823ef4fb --- /dev/null +++ b/projects/ViTDet/configs/COCO/cascade_mask_rcnn_vitdet_b_100ep.py @@ -0,0 +1,48 @@ +from detectron2.config import LazyCall as L +from detectron2.layers import ShapeSpec +from detectron2.modeling.box_regression import Box2BoxTransform +from detectron2.modeling.matcher import Matcher +from detectron2.modeling.roi_heads import ( + FastRCNNOutputLayers, + FastRCNNConvFCHead, + CascadeROIHeads, +) + +from .mask_rcnn_vitdet_b_100ep import ( + dataloader, + lr_multiplier, + model, + train, + optimizer, + get_vit_lr_decay_rate, +) + +# arguments that don't exist for Cascade R-CNN +[model.roi_heads.pop(k) for k in ["box_head", "box_predictor", "proposal_matcher"]] + +model.roi_heads.update( + _target_=CascadeROIHeads, + box_heads=[ + L(FastRCNNConvFCHead)( + input_shape=ShapeSpec(channels=256, height=7, width=7), + conv_dims=[256, 256, 256, 256], + fc_dims=[1024], + conv_norm="LN", + ) + for _ in range(3) + ], + box_predictors=[ + L(FastRCNNOutputLayers)( + input_shape=ShapeSpec(channels=1024), + test_score_thresh=0.05, + box2box_transform=L(Box2BoxTransform)(weights=(w1, w1, w2, w2)), + cls_agnostic_bbox_reg=True, + num_classes="${...num_classes}", + ) + for (w1, w2) in [(10, 5), (20, 10), (30, 15)] + ], + proposal_matchers=[ + L(Matcher)(thresholds=[th], labels=[0, 1], allow_low_quality_matches=False) + for th in [0.5, 0.6, 0.7] + ], +) diff --git a/projects/ViTDet/configs/COCO/cascade_mask_rcnn_vitdet_h_75ep.py b/projects/ViTDet/configs/COCO/cascade_mask_rcnn_vitdet_h_75ep.py new file mode 100644 index 0000000000..34a553453c --- /dev/null +++ b/projects/ViTDet/configs/COCO/cascade_mask_rcnn_vitdet_h_75ep.py @@ -0,0 +1,31 @@ +from functools import partial + +from .cascade_mask_rcnn_vitdet_b_100ep import ( + dataloader, + lr_multiplier, + model, + train, + optimizer, + get_vit_lr_decay_rate, +) + +train.init_checkpoint = "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_huge_p14to16.pth" + +model.backbone.net.embed_dim = 1280 +model.backbone.net.depth = 32 +model.backbone.net.num_heads = 16 +model.backbone.net.drop_path_rate = 0.5 +# 7, 15, 23, 31 for global attention +model.backbone.net.window_block_indexes = ( + list(range(0, 7)) + list(range(8, 15)) + list(range(16, 23)) + list(range(24, 31)) +) + +optimizer.params.lr_factor_func = partial(get_vit_lr_decay_rate, lr_decay_rate=0.9, num_layers=32) +optimizer.params.overrides = {} +optimizer.params.weight_decay_norm = None + +train.max_iter = train.max_iter * 3 // 4 # 100ep -> 75ep +lr_multiplier.scheduler.milestones = [ + milestone * 3 // 4 for milestone in lr_multiplier.scheduler.milestones +] +lr_multiplier.scheduler.num_updates = train.max_iter diff --git a/projects/ViTDet/configs/COCO/cascade_mask_rcnn_vitdet_l_100ep.py b/projects/ViTDet/configs/COCO/cascade_mask_rcnn_vitdet_l_100ep.py new file mode 100644 index 0000000000..3ec259e0f9 --- /dev/null +++ b/projects/ViTDet/configs/COCO/cascade_mask_rcnn_vitdet_l_100ep.py @@ -0,0 +1,23 @@ +from functools import partial + +from .cascade_mask_rcnn_vitdet_b_100ep import ( + dataloader, + lr_multiplier, + model, + train, + optimizer, + get_vit_lr_decay_rate, +) + +train.init_checkpoint = "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_large.pth" + +model.backbone.net.embed_dim = 1024 +model.backbone.net.depth = 24 +model.backbone.net.num_heads = 16 +model.backbone.net.drop_path_rate = 0.4 +# 5, 11, 17, 23 for global attention +model.backbone.net.window_block_indexes = ( + list(range(0, 5)) + list(range(6, 11)) + list(range(12, 17)) + list(range(18, 23)) +) + +optimizer.params.lr_factor_func = partial(get_vit_lr_decay_rate, lr_decay_rate=0.8, num_layers=24) diff --git a/projects/ViTDet/configs/COCO/mask_rcnn_vitdet_b_100ep.py b/projects/ViTDet/configs/COCO/mask_rcnn_vitdet_b_100ep.py new file mode 100644 index 0000000000..7206525f4b --- /dev/null +++ b/projects/ViTDet/configs/COCO/mask_rcnn_vitdet_b_100ep.py @@ -0,0 +1,38 @@ +from functools import partial +from fvcore.common.param_scheduler import MultiStepParamScheduler + +from detectron2 import model_zoo +from detectron2.config import LazyCall as L +from detectron2.solver import WarmupParamScheduler +from detectron2.modeling.backbone.vit import get_vit_lr_decay_rate + +from ..common.coco_loader_lsj import dataloader + + +model = model_zoo.get_config("common/models/mask_rcnn_vitdet.py").model + +# Initialization and trainer settings +train = model_zoo.get_config("common/train.py").train +train.amp.enabled = True +train.ddp.fp16_compression = True +train.init_checkpoint = "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_base.pth" + + +# Schedule +# 100 ep = 184375 iters * 64 images/iter / 118000 images/ep +train.max_iter = 184375 + +lr_multiplier = L(WarmupParamScheduler)( + scheduler=L(MultiStepParamScheduler)( + values=[1.0, 0.1, 0.01], + milestones=[163889, 177546], + num_updates=train.max_iter, + ), + warmup_length=250 / train.max_iter, + warmup_factor=0.001, +) + +# Optimizer +optimizer = model_zoo.get_config("common/optim.py").AdamW +optimizer.params.lr_factor_func = partial(get_vit_lr_decay_rate, num_layers=12, lr_decay_rate=0.7) +optimizer.params.overrides = {"pos_embed": {"weight_decay": 0.0}} diff --git a/projects/ViTDet/configs/COCO/mask_rcnn_vitdet_h_75ep.py b/projects/ViTDet/configs/COCO/mask_rcnn_vitdet_h_75ep.py new file mode 100644 index 0000000000..9fe752c611 --- /dev/null +++ b/projects/ViTDet/configs/COCO/mask_rcnn_vitdet_h_75ep.py @@ -0,0 +1,31 @@ +from functools import partial + +from .mask_rcnn_vitdet_b_100ep import ( + dataloader, + lr_multiplier, + model, + train, + optimizer, + get_vit_lr_decay_rate, +) + +train.init_checkpoint = "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_huge_p14to16.pth" + +model.backbone.net.embed_dim = 1280 +model.backbone.net.depth = 32 +model.backbone.net.num_heads = 16 +model.backbone.net.drop_path_rate = 0.5 +# 7, 15, 23, 31 for global attention +model.backbone.net.window_block_indexes = ( + list(range(0, 7)) + list(range(8, 15)) + list(range(16, 23)) + list(range(24, 31)) +) + +optimizer.params.lr_factor_func = partial(get_vit_lr_decay_rate, lr_decay_rate=0.9, num_layers=32) +optimizer.params.overrides = {} +optimizer.params.weight_decay_norm = None + +train.max_iter = train.max_iter * 3 // 4 # 100ep -> 75ep +lr_multiplier.scheduler.milestones = [ + milestone * 3 // 4 for milestone in lr_multiplier.scheduler.milestones +] +lr_multiplier.scheduler.num_updates = train.max_iter diff --git a/projects/ViTDet/configs/COCO/mask_rcnn_vitdet_l_100ep.py b/projects/ViTDet/configs/COCO/mask_rcnn_vitdet_l_100ep.py new file mode 100644 index 0000000000..933b84eb4e --- /dev/null +++ b/projects/ViTDet/configs/COCO/mask_rcnn_vitdet_l_100ep.py @@ -0,0 +1,23 @@ +from functools import partial + +from .mask_rcnn_vitdet_b_100ep import ( + dataloader, + lr_multiplier, + model, + train, + optimizer, + get_vit_lr_decay_rate, +) + +train.init_checkpoint = "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_large.pth" + +model.backbone.net.embed_dim = 1024 +model.backbone.net.depth = 24 +model.backbone.net.num_heads = 16 +model.backbone.net.drop_path_rate = 0.4 +# 5, 11, 17, 23 for global attention +model.backbone.net.window_block_indexes = ( + list(range(0, 5)) + list(range(6, 11)) + list(range(12, 17)) + list(range(18, 23)) +) + +optimizer.params.lr_factor_func = partial(get_vit_lr_decay_rate, lr_decay_rate=0.8, num_layers=24) diff --git a/projects/ViTDet/configs/common/coco_loader_lsj.py b/projects/ViTDet/configs/common/coco_loader_lsj.py new file mode 100644 index 0000000000..e6c2f1e913 --- /dev/null +++ b/projects/ViTDet/configs/common/coco_loader_lsj.py @@ -0,0 +1,22 @@ +import detectron2.data.transforms as T +from detectron2 import model_zoo +from detectron2.config import LazyCall as L + +# Data using LSJ +image_size = 1024 +dataloader = model_zoo.get_config("common/data/coco.py").dataloader +dataloader.train.mapper.augmentations = [ + L(T.RandomFlip)(horizontal=True), # flip first + L(T.ResizeScale)( + min_scale=0.1, max_scale=2.0, target_height=image_size, target_width=image_size + ), + L(T.FixedSizeCrop)(crop_size=(image_size, image_size), pad=False), +] +dataloader.train.mapper.image_format = "RGB" +dataloader.train.total_batch_size = 64 +# recompute boxes due to cropping +dataloader.train.mapper.recompute_boxes = True + +dataloader.test.mapper.augmentations = [ + L(T.ResizeShortestEdge)(short_edge_length=image_size, max_size=image_size), +]