-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_masked.py
122 lines (103 loc) · 5.05 KB
/
train_masked.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import argparse
import time
import os
import sys
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from datasets import FCVID, ACTNET_tokens, YLIMED_tokens, miniKinetics_tokens
from model import ModelGCN_maskedblock_video as Model
parser = argparse.ArgumentParser(description='GCN Video Clasmsification')
parser.add_argument('--gcn_layers', type=int, default=2, help='number of gcn layers')
parser.add_argument('--dataset', default='minikinetics', choices=['fcvid', 'minikinetics', 'actnet', 'ylimed'])
parser.add_argument('--dataset_root', default='/home/dimidask/Projects/miniKinetics_130k', help='dataset root directory')
parser.add_argument('--lr', type=float, default=1e-3, help='initial learning rate')
parser.add_argument('--milestones', nargs="+", type=int, default=[50, 100], help='milestones of learning decay')
parser.add_argument('--num_epochs', type=int, default=200, help='number of epochs to train')
parser.add_argument('--batch_size', type=int, default=64, help='batch size')
parser.add_argument('--num_objects', type=int, default=50, help='number of objects with best DoC')
parser.add_argument('--num_workers', type=int, default=4, help='number of workers for data loader')
parser.add_argument('--ext_method', default='VIT', choices=['VIT', 'CLIP'], help='Extraction method for features')
parser.add_argument('--resume', default=None, help='checkpoint to resume training')
parser.add_argument('--save_interval', type=int, default=100, help='interval for saving models (epochs)')
parser.add_argument('--save_folder', default='weights', help='directory to save checkpoints')
parser.add_argument('-v', '--verbose', action='store_true', help='show details')
args = parser.parse_args()
def compute_tokens(tokens):
sums_tensor = tokens.sum(dim=(1, 2))
num_top_elements = int(0.25 * sums_tensor.shape[1])
topk_values, topk_indices = torch.topk(sums_tensor, k=num_top_elements, dim=1)
new_tensor = torch.zeros_like(sums_tensor)
new_tensor.scatter_(1, topk_indices, 1)
return new_tensor
def train(model, loader, crit, opt, sched, device):
epoch_loss = 0
for i, batch in enumerate(loader):
feats, tokens, _ = batch
tokens = compute_tokens(tokens)
feats = feats.to(device)
tokens = tokens.to(device)
opt.zero_grad()
out_data = model(feats)
loss = crit(out_data, tokens)
loss.backward()
opt.step()
epoch_loss += loss.item()
sched.step()
return epoch_loss / len(loader)
def main():
if not os.path.exists(args.save_folder):
os.mkdir(args.save_folder)
if args.dataset == 'fcvid':
dataset = FCVID(args.dataset_root, is_train=True, ext_method=args.ext_method)
crit = nn.BCEWithLogitsLoss()
elif args.dataset == 'actnet':
dataset = ACTNET_tokens(args.dataset_root, is_train=True, ext_method=args.ext_method)
crit = nn.BCEWithLogitsLoss()
elif args.dataset == 'minikinetics':
dataset = miniKinetics_tokens(args.dataset_root, is_train=True, ext_method=args.ext_method)
crit = nn.CrossEntropyLoss()
elif args.dataset == 'ylimed':
dataset = YLIMED_tokens(args.dataset_root, is_train=True, ext_method=args.ext_method)
crit = nn.BCEWithLogitsLoss()
else:
sys.exit("Unknown dataset!")
device = torch.device('cuda:0')
loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True)
if args.verbose:
print("running on {}".format(device))
print("num samples={}".format(len(dataset)))
print("missing videos={}".format(dataset.num_missing))
start_epoch = 0
model = Model(args.gcn_layers, dataset.NUM_FEATS, args.batch_size, dataset.NUM_FRAMES, dataset.NUM_BOXES).to(device)
opt = optim.Adam(model.parameters(), lr=args.lr)
sched = optim.lr_scheduler.MultiStepLR(opt, milestones=args.milestones)
if args.resume:
data = torch.load(args.resume)
start_epoch = data['epoch']
model.load_state_dict(data['model_state_dict'])
opt.load_state_dict(data['opt_state_dict'])
sched.load_state_dict(data['sched_state_dict'])
if args.verbose:
print("resuming from epoch {}".format(start_epoch))
model.train()
for epoch in range(start_epoch, args.num_epochs):
t0 = time.perf_counter()
loss = train(model, loader, crit, opt, sched, device)
t1 = time.perf_counter()
if (epoch + 1) % args.save_interval == 0:
sfnametmpl = 'model-{}-tokens_clip-{:03d}.pt'
sfname = sfnametmpl.format(args.dataset, epoch + 1)
spth = os.path.join(args.save_folder, sfname)
torch.save({
'epoch': epoch + 1,
'loss': loss,
'model_state_dict': model.state_dict(),
'opt_state_dict': opt.state_dict(),
'sched_state_dict': sched.state_dict()
}, spth)
if args.verbose:
print("[epoch {}] loss={} dt={:.2f}sec".format(epoch + 1, loss, t1 - t0))
if __name__ == '__main__':
main()