Skip to content

Commit

Permalink
[feat] add torchvision models interface
Browse files Browse the repository at this point in the history
  • Loading branch information
BrandonHanx committed Apr 10, 2023
1 parent ffe43f7 commit 72b2f37
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 7 deletions.
3 changes: 3 additions & 0 deletions src/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
resnet50,
resnet50_fc512,
)
from .tvmodels import mobilenet_v3_small, vgg16


__model_factory = {
Expand All @@ -18,6 +19,8 @@
"resnet34_fc512": resnet34_fc512,
"resnet50": resnet50,
"resnet50_fc512": resnet50_fc512,
"mobilenet_v3_small": mobilenet_v3_small,
"vgg16": vgg16,
}


Expand Down
62 changes: 62 additions & 0 deletions src/models/tvmodels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) EEEM071, University of Surrey

import torch.nn as nn
import torchvision.models as tvmodels


__all__ = ["mobilenet_v3_small", "vgg16"]


class TorchVisionModel(nn.Module):
def __init__(self, name, num_classes, loss, pretrained, **kwargs):
super().__init__()

self.loss = loss
self.backbone = tvmodels.__dict__[name](pretrained=pretrained)
self.feature_dim = self.backbone.classifier[0].in_features

# overwrite the classifier used for ImageNet pretrianing
# nn.Identity() will do nothing, it's just a place-holder
self.backbone.classifier = nn.Identity()
self.classifier = nn.Linear(self.feature_dim, num_classes)

def forward(self, x):
v = self.backbone(x)

if not self.training:
return v

y = self.classifier(v)

if self.loss == {"xent"}:
return y
elif self.loss == {"xent", "htri"}:
return y, v
else:
raise KeyError(f"Unsupported loss: {self.loss}")


def vgg16(num_classes, loss={"xent"}, pretrained=True, **kwargs):
model = TorchVisionModel(
"vgg16",
num_classes=num_classes,
loss=loss,
pretrained=pretrained,
**kwargs,
)
return model


def mobilenet_v3_small(num_classes, loss={"xent"}, pretrained=True, **kwargs):
model = TorchVisionModel(
"mobilenet_v3_small",
num_classes=num_classes,
loss=loss,
pretrained=pretrained,
**kwargs,
)
return model


# Define any models supported by torchvision bellow
# https://pytorch.org/vision/0.11/models.html
4 changes: 2 additions & 2 deletions test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
python main.py \
-s veri \
-t veri \
-a resnet18 \
-a mobilenet_v3_small \
--height 224 \
--width 224 \
--test-batch-size 100 \
--evaluate \
--save-dir logs/eval-resnet18-veri
--save-dir logs/eval-mobilenet_v3_small-veri
10 changes: 5 additions & 5 deletions train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
python main.py \
-s veri \
-t veri \
-a resnet18 \
-a mobilenet_v3_small \
--height 224 \
--width 224 \
--optim amsgrad \
--lr 0.0003 \
--max-epoch 60 \
--stepsize 20 40 \
--train-batch-size 64 \
--max-epoch 30 \
--stepsize 10 20 \
--train-batch-size 256 \
--test-batch-size 100 \
--save-dir logs/resnet18-veri
--save-dir logs/mobilenet_v3_small-veri

0 comments on commit 72b2f37

Please sign in to comment.