Skip to content

Commit

Permalink
[Add] ResNet50
Browse files Browse the repository at this point in the history
  • Loading branch information
USER authored and USER committed Feb 16, 2022
1 parent 9cccc48 commit f5e6a95
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
33 changes: 33 additions & 0 deletions backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,38 @@ def _forward_impl(self, x: Tensor) -> Tensor:

return x

class ResNet50(torchvision.models.resnet.ResNet):
def __init__(self, track_bn=True):
def norm_layer(*args, **kwargs):
return nn.BatchNorm2d(*args, **kwargs, track_running_stats=track_bn)
super().__init__(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], norm_layer=norm_layer)
del self.fc
self.final_feat_dim = 2048

def load_imagenet_weights(self, progress=True):
state_dict = load_state_dict_from_url(torchvision.models.resnet.model_urls['resnet50'],
progress=progress)
missing, unexpected = self.load_state_dict(state_dict, strict=False)
if len(missing) > 0:
raise AssertionError('Model code may be incorrect')

def _forward_impl(self, x: Tensor) -> Tensor:
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)

x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)

x = self.avgpool(x)
x = torch.flatten(x, 1)
# x = self.fc(x)

return x


##########################################################################################################
Expand Down Expand Up @@ -626,6 +658,7 @@ def forward(self, x, is_feat=False):
_backbone_class_map = {
'resnet10': ResNet10,
'resnet18': ResNet18,
'resnet50': ResNet50,
}


Expand Down
1 change: 1 addition & 0 deletions paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
BACKBONE_KEYS = {
'resnet10': 'resnet10',
'resnet18': 'resnet18',
'resnet50': 'resnet50',
}

MODEL_KEYS = {
Expand Down

0 comments on commit f5e6a95

Please sign in to comment.