-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmulti_train.py
executable file
·159 lines (135 loc) · 5.92 KB
/
multi_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
"""Thie file contains the main code for training the relevant networks.
Available models for training include:
- CLS model (with/without imagenet pretraining)
| - grconvnet
| - alexnet
- Grasp model (with/without imagenet pretraining)
| - alexnet
Comment or uncomment certain lines of code for swapping between
training CLS model and Grasping model.
E.g. Uncomment the lines with NO SPACE between '#' and the codes:
"Training for Grasping"
# Loss fn for CLS training
#loss = nn.CrossEntropyLoss()
# Loss fn for Grasping
loss = nn.MSELoss()
----->
# Loss fn for CLS training
loss = nn.CrossEntropyLoss()
# Loss fn for Grasping
#loss = nn.MSELoss()
"""
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torchvision.models import alexnet
from tqdm import tqdm
from multi_task_models.grcn_multi_alex import Multi_AlexnetMap_v3
from utils.paths import Path
from utils.parameters import Params
from data_processing.data_loader_v2 import DataLoader
from utils.utils import epoch_logger, log_writer, get_correct_cls_preds_from_map, get_acc
from utils.grasp_utils import get_correct_grasp_preds_from_map
from training.single_task.evaluation import get_cls_acc, get_grasp_acc
from training.single_task.loss import MapLoss, DistillationLoss
SEED=42
params = Params()
paths = Path()
# Create <trained-models> directory
paths.create_model_path()
# Create directory for training logs
paths.create_log_path()
# Create subdirectory in <logs> for current model
paths.create_model_log_path()
# Set common seed
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
# Load model
#model = models.AlexnetMap_v3().to(params.DEVICE)
#model = modelsGr.GrConvMap_v1().to(params.DEVICE)
model = Multi_AlexnetMap_v3().to('cpu')
# Teacher model
# pretrained_alexnet = alexnet(pretrained=True).to(params.DEVICE)
# pretrained_alexnet.eval()
# for weights in pretrained_alexnet.features.parameters():
# weights.requires_grad = False
# Load checkpoint weights
# checkpoint_name = 'alexnetGrasp_depthconcat_convtrans_top5_v4.3'
# checkpoint_epoch = 50
# checkpoint_path = os.path.join(params.MODEL_PATH, checkpoint_name, '%s_epoch%s.pth' % (checkpoint_name, checkpoint_epoch))
# model.load_state_dict(torch.load(checkpoint_path))
# Create DataLoader class
data_loader = DataLoader(params.TRAIN_PATH, params.BATCH_SIZE, params.TRAIN_VAL_SPLIT, seed=SEED)
# Get number of training/validation steps
n_train, n_val = data_loader.get_train_val()
# Training utils
optim = Adam(model.parameters(), lr=params.LR)
scheduler = torch.optim.lr_scheduler.StepLR(optim, 25, 0.5)
for epoch in tqdm(range(1, params.EPOCHS + 1)):
if epoch == 75:
model.unfreeze_depth_backbone()
train_history = []
c_train_history = []
val_history = []
c_val_history = []
train_total = 1
train_correct = 1
val_total = 1
val_correct = 1
# Data loop for CLS training
#for step, (img, map, label) in enumerate(data_loader.load_batch()):
# Data loop for Grasp training
image_data = enumerate(zip(data_loader.load_grasp_batch(), data_loader.load_batch()))
values = (0,0)
for step, ((img_grp, map_grp, label_grp), (img_cls, map_cls, label_cls)) in image_data:
optim.zero_grad()
output_cls = model(img_cls, is_grasp=False)
loss_cls = MapLoss(output_cls, map_cls)
output_grp = model(img_grp, is_grasp=True)
loss_grp = MapLoss(output_grp, map_grp)
# Loss fn for CLS/Grasp training
loss = (params.loss_weight)*loss_grp + (2 - params.loss_weight) * loss_cls
# Distillation loss (experimental)
#distill_loss = DistillationLoss(img, model, pretrained_alexnet, model_s_type='alexnetMap', model_t_type='alexnet')
#loss = loss + distill_loss * params.DISTILL_ALPHA
if step < n_train:
loss.backward()
optim.step()
# Write loss to log file -- 'logs/<model_name>/<model_name>_log.txt'
# log_writer(params.MODEL_NAME, epoch, step, loss.item(), train=True)
train_history.append(loss_grp)
c_train_history.append(loss_cls)
# Dummie prediction stats
correct, total = 0, 1
train_correct += correct
train_total += total
else:
# log_writer(params.MODEL_NAME, epoch, step, loss.item(), train=False)
val_history.append(loss_grp)
c_val_history.append(loss_cls)
# Dummie prediction stats
correct, total = 0, 1
val_correct += correct
val_total += total
# Get testing accuracy stats (CLS / Grasp)
if (epoch % 10 == 1):
model.eval()
c_train_acc, c_train_loss = get_cls_acc(model, include_depth=True, seed=SEED, dataset=params.TRAIN_PATH, truncation=None)
train_acc, train_loss = get_grasp_acc(model, include_depth=True, seed=SEED, dataset=params.TRAIN_PATH, truncation=None)
c_test_acc, c_test_loss = get_cls_acc(model, include_depth=True, seed=SEED, dataset=params.TEST_PATH, truncation=None)
test_acc, test_loss = get_grasp_acc(model, include_depth=True, seed=SEED, dataset=params.TRAIN_PATH, truncation=None)
scheduler.step()
# Experimental
#params.DISTILL_ALPHA /= 2
model.train()
# Get training and validation accuracies
val_acc = train_acc # get_acc(val_correct, val_total)
c_val_acc = c_train_acc
# Write epoch loss stats to log file
epoch_logger(params.MODEL_NAME, epoch, train_history, val_history, test_loss, train_acc, val_acc, test_acc, c_train_history, c_val_history, c_test_loss, c_train_acc, c_val_acc, c_test_acc)
# Save checkpoint model -- 'trained-models/<model_name>/<model_name>_epoch<epoch>.pth'
torch.save(model.state_dict(), os.path.join(params.MODEL_LOG_PATH, f"{params.MODEL_NAME}_epoch{epoch}.pth"))
#Save final epoch model
torch.save(model.state_dict(), os.path.join(params.MODEL_LOG_PATH, f"{params.MODEL_NAME}_final.pth"))