From 41d05e5aba9b0f20076e19afa05d690725366a92 Mon Sep 17 00:00:00 2001 From: mouliangyu Date: Mon, 15 Jun 2026 00:26:31 +0800 Subject: [PATCH] Fix TileLang DSL textract and tsels tests --- .../a5/src/st/testcase/textract/gen_data.py | 4 +-- .../npu/a5/src/st/testcase/tsels/tsels.pto | 32 ++++++++----------- 2 files changed, 15 insertions(+), 21 deletions(-) diff --git a/test/tilelang_st/npu/a5/src/st/testcase/textract/gen_data.py b/test/tilelang_st/npu/a5/src/st/testcase/textract/gen_data.py index 5c2d5f802..a47a74a94 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/textract/gen_data.py +++ b/test/tilelang_st/npu/a5/src/st/testcase/textract/gen_data.py @@ -30,7 +30,7 @@ elif name.startswith("mat2right"): id_mat = np.eye(case["shape_id"][0], case["shape_id"][1], dtype=case["dtype_id"]) src = np.random.uniform(-1.0, 1.0, size=case["shape_src"]).astype(case["dtype_src"]) - golden = np.matmul(id_mat.astype(np.float32), src.astype(np.float32)).astype(np.float32) + golden = src.astype(np.float32).T.copy() save_case_data(name, {"input1": id_mat, "input2": src, "golden": golden}) - print(f"[INFO] gen_data: {name} done") \ No newline at end of file + print(f"[INFO] gen_data: {name} done") diff --git a/test/tilelang_st/npu/a5/src/st/testcase/tsels/tsels.pto b/test/tilelang_st/npu/a5/src/st/testcase/tsels/tsels.pto index 5ffbb6ce6..cac405aad 100644 --- a/test/tilelang_st/npu/a5/src/st/testcase/tsels/tsels.pto +++ b/test/tilelang_st/npu/a5/src/st/testcase/tsels/tsels.pto @@ -353,7 +353,7 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + func.func @TSELS_f32_uint8_2x8_2x32_2x8_2x8(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index @@ -377,13 +377,12 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind) outs(%mask_tile : !pto.tile_buf) pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x8xf32>) outs(%src_tile : !pto.tile_buf) - %scalar_f32 = arith.bitcast %scalar : i32 to f32 - pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_f32 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, f32) outs(%dst_tile : !pto.tile_buf) + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, f32) outs(%dst_tile : !pto.tile_buf) pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x8xf32>) return } - func.func @TSELS_f32_uint16_2x8_2x16_2x8_2x8(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + func.func @TSELS_f32_uint16_2x8_2x16_2x8_2x8(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index @@ -406,13 +405,12 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind) outs(%mask_tile : !pto.tile_buf) pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x8xf32>) outs(%src_tile : !pto.tile_buf) - %scalar_f32 = arith.bitcast %scalar : i32 to f32 - pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_f32 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, f32) outs(%dst_tile : !pto.tile_buf) + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, f32) outs(%dst_tile : !pto.tile_buf) pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x8xf32>) return } - func.func @TSELS_f32_uint32_2x8_2x8_2x8_2x8(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + func.func @TSELS_f32_uint32_2x8_2x8_2x8_2x8(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index @@ -434,8 +432,7 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind) outs(%mask_tile : !pto.tile_buf) pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x8xf32>) outs(%src_tile : !pto.tile_buf) - %scalar_f32 = arith.bitcast %scalar : i32 to f32 - pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_f32 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, f32) outs(%dst_tile : !pto.tile_buf) + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, f32) outs(%dst_tile : !pto.tile_buf) pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x8xf32>) return } @@ -501,7 +498,7 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + func.func @TSELS_f32_uint8_2x32_2x64_2x128_2x31(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index @@ -527,8 +524,7 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind) outs(%mask_tile : !pto.tile_buf) pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x2x128xf32>) outs(%src_tile : !pto.tile_buf) - %scalar_f32 = arith.bitcast %scalar : i32 to f32 - pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_f32 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, f32) outs(%dst_tile : !pto.tile_buf) + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, f32) outs(%dst_tile : !pto.tile_buf) pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x2x31xf32>) return } @@ -593,7 +589,7 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + func.func @TSELS_f32_uint8_32x672_32x96_32x672_32x666(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c4 = arith.constant 4 : index @@ -618,13 +614,12 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind) outs(%mask_tile : !pto.tile_buf) pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x32x672xf32>) outs(%src_tile : !pto.tile_buf) - %scalar_f32 = arith.bitcast %scalar : i32 to f32 - pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_f32 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, f32) outs(%dst_tile : !pto.tile_buf) + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, f32) outs(%dst_tile : !pto.tile_buf) pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x32x666xf32>) return } - func.func @TSELS_f32_uint8_1x8192_1x4096_1x8192_1x8192(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: i32) attributes {pto.kernel} { + func.func @TSELS_f32_uint8_1x8192_1x4096_1x8192_1x8192(%mask_ptr: !pto.ptr, %src_ptr: !pto.ptr, %dst_ptr: !pto.ptr, %scalar: f32) attributes {pto.kernel} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c4 = arith.constant 4 : index @@ -646,9 +641,8 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind) outs(%mask_tile : !pto.tile_buf) pto.tload ins(%src_part : !pto.partition_tensor_view<1x1x1x1x8192xf32>) outs(%src_tile : !pto.tile_buf) - %scalar_f32 = arith.bitcast %scalar : i32 to f32 - pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar_f32 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, f32) outs(%dst_tile : !pto.tile_buf) + pto.tsels ins(%mask_tile, %src_tile, %tmp_tile, %scalar : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf, f32) outs(%dst_tile : !pto.tile_buf) pto.tstore ins(%dst_tile : !pto.tile_buf) outs(%dst_part : !pto.partition_tensor_view<1x1x1x1x8192xf32>) return } -} \ No newline at end of file +}