From 53bdd1117aa4fdea9c181e8fec6671bf9e38c8c1 Mon Sep 17 00:00:00 2001 From: v2yield Date: Fri, 3 Apr 2026 17:32:24 +0800 Subject: [PATCH 1/2] optimize _fp8_quantize/_int8_quantize/silu_and_mul/moe_sum --- src/flag_gems/fused/fused_moe.py | 727 ++++++++++++++++++++++++++++--- src/flag_gems/fused/moe_sum.py | 48 +- 2 files changed, 697 insertions(+), 78 deletions(-) diff --git a/src/flag_gems/fused/fused_moe.py b/src/flag_gems/fused/fused_moe.py index f460477308..1396990641 100644 --- a/src/flag_gems/fused/fused_moe.py +++ b/src/flag_gems/fused/fused_moe.py @@ -33,6 +33,10 @@ # OCP MX quantization helpers (requires amd-quark) OCP_MX_BLOCK_SIZE = 32 +_FP8_DTYPE = torch.float8_e4m3fn +_FP8_MAX = float(torch.finfo(_FP8_DTYPE).max) +_FP8_MIN = float(torch.finfo(_FP8_DTYPE).min) +_FP8_MIN_SCALE = 1.0 / (_FP8_MAX * 512.0) @functools.lru_cache(maxsize=1) @@ -317,6 +321,47 @@ def get_default_config( "num_warps": 4, "num_stages": 3, } + elif dtype == "int8_w8a8" or dtype == "fp8_w8a8": + if M <= 32: + block_m = 16 + elif M <= 96: + block_m = 32 + elif M <= 512: + block_m = 64 + else: + block_m = 128 + + block_n = 64 if M <= 64 else 128 + + # Small batches benefit from longer reduction (larger K tile), + # while large batches prefer more output parallelism. + # FP8 elements are half-width so larger K tiles are always cheap. + block_k = 128 if dtype == "fp8_w8a8" or M <= 64 else 64 + + # Grouping adjacent M-blocks lets them share weight tiles in L2. + # Only helps when there are enough M-blocks per expert to group; + # with many experts each one sees few tokens so grouping is useless. + tokens_per_expert = M // max(E, 1) + group_m = 16 if tokens_per_expert > 128 else 1 + + # Large batches have enough blocks to saturate the GPU, so we + # use more warps per block to increase arithmetic intensity. + num_warps = 4 if M <= 128 else 8 + + if M <= 32: + num_stages = 4 + else: + num_stages = 3 + + config = { + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": group_m, + "SPLIT_K": 1, + "num_warps": num_warps, + "num_stages": num_stages, + } else: # tokens_per_expert drives block_m: use M//E (not M*topk//E) to # estimate the actual per-expert token count after routing. @@ -376,6 +421,7 @@ def _get_config_dtype_str( dtype: Optional[torch.dtype] = None, use_fp8_w8a8: bool = False, use_fp8_w8a16: bool = False, + use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, ocp_mx_scheme: str | None = None, @@ -385,6 +431,8 @@ def _get_config_dtype_str( return "fp8_w8a8" elif use_fp8_w8a16: return "fp8_w8a16" + elif use_int8_w8a8: + return "int8_w8a8" elif use_int8_w8a16: return "int8_w8a16" elif use_int4_w4a16: @@ -459,11 +507,17 @@ def apply_moe_activation( f"{activation.value} expects equal sizes: " f"{output.size(-1)} vs {input.size(-1)}" ) - if activation in (MoEActivation.SILU, MoEActivation.SWIGLUOAI): N = output.size(-1) - x, y = input[:, :N], input[:, N:] - _silu_and_mul_kernel(x, y, out0=output) + if ( + input.is_contiguous() + and output.is_contiguous() + and input.size(0) <= _SMALL_M_SILU_MUL_THRESHOLD + ): + small_m_silu_and_mul_packed(output, input) + else: + x, y = input[:, :N], input[:, N:] + silu_and_mul_kernel(x, y, out0=output) elif activation == MoEActivation.GELU: N = output.size(-1) gate, up = input[:, :N], input[:, N:] @@ -490,63 +544,553 @@ def apply_moe_activation( return output +_SMALL_M_FP8_QUANT_THRESHOLD = 16 +_SINGLE_LAUNCH_FP8_QUANT_MAX_M = 8 +_SINGLE_LAUNCH_FP8_QUANT_MAX_NUMEL = 65536 +_SMALL_M_SILU_MUL_THRESHOLD = 32 + + +def _get_fp8_quant_2d_config( + M: int, + N: int, + *, + single_launch: bool, +) -> dict[str, int]: + if N <= 512: + block_n = 512 + num_warps = 4 + num_stages = 1 + elif N <= 1024: + block_n = 1024 + num_warps = 4 if single_launch or M <= 4 else 8 + num_stages = 2 + elif N <= 2048: + block_n = 2048 + num_warps = 8 + num_stages = 2 + else: + block_n = 4096 + num_warps = 8 + num_stages = 2 + + return { + "BLOCK_N": block_n, + "num_warps": num_warps, + "num_stages": num_stages, + } + + +def _get_fp8_quant_single_cta_config(numel: int) -> dict[str, int]: + if numel <= 1024: + block_size = 1024 + num_warps = 4 + elif numel <= 2048: + block_size = 2048 + num_warps = 8 + elif numel <= 4096: + block_size = 4096 + num_warps = 16 + elif numel <= 8192: + block_size = 8192 + num_warps = 16 + else: + block_size = 16384 + num_warps = 16 + + return { + "BLOCK_SIZE": block_size, + "num_warps": num_warps, + "num_stages": 2, + } + + +@triton.jit +def global_absmax_atomic_kernel( + inp_ptr, + absmax_ptr, + numel, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs < numel + x = tl.load(inp_ptr + offs, mask=mask, other=0.0).to(tl.float32) + tl.atomic_max(absmax_ptr, tl.max(tl.abs(x), axis=0)) + + +@triton.jit +def dynamic_scaled_fp8_quant_single_launch_kernel( + out_ptr, + inp_ptr, + scale_ptr, + M: tl.constexpr, + N: tl.constexpr, + stride_om, + stride_on, + stride_im, + stride_in, + fp8_min, + fp8_max, + min_scale, + BLOCK_N: tl.constexpr, +): + pid = tl.program_id(0) + if pid != 0: + return + + global_absmax = 0.0 + + for pid_m in tl.static_range(0, M): + row_base = pid_m * stride_im + for off in tl.static_range(0, N, BLOCK_N): + offs_n = off + tl.arange(0, BLOCK_N) + mask = offs_n < N + x = tl.load( + inp_ptr + row_base + offs_n * stride_in, + mask=mask, + other=0.0, + ).to(tl.float32) + global_absmax = tl.maximum(global_absmax, tl.max(tl.abs(x), axis=0)) + + scale = tl.maximum(global_absmax / fp8_max, min_scale) + tl.store(scale_ptr, scale) + inv_scale = 1.0 / scale + + for pid_m in tl.static_range(0, M): + row_in_base = pid_m * stride_im + row_out_base = pid_m * stride_om + for off in tl.static_range(0, N, BLOCK_N): + offs_n = off + tl.arange(0, BLOCK_N) + mask = offs_n < N + x = tl.load( + inp_ptr + row_in_base + offs_n * stride_in, + mask=mask, + other=0.0, + ).to(tl.float32) + q = tl.clamp(x * inv_scale, fp8_min, fp8_max) + tl.store( + out_ptr + row_out_base + offs_n * stride_on, + q.to(out_ptr.type.element_ty), + mask=mask, + ) + + +@triton.jit +def dynamic_scaled_fp8_quant_large_m_kernel( + out_ptr, + inp_ptr, + scale_ptr, + M: tl.constexpr, + N: tl.constexpr, + stride_om, + stride_on, + stride_im, + stride_in, + fp8_min, + fp8_max, + min_scale, + BLOCK_N: tl.constexpr, +): + pid_m = tl.program_id(0) + if pid_m >= M: + return + + row_in_base = pid_m * stride_im + row_out_base = pid_m * stride_om + scale = tl.load(scale_ptr).to(tl.float32) + inv_scale = 1.0 / tl.maximum(scale, min_scale) + + for off in tl.static_range(0, N, BLOCK_N): + offs_n = off + tl.arange(0, BLOCK_N) + mask = offs_n < N + x = tl.load( + inp_ptr + row_in_base + offs_n * stride_in, + mask=mask, + other=0.0, + ).to(tl.float32) + q = tl.clamp(x * inv_scale, fp8_min, fp8_max) + tl.store( + out_ptr + row_out_base + offs_n * stride_on, + q.to(out_ptr.type.element_ty), + mask=mask, + ) + + +def dynamic_scaled_fp8_quant_small_m( + output: torch.Tensor, + input: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + scale_out = torch.empty((1,), device=input.device, dtype=torch.float32) + config = _get_fp8_quant_single_cta_config(input.shape[-1]) + dynamic_scaled_fp8_quant_single_launch_kernel[(1,)]( + output, + input, + scale_out, + input.shape[0], + input.shape[1], + output.stride(0), + output.stride(1), + input.stride(0), + input.stride(1), + fp8_min=_FP8_MIN, + fp8_max=_FP8_MAX, + min_scale=_FP8_MIN_SCALE, + BLOCK_N=config["BLOCK_SIZE"], + num_warps=config["num_warps"], + num_stages=config["num_stages"], + ) + return output, scale_out + + +def dynamic_scaled_fp8_quant_large_m( + output: torch.Tensor, + input: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + M, N = input.shape + config = _get_fp8_quant_2d_config(M, N, single_launch=False) + absmax = torch.zeros((1,), device=input.device, dtype=torch.float32) + + numel = input.numel() + BLOCK = 512 + global_absmax_atomic_kernel[(triton.cdiv(numel, BLOCK),)]( + input, + absmax, + numel, + BLOCK_SIZE=BLOCK, + num_warps=4, + num_stages=1, + ) + + scale_out = (absmax / _FP8_MAX).clamp_(min=_FP8_MIN_SCALE) + dynamic_scaled_fp8_quant_large_m_kernel[(M,)]( + output, + input, + scale_out, + M, + N, + output.stride(0), + output.stride(1), + input.stride(0), + input.stride(1), + fp8_min=_FP8_MIN, + fp8_max=_FP8_MAX, + min_scale=_FP8_MIN_SCALE, + BLOCK_N=config["BLOCK_N"], + num_warps=config["num_warps"], + num_stages=config["num_stages"], + ) + return output, scale_out + + +def dynamic_scaled_fp8_quant( + input: torch.Tensor, + output: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Dynamic per-tensor FP8 quantization. + """ + assert input.ndim == 2 and input.stride(-1) == 1 + + if output is None: + output = torch.empty_like(input, dtype=torch.float8_e4m3fn) + else: + assert output.shape == input.shape and output.dtype == torch.float8_e4m3fn + + M, N = input.shape + use_single_launch_kernel = M <= 8 and input.numel() <= 65536 + if use_single_launch_kernel: + return dynamic_scaled_fp8_quant_small_m(output, input) + + return dynamic_scaled_fp8_quant_large_m(output, input) + + +@triton.jit +def per_token_group_quant_fp8_kernel( + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + y_num_columns, + y_row_stride, + eps, + fp8_min, + fp8_max, + BLOCK: tl.constexpr, +): + """ + This function converts the tensor values into float8 values. + """ + groups_per_row = y_num_columns // group_size + g_id = tl.program_id(0) + row = g_id // groups_per_row + row_g_id = g_id % groups_per_row + + y_ptr_offset = (row.to(tl.int64) * y_row_stride) + ( + row_g_id.to(tl.int64) * group_size + ) + y_ptr += y_ptr_offset + + y_q_ptr_offset = g_id.to(tl.int64) * group_size + y_q_ptr += y_q_ptr_offset + y_s_ptr += g_id + + cols = tl.arange(0, BLOCK) + mask = cols < group_size + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + scale_raw = _absmax / fp8_max + y_s = scale_raw + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +def per_token_group_quant_fp8( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, + dtype: torch.dtype = torch.float8_e4m3fn, + out_q: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Function to perform per-token-group quantization on an input tensor `x`.""" + assert x.shape[-1] % group_size == 0, ( + f"the last dimension of `x` {x.shape[-1]} must be divisible " + f"by `group_size` {group_size}" + ) + assert x.stride(-1) == 1, "`x` groups must be contiguous" + + fp8_min, fp8_max = torch.finfo(torch.float8_e4m3fn) + + assert out_q is None or out_q.shape == x.shape + x_q = out_q + if x_q is None: + x_q = torch.empty(x.shape, device=x.device, dtype=dtype) + + shape = x.shape[:-1] + (x.shape[-1] // group_size,) + x_s = torch.empty(shape, device=x.device, dtype=torch.float32) + + M = x.numel() // group_size + N = group_size + BLOCK = triton.next_power_of_2(N) + + num_warps = min(max(BLOCK // 256, 1), 8) + num_stages = 1 + + per_token_group_quant_fp8_kernel[(M,)]( + x, + x_q, + x_s, + group_size, + x.shape[1], + x.stride(0), + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + + return x_q, x_s + + +@triton.jit +def per_token_quant_int8_kernel( + x_ptr, + xq_ptr, + scale_ptr, + stride_x, + stride_xq, + N, + BLOCK: tl.constexpr, +): + # row_id = tl.program_id(0) + # row_in_base = row_id * stride_x + # row_out_base = row_id * stride_xq + # row_absmax = 0.0 + + # for off in tl.range(0, N, BLOCK): + # cols = off + tl.arange(0, BLOCK) + # mask = cols < N + # x = tl.load(x_ptr + row_in_base + cols, mask=mask, other=0.0).to(tl.float32) + # row_absmax = tl.maximum(row_absmax, tl.max(tl.abs(x), axis=0)) + + # absmax = tl.maximum(row_absmax, 1e-10) + # scale_x = absmax / 127.0 + # inv_scale_x = 127.0 / absmax + + # for off in tl.range(0, N, BLOCK): + # cols = off + tl.arange(0, BLOCK) + # mask = cols < N + # x = tl.load(x_ptr + row_in_base + cols, mask=mask, other=0.0).to(tl.float32) + # x_q = tl.extra.cuda.libdevice.round(x * inv_scale_x) + # x_q = tl.clamp(x_q, -128.0, 127.0).to(xq_ptr.dtype.element_ty) + # tl.store(xq_ptr + row_out_base + cols, x_q, mask=mask) + + # tl.store(scale_ptr + row_id, scale_x) + row_id = tl.program_id(0) + + cols = tl.arange(0, BLOCK) + mask = cols < N + + x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask, other=0.0).to(tl.float32) + absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10) + scale_x = absmax / 127 + x_q = x * (127 / absmax) + x_q = tl.extra.cuda.libdevice.round(x_q) + + tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask) + tl.store(scale_ptr + row_id, scale_x) + + +def per_token_quant_int8(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + original_shape = x.shape + if x.dim() > 2: + x = x.view(-1, original_shape[-1]) + M = x.numel() // x.shape[-1] + N = x.shape[-1] + x_q = torch.empty((M, N), device=x.device, dtype=torch.int8) + scales = torch.empty((M, 1), device=x.device, dtype=torch.float32) + BLOCK = triton.next_power_of_2(N) + num_warps = min(max(BLOCK // 256, 1), 8) + x = x.contiguous() + + per_token_quant_int8_kernel[(M,)]( + x, + x_q, + scales, + stride_x=x.stride(-2), + stride_xq=x_q.stride(-2), + N=N, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + x_q = x_q.view(*original_shape) + scales = scales.view(*original_shape[:-1], 1) + return x_q, scales + + +@triton.jit +def per_token_group_quant_int8_kernel( + y_ptr, + y_q_ptr, + y_s_ptr, + y_stride, + N, + eps, + int8_min, + int8_max, + BLOCK: tl.constexpr, +): + """ + This function converts the tensor values into int8 values. + """ + g_id = tl.program_id(0) + y_ptr += g_id * y_stride + y_q_ptr += g_id * y_stride + y_s_ptr += g_id + + cols = tl.arange(0, BLOCK) + mask = cols < N + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / int8_max + y_q = tl.clamp(y / y_s, int8_min, int8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +def per_token_group_quant_int8( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, + dtype: torch.dtype = torch.int8, +) -> tuple[torch.Tensor, torch.Tensor]: + assert ( + x.shape[-1] % group_size == 0 + ), "the last dimension of `x` cannot be divisible by `group_size`" + assert x.is_contiguous(), "`x` is not contiguous" + + iinfo = torch.iinfo(dtype) + int8_max = iinfo.max + int8_min = iinfo.min + + x_q = torch.empty_like(x, device=x.device, dtype=dtype) + x_s = torch.empty( + x.shape[:-1] + (x.shape[-1] // group_size,), + device=x.device, + dtype=torch.float32, + ) + + M = x.numel() // group_size + N = group_size + + BLOCK = triton.next_power_of_2(N) + + num_warps = min(max(BLOCK // 256, 1), 8) + num_stages = 1 + per_token_group_quant_int8_kernel[(M,)]( + x, + x_q, + x_s, + group_size, + N, + eps, + int8_min=int8_min, + int8_max=int8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + + return x_q, x_s + + def _fp8_quantize( A: torch.Tensor, A_scale: Optional[torch.Tensor], per_act_token: bool, block_shape: Optional[list[int]] = None, ) -> tuple[torch.Tensor, torch.Tensor]: - """FP8 E4M3 quantization: per-tensor, per-token, or block-wise.""" - fp8_dtype = torch.float8_e4m3fn - finfo = torch.finfo(fp8_dtype) - fp8_max = finfo.max - fp8_min = finfo.min - eps = 1e-10 - + """FP8 E4M3 quantization: keep dispatch shallow, specialize the hot paths.""" if block_shape is not None: assert not per_act_token assert len(block_shape) == 2 - block_k = block_shape[1] - assert A.size(-1) % block_k == 0 - orig_shape = A.shape - A_flat = A.reshape(-1, A.size(-1)) - M, K = A_flat.shape - A_groups = A_flat.reshape(M * (K // block_k), block_k) - amax = ( - A_groups.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32) - ) - scale = amax / fp8_max - A_q = (A_groups.float() / scale).clamp(fp8_min, fp8_max).to(fp8_dtype) - A_q = A_q.reshape(orig_shape) - scale = scale.reshape(M, K // block_k) - return A_q, scale + _, block_k = block_shape[0], block_shape[1] + return per_token_group_quant_fp8(A, block_k) elif per_act_token: A_flat = A.reshape(-1, A.size(-1)) - amax = A_flat.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32) - scale = amax / fp8_max - min_scale = torch.tensor( - 1.0 / (fp8_max * 512.0), dtype=torch.float32, device=A.device + amax = ( + A_flat.abs().amax(dim=-1, keepdim=True).clamp(min=1e-10).to(torch.float32) ) + scale = amax / _FP8_MAX + min_scale = torch.tensor(_FP8_MIN_SCALE, dtype=torch.float32, device=A.device) scale = scale.clamp(min=min_scale) - A_q = (A_flat.float() / scale).clamp(fp8_min, fp8_max).to(fp8_dtype) + A_q = (A_flat.float() / scale).clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE) A_q = A_q.reshape(A.shape) scale = scale.reshape(A.shape[:-1] + (1,)) return A_q, scale - else: if A_scale is not None: scale = ( - A_scale.float().view(1, 1) if A_scale.numel() == 1 else A_scale.float() + A_scale.float().reshape(1) if A_scale.numel() == 1 else A_scale.float() ) - A_q = (A.float() / scale).clamp(fp8_min, fp8_max).to(fp8_dtype) - return A_q, A_scale + A_q = (A.float() / scale).clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE) + return A_q, scale else: - amax = A.abs().amax().clamp(min=eps).to(torch.float32) - scale = amax / fp8_max - iscale = 1.0 / scale - A_q = (A.float() * iscale).clamp(fp8_min, fp8_max).to(fp8_dtype) - return A_q, scale.view(1) + assert A.stride(-1) == 1, "last dimension must be contiguous" + orig_shape = A.shape + A_2d = A.reshape(-1, orig_shape[-1]) + A_q_2d, scale = dynamic_scaled_fp8_quant(A_2d) + return A_q_2d.reshape(orig_shape), scale def _int8_quantize( @@ -556,41 +1100,20 @@ def _int8_quantize( block_shape: Optional[list[int]] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """INT8 quantization: per-tensor, per-token, or block-wise.""" - iinfo = torch.iinfo(torch.int8) - int8_max = iinfo.max - int8_min = iinfo.min - eps = 1e-10 - if block_shape is not None: assert not per_act_token assert len(block_shape) == 2 - block_k = block_shape[1] - assert A.size(-1) % block_k == 0 - orig_shape = A.shape - A_flat = A.reshape(-1, A.size(-1)) - M, K = A_flat.shape - A_groups = A_flat.reshape(M * (K // block_k), block_k) - amax = ( - A_groups.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32) - ) - scale = amax / int8_max - A_q = ( - (A_groups.float() / scale).round().clamp(int8_min, int8_max).to(torch.int8) - ) - A_q = A_q.reshape(orig_shape) - scale = scale.reshape(M, K // block_k) - return A_q, scale + _, block_k = block_shape[0], block_shape[1] + A_q, A_scale_out = per_token_group_quant_int8(A.contiguous(), block_k) + return A_q, A_scale_out elif per_act_token: - A_flat = A.reshape(-1, A.size(-1)) - amax = A_flat.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32) - scale = amax / int8_max - A_q = (A_flat.float() / scale).round().clamp(int8_min, int8_max).to(torch.int8) - A_q = A_q.reshape(A.shape) - scale = scale.reshape(A.shape[:-1] + (1,)) - return A_q, scale + return per_token_quant_int8(A) else: + iinfo = torch.iinfo(torch.int8) + int8_max = iinfo.max + int8_min = iinfo.min assert A_scale is not None, "int8 per-tensor requires A_scale" scale = A_scale.float().view(1, 1) if A_scale.numel() == 1 else A_scale.float() A_q = (A.float() / scale).round().clamp(int8_min, int8_max).to(torch.int8) @@ -645,12 +1168,81 @@ def _ensure_block_size_k_divisible( @pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit -def _silu_and_mul_kernel(x, y): +def silu_and_mul_kernel(x, y): x_fp32 = x.to(tl.float32) x_silu = tl.fdiv(x_fp32, (1.0 + tl.exp(-x_fp32))) return x_silu * y +@triton.jit +def small_m_silu_and_mul_kernel( + out_ptr, + inp_ptr, + M: tl.constexpr, + N: tl.constexpr, + stride_om, + stride_on, + stride_im, + stride_in, + BLOCK_N: tl.constexpr, +): + pid = tl.program_id(0) + num_pid_n = tl.cdiv(N, BLOCK_N) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + if pid_m >= M: + return + + row_out_base = pid_m * stride_om + row_in_base = pid_m * stride_im + + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask = offs_n < N + x = tl.load( + inp_ptr + row_in_base + offs_n * stride_in, + mask=mask, + other=0.0, + ).to(tl.float32) + y = tl.load( + inp_ptr + row_in_base + (offs_n + N) * stride_in, + mask=mask, + other=0.0, + ) + x_silu = tl.fdiv(x, 1.0 + tl.exp(-x)) + tl.store( + out_ptr + row_out_base + offs_n * stride_on, + (x_silu * y).to(out_ptr.type.element_ty), + mask=mask, + ) + + +def small_m_silu_and_mul_packed( + output: torch.Tensor, + input: torch.Tensor, +) -> None: + assert input.ndim == 2 and output.ndim == 2 + assert input.size(0) == output.size(0) + assert input.size(1) == output.size(1) * 2 + assert input.stride(-1) == 1 and output.stride(-1) == 1 + + M, N = output.shape + grid = lambda META: (M * triton.cdiv(N, META["BLOCK_N"]),) + small_m_silu_and_mul_kernel[grid]( + output, + input, + M, + N, + output.stride(0), + output.stride(1), + input.stride(0), + input.stride(1), + BLOCK_N=2048, + num_warps=8, + num_stages=1, + ) + + @triton.jit def write_zeros_to_output( c_ptr, @@ -1412,6 +2004,7 @@ def fused_experts_impl( config_dtype = _get_config_dtype_str( use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, ocp_mx_scheme=ocp_mx_scheme, diff --git a/src/flag_gems/fused/moe_sum.py b/src/flag_gems/fused/moe_sum.py index 185b766365..3a94118758 100644 --- a/src/flag_gems/fused/moe_sum.py +++ b/src/flag_gems/fused/moe_sum.py @@ -7,21 +7,43 @@ logger = logging.getLogger(__name__) -@triton.autotune( - configs=[ - triton.Config({"BLOCK_SIZE": 128}, num_warps=2), - triton.Config({"BLOCK_SIZE": 256}, num_warps=4), - triton.Config({"BLOCK_SIZE": 512}, num_warps=8), - triton.Config({"BLOCK_SIZE": 1024}, num_warps=8), - ], - key=["hidden_size", "topk"], -) +def _get_moe_sum_config( + num_tokens: int, + hidden_size: int, + topk: int, +): + if hidden_size <= 64: + block_size = 64 + num_warps = 1 + elif hidden_size <= 128: + block_size = 128 + num_warps = 2 + elif hidden_size <= 256: + block_size = 256 + num_warps = 2 if num_tokens < 8 and topk <= 2 else 4 + elif hidden_size <= 512: + block_size = 512 + num_warps = 4 + else: + block_size = 1024 + if num_tokens < 4 and topk <= 2: + num_warps = 4 + else: + num_warps = 8 + + return { + "BLOCK_SIZE": block_size, + "num_warps": num_warps, + "num_stages": 1, + } + + @triton.jit def moe_sum_kernel( input_ptr, output_ptr, num_tokens, - topk, + topk: tl.constexpr, hidden_size, input_stride_token, input_stride_topk, @@ -40,7 +62,7 @@ def moe_sum_kernel( acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) input_base = input_ptr + token_idx * input_stride_token - for expert_idx in range(topk): + for expert_idx in tl.static_range(topk): expert_ptr = input_base + expert_idx * input_stride_topk expert_data = tl.load(expert_ptr + hidden_offsets, mask=hidden_mask, other=0.0) acc += expert_data @@ -61,6 +83,7 @@ def moe_sum( num_tokens, topk, hidden_size = input.shape input_strides = input.stride() output_strides = output.stride() + config = _get_moe_sum_config(num_tokens, hidden_size, topk) grid = lambda meta: (num_tokens, triton.cdiv(hidden_size, meta["BLOCK_SIZE"])) moe_sum_kernel[grid]( input, @@ -73,4 +96,7 @@ def moe_sum( input_strides[2], output_strides[0], output_strides[1], + BLOCK_SIZE=config["BLOCK_SIZE"], + num_warps=config["num_warps"], + num_stages=config["num_stages"], ) From ce664ed9a6718863163362cd8bb56f8c76ba6310 Mon Sep 17 00:00:00 2001 From: v2yield Date: Fri, 3 Apr 2026 18:58:01 +0800 Subject: [PATCH 2/2] remove unused variables --- src/flag_gems/fused/fused_moe.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/src/flag_gems/fused/fused_moe.py b/src/flag_gems/fused/fused_moe.py index 1396990641..da3d0846ba 100644 --- a/src/flag_gems/fused/fused_moe.py +++ b/src/flag_gems/fused/fused_moe.py @@ -509,11 +509,7 @@ def apply_moe_activation( ) if activation in (MoEActivation.SILU, MoEActivation.SWIGLUOAI): N = output.size(-1) - if ( - input.is_contiguous() - and output.is_contiguous() - and input.size(0) <= _SMALL_M_SILU_MUL_THRESHOLD - ): + if input.is_contiguous() and output.is_contiguous() and input.size(0) <= 32: small_m_silu_and_mul_packed(output, input) else: x, y = input[:, :N], input[:, N:] @@ -544,12 +540,6 @@ def apply_moe_activation( return output -_SMALL_M_FP8_QUANT_THRESHOLD = 16 -_SINGLE_LAUNCH_FP8_QUANT_MAX_M = 8 -_SINGLE_LAUNCH_FP8_QUANT_MAX_NUMEL = 65536 -_SMALL_M_SILU_MUL_THRESHOLD = 32 - - def _get_fp8_quant_2d_config( M: int, N: int,