@@ -23,6 +23,164 @@ namespace phi {
23
23
24
24
namespace funcs {
25
25
26
+ /* *
27
+ * @brief Normalizes the slice interval [st, ed) with a given step and dimension
28
+ * size.
29
+ *
30
+ * This function adjusts the interval [st, ed) to fit within the bounds defined
31
+ * by the dimension size, taking into account the specified step. It handles
32
+ * both positive and negative steps and accounts for negative indices by
33
+ * converting them to equivalent positive indices within the dimension size.
34
+ *
35
+ * @tparam T The data type of the input parameters, which can be an integer or
36
+ * floating-point type.
37
+ * @param st The starting index of the interval.
38
+ * @param ed The ending index of the interval (exclusive).
39
+ * @param step The step size for iterating through the interval, which can be
40
+ * positive or negative.
41
+ * @param dim_size The size of the dimension, serving as the upper bound for
42
+ * valid indices.
43
+ * @param st_out Pointer to store the normalized starting index.
44
+ * @param ed_out Pointer to store the normalized ending index.
45
+ * @param zero_dim_out Pointer to a boolean flag that is set to true if the
46
+ * resulting interval is empty.
47
+ *
48
+ * @details
49
+ * - If `step > 0`, the function ensures that `st` and `ed` are adjusted to be
50
+ * within the range [0, dim_size).
51
+ * - If `step < 0`, the function adjusts `st` and `ed` to accommodate the
52
+ * reverse traversal of the interval.
53
+ * - Handles special cases where `st` and `ed` may be out of bounds or where
54
+ * `dim_size` is zero.
55
+ * - Uses pointer parameters for output to modify the values directly.
56
+ * - The function also handles scenarios involving negative indices, converting
57
+ * them appropriately.
58
+ *
59
+ * @example
60
+ * T st_out, ed_out;
61
+ * bool zero_dim;
62
+ * normalize_interval(-3, -2, 1, 4, &st_out, &ed_out, &zero_dim);
63
+ * // Results in: st_out = 1, ed_out = 2, zero_dim = false
64
+ *
65
+ * @note The function assumes that the pointers provided for output parameters
66
+ * are valid and non-null.
67
+ */
68
+ template <typename T>
69
+ void normalize_interval (
70
+ T st, T ed, T step, T dim_size, T* st_out, T* ed_out, bool * zero_dim_out) {
71
+ /* Normalize slice interval [st, ed) with given step and dim_size.
72
+ e.g. if given st = -3, ed = -2, step = 1, dim_size = 4,
73
+ then normalized st_out = 1(-3+4), st_ed = 2(-2+4).
74
+
75
+ This function is general enough and applicable
76
+ for both step > 0 and step < 0 scenarios.
77
+
78
+ Indicices dipicted as below:
79
+
80
+ ===============================================================
81
+ | 0 1 2 3 ... D-1 | D D+1 ...
82
+ ... -D-2 -D-1 | -D -D+1 -D+2 -D+3 ... -1 |
83
+ ===============================================================
84
+ */
85
+ // 0 dim size, just return
86
+ if (dim_size <= 0 ) {
87
+ *st_out = *ed_out = 0 ;
88
+ *zero_dim_out = true ;
89
+ return ;
90
+ }
91
+
92
+ if (step > 0 ) {
93
+ /* positive step */
94
+ // 0 dim size case 1
95
+ if (st >= dim_size) {
96
+ *st_out = *ed_out = 0 ;
97
+ *zero_dim_out = true ;
98
+ return ;
99
+ }
100
+
101
+ // 0 dim size case 2
102
+ if (ed <= -dim_size) {
103
+ *st_out = *ed_out = 0 ;
104
+ *zero_dim_out = true ;
105
+ return ;
106
+ }
107
+
108
+ // make st belongs: (-inf, -D-1)∪[0, D)
109
+ if (-dim_size <= st && st < 0 ) {
110
+ st += dim_size;
111
+ }
112
+ // make st belongs: [0, D)
113
+ st = std::max (st, static_cast <T>(0 ));
114
+
115
+ // make ed belongs: [0, +inf)
116
+ if (-dim_size <= ed && ed < 0 ) {
117
+ ed += dim_size;
118
+ }
119
+ // make ed belongs: [0, D]
120
+ ed = std::min (ed, dim_size);
121
+
122
+ // 0 dim size case 3
123
+ if (st >= ed) {
124
+ *st_out = *ed_out = 0 ;
125
+ *zero_dim_out = true ;
126
+ return ;
127
+ }
128
+ *st_out = st;
129
+ *ed_out = ed;
130
+ return ;
131
+
132
+ } else {
133
+ /* negative step */
134
+ // 0 dim size case 1
135
+ if (st <= -dim_size - 1 ) {
136
+ *st_out = *ed_out = 0 ;
137
+ *zero_dim_out = true ;
138
+ return ;
139
+ }
140
+
141
+ // 0 dim size case 2
142
+ if (ed >= dim_size - 1 ) {
143
+ *st_out = *ed_out = 0 ;
144
+ *zero_dim_out = true ;
145
+ return ;
146
+ }
147
+
148
+ // make st belongs: [0, D)∪[0, +inf)
149
+ if (-dim_size <= st && st < 0 ) {
150
+ st += dim_size;
151
+ }
152
+ // make st belongs: [0, D)
153
+ st = std::min (st, dim_size - 1 );
154
+
155
+ // make ed belongs: [-inf, -D)∪[0, D)
156
+ if (-dim_size <= ed && ed < 0 ) {
157
+ ed += dim_size;
158
+ }
159
+ // make ed belongs: [-D-1, -D)∪[0, D) ==> {-D-1}∪[0, D)
160
+ ed = std::max (ed, -dim_size - 1 );
161
+
162
+ if (ed == -dim_size - 1 ) {
163
+ // When ed=-D-1, it is symmetrical to when step is greater than 0 and
164
+ // ed=D.
165
+ *st_out = st;
166
+ *ed_out = ed;
167
+ return ;
168
+ }
169
+
170
+ // now only remain the case that ed belongs to: [0, D)
171
+ // 0 dim size case 3
172
+ if (ed >= st) {
173
+ *st_out = *ed_out = 0 ;
174
+ *zero_dim_out = true ;
175
+ return ;
176
+ }
177
+
178
+ *st_out = st;
179
+ *ed_out = ed;
180
+ return ;
181
+ }
182
+ }
183
+
26
184
template <typename T = int64_t >
27
185
inline void CheckAndUpdateSliceAttrs (const DDim in_dims,
28
186
const std::vector<T>& axes,
@@ -56,41 +214,17 @@ inline void CheckAndUpdateSliceAttrs(const DDim in_dims,
56
214
common::errors::InvalidArgument (
57
215
" Step should not be 0, but received step = %d." , step));
58
216
59
- T start = (*starts)[i] < 0 ? ((*starts)[i] + dim_value) : (*starts)[i];
60
- start = std::max (start, static_cast <T>(0 ));
61
-
62
- T end =
63
- 0 < step && (*ends)[i] < 0 ? ((*ends)[i] + dim_value) : (*ends)[i];
64
- end = std::min (end, dim_value);
65
-
66
- if (step > 0 ) {
67
- start = std::min (start, dim_value);
68
- end = std::max (end, static_cast <T>(0 ));
69
- PADDLE_ENFORCE_GE (
70
- end,
71
- start,
72
- common::errors::InvalidArgument (
73
- " When step > 0, end should be greater than start, but "
74
- " received end = %d, start = %d." ,
75
- end,
76
- start));
77
- } else {
78
- // NOTE(liym27): When step < 0, start should less and equal to
79
- // dim_value-1
80
- // "end is -1" means contain the 0-th element of this axis.
81
- start = std::min (start, dim_value - 1 );
82
- if (end < -1 ) {
83
- end += dim_value;
84
- }
85
- end = std::max (end, static_cast <T>(-1 ));
86
- PADDLE_ENFORCE_GE (
87
- start,
88
- end,
89
- common::errors::InvalidArgument (
90
- " When step < 0, start should be greater than end, but "
91
- " received start = %d, end = %d." ,
92
- start,
93
- end));
217
+ T start, end;
218
+ bool dummy_zero_out_dim = false ;
219
+ normalize_interval ((*starts)[i],
220
+ (*ends)[i],
221
+ step,
222
+ dim_value,
223
+ &start,
224
+ &end,
225
+ &dummy_zero_out_dim);
226
+ if (end == -dim_value - 1 ) {
227
+ end = -1 ;
94
228
}
95
229
96
230
(*starts)[i] = start;
@@ -117,24 +251,17 @@ inline void UpdateSliceAttrs(const DDim in_dims,
117
251
T dim_value = in_dims[axis];
118
252
if (dim_value > 0 ) {
119
253
T step = steps == nullptr ? 1 : (*steps)[i];
120
- T start = (*starts)[i] < 0 ? ((*starts)[i] + dim_value) : (*starts)[i];
121
- start = std::max (start, static_cast <T>(0 ));
122
- T end =
123
- 0 < step && (*ends)[i] < 0 ? ((*ends)[i] + dim_value) : (*ends)[i];
124
- end = std::min (end, dim_value);
125
-
126
- if (step > 0 ) {
127
- start = std::min (start, dim_value);
128
- end = std::max (end, static_cast <T>(0 ));
129
- } else {
130
- // NOTE: When step < 0, start should less and equal to
131
- // dim_value-1
132
- // "end is -1" means contain the 0-th element of this axis.
133
- start = std::min (start, dim_value - 1 );
134
- if (end < -1 ) {
135
- end += dim_value;
136
- }
137
- end = std::max (end, static_cast <T>(-1 ));
254
+ T start = (*starts)[i];
255
+ T end = (*ends)[i];
256
+
257
+ bool dummy_zero_out_dim = false ;
258
+ normalize_interval (
259
+ start, end, step, dim_value, &start, &end, &dummy_zero_out_dim);
260
+
261
+ // manually set the end to -1 when step < 0,
262
+ // which indicates that it can extend to the left endpoint.
263
+ if (end == -dim_value - 1 && step < 0 ) {
264
+ end = -1 ;
138
265
}
139
266
(*starts)[i] = start;
140
267
(*ends)[i] = end;
0 commit comments