Skip to content

Commit e704480

Browse files
[Ehance & Fix] Support any slice interval for indexing(__getitem__) in eager/static mode (#69827)
* support any slice interval * fix bug * fix more bug
1 parent a4f5446 commit e704480

File tree

8 files changed

+410
-145
lines changed

8 files changed

+410
-145
lines changed

paddle/fluid/pybind/slice_utils.h

+8-5
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include "paddle/phi/core/compat/convert_utils.h"
3232
#include "paddle/phi/core/dense_tensor.h"
3333
#include "paddle/phi/kernels/funcs/common_infer_shape_functions.h"
34+
#include "paddle/phi/kernels/funcs/strided_slice.h"
3435
#include "pybind11/numpy.h"
3536
#include "pybind11/pybind11.h"
3637
#include "pybind11/stl.h"
@@ -143,11 +144,9 @@ static int _PySlice_GetIndices(PySliceObject* r,
143144
"tensor(int) and numpy(int) in slice item, but received %s.",
144145
std::string(Py_TYPE(r->start)->tp_name)));
145146
}
146-
if (*start < 0) *start += length;
147-
*start = std::max(*start, static_cast<Py_ssize_t>(0));
148147
}
149148
if (r->stop == Py_None) {
150-
*stop = *step < 0 ? -1 : length;
149+
*stop = *step < 0 ? -length - 1 : length;
151150
} else {
152151
if (PyCheckInteger(r->stop) || IsNumpyType(r->stop)) {
153152
*stop = PyLong_AsLong(r->stop);
@@ -159,9 +158,13 @@ static int _PySlice_GetIndices(PySliceObject* r,
159158
"tensor(int) and numpy(int) in slice item, but received %s.",
160159
std::string(Py_TYPE(r->stop)->tp_name)));
161160
}
162-
if (0 < *step && *stop < 0) *stop += length;
163-
*stop = std::min(*stop, length);
164161
}
162+
163+
// normalize start and stop
164+
bool dummy_zero_dim_out = false;
165+
phi::funcs::normalize_interval(
166+
*start, *stop, *step, length, start, stop, &dummy_zero_dim_out);
167+
// return value below seems to be useless...
165168
if (*stop > length) return -1;
166169
if (*start >= length) return -1;
167170
if (*step == 0) return -1;

paddle/phi/kernels/funcs/slice_utils.h

