Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 18 additions & 12 deletions examples/maca/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import tilelang
import tilelang.language as T
from tilelang.utils.tensor import map_torch_type
from tilelang.utils.target import determine_target, target_is_maca

tilelang.testing.set_random_seed(42)

Expand All @@ -30,6 +31,7 @@ def tl_gemm(
group_size = 128
block_M = 128
block_K = 128
num_stages = 1 if target_is_maca(determine_target("auto", return_object=True)) else 4

A_shape = (M, K)
Scales_A_shape = (M, T.ceildiv(K, group_size))
Expand All @@ -50,7 +52,6 @@ def main(
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_shared = T.alloc_shared(C_shared_shape, out_dtype)
Scale_C_shared = T.alloc_shared((block_M), T.float32)
C_local = T.alloc_fragment(C_shared_shape, accum_dtype)
C_local_accum = T.alloc_fragment(C_shared_shape, accum_dtype)
Expand All @@ -61,7 +62,7 @@ def main(
T.clear(C_local)
T.clear(C_local_accum)
K_iters = T.ceildiv(K, block_K)
for k in T.Pipelined(K_iters, num_stages=4):
for k in T.Pipelined(K_iters, num_stages=num_stages):
# Load A into shared memory
T.copy(A[by * block_M, k * block_K], A_shared)
# Load B into shared memory
Expand All @@ -76,9 +77,7 @@ def main(
for i, j in T.Parallel(block_M, block_N):
C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i]
T.clear(C_local)
# TMA store
T.copy(C_local_accum, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
T.copy(C_local_accum, C[by * block_M, bx * block_N])

return main

Expand Down Expand Up @@ -122,13 +121,20 @@ def ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype):
for j in range(ceildiv(N, 128)):
c_acc.zero_()
for k in range(ceildiv(K, 128)):
c = torch._scaled_mm(
A_fp8[i * 128 : (i + 1) * 128, k * 128 : (k + 1) * 128],
B_fp8[j * 128 : (j + 1) * 128, k * 128 : (k + 1) * 128].T,
scale_a=A_scales[i, k].view(128, 1).contiguous(),
scale_b=B_scales[j, k].view(1, 128).contiguous(),
out_dtype=torch.bfloat16,
)
a_tile = A_fp8[i * 128 : (i + 1) * 128, k * 128 : (k + 1) * 128]
b_tile = B_fp8[j * 128 : (j + 1) * 128, k * 128 : (k + 1) * 128]
scale_a = A_scales[i, k].view(128, 1).contiguous()
scale_b = B_scales[j, k].view(128, 1).contiguous()
try:
c = torch._scaled_mm(
a_tile,
b_tile.T,
scale_a=scale_a,
scale_b=scale_b.view(1, 128),
out_dtype=torch.bfloat16,
)
except RuntimeError:
c = (a_tile.to(torch.float32) * scale_a) @ (b_tile.to(torch.float32) * scale_b).T
c_acc += c.to(torch.float32)
C[i * 128 : (i + 1) * 128, j * 128 : (j + 1) * 128] = c_acc.to(out_dtype)
return C
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from example_deepgemm_fp8_2xAcc import main


@tilelang.testing.pytest.mark.xfail
def test_deepgemm_fp8_2xAcc():
main()

Expand Down
52 changes: 24 additions & 28 deletions examples/maca/dequantize_gemm/dequantize_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import torch


def reinterpret_u16_as_bfloat16(bits: torch.Tensor) -> torch.Tensor:
bits_i32 = (bits & 0xFFFF).to(torch.int32)
bits_i16 = torch.where(bits_i32 >= 0x8000, bits_i32 - 0x10000, bits_i32).to(torch.int16)
return bits_i16.view(torch.bfloat16)


def torch_convert_bit_twiddling(tensor):
"""
This function expects `tensor` to be a 2-D torch.Tensor of dtype `torch.uint8`. Each output element is produced by combining two input bytes and extracting a bf16-like 16-bit pattern according to one of four positional bit layouts (pos 0..3). The result is scaled by 2**126 to adjust the exponent bias and returned as dtype `torch.bfloat16`.
Expand Down Expand Up @@ -45,8 +51,7 @@ def torch_convert_bit_twiddling(tensor):
bf16 = torch.where(pos == 0, res0, torch.where(pos == 1, res1, torch.where(pos == 2, res2, res3)))

# Convert to uint16 for .view(torch.bfloat16)
bf16_uint16 = (bf16 & 0xFFFF).to(torch.uint16)
bf16_bf16 = bf16_uint16.view(torch.bfloat16)
bf16_bf16 = reinterpret_u16_as_bfloat16(bf16)

# Avoid integer overflow by using a float32 multiplier for the exponent scaling
bf16_new = bf16_bf16 * (2.0**126)
Expand All @@ -69,32 +74,23 @@ def torch_convert(tensor, scale_size=None, Scale=None):
torch.Tensor: A new tensor of shape (N, K*2) and dtype torch.bfloat16 containing the decoded bf16 values.
"""

def _convert(val, pos, scale=None):
assert val.dtype == torch.uint8
# val = val.view(torch.int8)
mask = (1 << 4) - 1
f4 = ((val >> (pos * 4)) & mask).to(torch.int16)
s = f4 >> 3
e_f4 = (f4 & 6) >> 1
e_f16 = e_f4 + 126
if scale is not None:
e_f16 = min(e_f16 + scale, (1 << 8) - 1)
m_f4 = f4 & 1
m_f16 = m_f4
val_f16 = (((e_f16 | (s << 8)) << 7) | (m_f16 << 6)) & 0xFFFF
lower_16_bits = (val_f16 & 0xFFFF).to(torch.uint16)
return lower_16_bits.view(torch.bfloat16)

N = tensor.shape[0]
K = tensor.shape[1]
new_tensor = torch.empty(N, K * 2, dtype=torch.bfloat16, device=tensor.device)
for i in range(new_tensor.shape[0]):
for j in range(new_tensor.shape[1]):
if scale_size is not None:
new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2, Scale[i][j // scale_size])
else:
new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2)
return new_tensor
assert tensor.dim() == 2 and tensor.dtype == torch.uint8

low = (tensor & 0x0F).to(torch.int16)
high = ((tensor >> 4) & 0x0F).to(torch.int16)
f4 = torch.stack((low, high), dim=-1).reshape(tensor.shape[0], tensor.shape[1] * 2)

sign = f4 >> 3
exponent = ((f4 & 0x6) >> 1) + 126
if scale_size is not None:
if Scale is None:
raise ValueError("Scale must be provided when scale_size is set")
scale_idx = torch.arange(f4.shape[1], device=tensor.device) // scale_size
exponent = torch.clamp(exponent + Scale[:, scale_idx].to(torch.int16), max=(1 << 8) - 1)

mantissa = f4 & 0x1
val_f16 = (((exponent | (sign << 8)) << 7) | (mantissa << 6)) & 0xFFFF
return reinterpret_u16_as_bfloat16(val_f16)


def print_bit(name, val):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from tvm import DataType
from tvm import tir
import torch
from tilelang.utils.target import determine_target, target_is_maca
from dequantize_utils import torch_convert_bit_twiddling, torch_convert


Expand Down Expand Up @@ -494,6 +495,12 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False,
"""
total_flops = 2 * m * n * k

if target_is_maca(determine_target("auto", return_object=True)):
fast_dequant = False
block_M, block_N, block_K, num_stages, threads, split = 64, 64, 64, 1, 128, 1
else:
block_M, block_N, block_K, num_stages, threads, split = 256, 128, 128, 2, 256, 1

if tune:
kernel = matmul(
m, n, k, T.bfloat16, T.bfloat16, T.float32, num_bits=4, scale_size=scale_size, fast_dequant=fast_dequant, with_bias=with_bias
Expand All @@ -508,12 +515,12 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False,
T.float32,
num_bits=4,
scale_size=scale_size,
block_M=256,
block_N=128,
block_K=128,
num_stages=2,
threads=256,
split=1,
block_M=block_M,
block_N=block_N,
block_K=block_K,
num_stages=num_stages,
threads=threads,
split=split,
fast_dequant=fast_dequant,
with_bias=with_bias,
)
Expand All @@ -537,6 +544,11 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False,


def run_regression_perf(m=4096, n=4096, k=4096, scale_size=32, fast_dequant=True, with_bias=False):
if target_is_maca(determine_target("auto", return_object=True)):
fast_dequant = False
block_M, block_N, block_K, num_stages, threads, split = 64, 64, 64, 1, 128, 1
else:
block_M, block_N, block_K, num_stages, threads, split = 256, 128, 128, 2, 256, 1
kernel = matmul(
m,
n,
Expand All @@ -546,12 +558,12 @@ def run_regression_perf(m=4096, n=4096, k=4096, scale_size=32, fast_dequant=True
"float32",
num_bits=4,
scale_size=scale_size,
block_M=256,
block_N=128,
block_K=128,
num_stages=2,
threads=256,
split=1,
block_M=block_M,
block_N=block_N,
block_K=block_K,
num_stages=num_stages,
threads=threads,
split=split,
fast_dequant=fast_dequant,
with_bias=with_bias,
)
Expand Down
32 changes: 28 additions & 4 deletions examples/maca/dequantize_gemm/example_dequant_gemm_fp4_hopper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@
import itertools
import torch
import argparse
from tilelang.utils.target import determine_target, target_is_maca


def reinterpret_u16_as_float16(bits: torch.Tensor) -> torch.Tensor:
bits_i32 = (bits & 0xFFFF).to(torch.int32)
bits_i16 = torch.where(bits_i32 >= 0x8000, bits_i32 - 0x10000, bits_i32).to(torch.int16)
return bits_i16.view(torch.float16)


def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
Expand Down Expand Up @@ -45,8 +52,7 @@ def _convert(val, pos):
m_f4 = f4 & 1
m_f16 = m_f4
val_f16 = (((e_f16 | (s << 5)) << 10) | (m_f16 << 9)) & 0xFFFF
lower_16_bits = (val_f16 & 0xFFFF).to(torch.uint16)
return lower_16_bits.view(torch.float16)
return reinterpret_u16_as_float16(val_f16)

N = tensor.shape[0]
K = tensor.shape[1]
Expand Down Expand Up @@ -249,8 +255,17 @@ def main(m=256, n=256, k=256, tune=False):
total_flops = 2 * m * n * k

if not tune:
if target_is_maca(determine_target("auto", return_object=True)):
block_M, block_N, block_K, num_stages, threads, split = 64, 64, 64, 1, 128, 1
else:
block_M, block_N, block_K, num_stages, threads, split = 128, 128, 128, 2, 256, 1
kernel = matmul(m, n, k, T.float16, T.float16, T.float32, num_bits=4, tune=tune)(
block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1
block_M=block_M,
block_N=block_N,
block_K=block_K,
num_stages=num_stages,
threads=threads,
split=split,
)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
Expand All @@ -271,8 +286,17 @@ def main(m=256, n=256, k=256, tune=False):


def run_regression_perf(m=4096, n=4096, k=4096):
if target_is_maca(determine_target("auto", return_object=True)):
block_M, block_N, block_K, num_stages, threads, split = 64, 64, 64, 1, 128, 1
else:
block_M, block_N, block_K, num_stages, threads, split = 128, 128, 128, 2, 256, 1
kernel = matmul(m, n, k, "float16", "float16", "float32", num_bits=4, tune=False)(
block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1
block_M=block_M,
block_N=block_N,
block_K=block_K,
num_stages=num_stages,
threads=threads,
split=split,
)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer)
return profiler.do_bench(backend="cupti")
Expand Down
7 changes: 5 additions & 2 deletions examples/maca/dequantize_gemm/example_dequant_gemm_w4a8.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,11 @@ def _convert(val, pos):
def ref_program(A, qB):
dtypeC = T.int32
B = torch_convert(qB)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
# CUDA/MACA does not provide exact int32 matmul here, and float32 matmul
# followed by cast introduces many +/-1 mismatches. Compute the reference
# with exact integer accumulation on CPU instead.
C = torch.matmul(A.cpu().to(torch.int32), B.cpu().to(torch.int32).T)
C = C.to(torch.__getattribute__(dtypeC)).to(A.device)
return C.transpose(0, 1)


Expand Down
14 changes: 11 additions & 3 deletions examples/maca/dequantize_gemm/example_dequant_gemv_fp16xint4.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from typing import Optional, Callable, Any
import torch
from tilelang import DataType
from tilelang.utils.target import determine_target, target_is_maca
from tilelang.quantize import (
_tir_packed_int_to_int_convert,
_tir_packed_to_unsigned_convert,
)


Expand Down Expand Up @@ -55,6 +57,12 @@ def dequantize_gemv(

import_source: Optional[str] = None
func_name: str = ""
if source_format == "uint":
convert_packed = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)
elif source_format in {"int", "sint"}:
convert_packed = _tir_packed_int_to_int_convert(storage_type, storage_nbit)
else:
raise ValueError(f"Unsupported source_format: {source_format}")
if fast_decoding is True:
# Lazy import to decrease the startup time
# as intrin registry may take a while to load
Expand Down Expand Up @@ -119,7 +127,7 @@ def main(
)
else:
for ki in T.serial(micro_size_k):
B_dequantize_local[ki] = _tir_packed_int_to_int_convert(storage_type, storage_nbit)(
B_dequantize_local[ki] = convert_packed(
num_bits, B_quant_local[ki // num_elems_per_byte], ki % num_elems_per_byte, in_dtype
)

Expand Down Expand Up @@ -167,7 +175,7 @@ def main() -> None:
source_format = "uint"
n_partition = 4
reduce_thread = 32
fast_decoding = True
fast_decoding = not target_is_maca(determine_target("auto", return_object=True))
trans_A = False
trans_B = True
group_size = -1
Expand Down Expand Up @@ -229,7 +237,7 @@ def run_regression_perf():
source_format = "uint"
n_partition = 4
reduce_thread = 32
fast_decoding = True
fast_decoding = not target_is_maca(determine_target("auto", return_object=True))
trans_A = False
trans_B = True
group_size = -1
Expand Down
Loading
Loading