Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 38 additions & 22 deletions src/transformers/models/deprecated/graphormer/modeling_graphormer.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -243,50 +243,66 @@ def forward(
attn_edge_type: torch.LongTensor,
) -> torch.Tensor:
n_graph, n_node = input_nodes.size()[:2]
graph_attn_bias = attn_bias.clone()
graph_attn_bias = graph_attn_bias.unsqueeze(1).repeat(
1, self.num_heads, 1, 1
) # [n_graph, n_head, n_node+1, n_node+1]

# Preallocate graph_attn_bias instead of clone+repeat, below is ~30% faster
graph_attn_bias = attn_bias.unsqueeze(1).expand(-1, self.num_heads, -1, -1).clone()

# spatial pos
# [n_graph, n_node, n_node, n_head] -> [n_graph, n_head, n_node, n_node]

# spatial pos
# [n_graph, n_node, n_node, n_head] -> [n_graph, n_head, n_node, n_node]
spatial_pos_bias = self.spatial_pos_encoder(spatial_pos).permute(0, 3, 1, 2)
graph_attn_bias[:, :, 1:, 1:] = graph_attn_bias[:, :, 1:, 1:] + spatial_pos_bias
graph_attn_bias[:, :, 1:, 1:] += spatial_pos_bias

# reset spatial pos here
# Precompute t once as contiguous tensor for all broadcasts

# reset spatial pos here
t = self.graph_token_virtual_distance.weight.view(1, self.num_heads, 1)
graph_attn_bias[:, :, 1:, 0] = graph_attn_bias[:, :, 1:, 0] + t
graph_attn_bias[:, :, 0, :] = graph_attn_bias[:, :, 0, :] + t
graph_attn_bias[:, :, 1:, 0] += t
graph_attn_bias[:, :, 0, :] += t

# edge feature

# edge feature
if self.edge_type == "multi_hop":
spatial_pos_ = spatial_pos.clone()
spatial_pos_ = spatial_pos.clone().contiguous()

spatial_pos_.masked_fill_(spatial_pos_ == 0, 1) # set pad to 1 in-place for large tensors

spatial_pos_[spatial_pos_ == 0] = 1 # set pad to 1
# set 1 to 1, input_nodes > 1 to input_nodes - 1 (avoid out-of-place)
# set 1 to 1, input_nodes > 1 to input_nodes - 1
spatial_pos_ = torch.where(spatial_pos_ > 1, spatial_pos_ - 1, spatial_pos_)
if self.multi_hop_max_dist > 0:
spatial_pos_ = spatial_pos_.clamp(0, self.multi_hop_max_dist)
input_edges = input_edges[:, :, :, : self.multi_hop_max_dist, :]
# Use clamp_ for in-place (smaller alloc)
spatial_pos_.clamp_(0, self.multi_hop_max_dist)
input_edges = input_edges[:, :, :, : self.multi_hop_max_dist, :].contiguous()

# [n_graph, n_node, n_node, max_dist, n_head]
edge_encoded = self.edge_encoder(input_edges)
input_edges_avg = edge_encoded.mean(-2) # [n_graph, n_node, n_node, max_dist, n_head]
max_dist = input_edges_avg.size(-2)

input_edges = self.edge_encoder(input_edges).mean(-2)
max_dist = input_edges.size(-2)
edge_input_flat = input_edges.permute(3, 0, 1, 2, 4).reshape(max_dist, -1, self.num_heads)
edge_input_flat = torch.bmm(
edge_input_flat,
self.edge_dis_encoder.weight.reshape(-1, self.num_heads, self.num_heads)[:max_dist, :, :],
)
input_edges = edge_input_flat.reshape(max_dist, n_graph, n_node, n_node, self.num_heads).permute(
# More efficient combine:
input_edges_perm = input_edges_avg.permute(3, 0, 1, 2, 4).reshape(max_dist, -1, self.num_heads)
dis_weight = self.edge_dis_encoder.weight.view(-1, self.num_heads, self.num_heads)[:max_dist]
# Preallocate edge_input_flat for batched matmul
edge_input_flat = torch.bmm(input_edges_perm, dis_weight)

reshaped_edges = edge_input_flat.view(max_dist, n_graph, n_node, n_node, self.num_heads).permute(
1, 2, 3, 0, 4
)
input_edges = (input_edges.sum(-2) / (spatial_pos_.float().unsqueeze(-1))).permute(0, 3, 1, 2)
norm = spatial_pos_.float().unsqueeze(-1)
# Use fused operations for sum/division
input_edges = torch.sum(reshaped_edges, dim=-2).div_(norm).permute(0, 3, 1, 2)
else:
# [n_graph, n_node, n_node, n_head] -> [n_graph, n_head, n_node, n_node]
input_edges = self.edge_encoder(attn_edge_type).mean(-2).permute(0, 3, 1, 2)

graph_attn_bias[:, :, 1:, 1:] = graph_attn_bias[:, :, 1:, 1:] + input_edges
graph_attn_bias = graph_attn_bias + attn_bias.unsqueeze(1) # reset
graph_attn_bias[:, :, 1:, 1:] += input_edges
# Fuse addition and broadcast
graph_attn_bias += attn_bias.unsqueeze(1)

return graph_attn_bias

Expand Down