-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtrain.py
160 lines (137 loc) · 6.21 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
from argparse import ArgumentParser
import torch
from datamodule.heart import HeartDecathlonDataModule
from datamodule.hippocampus import HippocampusDecathlonDataModule
from datamodule.iseg import ISeg2017DataModule
from datamodule.luna import LUNA16DataModule
from module.segcaps import SegCaps2D, SegCaps3D
from module.ucaps import UCaps3D
from module.mod_ucaps import ModifiedUCaps3D
from module.unet import UNetModule
from monai.utils import set_determinism
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
# Call example
# python train3D.py --gpus 1 --model_name UCaps --num_workers 4 --max_epochs 20000 --check_val_every_n_epoch 100 --log_dir=../logs --root_dir=/home/ubuntu/
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--root_dir", type=str, default="/root")
parser.add_argument("--cache_rate", type=float, default=None)
parser.add_argument("--cache_dir", type=str, default=None)
# Training options
train_parser = parser.add_argument_group("Training config")
train_parser.add_argument("--log_dir", type=str, default="/mnt/vinai/logs")
train_parser.add_argument("--model_name", type=str, default="ucaps", help="ucaps / segcaps-2d / segcaps-3d / unet")
train_parser.add_argument(
"--dataset", type=str, default="iseg2017", help="iseg2017 / task02_heart / task04_hippocampus / luna16"
)
train_parser.add_argument("--train_patch_size", nargs="+", type=int, default=[32, 32, 32])
train_parser.add_argument("--fold", type=int, default=0)
train_parser.add_argument("--num_workers", type=int, default=4)
train_parser.add_argument("--batch_size", type=int, default=1)
train_parser.add_argument(
"--num_samples", type=int, default=1, help="Effective batch size: batch_size x num_samples"
)
train_parser.add_argument("--balance_sampling", type=int, default=1)
train_parser.add_argument("--use_class_weight", type=int, default=0)
parser = Trainer.add_argparse_args(parser)
# THIS LINE IS KEY TO PULL THE MODEL NAME
temp_args, _ = parser.parse_known_args()
# let the model add what it wants
if temp_args.model_name == "ucaps":
parser, model_parser = UCaps3D.add_model_specific_args(parser)
elif temp_args.model_name == "segcaps-2d":
parser, model_parser = SegCaps2D.add_model_specific_args(parser)
elif temp_args.model_name == "segcaps-3d":
parser, model_parser = SegCaps3D.add_model_specific_args(parser)
elif temp_args.model_name == "unet":
parser, model_parser = UNetModule.add_model_specific_args(parser)
elif temp_args.model_name == "modified-ucaps":
parser, model_parser = ModifiedUCaps3D.add_model_specific_args(parser)
args = parser.parse_args()
dict_args = vars(args)
print(f"{args.model_name} config:")
for a in model_parser._group_actions:
print("\t{}:\t{}".format(a.dest, dict_args[a.dest]))
print("Training config:")
for a in train_parser._group_actions:
print("\t{}:\t{}".format(a.dest, dict_args[a.dest]))
# Improve reproducibility
set_determinism(seed=0)
seed_everything(0, workers=True)
# Set up datamodule
if args.dataset == "iseg2017":
data_module = ISeg2017DataModule(
**dict_args,
n_val_replication=args.gpus - 1,
)
elif args.dataset == "task02_heart":
data_module = HeartDecathlonDataModule(
**dict_args,
)
elif args.dataset == "task04_hippocampus":
data_module = HippocampusDecathlonDataModule(
**dict_args,
)
elif args.dataset == "luna16":
data_module = LUNA16DataModule(
**dict_args,
)
else:
pass
# initialise the LightningModule
if args.use_class_weight:
class_weight = torch.tensor(data_module.class_weight).float()
else:
class_weight = None
if args.model_name == "ucaps":
net = UCaps3D(**dict_args, class_weight=class_weight)
elif args.model_name == "segcaps-3d":
net = SegCaps3D(**dict_args, class_weight=class_weight)
elif args.model_name == "segcaps-2d":
net = SegCaps2D(**dict_args, class_weight=class_weight)
elif args.model_name == "unet":
net = UNetModule(**dict_args, class_weight=class_weight)
elif args.model_name == "modified-ucaps":
net = ModifiedUCaps3D(**dict_args, class_weight=class_weight)
# load pretrained
'''
best_weight = torch.load("/home/tqminh/MIS/data/iseg/logs/ucaps_iseg_64_64_64.ckpt")
for k, param in net.state_dict().items():
if k not in best_weight['state_dict'].keys():
print("Missing key: ", k)
else:
try:
# import copy
# before = copy.deepcopy(net.state_dict()[k])
# assert torch.eq(net.state_dict()[k], param) == True
net.state_dict()[k].copy_(best_weight['state_dict'][k])
# assert torch.eq(net.state_dict()[k], param) == False
print("Can load this key: ", k)
except Exception as e:
# print(e)
print("Cannot load this key: ", k)
'''
# set up loggers and checkpoints
if args.dataset == "iseg2017":
tb_logger = TensorBoardLogger(save_dir=args.log_dir, name=f"{args.model_name}_{args.dataset}", log_graph=True)
else:
tb_logger = TensorBoardLogger(
save_dir=args.log_dir, name=f"{args.model_name}_{args.dataset}_{args.fold}", log_graph=True
)
checkpoint_callback = ModelCheckpoint(
filename="{epoch}-{val_dice:.4f}", monitor="val_dice", save_top_k=2, mode="max", save_last=True
)
earlystopping_callback = EarlyStopping(monitor="val_dice", patience=20, mode="max")
trainer = Trainer.from_argparse_args(
args,
benchmark=True,
logger=tb_logger,
callbacks=[checkpoint_callback, earlystopping_callback],
num_sanity_val_steps=1,
terminate_on_nan=True,
)
trainer.fit(net, datamodule=data_module)
print("Best model path ", checkpoint_callback.best_model_path)
print("Best val mean dice ", checkpoint_callback.best_model_score.item())