From 0349947566fbbda65f17baa48c0cec4a03844e56 Mon Sep 17 00:00:00 2001 From: myeongjun Date: Thu, 28 Jan 2021 16:54:27 +0900 Subject: [PATCH] init: --- config.py | 20 +++++++ main.py | 114 ++++++++++++++++++++++++++++++++++++++ model.py | 150 ++++++++++++++++++++++++++++++++++++++++++++++++++ preprocess.py | 25 +++++++++ 4 files changed, 309 insertions(+) create mode 100644 config.py create mode 100644 main.py create mode 100644 model.py create mode 100644 preprocess.py diff --git a/config.py b/config.py new file mode 100644 index 0000000..ce4cd85 --- /dev/null +++ b/config.py @@ -0,0 +1,20 @@ +import argparse + + +def load_config(): + parser = argparse.ArgumentParser() + + parser.add_argument('--batch_size', type=int, default=128) + parser.add_argument('--num_workers', type=int, default=4) + parser.add_argument('--lr', type=float, default=0.1) + parser.add_argument('--weight_decay', type=float, default=1e-4) + parser.add_argument('--momentum', type=float, default=0.9) + parser.add_argument('--cuda', type=bool, default=True) + parser.add_argument('--epochs', type=int, default=310) + parser.add_argument('--print_intervals', type=int, default=100) + parser.add_argument('--evaluation', type=bool, default=False) + parser.add_argument('--checkpoints', type=str, default=None, help='model checkpoints path') + parser.add_argument('--device_num', type=int, default=1) + parser.add_argument('--gradient_clip', type=float, default=2.) + + return parser.parse_args() diff --git a/main.py b/main.py new file mode 100644 index 0000000..da004f3 --- /dev/null +++ b/main.py @@ -0,0 +1,114 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim + +import os +from config import load_config +from preprocess import load_data +from model import ResNet50 + + +def save_checkpoint(best_acc, model, optimizer, args, epoch): + print('Best Model Saving...') + if args.device_num > 1: + model_state_dict = model.module.state_dict() + else: + model_state_dict = model.state_dict() + + torch.save({ + 'model_state_dict': model_state_dict, + 'global_epoch': epoch, + 'optimizer_state_dict': optimizer.state_dict(), + 'best_acc': best_acc, + }, os.path.join('checkpoints', 'checkpoint_model_best.pth')) + + +def _train(epoch, train_loader, model, optimizer, criterion, args): + model.train() + + losses = 0. + acc = 0. + total = 0. + for idx, (data, target) in enumerate(train_loader): + if args.cuda: + data, target = data.cuda(), target.cuda() + + output = model(data) + _, pred = F.softmax(output, dim=-1).max(1) + acc += pred.eq(target).sum().item() + total += target.size(0) + + optimizer.zero_grad() + loss = criterion(output, target) + losses += loss + loss.backward() + if args.gradient_clip > 0: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_clip) + optimizer.step() + + if idx % args.print_intervals == 0 and idx != 0: + print('[Epoch: {0:4d}], Loss: {1:.3f}, Acc: {2:.3f}, Correct {3} / Total {4}'.format(epoch, + losses / (idx + 1), + acc / total * 100., + acc, total)) + + +def _eval(epoch, test_loader, model, args): + model.eval() + + acc = 0. + with torch.no_grad(): + for data, target in test_loader: + if args.cuda: + data, target = data.cuda(), target.cuda() + output = model(data) + _, pred = F.softmax(output, dim=-1).max(1) + + acc += pred.eq(target).sum().item() + print('[Epoch: {0:4d}], Acc: {1:.3f}'.format(epoch, acc / len(test_loader.dataset) * 100.)) + + return acc / len(test_loader.dataset) * 100. + + +def main(args): + train_loader, test_loader = load_data(args) + model = ResNet50() + + optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum) + + if not os.path.isdir('checkpoints'): + os.mkdir('checkpoints') + + if args.checkpoints is not None: + checkpoints = torch.load(os.path.join('checkpoints', args.checkpoints)) + model.load_state_dict(checkpoints['model_state_dict']) + optimizer.load_state_dict(checkpoints['optimizer_state_dict']) + start_epoch = checkpoints['global_epoch'] + else: + start_epoch = 1 + + if args.cuda: + model = model.cuda() + + if not args.evaluation: + criterion = nn.CrossEntropyLoss() + lr_scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=0.0001) + + global_acc = 0. + for epoch in range(start_epoch, args.epochs + 1): + _train(epoch, train_loader, model, optimizer, criterion, args) + best_acc = _eval(epoch, test_loader, model, args) + if global_acc < best_acc: + global_acc = best_acc + save_checkpoint(best_acc, model, optimizer, args, epoch) + + lr_scheduler.step() + print('Current Learning Rate: {}'.format(lr_scheduler.get_last_lr())) + else: + _eval(start_epoch, test_loader, model, args) + + +if __name__ == '__main__': + args = load_config() + main(args) \ No newline at end of file diff --git a/model.py b/model.py new file mode 100644 index 0000000..47748ce --- /dev/null +++ b/model.py @@ -0,0 +1,150 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def get_n_params(model): + pp=0 + for p in list(model.parameters()): + nn=1 + for s in list(p.size()): + nn = nn*s + pp += nn + return pp + + +class MHSA(nn.Module): + def __init__(self, n_dims, width=14, height=14): + super(MHSA, self).__init__() + + self.query = nn.Conv2d(n_dims, n_dims, kernel_size=1) + self.key = nn.Conv2d(n_dims, n_dims, kernel_size=1) + self.value = nn.Conv2d(n_dims, n_dims, kernel_size=1) + + self.rel_h = nn.Parameter(torch.randn([1, n_dims, height, 1]), requires_grad=True) + self.rel_w = nn.Parameter(torch.randn([1, n_dims, 1, width]), requires_grad=True) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x): + n_batch, C, width, height = x.size() + q = self.query(x).view(n_batch, C, -1) + k = self.key(x).view(n_batch, C, -1) + v = self.value(x).view(n_batch, C, -1) + + content_content = torch.bmm(q.permute(0, 2, 1), k) + + content_position = (self.rel_h + self.rel_w).view(1, C, -1).permute(0, 2, 1) + content_position = torch.matmul(content_position, q) + + energy = content_content + content_position + attention = self.softmax(energy) + + out = torch.bmm(v, attention.permute(0, 2, 1)) + out = out.view(n_batch, C, width, height) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, in_planes, planes, stride=1, mhsa=False): + super(Bottleneck, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + if not mhsa: + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, stride=stride, bias=False) + else: + if stride == 2: + self.conv2 = nn.Sequential( + MHSA(planes, width=14, height=14), + # MHSA(planes, width=8, height=8), # for CIFAR10 + nn.AvgPool2d(2, 2), + ) + else: + self.conv2 = nn.Sequential( + MHSA(planes, width=7, height=7), + # MHSA(planes, width=4, height=4), # for CIFAR10 + ) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(self.expansion * planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion*planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride), + nn.BatchNorm2d(self.expansion*planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = F.relu(self.bn2(self.conv2(out))) + out = self.bn3(self.conv3(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +# reference +# https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py +class ResNet(nn.Module): + def __init__(self, block, num_blocks, num_classes=1000): + super(ResNet, self).__init__() + self.in_planes = 64 + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + # self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # for ImageNet + + self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2, mhsa=True) + + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Sequential( + nn.Dropout(0.3), # All architecture deeper than ResNet-200 dropout_rate: 0.2 + nn.Linear(512 * block.expansion, num_classes) + ) + + def _make_layer(self, block, planes, num_blocks, stride=1, mhsa=False): + strides = [stride] + [1]*(num_blocks-1) + layers = [] + for idx, stride in enumerate(strides): + layers.append(block(self.in_planes, planes, stride, mhsa)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + out = self.relu(self.bn1(self.conv1(x))) + out = self.maxpool(out) # for ImageNet + + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + + out = self.avgpool(out) + out = torch.flatten(out, 1) + out = self.fc(out) + return out + + +def ResNet50(): + return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=1000) + + +def main(): + model = ResNet50() + x = torch.randn([2, 3, 224, 224]) + print(model(x).size()) + print(get_n_params(model)) + + +# if __name__ == '__main__': +# main() \ No newline at end of file diff --git a/preprocess.py b/preprocess.py new file mode 100644 index 0000000..33b53fd --- /dev/null +++ b/preprocess.py @@ -0,0 +1,25 @@ +from torch.utils.data import Dataset, DataLoader +from torchvision.datasets import CIFAR10 +import torchvision.transforms as transforms + + +def load_data(args): + train_transform = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + train_dataset = CIFAR10('./data', train=True, transform=train_transform, download=True) + + train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) + + test_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + test_dataset = CIFAR10('./data', train=False, transform=test_transform, download=True) + + test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) + + return train_loader, test_loader \ No newline at end of file