diff --git a/backends/mlu/kernels/slice_kernel.cc b/backends/mlu/kernels/slice_kernel.cc index c3b7be68b..273548b4d 100644 --- a/backends/mlu/kernels/slice_kernel.cc +++ b/backends/mlu/kernels/slice_kernel.cc @@ -17,6 +17,164 @@ namespace custom_kernel { +/** + * @brief Normalizes the slice interval [st, ed) with a given step and dimension + * size. + * + * This function adjusts the interval [st, ed) to fit within the bounds defined + * by the dimension size, taking into account the specified step. It handles + * both positive and negative steps and accounts for negative indices by + * converting them to equivalent positive indices within the dimension size. + * + * @tparam T The data type of the input parameters, which can be an integer or + * floating-point type. + * @param st The starting index of the interval. + * @param ed The ending index of the interval (exclusive). + * @param step The step size for iterating through the interval, which can be + * positive or negative. + * @param dim_size The size of the dimension, serving as the upper bound for + * valid indices. + * @param st_out Pointer to store the normalized starting index. + * @param ed_out Pointer to store the normalized ending index. + * @param zero_dim_out Pointer to a boolean flag that is set to true if the + * resulting interval is empty. + * + * @details + * - If `step > 0`, the function ensures that `st` and `ed` are adjusted to be + * within the range [0, dim_size). + * - If `step < 0`, the function adjusts `st` and `ed` to accommodate the + * reverse traversal of the interval. + * - Handles special cases where `st` and `ed` may be out of bounds or where + * `dim_size` is zero. + * - Uses pointer parameters for output to modify the values directly. + * - The function also handles scenarios involving negative indices, converting + * them appropriately. + * + * @example + * T st_out, ed_out; + * bool zero_dim; + * normalize_interval(-3, -2, 1, 4, &st_out, &ed_out, &zero_dim); + * // Results in: st_out = 1, ed_out = 2, zero_dim = false + * + * @note The function assumes that the pointers provided for output parameters + * are valid and non-null. + */ +template +void normalize_interval( + T st, T ed, T step, T dim_size, T* st_out, T* ed_out, bool* zero_dim_out) { + /* Normalize slice interval [st, ed) with given step and dim_size. + e.g. if given st = -3, ed = -2, step = 1, dim_size = 4, + then normalized st_out = 1(-3+4), st_ed = 2(-2+4). + + This function is general enough and applicable + for both step > 0 and step < 0 scenarios. + + Indicices dipicted as below: + + =============================================================== + | 0 1 2 3 ... D-1 | D D+1 ... + ... -D-2 -D-1 | -D -D+1 -D+2 -D+3 ... -1 | + =============================================================== + */ + // 0 dim size, just return + if (dim_size <= 0) { + *st_out = *ed_out = 0; + *zero_dim_out = true; + return; + } + + if (step > 0) { + /* positive step */ + // 0 dim size case 1 + if (st >= dim_size) { + *st_out = *ed_out = 0; + *zero_dim_out = true; + return; + } + + // 0 dim size case 2 + if (ed <= -dim_size) { + *st_out = *ed_out = 0; + *zero_dim_out = true; + return; + } + + // make st belongs: (-inf, -D-1)∪[0, D) + if (-dim_size <= st && st < 0) { + st += dim_size; + } + // make st belongs: [0, D) + st = std::max(st, static_cast(0)); + + // make ed belongs: [0, +inf) + if (-dim_size <= ed && ed < 0) { + ed += dim_size; + } + // make ed belongs: [0, D] + ed = std::min(ed, dim_size); + + // 0 dim size case 3 + if (st >= ed) { + *st_out = *ed_out = 0; + *zero_dim_out = true; + return; + } + *st_out = st; + *ed_out = ed; + return; + + } else { + /* negative step */ + // 0 dim size case 1 + if (st <= -dim_size - 1) { + *st_out = *ed_out = 0; + *zero_dim_out = true; + return; + } + + // 0 dim size case 2 + if (ed >= dim_size - 1) { + *st_out = *ed_out = 0; + *zero_dim_out = true; + return; + } + + // make st belongs: [0, D)∪[0, +inf) + if (-dim_size <= st && st < 0) { + st += dim_size; + } + // make st belongs: [0, D) + st = std::min(st, dim_size - 1); + + // make ed belongs: [-inf, -D)∪[0, D) + if (-dim_size <= ed && ed < 0) { + ed += dim_size; + } + // make ed belongs: [-D-1, -D)∪[0, D) ==> {-D-1}∪[0, D) + ed = std::max(ed, -dim_size - 1); + + if (ed == -dim_size - 1) { + // When ed=-D-1, it is symmetrical to when step is greater than 0 and + // ed=D. + *st_out = st; + *ed_out = ed; + return; + } + + // now only remain the case that ed belongs to: [0, D) + // 0 dim size case 3 + if (ed >= st) { + *st_out = *ed_out = 0; + *zero_dim_out = true; + return; + } + + *st_out = st; + *ed_out = ed; + return; + } +} + void UpdateAttr(const phi::DDim& in_dims, const std::vector axes, const std::vector starts, @@ -76,47 +234,17 @@ inline void CheckAndUpdateSliceAttrs(const phi::DDim in_dims, if (dim_value > 0) { T step = steps == nullptr ? 1 : (*steps)[i]; - PADDLE_ENFORCE_NE( - step, - 0, - phi::errors::InvalidArgument( - "Step should not be 0, but received step = %d.", step)); - - T start = (*starts)[i] < 0 ? ((*starts)[i] + dim_value) : (*starts)[i]; - start = std::max(start, static_cast(0)); - - T end = - 0 < step && (*ends)[i] < 0 ? ((*ends)[i] + dim_value) : (*ends)[i]; - end = std::min(end, dim_value); - - if (step > 0) { - start = std::min(start, dim_value); - end = std::max(end, static_cast(0)); - PADDLE_ENFORCE_GE( - end, - start, - phi::errors::InvalidArgument( - "When step > 0, end should be greater than start, but " - "received end = %d, start = %d.", - end, - start)); - } else { - // NOTE(liym27): When step < 0, start should less and equal to - // dim_value-1 - // "end is -1" means contain the 0-th element of this axis. - start = std::min(start, dim_value - 1); - if (end < -1) { - end += dim_value; - } - end = std::max(end, static_cast(-1)); - PADDLE_ENFORCE_GE( - start, - end, - phi::errors::InvalidArgument( - "When step < 0, start should be greater than end, but " - "received start = %d, end = %d.", - start, - end)); + T start, end; + bool dummy_zero_out_dim = false; + normalize_interval((*starts)[i], + (*ends)[i], + step, + dim_value, + &start, + &end, + &dummy_zero_out_dim); + if (end == -dim_value - 1) { + end = -1; } (*starts)[i] = start; diff --git a/backends/mlu/tests/unittests/test_multinomial_op_mlu.py b/backends/mlu/tests/unittests/test_multinomial_op_mlu.py index 7a320c900..2170103b7 100644 --- a/backends/mlu/tests/unittests/test_multinomial_op_mlu.py +++ b/backends/mlu/tests/unittests/test_multinomial_op_mlu.py @@ -287,11 +287,6 @@ def test_dim_less_than_1(): self.assertRaises(ValueError, test_dim_less_than_1) - with self.assertRaises(ValueError): - prob = paddle.rand([20, 1000]) - prob[1:0] = 0 - out = paddle.multinomial(prob) - if __name__ == "__main__": unittest.main() diff --git a/backends/npu/tests/unittests/test_multinomial_op_npu.py b/backends/npu/tests/unittests/test_multinomial_op_npu.py index 45173e359..57bd44670 100644 --- a/backends/npu/tests/unittests/test_multinomial_op_npu.py +++ b/backends/npu/tests/unittests/test_multinomial_op_npu.py @@ -287,11 +287,6 @@ def test_dim_less_than_1(): self.assertRaises(ValueError, test_dim_less_than_1) - with self.assertRaises(ValueError): - prob = paddle.rand([20, 1000]) - prob[1:0] = 0 - out = paddle.multinomial(prob) - if __name__ == "__main__": unittest.main()