-
Notifications
You must be signed in to change notification settings - Fork 50
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0854a12
commit 0349947
Showing
4 changed files
with
309 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import argparse | ||
|
||
|
||
def load_config(): | ||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument('--batch_size', type=int, default=128) | ||
parser.add_argument('--num_workers', type=int, default=4) | ||
parser.add_argument('--lr', type=float, default=0.1) | ||
parser.add_argument('--weight_decay', type=float, default=1e-4) | ||
parser.add_argument('--momentum', type=float, default=0.9) | ||
parser.add_argument('--cuda', type=bool, default=True) | ||
parser.add_argument('--epochs', type=int, default=310) | ||
parser.add_argument('--print_intervals', type=int, default=100) | ||
parser.add_argument('--evaluation', type=bool, default=False) | ||
parser.add_argument('--checkpoints', type=str, default=None, help='model checkpoints path') | ||
parser.add_argument('--device_num', type=int, default=1) | ||
parser.add_argument('--gradient_clip', type=float, default=2.) | ||
|
||
return parser.parse_args() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
|
||
import os | ||
from config import load_config | ||
from preprocess import load_data | ||
from model import ResNet50 | ||
|
||
|
||
def save_checkpoint(best_acc, model, optimizer, args, epoch): | ||
print('Best Model Saving...') | ||
if args.device_num > 1: | ||
model_state_dict = model.module.state_dict() | ||
else: | ||
model_state_dict = model.state_dict() | ||
|
||
torch.save({ | ||
'model_state_dict': model_state_dict, | ||
'global_epoch': epoch, | ||
'optimizer_state_dict': optimizer.state_dict(), | ||
'best_acc': best_acc, | ||
}, os.path.join('checkpoints', 'checkpoint_model_best.pth')) | ||
|
||
|
||
def _train(epoch, train_loader, model, optimizer, criterion, args): | ||
model.train() | ||
|
||
losses = 0. | ||
acc = 0. | ||
total = 0. | ||
for idx, (data, target) in enumerate(train_loader): | ||
if args.cuda: | ||
data, target = data.cuda(), target.cuda() | ||
|
||
output = model(data) | ||
_, pred = F.softmax(output, dim=-1).max(1) | ||
acc += pred.eq(target).sum().item() | ||
total += target.size(0) | ||
|
||
optimizer.zero_grad() | ||
loss = criterion(output, target) | ||
losses += loss | ||
loss.backward() | ||
if args.gradient_clip > 0: | ||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_clip) | ||
optimizer.step() | ||
|
||
if idx % args.print_intervals == 0 and idx != 0: | ||
print('[Epoch: {0:4d}], Loss: {1:.3f}, Acc: {2:.3f}, Correct {3} / Total {4}'.format(epoch, | ||
losses / (idx + 1), | ||
acc / total * 100., | ||
acc, total)) | ||
|
||
|
||
def _eval(epoch, test_loader, model, args): | ||
model.eval() | ||
|
||
acc = 0. | ||
with torch.no_grad(): | ||
for data, target in test_loader: | ||
if args.cuda: | ||
data, target = data.cuda(), target.cuda() | ||
output = model(data) | ||
_, pred = F.softmax(output, dim=-1).max(1) | ||
|
||
acc += pred.eq(target).sum().item() | ||
print('[Epoch: {0:4d}], Acc: {1:.3f}'.format(epoch, acc / len(test_loader.dataset) * 100.)) | ||
|
||
return acc / len(test_loader.dataset) * 100. | ||
|
||
|
||
def main(args): | ||
train_loader, test_loader = load_data(args) | ||
model = ResNet50() | ||
|
||
optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum) | ||
|
||
if not os.path.isdir('checkpoints'): | ||
os.mkdir('checkpoints') | ||
|
||
if args.checkpoints is not None: | ||
checkpoints = torch.load(os.path.join('checkpoints', args.checkpoints)) | ||
model.load_state_dict(checkpoints['model_state_dict']) | ||
optimizer.load_state_dict(checkpoints['optimizer_state_dict']) | ||
start_epoch = checkpoints['global_epoch'] | ||
else: | ||
start_epoch = 1 | ||
|
||
if args.cuda: | ||
model = model.cuda() | ||
|
||
if not args.evaluation: | ||
criterion = nn.CrossEntropyLoss() | ||
lr_scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=0.0001) | ||
|
||
global_acc = 0. | ||
for epoch in range(start_epoch, args.epochs + 1): | ||
_train(epoch, train_loader, model, optimizer, criterion, args) | ||
best_acc = _eval(epoch, test_loader, model, args) | ||
if global_acc < best_acc: | ||
global_acc = best_acc | ||
save_checkpoint(best_acc, model, optimizer, args, epoch) | ||
|
||
lr_scheduler.step() | ||
print('Current Learning Rate: {}'.format(lr_scheduler.get_last_lr())) | ||
else: | ||
_eval(start_epoch, test_loader, model, args) | ||
|
||
|
||
if __name__ == '__main__': | ||
args = load_config() | ||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
def get_n_params(model): | ||
pp=0 | ||
for p in list(model.parameters()): | ||
nn=1 | ||
for s in list(p.size()): | ||
nn = nn*s | ||
pp += nn | ||
return pp | ||
|
||
|
||
class MHSA(nn.Module): | ||
def __init__(self, n_dims, width=14, height=14): | ||
super(MHSA, self).__init__() | ||
|
||
self.query = nn.Conv2d(n_dims, n_dims, kernel_size=1) | ||
self.key = nn.Conv2d(n_dims, n_dims, kernel_size=1) | ||
self.value = nn.Conv2d(n_dims, n_dims, kernel_size=1) | ||
|
||
self.rel_h = nn.Parameter(torch.randn([1, n_dims, height, 1]), requires_grad=True) | ||
self.rel_w = nn.Parameter(torch.randn([1, n_dims, 1, width]), requires_grad=True) | ||
|
||
self.softmax = nn.Softmax(dim=-1) | ||
|
||
def forward(self, x): | ||
n_batch, C, width, height = x.size() | ||
q = self.query(x).view(n_batch, C, -1) | ||
k = self.key(x).view(n_batch, C, -1) | ||
v = self.value(x).view(n_batch, C, -1) | ||
|
||
content_content = torch.bmm(q.permute(0, 2, 1), k) | ||
|
||
content_position = (self.rel_h + self.rel_w).view(1, C, -1).permute(0, 2, 1) | ||
content_position = torch.matmul(content_position, q) | ||
|
||
energy = content_content + content_position | ||
attention = self.softmax(energy) | ||
|
||
out = torch.bmm(v, attention.permute(0, 2, 1)) | ||
out = out.view(n_batch, C, width, height) | ||
|
||
return out | ||
|
||
|
||
class Bottleneck(nn.Module): | ||
expansion = 4 | ||
|
||
def __init__(self, in_planes, planes, stride=1, mhsa=False): | ||
super(Bottleneck, self).__init__() | ||
|
||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) | ||
self.bn1 = nn.BatchNorm2d(planes) | ||
if not mhsa: | ||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, stride=stride, bias=False) | ||
else: | ||
if stride == 2: | ||
self.conv2 = nn.Sequential( | ||
MHSA(planes, width=14, height=14), | ||
# MHSA(planes, width=8, height=8), # for CIFAR10 | ||
nn.AvgPool2d(2, 2), | ||
) | ||
else: | ||
self.conv2 = nn.Sequential( | ||
MHSA(planes, width=7, height=7), | ||
# MHSA(planes, width=4, height=4), # for CIFAR10 | ||
) | ||
self.bn2 = nn.BatchNorm2d(planes) | ||
self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) | ||
self.bn3 = nn.BatchNorm2d(self.expansion * planes) | ||
|
||
self.shortcut = nn.Sequential() | ||
if stride != 1 or in_planes != self.expansion*planes: | ||
self.shortcut = nn.Sequential( | ||
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride), | ||
nn.BatchNorm2d(self.expansion*planes) | ||
) | ||
|
||
def forward(self, x): | ||
out = F.relu(self.bn1(self.conv1(x))) | ||
out = F.relu(self.bn2(self.conv2(out))) | ||
out = self.bn3(self.conv3(out)) | ||
out += self.shortcut(x) | ||
out = F.relu(out) | ||
return out | ||
|
||
|
||
# reference | ||
# https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py | ||
class ResNet(nn.Module): | ||
def __init__(self, block, num_blocks, num_classes=1000): | ||
super(ResNet, self).__init__() | ||
self.in_planes = 64 | ||
|
||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) | ||
# self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) | ||
self.bn1 = nn.BatchNorm2d(64) | ||
self.relu = nn.ReLU(inplace=True) | ||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # for ImageNet | ||
|
||
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) | ||
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) | ||
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) | ||
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2, mhsa=True) | ||
|
||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | ||
self.fc = nn.Sequential( | ||
nn.Dropout(0.3), # All architecture deeper than ResNet-200 dropout_rate: 0.2 | ||
nn.Linear(512 * block.expansion, num_classes) | ||
) | ||
|
||
def _make_layer(self, block, planes, num_blocks, stride=1, mhsa=False): | ||
strides = [stride] + [1]*(num_blocks-1) | ||
layers = [] | ||
for idx, stride in enumerate(strides): | ||
layers.append(block(self.in_planes, planes, stride, mhsa)) | ||
self.in_planes = planes * block.expansion | ||
return nn.Sequential(*layers) | ||
|
||
def forward(self, x): | ||
out = self.relu(self.bn1(self.conv1(x))) | ||
out = self.maxpool(out) # for ImageNet | ||
|
||
out = self.layer1(out) | ||
out = self.layer2(out) | ||
out = self.layer3(out) | ||
out = self.layer4(out) | ||
|
||
out = self.avgpool(out) | ||
out = torch.flatten(out, 1) | ||
out = self.fc(out) | ||
return out | ||
|
||
|
||
def ResNet50(): | ||
return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=1000) | ||
|
||
|
||
def main(): | ||
model = ResNet50() | ||
x = torch.randn([2, 3, 224, 224]) | ||
print(model(x).size()) | ||
print(get_n_params(model)) | ||
|
||
|
||
# if __name__ == '__main__': | ||
# main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from torch.utils.data import Dataset, DataLoader | ||
from torchvision.datasets import CIFAR10 | ||
import torchvision.transforms as transforms | ||
|
||
|
||
def load_data(args): | ||
train_transform = transforms.Compose([ | ||
transforms.RandomCrop(32, padding=4), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), | ||
]) | ||
train_dataset = CIFAR10('./data', train=True, transform=train_transform, download=True) | ||
|
||
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) | ||
|
||
test_transform = transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), | ||
]) | ||
test_dataset = CIFAR10('./data', train=False, transform=test_transform, download=True) | ||
|
||
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) | ||
|
||
return train_loader, test_loader |