From e25073ee7fa58812a9cb6032bbe29f7e9be5cb18 Mon Sep 17 00:00:00 2001 From: VitalyR <24508452+VitalyAnkh@users.noreply.github.com> Date: Thu, 9 Apr 2026 18:56:09 +0000 Subject: [PATCH 1/3] [MetaxGPU][MACA] Fill backend dtype and codegen support gaps --- src/target/codegen_maca.cc | 196 ++++++++++++++---- src/tl_templates/maca/common.h | 43 ++++ tilelang/contrib/dlpack.py | 4 +- .../intrinsics/maca_mma_macro_generator.py | 46 +++- tilelang/jit/adapter/tvm_ffi.py | 25 ++- tilelang/quantize/quantization.py | 7 +- tilelang/tileop/gemm/gemm_maca_mma.py | 9 +- 7 files changed, 268 insertions(+), 62 deletions(-) diff --git a/src/target/codegen_maca.cc b/src/target/codegen_maca.cc index ee2e659e..c28de004 100644 --- a/src/target/codegen_maca.cc +++ b/src/target/codegen_maca.cc @@ -48,9 +48,12 @@ static std::string GetTileLangFP8Type(DataType type) { << "Only support scalar and vector types of width (2, 4, 8, 16, 32) " "for FP8"; } - if (type.is_float8_e4m3() || type.is_float8_e4m3fn()) { + if (type.is_float8_e4m3() || type.is_float8_e4m3fn() || + type.is_float8_e4m3fnuz() || + type.code() == DataType::kFloat8_e4m3b11fnuz) { stream << "fp8_e4" << vec << "_t"; - } else if (type.is_float8_e5m2()) { + } else if (type.is_float8_e5m2() || type.is_float8_e5m2fnuz() || + type.code() == DataType::kFloat8_e5m2) { stream << "fp8_e5" << vec << "_t"; } else if (type.is_float8_e8m0fnu()) { stream << "fp8_e8" << vec << "_t"; @@ -362,6 +365,8 @@ void CodeGenTileLangMACA::PrintType(DataType t, std::ostream &os) { // NOLINT(*) ICHECK_EQ(lanes % 2, 0) << "only support even lane for float type with lanes > 4"; os << "ulonglong" << lanes / 2; + } else if (lanes == 16 || lanes == 32) { + os << "float32x" << lanes; } else { fail = true; } @@ -375,7 +380,8 @@ void CodeGenTileLangMACA::PrintType(DataType t, std::ostream &os) { // NOLINT(*) } if (!fail && (t.is_scalar() || t.bits() == 16)) return; - if (!fail && (lanes > 4 && lanes <= 8 && t.bits() == 32)) + if (!fail && t.bits() == 32 && + ((lanes > 4 && lanes <= 8) || lanes == 16 || lanes == 32)) return; if (!fail && (lanes >= 2 && lanes <= 4)) { os << lanes; @@ -629,7 +635,11 @@ void CodeGenTileLangMACA::PrintVecElemLoad(const std::string &vec, DataType t, } static const char access[] = {'x', 'y', 'z', 'w'}; - ICHECK(i >= 0 && i < 256 / t.bits()) + int max_lanes = 256 / t.bits(); + if (t.is_float() && t.bits() == 32 && (t.lanes() == 16 || t.lanes() == 32)) { + max_lanes = t.lanes(); + } + ICHECK(i >= 0 && i < max_lanes) << "i: " << i << " t: " << t << " t.bits(): " << t.bits() << " t.lanes(): " << t.lanes(); if (t.bits() == 8 && (t.is_int() || t.is_uint())) { @@ -692,6 +702,9 @@ void CodeGenTileLangMACA::PrintVecElemLoad(const std::string &vec, DataType t, os << "." << access[(i % 4) / 2]; // fp4_e2_2_t -> method call x() or y() os << "." << access[i % 2] << "()"; + } else if (t.is_float() && t.bits() == 32 && + (t.lanes() == 16 || t.lanes() == 32)) { + os << vec << "[" << i << "]"; } else if (t.lanes() > 4 && t.lanes() <= 8) { std::string type_name; if (t.bits() == 16) { @@ -721,7 +734,11 @@ void CodeGenTileLangMACA::PrintVecElemStore(const std::string &vec, DataType t, int i, const std::string &value) { this->PrintIndent(); static const char access[] = {'x', 'y', 'z', 'w'}; - ICHECK(i >= 0 && i < 256 / t.bits()); + int max_lanes = 256 / t.bits(); + if (t.is_float() && t.bits() == 32 && (t.lanes() == 16 || t.lanes() == 32)) { + max_lanes = t.lanes(); + } + ICHECK(i >= 0 && i < max_lanes); if (t.bits() == 8 && (t.is_int() || t.is_uint())) { if (t.lanes() == 2 || t.lanes() == 3) { stream << vec << '.' << access[i % t.lanes()] << "=" @@ -795,6 +812,9 @@ void CodeGenTileLangMACA::PrintVecElemStore(const std::string &vec, DataType t, ICHECK(!type_name.empty()); stream << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = " << value << ";\n"; + } else if (t.is_float() && t.bits() == 32 && + (t.lanes() == 16 || t.lanes() == 32)) { + stream << vec << "[" << i << "] = " << value << ";\n"; } else if (t.is_float4_e2m1fn()) { stream << vec; // fp4_e2_64_t @@ -1791,9 +1811,20 @@ void CodeGenTileLangMACA::VisitExpr_(const CallNode *op, std::ostream &os) { {"float16x4", "float16x4"}, {"bfloat16x4", "bfloat16x4_vec"}, {"float32x4", "float32x4"}, + {"float8_e4m3x4", "fp8_e4_4_t"}, + {"float8_e4m3x8", "long"}, + {"float8_e4m3fnx4", "fp8_e4_4_t"}, + {"float8_e4m3fnx8", "long"}, {"float8_e4m3fnuzx4", "fp8_e4_4_t"}, {"float8_e4m3fnuzx8", "long"}, - {"float32x16", "float32x16"}}; + {"float8_e4m3b11fnuzx4", "fp8_e4_4_t"}, + {"float8_e4m3b11fnuzx8", "long"}, + {"float8_e5m2x4", "fp8_e5_4_t"}, + {"float8_e5m2x8", "long"}, + {"float8_e5m2fnuzx4", "fp8_e5_4_t"}, + {"float8_e5m2fnuzx8", "long"}, + {"float32x16", "float32x16"}, + {"float32x32", "float32x32"}}; std::string call_mfma_code = R"({ *((({C_dtype}*){c_ref}) + {c_bias}) = {mfma_buildin}(*((({A_dtype}*){a_ref}) + {a_bias}), *((({B_dtype}*){b_ref}) + {b_bias}), @@ -2083,6 +2114,8 @@ void CodeGenTileLangMACA::VisitExpr_(const CallNode *op, std::ostream &os) { os << ", " << PrintExpr(op->args[2]); } os << ")"; + } else if (op->op.same_as(tl::tl_shuffle_elect())) { + os << "tl::tl_shuffle_elect<" << PrintExpr(op->args[0]) << ">()"; } else if (op->op.same_as(tl::tl_gemm_sp())) { ICHECK(op->args.size() == 5) << "tl_gemm_sp expects 5 arguments value); - ICHECK(p); - int64_t v = *p & 0xFF; - v = (v << 24) | (v << 16) | (v << 8) | v; - if (op->dtype.is_uint()) { - os << "(uint)" << v; + if (p) { + int64_t v = *p & 0xFF; + v = (v << 24) | (v << 16) | (v << 8) | v; + if (op->dtype.is_uint()) { + os << "(uint)" << v; + } else { + os << "(int)" << v; + } + } else { + std::string scalar = PrintExpr(op->value); + std::string byte = "((" + scalar + ") & 0xFF)"; + os << "(" << (op->dtype.is_uint() ? "uint" : "int") << ")(" + << byte << " | (" << byte << " << 8) | (" << byte << " << 16) | (" + << byte << " << 24))"; + } + return; + } else if (lanes == 8 || lanes == 16) { + const int64_t *p = as_const_int(op->value); + std::string packed32; + if (p) { + int64_t v = *p & 0xFF; + v = (v << 24) | (v << 16) | (v << 8) | v; + std::ostringstream oss; + oss << "(" << (op->dtype.is_uint() ? "uint" : "int") << ")" << v; + packed32 = oss.str(); } else { - os << "(int)" << v; + std::string scalar = PrintExpr(op->value); + std::string byte = "((" + scalar + ") & 0xFF)"; + packed32 = "(" + byte + " | (" + byte + " << 8) | (" + byte + + " << 16) | (" + byte + " << 24))"; + packed32 = "(" + std::string(op->dtype.is_uint() ? "uint" : "int") + + ")" + packed32; + } + os << "make_"; + PrintType(op->dtype, os); + os << '('; + for (int i = 0; i < lanes / 4; ++i) { + if (i != 0) + os << ", "; + os << packed32; } + os << ')'; return; } else if (lanes == 32) { // make_int8x32 const int64_t *p = as_const_int(op->value); - ICHECK(p); - int64_t v = *p & 0xFF; - v = (v << 24) | (v << 16) | (v << 8) | v; - if (op->dtype.is_uint()) { - os << "make_ulonglong4(" << v << ", " << v << ", " << v << ", " << v - << ")"; + if (p) { + int64_t v = *p & 0xFF; + v = (v << 24) | (v << 16) | (v << 8) | v; + if (op->dtype.is_uint()) { + os << "make_ulonglong4(" << v << ", " << v << ", " << v << ", " << v + << ")"; + } else { + os << "make_longlong4(" << v << ", " << v << ", " << v << ", " << v + << ")"; + } } else { - os << "make_longlong4(" << v << ", " << v << ", " << v << ", " << v - << ")"; + std::string scalar = PrintExpr(op->value); + std::string byte = "((" + scalar + ") & 0xFF)"; + std::string packed32 = "(" + byte + " | (" + byte + " << 8) | (" + + byte + " << 16) | (" + byte + " << 24))"; + std::string packed64 = + "(((unsigned long long)" + packed32 + + ") | (((unsigned long long)" + packed32 + ") << 32))"; + if (op->dtype.is_uint()) { + os << "make_ulonglong4(" << packed64 << ", " << packed64 << ", " + << packed64 << ", " << packed64 << ")"; + } else { + os << "make_longlong4(" << packed64 << ", " << packed64 << ", " + << packed64 << ", " << packed64 << ")"; + } } return; } @@ -2529,7 +2612,7 @@ void CodeGenTileLangMACA::VisitExpr_(const BroadcastNode *op, for (int i = 0; i < 4; ++i) { if (i != 0) os << ", "; - os << "*(unsigned long long*)&make_float2(" << v << ", " << v << ")"; + os << "pack_float2(" << v << ", " << v << ")"; } os << ')'; return; @@ -2538,37 +2621,64 @@ void CodeGenTileLangMACA::VisitExpr_(const BroadcastNode *op, if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 4) { bool fail = false; const int64_t *p = as_const_int(op->value); - ICHECK(p); - int64_t v = *p & 0xF; - - if (lanes == 4) { - v = (v << 12) | (v << 8) | (v << 4) | v; - if (op->dtype.is_uint()) { - os << "(uint16_t)" << v; - } else { - os << "(int16_t)" << v; - } - } else { - v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | (v << 8) | - (v << 4) | v; - if (lanes == 8) { + if (p) { + int64_t v = *p & 0xF; + if (lanes == 4) { + v = (v << 12) | (v << 8) | (v << 4) | v; if (op->dtype.is_uint()) { - os << "(uint)" << v; + os << "(uint16_t)" << v; } else { - os << "(int)" << v; + os << "(int16_t)" << v; } - } else if (lanes == 16 || lanes == 32) { + } else { + v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | + (v << 8) | (v << 4) | v; + if (lanes == 8) { + if (op->dtype.is_uint()) { + os << "(uint)" << v; + } else { + os << "(int)" << v; + } + } else if (lanes == 16 || lanes == 32 || lanes == 64) { + os << "make_"; + PrintType(op->dtype, os); + os << '('; + for (int i = 0; i < lanes / 8; ++i) { + if (i != 0) + os << ", "; + if (op->dtype.is_uint()) { + os << "(uint)" << v; + } else { + os << "(int)" << v; + } + } + os << ')'; + } else { + fail = true; + } + } + } else { + std::string scalar = PrintExpr(op->value); + std::string nibble = "((" + scalar + ") & 0xF)"; + std::string packed32 = "(" + nibble + " | (" + nibble + " << 4) | (" + + nibble + " << 8) | (" + nibble + " << 12) | (" + + nibble + " << 16) | (" + nibble + " << 20) | (" + + nibble + " << 24) | (" + nibble + " << 28))"; + if (lanes == 4) { + os << "(" << (op->dtype.is_uint() ? "uint16_t" : "int16_t") << ")(" + << nibble << " | (" << nibble << " << 4) | (" << nibble + << " << 8) | (" << nibble << " << 12))"; + } else if (lanes == 8) { + os << "(" << (op->dtype.is_uint() ? "uint" : "int") << ")" << packed32; + } else if (lanes == 16 || lanes == 32 || lanes == 64) { os << "make_"; PrintType(op->dtype, os); os << '('; for (int i = 0; i < lanes / 8; ++i) { if (i != 0) os << ", "; - if (op->dtype.is_uint()) { - os << "(uint)" << v; - } else { - os << "(int)" << v; - } + os << "(" << (op->dtype.is_uint() ? "uint" : "int") << ")" + << packed32; } os << ')'; } else { diff --git a/src/tl_templates/maca/common.h b/src/tl_templates/maca/common.h index 9b7392b4..ed40b9f5 100644 --- a/src/tl_templates/maca/common.h +++ b/src/tl_templates/maca/common.h @@ -162,9 +162,43 @@ typedef using int32x4 = __attribute__((__vector_size__(4 * sizeof(int)))) int; using float32x4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; using float32x16 = __attribute__((__vector_size__(16 * sizeof(float)))) float; +using float32x32 = __attribute__((__vector_size__(32 * sizeof(float)))) float; using float64x4 = __attribute__((__vector_size__(4 * sizeof(double)))) double; using int8x4 = __attribute__((__vector_size__(4 * sizeof(int8_t)))) int8_t; +TL_DEVICE float32x16 make_float32x16( + float x0, float x1, float x2, float x3, float x4, float x5, float x6, + float x7, float x8, float x9, float x10, float x11, float x12, float x13, + float x14, float x15) { + return float32x16{x0, x1, x2, x3, x4, x5, x6, x7, + x8, x9, x10, x11, x12, x13, x14, x15}; +} + +TL_DEVICE float32x32 make_float32x32( + float x0, float x1, float x2, float x3, float x4, float x5, float x6, + float x7, float x8, float x9, float x10, float x11, float x12, float x13, + float x14, float x15, float x16, float x17, float x18, float x19, + float x20, float x21, float x22, float x23, float x24, float x25, + float x26, float x27, float x28, float x29, float x30, float x31) { + return float32x32{x0, x1, x2, x3, x4, x5, x6, x7, + x8, x9, x10, x11, x12, x13, x14, x15, + x16, x17, x18, x19, x20, x21, x22, x23, + x24, x25, x26, x27, x28, x29, x30, x31}; +} + +template TL_DEVICE bool tl_shuffle_elect() { + if constexpr (thread_extent == 0) { + return threadIdx.x == 0; + } else if constexpr (thread_extent <= 32) { + return (threadIdx.x % thread_extent) == 0; + } else if constexpr (thread_extent % 32 == 0) { + return ((threadIdx.x / 32) % (thread_extent / 32)) == 0 && + (threadIdx.x % 32) == 0; + } else { + return (threadIdx.x % thread_extent) == 0; + } +} + // Pack four char values. TL_DEVICE unsigned int make_uint(unsigned char x0, unsigned char x1, unsigned char x2, unsigned char x3) { @@ -196,6 +230,15 @@ TL_DEVICE unsigned __pack_maca_bfloat162(const bfloat16_t x, return (v1 << 16) | v0; } +TL_DEVICE unsigned long long pack_float2(const float x, const float y) { + union { + float2 f; + unsigned long long u64; + } bits; + bits.f = make_float2(x, y); + return bits.u64; +} + template TL_DEVICE void AtomicAdd(T1 *address, T2 val, int memory_order = 0) { (void)memory_order; diff --git a/tilelang/contrib/dlpack.py b/tilelang/contrib/dlpack.py index d80f0fdb..66adbc1e 100644 --- a/tilelang/contrib/dlpack.py +++ b/tilelang/contrib/dlpack.py @@ -38,10 +38,10 @@ def convert_func(tvm_func, tensor_type, to_dlpack_func): import torch float8_dtype_map = { - torch.float8_e4m3fn: "float8_e4m3", + torch.float8_e4m3fn: "float8_e4m3fn", torch.float8_e4m3fnuz: "float8_e4m3fnuz", torch.float8_e5m2: "float8_e5m2", - torch.float8_e5m2fnuz: "float8_e5m2", + torch.float8_e5m2fnuz: "float8_e5m2fnuz", } def adapt_tensor(arg): diff --git a/tilelang/intrinsics/maca_mma_macro_generator.py b/tilelang/intrinsics/maca_mma_macro_generator.py index 1ea41388..f621ec37 100644 --- a/tilelang/intrinsics/maca_mma_macro_generator.py +++ b/tilelang/intrinsics/maca_mma_macro_generator.py @@ -57,6 +57,14 @@ class TensorCoreIntrinEmitter: k_pack = 1 # Represent the thread binding in the form of (tx, warp_n, warp_m) is_m_first = False + fp8_dtypes = { + "float8_e4m3", + "float8_e5m2", + "float8_e4m3fn", + "float8_e5m2fn", + "float8_e4m3fnuz", + "float8_e5m2fnuz", + } def __init__( self, @@ -88,6 +96,7 @@ def __init__( self.warp_row_tiles = warp_row_tiles self.warp_col_tiles = warp_col_tiles self.chunk = chunk + self.mma_input_dtype = self._resolve_mma_input_dtype(a_dtype) self._initialize_k_dim(a_dtype) self._initialize_abbrev(a_dtype, b_dtype, accum_dtype) self._initialize_local_size(self.M_DIM, self.N_DIM, self.k_dim, self.WARP_SIZE) @@ -106,8 +115,6 @@ def __init__( def _initialize_k_dim(self, a_dtype=T.float16): if isinstance(a_dtype, str): - if a_dtype in ["float8_e4m3fnuz", "float8_e5m2fnuz"]: - return a_dtype = DataType(a_dtype) if a_dtype.bits == 32: @@ -130,13 +137,21 @@ def _dtype_abbrv_lookup(self, dtype): raise KeyError(f"Unsupported dtype for MACA MMA: {dtype!r}") return self.dtype_abbrv[s] + def _resolve_mma_input_dtype(self, dtype): + s = str(dtype) + if s.startswith("dtype('") and s.endswith("')"): + s = s[7:-2] + if s in self.fp8_dtypes: + return T.float16 + return dtype + def _initialize_abbrev(self, a_dtype, b_dtype, accum_dtype): self.a_dtype_abbrv = self._dtype_abbrv_lookup(a_dtype) self.b_dtype_abbrv = self._dtype_abbrv_lookup(b_dtype) self.accum_dtype_abbrv = self._dtype_abbrv_lookup(accum_dtype) def _initialize_mma_prefix(self, k_dim=16): - in_dtype = self.a_dtype + in_dtype = self.mma_input_dtype M_DIM, N_DIM = self.M_DIM, self.N_DIM in_dtype_key = str(in_dtype) @@ -148,6 +163,8 @@ def _initialize_mma_prefix(self, k_dim=16): "float32": "f32", "int8": "i8", "int32": "i32", + "float8_e4m3": "f16", + "float8_e5m2": "f16", "float8_e4m3fnuz": "fp8", "float8_e5m2fnuz": "fp8", "float8_e4m3fn": "fp8", @@ -281,6 +298,7 @@ def ldmatrix_a(self, A_local_buf, A_shared_buf: Buffer | BufferRegion, ki, rk=0) # legalize shared buffer to region A_region = self._legalize_to_buffer_region(A_shared_buf) A_buf = A_region.buffer + A_prefix = [region.min for region in A_region.region[:-2]] A_base0 = A_region.region[-2].min A_base1 = A_region.region[-1].min @@ -298,13 +316,17 @@ def _warp_ldmatrix_a( for local_id in T.vectorized(k_pack * local_size_a): row, col = T.meta_var(reverse_index_map(tx, local_id)) l, r = (rk * chunk + ki * (k_pack * micro_size_k), warp_m * warp_row_tiles + i * micro_size_x) - A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col] + A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[ + tuple(A_prefix + [A_base0 + l + row, A_base1 + r + col]) + ] else: for i in T.serial(warp_rows): for local_id in T.vectorized(k_pack * local_size_a): row, col = T.meta_var(reverse_index_map(tx, local_id)) l, r = (warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * (k_pack * micro_size_k)) - A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col] + A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[ + tuple(A_prefix + [A_base0 + l + row, A_base1 + r + col]) + ] return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk) @@ -323,6 +345,7 @@ def ldmatrix_b(self, B_local_buf, B_shared_buf: Buffer | BufferRegion, ki, rk=0) # legalize shared buffer to region B_region = self._legalize_to_buffer_region(B_shared_buf) B_buf = B_region.buffer + B_prefix = [region.min for region in B_region.region[:-2]] B_base0 = B_region.region[-2].min B_base1 = B_region.region[-1].min @@ -343,7 +366,9 @@ def _warp_ldmatrix_b( warp_n * warp_col_tiles + j * micro_size_y, rk * chunk + ki * (k_pack * micro_size_k), ) - B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row, B_base1 + r + col] + B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[ + tuple(B_prefix + [B_base0 + l + row, B_base1 + r + col]) + ] else: for j in T.serial(warp_cols): @@ -353,7 +378,9 @@ def _warp_ldmatrix_b( rk * chunk + ki * (k_pack * micro_size_k), warp_n * warp_col_tiles + j * micro_size_y, ) - B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row, B_base1 + r + col] + B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[ + tuple(B_prefix + [B_base0 + l + row, B_base1 + r + col]) + ] return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk) @@ -366,8 +393,9 @@ def mma(self, A_local_buf: Buffer, B_local_buf: Buffer, C_local_buf: Buffer, k_i k_pack = self.k_pack mma_suffix = self.mma_suffix a_dtype, b_dtype, out_dtype = self.a_dtype, self.b_dtype, self.accum_dtype - compute_a_dtype = a_dtype if local_size_a == 1 else f"{a_dtype}x{local_size_a}" - compute_b_dtype = b_dtype if local_size_b == 1 else f"{b_dtype}x{local_size_b}" + mma_input_dtype = self.mma_input_dtype + compute_a_dtype = mma_input_dtype if local_size_a == 1 else f"{mma_input_dtype}x{local_size_a}" + compute_b_dtype = mma_input_dtype if local_size_b == 1 else f"{mma_input_dtype}x{local_size_b}" compute_out_dtype = out_dtype if local_size_out == 1 else f"{out_dtype}x{local_size_out}" a_is_fragment = is_fragment(A_local_buf) diff --git a/tilelang/jit/adapter/tvm_ffi.py b/tilelang/jit/adapter/tvm_ffi.py index 63147954..87374ac5 100644 --- a/tilelang/jit/adapter/tvm_ffi.py +++ b/tilelang/jit/adapter/tvm_ffi.py @@ -176,6 +176,29 @@ def _convert_torch_func(self) -> Callable[..., Any]: dynamic_symbolic_map = self._process_dynamic_symbolic() executable = self.executable + float8_dtype_map = { + getattr(torch, "float8_e4m3fn", None): "float8_e4m3fn", + getattr(torch, "float8_e4m3fnuz", None): "float8_e4m3fnuz", + getattr(torch, "float8_e5m2", None): "float8_e5m2", + getattr(torch, "float8_e5m2fnuz", None): "float8_e5m2fnuz", + } + float8_dtype_map = {k: v for k, v in float8_dtype_map.items() if k is not None} + + def adapt_tensor_for_tvm(arg: torch.Tensor | Any): + if not isinstance(arg, torch.Tensor): + return arg + + float8_dtype = float8_dtype_map.get(arg.dtype) + if float8_dtype is None: + return arg + + # tvm_ffi cannot ingest float8 tensors directly via DLPack today. + # Reuse the existing float8 bridge pattern: pass the storage as int8 + # and recover the logical dtype through a TVM tensor view. + return runtime.from_dlpack(torch.utils.dlpack.to_dlpack(arg.view(torch.int8)))._create_view( + arg.shape, dtype=float8_dtype + ) + # Prepare helpers for friendly dtype error messages prim_func = self.prim_func buffer_map = prim_func.buffer_map @@ -241,7 +264,7 @@ def func(*inputs: torch.Tensor | Any): ins_idx += 1 tensor_list.append(tensor) - executable(*tensor_list) + executable(*(adapt_tensor_for_tvm(tensor) for tensor in tensor_list)) # Return outputs in the requested form if len(self.result_idx) == 1: diff --git a/tilelang/quantize/quantization.py b/tilelang/quantize/quantization.py index 74a545f2..6f84317c 100644 --- a/tilelang/quantize/quantization.py +++ b/tilelang/quantize/quantization.py @@ -63,9 +63,10 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale e_f4 = (f4 & tir.const(6, T.uint16)) >> tir.const(1, T.uint16) # Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126 e_bf16 = e_f4 + tir.const(126, T.uint16) - # Scale is the exponential part, within the representation of uint8 - # To handle the overflow, we use the max function to limit the exponential part to 8 bits - e_bf16 = min(e_bf16 + scale, tir.const((1 << 8) - 1, T.uint16)) + # Scale is the exponent offset stored as uint8. Clamp the adjusted exponent to bf16 range. + tir_u16_max = tir.const((1 << 8) - 1, T.uint16) + scaled_e_bf16 = e_bf16 + tir.Cast(T.uint16, scale) + e_bf16 = tir.Select(scaled_e_bf16 > tir_u16_max, tir_u16_max, scaled_e_bf16) m_f4 = f4 & tir.const(1, T.uint16) val_bf16 = tir.reinterpret(T.bfloat16, ((((s << tir.const(8, T.uint16)) | e_bf16) << tir.const(7, T.uint16)) diff --git a/tilelang/tileop/gemm/gemm_maca_mma.py b/tilelang/tileop/gemm/gemm_maca_mma.py index 5887c957..e7b3fd9b 100644 --- a/tilelang/tileop/gemm/gemm_maca_mma.py +++ b/tilelang/tileop/gemm/gemm_maca_mma.py @@ -86,6 +86,7 @@ def lower( ) in_dtype = self.in_dtype + mma_input_dtype = mma_emitter.mma_input_dtype warp_rows = mma_emitter.warp_rows warp_cols = mma_emitter.warp_cols local_size_a = mma_emitter.local_size_a @@ -117,8 +118,8 @@ def _gemm_ssr() -> None: B_shared into local fragments, then issues Tensor Core mma ops, accumulating into C_local. """ - A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) - B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + A_local = T.alloc_local((warp_rows * local_size_a), mma_input_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), mma_input_dtype) if clear_accum: T.clear(C_buf) for ki in T.serial(0, (block_K // micro_size_k)): @@ -152,7 +153,7 @@ def _gemm_srr() -> None: B_shared into local fragments, then issues Tensor Core mma ops, accumulating into C_local. """ - A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + A_local = T.alloc_local((warp_rows * local_size_a), mma_input_dtype) for ki in T.serial(0, (block_K // micro_size_k)): if clear_accum: @@ -182,7 +183,7 @@ def _gemm_rsr() -> None: B_shared into local fragments, then issues Tensor Core mma ops, accumulating into C_local. """ - B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), mma_input_dtype) if clear_accum: T.clear(C_buf) for ki in T.serial(0, (block_K // micro_size_k)): From bed61301888911485c6f717b8071a3232ede74ab Mon Sep 17 00:00:00 2001 From: VitalyR <24508452+VitalyAnkh@users.noreply.github.com> Date: Thu, 9 Apr 2026 20:14:59 +0000 Subject: [PATCH 2/3] [MetaxGPU][MACA] Apply pre-commit fixes for backend support --- src/target/codegen_maca.cc | 12 +++++----- src/tl_templates/maca/common.h | 24 +++++++++---------- .../intrinsics/maca_mma_macro_generator.py | 18 ++++---------- tilelang/jit/adapter/tvm_ffi.py | 4 +--- tilelang/tileop/gemm/gemm_maca_mma.py | 1 - 5 files changed, 24 insertions(+), 35 deletions(-) diff --git a/src/target/codegen_maca.cc b/src/target/codegen_maca.cc index c28de004..b655ced9 100644 --- a/src/target/codegen_maca.cc +++ b/src/target/codegen_maca.cc @@ -2494,9 +2494,9 @@ void CodeGenTileLangMACA::VisitExpr_(const BroadcastNode *op, } else { std::string scalar = PrintExpr(op->value); std::string byte = "((" + scalar + ") & 0xFF)"; - os << "(" << (op->dtype.is_uint() ? "uint" : "int") << ")(" - << byte << " | (" << byte << " << 8) | (" << byte << " << 16) | (" - << byte << " << 24))"; + os << "(" << (op->dtype.is_uint() ? "uint" : "int") << ")(" << byte + << " | (" << byte << " << 8) | (" << byte << " << 16) | (" << byte + << " << 24))"; } return; } else if (lanes == 8 || lanes == 16) { @@ -2544,9 +2544,9 @@ void CodeGenTileLangMACA::VisitExpr_(const BroadcastNode *op, std::string byte = "((" + scalar + ") & 0xFF)"; std::string packed32 = "(" + byte + " | (" + byte + " << 8) | (" + byte + " << 16) | (" + byte + " << 24))"; - std::string packed64 = - "(((unsigned long long)" + packed32 + - ") | (((unsigned long long)" + packed32 + ") << 32))"; + std::string packed64 = "(((unsigned long long)" + packed32 + + ") | (((unsigned long long)" + packed32 + + ") << 32))"; if (op->dtype.is_uint()) { os << "make_ulonglong4(" << packed64 << ", " << packed64 << ", " << packed64 << ", " << packed64 << ")"; diff --git a/src/tl_templates/maca/common.h b/src/tl_templates/maca/common.h index ed40b9f5..89727e19 100644 --- a/src/tl_templates/maca/common.h +++ b/src/tl_templates/maca/common.h @@ -166,24 +166,24 @@ using float32x32 = __attribute__((__vector_size__(32 * sizeof(float)))) float; using float64x4 = __attribute__((__vector_size__(4 * sizeof(double)))) double; using int8x4 = __attribute__((__vector_size__(4 * sizeof(int8_t)))) int8_t; -TL_DEVICE float32x16 make_float32x16( - float x0, float x1, float x2, float x3, float x4, float x5, float x6, - float x7, float x8, float x9, float x10, float x11, float x12, float x13, - float x14, float x15) { - return float32x16{x0, x1, x2, x3, x4, x5, x6, x7, +TL_DEVICE float32x16 make_float32x16(float x0, float x1, float x2, float x3, + float x4, float x5, float x6, float x7, + float x8, float x9, float x10, float x11, + float x12, float x13, float x14, + float x15) { + return float32x16{x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14, x15}; } TL_DEVICE float32x32 make_float32x32( float x0, float x1, float x2, float x3, float x4, float x5, float x6, float x7, float x8, float x9, float x10, float x11, float x12, float x13, - float x14, float x15, float x16, float x17, float x18, float x19, - float x20, float x21, float x22, float x23, float x24, float x25, - float x26, float x27, float x28, float x29, float x30, float x31) { - return float32x32{x0, x1, x2, x3, x4, x5, x6, x7, - x8, x9, x10, x11, x12, x13, x14, x15, - x16, x17, x18, x19, x20, x21, x22, x23, - x24, x25, x26, x27, x28, x29, x30, x31}; + float x14, float x15, float x16, float x17, float x18, float x19, float x20, + float x21, float x22, float x23, float x24, float x25, float x26, float x27, + float x28, float x29, float x30, float x31) { + return float32x32{x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, + x11, x12, x13, x14, x15, x16, x17, x18, x19, x20, x21, + x22, x23, x24, x25, x26, x27, x28, x29, x30, x31}; } template TL_DEVICE bool tl_shuffle_elect() { diff --git a/tilelang/intrinsics/maca_mma_macro_generator.py b/tilelang/intrinsics/maca_mma_macro_generator.py index f621ec37..4e3e0b7f 100644 --- a/tilelang/intrinsics/maca_mma_macro_generator.py +++ b/tilelang/intrinsics/maca_mma_macro_generator.py @@ -316,17 +316,13 @@ def _warp_ldmatrix_a( for local_id in T.vectorized(k_pack * local_size_a): row, col = T.meta_var(reverse_index_map(tx, local_id)) l, r = (rk * chunk + ki * (k_pack * micro_size_k), warp_m * warp_row_tiles + i * micro_size_x) - A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[ - tuple(A_prefix + [A_base0 + l + row, A_base1 + r + col]) - ] + A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[tuple(A_prefix + [A_base0 + l + row, A_base1 + r + col])] else: for i in T.serial(warp_rows): for local_id in T.vectorized(k_pack * local_size_a): row, col = T.meta_var(reverse_index_map(tx, local_id)) l, r = (warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * (k_pack * micro_size_k)) - A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[ - tuple(A_prefix + [A_base0 + l + row, A_base1 + r + col]) - ] + A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[tuple(A_prefix + [A_base0 + l + row, A_base1 + r + col])] return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk) @@ -366,9 +362,7 @@ def _warp_ldmatrix_b( warp_n * warp_col_tiles + j * micro_size_y, rk * chunk + ki * (k_pack * micro_size_k), ) - B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[ - tuple(B_prefix + [B_base0 + l + row, B_base1 + r + col]) - ] + B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[tuple(B_prefix + [B_base0 + l + row, B_base1 + r + col])] else: for j in T.serial(warp_cols): @@ -378,9 +372,7 @@ def _warp_ldmatrix_b( rk * chunk + ki * (k_pack * micro_size_k), warp_n * warp_col_tiles + j * micro_size_y, ) - B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[ - tuple(B_prefix + [B_base0 + l + row, B_base1 + r + col]) - ] + B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[tuple(B_prefix + [B_base0 + l + row, B_base1 + r + col])] return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk) @@ -392,7 +384,7 @@ def mma(self, A_local_buf: Buffer, B_local_buf: Buffer, C_local_buf: Buffer, k_i local_size_out = self.local_size_out k_pack = self.k_pack mma_suffix = self.mma_suffix - a_dtype, b_dtype, out_dtype = self.a_dtype, self.b_dtype, self.accum_dtype + out_dtype = self.accum_dtype mma_input_dtype = self.mma_input_dtype compute_a_dtype = mma_input_dtype if local_size_a == 1 else f"{mma_input_dtype}x{local_size_a}" compute_b_dtype = mma_input_dtype if local_size_b == 1 else f"{mma_input_dtype}x{local_size_b}" diff --git a/tilelang/jit/adapter/tvm_ffi.py b/tilelang/jit/adapter/tvm_ffi.py index 87374ac5..2edfabf8 100644 --- a/tilelang/jit/adapter/tvm_ffi.py +++ b/tilelang/jit/adapter/tvm_ffi.py @@ -195,9 +195,7 @@ def adapt_tensor_for_tvm(arg: torch.Tensor | Any): # tvm_ffi cannot ingest float8 tensors directly via DLPack today. # Reuse the existing float8 bridge pattern: pass the storage as int8 # and recover the logical dtype through a TVM tensor view. - return runtime.from_dlpack(torch.utils.dlpack.to_dlpack(arg.view(torch.int8)))._create_view( - arg.shape, dtype=float8_dtype - ) + return runtime.from_dlpack(torch.utils.dlpack.to_dlpack(arg.view(torch.int8)))._create_view(arg.shape, dtype=float8_dtype) # Prepare helpers for friendly dtype error messages prim_func = self.prim_func diff --git a/tilelang/tileop/gemm/gemm_maca_mma.py b/tilelang/tileop/gemm/gemm_maca_mma.py index e7b3fd9b..3a5b7234 100644 --- a/tilelang/tileop/gemm/gemm_maca_mma.py +++ b/tilelang/tileop/gemm/gemm_maca_mma.py @@ -85,7 +85,6 @@ def lower( thread_var=thread_var, ) - in_dtype = self.in_dtype mma_input_dtype = mma_emitter.mma_input_dtype warp_rows = mma_emitter.warp_rows warp_cols = mma_emitter.warp_cols From b72cfe6c47f2fdad020812195fdda2d05a384f1f Mon Sep 17 00:00:00 2001 From: VitalyR <24508452+VitalyAnkh@users.noreply.github.com> Date: Thu, 9 Apr 2026 19:10:43 +0000 Subject: [PATCH 3/3] [MetaxGPU][Examples] Tune resource-bound MACA examples --- .../example_gqa_sink_bwd_varlen.py | 3 + .../example_gqa_sink_fwd_varlen.py | 21 ++++++- .../test_example_attention_sink.py | 3 - .../maca/convolution/example_convolution.py | 6 +- .../convolution/test_example_convolution.py | 2 - .../maca/deepseek_mla/example_mla_decode.py | 24 +++++--- .../deepseek_mla/test_example_mla_decode.py | 1 - .../example_gqa_bwd_tma_reduce_varlen.py | 34 ++++++++--- .../test_example_flash_attention.py | 1 - examples/maca/gdn/example_chunk_delta_bwd.py | 60 +++++++++++-------- examples/maca/gdn/example_chunk_delta_h.py | 23 +++---- examples/maca/gdn/example_chunk_o.py | 11 ++-- examples/maca/gdn/example_chunk_o_bwd.py | 29 ++++----- .../maca/gdn/example_chunk_scaled_dot_kkt.py | 7 ++- examples/maca/gdn/example_cumsum.py | 5 +- examples/maca/gdn/example_wy_fast.py | 11 ++-- .../maca/gdn/example_wy_fast_bwd_split.py | 29 ++++----- .../maca/gdn/test_example_gdn_compilation.py | 6 -- 18 files changed, 163 insertions(+), 113 deletions(-) diff --git a/examples/maca/attention_sink/example_gqa_sink_bwd_varlen.py b/examples/maca/attention_sink/example_gqa_sink_bwd_varlen.py index 64a5a39a..63ae707c 100644 --- a/examples/maca/attention_sink/example_gqa_sink_bwd_varlen.py +++ b/examples/maca/attention_sink/example_gqa_sink_bwd_varlen.py @@ -2,6 +2,7 @@ import tilelang from tilelang.profiler import do_bench import tilelang.language as T +from tilelang.utils.target import determine_target, target_is_maca import argparse from typing import Optional import sys @@ -12,6 +13,8 @@ def get_bwd_configs(): + if target_is_maca(determine_target("auto", return_object=True)): + return 32, 16, 1, 128 sm_major, sm_minor = torch.cuda.get_device_capability() sm_version = sm_major * 10 + sm_minor if sm_version == 80: diff --git a/examples/maca/attention_sink/example_gqa_sink_fwd_varlen.py b/examples/maca/attention_sink/example_gqa_sink_fwd_varlen.py index 16838dd8..817e5655 100644 --- a/examples/maca/attention_sink/example_gqa_sink_fwd_varlen.py +++ b/examples/maca/attention_sink/example_gqa_sink_fwd_varlen.py @@ -7,6 +7,7 @@ import tilelang.language as T import tilelang.testing from tilelang.profiler import do_bench +from tilelang.utils.target import determine_target, target_is_maca from typing import Optional import sys import os @@ -15,6 +16,12 @@ from varlen_utils import generate_random_padding_mask, generate_qkv +def get_fwd_configs(): + if target_is_maca(determine_target("auto", return_object=True)): + return 64, 32, 1, 128 + return 128, 128, 2, 256 + + @tilelang.jit( out_idx=[7], pass_configs={ @@ -352,8 +359,20 @@ def main( UQ = q_unpad.shape[0] UKV = k_unpad.shape[0] + block_M, block_N, num_stages, threads = get_fwd_configs() kernel = flashattn_sink( - batch, groups, UQ, UKV, heads, dim, is_causal, window_size=window_size, block_M=128, block_N=128, num_stages=2, threads=256 + batch, + groups, + UQ, + UKV, + heads, + dim, + is_causal, + window_size=window_size, + block_M=block_M, + block_N=block_N, + num_stages=num_stages, + threads=threads, ) out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, sinks) diff --git a/examples/maca/attention_sink/test_example_attention_sink.py b/examples/maca/attention_sink/test_example_attention_sink.py index d1e4db65..7b6bde88 100644 --- a/examples/maca/attention_sink/test_example_attention_sink.py +++ b/examples/maca/attention_sink/test_example_attention_sink.py @@ -15,12 +15,10 @@ def test_example_mha_sink_fwd_bhsd_sliding_window(): example_mha_sink_fwd_bhsd.main(window_size=128) -@tilelang.testing.pytest.mark.xfail def test_example_mha_sink_bwd_bhsd(): example_mha_sink_bwd_bhsd.main() -@tilelang.testing.pytest.mark.xfail def test_example_mha_sink_bwd_bhsd_sliding_window(): example_mha_sink_bwd_bhsd.main(window_size=128) @@ -33,7 +31,6 @@ def test_example_gqa_sink_bwd_bhsd_sliding_window(): example_gqa_sink_bwd_bhsd.main(window_size=128) -@tilelang.testing.pytest.mark.xfail def test_example_gqa_sink_varlen(): example_gqa_sink_fwd_varlen.main() # non-causal example_gqa_sink_bwd_varlen.main() # causal diff --git a/examples/maca/convolution/example_convolution.py b/examples/maca/convolution/example_convolution.py index 1599d346..4f71210d 100644 --- a/examples/maca/convolution/example_convolution.py +++ b/examples/maca/convolution/example_convolution.py @@ -83,8 +83,10 @@ def main(argv=None): args = parser.parse_args(argv) N, C, H, W, F, K, S, D, P = args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p - a = torch.randn(N, H, W, C).cuda().half() - b = torch.randn(K, K, C, F).cuda().half() + # MACA may default to non-standard 4D strides here; TileLang expects + # dense NHWC/HWCF tensors. + a = torch.randn(N, H, W, C).cuda().half().contiguous() + b = torch.randn(K, K, C, F).cuda().half().contiguous() block_m = 64 block_n = 128 diff --git a/examples/maca/convolution/test_example_convolution.py b/examples/maca/convolution/test_example_convolution.py index 86ba9bce..186b13b2 100644 --- a/examples/maca/convolution/test_example_convolution.py +++ b/examples/maca/convolution/test_example_convolution.py @@ -4,8 +4,6 @@ import example_convolution_autotune -# TODO(@cy): TMA with convolution must be fixed in future. -@tilelang.testing.pytest.mark.xfail def test_example_convolution(): example_convolution.main([]) diff --git a/examples/maca/deepseek_mla/example_mla_decode.py b/examples/maca/deepseek_mla/example_mla_decode.py index d6d76e54..b5b372c5 100644 --- a/examples/maca/deepseek_mla/example_mla_decode.py +++ b/examples/maca/deepseek_mla/example_mla_decode.py @@ -5,6 +5,7 @@ import tilelang.language as T from einops import rearrange, einsum import argparse +from tilelang.utils.target import determine_target, target_is_maca @tilelang.jit( @@ -20,6 +21,9 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ kv_group_num = heads // kv_head_num VALID_BLOCK_H = min(block_H, kv_group_num) assert kv_head_num == 1, "kv_head_num must be 1" + is_maca = target_is_maca(determine_target("auto", return_object=True)) + pipeline_stages = 1 if is_maca else 2 + main_threads = 64 if is_maca else 256 @T.prim_func def main_split( @@ -32,7 +36,7 @@ def main_split( Output: T.Tensor([batch, heads, dim], dtype), ): # flash_attn_split - with T.Kernel(batch, heads // min(block_H, kv_group_num), num_split, threads=256) as (bid, hid, bz): + with T.Kernel(batch, heads // min(block_H, kv_group_num), num_split, threads=main_threads) as (bid, hid, bz): Q_shared = T.alloc_shared([block_H, dim], dtype) S_shared = T.alloc_shared([block_H, block_N], dtype) Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype) @@ -58,7 +62,7 @@ def main_split( T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = T.ceildiv((seqlen_kv // num_split), block_N) - for k in T.Pipelined(loop_range, num_stages=2): + for k in T.Pipelined(loop_range, num_stages=pipeline_stages): kv_start = (seqlen_kv // num_split) * bz + k * block_N kv_end = (seqlen_kv // num_split) * bz + (k + 1) * block_N T.copy(KV[bid, kv_start:kv_end, cur_kv_head, :], KV_shared) @@ -129,7 +133,7 @@ def main_no_split( Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), Output: T.Tensor([batch, heads, dim], dtype), ): - with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=256) as (hid, bid): + with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=main_threads) as (hid, bid): Q_shared = T.alloc_shared([block_H, dim], dtype) S_shared = T.alloc_shared([block_H, block_N], dtype) Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype) @@ -153,7 +157,7 @@ def main_no_split( T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = T.ceildiv(seqlen_kv, block_N) - for k in T.Pipelined(loop_range, num_stages=2): + for k in T.Pipelined(loop_range, num_stages=pipeline_stages): T.copy(KV[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], KV_shared) T.copy(K_pe[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], K_pe_shared) T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True) @@ -232,8 +236,10 @@ def main( qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim) pv_flops = 2 * batch * heads * kv_ctx * dim total_flops = qk_flops + pv_flops - BLOCK_N = 64 - BLOCK_H = min(64, heads // kv_heads) + target = determine_target("auto", return_object=True) + is_maca = target_is_maca(target) + BLOCK_N = 16 if is_maca else 64 + BLOCK_H = min(16 if is_maca else 64, heads // kv_heads) num_split = 1 softmax_scale = (dim + pe_dim) ** -0.5 @@ -253,8 +259,10 @@ def run_regression_perf( dim=512, pe_dim=64, ): - BLOCK_N = 64 - BLOCK_H = min(64, heads // kv_heads) + target = determine_target("auto", return_object=True) + is_maca = target_is_maca(target) + BLOCK_N = 16 if is_maca else 64 + BLOCK_H = min(16 if is_maca else 64, heads // kv_heads) num_split = 1 softmax_scale = (dim + pe_dim) ** -0.5 diff --git a/examples/maca/deepseek_mla/test_example_mla_decode.py b/examples/maca/deepseek_mla/test_example_mla_decode.py index 4e4e26ab..c9b7cb7f 100644 --- a/examples/maca/deepseek_mla/test_example_mla_decode.py +++ b/examples/maca/deepseek_mla/test_example_mla_decode.py @@ -4,7 +4,6 @@ import example_mla_decode -@tilelang.testing.pytest.mark.xfail def test_example_mla_decode(): example_mla_decode.main() diff --git a/examples/maca/flash_attention/example_gqa_bwd_tma_reduce_varlen.py b/examples/maca/flash_attention/example_gqa_bwd_tma_reduce_varlen.py index b09eec00..53b289c8 100644 --- a/examples/maca/flash_attention/example_gqa_bwd_tma_reduce_varlen.py +++ b/examples/maca/flash_attention/example_gqa_bwd_tma_reduce_varlen.py @@ -3,6 +3,7 @@ import tilelang import tilelang.language as T from tilelang.contrib import nvcc +from tilelang.utils.target import determine_target, target_is_maca import argparse from einops import rearrange, repeat from bert_padding import pad_input, unpad_input @@ -20,6 +21,22 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): return padding_mask +def is_maca_target(): + return target_is_maca(determine_target("auto", return_object=True)) + + +def get_varlen_fwd_configs(): + if is_maca_target(): + return 64, 32 + return 128, 64 + + +def get_varlen_bwd_configs(use_atomic): + if is_maca_target(): + return 32, 32, 128, 1, False + return 128, 32, 256, 2, use_atomic + + @tilelang.jit( out_idx=[5, 6], pass_configs={ @@ -491,8 +508,7 @@ def forward( ): BATCH, N_CTX, H, D_HEAD_QK = q.shape D_HEAD_V = v.shape[-1] - block_M = 128 - block_N = 64 + block_M, block_N = get_varlen_fwd_configs() q_unpad, indices_q, _, _ = unpad_input(q, (torch.arange(N_CTX, device=q.device).unsqueeze(0) < seqlens_q.unsqueeze(1))) k_unpad, indices_k, _, _ = unpad_input(k, (torch.arange(N_CTX, device=k.device).unsqueeze(0) < seqlens_k.unsqueeze(1))) v_unpad, _, _, _ = unpad_input(v, (torch.arange(N_CTX, device=v.device).unsqueeze(0) < seqlens_k.unsqueeze(1))) @@ -506,6 +522,7 @@ def forward( ctx.save_for_backward(q_unpad, k_unpad, v_unpad, o_unpad, lse, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k) ctx.batch = BATCH ctx.causal = causal + _, _, _, _, use_atomic = get_varlen_bwd_configs(use_atomic) ctx.use_atomic = use_atomic ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_k = max_seqlen_k @@ -530,12 +547,11 @@ def maybe_contiguous(x): return x do, q, k, v, o = [maybe_contiguous(x) for x in (do_unpad, q, k, v, o)] - block_M = 128 - block_N = 32 + block_M, block_N, threads, num_stages, use_atomic = get_varlen_bwd_configs(ctx.use_atomic) mod_prep = flashattn_bwd_preprocess(BATCH, H, total_q, N_CTX, ctx.max_seqlen_q, D_HEAD_V) delta = mod_prep(o, do, cu_seqlens_q) - if ctx.use_atomic: + if use_atomic: kernel = flashattn_bwd_atomic_add( BATCH, total_q, @@ -548,8 +564,8 @@ def maybe_contiguous(x): ctx.causal, block_M, block_N, - threads=256, - num_stages=2, + threads=threads, + num_stages=num_stages, groups=groups, ) dq = torch.zeros_like(q, dtype=torch.float32) @@ -569,8 +585,8 @@ def maybe_contiguous(x): ctx.causal, block_M, block_N, - threads=256, - num_stages=2, + threads=threads, + num_stages=num_stages, groups=groups, ) mod_post = flashattn_bwd_postprocess(total_q, total_kv, H, HEAD_KV, D_HEAD_QK, D_HEAD_V) diff --git a/examples/maca/flash_attention/test_example_flash_attention.py b/examples/maca/flash_attention/test_example_flash_attention.py index f90a6385..6e8bb664 100644 --- a/examples/maca/flash_attention/test_example_flash_attention.py +++ b/examples/maca/flash_attention/test_example_flash_attention.py @@ -11,7 +11,6 @@ import example_gqa_fwd_varlen -@tilelang.testing.pytest.mark.xfail def test_example_gqa_bwd_tma_reduce_varlen(): example_gqa_bwd_tma_reduce_varlen.main() diff --git a/examples/maca/gdn/example_chunk_delta_bwd.py b/examples/maca/gdn/example_chunk_delta_bwd.py index 466c4718..eda3e541 100644 --- a/examples/maca/gdn/example_chunk_delta_bwd.py +++ b/examples/maca/gdn/example_chunk_delta_bwd.py @@ -5,6 +5,7 @@ import tilelang import tilelang.language as T from tilelang.profiler import do_bench +from tilelang.utils.target import determine_target, target_is_maca print(tilelang.__file__, flush=True) @@ -16,9 +17,10 @@ print(fla.__file__, flush=True) from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu -except ImportError: - print("fla not found, using tilelang implementation") +except Exception as exc: + print(f"fla unavailable, using tilelang implementation: {exc}") fla = None + chunk_gated_delta_rule_bwd_dhu = None import torch import torch.nn.functional as F @@ -42,10 +44,10 @@ def prepare_input( gate_dtype, state_dtype, ): - Q = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() - K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() - K = F.normalize(K, dim=-1, p=2) - W = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + Q = torch.randn(B, S, H, DK, dtype=input_dtype).cuda().contiguous() + K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda().contiguous() + K = F.normalize(K, dim=-1, p=2).contiguous() + W = torch.randn(B, S, H, DK, dtype=input_dtype).cuda().contiguous() # Note: G should be in logspace and do chunkwise cumsum G = torch.randn(B, S, H, dtype=gate_dtype).cuda() G = F.logsigmoid(G) @@ -53,13 +55,13 @@ def prepare_input( from fla.ops.utils.cumsum import chunk_local_cumsum G = chunk_local_cumsum(G, chunk_size) - except ImportError: - print("fla not found, skip cumsum") + except Exception as exc: + print(f"fla unavailable, skip cumsum: {exc}") - h0 = torch.randn(B, H, DK, DV, dtype=input_dtype).cuda() - dht = torch.randn(B, H, DK, DV, dtype=input_dtype).cuda() - dO = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() - dv = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + h0 = torch.randn(B, H, DK, DV, dtype=input_dtype).cuda().contiguous() + dht = torch.randn(B, H, DK, DV, dtype=input_dtype).cuda().contiguous() + dO = torch.randn(B, S, H, DV, dtype=input_dtype).cuda().contiguous() + dv = torch.randn(B, S, H, DV, dtype=input_dtype).cuda().contiguous() return Q, K, W, G, h0, dht, dO, dv @@ -76,14 +78,14 @@ def prepare_input_fake( gate_dtype, state_dtype, ): - Q = torch.ones(B, S, H, DK, dtype=input_dtype).cuda() - K = torch.ones(B, S, H, DK, dtype=input_dtype).cuda() - W = torch.ones(B, S, H, DK, dtype=input_dtype).cuda() + Q = torch.ones(B, S, H, DK, dtype=input_dtype).cuda().contiguous() + K = torch.ones(B, S, H, DK, dtype=input_dtype).cuda().contiguous() + W = torch.ones(B, S, H, DK, dtype=input_dtype).cuda().contiguous() G = torch.ones(B, S, H, dtype=gate_dtype).cuda() - h0 = torch.ones(B, H, DK, DV, dtype=input_dtype).cuda() - dht = torch.ones(B, H, DK, DV, dtype=input_dtype).cuda() - dO = torch.ones(B, S, H, DV, dtype=input_dtype).cuda() - dv = torch.ones(B, S, H, DV, dtype=input_dtype).cuda() + h0 = torch.ones(B, H, DK, DV, dtype=input_dtype).cuda().contiguous() + dht = torch.ones(B, H, DK, DV, dtype=input_dtype).cuda().contiguous() + dO = torch.ones(B, S, H, DV, dtype=input_dtype).cuda().contiguous() + dv = torch.ones(B, S, H, DV, dtype=input_dtype).cuda().contiguous() return Q, K, W, G, h0, dht, dO, dv @@ -206,6 +208,10 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu( threads=256, num_stages=0, ): + is_maca = target_is_maca(determine_target("auto", return_object=True)) + if is_maca: + block_DV = min(block_DV, 16) + block_S = chunk_size # Should support cu_seqlen BS = S // block_S @@ -265,14 +271,16 @@ def kernel( Q_fragment = T.alloc_fragment((block_S, DK), dtype=accum_dtype) Q_fragment_t = T.alloc_fragment((DK, block_S), dtype=accum_dtype) - T.use_swizzle(10) + if not is_maca: + T.use_swizzle(10) - T.annotate_layout( - { - dO_shared: tilelang.layout.make_swizzled_layout(dO_shared), - Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), - } - ) + if not is_maca: + T.annotate_layout( + { + dO_shared: tilelang.layout.make_swizzled_layout(dO_shared), + Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), + } + ) if use_final_state_gradient: T.copy(dht[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV], b_dh_shared) diff --git a/examples/maca/gdn/example_chunk_delta_h.py b/examples/maca/gdn/example_chunk_delta_h.py index c34d9b53..71a0e5fb 100644 --- a/examples/maca/gdn/example_chunk_delta_h.py +++ b/examples/maca/gdn/example_chunk_delta_h.py @@ -14,9 +14,10 @@ print(fla.__file__) from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_fwd_h -except ImportError: - print("fla not found, using tilelang implementation") +except Exception as exc: + print(f"fla unavailable, using tilelang implementation: {exc}") fla = None + chunk_gated_delta_rule_fwd_h = None import torch import torch.nn.functional as F @@ -48,22 +49,22 @@ def prepare_input( accum_dtype, gate_dtype, ): - K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() - K = F.normalize(K, dim=-1, p=2) - W = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() - W = F.normalize(W, dim=-1, p=2) - U = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() - U = F.normalize(U, dim=-1, p=2) + K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda().contiguous() + K = F.normalize(K, dim=-1, p=2).contiguous() + W = torch.randn(B, S, H, DK, dtype=input_dtype).cuda().contiguous() + W = F.normalize(W, dim=-1, p=2).contiguous() + U = torch.randn(B, S, H, DV, dtype=input_dtype).cuda().contiguous() + U = F.normalize(U, dim=-1, p=2).contiguous() G = torch.randn(B, S, H, dtype=gate_dtype).cuda() G = F.logsigmoid(G) try: from fla.ops.utils.cumsum import chunk_local_cumsum G = chunk_local_cumsum(G, chunk_size) - except ImportError: - print("fla not found, skip cumsum") + except Exception as exc: + print(f"fla unavailable, skip cumsum: {exc}") - initial_state = torch.randn(B, H, DK, DV, dtype=input_dtype).cuda() + initial_state = torch.randn(B, H, DK, DV, dtype=input_dtype).cuda().contiguous() return K, W, U, G, initial_state diff --git a/examples/maca/gdn/example_chunk_o.py b/examples/maca/gdn/example_chunk_o.py index bb95f555..dd78f343 100644 --- a/examples/maca/gdn/example_chunk_o.py +++ b/examples/maca/gdn/example_chunk_o.py @@ -12,9 +12,10 @@ print(fla.__file__) from fla.ops.common.chunk_o import chunk_fwd_o -except ImportError: - print("fla not found, using tilelang implementation") +except Exception as exc: + print(f"fla unavailable, using tilelang implementation: {exc}") fla = None + chunk_fwd_o = None import torch @@ -34,9 +35,9 @@ def prepare_input( gate_dtype, ): BS = chunk_size - Q = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() - K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() - V = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + Q = torch.randn(B, S, H, DK, dtype=input_dtype).cuda().contiguous() + K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda().contiguous() + V = torch.randn(B, S, H, DV, dtype=input_dtype).cuda().contiguous() HIDDEN = torch.randn(B, S // BS, H, DK, DV, dtype=input_dtype).cuda() G = torch.randn(B, S, H, dtype=gate_dtype).cuda() return Q, K, V, HIDDEN, G diff --git a/examples/maca/gdn/example_chunk_o_bwd.py b/examples/maca/gdn/example_chunk_o_bwd.py index 19233de6..d0744f1d 100644 --- a/examples/maca/gdn/example_chunk_o_bwd.py +++ b/examples/maca/gdn/example_chunk_o_bwd.py @@ -15,9 +15,10 @@ print(fla.__file__) from fla.ops.common.chunk_o import chunk_bwd_dqkwg -except ImportError: - print("fla not found, using tilelang implementation") +except Exception as exc: + print(f"fla unavailable, using tilelang implementation: {exc}") fla = None + chunk_bwd_dqkwg = None import torch from test_utils import assert_similar @@ -40,15 +41,15 @@ def prepare_input_fake( state_dtype, ): BS = S // chunk_size - Q = torch.ones(B, S, H, DK, dtype=input_dtype).cuda() - K = torch.ones(B, S, H, DK, dtype=input_dtype).cuda() - V = torch.ones(B, S, H, DV, dtype=input_dtype).cuda() + Q = torch.ones(B, S, H, DK, dtype=input_dtype).cuda().contiguous() + K = torch.ones(B, S, H, DK, dtype=input_dtype).cuda().contiguous() + V = torch.ones(B, S, H, DV, dtype=input_dtype).cuda().contiguous() h = torch.ones(B, BS, H, DK, DV, dtype=input_dtype).cuda() G = torch.ones(B, S, H, dtype=gate_dtype).cuda() - dO = torch.ones(B, S, H, DV, dtype=input_dtype).cuda() + dO = torch.ones(B, S, H, DV, dtype=input_dtype).cuda().contiguous() dh = torch.ones(B, BS, H, DK, DV, dtype=input_dtype).cuda() - dv = torch.ones(B, S, H, DV, dtype=output_dtype).cuda() - W = torch.ones(B, S, H, DK, dtype=input_dtype).cuda() + dv = torch.ones(B, S, H, DV, dtype=output_dtype).cuda().contiguous() + W = torch.ones(B, S, H, DK, dtype=input_dtype).cuda().contiguous() return Q, K, V, h, G, dO, dh, dv, W @@ -67,15 +68,15 @@ def prepare_input( ): BS = S // chunk_size - Q = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() - K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() - V = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + Q = torch.randn(B, S, H, DK, dtype=input_dtype).cuda().contiguous() + K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda().contiguous() + V = torch.randn(B, S, H, DV, dtype=input_dtype).cuda().contiguous() h = torch.randn(B, BS, H, DK, DV, dtype=input_dtype).cuda() G = torch.randn(B, S, H, dtype=gate_dtype).cuda() - dO = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + dO = torch.randn(B, S, H, DV, dtype=input_dtype).cuda().contiguous() dh = torch.randn(B, BS, H, DK, DV, dtype=input_dtype).cuda() - dv = torch.randn(B, S, H, DV, dtype=output_dtype).cuda() - W = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + dv = torch.randn(B, S, H, DV, dtype=output_dtype).cuda().contiguous() + W = torch.randn(B, S, H, DK, dtype=input_dtype).cuda().contiguous() return Q, K, V, h, G, dO, dh, dv, W diff --git a/examples/maca/gdn/example_chunk_scaled_dot_kkt.py b/examples/maca/gdn/example_chunk_scaled_dot_kkt.py index c16374fe..5fad17ed 100644 --- a/examples/maca/gdn/example_chunk_scaled_dot_kkt.py +++ b/examples/maca/gdn/example_chunk_scaled_dot_kkt.py @@ -12,9 +12,10 @@ print(fla.__file__) from fla.ops.common.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd -except ImportError: - print("fla not found, using tilelang implementation") +except Exception as exc: + print(f"fla unavailable, using tilelang implementation: {exc}") fla = None + chunk_scaled_dot_kkt_fwd = None import torch @@ -31,7 +32,7 @@ def prepare_input( output_dtype, accum_dtype, ): - K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda().contiguous() Beta = torch.randn(B, S, H, dtype=input_dtype).cuda() G = torch.randn(B, S, H, dtype=accum_dtype).cuda() return K, Beta, G diff --git a/examples/maca/gdn/example_cumsum.py b/examples/maca/gdn/example_cumsum.py index 0760b496..71ef83f3 100644 --- a/examples/maca/gdn/example_cumsum.py +++ b/examples/maca/gdn/example_cumsum.py @@ -13,9 +13,10 @@ print(fla.__file__) from fla.ops.utils.cumsum import chunk_local_cumsum_scalar -except ImportError: - print("fla not found, using tilelang implementation") +except Exception as exc: + print(f"fla unavailable, using tilelang implementation: {exc}") fla = None + chunk_local_cumsum_scalar = None import torch diff --git a/examples/maca/gdn/example_wy_fast.py b/examples/maca/gdn/example_wy_fast.py index d36dcf9b..004f4c7a 100644 --- a/examples/maca/gdn/example_wy_fast.py +++ b/examples/maca/gdn/example_wy_fast.py @@ -12,9 +12,10 @@ print(fla.__file__) from fla.ops.gated_delta_rule.wy_fast import recompute_w_u_fwd -except ImportError: - print("fla not found, using tilelang implementation") +except Exception as exc: + print(f"fla unavailable, using tilelang implementation: {exc}") fla = None + recompute_w_u_fwd = None import torch @@ -23,11 +24,11 @@ def prepare_input(B, S, H, DK, DV, chunk_size, input_dtype, output_dtype, gate_dtype=torch.float32): BS = chunk_size - K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() - V = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda().contiguous() + V = torch.randn(B, S, H, DV, dtype=input_dtype).cuda().contiguous() Beta = torch.randn(B, S, H, dtype=input_dtype).cuda() G = torch.randn(B, S, H, dtype=gate_dtype).cuda() - A = torch.randn(B, S, H, BS, dtype=output_dtype).cuda() + A = torch.randn(B, S, H, BS, dtype=output_dtype).cuda().contiguous() return K, V, Beta, G, A diff --git a/examples/maca/gdn/example_wy_fast_bwd_split.py b/examples/maca/gdn/example_wy_fast_bwd_split.py index 822f745f..13530240 100644 --- a/examples/maca/gdn/example_wy_fast_bwd_split.py +++ b/examples/maca/gdn/example_wy_fast_bwd_split.py @@ -13,9 +13,10 @@ print(fla.__file__) from fla.ops.gated_delta_rule.wy_fast import bwd_prepare_wy_repr -except ImportError: - print("fla not found, using tilelang implementation") +except Exception as exc: + print(f"fla unavailable, using tilelang implementation: {exc}") fla = None + bwd_prepare_wy_repr = None import torch import torch.nn.functional as F @@ -38,13 +39,13 @@ def prepare_input_fake( state_dtype, ): BS = chunk_size - K = torch.ones(B, S, H, DK, dtype=input_dtype).cuda() - V = torch.ones(B, S, H, DV, dtype=input_dtype).cuda() + K = torch.ones(B, S, H, DK, dtype=input_dtype).cuda().contiguous() + V = torch.ones(B, S, H, DV, dtype=input_dtype).cuda().contiguous() Beta = torch.ones(B, S, H, dtype=input_dtype).cuda() G = torch.ones(B, S, H, dtype=gate_dtype).cuda() - A = torch.ones(B, S, H, BS, dtype=input_dtype).cuda() - dw = torch.ones(B, S, H, DK, dtype=input_dtype).cuda() - du = torch.ones(B, S, H, DV, dtype=input_dtype).cuda() + A = torch.ones(B, S, H, BS, dtype=input_dtype).cuda().contiguous() + dw = torch.ones(B, S, H, DK, dtype=input_dtype).cuda().contiguous() + du = torch.ones(B, S, H, DV, dtype=input_dtype).cuda().contiguous() return K, V, Beta, G, A, dw, du @@ -62,15 +63,15 @@ def prepare_input( state_dtype, ): BS = chunk_size - K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() - K = F.normalize(K, dim=-1, p=2) - V = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() - V = F.normalize(V, dim=-1, p=2) + K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda().contiguous() + K = F.normalize(K, dim=-1, p=2).contiguous() + V = torch.randn(B, S, H, DV, dtype=input_dtype).cuda().contiguous() + V = F.normalize(V, dim=-1, p=2).contiguous() Beta = torch.randn(B, S, H, dtype=input_dtype).cuda() G = torch.randn(B, S, H, dtype=gate_dtype).cuda() - A = torch.randn(B, S, H, BS, dtype=input_dtype).cuda() - dw = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() - du = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + A = torch.randn(B, S, H, BS, dtype=input_dtype).cuda().contiguous() + dw = torch.randn(B, S, H, DK, dtype=input_dtype).cuda().contiguous() + du = torch.randn(B, S, H, DV, dtype=input_dtype).cuda().contiguous() return K, V, Beta, G, A, dw, du diff --git a/examples/maca/gdn/test_example_gdn_compilation.py b/examples/maca/gdn/test_example_gdn_compilation.py index 9a7943c3..530452f6 100644 --- a/examples/maca/gdn/test_example_gdn_compilation.py +++ b/examples/maca/gdn/test_example_gdn_compilation.py @@ -24,7 +24,6 @@ num_stages = 0 -@tilelang.testing.pytest.mark.xfail def test_example_wy_fast_compilation(): from example_wy_fast import tilelang_recompute_w_u_fwd, prepare_input @@ -127,7 +126,6 @@ def test_example_wy_fast_bwd_split_compilation(): dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum(dim=-1) -@tilelang.testing.pytest.mark.xfail def test_example_chunk_o_compilation(): from example_chunk_o import tilelang_chunk_fwd_o, prepare_input @@ -167,7 +165,6 @@ def test_example_chunk_o_compilation(): O_tilelang = kernel(Q, K, V, HIDDEN, G) # noqa: F841 -@tilelang.testing.pytest.mark.xfail def test_example_chunk_o_bwd_compilation(): from example_chunk_o_bwd import tilelang_chunk_o_bwd_dqkwg, prepare_input @@ -210,7 +207,6 @@ def test_example_chunk_o_bwd_compilation(): dg_tilelang = dg_tilelang.sum(dim=0) -@tilelang.testing.pytest.mark.xfail def test_example_chunk_scaled_dot_kkt_compilation(): from example_chunk_scaled_dot_kkt import tilelang_chunk_scaled_dot_kkt_fwd, prepare_input @@ -222,7 +218,6 @@ def test_example_chunk_scaled_dot_kkt_compilation(): A_tilelang = kernel(K, Beta, G) # noqa: F841 -@tilelang.testing.pytest.mark.xfail def test_example_cumsum_compilation(): from example_cumsum import tilelang_chunk_local_cumsum_scalar, prepare_cumsum_input, prepare_cumsum_output @@ -285,7 +280,6 @@ def test_example_chunk_delta_h_compilation(): h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G, initial_state) # noqa: F841 -@tilelang.testing.pytest.mark.xfail def test_example_chunk_delta_bwd_compilation(): from example_chunk_delta_bwd import tilelang_chunk_gated_delta_rule_bwd_dhu, prepare_input