Skip to content

Commit a1779ad

Browse files
committed
Add correction to edge weights definition
1 parent bef7eea commit a1779ad

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

project/utils/deepinteract_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,7 @@ def convert_df_to_dgl_graph(df: pd.DataFrame, input_file: str, knn: int,
496496
# Positional encoding for each edge (used for sequentially-ordered inputs like proteins)
497497
graph.edata['f'] = torch.sin((graph.edges()[0] - graph.edges()[1]).float()).reshape(-1, 1) # [num_edges, 1]
498498
# Normalized edge weights (according to Euclidean distance)
499-
edge_weights = min_max_normalize_tensor(torch.sum(node_coords[srcs] - node_coords[dsts] ** 2, 1)).reshape(-1, 1)
499+
edge_weights = min_max_normalize_tensor(torch.sum((node_coords[srcs] - node_coords[dsts]) ** 2, 1)).reshape(-1, 1)
500500
graph.edata['f'] = torch.cat((graph.edata['f'], edge_weights), dim=1) # [num_edges, 1]
501501
# Geometric edge features derived above
502502
graph.edata['f'] = torch.cat((graph.edata['f'], edge_dist_feats), dim=1) # Distance: [num_edges, num_rbf] if full

0 commit comments

Comments
 (0)