diff --git a/src/flag_gems/runtime/backend/_mthreads/ops/mm.py b/src/flag_gems/runtime/backend/_mthreads/ops/mm.py index b4e3ead152..4fcd94a207 100644 --- a/src/flag_gems/runtime/backend/_mthreads/ops/mm.py +++ b/src/flag_gems/runtime/backend/_mthreads/ops/mm.py @@ -10,11 +10,33 @@ from flag_gems.utils import libentry, libtuner from flag_gems.utils import triton_lang_extension as tle -from .utils import create_tma_device_descriptor, should_enable_sqmma +from .utils import create_tma_device_descriptor, get_cached_tma_device_descriptor + +logger = logging.getLogger("flag_gems.runtime.backend._mthreads.ops.mm") + + +def is_supported_sqmma_layout(tensor): + return tensor.is_contiguous() or ( + tensor.stride(0) == 1 and tensor.stride(1) == tensor.shape[0] + ) -logger = logging.getLogger( - f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}' -) + +def is_sqmma_compatible(a, b, N, K): + return ( + os.getenv("MUSA_ENABLE_SQMMA", "0") == "1" + and a.dim() == 2 + and b.dim() == 2 + and a.dtype == b.dtype + and a.dtype in (torch.float16, torch.bfloat16) + and is_supported_sqmma_layout(a) + and is_supported_sqmma_layout(b) + and N % 8 == 0 + and K % 8 == 0 + ) + + +def matmul_get_configs(): + return runtime.get_tuned_config("mm") @triton.jit @@ -25,9 +47,9 @@ def prev_multiple_of(a, b): @libentry() @libtuner( - configs=runtime.get_tuned_config("mm"), - key=["M", "N", "K"], - strategy=["align32", "align32", "align32"], + configs=matmul_get_configs(), + key=["M", "N", "K", "stride_am", "stride_bk"], + strategy=["align32", "align32", "align32", "align32", "align32"], ) @triton.jit def mm_kernel( @@ -43,6 +65,7 @@ def mm_kernel( stride_bn, stride_cm, stride_cn, + dtype: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, @@ -101,6 +124,58 @@ def mm_kernel( tl.store(C, acc, mask=mask) +def gemv_get_configs(): + return [triton.Config({"BLOCK_M": 64, "BLOCK_K": 64})] + + +@libentry() +@libtuner( + configs=gemv_get_configs(), + key=["M", "K", "stride_am", "stride_bk"], + strategy=["align32", "align32", "align32", "default"], +) +@triton.jit +def gemv_kernel( + A, + B, + C, + M, + K, + stride_am, + stride_ak, + stride_bk, + stride_cm, + BLOCK_M: tl.constexpr, + BLOCK_K: tl.constexpr, +): + pid = tle.program_id(0) + + row_start = pid * BLOCK_M + row_offset = row_start + tl.arange(0, BLOCK_M) + row_mask = row_offset < M + + acc = tl.zeros((BLOCK_M,), dtype=tl.float32) + + for k_start in range(0, K, BLOCK_K): + k_offset = k_start + tl.arange(0, BLOCK_K) + k_mask = k_offset < K + + a_ptrs = A + row_offset[:, None] * stride_am + k_offset[None, :] * stride_ak + a = tl.load(a_ptrs, mask=row_mask[:, None] & k_mask[None, :], other=0.0) + + b_ptrs = B + k_offset * stride_bk + b = tl.load(b_ptrs, mask=k_mask, other=0.0) + + # Keep the reduction in fp32 so N=1 GEMV matches the mm path more closely. + a = a.to(tl.float32) + b = b.to(tl.float32) + acc += tl.sum(a * b[None, :], axis=1) + + c_ptrs = C + row_offset * stride_cm + acc = acc.to(C.dtype.element_ty) + tl.store(c_ptrs, acc, mask=row_mask) + + _ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32] @@ -151,11 +226,34 @@ def mm_fma(a, b): b.stride(1), c.stride(0), c.stride(1), + dtype=str(a.dtype).split(".")[-1], GROUP_M=8, ) return c +def gemv_mm(a, b, c, M, K): + logger.debug( + "GEMS_MTHREADS MM(GEMV), [shape info]: [%s, %s, 1](M, K, N)", + M, + K, + ) + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),) + with torch_device_fn.device(a.device): + gemv_kernel[grid]( + a, + b, + c, + M, + K, + a.stride(0), + a.stride(1), + b.stride(0), + c.stride(0), + ) + return c + + def mm_out(a, b, *, out): logger.debug("GEMS_MTHREADS MM_OUT") # handle non-contiguous inputs if necessary @@ -169,6 +267,8 @@ def mm_out(a, b, *, out): _, N = b.shape # allocates output c = out + if N == 1: + return gemv_mm(a, b, c, M, K) # launch kernel grid = lambda META: ( triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), @@ -187,62 +287,106 @@ def mm_out(a, b, *, out): b.stride(1), c.stride(0), c.stride(1), + dtype=str(a.dtype).split(".")[-1], GROUP_M=8, ) return c +def sqmma_descriptor_pre_hook(nargs): + a = nargs["A"] + b = nargs["B"] + c = nargs["C"] + block_m = nargs["BLOCK_M"] + block_n = nargs["BLOCK_N"] + block_k = nargs["BLOCK_K"] + device = c.device + + nargs["a_desc_ptr"].copy_( + get_cached_tma_device_descriptor(a, block_m, block_k, device) + ) + nargs["b_desc_ptr"].copy_( + get_cached_tma_device_descriptor(b, block_k, block_n, device) + ) + nargs["c_desc_ptr"].copy_(create_tma_device_descriptor(c, block_m, block_n, device)) + + +def sqmma_get_configs(pre_hook=sqmma_descriptor_pre_hook): + return [ + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64}, + num_stages=1, + num_warps=4, + pre_hook=pre_hook, + ) + ] + + +@libentry() +@libtuner( + configs=sqmma_get_configs(), + key=["M", "N", "K", "stride_am", "stride_bk", "dtype"], + strategy=["align32", "align32", "align32", "align32", "align32", "default"], +) @triton.jit def mm_sqmma_kernel( + A, + B, + C, a_desc_ptr, b_desc_ptr, c_desc_ptr, M, N, K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + dtype: tl.constexpr, GROUP_M: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, ab_dtype: tl.constexpr, c_dtype: tl.constexpr, is_transpose_a: tl.constexpr = False, is_transpose_b: tl.constexpr = False, ): pid = tle.program_id(0) - grid_m = tl.cdiv(M, BLOCK_SIZE_M) - grid_n = tl.cdiv(N, BLOCK_SIZE_N) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) width = GROUP_M * grid_n group_id = pid // width group_size = min(grid_m - group_id * GROUP_M, GROUP_M) pid_m = group_id * GROUP_M + (pid % group_size) pid_n = (pid % width) // (group_size) - offs_am = pid_m * BLOCK_SIZE_M - offs_bn = pid_n * BLOCK_SIZE_N + offs_am = pid_m * BLOCK_M + offs_bn = pid_n * BLOCK_N offs_k = 0 offs_am = offs_am.to(tl.int32) offs_bn = offs_bn.to(tl.int32) offs_k = offs_k.to(tl.int32) - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) tme_load_ab_dtype = ab_dtype c_store_dtype = c_dtype - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + for k in range(0, tl.cdiv(K, BLOCK_K)): a = tl._experimental_descriptor_load( a_desc_ptr, [offs_am, offs_k], - [BLOCK_SIZE_M, BLOCK_SIZE_K], + [BLOCK_M, BLOCK_K], tme_load_ab_dtype, is_transpose_a, ) b = tl._experimental_descriptor_load( b_desc_ptr, [offs_k, offs_bn], - [BLOCK_SIZE_K, BLOCK_SIZE_N], + [BLOCK_K, BLOCK_N], tme_load_ab_dtype, is_transpose_b, ) accumulator += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) - offs_k += BLOCK_SIZE_K + offs_k += BLOCK_K accumulator = accumulator.to(c_store_dtype) tl._experimental_descriptor_store(c_desc_ptr, accumulator, [offs_am, offs_bn]) @@ -256,9 +400,9 @@ def get_triton_type(elem_type): return type_map.get(elem_type, None) -def mm_sqmma(A, B, M, N, K, GROUP_M, BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages): +def mm_sqmma(A, B, M, N, K, GROUP_M): logger.debug("GEMS_MTHREADS MM(SQMMA)") - device = "musa" + device = A.device # handle non-contiguous inputs if necessary is_transpose_a = False is_transpose_b = False @@ -277,24 +421,32 @@ def mm_sqmma(A, B, M, N, K, GROUP_M, BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_s assert a_type == b_type, "Mat A and Mat B should have the same dtype" c_dtype = get_higher_dtype(a_type, b_type) C = torch.empty((M, N), dtype=c_dtype, device=device) - desc_a = create_tma_device_descriptor(A, BLOCK_M, BLOCK_K, device) - desc_b = create_tma_device_descriptor(B, BLOCK_K, BLOCK_N, device) - desc_c = create_tma_device_descriptor(C, BLOCK_M, BLOCK_N, device) - mm_sqmma_kernel[(triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1, 1)]( + desc_a = torch.empty((64,), dtype=torch.int8, device=device) + desc_b = torch.empty((64,), dtype=torch.int8, device=device) + desc_c = torch.empty((64,), dtype=torch.int8, device=device) + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), + 1, + 1, + ) + mm_sqmma_kernel[grid]( + A, + B, + C, desc_a, desc_b, desc_c, M, N, K, - GROUP_M, - BLOCK_M, - BLOCK_N, - BLOCK_K, - get_triton_type(a_type), - get_triton_type(c_dtype), - num_warps=num_warps, - num_stages=num_stages, + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(1), + str(a_type).split(".")[-1], + GROUP_M=GROUP_M, + ab_dtype=get_triton_type(a_type), + c_dtype=get_triton_type(c_dtype), is_transpose_a=is_transpose_a, is_transpose_b=is_transpose_b, ) @@ -306,14 +458,12 @@ def mm(a, b): b_dtype = b.dtype M, K = a.shape _, N = b.shape - use_sqmma = should_enable_sqmma(a_dtype, b_dtype, M, N, K) - if use_sqmma: + if N == 1: + c_dtype = get_higher_dtype(a_dtype, b_dtype) + c = torch.empty((M, N), device=a.device, dtype=c_dtype) + return gemv_mm(a, b, c, M, K) + if is_sqmma_compatible(a, b, N, K): GROUP_M = 8 - BLOCK_M = 128 - BLOCK_N = BLOCK_M - BLOCK_K = 64 - num_warps = 16 if BLOCK_M == 256 else 4 - num_stages = 1 return mm_sqmma( a, b, @@ -321,15 +471,6 @@ def mm(a, b): N, K, GROUP_M, - BLOCK_M, - BLOCK_N, - BLOCK_K, - num_warps, - num_stages, ) else: - enable_sqmma = os.environ.pop("MUSA_ENABLE_SQMMA", None) - result = mm_fma(a, b) - if enable_sqmma: - os.environ["MUSA_ENABLE_SQMMA"] = enable_sqmma - return result + return mm_fma(a, b) diff --git a/src/flag_gems/runtime/backend/_mthreads/ops/utils.py b/src/flag_gems/runtime/backend/_mthreads/ops/utils.py index 58a628207d..87e640a9dc 100644 --- a/src/flag_gems/runtime/backend/_mthreads/ops/utils.py +++ b/src/flag_gems/runtime/backend/_mthreads/ops/utils.py @@ -1,10 +1,14 @@ import os +from collections import OrderedDict import numpy as np import torch import triton import triton.language as tl +_TMA_DESCRIPTOR_CACHE_MAXSIZE = 256 +_tma_descriptor_cache = OrderedDict() + def create_tma_device_descriptor(tensor, block_m, block_n, device): assert tensor.dim() == 2, "TMA descriptor only supports 2D tensors" @@ -29,6 +33,32 @@ def create_tma_device_descriptor(tensor, block_m, block_n, device): return desc +def _tma_descriptor_cache_key(tensor, block_m, block_n, device): + return ( + tensor.data_ptr(), + tuple(tensor.shape), + tuple(tensor.stride()), + str(tensor.dtype), + block_m, + block_n, + str(device), + ) + + +def get_cached_tma_device_descriptor(tensor, block_m, block_n, device): + key = _tma_descriptor_cache_key(tensor, block_m, block_n, device) + desc = _tma_descriptor_cache.get(key) + if desc is not None: + _tma_descriptor_cache.move_to_end(key) + return desc + + desc = create_tma_device_descriptor(tensor, block_m, block_n, device) + _tma_descriptor_cache[key] = desc + if len(_tma_descriptor_cache) > _TMA_DESCRIPTOR_CACHE_MAXSIZE: + _tma_descriptor_cache.popitem(last=False) + return desc + + def get_triton_dtype(dtype): dtype_map = { torch.float16: tl.float16, diff --git a/src/flag_gems/utils/libentry.py b/src/flag_gems/utils/libentry.py index 8c6d8360a1..60ffed06e3 100644 --- a/src/flag_gems/utils/libentry.py +++ b/src/flag_gems/utils/libentry.py @@ -724,6 +724,7 @@ def run(self, *args, **kwargs): constexprs = {} tune_constexprs = {} heur_constexprs = {} + launch_pre_hooks = [] while not isinstance(fn, triton.runtime.JITFunction): if isinstance(fn, triton.runtime.Autotuner): config = fn.best_config @@ -732,6 +733,10 @@ def run(self, *args, **kwargs): constexprs["num_ctas"] = config.num_ctas constexprs = {**constexprs, **config.kwargs} tune_constexprs = {**tune_constexprs, **config.kwargs} + if config.pre_hook is not None: + launch_pre_hooks.append( + (config.pre_hook, config.all_kwargs()) + ) elif isinstance(fn, triton.runtime.Heuristics): for v, heur in fn.values.items(): heur_constexprs[v] = heur( @@ -757,10 +762,17 @@ def run(self, *args, **kwargs): constexprs, tune_constexprs, heur_constexprs, + tuple(launch_pre_hooks), ) return kernel, constexprs - kernel, constexprs, tune_constexprs, heur_constexprs = cache[entry_key] + ( + kernel, + constexprs, + tune_constexprs, + heur_constexprs, + launch_pre_hooks, + ) = cache[entry_key] if callable(grid): # collect all arguments to the grid fn,ie: @@ -772,6 +784,11 @@ def run(self, *args, **kwargs): grid = grid(meta) grid = grid + (1, 1) + if launch_pre_hooks: + hook_nargs = {**dict(zip(self.arg_names, args)), **kwargs} + for pre_hook, hook_kwargs in launch_pre_hooks: + pre_hook({**hook_nargs, **hook_kwargs}) + if major_version == 3 and 3 <= minor_version <= 6: all_args = [] missing_keys = []