From 4f45a18c22b96070f3a61f97a090b1ee870d475e Mon Sep 17 00:00:00 2001 From: Vladimir Loncar Date: Mon, 30 Mar 2026 09:30:50 +0000 Subject: [PATCH 01/17] Rough initial implementation, still not working --- CMakeLists.txt | 1 + csrc/host/pybind11.cpp | 3 + csrc/host/torch_histogram.h | 73 ++++++++++ csrc/kernel/kernel_histogram.cpp | 229 +++++++++++++++++++++++++++++++ tests/test_histogram.py | 27 ++++ 5 files changed, 333 insertions(+) create mode 100644 csrc/host/torch_histogram.h create mode 100644 csrc/kernel/kernel_histogram.cpp create mode 100644 tests/test_histogram.py diff --git a/CMakeLists.txt b/CMakeLists.txt index f24285ff..7c2f0e68 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -109,6 +109,7 @@ ascendc_library( SHARED csrc/kernel/kernel_tri_inv_col_sweep.cpp csrc/kernel/kernel_abs.cpp + csrc/kernel/kernel_histogram.cpp csrc/kernel/kernel_simple_matmul.cpp csrc/kernel/kernel_batch_matrix_square.cpp csrc/kernel/kernel_tri_inv_rec_unroll.cpp diff --git a/csrc/host/pybind11.cpp b/csrc/host/pybind11.cpp index bcd429c1..5ce3353a 100644 --- a/csrc/host/pybind11.cpp +++ b/csrc/host/pybind11.cpp @@ -11,6 +11,7 @@ for the full License text. #include "torch_abs.h" #include "torch_batch_matrix_square.h" +#include "torch_histogram.h" #include "torch_simple_matmul.h" #include "torch_tri_inv.h" #include "torch_tri_inv_rec_unroll.h" @@ -34,6 +35,8 @@ PYBIND11_MODULE(pto_kernels_ops, m) { }, pybind11::arg("device_id") = 0); m.def("pto_abs", &pto_isa_ops::run_abs); + m.def("pto_histogram", &pto_isa_ops::run_histogram, py::arg("x"), + py::arg("bins") = 100, py::arg("min") = 0.0, py::arg("max") = 0.0); m.def("pto_batch_matrix_square", &pto_isa_ops::run_batch_matrix_square); m.def("pto_simple_matmul", &pto_isa_ops::run_simple_matmul); m.def("pto_tri_inv_trick", &pto_isa_ops::run_tri_inv_trick); diff --git a/csrc/host/torch_histogram.h b/csrc/host/torch_histogram.h new file mode 100644 index 00000000..5a734073 --- /dev/null +++ b/csrc/host/torch_histogram.h @@ -0,0 +1,73 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +All rights reserved. + +See LICENSE in the root of the software repository: +https://github.com/huawei-csl/pto-kernels/ +for the full License text. +*/ +#pragma once + +#include +#include + +#include "aclrtlaunch_vhistogram_fp16.h" +#include "aclrtlaunch_vhistogram_fp32.h" +#include "utils.h" + +namespace pto_isa_ops { + +/** + * @brief Computes the histogram of a tensor. + * + * @param [in] x Input tensor of dtype fp16 or fp32. + * @param [in] bins Number of histogram bins. + * @param [in] min_val Lower bound of the range. + * @param [in] max_val Upper bound of the range. + * @return at::Tensor Computed histogram tensor. + */ +at::Tensor run_histogram(const at::Tensor& x, int64_t bins = 100, + double min_val = 0.0, double max_val = 0.0) { + + const uint32_t total_len = x.numel(); + // FIXME: tile length is fixed to 64 for now + constexpr uint32_t TILE_LEN = 64; + const uint32_t total_tiles = total_len / TILE_LEN; + uint32_t num_cores = GetNumVectorCores(); + if (total_tiles < num_cores) { + num_cores = total_tiles; + } + + const auto dtype = x.options().dtype(); + const auto device = x.options().device(); + auto z_opts = at::TensorOptions() + .dtype(at::kInt) // Set data type to int32 for histogram counts + .device(device); + // Allocate a 1D tensor sized `[bins]` for the histogram. + at::Tensor z = at::zeros({bins}, z_opts); + at::Tensor z_local = at::zeros({num_cores, bins}, z_opts); + + const auto num_bins = static_cast(bins); + + if (min_val == 0.0 && max_val == 0.0) { + const double tensor_min = x.min().item(); + const double tensor_max = x.max().item(); + min_val = tensor_min; + max_val = tensor_max; + } + + const auto f_min_val = static_cast(min_val); + const auto f_max_val = static_cast(max_val); + + if (dtype == at::kHalf) { + EXEC_KERNEL_CMD(vhistogram_fp16, total_tiles, x, z, z_local, total_len, num_bins, + f_min_val, f_max_val); + } else if (dtype == at::kFloat) { + EXEC_KERNEL_CMD(vhistogram_fp32, total_tiles, x, z, z_local, total_len, num_bins, + f_min_val, f_max_val); + } else { + throw std::runtime_error("Unsupported dtype for `pto_histogram` kernel"); + } + return z; +} +} // namespace pto_isa_ops diff --git a/csrc/kernel/kernel_histogram.cpp b/csrc/kernel/kernel_histogram.cpp new file mode 100644 index 00000000..736780bf --- /dev/null +++ b/csrc/kernel/kernel_histogram.cpp @@ -0,0 +1,229 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +All rights reserved. + +See LICENSE in the root of the software repository: +https://github.com/huawei-csl/pto-kernels/ +for the full License text. +*/ + +#if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) + +#define MEMORY_BASE + +#include +#include +#include "kernel_utils.h" + +#define GM_ADDR __gm__ uint8_t* // To avoid #include "kernel_operator.h" + +using namespace pto; + +template +AICORE void runTHistogram(__gm__ T* x, __gm__ int32_t* z, __gm__ int32_t* z_local, const uint32_t total_length, + const int32_t num_bins, const float min_val, const float max_val) { + + set_mask_norm(); + set_vector_mask(-1, -1); + + // --- Define Global Tensors --- + const uint32_t tile_num_elems = TILE_LEN * TILE_LEN; + using InputShape = pto::Shape<1, 1, 1, TILE_LEN, TILE_LEN>; + using InputStride = pto::Stride<1, 1, 1, TILE_LEN, 1>; + using InputGlobalData = pto::GlobalTensor; + using HistGlobalData = pto::GlobalTensor, pto::Stride<1, 1, 1, 1, 1>>; + + // --- Define UB Tiles --- + // Align num_bins for vector processing. Each int32_t is 4 bytes. 32-byte alignment. + const uint32_t num_bins_aligned = kernel_utils::CeilDiv(num_bins, (uint32_t)(32 / sizeof(int32_t))) * (32 / sizeof(int32_t)); + + // UB Memory Layout + constexpr uint32_t UB_X_TILES_ADDR = 0; + const uint32_t UB_CUR_MASK_ADDR = UB_X_TILES_ADDR + tile_num_elems * sizeof(T); + const uint32_t UB_CUR_MASK_I32_ADDR = UB_CUR_MASK_ADDR + tile_num_elems * sizeof(uint8_t); + const uint32_t UB_PREV_MASK_I32_ADDR = UB_CUR_MASK_I32_ADDR + tile_num_elems * sizeof(int32_t); + const uint32_t UB_BIN_MASK_I32_ADDR = UB_PREV_MASK_I32_ADDR + tile_num_elems * sizeof(int32_t); + const uint32_t UB_BIN_MASK_F32_ADDR = UB_BIN_MASK_I32_ADDR + tile_num_elems * sizeof(int32_t); + const uint32_t UB_ROW_SUM_ADDR = UB_BIN_MASK_F32_ADDR + tile_num_elems * sizeof(float); + const uint32_t UB_COUNT_ADDR = UB_ROW_SUM_ADDR + TILE_LEN * 8 * sizeof(float); + const uint32_t UB_LOCAL_HIST_ADDR = UB_COUNT_ADDR + 8 * sizeof(float); + + // Input tile + using InputTileData = Tile; + InputTileData xTiles(TILE_LEN, TILE_LEN); + TASSIGN(xTiles, UB_X_TILES_ADDR); + + // Mask tiles for binning + using MaskTileData = Tile; + MaskTileData current_mask(TILE_LEN, TILE_LEN); + TASSIGN(current_mask, UB_CUR_MASK_ADDR); + + // Tiles for reduction (counting) + using I32TileData = Tile; + I32TileData current_mask_i32(TILE_LEN, TILE_LEN); + TASSIGN(current_mask_i32, UB_CUR_MASK_I32_ADDR); + I32TileData prev_mask_i32(TILE_LEN, TILE_LEN); + TASSIGN(prev_mask_i32, UB_PREV_MASK_I32_ADDR); + I32TileData bin_mask_i32(TILE_LEN, TILE_LEN); + TASSIGN(bin_mask_i32, UB_BIN_MASK_I32_ADDR); + + // Tiles for reduction (counting) - float versions for reduction ops + using F32TileData = Tile; + F32TileData bin_mask_f32(TILE_LEN, TILE_LEN); + TASSIGN(bin_mask_f32, UB_BIN_MASK_F32_ADDR); + + using FloatRowSumTile = Tile; + FloatRowSumTile row_sum_f32_tile(TILE_LEN, 8); + TASSIGN(row_sum_f32_tile, UB_ROW_SUM_ADDR); + + using FloatCountTile = Tile; + FloatCountTile count_f32_tile(1, 8); + TASSIGN(count_f32_tile, UB_COUNT_ADDR); + + // Local histogram tile + constexpr uint32_t MAX_BINS = 8192; + using HistTile = Tile; + HistTile localHist(num_bins_aligned); + TASSIGN(localHist, UB_LOCAL_HIST_ADDR); + TEXPANDS(localHist, (int32_t)0); + + // --- Phase 1: Local histogram calculation --- + const uint32_t num_tiles_total = total_length / tile_num_elems; + const uint32_t num_tiles_per_core = num_tiles_total / get_block_num(); + const uint32_t start_tile_idx = get_block_idx() * num_tiles_per_core; + const uint32_t end_tile_idx = start_tile_idx + num_tiles_per_core; + + const float bin_width = (max_val - min_val) / num_bins; + + for (uint32_t tile_idx = start_tile_idx; tile_idx < end_tile_idx; ++tile_idx) { + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + + // Load input tile from GM to UB + const uint32_t offset = tile_idx * tile_num_elems; + InputGlobalData xGlobal(x + offset); + TLOAD(xTiles, xGlobal); + + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + // Initialize prev_mask to all zeros for the first bin boundary check + TEXPANDS(prev_mask_i32, (int32_t)0); + + for (int32_t j = 0; j < num_bins; ++j) { + float bin_upper_bound = min_val + (j + 1) * bin_width; + CmpMode mode = CmpMode::LT; + + // The last bin is inclusive on the upper bound + if (j == num_bins - 1) { + bin_upper_bound = max_val; + mode = CmpMode::LE; + } + + // Create a mask for elements less than (or equal to) the bin's upper bound. + // The result of TCMPS is a tile where elements are 0 or 1. + TCMPS(current_mask, xTiles, static_cast(bin_upper_bound), mode); + + // Convert the uint8_t mask to int32_t. + TCVT(current_mask_i32, current_mask, RoundMode::CAST_NONE); + + // The elements in the current bin are those in the current mask but not the previous one. + // Should have been done with TXOR but that fails. Since prev_mask is a subset of current_mask, + // using TSUB is the same. + TSUB(bin_mask_i32, current_mask_i32, prev_mask_i32); + + // TROWSUM/TCOLSUM do not support int32_t, so convert to float for reduction. + TCVT(bin_mask_f32, bin_mask_i32, RoundMode::CAST_NONE); + + // Reduce the 2D tile to a single scalar value. + // This requires a temporary tile for the intermediate row sums. + TROWSUM(row_sum_f32_tile, bin_mask_f32, row_sum_f32_tile); // In-place temporary for some targets + TCOLSUM(count_f32_tile, row_sum_f32_tile, row_sum_f32_tile, true); + + set_flag(PIPE_V, PIPE_S, EVENT_ID0); + wait_flag(PIPE_V, PIPE_S, EVENT_ID0); + + // Add the count to the local histogram for the current bin. + // This part is scalar as we update one bin at a time. + float f_count = count_f32_tile.GetValue(0); + int32_t count = static_cast(f_count + 0.5f); // Round to nearest int + if (count > 0) { + int32_t current_bin_count = localHist.GetValue(j); + localHist.SetValue(j, current_bin_count + count); + } + + set_flag(PIPE_S, PIPE_V, EVENT_ID0); + wait_flag(PIPE_S, PIPE_V, EVENT_ID0); + + // The current mask becomes the previous mask for the next iteration. + TMOV(prev_mask_i32, current_mask_i32); + } + } + + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + + // Store local histogram to GM + const uint32_t local_hist_offset = get_block_idx() * num_bins; + HistGlobalData zLocalGlobal(z_local + local_hist_offset, {num_bins}); + TSTORE(zLocalGlobal, localHist); + + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + + // Barrier to ensure all local histograms are in GM before reduction phase + pipe_barrier(PIPE_ALL); + + // --- Phase 2: Reduction of local histograms --- + if (get_block_idx() == 0) { + // Block 0's local histogram is already in its UB. + // Now, add histograms from other blocks. + HistTile otherHist(num_bins_aligned); + TASSIGN(otherHist, UB_X_TILES_ADDR); // Reuse UB space from the beginning + + for (uint32_t i = 1; i < get_block_num(); ++i) { + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + + // Load other block's histogram + const uint32_t other_hist_offset = i * num_bins; + HistGlobalData otherHistGlobal(z_local + other_hist_offset, {num_bins}); + TLOAD(otherHist, otherHistGlobal); + + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + // Add to the main histogram + TADD(localHist, localHist, otherHist); + } + + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + + // Store final histogram to z + HistGlobalData zGlobal(z, {num_bins}); + TSTORE(zGlobal, localHist); + + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + } +} + +extern "C" __global__ AICORE void vhistogram_fp16(GM_ADDR x, GM_ADDR z, GM_ADDR z_local, + const uint32_t in_length, + const int32_t num_bins, + const float min_val, const float max_val) { + constexpr unsigned TILE_LEN = 64; + runTHistogram((__gm__ half*)x, (__gm__ int32_t*)z, (__gm__ int32_t*)z_local, in_length, + num_bins, min_val, max_val); +} + +extern "C" __global__ AICORE void vhistogram_fp32(GM_ADDR x, GM_ADDR z, GM_ADDR z_local, + const uint32_t in_length, + const int32_t num_bins, + const float min_val, const float max_val) { + constexpr unsigned TILE_LEN = 64; + runTHistogram((__gm__ float*)x, (__gm__ int32_t*)z, (__gm__ int32_t*)z_local, in_length, + num_bins, min_val, max_val); +} +#endif diff --git a/tests/test_histogram.py b/tests/test_histogram.py new file mode 100644 index 00000000..aaa840e8 --- /dev/null +++ b/tests/test_histogram.py @@ -0,0 +1,27 @@ +# -------------------------------------------------------------------------------- +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# All rights reserved. +# See LICENSE in the root of the software repository: +# https://github.com/huawei-csl/pto-kernels/ +# for the full License text. +# -------------------------------------------------------------------------------- + +import torch +from pto_kernels import pto_histogram +import pytest + + +@pytest.mark.parametrize("num_blocks", [1, 2, 10, 20, 32, 64]) +@pytest.mark.parametrize("bins", [2, 4, 16, 50, 100]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=str) +def test_pto_histogram(num_blocks: int, bins: int, dtype: torch.dtype): + tile_len = 64 + length = [num_blocks * tile_len] + + x = torch.rand(length, device="cpu", dtype=dtype) + x_npu = x.npu() + + y_npu = pto_histogram(x_npu, bins=bins).cpu() + y_cpu = torch.histc(x, bins=bins) + + assert torch.equal(y_npu, y_cpu) From 47b7654ace11caa5031d68967527c2ddcd05291f Mon Sep 17 00:00:00 2001 From: Vladimir Loncar Date: Tue, 31 Mar 2026 09:43:12 +0000 Subject: [PATCH 02/17] Split into 2 phases, WIP --- csrc/host/torch_histogram.h | 29 ++-- csrc/kernel/kernel_histogram.cpp | 218 +++++++++++++++++-------------- tests/test_histogram.py | 6 +- 3 files changed, 145 insertions(+), 108 deletions(-) diff --git a/csrc/host/torch_histogram.h b/csrc/host/torch_histogram.h index 5a734073..429768ab 100644 --- a/csrc/host/torch_histogram.h +++ b/csrc/host/torch_histogram.h @@ -11,8 +11,10 @@ for the full License text. #include #include -#include "aclrtlaunch_vhistogram_fp16.h" -#include "aclrtlaunch_vhistogram_fp32.h" +#include "aclrtlaunch_vhistogram_local_fp16.h" +#include "aclrtlaunch_vhistogram_local_fp32.h" +#include "aclrtlaunch_vhistogram_reduce_fp16.h" +#include "aclrtlaunch_vhistogram_reduce_fp32.h" #include "utils.h" namespace pto_isa_ops { @@ -29,14 +31,15 @@ namespace pto_isa_ops { at::Tensor run_histogram(const at::Tensor& x, int64_t bins = 100, double min_val = 0.0, double max_val = 0.0) { + const uint32_t total_len = x.numel(); - // FIXME: tile length is fixed to 64 for now - constexpr uint32_t TILE_LEN = 64; - const uint32_t total_tiles = total_len / TILE_LEN; + // // FIXME: tile length is fixed to 64 for now + // constexpr uint32_t TILE_LEN = 64; + // const uint32_t total_tiles = total_len / TILE_LEN; uint32_t num_cores = GetNumVectorCores(); - if (total_tiles < num_cores) { - num_cores = total_tiles; - } + // if (total_tiles < num_cores) { + // num_cores = total_tiles; + // } const auto dtype = x.options().dtype(); const auto device = x.options().device(); @@ -59,15 +62,21 @@ at::Tensor run_histogram(const at::Tensor& x, int64_t bins = 100, const auto f_min_val = static_cast(min_val); const auto f_max_val = static_cast(max_val); + // Phase 1: Launch one kernel per core to compute local histograms if (dtype == at::kHalf) { - EXEC_KERNEL_CMD(vhistogram_fp16, total_tiles, x, z, z_local, total_len, num_bins, + EXEC_KERNEL_CMD(vhistogram_local_fp16, num_cores, x, z_local, total_len, num_bins, f_min_val, f_max_val); } else if (dtype == at::kFloat) { - EXEC_KERNEL_CMD(vhistogram_fp32, total_tiles, x, z, z_local, total_len, num_bins, + EXEC_KERNEL_CMD(vhistogram_local_fp32, num_cores, x, z_local, total_len, num_bins, f_min_val, f_max_val); } else { throw std::runtime_error("Unsupported dtype for `pto_histogram` kernel"); } + + // Phase 2: Launch a single kernel to reduce all local histograms + const uint32_t num_reduce_cores = 1; + EXEC_KERNEL_CMD(vhistogram_reduce_fp32, num_reduce_cores, z, z_local, num_bins, num_cores); + return z; } } // namespace pto_isa_ops diff --git a/csrc/kernel/kernel_histogram.cpp b/csrc/kernel/kernel_histogram.cpp index 736780bf..fd290fde 100644 --- a/csrc/kernel/kernel_histogram.cpp +++ b/csrc/kernel/kernel_histogram.cpp @@ -19,79 +19,94 @@ for the full License text. using namespace pto; +/** + * runTLocalHistogram - Phase 1: Local histogram calculation per core. + */ template -AICORE void runTHistogram(__gm__ T* x, __gm__ int32_t* z, __gm__ int32_t* z_local, const uint32_t total_length, - const int32_t num_bins, const float min_val, const float max_val) { - +AICORE void runTLocalHistogram(__gm__ T* x, __gm__ int32_t* z_local, const uint32_t total_length, + const int32_t num_bins, const float min_val, const float max_val) +{ set_mask_norm(); set_vector_mask(-1, -1); // --- Define Global Tensors --- - const uint32_t tile_num_elems = TILE_LEN * TILE_LEN; + constexpr uint32_t TILE_SIZE = TILE_LEN * TILE_LEN; using InputShape = pto::Shape<1, 1, 1, TILE_LEN, TILE_LEN>; using InputStride = pto::Stride<1, 1, 1, TILE_LEN, 1>; using InputGlobalData = pto::GlobalTensor; + + // Align num_bins for vector processing and GM 32-byte alignment. + const uint32_t num_bins_aligned = kernel_utils::CeilDiv(num_bins, 8) * 8; using HistGlobalData = pto::GlobalTensor, pto::Stride<1, 1, 1, 1, 1>>; - // --- Define UB Tiles --- - // Align num_bins for vector processing. Each int32_t is 4 bytes. 32-byte alignment. - const uint32_t num_bins_aligned = kernel_utils::CeilDiv(num_bins, (uint32_t)(32 / sizeof(int32_t))) * (32 / sizeof(int32_t)); - - // UB Memory Layout - constexpr uint32_t UB_X_TILES_ADDR = 0; - const uint32_t UB_CUR_MASK_ADDR = UB_X_TILES_ADDR + tile_num_elems * sizeof(T); - const uint32_t UB_CUR_MASK_I32_ADDR = UB_CUR_MASK_ADDR + tile_num_elems * sizeof(uint8_t); - const uint32_t UB_PREV_MASK_I32_ADDR = UB_CUR_MASK_I32_ADDR + tile_num_elems * sizeof(int32_t); - const uint32_t UB_BIN_MASK_I32_ADDR = UB_PREV_MASK_I32_ADDR + tile_num_elems * sizeof(int32_t); - const uint32_t UB_BIN_MASK_F32_ADDR = UB_BIN_MASK_I32_ADDR + tile_num_elems * sizeof(int32_t); - const uint32_t UB_ROW_SUM_ADDR = UB_BIN_MASK_F32_ADDR + tile_num_elems * sizeof(float); - const uint32_t UB_COUNT_ADDR = UB_ROW_SUM_ADDR + TILE_LEN * 8 * sizeof(float); - const uint32_t UB_LOCAL_HIST_ADDR = UB_COUNT_ADDR + 8 * sizeof(float); + // --- Define UB Tiles and Memory Layout --- + constexpr uint32_t MASK_COLS = TILE_LEN / 8; + constexpr uint32_t MASK_CAPACITY_COLS = 32; + + uint32_t addr = 0; + const uint32_t UB_X_ADDR = addr; addr += TILE_SIZE * sizeof(T); + const uint32_t UB_CUR_MASK_ADDR = addr; addr += TILE_LEN * MASK_CAPACITY_COLS * sizeof(uint8_t); + const uint32_t UB_CUR_F32_ADDR = addr; addr += TILE_SIZE * sizeof(float); + const uint32_t UB_PREV_F32_ADDR = addr; addr += TILE_SIZE * sizeof(float); + const uint32_t UB_BIN_F32_ADDR = addr; addr += TILE_SIZE * sizeof(float); + const uint32_t UB_ONE_ADDR = addr; addr += TILE_SIZE * sizeof(float); + const uint32_t UB_ZERO_ADDR = addr; addr += TILE_SIZE * sizeof(float); + const uint32_t UB_REDUCE_TMP_ADDR = addr; addr += TILE_SIZE * sizeof(float); + const uint32_t UB_RSUM_ADDR = addr; addr += TILE_LEN * 8 * sizeof(float); // 64x8 for alignment + const uint32_t UB_COUNT_ADDR = addr; addr += 8 * sizeof(float); + const uint32_t UB_LOCAL_HIST_ADDR = addr; // Input tile using InputTileData = Tile; InputTileData xTiles(TILE_LEN, TILE_LEN); - TASSIGN(xTiles, UB_X_TILES_ADDR); + TASSIGN(xTiles, UB_X_ADDR); - // Mask tiles for binning - using MaskTileData = Tile; - MaskTileData current_mask(TILE_LEN, TILE_LEN); + // Mask tile (packed bits) + using MaskTileData = Tile; + MaskTileData current_mask(TILE_LEN, MASK_COLS); TASSIGN(current_mask, UB_CUR_MASK_ADDR); - // Tiles for reduction (counting) - using I32TileData = Tile; - I32TileData current_mask_i32(TILE_LEN, TILE_LEN); - TASSIGN(current_mask_i32, UB_CUR_MASK_I32_ADDR); - I32TileData prev_mask_i32(TILE_LEN, TILE_LEN); - TASSIGN(prev_mask_i32, UB_PREV_MASK_I32_ADDR); - I32TileData bin_mask_i32(TILE_LEN, TILE_LEN); - TASSIGN(bin_mask_i32, UB_BIN_MASK_I32_ADDR); - - // Tiles for reduction (counting) - float versions for reduction ops + // Float conversion tiles using F32TileData = Tile; + F32TileData cur_f32(TILE_LEN, TILE_LEN); + TASSIGN(cur_f32, UB_CUR_F32_ADDR); + F32TileData prev_f32(TILE_LEN, TILE_LEN); + TASSIGN(prev_f32, UB_PREV_F32_ADDR); F32TileData bin_mask_f32(TILE_LEN, TILE_LEN); - TASSIGN(bin_mask_f32, UB_BIN_MASK_F32_ADDR); - - using FloatRowSumTile = Tile; - FloatRowSumTile row_sum_f32_tile(TILE_LEN, 8); - TASSIGN(row_sum_f32_tile, UB_ROW_SUM_ADDR); - - using FloatCountTile = Tile; - FloatCountTile count_f32_tile(1, 8); + TASSIGN(bin_mask_f32, UB_BIN_F32_ADDR); + + F32TileData one_tile(TILE_LEN, TILE_LEN); + TASSIGN(one_tile, UB_ONE_ADDR); + F32TileData zero_tile(TILE_LEN, TILE_LEN); + TASSIGN(zero_tile, UB_ZERO_ADDR); + TEXPANDS(one_tile, 1.0f); + TEXPANDS(zero_tile, 0.0f); + + F32TileData reduce_tmp(TILE_LEN, TILE_LEN); + TASSIGN(reduce_tmp, UB_REDUCE_TMP_ADDR); + + // Reduction result tiles. + // For RowMajor and NoneBox, Cols must be a multiple of 8 (32 bytes for float). + using FloatRowSumTile = Tile; + FloatRowSumTile row_sum_f32_tile(1); // Set ValidCol to 1 + TASSIGN(row_sum_f32_tile, UB_RSUM_ADDR); + + using FloatCountTile = Tile; + FloatCountTile count_f32_tile; TASSIGN(count_f32_tile, UB_COUNT_ADDR); // Local histogram tile constexpr uint32_t MAX_BINS = 8192; using HistTile = Tile; - HistTile localHist(num_bins_aligned); + HistTile localHist(static_cast(num_bins_aligned)); TASSIGN(localHist, UB_LOCAL_HIST_ADDR); TEXPANDS(localHist, (int32_t)0); // --- Phase 1: Local histogram calculation --- - const uint32_t num_tiles_total = total_length / tile_num_elems; - const uint32_t num_tiles_per_core = num_tiles_total / get_block_num(); + const uint32_t num_tiles_total = kernel_utils::CeilDiv(total_length, TILE_SIZE); + const uint32_t num_tiles_per_core = kernel_utils::CeilDiv(num_tiles_total, get_block_num()); const uint32_t start_tile_idx = get_block_idx() * num_tiles_per_core; - const uint32_t end_tile_idx = start_tile_idx + num_tiles_per_core; + const uint32_t end_tile_idx = (start_tile_idx + num_tiles_per_core > num_tiles_total) ? num_tiles_total : (start_tile_idx + num_tiles_per_core); const float bin_width = (max_val - min_val) / num_bins; @@ -99,131 +114,142 @@ AICORE void runTHistogram(__gm__ T* x, __gm__ int32_t* z, __gm__ int32_t* z_loca set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); - // Load input tile from GM to UB - const uint32_t offset = tile_idx * tile_num_elems; + const uint32_t offset = tile_idx * TILE_SIZE; InputGlobalData xGlobal(x + offset); TLOAD(xTiles, xGlobal); set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); - // Initialize prev_mask to all zeros for the first bin boundary check - TEXPANDS(prev_mask_i32, (int32_t)0); + TCMPS(current_mask, xTiles, static_cast(min_val), CmpMode::LT); + TSEL(prev_f32, current_mask, one_tile, zero_tile); for (int32_t j = 0; j < num_bins; ++j) { float bin_upper_bound = min_val + (j + 1) * bin_width; CmpMode mode = CmpMode::LT; - - // The last bin is inclusive on the upper bound if (j == num_bins - 1) { bin_upper_bound = max_val; mode = CmpMode::LE; } - // Create a mask for elements less than (or equal to) the bin's upper bound. - // The result of TCMPS is a tile where elements are 0 or 1. TCMPS(current_mask, xTiles, static_cast(bin_upper_bound), mode); + TSEL(cur_f32, current_mask, one_tile, zero_tile); - // Convert the uint8_t mask to int32_t. - TCVT(current_mask_i32, current_mask, RoundMode::CAST_NONE); - - // The elements in the current bin are those in the current mask but not the previous one. - // Should have been done with TXOR but that fails. Since prev_mask is a subset of current_mask, - // using TSUB is the same. - TSUB(bin_mask_i32, current_mask_i32, prev_mask_i32); + TSUB(bin_mask_f32, cur_f32, prev_f32); - // TROWSUM/TCOLSUM do not support int32_t, so convert to float for reduction. - TCVT(bin_mask_f32, bin_mask_i32, RoundMode::CAST_NONE); - - // Reduce the 2D tile to a single scalar value. - // This requires a temporary tile for the intermediate row sums. - TROWSUM(row_sum_f32_tile, bin_mask_f32, row_sum_f32_tile); // In-place temporary for some targets + TROWSUM(row_sum_f32_tile, bin_mask_f32, reduce_tmp); TCOLSUM(count_f32_tile, row_sum_f32_tile, row_sum_f32_tile, true); set_flag(PIPE_V, PIPE_S, EVENT_ID0); wait_flag(PIPE_V, PIPE_S, EVENT_ID0); - // Add the count to the local histogram for the current bin. - // This part is scalar as we update one bin at a time. float f_count = count_f32_tile.GetValue(0); - int32_t count = static_cast(f_count + 0.5f); // Round to nearest int + int32_t count = static_cast(f_count + 0.5f); if (count > 0) { - int32_t current_bin_count = localHist.GetValue(j); - localHist.SetValue(j, current_bin_count + count); + localHist.SetValue(j, localHist.GetValue(j) + count); } set_flag(PIPE_S, PIPE_V, EVENT_ID0); wait_flag(PIPE_S, PIPE_V, EVENT_ID0); - // The current mask becomes the previous mask for the next iteration. - TMOV(prev_mask_i32, current_mask_i32); + TMOV(prev_f32, cur_f32); } } set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); - // Store local histogram to GM - const uint32_t local_hist_offset = get_block_idx() * num_bins; - HistGlobalData zLocalGlobal(z_local + local_hist_offset, {num_bins}); + const uint32_t local_hist_offset = get_block_idx() * num_bins_aligned; + HistGlobalData zLocalGlobal(z_local + local_hist_offset, {static_cast(num_bins_aligned)}); TSTORE(zLocalGlobal, localHist); set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); +} - // Barrier to ensure all local histograms are in GM before reduction phase - pipe_barrier(PIPE_ALL); +/** + * runTReduceHistogram - Phase 2: Reduction of local histograms from all cores. + */ +AICORE void runTReduceHistogram(__gm__ int32_t* z, __gm__ int32_t* z_local, + const int32_t num_bins, const uint32_t num_cores) { + set_mask_norm(); + set_vector_mask(-1, -1); + + const uint32_t num_bins_aligned = kernel_utils::CeilDiv(num_bins, 8) * 8; + using HistGlobalData = pto::GlobalTensor, pto::Stride<1, 1, 1, 1, 1>>; + + constexpr uint32_t UB_MAIN_HIST_ADDR = 0; + const uint32_t UB_OTHER_HIST_ADDR = UB_MAIN_HIST_ADDR + num_bins_aligned * sizeof(int32_t); + + constexpr uint32_t MAX_BINS = 8192; + using HistTile = Tile; - // --- Phase 2: Reduction of local histograms --- if (get_block_idx() == 0) { - // Block 0's local histogram is already in its UB. - // Now, add histograms from other blocks. - HistTile otherHist(num_bins_aligned); - TASSIGN(otherHist, UB_X_TILES_ADDR); // Reuse UB space from the beginning + HistTile mainHist(static_cast(num_bins_aligned)); + TASSIGN(mainHist, UB_MAIN_HIST_ADDR); - for (uint32_t i = 1; i < get_block_num(); ++i) { + HistGlobalData mainHistGlobal(z_local, {static_cast(num_bins_aligned)}); + TLOAD(mainHist, mainHistGlobal); + + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + HistTile otherHist(static_cast(num_bins_aligned)); + TASSIGN(otherHist, UB_OTHER_HIST_ADDR); + + for (uint32_t i = 1; i < num_cores; ++i) { set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); - // Load other block's histogram - const uint32_t other_hist_offset = i * num_bins; - HistGlobalData otherHistGlobal(z_local + other_hist_offset, {num_bins}); + const uint32_t other_hist_offset = i * num_bins_aligned; + HistGlobalData otherHistGlobal(z_local + other_hist_offset, {static_cast(num_bins_aligned)}); TLOAD(otherHist, otherHistGlobal); set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); - // Add to the main histogram - TADD(localHist, localHist, otherHist); + TADD(mainHist, mainHist, otherHist); } set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); - // Store final histogram to z - HistGlobalData zGlobal(z, {num_bins}); - TSTORE(zGlobal, localHist); + // Create a new tile object sharing the same address to set ValidCol to num_bins for final store + HistTile finalHist(static_cast(num_bins)); + TASSIGN(finalHist, UB_MAIN_HIST_ADDR); + HistGlobalData zGlobal(z, {static_cast(num_bins)}); + TSTORE(zGlobal, finalHist); set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); } } -extern "C" __global__ AICORE void vhistogram_fp16(GM_ADDR x, GM_ADDR z, GM_ADDR z_local, +extern "C" __global__ AICORE void vhistogram_local_fp16(GM_ADDR x, GM_ADDR z_local, const uint32_t in_length, const int32_t num_bins, const float min_val, const float max_val) { constexpr unsigned TILE_LEN = 64; - runTHistogram((__gm__ half*)x, (__gm__ int32_t*)z, (__gm__ int32_t*)z_local, in_length, + runTLocalHistogram((__gm__ half*)x, (__gm__ int32_t*)z_local, in_length, num_bins, min_val, max_val); } -extern "C" __global__ AICORE void vhistogram_fp32(GM_ADDR x, GM_ADDR z, GM_ADDR z_local, +extern "C" __global__ AICORE void vhistogram_local_fp32(GM_ADDR x, GM_ADDR z_local, const uint32_t in_length, const int32_t num_bins, const float min_val, const float max_val) { constexpr unsigned TILE_LEN = 64; - runTHistogram((__gm__ float*)x, (__gm__ int32_t*)z, (__gm__ int32_t*)z_local, in_length, + runTLocalHistogram((__gm__ float*)x, (__gm__ int32_t*)z_local, in_length, num_bins, min_val, max_val); } -#endif + +extern "C" __global__ AICORE void vhistogram_reduce_fp16(__gm__ int32_t* z, __gm__ int32_t* z_local, + const int32_t num_bins, const uint32_t num_cores) { + runTReduceHistogram(z, z_local, num_bins, num_cores); +} + +extern "C" __global__ AICORE void vhistogram_reduce_fp32(__gm__ int32_t* z, __gm__ int32_t* z_local, + const int32_t num_bins, const uint32_t num_cores) { + runTReduceHistogram(z, z_local, num_bins, num_cores); +} +#endif \ No newline at end of file diff --git a/tests/test_histogram.py b/tests/test_histogram.py index aaa840e8..e1f96bf5 100644 --- a/tests/test_histogram.py +++ b/tests/test_histogram.py @@ -11,8 +11,10 @@ import pytest -@pytest.mark.parametrize("num_blocks", [1, 2, 10, 20, 32, 64]) -@pytest.mark.parametrize("bins", [2, 4, 16, 50, 100]) +#@pytest.mark.parametrize("num_blocks", [1, 2, 10, 20, 32, 64]) +#@pytest.mark.parametrize("bins", [2, 4, 16, 50, 100]) +@pytest.mark.parametrize("num_blocks", [1]) +@pytest.mark.parametrize("bins", [64]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=str) def test_pto_histogram(num_blocks: int, bins: int, dtype: torch.dtype): tile_len = 64 From 1f20f0fc4d4e7dffc08dc5706ee5ebdee7a0474a Mon Sep 17 00:00:00 2001 From: Vladimir Loncar Date: Wed, 1 Apr 2026 13:47:25 +0000 Subject: [PATCH 03/17] Implementation using AtomicAdd in TSTORE, WIP --- csrc/host/torch_histogram.h | 49 ++--- csrc/kernel/kernel_histogram.cpp | 331 +++++++++++++------------------ tests/test_histogram.py | 24 ++- 3 files changed, 174 insertions(+), 230 deletions(-) diff --git a/csrc/host/torch_histogram.h b/csrc/host/torch_histogram.h index 429768ab..74db3d63 100644 --- a/csrc/host/torch_histogram.h +++ b/csrc/host/torch_histogram.h @@ -11,10 +11,8 @@ for the full License text. #include #include -#include "aclrtlaunch_vhistogram_local_fp16.h" -#include "aclrtlaunch_vhistogram_local_fp32.h" -#include "aclrtlaunch_vhistogram_reduce_fp16.h" -#include "aclrtlaunch_vhistogram_reduce_fp32.h" +#include "aclrtlaunch_vhistogram_fp16.h" +#include "aclrtlaunch_vhistogram_fp32.h" #include "utils.h" namespace pto_isa_ops { @@ -30,26 +28,26 @@ namespace pto_isa_ops { */ at::Tensor run_histogram(const at::Tensor& x, int64_t bins = 100, double min_val = 0.0, double max_val = 0.0) { - - const uint32_t total_len = x.numel(); - // // FIXME: tile length is fixed to 64 for now - // constexpr uint32_t TILE_LEN = 64; - // const uint32_t total_tiles = total_len / TILE_LEN; - uint32_t num_cores = GetNumVectorCores(); - // if (total_tiles < num_cores) { - // num_cores = total_tiles; - // } + constexpr uint32_t TILE_LEN = 64; + constexpr uint32_t TILE_SIZE = TILE_LEN * TILE_LEN; + // const uint32_t block_dim = (total_len + TILE_SIZE - 1) / TILE_SIZE; + // const uint32_t block_dim = GetNumVectorCores();; + const uint32_t block_dim = 1; + + TORCH_CHECK(total_len / TILE_SIZE != 0, + "total number of elements must be divisible by 64 * 64"); + TORCH_CHECK(bins <= 1024, "bins must be <= 1024"); const auto dtype = x.options().dtype(); const auto device = x.options().device(); - auto z_opts = at::TensorOptions() - .dtype(at::kInt) // Set data type to int32 for histogram counts - .device(device); + auto z_opts = + at::TensorOptions() + .dtype(at::kInt) // Set data type to int32 for histogram counts + .device(device); // Allocate a 1D tensor sized `[bins]` for the histogram. at::Tensor z = at::zeros({bins}, z_opts); - at::Tensor z_local = at::zeros({num_cores, bins}, z_opts); - + const auto num_bins = static_cast(bins); if (min_val == 0.0 && max_val == 0.0) { @@ -61,22 +59,19 @@ at::Tensor run_histogram(const at::Tensor& x, int64_t bins = 100, const auto f_min_val = static_cast(min_val); const auto f_max_val = static_cast(max_val); + const float f_bin_width = (f_max_val - f_min_val) / (float)num_bins; - // Phase 1: Launch one kernel per core to compute local histograms + at::Tensor x_contig = x.contiguous(); // Just in case if (dtype == at::kHalf) { - EXEC_KERNEL_CMD(vhistogram_local_fp16, num_cores, x, z_local, total_len, num_bins, - f_min_val, f_max_val); + EXEC_KERNEL_CMD(vhistogram_fp16, block_dim, x_contig, z, total_len, + num_bins, f_min_val, f_max_val, f_bin_width); } else if (dtype == at::kFloat) { - EXEC_KERNEL_CMD(vhistogram_local_fp32, num_cores, x, z_local, total_len, num_bins, - f_min_val, f_max_val); + EXEC_KERNEL_CMD(vhistogram_fp32, block_dim, x_contig, z, total_len, + num_bins, f_min_val, f_max_val, f_bin_width); } else { throw std::runtime_error("Unsupported dtype for `pto_histogram` kernel"); } - // Phase 2: Launch a single kernel to reduce all local histograms - const uint32_t num_reduce_cores = 1; - EXEC_KERNEL_CMD(vhistogram_reduce_fp32, num_reduce_cores, z, z_local, num_bins, num_cores); - return z; } } // namespace pto_isa_ops diff --git a/csrc/kernel/kernel_histogram.cpp b/csrc/kernel/kernel_histogram.cpp index fd290fde..ce43bd2c 100644 --- a/csrc/kernel/kernel_histogram.cpp +++ b/csrc/kernel/kernel_histogram.cpp @@ -12,19 +12,19 @@ for the full License text. #define MEMORY_BASE #include -#include #include "kernel_utils.h" -#define GM_ADDR __gm__ uint8_t* // To avoid #include "kernel_operator.h" +#define GM_ADDR __gm__ uint8_t* using namespace pto; /** - * runTLocalHistogram - Phase 1: Local histogram calculation per core. + * runTHistogram - Local histogram calculation with Atomic Addition to Global Memory. */ template -AICORE void runTLocalHistogram(__gm__ T* x, __gm__ int32_t* z_local, const uint32_t total_length, - const int32_t num_bins, const float min_val, const float max_val) +AICORE void runTHistogram(__gm__ T* x, __gm__ int32_t* z, const uint32_t total_length, + const int32_t num_bins, const float min_val, const float max_val, + const float bin_width) { set_mask_norm(); set_vector_mask(-1, -1); @@ -35,221 +35,162 @@ AICORE void runTLocalHistogram(__gm__ T* x, __gm__ int32_t* z_local, const uint3 using InputStride = pto::Stride<1, 1, 1, TILE_LEN, 1>; using InputGlobalData = pto::GlobalTensor; - // Align num_bins for vector processing and GM 32-byte alignment. + // Align num_bins for vector processing (multiple of 8 for 32-byte alignment) const uint32_t num_bins_aligned = kernel_utils::CeilDiv(num_bins, 8) * 8; using HistGlobalData = pto::GlobalTensor, pto::Stride<1, 1, 1, 1, 1>>; - // --- Define UB Tiles and Memory Layout --- - constexpr uint32_t MASK_COLS = TILE_LEN / 8; - constexpr uint32_t MASK_CAPACITY_COLS = 32; - - uint32_t addr = 0; - const uint32_t UB_X_ADDR = addr; addr += TILE_SIZE * sizeof(T); - const uint32_t UB_CUR_MASK_ADDR = addr; addr += TILE_LEN * MASK_CAPACITY_COLS * sizeof(uint8_t); - const uint32_t UB_CUR_F32_ADDR = addr; addr += TILE_SIZE * sizeof(float); - const uint32_t UB_PREV_F32_ADDR = addr; addr += TILE_SIZE * sizeof(float); - const uint32_t UB_BIN_F32_ADDR = addr; addr += TILE_SIZE * sizeof(float); - const uint32_t UB_ONE_ADDR = addr; addr += TILE_SIZE * sizeof(float); - const uint32_t UB_ZERO_ADDR = addr; addr += TILE_SIZE * sizeof(float); - const uint32_t UB_REDUCE_TMP_ADDR = addr; addr += TILE_SIZE * sizeof(float); - const uint32_t UB_RSUM_ADDR = addr; addr += TILE_LEN * 8 * sizeof(float); // 64x8 for alignment - const uint32_t UB_COUNT_ADDR = addr; addr += 8 * sizeof(float); - const uint32_t UB_LOCAL_HIST_ADDR = addr; - - // Input tile - using InputTileData = Tile; - InputTileData xTiles(TILE_LEN, TILE_LEN); - TASSIGN(xTiles, UB_X_ADDR); - - // Mask tile (packed bits) - using MaskTileData = Tile; - MaskTileData current_mask(TILE_LEN, MASK_COLS); - TASSIGN(current_mask, UB_CUR_MASK_ADDR); - - // Float conversion tiles - using F32TileData = Tile; - F32TileData cur_f32(TILE_LEN, TILE_LEN); - TASSIGN(cur_f32, UB_CUR_F32_ADDR); - F32TileData prev_f32(TILE_LEN, TILE_LEN); - TASSIGN(prev_f32, UB_PREV_F32_ADDR); - F32TileData bin_mask_f32(TILE_LEN, TILE_LEN); - TASSIGN(bin_mask_f32, UB_BIN_F32_ADDR); - - F32TileData one_tile(TILE_LEN, TILE_LEN); - TASSIGN(one_tile, UB_ONE_ADDR); - F32TileData zero_tile(TILE_LEN, TILE_LEN); - TASSIGN(zero_tile, UB_ZERO_ADDR); - TEXPANDS(one_tile, 1.0f); - TEXPANDS(zero_tile, 0.0f); - - F32TileData reduce_tmp(TILE_LEN, TILE_LEN); - TASSIGN(reduce_tmp, UB_REDUCE_TMP_ADDR); - - // Reduction result tiles. - // For RowMajor and NoneBox, Cols must be a multiple of 8 (32 bytes for float). - using FloatRowSumTile = Tile; - FloatRowSumTile row_sum_f32_tile(1); // Set ValidCol to 1 - TASSIGN(row_sum_f32_tile, UB_RSUM_ADDR); - - using FloatCountTile = Tile; - FloatCountTile count_f32_tile; - TASSIGN(count_f32_tile, UB_COUNT_ADDR); - - // Local histogram tile - constexpr uint32_t MAX_BINS = 8192; - using HistTile = Tile; - HistTile localHist(static_cast(num_bins_aligned)); - TASSIGN(localHist, UB_LOCAL_HIST_ADDR); - TEXPANDS(localHist, (int32_t)0); - - // --- Phase 1: Local histogram calculation --- + // --- Work Distribution --- const uint32_t num_tiles_total = kernel_utils::CeilDiv(total_length, TILE_SIZE); const uint32_t num_tiles_per_core = kernel_utils::CeilDiv(num_tiles_total, get_block_num()); const uint32_t start_tile_idx = get_block_idx() * num_tiles_per_core; const uint32_t end_tile_idx = (start_tile_idx + num_tiles_per_core > num_tiles_total) ? num_tiles_total : (start_tile_idx + num_tiles_per_core); - const float bin_width = (max_val - min_val) / num_bins; - - for (uint32_t tile_idx = start_tile_idx; tile_idx < end_tile_idx; ++tile_idx) { - set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); - wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); - - const uint32_t offset = tile_idx * TILE_SIZE; - InputGlobalData xGlobal(x + offset); - TLOAD(xTiles, xGlobal); - - set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); - wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); - - TCMPS(current_mask, xTiles, static_cast(min_val), CmpMode::LT); - TSEL(prev_f32, current_mask, one_tile, zero_tile); - - for (int32_t j = 0; j < num_bins; ++j) { - float bin_upper_bound = min_val + (j + 1) * bin_width; - CmpMode mode = CmpMode::LT; - if (j == num_bins - 1) { - bin_upper_bound = max_val; - mode = CmpMode::LE; - } - - TCMPS(current_mask, xTiles, static_cast(bin_upper_bound), mode); - TSEL(cur_f32, current_mask, one_tile, zero_tile); - - TSUB(bin_mask_f32, cur_f32, prev_f32); - - TROWSUM(row_sum_f32_tile, bin_mask_f32, reduce_tmp); - TCOLSUM(count_f32_tile, row_sum_f32_tile, row_sum_f32_tile, true); - - set_flag(PIPE_V, PIPE_S, EVENT_ID0); - wait_flag(PIPE_V, PIPE_S, EVENT_ID0); - - float f_count = count_f32_tile.GetValue(0); - int32_t count = static_cast(f_count + 0.5f); - if (count > 0) { - localHist.SetValue(j, localHist.GetValue(j) + count); - } - - set_flag(PIPE_S, PIPE_V, EVENT_ID0); - wait_flag(PIPE_S, PIPE_V, EVENT_ID0); - - TMOV(prev_f32, cur_f32); - } - } - - set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); - wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); - - const uint32_t local_hist_offset = get_block_idx() * num_bins_aligned; - HistGlobalData zLocalGlobal(z_local + local_hist_offset, {static_cast(num_bins_aligned)}); - TSTORE(zLocalGlobal, localHist); - - set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); - wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); -} - -/** - * runTReduceHistogram - Phase 2: Reduction of local histograms from all cores. - */ -AICORE void runTReduceHistogram(__gm__ int32_t* z, __gm__ int32_t* z_local, - const int32_t num_bins, const uint32_t num_cores) { - set_mask_norm(); - set_vector_mask(-1, -1); - - const uint32_t num_bins_aligned = kernel_utils::CeilDiv(num_bins, 8) * 8; - using HistGlobalData = pto::GlobalTensor, pto::Stride<1, 1, 1, 1, 1>>; - - constexpr uint32_t UB_MAIN_HIST_ADDR = 0; - const uint32_t UB_OTHER_HIST_ADDR = UB_MAIN_HIST_ADDR + num_bins_aligned * sizeof(int32_t); - - constexpr uint32_t MAX_BINS = 8192; - using HistTile = Tile; - - if (get_block_idx() == 0) { - HistTile mainHist(static_cast(num_bins_aligned)); - TASSIGN(mainHist, UB_MAIN_HIST_ADDR); - - HistGlobalData mainHistGlobal(z_local, {static_cast(num_bins_aligned)}); - TLOAD(mainHist, mainHistGlobal); - - set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); - wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); - - HistTile otherHist(static_cast(num_bins_aligned)); - TASSIGN(otherHist, UB_OTHER_HIST_ADDR); - - for (uint32_t i = 1; i < num_cores; ++i) { + // Only cores that have actual tiles to process perform the Atomic operation. + if (start_tile_idx < end_tile_idx) { + // --- Define UB Tiles and Memory Layout --- + constexpr uint32_t MASK_COLS = TILE_LEN / 8; + constexpr uint32_t MASK_CAPACITY_COLS = MASK_COLS > 32 ? MASK_COLS : 32; + + uint32_t addr = 0; + const uint32_t UB_X_ADDR = addr; addr += TILE_SIZE * sizeof(T); + const uint32_t UB_CUR_MASK_ADDR = addr; addr += TILE_LEN * MASK_CAPACITY_COLS * sizeof(uint8_t); + const uint32_t UB_CUR_F32_ADDR = addr; addr += TILE_SIZE * sizeof(float); + const uint32_t UB_PREV_F32_ADDR = addr; addr += TILE_SIZE * sizeof(float); + const uint32_t UB_BIN_F32_ADDR = addr; addr += TILE_SIZE * sizeof(float); + const uint32_t UB_ONE_ADDR = addr; addr += TILE_SIZE * sizeof(float); + const uint32_t UB_ZERO_ADDR = addr; addr += TILE_SIZE * sizeof(float); + const uint32_t UB_REDUCE_TMP_ADDR = addr; addr += TILE_SIZE * sizeof(float); + const uint32_t UB_RSUM_ADDR = addr; addr += TILE_LEN * 8 * sizeof(float); + const uint32_t UB_COUNT_ADDR = addr; addr += 8 * sizeof(float); + const uint32_t UB_LOCAL_HIST_ADDR = addr; + + using InputTileData = Tile; + InputTileData xTiles; + TASSIGN(xTiles, UB_X_ADDR); + + // Mask tile + //using MaskTileData = Tile; + //MaskTileData current_mask(TILE_LEN, MASK_COLS); + Tile current_mask; + TASSIGN(current_mask, UB_CUR_MASK_ADDR); + + // Float conversion tiles + using F32TileData = Tile; + F32TileData cur_f32; + TASSIGN(cur_f32, UB_CUR_F32_ADDR); + F32TileData prev_f32; + TASSIGN(prev_f32, UB_PREV_F32_ADDR); + F32TileData bin_mask_f32; + TASSIGN(bin_mask_f32, UB_BIN_F32_ADDR); + + F32TileData one_tile; + TASSIGN(one_tile, UB_ONE_ADDR); + TEXPANDS(one_tile, 1.0f); + F32TileData zero_tile; + TASSIGN(zero_tile, UB_ZERO_ADDR); + TEXPANDS(zero_tile, 0.0f); + + F32TileData reduce_tmp; + TASSIGN(reduce_tmp, UB_REDUCE_TMP_ADDR); + + using FloatRowSumTile = Tile; + FloatRowSumTile row_sum_f32_tile(1); + TASSIGN(row_sum_f32_tile, UB_RSUM_ADDR); + + using FloatCountTile = Tile; + FloatCountTile count_f32_tile(1); + TASSIGN(count_f32_tile, UB_COUNT_ADDR); + + // Local histogram tile in UB + constexpr uint32_t MAX_BINS_LIMIT = 1024; + using HistTile = Tile; + HistTile localHist(static_cast(num_bins_aligned)); + TASSIGN(localHist, UB_LOCAL_HIST_ADDR); + TEXPANDS(localHist, (int32_t)0); + + // Initial barrier to ensure UB layout and constants are set. + pipe_barrier(PIPE_ALL); + + // --- Main Calculation Loop --- + for (uint32_t tile_idx = start_tile_idx; tile_idx < end_tile_idx; ++tile_idx) { + const uint32_t offset = tile_idx * TILE_SIZE; + InputGlobalData xGlobal(x + offset); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); - - const uint32_t other_hist_offset = i * num_bins_aligned; - HistGlobalData otherHistGlobal(z_local + other_hist_offset, {static_cast(num_bins_aligned)}); - TLOAD(otherHist, otherHistGlobal); - + TLOAD(xTiles, xGlobal); set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); - TADD(mainHist, mainHist, otherHist); + TCMPS(current_mask, xTiles, static_cast(min_val), CmpMode::LT); + TSEL(prev_f32, current_mask, one_tile, zero_tile); + //TEXPANDS(prev_f32, 0.0f); + + for (int32_t j = 0; j < num_bins; ++j) { + float bin_upper_bound = min_val + (j + 1) * bin_width; + CmpMode mode = (j == num_bins - 1) ? CmpMode::LE : CmpMode::LT; + + TCMPS(current_mask, xTiles, static_cast(bin_upper_bound), mode); + TSEL(cur_f32, current_mask, one_tile, zero_tile); + TSUB(bin_mask_f32, cur_f32, prev_f32); + + TEXPANDS(row_sum_f32_tile, 0.0f); + TEXPANDS(reduce_tmp, 0.0f); + TROWSUM(row_sum_f32_tile, bin_mask_f32, reduce_tmp); + + TEXPANDS(count_f32_tile, 0.0f); + TEXPANDS(reduce_tmp, 0.0f); + TCOLSUM(count_f32_tile, row_sum_f32_tile, reduce_tmp, true); + + // Scalar move to update UB local histogram + set_flag(PIPE_V, PIPE_S, EVENT_ID0); + wait_flag(PIPE_V, PIPE_S, EVENT_ID0); + float f_count = count_f32_tile.GetValue(0); + int32_t count = static_cast(f_count + 0.5f); + if (count > 0) { + localHist.SetValue(j, localHist.GetValue(j) + count); + } + set_flag(PIPE_S, PIPE_V, EVENT_ID0); + wait_flag(PIPE_S, PIPE_V, EVENT_ID0); + + TMOV(prev_f32, cur_f32); + pipe_barrier(PIPE_ALL); + } } + pipe_barrier(PIPE_ALL); - set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); - wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); - - // Create a new tile object sharing the same address to set ValidCol to num_bins for final store + // --- Final Atomic Store to Global Memory --- HistTile finalHist(static_cast(num_bins)); - TASSIGN(finalHist, UB_MAIN_HIST_ADDR); - HistGlobalData zGlobal(z, {static_cast(num_bins)}); - TSTORE(zGlobal, finalHist); + TASSIGN(finalHist, UB_LOCAL_HIST_ADDR); + HistGlobalData zGlobal(z, {1, 1, 1, 1, num_bins}); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + // Perform Atomic Addition directly into the final result tensor + TSTORE(zGlobal, finalHist); set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); } } -extern "C" __global__ AICORE void vhistogram_local_fp16(GM_ADDR x, GM_ADDR z_local, - const uint32_t in_length, - const int32_t num_bins, - const float min_val, const float max_val) { +extern "C" __global__ AICORE void vhistogram_fp16(GM_ADDR x, GM_ADDR z, + const uint32_t in_length, + const int32_t num_bins, + const float min_val, const float max_val, + const float bin_width) { constexpr unsigned TILE_LEN = 64; - runTLocalHistogram((__gm__ half*)x, (__gm__ int32_t*)z_local, in_length, - num_bins, min_val, max_val); + runTHistogram((__gm__ half*)x, (__gm__ int32_t*)z, in_length, + num_bins, min_val, max_val, bin_width); } -extern "C" __global__ AICORE void vhistogram_local_fp32(GM_ADDR x, GM_ADDR z_local, - const uint32_t in_length, - const int32_t num_bins, - const float min_val, const float max_val) { +extern "C" __global__ AICORE void vhistogram_fp32(GM_ADDR x, GM_ADDR z, + const uint32_t in_length, + const int32_t num_bins, + const float min_val, const float max_val, + const float bin_width) { constexpr unsigned TILE_LEN = 64; - runTLocalHistogram((__gm__ float*)x, (__gm__ int32_t*)z_local, in_length, - num_bins, min_val, max_val); + runTHistogram((__gm__ float*)x, (__gm__ int32_t*)z, in_length, + num_bins, min_val, max_val, bin_width); } -extern "C" __global__ AICORE void vhistogram_reduce_fp16(__gm__ int32_t* z, __gm__ int32_t* z_local, - const int32_t num_bins, const uint32_t num_cores) { - runTReduceHistogram(z, z_local, num_bins, num_cores); -} - -extern "C" __global__ AICORE void vhistogram_reduce_fp32(__gm__ int32_t* z, __gm__ int32_t* z_local, - const int32_t num_bins, const uint32_t num_cores) { - runTReduceHistogram(z, z_local, num_bins, num_cores); -} -#endif \ No newline at end of file +#endif diff --git a/tests/test_histogram.py b/tests/test_histogram.py index e1f96bf5..43318592 100644 --- a/tests/test_histogram.py +++ b/tests/test_histogram.py @@ -11,19 +11,27 @@ import pytest -#@pytest.mark.parametrize("num_blocks", [1, 2, 10, 20, 32, 64]) -#@pytest.mark.parametrize("bins", [2, 4, 16, 50, 100]) -@pytest.mark.parametrize("num_blocks", [1]) +# @pytest.mark.parametrize("num_blocks", [1, 2, 10, 20, 32, 64]) +# @pytest.mark.parametrize("bins", [2, 4, 16, 50, 100]) +# @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=str) +@pytest.mark.parametrize("num_blocks", [64]) @pytest.mark.parametrize("bins", [64]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=str) +@pytest.mark.parametrize("dtype", [torch.float32], ids=str) def test_pto_histogram(num_blocks: int, bins: int, dtype: torch.dtype): tile_len = 64 length = [num_blocks * tile_len] - - x = torch.rand(length, device="cpu", dtype=dtype) + + # x = torch.rand(length, device="cpu", dtype=dtype) + x = torch.arange(length[0], device="cpu", dtype=dtype) x_npu = x.npu() - y_npu = pto_histogram(x_npu, bins=bins).cpu() + y_npu = pto_histogram(x_npu, bins=bins).cpu().float() + y_npu2 = pto_histogram(x_npu, bins=bins).cpu().float() + y_npu3 = pto_histogram(x_npu, bins=bins).cpu().float() + y_npu4 = pto_histogram(x_npu, bins=bins).cpu().float() + assert torch.equal(y_npu, y_npu2) + assert torch.equal(y_npu, y_npu3) + assert torch.equal(y_npu, y_npu4) y_cpu = torch.histc(x, bins=bins) - + assert torch.equal(y_npu, y_cpu) From f98f957cf43d70cafdedb623b40a2368e878743b Mon Sep 17 00:00:00 2001 From: Vladimir Loncar Date: Tue, 7 Apr 2026 12:29:46 +0000 Subject: [PATCH 04/17] A functional histogram implementation (yay) --- csrc/host/torch_histogram.h | 45 ++-- csrc/kernel/kernel_histogram.cpp | 388 +++++++++++++++++++------------ tests/test_histogram.py | 48 ++-- 3 files changed, 290 insertions(+), 191 deletions(-) diff --git a/csrc/host/torch_histogram.h b/csrc/host/torch_histogram.h index 74db3d63..590c09a8 100644 --- a/csrc/host/torch_histogram.h +++ b/csrc/host/torch_histogram.h @@ -11,8 +11,9 @@ for the full License text. #include #include -#include "aclrtlaunch_vhistogram_fp16.h" -#include "aclrtlaunch_vhistogram_fp32.h" +#include "aclrtlaunch_histogram_final.h" +#include "aclrtlaunch_histogram_fp16.h" +#include "aclrtlaunch_histogram_fp32.h" #include "utils.h" namespace pto_isa_ops { @@ -29,23 +30,30 @@ namespace pto_isa_ops { at::Tensor run_histogram(const at::Tensor& x, int64_t bins = 100, double min_val = 0.0, double max_val = 0.0) { const uint32_t total_len = x.numel(); - constexpr uint32_t TILE_LEN = 64; - constexpr uint32_t TILE_SIZE = TILE_LEN * TILE_LEN; - // const uint32_t block_dim = (total_len + TILE_SIZE - 1) / TILE_SIZE; - // const uint32_t block_dim = GetNumVectorCores();; - const uint32_t block_dim = 1; + constexpr uint32_t TILE_SIZE = 512; + const uint32_t block_dim = GetNumVectorCores(); - TORCH_CHECK(total_len / TILE_SIZE != 0, - "total number of elements must be divisible by 64 * 64"); - TORCH_CHECK(bins <= 1024, "bins must be <= 1024"); + TORCH_CHECK(total_len % TILE_SIZE == 0, + "total number of elements must be divisible by TILE_SIZE"); + TORCH_CHECK(bins <= 256, "bins must be <= 256"); + TORCH_CHECK(x.is_contiguous(), "x must be contiguous"); const auto dtype = x.options().dtype(); const auto device = x.options().device(); - auto z_opts = + + // Allocate a 1D tensor sized `[block_dim * bins]` for the local histogram + // counts. + auto z_local_opts = at::TensorOptions() - .dtype(at::kInt) // Set data type to int32 for histogram counts + .dtype( + at::kFloat) // Local (per-core) histogram counts will be floats .device(device); + at::Tensor z_local = at::zeros({block_dim * bins}, z_local_opts); + // Allocate a 1D tensor sized `[bins]` for the histogram. + auto z_opts = at::TensorOptions() + .dtype(at::kInt) // The final result will be int32 counts + .device(device); at::Tensor z = at::zeros({bins}, z_opts); const auto num_bins = static_cast(bins); @@ -59,19 +67,20 @@ at::Tensor run_histogram(const at::Tensor& x, int64_t bins = 100, const auto f_min_val = static_cast(min_val); const auto f_max_val = static_cast(max_val); - const float f_bin_width = (f_max_val - f_min_val) / (float)num_bins; - at::Tensor x_contig = x.contiguous(); // Just in case if (dtype == at::kHalf) { - EXEC_KERNEL_CMD(vhistogram_fp16, block_dim, x_contig, z, total_len, - num_bins, f_min_val, f_max_val, f_bin_width); + EXEC_KERNEL_CMD(histogram_fp16, block_dim, x, z_local, total_len, num_bins, + f_min_val, f_max_val); } else if (dtype == at::kFloat) { - EXEC_KERNEL_CMD(vhistogram_fp32, block_dim, x_contig, z, total_len, - num_bins, f_min_val, f_max_val, f_bin_width); + EXEC_KERNEL_CMD(histogram_fp32, block_dim, x, z_local, total_len, num_bins, + f_min_val, f_max_val); } else { throw std::runtime_error("Unsupported dtype for `pto_histogram` kernel"); } + const uint32_t reduce_dim = 1; + EXEC_KERNEL_CMD(histogram_final, reduce_dim, z_local, z, num_bins, block_dim); + return z; } } // namespace pto_isa_ops diff --git a/csrc/kernel/kernel_histogram.cpp b/csrc/kernel/kernel_histogram.cpp index ce43bd2c..10949efd 100644 --- a/csrc/kernel/kernel_histogram.cpp +++ b/csrc/kernel/kernel_histogram.cpp @@ -12,185 +12,273 @@ for the full License text. #define MEMORY_BASE #include + #include "kernel_utils.h" #define GM_ADDR __gm__ uint8_t* using namespace pto; +constexpr uint32_t DEFAULT_TILE_SIZE = 512; +constexpr uint32_t MAX_BINS = 256; +constexpr uint32_t MAX_BLOCKS = 64; + /** - * runTHistogram - Local histogram calculation with Atomic Addition to Global Memory. + * runTLocalHistogram - Local, per-core histogram calculation */ -template -AICORE void runTHistogram(__gm__ T* x, __gm__ int32_t* z, const uint32_t total_length, - const int32_t num_bins, const float min_val, const float max_val, - const float bin_width) -{ +template +AICORE void runTLocalHistogram(__gm__ T* x, __gm__ float* z_local, + const uint32_t total_length, + const int32_t num_bins, const float min_val, + const float max_val) { set_mask_norm(); set_vector_mask(-1, -1); // --- Define Global Tensors --- - constexpr uint32_t TILE_SIZE = TILE_LEN * TILE_LEN; - using InputShape = pto::Shape<1, 1, 1, TILE_LEN, TILE_LEN>; - using InputStride = pto::Stride<1, 1, 1, TILE_LEN, 1>; - using InputGlobalData = pto::GlobalTensor; - - // Align num_bins for vector processing (multiple of 8 for 32-byte alignment) - const uint32_t num_bins_aligned = kernel_utils::CeilDiv(num_bins, 8) * 8; - using HistGlobalData = pto::GlobalTensor, pto::Stride<1, 1, 1, 1, 1>>; + using InputGlobalData = pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + using HistGlobalData = + pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; // --- Work Distribution --- - const uint32_t num_tiles_total = kernel_utils::CeilDiv(total_length, TILE_SIZE); - const uint32_t num_tiles_per_core = kernel_utils::CeilDiv(num_tiles_total, get_block_num()); - const uint32_t start_tile_idx = get_block_idx() * num_tiles_per_core; - const uint32_t end_tile_idx = (start_tile_idx + num_tiles_per_core > num_tiles_total) ? num_tiles_total : (start_tile_idx + num_tiles_per_core); + const uint32_t block_idx = get_block_idx(); + const uint32_t block_num = get_block_num(); + const uint32_t num_tiles_total = + kernel_utils::CeilDiv(total_length, TILE_SIZE); + const uint32_t num_tiles_per_core = + kernel_utils::CeilDiv(num_tiles_total, block_num); + const uint32_t start_idx = block_idx * num_tiles_per_core; + const uint32_t end_idx = (start_idx + num_tiles_per_core > num_tiles_total) + ? num_tiles_total + : (start_idx + num_tiles_per_core); - // Only cores that have actual tiles to process perform the Atomic operation. - if (start_tile_idx < end_tile_idx) { - // --- Define UB Tiles and Memory Layout --- - constexpr uint32_t MASK_COLS = TILE_LEN / 8; - constexpr uint32_t MASK_CAPACITY_COLS = MASK_COLS > 32 ? MASK_COLS : 32; + // --- Define UB Tiles and Memory Layout --- + uint32_t addr = 0; + const uint32_t UB_X_ADDR = addr; + addr += TILE_SIZE * sizeof(T); + const uint32_t UB_CUR_MASK_ADDR = addr; + addr += TILE_SIZE * sizeof(uint8_t); + const uint32_t UB_CUR_F32_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_PREV_F32_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_BIN_F32_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_ONE_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_ZERO_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_REDUCE_TMP_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_COUNT_ADDR = addr; + addr += 8 * sizeof(float); + const uint32_t UB_LOCAL_HIST_ADDR = addr; - uint32_t addr = 0; - const uint32_t UB_X_ADDR = addr; addr += TILE_SIZE * sizeof(T); - const uint32_t UB_CUR_MASK_ADDR = addr; addr += TILE_LEN * MASK_CAPACITY_COLS * sizeof(uint8_t); - const uint32_t UB_CUR_F32_ADDR = addr; addr += TILE_SIZE * sizeof(float); - const uint32_t UB_PREV_F32_ADDR = addr; addr += TILE_SIZE * sizeof(float); - const uint32_t UB_BIN_F32_ADDR = addr; addr += TILE_SIZE * sizeof(float); - const uint32_t UB_ONE_ADDR = addr; addr += TILE_SIZE * sizeof(float); - const uint32_t UB_ZERO_ADDR = addr; addr += TILE_SIZE * sizeof(float); - const uint32_t UB_REDUCE_TMP_ADDR = addr; addr += TILE_SIZE * sizeof(float); - const uint32_t UB_RSUM_ADDR = addr; addr += TILE_LEN * 8 * sizeof(float); - const uint32_t UB_COUNT_ADDR = addr; addr += 8 * sizeof(float); - const uint32_t UB_LOCAL_HIST_ADDR = addr; - - using InputTileData = Tile; - InputTileData xTiles; - TASSIGN(xTiles, UB_X_ADDR); - - // Mask tile - //using MaskTileData = Tile; - //MaskTileData current_mask(TILE_LEN, MASK_COLS); - Tile current_mask; - TASSIGN(current_mask, UB_CUR_MASK_ADDR); - - // Float conversion tiles - using F32TileData = Tile; - F32TileData cur_f32; - TASSIGN(cur_f32, UB_CUR_F32_ADDR); - F32TileData prev_f32; - TASSIGN(prev_f32, UB_PREV_F32_ADDR); - F32TileData bin_mask_f32; - TASSIGN(bin_mask_f32, UB_BIN_F32_ADDR); - - F32TileData one_tile; - TASSIGN(one_tile, UB_ONE_ADDR); - TEXPANDS(one_tile, 1.0f); - F32TileData zero_tile; - TASSIGN(zero_tile, UB_ZERO_ADDR); - TEXPANDS(zero_tile, 0.0f); - - F32TileData reduce_tmp; - TASSIGN(reduce_tmp, UB_REDUCE_TMP_ADDR); - - using FloatRowSumTile = Tile; - FloatRowSumTile row_sum_f32_tile(1); - TASSIGN(row_sum_f32_tile, UB_RSUM_ADDR); - - using FloatCountTile = Tile; - FloatCountTile count_f32_tile(1); - TASSIGN(count_f32_tile, UB_COUNT_ADDR); - - // Local histogram tile in UB - constexpr uint32_t MAX_BINS_LIMIT = 1024; - using HistTile = Tile; - HistTile localHist(static_cast(num_bins_aligned)); - TASSIGN(localHist, UB_LOCAL_HIST_ADDR); - TEXPANDS(localHist, (int32_t)0); - - // Initial barrier to ensure UB layout and constants are set. - pipe_barrier(PIPE_ALL); - - // --- Main Calculation Loop --- - for (uint32_t tile_idx = start_tile_idx; tile_idx < end_tile_idx; ++tile_idx) { - const uint32_t offset = tile_idx * TILE_SIZE; - InputGlobalData xGlobal(x + offset); - - set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); - wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); - TLOAD(xTiles, xGlobal); - set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); - wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); - - TCMPS(current_mask, xTiles, static_cast(min_val), CmpMode::LT); - TSEL(prev_f32, current_mask, one_tile, zero_tile); - //TEXPANDS(prev_f32, 0.0f); - - for (int32_t j = 0; j < num_bins; ++j) { - float bin_upper_bound = min_val + (j + 1) * bin_width; - CmpMode mode = (j == num_bins - 1) ? CmpMode::LE : CmpMode::LT; - - TCMPS(current_mask, xTiles, static_cast(bin_upper_bound), mode); - TSEL(cur_f32, current_mask, one_tile, zero_tile); - TSUB(bin_mask_f32, cur_f32, prev_f32); - - TEXPANDS(row_sum_f32_tile, 0.0f); - TEXPANDS(reduce_tmp, 0.0f); - TROWSUM(row_sum_f32_tile, bin_mask_f32, reduce_tmp); - - TEXPANDS(count_f32_tile, 0.0f); - TEXPANDS(reduce_tmp, 0.0f); - TCOLSUM(count_f32_tile, row_sum_f32_tile, reduce_tmp, true); - - // Scalar move to update UB local histogram - set_flag(PIPE_V, PIPE_S, EVENT_ID0); - wait_flag(PIPE_V, PIPE_S, EVENT_ID0); - float f_count = count_f32_tile.GetValue(0); - int32_t count = static_cast(f_count + 0.5f); - if (count > 0) { - localHist.SetValue(j, localHist.GetValue(j) + count); - } - set_flag(PIPE_S, PIPE_V, EVENT_ID0); - wait_flag(PIPE_S, PIPE_V, EVENT_ID0); - - TMOV(prev_f32, cur_f32); - pipe_barrier(PIPE_ALL); + InputGlobalData x_gm(x, {static_cast(total_length)}); + + using InputTileData = Tile; + InputTileData x_tile; + TASSIGN(x_tile, UB_X_ADDR); + + using MaskTileData = Tile; + MaskTileData current_mask; + TASSIGN(current_mask, UB_CUR_MASK_ADDR); + + // Float conversion tiles + using F32TileData = Tile; + F32TileData cur_f32; + TASSIGN(cur_f32, UB_CUR_F32_ADDR); + F32TileData prev_f32; + TASSIGN(prev_f32, UB_PREV_F32_ADDR); + F32TileData bin_mask_f32; + TASSIGN(bin_mask_f32, UB_BIN_F32_ADDR); + + F32TileData one_tile; + TASSIGN(one_tile, UB_ONE_ADDR); + TEXPANDS(one_tile, 1.0f); + F32TileData zero_tile; + TASSIGN(zero_tile, UB_ZERO_ADDR); + TEXPANDS(zero_tile, 0.0f); + + F32TileData reduce_tmp; + TASSIGN(reduce_tmp, UB_REDUCE_TMP_ADDR); + + using F32CountTile = + Tile; + F32CountTile count_f32_tile; + TASSIGN(count_f32_tile, UB_COUNT_ADDR); + + // Local histogram tile in UB + using HistTile = + Tile; + HistTile local_hist(num_bins); + TASSIGN(local_hist, UB_LOCAL_HIST_ADDR); + TEXPANDS(local_hist, 0.0f); + + const float bin_width = (max_val - min_val) / static_cast(num_bins); + + // --- Main Calculation Loop --- + for (uint32_t tile_idx = start_idx; tile_idx < end_idx; ++tile_idx) { + int offset = tile_idx * TILE_SIZE; + TASSIGN(x_gm, x + offset); + + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(x_tile, x_gm); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + // Generate packed bit-mask + TCMPS(current_mask, x_tile, static_cast(min_val), CmpMode::LT); + // Select 1.0f or 0.0f based on the packed bit-mask + TSEL(prev_f32, current_mask, one_tile, zero_tile); + + for (int32_t j = 0; j < num_bins; ++j) { + float bin_upper_bound = min_val + (j + 1) * bin_width; + CmpMode mode = (j == num_bins - 1) ? CmpMode::LE : CmpMode::LT; + + TCMPS(current_mask, x_tile, static_cast(bin_upper_bound), mode); + TSEL(cur_f32, current_mask, one_tile, zero_tile); + TSUB(bin_mask_f32, cur_f32, prev_f32); + + // Reduce the selected tile to get the count of elements less than pivot + // in this tile + TEXPANDS(count_f32_tile, 0.0f); + TEXPANDS(reduce_tmp, 0.0f); + TROWSUM(count_f32_tile, bin_mask_f32, reduce_tmp); + + // Scalar move to update UB local histogram + set_flag(PIPE_V, PIPE_S, EVENT_ID0); + wait_flag(PIPE_V, PIPE_S, EVENT_ID0); + float f_count = count_f32_tile.GetValue(0); + if (f_count > 0.0f) { + local_hist.SetValue(j, local_hist.GetValue(j) + f_count); } + set_flag(PIPE_S, PIPE_V, EVENT_ID0); + wait_flag(PIPE_S, PIPE_V, EVENT_ID0); + + TMOV(prev_f32, cur_f32); } - pipe_barrier(PIPE_ALL); + } - // --- Final Atomic Store to Global Memory --- - HistTile finalHist(static_cast(num_bins)); - TASSIGN(finalHist, UB_LOCAL_HIST_ADDR); - HistGlobalData zGlobal(z, {1, 1, 1, 1, num_bins}); + // --- Final Store to Global Memory --- + HistGlobalData z_gm(z_local + block_idx * num_bins, + {static_cast(block_num) * num_bins}); + + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(z_gm, local_hist); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); +} + +// Template parameter to avoid "no function" kernel launch error +template +AICORE void runTHistogramFinal(__gm__ float* z_local, __gm__ int32_t* z, + const int32_t num_bins, + const int32_t num_blocks) { + set_mask_norm(); + set_vector_mask(-1, -1); + + if (get_block_idx() == 0) { + // --- Define Global Tensors --- + using InGlobalData = + pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + using OutGlobalData = + pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + + uint32_t addr = 0; + const uint32_t UB_IN_ADDR = addr; + addr += MAX_BLOCKS * MAX_BINS * sizeof(float); + const uint32_t UB_REDUCE_TMP_ADDR = addr; + addr += MAX_BINS * sizeof(float); + const uint32_t UB_FLOAT_OUT_ADDR = addr; + addr += MAX_BINS * sizeof(float); + const uint32_t UB_OUT_ADDR = addr; + addr += MAX_BINS * sizeof(int32_t); + + using InTile = Tile; + InTile in_tile( + {static_cast(num_blocks), static_cast(num_bins)}); + TASSIGN(in_tile, UB_IN_ADDR); + + using ReduceTmpTile = + Tile; + ReduceTmpTile reduce_tmp_tile(num_bins); + TASSIGN(reduce_tmp_tile, UB_REDUCE_TMP_ADDR); + TEXPANDS(reduce_tmp_tile, 0.0f); + + using FloatOutTile = + Tile; + FloatOutTile float_out_tile(num_bins); + TASSIGN(float_out_tile, UB_FLOAT_OUT_ADDR); + TEXPANDS(float_out_tile, 0.0f); + + using OutTile = Tile; + OutTile out_tile(num_bins); + TASSIGN(out_tile, UB_OUT_ADDR); + TEXPANDS(out_tile, static_cast(0)); + + // Load all block counts into UB row by row to match 2D tile padding + using InRowTile = + Tile; + InRowTile row_tile(static_cast(num_bins)); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + for (int32_t b = 0; b < num_blocks; ++b) { + InGlobalData z_local_gm(z_local + b * num_bins, {num_bins}); + TASSIGN(row_tile, UB_IN_ADDR + b * MAX_BINS * sizeof(float)); + TLOAD(row_tile, z_local_gm); + } + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + // TCOLSUM reduces along the row dimension (num_blocks) + TCOLSUM(float_out_tile, in_tile, reduce_tmp_tile, true); + + TCVT(out_tile, float_out_tile, RoundMode::CAST_RINT); + + // --- Final Store to Global Memory --- + OutGlobalData z_gm(z, {num_bins}); set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); - // Perform Atomic Addition directly into the final result tensor - TSTORE(zGlobal, finalHist); + TSTORE(z_gm, out_tile); set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); } } -extern "C" __global__ AICORE void vhistogram_fp16(GM_ADDR x, GM_ADDR z, - const uint32_t in_length, - const int32_t num_bins, - const float min_val, const float max_val, - const float bin_width) { - constexpr unsigned TILE_LEN = 64; - runTHistogram((__gm__ half*)x, (__gm__ int32_t*)z, in_length, - num_bins, min_val, max_val, bin_width); +extern "C" __global__ AICORE void histogram_fp16(GM_ADDR x, GM_ADDR z_local, + const uint32_t in_length, + const int32_t num_bins, + const float min_val, + const float max_val) { + runTLocalHistogram((__gm__ half*)x, + (__gm__ float*)z_local, in_length, + num_bins, min_val, max_val); +} + +extern "C" __global__ AICORE void histogram_fp32(GM_ADDR x, GM_ADDR z_local, + const uint32_t in_length, + const int32_t num_bins, + const float min_val, + const float max_val) { + runTLocalHistogram( + (__gm__ float*)x, (__gm__ float*)z_local, in_length, num_bins, min_val, + max_val); } -extern "C" __global__ AICORE void vhistogram_fp32(GM_ADDR x, GM_ADDR z, - const uint32_t in_length, - const int32_t num_bins, - const float min_val, const float max_val, - const float bin_width) { - constexpr unsigned TILE_LEN = 64; - runTHistogram((__gm__ float*)x, (__gm__ int32_t*)z, in_length, - num_bins, min_val, max_val, bin_width); +extern "C" __global__ AICORE void histogram_final(GM_ADDR z_local, GM_ADDR z, + const int32_t num_bins, + const int32_t num_blocks) { + runTHistogramFinal<0>((__gm__ float*)z_local, (__gm__ int32_t*)z, num_bins, + num_blocks); } #endif diff --git a/tests/test_histogram.py b/tests/test_histogram.py index 43318592..21b82662 100644 --- a/tests/test_histogram.py +++ b/tests/test_histogram.py @@ -11,27 +11,29 @@ import pytest -# @pytest.mark.parametrize("num_blocks", [1, 2, 10, 20, 32, 64]) -# @pytest.mark.parametrize("bins", [2, 4, 16, 50, 100]) -# @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=str) -@pytest.mark.parametrize("num_blocks", [64]) -@pytest.mark.parametrize("bins", [64]) +@pytest.mark.parametrize("size", [1, 8, 50, 256, 1000]) +@pytest.mark.parametrize("bins", [2, 4, 32, 100, 256]) +# torch.float16 requires different boundary calculation to match torch implementeation @pytest.mark.parametrize("dtype", [torch.float32], ids=str) -def test_pto_histogram(num_blocks: int, bins: int, dtype: torch.dtype): - tile_len = 64 - length = [num_blocks * tile_len] - - # x = torch.rand(length, device="cpu", dtype=dtype) - x = torch.arange(length[0], device="cpu", dtype=dtype) - x_npu = x.npu() - - y_npu = pto_histogram(x_npu, bins=bins).cpu().float() - y_npu2 = pto_histogram(x_npu, bins=bins).cpu().float() - y_npu3 = pto_histogram(x_npu, bins=bins).cpu().float() - y_npu4 = pto_histogram(x_npu, bins=bins).cpu().float() - assert torch.equal(y_npu, y_npu2) - assert torch.equal(y_npu, y_npu3) - assert torch.equal(y_npu, y_npu4) - y_cpu = torch.histc(x, bins=bins) - - assert torch.equal(y_npu, y_cpu) +def test_pto_histogram(size: int, bins: int, dtype: torch.dtype): + # Tile size is fixed in the kernel + tile_size = 512 + num_cores = torch.npu.get_device_properties(0).vector_core_num + num_tiles = num_cores * tile_size + total_len = num_tiles * size + + x = torch.randint(high=bins, size=(total_len,), device="cpu", dtype=dtype) + + # Golden PyTorch implementation + y_cpu = torch.histc(x, bins=bins).float() + + # NPU kernel execution, test to see if any race conditions occur across multiple runs + x_npu = x.npu().contiguous() + y_npu = [] + repeat_runs = 100 + for _ in range(repeat_runs): + y_npu.append(pto_histogram(x_npu, bins=bins).cpu().float()) + + torch.npu.synchronize() + + assert all(torch.equal(y_cpu, y) for y in y_npu) From 19a4e338825c5336e9f846da7e63cf36f79b16ad Mon Sep 17 00:00:00 2001 From: Vladimir Loncar Date: Tue, 7 Apr 2026 12:49:58 +0000 Subject: [PATCH 05/17] Histogram implementation guide, step0 - count_less_than --- examples/jit_cpp/histogram/kernel_utils.h | 45 ++++ .../jit_util_count_less_than.py | 82 ++++++ .../kernel_count_less_than.cpp | 241 ++++++++++++++++++ .../kernel_count_less_than_atomic.cpp | 156 ++++++++++++ .../run_count_less_than.py | 65 +++++ 5 files changed, 589 insertions(+) create mode 100644 examples/jit_cpp/histogram/kernel_utils.h create mode 100644 examples/jit_cpp/histogram/step0_count_less_than/jit_util_count_less_than.py create mode 100644 examples/jit_cpp/histogram/step0_count_less_than/kernel_count_less_than.cpp create mode 100644 examples/jit_cpp/histogram/step0_count_less_than/kernel_count_less_than_atomic.cpp create mode 100644 examples/jit_cpp/histogram/step0_count_less_than/run_count_less_than.py diff --git a/examples/jit_cpp/histogram/kernel_utils.h b/examples/jit_cpp/histogram/kernel_utils.h new file mode 100644 index 00000000..0b6ae4c0 --- /dev/null +++ b/examples/jit_cpp/histogram/kernel_utils.h @@ -0,0 +1,45 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +All rights reserved. + +See LICENSE in the root of the software repository: +https://github.com/huawei-csl/pto-kernels/ +for the full License text. +*/ +#pragma once + +#include + +namespace kernel_utils { +/** + * @brief Do a sync step (set-wait flag) between two pipes. + * + * @tparam SrcPipe The pipe that sets the flag. + * @tparam DstPipe The pipe that waits for the flag. + * @param [in] id The event id to sync for. + */ +template +AICORE inline void SetWaitFlag(uint32_t id) { + set_flag(SrcPipe, DstPipe, static_cast(id)); + wait_flag(SrcPipe, DstPipe, static_cast(id)); +} + +/** + * @brief Performs a division on two integral numbers and rounds the result up + * to the nearest integer. + * + * @tparam T1 Data type of dividend. + * @tparam T2 Data type of divisor. + * @param [in] value Dividend. + * @param [in] divisor Divisor. + * @return Result of division. + */ +template ::value && + std::is_integral::value, + int>::type = 0> +AICORE inline T1 CeilDiv(T1 value, T2 divisor) { + return (value + divisor - 1) / divisor; +} + +} // namespace kernel_utils diff --git a/examples/jit_cpp/histogram/step0_count_less_than/jit_util_count_less_than.py b/examples/jit_cpp/histogram/step0_count_less_than/jit_util_count_less_than.py new file mode 100644 index 00000000..d76054b4 --- /dev/null +++ b/examples/jit_cpp/histogram/step0_count_less_than/jit_util_count_less_than.py @@ -0,0 +1,82 @@ +import os +import subprocess +import ctypes + +import torch + +ASCEND_TOOLKIT_HOME = os.environ["ASCEND_TOOLKIT_HOME"] +PTO_LIB_PATH = os.environ["PTO_LIB_PATH"] + + +def compile_cpp(kernel_cpp: str, verbose: bool = False, timeout: int = 120) -> str: + lib_path = os.path.join(os.path.dirname(kernel_cpp), "count_less_than_jit.so") + + flags = [ + "-fPIC", + "-shared", + "-xcce", + "--npu-arch=dav-2201", + "-DMEMORY_BASE", # here hardcoded for A2A3; TODO: expose this option to jit interface + "-O2", + "-std=c++17", + f"-I{PTO_LIB_PATH}/include", + ] + + command = ["bisheng", *flags, kernel_cpp, "-o", lib_path] + if verbose: + print(f"compile {kernel_cpp} with command: \n", command) + + try: + subprocess.run( + command, + timeout=timeout, + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + except Exception as e: + raise RuntimeError(f"Compile failed: {e}") from e + + if verbose: + print(f"generated {lib_path}") + return lib_path + + +def torch_to_ctypes(tensor): + return ctypes.c_void_p(tensor.data_ptr()) + + +def load_lib(lib_path, check_type=True): + lib_path = os.path.abspath(lib_path) + lib = ctypes.CDLL(lib_path) + + default_block_dim = torch.npu.get_device_properties().vector_core_num + + if check_type: + lib.count_less_than_fp32.argtypes = [ + ctypes.c_uint32, # blockDim + ctypes.c_void_p, # stream + ctypes.c_void_p, # x + ctypes.c_void_p, # z + ctypes.c_uint, # in_length + ctypes.c_float, # pivot + ] + lib.count_less_than_fp32.restype = None + + def count_func(x, z, pivot, block_dim=default_block_dim, stream_ptr=None): + if stream_ptr is None: + stream_ptr = torch.npu.current_stream()._as_parameter_ # noqa + N = x.numel() + lib.count_less_than_fp32( + block_dim, stream_ptr, torch_to_ctypes(x), torch_to_ctypes(z), N, pivot + ) + + return count_func + + +def jit_compile(src_path, clean_up=True): + lib_path = compile_cpp(src_path, verbose=True) + func = load_lib(lib_path) + if clean_up: + os.remove(lib_path) + return func diff --git a/examples/jit_cpp/histogram/step0_count_less_than/kernel_count_less_than.cpp b/examples/jit_cpp/histogram/step0_count_less_than/kernel_count_less_than.cpp new file mode 100644 index 00000000..8d35f5e2 --- /dev/null +++ b/examples/jit_cpp/histogram/step0_count_less_than/kernel_count_less_than.cpp @@ -0,0 +1,241 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +All rights reserved. + +See LICENSE in the root of the software repository: +https://github.com/huawei-csl/pto-kernels/ +for the full License text. +*/ + +#include + +#include "../kernel_utils.h" +#include "acl/acl.h" + +using namespace pto; + +/** + * runTCountLessThan - Local count calculation using TCMPS and TSEL + */ +template +AICORE void runTCountLessThan(__gm__ T *x, __gm__ float *z_local, + const uint32_t total_length, const float pivot) { +#if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) + + set_mask_norm(); + set_vector_mask(-1, -1); + + // --- Define Global Tensors --- + using InputGlobalData = pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + using OutGlobalData = + pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + + // --- Work Distribution --- + const uint32_t block_idx = get_block_idx(); + const uint32_t block_num = get_block_num(); + const uint32_t num_tiles_total = + kernel_utils::CeilDiv(total_length, TILE_SIZE); + const uint32_t num_tiles_per_core = + kernel_utils::CeilDiv(num_tiles_total, block_num); + const uint32_t start_idx = block_idx * num_tiles_per_core; + const uint32_t end_idx = (start_idx + num_tiles_per_core > num_tiles_total) + ? num_tiles_total + : (start_idx + num_tiles_per_core); + + if (start_idx < end_idx) { + // --- Define UB Tiles and Memory Layout --- + uint32_t addr = 0; + const uint32_t UB_X_ADDR = addr; + addr += TILE_SIZE * sizeof(T); + const uint32_t UB_CUR_MASK_ADDR = addr; + addr += TILE_SIZE * sizeof(uint8_t); + const uint32_t UB_ONES_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_ZEROS_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_TSEL_OUT_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_REDUCE_TMP_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_COUNT_ADDR = addr; + addr += 8 * sizeof(float); + const uint32_t UB_TOTAL_COUNT_ADDR = addr; + addr += 8 * sizeof(float); + + using InputTileData = Tile; + InputTileData x_tile; + TASSIGN(x_tile, UB_X_ADDR); + + using MaskTileData = Tile; + MaskTileData current_mask; + TASSIGN(current_mask, UB_CUR_MASK_ADDR); + + // Float conversion tiles + using F32TileData = Tile; + F32TileData ones_tile; + TASSIGN(ones_tile, UB_ONES_ADDR); + TEXPANDS(ones_tile, 1.0f); + + F32TileData zeros_tile; + TASSIGN(zeros_tile, UB_ZEROS_ADDR); + TEXPANDS(zeros_tile, 0.0f); + + F32TileData tsel_out_tile; + TASSIGN(tsel_out_tile, UB_TSEL_OUT_ADDR); + + F32TileData reduce_tmp; + TASSIGN(reduce_tmp, UB_REDUCE_TMP_ADDR); + + using F32CountTile = + Tile; + F32CountTile count_f32_tile; + TASSIGN(count_f32_tile, UB_COUNT_ADDR); + + F32CountTile total_count_f32_tile; + TASSIGN(total_count_f32_tile, UB_TOTAL_COUNT_ADDR); + TEXPANDS(total_count_f32_tile, 0.0f); + + // --- Main Calculation Loop --- + for (uint32_t tile_idx = start_idx; tile_idx < end_idx; ++tile_idx) { + const uint32_t offset = tile_idx * TILE_SIZE; + InputGlobalData x_gm(x + offset, {static_cast(total_length)}); + + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(x_tile, x_gm); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + // Generate packed bit-mask + TCMPS(current_mask, x_tile, static_cast(pivot), CmpMode::LT); + // Select 1.0f or 0.0f based on the packed bit-mask + TSEL(tsel_out_tile, current_mask, ones_tile, zeros_tile); + + // Reduce the selected tile to get the count of elements less than pivot + // in this tile + TEXPANDS(count_f32_tile, 0.0f); + TEXPANDS(reduce_tmp, 0.0f); + TROWSUM(count_f32_tile, tsel_out_tile, reduce_tmp); + + // Accumulate the count from this tile into the total count for this block + TADD(total_count_f32_tile, total_count_f32_tile, count_f32_tile); + } + + // --- Final Store to Global Memory --- + OutGlobalData z_local_gm(z_local + block_idx, + {static_cast(block_num)}); + + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(z_local_gm, total_count_f32_tile); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + } + +#endif +} + +// Template parameter to avoid "no function" kernel launch error +template +AICORE void runTCountFinal(__gm__ float *z_local, __gm__ int32_t *z, + const int32_t num_blocks) { +#if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) + + set_mask_norm(); + set_vector_mask(-1, -1); + + if (get_block_idx() == 0) { + // --- Define Global Tensors --- + using InGlobalData = + pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + using OutGlobalData = pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + + constexpr uint32_t MAX_BLOCKS = + 256; // Must be larger than number of AIV cores + + uint32_t addr = 0; + const uint32_t UB_IN_ADDR = addr; + addr += MAX_BLOCKS * sizeof(float); + const uint32_t UB_REDUCE_TMP_ADDR = addr; + addr += 8 * sizeof(float); + const uint32_t UB_FLOAT_OUT_ADDR = addr; + addr += 8 * sizeof(float); + const uint32_t UB_OUT_ADDR = addr; + addr += 8 * sizeof(int32_t); + + using InTile = Tile; + InTile in_tile(static_cast(num_blocks)); + TASSIGN(in_tile, UB_IN_ADDR); + + using ReduceTmpTile = Tile; + ReduceTmpTile reduce_tmp_tile; + TASSIGN(reduce_tmp_tile, UB_REDUCE_TMP_ADDR); + + using FloatOutTile = + Tile; + FloatOutTile float_out_tile; + TASSIGN(float_out_tile, UB_FLOAT_OUT_ADDR); + TEXPANDS(float_out_tile, 0.0f); + + using OutTile = Tile; + OutTile out_tile; + TASSIGN(out_tile, UB_OUT_ADDR); + TEXPANDS(out_tile, (int32_t)0); + + // Load all block counts into UB + InGlobalData z_local_gm(z_local, {num_blocks}); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(in_tile, z_local_gm); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + TROWSUM(float_out_tile, in_tile, reduce_tmp_tile); + TCVT(out_tile, float_out_tile, RoundMode::CAST_RINT); + + OutGlobalData z_gm(z); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(z_gm, out_tile); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + } + +#endif +} + +__global__ AICORE void count_less_than_local(__gm__ void *x, + __gm__ void *z_local, + const uint32_t in_length, + const float pivot) { + constexpr unsigned TILE_SIZE = 512; + runTCountLessThan( + (__gm__ float *)x, (__gm__ float *)z_local, in_length, pivot); +} + +__global__ AICORE void count_final(__gm__ void *z_local, __gm__ void *z, + const int32_t num_blocks) { + runTCountFinal<0>((__gm__ float *)z_local, (__gm__ int32_t *)z, num_blocks); +} + +extern "C" void count_less_than_fp32(uint32_t num_blocks, void *stream, + uint8_t *x, uint8_t *z, uint32_t in_length, + float pivot) { + // Could have been allocated in Torch and passed here + // This is not a suggested practice for production code, we use it here to + // keep the python side interface cleaner + void *z_local = nullptr; + size_t size = num_blocks * sizeof(float); + aclrtMalloc(&z_local, size, ACL_MEM_MALLOC_HUGE_FIRST); + + count_less_than_local<<>>(x, z_local, in_length, + pivot); + count_final<<<1, nullptr, stream>>>(z_local, z, num_blocks); + + aclrtFree(z_local); +} diff --git a/examples/jit_cpp/histogram/step0_count_less_than/kernel_count_less_than_atomic.cpp b/examples/jit_cpp/histogram/step0_count_less_than/kernel_count_less_than_atomic.cpp new file mode 100644 index 00000000..822c2b09 --- /dev/null +++ b/examples/jit_cpp/histogram/step0_count_less_than/kernel_count_less_than_atomic.cpp @@ -0,0 +1,156 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +All rights reserved. + +See LICENSE in the root of the software repository: +https://github.com/huawei-csl/pto-kernels/ +for the full License text. +*/ + +#include + +#include "../kernel_utils.h" + +using namespace pto; + +/** + * runTCountLessThan - Local count calculation with Atomic Addition to Global + * Memory. + */ +template +AICORE void runTCountLessThan(__gm__ T *x, __gm__ int32_t *z, + const uint32_t total_length, const float pivot) { +#if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + + // --- Define Global Tensors --- + using InputGlobalData = pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + using OutGlobalData = pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + + // --- Work Distribution --- + const uint32_t block_idx = get_block_idx(); + const uint32_t block_num = get_block_num(); + const uint32_t num_tiles_total = + kernel_utils::CeilDiv(total_length, TILE_SIZE); + const uint32_t num_tiles_per_core = + kernel_utils::CeilDiv(num_tiles_total, block_num); + const uint32_t start_idx = block_idx * num_tiles_per_core; + const uint32_t end_idx = (start_idx + num_tiles_per_core > num_tiles_total) + ? num_tiles_total + : (start_idx + num_tiles_per_core); + + if (start_idx < end_idx) { + // --- Define UB Tiles and Memory Layout --- + uint32_t addr = 0; + const uint32_t UB_X_ADDR = addr; + addr += TILE_SIZE * sizeof(T); + const uint32_t UB_CUR_MASK_ADDR = addr; + addr += TILE_SIZE * sizeof(uint8_t); + const uint32_t UB_ONES_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_ZEROS_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_TSEL_OUT_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_REDUCE_TMP_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_COUNT_ADDR = addr; + addr += 8 * sizeof(float); + const uint32_t UB_TOTAL_COUNT_ADDR = addr; + addr += 8 * sizeof(float); + const uint32_t UB_LOCAL_OUT_ADDR = addr; + + using InputTileData = Tile; + InputTileData x_tile; + TASSIGN(x_tile, UB_X_ADDR); + + using MaskTileData = Tile; + MaskTileData current_mask; + TASSIGN(current_mask, UB_CUR_MASK_ADDR); + + // Float conversion tiles + using F32TileData = Tile; + F32TileData ones_tile; + TASSIGN(ones_tile, UB_ONES_ADDR); + TEXPANDS(ones_tile, 1.0f); + + F32TileData zeros_tile; + TASSIGN(zeros_tile, UB_ZEROS_ADDR); + TEXPANDS(zeros_tile, 0.0f); + + F32TileData tsel_out_tile; + TASSIGN(tsel_out_tile, UB_TSEL_OUT_ADDR); + + F32TileData reduce_tmp; + TASSIGN(reduce_tmp, UB_REDUCE_TMP_ADDR); + + using F32CountTile = + Tile; + F32CountTile count_f32_tile; + TASSIGN(count_f32_tile, UB_COUNT_ADDR); + + F32CountTile total_count_f32_tile; + TASSIGN(total_count_f32_tile, UB_TOTAL_COUNT_ADDR); + TEXPANDS(total_count_f32_tile, 0.0f); + + using OutTile = Tile; + OutTile local_out; + TASSIGN(local_out, UB_LOCAL_OUT_ADDR); + TEXPANDS(local_out, (int32_t)0); + + // --- Main Calculation Loop --- + for (uint32_t tile_idx = start_idx; tile_idx < end_idx; ++tile_idx) { + const uint32_t offset = tile_idx * TILE_SIZE; + InputGlobalData x_gm(x + offset, {static_cast(total_length)}); + + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(x_tile, x_gm); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + // Generate packed bit-mask + TCMPS(current_mask, x_tile, static_cast(pivot), CmpMode::LT); + // Select 1.0f or 0.0f based on the packed bit-mask + TSEL(tsel_out_tile, current_mask, ones_tile, zeros_tile); + + TEXPANDS(count_f32_tile, 0.0f); + TEXPANDS(reduce_tmp, 0.0f); + TROWSUM(count_f32_tile, tsel_out_tile, reduce_tmp); + + TADD(total_count_f32_tile, total_count_f32_tile, count_f32_tile); + } + + // Convert accumulated total into our UB local count + TCVT(local_out, total_count_f32_tile, RoundMode::CAST_RINT); + + // --- Final Atomic Store to Global Memory --- + OutGlobalData z_gm(z); + + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + // DOesn't do atomic adds + TSTORE(z_gm, local_out); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + } + +#endif +} + +__global__ AICORE void count_less_than(__gm__ void *x, __gm__ void *z, + const uint32_t in_length, + const float pivot) { + constexpr unsigned TILE_SIZE = 512; + runTCountLessThan((__gm__ float *)x, (__gm__ int32_t *)z, + in_length, pivot); +} + +extern "C" void count_less_than_fp32(uint32_t num_blocks, void *stream, + uint8_t *x, uint8_t *z, uint32_t in_length, + float pivot) { + count_less_than<<>>(x, z, in_length, pivot); +} diff --git a/examples/jit_cpp/histogram/step0_count_less_than/run_count_less_than.py b/examples/jit_cpp/histogram/step0_count_less_than/run_count_less_than.py new file mode 100644 index 00000000..a93a0614 --- /dev/null +++ b/examples/jit_cpp/histogram/step0_count_less_than/run_count_less_than.py @@ -0,0 +1,65 @@ +import random +import torch +import torch_npu # noqa + +from jit_util_count_less_than import jit_compile + + +def hist_ref(x, bins): + return torch.histc(x, bins=bins) + + +def random_2d_shape( + min_m=1, + max_m=2048, + min_n=1, + max_n=2048, +): + m = random.randint(min_m, max_m) + n = random.randint(min_n, max_n) + return [m, n] + + +def test_count_less_than(size_mult=1, repeat_runs=20, use_atomic_impl=False): + device = "npu:1" + dtype = torch.float32 + torch.npu.set_device(device) + + # Tile size is fixed in the kernel + tile_size = 512 + num_cores = torch.npu.get_device_properties(0).vector_core_num + num_tiles = num_cores * tile_size + total_len = num_tiles * size_mult + + # Create an input tensor bounded around our standard test pivots + x = torch.rand(size=(total_len,), device="npu", dtype=dtype).contiguous() + z = torch.zeros(1, device="npu", dtype=torch.int32) + pivot = torch.rand(1).item() + + # Golden PyTorch implementation + expected_count = (x < pivot).sum().to(torch.int32).cpu().item() + + if use_atomic_impl: + count_func = jit_compile("kernel_count_less_than_atomic.cpp") + else: + count_func = jit_compile("kernel_count_less_than.cpp") + + # NPU kernel execution, test to see if any race conditions occur across multiple runs + actual_count = [] + for _ in range(repeat_runs): + count_func(x, z, pivot, block_dim=num_cores) + actual_count.append(z.item()) + + torch.npu.synchronize() + + # Check for consistency across runs and correctness against the expected count + assert len(set(actual_count)) == 1, "Inconsistent results across runs" + assert ( + expected_count == actual_count[0] + ), f"Mismatch: expected {expected_count}, got {actual_count}" + + +if __name__ == "__main__": + # Atomic implementation has problems + # test_count_less_than(size_mult=64, repeat_runs=20, use_atomic_impl=True) + test_count_less_than(size_mult=64, repeat_runs=20, use_atomic_impl=False) From 1a8f55ef5df894118daa1b32a2c30ca314ad9f8d Mon Sep 17 00:00:00 2001 From: Vladimir Loncar Date: Tue, 7 Apr 2026 13:40:46 +0000 Subject: [PATCH 06/17] Minor cleanups --- .../step0_count_less_than/run_count_less_than.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/examples/jit_cpp/histogram/step0_count_less_than/run_count_less_than.py b/examples/jit_cpp/histogram/step0_count_less_than/run_count_less_than.py index a93a0614..7a27d5f5 100644 --- a/examples/jit_cpp/histogram/step0_count_less_than/run_count_less_than.py +++ b/examples/jit_cpp/histogram/step0_count_less_than/run_count_less_than.py @@ -5,21 +5,6 @@ from jit_util_count_less_than import jit_compile -def hist_ref(x, bins): - return torch.histc(x, bins=bins) - - -def random_2d_shape( - min_m=1, - max_m=2048, - min_n=1, - max_n=2048, -): - m = random.randint(min_m, max_m) - n = random.randint(min_n, max_n) - return [m, n] - - def test_count_less_than(size_mult=1, repeat_runs=20, use_atomic_impl=False): device = "npu:1" dtype = torch.float32 From 32d593432c8246099d57568871c0d5a6a7a96ed0 Mon Sep 17 00:00:00 2001 From: Vladimir Loncar Date: Tue, 7 Apr 2026 16:12:18 +0000 Subject: [PATCH 07/17] Histogram implementation guide, step1 - naive histogram --- .../jit_util_histogram.py | 99 +++++++ .../kernel_histogram.cpp | 279 ++++++++++++++++++ .../step1_naive_histogram/run_histogram.py | 53 ++++ 3 files changed, 431 insertions(+) create mode 100644 examples/jit_cpp/histogram/step1_naive_histogram/jit_util_histogram.py create mode 100644 examples/jit_cpp/histogram/step1_naive_histogram/kernel_histogram.cpp create mode 100644 examples/jit_cpp/histogram/step1_naive_histogram/run_histogram.py diff --git a/examples/jit_cpp/histogram/step1_naive_histogram/jit_util_histogram.py b/examples/jit_cpp/histogram/step1_naive_histogram/jit_util_histogram.py new file mode 100644 index 00000000..6c35e5d0 --- /dev/null +++ b/examples/jit_cpp/histogram/step1_naive_histogram/jit_util_histogram.py @@ -0,0 +1,99 @@ +import os +import subprocess +import ctypes + +import torch + +ASCEND_TOOLKIT_HOME = os.environ["ASCEND_TOOLKIT_HOME"] +PTO_LIB_PATH = os.environ["PTO_LIB_PATH"] + + +def compile_cpp(kernel_cpp: str, verbose: bool = False, timeout: int = 120) -> str: + lib_path = os.path.join(os.path.dirname(kernel_cpp), "histogram_jit.so") + + flags = [ + "-fPIC", + "-shared", + "-xcce", + "--npu-arch=dav-2201", + "-DMEMORY_BASE", # here hardcoded for A2A3; TODO: expose this option to jit interface + "-O2", + "-std=c++17", + f"-I{PTO_LIB_PATH}/include", + ] + + command = ["bisheng", *flags, kernel_cpp, "-o", lib_path] + if verbose: + print(f"compile {kernel_cpp} with command: \n", command) + + try: + subprocess.run( + command, + timeout=timeout, + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + except Exception as e: + output = e.stdout.decode("utf-8", errors="replace") if e.stdout else "" + raise RuntimeError( + f"Compile failed with exit code {e.returncode}:\n{output}" + ) from e + + if verbose: + print(f"generated {lib_path}") + return lib_path + + +def torch_to_ctypes(tensor): + return ctypes.c_void_p(tensor.data_ptr()) + + +def load_lib(lib_path, check_type=True): + lib_path = os.path.abspath(lib_path) + lib = ctypes.CDLL(lib_path) + + default_block_dim = torch.npu.get_device_properties().vector_core_num + + if check_type: + lib.histogram_fp32.argtypes = [ + ctypes.c_uint32, # blockDim + ctypes.c_void_p, # stream + ctypes.c_void_p, # x + ctypes.c_void_p, # z_local + ctypes.c_void_p, # z + ctypes.c_uint, # in_length + ctypes.c_int, # bins + ctypes.c_float, # min_val + ctypes.c_float, # max_val + ] + lib.histogram_fp32.restype = None + + def hist_func( + x, z, bins, min_val, max_val, block_dim=default_block_dim, stream_ptr=None + ): + if stream_ptr is None: + stream_ptr = torch.npu.current_stream()._as_parameter_ # noqa + N = x.numel() + z_local = torch.zeros((block_dim, bins), device=x.device, dtype=torch.float32) + lib.histogram_fp32( + block_dim, + stream_ptr, + torch_to_ctypes(x), + torch_to_ctypes(z_local), + torch_to_ctypes(z), + N, + bins, + ctypes.c_float(min_val), + ctypes.c_float(max_val), + ) + + return hist_func + + +def jit_compile(src_path, clean_up=True): + lib_path = compile_cpp(src_path, verbose=True) + func = load_lib(lib_path, check_type=False) + if clean_up: + os.remove(lib_path) + return func diff --git a/examples/jit_cpp/histogram/step1_naive_histogram/kernel_histogram.cpp b/examples/jit_cpp/histogram/step1_naive_histogram/kernel_histogram.cpp new file mode 100644 index 00000000..71ca63a5 --- /dev/null +++ b/examples/jit_cpp/histogram/step1_naive_histogram/kernel_histogram.cpp @@ -0,0 +1,279 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +All rights reserved. + +See LICENSE in the root of the software repository: +https://github.com/huawei-csl/pto-kernels/ +for the full License text. +*/ + +#include + +#include "../kernel_utils.h" + +using namespace pto; + +constexpr uint32_t DEFAULT_TILE_SIZE = 512; +constexpr uint32_t MAX_BINS = 256; +constexpr uint32_t MAX_BLOCKS = 64; + +/** + * runTLocalHistogram - Local, per-core histogram calculation + */ +template +AICORE void runTLocalHistogram(__gm__ T *x, __gm__ float *z_local, + const uint32_t total_length, + const int32_t num_bins, const float min_val, + const float max_val) { +#if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + + // --- Define Global Tensors --- + using InputGlobalData = pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + using HistGlobalData = + pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + + // --- Work Distribution --- + const uint32_t block_idx = get_block_idx(); + const uint32_t block_num = get_block_num(); + const uint32_t num_tiles_total = + kernel_utils::CeilDiv(total_length, TILE_SIZE); + const uint32_t num_tiles_per_core = + kernel_utils::CeilDiv(num_tiles_total, block_num); + const uint32_t start_idx = block_idx * num_tiles_per_core; + const uint32_t end_idx = (start_idx + num_tiles_per_core > num_tiles_total) + ? num_tiles_total + : (start_idx + num_tiles_per_core); + + // --- Define UB Tiles and Memory Layout --- + uint32_t addr = 0; + const uint32_t UB_X_ADDR = addr; + addr += TILE_SIZE * sizeof(T); + const uint32_t UB_CUR_MASK_ADDR = addr; + addr += TILE_SIZE * sizeof(uint8_t); + const uint32_t UB_CUR_F32_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_PREV_F32_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_BIN_F32_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_ONE_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_ZERO_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_REDUCE_TMP_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_COUNT_ADDR = addr; + addr += 8 * sizeof(float); + const uint32_t UB_LOCAL_HIST_ADDR = addr; + + InputGlobalData x_gm(x, {static_cast(total_length)}); + + using InputTileData = Tile; + InputTileData x_tile; + TASSIGN(x_tile, UB_X_ADDR); + + using MaskTileData = Tile; + MaskTileData current_mask; + TASSIGN(current_mask, UB_CUR_MASK_ADDR); + + // Float conversion tiles + using F32TileData = Tile; + F32TileData cur_f32; + TASSIGN(cur_f32, UB_CUR_F32_ADDR); + F32TileData prev_f32; + TASSIGN(prev_f32, UB_PREV_F32_ADDR); + F32TileData bin_mask_f32; + TASSIGN(bin_mask_f32, UB_BIN_F32_ADDR); + + F32TileData one_tile; + TASSIGN(one_tile, UB_ONE_ADDR); + TEXPANDS(one_tile, 1.0f); + F32TileData zero_tile; + TASSIGN(zero_tile, UB_ZERO_ADDR); + TEXPANDS(zero_tile, 0.0f); + + F32TileData reduce_tmp; + TASSIGN(reduce_tmp, UB_REDUCE_TMP_ADDR); + + using F32CountTile = + Tile; + F32CountTile count_f32_tile; + TASSIGN(count_f32_tile, UB_COUNT_ADDR); + + // Local histogram tile in UB + using HistTile = + Tile; + HistTile local_hist(num_bins); + TASSIGN(local_hist, UB_LOCAL_HIST_ADDR); + TEXPANDS(local_hist, 0.0f); + + const float bin_width = (max_val - min_val) / static_cast(num_bins); + + // --- Main Calculation Loop --- + for (uint32_t tile_idx = start_idx; tile_idx < end_idx; ++tile_idx) { + int offset = tile_idx * TILE_SIZE; + TASSIGN(x_gm, x + offset); + + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(x_tile, x_gm); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + // Generate packed bit-mask + TCMPS(current_mask, x_tile, static_cast(min_val), CmpMode::LT); + // Select 1.0f or 0.0f based on the packed bit-mask + TSEL(prev_f32, current_mask, one_tile, zero_tile); + + for (int32_t j = 0; j < num_bins; ++j) { + float bin_upper_bound = min_val + (j + 1) * bin_width; + CmpMode mode = (j == num_bins - 1) ? CmpMode::LE : CmpMode::LT; + + TCMPS(current_mask, x_tile, static_cast(bin_upper_bound), mode); + TSEL(cur_f32, current_mask, one_tile, zero_tile); + TSUB(bin_mask_f32, cur_f32, prev_f32); + + // Reduce the selected tile to get the count of elements less than pivot + // in this tile + TEXPANDS(count_f32_tile, 0.0f); + TEXPANDS(reduce_tmp, 0.0f); + TROWSUM(count_f32_tile, bin_mask_f32, reduce_tmp); + + // Scalar move to update UB local histogram + set_flag(PIPE_V, PIPE_S, EVENT_ID0); + wait_flag(PIPE_V, PIPE_S, EVENT_ID0); + float f_count = count_f32_tile.GetValue(0); + if (f_count > 0.0f) { + local_hist.SetValue(j, local_hist.GetValue(j) + f_count); + } + set_flag(PIPE_S, PIPE_V, EVENT_ID0); + wait_flag(PIPE_S, PIPE_V, EVENT_ID0); + + TMOV(prev_f32, cur_f32); + } + } + + // --- Final Store to Global Memory --- + HistGlobalData z_gm(z_local + block_idx * num_bins, + {static_cast(block_num) * num_bins}); + + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(z_gm, local_hist); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + +#endif +} + +// Template parameter to avoid "no function" kernel launch error +template +AICORE void runTHistogramFinal(__gm__ float *z_local, __gm__ int32_t *z, + const int32_t num_bins, + const int32_t num_blocks) { +#if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + + if (get_block_idx() == 0) { + // --- Define Global Tensors --- + using InGlobalData = + pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + using OutGlobalData = + pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + + uint32_t addr = 0; + const uint32_t UB_IN_ADDR = addr; + addr += MAX_BLOCKS * MAX_BINS * sizeof(float); + const uint32_t UB_REDUCE_TMP_ADDR = addr; + addr += MAX_BINS * sizeof(float); + const uint32_t UB_FLOAT_OUT_ADDR = addr; + addr += MAX_BINS * sizeof(float); + const uint32_t UB_OUT_ADDR = addr; + addr += MAX_BINS * sizeof(int32_t); + + using InTile = Tile; + InTile in_tile( + {static_cast(num_blocks), static_cast(num_bins)}); + TASSIGN(in_tile, UB_IN_ADDR); + + using ReduceTmpTile = + Tile; + ReduceTmpTile reduce_tmp_tile(num_bins); + TASSIGN(reduce_tmp_tile, UB_REDUCE_TMP_ADDR); + TEXPANDS(reduce_tmp_tile, 0.0f); + + using FloatOutTile = + Tile; + FloatOutTile float_out_tile(num_bins); + TASSIGN(float_out_tile, UB_FLOAT_OUT_ADDR); + TEXPANDS(float_out_tile, 0.0f); + + using OutTile = Tile; + OutTile out_tile(num_bins); + TASSIGN(out_tile, UB_OUT_ADDR); + TEXPANDS(out_tile, static_cast(0)); + + // Load all block counts into UB row by row to match 2D tile padding + using InRowTile = + Tile; + InRowTile row_tile(static_cast(num_bins)); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + for (int32_t b = 0; b < num_blocks; ++b) { + InGlobalData z_local_gm(z_local + b * num_bins, {num_bins}); + TASSIGN(row_tile, UB_IN_ADDR + b * MAX_BINS * sizeof(float)); + TLOAD(row_tile, z_local_gm); + } + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + // TCOLSUM reduces along the row dimension (num_blocks) + TCOLSUM(float_out_tile, in_tile, reduce_tmp_tile, true); + + TCVT(out_tile, float_out_tile, RoundMode::CAST_RINT); + + // --- Final Store to Global Memory --- + OutGlobalData z_gm(z, {num_bins}); + + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(z_gm, out_tile); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + } + +#endif +} + +__global__ AICORE void histogram_local_fp32( + __gm__ void *x, __gm__ void *z_local, const uint32_t in_length, + const int32_t num_bins, const float min_val, const float max_val) { + runTLocalHistogram( + (__gm__ float *)x, (__gm__ float *)z_local, in_length, num_bins, min_val, + max_val); +} + +__global__ AICORE void histogram_final(__gm__ void *z_local, __gm__ void *z, + const int32_t num_bins, + const int32_t num_blocks) { + runTHistogramFinal<0>((__gm__ float *)z_local, (__gm__ int32_t *)z, num_bins, + num_blocks); +} + +extern "C" void histogram_fp32(uint32_t num_blocks, void *stream, void *x, + void *z_local, void *z, const uint32_t in_length, + const int32_t num_bins, const float min_val, + const float max_val) { + histogram_local_fp32<<>>( + x, z_local, in_length, num_bins, min_val, max_val); + histogram_final<<<1, nullptr, stream>>>(z_local, z, num_bins, num_blocks); +} diff --git a/examples/jit_cpp/histogram/step1_naive_histogram/run_histogram.py b/examples/jit_cpp/histogram/step1_naive_histogram/run_histogram.py new file mode 100644 index 00000000..37fa9908 --- /dev/null +++ b/examples/jit_cpp/histogram/step1_naive_histogram/run_histogram.py @@ -0,0 +1,53 @@ +import random +import torch +import torch_npu # noqa + +from jit_util_histogram import jit_compile + + +def test_histogram(size_mult=1, repeat_runs=20): + device = "npu:1" + dtype = torch.float32 + torch.npu.set_device(device) + + # Tile size is fixed in the kernel + tile_size = 512 + num_cores = torch.npu.get_device_properties().vector_core_num + num_tiles = num_cores * tile_size + total_len = num_tiles * size_mult + + bins = 256 + min_val = 0.0 + max_val = 256.0 + + # Create an input tensor bounded around our standard test range + x = torch.rand(size=(total_len,), device="npu", dtype=dtype).contiguous() * max_val + z = torch.zeros(bins, device="npu", dtype=torch.int32) + + # Golden PyTorch implementation + expected_hist = torch.histc(x.cpu(), bins, min=min_val, max=max_val).to(torch.int32) + + hist_func = jit_compile("kernel_histogram.cpp") + + # NPU kernel execution, test to see if any race conditions occur across multiple runs + actual_hist = [] + for _ in range(repeat_runs): + z.zero_() + hist_func(x, z, bins, min_val, max_val, block_dim=num_cores) + actual_hist.append(z.cpu().clone()) + + torch.npu.synchronize() + + # Check for consistency across runs and correctness against the expected count + for i, hist in enumerate(actual_hist): + assert torch.equal( + hist, actual_hist[0] + ), f"Inconsistent results across runs at run {i}" + + assert torch.equal( + expected_hist, actual_hist[0] + ), "Mismatch between expected and actual histogram" + + +if __name__ == "__main__": + test_histogram(size_mult=64, repeat_runs=20) From 974053b78768e9fd89dad34cdac89a19d600e581 Mon Sep 17 00:00:00 2001 From: Vladimir Loncar Date: Tue, 7 Apr 2026 16:12:33 +0000 Subject: [PATCH 08/17] Histogram implementation guide, step2 - double buffering --- .../jit_util_histogram.py | 99 ++++++ .../kernel_histogram.cpp | 294 ++++++++++++++++++ .../step2_double_buffering/run_histogram.py | 53 ++++ 3 files changed, 446 insertions(+) create mode 100644 examples/jit_cpp/histogram/step2_double_buffering/jit_util_histogram.py create mode 100644 examples/jit_cpp/histogram/step2_double_buffering/kernel_histogram.cpp create mode 100644 examples/jit_cpp/histogram/step2_double_buffering/run_histogram.py diff --git a/examples/jit_cpp/histogram/step2_double_buffering/jit_util_histogram.py b/examples/jit_cpp/histogram/step2_double_buffering/jit_util_histogram.py new file mode 100644 index 00000000..6c35e5d0 --- /dev/null +++ b/examples/jit_cpp/histogram/step2_double_buffering/jit_util_histogram.py @@ -0,0 +1,99 @@ +import os +import subprocess +import ctypes + +import torch + +ASCEND_TOOLKIT_HOME = os.environ["ASCEND_TOOLKIT_HOME"] +PTO_LIB_PATH = os.environ["PTO_LIB_PATH"] + + +def compile_cpp(kernel_cpp: str, verbose: bool = False, timeout: int = 120) -> str: + lib_path = os.path.join(os.path.dirname(kernel_cpp), "histogram_jit.so") + + flags = [ + "-fPIC", + "-shared", + "-xcce", + "--npu-arch=dav-2201", + "-DMEMORY_BASE", # here hardcoded for A2A3; TODO: expose this option to jit interface + "-O2", + "-std=c++17", + f"-I{PTO_LIB_PATH}/include", + ] + + command = ["bisheng", *flags, kernel_cpp, "-o", lib_path] + if verbose: + print(f"compile {kernel_cpp} with command: \n", command) + + try: + subprocess.run( + command, + timeout=timeout, + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + except Exception as e: + output = e.stdout.decode("utf-8", errors="replace") if e.stdout else "" + raise RuntimeError( + f"Compile failed with exit code {e.returncode}:\n{output}" + ) from e + + if verbose: + print(f"generated {lib_path}") + return lib_path + + +def torch_to_ctypes(tensor): + return ctypes.c_void_p(tensor.data_ptr()) + + +def load_lib(lib_path, check_type=True): + lib_path = os.path.abspath(lib_path) + lib = ctypes.CDLL(lib_path) + + default_block_dim = torch.npu.get_device_properties().vector_core_num + + if check_type: + lib.histogram_fp32.argtypes = [ + ctypes.c_uint32, # blockDim + ctypes.c_void_p, # stream + ctypes.c_void_p, # x + ctypes.c_void_p, # z_local + ctypes.c_void_p, # z + ctypes.c_uint, # in_length + ctypes.c_int, # bins + ctypes.c_float, # min_val + ctypes.c_float, # max_val + ] + lib.histogram_fp32.restype = None + + def hist_func( + x, z, bins, min_val, max_val, block_dim=default_block_dim, stream_ptr=None + ): + if stream_ptr is None: + stream_ptr = torch.npu.current_stream()._as_parameter_ # noqa + N = x.numel() + z_local = torch.zeros((block_dim, bins), device=x.device, dtype=torch.float32) + lib.histogram_fp32( + block_dim, + stream_ptr, + torch_to_ctypes(x), + torch_to_ctypes(z_local), + torch_to_ctypes(z), + N, + bins, + ctypes.c_float(min_val), + ctypes.c_float(max_val), + ) + + return hist_func + + +def jit_compile(src_path, clean_up=True): + lib_path = compile_cpp(src_path, verbose=True) + func = load_lib(lib_path, check_type=False) + if clean_up: + os.remove(lib_path) + return func diff --git a/examples/jit_cpp/histogram/step2_double_buffering/kernel_histogram.cpp b/examples/jit_cpp/histogram/step2_double_buffering/kernel_histogram.cpp new file mode 100644 index 00000000..f74a1507 --- /dev/null +++ b/examples/jit_cpp/histogram/step2_double_buffering/kernel_histogram.cpp @@ -0,0 +1,294 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +All rights reserved. + +See LICENSE in the root of the software repository: +https://github.com/huawei-csl/pto-kernels/ +for the full License text. +*/ + +#include + +#include "../kernel_utils.h" + +using namespace pto; + +constexpr uint32_t DEFAULT_TILE_SIZE = 512; +constexpr uint32_t MAX_BINS = 256; +constexpr uint32_t MAX_BLOCKS = 64; + +/** + * runTLocalHistogram - Local, per-core histogram calculation + */ +template +AICORE void runTLocalHistogram(__gm__ T *x, __gm__ float *z_local, + const uint32_t total_length, + const int32_t num_bins, const float min_val, + const float max_val) { +#if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + + // --- Define Global Tensors --- + using InputGlobalData = pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + using HistGlobalData = + pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + + // --- Work Distribution --- + const uint32_t block_idx = get_block_idx(); + const uint32_t block_num = get_block_num(); + const uint32_t num_tiles_total = + kernel_utils::CeilDiv(total_length, TILE_SIZE); + const uint32_t num_tiles_per_core = + kernel_utils::CeilDiv(num_tiles_total, block_num); + const uint32_t start_idx = block_idx * num_tiles_per_core; + const uint32_t end_idx = (start_idx + num_tiles_per_core > num_tiles_total) + ? num_tiles_total + : (start_idx + num_tiles_per_core); + + // --- Define UB Tiles and Memory Layout --- + uint32_t addr = 0; + const uint32_t UB_X_PING = addr; + addr += TILE_SIZE * sizeof(T); + const uint32_t UB_X_PONG = addr; + addr += TILE_SIZE * sizeof(T); + const uint32_t UB_CUR_MASK_ADDR = addr; + addr += TILE_SIZE * sizeof(uint8_t); + const uint32_t UB_CUR_F32_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_PREV_F32_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_BIN_F32_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_ONE_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_ZERO_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_REDUCE_TMP_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_COUNT_ADDR = addr; + addr += 8 * sizeof(float); + const uint32_t UB_LOCAL_HIST_ADDR = addr; + + InputGlobalData x_gm(x, {static_cast(total_length)}); + + using InputTileData = Tile; + + using MaskTileData = Tile; + MaskTileData current_mask; + TASSIGN(current_mask, UB_CUR_MASK_ADDR); + + // Float conversion tiles + using F32TileData = Tile; + F32TileData cur_f32; + TASSIGN(cur_f32, UB_CUR_F32_ADDR); + F32TileData prev_f32; + TASSIGN(prev_f32, UB_PREV_F32_ADDR); + F32TileData bin_mask_f32; + TASSIGN(bin_mask_f32, UB_BIN_F32_ADDR); + + F32TileData one_tile; + TASSIGN(one_tile, UB_ONE_ADDR); + TEXPANDS(one_tile, 1.0f); + F32TileData zero_tile; + TASSIGN(zero_tile, UB_ZERO_ADDR); + TEXPANDS(zero_tile, 0.0f); + + F32TileData reduce_tmp; + TASSIGN(reduce_tmp, UB_REDUCE_TMP_ADDR); + + using F32CountTile = + Tile; + F32CountTile count_f32_tile; + TASSIGN(count_f32_tile, UB_COUNT_ADDR); + + // Local histogram tile in UB + using HistTile = + Tile; + HistTile local_hist(num_bins); + TASSIGN(local_hist, UB_LOCAL_HIST_ADDR); + TEXPANDS(local_hist, 0.0f); + + const float bin_width = (max_val - min_val) / static_cast(num_bins); + + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + + // --- Main Calculation Loop --- + for (uint32_t tile_idx = start_idx, ping = 1; tile_idx < end_idx; + ++tile_idx) { + int offset = tile_idx * TILE_SIZE; + TASSIGN(x_gm, x + offset); + + const event_t ev = ping ? (event_t)EVENT_ID0 : (event_t)EVENT_ID1; + const unsigned x_base = ping ? UB_X_PING : UB_X_PONG; + + InputTileData x_tile; + TASSIGN(x_tile, x_base); + + wait_flag(PIPE_V, PIPE_MTE2, ev); + TLOAD(x_tile, x_gm); + set_flag(PIPE_MTE2, PIPE_V, ev); + wait_flag(PIPE_MTE2, PIPE_V, ev); + + // Generate packed bit-mask + TCMPS(current_mask, x_tile, static_cast(min_val), CmpMode::LT); + // Select 1.0f or 0.0f based on the packed bit-mask + TSEL(prev_f32, current_mask, one_tile, zero_tile); + + for (int32_t j = 0; j < num_bins; ++j) { + float bin_upper_bound = min_val + (j + 1) * bin_width; + CmpMode mode = (j == num_bins - 1) ? CmpMode::LE : CmpMode::LT; + + TCMPS(current_mask, x_tile, static_cast(bin_upper_bound), mode); + TSEL(cur_f32, current_mask, one_tile, zero_tile); + TSUB(bin_mask_f32, cur_f32, prev_f32); + + // Reduce the selected tile to get the count of elements less than pivot + // in this tile + TEXPANDS(count_f32_tile, 0.0f); + TEXPANDS(reduce_tmp, 0.0f); + TROWSUM(count_f32_tile, bin_mask_f32, reduce_tmp); + + // Scalar move to update UB local histogram + set_flag(PIPE_V, PIPE_S, EVENT_ID2); + wait_flag(PIPE_V, PIPE_S, EVENT_ID2); + float f_count = count_f32_tile.GetValue(0); + if (f_count > 0.0f) { + local_hist.SetValue(j, local_hist.GetValue(j) + f_count); + } + set_flag(PIPE_S, PIPE_V, EVENT_ID2); + wait_flag(PIPE_S, PIPE_V, EVENT_ID2); + + TMOV(prev_f32, cur_f32); + } + + set_flag(PIPE_V, PIPE_MTE2, ev); + ping = 1 - ping; + } + + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + + // --- Final Store to Global Memory --- + HistGlobalData z_gm(z_local + block_idx * num_bins, + {static_cast(block_num) * num_bins}); + + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(z_gm, local_hist); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + +#endif +} + +// Template parameter to avoid "no function" kernel launch error +template +AICORE void runTHistogramFinal(__gm__ float *z_local, __gm__ int32_t *z, + const int32_t num_bins, + const int32_t num_blocks) { +#if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + + if (get_block_idx() == 0) { + // --- Define Global Tensors --- + using InGlobalData = + pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + using OutGlobalData = + pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + + uint32_t addr = 0; + const uint32_t UB_IN_ADDR = addr; + addr += MAX_BLOCKS * MAX_BINS * sizeof(float); + const uint32_t UB_REDUCE_TMP_ADDR = addr; + addr += MAX_BINS * sizeof(float); + const uint32_t UB_FLOAT_OUT_ADDR = addr; + addr += MAX_BINS * sizeof(float); + const uint32_t UB_OUT_ADDR = addr; + addr += MAX_BINS * sizeof(int32_t); + + using InTile = Tile; + InTile in_tile( + {static_cast(num_blocks), static_cast(num_bins)}); + TASSIGN(in_tile, UB_IN_ADDR); + + using ReduceTmpTile = + Tile; + ReduceTmpTile reduce_tmp_tile(num_bins); + TASSIGN(reduce_tmp_tile, UB_REDUCE_TMP_ADDR); + TEXPANDS(reduce_tmp_tile, 0.0f); + + using FloatOutTile = + Tile; + FloatOutTile float_out_tile(num_bins); + TASSIGN(float_out_tile, UB_FLOAT_OUT_ADDR); + TEXPANDS(float_out_tile, 0.0f); + + using OutTile = Tile; + OutTile out_tile(num_bins); + TASSIGN(out_tile, UB_OUT_ADDR); + TEXPANDS(out_tile, static_cast(0)); + + // Load all block counts into UB row by row to match 2D tile padding + using InRowTile = + Tile; + InRowTile row_tile(static_cast(num_bins)); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + for (int32_t b = 0; b < num_blocks; ++b) { + InGlobalData z_local_gm(z_local + b * num_bins, {num_bins}); + TASSIGN(row_tile, UB_IN_ADDR + b * MAX_BINS * sizeof(float)); + TLOAD(row_tile, z_local_gm); + } + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + // TCOLSUM reduces along the row dimension (num_blocks) + TCOLSUM(float_out_tile, in_tile, reduce_tmp_tile, true); + + TCVT(out_tile, float_out_tile, RoundMode::CAST_RINT); + + // --- Final Store to Global Memory --- + OutGlobalData z_gm(z, {num_bins}); + + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(z_gm, out_tile); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + } + +#endif +} + +__global__ AICORE void histogram_local_fp32( + __gm__ void *x, __gm__ void *z_local, const uint32_t in_length, + const int32_t num_bins, const float min_val, const float max_val) { + runTLocalHistogram( + (__gm__ float *)x, (__gm__ float *)z_local, in_length, num_bins, min_val, + max_val); +} + +__global__ AICORE void histogram_final(__gm__ void *z_local, __gm__ void *z, + const int32_t num_bins, + const int32_t num_blocks) { + runTHistogramFinal<0>((__gm__ float *)z_local, (__gm__ int32_t *)z, num_bins, + num_blocks); +} + +extern "C" void histogram_fp32(uint32_t num_blocks, void *stream, void *x, + void *z_local, void *z, const uint32_t in_length, + const int32_t num_bins, const float min_val, + const float max_val) { + histogram_local_fp32<<>>( + x, z_local, in_length, num_bins, min_val, max_val); + histogram_final<<<1, nullptr, stream>>>(z_local, z, num_bins, num_blocks); +} diff --git a/examples/jit_cpp/histogram/step2_double_buffering/run_histogram.py b/examples/jit_cpp/histogram/step2_double_buffering/run_histogram.py new file mode 100644 index 00000000..3aa94189 --- /dev/null +++ b/examples/jit_cpp/histogram/step2_double_buffering/run_histogram.py @@ -0,0 +1,53 @@ +import random +import torch +import torch_npu # noqa + +from jit_util_histogram import jit_compile + + +def test_histogram(size_mult=1, repeat_runs=20): + device = "npu:1" + dtype = torch.float32 + torch.npu.set_device(device) + + # Tile size is fixed in the kernel + tile_size = 512 + num_cores = 1 # torch.npu.get_device_properties().vector_core_num + num_tiles = num_cores * tile_size + total_len = num_tiles * size_mult + + bins = 256 + min_val = 0.0 + max_val = 256.0 + + # Create an input tensor bounded around our standard test range + x = torch.rand(size=(total_len,), device="npu", dtype=dtype).contiguous() * max_val + z = torch.zeros(bins, device="npu", dtype=torch.int32) + + # Golden PyTorch implementation + expected_hist = torch.histc(x.cpu(), bins, min=min_val, max=max_val).to(torch.int32) + + hist_func = jit_compile("kernel_histogram.cpp") + + # NPU kernel execution, test to see if any race conditions occur across multiple runs + actual_hist = [] + for _ in range(repeat_runs): + z.zero_() + hist_func(x, z, bins, min_val, max_val, block_dim=num_cores) + actual_hist.append(z.cpu().clone()) + + torch.npu.synchronize() + + # Check for consistency across runs and correctness against the expected count + for i, hist in enumerate(actual_hist): + assert torch.equal( + hist, actual_hist[0] + ), f"Inconsistent results across runs at run {i}" + + assert torch.equal( + expected_hist, actual_hist[0] + ), "Mismatch between expected and actual histogram" + + +if __name__ == "__main__": + test_histogram(size_mult=64, repeat_runs=20) From 1788f72c89e414dd0e36343412fe8fcaf0da58de Mon Sep 17 00:00:00 2001 From: Vladimir Loncar Date: Wed, 8 Apr 2026 09:28:27 +0000 Subject: [PATCH 09/17] Double buffering consistency fixes --- csrc/kernel/kernel_histogram.cpp | 57 ++++++++++++------- .../kernel_histogram.cpp | 6 ++ .../step1_naive_histogram/run_histogram.py | 4 +- .../kernel_histogram.cpp | 4 ++ .../step2_double_buffering/run_histogram.py | 6 +- 5 files changed, 53 insertions(+), 24 deletions(-) diff --git a/csrc/kernel/kernel_histogram.cpp b/csrc/kernel/kernel_histogram.cpp index 10949efd..6e68203c 100644 --- a/csrc/kernel/kernel_histogram.cpp +++ b/csrc/kernel/kernel_histogram.cpp @@ -27,7 +27,7 @@ constexpr uint32_t MAX_BLOCKS = 64; * runTLocalHistogram - Local, per-core histogram calculation */ template -AICORE void runTLocalHistogram(__gm__ T* x, __gm__ float* z_local, +AICORE void runTLocalHistogram(__gm__ T *x, __gm__ float *z_local, const uint32_t total_length, const int32_t num_bins, const float min_val, const float max_val) { @@ -55,7 +55,9 @@ AICORE void runTLocalHistogram(__gm__ T* x, __gm__ float* z_local, // --- Define UB Tiles and Memory Layout --- uint32_t addr = 0; - const uint32_t UB_X_ADDR = addr; + const uint32_t UB_X_PING = addr; + addr += TILE_SIZE * sizeof(T); + const uint32_t UB_X_PONG = addr; addr += TILE_SIZE * sizeof(T); const uint32_t UB_CUR_MASK_ADDR = addr; addr += TILE_SIZE * sizeof(uint8_t); @@ -78,8 +80,6 @@ AICORE void runTLocalHistogram(__gm__ T* x, __gm__ float* z_local, InputGlobalData x_gm(x, {static_cast(total_length)}); using InputTileData = Tile; - InputTileData x_tile; - TASSIGN(x_tile, UB_X_ADDR); using MaskTileData = Tile; MaskTileData current_mask; @@ -118,16 +118,27 @@ AICORE void runTLocalHistogram(__gm__ T* x, __gm__ float* z_local, const float bin_width = (max_val - min_val) / static_cast(num_bins); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + // --- Main Calculation Loop --- - for (uint32_t tile_idx = start_idx; tile_idx < end_idx; ++tile_idx) { + for (uint32_t tile_idx = start_idx, ping = 1; tile_idx < end_idx; + ++tile_idx) { int offset = tile_idx * TILE_SIZE; TASSIGN(x_gm, x + offset); - set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); - wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + const event_t ev = ping ? (event_t)EVENT_ID0 : (event_t)EVENT_ID1; + const unsigned x_base = ping ? UB_X_PING : UB_X_PONG; + + InputTileData x_tile; + TASSIGN(x_tile, x_base); + + wait_flag(PIPE_V, PIPE_MTE2, ev); TLOAD(x_tile, x_gm); - set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); - wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + set_flag(PIPE_MTE2, PIPE_V, ev); + wait_flag(PIPE_MTE2, PIPE_V, ev); // Generate packed bit-mask TCMPS(current_mask, x_tile, static_cast(min_val), CmpMode::LT); @@ -149,19 +160,27 @@ AICORE void runTLocalHistogram(__gm__ T* x, __gm__ float* z_local, TROWSUM(count_f32_tile, bin_mask_f32, reduce_tmp); // Scalar move to update UB local histogram - set_flag(PIPE_V, PIPE_S, EVENT_ID0); - wait_flag(PIPE_V, PIPE_S, EVENT_ID0); + set_flag(PIPE_V, PIPE_S, EVENT_ID2); + wait_flag(PIPE_V, PIPE_S, EVENT_ID2); float f_count = count_f32_tile.GetValue(0); if (f_count > 0.0f) { local_hist.SetValue(j, local_hist.GetValue(j) + f_count); } - set_flag(PIPE_S, PIPE_V, EVENT_ID0); - wait_flag(PIPE_S, PIPE_V, EVENT_ID0); + set_flag(PIPE_S, PIPE_V, EVENT_ID2); + wait_flag(PIPE_S, PIPE_V, EVENT_ID2); TMOV(prev_f32, cur_f32); } + + set_flag(PIPE_V, PIPE_MTE2, ev); + ping = 1 - ping; } + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + // --- Final Store to Global Memory --- HistGlobalData z_gm(z_local + block_idx * num_bins, {static_cast(block_num) * num_bins}); @@ -175,7 +194,7 @@ AICORE void runTLocalHistogram(__gm__ T* x, __gm__ float* z_local, // Template parameter to avoid "no function" kernel launch error template -AICORE void runTHistogramFinal(__gm__ float* z_local, __gm__ int32_t* z, +AICORE void runTHistogramFinal(__gm__ float *z_local, __gm__ int32_t *z, const int32_t num_bins, const int32_t num_blocks) { set_mask_norm(); @@ -259,9 +278,9 @@ extern "C" __global__ AICORE void histogram_fp16(GM_ADDR x, GM_ADDR z_local, const int32_t num_bins, const float min_val, const float max_val) { - runTLocalHistogram((__gm__ half*)x, - (__gm__ float*)z_local, in_length, - num_bins, min_val, max_val); + runTLocalHistogram( + (__gm__ half *)x, (__gm__ float *)z_local, in_length, num_bins, min_val, + max_val); } extern "C" __global__ AICORE void histogram_fp32(GM_ADDR x, GM_ADDR z_local, @@ -270,14 +289,14 @@ extern "C" __global__ AICORE void histogram_fp32(GM_ADDR x, GM_ADDR z_local, const float min_val, const float max_val) { runTLocalHistogram( - (__gm__ float*)x, (__gm__ float*)z_local, in_length, num_bins, min_val, + (__gm__ float *)x, (__gm__ float *)z_local, in_length, num_bins, min_val, max_val); } extern "C" __global__ AICORE void histogram_final(GM_ADDR z_local, GM_ADDR z, const int32_t num_bins, const int32_t num_blocks) { - runTHistogramFinal<0>((__gm__ float*)z_local, (__gm__ int32_t*)z, num_bins, + runTHistogramFinal<0>((__gm__ float *)z_local, (__gm__ int32_t *)z, num_bins, num_blocks); } diff --git a/examples/jit_cpp/histogram/step1_naive_histogram/kernel_histogram.cpp b/examples/jit_cpp/histogram/step1_naive_histogram/kernel_histogram.cpp index 71ca63a5..15df61fc 100644 --- a/examples/jit_cpp/histogram/step1_naive_histogram/kernel_histogram.cpp +++ b/examples/jit_cpp/histogram/step1_naive_histogram/kernel_histogram.cpp @@ -113,6 +113,9 @@ AICORE void runTLocalHistogram(__gm__ T *x, __gm__ float *z_local, const float bin_width = (max_val - min_val) / static_cast(num_bins); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + // --- Main Calculation Loop --- for (uint32_t tile_idx = start_idx; tile_idx < end_idx; ++tile_idx) { int offset = tile_idx * TILE_SIZE; @@ -157,6 +160,9 @@ AICORE void runTLocalHistogram(__gm__ T *x, __gm__ float *z_local, } } + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + // --- Final Store to Global Memory --- HistGlobalData z_gm(z_local + block_idx * num_bins, {static_cast(block_num) * num_bins}); diff --git a/examples/jit_cpp/histogram/step1_naive_histogram/run_histogram.py b/examples/jit_cpp/histogram/step1_naive_histogram/run_histogram.py index 37fa9908..91141c45 100644 --- a/examples/jit_cpp/histogram/step1_naive_histogram/run_histogram.py +++ b/examples/jit_cpp/histogram/step1_naive_histogram/run_histogram.py @@ -42,11 +42,11 @@ def test_histogram(size_mult=1, repeat_runs=20): for i, hist in enumerate(actual_hist): assert torch.equal( hist, actual_hist[0] - ), f"Inconsistent results across runs at run {i}" + ), f"Inconsistent results across runs at run {i}, expected\n {actual_hist[0]}\ngot\n{hist}\n" assert torch.equal( expected_hist, actual_hist[0] - ), "Mismatch between expected and actual histogram" + ), f"Mismatch between expected and actual histogram, expected\n {expected_hist}\ngot\n{actual_hist[0]}\n" if __name__ == "__main__": diff --git a/examples/jit_cpp/histogram/step2_double_buffering/kernel_histogram.cpp b/examples/jit_cpp/histogram/step2_double_buffering/kernel_histogram.cpp index f74a1507..3aff7361 100644 --- a/examples/jit_cpp/histogram/step2_double_buffering/kernel_histogram.cpp +++ b/examples/jit_cpp/histogram/step2_double_buffering/kernel_histogram.cpp @@ -115,6 +115,8 @@ AICORE void runTLocalHistogram(__gm__ T *x, __gm__ float *z_local, set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); // --- Main Calculation Loop --- for (uint32_t tile_idx = start_idx, ping = 1; tile_idx < end_idx; @@ -171,6 +173,8 @@ AICORE void runTLocalHistogram(__gm__ T *x, __gm__ float *z_local, wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); // --- Final Store to Global Memory --- HistGlobalData z_gm(z_local + block_idx * num_bins, diff --git a/examples/jit_cpp/histogram/step2_double_buffering/run_histogram.py b/examples/jit_cpp/histogram/step2_double_buffering/run_histogram.py index 3aa94189..91141c45 100644 --- a/examples/jit_cpp/histogram/step2_double_buffering/run_histogram.py +++ b/examples/jit_cpp/histogram/step2_double_buffering/run_histogram.py @@ -12,7 +12,7 @@ def test_histogram(size_mult=1, repeat_runs=20): # Tile size is fixed in the kernel tile_size = 512 - num_cores = 1 # torch.npu.get_device_properties().vector_core_num + num_cores = torch.npu.get_device_properties().vector_core_num num_tiles = num_cores * tile_size total_len = num_tiles * size_mult @@ -42,11 +42,11 @@ def test_histogram(size_mult=1, repeat_runs=20): for i, hist in enumerate(actual_hist): assert torch.equal( hist, actual_hist[0] - ), f"Inconsistent results across runs at run {i}" + ), f"Inconsistent results across runs at run {i}, expected\n {actual_hist[0]}\ngot\n{hist}\n" assert torch.equal( expected_hist, actual_hist[0] - ), "Mismatch between expected and actual histogram" + ), f"Mismatch between expected and actual histogram, expected\n {expected_hist}\ngot\n{actual_hist[0]}\n" if __name__ == "__main__": From 908e8116162d1ab8af8ab8575484f716a3da4426 Mon Sep 17 00:00:00 2001 From: Vladimir Loncar Date: Wed, 8 Apr 2026 09:32:53 +0000 Subject: [PATCH 10/17] Make ruff happy --- .../histogram/step0_count_less_than/run_count_less_than.py | 1 - .../jit_cpp/histogram/step1_naive_histogram/run_histogram.py | 1 - .../jit_cpp/histogram/step2_double_buffering/run_histogram.py | 1 - 3 files changed, 3 deletions(-) diff --git a/examples/jit_cpp/histogram/step0_count_less_than/run_count_less_than.py b/examples/jit_cpp/histogram/step0_count_less_than/run_count_less_than.py index 7a27d5f5..2e6b6618 100644 --- a/examples/jit_cpp/histogram/step0_count_less_than/run_count_less_than.py +++ b/examples/jit_cpp/histogram/step0_count_less_than/run_count_less_than.py @@ -1,4 +1,3 @@ -import random import torch import torch_npu # noqa diff --git a/examples/jit_cpp/histogram/step1_naive_histogram/run_histogram.py b/examples/jit_cpp/histogram/step1_naive_histogram/run_histogram.py index 91141c45..7a16a81f 100644 --- a/examples/jit_cpp/histogram/step1_naive_histogram/run_histogram.py +++ b/examples/jit_cpp/histogram/step1_naive_histogram/run_histogram.py @@ -1,4 +1,3 @@ -import random import torch import torch_npu # noqa diff --git a/examples/jit_cpp/histogram/step2_double_buffering/run_histogram.py b/examples/jit_cpp/histogram/step2_double_buffering/run_histogram.py index 91141c45..7a16a81f 100644 --- a/examples/jit_cpp/histogram/step2_double_buffering/run_histogram.py +++ b/examples/jit_cpp/histogram/step2_double_buffering/run_histogram.py @@ -1,4 +1,3 @@ -import random import torch import torch_npu # noqa From c510f372178a0182ed5ff89c0cd05014b0d074b5 Mon Sep 17 00:00:00 2001 From: Vladimir Loncar Date: Wed, 8 Apr 2026 11:44:05 +0000 Subject: [PATCH 11/17] Refactor example run scripts --- .../jit_util_histogram.py | 0 .../run_histogram.py | 10 +- .../jit_util_histogram.py | 99 ------------------- .../step2_double_buffering/run_histogram.py | 52 ---------- 4 files changed, 7 insertions(+), 154 deletions(-) rename examples/jit_cpp/histogram/{step1_naive_histogram => }/jit_util_histogram.py (100%) rename examples/jit_cpp/histogram/{step1_naive_histogram => }/run_histogram.py (84%) delete mode 100644 examples/jit_cpp/histogram/step2_double_buffering/jit_util_histogram.py delete mode 100644 examples/jit_cpp/histogram/step2_double_buffering/run_histogram.py diff --git a/examples/jit_cpp/histogram/step1_naive_histogram/jit_util_histogram.py b/examples/jit_cpp/histogram/jit_util_histogram.py similarity index 100% rename from examples/jit_cpp/histogram/step1_naive_histogram/jit_util_histogram.py rename to examples/jit_cpp/histogram/jit_util_histogram.py diff --git a/examples/jit_cpp/histogram/step1_naive_histogram/run_histogram.py b/examples/jit_cpp/histogram/run_histogram.py similarity index 84% rename from examples/jit_cpp/histogram/step1_naive_histogram/run_histogram.py rename to examples/jit_cpp/histogram/run_histogram.py index 7a16a81f..76da56d5 100644 --- a/examples/jit_cpp/histogram/step1_naive_histogram/run_histogram.py +++ b/examples/jit_cpp/histogram/run_histogram.py @@ -3,8 +3,12 @@ from jit_util_histogram import jit_compile +algo_steps = { + 1: "step1_naive_histogram", + 2: "step2_double_buffering", +} -def test_histogram(size_mult=1, repeat_runs=20): +def test_histogram(algo_step=2, size_mult=1, repeat_runs=20): device = "npu:1" dtype = torch.float32 torch.npu.set_device(device) @@ -26,7 +30,7 @@ def test_histogram(size_mult=1, repeat_runs=20): # Golden PyTorch implementation expected_hist = torch.histc(x.cpu(), bins, min=min_val, max=max_val).to(torch.int32) - hist_func = jit_compile("kernel_histogram.cpp") + hist_func = jit_compile(f"{algo_steps[algo_step]}/kernel_histogram.cpp") # NPU kernel execution, test to see if any race conditions occur across multiple runs actual_hist = [] @@ -49,4 +53,4 @@ def test_histogram(size_mult=1, repeat_runs=20): if __name__ == "__main__": - test_histogram(size_mult=64, repeat_runs=20) + test_histogram(algo_step=2, size_mult=64, repeat_runs=20) diff --git a/examples/jit_cpp/histogram/step2_double_buffering/jit_util_histogram.py b/examples/jit_cpp/histogram/step2_double_buffering/jit_util_histogram.py deleted file mode 100644 index 6c35e5d0..00000000 --- a/examples/jit_cpp/histogram/step2_double_buffering/jit_util_histogram.py +++ /dev/null @@ -1,99 +0,0 @@ -import os -import subprocess -import ctypes - -import torch - -ASCEND_TOOLKIT_HOME = os.environ["ASCEND_TOOLKIT_HOME"] -PTO_LIB_PATH = os.environ["PTO_LIB_PATH"] - - -def compile_cpp(kernel_cpp: str, verbose: bool = False, timeout: int = 120) -> str: - lib_path = os.path.join(os.path.dirname(kernel_cpp), "histogram_jit.so") - - flags = [ - "-fPIC", - "-shared", - "-xcce", - "--npu-arch=dav-2201", - "-DMEMORY_BASE", # here hardcoded for A2A3; TODO: expose this option to jit interface - "-O2", - "-std=c++17", - f"-I{PTO_LIB_PATH}/include", - ] - - command = ["bisheng", *flags, kernel_cpp, "-o", lib_path] - if verbose: - print(f"compile {kernel_cpp} with command: \n", command) - - try: - subprocess.run( - command, - timeout=timeout, - check=True, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - ) - except Exception as e: - output = e.stdout.decode("utf-8", errors="replace") if e.stdout else "" - raise RuntimeError( - f"Compile failed with exit code {e.returncode}:\n{output}" - ) from e - - if verbose: - print(f"generated {lib_path}") - return lib_path - - -def torch_to_ctypes(tensor): - return ctypes.c_void_p(tensor.data_ptr()) - - -def load_lib(lib_path, check_type=True): - lib_path = os.path.abspath(lib_path) - lib = ctypes.CDLL(lib_path) - - default_block_dim = torch.npu.get_device_properties().vector_core_num - - if check_type: - lib.histogram_fp32.argtypes = [ - ctypes.c_uint32, # blockDim - ctypes.c_void_p, # stream - ctypes.c_void_p, # x - ctypes.c_void_p, # z_local - ctypes.c_void_p, # z - ctypes.c_uint, # in_length - ctypes.c_int, # bins - ctypes.c_float, # min_val - ctypes.c_float, # max_val - ] - lib.histogram_fp32.restype = None - - def hist_func( - x, z, bins, min_val, max_val, block_dim=default_block_dim, stream_ptr=None - ): - if stream_ptr is None: - stream_ptr = torch.npu.current_stream()._as_parameter_ # noqa - N = x.numel() - z_local = torch.zeros((block_dim, bins), device=x.device, dtype=torch.float32) - lib.histogram_fp32( - block_dim, - stream_ptr, - torch_to_ctypes(x), - torch_to_ctypes(z_local), - torch_to_ctypes(z), - N, - bins, - ctypes.c_float(min_val), - ctypes.c_float(max_val), - ) - - return hist_func - - -def jit_compile(src_path, clean_up=True): - lib_path = compile_cpp(src_path, verbose=True) - func = load_lib(lib_path, check_type=False) - if clean_up: - os.remove(lib_path) - return func diff --git a/examples/jit_cpp/histogram/step2_double_buffering/run_histogram.py b/examples/jit_cpp/histogram/step2_double_buffering/run_histogram.py deleted file mode 100644 index 7a16a81f..00000000 --- a/examples/jit_cpp/histogram/step2_double_buffering/run_histogram.py +++ /dev/null @@ -1,52 +0,0 @@ -import torch -import torch_npu # noqa - -from jit_util_histogram import jit_compile - - -def test_histogram(size_mult=1, repeat_runs=20): - device = "npu:1" - dtype = torch.float32 - torch.npu.set_device(device) - - # Tile size is fixed in the kernel - tile_size = 512 - num_cores = torch.npu.get_device_properties().vector_core_num - num_tiles = num_cores * tile_size - total_len = num_tiles * size_mult - - bins = 256 - min_val = 0.0 - max_val = 256.0 - - # Create an input tensor bounded around our standard test range - x = torch.rand(size=(total_len,), device="npu", dtype=dtype).contiguous() * max_val - z = torch.zeros(bins, device="npu", dtype=torch.int32) - - # Golden PyTorch implementation - expected_hist = torch.histc(x.cpu(), bins, min=min_val, max=max_val).to(torch.int32) - - hist_func = jit_compile("kernel_histogram.cpp") - - # NPU kernel execution, test to see if any race conditions occur across multiple runs - actual_hist = [] - for _ in range(repeat_runs): - z.zero_() - hist_func(x, z, bins, min_val, max_val, block_dim=num_cores) - actual_hist.append(z.cpu().clone()) - - torch.npu.synchronize() - - # Check for consistency across runs and correctness against the expected count - for i, hist in enumerate(actual_hist): - assert torch.equal( - hist, actual_hist[0] - ), f"Inconsistent results across runs at run {i}, expected\n {actual_hist[0]}\ngot\n{hist}\n" - - assert torch.equal( - expected_hist, actual_hist[0] - ), f"Mismatch between expected and actual histogram, expected\n {expected_hist}\ngot\n{actual_hist[0]}\n" - - -if __name__ == "__main__": - test_histogram(size_mult=64, repeat_runs=20) From d0763ef05af1bb62d107564c6830f127e982e0f2 Mon Sep 17 00:00:00 2001 From: Vladimir Loncar Date: Wed, 8 Apr 2026 15:52:55 +0000 Subject: [PATCH 12/17] Add initial benchmark and plotting suite --- examples/jit_cpp/histogram/.gitignore | 1 + examples/jit_cpp/histogram/bench_kernels.py | 221 +++++++++++++++++ .../jit_cpp/histogram/jit_util_histogram.py | 14 +- examples/jit_cpp/histogram/plot_kernels.py | 227 ++++++++++++++++++ examples/jit_cpp/histogram/run_histogram.py | 13 +- .../kernel_histogram.cpp | 10 +- .../kernel_histogram.cpp | 10 +- 7 files changed, 478 insertions(+), 18 deletions(-) create mode 100644 examples/jit_cpp/histogram/.gitignore create mode 100644 examples/jit_cpp/histogram/bench_kernels.py create mode 100644 examples/jit_cpp/histogram/plot_kernels.py diff --git a/examples/jit_cpp/histogram/.gitignore b/examples/jit_cpp/histogram/.gitignore new file mode 100644 index 00000000..17aa483a --- /dev/null +++ b/examples/jit_cpp/histogram/.gitignore @@ -0,0 +1 @@ +outputs/ diff --git a/examples/jit_cpp/histogram/bench_kernels.py b/examples/jit_cpp/histogram/bench_kernels.py new file mode 100644 index 00000000..597d00b9 --- /dev/null +++ b/examples/jit_cpp/histogram/bench_kernels.py @@ -0,0 +1,221 @@ +import argparse +import os +import sys +from pathlib import Path + +import pandas as pd +import torch + +from jit_util_histogram import jit_compile + +DEVICE = os.environ.get("NPU_DEVICE", "npu:1") +DTYPE = torch.float32 + +N_REPEAT = 20 +N_WARMUP = 5 +N_ALLOC = N_REPEAT + N_WARMUP + +TILE_SIZES = [512, 1024, 2048, 4096, 8192] +BINS_LIST = [8, 32, 64, 128, 192, 256] +MIN_VAL = 0.0 +MAX_VAL = 255.0 + +DEFAULT_CSV_REL_PATH = Path("outputs") / "csv" / "histogram_timing.csv" + + +def _parse_args(): + parser = argparse.ArgumentParser( + description=( + "Benchmark torch and implementation histogram kernels and save the results to " + "a CSV file." + ) + ) + parser.add_argument( + "--csv", + type=str, + default=str(DEFAULT_CSV_REL_PATH), + help=f"Output CSV path (default: {DEFAULT_CSV_REL_PATH})", + ) + parser.add_argument( + "--implementation", + type=int, + nargs="*", + choices=[1, 2], + help="Select the implementation steps (1: naive, 2: double buffering). If not provided, benchmarks all.", + ) + parser.add_argument( + "--with-torch", + action="store_true", + help="Include torch baseline timing/throughput benchmarking.", + ) + return parser.parse_args() + + +IMPLEMENTATIONS = { + 1: "step1_naive_histogram", + 2: "step2_double_buffering", +} + + +def _bench_backend( + name, func, a_list, z_list, c_ref, processed_elements, total_bytes, bins +): + c = None + for a, z in zip(a_list[:N_WARMUP], z_list[:N_WARMUP]): + res = func(a, z, bins, MIN_VAL, MAX_VAL) + c = res if res is not None else z + + mean_diff = float(torch.mean(torch.abs(c.float() - c_ref.float())).cpu()) + abs_error = float(torch.max(torch.abs(c.float() - c_ref.float())).cpu()) + + start = torch.npu.Event(enable_timing=True) + end = torch.npu.Event(enable_timing=True) + start.record() + for a, z in zip( + a_list[N_WARMUP : N_WARMUP + N_REPEAT], z_list[N_WARMUP : N_WARMUP + N_REPEAT] + ): + func(a, z, bins, MIN_VAL, MAX_VAL) + end.record() + torch.npu.synchronize() + dur_us = start.elapsed_time(end) / N_REPEAT * 1e3 + + gmelem_s = processed_elements / dur_us / 1e3 + bw_gbs = total_bytes * 1e6 / dur_us / (1024**3) + + print( + f"{name} duration: {dur_us:.3f} us, GElem/s: {gmelem_s:.3f}, mean diff: {mean_diff}" + ) + + return { + f"{name}_time_us": dur_us, + f"{name}_gmelem_s": gmelem_s, + f"{name}_bandwidth_gbs": bw_gbs, + f"{name}_mean_diff": mean_diff, + f"{name}_abs_error": abs_error, + f"{name}_error": "", + } + + +def bench_n_elems(funcs_to_bench, num_elements, bins, tile_size): + print(f"\n=== N = {num_elements:_} | bins = {bins} | tile_size = {tile_size} ===") + + a_list = [ + torch.rand(num_elements, dtype=DTYPE, device=DEVICE) * (MAX_VAL - MIN_VAL) + + MIN_VAL + for _ in range(N_ALLOC) + ] + z_list = [ + torch.zeros(bins, device=DEVICE, dtype=torch.int32) for _ in range(N_ALLOC) + ] + + ref_a = a_list[N_WARMUP - 1] + c_ref = torch.histc(ref_a, bins=bins, min=MIN_VAL, max=MAX_VAL).to(torch.int32) + + processed_elements = num_elements + total_bytes = num_elements * int(ref_a.element_size()) + bins * int( + c_ref.element_size() + ) + + record = { + "N": num_elements, + "bins": bins, + "tile_size": tile_size, + } + + for name, func in funcs_to_bench.items(): + if func is None: + record.update( + { + f"{name}_time_us": float("nan"), + f"{name}_gmelem_s": float("nan"), + f"{name}_bandwidth_gbs": float("nan"), + f"{name}_mean_diff": float("nan"), + f"{name}_abs_error": float("nan"), + f"{name}_error": "backend not compiled/enabled", + } + ) + continue + + try: + stats = _bench_backend( + name, func, a_list, z_list, c_ref, processed_elements, total_bytes, bins + ) + record.update(stats) + except Exception as exc: + print(f"{name} unavailable: {exc}") + record.update( + { + f"{name}_time_us": float("nan"), + f"{name}_gmelem_s": float("nan"), + f"{name}_bandwidth_gbs": float("nan"), + f"{name}_mean_diff": float("nan"), + f"{name}_abs_error": float("nan"), + f"{name}_error": str(exc), + } + ) + + return [record] + + +def main(): + args = _parse_args() + include_torch = args.with_torch + impls_to_run = args.implementation if args.implementation else [1, 2] + + torch.npu.set_device(DEVICE) + base = Path(__file__).resolve().parent + + csv_path = Path(args.csv) + if not csv_path.is_absolute(): + csv_path = base / csv_path + csv_path.parent.mkdir(parents=True, exist_ok=True) + + print(f"Implementations selected: {[IMPLEMENTATIONS[i] for i in impls_to_run]}") + if include_torch: + print("Torch baseline: enabled") + + vector_cores = torch.npu.get_device_properties().vector_core_num + base_elements = max(TILE_SIZES) * vector_cores + # multiplier = max(1, (1024 * 1024) // base_elements) + multiplier = 1 + n_elements_list = [base_elements * multiplier * i for i in range(1, 33)] + + records = [] + for tile_size in TILE_SIZES: + funcs_to_bench = {} + + if include_torch: + + def torch_hist_bench(x, z, bins, min_val, max_val): + return torch.histc(x, bins=bins, min=min_val, max=max_val).to( + torch.int32 + ) + + funcs_to_bench["torch"] = torch_hist_bench + + for i in impls_to_run: + impl_dir = IMPLEMENTATIONS[i] + custom_path = base / impl_dir / "kernel_histogram.cpp" + name = f"step{i}" + try: + print( + f"Compiling {impl_dir}/kernel_histogram.cpp for backend '{name}' (tile_size={tile_size}) ..." + ) + funcs_to_bench[name] = jit_compile( + str(custom_path), tile_size=tile_size + ) + except Exception as exc: + print(f"[WARN] backend '{name}' unavailable: {exc}") + funcs_to_bench[name] = None + + for bins in BINS_LIST: + for n in n_elements_list: + records.extend(bench_n_elems(funcs_to_bench, n, bins, tile_size)) + + df = pd.DataFrame.from_records(records) + df.to_csv(csv_path, index=False) + print(f"\nSaved benchmark CSV: {csv_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/jit_cpp/histogram/jit_util_histogram.py b/examples/jit_cpp/histogram/jit_util_histogram.py index 6c35e5d0..7721c655 100644 --- a/examples/jit_cpp/histogram/jit_util_histogram.py +++ b/examples/jit_cpp/histogram/jit_util_histogram.py @@ -5,11 +5,14 @@ import torch ASCEND_TOOLKIT_HOME = os.environ["ASCEND_TOOLKIT_HOME"] -PTO_LIB_PATH = os.environ["PTO_LIB_PATH"] +PTO_LIB_PATH = os.environ.get("PTO_LIB_PATH", ASCEND_TOOLKIT_HOME) -def compile_cpp(kernel_cpp: str, verbose: bool = False, timeout: int = 120) -> str: - lib_path = os.path.join(os.path.dirname(kernel_cpp), "histogram_jit.so") +def compile_cpp( + kernel_cpp: str, tile_size=512, verbose: bool = False, timeout: int = 120 +) -> str: + dirname = os.path.dirname(kernel_cpp) + lib_path = os.path.join(dirname, f"{dirname}_jit.so") flags = [ "-fPIC", @@ -17,6 +20,7 @@ def compile_cpp(kernel_cpp: str, verbose: bool = False, timeout: int = 120) -> s "-xcce", "--npu-arch=dav-2201", "-DMEMORY_BASE", # here hardcoded for A2A3; TODO: expose this option to jit interface + "-DHIST_TILE_SIZE=" + str(tile_size), "-O2", "-std=c++17", f"-I{PTO_LIB_PATH}/include", @@ -91,8 +95,8 @@ def hist_func( return hist_func -def jit_compile(src_path, clean_up=True): - lib_path = compile_cpp(src_path, verbose=True) +def jit_compile(src_path, tile_size=512, clean_up=True): + lib_path = compile_cpp(src_path, tile_size=tile_size, verbose=True) func = load_lib(lib_path, check_type=False) if clean_up: os.remove(lib_path) diff --git a/examples/jit_cpp/histogram/plot_kernels.py b/examples/jit_cpp/histogram/plot_kernels.py new file mode 100644 index 00000000..2f83723b --- /dev/null +++ b/examples/jit_cpp/histogram/plot_kernels.py @@ -0,0 +1,227 @@ +import argparse +from pathlib import Path + +import matplotlib.pyplot as plt +import pandas as pd + +DEFAULT_CSV_REL_PATH = Path("outputs") / "csv" / "histogram_timing.csv" +DEFAULT_PLOT_REL_DIR = Path("outputs") / "plots" + +BACKEND_STYLE = { + "torch": {"color": "#111111", "marker": "x", "linestyle": "--"}, + "step1": {"color": "#1f77b4", "marker": "o", "linestyle": "-"}, + "step2": {"color": "#ff7f0e", "marker": "s", "linestyle": "-"}, +} +CUSTOM_MARKERS = ["o", "s", "^", "v", "D", "P", "X", "*", "<", ">"] + + +def _parse_args(): + parser = argparse.ArgumentParser( + description="Plot benchmark figures from a benchmark CSV file." + ) + parser.add_argument( + "--csv", + type=str, + default=str(DEFAULT_CSV_REL_PATH), + help=f"Input benchmark CSV path (default: {DEFAULT_CSV_REL_PATH})", + ) + parser.add_argument( + "--plot-dir", + type=str, + default=str(DEFAULT_PLOT_REL_DIR), + help=f"Output plot directory (default: {DEFAULT_PLOT_REL_DIR})", + ) + parser.add_argument( + "--bins", + type=int, + default=256, + help="Number of bins to plot (default: 256)", + ) + parser.add_argument( + "--tile-size", + type=int, + action="append", + default=[], + help="Tile sizes to plot. Can be repeated. If not provided, defaults to 4096.", + ) + return parser.parse_args() + + +def _style(name: str) -> dict: + return BACKEND_STYLE.get( + name, {"color": "#2ca02c", "marker": "^", "linestyle": "-"} + ) + + +def _finalize_plot(title: str, xlabel: str, ylabel: str): + plt.xlabel(xlabel) + plt.ylabel(ylabel) + plt.title(title) + plt.xlim(left=0) + plt.ylim(bottom=0) + plt.grid(True, alpha=0.25) + handles, _ = plt.gca().get_legend_handles_labels() + if handles: + plt.legend(fontsize=8) + plt.tight_layout() + + +def _plot_backend(df: pd.DataFrame, backend: str, metric_col: str, style: dict): + if metric_col not in df.columns: + return + + if backend == "torch" or "tile_size" not in df.columns: + g = df[["N", metric_col]].dropna() + if g.empty: + return + g = g.groupby("N", as_index=False)[metric_col].mean().sort_values("N") + plt.plot( + g["N"], + g[metric_col], + marker=style["marker"], + linestyle=style["linestyle"], + color=style["color"], + label=backend, + ) + else: + ts_values = sorted(df["tile_size"].dropna().unique()) + grouped = df.dropna(subset=[metric_col]).groupby("tile_size", sort=True) + for idx, (ts, group) in enumerate(grouped): + g = group[["N", metric_col]].dropna() + if g.empty: + continue + g = g.groupby("N", as_index=False)[metric_col].mean().sort_values("N") + + if len(ts_values) > 1: + label = f"{backend} (ts={ts})" + marker = CUSTOM_MARKERS[idx % len(CUSTOM_MARKERS)] + alpha = max(0.4, 1.0 - (idx * 0.15)) + else: + label = backend + marker = style["marker"] + alpha = 1.0 + + plt.plot( + g["N"], + g[metric_col], + marker=marker, + linestyle=style["linestyle"], + color=style["color"], + alpha=alpha, + label=label, + ) + + +def plot_runtime(df: pd.DataFrame, out_dir: Path, bins: int) -> Path: + plt.figure(figsize=(10, 5)) + for backend in ["torch", "step1", "step2"]: + _plot_backend(df, backend, f"{backend}_time_us", _style(backend)) + + _finalize_plot( + title=f"Runtime vs N (bins={bins})", + xlabel="Number of Elements (N)", + ylabel="Runtime (us)", + ) + + out_path = out_dir / "duration.png" + plt.savefig(out_path, dpi=160) + plt.close() + return out_path + + +def plot_throughput(df: pd.DataFrame, out_dir: Path, bins: int) -> Path: + plt.figure(figsize=(10, 5)) + for backend in ["torch", "step1", "step2"]: + _plot_backend(df, backend, f"{backend}_gmelem_s", _style(backend)) + + _finalize_plot( + title=f"Throughput vs N (bins={bins})", + xlabel="Number of Elements (N)", + ylabel="GElem/s", + ) + + out_path = out_dir / "throughput.png" + plt.savefig(out_path, dpi=160) + plt.close() + return out_path + + +def plot_error(df: pd.DataFrame, out_dir: Path, bins: int) -> Path: + plt.figure(figsize=(10, 5)) + for backend in ["step1", "step2"]: + _plot_backend(df, backend, f"{backend}_mean_diff", _style(backend)) + + _finalize_plot( + title=f"Error vs N (bins={bins})", + xlabel="Number of Elements (N)", + ylabel="Mean Abs Error", + ) + + out_path = out_dir / "error.png" + plt.savefig(out_path, dpi=160) + plt.close() + return out_path + + +def main(): + args = _parse_args() + base = Path(__file__).resolve().parent + + csv_path = Path(args.csv) + if not csv_path.is_absolute(): + csv_path = base / csv_path + if not csv_path.exists(): + raise FileNotFoundError(f"Benchmark CSV not found: {csv_path}") + + plot_dir = Path(args.plot_dir) + if not plot_dir.is_absolute(): + plot_dir = base / plot_dir + plot_dir.mkdir(parents=True, exist_ok=True) + + df = pd.read_csv(csv_path) + required_columns = {"N"} + missing = required_columns - set(df.columns) + if missing: + raise ValueError( + f"CSV is missing required columns: {sorted(missing)} (file: {csv_path})" + ) + + plot_df = df + + if "bins" in plot_df.columns: + plot_df = plot_df[plot_df["bins"] == args.bins] + if plot_df.empty: + available_bins = sorted(df["bins"].dropna().unique()) + raise RuntimeError( + f"No rows found for bins={args.bins} in {csv_path}. " + f"Available bins: {available_bins}" + ) + + tile_sizes = args.tile_size if args.tile_size else [4096] + if "tile_size" in plot_df.columns: + plot_df = plot_df[plot_df["tile_size"].isin(tile_sizes)] + if plot_df.empty: + available_ts = sorted(df["tile_size"].dropna().unique()) + raise RuntimeError( + f"No rows found for tile_sizes={tile_sizes} in {csv_path} (with bins={args.bins}). " + f"Available tile sizes: {available_ts}" + ) + + if plot_df.empty: + raise RuntimeError(f"No data found in {csv_path}.") + + runtime_path = plot_runtime(plot_df, plot_dir, args.bins) + throughput_path = plot_throughput(plot_df, plot_dir, args.bins) + error_path = plot_error(plot_df, plot_dir, args.bins) + + print(f"Loaded CSV: {csv_path}") + print( + f"Filters applied: bins={args.bins}, tile_sizes={tile_sizes if 'tile_size' in df.columns else 'N/A'}" + ) + print(f"Saved runtime plot: {runtime_path}") + print(f"Saved throughput plot: {throughput_path}") + print(f"Saved error plot: {error_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/jit_cpp/histogram/run_histogram.py b/examples/jit_cpp/histogram/run_histogram.py index 76da56d5..fc4ef1f9 100644 --- a/examples/jit_cpp/histogram/run_histogram.py +++ b/examples/jit_cpp/histogram/run_histogram.py @@ -3,18 +3,19 @@ from jit_util_histogram import jit_compile -algo_steps = { +IMPLEMENTATIONS = { 1: "step1_naive_histogram", 2: "step2_double_buffering", } -def test_histogram(algo_step=2, size_mult=1, repeat_runs=20): + +def test_histogram(impl=2, size_mult=1, repeat_runs=20): device = "npu:1" dtype = torch.float32 torch.npu.set_device(device) # Tile size is fixed in the kernel - tile_size = 512 + tile_size = 4096 num_cores = torch.npu.get_device_properties().vector_core_num num_tiles = num_cores * tile_size total_len = num_tiles * size_mult @@ -30,7 +31,9 @@ def test_histogram(algo_step=2, size_mult=1, repeat_runs=20): # Golden PyTorch implementation expected_hist = torch.histc(x.cpu(), bins, min=min_val, max=max_val).to(torch.int32) - hist_func = jit_compile(f"{algo_steps[algo_step]}/kernel_histogram.cpp") + hist_func = jit_compile( + f"{IMPLEMENTATIONS[impl]}/kernel_histogram.cpp", tile_size=tile_size + ) # NPU kernel execution, test to see if any race conditions occur across multiple runs actual_hist = [] @@ -53,4 +56,4 @@ def test_histogram(algo_step=2, size_mult=1, repeat_runs=20): if __name__ == "__main__": - test_histogram(algo_step=2, size_mult=64, repeat_runs=20) + test_histogram(impl=2, size_mult=64, repeat_runs=20) diff --git a/examples/jit_cpp/histogram/step1_naive_histogram/kernel_histogram.cpp b/examples/jit_cpp/histogram/step1_naive_histogram/kernel_histogram.cpp index 15df61fc..6e1b6ea9 100644 --- a/examples/jit_cpp/histogram/step1_naive_histogram/kernel_histogram.cpp +++ b/examples/jit_cpp/histogram/step1_naive_histogram/kernel_histogram.cpp @@ -13,7 +13,9 @@ for the full License text. using namespace pto; -constexpr uint32_t DEFAULT_TILE_SIZE = 512; +#ifndef HIST_TILE_SIZE +#define HIST_TILE_SIZE 1024 +#endif constexpr uint32_t MAX_BINS = 256; constexpr uint32_t MAX_BLOCKS = 64; @@ -263,9 +265,9 @@ AICORE void runTHistogramFinal(__gm__ float *z_local, __gm__ int32_t *z, __global__ AICORE void histogram_local_fp32( __gm__ void *x, __gm__ void *z_local, const uint32_t in_length, const int32_t num_bins, const float min_val, const float max_val) { - runTLocalHistogram( - (__gm__ float *)x, (__gm__ float *)z_local, in_length, num_bins, min_val, - max_val); + runTLocalHistogram((__gm__ float *)x, + (__gm__ float *)z_local, in_length, + num_bins, min_val, max_val); } __global__ AICORE void histogram_final(__gm__ void *z_local, __gm__ void *z, diff --git a/examples/jit_cpp/histogram/step2_double_buffering/kernel_histogram.cpp b/examples/jit_cpp/histogram/step2_double_buffering/kernel_histogram.cpp index 3aff7361..636e2b82 100644 --- a/examples/jit_cpp/histogram/step2_double_buffering/kernel_histogram.cpp +++ b/examples/jit_cpp/histogram/step2_double_buffering/kernel_histogram.cpp @@ -13,7 +13,9 @@ for the full License text. using namespace pto; -constexpr uint32_t DEFAULT_TILE_SIZE = 512; +#ifndef HIST_TILE_SIZE +#define HIST_TILE_SIZE 1024 +#endif constexpr uint32_t MAX_BINS = 256; constexpr uint32_t MAX_BLOCKS = 64; @@ -276,9 +278,9 @@ AICORE void runTHistogramFinal(__gm__ float *z_local, __gm__ int32_t *z, __global__ AICORE void histogram_local_fp32( __gm__ void *x, __gm__ void *z_local, const uint32_t in_length, const int32_t num_bins, const float min_val, const float max_val) { - runTLocalHistogram( - (__gm__ float *)x, (__gm__ float *)z_local, in_length, num_bins, min_val, - max_val); + runTLocalHistogram((__gm__ float *)x, + (__gm__ float *)z_local, in_length, + num_bins, min_val, max_val); } __global__ AICORE void histogram_final(__gm__ void *z_local, __gm__ void *z, From 9635717d1dc4492d39039d064ff83535d4acc934 Mon Sep 17 00:00:00 2001 From: Vladimir Loncar Date: Wed, 8 Apr 2026 16:02:36 +0000 Subject: [PATCH 13/17] Hope this satisfies the linter! --- csrc/kernel/kernel_histogram.cpp | 4 ++-- examples/jit_cpp/histogram/bench_kernels.py | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/csrc/kernel/kernel_histogram.cpp b/csrc/kernel/kernel_histogram.cpp index 6e68203c..82543672 100644 --- a/csrc/kernel/kernel_histogram.cpp +++ b/csrc/kernel/kernel_histogram.cpp @@ -14,9 +14,9 @@ for the full License text. #include #include "kernel_utils.h" - +// clang-format off #define GM_ADDR __gm__ uint8_t* - +// clang-format on using namespace pto; constexpr uint32_t DEFAULT_TILE_SIZE = 512; diff --git a/examples/jit_cpp/histogram/bench_kernels.py b/examples/jit_cpp/histogram/bench_kernels.py index 597d00b9..0025d81f 100644 --- a/examples/jit_cpp/histogram/bench_kernels.py +++ b/examples/jit_cpp/histogram/bench_kernels.py @@ -1,6 +1,5 @@ import argparse import os -import sys from pathlib import Path import pandas as pd @@ -186,7 +185,7 @@ def main(): if include_torch: - def torch_hist_bench(x, z, bins, min_val, max_val): + def torch_hist_bench(x, _, bins, min_val, max_val): return torch.histc(x, bins=bins, min=min_val, max=max_val).to( torch.int32 ) From f90450a937a601e407639ea50f3076907455bd33 Mon Sep 17 00:00:00 2001 From: Vladimir Loncar Date: Wed, 8 Apr 2026 17:15:18 +0000 Subject: [PATCH 14/17] WIP implementation with MSCATTER --- examples/jit_cpp/histogram/run_histogram.py | 1 + .../kernel_histogram.cpp | 240 ++++++++++++++++++ 2 files changed, 241 insertions(+) create mode 100644 examples/jit_cpp/histogram/step3_scatter_index_to_gm/kernel_histogram.cpp diff --git a/examples/jit_cpp/histogram/run_histogram.py b/examples/jit_cpp/histogram/run_histogram.py index fc4ef1f9..4d21bca8 100644 --- a/examples/jit_cpp/histogram/run_histogram.py +++ b/examples/jit_cpp/histogram/run_histogram.py @@ -6,6 +6,7 @@ IMPLEMENTATIONS = { 1: "step1_naive_histogram", 2: "step2_double_buffering", + 3: "step3_scatter_index_to_gm", # Not working on A2/A3 } diff --git a/examples/jit_cpp/histogram/step3_scatter_index_to_gm/kernel_histogram.cpp b/examples/jit_cpp/histogram/step3_scatter_index_to_gm/kernel_histogram.cpp new file mode 100644 index 00000000..a5980052 --- /dev/null +++ b/examples/jit_cpp/histogram/step3_scatter_index_to_gm/kernel_histogram.cpp @@ -0,0 +1,240 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +All rights reserved. + +See LICENSE in the root of the software repository: +https://github.com/huawei-csl/pto-kernels/ +for the full License text. +*/ + +#include + +#include "../kernel_utils.h" + +using namespace pto; + +#ifndef HIST_TILE_SIZE +#define HIST_TILE_SIZE 1024 +#endif +constexpr uint32_t MAX_BINS = 256; +constexpr uint32_t MAX_BLOCKS = 64; + +/** + * runTLocalHistogram - Local histogram using MSCATTER with AtomicAdd + */ +template +AICORE void runTLocalHistogram(__gm__ T *x, __gm__ float *z_local, + const uint32_t total_length, + const int32_t num_bins, const float min_val, + const float max_val) { +#if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + + // --- Define Global Tensors --- + using InputGlobalData = pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + using HistGlobalData = + pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + + // --- Work Distribution --- + const uint32_t block_idx = get_block_idx(); + const uint32_t block_num = get_block_num(); + const uint32_t num_tiles_total = + kernel_utils::CeilDiv(total_length, TILE_SIZE); + const uint32_t num_tiles_per_core = + kernel_utils::CeilDiv(num_tiles_total, block_num); + const uint32_t start_idx = block_idx * num_tiles_per_core; + const uint32_t end_idx = (start_idx + num_tiles_per_core > num_tiles_total) + ? num_tiles_total + : (start_idx + num_tiles_per_core); + + // --- Define UB Tiles and Memory Layout --- + uint32_t addr = 0; + const uint32_t UB_X_PING = addr; + addr += TILE_SIZE * sizeof(T); + const uint32_t UB_X_PONG = addr; + addr += TILE_SIZE * sizeof(T); + const uint32_t UB_F32_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_IDX_ADDR = addr; + addr += TILE_SIZE * sizeof(int32_t); + const uint32_t UB_ONES_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + + InputGlobalData x_gm(x, {static_cast(total_length)}); + HistGlobalData z_gm(z_local + block_idx * num_bins, {num_bins}); + + using InputTileData = Tile; + using F32TileData = Tile; + using IdxTileData = Tile; + + F32TileData f32_tile; + TASSIGN(f32_tile, UB_F32_ADDR); + + IdxTileData idx_tile; + TASSIGN(idx_tile, UB_IDX_ADDR); + + F32TileData ones_tile; + TASSIGN(ones_tile, UB_ONES_ADDR); + TEXPANDS(ones_tile, 1.0f); + + const float inv_bin_width = + static_cast(num_bins) / (max_val - min_val); + + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + + // --- Main Calculation Loop --- + for (uint32_t tile_idx = start_idx, ping = 1; tile_idx < end_idx; + ++tile_idx) { + int offset = tile_idx * TILE_SIZE; + TASSIGN(x_gm, x + offset); + + const event_t ev = ping ? (event_t)EVENT_ID0 : (event_t)EVENT_ID1; + const unsigned x_base = ping ? UB_X_PING : UB_X_PONG; + + InputTileData x_tile; + TASSIGN(x_tile, x_base); + + wait_flag(PIPE_V, PIPE_MTE2, ev); + TLOAD(x_tile, x_gm); + set_flag(PIPE_MTE2, PIPE_V, ev); + wait_flag(PIPE_MTE2, PIPE_V, ev); + + // Calculate bin index: idx = (val - min_val) * inv_bin_width + TADDS(f32_tile, x_tile, -min_val); + TMULS(f32_tile, f32_tile, inv_bin_width); + + // Convert to int32 index using Floor rounding + TCVT(idx_tile, f32_tile, RoundMode::CAST_FLOOR); + + // Clamp indices to [0, num_bins - 1] to handle edge cases and outliers + TMAXS(idx_tile, idx_tile, static_cast(0)); + TMINS(idx_tile, idx_tile, static_cast(num_bins - 1)); + + // Atomic Scatter-Add to global memory histogram + MSCATTER(z_gm, ones_tile, idx_tile); + + set_flag(PIPE_V, PIPE_MTE2, ev); + ping = 1 - ping; + } + + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + +#endif +} + +// Template parameter to avoid "no function" kernel launch error +template +AICORE void runTHistogramFinal(__gm__ float *z_local, __gm__ int32_t *z, + const int32_t num_bins, + const int32_t num_blocks) { +#if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + + if (get_block_idx() == 0) { + // --- Define Global Tensors --- + using InGlobalData = + pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + using OutGlobalData = + pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + + uint32_t addr = 0; + const uint32_t UB_IN_ADDR = addr; + addr += MAX_BLOCKS * MAX_BINS * sizeof(float); + const uint32_t UB_REDUCE_TMP_ADDR = addr; + addr += MAX_BINS * sizeof(float); + const uint32_t UB_FLOAT_OUT_ADDR = addr; + addr += MAX_BINS * sizeof(float); + const uint32_t UB_OUT_ADDR = addr; + addr += MAX_BINS * sizeof(int32_t); + + using InTile = Tile; + InTile in_tile( + {static_cast(num_blocks), static_cast(num_bins)}); + TASSIGN(in_tile, UB_IN_ADDR); + + using ReduceTmpTile = + Tile; + ReduceTmpTile reduce_tmp_tile(num_bins); + TASSIGN(reduce_tmp_tile, UB_REDUCE_TMP_ADDR); + TEXPANDS(reduce_tmp_tile, 0.0f); + + using FloatOutTile = + Tile; + FloatOutTile float_out_tile(num_bins); + TASSIGN(float_out_tile, UB_FLOAT_OUT_ADDR); + TEXPANDS(float_out_tile, 0.0f); + + using OutTile = Tile; + OutTile out_tile(num_bins); + TASSIGN(out_tile, UB_OUT_ADDR); + TEXPANDS(out_tile, static_cast(0)); + + // Load all block counts into UB row by row to match 2D tile padding + using InRowTile = + Tile; + InRowTile row_tile(static_cast(num_bins)); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + for (int32_t b = 0; b < num_blocks; ++b) { + InGlobalData z_local_gm(z_local + b * num_bins, {num_bins}); + TASSIGN(row_tile, UB_IN_ADDR + b * MAX_BINS * sizeof(float)); + TLOAD(row_tile, z_local_gm); + } + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + // TCOLSUM reduces along the row dimension (num_blocks) + TCOLSUM(float_out_tile, in_tile, reduce_tmp_tile, true); + + TCVT(out_tile, float_out_tile, RoundMode::CAST_RINT); + + // --- Final Store to Global Memory --- + OutGlobalData z_gm(z, {num_bins}); + + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(z_gm, out_tile); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + } + +#endif +} + +__global__ AICORE void histogram_local_fp32( + __gm__ void *x, __gm__ void *z_local, const uint32_t in_length, + const int32_t num_bins, const float min_val, const float max_val) { + runTLocalHistogram((__gm__ float *)x, + (__gm__ float *)z_local, in_length, + num_bins, min_val, max_val); +} + +__global__ AICORE void histogram_final(__gm__ void *z_local, __gm__ void *z, + const int32_t num_bins, + const int32_t num_blocks) { + runTHistogramFinal<0>((__gm__ float *)z_local, (__gm__ int32_t *)z, num_bins, + num_blocks); +} + +extern "C" void histogram_fp32(uint32_t num_blocks, void *stream, void *x, + void *z_local, void *z, const uint32_t in_length, + const int32_t num_bins, const float min_val, + const float max_val) { + histogram_local_fp32<<>>( + x, z_local, in_length, num_bins, min_val, max_val); + histogram_final<<<1, nullptr, stream>>>(z_local, z, num_bins, num_blocks); +} From 3a8284e292dab4bcb9901af1662f09ea653593a7 Mon Sep 17 00:00:00 2001 From: Vladimir Loncar Date: Thu, 9 Apr 2026 09:00:01 +0000 Subject: [PATCH 15/17] Add readme --- examples/jit_cpp/histogram/README.md | 58 ++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 examples/jit_cpp/histogram/README.md diff --git a/examples/jit_cpp/histogram/README.md b/examples/jit_cpp/histogram/README.md new file mode 100644 index 00000000..f31c220d --- /dev/null +++ b/examples/jit_cpp/histogram/README.md @@ -0,0 +1,58 @@ +# Ascend PTO Histogram Implementation Examples + +This directory contains a series of implementations demonstrating the evolution and optimization of a histogram kernel using the PTO library for Ascend NPUs (specifically targeting the A2/910B architecture). + +## Implementation Evolution + +The implementations are organized into steps, each introducing new concepts or optimizations: + +- **Step 0: Count Less Than (`step0_count_less_than`)**: The foundational algorithm that counts how many elements in a tile are less than a given pivot value using vector comparisons and reductions. The algorithm is implemented using atomic operations and using a two phase kernels. The atomic operations didn't always behave as expected, see note below. +- **Step 1: Naive Histogram (`step1_naive_histogram`)**: Expands the logic to a full histogram by looping over all bins. For each bin, it calculates the count of elements falling within that range. +- **Step 2: Double Buffering (`step2_double_buffering`)**: Introduces double buffering (ping-pong) for data loading from Global Memory (GM) to Unified Buffer (UB), allowing computation and data movement to overlap. +- **Step 3: Scatter Index to GM (`step3_scatter_index_to_gm`)**: A significant algorithmic shift. Instead of looping over bins and performing scalar updates, this implementation calculates the bin index for each element in parallel using vector operations and uses `MSCATTER` with `AtomicAdd` to update the histogram directly in Global Memory. This avoids the slow scalar-to-vector synchronization. + +## Included Files + +- `bench_kernels.py`: A comprehensive benchmarking suite to compare the performance of different implementations. +- `plot_kernels.py`: A utility script to visualize the benchmarking results. +- `run_histogram.py`: A script for functional testing and verification of the kernels. +- `jit_util_histogram.py`: Utility functions for JIT-compiling the C++ kernels into Python-callable operators. +- `kernel_utils.h`: Common helper functions used across the different kernel implementations. + +## Usage + +### Testing +To verify the correctness of the implementations, run: +```bash +python run_histogram.py +``` +*Note: You can modify the parameters (such as `num_bins`, `total_length`, `tile_size` or which implementation to use) directly inside the script.* + +### Benchmarking +To compare the performance of the different steps: +```bash +python bench_kernels.py +``` +The script supports various arguments for configuring the benchmark range and parameters. Use `--help` to see all options. + +### Plotting +After running the benchmarks, you can generate plots using: +```bash +python plot_kernels.py +``` + +## Torch Operator Integration + +A production-ready Torch operator implementation is provided in the repository's core source tree: +- **Kernel Implementation**: `csrc/kernel/kernel_histogram.cpp` +- **Host/C++ Wrapper**: `csrc/host/torch_histogram.h` + +## Known Issues and Observations + +During development and optimization, several architectural and library-specific details were noted: + +- **`TCMPS` Ambiguity**: The documentation is ambiguous regarding whether the valid column number of the mask tile must exactly match the source tile or if it can be different. The bits are packed, the example shows allocation of a tile that is using reduced number of valid rows, however this is in contrast with the description of the operation and in fact doesn't work. +- **Temporary Tiles**: Several instructions such as `TSEL`, `TXOR`, `TGATHER` and others mention an extra temporary tile parameter that is not actually required. Operations that do require it (like `TROWSUM`) don't mention any constraints that this tile should have. +- **`AtomicAdd` in `TSTORE`**: The `AtomicAdd` parameter in the `TSTORE` instruction was found to be unreliable in some configurations, requiring the use of a two-phase algorithm (local reduction followed by a final global reduction) for stability. +- **`MSCATTER` on A2/A3**: The `MSCATTER` implementation is not supported on A2/A3 architectures, thus at the time of writing Step 3 is provided only as an illustration of the next direction of implementation. +- Occasionally the two-phase algorithm will not launch the first or second phase with no error provided. Whether this is an implementation issue or a driver/toolkit issue is still under investigation. From a74087f414f8374b37410f37b4d298640949d397 Mon Sep 17 00:00:00 2001 From: Vladimir Loncar Date: Wed, 15 Apr 2026 08:21:43 +0000 Subject: [PATCH 16/17] style consistency --- .../step0_count_less_than/jit_util_count_less_than.py | 7 +++++-- .../kernel_count_less_than_atomic.cpp | 3 ++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/examples/jit_cpp/histogram/step0_count_less_than/jit_util_count_less_than.py b/examples/jit_cpp/histogram/step0_count_less_than/jit_util_count_less_than.py index d76054b4..eded6d5b 100644 --- a/examples/jit_cpp/histogram/step0_count_less_than/jit_util_count_less_than.py +++ b/examples/jit_cpp/histogram/step0_count_less_than/jit_util_count_less_than.py @@ -5,7 +5,7 @@ import torch ASCEND_TOOLKIT_HOME = os.environ["ASCEND_TOOLKIT_HOME"] -PTO_LIB_PATH = os.environ["PTO_LIB_PATH"] +PTO_LIB_PATH = os.environ.get("PTO_LIB_PATH", ASCEND_TOOLKIT_HOME) def compile_cpp(kernel_cpp: str, verbose: bool = False, timeout: int = 120) -> str: @@ -35,7 +35,10 @@ def compile_cpp(kernel_cpp: str, verbose: bool = False, timeout: int = 120) -> s stderr=subprocess.STDOUT, ) except Exception as e: - raise RuntimeError(f"Compile failed: {e}") from e + output = e.stdout.decode("utf-8", errors="replace") if e.stdout else "" + raise RuntimeError( + f"Compile failed with exit code {e.returncode}:\n{output}" + ) from e if verbose: print(f"generated {lib_path}") diff --git a/examples/jit_cpp/histogram/step0_count_less_than/kernel_count_less_than_atomic.cpp b/examples/jit_cpp/histogram/step0_count_less_than/kernel_count_less_than_atomic.cpp index 822c2b09..118af5c2 100644 --- a/examples/jit_cpp/histogram/step0_count_less_than/kernel_count_less_than_atomic.cpp +++ b/examples/jit_cpp/histogram/step0_count_less_than/kernel_count_less_than_atomic.cpp @@ -23,6 +23,7 @@ AICORE void runTCountLessThan(__gm__ T *x, __gm__ int32_t *z, #if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) set_mask_norm(); set_vector_mask(-1, -1); + set_atomic_add(); // --- Define Global Tensors --- using InputGlobalData = pto::GlobalTensor, @@ -132,7 +133,7 @@ AICORE void runTCountLessThan(__gm__ T *x, __gm__ int32_t *z, set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); - // DOesn't do atomic adds + // Doesn't do atomic adds TSTORE(z_gm, local_out); set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); From 2ee9daee7ba2fb4f7ca4503fd835bb51af90a0df Mon Sep 17 00:00:00 2001 From: Vladimir Loncar Date: Wed, 15 Apr 2026 08:25:41 +0000 Subject: [PATCH 17/17] linting fix --- csrc/host/pybind11.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/host/pybind11.cpp b/csrc/host/pybind11.cpp index 487bb254..05ec6c86 100644 --- a/csrc/host/pybind11.cpp +++ b/csrc/host/pybind11.cpp @@ -11,8 +11,8 @@ for the full License text. #include "torch_abs.h" #include "torch_batch_matrix_square.h" -#include "torch_histogram.h" #include "torch_csr_gather.h" +#include "torch_histogram.h" #include "torch_simple_matmul.h" #include "torch_swiglu.h" #include "torch_tri_inv.h"