-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathtrain.py
173 lines (144 loc) · 6.71 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import _init_paths
import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
from model.net import get_model
from dataloader.triplet_img_loader import get_loader
from utils.gen_utils import make_dir_if_not_exist
from utils.vis_utils import vis_with_paths, vis_with_paths_and_bboxes
from config.base_config import cfg, cfg_from_file
def main():
torch.manual_seed(1)
if args.cuda:
torch.cuda.manual_seed(1)
cudnn.benchmark = True
exp_dir = os.path.join(args.result_dir, args.exp_name)
make_dir_if_not_exist(exp_dir)
# Build Model
model = get_model(args, device)
if model is None:
return
# Criterion and Optimizer
params = []
for key, value in dict(model.named_parameters()).items():
if value.requires_grad:
params += [{'params': [value]}]
criterion = torch.nn.MarginRankingLoss(margin=args.margin)
optimizer = optim.Adam(params, lr=args.lr)
# Train Test Loop
for epoch in range(1, args.epochs + 1):
# Init data loaders
train_data_loader, test_data_loader = get_loader(args)
# Test train
test(test_data_loader, model, criterion)
train(train_data_loader, model, criterion, optimizer, epoch)
# Save model
model_to_save = {
"epoch": epoch + 1,
'state_dict': model.state_dict(),
}
if epoch % args.ckp_freq == 0:
file_name = os.path.join(exp_dir, "checkpoint_" + str(epoch) + ".pth")
save_checkpoint(model_to_save, file_name)
def train(data, model, criterion, optimizer, epoch):
print("******** Training ********")
total_loss = 0
model.train()
for batch_idx, img_triplet in enumerate(data):
anchor_img, pos_img, neg_img = img_triplet
anchor_img, pos_img, neg_img = anchor_img.to(device), pos_img.to(device), neg_img.to(device)
anchor_img, pos_img, neg_img = Variable(anchor_img), Variable(pos_img), Variable(neg_img)
E1, E2, E3 = model(anchor_img, pos_img, neg_img)
dist_E1_E2 = F.pairwise_distance(E1, E2, 2)
dist_E1_E3 = F.pairwise_distance(E1, E3, 2)
target = torch.FloatTensor(dist_E1_E2.size()).fill_(-1)
target = target.to(device)
target = Variable(target)
loss = criterion(dist_E1_E2, dist_E1_E3, target)
total_loss += loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
log_step = args.train_log_step
if (batch_idx % log_step == 0) and (batch_idx != 0):
print('Train Epoch: {} [{}/{}] \t Loss: {:.4f}'.format(epoch, batch_idx, len(data), total_loss / log_step))
total_loss = 0
print("****************")
def test(data, model, criterion):
print("******** Testing ********")
with torch.no_grad():
model.eval()
accuracies = [0, 0, 0]
acc_threshes = [0, 0.2, 0.5]
total_loss = 0
for batch_idx, img_triplet in enumerate(data):
anchor_img, pos_img, neg_img = img_triplet
anchor_img, pos_img, neg_img = anchor_img.to(device), pos_img.to(device), neg_img.to(device)
anchor_img, pos_img, neg_img = Variable(anchor_img), Variable(pos_img), Variable(neg_img)
E1, E2, E3 = model(anchor_img, pos_img, neg_img)
dist_E1_E2 = F.pairwise_distance(E1, E2, 2)
dist_E1_E3 = F.pairwise_distance(E1, E3, 2)
target = torch.FloatTensor(dist_E1_E2.size()).fill_(-1)
target = target.to(device)
target = Variable(target)
loss = criterion(dist_E1_E2, dist_E1_E3, target)
total_loss += loss
for i in range(len(accuracies)):
prediction = (dist_E1_E3 - dist_E1_E2 - args.margin * acc_threshes[i]).cpu().data
prediction = prediction.view(prediction.numel())
prediction = (prediction > 0).float()
batch_acc = prediction.sum() * 1.0 / prediction.numel()
accuracies[i] += batch_acc
print('Test Loss: {}'.format(total_loss / len(data)))
for i in range(len(accuracies)):
print(
'Test Accuracy with diff = {}% of margin: {}'.format(acc_threshes[i] * 100, accuracies[i] / len(data)))
print("****************")
def save_checkpoint(state, file_name):
torch.save(state, file_name)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='PyTorch Siamese Example')
parser.add_argument('--result_dir', default='data', type=str,
help='Directory to store results')
parser.add_argument('--exp_name', default='exp0', type=str,
help='name of experiment')
parser.add_argument('--cuda', action='store_true', default=False,
help='enables CUDA training')
parser.add_argument("--gpu_devices", type=int, nargs='+', default=None,
help="List of GPU Devices to train on")
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--ckp_freq', type=int, default=1, metavar='N',
help='Checkpoint Frequency (default: 1)')
parser.add_argument('--batch_size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--lr', type=float, default=0.0001, metavar='LR',
help='learning rate (default: 0.0001)')
parser.add_argument('--margin', type=float, default=1.0, metavar='M',
help='margin for triplet loss (default: 1.0)')
parser.add_argument('--ckp', default=None, type=str,
help='path to load checkpoint')
parser.add_argument('--dataset', type=str, default='mnist', metavar='M',
help='Dataset (default: mnist)')
parser.add_argument('--num_train_samples', type=int, default=50000, metavar='M',
help='number of training samples (default: 3000)')
parser.add_argument('--num_test_samples', type=int, default=10000, metavar='M',
help='number of test samples (default: 1000)')
parser.add_argument('--train_log_step', type=int, default=100, metavar='M',
help='Number of iterations after which to log the loss')
global args, device
args = parser.parse_args()
args.cuda = args.cuda and torch.cuda.is_available()
cfg_from_file("config/test.yaml")
if args.cuda:
device = 'cuda'
if args.gpu_devices is None:
args.gpu_devices = [0]
else:
device = 'cpu'
main()