-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmodel.py
115 lines (97 loc) · 4.86 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
import torch.nn as nn
from torch_scatter import scatter
class GNNLayer(torch.nn.Module):
def __init__(self, in_dim, out_dim, attn_dim, n_rel, act=lambda x:x):
super(GNNLayer, self).__init__()
self.n_rel = n_rel
self.in_dim = in_dim
self.out_dim = out_dim
self.attn_dim = attn_dim
self.act = act
self.rela_embed = nn.Embedding(2*n_rel+1, in_dim)
self.Ws_attn = nn.Linear(in_dim, attn_dim, bias=False)
self.Wr_attn = nn.Linear(in_dim, attn_dim, bias=False)
self.Wqr_attn = nn.Linear(in_dim, attn_dim)
self.w_alpha = nn.Linear(attn_dim, 1)
self.W_h = nn.Linear(in_dim, out_dim, bias=False)
def forward(self, q_sub, q_rel, r_idx, hidden, edges, n_node, shortcut=False):
# edges: [h, r, t]
sub = edges[:,0]
rel = edges[:,1]
obj = edges[:,2]
hs = hidden[sub]
hr = self.rela_embed(rel) # relation embedding of each edge
h_qr = self.rela_embed(q_rel)[r_idx] # use batch_idx to get the query relation
# message aggregation
message = hs * hr
alpha = torch.sigmoid(self.w_alpha(nn.ReLU()(self.Ws_attn(hs) + self.Wr_attn(hr) + self.Wqr_attn(h_qr))))
message = alpha * message
message_agg = scatter(message, index=obj, dim=0, dim_size=n_node, reduce='sum') #ori
# get new hidden representations
hidden_new = self.act(self.W_h(message_agg))
if shortcut: hidden_new = hidden_new + hidden
return hidden_new
class GNN_auto(torch.nn.Module):
def __init__(self, params, loader):
super(GNN_auto, self).__init__()
self.params = params
self.n_layer = params.n_layer
self.hidden_dim = params.hidden_dim
self.attn_dim = params.attn_dim
self.n_rel = params.n_rel
self.n_ent = params.n_ent
self.loader = loader
acts = {'relu': nn.ReLU(), 'tanh': torch.tanh, 'idd': lambda x:x}
act = acts[params.act]
self.gnn_layers = []
for i in range(self.n_layer):
self.gnn_layers.append(GNNLayer(self.hidden_dim, self.hidden_dim, self.attn_dim, self.n_rel, act=act))
self.gnn_layers = nn.ModuleList(self.gnn_layers)
self.dropout = nn.Dropout(params.dropout)
self.gate = nn.GRU(self.hidden_dim, self.hidden_dim)
if self.params.initializer == 'relation': self.query_rela_embed = nn.Embedding(2*self.n_rel+1, self.hidden_dim)
if self.params.readout == 'linear':
if self.params.concatHidden:
self.W_final = nn.Linear(self.hidden_dim * (self.n_layer+1), 1, bias=False)
else:
self.W_final = nn.Linear(self.hidden_dim, 1, bias=False)
def forward(self, q_sub, q_rel, subgraph_data, mode='train'):
''' forward with extra propagation '''
n = len(q_sub)
batch_idxs, abs_idxs, query_sub_idxs, edge_batch_idxs, batch_sampled_edges = subgraph_data
n_node = len(batch_idxs)
h0 = torch.zeros((1, n_node, self.hidden_dim)).cuda()
hidden = torch.zeros(n_node, self.hidden_dim).cuda()
# initialize the hidden
if self.params.initializer == 'binary':
hidden[query_sub_idxs, :] = 1
elif self.params.initializer == 'relation':
hidden[query_sub_idxs, :] = self.query_rela_embed(q_rel)
# store hidden at each layer or not
if self.params.concatHidden: hidden_list = [hidden]
# propagation
for i in range(self.n_layer):
# forward
hidden = self.gnn_layers[i](q_sub, q_rel, edge_batch_idxs, hidden, batch_sampled_edges, n_node,
shortcut=self.params.shortcut)
# act_signal is a binary (0/1) tensor
# that 1 for non-activated entities and 0 for activated entities
act_signal = (hidden.sum(-1) == 0).detach().int()
hidden = self.dropout(hidden)
hidden, h0 = self.gate(hidden.unsqueeze(0), h0)
hidden = hidden.squeeze(0)
hidden = hidden * (1-act_signal).unsqueeze(-1)
h0 = h0 * (1-act_signal).unsqueeze(-1).unsqueeze(0)
if self.params.concatHidden: hidden_list.append(hidden)
# readout
if self.params.readout == 'linear':
if self.params.concatHidden: hidden = torch.cat(hidden_list, dim=-1)
scores = self.W_final(hidden).squeeze(-1)
elif self.params.readout == 'multiply':
if self.params.concatHidden: hidden = torch.cat(hidden_list, dim=-1)
scores = torch.sum(hidden * hidden[query_sub_idxs][batch_idxs], dim=-1)
# re-indexing
scores_all = torch.zeros((n, self.loader.n_ent)).cuda()
scores_all[batch_idxs, abs_idxs] = scores
return scores_all