diff --git a/src/transformers/models/deprecated/graphormer/modeling_graphormer.py b/src/transformers/models/deprecated/graphormer/modeling_graphormer.py old mode 100755 new mode 100644 index b3e8ea742c8d..176a62fc0b1c --- a/src/transformers/models/deprecated/graphormer/modeling_graphormer.py +++ b/src/transformers/models/deprecated/graphormer/modeling_graphormer.py @@ -202,7 +202,7 @@ def forward( + self.out_degree_encoder(out_degree) ) - graph_token_feature = self.graph_token.weight.unsqueeze(0).repeat(n_graph, 1, 1) + graph_token_feature = self.graph_token.weight.unsqueeze(0).expand(n_graph, -1, -1) graph_node_feature = torch.cat([graph_token_feature, node_feature], dim=1)