From 4d4636e8926938cf41be281bc438d9a273585e70 Mon Sep 17 00:00:00 2001 From: phpin57 <78471883+phpin57@users.noreply.github.com> Date: Sun, 19 Nov 2023 20:09:08 +0100 Subject: [PATCH 1/4] Update main.py --- main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/main.py b/main.py index 011389c..e9adf22 100644 --- a/main.py +++ b/main.py @@ -240,4 +240,5 @@ def main(): if __name__ == "__main__": + print("optim en cours") main() From 341f93668aec29554736441f3eee7f8ab04f6dac Mon Sep 17 00:00:00 2001 From: phpin57 <78471883+phpin57@users.noreply.github.com> Date: Sun, 19 Nov 2023 21:35:07 +0100 Subject: [PATCH 2/4] Update main.py --- main.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/main.py b/main.py index e9adf22..a8479b9 100644 --- a/main.py +++ b/main.py @@ -213,7 +213,8 @@ def main(): ) # Setup optimizer - optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) + optimizer = optim.Adam(model.parameters(), lr=args.lr, momentum=args.momentum) + print("optim Adam") # Loop over the epochs best_val_loss = 1e8 From 26d0dbf190270a38205b31a555b49b8ae3746e4e Mon Sep 17 00:00:00 2001 From: phpin57 <78471883+phpin57@users.noreply.github.com> Date: Sun, 19 Nov 2023 21:44:27 +0100 Subject: [PATCH 3/4] Update main.py --- main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.py b/main.py index a8479b9..f86902f 100644 --- a/main.py +++ b/main.py @@ -213,7 +213,7 @@ def main(): ) # Setup optimizer - optimizer = optim.Adam(model.parameters(), lr=args.lr, momentum=args.momentum) + optimizer = optim.Adam(model.parameters(), lr=args.lr) print("optim Adam") # Loop over the epochs From 5f1722d7ce20bc9b681afccf29833bd80f800373 Mon Sep 17 00:00:00 2001 From: phpin57 <78471883+phpin57@users.noreply.github.com> Date: Tue, 21 Nov 2023 10:57:06 +0100 Subject: [PATCH 4/4] Update model.py --- model.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/model.py b/model.py index 4e8ca4f..3544cde 100644 --- a/model.py +++ b/model.py @@ -1,23 +1,25 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torchvision import transforms, models + nclasses = 250 class Net(nn.Module): - def __init__(self): + def __init__(self,num_classes=nclasses): super(Net, self).__init__() - self.conv1 = nn.Conv2d(3, 10, kernel_size=5) - self.conv2 = nn.Conv2d(10, 20, kernel_size=5) - self.conv3 = nn.Conv2d(20, 20, kernel_size=5) - self.fc1 = nn.Linear(320, 50) - self.fc2 = nn.Linear(50, nclasses) + resnet = models.resnet18(pretrained=True) + + # Remove the fully connected layer of ResNet-18 + self.features = nn.Sequential(*list(resnet.children())[:-1]) + + # Add your custom fully connected layer + self.fc = nn.Linear(resnet.fc.in_features, num_classes) def forward(self, x): - x = F.relu(F.max_pool2d(self.conv1(x), 2)) - x = F.relu(F.max_pool2d(self.conv2(x), 2)) - x = F.relu(F.max_pool2d(self.conv3(x), 2)) - x = x.view(-1, 320) - x = F.relu(self.fc1(x)) - return self.fc2(x) + x = self.features(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + return x