Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLU] Fix slice #1523

Closed
wants to merge 12 commits into from
2 changes: 1 addition & 1 deletion backends/mlu/kernels/slice_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ inline void CheckAndUpdateSliceAttrs(const phi::DDim in_dims,
&start,
&end,
&dummy_zero_out_dim);
if (end == -dim_value - 1) {
if (step < 0 && end == -dim_value - 1) {
end = -1;
}

Expand Down
75 changes: 34 additions & 41 deletions backends/mlu/kernels/strided_slice_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "kernels/funcs/mlu_baseop.h"
#include "kernels/funcs/mlu_funcs.h"
#include "paddle/phi/kernels/funcs/slice_utils.h"

namespace custom_kernel {
static void StridedSliceOutDims(const std::vector<int64_t>& starts,
Expand All @@ -40,6 +41,8 @@ static void StridedSliceOutDims(const std::vector<int64_t>& starts,
auto ret = std::find(decrease_axis.begin(), decrease_axis.end(), axes[i]);
if (ret != decrease_axis.end()) {
decrease_axis_affect = true;
start_index = in_dims[axes_index] - 1;
end_index = in_dims[axes_index];
}
}
if (decrease_axis_affect) {
Expand All @@ -61,39 +64,27 @@ static void StridedSliceOutDims(const std::vector<int64_t>& starts,
continue;
}

if (start_index < 0) {
start_index = start_index + axis_size;
start_index = std::max<int64_t>(start_index, 0);
}
if (end_index < 0) {
if (!(end_index == -1 && stride_index < 0)) { // skip None stop condition
end_index = end_index + axis_size;
if (end_index < 0) {
end_index = 0;
}
}
bool neg_dim_condition = false;
phi::funcs::normalize_interval(start_index,
end_index,
stride_index,
axis_size,
&start_index,
&end_index,
&neg_dim_condition);
if (stride_index < 0 && end_index == -axis_size - 1) {
end_index = -1;
}

if (stride_index < 0) {
start_index = start_index + 1;
end_index = end_index + 1;
int64_t out_dims_index;
if (neg_dim_condition) {
out_dims_index = 0;
} else {
int64_t step_size = std::abs(stride_index);
out_dims_index =
(std::abs(end_index - start_index) + step_size - 1) / step_size;
}

bool neg_dim_condition = ((stride_index < 0 && (start_index < end_index)) ||
(stride_index > 0 && (start_index > end_index)));
PADDLE_ENFORCE_EQ(neg_dim_condition,
false,
phi::errors::InvalidArgument(
"The start index and end index are invalid for their "
"corresponding stride."));

int64_t left =
std::max(static_cast<int64_t>(0), std::min(start_index, end_index));
int64_t right = std::min(axis_size, std::max(start_index, end_index));
int64_t step = std::abs(stride_index);

auto out_dims_index = (std::abs(right - left) + step - 1) / step;

out_dims_vector[axes_index] = out_dims_index;
}
}
Expand Down Expand Up @@ -122,21 +113,23 @@ static void StridedSliceFunctor(int64_t* starts,
decrease_axis.begin(), decrease_axis.end(), axes[axis_index]);
if (ret != decrease_axis.end()) {
decrease_axis_affect = true;
starts[axis_index] = axis_size - 1;
ends[axis_index] = axis_size;
}
}
// stride must not be zero
if (starts[axis_index] < 0) {
starts[axis_index] = starts[axis_index] + axis_size;
starts[axis_index] = std::max<int64_t>(starts[axis_index], 0);
}
if (ends[axis_index] < 0) {
if (!(ends[axis_index] == -1 &&
strides[axis_index] < 0)) { // skip None stop condition
ends[axis_index] = ends[axis_index] + axis_size;
if (ends[axis_index] < 0) {
ends[axis_index] = 0;
}
}
bool dummy_zero_dim_out = false;
phi::funcs::normalize_interval(starts[axis_index],
ends[axis_index],
strides[axis_index],
axis_size,
&starts[axis_index],
&ends[axis_index],
&dummy_zero_dim_out);
if (strides[axis_index] < 0 && ends[axis_index] == -axis_size - 1) {
// manually set the end to -1 when step < 0,
// which indicates that it can extend to the left endpoint.
ends[axis_index] = -1;
}
if (decrease_axis_affect) {
if (strides[axis_index] < 0) {
Expand Down
1 change: 0 additions & 1 deletion backends/mlu/tools/disable_ut_mlu
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,3 @@ test_rms_norm_op_mlu
test_sync_batch_norm_op_mlu
test_unsqueeze_op_mlu
test_LeNet_MNIST
test_strided_slice_op_mlu