From c03540e2deef199f8390ae5bbfda869b1a6ce6d2 Mon Sep 17 00:00:00 2001 From: decade-afk <3995409050@qq.com> Date: Sun, 26 Jan 2025 10:54:25 +0800 Subject: [PATCH] update --- paddle/phi/backends/dynload/cusolver.h | 2 ++ paddle/phi/kernels/cpu/lu_solve_kernel.cc | 11 +++++++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/paddle/phi/backends/dynload/cusolver.h b/paddle/phi/backends/dynload/cusolver.h index 74c64085ea721..adbc5cdf0b6e9 100644 --- a/paddle/phi/backends/dynload/cusolver.h +++ b/paddle/phi/backends/dynload/cusolver.h @@ -65,6 +65,8 @@ CUSOLVER_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP); #if CUDA_VERSION >= 9020 #define CUSOLVER_ROUTINE_EACH_R1(__macro) \ + __macro(cusolverDnSgetrs); \ + __macro(cusolverDnDgetrs); \ __macro(cusolverDnSpotrfBatched); \ __macro(cusolverDnDpotrfBatched); \ __macro(cusolverDnSpotrsBatched); \ diff --git a/paddle/phi/kernels/cpu/lu_solve_kernel.cc b/paddle/phi/kernels/cpu/lu_solve_kernel.cc index 59b8cb20f5b27..1d05365b5a345 100644 --- a/paddle/phi/kernels/cpu/lu_solve_kernel.cc +++ b/paddle/phi/kernels/cpu/lu_solve_kernel.cc @@ -47,11 +47,16 @@ void LuSolveKernel(const Context& dev_ctx, const auto& x_dims = x.dims(); const int64_t nrhs = x_dims[x_dims.size() - 1]; // Number of columns + // Get number of right-hand sides from x + const auto& x_dims = x.dims(); + const int64_t nrhs = x_dims[x_dims.size() - 1]; // Number of columns + // Allocate output tensor dev_ctx.template Alloc(out); // Copy RHS data to output (will be overwritten with solution) - std::copy_n(x.data(), x.numel(), out->data()); + // std::copy_n(x.data(), x.numel(), out->data()); + phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); // Prepare LAPACK parameters char trans_char = (trans == "N") ? 'N' : ((trans == "T") ? 'T' : 'C'); @@ -64,9 +69,11 @@ void LuSolveKernel(const Context& dev_ctx, auto outdims = out->dims(); auto outrank = outdims.size(); auto batchsize = product(common::slice_ddim(outdims, 0, outrank - 2)); + auto out_data = out->data(); auto lu_data = lu.data(); auto pivots_data = pivots.data(); + for (int i = 0; i < batchsize; i++) { auto out_data_item = &out_data[i * n_int * n_int]; auto* lu_data_item = &lu_data[i * n_int * n_int]; @@ -79,7 +86,7 @@ void LuSolveKernel(const Context& dev_ctx, pivots_data_item, out_data_item, ldb, - info); + *info); PADDLE_ENFORCE_EQ( info,