diff --git a/mmcv/ops/csrc/common/box_iou_rotated_utils.hpp b/mmcv/ops/csrc/common/box_iou_rotated_utils.hpp index a8453eaa8d..32663ab9ef 100644 --- a/mmcv/ops/csrc/common/box_iou_rotated_utils.hpp +++ b/mmcv/ops/csrc/common/box_iou_rotated_utils.hpp @@ -187,17 +187,18 @@ HOST_DEVICE_INLINE int convex_hull_graham(const Point (&p)[24], // (essentially sorting according to angles) // If the angles are the same, sort according to their distance to origin T dist[24]; - for (int i = 0; i < num_in; i++) { - dist[i] = dot_2d(q[i], q[i]); - } #ifdef __CUDACC__ // CUDA version // In the future, we can potentially use thrust // for sorting here to improve speed (though not guaranteed) + for (int i = 0; i < num_in; i++) { + dist[i] = dot_2d(q[i], q[i]); + dist[i] = sqrtf(float(dist[i])) + 1e-6; + } for (int i = 1; i < num_in - 1; i++) { for (int j = i + 1; j < num_in; j++) { - T crossProduct = cross_2d(q[i], q[j]); + T crossProduct = cross_2d(q[i] * (1 / dist[i]), q[j] * (1 / dist[j])); if ((crossProduct < -1e-6) || (fabs(crossProduct) < 1e-6 && dist[i] > dist[j])) { auto q_tmp = q[i]; @@ -211,18 +212,21 @@ HOST_DEVICE_INLINE int convex_hull_graham(const Point (&p)[24], } #else // CPU version + // compute distance to origin after sort, since the points are now different. std::sort(q + 1, q + num_in, [](const Point& A, const Point& B) -> bool { - T temp = cross_2d(A, B); + const T dot_A = sqrtf(float(dot_2d(A, A))) + 1e-6; + const T dot_B = sqrtf(float(dot_2d(B, B))) + 1e-6; + T temp = cross_2d(A * (1 / dot_A), B * (1 / dot_B)); if (fabs(temp) < 1e-6) { - return dot_2d(A, A) < dot_2d(B, B); + return dot_A < dot_B; } else { return temp > 0; } }); - // compute distance to origin after sort, since the points are now different. for (int i = 0; i < num_in; i++) { dist[i] = dot_2d(q[i], q[i]); + dist[i] = sqrtf(float(dist[i])); } #endif diff --git a/mmcv/ops/csrc/pytorch/cuda/box_iou_quadri_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/box_iou_quadri_cuda.cu index 25b6819a79..b8e27d2cb9 100644 --- a/mmcv/ops/csrc/pytorch/cuda/box_iou_quadri_cuda.cu +++ b/mmcv/ops/csrc/pytorch/cuda/box_iou_quadri_cuda.cu @@ -17,7 +17,7 @@ void box_iou_quadri_cuda(const Tensor boxes1, const Tensor boxes2, Tensor ious, box_iou_quadri_cuda_kernel <<>>( num_boxes1, num_boxes2, boxes1.data_ptr(), - boxes2.data_ptr(), (scalar_t*)ious.data_ptr(), + boxes2.data_ptr(), (scalar_t *)ious.data_ptr(), mode_flag, aligned); AT_CUDA_CHECK(cudaGetLastError()); } diff --git a/mmcv/ops/csrc/pytorch/cuda/box_iou_rotated_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/box_iou_rotated_cuda.cu index 3c13e06237..e055b0fdf7 100644 --- a/mmcv/ops/csrc/pytorch/cuda/box_iou_rotated_cuda.cu +++ b/mmcv/ops/csrc/pytorch/cuda/box_iou_rotated_cuda.cu @@ -19,7 +19,7 @@ void box_iou_rotated_cuda(const Tensor boxes1, const Tensor boxes2, Tensor ious, box_iou_rotated_cuda_kernel <<>>( num_boxes1, num_boxes2, boxes1.data_ptr(), - boxes2.data_ptr(), (scalar_t*)ious.data_ptr(), + boxes2.data_ptr(), (scalar_t *)ious.data_ptr(), mode_flag, aligned); AT_CUDA_CHECK(cudaGetLastError()); }