Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ ascendc_library(
csrc/kernel/kernel_csr_gather.cpp
csrc/kernel/kernel_simple_matmul.cpp
csrc/kernel/kernel_batch_matrix_square.cpp
csrc/kernel/kernel_tri_inv_ns.cpp
csrc/kernel/kernel_tri_inv_rec_unroll.cpp
csrc/kernel/kernel_tri_inv_trick.cpp
csrc/kernel/kernel_swiglu.cpp)
Expand Down
3 changes: 3 additions & 0 deletions csrc/host/pybind11.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ for the full License text.
#include "torch_simple_matmul.h"
#include "torch_swiglu.h"
#include "torch_tri_inv.h"
#include "torch_tri_inv_ns.h"
#include "torch_tri_inv_rec_unroll.h"
#include "torch_tri_inv_trick.h"

Expand Down Expand Up @@ -45,5 +46,7 @@ PYBIND11_MODULE(pto_kernels_ops, m) {
m.def("pto_tri_inv_rec_unroll", &pto_isa_ops::run_tri_inv_rec_unroll,
py::arg("M"), py::arg("is_bsnd_format") = false,
py::arg("cu_seqlens") = at::zeros({1}));
m.def("pto_tri_inv_ns", &pto_isa_ops::run_tri_inv_ns, py::arg("M"),
py::arg("num_iters") = 0, py::arg("scale_value") = 0.0f);
m.def("pto_tri_inv", &pto_isa_ops::run_tri_inv);
}
82 changes: 82 additions & 0 deletions csrc/host/torch_tri_inv_ns.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/**
Copyright (c) 2026 Huawei Technologies Co., Ltd.
All rights reserved.

See LICENSE in the root of the software repository:
https://github.com/huawei-csl/pto-kernels/
for the full License text.
*/
#pragma once

#include <ATen/ATen.h>
#include <torch/library.h>

#include <cmath>

#include "aclrtlaunch_tri_inv_ns_fp16.h"
#include "utils.h"

namespace pto_isa_ops {

/**
* @brief Triangular inverse using Newton–Schulz iterations.
*
* Implements the following algorithm:
* A = I + M
* X = I * scale
* for _ in range(num_iters):
* Y = A @ X
* X = X @ (2*I - Y)
* return X
*
* @param M Input tensor of strictly upper-triangular matrices (..., n,
* n), dtype fp16. The full matrix inverted by the algorithm is A = I + M.
* @param num_iters Number of Newton–Schulz iterations (0 = auto).
* @param scale_value Value to scale the initial guess. Defaults to zero, which
* sets scale_value = 2 * n, where n is the size of the matrices.
Comment thread
asobczyk marked this conversation as resolved.
* @return at::Tensor Tensor of approximate inverses in fp32, same batch shape
* as M.
*/
at::Tensor run_tri_inv_ns(const at::Tensor& M, uint32_t num_iters = 0,
float scale_value = 0) {
const at::Device device = M.options().device();
const auto dtype = M.options().dtype();
const auto dtype_out = at::kFloat;

if (!(dtype == at::kHalf)) {
throw std::runtime_error(
"Unsupported dtype for tri_inv_ns kernel. Supports only fp16");
}
const uint32_t n = static_cast<uint32_t>(M.size(-1));
if (n != static_cast<uint32_t>(M.size(-2))) {
throw std::runtime_error("Only square matrices are supported.\n");
}

if (scale_value == 0) {
scale_value = 2 * n;
}

const uint32_t num_matrices = static_cast<uint32_t>(M.numel()) / (n * n);

const auto opts_in = at::TensorOptions().dtype(dtype).device(device);

if (num_iters == 0) {
num_iters = static_cast<uint32_t>(std::ceil(2.0f * std::log2(n)));
num_iters = std::max<uint32_t>(num_iters, 8);
}

const at::Tensor I_eye = at::eye(n, opts_in);
Comment thread
asobczyk marked this conversation as resolved.
Outdated
const at::Tensor I_scaled =
(I_eye / scale_value).to(dtype).contiguous(); // per matrix
Comment thread
asobczyk marked this conversation as resolved.
Outdated

const at::Tensor I_neg = -I_eye.contiguous();

const at::Tensor M_inv_raw =
at::zeros_like(M, at::TensorOptions().dtype(dtype_out).device(device));

EXEC_KERNEL_CMD(tri_inv_ns_fp16, num_matrices, M_inv_raw, M, I_neg, I_scaled,
n, num_iters);

return M_inv_raw;
}
} // namespace pto_isa_ops
239 changes: 239 additions & 0 deletions csrc/kernel/kernel_tri_inv_ns.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
/**
Copyright (c) 2026 Huawei Technologies Co., Ltd.
All rights reserved.

See LICENSE in the root of the software repository:
https://github.com/huawei-csl/pto-kernels/
for the full License text.
*/

#ifndef MEMORY_BASE
#define MEMORY_BASE
#endif
#include <pto/pto-inst.hpp>

