diff --git a/CMakeLists.txt b/CMakeLists.txt index 46f02038..0e5558f1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -116,6 +116,7 @@ ascendc_library( csrc/kernel/kernel_csr_gather.cpp csrc/kernel/kernel_simple_matmul.cpp csrc/kernel/kernel_batch_matrix_square.cpp + csrc/kernel/kernel_tri_inv_ns.cpp csrc/kernel/kernel_tri_inv_rec_unroll.cpp csrc/kernel/kernel_tri_inv_trick.cpp csrc/kernel/kernel_swiglu.cpp) diff --git a/csrc/host/pybind11.cpp b/csrc/host/pybind11.cpp index 134195b7..5091bbb6 100644 --- a/csrc/host/pybind11.cpp +++ b/csrc/host/pybind11.cpp @@ -15,6 +15,7 @@ for the full License text. #include "torch_simple_matmul.h" #include "torch_swiglu.h" #include "torch_tri_inv.h" +#include "torch_tri_inv_ns.h" #include "torch_tri_inv_rec_unroll.h" #include "torch_tri_inv_trick.h" @@ -45,5 +46,7 @@ PYBIND11_MODULE(pto_kernels_ops, m) { m.def("pto_tri_inv_rec_unroll", &pto_isa_ops::run_tri_inv_rec_unroll, py::arg("M"), py::arg("is_bsnd_format") = false, py::arg("cu_seqlens") = at::zeros({1})); + m.def("pto_tri_inv_ns", &pto_isa_ops::run_tri_inv_ns, py::arg("M"), + py::arg("num_iters") = 0, py::arg("scale_value") = 0.0f); m.def("pto_tri_inv", &pto_isa_ops::run_tri_inv); } diff --git a/csrc/host/torch_tri_inv_ns.h b/csrc/host/torch_tri_inv_ns.h new file mode 100644 index 00000000..fb9e9087 --- /dev/null +++ b/csrc/host/torch_tri_inv_ns.h @@ -0,0 +1,77 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +All rights reserved. + +See LICENSE in the root of the software repository: +https://github.com/huawei-csl/pto-kernels/ +for the full License text. +*/ +#pragma once + +#include +#include + +#include + +#include "aclrtlaunch_tri_inv_ns_fp16.h" +#include "utils.h" + +namespace pto_isa_ops { + +/** + * @brief Triangular inverse using Newton–Schulz iterations. + * + * Implements the following algorithm: + * A = I + M + * X = I * scale + * for _ in range(num_iters): + * Y = A @ X + * X = X @ (2*I - Y) + * return X + * + * @param M Input tensor of strictly upper-triangular matrices (..., n, + * n), dtype fp16. The full matrix inverted by the algorithm is A = I + M. + * @param num_iters Number of Newton–Schulz iterations (0 = auto). + * @param scale_value Value to scale the initial guess. Defaults to zero, which + * sets scale_value = 2 * n, where n is the size of the matrices. + * @return at::Tensor Tensor of approximate inverses in fp32, same batch shape + * as M. + */ +at::Tensor run_tri_inv_ns(const at::Tensor& M, uint32_t num_iters = 0, + float scale_value = 0) { + const at::Device device = M.options().device(); + const auto dtype = M.options().dtype(); + const auto dtype_out = at::kFloat; + + if (!(dtype == at::kHalf)) { + throw std::runtime_error( + "Unsupported dtype for tri_inv_ns kernel. Supports only fp16"); + } + const uint32_t n = static_cast(M.size(-1)); + if (n != static_cast(M.size(-2))) { + throw std::runtime_error("Only square matrices are supported.\n"); + } + + if (scale_value == 0) { + scale_value = 2 * n; + } + + const uint32_t num_matrices = static_cast(M.numel()) / (n * n); + + if (num_iters == 0) { + num_iters = static_cast(std::ceil(2.0f * std::log2(n))); + num_iters = std::max(num_iters, 8); + } + + const at::Tensor I_neg = -at::eye(n, M.options()); + const at::Tensor I_scaled = I_neg / (-scale_value); + + const at::Tensor M_inv = + at::zeros_like(M, at::TensorOptions().dtype(dtype_out).device(device)); + + EXEC_KERNEL_CMD(tri_inv_ns_fp16, num_matrices, M_inv, M, I_neg, I_scaled, n, + num_iters); + + return M_inv; +} +} // namespace pto_isa_ops diff --git a/csrc/kernel/kernel_tri_inv_ns.cpp b/csrc/kernel/kernel_tri_inv_ns.cpp new file mode 100644 index 00000000..a5e11950 --- /dev/null +++ b/csrc/kernel/kernel_tri_inv_ns.cpp @@ -0,0 +1,249 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +All rights reserved. + +See LICENSE in the root of the software repository: +https://github.com/huawei-csl/pto-kernels/ +for the full License text. +*/ + +#ifndef MEMORY_BASE +#define MEMORY_BASE +#endif +#include + +#define GM_ADDR __gm__ uint8_t* // To avoid #include "kernel_operator.h" + +using namespace pto; + +/** + * @brief Triangular inverse using Newton–Schulz iterations. + * + * Implements the following algorithm: + * A = I + M + * X = I / (2 * MatrixSize) + * for _ in range(num_iters): + * Y = X @ (-A) + * X = Y @ X + 2 * X + * return X + * @tparam InputT The type of the input elements. + * @tparam OutputT The type of the output elements. + * @tparam MatrixSize Size of the entire input/output matrices. + * + * @param M_inv pointer to the global memory to store the final inverse. + * @param M Pointer to the global tensor matrix in global memory. + * @param I_neg Pointer to global memory that contains the negative identity. + * @param I_scaled Pointer to global memory containing the identity scaled by: + * 1 / (2 * MatrixSize). + * @param num_iters Number of Newton-Schulz iterations. + */ +template +AICORE void runKernelTriInvNS(__gm__ OutputT* M_inv, __gm__ InputT* M, + __gm__ InputT* I_neg, __gm__ InputT* I_scaled, + uint32_t num_iters) { +#if (__CHECK_FEATURE_AT_PRECOMPILE) || \ + (__CCE_AICORE__ == 220 && defined(__DAV_C220_CUBE__)) // Cube compilation + + constexpr uint32_t TileLen = MatrixSize * MatrixSize; + const uint32_t global_index = get_block_idx() * TileLen; + constexpr uint32_t NumL0Buffers = 2; + + /* Global Memory / Tensors */ + using TensorShapeInND = + TileShape2D; + using TensorStridesInND = + BaseShape2D; + using GlobalTensorIn = + GlobalTensor; + + using TensorShapeOut = + TileShape2D; + using TensorStridesOut = + BaseShape2D; + using GlobalTensorOut = + GlobalTensor; + + /* L1 Memory */ + using TileL1AB = + Tile; + + /* L0 Memory */ + using TileL0A = TileLeft; + using TileL0B = TileRight; + using TileL0C = TileAcc; + + GlobalTensorIn M_global_in(M + global_index); + GlobalTensorIn I_neg_global_in(I_neg); + GlobalTensorIn I_scaled_global_in(I_scaled); + GlobalTensorOut M_inv_global_out(M_inv + global_index); + + TileL1AB A_neg_l1_tile; + TileL1AB X_l1_tile; + TileL1AB Y_l1_tile; + TileL1AB I_neg_l1_tile; + TileL1AB I_l1_tile; + TileL1AB two_I_l1_tile; + + TileL0A a_l0_tile[NumL0Buffers]; + TileL0B b_l0_tile[NumL0Buffers]; + TileL0C c_l0_tile[NumL0Buffers]; + + TASSIGN(A_neg_l1_tile, 0x0); + TASSIGN(X_l1_tile, 0x0 + TileLen * sizeof(InputT)); + TASSIGN(Y_l1_tile, 0x0 + 2 * TileLen * sizeof(InputT)); + TASSIGN(I_neg_l1_tile, 0x0 + 3 * TileLen * sizeof(InputT)); + TASSIGN(I_l1_tile, 0x0 + 4 * TileLen * sizeof(InputT)); + TASSIGN(two_I_l1_tile, 0x0 + 5 * TileLen * sizeof(InputT)); + + for (uint32_t buffer_num = 0; buffer_num < NumL0Buffers; ++buffer_num) { + TASSIGN(a_l0_tile[buffer_num], 0x0 + buffer_num * TileLen * sizeof(InputT)); + TASSIGN(b_l0_tile[buffer_num], 0x0 + buffer_num * TileLen * sizeof(InputT)); + TASSIGN(c_l0_tile[buffer_num], + 0x0 + buffer_num * TileLen * sizeof(OutputT)); + } + + // LOAD GM -> L1 (MTE2) + TLOAD(I_neg_l1_tile, I_neg_global_in); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + + TLOAD(A_neg_l1_tile, M_global_in); + TLOAD(X_l1_tile, I_scaled_global_in); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + + // Precompute I and store to L1 + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + TMOV(a_l0_tile[0], I_neg_l1_tile); + TMOV(b_l0_tile[0], I_neg_l1_tile); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + TMOV(a_l0_tile[1], A_neg_l1_tile); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID1); + + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL(c_l0_tile[1], a_l0_tile[0], b_l0_tile[0]); // c_l0[1] = I + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + TMOV(I_l1_tile, c_l0_tile[1]); + set_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID0); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID1); + TMATMUL(c_l0_tile[0], a_l0_tile[1], b_l0_tile[0]); // c_l0[0] = -M + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID0); + TMOV(a_l0_tile[1], I_l1_tile); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + TMATMUL_ACC(c_l0_tile[1], c_l0_tile[1], a_l0_tile[0], + b_l0_tile[0]); // c_l0[1] <- 2I + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + TMOV(two_I_l1_tile, c_l0_tile[1]); + set_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID0); + + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(c_l0_tile[0], c_l0_tile[0], a_l0_tile[1], + b_l0_tile[0]); // c_l0[0] = -M-I = -A + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID1); + + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID1); + TMOV(A_neg_l1_tile, c_l0_tile[0]); // A_l1 = -A + set_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID1); + set_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID0); + + wait_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TMOV(a_l0_tile[1], two_I_l1_tile); // a_l0[1] <- 2I + TMOV(b_l0_tile[1], A_neg_l1_tile); // b_l0[1] <- -A + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID1); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID1); + + for (uint32_t i = 0; i < num_iters; ++i) { + TMOV(b_l0_tile[0], X_l1_tile); + TMOV(a_l0_tile[0], X_l1_tile); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + + TMATMUL(c_l0_tile[0], a_l0_tile[0], b_l0_tile[1]); // c_l0[0] <- -XA + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + TMOV(Y_l1_tile, c_l0_tile[0]); + set_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID0); + + TMATMUL(c_l0_tile[1], a_l0_tile[1], b_l0_tile[0]); // c_l0[1] <- 2X + wait_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID0); + TMOV(a_l0_tile[0], Y_l1_tile); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(c_l0_tile[1], c_l0_tile[1], a_l0_tile[0], b_l0_tile[0]); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + + if (i < num_iters - 1) { + TMOV(X_l1_tile, c_l0_tile[1]); // X_l1 now contains X_new + set_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID0); + } + } + TSTORE(M_inv_global_out, c_l0_tile[1]); +#else +// Nothing to do on AIV +#endif +} + +template +AICORE void run_tri_inv_ns(__gm__ float* tensor_out, __gm__ InputT* tensor_in, + __gm__ InputT* identity_neg_in, + __gm__ InputT* identity_over_n_in, + uint32_t matrix_size, uint32_t num_iters) { + static_assert(std::is_same_v, "tri_inv_ns supports only fp16."); + switch (matrix_size) { + case 16: + runKernelTriInvNS(tensor_out, tensor_in, + identity_neg_in, identity_over_n_in, + num_iters); + break; + case 32: + runKernelTriInvNS(tensor_out, tensor_in, + identity_neg_in, identity_over_n_in, + num_iters); + break; + case 64: + runKernelTriInvNS(tensor_out, tensor_in, + identity_neg_in, identity_over_n_in, + num_iters); + break; + case 96: + runKernelTriInvNS(tensor_out, tensor_in, + identity_neg_in, identity_over_n_in, + num_iters); + break; + case 128: + runKernelTriInvNS(tensor_out, tensor_in, + identity_neg_in, identity_over_n_in, + num_iters); + break; + } +} + +extern "C" __global__ AICORE void tri_inv_ns_fp16( + __gm__ void* tensor_out, __gm__ void* tensor_in, + __gm__ void* identity_neg_in, __gm__ void* identity_over_n_in, + uint32_t matrix_size, uint32_t num_iters) { + run_tri_inv_ns((__gm__ float*)tensor_out, (__gm__ half*)tensor_in, + (__gm__ half*)identity_neg_in, + (__gm__ half*)identity_over_n_in, matrix_size, + num_iters); +} diff --git a/tests/test_tri_inv_ns.py b/tests/test_tri_inv_ns.py new file mode 100644 index 00000000..80ce0b19 --- /dev/null +++ b/tests/test_tri_inv_ns.py @@ -0,0 +1,133 @@ +# -------------------------------------------------------------------------------- +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# All rights reserved. +# See LICENSE in the root of the software repository: +# https://github.com/huawei-csl/pto-kernels/ +# for the full License text. +# -------------------------------------------------------------------------------- + +import math +import random +from typing import Callable + +import numpy as np +import pytest +import torch + +from pto_kernels import pto_tri_inv_ns + +SEED = 42 +random.seed(SEED) +torch.manual_seed(SEED) +np.random.seed(SEED) + + +def random_triu_matrix(n, block_dim_x, block_dim_y, scale=0.1): + U = scale * torch.triu(torch.rand((block_dim_x, block_dim_y, n, n)), diagonal=1) + return U + + +def block_ones_matrix(n, block_dim_x, block_dim_y): + U_ = np.ones((16, 16)) + n_blocks = n // 16 + U = np.zeros((block_dim_x, block_dim_y, n, n)) + for x in range(block_dim_x): + for y in range(block_dim_y): + for i in range(n_blocks): + start = i * 16 + end = i * 16 + 16 + U[x, y, start:end, start:end] = U_ + return torch.from_numpy(np.triu(U, 1)) + + +def ones_matrix(n, block_dim_x, block_dim_y): + U = np.ones((block_dim_x, block_dim_y, n, n)) + return torch.from_numpy(np.triu(U, 1)) + + +def zeros_matrix(n, block_dim_x, block_dim_y): + return torch.zeros(block_dim_x, block_dim_y, n, n) + + +def block_random_matrix(n, block_dim_x, block_dim_y, scale=0.2): + U_ = scale * np.random.rand(16, 16) + U_ = np.triu(U_, k=1) + U = np.zeros((block_dim_x, block_dim_y, n, n)) + for x in range(block_dim_x): + for y in range(block_dim_y): + for i in range(0, n, 16): + U[x, y, i : i + 16, i : i + 16] = U_.copy() + return torch.from_numpy(U) + + +def linalg_inv(U: torch.Tensor) -> torch.Tensor: + n = U.shape[-1] + identity = np.eye(n, dtype=np.double) + golden_numpy = np.zeros(U.shape) + for x in range(U.shape[0]): + for y in range(U.shape[1]): + golden_numpy[x, y] = np.linalg.inv( + U[x, y].numpy().astype(np.double) + identity + ) + return torch.from_numpy(golden_numpy) + + +def default_num_iters(n: int) -> int: + return int(math.ceil(4.0 * math.log2(n))) + + +def _test_tri_inv_ns( + U: torch.Tensor, + atol: float, + rtol: float, + ftol: float, +): + U = U.to(torch.half) + golden_cpu = linalg_inv(U) + + U_npu = U.npu() + + torch.npu.synchronize() + num_iters = int(4.0 * math.ceil(math.log2(U.shape[-1]))) + actual = pto_tri_inv_ns(U_npu, num_iters=num_iters) + torch.npu.synchronize() + + actual_cpu = actual.cpu().to(torch.float64) + + frob_error = torch.sqrt( + torch.sum((golden_cpu - actual_cpu) * (golden_cpu - actual_cpu)) + / torch.sum(golden_cpu * golden_cpu) + ) + + actual_numpy = actual_cpu.numpy() + golden_numpy = golden_cpu.numpy() + assert np.allclose( + actual_numpy, golden_numpy, atol=atol, rtol=rtol + ), f"Error at allclose - tensor shape: {U.shape} - rtol: {rtol}." + assert frob_error <= ftol, f"frob_error: {frob_error}" + + +@pytest.mark.parametrize("n", [16, 32, 64, 96, 128]) +@pytest.mark.parametrize("block_dim_x", [1, 3, 7, 16]) +@pytest.mark.parametrize("block_dim_y", [1, 2, 4, 16]) +@pytest.mark.parametrize( + "matrix_gen,atol,rtol,ftol", + [ + (zeros_matrix, 5e-5, 0.1, 1e-2), + (ones_matrix, 5e-5, 0.1, 1e-2), + (block_ones_matrix, 5e-5, 0.1, 1e-2), + (block_random_matrix, 5e-5, 0.1, 1e-2), + (random_triu_matrix, 5e-5, 0.1, 1e-2), + ], +) +def test_tri_inv_ns( + n: int, + block_dim_x: int, + block_dim_y: int, + matrix_gen: Callable, + atol: float, + rtol: float, + ftol: float, +): + U = matrix_gen(n, block_dim_x, block_dim_y) + _test_tri_inv_ns(U, atol, rtol, ftol)