-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
executable file
·89 lines (71 loc) · 4.87 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
import os.path as osp
import shutil
import sys
import hydra
import logging
from pytorch_lightning import callbacks
import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from datautil.waymo_dataset import WaymoDataset, waymo_collate_fn, waymo_worker_fn
from model.pl_module import SceneTransformer
torch.multiprocessing.set_sharing_strategy('file_system')
@hydra.main(config_path='./conf', config_name='config.yaml')
def main(cfg):
pl.seed_everything(cfg.seed)
logger = logging.getLogger("pytorch_lightning")
logger.addHandler(logging.FileHandler("terminal.log"))
pwd = hydra.utils.get_original_cwd()
print('Current Path: ', pwd)
model = SceneTransformer(cfg)
checkpoint_callback = ModelCheckpoint(mode='min', monitor='val/loss', auto_insert_metric_name=True, verbose=True, save_last=True, save_top_k=3)
checkpoint_callback_2 = ModelCheckpoint(mode='min', monitor='val/minfde', auto_insert_metric_name=True, verbose=True, save_last=True, save_top_k=3)
early_stopping = callbacks.EarlyStopping(monitor='val/loss', mode='min', patience=10)
lr_monitor = LearningRateMonitor(logging_interval='step')
trainer_args = {'max_epochs': cfg.max_epochs,
'gpus': cfg.gpu_ids,
'accelerator': 'gpu',
'val_check_interval': cfg.dataset.train.val_check_interval, 'limit_train_batches': cfg.dataset.train.limit_train_batches,
'limit_val_batches': cfg.dataset.train.limit_val_batches,
'log_every_n_steps': cfg.max_epochs, 'auto_lr_find': True,
'callbacks': [checkpoint_callback, checkpoint_callback_2, lr_monitor]}
if cfg.resume:
trainer_args['resume_from_checkpoint'] = cfg.resume
trainer_args['num_sanity_val_steps']=0
trainer = pl.Trainer(**trainer_args)
if cfg.mode == 'train':
dataset_train = WaymoDataset(osp.join(pwd, cfg.dataset.train.tfrecords), osp.join(pwd, cfg.dataset.train.idxs), shuffle_queue_size=cfg.dataset.train.batchsize)
dloader_train = DataLoader(dataset_train, batch_size=cfg.dataset.train.batchsize,
collate_fn=lambda b: waymo_collate_fn(b, halfwidth=cfg.dataset.halfwidth, only_veh=cfg.dataset.only_veh, hidden=cfg.dataset.hidden, time_steps=cfg.model.time_steps, current_step=cfg.model.current_step),
num_workers=cfg.dataset.train.batchsize//2)
dataset_valid = WaymoDataset(osp.join(pwd, cfg.dataset.valid.tfrecords), osp.join(pwd, cfg.dataset.valid.idxs), shuffle_queue_size=None)
dloader_valid = DataLoader(dataset_valid, batch_size=cfg.dataset.valid.batchsize,
collate_fn=lambda b: waymo_collate_fn(b, halfwidth=cfg.dataset.halfwidth, only_veh=cfg.dataset.only_veh, hidden=cfg.dataset.hidden, time_steps=cfg.model.time_steps, current_step=cfg.model.current_step),
num_workers=cfg.dataset.valid.batchsize//2)
if cfg.resume:
trainer.fit(model, dloader_train, dloader_valid, ckpt_path=cfg.resume)
else:
trainer.fit(model, dloader_train, dloader_valid)
elif cfg.mode == 'validate':
dataset_valid = WaymoDataset(osp.join(pwd, cfg.dataset.valid.tfrecords), osp.join(pwd, cfg.dataset.valid.idxs), shuffle_queue_size=None)
dloader_valid = DataLoader(dataset_valid, batch_size=cfg.dataset.valid.batchsize,
collate_fn=lambda b: waymo_collate_fn(b, halfwidth=cfg.dataset.halfwidth, only_veh=cfg.dataset.only_veh, hidden=cfg.dataset.hidden, time_steps=cfg.model.time_steps, current_step=cfg.model.current_step),
worker_init_fn=waymo_worker_fn,
num_workers=cfg.dataset.valid.batchsize)
model = model.load_from_checkpoint(checkpoint_path=cfg.dataset.test.ckpt_path, cfg=cfg)
trainer.validate(model, dataloaders=dloader_valid, verbose=True)
elif cfg.mode == 'test':
dataset_test = WaymoDataset(osp.join(pwd, cfg.dataset.test.tfrecords), osp.join(pwd, cfg.dataset.test.idxs), shuffle_queue_size=None)
dloader_test = DataLoader(dataset_test, batch_size=cfg.dataset.test.batchsize, collate_fn=lambda b: waymo_collate_fn(b, time_steps=cfg.model.time_steps, current_step=cfg.model.current_step, hidden=cfg.dataset.hidden), worker_init_fn=waymo_worker_fn, num_workers=cfg.dataset.test.batchsize)
dir = os.path.join(os.getcwd(),'results')
if os.path.exists(dir):
shutil.rmtree(dir)
os.makedirs(dir)
model = model.load_from_checkpoint(checkpoint_path=cfg.dataset.test.ckpt_path, cfg=cfg)
trainer.test(model, dataloaders=dloader_test, verbose=True)
else:
raise KeyError
if __name__ == '__main__':
sys.exit(main())