diff --git a/examples/maca/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py b/examples/maca/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py index 18467a81..07f0bea6 100644 --- a/examples/maca/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py +++ b/examples/maca/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py @@ -5,6 +5,7 @@ import tilelang import tilelang.language as T from tilelang.utils.tensor import map_torch_type +from tilelang.utils.target import determine_target, target_is_maca tilelang.testing.set_random_seed(42) @@ -30,6 +31,7 @@ def tl_gemm( group_size = 128 block_M = 128 block_K = 128 + num_stages = 1 if target_is_maca(determine_target("auto", return_object=True)) else 4 A_shape = (M, K) Scales_A_shape = (M, T.ceildiv(K, group_size)) @@ -50,7 +52,6 @@ def main( with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype) - C_shared = T.alloc_shared(C_shared_shape, out_dtype) Scale_C_shared = T.alloc_shared((block_M), T.float32) C_local = T.alloc_fragment(C_shared_shape, accum_dtype) C_local_accum = T.alloc_fragment(C_shared_shape, accum_dtype) @@ -61,7 +62,7 @@ def main( T.clear(C_local) T.clear(C_local_accum) K_iters = T.ceildiv(K, block_K) - for k in T.Pipelined(K_iters, num_stages=4): + for k in T.Pipelined(K_iters, num_stages=num_stages): # Load A into shared memory T.copy(A[by * block_M, k * block_K], A_shared) # Load B into shared memory @@ -76,9 +77,7 @@ def main( for i, j in T.Parallel(block_M, block_N): C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i] T.clear(C_local) - # TMA store - T.copy(C_local_accum, C_shared) - T.copy(C_shared, C[by * block_M, bx * block_N]) + T.copy(C_local_accum, C[by * block_M, bx * block_N]) return main @@ -122,13 +121,20 @@ def ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype): for j in range(ceildiv(N, 128)): c_acc.zero_() for k in range(ceildiv(K, 128)): - c = torch._scaled_mm( - A_fp8[i * 128 : (i + 1) * 128, k * 128 : (k + 1) * 128], - B_fp8[j * 128 : (j + 1) * 128, k * 128 : (k + 1) * 128].T, - scale_a=A_scales[i, k].view(128, 1).contiguous(), - scale_b=B_scales[j, k].view(1, 128).contiguous(), - out_dtype=torch.bfloat16, - ) + a_tile = A_fp8[i * 128 : (i + 1) * 128, k * 128 : (k + 1) * 128] + b_tile = B_fp8[j * 128 : (j + 1) * 128, k * 128 : (k + 1) * 128] + scale_a = A_scales[i, k].view(128, 1).contiguous() + scale_b = B_scales[j, k].view(128, 1).contiguous() + try: + c = torch._scaled_mm( + a_tile, + b_tile.T, + scale_a=scale_a, + scale_b=scale_b.view(1, 128), + out_dtype=torch.bfloat16, + ) + except RuntimeError: + c = (a_tile.to(torch.float32) * scale_a) @ (b_tile.to(torch.float32) * scale_b).T c_acc += c.to(torch.float32) C[i * 128 : (i + 1) * 128, j * 128 : (j + 1) * 128] = c_acc.to(out_dtype) return C diff --git a/examples/maca/deepseek_deepgemm/test_example_deepgemm_fp8_2xAcc.py b/examples/maca/deepseek_deepgemm/test_example_deepgemm_fp8_2xAcc.py index 550ca9d8..d2cdeb02 100644 --- a/examples/maca/deepseek_deepgemm/test_example_deepgemm_fp8_2xAcc.py +++ b/examples/maca/deepseek_deepgemm/test_example_deepgemm_fp8_2xAcc.py @@ -3,7 +3,6 @@ from example_deepgemm_fp8_2xAcc import main -@tilelang.testing.pytest.mark.xfail def test_deepgemm_fp8_2xAcc(): main() diff --git a/examples/maca/dequantize_gemm/dequantize_utils.py b/examples/maca/dequantize_gemm/dequantize_utils.py index 90a6265f..4ad01361 100644 --- a/examples/maca/dequantize_gemm/dequantize_utils.py +++ b/examples/maca/dequantize_gemm/dequantize_utils.py @@ -1,6 +1,12 @@ import torch +def reinterpret_u16_as_bfloat16(bits: torch.Tensor) -> torch.Tensor: + bits_i32 = (bits & 0xFFFF).to(torch.int32) + bits_i16 = torch.where(bits_i32 >= 0x8000, bits_i32 - 0x10000, bits_i32).to(torch.int16) + return bits_i16.view(torch.bfloat16) + + def torch_convert_bit_twiddling(tensor): """ This function expects `tensor` to be a 2-D torch.Tensor of dtype `torch.uint8`. Each output element is produced by combining two input bytes and extracting a bf16-like 16-bit pattern according to one of four positional bit layouts (pos 0..3). The result is scaled by 2**126 to adjust the exponent bias and returned as dtype `torch.bfloat16`. @@ -45,8 +51,7 @@ def torch_convert_bit_twiddling(tensor): bf16 = torch.where(pos == 0, res0, torch.where(pos == 1, res1, torch.where(pos == 2, res2, res3))) # Convert to uint16 for .view(torch.bfloat16) - bf16_uint16 = (bf16 & 0xFFFF).to(torch.uint16) - bf16_bf16 = bf16_uint16.view(torch.bfloat16) + bf16_bf16 = reinterpret_u16_as_bfloat16(bf16) # Avoid integer overflow by using a float32 multiplier for the exponent scaling bf16_new = bf16_bf16 * (2.0**126) @@ -69,32 +74,23 @@ def torch_convert(tensor, scale_size=None, Scale=None): torch.Tensor: A new tensor of shape (N, K*2) and dtype torch.bfloat16 containing the decoded bf16 values. """ - def _convert(val, pos, scale=None): - assert val.dtype == torch.uint8 - # val = val.view(torch.int8) - mask = (1 << 4) - 1 - f4 = ((val >> (pos * 4)) & mask).to(torch.int16) - s = f4 >> 3 - e_f4 = (f4 & 6) >> 1 - e_f16 = e_f4 + 126 - if scale is not None: - e_f16 = min(e_f16 + scale, (1 << 8) - 1) - m_f4 = f4 & 1 - m_f16 = m_f4 - val_f16 = (((e_f16 | (s << 8)) << 7) | (m_f16 << 6)) & 0xFFFF - lower_16_bits = (val_f16 & 0xFFFF).to(torch.uint16) - return lower_16_bits.view(torch.bfloat16) - - N = tensor.shape[0] - K = tensor.shape[1] - new_tensor = torch.empty(N, K * 2, dtype=torch.bfloat16, device=tensor.device) - for i in range(new_tensor.shape[0]): - for j in range(new_tensor.shape[1]): - if scale_size is not None: - new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2, Scale[i][j // scale_size]) - else: - new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2) - return new_tensor + assert tensor.dim() == 2 and tensor.dtype == torch.uint8 + + low = (tensor & 0x0F).to(torch.int16) + high = ((tensor >> 4) & 0x0F).to(torch.int16) + f4 = torch.stack((low, high), dim=-1).reshape(tensor.shape[0], tensor.shape[1] * 2) + + sign = f4 >> 3 + exponent = ((f4 & 0x6) >> 1) + 126 + if scale_size is not None: + if Scale is None: + raise ValueError("Scale must be provided when scale_size is set") + scale_idx = torch.arange(f4.shape[1], device=tensor.device) // scale_size + exponent = torch.clamp(exponent + Scale[:, scale_idx].to(torch.int16), max=(1 << 8) - 1) + + mantissa = f4 & 0x1 + val_f16 = (((exponent | (sign << 8)) << 7) | (mantissa << 6)) & 0xFFFF + return reinterpret_u16_as_bfloat16(val_f16) def print_bit(name, val): diff --git a/examples/maca/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py b/examples/maca/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py index 0842e168..28eb79e0 100644 --- a/examples/maca/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py +++ b/examples/maca/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py @@ -4,6 +4,7 @@ from tvm import DataType from tvm import tir import torch +from tilelang.utils.target import determine_target, target_is_maca from dequantize_utils import torch_convert_bit_twiddling, torch_convert @@ -494,6 +495,12 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, """ total_flops = 2 * m * n * k + if target_is_maca(determine_target("auto", return_object=True)): + fast_dequant = False + block_M, block_N, block_K, num_stages, threads, split = 64, 64, 64, 1, 128, 1 + else: + block_M, block_N, block_K, num_stages, threads, split = 256, 128, 128, 2, 256, 1 + if tune: kernel = matmul( m, n, k, T.bfloat16, T.bfloat16, T.float32, num_bits=4, scale_size=scale_size, fast_dequant=fast_dequant, with_bias=with_bias @@ -508,12 +515,12 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, T.float32, num_bits=4, scale_size=scale_size, - block_M=256, - block_N=128, - block_K=128, - num_stages=2, - threads=256, - split=1, + block_M=block_M, + block_N=block_N, + block_K=block_K, + num_stages=num_stages, + threads=threads, + split=split, fast_dequant=fast_dequant, with_bias=with_bias, ) @@ -537,6 +544,11 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, def run_regression_perf(m=4096, n=4096, k=4096, scale_size=32, fast_dequant=True, with_bias=False): + if target_is_maca(determine_target("auto", return_object=True)): + fast_dequant = False + block_M, block_N, block_K, num_stages, threads, split = 64, 64, 64, 1, 128, 1 + else: + block_M, block_N, block_K, num_stages, threads, split = 256, 128, 128, 2, 256, 1 kernel = matmul( m, n, @@ -546,12 +558,12 @@ def run_regression_perf(m=4096, n=4096, k=4096, scale_size=32, fast_dequant=True "float32", num_bits=4, scale_size=scale_size, - block_M=256, - block_N=128, - block_K=128, - num_stages=2, - threads=256, - split=1, + block_M=block_M, + block_N=block_N, + block_K=block_K, + num_stages=num_stages, + threads=threads, + split=split, fast_dequant=fast_dequant, with_bias=with_bias, ) diff --git a/examples/maca/dequantize_gemm/example_dequant_gemm_fp4_hopper.py b/examples/maca/dequantize_gemm/example_dequant_gemm_fp4_hopper.py index 2bdcbb06..f1ecefe6 100644 --- a/examples/maca/dequantize_gemm/example_dequant_gemm_fp4_hopper.py +++ b/examples/maca/dequantize_gemm/example_dequant_gemm_fp4_hopper.py @@ -5,6 +5,13 @@ import itertools import torch import argparse +from tilelang.utils.target import determine_target, target_is_maca + + +def reinterpret_u16_as_float16(bits: torch.Tensor) -> torch.Tensor: + bits_i32 = (bits & 0xFFFF).to(torch.int32) + bits_i16 = torch.where(bits_i32 >= 0x8000, bits_i32 - 0x10000, bits_i32).to(torch.int16) + return bits_i16.view(torch.float16) def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): @@ -45,8 +52,7 @@ def _convert(val, pos): m_f4 = f4 & 1 m_f16 = m_f4 val_f16 = (((e_f16 | (s << 5)) << 10) | (m_f16 << 9)) & 0xFFFF - lower_16_bits = (val_f16 & 0xFFFF).to(torch.uint16) - return lower_16_bits.view(torch.float16) + return reinterpret_u16_as_float16(val_f16) N = tensor.shape[0] K = tensor.shape[1] @@ -249,8 +255,17 @@ def main(m=256, n=256, k=256, tune=False): total_flops = 2 * m * n * k if not tune: + if target_is_maca(determine_target("auto", return_object=True)): + block_M, block_N, block_K, num_stages, threads, split = 64, 64, 64, 1, 128, 1 + else: + block_M, block_N, block_K, num_stages, threads, split = 128, 128, 128, 2, 256, 1 kernel = matmul(m, n, k, T.float16, T.float16, T.float32, num_bits=4, tune=tune)( - block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1 + block_M=block_M, + block_N=block_N, + block_K=block_K, + num_stages=num_stages, + threads=threads, + split=split, ) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) @@ -271,8 +286,17 @@ def main(m=256, n=256, k=256, tune=False): def run_regression_perf(m=4096, n=4096, k=4096): + if target_is_maca(determine_target("auto", return_object=True)): + block_M, block_N, block_K, num_stages, threads, split = 64, 64, 64, 1, 128, 1 + else: + block_M, block_N, block_K, num_stages, threads, split = 128, 128, 128, 2, 256, 1 kernel = matmul(m, n, k, "float16", "float16", "float32", num_bits=4, tune=False)( - block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1 + block_M=block_M, + block_N=block_N, + block_K=block_K, + num_stages=num_stages, + threads=threads, + split=split, ) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer) return profiler.do_bench(backend="cupti") diff --git a/examples/maca/dequantize_gemm/example_dequant_gemm_w4a8.py b/examples/maca/dequantize_gemm/example_dequant_gemm_w4a8.py index 2db3cd61..24dd11b6 100644 --- a/examples/maca/dequantize_gemm/example_dequant_gemm_w4a8.py +++ b/examples/maca/dequantize_gemm/example_dequant_gemm_w4a8.py @@ -87,8 +87,11 @@ def _convert(val, pos): def ref_program(A, qB): dtypeC = T.int32 B = torch_convert(qB) - C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) - C = C.to(torch.__getattribute__(dtypeC)) + # CUDA/MACA does not provide exact int32 matmul here, and float32 matmul + # followed by cast introduces many +/-1 mismatches. Compute the reference + # with exact integer accumulation on CPU instead. + C = torch.matmul(A.cpu().to(torch.int32), B.cpu().to(torch.int32).T) + C = C.to(torch.__getattribute__(dtypeC)).to(A.device) return C.transpose(0, 1) diff --git a/examples/maca/dequantize_gemm/example_dequant_gemv_fp16xint4.py b/examples/maca/dequantize_gemm/example_dequant_gemv_fp16xint4.py index b67d8165..fc8ec891 100644 --- a/examples/maca/dequantize_gemm/example_dequant_gemv_fp16xint4.py +++ b/examples/maca/dequantize_gemm/example_dequant_gemv_fp16xint4.py @@ -3,8 +3,10 @@ from typing import Optional, Callable, Any import torch from tilelang import DataType +from tilelang.utils.target import determine_target, target_is_maca from tilelang.quantize import ( _tir_packed_int_to_int_convert, + _tir_packed_to_unsigned_convert, ) @@ -55,6 +57,12 @@ def dequantize_gemv( import_source: Optional[str] = None func_name: str = "" + if source_format == "uint": + convert_packed = _tir_packed_to_unsigned_convert(storage_type, storage_nbit) + elif source_format in {"int", "sint"}: + convert_packed = _tir_packed_int_to_int_convert(storage_type, storage_nbit) + else: + raise ValueError(f"Unsupported source_format: {source_format}") if fast_decoding is True: # Lazy import to decrease the startup time # as intrin registry may take a while to load @@ -119,7 +127,7 @@ def main( ) else: for ki in T.serial(micro_size_k): - B_dequantize_local[ki] = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( + B_dequantize_local[ki] = convert_packed( num_bits, B_quant_local[ki // num_elems_per_byte], ki % num_elems_per_byte, in_dtype ) @@ -167,7 +175,7 @@ def main() -> None: source_format = "uint" n_partition = 4 reduce_thread = 32 - fast_decoding = True + fast_decoding = not target_is_maca(determine_target("auto", return_object=True)) trans_A = False trans_B = True group_size = -1 @@ -229,7 +237,7 @@ def run_regression_perf(): source_format = "uint" n_partition = 4 reduce_thread = 32 - fast_decoding = True + fast_decoding = not target_is_maca(determine_target("auto", return_object=True)) trans_A = False trans_B = True group_size = -1 diff --git a/examples/maca/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py b/examples/maca/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py index 501fe11c..e98d1b7b 100644 --- a/examples/maca/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py +++ b/examples/maca/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py @@ -4,8 +4,9 @@ from tilelang import tvm as tvm from tvm import DataType import torch -from dequantize_utils import torch_convert_bit_twiddling, assert_similar +from dequantize_utils import torch_convert, torch_convert_bit_twiddling, assert_similar from tilelang.autotuner import set_autotune_inputs +from tilelang.utils.target import determine_target, target_is_maca import argparse @@ -246,7 +247,7 @@ def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale_shared, k): i, k * block_K // scale_size + j // scale_size ], # Scale is the exponential part, within the representation of uint8 dtype=out_dtype, - ) * T.shift_left(1, (Scale_shared[i, k * block_K // scale_size + j // scale_size])) + ) T.copy(B_dequantize_local, B_dequantize_shared) return simple_dequant_bf16_fp4 @@ -343,45 +344,73 @@ def main( return main -def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, block_M=256): +def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, block_M=256, fast_dequant=True): dtypeC = T.bfloat16 M, K = A.shape E, N, QK = qB.shape topk = topk_weights.shape[0] // M scale_size = K // Scale.shape[2] assert scale_size == 32 # MXFP4 + maca_mode = target_is_maca(determine_target("auto", return_object=True)) + + work_A = A + work_qB = qB + work_Scale = Scale + work_Bias = Bias + work_topk_weights = topk_weights + work_sorted_token_ids = sorted_token_ids + work_expert_ids = expert_ids + accum_dtype = torch.bfloat16 + + if maca_mode: + # The MACA path is correctness-oriented here; keeping the reference on CPU + # avoids the heavy post-kernel GPU work that was getting the example stuck. + work_A = A.cpu() + work_qB = qB.cpu() + work_Scale = Scale.cpu() + work_Bias = Bias.cpu() + work_topk_weights = topk_weights.cpu() + work_sorted_token_ids = sorted_token_ids.cpu() + work_expert_ids = expert_ids.cpu() + accum_dtype = torch.float32 # Initialize output tensor - C = torch.ones((M, topk, N), dtype=getattr(torch, dtypeC), device="cuda") - - # Iterate over sorted_token_ids - for idx in range(len(sorted_token_ids)): # padding_M - token_id = sorted_token_ids[idx] - if token_id == -1: + C = torch.empty((M, topk, N), dtype=accum_dtype, device=work_A.device) + + expert_weights = [] + for expert in range(E): + if fast_dequant: + B = torch_convert_bit_twiddling(work_qB[expert]) # shape: (N, K) + else: + B = torch_convert(work_qB[expert]) # shape: (N, K) + B *= 2 ** (work_Scale[expert][:, (torch.arange(B.shape[1], device=B.device) // scale_size)].to(torch.bfloat16)) + expert_weights.append(B.to(accum_dtype)) + + block_count = work_expert_ids.numel() + block_idx = 0 + while block_idx < block_count: + expert_id = int(work_expert_ids[block_idx].item()) + group_end = block_idx + 1 + while group_end < block_count and int(work_expert_ids[group_end].item()) == expert_id: + group_end += 1 + + token_ids = work_sorted_token_ids[block_idx * block_M : group_end * block_M] + valid_mask = token_ids != -1 + token_ids = token_ids[valid_mask] + if token_ids.numel() == 0: + block_idx = group_end continue - expert_id = expert_ids[idx // block_M] - topk_idx = token_id % topk - - # Get the token embedding - token_embedding = A[token_id // topk] - - # Dequantize the expert weights - B = torch_convert_bit_twiddling(qB[expert_id]) # shape: (N, K) - B *= 2 ** (Scale[expert_id][:, (torch.arange(B.shape[1], device=B.device) // scale_size)].to(torch.bfloat16)) - # Compute the output for this token-expert pair - # token_embedding @ B.T + bias - output = torch.matmul(token_embedding.to(torch.bfloat16), B.T.to(torch.bfloat16)) + Bias[expert_id] - output = output.to(torch.__getattribute__(dtypeC)) + token_rows = torch.div(token_ids, topk, rounding_mode="floor") + topk_rows = torch.remainder(token_ids, topk) + token_embeddings = work_A[token_rows].to(accum_dtype) + output = torch.matmul(token_embeddings, expert_weights[expert_id].T) + output = output + work_Bias[expert_id].to(accum_dtype) + output = output * work_topk_weights[token_ids].to(accum_dtype).unsqueeze(1) + C[token_rows, topk_rows] = output + block_idx = group_end - # Apply the topk weight - weight = topk_weights[token_id] - output = output * weight - - # Store the result - C[token_id // topk, topk_idx] = output - - return C + return C.to(device=A.device, dtype=torch.__getattribute__(dtypeC)) def get_data(m, n, k, qk, scale_size, topk, E, block_M): @@ -428,10 +457,18 @@ def get_data(m, n, k, qk, scale_size, topk, E, block_M): def main(m=256, n=256, k=256, scale_size=32, topk=4, E=32, fast_dequant=True, with_bias=False, tune=False): # Tunable parameters - block_M, block_N, block_K = 128, 256, 128 # noqa: F841 - num_stages = 1 # noqa: F841 - threads = 512 # noqa: F841 - split = 1 # noqa: F841 + maca_mode = target_is_maca(determine_target("auto", return_object=True)) + if maca_mode: + fast_dequant = False + block_M, block_N, block_K = 64, 64, 64 # noqa: F841 + num_stages = 1 # noqa: F841 + threads = 128 # noqa: F841 + split = 1 # noqa: F841 + else: + block_M, block_N, block_K = 128, 256, 128 # noqa: F841 + num_stages = 1 # noqa: F841 + threads = 512 # noqa: F841 + split = 1 # noqa: F841 total_flops = 2 * m * n * k * topk num_bits = 4 @@ -491,11 +528,24 @@ def main(m=256, n=256, k=256, scale_size=32, topk=4, E=32, fast_dequant=True, wi ) print("Tilelang kernel run finished.") - ref_output = ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, block_M=block_M) # Maybe a little bit slow... + ref_output = ref_moe( + A, + qB, + Scale, + Bias, + topk_weights, + sorted_token_ids, + expert_ids, + block_M=block_M, + fast_dequant=fast_dequant, + ) # Maybe a little bit slow... - latency = tilelang.profiler.do_bench(lambda: kernel(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids), warmup=100) - print("Tilelang: {:.2f} ms".format(latency)) - print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + if maca_mode: + print("Tilelang benchmark skipped on MACA in correctness mode.") + else: + latency = tilelang.profiler.do_bench(lambda: kernel(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids), warmup=100) + print("Tilelang: {:.2f} ms".format(latency)) + print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) diff = (output - ref_output).abs() max_val = diff.max() @@ -506,10 +556,17 @@ def main(m=256, n=256, k=256, scale_size=32, topk=4, E=32, fast_dequant=True, wi def run_regression_perf(m=4096, n=4096, k=4096, scale_size=32, topk=4, E=32, fast_dequant=True, with_bias=False, tune=False): - block_M, block_N, block_K = 128, 256, 128 - num_stages = 1 - threads = 512 - split = 1 + if target_is_maca(determine_target("auto", return_object=True)): + fast_dequant = False + block_M, block_N, block_K = 64, 64, 64 + num_stages = 1 + threads = 128 + split = 1 + else: + block_M, block_N, block_K = 128, 256, 128 + num_stages = 1 + threads = 512 + split = 1 num_bits = 4 num_elems_per_byte = 8 // num_bits qk = k // num_elems_per_byte diff --git a/examples/maca/dequantize_gemm/test_example_dequantize_gemm.py b/examples/maca/dequantize_gemm/test_example_dequantize_gemm.py index 83cfd972..d995ee93 100644 --- a/examples/maca/dequantize_gemm/test_example_dequantize_gemm.py +++ b/examples/maca/dequantize_gemm/test_example_dequantize_gemm.py @@ -7,22 +7,18 @@ import example_dequant_gemm_w4a8 -@tilelang.testing.pytest.mark.xfail def test_example_dequant_gemv_fp16xint4(): example_dequant_gemv_fp16xint4.main() -@tilelang.testing.pytest.mark.xfail def test_example_dequant_gemm_fp4_hopper(): example_dequant_gemm_fp4_hopper.main() -@tilelang.testing.pytest.mark.xfail def test_example_dequant_gemm_bf16_mxfp4_hopper(): example_dequant_gemm_bf16_mxfp4_hopper.main() -@tilelang.testing.pytest.mark.xfail def test_example_dequant_groupedgemm_bf16_mxfp4_hopper(): example_dequant_groupedgemm_bf16_mxfp4_hopper.main() diff --git a/examples/maca/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py b/examples/maca/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py index 39c6fc33..f6e74144 100644 --- a/examples/maca/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py +++ b/examples/maca/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py @@ -20,7 +20,6 @@ def gemm_fp8_2xAcc( with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) B_shared = T.alloc_shared((block_N, block_K), dtype) - C_shared = T.alloc_shared((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -40,9 +39,7 @@ def gemm_fp8_2xAcc( if K_iters % update_interval != 0: for i, j in T.Parallel(block_M, block_N): C_local_accum[i, j] += C_local[i, j] - # TMA store - T.copy(C_local_accum, C_shared) - T.copy(C_shared, C[by * block_M, bx * block_N]) + T.copy(C_local_accum, C[by * block_M, bx * block_N]) return gemm_fp8_2xAcc diff --git a/examples/maca/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py b/examples/maca/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py index 9cfd9782..b673ad3a 100644 --- a/examples/maca/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py +++ b/examples/maca/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py @@ -5,9 +5,11 @@ import tilelang.language as T from tilelang.intrinsics import get_swizzle_layout from tilelang.intrinsics.mma_macro_generator import TensorCoreIntrinEmitter +from tilelang.intrinsics.maca_mma_macro_generator import TensorCoreIntrinEmitter as MacaTensorCoreIntrinEmitter from tilelang.intrinsics.mfma_macro_generator import MatrixCoreIntrinEmitter from tilelang.utils.tensor import map_torch_type from tilelang.utils import determine_fp8_type +from tilelang.utils.target import determine_target, target_is_maca tilelang.testing.set_random_seed(0) @@ -70,6 +72,7 @@ def tl_matmul( A_shared_shape = (block_M, block_K) B_shared_shape = (block_N, block_K) is_hip = torch.version.hip is not None + is_maca = target_is_maca(determine_target("auto", return_object=True)) # MMA Wrapper to Auto Generate Code for MMA/MFMA if is_hip: mma_emitter = MatrixCoreIntrinEmitter( @@ -84,6 +87,19 @@ def tl_matmul( warp_col_tiles=warp_col_tiles, chunk=chunk, ) + elif is_maca: + mma_emitter = MacaTensorCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) else: mma_emitter = TensorCoreIntrinEmitter( a_dtype=in_dtype, @@ -114,6 +130,7 @@ def tl_matmul( local_size_c = mma_emitter.local_size_out warp_rows = mma_emitter.warp_rows warp_cols = mma_emitter.warp_cols + local_in_dtype = getattr(mma_emitter, "mma_input_dtype", in_dtype) @T.prim_func def gemm_fp8_intrinsic( @@ -125,8 +142,8 @@ def gemm_fp8_intrinsic( A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) - A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) - B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + A_local = T.alloc_local((warp_rows * local_size_a), local_in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), local_in_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) T.annotate_layout( diff --git a/examples/maca/gemm_fp8/test_example_gemm_fp8.py b/examples/maca/gemm_fp8/test_example_gemm_fp8.py index 5b7e625a..19a9ee00 100644 --- a/examples/maca/gemm_fp8/test_example_gemm_fp8.py +++ b/examples/maca/gemm_fp8/test_example_gemm_fp8.py @@ -4,17 +4,14 @@ import example_tilelang_gemm_fp8 -@tilelang.testing.pytest.mark.xfail def test_example_tilelang_gemm_fp8_2xAcc(): example_tilelang_gemm_fp8_2xAcc.main() -@tilelang.testing.pytest.mark.xfail def test_example_tilelang_gemm_fp8_intrinsic(): example_tilelang_gemm_fp8_intrinsic.main() -@tilelang.testing.pytest.mark.xfail def test_example_tilelang_gemm_fp8(): example_tilelang_gemm_fp8.main() diff --git a/src/target/codegen_maca.cc b/src/target/codegen_maca.cc index ee2e659e..b655ced9 100644 --- a/src/target/codegen_maca.cc +++ b/src/target/codegen_maca.cc @@ -48,9 +48,12 @@ static std::string GetTileLangFP8Type(DataType type) { << "Only support scalar and vector types of width (2, 4, 8, 16, 32) " "for FP8"; } - if (type.is_float8_e4m3() || type.is_float8_e4m3fn()) { + if (type.is_float8_e4m3() || type.is_float8_e4m3fn() || + type.is_float8_e4m3fnuz() || + type.code() == DataType::kFloat8_e4m3b11fnuz) { stream << "fp8_e4" << vec << "_t"; - } else if (type.is_float8_e5m2()) { + } else if (type.is_float8_e5m2() || type.is_float8_e5m2fnuz() || + type.code() == DataType::kFloat8_e5m2) { stream << "fp8_e5" << vec << "_t"; } else if (type.is_float8_e8m0fnu()) { stream << "fp8_e8" << vec << "_t"; @@ -362,6 +365,8 @@ void CodeGenTileLangMACA::PrintType(DataType t, std::ostream &os) { // NOLINT(*) ICHECK_EQ(lanes % 2, 0) << "only support even lane for float type with lanes > 4"; os << "ulonglong" << lanes / 2; + } else if (lanes == 16 || lanes == 32) { + os << "float32x" << lanes; } else { fail = true; } @@ -375,7 +380,8 @@ void CodeGenTileLangMACA::PrintType(DataType t, std::ostream &os) { // NOLINT(*) } if (!fail && (t.is_scalar() || t.bits() == 16)) return; - if (!fail && (lanes > 4 && lanes <= 8 && t.bits() == 32)) + if (!fail && t.bits() == 32 && + ((lanes > 4 && lanes <= 8) || lanes == 16 || lanes == 32)) return; if (!fail && (lanes >= 2 && lanes <= 4)) { os << lanes; @@ -629,7 +635,11 @@ void CodeGenTileLangMACA::PrintVecElemLoad(const std::string &vec, DataType t, } static const char access[] = {'x', 'y', 'z', 'w'}; - ICHECK(i >= 0 && i < 256 / t.bits()) + int max_lanes = 256 / t.bits(); + if (t.is_float() && t.bits() == 32 && (t.lanes() == 16 || t.lanes() == 32)) { + max_lanes = t.lanes(); + } + ICHECK(i >= 0 && i < max_lanes) << "i: " << i << " t: " << t << " t.bits(): " << t.bits() << " t.lanes(): " << t.lanes(); if (t.bits() == 8 && (t.is_int() || t.is_uint())) { @@ -692,6 +702,9 @@ void CodeGenTileLangMACA::PrintVecElemLoad(const std::string &vec, DataType t, os << "." << access[(i % 4) / 2]; // fp4_e2_2_t -> method call x() or y() os << "." << access[i % 2] << "()"; + } else if (t.is_float() && t.bits() == 32 && + (t.lanes() == 16 || t.lanes() == 32)) { + os << vec << "[" << i << "]"; } else if (t.lanes() > 4 && t.lanes() <= 8) { std::string type_name; if (t.bits() == 16) { @@ -721,7 +734,11 @@ void CodeGenTileLangMACA::PrintVecElemStore(const std::string &vec, DataType t, int i, const std::string &value) { this->PrintIndent(); static const char access[] = {'x', 'y', 'z', 'w'}; - ICHECK(i >= 0 && i < 256 / t.bits()); + int max_lanes = 256 / t.bits(); + if (t.is_float() && t.bits() == 32 && (t.lanes() == 16 || t.lanes() == 32)) { + max_lanes = t.lanes(); + } + ICHECK(i >= 0 && i < max_lanes); if (t.bits() == 8 && (t.is_int() || t.is_uint())) { if (t.lanes() == 2 || t.lanes() == 3) { stream << vec << '.' << access[i % t.lanes()] << "=" @@ -795,6 +812,9 @@ void CodeGenTileLangMACA::PrintVecElemStore(const std::string &vec, DataType t, ICHECK(!type_name.empty()); stream << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = " << value << ";\n"; + } else if (t.is_float() && t.bits() == 32 && + (t.lanes() == 16 || t.lanes() == 32)) { + stream << vec << "[" << i << "] = " << value << ";\n"; } else if (t.is_float4_e2m1fn()) { stream << vec; // fp4_e2_64_t @@ -1791,9 +1811,20 @@ void CodeGenTileLangMACA::VisitExpr_(const CallNode *op, std::ostream &os) { {"float16x4", "float16x4"}, {"bfloat16x4", "bfloat16x4_vec"}, {"float32x4", "float32x4"}, + {"float8_e4m3x4", "fp8_e4_4_t"}, + {"float8_e4m3x8", "long"}, + {"float8_e4m3fnx4", "fp8_e4_4_t"}, + {"float8_e4m3fnx8", "long"}, {"float8_e4m3fnuzx4", "fp8_e4_4_t"}, {"float8_e4m3fnuzx8", "long"}, - {"float32x16", "float32x16"}}; + {"float8_e4m3b11fnuzx4", "fp8_e4_4_t"}, + {"float8_e4m3b11fnuzx8", "long"}, + {"float8_e5m2x4", "fp8_e5_4_t"}, + {"float8_e5m2x8", "long"}, + {"float8_e5m2fnuzx4", "fp8_e5_4_t"}, + {"float8_e5m2fnuzx8", "long"}, + {"float32x16", "float32x16"}, + {"float32x32", "float32x32"}}; std::string call_mfma_code = R"({ *((({C_dtype}*){c_ref}) + {c_bias}) = {mfma_buildin}(*((({A_dtype}*){a_ref}) + {a_bias}), *((({B_dtype}*){b_ref}) + {b_bias}), @@ -2083,6 +2114,8 @@ void CodeGenTileLangMACA::VisitExpr_(const CallNode *op, std::ostream &os) { os << ", " << PrintExpr(op->args[2]); } os << ")"; + } else if (op->op.same_as(tl::tl_shuffle_elect())) { + os << "tl::tl_shuffle_elect<" << PrintExpr(op->args[0]) << ">()"; } else if (op->op.same_as(tl::tl_gemm_sp())) { ICHECK(op->args.size() == 5) << "tl_gemm_sp expects 5 arguments value); - ICHECK(p); - int64_t v = *p & 0xFF; - v = (v << 24) | (v << 16) | (v << 8) | v; - if (op->dtype.is_uint()) { - os << "(uint)" << v; + if (p) { + int64_t v = *p & 0xFF; + v = (v << 24) | (v << 16) | (v << 8) | v; + if (op->dtype.is_uint()) { + os << "(uint)" << v; + } else { + os << "(int)" << v; + } + } else { + std::string scalar = PrintExpr(op->value); + std::string byte = "((" + scalar + ") & 0xFF)"; + os << "(" << (op->dtype.is_uint() ? "uint" : "int") << ")(" << byte + << " | (" << byte << " << 8) | (" << byte << " << 16) | (" << byte + << " << 24))"; + } + return; + } else if (lanes == 8 || lanes == 16) { + const int64_t *p = as_const_int(op->value); + std::string packed32; + if (p) { + int64_t v = *p & 0xFF; + v = (v << 24) | (v << 16) | (v << 8) | v; + std::ostringstream oss; + oss << "(" << (op->dtype.is_uint() ? "uint" : "int") << ")" << v; + packed32 = oss.str(); } else { - os << "(int)" << v; + std::string scalar = PrintExpr(op->value); + std::string byte = "((" + scalar + ") & 0xFF)"; + packed32 = "(" + byte + " | (" + byte + " << 8) | (" + byte + + " << 16) | (" + byte + " << 24))"; + packed32 = "(" + std::string(op->dtype.is_uint() ? "uint" : "int") + + ")" + packed32; + } + os << "make_"; + PrintType(op->dtype, os); + os << '('; + for (int i = 0; i < lanes / 4; ++i) { + if (i != 0) + os << ", "; + os << packed32; } + os << ')'; return; } else if (lanes == 32) { // make_int8x32 const int64_t *p = as_const_int(op->value); - ICHECK(p); - int64_t v = *p & 0xFF; - v = (v << 24) | (v << 16) | (v << 8) | v; - if (op->dtype.is_uint()) { - os << "make_ulonglong4(" << v << ", " << v << ", " << v << ", " << v - << ")"; + if (p) { + int64_t v = *p & 0xFF; + v = (v << 24) | (v << 16) | (v << 8) | v; + if (op->dtype.is_uint()) { + os << "make_ulonglong4(" << v << ", " << v << ", " << v << ", " << v + << ")"; + } else { + os << "make_longlong4(" << v << ", " << v << ", " << v << ", " << v + << ")"; + } } else { - os << "make_longlong4(" << v << ", " << v << ", " << v << ", " << v - << ")"; + std::string scalar = PrintExpr(op->value); + std::string byte = "((" + scalar + ") & 0xFF)"; + std::string packed32 = "(" + byte + " | (" + byte + " << 8) | (" + + byte + " << 16) | (" + byte + " << 24))"; + std::string packed64 = "(((unsigned long long)" + packed32 + + ") | (((unsigned long long)" + packed32 + + ") << 32))"; + if (op->dtype.is_uint()) { + os << "make_ulonglong4(" << packed64 << ", " << packed64 << ", " + << packed64 << ", " << packed64 << ")"; + } else { + os << "make_longlong4(" << packed64 << ", " << packed64 << ", " + << packed64 << ", " << packed64 << ")"; + } } return; } @@ -2529,7 +2612,7 @@ void CodeGenTileLangMACA::VisitExpr_(const BroadcastNode *op, for (int i = 0; i < 4; ++i) { if (i != 0) os << ", "; - os << "*(unsigned long long*)&make_float2(" << v << ", " << v << ")"; + os << "pack_float2(" << v << ", " << v << ")"; } os << ')'; return; @@ -2538,37 +2621,64 @@ void CodeGenTileLangMACA::VisitExpr_(const BroadcastNode *op, if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 4) { bool fail = false; const int64_t *p = as_const_int(op->value); - ICHECK(p); - int64_t v = *p & 0xF; - - if (lanes == 4) { - v = (v << 12) | (v << 8) | (v << 4) | v; - if (op->dtype.is_uint()) { - os << "(uint16_t)" << v; - } else { - os << "(int16_t)" << v; - } - } else { - v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | (v << 8) | - (v << 4) | v; - if (lanes == 8) { + if (p) { + int64_t v = *p & 0xF; + if (lanes == 4) { + v = (v << 12) | (v << 8) | (v << 4) | v; if (op->dtype.is_uint()) { - os << "(uint)" << v; + os << "(uint16_t)" << v; } else { - os << "(int)" << v; + os << "(int16_t)" << v; } - } else if (lanes == 16 || lanes == 32) { + } else { + v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | + (v << 8) | (v << 4) | v; + if (lanes == 8) { + if (op->dtype.is_uint()) { + os << "(uint)" << v; + } else { + os << "(int)" << v; + } + } else if (lanes == 16 || lanes == 32 || lanes == 64) { + os << "make_"; + PrintType(op->dtype, os); + os << '('; + for (int i = 0; i < lanes / 8; ++i) { + if (i != 0) + os << ", "; + if (op->dtype.is_uint()) { + os << "(uint)" << v; + } else { + os << "(int)" << v; + } + } + os << ')'; + } else { + fail = true; + } + } + } else { + std::string scalar = PrintExpr(op->value); + std::string nibble = "((" + scalar + ") & 0xF)"; + std::string packed32 = "(" + nibble + " | (" + nibble + " << 4) | (" + + nibble + " << 8) | (" + nibble + " << 12) | (" + + nibble + " << 16) | (" + nibble + " << 20) | (" + + nibble + " << 24) | (" + nibble + " << 28))"; + if (lanes == 4) { + os << "(" << (op->dtype.is_uint() ? "uint16_t" : "int16_t") << ")(" + << nibble << " | (" << nibble << " << 4) | (" << nibble + << " << 8) | (" << nibble << " << 12))"; + } else if (lanes == 8) { + os << "(" << (op->dtype.is_uint() ? "uint" : "int") << ")" << packed32; + } else if (lanes == 16 || lanes == 32 || lanes == 64) { os << "make_"; PrintType(op->dtype, os); os << '('; for (int i = 0; i < lanes / 8; ++i) { if (i != 0) os << ", "; - if (op->dtype.is_uint()) { - os << "(uint)" << v; - } else { - os << "(int)" << v; - } + os << "(" << (op->dtype.is_uint() ? "uint" : "int") << ")" + << packed32; } os << ')'; } else { diff --git a/src/tl_templates/maca/common.h b/src/tl_templates/maca/common.h index 9b7392b4..89727e19 100644 --- a/src/tl_templates/maca/common.h +++ b/src/tl_templates/maca/common.h @@ -162,9 +162,43 @@ typedef using int32x4 = __attribute__((__vector_size__(4 * sizeof(int)))) int; using float32x4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; using float32x16 = __attribute__((__vector_size__(16 * sizeof(float)))) float; +using float32x32 = __attribute__((__vector_size__(32 * sizeof(float)))) float; using float64x4 = __attribute__((__vector_size__(4 * sizeof(double)))) double; using int8x4 = __attribute__((__vector_size__(4 * sizeof(int8_t)))) int8_t; +TL_DEVICE float32x16 make_float32x16(float x0, float x1, float x2, float x3, + float x4, float x5, float x6, float x7, + float x8, float x9, float x10, float x11, + float x12, float x13, float x14, + float x15) { + return float32x16{x0, x1, x2, x3, x4, x5, x6, x7, + x8, x9, x10, x11, x12, x13, x14, x15}; +} + +TL_DEVICE float32x32 make_float32x32( + float x0, float x1, float x2, float x3, float x4, float x5, float x6, + float x7, float x8, float x9, float x10, float x11, float x12, float x13, + float x14, float x15, float x16, float x17, float x18, float x19, float x20, + float x21, float x22, float x23, float x24, float x25, float x26, float x27, + float x28, float x29, float x30, float x31) { + return float32x32{x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, + x11, x12, x13, x14, x15, x16, x17, x18, x19, x20, x21, + x22, x23, x24, x25, x26, x27, x28, x29, x30, x31}; +} + +template TL_DEVICE bool tl_shuffle_elect() { + if constexpr (thread_extent == 0) { + return threadIdx.x == 0; + } else if constexpr (thread_extent <= 32) { + return (threadIdx.x % thread_extent) == 0; + } else if constexpr (thread_extent % 32 == 0) { + return ((threadIdx.x / 32) % (thread_extent / 32)) == 0 && + (threadIdx.x % 32) == 0; + } else { + return (threadIdx.x % thread_extent) == 0; + } +} + // Pack four char values. TL_DEVICE unsigned int make_uint(unsigned char x0, unsigned char x1, unsigned char x2, unsigned char x3) { @@ -196,6 +230,15 @@ TL_DEVICE unsigned __pack_maca_bfloat162(const bfloat16_t x, return (v1 << 16) | v0; } +TL_DEVICE unsigned long long pack_float2(const float x, const float y) { + union { + float2 f; + unsigned long long u64; + } bits; + bits.f = make_float2(x, y); + return bits.u64; +} + template TL_DEVICE void AtomicAdd(T1 *address, T2 val, int memory_order = 0) { (void)memory_order; diff --git a/tilelang/contrib/dlpack.py b/tilelang/contrib/dlpack.py index d80f0fdb..66adbc1e 100644 --- a/tilelang/contrib/dlpack.py +++ b/tilelang/contrib/dlpack.py @@ -38,10 +38,10 @@ def convert_func(tvm_func, tensor_type, to_dlpack_func): import torch float8_dtype_map = { - torch.float8_e4m3fn: "float8_e4m3", + torch.float8_e4m3fn: "float8_e4m3fn", torch.float8_e4m3fnuz: "float8_e4m3fnuz", torch.float8_e5m2: "float8_e5m2", - torch.float8_e5m2fnuz: "float8_e5m2", + torch.float8_e5m2fnuz: "float8_e5m2fnuz", } def adapt_tensor(arg): diff --git a/tilelang/intrinsics/maca_mma_macro_generator.py b/tilelang/intrinsics/maca_mma_macro_generator.py index 1ea41388..4e3e0b7f 100644 --- a/tilelang/intrinsics/maca_mma_macro_generator.py +++ b/tilelang/intrinsics/maca_mma_macro_generator.py @@ -57,6 +57,14 @@ class TensorCoreIntrinEmitter: k_pack = 1 # Represent the thread binding in the form of (tx, warp_n, warp_m) is_m_first = False + fp8_dtypes = { + "float8_e4m3", + "float8_e5m2", + "float8_e4m3fn", + "float8_e5m2fn", + "float8_e4m3fnuz", + "float8_e5m2fnuz", + } def __init__( self, @@ -88,6 +96,7 @@ def __init__( self.warp_row_tiles = warp_row_tiles self.warp_col_tiles = warp_col_tiles self.chunk = chunk + self.mma_input_dtype = self._resolve_mma_input_dtype(a_dtype) self._initialize_k_dim(a_dtype) self._initialize_abbrev(a_dtype, b_dtype, accum_dtype) self._initialize_local_size(self.M_DIM, self.N_DIM, self.k_dim, self.WARP_SIZE) @@ -106,8 +115,6 @@ def __init__( def _initialize_k_dim(self, a_dtype=T.float16): if isinstance(a_dtype, str): - if a_dtype in ["float8_e4m3fnuz", "float8_e5m2fnuz"]: - return a_dtype = DataType(a_dtype) if a_dtype.bits == 32: @@ -130,13 +137,21 @@ def _dtype_abbrv_lookup(self, dtype): raise KeyError(f"Unsupported dtype for MACA MMA: {dtype!r}") return self.dtype_abbrv[s] + def _resolve_mma_input_dtype(self, dtype): + s = str(dtype) + if s.startswith("dtype('") and s.endswith("')"): + s = s[7:-2] + if s in self.fp8_dtypes: + return T.float16 + return dtype + def _initialize_abbrev(self, a_dtype, b_dtype, accum_dtype): self.a_dtype_abbrv = self._dtype_abbrv_lookup(a_dtype) self.b_dtype_abbrv = self._dtype_abbrv_lookup(b_dtype) self.accum_dtype_abbrv = self._dtype_abbrv_lookup(accum_dtype) def _initialize_mma_prefix(self, k_dim=16): - in_dtype = self.a_dtype + in_dtype = self.mma_input_dtype M_DIM, N_DIM = self.M_DIM, self.N_DIM in_dtype_key = str(in_dtype) @@ -148,6 +163,8 @@ def _initialize_mma_prefix(self, k_dim=16): "float32": "f32", "int8": "i8", "int32": "i32", + "float8_e4m3": "f16", + "float8_e5m2": "f16", "float8_e4m3fnuz": "fp8", "float8_e5m2fnuz": "fp8", "float8_e4m3fn": "fp8", @@ -281,6 +298,7 @@ def ldmatrix_a(self, A_local_buf, A_shared_buf: Buffer | BufferRegion, ki, rk=0) # legalize shared buffer to region A_region = self._legalize_to_buffer_region(A_shared_buf) A_buf = A_region.buffer + A_prefix = [region.min for region in A_region.region[:-2]] A_base0 = A_region.region[-2].min A_base1 = A_region.region[-1].min @@ -298,13 +316,13 @@ def _warp_ldmatrix_a( for local_id in T.vectorized(k_pack * local_size_a): row, col = T.meta_var(reverse_index_map(tx, local_id)) l, r = (rk * chunk + ki * (k_pack * micro_size_k), warp_m * warp_row_tiles + i * micro_size_x) - A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col] + A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[tuple(A_prefix + [A_base0 + l + row, A_base1 + r + col])] else: for i in T.serial(warp_rows): for local_id in T.vectorized(k_pack * local_size_a): row, col = T.meta_var(reverse_index_map(tx, local_id)) l, r = (warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * (k_pack * micro_size_k)) - A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[A_base0 + l + row, A_base1 + r + col] + A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[tuple(A_prefix + [A_base0 + l + row, A_base1 + r + col])] return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk) @@ -323,6 +341,7 @@ def ldmatrix_b(self, B_local_buf, B_shared_buf: Buffer | BufferRegion, ki, rk=0) # legalize shared buffer to region B_region = self._legalize_to_buffer_region(B_shared_buf) B_buf = B_region.buffer + B_prefix = [region.min for region in B_region.region[:-2]] B_base0 = B_region.region[-2].min B_base1 = B_region.region[-1].min @@ -343,7 +362,7 @@ def _warp_ldmatrix_b( warp_n * warp_col_tiles + j * micro_size_y, rk * chunk + ki * (k_pack * micro_size_k), ) - B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row, B_base1 + r + col] + B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[tuple(B_prefix + [B_base0 + l + row, B_base1 + r + col])] else: for j in T.serial(warp_cols): @@ -353,7 +372,7 @@ def _warp_ldmatrix_b( rk * chunk + ki * (k_pack * micro_size_k), warp_n * warp_col_tiles + j * micro_size_y, ) - B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[B_base0 + l + row, B_base1 + r + col] + B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[tuple(B_prefix + [B_base0 + l + row, B_base1 + r + col])] return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk) @@ -365,9 +384,10 @@ def mma(self, A_local_buf: Buffer, B_local_buf: Buffer, C_local_buf: Buffer, k_i local_size_out = self.local_size_out k_pack = self.k_pack mma_suffix = self.mma_suffix - a_dtype, b_dtype, out_dtype = self.a_dtype, self.b_dtype, self.accum_dtype - compute_a_dtype = a_dtype if local_size_a == 1 else f"{a_dtype}x{local_size_a}" - compute_b_dtype = b_dtype if local_size_b == 1 else f"{b_dtype}x{local_size_b}" + out_dtype = self.accum_dtype + mma_input_dtype = self.mma_input_dtype + compute_a_dtype = mma_input_dtype if local_size_a == 1 else f"{mma_input_dtype}x{local_size_a}" + compute_b_dtype = mma_input_dtype if local_size_b == 1 else f"{mma_input_dtype}x{local_size_b}" compute_out_dtype = out_dtype if local_size_out == 1 else f"{out_dtype}x{local_size_out}" a_is_fragment = is_fragment(A_local_buf) diff --git a/tilelang/jit/adapter/tvm_ffi.py b/tilelang/jit/adapter/tvm_ffi.py index 63147954..2edfabf8 100644 --- a/tilelang/jit/adapter/tvm_ffi.py +++ b/tilelang/jit/adapter/tvm_ffi.py @@ -176,6 +176,27 @@ def _convert_torch_func(self) -> Callable[..., Any]: dynamic_symbolic_map = self._process_dynamic_symbolic() executable = self.executable + float8_dtype_map = { + getattr(torch, "float8_e4m3fn", None): "float8_e4m3fn", + getattr(torch, "float8_e4m3fnuz", None): "float8_e4m3fnuz", + getattr(torch, "float8_e5m2", None): "float8_e5m2", + getattr(torch, "float8_e5m2fnuz", None): "float8_e5m2fnuz", + } + float8_dtype_map = {k: v for k, v in float8_dtype_map.items() if k is not None} + + def adapt_tensor_for_tvm(arg: torch.Tensor | Any): + if not isinstance(arg, torch.Tensor): + return arg + + float8_dtype = float8_dtype_map.get(arg.dtype) + if float8_dtype is None: + return arg + + # tvm_ffi cannot ingest float8 tensors directly via DLPack today. + # Reuse the existing float8 bridge pattern: pass the storage as int8 + # and recover the logical dtype through a TVM tensor view. + return runtime.from_dlpack(torch.utils.dlpack.to_dlpack(arg.view(torch.int8)))._create_view(arg.shape, dtype=float8_dtype) + # Prepare helpers for friendly dtype error messages prim_func = self.prim_func buffer_map = prim_func.buffer_map @@ -241,7 +262,7 @@ def func(*inputs: torch.Tensor | Any): ins_idx += 1 tensor_list.append(tensor) - executable(*tensor_list) + executable(*(adapt_tensor_for_tvm(tensor) for tensor in tensor_list)) # Return outputs in the requested form if len(self.result_idx) == 1: diff --git a/tilelang/quantize/quantization.py b/tilelang/quantize/quantization.py index 74a545f2..6f84317c 100644 --- a/tilelang/quantize/quantization.py +++ b/tilelang/quantize/quantization.py @@ -63,9 +63,10 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale e_f4 = (f4 & tir.const(6, T.uint16)) >> tir.const(1, T.uint16) # Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126 e_bf16 = e_f4 + tir.const(126, T.uint16) - # Scale is the exponential part, within the representation of uint8 - # To handle the overflow, we use the max function to limit the exponential part to 8 bits - e_bf16 = min(e_bf16 + scale, tir.const((1 << 8) - 1, T.uint16)) + # Scale is the exponent offset stored as uint8. Clamp the adjusted exponent to bf16 range. + tir_u16_max = tir.const((1 << 8) - 1, T.uint16) + scaled_e_bf16 = e_bf16 + tir.Cast(T.uint16, scale) + e_bf16 = tir.Select(scaled_e_bf16 > tir_u16_max, tir_u16_max, scaled_e_bf16) m_f4 = f4 & tir.const(1, T.uint16) val_bf16 = tir.reinterpret(T.bfloat16, ((((s << tir.const(8, T.uint16)) | e_bf16) << tir.const(7, T.uint16)) diff --git a/tilelang/tileop/gemm/gemm_maca_mma.py b/tilelang/tileop/gemm/gemm_maca_mma.py index 5887c957..3a5b7234 100644 --- a/tilelang/tileop/gemm/gemm_maca_mma.py +++ b/tilelang/tileop/gemm/gemm_maca_mma.py @@ -85,7 +85,7 @@ def lower( thread_var=thread_var, ) - in_dtype = self.in_dtype + mma_input_dtype = mma_emitter.mma_input_dtype warp_rows = mma_emitter.warp_rows warp_cols = mma_emitter.warp_cols local_size_a = mma_emitter.local_size_a @@ -117,8 +117,8 @@ def _gemm_ssr() -> None: B_shared into local fragments, then issues Tensor Core mma ops, accumulating into C_local. """ - A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) - B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + A_local = T.alloc_local((warp_rows * local_size_a), mma_input_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), mma_input_dtype) if clear_accum: T.clear(C_buf) for ki in T.serial(0, (block_K // micro_size_k)): @@ -152,7 +152,7 @@ def _gemm_srr() -> None: B_shared into local fragments, then issues Tensor Core mma ops, accumulating into C_local. """ - A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + A_local = T.alloc_local((warp_rows * local_size_a), mma_input_dtype) for ki in T.serial(0, (block_K // micro_size_k)): if clear_accum: @@ -182,7 +182,7 @@ def _gemm_rsr() -> None: B_shared into local fragments, then issues Tensor Core mma ops, accumulating into C_local. """ - B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), mma_input_dtype) if clear_accum: T.clear(C_buf) for ki in T.serial(0, (block_K // micro_size_k)):