Skip to content

[Performance] GroupedGEMM seems not to be faster than DeepGEMM #28

@NiuMa-1234

Description

@NiuMa-1234

Thank you for your brilliant work! I test the performance of grouped-gemm with some regular input shapes but found no acceleration compared with DeepGEMM, and I'm not sure if it's the scene chosen is unsuitable, could you please help me or recommend some inputs? Many many thanks~

env:

Cuda version: H100(81G)
Torch version:2.10.0a0+b558c986e8.nv25.11

My result:

Image The latency chart: Image

And below is my test scirpt:

import torch
import os
import sys
from pathlib import Path
from typing import Generator, List, Optional, Tuple

sys.path.insert(0, os.path.realpath(list(Path(__file__).parent.glob("../build/lib.*/"))[0]))

import math
import pytest
import torch
import hpc
from utils import allclose
# Set random seed for reproducibility
torch.manual_seed(41)
torch.cuda.manual_seed(41)

import deep_gemm
import enum
import random
import time

def ceil_div(x: int, y: int) -> int:
    return (x + y - 1) // y

def align(x: int, y: int) -> int:
    return ceil_div(x, y) * y

def ceil_to_ue8m0(x: torch.Tensor):
    assert x.view(-1).amax().item() > 0
    return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))

def per_token_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
    assert x.dim() == 2
    m, n = x.shape
    padded_n = align(n, gran_k)
    x_padded = torch.empty((m, padded_n), dtype=x.dtype, device=x.device).fill_(0)
    x_padded[:, :n] = x
    x_view = x_padded.view(m, -1, gran_k)
    x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
    sf = x_amax / 448.0
    sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
    return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, padded_n)[:, :n].contiguous(), sf