#define GM_ADDR __gm__ uint8_t* // To avoid #include "kernel_operator.h"
Comment thread
gioelegott marked this conversation as resolved.

using namespace pto;

/**
* @brief Triangular inverse using Newton–Schulz iterations.
*
* Implements the following algorithm:
* A = I + M
* X = I * scale
* for _ in range(num_iters):
* Y = X @ (-A)
* X = Y @ X + 2 * X
* return X
*/
template <typename InputT, typename OutputT, uint32_t MatrixSize>
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick for doxygen docs. I thought CI should complain about it :-(

AICORE void runKernelTriInvNS(__gm__ OutputT* M_inv, __gm__ InputT* M,
__gm__ InputT* I_neg, __gm__ InputT* I_over_n,
uint32_t num_iters) {
#if (__CHECK_FEATURE_AT_PRECOMPILE) || \
(__CCE_AICORE__ == 220 && defined(__DAV_C220_CUBE__)) // Cube compilation

constexpr uint32_t TileLen = MatrixSize * MatrixSize;
const uint32_t global_index = get_block_idx() * TileLen;
constexpr uint32_t NumL0Buffers = 2;

/* Global Memory / Tensors */
using TensorShapeInND =
TileShape2D<InputT, MatrixSize, MatrixSize, Layout::ND>;
using TensorStridesInND =
BaseShape2D<InputT, MatrixSize, MatrixSize, Layout::ND>;
using GlobalTensorIn =
GlobalTensor<InputT, TensorShapeInND, TensorStridesInND, Layout::ND>;

using TensorShapeOut =
TileShape2D<OutputT, MatrixSize, MatrixSize, Layout::ND>;
using TensorStridesOut =
BaseShape2D<OutputT, MatrixSize, MatrixSize, Layout::ND>;
using GlobalTensorOut =
GlobalTensor<OutputT, TensorShapeOut, TensorStridesOut, Layout::ND>;

/* L1 Memory */
using TileL1AB =
Tile<TileType::Mat, InputT, MatrixSize, MatrixSize, BLayout::ColMajor,
MatrixSize, MatrixSize, SLayout::RowMajor, 512>;

/* L0 Memory */
using TileL0A = TileLeft<InputT, MatrixSize, MatrixSize>;
using TileL0B = TileRight<InputT, MatrixSize, MatrixSize>;
using TileL0C = TileAcc<OutputT, MatrixSize, MatrixSize>;

GlobalTensorIn M_global_in(M + global_index);
GlobalTensorIn I_neg_global_in(I_neg);
GlobalTensorIn I_over_n_global_in(I_over_n);
GlobalTensorOut M_inv_global_out(M_inv + global_index);

TileL1AB A_neg_l1_tile;
TileL1AB X_l1_tile;
TileL1AB Y_l1_tile;
TileL1AB I_neg_l1_tile;
TileL1AB I_l1_tile;
TileL1AB two_I_l1_tile;

TileL0A a_l0_tile[NumL0Buffers];
TileL0B b_l0_tile[NumL0Buffers];
TileL0C c_l0_tile[NumL0Buffers];

TASSIGN(A_neg_l1_tile, 0x0);
TASSIGN(X_l1_tile, 0x0 + TileLen * sizeof(InputT));
TASSIGN(Y_l1_tile, 0x0 + 2 * TileLen * sizeof(InputT));
TASSIGN(I_neg_l1_tile, 0x0 + 3 * TileLen * sizeof(InputT));
TASSIGN(I_l1_tile, 0x0 + 4 * TileLen * sizeof(InputT));
TASSIGN(two_I_l1_tile, 0x0 + 5 * TileLen * sizeof(InputT));

for (uint32_t buffer_num = 0; buffer_num < NumL0Buffers; ++buffer_num) {
TASSIGN(a_l0_tile[buffer_num], 0x0 + buffer_num * TileLen * sizeof(InputT));
TASSIGN(b_l0_tile[buffer_num], 0x0 + buffer_num * TileLen * sizeof(InputT));
TASSIGN(c_l0_tile[buffer_num],
0x0 + buffer_num * TileLen * sizeof(OutputT));
}

// LOAD GM -> L1 (MTE2)
TLOAD(I_neg_l1_tile, I_neg_global_in);
set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0);

TLOAD(A_neg_l1_tile, M_global_in);
TLOAD(X_l1_tile, I_over_n_global_in);
set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1);

// Precompute I and store to L1
wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0);
TMOV(a_l0_tile[0], I_neg_l1_tile);
TMOV(b_l0_tile[0], I_neg_l1_tile);
set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0);

wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1);
TMOV(a_l0_tile[1], A_neg_l1_tile);
set_flag(PIPE_MTE1, PIPE_M, EVENT_ID1);

wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0);
TMATMUL(c_l0_tile[1], a_l0_tile[0], b_l0_tile[0]); // c_l0[1] = I
set_flag(PIPE_M, PIPE_FIX, EVENT_ID0);

wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0);
TMOV(I_l1_tile, c_l0_tile[1]);
set_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID0);
set_flag(PIPE_FIX, PIPE_M, EVENT_ID0);

wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID1);
TMATMUL(c_l0_tile[0], a_l0_tile[1], b_l0_tile[0]); // c_l0[0] = -M
set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1);

wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1);
wait_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID0);
TMOV(a_l0_tile[1], I_l1_tile);
set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0);

wait_flag(PIPE_FIX, PIPE_M, EVENT_ID0);
TMATMUL_ACC(c_l0_tile[1], c_l0_tile[1], a_l0_tile[0],
b_l0_tile[0]); // c_l0[1] <- 2I
set_flag(PIPE_M, PIPE_FIX, EVENT_ID0);
wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0);
TMOV(two_I_l1_tile, c_l0_tile[1]);
set_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID0);
wait_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID0);

set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0);

wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0);
TMATMUL_ACC(c_l0_tile[0], c_l0_tile[0], a_l0_tile[1],
b_l0_tile[0]); // c_l0[0] = -M-I = -A
set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1);
set_flag(PIPE_M, PIPE_FIX, EVENT_ID1);

wait_flag(PIPE_M, PIPE_FIX, EVENT_ID1);
TMOV(A_neg_l1_tile, c_l0_tile[0]); // A_l1 = -A
set_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID1);
set_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID0);

wait_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID1);
wait_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID0);
wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0);
wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1);
TMOV(a_l0_tile[1], two_I_l1_tile); // a_l0[1] <- 2I
TMOV(b_l0_tile[1], A_neg_l1_tile); // b_l0[1] <- -A
set_flag(PIPE_MTE1, PIPE_M, EVENT_ID1);
wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID1);

for (uint32_t i = 0; i < num_iters; ++i) {
TMOV(b_l0_tile[0], X_l1_tile);
TMOV(a_l0_tile[0], X_l1_tile);
set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0);
wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0);

TMATMUL(c_l0_tile[0], a_l0_tile[0], b_l0_tile[1]); // c_l0[0] <- -XA
set_flag(PIPE_M, PIPE_FIX, EVENT_ID0);
wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0);
TMOV(Y_l1_tile, c_l0_tile[0]);
set_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID0);

TMATMUL(c_l0_tile[1], a_l0_tile[1], b_l0_tile[0]); // c_l0[1] <- 2X
wait_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID0);
TMOV(a_l0_tile[0], Y_l1_tile);
set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0);
wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0);
TMATMUL_ACC(c_l0_tile[1], c_l0_tile[1], a_l0_tile[0], b_l0_tile[0]);
set_flag(PIPE_M, PIPE_FIX, EVENT_ID0);
wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0);

if (i < num_iters - 1) {
TMOV(X_l1_tile, c_l0_tile[1]); // X_l1 now contains X_new
set_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID0);
wait_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID0);
}
}
TSTORE(M_inv_global_out, c_l0_tile[1]);
#else
// Nothing to do on AIV
#endif
}

template <typename InputT>
AICORE void run_tri_inv_ns(__gm__ float* tensor_out, __gm__ InputT* tensor_in,
__gm__ InputT* identity_neg_in,
__gm__ InputT* identity_over_n_in,
uint32_t matrix_size, uint32_t num_iters) {
static_assert(std::is_same_v<InputT, half>, "tri_inv_ns supports only fp16.");
switch (matrix_size) {
case 16:
runKernelTriInvNS<InputT, float, 16>(tensor_out, tensor_in,
identity_neg_in, identity_over_n_in,
num_iters);
break;
case 32:
runKernelTriInvNS<InputT, float, 32>(tensor_out, tensor_in,
identity_neg_in, identity_over_n_in,
num_iters);
break;
case 64:
runKernelTriInvNS<InputT, float, 64>(tensor_out, tensor_in,
identity_neg_in, identity_over_n_in,
num_iters);
break;
case 96:
runKernelTriInvNS<InputT, float, 96>(tensor_out, tensor_in,
identity_neg_in, identity_over_n_in,
num_iters);
break;
case 128:
runKernelTriInvNS<InputT, float, 128>(tensor_out, tensor_in,
identity_neg_in, identity_over_n_in,
num_iters);
break;
}
}

extern "C" __global__ AICORE void tri_inv_ns_fp16(
__gm__ void* tensor_out, __gm__ void* tensor_in,
__gm__ void* identity_neg_in, __gm__ void* identity_over_n_in,
uint32_t matrix_size, uint32_t num_iters) {
run_tri_inv_ns<half>((__gm__ float*)tensor_out, (__gm__ half*)tensor_in,
(__gm__ half*)identity_neg_in,
(__gm__ half*)identity_over_n_in, matrix_size,
num_iters);
}
Loading
Loading