1
- from typing import Union , Tuple , Optional
2
- from torch_geometric .typing import (OptPairTensor , Adj , Size , NoneType ,
3
- OptTensor )
1
+ from typing import Optional , Tuple , Union
4
2
5
3
import torch
6
- from torch import Tensor
4
+ import torch . nn as nn
7
5
import torch .nn .functional as F
6
+ from torch import Tensor
8
7
from torch .nn import Parameter
9
- import torch .nn as nn
10
- from torch_sparse import SparseTensor , set_diag
11
- from torch_geometric .nn .dense .linear import Linear
12
8
from torch_geometric .nn .conv import MessagePassing
13
- from torch_geometric .utils import remove_self_loops , add_self_loops , softmax
14
9
10
+ # from torch_sparse import SparseTensor, set_diag
11
+ from torch_geometric .nn .dense .linear import Linear
12
+ from torch_geometric .typing import Adj , NoneType , OptPairTensor , OptTensor , Size
13
+ from torch_geometric .utils import add_self_loops , remove_self_loops , softmax
15
14
16
15
17
16
class GATConv (MessagePassing ):
@@ -58,13 +57,22 @@ class GATConv(MessagePassing):
58
57
**kwargs (optional): Additional arguments of
59
58
:class:`torch_geometric.nn.conv.MessagePassing`.
60
59
"""
60
+
61
61
_alpha : OptTensor
62
62
63
- def __init__ (self , in_channels : Union [int , Tuple [int , int ]],
64
- out_channels : int , heads : int = 1 , concat : bool = True ,
65
- negative_slope : float = 0.2 , dropout : float = 0.0 ,
66
- add_self_loops : bool = True , bias : bool = True , ** kwargs ):
67
- kwargs .setdefault ('aggr' , 'add' )
63
+ def __init__ (
64
+ self ,
65
+ in_channels : Union [int , Tuple [int , int ]],
66
+ out_channels : int ,
67
+ heads : int = 1 ,
68
+ concat : bool = True ,
69
+ negative_slope : float = 0.2 ,
70
+ dropout : float = 0.0 ,
71
+ add_self_loops : bool = True ,
72
+ bias : bool = True ,
73
+ ** kwargs ,
74
+ ):
75
+ kwargs .setdefault ("aggr" , "add" )
68
76
super (GATConv , self ).__init__ (node_dim = 0 , ** kwargs )
69
77
70
78
self .in_channels = in_channels
@@ -91,7 +99,6 @@ def __init__(self, in_channels: Union[int, Tuple[int, int]],
91
99
nn .init .xavier_normal_ (self .lin_src .data , gain = 1.414 )
92
100
self .lin_dst = self .lin_src
93
101
94
-
95
102
# The learnable parameters to compute attention coefficients:
96
103
self .att_src = Parameter (torch .Tensor (1 , heads , out_channels ))
97
104
self .att_dst = Parameter (torch .Tensor (1 , heads , out_channels ))
@@ -117,12 +124,15 @@ def __init__(self, in_channels: Union[int, Tuple[int, int]],
117
124
# glorot(self.att_dst)
118
125
# # zeros(self.bias)
119
126
120
- def forward (self , x : Union [Tensor , OptPairTensor ], edge_index : Adj ,
121
- size : Size = None , return_attention_weights = None , attention = True , tied_attention = None ):
122
- # type: (Union[Tensor, OptPairTensor], Tensor, Size, NoneType) -> Tensor # noqa
123
- # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, NoneType) -> Tensor # noqa
124
- # type: (Union[Tensor, OptPairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa
125
- # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa
127
+ def forward (
128
+ self ,
129
+ x : Union [Tensor , OptPairTensor ],
130
+ edge_index : Adj ,
131
+ size : Size = None ,
132
+ return_attention_weights = None ,
133
+ attention = True ,
134
+ tied_attention = None ,
135
+ ):
126
136
r"""
127
137
Args:
128
138
return_attention_weights (bool, optional): If set to :obj:`True`,
@@ -161,7 +171,6 @@ def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
161
171
else :
162
172
alpha = tied_attention
163
173
164
-
165
174
if self .add_self_loops :
166
175
if isinstance (edge_index , Tensor ):
167
176
# We only want to add self-loops for nodes that appear both as
@@ -172,8 +181,10 @@ def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
172
181
num_nodes = min (size ) if size is not None else num_nodes
173
182
edge_index , _ = remove_self_loops (edge_index )
174
183
edge_index , _ = add_self_loops (edge_index , num_nodes = num_nodes )
175
- elif isinstance (edge_index , SparseTensor ):
176
- edge_index = set_diag (edge_index )
184
+ else :
185
+ raise ValueError (f"Received invalid type { type (edge_index )} " )
186
+ # elif isinstance(edge_index, SparseTensor):
187
+ # edge_index = set_diag(edge_index)
177
188
178
189
# propagate_type: (x: OptPairTensor, alpha: OptPairTensor)
179
190
out = self .propagate (edge_index , x = x , alpha = alpha , size = size )
@@ -193,26 +204,26 @@ def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
193
204
if isinstance (return_attention_weights , bool ):
194
205
if isinstance (edge_index , Tensor ):
195
206
return out , (edge_index , alpha )
196
- elif isinstance (edge_index , SparseTensor ):
197
- return out , edge_index .set_value (alpha , layout = 'coo' )
207
+ else :
208
+ raise ValueError (f"Received invalid type { type (edge_index )} " )
209
+ # elif isinstance(edge_index, SparseTensor):
210
+ # return out, edge_index.set_value(alpha, layout="coo")
198
211
else :
199
212
return out
200
213
201
- def message (self , x_j : Tensor , alpha_j : Tensor , alpha_i : OptTensor ,
202
- index : Tensor , ptr : OptTensor ,
203
- size_i : Optional [ int ] ) -> Tensor :
214
+ def message (
215
+ self , x_j : Tensor , alpha_j : Tensor , alpha_i : OptTensor , index : Tensor , ptr : OptTensor , size_i : Optional [ int ]
216
+ ) -> Tensor :
204
217
# Given egel-level attention coefficients for source and target nodes,
205
218
# we simply need to sum them up to "emulate" concatenation:
206
219
alpha = alpha_j if alpha_i is None else alpha_j + alpha_i
207
220
208
- #alpha = F.leaky_relu(alpha, self.negative_slope)
221
+ # alpha = F.leaky_relu(alpha, self.negative_slope)
209
222
alpha = torch .sigmoid (alpha )
210
223
alpha = softmax (alpha , index , ptr , size_i )
211
224
self ._alpha = alpha # Save for later use.
212
225
alpha = F .dropout (alpha , p = self .dropout , training = self .training )
213
226
return x_j * alpha .unsqueeze (- 1 )
214
227
215
228
def __repr__ (self ):
216
- return '{}({}, {}, heads={})' .format (self .__class__ .__name__ ,
217
- self .in_channels ,
218
- self .out_channels , self .heads )
229
+ return "{}({}, {}, heads={})" .format (self .__class__ .__name__ , self .in_channels , self .out_channels , self .heads )
0 commit comments