From 42914eff3468e97394a1ed07ddb7e2d028ec1703 Mon Sep 17 00:00:00 2001 From: HarshSulakhe Date: Fri, 29 Oct 2021 18:28:17 +0530 Subject: [PATCH 1/6] Add SwinTransformer2D,3D --- connectomics/config/defaults.py | 23 + connectomics/model/arch/__init__.py | 3 + connectomics/model/arch/swintr.py | 1216 +++++++++++++++++++++++++++ connectomics/model/build.py | 33 +- 4 files changed, 1274 insertions(+), 1 deletion(-) create mode 100644 connectomics/model/arch/swintr.py diff --git a/connectomics/config/defaults.py b/connectomics/config/defaults.py index 436ba09d..887c65e7 100755 --- a/connectomics/config/defaults.py +++ b/connectomics/config/defaults.py @@ -109,6 +109,29 @@ # Predict an auxiliary output (only works with 2D DeeplabV3) _C.MODEL.AUX_OUT = False +## EXCLUSIVE TO SWINTRANSFORMERS + +_C.MODEL.PATCH_SIZE = (4,4,4) +_C.MODEL.DEPTHS = [2,2,6,2] +_C.MODEL.NUM_HEADS = [3,6,12,24] +_C.MODEL.WINDOW_SIIE = (2,7,7) +_C.MODEL.MLP_RATIO = 4. +_C.MODEL.QKV_BIAS = True +_C.MODEL.QK_SCALE = None +_C.MODEL.DROP_RATE = 0. +_C.MODEL.ATTN_DROP_RATE = 0. +_C.MODEL.DROP_PATH_RATE = 0.2 +#!!!!!!!!!!!!!!! +# _C.MODEL.NORM_LAYER = nn.LayerNorm +_C.MODEL.PATCH_NORM = False +_C.MODEL.FROZEN_STAGES = -1 +_C.MODEL.USE_CHECKPOINT = False +_C.MODEL.EMBED_DIM = 96 + +## EXCLUSIVE TO SWINTRANSFORMER2D +_C.MODEL.APE = False +_C.MODEL.OUT_INDICES = (0, 1, 2, 3) + # ----------------------------------------------------------------------------- # Dataset # ----------------------------------------------------------------------------- diff --git a/connectomics/model/arch/__init__.py b/connectomics/model/arch/__init__.py index 5f4d3092..30b9afa8 100755 --- a/connectomics/model/arch/__init__.py +++ b/connectomics/model/arch/__init__.py @@ -1,6 +1,7 @@ from .unet import UNet3D, UNet2D, UNetPlus3D from .fpn import FPN3D from .deeplab import DeepLabV3 +from .swintr import SwinTransformer3D,SwinTransformer2D __all__ = [ 'UNet3D', @@ -8,4 +9,6 @@ 'UNet2D', 'FPN3D', 'DeepLabV3', + 'SwinTransformer3D', + 'SwinTransformer2D' ] diff --git a/connectomics/model/arch/swintr.py b/connectomics/model/arch/swintr.py new file mode 100644 index 00000000..a691736b --- /dev/null +++ b/connectomics/model/arch/swintr.py @@ -0,0 +1,1216 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +import numpy as np +from timm.models.layers import DropPath, trunc_normal_,to_2tuple + +from mmcv.runner import load_checkpoint +from mmaction.utils import get_root_logger + +from functools import reduce, lru_cache +from operator import mul +from einops import rearrange + + +class Mlp(nn.Module): + """ Multilayer perceptron.""" + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + +def window_partition_3d(x, window_size): + """ + Args: + x: (B, D, H, W, C) + window_size (tuple[int]): window size + Returns: + windows: (B*num_windows, window_size*window_size, C) + """ + B, D, H, W, C = x.shape + x = x.view(B, D // window_size[0], window_size[0], H // window_size[1], window_size[1], W // window_size[2], window_size[2], C) + windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, reduce(mul, window_size), C) + return windows + +def window_reverse_3d(windows, window_size, B, D, H, W): + """ + Args: + windows: (B*num_windows, window_size, window_size, C) + window_size (tuple[int]): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, D, H, W, C) + """ + x = windows.view(B, D // window_size[0], H // window_size[1], W // window_size[2], window_size[0], window_size[1], window_size[2], -1) + x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, D, H, W, -1) + return x + +def get_window_size_3d(x_size, window_size, shift_size=None): + use_window_size = list(window_size) + if shift_size is not None: + use_shift_size = list(shift_size) + for i in range(len(x_size)): + if x_size[i] <= window_size[i]: + use_window_size[i] = x_size[i] + if shift_size is not None: + use_shift_size[i] = 0 + + if shift_size is None: + return tuple(use_window_size) + else: + return tuple(use_window_size), tuple(use_shift_size) + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + +class WindowAttention3D(nn.Module): + """ Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The temporal length, height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wd, Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), num_heads)) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_d = torch.arange(self.window_size[0]) + coords_h = torch.arange(self.window_size[1]) + coords_w = torch.arange(self.window_size[2]) + coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w)) # 3, Wd, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 3, Wd*Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 3, Wd*Wh*Ww, Wd*Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 2] += self.window_size[2] - 1 + + relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1) + relative_coords[:, :, 1] *= (2 * self.window_size[2] - 1) + relative_position_index = relative_coords.sum(-1) # Wd*Wh*Ww, Wd*Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ Forward function. + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, N, N) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B_, nH, N, C + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index[:N, :N].reshape(-1)].reshape( + N, N, -1) # Wd*Wh*Ww,Wd*Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wd*Wh*Ww, Wd*Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) # B_, nH, N, N + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + +class WindowAttention(nn.Module): + """ Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ Forward function. + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + +class SwinTransformerBlock3D(nn.Module): + """ Swin Transformer Block. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (tuple[int]): Window size. + shift_size (tuple[int]): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, num_heads, window_size=(2,7,7), shift_size=(0,0,0), + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_checkpoint=False): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + self.use_checkpoint=use_checkpoint + + assert 0 <= self.shift_size[0] < self.window_size[0], "shift_size must in 0-window_size" + assert 0 <= self.shift_size[1] < self.window_size[1], "shift_size must in 0-window_size" + assert 0 <= self.shift_size[2] < self.window_size[2], "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention3D( + dim, window_size=self.window_size, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward_part1(self, x, mask_matrix): + B, D, H, W, C = x.shape + window_size, shift_size = get_window_size_3d((D, H, W), self.window_size, self.shift_size) + + x = self.norm1(x) + # pad feature maps to multiples of window size + pad_l = pad_t = pad_d0 = 0 + pad_d1 = (window_size[0] - D % window_size[0]) % window_size[0] + pad_b = (window_size[1] - H % window_size[1]) % window_size[1] + pad_r = (window_size[2] - W % window_size[2]) % window_size[2] + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1)) + _, Dp, Hp, Wp, _ = x.shape + # cyclic shift + if any(i > 0 for i in shift_size): + shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + # partition windows + x_windows = window_partition_3d(shifted_x, window_size) # B*nW, Wd*Wh*Ww, C + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=attn_mask) # B*nW, Wd*Wh*Ww, C + # merge windows + attn_windows = attn_windows.view(-1, *(window_size+(C,))) + shifted_x = window_reverse_3d(attn_windows, window_size, B, Dp, Hp, Wp) # B D' H' W' C + # reverse cyclic shift + if any(i > 0 for i in shift_size): + x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3)) + else: + x = shifted_x + + if pad_d1 >0 or pad_r > 0 or pad_b > 0: + x = x[:, :D, :H, :W, :].contiguous() + return x + + def forward_part2(self, x): + return self.drop_path(self.mlp(self.norm2(x))) + + def forward(self, x, mask_matrix): + """ Forward function. + Args: + x: Input feature, tensor size (B, D, H, W, C). + mask_matrix: Attention mask for cyclic shift. + """ + + shortcut = x + if self.use_checkpoint: + x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix) + else: + x = self.forward_part1(x, mask_matrix) + x = shortcut + self.drop_path(x) + + if self.use_checkpoint: + x = x + checkpoint.checkpoint(self.forward_part2, x) + else: + x = x + self.forward_part2(x) + + return x + +class SwinTransformerBlock(nn.Module): + """ Swin Transformer Block. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + self.H = None + self.W = None + + def forward(self, x, mask_matrix): + """ Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + mask_matrix: Attention mask for cyclic shift. + """ + B, L, C = x.shape + H, W = self.H, self.W + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + +class PatchMerging3D(nn.Module): + + """ Patch Merging Layer + Args: + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ Forward function. + Args: + x: Input feature, tensor size (B, D, H, W, C). + """ + B, D, H, W, C = x.shape + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[:, :, 0::2, 0::2, :] # B D H/2 W/2 C + x1 = x[:, :, 1::2, 0::2, :] # B D H/2 W/2 C + x2 = x[:, :, 0::2, 1::2, :] # B D H/2 W/2 C + x3 = x[:, :, 1::2, 1::2, :] # B D H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B D H/2 W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + +class PatchMerging(nn.Module): + """ Patch Merging Layer + Args: + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x, H, W): + """ Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + +@lru_cache() +def compute_mask(D, H, W, window_size, shift_size, device): + + img_mask = torch.zeros((1, D, H, W, 1), device=device) # 1 Dp Hp Wp 1 + cnt = 0 + for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0],None): + for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1],None): + for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2],None): + img_mask[:, d, h, w, :] = cnt + cnt += 1 + mask_windows = window_partition(img_mask, window_size) # nW, ws[0]*ws[1]*ws[2], 1 + mask_windows = mask_windows.squeeze(-1) # nW, ws[0]*ws[1]*ws[2] + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + return attn_mask + +class BasicLayer3D(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (tuple[int]): Local window size. Default: (1,7,7). + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + """ + + def __init__(self, + dim, + depth, + num_heads, + window_size=(1,7,7), + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False): + super().__init__() + self.window_size = window_size + self.shift_size = tuple(i // 2 for i in window_size) + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock3D( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=(0,0,0) if (i % 2 == 0) else self.shift_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + use_checkpoint=use_checkpoint, + ) + for i in range(depth)]) + + self.downsample = downsample + if self.downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + + def forward(self, x): + """ Forward function. + Args: + x: Input feature, tensor size (B, C, D, H, W). + """ + # calculate attention mask for SW-MSA + B, C, D, H, W = x.shape + window_size, shift_size = get_window_size_3d((D,H,W), self.window_size, self.shift_size) + x = rearrange(x, 'b c d h w -> b d h w c') + Dp = int(np.ceil(D / window_size[0])) * window_size[0] + Hp = int(np.ceil(H / window_size[1])) * window_size[1] + Wp = int(np.ceil(W / window_size[2])) * window_size[2] + attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device) + for blk in self.blocks: + x = blk(x, attn_mask) + x = x.view(B, D, H, W, -1) + + if self.downsample is not None: + x = self.downsample(x) + x = rearrange(x, 'b d h w c -> b c d h w') + return x + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (int): Local window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + dim, + depth, + num_heads, + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False): + super().__init__() + self.window_size = window_size + self.shift_size = window_size // 2 + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, H, W): + """ Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + + # calculate attention mask for SW-MSA + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + for blk in self.blocks: + blk.H, blk.W = H, W + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, attn_mask) + else: + x = blk(x, attn_mask) + if self.downsample is not None: + x_down = self.downsample(x, H, W) + Wh, Ww = (H + 1) // 2, (W + 1) // 2 + return x, H, W, x_down, Wh, Ww + else: + return x, H, W, x, H, W + +class PatchEmbed3D(nn.Module): + """ Video to Patch Embedding. + Args: + patch_size (int): Patch token size. Default: (2,4,4). + in_channel (int): Number of input video channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + def __init__(self, patch_size=(2,4,4), in_channel=3, embed_dim=96, norm_layer=None): + super().__init__() + self.patch_size = patch_size + + self.in_channel = in_channel + self.embed_dim = embed_dim + + self.proj = nn.Conv3d(in_channel, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, D, H, W = x.size() + if W % self.patch_size[2] != 0: + x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2])) + if H % self.patch_size[1] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1])) + if D % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0])) + + x = self.proj(x) # B C D Wh Ww + if self.norm is not None: + D, Wh, Ww = x.size(2), x.size(3), x.size(4) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww) + + return x + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + Args: + patch_size (int): Patch token size. Default: 4. + in_channel (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, patch_size=4, in_channel=3, embed_dim=96, norm_layer=None): + super().__init__() + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + + self.in_channel = in_channel + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_channel, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, H, W = x.size() + if W % self.patch_size[1] != 0: + x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + + x = self.proj(x) # B C Wh Ww + if self.norm is not None: + Wh, Ww = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) + + return x + +class SwinTransformer3D(nn.Module): + """ Swin Transformer backbone. + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + patch_size (int | tuple(int)): Patch size. Default: (4,4,4). + in_channel (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + depths (tuple[int]): Depths of each Swin Transformer stage. + num_heads (tuple[int]): Number of attention head of each stage. + window_size (int): Window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: Truee + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer: Normalization layer. Default: nn.LayerNorm. + patch_norm (bool): If True, add normalization after patch embedding. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + """ + + def __init__(self, + pretrained=None, + pretrained2d=True, + patch_size=(4,4,4), + in_channel=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=(2,7,7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + patch_norm=False, + frozen_stages=-1, + use_checkpoint=False, + **kwargs): + super().__init__() + + self.pretrained = pretrained + self.pretrained2d = pretrained2d + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.patch_norm = patch_norm + self.frozen_stages = frozen_stages + self.window_size = window_size + self.patch_size = patch_size + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed3D( + patch_size=patch_size, in_channel=in_channel, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer3D( + dim=int(embed_dim * 2**i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging3D if i_layer= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1: + self.pos_drop.eval() + for i in range(0, self.frozen_stages): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def inflate_weights(self, logger): + """Inflate the swin2d parameters to swin3d. + The differences between swin3d and swin2d mainly lie in an extra + axis. To utilize the pretrained parameters in 2d model, + the weight of swin2d models should be inflated to fit in the shapes of + the 3d counterpart. + Args: + logger (logging.Logger): The logger used to print + debugging infomation. + """ + checkpoint = torch.load(self.pretrained, map_location='cpu') + state_dict = checkpoint['model'] + + # delete relative_position_index since we always re-init it + relative_position_index_keys = [k for k in state_dict.keys() if "relative_position_index" in k] + for k in relative_position_index_keys: + del state_dict[k] + + # delete attn_mask since we always re-init it + attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k] + for k in attn_mask_keys: + del state_dict[k] + + state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).repeat(1,1,self.patch_size[0],1,1) / self.patch_size[0] + + # bicubic interpolate relative_position_bias_table if not match + relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k] + for k in relative_position_bias_table_keys: + relative_position_bias_table_pretrained = state_dict[k] + relative_position_bias_table_current = self.state_dict()[k] + L1, nH1 = relative_position_bias_table_pretrained.size() + L2, nH2 = relative_position_bias_table_current.size() + L2 = (2*self.window_size[1]-1) * (2*self.window_size[2]-1) + wd = self.window_size[0] + if nH1 != nH2: + logger.warning(f"Error in loading {k}, passing") + else: + if L1 != L2: + S1 = int(L1 ** 0.5) + relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( + relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(2*self.window_size[1]-1, 2*self.window_size[2]-1), + mode='bicubic') + relative_position_bias_table_pretrained = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0) + state_dict[k] = relative_position_bias_table_pretrained.repeat(2*wd-1,1) + + msg = self.load_state_dict(state_dict, strict=False) + logger.info(msg) + logger.info(f"=> loaded successfully '{self.pretrained}'") + del checkpoint + torch.cuda.empty_cache() + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if pretrained: + self.pretrained = pretrained + if isinstance(self.pretrained, str): + self.apply(_init_weights) + logger = get_root_logger() + logger.info(f'load model from: {self.pretrained}') + + if self.pretrained2d: + # Inflate 2D model into 3D model. + self.inflate_weights(logger) + else: + # Directly load 3D model. + load_checkpoint(self, self.pretrained, strict=False, logger=logger) + elif self.pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + """Forward function.""" + x = self.patch_embed(x) + + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x.contiguous()) + + x = rearrange(x, 'n c d h w -> n d h w c') + x = self.norm(x) + x = rearrange(x, 'n d h w c -> n c d h w') + + return x + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer3D, self).train(mode) + self._freeze_stages() + +class SwinTransformer2D(nn.Module): + """ Swin Transformer backbone. + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + pretrain_img_size (int): Input image size for training the pretrained model, + used in absolute postion embedding. Default 224. + patch_size (int | tuple(int)): Patch size. Default: 4. + in_channel (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + depths (tuple[int]): Depths of each Swin Transformer stage. + num_heads (tuple[int]): Number of attention head of each stage. + window_size (int): Window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. + patch_norm (bool): If True, add normalization after patch embedding. Default: True. + out_indices (Sequence[int]): Output from which stages. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + pretrain_img_size=224, + patch_size=4, + in_channel=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + use_checkpoint=False, + **kwargs): + super().__init__() + + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, in_channel=in_channel, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_2tuple(pretrain_img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]] + + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint) + self.layers.append(layer) + + num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer]) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if isinstance(pretrained, str): + self.apply(_init_weights) + logger = get_root_logger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + """Forward function.""" + x = self.patch_embed(x) + + Wh, Ww = x.size(2), x.size(3) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic') + x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C + else: + x = x.flatten(2).transpose(1, 2) + x = self.pos_drop(x) + + outs = [] + for i in range(self.num_layers): + layer = self.layers[i] + x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) + + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + x_out = norm_layer(x_out) + + out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() + outs.append(out) + + return tuple(outs) + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer2D, self).train(mode) + self._freeze_stages() \ No newline at end of file diff --git a/connectomics/model/build.py b/connectomics/model/build.py index 149c32f4..788f2de6 100755 --- a/connectomics/model/build.py +++ b/connectomics/model/build.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn -from .arch import UNet3D, UNet2D, FPN3D, DeepLabV3, UNetPlus3D +from .arch import UNet3D, UNet2D, FPN3D, DeepLabV3, UNetPlus3D, SwinTransformer2D, SwinTransformer3D from .backbone import RepVGG3D MODEL_MAP = { @@ -13,6 +13,8 @@ 'deeplabv3a': DeepLabV3, 'deeplabv3b': DeepLabV3, 'deeplabv3c': DeepLabV3, + 'swintransformer2d' : SwinTransformer2D, + 'swintransformer3d': SwinTransformer3D, } @@ -46,6 +48,35 @@ def build_model(cfg, device, rank=None): kwargs['name'] = model_arch kwargs['backbone_type'] = cfg.MODEL.BACKBONE kwargs['aux_out'] = cfg.MODEL.AUX_OUT + + if model_arch[:15] == 'swintransformer': + kwargs = { + 'patch_size': cfg.MODEL.PATCH_SIZE, + 'in_channel': cfg.MODEL.IN_PLANES, + 'depths': cfg.MODEL.DEPTHS, + 'num_heads': cfg.MODEL.NUM_HEADS, + 'window_size': cfg.MODEL.WINDOW_SIIE, + 'mlp_ratio': cfg.MODEL.MLP_RATIO, + 'qkv_bias': cfg.MODEL.QKV_BIAS, + 'qk_scale': cfg.MODEL.QK_SCALE, + 'drop_rate': cfg.MODEL.DROP_RATE, + 'attn_drop_rate': cfg.MODEL.ATTN_DROP_RATE, + 'drop_path_rate': cfg.MODEL.DROP_PATH_RATE, + # 'norm_layer': cfg.MODEL.NORM_LAYER, + 'embed_dim': cfg.MODEL.EMBED_DIM, + 'patch_norm': cfg.MODEL.PATCH_NORM, + + 'frozen_stages': cfg.MODEL.FROZEN_STAGES, + 'use_checkpoint': cfg.MODEL.USE_CHECKPOINT, + # 'pretrain_img_size': cfg.MODEL.PRETRAIN_IMG_SIZE, + } + if model_arch[15:17] == '2d': + kwargs['pretrain_img_size'] = cfg.MODEL.PRETRAIN_IMG_SIZE + kwargs['ape'] = cfg.MODEL.APE + kwargs['out_indices'] = cfg.MODEL.OUT_INDICES + if model_arch[15:17] == '3d': + kwargs['pretrained'] = cfg.MODEL.PRETRAINED + kwargs['pretrained2d'] = cfg.MODEL.PRETRAINED2D model = MODEL_MAP[cfg.MODEL.ARCHITECTURE](**kwargs) print('model: ', model.__class__.__name__) From 9ca87e2df45351196dc63bc5dc35d5740c3ebd94 Mon Sep 17 00:00:00 2001 From: HarshSulakhe Date: Thu, 25 Nov 2021 02:57:33 +0530 Subject: [PATCH 2/6] Integration of SwinTransformer3D as a backbone for FPN, preliminary --- connectomics/config/defaults.py | 9 +- connectomics/model/arch/__init__.py | 3 - connectomics/model/arch/fpn.py | 11 +- connectomics/model/backbone/__init__.py | 1 + connectomics/model/backbone/build.py | 10 +- .../model/{arch => backbone}/swintr.py | 124 ++++++++++++------ connectomics/model/block/basic.py | 11 +- connectomics/model/build.py | 54 +++----- connectomics/model/utils/misc.py | 23 +++- tests/test_models.py | 16 +++ 10 files changed, 168 insertions(+), 94 deletions(-) rename connectomics/model/{arch => backbone}/swintr.py (93%) diff --git a/connectomics/config/defaults.py b/connectomics/config/defaults.py index 887c65e7..ddc7fae0 100755 --- a/connectomics/config/defaults.py +++ b/connectomics/config/defaults.py @@ -112,8 +112,8 @@ ## EXCLUSIVE TO SWINTRANSFORMERS _C.MODEL.PATCH_SIZE = (4,4,4) -_C.MODEL.DEPTHS = [2,2,6,2] -_C.MODEL.NUM_HEADS = [3,6,12,24] +_C.MODEL.DEPTHS = [2,2,2,2,2] +_C.MODEL.NUM_HEADS = [3,6,12,24,24] _C.MODEL.WINDOW_SIIE = (2,7,7) _C.MODEL.MLP_RATIO = 4. _C.MODEL.QKV_BIAS = True @@ -121,17 +121,12 @@ _C.MODEL.DROP_RATE = 0. _C.MODEL.ATTN_DROP_RATE = 0. _C.MODEL.DROP_PATH_RATE = 0.2 -#!!!!!!!!!!!!!!! # _C.MODEL.NORM_LAYER = nn.LayerNorm _C.MODEL.PATCH_NORM = False _C.MODEL.FROZEN_STAGES = -1 _C.MODEL.USE_CHECKPOINT = False _C.MODEL.EMBED_DIM = 96 -## EXCLUSIVE TO SWINTRANSFORMER2D -_C.MODEL.APE = False -_C.MODEL.OUT_INDICES = (0, 1, 2, 3) - # ----------------------------------------------------------------------------- # Dataset # ----------------------------------------------------------------------------- diff --git a/connectomics/model/arch/__init__.py b/connectomics/model/arch/__init__.py index 30b9afa8..5f4d3092 100755 --- a/connectomics/model/arch/__init__.py +++ b/connectomics/model/arch/__init__.py @@ -1,7 +1,6 @@ from .unet import UNet3D, UNet2D, UNetPlus3D from .fpn import FPN3D from .deeplab import DeepLabV3 -from .swintr import SwinTransformer3D,SwinTransformer2D __all__ = [ 'UNet3D', @@ -9,6 +8,4 @@ 'UNet2D', 'FPN3D', 'DeepLabV3', - 'SwinTransformer3D', - 'SwinTransformer2D' ] diff --git a/connectomics/model/arch/fpn.py b/connectomics/model/arch/fpn.py index f660f6fa..4147f6b3 100755 --- a/connectomics/model/arch/fpn.py +++ b/connectomics/model/arch/fpn.py @@ -77,6 +77,11 @@ def __init__(self, 'attention': attn, } backbone_kwargs.update(self.shared_kwargs) + self.swin = False + if backbone_type == 'swintransformer3d': + backbone_kwargs.update(kwargs) + self.shared_kwargs['norm_mode'] = 'layer' + self.swin = True self.backbone = build_backbone( backbone_type, feature_keys, **backbone_kwargs) @@ -84,14 +89,14 @@ def __init__(self, self.latplanes = filters[0] self.latlayers = nn.ModuleList([ - conv3d_norm_act(x, self.latplanes, kernel_size=1, padding=0, + conv3d_norm_act(x, self.latplanes, kernel_size=1, padding=0, swin=self.swin, **self.shared_kwargs) for x in filters]) self.smooth = nn.ModuleList() for i in range(self.depth): kernel_size, padding = self._get_kernel_size(isotropy[i]) self.smooth.append(conv3d_norm_act( - self.latplanes, self.latplanes, kernel_size=kernel_size, + self.latplanes, self.latplanes, kernel_size=kernel_size, swin=self.swin, padding=padding, **self.shared_kwargs)) self.conv_out = self._get_io_conv(out_channel, isotropy[0]) @@ -138,4 +143,4 @@ def _get_io_conv(self, out_channel, is_isotropic): return conv3d_norm_act( self.filters[0], out_channel, kernel_size_io, padding=padding_io, pad_mode=self.shared_kwargs['pad_mode'], bias=True, - act_mode='none', norm_mode='none') + act_mode='none', norm_mode='none',swin=self.swin,) diff --git a/connectomics/model/backbone/__init__.py b/connectomics/model/backbone/__init__.py index f2a34b32..097697d9 100755 --- a/connectomics/model/backbone/__init__.py +++ b/connectomics/model/backbone/__init__.py @@ -2,3 +2,4 @@ from .resnet import ResNet3D from .repvgg import RepVGG3D, RepVGGBlock3D from .botnet import BotNet3D +from .swintr import SwinTransformer2D,SwinTransformer3D \ No newline at end of file diff --git a/connectomics/model/backbone/build.py b/connectomics/model/backbone/build.py index 6ec06cc6..b46bf8aa 100755 --- a/connectomics/model/backbone/build.py +++ b/connectomics/model/backbone/build.py @@ -7,6 +7,7 @@ from .repvgg import RepVGG3D from .botnet import BotNet3D from .efficientnet import EfficientNet3D +from .swintr import SwinTransformer2D,SwinTransformer3D from ..utils.misc import IntermediateLayerGetter backbone_dict = { @@ -14,13 +15,15 @@ 'repvgg': RepVGG3D, 'botnet': BotNet3D, 'efficientnet': EfficientNet3D, + 'swintransformer2d': SwinTransformer2D, + 'swintransformer3d': SwinTransformer3D, } def build_backbone(backbone_type: str, feat_keys: List[str], **kwargs): - assert backbone_type in ['resnet', 'repvgg', 'botnet', 'efficientnet'] + assert backbone_type in ['resnet', 'repvgg', 'botnet', 'efficientnet','swintransformer2d','swintransformer3d'] return_layers = {'layer0': feat_keys[0], 'layer1': feat_keys[1], 'layer2': feat_keys[2], @@ -28,5 +31,8 @@ def build_backbone(backbone_type: str, 'layer4': feat_keys[4]} backbone = backbone_dict[backbone_type](**kwargs) - assert len(feat_keys) == backbone.num_stages + if backbone_type[:15] =='swintransformer': + assert len(feat_keys) == backbone.num_layers + else: + assert len(feat_keys) == backbone.num_stages return IntermediateLayerGetter(backbone, return_layers) diff --git a/connectomics/model/arch/swintr.py b/connectomics/model/backbone/swintr.py similarity index 93% rename from connectomics/model/arch/swintr.py rename to connectomics/model/backbone/swintr.py index a691736b..cf6ad5f8 100644 --- a/connectomics/model/arch/swintr.py +++ b/connectomics/model/backbone/swintr.py @@ -1,3 +1,4 @@ +# Code adapted from https://github.com/microsoft/Swin-Transformer import torch import torch.nn as nn import torch.nn.functional as F @@ -5,8 +6,8 @@ import numpy as np from timm.models.layers import DropPath, trunc_normal_,to_2tuple -from mmcv.runner import load_checkpoint -from mmaction.utils import get_root_logger +# from mmcv.runner import load_checkpoint +# from mmaction.utils import get_root_logger from functools import reduce, lru_cache from operator import mul @@ -312,7 +313,6 @@ def __init__(self, dim, num_heads, window_size=(2,7,7), shift_size=(0,0,0), def forward_part1(self, x, mask_matrix): B, D, H, W, C = x.shape window_size, shift_size = get_window_size_3d((D, H, W), self.window_size, self.shift_size) - x = self.norm1(x) # pad feature maps to multiples of window size pad_l = pad_t = pad_d0 = 0 @@ -475,11 +475,17 @@ class PatchMerging3D(nn.Module): dim (int): Number of input channels. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ - def __init__(self, dim, norm_layer=nn.LayerNorm): + def __init__(self, dim, norm_layer=nn.LayerNorm,isotropy=False): super().__init__() self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) - self.norm = norm_layer(4 * dim) + self.isotropy = isotropy + if self.isotropy: + self.reduction = nn.Linear(8 * dim, 2 * dim, bias=False) + self.norm = norm_layer(8*dim) + else: + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + def forward(self, x): """ Forward function. @@ -493,13 +499,27 @@ def forward(self, x): if pad_input: x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) - x0 = x[:, :, 0::2, 0::2, :] # B D H/2 W/2 C - x1 = x[:, :, 1::2, 0::2, :] # B D H/2 W/2 C - x2 = x[:, :, 0::2, 1::2, :] # B D H/2 W/2 C - x3 = x[:, :, 1::2, 1::2, :] # B D H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B D H/2 W/2 4*C + if self.isotropy: + x0 = x[:, 0::2, 0::2, 0::2, :] # B D/2 H/2 W/2 C + x1 = x[:, 0::2, 1::2, 0::2, :] # B D/2 H/2 W/2 C + x2 = x[:, 0::2, 0::2, 1::2, :] # B D/2 H/2 W/2 C + x3 = x[:, 0::2, 1::2, 1::2, :] # B D/2 H/2 W/2 C + x4 = x[:, 1::2, 0::2, 0::2, :] # B D/2 H/2 W/2 C + x5 = x[:, 1::2, 1::2, 0::2, :] # B D/2 H/2 W/2 C + x6 = x[:, 1::2, 0::2, 1::2, :] # B D/2 H/2 W/2 C + x7 = x[:, 1::2, 1::2, 1::2, :] # B D/2 H/2 W/2 C - x = self.norm(x) + x = torch.cat([x0, x1, x2, x3,x4, x5, x6, x7], -1) # B D/2 H/2 W/2 8*C + + else: + x0 = x[:, :, 0::2, 0::2, :] # B D H/2 W/2 C + x1 = x[:, :, 1::2, 0::2, :] # B D H/2 W/2 C + x2 = x[:, :, 0::2, 1::2, :] # B D H/2 W/2 C + x3 = x[:, :, 1::2, 1::2, :] # B D H/2 W/2 C + + x = torch.cat([x0, x1, x2, x3], -1) # B D H/2 W/2 4*C + + x = self.norm(x) x = self.reduction(x) return x @@ -555,7 +575,7 @@ def compute_mask(D, H, W, window_size, shift_size, device): for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2],None): img_mask[:, d, h, w, :] = cnt cnt += 1 - mask_windows = window_partition(img_mask, window_size) # nW, ws[0]*ws[1]*ws[2], 1 + mask_windows = window_partition_3d(img_mask, window_size) # nW, ws[0]*ws[1]*ws[2], 1 mask_windows = mask_windows.squeeze(-1) # nW, ws[0]*ws[1]*ws[2] attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) @@ -591,6 +611,7 @@ def __init__(self, drop_path=0., norm_layer=nn.LayerNorm, downsample=None, + isotropy=False, use_checkpoint=False): super().__init__() self.window_size = window_size @@ -618,7 +639,7 @@ def __init__(self, self.downsample = downsample if self.downsample is not None: - self.downsample = downsample(dim=dim, norm_layer=norm_layer) + self.downsample = downsample(dim=dim, norm_layer=norm_layer,isotropy=isotropy) def forward(self, x): """ Forward function. @@ -854,8 +875,8 @@ def __init__(self, patch_size=(4,4,4), in_channel=3, embed_dim=96, - depths=[2, 2, 6, 2], - num_heads=[3, 6, 12, 24], + depths=[2, 2, 2, 2, 2], + num_heads=[3, 6, 12, 18, 24], window_size=(2,7,7), mlp_ratio=4., qkv_bias=True, @@ -867,6 +888,7 @@ def __init__(self, patch_norm=False, frozen_stages=-1, use_checkpoint=False, + isotropy = [False,False,False,False,False], **kwargs): super().__init__() @@ -878,7 +900,10 @@ def __init__(self, self.frozen_stages = frozen_stages self.window_size = window_size self.patch_size = patch_size - + self.isotropy = isotropy + assert len(self.isotropy) == self.num_layers + assert len(num_heads) == self.num_layers + assert self.num_layers == 5 # split image into non-overlapping patches self.patch_embed = PatchEmbed3D( patch_size=patch_size, in_channel=in_channel, embed_dim=embed_dim, @@ -890,7 +915,7 @@ def __init__(self, dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule # build layers - self.layers = nn.ModuleList() + layers = nn.ModuleList() for i_layer in range(self.num_layers): layer = BasicLayer3D( dim=int(embed_dim * 2**i_layer), @@ -905,8 +930,15 @@ def __init__(self, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, downsample=PatchMerging3D if i_layer n d h w c') + print("AFTER FIRST REARRANGE",x.size()) + x = self.norm(x) + print("AFTER NORM",x.size()) + x = rearrange(x, 'n d h w c -> n c d h w') + print("AFTER SECOND REARRANGE",x.size()) return x @@ -1081,7 +1125,7 @@ def __init__(self, out_indices=(0, 1, 2, 3), frozen_stages=-1, use_checkpoint=False, - **kwargs): + **_): super().__init__() self.pretrain_img_size = pretrain_img_size @@ -1174,11 +1218,11 @@ def _init_weights(m): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) - if isinstance(pretrained, str): - self.apply(_init_weights) - logger = get_root_logger() - load_checkpoint(self, pretrained, strict=False, logger=logger) - elif pretrained is None: + # if isinstance(pretrained, str): + # self.apply(_init_weights) + # logger = get_root_logger() + # load_checkpoint(self, pretrained, strict=False, logger=logger) + if pretrained is None: self.apply(_init_weights) else: raise TypeError('pretrained must be a str or None') diff --git a/connectomics/model/block/basic.py b/connectomics/model/block/basic.py index 1753ae0a..731e4ea0 100755 --- a/connectomics/model/block/basic.py +++ b/connectomics/model/block/basic.py @@ -5,7 +5,7 @@ import torch.nn as nn import torch.nn.functional as F -from ..utils import get_norm_2d, get_norm_3d, get_activation +from ..utils import get_norm_2d, get_norm_3d, get_activation, Rearrange def conv2d_norm_act(in_planes, planes, kernel_size=(3, 3), stride=1, groups=1, @@ -26,13 +26,18 @@ def conv2d_norm_act(in_planes, planes, kernel_size=(3, 3), stride=1, groups=1, def conv3d_norm_act(in_planes, planes, kernel_size=(3, 3, 3), stride=1, groups=1, dilation=(1, 1, 1), padding=(1, 1, 1), bias=False, pad_mode='replicate', - norm_mode='bn', act_mode='relu', return_list=False): + norm_mode='bn', act_mode='relu', return_list=False, swin=False): layers = [] layers += [nn.Conv3d(in_planes, planes, kernel_size=kernel_size, stride=stride, groups=groups, padding=padding, padding_mode=pad_mode, dilation=dilation, bias=bias)] - layers += [get_norm_3d(norm_mode, planes)] + if swin: + layers += [Rearrange()] + layers += [get_norm_3d(norm_mode, planes)] + layers += [Rearrange(before_norm=False)] + else: + layers += [get_norm_3d(norm_mode, planes)] layers += [get_activation(act_mode)] if return_list: # return a list of layers diff --git a/connectomics/model/build.py b/connectomics/model/build.py index 788f2de6..1a906c54 100755 --- a/connectomics/model/build.py +++ b/connectomics/model/build.py @@ -2,8 +2,8 @@ import torch import torch.nn as nn -from .arch import UNet3D, UNet2D, FPN3D, DeepLabV3, UNetPlus3D, SwinTransformer2D, SwinTransformer3D -from .backbone import RepVGG3D +from .arch import UNet3D, UNet2D, FPN3D, DeepLabV3, UNetPlus3D +from .backbone import RepVGG3D, SwinTransformer2D, SwinTransformer3D MODEL_MAP = { 'unet_3d': UNet3D, @@ -13,8 +13,6 @@ 'deeplabv3a': DeepLabV3, 'deeplabv3b': DeepLabV3, 'deeplabv3c': DeepLabV3, - 'swintransformer2d' : SwinTransformer2D, - 'swintransformer3d': SwinTransformer3D, } @@ -43,40 +41,30 @@ def build_model(cfg, device, rank=None): kwargs['deploy'] = cfg.MODEL.DEPLOY_MODE if cfg.MODEL.BACKBONE == 'botnet': kwargs['fmap_size'] = cfg.MODEL.INPUT_SIZE + if cfg.MODEL.BACKBONE == 'swintransformer3d': + swin_kwargs = { + 'patch_size': cfg.MODEL.PATCH_SIZE, + 'depths': cfg.MODEL.DEPTHS, + 'num_heads': cfg.MODEL.NUM_HEADS, + 'window_size': cfg.MODEL.WINDOW_SIIE, + 'mlp_ratio': cfg.MODEL.MLP_RATIO, + 'qkv_bias': cfg.MODEL.QKV_BIAS, + 'qk_scale': cfg.MODEL.QK_SCALE, + 'drop_rate': cfg.MODEL.DROP_RATE, + 'attn_drop_rate': cfg.MODEL.ATTN_DROP_RATE, + 'drop_path_rate': cfg.MODEL.DROP_PATH_RATE, + 'embed_dim': cfg.MODEL.EMBED_DIM, + 'patch_norm': cfg.MODEL.PATCH_NORM, + 'frozen_stages': cfg.MODEL.FROZEN_STAGES, + 'use_checkpoint': cfg.MODEL.USE_CHECKPOINT, + 'isotropy': cfg.MODEL.ISOTROPY, + } + kwargs.update(swin_kwargs) if model_arch[:7] == 'deeplab': kwargs['name'] = model_arch kwargs['backbone_type'] = cfg.MODEL.BACKBONE kwargs['aux_out'] = cfg.MODEL.AUX_OUT - - if model_arch[:15] == 'swintransformer': - kwargs = { - 'patch_size': cfg.MODEL.PATCH_SIZE, - 'in_channel': cfg.MODEL.IN_PLANES, - 'depths': cfg.MODEL.DEPTHS, - 'num_heads': cfg.MODEL.NUM_HEADS, - 'window_size': cfg.MODEL.WINDOW_SIIE, - 'mlp_ratio': cfg.MODEL.MLP_RATIO, - 'qkv_bias': cfg.MODEL.QKV_BIAS, - 'qk_scale': cfg.MODEL.QK_SCALE, - 'drop_rate': cfg.MODEL.DROP_RATE, - 'attn_drop_rate': cfg.MODEL.ATTN_DROP_RATE, - 'drop_path_rate': cfg.MODEL.DROP_PATH_RATE, - # 'norm_layer': cfg.MODEL.NORM_LAYER, - 'embed_dim': cfg.MODEL.EMBED_DIM, - 'patch_norm': cfg.MODEL.PATCH_NORM, - - 'frozen_stages': cfg.MODEL.FROZEN_STAGES, - 'use_checkpoint': cfg.MODEL.USE_CHECKPOINT, - # 'pretrain_img_size': cfg.MODEL.PRETRAIN_IMG_SIZE, - } - if model_arch[15:17] == '2d': - kwargs['pretrain_img_size'] = cfg.MODEL.PRETRAIN_IMG_SIZE - kwargs['ape'] = cfg.MODEL.APE - kwargs['out_indices'] = cfg.MODEL.OUT_INDICES - if model_arch[15:17] == '3d': - kwargs['pretrained'] = cfg.MODEL.PRETRAINED - kwargs['pretrained2d'] = cfg.MODEL.PRETRAINED2D model = MODEL_MAP[cfg.MODEL.ARCHITECTURE](**kwargs) print('model: ', model.__class__.__name__) diff --git a/connectomics/model/utils/misc.py b/connectomics/model/utils/misc.py index 183241ff..24b5221c 100755 --- a/connectomics/model/utils/misc.py +++ b/connectomics/model/utils/misc.py @@ -6,7 +6,7 @@ from torch import nn import torch.nn.functional as F from torch.jit.annotations import Dict - +from einops import rearrange class IntermediateLayerGetter(nn.ModuleDict): """ @@ -256,7 +256,7 @@ def get_norm_3d(norm: str, out_channels: int, bn_momentum: float = 0.1) -> nn.Mo Returns: nn.Module: the normalization layer """ - assert norm in ["bn", "sync_bn", "gn", "in", "none"], \ + assert norm in ["bn", "sync_bn", "gn", "in", "none","layer"], \ "Get unknown normalization layer key {}".format(norm) if norm == "gn": assert out_channels%8 == 0, "GN requires channels to separable into 8 groups" norm = { @@ -265,6 +265,7 @@ def get_norm_3d(norm: str, out_channels: int, bn_momentum: float = 0.1) -> nn.Mo "in": nn.InstanceNorm3d, "gn": lambda channels: nn.GroupNorm(8, channels), "none": nn.Identity, + "layer": nn.LayerNorm, }[norm] if norm in ["bn", "sync_bn", "in"]: return norm(out_channels, momentum=bn_momentum) @@ -282,7 +283,7 @@ def get_norm_2d(norm: str, out_channels: int, bn_momentum: float = 0.1) -> nn.Mo Returns: nn.Module: the normalization layer """ - assert norm in ["bn", "sync_bn", "gn", "in", "none"], \ + assert norm in ["bn", "sync_bn", "gn", "in", "none","layer"], \ "Get unknown normalization layer key {}".format(norm) norm = { "bn": nn.BatchNorm2d, @@ -290,6 +291,7 @@ def get_norm_2d(norm: str, out_channels: int, bn_momentum: float = 0.1) -> nn.Mo "in": nn.InstanceNorm2d, "gn": lambda channels: nn.GroupNorm(16, channels), "none": nn.Identity, + "layer": nn.LayerNorm, }[norm] if norm in ["bn", "sync_bn", "in"]: return norm(out_channels, momentum=bn_momentum) @@ -325,3 +327,18 @@ def get_norm_1d(norm: str, out_channels: int, bn_momentum: float = 0.1) -> nn.Mo def get_num_params(model): num_param = sum([param.nelement() for param in model.parameters()]) return num_param + +# ---------------------- +# Miscellanous Modules +# ---------------------- + +class Rearrange(nn.Module): + def __init__(self,before_norm=True): + super(Rearrange, self).__init__() + self.before_norm = before_norm + + def forward(self, x): + if self.before_norm: + return rearrange(x, 'n c d h w -> n d h w c') + else: + return rearrange(x, 'n d h w c -> n c d h w') \ No newline at end of file diff --git a/tests/test_models.py b/tests/test_models.py index 10114253..189185b2 100755 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -155,6 +155,22 @@ def test_build_fpn_with_botnet(self): y1 = model(x) self.assertTupleEqual(tuple(y1.shape), (2, 1, d, h, w)) + def test_build_fpn_with_swintransformer(self): + r"""Test building a 3D FPN model with BotNet3D backbone from configs. + """ + cfg = get_cfg_defaults() + cfg.MODEL.ARCHITECTURE = 'fpn_3d' + cfg.MODEL.BACKBONE = 'swintransformer3d' + cfg.MODEL.FILTERS = [192, 384, 768, 1536, 1536] + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = build_model(cfg, device).eval() + + d, h, w = cfg.MODEL.INPUT_SIZE + c = cfg.MODEL.IN_PLANES + x = torch.rand(2, c, d, h, w) + y1 = model(x) + self.assertTupleEqual(tuple(y1.shape), (2, c, d//4, h//8, w//8)) + def test_build_fpn_with_efficientnet(self): r"""Test building a 3D FPN model with EfficientNet3D backbone from configs. """ From c231c78c1e58d2de10e17342b5c2d470320da2b3 Mon Sep 17 00:00:00 2001 From: HarshSulakhe Date: Thu, 9 Dec 2021 01:34:17 +0530 Subject: [PATCH 3/6] Add interpolation layer for SwinTr in FPN --- connectomics/config/defaults.py | 2 +- connectomics/model/arch/fpn.py | 13 ++++++++----- connectomics/model/backbone/swintr.py | 20 +++++++++----------- connectomics/model/block/basic.py | 4 ++-- tests/test_models.py | 2 +- 5 files changed, 21 insertions(+), 20 deletions(-) diff --git a/connectomics/config/defaults.py b/connectomics/config/defaults.py index ddc7fae0..ccdb70b0 100755 --- a/connectomics/config/defaults.py +++ b/connectomics/config/defaults.py @@ -10,7 +10,7 @@ # ----------------------------------------------------------------------------- _C.SYSTEM = CN() -_C.SYSTEM.NUM_GPUS = 4 +_C.SYSTEM.NUM_GPUS = 1 _C.SYSTEM.NUM_CPUS = 4 # Run distributed training using DistributedDataparallel model _C.SYSTEM.DISTRIBUTED = False diff --git a/connectomics/model/arch/fpn.py b/connectomics/model/arch/fpn.py index 4147f6b3..5d33255c 100755 --- a/connectomics/model/arch/fpn.py +++ b/connectomics/model/arch/fpn.py @@ -77,11 +77,11 @@ def __init__(self, 'attention': attn, } backbone_kwargs.update(self.shared_kwargs) - self.swin = False + self.is_swin = False if backbone_type == 'swintransformer3d': backbone_kwargs.update(kwargs) self.shared_kwargs['norm_mode'] = 'layer' - self.swin = True + self.is_swin = True self.backbone = build_backbone( backbone_type, feature_keys, **backbone_kwargs) @@ -89,14 +89,14 @@ def __init__(self, self.latplanes = filters[0] self.latlayers = nn.ModuleList([ - conv3d_norm_act(x, self.latplanes, kernel_size=1, padding=0, swin=self.swin, + conv3d_norm_act(x, self.latplanes, kernel_size=1, padding=0, is_swin=self.is_swin, **self.shared_kwargs) for x in filters]) self.smooth = nn.ModuleList() for i in range(self.depth): kernel_size, padding = self._get_kernel_size(isotropy[i]) self.smooth.append(conv3d_norm_act( - self.latplanes, self.latplanes, kernel_size=kernel_size, swin=self.swin, + self.latplanes, self.latplanes, kernel_size=kernel_size, is_swin=self.is_swin, padding=padding, **self.shared_kwargs)) self.conv_out = self._get_io_conv(out_channel, isotropy[0]) @@ -118,6 +118,9 @@ def _forward_main(self, z): out = self._up_smooth_add(out, features[i-1], self.smooth[i]) out = self.smooth[0](out) out = self.conv_out(out) + if self.is_swin: + b,c,d,h,w = out.size() + out = F.interpolate(out,size=(4*d,8*h,8*w),mode='trilinear') return out def _up_smooth_add(self, x, y, smooth): @@ -143,4 +146,4 @@ def _get_io_conv(self, out_channel, is_isotropic): return conv3d_norm_act( self.filters[0], out_channel, kernel_size_io, padding=padding_io, pad_mode=self.shared_kwargs['pad_mode'], bias=True, - act_mode='none', norm_mode='none',swin=self.swin,) + act_mode='none', norm_mode='none',is_swin=self.is_swin,) diff --git a/connectomics/model/backbone/swintr.py b/connectomics/model/backbone/swintr.py index cf6ad5f8..bc57e19f 100644 --- a/connectomics/model/backbone/swintr.py +++ b/connectomics/model/backbone/swintr.py @@ -1047,28 +1047,26 @@ def _init_weights(m): def forward(self, x): """Forward function.""" - print("BEFORE PATCH EMBED",x.size()) x = self.patch_embed(x) - print("AFTER PATCH EMBED",x.size()) x = self.pos_drop(x) - print("AFTER POS_DROP",x.size()) - i = 0 - for layer in self.layers: - x = layer(x.contiguous()) - print("AFTER LAYER",i,x.size()) - i+=1 + x = self.layer0(x.contiguous()) + + x = self.layer1(x.contiguous()) + + x = self.layer2(x.contiguous()) + + x = self.layer3(x.contiguous()) + + x = self.layer4(x.contiguous()) x = rearrange(x, 'n c d h w -> n d h w c') - print("AFTER FIRST REARRANGE",x.size()) x = self.norm(x) - print("AFTER NORM",x.size()) x = rearrange(x, 'n d h w c -> n c d h w') - print("AFTER SECOND REARRANGE",x.size()) return x diff --git a/connectomics/model/block/basic.py b/connectomics/model/block/basic.py index 731e4ea0..dbe6fd7c 100755 --- a/connectomics/model/block/basic.py +++ b/connectomics/model/block/basic.py @@ -26,13 +26,13 @@ def conv2d_norm_act(in_planes, planes, kernel_size=(3, 3), stride=1, groups=1, def conv3d_norm_act(in_planes, planes, kernel_size=(3, 3, 3), stride=1, groups=1, dilation=(1, 1, 1), padding=(1, 1, 1), bias=False, pad_mode='replicate', - norm_mode='bn', act_mode='relu', return_list=False, swin=False): + norm_mode='bn', act_mode='relu', return_list=False, is_swin=False): layers = [] layers += [nn.Conv3d(in_planes, planes, kernel_size=kernel_size, stride=stride, groups=groups, padding=padding, padding_mode=pad_mode, dilation=dilation, bias=bias)] - if swin: + if is_swin: layers += [Rearrange()] layers += [get_norm_3d(norm_mode, planes)] layers += [Rearrange(before_norm=False)] diff --git a/tests/test_models.py b/tests/test_models.py index 189185b2..836291bf 100755 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -169,7 +169,7 @@ def test_build_fpn_with_swintransformer(self): c = cfg.MODEL.IN_PLANES x = torch.rand(2, c, d, h, w) y1 = model(x) - self.assertTupleEqual(tuple(y1.shape), (2, c, d//4, h//8, w//8)) + self.assertTupleEqual(tuple(y1.shape), (2, c, d, h, w)) def test_build_fpn_with_efficientnet(self): r"""Test building a 3D FPN model with EfficientNet3D backbone from configs. From 809f87ad3bed04a509a54e55124a7a8e43e0163a Mon Sep 17 00:00:00 2001 From: HarshSulakhe Date: Thu, 9 Dec 2021 01:47:56 +0530 Subject: [PATCH 4/6] Revert default gpus to 4 --- connectomics/config/defaults.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/connectomics/config/defaults.py b/connectomics/config/defaults.py index ccdb70b0..ddc7fae0 100755 --- a/connectomics/config/defaults.py +++ b/connectomics/config/defaults.py @@ -10,7 +10,7 @@ # ----------------------------------------------------------------------------- _C.SYSTEM = CN() -_C.SYSTEM.NUM_GPUS = 1 +_C.SYSTEM.NUM_GPUS = 4 _C.SYSTEM.NUM_CPUS = 4 # Run distributed training using DistributedDataparallel model _C.SYSTEM.DISTRIBUTED = False From 1b4064d8d73a433852eb000beef3beef3d8caa3b Mon Sep 17 00:00:00 2001 From: HarshSulakhe Date: Wed, 29 Dec 2021 15:27:39 -0500 Subject: [PATCH 5/6] Modify interpolation --- connectomics/model/arch/fpn.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/connectomics/model/arch/fpn.py b/connectomics/model/arch/fpn.py index 5d33255c..9204e60e 100755 --- a/connectomics/model/arch/fpn.py +++ b/connectomics/model/arch/fpn.py @@ -105,6 +105,7 @@ def __init__(self, model_init(self, init_mode) def forward(self, x): + self.x_size = x.size() z = self.backbone(x) return self._forward_main(z) @@ -119,8 +120,8 @@ def _forward_main(self, z): out = self.smooth[0](out) out = self.conv_out(out) if self.is_swin: - b,c,d,h,w = out.size() - out = F.interpolate(out,size=(4*d,8*h,8*w),mode='trilinear') + b,c,d,h,w = self.x_size + out = F.interpolate(out,size=(d,h,w),mode='trilinear') return out def _up_smooth_add(self, x, y, smooth): From 73402e654afde69a43a5836cc90a32ef75c75dc2 Mon Sep 17 00:00:00 2001 From: HarshSulakhe Date: Mon, 3 Jan 2022 03:32:11 +0530 Subject: [PATCH 6/6] Add conv improvement option for swintr, add test --- connectomics/config/defaults.py | 7 +- connectomics/model/arch/fpn.py | 6 +- connectomics/model/backbone/build.py | 5 +- connectomics/model/backbone/swintr.py | 164 +++++++++++++++++++------- connectomics/model/build.py | 4 +- tests/test_models.py | 30 ++++- 6 files changed, 164 insertions(+), 52 deletions(-) diff --git a/connectomics/config/defaults.py b/connectomics/config/defaults.py index ddc7fae0..eb0a5fc3 100755 --- a/connectomics/config/defaults.py +++ b/connectomics/config/defaults.py @@ -10,7 +10,7 @@ # ----------------------------------------------------------------------------- _C.SYSTEM = CN() -_C.SYSTEM.NUM_GPUS = 4 +_C.SYSTEM.NUM_GPUS = 1 _C.SYSTEM.NUM_CPUS = 4 # Run distributed training using DistributedDataparallel model _C.SYSTEM.DISTRIBUTED = False @@ -121,11 +121,14 @@ _C.MODEL.DROP_RATE = 0. _C.MODEL.ATTN_DROP_RATE = 0. _C.MODEL.DROP_PATH_RATE = 0.2 -# _C.MODEL.NORM_LAYER = nn.LayerNorm +_C.MODEL.USE_CONV = False _C.MODEL.PATCH_NORM = False _C.MODEL.FROZEN_STAGES = -1 _C.MODEL.USE_CHECKPOINT = False _C.MODEL.EMBED_DIM = 96 +_C.MODEL.DOWNSAMPLE_BEFORE = [True, True, True, True] +_C.MODEL.SWIN_ISOTROPY = [True, True, True, True] + # ----------------------------------------------------------------------------- # Dataset diff --git a/connectomics/model/arch/fpn.py b/connectomics/model/arch/fpn.py index 9204e60e..00c12e4b 100755 --- a/connectomics/model/arch/fpn.py +++ b/connectomics/model/arch/fpn.py @@ -54,7 +54,7 @@ def __init__(self, self.filters = filters self.depth = len(filters) - assert len(isotropy) == self.depth + # assert len(isotropy) == self.depth if is_isotropic: isotropy = [True] * self.depth self.isotropy = isotropy @@ -121,7 +121,9 @@ def _forward_main(self, z): out = self.conv_out(out) if self.is_swin: b,c,d,h,w = self.x_size - out = F.interpolate(out,size=(d,h,w),mode='trilinear') + _b,_c,_d,_h,_w = out.size() + if _d != d or _h != h or _w != w: + out = F.interpolate(out,size=(d,h,w),mode='trilinear') return out def _up_smooth_add(self, x, y, smooth): diff --git a/connectomics/model/backbone/build.py b/connectomics/model/backbone/build.py index b46bf8aa..6ee4099e 100755 --- a/connectomics/model/backbone/build.py +++ b/connectomics/model/backbone/build.py @@ -32,7 +32,10 @@ def build_backbone(backbone_type: str, backbone = backbone_dict[backbone_type](**kwargs) if backbone_type[:15] =='swintransformer': - assert len(feat_keys) == backbone.num_layers + if backbone.use_conv: + assert len(feat_keys) == backbone.num_layers + 2 + else: + assert len(feat_keys) == backbone.num_layers + 1 else: assert len(feat_keys) == backbone.num_stages return IntermediateLayerGetter(backbone, return_layers) diff --git a/connectomics/model/backbone/swintr.py b/connectomics/model/backbone/swintr.py index bc57e19f..004de661 100644 --- a/connectomics/model/backbone/swintr.py +++ b/connectomics/model/backbone/swintr.py @@ -172,7 +172,6 @@ def forward(self, x, mask=None): N, N, -1) # Wd*Wh*Ww,Wd*Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wd*Wh*Ww, Wd*Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) # B_, nH, N, N - if mask is not None: nW = mask.shape[0] attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) @@ -612,17 +611,19 @@ def __init__(self, norm_layer=nn.LayerNorm, downsample=None, isotropy=False, + downsample_before=False, use_checkpoint=False): super().__init__() self.window_size = window_size self.shift_size = tuple(i // 2 for i in window_size) self.depth = depth self.use_checkpoint = use_checkpoint - + self.downsample_before = downsample_before + self.downsample = downsample # build blocks self.blocks = nn.ModuleList([ SwinTransformerBlock3D( - dim=dim, + dim=dim*2 if self.downsample_before and self.downsample else dim, num_heads=num_heads, window_size=window_size, shift_size=(0,0,0) if (i % 2 == 0) else self.shift_size, @@ -637,7 +638,6 @@ def __init__(self, ) for i in range(depth)]) - self.downsample = downsample if self.downsample is not None: self.downsample = downsample(dim=dim, norm_layer=norm_layer,isotropy=isotropy) @@ -647,19 +647,37 @@ def forward(self, x): x: Input feature, tensor size (B, C, D, H, W). """ # calculate attention mask for SW-MSA - B, C, D, H, W = x.shape - window_size, shift_size = get_window_size_3d((D,H,W), self.window_size, self.shift_size) - x = rearrange(x, 'b c d h w -> b d h w c') - Dp = int(np.ceil(D / window_size[0])) * window_size[0] - Hp = int(np.ceil(H / window_size[1])) * window_size[1] - Wp = int(np.ceil(W / window_size[2])) * window_size[2] - attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device) - for blk in self.blocks: - x = blk(x, attn_mask) - x = x.view(B, D, H, W, -1) + if self.downsample_before: + B, C, D, H, W = x.shape + x = rearrange(x, 'b c d h w -> b d h w c') + if self.downsample is not None: + x = self.downsample(x) + B, D, H, W, C = x.shape + + window_size, shift_size = get_window_size_3d((D,H,W), self.window_size, self.shift_size) + # x = rearrange(x, 'b c d h w -> b d h w c') + Dp = int(np.ceil(D / window_size[0])) * window_size[0] + Hp = int(np.ceil(H / window_size[1])) * window_size[1] + Wp = int(np.ceil(W / window_size[2])) * window_size[2] + attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device) + for blk in self.blocks: + x = blk(x, attn_mask) + x = x.view(B, D, H, W, -1) + + else: + B, C, D, H, W = x.shape + window_size, shift_size = get_window_size_3d((D,H,W), self.window_size, self.shift_size) + x = rearrange(x, 'b c d h w -> b d h w c') + Dp = int(np.ceil(D / window_size[0])) * window_size[0] + Hp = int(np.ceil(H / window_size[1])) * window_size[1] + Wp = int(np.ceil(W / window_size[2])) * window_size[2] + attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device) + for blk in self.blocks: + x = blk(x, attn_mask) + x = x.view(B, D, H, W, -1) + if self.downsample is not None: + x = self.downsample(x) - if self.downsample is not None: - x = self.downsample(x) x = rearrange(x, 'b d h w c -> b c d h w') return x @@ -875,8 +893,8 @@ def __init__(self, patch_size=(4,4,4), in_channel=3, embed_dim=96, - depths=[2, 2, 2, 2, 2], - num_heads=[3, 6, 12, 18, 24], + depths=[2, 2, 2, 2], + num_heads=[3, 6, 12, 24], window_size=(2,7,7), mlp_ratio=4., qkv_bias=True, @@ -888,7 +906,9 @@ def __init__(self, patch_norm=False, frozen_stages=-1, use_checkpoint=False, - isotropy = [False,False,False,False,False], + swin_isotropy = [False,False,False,False], + use_conv = False, + downsample_before = [True,True,True,True], **kwargs): super().__init__() @@ -900,25 +920,27 @@ def __init__(self, self.frozen_stages = frozen_stages self.window_size = window_size self.patch_size = patch_size - self.isotropy = isotropy + self.isotropy = swin_isotropy assert len(self.isotropy) == self.num_layers assert len(num_heads) == self.num_layers - assert self.num_layers == 5 - # split image into non-overlapping patches - self.patch_embed = PatchEmbed3D( - patch_size=patch_size, in_channel=in_channel, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - - self.pos_drop = nn.Dropout(p=drop_rate) + assert len(downsample_before) == self.num_layers + self.use_conv = use_conv + if self.use_conv: + assert self.num_layers == 3 + else: + assert self.num_layers == 4 + # stochastic depth dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - + dims = [embed_dim] + for _ in range(self.num_layers-1): + dims.append(embed_dim * 2**_) # build layers layers = nn.ModuleList() for i_layer in range(self.num_layers): layer = BasicLayer3D( - dim=int(embed_dim * 2**i_layer), + dim=dims[i_layer], depth=depths[i_layer], num_heads=num_heads[i_layer], window_size=window_size, @@ -929,16 +951,56 @@ def __init__(self, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, - downsample=PatchMerging3D if i_layer0 else None, + isotropy=self.isotropy[i_layer], use_checkpoint=use_checkpoint) layers.append(layer) - self.layer0 = layers[0] - self.layer1 = layers[1] - self.layer2 = layers[2] - self.layer3 = layers[3] - self.layer4 = layers[4] + # split image into non-overlapping patches + patch_embed = PatchEmbed3D( + patch_size=patch_size, in_channel=in_channel, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + pos_drop = nn.Dropout(p=drop_rate) + + if self.use_conv: + # self.layer0 = nn.Sequential(nn.Conv3d(in_channels = in_channel,out_channels = in_channel,kernel_size = 3,stride = 1,padding = 1,padding_mode = 'replicate'), + # nn.Conv3d(in_channels = in_channel,out_channels = in_channel,kernel_size = 3,stride = 1,padding = 1,padding_mode = 'replicate'), + # nn.Conv3d(in_channels = in_channel,out_channels = in_channel,kernel_size = 3,stride = 1,padding = 1,padding_mode = 'replicate'), + # ) #ORIGINAL DIMENSIONS + # self.layer1 = nn.Sequential(patch_embed,pos_drop,) + # self.layer2 = layers[0] + # self.layer3 = layers[1] + # self.layer4 = nn.Sequential(layers[2]) + + + # self.layer0 = nn.Sequential(nn.Conv3d(in_channels = in_channel,out_channels = in_channel,kernel_size = 3,stride = 1,padding = 1,padding_mode = 'replicate'), + # nn.Conv3d(in_channels = in_channel,out_channels = in_channel,kernel_size = 3,stride = 1,padding = 1,padding_mode = 'replicate'), + # nn.Conv3d(in_channels = in_channel,out_channels = in_channel,kernel_size = 3,stride = 1,padding = 1,padding_mode = 'replicate'), + # patch_embed,pos_drop,) + # self.layer1 = layers[0] + # self.layer2 = layers[1] + # self.layer3 = layers[2] + # self.layer4 = layers[3] + + self.layer0 = nn.Sequential(nn.Conv3d(in_channels = in_channel,out_channels = in_channel,kernel_size = 3,stride = 1,padding = 1,padding_mode = 'replicate'), + nn.Conv3d(in_channels = in_channel,out_channels = in_channel,kernel_size = 3,stride = 1,padding = 1,padding_mode = 'replicate'), + nn.Conv3d(in_channels = in_channel,out_channels = in_channel,kernel_size = 3,stride = 1,padding = 1,padding_mode = 'replicate'), + ) + self.layer1 = nn.Sequential(patch_embed,pos_drop) + self.layer2 = layers[0] + self.layer3 = layers[1] + self.layer4 = layers[2] + + + else: + + self.layer0 = nn.Sequential(patch_embed,pos_drop) + self.layer1 = layers[0] + self.layer2 = layers[1] + self.layer3 = layers[2] + self.layer4 = layers[3] self.num_features = int(embed_dim * 2**(self.num_layers-1)) @@ -1047,20 +1109,34 @@ def _init_weights(m): def forward(self, x): """Forward function.""" - x = self.patch_embed(x) + if self.use_conv: + x = self.layer0(x.contiguous()) - x = self.pos_drop(x) - - x = self.layer0(x.contiguous()) + x = self.patch_embed(x) + + x = self.pos_drop(x) + + x = self.layer1(x.contiguous()) - x = self.layer1(x.contiguous()) + x = self.layer2(x.contiguous()) - x = self.layer2(x.contiguous()) + x = self.layer3(x.contiguous()) + + x = self.layer4(x.contiguous()) + else: + x = self.patch_embed(x) + + x = self.pos_drop(x) + + x = self.layer0(x.contiguous()) + + x = self.layer1(x.contiguous()) - x = self.layer3(x.contiguous()) + x = self.layer2(x.contiguous()) - x = self.layer4(x.contiguous()) + x = self.layer3(x.contiguous()) + x = self.layer4(x.contiguous()) x = rearrange(x, 'n c d h w -> n d h w c') diff --git a/connectomics/model/build.py b/connectomics/model/build.py index 1a906c54..b30c82f4 100755 --- a/connectomics/model/build.py +++ b/connectomics/model/build.py @@ -57,7 +57,9 @@ def build_model(cfg, device, rank=None): 'patch_norm': cfg.MODEL.PATCH_NORM, 'frozen_stages': cfg.MODEL.FROZEN_STAGES, 'use_checkpoint': cfg.MODEL.USE_CHECKPOINT, - 'isotropy': cfg.MODEL.ISOTROPY, + 'swin_isotropy': cfg.MODEL.SWIN_ISOTROPY, + 'use_conv': cfg.MODEL.USE_CONV, + 'downsample_before': cfg.MODEL.DOWNSAMPLE_BEFORE, } kwargs.update(swin_kwargs) diff --git a/tests/test_models.py b/tests/test_models.py index 836291bf..c99d274e 100755 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -155,13 +155,39 @@ def test_build_fpn_with_botnet(self): y1 = model(x) self.assertTupleEqual(tuple(y1.shape), (2, 1, d, h, w)) - def test_build_fpn_with_swintransformer(self): + def test_build_fpn_with_default_swintransformer(self): r"""Test building a 3D FPN model with BotNet3D backbone from configs. """ cfg = get_cfg_defaults() cfg.MODEL.ARCHITECTURE = 'fpn_3d' cfg.MODEL.BACKBONE = 'swintransformer3d' - cfg.MODEL.FILTERS = [192, 384, 768, 1536, 1536] + cfg.MODEL.FILTERS = [96, 96, 192, 384, 768] + cfg.MODEL.SWIN_ISOTROPY = [False,False,False,False] + cfg.MODEL.DEPTHS = [2,2,2,2] + cfg.MODEL.NUM_HEADS = [3,6,12,24] + cfg.MODEL.USE_CONV = False + cfg.MODEL.DOWNSAMPLE_BEFORE = [True,True,True,True] + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = build_model(cfg, device).eval() + + d, h, w = cfg.MODEL.INPUT_SIZE + c = cfg.MODEL.IN_PLANES + x = torch.rand(2, c, d, h, w) + y1 = model(x) + self.assertTupleEqual(tuple(y1.shape), (2, c, d, h, w)) + + def test_build_fpn_with_conv_swintransformer(self): + r"""Test building a 3D FPN model with BotNet3D backbone from configs. + """ + cfg = get_cfg_defaults() + cfg.MODEL.ARCHITECTURE = 'fpn_3d' + cfg.MODEL.BACKBONE = 'swintransformer3d' + cfg.MODEL.FILTERS = [1, 96, 96, 192, 384] + cfg.MODEL.SWIN_ISOTROPY = [False,False,False] + cfg.MODEL.DEPTHS = [2,2,2] + cfg.MODEL.NUM_HEADS = [3,6,12] + cfg.MODEL.USE_CONV = True + cfg.MODEL.DOWNSAMPLE_BEFORE = [True,True,True] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = build_model(cfg, device).eval()