-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathjust_balance.py
75 lines (56 loc) · 2.73 KB
/
just_balance.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
EPS = 1e-15
def just_balance_pool(x, adj, s, mask=None, normalize=True):
r"""The Just Balance pooling operator from the `"Simplifying Clustering with
Graph Neural Networks" <https://arxiv.org/abs/2207.08779>`_ paper
.. math::
\mathbf{X}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot
\mathbf{X}
\mathbf{A}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot
\mathbf{A} \cdot \mathrm{softmax}(\mathbf{S})
based on dense learned assignments :math:`\mathbf{S} \in \mathbb{R}^{B
\times N \times C}`.
Returns the pooled node feature matrix, the coarsened and symmetrically
normalized adjacency matrix and the following auxiliary objective:
.. math::
\mathcal{L} = - {\mathrm{Tr}(\sqrt{\mathbf{S}^{\top} \mathbf{S}})}
Args:
x (Tensor): Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`
with batch-size :math:`B`, (maximum) number of nodes :math:`N`
for each graph, and feature dimension :math:`F`.
adj (Tensor): Symmetrically normalized adjacency tensor
:math:`\mathbf{A} \in \mathbb{R}^{B \times N \times N}`.
s (Tensor): Assignment tensor :math:`\mathbf{S} \in \mathbb{R}^{B \times N \times C}`
with number of clusters :math:`C`. The softmax does not have to be
applied beforehand, since it is executed within this method.
mask (BoolTensor, optional): Mask matrix
:math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating
the valid nodes for each graph. (default: :obj:`None`)
:rtype: (:class:`Tensor`, :class:`Tensor`, :class:`Tensor`,
:class:`Tensor`)
"""
x = x.unsqueeze(0) if x.dim() == 2 else x
adj = adj.unsqueeze(0) if adj.dim() == 2 else adj
s = s.unsqueeze(0) if s.dim() == 2 else s
(batch_size, num_nodes, _), k = x.size(), s.size(-1)
s = torch.softmax(s, dim=-1)
if mask is not None:
mask = mask.view(batch_size, num_nodes, 1).to(x.dtype)
x, s = x * mask, s * mask
out = torch.matmul(s.transpose(1, 2), x)
out_adj = torch.matmul(torch.matmul(s.transpose(1, 2), adj), s)
# Loss
ss = torch.matmul(s.transpose(1, 2), s)
ss_sqrt = torch.sqrt(ss + EPS)
loss = torch.mean(-_rank3_trace(ss_sqrt))
if normalize:
loss = loss / torch.sqrt(torch.tensor(num_nodes * k))
# Fix and normalize coarsened adjacency matrix.
ind = torch.arange(k, device=out_adj.device)
out_adj[:, ind, ind] = 0
d = torch.einsum('ijk->ij', out_adj)
d = torch.sqrt(torch.clamp(d, min=EPS))[:, None]
out_adj = (out_adj / d) / d.transpose(1, 2)
return out, out_adj, loss
def _rank3_trace(x):
return torch.einsum('ijj->i', x)