From c9790d25c1bbaead93e6e72b1e7174d11ef8b516 Mon Sep 17 00:00:00 2001 From: learning-chip Date: Tue, 31 Mar 2026 13:24:49 +0000 Subject: [PATCH 01/14] standalone fast inverse for quick hacking --- examples/jit_cpp/fast_inverse/README.md | 50 +++++ .../jit_cpp/fast_inverse/fast_inverse.cpp | 37 ++++ .../fast_inverse/jit_util_fast_inverse.py | 128 +++++++++++++ .../jit_cpp/fast_inverse/run_fast_inverse.py | 178 ++++++++++++++++++ 4 files changed, 393 insertions(+) create mode 100644 examples/jit_cpp/fast_inverse/README.md create mode 100644 examples/jit_cpp/fast_inverse/fast_inverse.cpp create mode 100644 examples/jit_cpp/fast_inverse/jit_util_fast_inverse.py create mode 100644 examples/jit_cpp/fast_inverse/run_fast_inverse.py diff --git a/examples/jit_cpp/fast_inverse/README.md b/examples/jit_cpp/fast_inverse/README.md new file mode 100644 index 00000000..f7c47fd7 --- /dev/null +++ b/examples/jit_cpp/fast_inverse/README.md @@ -0,0 +1,50 @@ +## fast_inverse — JIT triangular matrix inverse (recursive unroll) + +JIT-compiled example of `kernel_tri_inv_rec_unroll`, which inverts a batch of +upper-triangular fp16 matrices stored in a multi-dimensional tensor. + +### Algorithm + +Given an input tensor whose last two dimensions form an n×n upper-triangular +matrix U (off-diagonal part only; the diagonal is assumed to be all-ones), the +kernel computes the inverse of (U + I) for every matrix in the batch. + +The implementation uses a two-phase recursive approach on Ascend cube cores: + +1. **Inv-trick phase** – inverts each 16×16 diagonal fractal block via a + Neumann-series expansion (`X = (I − M) + (I − M)·M + …`). +2. **Unrolled recursion phase** – assembles partial inverses of progressively + larger sub-blocks until the full matrix is inverted. + +### Files + +| File | Purpose | +|------|---------| +| `fast_inverse.cpp` | Thin JIT wrapper: includes the kernel and exposes `call_kernel` | +| `jit_util_fast_inverse.py` | Compiles the kernel with `bisheng` and loads it via `ctypes` | +| `run_fast_inverse.py` | Correctness test suite (mirrors the pytest unit tests) | + +### Usage + +```bash +export PTO_LIB_PATH=/sources/pto-isa/ # need latest header, not CANN 8.5.0 default + +cd examples/jit_cpp/fast_inverse +python run_fast_inverse.py +``` + +The script compiles `fast_inverse.cpp` on first run (takes ~60 s), then +executes correctness checks across a range of matrix sizes (16, 32, 64, 128) +and batch configurations. + +### Supported matrix sizes + +`matrix_size` (last dimension of the input tensor) must be one of: **16, 32, +64, 128**. + +### Layout conventions + +| `num_bsnd_heads` | Memory layout | +|-----------------|---------------| +| `0` (default) | Each matrix stored consecutively in row-major order (`B × … × N × D × D`) | +| `> 0` | BSND layout: `(B, S, N, D)` where S is chunked into tiles of size D and N heads are interleaved | diff --git a/examples/jit_cpp/fast_inverse/fast_inverse.cpp b/examples/jit_cpp/fast_inverse/fast_inverse.cpp new file mode 100644 index 00000000..a82b5df9 --- /dev/null +++ b/examples/jit_cpp/fast_inverse/fast_inverse.cpp @@ -0,0 +1,37 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +All rights reserved. + +See LICENSE in the root of the software repository: +https://github.com/huawei-csl/pto-kernels/ +for the full License text. +*/ + +// Include the triangular inverse kernel implementation. +// The build script adds csrc/kernel/ to the include path so that +// kernel_utils.h (included by kernel_tri_inv_rec_unroll.cpp) is found. +#include "kernel_tri_inv_rec_unroll.cpp" + +/** + * @brief JIT entry point for the triangular inverse (recursive unroll) kernel. + * + * @param blockDim Number of AI-Core blocks to launch. + * @param stream NPU stream handle. + * @param tensor_out fp32 output buffer (same element count as tensor_in). + * @param tensor_in fp16 input buffer holding the upper-triangular matrices + * (diagonal is assumed to be all-ones). + * @param minus_identity_in fp16 buffer of size matrix_size×matrix_size + * pre-filled with -I (negative identity). + * @param matrix_size Side length of each square matrix (16 / 32 / 64 / 128). + * @param num_matrices Total number of matrices to invert. + * @param num_bsnd_heads 0 for standard (B…ND) layout; + * N (number of heads) for BSND layout. + */ +extern "C" void call_kernel(uint32_t blockDim, void* stream, void* tensor_out, + void* tensor_in, void* minus_identity_in, + uint32_t matrix_size, uint32_t num_matrices, + uint32_t num_bsnd_heads) { + tri_inv_rec_unroll_fp16<<>>( + tensor_out, tensor_in, minus_identity_in, matrix_size, num_matrices, + num_bsnd_heads); +} diff --git a/examples/jit_cpp/fast_inverse/jit_util_fast_inverse.py b/examples/jit_cpp/fast_inverse/jit_util_fast_inverse.py new file mode 100644 index 00000000..6490ee66 --- /dev/null +++ b/examples/jit_cpp/fast_inverse/jit_util_fast_inverse.py @@ -0,0 +1,128 @@ +# -------------------------------------------------------------------------------- +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# All rights reserved. +# See LICENSE in the root of the software repository: +# https://github.com/huawei-csl/pto-kernels/ +# for the full License text. +# -------------------------------------------------------------------------------- + +import ctypes +import os +import subprocess + +import torch + +# --------------------------------------------------------------------------- +# Environment / paths +# --------------------------------------------------------------------------- +PTO_LIB_PATH = os.environ.get("PTO_LIB_PATH", os.environ["ASCEND_TOOLKIT_HOME"]) + +# Directory of this file → repo-root/examples/jit_cpp/fast_inverse +_THIS_DIR = os.path.dirname(os.path.abspath(__file__)) +# csrc/kernel lives three levels up from this file +_CSRC_KERNEL_DIR = os.path.abspath(os.path.join(_THIS_DIR, "../../../csrc/kernel")) + +BLOCK_DIM = int(getattr(torch.npu.get_device_properties("npu:0"), "cube_core_num", 20)) + + +# --------------------------------------------------------------------------- +# Compilation +# --------------------------------------------------------------------------- + +def compile_cpp(kernel_cpp: str, verbose: bool = False, timeout: int = 180) -> str: + """Compile *kernel_cpp* with bisheng and return the path to the .so.""" + lib_path = os.path.join(os.path.dirname(kernel_cpp), "fast_inverse_jit.so") + + flags = [ + "-fPIC", + "-shared", + "-xcce", + "-DMEMORY_BASE", + "-O2", + "-std=c++17", + # Resolve kernel_utils.h (included by kernel_tri_inv_rec_unroll.cpp) + f"-I{_CSRC_KERNEL_DIR}", + # PTO-ISA headers + f"-I{PTO_LIB_PATH}/include", + # Target the Ascend 910B cube core + "--cce-soc-version=Ascend910B4", + "--cce-soc-core-type=CubeCore", + ] + + command = ["bisheng", *flags, kernel_cpp, "-o", lib_path] + if verbose: + print("Compiling fast_inverse kernel:") + print(" ", " ".join(command)) + + try: + subprocess.run(command, timeout=timeout, check=True) + except Exception as exc: + raise RuntimeError(f"Compilation failed: {exc}") from exc + + if verbose: + print(f"Generated: {lib_path}") + return lib_path + + +# --------------------------------------------------------------------------- +# Loading +# --------------------------------------------------------------------------- + +def _torch_to_ctypes(tensor: torch.Tensor) -> ctypes.c_void_p: + return ctypes.c_void_p(tensor.data_ptr()) + + +def load_lib(lib_path: str): + """Load the compiled .so and return a Python callable for the kernel.""" + lib = ctypes.CDLL(os.path.abspath(lib_path)) + + lib.call_kernel.argtypes = [ + ctypes.c_uint32, # blockDim + ctypes.c_void_p, # stream + ctypes.c_void_p, # tensor_out (fp32) + ctypes.c_void_p, # tensor_in (fp16) + ctypes.c_void_p, # minus_identity_in (fp16) + ctypes.c_uint32, # matrix_size + ctypes.c_uint32, # num_matrices + ctypes.c_uint32, # num_bsnd_heads + ] + lib.call_kernel.restype = None + + def tri_inv_func( + tensor_out: torch.Tensor, + tensor_in: torch.Tensor, + minus_identity: torch.Tensor, + matrix_size: int, + num_matrices: int, + num_bsnd_heads: int = 0, + block_dim: int = BLOCK_DIM, + stream_ptr=None, + ): + if stream_ptr is None: + stream_ptr = torch.npu.current_stream()._as_parameter_ # noqa + effective_block_dim = min(block_dim, num_matrices) + lib.call_kernel( + effective_block_dim, + stream_ptr, + _torch_to_ctypes(tensor_out), + _torch_to_ctypes(tensor_in), + _torch_to_ctypes(minus_identity), + matrix_size, + num_matrices, + num_bsnd_heads, + ) + + return tri_inv_func + + +# --------------------------------------------------------------------------- +# Convenience: compile + load in one call +# --------------------------------------------------------------------------- + +def jit_compile(src_path: str, verbose: bool = True, clean_up: bool = False): + """Compile *src_path* and return the kernel callable.""" + lib_path = compile_cpp(src_path, verbose=verbose) + func = load_lib(lib_path) + if clean_up: + os.remove(lib_path) + return func diff --git a/examples/jit_cpp/fast_inverse/run_fast_inverse.py b/examples/jit_cpp/fast_inverse/run_fast_inverse.py new file mode 100644 index 00000000..345ffed7 --- /dev/null +++ b/examples/jit_cpp/fast_inverse/run_fast_inverse.py @@ -0,0 +1,178 @@ +# -------------------------------------------------------------------------------- +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# All rights reserved. +# See LICENSE in the root of the software repository: +# https://github.com/huawei-csl/pto-kernels/ +# for the full License text. +# -------------------------------------------------------------------------------- + +""" +Correctness tests for the JIT-compiled triangular inverse (recursive unroll) +kernel. Run from the fast_inverse/ directory: + + python run_fast_inverse.py +""" + +import numpy as np +import torch +import torch_npu # noqa: F401 – registers the NPU backend + +from jit_util_fast_inverse import jit_compile + +# --------------------------------------------------------------------------- +# Reproducibility +# --------------------------------------------------------------------------- +torch.manual_seed(42) +np.random.seed(42) + +# --------------------------------------------------------------------------- +# Matrix generators (identical to the unit-test suite) +# --------------------------------------------------------------------------- + +def random_triu_matrix(n, block_dim_x, block_dim_y, scale=0.1): + return scale * torch.triu(torch.rand((block_dim_x, block_dim_y, n, n)), diagonal=1) + + +def ones_triu_matrix(n, block_dim_x, block_dim_y): + return torch.triu(torch.ones((block_dim_x, block_dim_y, n, n)), diagonal=1) + + +def block_ones_triu_matrix(n, block_dim_x, block_dim_y): + U_ = np.ones((16, 16)) + n_blocks = n // 16 + U = np.zeros((block_dim_x, block_dim_y, n, n)) + for x in range(block_dim_x): + for y in range(block_dim_y): + for i in range(n_blocks): + s, e = i * 16, i * 16 + 16 + U[x, y, s:e, s:e] = U_ + return torch.from_numpy(np.triu(U, 1)) + + +def block_random_triu_matrix(n, block_dim_x, block_dim_y, scale=0.2): + U_ = np.triu(scale * np.random.rand(16, 16), k=1) + U = np.zeros((block_dim_x, block_dim_y, n, n)) + for x in range(block_dim_x): + for y in range(block_dim_y): + for i in range(0, n, 16): + U[x, y, i : i + 16, i : i + 16] = U_.copy() + return torch.from_numpy(U) + + +# --------------------------------------------------------------------------- +# Reference implementation (CPU / numpy) +# --------------------------------------------------------------------------- + +def linalg_inv_ref(U: torch.Tensor) -> torch.Tensor: + """Invert (U + I) for each matrix in the batch using numpy.""" + n = U.shape[-1] + identity = np.triu(np.tril(np.ones((n, n), dtype=np.double))) + out = np.zeros(U.shape) + for x in range(U.shape[0]): + for y in range(U.shape[1]): + out[x, y] = np.linalg.inv(U[x, y].numpy().astype(np.double) + identity) + return torch.from_numpy(out) + + +# --------------------------------------------------------------------------- +# Kernel helpers +# --------------------------------------------------------------------------- + +def _make_minus_identity(matrix_size: int, device: str) -> torch.Tensor: + I_neg = torch.zeros(matrix_size, matrix_size, dtype=torch.half, device=device) + I_neg.fill_diagonal_(-1) + return I_neg + + +def _run_kernel(tri_inv_func, U_fp16: torch.Tensor): + """ + Allocate output, build -I, run kernel, return fp64 CPU result. + + U_fp16 : (block_dim_x, block_dim_y, n, n) half tensor on NPU. + """ + matrix_size = U_fp16.shape[-1] + num_matrices = U_fp16.numel() // (matrix_size * matrix_size) + device = U_fp16.device + + tensor_out = torch.zeros_like(U_fp16, dtype=torch.float32) + I_neg = _make_minus_identity(matrix_size, str(device)) + + torch.npu.synchronize() + tri_inv_func(tensor_out, U_fp16, I_neg, matrix_size, num_matrices) + torch.npu.synchronize() + + return tensor_out.cpu().to(torch.float64) + + +# --------------------------------------------------------------------------- +# Single test +# --------------------------------------------------------------------------- + +def _test_case(tri_inv_func, U: torch.Tensor, atol: float, rtol: float, ftol: float, + label: str): + U_fp16 = U.to(torch.half) + golden = linalg_inv_ref(U_fp16) + + actual = _run_kernel(tri_inv_func, U_fp16.npu()) + + frob = torch.sqrt( + torch.sum((golden - actual) ** 2) / torch.sum(golden ** 2) + ).item() + + assert np.allclose( + actual.numpy(), golden.numpy(), atol=atol, rtol=rtol + ), f"[{label}] allclose failed — shape {U.shape}, rtol={rtol}" + assert frob <= ftol, f"[{label}] Frobenius error {frob:.2e} > {ftol:.2e}" + + print(f" PASS {label} frob={frob:.2e}") + + +# --------------------------------------------------------------------------- +# Test suite +# --------------------------------------------------------------------------- + +def run_tests(tri_inv_func): + cases = [ + ("block_ones", block_ones_triu_matrix, 0, 0, 0), + ("ones", ones_triu_matrix, 0, 0, 0), + ("block_random", block_random_triu_matrix, 5e-5, 0.1, 1e-4), + ("random", random_triu_matrix, 5e-5, 0.1, 1e-4), + ] + sizes = [16, 32, 64, 128] + x_dims = [1, 2, 4] + y_dims = [2, 4] + + total = passed = 0 + for n in sizes: + for bdx in x_dims: + for bdy in y_dims: + for name, gen, atol, rtol, ftol in cases: + total += 1 + label = f"n={n} x={bdx} y={bdy} [{name}]" + try: + U = gen(n, bdx, bdy) + _test_case(tri_inv_func, U, atol, rtol, ftol, label) + passed += 1 + except AssertionError as err: + print(f" FAIL {label}: {err}") + + print(f"\n{passed}/{total} tests passed.") + return passed == total + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + import os + + torch.npu.set_device("npu:0") + + src = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fast_inverse.cpp") + print(f"Compiling {src} ...") + tri_inv_func = jit_compile(src) + print("Compilation successful.\n") + + ok = run_tests(tri_inv_func) + raise SystemExit(0 if ok else 1) From 06b6a29b01a8b11be9ffb595d9195e0297572d73 Mon Sep 17 00:00:00 2001 From: learning-chip Date: Tue, 31 Mar 2026 13:35:50 +0000 Subject: [PATCH 02/14] also test bsnd branch --- .../jit_cpp/fast_inverse/run_fast_inverse.py | 93 ++++++++++++++++++- 1 file changed, 89 insertions(+), 4 deletions(-) diff --git a/examples/jit_cpp/fast_inverse/run_fast_inverse.py b/examples/jit_cpp/fast_inverse/run_fast_inverse.py index 345ffed7..7022b67c 100644 --- a/examples/jit_cpp/fast_inverse/run_fast_inverse.py +++ b/examples/jit_cpp/fast_inverse/run_fast_inverse.py @@ -104,8 +104,33 @@ def _run_kernel(tri_inv_func, U_fp16: torch.Tensor): return tensor_out.cpu().to(torch.float64) +def _run_kernel_bsnd(tri_inv_func, U_bsnd_fp16: torch.Tensor): + """ + Run the kernel in BSND mode and return fp64 CPU result. + + U_bsnd_fp16 : (B, S, N, D) half tensor on NPU where each (D, D) block + along the S dimension is one matrix to invert. + """ + matrix_size = U_bsnd_fp16.shape[-1] # D + num_bsnd_heads = U_bsnd_fp16.shape[-2] # N + num_matrices = U_bsnd_fp16.numel() // (matrix_size * matrix_size) + device = U_bsnd_fp16.device + + tensor_out = torch.zeros_like(U_bsnd_fp16, dtype=torch.float32) + I_neg = _make_minus_identity(matrix_size, str(device)) + + torch.npu.synchronize() + tri_inv_func( + tensor_out, U_bsnd_fp16, I_neg, + matrix_size, num_matrices, num_bsnd_heads, + ) + torch.npu.synchronize() + + return tensor_out.cpu().to(torch.float64) + + # --------------------------------------------------------------------------- -# Single test +# Single test – standard layout # --------------------------------------------------------------------------- def _test_case(tri_inv_func, U: torch.Tensor, atol: float, rtol: float, ftol: float, @@ -127,6 +152,39 @@ def _test_case(tri_inv_func, U: torch.Tensor, atol: float, rtol: float, ftol: fl print(f" PASS {label} frob={frob:.2e}") +# --------------------------------------------------------------------------- +# Single test – BSND layout +# --------------------------------------------------------------------------- + +def _test_case_bsnd(tri_inv_func, U: torch.Tensor, B: int, S: int, N: int, D: int, + atol: float, rtol: float, ftol: float, label: str): + """ + U has shape (B*S//D, N, D, D) – the raw generator output. + It is converted to (B, S, N, D) before being fed to the kernel, mirroring + the original pytest test_tri_inv_rec_unroll_bsnd helper. + """ + U_fp16 = U.to(torch.half) + # Compute reference in (B*S//D, N, D, D) space, then reshape to (B, S, N, D) + golden = linalg_inv_ref(U_fp16) + golden = golden.transpose(1, 2).contiguous().reshape(B, S, N, D) + + # Transform input to BSND layout: (B*S//D, N, D, D) → (B, S, N, D) + U_bsnd = U_fp16.transpose(1, 2).contiguous().reshape(B, S, N, D) + + actual = _run_kernel_bsnd(tri_inv_func, U_bsnd.npu()) + + frob = torch.sqrt( + torch.sum((golden - actual) ** 2) / torch.sum(golden ** 2) + ).item() + + assert np.allclose( + actual.numpy(), golden.numpy(), atol=atol, rtol=rtol + ), f"[{label}] allclose failed — shape {U_bsnd.shape}, rtol={rtol}" + assert frob <= ftol, f"[{label}] Frobenius error {frob:.2e} > {ftol:.2e}" + + print(f" PASS {label} frob={frob:.2e}") + + # --------------------------------------------------------------------------- # Test suite # --------------------------------------------------------------------------- @@ -138,11 +196,15 @@ def run_tests(tri_inv_func): ("block_random", block_random_triu_matrix, 5e-5, 0.1, 1e-4), ("random", random_triu_matrix, 5e-5, 0.1, 1e-4), ] - sizes = [16, 32, 64, 128] - x_dims = [1, 2, 4] - y_dims = [2, 4] total = passed = 0 + + # -- Standard layout tests ----------------------------------------------- + print("=== Standard layout ===") + sizes = [16, 32, 64, 128] + x_dims = [1, 2, 4] + y_dims = [2, 4] + for n in sizes: for bdx in x_dims: for bdy in y_dims: @@ -156,6 +218,29 @@ def run_tests(tri_inv_func): except AssertionError as err: print(f" FAIL {label}: {err}") + # -- BSND layout tests --------------------------------------------------- + print("\n=== BSND layout ===") + # Keep a representative subset: S must be a multiple of D + bsnd_configs = [ + (B, S, N, D) + for B in [1, 4] + for S in [128, 256] + for N in [4, 8] + for D in [16, 32, 64, 128] + if S % D == 0 + ] + + for B, S, N, D in bsnd_configs: + for name, gen, atol, rtol, ftol in cases: + total += 1 + label = f"B={B} S={S} N={N} D={D} [{name}]" + try: + U = gen(D, B * S // D, N) + _test_case_bsnd(tri_inv_func, U, B, S, N, D, atol, rtol, ftol, label) + passed += 1 + except AssertionError as err: + print(f" FAIL {label}: {err}") + print(f"\n{passed}/{total} tests passed.") return passed == total From cf961b01ffe824423cf59058ba5b9db2d2c7d9a9 Mon Sep 17 00:00:00 2001 From: learning-chip Date: Tue, 31 Mar 2026 15:05:08 +0000 Subject: [PATCH 03/14] support varlen version of bsnd inverse --- examples/jit_cpp/fast_inverse/README.md | 11 +- .../jit_cpp/fast_inverse/fast_inverse.cpp | 7 +- .../fast_inverse/jit_util_fast_inverse.py | 10 + .../kernel_tri_inv_rec_unroll.cpp | 637 ++++++++++++++++++ .../jit_cpp/fast_inverse/run_fast_inverse.py | 254 ++++++- 5 files changed, 879 insertions(+), 40 deletions(-) create mode 100644 examples/jit_cpp/fast_inverse/kernel_tri_inv_rec_unroll.cpp diff --git a/examples/jit_cpp/fast_inverse/README.md b/examples/jit_cpp/fast_inverse/README.md index f7c47fd7..a2487dc3 100644 --- a/examples/jit_cpp/fast_inverse/README.md +++ b/examples/jit_cpp/fast_inverse/README.md @@ -22,7 +22,7 @@ The implementation uses a two-phase recursive approach on Ascend cube cores: |------|---------| | `fast_inverse.cpp` | Thin JIT wrapper: includes the kernel and exposes `call_kernel` | | `jit_util_fast_inverse.py` | Compiles the kernel with `bisheng` and loads it via `ctypes` | -| `run_fast_inverse.py` | Correctness test suite (mirrors the pytest unit tests) | +| `run_fast_inverse.py` | Correctness test suite, including aligned and varlen BSND coverage | ### Usage @@ -48,3 +48,12 @@ and batch configurations. |-----------------|---------------| | `0` (default) | Each matrix stored consecutively in row-major order (`B × … × N × D × D`) | | `> 0` | BSND layout: `(B, S, N, D)` where S is chunked into tiles of size D and N heads are interleaved | + +### Varlen BSND mode + +The standalone example also supports variable-length BSND inputs by padding each +sequence to the next multiple of `D` and passing a `chunk_indices` tensor to the +kernel. Each entry in `chunk_indices` is the padded row-start of one valid +chunk. The kernel still inverts dense `D x D` tiles; the Python harness pads +inputs before launch and slices the padded rows back away when validating the +result. diff --git a/examples/jit_cpp/fast_inverse/fast_inverse.cpp b/examples/jit_cpp/fast_inverse/fast_inverse.cpp index a82b5df9..ea79df4a 100644 --- a/examples/jit_cpp/fast_inverse/fast_inverse.cpp +++ b/examples/jit_cpp/fast_inverse/fast_inverse.cpp @@ -26,12 +26,15 @@ for the full License text. * @param num_matrices Total number of matrices to invert. * @param num_bsnd_heads 0 for standard (B…ND) layout; * N (number of heads) for BSND layout. + * @param chunk_indices Optional int32 pointer used only for varlen BSND. Each + * entry is the absolute row offset of one padded D x D + * chunk within the BSND tensor. */ extern "C" void call_kernel(uint32_t blockDim, void* stream, void* tensor_out, void* tensor_in, void* minus_identity_in, uint32_t matrix_size, uint32_t num_matrices, - uint32_t num_bsnd_heads) { + uint32_t num_bsnd_heads, void* chunk_indices) { tri_inv_rec_unroll_fp16<<>>( tensor_out, tensor_in, minus_identity_in, matrix_size, num_matrices, - num_bsnd_heads); + num_bsnd_heads, chunk_indices); } diff --git a/examples/jit_cpp/fast_inverse/jit_util_fast_inverse.py b/examples/jit_cpp/fast_inverse/jit_util_fast_inverse.py index 6490ee66..a6e205ec 100644 --- a/examples/jit_cpp/fast_inverse/jit_util_fast_inverse.py +++ b/examples/jit_cpp/fast_inverse/jit_util_fast_inverse.py @@ -85,6 +85,7 @@ def load_lib(lib_path: str): ctypes.c_uint32, # matrix_size ctypes.c_uint32, # num_matrices ctypes.c_uint32, # num_bsnd_heads + ctypes.c_void_p, # chunk_indices (optional int32 metadata) ] lib.call_kernel.restype = None @@ -95,11 +96,17 @@ def tri_inv_func( matrix_size: int, num_matrices: int, num_bsnd_heads: int = 0, + chunk_indices: torch.Tensor | None = None, block_dim: int = BLOCK_DIM, stream_ptr=None, ): if stream_ptr is None: stream_ptr = torch.npu.current_stream()._as_parameter_ # noqa + if chunk_indices is not None: + if chunk_indices.dtype != torch.int32: + raise TypeError("chunk_indices must be int32.") + if not chunk_indices.is_contiguous(): + raise ValueError("chunk_indices must be contiguous.") effective_block_dim = min(block_dim, num_matrices) lib.call_kernel( effective_block_dim, @@ -110,6 +117,9 @@ def tri_inv_func( matrix_size, num_matrices, num_bsnd_heads, + _torch_to_ctypes(chunk_indices) + if chunk_indices is not None + else ctypes.c_void_p(), ) return tri_inv_func diff --git a/examples/jit_cpp/fast_inverse/kernel_tri_inv_rec_unroll.cpp b/examples/jit_cpp/fast_inverse/kernel_tri_inv_rec_unroll.cpp new file mode 100644 index 00000000..654f54ab --- /dev/null +++ b/examples/jit_cpp/fast_inverse/kernel_tri_inv_rec_unroll.cpp @@ -0,0 +1,637 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +All rights reserved. + +See LICENSE in the root of the software repository: +https://github.com/huawei-csl/pto-kernels/ +for the full License text. +*/ + +#ifndef MEMORY_BASE +#define MEMORY_BASE +#endif +#include + +#include "kernel_utils.h" + +#define GM_ADDR __gm__ uint8_t* // To avoid #include "kernel_operator.h" +using namespace pto; +using namespace kernel_utils; + +#define BSND_OFFSET(tile_id, N, S, D) \ + (((tile_id) / (N)) * (S) * (N) * (D) + ((tile_id) % (N)) * (D)) + +/* + * For varlen BSND, chunk_indices stores the absolute row offset (within the + * padded BSND tensor) of each D x D chunk. Each tile_id still enumerates + * chunk-major, then head-major. + */ +AICORE inline uint32_t GetBSNDTileOffset(uint32_t tile_id, + uint32_t num_bsnd_heads, + uint32_t matrix_size, + __gm__ int32_t* chunk_indices) { + const uint32_t head_idx = tile_id % num_bsnd_heads; + if (chunk_indices == nullptr) { + return BSND_OFFSET(tile_id, num_bsnd_heads, matrix_size, matrix_size); + } + const uint32_t chunk_idx = tile_id / num_bsnd_heads; + const uint32_t chunk_row_start = + static_cast(chunk_indices[chunk_idx]); + return chunk_row_start * num_bsnd_heads * matrix_size + + head_idx * matrix_size; +} + +/* + * @brief: Takes as input two matrices of size MatrixSize * MatrixSize each. + * The src matrix lies in L1, while the dst matrix lies either in L0A or L0B. + * This kernel copies only the diagonal blocks (fractals) of size FractalSize * + * FractalSize from the src matrix to the dst matrix. + * + * @tparam InputT Input data type (fp16). + * @tparam FractalSize Size of each fractal matrix (diagonal block). + * @tparam MatrixSize Size of the entire input/output matrices. + * @tparam SrcL1TileT The actual tile type of the src matrix. + * @tparam DstL0TileT The actual tile type of the dst matrix. + * + * @param src Tile in L1 memory. + * @param dst Tile in L0A or L0B memory. + */ +template +AICORE inline void CopyDiagonalFractalsL1ToL0(SrcL1TileT src, DstL0TileT dst) { + constexpr uint32_t NumFractals = MatrixSize / FractalSize; + constexpr bool is_left = + std::is_same_v>; + constexpr TileType LeftOrRight = is_left ? TileType::Left : TileType::Right; + constexpr SLayout InnerLayout = + is_left ? SLayout::RowMajor : SLayout::ColMajor; + + Tile + fractals[NumFractals]; + const std::uintptr_t starting_address = + reinterpret_cast(dst.data()); + for (uint32_t i = 0; i < NumFractals; ++i) { + TASSIGN(fractals[i], starting_address + i * FractalSize * + (MatrixSize + FractalSize) * + sizeof(InputT)); + TEXTRACT(fractals[i], src, i * FractalSize, i * FractalSize); + } +} + +/* + * @brief: Takes as input two matrices of size MatrixSize * MatrixSize each, + * and an integer block_size. The src matrix lies in L1, while the dst matrix + * either in L0A or L0B. This method copies some of the diagonal blocks from the + * input to the output as follows: + * - If dst is in L0A (left): copy even diagonal blocks 0, 2, 4, ... + * - If dst is in L0B (right): copy odd blocks 1, 3, 5, ... + * Important note: the dst matrix should be initialized to all-zeros before + * calling this method + * + * @tparam InputT Input data type (fp16). + * @tparam FractalSize Size of each fractal matrix (diagonal block). + * @tparam MatrixSize Size of the entire input/output matrices. + * @tparam SrcL1TileT The actual tile type of the src matrix. + * @tparam DstL0TileT The actual tile type of the dst matrix. + * + * @param src Tile in L1 memory. + * @param dst Tile in L0A or L0B memory. + * @param block_size Size of diagonal blocks. Needs: block_size >= FractalSize. + */ +template +AICORE inline void CopyOddOrEvenBlocksL1ToL0(SrcL1TileT src, DstL0TileT dst, + uint32_t block_size) { + constexpr bool is_left = + std::is_same_v>; + constexpr TileType LeftOrRight = is_left ? TileType::Left : TileType::Right; + constexpr SLayout InnerLayout = + is_left ? SLayout::RowMajor : SLayout::ColMajor; + + const uint32_t starting_block_index = is_left ? 0 : 1; + + const uint32_t num_blocks = MatrixSize / block_size; + const uint32_t num_fractals_per_block = block_size / FractalSize; + + Tile + fractals[MatrixSize / FractalSize]; + + const std::uintptr_t starting_address = + reinterpret_cast(dst.data()); + for (uint32_t i = 0; i < num_fractals_per_block; ++i) { + for (uint32_t j = 0; j < num_fractals_per_block; ++j) { + for (uint32_t b = starting_block_index; b < num_blocks; b += 2) { + const uint32_t offset = + b * (MatrixSize + FractalSize) * block_size + + i * MatrixSize * FractalSize + + j * FractalSize * FractalSize; + TASSIGN(fractals[b], starting_address + offset * sizeof(InputT)); + TEXTRACT(fractals[b], src, b * block_size + i * FractalSize, + b * block_size + j * FractalSize); + } + } + } +} + +/* + * @brief: Prepares Identity and Zeros matrix. + * + * @tparam TileL1AB The type of the input tiles in L1. + * @tparam TileL0A The type of the input tiles in L0A. + * @tparam TileL0B The type of the input tiles in L0B. + * @tparam TileL0C The type of the input tiles in L0C. + * + * @param I_neg_l1_tile Tile containing the -I (negative identity) matrix. + * @param Zero_l1_tile Tile to store the all-zero matrix. + * @param I_l1_tile Tile to store the identity matrix. + * @param a_l0_tile Tile in L0A for matmuls. + * @param b_l0_tile Tile in L0B for matmuls. + * @param c_l0_tile Tile in L0C for matmuls. + */ +template +AICORE inline void PrepareAuxiliaryMatrices( + TileL1AB I_neg_l1_tile, TileL1AB Zero_l1_tile, TileL1AB I_l1_tile, + TileL0A a_l0_tile, TileL0B b_l0_tile, TileL0C c_l0_tile) { + TMOV(a_l0_tile, I_neg_l1_tile); + TMOV(b_l0_tile, I_neg_l1_tile); + set_flag(PIPE_MTE1, PIPE_M, static_cast(0)); + wait_flag(PIPE_MTE1, PIPE_M, static_cast(0)); + + TMATMUL(c_l0_tile, a_l0_tile, b_l0_tile); + set_flag(PIPE_M, PIPE_FIX, static_cast(0)); + wait_flag(PIPE_M, PIPE_FIX, static_cast(0)); + + TMOV(I_l1_tile, c_l0_tile); + set_flag(PIPE_FIX, PIPE_MTE1, static_cast(0)); + wait_flag(PIPE_FIX, PIPE_MTE1, static_cast(0)); + + TMOV(b_l0_tile, I_l1_tile); + set_flag(PIPE_MTE1, PIPE_M, static_cast(0)); + wait_flag(PIPE_MTE1, PIPE_M, static_cast(0)); + + TMATMUL_ACC(c_l0_tile, c_l0_tile, a_l0_tile, b_l0_tile); + set_flag(PIPE_M, PIPE_FIX, static_cast(0)); + wait_flag(PIPE_M, PIPE_FIX, static_cast(0)); + + TMOV(Zero_l1_tile, c_l0_tile); + set_flag(PIPE_FIX, PIPE_MTE1, static_cast(0)); + wait_flag(PIPE_FIX, PIPE_MTE1, static_cast(0)); +} + +/* + * @brief: Inverts a single matrix / tile of the global tensor. + */ +template +AICORE inline void InvertSingleTile(TileL1AB X_l1_tile, TileL1AB I_l1_tile, + TileL1AB I_neg_l1_tile, + TileL1AB M_neg_l1_tile, + TileL1AB Zero_l1_tile, TileL1AB Y_l1_tile, + TileL0A* a_l0_tile, TileL0B* b_l0_tile, + TileL0C* c_l0_tile, + const uint32_t tile_id) { + const event_t event_0 = static_cast(tile_id); + const event_t event_1 = static_cast(tile_id + NumTilesPerCubeIter); + + TMOV(b_l0_tile[0], Y_l1_tile); + TMOV(a_l0_tile[0], I_neg_l1_tile); + set_flag(PIPE_MTE1, PIPE_M, event_0); + TMOV(a_l0_tile[1], Zero_l1_tile); + TMOV(b_l0_tile[1], Zero_l1_tile); + set_flag(PIPE_MTE1, PIPE_M, event_1); + wait_flag(PIPE_MTE1, PIPE_M, event_1); + set_flag(PIPE_M, PIPE_MTE1, event_1); + wait_flag(PIPE_M, PIPE_MTE1, event_1); + CopyDiagonalFractalsL1ToL0(Y_l1_tile, + a_l0_tile[1]); + CopyDiagonalFractalsL1ToL0(Y_l1_tile, + b_l0_tile[1]); + set_flag(PIPE_MTE1, PIPE_M, event_1); + + wait_flag(PIPE_MTE1, PIPE_M, event_0); + TMATMUL(c_l0_tile[0], a_l0_tile[0], b_l0_tile[0]); + set_flag(PIPE_M, PIPE_FIX, event_0); + set_flag(PIPE_M, PIPE_MTE1, event_0); + + wait_flag(PIPE_M, PIPE_FIX, event_0); + TMOV(M_neg_l1_tile, c_l0_tile[0]); + set_flag(PIPE_FIX, PIPE_M, event_0); + + wait_flag(PIPE_MTE1, PIPE_M, event_1); + set_flag(PIPE_MTE1, PIPE_M, event_1); + TMATMUL(c_l0_tile[1], a_l0_tile[1], b_l0_tile[1]); + set_flag(PIPE_M, PIPE_FIX, event_1); + wait_flag(PIPE_M, PIPE_FIX, event_1); + TMOV(Y_l1_tile, c_l0_tile[1]); + set_flag(PIPE_FIX, PIPE_M, event_1); + wait_flag(PIPE_FIX, PIPE_M, event_1); + + wait_flag(PIPE_M, PIPE_MTE1, event_0); + TMOV(b_l0_tile[0], I_neg_l1_tile); + TMOV(a_l0_tile[0], I_neg_l1_tile); + set_flag(PIPE_MTE1, PIPE_M, event_0); + + wait_flag(PIPE_MTE1, PIPE_M, event_0); + wait_flag(PIPE_FIX, PIPE_M, event_0); + wait_flag(PIPE_MTE1, PIPE_M, event_1); + TMATMUL(c_l0_tile[0], a_l0_tile[1], b_l0_tile[0]); + set_flag(PIPE_M, PIPE_FIX, event_0); + wait_flag(PIPE_M, PIPE_FIX, event_0); + set_flag(PIPE_FIX, PIPE_M, event_0); + wait_flag(PIPE_FIX, PIPE_M, event_0); + + TMATMUL_ACC(c_l0_tile[0], c_l0_tile[0], a_l0_tile[0], b_l0_tile[0]); + set_flag(PIPE_M, PIPE_FIX, event_1); + wait_flag(PIPE_M, PIPE_FIX, event_1); + TMOV(X_l1_tile, c_l0_tile[0]); + + set_flag(PIPE_FIX, PIPE_M, event_0); + set_flag(PIPE_M, PIPE_MTE1, event_0); + set_flag(PIPE_FIX, PIPE_MTE1, event_0); + set_flag(PIPE_FIX, PIPE_M, event_1); + set_flag(PIPE_M, PIPE_MTE1, event_1); + set_flag(PIPE_FIX, PIPE_MTE1, event_1); + for (uint32_t block_size = 1; block_size < FractalSize / 2; block_size *= 2) { + wait_flag(PIPE_M, PIPE_MTE1, event_0); + TMOV(b_l0_tile[0], I_l1_tile); + wait_flag(PIPE_FIX, PIPE_MTE1, event_0); + TMOV(a_l0_tile[0], X_l1_tile); + set_flag(PIPE_MTE1, PIPE_M, event_0); + + wait_flag(PIPE_FIX, PIPE_MTE1, event_1); + TMOV(b_l0_tile[1], Y_l1_tile); + set_flag(PIPE_MTE1, PIPE_M, event_1); + + wait_flag(PIPE_FIX, PIPE_M, event_0); + wait_flag(PIPE_MTE1, PIPE_M, event_0); + TMATMUL(c_l0_tile[0], a_l0_tile[0], b_l0_tile[0]); + set_flag(PIPE_M, PIPE_FIX, event_0); + wait_flag(PIPE_M, PIPE_FIX, event_0); + set_flag(PIPE_FIX, PIPE_M, event_0); + wait_flag(PIPE_FIX, PIPE_M, event_0); + + if (block_size < FractalSize / 4) { + wait_flag(PIPE_M, PIPE_MTE1, event_1); + TMOV(a_l0_tile[1], Y_l1_tile); + wait_flag(PIPE_MTE1, PIPE_M, event_1); + set_flag(PIPE_MTE1, PIPE_M, event_1); + + wait_flag(PIPE_MTE1, PIPE_M, event_1); + wait_flag(PIPE_FIX, PIPE_M, event_1); + TMATMUL(c_l0_tile[1], a_l0_tile[1], b_l0_tile[1]); + set_flag(PIPE_M, PIPE_MTE1, event_1); + set_flag(PIPE_M, PIPE_FIX, event_1); + set_flag(PIPE_MTE1, PIPE_M, event_1); + + wait_flag(PIPE_M, PIPE_FIX, event_1); + TMOV(Y_l1_tile, c_l0_tile[1]); + set_flag(PIPE_FIX, PIPE_M, event_1); + } + set_flag(PIPE_FIX, PIPE_MTE1, event_1); + + wait_flag(PIPE_MTE1, PIPE_M, event_1); + TMATMUL_ACC(c_l0_tile[0], c_l0_tile[0], a_l0_tile[0], b_l0_tile[1]); + set_flag(PIPE_M, PIPE_MTE1, event_0); + set_flag(PIPE_M, PIPE_FIX, event_0); + + wait_flag(PIPE_M, PIPE_FIX, event_0); + TMOV(X_l1_tile, c_l0_tile[0]); + set_flag(PIPE_FIX, PIPE_M, event_0); + set_flag(PIPE_FIX, PIPE_MTE1, event_0); + } + wait_flag(PIPE_FIX, PIPE_MTE1, event_1); + wait_flag(PIPE_M, PIPE_MTE1, event_1); + wait_flag(PIPE_FIX, PIPE_M, event_1); + wait_flag(PIPE_FIX, PIPE_MTE1, event_0); + wait_flag(PIPE_M, PIPE_MTE1, event_0); + wait_flag(PIPE_FIX, PIPE_M, event_0); + + TMOV(b_l0_tile[1], M_neg_l1_tile); + TMOV(a_l0_tile[0], I_l1_tile); + + if constexpr (MatrixSize > FractalSize) { + set_flag(PIPE_FIX, PIPE_M, event_1); + } + set_flag(PIPE_M, PIPE_MTE1, event_1); + set_flag(PIPE_M, PIPE_MTE1, event_0); + set_flag(PIPE_FIX, PIPE_MTE1, event_1); + set_flag(PIPE_FIX, PIPE_M, event_0); + for (uint32_t block_size = FractalSize; block_size < MatrixSize; + block_size *= 2) { + wait_flag(PIPE_M, PIPE_MTE1, event_0); + TMOV(a_l0_tile[1], Zero_l1_tile); + + wait_flag(PIPE_M, PIPE_MTE1, event_1); + TMOV(b_l0_tile[0], I_l1_tile); + set_flag(PIPE_MTE1, PIPE_M, event_0); + + wait_flag(PIPE_FIX, PIPE_MTE1, event_1); + CopyOddOrEvenBlocksL1ToL0(X_l1_tile, + a_l0_tile[1], + block_size); + set_flag(PIPE_MTE1, PIPE_M, event_1); + + wait_flag(PIPE_MTE1, PIPE_M, event_0); + wait_flag(PIPE_FIX, PIPE_M, event_0); + TMATMUL(c_l0_tile[0], a_l0_tile[0], b_l0_tile[0]); + + wait_flag(PIPE_MTE1, PIPE_M, event_1); + wait_flag(PIPE_FIX, PIPE_M, event_1); + TMATMUL(c_l0_tile[1], a_l0_tile[1], b_l0_tile[0]); + set_flag(PIPE_M, PIPE_MTE1, event_1); + + TMATMUL_ACC(c_l0_tile[0], c_l0_tile[0], a_l0_tile[1], b_l0_tile[1]); + set_flag(PIPE_M, PIPE_FIX, event_0); + set_flag(PIPE_M, PIPE_MTE1, event_0); + + wait_flag(PIPE_M, PIPE_FIX, event_0); + TMOV(Y_l1_tile, c_l0_tile[0]); + set_flag(PIPE_FIX, PIPE_MTE1, event_0); + set_flag(PIPE_FIX, PIPE_M, event_0); + + wait_flag(PIPE_M, PIPE_MTE1, event_1); + TMOV(b_l0_tile[0], Zero_l1_tile); + CopyOddOrEvenBlocksL1ToL0(X_l1_tile, + b_l0_tile[0], + block_size); + + wait_flag(PIPE_M, PIPE_MTE1, event_0); + wait_flag(PIPE_FIX, PIPE_MTE1, event_0); + TMOV(a_l0_tile[1], Y_l1_tile); + set_flag(PIPE_MTE1, PIPE_M, event_0); + + wait_flag(PIPE_MTE1, PIPE_M, event_0); + TMATMUL_ACC(c_l0_tile[1], c_l0_tile[1], a_l0_tile[1], b_l0_tile[0]); + set_flag(PIPE_M, PIPE_MTE1, event_0); + set_flag(PIPE_M, PIPE_MTE1, event_1); + set_flag(PIPE_M, PIPE_FIX, event_0); + wait_flag(PIPE_M, PIPE_FIX, event_0); + + if (block_size < MatrixSize / 2) { + TMOV(X_l1_tile, c_l0_tile[1]); + set_flag(PIPE_FIX, PIPE_M, event_1); + } + set_flag(PIPE_FIX, PIPE_MTE1, event_1); + } + wait_flag(PIPE_M, PIPE_MTE1, event_0); + wait_flag(PIPE_M, PIPE_MTE1, event_1); + wait_flag(PIPE_FIX, PIPE_M, event_0); + wait_flag(PIPE_FIX, PIPE_MTE1, event_1); +} + +/* + * @brief: Runs the main kernel (inverts all matrices in the tensor) + * + * When chunk_indices is non-null in BSND mode, it maps each chunk index + * (tile_id / num_bsnd_heads) to the absolute row offset in the padded BSND + * tensor. This lets the kernel support per-sequence padding without changing + * the per-tile inverse logic. + */ +template +AICORE inline void TriInvRecUnrollKernel(__gm__ OutputT* M_inv, + __gm__ InputT* M, __gm__ InputT* I_neg, + uint32_t total_tiles, + uint32_t num_bsnd_heads = 0, + __gm__ int32_t* chunk_indices = + nullptr) { + constexpr uint32_t TileLen = MatrixSize * MatrixSize; + constexpr uint32_t FractalSize = 16; + constexpr uint32_t NumL0Buffers = 2; + + if (get_block_idx() * NumTilesPerCubeIter >= total_tiles) { + return; + } + + using GlobalTileShapeIn = + TileShape2D; + using GlobalTileStridesIn = typename std::conditional< + !IsBSND, BaseShape2D, + Stride<1, 1, 1, -1, 1>>::type; + using GlobalTileIn = + GlobalTensor; + + using GlobalTileStridesINeg = + BaseShape2D; + using GlobalTileINeg = GlobalTensor; + + using GlobalTileShapeOut = + TileShape2D; + using GlobalTileStridesOut = typename std::conditional< + !IsBSND, BaseShape2D, + Stride<1, 1, 1, -1, 1>>::type; + using GlobalTileOut = GlobalTensor; + + using TileL1AB = + Tile; + + using TileL0A = TileLeft; + using TileL0B = TileRight; + using TileL0C = TileAcc; + + GlobalTileINeg I_neg_global_in(I_neg); + + TileL1AB X_l1_tile; + TileL1AB I_l1_tile; + TileL1AB I_neg_l1_tile; + TileL1AB M_neg_l1_tile; + TileL1AB Zero_l1_tile; + TileL1AB Y_l1_tile[NumTilesPerCubeIter]; + + TileL0A a_l0_tile[NumL0Buffers]; + TileL0B b_l0_tile[NumL0Buffers]; + TileL0C c_l0_tile[NumL0Buffers]; + + TASSIGN(I_l1_tile, 0x0); + TASSIGN(I_neg_l1_tile, 0x0 + TileLen * sizeof(InputT)); + TASSIGN(Zero_l1_tile, 0x0 + 2 * TileLen * sizeof(InputT)); + TASSIGN(M_neg_l1_tile, 0x0 + 3 * TileLen * sizeof(InputT)); + TASSIGN(X_l1_tile, 0x0 + 4 * TileLen * sizeof(InputT)); + for (uint32_t tile_id = 0; tile_id < NumTilesPerCubeIter; ++tile_id) { + TASSIGN(Y_l1_tile[tile_id], 0x0 + (5 + tile_id) * TileLen * sizeof(InputT)); + } + + for (uint32_t buffer_num = 0; buffer_num < NumL0Buffers; ++buffer_num) { + TASSIGN(a_l0_tile[buffer_num], 0x0 + buffer_num * TileLen * sizeof(InputT)); + TASSIGN(b_l0_tile[buffer_num], 0x0 + buffer_num * TileLen * sizeof(InputT)); + TASSIGN(c_l0_tile[buffer_num], + 0x0 + buffer_num * TileLen * sizeof(OutputT)); + } + TLOAD(I_neg_l1_tile, I_neg_global_in); + set_flag(PIPE_MTE2, PIPE_MTE1, static_cast(0)); + wait_flag(PIPE_MTE2, PIPE_MTE1, static_cast(0)); + + PrepareAuxiliaryMatrices( + I_neg_l1_tile, Zero_l1_tile, I_l1_tile, a_l0_tile[0], b_l0_tile[0], + c_l0_tile[0]); + + const uint32_t max_iters_per_aic = + CeilDiv(total_tiles, (uint32_t)(NumTilesPerCubeIter * get_block_num())); + + uint32_t next_tile_id_that_waits_for_pipe_fix_pipe_m = 0; + set_flag(PIPE_FIX, PIPE_M, + static_cast(next_tile_id_that_waits_for_pipe_fix_pipe_m)); + for (uint32_t tile_id = 0; tile_id < NumTilesPerCubeIter; ++tile_id) { + set_flag(PIPE_M, PIPE_MTE2, static_cast(tile_id)); + } + for (uint32_t cube_iter = 0; cube_iter < max_iters_per_aic; ++cube_iter) { + const uint32_t global_index = + (cube_iter * get_block_num() + get_block_idx()) * NumTilesPerCubeIter; + if (global_index >= total_tiles) { + break; + } + for (uint32_t tile_id = 0; (tile_id < NumTilesPerCubeIter) && + (global_index + tile_id < total_tiles); + ++tile_id) { + if constexpr (IsBSND) { + const uint32_t bsnd_offset = GetBSNDTileOffset( + global_index + tile_id, num_bsnd_heads, MatrixSize, chunk_indices); + GlobalTileIn M_global_in(M + bsnd_offset, {}, + {static_cast(MatrixSize * num_bsnd_heads)}); + wait_flag(PIPE_M, PIPE_MTE2, static_cast(tile_id)); + TLOAD(Y_l1_tile[tile_id], M_global_in); + } else { + GlobalTileIn M_global_in(M + (global_index + tile_id) * TileLen); + wait_flag(PIPE_M, PIPE_MTE2, static_cast(tile_id)); + TLOAD(Y_l1_tile[tile_id], M_global_in); + } + set_flag(PIPE_MTE2, PIPE_MTE1, static_cast(tile_id)); + } + + constexpr uint32_t final_c_buffer_index = MatrixSize > FractalSize ? 1 : 0; + for (uint32_t tile_id = 0; (tile_id < NumTilesPerCubeIter) && + (global_index + tile_id < total_tiles); + ++tile_id) { + wait_flag(PIPE_FIX, PIPE_M, static_cast(tile_id)); + wait_flag(PIPE_MTE2, PIPE_MTE1, static_cast(tile_id)); + + InvertSingleTile( + X_l1_tile, I_l1_tile, I_neg_l1_tile, M_neg_l1_tile, Zero_l1_tile, + Y_l1_tile[tile_id], a_l0_tile, b_l0_tile, c_l0_tile, tile_id); + + set_flag(PIPE_M, PIPE_MTE2, static_cast(tile_id)); + + if constexpr (IsBSND) { + const uint32_t bsnd_offset = GetBSNDTileOffset( + global_index + tile_id, num_bsnd_heads, MatrixSize, chunk_indices); + GlobalTileOut M_inv_global_out( + M_inv + bsnd_offset, {}, + {static_cast(MatrixSize * num_bsnd_heads)}); + TSTORE(M_inv_global_out, c_l0_tile[final_c_buffer_index]); + } else { + GlobalTileOut M_inv_global_out(M_inv + + (global_index + tile_id) * TileLen); + TSTORE(M_inv_global_out, c_l0_tile[final_c_buffer_index]); + } + next_tile_id_that_waits_for_pipe_fix_pipe_m = + (tile_id + 1) % NumTilesPerCubeIter; + set_flag( + PIPE_FIX, PIPE_M, + static_cast(next_tile_id_that_waits_for_pipe_fix_pipe_m)); + } + } + for (uint32_t tile_id = 0; tile_id < NumTilesPerCubeIter; ++tile_id) { + wait_flag(PIPE_M, PIPE_MTE2, static_cast(tile_id)); + } + wait_flag(PIPE_FIX, PIPE_M, + static_cast(next_tile_id_that_waits_for_pipe_fix_pipe_m)); +} + +template +AICORE void runKernelTriInvRecUnroll(__gm__ OutputT* M_inv, __gm__ InputT* M, + __gm__ InputT* I_neg, uint32_t total_tiles, + uint32_t num_bsnd_heads = 0, + __gm__ int32_t* chunk_indices = nullptr) { +#if (__CHECK_FEATURE_AT_PRECOMPILE) || \ + (__CCE_AICORE__ == 220 && defined(__DAV_C220_CUBE__)) + + TriInvRecUnrollKernel(M_inv, M, I_neg, total_tiles, num_bsnd_heads, + chunk_indices); +#else +// Nothing to do on AIV +#endif +} + +template +AICORE void run_tri_inv_rec_unroll(__gm__ float* tensor_out, + __gm__ InputT* tensor_in, + __gm__ InputT* minus_identity_in, + uint32_t matrix_size, uint32_t num_matrices, + uint32_t num_bsnd_heads, + __gm__ int32_t* chunk_indices) { + static_assert(std::is_same_v, + "tri_inv_rec_unroll supports only fp16."); + switch (matrix_size) { + case 16: + runKernelTriInvRecUnroll( + tensor_out, tensor_in, minus_identity_in, num_matrices, + num_bsnd_heads, chunk_indices); + break; + case 32: + runKernelTriInvRecUnroll( + tensor_out, tensor_in, minus_identity_in, num_matrices, + num_bsnd_heads, chunk_indices); + break; + case 64: + runKernelTriInvRecUnroll( + tensor_out, tensor_in, minus_identity_in, num_matrices, + num_bsnd_heads, chunk_indices); + break; + case 128: + runKernelTriInvRecUnroll( + tensor_out, tensor_in, minus_identity_in, num_matrices, + num_bsnd_heads, chunk_indices); + break; + } +} + +extern "C" __global__ AICORE void tri_inv_rec_unroll_fp16( + __gm__ void* tensor_out, __gm__ void* tensor_in, + __gm__ void* minus_identity_in, uint32_t matrix_size, uint32_t num_matrices, + uint32_t num_bsnd_heads, __gm__ void* chunk_indices) { + if (num_bsnd_heads == 0) { + if (num_matrices <= get_block_num()) { + run_tri_inv_rec_unroll( + (__gm__ float*)tensor_out, (__gm__ half*)tensor_in, + (__gm__ half*)minus_identity_in, matrix_size, num_matrices, + num_bsnd_heads, (__gm__ int32_t*)chunk_indices); + } else if (num_matrices <= 2 * get_block_num()) { + run_tri_inv_rec_unroll( + (__gm__ float*)tensor_out, (__gm__ half*)tensor_in, + (__gm__ half*)minus_identity_in, matrix_size, num_matrices, + num_bsnd_heads, (__gm__ int32_t*)chunk_indices); + } else { + run_tri_inv_rec_unroll( + (__gm__ float*)tensor_out, (__gm__ half*)tensor_in, + (__gm__ half*)minus_identity_in, matrix_size, num_matrices, + num_bsnd_heads, (__gm__ int32_t*)chunk_indices); + } + } else { + if (num_matrices <= get_block_num()) { + run_tri_inv_rec_unroll( + (__gm__ float*)tensor_out, (__gm__ half*)tensor_in, + (__gm__ half*)minus_identity_in, matrix_size, num_matrices, + num_bsnd_heads, (__gm__ int32_t*)chunk_indices); + } else if (num_matrices <= 2 * get_block_num()) { + run_tri_inv_rec_unroll( + (__gm__ float*)tensor_out, (__gm__ half*)tensor_in, + (__gm__ half*)minus_identity_in, matrix_size, num_matrices, + num_bsnd_heads, (__gm__ int32_t*)chunk_indices); + } else { + run_tri_inv_rec_unroll( + (__gm__ float*)tensor_out, (__gm__ half*)tensor_in, + (__gm__ half*)minus_identity_in, matrix_size, num_matrices, + num_bsnd_heads, (__gm__ int32_t*)chunk_indices); + } + } +} diff --git a/examples/jit_cpp/fast_inverse/run_fast_inverse.py b/examples/jit_cpp/fast_inverse/run_fast_inverse.py index 7022b67c..af0a9422 100644 --- a/examples/jit_cpp/fast_inverse/run_fast_inverse.py +++ b/examples/jit_cpp/fast_inverse/run_fast_inverse.py @@ -8,7 +8,7 @@ """ Correctness tests for the JIT-compiled triangular inverse (recursive unroll) -kernel. Run from the fast_inverse/ directory: +kernel. Run from the fast_inverse/ directory: python run_fast_inverse.py """ @@ -25,12 +25,15 @@ torch.manual_seed(42) np.random.seed(42) + # --------------------------------------------------------------------------- -# Matrix generators (identical to the unit-test suite) +# Matrix generators (identical to the unit-test suite) # --------------------------------------------------------------------------- - def random_triu_matrix(n, block_dim_x, block_dim_y, scale=0.1): - return scale * torch.triu(torch.rand((block_dim_x, block_dim_y, n, n)), diagonal=1) + return scale * torch.triu( + torch.rand((block_dim_x, block_dim_y, n, n)), + diagonal=1, + ) def ones_triu_matrix(n, block_dim_x, block_dim_y): @@ -60,24 +63,30 @@ def block_random_triu_matrix(n, block_dim_x, block_dim_y, scale=0.2): # --------------------------------------------------------------------------- -# Reference implementation (CPU / numpy) +# Reference implementation (CPU / numpy) # --------------------------------------------------------------------------- - def linalg_inv_ref(U: torch.Tensor) -> torch.Tensor: """Invert (U + I) for each matrix in the batch using numpy.""" n = U.shape[-1] - identity = np.triu(np.tril(np.ones((n, n), dtype=np.double))) - out = np.zeros(U.shape) + identity = np.eye(n, dtype=np.double) + out = np.zeros(U.shape, dtype=np.double) for x in range(U.shape[0]): for y in range(U.shape[1]): out[x, y] = np.linalg.inv(U[x, y].numpy().astype(np.double) + identity) return torch.from_numpy(out) +def invert_single_chunk_ref(U: torch.Tensor) -> torch.Tensor: + """Invert one upper-triangular chunk U where U is (..., m, m).""" + m = U.shape[-1] + return torch.from_numpy( + np.linalg.inv(U.numpy().astype(np.double) + np.eye(m, dtype=np.double)) + ) + + # --------------------------------------------------------------------------- # Kernel helpers # --------------------------------------------------------------------------- - def _make_minus_identity(matrix_size: int, device: str) -> torch.Tensor: I_neg = torch.zeros(matrix_size, matrix_size, dtype=torch.half, device=device) I_neg.fill_diagonal_(-1) @@ -104,37 +113,121 @@ def _run_kernel(tri_inv_func, U_fp16: torch.Tensor): return tensor_out.cpu().to(torch.float64) -def _run_kernel_bsnd(tri_inv_func, U_bsnd_fp16: torch.Tensor): +def _run_kernel_bsnd( + tri_inv_func, + U_bsnd_fp16: torch.Tensor, + chunk_indices: torch.Tensor | None = None, +): """ Run the kernel in BSND mode and return fp64 CPU result. U_bsnd_fp16 : (B, S, N, D) half tensor on NPU where each (D, D) block along the S dimension is one matrix to invert. + chunk_indices : optional int32 tensor containing the padded row start of + each valid chunk for varlen BSND inputs. """ - matrix_size = U_bsnd_fp16.shape[-1] # D - num_bsnd_heads = U_bsnd_fp16.shape[-2] # N + matrix_size = U_bsnd_fp16.shape[-1] + num_bsnd_heads = U_bsnd_fp16.shape[-2] num_matrices = U_bsnd_fp16.numel() // (matrix_size * matrix_size) device = U_bsnd_fp16.device tensor_out = torch.zeros_like(U_bsnd_fp16, dtype=torch.float32) I_neg = _make_minus_identity(matrix_size, str(device)) + if chunk_indices is not None: + chunk_indices = chunk_indices.to(device=device, dtype=torch.int32).contiguous() torch.npu.synchronize() tri_inv_func( - tensor_out, U_bsnd_fp16, I_neg, - matrix_size, num_matrices, num_bsnd_heads, + tensor_out, + U_bsnd_fp16, + I_neg, + matrix_size, + num_matrices, + num_bsnd_heads, + chunk_indices=chunk_indices, ) torch.npu.synchronize() return tensor_out.cpu().to(torch.float64) +def _build_varlen_bsnd_case( + gen, + cu_seqlens: list[int], + num_heads: int, + chunk_size: int, +): + """ + Build a padded BSND tensor plus reference output for varlen testing. + + Each sequence is padded independently to the next multiple of chunk_size. + chunk_indices records the padded row offset of every valid chunk. + """ + seq_lens = [ + cu_seqlens[i + 1] - cu_seqlens[i] + for i in range(len(cu_seqlens) - 1) + ] + print( + f" varlen sequence lengths: {seq_lens} " + f"(chunk_size={chunk_size}, num_heads={num_heads})" + ) + + total_tokens = cu_seqlens[-1] + num_chunks = sum( + (cu_seqlens[i + 1] - cu_seqlens[i] + chunk_size - 1) // chunk_size + for i in range(len(cu_seqlens) - 1) + ) + chunk_mats = gen(chunk_size, num_chunks, num_heads).to(torch.float64) + + padded_total = num_chunks * chunk_size + U_padded = torch.zeros((1, padded_total, num_heads, chunk_size), dtype=torch.float64) + golden = torch.zeros((1, total_tokens, num_heads, chunk_size), dtype=torch.float64) + + chunk_indices: list[int] = [] + chunk_infos: list[tuple[int, int, int]] = [] + chunk_idx = 0 + padded_row = 0 + + for seq_idx in range(len(cu_seqlens) - 1): + seq_start = cu_seqlens[seq_idx] + seq_end = cu_seqlens[seq_idx + 1] + for chunk_start in range(seq_start, seq_end, chunk_size): + actual_size = min(chunk_size, seq_end - chunk_start) + chunk = chunk_mats[chunk_idx] + for head_idx in range(num_heads): + U_valid = chunk[head_idx, :actual_size, :actual_size] + U_padded[ + 0, + padded_row : padded_row + actual_size, + head_idx, + :actual_size, + ] = U_valid + golden[ + 0, + chunk_start : chunk_start + actual_size, + head_idx, + :actual_size, + ] = invert_single_chunk_ref(U_valid) + + chunk_indices.append(padded_row) + chunk_infos.append((padded_row, chunk_start, actual_size)) + padded_row += chunk_size + chunk_idx += 1 + + return U_padded, golden, chunk_infos, torch.tensor(chunk_indices, dtype=torch.int32) + + # --------------------------------------------------------------------------- # Single test – standard layout # --------------------------------------------------------------------------- - -def _test_case(tri_inv_func, U: torch.Tensor, atol: float, rtol: float, ftol: float, - label: str): +def _test_case( + tri_inv_func, + U: torch.Tensor, + atol: float, + rtol: float, + ftol: float, + label: str, +): U_fp16 = U.to(torch.half) golden = linalg_inv_ref(U_fp16) @@ -145,7 +238,10 @@ def _test_case(tri_inv_func, U: torch.Tensor, atol: float, rtol: float, ftol: fl ).item() assert np.allclose( - actual.numpy(), golden.numpy(), atol=atol, rtol=rtol + actual.numpy(), + golden.numpy(), + atol=atol, + rtol=rtol, ), f"[{label}] allclose failed — shape {U.shape}, rtol={rtol}" assert frob <= ftol, f"[{label}] Frobenius error {frob:.2e} > {ftol:.2e}" @@ -155,22 +251,27 @@ def _test_case(tri_inv_func, U: torch.Tensor, atol: float, rtol: float, ftol: fl # --------------------------------------------------------------------------- # Single test – BSND layout # --------------------------------------------------------------------------- - -def _test_case_bsnd(tri_inv_func, U: torch.Tensor, B: int, S: int, N: int, D: int, - atol: float, rtol: float, ftol: float, label: str): +def _test_case_bsnd( + tri_inv_func, + U: torch.Tensor, + B: int, + S: int, + N: int, + D: int, + atol: float, + rtol: float, + ftol: float, + label: str, +): """ U has shape (B*S//D, N, D, D) – the raw generator output. - It is converted to (B, S, N, D) before being fed to the kernel, mirroring - the original pytest test_tri_inv_rec_unroll_bsnd helper. + It is converted to (B, S, N, D) before being fed to the kernel. """ U_fp16 = U.to(torch.half) - # Compute reference in (B*S//D, N, D, D) space, then reshape to (B, S, N, D) golden = linalg_inv_ref(U_fp16) golden = golden.transpose(1, 2).contiguous().reshape(B, S, N, D) - # Transform input to BSND layout: (B*S//D, N, D, D) → (B, S, N, D) U_bsnd = U_fp16.transpose(1, 2).contiguous().reshape(B, S, N, D) - actual = _run_kernel_bsnd(tri_inv_func, U_bsnd.npu()) frob = torch.sqrt( @@ -178,30 +279,83 @@ def _test_case_bsnd(tri_inv_func, U: torch.Tensor, B: int, S: int, N: int, D: in ).item() assert np.allclose( - actual.numpy(), golden.numpy(), atol=atol, rtol=rtol + actual.numpy(), + golden.numpy(), + atol=atol, + rtol=rtol, ), f"[{label}] allclose failed — shape {U_bsnd.shape}, rtol={rtol}" assert frob <= ftol, f"[{label}] Frobenius error {frob:.2e} > {ftol:.2e}" print(f" PASS {label} frob={frob:.2e}") +def _test_case_bsnd_varlen( + tri_inv_func, + gen, + cu_seqlens: list[int], + N: int, + D: int, + atol: float, + rtol: float, + ftol: float, + label: str, +): + U_padded, golden, chunk_infos, chunk_indices = _build_varlen_bsnd_case( + gen, + cu_seqlens, + N, + D, + ) + actual_padded = _run_kernel_bsnd( + tri_inv_func, + U_padded.to(torch.half).npu(), + chunk_indices=chunk_indices.npu(), + ) + + actual = torch.zeros_like(golden) + for padded_row, token_row, actual_size in chunk_infos: + actual[ + :, + token_row : token_row + actual_size, + :, + :actual_size, + ] = actual_padded[ + :, + padded_row : padded_row + actual_size, + :, + :actual_size, + ] + + frob = torch.sqrt( + torch.sum((golden - actual) ** 2) / torch.sum(golden ** 2) + ).item() + + assert np.allclose( + actual.numpy(), + golden.numpy(), + atol=atol, + rtol=rtol, + ), f"[{label}] allclose failed — shape {actual.shape}, rtol={rtol}" + assert frob <= ftol, f"[{label}] Frobenius error {frob:.2e} > {ftol:.2e}" + + print(f" PASS {label} frob={frob:.2e}") + + # --------------------------------------------------------------------------- # Test suite # --------------------------------------------------------------------------- - def run_tests(tri_inv_func): cases = [ - ("block_ones", block_ones_triu_matrix, 0, 0, 0), - ("ones", ones_triu_matrix, 0, 0, 0), - ("block_random", block_random_triu_matrix, 5e-5, 0.1, 1e-4), - ("random", random_triu_matrix, 5e-5, 0.1, 1e-4), + ("block_ones", block_ones_triu_matrix, 0, 0, 0), + ("ones", ones_triu_matrix, 0, 0, 0), + ("block_random", block_random_triu_matrix, 5e-5, 0.1, 1e-4), + ("random", random_triu_matrix, 5e-5, 0.1, 1e-4), ] total = passed = 0 - # -- Standard layout tests ----------------------------------------------- print("=== Standard layout ===") - sizes = [16, 32, 64, 128] + sizes = [16, 32, 64, 128] x_dims = [1, 2, 4] y_dims = [2, 4] @@ -218,9 +372,7 @@ def run_tests(tri_inv_func): except AssertionError as err: print(f" FAIL {label}: {err}") - # -- BSND layout tests --------------------------------------------------- print("\n=== BSND layout ===") - # Keep a representative subset: S must be a multiple of D bsnd_configs = [ (B, S, N, D) for B in [1, 4] @@ -241,6 +393,35 @@ def run_tests(tri_inv_func): except AssertionError as err: print(f" FAIL {label}: {err}") + print("\n=== BSND varlen layout ===") + varlen_configs = [ + (4, 16, [0, 15]), + (4, 32, [0, 256, 500, 1000]), + (4, 64, [0, 15, 100, 300, 1200, 2000]), + (4, 16, [0, 1, 100, 300, 1200, 2048]), + (4, 32, [0, 200, 512, 1200, 2048]), + ] + + for N, D, cu_seqlens in varlen_configs: + for name, gen, atol, rtol, ftol in cases: + total += 1 + label = f"N={N} D={D} cu={cu_seqlens} [{name}]" + try: + _test_case_bsnd_varlen( + tri_inv_func, + gen, + cu_seqlens, + N, + D, + atol, + rtol, + ftol, + label, + ) + passed += 1 + except AssertionError as err: + print(f" FAIL {label}: {err}") + print(f"\n{passed}/{total} tests passed.") return passed == total @@ -248,7 +429,6 @@ def run_tests(tri_inv_func): # --------------------------------------------------------------------------- # Entry point # --------------------------------------------------------------------------- - if __name__ == "__main__": import os From 0a61a7212deb803b3224ade974252153941bced8 Mon Sep 17 00:00:00 2001 From: learning-chip Date: Tue, 31 Mar 2026 15:53:13 +0000 Subject: [PATCH 04/14] add bandwidth benchmark for varlen inverse kernel --- examples/jit_cpp/fast_inverse/README.md | 24 + .../benchmark_bsnd_fast_inverse.py | 640 ++++++++++++++++++ 2 files changed, 664 insertions(+) create mode 100644 examples/jit_cpp/fast_inverse/benchmark_bsnd_fast_inverse.py diff --git a/examples/jit_cpp/fast_inverse/README.md b/examples/jit_cpp/fast_inverse/README.md index a2487dc3..b5a6cdfb 100644 --- a/examples/jit_cpp/fast_inverse/README.md +++ b/examples/jit_cpp/fast_inverse/README.md @@ -23,6 +23,7 @@ The implementation uses a two-phase recursive approach on Ascend cube cores: | `fast_inverse.cpp` | Thin JIT wrapper: includes the kernel and exposes `call_kernel` | | `jit_util_fast_inverse.py` | Compiles the kernel with `bisheng` and loads it via `ctypes` | | `run_fast_inverse.py` | Correctness test suite, including aligned and varlen BSND coverage | +| `benchmark_bsnd_fast_inverse.py` | Benchmarks fixed BSND vs varlen-uniform BSND and plots effective bandwidth | ### Usage @@ -57,3 +58,26 @@ kernel. Each entry in `chunk_indices` is the padded row-start of one valid chunk. The kernel still inverts dense `D x D` tiles; the Python harness pads inputs before launch and slices the padded rows back away when validating the result. + +### Benchmark + +To compare the original fixed-length BSND path against the new varlen path in a +matched-size sanity check: + +```bash +export PTO_LIB_PATH=/sources/pto-isa/ + +cd examples/jit_cpp/fast_inverse +python benchmark_bsnd_fast_inverse.py --chunk-size 64 +``` + +The benchmark script: + +- runs only the PTO-ISA BSND kernel +- compares `bsnd-fixed` against `bsnd-varlen-uniform` +- uses uniform `cu_seqlens=[0, T, 2T, ...]` so both paths process the same + total data size +- reports numerical agreement between the two outputs +- also generates a true-varlen benchmark that plots scattered bandwidth points + against aggregated sequence length +- writes all CSV and PNG artifacts into `examples/jit_cpp/fast_inverse/benchmark_results/` diff --git a/examples/jit_cpp/fast_inverse/benchmark_bsnd_fast_inverse.py b/examples/jit_cpp/fast_inverse/benchmark_bsnd_fast_inverse.py new file mode 100644 index 00000000..7fa4820c --- /dev/null +++ b/examples/jit_cpp/fast_inverse/benchmark_bsnd_fast_inverse.py @@ -0,0 +1,640 @@ +#!/usr/bin/env python3 +# -------------------------------------------------------------------------------- +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# All rights reserved. +# See LICENSE in the root of the software repository: +# https://github.com/huawei-csl/pto-kernels/ +# for the full License text. +# -------------------------------------------------------------------------------- + +""" +Benchmark the standalone BSND fast-inverse kernel. + +This script only benchmarks the PTO-ISA BSND kernel in two modes: + +1. `bsnd-fixed`: + Original aligned BSND layout with shape `(B, T, H, D)`. +2. `bsnd-varlen-uniform`: + The new varlen path using packed shape `(1, B*T, H, D)` with uniform + `cu_seqlens = [0, T, 2T, ...]`. + +The two modes use the same total token count and the same underlying chunk data, +so their latency / effective bandwidth can be compared directly. The script also +checks that both modes produce numerically matching results. +""" + +from __future__ import annotations + +import argparse +import csv +import math +import os +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch_npu # noqa: F401 + +from jit_util_fast_inverse import jit_compile + + +DEFAULT_SEQLENS = (512, 1024, 2048, 4096, 8192, 16384) +DEFAULT_CACHE_SIZE = 256 * 1024 * 1024 +NPU_DEVICE = os.getenv("GDN_TRI_INVERSE_NPU_DEVICE", "npu:0") +THIS_DIR = Path(__file__).resolve().parent +RESULTS_DIR = THIS_DIR / "benchmark_results" +DEFAULT_TRUE_VARLEN_SAMPLES = 6 + + +def parse_int_list(spec: str) -> tuple[int, ...]: + parts = [p.strip() for p in spec.split(",") if p.strip()] + if not parts: + raise argparse.ArgumentTypeError("expected at least one integer") + try: + return tuple(int(p, 10) for p in parts) + except ValueError as exc: + raise argparse.ArgumentTypeError(f"invalid integer list {spec!r}: {exc}") from exc + + +def make_minus_identity(matrix_size: int, device: str) -> torch.Tensor: + minus_identity = torch.zeros( + matrix_size, + matrix_size, + dtype=torch.half, + device=device, + ) + minus_identity.fill_diagonal_(-1) + return minus_identity + + +def random_chunk_mats( + total_chunks: int, + num_heads: int, + chunk_size: int, + scale: float, + device: str, +) -> torch.Tensor: + return scale * torch.triu( + torch.rand( + (total_chunks, num_heads, chunk_size, chunk_size), + dtype=torch.half, + device=device, + ), + diagonal=1, + ) + + +def build_fixed_bsnd_input( + chunk_mats: torch.Tensor, + batch_size: int, + seqlen: int, + num_heads: int, + chunk_size: int, +) -> torch.Tensor: + return ( + chunk_mats.transpose(1, 2) + .contiguous() + .reshape(batch_size, seqlen, num_heads, chunk_size) + ) + + +def build_uniform_varlen_input( + fixed_input: torch.Tensor, + batch_size: int, + seqlen: int, + chunk_size: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + total_tokens = batch_size * seqlen + packed_input = fixed_input.reshape(1, total_tokens, fixed_input.shape[2], chunk_size).contiguous() + cu_seqlens = torch.arange( + 0, + total_tokens + 1, + seqlen, + dtype=torch.int32, + device=fixed_input.device, + ) + chunk_indices = torch.arange( + 0, + total_tokens, + chunk_size, + dtype=torch.int32, + device=fixed_input.device, + ) + return packed_input, cu_seqlens, chunk_indices + + +def sample_true_varlen_lengths( + batch_size: int, + aggregated_tokens: int, + rng: np.random.Generator, +) -> list[int]: + if aggregated_tokens < batch_size: + raise ValueError("aggregated_tokens must be >= batch_size.") + + remaining = aggregated_tokens - batch_size + while True: + weights = rng.dirichlet(np.ones(batch_size)) + extras = np.floor(weights * remaining).astype(np.int64) + deficit = remaining - int(extras.sum()) + if deficit > 0: + extras[:deficit] += 1 + lengths = (extras + 1).tolist() + if any(length != lengths[0] for length in lengths): + return lengths + + +def build_true_varlen_input( + seq_lens: list[int], + num_heads: int, + chunk_size: int, + scale: float, + device: str, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + cu_seqlens = np.cumsum([0, *seq_lens], dtype=np.int64) + num_chunks = sum((seq_len + chunk_size - 1) // chunk_size for seq_len in seq_lens) + chunk_mats = random_chunk_mats( + total_chunks=num_chunks, + num_heads=num_heads, + chunk_size=chunk_size, + scale=scale, + device=device, + ) + + padded_total = num_chunks * chunk_size + packed_input = torch.zeros( + (1, padded_total, num_heads, chunk_size), + dtype=torch.half, + device=device, + ) + chunk_indices: list[int] = [] + chunk_idx = 0 + padded_row = 0 + + for seq_len in seq_lens: + for chunk_start in range(0, seq_len, chunk_size): + actual_size = min(chunk_size, seq_len - chunk_start) + chunk = chunk_mats[chunk_idx] + for head_idx in range(num_heads): + packed_input[ + 0, + padded_row : padded_row + actual_size, + head_idx, + :actual_size, + ] = chunk[head_idx, :actual_size, :actual_size] + chunk_indices.append(padded_row) + padded_row += chunk_size + chunk_idx += 1 + + return ( + packed_input.contiguous(), + torch.tensor(cu_seqlens.tolist(), dtype=torch.int32, device=device), + torch.tensor(chunk_indices, dtype=torch.int32, device=device), + ) + + +def make_fixed_runner( + tri_inv_func, + tensor_in: torch.Tensor, +) -> tuple[callable, torch.Tensor]: + matrix_size = tensor_in.shape[-1] + num_bsnd_heads = tensor_in.shape[-2] + num_matrices = tensor_in.numel() // (matrix_size * matrix_size) + tensor_out = torch.empty_like(tensor_in, dtype=torch.float32) + minus_identity = make_minus_identity(matrix_size, str(tensor_in.device)) + + def run(): + tri_inv_func( + tensor_out, + tensor_in, + minus_identity, + matrix_size, + num_matrices, + num_bsnd_heads, + ) + + return run, tensor_out + + +def make_varlen_runner( + tri_inv_func, + tensor_in: torch.Tensor, + chunk_indices: torch.Tensor, +) -> tuple[callable, torch.Tensor]: + matrix_size = tensor_in.shape[-1] + num_bsnd_heads = tensor_in.shape[-2] + num_matrices = tensor_in.numel() // (matrix_size * matrix_size) + tensor_out = torch.empty_like(tensor_in, dtype=torch.float32) + minus_identity = make_minus_identity(matrix_size, str(tensor_in.device)) + + def run(): + tri_inv_func( + tensor_out, + tensor_in, + minus_identity, + matrix_size, + num_matrices, + num_bsnd_heads, + chunk_indices=chunk_indices, + ) + + return run, tensor_out + + +def benchmark_ms( + fn, + warmup_iters: int, + benchmark_iters: int, + device: str, +) -> list[float]: + start_events = [torch.npu.Event(enable_timing=True) for _ in range(benchmark_iters)] + end_events = [torch.npu.Event(enable_timing=True) for _ in range(benchmark_iters)] + + torch.npu.synchronize() + for _ in range(warmup_iters): + fn() + torch.npu.synchronize() + + cache = torch.ones(DEFAULT_CACHE_SIZE, dtype=torch.int8, device=device) + times_ms: list[float] = [] + for idx in range(benchmark_iters): + cache.zero_() + torch.npu.synchronize() + start_events[idx].record() + fn() + end_events[idx].record() + end_events[idx].synchronize() + times_ms.append(start_events[idx].elapsed_time(end_events[idx])) + return times_ms + + +def add_bandwidth_fields(row: dict[str, float | int | str], input_dtype_bytes: int = 2) -> None: + size_elems = int(row.get("valid_numel", row["numel"])) + mem_bytes = size_elems * (input_dtype_bytes + 4) + row["mem_bytes"] = mem_bytes + row["bw_gbs"] = (mem_bytes / 1e9) / (float(row["time_us"]) / 1e6) + + +def accuracy_metrics(reference: torch.Tensor, candidate: torch.Tensor) -> tuple[float, float]: + ref = reference.detach().cpu().to(torch.float64) + cand = candidate.detach().cpu().to(torch.float64) + diff = ref - cand + max_abs = diff.abs().max().item() + denom = torch.sum(ref * ref).item() + rel_frob = 0.0 if denom == 0 else math.sqrt(torch.sum(diff * diff).item() / denom) + return max_abs, rel_frob + + +def write_csv(csv_path: Path, rows: list[dict[str, float | int | str]]) -> None: + csv_path.parent.mkdir(parents=True, exist_ok=True) + fieldnames = [ + "inverse_type", + "dtype", + "B", + "T", + "aggregated_T", + "padded_T", + "H", + "numel", + "valid_numel", + "chunk_size", + "time_us", + "mem_bytes", + "bw_gbs", + "max_abs_diff_to_fixed", + "rel_frob_diff_to_fixed", + "sample_id", + "seq_lens", + ] + with csv_path.open("w", newline="", encoding="utf-8") as handle: + writer = csv.DictWriter(handle, fieldnames=fieldnames) + writer.writeheader() + for row in rows: + writer.writerow(row) + + +def plot_bandwidth(plot_path: Path, rows: list[dict[str, float | int | str]], batch_size: int, num_heads: int, chunk_size: int) -> None: + plot_path.parent.mkdir(parents=True, exist_ok=True) + fixed_rows = [row for row in rows if row["inverse_type"] == "bsnd-fixed"] + varlen_rows = [row for row in rows if row["inverse_type"] == "bsnd-varlen-uniform"] + + fig, ax = plt.subplots(figsize=(7.5, 5.0)) + ax.plot( + [int(row["T"]) / 1000.0 for row in fixed_rows], + [float(row["bw_gbs"]) for row in fixed_rows], + marker="o", + linewidth=2, + label="BSND fixed", + ) + ax.plot( + [int(row["T"]) / 1000.0 for row in varlen_rows], + [float(row["bw_gbs"]) for row in varlen_rows], + marker="s", + linewidth=2, + label="BSND varlen-uniform", + ) + ax.set_xlabel("Sequence length T (K)") + ax.set_ylabel("Effective bandwidth (GB/s)") + ax.set_title( + f"Fast inverse BSND bandwidth\n" + f"(batch={batch_size}, head={num_heads}, chunk_size={chunk_size})" + ) + ax.set_ylim(bottom=0) + ax.grid(alpha=0.25) + ax.legend() + fig.tight_layout() + fig.savefig(plot_path, dpi=150) + plt.close(fig) + + +def plot_true_varlen_scatter( + plot_path: Path, + rows: list[dict[str, float | int | str]], + batch_size: int, + num_heads: int, + chunk_size: int, +) -> None: + plot_path.parent.mkdir(parents=True, exist_ok=True) + fig, ax = plt.subplots(figsize=(7.5, 5.0)) + ax.scatter( + [int(row["aggregated_T"]) for row in rows], + [float(row["bw_gbs"]) for row in rows], + alpha=0.8, + s=32, + ) + ax.set_xlabel("Aggregated sequence length") + ax.set_ylabel("Effective bandwidth (GB/s)") + ax.set_title( + f"Fast inverse true-varlen BSND bandwidth\n" + f"(batch={batch_size}, head={num_heads}, chunk_size={chunk_size})" + ) + ax.set_ylim(bottom=0) + ax.grid(alpha=0.25) + fig.tight_layout() + fig.savefig(plot_path, dpi=150) + plt.close(fig) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Benchmark standalone BSND fast-inverse kernel.") + parser.add_argument("--warmup", type=int, default=5) + parser.add_argument("--repeats", type=int, default=20) + parser.add_argument("--B", type=int, default=32, help="Dense BSND batch size.") + parser.add_argument("--H", type=int, default=4, help="Number of BSND heads.") + parser.add_argument("--chunk-size", type=int, default=64) + parser.add_argument( + "--seqlens", + type=parse_int_list, + default=DEFAULT_SEQLENS, + metavar="T[,T,...]", + help=( + "Comma-separated dense per-sequence lengths to benchmark " + f"(default: {','.join(map(str, DEFAULT_SEQLENS))})" + ), + ) + parser.add_argument("--scale", type=float, default=0.1) + parser.add_argument( + "--csv", + type=str, + default="", + help="Optional CSV output path. Defaults to bench_results_bsnd_fast_inverse_.csv", + ) + parser.add_argument( + "--plot", + type=str, + default="", + help="Optional plot output path. Defaults to bench_results_bsnd_fast_inverse_bw_.png", + ) + parser.add_argument( + "--true-varlen-csv", + type=str, + default="", + help="Optional CSV path for true-varlen benchmark points.", + ) + parser.add_argument( + "--true-varlen-plot", + type=str, + default="", + help="Optional scatter plot path for true-varlen benchmark points.", + ) + parser.add_argument( + "--true-varlen-samples", + type=int, + default=DEFAULT_TRUE_VARLEN_SAMPLES, + help="Number of random true-varlen batches per aggregated sequence length.", + ) + args = parser.parse_args() + + torch.npu.set_device(NPU_DEVICE) + + src = THIS_DIR / "fast_inverse.cpp" + print(f"Compiling {src} ...") + tri_inv_func = jit_compile(str(src)) + print("Compilation successful.\n") + + RESULTS_DIR.mkdir(parents=True, exist_ok=True) + csv_path = ( + Path(args.csv) + if args.csv + else RESULTS_DIR / f"bench_results_bsnd_fast_inverse_{args.chunk_size}.csv" + ) + plot_path = ( + Path(args.plot) + if args.plot + else RESULTS_DIR / f"bench_results_bsnd_fast_inverse_bw_{args.chunk_size}.png" + ) + true_varlen_csv_path = ( + Path(args.true_varlen_csv) + if args.true_varlen_csv + else RESULTS_DIR / f"bench_results_bsnd_fast_inverse_true_varlen_{args.chunk_size}.csv" + ) + true_varlen_plot_path = ( + Path(args.true_varlen_plot) + if args.true_varlen_plot + else RESULTS_DIR / f"bench_results_bsnd_fast_inverse_true_varlen_bw_{args.chunk_size}.png" + ) + + rows: list[dict[str, float | int | str]] = [] + true_varlen_rows: list[dict[str, float | int | str]] = [] + rng = np.random.default_rng(42) + + for seqlen in args.seqlens: + if seqlen % args.chunk_size != 0: + print( + f"Skipping T={seqlen}: requires T to be a multiple of chunk_size={args.chunk_size} " + "for matched fixed vs uniform-varlen comparison." + ) + continue + + total_chunks = args.B * seqlen // args.chunk_size + total_tokens = args.B * seqlen + print( + f"Profiling T={seqlen}, total_tokens={total_tokens}, " + f"B={args.B}, H={args.H}, chunk_size={args.chunk_size}" + ) + + chunk_mats = random_chunk_mats( + total_chunks=total_chunks, + num_heads=args.H, + chunk_size=args.chunk_size, + scale=args.scale, + device=NPU_DEVICE, + ) + fixed_input = build_fixed_bsnd_input( + chunk_mats, + batch_size=args.B, + seqlen=seqlen, + num_heads=args.H, + chunk_size=args.chunk_size, + ) + varlen_input, cu_seqlens, chunk_indices = build_uniform_varlen_input( + fixed_input, + batch_size=args.B, + seqlen=seqlen, + chunk_size=args.chunk_size, + ) + + print(f" uniform cu_seqlens: {cu_seqlens.cpu().tolist()}") + + fixed_run, fixed_out = make_fixed_runner(tri_inv_func, fixed_input) + varlen_run, varlen_out = make_varlen_runner(tri_inv_func, varlen_input, chunk_indices) + + fixed_run() + varlen_run() + torch.npu.synchronize() + + packed_fixed_out = fixed_out.reshape(1, total_tokens, args.H, args.chunk_size) + max_abs_diff, rel_frob_diff = accuracy_metrics(packed_fixed_out, varlen_out) + print( + f" accuracy vs fixed: max_abs_diff={max_abs_diff:.3e}, " + f"rel_frob_diff={rel_frob_diff:.3e}" + ) + + fixed_times_ms = benchmark_ms( + fixed_run, + warmup_iters=args.warmup, + benchmark_iters=args.repeats, + device=NPU_DEVICE, + ) + varlen_times_ms = benchmark_ms( + varlen_run, + warmup_iters=args.warmup, + benchmark_iters=args.repeats, + device=NPU_DEVICE, + ) + + fixed_row = { + "inverse_type": "bsnd-fixed", + "dtype": "fp16", + "B": args.B, + "T": seqlen, + "aggregated_T": total_tokens, + "padded_T": total_tokens, + "H": args.H, + "numel": fixed_input.numel(), + "valid_numel": fixed_input.numel(), + "chunk_size": args.chunk_size, + "time_us": int(round(np.mean(fixed_times_ms) * 1000.0)), + "max_abs_diff_to_fixed": 0.0, + "rel_frob_diff_to_fixed": 0.0, + "sample_id": "", + "seq_lens": "", + } + add_bandwidth_fields(fixed_row) + + varlen_row = { + "inverse_type": "bsnd-varlen-uniform", + "dtype": "fp16", + "B": args.B, + "T": seqlen, + "aggregated_T": total_tokens, + "padded_T": total_tokens, + "H": args.H, + "numel": varlen_input.numel(), + "valid_numel": total_tokens * args.H * args.chunk_size, + "chunk_size": args.chunk_size, + "time_us": int(round(np.mean(varlen_times_ms) * 1000.0)), + "max_abs_diff_to_fixed": max_abs_diff, + "rel_frob_diff_to_fixed": rel_frob_diff, + "sample_id": "", + "seq_lens": ",".join([str(seqlen)] * args.B), + } + add_bandwidth_fields(varlen_row) + + rows.extend([fixed_row, varlen_row]) + print( + f" fixed: time_us={fixed_row['time_us']}, bw_gbs={fixed_row['bw_gbs']:.2f} | " + f"varlen-uniform: time_us={varlen_row['time_us']}, bw_gbs={varlen_row['bw_gbs']:.2f}" + ) + + for sample_idx in range(args.true_varlen_samples): + seq_lens = sample_true_varlen_lengths(args.B, total_tokens, rng) + packed_input, cu_seqlens, chunk_indices = build_true_varlen_input( + seq_lens=seq_lens, + num_heads=args.H, + chunk_size=args.chunk_size, + scale=args.scale, + device=NPU_DEVICE, + ) + varlen_run_true, _ = make_varlen_runner( + tri_inv_func, + packed_input, + chunk_indices, + ) + times_ms = benchmark_ms( + varlen_run_true, + warmup_iters=args.warmup, + benchmark_iters=args.repeats, + device=NPU_DEVICE, + ) + row = { + "inverse_type": "bsnd-varlen-true", + "dtype": "fp16", + "B": args.B, + "T": seqlen, + "aggregated_T": total_tokens, + "padded_T": int(packed_input.shape[1]), + "H": args.H, + "numel": packed_input.numel(), + "valid_numel": total_tokens * args.H * args.chunk_size, + "chunk_size": args.chunk_size, + "time_us": int(round(np.mean(times_ms) * 1000.0)), + "max_abs_diff_to_fixed": "", + "rel_frob_diff_to_fixed": "", + "sample_id": sample_idx, + "seq_lens": ",".join(map(str, seq_lens)), + } + add_bandwidth_fields(row) + true_varlen_rows.append(row) + print( + f" true-varlen sample={sample_idx}: aggregated_T={total_tokens}, " + f"padded_T={row['padded_T']}, bw_gbs={row['bw_gbs']:.2f}" + ) + + if not rows: + raise RuntimeError("No benchmark rows were generated.") + + write_csv(csv_path, rows) + plot_bandwidth( + plot_path, + rows, + batch_size=args.B, + num_heads=args.H, + chunk_size=args.chunk_size, + ) + write_csv(true_varlen_csv_path, true_varlen_rows) + plot_true_varlen_scatter( + true_varlen_plot_path, + true_varlen_rows, + batch_size=args.B, + num_heads=args.H, + chunk_size=args.chunk_size, + ) + print(f"\nWrote CSV: {csv_path}") + print(f"Wrote plot: {plot_path}") + print(f"Wrote true-varlen CSV: {true_varlen_csv_path}") + print(f"Wrote true-varlen plot: {true_varlen_plot_path}") + + +if __name__ == "__main__": + main() From 04c33df9ca6d191d4cfa58a602472362cfa5b743 Mon Sep 17 00:00:00 2001 From: learning-chip Date: Tue, 31 Mar 2026 16:29:48 +0000 Subject: [PATCH 05/14] paritial load/store in kernel to avoid slow torch padding --- .../benchmark_bsnd_fast_inverse.py | 41 ++- .../jit_cpp/fast_inverse/fast_inverse.cpp | 11 +- .../fast_inverse/jit_util_fast_inverse.py | 10 + .../kernel_tri_inv_rec_unroll.cpp | 256 +++++++++++++++--- .../jit_cpp/fast_inverse/run_fast_inverse.py | 59 ++-- 5 files changed, 299 insertions(+), 78 deletions(-) diff --git a/examples/jit_cpp/fast_inverse/benchmark_bsnd_fast_inverse.py b/examples/jit_cpp/fast_inverse/benchmark_bsnd_fast_inverse.py index 7fa4820c..41fed753 100644 --- a/examples/jit_cpp/fast_inverse/benchmark_bsnd_fast_inverse.py +++ b/examples/jit_cpp/fast_inverse/benchmark_bsnd_fast_inverse.py @@ -104,7 +104,7 @@ def build_uniform_varlen_input( batch_size: int, seqlen: int, chunk_size: int, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: total_tokens = batch_size * seqlen packed_input = fixed_input.reshape(1, total_tokens, fixed_input.shape[2], chunk_size).contiguous() cu_seqlens = torch.arange( @@ -121,7 +121,8 @@ def build_uniform_varlen_input( dtype=torch.int32, device=fixed_input.device, ) - return packed_input, cu_seqlens, chunk_indices + chunk_valid_sizes = torch.full_like(chunk_indices, chunk_size) + return packed_input, cu_seqlens, chunk_indices, chunk_valid_sizes def sample_true_varlen_lengths( @@ -150,7 +151,7 @@ def build_true_varlen_input( chunk_size: int, scale: float, device: str, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: cu_seqlens = np.cumsum([0, *seq_lens], dtype=np.int64) num_chunks = sum((seq_len + chunk_size - 1) // chunk_size for seq_len in seq_lens) chunk_mats = random_chunk_mats( @@ -161,35 +162,37 @@ def build_true_varlen_input( device=device, ) - padded_total = num_chunks * chunk_size packed_input = torch.zeros( - (1, padded_total, num_heads, chunk_size), + (1, int(cu_seqlens[-1]), num_heads, chunk_size), dtype=torch.half, device=device, ) chunk_indices: list[int] = [] + chunk_valid_sizes: list[int] = [] chunk_idx = 0 - padded_row = 0 + token_row = 0 for seq_len in seq_lens: - for chunk_start in range(0, seq_len, chunk_size): - actual_size = min(chunk_size, seq_len - chunk_start) + for local_chunk_start in range(0, seq_len, chunk_size): + actual_size = min(chunk_size, seq_len - local_chunk_start) chunk = chunk_mats[chunk_idx] for head_idx in range(num_heads): packed_input[ 0, - padded_row : padded_row + actual_size, + token_row : token_row + actual_size, head_idx, :actual_size, ] = chunk[head_idx, :actual_size, :actual_size] - chunk_indices.append(padded_row) - padded_row += chunk_size + chunk_indices.append(token_row) + chunk_valid_sizes.append(actual_size) + token_row += actual_size chunk_idx += 1 return ( packed_input.contiguous(), torch.tensor(cu_seqlens.tolist(), dtype=torch.int32, device=device), torch.tensor(chunk_indices, dtype=torch.int32, device=device), + torch.tensor(chunk_valid_sizes, dtype=torch.int32, device=device), ) @@ -220,10 +223,11 @@ def make_varlen_runner( tri_inv_func, tensor_in: torch.Tensor, chunk_indices: torch.Tensor, + chunk_valid_sizes: torch.Tensor, ) -> tuple[callable, torch.Tensor]: matrix_size = tensor_in.shape[-1] num_bsnd_heads = tensor_in.shape[-2] - num_matrices = tensor_in.numel() // (matrix_size * matrix_size) + num_matrices = chunk_indices.numel() * num_bsnd_heads tensor_out = torch.empty_like(tensor_in, dtype=torch.float32) minus_identity = make_minus_identity(matrix_size, str(tensor_in.device)) @@ -236,6 +240,7 @@ def run(): num_matrices, num_bsnd_heads, chunk_indices=chunk_indices, + chunk_valid_sizes=chunk_valid_sizes, ) return run, tensor_out @@ -487,7 +492,7 @@ def main() -> None: num_heads=args.H, chunk_size=args.chunk_size, ) - varlen_input, cu_seqlens, chunk_indices = build_uniform_varlen_input( + varlen_input, cu_seqlens, chunk_indices, chunk_valid_sizes = build_uniform_varlen_input( fixed_input, batch_size=args.B, seqlen=seqlen, @@ -497,7 +502,12 @@ def main() -> None: print(f" uniform cu_seqlens: {cu_seqlens.cpu().tolist()}") fixed_run, fixed_out = make_fixed_runner(tri_inv_func, fixed_input) - varlen_run, varlen_out = make_varlen_runner(tri_inv_func, varlen_input, chunk_indices) + varlen_run, varlen_out = make_varlen_runner( + tri_inv_func, + varlen_input, + chunk_indices, + chunk_valid_sizes, + ) fixed_run() varlen_run() @@ -569,7 +579,7 @@ def main() -> None: for sample_idx in range(args.true_varlen_samples): seq_lens = sample_true_varlen_lengths(args.B, total_tokens, rng) - packed_input, cu_seqlens, chunk_indices = build_true_varlen_input( + packed_input, cu_seqlens, chunk_indices, chunk_valid_sizes = build_true_varlen_input( seq_lens=seq_lens, num_heads=args.H, chunk_size=args.chunk_size, @@ -580,6 +590,7 @@ def main() -> None: tri_inv_func, packed_input, chunk_indices, + chunk_valid_sizes, ) times_ms = benchmark_ms( varlen_run_true, diff --git a/examples/jit_cpp/fast_inverse/fast_inverse.cpp b/examples/jit_cpp/fast_inverse/fast_inverse.cpp index ea79df4a..12c81086 100644 --- a/examples/jit_cpp/fast_inverse/fast_inverse.cpp +++ b/examples/jit_cpp/fast_inverse/fast_inverse.cpp @@ -27,14 +27,17 @@ for the full License text. * @param num_bsnd_heads 0 for standard (B…ND) layout; * N (number of heads) for BSND layout. * @param chunk_indices Optional int32 pointer used only for varlen BSND. Each - * entry is the absolute row offset of one padded D x D - * chunk within the BSND tensor. + * entry is the absolute row offset of one chunk within the + * unpadded BSND tensor. + * @param chunk_valid_sizes Optional int32 pointer used only for varlen BSND. + * Each entry stores the runtime size of that chunk. */ extern "C" void call_kernel(uint32_t blockDim, void* stream, void* tensor_out, void* tensor_in, void* minus_identity_in, uint32_t matrix_size, uint32_t num_matrices, - uint32_t num_bsnd_heads, void* chunk_indices) { + uint32_t num_bsnd_heads, void* chunk_indices, + void* chunk_valid_sizes) { tri_inv_rec_unroll_fp16<<>>( tensor_out, tensor_in, minus_identity_in, matrix_size, num_matrices, - num_bsnd_heads, chunk_indices); + num_bsnd_heads, chunk_indices, chunk_valid_sizes); } diff --git a/examples/jit_cpp/fast_inverse/jit_util_fast_inverse.py b/examples/jit_cpp/fast_inverse/jit_util_fast_inverse.py index a6e205ec..90912373 100644 --- a/examples/jit_cpp/fast_inverse/jit_util_fast_inverse.py +++ b/examples/jit_cpp/fast_inverse/jit_util_fast_inverse.py @@ -86,6 +86,7 @@ def load_lib(lib_path: str): ctypes.c_uint32, # num_matrices ctypes.c_uint32, # num_bsnd_heads ctypes.c_void_p, # chunk_indices (optional int32 metadata) + ctypes.c_void_p, # chunk_valid_sizes (optional int32 metadata) ] lib.call_kernel.restype = None @@ -97,6 +98,7 @@ def tri_inv_func( num_matrices: int, num_bsnd_heads: int = 0, chunk_indices: torch.Tensor | None = None, + chunk_valid_sizes: torch.Tensor | None = None, block_dim: int = BLOCK_DIM, stream_ptr=None, ): @@ -107,6 +109,11 @@ def tri_inv_func( raise TypeError("chunk_indices must be int32.") if not chunk_indices.is_contiguous(): raise ValueError("chunk_indices must be contiguous.") + if chunk_valid_sizes is not None: + if chunk_valid_sizes.dtype != torch.int32: + raise TypeError("chunk_valid_sizes must be int32.") + if not chunk_valid_sizes.is_contiguous(): + raise ValueError("chunk_valid_sizes must be contiguous.") effective_block_dim = min(block_dim, num_matrices) lib.call_kernel( effective_block_dim, @@ -120,6 +127,9 @@ def tri_inv_func( _torch_to_ctypes(chunk_indices) if chunk_indices is not None else ctypes.c_void_p(), + _torch_to_ctypes(chunk_valid_sizes) + if chunk_valid_sizes is not None + else ctypes.c_void_p(), ) return tri_inv_func diff --git a/examples/jit_cpp/fast_inverse/kernel_tri_inv_rec_unroll.cpp b/examples/jit_cpp/fast_inverse/kernel_tri_inv_rec_unroll.cpp index 654f54ab..d998ab68 100644 --- a/examples/jit_cpp/fast_inverse/kernel_tri_inv_rec_unroll.cpp +++ b/examples/jit_cpp/fast_inverse/kernel_tri_inv_rec_unroll.cpp @@ -22,18 +22,25 @@ using namespace kernel_utils; (((tile_id) / (N)) * (S) * (N) * (D) + ((tile_id) % (N)) * (D)) /* - * For varlen BSND, chunk_indices stores the absolute row offset (within the - * padded BSND tensor) of each D x D chunk. Each tile_id still enumerates - * chunk-major, then head-major. + * For aligned BSND, tile_id enumerates chunk-major then head-major and maps to + * a fixed-stride address inside the dense BSND tensor. */ -AICORE inline uint32_t GetBSNDTileOffset(uint32_t tile_id, - uint32_t num_bsnd_heads, - uint32_t matrix_size, - __gm__ int32_t* chunk_indices) { +AICORE inline uint32_t GetBSNDFixedTileOffset(uint32_t tile_id, + uint32_t num_bsnd_heads, + uint32_t matrix_size) { + return BSND_OFFSET(tile_id, num_bsnd_heads, matrix_size, matrix_size); +} + +/* + * For varlen BSND, chunk_indices stores the absolute row offset of each chunk + * inside the unpadded BSND tensor. Each tile_id still enumerates chunk-major, + * then head-major. + */ +AICORE inline uint32_t GetBSNDVarlenTileOffset(uint32_t tile_id, + uint32_t num_bsnd_heads, + uint32_t matrix_size, + __gm__ int32_t* chunk_indices) { const uint32_t head_idx = tile_id % num_bsnd_heads; - if (chunk_indices == nullptr) { - return BSND_OFFSET(tile_id, num_bsnd_heads, matrix_size, matrix_size); - } const uint32_t chunk_idx = tile_id / num_bsnd_heads; const uint32_t chunk_row_start = static_cast(chunk_indices[chunk_idx]); @@ -385,20 +392,13 @@ AICORE inline void InvertSingleTile(TileL1AB X_l1_tile, TileL1AB I_l1_tile, /* * @brief: Runs the main kernel (inverts all matrices in the tensor) - * - * When chunk_indices is non-null in BSND mode, it maps each chunk index - * (tile_id / num_bsnd_heads) to the absolute row offset in the padded BSND - * tensor. This lets the kernel support per-sequence padding without changing - * the per-tile inverse logic. */ template AICORE inline void TriInvRecUnrollKernel(__gm__ OutputT* M_inv, __gm__ InputT* M, __gm__ InputT* I_neg, uint32_t total_tiles, - uint32_t num_bsnd_heads = 0, - __gm__ int32_t* chunk_indices = - nullptr) { + uint32_t num_bsnd_heads = 0) { constexpr uint32_t TileLen = MatrixSize * MatrixSize; constexpr uint32_t FractalSize = 16; constexpr uint32_t NumL0Buffers = 2; @@ -491,8 +491,8 @@ AICORE inline void TriInvRecUnrollKernel(__gm__ OutputT* M_inv, (global_index + tile_id < total_tiles); ++tile_id) { if constexpr (IsBSND) { - const uint32_t bsnd_offset = GetBSNDTileOffset( - global_index + tile_id, num_bsnd_heads, MatrixSize, chunk_indices); + const uint32_t bsnd_offset = GetBSNDFixedTileOffset( + global_index + tile_id, num_bsnd_heads, MatrixSize); GlobalTileIn M_global_in(M + bsnd_offset, {}, {static_cast(MatrixSize * num_bsnd_heads)}); wait_flag(PIPE_M, PIPE_MTE2, static_cast(tile_id)); @@ -520,8 +520,8 @@ AICORE inline void TriInvRecUnrollKernel(__gm__ OutputT* M_inv, set_flag(PIPE_M, PIPE_MTE2, static_cast(tile_id)); if constexpr (IsBSND) { - const uint32_t bsnd_offset = GetBSNDTileOffset( - global_index + tile_id, num_bsnd_heads, MatrixSize, chunk_indices); + const uint32_t bsnd_offset = GetBSNDFixedTileOffset( + global_index + tile_id, num_bsnd_heads, MatrixSize); GlobalTileOut M_inv_global_out( M_inv + bsnd_offset, {}, {static_cast(MatrixSize * num_bsnd_heads)}); @@ -545,18 +545,188 @@ AICORE inline void TriInvRecUnrollKernel(__gm__ OutputT* M_inv, static_cast(next_tile_id_that_waits_for_pipe_fix_pipe_m)); } +/* + * @brief: Varlen BSND kernel. + * + * The input/output tensors stay unpadded. For tail chunks with size + * `actual_size < MatrixSize`, the kernel: + * 1. loads only the valid `actual_size x actual_size` prefix via dynamic TLOAD + * 2. zero-fills the remaining rows/cols in-place via TFILLPAD_INPLACE + * 3. runs the original dense recursive inverse on the materialized full tile + * 4. stores only the valid `actual_size x actual_size` prefix back to GM + */ +template +AICORE inline void TriInvRecUnrollKernelBSNDVarlen( + __gm__ OutputT* M_inv, __gm__ InputT* M, __gm__ InputT* I_neg, + uint32_t total_tiles, uint32_t num_bsnd_heads, __gm__ int32_t* chunk_indices, + __gm__ int32_t* chunk_valid_sizes) { + constexpr uint32_t TileLen = MatrixSize * MatrixSize; + constexpr uint32_t FractalSize = 16; + constexpr uint32_t NumL0Buffers = 2; + + if (get_block_idx() * NumTilesPerCubeIter >= total_tiles) { + return; + } + + using GlobalTileShapeIn = + TileShape2D; + using GlobalTileStridesIn = Stride<1, 1, 1, -1, 1>; + using GlobalTileIn = + GlobalTensor; + + using GlobalTileDynShape = Shape<1, 1, 1, DYNAMIC, DYNAMIC>; + using GlobalTileDynStride = Stride<1, 1, 1, DYNAMIC, 1>; + using GlobalTileInDyn = + GlobalTensor; + using GlobalTileOutDyn = + GlobalTensor; + + using GlobalTileStridesINeg = + BaseShape2D; + using GlobalTileINeg = GlobalTensor; + + using GlobalTileShapeOut = + TileShape2D; + using GlobalTileStridesOut = Stride<1, 1, 1, -1, 1>; + using GlobalTileOut = GlobalTensor; + + using TileL1AB = + Tile; + using TileL1ABDyn = Tile; + + using TileL0A = TileLeft; + using TileL0B = TileRight; + using TileL0C = TileAcc; + using TileL0CDyn = TileAcc; + + GlobalTileINeg I_neg_global_in(I_neg); + + TileL1AB X_l1_tile; + TileL1AB I_l1_tile; + TileL1AB I_neg_l1_tile; + TileL1AB M_neg_l1_tile; + TileL1AB Zero_l1_tile; + TileL1AB Y_l1_tile[NumTilesPerCubeIter]; + + TileL0A a_l0_tile[NumL0Buffers]; + TileL0B b_l0_tile[NumL0Buffers]; + TileL0C c_l0_tile[NumL0Buffers]; + + TASSIGN(I_l1_tile, 0x0); + TASSIGN(I_neg_l1_tile, 0x0 + TileLen * sizeof(InputT)); + TASSIGN(Zero_l1_tile, 0x0 + 2 * TileLen * sizeof(InputT)); + TASSIGN(M_neg_l1_tile, 0x0 + 3 * TileLen * sizeof(InputT)); + TASSIGN(X_l1_tile, 0x0 + 4 * TileLen * sizeof(InputT)); + for (uint32_t tile_id = 0; tile_id < NumTilesPerCubeIter; ++tile_id) { + TASSIGN(Y_l1_tile[tile_id], 0x0 + (5 + tile_id) * TileLen * sizeof(InputT)); + } + + for (uint32_t buffer_num = 0; buffer_num < NumL0Buffers; ++buffer_num) { + TASSIGN(a_l0_tile[buffer_num], 0x0 + buffer_num * TileLen * sizeof(InputT)); + TASSIGN(b_l0_tile[buffer_num], 0x0 + buffer_num * TileLen * sizeof(InputT)); + TASSIGN(c_l0_tile[buffer_num], + 0x0 + buffer_num * TileLen * sizeof(OutputT)); + } + TLOAD(I_neg_l1_tile, I_neg_global_in); + set_flag(PIPE_MTE2, PIPE_MTE1, static_cast(0)); + wait_flag(PIPE_MTE2, PIPE_MTE1, static_cast(0)); + + PrepareAuxiliaryMatrices( + I_neg_l1_tile, Zero_l1_tile, I_l1_tile, a_l0_tile[0], b_l0_tile[0], + c_l0_tile[0]); + + const uint32_t max_iters_per_aic = + CeilDiv(total_tiles, (uint32_t)(NumTilesPerCubeIter * get_block_num())); + constexpr uint32_t final_c_buffer_index = MatrixSize > FractalSize ? 1 : 0; + + for (uint32_t cube_iter = 0; cube_iter < max_iters_per_aic; ++cube_iter) { + const uint32_t global_index = + (cube_iter * get_block_num() + get_block_idx()) * NumTilesPerCubeIter; + if (global_index >= total_tiles) { + break; + } + + for (uint32_t tile_id = 0; (tile_id < NumTilesPerCubeIter) && + (global_index + tile_id < total_tiles); + ++tile_id) { + const uint32_t global_tile_id = global_index + tile_id; + const uint32_t chunk_idx = global_tile_id / num_bsnd_heads; + const uint32_t valid_size = + static_cast(chunk_valid_sizes[chunk_idx]); + const uint32_t bsnd_offset = GetBSNDVarlenTileOffset( + global_tile_id, num_bsnd_heads, MatrixSize, chunk_indices); + const int row_stride = static_cast(MatrixSize * num_bsnd_heads); + + if (valid_size == MatrixSize) { + GlobalTileIn M_global_in(M + bsnd_offset, {}, {row_stride}); + TLOAD(Y_l1_tile[tile_id], M_global_in); + set_flag(PIPE_MTE2, PIPE_MTE1, static_cast(tile_id)); + wait_flag(PIPE_MTE2, PIPE_MTE1, static_cast(tile_id)); + } else { + TileL1ABDyn Y_dyn_l1_tile(valid_size, valid_size); + TASSIGN(Y_dyn_l1_tile, + 0x0 + (5 + tile_id) * TileLen * sizeof(InputT)); + GlobalTileInDyn M_global_in_dyn(M + bsnd_offset, + {1, 1, 1, valid_size, valid_size}, + {1, 1, 1, row_stride, 1}); + TLOAD(Y_dyn_l1_tile, M_global_in_dyn); + set_flag(PIPE_MTE2, PIPE_MTE1, static_cast(tile_id)); + wait_flag(PIPE_MTE2, PIPE_MTE1, static_cast(tile_id)); + TFILLPAD(Y_dyn_l1_tile, Y_dyn_l1_tile); + } + + InvertSingleTile( + X_l1_tile, I_l1_tile, I_neg_l1_tile, M_neg_l1_tile, Zero_l1_tile, + Y_l1_tile[tile_id], a_l0_tile, b_l0_tile, c_l0_tile, tile_id); + + if (valid_size == MatrixSize) { + GlobalTileOut M_inv_global_out(M_inv + bsnd_offset, {}, {row_stride}); + TSTORE(M_inv_global_out, c_l0_tile[final_c_buffer_index]); + } else { + TileL0CDyn c_l0_tail_tile(valid_size, valid_size); + TASSIGN(c_l0_tail_tile, + 0x0 + final_c_buffer_index * TileLen * sizeof(OutputT)); + GlobalTileOutDyn M_inv_global_out_dyn( + M_inv + bsnd_offset, {1, 1, 1, valid_size, valid_size}, + {1, 1, 1, row_stride, 1}); + TSTORE(M_inv_global_out_dyn, c_l0_tail_tile); + } + } + } +} + template AICORE void runKernelTriInvRecUnroll(__gm__ OutputT* M_inv, __gm__ InputT* M, __gm__ InputT* I_neg, uint32_t total_tiles, uint32_t num_bsnd_heads = 0, - __gm__ int32_t* chunk_indices = nullptr) { + __gm__ int32_t* chunk_indices = nullptr, + __gm__ int32_t* chunk_valid_sizes = + nullptr) { #if (__CHECK_FEATURE_AT_PRECOMPILE) || \ (__CCE_AICORE__ == 220 && defined(__DAV_C220_CUBE__)) - - TriInvRecUnrollKernel(M_inv, M, I_neg, total_tiles, num_bsnd_heads, - chunk_indices); + if constexpr (IsBSND) { + if (chunk_indices != nullptr && chunk_valid_sizes != nullptr) { + TriInvRecUnrollKernelBSNDVarlen( + M_inv, M, I_neg, total_tiles, num_bsnd_heads, chunk_indices, + chunk_valid_sizes); + } else { + TriInvRecUnrollKernel(M_inv, M, I_neg, total_tiles, + num_bsnd_heads); + } + } else { + TriInvRecUnrollKernel(M_inv, M, I_neg, total_tiles, num_bsnd_heads); + } #else // Nothing to do on AIV #endif @@ -568,29 +738,30 @@ AICORE void run_tri_inv_rec_unroll(__gm__ float* tensor_out, __gm__ InputT* minus_identity_in, uint32_t matrix_size, uint32_t num_matrices, uint32_t num_bsnd_heads, - __gm__ int32_t* chunk_indices) { + __gm__ int32_t* chunk_indices, + __gm__ int32_t* chunk_valid_sizes) { static_assert(std::is_same_v, "tri_inv_rec_unroll supports only fp16."); switch (matrix_size) { case 16: runKernelTriInvRecUnroll( tensor_out, tensor_in, minus_identity_in, num_matrices, - num_bsnd_heads, chunk_indices); + num_bsnd_heads, chunk_indices, chunk_valid_sizes); break; case 32: runKernelTriInvRecUnroll( tensor_out, tensor_in, minus_identity_in, num_matrices, - num_bsnd_heads, chunk_indices); + num_bsnd_heads, chunk_indices, chunk_valid_sizes); break; case 64: runKernelTriInvRecUnroll( tensor_out, tensor_in, minus_identity_in, num_matrices, - num_bsnd_heads, chunk_indices); + num_bsnd_heads, chunk_indices, chunk_valid_sizes); break; case 128: runKernelTriInvRecUnroll( tensor_out, tensor_in, minus_identity_in, num_matrices, - num_bsnd_heads, chunk_indices); + num_bsnd_heads, chunk_indices, chunk_valid_sizes); break; } } @@ -598,40 +769,47 @@ AICORE void run_tri_inv_rec_unroll(__gm__ float* tensor_out, extern "C" __global__ AICORE void tri_inv_rec_unroll_fp16( __gm__ void* tensor_out, __gm__ void* tensor_in, __gm__ void* minus_identity_in, uint32_t matrix_size, uint32_t num_matrices, - uint32_t num_bsnd_heads, __gm__ void* chunk_indices) { + uint32_t num_bsnd_heads, __gm__ void* chunk_indices, + __gm__ void* chunk_valid_sizes) { if (num_bsnd_heads == 0) { if (num_matrices <= get_block_num()) { run_tri_inv_rec_unroll( (__gm__ float*)tensor_out, (__gm__ half*)tensor_in, (__gm__ half*)minus_identity_in, matrix_size, num_matrices, - num_bsnd_heads, (__gm__ int32_t*)chunk_indices); + num_bsnd_heads, (__gm__ int32_t*)chunk_indices, + (__gm__ int32_t*)chunk_valid_sizes); } else if (num_matrices <= 2 * get_block_num()) { run_tri_inv_rec_unroll( (__gm__ float*)tensor_out, (__gm__ half*)tensor_in, (__gm__ half*)minus_identity_in, matrix_size, num_matrices, - num_bsnd_heads, (__gm__ int32_t*)chunk_indices); + num_bsnd_heads, (__gm__ int32_t*)chunk_indices, + (__gm__ int32_t*)chunk_valid_sizes); } else { run_tri_inv_rec_unroll( (__gm__ float*)tensor_out, (__gm__ half*)tensor_in, (__gm__ half*)minus_identity_in, matrix_size, num_matrices, - num_bsnd_heads, (__gm__ int32_t*)chunk_indices); + num_bsnd_heads, (__gm__ int32_t*)chunk_indices, + (__gm__ int32_t*)chunk_valid_sizes); } } else { if (num_matrices <= get_block_num()) { run_tri_inv_rec_unroll( (__gm__ float*)tensor_out, (__gm__ half*)tensor_in, (__gm__ half*)minus_identity_in, matrix_size, num_matrices, - num_bsnd_heads, (__gm__ int32_t*)chunk_indices); + num_bsnd_heads, (__gm__ int32_t*)chunk_indices, + (__gm__ int32_t*)chunk_valid_sizes); } else if (num_matrices <= 2 * get_block_num()) { run_tri_inv_rec_unroll( (__gm__ float*)tensor_out, (__gm__ half*)tensor_in, (__gm__ half*)minus_identity_in, matrix_size, num_matrices, - num_bsnd_heads, (__gm__ int32_t*)chunk_indices); + num_bsnd_heads, (__gm__ int32_t*)chunk_indices, + (__gm__ int32_t*)chunk_valid_sizes); } else { run_tri_inv_rec_unroll( (__gm__ float*)tensor_out, (__gm__ half*)tensor_in, (__gm__ half*)minus_identity_in, matrix_size, num_matrices, - num_bsnd_heads, (__gm__ int32_t*)chunk_indices); + num_bsnd_heads, (__gm__ int32_t*)chunk_indices, + (__gm__ int32_t*)chunk_valid_sizes); } } } diff --git a/examples/jit_cpp/fast_inverse/run_fast_inverse.py b/examples/jit_cpp/fast_inverse/run_fast_inverse.py index af0a9422..a4b6e607 100644 --- a/examples/jit_cpp/fast_inverse/run_fast_inverse.py +++ b/examples/jit_cpp/fast_inverse/run_fast_inverse.py @@ -117,24 +117,35 @@ def _run_kernel_bsnd( tri_inv_func, U_bsnd_fp16: torch.Tensor, chunk_indices: torch.Tensor | None = None, + chunk_valid_sizes: torch.Tensor | None = None, ): """ Run the kernel in BSND mode and return fp64 CPU result. U_bsnd_fp16 : (B, S, N, D) half tensor on NPU where each (D, D) block along the S dimension is one matrix to invert. - chunk_indices : optional int32 tensor containing the padded row start of + chunk_indices : optional int32 tensor containing the unpadded row start of each valid chunk for varlen BSND inputs. + chunk_valid_sizes : optional int32 tensor containing the runtime size of + each chunk for varlen BSND inputs. """ matrix_size = U_bsnd_fp16.shape[-1] num_bsnd_heads = U_bsnd_fp16.shape[-2] - num_matrices = U_bsnd_fp16.numel() // (matrix_size * matrix_size) + if chunk_indices is not None and chunk_valid_sizes is not None: + num_matrices = chunk_indices.numel() * num_bsnd_heads + else: + num_matrices = U_bsnd_fp16.numel() // (matrix_size * matrix_size) device = U_bsnd_fp16.device tensor_out = torch.zeros_like(U_bsnd_fp16, dtype=torch.float32) I_neg = _make_minus_identity(matrix_size, str(device)) if chunk_indices is not None: chunk_indices = chunk_indices.to(device=device, dtype=torch.int32).contiguous() + if chunk_valid_sizes is not None: + chunk_valid_sizes = chunk_valid_sizes.to( + device=device, + dtype=torch.int32, + ).contiguous() torch.npu.synchronize() tri_inv_func( @@ -145,6 +156,7 @@ def _run_kernel_bsnd( num_matrices, num_bsnd_heads, chunk_indices=chunk_indices, + chunk_valid_sizes=chunk_valid_sizes, ) torch.npu.synchronize() @@ -158,10 +170,11 @@ def _build_varlen_bsnd_case( chunk_size: int, ): """ - Build a padded BSND tensor plus reference output for varlen testing. + Build an unpadded BSND tensor plus reference output for varlen testing. - Each sequence is padded independently to the next multiple of chunk_size. - chunk_indices records the padded row offset of every valid chunk. + Each sequence contributes only its true rows in the packed BSND tensor. + chunk_indices records the unpadded row offset of every valid chunk and + chunk_valid_sizes stores each chunk's runtime size. """ seq_lens = [ cu_seqlens[i + 1] - cu_seqlens[i] @@ -179,14 +192,13 @@ def _build_varlen_bsnd_case( ) chunk_mats = gen(chunk_size, num_chunks, num_heads).to(torch.float64) - padded_total = num_chunks * chunk_size - U_padded = torch.zeros((1, padded_total, num_heads, chunk_size), dtype=torch.float64) + U = torch.zeros((1, total_tokens, num_heads, chunk_size), dtype=torch.float64) golden = torch.zeros((1, total_tokens, num_heads, chunk_size), dtype=torch.float64) chunk_indices: list[int] = [] + chunk_valid_sizes: list[int] = [] chunk_infos: list[tuple[int, int, int]] = [] chunk_idx = 0 - padded_row = 0 for seq_idx in range(len(cu_seqlens) - 1): seq_start = cu_seqlens[seq_idx] @@ -196,9 +208,9 @@ def _build_varlen_bsnd_case( chunk = chunk_mats[chunk_idx] for head_idx in range(num_heads): U_valid = chunk[head_idx, :actual_size, :actual_size] - U_padded[ + U[ 0, - padded_row : padded_row + actual_size, + chunk_start : chunk_start + actual_size, head_idx, :actual_size, ] = U_valid @@ -209,12 +221,18 @@ def _build_varlen_bsnd_case( :actual_size, ] = invert_single_chunk_ref(U_valid) - chunk_indices.append(padded_row) - chunk_infos.append((padded_row, chunk_start, actual_size)) - padded_row += chunk_size + chunk_indices.append(chunk_start) + chunk_valid_sizes.append(actual_size) + chunk_infos.append((chunk_start, chunk_start, actual_size)) chunk_idx += 1 - return U_padded, golden, chunk_infos, torch.tensor(chunk_indices, dtype=torch.int32) + return ( + U, + golden, + chunk_infos, + torch.tensor(chunk_indices, dtype=torch.int32), + torch.tensor(chunk_valid_sizes, dtype=torch.int32), + ) # --------------------------------------------------------------------------- @@ -300,28 +318,29 @@ def _test_case_bsnd_varlen( ftol: float, label: str, ): - U_padded, golden, chunk_infos, chunk_indices = _build_varlen_bsnd_case( + U_varlen, golden, chunk_infos, chunk_indices, chunk_valid_sizes = _build_varlen_bsnd_case( gen, cu_seqlens, N, D, ) - actual_padded = _run_kernel_bsnd( + actual_varlen = _run_kernel_bsnd( tri_inv_func, - U_padded.to(torch.half).npu(), + U_varlen.to(torch.half).npu(), chunk_indices=chunk_indices.npu(), + chunk_valid_sizes=chunk_valid_sizes.npu(), ) actual = torch.zeros_like(golden) - for padded_row, token_row, actual_size in chunk_infos: + for input_row, token_row, actual_size in chunk_infos: actual[ :, token_row : token_row + actual_size, :, :actual_size, - ] = actual_padded[ + ] = actual_varlen[ :, - padded_row : padded_row + actual_size, + input_row : input_row + actual_size, :, :actual_size, ] From dc1c5f4cf46e498624bb28c1d8ae661cbb558c85 Mon Sep 17 00:00:00 2001 From: learning-chip Date: Tue, 31 Mar 2026 16:51:42 +0000 Subject: [PATCH 06/14] fix kernel synchornization for large-size benchmarks --- .../kernel_tri_inv_rec_unroll.cpp | 129 ++++++++++++++---- 1 file changed, 102 insertions(+), 27 deletions(-) diff --git a/examples/jit_cpp/fast_inverse/kernel_tri_inv_rec_unroll.cpp b/examples/jit_cpp/fast_inverse/kernel_tri_inv_rec_unroll.cpp index d998ab68..9065b7fd 100644 --- a/examples/jit_cpp/fast_inverse/kernel_tri_inv_rec_unroll.cpp +++ b/examples/jit_cpp/fast_inverse/kernel_tri_inv_rec_unroll.cpp @@ -398,7 +398,11 @@ template >::type; using GlobalTileIn = GlobalTensor; + using GlobalTileDynShape = Shape<1, 1, 1, DYNAMIC, DYNAMIC>; + using GlobalTileDynStride = Stride<1, 1, 1, DYNAMIC, 1>; + using GlobalTileInDyn = + GlobalTensor; using GlobalTileStridesINeg = BaseShape2D; @@ -427,10 +435,16 @@ AICORE inline void TriInvRecUnrollKernel(__gm__ OutputT* M_inv, Stride<1, 1, 1, -1, 1>>::type; using GlobalTileOut = GlobalTensor; + using GlobalTileOutDyn = + GlobalTensor; using TileL1AB = Tile; + using TileL1ABDyn = Tile; + using TileL0CDyn = TileAcc; using TileL0A = TileLeft; using TileL0B = TileRight; @@ -491,12 +505,38 @@ AICORE inline void TriInvRecUnrollKernel(__gm__ OutputT* M_inv, (global_index + tile_id < total_tiles); ++tile_id) { if constexpr (IsBSND) { - const uint32_t bsnd_offset = GetBSNDFixedTileOffset( - global_index + tile_id, num_bsnd_heads, MatrixSize); - GlobalTileIn M_global_in(M + bsnd_offset, {}, - {static_cast(MatrixSize * num_bsnd_heads)}); + const uint32_t global_tile_id = global_index + tile_id; + const uint32_t bsnd_offset = + chunk_indices != nullptr + ? GetBSNDVarlenTileOffset(global_tile_id, num_bsnd_heads, + MatrixSize, chunk_indices) + : GetBSNDFixedTileOffset(global_tile_id, num_bsnd_heads, + MatrixSize); + const int row_stride = static_cast(MatrixSize * num_bsnd_heads); wait_flag(PIPE_M, PIPE_MTE2, static_cast(tile_id)); - TLOAD(Y_l1_tile[tile_id], M_global_in); + if (chunk_valid_sizes != nullptr) { + const uint32_t chunk_idx = global_tile_id / num_bsnd_heads; + const uint32_t valid_size = + static_cast(chunk_valid_sizes[chunk_idx]); + if (valid_size < MatrixSize) { + TileL1ABDyn Y_dyn_l1_tile(valid_size, valid_size); + TASSIGN(Y_dyn_l1_tile, + 0x0 + (5 + tile_id) * TileLen * sizeof(InputT)); + GlobalTileInDyn M_global_in_dyn( + M + bsnd_offset, {1, 1, 1, valid_size, valid_size}, + {1, 1, 1, row_stride, 1}); + TLOAD(Y_dyn_l1_tile, M_global_in_dyn); + set_flag(PIPE_MTE2, PIPE_MTE1, static_cast(tile_id)); + wait_flag(PIPE_MTE2, PIPE_MTE1, static_cast(tile_id)); + TFILLPAD(Y_dyn_l1_tile, Y_dyn_l1_tile); + } else { + GlobalTileIn M_global_in(M + bsnd_offset, {}, {row_stride}); + TLOAD(Y_l1_tile[tile_id], M_global_in); + } + } else { + GlobalTileIn M_global_in(M + bsnd_offset, {}, {row_stride}); + TLOAD(Y_l1_tile[tile_id], M_global_in); + } } else { GlobalTileIn M_global_in(M + (global_index + tile_id) * TileLen); wait_flag(PIPE_M, PIPE_MTE2, static_cast(tile_id)); @@ -520,12 +560,48 @@ AICORE inline void TriInvRecUnrollKernel(__gm__ OutputT* M_inv, set_flag(PIPE_M, PIPE_MTE2, static_cast(tile_id)); if constexpr (IsBSND) { - const uint32_t bsnd_offset = GetBSNDFixedTileOffset( - global_index + tile_id, num_bsnd_heads, MatrixSize); - GlobalTileOut M_inv_global_out( - M_inv + bsnd_offset, {}, - {static_cast(MatrixSize * num_bsnd_heads)}); - TSTORE(M_inv_global_out, c_l0_tile[final_c_buffer_index]); + const uint32_t global_tile_id = global_index + tile_id; + const uint32_t bsnd_offset = + chunk_indices != nullptr + ? GetBSNDVarlenTileOffset(global_tile_id, num_bsnd_heads, + MatrixSize, chunk_indices) + : GetBSNDFixedTileOffset(global_tile_id, num_bsnd_heads, + MatrixSize); + const int row_stride = static_cast(MatrixSize * num_bsnd_heads); + if (chunk_valid_sizes != nullptr) { + const uint32_t chunk_idx = global_tile_id / num_bsnd_heads; + const uint32_t valid_size = + static_cast(chunk_valid_sizes[chunk_idx]); + if (valid_size < MatrixSize) { + const event_t event_0 = static_cast(tile_id); + const event_t event_1 = + static_cast(tile_id + NumTilesPerCubeIter); + TileL0CDyn c_l0_tail_tile(valid_size, valid_size); + TASSIGN(c_l0_tail_tile, + 0x0 + final_c_buffer_index * TileLen * sizeof(OutputT)); + if constexpr (final_c_buffer_index == 1) { + set_flag(PIPE_M, PIPE_FIX, event_1); + wait_flag(PIPE_M, PIPE_FIX, event_1); + } else { + set_flag(PIPE_M, PIPE_FIX, event_0); + wait_flag(PIPE_M, PIPE_FIX, event_0); + } + set_flag(PIPE_FIX, PIPE_MTE3, static_cast(tile_id)); + wait_flag(PIPE_FIX, PIPE_MTE3, static_cast(tile_id)); + GlobalTileOutDyn M_inv_global_out_dyn( + M_inv + bsnd_offset, {1, 1, 1, valid_size, valid_size}, + {1, 1, 1, row_stride, 1}); + TSTORE(M_inv_global_out_dyn, c_l0_tail_tile); + } else { + GlobalTileOut M_inv_global_out(M_inv + bsnd_offset, {}, + {row_stride}); + TSTORE(M_inv_global_out, c_l0_tile[final_c_buffer_index]); + } + } else { + GlobalTileOut M_inv_global_out(M_inv + bsnd_offset, {}, + {row_stride}); + TSTORE(M_inv_global_out, c_l0_tile[final_c_buffer_index]); + } } else { GlobalTileOut M_inv_global_out(M_inv + (global_index + tile_id) * TileLen); @@ -690,9 +766,20 @@ AICORE inline void TriInvRecUnrollKernelBSNDVarlen( GlobalTileOut M_inv_global_out(M_inv + bsnd_offset, {}, {row_stride}); TSTORE(M_inv_global_out, c_l0_tile[final_c_buffer_index]); } else { + const event_t event_0 = static_cast(tile_id); + const event_t event_1 = static_cast(tile_id + NumTilesPerCubeIter); TileL0CDyn c_l0_tail_tile(valid_size, valid_size); TASSIGN(c_l0_tail_tile, 0x0 + final_c_buffer_index * TileLen * sizeof(OutputT)); + if constexpr (final_c_buffer_index == 1) { + set_flag(PIPE_M, PIPE_FIX, event_1); + wait_flag(PIPE_M, PIPE_FIX, event_1); + } else { + set_flag(PIPE_M, PIPE_FIX, event_0); + wait_flag(PIPE_M, PIPE_FIX, event_0); + } + set_flag(PIPE_FIX, PIPE_MTE3, static_cast(tile_id)); + wait_flag(PIPE_FIX, PIPE_MTE3, static_cast(tile_id)); GlobalTileOutDyn M_inv_global_out_dyn( M_inv + bsnd_offset, {1, 1, 1, valid_size, valid_size}, {1, 1, 1, row_stride, 1}); @@ -712,21 +799,9 @@ AICORE void runKernelTriInvRecUnroll(__gm__ OutputT* M_inv, __gm__ InputT* M, nullptr) { #if (__CHECK_FEATURE_AT_PRECOMPILE) || \ (__CCE_AICORE__ == 220 && defined(__DAV_C220_CUBE__)) - if constexpr (IsBSND) { - if (chunk_indices != nullptr && chunk_valid_sizes != nullptr) { - TriInvRecUnrollKernelBSNDVarlen( - M_inv, M, I_neg, total_tiles, num_bsnd_heads, chunk_indices, - chunk_valid_sizes); - } else { - TriInvRecUnrollKernel(M_inv, M, I_neg, total_tiles, - num_bsnd_heads); - } - } else { - TriInvRecUnrollKernel(M_inv, M, I_neg, total_tiles, num_bsnd_heads); - } + TriInvRecUnrollKernel(M_inv, M, I_neg, total_tiles, num_bsnd_heads, + chunk_indices, chunk_valid_sizes); #else // Nothing to do on AIV #endif From a43f974dd5d4be3558854f049f7e8b3a773fbdae Mon Sep 17 00:00:00 2001 From: learning-chip Date: Wed, 1 Apr 2026 20:58:27 +0000 Subject: [PATCH 07/14] compute chunk_metadata from cu_seqlens --- .../benchmark_bsnd_fast_inverse.py | 74 +++++++++----- .../jit_cpp/fast_inverse/run_fast_inverse.py | 99 +++++++++++-------- 2 files changed, 104 insertions(+), 69 deletions(-) diff --git a/examples/jit_cpp/fast_inverse/benchmark_bsnd_fast_inverse.py b/examples/jit_cpp/fast_inverse/benchmark_bsnd_fast_inverse.py index 41fed753..ac25d9e5 100644 --- a/examples/jit_cpp/fast_inverse/benchmark_bsnd_fast_inverse.py +++ b/examples/jit_cpp/fast_inverse/benchmark_bsnd_fast_inverse.py @@ -68,6 +68,39 @@ def make_minus_identity(matrix_size: int, device: str) -> torch.Tensor: return minus_identity +def chunk_metadata_from_cu_seqlens( + cu_seqlens: torch.Tensor, + chunk_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + cu_seqlens_np = cu_seqlens.detach().cpu().numpy().astype(np.int64, copy=False) + seq_starts = cu_seqlens_np[:-1] + seq_lens = cu_seqlens_np[1:] - seq_starts + seq_num_chunks = (seq_lens + chunk_size - 1) // chunk_size + total_chunks = int(seq_num_chunks.sum()) + + chunk_indices = np.empty(total_chunks, dtype=np.int32) + chunk_valid_sizes = np.empty(total_chunks, dtype=np.int32) + cursor = 0 + for seq_start, seq_len, num_chunks in zip(seq_starts, seq_lens, seq_num_chunks): + num_chunks_int = int(num_chunks) + local_offsets = np.arange(num_chunks_int, dtype=np.int64) * chunk_size + next_cursor = cursor + num_chunks_int + chunk_indices[cursor:next_cursor] = (seq_start + local_offsets).astype( + np.int32, + copy=False, + ) + chunk_valid_sizes[cursor:next_cursor] = np.minimum( + chunk_size, + seq_len - local_offsets, + ).astype(np.int32, copy=False) + cursor = next_cursor + + return ( + torch.from_numpy(chunk_indices).to(device=cu_seqlens.device), + torch.from_numpy(chunk_valid_sizes).to(device=cu_seqlens.device), + ) + + def random_chunk_mats( total_chunks: int, num_heads: int, @@ -104,7 +137,7 @@ def build_uniform_varlen_input( batch_size: int, seqlen: int, chunk_size: int, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: total_tokens = batch_size * seqlen packed_input = fixed_input.reshape(1, total_tokens, fixed_input.shape[2], chunk_size).contiguous() cu_seqlens = torch.arange( @@ -114,15 +147,7 @@ def build_uniform_varlen_input( dtype=torch.int32, device=fixed_input.device, ) - chunk_indices = torch.arange( - 0, - total_tokens, - chunk_size, - dtype=torch.int32, - device=fixed_input.device, - ) - chunk_valid_sizes = torch.full_like(chunk_indices, chunk_size) - return packed_input, cu_seqlens, chunk_indices, chunk_valid_sizes + return packed_input, cu_seqlens def sample_true_varlen_lengths( @@ -151,7 +176,7 @@ def build_true_varlen_input( chunk_size: int, scale: float, device: str, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: cu_seqlens = np.cumsum([0, *seq_lens], dtype=np.int64) num_chunks = sum((seq_len + chunk_size - 1) // chunk_size for seq_len in seq_lens) chunk_mats = random_chunk_mats( @@ -167,8 +192,6 @@ def build_true_varlen_input( dtype=torch.half, device=device, ) - chunk_indices: list[int] = [] - chunk_valid_sizes: list[int] = [] chunk_idx = 0 token_row = 0 @@ -183,16 +206,12 @@ def build_true_varlen_input( head_idx, :actual_size, ] = chunk[head_idx, :actual_size, :actual_size] - chunk_indices.append(token_row) - chunk_valid_sizes.append(actual_size) token_row += actual_size chunk_idx += 1 return ( packed_input.contiguous(), torch.tensor(cu_seqlens.tolist(), dtype=torch.int32, device=device), - torch.tensor(chunk_indices, dtype=torch.int32, device=device), - torch.tensor(chunk_valid_sizes, dtype=torch.int32, device=device), ) @@ -222,14 +241,19 @@ def run(): def make_varlen_runner( tri_inv_func, tensor_in: torch.Tensor, - chunk_indices: torch.Tensor, - chunk_valid_sizes: torch.Tensor, + cu_seqlens: torch.Tensor, ) -> tuple[callable, torch.Tensor]: matrix_size = tensor_in.shape[-1] num_bsnd_heads = tensor_in.shape[-2] - num_matrices = chunk_indices.numel() * num_bsnd_heads + seq_lens = cu_seqlens[1:].to(torch.int64) - cu_seqlens[:-1].to(torch.int64) + num_chunks = ((seq_lens + matrix_size - 1) // matrix_size).sum().item() + num_matrices = int(num_chunks) * num_bsnd_heads tensor_out = torch.empty_like(tensor_in, dtype=torch.float32) minus_identity = make_minus_identity(matrix_size, str(tensor_in.device)) + chunk_indices, chunk_valid_sizes = chunk_metadata_from_cu_seqlens( + cu_seqlens, + matrix_size, + ) def run(): tri_inv_func( @@ -492,7 +516,7 @@ def main() -> None: num_heads=args.H, chunk_size=args.chunk_size, ) - varlen_input, cu_seqlens, chunk_indices, chunk_valid_sizes = build_uniform_varlen_input( + varlen_input, cu_seqlens = build_uniform_varlen_input( fixed_input, batch_size=args.B, seqlen=seqlen, @@ -505,8 +529,7 @@ def main() -> None: varlen_run, varlen_out = make_varlen_runner( tri_inv_func, varlen_input, - chunk_indices, - chunk_valid_sizes, + cu_seqlens, ) fixed_run() @@ -579,7 +602,7 @@ def main() -> None: for sample_idx in range(args.true_varlen_samples): seq_lens = sample_true_varlen_lengths(args.B, total_tokens, rng) - packed_input, cu_seqlens, chunk_indices, chunk_valid_sizes = build_true_varlen_input( + packed_input, cu_seqlens = build_true_varlen_input( seq_lens=seq_lens, num_heads=args.H, chunk_size=args.chunk_size, @@ -589,8 +612,7 @@ def main() -> None: varlen_run_true, _ = make_varlen_runner( tri_inv_func, packed_input, - chunk_indices, - chunk_valid_sizes, + cu_seqlens, ) times_ms = benchmark_ms( varlen_run_true, diff --git a/examples/jit_cpp/fast_inverse/run_fast_inverse.py b/examples/jit_cpp/fast_inverse/run_fast_inverse.py index a4b6e607..75a247f8 100644 --- a/examples/jit_cpp/fast_inverse/run_fast_inverse.py +++ b/examples/jit_cpp/fast_inverse/run_fast_inverse.py @@ -93,6 +93,40 @@ def _make_minus_identity(matrix_size: int, device: str) -> torch.Tensor: return I_neg +def _chunk_metadata_from_cu_seqlens( + cu_seqlens: torch.Tensor | list[int], + chunk_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + if isinstance(cu_seqlens, torch.Tensor): + cu_seqlens_np = cu_seqlens.detach().cpu().numpy().astype(np.int64, copy=False) + else: + cu_seqlens_np = np.asarray(cu_seqlens, dtype=np.int64) + + seq_starts = cu_seqlens_np[:-1] + seq_lens = cu_seqlens_np[1:] - seq_starts + seq_num_chunks = (seq_lens + chunk_size - 1) // chunk_size + total_chunks = int(seq_num_chunks.sum()) + + chunk_indices = np.empty(total_chunks, dtype=np.int32) + chunk_valid_sizes = np.empty(total_chunks, dtype=np.int32) + cursor = 0 + for seq_start, seq_len, num_chunks in zip(seq_starts, seq_lens, seq_num_chunks): + num_chunks_int = int(num_chunks) + local_offsets = np.arange(num_chunks_int, dtype=np.int64) * chunk_size + next_cursor = cursor + num_chunks_int + chunk_indices[cursor:next_cursor] = (seq_start + local_offsets).astype( + np.int32, + copy=False, + ) + chunk_valid_sizes[cursor:next_cursor] = np.minimum( + chunk_size, + seq_len - local_offsets, + ).astype(np.int32, copy=False) + cursor = next_cursor + + return torch.from_numpy(chunk_indices), torch.from_numpy(chunk_valid_sizes) + + def _run_kernel(tri_inv_func, U_fp16: torch.Tensor): """ Allocate output, build -I, run kernel, return fp64 CPU result. @@ -116,36 +150,39 @@ def _run_kernel(tri_inv_func, U_fp16: torch.Tensor): def _run_kernel_bsnd( tri_inv_func, U_bsnd_fp16: torch.Tensor, - chunk_indices: torch.Tensor | None = None, - chunk_valid_sizes: torch.Tensor | None = None, + cu_seqlens: torch.Tensor | None = None, ): """ Run the kernel in BSND mode and return fp64 CPU result. U_bsnd_fp16 : (B, S, N, D) half tensor on NPU where each (D, D) block along the S dimension is one matrix to invert. - chunk_indices : optional int32 tensor containing the unpadded row start of - each valid chunk for varlen BSND inputs. - chunk_valid_sizes : optional int32 tensor containing the runtime size of - each chunk for varlen BSND inputs. + cu_seqlens : optional int32 tensor containing cumulative sequence lengths + for varlen BSND inputs. """ matrix_size = U_bsnd_fp16.shape[-1] num_bsnd_heads = U_bsnd_fp16.shape[-2] - if chunk_indices is not None and chunk_valid_sizes is not None: - num_matrices = chunk_indices.numel() * num_bsnd_heads + if cu_seqlens is not None: + seq_lens = cu_seqlens[1:].to(torch.int64) - cu_seqlens[:-1].to(torch.int64) + num_chunks = ((seq_lens + matrix_size - 1) // matrix_size).sum().item() + num_matrices = int(num_chunks) * num_bsnd_heads else: num_matrices = U_bsnd_fp16.numel() // (matrix_size * matrix_size) device = U_bsnd_fp16.device tensor_out = torch.zeros_like(U_bsnd_fp16, dtype=torch.float32) I_neg = _make_minus_identity(matrix_size, str(device)) - if chunk_indices is not None: - chunk_indices = chunk_indices.to(device=device, dtype=torch.int32).contiguous() - if chunk_valid_sizes is not None: - chunk_valid_sizes = chunk_valid_sizes.to( - device=device, - dtype=torch.int32, - ).contiguous() + if cu_seqlens is not None: + cu_seqlens = cu_seqlens.to(device=device, dtype=torch.int32).contiguous() + chunk_indices, chunk_valid_sizes = _chunk_metadata_from_cu_seqlens( + cu_seqlens, + matrix_size, + ) + chunk_indices = chunk_indices.to(device=device).contiguous() + chunk_valid_sizes = chunk_valid_sizes.to(device=device).contiguous() + else: + chunk_indices = None + chunk_valid_sizes = None torch.npu.synchronize() tri_inv_func( @@ -173,8 +210,6 @@ def _build_varlen_bsnd_case( Build an unpadded BSND tensor plus reference output for varlen testing. Each sequence contributes only its true rows in the packed BSND tensor. - chunk_indices records the unpadded row offset of every valid chunk and - chunk_valid_sizes stores each chunk's runtime size. """ seq_lens = [ cu_seqlens[i + 1] - cu_seqlens[i] @@ -195,9 +230,6 @@ def _build_varlen_bsnd_case( U = torch.zeros((1, total_tokens, num_heads, chunk_size), dtype=torch.float64) golden = torch.zeros((1, total_tokens, num_heads, chunk_size), dtype=torch.float64) - chunk_indices: list[int] = [] - chunk_valid_sizes: list[int] = [] - chunk_infos: list[tuple[int, int, int]] = [] chunk_idx = 0 for seq_idx in range(len(cu_seqlens) - 1): @@ -221,17 +253,12 @@ def _build_varlen_bsnd_case( :actual_size, ] = invert_single_chunk_ref(U_valid) - chunk_indices.append(chunk_start) - chunk_valid_sizes.append(actual_size) - chunk_infos.append((chunk_start, chunk_start, actual_size)) chunk_idx += 1 return ( U, golden, - chunk_infos, - torch.tensor(chunk_indices, dtype=torch.int32), - torch.tensor(chunk_valid_sizes, dtype=torch.int32), + torch.tensor(cu_seqlens, dtype=torch.int32), ) @@ -318,7 +345,7 @@ def _test_case_bsnd_varlen( ftol: float, label: str, ): - U_varlen, golden, chunk_infos, chunk_indices, chunk_valid_sizes = _build_varlen_bsnd_case( + U_varlen, golden, cu_seqlens_tensor = _build_varlen_bsnd_case( gen, cu_seqlens, N, @@ -327,23 +354,9 @@ def _test_case_bsnd_varlen( actual_varlen = _run_kernel_bsnd( tri_inv_func, U_varlen.to(torch.half).npu(), - chunk_indices=chunk_indices.npu(), - chunk_valid_sizes=chunk_valid_sizes.npu(), + cu_seqlens=cu_seqlens_tensor.npu(), ) - - actual = torch.zeros_like(golden) - for input_row, token_row, actual_size in chunk_infos: - actual[ - :, - token_row : token_row + actual_size, - :, - :actual_size, - ] = actual_varlen[ - :, - input_row : input_row + actual_size, - :, - :actual_size, - ] + actual = actual_varlen frob = torch.sqrt( torch.sum((golden - actual) ** 2) / torch.sum(golden ** 2) From 496e6c4bb17997f8cec47f95ef6c5571f778272f Mon Sep 17 00:00:00 2001 From: learning-chip Date: Wed, 1 Apr 2026 20:58:36 +0000 Subject: [PATCH 08/14] add gitignore for fast_inverse example --- examples/jit_cpp/fast_inverse/.gitignore | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 examples/jit_cpp/fast_inverse/.gitignore diff --git a/examples/jit_cpp/fast_inverse/.gitignore b/examples/jit_cpp/fast_inverse/.gitignore new file mode 100644 index 00000000..a937d7e5 --- /dev/null +++ b/examples/jit_cpp/fast_inverse/.gitignore @@ -0,0 +1,2 @@ +benchmark_results +*.so From 16496b1de5c831521d450f364b5bc4cb1725c645 Mon Sep 17 00:00:00 2001 From: learning-chip Date: Wed, 1 Apr 2026 21:14:53 +0000 Subject: [PATCH 09/14] compute chunk metadata inside NPU kernel using scalar core unit --- examples/jit_cpp/fast_inverse/README.md | 11 +- .../benchmark_bsnd_fast_inverse.py | 45 +--- .../jit_cpp/fast_inverse/fast_inverse.cpp | 13 +- .../fast_inverse/jit_util_fast_inverse.py | 28 +-- .../kernel_tri_inv_rec_unroll.cpp | 219 +++++++++--------- .../jit_cpp/fast_inverse/run_fast_inverse.py | 52 +---- 6 files changed, 143 insertions(+), 225 deletions(-) diff --git a/examples/jit_cpp/fast_inverse/README.md b/examples/jit_cpp/fast_inverse/README.md index b5a6cdfb..669aa338 100644 --- a/examples/jit_cpp/fast_inverse/README.md +++ b/examples/jit_cpp/fast_inverse/README.md @@ -52,12 +52,11 @@ and batch configurations. ### Varlen BSND mode -The standalone example also supports variable-length BSND inputs by padding each -sequence to the next multiple of `D` and passing a `chunk_indices` tensor to the -kernel. Each entry in `chunk_indices` is the padded row-start of one valid -chunk. The kernel still inverts dense `D x D` tiles; the Python harness pads -inputs before launch and slices the padded rows back away when validating the -result. +The standalone example also supports variable-length BSND inputs with the same +external signature as the Triton reference path: callers provide packed BSND +data plus `cu_seqlens`, and the PTO kernel derives each chunk row-start and +tail size internally on NPU. The kernel still inverts dense `D x D` tiles, but +tail chunks load/store only their valid prefix. ### Benchmark diff --git a/examples/jit_cpp/fast_inverse/benchmark_bsnd_fast_inverse.py b/examples/jit_cpp/fast_inverse/benchmark_bsnd_fast_inverse.py index ac25d9e5..bf79cd3b 100644 --- a/examples/jit_cpp/fast_inverse/benchmark_bsnd_fast_inverse.py +++ b/examples/jit_cpp/fast_inverse/benchmark_bsnd_fast_inverse.py @@ -68,36 +68,14 @@ def make_minus_identity(matrix_size: int, device: str) -> torch.Tensor: return minus_identity -def chunk_metadata_from_cu_seqlens( +def count_varlen_chunks( cu_seqlens: torch.Tensor, chunk_size: int, -) -> tuple[torch.Tensor, torch.Tensor]: - cu_seqlens_np = cu_seqlens.detach().cpu().numpy().astype(np.int64, copy=False) - seq_starts = cu_seqlens_np[:-1] - seq_lens = cu_seqlens_np[1:] - seq_starts - seq_num_chunks = (seq_lens + chunk_size - 1) // chunk_size - total_chunks = int(seq_num_chunks.sum()) - - chunk_indices = np.empty(total_chunks, dtype=np.int32) - chunk_valid_sizes = np.empty(total_chunks, dtype=np.int32) - cursor = 0 - for seq_start, seq_len, num_chunks in zip(seq_starts, seq_lens, seq_num_chunks): - num_chunks_int = int(num_chunks) - local_offsets = np.arange(num_chunks_int, dtype=np.int64) * chunk_size - next_cursor = cursor + num_chunks_int - chunk_indices[cursor:next_cursor] = (seq_start + local_offsets).astype( - np.int32, - copy=False, - ) - chunk_valid_sizes[cursor:next_cursor] = np.minimum( - chunk_size, - seq_len - local_offsets, - ).astype(np.int32, copy=False) - cursor = next_cursor - - return ( - torch.from_numpy(chunk_indices).to(device=cu_seqlens.device), - torch.from_numpy(chunk_valid_sizes).to(device=cu_seqlens.device), +) -> int: + cu_seqlens_list = [int(x) for x in cu_seqlens.detach().cpu().tolist()] + return sum( + (cu_seqlens_list[i + 1] - cu_seqlens_list[i] + chunk_size - 1) // chunk_size + for i in range(len(cu_seqlens_list) - 1) ) @@ -245,15 +223,9 @@ def make_varlen_runner( ) -> tuple[callable, torch.Tensor]: matrix_size = tensor_in.shape[-1] num_bsnd_heads = tensor_in.shape[-2] - seq_lens = cu_seqlens[1:].to(torch.int64) - cu_seqlens[:-1].to(torch.int64) - num_chunks = ((seq_lens + matrix_size - 1) // matrix_size).sum().item() - num_matrices = int(num_chunks) * num_bsnd_heads + num_matrices = count_varlen_chunks(cu_seqlens, matrix_size) * num_bsnd_heads tensor_out = torch.empty_like(tensor_in, dtype=torch.float32) minus_identity = make_minus_identity(matrix_size, str(tensor_in.device)) - chunk_indices, chunk_valid_sizes = chunk_metadata_from_cu_seqlens( - cu_seqlens, - matrix_size, - ) def run(): tri_inv_func( @@ -263,8 +235,7 @@ def run(): matrix_size, num_matrices, num_bsnd_heads, - chunk_indices=chunk_indices, - chunk_valid_sizes=chunk_valid_sizes, + cu_seqlens=cu_seqlens, ) return run, tensor_out diff --git a/examples/jit_cpp/fast_inverse/fast_inverse.cpp b/examples/jit_cpp/fast_inverse/fast_inverse.cpp index 12c81086..ac77e8ba 100644 --- a/examples/jit_cpp/fast_inverse/fast_inverse.cpp +++ b/examples/jit_cpp/fast_inverse/fast_inverse.cpp @@ -26,18 +26,15 @@ for the full License text. * @param num_matrices Total number of matrices to invert. * @param num_bsnd_heads 0 for standard (B…ND) layout; * N (number of heads) for BSND layout. - * @param chunk_indices Optional int32 pointer used only for varlen BSND. Each - * entry is the absolute row offset of one chunk within the - * unpadded BSND tensor. - * @param chunk_valid_sizes Optional int32 pointer used only for varlen BSND. - * Each entry stores the runtime size of that chunk. + * @param cu_seqlens Optional int32 pointer used only for varlen BSND. Matches + * the Triton-style API and stores cumulative sequence + * boundaries for the packed BSND tensor. */ extern "C" void call_kernel(uint32_t blockDim, void* stream, void* tensor_out, void* tensor_in, void* minus_identity_in, uint32_t matrix_size, uint32_t num_matrices, - uint32_t num_bsnd_heads, void* chunk_indices, - void* chunk_valid_sizes) { + uint32_t num_bsnd_heads, void* cu_seqlens) { tri_inv_rec_unroll_fp16<<>>( tensor_out, tensor_in, minus_identity_in, matrix_size, num_matrices, - num_bsnd_heads, chunk_indices, chunk_valid_sizes); + num_bsnd_heads, cu_seqlens); } diff --git a/examples/jit_cpp/fast_inverse/jit_util_fast_inverse.py b/examples/jit_cpp/fast_inverse/jit_util_fast_inverse.py index 90912373..9e160141 100644 --- a/examples/jit_cpp/fast_inverse/jit_util_fast_inverse.py +++ b/examples/jit_cpp/fast_inverse/jit_util_fast_inverse.py @@ -85,8 +85,7 @@ def load_lib(lib_path: str): ctypes.c_uint32, # matrix_size ctypes.c_uint32, # num_matrices ctypes.c_uint32, # num_bsnd_heads - ctypes.c_void_p, # chunk_indices (optional int32 metadata) - ctypes.c_void_p, # chunk_valid_sizes (optional int32 metadata) + ctypes.c_void_p, # cu_seqlens (optional int32 metadata) ] lib.call_kernel.restype = None @@ -97,23 +96,17 @@ def tri_inv_func( matrix_size: int, num_matrices: int, num_bsnd_heads: int = 0, - chunk_indices: torch.Tensor | None = None, - chunk_valid_sizes: torch.Tensor | None = None, + cu_seqlens: torch.Tensor | None = None, block_dim: int = BLOCK_DIM, stream_ptr=None, ): if stream_ptr is None: stream_ptr = torch.npu.current_stream()._as_parameter_ # noqa - if chunk_indices is not None: - if chunk_indices.dtype != torch.int32: - raise TypeError("chunk_indices must be int32.") - if not chunk_indices.is_contiguous(): - raise ValueError("chunk_indices must be contiguous.") - if chunk_valid_sizes is not None: - if chunk_valid_sizes.dtype != torch.int32: - raise TypeError("chunk_valid_sizes must be int32.") - if not chunk_valid_sizes.is_contiguous(): - raise ValueError("chunk_valid_sizes must be contiguous.") + if cu_seqlens is not None: + if cu_seqlens.dtype != torch.int32: + raise TypeError("cu_seqlens must be int32.") + if not cu_seqlens.is_contiguous(): + raise ValueError("cu_seqlens must be contiguous.") effective_block_dim = min(block_dim, num_matrices) lib.call_kernel( effective_block_dim, @@ -124,11 +117,8 @@ def tri_inv_func( matrix_size, num_matrices, num_bsnd_heads, - _torch_to_ctypes(chunk_indices) - if chunk_indices is not None - else ctypes.c_void_p(), - _torch_to_ctypes(chunk_valid_sizes) - if chunk_valid_sizes is not None + _torch_to_ctypes(cu_seqlens) + if cu_seqlens is not None else ctypes.c_void_p(), ) diff --git a/examples/jit_cpp/fast_inverse/kernel_tri_inv_rec_unroll.cpp b/examples/jit_cpp/fast_inverse/kernel_tri_inv_rec_unroll.cpp index 9065b7fd..be8f43fb 100644 --- a/examples/jit_cpp/fast_inverse/kernel_tri_inv_rec_unroll.cpp +++ b/examples/jit_cpp/fast_inverse/kernel_tri_inv_rec_unroll.cpp @@ -31,21 +31,39 @@ AICORE inline uint32_t GetBSNDFixedTileOffset(uint32_t tile_id, return BSND_OFFSET(tile_id, num_bsnd_heads, matrix_size, matrix_size); } +struct BSNDVarlenTileInfo { + uint32_t bsnd_offset; + uint32_t valid_size; +}; + /* - * For varlen BSND, chunk_indices stores the absolute row offset of each chunk - * inside the unpadded BSND tensor. Each tile_id still enumerates chunk-major, - * then head-major. + * For cu_seqlens-based varlen BSND, tile_id still enumerates chunk-major then + * head-major. We recover the owning sequence by scanning cu_seqlens and + * counting chunks per sequence. */ -AICORE inline uint32_t GetBSNDVarlenTileOffset(uint32_t tile_id, - uint32_t num_bsnd_heads, - uint32_t matrix_size, - __gm__ int32_t* chunk_indices) { +AICORE inline BSNDVarlenTileInfo GetBSNDVarlenTileInfoFromCuSeqlens( + uint32_t tile_id, uint32_t num_bsnd_heads, uint32_t matrix_size, + __gm__ int32_t* cu_seqlens) { const uint32_t head_idx = tile_id % num_bsnd_heads; const uint32_t chunk_idx = tile_id / num_bsnd_heads; - const uint32_t chunk_row_start = - static_cast(chunk_indices[chunk_idx]); - return chunk_row_start * num_bsnd_heads * matrix_size + - head_idx * matrix_size; + + uint32_t seq_start = static_cast(cu_seqlens[0]); + uint32_t accumulated_chunks = 0; + for (uint32_t seq_idx = 0;; ++seq_idx) { + const uint32_t seq_end = static_cast(cu_seqlens[seq_idx + 1]); + const uint32_t seq_len = seq_end - seq_start; + const uint32_t seq_num_chunks = CeilDiv(seq_len, matrix_size); + if (chunk_idx < accumulated_chunks + seq_num_chunks) { + const uint32_t local_chunk_idx = chunk_idx - accumulated_chunks; + const uint32_t row_start = seq_start + local_chunk_idx * matrix_size; + const uint32_t valid_size = + min(static_cast(seq_end - row_start), matrix_size); + return {row_start * num_bsnd_heads * matrix_size + head_idx * matrix_size, + valid_size}; + } + accumulated_chunks += seq_num_chunks; + seq_start = seq_end; + } } /* @@ -399,10 +417,7 @@ AICORE inline void TriInvRecUnrollKernel(__gm__ OutputT* M_inv, __gm__ InputT* M, __gm__ InputT* I_neg, uint32_t total_tiles, uint32_t num_bsnd_heads = 0, - __gm__ int32_t* chunk_indices = - nullptr, - __gm__ int32_t* chunk_valid_sizes = - nullptr) { + __gm__ int32_t* cu_seqlens = nullptr) { constexpr uint32_t TileLen = MatrixSize * MatrixSize; constexpr uint32_t FractalSize = 16; constexpr uint32_t NumL0Buffers = 2; @@ -489,6 +504,9 @@ AICORE inline void TriInvRecUnrollKernel(__gm__ OutputT* M_inv, const uint32_t max_iters_per_aic = CeilDiv(total_tiles, (uint32_t)(NumTilesPerCubeIter * get_block_num())); + uint32_t bsnd_tile_offsets[NumTilesPerCubeIter] = {0}; + uint32_t bsnd_tile_valid_sizes[NumTilesPerCubeIter] = {0}; + uint32_t next_tile_id_that_waits_for_pipe_fix_pipe_m = 0; set_flag(PIPE_FIX, PIPE_M, static_cast(next_tile_id_that_waits_for_pipe_fix_pipe_m)); @@ -506,33 +524,32 @@ AICORE inline void TriInvRecUnrollKernel(__gm__ OutputT* M_inv, ++tile_id) { if constexpr (IsBSND) { const uint32_t global_tile_id = global_index + tile_id; - const uint32_t bsnd_offset = - chunk_indices != nullptr - ? GetBSNDVarlenTileOffset(global_tile_id, num_bsnd_heads, - MatrixSize, chunk_indices) - : GetBSNDFixedTileOffset(global_tile_id, num_bsnd_heads, - MatrixSize); + if (cu_seqlens != nullptr) { + const BSNDVarlenTileInfo tile_info = GetBSNDVarlenTileInfoFromCuSeqlens( + global_tile_id, num_bsnd_heads, MatrixSize, cu_seqlens); + bsnd_tile_offsets[tile_id] = tile_info.bsnd_offset; + bsnd_tile_valid_sizes[tile_id] = tile_info.valid_size; + } else { + bsnd_tile_offsets[tile_id] = + GetBSNDFixedTileOffset(global_tile_id, num_bsnd_heads, MatrixSize); + bsnd_tile_valid_sizes[tile_id] = MatrixSize; + } + const uint32_t bsnd_offset = bsnd_tile_offsets[tile_id]; + const uint32_t valid_size = bsnd_tile_valid_sizes[tile_id]; const int row_stride = static_cast(MatrixSize * num_bsnd_heads); wait_flag(PIPE_M, PIPE_MTE2, static_cast(tile_id)); - if (chunk_valid_sizes != nullptr) { - const uint32_t chunk_idx = global_tile_id / num_bsnd_heads; - const uint32_t valid_size = - static_cast(chunk_valid_sizes[chunk_idx]); - if (valid_size < MatrixSize) { - TileL1ABDyn Y_dyn_l1_tile(valid_size, valid_size); - TASSIGN(Y_dyn_l1_tile, - 0x0 + (5 + tile_id) * TileLen * sizeof(InputT)); - GlobalTileInDyn M_global_in_dyn( - M + bsnd_offset, {1, 1, 1, valid_size, valid_size}, - {1, 1, 1, row_stride, 1}); - TLOAD(Y_dyn_l1_tile, M_global_in_dyn); - set_flag(PIPE_MTE2, PIPE_MTE1, static_cast(tile_id)); - wait_flag(PIPE_MTE2, PIPE_MTE1, static_cast(tile_id)); - TFILLPAD(Y_dyn_l1_tile, Y_dyn_l1_tile); - } else { - GlobalTileIn M_global_in(M + bsnd_offset, {}, {row_stride}); - TLOAD(Y_l1_tile[tile_id], M_global_in); - } + if (valid_size < MatrixSize) { + TileL1ABDyn Y_dyn_l1_tile(valid_size, valid_size); + TASSIGN(Y_dyn_l1_tile, + 0x0 + (5 + tile_id) * TileLen * sizeof(InputT)); + GlobalTileInDyn M_global_in_dyn( + M + bsnd_offset, + {1, 1, 1, static_cast(valid_size), static_cast(valid_size)}, + {1, 1, 1, row_stride, 1}); + TLOAD(Y_dyn_l1_tile, M_global_in_dyn); + set_flag(PIPE_MTE2, PIPE_MTE1, static_cast(tile_id)); + wait_flag(PIPE_MTE2, PIPE_MTE1, static_cast(tile_id)); + TFILLPAD(Y_dyn_l1_tile, Y_dyn_l1_tile); } else { GlobalTileIn M_global_in(M + bsnd_offset, {}, {row_stride}); TLOAD(Y_l1_tile[tile_id], M_global_in); @@ -560,43 +577,30 @@ AICORE inline void TriInvRecUnrollKernel(__gm__ OutputT* M_inv, set_flag(PIPE_M, PIPE_MTE2, static_cast(tile_id)); if constexpr (IsBSND) { - const uint32_t global_tile_id = global_index + tile_id; - const uint32_t bsnd_offset = - chunk_indices != nullptr - ? GetBSNDVarlenTileOffset(global_tile_id, num_bsnd_heads, - MatrixSize, chunk_indices) - : GetBSNDFixedTileOffset(global_tile_id, num_bsnd_heads, - MatrixSize); + const uint32_t bsnd_offset = bsnd_tile_offsets[tile_id]; + const uint32_t valid_size = bsnd_tile_valid_sizes[tile_id]; const int row_stride = static_cast(MatrixSize * num_bsnd_heads); - if (chunk_valid_sizes != nullptr) { - const uint32_t chunk_idx = global_tile_id / num_bsnd_heads; - const uint32_t valid_size = - static_cast(chunk_valid_sizes[chunk_idx]); - if (valid_size < MatrixSize) { - const event_t event_0 = static_cast(tile_id); - const event_t event_1 = - static_cast(tile_id + NumTilesPerCubeIter); - TileL0CDyn c_l0_tail_tile(valid_size, valid_size); - TASSIGN(c_l0_tail_tile, - 0x0 + final_c_buffer_index * TileLen * sizeof(OutputT)); - if constexpr (final_c_buffer_index == 1) { - set_flag(PIPE_M, PIPE_FIX, event_1); - wait_flag(PIPE_M, PIPE_FIX, event_1); - } else { - set_flag(PIPE_M, PIPE_FIX, event_0); - wait_flag(PIPE_M, PIPE_FIX, event_0); - } - set_flag(PIPE_FIX, PIPE_MTE3, static_cast(tile_id)); - wait_flag(PIPE_FIX, PIPE_MTE3, static_cast(tile_id)); - GlobalTileOutDyn M_inv_global_out_dyn( - M_inv + bsnd_offset, {1, 1, 1, valid_size, valid_size}, - {1, 1, 1, row_stride, 1}); - TSTORE(M_inv_global_out_dyn, c_l0_tail_tile); + if (valid_size < MatrixSize) { + const event_t event_0 = static_cast(tile_id); + const event_t event_1 = + static_cast(tile_id + NumTilesPerCubeIter); + TileL0CDyn c_l0_tail_tile(valid_size, valid_size); + TASSIGN(c_l0_tail_tile, + 0x0 + final_c_buffer_index * TileLen * sizeof(OutputT)); + if constexpr (final_c_buffer_index == 1) { + set_flag(PIPE_M, PIPE_FIX, event_1); + wait_flag(PIPE_M, PIPE_FIX, event_1); } else { - GlobalTileOut M_inv_global_out(M_inv + bsnd_offset, {}, - {row_stride}); - TSTORE(M_inv_global_out, c_l0_tile[final_c_buffer_index]); + set_flag(PIPE_M, PIPE_FIX, event_0); + wait_flag(PIPE_M, PIPE_FIX, event_0); } + set_flag(PIPE_FIX, PIPE_MTE3, static_cast(tile_id)); + wait_flag(PIPE_FIX, PIPE_MTE3, static_cast(tile_id)); + GlobalTileOutDyn M_inv_global_out_dyn( + M_inv + bsnd_offset, + {1, 1, 1, static_cast(valid_size), static_cast(valid_size)}, + {1, 1, 1, row_stride, 1}); + TSTORE(M_inv_global_out_dyn, c_l0_tail_tile); } else { GlobalTileOut M_inv_global_out(M_inv + bsnd_offset, {}, {row_stride}); @@ -626,17 +630,17 @@ AICORE inline void TriInvRecUnrollKernel(__gm__ OutputT* M_inv, * * The input/output tensors stay unpadded. For tail chunks with size * `actual_size < MatrixSize`, the kernel: - * 1. loads only the valid `actual_size x actual_size` prefix via dynamic TLOAD - * 2. zero-fills the remaining rows/cols in-place via TFILLPAD_INPLACE - * 3. runs the original dense recursive inverse on the materialized full tile - * 4. stores only the valid `actual_size x actual_size` prefix back to GM + * 1. derives the chunk row-start and runtime size from `cu_seqlens` + * 2. loads only the valid `actual_size x actual_size` prefix via dynamic TLOAD + * 3. zero-fills the remaining rows/cols in-place via TFILLPAD_INPLACE + * 4. runs the original dense recursive inverse on the materialized full tile + * 5. stores only the valid `actual_size x actual_size` prefix back to GM */ template AICORE inline void TriInvRecUnrollKernelBSNDVarlen( __gm__ OutputT* M_inv, __gm__ InputT* M, __gm__ InputT* I_neg, - uint32_t total_tiles, uint32_t num_bsnd_heads, __gm__ int32_t* chunk_indices, - __gm__ int32_t* chunk_valid_sizes) { + uint32_t total_tiles, uint32_t num_bsnd_heads, __gm__ int32_t* cu_seqlens) { constexpr uint32_t TileLen = MatrixSize * MatrixSize; constexpr uint32_t FractalSize = 16; constexpr uint32_t NumL0Buffers = 2; @@ -732,11 +736,10 @@ AICORE inline void TriInvRecUnrollKernelBSNDVarlen( (global_index + tile_id < total_tiles); ++tile_id) { const uint32_t global_tile_id = global_index + tile_id; - const uint32_t chunk_idx = global_tile_id / num_bsnd_heads; - const uint32_t valid_size = - static_cast(chunk_valid_sizes[chunk_idx]); - const uint32_t bsnd_offset = GetBSNDVarlenTileOffset( - global_tile_id, num_bsnd_heads, MatrixSize, chunk_indices); + const BSNDVarlenTileInfo tile_info = GetBSNDVarlenTileInfoFromCuSeqlens( + global_tile_id, num_bsnd_heads, MatrixSize, cu_seqlens); + const uint32_t valid_size = tile_info.valid_size; + const uint32_t bsnd_offset = tile_info.bsnd_offset; const int row_stride = static_cast(MatrixSize * num_bsnd_heads); if (valid_size == MatrixSize) { @@ -794,14 +797,12 @@ template (M_inv, M, I_neg, total_tiles, num_bsnd_heads, - chunk_indices, chunk_valid_sizes); + cu_seqlens); #else // Nothing to do on AIV #endif @@ -813,30 +814,29 @@ AICORE void run_tri_inv_rec_unroll(__gm__ float* tensor_out, __gm__ InputT* minus_identity_in, uint32_t matrix_size, uint32_t num_matrices, uint32_t num_bsnd_heads, - __gm__ int32_t* chunk_indices, - __gm__ int32_t* chunk_valid_sizes) { + __gm__ int32_t* cu_seqlens) { static_assert(std::is_same_v, "tri_inv_rec_unroll supports only fp16."); switch (matrix_size) { case 16: runKernelTriInvRecUnroll( - tensor_out, tensor_in, minus_identity_in, num_matrices, - num_bsnd_heads, chunk_indices, chunk_valid_sizes); + tensor_out, tensor_in, minus_identity_in, num_matrices, num_bsnd_heads, + cu_seqlens); break; case 32: runKernelTriInvRecUnroll( - tensor_out, tensor_in, minus_identity_in, num_matrices, - num_bsnd_heads, chunk_indices, chunk_valid_sizes); + tensor_out, tensor_in, minus_identity_in, num_matrices, num_bsnd_heads, + cu_seqlens); break; case 64: runKernelTriInvRecUnroll( - tensor_out, tensor_in, minus_identity_in, num_matrices, - num_bsnd_heads, chunk_indices, chunk_valid_sizes); + tensor_out, tensor_in, minus_identity_in, num_matrices, num_bsnd_heads, + cu_seqlens); break; case 128: runKernelTriInvRecUnroll( - tensor_out, tensor_in, minus_identity_in, num_matrices, - num_bsnd_heads, chunk_indices, chunk_valid_sizes); + tensor_out, tensor_in, minus_identity_in, num_matrices, num_bsnd_heads, + cu_seqlens); break; } } @@ -844,47 +844,40 @@ AICORE void run_tri_inv_rec_unroll(__gm__ float* tensor_out, extern "C" __global__ AICORE void tri_inv_rec_unroll_fp16( __gm__ void* tensor_out, __gm__ void* tensor_in, __gm__ void* minus_identity_in, uint32_t matrix_size, uint32_t num_matrices, - uint32_t num_bsnd_heads, __gm__ void* chunk_indices, - __gm__ void* chunk_valid_sizes) { + uint32_t num_bsnd_heads, __gm__ void* cu_seqlens) { if (num_bsnd_heads == 0) { if (num_matrices <= get_block_num()) { run_tri_inv_rec_unroll( (__gm__ float*)tensor_out, (__gm__ half*)tensor_in, (__gm__ half*)minus_identity_in, matrix_size, num_matrices, - num_bsnd_heads, (__gm__ int32_t*)chunk_indices, - (__gm__ int32_t*)chunk_valid_sizes); + num_bsnd_heads, (__gm__ int32_t*)cu_seqlens); } else if (num_matrices <= 2 * get_block_num()) { run_tri_inv_rec_unroll( (__gm__ float*)tensor_out, (__gm__ half*)tensor_in, (__gm__ half*)minus_identity_in, matrix_size, num_matrices, - num_bsnd_heads, (__gm__ int32_t*)chunk_indices, - (__gm__ int32_t*)chunk_valid_sizes); + num_bsnd_heads, (__gm__ int32_t*)cu_seqlens); } else { run_tri_inv_rec_unroll( (__gm__ float*)tensor_out, (__gm__ half*)tensor_in, (__gm__ half*)minus_identity_in, matrix_size, num_matrices, - num_bsnd_heads, (__gm__ int32_t*)chunk_indices, - (__gm__ int32_t*)chunk_valid_sizes); + num_bsnd_heads, (__gm__ int32_t*)cu_seqlens); } } else { if (num_matrices <= get_block_num()) { run_tri_inv_rec_unroll( (__gm__ float*)tensor_out, (__gm__ half*)tensor_in, (__gm__ half*)minus_identity_in, matrix_size, num_matrices, - num_bsnd_heads, (__gm__ int32_t*)chunk_indices, - (__gm__ int32_t*)chunk_valid_sizes); + num_bsnd_heads, (__gm__ int32_t*)cu_seqlens); } else if (num_matrices <= 2 * get_block_num()) { run_tri_inv_rec_unroll( (__gm__ float*)tensor_out, (__gm__ half*)tensor_in, (__gm__ half*)minus_identity_in, matrix_size, num_matrices, - num_bsnd_heads, (__gm__ int32_t*)chunk_indices, - (__gm__ int32_t*)chunk_valid_sizes); + num_bsnd_heads, (__gm__ int32_t*)cu_seqlens); } else { run_tri_inv_rec_unroll( (__gm__ float*)tensor_out, (__gm__ half*)tensor_in, (__gm__ half*)minus_identity_in, matrix_size, num_matrices, - num_bsnd_heads, (__gm__ int32_t*)chunk_indices, - (__gm__ int32_t*)chunk_valid_sizes); + num_bsnd_heads, (__gm__ int32_t*)cu_seqlens); } } } diff --git a/examples/jit_cpp/fast_inverse/run_fast_inverse.py b/examples/jit_cpp/fast_inverse/run_fast_inverse.py index 75a247f8..b7dbad99 100644 --- a/examples/jit_cpp/fast_inverse/run_fast_inverse.py +++ b/examples/jit_cpp/fast_inverse/run_fast_inverse.py @@ -93,38 +93,18 @@ def _make_minus_identity(matrix_size: int, device: str) -> torch.Tensor: return I_neg -def _chunk_metadata_from_cu_seqlens( +def _count_varlen_chunks( cu_seqlens: torch.Tensor | list[int], chunk_size: int, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> int: if isinstance(cu_seqlens, torch.Tensor): - cu_seqlens_np = cu_seqlens.detach().cpu().numpy().astype(np.int64, copy=False) + cu_seqlens_list = [int(x) for x in cu_seqlens.detach().cpu().tolist()] else: - cu_seqlens_np = np.asarray(cu_seqlens, dtype=np.int64) - - seq_starts = cu_seqlens_np[:-1] - seq_lens = cu_seqlens_np[1:] - seq_starts - seq_num_chunks = (seq_lens + chunk_size - 1) // chunk_size - total_chunks = int(seq_num_chunks.sum()) - - chunk_indices = np.empty(total_chunks, dtype=np.int32) - chunk_valid_sizes = np.empty(total_chunks, dtype=np.int32) - cursor = 0 - for seq_start, seq_len, num_chunks in zip(seq_starts, seq_lens, seq_num_chunks): - num_chunks_int = int(num_chunks) - local_offsets = np.arange(num_chunks_int, dtype=np.int64) * chunk_size - next_cursor = cursor + num_chunks_int - chunk_indices[cursor:next_cursor] = (seq_start + local_offsets).astype( - np.int32, - copy=False, - ) - chunk_valid_sizes[cursor:next_cursor] = np.minimum( - chunk_size, - seq_len - local_offsets, - ).astype(np.int32, copy=False) - cursor = next_cursor - - return torch.from_numpy(chunk_indices), torch.from_numpy(chunk_valid_sizes) + cu_seqlens_list = [int(x) for x in cu_seqlens] + return sum( + (cu_seqlens_list[i + 1] - cu_seqlens_list[i] + chunk_size - 1) // chunk_size + for i in range(len(cu_seqlens_list) - 1) + ) def _run_kernel(tri_inv_func, U_fp16: torch.Tensor): @@ -163,9 +143,7 @@ def _run_kernel_bsnd( matrix_size = U_bsnd_fp16.shape[-1] num_bsnd_heads = U_bsnd_fp16.shape[-2] if cu_seqlens is not None: - seq_lens = cu_seqlens[1:].to(torch.int64) - cu_seqlens[:-1].to(torch.int64) - num_chunks = ((seq_lens + matrix_size - 1) // matrix_size).sum().item() - num_matrices = int(num_chunks) * num_bsnd_heads + num_matrices = _count_varlen_chunks(cu_seqlens, matrix_size) * num_bsnd_heads else: num_matrices = U_bsnd_fp16.numel() // (matrix_size * matrix_size) device = U_bsnd_fp16.device @@ -174,15 +152,6 @@ def _run_kernel_bsnd( I_neg = _make_minus_identity(matrix_size, str(device)) if cu_seqlens is not None: cu_seqlens = cu_seqlens.to(device=device, dtype=torch.int32).contiguous() - chunk_indices, chunk_valid_sizes = _chunk_metadata_from_cu_seqlens( - cu_seqlens, - matrix_size, - ) - chunk_indices = chunk_indices.to(device=device).contiguous() - chunk_valid_sizes = chunk_valid_sizes.to(device=device).contiguous() - else: - chunk_indices = None - chunk_valid_sizes = None torch.npu.synchronize() tri_inv_func( @@ -192,8 +161,7 @@ def _run_kernel_bsnd( matrix_size, num_matrices, num_bsnd_heads, - chunk_indices=chunk_indices, - chunk_valid_sizes=chunk_valid_sizes, + cu_seqlens=cu_seqlens, ) torch.npu.synchronize() From 87d0a149f002241e69b8994f9823b8984e0d8409 Mon Sep 17 00:00:00 2001 From: learning-chip Date: Wed, 1 Apr 2026 21:47:01 +0000 Subject: [PATCH 10/14] unit test mirror FLA triton repo --- examples/jit_cpp/fast_inverse/README.md | 18 ++ .../run_fast_inverse_varlen_like_triton.py | 202 ++++++++++++++++++ 2 files changed, 220 insertions(+) create mode 100644 examples/jit_cpp/fast_inverse/run_fast_inverse_varlen_like_triton.py diff --git a/examples/jit_cpp/fast_inverse/README.md b/examples/jit_cpp/fast_inverse/README.md index 669aa338..41aeff5a 100644 --- a/examples/jit_cpp/fast_inverse/README.md +++ b/examples/jit_cpp/fast_inverse/README.md @@ -23,6 +23,7 @@ The implementation uses a two-phase recursive approach on Ascend cube cores: | `fast_inverse.cpp` | Thin JIT wrapper: includes the kernel and exposes `call_kernel` | | `jit_util_fast_inverse.py` | Compiles the kernel with `bisheng` and loads it via `ctypes` | | `run_fast_inverse.py` | Correctness test suite, including aligned and varlen BSND coverage | +| `run_fast_inverse_varlen_like_triton.py` | Standalone varlen runner that mirrors the Triton `test_solve_tril_varlen` input generation in pure PyTorch | | `benchmark_bsnd_fast_inverse.py` | Benchmarks fixed BSND vs varlen-uniform BSND and plots effective bandwidth | ### Usage @@ -38,6 +39,23 @@ The script compiles `fast_inverse.cpp` on first run (takes ~60 s), then executes correctness checks across a range of matrix sizes (16, 32, 64, 128) and batch configurations. +To run the standalone Triton-like varlen coverage: + +```bash +export PTO_LIB_PATH=/sources/pto-isa/ + +cd examples/jit_cpp/fast_inverse +python run_fast_inverse_varlen_like_triton.py +``` + +That script: + +- uses the same varlen case list and input-generation structure as + `flash-linear-attention/tests/ops/test_solve_tril.py::test_solve_tril_varlen` +- keeps PTO inputs in `float16` +- emulates `chunk_scaled_dot_kkt_fwd` in PyTorch because Triton is not available +- prints a simple pytest-like `PASS` / `FAIL` report plus a final summary + ### Supported matrix sizes `matrix_size` (last dimension of the input tensor) must be one of: **16, 32, diff --git a/examples/jit_cpp/fast_inverse/run_fast_inverse_varlen_like_triton.py b/examples/jit_cpp/fast_inverse/run_fast_inverse_varlen_like_triton.py new file mode 100644 index 00000000..5ef553b5 --- /dev/null +++ b/examples/jit_cpp/fast_inverse/run_fast_inverse_varlen_like_triton.py @@ -0,0 +1,202 @@ +""" +Standalone varlen BSND correctness runner that mirrors the Triton unit tests: +https://github.com/fla-org/flash-linear-attention/blob/v0.4.2/tests/ops/test_solve_tril.py + +But changes: +1. uses fp16 inputs because the PTO kernel currently supports fp16 only +2. emulates `chunk_scaled_dot_kkt_fwd` in PyTorch because Triton is unavailable + +Run from the fast_inverse/ directory: + + export PTO_LIB_PATH=/sources/pto-isa + python run_fast_inverse_varlen_like_triton.py +""" + +from __future__ import annotations + +import os + +import numpy as np +import torch +import torch.nn.functional as F +import torch_npu # noqa: F401 + +from jit_util_fast_inverse import jit_compile + + +torch.manual_seed(42) +np.random.seed(42) + + +def _make_minus_identity(matrix_size: int, device: torch.device) -> torch.Tensor: + minus_identity = torch.zeros( + (matrix_size, matrix_size), + dtype=torch.float16, + device=device, + ) + minus_identity.fill_diagonal_(-1) + return minus_identity + + +def _count_varlen_chunks(cu_seqlens: torch.Tensor, chunk_size: int) -> int: + return sum( + (int(eos) - int(bos) + chunk_size - 1) // chunk_size + for bos, eos in zip(cu_seqlens[:-1].tolist(), cu_seqlens[1:].tolist(), strict=False) + ) + + +def _chunk_scaled_dot_kkt_fwd_emulated( + k: torch.Tensor, + beta: torch.Tensor, + cu_seqlens: torch.Tensor, + chunk_size: int, +) -> torch.Tensor: + t_total = int(cu_seqlens[-1].item()) + num_heads = k.shape[2] + A = torch.zeros((1, t_total, num_heads, chunk_size), dtype=k.dtype, device=k.device) + + for bos, eos in zip(cu_seqlens[:-1].tolist(), cu_seqlens[1:].tolist(), strict=False): + for chunk_start in range(bos, eos, chunk_size): + chunk_end = min(chunk_start + chunk_size, eos) + actual_size = chunk_end - chunk_start + k_chunk = k[:, chunk_start:chunk_end].transpose(1, 2).to(torch.float32) + beta_chunk = ( + beta[:, chunk_start:chunk_end] + .transpose(1, 2) + .unsqueeze(-1) + .to(torch.float32) + ) + scores = torch.matmul(k_chunk, k_chunk.transpose(-1, -2)) + scores = torch.tril(scores * beta_chunk, diagonal=-1).to(k.dtype) + A[:, chunk_start:chunk_end, :, :actual_size] = scores.transpose(1, 2) + + return A + + +def _reference_inverse(A: torch.Tensor, cu_seqlens: torch.Tensor, chunk_size: int) -> torch.Tensor: + A_cpu = A.cpu().to(torch.float64) + ref = torch.zeros_like(A_cpu, dtype=torch.float64) + for bos, eos in zip(cu_seqlens[:-1].tolist(), cu_seqlens[1:].tolist(), strict=False): + for chunk_start in range(bos, eos, chunk_size): + actual_size = min(chunk_size, eos - chunk_start) + ref[:, chunk_start : chunk_start + actual_size, :, :actual_size] = torch.inverse( + A_cpu[:, chunk_start : chunk_start + actual_size, :, :actual_size].transpose(1, 2) + + torch.eye(actual_size, dtype=torch.float64)[None, None, ...] + ).transpose(1, 2) + return ref + + +def _transpose_valid_chunks( + A: torch.Tensor, + cu_seqlens: torch.Tensor, + chunk_size: int, +) -> torch.Tensor: + transposed = torch.zeros_like(A) + for bos, eos in zip(cu_seqlens[:-1].tolist(), cu_seqlens[1:].tolist(), strict=False): + for chunk_start in range(bos, eos, chunk_size): + actual_size = min(chunk_size, eos - chunk_start) + chunk = A[:, chunk_start : chunk_start + actual_size, :, :actual_size] + transposed[:, chunk_start : chunk_start + actual_size, :, :actual_size] = chunk.transpose( + 1, 3 + ) + return transposed + + +def _run_pto_varlen(tri_inv_func, A: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: + chunk_size = A.shape[-1] + num_heads = A.shape[-2] + num_matrices = _count_varlen_chunks(cu_seqlens, chunk_size) * num_heads + tensor_out = torch.zeros_like(A, dtype=torch.float32) + minus_identity = _make_minus_identity(chunk_size, A.device) + + torch.npu.synchronize() + tri_inv_func( + tensor_out, + A, + minus_identity, + chunk_size, + num_matrices, + num_heads, + cu_seqlens=cu_seqlens, + ) + torch.npu.synchronize() + return tensor_out.cpu().to(torch.float64) + + +def _run_case( + tri_inv_func, + H: int, + D: int, + chunk_size: int, + cu_seqlens_list: list[int], + atol: float = 5e-4, + rtol: float = 5e-2, + ftol: float = 1e-4, +) -> None: + device = torch.device("npu:0") + T = cu_seqlens_list[-1] + cu_seqlens = torch.tensor(cu_seqlens_list, dtype=torch.int32, device=device) + + # Match the Triton varlen test structure, using fp16 instead of bf16. + k = F.normalize(torch.randn((1, T, H, D), dtype=torch.float16, device=device), dim=-1) + beta = torch.randn((1, T, H), dtype=torch.float16, device=device).sigmoid() + A = _chunk_scaled_dot_kkt_fwd_emulated( + k=k, + beta=beta, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + ) + + ref = _reference_inverse(A, cu_seqlens, chunk_size) + tri = _run_pto_varlen( + tri_inv_func, + _transpose_valid_chunks(A, cu_seqlens, chunk_size), + cu_seqlens, + ) + tri = _transpose_valid_chunks(tri, cu_seqlens, chunk_size) + + frob = torch.sqrt(torch.sum((ref - tri) ** 2) / torch.sum(ref ** 2)).item() + torch.testing.assert_close(tri, ref, atol=atol, rtol=rtol) + assert frob <= ftol, f"Frobenius error {frob:.2e} > {ftol:.2e}" + + +def main() -> int: + if "PTO_LIB_PATH" not in os.environ: + fallback = "/sources/pto-isa" + if os.path.exists(fallback): + os.environ["PTO_LIB_PATH"] = fallback + + torch.npu.set_device("npu:0") + + src = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fast_inverse.cpp") + print(f"Compiling {src} ...") + tri_inv_func = jit_compile(src) + print("Compilation successful.\n") + + cases = [ + (4, 64, 16, [0, 15]), + (4, 64, 32, [0, 256, 500, 1000]), + (4, 100, 64, [0, 15, 100, 300, 1200, 2000]), + (4, 64, 16, [0, 1, 100, 300, 1200, 2048]), + (4, 128, 32, [0, 200, 512, 1200, 2048]), + ] + + total = 0 + passed = 0 + print("=== Varlen Like Triton ===") + for H, D, chunk_size, cu_seqlens in cases: + total += 1 + label = f"H={H} D={D} chunk_size={chunk_size} cu_seqlens={cu_seqlens}" + try: + _run_case(tri_inv_func, H, D, chunk_size, cu_seqlens) + print(f" PASS {label}") + passed += 1 + except Exception as err: + print(f" FAIL {label}: {err}") + + print(f"\n{passed}/{total} cases passed.") + return 0 if passed == total else 1 + + +if __name__ == "__main__": + raise SystemExit(main()) From bd5401739d3b925aa8cfb4cf990a2963ca0d602f Mon Sep 17 00:00:00 2001 From: learning-chip Date: Wed, 1 Apr 2026 22:06:35 +0000 Subject: [PATCH 11/14] also change benchmark script to use triton-like input preparation --- .../benchmark_bsnd_fast_inverse.py | 173 ++++++++++-------- 1 file changed, 101 insertions(+), 72 deletions(-) diff --git a/examples/jit_cpp/fast_inverse/benchmark_bsnd_fast_inverse.py b/examples/jit_cpp/fast_inverse/benchmark_bsnd_fast_inverse.py index bf79cd3b..de1aa956 100644 --- a/examples/jit_cpp/fast_inverse/benchmark_bsnd_fast_inverse.py +++ b/examples/jit_cpp/fast_inverse/benchmark_bsnd_fast_inverse.py @@ -10,7 +10,8 @@ """ Benchmark the standalone BSND fast-inverse kernel. -This script only benchmarks the PTO-ISA BSND kernel in two modes: +This script benchmarks the PTO-ISA BSND kernel in two modes using Triton-unit- +test-like inputs: 1. `bsnd-fixed`: Original aligned BSND layout with shape `(B, T, H, D)`. @@ -18,9 +19,12 @@ The new varlen path using packed shape `(1, B*T, H, D)` with uniform `cu_seqlens = [0, T, 2T, ...]`. -The two modes use the same total token count and the same underlying chunk data, -so their latency / effective bandwidth can be compared directly. The script also -checks that both modes produce numerically matching results. +The two modes use the same total token count and the same underlying `k` / `beta` +inputs. `A` is generated in eager PyTorch with an emulation of +`chunk_scaled_dot_kkt_fwd`, then each valid chunk is transposed before launch so +the PTO kernel still sees its expected upper-triangular layout. The script also +checks that both modes produce numerically matching results after transposing +outputs back to the lower-triangular convention used by the Triton tests. """ from __future__ import annotations @@ -34,6 +38,7 @@ import matplotlib.pyplot as plt import numpy as np import torch +import torch.nn.functional as F import torch_npu # noqa: F401 from jit_util_fast_inverse import jit_compile @@ -41,6 +46,7 @@ DEFAULT_SEQLENS = (512, 1024, 2048, 4096, 8192, 16384) DEFAULT_CACHE_SIZE = 256 * 1024 * 1024 +DEFAULT_FEATURE_DIM = 64 NPU_DEVICE = os.getenv("GDN_TRI_INVERSE_NPU_DEVICE", "npu:0") THIS_DIR = Path(__file__).resolve().parent RESULTS_DIR = THIS_DIR / "benchmark_results" @@ -79,35 +85,77 @@ def count_varlen_chunks( ) -def random_chunk_mats( - total_chunks: int, - num_heads: int, +def chunk_scaled_dot_kkt_fwd_emulated( + k: torch.Tensor, + beta: torch.Tensor, + cu_seqlens: torch.Tensor, chunk_size: int, - scale: float, - device: str, ) -> torch.Tensor: - return scale * torch.triu( - torch.rand( - (total_chunks, num_heads, chunk_size, chunk_size), - dtype=torch.half, - device=device, - ), - diagonal=1, - ) + total_tokens = int(cu_seqlens[-1].item()) + num_heads = k.shape[2] + A = torch.zeros((1, total_tokens, num_heads, chunk_size), dtype=k.dtype, device=k.device) + + for bos, eos in zip(cu_seqlens[:-1].tolist(), cu_seqlens[1:].tolist(), strict=False): + for chunk_start in range(bos, eos, chunk_size): + chunk_end = min(chunk_start + chunk_size, eos) + actual_size = chunk_end - chunk_start + k_chunk = k[:, chunk_start:chunk_end].transpose(1, 2).to(torch.float32) + beta_chunk = ( + beta[:, chunk_start:chunk_end] + .transpose(1, 2) + .unsqueeze(-1) + .to(torch.float32) + ) + scores = torch.matmul(k_chunk, k_chunk.transpose(-1, -2)) + scores = torch.tril(scores * beta_chunk, diagonal=-1).to(k.dtype) + A[:, chunk_start:chunk_end, :, :actual_size] = scores.transpose(1, 2) + + return A + + +def transpose_valid_chunks( + A: torch.Tensor, + cu_seqlens: torch.Tensor, + chunk_size: int, +) -> torch.Tensor: + transposed = torch.zeros_like(A) + for bos, eos in zip(cu_seqlens[:-1].tolist(), cu_seqlens[1:].tolist(), strict=False): + for chunk_start in range(bos, eos, chunk_size): + actual_size = min(chunk_size, eos - chunk_start) + chunk = A[:, chunk_start : chunk_start + actual_size, :, :actual_size] + transposed[:, chunk_start : chunk_start + actual_size, :, :actual_size] = chunk.transpose( + 1, 3 + ) + return transposed def build_fixed_bsnd_input( - chunk_mats: torch.Tensor, batch_size: int, seqlen: int, num_heads: int, chunk_size: int, -) -> torch.Tensor: - return ( - chunk_mats.transpose(1, 2) - .contiguous() - .reshape(batch_size, seqlen, num_heads, chunk_size) + feature_dim: int, + device: str, +) -> tuple[torch.Tensor, torch.Tensor]: + total_tokens = batch_size * seqlen + cu_seqlens = torch.arange( + 0, + total_tokens + 1, + seqlen, + dtype=torch.int32, + device=device, + ) + k = F.normalize( + torch.randn((1, total_tokens, num_heads, feature_dim), dtype=torch.float16, device=device), + dim=-1, + ) + beta = torch.randn((1, total_tokens, num_heads), dtype=torch.float16, device=device).sigmoid() + A = transpose_valid_chunks( + chunk_scaled_dot_kkt_fwd_emulated(k, beta, cu_seqlens, chunk_size), + cu_seqlens, + chunk_size, ) + return A.reshape(batch_size, seqlen, num_heads, chunk_size).contiguous(), cu_seqlens def build_uniform_varlen_input( @@ -152,45 +200,23 @@ def build_true_varlen_input( seq_lens: list[int], num_heads: int, chunk_size: int, - scale: float, + feature_dim: int, device: str, ) -> tuple[torch.Tensor, torch.Tensor]: cu_seqlens = np.cumsum([0, *seq_lens], dtype=np.int64) - num_chunks = sum((seq_len + chunk_size - 1) // chunk_size for seq_len in seq_lens) - chunk_mats = random_chunk_mats( - total_chunks=num_chunks, - num_heads=num_heads, - chunk_size=chunk_size, - scale=scale, - device=device, - ) - - packed_input = torch.zeros( - (1, int(cu_seqlens[-1]), num_heads, chunk_size), - dtype=torch.half, - device=device, + cu_seqlens_tensor = torch.tensor(cu_seqlens.tolist(), dtype=torch.int32, device=device) + total_tokens = int(cu_seqlens[-1]) + k = F.normalize( + torch.randn((1, total_tokens, num_heads, feature_dim), dtype=torch.float16, device=device), + dim=-1, ) - chunk_idx = 0 - token_row = 0 - - for seq_len in seq_lens: - for local_chunk_start in range(0, seq_len, chunk_size): - actual_size = min(chunk_size, seq_len - local_chunk_start) - chunk = chunk_mats[chunk_idx] - for head_idx in range(num_heads): - packed_input[ - 0, - token_row : token_row + actual_size, - head_idx, - :actual_size, - ] = chunk[head_idx, :actual_size, :actual_size] - token_row += actual_size - chunk_idx += 1 - - return ( - packed_input.contiguous(), - torch.tensor(cu_seqlens.tolist(), dtype=torch.int32, device=device), + beta = torch.randn((1, total_tokens, num_heads), dtype=torch.float16, device=device).sigmoid() + packed_input = transpose_valid_chunks( + chunk_scaled_dot_kkt_fwd_emulated(k, beta, cu_seqlens_tensor, chunk_size), + cu_seqlens_tensor, + chunk_size, ) + return packed_input.contiguous(), cu_seqlens_tensor def make_fixed_runner( @@ -382,6 +408,12 @@ def main() -> None: parser.add_argument("--B", type=int, default=32, help="Dense BSND batch size.") parser.add_argument("--H", type=int, default=4, help="Number of BSND heads.") parser.add_argument("--chunk-size", type=int, default=64) + parser.add_argument( + "--feature-dim", + type=int, + default=DEFAULT_FEATURE_DIM, + help="Feature dimension used to generate Triton-like `k` inputs.", + ) parser.add_argument( "--seqlens", type=parse_int_list, @@ -392,7 +424,6 @@ def main() -> None: f"(default: {','.join(map(str, DEFAULT_SEQLENS))})" ), ) - parser.add_argument("--scale", type=float, default=0.1) parser.add_argument( "--csv", type=str, @@ -466,26 +497,19 @@ def main() -> None: ) continue - total_chunks = args.B * seqlen // args.chunk_size total_tokens = args.B * seqlen print( f"Profiling T={seqlen}, total_tokens={total_tokens}, " - f"B={args.B}, H={args.H}, chunk_size={args.chunk_size}" + f"B={args.B}, H={args.H}, chunk_size={args.chunk_size}, feature_dim={args.feature_dim}" ) - chunk_mats = random_chunk_mats( - total_chunks=total_chunks, - num_heads=args.H, - chunk_size=args.chunk_size, - scale=args.scale, - device=NPU_DEVICE, - ) - fixed_input = build_fixed_bsnd_input( - chunk_mats, + fixed_input, uniform_cu_seqlens = build_fixed_bsnd_input( batch_size=args.B, seqlen=seqlen, num_heads=args.H, chunk_size=args.chunk_size, + feature_dim=args.feature_dim, + device=NPU_DEVICE, ) varlen_input, cu_seqlens = build_uniform_varlen_input( fixed_input, @@ -507,8 +531,13 @@ def main() -> None: varlen_run() torch.npu.synchronize() - packed_fixed_out = fixed_out.reshape(1, total_tokens, args.H, args.chunk_size) - max_abs_diff, rel_frob_diff = accuracy_metrics(packed_fixed_out, varlen_out) + packed_fixed_out = transpose_valid_chunks( + fixed_out.reshape(1, total_tokens, args.H, args.chunk_size), + uniform_cu_seqlens, + args.chunk_size, + ) + packed_varlen_out = transpose_valid_chunks(varlen_out, cu_seqlens, args.chunk_size) + max_abs_diff, rel_frob_diff = accuracy_metrics(packed_fixed_out, packed_varlen_out) print( f" accuracy vs fixed: max_abs_diff={max_abs_diff:.3e}, " f"rel_frob_diff={rel_frob_diff:.3e}" @@ -577,7 +606,7 @@ def main() -> None: seq_lens=seq_lens, num_heads=args.H, chunk_size=args.chunk_size, - scale=args.scale, + feature_dim=args.feature_dim, device=NPU_DEVICE, ) varlen_run_true, _ = make_varlen_runner( From de06b4a0e43ccd1cd4921e3c44f984049e19f711 Mon Sep 17 00:00:00 2001 From: learning-chip Date: Wed, 1 Apr 2026 22:29:02 +0000 Subject: [PATCH 12/14] compare host vs device-side chunk metadata computation --- .../benchmark_bsnd_fast_inverse.py | 208 ++++++++++++++++-- .../jit_cpp/fast_inverse/fast_inverse.cpp | 14 +- .../fast_inverse/host_chunk_metadata.cpp | 42 ++++ .../fast_inverse/host_metadata_util.py | 102 +++++++++ .../fast_inverse/jit_util_fast_inverse.py | 22 ++ .../kernel_tri_inv_rec_unroll.cpp | 67 ++++-- .../jit_cpp/fast_inverse/metadata_overhead.md | 122 ++++++++++ 7 files changed, 537 insertions(+), 40 deletions(-) create mode 100644 examples/jit_cpp/fast_inverse/host_chunk_metadata.cpp create mode 100644 examples/jit_cpp/fast_inverse/host_metadata_util.py create mode 100644 examples/jit_cpp/fast_inverse/metadata_overhead.md diff --git a/examples/jit_cpp/fast_inverse/benchmark_bsnd_fast_inverse.py b/examples/jit_cpp/fast_inverse/benchmark_bsnd_fast_inverse.py index de1aa956..488076fb 100644 --- a/examples/jit_cpp/fast_inverse/benchmark_bsnd_fast_inverse.py +++ b/examples/jit_cpp/fast_inverse/benchmark_bsnd_fast_inverse.py @@ -33,6 +33,7 @@ import csv import math import os +import time from pathlib import Path import matplotlib.pyplot as plt @@ -41,6 +42,7 @@ import torch.nn.functional as F import torch_npu # noqa: F401 +from host_metadata_util import build_varlen_chunk_metadata_cpp from jit_util_fast_inverse import jit_compile @@ -267,6 +269,33 @@ def run(): return run, tensor_out +def make_varlen_runner_host_metadata( + tri_inv_func, + tensor_in: torch.Tensor, + chunk_indices: torch.Tensor, + chunk_valid_sizes: torch.Tensor, +) -> tuple[callable, torch.Tensor]: + matrix_size = tensor_in.shape[-1] + num_bsnd_heads = tensor_in.shape[-2] + num_matrices = int(chunk_indices.numel()) * num_bsnd_heads + tensor_out = torch.empty_like(tensor_in, dtype=torch.float32) + minus_identity = make_minus_identity(matrix_size, str(tensor_in.device)) + + def run(): + tri_inv_func( + tensor_out, + tensor_in, + minus_identity, + matrix_size, + num_matrices, + num_bsnd_heads, + chunk_indices=chunk_indices, + chunk_valid_sizes=chunk_valid_sizes, + ) + + return run, tensor_out + + def benchmark_ms( fn, warmup_iters: int, @@ -294,6 +323,39 @@ def benchmark_ms( return times_ms +def build_host_metadata_on_npu( + cu_seqlens: torch.Tensor, + chunk_size: int, + device: str, +) -> tuple[torch.Tensor, torch.Tensor]: + chunk_indices_cpu, chunk_valid_sizes_cpu = build_varlen_chunk_metadata_cpp( + cu_seqlens, + chunk_size, + ) + return ( + chunk_indices_cpu.to(device=device).contiguous(), + chunk_valid_sizes_cpu.to(device=device).contiguous(), + ) + + +def benchmark_host_metadata_prep_ms( + cu_seqlens: torch.Tensor, + chunk_size: int, + benchmark_iters: int, + device: str, +) -> list[float]: + times_ms: list[float] = [] + cache = torch.ones(DEFAULT_CACHE_SIZE, dtype=torch.int8, device=device) + for _ in range(benchmark_iters): + cache.zero_() + torch.npu.synchronize() + start = time.perf_counter() + build_host_metadata_on_npu(cu_seqlens, chunk_size, device) + torch.npu.synchronize() + times_ms.append((time.perf_counter() - start) * 1000.0) + return times_ms + + def add_bandwidth_fields(row: dict[str, float | int | str], input_dtype_bytes: int = 2) -> None: size_elems = int(row.get("valid_numel", row["numel"])) mem_bytes = size_elems * (input_dtype_bytes + 4) @@ -315,6 +377,7 @@ def write_csv(csv_path: Path, rows: list[dict[str, float | int | str]]) -> None: csv_path.parent.mkdir(parents=True, exist_ok=True) fieldnames = [ "inverse_type", + "metadata_strategy", "dtype", "B", "T", @@ -325,6 +388,8 @@ def write_csv(csv_path: Path, rows: list[dict[str, float | int | str]]) -> None: "valid_numel", "chunk_size", "time_us", + "kernel_time_us", + "metadata_time_us", "mem_bytes", "bw_gbs", "max_abs_diff_to_fixed", @@ -342,7 +407,18 @@ def write_csv(csv_path: Path, rows: list[dict[str, float | int | str]]) -> None: def plot_bandwidth(plot_path: Path, rows: list[dict[str, float | int | str]], batch_size: int, num_heads: int, chunk_size: int) -> None: plot_path.parent.mkdir(parents=True, exist_ok=True) fixed_rows = [row for row in rows if row["inverse_type"] == "bsnd-fixed"] - varlen_rows = [row for row in rows if row["inverse_type"] == "bsnd-varlen-uniform"] + varlen_device_rows = [ + row + for row in rows + if row["inverse_type"] == "bsnd-varlen-uniform" + and row["metadata_strategy"] == "device-cu_seqlens" + ] + varlen_host_rows = [ + row + for row in rows + if row["inverse_type"] == "bsnd-varlen-uniform" + and row["metadata_strategy"] == "host-cpp" + ] fig, ax = plt.subplots(figsize=(7.5, 5.0)) ax.plot( @@ -353,11 +429,18 @@ def plot_bandwidth(plot_path: Path, rows: list[dict[str, float | int | str]], ba label="BSND fixed", ) ax.plot( - [int(row["T"]) / 1000.0 for row in varlen_rows], - [float(row["bw_gbs"]) for row in varlen_rows], + [int(row["T"]) / 1000.0 for row in varlen_device_rows], + [float(row["bw_gbs"]) for row in varlen_device_rows], marker="s", linewidth=2, - label="BSND varlen-uniform", + label="BSND varlen device metadata", + ) + ax.plot( + [int(row["T"]) / 1000.0 for row in varlen_host_rows], + [float(row["bw_gbs"]) for row in varlen_host_rows], + marker="^", + linewidth=2, + label="BSND varlen host metadata", ) ax.set_xlabel("Sequence length T (K)") ax.set_ylabel("Effective bandwidth (GB/s)") @@ -517,18 +600,31 @@ def main() -> None: seqlen=seqlen, chunk_size=args.chunk_size, ) + cu_seqlens_cpu = cu_seqlens.cpu() print(f" uniform cu_seqlens: {cu_seqlens.cpu().tolist()}") fixed_run, fixed_out = make_fixed_runner(tri_inv_func, fixed_input) - varlen_run, varlen_out = make_varlen_runner( + varlen_run_device, varlen_out_device = make_varlen_runner( tri_inv_func, varlen_input, cu_seqlens, ) + chunk_indices, chunk_valid_sizes = build_host_metadata_on_npu( + cu_seqlens_cpu, + args.chunk_size, + NPU_DEVICE, + ) + varlen_run_host, varlen_out_host = make_varlen_runner_host_metadata( + tri_inv_func, + varlen_input, + chunk_indices, + chunk_valid_sizes, + ) fixed_run() - varlen_run() + varlen_run_device() + varlen_run_host() torch.npu.synchronize() packed_fixed_out = transpose_valid_chunks( @@ -536,11 +632,29 @@ def main() -> None: uniform_cu_seqlens, args.chunk_size, ) - packed_varlen_out = transpose_valid_chunks(varlen_out, cu_seqlens, args.chunk_size) - max_abs_diff, rel_frob_diff = accuracy_metrics(packed_fixed_out, packed_varlen_out) + packed_varlen_out_device = transpose_valid_chunks( + varlen_out_device, + cu_seqlens, + args.chunk_size, + ) + packed_varlen_out_host = transpose_valid_chunks( + varlen_out_host, + cu_seqlens, + args.chunk_size, + ) + max_abs_diff_device, rel_frob_diff_device = accuracy_metrics( + packed_fixed_out, + packed_varlen_out_device, + ) + max_abs_diff_host, rel_frob_diff_host = accuracy_metrics( + packed_fixed_out, + packed_varlen_out_host, + ) print( - f" accuracy vs fixed: max_abs_diff={max_abs_diff:.3e}, " - f"rel_frob_diff={rel_frob_diff:.3e}" + f" accuracy vs fixed: device max_abs_diff={max_abs_diff_device:.3e}, " + f"device rel_frob_diff={rel_frob_diff_device:.3e}, " + f"host max_abs_diff={max_abs_diff_host:.3e}, " + f"host rel_frob_diff={rel_frob_diff_host:.3e}" ) fixed_times_ms = benchmark_ms( @@ -549,8 +663,20 @@ def main() -> None: benchmark_iters=args.repeats, device=NPU_DEVICE, ) - varlen_times_ms = benchmark_ms( - varlen_run, + varlen_device_times_ms = benchmark_ms( + varlen_run_device, + warmup_iters=args.warmup, + benchmark_iters=args.repeats, + device=NPU_DEVICE, + ) + host_metadata_times_ms = benchmark_host_metadata_prep_ms( + cu_seqlens_cpu, + args.chunk_size, + benchmark_iters=args.repeats, + device=NPU_DEVICE, + ) + varlen_host_kernel_times_ms = benchmark_ms( + varlen_run_host, warmup_iters=args.warmup, benchmark_iters=args.repeats, device=NPU_DEVICE, @@ -558,6 +684,7 @@ def main() -> None: fixed_row = { "inverse_type": "bsnd-fixed", + "metadata_strategy": "none", "dtype": "fp16", "B": args.B, "T": seqlen, @@ -568,6 +695,8 @@ def main() -> None: "valid_numel": fixed_input.numel(), "chunk_size": args.chunk_size, "time_us": int(round(np.mean(fixed_times_ms) * 1000.0)), + "kernel_time_us": int(round(np.mean(fixed_times_ms) * 1000.0)), + "metadata_time_us": 0, "max_abs_diff_to_fixed": 0.0, "rel_frob_diff_to_fixed": 0.0, "sample_id": "", @@ -575,8 +704,9 @@ def main() -> None: } add_bandwidth_fields(fixed_row) - varlen_row = { + varlen_device_row = { "inverse_type": "bsnd-varlen-uniform", + "metadata_strategy": "device-cu_seqlens", "dtype": "fp16", "B": args.B, "T": seqlen, @@ -586,18 +716,55 @@ def main() -> None: "numel": varlen_input.numel(), "valid_numel": total_tokens * args.H * args.chunk_size, "chunk_size": args.chunk_size, - "time_us": int(round(np.mean(varlen_times_ms) * 1000.0)), - "max_abs_diff_to_fixed": max_abs_diff, - "rel_frob_diff_to_fixed": rel_frob_diff, + "time_us": int(round(np.mean(varlen_device_times_ms) * 1000.0)), + "kernel_time_us": int(round(np.mean(varlen_device_times_ms) * 1000.0)), + "metadata_time_us": 0, + "max_abs_diff_to_fixed": max_abs_diff_device, + "rel_frob_diff_to_fixed": rel_frob_diff_device, "sample_id": "", "seq_lens": ",".join([str(seqlen)] * args.B), } - add_bandwidth_fields(varlen_row) + add_bandwidth_fields(varlen_device_row) - rows.extend([fixed_row, varlen_row]) + avg_host_metadata_us = int(round(np.mean(host_metadata_times_ms) * 1000.0)) + avg_host_kernel_us = int(round(np.mean(varlen_host_kernel_times_ms) * 1000.0)) + varlen_host_row = { + "inverse_type": "bsnd-varlen-uniform", + "metadata_strategy": "host-cpp", + "dtype": "fp16", + "B": args.B, + "T": seqlen, + "aggregated_T": total_tokens, + "padded_T": total_tokens, + "H": args.H, + "numel": varlen_input.numel(), + "valid_numel": total_tokens * args.H * args.chunk_size, + "chunk_size": args.chunk_size, + "time_us": avg_host_metadata_us + avg_host_kernel_us, + "kernel_time_us": avg_host_kernel_us, + "metadata_time_us": avg_host_metadata_us, + "max_abs_diff_to_fixed": max_abs_diff_host, + "rel_frob_diff_to_fixed": rel_frob_diff_host, + "sample_id": "", + "seq_lens": ",".join([str(seqlen)] * args.B), + } + add_bandwidth_fields(varlen_host_row) + + rows.extend([fixed_row, varlen_device_row, varlen_host_row]) print( f" fixed: time_us={fixed_row['time_us']}, bw_gbs={fixed_row['bw_gbs']:.2f} | " - f"varlen-uniform: time_us={varlen_row['time_us']}, bw_gbs={varlen_row['bw_gbs']:.2f}" + f"varlen-device: time_us={varlen_device_row['time_us']}, " + f"bw_gbs={varlen_device_row['bw_gbs']:.2f} | " + f"varlen-host: time_us={varlen_host_row['time_us']} " + f"(meta={varlen_host_row['metadata_time_us']}, kernel={varlen_host_row['kernel_time_us']}), " + f"bw_gbs={varlen_host_row['bw_gbs']:.2f}" + ) + device_metadata_overhead_us = ( + varlen_device_row["kernel_time_us"] - varlen_host_row["kernel_time_us"] + ) + print( + f" metadata overhead comparison: device_in_kernel_delta_us={device_metadata_overhead_us}, " + f"host_cpp_metadata_us={varlen_host_row['metadata_time_us']}" ) for sample_idx in range(args.true_varlen_samples): @@ -622,6 +789,7 @@ def main() -> None: ) row = { "inverse_type": "bsnd-varlen-true", + "metadata_strategy": "device-cu_seqlens", "dtype": "fp16", "B": args.B, "T": seqlen, @@ -632,6 +800,8 @@ def main() -> None: "valid_numel": total_tokens * args.H * args.chunk_size, "chunk_size": args.chunk_size, "time_us": int(round(np.mean(times_ms) * 1000.0)), + "kernel_time_us": int(round(np.mean(times_ms) * 1000.0)), + "metadata_time_us": 0, "max_abs_diff_to_fixed": "", "rel_frob_diff_to_fixed": "", "sample_id": sample_idx, diff --git a/examples/jit_cpp/fast_inverse/fast_inverse.cpp b/examples/jit_cpp/fast_inverse/fast_inverse.cpp index ac77e8ba..8155351e 100644 --- a/examples/jit_cpp/fast_inverse/fast_inverse.cpp +++ b/examples/jit_cpp/fast_inverse/fast_inverse.cpp @@ -26,15 +26,19 @@ for the full License text. * @param num_matrices Total number of matrices to invert. * @param num_bsnd_heads 0 for standard (B…ND) layout; * N (number of heads) for BSND layout. - * @param cu_seqlens Optional int32 pointer used only for varlen BSND. Matches - * the Triton-style API and stores cumulative sequence - * boundaries for the packed BSND tensor. + * @param cu_seqlens Optional int32 pointer used only for varlen BSND when the + * device kernel derives chunk metadata itself. + * @param chunk_indices Optional int32 pointer containing per-chunk row starts + * for the host-precomputed varlen path. + * @param chunk_valid_sizes Optional int32 pointer containing each chunk's + * runtime size for the host-precomputed varlen path. */ extern "C" void call_kernel(uint32_t blockDim, void* stream, void* tensor_out, void* tensor_in, void* minus_identity_in, uint32_t matrix_size, uint32_t num_matrices, - uint32_t num_bsnd_heads, void* cu_seqlens) { + uint32_t num_bsnd_heads, void* cu_seqlens, + void* chunk_indices, void* chunk_valid_sizes) { tri_inv_rec_unroll_fp16<<>>( tensor_out, tensor_in, minus_identity_in, matrix_size, num_matrices, - num_bsnd_heads, cu_seqlens); + num_bsnd_heads, cu_seqlens, chunk_indices, chunk_valid_sizes); } diff --git a/examples/jit_cpp/fast_inverse/host_chunk_metadata.cpp b/examples/jit_cpp/fast_inverse/host_chunk_metadata.cpp new file mode 100644 index 00000000..11fdd6c0 --- /dev/null +++ b/examples/jit_cpp/fast_inverse/host_chunk_metadata.cpp @@ -0,0 +1,42 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +All rights reserved. + +See LICENSE in the root of the software repository: +https://github.com/huawei-csl/pto-kernels/ +for the full License text. +*/ + +#include +#include + +extern "C" uint32_t count_varlen_chunks_host_cpp(const int32_t* cu_seqlens, + uint32_t num_sequences, + uint32_t chunk_size) { + uint32_t total_chunks = 0; + for (uint32_t seq_idx = 0; seq_idx < num_sequences; ++seq_idx) { + const uint32_t seq_start = static_cast(cu_seqlens[seq_idx]); + const uint32_t seq_end = static_cast(cu_seqlens[seq_idx + 1]); + const uint32_t seq_len = seq_end - seq_start; + total_chunks += (seq_len + chunk_size - 1) / chunk_size; + } + return total_chunks; +} + +extern "C" void build_varlen_chunk_metadata_host_cpp( + const int32_t* cu_seqlens, uint32_t num_sequences, uint32_t chunk_size, + int32_t* chunk_indices, int32_t* chunk_valid_sizes) { + uint32_t chunk_idx = 0; + for (uint32_t seq_idx = 0; seq_idx < num_sequences; ++seq_idx) { + const uint32_t seq_start = static_cast(cu_seqlens[seq_idx]); + const uint32_t seq_end = static_cast(cu_seqlens[seq_idx + 1]); + for (uint32_t row_start = seq_start; row_start < seq_end; + row_start += chunk_size) { + const uint32_t valid_size = + std::min(chunk_size, static_cast(seq_end - row_start)); + chunk_indices[chunk_idx] = static_cast(row_start); + chunk_valid_sizes[chunk_idx] = static_cast(valid_size); + ++chunk_idx; + } + } +} diff --git a/examples/jit_cpp/fast_inverse/host_metadata_util.py b/examples/jit_cpp/fast_inverse/host_metadata_util.py new file mode 100644 index 00000000..bf6b4f27 --- /dev/null +++ b/examples/jit_cpp/fast_inverse/host_metadata_util.py @@ -0,0 +1,102 @@ +# -------------------------------------------------------------------------------- +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# All rights reserved. +# See LICENSE in the root of the software repository: +# https://github.com/huawei-csl/pto-kernels/ +# for the full License text. +# -------------------------------------------------------------------------------- + +from __future__ import annotations + +import ctypes +import os +import subprocess + +import torch + + +_THIS_DIR = os.path.dirname(os.path.abspath(__file__)) +_SRC = os.path.join(_THIS_DIR, "host_chunk_metadata.cpp") +_LIB = os.path.join(_THIS_DIR, "host_chunk_metadata.so") +_HOST_LIB = None + + +def _torch_to_ctypes(tensor: torch.Tensor) -> ctypes.c_void_p: + return ctypes.c_void_p(tensor.data_ptr()) + + +def compile_host_metadata_cpp(timeout: int = 60) -> str: + compiler = os.environ.get("CXX", "g++") + command = [ + compiler, + "-O3", + "-std=c++17", + "-shared", + "-fPIC", + _SRC, + "-o", + _LIB, + ] + try: + subprocess.run(command, timeout=timeout, check=True) + except Exception as exc: + raise RuntimeError(f"Host metadata compilation failed: {exc}") from exc + return _LIB + + +def load_host_metadata_lib(): + global _HOST_LIB + if _HOST_LIB is not None: + return _HOST_LIB + + lib_path = compile_host_metadata_cpp() + lib = ctypes.CDLL(os.path.abspath(lib_path)) + lib.count_varlen_chunks_host_cpp.argtypes = [ + ctypes.c_void_p, + ctypes.c_uint32, + ctypes.c_uint32, + ] + lib.count_varlen_chunks_host_cpp.restype = ctypes.c_uint32 + lib.build_varlen_chunk_metadata_host_cpp.argtypes = [ + ctypes.c_void_p, + ctypes.c_uint32, + ctypes.c_uint32, + ctypes.c_void_p, + ctypes.c_void_p, + ] + lib.build_varlen_chunk_metadata_host_cpp.restype = None + _HOST_LIB = lib + return lib + + +def build_varlen_chunk_metadata_cpp( + cu_seqlens: torch.Tensor | list[int], + chunk_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + lib = load_host_metadata_lib() + if isinstance(cu_seqlens, torch.Tensor): + cu_seqlens_cpu = cu_seqlens.detach().to(device="cpu", dtype=torch.int32).contiguous() + else: + cu_seqlens_cpu = torch.tensor(cu_seqlens, dtype=torch.int32) + + if cu_seqlens_cpu.numel() < 2: + raise ValueError("cu_seqlens must contain at least 2 entries.") + + num_sequences = cu_seqlens_cpu.numel() - 1 + num_chunks = int( + lib.count_varlen_chunks_host_cpp( + _torch_to_ctypes(cu_seqlens_cpu), + num_sequences, + chunk_size, + ) + ) + chunk_indices = torch.empty(num_chunks, dtype=torch.int32) + chunk_valid_sizes = torch.empty(num_chunks, dtype=torch.int32) + lib.build_varlen_chunk_metadata_host_cpp( + _torch_to_ctypes(cu_seqlens_cpu), + num_sequences, + chunk_size, + _torch_to_ctypes(chunk_indices), + _torch_to_ctypes(chunk_valid_sizes), + ) + return chunk_indices, chunk_valid_sizes diff --git a/examples/jit_cpp/fast_inverse/jit_util_fast_inverse.py b/examples/jit_cpp/fast_inverse/jit_util_fast_inverse.py index 9e160141..63236f79 100644 --- a/examples/jit_cpp/fast_inverse/jit_util_fast_inverse.py +++ b/examples/jit_cpp/fast_inverse/jit_util_fast_inverse.py @@ -86,6 +86,8 @@ def load_lib(lib_path: str): ctypes.c_uint32, # num_matrices ctypes.c_uint32, # num_bsnd_heads ctypes.c_void_p, # cu_seqlens (optional int32 metadata) + ctypes.c_void_p, # chunk_indices (optional int32 metadata) + ctypes.c_void_p, # chunk_valid_sizes (optional int32 metadata) ] lib.call_kernel.restype = None @@ -97,6 +99,8 @@ def tri_inv_func( num_matrices: int, num_bsnd_heads: int = 0, cu_seqlens: torch.Tensor | None = None, + chunk_indices: torch.Tensor | None = None, + chunk_valid_sizes: torch.Tensor | None = None, block_dim: int = BLOCK_DIM, stream_ptr=None, ): @@ -107,6 +111,18 @@ def tri_inv_func( raise TypeError("cu_seqlens must be int32.") if not cu_seqlens.is_contiguous(): raise ValueError("cu_seqlens must be contiguous.") + if chunk_indices is not None: + if chunk_indices.dtype != torch.int32: + raise TypeError("chunk_indices must be int32.") + if not chunk_indices.is_contiguous(): + raise ValueError("chunk_indices must be contiguous.") + if chunk_valid_sizes is not None: + if chunk_valid_sizes.dtype != torch.int32: + raise TypeError("chunk_valid_sizes must be int32.") + if not chunk_valid_sizes.is_contiguous(): + raise ValueError("chunk_valid_sizes must be contiguous.") + if (chunk_indices is None) != (chunk_valid_sizes is None): + raise ValueError("chunk_indices and chunk_valid_sizes must be provided together.") effective_block_dim = min(block_dim, num_matrices) lib.call_kernel( effective_block_dim, @@ -120,6 +136,12 @@ def tri_inv_func( _torch_to_ctypes(cu_seqlens) if cu_seqlens is not None else ctypes.c_void_p(), + _torch_to_ctypes(chunk_indices) + if chunk_indices is not None + else ctypes.c_void_p(), + _torch_to_ctypes(chunk_valid_sizes) + if chunk_valid_sizes is not None + else ctypes.c_void_p(), ) return tri_inv_func diff --git a/examples/jit_cpp/fast_inverse/kernel_tri_inv_rec_unroll.cpp b/examples/jit_cpp/fast_inverse/kernel_tri_inv_rec_unroll.cpp index be8f43fb..b5dafd6d 100644 --- a/examples/jit_cpp/fast_inverse/kernel_tri_inv_rec_unroll.cpp +++ b/examples/jit_cpp/fast_inverse/kernel_tri_inv_rec_unroll.cpp @@ -66,6 +66,18 @@ AICORE inline BSNDVarlenTileInfo GetBSNDVarlenTileInfoFromCuSeqlens( } } +AICORE inline BSNDVarlenTileInfo GetBSNDVarlenTileInfoFromChunkMetadata( + uint32_t tile_id, uint32_t num_bsnd_heads, uint32_t matrix_size, + __gm__ int32_t* chunk_indices, __gm__ int32_t* chunk_valid_sizes) { + const uint32_t head_idx = tile_id % num_bsnd_heads; + const uint32_t chunk_idx = tile_id / num_bsnd_heads; + const uint32_t row_start = static_cast(chunk_indices[chunk_idx]); + const uint32_t valid_size = + static_cast(chunk_valid_sizes[chunk_idx]); + return {row_start * num_bsnd_heads * matrix_size + head_idx * matrix_size, + valid_size}; +} + /* * @brief: Takes as input two matrices of size MatrixSize * MatrixSize each. * The src matrix lies in L1, while the dst matrix lies either in L0A or L0B. @@ -417,7 +429,10 @@ AICORE inline void TriInvRecUnrollKernel(__gm__ OutputT* M_inv, __gm__ InputT* M, __gm__ InputT* I_neg, uint32_t total_tiles, uint32_t num_bsnd_heads = 0, - __gm__ int32_t* cu_seqlens = nullptr) { + __gm__ int32_t* cu_seqlens = nullptr, + __gm__ int32_t* chunk_indices = nullptr, + __gm__ int32_t* chunk_valid_sizes = + nullptr) { constexpr uint32_t TileLen = MatrixSize * MatrixSize; constexpr uint32_t FractalSize = 16; constexpr uint32_t NumL0Buffers = 2; @@ -524,7 +539,14 @@ AICORE inline void TriInvRecUnrollKernel(__gm__ OutputT* M_inv, ++tile_id) { if constexpr (IsBSND) { const uint32_t global_tile_id = global_index + tile_id; - if (cu_seqlens != nullptr) { + if (chunk_indices != nullptr && chunk_valid_sizes != nullptr) { + const BSNDVarlenTileInfo tile_info = + GetBSNDVarlenTileInfoFromChunkMetadata( + global_tile_id, num_bsnd_heads, MatrixSize, chunk_indices, + chunk_valid_sizes); + bsnd_tile_offsets[tile_id] = tile_info.bsnd_offset; + bsnd_tile_valid_sizes[tile_id] = tile_info.valid_size; + } else if (cu_seqlens != nullptr) { const BSNDVarlenTileInfo tile_info = GetBSNDVarlenTileInfoFromCuSeqlens( global_tile_id, num_bsnd_heads, MatrixSize, cu_seqlens); bsnd_tile_offsets[tile_id] = tile_info.bsnd_offset; @@ -797,12 +819,16 @@ template (M_inv, M, I_neg, total_tiles, num_bsnd_heads, - cu_seqlens); + cu_seqlens, chunk_indices, + chunk_valid_sizes); #else // Nothing to do on AIV #endif @@ -814,29 +840,31 @@ AICORE void run_tri_inv_rec_unroll(__gm__ float* tensor_out, __gm__ InputT* minus_identity_in, uint32_t matrix_size, uint32_t num_matrices, uint32_t num_bsnd_heads, - __gm__ int32_t* cu_seqlens) { + __gm__ int32_t* cu_seqlens, + __gm__ int32_t* chunk_indices, + __gm__ int32_t* chunk_valid_sizes) { static_assert(std::is_same_v, "tri_inv_rec_unroll supports only fp16."); switch (matrix_size) { case 16: runKernelTriInvRecUnroll( tensor_out, tensor_in, minus_identity_in, num_matrices, num_bsnd_heads, - cu_seqlens); + cu_seqlens, chunk_indices, chunk_valid_sizes); break; case 32: runKernelTriInvRecUnroll( tensor_out, tensor_in, minus_identity_in, num_matrices, num_bsnd_heads, - cu_seqlens); + cu_seqlens, chunk_indices, chunk_valid_sizes); break; case 64: runKernelTriInvRecUnroll( tensor_out, tensor_in, minus_identity_in, num_matrices, num_bsnd_heads, - cu_seqlens); + cu_seqlens, chunk_indices, chunk_valid_sizes); break; case 128: runKernelTriInvRecUnroll( tensor_out, tensor_in, minus_identity_in, num_matrices, num_bsnd_heads, - cu_seqlens); + cu_seqlens, chunk_indices, chunk_valid_sizes); break; } } @@ -844,40 +872,47 @@ AICORE void run_tri_inv_rec_unroll(__gm__ float* tensor_out, extern "C" __global__ AICORE void tri_inv_rec_unroll_fp16( __gm__ void* tensor_out, __gm__ void* tensor_in, __gm__ void* minus_identity_in, uint32_t matrix_size, uint32_t num_matrices, - uint32_t num_bsnd_heads, __gm__ void* cu_seqlens) { + uint32_t num_bsnd_heads, __gm__ void* cu_seqlens, + __gm__ void* chunk_indices, __gm__ void* chunk_valid_sizes) { if (num_bsnd_heads == 0) { if (num_matrices <= get_block_num()) { run_tri_inv_rec_unroll( (__gm__ float*)tensor_out, (__gm__ half*)tensor_in, (__gm__ half*)minus_identity_in, matrix_size, num_matrices, - num_bsnd_heads, (__gm__ int32_t*)cu_seqlens); + num_bsnd_heads, (__gm__ int32_t*)cu_seqlens, + (__gm__ int32_t*)chunk_indices, (__gm__ int32_t*)chunk_valid_sizes); } else if (num_matrices <= 2 * get_block_num()) { run_tri_inv_rec_unroll( (__gm__ float*)tensor_out, (__gm__ half*)tensor_in, (__gm__ half*)minus_identity_in, matrix_size, num_matrices, - num_bsnd_heads, (__gm__ int32_t*)cu_seqlens); + num_bsnd_heads, (__gm__ int32_t*)cu_seqlens, + (__gm__ int32_t*)chunk_indices, (__gm__ int32_t*)chunk_valid_sizes); } else { run_tri_inv_rec_unroll( (__gm__ float*)tensor_out, (__gm__ half*)tensor_in, (__gm__ half*)minus_identity_in, matrix_size, num_matrices, - num_bsnd_heads, (__gm__ int32_t*)cu_seqlens); + num_bsnd_heads, (__gm__ int32_t*)cu_seqlens, + (__gm__ int32_t*)chunk_indices, (__gm__ int32_t*)chunk_valid_sizes); } } else { if (num_matrices <= get_block_num()) { run_tri_inv_rec_unroll( (__gm__ float*)tensor_out, (__gm__ half*)tensor_in, (__gm__ half*)minus_identity_in, matrix_size, num_matrices, - num_bsnd_heads, (__gm__ int32_t*)cu_seqlens); + num_bsnd_heads, (__gm__ int32_t*)cu_seqlens, + (__gm__ int32_t*)chunk_indices, (__gm__ int32_t*)chunk_valid_sizes); } else if (num_matrices <= 2 * get_block_num()) { run_tri_inv_rec_unroll( (__gm__ float*)tensor_out, (__gm__ half*)tensor_in, (__gm__ half*)minus_identity_in, matrix_size, num_matrices, - num_bsnd_heads, (__gm__ int32_t*)cu_seqlens); + num_bsnd_heads, (__gm__ int32_t*)cu_seqlens, + (__gm__ int32_t*)chunk_indices, (__gm__ int32_t*)chunk_valid_sizes); } else { run_tri_inv_rec_unroll( (__gm__ float*)tensor_out, (__gm__ half*)tensor_in, (__gm__ half*)minus_identity_in, matrix_size, num_matrices, - num_bsnd_heads, (__gm__ int32_t*)cu_seqlens); + num_bsnd_heads, (__gm__ int32_t*)cu_seqlens, + (__gm__ int32_t*)chunk_indices, (__gm__ int32_t*)chunk_valid_sizes); } } } diff --git a/examples/jit_cpp/fast_inverse/metadata_overhead.md b/examples/jit_cpp/fast_inverse/metadata_overhead.md new file mode 100644 index 00000000..5b939745 --- /dev/null +++ b/examples/jit_cpp/fast_inverse/metadata_overhead.md @@ -0,0 +1,122 @@ +# Metadata Overhead Comparison + +This note compares two varlen BSND metadata strategies for the fast-inverse PTO kernel. + +## Strategies + +### 1. Device-side metadata from `cu_seqlens` + +Files: +- `kernel_tri_inv_rec_unroll.cpp` +- `fast_inverse.cpp` +- `jit_util_fast_inverse.py` + +Behavior: +- Python passes only `cu_seqlens` for the varlen path. +- The NPU kernel derives each chunk's row offset and `valid_size` by scanning `cu_seqlens` inside `GetBSNDVarlenTileInfoFromCuSeqlens()`. + +Pros: +- Matches the Triton-style deployment API. +- No host-side metadata buffers to build or upload. +- Best end-to-end latency in the current measurements. + +Cons: +- Adds a small amount of device-side work per tile. + +### 2. Host-side C++ metadata precompute + +Files: +- `host_chunk_metadata.cpp` +- `host_metadata_util.py` +- `kernel_tri_inv_rec_unroll.cpp` +- `fast_inverse.cpp` +- `jit_util_fast_inverse.py` + +Behavior: +- A small host C++ helper builds `chunk_indices` and `chunk_valid_sizes` from `cu_seqlens`. +- Python uploads those buffers to NPU memory. +- The NPU kernel uses the precomputed metadata directly and skips the in-kernel `cu_seqlens` scan. + +Pros: +- Simpler varlen metadata lookup inside the kernel. +- Kernel-only time is slightly lower or roughly equal to the device-side scan path. + +Cons: +- Host metadata build plus host-to-device upload dominates the savings. +- Worse end-to-end latency in the current measurements. + +## Quick Perf Summary + +Benchmark setup: +- script: `benchmark_bsnd_fast_inverse.py` +- input style: Triton-unit-test-like `k` / `beta` generation +- config: `B=32`, `H=4`, `feature_dim=64` +- seqlens: `2048,8192` +- repeats: `10` +- warmup: `3` +- true-varlen samples: `0` + +### `chunk_size=64` + +| T | Device metadata total | Host metadata total | Host kernel only | Host metadata only | +|---|---:|---:|---:|---:| +| 2048 | 556 us | 862 us | 553 us | 309 us | +| 8192 | 2075 us | 2377 us | 2048 us | 329 us | + +Takeaway: +- Device-side metadata cost is only about `3-27 us` relative to the host-precomputed kernel-only time. +- Host-side metadata costs about `309-329 us`, so it loses badly end to end. + +### `chunk_size=128` + +| T | Device metadata total | Host metadata total | Host kernel only | Host metadata only | +|---|---:|---:|---:|---:| +| 2048 | 1088 us | 1378 us | 1089 us | 289 us | +| 8192 | 4074 us | 4372 us | 4058 us | 314 us | + +Takeaway: +- Device-side metadata overhead is effectively negligible here. +- Host-side metadata still adds about `289-314 us`, so end-to-end performance is worse. + +## Conclusion + +For the current implementation and tested shapes, the device-side `cu_seqlens` scan is the better overall strategy. + +Reason: +- The host-C++ path does reduce or nearly eliminate kernel-side metadata overhead. +- But the saved kernel time is much smaller than the cost of building and uploading host metadata. + +## How To Reproduce + +From `examples/jit_cpp/fast_inverse/`: + +```bash +export PTO_LIB_PATH=/sources/pto-isa + +python benchmark_bsnd_fast_inverse.py \ + --chunk-size 64 \ + --seqlens 2048,8192 \ + --repeats 10 \ + --warmup 3 \ + --true-varlen-samples 0 + +python benchmark_bsnd_fast_inverse.py \ + --chunk-size 128 \ + --seqlens 2048,8192 \ + --repeats 10 \ + --warmup 3 \ + --true-varlen-samples 0 +``` + +The benchmark writes: +- `benchmark_results/bench_results_bsnd_fast_inverse_64.csv` +- `benchmark_results/bench_results_bsnd_fast_inverse_128.csv` +- `benchmark_results/bench_results_bsnd_fast_inverse_bw_64.png` +- `benchmark_results/bench_results_bsnd_fast_inverse_bw_128.png` + +Relevant CSV fields: +- `metadata_strategy` +- `time_us` +- `kernel_time_us` +- `metadata_time_us` +- `bw_gbs` From aa9c09926d818487b6b8f30bb69da6ee8ade3316 Mon Sep 17 00:00:00 2001 From: learning-chip Date: Wed, 1 Apr 2026 22:35:58 +0000 Subject: [PATCH 13/14] use prefix trick to speed-up computation --- .../benchmark_bsnd_fast_inverse.py | 144 +++++++++++++++++- .../jit_cpp/fast_inverse/fast_inverse.cpp | 6 +- .../fast_inverse/host_chunk_metadata.cpp | 16 ++ .../fast_inverse/host_metadata_util.py | 31 ++++ .../fast_inverse/jit_util_fast_inverse.py | 10 ++ .../kernel_tri_inv_rec_unroll.cpp | 68 ++++++++- .../jit_cpp/fast_inverse/metadata_overhead.md | 60 +++++--- 7 files changed, 309 insertions(+), 26 deletions(-) diff --git a/examples/jit_cpp/fast_inverse/benchmark_bsnd_fast_inverse.py b/examples/jit_cpp/fast_inverse/benchmark_bsnd_fast_inverse.py index 488076fb..257101ab 100644 --- a/examples/jit_cpp/fast_inverse/benchmark_bsnd_fast_inverse.py +++ b/examples/jit_cpp/fast_inverse/benchmark_bsnd_fast_inverse.py @@ -42,7 +42,10 @@ import torch.nn.functional as F import torch_npu # noqa: F401 -from host_metadata_util import build_varlen_chunk_metadata_cpp +from host_metadata_util import ( + build_chunk_sequence_prefix_cpp, + build_varlen_chunk_metadata_cpp, +) from jit_util_fast_inverse import jit_compile @@ -296,6 +299,33 @@ def run(): return run, tensor_out +def make_varlen_runner_prefix_metadata( + tri_inv_func, + tensor_in: torch.Tensor, + cu_seqlens: torch.Tensor, + chunk_sequence_prefix: torch.Tensor, +) -> tuple[callable, torch.Tensor]: + matrix_size = tensor_in.shape[-1] + num_bsnd_heads = tensor_in.shape[-2] + num_matrices = count_varlen_chunks(cu_seqlens, matrix_size) * num_bsnd_heads + tensor_out = torch.empty_like(tensor_in, dtype=torch.float32) + minus_identity = make_minus_identity(matrix_size, str(tensor_in.device)) + + def run(): + tri_inv_func( + tensor_out, + tensor_in, + minus_identity, + matrix_size, + num_matrices, + num_bsnd_heads, + cu_seqlens=cu_seqlens, + chunk_sequence_prefix=chunk_sequence_prefix, + ) + + return run, tensor_out + + def benchmark_ms( fn, warmup_iters: int, @@ -338,6 +368,16 @@ def build_host_metadata_on_npu( ) +def build_prefix_metadata_on_npu( + cu_seqlens: torch.Tensor, + chunk_size: int, + device: str, +) -> torch.Tensor: + return build_chunk_sequence_prefix_cpp(cu_seqlens, chunk_size).to( + device=device + ).contiguous() + + def benchmark_host_metadata_prep_ms( cu_seqlens: torch.Tensor, chunk_size: int, @@ -356,6 +396,24 @@ def benchmark_host_metadata_prep_ms( return times_ms +def benchmark_prefix_metadata_prep_ms( + cu_seqlens: torch.Tensor, + chunk_size: int, + benchmark_iters: int, + device: str, +) -> list[float]: + times_ms: list[float] = [] + cache = torch.ones(DEFAULT_CACHE_SIZE, dtype=torch.int8, device=device) + for _ in range(benchmark_iters): + cache.zero_() + torch.npu.synchronize() + start = time.perf_counter() + build_prefix_metadata_on_npu(cu_seqlens, chunk_size, device) + torch.npu.synchronize() + times_ms.append((time.perf_counter() - start) * 1000.0) + return times_ms + + def add_bandwidth_fields(row: dict[str, float | int | str], input_dtype_bytes: int = 2) -> None: size_elems = int(row.get("valid_numel", row["numel"])) mem_bytes = size_elems * (input_dtype_bytes + 4) @@ -419,6 +477,12 @@ def plot_bandwidth(plot_path: Path, rows: list[dict[str, float | int | str]], ba if row["inverse_type"] == "bsnd-varlen-uniform" and row["metadata_strategy"] == "host-cpp" ] + varlen_prefix_rows = [ + row + for row in rows + if row["inverse_type"] == "bsnd-varlen-uniform" + and row["metadata_strategy"] == "device-chunk-prefix" + ] fig, ax = plt.subplots(figsize=(7.5, 5.0)) ax.plot( @@ -442,6 +506,13 @@ def plot_bandwidth(plot_path: Path, rows: list[dict[str, float | int | str]], ba linewidth=2, label="BSND varlen host metadata", ) + ax.plot( + [int(row["T"]) / 1000.0 for row in varlen_prefix_rows], + [float(row["bw_gbs"]) for row in varlen_prefix_rows], + marker="d", + linewidth=2, + label="BSND varlen prefix metadata", + ) ax.set_xlabel("Sequence length T (K)") ax.set_ylabel("Effective bandwidth (GB/s)") ax.set_title( @@ -610,6 +681,17 @@ def main() -> None: varlen_input, cu_seqlens, ) + chunk_sequence_prefix = build_prefix_metadata_on_npu( + cu_seqlens_cpu, + args.chunk_size, + NPU_DEVICE, + ) + varlen_run_prefix, varlen_out_prefix = make_varlen_runner_prefix_metadata( + tri_inv_func, + varlen_input, + cu_seqlens, + chunk_sequence_prefix, + ) chunk_indices, chunk_valid_sizes = build_host_metadata_on_npu( cu_seqlens_cpu, args.chunk_size, @@ -624,6 +706,7 @@ def main() -> None: fixed_run() varlen_run_device() + varlen_run_prefix() varlen_run_host() torch.npu.synchronize() @@ -637,6 +720,11 @@ def main() -> None: cu_seqlens, args.chunk_size, ) + packed_varlen_out_prefix = transpose_valid_chunks( + varlen_out_prefix, + cu_seqlens, + args.chunk_size, + ) packed_varlen_out_host = transpose_valid_chunks( varlen_out_host, cu_seqlens, @@ -650,9 +738,15 @@ def main() -> None: packed_fixed_out, packed_varlen_out_host, ) + max_abs_diff_prefix, rel_frob_diff_prefix = accuracy_metrics( + packed_fixed_out, + packed_varlen_out_prefix, + ) print( f" accuracy vs fixed: device max_abs_diff={max_abs_diff_device:.3e}, " f"device rel_frob_diff={rel_frob_diff_device:.3e}, " + f"prefix max_abs_diff={max_abs_diff_prefix:.3e}, " + f"prefix rel_frob_diff={rel_frob_diff_prefix:.3e}, " f"host max_abs_diff={max_abs_diff_host:.3e}, " f"host rel_frob_diff={rel_frob_diff_host:.3e}" ) @@ -669,6 +763,18 @@ def main() -> None: benchmark_iters=args.repeats, device=NPU_DEVICE, ) + prefix_metadata_times_ms = benchmark_prefix_metadata_prep_ms( + cu_seqlens_cpu, + args.chunk_size, + benchmark_iters=args.repeats, + device=NPU_DEVICE, + ) + varlen_prefix_kernel_times_ms = benchmark_ms( + varlen_run_prefix, + warmup_iters=args.warmup, + benchmark_iters=args.repeats, + device=NPU_DEVICE, + ) host_metadata_times_ms = benchmark_host_metadata_prep_ms( cu_seqlens_cpu, args.chunk_size, @@ -726,6 +832,30 @@ def main() -> None: } add_bandwidth_fields(varlen_device_row) + avg_prefix_metadata_us = int(round(np.mean(prefix_metadata_times_ms) * 1000.0)) + avg_prefix_kernel_us = int(round(np.mean(varlen_prefix_kernel_times_ms) * 1000.0)) + varlen_prefix_row = { + "inverse_type": "bsnd-varlen-uniform", + "metadata_strategy": "device-chunk-prefix", + "dtype": "fp16", + "B": args.B, + "T": seqlen, + "aggregated_T": total_tokens, + "padded_T": total_tokens, + "H": args.H, + "numel": varlen_input.numel(), + "valid_numel": total_tokens * args.H * args.chunk_size, + "chunk_size": args.chunk_size, + "time_us": avg_prefix_metadata_us + avg_prefix_kernel_us, + "kernel_time_us": avg_prefix_kernel_us, + "metadata_time_us": avg_prefix_metadata_us, + "max_abs_diff_to_fixed": max_abs_diff_prefix, + "rel_frob_diff_to_fixed": rel_frob_diff_prefix, + "sample_id": "", + "seq_lens": ",".join([str(seqlen)] * args.B), + } + add_bandwidth_fields(varlen_prefix_row) + avg_host_metadata_us = int(round(np.mean(host_metadata_times_ms) * 1000.0)) avg_host_kernel_us = int(round(np.mean(varlen_host_kernel_times_ms) * 1000.0)) varlen_host_row = { @@ -750,11 +880,14 @@ def main() -> None: } add_bandwidth_fields(varlen_host_row) - rows.extend([fixed_row, varlen_device_row, varlen_host_row]) + rows.extend([fixed_row, varlen_device_row, varlen_prefix_row, varlen_host_row]) print( f" fixed: time_us={fixed_row['time_us']}, bw_gbs={fixed_row['bw_gbs']:.2f} | " f"varlen-device: time_us={varlen_device_row['time_us']}, " f"bw_gbs={varlen_device_row['bw_gbs']:.2f} | " + f"varlen-prefix: time_us={varlen_prefix_row['time_us']} " + f"(meta={varlen_prefix_row['metadata_time_us']}, kernel={varlen_prefix_row['kernel_time_us']}), " + f"bw_gbs={varlen_prefix_row['bw_gbs']:.2f} | " f"varlen-host: time_us={varlen_host_row['time_us']} " f"(meta={varlen_host_row['metadata_time_us']}, kernel={varlen_host_row['kernel_time_us']}), " f"bw_gbs={varlen_host_row['bw_gbs']:.2f}" @@ -762,8 +895,13 @@ def main() -> None: device_metadata_overhead_us = ( varlen_device_row["kernel_time_us"] - varlen_host_row["kernel_time_us"] ) + prefix_metadata_overhead_us = ( + varlen_device_row["kernel_time_us"] - varlen_prefix_row["kernel_time_us"] + ) print( - f" metadata overhead comparison: device_in_kernel_delta_us={device_metadata_overhead_us}, " + f" metadata overhead comparison: device_vs_host_kernel_delta_us={device_metadata_overhead_us}, " + f"device_vs_prefix_kernel_delta_us={prefix_metadata_overhead_us}, " + f"prefix_metadata_us={varlen_prefix_row['metadata_time_us']}, " f"host_cpp_metadata_us={varlen_host_row['metadata_time_us']}" ) diff --git a/examples/jit_cpp/fast_inverse/fast_inverse.cpp b/examples/jit_cpp/fast_inverse/fast_inverse.cpp index 8155351e..42ba6cce 100644 --- a/examples/jit_cpp/fast_inverse/fast_inverse.cpp +++ b/examples/jit_cpp/fast_inverse/fast_inverse.cpp @@ -28,6 +28,8 @@ for the full License text. * N (number of heads) for BSND layout. * @param cu_seqlens Optional int32 pointer used only for varlen BSND when the * device kernel derives chunk metadata itself. + * @param chunk_sequence_prefix Optional int32 pointer containing a compact + * per-sequence cumulative chunk-count prefix. * @param chunk_indices Optional int32 pointer containing per-chunk row starts * for the host-precomputed varlen path. * @param chunk_valid_sizes Optional int32 pointer containing each chunk's @@ -37,8 +39,10 @@ extern "C" void call_kernel(uint32_t blockDim, void* stream, void* tensor_out, void* tensor_in, void* minus_identity_in, uint32_t matrix_size, uint32_t num_matrices, uint32_t num_bsnd_heads, void* cu_seqlens, + void* chunk_sequence_prefix, void* chunk_indices, void* chunk_valid_sizes) { tri_inv_rec_unroll_fp16<<>>( tensor_out, tensor_in, minus_identity_in, matrix_size, num_matrices, - num_bsnd_heads, cu_seqlens, chunk_indices, chunk_valid_sizes); + num_bsnd_heads, cu_seqlens, chunk_sequence_prefix, chunk_indices, + chunk_valid_sizes); } diff --git a/examples/jit_cpp/fast_inverse/host_chunk_metadata.cpp b/examples/jit_cpp/fast_inverse/host_chunk_metadata.cpp index 11fdd6c0..bfef9b29 100644 --- a/examples/jit_cpp/fast_inverse/host_chunk_metadata.cpp +++ b/examples/jit_cpp/fast_inverse/host_chunk_metadata.cpp @@ -40,3 +40,19 @@ extern "C" void build_varlen_chunk_metadata_host_cpp( } } } + +extern "C" void build_chunk_sequence_prefix_host_cpp( + const int32_t* cu_seqlens, uint32_t num_sequences, uint32_t chunk_size, + int32_t* chunk_sequence_prefix) { + chunk_sequence_prefix[0] = static_cast(num_sequences); + chunk_sequence_prefix[1] = 0; + + uint32_t total_chunks = 0; + for (uint32_t seq_idx = 0; seq_idx < num_sequences; ++seq_idx) { + const uint32_t seq_start = static_cast(cu_seqlens[seq_idx]); + const uint32_t seq_end = static_cast(cu_seqlens[seq_idx + 1]); + const uint32_t seq_len = seq_end - seq_start; + total_chunks += (seq_len + chunk_size - 1) / chunk_size; + chunk_sequence_prefix[seq_idx + 2] = static_cast(total_chunks); + } +} diff --git a/examples/jit_cpp/fast_inverse/host_metadata_util.py b/examples/jit_cpp/fast_inverse/host_metadata_util.py index bf6b4f27..b0d13383 100644 --- a/examples/jit_cpp/fast_inverse/host_metadata_util.py +++ b/examples/jit_cpp/fast_inverse/host_metadata_util.py @@ -65,6 +65,13 @@ def load_host_metadata_lib(): ctypes.c_void_p, ] lib.build_varlen_chunk_metadata_host_cpp.restype = None + lib.build_chunk_sequence_prefix_host_cpp.argtypes = [ + ctypes.c_void_p, + ctypes.c_uint32, + ctypes.c_uint32, + ctypes.c_void_p, + ] + lib.build_chunk_sequence_prefix_host_cpp.restype = None _HOST_LIB = lib return lib @@ -100,3 +107,27 @@ def build_varlen_chunk_metadata_cpp( _torch_to_ctypes(chunk_valid_sizes), ) return chunk_indices, chunk_valid_sizes + + +def build_chunk_sequence_prefix_cpp( + cu_seqlens: torch.Tensor | list[int], + chunk_size: int, +) -> torch.Tensor: + lib = load_host_metadata_lib() + if isinstance(cu_seqlens, torch.Tensor): + cu_seqlens_cpu = cu_seqlens.detach().to(device="cpu", dtype=torch.int32).contiguous() + else: + cu_seqlens_cpu = torch.tensor(cu_seqlens, dtype=torch.int32) + + if cu_seqlens_cpu.numel() < 2: + raise ValueError("cu_seqlens must contain at least 2 entries.") + + num_sequences = cu_seqlens_cpu.numel() - 1 + chunk_sequence_prefix = torch.empty(num_sequences + 2, dtype=torch.int32) + lib.build_chunk_sequence_prefix_host_cpp( + _torch_to_ctypes(cu_seqlens_cpu), + num_sequences, + chunk_size, + _torch_to_ctypes(chunk_sequence_prefix), + ) + return chunk_sequence_prefix diff --git a/examples/jit_cpp/fast_inverse/jit_util_fast_inverse.py b/examples/jit_cpp/fast_inverse/jit_util_fast_inverse.py index 63236f79..a8c86c7d 100644 --- a/examples/jit_cpp/fast_inverse/jit_util_fast_inverse.py +++ b/examples/jit_cpp/fast_inverse/jit_util_fast_inverse.py @@ -86,6 +86,7 @@ def load_lib(lib_path: str): ctypes.c_uint32, # num_matrices ctypes.c_uint32, # num_bsnd_heads ctypes.c_void_p, # cu_seqlens (optional int32 metadata) + ctypes.c_void_p, # chunk_sequence_prefix (optional int32 metadata) ctypes.c_void_p, # chunk_indices (optional int32 metadata) ctypes.c_void_p, # chunk_valid_sizes (optional int32 metadata) ] @@ -99,6 +100,7 @@ def tri_inv_func( num_matrices: int, num_bsnd_heads: int = 0, cu_seqlens: torch.Tensor | None = None, + chunk_sequence_prefix: torch.Tensor | None = None, chunk_indices: torch.Tensor | None = None, chunk_valid_sizes: torch.Tensor | None = None, block_dim: int = BLOCK_DIM, @@ -111,6 +113,11 @@ def tri_inv_func( raise TypeError("cu_seqlens must be int32.") if not cu_seqlens.is_contiguous(): raise ValueError("cu_seqlens must be contiguous.") + if chunk_sequence_prefix is not None: + if chunk_sequence_prefix.dtype != torch.int32: + raise TypeError("chunk_sequence_prefix must be int32.") + if not chunk_sequence_prefix.is_contiguous(): + raise ValueError("chunk_sequence_prefix must be contiguous.") if chunk_indices is not None: if chunk_indices.dtype != torch.int32: raise TypeError("chunk_indices must be int32.") @@ -136,6 +143,9 @@ def tri_inv_func( _torch_to_ctypes(cu_seqlens) if cu_seqlens is not None else ctypes.c_void_p(), + _torch_to_ctypes(chunk_sequence_prefix) + if chunk_sequence_prefix is not None + else ctypes.c_void_p(), _torch_to_ctypes(chunk_indices) if chunk_indices is not None else ctypes.c_void_p(), diff --git a/examples/jit_cpp/fast_inverse/kernel_tri_inv_rec_unroll.cpp b/examples/jit_cpp/fast_inverse/kernel_tri_inv_rec_unroll.cpp index b5dafd6d..3d093648 100644 --- a/examples/jit_cpp/fast_inverse/kernel_tri_inv_rec_unroll.cpp +++ b/examples/jit_cpp/fast_inverse/kernel_tri_inv_rec_unroll.cpp @@ -78,6 +78,40 @@ AICORE inline BSNDVarlenTileInfo GetBSNDVarlenTileInfoFromChunkMetadata( valid_size}; } +AICORE inline BSNDVarlenTileInfo GetBSNDVarlenTileInfoFromChunkPrefix( + uint32_t tile_id, uint32_t num_bsnd_heads, uint32_t matrix_size, + __gm__ int32_t* cu_seqlens, __gm__ int32_t* chunk_sequence_prefix) { + const uint32_t head_idx = tile_id % num_bsnd_heads; + const uint32_t chunk_idx = tile_id / num_bsnd_heads; + const uint32_t num_sequences = + static_cast(chunk_sequence_prefix[0]); + + uint32_t left = 0; + uint32_t right = num_sequences; + while (left < right) { + const uint32_t mid = (left + right) / 2; + const uint32_t chunk_end = + static_cast(chunk_sequence_prefix[mid + 2]); + if (chunk_idx < chunk_end) { + right = mid; + } else { + left = mid + 1; + } + } + + const uint32_t seq_idx = left; + const uint32_t chunk_base = + static_cast(chunk_sequence_prefix[seq_idx + 1]); + const uint32_t local_chunk_idx = chunk_idx - chunk_base; + const uint32_t seq_start = static_cast(cu_seqlens[seq_idx]); + const uint32_t seq_end = static_cast(cu_seqlens[seq_idx + 1]); + const uint32_t row_start = seq_start + local_chunk_idx * matrix_size; + const uint32_t valid_size = + min(static_cast(seq_end - row_start), matrix_size); + return {row_start * num_bsnd_heads * matrix_size + head_idx * matrix_size, + valid_size}; +} + /* * @brief: Takes as input two matrices of size MatrixSize * MatrixSize each. * The src matrix lies in L1, while the dst matrix lies either in L0A or L0B. @@ -430,6 +464,8 @@ AICORE inline void TriInvRecUnrollKernel(__gm__ OutputT* M_inv, uint32_t total_tiles, uint32_t num_bsnd_heads = 0, __gm__ int32_t* cu_seqlens = nullptr, + __gm__ int32_t* chunk_sequence_prefix = + nullptr, __gm__ int32_t* chunk_indices = nullptr, __gm__ int32_t* chunk_valid_sizes = nullptr) { @@ -546,6 +582,13 @@ AICORE inline void TriInvRecUnrollKernel(__gm__ OutputT* M_inv, chunk_valid_sizes); bsnd_tile_offsets[tile_id] = tile_info.bsnd_offset; bsnd_tile_valid_sizes[tile_id] = tile_info.valid_size; + } else if (chunk_sequence_prefix != nullptr && cu_seqlens != nullptr) { + const BSNDVarlenTileInfo tile_info = + GetBSNDVarlenTileInfoFromChunkPrefix( + global_tile_id, num_bsnd_heads, MatrixSize, cu_seqlens, + chunk_sequence_prefix); + bsnd_tile_offsets[tile_id] = tile_info.bsnd_offset; + bsnd_tile_valid_sizes[tile_id] = tile_info.valid_size; } else if (cu_seqlens != nullptr) { const BSNDVarlenTileInfo tile_info = GetBSNDVarlenTileInfoFromCuSeqlens( global_tile_id, num_bsnd_heads, MatrixSize, cu_seqlens); @@ -820,6 +863,8 @@ AICORE void runKernelTriInvRecUnroll(__gm__ OutputT* M_inv, __gm__ InputT* M, __gm__ InputT* I_neg, uint32_t total_tiles, uint32_t num_bsnd_heads = 0, __gm__ int32_t* cu_seqlens = nullptr, + __gm__ int32_t* chunk_sequence_prefix = + nullptr, __gm__ int32_t* chunk_indices = nullptr, __gm__ int32_t* chunk_valid_sizes = nullptr) { @@ -827,7 +872,8 @@ AICORE void runKernelTriInvRecUnroll(__gm__ OutputT* M_inv, __gm__ InputT* M, (__CCE_AICORE__ == 220 && defined(__DAV_C220_CUBE__)) TriInvRecUnrollKernel(M_inv, M, I_neg, total_tiles, num_bsnd_heads, - cu_seqlens, chunk_indices, + cu_seqlens, chunk_sequence_prefix, + chunk_indices, chunk_valid_sizes); #else // Nothing to do on AIV @@ -841,6 +887,7 @@ AICORE void run_tri_inv_rec_unroll(__gm__ float* tensor_out, uint32_t matrix_size, uint32_t num_matrices, uint32_t num_bsnd_heads, __gm__ int32_t* cu_seqlens, + __gm__ int32_t* chunk_sequence_prefix, __gm__ int32_t* chunk_indices, __gm__ int32_t* chunk_valid_sizes) { static_assert(std::is_same_v, @@ -849,22 +896,26 @@ AICORE void run_tri_inv_rec_unroll(__gm__ float* tensor_out, case 16: runKernelTriInvRecUnroll( tensor_out, tensor_in, minus_identity_in, num_matrices, num_bsnd_heads, - cu_seqlens, chunk_indices, chunk_valid_sizes); + cu_seqlens, chunk_sequence_prefix, chunk_indices, + chunk_valid_sizes); break; case 32: runKernelTriInvRecUnroll( tensor_out, tensor_in, minus_identity_in, num_matrices, num_bsnd_heads, - cu_seqlens, chunk_indices, chunk_valid_sizes); + cu_seqlens, chunk_sequence_prefix, chunk_indices, + chunk_valid_sizes); break; case 64: runKernelTriInvRecUnroll( tensor_out, tensor_in, minus_identity_in, num_matrices, num_bsnd_heads, - cu_seqlens, chunk_indices, chunk_valid_sizes); + cu_seqlens, chunk_sequence_prefix, chunk_indices, + chunk_valid_sizes); break; case 128: runKernelTriInvRecUnroll( tensor_out, tensor_in, minus_identity_in, num_matrices, num_bsnd_heads, - cu_seqlens, chunk_indices, chunk_valid_sizes); + cu_seqlens, chunk_sequence_prefix, chunk_indices, + chunk_valid_sizes); break; } } @@ -873,6 +924,7 @@ extern "C" __global__ AICORE void tri_inv_rec_unroll_fp16( __gm__ void* tensor_out, __gm__ void* tensor_in, __gm__ void* minus_identity_in, uint32_t matrix_size, uint32_t num_matrices, uint32_t num_bsnd_heads, __gm__ void* cu_seqlens, + __gm__ void* chunk_sequence_prefix, __gm__ void* chunk_indices, __gm__ void* chunk_valid_sizes) { if (num_bsnd_heads == 0) { if (num_matrices <= get_block_num()) { @@ -880,18 +932,21 @@ extern "C" __global__ AICORE void tri_inv_rec_unroll_fp16( (__gm__ float*)tensor_out, (__gm__ half*)tensor_in, (__gm__ half*)minus_identity_in, matrix_size, num_matrices, num_bsnd_heads, (__gm__ int32_t*)cu_seqlens, + (__gm__ int32_t*)chunk_sequence_prefix, (__gm__ int32_t*)chunk_indices, (__gm__ int32_t*)chunk_valid_sizes); } else if (num_matrices <= 2 * get_block_num()) { run_tri_inv_rec_unroll( (__gm__ float*)tensor_out, (__gm__ half*)tensor_in, (__gm__ half*)minus_identity_in, matrix_size, num_matrices, num_bsnd_heads, (__gm__ int32_t*)cu_seqlens, + (__gm__ int32_t*)chunk_sequence_prefix, (__gm__ int32_t*)chunk_indices, (__gm__ int32_t*)chunk_valid_sizes); } else { run_tri_inv_rec_unroll( (__gm__ float*)tensor_out, (__gm__ half*)tensor_in, (__gm__ half*)minus_identity_in, matrix_size, num_matrices, num_bsnd_heads, (__gm__ int32_t*)cu_seqlens, + (__gm__ int32_t*)chunk_sequence_prefix, (__gm__ int32_t*)chunk_indices, (__gm__ int32_t*)chunk_valid_sizes); } } else { @@ -900,18 +955,21 @@ extern "C" __global__ AICORE void tri_inv_rec_unroll_fp16( (__gm__ float*)tensor_out, (__gm__ half*)tensor_in, (__gm__ half*)minus_identity_in, matrix_size, num_matrices, num_bsnd_heads, (__gm__ int32_t*)cu_seqlens, + (__gm__ int32_t*)chunk_sequence_prefix, (__gm__ int32_t*)chunk_indices, (__gm__ int32_t*)chunk_valid_sizes); } else if (num_matrices <= 2 * get_block_num()) { run_tri_inv_rec_unroll( (__gm__ float*)tensor_out, (__gm__ half*)tensor_in, (__gm__ half*)minus_identity_in, matrix_size, num_matrices, num_bsnd_heads, (__gm__ int32_t*)cu_seqlens, + (__gm__ int32_t*)chunk_sequence_prefix, (__gm__ int32_t*)chunk_indices, (__gm__ int32_t*)chunk_valid_sizes); } else { run_tri_inv_rec_unroll( (__gm__ float*)tensor_out, (__gm__ half*)tensor_in, (__gm__ half*)minus_identity_in, matrix_size, num_matrices, num_bsnd_heads, (__gm__ int32_t*)cu_seqlens, + (__gm__ int32_t*)chunk_sequence_prefix, (__gm__ int32_t*)chunk_indices, (__gm__ int32_t*)chunk_valid_sizes); } } diff --git a/examples/jit_cpp/fast_inverse/metadata_overhead.md b/examples/jit_cpp/fast_inverse/metadata_overhead.md index 5b939745..d6b4ddb6 100644 --- a/examples/jit_cpp/fast_inverse/metadata_overhead.md +++ b/examples/jit_cpp/fast_inverse/metadata_overhead.md @@ -1,6 +1,6 @@ # Metadata Overhead Comparison -This note compares two varlen BSND metadata strategies for the fast-inverse PTO kernel. +This note compares three varlen BSND metadata strategies for the fast-inverse PTO kernel. ## Strategies @@ -23,7 +23,30 @@ Pros: Cons: - Adds a small amount of device-side work per tile. -### 2. Host-side C++ metadata precompute +### 2. Device-side compact chunk-prefix metadata + +Files: +- `host_chunk_metadata.cpp` +- `host_metadata_util.py` +- `kernel_tri_inv_rec_unroll.cpp` +- `fast_inverse.cpp` +- `jit_util_fast_inverse.py` + +Behavior: +- A small host C++ helper builds a compact per-sequence cumulative chunk-count prefix. +- Python uploads that prefix together with `cu_seqlens`. +- The NPU kernel uses the prefix to binary-search the owning sequence for each chunk, instead of walking all prior sequences. + +Pros: +- Reduces in-kernel metadata work compared with full `cu_seqlens` walking. +- Metadata payload is much smaller than full per-chunk host metadata. +- Better end-to-end than the full host metadata path. + +Cons: +- Still requires host preprocessing and one extra metadata upload. +- Still slower end-to-end than pure device-side `cu_seqlens` scanning in the current measurements. + +### 3. Host-side C++ metadata precompute Files: - `host_chunk_metadata.cpp` @@ -58,33 +81,36 @@ Benchmark setup: ### `chunk_size=64` -| T | Device metadata total | Host metadata total | Host kernel only | Host metadata only | -|---|---:|---:|---:|---:| -| 2048 | 556 us | 862 us | 553 us | 309 us | -| 8192 | 2075 us | 2377 us | 2048 us | 329 us | +| T | Device scan total | Prefix total | Prefix kernel | Prefix metadata | Host total | Host kernel | Host metadata | +|---|---:|---:|---:|---:|---:|---:|---:| +| 2048 | 564 us | 746 us | 559 us | 187 us | 836 us | 559 us | 277 us | +| 8192 | 2071 us | 2235 us | 2049 us | 186 us | 2340 us | 2048 us | 292 us | Takeaway: -- Device-side metadata cost is only about `3-27 us` relative to the host-precomputed kernel-only time. -- Host-side metadata costs about `309-329 us`, so it loses badly end to end. +- Prefix metadata cuts host metadata overhead from about `277-292 us` down to about `186-187 us`. +- Kernel-only time for prefix is slightly better than full device-side scanning, by about `5-22 us`. +- End to end, plain device-side `cu_seqlens` scanning is still best. ### `chunk_size=128` -| T | Device metadata total | Host metadata total | Host kernel only | Host metadata only | -|---|---:|---:|---:|---:| -| 2048 | 1088 us | 1378 us | 1089 us | 289 us | -| 8192 | 4074 us | 4372 us | 4058 us | 314 us | +| T | Device scan total | Prefix total | Prefix kernel | Prefix metadata | Host total | Host kernel | Host metadata | +|---|---:|---:|---:|---:|---:|---:|---:| +| 2048 | 1085 us | 1298 us | 1084 us | 214 us | 1363 us | 1080 us | 283 us | +| 8192 | 4065 us | 4253 us | 4056 us | 197 us | 4351 us | 4063 us | 288 us | Takeaway: -- Device-side metadata overhead is effectively negligible here. -- Host-side metadata still adds about `289-314 us`, so end-to-end performance is worse. +- Prefix metadata cuts host metadata overhead from about `283-288 us` down to about `197-214 us`. +- Kernel-only improvement versus device scan is tiny, around `1-9 us`. +- End to end, the pure device-side scan still wins. ## Conclusion -For the current implementation and tested shapes, the device-side `cu_seqlens` scan is the better overall strategy. +For the current implementation and tested shapes, the device-side `cu_seqlens` scan is still the best overall strategy. Reason: -- The host-C++ path does reduce or nearly eliminate kernel-side metadata overhead. -- But the saved kernel time is much smaller than the cost of building and uploading host metadata. +- The compact prefix path does reduce kernel-side metadata work and is clearly better than full host per-chunk metadata. +- But the saved kernel time is still much smaller than the cost of building and uploading the prefix. +- The full host per-chunk metadata path remains the slowest end-to-end option. ## How To Reproduce From 8da6f1326737e9b2d984ff736b7e2bc34b0cea58 Mon Sep 17 00:00:00 2001 From: Aleksandros Sobczyk <6952514+asobczyk@users.noreply.github.com> Date: Thu, 2 Apr 2026 14:09:37 +0000 Subject: [PATCH 14/14] Add description for BSND --- examples/jit_cpp/fast_inverse/README.md | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/examples/jit_cpp/fast_inverse/README.md b/examples/jit_cpp/fast_inverse/README.md index 41aeff5a..07621c5e 100644 --- a/examples/jit_cpp/fast_inverse/README.md +++ b/examples/jit_cpp/fast_inverse/README.md @@ -63,10 +63,21 @@ That script: ### Layout conventions +In general, the input to the `fast_inverse` kernels is a set of `D × D` sized triangular matrices. Depending on how these matrices are stored in memory, we might have `contiguous` layout, or the so-called `BSND` layout. The main input is a batch of sequences, and each sequence is then split in "chunks" of length `chunk_size`. This `chunk_size` is the same as the matrix size `D`. + +Both layouts depend on the following parameters: +- The parameter `B` denotes the batch-size (or batch-dimension). This is always the first dimension of the input tensor. +- The parameter `N` or `H` (used interchangeably) is the number of heads. +- `D` is equal to the `chunk_size`. +- `S` is the total sum of all sequence lengths combined. +`BSND` can be thought of as the "raw" input tensor. The `contiguous` layout can be obtained, for example, by transposing the `N` and `S` dimensions, and by "chunking" the `S` dimension to chunks of size `S`. The final tensor will be transformed from shape `(B,S,N,D)` to `->(B,N,S/D,D)`, where we assumed that `D` divides `S` for simplicity. + +The actual kernel can verify if the input is in `BSND` layout or in `contiguous` layout by specifying the input argument `num_bsnd_heads`. If it is equal to zero, then the format is assumed to be `contiguous` + | `num_bsnd_heads` | Memory layout | |-----------------|---------------| -| `0` (default) | Each matrix stored consecutively in row-major order (`B × … × N × D × D`) | -| `> 0` | BSND layout: `(B, S, N, D)` where S is chunked into tiles of size D and N heads are interleaved | +| `0` (default) | Each matrix stored consecutively in row-major order (`B × … × D × D`) | +| `> 0` | BSND layout: `(B, S, N, D)` where `S` is chunked into tiles of size D and N heads are interleaved | ### Varlen BSND mode