-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathexample_classification.py
148 lines (114 loc) · 4.42 KB
/
example_classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.nn import GINConv, DenseGINConv
from torch_geometric.nn.models.mlp import MLP
from torch_geometric.utils import to_dense_batch
from torch_geometric.loader import DataLoader
from torch_geometric.datasets import TUDataset
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import (to_dense_batch,
get_laplacian,
to_dense_adj,
dense_to_sparse)
# Local imports
from just_balance import just_balance_pool
class NormalizeAdj(BaseTransform):
"""
Applies the following transformation:
A --> I - delta * L
"""
def __init__(self, delta: float = 0.85) -> None:
self.delta = delta
super().__init__()
def forward(self, data: torch.Any) -> torch.Any:
edge_index, edge_weight = get_laplacian(data.edge_index, data.edge_weight, normalization='sym')
L = to_dense_adj(edge_index, edge_attr=edge_weight)
A_norm = torch.eye(data.num_nodes) - self.delta * L
data.edge_index, data.edge_weight = dense_to_sparse(A_norm)
return data
### Get the data
dataset = TUDataset(root="../data/TUDataset", name='NCI1', pre_transform=NormalizeAdj())
train_loader = DataLoader(dataset[:0.9], batch_size=32, shuffle=True)
test_loader = DataLoader(dataset[0.9:], batch_size=32)
### Model definition
class Net(torch.nn.Module):
def __init__(self,
hidden_channels = 64,
mlp_units=[16],
mlp_act="ReLU"
):
super().__init__()
num_features = dataset.num_features
num_classes = dataset.num_classes
n_clusters = int(dataset._data.x.size(0) / len(dataset)) # average number of nodes per graph
mlp_act = getattr(torch.nn, mlp_act)(inplace=True)
# First MP layer
self.conv1 = GINConv(
torch.nn.Sequential(
torch.nn.Linear(num_features, hidden_channels),
torch.nn.ReLU(),
torch.nn.Linear(hidden_channels, hidden_channels),
)
)
self.mlp = MLP([hidden_channels] + mlp_units + [n_clusters], act=mlp_act, norm=None)
# Second MP layer
self.conv2 = DenseGINConv(
torch.nn.Sequential(
torch.nn.Linear(hidden_channels, hidden_channels),
torch.nn.ReLU(),
torch.nn.Linear(hidden_channels, hidden_channels),
)
)
# Readout layer
self.lin = torch.nn.Linear(hidden_channels, num_classes)
def forward(self, x, edge_index, batch=None):
# First MP layer
x = self.conv1(x, edge_index)
# Transform to dense batch
x, mask = to_dense_batch(x, batch)
adj = to_dense_adj(edge_index, batch)
# Cluster assignments (logits)
s = self.mlp(x)
# Pooling
x_pool, adj_pool, aux_loss = just_balance_pool(x, adj, s, mask, normalize=True)
# Second MP layer
x = self.conv2(x_pool, adj_pool)
# Global pooling
x = x.mean(dim=1)
# Readout layer
x = self.lin(x)
return F.log_softmax(x, dim=-1), aux_loss
### Model setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
def train():
model.train()
loss_all = 0
for data in train_loader:
data = data.to(device)
optimizer.zero_grad()
output, aux_loss = model(data.x, data.edge_index, data.batch)
loss = F.nll_loss(output, data.y.view(-1)) + aux_loss
loss.backward()
# torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
loss_all += data.y.size(0) * float(loss)
optimizer.step()
return loss_all / len(dataset)
@torch.no_grad()
def test(loader):
model.eval()
correct = 0
for data in loader:
data = data.to(device)
pred = model(data.x, data.edge_index, data.batch)[0].max(dim=1)[1]
correct += int(pred.eq(data.y.view(-1)).sum())
return correct / len(loader.dataset)
### Training loop
best_val_acc = test_acc = 0
for epoch in range(1, 501):
train_loss = train()
val_acc = test(test_loader)
print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.3f}, '
f'Val Acc: {val_acc:.3f}')