Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions annotator/uniformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ def __init__(self):
if not os.path.exists(modelpath):
from basicsr.utils.download_util import load_file_from_url
load_file_from_url(checkpoint_file, model_dir=annotator_ckpts_path)
# config.py(默认ADE20K)
config_file = os.path.join(os.path.dirname(annotator_ckpts_path), "uniformer", "exp", "upernet_global_small", "config.py")
self.model = init_segmentor(config_file, modelpath).cuda()

def __call__(self, img):
result = inference_segmentor(self.model, img)
# palette: cityscapes, ade, voc
res_img = show_result_pyplot(self.model, img, result, get_palette('ade'), opacity=1)
return res_img
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_base_ = ['./mask2former_r50_8xb2-90k_cityscapes-512x1024.py']

model = dict(
backbone=dict(
depth=101))
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
_base_ = ['../../configs/_base_/default_runtime.py', '../../configs/_base_/datasets/cityscapes.py']

crop_size = (512, 1024)
num_classes = 19
model = dict(
type='EncoderDecoder',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=-1,
norm_cfg=dict(type='SyncBN', requires_grad=False),
style='pytorch'),
decode_head=dict(
type='Mask2FormerHead',
in_channels=[256, 512, 1024, 2048],
strides=[4, 8, 16, 32],
feat_channels=256,
out_channels=256,
num_classes=num_classes,
num_queries=100,
num_transformer_feat_level=3,
align_corners=False,
pixel_decoder=dict(
type='mmdet.MSDeformAttnPixelDecoder',
num_outs=3,
norm_cfg=dict(type='GN', num_groups=32),
act_cfg=dict(type='ReLU'),
encoder=dict( # DeformableDetrTransformerEncoder
num_layers=6,
layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
self_attn_cfg=dict( # MultiScaleDeformableAttention
embed_dims=256,
num_heads=8,
num_levels=3,
num_points=4,
im2col_step=64,
dropout=0.0,
batch_first=True,
norm_cfg=None,
init_cfg=None),
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=1024,
num_fcs=2,
ffn_drop=0.0,
act_cfg=dict(type='ReLU', inplace=True))),
init_cfg=None),
positional_encoding=dict( # SinePositionalEncoding
num_feats=128, normalize=True),
init_cfg=None),
enforce_decoder_input_project=False,
positional_encoding=dict( # SinePositionalEncoding
num_feats=128, normalize=True),
transformer_decoder=dict( # Mask2FormerTransformerDecoder
return_intermediate=True,
num_layers=9,
layer_cfg=dict( # Mask2FormerTransformerDecoderLayer
self_attn_cfg=dict( # MultiheadAttention
embed_dims=256,
num_heads=8,
attn_drop=0.0,
proj_drop=0.0,
dropout_layer=None,
batch_first=True),
cross_attn_cfg=dict( # MultiheadAttention
embed_dims=256,
num_heads=8,
attn_drop=0.0,
proj_drop=0.0,
dropout_layer=None,
batch_first=True),
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=2048,
num_fcs=2,
act_cfg=dict(type='ReLU', inplace=True),
ffn_drop=0.0,
dropout_layer=None,
add_identity=True)),
init_cfg=None),
loss_cls=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=False,
loss_weight=2.0,
reduction='mean',
class_weight=[1.0] * num_classes + [0.1]),
loss_mask=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
reduction='mean',
loss_weight=5.0),
loss_dice=dict(
type='mmdet.DiceLoss',
use_sigmoid=True,
activate=True,
reduction='mean',
naive_dice=True,
eps=1.0,
loss_weight=5.0),
train_cfg=dict(
num_points=12544,
oversample_ratio=3.0,
importance_sample_ratio=0.75,
assigner=dict(
type='mmdet.HungarianAssigner',
match_costs=[
dict(type='mmdet.ClassificationCost', weight=2.0),
dict(
type='mmdet.CrossEntropyLossCost',
weight=5.0,
use_sigmoid=True),
dict(
type='mmdet.DiceCost',
weight=5.0,
pred_act=True,
eps=1.0)
]),
sampler=dict(type='mmdet.MaskPseudoSampler'))),
train_cfg=dict())

# dataset config
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(
type='RandomChoiceResize',
scales=[int(1024 * x * 0.1) for x in range(5, 21)],
resize_type='ResizeShortestEdge',
max_size=4096),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='PackSegInputs')
]
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))

# optimizer
embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
optimizer = dict(
type='AdamW', lr=0.0001, weight_decay=0.05, eps=1e-8, betas=(0.9, 0.999))
optim_wrapper = dict(
type='OptimWrapper',
optimizer=optimizer,
clip_grad=dict(max_norm=0.01, norm_type=2),
paramwise_cfg=dict(
custom_keys={
'backbone': dict(lr_mult=0.1, decay_mult=1.0),
'query_embed': embed_multi,
'query_feat': embed_multi,
'level_embed': embed_multi,
},
norm_decay_mult=0.0))
# learning policy
param_scheduler = [
dict(
type='PolyLR',
eta_min=0,
power=0.9,
begin=0,
end=90000,
by_epoch=False)
]

