-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathGTVConv.py
189 lines (146 loc) · 6.28 KB
/
GTVConv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import tensorflow as tf
from tensorflow.keras import backend as K
from spektral.layers import ops
from spektral.layers.convolutional.conv import Conv
class GTVConv(Conv):
r"""
A graph total variation convolutional layer (GTVConv) from the paper
> [Clustering with Total Variation Graph Neural Networks](https://arxiv.org/abs/2211.06218)
> Jonas Berg Hansen and Filippo Maria Bianchi
**Mode**: single, batch
This layer computes
$$
\X' = \sigma\left[\left(\I - \delta{\hat{\Lb}_\mathbf{\Gamma}}\right) \X \W \right]
$$
**Input**
- Node features of shape `(batch, n_nodes, n_node_features)`;
- Adjacency matrix of shape `(batch, n_nodes, n_nodes)`;
**Output**
- Node features with the same shape as the input, but with the last
dimension changed to `channels`.
**Arguments**
- `channels`: number of output channels;
- `delta_coeff`: step size for gradient descent of GTV
- `epsilon`: small number used to numerically stabilize the computation of new adjacency weights
- `activation`: activation function;
- `use_bias`: bool, add a bias vector to the output;
- `kernel_initializer`: initializer for the weights;
- `bias_initializer`: initializer for the bias vector;
- `kernel_regularizer`: regularization applied to the weights;
- `bias_regularizer`: regularization applied to the bias vector;
- `activity_regularizer`: regularization applied to the output;
- `kernel_constraint`: constraint applied to the weights;
- `bias_constraint`: constraint applied to the bias vector.
"""
def __init__(
self,
channels,
delta_coeff=1.,
epsilon=1e-3,
activation=None,
use_bias=True,
kernel_initializer="he_normal",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs
):
super().__init__(
activation=activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
kernel_constraint=kernel_constraint,
bias_constraint=bias_constraint,
**kwargs
)
self.channels = channels
self.delta_coeff = delta_coeff
self.epsilon = epsilon
def build(self, input_shape):
assert len(input_shape) >= 2
input_dim = input_shape[0][-1]
self.kernel = self.add_weight(
shape=(input_dim, self.channels),
initializer=self.kernel_initializer,
name="kernel",
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
)
if self.use_bias:
self.bias = self.add_weight(
shape=(self.channels,),
initializer=self.bias_initializer,
name="bias",
regularizer=self.bias_regularizer,
constraint=self.bias_constraint,
)
self.built = True
def call(self, inputs, mask=None):
x, a = inputs
mode = ops.autodetect_mode(x, a)
# Update node features
x = K.dot(x, self.kernel)
if mode == ops.modes.SINGLE:
output = self._call_single(x, a)
elif mode == ops.modes.BATCH:
output = self._call_batch(x, a)
if self.use_bias:
output = K.bias_add(output, self.bias)
if mask is not None:
output *= mask[0]
output = self.activation(output)
return output
def _call_single(self, x, a):
if K.is_sparse(a):
index_i = a.indices[:, 0]
index_j = a.indices[:, 1]
n_nodes = tf.shape(a, out_type=index_i.dtype)[0]
# Compute absolute differences between neighbouring nodes
abs_diff = tf.math.abs(tf.transpose(tf.gather(x, index_i)) -
tf.transpose(tf.gather(x, index_j)))
abs_diff = tf.math.reduce_sum(abs_diff, axis=0)
# Compute new adjacency matrix
gamma = tf.sparse.map_values(tf.multiply,
a,
1 / tf.math.maximum(abs_diff, self.epsilon))
# Compute degree matrix from gamma matrix
d_gamma = tf.sparse.SparseTensor(tf.stack([tf.range(n_nodes)] * 2, axis=1),
tf.sparse.reduce_sum(gamma, axis=-1),
[n_nodes, n_nodes])
# Compute laplcian: L = D_gamma - Gamma
l = tf.sparse.add(d_gamma, tf.sparse.map_values(
tf.multiply, gamma, -1.))
# Compute adjsuted laplacian: L_adjusted = I - delta*L
l = tf.sparse.add(tf.sparse.eye(n_nodes), tf.sparse.map_values(
tf.multiply, l, -self.delta_coeff))
# Aggregate features with adjusted laplacian
output = ops.modal_dot(l, x)
else:
n_nodes = tf.shape(a)[-1]
abs_diff = tf.math.abs(x[:, tf.newaxis, :] - x)
abs_diff = tf.reduce_sum(abs_diff, axis=-1)
gamma = a / tf.math.maximum(abs_diff, self.epsilon)
degrees = tf.math.reduce_sum(gamma, axis=-1)
l = -gamma
l = tf.linalg.set_diag(l, degrees - tf.linalg.diag_part(gamma))
l = tf.eye(n_nodes) - self.delta_coeff * l
output = tf.matmul(l, x)
return output
def _call_batch(self, x, a):
n_nodes = tf.shape(a)[-1]
abs_diff = tf.reduce_sum(tf.math.abs(tf.expand_dims(x, 2) -
tf.expand_dims(x, 1)), axis = -1)
gamma = a / tf.math.maximum(abs_diff, self.epsilon)
degrees = tf.math.reduce_sum(gamma, axis=-1)
l = -gamma
l = tf.linalg.set_diag(l, degrees - tf.linalg.diag_part(gamma))
l = tf.eye(n_nodes) - self.delta_coeff * l
output = tf.matmul(l, x)
return output