Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/op/ascend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ NPUIR_BINARY_OP_CTOR(And, and)
NPUIR_BINARY_OP_CTOR(Xor, xor)
NPUIR_BINARY_OP_CTOR(Pow, pow)
NPUIR_BINARY_OP_CTOR(Shl, shl)
NPUIR_BINARY_OP_CTOR(FloorDiv, floordiv)

#define NPUIR_UNARY_OP_CTOR(OPNAME, opname) \
Npuir##OPNAME::Npuir##OPNAME(Array<PrimExpr> args, BufferMap vmap) { \
Expand Down Expand Up @@ -100,6 +101,7 @@ NPUIR_UNARY_OP_CTOR(Rsqrt, rsqrt)
NPUIR_UNARY_OP_CTOR(Abs, abs)
NPUIR_UNARY_OP_CTOR(Rec, rec)
NPUIR_UNARY_OP_CTOR(Not, not )
NPUIR_UNARY_OP_CTOR(Floor, floor)

NpuirBrc::NpuirBrc(Array<PrimExpr> args, BufferMap vmap) {
in = args[0], out = args[1];
Expand Down
12 changes: 12 additions & 0 deletions src/op/ascend.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,18 @@ NPUIR_UNARY_OP_CLASS(Rsqrt)
NPUIR_UNARY_OP_CLASS(Abs)
NPUIR_UNARY_OP_CLASS(Rec)
NPUIR_UNARY_OP_CLASS(Not)
NPUIR_UNARY_OP_CLASS(Floor)

class NpuirFloorDiv : public Operator {
public:
NpuirFloorDiv(Array<PrimExpr> args, BufferMap vmap);

static const Op &Get();

Buffer src0, src1, dst;

Array<Range> src0_range, src1_range, dst_range;
};

class NpuirDot : public Operator {
public:
Expand Down
68 changes: 64 additions & 4 deletions src/target/codegen_npuir_dev.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
#include "bishengir/Dialect/Utils/Util.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "tvm/ir/op.h"
#include "tvm/runtime/logging.h"
#include "llvm/Support/Debug.h"

Expand Down Expand Up @@ -759,15 +760,18 @@ mlir::Type CodeGenTileLangNPUIRDEV::DTypetoMLIRType(DataType t) { // NOLINT(*)
mlir::Value CodeGenTileLangNPUIRDEV::VisitExpr_(const FloorDivNode *op) {
auto lhs = MakeValue(op->a);
auto rhs = MakeValue(op->b);
// FIXME: The floor div in python is not the same as arith.divsi in negative
// scenarios.
mlir::Value mlirVal;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The removal of this FIXME is premature. mlir::arith::DivSIOp performs truncation towards zero, which does not match Python's floor division semantics for negative integers (e.g., -5 // 2 should be -3, but arith.divsi results in -2). While the float path now correctly uses math::FloorOp, the integer path remains incorrect for negative scenarios.

if (op->dtype.is_int() || op->dtype.is_uint()) {
// FIXME: The floor div in python is not the same as arith.divsi in negative
// scenarios.
mlirVal = BinaryOpCodegen<mlir::arith::DivSIOp, std::nullptr_t>(op, nullptr,
lhs, rhs);
} else if (op->dtype.is_float()) {
mlirVal = BinaryOpCodegen<mlir::arith::DivFOp, std::nullptr_t>(op, nullptr,
lhs, rhs);
auto result = CheckPrimExprMap(op);
if (result.first) return result.second;
auto divResult = builder.create<mlir::arith::DivFOp>(builder.getUnknownLoc(), lhs, rhs);
mlirVal = builder.create<mlir::math::FloorOp>(builder.getUnknownLoc(), divResult);
UpdatePrimExprMap(op, mlirVal);
}
return mlirVal;
}
Expand Down Expand Up @@ -3408,6 +3412,58 @@ void CodeGenTileLangNPUIRDEV::VtanhCodegen(const CallNode *op) {
}
}

void CodeGenTileLangNPUIRDEV::VfloordivCodegen(const CallNode *op) {
tvm::tl::NpuirFloorDiv npuirop(op->args, this->vmap);
auto loc = builder.getUnknownLoc();

Value src0 = GenExtractSliceFromRegion(npuirop.src0, npuirop.src0_range);
Value src1 = GenExtractSliceFromRegion(npuirop.src1, npuirop.src1_range);
const CallNode *region_node_dst = op->args[2].as<CallNode>();

tvm::tl::RegionOp region_dst_tmp(region_node_dst->args, vmap);
Array<Range> dst_range = region_dst_tmp.GetRanges();

mlir::Value insertBase = NeedGenInsertSlice(region_dst_tmp.GetBuffer(), dst_range, src0);
bool needInsertSlice = (insertBase != GetVarValue(region_node_dst));

auto src0TensorTy = src0.getType().cast<mlir::TensorType>();
auto src0Shape = src0TensorTy.getShape();
auto src1TensorTy = src1.getType().cast<mlir::TensorType>();
auto src1Shape = src1TensorTy.getShape();
auto dstTensorTy = insertBase.getType().cast<mlir::TensorType>();
auto dstShape = dstTensorTy.getShape();

// transpose
mlir::DenseI64ArrayAttr transpose = builder.getDenseI64ArrayAttr({});
// broadcast
ArrayRef<int64_t> shape;
if (auto shapedType = insertBase.getType().dyn_cast<ShapedType>()) {
shape = shapedType.getShape();
}
auto dims0 = getBroadcastDim(src0Shape, dstShape);
auto brc0 = builder.getDenseI64ArrayAttr(dims0);
auto dims1 = getBroadcastDim(src1Shape, dstShape);
auto brc1 = builder.getDenseI64ArrayAttr(dims1);
llvm::SetVector<int64_t> dims(llvm::from_range_t(), dims0);
dims.insert_range(dims1);
mlir::DenseI64ArrayAttr broadcast =
builder.getDenseI64ArrayAttr(dims.takeVector());
src0 = broadcastOrTranspose(src0, insertBase, brc0, transpose, builder);
src1 = broadcastOrTranspose(src1, insertBase, brc1, transpose, builder);
Value Op;
if (src1TensorTy.getElementType().isa<FloatType>()) {
auto divResult = builder.create<mlir::arith::DivFOp>(loc, src0, src1);
Op = builder.create<mlir::math::FloorOp>(loc, divResult);
} else {
Op = builder.create<mlir::arith::DivSIOp>(loc, src0, src1);
}
mlir::Value result = needInsertSlice
? ReshapeCastAndInsertSlice(Op, GetVarValue(region_node_dst), dst_range)
: Op;

SetVarValue(region_node_dst, result);
}
Comment on lines +3415 to +3465
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The implementation of VfloordivCodegen is missing critical logic present in other vector operations:

  1. Broadcasting: MLIR arithmetic operations require operands to have matching shapes. If src0 and src1 have different shapes (e.g., one is a broadcasted dimension), this will fail. You should use broadcastOrTranspose to align operand shapes.
  2. Sliced Insertion: SetVarValue(npuirop.dst, result) overwrites the entire buffer variable mapping. If the operation is performed on a slice (as indicated by dst_range), it must use ReshapeCastAndInsertSlice to update only the relevant portion of the tensor and maintain SSA consistency.


/// Generate hivm.hir.vreduce for tl.npuir_reshape.
/// before:
/// T.npuir_reshape(A, B)
Expand Down Expand Up @@ -3520,6 +3576,8 @@ mlir::Value CodeGenTileLangNPUIRDEV::VisitExpr_(const CallNode *op) {
CreateHIVMBinaryVectorOp<ElemwiseOp<linalg::BinaryFn::add>>(op);
} else if (op->op.same_as(Op::Get("tl.npuir_exp"))) {
UnaryVecOpCodegen<tvm::tl::NpuirExp, ElemwiseOp<linalg::UnaryFn::exp>>(op);
} else if (op->op.same_as(Op::Get("tl.npuir_floor"))) {
UnaryVecOpCodegen<tvm::tl::NpuirFloor, ElemwiseOp<linalg::UnaryFn::floor>>(op);
} else if (op->op.same_as(Op::Get("tl.npuir_ln"))) {
UnaryVecOpCodegen<tvm::tl::NpuirLn, ElemwiseOp<linalg::UnaryFn::log>>(op);
} else if (op->op.same_as(Op::Get("tl.npuir_relu"))) {
Expand Down Expand Up @@ -3619,6 +3677,8 @@ mlir::Value CodeGenTileLangNPUIRDEV::VisitExpr_(const CallNode *op) {
VerfCodegen(op);
} else if (op->op.same_as(Op::Get("tl.npuir_vtanh"))) {
VtanhCodegen(op);
} else if (op->op.same_as(Op::Get("tl.npuir_floordiv"))) {
VfloordivCodegen(op);
} else if (op->op.same_as(Op::Get("tl.npuir_debug_print_var")) ||
op->op.same_as(Op::Get("tl.npuir_debug_print_buffer_value"))) {
DebugPrintCodegen(op);
Expand Down
1 change: 1 addition & 0 deletions src/target/codegen_npuir_dev.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ class CodeGenTileLangNPUIRDEV final
void VsinCodegen(const CallNode *op);
void VerfCodegen(const CallNode *op);
void VtanhCodegen(const CallNode *op);
void VfloordivCodegen(const CallNode *op);
void DebugPrintCodegen(const CallNode *op);
void ReshapeCodegen(const CallNode *op);
template <typename T, typename U = void, typename V = void>
Expand Down
71 changes: 71 additions & 0 deletions testing/npuir/arith_ops/test_floor_dev.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import os

import pytest
import torch
import torch_npu

import tilelang
import tilelang.language as T

tilelang.cache.clear_cache()
os.environ["TILELANG_ASCEND_MODE"] = "Developer"


@pytest.fixture(
params=[
((1024, 1024), "float16"),
((1024, 4096), "float16"),
((1024, 10240), "float32"),
((1024, 16384), "float32"),
]
)
def floor_case(request):
return request.param

block_M = 16
block_N = 1024

def floor_kernel(M, N, dtype):
m_num = M // block_M
n_num = N // block_N

@T.prim_func
def floorKernel(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)):
with T.Kernel(m_num * n_num, is_npu=True) as (cid, _):
bx = block_M * (cid // n_num)
by = block_N * (cid % n_num)

A_VEC = T.alloc_shared((block_M, block_N), dtype)
B_VEC = T.alloc_shared((block_M, block_N), dtype)

T.copy(A[bx:bx+block_M, by:by+block_N], A_VEC)
T.npuir_floor(A_VEC, B_VEC)
T.copy(B_VEC, B[bx:bx+block_M, by:by+block_N])

return floorKernel


def generate_tensor(shape, dtype, clear=False, positive=False):
if clear:
return torch.zeros(shape, dtype=eval("torch." + dtype))
if dtype in ("float32", "float16"):
t = torch.randn(size=shape, dtype=eval("torch." + dtype))
if positive:
t = torch.abs(t) + 0.1
return t
raise ValueError('Invalid parameter "dtype" is found : {}'.format(dtype))


def test_floor_dev(floor_case):
shape, dtype = floor_case

func = floor_kernel(*shape, dtype)
compiled_kernel = tilelang.compile(func, target="npuir")

src = generate_tensor(shape, dtype).npu()
dst = generate_tensor(shape, dtype, clear=True).npu()

ref = torch.floor(src.cpu())
compiled_kernel(src, dst)

assert torch.allclose(dst.cpu(), ref, rtol=1e-5, atol=1e-5)
75 changes: 75 additions & 0 deletions testing/npuir/arith_ops/test_floordiv_dev.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import os

import pytest
import torch
import torch_npu

import tilelang
import tilelang.language as T

tilelang.cache.clear_cache()
os.environ["TILELANG_ASCEND_MODE"] = "Developer"

@pytest.fixture(
params=[
((1024, 1024), "float16"),
((1024, 4096), "float16"),
((1024, 10240), "float32"),
((1024, 16384), "float32"),
]
)
def floordiv_case(request):
return request.param

block_M = 16
block_N = 1024


def floordiv_kernel(M, N, dtype):
m_num = M // block_M
n_num = N // block_N

@T.prim_func
def floordivKernel(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype)):
with T.Kernel(m_num * n_num, is_npu=True) as (cid, _):
bx = block_M * (cid // n_num)
by = block_N * (cid % n_num)

A_VEC = T.alloc_shared((block_M, block_N), dtype)
B_VEC = T.alloc_shared((block_M, block_N), dtype)
C_VEC = T.alloc_shared((block_M, block_N), dtype)

T.copy(A[bx:bx+block_M, by:by+block_N], A_VEC)
T.copy(B[bx:bx+block_M, by:by+block_N], B_VEC)
T.npuir_floordiv(A_VEC, B_VEC, C_VEC)
T.copy(C_VEC, C[bx:bx+block_M, by:by+block_N])

return floordivKernel


def generate_tensor(shape, dtype, clear=False, positive=False):
if clear:
return torch.zeros(shape, dtype=eval("torch." + dtype))
if dtype in ("float32", "float16"):
t = torch.randn(size=shape, dtype=eval("torch." + dtype))
if positive:
t = torch.abs(t) + 0.1
return t
raise ValueError('Invalid parameter "dtype" is found : {}'.format(dtype))


def test_floordiv_dev(floordiv_case):
shape, dtype = floordiv_case

func = floordiv_kernel(*shape, dtype)
compiled_kernel = tilelang.compile(func, target="npuir")

a = generate_tensor(shape, dtype).npu()
b = generate_tensor(shape, dtype, positive=True).npu()
c = generate_tensor(shape, dtype, clear=True).npu()

ref = torch.floor(a.cpu() / b.cpu())
compiled_kernel(a, b, c)

assert torch.allclose(c.cpu(), ref, rtol=1e-3, atol=1e-3)

4 changes: 4 additions & 0 deletions tilelang/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@
npuir_min as vmin,
npuir_mul,
npuir_mul as vmul,
npuir_floor,
npuir_floor as vfloor,
npuir_floordiv,
npuir_floordiv as vfloordiv,
npuir_div,
npuir_div as vdiv,
npuir_or,
Expand Down
7 changes: 7 additions & 0 deletions tilelang/language/customize_npuir.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,13 @@ def npuir_rec(A, B):
def npuir_not(A, B):
return AscendUnaryOp("not", A, B).buildTirCall()

"""npuir floor at tile-level."""
def npuir_floor(A, B):
return AscendUnaryOp("floor", A, B).buildTirCall()
"""npuir floordiv at tile-level."""
def npuir_floordiv(A, B, C):
return AscendBinaryOp("floordiv", A, B, C).buildTirCall()
Comment on lines +230 to +235
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstrings for npuir_floor and npuir_floordiv are placed outside the function definitions as standalone string literals. They should be moved inside the functions to be correctly recognized as docstrings.

Suggested change
"""npuir floor at tile-level."""
def npuir_floor(A, B):
return AscendUnaryOp("floor", A, B).buildTirCall()
"""npuir floordiv at tile-level."""
def npuir_floordiv(A, B, C):
return AscendBinaryOp("floordiv", A, B, C).buildTirCall()
def npuir_floor(A, B):
"""npuir floor at tile-level."""
return AscendUnaryOp("floor", A, B).buildTirCall()
def npuir_floordiv(A, B, C):
"""npuir floordiv at tile-level."""
return AscendBinaryOp("floordiv", A, B, C).buildTirCall()


def npuir_exp2(A, B, Tmp):
"""
npuir exp2 at tile-level.
Expand Down