-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmodel.py
57 lines (43 loc) · 2.29 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn.models import InnerProductDecoder, VGAE
from torch_geometric.nn.conv import GCNConv
from torch_geometric.utils import negative_sampling, remove_self_loops, add_self_loops
class GCNEncoder(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super(GCNEncoder, self).__init__()
self.gcn_shared = GCNConv(in_channels, hidden_channels)
self.gcn_mu = GCNConv(hidden_channels, out_channels)
self.gcn_logvar = GCNConv(hidden_channels, out_channels)
def forward(self, x, edge_index):
x = F.relu(self.gcn_shared(x, edge_index))
mu = self.gcn_mu(x, edge_index)
logvar = self.gcn_logvar(x, edge_index)
return mu, logvar
class DeepVGAE(VGAE):
def __init__(self, args):
super(DeepVGAE, self).__init__(encoder=GCNEncoder(args.enc_in_channels,
args.enc_hidden_channels,
args.enc_out_channels),
decoder=InnerProductDecoder())
def forward(self, x, edge_index):
z = self.encode(x, edge_index)
adj_pred = self.decoder.forward_all(z)
return adj_pred
def loss(self, x, pos_edge_index, all_edge_index):
z = self.encode(x, pos_edge_index)
pos_loss = -torch.log(
self.decoder(z, pos_edge_index, sigmoid=True) + 1e-15).mean()
# Do not include self-loops in negative samples
all_edge_index_tmp, _ = remove_self_loops(all_edge_index)
all_edge_index_tmp, _ = add_self_loops(all_edge_index_tmp)
neg_edge_index = negative_sampling(all_edge_index_tmp, z.size(0), pos_edge_index.size(1))
neg_loss = -torch.log(1 - self.decoder(z, neg_edge_index, sigmoid=True) + 1e-15).mean()
kl_loss = 1 / x.size(0) * self.kl_loss()
return pos_loss + neg_loss + kl_loss
def single_test(self, x, train_pos_edge_index, test_pos_edge_index, test_neg_edge_index):
with torch.no_grad():
z = self.encode(x, train_pos_edge_index)
roc_auc_score, average_precision_score = self.test(z, test_pos_edge_index, test_neg_edge_index)
return roc_auc_score, average_precision_score