add dequant_gemm and gemm_streamk examples#907
add dequant_gemm and gemm_streamk examples#907pbbb205 wants to merge 1 commit intotile-ai:ascendc_ptofrom
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Code Review
This pull request introduces design documentation and example implementations for FP16xINT4 Dequantize GEMV, INT8xINT4 Dequantize GEMM, and StreamK GEMM tailored for Ascend NPU. The review feedback highlights several critical improvements for performance and correctness: vectorizing host-side INT4 unpacking to replace inefficient Python loops, correcting the grid size in persistent kernels to match physical core counts and avoid redundant computations, and implementing pipelining and specific compiler configurations to optimize hardware utilization on Ascend.
| C: T.Tensor((M, N), dtype), # type: ignore | ||
| ): | ||
| # Kernel使用tiles总数,Persistent内部做负载均衡 | ||
| with T.Kernel(T.ceildiv(M, block_M) * T.ceildiv(N, block_N), is_npu=True) as (cid, _): |
There was a problem hiding this comment.
The grid size for a persistent kernel should be equal to the number of physical cores (core_num), not the total number of tiles. Using the total number of tiles as the grid size causes each block to redundantly check all tiles in the T.Persistent loop, and if grid_size > core_num, tiles will be computed multiple times. For a persistent kernel to work as intended, the grid size must match the wave_size parameter of T.Persistent.
| with T.Kernel(T.ceildiv(M, block_M) * T.ceildiv(N, block_N), is_npu=True) as (cid, _): | |
| with T.Kernel(core_num, is_npu=True) as (cid, _): |
| B = torch.zeros(N, K, dtype=torch.float16, device=B_packed.device) | ||
| for j in range(K): | ||
| shift = 4 * (j % 2) | ||
| B[:, j] = ((B_packed[:, j // 2].int() >> shift) & 0xF).half() |
There was a problem hiding this comment.
The CPU-based unpack implementation uses a Python loop over the K dimension, which is extremely inefficient for large tensors. Using vectorized PyTorch operations will significantly speed up the host-side preprocessing.
B = torch.empty(N, K, dtype=torch.float16, device=B_packed.device)
B[:, 0::2] = (B_packed & 0x0F).half()
B[:, 1::2] = ((B_packed.to(torch.uint8) >> 4) & 0x0F).half()| B = torch.zeros(N, K, dtype=torch.int8, device=B_packed.device) | ||
| for j in range(K): | ||
| shift = 4 * (j % 2) | ||
| i4 = (B_packed[:, j // 2].to(torch.int32) >> shift) & 0xF | ||
| # 符号扩展 | ||
| i4_signed = ((i4 << 28) >> 28) | ||
| B[:, j] = i4_signed.to(torch.int8) |
There was a problem hiding this comment.
The CPU-based unpack loop is inefficient. It can be replaced with vectorized operations that handle 4-bit sign extension more effectively using arithmetic shifts.
B = torch.empty(N, K, dtype=torch.int8, device=B_packed.device)
# Extract low 4 bits and sign extend
B[:, 0::2] = (B_packed << 4).to(torch.int8) >> 4
# Extract high 4 bits and sign extend
B[:, 1::2] = B_packed >> 4| return math.ceil(a / b) | ||
|
|
||
|
|
||
| @tl.jit(out_idx=[-1]) |
There was a problem hiding this comment.
For Ascend NPU kernels, it is highly recommended to set pass_configs such as TL_ASCEND_AUTO_SYNC to ensure correct synchronization and optimal performance. This allows the compiler to automatically insert necessary synchronization primitives.
@tl.jit(
out_idx=[-1],
pass_configs={
tl.PassConfigKey.TL_ASCEND_AUTO_SYNC: True,
tl.PassConfigKey.TL_ASCEND_AUTO_CV_COMBINE: True,
}
)| for k in T.serial(loop_k): | ||
| T.copy(A[bx * block_M, k * block_K], A_L1) | ||
| T.copy(B[k * block_K, by * block_N], B_L1) | ||
|
|
||
| T.barrier_all() | ||
| T.gemm_v0(A_L1, B_L1, C_L0, init=(k == 0)) | ||
| T.barrier_all() |
There was a problem hiding this comment.
The design document and the module docstring specify using T.Pipelined for K-dimension optimization. The current implementation uses T.serial, which does not overlap memory transfers with computation. Switching to T.Pipelined will improve performance.
| for k in T.serial(loop_k): | |
| T.copy(A[bx * block_M, k * block_K], A_L1) | |
| T.copy(B[k * block_K, by * block_N], B_L1) | |
| T.barrier_all() | |
| T.gemm_v0(A_L1, B_L1, C_L0, init=(k == 0)) | |
| T.barrier_all() | |
| for k in T.Pipelined(loop_k, num_stages=2): | |
| T.copy(A[bx * block_M, k * block_K], A_L1) | |
| T.copy(B[k * block_K, by * block_N], B_L1) | |
| T.gemm_v0(A_L1, B_L1, C_L0, init=(k == 0)) |
No description provided.