Skip to content

Add SwinTransformer2D,3D #89

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion connectomics/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# -----------------------------------------------------------------------------
_C.SYSTEM = CN()

_C.SYSTEM.NUM_GPUS = 4
_C.SYSTEM.NUM_GPUS = 1
_C.SYSTEM.NUM_CPUS = 4
# Run distributed training using DistributedDataparallel model
_C.SYSTEM.DISTRIBUTED = False
Expand Down Expand Up @@ -109,6 +109,27 @@
# Predict an auxiliary output (only works with 2D DeeplabV3)
_C.MODEL.AUX_OUT = False

## EXCLUSIVE TO SWINTRANSFORMERS

_C.MODEL.PATCH_SIZE = (4,4,4)
_C.MODEL.DEPTHS = [2,2,2,2,2]
_C.MODEL.NUM_HEADS = [3,6,12,24,24]
_C.MODEL.WINDOW_SIIE = (2,7,7)
_C.MODEL.MLP_RATIO = 4.
_C.MODEL.QKV_BIAS = True
_C.MODEL.QK_SCALE = None
_C.MODEL.DROP_RATE = 0.
_C.MODEL.ATTN_DROP_RATE = 0.
_C.MODEL.DROP_PATH_RATE = 0.2
_C.MODEL.USE_CONV = False
_C.MODEL.PATCH_NORM = False
_C.MODEL.FROZEN_STAGES = -1
_C.MODEL.USE_CHECKPOINT = False
_C.MODEL.EMBED_DIM = 96
_C.MODEL.DOWNSAMPLE_BEFORE = [True, True, True, True]
_C.MODEL.SWIN_ISOTROPY = [True, True, True, True]


# -----------------------------------------------------------------------------
# Dataset
# -----------------------------------------------------------------------------
Expand Down
19 changes: 15 additions & 4 deletions connectomics/model/arch/fpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self,
self.filters = filters
self.depth = len(filters)

assert len(isotropy) == self.depth
# assert len(isotropy) == self.depth
if is_isotropic:
isotropy = [True] * self.depth
self.isotropy = isotropy
Expand All @@ -77,21 +77,26 @@ def __init__(self,
'attention': attn,
}
backbone_kwargs.update(self.shared_kwargs)
self.is_swin = False
if backbone_type == 'swintransformer3d':
backbone_kwargs.update(kwargs)
self.shared_kwargs['norm_mode'] = 'layer'
self.is_swin = True

self.backbone = build_backbone(
backbone_type, feature_keys, **backbone_kwargs)
self.feature_keys = feature_keys

self.latplanes = filters[0]
self.latlayers = nn.ModuleList([
conv3d_norm_act(x, self.latplanes, kernel_size=1, padding=0,
conv3d_norm_act(x, self.latplanes, kernel_size=1, padding=0, is_swin=self.is_swin,
**self.shared_kwargs) for x in filters])

self.smooth = nn.ModuleList()
for i in range(self.depth):
kernel_size, padding = self._get_kernel_size(isotropy[i])
self.smooth.append(conv3d_norm_act(
self.latplanes, self.latplanes, kernel_size=kernel_size,
self.latplanes, self.latplanes, kernel_size=kernel_size, is_swin=self.is_swin,
padding=padding, **self.shared_kwargs))

self.conv_out = self._get_io_conv(out_channel, isotropy[0])
Expand All @@ -100,6 +105,7 @@ def __init__(self,
model_init(self, init_mode)

def forward(self, x):
self.x_size = x.size()
z = self.backbone(x)
return self._forward_main(z)

Expand All @@ -113,6 +119,11 @@ def _forward_main(self, z):
out = self._up_smooth_add(out, features[i-1], self.smooth[i])
out = self.smooth[0](out)
out = self.conv_out(out)
if self.is_swin:
b,c,d,h,w = self.x_size
_b,_c,_d,_h,_w = out.size()
if _d != d or _h != h or _w != w:
out = F.interpolate(out,size=(d,h,w),mode='trilinear')
return out

def _up_smooth_add(self, x, y, smooth):
Expand All @@ -138,4 +149,4 @@ def _get_io_conv(self, out_channel, is_isotropic):
return conv3d_norm_act(
self.filters[0], out_channel, kernel_size_io, padding=padding_io,
pad_mode=self.shared_kwargs['pad_mode'], bias=True,
act_mode='none', norm_mode='none')
act_mode='none', norm_mode='none',is_swin=self.is_swin,)
1 change: 1 addition & 0 deletions connectomics/model/backbone/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .resnet import ResNet3D
from .repvgg import RepVGG3D, RepVGGBlock3D
from .botnet import BotNet3D
from .swintr import SwinTransformer2D,SwinTransformer3D
13 changes: 11 additions & 2 deletions connectomics/model/backbone/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,35 @@
from .repvgg import RepVGG3D
from .botnet import BotNet3D
from .efficientnet import EfficientNet3D
from .swintr import SwinTransformer2D,SwinTransformer3D
from ..utils.misc import IntermediateLayerGetter

backbone_dict = {
'resnet': ResNet3D,
'repvgg': RepVGG3D,
'botnet': BotNet3D,
'efficientnet': EfficientNet3D,
'swintransformer2d': SwinTransformer2D,
'swintransformer3d': SwinTransformer3D,
}


def build_backbone(backbone_type: str,
feat_keys: List[str],
**kwargs):
assert backbone_type in ['resnet', 'repvgg', 'botnet', 'efficientnet']
assert backbone_type in ['resnet', 'repvgg', 'botnet', 'efficientnet','swintransformer2d','swintransformer3d']
return_layers = {'layer0': feat_keys[0],
'layer1': feat_keys[1],
'layer2': feat_keys[2],
'layer3': feat_keys[3],
'layer4': feat_keys[4]}

backbone = backbone_dict[backbone_type](**kwargs)
assert len(feat_keys) == backbone.num_stages
if backbone_type[:15] =='swintransformer':
if backbone.use_conv:
assert len(feat_keys) == backbone.num_layers + 2
else:
assert len(feat_keys) == backbone.num_layers + 1
else:
assert len(feat_keys) == backbone.num_stages
return IntermediateLayerGetter(backbone, return_layers)
Loading