forked from XiaoxinHe/G-Retriever
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
147 lines (114 loc) · 5.41 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os
import wandb
import gc
from tqdm import tqdm
import torch
import json
import pandas as pd
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_
from src.model import load_model, llama_model_path
from src.dataset import load_dataset
from src.utils.evaluate import eval_funcs
from src.config import parse_args_llama
from src.utils.ckpt import _save_checkpoint, _reload_best_model
from src.utils.collate import collate_fn
from src.utils.seed import seed_everything
from src.utils.lr_schedule import adjust_learning_rate
def main(args):
# Step 1: Set up wandb
seed = args.seed
wandb.init(project=f"{args.project}",
name=f"{args.dataset}_{args.model_name}_seed{seed}",
config=args)
seed_everything(seed=args.seed)
print(args)
dataset = load_dataset[args.dataset]()
idx_split = dataset.get_idx_split()
# Step 2: Build Node Classification Dataset
train_dataset = [dataset[i] for i in idx_split['train']]
val_dataset = [dataset[i] for i in idx_split['val']]
test_dataset = [dataset[i] for i in idx_split['test']]
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, drop_last=True, pin_memory=True, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, drop_last=False, pin_memory=True, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size, drop_last=False, pin_memory=True, shuffle=False, collate_fn=collate_fn)
# Step 3: Build Model
args.llm_model_path = llama_model_path[args.llm_model_name]
model = load_model[args.model_name](graph_type=dataset.graph_type, args=args, init_prompt=dataset.prompt)
# Step 4 Set Optimizer
params = [p for _, p in model.named_parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(
[{'params': params, 'lr': args.lr, 'weight_decay': args.wd}, ],
betas=(0.9, 0.95)
)
trainable_params, all_param = model.print_trainable_params()
print(f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}")
# Step 5. Training
num_training_steps = args.num_epochs * len(train_loader)
progress_bar = tqdm(range(num_training_steps))
best_val_loss = float('inf')
for epoch in range(args.num_epochs):
model.train()
epoch_loss, accum_loss = 0., 0.
for step, batch in enumerate(train_loader):
optimizer.zero_grad()
loss = model(batch)
loss.backward()
clip_grad_norm_(optimizer.param_groups[0]['params'], 0.1)
if (step + 1) % args.grad_steps == 0:
adjust_learning_rate(optimizer.param_groups[0], args.lr, step / len(train_loader) + epoch, args)
optimizer.step()
epoch_loss, accum_loss = epoch_loss + loss.item(), accum_loss + loss.item()
if (step + 1) % args.grad_steps == 0:
lr = optimizer.param_groups[0]["lr"]
wandb.log({'Lr': lr})
wandb.log({'Accum Loss': accum_loss / args.grad_steps})
accum_loss = 0.
progress_bar.update(1)
print(f"Epoch: {epoch}|{args.num_epochs}: Train Loss (Epoch Mean): {epoch_loss / len(train_loader)}")
wandb.log({'Train Loss (Epoch Mean)': epoch_loss / len(train_loader)})
val_loss = 0.
eval_output = []
model.eval()
with torch.no_grad():
for step, batch in enumerate(val_loader):
loss = model(batch)
val_loss += loss.item()
val_loss = val_loss/len(val_loader)
print(f"Epoch: {epoch}|{args.num_epochs}: Val Loss: {val_loss}")
wandb.log({'Val Loss': val_loss})
if val_loss < best_val_loss:
best_val_loss = val_loss
_save_checkpoint(model, optimizer, epoch, args, is_best=True)
best_epoch = epoch
print(f'Epoch {epoch} Val Loss {val_loss} Best Val Loss {best_val_loss} Best Epoch {best_epoch}')
if epoch - best_epoch >= args.patience:
print(f'Early stop at epoch {epoch}')
break
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
# Step 5. Evaluating
os.makedirs(f'{args.output_dir}/{args.dataset}', exist_ok=True)
path = f'{args.output_dir}/{args.dataset}/model_name_{args.model_name}_llm_model_name_{args.llm_model_name}_llm_frozen_{args.llm_frozen}_max_txt_len_{args.max_txt_len}_max_new_tokens_{args.max_new_tokens}_gnn_model_name_{args.gnn_model_name}_patience_{args.patience}_num_epochs_{args.num_epochs}_seed{seed}.csv'
print(f'path: {path}')
model = _reload_best_model(model, args)
model.eval()
progress_bar_test = tqdm(range(len(test_loader)))
with open(path, "w") as f:
for step, batch in enumerate(test_loader):
with torch.no_grad():
output = model.inference(batch)
df = pd.DataFrame(output)
for _, row in df.iterrows():
f.write(json.dumps(dict(row)) + "\n")
progress_bar_test.update(1)
# Step 6. Post-processing & compute metrics
acc = eval_funcs[args.dataset](path)
print(f'Test Acc {acc}')
wandb.log({'Test Acc': acc})
if __name__ == "__main__":
args = parse_args_llama()
main(args)
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
gc.collect()