diff --git a/src/op/ascend.cc b/src/op/ascend.cc index e37e93518..529579e35 100644 --- a/src/op/ascend.cc +++ b/src/op/ascend.cc @@ -69,6 +69,7 @@ NPUIR_BINARY_OP_CTOR(And, and) NPUIR_BINARY_OP_CTOR(Xor, xor) NPUIR_BINARY_OP_CTOR(Pow, pow) NPUIR_BINARY_OP_CTOR(Shl, shl) +NPUIR_BINARY_OP_CTOR(FloorDiv, floordiv) #define NPUIR_UNARY_OP_CTOR(OPNAME, opname) \ Npuir##OPNAME::Npuir##OPNAME(Array args, BufferMap vmap) { \ @@ -100,6 +101,7 @@ NPUIR_UNARY_OP_CTOR(Rsqrt, rsqrt) NPUIR_UNARY_OP_CTOR(Abs, abs) NPUIR_UNARY_OP_CTOR(Rec, rec) NPUIR_UNARY_OP_CTOR(Not, not ) +NPUIR_UNARY_OP_CTOR(Floor, floor) NpuirBrc::NpuirBrc(Array args, BufferMap vmap) { in = args[0], out = args[1]; diff --git a/src/op/ascend.h b/src/op/ascend.h index 08452780b..08e917460 100644 --- a/src/op/ascend.h +++ b/src/op/ascend.h @@ -75,6 +75,18 @@ NPUIR_UNARY_OP_CLASS(Rsqrt) NPUIR_UNARY_OP_CLASS(Abs) NPUIR_UNARY_OP_CLASS(Rec) NPUIR_UNARY_OP_CLASS(Not) +NPUIR_UNARY_OP_CLASS(Floor) + +class NpuirFloorDiv : public Operator { +public: + NpuirFloorDiv(Array args, BufferMap vmap); + + static const Op &Get(); + + Buffer src0, src1, dst; + + Array src0_range, src1_range, dst_range; +}; class NpuirDot : public Operator { public: diff --git a/src/target/codegen_npuir_dev.cc b/src/target/codegen_npuir_dev.cc index 27b49f70a..173513ac4 100644 --- a/src/target/codegen_npuir_dev.cc +++ b/src/target/codegen_npuir_dev.cc @@ -88,6 +88,7 @@ #include "bishengir/Dialect/Utils/Util.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Dialect/Math/IR/Math.h" +#include "tvm/ir/op.h" #include "tvm/runtime/logging.h" #include "llvm/Support/Debug.h" @@ -759,15 +760,18 @@ mlir::Type CodeGenTileLangNPUIRDEV::DTypetoMLIRType(DataType t) { // NOLINT(*) mlir::Value CodeGenTileLangNPUIRDEV::VisitExpr_(const FloorDivNode *op) { auto lhs = MakeValue(op->a); auto rhs = MakeValue(op->b); - // FIXME: The floor div in python is not the same as arith.divsi in negative - // scenarios. mlir::Value mlirVal; if (op->dtype.is_int() || op->dtype.is_uint()) { + // FIXME: The floor div in python is not the same as arith.divsi in negative + // scenarios. mlirVal = BinaryOpCodegen(op, nullptr, lhs, rhs); } else if (op->dtype.is_float()) { - mlirVal = BinaryOpCodegen(op, nullptr, - lhs, rhs); + auto result = CheckPrimExprMap(op); + if (result.first) return result.second; + auto divResult = builder.create(builder.getUnknownLoc(), lhs, rhs); + mlirVal = builder.create(builder.getUnknownLoc(), divResult); + UpdatePrimExprMap(op, mlirVal); } return mlirVal; } @@ -3408,6 +3412,58 @@ void CodeGenTileLangNPUIRDEV::VtanhCodegen(const CallNode *op) { } } +void CodeGenTileLangNPUIRDEV::VfloordivCodegen(const CallNode *op) { + tvm::tl::NpuirFloorDiv npuirop(op->args, this->vmap); + auto loc = builder.getUnknownLoc(); + + Value src0 = GenExtractSliceFromRegion(npuirop.src0, npuirop.src0_range); + Value src1 = GenExtractSliceFromRegion(npuirop.src1, npuirop.src1_range); + const CallNode *region_node_dst = op->args[2].as(); + + tvm::tl::RegionOp region_dst_tmp(region_node_dst->args, vmap); + Array dst_range = region_dst_tmp.GetRanges(); + + mlir::Value insertBase = NeedGenInsertSlice(region_dst_tmp.GetBuffer(), dst_range, src0); + bool needInsertSlice = (insertBase != GetVarValue(region_node_dst)); + + auto src0TensorTy = src0.getType().cast(); + auto src0Shape = src0TensorTy.getShape(); + auto src1TensorTy = src1.getType().cast(); + auto src1Shape = src1TensorTy.getShape(); + auto dstTensorTy = insertBase.getType().cast(); + auto dstShape = dstTensorTy.getShape(); + + // transpose + mlir::DenseI64ArrayAttr transpose = builder.getDenseI64ArrayAttr({}); + // broadcast + ArrayRef shape; + if (auto shapedType = insertBase.getType().dyn_cast()) { + shape = shapedType.getShape(); + } + auto dims0 = getBroadcastDim(src0Shape, dstShape); + auto brc0 = builder.getDenseI64ArrayAttr(dims0); + auto dims1 = getBroadcastDim(src1Shape, dstShape); + auto brc1 = builder.getDenseI64ArrayAttr(dims1); + llvm::SetVector dims(llvm::from_range_t(), dims0); + dims.insert_range(dims1); + mlir::DenseI64ArrayAttr broadcast = + builder.getDenseI64ArrayAttr(dims.takeVector()); + src0 = broadcastOrTranspose(src0, insertBase, brc0, transpose, builder); + src1 = broadcastOrTranspose(src1, insertBase, brc1, transpose, builder); + Value Op; + if (src1TensorTy.getElementType().isa()) { + auto divResult = builder.create(loc, src0, src1); + Op = builder.create(loc, divResult); + } else { + Op = builder.create(loc, src0, src1); + } + mlir::Value result = needInsertSlice + ? ReshapeCastAndInsertSlice(Op, GetVarValue(region_node_dst), dst_range) + : Op; + + SetVarValue(region_node_dst, result); +} + /// Generate hivm.hir.vreduce for tl.npuir_reshape. /// before: /// T.npuir_reshape(A, B) @@ -3520,6 +3576,8 @@ mlir::Value CodeGenTileLangNPUIRDEV::VisitExpr_(const CallNode *op) { CreateHIVMBinaryVectorOp>(op); } else if (op->op.same_as(Op::Get("tl.npuir_exp"))) { UnaryVecOpCodegen>(op); + } else if (op->op.same_as(Op::Get("tl.npuir_floor"))) { + UnaryVecOpCodegen>(op); } else if (op->op.same_as(Op::Get("tl.npuir_ln"))) { UnaryVecOpCodegen>(op); } else if (op->op.same_as(Op::Get("tl.npuir_relu"))) { @@ -3619,6 +3677,8 @@ mlir::Value CodeGenTileLangNPUIRDEV::VisitExpr_(const CallNode *op) { VerfCodegen(op); } else if (op->op.same_as(Op::Get("tl.npuir_vtanh"))) { VtanhCodegen(op); + } else if (op->op.same_as(Op::Get("tl.npuir_floordiv"))) { + VfloordivCodegen(op); } else if (op->op.same_as(Op::Get("tl.npuir_debug_print_var")) || op->op.same_as(Op::Get("tl.npuir_debug_print_buffer_value"))) { DebugPrintCodegen(op); diff --git a/src/target/codegen_npuir_dev.h b/src/target/codegen_npuir_dev.h index 77e85c23c..21706398e 100644 --- a/src/target/codegen_npuir_dev.h +++ b/src/target/codegen_npuir_dev.h @@ -269,6 +269,7 @@ class CodeGenTileLangNPUIRDEV final void VsinCodegen(const CallNode *op); void VerfCodegen(const CallNode *op); void VtanhCodegen(const CallNode *op); + void VfloordivCodegen(const CallNode *op); void DebugPrintCodegen(const CallNode *op); void ReshapeCodegen(const CallNode *op); template diff --git a/testing/npuir/arith_ops/test_floor_dev.py b/testing/npuir/arith_ops/test_floor_dev.py new file mode 100644 index 000000000..44f3b9e45 --- /dev/null +++ b/testing/npuir/arith_ops/test_floor_dev.py @@ -0,0 +1,71 @@ +import os + +import pytest +import torch +import torch_npu + +import tilelang +import tilelang.language as T + +tilelang.cache.clear_cache() +os.environ["TILELANG_ASCEND_MODE"] = "Developer" + + +@pytest.fixture( + params=[ + ((1024, 1024), "float16"), + ((1024, 4096), "float16"), + ((1024, 10240), "float32"), + ((1024, 16384), "float32"), + ] +) +def floor_case(request): + return request.param + +block_M = 16 +block_N = 1024 + +def floor_kernel(M, N, dtype): + m_num = M // block_M + n_num = N // block_N + + @T.prim_func + def floorKernel(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): + with T.Kernel(m_num * n_num, is_npu=True) as (cid, _): + bx = block_M * (cid // n_num) + by = block_N * (cid % n_num) + + A_VEC = T.alloc_shared((block_M, block_N), dtype) + B_VEC = T.alloc_shared((block_M, block_N), dtype) + + T.copy(A[bx:bx+block_M, by:by+block_N], A_VEC) + T.npuir_floor(A_VEC, B_VEC) + T.copy(B_VEC, B[bx:bx+block_M, by:by+block_N]) + + return floorKernel + + +def generate_tensor(shape, dtype, clear=False, positive=False): + if clear: + return torch.zeros(shape, dtype=eval("torch." + dtype)) + if dtype in ("float32", "float16"): + t = torch.randn(size=shape, dtype=eval("torch." + dtype)) + if positive: + t = torch.abs(t) + 0.1 + return t + raise ValueError('Invalid parameter "dtype" is found : {}'.format(dtype)) + + +def test_floor_dev(floor_case): + shape, dtype = floor_case + + func = floor_kernel(*shape, dtype) + compiled_kernel = tilelang.compile(func, target="npuir") + + src = generate_tensor(shape, dtype).npu() + dst = generate_tensor(shape, dtype, clear=True).npu() + + ref = torch.floor(src.cpu()) + compiled_kernel(src, dst) + + assert torch.allclose(dst.cpu(), ref, rtol=1e-5, atol=1e-5) diff --git a/testing/npuir/arith_ops/test_floordiv_dev.py b/testing/npuir/arith_ops/test_floordiv_dev.py new file mode 100644 index 000000000..7cb0c6552 --- /dev/null +++ b/testing/npuir/arith_ops/test_floordiv_dev.py @@ -0,0 +1,75 @@ +import os + +import pytest +import torch +import torch_npu + +import tilelang +import tilelang.language as T + +tilelang.cache.clear_cache() +os.environ["TILELANG_ASCEND_MODE"] = "Developer" + +@pytest.fixture( + params=[ + ((1024, 1024), "float16"), + ((1024, 4096), "float16"), + ((1024, 10240), "float32"), + ((1024, 16384), "float32"), + ] +) +def floordiv_case(request): + return request.param + +block_M = 16 +block_N = 1024 + + +def floordiv_kernel(M, N, dtype): + m_num = M // block_M + n_num = N // block_N + + @T.prim_func + def floordivKernel(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype)): + with T.Kernel(m_num * n_num, is_npu=True) as (cid, _): + bx = block_M * (cid // n_num) + by = block_N * (cid % n_num) + + A_VEC = T.alloc_shared((block_M, block_N), dtype) + B_VEC = T.alloc_shared((block_M, block_N), dtype) + C_VEC = T.alloc_shared((block_M, block_N), dtype) + + T.copy(A[bx:bx+block_M, by:by+block_N], A_VEC) + T.copy(B[bx:bx+block_M, by:by+block_N], B_VEC) + T.npuir_floordiv(A_VEC, B_VEC, C_VEC) + T.copy(C_VEC, C[bx:bx+block_M, by:by+block_N]) + + return floordivKernel + + +def generate_tensor(shape, dtype, clear=False, positive=False): + if clear: + return torch.zeros(shape, dtype=eval("torch." + dtype)) + if dtype in ("float32", "float16"): + t = torch.randn(size=shape, dtype=eval("torch." + dtype)) + if positive: + t = torch.abs(t) + 0.1 + return t + raise ValueError('Invalid parameter "dtype" is found : {}'.format(dtype)) + + +def test_floordiv_dev(floordiv_case): + shape, dtype = floordiv_case + + func = floordiv_kernel(*shape, dtype) + compiled_kernel = tilelang.compile(func, target="npuir") + + a = generate_tensor(shape, dtype).npu() + b = generate_tensor(shape, dtype, positive=True).npu() + c = generate_tensor(shape, dtype, clear=True).npu() + + ref = torch.floor(a.cpu() / b.cpu()) + compiled_kernel(a, b, c) + + assert torch.allclose(c.cpu(), ref, rtol=1e-3, atol=1e-3) + diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index 60a2d1444..d931a5465 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -77,6 +77,10 @@ npuir_min as vmin, npuir_mul, npuir_mul as vmul, + npuir_floor, + npuir_floor as vfloor, + npuir_floordiv, + npuir_floordiv as vfloordiv, npuir_div, npuir_div as vdiv, npuir_or, diff --git a/tilelang/language/customize_npuir.py b/tilelang/language/customize_npuir.py index ef967aff1..767ac3030 100644 --- a/tilelang/language/customize_npuir.py +++ b/tilelang/language/customize_npuir.py @@ -227,6 +227,13 @@ def npuir_rec(A, B): def npuir_not(A, B): return AscendUnaryOp("not", A, B).buildTirCall() +"""npuir floor at tile-level.""" +def npuir_floor(A, B): + return AscendUnaryOp("floor", A, B).buildTirCall() +"""npuir floordiv at tile-level.""" +def npuir_floordiv(A, B, C): + return AscendBinaryOp("floordiv", A, B, C).buildTirCall() + def npuir_exp2(A, B, Tmp): """ npuir exp2 at tile-level.