Skip to content

Commit

Permalink
init:
Browse files Browse the repository at this point in the history
  • Loading branch information
leaderj1001 committed Jan 28, 2021
1 parent 0854a12 commit 0349947
Show file tree
Hide file tree
Showing 4 changed files with 309 additions and 0 deletions.
20 changes: 20 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import argparse


def load_config():
parser = argparse.ArgumentParser()

parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--lr', type=float, default=0.1)
parser.add_argument('--weight_decay', type=float, default=1e-4)
parser.add_argument('--momentum', type=float, default=0.9)
parser.add_argument('--cuda', type=bool, default=True)
parser.add_argument('--epochs', type=int, default=310)
parser.add_argument('--print_intervals', type=int, default=100)
parser.add_argument('--evaluation', type=bool, default=False)
parser.add_argument('--checkpoints', type=str, default=None, help='model checkpoints path')
parser.add_argument('--device_num', type=int, default=1)
parser.add_argument('--gradient_clip', type=float, default=2.)

return parser.parse_args()
114 changes: 114 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import os
from config import load_config
from preprocess import load_data
from model import ResNet50


def save_checkpoint(best_acc, model, optimizer, args, epoch):
print('Best Model Saving...')
if args.device_num > 1:
model_state_dict = model.module.state_dict()
else:
model_state_dict = model.state_dict()

torch.save({
'model_state_dict': model_state_dict,
'global_epoch': epoch,
'optimizer_state_dict': optimizer.state_dict(),
'best_acc': best_acc,
}, os.path.join('checkpoints', 'checkpoint_model_best.pth'))


def _train(epoch, train_loader, model, optimizer, criterion, args):
model.train()

losses = 0.
acc = 0.
total = 0.
for idx, (data, target) in enumerate(train_loader):
if args.cuda:
data, target = data.cuda(), target.cuda()

output = model(data)
_, pred = F.softmax(output, dim=-1).max(1)
acc += pred.eq(target).sum().item()
total += target.size(0)

optimizer.zero_grad()
loss = criterion(output, target)
losses += loss
loss.backward()
if args.gradient_clip > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_clip)
optimizer.step()

if idx % args.print_intervals == 0 and idx != 0:
print('[Epoch: {0:4d}], Loss: {1:.3f}, Acc: {2:.3f}, Correct {3} / Total {4}'.format(epoch,
losses / (idx + 1),
acc / total * 100.,
acc, total))


def _eval(epoch, test_loader, model, args):
model.eval()

acc = 0.
with torch.no_grad():
for data, target in test_loader:
if args.cuda:
data, target = data.cuda(), target.cuda()
output = model(data)
_, pred = F.softmax(output, dim=-1).max(1)

acc += pred.eq(target).sum().item()
print('[Epoch: {0:4d}], Acc: {1:.3f}'.format(epoch, acc / len(test_loader.dataset) * 100.))

return acc / len(test_loader.dataset) * 100.


def main(args):
train_loader, test_loader = load_data(args)
model = ResNet50()

optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum)

if not os.path.isdir('checkpoints'):
os.mkdir('checkpoints')

if args.checkpoints is not None:
checkpoints = torch.load(os.path.join('checkpoints', args.checkpoints))
model.load_state_dict(checkpoints['model_state_dict'])
optimizer.load_state_dict(checkpoints['optimizer_state_dict'])
start_epoch = checkpoints['global_epoch']
else:
start_epoch = 1

if args.cuda:
model = model.cuda()

if not args.evaluation:
criterion = nn.CrossEntropyLoss()
lr_scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=0.0001)

global_acc = 0.
for epoch in range(start_epoch, args.epochs + 1):
_train(epoch, train_loader, model, optimizer, criterion, args)
best_acc = _eval(epoch, test_loader, model, args)
if global_acc < best_acc:
global_acc = best_acc
save_checkpoint(best_acc, model, optimizer, args, epoch)