+180-53
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,164 @@ namespace phi {
2323

2424
namespace funcs {
2525

26+
/**
27+
* @brief Normalizes the slice interval [st, ed) with a given step and dimension
28+
* size.
29+
*
30+
* This function adjusts the interval [st, ed) to fit within the bounds defined
31+
* by the dimension size, taking into account the specified step. It handles
32+
* both positive and negative steps and accounts for negative indices by
33+
* converting them to equivalent positive indices within the dimension size.
34+
*
35+
* @tparam T The data type of the input parameters, which can be an integer or
36+
* floating-point type.
37+
* @param st The starting index of the interval.
38+
* @param ed The ending index of the interval (exclusive).
39+
* @param step The step size for iterating through the interval, which can be
40+
* positive or negative.
41+
* @param dim_size The size of the dimension, serving as the upper bound for
42+
* valid indices.
43+
* @param st_out Pointer to store the normalized starting index.
44+
* @param ed_out Pointer to store the normalized ending index.
45+
* @param zero_dim_out Pointer to a boolean flag that is set to true if the
46+
* resulting interval is empty.
47+
*
48+
* @details
49+
* - If `step > 0`, the function ensures that `st` and `ed` are adjusted to be
50+
* within the range [0, dim_size).
51+
* - If `step < 0`, the function adjusts `st` and `ed` to accommodate the
52+
* reverse traversal of the interval.
53+
* - Handles special cases where `st` and `ed` may be out of bounds or where
54+
* `dim_size` is zero.
55+
* - Uses pointer parameters for output to modify the values directly.
56+
* - The function also handles scenarios involving negative indices, converting
57+
* them appropriately.
58+
*
59+
* @example
60+
* T st_out, ed_out;
61+
* bool zero_dim;
62+
* normalize_interval(-3, -2, 1, 4, &st_out, &ed_out, &zero_dim);
63+
* // Results in: st_out = 1, ed_out = 2, zero_dim = false
64+
*
65+
* @note The function assumes that the pointers provided for output parameters
66+
* are valid and non-null.
67+
*/
68+
template <typename T>
69+
void normalize_interval(
70+
T st, T ed, T step, T dim_size, T* st_out, T* ed_out, bool* zero_dim_out) {
71+
/* Normalize slice interval [st, ed) with given step and dim_size.
72+
e.g. if given st = -3, ed = -2, step = 1, dim_size = 4,
73+
then normalized st_out = 1(-3+4), st_ed = 2(-2+4).
74+
75+
This function is general enough and applicable
76+
for both step > 0 and step < 0 scenarios.
77+
78+
Indicices dipicted as below:
79+
80+
===============================================================
81+
| 0 1 2 3 ... D-1 | D D+1 ...
82+
... -D-2 -D-1 | -D -D+1 -D+2 -D+3 ... -1 |
83+
===============================================================
84+
*/
85+
// 0 dim size, just return
86+
if (dim_size <= 0) {
87+
*st_out = *ed_out = 0;
88+
*zero_dim_out = true;
89+
return;
90+
}
91+
92+
if (step > 0) {
93+
/* positive step */
94+
// 0 dim size case 1
95+
if (st >= dim_size) {
96+
*st_out = *ed_out = 0;
97+
*zero_dim_out = true;
98+
return;
99+
}
100+
101+
// 0 dim size case 2
102+
if (ed <= -dim_size) {
103+
*st_out = *ed_out = 0;
104+
*zero_dim_out = true;
105+
return;
106+
}
107+
108+
// make st belongs: (-inf, -D-1)∪[0, D)
109+
if (-dim_size <= st && st < 0) {
110+
st += dim_size;
111+
}
112+
// make st belongs: [0, D)
113+
st = std::max(st, static_cast<T>(0));
114+
115+
// make ed belongs: [0, +inf)
116+
if (-dim_size <= ed && ed < 0) {
117+
ed += dim_size;
118+
}
119+
// make ed belongs: [0, D]
120+
ed = std::min(ed, dim_size);
121+
122+
// 0 dim size case 3
123+
if (st >= ed) {
124+
*st_out = *ed_out = 0;
125+
*zero_dim_out = true;
126+
return;
127+
}
128+
*st_out = st;
129+
*ed_out = ed;
130+
return;
131+
132+
} else {
133+
/* negative step */
134+
// 0 dim size case 1
135+
if (st <= -dim_size - 1) {
136+
*st_out = *ed_out = 0;
137+
*zero_dim_out = true;
138+
return;
139+
}
140+
141+
// 0 dim size case 2
142+
if (ed >= dim_size - 1) {
143+
*st_out = *ed_out = 0;
144+
*zero_dim_out = true;
145+
return;
146+
}
147+
148+
// make st belongs: [0, D)∪[0, +inf)
149+
if (-dim_size <= st && st < 0) {
150+
st += dim_size;
151+
}
152+
// make st belongs: [0, D)
153+
st = std::min(st, dim_size - 1);
154+
155+
// make ed belongs: [-inf, -D)∪[0, D)
156+
if (-dim_size <= ed && ed < 0) {
157+
ed += dim_size;
158+
}
159+
// make ed belongs: [-D-1, -D)∪[0, D) ==> {-D-1}∪[0, D)
160+
ed = std::max(ed, -dim_size - 1);
161+
162+
if (ed == -dim_size - 1) {
163+
// When ed=-D-1, it is symmetrical to when step is greater than 0 and
164+
// ed=D.
165+
*st_out = st;
166+
*ed_out = ed;
167+
return;
168+
}
169+
170+
// now only remain the case that ed belongs to: [0, D)
171+
// 0 dim size case 3
172+
if (ed >= st) {
173+
*st_out = *ed_out = 0;
174+
*zero_dim_out = true;
175+
return;
176+
}
177+
178+
*st_out = st;
179+
*ed_out = ed;
180+
return;
181+
}
182+
}
183+
26184
template <typename T = int64_t>
27185
inline void CheckAndUpdateSliceAttrs(const DDim in_dims,
28186
const std::vector<T>& axes,
@@ -56,41 +214,17 @@ inline void CheckAndUpdateSliceAttrs(const DDim in_dims,
56214
common::errors::InvalidArgument(
57215
"Step should not be 0, but received step = %d.", step));
58216

59-
T start = (*starts)[i] < 0 ? ((*starts)[i] + dim_value) : (*starts)[i];
60-
start = std::max(start, static_cast<T>(0));
61-
62-
T end =
63-
0 < step && (*ends)[i] < 0 ? ((*ends)[i] + dim_value) : (*ends)[i];
64-
end = std::min(end, dim_value);
65-
66-
if (step > 0) {
67-
start = std::min(start, dim_value);
68-
end = std::max(end, static_cast<T>(0));
69-
PADDLE_ENFORCE_GE(
70-
end,
71-
start,
72-
common::errors::InvalidArgument(
73-
"When step > 0, end should be greater than start, but "
74-
"received end = %d, start = %d.",
75-
end,
76-
start));
77-
} else {
78-
// NOTE(liym27): When step < 0, start should less and equal to
79-
// dim_value-1
80-
// "end is -1" means contain the 0-th element of this axis.
81-
start = std::min(start, dim_value - 1);
82-
if (end < -1) {
83-
end += dim_value;
84-
}
85-
end = std::max(end, static_cast<T>(-1));
86-
PADDLE_ENFORCE_GE(
87-
start,
88-
end,
89-
common::errors::InvalidArgument(
90-
"When step < 0, start should be greater than end, but "
91-
"received start = %d, end = %d.",
92-
start,
93-
end));
217+
T start, end;
218+
bool dummy_zero_out_dim = false;
219+
normalize_interval((*starts)[i],
220+
(*ends)[i],
221+
step,
222+
dim_value,
223+
&start,
224+
&end,
225+
&dummy_zero_out_dim);
226+
if (end == -dim_value - 1) {
227+
end = -1;
94228
}
95229

96230
(*starts)[i] = start;
@@ -117,24 +251,17 @@ inline void UpdateSliceAttrs(const DDim in_dims,
117251
T dim_value = in_dims[axis];
118252
if (dim_value > 0) {
119253
T step = steps == nullptr ? 1 : (*steps)[i];
120-
T start = (*starts)[i] < 0 ? ((*starts)[i] + dim_value) : (*starts)[i];
121-
start = std::max(start, static_cast<T>(0));
122-
T end =
123-
0 < step && (*ends)[i] < 0 ? ((*ends)[i] + dim_value) : (*ends)[i];
124-
end = std::min(end, dim_value);
125-
126-
if (step > 0) {
127-
start = std::min(start, dim_value);
128-
end = std::max(end, static_cast<T>(0));
129-
} else {
130-
// NOTE: When step < 0, start should less and equal to
131-
// dim_value-1
132-
// "end is -1" means contain the 0-th element of this axis.
133-
start = std::min(start, dim_value - 1);
134-
if (end < -1) {
135-
end += dim_value;
136-
}
137-
end = std::max(end, static_cast<T>(-1));
254+
T start = (*starts)[i];
255+
T end = (*ends)[i];
256+
257+
bool dummy_zero_out_dim = false;
258+
normalize_interval(
259+
start, end, step, dim_value, &start, &end, &dummy_zero_out_dim);
260+
261+
// manually set the end to -1 when step < 0,
262+
// which indicates that it can extend to the left endpoint.
263+
if (end == -dim_value - 1 && step < 0) {
264+
end = -1;
138265
}
139266
(*starts)[i] = start;
140267
(*ends)[i] = end;

paddle/phi/kernels/funcs/strided_slice.h

+30-43
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "paddle/phi/kernels/funcs/eigen/common.h"
2626
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
2727
#include "paddle/phi/kernels/funcs/math_function.h"
28+
#include "paddle/phi/kernels/funcs/slice_utils.h"
2829

2930
namespace phi {
3031
namespace funcs {
@@ -73,39 +74,26 @@ static void StridedSliceOutDims(const std::vector<int64_t>& starts,
7374
continue;
7475
}
7576

76-
if (start_index < 0) {
77-
start_index = start_index + axis_size;
78-
start_index = std::max<int64_t>(start_index, 0);
79-
}
80-
if (end_index < 0) {
81-
if (!(end_index == -1 && stride_index < 0)) { // skip None stop condition
82-
end_index = end_index + axis_size;
83-
if (end_index < 0) {
84-
end_index = 0;
85-
}
86-
}
77+
bool neg_dim_condition = false;
78+
normalize_interval(start_index,
79+
end_index,
80+
stride_index,
81+
axis_size,
82+
&start_index,
83+
&end_index,
84+
&neg_dim_condition);
85+
if (end_index == -axis_size - 1) {
86+
end_index = -1;
8787
}
8888

89-
if (stride_index < 0) {
90-
start_index = start_index + 1;
91-
end_index = end_index + 1;
89+
int64_t out_dims_index;
90+
if (neg_dim_condition) {
91+
out_dims_index = 0;
92+
} else {
93+
int64_t step_size = std::abs(stride_index);
94+
out_dims_index =
95+
(std::abs(end_index - start_index) + step_size - 1) / step_size;
9296
}
93-
94-
bool neg_dim_condition = ((stride_index < 0 && (start_index < end_index)) ||
95-
(stride_index > 0 && (start_index > end_index)));
96-
PADDLE_ENFORCE_EQ(neg_dim_condition,
97-
false,
98-
errors::InvalidArgument(
99-
"The start index and end index are invalid for their "
100-
"corresponding stride."));
101-
102-
int64_t left =
103-
std::max(static_cast<int64_t>(0), std::min(start_index, end_index));
104-
int64_t right = std::min(axis_size, std::max(start_index, end_index));
105-
int64_t step = std::abs(stride_index);
106-
107-
auto out_dims_index = (std::abs(right - left) + step - 1) / step;
108-
10997
out_dims_vector[axes_index] = out_dims_index;
11098
}
11199
}
@@ -136,19 +124,18 @@ static void StridedSliceFunctor(int64_t* starts,
136124
decrease_axis_affect = true;
137125
}
138126
}
139-
// stride must not be zero
140-
if (starts[axis_index] < 0) {
141-
starts[axis_index] = starts[axis_index] + axis_size;
142-
starts[axis_index] = std::max<int64_t>(starts[axis_index], 0);
143-
}
144-
if (ends[axis_index] < 0) {
145-
if (!(ends[axis_index] == -1 &&
146-
strides[axis_index] < 0)) { // skip None stop condition
147-
ends[axis_index] = ends[axis_index] + axis_size;
148-
if (ends[axis_index] < 0) {
149-
ends[axis_index] = 0;
150-
}
151-
}
127+
bool dummy_zero_dim_out = false;
128+
normalize_interval(starts[axis_index],
129+
ends[axis_index],
130+
strides[axis_index],
131+
axis_size,
132+
&starts[axis_index],
133+
&ends[axis_index],
134+
&dummy_zero_dim_out);
135+
if (ends[axis_index] == -axis_size - 1) {
136+
// manually set the end to -1 when step < 0,
137+
// which indicates that it can extend to the left endpoint.
138+
ends[axis_index] = -1;
152139
}
153140
if (decrease_axis_affect) {
154141
if (strides[axis_index] < 0) {

0 commit comments

Comments
 (0)