diff --git a/benchmark/test_blas_perf.py b/benchmark/test_blas_perf.py index a20e95263d..64404cbdbe 100644 --- a/benchmark/test_blas_perf.py +++ b/benchmark/test_blas_perf.py @@ -350,3 +350,32 @@ def addr_input_fn(m, n, cur_dtype, device): dtypes=FLOAT_DTYPES, ) bench.run() + + +class GemmBenchmark(BlasBenchmark): + """ + benchmark for gemm + """ + + pass + + +@pytest.mark.gemm +def test_gemm_benchmark(): + def gemm_input_fn(b, m, n, k, cur_dtype, device, b_column_major): + inp1 = torch.randn([m, k], dtype=cur_dtype, device=device) + if b_column_major: + inp2 = torch.randn([n, k], dtype=cur_dtype, device=device) + yield inp1, inp2.t() + else: + inp2 = torch.randn([k, n], dtype=cur_dtype, device=device) + yield inp1, inp2 + + bench = GemmBenchmark( + input_fn=gemm_input_fn, + op_name="gemm", + torch_op=torch.Tensor.mm, + dtypes=FLOAT_DTYPES, + ) + bench.set_gems(flag_gems.ops.gemm) + bench.run() diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py index 79de2d94a5..fbaace1510 100644 --- a/src/flag_gems/__init__.py +++ b/src/flag_gems/__init__.py @@ -237,6 +237,8 @@ def torch_ge(v): ("minimum", minimum), ("mm", mm), ("mm.out", mm_out), + ("gemm", gemm), + ("gemm_out", gemm_out), ("mse_loss", mse_loss), ("mul.Tensor", mul), ("mul_.Tensor", mul_), diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index 194668a005..7cd8ad867f 100644 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -104,6 +104,7 @@ from flag_gems.ops.gather import gather, gather_backward from flag_gems.ops.ge import ge, ge_scalar from flag_gems.ops.gelu import gelu, gelu_, gelu_backward +from flag_gems.ops.gemm import gemm, gemm_out from flag_gems.ops.get_scheduler_metadata import get_scheduler_metadata from flag_gems.ops.glu import glu, glu_backward from flag_gems.ops.groupnorm import group_norm, group_norm_backward @@ -375,6 +376,8 @@ "gelu_", "gelu_backward", "get_scheduler_metadata", + "gemm", + "gemm_out", "glu", "glu_backward", "group_norm", diff --git a/src/flag_gems/ops/gemm.py b/src/flag_gems/ops/gemm.py new file mode 100644 index 0000000000..c385b5fc68 --- /dev/null +++ b/src/flag_gems/ops/gemm.py @@ -0,0 +1,409 @@ +import logging +from typing import Optional + +import torch +import triton +import triton.language as tl + +from flag_gems import runtime +from flag_gems.ops.mm_streamk import streamk_mm +from flag_gems.runtime import torch_device_fn +from flag_gems.utils import libentry, libtuner +from flag_gems.utils import triton_lang_extension as tle +from flag_gems.utils.device_info import get_device_capability, get_sm_count + +CACHE_USAGE_THRESHOLD = 0.8 + +logger = logging.getLogger(__name__) + + +def is_tma_compatible(a, b, N, K): + """ + Check if tensors are compatible with TMA (Tensor Memory Accelerator). + + TMA requires 128-bit (16-byte) alignment for memory access: + - For FP16/BF16 (2 bytes/element): N and K must be multiples of 8 + (8 elements × 2 bytes = 16 bytes) + - For FP32 (4 bytes/element): N and K must be multiples of 4 + (4 elements × 4 bytes = 16 bytes) + + Args: + a, b: Input tensors + N, K: Matrix dimensions + + Returns: + bool: True if compatible with TMA's 128-bit alignment requirement + """ + return ( + a.dtype in (torch.float16, torch.bfloat16) + and b.dtype in (torch.float16, torch.bfloat16) + and N % 8 == 0 + and K % 8 == 0 + ) or ( + a.dtype in (torch.float32,) + and b.dtype in (torch.float32,) + and N % 4 == 0 + and K % 4 == 0 + ) + + +@triton.jit +def prev_multiple_of(a, b): + # the largest x= "3.4" + and hasattr(triton.tools.tensor_descriptor, "TensorDescriptor") + and is_tma_compatible(a, b, N, K) + ): + logger.debug("Using TMA-optimized kernel") + a_row_major = a.stride(1) == 1 + b_row_major = b.stride(1) == 1 + dummy_block = [1, 1] + # triton 3.5.0 + from triton.tools.tensor_descriptor import TensorDescriptor + + if a_row_major: + a_desc = TensorDescriptor(a, a.shape, a.stride(), dummy_block) + else: + a_desc = TensorDescriptor(a, a.T.shape, a.T.stride(), dummy_block) + if b_row_major: + b_desc = TensorDescriptor(b, b.shape, b.stride(), dummy_block) + else: + b_desc = TensorDescriptor(b, b.T.shape, b.T.stride(), dummy_block) + c_desc = TensorDescriptor(c, c.shape, c.stride(), dummy_block) + + input_dtype = a.dtype + dtype_str = str(input_dtype).split(".")[-1] + + with torch_device_fn.device(a.device): + gemm_kernel_general_host_tma[grid]( + a_desc, + b_desc, + c_desc, + alpha, + beta, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + GROUP_M=8, + A_ROW_MAJOR=a_row_major, + B_ROW_MAJOR=b_row_major, + dtype=dtype_str, + ) + else: + logger.debug("Using regular kernel") + + def alloc_fn(size: int, align: int, stream: Optional[int]): + return torch.empty(size, dtype=torch.int8, device=a.device) + + triton.set_allocator(alloc_fn) + + with torch_device_fn.device(a.device): + gemm_kernel_general[grid]( + a, + b, + c, + alpha, + beta, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + GROUP_M=8, + ) + return c + + +def streamk_scenario(a, b, alpha, beta, M, N, K): + # TODO: this my change sometime according to the realbenchmark result + # Currently, the best configuration for streamk has only been tested on A100(capability[0] == 8). + # The optimal settings for other devices need to be determined through real testing. + capability = get_device_capability() + return ( + capability[0] == 8 + and alpha == 1 + and beta == 0 + and a.dtype in [torch.float16, torch.bfloat16] + and b.dtype in [torch.float16, torch.bfloat16] + and a.is_contiguous() + and b.is_contiguous() + and K > M * 5 + and K > N * 5 + ) + + +def gemm(a, b, beta=0, alpha=1): + device = a.device + # handle non-contiguous inputs if necessary + if a.stride(0) > 1 and a.stride(1) > 1: + a = a.contiguous() + if b.stride(0) > 1 and b.stride(1) > 1: + b = b.contiguous() + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + # allocates output + c_dtype = get_higher_dtype(a.dtype, b.dtype) + c = torch.empty((M, N), device=device, dtype=c_dtype) + # l2_cache_size = get_l2_cache_size() + sm_count = get_sm_count() + if streamk_scenario(a, b, alpha, beta, M, N, K): + return streamk_mm(a, b, c, M, N, K, sm_count=sm_count) + else: + return general_gemm(a, b, c, alpha, beta, M, N, K) + + +def gemm_out(a, b, *, beta=0, alpha=1, out): + # handle non-contiguous inputs if necessary + if a.stride(0) > 1 and a.stride(1) > 1: + a = a.contiguous() + if b.stride(0) > 1 and b.stride(1) > 1: + b = b.contiguous() + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + # l2_cache_size = get_l2_cache_size() + sm_count = get_sm_count() + if streamk_scenario(a, b, alpha, beta, M, N, K): + return streamk_mm(a, b, out, M, N, K, sm_count=sm_count) + else: + return general_gemm(a, b, out, alpha, beta, M, N, K) diff --git a/src/flag_gems/runtime/backend/_nvidia/heuristics_config_utils.py b/src/flag_gems/runtime/backend/_nvidia/heuristics_config_utils.py index 149d2a568f..ac58b0b200 100644 --- a/src/flag_gems/runtime/backend/_nvidia/heuristics_config_utils.py +++ b/src/flag_gems/runtime/backend/_nvidia/heuristics_config_utils.py @@ -183,6 +183,10 @@ def mm_heur_even_k(args): return args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0 +def gemm_heur_even_k(args): + return args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0 + + def rand_heur_block(args): if args["N"] <= 512: return 512 @@ -438,6 +442,9 @@ def mean_heur_one_tile_per_cta(args): "mm": { "EVEN_K": mm_heur_even_k, }, + "gemm": { + "EVEN_K": gemm_heur_even_k, + }, "rand": { "BLOCK": rand_heur_block, "num_warps": rand_heur_num_warps, diff --git a/src/flag_gems/runtime/backend/_nvidia/hopper/ops/__init__.py b/src/flag_gems/runtime/backend/_nvidia/hopper/ops/__init__.py index 0654a948de..7d5b7586a1 100644 --- a/src/flag_gems/runtime/backend/_nvidia/hopper/ops/__init__.py +++ b/src/flag_gems/runtime/backend/_nvidia/hopper/ops/__init__.py @@ -2,5 +2,4 @@ if triton.__version__ >= "3.4": from .mm import mm # noqa: F401 - __all__ = ["*"] diff --git a/src/flag_gems/runtime/backend/_nvidia/hopper/tune_configs.yaml b/src/flag_gems/runtime/backend/_nvidia/hopper/tune_configs.yaml index ede1cadeff..044adcc3a6 100644 --- a/src/flag_gems/runtime/backend/_nvidia/hopper/tune_configs.yaml +++ b/src/flag_gems/runtime/backend/_nvidia/hopper/tune_configs.yaml @@ -131,3 +131,136 @@ mm: BLOCK_K: 32 num_stages: 4 num_warps: 8 +gemm: + - META: + BLOCK_M: 64 + BLOCK_N: 64 + BLOCK_K: 64 + num_stages: 5 + num_warps: 4 + - META: + BLOCK_M: 64 + BLOCK_N: 128 + BLOCK_K: 64 + num_stages: 3 + num_warps: 4 + - META: + BLOCK_M: 64 + BLOCK_N: 64 + BLOCK_K: 128 + num_stages: 4 + num_warps: 8 + - META: + BLOCK_M: 64 + BLOCK_N: 128 + BLOCK_K: 128 + num_stages: 4 + num_warps: 8 + - META: + BLOCK_M: 64 + BLOCK_N: 64 + BLOCK_K: 64 + num_stages: 4 + num_warps: 4 + - META: + BLOCK_M: 64 + BLOCK_N: 64 + BLOCK_K: 128 + num_stages: 3 + num_warps: 4 + - META: + BLOCK_M: 64 + BLOCK_N: 128 + BLOCK_K: 64 + num_stages: 4 + num_warps: 4 + - META: + BLOCK_M: 128 + BLOCK_N: 128 + BLOCK_K: 32 + num_stages: 5 + num_warps: 8 + - META: + BLOCK_M: 128 + BLOCK_N: 128 + BLOCK_K: 32 + num_stages: 4 + num_warps: 8 + - META: + BLOCK_M: 32 + BLOCK_N: 64 + BLOCK_K: 256 + num_stages: 4 + num_warps: 4 + - META: + BLOCK_M: 16 + BLOCK_N: 64 + BLOCK_K: 256 + num_stages: 4 + num_warps: 4 + - META: + BLOCK_M: 16 + BLOCK_N: 64 + BLOCK_K: 64 + num_stages: 5 + num_warps: 2 + - META: + BLOCK_M: 16 + BLOCK_N: 64 + BLOCK_K: 128 + num_stages: 4 + num_warps: 2 + - META: + BLOCK_M: 128 + BLOCK_N: 128 + BLOCK_K: 64 + num_stages: 3 + num_warps: 8 + - META: + BLOCK_M: 128 + BLOCK_N: 256 + BLOCK_K: 64 + num_stages: 3 + num_warps: 8 + - META: + BLOCK_M: 256 + BLOCK_N: 128 + BLOCK_K: 64 + num_stages: 3 + num_warps: 8 + - META: + BLOCK_M: 128 + BLOCK_N: 256 + BLOCK_K: 32 + num_stages: 3 + num_warps: 8 + - META: + BLOCK_M: 256 + BLOCK_N: 128 + BLOCK_K: 32 + num_stages: 3 + num_warps: 8 + - META: + BLOCK_M: 128 + BLOCK_N: 256 + BLOCK_K: 64 + num_stages: 4 + num_warps: 8 + - META: + BLOCK_M: 256 + BLOCK_N: 128 + BLOCK_K: 64 + num_stages: 4 + num_warps: 8 + - META: + BLOCK_M: 128 + BLOCK_N: 256 + BLOCK_K: 32 + num_stages: 4 + num_warps: 8 + - META: + BLOCK_M: 256 + BLOCK_N: 128 + BLOCK_K: 32 + num_stages: 4 + num_warps: 8 diff --git a/src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml b/src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml index 9b34bac55c..7627dd4486 100644 --- a/src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml +++ b/src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml @@ -273,6 +273,139 @@ mm: BLOCK_K: 32 num_stages: 4 num_warps: 8 +gemm: + - META: + BLOCK_M: 64 + BLOCK_N: 64 + BLOCK_K: 64 + num_stages: 5 + num_warps: 4 + - META: + BLOCK_M: 64 + BLOCK_N: 128 + BLOCK_K: 64 + num_stages: 3 + num_warps: 4 + - META: + BLOCK_M: 64 + BLOCK_N: 64 + BLOCK_K: 128 + num_stages: 4 + num_warps: 8 + - META: + BLOCK_M: 64 + BLOCK_N: 128 + BLOCK_K: 128 + num_stages: 4 + num_warps: 8 + - META: + BLOCK_M: 64 + BLOCK_N: 64 + BLOCK_K: 64 + num_stages: 4 + num_warps: 4 + - META: + BLOCK_M: 64 + BLOCK_N: 64 + BLOCK_K: 128 + num_stages: 3 + num_warps: 4 + - META: + BLOCK_M: 64 + BLOCK_N: 128 + BLOCK_K: 64 + num_stages: 4 + num_warps: 4 + - META: + BLOCK_M: 128 + BLOCK_N: 128 + BLOCK_K: 32 + num_stages: 5 + num_warps: 8 + - META: + BLOCK_M: 128 + BLOCK_N: 128 + BLOCK_K: 32 + num_stages: 4 + num_warps: 8 + - META: + BLOCK_M: 32 + BLOCK_N: 64 + BLOCK_K: 256 + num_stages: 4 + num_warps: 4 + - META: + BLOCK_M: 16 + BLOCK_N: 64 + BLOCK_K: 256 + num_stages: 4 + num_warps: 4 + - META: + BLOCK_M: 16 + BLOCK_N: 64 + BLOCK_K: 64 + num_stages: 5 + num_warps: 2 + - META: + BLOCK_M: 16 + BLOCK_N: 64 + BLOCK_K: 128 + num_stages: 4 + num_warps: 2 + - META: + BLOCK_M: 128 + BLOCK_N: 128 + BLOCK_K: 64 + num_stages: 3 + num_warps: 8 + - META: + BLOCK_M: 128 + BLOCK_N: 256 + BLOCK_K: 64 + num_stages: 3 + num_warps: 8 + - META: + BLOCK_M: 256 + BLOCK_N: 128 + BLOCK_K: 64 + num_stages: 3 + num_warps: 8 + - META: + BLOCK_M: 128 + BLOCK_N: 256 + BLOCK_K: 32 + num_stages: 3 + num_warps: 8 + - META: + BLOCK_M: 256 + BLOCK_N: 128 + BLOCK_K: 32 + num_stages: 3 + num_warps: 8 + - META: + BLOCK_M: 128 + BLOCK_N: 256 + BLOCK_K: 64 + num_stages: 4 + num_warps: 8 + - META: + BLOCK_M: 256 + BLOCK_N: 128 + BLOCK_K: 64 + num_stages: 4 + num_warps: 8 + - META: + BLOCK_M: 128 + BLOCK_N: 256 + BLOCK_K: 32 + num_stages: 4 + num_warps: 8 + - META: + BLOCK_M: 256 + BLOCK_N: 128 + BLOCK_K: 32 + num_stages: 4 + num_warps: 8 baddbmm: - META: TILE_M: 32 diff --git a/tests/test_blas_ops.py b/tests/test_blas_ops.py index d1b88f2a1e..aa82b91525 100644 --- a/tests/test_blas_ops.py +++ b/tests/test_blas_ops.py @@ -289,6 +289,36 @@ def test_accuracy_mm(M, N, K, dtype, b_column_major): gems_assert_close(res_out, ref_out, dtype, reduce_dim=K) +@pytest.mark.gemm +@pytest.mark.parametrize("M, N, K", MNK_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +@pytest.mark.parametrize("b_column_major", [True, False]) +def test_accuracy_gemm(M, N, K, dtype, b_column_major): + if flag_gems.vendor_name == "tsingmicro" and dtype == torch.float32: + pytest.skip("Skiping fp32 mm test on tsingmicro platform") + + torch.manual_seed(0) + torch.cuda.manual_seed_all(0) + np.random.seed(0) + random.seed(0) + + alpha = 2.0 + beta = 0 + mat1 = torch.randn((M, K), dtype=dtype, device=flag_gems.device) + if b_column_major: + mat2 = torch.randn((N, K), dtype=dtype, device=flag_gems.device).t() + else: + mat2 = torch.randn((K, N), dtype=dtype, device=flag_gems.device) + ref_mat1 = to_reference(mat1, True) + ref_mat2 = to_reference(mat2, True) + + ref_out = torch.mm(ref_mat1, ref_mat2) * alpha + with flag_gems.use_gems(): + res_out = flag_gems.ops.gemm(mat1, mat2, beta, alpha) + + gems_assert_close(res_out, ref_out, dtype, reduce_dim=K) + + @pytest.mark.mv @pytest.mark.parametrize("M, N", MN_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES)