|
| 1 | +from typing import Union, Tuple, Optional |
| 2 | +from torch_geometric.typing import (OptPairTensor, Adj, Size, NoneType, |
| 3 | + OptTensor) |
| 4 | + |
| 5 | +import torch |
| 6 | +from torch import Tensor |
| 7 | +import torch.nn.functional as F |
| 8 | +from torch.nn import Parameter |
| 9 | +import torch.nn as nn |
| 10 | +from torch_sparse import SparseTensor, set_diag |
| 11 | +from torch_geometric.nn.dense.linear import Linear |
| 12 | +from torch_geometric.nn.conv import MessagePassing |
| 13 | +from torch_geometric.utils import remove_self_loops, add_self_loops, softmax |
| 14 | + |
| 15 | + |
| 16 | + |
| 17 | +class GATConv(MessagePassing): |
| 18 | + r"""The graph attentional operator from the `"Graph Attention Networks" |
| 19 | + <https://arxiv.org/abs/1710.10903>`_ paper |
| 20 | +
|
| 21 | + .. math:: |
| 22 | + \mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} + |
| 23 | + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j}, |
| 24 | +
|
| 25 | + where the attention coefficients :math:`\alpha_{i,j}` are computed as |
| 26 | +
|
| 27 | + .. math:: |
| 28 | + \alpha_{i,j} = |
| 29 | + \frac{ |
| 30 | + \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} |
| 31 | + [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j] |
| 32 | + \right)\right)} |
| 33 | + {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} |
| 34 | + \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} |
| 35 | + [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_k] |
| 36 | + \right)\right)}. |
| 37 | +
|
| 38 | + Args: |
| 39 | + in_channels (int or tuple): Size of each input sample, or :obj:`-1` to |
| 40 | + derive the size from the first input(s) to the forward method. |
| 41 | + A tuple corresponds to the sizes of source and target |
| 42 | + dimensionalities. |
| 43 | + out_channels (int): Size of each output sample. |
| 44 | + heads (int, optional): Number of multi-head-attentions. |
| 45 | + (default: :obj:`1`) |
| 46 | + concat (bool, optional): If set to :obj:`False`, the multi-head |
| 47 | + attentions are averaged instead of concatenated. |
| 48 | + (default: :obj:`True`) |
| 49 | + negative_slope (float, optional): LeakyReLU angle of the negative |
| 50 | + slope. (default: :obj:`0.2`) |
| 51 | + dropout (float, optional): Dropout probability of the normalized |
| 52 | + attention coefficients which exposes each node to a stochastically |
| 53 | + sampled neighborhood during training. (default: :obj:`0`) |
| 54 | + add_self_loops (bool, optional): If set to :obj:`False`, will not add |
| 55 | + self-loops to the input graph. (default: :obj:`True`) |
| 56 | + bias (bool, optional): If set to :obj:`False`, the layer will not learn |
| 57 | + an additive bias. (default: :obj:`True`) |
| 58 | + **kwargs (optional): Additional arguments of |
| 59 | + :class:`torch_geometric.nn.conv.MessagePassing`. |
| 60 | + """ |
| 61 | + _alpha: OptTensor |
| 62 | + |
| 63 | + def __init__(self, in_channels: Union[int, Tuple[int, int]], |
| 64 | + out_channels: int, heads: int = 1, concat: bool = True, |
| 65 | + negative_slope: float = 0.2, dropout: float = 0.0, |
| 66 | + add_self_loops: bool = True, bias: bool = True, **kwargs): |
| 67 | + kwargs.setdefault('aggr', 'add') |
| 68 | + super(GATConv, self).__init__(node_dim=0, **kwargs) |
| 69 | + |
| 70 | + self.in_channels = in_channels |
| 71 | + self.out_channels = out_channels |
| 72 | + self.heads = heads |
| 73 | + self.concat = concat |
| 74 | + self.negative_slope = negative_slope |
| 75 | + self.dropout = dropout |
| 76 | + self.add_self_loops = add_self_loops |
| 77 | + |
| 78 | + # In case we are operating in bipartite graphs, we apply separate |
| 79 | + # transformations 'lin_src' and 'lin_dst' to source and target nodes: |
| 80 | + # if isinstance(in_channels, int): |
| 81 | + # self.lin_src = Linear(in_channels, heads * out_channels, |
| 82 | + # bias=False, weight_initializer='glorot') |
| 83 | + # self.lin_dst = self.lin_src |
| 84 | + # else: |
| 85 | + # self.lin_src = Linear(in_channels[0], heads * out_channels, False, |
| 86 | + # weight_initializer='glorot') |
| 87 | + # self.lin_dst = Linear(in_channels[1], heads * out_channels, False, |
| 88 | + # weight_initializer='glorot') |
| 89 | + |
| 90 | + self.lin_src = nn.Parameter(torch.zeros(size=(in_channels, out_channels))) |
| 91 | + nn.init.xavier_normal_(self.lin_src.data, gain=1.414) |
| 92 | + self.lin_dst = self.lin_src |
| 93 | + |
| 94 | + |
| 95 | + # The learnable parameters to compute attention coefficients: |
| 96 | + self.att_src = Parameter(torch.Tensor(1, heads, out_channels)) |
| 97 | + self.att_dst = Parameter(torch.Tensor(1, heads, out_channels)) |
| 98 | + nn.init.xavier_normal_(self.att_src.data, gain=1.414) |
| 99 | + nn.init.xavier_normal_(self.att_dst.data, gain=1.414) |
| 100 | + |
| 101 | + # if bias and concat: |
| 102 | + # self.bias = Parameter(torch.Tensor(heads * out_channels)) |
| 103 | + # elif bias and not concat: |
| 104 | + # self.bias = Parameter(torch.Tensor(out_channels)) |
| 105 | + # else: |
| 106 | + # self.register_parameter('bias', None) |
| 107 | + |
| 108 | + self._alpha = None |
| 109 | + self.attentions = None |
| 110 | + |
| 111 | + # self.reset_parameters() |
| 112 | + |
| 113 | + # def reset_parameters(self): |
| 114 | + # self.lin_src.reset_parameters() |
| 115 | + # self.lin_dst.reset_parameters() |
| 116 | + # glorot(self.att_src) |
| 117 | + # glorot(self.att_dst) |
| 118 | + # # zeros(self.bias) |
| 119 | + |
| 120 | + def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, |
| 121 | + size: Size = None, return_attention_weights=None, attention=True, tied_attention = None): |
| 122 | + # type: (Union[Tensor, OptPairTensor], Tensor, Size, NoneType) -> Tensor # noqa |
| 123 | + # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, NoneType) -> Tensor # noqa |
| 124 | + # type: (Union[Tensor, OptPairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa |
| 125 | + # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa |
| 126 | + r""" |
| 127 | + Args: |
| 128 | + return_attention_weights (bool, optional): If set to :obj:`True`, |
| 129 | + will additionally return the tuple |
| 130 | + :obj:`(edge_index, attention_weights)`, holding the computed |
| 131 | + attention weights for each edge. (default: :obj:`None`) |
| 132 | + """ |
| 133 | + H, C = self.heads, self.out_channels |
| 134 | + |
| 135 | + # We first transform the input node features. If a tuple is passed, we |
| 136 | + # transform source and target node features via separate weights: |
| 137 | + if isinstance(x, Tensor): |
| 138 | + assert x.dim() == 2, "Static graphs not supported in 'GATConv'" |
| 139 | + # x_src = x_dst = self.lin_src(x).view(-1, H, C) |
| 140 | + x_src = x_dst = torch.mm(x, self.lin_src).view(-1, H, C) |
| 141 | + else: # Tuple of source and target node features: |
| 142 | + x_src, x_dst = x |
| 143 | + assert x_src.dim() == 2, "Static graphs not supported in 'GATConv'" |
| 144 | + x_src = self.lin_src(x_src).view(-1, H, C) |
| 145 | + if x_dst is not None: |
| 146 | + x_dst = self.lin_dst(x_dst).view(-1, H, C) |
| 147 | + |
| 148 | + x = (x_src, x_dst) |
| 149 | + |
| 150 | + if not attention: |
| 151 | + return x[0].mean(dim=1) |
| 152 | + # return x[0].view(-1, self.heads * self.out_channels) |
| 153 | + |
| 154 | + if tied_attention == None: |
| 155 | + # Next, we compute node-level attention coefficients, both for source |
| 156 | + # and target nodes (if present): |
| 157 | + alpha_src = (x_src * self.att_src).sum(dim=-1) |
| 158 | + alpha_dst = None if x_dst is None else (x_dst * self.att_dst).sum(-1) |
| 159 | + alpha = (alpha_src, alpha_dst) |
| 160 | + self.attentions = alpha |
| 161 | + else: |
| 162 | + alpha = tied_attention |
| 163 | + |
| 164 | + |
| 165 | + if self.add_self_loops: |
| 166 | + if isinstance(edge_index, Tensor): |
| 167 | + # We only want to add self-loops for nodes that appear both as |
| 168 | + # source and target nodes: |
| 169 | + num_nodes = x_src.size(0) |
| 170 | + if x_dst is not None: |
| 171 | + num_nodes = min(num_nodes, x_dst.size(0)) |
| 172 | + num_nodes = min(size) if size is not None else num_nodes |
| 173 | + edge_index, _ = remove_self_loops(edge_index) |
| 174 | + edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) |
| 175 | + elif isinstance(edge_index, SparseTensor): |
| 176 | + edge_index = set_diag(edge_index) |
| 177 | + |
| 178 | + # propagate_type: (x: OptPairTensor, alpha: OptPairTensor) |
| 179 | + out = self.propagate(edge_index, x=x, alpha=alpha, size=size) |
| 180 | + |
| 181 | + alpha = self._alpha |
| 182 | + assert alpha is not None |
| 183 | + self._alpha = None |
| 184 | + |
| 185 | + if self.concat: |
| 186 | + out = out.view(-1, self.heads * self.out_channels) |
| 187 | + else: |
| 188 | + out = out.mean(dim=1) |
| 189 | + |
| 190 | + # if self.bias is not None: |
| 191 | + # out += self.bias |
| 192 | + |
| 193 | + if isinstance(return_attention_weights, bool): |
| 194 | + if isinstance(edge_index, Tensor): |
| 195 | + return out, (edge_index, alpha) |
| 196 | + elif isinstance(edge_index, SparseTensor): |
| 197 | + return out, edge_index.set_value(alpha, layout='coo') |
| 198 | + else: |
| 199 | + return out |
| 200 | + |
| 201 | + def message(self, x_j: Tensor, alpha_j: Tensor, alpha_i: OptTensor, |
| 202 | + index: Tensor, ptr: OptTensor, |
| 203 | + size_i: Optional[int]) -> Tensor: |
| 204 | + # Given egel-level attention coefficients for source and target nodes, |
| 205 | + # we simply need to sum them up to "emulate" concatenation: |
| 206 | + alpha = alpha_j if alpha_i is None else alpha_j + alpha_i |
| 207 | + |
| 208 | + #alpha = F.leaky_relu(alpha, self.negative_slope) |
| 209 | + alpha = torch.sigmoid(alpha) |
| 210 | + alpha = softmax(alpha, index, ptr, size_i) |
| 211 | + self._alpha = alpha # Save for later use. |
| 212 | + alpha = F.dropout(alpha, p=self.dropout, training=self.training) |
| 213 | + return x_j * alpha.unsqueeze(-1) |
| 214 | + |
| 215 | + def __repr__(self): |
| 216 | + return '{}({}, {}, heads={})'.format(self.__class__.__name__, |
| 217 | + self.in_channels, |
| 218 | + self.out_channels, self.heads) |
0 commit comments