diff --git a/src/models/__init__.py b/src/models/__init__.py index 658b9ab..d56e2ab 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -8,6 +8,7 @@ resnet50, resnet50_fc512, ) +from .tvmodels import mobilenet_v3_small, vgg16 __model_factory = { @@ -18,6 +19,8 @@ "resnet34_fc512": resnet34_fc512, "resnet50": resnet50, "resnet50_fc512": resnet50_fc512, + "mobilenet_v3_small": mobilenet_v3_small, + "vgg16": vgg16, } diff --git a/src/models/tvmodels.py b/src/models/tvmodels.py new file mode 100644 index 0000000..795fc23 --- /dev/null +++ b/src/models/tvmodels.py @@ -0,0 +1,62 @@ +# Copyright (c) EEEM071, University of Surrey + +import torch.nn as nn +import torchvision.models as tvmodels + + +__all__ = ["mobilenet_v3_small", "vgg16"] + + +class TorchVisionModel(nn.Module): + def __init__(self, name, num_classes, loss, pretrained, **kwargs): + super().__init__() + + self.loss = loss + self.backbone = tvmodels.__dict__[name](pretrained=pretrained) + self.feature_dim = self.backbone.classifier[0].in_features + + # overwrite the classifier used for ImageNet pretrianing + # nn.Identity() will do nothing, it's just a place-holder + self.backbone.classifier = nn.Identity() + self.classifier = nn.Linear(self.feature_dim, num_classes) + + def forward(self, x): + v = self.backbone(x) + + if not self.training: + return v + + y = self.classifier(v) + + if self.loss == {"xent"}: + return y + elif self.loss == {"xent", "htri"}: + return y, v + else: + raise KeyError(f"Unsupported loss: {self.loss}") + + +def vgg16(num_classes, loss={"xent"}, pretrained=True, **kwargs): + model = TorchVisionModel( + "vgg16", + num_classes=num_classes, + loss=loss, + pretrained=pretrained, + **kwargs, + ) + return model + + +def mobilenet_v3_small(num_classes, loss={"xent"}, pretrained=True, **kwargs): + model = TorchVisionModel( + "mobilenet_v3_small", + num_classes=num_classes, + loss=loss, + pretrained=pretrained, + **kwargs, + ) + return model + + +# Define any models supported by torchvision bellow +# https://pytorch.org/vision/0.11/models.html diff --git a/test.sh b/test.sh index 3dd87c6..593adc9 100644 --- a/test.sh +++ b/test.sh @@ -3,9 +3,9 @@ python main.py \ -s veri \ -t veri \ --a resnet18 \ +-a mobilenet_v3_small \ --height 224 \ --width 224 \ --test-batch-size 100 \ --evaluate \ ---save-dir logs/eval-resnet18-veri +--save-dir logs/eval-mobilenet_v3_small-veri diff --git a/train.sh b/train.sh index 8338f04..a7a32d8 100644 --- a/train.sh +++ b/train.sh @@ -3,13 +3,13 @@ python main.py \ -s veri \ -t veri \ --a resnet18 \ +-a mobilenet_v3_small \ --height 224 \ --width 224 \ --optim amsgrad \ --lr 0.0003 \ ---max-epoch 60 \ ---stepsize 20 40 \ ---train-batch-size 64 \ +--max-epoch 30 \ +--stepsize 10 20 \ +--train-batch-size 256 \ --test-batch-size 100 \ ---save-dir logs/resnet18-veri +--save-dir logs/mobilenet_v3_small-veri