# training schedule for 90k
train_cfg = dict(type='IterBasedTrainLoop', max_iters=90000, val_interval=5000)
val_cfg = dict(type='ValLoop')
default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(
type='CheckpointHook', by_epoch=False, interval=5000,
save_best='mIoU'),
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(type='SegVisualizationHook'))

# Default setting for scaling LR automatically
# - `enable` means enable scaling LR automatically
# or not by default.
# - `base_batch_size` = (8 GPUs) x (2 samples per GPU).
auto_scale_lr = dict(enable=False, base_batch_size=16)
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
_base_ = ['./mask2former_swin-t_8xb2-90k_cityscapes-512x1024.py']
pretrained = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_large_patch4_window12_384_22k_20220412-6580f57d.pth' # noqa

depths = [2, 2, 18, 2]
model = dict(
backbone=dict(
pretrain_img_size=384,
embed_dims=192,
depths=depths,
num_heads=[6, 12, 24, 48],
window_size=12,
init_cfg=dict(type='Pretrained', checkpoint=pretrained)),
decode_head=dict(in_channels=[192, 384, 768, 1536]))

# set all layers in backbone to lr_mult=0.1
# set all norm layers, position_embeding,
# query_embeding, level_embeding to decay_multi=0.0
backbone_norm_multi = dict(lr_mult=0.1, decay_mult=0.0)
backbone_embed_multi = dict(lr_mult=0.1, decay_mult=0.0)
embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
custom_keys = {
'backbone': dict(lr_mult=0.1, decay_mult=1.0),
'backbone.patch_embed.norm': backbone_norm_multi,
'backbone.norm': backbone_norm_multi,
'absolute_pos_embed': backbone_embed_multi,
'relative_position_bias_table': backbone_embed_multi,
'query_embed': embed_multi,
'query_feat': embed_multi,
'level_embed': embed_multi
}
custom_keys.update({
f'backbone.stages.{stage_id}.blocks.{block_id}.norm': backbone_norm_multi
for stage_id, num_blocks in enumerate(depths)
for block_id in range(num_blocks)
})
custom_keys.update({
f'backbone.stages.{stage_id}.downsample.norm': backbone_norm_multi
for stage_id in range(len(depths) - 1)
})
# optimizer
optim_wrapper = dict(
paramwise_cfg=dict(custom_keys=custom_keys, norm_decay_mult=0.0))
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
_base_ = ['./mask2former_r50_8xb2-90k_cityscapes-512x1024.py']
pretrained = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_tiny_patch4_window7_224_20220317-1cdeb081.pth' # noqa
depths = [2, 2, 6, 2]
model = dict(
backbone=dict(
_delete_=True,
type='SwinTransformer',
embed_dims=96,
depths=depths,
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.3,
patch_norm=True,
out_indices=(0, 1, 2, 3),
with_cp=False,
frozen_stages=-1,
init_cfg=dict(type='Pretrained', checkpoint=pretrained)),
decode_head=dict(in_channels=[96, 192, 384, 768]))

# set all layers in backbone to lr_mult=0.1
# set all norm layers, position_embeding,
# query_embeding, level_embeding to decay_multi=0.0
backbone_norm_multi = dict(lr_mult=0.1, decay_mult=0.0)
backbone_embed_multi = dict(lr_mult=0.1, decay_mult=0.0)
embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
custom_keys = {
'backbone': dict(lr_mult=0.1, decay_mult=1.0),
'backbone.patch_embed.norm': backbone_norm_multi,
'backbone.norm': backbone_norm_multi,
'absolute_pos_embed': backbone_embed_multi,
'relative_position_bias_table': backbone_embed_multi,
'query_embed': embed_multi,
'query_feat': embed_multi,
'level_embed': embed_multi
}
custom_keys.update({
f'backbone.stages.{stage_id}.blocks.{block_id}.norm': backbone_norm_multi
for stage_id, num_blocks in enumerate(depths)
for block_id in range(num_blocks)
})
custom_keys.update({
f'backbone.stages.{stage_id}.downsample.norm': backbone_norm_multi
for stage_id in range(len(depths) - 1)
})
# optimizer
optim_wrapper = dict(
paramwise_cfg=dict(custom_keys=custom_keys, norm_decay_mult=0.0))
11 changes: 11 additions & 0 deletions cldm/cldm.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,10 +331,21 @@ def apply_model(self, x_noisy, t, cond, *args, **kwargs):

cond_txt = torch.cat(cond['c_crossattn'], 1)

# !!!
# obstruction_txt = None
# if cond['obstruction_c_crossattn'] != None:
# obstruction_txt = torch.cat(cond['obstruction_c_crossattn'], 1)

if cond['c_concat'] is None:
eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)
else:
# !!!
control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)
# if obstruction_txt != None:
# control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=obstruction_txt)
# else:
# control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)

control = [c * scale for c, scale in zip(control, self.control_scales)]
eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)

Expand Down
Loading