Skip to content

Commit

Permalink
removed redundant code in resnet32
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasverelst committed Jun 22, 2020
1 parent b586cd5 commit be1024c
Showing 1 changed file with 0 additions and 50 deletions.
50 changes: 0 additions & 50 deletions classification/models/resnet_32x32.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,56 +11,6 @@
from torch.autograd import Variable
from models.resnet_util import *

class BasicBlock(nn.Module):
"""Standard residual block """
expansion = 1

def __init__(self, inplanes, planes, stride=1, downsample=None, sparse=False):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
self.sparse = sparse

if sparse:
# in the resnet basic block, the first convolution is already strided, so mask_stride = 1
self.masker = dynconv.MaskUnit(channels=inplanes, stride=stride, dilate_stride=1)

self.fast = False

def forward(self, input):
x, meta = input
identity = x
if self.downsample is not None:
identity = self.downsample(x)

if not self.sparse:
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += identity
else:
assert meta is not None
m = self.masker(x, meta)
mask_dilate, mask = m['dilate'], m['std']

x = dynconv.conv3x3(self.conv1, x, None, mask_dilate)
x = dynconv.bn_relu(self.bn1, self.relu, x, mask_dilate)
x = dynconv.conv3x3(self.conv2, x, mask_dilate, mask)
x = dynconv.bn_relu(self.bn2, None, x, mask)
out = identity + dynconv.apply_mask(x, mask)

out = self.relu(out)
return out, meta



########################################
# Original ResNet #
########################################
Expand Down

0 comments on commit be1024c

Please sign in to comment.