diff --git a/examples/gemm_fp4/example_gemm_a8w4_sm120.py b/examples/gemm_fp4/example_gemm_a8w4_sm120.py new file mode 100644 index 000000000..c18a59edd --- /dev/null +++ b/examples/gemm_fp4/example_gemm_a8w4_sm120.py @@ -0,0 +1,116 @@ +import torch +import tilelang +import tilelang.language as T + + +FP4_E2M1_TO_FLOAT = [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, +] + + +def unpack_fp4_storage_to_float(packed: torch.Tensor, rows: int, cols: int) -> torch.Tensor: + packed_u8 = packed.to(torch.uint8).reshape(rows, cols // 2) + lo = packed_u8 & 0x0F + hi = (packed_u8 >> 4) & 0x0F + values = torch.stack([lo, hi], dim=-1).reshape(rows, cols).to(torch.int64) + lut = torch.tensor(FP4_E2M1_TO_FLOAT, dtype=torch.float32, device=packed.device) + return lut[values] + + +def require_sm120(): + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available") + major, _ = torch.cuda.get_device_capability() + if major < 12: + raise RuntimeError("SM120 A8W4 GEMM requires an SM120+ CUDA device") + + +def matmul_a8w4( + M, + N, + K, + block_M, + block_N, + block_K, + out_dtype, + accum_dtype, + num_stages=2, + threads=128, +): + if K % 32 != 0 or block_K % 32 != 0 or block_K > K: + raise ValueError("matmul_a8w4 requires K and block_K to be multiples of 32 and block_K <= K") + + A_shape = (M, K) + B_shape = (N, K) + + @T.prim_func + def main( + A: T.Tensor(A_shape, T.float8_e4m3fn), + B: T.Tensor(B_shape, T.float4_e2m1fn), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), T.float8_e4m3fn) + B_shared = T.alloc_shared((block_N, block_K), T.float4_e2m1fn) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.clear(C_local) + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, ko * block_K], A_shared) + T.copy(B[bx * block_N, ko * block_K], B_shared) + T.gemm(A_shared, B_shared, C_local, transpose_B=True) + + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def main(): + require_sm120() + + M, N, K = 256, 256, 256 + block_M, block_N, block_K = 128, 128, 64 + func = matmul_a8w4(M, N, K, block_M, block_N, block_K, T.float32, T.float32) + kernel = tilelang.compile( + func, + out_idx=[2], + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + + torch.manual_seed(0) + a_f16 = torch.randn(M, K, device="cuda", dtype=torch.float16) + a = a_f16.to(torch.float8_e4m3fn) + b = torch.randint(0, 256, (N, K // 2), device="cuda", dtype=torch.uint8).view(torch.int8) + + c_zero = kernel(torch.zeros_like(a), torch.zeros_like(b)) + assert c_zero.abs().max().item() == 0.0 + + c = kernel(a, b) + ref = a.to(torch.float32) @ unpack_fp4_storage_to_float(b, N, K).T + diff = (c.float() - ref).abs() + rel_err = diff.sum().item() / (ref.abs().sum().item() + 1e-10) + assert diff.max().item() <= 1e-3 + print(f"max_abs_diff={diff.max().item():.6f}, rel_err={rel_err:.6f}") + + +if __name__ == "__main__": + main() diff --git a/examples/gemm_fp4/example_gemm_fp4_sm120.py b/examples/gemm_fp4/example_gemm_fp4_sm120.py new file mode 100644 index 000000000..d7925e2ce --- /dev/null +++ b/examples/gemm_fp4/example_gemm_fp4_sm120.py @@ -0,0 +1,116 @@ +import torch +import tilelang +import tilelang.language as T + + +FP4_E2M1_TO_FLOAT = [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, +] + + +def unpack_fp4_storage_to_float(packed: torch.Tensor, rows: int, cols: int) -> torch.Tensor: + # Host tensors use packed bytes; kernel tensors use T.float4_e2m1fn. + packed_u8 = packed.to(torch.uint8).reshape(rows, cols // 2) + lo = packed_u8 & 0x0F + hi = (packed_u8 >> 4) & 0x0F + values = torch.stack([lo, hi], dim=-1).reshape(rows, cols).to(torch.int64) + lut = torch.tensor(FP4_E2M1_TO_FLOAT, dtype=torch.float32, device=packed.device) + return lut[values] + + +def require_sm120(): + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available") + major, _ = torch.cuda.get_device_capability() + if major < 12: + raise RuntimeError("SM120 FP4 GEMM requires an SM120+ CUDA device") + + +def matmul_fp4( + M, + N, + K, + block_M, + block_N, + block_K, + out_dtype, + accum_dtype, + num_stages=2, + threads=128, +): + if K % 32 != 0 or block_K % 32 != 0 or block_K > K: + raise ValueError("matmul_fp4 requires K and block_K to be multiples of 32 and block_K <= K") + + A_shape = (M, K) + B_shape = (N, K) + + @T.prim_func + def main( + A: T.Tensor(A_shape, T.float4_e2m1fn), + B: T.Tensor(B_shape, T.float4_e2m1fn), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), T.float4_e2m1fn) + B_shared = T.alloc_shared((block_N, block_K), T.float4_e2m1fn) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.clear(C_local) + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, ko * block_K], A_shared) + T.copy(B[bx * block_N, ko * block_K], B_shared) + T.gemm(A_shared, B_shared, C_local, transpose_B=True) + + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def main(): + require_sm120() + + M, N, K = 256, 256, 256 + block_M, block_N, block_K = 128, 128, 64 + func = matmul_fp4(M, N, K, block_M, block_N, block_K, T.float32, T.float32) + kernel = tilelang.compile( + func, + out_idx=[2], + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + + torch.manual_seed(0) + a = torch.randint(0, 256, (M, K // 2), device="cuda", dtype=torch.uint8).view(torch.int8) + b = torch.randint(0, 256, (N, K // 2), device="cuda", dtype=torch.uint8).view(torch.int8) + + zero = torch.zeros_like(a) + c_zero = kernel(zero, torch.zeros_like(b)) + assert c_zero.abs().max().item() == 0.0 + + c = kernel(a, b) + ref = unpack_fp4_storage_to_float(a, M, K) @ unpack_fp4_storage_to_float(b, N, K).T + max_diff = (c.float() - ref).abs().max().item() + assert max_diff <= 1e-3 + print(f"max_abs_diff={max_diff:.6f}") + + +if __name__ == "__main__": + main() diff --git a/src/backend/cuda/codegen/codegen_cuda.cc b/src/backend/cuda/codegen/codegen_cuda.cc index 832a8fba9..f7387942c 100644 --- a/src/backend/cuda/codegen/codegen_cuda.cc +++ b/src/backend/cuda/codegen/codegen_cuda.cc @@ -1875,11 +1875,9 @@ std::string CodeGenTileLangCUDA::GetBufferRef(DataType t, const VarNode *buffer_var = buffer->data.get(); std::ostringstream os; std::string vid = GetVarID(buffer_var); - // For fp4 packed buffers, use the packed buffer name for vector accesses - auto it = fp4_packed_buffers_.find(buffer_var); - if (it != fp4_packed_buffers_.end() && !t.is_scalar()) { - vid = it->second; - } + // FP4 storage is selected by scope, not by renaming the buffer. Global and + // legacy shared scopes use packed byte addresses, while local fragments keep + // the declared name so generated MMA operands match the fragment allocation. std::string scope; if (alloc_storage_scope_.count(buffer_var)) { scope = alloc_storage_scope_.at(buffer_var); @@ -1918,7 +1916,17 @@ std::string CodeGenTileLangCUDA::GetBufferRef(DataType t, return os.str(); } std::string index_str = PrintExpr(index); - if ((t.bits() == 4 && !t.is_float4()) || (t.bits() == 1 && t.is_int())) { + if (IsFp4PaddedSharedStorage(buffer_var, buffer_element_dtype) && + t.is_float4_e2m1fn() && t.lanes() == 1) { + // SM120 b4x16 ldmatrix consumes shared FP4 in 16-value rows padded to + // 32 logical slots. Convert the logical FP4 element index to the padded + // row index first, then to the packed byte offset. + PrimExpr padded_index = GetFp4PaddedSharedIndex(index); + index_str = + PrintExpr(arith::Analyzer().Simplify(truncdiv(padded_index, 2))); + os << buffer_str << "[" << index_str << "]"; + } else if ((t.bits() == 4 && !t.is_float4()) || + (t.bits() == 1 && t.is_int())) { // Scalar int4/uint4 storage is byte-packed (2 logical elements per byte). // Vector int4 loads/stores reinterpret the underlying packed bytes as the // requested vector type, so their index still advances by the vector lane @@ -1933,17 +1941,21 @@ std::string CodeGenTileLangCUDA::GetBufferRef(DataType t, << " + " << index_str << ")"; } else if (t == buffer_element_dtype) { int div_factor = 1; - if (buffer_element_dtype.is_float4() && buffer_element_dtype.lanes() == 1) { + if (IsFp4PackedStorage(buffer_var, buffer_element_dtype)) { + // Packed FP4 buffers are byte-addressed. Dividing by two maps two + // neighboring logical FP4 elements to one backing byte; scalar helpers + // still receive the original logical index for nibble selection. div_factor = 2; } index_str = PrintExpr(arith::Analyzer().Simplify(truncdiv(index, div_factor))); os << buffer_str << "[" << index_str << "]"; } else { - // Fix fp4 pointer arithmetic: fp4 elements are 4-bit packed 2 per byte. - // fp4* + n incorrectly advances n bytes (skipping 2n elements). int div_factor = 1; - if (buffer_element_dtype.is_float4() && buffer_element_dtype.lanes() == 1) { + if (IsFp4PackedStorage(buffer_var, buffer_element_dtype)) { + // Reinterpreted FP4 references start from the backing byte offset. C + // pointer arithmetic on fp4_e2_t would advance by whole storage objects + // and skip the neighboring nibble. div_factor = 2; } index_str = @@ -1958,6 +1970,75 @@ std::string CodeGenTileLangCUDA::GetVecLoad(DataType t, const BufferNode *buffer, PrimExpr base) { const VarNode *buffer_var = buffer->data.get(); + if (IsFp4SemanticLocalStorage(buffer_var, buffer->dtype) && + t.is_float4_e2m1fn() && t.lanes() > 1) { + // Local FP4 vectors are logical element arrays, not packed byte arrays. + std::ostringstream os; + os << "make_fp4_e2_" << t.lanes() << "_t("; + for (int i = 0; i < t.lanes(); ++i) { + if (i != 0) { + os << ", "; + } + PrimExpr index = arith::Analyzer().Simplify( + base + IntImm(base.dtype(), static_cast(i))); + os << GetBufferRef(t.element_of(), buffer, index); + } + os << ")"; + return os.str(); + } + + if (IsFp4PaddedSharedStorage(buffer_var, buffer->dtype) && + t.is_float4_e2m1fn() && t.lanes() > 1) { + ICHECK(t.lanes() <= 32 && (t.lanes() & (t.lanes() - 1)) == 0) + << "Unsupported SM120 padded shared FP4 vector load: " << t; + arith::Analyzer analyzer; + bool row_aligned = is_zero(analyzer.Simplify(truncmod(base, 16))); + auto padded_index = [&](int logical_offset) { + PrimExpr logical_index = arith::Analyzer().Simplify( + base + IntImm(base.dtype(), logical_offset)); + return this->PrintExpr(GetFp4PaddedSharedIndex(logical_index)); + }; + auto byte_offset = [&](int logical_offset) { + PrimExpr logical_index = arith::Analyzer().Simplify( + base + IntImm(base.dtype(), logical_offset)); + PrimExpr padded_index_expr = GetFp4PaddedSharedIndex(logical_index); + return this->PrintExpr( + arith::Analyzer().Simplify(truncdiv(padded_index_expr, 2))); + }; + + std::string vid = GetVarID(buffer_var); + if (row_aligned && t.lanes() == 32) { + // A 32-wide FP4 vector spans two padded b4x16 rows, so materialize it + // from two contiguous 16-value row fragments. + std::ostringstream os; + os << "fp4_e2_32_t{*(fp4_e2_16_t*)((uint8_t*)" << vid << " + " + << byte_offset(0) << "), *(fp4_e2_16_t*)((uint8_t*)" << vid << " + " + << byte_offset(16) << ")}"; + return os.str(); + } + if (row_aligned) { + // Row-aligned vectors stay contiguous within the padded b4x16 row. + std::ostringstream vec_type; + PrintType(t, vec_type); + std::ostringstream os; + os << "*(" << vec_type.str() << "*)((uint8_t*)" << vid << " + " + << byte_offset(0) << ")"; + return os.str(); + } + + std::ostringstream os; + os << "make_fp4_e2_" << t.lanes() << "_t("; + for (int i = 0; i < t.lanes(); ++i) { + if (i != 0) { + os << ", "; + } + os << "tl_fp4_packed_load((fp4_e2_2_t*)" << vid << ", " << padded_index(i) + << ")"; + } + os << ")"; + return os.str(); + } + std::string scope; if (alloc_storage_scope_.count(buffer_var)) { scope = alloc_storage_scope_.at(buffer_var); @@ -1981,6 +2062,77 @@ void CodeGenTileLangCUDA::PrintVecStore(const BufferNode *buffer, DataType t, PrimExpr base, const std::string &value) { const VarNode *buffer_var = buffer->data.get(); + if (IsFp4SemanticLocalStorage(buffer_var, buffer->dtype) && + t.is_float4_e2m1fn() && t.lanes() > 1) { + // Store each FP4 lane so vector casts fill every semantic local element. + std::ostringstream vec_type; + PrintType(t, vec_type); + std::string vid = GetVarID(buffer_var); + this->PrintIndent(); + this->stream << "{ " << vec_type.str() << " __tl_fp4_vec = " << value + << "; "; + for (int i = 0; i < t.lanes(); ++i) { + std::ostringstream elem; + PrintVecElemLoad("__tl_fp4_vec", t, i, elem); + PrimExpr index = arith::Analyzer().Simplify( + base + IntImm(base.dtype(), static_cast(i))); + this->stream << vid << "[" << PrintExpr(index) << "] = " << elem.str() + << "; "; + } + this->stream << "}\n"; + return; + } + + if (IsFp4PaddedSharedStorage(buffer_var, buffer->dtype) && + t.is_float4_e2m1fn() && t.lanes() > 1) { + ICHECK(t.lanes() <= 32 && (t.lanes() & (t.lanes() - 1)) == 0) + << "Unsupported SM120 padded shared FP4 vector store: " << t; + std::string vid = GetVarID(buffer_var); + arith::Analyzer analyzer; + bool row_aligned = is_zero(analyzer.Simplify(truncmod(base, 16))); + auto padded_index = [&](int logical_offset) { + PrimExpr logical_index = arith::Analyzer().Simplify( + base + IntImm(base.dtype(), logical_offset)); + return this->PrintExpr(GetFp4PaddedSharedIndex(logical_index)); + }; + auto byte_offset = [&](int logical_offset) { + PrimExpr logical_index = arith::Analyzer().Simplify( + base + IntImm(base.dtype(), logical_offset)); + PrimExpr padded_index_expr = GetFp4PaddedSharedIndex(logical_index); + return this->PrintExpr( + arith::Analyzer().Simplify(truncdiv(padded_index_expr, 2))); + }; + + this->PrintIndent(); + if (row_aligned && t.lanes() == 32) { + // A 32-wide FP4 vector spans two padded b4x16 rows, so store it as two + // contiguous 16-value row fragments. + this->stream << "{ fp4_e2_32_t __tl_fp4_vec = " << value << "; "; + this->stream << "*(fp4_e2_16_t*)((uint8_t*)" << vid << " + " + << byte_offset(0) << ") = __tl_fp4_vec.x; "; + this->stream << "*(fp4_e2_16_t*)((uint8_t*)" << vid << " + " + << byte_offset(16) << ") = __tl_fp4_vec.y; }\n"; + } else if (row_aligned) { + std::ostringstream vec_type; + PrintType(t, vec_type); + this->stream << "*(" << vec_type.str() << "*)((uint8_t*)" << vid << " + " + << byte_offset(0) << ") = " << value << ";\n"; + } else { + std::ostringstream vec_type; + PrintType(t, vec_type); + this->stream << "{ " << vec_type.str() << " __tl_fp4_vec = " << value + << "; "; + for (int i = 0; i < t.lanes(); ++i) { + std::ostringstream elem; + PrintVecElemLoad("__tl_fp4_vec", t, i, elem); + this->stream << "tl_fp4_packed_store((fp4_e2_2_t*)" << vid << ", " + << padded_index(i) << ", " << elem.str() << "); "; + } + this->stream << "}\n"; + } + return; + } + std::string scope; if (alloc_storage_scope_.count(buffer_var)) { scope = alloc_storage_scope_.at(buffer_var); @@ -2001,6 +2153,86 @@ void CodeGenTileLangCUDA::PrintVecStore(const BufferNode *buffer, DataType t, << ");\n"; } +// FP4 has three storage cases: +// - Global buffers use packed bytes. +// - SM120 shared buffers use packed bytes plus b4x16 padded rows. +// - Local/local.fragment buffers use semantic FP4 elements for MMA operands. +bool CodeGenTileLangCUDA::IsFp4PackedStorage(const VarNode *buffer_var, + DataType element_dtype) const { + if (!element_dtype.is_float4_e2m1fn() || !element_dtype.is_scalar()) { + return false; + } + + std::string scope; + auto it = alloc_storage_scope_.find(buffer_var); + if (it != alloc_storage_scope_.end()) { + scope = it->second; + } else if (const auto *ptr = + buffer_var->type_annotation.as()) { + scope = ptr->storage_scope; + } + + if (scope.empty()) { + return true; + } + if (scope == "global") { + return true; + } + if (scope == "shared" || scope == "shared.dyn") { + // Pre-SM120 shared FP4 keeps the packed-byte convention. SM120 shared FP4 + // is handled by IsFp4PaddedSharedStorage so b4x16 row padding is preserved. + Target cur_target = Target::Current(/*allow_not_defined=*/true); + return cur_target.defined() && tl::TargetHasSMVersionGE(cur_target, 100) && + !tl::TargetHasSMVersionGE(cur_target, 120); + } + return scope != "local" && scope != "local.var" && scope != "local.fragment"; +} + +bool CodeGenTileLangCUDA::IsFp4PaddedSharedStorage( + const VarNode *buffer_var, DataType element_dtype) const { + if (!element_dtype.is_float4_e2m1fn() || !element_dtype.is_scalar()) { + return false; + } + + std::string scope; + auto it = alloc_storage_scope_.find(buffer_var); + if (it != alloc_storage_scope_.end()) { + scope = it->second; + } else if (const auto *ptr = + buffer_var->type_annotation.as()) { + scope = ptr->storage_scope; + } + + if (scope != "shared" && scope != "shared.dyn") { + return false; + } + // SM120 b4x16 ldmatrix requires a padded shared-memory row layout. + Target cur_target = Target::Current(/*allow_not_defined=*/true); + return cur_target.defined() && tl::TargetHasSMVersionGE(cur_target, 120); +} + +bool CodeGenTileLangCUDA::IsFp4SemanticLocalStorage( + const VarNode *buffer_var, DataType element_dtype) const { + if (!element_dtype.is_float4_e2m1fn() || !element_dtype.is_scalar()) { + return false; + } + + std::string scope; + auto it = alloc_storage_scope_.find(buffer_var); + if (it != alloc_storage_scope_.end()) { + scope = it->second; + } else if (const auto *ptr = + buffer_var->type_annotation.as()) { + scope = ptr->storage_scope; + } + return scope == "local" || scope == "local.fragment"; +} + +PrimExpr CodeGenTileLangCUDA::GetFp4PaddedSharedIndex(PrimExpr index) const { + arith::Analyzer analyzer; + return analyzer.Simplify(truncdiv(index, 16) * 32 + truncmod(index, 16)); +} + /** * @brief Emit CUDA/TensorLib-specific code for a call expression. * @@ -2074,6 +2306,148 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { << PrintExpr(op->args[1]) << ")"; return; } + + bool matched_fp4_padded_cp_async = false; + auto is_fp4_row_aligned = [&](PrimExpr index) { + arith::Analyzer analyzer; + return is_zero(analyzer.Simplify(truncmod(index, 16))); + }; + auto try_print_fp4_padded_cp_async = [&](int num_segments) -> bool { + // Generic cp.async assumes one contiguous byte span. SM120 shared FP4 has a + // padding gap after each 16 logical elements, so copy one 8B b4x16 segment + // at a time and compute global/shared byte offsets independently. + auto dst_elem_type = GetAccessPtrElementType(op->args[0]); + auto src_elem_type = GetAccessPtrElementType(op->args[1]); + if (!dst_elem_type.has_value() || !src_elem_type.has_value() || + !dst_elem_type->is_float4_e2m1fn() || + !src_elem_type->is_float4_e2m1fn()) { + return false; + } + + auto print_byte_ptr = [&](const PrimExpr &base, PrimExpr byte_offset) { + arith::Analyzer analyzer; + byte_offset = analyzer.Simplify(byte_offset); + return "(uint8_t*)" + this->PrintExpr(base) + " + " + + this->PrintExpr(byte_offset); + }; + + auto print_cp_async = [&](const std::string &dst, const std::string &src) { + this->PrintIndent(); + if (op->args.size() == 3) { + this->stream << "tl::cp_async_gs<8>(" << dst << ", " << src << ");\n"; + } else { + std::string condition = this->PrintExpr(op->args[3]); + this->stream << "tl::cp_async_gs_conditional<8>(" << dst << ", " << src + << ", " << condition << ");\n"; + } + }; + + auto emit_from_loads = [&](const BufferLoadNode *dst_load, + const BufferLoadNode *src_load) -> bool { + if (dst_load == nullptr || src_load == nullptr || + dst_load->indices.size() != 1U || src_load->indices.size() != 1U) { + return false; + } + const VarNode *dst_var = dst_load->buffer->data.get(); + const VarNode *src_var = src_load->buffer->data.get(); + if (!IsFp4PaddedSharedStorage(dst_var, dst_load->buffer->dtype) || + !IsFp4PackedStorage(src_var, src_load->buffer->dtype)) { + return false; + } + matched_fp4_padded_cp_async = true; + if (!is_fp4_row_aligned(dst_load->indices[0]) || + !is_fp4_row_aligned(src_load->indices[0])) { + return false; + } + + for (int segment = 0; segment < num_segments; ++segment) { + PrimExpr dst_logical_offset = IntImm( + dst_load->indices[0].dtype(), static_cast(segment * 16)); + PrimExpr src_logical_offset = IntImm( + src_load->indices[0].dtype(), static_cast(segment * 16)); + PrimExpr dst_byte_offset = truncdiv( + GetFp4PaddedSharedIndex(dst_load->indices[0] + dst_logical_offset), + 2); + PrimExpr src_byte_offset = + truncdiv(src_load->indices[0] + src_logical_offset, 2); + print_cp_async(print_byte_ptr(dst_load->buffer->data, dst_byte_offset), + print_byte_ptr(src_load->buffer->data, src_byte_offset)); + } + return true; + }; + + auto try_print_address_of = [&]() -> bool { + const auto *dst_addr = op->args[0].as(); + const auto *src_addr = op->args[1].as(); + if (dst_addr == nullptr || src_addr == nullptr || + !dst_addr->op.same_as(builtin::address_of()) || + !src_addr->op.same_as(builtin::address_of()) || + dst_addr->args.empty() || src_addr->args.empty()) { + return false; + } + return emit_from_loads(dst_addr->args[0].as(), + src_addr->args[0].as()); + }; + + if (try_print_address_of()) { + return true; + } + + auto try_print_tl_access_ptr = [&]() -> bool { + const auto *dst_call = op->args[0].as(); + const auto *src_call = op->args[1].as(); + if (dst_call == nullptr || src_call == nullptr || + !dst_call->op.same_as(tl::access_ptr()) || + !src_call->op.same_as(tl::access_ptr()) || + dst_call->args.size() != 3U || src_call->args.size() != 3U) { + return false; + } + return emit_from_loads(dst_call->args[0].as(), + src_call->args[0].as()); + }; + + if (try_print_tl_access_ptr()) { + return true; + } + + const auto *dst_call = op->args[0].as(); + const auto *src_call = op->args[1].as(); + if (dst_call == nullptr || src_call == nullptr || + !dst_call->op.same_as(builtin::tvm_access_ptr()) || + !src_call->op.same_as(builtin::tvm_access_ptr()) || + dst_call->args.size() < 5U || src_call->args.size() < 5U) { + return false; + } + const auto *dst_var = dst_call->args[1].as(); + const auto *src_var = src_call->args[1].as(); + DataType dst_storage_type = dst_elem_type->element_of(); + DataType src_storage_type = src_elem_type->element_of(); + if (dst_var == nullptr || src_var == nullptr || + !IsFp4PaddedSharedStorage(dst_var, dst_storage_type) || + !IsFp4PackedStorage(src_var, src_storage_type)) { + return false; + } + matched_fp4_padded_cp_async = true; + if (!is_fp4_row_aligned(dst_call->args[2]) || + !is_fp4_row_aligned(src_call->args[2])) { + return false; + } + + for (int segment = 0; segment < num_segments; ++segment) { + PrimExpr dst_logical_offset = + IntImm(dst_call->args[2].dtype(), static_cast(segment * 16)); + PrimExpr src_logical_offset = + IntImm(src_call->args[2].dtype(), static_cast(segment * 16)); + PrimExpr dst_byte_offset = truncdiv( + GetFp4PaddedSharedIndex(dst_call->args[2] + dst_logical_offset), 2); + PrimExpr src_byte_offset = + truncdiv(src_call->args[2] + src_logical_offset, 2); + print_cp_async(print_byte_ptr(dst_call->args[1], dst_byte_offset), + print_byte_ptr(src_call->args[1], src_byte_offset)); + } + return true; + }; + if (op->op.same_as(builtin::ptx_cp_async())) { // args[0] = dst_access_ptr, args[1] = src_access_ptr, args[2] = bytes, // args[3] = predicate (optional) @@ -2081,6 +2455,16 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { << "ptx_cp_async expects 3 or 4 arguments (dst_access_ptr, " "src_access_ptr, bytes, [predicate])"; + const auto *bytes_imm = op->args[2].as(); + matched_fp4_padded_cp_async = false; + if (bytes_imm != nullptr && + (bytes_imm->value == 8 || bytes_imm->value == 16) && + try_print_fp4_padded_cp_async(static_cast(bytes_imm->value / 8))) { + return; + } + ICHECK(!matched_fp4_padded_cp_async) + << "SM120 FP4 padded cp.async requires 16-element aligned offsets"; + std::string dst = this->PrintExpr(op->args[0]); std::string src = this->PrintExpr(op->args[1]); std::string size = this->PrintExpr(op->args[2]); @@ -2101,8 +2485,39 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { // args[2] = num_elems, args[3] = predicate (optional) int total_bytes = GetTileLangCPAsyncTransferBytes(op); - std::string dst = this->PrintExpr(op->args[0]); - std::string src = this->PrintExpr(op->args[1]); + const auto *num_elems_imm = op->args[2].as(); + matched_fp4_padded_cp_async = false; + if (num_elems_imm != nullptr && + (num_elems_imm->value == 16 || num_elems_imm->value == 32) && + try_print_fp4_padded_cp_async( + static_cast(num_elems_imm->value / 16))) { + return; + } + ICHECK(!matched_fp4_padded_cp_async) + << "SM120 FP4 padded cp.async requires 16-element aligned offsets"; + + auto print_access_ptr = [&](const PrimExpr &access_ptr) { + auto elem_type = GetAccessPtrElementType(access_ptr); + const auto *ptr_call = access_ptr.as(); + if (elem_type.has_value() && elem_type->is_float4_e2m1fn() && + ptr_call != nullptr && + ptr_call->op.same_as(builtin::tvm_access_ptr())) { + ICHECK_GE(ptr_call->args.size(), 5U); + std::string base = this->PrintExpr(ptr_call->args[1]); + PrimExpr offset_expr = ptr_call->args[2]; + if (const auto *base_var = ptr_call->args[1].as()) { + if (IsFp4PackedStorage(base_var, elem_type.value())) { + // Generic FP4 cp.async operands keep packed-byte addressing. + offset_expr = arith::Analyzer().Simplify(truncdiv(offset_expr, 2)); + } + } + return base + " + " + this->PrintExpr(offset_expr); + } + return this->PrintExpr(access_ptr); + }; + + std::string dst = print_access_ptr(op->args[0]); + std::string src = print_access_ptr(op->args[1]); std::string size = std::to_string(total_bytes); this->PrintIndent(); @@ -2270,10 +2685,40 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { } else if (op->op.same_as(tl::ptx_ldmatrix())) { int trans = Downcast(op->args[0])->value; int num = Downcast(op->args[1])->value; - std::string func_name = "tl::ptx_ldmatrix_x" + std::to_string(num); - if (trans == 1) - func_name += "_trans"; - print_extern_call_stmt(func_name, 2); + auto dst_elem_type = GetAccessPtrElementType(op->args[3]); + bool is_fp4_ldmatrix = + dst_elem_type.has_value() && dst_elem_type->is_float4_e2m1fn(); + std::string func_name; + if (is_fp4_ldmatrix) { + Target cur_target = Target::Current(/*allow_not_defined=*/true); + ICHECK(cur_target.defined() && tl::TargetHasSMVersionGE(cur_target, 120)) + << "SM120 b4x16 ldmatrix requires SM120+"; + ICHECK_EQ(trans, 0) << "SM120 b4x16 ldmatrix does not support trans"; + enable_fp4_ = true; + // SM120 FP4 ldmatrix uses the b4x16_p64 shared-memory form. + func_name = "tl::ptx_ldmatrix_b4x16_x" + std::to_string(num); + + auto print_access_ptr = [&](const PrimExpr &access_ptr) { + const auto *ptr_call = access_ptr.as(); + if (ptr_call != nullptr && + ptr_call->op.same_as(builtin::tvm_access_ptr())) { + ICHECK_GE(ptr_call->args.size(), 5U); + std::string base = this->PrintExpr(ptr_call->args[1]); + return base + " + " + this->PrintExpr(ptr_call->args[2]); + } + return this->PrintExpr(access_ptr); + }; + + std::string src_ptr = print_access_ptr(op->args[2]); + std::string dst_ptr = print_access_ptr(op->args[3]); + this->PrintIndent(); + this->stream << func_name << "(" << src_ptr << ", " << dst_ptr << ");\n"; + } else { + func_name = "tl::ptx_ldmatrix_x" + std::to_string(num); + if (trans == 1) + func_name += "_trans"; + print_extern_call_stmt(func_name, 2); + } } else if (op->op.same_as(tl::ptx_stmatrix())) { int trans = Downcast(op->args[0])->value; int num = Downcast(op->args[1])->value; @@ -3039,6 +3484,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { std::string local_ptr = this->PrintExpr(op->args[3]); bool is_packed_int4 = op->dtype.bits() == 4 && (op->dtype.is_int() || op->dtype.is_uint()); + bool is_fp4_ldmatrix = op->dtype.is_float4_e2m1fn(); PrimExpr local_elem_offset_expr = op->args[4]; if (is_packed_int4) { local_elem_offset_expr = @@ -3066,8 +3512,19 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { arith::Analyzer().Simplify(truncdiv(smem_elem_offset_expr, 2)); } std::string smem_elem_offset = this->PrintExpr(smem_elem_offset_expr); - std::string func_name = "tl::ptx_ldmatrix_x" + std::to_string(num); - if (trans == 1) + std::string func_name; + if (is_fp4_ldmatrix) { + Target cur_target = Target::Current(/*allow_not_defined=*/true); + ICHECK(cur_target.defined() && + tl::TargetHasSMVersionGE(cur_target, 120)) + << "SM120 b4x16 ldmatrix requires SM120+"; + ICHECK(!trans) << "SM120 b4x16 ldmatrix does not support trans"; + enable_fp4_ = true; + func_name = "tl::ptx_ldmatrix_b4x16_x" + std::to_string(num); + } else { + func_name = "tl::ptx_ldmatrix_x" + std::to_string(num); + } + if (trans == 1 && !is_fp4_ldmatrix) func_name += "_trans"; this->PrintIndent(); this->stream << func_name << "(" << smem_ptr << " + " << smem_elem_offset @@ -4080,16 +4537,21 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) { } else if (scope == "local.descriptor.tcgen05_instr") { stream << "tl::Tcgen05InstrDescriptor " << vid << ";\n"; } else { - // For FP4 scalar local buffers, we use packed storage type, - // so skip type declaration here (will be handled in the local scope section - // below) - bool is_fp4_scalar_local = - op->dtype.is_float4() && op->dtype.is_scalar() && scope == "local"; + // Only int4 scalar locals use compact backing in this allocation path. + // FP4 locals are declared as semantic fp4_e2_t elements so local.fragment + // buffers and generated MMA operands share the same variable name. bool is_int4_scalar_local = (op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4)) && - op->dtype.is_scalar() && scope == "local"; - if (!is_fp4_scalar_local && !is_int4_scalar_local) { + op->dtype.is_scalar() && + (scope == "local" || scope == "local.fragment"); + if (!is_int4_scalar_local) { PrintStorageScope(scope, stream); + if (op->dtype.is_float4_e2m1fn() && op->dtype.is_scalar() && + (scope == "local" || scope == "local.fragment")) { + // Vectorized FP4 fragments reinterpret groups such as fp4_e2_16_t and + // fp4_e2_32_t from this local backing. + stream << "alignas(16) "; + } PrintType(op->dtype, stream); } } @@ -4118,25 +4580,16 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) { PrintIndent(); stream << "auto " << vid << " = reinterpret_cast<" << mbarrier_dtype_ << "*>(" << v_id_mem << ");\n"; - } else if (scope == "local") { + } else if (scope == "local" || scope == "local.fragment") { if (op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4)) { stream << "alignas(16) "; PrintType(op->dtype, stream); stream << ' ' << vid << '[' << (constant_size + 1) / 2 << "];\n"; } else { - // For FP4 types, use packed storage type to avoid wasting registers. - // fp4_e2_t uses int8 as storage but only needs 4 bits per element. - // By using fp4_e2_2_t (which stores 2 fp4 values in 1 byte), we halve - // the storage. - if (op->dtype.is_float4() && op->dtype.is_scalar()) { - auto vid_packed = vid + "_packed"; - stream << "fp4_e2_2_t " << vid_packed << '[' - << (constant_size + 1) / 2 << "];\n"; - // Record mapping from original buffer to packed buffer name - fp4_packed_buffers_[op->buffer_var.get()] = vid_packed; - } else { - stream << ' ' << vid << '[' << constant_size << "];\n"; - } + // Local FP4 is not byte-packed. Global/shared scopes own packing + // through GetBufferRef and the packed helpers; local arrays model the + // fragment elements consumed by ldmatrix and MMA. + stream << ' ' << vid << '[' << constant_size << "];\n"; } } else if (scope == "local.var") { PrimExpr init = tir::make_const(op->dtype, 0); @@ -4235,22 +4688,17 @@ void CodeGenTileLangCUDA::VisitExpr_(const BufferLoadNode *op, return; } - // Check if this is a fp4 packed buffer access - auto packed_it = fp4_packed_buffers_.find(buffer_var.get()); - if (packed_it != fp4_packed_buffers_.end() && value_dtype.is_scalar()) { - std::string idx_str = PrintExpr(index); - os << "tl_fp4_packed_load(" << packed_it->second << ", " << idx_str << ")"; - return; - } int lanes = op->dtype.lanes(); // declare type. if (value_dtype.lanes() == element_dtype.lanes()) { - // For scalar fp4 loads from non-packed buffers, use tl_fp4_packed_load - // to correctly extract the nibble at the given index (the /2 in - // GetBufferRef maps two consecutive fp4 elements to the same byte, but - // reading that byte only returns the low nibble — the odd-indexed element - // is lost). - if (element_dtype.is_float4() && element_dtype.lanes() == 1) { + // Scalar FP4 loads need nibble selection even when the backing storage is + // byte-indexed. Use the packed helper with a logical FP4 index; SM120 + // shared memory first maps that index through the b4x16 padded-row layout. + if (IsFp4PaddedSharedStorage(buffer_var.get(), element_dtype)) { + std::string idx_str = PrintExpr(GetFp4PaddedSharedIndex(index)); + std::string vid = GetVarID(buffer_var.get()); + os << "tl_fp4_packed_load((fp4_e2_2_t*)" << vid << ", " << idx_str << ")"; + } else if (IsFp4PackedStorage(buffer_var.get(), element_dtype)) { std::string idx_str = PrintExpr(index); std::string vid = GetVarID(buffer_var.get()); os << "tl_fp4_packed_load((fp4_e2_2_t*)" << vid << ", " << idx_str << ")"; @@ -4334,22 +4782,18 @@ void CodeGenTileLangCUDA::VisitStmt_(const BufferStoreNode *op) { return; } - // Check if this is a fp4 packed buffer access - auto packed_it = fp4_packed_buffers_.find(buffer_var.get()); - if (packed_it != fp4_packed_buffers_.end() && value_dtype.is_scalar()) { - std::string idx_str = PrintExpr(index_expr); - std::string value = this->PrintExpr(op->value); - this->PrintIndent(); - stream << "tl_fp4_packed_store(" << packed_it->second << ", " << idx_str - << ", " << value << ");\n"; - return; - } if (value_dtype.lanes() == element_dtype.lanes()) { - // For scalar fp4 stores to non-packed buffers, use tl_fp4_packed_store - // to correctly handle nibble-level writes. The /2 in GetBufferRef maps two - // consecutive fp4 elements to the same byte, and a plain assignment - // overwrites the entire byte — destroying the neighboring nibble. - if (element_dtype.is_float4() && element_dtype.lanes() == 1) { + // Scalar FP4 stores update one nibble at the logical index. A plain + // assignment to the backing byte would overwrite the neighboring element; + // SM120 shared memory first applies the b4x16 padded-row layout. + if (IsFp4PaddedSharedStorage(buffer_var.get(), element_dtype)) { + std::string idx_str = PrintExpr(GetFp4PaddedSharedIndex(index_expr)); + std::string value = this->PrintExpr(op->value); + std::string vid = GetVarID(buffer_var.get()); + this->PrintIndent(); + stream << "tl_fp4_packed_store((fp4_e2_2_t*)" << vid << ", " << idx_str + << ", " << value << ");\n"; + } else if (IsFp4PackedStorage(buffer_var.get(), element_dtype)) { std::string idx_str = PrintExpr(index_expr); std::string value = this->PrintExpr(op->value); std::string vid = GetVarID(buffer_var.get()); diff --git a/src/backend/cuda/codegen/codegen_cuda.h b/src/backend/cuda/codegen/codegen_cuda.h index 656e7dbf3..4a2450811 100644 --- a/src/backend/cuda/codegen/codegen_cuda.h +++ b/src/backend/cuda/codegen/codegen_cuda.h @@ -80,6 +80,15 @@ class CodeGenTileLangCUDA final : public CodeGenC { void HandleVolatileLoads(const std::string &value, const BufferLoadNode *op, std::ostream &os) final; bool HandleLateIntrinsicCall(const CallNode *op, std::ostream &os); + // FP4 address helpers distinguish packed global storage from SM120 padded + // shared storage; local fragments remain semantic FP4 arrays. + bool IsFp4PackedStorage(const VarNode *buffer_var, + DataType element_dtype) const; + bool IsFp4PaddedSharedStorage(const VarNode *buffer_var, + DataType element_dtype) const; + bool IsFp4SemanticLocalStorage(const VarNode *buffer_var, + DataType element_dtype) const; + PrimExpr GetFp4PaddedSharedIndex(PrimExpr index) const; // Whether scope such as "__shared__" or "__constant__" is part of type. bool IsScopePartOfType() const final { return false; } @@ -156,8 +165,6 @@ class CodeGenTileLangCUDA final : public CodeGenC { std::unordered_map fragment_layouts; std::unordered_map unroll_factor; std::optional> cluster_dims; - // Map from VarNode to packed buffer variable name for fp4 packed storage - std::unordered_map fp4_packed_buffers_; friend void PrintConst(const FloatImmNode *op, std::ostream &os, CodeGenTileLangCUDA *p); void PrintWmmaScope(const std::string &scope, DataType t, diff --git a/src/backend/cuda/op/copy.cc b/src/backend/cuda/op/copy.cc index 14cbd6348..fdab925f9 100644 --- a/src/backend/cuda/op/copy.cc +++ b/src/backend/cuda/op/copy.cc @@ -473,6 +473,11 @@ Stmt Copy::LowerCPAsync(const CopyNode &op, const LowerArgs &T, return LowerNormal(op, T, analyzer); } + bool fp4_padded_shared_copy = + TargetIsSM120(T.target) && IsGlobalBuffer(op.src) && + IsSharedBuffer(op.dst) && op.src->dtype.is_float4_e2m1fn() && + op.dst->dtype.is_float4_e2m1fn(); + auto simt_loop = op.MakeSIMTLoop(analyzer); auto fused_loop = Downcast(ParallelLoopFuser::Fuse(simt_loop)); auto par_op = ParallelOp(fused_loop); @@ -494,13 +499,15 @@ Stmt Copy::LowerCPAsync(const CopyNode &op, const LowerArgs &T, LowerParallelLoop(par_op->GetRoot(), loop_layout, T.thread_var, analyzer, T.layout_map, par_op->GetPredicate(T.thread_var)); - auto inject_result = - InjectPTXAsyncCopy(lowered_loop, /*enable_auto_async_copy=*/true, - /*async_without_async_commit_wait=*/ - no_implicit_commit_wait || GetIsAsyncCopy(op)); + auto inject_result = InjectPTXAsyncCopy( + lowered_loop, /*enable_auto_async_copy=*/true, + /*async_without_async_commit_wait=*/ + no_implicit_commit_wait || GetIsAsyncCopy(op), fp4_padded_shared_copy); Stmt cp_async_loop = inject_result.stmt; if (!inject_result.injected_ptx_async_copy) { - DLOG(WARNING) << "cp.async rewrite miss for copy src=" << op.src->name + const char *copy_kind = + fp4_padded_shared_copy ? "SM120 FP4 padded cp.async" : "cp.async"; + DLOG(WARNING) << copy_kind << " rewrite miss for copy src=" << op.src->name << " (scope=" << op.src.scope() << ", dtype=" << op.src->dtype << "), dst=" << op.dst->name << " (scope=" << op.dst.scope() << ", dtype=" << op.dst->dtype @@ -514,8 +521,9 @@ Stmt Copy::LowerCPAsync(const CopyNode &op, const LowerArgs &T, return lowered_loop; } if (explicit_async_semantics) { - LOG(FATAL) << "Explicit async copy semantics require cp.async lowering, " - "but no eligible global->shared store was rewritten."; + LOG(FATAL) << "Explicit async copy semantics require " << copy_kind + << " lowering, but no eligible global->shared store was " + "rewritten."; } DLOG(WARNING) << "Fallback to normal copy because cp.async rewrite found " "no eligible global->shared store."; @@ -562,6 +570,10 @@ Stmt Copy::LowerLDSM(const CopyNode &op, const LowerArgs &T, Buffer shared_tensor = is_ldmatrix ? src : dst; Buffer local_tensor = is_ldmatrix ? dst : src; + bool is_fp4_ldmatrix = is_ldmatrix && shared_tensor->dtype.is_float4_e2m1fn(); + if (is_fp4_ldmatrix && !TargetIsSM120(T.target)) { + return LowerNormal(op, T, analyzer); + } Array local_region = is_ldmatrix ? src_range : dst_range; bool is_full_range = true; for (size_t i = 0; i < local_region.size(); i++) { @@ -611,12 +623,35 @@ Stmt Copy::LowerLDSM(const CopyNode &op, const LowerArgs &T, } else { return LowerNormal(op, T, analyzer); } - if (shared_tensor->dtype.bytes() != 2) { + if (is_fp4_ldmatrix && is_transposed) { return LowerNormal(op, T, analyzer); } + if (!is_fp4_ldmatrix && shared_tensor->dtype.bytes() != 2) { + return LowerNormal(op, T, analyzer); + } + + PrimExpr extent = local_tensor->shape[0]; + int num = 1; + if (is_fp4_ldmatrix) { + if (analyzer->CanProveEqual(FloorMod(extent, 16), 0)) + num = 4; + else if (analyzer->CanProveEqual(FloorMod(extent, 8), 0)) + num = 2; + } else { + if (analyzer->CanProveEqual(FloorMod(extent, 8), 0)) + num = 4; + else if (analyzer->CanProveEqual(FloorMod(extent, 4), 0)) + num = 2; + } + int elems_per_reg = is_fp4_ldmatrix ? 4 : 2; + int elems_per_inst = elems_per_reg * num; + // FP4 b4x16 ldmatrix returns four logical FP4 elements per 32-bit register, + // while the existing b16 path returns two 16-bit elements per register. + PrimExpr flattened_indice = shared_tensor.OffsetOf(shared_indices).back(); if (!IndicesCanVectorize(flattened_indice, loop_vars.back()->var, - loop_vars.back()->dom->extent, 8, analyzer)) { + loop_vars.back()->dom->extent, elems_per_inst, + analyzer)) { return LowerNormal(op, T, analyzer); } @@ -626,13 +661,6 @@ Stmt Copy::LowerLDSM(const CopyNode &op, const LowerArgs &T, return LowerNormal(op, T, analyzer); } - PrimExpr extent = local_tensor->shape[0]; - int num = 1; - if (analyzer->CanProveEqual(FloorMod(extent, 8), 0)) - num = 4; - else if (analyzer->CanProveEqual(FloorMod(extent, 4), 0)) - num = 2; - Array args; const Op ©_op = is_ldmatrix ? tl::ptx_ldmatrix() : tl::ptx_stmatrix(); args.push_back(static_cast(is_transposed)); @@ -644,7 +672,8 @@ Stmt Copy::LowerLDSM(const CopyNode &op, const LowerArgs &T, PrimExpr warp = FloorDiv(T.thread_var, 32) * 32; if (!is_transposed) { auto local_index = analyzer->Simplify( - local_iter * 2 * num + 2 * FloorMod(FloorDiv(T.thread_var, 8), num)); + local_iter * elems_per_inst + + elems_per_reg * FloorMod(FloorDiv(T.thread_var, 8), num)); auto thread_index = analyzer->Simplify(warp + FloorMod(T.thread_var, 8) * 4); shared_coords = inv->Forward({local_index, thread_index}); @@ -659,7 +688,7 @@ Stmt Copy::LowerLDSM(const CopyNode &op, const LowerArgs &T, shared_coords.pop_back(); PrimExpr shared_addr = Call(DataType::Handle(), tl::access_ptr(), - {BufferLoad(shared_tensor, shared_coords), PrimExpr(2 * num), + {BufferLoad(shared_tensor, shared_coords), PrimExpr(elems_per_inst), make_const(DataType::Int(32), is_ldmatrix ? 1 : 2)}); args.push_back(shared_addr); @@ -669,8 +698,8 @@ Stmt Copy::LowerLDSM(const CopyNode &op, const LowerArgs &T, } PrimExpr local_addr = Call(DataType::Handle(), tl::access_ptr(), - {BufferLoad(local_tensor, {local_iter * 2 * num}), - PrimExpr(2 * num), make_const(DataType::Int(32), 2)}); + {BufferLoad(local_tensor, {local_iter * elems_per_inst}), + PrimExpr(elems_per_inst), make_const(DataType::Int(32), 2)}); args.push_back(local_addr); } else { for (int i = 0; i < num; i++) { @@ -689,8 +718,8 @@ Stmt Copy::LowerLDSM(const CopyNode &op, const LowerArgs &T, } auto body = Evaluate(Call(DataType::Handle(), copy_op, args)); - For for_node = - For(local_iter, 0, FloorDiv(extent, 2 * num), ForKind::kSerial, body); + For for_node = For(local_iter, 0, FloorDiv(extent, elems_per_inst), + ForKind::kSerial, body); for_node = PragmaUnrollLoop(for_node); auto range = T.thread_bounds; if (range.defined()) { diff --git a/src/backend/cuda/op/copy_analysis.cc b/src/backend/cuda/op/copy_analysis.cc index d43aa5539..3bd3ba587 100644 --- a/src/backend/cuda/op/copy_analysis.cc +++ b/src/backend/cuda/op/copy_analysis.cc @@ -258,11 +258,20 @@ bool CheckBulkStore1D(const CopyNode &op, Target target, } bool CheckLDSMCopy(const CopyNode &op, Target target) { - return TargetHasLdmatrix(target) && IsSharedBuffer(op.src) && - IsFragmentBuffer(op.dst); + if (!TargetHasLdmatrix(target) || !IsSharedBuffer(op.src) || + !IsFragmentBuffer(op.dst)) { + return false; + } + if (op.src->dtype.is_float4_e2m1fn() || op.dst->dtype.is_float4_e2m1fn()) { + return TargetIsSM120(target) && op.src->dtype == op.dst->dtype; + } + return true; } bool CheckSTSMCopy(const CopyNode &op, Target target) { + if (op.src->dtype.is_float4_e2m1fn() || op.dst->dtype.is_float4_e2m1fn()) { + return false; + } return TargetHasStmatrix(target) && IsFragmentBuffer(op.src) && IsSharedBuffer(op.dst); } diff --git a/src/tl_templates/cuda/cuda_fp4.h b/src/tl_templates/cuda/cuda_fp4.h index d2efef24a..30d682b9b 100644 --- a/src/tl_templates/cuda/cuda_fp4.h +++ b/src/tl_templates/cuda/cuda_fp4.h @@ -2,7 +2,8 @@ #include "common.h" -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) || \ + (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 1200)) #include // Wrapper for __nv_fp4_e2m1 with implicit conversions @@ -161,6 +162,30 @@ TL_DEVICE fp4_e2_32_t make_fp4_e2_32_t( return result; } +// Pack sixty-four fp4_e2_t values. +template +TL_DEVICE fp4_e2_64_t make_fp4_e2_64_t(Args... args) { + static_assert(sizeof...(Args) == 64, + "make_fp4_e2_64_t expects 64 fp4 values"); + fp4_e2_t values[64] = {fp4_e2_t(args)...}; + fp4_e2_64_t result; + result.x = make_fp4_e2_32_t( + values[0], values[1], values[2], values[3], values[4], values[5], + values[6], values[7], values[8], values[9], values[10], values[11], + values[12], values[13], values[14], values[15], values[16], values[17], + values[18], values[19], values[20], values[21], values[22], values[23], + values[24], values[25], values[26], values[27], values[28], values[29], + values[30], values[31]); + result.y = make_fp4_e2_32_t( + values[32], values[33], values[34], values[35], values[36], values[37], + values[38], values[39], values[40], values[41], values[42], values[43], + values[44], values[45], values[46], values[47], values[48], values[49], + values[50], values[51], values[52], values[53], values[54], values[55], + values[56], values[57], values[58], values[59], values[60], values[61], + values[62], values[63]); + return result; +} + // ============================================================================ // FP4 <-> Half Precision Conversions // ============================================================================ diff --git a/src/tl_templates/cuda/gemm_mma.h b/src/tl_templates/cuda/gemm_mma.h index b73369ecf..02286b4cf 100644 --- a/src/tl_templates/cuda/gemm_mma.h +++ b/src/tl_templates/cuda/gemm_mma.h @@ -42,9 +42,18 @@ using _X = Underscore; #ifdef __CUDA_ARCH_LIST__ #if __CUDA_ARCH_LIST__ >= 1200 +#include "cuda_fp4.h" #include "cuda_fp8.h" #include #include +namespace tl { +template <> struct to_cute_type { + using type = cute::float_e2m1_t; +}; +} // namespace tl +TL_DISPATCH_MMA_TEMPLATE(fp4_e2_t, fp4_e2_t, float, SM120_16x8x32_TN) +TL_DISPATCH_MMA_TEMPLATE(fp8_e4_t, fp4_e2_t, float, SM120_16x8x32_TN) +TL_DISPATCH_MMA_TEMPLATE(fp4_e2_t, fp8_e4_t, float, SM120_16x8x32_TN) TL_DISPATCH_MMA_TEMPLATE(fp8_e4_t, fp8_e4_t, float, SM120_16x8x32_TN) TL_DISPATCH_MMA_TEMPLATE(fp8_e5_t, fp8_e5_t, float, SM120_16x8x32_TN) TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN) diff --git a/src/tl_templates/cuda/instruction/mma.h b/src/tl_templates/cuda/instruction/mma.h index c4a276f3a..31bbcb8d9 100644 --- a/src/tl_templates/cuda/instruction/mma.h +++ b/src/tl_templates/cuda/instruction/mma.h @@ -1,6 +1,7 @@ #pragma once #include "../common.h" +#include #include #include @@ -146,6 +147,20 @@ TL_DEFINE_MMA_DISPATCHER(kTensorFloat32, kTensorFloat32, kFloat32, 16, 8, 8, TL_DEFINE_MMA_DISPATCHER(kFloat64, kFloat64, kFloat64, 8, 8, 4, false, true, false, cute::SM80_8x8x4_F64F64F64F64_TN) +// SM120 FP4/F8F6F4 inputs (k32) +using SM120_FP4_FP4_F32_TN = + cute::SM120_16x8x32_TN; +using SM120_FP8_FP4_F32_TN = + cute::SM120_16x8x32_TN; +using SM120_FP4_FP8_F32_TN = + cute::SM120_16x8x32_TN; +TL_DEFINE_MMA_DISPATCHER(kFloat4_e2m1fn, kFloat4_e2m1fn, kFloat32, 16, 8, 32, + false, true, false, SM120_FP4_FP4_F32_TN) +TL_DEFINE_MMA_DISPATCHER(kFloat8_e4m3, kFloat4_e2m1fn, kFloat32, 16, 8, 32, + false, true, false, SM120_FP8_FP4_F32_TN) +TL_DEFINE_MMA_DISPATCHER(kFloat4_e2m1fn, kFloat8_e4m3, kFloat32, 16, 8, 32, + false, true, false, SM120_FP4_FP8_F32_TN) + #undef TL_DEFINE_MMA_DISPATCHER } // namespace detail @@ -163,7 +178,37 @@ TL_DEVICE void mma_sync( TransB, Saturate>; static_assert(!std::is_void_v, "tl::mma_sync: unsupported configuration"); - Dispatcher::exec(c, a, b, c); + if constexpr (AType == DataType::kFloat4_e2m1fn || + BType == DataType::kFloat4_e2m1fn) { + // SM120 f8f6f4 MMA expects FP4 operands in the same register placement as + // CuTe's b4x16 load path. Shift only FP4 operands; mixed FP8 operands keep + // their native register bits. + using AReg = typename Dispatcher::ARegType; + using BReg = typename Dispatcher::BRegType; + constexpr int nA = detail::MmaImplTraits::kARegs; + constexpr int nB = detail::MmaImplTraits::kBRegs; + AReg as[nA]; + BReg bs[nB]; +#pragma unroll + for (int i = 0; i < nA; ++i) { + if constexpr (AType == DataType::kFloat4_e2m1fn) { + as[i] = a[i] << 2; + } else { + as[i] = a[i]; + } + } +#pragma unroll + for (int i = 0; i < nB; ++i) { + if constexpr (BType == DataType::kFloat4_e2m1fn) { + bs[i] = b[i] << 2; + } else { + bs[i] = b[i]; + } + } + Dispatcher::exec(c, as, bs, c); + } else { + Dispatcher::exec(c, a, b, c); + } } } // namespace tl diff --git a/src/tl_templates/cuda/ldsm.h b/src/tl_templates/cuda/ldsm.h index a20746dff..d2f97989a 100644 --- a/src/tl_templates/cuda/ldsm.h +++ b/src/tl_templates/cuda/ldsm.h @@ -32,6 +32,50 @@ TL_DEVICE void ptx_ldmatrix_x4(void const *const smem_ptr, : "r"(smem_int_ptr)); } +TL_DEVICE void ptx_ldmatrix_b4x16_x1(void const *const smem_ptr, + void *const local_ptr) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1200) + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + int32_t *value = reinterpret_cast(local_ptr); + asm volatile( + "ldmatrix.sync.aligned.m8n16.x1.shared.b8x16.b4x16_p64 {%0}, [%1];\n" + : "=r"(value[0]) + : "r"(smem_int_ptr)); +#else + TILELANG_UNREACHABLE("b4x16 ldmatrix requires SM120+"); +#endif +} + +TL_DEVICE void ptx_ldmatrix_b4x16_x2(void const *const smem_ptr, + void *const local_ptr) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1200) + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + int32_t *value = reinterpret_cast(local_ptr); + asm volatile( + "ldmatrix.sync.aligned.m8n16.x2.shared.b8x16.b4x16_p64 {%0, %1}, " + "[%2];\n" + : "=r"(value[0]), "=r"(value[1]) + : "r"(smem_int_ptr)); +#else + TILELANG_UNREACHABLE("b4x16 ldmatrix requires SM120+"); +#endif +} + +TL_DEVICE void ptx_ldmatrix_b4x16_x4(void const *const smem_ptr, + void *const local_ptr) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1200) + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + int32_t *value = reinterpret_cast(local_ptr); + asm volatile( + "ldmatrix.sync.aligned.m8n16.x4.shared.b8x16.b4x16_p64 {%0, %1, %2, " + "%3}, [%4];\n" + : "=r"(value[0]), "=r"(value[1]), "=r"(value[2]), "=r"(value[3]) + : "r"(smem_int_ptr)); +#else + TILELANG_UNREACHABLE("b4x16 ldmatrix requires SM120+"); +#endif +} + TL_DEVICE void ptx_ldmatrix_x1_trans(void const *const smem_ptr, void *const local_ptr) { uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); diff --git a/src/transform/lower_ptx_async_copy.cc b/src/transform/lower_ptx_async_copy.cc index fc90b9ad4..5ce6f73ed 100644 --- a/src/transform/lower_ptx_async_copy.cc +++ b/src/transform/lower_ptx_async_copy.cc @@ -32,9 +32,11 @@ using namespace tir; class PTXAsyncCopyInjector : public StmtMutator { public: explicit PTXAsyncCopyInjector(bool enable_auto_async_copy, - bool async_without_async_commit_wait) + bool async_without_async_commit_wait, + bool fp4_padded_shared_copy) : enable_auto_async_copy_(enable_auto_async_copy), - async_without_async_commit_wait_(async_without_async_commit_wait) {} + async_without_async_commit_wait_(async_without_async_commit_wait), + fp4_padded_shared_copy_(fp4_padded_shared_copy) {} bool InjectedPTXAsyncCopy() const { return injected_ptx_async_copy_; } @@ -111,6 +113,16 @@ class PTXAsyncCopyInjector : public StmtMutator { return Optional(); } + if (fp4_padded_shared_copy_ && IsFp4GlobalToSharedCopy(load, store)) { + // FP4 shared destinations on SM120 are padded for b4x16_p64 ldmatrix. + // Emit 16-FP4/8B segments instead of a single contiguous byte span. + Optional fp4_cp_async = MakeFp4PaddedCPAsyncStmt( + load, store, *index_info, predicated, predicate_value); + if (fp4_cp_async.defined()) { + return fp4_cp_async; + } + } + if (index_info->index_lanes == 1) { if (current_vectorized_lanes_ > 1 && !HasContiguousVectorizedOffsets(index_info->src_index, @@ -507,6 +519,69 @@ class PTXAsyncCopyInjector : public StmtMutator { Call(store->buffer->dtype, tvm::tl::ptx_cp_async(), cp_async_args)); } + static PrimExpr GetFp4PaddedSharedIndex(PrimExpr index) { + arith::Analyzer analyzer; + return analyzer.Simplify(truncdiv(index, 16) * 32 + truncmod(index, 16)); + } + + static bool IsFp4GlobalToSharedCopy(const BufferLoadNode *load, + const BufferStoreNode *store) { + return IsGlobalBuffer(load->buffer) && IsSharedBuffer(store->buffer) && + load->buffer->dtype.is_float4_e2m1fn() && + store->buffer->dtype.is_float4_e2m1fn() && + load->buffer->dtype == store->buffer->dtype; + } + + Optional MakeFp4PaddedCPAsyncStmt(const BufferLoadNode *load, + const BufferStoreNode *store, + const CopyIndexInfo &index_info, + bool predicated, + const PrimExpr &predicate_value) { + if (current_vectorized_lanes_ != 1 || + (index_info.per_access_num_elems != 16 && + index_info.per_access_num_elems != 32)) { + return Optional(); + } + + PrimExpr src_base = ExtractVectorBase(index_info.src_index); + PrimExpr dst_base = ExtractVectorBase(index_info.dst_index); + if (!src_base.defined() || !dst_base.defined()) { + return Optional(); + } + + Buffer flat_src = load->buffer.GetFlattenedBuffer(); + Buffer flat_dst = store->buffer.GetFlattenedBuffer(); + + Array stmts; + int num_segments = index_info.per_access_num_elems / 16; + for (int segment = 0; segment < num_segments; ++segment) { + // Source FP4 is packed contiguously, while destination FP4 follows the + // padded shared layout. Each segment copies 16 FP4 values, i.e. 8 bytes. + PrimExpr src_logical_offset = + IntImm(src_base.dtype(), static_cast(segment * 16)); + PrimExpr dst_logical_offset = + IntImm(dst_base.dtype(), static_cast(segment * 16)); + PrimExpr src_index = analyzer_.Simplify(src_base + src_logical_offset); + PrimExpr dst_logical_index = + analyzer_.Simplify(dst_base + dst_logical_offset); + PrimExpr dst_byte_index = analyzer_.Simplify( + truncdiv(GetFp4PaddedSharedIndex(dst_logical_index), 2)); + + BufferLoad src_base_load = BufferLoad(flat_src, {src_index}); + BufferLoad dst_base_load = BufferLoad(flat_dst, {dst_byte_index}); + Optional cp_async = MakeCPAsyncStmtFromLoads( + store, dst_base_load, src_base_load, /*num_elems=*/16, predicated, + predicate_value); + ICHECK(cp_async.defined()); + stmts.push_back(cp_async.value()); + } + + if (stmts.size() == 1) { + return stmts[0]; + } + return SeqStmt(stmts); + } + static Stmt MakeCommitGroupStmt() { return Evaluate(Call(DataType::Handle(), builtin::ptx_commit_group(), {})); } @@ -687,6 +762,7 @@ class PTXAsyncCopyInjector : public StmtMutator { bool enable_auto_async_copy_{true}; bool async_without_async_commit_wait_{false}; + bool fp4_padded_shared_copy_{false}; int explicit_async_scope_depth_{0}; int current_vectorized_lanes_{1}; std::vector active_vectorized_loops_; @@ -700,9 +776,11 @@ using namespace tir::transform; PTXAsyncCopyInjectResult InjectPTXAsyncCopy(const Stmt &body, bool enable_auto_async_copy, - bool async_without_async_commit_wait) { + bool async_without_async_commit_wait, + bool fp4_padded_shared_copy) { PTXAsyncCopyInjector injector(enable_auto_async_copy, - async_without_async_commit_wait); + async_without_async_commit_wait, + fp4_padded_shared_copy); Stmt injected = injector(body); return {injector.Finalize(injected), injector.InjectedPTXAsyncCopy()}; } diff --git a/src/transform/ptx_async_copy_injector.h b/src/transform/ptx_async_copy_injector.h index 80c642562..e51727da3 100644 --- a/src/transform/ptx_async_copy_injector.h +++ b/src/transform/ptx_async_copy_injector.h @@ -18,7 +18,8 @@ struct PTXAsyncCopyInjectResult { */ PTXAsyncCopyInjectResult InjectPTXAsyncCopy(const tvm::tir::Stmt &body, bool enable_auto_async_copy, - bool async_without_async_commit_wait = false); + bool async_without_async_commit_wait = false, + bool fp4_padded_shared_copy = false); } // namespace tl } // namespace tvm diff --git a/tilelang/cuda/intrinsics/layout/mma_layout.py b/tilelang/cuda/intrinsics/layout/mma_layout.py index 2eb575f0c..aa226f8f0 100644 --- a/tilelang/cuda/intrinsics/layout/mma_layout.py +++ b/tilelang/cuda/intrinsics/layout/mma_layout.py @@ -39,6 +39,20 @@ def ldmatrix_32x16_to_shared_16x32_layout_b(thread_id, local_id): return row, col +# FP4 keeps logical element coordinates here; CUDA lowering maps the shared +# address through the SM120 b4x16 padded-byte layout. +def ldmatrix_32x16_to_shared_16x32_fp4_layout_a(thread_id, local_id): + row = thread_id % 16 + col = local_id + (thread_id // 16) * 16 + return row, col + + +def ldmatrix_32x16_to_shared_16x32_fp4_layout_b(thread_id, local_id): + row = (thread_id // 16) * 8 + (thread_id % 8) + col = local_id + 16 * ((thread_id % 16) // 8) + return row, col + + def mma_store_32x8_to_shared_16x16_layout(thread_id, local_id): row = 8 * (local_id % 4 // 2) + (thread_id // 4) col = 8 * (local_id // 4) + (thread_id % 4) * 2 + (local_id % 2) diff --git a/tilelang/cuda/intrinsics/layout/utils.py b/tilelang/cuda/intrinsics/layout/utils.py index 050ad0932..c565f52de 100644 --- a/tilelang/cuda/intrinsics/layout/utils.py +++ b/tilelang/cuda/intrinsics/layout/utils.py @@ -7,6 +7,8 @@ ldmatrix_trans_32x8_to_shared_16x16_layout, ldmatrix_32x16_to_shared_16x32_layout_a, ldmatrix_32x16_to_shared_16x32_layout_b, + ldmatrix_32x16_to_shared_16x32_fp4_layout_a, + ldmatrix_32x16_to_shared_16x32_fp4_layout_b, mma_store_32x8_to_shared_16x16_layout, mma_store_32x2_to_shared_8x8_layout_fp64, ) @@ -22,11 +24,13 @@ def get_ldmatrix_offset( row_idx, col_idx, stride, - dtype: Literal["float16", "int8", "int4"] = "float16", + dtype: Literal["float16", "int8", "int4", "float4_e2m1fn"] = "float16", transposed: bool = False, ): assert matrix in ["A", "B"], "matrix should be either A or B" - dtype_bits = DataType(dtype).bits + dtype = DataType(dtype) + dtype_bits = dtype.bits + is_fp4_e2m1fn = dtype_bits == 4 and str(dtype) == "float4_e2m1fn" if dtype_bits == 32: if matrix == "B" and transposed: transform_func = ldmatrix_32x4_to_shared_16x8_layout_b @@ -47,6 +51,19 @@ def get_ldmatrix_offset( else: new_row_idx, new_col_idx = transform_func(row_idx, col_idx) return new_row_idx, new_col_idx + elif is_fp4_e2m1fn: + # FP4 uses the SM120 b4x16_p64 layout. Keep int4/uint4 on the generic + # sub-byte path below; they have different packed-byte offsets. + if matrix == "B" and transposed: + transform_func = ldmatrix_32x16_to_shared_16x32_fp4_layout_b + new_row_idx, new_col_idx = transform_func(row_idx, col_idx) + return new_row_idx, new_col_idx + elif matrix == "A" and not transposed: + transform_func = ldmatrix_32x16_to_shared_16x32_fp4_layout_a + new_row_idx, new_col_idx = transform_func(row_idx, col_idx) + return new_row_idx, new_col_idx + else: + raise ValueError("ldmatrix only supports B transposed and A non-transposed for float4_e2m1fn") elif dtype_bits <= 8: if matrix == "B" and transposed: transform_func = ldmatrix_32x16_to_shared_16x32_layout_b diff --git a/tilelang/cuda/intrinsics/macro/mma_macro_generator.py b/tilelang/cuda/intrinsics/macro/mma_macro_generator.py index 6461dbd88..139a657d2 100644 --- a/tilelang/cuda/intrinsics/macro/mma_macro_generator.py +++ b/tilelang/cuda/intrinsics/macro/mma_macro_generator.py @@ -118,7 +118,12 @@ def __init__( def _initialize_k_dim(self, a_dtype=T.float16): if isinstance(a_dtype, str): a_dtype = DataType(a_dtype) - self.k_dim = min(256 // a_dtype.bits, self.chunk) + if str(a_dtype) == "float4_e2m1fn": + if self.chunk < 32: + raise ValueError(f"float4_e2m1fn MMA requires chunk >= 32, got chunk={self.chunk}") + self.k_dim = 32 + else: + self.k_dim = min(256 // a_dtype.bits, self.chunk) def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32): self.local_size_a = (m_dim * k_dim) // warp_size @@ -295,18 +300,19 @@ def _warp_ld_a_fp64( micro_size_k = self.micro_size_k local_size_a = self.local_size_a a_transposed = self.a_transposed + a_dtype_bits = DataType(a_dtype).bits # ldmatrix cannot be used for int8 + trans case. - ldmatrix_available = not (DataType(a_dtype).bits != 16 and a_transposed) + ldmatrix_available = not (a_dtype_bits != 16 and a_transposed) def mma_load_layout(i, j): return i, j if not ldmatrix_available: - if DataType(a_dtype).bits == 8: + if a_dtype_bits in (4, 8): mma_load_layout = mma_load_a_32x16_to_shared_16x32_layout - elif DataType(a_dtype).bits == 16: + elif a_dtype_bits == 16: mma_load_layout = mma_load_a_32x8_to_shared_16x16_layout - elif DataType(a_dtype).bits == 32: + elif a_dtype_bits == 32: mma_load_layout = mma_load_a_32x4_to_shared_16x8_layout else: raise ValueError(f"Unsupported dtype: {a_dtype}") @@ -337,6 +343,9 @@ def _warp_ldmatrix_a( wi, wk = warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * micro_size_k if ldmatrix_available: + num = 4 + is_fp4 = str(DataType(a_dtype)) == "float4_e2m1fn" + access_extent = 4 * num if is_fp4 else 2 * num row_off, col_off = get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed) src_indices = ( tuple(A_other) + (A_base0 + wk + row_off, A_base1 + wi + col_off) @@ -345,9 +354,9 @@ def _warp_ldmatrix_a( ) T.ptx_ldmatrix( T.bool(trans), - 4, - T.access_ptr(A_buf[src_indices], "r", extent=8), - T.access_ptr(A_local_buf[i * local_size_a], "w", extent=8), + num, + T.access_ptr(A_buf[src_indices], "r", extent=access_extent), + T.access_ptr(A_local_buf[i * local_size_a], "w", extent=access_extent), ) else: for j in T.serial(local_size_a): @@ -407,6 +416,7 @@ def _warp_ld_b_fp64( micro_size_k = self.micro_size_k local_size_b = self.local_size_b b_transposed = self.b_transposed + b_dtype_bits = DataType(b_dtype).bits thread_binding = self.get_thread_binding() # legalize shared buffer to region @@ -418,17 +428,17 @@ def _warp_ld_b_fp64( B_stride_last = B_buf.shape[-1] replicate_b = self.n_dim == 16 # ldmatrix cannot be used for int8 + trans case. - ldmatrix_available = not (DataType(b_dtype).bits != 16 and not b_transposed) + ldmatrix_available = not (b_dtype_bits != 16 and not b_transposed) def mma_load_layout(i, j): return i, j if not ldmatrix_available: - if DataType(b_dtype).bits == 8: + if b_dtype_bits in (4, 8): mma_load_layout = mma_load_b_32x16_to_shared_16x32_layout - elif DataType(b_dtype).bits == 16: + elif b_dtype_bits == 16: mma_load_layout = mma_load_b_32x8_to_shared_16x16_layout - elif DataType(b_dtype).bits == 32: + elif b_dtype_bits == 32: mma_load_layout = mma_load_b_32x4_to_shared_16x8_layout else: raise ValueError(f"Unsupported dtype: {b_dtype}") @@ -454,6 +464,8 @@ def _warp_ldmatrix_b( if ldmatrix_available: num = 4 if replicate_b else 2 + is_fp4 = str(DataType(b_dtype)) == "float4_e2m1fn" + access_extent = 4 * num if is_fp4 else 2 * num row_off, col_off = get_ldmatrix_offset("B", tx, 0, stride, b_dtype, b_transposed) src_indices = ( tuple(B_other) + (B_base0 + wi + row_off, B_base1 + wk + col_off) @@ -463,8 +475,8 @@ def _warp_ldmatrix_b( T.ptx_ldmatrix( T.bool(trans), num, - T.access_ptr(B_buf[src_indices], "r", extent=2 * num), - T.access_ptr(B_local_buf[i * local_size_b], "w", extent=2 * num), + T.access_ptr(B_buf[src_indices], "r", extent=access_extent), + T.access_ptr(B_local_buf[i * local_size_b], "w", extent=access_extent), ) else: @@ -637,7 +649,7 @@ def make_mma_load_layout(self, local_buf: Buffer, matrix: Literal["A", "B"] = "A elif dtype_bits == 16: transform_func_sr_a = shared_16x16_to_mma_32x8_layout_sr_a transform_func_sr_b = shared_16x16_to_mma_32x8_layout_sr_b - elif dtype_bits == 8: + elif dtype_bits in (4, 8): transform_func_sr_a = shared_16x32_to_mma_32x16_layout_sr_a transform_func_sr_b = shared_16x32_to_mma_32x16_layout_sr_b else: @@ -861,7 +873,10 @@ def __init__( self._initialize_transform_kind(transform_kind_a, transform_kind_b) def _initialize_k_dim(self, a_dtype=T.float16): - self.k_dim = 256 // DataType(a_dtype).bits + if str(DataType(a_dtype)) == "float4_e2m1fn": + self.k_dim = 32 + else: + self.k_dim = 256 // DataType(a_dtype).bits def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32): self.local_size_a = (m_dim * k_dim) // warp_size diff --git a/tilelang/cuda/op/gemm/gemm_mma.py b/tilelang/cuda/op/gemm/gemm_mma.py index bd572c407..5be8edc4d 100644 --- a/tilelang/cuda/op/gemm/gemm_mma.py +++ b/tilelang/cuda/op/gemm/gemm_mma.py @@ -18,13 +18,35 @@ class GemmMMA(GemmBase): + @staticmethod + def _is_fp8_e4m3(dtype: str) -> bool: + return str(dtype) in {"float8_e4m3", "float8_e4m3fn", "float8_e4m3fnuz"} + + @staticmethod + def _is_fp4_e2m1(dtype: str) -> bool: + return str(dtype) == "float4_e2m1fn" + + def _validate_mma_dtypes(self): + a_dtype = str(self.A.dtype) + b_dtype = str(self.B.dtype) + if a_dtype == b_dtype: + return + # Mixed A8W4 paths are selected only from semantic dtypes. Packed host + # storage such as uint8 is not treated as an FP4 GEMM dtype. + mixed_fp8_fp4 = (self._is_fp8_e4m3(a_dtype) and self._is_fp4_e2m1(b_dtype)) or ( + self._is_fp4_e2m1(a_dtype) and self._is_fp8_e4m3(b_dtype) + ) + if not mixed_fp8_fp4: + raise AssertionError(f"Unsupported mixed MMA dtypes: A={a_dtype}, B={b_dtype}") + def _make_mma_emitter(self, target: Target, thread_nums: int, thread_var: tir.Var | None = None): + self._validate_mma_dtypes() m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, GEMM_INST_MMA) warp_row_tiles = int(self.M // m_warp) warp_col_tiles = int(self.N // n_warp) emitter = TensorCoreIntrinEmitter( - a_dtype=self.in_dtype, - b_dtype=self.in_dtype, + a_dtype=self.A.dtype, + b_dtype=self.B.dtype, accum_dtype=self.accum_dtype, a_transposed=self.trans_A, b_transposed=self.trans_B, @@ -77,7 +99,8 @@ def lower( thread_nums = thread_bounds.extent mma_emitter = self._make_mma_emitter(target, thread_nums, thread_var=thread_var) - in_dtype = self.in_dtype + a_dtype = self.A.dtype + b_dtype = self.B.dtype warp_rows = mma_emitter.warp_rows warp_cols = mma_emitter.warp_cols local_size_a = mma_emitter.local_size_a @@ -109,8 +132,8 @@ def _gemm_ssr() -> None: B_shared into local fragments, then issues Tensor Core mma ops, accumulating into C_local. """ - A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) - B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + A_local = T.alloc_local((warp_rows * local_size_a), a_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), b_dtype) if clear_accum: T.clear(C_buf) for ki in T.serial(0, (block_K // micro_size_k)): @@ -144,7 +167,7 @@ def _gemm_srr() -> None: B_shared into local fragments, then issues Tensor Core mma ops, accumulating into C_local. """ - A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + A_local = T.alloc_local((warp_rows * local_size_a), a_dtype) for ki in T.serial(0, (block_K // micro_size_k)): if clear_accum: @@ -174,7 +197,7 @@ def _gemm_rsr() -> None: B_shared into local fragments, then issues Tensor Core mma ops, accumulating into C_local. """ - B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), b_dtype) if clear_accum: T.clear(C_buf) for ki in T.serial(0, (block_K // micro_size_k)):