-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlink.py
136 lines (114 loc) · 5.12 KB
/
link.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""
Differences compared to MichSchli/RelationPrediction
* Report raw metrics instead of filtered metrics.
* By default, we use uniform edge sampling instead of neighbor-based edge
sampling used in author's code. In practice, we find it achieves similar MRR.
"""
import argparse
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from dgl.data.knowledge_graph import FB15k237Dataset
from dgl.dataloading import GraphDataLoader
from link_utils import preprocess, SubgraphIterator, calc_mrr
from model import RGCN
class LinkPredict(nn.Module):
def __init__(self, in_dim, num_rels, h_dim=500, num_bases=100, dropout=0.2, reg_param=0.01):
super(LinkPredict, self).__init__()
self.rgcn = RGCN(in_dim, h_dim, h_dim, num_rels * 2, regularizer="bdd",
num_bases=num_bases, dropout=dropout, self_loop=True)
self.dropout = nn.Dropout(dropout)
self.reg_param = reg_param
self.w_relation = nn.Parameter(th.Tensor(num_rels, h_dim))
nn.init.xavier_uniform_(self.w_relation,
gain=nn.init.calculate_gain('relu'))
def calc_score(self, embedding, triplets):
# DistMult
s = embedding[triplets[:,0]]
r = self.w_relation[triplets[:,1]]
o = embedding[triplets[:,2]]
score = th.sum(s * r * o, dim=1)
return score
def forward(self, g, nids):
return self.dropout(self.rgcn(g, nids=nids))
def regularization_loss(self, embedding):
return th.mean(embedding.pow(2)) + th.mean(self.w_relation.pow(2))
def get_loss(self, embed, triplets, labels):
# each row in the triplets is a 3-tuple of (source, relation, destination)
score = self.calc_score(embed, triplets)
predict_loss = F.binary_cross_entropy_with_logits(score, labels)
reg_loss = self.regularization_loss(embed)
return predict_loss + self.reg_param * reg_loss
def main(args):
data = FB15k237Dataset(reverse=False)
graph = data[0]
num_nodes = graph.num_nodes()
num_rels = data.num_rels
train_g, test_g = preprocess(graph, num_rels)
test_nids = th.arange(0, num_nodes)
test_mask = graph.edata['test_mask']
subg_iter = SubgraphIterator(train_g, num_rels, args.edge_sampler)
dataloader = GraphDataLoader(subg_iter, batch_size=1, collate_fn=lambda x: x[0])
# Prepare data for metric computation
src, dst = graph.edges()
triplets = th.stack([src, graph.edata['etype'], dst], dim=1)
model = LinkPredict(num_nodes, num_rels)
optimizer = th.optim.Adam(model.parameters(), lr=1e-2)
if args.gpu >= 0 and th.cuda.is_available():
device = th.device(args.gpu)
else:
device = th.device('cpu')
model = model.to(device)
best_mrr = 0
model_state_file = 'model_state.pth'
for epoch, batch_data in enumerate(dataloader):
model.train()
g, train_nids, edges, labels = batch_data
g = g.to(device)
train_nids = train_nids.to(device)
edges = edges.to(device)
labels = labels.to(device)
embed = model(g, train_nids)
loss = model.get_loss(embed, edges, labels)
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # clip gradients
optimizer.step()
print("Epoch {:04d} | Loss {:.4f} | Best MRR {:.4f}".format(epoch, loss.item(), best_mrr))
if (epoch + 1) % 500 == 0:
# perform validation on CPU because full graph is too large
model = model.cpu()
model.eval()
print("start eval")
embed = model(test_g, test_nids)
mrr = calc_mrr(embed, model.w_relation, test_mask, triplets,
batch_size=500, eval_p=args.eval_protocol)
# save best model
if best_mrr < mrr:
best_mrr = mrr
th.save({'state_dict': model.state_dict(), 'epoch': epoch}, model_state_file)
model = model.to(device)
print("Start testing:")
# use best model checkpoint
checkpoint = th.load(model_state_file)
model = model.cpu() # test on CPU
model.eval()
model.load_state_dict(checkpoint['state_dict'])
print("Using best epoch: {}".format(checkpoint['epoch']))
embed = model(test_g, test_nids)
calc_mrr(embed, model.w_relation, test_mask, triplets,
batch_size=500, eval_p=args.eval_protocol)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='RGCN for link prediction')
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--eval-protocol", type=str, default='filtered',
choices=['filtered', 'raw'],
help="Whether to use 'filtered' or 'raw' MRR for evaluation")
parser.add_argument("--edge-sampler", type=str, default='uniform',
choices=['uniform', 'neighbor'],
help="Type of edge sampler: 'uniform' or 'neighbor'"
"The original implementation uses neighbor sampler.")
args = parser.parse_args()
print(args)
main(args)