-
Notifications
You must be signed in to change notification settings - Fork 111
[AscendNPU-IR][A5] A5 support for floorOp and floorDivOp #950
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: npuir-dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -88,6 +88,7 @@ | |
| #include "bishengir/Dialect/Utils/Util.h" | ||
| #include "mlir/IR/TypeUtilities.h" | ||
| #include "mlir/Dialect/Math/IR/Math.h" | ||
| #include "tvm/ir/op.h" | ||
| #include "tvm/runtime/logging.h" | ||
| #include "llvm/Support/Debug.h" | ||
|
|
||
|
|
@@ -759,15 +760,18 @@ mlir::Type CodeGenTileLangNPUIRDEV::DTypetoMLIRType(DataType t) { // NOLINT(*) | |
| mlir::Value CodeGenTileLangNPUIRDEV::VisitExpr_(const FloorDivNode *op) { | ||
| auto lhs = MakeValue(op->a); | ||
| auto rhs = MakeValue(op->b); | ||
| // FIXME: The floor div in python is not the same as arith.divsi in negative | ||
| // scenarios. | ||
| mlir::Value mlirVal; | ||
| if (op->dtype.is_int() || op->dtype.is_uint()) { | ||
| // FIXME: The floor div in python is not the same as arith.divsi in negative | ||
| // scenarios. | ||
| mlirVal = BinaryOpCodegen<mlir::arith::DivSIOp, std::nullptr_t>(op, nullptr, | ||
| lhs, rhs); | ||
| } else if (op->dtype.is_float()) { | ||
| mlirVal = BinaryOpCodegen<mlir::arith::DivFOp, std::nullptr_t>(op, nullptr, | ||
| lhs, rhs); | ||
| auto result = CheckPrimExprMap(op); | ||
| if (result.first) return result.second; | ||
| auto divResult = builder.create<mlir::arith::DivFOp>(builder.getUnknownLoc(), lhs, rhs); | ||
| mlirVal = builder.create<mlir::math::FloorOp>(builder.getUnknownLoc(), divResult); | ||
| UpdatePrimExprMap(op, mlirVal); | ||
| } | ||
| return mlirVal; | ||
| } | ||
|
|
@@ -3408,6 +3412,58 @@ void CodeGenTileLangNPUIRDEV::VtanhCodegen(const CallNode *op) { | |
| } | ||
| } | ||
|
|
||
| void CodeGenTileLangNPUIRDEV::VfloordivCodegen(const CallNode *op) { | ||
| tvm::tl::NpuirFloorDiv npuirop(op->args, this->vmap); | ||
| auto loc = builder.getUnknownLoc(); | ||
|
|
||
| Value src0 = GenExtractSliceFromRegion(npuirop.src0, npuirop.src0_range); | ||
| Value src1 = GenExtractSliceFromRegion(npuirop.src1, npuirop.src1_range); | ||
| const CallNode *region_node_dst = op->args[2].as<CallNode>(); | ||
|
|
||
| tvm::tl::RegionOp region_dst_tmp(region_node_dst->args, vmap); | ||
| Array<Range> dst_range = region_dst_tmp.GetRanges(); | ||
|
|
||
| mlir::Value insertBase = NeedGenInsertSlice(region_dst_tmp.GetBuffer(), dst_range, src0); | ||
| bool needInsertSlice = (insertBase != GetVarValue(region_node_dst)); | ||
|
|
||
| auto src0TensorTy = src0.getType().cast<mlir::TensorType>(); | ||
| auto src0Shape = src0TensorTy.getShape(); | ||
| auto src1TensorTy = src1.getType().cast<mlir::TensorType>(); | ||
| auto src1Shape = src1TensorTy.getShape(); | ||
| auto dstTensorTy = insertBase.getType().cast<mlir::TensorType>(); | ||
| auto dstShape = dstTensorTy.getShape(); | ||
|
|
||
| // transpose | ||
| mlir::DenseI64ArrayAttr transpose = builder.getDenseI64ArrayAttr({}); | ||
| // broadcast | ||
| ArrayRef<int64_t> shape; | ||
| if (auto shapedType = insertBase.getType().dyn_cast<ShapedType>()) { | ||
| shape = shapedType.getShape(); | ||
| } | ||
| auto dims0 = getBroadcastDim(src0Shape, dstShape); | ||
| auto brc0 = builder.getDenseI64ArrayAttr(dims0); | ||
| auto dims1 = getBroadcastDim(src1Shape, dstShape); | ||
| auto brc1 = builder.getDenseI64ArrayAttr(dims1); | ||
| llvm::SetVector<int64_t> dims(llvm::from_range_t(), dims0); | ||
| dims.insert_range(dims1); | ||
| mlir::DenseI64ArrayAttr broadcast = | ||
| builder.getDenseI64ArrayAttr(dims.takeVector()); | ||
| src0 = broadcastOrTranspose(src0, insertBase, brc0, transpose, builder); | ||
| src1 = broadcastOrTranspose(src1, insertBase, brc1, transpose, builder); | ||
| Value Op; | ||
| if (src1TensorTy.getElementType().isa<FloatType>()) { | ||
| auto divResult = builder.create<mlir::arith::DivFOp>(loc, src0, src1); | ||
| Op = builder.create<mlir::math::FloorOp>(loc, divResult); | ||
| } else { | ||
| Op = builder.create<mlir::arith::DivSIOp>(loc, src0, src1); | ||
| } | ||
| mlir::Value result = needInsertSlice | ||
| ? ReshapeCastAndInsertSlice(Op, GetVarValue(region_node_dst), dst_range) | ||
| : Op; | ||
|
|
||
| SetVarValue(region_node_dst, result); | ||
| } | ||
|
Comment on lines
+3415
to
+3465
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The implementation of
|
||
|
|
||
| /// Generate hivm.hir.vreduce for tl.npuir_reshape. | ||
| /// before: | ||
| /// T.npuir_reshape(A, B) | ||
|
|
@@ -3520,6 +3576,8 @@ mlir::Value CodeGenTileLangNPUIRDEV::VisitExpr_(const CallNode *op) { | |
| CreateHIVMBinaryVectorOp<ElemwiseOp<linalg::BinaryFn::add>>(op); | ||
| } else if (op->op.same_as(Op::Get("tl.npuir_exp"))) { | ||
| UnaryVecOpCodegen<tvm::tl::NpuirExp, ElemwiseOp<linalg::UnaryFn::exp>>(op); | ||
| } else if (op->op.same_as(Op::Get("tl.npuir_floor"))) { | ||
| UnaryVecOpCodegen<tvm::tl::NpuirFloor, ElemwiseOp<linalg::UnaryFn::floor>>(op); | ||
| } else if (op->op.same_as(Op::Get("tl.npuir_ln"))) { | ||
| UnaryVecOpCodegen<tvm::tl::NpuirLn, ElemwiseOp<linalg::UnaryFn::log>>(op); | ||
| } else if (op->op.same_as(Op::Get("tl.npuir_relu"))) { | ||
|
|
@@ -3619,6 +3677,8 @@ mlir::Value CodeGenTileLangNPUIRDEV::VisitExpr_(const CallNode *op) { | |
| VerfCodegen(op); | ||
| } else if (op->op.same_as(Op::Get("tl.npuir_vtanh"))) { | ||
| VtanhCodegen(op); | ||
| } else if (op->op.same_as(Op::Get("tl.npuir_floordiv"))) { | ||
| VfloordivCodegen(op); | ||
| } else if (op->op.same_as(Op::Get("tl.npuir_debug_print_var")) || | ||
| op->op.same_as(Op::Get("tl.npuir_debug_print_buffer_value"))) { | ||
| DebugPrintCodegen(op); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| import os | ||
|
|
||
| import pytest | ||
| import torch | ||
| import torch_npu | ||
|
|
||
| import tilelang | ||
| import tilelang.language as T | ||
|
|
||
| tilelang.cache.clear_cache() | ||
| os.environ["TILELANG_ASCEND_MODE"] = "Developer" | ||
|
|
||
|
|
||
| @pytest.fixture( | ||
| params=[ | ||
| ((1024, 1024), "float16"), | ||
| ((1024, 4096), "float16"), | ||
| ((1024, 10240), "float32"), | ||
| ((1024, 16384), "float32"), | ||
| ] | ||
| ) | ||
| def floor_case(request): | ||
| return request.param | ||
|
|
||
| block_M = 16 | ||
| block_N = 1024 | ||
|
|
||
| def floor_kernel(M, N, dtype): | ||
| m_num = M // block_M | ||
| n_num = N // block_N | ||
|
|
||
| @T.prim_func | ||
| def floorKernel(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): | ||
| with T.Kernel(m_num * n_num, is_npu=True) as (cid, _): | ||
| bx = block_M * (cid // n_num) | ||
| by = block_N * (cid % n_num) | ||
|
|
||
| A_VEC = T.alloc_shared((block_M, block_N), dtype) | ||
| B_VEC = T.alloc_shared((block_M, block_N), dtype) | ||
|
|
||
| T.copy(A[bx:bx+block_M, by:by+block_N], A_VEC) | ||
| T.npuir_floor(A_VEC, B_VEC) | ||
| T.copy(B_VEC, B[bx:bx+block_M, by:by+block_N]) | ||
|
|
||
| return floorKernel | ||
|
|
||
|
|
||
| def generate_tensor(shape, dtype, clear=False, positive=False): | ||
| if clear: | ||
| return torch.zeros(shape, dtype=eval("torch." + dtype)) | ||
| if dtype in ("float32", "float16"): | ||
| t = torch.randn(size=shape, dtype=eval("torch." + dtype)) | ||
| if positive: | ||
| t = torch.abs(t) + 0.1 | ||
| return t | ||
| raise ValueError('Invalid parameter "dtype" is found : {}'.format(dtype)) | ||
|
|
||
|
|
||
| def test_floor_dev(floor_case): | ||
| shape, dtype = floor_case | ||
|
|
||
| func = floor_kernel(*shape, dtype) | ||
| compiled_kernel = tilelang.compile(func, target="npuir") | ||
|
|
||
| src = generate_tensor(shape, dtype).npu() | ||
| dst = generate_tensor(shape, dtype, clear=True).npu() | ||
|
|
||
| ref = torch.floor(src.cpu()) | ||
| compiled_kernel(src, dst) | ||
|
|
||
| assert torch.allclose(dst.cpu(), ref, rtol=1e-5, atol=1e-5) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,75 @@ | ||
| import os | ||
|
|
||
| import pytest | ||
| import torch | ||
| import torch_npu | ||
|
|
||
| import tilelang | ||
| import tilelang.language as T | ||
|
|
||
| tilelang.cache.clear_cache() | ||
| os.environ["TILELANG_ASCEND_MODE"] = "Developer" | ||
|
|
||
| @pytest.fixture( | ||
| params=[ | ||
| ((1024, 1024), "float16"), | ||
| ((1024, 4096), "float16"), | ||
| ((1024, 10240), "float32"), | ||
| ((1024, 16384), "float32"), | ||
| ] | ||
| ) | ||
| def floordiv_case(request): | ||
| return request.param | ||
|
|
||
| block_M = 16 | ||
| block_N = 1024 | ||
|
|
||
|
|
||
| def floordiv_kernel(M, N, dtype): | ||
| m_num = M // block_M | ||
| n_num = N // block_N | ||
|
|
||
| @T.prim_func | ||
| def floordivKernel(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype)): | ||
| with T.Kernel(m_num * n_num, is_npu=True) as (cid, _): | ||
| bx = block_M * (cid // n_num) | ||
| by = block_N * (cid % n_num) | ||
|
|
||
| A_VEC = T.alloc_shared((block_M, block_N), dtype) | ||
| B_VEC = T.alloc_shared((block_M, block_N), dtype) | ||
| C_VEC = T.alloc_shared((block_M, block_N), dtype) | ||
|
|
||
| T.copy(A[bx:bx+block_M, by:by+block_N], A_VEC) | ||
| T.copy(B[bx:bx+block_M, by:by+block_N], B_VEC) | ||
| T.npuir_floordiv(A_VEC, B_VEC, C_VEC) | ||
| T.copy(C_VEC, C[bx:bx+block_M, by:by+block_N]) | ||
|
|
||
| return floordivKernel | ||
|
|
||
|
|
||
| def generate_tensor(shape, dtype, clear=False, positive=False): | ||
| if clear: | ||
| return torch.zeros(shape, dtype=eval("torch." + dtype)) | ||
| if dtype in ("float32", "float16"): | ||
| t = torch.randn(size=shape, dtype=eval("torch." + dtype)) | ||
| if positive: | ||
| t = torch.abs(t) + 0.1 | ||
| return t | ||
| raise ValueError('Invalid parameter "dtype" is found : {}'.format(dtype)) | ||
|
|
||
|
|
||
| def test_floordiv_dev(floordiv_case): | ||
| shape, dtype = floordiv_case | ||
|
|
||
| func = floordiv_kernel(*shape, dtype) | ||
| compiled_kernel = tilelang.compile(func, target="npuir") | ||
|
|
||
| a = generate_tensor(shape, dtype).npu() | ||
| b = generate_tensor(shape, dtype, positive=True).npu() | ||
| c = generate_tensor(shape, dtype, clear=True).npu() | ||
|
|
||
| ref = torch.floor(a.cpu() / b.cpu()) | ||
| compiled_kernel(a, b, c) | ||
|
|
||
| assert torch.allclose(c.cpu(), ref, rtol=1e-3, atol=1e-3) | ||
|
|
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -227,6 +227,13 @@ def npuir_rec(A, B): | |||||||||||||||||||||||||||
| def npuir_not(A, B): | ||||||||||||||||||||||||||||
| return AscendUnaryOp("not", A, B).buildTirCall() | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| """npuir floor at tile-level.""" | ||||||||||||||||||||||||||||
| def npuir_floor(A, B): | ||||||||||||||||||||||||||||
| return AscendUnaryOp("floor", A, B).buildTirCall() | ||||||||||||||||||||||||||||
| """npuir floordiv at tile-level.""" | ||||||||||||||||||||||||||||
| def npuir_floordiv(A, B, C): | ||||||||||||||||||||||||||||
| return AscendBinaryOp("floordiv", A, B, C).buildTirCall() | ||||||||||||||||||||||||||||
|
Comment on lines
+230
to
+235
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The docstrings for
Suggested change
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def npuir_exp2(A, B, Tmp): | ||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||
| npuir exp2 at tile-level. | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The removal of this FIXME is premature.
mlir::arith::DivSIOpperforms truncation towards zero, which does not match Python's floor division semantics for negative integers (e.g.,-5 // 2should be-3, butarith.divsiresults in-2). While the float path now correctly usesmath::FloorOp, the integer path remains incorrect for negative scenarios.