-
Notifications
You must be signed in to change notification settings - Fork 57
Open
Description
Thank you for your brilliant work! I test the performance of grouped-gemm with some regular input shapes but found no acceleration compared with DeepGEMM, and I'm not sure if it's the scene chosen is unsuitable, could you please help me or recommend some inputs? Many many thanks~
env:
Cuda version: H100(81G)
Torch version:2.10.0a0+b558c986e8.nv25.11
My result:
The latency chart:
And below is my test scirpt:
import torch
import os
import sys
from pathlib import Path
from typing import Generator, List, Optional, Tuple
sys.path.insert(0, os.path.realpath(list(Path(__file__).parent.glob("../build/lib.*/"))[0]))
import math
import pytest
import torch
import hpc
from utils import allclose
# Set random seed for reproducibility
torch.manual_seed(41)
torch.cuda.manual_seed(41)
import deep_gemm
import enum
import random
import time
def ceil_div(x: int, y: int) -> int:
return (x + y - 1) // y
def align(x: int, y: int) -> int:
return ceil_div(x, y) * y
def ceil_to_ue8m0(x: torch.Tensor):
assert x.view(-1).amax().item() > 0
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
def per_token_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
padded_n = align(n, gran_k)
x_padded = torch.empty((m, padded_n), dtype=x.dtype, device=x.device).fill_(0)
x_padded[:, :n] = x
x_view = x_padded.view(m, -1, gran_k)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
sf = x_amax / 448.0
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, padded_n)[:, :n].contiguous(), sf
def per_block_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros((align(m, gran_k), align(n, gran_k)), dtype=x.dtype, device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, gran_k, x_padded.size(1) // gran_k, gran_k)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
sf = x_amax / 448.0
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(x_view.size(0), x_view.size(2))
def naive_group_gemm(x, w, seqlens, cu_seqlens, xscale, wscale):
m, k = x.shape
num_group, n, _ = w.shape
m_pergroup = m // num_group
y = torch.zeros((m, n), dtype=torch.bfloat16, device=x.device)
xscale = (xscale.repeat_interleave(128, dim=0).permute(1, 0).reshape((num_group, -1, k)))[
:, :m_pergroup, :
].reshape(-1, k)
wscale = wscale.repeat_interleave(128, dim=1).repeat_interleave(128, dim=2)[:, :, :k]
x = (x.to(torch.bfloat16) * xscale).to(torch.bfloat16)
w = (w.to(torch.bfloat16) * wscale).to(torch.bfloat16)
for i in range(num_group):
start_idx = int(cu_seqlens[i].item())
end_idx = int(start_idx + seqlens[i].item()) # cu_seqlens[i + 1].item()
if seqlens[i].item() == 0:
continue
x_group = x[start_idx:end_idx]
w_group = w[i]
y[start_idx:end_idx] = x_group @ w_group.t()
return y
def test_speed_and_precision(num_group, actual_m, n, k):
dtype = torch.float8_e4m3fn
warmup_time = 2
repeat_time = 10
actual_ms = [int(actual_m) for _ in range(num_group)]
aligned_ms = [align(actual_m, 128) for actual_m in actual_ms]
m = sum(aligned_ms)
x_fp32 = (torch.randn((m, k), dtype=torch.float, device="cuda") / 10).to(torch.float32)
w_fp32 = (torch.randn((num_group, n, k), dtype=torch.float, device="cuda") / 10).to(torch.float32)
grouped_layout = torch.empty(m, device='cuda', dtype=torch.int32)
dpgmm_out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
ref_d = torch.randn((m, n), device='cuda', dtype=torch.bfloat16)
start = 0
for i, (actual_m, aligned_m) in enumerate(zip(actual_ms, aligned_ms)):
actual_end = start + actual_m
aligned_end = start + aligned_m
grouped_layout[start: actual_end] = i
grouped_layout[actual_end: aligned_end] = -1
x_fp32[actual_end: aligned_end] = 0
ref_d[start: aligned_end] = x_fp32[start: aligned_end] @ w_fp32[i].t()
start = aligned_end
x_fp8, xscale = per_token_cast_to_fp8(x_fp32, use_ue8m0=False, gran_k=128)
w_fp8_v0 = (torch.empty_like(w_fp32, dtype=torch.float8_e4m3fn),
torch.empty((num_group, ceil_div(n, 128), ceil_div(k, 128)), device='cuda', dtype=torch.float))
for i in range(num_group):
w_fp8_v0[0][i], w_fp8_v0[1][i] = per_block_cast_to_fp8(w_fp32[i], use_ue8m0=False, gran_k=128)
w_fp8, wscale = w_fp8_v0
#warmup
for i in range(warmup_time):
deep_gemm.m_grouped_fp8_fp4_gemm_nt_contiguous((x_fp8, xscale), (w_fp8, wscale), dpgmm_out, grouped_layout, disable_ue8m0_cast=True, use_psum_layout=False,
recipe=None, recipe_a=(1,128), recipe_b=(128, 128))
torch.cuda.synchronize()
#record latency of DeepGEMM
st_time = time.time()
for i in range(repeat_time):
deep_gemm.m_grouped_fp8_fp4_gemm_nt_contiguous((x_fp8, xscale), (w_fp8, wscale), dpgmm_out, grouped_layout, disable_ue8m0_cast=True, use_psum_layout=False,
recipe=None, recipe_a=(1,128), recipe_b=(128, 128))
torch.cuda.synchronize()
deepgemm_latency = time.time() - st_time
########
xscale = xscale.T.contiguous()
actual_ms = torch.tensor(actual_ms,dtype=torch.int32, device="cuda")
cu_seqlens = torch.cumsum(torch.cat([torch.tensor([0], dtype=torch.int32, device="cuda"), actual_ms]), dim=0).to(torch.int32)
mean_seq = int(torch.sum(actual_ms) / num_group)
#warmup
for i in range(warmup_time):
my = hpc.group_gemm_blockwise_fp8(
x_fp8, w_fp8, actual_ms, cu_seqlens, xscale, wscale, num_seq_per_group_avg=mean_seq
)
torch.cuda.synchronize()
#record latency of hpc-GroupedGEMM
st_time = time.time()
for i in range(repeat_time):
my = hpc.group_gemm_blockwise_fp8(
x_fp8, w_fp8, actual_ms, cu_seqlens, xscale, wscale, num_seq_per_group_avg=mean_seq
)
torch.cuda.synchronize()
hpc_latency = time.time() - st_time
########
#check precison
gt = naive_group_gemm(x_fp8, w_fp8, actual_ms, cu_seqlens, xscale, wscale)
torch.testing.assert_close(my.to(torch.float), gt.to(torch.float), rtol=0.08, atol=0.1)
torch.testing.assert_close(dpgmm_out.to(torch.float), gt.to(torch.float), rtol=0.08, atol=0.1)
########
print(f"[num_group={num_group},m={m},n={n},k={k} ]hpc_latency:{hpc_latency*100:.3f}ms, deepgemm_latency:{deepgemm_latency*100:.3f}ms")
return hpc_latency*100, deepgemm_latency*100
if __name__ == "__main__":
group_list = [1,4,8,16]
m_list = [ 128, 1024, 2048, 4096, 8192, 12288, 16384]
n_list = [1024, 4096, 8192]
k_list = [4096, 8192]
for group in group_list:
for m in m_list:
for n in n_list:
for k in k_list:
hpc_time, dpgemm_time=test_speed_and_precision(group, m, n, k)
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels