diff --git a/Makefile b/Makefile index 00d4c776..efa112e6 100644 --- a/Makefile +++ b/Makefile @@ -27,13 +27,16 @@ wheel: clean # 'make compile_abs' compiles 'kernel_abs.cpp' into 'libkernel_abs.so' without building the whole wheel package. # This is useful for development and debugging of individual kernels. compile_%: - bisheng -fPIC -shared -xcce -DMEMORY_BASE -O2 -std=c++17 \ + mkdir -p build/lib/ + bisheng -fPIC -shared -xcce -DREGISTER_BASE -O2 -std=gnu++17 \ -I$(CSRC_KERNEL_DIR) \ -I$(PTO_LIB_PATH)/include \ - --npu-arch=dav-2201 \ - -Wno-ignored-attributes \ + --cce-aicore-arch=dav-c310 \ + -mllvm -cce-aicore-stack-size=0x8000 \ + -mllvm -cce-aicore-function-stack-size=0x8000 \ + -Wno-ignored-attributes \ $(CSRC_KERNEL_DIR)/kernel_$*.cpp \ - -o libkernel_$*.so + -o build/lib/libkernel_$*.so install: @@ -47,3 +50,13 @@ test: test_tri_inv: pytest tests/test_tri_inv_*.py + + +run_abs_a5: compile_abs + python scripts/data_gen_abs.py + g++ -o main_abs examples/a5/main_abs.cpp \ + -L$(shell pwd)/build/lib/ -L$(ASCEND_TOOLKIT_HOME)/lib64/ \ + -lkernel_abs -lacl_rt -I$(ASCEND_TOOLKIT_HOME)/include/ \ + -I$(shell pwd)/examples/a5/ -Wno-ignored-attributes + LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:$(shell pwd)/build/lib/ cannsim record \ + --soc=Ascend950 ./main_abs diff --git a/csrc/kernel/kernel_abs.cpp b/csrc/kernel/kernel_abs.cpp index f7792cff..3d4d54c3 100644 --- a/csrc/kernel/kernel_abs.cpp +++ b/csrc/kernel/kernel_abs.cpp @@ -37,6 +37,11 @@ AICORE void runTAbs(__gm__ T* x, __gm__ T* z, uint32_t total_size) { const uint32_t num_aiv_cores = get_block_num(); const uint32_t aiv_core_id = get_block_idx(); + if (get_subblockid() != 0) { + // Only subblock 0 is used in this kernel + return; + } + constexpr uint32_t UB_ZERO_ADDR = 0; constexpr uint32_t TILE_SIZE_IN_BYTES = TILE_SIZE * sizeof(T); const uint32_t num_tiles = (total_size + TILE_SIZE - 1) / TILE_SIZE; @@ -132,5 +137,5 @@ extern "C" __global__ AICORE void vabs_fp32(GM_ADDR x, GM_ADDR z, extern "C" void call_vabs_fp16(uint32_t blockDim, void* stream, uint8_t* x, uint8_t* y, uint32_t in_length) { - vabs_fp16<<>>(x, y, in_length); + vabs_fp16<<>>(x, y, in_length); } diff --git a/examples/a5/data_utils.h b/examples/a5/data_utils.h new file mode 100644 index 00000000..ea45595f --- /dev/null +++ b/examples/a5/data_utils.h @@ -0,0 +1,256 @@ +/** + * @file data_utils.h + * @brief Common functions used to read, write and print data. + */ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +enum class PrintDataType { + DT_UNDEFINED = -1, + FLOAT = 0, + HALF = 1, + INT8_T = 2, + INT32_T = 3, + UINT8_T = 4, + INT16_T = 6, + UINT16_T = 7, + UINT32_T = 8, + INT64_T = 9, + UINT64_T = 10, + DOUBLE = 11, + BOOL = 12, + STRING = 13, + COMPLEX64 = 16, + COMPLEX128 = 17, + BF16 = 27 +}; + +#define INFO_LOG(fmt, args...) fprintf(stdout, "[INFO] " fmt "\n", ##args) +#define WARN_LOG(fmt, args...) fprintf(stdout, "[WARN] " fmt "\n", ##args) +#define ERROR_LOG(fmt, args...) fprintf(stdout, "[ERROR] " fmt "\n", ##args) +#define CHECK_ACL(x) \ + do { \ + const aclError __ret = x; \ + if (__ret != ACL_ERROR_NONE) { \ + std::cerr << __FILE__ << ":" << __LINE__ << " aclError:" << __ret \ + << std::endl; \ + } \ + } while (0); + +/** + * @brief Read data from file. + * + * @param [in] file_path File path. + * @param [out] buffer Pointer to the buffer where the data is read. + * @param [in] size Size of the file and the buffer. + * @return Boolean indicating if the data read was successful or not. + */ +bool ReadFile(const std::string& file_path, void* buffer, size_t size) { + struct stat s_buf; + const int file_status = stat(file_path.data(), &s_buf); + if (file_status == -1) { + ERROR_LOG("Failed to read file."); + return false; + } + if (S_ISREG(s_buf.st_mode) == 0) { + ERROR_LOG("File does not exist: %s", file_path.c_str()); + return false; + } + + std::ifstream file; + file.open(file_path, std::ios::binary); + if (!file.is_open()) { + ERROR_LOG("Failed to open file. Path = %s", file_path.c_str()); + return false; + } + + std::filebuf* const buf = file.rdbuf(); + const size_t read_size = buf->pubseekoff(0, std::ios::end, std::ios::in); + if (read_size == 0) { + ERROR_LOG("%s: File is empty.", file_path.c_str()); + file.close(); + return false; + } + if (read_size > size) { + ERROR_LOG("%s: File size is larger than the buffer size.", + file_path.c_str()); + file.close(); + return false; + } + buf->pubseekpos(0, std::ios::in); + buf->sgetn(static_cast(buffer), read_size); + file.close(); + return true; +} + +/** + * @brief Write data to file. + * + * @param [in] file_path File path. + * @param [in] buffer Data to write to file. + * @param [in] size Size to write. + * @return Boolean indicating if the data write was successful or not. + */ +bool WriteFile(const std::string& file_path, const void* buffer, size_t size) { + if (buffer == nullptr) { + ERROR_LOG("Cannot write file from a nullptr buffer."); + return false; + } + + const int fd = + open(file_path.c_str(), O_RDWR | O_CREAT | O_TRUNC, S_IRUSR | S_IWRITE); + if (fd < 0) { + ERROR_LOG("Failed to open file. Path = %s", file_path.c_str()); + return false; + } + + const size_t writeSize = write(fd, buffer, size); + (void)close(fd); + if (writeSize != size) { + ERROR_LOG("Failed to write file."); + return false; + } + + return true; +} + +/// @private +template +void DoPrintData(const T* data, size_t count, size_t elements_per_row) { + assert(elements_per_row != 0); + for (size_t i = 0; i < count; ++i) { + if constexpr (std::is_same::value || + std::is_same::value) { + // cout treats int8 as char and doesn't output its numeric + // representation + std::cout << std::setw(10) << static_cast(data[i]); + } else { + std::cout << std::setw(10) << data[i]; + } + if (i % elements_per_row == elements_per_row - 1) { + std::cout << std::endl; + } + } +} + +/// @private +void DoPrintHalfData(const aclFloat16* data, size_t count, + size_t elements_per_row) { + assert(elements_per_row != 0); + for (size_t i = 0; i < count; ++i) { + std::cout << std::setw(10) << std::setprecision(6) + << aclFloat16ToFloat(data[i]); + if (i % elements_per_row == elements_per_row - 1) { + std::cout << std::endl; + } + } +} + +/** + * @brief Print array content. + * + * @param [in] data Pointer to the array. + * @param [in] count Number of elements to print. + * @param [in] data_type Data type of the elements. + * @param [in] elements_per_row Number of elements to be printed in a single + * row. + */ +void PrintData(const void* data, size_t count, PrintDataType data_type, + size_t elements_per_row = 16) { + if (data == nullptr) { + ERROR_LOG("Cannot print a nullptr buffer."); + return; + } + + switch (data_type) { + case PrintDataType::BOOL: + DoPrintData(reinterpret_cast(data), count, elements_per_row); + break; + case PrintDataType::INT8_T: + DoPrintData(reinterpret_cast(data), count, + elements_per_row); + break; + case PrintDataType::UINT8_T: + DoPrintData(reinterpret_cast(data), count, + elements_per_row); + break; + case PrintDataType::INT16_T: + DoPrintData(reinterpret_cast(data), count, + elements_per_row); + break; + case PrintDataType::UINT16_T: + DoPrintData(reinterpret_cast(data), count, + elements_per_row); + break; + case PrintDataType::INT32_T: + DoPrintData(reinterpret_cast(data), count, + elements_per_row); + break; + case PrintDataType::UINT32_T: + DoPrintData(reinterpret_cast(data), count, + elements_per_row); + break; + case PrintDataType::INT64_T: + DoPrintData(reinterpret_cast(data), count, + elements_per_row); + break; + case PrintDataType::UINT64_T: + DoPrintData(reinterpret_cast(data), count, + elements_per_row); + break; + case PrintDataType::HALF: + DoPrintHalfData(reinterpret_cast(data), count, + elements_per_row); + break; + case PrintDataType::FLOAT: + DoPrintData(reinterpret_cast(data), count, + elements_per_row); + break; + case PrintDataType::DOUBLE: + DoPrintData(reinterpret_cast(data), count, + elements_per_row); + break; + default: + ERROR_LOG("Unsupported type."); + } + std::cout << std::endl; +} + +/** + * @brief Prints beginning and end of a given vector. + * + * @param [in] data Pointer to the array. + * @param [in] dt Data type of the elements. + * @param [in] elems_to_print Number of elements to print both from the + * beginning and end. + * @param [in] vector_len Total number of elements in the vector. + * @param [in] msg Additional message printed at the beginning. + */ +template +void PrintVector(const T* data, PrintDataType dt, size_t elems_to_print, + size_t vector_len, std::string msg = "") { + std::cout << "==========================================" << std::endl; + if (msg != "") { + std::cout << msg << std::endl; + } + if (2 * elems_to_print >= vector_len) { + PrintData(data, vector_len, dt); + } else { + PrintData(data, elems_to_print, dt); + std::cout << "\t..." << std::endl; + const size_t tail_start = vector_len - elems_to_print; + PrintData(data + tail_start, elems_to_print, dt); + } + std::cout << "==========================================" << std::endl; +} diff --git a/examples/a5/main_abs.cpp b/examples/a5/main_abs.cpp new file mode 100644 index 00000000..045c49e7 --- /dev/null +++ b/examples/a5/main_abs.cpp @@ -0,0 +1,76 @@ +/** + * + * @file main_abs.cpp + * @brief Example of using the `abs` kernel. + */ + +#include + +#include "data_utils.h" + +extern "C" void call_vabs_fp16(uint32_t blockDim, aclrtStream stream, void* x, + void* z, uint32_t num_elements); + +/// Number of elements in input vectors. +constexpr size_t VABS_TOTAL_LENGTH = 8 * 128; + +int32_t main(int32_t argc, char* argv[]) { + uint32_t blockDim; + if (argc > 2) { + std::cerr << "Usage: ./" << argv[0] << " " << std::endl; + return 1; + } else if (argc == 2) { + blockDim = std::stoul(argv[1]); + std::cout << "[vabs] Use input BlockDim: " << blockDim << std::endl; + } else { + std::cout << "[vabs] Use default BlockDim: 8" << std::endl; + blockDim = 8; + } + + constexpr size_t inputByteSize = VABS_TOTAL_LENGTH * sizeof(uint16_t); + constexpr size_t outputByteSize = VABS_TOTAL_LENGTH * sizeof(uint16_t); + + CHECK_ACL(aclInit(nullptr)); + aclrtContext context; + const int32_t device_id = 0; + CHECK_ACL(aclrtSetDevice(device_id)); + CHECK_ACL(aclrtCreateContext(&context, device_id)); + aclrtStream stream = nullptr; + CHECK_ACL(aclrtCreateStream(&stream)); + + uint8_t *xHost, *zHost; + uint8_t *xDevice, *zDevice; + CHECK_ACL(aclrtMallocHost((void**)&xHost, inputByteSize)); + CHECK_ACL(aclrtMallocHost((void**)&zHost, outputByteSize)); + CHECK_ACL( + aclrtMalloc((void**)&xDevice, inputByteSize, ACL_MEM_MALLOC_HUGE_FIRST)); + CHECK_ACL( + aclrtMalloc((void**)&zDevice, outputByteSize, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./input/input_x.bin", xHost, inputByteSize); + PrintVector((uint16_t*)xHost, PrintDataType::HALF, 16, VABS_TOTAL_LENGTH, + "Input X"); + CHECK_ACL(aclrtMemcpy(xDevice, inputByteSize, xHost, inputByteSize, + ACL_MEMCPY_HOST_TO_DEVICE)); + + std::cout << "Init vabs_fp16 kernel" << std::endl; + call_vabs_fp16(blockDim, stream, xDevice, zDevice, VABS_TOTAL_LENGTH); + CHECK_ACL(aclrtSynchronizeStream(stream)); + + CHECK_ACL(aclrtMemcpy(zHost, outputByteSize, zDevice, outputByteSize, + ACL_MEMCPY_DEVICE_TO_HOST)); + PrintVector((uint16_t*)zHost, PrintDataType::HALF, 16, VABS_TOTAL_LENGTH, + "Output"); + WriteFile("vabs_output.bin", zHost, outputByteSize); + + CHECK_ACL(aclrtFree(xDevice)); + CHECK_ACL(aclrtFree(zDevice)); + CHECK_ACL(aclrtFreeHost(xHost)); + CHECK_ACL(aclrtFreeHost(zHost)); + + CHECK_ACL(aclrtDestroyStream(stream)); + CHECK_ACL(aclrtDestroyContext(context)); + CHECK_ACL(aclrtResetDevice(device_id)); + CHECK_ACL(aclFinalize()); + return 0; +} diff --git a/scripts/data_gen_abs.py b/scripts/data_gen_abs.py new file mode 100644 index 00000000..d688d6cd --- /dev/null +++ b/scripts/data_gen_abs.py @@ -0,0 +1,14 @@ +#!/usr/bin/python3 +# -*- coding:utf-8 -*- +# Copyright 2026 Huawei Technologies Co., Ltd +from pathlib import Path + +import numpy as np + +if __name__ == "__main__": + shape = [8, 128] + + rng = np.random.default_rng(seed=42) + input_x = rng.uniform(-100, 100, shape).astype(np.float16) + Path("./input").mkdir(parents=True, exist_ok=True) + input_x.tofile("./input/input_x.bin")