Skip to content

Commit

Permalink
add a config for Mask2Former+BEiT-Adapter-L
Browse files Browse the repository at this point in the history
ade20k 160k iters without coco-stuff
  • Loading branch information
czczup committed Feb 7, 2023
1 parent f19aa7b commit 19ed23e
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 11 deletions.
31 changes: 20 additions & 11 deletions segmentation/configs/ade20k/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,21 @@ The ADE20K semantic segmentation dataset contains more than 20K scene-centric

## Results and Models

There are two training strategies for the ADE20K dataset:

- `strategy 1`: ADE20K 160k iterations.

- `strategy 2`: COCO-Stuff 80k + ADE20K 80k iterations.

In other words, if the filename contains `80k`, that means `strategy 2`.

**DeiT Pre-train (ImageNet-1K, supervised)**

| Method | Backbone | Pre-train | Batch Size | Lr schd | Crop Size | mIoU (SS) | mIoU (MS) | #Param | Config | Download |
|:-------:|:-------------:|:---------------------------------------------------------------------------------:|:----------:|:-------:|:---------:|:---------:|:---------:|:------:|:---------------------------------------------------------------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|
| UperNet | ViT-Adapter-T | [DeiT-T](https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth) | 8x2 | 160k | 512 | 42.6 | 43.6 | 36M | [config](./upernet_deit_adapter_tiny_512_160k_ade20k.py) | [model](https://github.com/czczup/ViT-Adapter/releases/download/v0.3.1/upernet_deit_adapter_tiny_512_160_ade20k.pth.tar) \| [log](https://drive.google.com/file/d/1wG_6iIaVirmqLGDZt_2rtzp_ZtNV2D4O/view?usp=sharing) |
| UperNet | ViT-Adapter-S | [DeiT-S](https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth) | 8x2 | 160k | 512 | 46.2 | 47.1 | 58M | [config](./upernet_deit_adapter_small_512_160k_ade20k.py) | [model](https://github.com/czczup/ViT-Adapter/releases/download/v0.3.1/upernet_deit_adapter_small_512_160k_ade20k.pth) \| [log]() |
| UperNet | ViT-Adapter-B | [DeiT-B](https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth) | 8x2 | 160k | 512 | 48.8 | 49.7 | 134M | [config](./upernet_deit_adapter_base_512_160k_ade20k.py) | [model](https://github.com/czczup/ViT-Adapter/releases/download/v0.3.1/upernet_deit_adapter_base_512_160k_ade20k.pth.tar) \| [log](https://drive.google.com/file/d/12xHSW7_VYnzQSNzGu2EPuh6BBuozWUTn/view?usp=sharing) |
| Method | Backbone | Pre-train | Batch Size | Lr schd | Crop Size | mIoU (SS) | mIoU (MS) | #Param | Config | Download |
|:-------:|:-------------:|:---------------------------------------------------------------------------------:|:----------:|:-------:|:---------:|:---------:|:---------:|:------:|:---------------------------------------------------------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|
| UperNet | ViT-Adapter-T | [DeiT-T](https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth) | 8x2 | 160k | 512 | 42.6 | 43.6 | 36M | [config](./upernet_deit_adapter_tiny_512_160k_ade20k.py) | [model](https://github.com/czczup/ViT-Adapter/releases/download/v0.3.1/upernet_deit_adapter_tiny_512_160_ade20k.pth.tar) \| [log](https://drive.google.com/file/d/1wG_6iIaVirmqLGDZt_2rtzp_ZtNV2D4O/view?usp=sharing) |
| UperNet | ViT-Adapter-S | [DeiT-S](https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth) | 8x2 | 160k | 512 | 46.2 | 47.1 | 58M | [config](./upernet_deit_adapter_small_512_160k_ade20k.py) | [model](https://github.com/czczup/ViT-Adapter/releases/download/v0.3.1/upernet_deit_adapter_small_512_160k_ade20k.pth) \| [log]() |
| UperNet | ViT-Adapter-B | [DeiT-B](https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth) | 8x2 | 160k | 512 | 48.8 | 49.7 | 134M | [config](./upernet_deit_adapter_base_512_160k_ade20k.py) | [model](https://github.com/czczup/ViT-Adapter/releases/download/v0.3.1/upernet_deit_adapter_base_512_160k_ade20k.pth.tar) \| [log](https://drive.google.com/file/d/12xHSW7_VYnzQSNzGu2EPuh6BBuozWUTn/view?usp=sharing) |

**AugReg Pre-train (ImageNet-22K, supervised)**

Expand All @@ -26,9 +34,10 @@ The ADE20K semantic segmentation dataset contains more than 20K scene-centric

**BEiT Pre-train (ImageNet-22K, MIM)**

| Method | Backbone | Pre-train | Batch Size | Lr schd | Crop Size | mIoU (SS) | mIoU (MS) | #Param | Config | Download |
|:-----------:|:-------------:|:------------------------------------------------------------------------------------------------------------------------------------------:|:----------:|:-------:|:---------:|:------------------------------------------------------------------------------------------:|:------------------------------------------------------------------------------------------:|:------:|:----------------------------------------------------------------:|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|
| UperNet | ViT-Adapter-L | [BEiT-L](https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth) | 8x2 | 160k | 640 | [58.0](https://drive.google.com/file/d/1KsV4QPfoRi5cj2hjCzy8VfWih8xCTrE3/view?usp=sharing) | [58.4](https://drive.google.com/file/d/1haeTUvQhKCM7hunVdK60yxULbRH7YYBK/view?usp=sharing) | 451M | [config](./upernet_beit_adapter_large_640_160k_ade20k_ss.py) | [model](https://github.com/czczup/ViT-Adapter/releases/download/v0.2.1/upernet_beit_adapter_large_640_160k_ade20k.pth.tar) \| [log](https://github.com/czczup/ViT-Adapter/releases/download/v0.2.1/20220313_233147.log) |
| Mask2Former | ViT-Adapter-L | [BEiT-L](https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth) | 8x2 | 160k | 640 | [58.3](https://drive.google.com/file/d/1jj56lSbc2s4ZNc-Hi-w6o-OSS99oi-_g/view?usp=sharing) | [59.0](https://drive.google.com/file/d/1hgpZB5gsyd7LTS7Aay2CbHmlY10nafCw/view?usp=sharing) | 568M | [config](./mask2former_beit_adapter_large_640_160k_ade20k_ss.py) | [model](https://github.com/czczup/ViT-Adapter/releases/download/v0.2.2/mask2former_beit_adapter_large_640_160k_ade20k.zip) \| [log](https://github.com/czczup/ViT-Adapter/releases/download/v0.2.2/20220426_003454.log) |
| Mask2Former | ViT-Adapter-L | [BEiT-L+COCO-Stuff](https://github.com/czczup/ViT-Adapter/releases/download/v0.2.6/mask2former_beit_adapter_large_896_80k_cocostuff164k.zip) | 16x1 | 80k | 896 | [59.4](https://drive.google.com/file/d/1B_1XSwdnLhjJeUmn1g_nxfvGJpYmYWHa/view?usp=sharing) | [60.5](https://drive.google.com/file/d/1UtjmgcYKR-2h116oQXklUYOVcTw15woM/view?usp=sharing) | 571M | [config](./mask2former_beit_adapter_large_896_80k_ade20k_ss.py) | [model](https://github.com/czczup/ViT-Adapter/releases/download/v0.2.0/mask2former_beit_adapter_large_896_80k_ade20k.zip) \| [log](https://github.com/czczup/ViT-Adapter/releases/download/v0.2.0/20220430_154104.log) |
| Mask2Former | ViT-Adapter-L | [BEiTv2-L+COCO-Stuff](https://github.com/czczup/ViT-Adapter/releases/download/v0.3.1/mask2former_beitv2_adapter_large_896_80k_cocostuff164k.zip) | 16x1 | 80k | 896 | 61.2 | 61.5 | 571M | [config](./mask2former_beitv2_adapter_large_896_80k_ade20k_ss.py) | [model](https://github.com/czczup/ViT-Adapter/releases/download/v0.3.1/mask2former_beitv2_adapter_large_896_80k_ade20k.zip) \| [log](https://github.com/czczup/ViT-Adapter/releases/download/v0.3.1/20220915_112635.log) |
| Method | Backbone | Pre-train | Batch Size | Lr schd | Crop Size | mIoU (SS) | mIoU (MS) | #Param | Config | Download |
|:-----------:|:-------------:|:------------------------------------------------------------------------------------------------------------------------------------------------:|:----------:|:-------:|:---------:|:------------------------------------------------------------------------------------------:|:------------------------------------------------------------------------------------------:|:------:|:------------------------------------------------------------------:|:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|
| UperNet | ViT-Adapter-L | [BEiT-L](https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth) | 8x2 | 160k | 640 | [58.0](https://drive.google.com/file/d/1KsV4QPfoRi5cj2hjCzy8VfWih8xCTrE3/view?usp=sharing) | [58.4](https://drive.google.com/file/d/1haeTUvQhKCM7hunVdK60yxULbRH7YYBK/view?usp=sharing) | 451M | [config](./upernet_beit_adapter_large_640_160k_ade20k_ss.py) | [model](https://github.com/czczup/ViT-Adapter/releases/download/v0.2.1/upernet_beit_adapter_large_640_160k_ade20k.pth.tar) \| [log](https://github.com/czczup/ViT-Adapter/releases/download/v0.2.1/20220313_233147.log) |
| Mask2Former | ViT-Adapter-L | [BEiT-L](https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth) | 8x2 | 160k | 640 | [58.3](https://drive.google.com/file/d/1jj56lSbc2s4ZNc-Hi-w6o-OSS99oi-_g/view?usp=sharing) | [59.0](https://drive.google.com/file/d/1hgpZB5gsyd7LTS7Aay2CbHmlY10nafCw/view?usp=sharing) | 568M | [config](./mask2former_beit_adapter_large_640_160k_ade20k_ss.py) | [model](https://github.com/czczup/ViT-Adapter/releases/download/v0.2.2/mask2former_beit_adapter_large_640_160k_ade20k.zip) \| [log](https://github.com/czczup/ViT-Adapter/releases/download/v0.2.2/20220426_003454.log) |
| Mask2Former | ViT-Adapter-L | [BEiT-L+COCO-Stuff](https://github.com/czczup/ViT-Adapter/releases/download/v0.2.6/mask2former_beit_adapter_large_896_80k_cocostuff164k.zip) | 16x1 | 80k | 896 | [59.4](https://drive.google.com/file/d/1B_1XSwdnLhjJeUmn1g_nxfvGJpYmYWHa/view?usp=sharing) | [60.5](https://drive.google.com/file/d/1UtjmgcYKR-2h116oQXklUYOVcTw15woM/view?usp=sharing) | 571M | [config](./mask2former_beit_adapter_large_896_80k_ade20k_ss.py) | [model](https://github.com/czczup/ViT-Adapter/releases/download/v0.2.0/mask2former_beit_adapter_large_896_80k_ade20k.zip) \| [log](https://github.com/czczup/ViT-Adapter/releases/download/v0.2.0/20220430_154104.log) |
| Mask2Former | ViT-Adapter-L | BEiTv2 | 16x1 | 160k | 896 | - | - | 571M | [config](./mask2former_beitv2_adapter_large_896_160k_ade20k_ss.py) | - |
| Mask2Former | ViT-Adapter-L | [BEiTv2-L+COCO-Stuff](https://github.com/czczup/ViT-Adapter/releases/download/v0.3.1/mask2former_beitv2_adapter_large_896_80k_cocostuff164k.zip) | 16x1 | 80k | 896 | 61.2 | 61.5 | 571M | [config](./mask2former_beitv2_adapter_large_896_80k_ade20k_ss.py) | [model](https://github.com/czczup/ViT-Adapter/releases/download/v0.3.1/mask2former_beitv2_adapter_large_896_80k_ade20k.zip) \| [log](https://github.com/czczup/ViT-Adapter/releases/download/v0.3.1/20220915_112635.log) |
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright (c) Shanghai AI Lab. All rights reserved.
_base_ = [
'../_base_/models/mask2former_beit.py',
'../_base_/datasets/ade20k.py',
'../_base_/default_runtime.py',
'../_base_/schedules/schedule_160k.py'
]
crop_size = (896, 896)
# pretrained = 'https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21k.pth'
pretrained = 'pretrained/beitv2_large_patch16_224_pt1k_ft21k.pth'
model = dict(
type='EncoderDecoderMask2Former',
pretrained=pretrained,
backbone=dict(
type='BEiTAdapter',
img_size=896,
patch_size=16,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4,
qkv_bias=True,
use_abs_pos_emb=False,
use_rel_pos_bias=True,
init_values=1e-6,
drop_path_rate=0.3,
conv_inplane=64,
n_points=4,
deform_num_heads=16,
cffn_ratio=0.25,
deform_ratio=0.5,
with_cp=True, # set with_cp=True to save memory
interaction_indexes=[[0, 5], [6, 11], [12, 17], [18, 23]],
),
decode_head=dict(
in_channels=[1024, 1024, 1024, 1024],
feat_channels=1024,
out_channels=1024,
num_queries=200,
pixel_decoder=dict(
type='MSDeformAttnPixelDecoder',
num_outs=3,
norm_cfg=dict(type='GN', num_groups=32),
act_cfg=dict(type='ReLU'),
encoder=dict(
type='DetrTransformerEncoder',
num_layers=6,
transformerlayers=dict(
type='BaseTransformerLayer',
attn_cfgs=dict(
type='MultiScaleDeformableAttention',
embed_dims=1024,
num_heads=32,
num_levels=3,
num_points=4,
im2col_step=64,
dropout=0.0,
batch_first=False,
norm_cfg=None,
init_cfg=None),
ffn_cfgs=dict(
type='FFN',
embed_dims=1024,
feedforward_channels=4096,
num_fcs=2,
ffn_drop=0.0,
with_cp=True, # set with_cp=True to save memory
act_cfg=dict(type='ReLU', inplace=True)),
operation_order=('self_attn', 'norm', 'ffn', 'norm')),
init_cfg=None),
positional_encoding=dict(
type='SinePositionalEncoding', num_feats=512, normalize=True),
init_cfg=None),
positional_encoding=dict(
type='SinePositionalEncoding', num_feats=512, normalize=True),
transformer_decoder=dict(
type='DetrTransformerDecoder',
return_intermediate=True,
num_layers=9,
transformerlayers=dict(
type='DetrTransformerDecoderLayer',
attn_cfgs=dict(
type='MultiheadAttention',
embed_dims=1024,
num_heads=32,
attn_drop=0.0,
proj_drop=0.0,
dropout_layer=None,
batch_first=False),
ffn_cfgs=dict(
embed_dims=1024,
feedforward_channels=4096,
num_fcs=2,
act_cfg=dict(type='ReLU', inplace=True),
ffn_drop=0.0,
dropout_layer=None,
with_cp=True, # set with_cp=True to save memory
add_identity=True),
feedforward_channels=4096,
operation_order=('cross_attn', 'norm', 'self_attn', 'norm',
'ffn', 'norm')),
init_cfg=None)
),
test_cfg=dict(mode='slide', crop_size=crop_size, stride=(512, 512))
)
# dataset settings
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='Resize', img_scale=(3584, 896), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='ToMask'),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg', 'gt_masks', 'gt_labels'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(3584, 896),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='ResizeToMultiple', size_divisor=32),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
optimizer = dict(_delete_=True, type='AdamW', lr=2e-5, betas=(0.9, 0.999), weight_decay=0.05,
constructor='LayerDecayOptimizerConstructor',
paramwise_cfg=dict(num_layers=24, layer_decay_rate=0.90))
lr_config = dict(_delete_=True,
policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0, min_lr=0.0, by_epoch=False)
data = dict(samples_per_gpu=1,
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))
runner = dict(type='IterBasedRunner')
checkpoint_config = dict(by_epoch=False, interval=1000, max_keep_ckpts=1)
evaluation = dict(interval=8000, metric='mIoU', save_best='mIoU')

0 comments on commit 19ed23e

Please sign in to comment.