diff --git a/mmcv/ops/csrc/pytorch/npu/assign_score_withk_npu.cpp b/mmcv/ops/csrc/pytorch/npu/assign_score_withk_npu.cpp new file mode 100644 index 0000000000..8f5db077a0 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/assign_score_withk_npu.cpp @@ -0,0 +1,44 @@ +#include "pytorch_npu_helper.hpp" + +using namespace NPU_NAME_SPACE; +using namespace std; + +void assign_score_withk_forward_npu(int B, int N0, int N1, int M, int K, int O, + int aggregate, const Tensor& points, + const Tensor& centers, + const Tensor& scores, + const Tensor& knn_idx, Tensor& output) { + at::Tensor points_trans = points.permute({0, 3, 1, 2}); + at::Tensor centers_trans = centers.permute({0, 3, 1, 2}); + + EXEC_NPU_CMD(aclnnAssignScoreWithk, points_trans, centers_trans, scores, knn_idx, B, N0, N1, M, K, O, aggregate, output); +} + +void assign_score_withk_forward_impl(int B, int N0, int N1, int M, int K, int O, + int aggregate, const Tensor& points, + const Tensor& centers, + const Tensor& scores, + const Tensor& knn_idx, Tensor& output); + +REGISTER_NPU_IMPL(assign_score_withk_forward_impl, assign_score_withk_forward_npu); + + +void assign_score_withk_backward_npu( + int B, int N0, int N1, int M, int K, int O, int aggregate, + const Tensor& grad_out, const Tensor& points, const Tensor& centers, + const Tensor& scores, const Tensor& knn_idx, Tensor& grad_points, + Tensor& grad_centers, Tensor& grad_scores) { + + at::Tensor grad_out_trans = grad_out.permute({0, 2, 3, 1}); + + EXEC_NPU_CMD(aclnnAssignScoreWithkGrad, grad_out_trans, points, centers, scores, knn_idx, B, N0, N1, M, K, O, aggregate, grad_scores, grad_points, grad_centers); +} + +void assign_score_withk_backward_impl( + int B, int N0, int N1, int M, int K, int O, int aggregate, + const Tensor& grad_out, const Tensor& points, const Tensor& centers, + const Tensor& scores, const Tensor& knn_idx, Tensor& grad_points, + Tensor& grad_centers, Tensor& grad_scores); + +REGISTER_NPU_IMPL(assign_score_withk_backward_impl, assign_score_withk_backward_npu); + diff --git a/mmcv/ops/csrc/pytorch/npu/border_align_npu.cpp b/mmcv/ops/csrc/pytorch/npu/border_align_npu.cpp new file mode 100644 index 0000000000..b1d0004e00 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/border_align_npu.cpp @@ -0,0 +1,53 @@ +#include "pytorch_npu_helper.hpp" + +using namespace NPU_NAME_SPACE; +using namespace std; + +void border_align_forward_impl(const Tensor &input, const Tensor &boxes, Tensor output, + Tensor argmax_idx, const int pool_size); + +void border_align_forward_npu(const Tensor &input, const Tensor &boxes, Tensor output, + Tensor argmax_idx, const int pool_size){ + TORCH_CHECK(input.size(0) == boxes.size(0), "The batch sizes of feature map and rois must be the same."); + TORCH_CHECK(input.size(1) % 4 == 0, "The number of channels must be divisible by 4."); + TORCH_CHECK(pool_size >= 2, "The pool size should be larger than 2."); + int32_t batch_size = input.size(0); + int32_t channels = input.size(1); + int32_t height = input.size(2); + int32_t width = input.size(3); + at::Tensor feature_map = input.permute({0, 2, 3, 1}).contiguous(); + at::Tensor rois_map = boxes.contiguous(); + at::Tensor temp_tensor = at::zeros({batch_size, height * width, pool_size + 1, channels}, input.options()); + EXEC_NPU_CMD(aclnnBorderAlign, feature_map, rois_map, pool_size, temp_tensor); + auto max_result = temp_tensor.max(-2); + at::Tensor output_ = std::get<0>(max_result).to(at::kFloat); + output_ = output_.reshape({batch_size, height * width, 4, channels / 4}).permute({0, 3, 1, 2}).contiguous(); + output.copy_(output_); + at::Tensor argmax_idx_ = std::get<1>(max_result).to(at::kInt); + argmax_idx_ = argmax_idx_.reshape({batch_size, height * width, 4, channels / 4}).permute({0, 3, 1, 2}).contiguous(); + argmax_idx.copy_(argmax_idx_); +} +REGISTER_NPU_IMPL(border_align_forward_impl, border_align_forward_npu); + + +void border_align_backward_impl(const Tensor &grad_output, const Tensor &boxes, + const Tensor &argmax_idx, Tensor grad_input, + const int pool_size); + +void border_align_backward_npu(const Tensor &grad_output, const Tensor &boxes, + const Tensor &argmax_idx, Tensor grad_input, + const int pool_size){ + TORCH_CHECK(grad_output.dim() == 4, "grad_out.dim() must be 4, but got: ", grad_output.dim()); + TORCH_CHECK(boxes.dim() == 3, "idx.dim() must be 3, but got: ", boxes.dim()); + TORCH_CHECK(argmax_idx.dim() == 4, "argmax_idx.dim() must be 4, but got: ", argmax_idx.dim()); + + int32_t batch_size = grad_output.size(0); + int32_t feat_channels = grad_output.size(1) * 4; + int32_t channels = grad_output.size(1); + int32_t box_size = boxes.size(1); + int32_t height = grad_input.size(2); + int32_t width = grad_input.size(3); + + EXEC_NPU_CMD(aclnnBorderAlignGrad, grad_output, boxes, argmax_idx, channels, box_size, height, width, pool_size, batch_size, grad_input); +} +REGISTER_NPU_IMPL(border_align_backward_impl, border_align_backward_npu); diff --git a/mmcv/ops/csrc/pytorch/npu/box_iou_quadri_npu.cpp b/mmcv/ops/csrc/pytorch/npu/box_iou_quadri_npu.cpp index 84e2c0a678..6baf44f448 100644 --- a/mmcv/ops/csrc/pytorch/npu/box_iou_quadri_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/box_iou_quadri_npu.cpp @@ -4,10 +4,11 @@ using namespace NPU_NAME_SPACE; using namespace std; void box_iou_quadri_impl(const Tensor boxes1, const Tensor boxes2, Tensor ious, - const int mode_flag, const bool aligned); + const int mode_flag, const bool aligned); void box_iou_quadri_npu(const Tensor boxes1, const Tensor boxes2, Tensor ious, - const int mode_flag, const bool aligned) { + const int mode_flag, const bool aligned) { + TORCH_CHECK(boxes1.size(1) == 8, "boxes1 must be 2D tensor (N, 8)"); TORCH_CHECK(boxes1.size(1) == 8, "boxes1 must be 2D tensor (N, 8)"); diff --git a/mmcv/ops/csrc/pytorch/npu/box_iou_rotated_npu.cpp b/mmcv/ops/csrc/pytorch/npu/box_iou_rotated_npu.cpp index 5b229a3926..d8b0bbaa67 100644 --- a/mmcv/ops/csrc/pytorch/npu/box_iou_rotated_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/box_iou_rotated_npu.cpp @@ -8,13 +8,14 @@ void box_iou_rotated_impl(const Tensor boxes1, const Tensor boxes2, Tensor ious, void box_iou_rotated_npu(const Tensor boxes1, const Tensor boxes2, Tensor ious, const int mode_flag, const bool aligned) { + TORCH_CHECK(boxes1.size(1) == 5, "boxes1 must be 2D tensor (N, 5)"); TORCH_CHECK(boxes1.size(1) == 5, "boxes1 must be 2D tensor (N, 5)"); auto trans = false; auto is_clockwise = false; EXEC_NPU_CMD(aclnnBoxesOverlapBev, boxes1, boxes2, trans, is_clockwise, - aligned, mode_flag, ious); + aligned, mode_flag, ious); return; } diff --git a/mmcv/ops/csrc/pytorch/npu/boxes_overlap_bev_npu.cpp b/mmcv/ops/csrc/pytorch/npu/boxes_overlap_bev_npu.cpp index 0d97df6ad7..6bc6273083 100644 --- a/mmcv/ops/csrc/pytorch/npu/boxes_overlap_bev_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/boxes_overlap_bev_npu.cpp @@ -10,17 +10,16 @@ void iou3d_boxes_overlap_bev_forward_impl(const int num_a, const Tensor boxes_a, void iou3d_boxes_overlap_bev_forward_npu(const int num_a, const Tensor boxes_a, const int num_b, const Tensor boxes_b, Tensor ans_overlap) { - TORCH_CHECK(boxes_a.size(1) == 7, "boxes_a must be 2D tensor (N, 7)"); - TORCH_CHECK(boxes_b.size(1) == 7, "boxes_b must be 2D tensor (N, 7)"); - auto trans = false; - auto is_clockwise = false; - auto aligned = false; - auto mode_flag = 2; - EXEC_NPU_CMD(aclnnBoxesOverlapBev, boxes_a, boxes_b, trans, is_clockwise, - aligned, mode_flag, ans_overlap); - return; + TORCH_CHECK(boxes_a.size(1) == 7, "boxes_a must be 2D tensor (N, 7)"); + TORCH_CHECK(boxes_b.size(1) == 7, "boxes_b must be 2D tensor (N, 7)"); + + auto trans = false; + auto is_clockwise = false; + auto aligned = false; + auto mode_flag = 2; + EXEC_NPU_CMD(aclnnBoxesOverlapBev, boxes_a, boxes_b, trans, is_clockwise, aligned, mode_flag, ans_overlap); + return; } -REGISTER_NPU_IMPL(iou3d_boxes_overlap_bev_forward_impl, - iou3d_boxes_overlap_bev_forward_npu); +REGISTER_NPU_IMPL(iou3d_boxes_overlap_bev_forward_impl, iou3d_boxes_overlap_bev_forward_npu); diff --git a/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp b/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp index 8b30fa15df..4f5c32dbec 100644 --- a/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp @@ -5,19 +5,34 @@ using namespace std; void chamfer_distance_forward_npu(Tensor XYZ1, Tensor XYZ2, Tensor dist1, Tensor dist2, Tensor idx1, Tensor idx2) { + bool is_half = XYZ1.scalar_type() == at::kHalf; at::Tensor xyz1 = at::ones_like(XYZ1); at::Tensor xyz2 = at::ones_like(XYZ2); + at::Tensor distf1 = at::ones_like(dist1); + at::Tensor distf2 = at::ones_like(dist2); xyz1 = XYZ1.transpose(1, 2).transpose(0, 1); xyz2 = XYZ2.transpose(1, 2).transpose(0, 1); + if (is_half) { + xyz1 = xyz1.to(at::kFloat); + xyz2 = xyz2.to(at::kFloat); + distf1 = dist1.to(at::kFloat); + distf2 = dist2.to(at::kFloat); + } OpCommand cmd; cmd.Name("ChamferDistance") .Input(xyz1) .Input(xyz2) - .Output(dist1) - .Output(dist2) + .Output(distf1) + .Output(distf2) .Output(idx1) .Output(idx2) .Run(); + if (is_half) { + distf1 = distf1.to(at::kHalf); + distf2 = distf2.to(at::kHalf); + } + dist1.copy_(distf1); + dist2.copy_(distf2); } void chamfer_distance_backward_npu(Tensor xyz1, Tensor xyz2, Tensor idx1, diff --git a/mmcv/ops/csrc/pytorch/npu/deform_roi_pool.cpp b/mmcv/ops/csrc/pytorch/npu/deform_roi_pool.cpp index 074e52d4f4..42de978e88 100644 --- a/mmcv/ops/csrc/pytorch/npu/deform_roi_pool.cpp +++ b/mmcv/ops/csrc/pytorch/npu/deform_roi_pool.cpp @@ -53,7 +53,7 @@ void deform_roi_pool_backward_npu(Tensor grad_output, Tensor input, Tensor rois, .Output(grad_offset) .Attr("output_size", output_size) .Attr("spatial_scale", spatial_scale) - .Attr("sample_ratio", sampling_ratio_) + .Attr("sampling_ratio", sampling_ratio_) .Attr("gamma", gamma) .Run(); } diff --git a/mmcv/ops/csrc/pytorch/npu/diff_iou_rotated_npu.cpp b/mmcv/ops/csrc/pytorch/npu/diff_iou_rotated_npu.cpp new file mode 100644 index 0000000000..7091a8ec3d --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/diff_iou_rotated_npu.cpp @@ -0,0 +1,28 @@ +#include "pytorch_npu_helper.hpp" +using namespace NPU_NAME_SPACE; +using namespace std; + +Tensor diff_iou_rotated_sort_vertices_npu(Tensor vertices, + Tensor mask, + Tensor num_valid) { + TORCH_CHECK(vertices.dim() == 4, "vertices must be a 4D Tensor, but got: ", vertices.dim()); + TORCH_CHECK(mask.dim() == 3, "mask must be a 3D Tensor, but got: ", mask.dim()); + TORCH_CHECK(num_valid.dim() == 2, "num_valid must be a 2D Tensor, but got: ", num_valid.dim()); + + uint32_t B = vertices.size(0); + uint32_t N = vertices.size(1); + + at::Tensor sortedIdx = at::empty({B, N, 9}, num_valid.options()); + at::Tensor mask_fp = mask.to(at::kFloat); + + EXEC_NPU_CMD(aclnnDiffIouRotatedSortVertices, vertices, mask_fp, num_valid, sortedIdx); + + return sortedIdx; +} + +Tensor diff_iou_rotated_sort_vertices_forward_impl(Tensor vertices, + Tensor mask, + Tensor num_valid); + +REGISTER_NPU_IMPL(diff_iou_rotated_sort_vertices_forward_impl, + diff_iou_rotated_sort_vertices_npu); diff --git a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp index 5030fed0e7..ef7df560c9 100644 --- a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp @@ -4,6 +4,21 @@ using namespace std; void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, Tensor output, float gamma, float alpha) { + at::Tensor input_y = input; + at::Tensor output_y = output; + bool is_half = input.scalar_type() == at::kHalf; + if (is_half) { + input_y = input.to(at::kFloat); + output_y = output.to(at::kFloat); + } + int64_t weight_size = weight.size(0); + at::Tensor weight_y = at::ones_like(input_y); + if (weight_size > 0) { + weight_y = at::broadcast_to(weight, input.sizes()); + if (is_half) { + weight_y = weight_y.to(at::kFloat); + } + } int64_t n_class = input.size(1); at::Tensor target_y = at::ones_like(input); if (n_class == 1) { @@ -12,24 +27,26 @@ void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, target_y = at::add(target_y, 1.0); } else { target_y = at::one_hot(target, n_class); + weight_y = at::mul(weight_y, target_y); + weight_y = at::sum(weight_y, 1, true); + weight_y = at::broadcast_to(weight_y, input.sizes()); } target_y = target_y.to(at::kInt); - int64_t weight_size = weight.size(0); - at::Tensor weight_y = at::ones_like(input); - if (weight_size > 0) { - weight_y = at::broadcast_to(weight, input.sizes()); - } OpCommand cmd; string reduction = "none"; cmd.Name("SigmoidFocalLoss") - .Input(input) + .Input(input_y) .Input(target_y) .Input(weight_y) - .Output(output) + .Output(output_y) .Attr("gamma", gamma) .Attr("alpha", alpha) .Attr("reduction", reduction) .Run(); + if (is_half) { + output_y = output_y.to(at::kHalf); + } + output.copy_(output_y); } void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, @@ -38,34 +55,51 @@ void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, Tensor grad_input, float gamma, float alpha) { + at::Tensor input_y = input; + at::Tensor grad_input_y = grad_input; + bool is_half = input.scalar_type() == at::kHalf; + if (is_half) { + input_y = input.to(at::kFloat); + grad_input_y = grad_input.to(at::kFloat); + } + int64_t weight_size = weight.size(0); + at::Tensor weight_y = at::ones_like(input_y); + if (weight_size > 0) { + weight_y = at::broadcast_to(weight, input.sizes()); + if (is_half) { + weight_y = weight_y.to(at::kFloat); + } + } int64_t n_class = input.size(1); at::Tensor target_y = at::ones_like(input); if (n_class == 1) { target_y = at::reshape(target, input.sizes()); } else { target_y = at::one_hot(target, n_class); + weight_y = at::mul(weight_y, target_y); + weight_y = at::sum(weight_y, 1, true); + weight_y = at::broadcast_to(weight_y, input.sizes()); target_y = at::mul(target_y, -1.0); target_y = at::add(target_y, 1.0); } target_y = target_y.to(at::kInt); at::Tensor grad_up = at::ones_like(input); - int64_t weight_size = weight.size(0); - at::Tensor weight_y = at::ones_like(input); - if (weight_size > 0) { - weight_y = at::broadcast_to(weight, input.sizes()); - } OpCommand cmd; string reduction = "none"; cmd.Name("SigmoidFocalLossGrad") - .Input(input) + .Input(input_y) .Input(target_y) .Input(grad_up) .Input(weight_y) - .Output(grad_input) + .Output(grad_input_y) .Attr("gamma", gamma) .Attr("alpha", alpha) .Attr("reduction", reduction) .Run(); + if (is_half) { + grad_input_y = grad_input_y.to(at::kHalf); + } + grad_input.copy_(grad_input_y); } void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target, @@ -74,19 +108,30 @@ void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target, void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, Tensor output, float gamma, float alpha) { + at::Tensor input_y = input; + bool is_half = input.scalar_type() == at::kHalf; + if (is_half) { + input_y = input.to(at::kFloat); + } int64_t n_class = input.size(1); at::Tensor target_y = at::one_hot(target, n_class); target_y = target_y.to(at::kInt); int64_t weight_size = weight.size(0); - at::Tensor weight_y = at::ones_like(input); + at::Tensor weight_y = at::ones_like(input_y); if (weight_size > 0) { weight_y = at::broadcast_to(weight, input.sizes()); + if (is_half) { + weight_y = weight_y.to(at::kFloat); + } + weight_y = at::mul(weight_y, target_y); + weight_y = at::sum(weight_y, 1, true); + weight_y = at::broadcast_to(weight_y, input.sizes()); } - at::Tensor op_output = at::ones_like(input); + at::Tensor op_output = at::ones_like(input_y); OpCommand cmd; string reduction = "none"; cmd.Name("SoftmaxFocalLoss") - .Input(input) + .Input(input_y) .Input(target_y) .Input(weight_y) .Output(op_output) @@ -94,6 +139,9 @@ void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, .Attr("alpha", alpha) .Attr("reduction", reduction) .Run(); + if (is_half) { + op_output = op_output.to(at::kHalf); + } int64_t n_batch = input.size(0); c10::SmallVector offsets = {0, 0}; c10::SmallVector sizes = {n_batch, 1}; @@ -124,27 +172,45 @@ void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, void softmax_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, Tensor buff, Tensor grad_input, float gamma, float alpha) { + at::Tensor input_y = input; + at::Tensor grad_input_y = grad_input; + bool is_half = input.scalar_type() == at::kHalf; + if (is_half) { + input_y = input.to(at::kFloat); + grad_input_y = grad_input.to(at::kFloat); + } int64_t n_class = input.size(1); at::Tensor target_y = at::one_hot(target, n_class); target_y = target_y.to(at::kInt); at::Tensor grad_up = at::ones_like(input); int64_t weight_size = weight.size(0); - at::Tensor weight_y = at::ones_like(input); + at::Tensor weight_y = at::ones_like(input_y); if (weight_size > 0) { weight_y = at::broadcast_to(weight, input.sizes()); + if (is_half) { + weight_y = weight_y.to(at::kFloat); + } + weight_y = at::mul(weight_y, target_y); + weight_y = at::sum(weight_y, 1, true); + weight_y = at::broadcast_to(weight_y, input.sizes()); } + grad_input_y = grad_input_y.fill_(0); OpCommand cmd; string reduction = "none"; cmd.Name("SoftmaxFocalLossGrad") - .Input(input) + .Input(input_y) .Input(target_y) .Input(grad_up) .Input(weight_y) - .Output(grad_input) + .Output(grad_input_y) .Attr("gamma", gamma) .Attr("alpha", alpha) .Attr("reduction", reduction) .Run(); + if (is_half) { + grad_input_y = grad_input_y.to(at::kHalf); + } + grad_input.copy_(grad_input_y); } void softmax_focal_loss_backward_impl(Tensor input, Tensor target, diff --git a/mmcv/ops/csrc/pytorch/npu/knn_npu.cpp b/mmcv/ops/csrc/pytorch/npu/knn_npu.cpp index f25f9cf623..d63977f508 100644 --- a/mmcv/ops/csrc/pytorch/npu/knn_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/knn_npu.cpp @@ -8,11 +8,11 @@ using namespace std; void knn_forward_npu(int b, int n, int m, int nsample, const Tensor xyz, const Tensor new_xyz, Tensor idx, Tensor dist2) { // transpose known from [B, N, 3] to [B, 3, N] - at::Tensor source = xyz.transpose(1, 2).contiguous(); + at::Tensor source = xyz.transpose(2, 1).contiguous(); at::Tensor target = new_xyz.contiguous(); bool is_from_knn = true; - EXEC_NPU_CMD(aclnnKnn, source, target, nsample, is_from_knn, idx, dist2); + EXEC_NPU_CMD(aclnnKnn, source, target, is_from_knn, nsample, dist2, idx); } void knn_forward_impl(int b, int n, int m, int nsample, const Tensor xyz, diff --git a/mmcv/ops/csrc/pytorch/npu/nms3d_normal_npu.cpp b/mmcv/ops/csrc/pytorch/npu/nms3d_normal_npu.cpp index 5d812fe047..6d2588a01d 100644 --- a/mmcv/ops/csrc/pytorch/npu/nms3d_normal_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/nms3d_normal_npu.cpp @@ -3,23 +3,24 @@ using namespace NPU_NAME_SPACE; void iou3d_nms3d_normal_forward_npu(const Tensor boxes, Tensor &keep, - Tensor &keep_num, - float nms_overlap_thresh) { + Tensor &num_out, float nms_overlap_thresh) { int32_t box_num = boxes.size(0); int32_t data_align = 16; int32_t mask_num = ((box_num - 1) / data_align + 1) * data_align; + const double iou_threshold = nms_overlap_thresh; at::Tensor mask = at::empty({box_num, mask_num}, boxes.options().dtype(at::kShort)); - EXEC_NPU_CMD(aclnnNms3dNormal, boxes, nms_overlap_thresh, mask); + EXEC_NPU_CMD(aclnnNms3dNormal, boxes, iou_threshold, mask); - keep = at::zeros({box_num}, mask.options()); - keep_num = at::zeros(1, mask.options()); - EXEC_NPU_CMD(aclnnGatherNms3dMask, mask, keep, keep_num); + Tensor keep_t = at::zeros({box_num}, mask.options()); + Tensor num_out_t = at::zeros(1, mask.options()); + EXEC_NPU_CMD(aclnnGatherNms3dMask, mask, keep_t, num_out_t); + num_out.fill_(num_out_t.item().toLong()); + keep.copy_(keep_t); } void iou3d_nms3d_normal_forward_impl(const Tensor boxes, Tensor &keep, - Tensor &keep_num, - float nms_overlap_thresh); + Tensor &num_out, float nms_overlap_thresh); REGISTER_NPU_IMPL(iou3d_nms3d_normal_forward_impl, iou3d_nms3d_normal_forward_npu); diff --git a/mmcv/ops/csrc/pytorch/npu/nms3d_npu.cpp b/mmcv/ops/csrc/pytorch/npu/nms3d_npu.cpp index 13fe6db860..a143ed07b5 100644 --- a/mmcv/ops/csrc/pytorch/npu/nms3d_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/nms3d_npu.cpp @@ -5,22 +5,26 @@ using namespace std; constexpr int32_t BOX_DIM = 7; -void iou3d_nms3d_forward_npu(const Tensor boxes, Tensor &keep, Tensor &keep_num, +void iou3d_nms3d_forward_npu(const Tensor boxes, Tensor &keep, Tensor &num_out, float nms_overlap_thresh) { TORCH_CHECK((boxes.sizes()[1] == BOX_DIM), "Input boxes shape should be (N, 7)"); int32_t box_num = boxes.size(0); int32_t data_align = 16; int32_t mask_num = ((box_num - 1) / data_align + 1) * data_align; + const double iou_threshold = nms_overlap_thresh; at::Tensor mask = at::empty({box_num, mask_num}, boxes.options().dtype(at::kShort)); - EXEC_NPU_CMD(aclnnNms3d, boxes, nms_overlap_thresh, mask); - keep = at::zeros({box_num}, mask.options()); - keep_num = at::zeros(1, mask.options()); - EXEC_NPU_CMD(aclnnGatherNms3dMask, mask, keep, keep_num); + EXEC_NPU_CMD(aclnnNms3d, boxes, iou_threshold, mask); + + Tensor keep_t = at::zeros({box_num}, mask.options()); + Tensor num_out_t = at::zeros(1, mask.options()); + EXEC_NPU_CMD(aclnnGatherNms3dMask, mask, keep_t, num_out_t); + num_out.fill_(num_out_t.item().toLong()); + keep.copy_(keep_t); } -void iou3d_nms3d_forward_impl(const Tensor boxes, Tensor &keep, - Tensor &keep_num, float nms_overlap_thresh); +void iou3d_nms3d_forward_impl(const Tensor boxes, Tensor &keep, Tensor &num_out, + float nms_overlap_thresh); REGISTER_NPU_IMPL(iou3d_nms3d_forward_impl, iou3d_nms3d_forward_npu); diff --git a/mmcv/ops/csrc/pytorch/npu/points_in_box_npu_all.cpp b/mmcv/ops/csrc/pytorch/npu/points_in_box_npu_all.cpp index e3cc9284cc..acf70d81c2 100644 --- a/mmcv/ops/csrc/pytorch/npu/points_in_box_npu_all.cpp +++ b/mmcv/ops/csrc/pytorch/npu/points_in_box_npu_all.cpp @@ -4,17 +4,17 @@ using namespace NPU_NAME_SPACE; using namespace std; void points_in_boxes_all_forward_impl_npu(int batch_size, int boxes_num, - int pts_num, const Tensor boxes, - const Tensor pts, - Tensor box_idx_of_points) { - c10::SmallVector output_size = {pts.size(0), pts.size(1), - boxes.size(1)}; - auto boxes_trans = boxes.transpose(1, 2).contiguous(); - EXEC_NPU_CMD(aclnnPointsInBoxAll, boxes_trans, pts, box_idx_of_points); + int pts_num, const Tensor boxes, + const Tensor pts, + Tensor box_idx_of_points) { + c10::SmallVector output_size = {pts.size(0), pts.size(1), boxes.size(1)}; + auto boxes_trans = boxes.transpose(1, 2).contiguous(); + EXEC_NPU_CMD(aclnnPointsInBoxAll, boxes_trans, pts, box_idx_of_points); } + void points_in_boxes_all_forward_impl(int batch_size, int boxes_num, - int pts_num, const Tensor boxes, - const Tensor pts, - Tensor box_idx_of_points); + int pts_num, const Tensor boxes, + const Tensor pts, + Tensor box_idx_of_points); REGISTER_NPU_IMPL(points_in_boxes_all_forward_impl, points_in_boxes_all_forward_impl_npu); diff --git a/mmcv/ops/csrc/pytorch/npu/roi_align_npu.cpp b/mmcv/ops/csrc/pytorch/npu/roi_align_npu.cpp index 2255d302e3..9d12f08906 100644 --- a/mmcv/ops/csrc/pytorch/npu/roi_align_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/roi_align_npu.cpp @@ -8,6 +8,8 @@ void roi_align_forward_npu(Tensor input, Tensor rois, Tensor output, Tensor argmax_y, Tensor argmax_x, int aligned_height, int aligned_width, float spatial_scale, int sampling_ratio, int pool_mode, bool aligned) { + TORCH_CHECK(input.scalar_type() == at::kFloat, + "input should be a float tensor"); int64_t roi_end_mode = 2; if (!aligned) { LOG(WARNING) << "The [aligned] attr in roi_align op is false"; @@ -34,6 +36,8 @@ void roi_align_backward_npu(Tensor grad_output, Tensor rois, Tensor argmax_y, int aligned_height, int aligned_width, float spatial_scale, int sampling_ratio, int pool_mode, bool aligned) { + TORCH_CHECK(grad_output.scalar_type() == at::kFloat, + "input should be a float tensor"); int64_t aligned_height_64 = aligned_height; int64_t aligned_width_64 = aligned_width; int64_t sampling_ratio_64 = sampling_ratio; diff --git a/mmcv/ops/csrc/pytorch/npu/roi_align_rotated_v2_npu.cpp b/mmcv/ops/csrc/pytorch/npu/roi_align_rotated_v2_npu.cpp index f9fac97397..64248ada45 100644 --- a/mmcv/ops/csrc/pytorch/npu/roi_align_rotated_v2_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/roi_align_rotated_v2_npu.cpp @@ -3,42 +3,61 @@ using namespace NPU_NAME_SPACE; using namespace std; -void roi_align_rotated_v2_forward_npu(const Tensor input, Tensor rois_map, - Tensor output, double spatial_scale, - int32_t sampling_ratio, - int32_t pooled_height, - int32_t pooled_width, bool aligned, - bool clockwise) { - at::Tensor feature_map = input.permute({0, 2, 3, 1}).contiguous(); +void roi_align_rotated_v2_forward_npu(const Tensor x, Tensor rois_map, + Tensor y, + int32_t pooled_h, + int32_t pooled_w, + double spatial_scale, + int32_t sampling_ratio, + bool aligned, + bool clockwise) { + at::Tensor feature_map = x.permute({0, 2, 3, 1}).contiguous(); at::Tensor rois = rois_map.permute({1, 0}).contiguous(); - EXEC_NPU_CMD(aclnnRoiAlignRotatedV2, feature_map, rois, spatial_scale, - sampling_ratio, pooled_height, pooled_width, aligned, clockwise, - output); + at_npu::native::OpCommand cmd; + cmd.Name("RoiAlignRotated") + .Input(feature_map) + .Input(rois) + .Output(y) + .Attr("pooled_h", static_cast(pooled_h)) + .Attr("pooled_w", static_cast(pooled_w)) + .Attr("spatial_scale", static_cast(spatial_scale)) + .Attr("sampling_ratio", static_cast(sampling_ratio)) + .Attr("aligned", aligned) + .Attr("clockwise", clockwise) + .Run(); } -void roi_align_rotated_v2_forward_impl(const Tensor input, Tensor rois, - Tensor output, double spatial_scale, - int32_t sampling_ratio, - int32_t pooled_height, - int32_t pooled_width, bool aligned, - bool clockwise); +void roi_align_rotated_v2_forward_impl(const Tensor x, Tensor rois, + Tensor y, + int32_t pooled_h, + int32_t pooled_w, + double spatial_scale, + int32_t sampling_ratio, + bool aligned, + bool clockwise); -REGISTER_NPU_IMPL(roi_align_rotated_v2_forward_impl, - roi_align_rotated_v2_forward_npu); +REGISTER_NPU_IMPL(roi_align_rotated_v2_forward_impl, roi_align_rotated_v2_forward_npu); -void roi_align_rotated_v2_backward_npu( - const Tensor input, Tensor rois, Tensor grad_output, Tensor grad_input, - int32_t pooled_height, int32_t pooled_width, double spatial_scale, - int32_t sampling_ratio, bool aligned, bool clockwise) { +void roi_align_rotated_v2_backward_npu(const Tensor input, Tensor rois, + Tensor grad_output, Tensor grad_input, + int32_t pooled_height, + int32_t pooled_width, + double spatial_scale, + int32_t sampling_ratio, + bool aligned, + bool clockwise) { EXEC_NPU_CMD(aclnnRoiAlignRotatedGradV2, input, rois, grad_output, - pooled_height, pooled_width, spatial_scale, sampling_ratio, - aligned, clockwise, grad_input); + pooled_height, pooled_width, spatial_scale, sampling_ratio, aligned, clockwise, + grad_input); } -void roi_align_rotated_v2_backward_impl( - const Tensor input, Tensor rois, Tensor grad_output, Tensor grad_input, - int32_t pooled_height, int32_t pooled_width, double spatial_scale, - int32_t sampling_ratio, bool aligned, bool clockwise); +void roi_align_rotated_v2_backward_impl(const Tensor input, Tensor rois, + Tensor grad_output, Tensor grad_input, + int32_t pooled_height, + int32_t pooled_width, + double spatial_scale, + int32_t sampling_ratio, + bool aligned, + bool clockwise); -REGISTER_NPU_IMPL(roi_align_rotated_v2_backward_impl, - roi_align_rotated_v2_backward_npu); +REGISTER_NPU_IMPL(roi_align_rotated_v2_backward_impl, roi_align_rotated_v2_backward_npu); diff --git a/mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp b/mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp index c7a11e8c6d..b7015439b9 100644 --- a/mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp @@ -50,23 +50,29 @@ void roi_pool_backward_npu(Tensor grad_output, Tensor rois, Tensor argmax, int64_t pooled_height_64 = pooled_height; int64_t pooled_width_64 = pooled_width; int64_t pooled_channel = 1; + at::Tensor argmax_trans = argmax.transpose(1, 2).transpose(2, 3); + at::Tensor grad_output_trans = grad_output.transpose(1, 2).transpose(2, 3); at::Tensor roi_actual_num = at::empty_like(rois, rois.options().dtype(at::kInt)); - at::Tensor x = at::ones_like(grad_input); + at::Tensor x = at::ones_like(grad_input).transpose(1, 2).transpose(2, 3); + at::Tensor y = at::zeros_like(x); OpCommand cmd; cmd.Name("RoiPoolingGradWithArgMax") - .Input(grad_output) + .Input(grad_output_trans) .Input(x) .Input(rois) .Input(roi_actual_num) - .Input(argmax) - .Output(grad_input) + .Input(argmax_trans) + .Output(y) .Attr("pooled_h", pooled_height_64) .Attr("pooled_w", pooled_width_64) .Attr("spatial_scale_h", spatial_scale) .Attr("spatial_scale_w", spatial_scale) .Attr("pool_channel", pooled_channel) .Run(); + at::Tensor result = y.transpose(2, 3).transpose(1, 2); + at::Tensor res = result.contiguous(); + grad_input.copy_(res); } void roi_pool_forward_impl(Tensor input, Tensor rois, Tensor output, diff --git a/mmcv/ops/csrc/pytorch/npu/roiaware_pool3d.cpp b/mmcv/ops/csrc/pytorch/npu/roiaware_pool3d.cpp new file mode 100644 index 0000000000..50706df867 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/roiaware_pool3d.cpp @@ -0,0 +1,86 @@ +#include "pytorch_npu_helper.hpp" +using namespace NPU_NAME_SPACE; +using namespace std; + +void roiaware_pool3d_forward_npu(int boxes_num, int pts_num, int channels, + int max_pts_each_voxel, int out_x, int out_y, + int out_z, const Tensor rois, const Tensor pts, + const Tensor pts_feature, Tensor argmax, + Tensor pts_idx_of_voxels, + Tensor pooled_features, int pool_method) { + at::Tensor rois_cast = rois; + at::Tensor pts_cast = pts; + at::Tensor pts_feature_cast = pts_feature; + at::Tensor pooled_features_cast = pooled_features; + + auto dtype = rois.dtype(); + if (dtype == at::kHalf) { + rois_cast = rois_cast.to(at::kFloat); + pts_cast = pts_cast.to(at::kFloat); + pts_feature_cast = pts_feature_cast.to(at::kFloat); + pooled_features_cast = pooled_features_cast.to(at::kFloat); + } + + EXEC_NPU_CMD(aclnnRoiawarePool3d, rois_cast, pts_cast, pts_feature_cast, + pool_method, max_pts_each_voxel, out_x, out_y, out_z, argmax, + pts_idx_of_voxels, pooled_features_cast); + + if (dtype == at::kHalf) { + pooled_features_cast = pooled_features_cast.to(at::kHalf); + } + + pooled_features.copy_(pooled_features_cast); +} + +void roiaware_pool3d_backward_npu(int boxes_num, int out_x, int out_y, + int out_z, int channels, + int max_pts_each_voxel, + const Tensor pts_idx_of_voxels, + const Tensor argmax, const Tensor grad_out, + Tensor grad_in, int pool_method) +{ + int32_t npoints = grad_in.size(0); + + auto dtype = grad_out.dtype(); + at::Tensor grad_out_cast = grad_out; + at::Tensor grad_in_cast = grad_in; + + if (dtype == at::kHalf) { + grad_out_cast = grad_out.to(at::kFloat); + grad_in_cast = grad_in_cast.to(at::kFloat); + } + + if (pool_method == 0) { + // maxpool3d + EXEC_NPU_CMD(aclnnRoiawareMaxpool3dGrad, argmax, grad_out_cast, boxes_num, + out_x, out_y, out_z, channels, npoints, grad_in_cast); + } else if (pool_method == 1) { + // avgpool3d + EXEC_NPU_CMD(aclnnRoiawareAvgpool3dGrad, pts_idx_of_voxels, grad_out_cast, + boxes_num, out_x, out_y, out_z, channels, npoints, + max_pts_each_voxel, grad_in_cast); + } + + if (dtype == at::kHalf) { + grad_in_cast = grad_in_cast.to(at::kHalf); + } + + grad_in.copy_(grad_in_cast); +} + +void roiaware_pool3d_forward_impl(int boxes_num, int pts_num, int channels, + int max_pts_each_voxel, int out_x, int out_y, + int out_z, const Tensor rois, + const Tensor pts, const Tensor pts_feature, + Tensor argmax, Tensor pts_idx_of_voxels, + Tensor pooled_features, int pool_method); + +void roiaware_pool3d_backward_impl(int boxes_num, int out_x, int out_y, + int out_z, int channels, + int max_pts_each_voxel, + const Tensor pts_idx_of_voxels, + const Tensor argmax, const Tensor grad_out, + Tensor grad_in, int pool_method); + +REGISTER_NPU_IMPL(roiaware_pool3d_forward_impl, roiaware_pool3d_forward_npu); +REGISTER_NPU_IMPL(roiaware_pool3d_backward_impl, roiaware_pool3d_backward_npu); \ No newline at end of file diff --git a/mmcv/ops/csrc/pytorch/npu/stack_ball_query_npu.cpp b/mmcv/ops/csrc/pytorch/npu/stack_ball_query_npu.cpp index cd8c3ad8c9..92627df6e3 100644 --- a/mmcv/ops/csrc/pytorch/npu/stack_ball_query_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/stack_ball_query_npu.cpp @@ -8,9 +8,10 @@ void stack_ball_query_forward_npu(float max_radius, int nsample, const Tensor new_xyz_batch_cnt, const Tensor xyz, const Tensor xyz_batch_cnt, Tensor idx) { - at::Tensor xyz_transpose = xyz.transpose(0, 1).contiguous(); + at::Tensor xyz_transpose = xyz.transpose(0, 1).contiguous().to(at::kFloat); + at::Tensor new_xyz_fp32 = new_xyz.to(at::kFloat); double max_radius_double = double(max_radius); - EXEC_NPU_CMD(aclnnStackBallQuery, xyz_transpose, new_xyz, xyz_batch_cnt, + EXEC_NPU_CMD(aclnnStackBallQuery, xyz_transpose, new_xyz_fp32, xyz_batch_cnt, new_xyz_batch_cnt, max_radius_double, nsample, idx); } diff --git a/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp b/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp index f908755478..42d346f7d2 100644 --- a/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp @@ -12,17 +12,21 @@ void three_interpolate_forward_npu(int b, int c, int m, int n, TORCH_CHECK((originDtype == at::kFloat || originDtype == at::kHalf), "three_interpolate_forward ascend only support fp32 and fp16."); - auto point_c_trans = points.transpose(1, 2); - + auto point_c_trans = points.transpose(1, 2).to(at::kFloat); + auto weight_cast = weight.to(at::kFloat); + auto out_cast = out.to(at::kFloat); OpCommand cmd; cmd.Name("ThreeInterpolate") .Input(point_c_trans) .Input(idx) - .Input(weight) - .Output(out) + .Input(weight_cast) + .Output(out_cast) .Run(); - auto output = out.view({b, n, c}).transpose(1, 2); + if (originDtype == at::kHalf) { + out_cast = out_cast.to(at::kHalf); + } + auto output = out_cast.view({b, n, c}).transpose(1, 2); auto res = output.contiguous(); out.copy_(res); } @@ -34,12 +38,17 @@ void three_interpolate_backward_npu(int b, int c, int n, int m, TORCH_CHECK((originDtype == at::kFloat || originDtype == at::kHalf), "three_interpolate_backward ascend only support fp32 and fp16."); - auto grad_x = at::unsqueeze(grad_out, 3); - auto grad_y = at::unsqueeze(grad_points, 3); - - EXEC_NPU_CMD(aclnnThreeInterpolateBackward, grad_x, idx, weight, m, grad_y); + auto grad_x = at::unsqueeze(grad_out, 3).to(at::kFloat); + auto grad_y = at::unsqueeze(grad_points, 3).to(at::kFloat); + auto weight_cast = weight.to(at::kFloat); + EXEC_NPU_CMD(aclnnThreeInterpolateBackward, grad_x, idx, weight_cast, m, + grad_y); - auto output = at::squeeze(grad_y, 3); + auto grad_y_cast = grad_y; + if (originDtype == at::kHalf) { + grad_y_cast = grad_y.to(at::kHalf); + } + auto output = at::squeeze(grad_y_cast, 3); auto res = output.contiguous(); grad_points.copy_(res); } diff --git a/mmcv/ops/csrc/pytorch/npu/three_nn_npu.cpp b/mmcv/ops/csrc/pytorch/npu/three_nn_npu.cpp index 9766816f6c..0a6a10bc81 100644 --- a/mmcv/ops/csrc/pytorch/npu/three_nn_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/three_nn_npu.cpp @@ -7,21 +7,12 @@ using namespace std; void three_nn_forward_npu(int b, int n, int m, const Tensor unknown, const Tensor known, Tensor dist2, Tensor idx) { - // transpose known [B, N, 3] -> [B, 3, N] - at::Tensor source = known.transpose(1, 2).contiguous(); + at::Tensor source = known.contiguous(); at::Tensor target = unknown.contiguous(); - auto originDtype = source.scalar_type(); - if (originDtype == at::kHalf) { - source = source.to(at::kFloat); - target = target.to(at::kFloat); - } bool is_from_knn = false; - uint32_t nsample = 3; - EXEC_NPU_CMD(aclnnKnn, source, target, nsample, is_from_knn, idx, dist2); - if (originDtype == at::kHalf) { - dist2 = dist2.to(at::kHalf); - } + int nsample = 3; + EXEC_NPU_CMD(aclnnKnn, source, target, is_from_knn, nsample, dist2, idx); } void three_nn_forward_impl(int b, int n, int m, const Tensor unknown, diff --git a/mmcv/ops/csrc/pytorch/npu/voxelization_npu.cpp b/mmcv/ops/csrc/pytorch/npu/voxelization_npu.cpp index ffd9b4c43b..2abe7c8f95 100644 --- a/mmcv/ops/csrc/pytorch/npu/voxelization_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/voxelization_npu.cpp @@ -11,6 +11,11 @@ int hard_voxelize_forward_impl(const at::Tensor &points, at::Tensor &voxels, const int max_points, const int max_voxels, const int NDim = 3); +void dynamic_voxelize_forward_impl(const at::Tensor &points, at::Tensor &coors, + const std::vector voxel_size, + const std::vector coors_range, + const int NDim = 3); + int hard_voxelize_forward_npu(const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors, at::Tensor &num_points_per_voxel, @@ -53,4 +58,34 @@ int hard_voxelize_forward_npu(const at::Tensor &points, at::Tensor &voxels, return voxel_num_int; } +void dynamic_voxelize_forward_npu(const at::Tensor &points, at::Tensor &coors, + const std::vector voxel_size, + const std::vector coors_range, + const int NDim = 3) { + uint32_t ptsNum = points.size(0); + uint32_t ptsFeature = points.size(1); + at::Tensor ptsTrans = at::transpose(points, 0, 1); + double coors_min_x = coors_range[0]; + double coors_min_y = coors_range[1]; + double coors_min_z = coors_range[2]; + double coors_max_x = coors_range[3]; + double coors_max_y = coors_range[4]; + double coors_max_z = coors_range[5]; + double voxel_x = voxel_size[0]; + double voxel_y = voxel_size[1]; + double voxel_z = voxel_size[2]; + int grid_x = std::round((coors_max_x - coors_min_x) / voxel_x); + int grid_y = std::round((coors_max_y - coors_min_y) / voxel_y); + int grid_z = std::round((coors_max_z - coors_min_z) / voxel_z); + + at::Tensor tmp_coors = + at::zeros({3, ptsNum}, points.options().dtype(at::kInt)); + EXEC_NPU_CMD(aclnnDynamicVoxelization, ptsTrans, coors_min_x, coors_min_y, + coors_min_z, voxel_x, voxel_y, voxel_z, grid_x, grid_y, grid_z, + tmp_coors); + tmp_coors.transpose_(0, 1); + coors.copy_(tmp_coors); +} + REGISTER_NPU_IMPL(hard_voxelize_forward_impl, hard_voxelize_forward_npu); +REGISTER_NPU_IMPL(dynamic_voxelize_forward_impl, dynamic_voxelize_forward_npu); diff --git a/mmcv/ops/csrc/pytorch/pybind.cpp b/mmcv/ops/csrc/pytorch/pybind.cpp index 864630f838..43e9b270f2 100644 --- a/mmcv/ops/csrc/pytorch/pybind.cpp +++ b/mmcv/ops/csrc/pytorch/pybind.cpp @@ -209,8 +209,8 @@ void roi_align_backward(Tensor grad_output, Tensor rois, Tensor argmax_y, int sampling_ratio, int pool_mode, bool aligned); void roi_align_rotated_v2_forward(Tensor input, Tensor rois, Tensor output, + int pooled_h, int pooled_w, double spatial_scale, int sampling_ratio, - int aligned_height, int aligned_width, bool aligned, bool clockwise); void roi_align_rotated_v2_backward(Tensor input, Tensor rois, @@ -342,6 +342,15 @@ void roi_align_rotated_backward(Tensor grad_output, Tensor rois, int sampling_ratio, bool aligned, bool clockwise); +void roi_align_rotated_v2_forward(Tensor input, Tensor rois, Tensor output, + int pooled_h, int pooled_w, + double spatial_scale, int sampling_ratio, + bool aligned, bool clockwise); + +void roi_align_rotated_v2_backward(Tensor input, Tensor rois, Tensor grad_output, Tensor grad_input, + int pooled_height, int pooled_width, double spatial_scale, + int sampling_ratio, bool aligned, bool clockwise); + std::vector dynamic_point_to_voxel_forward( const torch::Tensor &feats, const torch::Tensor &coors, const std::string &reduce_type); @@ -805,14 +814,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("sampling_ratio"), py::arg("aligned"), py::arg("clockwise")); m.def("roi_align_rotated_v2_forward", &roi_align_rotated_v2_forward, "roi_align_rotated_v2_forward", py::arg("input"), py::arg("rois"), - py::arg("output"), py::arg("spatial_scale"), py::arg("sampling_ratio"), - py::arg("pooled_height"), py::arg("pooled_width"), py::arg("aligned"), - py::arg("clockwise")); + py::arg("output"), py::arg("pooled_h"), py::arg("pooled_w"), + py::arg("spatial_scale"), py::arg("sampling_ratio"), + py::arg("aligned"), py::arg("clockwise")); m.def("roi_align_rotated_v2_backward", &roi_align_rotated_v2_backward, "roi_align_rotated_v2_backward", py::arg("input"), py::arg("rois"), py::arg("grad_output"), py::arg("grad_input"), py::arg("pooled_height"), - py::arg("pooled_width"), py::arg("spatial_scale"), - py::arg("sampling_ratio"), py::arg("aligned"), py::arg("clockwise")); + py::arg("pooled_width"), py::arg("spatial_scale"), py::arg("sampling_ratio"), + py::arg("aligned"), py::arg("clockwise")); m.def("dynamic_point_to_voxel_forward", &dynamic_point_to_voxel_forward, "dynamic_point_to_voxel_forward", py::arg("feats"), py::arg("coors"), py::arg("reduce_type")); diff --git a/mmcv/ops/csrc/pytorch/roi_align_rotated_v2.cpp b/mmcv/ops/csrc/pytorch/roi_align_rotated_v2.cpp index 9b0a623530..ec2a17bf77 100644 --- a/mmcv/ops/csrc/pytorch/roi_align_rotated_v2.cpp +++ b/mmcv/ops/csrc/pytorch/roi_align_rotated_v2.cpp @@ -2,41 +2,36 @@ #include "pytorch_cpp_helper.hpp" #include "pytorch_device_registry.hpp" -void roi_align_rotated_v2_forward_impl(Tensor input, Tensor rois, Tensor output, - double spatial_scale, int sampling_ratio, - int pooled_height, int pooled_width, - bool aligned, bool clockwise) { - DISPATCH_DEVICE_IMPL(roi_align_rotated_v2_forward_impl, input, rois, output, - spatial_scale, sampling_ratio, pooled_height, - pooled_width, aligned, clockwise); +void roi_align_rotated_v2_forward_impl(Tensor x, Tensor rois, Tensor y, + int pooled_h, int pooled_w, + double spatial_scale, int sampling_ratio, + bool aligned, bool clockwise) { + DISPATCH_DEVICE_IMPL(roi_align_rotated_v2_forward_impl, x, rois, y, + pooled_h, pooled_w, spatial_scale, sampling_ratio, + aligned, clockwise); } -void roi_align_rotated_v2_forward(Tensor input, Tensor rois, Tensor output, - double spatial_scale, int sampling_ratio, - int pooled_height, int pooled_width, - bool aligned, bool clockwise) { - roi_align_rotated_v2_forward_impl(input, rois, output, spatial_scale, - sampling_ratio, pooled_height, pooled_width, - aligned, clockwise); + +void roi_align_rotated_v2_forward(Tensor x, Tensor rois, Tensor y, + int pooled_h, int pooled_w, + double spatial_scale, int sampling_ratio, + bool aligned, bool clockwise) { + roi_align_rotated_v2_forward_impl(x, rois, y, pooled_h, pooled_w, + spatial_scale, sampling_ratio, aligned, clockwise); } -void roi_align_rotated_v2_backward_impl(Tensor input, Tensor rois, - Tensor grad_output, Tensor grad_input, - int pooled_height, int pooled_width, - double spatial_scale, - int sampling_ratio, bool aligned, - bool clockwise) { - DISPATCH_DEVICE_IMPL(roi_align_rotated_v2_backward_impl, input, rois, - grad_output, grad_input, pooled_height, pooled_width, - spatial_scale, sampling_ratio, aligned, clockwise); + +void roi_align_rotated_v2_backward_impl(Tensor input, Tensor rois, Tensor grad_output, Tensor grad_input, + int pooled_height, int pooled_width, double spatial_scale, + int sampling_ratio, bool aligned, bool clockwise) { + DISPATCH_DEVICE_IMPL(roi_align_rotated_v2_backward_impl, input, rois, grad_output, grad_input, + pooled_height, pooled_width, spatial_scale, sampling_ratio, aligned, clockwise); } -void roi_align_rotated_v2_backward(Tensor input, Tensor rois, - Tensor grad_output, Tensor grad_input, - int pooled_height, int pooled_width, - double spatial_scale, int sampling_ratio, - bool aligned, bool clockwise) { + +void roi_align_rotated_v2_backward(Tensor input, Tensor rois, Tensor grad_output, Tensor grad_input, + int pooled_height, int pooled_width, double spatial_scale, + int sampling_ratio, bool aligned, bool clockwise) { roi_align_rotated_v2_backward_impl(input, rois, grad_output, grad_input, - pooled_height, pooled_width, spatial_scale, - sampling_ratio, aligned, clockwise); + pooled_height, pooled_width, spatial_scale, sampling_ratio, aligned, clockwise); } diff --git a/mmcv/ops/deform_conv.py b/mmcv/ops/deform_conv.py index 1bde04d735..4804635b58 100644 --- a/mmcv/ops/deform_conv.py +++ b/mmcv/ops/deform_conv.py @@ -57,7 +57,7 @@ def _npu_backward(ctx, grad_output): grad_input, grad_weight, grad_offset_all, grad_bias = \ torch_npu.npu_deformable_conv2dbk( input_tensor, grad_output, offset_out, weight, offset_all, - kernel_size=[weight.shape[3], weight.shape[2]], + kernel_size=[weight.shape[2], weight.shape[3]], stride=[1, 1, ctx.stride[0], ctx.stride[1]], padding=[ctx.padding[0], ctx.padding[0], ctx.padding[1], ctx.padding[1]], diff --git a/mmcv/ops/fused_bias_leakyrelu.py b/mmcv/ops/fused_bias_leakyrelu.py index e23617fb3a..fe17d2db7b 100644 --- a/mmcv/ops/fused_bias_leakyrelu.py +++ b/mmcv/ops/fused_bias_leakyrelu.py @@ -258,7 +258,7 @@ def fused_bias_leakyrelu(input: torch.Tensor, torch.Tensor: Feature map after non-linear activation. """ - if not input.is_cuda: + if not input.is_cuda and input.device.type != 'npu': return bias_leakyrelu_ref(input, bias, negative_slope, scale) return FusedBiasLeakyReLUFunction.apply(input, bias.to(input.dtype), diff --git a/mmcv/ops/knn.py b/mmcv/ops/knn.py index 47ced04c6a..d56f74f2c9 100644 --- a/mmcv/ops/knn.py +++ b/mmcv/ops/knn.py @@ -62,6 +62,17 @@ def forward(ctx, B, npoint, _ = center_xyz.shape N = xyz.shape[1] + if xyz.device.type == 'npu': + dist2 = center_xyz.new_zeros((B, npoint, k)).float() + idx = center_xyz.new_zeros((B, npoint, k)).int() + ext_module.knn_forward( + xyz, center_xyz, idx, dist2, b=B, n=N, m=npoint, nsample=k) + zeros_idx = torch.zeros( + xyz.shape[0], center_xyz.shape[1], k, dtype=torch.int32).npu() + idx.where(dist2 >= 1e10, zeros_idx) + idx = idx.transpose(2, 1).contiguous() # [B, k, npoint] + return idx.int() + idx = center_xyz.new_zeros((B, npoint, k)).int() dist2 = center_xyz.new_zeros((B, npoint, k)).float() diff --git a/mmcv/ops/modulated_deform_conv.py b/mmcv/ops/modulated_deform_conv.py index 0c169009a5..b6e8c6d40a 100644 --- a/mmcv/ops/modulated_deform_conv.py +++ b/mmcv/ops/modulated_deform_conv.py @@ -55,7 +55,7 @@ def _npu_forward(ctx, input_tensor, offset, mask, weight, bias): conv2d_bias = bias if len(bias) > 0 else None sort_index_fp, sort_index_bp = \ ModulatedDeformConv2dFunction._calculate_sort_index( - kernel_w, kernel_h, ctx.deform_groups) + kernel_h, kernel_w, ctx.deform_groups) select_offset = offset.index_select(1, sort_index_fp) offset_all = torch.cat([select_offset, mask], dim=1) import torch_npu @@ -64,7 +64,7 @@ def _npu_forward(ctx, input_tensor, offset, mask, weight, bias): weight, offset_all, conv2d_bias, - kernel_size=[kernel_w, kernel_h], + kernel_size=[kernel_h, kernel_w], stride=[1, 1, ctx.stride[0], ctx.stride[1]], padding=[ ctx.padding[0], ctx.padding[0], ctx.padding[1], ctx.padding[1] @@ -87,7 +87,7 @@ def _npu_backward(ctx, grad_output): grad_input, grad_weight, grad_offset_all, grad_bias = \ torch_npu.npu_deformable_conv2dbk( input_tensor, grad_output, offset_out, weight, offset_all, - kernel_size=[weight.shape[3], weight.shape[2]], + kernel_size=[weight.shape[2], weight.shape[3]], stride=[1, 1, ctx.stride[0], ctx.stride[1]], padding=[ctx.padding[0], ctx.padding[0], ctx.padding[1], ctx.padding[1]], diff --git a/mmcv/ops/nms.py b/mmcv/ops/nms.py index 14df732482..946d016a70 100644 --- a/mmcv/ops/nms.py +++ b/mmcv/ops/nms.py @@ -411,6 +411,7 @@ def nms_rotated(dets: Tensor, input_labels = scores.new_empty(0, dtype=torch.int) else: input_labels = labels + if dets.device.type in ('npu', 'mlu'): order = scores.new_empty(0, dtype=torch.long) keep_inds = ext_module.nms_rotated(dets_cw, scores, order, dets_cw, diff --git a/mmcv/ops/pixel_group.py b/mmcv/ops/pixel_group.py index cf73e326da..c03f80a13f 100644 --- a/mmcv/ops/pixel_group.py +++ b/mmcv/ops/pixel_group.py @@ -9,6 +9,34 @@ ext_module = ext_loader.load_ext('_ext', ['pixel_group']) +def estimate_confidence(label: torch.Tensor, score: torch.Tensor, + label_num: int) -> List[List[float]]: + + import torch_npu + point_vector = torch.zeros((label_num, 2), + dtype=torch.float32).to(score.device) + + label_flat = label.flatten() + score_flat = score.flatten() + + mask = label_flat > 0 + valid_labels = label_flat[mask] + valid_scores = score_flat[mask] + + point_vector.index_add_( + 0, valid_labels, + torch.stack((valid_scores, torch.ones_like(valid_scores)), dim=1)) + + valid_mask = point_vector[:, 1] > 0 + point_vector[valid_mask, 0] /= point_vector[valid_mask, 1] + + point_vector_list = point_vector.tolist() + for l in range(1, label_num): + coords = (label == l).nonzero(as_tuple=False).float() + coords = coords[:, [1, 0]] + point_vector_list[l].extend(coords.flatten().tolist()) + + return point_vector_list def pixel_group( score: Union[np.ndarray, Tensor], @@ -59,6 +87,30 @@ def pixel_group( if isinstance(kernel_contour, np.ndarray): kernel_contour = torch.from_numpy(kernel_contour) + if score.device.type == 'npu': + import torch_npu + embedding_dim = embedding.shape[2] + kernel_vector = torch.zeros((kernel_region_num, embedding_dim), + dtype=torch.float32).to(score.device) + + for label in range(1, kernel_region_num): + label_mask = (kernel_label == label) + label_embeddings = embedding[label_mask] + kernel_vector[label, :] = label_embeddings.sum(dim=0) + vector_sum = label_mask.sum() + kernel_vector[label, :] /= vector_sum + + kernel_cv = kernel_vector[label, :] + valid_mask = (mask == 1) & (kernel_label == 0) + valid_embeddings = embedding[valid_mask] + distances = torch.sum((valid_embeddings - kernel_cv)**2, dim=1) + within_threshold = distances < distance_threshold**2 + + kernel_label[valid_mask] = torch.where(within_threshold, label, + kernel_label[valid_mask]) + + return estimate_confidence(kernel_label, score, kernel_region_num) + if torch.__version__ == 'parrots': label = ext_module.pixel_group( score, @@ -83,4 +135,4 @@ def pixel_group( kernel_label, kernel_contour, kernel_region_num, distance_threshold) - return pixel_assignment + return pixel_assignment \ No newline at end of file diff --git a/mmcv/ops/points_in_boxes.py b/mmcv/ops/points_in_boxes.py index 4915e6b573..3069867886 100644 --- a/mmcv/ops/points_in_boxes.py +++ b/mmcv/ops/points_in_boxes.py @@ -47,8 +47,11 @@ def points_in_boxes_part(points: Tensor, boxes: Tensor) -> Tensor: points_device = points.get_device() assert points_device == boxes.get_device(), \ 'Points and boxes should be put on the same device' - if torch.cuda.current_device() != points_device: - torch.cuda.set_device(points_device) + if points.device.type != 'npu': + if torch.cuda.current_device() != points_device: + torch.cuda.set_device(points_device) + elif points.device.type == 'npu': + boxes[:, :, 2] += boxes[:, :, 5] / 2.0 ext_module.points_in_boxes_part_forward(boxes.contiguous(), points.contiguous(), @@ -127,9 +130,9 @@ def points_in_boxes_all(points: Tensor, boxes: Tensor) -> Tensor: points_device = points.get_device() assert points_device == boxes.get_device(), \ 'Points and boxes should be put on the same device' - if torch.cuda.current_device() != points_device: - torch.cuda.set_device(points_device) - + if points.device.type != 'npu': + if torch.cuda.current_device() != points_device: + torch.cuda.set_device(points_device) ext_module.points_in_boxes_all_forward(boxes.contiguous(), points.contiguous(), box_idxs_of_pts) diff --git a/mmcv/ops/points_in_polygons.py b/mmcv/ops/points_in_polygons.py index e54b5a896d..8d3bc8dd48 100644 --- a/mmcv/ops/points_in_polygons.py +++ b/mmcv/ops/points_in_polygons.py @@ -19,6 +19,8 @@ def points_in_polygons(points: Tensor, polygons: Tensor) -> Tensor: polygons (torch.Tensor): It has shape (M, 8), indicating (x1, y1, x2, y2, x3, y3, x4, y4). M means the number of ground truth polygons. + constraints: The number of significant digits for the input-arguments + are between -10 and 10 when running on Ascend device. Returns: torch.Tensor: Return the result with the shape of (B, M), diff --git a/mmcv/ops/roi_align_rotated_v2.py b/mmcv/ops/roi_align_rotated_v2.py index 639fea3a23..80c97ae736 100644 --- a/mmcv/ops/roi_align_rotated_v2.py +++ b/mmcv/ops/roi_align_rotated_v2.py @@ -14,55 +14,55 @@ class RoIAlignRotatedV2Function(Function): @staticmethod - def symbolic(g, input, rois, spatial_scale, sampling_ratio, pooled_height, - pooled_width, aligned, clockwise): + def symbolic(g, x, rois, spatial_scale, sampling_ratio, pooled_h, + pooled_w, aligned, clockwise): return g.op( 'mmcv::MMCVRoIAlignRotatedV2', - input, + x, rois, + pooled_h=pooled_h, + pooled_w=pooled_w, spatial_scale_f=spatial_scale, sampling_ratio_i=sampling_ratio, - pooled_height=pooled_height, - pooled_width=pooled_width, aligned_i=aligned, clockwise_i=clockwise) @staticmethod def forward(ctx: Any, - input: torch.Tensor, + x: torch.Tensor, rois: torch.Tensor, + pooled_h: int, + pooled_w: int, spatial_scale: float, sampling_ratio: int, - pooled_height: int, - pooled_width: int, aligned: bool = True, clockwise: bool = False) -> torch.Tensor: - ctx.pooled_height = pooled_height - ctx.pooled_width = pooled_width + ctx.pooled_h = pooled_h + ctx.pooled_w = pooled_w ctx.spatial_scale = spatial_scale ctx.sampling_ratio = sampling_ratio ctx.aligned = aligned ctx.clockwise = clockwise - ctx.save_for_backward(input, rois) - ctx.feature_size = input.size() - batch_size, num_channels, data_height, data_width = input.size() + ctx.save_for_backward(x, rois) + ctx.feature_size = x.size() + batch_size, num_channels, data_height, data_width = x.size() num_rois = rois.size(0) - output = input.new_zeros(num_rois, ctx.pooled_height, ctx.pooled_width, + y = x.new_zeros(num_rois, ctx.pooled_h, ctx.pooled_w, num_channels) ext_module.roi_align_rotated_v2_forward( - input, + x, rois, - output, + y, + pooled_h=ctx.pooled_h, + pooled_w=ctx.pooled_w, spatial_scale=ctx.spatial_scale, sampling_ratio=ctx.sampling_ratio, - pooled_height=ctx.pooled_height, - pooled_width=ctx.pooled_width, aligned=ctx.aligned, clockwise=ctx.clockwise) - output = output.transpose(2, 3).transpose(1, 2).contiguous() - return output + y = y.transpose(2, 3).transpose(1, 2).contiguous() + return y @staticmethod def backward(ctx: Any, grad_output: torch.Tensor): @@ -74,7 +74,7 @@ def backward(ctx: Any, grad_output: torch.Tensor): input.size(0), input.size(2), input.size(3), input.size(1)) ext_module.roi_align_rotated_v2_backward( input, rois_trans, grad_output_trans, grad_input, - ctx.pooled_height, ctx.pooled_width, ctx.spatial_scale, + ctx.pooled_h, ctx.pooled_w, ctx.spatial_scale, ctx.sampling_ratio, ctx.aligned, ctx.clockwise) grad_input = grad_input.permute(0, 3, 1, 2).contiguous() @@ -134,31 +134,33 @@ class RoIAlignRotatedV2(nn.Module): }, cls_name='RoIAlignRotatedV2') def __init__(self, + pooled_h: int, + pooled_w: int, spatial_scale: float, sampling_ratio: int, - pooled_height: int, - pooled_width: int, aligned: bool = True, clockwise: bool = False): super().__init__() - self.pooled_height = int(pooled_height) - self.pooled_width = int(pooled_width) + self.pooled_h = int(pooled_h) + self.pooled_w = int(pooled_w) self.spatial_scale = float(spatial_scale) self.sampling_ratio = int(sampling_ratio) self.aligned = aligned self.clockwise = clockwise def forward(self, input: torch.Tensor, rois: torch.Tensor) -> torch.Tensor: - return RoIAlignRotatedV2Function.apply(input, rois, self.spatial_scale, + return RoIAlignRotatedV2Function.apply(input, rois, + self.pooled_h, + self.pooled_w, + self.spatial_scale, self.sampling_ratio, - self.pooled_height, - self.pooled_width, self.aligned, + self.aligned, self.clockwise) def __repr__(self): s = self.__class__.__name__ - s += f'(pooled_height={self.pooled_height}, ' + s += f'(pooled_h={self.pooled_h}, ' s += f'spatial_scale={self.spatial_scale}, ' s += f'sampling_ratio={self.sampling_ratio}, ' s += f'aligned={self.aligned}, ' diff --git a/mmcv/ops/scatter_points.py b/mmcv/ops/scatter_points.py index 7b03e8f1da..3061343d2a 100644 --- a/mmcv/ops/scatter_points.py +++ b/mmcv/ops/scatter_points.py @@ -36,10 +36,30 @@ def forward(ctx: Any, reduced from input features that share the same voxel coordinates. The second is voxel coordinates with shape [M, ndim]. """ + ctx.device = feats.device.type + if ctx.device == 'npu': + import mx_driving._C + voxel_idx = mx_driving._C.point_to_voxel(coors, [], [], 'XYZ') + unique_res = mx_driving._C.unique_voxel(voxel_idx) + num_voxels, uniqued_voxel_idx, prefix_sum, \ + argsort_coor, _ = unique_res + voxel_coors = \ + mx_driving._C.voxel_to_point(uniqued_voxel_idx, [], [], 'XYZ') + voxel_feats, \ + compare_mask = mx_driving._C.npu_dynamic_scatter(feats, coors, + prefix_sum, + argsort_coor, + num_voxels, + reduce_type) + ctx.reduce_type = reduce_type + ctx.feats_shape = feats.shape + ctx.save_for_backward(prefix_sum, argsort_coor, compare_mask) + ctx.mark_non_differentiable(voxel_coors) + return voxel_feats, voxel_coors + results = ext_module.dynamic_point_to_voxel_forward( feats, coors, reduce_type) - (voxel_feats, voxel_coors, point2voxel_map, - voxel_points_count) = results + voxel_feats, voxel_coors, point2voxel_map, voxel_points_count = results ctx.reduce_type = reduce_type ctx.save_for_backward(feats, voxel_feats, point2voxel_map, voxel_points_count) @@ -50,6 +70,19 @@ def forward(ctx: Any, def backward(ctx: Any, grad_voxel_feats: torch.Tensor, grad_voxel_coors: Optional[torch.Tensor] = None) -> tuple: + if ctx.device == 'npu': + import mx_driving._C + prefix_sum, argsort_coor, compare_mask = ctx.saved_tensors + grad_point_feats = torch.zeros( + ctx.feats_shape, + dtype=grad_voxel_feats.dtype, + device=grad_voxel_feats.device) + mx_driving._C.npu_dynamic_scatter_grad(grad_point_feats, + grad_voxel_feats.contiguous(), + prefix_sum, argsort_coor, + compare_mask, ctx.reduce_type) + return grad_point_feats, None, None + (feats, voxel_feats, point2voxel_map, voxel_points_count) = ctx.saved_tensors grad_feats = torch.zeros_like(feats) diff --git a/mmcv/ops/three_nn.py b/mmcv/ops/three_nn.py index d41b9789cf..db8bbdef11 100644 --- a/mmcv/ops/three_nn.py +++ b/mmcv/ops/three_nn.py @@ -34,6 +34,21 @@ def forward(ctx: Any, target: torch.Tensor, B, N, _ = target.size() m = source.size(1) + if source.device.type == 'npu': + # strict to fp32 + source = source.transpose(2, 1).contiguous() + dtype_ = source.dtype + if dtype_ == torch.float16: + target = target.float() + source = source.float() + dist2 = target.new_empty(B, N, 3) + idx = target.new_empty(B, N, 3, dtype=torch.int32) + ext_module.three_nn_forward( + target, source, dist2, idx, b=B, n=N, m=m) + dist2 = torch.sqrt(dist2) + if dtype_ == torch.float16: + dist2 = dist2.half() + return dist2, idx.int() dist2 = target.new_empty(B, N, 3) idx = target.new_empty(B, N, 3, dtype=torch.int32) diff --git a/tests/test_ops/test_assign_score_withk.py b/tests/test_ops/test_assign_score_withk.py index f8fc6ae626..d778121c74 100644 --- a/tests/test_ops/test_assign_score_withk.py +++ b/tests/test_ops/test_assign_score_withk.py @@ -3,93 +3,105 @@ import torch from mmcv.ops import assign_score_withk +from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE -@pytest.mark.skipif( - not torch.cuda.is_available(), reason='requires CUDA support') -def test_paconv_assign_scores(): - scores = torch.tensor([[[[0.06947571, 0.6065746], [0.28462553, 0.8378516], - [0.7595994, 0.97220325], [0.519155, 0.766185]], - [[0.15348864, 0.6051019], [0.21510637, 0.31916398], - [0.00236845, 0.5842595], [0.6783676, 0.5216348]]], - [[[0.23089725, 0.5568468], [0.7405102, 0.06438422], - [0.6887394, 0.22089851], [0.0502342, 0.79228795]], - [[0.44883424, 0.15427643], - [0.13817799, 0.34856772], [0.7989621, 0.33788306], - [0.15699774, 0.7693662]]]]).float().cuda() - scores.requires_grad_() - points = torch.tensor([[[[0.06001121, 0.92963666, 0.5753327, 0.7251477], - [0.53563064, 0.23129565, 0.92366195, 0.44261628]], - [[0.5770022, 0.56625944, 0.23560429, 0.11178821], - [0.7735967, 0.95678777, 0.25468266, 0.02895975]], - [[0.0589869, 0.09017515, 0.5977862, 0.02797985], - [0.603862, 0.35991007, 0.85761684, 0.3096559]], - [[0.22359002, 0.13983732, 0.5544243, 0.68863827], - [0.85646236, 0.75651926, 0.8638947, 0.83600986]], - [[0.45424145, 0.27458847, 0.6456112, 0.47162914], - [0.15773582, 0.47645122, 0.79964715, 0.3323908]], - [[0.8351399, 0.84696376, 0.9431732, 0.29418713], - [0.77168906, 0.6996871, 0.19354361, 0.03392768]], - [[0.30976456, 0.7074133, 0.581795, 0.976677], - [0.69656056, 0.07199162, 0.4708506, 0.29117996]], - [[0.5829035, 0.30201727, 0.76556486, 0.0935446], - [0.88030535, 0.16129416, 0.9242525, 0.49545723]]], - [[[0.50899494, 0.06482804, 0.44939405, 0.37704808], - [0.47028124, 0.11969638, 0.62823206, 0.28560323]], - [[0.40690207, 0.689753, 0.51636654, 0.23040164], - [0.06935787, 0.00488842, 0.22462702, 0.09182382]], - [[0.26611632, 0.00184339, 0.7730655, 0.5228131], - [0.87776035, 0.77895886, 0.2787183, 0.16620636]], - [[0.502574, 0.04039001, 0.5368497, 0.98379374], - [0.40973026, 0.3238272, 0.9733018, 0.13988364]], - [[0.04586202, 0.20983845, 0.20662665, 0.22270602], - [0.60387236, 0.5155574, 0.51237285, 0.6528438]], - [[0.45735973, 0.86821306, 0.61054605, 0.8370336], - [0.45193362, 0.3734138, 0.7825672, 0.5699416]], - [[0.44591594, 0.12447512, 0.09282011, 0.7055254], - [0.25223452, 0.46696228, 0.7051136, 0.892151]], - [[0.49615085, 0.47321403, 0.93138885, 0.7652197], - [0.38766378, 0.30332977, 0.23131835, - 0.02863514]]]]).float().cuda() - points.requires_grad_() - centers = torch.tensor([[[[0.83878064, 0.96658987, 0.8033424, 0.9598312], - [0.45035273, 0.8768925, 0.977736, 0.54547966]], - [[0.01041394, 0.597893, 0.36212963, 0.4410367], - [0.94879234, 0.8372817, 0.21237361, 0.67945415]], - [[0.5096087, 0.26401454, 0.60034937, 0.5417416], - [0.87591463, 0.546456, 0.4096033, 0.16373193]], - [[0.79547447, 0.1482386, 0.12840575, 0.45384115], - [0.5640288, 0.944541, 0.5745328, 0.73229736]], - [[0.93011934, 0.7406011, 0.62621707, 0.8677915], - [0.91563636, 0.3595413, 0.6678378, 0.6085383]], - [[0.22431666, 0.65617776, 0.7483924, 0.6263364], - [0.30968404, 0.78204364, 0.14899081, - 0.09628749]], - [[0.73675203, 0.72104895, 0.4648038, 0.6101647], - [0.7817645, 0.16572917, 0.3311919, 0.43407398]], - [[0.8193154, 0.09559608, 0.05978829, 0.90262103], - [0.4256065, 0.8165596, 0.8206446, 0.6604721]]], - [[[0.7159653, 0.18600845, 0.21433902, 0.3159626], - [0.3921569, 0.33221376, 0.5061177, 0.7961841]], - [[0.95338356, 0.04785997, 0.67185795, 0.6538394], - [0.4729132, 0.33404195, 0.17750603, 0.8445621]], - [[0.6755793, 0.16193843, 0.75943846, 0.92123103], - [0.2781859, 0.03114432, 0.710638, 0.52729136]], - [[0.8376105, 0.10858494, 0.13208169, 0.365772], - [0.5930795, 0.27390373, 0.14036089, 0.170403]], - [[0.3479789, 0.89855295, 0.04844379, 0.9871029], - [0.29781651, 0.0244137, 0.9179047, 0.8081611]], - [[0.12460887, 0.44991326, 0.19382608, 0.35037738], - [0.2773472, 0.4362057, 0.36757517, 0.5993509]], - [[0.29630446, 0.90046406, 0.5417113, 0.13510644], - [0.09623539, 0.04226565, 0.32001644, - 0.44358212]], - [[0.5274848, 0.82096446, 0.9415489, 0.7123748], - [0.7537517, 0.8086482, 0.85345286, - 0.7472754]]]]).float().cuda() - centers.requires_grad_() - knn_idx = torch.tensor([[[6, 7, 4, 6], [2, 4, 2, 4]], - [[7, 1, 3, 2], [6, 0, 2, 6]]]).long().cuda() +@pytest.mark.parametrize('device', [ + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'npu', + marks=pytest.mark.skipif( + not IS_NPU_AVAILABLE, reason='requires NPU support')) +]) +def test_paconv_assign_scores(device): + scores = torch.tensor( + [[[[0.06947571, 0.6065746], [0.28462553, 0.8378516], + [0.7595994, 0.97220325], [0.519155, 0.766185]], + [[0.15348864, 0.6051019], [0.21510637, 0.31916398], + [0.00236845, 0.5842595], [0.6783676, 0.5216348]]], + [[[0.23089725, 0.5568468], [0.7405102, 0.06438422], + [0.6887394, 0.22089851], [0.0502342, 0.79228795]], + [[0.44883424, 0.15427643], [0.13817799, 0.34856772], + [0.7989621, 0.33788306], [0.15699774, 0.7693662]]]], + device=device).float() + points = torch.tensor( + [[[[0.06001121, 0.92963666, 0.5753327, 0.7251477], + [0.53563064, 0.23129565, 0.92366195, 0.44261628]], + [[0.5770022, 0.56625944, 0.23560429, 0.11178821], + [0.7735967, 0.95678777, 0.25468266, 0.02895975]], + [[0.0589869, 0.09017515, 0.5977862, 0.02797985], + [0.603862, 0.35991007, 0.85761684, 0.3096559]], + [[0.22359002, 0.13983732, 0.5544243, 0.68863827], + [0.85646236, 0.75651926, 0.8638947, 0.83600986]], + [[0.45424145, 0.27458847, 0.6456112, 0.47162914], + [0.15773582, 0.47645122, 0.79964715, 0.3323908]], + [[0.8351399, 0.84696376, 0.9431732, 0.29418713], + [0.77168906, 0.6996871, 0.19354361, 0.03392768]], + [[0.30976456, 0.7074133, 0.581795, 0.976677], + [0.69656056, 0.07199162, 0.4708506, 0.29117996]], + [[0.5829035, 0.30201727, 0.76556486, 0.0935446], + [0.88030535, 0.16129416, 0.9242525, 0.49545723]]], + [[[0.50899494, 0.06482804, 0.44939405, 0.37704808], + [0.47028124, 0.11969638, 0.62823206, 0.28560323]], + [[0.40690207, 0.689753, 0.51636654, 0.23040164], + [0.06935787, 0.00488842, 0.22462702, 0.09182382]], + [[0.26611632, 0.00184339, 0.7730655, 0.5228131], + [0.87776035, 0.77895886, 0.2787183, 0.16620636]], + [[0.502574, 0.04039001, 0.5368497, 0.98379374], + [0.40973026, 0.3238272, 0.9733018, 0.13988364]], + [[0.04586202, 0.20983845, 0.20662665, 0.22270602], + [0.60387236, 0.5155574, 0.51237285, 0.6528438]], + [[0.45735973, 0.86821306, 0.61054605, 0.8370336], + [0.45193362, 0.3734138, 0.7825672, 0.5699416]], + [[0.44591594, 0.12447512, 0.09282011, 0.7055254], + [0.25223452, 0.46696228, 0.7051136, 0.892151]], + [[0.49615085, 0.47321403, 0.93138885, 0.7652197], + [0.38766378, 0.30332977, 0.23131835, 0.02863514]]]], + device=device).float() + centers = torch.tensor( + [[[[0.83878064, 0.96658987, 0.8033424, 0.9598312], + [0.45035273, 0.8768925, 0.977736, 0.54547966]], + [[0.01041394, 0.597893, 0.36212963, 0.4410367], + [0.94879234, 0.8372817, 0.21237361, 0.67945415]], + [[0.5096087, 0.26401454, 0.60034937, 0.5417416], + [0.87591463, 0.546456, 0.4096033, 0.16373193]], + [[0.79547447, 0.1482386, 0.12840575, 0.45384115], + [0.5640288, 0.944541, 0.5745328, 0.73229736]], + [[0.93011934, 0.7406011, 0.62621707, 0.8677915], + [0.91563636, 0.3595413, 0.6678378, 0.6085383]], + [[0.22431666, 0.65617776, 0.7483924, 0.6263364], + [0.30968404, 0.78204364, 0.14899081, 0.09628749]], + [[0.73675203, 0.72104895, 0.4648038, 0.6101647], + [0.7817645, 0.16572917, 0.3311919, 0.43407398]], + [[0.8193154, 0.09559608, 0.05978829, 0.90262103], + [0.4256065, 0.8165596, 0.8206446, 0.6604721]]], + [[[0.7159653, 0.18600845, 0.21433902, 0.3159626], + [0.3921569, 0.33221376, 0.5061177, 0.7961841]], + [[0.95338356, 0.04785997, 0.67185795, 0.6538394], + [0.4729132, 0.33404195, 0.17750603, 0.8445621]], + [[0.6755793, 0.16193843, 0.75943846, 0.92123103], + [0.2781859, 0.03114432, 0.710638, 0.52729136]], + [[0.8376105, 0.10858494, 0.13208169, 0.365772], + [0.5930795, 0.27390373, 0.14036089, 0.170403]], + [[0.3479789, 0.89855295, 0.04844379, 0.9871029], + [0.29781651, 0.0244137, 0.9179047, 0.8081611]], + [[0.12460887, 0.44991326, 0.19382608, 0.35037738], + [0.2773472, 0.4362057, 0.36757517, 0.5993509]], + [[0.29630446, 0.90046406, 0.5417113, 0.13510644], + [0.09623539, 0.04226565, 0.32001644, 0.44358212]], + [[0.5274848, 0.82096446, 0.9415489, 0.7123748], + [0.7537517, 0.8086482, 0.85345286, 0.7472754]]]], + device=device).float() + if device == 'cuda': + points.requires_grad_() + scores.requires_grad_() + centers.requires_grad_() + knn_idx = torch.tensor( + [[[6, 7, 4, 6], [2, 4, 2, 4]], [[7, 1, 3, 2], [6, 0, 2, 6]]], + device=device).long() aggregate = 'sum' expected_output = torch.tensor( [[[[-0.08134781, 0.03877336, -0.8212776, -0.2869547], @@ -117,69 +129,70 @@ def test_paconv_assign_scores(): loss = output.sum() loss.backward() expected_scores_grad = torch.tensor([[[[0.04288036, -0.18217683], - [-0.78873926, 0.7485497], - [-0.6866992, 0.05346543], - [0.04288036, -0.18217683]], + [-0.78873926, 0.7485497], + [-0.6866992, 0.05346543], + [0.04288036, -0.18217683]], [[-1.1407862, 0.13533896], - [-0.06964391, -0.22948086], - [-1.1407862, 0.13533896], - [-0.06964391, -0.22948086]]], - [[[-0.3363995, -2.212181], - [-1.1589496, -2.7724311], - [-0.9387654, -1.3163853], - [-1.4385346, -1.0614843]], + [-0.06964391, -0.22948086], + [-1.1407862, 0.13533896], + [-0.06964391, -0.22948086]]], + [[[-0.3363995, -2.212181], + [-1.1589496, -2.7724311], + [-0.9387654, -1.3163853], + [-1.4385346, -1.0614843]], [[-0.5048497, 1.4143617], - [-0.47332114, 0.6017133], - [-0.30974793, 1.1995442], - [-0.5048497, 1.4143617]]]]).float() + [-0.47332114, 0.6017133], + [-0.30974793, 1.1995442], + [-0.5048497, + 1.4143617]]]]).float() expected_points_grad = torch.tensor( [[[[0., 0., 0., 0.], [0., 0., 0., 0.]], [[0., 0., 0., 0.], [0., 0., 0., 0.]], [[0.15585709, 0.15585709, 0.15585709, 0.15585709], - [1.1893613, 1.1893613, 1.1893613, 1.1893613]], + [1.1893613, 1.1893613, 1.1893613, 1.1893613]], [[0., 0., 0., 0.], [0., 0., 0., 0.]], [[1.6530733, 1.6530733, 1.6530733, 1.6530733], - [1.8130021, 1.8130021, 1.8130021, 1.8130021]], + [1.8130021, 1.8130021, 1.8130021, 1.8130021]], [[0., 0., 0., 0.], [0., 0., 0., 0.]], [[0.58863074, 0.58863074, 0.58863074, 0.58863074], - [1.3727596, 1.3727596, 1.3727596, 1.3727596]], + [1.3727596, 1.3727596, 1.3727596, 1.3727596]], [[0.28462553, 0.28462553, 0.28462553, 0.28462553], - [0.8378516, 0.8378516, 0.8378516, 0.8378516]]], - [[[0.13817799, 0.13817799, 0.13817799, 0.13817799], - [0.34856772, 0.34856772, 0.34856772, 0.34856772]], + [0.8378516, 0.8378516, 0.8378516, 0.8378516]]], + [[[0.13817799, 0.13817799, 0.13817799, 0.13817799], + [0.34856772, 0.34856772, 0.34856772, 0.34856772]], [[0.7405102, 0.7405102, 0.7405102, 0.7405102], - [0.06438422, 0.06438422, 0.06438422, 0.06438422]], + [0.06438422, 0.06438422, 0.06438422, 0.06438422]], [[0.8491963, 0.8491963, 0.8491963, 0.8491963], - [1.1301711, 1.1301711, 1.1301711, 1.1301711]], + [1.1301711, 1.1301711, 1.1301711, 1.1301711]], [[0.6887394, 0.6887394, 0.6887394, 0.6887394], - [0.22089851, 0.22089851, 0.22089851, 0.22089851]], + [0.22089851, 0.22089851, 0.22089851, 0.22089851]], [[0., 0., 0., 0.], [0., 0., 0., 0.]], [[0., 0., 0., 0.], [0., 0., 0., 0.]], [[0.605832, 0.605832, 0.605832, 0.605832], - [0.92364264, 0.92364264, 0.92364264, 0.92364264]], + [0.92364264, 0.92364264, 0.92364264, 0.92364264]], [[0.23089725, 0.23089725, 0.23089725, 0.23089725], - [0.5568468, 0.5568468, 0.5568468, 0.5568468]]]]).float() + [0.5568468, 0.5568468, 0.5568468, 0.5568468]]]]).float() expected_centers_grad = torch.tensor( [[[[0., 0., 0., 0.], [0., 0., 0., 0.]], [[0., 0., 0., 0.], [0., 0., 0., 0.]], [[-1.0493311, -1.0493311, -1.0493311, -1.0493311], - [-2.0301602, -2.0301602, -2.0301602, -2.0301602]], + [-2.0301602, -2.0301602, -2.0301602, -2.0301602]], [[0., 0., 0., 0.], [0., 0., 0., 0.]], [[0., 0., 0., 0.], [0., 0., 0., 0.]], [[0., 0., 0., 0.], [0., 0., 0., 0.]], [[-1.6328557, -1.6328557, -1.6328557, -1.6328557], - [-3.1828144, -3.1828144, -3.1828144, -3.1828144]], + [-3.1828144, -3.1828144, -3.1828144, -3.1828144]], [[0., 0., 0., 0.], [0., 0., 0., 0.]]], - [[[0., 0., 0., 0.], [0., 0., 0., 0.]], + [[[0., 0., 0., 0.], [0., 0., 0., 0.]], [[0., 0., 0., 0.], [0., 0., 0., 0.]], [[0., 0., 0., 0.], [0., 0., 0., 0.]], [[0., 0., 0., 0.], [0., 0., 0., 0.]], [[0., 0., 0., 0.], [0., 0., 0., 0.]], [[0., 0., 0., 0.], [0., 0., 0., 0.]], [[-1.5429721, -1.5429721, -1.5429721, -1.5429721], - [-1.6100934, -1.6100934, -1.6100934, -1.6100934]], + [-1.6100934, -1.6100934, -1.6100934, -1.6100934]], [[-1.7103812, -1.7103812, -1.7103812, -1.7103812], - [-1.6344175, -1.6344175, -1.6344175, -1.6344175]]]]).float() + [-1.6344175, -1.6344175, -1.6344175, -1.6344175]]]]).float() assert torch.allclose( scores.grad.detach().cpu(), expected_scores_grad, atol=1e-6) assert torch.allclose( diff --git a/tests/test_ops/test_voxelization.py b/tests/test_ops/test_voxelization.py index 78282a8ad0..c31224913e 100644 --- a/tests/test_ops/test_voxelization.py +++ b/tests/test_ops/test_voxelization.py @@ -200,6 +200,9 @@ def test_voxelization_npu(device_type): points = voxel_dict['points'] points = torch.tensor(points) + max_num_points = -1 + dynamic_voxelization = Voxelization(voxel_size, point_cloud_range, + max_num_points) max_num_points = 1000 hard_voxelization = Voxelization(voxel_size, point_cloud_range, max_num_points) @@ -215,3 +218,15 @@ def test_voxelization_npu(device_type): assert np.all(coors == expected_coors) assert np.all(voxels == expected_voxels) assert np.all(num_points_per_voxel == expected_num_points_per_voxel) + + # test dynamic_voxelization on npu + coors = dynamic_voxelization.forward(points) + coors = coors.cpu().detach().numpy() + points = points.cpu().detach().numpy() + for i in range(expected_voxels.shape[0]): + indices = _get_voxel_points_indices(points, coors, expected_voxels[i]) + num_points_current_voxel = points[indices].shape[0] + assert num_points_current_voxel > 0 + assert np.all( + points[indices] == expected_coors[i][:num_points_current_voxel]) + assert num_points_current_voxel == expected_num_points_per_voxel[i]