Skip to content

Commit

Permalink
fix deformConv and modulatedDeformConv input kernel_size
Browse files Browse the repository at this point in the history
  • Loading branch information
wuzheyi1028 authored and momo609 committed Jun 18, 2024
1 parent 8552434 commit c151a35
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion mmcv/ops/deform_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def _npu_backward(ctx, grad_output):
grad_input, grad_weight, grad_offset_all, grad_bias = \
torch_npu.npu_deformable_conv2dbk(
input_tensor, grad_output, offset_out, weight, offset_all,
kernel_size=[weight.shape[3], weight.shape[2]],
kernel_size=[weight.shape[2], weight.shape[3]],
stride=[1, 1, ctx.stride[0], ctx.stride[1]],
padding=[ctx.padding[0], ctx.padding[0], ctx.padding[1],
ctx.padding[1]],
Expand Down
6 changes: 3 additions & 3 deletions mmcv/ops/modulated_deform_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _npu_forward(ctx, input_tensor, offset, mask, weight, bias):
conv2d_bias = bias if len(bias) > 0 else None
sort_index_fp, sort_index_bp = \
ModulatedDeformConv2dFunction._calculate_sort_index(
kernel_w, kernel_h, ctx.deform_groups)
kernel_h, kernel_w, ctx.deform_groups)
select_offset = offset.index_select(1, sort_index_fp)
offset_all = torch.cat([select_offset, mask], dim=1)
import torch_npu
Expand All @@ -64,7 +64,7 @@ def _npu_forward(ctx, input_tensor, offset, mask, weight, bias):
weight,
offset_all,
conv2d_bias,
kernel_size=[kernel_w, kernel_h],
kernel_size=[kernel_h, kernel_w],
stride=[1, 1, ctx.stride[0], ctx.stride[1]],
padding=[
ctx.padding[0], ctx.padding[0], ctx.padding[1], ctx.padding[1]
Expand All @@ -87,7 +87,7 @@ def _npu_backward(ctx, grad_output):
grad_input, grad_weight, grad_offset_all, grad_bias = \
torch_npu.npu_deformable_conv2dbk(
input_tensor, grad_output, offset_out, weight, offset_all,
kernel_size=[weight.shape[3], weight.shape[2]],
kernel_size=[weight.shape[2], weight.shape[3]],
stride=[1, 1, ctx.stride[0], ctx.stride[1]],
padding=[ctx.padding[0], ctx.padding[0], ctx.padding[1],
ctx.padding[1]],
Expand Down

0 comments on commit c151a35

Please sign in to comment.