-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_pipeline.py
128 lines (110 loc) · 5.55 KB
/
train_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import argparse
import os
import sys
import torch
from grb.dataset import Dataset
from grb.dataset import GRB_SUPPORTED_DATASETS
from grb.trainer import Trainer
from grb.utils import Logger
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Training GML models in pipeline.')
# Dataset settings
parser.add_argument("--dataset", type=str, default="grb-cora")
parser.add_argument("--data_dir", type=str, default="../data/")
parser.add_argument("--feat_norm", type=str, default="arctan")
# Model settings
parser.add_argument("--model", nargs='+', default=None)
parser.add_argument("--save_dir", type=str, default="../saved_models_test/")
parser.add_argument("--log_dir", type=str, default="../pipeline/logs/")
parser.add_argument("--save_name", type=str, default="model.pt")
# Training settings
parser.add_argument("--gpu", type=int, default=0, help="gpu")
parser.add_argument("--n_train", type=int, default=1)
parser.add_argument("--n_epoch", type=int, default=1000, help="Training epoch.")
parser.add_argument("--lr", type=float, default=0.01, help="Learning rate.")
parser.add_argument("--eval_every", type=int, default=1)
parser.add_argument("--save_after", type=int, default=0)
parser.add_argument("--train_mode", type=str, default="inductive")
parser.add_argument("--eval_metric", type=str, default="acc")
parser.add_argument("--early_stop", action="store_true")
parser.add_argument("--early_stop_patience", type=int, default=500)
parser.add_argument("--lr_scheduler", action="store_true")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--verbose", action="store_true")
args = parser.parse_args()
if args.gpu >= 0:
device = "cuda:{}".format(args.gpu)
else:
device = "cpu"
if args.dataset not in args.data_dir:
args.data_dir = os.path.join(args.data_dir, args.dataset)
if args.dataset not in args.save_dir:
args.save_dir = os.path.join(args.save_dir, args.dataset)
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
if args.dataset not in args.config_dir:
args.config_dir = os.path.join(args.config_dir, args.dataset)
if args.dataset not in args.log_dir:
args.log_dir = os.path.join(args.log_dir, args.dataset)
print(args)
sys.path.append(args.config_dir)
import config
if args.dataset in GRB_SUPPORTED_DATASETS:
dataset = Dataset(name=args.dataset,
data_dir=args.data_dir,
mode='full',
feat_norm=args.feat_norm,
verbose=True)
else:
print("{} dataset not supported.".format(args.dataset))
exit(1)
if args.model is not None:
model_list = args.model
else:
model_list = config.model_list_basic
terminal_out = sys.stdout
for model_name in model_list:
logger = Logger(file_dir=args.log_dir,
file_name="train_{}.out".format(model_name),
stream=terminal_out)
sys.stdout = logger
print("*" * 80)
print("Training {} model...........".format(model_name))
for i in range(args.n_train):
print("{} time training..........".format(i))
save_name = args.save_name.split('.')[0] + "_{}.pt".format(i)
model, train_params = config.build_model(model_name=model_name,
num_features=dataset.num_features,
num_classes=dataset.num_classes)
optimizer = config.build_optimizer(model=model,
lr=train_params["lr"] if "lr" in train_params else args.lr)
loss = config.build_loss()
eval_metric = config.build_metric()
trainer = Trainer(dataset=dataset,
optimizer=optimizer,
loss=loss,
lr_scheduler=args.lr_scheduler,
eval_metric=eval_metric,
early_stop=train_params[
"early_stop"] if "early_stop" in train_params else args.early_stop,
early_stop_patience=train_params[
"early_stop_patience"] if "early_stop_patience" in train_params else args.early_stop_patience,
device=device)
trainer.train(model=model,
n_epoch=train_params["n_epoch"] if "n_epoch" in train_params else args.n_epoch,
save_dir=os.path.join(args.save_dir, model_name),
save_name=save_name,
eval_every=args.eval_every,
save_after=args.save_after,
train_mode=train_params["train_mode"] if "train_mode" in train_params else args.train_mode,
verbose=args.verbose,
if_save=True
)
model = torch.load(os.path.join(args.save_dir, model_name, save_name), map_location=device)
val_score = trainer.evaluate(model, dataset.val_mask)
test_score = trainer.evaluate(model, dataset.test_mask)
print("*" * 80)
print("Val ACC of {}: {:.4f}".format(model_name, val_score))
print("Test ACC of {}: {:.4f}".format(model_name, test_score))
del model, trainer
print("Training completed.")