diff --git a/.prospector.yml b/.prospector.yml index e8a1a392..63fa1236 100644 --- a/.prospector.yml +++ b/.prospector.yml @@ -27,8 +27,8 @@ pep8: pylint: run: true options: - max-args: 10 - max-positional-arguments: 10 + max-args: 15 + max-positional-arguments: 15 disable: - import-error - import-outside-toplevel diff --git a/Makefile b/Makefile index b5f43c55..04a9baa6 100644 --- a/Makefile +++ b/Makefile @@ -23,6 +23,8 @@ build_cmake: clean build_wheel: export CMAKE_GENERATOR="Unix Makefiles" && pip wheel -v . --extra-index-url https://download.pytorch.org/whl/cpu +build_cpu: + python scripts/build_cpu.py # 'make compile_abs' compiles 'kernel_abs.cpp' into 'libkernel_abs.so' without building the whole wheel package. # This is useful for development and debugging of individual kernels. @@ -47,3 +49,6 @@ test: test_tri_inv: pytest tests/test_tri_inv_*.py + +test_cpu: + python scripts/test_cpu.py diff --git a/README.md b/README.md index fffb7f74..d68787d3 100644 --- a/README.md +++ b/README.md @@ -58,6 +58,21 @@ make test --- +## CPU Simulation + +Subset of the kernels supports CPU simulation. Kernels can be built with + +```bash +make build_cpu +``` +and tested with + +```bash +make test_cpu +``` + +--- + ## Repository Structure ``` diff --git a/csrc/host/pybind11.cpp b/csrc/host/pybind11.cpp index 248cce22..1033241a 100644 --- a/csrc/host/pybind11.cpp +++ b/csrc/host/pybind11.cpp @@ -11,26 +11,36 @@ for the full License text. #include "torch_abs.h" #include "torch_batch_matrix_square.h" -#include "torch_chunk_cumsum.h" #include "torch_csr_gather.h" -#include "torch_gdn_chunk_h.h" -#include "torch_gdn_chunk_o.h" -#include "torch_gdn_scaled_dot_kkt.h" -#include "torch_gdn_wy_fast.h" #include "torch_scan_ul1.h" #include "torch_simple_matmul.h" #include "torch_swiglu.h" #include "torch_tri_inv.h" #include "torch_tri_inv_ns.h" -#include "torch_tri_inv_rec_unroll.h" #include "torch_tri_inv_trick.h" +#ifndef __CPU_SIM +#include "torch_chunk_cumsum.h" +#include "torch_gdn_chunk_h.h" +#include "torch_gdn_chunk_o.h" +#include "torch_gdn_scaled_dot_kkt.h" +#include "torch_gdn_wy_fast.h" +#include "torch_tri_inv_rec_unroll.h" +#endif using namespace pto_isa_ops; +// Not really needed, but to ensure different modules in PyTorch and avoid any +// potential name clashing. This way, both can coexist in a single runtime. +#ifdef __CPU_SIM +#define MODULE_NAME pto_kernels_cpu +#else +#define MODULE_NAME pto_kernels_ops +#endif + /** * @brief Pybind11 module. */ -PYBIND11_MODULE(pto_kernels_ops, m) { +PYBIND11_MODULE(MODULE_NAME, m) { m.doc() = "PTO-ISA Kernels"; m.def( "get_aic_cores", @@ -43,6 +53,24 @@ PYBIND11_MODULE(pto_kernels_ops, m) { }, pybind11::arg("device_id") = 0); m.def("pto_abs", &pto_isa_ops::run_abs); + m.def("pto_batch_matrix_square", &pto_isa_ops::run_batch_matrix_square); + m.def("pto_csr_gather", &pto_isa_ops::run_csr_gather); + m.def("pto_scan_ul1", &pto_isa_ops::run_scan_ul1); + m.def("pto_simple_matmul", &pto_isa_ops::run_simple_matmul); + m.def("pto_swiglu", &pto_isa_ops::run_swiglu, py::arg("x"), + py::arg("dim") = -1); + m.def("pto_tri_inv_trick", &pto_isa_ops::run_tri_inv_trick); + m.def("pto_tri_inv_ns", &pto_isa_ops::run_tri_inv_ns, py::arg("M"), + py::arg("num_iters") = 0, py::arg("scale_value") = 0.0f); + m.def("pto_tri_inv", &pto_isa_ops::run_tri_inv); +#ifndef __CPU_SIM + m.def("pto_tri_inv_rec_unroll", &pto_isa_ops::run_tri_inv_rec_unroll, + py::arg("M"), py::arg("cu_seqlens") = at::zeros({1}), + py::arg("is_bsnd_format") = false, + py::arg("dtype_out") = at::ScalarType::Half); + m.def("pto_tri_inv_ns", &pto_isa_ops::run_tri_inv_ns, py::arg("M"), + py::arg("num_iters") = 0, py::arg("scale_value") = 0.0f); + m.def("pto_tri_inv", &pto_isa_ops::run_tri_inv); m.def("pto_chunk_h", &pto_isa_ops::run_gdn_chunk_h, py::arg("K"), py::arg("W"), py::arg("U"), py::arg("G"), py::arg("cu_seqlens") = at::zeros({1}), py::arg("batch_size"), @@ -50,8 +78,6 @@ PYBIND11_MODULE(pto_kernels_ops, m) { m.def("pto_chunk_cumsum", &pto_isa_ops::run_chunk_cumsum, py::arg("g"), py::arg("batch_size"), py::arg("seq_len"), py::arg("cu_seqlens") = at::zeros({1})); - m.def("pto_batch_matrix_square", &pto_isa_ops::run_batch_matrix_square); - m.def("pto_csr_gather", &pto_isa_ops::run_csr_gather); m.def("pto_gdn_scaled_dot_kkt", &pto_isa_ops::run_gdn_scaled_dot_kkt, py::arg("K"), py::arg("Beta"), py::arg("G"), py::arg("Msk"), py::arg("batch_size"), py::arg("seq_len"), @@ -60,20 +86,9 @@ PYBIND11_MODULE(pto_kernels_ops, m) { py::arg("K"), py::arg("V"), py::arg("S"), py::arg("G"), py::arg("Msk"), py::arg("batch_size"), py::arg("seq_len"), py::arg("cu_seqlens") = at::zeros({1})); - m.def("pto_scan_ul1", &pto_isa_ops::run_scan_ul1); - m.def("pto_simple_matmul", &pto_isa_ops::run_simple_matmul); - m.def("pto_swiglu", &pto_isa_ops::run_swiglu, py::arg("x"), - py::arg("dim") = -1); - m.def("pto_tri_inv_trick", &pto_isa_ops::run_tri_inv_trick); - m.def("pto_tri_inv_rec_unroll", &pto_isa_ops::run_tri_inv_rec_unroll, - py::arg("M"), py::arg("cu_seqlens") = at::zeros({1}), - py::arg("is_bsnd_format") = false, - py::arg("dtype_out") = at::ScalarType::Half); - m.def("pto_tri_inv_ns", &pto_isa_ops::run_tri_inv_ns, py::arg("M"), - py::arg("num_iters") = 0, py::arg("scale_value") = 0.0f); - m.def("pto_tri_inv", &pto_isa_ops::run_tri_inv); m.def("pto_gdn_wy_fast", &pto_isa_ops::run_gdn_wy_fast, py::arg("K"), py::arg("V"), py::arg("Beta"), py::arg("G"), py::arg("A"), py::arg("batch_size"), py::arg("seq_len"), py::arg("cu_seqlens") = at::zeros({1})); +#endif } diff --git a/csrc/host/torch_abs.h b/csrc/host/torch_abs.h index 32d607b2..0fbd5406 100644 --- a/csrc/host/torch_abs.h +++ b/csrc/host/torch_abs.h @@ -11,9 +11,14 @@ for the full License text. #include #include +#ifdef __CPU_SIM +#include "../kernel/kernel_abs.h" +#include "utils_cpu.h" +#else #include "aclrtlaunch_vabs_fp16.h" #include "aclrtlaunch_vabs_fp32.h" #include "utils.h" +#endif namespace pto_isa_ops { @@ -40,8 +45,10 @@ at::Tensor run_abs(const at::Tensor& x) { block_dim = total_tiles; } +#ifndef __CPU_SIM TORCH_CHECK(x.device().type() == DEVICE_TYPE, "pto_abs: tensor must be on NPU, got ", x.device()); +#endif TORCH_CHECK(dtype == at::kHalf || dtype == at::kFloat, "pto_abs: dtype must be fp16 or float32, got ", dtype); if (dtype == at::kHalf) { diff --git a/csrc/host/torch_batch_matrix_square.h b/csrc/host/torch_batch_matrix_square.h index c72b0d3f..17a86f1d 100644 --- a/csrc/host/torch_batch_matrix_square.h +++ b/csrc/host/torch_batch_matrix_square.h @@ -11,9 +11,14 @@ for the full License text. #include #include +#ifdef __CPU_SIM +#include "../kernel/kernel_batch_matrix_square.h" +#include "utils_cpu.h" +#else #include "aclrtlaunch_batch_matrix_square_fp16.h" #include "aclrtlaunch_batch_matrix_square_fp32.h" #include "utils.h" +#endif namespace pto_isa_ops { @@ -29,8 +34,10 @@ at::Tensor run_batch_matrix_square(const at::Tensor& x) { const auto dtype = x.options().dtype(); const auto dtype_out = at::kFloat; +#ifndef __CPU_SIM TORCH_CHECK(device.type() == DEVICE_TYPE, "batch_matrix_square: tensor must be on NPU, got ", device); +#endif TORCH_CHECK(dtype == at::kHalf || dtype == at::kFloat, "batch_matrix_square: dtype must be fp16 or float32, got ", dtype); diff --git a/csrc/host/torch_csr_gather.h b/csrc/host/torch_csr_gather.h index a1096e9e..1bb1e7c3 100644 --- a/csrc/host/torch_csr_gather.h +++ b/csrc/host/torch_csr_gather.h @@ -11,9 +11,14 @@ for the full License text. #include #include +#ifdef __CPU_SIM +#include "../kernel/kernel_csr_gather.h" +#include "utils_cpu.h" +#else #include "aclrtlaunch_csr_gather_fp16.h" #include "aclrtlaunch_csr_gather_fp32.h" #include "utils.h" +#endif namespace pto_isa_ops { @@ -52,12 +57,14 @@ at::Tensor run_csr_gather(const at::Tensor& values, const at::Tensor& indices, block_dim = total_tiles; } +#ifndef __CPU_SIM TORCH_CHECK(values.device().type() == DEVICE_TYPE, "csr_gather: tensor must be on NPU, got ", values.device()); TORCH_CHECK(indices.device().type() == DEVICE_TYPE, "csr_gather: tensor must be on NPU, got ", indices.device()); TORCH_CHECK(x.device().type() == DEVICE_TYPE, "csr_gather: tensor must be on NPU, got ", x.device()); +#endif TORCH_CHECK(dtype == at::kHalf || dtype == at::kFloat, "csr_gather: dtype must be fp16 or float32, got ", dtype); if (dtype == at::kHalf) { diff --git a/csrc/host/torch_scan_ul1.h b/csrc/host/torch_scan_ul1.h index c02b93bb..5be64af2 100644 --- a/csrc/host/torch_scan_ul1.h +++ b/csrc/host/torch_scan_ul1.h @@ -11,9 +11,14 @@ for the full License text. #include #include +#ifdef __CPU_SIM +#include "../kernel/kernel_scan_ul1.h" +#include "utils_cpu.h" +#else #include "aclrtlaunch_scan_ul1_fp16.h" #include "aclrtlaunch_scan_ul1_fp32.h" #include "utils.h" +#endif namespace pto_isa_ops { @@ -28,8 +33,10 @@ at::Tensor run_scan_ul1(const at::Tensor& x) { const auto dtype = x.options().dtype(); const auto dtype_out = at::kFloat; +#ifndef __CPU_SIM TORCH_CHECK(device.type() == DEVICE_TYPE, "scan_ul1: tensor must be on NPU, got ", device); +#endif TORCH_CHECK(dtype == at::kHalf || dtype == at::kFloat, "scan_ul1: dtype must be fp16 or float32, got ", dtype); diff --git a/csrc/host/torch_simple_matmul.h b/csrc/host/torch_simple_matmul.h index bd3c7b68..2604ee95 100644 --- a/csrc/host/torch_simple_matmul.h +++ b/csrc/host/torch_simple_matmul.h @@ -11,10 +11,15 @@ for the full License text. #include #include +#ifdef __CPU_SIM +#include "../kernel/kernel_simple_matmul.h" +#include "utils_cpu.h" +#else #include "aclrtlaunch_simple_matmul_bf16.h" #include "aclrtlaunch_simple_matmul_fp16.h" #include "aclrtlaunch_simple_matmul_fp32.h" #include "utils.h" +#endif namespace pto_isa_ops { @@ -30,10 +35,12 @@ at::Tensor run_simple_matmul(const at::Tensor& a, const at::Tensor& b) { const auto dtype = a.options().dtype(); const auto dtype_out = at::kFloat; +#ifndef __CPU_SIM TORCH_CHECK(device.type() == DEVICE_TYPE, "simple_matmul: tensor must be on NPU, got ", device); TORCH_CHECK(b.device().type() == DEVICE_TYPE, "simple_matmul: tensor must be on NPU, got ", b.device()); +#endif TORCH_CHECK( dtype == at::kHalf || dtype == at::kFloat || dtype == at::kBFloat16, "simple_matmul: dtype must be fp16, bf16, or float32, got ", dtype); diff --git a/csrc/host/torch_swiglu.h b/csrc/host/torch_swiglu.h index a965c159..40a1d6a7 100644 --- a/csrc/host/torch_swiglu.h +++ b/csrc/host/torch_swiglu.h @@ -13,8 +13,13 @@ for the full License text. #include +#ifdef __CPU_SIM +#include "../kernel/kernel_swiglu.h" +#include "utils_cpu.h" +#else #include "aclrtlaunch_swiglu_fp16.h" #include "utils.h" +#endif namespace pto_isa_ops { @@ -35,8 +40,10 @@ at::Tensor run_swiglu(const at::Tensor& x, int64_t dim = -1) { dim += x.dim(); } TORCH_CHECK(dim == 1, "swiglu: currently supports only dim=-1"); +#ifndef __CPU_SIM TORCH_CHECK(x.device().type() == DEVICE_TYPE, "swiglu: tensor must be on NPU, got ", x.device()); +#endif TORCH_CHECK(x.scalar_type() == at::kHalf, "swiglu: dtype must be fp16, got ", x.scalar_type()); TORCH_CHECK(x.is_contiguous(), "swiglu: expects a contiguous input tensor"); diff --git a/csrc/host/torch_tri_inv.h b/csrc/host/torch_tri_inv.h index 302204f4..d85479a1 100644 --- a/csrc/host/torch_tri_inv.h +++ b/csrc/host/torch_tri_inv.h @@ -11,9 +11,14 @@ for the full License text. #include #include +#ifdef __CPU_SIM +#include "../kernel/kernel_tri_inv_col_sweep.h" +#include "utils_cpu.h" +#else #include "aclrtlaunch_triv_inv_col_sweep_fp16.h" #include "aclrtlaunch_triv_inv_col_sweep_fp32.h" #include "utils.h" +#endif namespace pto_isa_ops { @@ -28,8 +33,10 @@ namespace pto_isa_ops { at::Tensor run_tri_inv(const at::Tensor& x) { const at::Device device = x.options().device(); const auto dtype = x.options().dtype(); +#ifndef __CPU_SIM TORCH_CHECK(device.type() == DEVICE_TYPE, "tri_inv: tensor must be on NPU, got ", device); +#endif TORCH_CHECK(x.dim() >= 2, "tri_inv: input tensor must have at least 2 dimensions, got ", x.dim()); diff --git a/csrc/host/torch_tri_inv_ns.h b/csrc/host/torch_tri_inv_ns.h index fafdd3b5..71facee2 100644 --- a/csrc/host/torch_tri_inv_ns.h +++ b/csrc/host/torch_tri_inv_ns.h @@ -13,8 +13,13 @@ for the full License text. #include +#ifdef __CPU_SIM +#include "../kernel/kernel_tri_inv_ns.h" +#include "utils_cpu.h" +#else #include "aclrtlaunch_tri_inv_ns_fp16.h" #include "utils.h" +#endif namespace pto_isa_ops { @@ -43,8 +48,10 @@ at::Tensor run_tri_inv_ns(const at::Tensor& M, uint32_t num_iters = 0, const auto dtype = M.options().dtype(); const auto dtype_out = at::kFloat; +#ifndef __CPU_SIM TORCH_CHECK(device.type() == DEVICE_TYPE, "tri_inv_ns: tensor must be on NPU, got ", device); +#endif TORCH_CHECK(dtype == at::kHalf, "tri_inv_ns: dtype must be fp16, got ", dtype); const uint32_t n = static_cast(M.size(-1)); diff --git a/csrc/host/torch_tri_inv_rec_unroll.h b/csrc/host/torch_tri_inv_rec_unroll.h index 10caab36..9b94f444 100644 --- a/csrc/host/torch_tri_inv_rec_unroll.h +++ b/csrc/host/torch_tri_inv_rec_unroll.h @@ -11,9 +11,14 @@ for the full License text. #include #include +#ifdef __CPU_SIM +#include "../kernel/kernel_tri_inv_rec_unroll.h" +#include "utils_cpu.h" +#else #include "aclrtlaunch_tri_inv_rec_unroll_fp16fp16.h" #include "aclrtlaunch_tri_inv_rec_unroll_fp16fp32.h" #include "utils.h" +#endif namespace pto_isa_ops { @@ -41,8 +46,10 @@ at::Tensor run_tri_inv_rec_unroll( const at::Device device = M.options().device(); const auto dtype = M.options().dtype(); +#ifndef __CPU_SIM TORCH_CHECK(device.type() == DEVICE_TYPE, "tri_inv_ns: tensor must be on NPU, got ", device); +#endif TORCH_CHECK(dtype_out == at::kHalf || dtype_out == at::kFloat, "tri_inv_rec_unroll: dtype_out must be fp16 or float32, got ", dtype_out); diff --git a/csrc/host/torch_tri_inv_trick.h b/csrc/host/torch_tri_inv_trick.h index e16e653e..68e227e0 100644 --- a/csrc/host/torch_tri_inv_trick.h +++ b/csrc/host/torch_tri_inv_trick.h @@ -11,8 +11,13 @@ for the full License text. #include #include +#ifdef __CPU_SIM +#include "../kernel/kernel_tri_inv_trick.h" +#include "utils_cpu.h" +#else #include "aclrtlaunch_tri_inv_trick_fp16.h" #include "utils.h" +#endif namespace pto_isa_ops { @@ -30,8 +35,10 @@ at::Tensor run_tri_inv_trick(const at::Tensor& M) { const uint32_t max_block_size = 16; const uint32_t matrix_size = static_cast(M.size(-1)); +#ifndef __CPU_SIM TORCH_CHECK(device.type() == DEVICE_TYPE, "tri_inv_ns: tensor must be on NPU, got ", device); +#endif TORCH_CHECK(dtype == at::kHalf, "tri_inv_trick: dtype must be fp16, got ", dtype); TORCH_CHECK(matrix_size == static_cast(M.size(-2)), diff --git a/csrc/host/utils_cpu.h b/csrc/host/utils_cpu.h new file mode 100644 index 00000000..581eee1d --- /dev/null +++ b/csrc/host/utils_cpu.h @@ -0,0 +1,133 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under +the terms and conditions of CANN Open Software License Agreement Version 2.0 +(the "License"). Please refer to the License for details. You may not use this +file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN "AS +IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING +BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A +PARTICULAR PURPOSE. See LICENSE in the root of the software repository for the +full text of the License. +*/ + +#ifndef EXTENSION_CSRC_UTILS_CPU_H +#define EXTENSION_CSRC_UTILS_CPU_H + +#include + +#include +#include +#include + +// Dummy defintions to be used by SOC_VERSION +#define Ascend910B1 1 +#define Ascend910B2 2 +#define Ascend910B3 3 +#define Ascend910B4 4 + +#if !defined(SOC_VERSION) +#define SOC_VERSION Ascend910B4 +#endif + +namespace pto_isa_ops { + +/** + * @brief Thread-local variables to store block information for CPU simulation. + */ +inline thread_local uint32_t g_block_num = 1; + +/** + * @brief Returns the number of Cube cores on the specified device. + * + * @param [in] device_id Device ID, default is 0. + * @return uint32_t Number of Cube cores on the specified device. + */ +inline uint32_t GetNumCubeCores(int32_t device_id = 0) { +#if (SOC_VERSION == Ascend910B1) || (SOC_VERSION == Ascend910B2) + return 24; +#elif (SOC_VERSION == Ascend910B3) || (SOC_VERSION == Ascend910B4) + return 20; +#else +#error "Unsupported SOC_VERSION value provided." +#endif + return 1; +} + +/** + * @brief Get the number of vector Cores. + * + * @param [in] device_id Device ID, default is 0. + * @return uint32_t Number of vector cores on the specified device. + */ +inline uint32_t GetNumVectorCores(int32_t device_id = 0) { + return 2 * GetNumCubeCores(); +} + +/** + * @brief Converts the type of input tensor to (uint8_t*) + * + * @param [in] tensor Input tensor + * @return uint8_t* Pointer of input tensor. + */ +inline uint8_t* ConvertType(const at::Tensor& tensor) { + return reinterpret_cast(const_cast(tensor.storage().data())); +} + +/** + * @brief Converts any pointer type to (uint8_t*) + * + * @tparam T Input pointer type. + * @param [in] value Input pointer + * @return uint8_t* Converted pointer. + */ +template +inline uint8_t* ConvertType(T* value) { + return reinterpret_cast(const_cast*>(value)); +} + +/** + * @brief Identity conversion for non-pointer types. + * + * @tparam T Input type. + * @param [in] value Input value + * @return T Returns same value + */ +template +inline std::enable_if_t, T> ConvertType(T value) { + return value; +} + +/** + * @brief Converts types given a variadic list. + * + * @tparam Ts Variadic list of types + * + * @param [in] args Variadic list of input arguments + * @return Tuple of converted types. + */ +template +constexpr auto ConvertTypes(Ts&... args) { + return std::make_tuple(ConvertType(args)...); +} + +#define EXEC_KERNEL_CMD(kernel_name, blockdim, ...) \ + do { \ + auto converted_params = pto_isa_ops::ConvertTypes(__VA_ARGS__); \ + pto_isa_ops::g_block_num = blockdim; \ + for (uint32_t i = 0; i < static_cast(blockdim); ++i) { \ + pto::cpu_sim::ScopedExecutionContext ctx(i, 0, 1); \ + std::apply([&](auto&&... params) { kernel_name(params...); }, \ + converted_params); \ + } \ + } while (false) + +} // namespace pto_isa_ops + +/** + * @brief Global accessor for block number in CPU simulation. + * + * We need this function because pto/common/cpu_stub.hpp doesn't define it. + */ +extern "C" uint32_t get_block_num() { return pto_isa_ops::g_block_num; } + +#endif // EXTENSION_CSRC_UTILS_CPU_H diff --git a/csrc/kernel/kernel_abs.cpp b/csrc/kernel/kernel_abs.cpp index d75e729d..df6b8b73 100644 --- a/csrc/kernel/kernel_abs.cpp +++ b/csrc/kernel/kernel_abs.cpp @@ -7,7 +7,7 @@ See LICENSE in the root of the software repository: for the full License text. */ -#include "kernel_utils.h" +#include "kernel_abs.h" using namespace pto; @@ -107,7 +107,7 @@ AICORE void runTAbs(__gm__ T* x, __gm__ T* z, uint32_t total_size) { extern "C" __global__ AICORE void vabs_fp16(GM_ADDR x, GM_ADDR z, uint32_t in_length) { -#if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) +#if defined(__DAV_C220_VEC__) constexpr uint32_t TILE_LEN = 128; runTAbs((__gm__ half*)x, (__gm__ half*)z, in_length); #else @@ -119,8 +119,7 @@ extern "C" __global__ AICORE void vabs_fp16(GM_ADDR x, GM_ADDR z, extern "C" __global__ AICORE void vabs_fp32(GM_ADDR x, GM_ADDR z, uint32_t in_length) { -#if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) - +#if defined(__DAV_C220_VEC__) constexpr uint32_t TILE_LEN = 128; runTAbs((__gm__ float*)x, (__gm__ float*)z, in_length); #else @@ -129,8 +128,3 @@ extern "C" __global__ AICORE void vabs_fp32(GM_ADDR x, GM_ADDR z, (void)in_length; #endif } - -extern "C" void call_vabs_fp16(uint32_t blockDim, void* stream, uint8_t* x, - uint8_t* y, uint32_t in_length) { - vabs_fp16<<>>(x, y, in_length); -} diff --git a/csrc/kernel/kernel_abs.h b/csrc/kernel/kernel_abs.h new file mode 100644 index 00000000..d008a983 --- /dev/null +++ b/csrc/kernel/kernel_abs.h @@ -0,0 +1,24 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +All rights reserved. + +See LICENSE in the root of the software repository: +https://github.com/huawei-csl/pto-kernels/ +for the full License text. +*/ + +#ifndef CSRC_KERNEL_KERNEL_ABS_H +#define CSRC_KERNEL_KERNEL_ABS_H + +#include + +#include "kernel_utils.h" + +#ifdef __CPU_SIM +extern "C" __global__ AICORE void vabs_fp16(GM_ADDR x, GM_ADDR z, + uint32_t in_length); +extern "C" __global__ AICORE void vabs_fp32(GM_ADDR x, GM_ADDR z, + uint32_t in_length); +#endif + +#endif // CSRC_KERNEL_KERNEL_ABS_H diff --git a/csrc/kernel/kernel_batch_matrix_square.cpp b/csrc/kernel/kernel_batch_matrix_square.cpp index 310738cd..64f18190 100644 --- a/csrc/kernel/kernel_batch_matrix_square.cpp +++ b/csrc/kernel/kernel_batch_matrix_square.cpp @@ -7,7 +7,7 @@ See LICENSE in the root of the software repository: for the full License text. */ -#include "kernel_utils.h" +#include "kernel_batch_matrix_square.h" using namespace pto; @@ -101,8 +101,7 @@ AICORE void run_batch_matrix_square(__gm__ float* z, __gm__ InputT* x, extern "C" __global__ AICORE void batch_matrix_square_fp16( __gm__ void* z, __gm__ void* x, uint32_t matrix_size) { -#if (__CHECK_FEATURE_AT_PRECOMPILE) || \ - (__CCE_AICORE__ == 220 && defined(__DAV_C220_CUBE__)) // AIC +#if defined(__DAV_C220_CUBE__) // AIC run_batch_matrix_square((__gm__ float*)z, (__gm__ half*)x, matrix_size); #else // Nothing to do on AIV @@ -114,8 +113,7 @@ extern "C" __global__ AICORE void batch_matrix_square_fp16( extern "C" __global__ AICORE void batch_matrix_square_fp32( __gm__ void* z, __gm__ void* x, uint32_t matrix_size) { -#if (__CHECK_FEATURE_AT_PRECOMPILE) || \ - (__CCE_AICORE__ == 220 && defined(__DAV_C220_CUBE__)) // AIC +#if defined(__DAV_C220_CUBE__) // AIC run_batch_matrix_square((__gm__ float*)z, (__gm__ float*)x, matrix_size); diff --git a/csrc/kernel/kernel_batch_matrix_square.h b/csrc/kernel/kernel_batch_matrix_square.h new file mode 100644 index 00000000..8bc42402 --- /dev/null +++ b/csrc/kernel/kernel_batch_matrix_square.h @@ -0,0 +1,24 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +All rights reserved. + +See LICENSE in the root of the software repository: +https://github.com/huawei-csl/pto-kernels/ +for the full License text. +*/ + +#ifndef CSRC_KERNEL_KERNEL_BATCH_MATRIX_SQUARE_H +#define CSRC_KERNEL_KERNEL_BATCH_MATRIX_SQUARE_H + +#include + +#include "kernel_utils.h" + +#ifdef __CPU_SIM +extern "C" __global__ AICORE void batch_matrix_square_fp16( + __gm__ void* z, __gm__ void* x, uint32_t matrix_size); +extern "C" __global__ AICORE void batch_matrix_square_fp32( + __gm__ void* z, __gm__ void* x, uint32_t matrix_size); +#endif + +#endif // CSRC_KERNEL_KERNEL_BATCH_MATRIX_SQUARE_H diff --git a/csrc/kernel/kernel_csr_gather.cpp b/csrc/kernel/kernel_csr_gather.cpp index 75eae22d..4bd4905e 100644 --- a/csrc/kernel/kernel_csr_gather.cpp +++ b/csrc/kernel/kernel_csr_gather.cpp @@ -7,7 +7,7 @@ See LICENSE in the root of the software repository: for the full License text. */ -#include "kernel_utils.h" +#include "kernel_csr_gather.h" using namespace pto; @@ -116,10 +116,13 @@ AICORE void runTCsrGather(__gm__ T* values, __gm__ int32_t* indices, TileDataVal wTiles(remaining_elements); TileDataVal zTiles(remaining_elements); TileDataIdx idxTiles(remaining_elements); + TileDataIdx tmpTiles(remaining_elements); // Assign the UB address for each tile TASSIGN(valTiles, V_T_ADDR + stage * TILE_SIZE_IN_BYTES); TASSIGN(wTiles, W_T_ADDR + stage * 2 * TILE_SIZE_IDX_IN_BYTES); + TASSIGN(tmpTiles, W_T_ADDR + stage * 2 * TILE_SIZE_IDX_IN_BYTES + + TILE_SIZE_IDX_IN_BYTES); TASSIGN(zTiles, Z_T_ADDR + stage * TILE_SIZE_IN_BYTES); TASSIGN(idxTiles, IDX_T_ADDR + stage * TILE_SIZE_IDX_IN_BYTES); @@ -147,7 +150,11 @@ AICORE void runTCsrGather(__gm__ T* values, __gm__ int32_t* indices, pipe_barrier(PIPE_V); // Gather +#ifdef __CPU_SIM + TGATHER(wTiles, xTiles, idxTiles, tmpTiles); +#else TGATHER(wTiles, xTiles, idxTiles); +#endif // Signal end of gather to MTE2 (next load) set_flag(PIPE_V, PIPE_MTE2, ev1); @@ -206,7 +213,7 @@ extern "C" __global__ AICORE void csr_gather_fp16(GM_ADDR values, GM_ADDR indices, GM_ADDR x, GM_ADDR z, uint32_t x_size, uint32_t indices_size) { -#if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) +#if defined(__DAV_C220_VEC__) constexpr uint32_t TILE_SIZE = 512; constexpr uint32_t TILE_SIZE_X = 40960; @@ -220,7 +227,7 @@ extern "C" __global__ AICORE void csr_gather_fp32(GM_ADDR values, GM_ADDR indices, GM_ADDR x, GM_ADDR z, uint32_t x_size, uint32_t indices_size) { -#if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) +#if defined(__DAV_C220_VEC__) constexpr uint32_t TILE_SIZE = 512; constexpr uint32_t TILE_SIZE_X = 40960; diff --git a/csrc/kernel/kernel_csr_gather.h b/csrc/kernel/kernel_csr_gather.h new file mode 100644 index 00000000..88630cd1 --- /dev/null +++ b/csrc/kernel/kernel_csr_gather.h @@ -0,0 +1,28 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +All rights reserved. + +See LICENSE in the root of the software repository: +https://github.com/huawei-csl/pto-kernels/ +for the full License text. +*/ + +#ifndef CSRC_KERNEL_KERNEL_CSR_GATHER_H +#define CSRC_KERNEL_KERNEL_CSR_GATHER_H + +#include + +#include "kernel_utils.h" + +#ifdef __CPU_SIM +extern "C" __global__ AICORE void csr_gather_fp16(GM_ADDR values, + GM_ADDR indices, GM_ADDR x, + GM_ADDR z, uint32_t x_size, + uint32_t indices_size); +extern "C" __global__ AICORE void csr_gather_fp32(GM_ADDR values, + GM_ADDR indices, GM_ADDR x, + GM_ADDR z, uint32_t x_size, + uint32_t indices_size); +#endif + +#endif // CSRC_KERNEL_KERNEL_CSR_GATHER_H diff --git a/csrc/kernel/kernel_scan_ul1.cpp b/csrc/kernel/kernel_scan_ul1.cpp index 91c27c5c..47575292 100644 --- a/csrc/kernel/kernel_scan_ul1.cpp +++ b/csrc/kernel/kernel_scan_ul1.cpp @@ -7,14 +7,10 @@ See LICENSE in the root of the software repository: for the full License text. */ -#include - -#include "kernel_utils.h" +#include "kernel_scan_ul1.h" using namespace pto; -constexpr unsigned UB_SIZE = 0x30000; // 192KB UB of A2A3 - /** * @brief Kernel implementation for scan operation on a single cube. * @@ -37,8 +33,7 @@ template AICORE void runKernelScanUl1(__gm__ InputT* x, __gm__ InputT* o, __gm__ InputT* u, __gm__ InputT* l, __gm__ OutputT* s) { -#if (__CHECK_FEATURE_AT_PRECOMPILE) || \ - (__CCE_AICORE__ == 220 && defined(__DAV_C220_CUBE__)) +#if defined(__DAV_C220_CUBE__) // Type definitions for different memory levels // GM @@ -46,6 +41,7 @@ AICORE void runKernelScanUl1(__gm__ InputT* x, __gm__ InputT* o, using Stride = pto::Stride<1, 1, 1, matrix_size, 1>; using GlobalDataIn = pto::GlobalTensor; using GlobalDataOut = pto::GlobalTensor; + using GlobalDataInFp32 = GlobalDataOut; // Used for temporary casting // L1 using TileL1In = @@ -69,7 +65,9 @@ AICORE void runKernelScanUl1(__gm__ InputT* x, __gm__ InputT* o, GlobalDataIn lGlobal(l); GlobalDataOut sGlobal(s); // Reuse output buffer for intermediate result C1 - GlobalDataIn c1GM(reinterpret_cast<__gm__ InputT*>(s)); + GlobalDataIn c1GMi(reinterpret_cast<__gm__ InputT*>(s)); + GlobalDataInFp32 c1GMi_fp32(s); + GlobalDataOut c1GMo(s); // Load data from GM to L1 TileL1In xL1; @@ -122,12 +120,15 @@ AICORE void runKernelScanUl1(__gm__ InputT* x, __gm__ InputT* o, wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); // Move C1 from L0C to L1 - // TMOV_FLOAT(c1L1, sL0); + if constexpr (!std::is_same_v) { + // TMOV directly from L0C to L1 handles downcasting automatically + TMOV(c1L1, sL0); + } // Move C1 from L0C to GM, in the float case, // we cannot move to L1 directly - // because of downcasting - TSTORE(c1GM, sL0); + // because of downcasting (TMOV doesn't support float->float from Acc to Mat) + TSTORE(c1GMo, sL0); // Wait for FP set_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID0); @@ -152,11 +153,17 @@ AICORE void runKernelScanUl1(__gm__ InputT* x, __gm__ InputT* o, wait_flag(PIPE_FIX, PIPE_MTE2, EVENT_ID0); // Load C1 from GM to L1 - TLOAD(c1L1, c1GM); - - // Wait for load to be complete before moving C1 to L0 - set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); - wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + if constexpr (std::is_same_v) { + TLOAD(c1L1, c1GMi); + // Wait for load to be complete before moving C1 to L0 + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + } else { + // c1L1 was already moved directly from L0C to L1 via TMOV earlier + // Just sync the fixpipe (TMOV) with MTE1 + set_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID1); + } // Wait for matmul to complet before loading Ls and C1 set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); diff --git a/csrc/kernel/kernel_scan_ul1.h b/csrc/kernel/kernel_scan_ul1.h new file mode 100644 index 00000000..e8d69263 --- /dev/null +++ b/csrc/kernel/kernel_scan_ul1.h @@ -0,0 +1,28 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +All rights reserved. + +See LICENSE in the root of the software repository: +https://github.com/huawei-csl/pto-kernels/ +for the full License text. +*/ + +#ifndef CSRC_KERNEL_KERNEL_SCAN_UL1_H +#define CSRC_KERNEL_KERNEL_SCAN_UL1_H + +#include + +#include "kernel_utils.h" + +#ifdef __CPU_SIM +extern "C" __global__ AICORE void scan_ul1_fp16(__gm__ void* x, __gm__ void* o, + __gm__ void* u, __gm__ void* l, + __gm__ void* s, + uint32_t matrix_size); +extern "C" __global__ AICORE void scan_ul1_fp32(__gm__ void* x, __gm__ void* o, + __gm__ void* u, __gm__ void* l, + __gm__ void* s, + uint32_t matrix_size); +#endif + +#endif // CSRC_KERNEL_KERNEL_SCAN_UL1_H diff --git a/csrc/kernel/kernel_simple_matmul.cpp b/csrc/kernel/kernel_simple_matmul.cpp index d3bfec52..18571247 100644 --- a/csrc/kernel/kernel_simple_matmul.cpp +++ b/csrc/kernel/kernel_simple_matmul.cpp @@ -7,7 +7,7 @@ See LICENSE in the root of the software repository: for the full License text. */ -#include "kernel_utils.h" +#include "kernel_simple_matmul.h" using namespace pto; @@ -129,8 +129,7 @@ extern "C" __global__ AICORE void simple_matmul_fp16(__gm__ void* a, __gm__ void* b, __gm__ void* c, uint32_t matrix_size) { -#if (__CHECK_FEATURE_AT_PRECOMPILE) || \ - (__CCE_AICORE__ == 220 && defined(__DAV_C220_CUBE__)) +#if defined(__DAV_C220_CUBE__) run_simple_matmul((__gm__ half*)a, (__gm__ half*)b, (__gm__ float*)c, matrix_size); #endif @@ -140,8 +139,7 @@ extern "C" __global__ AICORE void simple_matmul_bf16(__gm__ void* a, __gm__ void* b, __gm__ void* c, uint32_t matrix_size) { -#if (__CHECK_FEATURE_AT_PRECOMPILE) || \ - (__CCE_AICORE__ == 220 && defined(__DAV_C220_CUBE__)) +#if defined(__DAV_C220_CUBE__) run_simple_matmul((__gm__ bfloat16_t*)a, (__gm__ bfloat16_t*)b, (__gm__ float*)c, matrix_size); #endif @@ -151,8 +149,7 @@ extern "C" __global__ AICORE void simple_matmul_fp32(__gm__ void* a, __gm__ void* b, __gm__ void* c, uint32_t matrix_size) { -#if (__CHECK_FEATURE_AT_PRECOMPILE) || \ - (__CCE_AICORE__ == 220 && defined(__DAV_C220_CUBE__)) +#if defined(__DAV_C220_CUBE__) run_simple_matmul((__gm__ float*)a, (__gm__ float*)b, (__gm__ float*)c, matrix_size); #endif diff --git a/csrc/kernel/kernel_simple_matmul.h b/csrc/kernel/kernel_simple_matmul.h new file mode 100644 index 00000000..3acd2dd0 --- /dev/null +++ b/csrc/kernel/kernel_simple_matmul.h @@ -0,0 +1,32 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +All rights reserved. + +See LICENSE in the root of the software repository: +https://github.com/huawei-csl/pto-kernels/ +for the full License text. +*/ + +#ifndef CSRC_KERNEL_KERNEL_SIMPLE_MATMUL_H +#define CSRC_KERNEL_KERNEL_SIMPLE_MATMUL_H + +#include + +#include "kernel_utils.h" + +#ifdef __CPU_SIM +extern "C" __global__ AICORE void simple_matmul_fp16(__gm__ void* a, + __gm__ void* b, + __gm__ void* c, + uint32_t matrix_size); +extern "C" __global__ AICORE void simple_matmul_bf16(__gm__ void* a, + __gm__ void* b, + __gm__ void* c, + uint32_t matrix_size); +extern "C" __global__ AICORE void simple_matmul_fp32(__gm__ void* a, + __gm__ void* b, + __gm__ void* c, + uint32_t matrix_size); +#endif + +#endif // CSRC_KERNEL_KERNEL_SIMPLE_MATMUL_H diff --git a/csrc/kernel/kernel_swiglu.cpp b/csrc/kernel/kernel_swiglu.cpp index bcba6cf2..cd3e8815 100644 --- a/csrc/kernel/kernel_swiglu.cpp +++ b/csrc/kernel/kernel_swiglu.cpp @@ -7,7 +7,7 @@ See LICENSE in the root of the software repository: for the full License text. */ -#include "kernel_utils.h" +#include "kernel_swiglu.h" using namespace pto; @@ -59,7 +59,7 @@ static_assert(UB_SLOT_BYTES * 6 == UB_USABLE_BYTES, static_assert(Y_PONG + Y_BUFFER_BYTES <= UB_USABLE_BYTES, "SwiGLU UB layout exceeds usable UB."); -#if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) +#if defined(__DAV_C220_VEC__) namespace { @@ -256,7 +256,7 @@ template AICORE void runTSwiGLUTiled(__gm__ T* x, __gm__ T* y, uint32_t batch, uint32_t input_n, uint32_t num_cores, uint32_t vid, uint32_t row_tile_len) { -#if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) +#if defined(__DAV_C220_VEC__) constexpr uint32_t kTileRows = ELEMENTS_PER_TILE / kTileCols; static_assert(kTileRows * kTileCols == ELEMENTS_PER_TILE, "2D tile shape must match the UB vector tile capacity."); @@ -338,7 +338,7 @@ template AICORE void runTSwiGLUMainTiled(__gm__ T* x, __gm__ T* y, uint32_t batch, uint32_t input_n, uint32_t num_cores, uint32_t vid) { -#if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) +#if defined(__DAV_C220_VEC__) const uint32_t output_n = input_n >> 1; const TileConfig cfg = chooseTileConfig(batch, output_n, num_cores); @@ -368,7 +368,7 @@ AICORE void runTSwiGLUMainTiled(__gm__ T* x, __gm__ T* y, uint32_t batch, template AICORE void runTSwiGLU(__gm__ T* x, __gm__ T* y, uint32_t batch, uint32_t input_n, uint32_t num_cores, uint32_t vid) { -#if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) +#if defined(__DAV_C220_VEC__) set_mask_norm(); set_vector_mask(-1, -1); @@ -397,7 +397,7 @@ AICORE void runTSwiGLU(__gm__ T* x, __gm__ T* y, uint32_t batch, extern "C" __global__ AICORE void swiglu_fp16(GM_ADDR x, GM_ADDR y, uint32_t batch, uint32_t input_n) { -#if defined(__DAV_VEC__) +#if defined(__DAV_C220_VEC__) const uint32_t num_cores = get_block_num() * get_subblockdim(); const uint32_t vid = get_block_idx() * get_subblockdim() + get_subblockid(); runTSwiGLU((__gm__ half*)x, (__gm__ half*)y, batch, input_n, num_cores, @@ -410,8 +410,10 @@ extern "C" __global__ AICORE void swiglu_fp16(GM_ADDR x, GM_ADDR y, #endif } +#ifndef __CPU_SIM extern "C" void call_swiglu_kernel(uint32_t blockDim, void* stream, uint8_t* x, uint8_t* y, uint32_t batch, uint32_t input_n) { swiglu_fp16<<>>(x, y, batch, input_n); } +#endif diff --git a/csrc/kernel/kernel_swiglu.h b/csrc/kernel/kernel_swiglu.h new file mode 100644 index 00000000..4319d693 --- /dev/null +++ b/csrc/kernel/kernel_swiglu.h @@ -0,0 +1,22 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +All rights reserved. + +See LICENSE in the root of the software repository: +https://github.com/huawei-csl/pto-kernels/ +for the full License text. +*/ + +#ifndef CSRC_KERNEL_KERNEL_SWIGLU_H +#define CSRC_KERNEL_KERNEL_SWIGLU_H + +#include + +#include "kernel_utils.h" + +#ifdef __CPU_SIM +extern "C" __global__ AICORE void swiglu_fp16(GM_ADDR x, GM_ADDR y, + uint32_t batch, uint32_t input_n); +#endif + +#endif // CSRC_KERNEL_KERNEL_SWIGLU_H diff --git a/csrc/kernel/kernel_tri_inv_col_sweep.cpp b/csrc/kernel/kernel_tri_inv_col_sweep.cpp index 4177d32f..d2b79c0c 100644 --- a/csrc/kernel/kernel_tri_inv_col_sweep.cpp +++ b/csrc/kernel/kernel_tri_inv_col_sweep.cpp @@ -7,7 +7,7 @@ See LICENSE in the root of the software repository: for the full License text. */ -#include "kernel_utils.h" +#include "kernel_tri_inv_col_sweep.h" using namespace pto; @@ -156,7 +156,7 @@ AICORE void runTTriInv(__gm__ T* vec_in, __gm__ T* vec_out, extern "C" __global__ AICORE void triv_inv_col_sweep_fp16( GM_ADDR x, GM_ADDR z, uint32_t in_length, uint32_t matrix_size) { -#if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) +#if defined(__DAV_C220_VEC__) if (matrix_size == 16) { runTTriInv((__gm__ half*)x, (__gm__ half*)z, in_length); @@ -172,7 +172,7 @@ extern "C" __global__ AICORE void triv_inv_col_sweep_fp16( extern "C" __global__ AICORE void triv_inv_col_sweep_fp32( GM_ADDR x, GM_ADDR z, uint32_t in_length, uint32_t matrix_size) { -#if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) +#if defined(__DAV_C220_VEC__) if (matrix_size == 16) { runTTriInv((__gm__ float*)x, (__gm__ float*)z, in_length); diff --git a/csrc/kernel/kernel_tri_inv_col_sweep.h b/csrc/kernel/kernel_tri_inv_col_sweep.h new file mode 100644 index 00000000..30aa831c --- /dev/null +++ b/csrc/kernel/kernel_tri_inv_col_sweep.h @@ -0,0 +1,26 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +All rights reserved. + +See LICENSE in the root of the software repository: +https://github.com/huawei-csl/pto-kernels/ +for the full License text. +*/ + +#ifndef CSRC_KERNEL_KERNEL_TRI_INV_H +#define CSRC_KERNEL_KERNEL_TRI_INV_H + +#include + +#include "kernel_utils.h" + +#ifdef __CPU_SIM +extern "C" __global__ AICORE void triv_inv_col_sweep_fp16(GM_ADDR x, GM_ADDR z, + uint32_t in_length, + uint32_t matrix_size); +extern "C" __global__ AICORE void triv_inv_col_sweep_fp32(GM_ADDR x, GM_ADDR z, + uint32_t in_length, + uint32_t matrix_size); +#endif + +#endif // CSRC_KERNEL_KERNEL_TRI_INV_H diff --git a/csrc/kernel/kernel_tri_inv_ns.cpp b/csrc/kernel/kernel_tri_inv_ns.cpp index 2ffe7188..c76e00b4 100644 --- a/csrc/kernel/kernel_tri_inv_ns.cpp +++ b/csrc/kernel/kernel_tri_inv_ns.cpp @@ -7,7 +7,7 @@ See LICENSE in the root of the software repository: for the full License text. */ -#include "kernel_utils.h" +#include "kernel_tri_inv_ns.h" using namespace pto; @@ -44,7 +44,7 @@ AICORE inline void PrepareAuxiliaryMatrices( set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); TMATMUL(c_l0_tile_1, a_l0_tile, b_l0_tile); // c_l0_1 = I pipe_barrier(PIPE_M); - TMATMUL_ACC(c_l0_tile_1, a_l0_tile, b_l0_tile); // c_l0_1 = 2*I + TMATMUL_ACC(c_l0_tile_1, c_l0_tile_1, a_l0_tile, b_l0_tile); // c_l0_1 = 2*I set_flag(PIPE_M, PIPE_FIX, EVENT_ID1); set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); @@ -209,8 +209,7 @@ template + +#include "kernel_utils.h" + +#ifdef __CPU_SIM +extern "C" __global__ AICORE void tri_inv_ns_fp16( + __gm__ void* tensor_out, __gm__ void* tensor_in, + __gm__ void* identity_neg_in, __gm__ void* identity_over_n_in, + uint32_t matrix_size, uint32_t num_iters, uint32_t num_matrices); +#endif + +#endif // CSRC_KERNEL_KERNEL_TRI_INV_NS_H diff --git a/csrc/kernel/kernel_tri_inv_rec_unroll.cpp b/csrc/kernel/kernel_tri_inv_rec_unroll.cpp index 614e6b22..5629f559 100644 --- a/csrc/kernel/kernel_tri_inv_rec_unroll.cpp +++ b/csrc/kernel/kernel_tri_inv_rec_unroll.cpp @@ -7,7 +7,7 @@ See LICENSE in the root of the software repository: for the full License text. */ -#include "kernel_utils.h" +#include "kernel_tri_inv_rec_unroll.h" using namespace kernel_utils; using namespace pto; @@ -654,8 +654,7 @@ AICORE void runKernelTriInvRecUnroll(__gm__ OutputT* M_inv, __gm__ InputT* M, __gm__ InputT* I_neg, uint32_t total_tiles, uint32_t num_bsnd_heads = 0, __gm__ int32_t* cu_seqlens = nullptr) { -#if (__CHECK_FEATURE_AT_PRECOMPILE) || \ - (__CCE_AICORE__ == 220 && defined(__DAV_C220_CUBE__)) // Cube compilation +#if defined(__DAV_C220_CUBE__) // Cube compilation TriInvRecUnrollKernel(M_inv, M, I_neg, total_tiles, num_bsnd_heads, diff --git a/csrc/kernel/kernel_tri_inv_rec_unroll.h b/csrc/kernel/kernel_tri_inv_rec_unroll.h new file mode 100644 index 00000000..345cd333 --- /dev/null +++ b/csrc/kernel/kernel_tri_inv_rec_unroll.h @@ -0,0 +1,26 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +All rights reserved. + +See LICENSE in the root of the software repository: +https://github.com/huawei-csl/pto-kernels/ +for the full License text. +*/ + +#ifndef CSRC_KERNEL_KERNEL_TRI_INV_REC_UNROLL_H +#define CSRC_KERNEL_KERNEL_TRI_INV_REC_UNROLL_H + +#include + +#include "kernel_utils.h" + +#ifdef __CPU_SIM +extern "C" __global__ AICORE void tri_inv_rec_unroll_fp16fp16( + __gm__ void* tensor_out, __gm__ void* tensor_in, uint32_t matrix_size, + uint32_t num_matrices, uint32_t num_bsnd_heads, __gm__ void* cu_seqlens); +extern "C" __global__ AICORE void tri_inv_rec_unroll_fp16fp32( + __gm__ void* tensor_out, __gm__ void* tensor_in, uint32_t matrix_size, + uint32_t num_matrices, uint32_t num_bsnd_heads, __gm__ void* cu_seqlens); +#endif + +#endif // CSRC_KERNEL_KERNEL_TRI_INV_REC_UNROLL_H diff --git a/csrc/kernel/kernel_tri_inv_trick.cpp b/csrc/kernel/kernel_tri_inv_trick.cpp index b62d3ae5..a5b88f1e 100644 --- a/csrc/kernel/kernel_tri_inv_trick.cpp +++ b/csrc/kernel/kernel_tri_inv_trick.cpp @@ -7,7 +7,7 @@ See LICENSE in the root of the software repository: for the full License text. */ -#include "kernel_utils.h" +#include "kernel_tri_inv_trick.h" using namespace pto; @@ -23,8 +23,7 @@ template AICORE void runKernelTriInvTrick(__gm__ OutputT* M_inv, __gm__ InputT* M, __gm__ InputT* I_neg, uint32_t max_block_size) { -#if (__CHECK_FEATURE_AT_PRECOMPILE) || \ - (__CCE_AICORE__ == 220 && defined(__DAV_C220_CUBE__)) // Cube compilation +#if defined(__DAV_C220_CUBE__) // Cube compilation constexpr uint32_t TileLen = MatrixSize * MatrixSize; const uint32_t global_index = get_block_idx() * TileLen; diff --git a/csrc/kernel/kernel_tri_inv_trick.h b/csrc/kernel/kernel_tri_inv_trick.h new file mode 100644 index 00000000..54fea1f5 --- /dev/null +++ b/csrc/kernel/kernel_tri_inv_trick.h @@ -0,0 +1,25 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +All rights reserved. + +See LICENSE in the root of the software repository: +https://github.com/huawei-csl/pto-kernels/ +for the full License text. +*/ + +#ifndef CSRC_KERNEL_KERNEL_TRI_INV_TRICK_H +#define CSRC_KERNEL_KERNEL_TRI_INV_TRICK_H + +#include + +#include "kernel_utils.h" + +#ifdef __CPU_SIM +extern "C" __global__ AICORE void tri_inv_trick_fp16(__gm__ void* tensor_out, + __gm__ void* tensor_in, + __gm__ void* identity_in, + uint32_t matrix_size, + uint32_t max_block_size); +#endif + +#endif // CSRC_KERNEL_KERNEL_TRI_INV_TRICK_H diff --git a/csrc/kernel/kernel_utils.h b/csrc/kernel/kernel_utils.h index 61acd5bf..08a268df 100644 --- a/csrc/kernel/kernel_utils.h +++ b/csrc/kernel/kernel_utils.h @@ -17,6 +17,20 @@ for the full License text. #endif // clang-format on +#ifdef __CPU_SIM +#define EVENT_ID1 1 +#define EVENT_ID2 2 +#define EVENT_ID3 3 +#define EVENT_ID4 4 +#define EVENT_ID5 5 +#define EVENT_ID6 6 +#define EVENT_ID7 7 +#define __DAV_C220_VEC__ +#define __DAV_C220_CUBE__ + +extern "C" uint32_t get_block_num(); +#endif + namespace kernel_utils { /** * @brief Do a sync step (set-wait flag) between two pipes. diff --git a/scripts/build_cpu.py b/scripts/build_cpu.py new file mode 100644 index 00000000..f639c9f7 --- /dev/null +++ b/scripts/build_cpu.py @@ -0,0 +1,67 @@ +import os +import sys + +from torch.utils.cpp_extension import load as load_cpp + +ASCEND_TOOLKIT_HOME = os.environ.get("ASCEND_TOOLKIT_HOME", "") +PTO_LIB_PATH = os.environ.get("PTO_LIB_PATH", ASCEND_TOOLKIT_HOME) + +# Get the directory of the current script +root_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +build_path = os.path.join(root_path, "build/cpu_sim") +os.makedirs(build_path, exist_ok=True) + +sources = [ + os.path.join(root_path, "csrc/host/pybind11.cpp"), + os.path.join(root_path, "csrc/kernel/kernel_abs.cpp"), + os.path.join(root_path, "csrc/kernel/kernel_batch_matrix_square.cpp"), + os.path.join(root_path, "csrc/kernel/kernel_csr_gather.cpp"), + os.path.join(root_path, "csrc/kernel/kernel_scan_ul1.cpp"), + os.path.join(root_path, "csrc/kernel/kernel_simple_matmul.cpp"), + os.path.join(root_path, "csrc/kernel/kernel_swiglu.cpp"), + os.path.join(root_path, "csrc/kernel/kernel_tri_inv_col_sweep.cpp"), + os.path.join(root_path, "csrc/kernel/kernel_tri_inv_ns.cpp"), + # os.path.join(root_path, "csrc/kernel/kernel_tri_inv_rec_unroll.cpp"), + os.path.join(root_path, "csrc/kernel/kernel_tri_inv_trick.cpp"), +] + +module = load_cpp( + name="pto_kernels_cpu", + sources=sources, + extra_cflags=[ + "-std=c++23", + "-O2", + "-fPIC", + "-D__CPU_SIM", + "-DSOC_VERSION=Ascend910B4", + # This is not requred, but reduces the number of warnings + "-Wno-narrowing", + ], + build_directory=build_path, + extra_include_paths=[f"{PTO_LIB_PATH}/include"], + verbose=True, + is_python_module=True, +) + +print("Built pto_kernels for CPU simulation.") + +sys.modules["pto_kernels"] = module + + +# Sanity tests, import declared functions +print("Test if the module contains custom ops...") + +from pto_kernels import ( # noqa + pto_abs, + pto_batch_matrix_square, + pto_csr_gather, + pto_scan_ul1, + pto_simple_matmul, + pto_swiglu, + pto_tri_inv, + pto_tri_inv_ns, + # pto_tri_inv_rec_unroll, + pto_tri_inv_trick, +) + +print("Imported custom ops successfully.") diff --git a/scripts/test_cpu.py b/scripts/test_cpu.py new file mode 100644 index 00000000..eb751b48 --- /dev/null +++ b/scripts/test_cpu.py @@ -0,0 +1,46 @@ +import os +import sys + +# importing Torch is required for the importlib to subsequently import custom ops +import torch # noqa: F401 +import importlib + +# Get the directory of the current script +root_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +so_path = os.path.join(root_path, "build/cpu_sim/pto_kernels_cpu.so") + +# Import the module +spec = importlib.util.spec_from_file_location("pto_kernels_cpu", so_path) +if spec is None: + raise AssertionError(f"Failed to create spec for pto_kernels_cpu at {so_path}") +module = importlib.util.module_from_spec(spec) +if not isinstance(spec.loader, importlib.abc.Loader): + raise AssertionError("spec.loader is not a valid importlib Loader") +spec.loader.exec_module(module) + +# Make sure the module uses the name "pto_kernels" so it can be imported with that name, overriding any real NPU pto_kernels module +sys.modules["pto_kernels"] = module + +# Set the device to CPU +os.environ["NPU_DEVICE"] = "cpu" + +# Run the actual tests +import pytest # noqa + +test_dir = os.path.join(root_path, "tests") +pytest.main( + [ + os.path.join(test_dir, "test_abs.py"), + os.path.join(test_dir, "test_batch_matrix_square.py"), + os.path.join(test_dir, "test_csr_gather.py"), + os.path.join(test_dir, "test_scan_ul1.py"), + os.path.join(test_dir, "test_simple_matmul.py"), + os.path.join(test_dir, "test_swiglu.py"), + os.path.join(test_dir, "test_tri_inv_ns.py"), + os.path.join(test_dir, "test_tri_inv_trick.py"), + # Disabled for now + # os.path.join(test_dir, "test_tri_inv_col_sweep.py"), + # os.path.join(test_dir, "test_tri_inv_rec_unroll.py"), + # os.path.join(test_dir, "test_tri_inv_rec_unroll_variable_sequence_lengths.py"), + ] +) diff --git a/tests/conftest.py b/tests/conftest.py index 08c13d6e..18907f37 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,8 +3,9 @@ import torch NPU_DEVICE = os.environ.get("NPU_DEVICE", "npu:1") -torch.npu.config.allow_internal_format = False -torch.npu.set_device(NPU_DEVICE) +if NPU_DEVICE.startswith("npu:"): + torch.npu.config.allow_internal_format = False + torch.npu.set_device(NPU_DEVICE) @pytest.fixture(scope="session") diff --git a/tests/test_abs.py b/tests/test_abs.py index a1b12299..fb4ad49f 100644 --- a/tests/test_abs.py +++ b/tests/test_abs.py @@ -14,12 +14,12 @@ @pytest.mark.parametrize("size0", [1, 2, 3, 10, 20, 64, 128]) @pytest.mark.parametrize("size1", [1, 2, 3, 10, 20, 64, 128]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=str) -def test_pto_abs(size0: int, size1: int, dtype: torch.dtype): +def test_pto_abs(npu_device: str, size0: int, size1: int, dtype: torch.dtype): size = [size0, size1] # Create random input tensors on CPU x = 2 * torch.rand(size, device="cpu", dtype=dtype) - 1 # Copy the input tensor to NPU - x_npu = x.npu() + x_npu = x.to(npu_device) # breakpoint() # Call the custom abs operator output = pto_abs(x_npu).cpu() diff --git a/tests/test_batch_matrix_square.py b/tests/test_batch_matrix_square.py index 4058cec7..a30112ef 100644 --- a/tests/test_batch_matrix_square.py +++ b/tests/test_batch_matrix_square.py @@ -19,12 +19,12 @@ @pytest.mark.parametrize("block_dim", [1, 2, 3, 5, 8, 11, 16, 37, 64, 128, 256]) @pytest.mark.parametrize("matrix_size", [16, 32, 64, 96, 128]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=str) -def test_pto_batch_matrix_square(block_dim: int, matrix_size: int, dtype: torch.dtype): +def test_pto_batch_matrix_square( + npu_device: str, block_dim: int, matrix_size: int, dtype: torch.dtype +): x = torch.rand((block_dim, matrix_size, matrix_size), device="cpu", dtype=dtype) - x_npu = x.npu() - torch.npu.synchronize() + x_npu = x.to(npu_device) z_npu = pto_batch_matrix_square(x_npu) - torch.npu.synchronize() z = z_npu.cpu() ref = (x.to(torch.double) @ x.to(torch.double)).to(torch.float32) assert torch.allclose(z, ref) diff --git a/tests/test_csr_gather.py b/tests/test_csr_gather.py index cc7594b7..d0b1b0ec 100644 --- a/tests/test_csr_gather.py +++ b/tests/test_csr_gather.py @@ -25,15 +25,15 @@ def ref_csr_gather( @pytest.mark.parametrize("x_size, v_size", sweep_sizes) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=str) -def test_pto_csr_gather(x_size: int, v_size: int, dtype: torch.dtype): +def test_pto_csr_gather(npu_device: str, x_size: int, v_size: int, dtype: torch.dtype): # Create random input tensors on CPU x = torch.rand((x_size,), device="cpu", dtype=dtype) values = torch.rand((v_size,), device="cpu", dtype=dtype) indices = torch.randint(0, x_size, (v_size,), device="cpu", dtype=torch.int32) # Copy the input tensors to NPU - x_npu = x.npu() - values_npu = values.npu() - indices_npu = indices.npu() + x_npu = x.to(npu_device) + values_npu = values.to(npu_device) + indices_npu = indices.to(npu_device) # Call the custom csr_gather operator output = pto_csr_gather(values_npu, indices_npu, x_npu).cpu() # Compute the expected result using a reference implementation on CPU @@ -43,4 +43,4 @@ def test_pto_csr_gather(x_size: int, v_size: int, dtype: torch.dtype): if __name__ == "__main__": - test_pto_csr_gather(32768, 131072, torch.float16) + test_pto_csr_gather("npu:1", 32768, 131072, torch.float16) diff --git a/tests/test_scan_ul1.py b/tests/test_scan_ul1.py index 1c90f7c4..8d0dd45f 100644 --- a/tests/test_scan_ul1.py +++ b/tests/test_scan_ul1.py @@ -16,10 +16,10 @@ @pytest.mark.parametrize("scan_size", matrix_size) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=str) -def test_pto_scan_ul1(scan_size: int, dtype: torch.dtype): +def test_pto_scan_ul1(npu_device: str, scan_size: int, dtype: torch.dtype): a = torch.ones(scan_size, dtype=dtype) - a_npu = a.npu() + a_npu = a.to(npu_device) scan_npu = pto_scan_ul1(a_npu) ref = torch.cumsum(a.to(torch.float32), dim=0) @@ -28,4 +28,4 @@ def test_pto_scan_ul1(scan_size: int, dtype: torch.dtype): if __name__ == "__main__": - test_pto_scan_ul1(128 * 128, torch.float32) + test_pto_scan_ul1("npu:1", 128 * 128, torch.float32) diff --git a/tests/test_simple_matmul.py b/tests/test_simple_matmul.py index b5ab2162..7f98a7ba 100644 --- a/tests/test_simple_matmul.py +++ b/tests/test_simple_matmul.py @@ -15,13 +15,13 @@ @pytest.mark.parametrize( "dtype", [torch.float16, torch.bfloat16, torch.float32], ids=str ) -def test_pto_simple_matmul(matrix_size: int, dtype: torch.dtype): +def test_pto_simple_matmul(npu_device: str, matrix_size: int, dtype: torch.dtype): m, k, n = matrix_size, matrix_size, matrix_size a = torch.rand((m, k), device="cpu", dtype=dtype) b = torch.rand((k, n), device="cpu", dtype=dtype) - a_npu = a.npu() - b_npu = b.npu() + a_npu = a.to(npu_device) + b_npu = b.to(npu_device) c_npu = pto_simple_matmul(a_npu, b_npu) ref = torch.matmul(a.float(), b.float()) diff --git a/tests/test_swiglu.py b/tests/test_swiglu.py index 508078de..38c789ca 100644 --- a/tests/test_swiglu.py +++ b/tests/test_swiglu.py @@ -1,6 +1,5 @@ import pytest import torch -import torch_npu # noqa from pto_kernels import pto_swiglu @@ -37,13 +36,17 @@ def test_pto_swiglu_matches_reference_and_torch_npu( actual = pto_swiglu(x) expected = swiglu_ref(x.cpu()) - torch_npu_expected = torch_npu.npu_swiglu(x, dim=-1) torch.testing.assert_close(actual.cpu(), expected, rtol=1e-2, atol=1e-5) - torch.testing.assert_close(actual, torch_npu_expected, rtol=1e-2, atol=1e-5) + if npu_device.startswith("npu:"): + import torch_npu # noqa -def test_pto_swiglu_rejects_non_last_dim(npu_device): + torch_npu_expected = torch_npu.npu_swiglu(x, dim=-1) + torch.testing.assert_close(actual, torch_npu_expected, rtol=1e-2, atol=1e-5) + + +def test_pto_swiglu_rejects_non_last_dim(npu_device: str): x = torch.randn(2, 256, device=npu_device, dtype=DTYPE) with pytest.raises(RuntimeError, match="dim=-1"): diff --git a/tests/test_tri_inv_col_sweep.py b/tests/test_tri_inv_col_sweep.py index e36abb65..8dfe0852 100644 --- a/tests/test_tri_inv_col_sweep.py +++ b/tests/test_tri_inv_col_sweep.py @@ -73,6 +73,7 @@ def ones_np_tril(batch_size: int, n: int, dtype: np.dtype): (rand_np_tril, ones_np_tril), ) def test_tri_inv_col_sweep( + npu_device: str, batch_size: int, matrix_size: int, data_type: np.dtype, @@ -84,15 +85,12 @@ def test_tri_inv_col_sweep( # Convert input matrices from row-major order to column-major order input_x_cpu = input_x_cpu.transpose(0, 2, 1) - input_x = torch.from_numpy(input_x_cpu).npu() - expected = torch.from_numpy(expected_cpu).npu() + input_x = torch.from_numpy(input_x_cpu).to(npu_device) + expected = torch.from_numpy(expected_cpu).to(npu_device) - torch.npu.synchronize() actual = pto_tri_inv(input_x) - torch.npu.synchronize() # Transpose matrices back to row-major order actual = actual.transpose(2, 1) - torch.npu.synchronize() assert actual.shape == expected.shape, "Output shape does not match expected shape." assert torch.equal(actual, expected) @@ -106,6 +104,7 @@ def test_tri_inv_col_sweep( [(rand_np_tril, 1e-5, 1e-5), (ones_np_tril, 0, 0)], ) def test_tri_inv_col_sweep_np_linalg_inv( + npu_device: str, batch_size: int, matrix_size: int, data_type: np.dtype, @@ -119,12 +118,10 @@ def test_tri_inv_col_sweep_np_linalg_inv( # Convert input matrices from row-major order to column-major order input_x_cpu = input_x_cpu.transpose(0, 2, 1) - input_x = torch.from_numpy(input_x_cpu).npu() - golden_numpy_as_torch = torch.from_numpy(golden_numpy_cpu).npu() + input_x = torch.from_numpy(input_x_cpu).to(npu_device) + golden_numpy_as_torch = torch.from_numpy(golden_numpy_cpu).to(npu_device) - torch.npu.synchronize() actual = pto_tri_inv(input_x) - torch.npu.synchronize() # rtol must be scaled w.r.t to the input size, see Higham's paper, Eq. (2.3) # https://nhigham.com/wp-content/uploads/2023/08/high89t.pdf diff --git a/tests/test_tri_inv_ns.py b/tests/test_tri_inv_ns.py index 4968815b..b452e593 100644 --- a/tests/test_tri_inv_ns.py +++ b/tests/test_tri_inv_ns.py @@ -77,6 +77,7 @@ def default_num_iters(n: int) -> int: def _test_tri_inv_ns( + npu_device: str, U: torch.Tensor, atol: float, rtol: float, @@ -85,13 +86,11 @@ def _test_tri_inv_ns( U = U.to(torch.half) golden_cpu = linalg_inv(U) - U_npu = U.npu() + U_npu = U.to(npu_device) - torch.npu.synchronize() num_iters = max([int(2.0 * math.ceil(math.log2(U.shape[-1]))), 12]) # num_iters = 1 actual = pto_tri_inv_ns(U_npu, num_iters=num_iters) - torch.npu.synchronize() actual_cpu = actual.cpu().to(torch.float64) @@ -123,6 +122,7 @@ def _test_tri_inv_ns( ], ) def test_tri_inv_ns( + npu_device: str, n: int, block_dim_x: int, block_dim_y: int, @@ -132,4 +132,4 @@ def test_tri_inv_ns( ftol: float, ): U = matrix_gen(n, block_dim_x, block_dim_y) - _test_tri_inv_ns(U, atol, rtol, ftol) + _test_tri_inv_ns(npu_device, U, atol, rtol, ftol) diff --git a/tests/test_tri_inv_rec_unroll.py b/tests/test_tri_inv_rec_unroll.py index 977dcf26..b84a5dd6 100644 --- a/tests/test_tri_inv_rec_unroll.py +++ b/tests/test_tri_inv_rec_unroll.py @@ -68,6 +68,7 @@ def linalg_inv(U: torch.tensor) -> torch.tensor: def _test_tri_inv_rec_unroll( + npu_device: str, U: torch.tensor, atol: float, rtol: float, @@ -79,13 +80,10 @@ def _test_tri_inv_rec_unroll( U = U.to(input_dtype) golden_cpu = linalg_inv(U) - U_npu = U.npu() + U_npu = U.to(npu_device) - torch.npu.synchronize() actual = pto_tri_inv_rec_unroll(U_npu, is_bsnd_format=False, dtype_out=output_dtype) - torch.npu.synchronize() actual_cpu = actual.cpu() - torch.npu.synchronize() actual_cpu = actual_cpu.to(torch.float64) frob_error = torch.sqrt( torch.sum((golden_cpu - actual_cpu) * (golden_cpu - actual_cpu)) @@ -101,6 +99,7 @@ def _test_tri_inv_rec_unroll( def _test_tri_inv_rec_unroll_bsnd( + npu_device: str, U: torch.tensor, B: int, S: int, @@ -118,16 +117,12 @@ def _test_tri_inv_rec_unroll_bsnd( # Transform to bsnd layout U = U.transpose(1, 2).contiguous().reshape(B, S, N, D) - torch.npu.synchronize() golden_cpu = golden_cpu.transpose(1, 2).contiguous().reshape(B, S, N, D) - U_npu = U.npu() + U_npu = U.to(npu_device) - torch.npu.synchronize() actual = pto_tri_inv_rec_unroll(U_npu, is_bsnd_format=True, dtype_out=output_dtype) - torch.npu.synchronize() actual_cpu = actual.cpu() - torch.npu.synchronize() actual_cpu = actual_cpu.to(torch.float64) frob_error = torch.sqrt( torch.sum((golden_cpu - actual_cpu) * (golden_cpu - actual_cpu)) @@ -163,6 +158,7 @@ def _test_tri_inv_rec_unroll_bsnd( ], ) def test_tri_inv_rec_unroll( + npu_device, n: int, block_dim_x: int, block_dim_y: int, @@ -174,7 +170,7 @@ def test_tri_inv_rec_unroll( output_dtype: torch.dtype, ): U = matrix_gen(n, block_dim_x, block_dim_y) - _test_tri_inv_rec_unroll(U, atol, rtol, ftol, input_dtype, output_dtype) + _test_tri_inv_rec_unroll(npu_device, U, atol, rtol, ftol, input_dtype, output_dtype) @pytest.mark.parametrize("B", [1, 4]) @@ -199,6 +195,7 @@ def test_tri_inv_rec_unroll( ], ) def test_tri_inv_rec_unroll_bsnd( + npu_device: str, B: int, S: int, N: int, @@ -215,5 +212,5 @@ def test_tri_inv_rec_unroll_bsnd( pytest.skip("Sequence length must be a multiple of chunk size C.") U = matrix_gen(C, B * S // C, N) _test_tri_inv_rec_unroll_bsnd( - U, B, S, N, C, atol, rtol, ftol, input_dtype, output_dtype + npu_device, U, B, S, N, C, atol, rtol, ftol, input_dtype, output_dtype ) diff --git a/tests/test_tri_inv_rec_unroll_variable_sequence_lengths.py b/tests/test_tri_inv_rec_unroll_variable_sequence_lengths.py index f781ac35..c24a74e1 100644 --- a/tests/test_tri_inv_rec_unroll_variable_sequence_lengths.py +++ b/tests/test_tri_inv_rec_unroll_variable_sequence_lengths.py @@ -60,6 +60,7 @@ def transpose_valid_chunks( def chunk_scaled_dot_kkt_fwd_emulated( + npu_device: str, k: torch.Tensor, beta: torch.Tensor, cu_seqlens: torch.Tensor, @@ -78,18 +79,25 @@ def chunk_scaled_dot_kkt_fwd_emulated( chunk_end = min(chunk_start + chunk_size, seq_end) actual_size = chunk_end - chunk_start k_chunk = ( - k[:, chunk_start:chunk_end].transpose(1, 2).to(torch.float32).npu() + k[:, chunk_start:chunk_end] + .transpose(1, 2) + .to(torch.float32) + .to(npu_device) ) beta_chunk = ( beta[:, chunk_start:chunk_end] .transpose(1, 2) .unsqueeze(-1) .to(torch.float32) - .npu() + .to(npu_device) ) scores = torch.matmul(k_chunk, k_chunk.transpose(-1, -2)) scores = torch.tril(scores * beta_chunk, diagonal=-1) # .to(k.dtype) - scores = torch.tril(torch.ones(scores.shape), diagonal=-1).to(k.dtype).npu() + scores = ( + torch.tril(torch.ones(scores.shape), diagonal=-1) + .to(k.dtype) + .to(npu_device) + ) A[:, chunk_start:chunk_end, :, :actual_size] = scores.transpose(1, 2) return A @@ -121,6 +129,7 @@ def all_ones_varlen_triangular_tensor( def build_variable_len_input( + npu_device: str, seq_lens: list[int], num_heads: int, chunk_size: int, @@ -148,13 +157,15 @@ def build_variable_len_input( ) beta = torch.randn((1, total_tokens, num_heads), dtype=torch.float16).sigmoid() packed_input = transpose_valid_chunks( - chunk_scaled_dot_kkt_fwd_emulated(k, beta, cu_seqlens_tensor, chunk_size), + chunk_scaled_dot_kkt_fwd_emulated( + npu_device, k, beta, cu_seqlens_tensor, chunk_size + ), cu_seqlens_tensor, chunk_size, ) else: raise RuntimeError(f"unknown matrix type to test: {matrix_type}") - return packed_input.contiguous().npu(), cu_seqlens_tensor.npu() + return packed_input.contiguous().to(npu_device), cu_seqlens_tensor.to(npu_device) def _reference_inverse( @@ -191,9 +202,7 @@ def _test_inverse_accuracy( ref = _reference_inverse(A, cu_seqlens, chunk_size) tri = pto_tri_inv_rec_unroll(A, cu_seqlens, True, output_dtype) - torch.npu.synchronize() tri = tri.to(torch.float32).cpu().to(torch.float64) - torch.npu.synchronize() assert torch.allclose(tri, ref, atol=atol, rtol=rtol) frob_error = torch.sqrt(torch.sum((ref - tri) ** 2) / torch.sum(ref**2)).item() @@ -211,6 +220,7 @@ def _test_inverse_accuracy( "matrix_type,atol,rtol,ftol", [("ones", 0, 0, 0), ("random", 1e-5, 5e-2, 1e-2)] ) def test_tri_inv_rec_unroll_variable_length( + npu_device: str, B: int, N: int, chunk_size: int, @@ -235,6 +245,7 @@ def test_tri_inv_rec_unroll_variable_length( default_feature_dim = 64 seq_lens = generate_random_sequence_lengths(B, total_tokens) packed_input, cu_seqlens = build_variable_len_input( + npu_device, seq_lens=seq_lens, num_heads=N, chunk_size=chunk_size, diff --git a/tests/test_tri_inv_trick.py b/tests/test_tri_inv_trick.py index 14a267dc..d78d7086 100644 --- a/tests/test_tri_inv_trick.py +++ b/tests/test_tri_inv_trick.py @@ -47,12 +47,13 @@ def block_random_matrix(n, block_dim_x, block_dim_y, scale=0.2): return torch.from_numpy(U) -def _test_tri_inv_trick(U: torch.tensor, atol: float, rtol: float, ftol: float): +def _test_tri_inv_trick( + npu_device: str, U: torch.tensor, atol: float, rtol: float, ftol: float +): n = U.shape[-1] U = U.to(torch.half) - U_npu = U.npu() - torch.npu.synchronize() + U_npu = U.to(npu_device) Identity = np.ones((n, n), dtype=np.double) Identity = np.triu(Identity) @@ -65,11 +66,8 @@ def _test_tri_inv_trick(U: torch.tensor, atol: float, rtol: float, ftol: float): ) golden_cpu = torch.from_numpy(golden_numpy) - torch.npu.synchronize() actual = pto_tri_inv_trick(U_npu) - torch.npu.synchronize() actual_cpu = actual.cpu() - torch.npu.synchronize() actual_cpu = actual_cpu.to(torch.float64) frob_error = torch.sqrt( torch.sum((golden_cpu - actual_cpu) * (golden_cpu - actual_cpu)) @@ -95,6 +93,7 @@ def _test_tri_inv_trick(U: torch.tensor, atol: float, rtol: float, ftol: float): ], ) def test_tri_inv_trick_ones( + npu_device: str, n: int, block_dim_x: int, block_dim_y: int, @@ -104,4 +103,4 @@ def test_tri_inv_trick_ones( ftol: float, ): U = matrix_gen(n, block_dim_x, block_dim_y) - _test_tri_inv_trick(U, atol, rtol, ftol) + _test_tri_inv_trick(npu_device, U, atol, rtol, ftol)