diff --git a/models/position_encoding.py b/models/position_encoding.py index 73ae39edf..32ea9d458 100644 --- a/models/position_encoding.py +++ b/models/position_encoding.py @@ -37,7 +37,7 @@ def forward(self, tensor_list: NestedTensor): y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale - dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = torch_arange(self.num_pos_feats, dtype=torch.float32, device=x.device) dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) pos_x = x_embed[:, :, :, None] / dim_t @@ -87,3 +87,8 @@ def build_position_encoding(args): raise ValueError(f"not supported {args.position_embedding}") return position_embedding + + +@torch.fx.wrap +def torch_arange(x: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor: + return torch.arange(x, dtype=dtype, device=device)