diff --git a/main.py b/main.py index 011389c..f86902f 100644 --- a/main.py +++ b/main.py @@ -213,7 +213,8 @@ def main(): ) # Setup optimizer - optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) + optimizer = optim.Adam(model.parameters(), lr=args.lr) + print("optim Adam") # Loop over the epochs best_val_loss = 1e8 @@ -240,4 +241,5 @@ def main(): if __name__ == "__main__": + print("optim en cours") main() diff --git a/model.py b/model.py index 4e8ca4f..3544cde 100644 --- a/model.py +++ b/model.py @@ -1,23 +1,25 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torchvision import transforms, models + nclasses = 250 class Net(nn.Module): - def __init__(self): + def __init__(self,num_classes=nclasses): super(Net, self).__init__() - self.conv1 = nn.Conv2d(3, 10, kernel_size=5) - self.conv2 = nn.Conv2d(10, 20, kernel_size=5) - self.conv3 = nn.Conv2d(20, 20, kernel_size=5) - self.fc1 = nn.Linear(320, 50) - self.fc2 = nn.Linear(50, nclasses) + resnet = models.resnet18(pretrained=True) + + # Remove the fully connected layer of ResNet-18 + self.features = nn.Sequential(*list(resnet.children())[:-1]) + + # Add your custom fully connected layer + self.fc = nn.Linear(resnet.fc.in_features, num_classes) def forward(self, x): - x = F.relu(F.max_pool2d(self.conv1(x), 2)) - x = F.relu(F.max_pool2d(self.conv2(x), 2)) - x = F.relu(F.max_pool2d(self.conv3(x), 2)) - x = x.view(-1, 320) - x = F.relu(self.fc1(x)) - return self.fc2(x) + x = self.features(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + return x