From 25350da45d5dc9ffbacb5378f901bfc76e1a7c12 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Fri, 18 Oct 2024 16:09:15 +0800 Subject: [PATCH 1/6] Update Adam.py --- backends/sdaa/sdaa_ext/python/custom_parallel/Adam.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/sdaa/sdaa_ext/python/custom_parallel/Adam.py b/backends/sdaa/sdaa_ext/python/custom_parallel/Adam.py index b0064fc71..605fab102 100644 --- a/backends/sdaa/sdaa_ext/python/custom_parallel/Adam.py +++ b/backends/sdaa/sdaa_ext/python/custom_parallel/Adam.py @@ -292,7 +292,7 @@ def _append_optimize_op(self, block, param_and_grad): moment1_ = moment1 moment2_ = moment2 - _, _, _, _, _, _ = paddle._C_ops.adam_( + _, _, _, _, _, *_ = paddle._C_ops.adam_( param_, grad_, lr, From 840429a0f2be5308203bf7fc384262ec2bb02d9a Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Fri, 18 Oct 2024 16:09:36 +0800 Subject: [PATCH 2/6] Update AdamW.py --- backends/sdaa/sdaa_ext/python/custom_parallel/AdamW.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/sdaa/sdaa_ext/python/custom_parallel/AdamW.py b/backends/sdaa/sdaa_ext/python/custom_parallel/AdamW.py index 7b966c04a..3206ba05f 100644 --- a/backends/sdaa/sdaa_ext/python/custom_parallel/AdamW.py +++ b/backends/sdaa/sdaa_ext/python/custom_parallel/AdamW.py @@ -294,7 +294,7 @@ def _append_optimize_op(self, block, param_and_grad): moment1_ = moment1 moment2_ = moment2 - _, _, _, _, _, _ = paddle._C_ops.adamw_( + _, _, _, _, _, *_ = paddle._C_ops.adamw_( param_, grad_, lr, From d94d896fa7852ae28827cca185722b94337a673a Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 17 Dec 2024 16:39:59 +0800 Subject: [PATCH 3/6] fix_slice --- backends/mlu/kernels/slice_kernel.cc | 2 +- backends/mlu/kernels/strided_slice_kernel.cc | 70 ++++++++------------ backends/mlu/tools/disable_ut_mlu | 1 - 3 files changed, 30 insertions(+), 43 deletions(-) diff --git a/backends/mlu/kernels/slice_kernel.cc b/backends/mlu/kernels/slice_kernel.cc index 273548b4d..2bd9d3e35 100644 --- a/backends/mlu/kernels/slice_kernel.cc +++ b/backends/mlu/kernels/slice_kernel.cc @@ -243,7 +243,7 @@ inline void CheckAndUpdateSliceAttrs(const phi::DDim in_dims, &start, &end, &dummy_zero_out_dim); - if (end == -dim_value - 1) { + if (step < 0 && end == -dim_value - 1) { end = -1; } diff --git a/backends/mlu/kernels/strided_slice_kernel.cc b/backends/mlu/kernels/strided_slice_kernel.cc index 32db94920..92f3303b6 100644 --- a/backends/mlu/kernels/strided_slice_kernel.cc +++ b/backends/mlu/kernels/strided_slice_kernel.cc @@ -61,39 +61,27 @@ static void StridedSliceOutDims(const std::vector& starts, continue; } - if (start_index < 0) { - start_index = start_index + axis_size; - start_index = std::max(start_index, 0); - } - if (end_index < 0) { - if (!(end_index == -1 && stride_index < 0)) { // skip None stop condition - end_index = end_index + axis_size; - if (end_index < 0) { - end_index = 0; - } - } + bool neg_dim_condition = false; + normalize_interval(start_index, + end_index, + stride_index, + axis_size, + &start_index, + &end_index, + &neg_dim_condition); + if (stride_index < 0 && end_index == -axis_size - 1) { + end_index = -1; } - if (stride_index < 0) { - start_index = start_index + 1; - end_index = end_index + 1; + int64_t out_dims_index; + if (neg_dim_condition) { + out_dims_index = 0; + } else { + int64_t step_size = std::abs(stride_index); + out_dims_index = + (std::abs(end_index - start_index) + step_size - 1) / step_size; } - bool neg_dim_condition = ((stride_index < 0 && (start_index < end_index)) || - (stride_index > 0 && (start_index > end_index))); - PADDLE_ENFORCE_EQ(neg_dim_condition, - false, - phi::errors::InvalidArgument( - "The start index and end index are invalid for their " - "corresponding stride.")); - - int64_t left = - std::max(static_cast(0), std::min(start_index, end_index)); - int64_t right = std::min(axis_size, std::max(start_index, end_index)); - int64_t step = std::abs(stride_index); - - auto out_dims_index = (std::abs(right - left) + step - 1) / step; - out_dims_vector[axes_index] = out_dims_index; } } @@ -125,18 +113,18 @@ static void StridedSliceFunctor(int64_t* starts, } } // stride must not be zero - if (starts[axis_index] < 0) { - starts[axis_index] = starts[axis_index] + axis_size; - starts[axis_index] = std::max(starts[axis_index], 0); - } - if (ends[axis_index] < 0) { - if (!(ends[axis_index] == -1 && - strides[axis_index] < 0)) { // skip None stop condition - ends[axis_index] = ends[axis_index] + axis_size; - if (ends[axis_index] < 0) { - ends[axis_index] = 0; - } - } + bool dummy_zero_dim_out = false; + normalize_interval(starts[axis_index], + ends[axis_index], + strides[axis_index], + axis_size, + &starts[axis_index], + &ends[axis_index], + &dummy_zero_dim_out); + if (strides[axis_index] < 0 && ends[axis_index] == -axis_size - 1) { + // manually set the end to -1 when step < 0, + // which indicates that it can extend to the left endpoint. + ends[axis_index] = -1; } if (decrease_axis_affect) { if (strides[axis_index] < 0) { diff --git a/backends/mlu/tools/disable_ut_mlu b/backends/mlu/tools/disable_ut_mlu index cfb2175a9..e04306945 100755 --- a/backends/mlu/tools/disable_ut_mlu +++ b/backends/mlu/tools/disable_ut_mlu @@ -14,4 +14,3 @@ test_rms_norm_op_mlu test_sync_batch_norm_op_mlu test_unsqueeze_op_mlu test_LeNet_MNIST -test_strided_slice_op_mlu From 3c756672e6dbdb9686679e48fa8d6dae44929137 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 17 Dec 2024 16:51:51 +0800 Subject: [PATCH 4/6] update header --- backends/mlu/kernels/strided_slice_kernel.cc | 29 ++++++++++---------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/backends/mlu/kernels/strided_slice_kernel.cc b/backends/mlu/kernels/strided_slice_kernel.cc index 92f3303b6..352bef548 100644 --- a/backends/mlu/kernels/strided_slice_kernel.cc +++ b/backends/mlu/kernels/strided_slice_kernel.cc @@ -14,6 +14,7 @@ #include "kernels/funcs/mlu_baseop.h" #include "kernels/funcs/mlu_funcs.h" +#include "paddle/phi/kernels/funcs/slice_utils.h" namespace custom_kernel { static void StridedSliceOutDims(const std::vector& starts, @@ -62,13 +63,13 @@ static void StridedSliceOutDims(const std::vector& starts, } bool neg_dim_condition = false; - normalize_interval(start_index, - end_index, - stride_index, - axis_size, - &start_index, - &end_index, - &neg_dim_condition); + phi::funcs::normalize_interval(start_index, + end_index, + stride_index, + axis_size, + &start_index, + &end_index, + &neg_dim_condition); if (stride_index < 0 && end_index == -axis_size - 1) { end_index = -1; } @@ -114,13 +115,13 @@ static void StridedSliceFunctor(int64_t* starts, } // stride must not be zero bool dummy_zero_dim_out = false; - normalize_interval(starts[axis_index], - ends[axis_index], - strides[axis_index], - axis_size, - &starts[axis_index], - &ends[axis_index], - &dummy_zero_dim_out); + phi::funcs::normalize_interval(starts[axis_index], + ends[axis_index], + strides[axis_index], + axis_size, + &starts[axis_index], + &ends[axis_index], + &dummy_zero_dim_out); if (strides[axis_index] < 0 && ends[axis_index] == -axis_size - 1) { // manually set the end to -1 when step < 0, // which indicates that it can extend to the left endpoint. From 52f2c4f1e6d34fdadb49c409e1123ca456552308 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 19 Dec 2024 11:26:22 +0800 Subject: [PATCH 5/6] fix slice --- backends/mlu/kernels/strided_slice_kernel.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/backends/mlu/kernels/strided_slice_kernel.cc b/backends/mlu/kernels/strided_slice_kernel.cc index 352bef548..6396437f7 100644 --- a/backends/mlu/kernels/strided_slice_kernel.cc +++ b/backends/mlu/kernels/strided_slice_kernel.cc @@ -41,6 +41,8 @@ static void StridedSliceOutDims(const std::vector& starts, auto ret = std::find(decrease_axis.begin(), decrease_axis.end(), axes[i]); if (ret != decrease_axis.end()) { decrease_axis_affect = true; + start_index = in_dims[axes_index] - 1; + end_index = in_dims[axes_index]; } } if (decrease_axis_affect) { @@ -111,6 +113,8 @@ static void StridedSliceFunctor(int64_t* starts, decrease_axis.begin(), decrease_axis.end(), axes[axis_index]); if (ret != decrease_axis.end()) { decrease_axis_affect = true; + starts[axis_index] = axis_size - 1; + ends[axis_index] = axis_size; } } // stride must not be zero From b2272c45383dfec8df048a97a468fc56a7e8bd92 Mon Sep 17 00:00:00 2001 From: --global <--global> Date: Thu, 19 Dec 2024 14:07:29 +0800 Subject: [PATCH 6/6] empty