diff --git a/backends/mlu/kernels/slice_kernel.cc b/backends/mlu/kernels/slice_kernel.cc index 273548b4d..2bd9d3e35 100644 --- a/backends/mlu/kernels/slice_kernel.cc +++ b/backends/mlu/kernels/slice_kernel.cc @@ -243,7 +243,7 @@ inline void CheckAndUpdateSliceAttrs(const phi::DDim in_dims, &start, &end, &dummy_zero_out_dim); - if (end == -dim_value - 1) { + if (step < 0 && end == -dim_value - 1) { end = -1; } diff --git a/backends/mlu/kernels/strided_slice_kernel.cc b/backends/mlu/kernels/strided_slice_kernel.cc index 32db94920..6396437f7 100644 --- a/backends/mlu/kernels/strided_slice_kernel.cc +++ b/backends/mlu/kernels/strided_slice_kernel.cc @@ -14,6 +14,7 @@ #include "kernels/funcs/mlu_baseop.h" #include "kernels/funcs/mlu_funcs.h" +#include "paddle/phi/kernels/funcs/slice_utils.h" namespace custom_kernel { static void StridedSliceOutDims(const std::vector& starts, @@ -40,6 +41,8 @@ static void StridedSliceOutDims(const std::vector& starts, auto ret = std::find(decrease_axis.begin(), decrease_axis.end(), axes[i]); if (ret != decrease_axis.end()) { decrease_axis_affect = true; + start_index = in_dims[axes_index] - 1; + end_index = in_dims[axes_index]; } } if (decrease_axis_affect) { @@ -61,39 +64,27 @@ static void StridedSliceOutDims(const std::vector& starts, continue; } - if (start_index < 0) { - start_index = start_index + axis_size; - start_index = std::max(start_index, 0); - } - if (end_index < 0) { - if (!(end_index == -1 && stride_index < 0)) { // skip None stop condition - end_index = end_index + axis_size; - if (end_index < 0) { - end_index = 0; - } - } + bool neg_dim_condition = false; + phi::funcs::normalize_interval(start_index, + end_index, + stride_index, + axis_size, + &start_index, + &end_index, + &neg_dim_condition); + if (stride_index < 0 && end_index == -axis_size - 1) { + end_index = -1; } - if (stride_index < 0) { - start_index = start_index + 1; - end_index = end_index + 1; + int64_t out_dims_index; + if (neg_dim_condition) { + out_dims_index = 0; + } else { + int64_t step_size = std::abs(stride_index); + out_dims_index = + (std::abs(end_index - start_index) + step_size - 1) / step_size; } - bool neg_dim_condition = ((stride_index < 0 && (start_index < end_index)) || - (stride_index > 0 && (start_index > end_index))); - PADDLE_ENFORCE_EQ(neg_dim_condition, - false, - phi::errors::InvalidArgument( - "The start index and end index are invalid for their " - "corresponding stride.")); - - int64_t left = - std::max(static_cast(0), std::min(start_index, end_index)); - int64_t right = std::min(axis_size, std::max(start_index, end_index)); - int64_t step = std::abs(stride_index); - - auto out_dims_index = (std::abs(right - left) + step - 1) / step; - out_dims_vector[axes_index] = out_dims_index; } } @@ -122,21 +113,23 @@ static void StridedSliceFunctor(int64_t* starts, decrease_axis.begin(), decrease_axis.end(), axes[axis_index]); if (ret != decrease_axis.end()) { decrease_axis_affect = true; + starts[axis_index] = axis_size - 1; + ends[axis_index] = axis_size; } } // stride must not be zero - if (starts[axis_index] < 0) { - starts[axis_index] = starts[axis_index] + axis_size; - starts[axis_index] = std::max(starts[axis_index], 0); - } - if (ends[axis_index] < 0) { - if (!(ends[axis_index] == -1 && - strides[axis_index] < 0)) { // skip None stop condition - ends[axis_index] = ends[axis_index] + axis_size; - if (ends[axis_index] < 0) { - ends[axis_index] = 0; - } - } + bool dummy_zero_dim_out = false; + phi::funcs::normalize_interval(starts[axis_index], + ends[axis_index], + strides[axis_index], + axis_size, + &starts[axis_index], + &ends[axis_index], + &dummy_zero_dim_out); + if (strides[axis_index] < 0 && ends[axis_index] == -axis_size - 1) { + // manually set the end to -1 when step < 0, + // which indicates that it can extend to the left endpoint. + ends[axis_index] = -1; } if (decrease_axis_affect) { if (strides[axis_index] < 0) { diff --git a/backends/mlu/tools/disable_ut_mlu b/backends/mlu/tools/disable_ut_mlu index cfb2175a9..e04306945 100755 --- a/backends/mlu/tools/disable_ut_mlu +++ b/backends/mlu/tools/disable_ut_mlu @@ -14,4 +14,3 @@ test_rms_norm_op_mlu test_sync_batch_norm_op_mlu test_unsqueeze_op_mlu test_LeNet_MNIST -test_strided_slice_op_mlu