def per_block_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
    assert x.dim() == 2
    m, n = x.shape
    x_padded = torch.zeros((align(m, gran_k), align(n, gran_k)), dtype=x.dtype, device=x.device)
    x_padded[:m, :n] = x
    x_view = x_padded.view(-1, gran_k, x_padded.size(1) // gran_k, gran_k)
    x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
    sf = x_amax / 448.0
    sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
    x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
    return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(x_view.size(0), x_view.size(2))

def naive_group_gemm(x, w, seqlens, cu_seqlens, xscale, wscale):

    m, k = x.shape
    num_group, n, _ = w.shape
    m_pergroup = m // num_group
    y = torch.zeros((m, n), dtype=torch.bfloat16, device=x.device)
    xscale = (xscale.repeat_interleave(128, dim=0).permute(1, 0).reshape((num_group, -1, k)))[
        :, :m_pergroup, :
    ].reshape(-1, k)
    wscale = wscale.repeat_interleave(128, dim=1).repeat_interleave(128, dim=2)[:, :, :k]
    x = (x.to(torch.bfloat16) * xscale).to(torch.bfloat16)
    w = (w.to(torch.bfloat16) * wscale).to(torch.bfloat16)

    for i in range(num_group):
        start_idx = int(cu_seqlens[i].item())
        end_idx = int(start_idx + seqlens[i].item())  # cu_seqlens[i + 1].item()
        if seqlens[i].item() == 0:
            continue
        x_group = x[start_idx:end_idx]
        w_group = w[i]
        y[start_idx:end_idx] = x_group @ w_group.t()
    return y

def test_speed_and_precision(num_group, actual_m, n, k):
    dtype = torch.float8_e4m3fn
    warmup_time = 2
    repeat_time = 10
    actual_ms = [int(actual_m) for _ in range(num_group)]
    aligned_ms = [align(actual_m, 128) for actual_m in actual_ms]
    m = sum(aligned_ms)
    x_fp32 = (torch.randn((m, k), dtype=torch.float, device="cuda") / 10).to(torch.float32)
    w_fp32 = (torch.randn((num_group, n, k), dtype=torch.float, device="cuda") / 10).to(torch.float32)
    grouped_layout = torch.empty(m, device='cuda', dtype=torch.int32)
    dpgmm_out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
    ref_d = torch.randn((m, n), device='cuda', dtype=torch.bfloat16)
    start = 0
    for i, (actual_m, aligned_m) in enumerate(zip(actual_ms, aligned_ms)):
        actual_end = start + actual_m
        aligned_end = start + aligned_m
        grouped_layout[start: actual_end] = i
        grouped_layout[actual_end: aligned_end] = -1
        x_fp32[actual_end: aligned_end] = 0
        ref_d[start: aligned_end] = x_fp32[start: aligned_end] @ w_fp32[i].t()
        start = aligned_end

    x_fp8, xscale = per_token_cast_to_fp8(x_fp32, use_ue8m0=False, gran_k=128)
    w_fp8_v0 = (torch.empty_like(w_fp32, dtype=torch.float8_e4m3fn),
                torch.empty((num_group, ceil_div(n, 128), ceil_div(k, 128)), device='cuda', dtype=torch.float))
    for i in range(num_group):
        w_fp8_v0[0][i], w_fp8_v0[1][i] = per_block_cast_to_fp8(w_fp32[i], use_ue8m0=False, gran_k=128)
    w_fp8, wscale = w_fp8_v0
    
    #warmup
    for i in range(warmup_time):
        deep_gemm.m_grouped_fp8_fp4_gemm_nt_contiguous((x_fp8, xscale), (w_fp8, wscale), dpgmm_out, grouped_layout, disable_ue8m0_cast=True, use_psum_layout=False,
                                            recipe=None, recipe_a=(1,128), recipe_b=(128, 128))
    torch.cuda.synchronize()
    #record latency of DeepGEMM
    st_time = time.time()
    for i in range(repeat_time):
        deep_gemm.m_grouped_fp8_fp4_gemm_nt_contiguous((x_fp8, xscale), (w_fp8, wscale), dpgmm_out, grouped_layout, disable_ue8m0_cast=True, use_psum_layout=False,
                                            recipe=None, recipe_a=(1,128), recipe_b=(128, 128))
    torch.cuda.synchronize()
    deepgemm_latency = time.time() - st_time
    ########

    xscale = xscale.T.contiguous()
    actual_ms = torch.tensor(actual_ms,dtype=torch.int32, device="cuda")
    cu_seqlens = torch.cumsum(torch.cat([torch.tensor([0], dtype=torch.int32, device="cuda"), actual_ms]), dim=0).to(torch.int32)
    mean_seq = int(torch.sum(actual_ms) / num_group)
    
    #warmup
    for i in range(warmup_time):
        my = hpc.group_gemm_blockwise_fp8(
            x_fp8, w_fp8, actual_ms, cu_seqlens, xscale, wscale, num_seq_per_group_avg=mean_seq
        )
    torch.cuda.synchronize()
    #record latency of hpc-GroupedGEMM
    st_time = time.time()
    for i in range(repeat_time):
        my = hpc.group_gemm_blockwise_fp8(
            x_fp8, w_fp8, actual_ms, cu_seqlens, xscale, wscale, num_seq_per_group_avg=mean_seq
        )
    torch.cuda.synchronize()
    hpc_latency = time.time() - st_time
    ########

    #check precison
    gt = naive_group_gemm(x_fp8, w_fp8, actual_ms, cu_seqlens, xscale, wscale)
    torch.testing.assert_close(my.to(torch.float), gt.to(torch.float), rtol=0.08, atol=0.1)
    torch.testing.assert_close(dpgmm_out.to(torch.float), gt.to(torch.float), rtol=0.08, atol=0.1)
    ########

    print(f"[num_group={num_group},m={m},n={n},k={k} ]hpc_latency:{hpc_latency*100:.3f}ms, deepgemm_latency:{deepgemm_latency*100:.3f}ms")
    return hpc_latency*100, deepgemm_latency*100

if __name__ == "__main__":
    group_list = [1,4,8,16]
    m_list = [ 128, 1024, 2048, 4096, 8192, 12288, 16384]
    n_list = [1024, 4096, 8192]
    k_list = [4096, 8192]

    for group in group_list:
        for m in m_list:
            for n in n_list:
                for k in k_list:
                    hpc_time, dpgemm_time=test_speed_and_precision(group, m, n, k)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions