Skip to content

Commit

Permalink
add README.md for chase_db1 and postdom
Browse files Browse the repository at this point in the history
  • Loading branch information
czczup committed Dec 21, 2022
1 parent ef394ed commit b582ded
Show file tree
Hide file tree
Showing 8 changed files with 356 additions and 15 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,6 @@ dmypy.json
.DS_Store

detection/visual/
trash/
setr/
swin/
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

The official implementation of the paper "[Vision Transformer Adapter for Dense Predictions](https://arxiv.org/abs/2205.08534)".

https://user-images.githubusercontent.com/23737120/208140362-f2029060-eb16-4280-b85f-074006547a12.mp4

## News

(2022/10/20) ViT-Adapter is adopted by Zhang et al. and ranked 1st in the [UVO Challenge 2022](https://arxiv.org/pdf/2210.09629.pdf).\
Expand Down
2 changes: 1 addition & 1 deletion segmentation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Preparing ADE20K/Cityscapes/COCO Stuff/Pascal Context according to the [guideli
| DeiT | 2021 | Supervised | ImageNet-1K | [repo](https://github.com/facebookresearch/deit/blob/main/README_deit.md) | [paper](https://arxiv.org/abs/2012.12877) |
| AugReg | 2021 | Supervised | ImageNet-22K | [repo](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) | [paper](https://arxiv.org/abs/2106.10270) |
| BEiT | 2021 | MIM | ImageNet-22K | [repo](https://github.com/microsoft/unilm/tree/master/beit) | [paper](https://arxiv.org/abs/2106.08254) |
| Uni-Perceiver | Supervised | 2022 | Multi-Modal | [repo](https://github.com/fundamentalvision/Uni-Perceiver) | [paper](https://openaccess.thecvf.com/content/CVPR2022/papers/Zhu_Uni-Perceiver_Pre-Training_Unified_Architecture_for_Generic_Perception_for_Zero-Shot_and_CVPR_2022_paper.pdf) |
| Uni-Perceiver | 2022 | Supervised | Multi-Modal | [repo](https://github.com/fundamentalvision/Uni-Perceiver) | [paper](https://openaccess.thecvf.com/content/CVPR2022/papers/Zhu_Uni-Perceiver_Pre-Training_Unified_Architecture_for_Generic_Perception_for_Zero-Shot_and_CVPR_2022_paper.pdf) |
| BEiTv2 | 2022 | MIM | ImageNet-22K | [repo](https://github.com/microsoft/unilm/tree/master/beit2) | [paper](https://arxiv.org/abs/2208.06366) |

## Results and Models
Expand Down
138 changes: 138 additions & 0 deletions segmentation/configs/_base_/models/mask2former_beit_chase_db1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# model_cfg
num_things_classes = 0
num_stuff_classes = 2
num_classes = num_things_classes + num_stuff_classes
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoderMask2Former',
pretrained=None,
backbone=dict(
type='BEiT',
patch_size=16,
embed_dim=384,
depth=12,
num_heads=8,
mlp_ratio=4,
qkv_bias=True,
use_abs_pos_emb=True,
use_rel_pos_bias=False,
),
decode_head=dict(
type='Mask2FormerHead',
in_channels=[256, 512, 1024, 2048], # pass to pixel_decoder inside
# strides=[4, 8, 16, 32],
feat_channels=256,
out_channels=256,
in_index=[0, 1, 2, 3],
num_things_classes=num_things_classes,
num_stuff_classes=num_stuff_classes,
num_queries=100,
num_transformer_feat_level=3,
pixel_decoder=dict(
type='MSDeformAttnPixelDecoder',
num_outs=3,
norm_cfg=dict(type='GN', num_groups=32),
act_cfg=dict(type='ReLU'),
encoder=dict(
type='DetrTransformerEncoder',
num_layers=6,
transformerlayers=dict(
type='BaseTransformerLayer',
attn_cfgs=dict(
type='MultiScaleDeformableAttention',
embed_dims=256,
num_heads=8,
num_levels=3,
num_points=4,
im2col_step=64,
dropout=0.0,
batch_first=False,
norm_cfg=None,
init_cfg=None),
ffn_cfgs=dict(
type='FFN',
embed_dims=256,
feedforward_channels=1024,
num_fcs=2,
ffn_drop=0.0,
act_cfg=dict(type='ReLU', inplace=True)),
operation_order=('self_attn', 'norm', 'ffn', 'norm')),
init_cfg=None),
positional_encoding=dict(
type='SinePositionalEncoding', num_feats=128, normalize=True),
init_cfg=None),
enforce_decoder_input_project=False,
positional_encoding=dict(
type='SinePositionalEncoding', num_feats=128, normalize=True),
transformer_decoder=dict(
type='DetrTransformerDecoder',
return_intermediate=True,
num_layers=9,
transformerlayers=dict(
type='DetrTransformerDecoderLayer',
attn_cfgs=dict(
type='MultiheadAttention',
embed_dims=256,
num_heads=8,
attn_drop=0.0,
proj_drop=0.0,
dropout_layer=None,
batch_first=False),
ffn_cfgs=dict(
embed_dims=256,
feedforward_channels=2048,
num_fcs=2,
act_cfg=dict(type='ReLU', inplace=True),
ffn_drop=0.0,
dropout_layer=None,
add_identity=True),
feedforward_channels=2048,
operation_order=('cross_attn', 'norm', 'self_attn', 'norm',
'ffn', 'norm')),
init_cfg=None),
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=2.0,
reduction='mean',
class_weight=[1.0] * num_classes + [0.1]),
loss_mask=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
reduction='mean',
loss_weight=5.0),
loss_dice=dict(
type='DiceLoss',
use_sigmoid=True,
activate=True,
reduction='mean',
naive_dice=True,
eps=1.0,
loss_weight=5.0)),
train_cfg=dict(
num_points=12544,
oversample_ratio=3.0,
importance_sample_ratio=0.75,
assigner=dict(
type='MaskHungarianAssigner',
cls_cost=dict(type='ClassificationCost', weight=2.0),
mask_cost=dict(
type='CrossEntropyLossCost', weight=5.0, use_sigmoid=True),
dice_cost=dict(
type='DiceCost', weight=5.0, pred_act=True, eps=1.0)),
sampler=dict(type='MaskPseudoSampler')),
test_cfg=dict(
panoptic_on=True,
# For now, the dataset does not support
# evaluating semantic segmentation metric.
semantic_on=False,
instance_on=True,
# max_per_image is for instance segmentation.
max_per_image=100,
iou_thr=0.8,
# In Mask2Former's panoptic postprocessing,
# it will filter mask area where score is less than 0.5 .
filter_low_score=True),
init_cfg=None)

# find_unused_parameters = True
21 changes: 21 additions & 0 deletions segmentation/configs/chase_db1/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# CHASE DB1

<!-- [ALGORITHM] -->

## Introduction

The training and validation set of CHASE DB1 could be download from [here](https://staffnet.kingston.ac.uk/~ku15565/CHASE_DB1/assets/CHASEDB1.zip).

To convert CHASE DB1 dataset to MMSegmentation format, you should run the [script](https://github.com/open-mmlab/mmsegmentation/blob/master/tools/convert_datasets/chase_db1.py) provided by mmseg official:

```shell
python /path/to/convertor/chase_db1.py /path/to/CHASEDB1.zip
```

The script will make directory structure automatically.

## Results and Models

| Method | Backbone | Pre-train | Batch Size | Lr schd | Crop Size | mDice | #Param | Config | Download |
|:-----------:|:-------------:|:---------:|:----------:|:-------:|:---------:|:---------:|:------:|:----------------------------------------------------------------:|:------------------------------------------------------:|
| Mask2Former | ViT-Adapter-L | BEiT-L | 4x4 | 40k | 128 | 89.4 | 350M | [config](./mask2former_beit_adapter_large_128_40k_chase_db1_ss.py) | [log](https://github.com/czczup/ViT-Adapter/issues/11) |
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Copyright (c) Shanghai AI Lab. All rights reserved.
_base_ = [
'../_base_/models/mask2former_beit_chase_db1.py',
'../_base_/datasets/chase_db1.py',
'../_base_/default_runtime.py',
'../_base_/schedules/schedule_40k.py'
]
crop_size = (128, 128)
img_scale = (960, 999)
# pretrained = 'https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth'
pretrained = 'pretrained/beit_large_patch16_224_pt22k_ft22k.pth'
model = dict(
type='EncoderDecoderMask2Former',
pretrained=pretrained,
backbone=dict(
type='BEiTAdapter',
img_size=crop_size[0],
patch_size=16,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4,
qkv_bias=True,
use_abs_pos_emb=False,
use_rel_pos_bias=True,
init_values=1e-6,
drop_path_rate=0.3,
conv_inplane=64,
n_points=4,
deform_num_heads=16,
cffn_ratio=0.25,
deform_ratio=0.5,
with_cp=True, # set with_cp=True to save memory
interaction_indexes=[[0, 5], [6, 11], [12, 17], [18, 23]],
),
decode_head=dict(
in_channels=[1024, 1024, 1024, 1024],
feat_channels=256,
out_channels=256,
num_queries=100,
pixel_decoder=dict(
type='MSDeformAttnPixelDecoder',
num_outs=3,
norm_cfg=dict(type='GN', num_groups=32),
act_cfg=dict(type='ReLU'),
encoder=dict(
type='DetrTransformerEncoder',
num_layers=6,
transformerlayers=dict(
type='BaseTransformerLayer',
attn_cfgs=dict(
type='MultiScaleDeformableAttention',
embed_dims=256,
num_heads=8,
num_levels=3,
num_points=4,
im2col_step=64,
dropout=0.0,
batch_first=False,
norm_cfg=None,
init_cfg=None),
ffn_cfgs=dict(
type='FFN',
embed_dims=256,
feedforward_channels=2048,
num_fcs=2,
ffn_drop=0.0,
with_cp=True, # set with_cp=True to save memory
act_cfg=dict(type='ReLU', inplace=True)),
operation_order=('self_attn', 'norm', 'ffn', 'norm')),
init_cfg=None),
positional_encoding=dict(
type='SinePositionalEncoding', num_feats=128, normalize=True),
init_cfg=None),
positional_encoding=dict(
type='SinePositionalEncoding', num_feats=128, normalize=True),
transformer_decoder=dict(
type='DetrTransformerDecoder',
return_intermediate=True,
num_layers=9,
transformerlayers=dict(
type='DetrTransformerDecoderLayer',
attn_cfgs=dict(
type='MultiheadAttention',
embed_dims=256,
num_heads=8,
attn_drop=0.0,
proj_drop=0.0,
dropout_layer=None,
batch_first=False),
ffn_cfgs=dict(
embed_dims=256,
feedforward_channels=2048,
num_fcs=2,
act_cfg=dict(type='ReLU', inplace=True),
ffn_drop=0.0,
dropout_layer=None,
with_cp=True, # set with_cp=True to save memory
add_identity=True),
feedforward_channels=2048,
operation_order=('cross_attn', 'norm', 'self_attn', 'norm',
'ffn', 'norm')),
init_cfg=None)
),
test_cfg=dict(mode='slide', crop_size=crop_size, stride=(85, 85))
)
# dataset settings
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='ToMask'),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg', 'gt_masks', 'gt_labels'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=img_scale,
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='ResizeToMultiple', size_divisor=32),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
optimizer = dict(_delete_=True, type='AdamW', lr=2e-5, betas=(0.9, 0.999), weight_decay=0.05,
constructor='LayerDecayOptimizerConstructor',
paramwise_cfg=dict(num_layers=24, layer_decay_rate=0.90))
lr_config = dict(_delete_=True,
policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0, min_lr=0.0, by_epoch=False)
data = dict(samples_per_gpu=4,
train=dict(dataset=dict(pipeline=train_pipeline)),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))
runner = dict(type='IterBasedRunner')
checkpoint_config = dict(by_epoch=False, interval=1000, max_keep_ckpts=1)
evaluation = dict(interval=4000, metric='mDice', save_best='mDice')
23 changes: 23 additions & 0 deletions segmentation/configs/potsdam/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# ISPRS Potsdam

<!-- [ALGORITHM] -->

## Introduction

The Potsdam dataset is for urban semantic segmentation used in the 2D Semantic Labeling Contest - Potsdam.

The dataset can be requested at the challenge [homepage](https://www2.isprs.org/commissions/comm2/wg4/benchmark/data-request-form/). The `2_Ortho_RGB.zip` and `5_Labels_all_noBoundary.zip` are required.

For Potsdam dataset, please run the [script](https://github.com/open-mmlab/mmsegmentation/blob/master/tools/convert_datasets/potsdam.py) provided by mmseg official to download and re-organize the dataset.

```python
python /path/to/convertor/potsdam.py /path/to/potsdam
```

In the default setting, it will generate 3456 images for training and 2016 images for validation.

## Results and Models

| Method | Backbone | Pre-train | Batch Size | Lr schd | Crop Size | mIoU (SS) | #Param | Config | Download |
|:-----------:|:-------------:|:---------:|:----------:|:-------:|:---------:|:---------:|:------:|:----------------------------------------------------------------:|:------------------------------------------------------:|
| Mask2Former | ViT-Adapter-L | BEiT-L | 8x1 | 80k | 512 | 80.0 | 352M | [config](./mask2former_beit_adapter_large_512_80k_potsdam_ss.py) | [log](https://github.com/czczup/ViT-Adapter/issues/38) |
Loading

0 comments on commit b582ded

Please sign in to comment.