Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Ehance & Fix] Support any slice interval for indexing(__getitem__) in eager/static mode #69827

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions paddle/fluid/pybind/slice_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/common_infer_shape_functions.h"
#include "paddle/phi/kernels/funcs/strided_slice.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
Expand Down Expand Up @@ -143,11 +144,9 @@ static int _PySlice_GetIndices(PySliceObject* r,
"tensor(int) and numpy(int) in slice item, but received %s.",
std::string(Py_TYPE(r->start)->tp_name)));
}
if (*start < 0) *start += length;
*start = std::max(*start, static_cast<Py_ssize_t>(0));
}
if (r->stop == Py_None) {
*stop = *step < 0 ? -1 : length;
*stop = *step < 0 ? -length - 1 : length;
} else {
if (PyCheckInteger(r->stop) || IsNumpyType(r->stop)) {
*stop = PyLong_AsLong(r->stop);
Expand All @@ -159,9 +158,13 @@ static int _PySlice_GetIndices(PySliceObject* r,
"tensor(int) and numpy(int) in slice item, but received %s.",
std::string(Py_TYPE(r->stop)->tp_name)));
}
if (0 < *step && *stop < 0) *stop += length;
*stop = std::min(*stop, length);
}

// normalize start and stop
bool dummy_zero_dim_out = false;
phi::funcs::normalize_interval(
*start, *stop, *step, length, start, stop, &dummy_zero_dim_out);
// return value below seems to be useless...
if (*stop > length) return -1;
if (*start >= length) return -1;
if (*step == 0) return -1;
Expand Down
233 changes: 180 additions & 53 deletions paddle/phi/kernels/funcs/slice_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,164 @@ namespace phi {

namespace funcs {

/**
* @brief Normalizes the slice interval [st, ed) with a given step and dimension
* size.
*
* This function adjusts the interval [st, ed) to fit within the bounds defined
* by the dimension size, taking into account the specified step. It handles
* both positive and negative steps and accounts for negative indices by
* converting them to equivalent positive indices within the dimension size.
*
* @tparam T The data type of the input parameters, which can be an integer or
* floating-point type.
* @param st The starting index of the interval.
* @param ed The ending index of the interval (exclusive).
* @param step The step size for iterating through the interval, which can be
* positive or negative.
* @param dim_size The size of the dimension, serving as the upper bound for
* valid indices.
* @param st_out Pointer to store the normalized starting index.
* @param ed_out Pointer to store the normalized ending index.
* @param zero_dim_out Pointer to a boolean flag that is set to true if the
* resulting interval is empty.
*
* @details
* - If `step > 0`, the function ensures that `st` and `ed` are adjusted to be
* within the range [0, dim_size).
* - If `step < 0`, the function adjusts `st` and `ed` to accommodate the
* reverse traversal of the interval.
* - Handles special cases where `st` and `ed` may be out of bounds or where
* `dim_size` is zero.
* - Uses pointer parameters for output to modify the values directly.
* - The function also handles scenarios involving negative indices, converting
* them appropriately.
*
* @example
* T st_out, ed_out;
* bool zero_dim;
* normalize_interval(-3, -2, 1, 4, &st_out, &ed_out, &zero_dim);
* // Results in: st_out = 1, ed_out = 2, zero_dim = false
*
* @note The function assumes that the pointers provided for output parameters
* are valid and non-null.
*/
template <typename T>
void normalize_interval(
T st, T ed, T step, T dim_size, T* st_out, T* ed_out, bool* zero_dim_out) {
/* Normalize slice interval [st, ed) with given step and dim_size.
e.g. if given st = -3, ed = -2, step = 1, dim_size = 4,
then normalized st_out = 1(-3+4), st_ed = 2(-2+4).

This function is general enough and applicable
for both step > 0 and step < 0 scenarios.

Indicices dipicted as below:

===============================================================
| 0 1 2 3 ... D-1 | D D+1 ...
... -D-2 -D-1 | -D -D+1 -D+2 -D+3 ... -1 |
===============================================================
*/
// 0 dim size, just return
if (dim_size <= 0) {
*st_out = *ed_out = 0;
*zero_dim_out = true;
return;
}

if (step > 0) {
/* positive step */
// 0 dim size case 1
if (st >= dim_size) {
*st_out = *ed_out = 0;
*zero_dim_out = true;
return;
}

// 0 dim size case 2
if (ed <= -dim_size) {
*st_out = *ed_out = 0;
*zero_dim_out = true;
return;
}

// make st belongs: (-inf, -D-1)∪[0, D)
if (-dim_size <= st && st < 0) {
st += dim_size;
}
// make st belongs: [0, D)
st = std::max(st, static_cast<T>(0));

// make ed belongs: [0, +inf)
if (-dim_size <= ed && ed < 0) {
ed += dim_size;
}
// make ed belongs: [0, D]
ed = std::min(ed, dim_size);

// 0 dim size case 3
if (st >= ed) {
*st_out = *ed_out = 0;
*zero_dim_out = true;
return;
}
*st_out = st;
*ed_out = ed;
return;

} else {
/* negative step */
// 0 dim size case 1
if (st <= -dim_size - 1) {
*st_out = *ed_out = 0;
*zero_dim_out = true;
return;
}

// 0 dim size case 2
if (ed >= dim_size - 1) {
*st_out = *ed_out = 0;
*zero_dim_out = true;
return;
}

// make st belongs: [0, D)∪[0, +inf)
if (-dim_size <= st && st < 0) {
st += dim_size;
}
// make st belongs: [0, D)
st = std::min(st, dim_size - 1);

// make ed belongs: [-inf, -D)∪[0, D)
if (-dim_size <= ed && ed < 0) {
ed += dim_size;
}
// make ed belongs: [-D-1, -D)∪[0, D) ==> {-D-1}∪[0, D)
ed = std::max(ed, -dim_size - 1);

if (ed == -dim_size - 1) {
// When ed=-D-1, it is symmetrical to when step is greater than 0 and
// ed=D.
*st_out = st;
*ed_out = ed;
return;
}

// now only remain the case that ed belongs to: [0, D)
// 0 dim size case 3
if (ed >= st) {
*st_out = *ed_out = 0;
*zero_dim_out = true;
return;
}

*st_out = st;
*ed_out = ed;
return;
}
}

template <typename T = int64_t>
inline void CheckAndUpdateSliceAttrs(const DDim in_dims,
const std::vector<T>& axes,
Expand Down Expand Up @@ -56,41 +214,17 @@ inline void CheckAndUpdateSliceAttrs(const DDim in_dims,
common::errors::InvalidArgument(
"Step should not be 0, but received step = %d.", step));

T start = (*starts)[i] < 0 ? ((*starts)[i] + dim_value) : (*starts)[i];
start = std::max(start, static_cast<T>(0));

T end =
0 < step && (*ends)[i] < 0 ? ((*ends)[i] + dim_value) : (*ends)[i];
end = std::min(end, dim_value);

if (step > 0) {
start = std::min(start, dim_value);
end = std::max(end, static_cast<T>(0));
PADDLE_ENFORCE_GE(
end,
start,
common::errors::InvalidArgument(
"When step > 0, end should be greater than start, but "
"received end = %d, start = %d.",
end,
start));
} else {
// NOTE(liym27): When step < 0, start should less and equal to
// dim_value-1
// "end is -1" means contain the 0-th element of this axis.
start = std::min(start, dim_value - 1);
if (end < -1) {
end += dim_value;
}
end = std::max(end, static_cast<T>(-1));
PADDLE_ENFORCE_GE(
start,
end,
common::errors::InvalidArgument(
"When step < 0, start should be greater than end, but "
"received start = %d, end = %d.",
start,
end));
T start, end;
bool dummy_zero_out_dim = false;
normalize_interval((*starts)[i],
(*ends)[i],
step,
dim_value,
&start,
&end,
&dummy_zero_out_dim);
if (end == -dim_value - 1) {
end = -1;
}

(*starts)[i] = start;
Expand All @@ -117,24 +251,17 @@ inline void UpdateSliceAttrs(const DDim in_dims,
T dim_value = in_dims[axis];
if (dim_value > 0) {
T step = steps == nullptr ? 1 : (*steps)[i];
T start = (*starts)[i] < 0 ? ((*starts)[i] + dim_value) : (*starts)[i];
start = std::max(start, static_cast<T>(0));
T end =
0 < step && (*ends)[i] < 0 ? ((*ends)[i] + dim_value) : (*ends)[i];
end = std::min(end, dim_value);

if (step > 0) {
start = std::min(start, dim_value);
end = std::max(end, static_cast<T>(0));
} else {
// NOTE: When step < 0, start should less and equal to
// dim_value-1
// "end is -1" means contain the 0-th element of this axis.
start = std::min(start, dim_value - 1);
if (end < -1) {
end += dim_value;
}
end = std::max(end, static_cast<T>(-1));
T start = (*starts)[i];
T end = (*ends)[i];

bool dummy_zero_out_dim = false;
normalize_interval(
start, end, step, dim_value, &start, &end, &dummy_zero_out_dim);

// manually set the end to -1 when step < 0,
// which indicates that it can extend to the left endpoint.
if (end == -dim_value - 1 && step < 0) {
end = -1;
}
(*starts)[i] = start;
(*ends)[i] = end;
Expand Down
73 changes: 30 additions & 43 deletions paddle/phi/kernels/funcs/strided_slice.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/slice_utils.h"

namespace phi {
namespace funcs {
Expand Down Expand Up @@ -73,39 +74,26 @@ static void StridedSliceOutDims(const std::vector<int64_t>& starts,
continue;
}

if (start_index < 0) {
start_index = start_index + axis_size;
start_index = std::max<int64_t>(start_index, 0);
}
if (end_index < 0) {
if (!(end_index == -1 && stride_index < 0)) { // skip None stop condition
end_index = end_index + axis_size;
if (end_index < 0) {
end_index = 0;
}
}
bool neg_dim_condition = false;
normalize_interval(start_index,
end_index,
stride_index,
axis_size,
&start_index,
&end_index,
&neg_dim_condition);
if (end_index == -axis_size - 1) {
end_index = -1;
}

if (stride_index < 0) {
start_index = start_index + 1;
end_index = end_index + 1;
int64_t out_dims_index;
if (neg_dim_condition) {
out_dims_index = 0;
} else {
int64_t step_size = std::abs(stride_index);
out_dims_index =
(std::abs(end_index - start_index) + step_size - 1) / step_size;
}

bool neg_dim_condition = ((stride_index < 0 && (start_index < end_index)) ||
(stride_index > 0 && (start_index > end_index)));
PADDLE_ENFORCE_EQ(neg_dim_condition,
false,
errors::InvalidArgument(
"The start index and end index are invalid for their "
"corresponding stride."));

int64_t left =
std::max(static_cast<int64_t>(0), std::min(start_index, end_index));
int64_t right = std::min(axis_size, std::max(start_index, end_index));
int64_t step = std::abs(stride_index);

auto out_dims_index = (std::abs(right - left) + step - 1) / step;

out_dims_vector[axes_index] = out_dims_index;
}
}
Expand Down Expand Up @@ -136,19 +124,18 @@ static void StridedSliceFunctor(int64_t* starts,
decrease_axis_affect = true;
}
}
// stride must not be zero
if (starts[axis_index] < 0) {
starts[axis_index] = starts[axis_index] + axis_size;
starts[axis_index] = std::max<int64_t>(starts[axis_index], 0);
}
if (ends[axis_index] < 0) {
if (!(ends[axis_index] == -1 &&
strides[axis_index] < 0)) { // skip None stop condition
ends[axis_index] = ends[axis_index] + axis_size;
if (ends[axis_index] < 0) {
ends[axis_index] = 0;
}
}
bool dummy_zero_dim_out = false;
normalize_interval(starts[axis_index],
ends[axis_index],
strides[axis_index],
axis_size,
&starts[axis_index],
&ends[axis_index],
&dummy_zero_dim_out);
if (ends[axis_index] == -axis_size - 1) {
// manually set the end to -1 when step < 0,
// which indicates that it can extend to the left endpoint.
ends[axis_index] = -1;
}
if (decrease_axis_affect) {
if (strides[axis_index] < 0) {
Expand Down
Loading