-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_game_pipeline.py
151 lines (133 loc) · 6.9 KB
/
train_game_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import argparse
import os
import sys
import torch
from grb.dataset import Dataset
from grb.dataset import GRB_SUPPORTED_DATASETS
from grb.trainer import Trainer
from grb.utils import Logger
from model.gnn import GCN_bn_moe
from grb.utils.normalize import GCNAdjNorm
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Training GML models in pipeline.')
# Dataset settings
parser.add_argument("--dataset", type=str, default="grb-cora")
parser.add_argument("--data_dir", type=str, default="../data/")
parser.add_argument("--feat_norm", type=str, default="arctan")
# Model settings
parser.add_argument("--model", nargs='+', default=None)
parser.add_argument("--save_dir", type=str, default="../normal_moe/")
parser.add_argument("--log_dir", type=str, default="../pipeline/logs/")
parser.add_argument("--save_name", type=str, default="gat_model.pt")
# Training settings
parser.add_argument("--gpu", type=int, default=0, help="gpu")
parser.add_argument("--n_train", type=int, default=1)
parser.add_argument("--n_epoch", type=int, default=1000, help="Training epoch.")
parser.add_argument("--lr", type=float, default=0.01, help="Learning rate.")
parser.add_argument("--eval_every", type=int, default=1)
parser.add_argument("--save_after", type=int, default=0)
parser.add_argument("--train_mode", type=str, default="inductive")
parser.add_argument("--eval_metric", type=str, default="acc")
parser.add_argument("--early_stop", action="store_true")
parser.add_argument("--early_stop_patience", type=int, default=500)
parser.add_argument("--lr_scheduler", action="store_true")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--verbose", action="store_true")
# set moe
parser.add_argument('--hidden_features', type=int, default=96, choices=[64, 96, 128])
parser.add_argument('--n_layers', type=int, default=3)
parser.add_argument('--n_heads', type=int, default=4)
parser.add_argument('--feat_dropout', type=float, default=0.5)
parser.add_argument('--attn_dropout', type=float, default=0.5)
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--adj_norm_func', type=str, default='GCNAdjNorm')
parser.add_argument('--layer_norm', type=bool, default=True, choices=[True, False])
parser.add_argument('--residual', type=bool, default=False)
parser.add_argument('--num_experts', type=int, default=3, help='number of experts in GAT_moe')
parser.add_argument('--noisy_gating', type=bool, default=True, help='whether to use noisy gating')
parser.add_argument('--k', type=int, default=1, help='number of experts to use')
args = parser.parse_args()
if args.gpu >= 0:
device = "cuda:{}".format(args.gpu)
else:
device = "cpu"
if args.dataset not in args.data_dir:
args.data_dir = os.path.join(args.data_dir, args.dataset)
if args.dataset not in args.save_dir:
args.save_dir = os.path.join(args.save_dir, args.dataset)
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
if args.dataset not in args.config_dir:
args.config_dir = os.path.join(args.config_dir, args.dataset)
if args.dataset not in args.log_dir:
args.log_dir = os.path.join(args.log_dir, args.dataset)
print(args)
sys.path.append(args.config_dir)
import config
if args.dataset in GRB_SUPPORTED_DATASETS:
dataset = Dataset(name=args.dataset,
data_dir=args.data_dir,
mode='full',
feat_norm=args.feat_norm,
verbose=True)
else:
print("{} dataset not supported.".format(args.dataset))
exit(1)
if args.model is not None:
model_list = args.model
else:
model_list = config.model_list_basic
terminal_out = sys.stdout
for model_name in model_list:
logger = Logger(file_dir=args.log_dir,
file_name="train_{}.out".format(model_name),
stream=terminal_out)
sys.stdout = logger
print("*" * 80)
print("Training {} model...........".format(model_name))
for i in range(args.n_train):
print("{} time training..........".format(i))
save_name = args.save_name.split('.')[0] + "_{}.pt".format(i)
model, train_params = config.build_model(model_name=model_name,
num_features=dataset.num_features,
num_classes=dataset.num_classes)
features = dataset.features
labels = dataset.labels
num_features = dataset.num_features
num_classes = dataset.num_classes
test_mask = dataset.test_mask
# model = GCN_bn_moe(in_features=num_features,
# hidden_features=args.hidden_features,
# )
optimizer = config.build_optimizer(model=model,
lr=train_params["lr"] if "lr" in train_params else args.lr)
loss = config.build_loss()
eval_metric = config.build_metric()
trainer = Trainer(dataset=dataset,
optimizer=optimizer,
loss=loss,
lr_scheduler=args.lr_scheduler,
eval_metric=eval_metric,
early_stop=train_params[
"early_stop"] if "early_stop" in train_params else args.early_stop,
early_stop_patience=train_params[
"early_stop_patience"] if "early_stop_patience" in train_params else args.early_stop_patience,
device=device)
trainer.train(model=model,
n_epoch=train_params["n_epoch"] if "n_epoch" in train_params else args.n_epoch,
save_dir=os.path.join(args.save_dir, model_name),
save_name=save_name,
eval_every=args.eval_every,
save_after=args.save_after,
train_mode=train_params["train_mode"] if "train_mode" in train_params else args.train_mode,
verbose=args.verbose,
if_save=True
)
model = torch.load(os.path.join(args.save_dir, model_name, save_name), map_location=device)
val_score = trainer.evaluate(model, dataset.val_mask)
test_score = trainer.evaluate(model, dataset.test_mask)
print("*" * 80)
print("Val ACC of {}: {:.4f}".format(model_name, val_score))
print("Test ACC of {}: {:.4f}".format(model_name, test_score))
del model, trainer
print("Training completed.")