-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
59 lines (45 loc) · 1.84 KB
/
Copy pathtrain.py
File metadata and controls
59 lines (45 loc) · 1.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import pickle
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
'''
data shape
ex) [[3, (50,120,38), (85,200,5)],
[0, (220,52,135), (200,83,121)],
... ]
'''
with open('data.pickle', 'rb') as f:
data = pickle.load(f)
dataloader = DataLoader(data, batch_size=len(data), shuffle=True)
def train(model, n_epoch, criterion, optimizer, scheduler=None, scale=1, epoch_step=200):
# save the initial state
model.show_space(save=True, epoch=0)
torch.save(model.state_dict(), model.param_dir + f'{model.model_name} epoch 0.pt')
loss_list = []
for epoch in range(n_epoch):
running_loss = 0
for data in dataloader:
answer, color1, color2 = data
answer = answer.type(torch.FloatTensor).cuda()
color1 = color1.type(torch.FloatTensor).cuda()
color2 = color2.type(torch.FloatTensor).cuda()
optimizer.zero_grad()
output1 = model(color1)
output2 = model(color2)
dist = nn.functional.pairwise_distance(output1, output2)
loss = criterion(dist, answer * scale)
loss = loss.type(torch.FloatTensor)
loss.backward()
optimizer.step()
if scheduler:
scheduler.step()
running_loss += loss.item()
epoch_loss = running_loss / len(data)
loss_list.append(epoch_loss)
if (epoch+1)%epoch_step==0:
print('Epoch {} loss: {:.4f}'.format(epoch+1, epoch_loss))
model.show_space(save=True, epoch=epoch+1)
torch.save(model.state_dict(), model.param_dir + f'{model.model_name} epoch {epoch+1}.pt')
with open(model.train_history_dir + 'train history.pickle', 'wb') as f:
pickle.dump(loss_list, f, pickle.HIGHEST_PROTOCOL)
print("Successfullly finished.")