diff --git a/CMakeLists.txt b/CMakeLists.txt index 46f02038..7e7eb901 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -113,6 +113,7 @@ ascendc_library( SHARED csrc/kernel/kernel_tri_inv_col_sweep.cpp csrc/kernel/kernel_abs.cpp + csrc/kernel/kernel_histogram.cpp csrc/kernel/kernel_csr_gather.cpp csrc/kernel/kernel_simple_matmul.cpp csrc/kernel/kernel_batch_matrix_square.cpp diff --git a/csrc/host/pybind11.cpp b/csrc/host/pybind11.cpp index 134195b7..05ec6c86 100644 --- a/csrc/host/pybind11.cpp +++ b/csrc/host/pybind11.cpp @@ -12,6 +12,7 @@ for the full License text. #include "torch_abs.h" #include "torch_batch_matrix_square.h" #include "torch_csr_gather.h" +#include "torch_histogram.h" #include "torch_simple_matmul.h" #include "torch_swiglu.h" #include "torch_tri_inv.h" @@ -36,6 +37,8 @@ PYBIND11_MODULE(pto_kernels_ops, m) { }, pybind11::arg("device_id") = 0); m.def("pto_abs", &pto_isa_ops::run_abs); + m.def("pto_histogram", &pto_isa_ops::run_histogram, py::arg("x"), + py::arg("bins") = 100, py::arg("min") = 0.0, py::arg("max") = 0.0); m.def("pto_batch_matrix_square", &pto_isa_ops::run_batch_matrix_square); m.def("pto_csr_gather", &pto_isa_ops::run_csr_gather); m.def("pto_simple_matmul", &pto_isa_ops::run_simple_matmul); diff --git a/csrc/host/torch_histogram.h b/csrc/host/torch_histogram.h new file mode 100644 index 00000000..590c09a8 --- /dev/null +++ b/csrc/host/torch_histogram.h @@ -0,0 +1,86 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +All rights reserved. + +See LICENSE in the root of the software repository: +https://github.com/huawei-csl/pto-kernels/ +for the full License text. +*/ +#pragma once + +#include +#include + +#include "aclrtlaunch_histogram_final.h" +#include "aclrtlaunch_histogram_fp16.h" +#include "aclrtlaunch_histogram_fp32.h" +#include "utils.h" + +namespace pto_isa_ops { + +/** + * @brief Computes the histogram of a tensor. + * + * @param [in] x Input tensor of dtype fp16 or fp32. + * @param [in] bins Number of histogram bins. + * @param [in] min_val Lower bound of the range. + * @param [in] max_val Upper bound of the range. + * @return at::Tensor Computed histogram tensor. + */ +at::Tensor run_histogram(const at::Tensor& x, int64_t bins = 100, + double min_val = 0.0, double max_val = 0.0) { + const uint32_t total_len = x.numel(); + constexpr uint32_t TILE_SIZE = 512; + const uint32_t block_dim = GetNumVectorCores(); + + TORCH_CHECK(total_len % TILE_SIZE == 0, + "total number of elements must be divisible by TILE_SIZE"); + TORCH_CHECK(bins <= 256, "bins must be <= 256"); + TORCH_CHECK(x.is_contiguous(), "x must be contiguous"); + + const auto dtype = x.options().dtype(); + const auto device = x.options().device(); + + // Allocate a 1D tensor sized `[block_dim * bins]` for the local histogram + // counts. + auto z_local_opts = + at::TensorOptions() + .dtype( + at::kFloat) // Local (per-core) histogram counts will be floats + .device(device); + at::Tensor z_local = at::zeros({block_dim * bins}, z_local_opts); + + // Allocate a 1D tensor sized `[bins]` for the histogram. + auto z_opts = at::TensorOptions() + .dtype(at::kInt) // The final result will be int32 counts + .device(device); + at::Tensor z = at::zeros({bins}, z_opts); + + const auto num_bins = static_cast(bins); + + if (min_val == 0.0 && max_val == 0.0) { + const double tensor_min = x.min().item(); + const double tensor_max = x.max().item(); + min_val = tensor_min; + max_val = tensor_max; + } + + const auto f_min_val = static_cast(min_val); + const auto f_max_val = static_cast(max_val); + + if (dtype == at::kHalf) { + EXEC_KERNEL_CMD(histogram_fp16, block_dim, x, z_local, total_len, num_bins, + f_min_val, f_max_val); + } else if (dtype == at::kFloat) { + EXEC_KERNEL_CMD(histogram_fp32, block_dim, x, z_local, total_len, num_bins, + f_min_val, f_max_val); + } else { + throw std::runtime_error("Unsupported dtype for `pto_histogram` kernel"); + } + + const uint32_t reduce_dim = 1; + EXEC_KERNEL_CMD(histogram_final, reduce_dim, z_local, z, num_bins, block_dim); + + return z; +} +} // namespace pto_isa_ops diff --git a/csrc/kernel/kernel_histogram.cpp b/csrc/kernel/kernel_histogram.cpp new file mode 100644 index 00000000..82543672 --- /dev/null +++ b/csrc/kernel/kernel_histogram.cpp @@ -0,0 +1,303 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +All rights reserved. + +See LICENSE in the root of the software repository: +https://github.com/huawei-csl/pto-kernels/ +for the full License text. +*/ + +#if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) + +#define MEMORY_BASE + +#include + +#include "kernel_utils.h" +// clang-format off +#define GM_ADDR __gm__ uint8_t* +// clang-format on +using namespace pto; + +constexpr uint32_t DEFAULT_TILE_SIZE = 512; +constexpr uint32_t MAX_BINS = 256; +constexpr uint32_t MAX_BLOCKS = 64; + +/** + * runTLocalHistogram - Local, per-core histogram calculation + */ +template +AICORE void runTLocalHistogram(__gm__ T *x, __gm__ float *z_local, + const uint32_t total_length, + const int32_t num_bins, const float min_val, + const float max_val) { + set_mask_norm(); + set_vector_mask(-1, -1); + + // --- Define Global Tensors --- + using InputGlobalData = pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + using HistGlobalData = + pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + + // --- Work Distribution --- + const uint32_t block_idx = get_block_idx(); + const uint32_t block_num = get_block_num(); + const uint32_t num_tiles_total = + kernel_utils::CeilDiv(total_length, TILE_SIZE); + const uint32_t num_tiles_per_core = + kernel_utils::CeilDiv(num_tiles_total, block_num); + const uint32_t start_idx = block_idx * num_tiles_per_core; + const uint32_t end_idx = (start_idx + num_tiles_per_core > num_tiles_total) + ? num_tiles_total + : (start_idx + num_tiles_per_core); + + // --- Define UB Tiles and Memory Layout --- + uint32_t addr = 0; + const uint32_t UB_X_PING = addr; + addr += TILE_SIZE * sizeof(T); + const uint32_t UB_X_PONG = addr; + addr += TILE_SIZE * sizeof(T); + const uint32_t UB_CUR_MASK_ADDR = addr; + addr += TILE_SIZE * sizeof(uint8_t); + const uint32_t UB_CUR_F32_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_PREV_F32_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_BIN_F32_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_ONE_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_ZERO_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_REDUCE_TMP_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_COUNT_ADDR = addr; + addr += 8 * sizeof(float); + const uint32_t UB_LOCAL_HIST_ADDR = addr; + + InputGlobalData x_gm(x, {static_cast(total_length)}); + + using InputTileData = Tile; + + using MaskTileData = Tile; + MaskTileData current_mask; + TASSIGN(current_mask, UB_CUR_MASK_ADDR); + + // Float conversion tiles + using F32TileData = Tile; + F32TileData cur_f32; + TASSIGN(cur_f32, UB_CUR_F32_ADDR); + F32TileData prev_f32; + TASSIGN(prev_f32, UB_PREV_F32_ADDR); + F32TileData bin_mask_f32; + TASSIGN(bin_mask_f32, UB_BIN_F32_ADDR); + + F32TileData one_tile; + TASSIGN(one_tile, UB_ONE_ADDR); + TEXPANDS(one_tile, 1.0f); + F32TileData zero_tile; + TASSIGN(zero_tile, UB_ZERO_ADDR); + TEXPANDS(zero_tile, 0.0f); + + F32TileData reduce_tmp; + TASSIGN(reduce_tmp, UB_REDUCE_TMP_ADDR); + + using F32CountTile = + Tile; + F32CountTile count_f32_tile; + TASSIGN(count_f32_tile, UB_COUNT_ADDR); + + // Local histogram tile in UB + using HistTile = + Tile; + HistTile local_hist(num_bins); + TASSIGN(local_hist, UB_LOCAL_HIST_ADDR); + TEXPANDS(local_hist, 0.0f); + + const float bin_width = (max_val - min_val) / static_cast(num_bins); + + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + + // --- Main Calculation Loop --- + for (uint32_t tile_idx = start_idx, ping = 1; tile_idx < end_idx; + ++tile_idx) { + int offset = tile_idx * TILE_SIZE; + TASSIGN(x_gm, x + offset); + + const event_t ev = ping ? (event_t)EVENT_ID0 : (event_t)EVENT_ID1; + const unsigned x_base = ping ? UB_X_PING : UB_X_PONG; + + InputTileData x_tile; + TASSIGN(x_tile, x_base); + + wait_flag(PIPE_V, PIPE_MTE2, ev); + TLOAD(x_tile, x_gm); + set_flag(PIPE_MTE2, PIPE_V, ev); + wait_flag(PIPE_MTE2, PIPE_V, ev); + + // Generate packed bit-mask + TCMPS(current_mask, x_tile, static_cast(min_val), CmpMode::LT); + // Select 1.0f or 0.0f based on the packed bit-mask + TSEL(prev_f32, current_mask, one_tile, zero_tile); + + for (int32_t j = 0; j < num_bins; ++j) { + float bin_upper_bound = min_val + (j + 1) * bin_width; + CmpMode mode = (j == num_bins - 1) ? CmpMode::LE : CmpMode::LT; + + TCMPS(current_mask, x_tile, static_cast(bin_upper_bound), mode); + TSEL(cur_f32, current_mask, one_tile, zero_tile); + TSUB(bin_mask_f32, cur_f32, prev_f32); + + // Reduce the selected tile to get the count of elements less than pivot + // in this tile + TEXPANDS(count_f32_tile, 0.0f); + TEXPANDS(reduce_tmp, 0.0f); + TROWSUM(count_f32_tile, bin_mask_f32, reduce_tmp); + + // Scalar move to update UB local histogram + set_flag(PIPE_V, PIPE_S, EVENT_ID2); + wait_flag(PIPE_V, PIPE_S, EVENT_ID2); + float f_count = count_f32_tile.GetValue(0); + if (f_count > 0.0f) { + local_hist.SetValue(j, local_hist.GetValue(j) + f_count); + } + set_flag(PIPE_S, PIPE_V, EVENT_ID2); + wait_flag(PIPE_S, PIPE_V, EVENT_ID2); + + TMOV(prev_f32, cur_f32); + } + + set_flag(PIPE_V, PIPE_MTE2, ev); + ping = 1 - ping; + } + + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + + // --- Final Store to Global Memory --- + HistGlobalData z_gm(z_local + block_idx * num_bins, + {static_cast(block_num) * num_bins}); + + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(z_gm, local_hist); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); +} + +// Template parameter to avoid "no function" kernel launch error +template +AICORE void runTHistogramFinal(__gm__ float *z_local, __gm__ int32_t *z, + const int32_t num_bins, + const int32_t num_blocks) { + set_mask_norm(); + set_vector_mask(-1, -1); + + if (get_block_idx() == 0) { + // --- Define Global Tensors --- + using InGlobalData = + pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + using OutGlobalData = + pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + + uint32_t addr = 0; + const uint32_t UB_IN_ADDR = addr; + addr += MAX_BLOCKS * MAX_BINS * sizeof(float); + const uint32_t UB_REDUCE_TMP_ADDR = addr; + addr += MAX_BINS * sizeof(float); + const uint32_t UB_FLOAT_OUT_ADDR = addr; + addr += MAX_BINS * sizeof(float); + const uint32_t UB_OUT_ADDR = addr; + addr += MAX_BINS * sizeof(int32_t); + + using InTile = Tile; + InTile in_tile( + {static_cast(num_blocks), static_cast(num_bins)}); + TASSIGN(in_tile, UB_IN_ADDR); + + using ReduceTmpTile = + Tile; + ReduceTmpTile reduce_tmp_tile(num_bins); + TASSIGN(reduce_tmp_tile, UB_REDUCE_TMP_ADDR); + TEXPANDS(reduce_tmp_tile, 0.0f); + + using FloatOutTile = + Tile; + FloatOutTile float_out_tile(num_bins); + TASSIGN(float_out_tile, UB_FLOAT_OUT_ADDR); + TEXPANDS(float_out_tile, 0.0f); + + using OutTile = Tile; + OutTile out_tile(num_bins); + TASSIGN(out_tile, UB_OUT_ADDR); + TEXPANDS(out_tile, static_cast(0)); + + // Load all block counts into UB row by row to match 2D tile padding + using InRowTile = + Tile; + InRowTile row_tile(static_cast(num_bins)); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + for (int32_t b = 0; b < num_blocks; ++b) { + InGlobalData z_local_gm(z_local + b * num_bins, {num_bins}); + TASSIGN(row_tile, UB_IN_ADDR + b * MAX_BINS * sizeof(float)); + TLOAD(row_tile, z_local_gm); + } + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + // TCOLSUM reduces along the row dimension (num_blocks) + TCOLSUM(float_out_tile, in_tile, reduce_tmp_tile, true); + + TCVT(out_tile, float_out_tile, RoundMode::CAST_RINT); + + // --- Final Store to Global Memory --- + OutGlobalData z_gm(z, {num_bins}); + + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(z_gm, out_tile); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + } +} + +extern "C" __global__ AICORE void histogram_fp16(GM_ADDR x, GM_ADDR z_local, + const uint32_t in_length, + const int32_t num_bins, + const float min_val, + const float max_val) { + runTLocalHistogram( + (__gm__ half *)x, (__gm__ float *)z_local, in_length, num_bins, min_val, + max_val); +} + +extern "C" __global__ AICORE void histogram_fp32(GM_ADDR x, GM_ADDR z_local, + const uint32_t in_length, + const int32_t num_bins, + const float min_val, + const float max_val) { + runTLocalHistogram( + (__gm__ float *)x, (__gm__ float *)z_local, in_length, num_bins, min_val, + max_val); +} + +extern "C" __global__ AICORE void histogram_final(GM_ADDR z_local, GM_ADDR z, + const int32_t num_bins, + const int32_t num_blocks) { + runTHistogramFinal<0>((__gm__ float *)z_local, (__gm__ int32_t *)z, num_bins, + num_blocks); +} + +#endif diff --git a/examples/jit_cpp/histogram/.gitignore b/examples/jit_cpp/histogram/.gitignore new file mode 100644 index 00000000..17aa483a --- /dev/null +++ b/examples/jit_cpp/histogram/.gitignore @@ -0,0 +1 @@ +outputs/ diff --git a/examples/jit_cpp/histogram/README.md b/examples/jit_cpp/histogram/README.md new file mode 100644 index 00000000..f31c220d --- /dev/null +++ b/examples/jit_cpp/histogram/README.md @@ -0,0 +1,58 @@ +# Ascend PTO Histogram Implementation Examples + +This directory contains a series of implementations demonstrating the evolution and optimization of a histogram kernel using the PTO library for Ascend NPUs (specifically targeting the A2/910B architecture). + +## Implementation Evolution + +The implementations are organized into steps, each introducing new concepts or optimizations: + +- **Step 0: Count Less Than (`step0_count_less_than`)**: The foundational algorithm that counts how many elements in a tile are less than a given pivot value using vector comparisons and reductions. The algorithm is implemented using atomic operations and using a two phase kernels. The atomic operations didn't always behave as expected, see note below. +- **Step 1: Naive Histogram (`step1_naive_histogram`)**: Expands the logic to a full histogram by looping over all bins. For each bin, it calculates the count of elements falling within that range. +- **Step 2: Double Buffering (`step2_double_buffering`)**: Introduces double buffering (ping-pong) for data loading from Global Memory (GM) to Unified Buffer (UB), allowing computation and data movement to overlap. +- **Step 3: Scatter Index to GM (`step3_scatter_index_to_gm`)**: A significant algorithmic shift. Instead of looping over bins and performing scalar updates, this implementation calculates the bin index for each element in parallel using vector operations and uses `MSCATTER` with `AtomicAdd` to update the histogram directly in Global Memory. This avoids the slow scalar-to-vector synchronization. + +## Included Files + +- `bench_kernels.py`: A comprehensive benchmarking suite to compare the performance of different implementations. +- `plot_kernels.py`: A utility script to visualize the benchmarking results. +- `run_histogram.py`: A script for functional testing and verification of the kernels. +- `jit_util_histogram.py`: Utility functions for JIT-compiling the C++ kernels into Python-callable operators. +- `kernel_utils.h`: Common helper functions used across the different kernel implementations. + +## Usage + +### Testing +To verify the correctness of the implementations, run: +```bash +python run_histogram.py +``` +*Note: You can modify the parameters (such as `num_bins`, `total_length`, `tile_size` or which implementation to use) directly inside the script.* + +### Benchmarking +To compare the performance of the different steps: +```bash +python bench_kernels.py +``` +The script supports various arguments for configuring the benchmark range and parameters. Use `--help` to see all options. + +### Plotting +After running the benchmarks, you can generate plots using: +```bash +python plot_kernels.py +``` + +## Torch Operator Integration + +A production-ready Torch operator implementation is provided in the repository's core source tree: +- **Kernel Implementation**: `csrc/kernel/kernel_histogram.cpp` +- **Host/C++ Wrapper**: `csrc/host/torch_histogram.h` + +## Known Issues and Observations + +During development and optimization, several architectural and library-specific details were noted: + +- **`TCMPS` Ambiguity**: The documentation is ambiguous regarding whether the valid column number of the mask tile must exactly match the source tile or if it can be different. The bits are packed, the example shows allocation of a tile that is using reduced number of valid rows, however this is in contrast with the description of the operation and in fact doesn't work. +- **Temporary Tiles**: Several instructions such as `TSEL`, `TXOR`, `TGATHER` and others mention an extra temporary tile parameter that is not actually required. Operations that do require it (like `TROWSUM`) don't mention any constraints that this tile should have. +- **`AtomicAdd` in `TSTORE`**: The `AtomicAdd` parameter in the `TSTORE` instruction was found to be unreliable in some configurations, requiring the use of a two-phase algorithm (local reduction followed by a final global reduction) for stability. +- **`MSCATTER` on A2/A3**: The `MSCATTER` implementation is not supported on A2/A3 architectures, thus at the time of writing Step 3 is provided only as an illustration of the next direction of implementation. +- Occasionally the two-phase algorithm will not launch the first or second phase with no error provided. Whether this is an implementation issue or a driver/toolkit issue is still under investigation. diff --git a/examples/jit_cpp/histogram/bench_kernels.py b/examples/jit_cpp/histogram/bench_kernels.py new file mode 100644 index 00000000..0025d81f --- /dev/null +++ b/examples/jit_cpp/histogram/bench_kernels.py @@ -0,0 +1,220 @@ +import argparse +import os +from pathlib import Path + +import pandas as pd +import torch + +from jit_util_histogram import jit_compile + +DEVICE = os.environ.get("NPU_DEVICE", "npu:1") +DTYPE = torch.float32 + +N_REPEAT = 20 +N_WARMUP = 5 +N_ALLOC = N_REPEAT + N_WARMUP + +TILE_SIZES = [512, 1024, 2048, 4096, 8192] +BINS_LIST = [8, 32, 64, 128, 192, 256] +MIN_VAL = 0.0 +MAX_VAL = 255.0 + +DEFAULT_CSV_REL_PATH = Path("outputs") / "csv" / "histogram_timing.csv" + + +def _parse_args(): + parser = argparse.ArgumentParser( + description=( + "Benchmark torch and implementation histogram kernels and save the results to " + "a CSV file." + ) + ) + parser.add_argument( + "--csv", + type=str, + default=str(DEFAULT_CSV_REL_PATH), + help=f"Output CSV path (default: {DEFAULT_CSV_REL_PATH})", + ) + parser.add_argument( + "--implementation", + type=int, + nargs="*", + choices=[1, 2], + help="Select the implementation steps (1: naive, 2: double buffering). If not provided, benchmarks all.", + ) + parser.add_argument( + "--with-torch", + action="store_true", + help="Include torch baseline timing/throughput benchmarking.", + ) + return parser.parse_args() + + +IMPLEMENTATIONS = { + 1: "step1_naive_histogram", + 2: "step2_double_buffering", +} + + +def _bench_backend( + name, func, a_list, z_list, c_ref, processed_elements, total_bytes, bins +): + c = None + for a, z in zip(a_list[:N_WARMUP], z_list[:N_WARMUP]): + res = func(a, z, bins, MIN_VAL, MAX_VAL) + c = res if res is not None else z + + mean_diff = float(torch.mean(torch.abs(c.float() - c_ref.float())).cpu()) + abs_error = float(torch.max(torch.abs(c.float() - c_ref.float())).cpu()) + + start = torch.npu.Event(enable_timing=True) + end = torch.npu.Event(enable_timing=True) + start.record() + for a, z in zip( + a_list[N_WARMUP : N_WARMUP + N_REPEAT], z_list[N_WARMUP : N_WARMUP + N_REPEAT] + ): + func(a, z, bins, MIN_VAL, MAX_VAL) + end.record() + torch.npu.synchronize() + dur_us = start.elapsed_time(end) / N_REPEAT * 1e3 + + gmelem_s = processed_elements / dur_us / 1e3 + bw_gbs = total_bytes * 1e6 / dur_us / (1024**3) + + print( + f"{name} duration: {dur_us:.3f} us, GElem/s: {gmelem_s:.3f}, mean diff: {mean_diff}" + ) + + return { + f"{name}_time_us": dur_us, + f"{name}_gmelem_s": gmelem_s, + f"{name}_bandwidth_gbs": bw_gbs, + f"{name}_mean_diff": mean_diff, + f"{name}_abs_error": abs_error, + f"{name}_error": "", + } + + +def bench_n_elems(funcs_to_bench, num_elements, bins, tile_size): + print(f"\n=== N = {num_elements:_} | bins = {bins} | tile_size = {tile_size} ===") + + a_list = [ + torch.rand(num_elements, dtype=DTYPE, device=DEVICE) * (MAX_VAL - MIN_VAL) + + MIN_VAL + for _ in range(N_ALLOC) + ] + z_list = [ + torch.zeros(bins, device=DEVICE, dtype=torch.int32) for _ in range(N_ALLOC) + ] + + ref_a = a_list[N_WARMUP - 1] + c_ref = torch.histc(ref_a, bins=bins, min=MIN_VAL, max=MAX_VAL).to(torch.int32) + + processed_elements = num_elements + total_bytes = num_elements * int(ref_a.element_size()) + bins * int( + c_ref.element_size() + ) + + record = { + "N": num_elements, + "bins": bins, + "tile_size": tile_size, + } + + for name, func in funcs_to_bench.items(): + if func is None: + record.update( + { + f"{name}_time_us": float("nan"), + f"{name}_gmelem_s": float("nan"), + f"{name}_bandwidth_gbs": float("nan"), + f"{name}_mean_diff": float("nan"), + f"{name}_abs_error": float("nan"), + f"{name}_error": "backend not compiled/enabled", + } + ) + continue + + try: + stats = _bench_backend( + name, func, a_list, z_list, c_ref, processed_elements, total_bytes, bins + ) + record.update(stats) + except Exception as exc: + print(f"{name} unavailable: {exc}") + record.update( + { + f"{name}_time_us": float("nan"), + f"{name}_gmelem_s": float("nan"), + f"{name}_bandwidth_gbs": float("nan"), + f"{name}_mean_diff": float("nan"), + f"{name}_abs_error": float("nan"), + f"{name}_error": str(exc), + } + ) + + return [record] + + +def main(): + args = _parse_args() + include_torch = args.with_torch + impls_to_run = args.implementation if args.implementation else [1, 2] + + torch.npu.set_device(DEVICE) + base = Path(__file__).resolve().parent + + csv_path = Path(args.csv) + if not csv_path.is_absolute(): + csv_path = base / csv_path + csv_path.parent.mkdir(parents=True, exist_ok=True) + + print(f"Implementations selected: {[IMPLEMENTATIONS[i] for i in impls_to_run]}") + if include_torch: + print("Torch baseline: enabled") + + vector_cores = torch.npu.get_device_properties().vector_core_num + base_elements = max(TILE_SIZES) * vector_cores + # multiplier = max(1, (1024 * 1024) // base_elements) + multiplier = 1 + n_elements_list = [base_elements * multiplier * i for i in range(1, 33)] + + records = [] + for tile_size in TILE_SIZES: + funcs_to_bench = {} + + if include_torch: + + def torch_hist_bench(x, _, bins, min_val, max_val): + return torch.histc(x, bins=bins, min=min_val, max=max_val).to( + torch.int32 + ) + + funcs_to_bench["torch"] = torch_hist_bench + + for i in impls_to_run: + impl_dir = IMPLEMENTATIONS[i] + custom_path = base / impl_dir / "kernel_histogram.cpp" + name = f"step{i}" + try: + print( + f"Compiling {impl_dir}/kernel_histogram.cpp for backend '{name}' (tile_size={tile_size}) ..." + ) + funcs_to_bench[name] = jit_compile( + str(custom_path), tile_size=tile_size + ) + except Exception as exc: + print(f"[WARN] backend '{name}' unavailable: {exc}") + funcs_to_bench[name] = None + + for bins in BINS_LIST: + for n in n_elements_list: + records.extend(bench_n_elems(funcs_to_bench, n, bins, tile_size)) + + df = pd.DataFrame.from_records(records) + df.to_csv(csv_path, index=False) + print(f"\nSaved benchmark CSV: {csv_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/jit_cpp/histogram/jit_util_histogram.py b/examples/jit_cpp/histogram/jit_util_histogram.py new file mode 100644 index 00000000..7721c655 --- /dev/null +++ b/examples/jit_cpp/histogram/jit_util_histogram.py @@ -0,0 +1,103 @@ +import os +import subprocess +import ctypes + +import torch + +ASCEND_TOOLKIT_HOME = os.environ["ASCEND_TOOLKIT_HOME"] +PTO_LIB_PATH = os.environ.get("PTO_LIB_PATH", ASCEND_TOOLKIT_HOME) + + +def compile_cpp( + kernel_cpp: str, tile_size=512, verbose: bool = False, timeout: int = 120 +) -> str: + dirname = os.path.dirname(kernel_cpp) + lib_path = os.path.join(dirname, f"{dirname}_jit.so") + + flags = [ + "-fPIC", + "-shared", + "-xcce", + "--npu-arch=dav-2201", + "-DMEMORY_BASE", # here hardcoded for A2A3; TODO: expose this option to jit interface + "-DHIST_TILE_SIZE=" + str(tile_size), + "-O2", + "-std=c++17", + f"-I{PTO_LIB_PATH}/include", + ] + + command = ["bisheng", *flags, kernel_cpp, "-o", lib_path] + if verbose: + print(f"compile {kernel_cpp} with command: \n", command) + + try: + subprocess.run( + command, + timeout=timeout, + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + except Exception as e: + output = e.stdout.decode("utf-8", errors="replace") if e.stdout else "" + raise RuntimeError( + f"Compile failed with exit code {e.returncode}:\n{output}" + ) from e + + if verbose: + print(f"generated {lib_path}") + return lib_path + + +def torch_to_ctypes(tensor): + return ctypes.c_void_p(tensor.data_ptr()) + + +def load_lib(lib_path, check_type=True): + lib_path = os.path.abspath(lib_path) + lib = ctypes.CDLL(lib_path) + + default_block_dim = torch.npu.get_device_properties().vector_core_num + + if check_type: + lib.histogram_fp32.argtypes = [ + ctypes.c_uint32, # blockDim + ctypes.c_void_p, # stream + ctypes.c_void_p, # x + ctypes.c_void_p, # z_local + ctypes.c_void_p, # z + ctypes.c_uint, # in_length + ctypes.c_int, # bins + ctypes.c_float, # min_val + ctypes.c_float, # max_val + ] + lib.histogram_fp32.restype = None + + def hist_func( + x, z, bins, min_val, max_val, block_dim=default_block_dim, stream_ptr=None + ): + if stream_ptr is None: + stream_ptr = torch.npu.current_stream()._as_parameter_ # noqa + N = x.numel() + z_local = torch.zeros((block_dim, bins), device=x.device, dtype=torch.float32) + lib.histogram_fp32( + block_dim, + stream_ptr, + torch_to_ctypes(x), + torch_to_ctypes(z_local), + torch_to_ctypes(z), + N, + bins, + ctypes.c_float(min_val), + ctypes.c_float(max_val), + ) + + return hist_func + + +def jit_compile(src_path, tile_size=512, clean_up=True): + lib_path = compile_cpp(src_path, tile_size=tile_size, verbose=True) + func = load_lib(lib_path, check_type=False) + if clean_up: + os.remove(lib_path) + return func diff --git a/examples/jit_cpp/histogram/kernel_utils.h b/examples/jit_cpp/histogram/kernel_utils.h new file mode 100644 index 00000000..0b6ae4c0 --- /dev/null +++ b/examples/jit_cpp/histogram/kernel_utils.h @@ -0,0 +1,45 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +All rights reserved. + +See LICENSE in the root of the software repository: +https://github.com/huawei-csl/pto-kernels/ +for the full License text. +*/ +#pragma once + +#include + +namespace kernel_utils { +/** + * @brief Do a sync step (set-wait flag) between two pipes. + * + * @tparam SrcPipe The pipe that sets the flag. + * @tparam DstPipe The pipe that waits for the flag. + * @param [in] id The event id to sync for. + */ +template +AICORE inline void SetWaitFlag(uint32_t id) { + set_flag(SrcPipe, DstPipe, static_cast(id)); + wait_flag(SrcPipe, DstPipe, static_cast(id)); +} + +/** + * @brief Performs a division on two integral numbers and rounds the result up + * to the nearest integer. + * + * @tparam T1 Data type of dividend. + * @tparam T2 Data type of divisor. + * @param [in] value Dividend. + * @param [in] divisor Divisor. + * @return Result of division. + */ +template ::value && + std::is_integral::value, + int>::type = 0> +AICORE inline T1 CeilDiv(T1 value, T2 divisor) { + return (value + divisor - 1) / divisor; +} + +} // namespace kernel_utils diff --git a/examples/jit_cpp/histogram/plot_kernels.py b/examples/jit_cpp/histogram/plot_kernels.py new file mode 100644 index 00000000..2f83723b --- /dev/null +++ b/examples/jit_cpp/histogram/plot_kernels.py @@ -0,0 +1,227 @@ +import argparse +from pathlib import Path + +import matplotlib.pyplot as plt +import pandas as pd + +DEFAULT_CSV_REL_PATH = Path("outputs") / "csv" / "histogram_timing.csv" +DEFAULT_PLOT_REL_DIR = Path("outputs") / "plots" + +BACKEND_STYLE = { + "torch": {"color": "#111111", "marker": "x", "linestyle": "--"}, + "step1": {"color": "#1f77b4", "marker": "o", "linestyle": "-"}, + "step2": {"color": "#ff7f0e", "marker": "s", "linestyle": "-"}, +} +CUSTOM_MARKERS = ["o", "s", "^", "v", "D", "P", "X", "*", "<", ">"] + + +def _parse_args(): + parser = argparse.ArgumentParser( + description="Plot benchmark figures from a benchmark CSV file." + ) + parser.add_argument( + "--csv", + type=str, + default=str(DEFAULT_CSV_REL_PATH), + help=f"Input benchmark CSV path (default: {DEFAULT_CSV_REL_PATH})", + ) + parser.add_argument( + "--plot-dir", + type=str, + default=str(DEFAULT_PLOT_REL_DIR), + help=f"Output plot directory (default: {DEFAULT_PLOT_REL_DIR})", + ) + parser.add_argument( + "--bins", + type=int, + default=256, + help="Number of bins to plot (default: 256)", + ) + parser.add_argument( + "--tile-size", + type=int, + action="append", + default=[], + help="Tile sizes to plot. Can be repeated. If not provided, defaults to 4096.", + ) + return parser.parse_args() + + +def _style(name: str) -> dict: + return BACKEND_STYLE.get( + name, {"color": "#2ca02c", "marker": "^", "linestyle": "-"} + ) + + +def _finalize_plot(title: str, xlabel: str, ylabel: str): + plt.xlabel(xlabel) + plt.ylabel(ylabel) + plt.title(title) + plt.xlim(left=0) + plt.ylim(bottom=0) + plt.grid(True, alpha=0.25) + handles, _ = plt.gca().get_legend_handles_labels() + if handles: + plt.legend(fontsize=8) + plt.tight_layout() + + +def _plot_backend(df: pd.DataFrame, backend: str, metric_col: str, style: dict): + if metric_col not in df.columns: + return + + if backend == "torch" or "tile_size" not in df.columns: + g = df[["N", metric_col]].dropna() + if g.empty: + return + g = g.groupby("N", as_index=False)[metric_col].mean().sort_values("N") + plt.plot( + g["N"], + g[metric_col], + marker=style["marker"], + linestyle=style["linestyle"], + color=style["color"], + label=backend, + ) + else: + ts_values = sorted(df["tile_size"].dropna().unique()) + grouped = df.dropna(subset=[metric_col]).groupby("tile_size", sort=True) + for idx, (ts, group) in enumerate(grouped): + g = group[["N", metric_col]].dropna() + if g.empty: + continue + g = g.groupby("N", as_index=False)[metric_col].mean().sort_values("N") + + if len(ts_values) > 1: + label = f"{backend} (ts={ts})" + marker = CUSTOM_MARKERS[idx % len(CUSTOM_MARKERS)] + alpha = max(0.4, 1.0 - (idx * 0.15)) + else: + label = backend + marker = style["marker"] + alpha = 1.0 + + plt.plot( + g["N"], + g[metric_col], + marker=marker, + linestyle=style["linestyle"], + color=style["color"], + alpha=alpha, + label=label, + ) + + +def plot_runtime(df: pd.DataFrame, out_dir: Path, bins: int) -> Path: + plt.figure(figsize=(10, 5)) + for backend in ["torch", "step1", "step2"]: + _plot_backend(df, backend, f"{backend}_time_us", _style(backend)) + + _finalize_plot( + title=f"Runtime vs N (bins={bins})", + xlabel="Number of Elements (N)", + ylabel="Runtime (us)", + ) + + out_path = out_dir / "duration.png" + plt.savefig(out_path, dpi=160) + plt.close() + return out_path + + +def plot_throughput(df: pd.DataFrame, out_dir: Path, bins: int) -> Path: + plt.figure(figsize=(10, 5)) + for backend in ["torch", "step1", "step2"]: + _plot_backend(df, backend, f"{backend}_gmelem_s", _style(backend)) + + _finalize_plot( + title=f"Throughput vs N (bins={bins})", + xlabel="Number of Elements (N)", + ylabel="GElem/s", + ) + + out_path = out_dir / "throughput.png" + plt.savefig(out_path, dpi=160) + plt.close() + return out_path + + +def plot_error(df: pd.DataFrame, out_dir: Path, bins: int) -> Path: + plt.figure(figsize=(10, 5)) + for backend in ["step1", "step2"]: + _plot_backend(df, backend, f"{backend}_mean_diff", _style(backend)) + + _finalize_plot( + title=f"Error vs N (bins={bins})", + xlabel="Number of Elements (N)", + ylabel="Mean Abs Error", + ) + + out_path = out_dir / "error.png" + plt.savefig(out_path, dpi=160) + plt.close() + return out_path + + +def main(): + args = _parse_args() + base = Path(__file__).resolve().parent + + csv_path = Path(args.csv) + if not csv_path.is_absolute(): + csv_path = base / csv_path + if not csv_path.exists(): + raise FileNotFoundError(f"Benchmark CSV not found: {csv_path}") + + plot_dir = Path(args.plot_dir) + if not plot_dir.is_absolute(): + plot_dir = base / plot_dir + plot_dir.mkdir(parents=True, exist_ok=True) + + df = pd.read_csv(csv_path) + required_columns = {"N"} + missing = required_columns - set(df.columns) + if missing: + raise ValueError( + f"CSV is missing required columns: {sorted(missing)} (file: {csv_path})" + ) + + plot_df = df + + if "bins" in plot_df.columns: + plot_df = plot_df[plot_df["bins"] == args.bins] + if plot_df.empty: + available_bins = sorted(df["bins"].dropna().unique()) + raise RuntimeError( + f"No rows found for bins={args.bins} in {csv_path}. " + f"Available bins: {available_bins}" + ) + + tile_sizes = args.tile_size if args.tile_size else [4096] + if "tile_size" in plot_df.columns: + plot_df = plot_df[plot_df["tile_size"].isin(tile_sizes)] + if plot_df.empty: + available_ts = sorted(df["tile_size"].dropna().unique()) + raise RuntimeError( + f"No rows found for tile_sizes={tile_sizes} in {csv_path} (with bins={args.bins}). " + f"Available tile sizes: {available_ts}" + ) + + if plot_df.empty: + raise RuntimeError(f"No data found in {csv_path}.") + + runtime_path = plot_runtime(plot_df, plot_dir, args.bins) + throughput_path = plot_throughput(plot_df, plot_dir, args.bins) + error_path = plot_error(plot_df, plot_dir, args.bins) + + print(f"Loaded CSV: {csv_path}") + print( + f"Filters applied: bins={args.bins}, tile_sizes={tile_sizes if 'tile_size' in df.columns else 'N/A'}" + ) + print(f"Saved runtime plot: {runtime_path}") + print(f"Saved throughput plot: {throughput_path}") + print(f"Saved error plot: {error_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/jit_cpp/histogram/run_histogram.py b/examples/jit_cpp/histogram/run_histogram.py new file mode 100644 index 00000000..4d21bca8 --- /dev/null +++ b/examples/jit_cpp/histogram/run_histogram.py @@ -0,0 +1,60 @@ +import torch +import torch_npu # noqa + +from jit_util_histogram import jit_compile + +IMPLEMENTATIONS = { + 1: "step1_naive_histogram", + 2: "step2_double_buffering", + 3: "step3_scatter_index_to_gm", # Not working on A2/A3 +} + + +def test_histogram(impl=2, size_mult=1, repeat_runs=20): + device = "npu:1" + dtype = torch.float32 + torch.npu.set_device(device) + + # Tile size is fixed in the kernel + tile_size = 4096 + num_cores = torch.npu.get_device_properties().vector_core_num + num_tiles = num_cores * tile_size + total_len = num_tiles * size_mult + + bins = 256 + min_val = 0.0 + max_val = 256.0 + + # Create an input tensor bounded around our standard test range + x = torch.rand(size=(total_len,), device="npu", dtype=dtype).contiguous() * max_val + z = torch.zeros(bins, device="npu", dtype=torch.int32) + + # Golden PyTorch implementation + expected_hist = torch.histc(x.cpu(), bins, min=min_val, max=max_val).to(torch.int32) + + hist_func = jit_compile( + f"{IMPLEMENTATIONS[impl]}/kernel_histogram.cpp", tile_size=tile_size + ) + + # NPU kernel execution, test to see if any race conditions occur across multiple runs + actual_hist = [] + for _ in range(repeat_runs): + z.zero_() + hist_func(x, z, bins, min_val, max_val, block_dim=num_cores) + actual_hist.append(z.cpu().clone()) + + torch.npu.synchronize() + + # Check for consistency across runs and correctness against the expected count + for i, hist in enumerate(actual_hist): + assert torch.equal( + hist, actual_hist[0] + ), f"Inconsistent results across runs at run {i}, expected\n {actual_hist[0]}\ngot\n{hist}\n" + + assert torch.equal( + expected_hist, actual_hist[0] + ), f"Mismatch between expected and actual histogram, expected\n {expected_hist}\ngot\n{actual_hist[0]}\n" + + +if __name__ == "__main__": + test_histogram(impl=2, size_mult=64, repeat_runs=20) diff --git a/examples/jit_cpp/histogram/step0_count_less_than/jit_util_count_less_than.py b/examples/jit_cpp/histogram/step0_count_less_than/jit_util_count_less_than.py new file mode 100644 index 00000000..eded6d5b --- /dev/null +++ b/examples/jit_cpp/histogram/step0_count_less_than/jit_util_count_less_than.py @@ -0,0 +1,85 @@ +import os +import subprocess +import ctypes + +import torch + +ASCEND_TOOLKIT_HOME = os.environ["ASCEND_TOOLKIT_HOME"] +PTO_LIB_PATH = os.environ.get("PTO_LIB_PATH", ASCEND_TOOLKIT_HOME) + + +def compile_cpp(kernel_cpp: str, verbose: bool = False, timeout: int = 120) -> str: + lib_path = os.path.join(os.path.dirname(kernel_cpp), "count_less_than_jit.so") + + flags = [ + "-fPIC", + "-shared", + "-xcce", + "--npu-arch=dav-2201", + "-DMEMORY_BASE", # here hardcoded for A2A3; TODO: expose this option to jit interface + "-O2", + "-std=c++17", + f"-I{PTO_LIB_PATH}/include", + ] + + command = ["bisheng", *flags, kernel_cpp, "-o", lib_path] + if verbose: + print(f"compile {kernel_cpp} with command: \n", command) + + try: + subprocess.run( + command, + timeout=timeout, + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + except Exception as e: + output = e.stdout.decode("utf-8", errors="replace") if e.stdout else "" + raise RuntimeError( + f"Compile failed with exit code {e.returncode}:\n{output}" + ) from e + + if verbose: + print(f"generated {lib_path}") + return lib_path + + +def torch_to_ctypes(tensor): + return ctypes.c_void_p(tensor.data_ptr()) + + +def load_lib(lib_path, check_type=True): + lib_path = os.path.abspath(lib_path) + lib = ctypes.CDLL(lib_path) + + default_block_dim = torch.npu.get_device_properties().vector_core_num + + if check_type: + lib.count_less_than_fp32.argtypes = [ + ctypes.c_uint32, # blockDim + ctypes.c_void_p, # stream + ctypes.c_void_p, # x + ctypes.c_void_p, # z + ctypes.c_uint, # in_length + ctypes.c_float, # pivot + ] + lib.count_less_than_fp32.restype = None + + def count_func(x, z, pivot, block_dim=default_block_dim, stream_ptr=None): + if stream_ptr is None: + stream_ptr = torch.npu.current_stream()._as_parameter_ # noqa + N = x.numel() + lib.count_less_than_fp32( + block_dim, stream_ptr, torch_to_ctypes(x), torch_to_ctypes(z), N, pivot + ) + + return count_func + + +def jit_compile(src_path, clean_up=True): + lib_path = compile_cpp(src_path, verbose=True) + func = load_lib(lib_path) + if clean_up: + os.remove(lib_path) + return func diff --git a/examples/jit_cpp/histogram/step0_count_less_than/kernel_count_less_than.cpp b/examples/jit_cpp/histogram/step0_count_less_than/kernel_count_less_than.cpp new file mode 100644 index 00000000..8d35f5e2 --- /dev/null +++ b/examples/jit_cpp/histogram/step0_count_less_than/kernel_count_less_than.cpp @@ -0,0 +1,241 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +All rights reserved. + +See LICENSE in the root of the software repository: +https://github.com/huawei-csl/pto-kernels/ +for the full License text. +*/ + +#include + +#include "../kernel_utils.h" +#include "acl/acl.h" + +using namespace pto; + +/** + * runTCountLessThan - Local count calculation using TCMPS and TSEL + */ +template +AICORE void runTCountLessThan(__gm__ T *x, __gm__ float *z_local, + const uint32_t total_length, const float pivot) { +#if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) + + set_mask_norm(); + set_vector_mask(-1, -1); + + // --- Define Global Tensors --- + using InputGlobalData = pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + using OutGlobalData = + pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + + // --- Work Distribution --- + const uint32_t block_idx = get_block_idx(); + const uint32_t block_num = get_block_num(); + const uint32_t num_tiles_total = + kernel_utils::CeilDiv(total_length, TILE_SIZE); + const uint32_t num_tiles_per_core = + kernel_utils::CeilDiv(num_tiles_total, block_num); + const uint32_t start_idx = block_idx * num_tiles_per_core; + const uint32_t end_idx = (start_idx + num_tiles_per_core > num_tiles_total) + ? num_tiles_total + : (start_idx + num_tiles_per_core); + + if (start_idx < end_idx) { + // --- Define UB Tiles and Memory Layout --- + uint32_t addr = 0; + const uint32_t UB_X_ADDR = addr; + addr += TILE_SIZE * sizeof(T); + const uint32_t UB_CUR_MASK_ADDR = addr; + addr += TILE_SIZE * sizeof(uint8_t); + const uint32_t UB_ONES_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_ZEROS_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_TSEL_OUT_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_REDUCE_TMP_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_COUNT_ADDR = addr; + addr += 8 * sizeof(float); + const uint32_t UB_TOTAL_COUNT_ADDR = addr; + addr += 8 * sizeof(float); + + using InputTileData = Tile; + InputTileData x_tile; + TASSIGN(x_tile, UB_X_ADDR); + + using MaskTileData = Tile; + MaskTileData current_mask; + TASSIGN(current_mask, UB_CUR_MASK_ADDR); + + // Float conversion tiles + using F32TileData = Tile; + F32TileData ones_tile; + TASSIGN(ones_tile, UB_ONES_ADDR); + TEXPANDS(ones_tile, 1.0f); + + F32TileData zeros_tile; + TASSIGN(zeros_tile, UB_ZEROS_ADDR); + TEXPANDS(zeros_tile, 0.0f); + + F32TileData tsel_out_tile; + TASSIGN(tsel_out_tile, UB_TSEL_OUT_ADDR); + + F32TileData reduce_tmp; + TASSIGN(reduce_tmp, UB_REDUCE_TMP_ADDR); + + using F32CountTile = + Tile; + F32CountTile count_f32_tile; + TASSIGN(count_f32_tile, UB_COUNT_ADDR); + + F32CountTile total_count_f32_tile; + TASSIGN(total_count_f32_tile, UB_TOTAL_COUNT_ADDR); + TEXPANDS(total_count_f32_tile, 0.0f); + + // --- Main Calculation Loop --- + for (uint32_t tile_idx = start_idx; tile_idx < end_idx; ++tile_idx) { + const uint32_t offset = tile_idx * TILE_SIZE; + InputGlobalData x_gm(x + offset, {static_cast(total_length)}); + + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(x_tile, x_gm); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + // Generate packed bit-mask + TCMPS(current_mask, x_tile, static_cast(pivot), CmpMode::LT); + // Select 1.0f or 0.0f based on the packed bit-mask + TSEL(tsel_out_tile, current_mask, ones_tile, zeros_tile); + + // Reduce the selected tile to get the count of elements less than pivot + // in this tile + TEXPANDS(count_f32_tile, 0.0f); + TEXPANDS(reduce_tmp, 0.0f); + TROWSUM(count_f32_tile, tsel_out_tile, reduce_tmp); + + // Accumulate the count from this tile into the total count for this block + TADD(total_count_f32_tile, total_count_f32_tile, count_f32_tile); + } + + // --- Final Store to Global Memory --- + OutGlobalData z_local_gm(z_local + block_idx, + {static_cast(block_num)}); + + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(z_local_gm, total_count_f32_tile); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + } + +#endif +} + +// Template parameter to avoid "no function" kernel launch error +template +AICORE void runTCountFinal(__gm__ float *z_local, __gm__ int32_t *z, + const int32_t num_blocks) { +#if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) + + set_mask_norm(); + set_vector_mask(-1, -1); + + if (get_block_idx() == 0) { + // --- Define Global Tensors --- + using InGlobalData = + pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + using OutGlobalData = pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + + constexpr uint32_t MAX_BLOCKS = + 256; // Must be larger than number of AIV cores + + uint32_t addr = 0; + const uint32_t UB_IN_ADDR = addr; + addr += MAX_BLOCKS * sizeof(float); + const uint32_t UB_REDUCE_TMP_ADDR = addr; + addr += 8 * sizeof(float); + const uint32_t UB_FLOAT_OUT_ADDR = addr; + addr += 8 * sizeof(float); + const uint32_t UB_OUT_ADDR = addr; + addr += 8 * sizeof(int32_t); + + using InTile = Tile; + InTile in_tile(static_cast(num_blocks)); + TASSIGN(in_tile, UB_IN_ADDR); + + using ReduceTmpTile = Tile; + ReduceTmpTile reduce_tmp_tile; + TASSIGN(reduce_tmp_tile, UB_REDUCE_TMP_ADDR); + + using FloatOutTile = + Tile; + FloatOutTile float_out_tile; + TASSIGN(float_out_tile, UB_FLOAT_OUT_ADDR); + TEXPANDS(float_out_tile, 0.0f); + + using OutTile = Tile; + OutTile out_tile; + TASSIGN(out_tile, UB_OUT_ADDR); + TEXPANDS(out_tile, (int32_t)0); + + // Load all block counts into UB + InGlobalData z_local_gm(z_local, {num_blocks}); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(in_tile, z_local_gm); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + TROWSUM(float_out_tile, in_tile, reduce_tmp_tile); + TCVT(out_tile, float_out_tile, RoundMode::CAST_RINT); + + OutGlobalData z_gm(z); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(z_gm, out_tile); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + } + +#endif +} + +__global__ AICORE void count_less_than_local(__gm__ void *x, + __gm__ void *z_local, + const uint32_t in_length, + const float pivot) { + constexpr unsigned TILE_SIZE = 512; + runTCountLessThan( + (__gm__ float *)x, (__gm__ float *)z_local, in_length, pivot); +} + +__global__ AICORE void count_final(__gm__ void *z_local, __gm__ void *z, + const int32_t num_blocks) { + runTCountFinal<0>((__gm__ float *)z_local, (__gm__ int32_t *)z, num_blocks); +} + +extern "C" void count_less_than_fp32(uint32_t num_blocks, void *stream, + uint8_t *x, uint8_t *z, uint32_t in_length, + float pivot) { + // Could have been allocated in Torch and passed here + // This is not a suggested practice for production code, we use it here to + // keep the python side interface cleaner + void *z_local = nullptr; + size_t size = num_blocks * sizeof(float); + aclrtMalloc(&z_local, size, ACL_MEM_MALLOC_HUGE_FIRST); + + count_less_than_local<<>>(x, z_local, in_length, + pivot); + count_final<<<1, nullptr, stream>>>(z_local, z, num_blocks); + + aclrtFree(z_local); +} diff --git a/examples/jit_cpp/histogram/step0_count_less_than/kernel_count_less_than_atomic.cpp b/examples/jit_cpp/histogram/step0_count_less_than/kernel_count_less_than_atomic.cpp new file mode 100644 index 00000000..118af5c2 --- /dev/null +++ b/examples/jit_cpp/histogram/step0_count_less_than/kernel_count_less_than_atomic.cpp @@ -0,0 +1,157 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +All rights reserved. + +See LICENSE in the root of the software repository: +https://github.com/huawei-csl/pto-kernels/ +for the full License text. +*/ + +#include + +#include "../kernel_utils.h" + +using namespace pto; + +/** + * runTCountLessThan - Local count calculation with Atomic Addition to Global + * Memory. + */ +template +AICORE void runTCountLessThan(__gm__ T *x, __gm__ int32_t *z, + const uint32_t total_length, const float pivot) { +#if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + set_atomic_add(); + + // --- Define Global Tensors --- + using InputGlobalData = pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + using OutGlobalData = pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + + // --- Work Distribution --- + const uint32_t block_idx = get_block_idx(); + const uint32_t block_num = get_block_num(); + const uint32_t num_tiles_total = + kernel_utils::CeilDiv(total_length, TILE_SIZE); + const uint32_t num_tiles_per_core = + kernel_utils::CeilDiv(num_tiles_total, block_num); + const uint32_t start_idx = block_idx * num_tiles_per_core; + const uint32_t end_idx = (start_idx + num_tiles_per_core > num_tiles_total) + ? num_tiles_total + : (start_idx + num_tiles_per_core); + + if (start_idx < end_idx) { + // --- Define UB Tiles and Memory Layout --- + uint32_t addr = 0; + const uint32_t UB_X_ADDR = addr; + addr += TILE_SIZE * sizeof(T); + const uint32_t UB_CUR_MASK_ADDR = addr; + addr += TILE_SIZE * sizeof(uint8_t); + const uint32_t UB_ONES_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_ZEROS_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_TSEL_OUT_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_REDUCE_TMP_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_COUNT_ADDR = addr; + addr += 8 * sizeof(float); + const uint32_t UB_TOTAL_COUNT_ADDR = addr; + addr += 8 * sizeof(float); + const uint32_t UB_LOCAL_OUT_ADDR = addr; + + using InputTileData = Tile; + InputTileData x_tile; + TASSIGN(x_tile, UB_X_ADDR); + + using MaskTileData = Tile; + MaskTileData current_mask; + TASSIGN(current_mask, UB_CUR_MASK_ADDR); + + // Float conversion tiles + using F32TileData = Tile; + F32TileData ones_tile; + TASSIGN(ones_tile, UB_ONES_ADDR); + TEXPANDS(ones_tile, 1.0f); + + F32TileData zeros_tile; + TASSIGN(zeros_tile, UB_ZEROS_ADDR); + TEXPANDS(zeros_tile, 0.0f); + + F32TileData tsel_out_tile; + TASSIGN(tsel_out_tile, UB_TSEL_OUT_ADDR); + + F32TileData reduce_tmp; + TASSIGN(reduce_tmp, UB_REDUCE_TMP_ADDR); + + using F32CountTile = + Tile; + F32CountTile count_f32_tile; + TASSIGN(count_f32_tile, UB_COUNT_ADDR); + + F32CountTile total_count_f32_tile; + TASSIGN(total_count_f32_tile, UB_TOTAL_COUNT_ADDR); + TEXPANDS(total_count_f32_tile, 0.0f); + + using OutTile = Tile; + OutTile local_out; + TASSIGN(local_out, UB_LOCAL_OUT_ADDR); + TEXPANDS(local_out, (int32_t)0); + + // --- Main Calculation Loop --- + for (uint32_t tile_idx = start_idx; tile_idx < end_idx; ++tile_idx) { + const uint32_t offset = tile_idx * TILE_SIZE; + InputGlobalData x_gm(x + offset, {static_cast(total_length)}); + + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(x_tile, x_gm); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + // Generate packed bit-mask + TCMPS(current_mask, x_tile, static_cast(pivot), CmpMode::LT); + // Select 1.0f or 0.0f based on the packed bit-mask + TSEL(tsel_out_tile, current_mask, ones_tile, zeros_tile); + + TEXPANDS(count_f32_tile, 0.0f); + TEXPANDS(reduce_tmp, 0.0f); + TROWSUM(count_f32_tile, tsel_out_tile, reduce_tmp); + + TADD(total_count_f32_tile, total_count_f32_tile, count_f32_tile); + } + + // Convert accumulated total into our UB local count + TCVT(local_out, total_count_f32_tile, RoundMode::CAST_RINT); + + // --- Final Atomic Store to Global Memory --- + OutGlobalData z_gm(z); + + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + // Doesn't do atomic adds + TSTORE(z_gm, local_out); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + } + +#endif +} + +__global__ AICORE void count_less_than(__gm__ void *x, __gm__ void *z, + const uint32_t in_length, + const float pivot) { + constexpr unsigned TILE_SIZE = 512; + runTCountLessThan((__gm__ float *)x, (__gm__ int32_t *)z, + in_length, pivot); +} + +extern "C" void count_less_than_fp32(uint32_t num_blocks, void *stream, + uint8_t *x, uint8_t *z, uint32_t in_length, + float pivot) { + count_less_than<<>>(x, z, in_length, pivot); +} diff --git a/examples/jit_cpp/histogram/step0_count_less_than/run_count_less_than.py b/examples/jit_cpp/histogram/step0_count_less_than/run_count_less_than.py new file mode 100644 index 00000000..2e6b6618 --- /dev/null +++ b/examples/jit_cpp/histogram/step0_count_less_than/run_count_less_than.py @@ -0,0 +1,49 @@ +import torch +import torch_npu # noqa + +from jit_util_count_less_than import jit_compile + + +def test_count_less_than(size_mult=1, repeat_runs=20, use_atomic_impl=False): + device = "npu:1" + dtype = torch.float32 + torch.npu.set_device(device) + + # Tile size is fixed in the kernel + tile_size = 512 + num_cores = torch.npu.get_device_properties(0).vector_core_num + num_tiles = num_cores * tile_size + total_len = num_tiles * size_mult + + # Create an input tensor bounded around our standard test pivots + x = torch.rand(size=(total_len,), device="npu", dtype=dtype).contiguous() + z = torch.zeros(1, device="npu", dtype=torch.int32) + pivot = torch.rand(1).item() + + # Golden PyTorch implementation + expected_count = (x < pivot).sum().to(torch.int32).cpu().item() + + if use_atomic_impl: + count_func = jit_compile("kernel_count_less_than_atomic.cpp") + else: + count_func = jit_compile("kernel_count_less_than.cpp") + + # NPU kernel execution, test to see if any race conditions occur across multiple runs + actual_count = [] + for _ in range(repeat_runs): + count_func(x, z, pivot, block_dim=num_cores) + actual_count.append(z.item()) + + torch.npu.synchronize() + + # Check for consistency across runs and correctness against the expected count + assert len(set(actual_count)) == 1, "Inconsistent results across runs" + assert ( + expected_count == actual_count[0] + ), f"Mismatch: expected {expected_count}, got {actual_count}" + + +if __name__ == "__main__": + # Atomic implementation has problems + # test_count_less_than(size_mult=64, repeat_runs=20, use_atomic_impl=True) + test_count_less_than(size_mult=64, repeat_runs=20, use_atomic_impl=False) diff --git a/examples/jit_cpp/histogram/step1_naive_histogram/kernel_histogram.cpp b/examples/jit_cpp/histogram/step1_naive_histogram/kernel_histogram.cpp new file mode 100644 index 00000000..6e1b6ea9 --- /dev/null +++ b/examples/jit_cpp/histogram/step1_naive_histogram/kernel_histogram.cpp @@ -0,0 +1,287 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +All rights reserved. + +See LICENSE in the root of the software repository: +https://github.com/huawei-csl/pto-kernels/ +for the full License text. +*/ + +#include + +#include "../kernel_utils.h" + +using namespace pto; + +#ifndef HIST_TILE_SIZE +#define HIST_TILE_SIZE 1024 +#endif +constexpr uint32_t MAX_BINS = 256; +constexpr uint32_t MAX_BLOCKS = 64; + +/** + * runTLocalHistogram - Local, per-core histogram calculation + */ +template +AICORE void runTLocalHistogram(__gm__ T *x, __gm__ float *z_local, + const uint32_t total_length, + const int32_t num_bins, const float min_val, + const float max_val) { +#if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + + // --- Define Global Tensors --- + using InputGlobalData = pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + using HistGlobalData = + pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + + // --- Work Distribution --- + const uint32_t block_idx = get_block_idx(); + const uint32_t block_num = get_block_num(); + const uint32_t num_tiles_total = + kernel_utils::CeilDiv(total_length, TILE_SIZE); + const uint32_t num_tiles_per_core = + kernel_utils::CeilDiv(num_tiles_total, block_num); + const uint32_t start_idx = block_idx * num_tiles_per_core; + const uint32_t end_idx = (start_idx + num_tiles_per_core > num_tiles_total) + ? num_tiles_total + : (start_idx + num_tiles_per_core); + + // --- Define UB Tiles and Memory Layout --- + uint32_t addr = 0; + const uint32_t UB_X_ADDR = addr; + addr += TILE_SIZE * sizeof(T); + const uint32_t UB_CUR_MASK_ADDR = addr; + addr += TILE_SIZE * sizeof(uint8_t); + const uint32_t UB_CUR_F32_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_PREV_F32_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_BIN_F32_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_ONE_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_ZERO_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_REDUCE_TMP_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_COUNT_ADDR = addr; + addr += 8 * sizeof(float); + const uint32_t UB_LOCAL_HIST_ADDR = addr; + + InputGlobalData x_gm(x, {static_cast(total_length)}); + + using InputTileData = Tile; + InputTileData x_tile; + TASSIGN(x_tile, UB_X_ADDR); + + using MaskTileData = Tile; + MaskTileData current_mask; + TASSIGN(current_mask, UB_CUR_MASK_ADDR); + + // Float conversion tiles + using F32TileData = Tile; + F32TileData cur_f32; + TASSIGN(cur_f32, UB_CUR_F32_ADDR); + F32TileData prev_f32; + TASSIGN(prev_f32, UB_PREV_F32_ADDR); + F32TileData bin_mask_f32; + TASSIGN(bin_mask_f32, UB_BIN_F32_ADDR); + + F32TileData one_tile; + TASSIGN(one_tile, UB_ONE_ADDR); + TEXPANDS(one_tile, 1.0f); + F32TileData zero_tile; + TASSIGN(zero_tile, UB_ZERO_ADDR); + TEXPANDS(zero_tile, 0.0f); + + F32TileData reduce_tmp; + TASSIGN(reduce_tmp, UB_REDUCE_TMP_ADDR); + + using F32CountTile = + Tile; + F32CountTile count_f32_tile; + TASSIGN(count_f32_tile, UB_COUNT_ADDR); + + // Local histogram tile in UB + using HistTile = + Tile; + HistTile local_hist(num_bins); + TASSIGN(local_hist, UB_LOCAL_HIST_ADDR); + TEXPANDS(local_hist, 0.0f); + + const float bin_width = (max_val - min_val) / static_cast(num_bins); + + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + + // --- Main Calculation Loop --- + for (uint32_t tile_idx = start_idx; tile_idx < end_idx; ++tile_idx) { + int offset = tile_idx * TILE_SIZE; + TASSIGN(x_gm, x + offset); + + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(x_tile, x_gm); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + // Generate packed bit-mask + TCMPS(current_mask, x_tile, static_cast(min_val), CmpMode::LT); + // Select 1.0f or 0.0f based on the packed bit-mask + TSEL(prev_f32, current_mask, one_tile, zero_tile); + + for (int32_t j = 0; j < num_bins; ++j) { + float bin_upper_bound = min_val + (j + 1) * bin_width; + CmpMode mode = (j == num_bins - 1) ? CmpMode::LE : CmpMode::LT; + + TCMPS(current_mask, x_tile, static_cast(bin_upper_bound), mode); + TSEL(cur_f32, current_mask, one_tile, zero_tile); + TSUB(bin_mask_f32, cur_f32, prev_f32); + + // Reduce the selected tile to get the count of elements less than pivot + // in this tile + TEXPANDS(count_f32_tile, 0.0f); + TEXPANDS(reduce_tmp, 0.0f); + TROWSUM(count_f32_tile, bin_mask_f32, reduce_tmp); + + // Scalar move to update UB local histogram + set_flag(PIPE_V, PIPE_S, EVENT_ID0); + wait_flag(PIPE_V, PIPE_S, EVENT_ID0); + float f_count = count_f32_tile.GetValue(0); + if (f_count > 0.0f) { + local_hist.SetValue(j, local_hist.GetValue(j) + f_count); + } + set_flag(PIPE_S, PIPE_V, EVENT_ID0); + wait_flag(PIPE_S, PIPE_V, EVENT_ID0); + + TMOV(prev_f32, cur_f32); + } + } + + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + + // --- Final Store to Global Memory --- + HistGlobalData z_gm(z_local + block_idx * num_bins, + {static_cast(block_num) * num_bins}); + + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(z_gm, local_hist); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + +#endif +} + +// Template parameter to avoid "no function" kernel launch error +template +AICORE void runTHistogramFinal(__gm__ float *z_local, __gm__ int32_t *z, + const int32_t num_bins, + const int32_t num_blocks) { +#if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + + if (get_block_idx() == 0) { + // --- Define Global Tensors --- + using InGlobalData = + pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + using OutGlobalData = + pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + + uint32_t addr = 0; + const uint32_t UB_IN_ADDR = addr; + addr += MAX_BLOCKS * MAX_BINS * sizeof(float); + const uint32_t UB_REDUCE_TMP_ADDR = addr; + addr += MAX_BINS * sizeof(float); + const uint32_t UB_FLOAT_OUT_ADDR = addr; + addr += MAX_BINS * sizeof(float); + const uint32_t UB_OUT_ADDR = addr; + addr += MAX_BINS * sizeof(int32_t); + + using InTile = Tile; + InTile in_tile( + {static_cast(num_blocks), static_cast(num_bins)}); + TASSIGN(in_tile, UB_IN_ADDR); + + using ReduceTmpTile = + Tile; + ReduceTmpTile reduce_tmp_tile(num_bins); + TASSIGN(reduce_tmp_tile, UB_REDUCE_TMP_ADDR); + TEXPANDS(reduce_tmp_tile, 0.0f); + + using FloatOutTile = + Tile; + FloatOutTile float_out_tile(num_bins); + TASSIGN(float_out_tile, UB_FLOAT_OUT_ADDR); + TEXPANDS(float_out_tile, 0.0f); + + using OutTile = Tile; + OutTile out_tile(num_bins); + TASSIGN(out_tile, UB_OUT_ADDR); + TEXPANDS(out_tile, static_cast(0)); + + // Load all block counts into UB row by row to match 2D tile padding + using InRowTile = + Tile; + InRowTile row_tile(static_cast(num_bins)); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + for (int32_t b = 0; b < num_blocks; ++b) { + InGlobalData z_local_gm(z_local + b * num_bins, {num_bins}); + TASSIGN(row_tile, UB_IN_ADDR + b * MAX_BINS * sizeof(float)); + TLOAD(row_tile, z_local_gm); + } + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + // TCOLSUM reduces along the row dimension (num_blocks) + TCOLSUM(float_out_tile, in_tile, reduce_tmp_tile, true); + + TCVT(out_tile, float_out_tile, RoundMode::CAST_RINT); + + // --- Final Store to Global Memory --- + OutGlobalData z_gm(z, {num_bins}); + + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(z_gm, out_tile); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + } + +#endif +} + +__global__ AICORE void histogram_local_fp32( + __gm__ void *x, __gm__ void *z_local, const uint32_t in_length, + const int32_t num_bins, const float min_val, const float max_val) { + runTLocalHistogram((__gm__ float *)x, + (__gm__ float *)z_local, in_length, + num_bins, min_val, max_val); +} + +__global__ AICORE void histogram_final(__gm__ void *z_local, __gm__ void *z, + const int32_t num_bins, + const int32_t num_blocks) { + runTHistogramFinal<0>((__gm__ float *)z_local, (__gm__ int32_t *)z, num_bins, + num_blocks); +} + +extern "C" void histogram_fp32(uint32_t num_blocks, void *stream, void *x, + void *z_local, void *z, const uint32_t in_length, + const int32_t num_bins, const float min_val, + const float max_val) { + histogram_local_fp32<<>>( + x, z_local, in_length, num_bins, min_val, max_val); + histogram_final<<<1, nullptr, stream>>>(z_local, z, num_bins, num_blocks); +} diff --git a/examples/jit_cpp/histogram/step2_double_buffering/kernel_histogram.cpp b/examples/jit_cpp/histogram/step2_double_buffering/kernel_histogram.cpp new file mode 100644 index 00000000..636e2b82 --- /dev/null +++ b/examples/jit_cpp/histogram/step2_double_buffering/kernel_histogram.cpp @@ -0,0 +1,300 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +All rights reserved. + +See LICENSE in the root of the software repository: +https://github.com/huawei-csl/pto-kernels/ +for the full License text. +*/ + +#include + +#include "../kernel_utils.h" + +using namespace pto; + +#ifndef HIST_TILE_SIZE +#define HIST_TILE_SIZE 1024 +#endif +constexpr uint32_t MAX_BINS = 256; +constexpr uint32_t MAX_BLOCKS = 64; + +/** + * runTLocalHistogram - Local, per-core histogram calculation + */ +template +AICORE void runTLocalHistogram(__gm__ T *x, __gm__ float *z_local, + const uint32_t total_length, + const int32_t num_bins, const float min_val, + const float max_val) { +#if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + + // --- Define Global Tensors --- + using InputGlobalData = pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + using HistGlobalData = + pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + + // --- Work Distribution --- + const uint32_t block_idx = get_block_idx(); + const uint32_t block_num = get_block_num(); + const uint32_t num_tiles_total = + kernel_utils::CeilDiv(total_length, TILE_SIZE); + const uint32_t num_tiles_per_core = + kernel_utils::CeilDiv(num_tiles_total, block_num); + const uint32_t start_idx = block_idx * num_tiles_per_core; + const uint32_t end_idx = (start_idx + num_tiles_per_core > num_tiles_total) + ? num_tiles_total + : (start_idx + num_tiles_per_core); + + // --- Define UB Tiles and Memory Layout --- + uint32_t addr = 0; + const uint32_t UB_X_PING = addr; + addr += TILE_SIZE * sizeof(T); + const uint32_t UB_X_PONG = addr; + addr += TILE_SIZE * sizeof(T); + const uint32_t UB_CUR_MASK_ADDR = addr; + addr += TILE_SIZE * sizeof(uint8_t); + const uint32_t UB_CUR_F32_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_PREV_F32_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_BIN_F32_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_ONE_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_ZERO_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_REDUCE_TMP_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_COUNT_ADDR = addr; + addr += 8 * sizeof(float); + const uint32_t UB_LOCAL_HIST_ADDR = addr; + + InputGlobalData x_gm(x, {static_cast(total_length)}); + + using InputTileData = Tile; + + using MaskTileData = Tile; + MaskTileData current_mask; + TASSIGN(current_mask, UB_CUR_MASK_ADDR); + + // Float conversion tiles + using F32TileData = Tile; + F32TileData cur_f32; + TASSIGN(cur_f32, UB_CUR_F32_ADDR); + F32TileData prev_f32; + TASSIGN(prev_f32, UB_PREV_F32_ADDR); + F32TileData bin_mask_f32; + TASSIGN(bin_mask_f32, UB_BIN_F32_ADDR); + + F32TileData one_tile; + TASSIGN(one_tile, UB_ONE_ADDR); + TEXPANDS(one_tile, 1.0f); + F32TileData zero_tile; + TASSIGN(zero_tile, UB_ZERO_ADDR); + TEXPANDS(zero_tile, 0.0f); + + F32TileData reduce_tmp; + TASSIGN(reduce_tmp, UB_REDUCE_TMP_ADDR); + + using F32CountTile = + Tile; + F32CountTile count_f32_tile; + TASSIGN(count_f32_tile, UB_COUNT_ADDR); + + // Local histogram tile in UB + using HistTile = + Tile; + HistTile local_hist(num_bins); + TASSIGN(local_hist, UB_LOCAL_HIST_ADDR); + TEXPANDS(local_hist, 0.0f); + + const float bin_width = (max_val - min_val) / static_cast(num_bins); + + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + + // --- Main Calculation Loop --- + for (uint32_t tile_idx = start_idx, ping = 1; tile_idx < end_idx; + ++tile_idx) { + int offset = tile_idx * TILE_SIZE; + TASSIGN(x_gm, x + offset); + + const event_t ev = ping ? (event_t)EVENT_ID0 : (event_t)EVENT_ID1; + const unsigned x_base = ping ? UB_X_PING : UB_X_PONG; + + InputTileData x_tile; + TASSIGN(x_tile, x_base); + + wait_flag(PIPE_V, PIPE_MTE2, ev); + TLOAD(x_tile, x_gm); + set_flag(PIPE_MTE2, PIPE_V, ev); + wait_flag(PIPE_MTE2, PIPE_V, ev); + + // Generate packed bit-mask + TCMPS(current_mask, x_tile, static_cast(min_val), CmpMode::LT); + // Select 1.0f or 0.0f based on the packed bit-mask + TSEL(prev_f32, current_mask, one_tile, zero_tile); + + for (int32_t j = 0; j < num_bins; ++j) { + float bin_upper_bound = min_val + (j + 1) * bin_width; + CmpMode mode = (j == num_bins - 1) ? CmpMode::LE : CmpMode::LT; + + TCMPS(current_mask, x_tile, static_cast(bin_upper_bound), mode); + TSEL(cur_f32, current_mask, one_tile, zero_tile); + TSUB(bin_mask_f32, cur_f32, prev_f32); + + // Reduce the selected tile to get the count of elements less than pivot + // in this tile + TEXPANDS(count_f32_tile, 0.0f); + TEXPANDS(reduce_tmp, 0.0f); + TROWSUM(count_f32_tile, bin_mask_f32, reduce_tmp); + + // Scalar move to update UB local histogram + set_flag(PIPE_V, PIPE_S, EVENT_ID2); + wait_flag(PIPE_V, PIPE_S, EVENT_ID2); + float f_count = count_f32_tile.GetValue(0); + if (f_count > 0.0f) { + local_hist.SetValue(j, local_hist.GetValue(j) + f_count); + } + set_flag(PIPE_S, PIPE_V, EVENT_ID2); + wait_flag(PIPE_S, PIPE_V, EVENT_ID2); + + TMOV(prev_f32, cur_f32); + } + + set_flag(PIPE_V, PIPE_MTE2, ev); + ping = 1 - ping; + } + + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + + // --- Final Store to Global Memory --- + HistGlobalData z_gm(z_local + block_idx * num_bins, + {static_cast(block_num) * num_bins}); + + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(z_gm, local_hist); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + +#endif +} + +// Template parameter to avoid "no function" kernel launch error +template +AICORE void runTHistogramFinal(__gm__ float *z_local, __gm__ int32_t *z, + const int32_t num_bins, + const int32_t num_blocks) { +#if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + + if (get_block_idx() == 0) { + // --- Define Global Tensors --- + using InGlobalData = + pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + using OutGlobalData = + pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + + uint32_t addr = 0; + const uint32_t UB_IN_ADDR = addr; + addr += MAX_BLOCKS * MAX_BINS * sizeof(float); + const uint32_t UB_REDUCE_TMP_ADDR = addr; + addr += MAX_BINS * sizeof(float); + const uint32_t UB_FLOAT_OUT_ADDR = addr; + addr += MAX_BINS * sizeof(float); + const uint32_t UB_OUT_ADDR = addr; + addr += MAX_BINS * sizeof(int32_t); + + using InTile = Tile; + InTile in_tile( + {static_cast(num_blocks), static_cast(num_bins)}); + TASSIGN(in_tile, UB_IN_ADDR); + + using ReduceTmpTile = + Tile; + ReduceTmpTile reduce_tmp_tile(num_bins); + TASSIGN(reduce_tmp_tile, UB_REDUCE_TMP_ADDR); + TEXPANDS(reduce_tmp_tile, 0.0f); + + using FloatOutTile = + Tile; + FloatOutTile float_out_tile(num_bins); + TASSIGN(float_out_tile, UB_FLOAT_OUT_ADDR); + TEXPANDS(float_out_tile, 0.0f); + + using OutTile = Tile; + OutTile out_tile(num_bins); + TASSIGN(out_tile, UB_OUT_ADDR); + TEXPANDS(out_tile, static_cast(0)); + + // Load all block counts into UB row by row to match 2D tile padding + using InRowTile = + Tile; + InRowTile row_tile(static_cast(num_bins)); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + for (int32_t b = 0; b < num_blocks; ++b) { + InGlobalData z_local_gm(z_local + b * num_bins, {num_bins}); + TASSIGN(row_tile, UB_IN_ADDR + b * MAX_BINS * sizeof(float)); + TLOAD(row_tile, z_local_gm); + } + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + // TCOLSUM reduces along the row dimension (num_blocks) + TCOLSUM(float_out_tile, in_tile, reduce_tmp_tile, true); + + TCVT(out_tile, float_out_tile, RoundMode::CAST_RINT); + + // --- Final Store to Global Memory --- + OutGlobalData z_gm(z, {num_bins}); + + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(z_gm, out_tile); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + } + +#endif +} + +__global__ AICORE void histogram_local_fp32( + __gm__ void *x, __gm__ void *z_local, const uint32_t in_length, + const int32_t num_bins, const float min_val, const float max_val) { + runTLocalHistogram((__gm__ float *)x, + (__gm__ float *)z_local, in_length, + num_bins, min_val, max_val); +} + +__global__ AICORE void histogram_final(__gm__ void *z_local, __gm__ void *z, + const int32_t num_bins, + const int32_t num_blocks) { + runTHistogramFinal<0>((__gm__ float *)z_local, (__gm__ int32_t *)z, num_bins, + num_blocks); +} + +extern "C" void histogram_fp32(uint32_t num_blocks, void *stream, void *x, + void *z_local, void *z, const uint32_t in_length, + const int32_t num_bins, const float min_val, + const float max_val) { + histogram_local_fp32<<>>( + x, z_local, in_length, num_bins, min_val, max_val); + histogram_final<<<1, nullptr, stream>>>(z_local, z, num_bins, num_blocks); +} diff --git a/examples/jit_cpp/histogram/step3_scatter_index_to_gm/kernel_histogram.cpp b/examples/jit_cpp/histogram/step3_scatter_index_to_gm/kernel_histogram.cpp new file mode 100644 index 00000000..a5980052 --- /dev/null +++ b/examples/jit_cpp/histogram/step3_scatter_index_to_gm/kernel_histogram.cpp @@ -0,0 +1,240 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +All rights reserved. + +See LICENSE in the root of the software repository: +https://github.com/huawei-csl/pto-kernels/ +for the full License text. +*/ + +#include + +#include "../kernel_utils.h" + +using namespace pto; + +#ifndef HIST_TILE_SIZE +#define HIST_TILE_SIZE 1024 +#endif +constexpr uint32_t MAX_BINS = 256; +constexpr uint32_t MAX_BLOCKS = 64; + +/** + * runTLocalHistogram - Local histogram using MSCATTER with AtomicAdd + */ +template +AICORE void runTLocalHistogram(__gm__ T *x, __gm__ float *z_local, + const uint32_t total_length, + const int32_t num_bins, const float min_val, + const float max_val) { +#if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + + // --- Define Global Tensors --- + using InputGlobalData = pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + using HistGlobalData = + pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + + // --- Work Distribution --- + const uint32_t block_idx = get_block_idx(); + const uint32_t block_num = get_block_num(); + const uint32_t num_tiles_total = + kernel_utils::CeilDiv(total_length, TILE_SIZE); + const uint32_t num_tiles_per_core = + kernel_utils::CeilDiv(num_tiles_total, block_num); + const uint32_t start_idx = block_idx * num_tiles_per_core; + const uint32_t end_idx = (start_idx + num_tiles_per_core > num_tiles_total) + ? num_tiles_total + : (start_idx + num_tiles_per_core); + + // --- Define UB Tiles and Memory Layout --- + uint32_t addr = 0; + const uint32_t UB_X_PING = addr; + addr += TILE_SIZE * sizeof(T); + const uint32_t UB_X_PONG = addr; + addr += TILE_SIZE * sizeof(T); + const uint32_t UB_F32_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + const uint32_t UB_IDX_ADDR = addr; + addr += TILE_SIZE * sizeof(int32_t); + const uint32_t UB_ONES_ADDR = addr; + addr += TILE_SIZE * sizeof(float); + + InputGlobalData x_gm(x, {static_cast(total_length)}); + HistGlobalData z_gm(z_local + block_idx * num_bins, {num_bins}); + + using InputTileData = Tile; + using F32TileData = Tile; + using IdxTileData = Tile; + + F32TileData f32_tile; + TASSIGN(f32_tile, UB_F32_ADDR); + + IdxTileData idx_tile; + TASSIGN(idx_tile, UB_IDX_ADDR); + + F32TileData ones_tile; + TASSIGN(ones_tile, UB_ONES_ADDR); + TEXPANDS(ones_tile, 1.0f); + + const float inv_bin_width = + static_cast(num_bins) / (max_val - min_val); + + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + + // --- Main Calculation Loop --- + for (uint32_t tile_idx = start_idx, ping = 1; tile_idx < end_idx; + ++tile_idx) { + int offset = tile_idx * TILE_SIZE; + TASSIGN(x_gm, x + offset); + + const event_t ev = ping ? (event_t)EVENT_ID0 : (event_t)EVENT_ID1; + const unsigned x_base = ping ? UB_X_PING : UB_X_PONG; + + InputTileData x_tile; + TASSIGN(x_tile, x_base); + + wait_flag(PIPE_V, PIPE_MTE2, ev); + TLOAD(x_tile, x_gm); + set_flag(PIPE_MTE2, PIPE_V, ev); + wait_flag(PIPE_MTE2, PIPE_V, ev); + + // Calculate bin index: idx = (val - min_val) * inv_bin_width + TADDS(f32_tile, x_tile, -min_val); + TMULS(f32_tile, f32_tile, inv_bin_width); + + // Convert to int32 index using Floor rounding + TCVT(idx_tile, f32_tile, RoundMode::CAST_FLOOR); + + // Clamp indices to [0, num_bins - 1] to handle edge cases and outliers + TMAXS(idx_tile, idx_tile, static_cast(0)); + TMINS(idx_tile, idx_tile, static_cast(num_bins - 1)); + + // Atomic Scatter-Add to global memory histogram + MSCATTER(z_gm, ones_tile, idx_tile); + + set_flag(PIPE_V, PIPE_MTE2, ev); + ping = 1 - ping; + } + + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + +#endif +} + +// Template parameter to avoid "no function" kernel launch error +template +AICORE void runTHistogramFinal(__gm__ float *z_local, __gm__ int32_t *z, + const int32_t num_bins, + const int32_t num_blocks) { +#if __CCE_AICORE__ == 220 && defined(__DAV_C220_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + + if (get_block_idx() == 0) { + // --- Define Global Tensors --- + using InGlobalData = + pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + using OutGlobalData = + pto::GlobalTensor, + pto::Stride<1, 1, 1, 1, 1>>; + + uint32_t addr = 0; + const uint32_t UB_IN_ADDR = addr; + addr += MAX_BLOCKS * MAX_BINS * sizeof(float); + const uint32_t UB_REDUCE_TMP_ADDR = addr; + addr += MAX_BINS * sizeof(float); + const uint32_t UB_FLOAT_OUT_ADDR = addr; + addr += MAX_BINS * sizeof(float); + const uint32_t UB_OUT_ADDR = addr; + addr += MAX_BINS * sizeof(int32_t); + + using InTile = Tile; + InTile in_tile( + {static_cast(num_blocks), static_cast(num_bins)}); + TASSIGN(in_tile, UB_IN_ADDR); + + using ReduceTmpTile = + Tile; + ReduceTmpTile reduce_tmp_tile(num_bins); + TASSIGN(reduce_tmp_tile, UB_REDUCE_TMP_ADDR); + TEXPANDS(reduce_tmp_tile, 0.0f); + + using FloatOutTile = + Tile; + FloatOutTile float_out_tile(num_bins); + TASSIGN(float_out_tile, UB_FLOAT_OUT_ADDR); + TEXPANDS(float_out_tile, 0.0f); + + using OutTile = Tile; + OutTile out_tile(num_bins); + TASSIGN(out_tile, UB_OUT_ADDR); + TEXPANDS(out_tile, static_cast(0)); + + // Load all block counts into UB row by row to match 2D tile padding + using InRowTile = + Tile; + InRowTile row_tile(static_cast(num_bins)); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + for (int32_t b = 0; b < num_blocks; ++b) { + InGlobalData z_local_gm(z_local + b * num_bins, {num_bins}); + TASSIGN(row_tile, UB_IN_ADDR + b * MAX_BINS * sizeof(float)); + TLOAD(row_tile, z_local_gm); + } + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + // TCOLSUM reduces along the row dimension (num_blocks) + TCOLSUM(float_out_tile, in_tile, reduce_tmp_tile, true); + + TCVT(out_tile, float_out_tile, RoundMode::CAST_RINT); + + // --- Final Store to Global Memory --- + OutGlobalData z_gm(z, {num_bins}); + + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(z_gm, out_tile); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + } + +#endif +} + +__global__ AICORE void histogram_local_fp32( + __gm__ void *x, __gm__ void *z_local, const uint32_t in_length, + const int32_t num_bins, const float min_val, const float max_val) { + runTLocalHistogram((__gm__ float *)x, + (__gm__ float *)z_local, in_length, + num_bins, min_val, max_val); +} + +__global__ AICORE void histogram_final(__gm__ void *z_local, __gm__ void *z, + const int32_t num_bins, + const int32_t num_blocks) { + runTHistogramFinal<0>((__gm__ float *)z_local, (__gm__ int32_t *)z, num_bins, + num_blocks); +} + +extern "C" void histogram_fp32(uint32_t num_blocks, void *stream, void *x, + void *z_local, void *z, const uint32_t in_length, + const int32_t num_bins, const float min_val, + const float max_val) { + histogram_local_fp32<<>>( + x, z_local, in_length, num_bins, min_val, max_val); + histogram_final<<<1, nullptr, stream>>>(z_local, z, num_bins, num_blocks); +} diff --git a/tests/test_histogram.py b/tests/test_histogram.py new file mode 100644 index 00000000..21b82662 --- /dev/null +++ b/tests/test_histogram.py @@ -0,0 +1,39 @@ +# -------------------------------------------------------------------------------- +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# All rights reserved. +# See LICENSE in the root of the software repository: +# https://github.com/huawei-csl/pto-kernels/ +# for the full License text. +# -------------------------------------------------------------------------------- + +import torch +from pto_kernels import pto_histogram +import pytest + + +@pytest.mark.parametrize("size", [1, 8, 50, 256, 1000]) +@pytest.mark.parametrize("bins", [2, 4, 32, 100, 256]) +# torch.float16 requires different boundary calculation to match torch implementeation +@pytest.mark.parametrize("dtype", [torch.float32], ids=str) +def test_pto_histogram(size: int, bins: int, dtype: torch.dtype): + # Tile size is fixed in the kernel + tile_size = 512 + num_cores = torch.npu.get_device_properties(0).vector_core_num + num_tiles = num_cores * tile_size + total_len = num_tiles * size + + x = torch.randint(high=bins, size=(total_len,), device="cpu", dtype=dtype) + + # Golden PyTorch implementation + y_cpu = torch.histc(x, bins=bins).float() + + # NPU kernel execution, test to see if any race conditions occur across multiple runs + x_npu = x.npu().contiguous() + y_npu = [] + repeat_runs = 100 + for _ in range(repeat_runs): + y_npu.append(pto_histogram(x_npu, bins=bins).cpu().float()) + + torch.npu.synchronize() + + assert all(torch.equal(y_cpu, y) for y in y_npu)