lr_scheduler.step()
print('Current Learning Rate: {}'.format(lr_scheduler.get_last_lr()))
else:
_eval(start_epoch, test_loader, model, args)


if __name__ == '__main__':
args = load_config()
main(args)
150 changes: 150 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


def get_n_params(model):
pp=0
for p in list(model.parameters()):
nn=1
for s in list(p.size()):
nn = nn*s
pp += nn
return pp


class MHSA(nn.Module):
def __init__(self, n_dims, width=14, height=14):
super(MHSA, self).__init__()

self.query = nn.Conv2d(n_dims, n_dims, kernel_size=1)
self.key = nn.Conv2d(n_dims, n_dims, kernel_size=1)
self.value = nn.Conv2d(n_dims, n_dims, kernel_size=1)

self.rel_h = nn.Parameter(torch.randn([1, n_dims, height, 1]), requires_grad=True)
self.rel_w = nn.Parameter(torch.randn([1, n_dims, 1, width]), requires_grad=True)

self.softmax = nn.Softmax(dim=-1)

def forward(self, x):
n_batch, C, width, height = x.size()
q = self.query(x).view(n_batch, C, -1)
k = self.key(x).view(n_batch, C, -1)
v = self.value(x).view(n_batch, C, -1)

content_content = torch.bmm(q.permute(0, 2, 1), k)

content_position = (self.rel_h + self.rel_w).view(1, C, -1).permute(0, 2, 1)
content_position = torch.matmul(content_position, q)

energy = content_content + content_position
attention = self.softmax(energy)

out = torch.bmm(v, attention.permute(0, 2, 1))
out = out.view(n_batch, C, width, height)

return out


class Bottleneck(nn.Module):
expansion = 4

def __init__(self, in_planes, planes, stride=1, mhsa=False):
super(Bottleneck, self).__init__()

self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
if not mhsa:
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, stride=stride, bias=False)
else:
if stride == 2:
self.conv2 = nn.Sequential(
MHSA(planes, width=14, height=14),
# MHSA(planes, width=8, height=8), # for CIFAR10
nn.AvgPool2d(2, 2),
)
else:
self.conv2 = nn.Sequential(
MHSA(planes, width=7, height=7),
# MHSA(planes, width=4, height=4), # for CIFAR10
)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(self.expansion * planes)

self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride),
nn.BatchNorm2d(self.expansion*planes)
)

def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += self.shortcut(x)
out = F.relu(out)
return out


# reference
# https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py
class ResNet(nn.Module):
def __init__(self, block, num_blocks, num_classes=1000):
super(ResNet, self).__init__()
self.in_planes = 64

self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
# self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # for ImageNet

self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2, mhsa=True)

self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Sequential(
nn.Dropout(0.3), # All architecture deeper than ResNet-200 dropout_rate: 0.2
nn.Linear(512 * block.expansion, num_classes)
)

def _make_layer(self, block, planes, num_blocks, stride=1, mhsa=False):
strides = [stride] + [1]*(num_blocks-1)
layers = []
for idx, stride in enumerate(strides):
layers.append(block(self.in_planes, planes, stride, mhsa))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)

def forward(self, x):
out = self.relu(self.bn1(self.conv1(x)))
out = self.maxpool(out) # for ImageNet

out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)

out = self.avgpool(out)
out = torch.flatten(out, 1)
out = self.fc(out)
return out


def ResNet50():
return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=1000)


def main():
model = ResNet50()
x = torch.randn([2, 3, 224, 224])
print(model(x).size())
print(get_n_params(model))


# if __name__ == '__main__':
# main()
25 changes: 25 additions & 0 deletions preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms


def load_data(args):
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
train_dataset = CIFAR10('./data', train=True, transform=train_transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)

test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
test_dataset = CIFAR10('./data', train=False, transform=test_transform, download=True)

test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

return train_loader, test_loader

0 comments on commit 0349947

Please sign in to comment.